sglang 0.4.4.post3__py3-none-any.whl → 0.4.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (99) hide show
  1. sglang/bench_serving.py +49 -7
  2. sglang/lang/chat_template.py +24 -0
  3. sglang/srt/_custom_ops.py +59 -92
  4. sglang/srt/configs/model_config.py +5 -0
  5. sglang/srt/constrained/base_grammar_backend.py +5 -1
  6. sglang/srt/conversation.py +29 -4
  7. sglang/srt/custom_op.py +5 -0
  8. sglang/srt/distributed/device_communicators/custom_all_reduce.py +27 -79
  9. sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +2 -2
  10. sglang/srt/entrypoints/engine.py +0 -5
  11. sglang/srt/layers/attention/flashattention_backend.py +678 -83
  12. sglang/srt/layers/attention/flashinfer_backend.py +5 -7
  13. sglang/srt/layers/attention/flashinfer_mla_backend.py +1 -3
  14. sglang/srt/layers/attention/flashmla_backend.py +1 -1
  15. sglang/srt/layers/moe/ep_moe/kernels.py +142 -0
  16. sglang/srt/layers/moe/ep_moe/layer.py +79 -80
  17. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +382 -199
  18. sglang/srt/layers/moe/fused_moe_native.py +5 -0
  19. sglang/srt/layers/moe/fused_moe_triton/configs/E=128,N=512,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  20. sglang/srt/layers/moe/fused_moe_triton/configs/E=144,N=512,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  21. sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=1024,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  22. sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=1024,device_name=NVIDIA_H200.json +146 -0
  23. sglang/srt/layers/moe/fused_moe_triton/configs/E=16,N=2048,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  24. sglang/srt/layers/moe/fused_moe_triton/configs/E=20,N=2048,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  25. sglang/srt/layers/moe/fused_moe_triton/configs/E=24,N=1024,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  26. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_H20,block_shape=[128, 128].json +146 -0
  27. sglang/srt/layers/moe/fused_moe_triton/configs/E=257,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  28. sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  29. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +416 -50
  30. sglang/srt/layers/moe/fused_moe_triton/layer.py +7 -0
  31. sglang/srt/layers/moe/topk.py +49 -3
  32. sglang/srt/layers/quantization/__init__.py +5 -1
  33. sglang/srt/layers/quantization/blockwise_int8.py +2 -0
  34. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +2 -1
  35. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +34 -10
  36. sglang/srt/layers/quantization/fp8.py +3 -1
  37. sglang/srt/layers/quantization/fp8_utils.py +1 -4
  38. sglang/srt/layers/quantization/moe_wna16.py +503 -0
  39. sglang/srt/layers/quantization/utils.py +1 -1
  40. sglang/srt/layers/quantization/w8a8_int8.py +2 -0
  41. sglang/srt/layers/radix_attention.py +2 -0
  42. sglang/srt/layers/rotary_embedding.py +63 -12
  43. sglang/srt/managers/cache_controller.py +34 -11
  44. sglang/srt/managers/mm_utils.py +202 -156
  45. sglang/srt/managers/multimodal_processor.py +0 -2
  46. sglang/srt/managers/multimodal_processors/base_processor.py +45 -77
  47. sglang/srt/managers/multimodal_processors/clip.py +7 -26
  48. sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +17 -58
  49. sglang/srt/managers/multimodal_processors/gemma3.py +12 -27
  50. sglang/srt/managers/multimodal_processors/janus_pro.py +21 -47
  51. sglang/srt/managers/multimodal_processors/llava.py +34 -14
  52. sglang/srt/managers/multimodal_processors/minicpm.py +35 -38
  53. sglang/srt/managers/multimodal_processors/mlama.py +10 -23
  54. sglang/srt/managers/multimodal_processors/mllama4.py +161 -0
  55. sglang/srt/managers/multimodal_processors/qwen_vl.py +22 -45
  56. sglang/srt/managers/schedule_batch.py +185 -128
  57. sglang/srt/managers/scheduler.py +4 -4
  58. sglang/srt/managers/tokenizer_manager.py +1 -1
  59. sglang/srt/managers/utils.py +1 -6
  60. sglang/srt/mem_cache/hiradix_cache.py +62 -52
  61. sglang/srt/mem_cache/memory_pool.py +72 -6
  62. sglang/srt/mem_cache/paged_allocator.py +39 -0
  63. sglang/srt/metrics/collector.py +23 -53
  64. sglang/srt/model_executor/cuda_graph_runner.py +8 -6
  65. sglang/srt/model_executor/forward_batch_info.py +10 -10
  66. sglang/srt/model_executor/model_runner.py +60 -57
  67. sglang/srt/model_loader/loader.py +8 -0
  68. sglang/srt/models/clip.py +12 -7
  69. sglang/srt/models/deepseek_janus_pro.py +10 -15
  70. sglang/srt/models/deepseek_v2.py +212 -121
  71. sglang/srt/models/deepseek_vl2.py +105 -104
  72. sglang/srt/models/gemma3_mm.py +14 -80
  73. sglang/srt/models/llama.py +16 -5
  74. sglang/srt/models/llama4.py +420 -0
  75. sglang/srt/models/llava.py +31 -19
  76. sglang/srt/models/llavavid.py +16 -7
  77. sglang/srt/models/minicpmo.py +63 -147
  78. sglang/srt/models/minicpmv.py +17 -27
  79. sglang/srt/models/mllama.py +29 -14
  80. sglang/srt/models/mllama4.py +154 -0
  81. sglang/srt/models/qwen2.py +9 -6
  82. sglang/srt/models/qwen2_5_vl.py +21 -31
  83. sglang/srt/models/qwen2_vl.py +20 -21
  84. sglang/srt/openai_api/adapter.py +18 -6
  85. sglang/srt/platforms/interface.py +371 -0
  86. sglang/srt/server_args.py +99 -14
  87. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +5 -5
  88. sglang/srt/speculative/eagle_utils.py +140 -28
  89. sglang/srt/speculative/eagle_worker.py +93 -24
  90. sglang/srt/utils.py +104 -51
  91. sglang/test/test_custom_ops.py +55 -0
  92. sglang/test/test_utils.py +13 -26
  93. sglang/utils.py +2 -2
  94. sglang/version.py +1 -1
  95. {sglang-0.4.4.post3.dist-info → sglang-0.4.5.dist-info}/METADATA +4 -3
  96. {sglang-0.4.4.post3.dist-info → sglang-0.4.5.dist-info}/RECORD +99 -84
  97. {sglang-0.4.4.post3.dist-info → sglang-0.4.5.dist-info}/WHEEL +0 -0
  98. {sglang-0.4.4.post3.dist-info → sglang-0.4.5.dist-info}/licenses/LICENSE +0 -0
  99. {sglang-0.4.4.post3.dist-info → sglang-0.4.5.dist-info}/top_level.txt +0 -0
