sglang 0.4.9.post2__py3-none-any.whl → 0.4.9.post4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (200) hide show
  1. sglang/bench_one_batch.py +2 -1
  2. sglang/eval/loogle_eval.py +7 -0
  3. sglang/srt/_custom_ops.py +29 -1
  4. sglang/srt/configs/deepseekvl2.py +11 -2
  5. sglang/srt/configs/internvl.py +3 -0
  6. sglang/srt/configs/janus_pro.py +3 -0
  7. sglang/srt/configs/model_config.py +10 -8
  8. sglang/srt/configs/update_config.py +3 -1
  9. sglang/srt/conversation.py +2 -1
  10. sglang/srt/custom_op.py +5 -2
  11. sglang/srt/disaggregation/common/conn.py +34 -6
  12. sglang/srt/disaggregation/decode.py +9 -1
  13. sglang/srt/disaggregation/mini_lb.py +3 -2
  14. sglang/srt/disaggregation/mooncake/conn.py +93 -76
  15. sglang/srt/disaggregation/mooncake/transfer_engine.py +4 -2
  16. sglang/srt/disaggregation/nixl/conn.py +17 -13
  17. sglang/srt/distributed/device_communicators/custom_all_reduce.py +3 -91
  18. sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +96 -1
  19. sglang/srt/distributed/device_communicators/quick_all_reduce.py +273 -0
  20. sglang/srt/distributed/device_communicators/shm_broadcast.py +12 -5
  21. sglang/srt/distributed/parallel_state.py +103 -15
  22. sglang/srt/entrypoints/engine.py +31 -33
  23. sglang/srt/entrypoints/http_server.py +20 -32
  24. sglang/srt/entrypoints/openai/protocol.py +3 -3
  25. sglang/srt/entrypoints/openai/serving_chat.py +48 -6
  26. sglang/srt/eplb/expert_location_dispatch.py +1 -1
  27. sglang/srt/function_call/base_format_detector.py +74 -12
  28. sglang/srt/function_call/deepseekv3_detector.py +26 -11
  29. sglang/srt/function_call/ebnf_composer.py +95 -63
  30. sglang/srt/function_call/function_call_parser.py +4 -2
  31. sglang/srt/function_call/kimik2_detector.py +41 -16
  32. sglang/srt/function_call/llama32_detector.py +6 -3
  33. sglang/srt/function_call/mistral_detector.py +11 -3
  34. sglang/srt/function_call/pythonic_detector.py +16 -14
  35. sglang/srt/function_call/qwen25_detector.py +12 -3
  36. sglang/srt/function_call/qwen3_coder_detector.py +151 -0
  37. sglang/srt/hf_transformers_utils.py +0 -1
  38. sglang/srt/layers/activation.py +24 -3
  39. sglang/srt/layers/attention/base_attn_backend.py +3 -1
  40. sglang/srt/layers/attention/flashattention_backend.py +3 -3
  41. sglang/srt/layers/attention/flashinfer_backend.py +40 -1
  42. sglang/srt/layers/communicator.py +12 -12
  43. sglang/srt/layers/dp_attention.py +72 -24
  44. sglang/srt/layers/linear.py +13 -102
  45. sglang/srt/layers/logits_processor.py +34 -24
  46. sglang/srt/layers/moe/ep_moe/kernels.py +4 -2
  47. sglang/srt/layers/moe/ep_moe/layer.py +23 -402
  48. sglang/srt/layers/moe/fused_moe_native.py +7 -47
  49. sglang/srt/layers/moe/fused_moe_triton/__init__.py +4 -4
  50. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=320,device_name=NVIDIA_H20-3e.json +146 -0
  51. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  52. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  53. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=384,N=256,device_name=NVIDIA_H20-3e,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  54. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=385,N=128,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  55. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=385,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  56. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +54 -263
  57. sglang/srt/layers/moe/fused_moe_triton/layer.py +14 -396
  58. sglang/srt/layers/moe/topk.py +190 -23
  59. sglang/srt/layers/quantization/__init__.py +20 -134
  60. sglang/srt/layers/quantization/awq.py +578 -11
  61. sglang/srt/layers/quantization/awq_triton.py +339 -0
  62. sglang/srt/layers/quantization/base_config.py +85 -10
  63. sglang/srt/layers/quantization/blockwise_int8.py +17 -55
  64. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +13 -11
  65. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +23 -79
  66. sglang/srt/layers/quantization/fp8.py +273 -62
  67. sglang/srt/layers/quantization/fp8_kernel.py +210 -46
  68. sglang/srt/layers/quantization/fp8_utils.py +2 -2
  69. sglang/srt/layers/quantization/gptq.py +501 -143
  70. sglang/srt/layers/quantization/marlin_utils.py +790 -0
  71. sglang/srt/layers/quantization/modelopt_quant.py +34 -112
  72. sglang/srt/layers/quantization/moe_wna16.py +45 -49
  73. sglang/srt/layers/quantization/petit.py +252 -0
  74. sglang/srt/layers/quantization/petit_utils.py +104 -0
  75. sglang/srt/layers/quantization/qoq.py +7 -6
  76. sglang/srt/layers/quantization/scalar_type.py +352 -0
  77. sglang/srt/layers/quantization/unquant.py +422 -0
  78. sglang/srt/layers/quantization/utils.py +340 -9
  79. sglang/srt/layers/quantization/w4afp8.py +8 -4
  80. sglang/srt/layers/quantization/w8a8_fp8.py +17 -51
  81. sglang/srt/layers/quantization/w8a8_int8.py +51 -115
  82. sglang/srt/layers/radix_attention.py +5 -3
  83. sglang/srt/layers/vocab_parallel_embedding.py +1 -41
  84. sglang/srt/lora/lora.py +0 -4
  85. sglang/srt/lora/lora_manager.py +162 -164
  86. sglang/srt/lora/lora_registry.py +124 -0
  87. sglang/srt/lora/mem_pool.py +83 -35
  88. sglang/srt/lora/utils.py +12 -5
  89. sglang/srt/managers/cache_controller.py +288 -0
  90. sglang/srt/managers/io_struct.py +60 -30
  91. sglang/srt/managers/mm_utils.py +7 -8
  92. sglang/srt/managers/schedule_batch.py +163 -113
  93. sglang/srt/managers/schedule_policy.py +68 -27
  94. sglang/srt/managers/scheduler.py +256 -86
  95. sglang/srt/managers/scheduler_output_processor_mixin.py +22 -4
  96. sglang/srt/managers/tokenizer_manager.py +38 -27
  97. sglang/srt/managers/tp_worker.py +16 -4
  98. sglang/srt/managers/tp_worker_overlap_thread.py +11 -0
  99. sglang/srt/mem_cache/allocator.py +74 -23
  100. sglang/srt/mem_cache/base_prefix_cache.py +14 -2
  101. sglang/srt/mem_cache/chunk_cache.py +5 -2
  102. sglang/srt/mem_cache/hicache_storage.py +168 -0
  103. sglang/srt/mem_cache/hiradix_cache.py +194 -5
  104. sglang/srt/mem_cache/memory_pool.py +16 -1
  105. sglang/srt/mem_cache/memory_pool_host.py +44 -2
  106. sglang/srt/mem_cache/radix_cache.py +26 -0
  107. sglang/srt/mem_cache/swa_radix_cache.py +1025 -0
  108. sglang/srt/metrics/collector.py +9 -0
  109. sglang/srt/model_executor/cuda_graph_runner.py +66 -31
  110. sglang/srt/model_executor/forward_batch_info.py +210 -25
  111. sglang/srt/model_executor/model_runner.py +147 -42
  112. sglang/srt/model_loader/loader.py +7 -1
  113. sglang/srt/model_loader/utils.py +4 -4
  114. sglang/srt/models/clip.py +1 -1
  115. sglang/srt/models/deepseek.py +9 -6
  116. sglang/srt/models/deepseek_janus_pro.py +1 -1
  117. sglang/srt/models/deepseek_v2.py +192 -173
  118. sglang/srt/models/deepseek_vl2.py +5 -5
  119. sglang/srt/models/gemma.py +48 -0
  120. sglang/srt/models/gemma2.py +52 -0
  121. sglang/srt/models/gemma3_causal.py +63 -0
  122. sglang/srt/models/gemma3_mm.py +1 -1
  123. sglang/srt/models/gemma3n_mm.py +2 -4
  124. sglang/srt/models/granitemoe.py +385 -0
  125. sglang/srt/models/grok.py +9 -3
  126. sglang/srt/models/hunyuan.py +63 -16
  127. sglang/srt/models/internvl.py +1 -1
  128. sglang/srt/models/kimi_vl.py +1 -1
  129. sglang/srt/models/llama.py +41 -0
  130. sglang/srt/models/llama4.py +11 -11
  131. sglang/srt/models/llava.py +2 -2
  132. sglang/srt/models/llavavid.py +1 -1
  133. sglang/srt/models/minicpm.py +0 -2
  134. sglang/srt/models/minicpmo.py +3 -7
  135. sglang/srt/models/minicpmv.py +1 -1
  136. sglang/srt/models/mistral.py +1 -1
  137. sglang/srt/models/mixtral.py +9 -2
  138. sglang/srt/models/mllama.py +3 -5
  139. sglang/srt/models/mllama4.py +13 -6
  140. sglang/srt/models/olmoe.py +8 -5
  141. sglang/srt/models/persimmon.py +330 -0
  142. sglang/srt/models/phi.py +321 -0
  143. sglang/srt/models/phi4mm.py +44 -4
  144. sglang/srt/models/phi4mm_audio.py +1260 -0
  145. sglang/srt/models/phi4mm_utils.py +1917 -0
  146. sglang/srt/models/phimoe.py +9 -3
  147. sglang/srt/models/qwen.py +37 -0
  148. sglang/srt/models/qwen2.py +41 -0
  149. sglang/srt/models/qwen2_5_vl.py +4 -4
  150. sglang/srt/models/qwen2_audio.py +1 -1
  151. sglang/srt/models/qwen2_moe.py +53 -9
  152. sglang/srt/models/qwen2_vl.py +4 -4
  153. sglang/srt/models/qwen3.py +65 -1
  154. sglang/srt/models/qwen3_moe.py +57 -24
  155. sglang/srt/models/vila.py +1 -1
  156. sglang/srt/multimodal/processors/base_processor.py +91 -97
  157. sglang/srt/multimodal/processors/clip.py +21 -19
  158. sglang/srt/multimodal/processors/deepseek_vl_v2.py +8 -26
  159. sglang/srt/multimodal/processors/gemma3.py +13 -17
  160. sglang/srt/multimodal/processors/gemma3n.py +19 -23
  161. sglang/srt/multimodal/processors/internvl.py +9 -10
  162. sglang/srt/multimodal/processors/janus_pro.py +12 -27
  163. sglang/srt/multimodal/processors/kimi_vl.py +12 -14
  164. sglang/srt/multimodal/processors/llava.py +4 -2
  165. sglang/srt/multimodal/processors/minicpm.py +35 -44
  166. sglang/srt/multimodal/processors/mlama.py +21 -18
  167. sglang/srt/multimodal/processors/mllama4.py +4 -5
  168. sglang/srt/multimodal/processors/phi4mm.py +63 -39
  169. sglang/srt/multimodal/processors/pixtral.py +14 -35
  170. sglang/srt/multimodal/processors/qwen_audio.py +65 -0
  171. sglang/srt/multimodal/processors/qwen_vl.py +16 -21
  172. sglang/srt/multimodal/processors/vila.py +14 -14
  173. sglang/srt/reasoning_parser.py +46 -4
  174. sglang/srt/sampling/sampling_batch_info.py +6 -5
  175. sglang/srt/sampling/sampling_params.py +8 -1
  176. sglang/srt/server_args.py +454 -270
  177. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +33 -28
  178. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +46 -37
  179. sglang/srt/speculative/eagle_utils.py +51 -23
  180. sglang/srt/speculative/eagle_worker.py +59 -44
  181. sglang/srt/two_batch_overlap.py +10 -5
  182. sglang/srt/utils.py +44 -69
  183. sglang/test/runners.py +14 -3
  184. sglang/test/test_activation.py +50 -1
  185. sglang/test/test_block_fp8.py +8 -3
  186. sglang/test/test_block_fp8_ep.py +1 -1
  187. sglang/test/test_custom_ops.py +12 -7
  188. sglang/test/test_cutlass_w4a8_moe.py +1 -3
  189. sglang/test/test_fp4_moe.py +1 -3
  190. sglang/test/test_marlin_moe.py +286 -0
  191. sglang/test/test_marlin_utils.py +171 -0
  192. sglang/test/test_utils.py +35 -0
  193. sglang/version.py +1 -1
  194. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/METADATA +10 -10
  195. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/RECORD +198 -175
  196. sglang/srt/layers/quantization/quant_utils.py +0 -166
  197. sglang/srt/managers/multimodal_processors/qwen_audio.py +0 -94
  198. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/WHEEL +0 -0
  199. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/licenses/LICENSE +0 -0
  200. {sglang-0.4.9.post2.dist-info → sglang-0.4.9.post4.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,339 @@
1
+ # Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/layers/quantization/awq_triton.py
2
+
3
+ # SPDX-License-Identifier: Apache-2.0
4
+ # SPDX-FileCopyrightText: Copyright contributors to the vLLM project
5
+
6
+ import torch
7
+ import triton
8
+ import triton.language as tl
9
+
10
+ AWQ_TRITON_SUPPORTED_GROUP_SIZES = [-1, 32, 64, 128]
11
+
12
+
13
+ @triton.jit
14
+ def awq_dequantize_kernel(
15
+ qweight_ptr, # quantized matrix
16
+ scales_ptr, # scales, per group
17
+ zeros_ptr, # zeros, per group
18
+ group_size, # Should always be one of the supported group sizes
19
+ result_ptr, # Output matrix
20
+ num_cols, # input num cols in qweight
21
+ num_rows, # input num rows in qweight
22
+ BLOCK_SIZE_X: tl.constexpr,
23
+ BLOCK_SIZE_Y: tl.constexpr,
24
+ ):
25
+ # Setup the pids.
26
+ pid_x = tl.program_id(axis=0)
27
+ pid_y = tl.program_id(axis=1)
28
+
29
+ # Compute offsets and masks for qweight_ptr.
30
+ offsets_y = pid_y * BLOCK_SIZE_Y + tl.arange(0, BLOCK_SIZE_Y)
31
+ offsets_x = pid_x * BLOCK_SIZE_X + tl.arange(0, BLOCK_SIZE_X)
32
+ offsets = num_cols * offsets_y[:, None] + offsets_x[None, :]
33
+
34
+ masks_y = offsets_y < num_rows
35
+ masks_x = offsets_x < num_cols
36
+
37
+ masks = masks_y[:, None] & masks_x[None, :]
38
+
39
+ # Compute offsets and masks for result output ptr.
40
+ result_offsets_y = pid_y * BLOCK_SIZE_Y + tl.arange(0, BLOCK_SIZE_Y)
41
+ result_offsets_x = pid_x * BLOCK_SIZE_X * 8 + tl.arange(0, BLOCK_SIZE_X * 8)
42
+ result_offsets = (
43
+ 8 * num_cols * result_offsets_y[:, None] + result_offsets_x[None, :]
44
+ )
45
+
46
+ result_masks_y = result_offsets_y < num_rows
47
+ result_masks_x = result_offsets_x < num_cols * 8
48
+ result_masks = result_masks_y[:, None] & result_masks_x[None, :]
49
+
50
+ # Load the weights.
51
+ iweights = tl.load(qweight_ptr + offsets, masks, 0.0)
52
+ iweights = tl.interleave(iweights, iweights)
53
+ iweights = tl.interleave(iweights, iweights)
54
+ iweights = tl.interleave(iweights, iweights)
55
+
56
+ # Create reverse AWQ order as tensor: [0, 4, 1, 5, 2, 6, 3, 7]
57
+ # that will map given indices to the correct order.
58
+ reverse_awq_order_tensor = (
59
+ (tl.arange(0, 2) * 4)[None, :] + tl.arange(0, 4)[:, None]
60
+ ).reshape(8)
61
+
62
+ # Use this to compute a set of shifts that can be used to unpack and
63
+ # reorder the values in iweights and zeros.
64
+ shifts = reverse_awq_order_tensor * 4
65
+ shifts = tl.broadcast_to(shifts[None, :], (BLOCK_SIZE_Y * BLOCK_SIZE_X, 8))
66
+ shifts = tl.reshape(shifts, (BLOCK_SIZE_Y, BLOCK_SIZE_X * 8))
67
+
68
+ # Unpack and reorder: shift out the correct 4-bit value and mask.
69
+ iweights = (iweights >> shifts) & 0xF
70
+
71
+ # Compute zero offsets and masks.
72
+ zero_offsets_y = pid_y * BLOCK_SIZE_Y // group_size + tl.arange(0, 1)
73
+ zero_offsets_x = pid_x * BLOCK_SIZE_X + tl.arange(0, BLOCK_SIZE_X)
74
+ zero_offsets = num_cols * zero_offsets_y[:, None] + zero_offsets_x[None, :]
75
+
76
+ zero_masks_y = zero_offsets_y < num_rows // group_size
77
+ zero_masks_x = zero_offsets_x < num_cols
78
+ zero_masks = zero_masks_y[:, None] & zero_masks_x[None, :]
79
+
80
+ # Load the zeros.
81
+ zeros = tl.load(zeros_ptr + zero_offsets, zero_masks, 0.0)
82
+ zeros = tl.interleave(zeros, zeros)
83
+ zeros = tl.interleave(zeros, zeros)
84
+ zeros = tl.interleave(zeros, zeros)
85
+ zeros = tl.broadcast_to(zeros, (BLOCK_SIZE_Y, BLOCK_SIZE_X * 8))
86
+
87
+ # Unpack and reorder: shift out the correct 4-bit value and mask.
88
+ zeros = (zeros >> shifts) & 0xF
89
+
90
+ # Compute scale offsets and masks.
91
+ scale_offsets_y = pid_y * BLOCK_SIZE_Y // group_size + tl.arange(0, 1)
92
+ scale_offsets_x = pid_x * BLOCK_SIZE_X * 8 + tl.arange(0, BLOCK_SIZE_X * 8)
93
+ scale_offsets = num_cols * 8 * scale_offsets_y[:, None] + scale_offsets_x[None, :]
94
+ scale_masks_y = scale_offsets_y < num_rows // group_size
95
+ scale_masks_x = scale_offsets_x < num_cols * 8
96
+ scale_masks = scale_masks_y[:, None] & scale_masks_x[None, :]
97
+
98
+ # Load the scales.
99
+ scales = tl.load(scales_ptr + scale_offsets, scale_masks, 0.0)
100
+ scales = tl.broadcast_to(scales, (BLOCK_SIZE_Y, BLOCK_SIZE_X * 8))
101
+
102
+ # Dequantize.
103
+ iweights = (iweights - zeros) * scales
104
+ iweights = iweights.to(result_ptr.type.element_ty)
105
+
106
+ # Finally, store.
107
+ tl.store(result_ptr + result_offsets, iweights, result_masks)
108
+
109
+
110
+ @triton.jit
111
+ def awq_gemm_kernel(
112
+ a_ptr,
113
+ b_ptr,
114
+ c_ptr,
115
+ zeros_ptr,
116
+ scales_ptr,
117
+ M,
118
+ N,
119
+ K,
120
+ group_size,
121
+ BLOCK_SIZE_M: tl.constexpr,
122
+ BLOCK_SIZE_N: tl.constexpr,
123
+ BLOCK_SIZE_K: tl.constexpr,
124
+ SPLIT_K: tl.constexpr,
125
+ ):
126
+ pid = tl.program_id(axis=0)
127
+ pid_z = tl.program_id(1)
128
+
129
+ # NOTE: This doesn't work in TRITON_INTERPRET=1 mode. Use below instead.
130
+ # num_pid_n = (N + BLOCK_SIZE_N - 1) // BLOCK_SIZE_N
131
+ num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
132
+
133
+ pid_m = pid // num_pid_n
134
+ pid_n = pid % num_pid_n
135
+
136
+ accumulator_dtype = c_ptr.type.element_ty
137
+
138
+ # NOTE: This doesn't work in TRITON_INTERPRET=1 mode. Use below instead.
139
+ # accumulator = tl.arange(0, BLOCK_SIZE_N)
140
+ # accumulator = tl.broadcast_to(accumulator[None, :],
141
+ # (BLOCK_SIZE_M, BLOCK_SIZE_N))
142
+ # accumulator = accumulator & 0x0
143
+ # accumulator = accumulator.to(accumulator_dtype)
144
+ accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=accumulator_dtype)
145
+
146
+ # Create reverse AWQ order as tensor: [0, 4, 1, 5, 2, 6, 3, 7]
147
+ # that will map given indices to the correct order.
148
+ reverse_awq_order_tensor = (
149
+ (tl.arange(0, 2) * 4)[None, :] + tl.arange(0, 4)[:, None]
150
+ ).reshape(8)
151
+
152
+ # Create the necessary shifts to use to unpack.
153
+ shifts = reverse_awq_order_tensor * 4
154
+ shifts = tl.broadcast_to(shifts[None, :], (BLOCK_SIZE_K * (BLOCK_SIZE_N // 8), 8))
155
+ shifts = tl.reshape(shifts, (BLOCK_SIZE_K, BLOCK_SIZE_N))
156
+
157
+ # Offsets and masks.
158
+ offsets_am = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
159
+ masks_am = offsets_am < M
160
+
161
+ offsets_bn = pid_n * (BLOCK_SIZE_N // 8) + tl.arange(0, BLOCK_SIZE_N // 8)
162
+ masks_bn = offsets_bn < N // 8
163
+
164
+ offsets_zn = pid_n * (BLOCK_SIZE_N // 8) + tl.arange(0, BLOCK_SIZE_N // 8)
165
+ masks_zn = offsets_zn < N // 8
166
+
167
+ offsets_sn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
168
+ masks_sn = offsets_sn < N
169
+
170
+ offsets_k = pid_z * BLOCK_SIZE_K + tl.arange(0, BLOCK_SIZE_K)
171
+ offsets_a = K * offsets_am[:, None] + offsets_k[None, :]
172
+ offsets_b = (N // 8) * offsets_k[:, None] + offsets_bn[None, :]
173
+
174
+ a_ptrs = a_ptr + offsets_a
175
+ b_ptrs = b_ptr + offsets_b
176
+
177
+ # NOTE: Use this in TRITON_INTERPRET=1 mode instead of tl.cdiv
178
+ # block_offset = BLOCK_SIZE_K * SPLIT_K
179
+ # for k in range(0, (K + block_offset - 1) // (block_offset)):
180
+ for k in range(0, tl.cdiv(K, BLOCK_SIZE_K * SPLIT_K)):
181
+ masks_k = offsets_k < K
182
+ masks_a = masks_am[:, None] & masks_k[None, :]
183
+ a = tl.load(a_ptrs, mask=masks_a, other=0.0)
184
+
185
+ masks_b = masks_k[:, None] & masks_bn[None, :]
186
+ b = tl.load(b_ptrs, mask=masks_b, other=0.0)
187
+ b = tl.interleave(b, b)
188
+ b = tl.interleave(b, b)
189
+ b = tl.interleave(b, b)
190
+
191
+ # Dequantize b.
192
+ offsets_szk = (
193
+ BLOCK_SIZE_K * SPLIT_K * k + pid_z * BLOCK_SIZE_K
194
+ ) // group_size + tl.arange(0, 1)
195
+ offsets_z = (N // 8) * offsets_szk[:, None] + offsets_zn[None, :]
196
+ masks_zk = offsets_szk < K // group_size
197
+ masks_z = masks_zk[:, None] & masks_zn[None, :]
198
+ zeros_ptrs = zeros_ptr + offsets_z
199
+ zeros = tl.load(zeros_ptrs, mask=masks_z, other=0.0)
200
+ zeros = tl.interleave(zeros, zeros)
201
+ zeros = tl.interleave(zeros, zeros)
202
+ zeros = tl.interleave(zeros, zeros)
203
+ zeros = tl.broadcast_to(zeros, (BLOCK_SIZE_K, BLOCK_SIZE_N))
204
+
205
+ offsets_s = N * offsets_szk[:, None] + offsets_sn[None, :]
206
+ masks_sk = offsets_szk < K // group_size
207
+ masks_s = masks_sk[:, None] & masks_sn[None, :]
208
+ scales_ptrs = scales_ptr + offsets_s
209
+ scales = tl.load(scales_ptrs, mask=masks_s, other=0.0)
210
+ scales = tl.broadcast_to(scales, (BLOCK_SIZE_K, BLOCK_SIZE_N))
211
+
212
+ b = (b >> shifts) & 0xF
213
+ zeros = (zeros >> shifts) & 0xF
214
+ b = (b - zeros) * scales
215
+ b = b.to(c_ptr.type.element_ty)
216
+
217
+ # Accumulate results.
218
+ accumulator = tl.dot(a, b, accumulator, out_dtype=accumulator_dtype)
219
+
220
+ offsets_k += BLOCK_SIZE_K * SPLIT_K
221
+ a_ptrs += BLOCK_SIZE_K * SPLIT_K
222
+ b_ptrs += BLOCK_SIZE_K * SPLIT_K * (N // 8)
223
+
224
+ c = accumulator.to(c_ptr.type.element_ty)
225
+ offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
226
+ offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
227
+ c_ptrs = c_ptr + pid_z * N * M + N * offs_cm[:, None] + offs_cn[None, :]
228
+ c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
229
+ tl.store(c_ptrs, c, mask=c_mask)
230
+
231
+
232
+ # qweights - [K , M // 8], int32
233
+ # scales - [K // G, M ], float16
234
+ # zeros - [K // G, M // 8], int32
235
+ def awq_dequantize_triton(
236
+ qweight: torch.Tensor,
237
+ scales: torch.Tensor,
238
+ zeros: torch.Tensor,
239
+ block_size_x: int = 32,
240
+ block_size_y: int = 32,
241
+ ) -> torch.Tensor:
242
+ K = qweight.shape[0]
243
+ M = scales.shape[1]
244
+ group_size = qweight.shape[0] // scales.shape[0]
245
+
246
+ assert K > 0 and M > 0
247
+ assert scales.shape[0] == K // group_size and scales.shape[1] == M
248
+ assert zeros.shape[0] == K // group_size and zeros.shape[1] == M // 8
249
+ assert group_size <= K
250
+ assert group_size in AWQ_TRITON_SUPPORTED_GROUP_SIZES or group_size == K
251
+
252
+ # Result tensor:
253
+ # number of rows = same as input tensor
254
+ # number of cols = 8 x input tensor num cols
255
+ result = torch.empty(
256
+ qweight.shape[0],
257
+ qweight.shape[1] * 8,
258
+ device=qweight.device,
259
+ dtype=scales.dtype,
260
+ )
261
+
262
+ Y = qweight.shape[0] # num rows
263
+ X = qweight.shape[1] # num cols
264
+
265
+ grid = lambda META: (
266
+ triton.cdiv(X, META["BLOCK_SIZE_X"]),
267
+ triton.cdiv(Y, META["BLOCK_SIZE_Y"]),
268
+ )
269
+ awq_dequantize_kernel[grid](
270
+ qweight,
271
+ scales,
272
+ zeros,
273
+ group_size,
274
+ result,
275
+ X,
276
+ Y,
277
+ BLOCK_SIZE_X=block_size_x,
278
+ BLOCK_SIZE_Y=block_size_y,
279
+ )
280
+
281
+ return result
282
+
283
+
284
+ # input - [M, K]
285
+ # qweight - [K, N // 8]
286
+ # qzeros - [K // G, N // 8]
287
+ # scales - [K // G, N]
288
+ # split_k_iters - parallelism along K-dimension, int, power of 2.
289
+ def awq_gemm_triton(
290
+ input: torch.Tensor,
291
+ qweight: torch.Tensor,
292
+ scales: torch.Tensor,
293
+ qzeros: torch.Tensor,
294
+ split_k_iters: int,
295
+ block_size_m: int = 32,
296
+ block_size_n: int = 32,
297
+ block_size_k: int = 32,
298
+ ) -> torch.Tensor:
299
+ M, K = input.shape
300
+ N = qweight.shape[1] * 8
301
+ group_size = qweight.shape[0] // qzeros.shape[0]
302
+
303
+ assert N > 0 and K > 0 and M > 0
304
+ assert qweight.shape[0] == K and qweight.shape[1] == N // 8
305
+ assert qzeros.shape[0] == K // group_size and qzeros.shape[1] == N // 8
306
+ assert scales.shape[0] == K // group_size and scales.shape[1] == N
307
+ assert split_k_iters & (split_k_iters - 1) == 0 and split_k_iters != 0
308
+ assert split_k_iters <= 32
309
+ assert group_size <= K
310
+ assert group_size in AWQ_TRITON_SUPPORTED_GROUP_SIZES or group_size == K
311
+
312
+ grid = lambda META: (
313
+ triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]),
314
+ split_k_iters,
315
+ )
316
+
317
+ result = torch.zeros((split_k_iters, M, N), dtype=scales.dtype, device=input.device)
318
+
319
+ # A = input, B = qweight, C = result
320
+ # A = M x K, B = K x N, C = M x N
321
+ awq_gemm_kernel[grid](
322
+ input,
323
+ qweight,
324
+ result,
325
+ qzeros,
326
+ scales,
327
+ M,
328
+ N,
329
+ K,
330
+ group_size,
331
+ BLOCK_SIZE_M=block_size_m,
332
+ BLOCK_SIZE_N=block_size_n,
333
+ BLOCK_SIZE_K=block_size_k,
334
+ SPLIT_K=split_k_iters,
335
+ )
336
+
337
+ result = result.sum(0)
338
+
339
+ return result
@@ -1,12 +1,16 @@
1
1
  # Adapted from https://raw.githubusercontent.com/vllm-project/vllm/v0.5.5/vllm/model_executor/layers/quantization/base_config.py
2
+ from __future__ import annotations
2
3
 
3
4
  import inspect
4
5
  from abc import ABC, abstractmethod
5
- from typing import Any, Dict, List, Optional, Type
6
+ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Type
6
7
 
7
8
  import torch
8
9
  from torch import nn
9
10
 
11
+ if TYPE_CHECKING:
12
+ from sglang.srt.layers.moe.topk import TopKOutput
13
+
10
14
 
11
15
  class QuantizeMethodBase(ABC):
12
16
  """Base class for different quantized methods."""
@@ -18,14 +22,14 @@ class QuantizeMethodBase(ABC):
18
22
  """Create weights for a layer.
19
23
 
20
24
  The weights will be set as attributes of the layer."""
21
- raise NotImplementedError
25
+ raise NotImplementedError()
22
26
 
23
27
  @abstractmethod
24
28
  def apply(self, layer: torch.nn.Module, *args, **kwargs) -> torch.Tensor:
25
29
  """Apply the weights in layer to the input tensor.
26
30
 
27
31
  Expects create_weights to have been called before on the layer."""
28
- raise NotImplementedError
32
+ raise NotImplementedError()
29
33
 
30
34
  def process_weights_after_loading(self, layer: nn.Module) -> None:
31
35
  """Process the weight after loading.
@@ -35,6 +39,77 @@ class QuantizeMethodBase(ABC):
35
39
  return
36
40
 
37
41
 
42
+ class LinearMethodBase(QuantizeMethodBase):
43
+ """Base class for different (maybe quantized) linear methods."""
44
+
45
+ @abstractmethod
46
+ def create_weights(
47
+ self,
48
+ layer: torch.nn.Module,
49
+ input_size_per_partition: int,
50
+ output_partition_sizes: List[int],
51
+ input_size: int,
52
+ output_size: int,
53
+ params_dtype: torch.dtype,
54
+ **extra_weight_attrs,
55
+ ):
56
+ """Create weights for a linear layer.
57
+ The weights will be set as attributes of the layer.
58
+
59
+ Args:
60
+ layer: The layer that is using the LinearMethodBase factory.
61
+ input_size_per_partition: Size of the weight input dim on rank X.
62
+ output_partition_sizes: Sizes of the output dim of each logical
63
+ weight on rank X. E.g., output_partition_sizes for QKVLinear
64
+ is a list contains the width of Wq, Wk, Wv on rank X.
65
+ input_size: Size of the input dim of the weight across all ranks.
66
+ output_size: Size of the output dim of the weight across all ranks.
67
+ params_dtype: Datatype of the parameters.
68
+ """
69
+ raise NotImplementedError()
70
+
71
+ @abstractmethod
72
+ def apply(
73
+ self,
74
+ layer: torch.nn.Module,
75
+ x: torch.Tensor,
76
+ bias: Optional[torch.Tensor] = None,
77
+ ) -> torch.Tensor:
78
+ """Apply the weights in layer to the input tensor.
79
+ Expects create_weights to have been called before on the layer."""
80
+ raise NotImplementedError()
81
+
82
+
83
+ class FusedMoEMethodBase(QuantizeMethodBase):
84
+
85
+ @abstractmethod
86
+ def create_weights(
87
+ self,
88
+ layer: torch.nn.Module,
89
+ num_experts: int,
90
+ hidden_size: int,
91
+ intermediate_size: int,
92
+ params_dtype: torch.dtype,
93
+ **extra_weight_attrs,
94
+ ):
95
+ raise NotImplementedError
96
+
97
+ @abstractmethod
98
+ def apply(
99
+ self,
100
+ layer: torch.nn.Module,
101
+ x: torch.Tensor,
102
+ topk_output: TopKOutput,
103
+ *,
104
+ activation: str = "silu",
105
+ apply_router_weight_on_input: bool = False,
106
+ inplace: bool = True,
107
+ no_combine: bool = False,
108
+ routed_scaling_factor: Optional[float] = None,
109
+ ) -> torch.Tensor:
110
+ raise NotImplementedError
111
+
112
+
38
113
  class QuantizationConfig(ABC):
39
114
  """Base class for quantization configs."""
40
115
 
@@ -46,12 +121,12 @@ class QuantizationConfig(ABC):
46
121
  @abstractmethod
47
122
  def get_name(self) -> str:
48
123
  """Name of the quantization method."""
49
- raise NotImplementedError
124
+ raise NotImplementedError()
50
125
 
51
126
  @abstractmethod
52
127
  def get_supported_act_dtypes(self) -> List[torch.dtype]:
53
128
  """List of supported activation dtypes."""
54
- raise NotImplementedError
129
+ raise NotImplementedError()
55
130
 
56
131
  @classmethod
57
132
  @abstractmethod
@@ -62,19 +137,19 @@ class QuantizationConfig(ABC):
62
137
  This requirement is due to the custom CUDA kernels used by the
63
138
  quantization method.
64
139
  """
65
- raise NotImplementedError
140
+ raise NotImplementedError()
66
141
 
67
142
  @staticmethod
68
143
  @abstractmethod
69
144
  def get_config_filenames() -> List[str]:
70
145
  """List of filenames to search for in the model directory."""
71
- raise NotImplementedError
146
+ raise NotImplementedError()
72
147
 
73
148
  @classmethod
74
149
  @abstractmethod
75
150
  def from_config(cls, config: Dict[str, Any]) -> "QuantizationConfig":
76
151
  """Create a config class from the model's quantization config."""
77
- raise NotImplementedError
152
+ raise NotImplementedError()
78
153
 
79
154
  @classmethod
80
155
  def override_quantization_method(cls, hf_quant_cfg, user_quant) -> Optional[str]:
@@ -117,7 +192,7 @@ class QuantizationConfig(ABC):
117
192
  The quantize method. None if the given layer doesn't support quant
118
193
  method.
119
194
  """
120
- raise NotImplementedError
195
+ raise NotImplementedError()
121
196
 
122
197
  @abstractmethod
123
198
  def get_scaled_act_names(self) -> List[str]:
@@ -125,7 +200,7 @@ class QuantizationConfig(ABC):
125
200
 
126
201
  For now, this is only used by AWQ.
127
202
  """
128
- raise NotImplementedError
203
+ raise NotImplementedError()
129
204
 
130
205
 
131
206
  def method_has_implemented_embedding(method_class: Type[QuantizeMethodBase]) -> bool:
@@ -1,26 +1,29 @@
1
1
  # Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/model_executor/layers/quantization/fp8.py
2
2
 
3
+ from __future__ import annotations
4
+
3
5
  import logging
4
- from typing import Any, Callable, Dict, List, Optional
6
+ from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional
5
7
 
6
8
  import torch
7
9
  from torch.nn import Module
8
10
 
9
11
  from sglang.srt.distributed import get_tensor_model_parallel_world_size
10
- from sglang.srt.layers.linear import (
11
- LinearBase,
12
- LinearMethodBase,
13
- UnquantizedLinearMethod,
14
- )
15
12
  from sglang.srt.layers.parameter import BlockQuantScaleParameter, ModelWeightParameter
16
13
  from sglang.srt.layers.quantization.base_config import (
14
+ FusedMoEMethodBase,
15
+ LinearMethodBase,
17
16
  QuantizationConfig,
18
17
  QuantizeMethodBase,
19
18
  )
20
19
  from sglang.srt.layers.quantization.int8_utils import apply_w8a8_block_int8_linear
20
+ from sglang.srt.layers.quantization.unquant import UnquantizedLinearMethod
21
21
  from sglang.srt.layers.quantization.utils import is_layer_skipped
22
22
  from sglang.srt.utils import set_weight_attrs
23
23
 
24
+ if TYPE_CHECKING:
25
+ from sglang.srt.layers.moe.topk import TopKOutput
26
+
24
27
  ACTIVATION_SCHEMES = ["static", "dynamic"]
25
28
 
26
29
  logger = logging.getLogger(__name__)
@@ -78,7 +81,7 @@ class BlockInt8Config(QuantizationConfig):
78
81
  return []
79
82
 
80
83
  @classmethod
81
- def from_config(cls, config: Dict[str, Any]) -> "BlockInt8Config":
84
+ def from_config(cls, config: Dict[str, Any]) -> BlockInt8Config:
82
85
  quant_method = cls.get_from_keys(config, ["quant_method"])
83
86
  is_checkpoint_int8_serialized = "int8" in quant_method
84
87
  activation_scheme = cls.get_from_keys(config, ["activation_scheme"])
@@ -93,7 +96,8 @@ class BlockInt8Config(QuantizationConfig):
93
96
 
94
97
  def get_quant_method(
95
98
  self, layer: torch.nn.Module, prefix: str
96
- ) -> Optional["QuantizeMethodBase"]:
99
+ ) -> Optional[QuantizeMethodBase]:
100
+ from sglang.srt.layers.linear import LinearBase
97
101
  from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
98
102
 
99
103
  if isinstance(layer, LinearBase):
@@ -230,7 +234,7 @@ class BlockInt8LinearMethod(LinearMethodBase):
230
234
  )
231
235
 
232
236
 
233
- class BlockInt8MoEMethod:
237
+ class BlockInt8MoEMethod(FusedMoEMethodBase):
234
238
  """MoE method for INT8.
235
239
  Supports loading INT8 checkpoints with static weight scale and
236
240
  dynamic activation scale.
@@ -242,25 +246,7 @@ class BlockInt8MoEMethod:
242
246
  quant_config: The quantization config.
243
247
  """
244
248
 
245
- def __new__(cls, *args, **kwargs):
246
- from sglang.srt.layers.moe.fused_moe_triton import FusedMoEMethodBase
247
-
248
- if not hasattr(cls, "_initialized"):
249
- original_init = cls.__init__
250
- new_cls = type(
251
- cls.__name__,
252
- (FusedMoEMethodBase,),
253
- {
254
- "__init__": original_init,
255
- **{k: v for k, v in cls.__dict__.items() if k != "__dict__"},
256
- },
257
- )
258
- obj = super(new_cls, new_cls).__new__(new_cls)
259
- obj.__init__(*args, **kwargs)
260
- return obj
261
- return super().__new__(cls)
262
-
263
- def __init__(self, quant_config):
249
+ def __init__(self, quant_config: BlockInt8Config):
264
250
  self.quant_config = quant_config
265
251
  assert self.quant_config.weight_block_size is not None
266
252
  assert self.quant_config.is_checkpoint_int8_serialized
@@ -361,15 +347,8 @@ class BlockInt8MoEMethod:
361
347
  self,
362
348
  layer: torch.nn.Module,
363
349
  x: torch.Tensor,
364
- router_logits: torch.Tensor,
365
- top_k: int,
366
- renormalize: bool,
367
- use_grouped_topk: bool,
368
- topk_group: Optional[int] = None,
369
- num_expert_group: Optional[int] = None,
370
- num_fused_shared_experts: int = 0,
371
- custom_routing_function: Optional[Callable] = None,
372
- correction_bias: Optional[torch.Tensor] = None,
350
+ topk_output: TopKOutput,
351
+ *,
373
352
  activation: str = "silu",
374
353
  apply_router_weight_on_input: bool = False,
375
354
  inplace: bool = True,
@@ -377,30 +356,13 @@ class BlockInt8MoEMethod:
377
356
  routed_scaling_factor: Optional[float] = None,
378
357
  ) -> torch.Tensor:
379
358
  from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
380
- from sglang.srt.layers.moe.topk import select_experts
381
-
382
- # Expert selection
383
- topk_weights, topk_ids = select_experts(
384
- hidden_states=x,
385
- router_logits=router_logits,
386
- use_grouped_topk=use_grouped_topk,
387
- top_k=top_k,
388
- renormalize=renormalize,
389
- topk_group=topk_group,
390
- num_expert_group=num_expert_group,
391
- num_fused_shared_experts=num_fused_shared_experts,
392
- custom_routing_function=custom_routing_function,
393
- correction_bias=correction_bias,
394
- routed_scaling_factor=routed_scaling_factor,
395
- )
396
359
 
397
360
  # Expert fusion with INT8 quantization
398
361
  return fused_experts(
399
362
  x,
400
363
  layer.w13_weight,
401
364
  layer.w2_weight,
402
- topk_weights=topk_weights,
403
- topk_ids=topk_ids,
365
+ topk_output=topk_output,
404
366
  inplace=inplace,
405
367
  activation=activation,
406
368
  apply_router_weight_on_input=apply_router_weight_on_input,