sglang 0.4.4.post3__py3-none-any.whl → 0.4.4.post4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (81) hide show
  1. sglang/bench_serving.py +49 -7
  2. sglang/srt/_custom_ops.py +59 -92
  3. sglang/srt/configs/model_config.py +1 -0
  4. sglang/srt/constrained/base_grammar_backend.py +5 -1
  5. sglang/srt/custom_op.py +5 -0
  6. sglang/srt/distributed/device_communicators/custom_all_reduce.py +27 -79
  7. sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +2 -2
  8. sglang/srt/entrypoints/engine.py +0 -5
  9. sglang/srt/layers/attention/flashattention_backend.py +394 -76
  10. sglang/srt/layers/attention/flashinfer_backend.py +5 -7
  11. sglang/srt/layers/attention/flashinfer_mla_backend.py +1 -3
  12. sglang/srt/layers/attention/flashmla_backend.py +1 -1
  13. sglang/srt/layers/moe/ep_moe/kernels.py +142 -0
  14. sglang/srt/layers/moe/ep_moe/layer.py +79 -80
  15. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +382 -199
  16. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_H20,block_shape=[128, 128].json +146 -0
  17. sglang/srt/layers/moe/fused_moe_triton/configs/E=257,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  18. sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  19. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +403 -47
  20. sglang/srt/layers/moe/topk.py +49 -3
  21. sglang/srt/layers/quantization/__init__.py +4 -1
  22. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +2 -1
  23. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +34 -10
  24. sglang/srt/layers/quantization/fp8_utils.py +1 -4
  25. sglang/srt/layers/quantization/moe_wna16.py +501 -0
  26. sglang/srt/layers/quantization/utils.py +1 -1
  27. sglang/srt/layers/rotary_embedding.py +0 -12
  28. sglang/srt/managers/cache_controller.py +34 -11
  29. sglang/srt/managers/mm_utils.py +202 -156
  30. sglang/srt/managers/multimodal_processor.py +0 -2
  31. sglang/srt/managers/multimodal_processors/base_processor.py +45 -77
  32. sglang/srt/managers/multimodal_processors/clip.py +7 -26
  33. sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +17 -58
  34. sglang/srt/managers/multimodal_processors/gemma3.py +12 -27
  35. sglang/srt/managers/multimodal_processors/janus_pro.py +21 -47
  36. sglang/srt/managers/multimodal_processors/llava.py +34 -14
  37. sglang/srt/managers/multimodal_processors/minicpm.py +35 -38
  38. sglang/srt/managers/multimodal_processors/mlama.py +10 -23
  39. sglang/srt/managers/multimodal_processors/qwen_vl.py +22 -45
  40. sglang/srt/managers/schedule_batch.py +185 -128
  41. sglang/srt/managers/scheduler.py +4 -4
  42. sglang/srt/managers/tokenizer_manager.py +1 -1
  43. sglang/srt/managers/utils.py +1 -6
  44. sglang/srt/mem_cache/hiradix_cache.py +62 -52
  45. sglang/srt/mem_cache/memory_pool.py +72 -6
  46. sglang/srt/mem_cache/paged_allocator.py +39 -0
  47. sglang/srt/metrics/collector.py +23 -53
  48. sglang/srt/model_executor/cuda_graph_runner.py +8 -6
  49. sglang/srt/model_executor/forward_batch_info.py +10 -10
  50. sglang/srt/model_executor/model_runner.py +59 -57
  51. sglang/srt/model_loader/loader.py +8 -0
  52. sglang/srt/models/clip.py +12 -7
  53. sglang/srt/models/deepseek_janus_pro.py +10 -15
  54. sglang/srt/models/deepseek_v2.py +212 -121
  55. sglang/srt/models/deepseek_vl2.py +105 -104
  56. sglang/srt/models/gemma3_mm.py +14 -80
  57. sglang/srt/models/llama.py +4 -1
  58. sglang/srt/models/llava.py +31 -19
  59. sglang/srt/models/llavavid.py +16 -7
  60. sglang/srt/models/minicpmo.py +63 -147
  61. sglang/srt/models/minicpmv.py +17 -27
  62. sglang/srt/models/mllama.py +29 -14
  63. sglang/srt/models/qwen2.py +9 -6
  64. sglang/srt/models/qwen2_5_vl.py +21 -31
  65. sglang/srt/models/qwen2_vl.py +20 -21
  66. sglang/srt/openai_api/adapter.py +18 -6
  67. sglang/srt/platforms/interface.py +371 -0
  68. sglang/srt/server_args.py +99 -14
  69. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +5 -5
  70. sglang/srt/speculative/eagle_utils.py +140 -28
  71. sglang/srt/speculative/eagle_worker.py +93 -24
  72. sglang/srt/utils.py +104 -51
  73. sglang/test/test_custom_ops.py +55 -0
  74. sglang/test/test_utils.py +13 -26
  75. sglang/utils.py +2 -2
  76. sglang/version.py +1 -1
  77. {sglang-0.4.4.post3.dist-info → sglang-0.4.4.post4.dist-info}/METADATA +4 -3
  78. {sglang-0.4.4.post3.dist-info → sglang-0.4.4.post4.dist-info}/RECORD +81 -76
  79. {sglang-0.4.4.post3.dist-info → sglang-0.4.4.post4.dist-info}/WHEEL +0 -0
  80. {sglang-0.4.4.post3.dist-info → sglang-0.4.4.post4.dist-info}/licenses/LICENSE +0 -0
  81. {sglang-0.4.4.post3.dist-info → sglang-0.4.4.post4.dist-info}/top_level.txt +0 -0