@@ -13,11 +13,6 @@ import triton
13
13
  import triton.language as tl
14
14
 
15
15
  from sglang.srt.layers.moe.topk import select_experts
16
- from sglang.srt.layers.quantization.fp8_kernel import per_token_group_quant_fp8
17
- from sglang.srt.layers.quantization.int8_kernel import (
18
- per_token_group_quant_int8,
19
- per_token_quant_int8,
20
- )
21
16
  from sglang.srt.utils import (
22
17
  direct_register_custom_op,
23
18
  get_bool_env_var,
@@ -42,9 +37,6 @@ if _is_cuda:
42
37
  from sgl_kernel import gelu_and_mul, silu_and_mul
43
38
 
44
39
  from sglang.srt.custom_op import scaled_fp8_quant as sgl_scaled_fp8_quant
45
- from sglang.srt.layers.quantization.fp8_kernel import (
46
- sglang_per_token_group_quant_fp8,
47
- )
48
40
  else:
49
41
  from vllm import _custom_ops as vllm_ops
50
42
 
@@ -52,6 +44,257 @@ if _is_cuda or _is_hip:
52
44
  from sgl_kernel import moe_align_block_size as sgl_moe_align_block_size
53
45
 
54
46
 
47
+ @triton.jit
48
+ def write_zeros_to_output(
49
+ c_ptr,
50
+ stride_cm,
51
+ stride_cn,
52
+ pid_n,
53
+ N,
54
+ offs_token,
55
+ token_mask,
56
+ BLOCK_SIZE_M,
57
+ BLOCK_SIZE_N,
58
+ compute_type,
59
+ ):
60
+ accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=compute_type)
61
+ offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
62
+ c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[None, :]
63
+ c_mask = token_mask[:, None] & (offs_cn[None, :] < N)
64
+ tl.store(c_ptrs, accumulator, mask=c_mask)
65
+
66
+
67
+ @triton.jit
68
+ def fused_moe_kernel_gptq_awq(
69
+ # Pointers to matrices
70
+ a_ptr,
71
+ b_ptr,
72
+ c_ptr,
73
+ b_scale_ptr,
74
+ b_zp_ptr,
75
+ topk_weights_ptr,
76
+ sorted_token_ids_ptr,
77
+ expert_ids_ptr,
78
+ num_tokens_post_padded_ptr,
79
+ # Matrix dimensions
80
+ N: tl.constexpr,
81
+ K: tl.constexpr,
82
+ EM,
83
+ num_valid_tokens,
84
+ # The stride variables represent how much to increase the ptr by when
85
+ # moving by 1 element in a particular dimension. E.g. `stride_am` is
86
+ # how much to increase `a_ptr` by to get the element one row down
87
+ # (A has M rows).
88
+ stride_am,
89
+ stride_ak,
90
+ stride_be,
91
+ stride_bk,
92
+ stride_bn,
93
+ stride_cm,
94
+ stride_cn,
95
+ stride_bse,
96
+ stride_bsk,
97
+ stride_bsn,
98
+ stride_bze,
99
+ stride_bzk,
100
+ stride_bzn,
101
+ group_size: tl.constexpr,
102
+ # Meta-parameters
103
+ BLOCK_SIZE_M: tl.constexpr,
104
+ BLOCK_SIZE_N: tl.constexpr,
105
+ BLOCK_SIZE_K: tl.constexpr,
106
+ GROUP_SIZE_M: tl.constexpr,
107
+ MUL_ROUTED_WEIGHT: tl.constexpr,
108
+ top_k: tl.constexpr,
109
+ compute_type: tl.constexpr,
110
+ has_zp: tl.constexpr,
111
+ use_int4_w4a16: tl.constexpr,
112
+ use_int8_w8a16: tl.constexpr,
113
+ even_Ks: tl.constexpr,
114
+ ):
115
+ """
116
+ Implements the fused computation for a Mixture of Experts (MOE) using
117
+ token and expert matrices.
118
+ Key Parameters:
119
+ - A: The input tensor representing tokens with shape (*, K), where '*' can
120
+ be any shape representing batches and K is the feature dimension of
121
+ each token.
122
+ - B: The stacked MOE weight tensor with shape (E, N, K), where E is
123
+ the number of experts, K is the input feature dimension, and N is
124
+ the output feature dimension.
125
+ - C: The output cache tensor with shape (M, topk, N), where M is the
126
+ total number of tokens post padding, topk is the number of times
127
+ each token is repeated, and N is the output feature dimension.
128
+ - sorted_token_ids: A tensor containing the sorted indices of tokens,
129
+ repeated topk times and arranged by the expert index they are
130
+ assigned to.
131
+ - expert_ids: A tensor containing the indices of the expert for each
132
+ block. It determines which expert matrix from B should be used for
133
+ each block in A.
134
+ This kernel performs the multiplication of a token by its corresponding
135
+ expert matrix as determined by `expert_ids`. The sorting of
136
+ `sorted_token_ids` by expert index and padding ensures divisibility by
137
+ BLOCK_SIZE_M, which is necessary to maintain consistency in block matrix
138
+ multiplication across different blocks processed by the same expert.
139
+ """
140
+ # -----------------------------------------------------------
141
+ # Map program ids `pid` to the block of C it should compute.
142
+ # This is done in a grouped ordering to promote L2 data reuse.
143
+ pid = tl.program_id(axis=0)
144
+ num_pid_m = tl.cdiv(EM, BLOCK_SIZE_M)
145
+ num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
146
+ num_pid_in_group = GROUP_SIZE_M * num_pid_n
147
+ group_id = pid // num_pid_in_group
148
+ first_pid_m = group_id * GROUP_SIZE_M
149
+ group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
150
+ pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
151
+ pid_n = (pid % num_pid_in_group) // group_size_m
152
+
153
+ # ----------------------------------------------------------
154
+ # Create pointers for the first blocks of A and B.
155
+ # We will advance this pointer as we move in the K direction
156
+ # and accumulate
157
+ # `a_ptrs` is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers
158
+ # `b_ptrs` is a block of [BLOCK_SIZE_K, BLOCK_SIZE_N] pointers
159
+ num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr)
160
+ if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded:
161
+ return
162
+ offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(tl.int64)
163
+ offs_token = tl.load(sorted_token_ids_ptr + offs_token_id)
164
+ token_mask = offs_token < num_valid_tokens
165
+
166
+ off_experts = tl.load(expert_ids_ptr + pid_m).to(tl.int64)
167
+ if off_experts == -1:
168
+ # -----------------------------------------------------------
169
+ # Write back zeros to the output when the expert is not
170
+ # in the current expert parallel rank.
171
+ write_zeros_to_output(
172
+ c_ptr,
173
+ stride_cm,
174
+ stride_cn,
175
+ pid_n,
176
+ N,
177
+ offs_token,
178
+ token_mask,
179
+ BLOCK_SIZE_M,
180
+ BLOCK_SIZE_N,
181
+ compute_type,
182
+ )
183
+ return
184
+
185
+ offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64)) % N
186
+ offs_k = tl.arange(0, BLOCK_SIZE_K)
187
+ a_ptrs = a_ptr + (
188
+ offs_token[:, None] // top_k * stride_am + offs_k[None, :] * stride_ak
189
+ )
190
+
191
+ if use_int4_w4a16:
192
+ b_ptrs = (
193
+ b_ptr
194
+ + off_experts * stride_be
195
+ + (offs_k[:, None] // 2) * stride_bk
196
+ + offs_bn[None, :] * stride_bn
197
+ )
198
+ b_shifter = (offs_k[:, None] % 2) * 4
199
+ elif use_int8_w8a16:
200
+ b_ptrs = (
201
+ b_ptr
202
+ + off_experts * stride_be
203
+ + offs_k[:, None] * stride_bk
204
+ + offs_bn[None, :] * stride_bn
205
+ )
206
+
207
+ if not has_zp and use_int4_w4a16:
208
+ b_zp_num = 8
209
+ if not has_zp and use_int8_w8a16:
210
+ b_zp_num = 128
211
+ elif has_zp and use_int4_w4a16:
212
+ b_zp_shifter = (offs_bn[None, :] % 2) * 4
213
+
214
+ # -----------------------------------------------------------
215
+ # Iterate to compute a block of the C matrix.
216
+ # We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block
217
+ # of fp32 values for higher accuracy.
218
+ # `accumulator` will be converted back to fp16 after the loop.
219
+ accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
220
+ for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
221
+ # Load the next block of A and B, generate a mask by checking the
222
+ # K dimension.
223
+
224
+ if not even_Ks:
225
+ k_mask = offs_k[:, None] < K - k * BLOCK_SIZE_K
226
+ k_other = 0.0
227
+ else:
228
+ k_mask = None
229
+ k_other = None
230
+
231
+ a = tl.load(
232
+ a_ptrs,
233
+ mask=token_mask[:, None] & (offs_k[None, :] < K - k * BLOCK_SIZE_K),
234
+ other=0.0,
235
+ )
236
+ b = tl.load(b_ptrs)
237
+ if use_int4_w4a16:
238
+ b = (b >> b_shifter) & 0xF
239
+
240
+ b_scale_ptrs = (
241
+ b_scale_ptr
242
+ + off_experts * stride_bse
243
+ + offs_bn[None, :] * stride_bsn
244
+ + ((offs_k[:, None] + BLOCK_SIZE_K * k) // group_size) * stride_bsk
245
+ )
246
+ b_scale = tl.load(b_scale_ptrs, mask=k_mask, other=k_other)
247
+ b_scale = b_scale.to(tl.float32)
248
+
249
+ if has_zp and use_int4_w4a16:
250
+ offs_k_true = (offs_k[:, None] + BLOCK_SIZE_K * k) // group_size
251
+ b_zp_ptrs = (
252
+ b_zp_ptr
253
+ + off_experts * stride_bze
254
+ + (offs_bn[None, :] // 2) * stride_bzn
255
+ + offs_k_true * stride_bzk
256
+ )
257
+ b_zp = tl.load(b_zp_ptrs, mask=k_mask, other=k_other)
258
+ b_zp = (b_zp >> b_zp_shifter) & 0xF
259
+ b_zp = b_zp.to(tl.float32)
260
+ elif has_zp and use_int8_w8a16:
261
+ offs_k_true = (offs_k[:, None] + BLOCK_SIZE_K * k) // group_size
262
+ b_zp_ptrs = (
263
+ b_zp_ptr
264
+ + off_experts * stride_bze
265
+ + offs_bn[None, :] * stride_bzn
266
+ + offs_k_true * stride_bzk
267
+ )
268
+ b_zp = tl.load(b_zp_ptrs, mask=k_mask, other=k_other)
269
+ b_zp = b_zp.to(tl.float32)
270
+
271
+ # We accumulate along the K dimension.
272
+ if has_zp:
273
+ b = ((b.to(tl.float32) - b_zp) * b_scale).to(compute_type)
274
+ else:
275
+ b = ((b.to(tl.float32) - b_zp_num) * b_scale).to(compute_type)
276
+ accumulator = tl.dot(a, b, acc=accumulator)
277
+
278
+ # Advance the ptrs to the next K block.
279
+ a_ptrs += BLOCK_SIZE_K * stride_ak
280
+ if use_int4_w4a16:
281
+ b_ptrs += (BLOCK_SIZE_K // 2) * stride_bk
282
+ else:
283
+ b_ptrs += BLOCK_SIZE_K * stride_bk
284
+
285
+ if MUL_ROUTED_WEIGHT:
286
+ moe_weight = tl.load(topk_weights_ptr + offs_token, mask=token_mask, other=0)
287
+ accumulator = accumulator * moe_weight[:, None]
288
+
289
+ accumulator = accumulator.to(compute_type)
290
+ # -----------------------------------------------------------
291
+ # Write back the block of the output
292
+ offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
293
+ c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[None, :]
294
+ c_mask = token_mask[:, None] & (offs_cn[None, :] < N)
295
+ tl.store(c_ptrs, accumulator, mask=c_mask)
296
+
297
+
55
298
  @triton.jit
56
299
  def fused_moe_kernel(
57
300
  # Pointers to matrices
@@ -152,6 +395,7 @@ def fused_moe_kernel(
152
395
  return
153
396
  offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
154
397
  offs_token = tl.load(sorted_token_ids_ptr + offs_token_id)
398
+ offs_token = offs_token.to(tl.int64)
155
399
  token_mask = offs_token < num_valid_tokens
156
400
 
157
401
  offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
@@ -495,6 +739,7 @@ def invoke_fused_moe_kernel(
495
739
  C: torch.Tensor,
496
740
  A_scale: Optional[torch.Tensor],
497
741
  B_scale: Optional[torch.Tensor],
742
+ B_zp: Optional[torch.Tensor],
498
743
  topk_weights: torch.Tensor,
499
744
  topk_ids: torch.Tensor,
500
745
  sorted_token_ids: torch.Tensor,
@@ -507,9 +752,20 @@ def invoke_fused_moe_kernel(
507
752
  use_fp8_w8a8: bool,
508
753
  use_int8_w8a8: bool,
509
754
  use_int8_w8a16: bool,
755
+ use_int4_w4a16: bool,
510
756
  block_shape: Optional[List[int]] = None,
511
757
  no_combine: bool = False,
512
758
  ) -> None:
759
+ from sglang.srt.layers.quantization.int8_kernel import (
760
+ per_token_group_quant_int8,
761
+ per_token_quant_int8,
762
+ )
763
+
764
+ if _is_cuda:
765
+ from sglang.srt.layers.quantization.fp8_kernel import (
766
+ sglang_per_token_group_quant_fp8,
767
+ )
768
+
513
769
  assert topk_weights.stride(1) == 1
514
770
  assert sorted_token_ids.stride(0) == 1
515
771
 
@@ -547,8 +803,9 @@ def invoke_fused_moe_kernel(
547
803
  assert triton.cdiv(A.shape[-1], block_k) == A_scale.shape[-1]
548
804
  assert triton.cdiv(B.shape[-2], block_n) == B_scale.shape[-2]
549
805
  assert triton.cdiv(B.shape[-1], block_k) == B_scale.shape[-1]
550
- elif use_int8_w8a16:
806
+ elif use_int8_w8a16 or use_int4_w4a16:
551
807
  assert B_scale is not None
808
+ assert block_shape is None or block_shape[0] == 0
552
809
  else:
553
810
  assert A_scale is None
554
811
  assert B_scale is None
@@ -564,43 +821,90 @@ def invoke_fused_moe_kernel(
564
821
  else:
565
822
  even_Ks = False
566
823
 
567
- fused_moe_kernel[grid](
568
- A,
569
- B,
570
- C,
571
- A_scale,
572
- B_scale,
573
- topk_weights,
574
- sorted_token_ids,
575
- expert_ids,
576
- num_tokens_post_padded,
577
- B.shape[1],
578
- B.shape[2] - padded_size,
579
- sorted_token_ids.shape[0],
580
- topk_ids.numel(),
581
- A.stride(0),
582
- A.stride(1),
583
- B.stride(0),
584
- B.stride(2),
585
- B.stride(1),
586
- C.stride(1),
587
- C.stride(2),
588
- A_scale.stride(0) if A_scale is not None and A_scale.ndim == 2 else 0,
589
- A_scale.stride(1) if A_scale is not None and A_scale.ndim == 2 else 0,
590
- B_scale.stride(0) if B_scale is not None and B_scale.ndim >= 2 else 0,
591
- B_scale.stride(2) if B_scale is not None and B_scale.ndim == 3 else 0,
592
- B_scale.stride(1) if B_scale is not None and B_scale.ndim >= 2 else 0,
593
- 0 if block_shape is None else block_shape[0],
594
- 0 if block_shape is None else block_shape[1],
595
- MUL_ROUTED_WEIGHT=mul_routed_weight,
596
- top_k=top_k,
597
- compute_type=compute_type,
598
- use_fp8_w8a8=use_fp8_w8a8,
599
- use_int8_w8a8=use_int8_w8a8,
600
- use_int8_w8a16=use_int8_w8a16,
601
- even_Ks=even_Ks,
602
- **config,
603
- )
824
+ if (
825
+ (use_int8_w8a16 or use_int4_w4a16)
826
+ and block_shape is not None
827
+ and block_shape[1] > 0
828
+ ):
829
+ assert B_scale is not None and B_scale.ndim == 3
830
+ assert B_zp is None or B_zp.ndim == 3
831
+ fused_moe_kernel_gptq_awq[grid](
832
+ A,
833
+ B,
834
+ C,
835
+ B_scale,
836
+ B_zp,
837
+ topk_weights,
838
+ sorted_token_ids,
839
+ expert_ids,
840
+ num_tokens_post_padded,
841
+ B.shape[1],
842
+ A.shape[1],
843
+ sorted_token_ids.shape[0],
844
+ topk_ids.numel(),
845
+ A.stride(0),
846
+ A.stride(1),
847
+ B.stride(0),
848
+ B.stride(2),
849
+ B.stride(1),
850
+ C.stride(1),
851
+ C.stride(2),
852
+ B_scale.stride(0),
853
+ B_scale.stride(2),
854
+ B_scale.stride(1),
855
+ B_zp.stride(0) if B_zp is not None else 0,
856
+ B_zp.stride(2) if B_zp is not None else 0,
857
+ B_zp.stride(1) if B_zp is not None else 0,
858
+ group_size=block_shape[1],
859
+ MUL_ROUTED_WEIGHT=mul_routed_weight,
860
+ top_k=top_k,
861
+ compute_type=compute_type,
862
+ has_zp=B_zp is not None,
863
+ use_int4_w4a16=use_int4_w4a16,
864
+ use_int8_w8a16=use_int8_w8a16,
865
+ even_Ks=even_Ks,
866
+ **config,
867
+ )
868
+
869
+ else:
870
+
871
+ fused_moe_kernel[grid](
872
+ A,
873
+ B,
874
+ C,
875
+ A_scale,
876
+ B_scale,
877
+ topk_weights,
878
+ sorted_token_ids,
879
+ expert_ids,
880
+ num_tokens_post_padded,
881
+ B.shape[1],
882
+ B.shape[2] - padded_size,
883
+ sorted_token_ids.shape[0],
884
+ topk_ids.numel(),
885
+ A.stride(0),
886
+ A.stride(1),
887
+ B.stride(0),
888
+ B.stride(2),
889
+ B.stride(1),
890
+ C.stride(1),
891
+ C.stride(2),
892
+ A_scale.stride(0) if A_scale is not None and A_scale.ndim == 2 else 0,
893
+ A_scale.stride(1) if A_scale is not None and A_scale.ndim == 2 else 0,
894
+ B_scale.stride(0) if B_scale is not None and B_scale.ndim >= 2 else 0,
895
+ B_scale.stride(2) if B_scale is not None and B_scale.ndim == 3 else 0,
896
+ B_scale.stride(1) if B_scale is not None and B_scale.ndim >= 2 else 0,
897
+ 0 if block_shape is None else block_shape[0],
898
+ 0 if block_shape is None else block_shape[1],
899
+ MUL_ROUTED_WEIGHT=mul_routed_weight,
900
+ top_k=top_k,
901
+ compute_type=compute_type,
902
+ use_fp8_w8a8=use_fp8_w8a8,
903
+ use_int8_w8a8=use_int8_w8a8,
904
+ use_int8_w8a16=use_int8_w8a16,
905
+ even_Ks=even_Ks,
906
+ **config,
907
+ )
604
908
 
605
909
 
606
910
  def get_config_file_name(
@@ -749,6 +1053,7 @@ def try_get_optimal_moe_config(
749
1053
  def get_config_dtype_str(
750
1054
  dtype: torch.dtype,
751
1055
  use_int8_w8a16: Optional[bool] = False,
1056
+ use_int4_w4a16: Optional[bool] = False,
752
1057
  use_fp8_w8a8: Optional[bool] = False,
753
1058
  use_int8_w8a8: Optional[bool] = False,
754
1059
  ):
@@ -756,6 +1061,8 @@ def get_config_dtype_str(
756
1061
  return "fp8_w8a8"
757
1062
  elif use_int8_w8a8:
758
1063
  return "int8_w8a8"
1064
+ elif use_int4_w4a16:
1065
+ return "int4_w4a16"
759
1066
  elif use_int8_w8a16:
760
1067
  return "int8_w8a16"
761
1068
  elif dtype == torch.float:
@@ -772,11 +1079,15 @@ def inplace_fused_experts(
772
1079
  topk_weights: torch.Tensor,
773
1080
  topk_ids: torch.Tensor,
774
1081
  activation: str = "silu",
1082
+ apply_router_weight_on_input: bool = False,
775
1083
  use_fp8_w8a8: bool = False,
776
1084
  use_int8_w8a8: bool = False,
777
1085
  use_int8_w8a16: bool = False,
1086
+ use_int4_w4a16: bool = False,
778
1087
  w1_scale: Optional[torch.Tensor] = None,
779
1088
  w2_scale: Optional[torch.Tensor] = None,
1089
+ w1_zp: Optional[torch.Tensor] = None,
1090
+ w2_zp: Optional[torch.Tensor] = None,
780
1091
  a1_scale: Optional[torch.Tensor] = None,
781
1092
  a2_scale: Optional[torch.Tensor] = None,
782
1093
  block_shape: Optional[List[int]] = None,
@@ -789,11 +1100,15 @@ def inplace_fused_experts(
789
1100
  topk_ids,
790
1101
  True,
791
1102
  activation,
1103
+ apply_router_weight_on_input,
792
1104
  use_fp8_w8a8,
793
1105
  use_int8_w8a8,
794
1106
  use_int8_w8a16,
1107
+ use_int4_w4a16,
795
1108
  w1_scale,
796
1109
  w2_scale,
1110
+ w1_zp,
1111
+ w2_zp,
797
1112
  a1_scale,
798
1113
  a2_scale,
799
1114
  block_shape,
@@ -807,11 +1122,15 @@ def inplace_fused_experts_fake(
807
1122
  topk_weights: torch.Tensor,
808
1123
  topk_ids: torch.Tensor,
809
1124
  activation: str = "silu",
1125
+ apply_router_weight_on_input: bool = False,
810
1126
  use_fp8_w8a8: bool = False,
811
1127
  use_int8_w8a8: bool = False,
812
1128
  use_int8_w8a16: bool = False,
1129
+ use_int4_w4a16: bool = False,
813
1130
  w1_scale: Optional[torch.Tensor] = None,
814
1131
  w2_scale: Optional[torch.Tensor] = None,
1132
+ w1_zp: Optional[torch.Tensor] = None,
1133
+ w2_zp: Optional[torch.Tensor] = None,
815
1134
  a1_scale: Optional[torch.Tensor] = None,
816
1135
  a2_scale: Optional[torch.Tensor] = None,
817
1136
  block_shape: Optional[List[int]] = None,
@@ -834,11 +1153,15 @@ def outplace_fused_experts(
834
1153
  topk_weights: torch.Tensor,
835
1154
  topk_ids: torch.Tensor,
836
1155
  activation: str = "silu",
1156
+ apply_router_weight_on_input: bool = False,
837
1157
  use_fp8_w8a8: bool = False,
838
1158
  use_int8_w8a8: bool = False,
839
1159
  use_int8_w8a16: bool = False,
1160
+ use_int4_w4a16: bool = False,
840
1161
  w1_scale: Optional[torch.Tensor] = None,
841
1162
  w2_scale: Optional[torch.Tensor] = None,
1163
+ w1_zp: Optional[torch.Tensor] = None,
1164
+ w2_zp: Optional[torch.Tensor] = None,
842
1165
  a1_scale: Optional[torch.Tensor] = None,
843
1166
  a2_scale: Optional[torch.Tensor] = None,
844
1167
  block_shape: Optional[List[int]] = None,
@@ -852,11 +1175,15 @@ def outplace_fused_experts(
852
1175
  topk_ids,
853
1176
  False,
854
1177
  activation,
1178
+ apply_router_weight_on_input,
855
1179
  use_fp8_w8a8,
856
1180
  use_int8_w8a8,
857
1181
  use_int8_w8a16,
1182
+ use_int4_w4a16,
858
1183
  w1_scale,
859
1184
  w2_scale,
1185
+ w1_zp,
1186
+ w2_zp,
860
1187
  a1_scale,
861
1188
  a2_scale,
862
1189
  block_shape,
@@ -871,11 +1198,15 @@ def outplace_fused_experts_fake(
871
1198
  topk_weights: torch.Tensor,
872
1199
  topk_ids: torch.Tensor,
873
1200
  activation: str = "silu",
1201
+ apply_router_weight_on_input: bool = False,
874
1202
  use_fp8_w8a8: bool = False,
875
1203
  use_int8_w8a8: bool = False,
876
1204
  use_int8_w8a16: bool = False,
1205
+ use_int4_w4a16: bool = False,
877
1206
  w1_scale: Optional[torch.Tensor] = None,
878
1207
  w2_scale: Optional[torch.Tensor] = None,
1208
+ w1_zp: Optional[torch.Tensor] = None,
1209
+ w2_zp: Optional[torch.Tensor] = None,
879
1210
  a1_scale: Optional[torch.Tensor] = None,
880
1211
  a2_scale: Optional[torch.Tensor] = None,
881
1212
  block_shape: Optional[List[int]] = None,
@@ -900,11 +1231,15 @@ def fused_experts(
900
1231
  topk_ids: torch.Tensor,
901
1232
  inplace: bool = False,
902
1233
  activation: str = "silu",
1234
+ apply_router_weight_on_input: bool = False,
903
1235
  use_fp8_w8a8: bool = False,
904
1236
  use_int8_w8a8: bool = False,
905
1237
  use_int8_w8a16: bool = False,
1238
+ use_int4_w4a16: bool = False,
906
1239
  w1_scale: Optional[torch.Tensor] = None,
907
1240
  w2_scale: Optional[torch.Tensor] = None,
1241
+ w1_zp: Optional[torch.Tensor] = None,
1242
+ w2_zp: Optional[torch.Tensor] = None,
908
1243
  a1_scale: Optional[torch.Tensor] = None,
909
1244
  a2_scale: Optional[torch.Tensor] = None,
910
1245
  block_shape: Optional[List[int]] = None,
@@ -919,11 +1254,15 @@ def fused_experts(
919
1254
  topk_weights,
920
1255
  topk_ids,
921
1256
  activation,
1257
+ apply_router_weight_on_input,
922
1258
  use_fp8_w8a8,
923
1259
  use_int8_w8a8,
924
1260
  use_int8_w8a16,
1261
+ use_int4_w4a16,
925
1262
  w1_scale,
926
1263
  w2_scale,
1264
+ w1_zp,
1265
+ w2_zp,
927
1266
  a1_scale,
928
1267
  a2_scale,
929
1268
  block_shape,
@@ -937,11 +1276,15 @@ def fused_experts(
937
1276
  topk_weights,
938
1277
  topk_ids,
939
1278
  activation,
1279
+ apply_router_weight_on_input,
940
1280
  use_fp8_w8a8,
941
1281
  use_int8_w8a8,
942
1282
  use_int8_w8a16,
1283
+ use_int4_w4a16,
943
1284
  w1_scale,
944
1285
  w2_scale,
1286
+ w1_zp,
1287
+ w2_zp,
945
1288
  a1_scale,
946
1289
  a2_scale,
947
1290
  block_shape,
@@ -957,11 +1300,15 @@ def fused_experts_impl(
957
1300
  topk_ids: torch.Tensor,
958
1301
  inplace: bool = False,
959
1302
  activation: str = "silu",
1303
+ apply_router_weight_on_input: bool = False,
960
1304
  use_fp8_w8a8: bool = False,
961
1305
  use_int8_w8a8: bool = False,
962
1306
  use_int8_w8a16: bool = False,
1307
+ use_int4_w4a16: bool = False,
963
1308
  w1_scale: Optional[torch.Tensor] = None,
964
1309
  w2_scale: Optional[torch.Tensor] = None,
1310
+ w1_zp: Optional[torch.Tensor] = None,
1311
+ w2_zp: Optional[torch.Tensor] = None,
965
1312
  a1_scale: Optional[torch.Tensor] = None,
966
1313
  a2_scale: Optional[torch.Tensor] = None,
967
1314
  block_shape: Optional[List[int]] = None,
@@ -976,7 +1323,12 @@ def fused_experts_impl(
976
1323
  padded_size = 0
977
1324
 
978
1325
  # Check constraints.
979
- assert hidden_states.shape[1] == w1.shape[2] - padded_size, "Hidden size mismatch"
1326
+ if use_int4_w4a16:
1327
+ assert hidden_states.shape[1] // 2 == w1.shape[2], "Hidden size mismatch"
1328
+ else:
1329
+ assert (
1330
+ hidden_states.shape[1] == w1.shape[2] - padded_size
1331
+ ), "Hidden size mismatch"
980
1332
  assert topk_weights.shape == topk_ids.shape, "topk shape mismatch"
981
1333
  assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
982
1334
  assert w1.is_contiguous(), "Expert weights1 must be contiguous"
@@ -993,6 +1345,7 @@ def fused_experts_impl(
993
1345
  use_fp8_w8a8=use_fp8_w8a8,
994
1346
  use_int8_w8a8=use_int8_w8a8,
995
1347
  use_int8_w8a16=use_int8_w8a16,
1348
+ use_int4_w4a16=use_int4_w4a16,
996
1349
  dtype=hidden_states.dtype,
997
1350
  )
998
1351
 
@@ -1074,18 +1427,20 @@ def fused_experts_impl(
1074
1427
  intermediate_cache1,
1075
1428
  a1_scale,
1076
1429
  w1_scale,
1430
+ w1_zp,
1077
1431
  curr_topk_weights,
1078
1432
  curr_topk_ids,
1079
1433
  sorted_token_ids,
1080
1434
  expert_ids,
1081
1435
  num_tokens_post_padded,
1082
- False,
1436
+ apply_router_weight_on_input,
1083
1437
  topk_ids.shape[1],
1084
1438
  config,
1085
1439
  compute_type=compute_type,
1086
1440
  use_fp8_w8a8=use_fp8_w8a8,
1087
1441
  use_int8_w8a8=use_int8_w8a8,
1088
1442
  use_int8_w8a16=use_int8_w8a16,
1443
+ use_int4_w4a16=use_int4_w4a16,
1089
1444
  block_shape=block_shape,
1090
1445
  )
1091
1446
  if activation == "silu":
@@ -1111,22 +1466,24 @@ def fused_experts_impl(
1111
1466
  (
1112
1467
  intermediate_cache3
1113
1468
  if not no_combine and topk_ids.shape[1] != 1
1114
- else out_hidden_states[begin_chunk_idx:end_chunk_idx]
1469
+ else out_hidden_states[begin_chunk_idx:end_chunk_idx].unsqueeze(0)
1115
1470
  ),
1116
1471
  a2_scale,
1117
1472
  w2_scale,
1473
+ w2_zp,
1118
1474
  curr_topk_weights,
1119
1475
  curr_topk_ids,
1120
1476
  sorted_token_ids,
1121
1477
  expert_ids,
1122
1478
  num_tokens_post_padded,
1123
- True,
1479
+ not apply_router_weight_on_input,
1124
1480
  1,
1125
1481
  config,
1126
1482
  compute_type=compute_type,
1127
1483
  use_fp8_w8a8=use_fp8_w8a8,
1128
1484
  use_int8_w8a8=use_int8_w8a8,
1129
1485
  use_int8_w8a16=use_int8_w8a16,
1486
+ use_int4_w4a16=use_int4_w4a16,
1130
1487
  block_shape=block_shape,
1131
1488
  )
1132
1489
 
@@ -1172,8 +1529,11 @@ def fused_moe(
1172
1529
  use_fp8_w8a8: bool = False,
1173
1530
  use_int8_w8a8: bool = False,
1174
1531
  use_int8_w8a16: bool = False,
1532
+ use_int4_w4a16: bool = False,
1175
1533
  w1_scale: Optional[torch.Tensor] = None,
1176
1534
  w2_scale: Optional[torch.Tensor] = None,
1535
+ w1_zp: Optional[torch.Tensor] = None,
1536
+ w2_zp: Optional[torch.Tensor] = None,
1177
1537
  a1_scale: Optional[torch.Tensor] = None,
1178
1538
  a2_scale: Optional[torch.Tensor] = None,
1179
1539
  block_shape: Optional[List[int]] = None,
@@ -1203,6 +1563,9 @@ def fused_moe(
1203
1563
  products for w1 and w2. Defaults to False.
1204
1564
  - use_int8_w8a16 (bool): If True, use fp8 arithmetic to compute the inner
1205
1565
  products for w1 and w2. Defaults to False.
1566
+ - use_int4_w4a16 (bool): If True, use matmul of int4 weight and bf16/fp16
1567
+ activation to compute the inner products for w1 and w2.
1568
+ Defaults to False.
1206
1569
  - w1_scale (Optional[torch.Tensor]): Optional scale to be used for
1207
1570
  w1.
1208
1571
  - w2_scale (Optional[torch.Tensor]): Optional scale to be used for
@@ -1242,8 +1605,11 @@ def fused_moe(
1242
1605
  use_fp8_w8a8=use_fp8_w8a8,
1243
1606
  use_int8_w8a8=use_int8_w8a8,
1244
1607
  use_int8_w8a16=use_int8_w8a16,
1608
+ use_int4_w4a16=use_int4_w4a16,
1245
1609
  w1_scale=w1_scale,
1246
1610
  w2_scale=w2_scale,
1611
+ w1_zp=w1_zp,
1612
+ w2_zp=w2_zp,
1247
1613
  a1_scale=a1_scale,
1248
1614
  a2_scale=a2_scale,
1249
1615
  block_shape=block_shape,