@@ -13,11 +13,6 @@ import triton
13
13
  import triton.language as tl
14
14
 
15
15
  from sglang.srt.layers.moe.topk import select_experts
16
- from sglang.srt.layers.quantization.fp8_kernel import per_token_group_quant_fp8
17
- from sglang.srt.layers.quantization.int8_kernel import (
18
- per_token_group_quant_int8,
19
- per_token_quant_int8,
20
- )
21
16
  from sglang.srt.utils import (
22
17
  direct_register_custom_op,
23
18
  get_bool_env_var,
@@ -42,9 +37,6 @@ if _is_cuda:
42
37
  from sgl_kernel import gelu_and_mul, silu_and_mul
43
38
 
44
39
  from sglang.srt.custom_op import scaled_fp8_quant as sgl_scaled_fp8_quant
45
- from sglang.srt.layers.quantization.fp8_kernel import (
46
- sglang_per_token_group_quant_fp8,
47
- )
48
40
  else:
49
41
  from vllm import _custom_ops as vllm_ops
50
42
 
@@ -52,6 +44,257 @@ if _is_cuda or _is_hip:
52
44
  from sgl_kernel import moe_align_block_size as sgl_moe_align_block_size
53
45
 
54
46
 
47
+ @triton.jit
48
+ def write_zeros_to_output(
49
+ c_ptr,
50
+ stride_cm,
51
+ stride_cn,
52
+ pid_n,
53
+ N,
54
+ offs_token,
55
+ token_mask,
56
+ BLOCK_SIZE_M,
57
+ BLOCK_SIZE_N,
58
+ compute_type,
59
+ ):
60
+ accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=compute_type)
61
+ offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
62
+ c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[None, :]
63
+ c_mask = token_mask[:, None] & (offs_cn[None, :] < N)
64
+ tl.store(c_ptrs, accumulator, mask=c_mask)
65
+
66
+
67
+ @triton.jit
68
+ def fused_moe_kernel_gptq_awq(
69
+ # Pointers to matrices
70
+ a_ptr,
71
+ b_ptr,
72
+ c_ptr,
73
+ b_scale_ptr,
74
+ b_zp_ptr,
75
+ topk_weights_ptr,
76
+ sorted_token_ids_ptr,
77
+ expert_ids_ptr,
78
+ num_tokens_post_padded_ptr,
79
+ # Matrix dimensions
80
+ N: tl.constexpr,
81
+ K: tl.constexpr,
82
+ EM,
83
+ num_valid_tokens,
84
+ # The stride variables represent how much to increase the ptr by when
85
+ # moving by 1 element in a particular dimension. E.g. `stride_am` is
86
+ # how much to increase `a_ptr` by to get the element one row down
87
+ # (A has M rows).
88
+ stride_am,
89
+ stride_ak,
90
+ stride_be,
91
+ stride_bk,
92
+ stride_bn,
93
+ stride_cm,
94
+ stride_cn,
95
+ stride_bse,
96
+ stride_bsk,
97
+ stride_bsn,
98
+ stride_bze,
99
+ stride_bzk,
100
+ stride_bzn,
101
+ group_size: tl.constexpr,
102
+ # Meta-parameters
103
+ BLOCK_SIZE_M: tl.constexpr,
104
+ BLOCK_SIZE_N: tl.constexpr,
105
+ BLOCK_SIZE_K: tl.constexpr,
106
+ GROUP_SIZE_M: tl.constexpr,
107
+ MUL_ROUTED_WEIGHT: tl.constexpr,
108
+ top_k: tl.constexpr,
109
+ compute_type: tl.constexpr,
110
+ has_zp: tl.constexpr,
111
+ use_int4_w4a16: tl.constexpr,
112
+ use_int8_w8a16: tl.constexpr,
113
+ even_Ks: tl.constexpr,
114
+ ):
115
+ """
116
+ Implements the fused computation for a Mixture of Experts (MOE) using
117
+ token and expert matrices.
118
+ Key Parameters:
119
+ - A: The input tensor representing tokens with shape (*, K), where '*' can
120
+ be any shape representing batches and K is the feature dimension of
121
+ each token.
122
+ - B: The stacked MOE weight tensor with shape (E, N, K), where E is
123
+ the number of experts, K is the input feature dimension, and N is
124
+ the output feature dimension.
125
+ - C: The output cache tensor with shape (M, topk, N), where M is the
126
+ total number of tokens post padding, topk is the number of times
127
+ each token is repeated, and N is the output feature dimension.
128
+ - sorted_token_ids: A tensor containing the sorted indices of tokens,
129
+ repeated topk times and arranged by the expert index they are
130
+ assigned to.
131
+ - expert_ids: A tensor containing the indices of the expert for each
132
+ block. It determines which expert matrix from B should be used for
133
+ each block in A.
134
+ This kernel performs the multiplication of a token by its corresponding
135
+ expert matrix as determined by `expert_ids`. The sorting of
136
+ `sorted_token_ids` by expert index and padding ensures divisibility by
137
+ BLOCK_SIZE_M, which is necessary to maintain consistency in block matrix
138
+ multiplication across different blocks processed by the same expert.
139
+ """
140
+ # -----------------------------------------------------------
141
+ # Map program ids `pid` to the block of C it should compute.
142
+ # This is done in a grouped ordering to promote L2 data reuse.
143
+ pid = tl.program_id(axis=0)
144
+ num_pid_m = tl.cdiv(EM, BLOCK_SIZE_M)
145
+ num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
146
+ num_pid_in_group = GROUP_SIZE_M * num_pid_n
147
+ group_id = pid // num_pid_in_group
148
+ first_pid_m = group_id * GROUP_SIZE_M
149
+ group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
150
+ pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
151
+ pid_n = (pid % num_pid_in_group) // group_size_m
152
+
153
+ # ----------------------------------------------------------
154
+ # Create pointers for the first blocks of A and B.
155
+ # We will advance this pointer as we move in the K direction
156
+ # and accumulate
157
+ # `a_ptrs` is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers
158
+ # `b_ptrs` is a block of [BLOCK_SIZE_K, BLOCK_SIZE_N] pointers
159
+ num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr)
160
+ if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded:
161
+ return
162
+ offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M).to(tl.int64)
163
+ offs_token = tl.load(sorted_token_ids_ptr + offs_token_id)
164
+ token_mask = offs_token < num_valid_tokens
165
+
166
+ off_experts = tl.load(expert_ids_ptr + pid_m).to(tl.int64)
167
+ if off_experts == -1:
168
+ # -----------------------------------------------------------
169
+ # Write back zeros to the output when the expert is not
170
+ # in the current expert parallel rank.
171
+ write_zeros_to_output(
172
+ c_ptr,
173
+ stride_cm,
174
+ stride_cn,
175
+ pid_n,
176
+ N,
177
+ offs_token,
178
+ token_mask,
179
+ BLOCK_SIZE_M,
180
+ BLOCK_SIZE_N,
181
+ compute_type,
182
+ )
183
+ return
184
+
185
+ offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N).to(tl.int64)) % N
186
+ offs_k = tl.arange(0, BLOCK_SIZE_K)
187
+ a_ptrs = a_ptr + (
188
+ offs_token[:, None] // top_k * stride_am + offs_k[None, :] * stride_ak
189
+ )
190
+
191
+ if use_int4_w4a16:
192
+ b_ptrs = (
193
+ b_ptr
194
+ + off_experts * stride_be
195
+ + (offs_k[:, None] // 2) * stride_bk
196
+ + offs_bn[None, :] * stride_bn
197
+ )
198
+ b_shifter = (offs_k[:, None] % 2) * 4
199
+ elif use_int8_w8a16:
200
+ b_ptrs = (
201
+ b_ptr
202
+ + off_experts * stride_be
203
+ + offs_k[:, None] * stride_bk
204
+ + offs_bn[None, :] * stride_bn
205
+ )
206
+
207
+ if not has_zp and use_int4_w4a16:
208
+ b_zp_num = 8
209
+ if not has_zp and use_int8_w8a16:
210
+ b_zp_num = 128
211
+ elif has_zp and use_int4_w4a16:
212
+ b_zp_shifter = (offs_bn[None, :] % 2) * 4
213
+
214
+ # -----------------------------------------------------------
215
+ # Iterate to compute a block of the C matrix.
216
+ # We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block
217
+ # of fp32 values for higher accuracy.
218
+ # `accumulator` will be converted back to fp16 after the loop.
219
+ accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
220
+ for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
221
+ # Load the next block of A and B, generate a mask by checking the
222
+ # K dimension.
223
+
224
+ if not even_Ks:
225
+ k_mask = offs_k[:, None] < K - k * BLOCK_SIZE_K
226
+ k_other = 0.0
227
+ else:
228
+ k_mask = None
229
+ k_other = None
230
+
231
+ a = tl.load(
232
+ a_ptrs,
233
+ mask=token_mask[:, None] & (offs_k[None, :] < K - k * BLOCK_SIZE_K),
234
+ other=0.0,
235
+ )
236
+ b = tl.load(b_ptrs)
237
+ if use_int4_w4a16:
238
+ b = (b >> b_shifter) & 0xF
239
+
240
+ b_scale_ptrs = (
241
+ b_scale_ptr
242
+ + off_experts * stride_bse
243
+ + offs_bn[None, :] * stride_bsn
244
+ + ((offs_k[:, None] + BLOCK_SIZE_K * k) // group_size) * stride_bsk
245
+ )
246
+ b_scale = tl.load(b_scale_ptrs, mask=k_mask, other=k_other)
247
+ b_scale = b_scale.to(tl.float32)
248
+
249
+ if has_zp and use_int4_w4a16:
250
+ offs_k_true = (offs_k[:, None] + BLOCK_SIZE_K * k) // group_size
251
+ b_zp_ptrs = (
252
+ b_zp_ptr
253
+ + off_experts * stride_bze
254
+ + (offs_bn[None, :] // 2) * stride_bzn
255
+ + offs_k_true * stride_bzk
256
+ )
257
+ b_zp = tl.load(b_zp_ptrs, mask=k_mask, other=k_other)
258
+ b_zp = (b_zp >> b_zp_shifter) & 0xF
259
+ b_zp = b_zp.to(tl.float32)
260
+ elif has_zp and use_int8_w8a16:
261
+ offs_k_true = (offs_k[:, None] + BLOCK_SIZE_K * k) // group_size
262
+ b_zp_ptrs = (
263
+ b_zp_ptr
264
+ + off_experts * stride_bze
265
+ + offs_bn[None, :] * stride_bzn
266
+ + offs_k_true * stride_bzk
267
+ )
268
+ b_zp = tl.load(b_zp_ptrs, mask=k_mask, other=k_other)
269
+ b_zp = b_zp.to(tl.float32)
270
+
271
+ # We accumulate along the K dimension.
272
+ if has_zp:
273
+ b = ((b.to(tl.float32) - b_zp) * b_scale).to(compute_type)
274
+ else:
275
+ b = ((b.to(tl.float32) - b_zp_num) * b_scale).to(compute_type)
276
+ accumulator = tl.dot(a, b, acc=accumulator)
277
+
278
+ # Advance the ptrs to the next K block.
279
+ a_ptrs += BLOCK_SIZE_K * stride_ak
280
+ if use_int4_w4a16:
281
+ b_ptrs += (BLOCK_SIZE_K // 2) * stride_bk
282
+ else:
283
+ b_ptrs += BLOCK_SIZE_K * stride_bk
284
+
285
+ if MUL_ROUTED_WEIGHT:
286
+ moe_weight = tl.load(topk_weights_ptr + offs_token, mask=token_mask, other=0)
287
+ accumulator = accumulator * moe_weight[:, None]
288
+
289
+ accumulator = accumulator.to(compute_type)
290
+ # -----------------------------------------------------------
291
+ # Write back the block of the output
292
+ offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
293
+ c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[None, :]
294
+ c_mask = token_mask[:, None] & (offs_cn[None, :] < N)
295
+ tl.store(c_ptrs, accumulator, mask=c_mask)
296
+
297
+
55
298
  @triton.jit
56
299
  def fused_moe_kernel(
57
300
  # Pointers to matrices
@@ -152,6 +395,7 @@ def fused_moe_kernel(
152
395
  return
153
396
  offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
154
397
  offs_token = tl.load(sorted_token_ids_ptr + offs_token_id)
398
+ offs_token = offs_token.to(tl.int64)
155
399
  token_mask = offs_token < num_valid_tokens
156
400
 
157
401
  offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
@@ -495,6 +739,7 @@ def invoke_fused_moe_kernel(
495
739
  C: torch.Tensor,
496
740
  A_scale: Optional[torch.Tensor],
497
741
  B_scale: Optional[torch.Tensor],
742
+ B_zp: Optional[torch.Tensor],
498
743
  topk_weights: torch.Tensor,
499
744
  topk_ids: torch.Tensor,
500
745
  sorted_token_ids: torch.Tensor,
@@ -507,9 +752,20 @@ def invoke_fused_moe_kernel(
507
752
  use_fp8_w8a8: bool,
508
753
  use_int8_w8a8: bool,
509
754
  use_int8_w8a16: bool,
755
+ use_int4_w4a16: bool,
510
756
  block_shape: Optional[List[int]] = None,
511
757
  no_combine: bool = False,
512
758
  ) -> None:
759
+ from sglang.srt.layers.quantization.int8_kernel import (
760
+ per_token_group_quant_int8,
761
+ per_token_quant_int8,
762
+ )
763
+
764
+ if _is_cuda:
765
+ from sglang.srt.layers.quantization.fp8_kernel import (
766
+ sglang_per_token_group_quant_fp8,
767
+ )
768
+
513
769
  assert topk_weights.stride(1) == 1
514
770
  assert sorted_token_ids.stride(0) == 1
515
771
 
@@ -547,8 +803,9 @@ def invoke_fused_moe_kernel(
547
803
  assert triton.cdiv(A.shape[-1], block_k) == A_scale.shape[-1]
548
804
  assert triton.cdiv(B.shape[-2], block_n) == B_scale.shape[-2]
549
805
  assert triton.cdiv(B.shape[-1], block_k) == B_scale.shape[-1]
550
- elif use_int8_w8a16:
806
+ elif use_int8_w8a16 or use_int4_w4a16:
551
807
  assert B_scale is not None
808
+ assert block_shape is None or block_shape[0] == 0
552
809
  else:
553
810
  assert A_scale is None
554
811
  assert B_scale is None
@@ -564,43 +821,90 @@ def invoke_fused_moe_kernel(
564
821
  else:
565
822
  even_Ks = False
566
823
 
567
- fused_moe_kernel[grid](
568
- A,
569
- B,
570
- C,
571
- A_scale,
572
- B_scale,
573
- topk_weights,
574
- sorted_token_ids,
575
- expert_ids,
576
- num_tokens_post_padded,
577
- B.shape[1],
578
- B.shape[2] - padded_size,
579
- sorted_token_ids.shape[0],
580
- topk_ids.numel(),
581
- A.stride(0),
582
- A.stride(1),
583
- B.stride(0),
584
- B.stride(2),
585
- B.stride(1),
586
- C.stride(1),
587
- C.stride(2),
588
- A_scale.stride(0) if A_scale is not None and A_scale.ndim == 2 else 0,
589
- A_scale.stride(1) if A_scale is not None and A_scale.ndim == 2 else 0,
590
- B_scale.stride(0) if B_scale is not None and B_scale.ndim >= 2 else 0,
591
- B_scale.stride(2) if B_scale is not None and B_scale.ndim == 3 else 0,
592
- B_scale.stride(1) if B_scale is not None and B_scale.ndim >= 2 else 0,
593
- 0 if block_shape is None else block_shape[0],
594
- 0 if block_shape is None else block_shape[1],
595
- MUL_ROUTED_WEIGHT=mul_routed_weight,
596
- top_k=top_k,
597
- compute_type=compute_type,
598
- use_fp8_w8a8=use_fp8_w8a8,
599
- use_int8_w8a8=use_int8_w8a8,
600
- use_int8_w8a16=use_int8_w8a16,
601
- even_Ks=even_Ks,
602
- **config,
603
- )
824
+ if (
825
+ (use_int8_w8a16 or use_int4_w4a16)
826
+ and block_shape is not None
827
+ and block_shape[1] > 0
828
+ ):
829
+ assert B_scale is not None and B_scale.ndim == 3
830
+ assert B_zp is None or B_zp.ndim == 3
831
+ fused_moe_kernel_gptq_awq[grid](
832
+ A,
833
+ B,
834
+ C,
835
+ B_scale,
836
+ B_zp,
837
+ topk_weights,
838
+ sorted_token_ids,
839
+ expert_ids,
840
+ num_tokens_post_padded,
841
+ B.shape[1],
842
+ A.shape[1],
843
+ sorted_token_ids.shape[0],
844
+ topk_ids.numel(),
845
+ A.stride(0),
846
+ A.stride(1),
847
+ B.stride(0),
848
+ B.stride(2),
849
+ B.stride(1),
850
+ C.stride(1),
851
+ C.stride(2),
852
+ B_scale.stride(0),
853
+ B_scale.stride(2),
854
+ B_scale.stride(1),
855
+ B_zp.stride(0) if B_zp is not None else 0,
856
+ B_zp.stride(2) if B_zp is not None else 0,
857
+ B_zp.stride(1) if B_zp is not None else 0,
858
+ group_size=block_shape[1],
859
+ MUL_ROUTED_WEIGHT=mul_routed_weight,
860
+ top_k=top_k,
861
+ compute_type=compute_type,
862
+ has_zp=B_zp is not None,
863
+ use_int4_w4a16=use_int4_w4a16,
864
+ use_int8_w8a16=use_int8_w8a16,
865
+ even_Ks=even_Ks,
866
+ **config,
867
+ )
868
+
869
+ else:
870
+
871
+ fused_moe_kernel[grid](
872
+ A,
873
+ B,
874
+ C,
875
+ A_scale,
876
+ B_scale,
877
+ topk_weights,
878
+ sorted_token_ids,
879
+ expert_ids,
880
+ num_tokens_post_padded,
881
+ B.shape[1],
882
+ B.shape[2] - padded_size,
883
+ sorted_token_ids.shape[0],
884
+ topk_ids.numel(),
885
+ A.stride(0),
886
+ A.stride(1),
887
+ B.stride(0),
888
+ B.stride(2),
889
+ B.stride(1),
890
+ C.stride(1),
891
+ C.stride(2),
892
+ A_scale.stride(0) if A_scale is not None and A_scale.ndim == 2 else 0,
893
+ A_scale.stride(1) if A_scale is not None and A_scale.ndim == 2 else 0,
894
+ B_scale.stride(0) if B_scale is not None and B_scale.ndim >= 2 else 0,
895
+ B_scale.stride(2) if B_scale is not None and B_scale.ndim == 3 else 0,
896
+ B_scale.stride(1) if B_scale is not None and B_scale.ndim >= 2 else 0,
897
+ 0 if block_shape is None else block_shape[0],
898
+ 0 if block_shape is None else block_shape[1],
899
+ MUL_ROUTED_WEIGHT=mul_routed_weight,
900
+ top_k=top_k,
901
+ compute_type=compute_type,
902
+ use_fp8_w8a8=use_fp8_w8a8,
903
+ use_int8_w8a8=use_int8_w8a8,
904
+ use_int8_w8a16=use_int8_w8a16,
905
+ even_Ks=even_Ks,
906
+ **config,
907
+ )
604
908
 
605
909
 
606
910
  def get_config_file_name(
@@ -749,6 +1053,7 @@ def try_get_optimal_moe_config(
749
1053
  def get_config_dtype_str(
750
1054
  dtype: torch.dtype,
751
1055
  use_int8_w8a16: Optional[bool] = False,
1056
+ use_int4_w4a16: Optional[bool] = False,
752
1057
  use_fp8_w8a8: Optional[bool] = False,
753
1058
  use_int8_w8a8: Optional[bool] = False,
754
1059
  ):
@@ -756,6 +1061,8 @@ def get_config_dtype_str(
756
1061
  return "fp8_w8a8"
757
1062
  elif use_int8_w8a8:
758
1063
  return "int8_w8a8"
1064
+ elif use_int4_w4a16:
1065
+ return "int4_w4a16"
759
1066
  elif use_int8_w8a16:
760
1067
  return "int8_w8a16"
761
1068
  elif dtype == torch.float:
@@ -775,8 +1082,11 @@ def inplace_fused_experts(
775
1082
  use_fp8_w8a8: bool = False,
776
1083
  use_int8_w8a8: bool = False,
777
1084
  use_int8_w8a16: bool = False,
1085
+ use_int4_w4a16: bool = False,
778
1086
  w1_scale: Optional[torch.Tensor] = None,
779
1087
  w2_scale: Optional[torch.Tensor] = None,
1088
+ w1_zp: Optional[torch.Tensor] = None,
1089
+ w2_zp: Optional[torch.Tensor] = None,
780
1090
  a1_scale: Optional[torch.Tensor] = None,
781
1091
  a2_scale: Optional[torch.Tensor] = None,
782
1092
  block_shape: Optional[List[int]] = None,
@@ -792,8 +1102,11 @@ def inplace_fused_experts(
792
1102
  use_fp8_w8a8,
793
1103
  use_int8_w8a8,
794
1104
  use_int8_w8a16,
1105
+ use_int4_w4a16,
795
1106
  w1_scale,
796
1107
  w2_scale,
1108
+ w1_zp,
1109
+ w2_zp,
797
1110
  a1_scale,
798
1111
  a2_scale,
799
1112
  block_shape,
@@ -810,8 +1123,11 @@ def inplace_fused_experts_fake(
810
1123
  use_fp8_w8a8: bool = False,
811
1124
  use_int8_w8a8: bool = False,
812
1125
  use_int8_w8a16: bool = False,
1126
+ use_int4_w4a16: bool = False,
813
1127
  w1_scale: Optional[torch.Tensor] = None,
814
1128
  w2_scale: Optional[torch.Tensor] = None,
1129
+ w1_zp: Optional[torch.Tensor] = None,
1130
+ w2_zp: Optional[torch.Tensor] = None,
815
1131
  a1_scale: Optional[torch.Tensor] = None,
816
1132
  a2_scale: Optional[torch.Tensor] = None,
817
1133
  block_shape: Optional[List[int]] = None,
@@ -837,8 +1153,11 @@ def outplace_fused_experts(
837
1153
  use_fp8_w8a8: bool = False,
838
1154
  use_int8_w8a8: bool = False,
839
1155
  use_int8_w8a16: bool = False,
1156
+ use_int4_w4a16: bool = False,
840
1157
  w1_scale: Optional[torch.Tensor] = None,
841
1158
  w2_scale: Optional[torch.Tensor] = None,
1159
+ w1_zp: Optional[torch.Tensor] = None,
1160
+ w2_zp: Optional[torch.Tensor] = None,
842
1161
  a1_scale: Optional[torch.Tensor] = None,
843
1162
  a2_scale: Optional[torch.Tensor] = None,
844
1163
  block_shape: Optional[List[int]] = None,
@@ -855,8 +1174,11 @@ def outplace_fused_experts(
855
1174
  use_fp8_w8a8,
856
1175
  use_int8_w8a8,
857
1176
  use_int8_w8a16,
1177
+ use_int4_w4a16,
858
1178
  w1_scale,
859
1179
  w2_scale,
1180
+ w1_zp,
1181
+ w2_zp,
860
1182
  a1_scale,
861
1183
  a2_scale,
862
1184
  block_shape,
@@ -874,8 +1196,11 @@ def outplace_fused_experts_fake(
874
1196
  use_fp8_w8a8: bool = False,
875
1197
  use_int8_w8a8: bool = False,
876
1198
  use_int8_w8a16: bool = False,
1199
+ use_int4_w4a16: bool = False,
877
1200
  w1_scale: Optional[torch.Tensor] = None,
878
1201
  w2_scale: Optional[torch.Tensor] = None,
1202
+ w1_zp: Optional[torch.Tensor] = None,
1203
+ w2_zp: Optional[torch.Tensor] = None,
879
1204
  a1_scale: Optional[torch.Tensor] = None,
880
1205
  a2_scale: Optional[torch.Tensor] = None,
881
1206
  block_shape: Optional[List[int]] = None,
@@ -903,8 +1228,11 @@ def fused_experts(
903
1228
  use_fp8_w8a8: bool = False,
904
1229
  use_int8_w8a8: bool = False,
905
1230
  use_int8_w8a16: bool = False,
1231
+ use_int4_w4a16: bool = False,
906
1232
  w1_scale: Optional[torch.Tensor] = None,
907
1233
  w2_scale: Optional[torch.Tensor] = None,
1234
+ w1_zp: Optional[torch.Tensor] = None,
1235
+ w2_zp: Optional[torch.Tensor] = None,
908
1236
  a1_scale: Optional[torch.Tensor] = None,
909
1237
  a2_scale: Optional[torch.Tensor] = None,
910
1238
  block_shape: Optional[List[int]] = None,
@@ -922,8 +1250,11 @@ def fused_experts(
922
1250
  use_fp8_w8a8,
923
1251
  use_int8_w8a8,
924
1252
  use_int8_w8a16,
1253
+ use_int4_w4a16,
925
1254
  w1_scale,
926
1255
  w2_scale,
1256
+ w1_zp,
1257
+ w2_zp,
927
1258
  a1_scale,
928
1259
  a2_scale,
929
1260
  block_shape,
@@ -940,8 +1271,11 @@ def fused_experts(
940
1271
  use_fp8_w8a8,
941
1272
  use_int8_w8a8,
942
1273
  use_int8_w8a16,
1274
+ use_int4_w4a16,
943
1275
  w1_scale,
944
1276
  w2_scale,
1277
+ w1_zp,
1278
+ w2_zp,
945
1279
  a1_scale,
946
1280
  a2_scale,
947
1281
  block_shape,
@@ -960,8 +1294,11 @@ def fused_experts_impl(
960
1294
  use_fp8_w8a8: bool = False,
961
1295
  use_int8_w8a8: bool = False,
962
1296
  use_int8_w8a16: bool = False,
1297
+ use_int4_w4a16: bool = False,
963
1298
  w1_scale: Optional[torch.Tensor] = None,
964
1299
  w2_scale: Optional[torch.Tensor] = None,
1300
+ w1_zp: Optional[torch.Tensor] = None,
1301
+ w2_zp: Optional[torch.Tensor] = None,
965
1302
  a1_scale: Optional[torch.Tensor] = None,
966
1303
  a2_scale: Optional[torch.Tensor] = None,
967
1304
  block_shape: Optional[List[int]] = None,
@@ -976,7 +1313,12 @@ def fused_experts_impl(
976
1313
  padded_size = 0
977
1314
 
978
1315
  # Check constraints.
979
- assert hidden_states.shape[1] == w1.shape[2] - padded_size, "Hidden size mismatch"
1316
+ if use_int4_w4a16:
1317
+ assert hidden_states.shape[1] // 2 == w1.shape[2], "Hidden size mismatch"
1318
+ else:
1319
+ assert (
1320
+ hidden_states.shape[1] == w1.shape[2] - padded_size
1321
+ ), "Hidden size mismatch"
980
1322
  assert topk_weights.shape == topk_ids.shape, "topk shape mismatch"
981
1323
  assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
982
1324
  assert w1.is_contiguous(), "Expert weights1 must be contiguous"
@@ -993,6 +1335,7 @@ def fused_experts_impl(
993
1335
  use_fp8_w8a8=use_fp8_w8a8,
994
1336
  use_int8_w8a8=use_int8_w8a8,
995
1337
  use_int8_w8a16=use_int8_w8a16,
1338
+ use_int4_w4a16=use_int4_w4a16,
996
1339
  dtype=hidden_states.dtype,
997
1340
  )
998
1341
 
@@ -1074,6 +1417,7 @@ def fused_experts_impl(
1074
1417
  intermediate_cache1,
1075
1418
  a1_scale,
1076
1419
  w1_scale,
1420
+ w1_zp,
1077
1421
  curr_topk_weights,
1078
1422
  curr_topk_ids,
1079
1423
  sorted_token_ids,
@@ -1086,6 +1430,7 @@ def fused_experts_impl(
1086
1430
  use_fp8_w8a8=use_fp8_w8a8,
1087
1431
  use_int8_w8a8=use_int8_w8a8,
1088
1432
  use_int8_w8a16=use_int8_w8a16,
1433
+ use_int4_w4a16=use_int4_w4a16,
1089
1434
  block_shape=block_shape,
1090
1435
  )
1091
1436
  if activation == "silu":
@@ -1115,6 +1460,7 @@ def fused_experts_impl(
1115
1460
  ),
1116
1461
  a2_scale,
1117
1462
  w2_scale,
1463
+ w2_zp,
1118
1464
  curr_topk_weights,
1119
1465
  curr_topk_ids,
1120
1466
  sorted_token_ids,
@@ -1127,6 +1473,7 @@ def fused_experts_impl(
1127
1473
  use_fp8_w8a8=use_fp8_w8a8,
1128
1474
  use_int8_w8a8=use_int8_w8a8,
1129
1475
  use_int8_w8a16=use_int8_w8a16,
1476
+ use_int4_w4a16=use_int4_w4a16,
1130
1477
  block_shape=block_shape,
1131
1478
  )
1132
1479
 
@@ -1172,8 +1519,11 @@ def fused_moe(
1172
1519
  use_fp8_w8a8: bool = False,
1173
1520
  use_int8_w8a8: bool = False,
1174
1521
  use_int8_w8a16: bool = False,
1522
+ use_int4_w4a16: bool = False,
1175
1523
  w1_scale: Optional[torch.Tensor] = None,
1176
1524
  w2_scale: Optional[torch.Tensor] = None,
1525
+ w1_zp: Optional[torch.Tensor] = None,
1526
+ w2_zp: Optional[torch.Tensor] = None,
1177
1527
  a1_scale: Optional[torch.Tensor] = None,
1178
1528
  a2_scale: Optional[torch.Tensor] = None,
1179
1529
  block_shape: Optional[List[int]] = None,
@@ -1203,6 +1553,9 @@ def fused_moe(
1203
1553
  products for w1 and w2. Defaults to False.
1204
1554
  - use_int8_w8a16 (bool): If True, use fp8 arithmetic to compute the inner
1205
1555
  products for w1 and w2. Defaults to False.
1556
+ - use_int4_w4a16 (bool): If True, use matmul of int4 weight and bf16/fp16
1557
+ activation to compute the inner products for w1 and w2.
1558
+ Defaults to False.
1206
1559
  - w1_scale (Optional[torch.Tensor]): Optional scale to be used for
1207
1560
  w1.
1208
1561
  - w2_scale (Optional[torch.Tensor]): Optional scale to be used for
@@ -1242,8 +1595,11 @@ def fused_moe(
1242
1595
  use_fp8_w8a8=use_fp8_w8a8,
1243
1596
  use_int8_w8a8=use_int8_w8a8,
1244
1597
  use_int8_w8a16=use_int8_w8a16,
1598
+ use_int4_w4a16=use_int4_w4a16,
1245
1599
  w1_scale=w1_scale,
1246
1600
  w2_scale=w2_scale,
1601
+ w1_zp=w1_zp,
1602
+ w2_zp=w2_zp,
1247
1603
  a1_scale=a1_scale,
1248
1604
  a2_scale=a2_scale,
1249
1605
  block_shape=block_shape,