sglang 0.3.6__py3-none-any.whl → 0.3.6.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (102) hide show
  1. sglang/__init__.py +2 -2
  2. sglang/api.py +2 -2
  3. sglang/bench_one_batch.py +2 -4
  4. sglang/bench_serving.py +75 -26
  5. sglang/lang/backend/base_backend.py +1 -1
  6. sglang/lang/backend/runtime_endpoint.py +2 -2
  7. sglang/srt/configs/model_config.py +13 -14
  8. sglang/srt/constrained/__init__.py +13 -14
  9. sglang/srt/constrained/base_grammar_backend.py +13 -15
  10. sglang/srt/constrained/outlines_backend.py +13 -15
  11. sglang/srt/constrained/outlines_jump_forward.py +13 -15
  12. sglang/srt/constrained/xgrammar_backend.py +38 -57
  13. sglang/srt/conversation.py +13 -15
  14. sglang/srt/hf_transformers_utils.py +13 -15
  15. sglang/srt/layers/activation.py +13 -13
  16. sglang/srt/layers/attention/flashinfer_backend.py +13 -6
  17. sglang/srt/layers/attention/triton_ops/decode_attention.py +51 -55
  18. sglang/srt/layers/attention/triton_ops/extend_attention.py +16 -16
  19. sglang/srt/layers/attention/triton_ops/prefill_attention.py +13 -15
  20. sglang/srt/layers/custom_op_util.py +13 -14
  21. sglang/srt/layers/fused_moe_grok/__init__.py +1 -0
  22. sglang/srt/layers/{fused_moe → fused_moe_grok}/layer.py +4 -9
  23. sglang/srt/layers/{fused_moe/patch.py → fused_moe_patch.py} +5 -0
  24. sglang/srt/layers/fused_moe_triton/__init__.py +44 -0
  25. sglang/srt/layers/fused_moe_triton/fused_moe.py +861 -0
  26. sglang/srt/layers/fused_moe_triton/layer.py +633 -0
  27. sglang/srt/layers/layernorm.py +13 -15
  28. sglang/srt/layers/logits_processor.py +13 -15
  29. sglang/srt/layers/quantization/__init__.py +77 -17
  30. sglang/srt/layers/radix_attention.py +13 -15
  31. sglang/srt/layers/rotary_embedding.py +13 -13
  32. sglang/srt/lora/lora.py +13 -14
  33. sglang/srt/lora/lora_config.py +13 -14
  34. sglang/srt/lora/lora_manager.py +22 -24
  35. sglang/srt/managers/data_parallel_controller.py +25 -19
  36. sglang/srt/managers/detokenizer_manager.py +13 -16
  37. sglang/srt/managers/io_struct.py +43 -28
  38. sglang/srt/managers/schedule_batch.py +55 -26
  39. sglang/srt/managers/schedule_policy.py +13 -15
  40. sglang/srt/managers/scheduler.py +89 -70
  41. sglang/srt/managers/session_controller.py +14 -15
  42. sglang/srt/managers/tokenizer_manager.py +29 -22
  43. sglang/srt/managers/tp_worker.py +13 -15
  44. sglang/srt/managers/tp_worker_overlap_thread.py +13 -15
  45. sglang/srt/metrics/collector.py +13 -15
  46. sglang/srt/metrics/func_timer.py +13 -15
  47. sglang/srt/mm_utils.py +13 -14
  48. sglang/srt/model_executor/cuda_graph_runner.py +20 -19
  49. sglang/srt/model_executor/forward_batch_info.py +19 -17
  50. sglang/srt/model_executor/model_runner.py +42 -30
  51. sglang/srt/models/chatglm.py +15 -16
  52. sglang/srt/models/commandr.py +15 -16
  53. sglang/srt/models/dbrx.py +15 -16
  54. sglang/srt/models/deepseek.py +15 -15
  55. sglang/srt/models/deepseek_v2.py +15 -15
  56. sglang/srt/models/exaone.py +14 -15
  57. sglang/srt/models/gemma.py +14 -14
  58. sglang/srt/models/gemma2.py +24 -19
  59. sglang/srt/models/gemma2_reward.py +13 -14
  60. sglang/srt/models/gpt_bigcode.py +14 -14
  61. sglang/srt/models/grok.py +15 -15
  62. sglang/srt/models/internlm2.py +13 -15
  63. sglang/srt/models/internlm2_reward.py +13 -14
  64. sglang/srt/models/llama.py +21 -21
  65. sglang/srt/models/llama_classification.py +13 -14
  66. sglang/srt/models/llama_reward.py +13 -14
  67. sglang/srt/models/llava.py +13 -15
  68. sglang/srt/models/llavavid.py +13 -15
  69. sglang/srt/models/minicpm.py +13 -15
  70. sglang/srt/models/minicpm3.py +13 -15
  71. sglang/srt/models/mistral.py +13 -15
  72. sglang/srt/models/mixtral.py +15 -15
  73. sglang/srt/models/mixtral_quant.py +14 -14
  74. sglang/srt/models/olmo.py +21 -19
  75. sglang/srt/models/olmoe.py +23 -20
  76. sglang/srt/models/qwen.py +14 -14
  77. sglang/srt/models/qwen2.py +22 -19
  78. sglang/srt/models/qwen2_moe.py +17 -18
  79. sglang/srt/models/stablelm.py +18 -16
  80. sglang/srt/models/torch_native_llama.py +15 -17
  81. sglang/srt/models/xverse.py +13 -14
  82. sglang/srt/models/xverse_moe.py +15 -16
  83. sglang/srt/models/yivl.py +13 -15
  84. sglang/srt/openai_api/adapter.py +13 -15
  85. sglang/srt/openai_api/protocol.py +13 -15
  86. sglang/srt/sampling/sampling_batch_info.py +4 -1
  87. sglang/srt/sampling/sampling_params.py +13 -15
  88. sglang/srt/server.py +59 -34
  89. sglang/srt/server_args.py +22 -22
  90. sglang/srt/utils.py +196 -17
  91. sglang/test/few_shot_gsm8k.py +8 -4
  92. sglang/test/runners.py +13 -14
  93. sglang/test/test_utils.py +1 -1
  94. sglang/version.py +1 -1
  95. {sglang-0.3.6.dist-info → sglang-0.3.6.post1.dist-info}/LICENSE +1 -1
  96. {sglang-0.3.6.dist-info → sglang-0.3.6.post1.dist-info}/METADATA +24 -15
  97. sglang-0.3.6.post1.dist-info/RECORD +164 -0
  98. sglang/srt/layers/fused_moe/__init__.py +0 -1
  99. sglang-0.3.6.dist-info/RECORD +0 -161
  100. /sglang/srt/layers/{fused_moe → fused_moe_grok}/fused_moe.py +0 -0
  101. {sglang-0.3.6.dist-info → sglang-0.3.6.post1.dist-info}/WHEEL +0 -0
  102. {sglang-0.3.6.dist-info → sglang-0.3.6.post1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,861 @@
1
+ # Adapted from https://github.com/vllm-project/vllm/blob/a6221a144af772fd1a68fe7e627935dc53e81738/vllm/model_executor/layers/fused_moe/fused_moe.py
2
+
3
+ """Fused MoE kernel."""
4
+
5
+ import functools
6
+ import json
7
+ import logging
8
+ import os
9
+ from typing import Any, Callable, Dict, Optional, Tuple
10
+
11
+ import torch
12
+ import triton
13
+ import triton.language as tl
14
+ from vllm import _custom_ops as ops
15
+
16
+ from sglang.srt.utils import direct_register_custom_op, get_device_name
17
+
18
+ logger = logging.getLogger(__name__)
19
+
20
+
21
+ @triton.jit
22
+ def fused_moe_kernel(
23
+ # Pointers to matrices
24
+ a_ptr,
25
+ b_ptr,
26
+ c_ptr,
27
+ a_scale_ptr,
28
+ b_scale_ptr,
29
+ topk_weights_ptr,
30
+ sorted_token_ids_ptr,
31
+ expert_ids_ptr,
32
+ num_tokens_post_padded_ptr,
33
+ # Matrix dimensions
34
+ N,
35
+ K,
36
+ EM,
37
+ num_valid_tokens,
38
+ # The stride variables represent how much to increase the ptr by when
39
+ # moving by 1 element in a particular dimension. E.g. `stride_am` is
40
+ # how much to increase `a_ptr` by to get the element one row down
41
+ # (A has M rows).
42
+ stride_am,
43
+ stride_ak,
44
+ stride_be,
45
+ stride_bk,
46
+ stride_bn,
47
+ stride_cm,
48
+ stride_cn,
49
+ stride_bse,
50
+ stride_bsn,
51
+ # Meta-parameters
52
+ BLOCK_SIZE_M: tl.constexpr,
53
+ BLOCK_SIZE_N: tl.constexpr,
54
+ BLOCK_SIZE_K: tl.constexpr,
55
+ GROUP_SIZE_M: tl.constexpr,
56
+ MUL_ROUTED_WEIGHT: tl.constexpr,
57
+ top_k: tl.constexpr,
58
+ compute_type: tl.constexpr,
59
+ use_fp8_w8a8: tl.constexpr,
60
+ use_int8_w8a16: tl.constexpr,
61
+ ):
62
+ """
63
+ Implements the fused computation for a Mixture of Experts (MOE) using
64
+ token and expert matrices.
65
+
66
+ Key Parameters:
67
+ - A: The input tensor representing tokens with shape (*, K), where '*' can
68
+ be any shape representing batches and K is the feature dimension of
69
+ each token.
70
+ - B: The stacked MOE weight tensor with shape (E, N, K), where E is
71
+ the number of experts, K is the input feature dimension, and N is
72
+ the output feature dimension.
73
+ - C: The output cache tensor with shape (M, topk, N), where M is the
74
+ total number of tokens post padding, topk is the number of times
75
+ each token is repeated, and N is the output feature dimension.
76
+ - sorted_token_ids: A tensor containing the sorted indices of tokens,
77
+ repeated topk times and arranged by the expert index they are
78
+ assigned to.
79
+ - expert_ids: A tensor containing the indices of the expert for each
80
+ block. It determines which expert matrix from B should be used for
81
+ each block in A.
82
+ This kernel performs the multiplication of a token by its corresponding
83
+ expert matrix as determined by `expert_ids`. The sorting of
84
+ `sorted_token_ids` by expert index and padding ensures divisibility by
85
+ BLOCK_SIZE_M, which is necessary to maintain consistency in block matrix
86
+ multiplication across different blocks processed by the same expert.
87
+ """
88
+ # -----------------------------------------------------------
89
+ # Map program ids `pid` to the block of C it should compute.
90
+ # This is done in a grouped ordering to promote L2 data reuse.
91
+ pid = tl.program_id(axis=0)
92
+ num_pid_m = tl.cdiv(EM, BLOCK_SIZE_M)
93
+ num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
94
+ num_pid_in_group = GROUP_SIZE_M * num_pid_n
95
+ group_id = pid // num_pid_in_group
96
+ first_pid_m = group_id * GROUP_SIZE_M
97
+ group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
98
+ pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
99
+ pid_n = (pid % num_pid_in_group) // group_size_m
100
+
101
+ # ----------------------------------------------------------
102
+ # Create pointers for the first blocks of A and B.
103
+ # We will advance this pointer as we move in the K direction
104
+ # and accumulate
105
+ # `a_ptrs` is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers
106
+ # `b_ptrs` is a block of [BLOCK_SIZE_K, BLOCK_SIZE_N] pointers
107
+ num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr)
108
+ if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded:
109
+ return
110
+ offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
111
+ offs_token = tl.load(sorted_token_ids_ptr + offs_token_id)
112
+ token_mask = offs_token < num_valid_tokens
113
+
114
+ offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
115
+ offs_k = tl.arange(0, BLOCK_SIZE_K)
116
+ a_ptrs = a_ptr + (
117
+ offs_token[:, None] // top_k * stride_am + offs_k[None, :] * stride_ak
118
+ )
119
+
120
+ off_experts = tl.load(expert_ids_ptr + pid_m)
121
+ b_ptrs = (
122
+ b_ptr
123
+ + off_experts * stride_be
124
+ + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
125
+ )
126
+ if use_int8_w8a16:
127
+ b_scale_ptrs = (
128
+ b_scale_ptr + off_experts * stride_bse + offs_bn[None, :] * stride_bsn
129
+ )
130
+ b_scale = tl.load(b_scale_ptrs)
131
+
132
+ if use_fp8_w8a8:
133
+ a_scale = tl.load(a_scale_ptr)
134
+ b_scale = tl.load(b_scale_ptr + off_experts)
135
+
136
+ # -----------------------------------------------------------
137
+ # Iterate to compute a block of the C matrix.
138
+ # We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block
139
+ # of fp32 values for higher accuracy.
140
+ # `accumulator` will be converted back to fp16 after the loop.
141
+ accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
142
+
143
+ for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
144
+ # Load the next block of A and B, generate a mask by checking the
145
+ # K dimension.
146
+ a = tl.load(
147
+ a_ptrs,
148
+ mask=token_mask[:, None] & (offs_k[None, :] < K - k * BLOCK_SIZE_K),
149
+ other=0.0,
150
+ )
151
+ b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
152
+ # We accumulate along the K dimension.
153
+ if use_int8_w8a16:
154
+ accumulator = tl.dot(a, b.to(compute_type), acc=accumulator)
155
+ elif use_fp8_w8a8:
156
+ accumulator = tl.dot(a, b, acc=accumulator)
157
+ else:
158
+ accumulator += tl.dot(a, b)
159
+ # Advance the ptrs to the next K block.
160
+ a_ptrs += BLOCK_SIZE_K * stride_ak
161
+ b_ptrs += BLOCK_SIZE_K * stride_bk
162
+
163
+ if MUL_ROUTED_WEIGHT:
164
+ moe_weight = tl.load(topk_weights_ptr + offs_token, mask=token_mask, other=0)
165
+ accumulator = accumulator * moe_weight[:, None]
166
+ if use_int8_w8a16:
167
+ accumulator = (accumulator * b_scale).to(compute_type)
168
+ elif use_fp8_w8a8:
169
+ accumulator = (accumulator * a_scale * b_scale).to(compute_type)
170
+ else:
171
+ accumulator = accumulator.to(compute_type)
172
+ # -----------------------------------------------------------
173
+ # Write back the block of the output
174
+ offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
175
+ c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[None, :]
176
+ c_mask = token_mask[:, None] & (offs_cn[None, :] < N)
177
+ tl.store(c_ptrs, accumulator, mask=c_mask)
178
+
179
+
180
+ def moe_align_block_size(
181
+ topk_ids: torch.Tensor, block_size: int, num_experts: int
182
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
183
+ """
184
+ Aligns the token distribution across experts to be compatible with block
185
+ size for matrix multiplication.
186
+
187
+ Parameters:
188
+ - topk_ids: A tensor of shape [total_tokens, top_k] representing the
189
+ top-k expert indices for each token.
190
+ - block_size: The block size used in block matrix multiplication.
191
+ - num_experts: The total number of experts.
192
+
193
+ Returns:
194
+ - sorted_token_ids: A tensor containing the sorted token indices according
195
+ to their allocated expert.
196
+ - expert_ids: A tensor indicating the assigned expert index for each block.
197
+ - num_tokens_post_padded: The total number of tokens after padding,
198
+ ensuring divisibility by block_size.
199
+
200
+ This function pads the number of tokens that each expert needs to process
201
+ so that it is divisible by block_size.
202
+ Padding ensures that during block matrix multiplication, the dimensions
203
+ align correctly.
204
+
205
+ Example:
206
+ Given topk_ids = [[2, 3, 4], [1, 2, 4], [1, 3, 4], [1, 2, 3]],
207
+ block_size = 4, and num_experts = 4:
208
+ - We initially have 12 tokens (after repeating 'top_k' times) and 4 experts,
209
+ with each expert needing to process 3 tokens.
210
+ - As block_size is 4, we pad 1 token for each expert.
211
+ - First, flatten topk_ids to [2, 3, 4, 1, 2, 4, 1, 3, 4, 1, 2, 3].
212
+ - Then append padding tokens [12, 12, 12, 12] for each block.
213
+ - After sorting by expert index, we obtain token_ids
214
+ [3, 6, 9, 12, 0, 4, 10, 12, 1, 7, 11, 12, 2, 5, 8, 12].
215
+ Tokens 12 are non-existent (padding) and are ignored in
216
+ the subsequent matrix multiplication.
217
+ - The padding ensures that the total number of tokens is now divisible
218
+ by block_size for proper block matrix operations.
219
+ """
220
+ max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1)
221
+ sorted_ids = torch.empty(
222
+ (max_num_tokens_padded,), dtype=torch.int32, device=topk_ids.device
223
+ )
224
+ sorted_ids.fill_(topk_ids.numel())
225
+ max_num_m_blocks = triton.cdiv(max_num_tokens_padded, block_size)
226
+ expert_ids = torch.empty(
227
+ (max_num_m_blocks,), dtype=torch.int32, device=topk_ids.device
228
+ )
229
+ num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device=topk_ids.device)
230
+ ops.moe_align_block_size(
231
+ topk_ids, num_experts, block_size, sorted_ids, expert_ids, num_tokens_post_pad
232
+ )
233
+ return sorted_ids, expert_ids, num_tokens_post_pad
234
+
235
+
236
+ def invoke_fused_moe_kernel(
237
+ A: torch.Tensor,
238
+ B: torch.Tensor,
239
+ C: torch.Tensor,
240
+ A_scale: Optional[torch.Tensor],
241
+ B_scale: Optional[torch.Tensor],
242
+ topk_weights: torch.Tensor,
243
+ topk_ids: torch.Tensor,
244
+ sorted_token_ids: torch.Tensor,
245
+ expert_ids: torch.Tensor,
246
+ num_tokens_post_padded: torch.Tensor,
247
+ mul_routed_weight: bool,
248
+ top_k: int,
249
+ config: Dict[str, Any],
250
+ compute_type: tl.dtype,
251
+ use_fp8_w8a8: bool,
252
+ use_int8_w8a16: bool,
253
+ ) -> None:
254
+ assert topk_weights.stride(1) == 1
255
+ assert sorted_token_ids.stride(0) == 1
256
+
257
+ if use_fp8_w8a8:
258
+ A, A_scale = ops.scaled_fp8_quant(A, A_scale)
259
+ assert B_scale is not None
260
+ elif use_int8_w8a16:
261
+ assert B_scale is not None
262
+ else:
263
+ assert A_scale is None
264
+ assert B_scale is None
265
+
266
+ grid = lambda META: (
267
+ triton.cdiv(sorted_token_ids.shape[0], META["BLOCK_SIZE_M"])
268
+ * triton.cdiv(B.shape[1], META["BLOCK_SIZE_N"]),
269
+ )
270
+
271
+ fused_moe_kernel[grid](
272
+ A,
273
+ B,
274
+ C,
275
+ A_scale,
276
+ B_scale,
277
+ topk_weights,
278
+ sorted_token_ids,
279
+ expert_ids,
280
+ num_tokens_post_padded,
281
+ B.shape[1],
282
+ B.shape[2],
283
+ sorted_token_ids.shape[0],
284
+ topk_ids.numel(),
285
+ A.stride(0),
286
+ A.stride(1),
287
+ B.stride(0),
288
+ B.stride(2),
289
+ B.stride(1),
290
+ C.stride(1),
291
+ C.stride(2),
292
+ B_scale.stride(0) if B_scale is not None and use_int8_w8a16 else 0,
293
+ B_scale.stride(1) if B_scale is not None and use_int8_w8a16 else 0,
294
+ MUL_ROUTED_WEIGHT=mul_routed_weight,
295
+ top_k=top_k,
296
+ compute_type=compute_type,
297
+ use_fp8_w8a8=use_fp8_w8a8,
298
+ use_int8_w8a16=use_int8_w8a16,
299
+ **config,
300
+ )
301
+
302
+
303
+ def get_config_file_name(E: int, N: int, dtype: Optional[str]) -> str:
304
+ device_name = get_device_name().replace(" ", "_")
305
+ dtype_selector = "" if not dtype else f",dtype={dtype}"
306
+ return f"E={E},N={N},device_name={device_name}{dtype_selector}.json"
307
+
308
+
309
+ @functools.lru_cache
310
+ def get_moe_configs(E: int, N: int, dtype: Optional[str]) -> Optional[Dict[int, Any]]:
311
+ """
312
+ Return optimized configurations for the fused MoE kernel.
313
+
314
+ The return value will be a dictionary that maps an irregular grid of
315
+ batch sizes to configurations of the fused_moe kernel. To evaluate the
316
+ kernel on a given batch size bs, the closest batch size in the grid should
317
+ be picked and the associated configuration chosen to invoke the kernel.
318
+ """
319
+
320
+ # First look up if an optimized configuration is available in the configs
321
+ # directory
322
+ json_file_name = get_config_file_name(E, N, dtype)
323
+
324
+ config_file_path = os.path.join(
325
+ os.path.dirname(os.path.realpath(__file__)), "configs", json_file_name
326
+ )
327
+ if os.path.exists(config_file_path):
328
+ with open(config_file_path) as f:
329
+ logger.info("Using configuration from %s for MoE layer.", config_file_path)
330
+ # If a configuration has been found, return it
331
+ return {int(key): val for key, val in json.load(f).items()}
332
+
333
+ # If no optimized configuration is available, we will use the default
334
+ # configuration
335
+ logger.warning(
336
+ (
337
+ "Using default MoE config. Performance might be sub-optimal! "
338
+ "Config file not found at %s"
339
+ ),
340
+ config_file_path,
341
+ )
342
+ return None
343
+
344
+
345
+ def get_default_config(
346
+ M: int,
347
+ E: int,
348
+ N: int,
349
+ K: int,
350
+ topk: int,
351
+ dtype: Optional[str],
352
+ is_marlin: bool,
353
+ ) -> Dict[str, int]:
354
+ config = {
355
+ "BLOCK_SIZE_M": 64,
356
+ "BLOCK_SIZE_N": 64,
357
+ "BLOCK_SIZE_K": 32,
358
+ "GROUP_SIZE_M": 8,
359
+ }
360
+ # A heuristic: fused marlin works faster with this config for small M
361
+ if M <= E or (is_marlin and M <= 32):
362
+ config = {
363
+ "BLOCK_SIZE_M": 16,
364
+ "BLOCK_SIZE_N": 32,
365
+ "BLOCK_SIZE_K": 64,
366
+ "GROUP_SIZE_M": 1,
367
+ }
368
+ return config
369
+
370
+
371
+ def try_get_optimal_moe_config(
372
+ w1_shape: Tuple[int, ...],
373
+ w2_shape: Tuple[int, ...],
374
+ top_k: int,
375
+ dtype: Optional[str],
376
+ M: int,
377
+ is_marlin: bool = False,
378
+ ):
379
+ from sglang.srt.layers.fused_moe_triton import get_config
380
+
381
+ override_config = get_config()
382
+ if override_config:
383
+ config = override_config
384
+ else:
385
+ # First try to load optimal config from the file
386
+ E, _, N = w2_shape
387
+ configs = get_moe_configs(E, N, dtype)
388
+
389
+ if configs:
390
+ # If an optimal configuration map has been found, look up the
391
+ # optimal config
392
+ config = configs[min(configs.keys(), key=lambda x: abs(x - M))]
393
+ else:
394
+ # Else use the default config
395
+ config = get_default_config(M, E, N, w1_shape[2], top_k, dtype, is_marlin)
396
+ return config
397
+
398
+
399
+ def fused_topk(
400
+ hidden_states: torch.Tensor,
401
+ gating_output: torch.Tensor,
402
+ topk: int,
403
+ renormalize: bool,
404
+ ):
405
+ assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
406
+
407
+ M, _ = hidden_states.shape
408
+
409
+ topk_weights = torch.empty(
410
+ M, topk, dtype=torch.float32, device=hidden_states.device
411
+ )
412
+ topk_ids = torch.empty(M, topk, dtype=torch.int32, device=hidden_states.device)
413
+ token_expert_indicies = torch.empty(
414
+ M, topk, dtype=torch.int32, device=hidden_states.device
415
+ )
416
+
417
+ ops.topk_softmax(
418
+ topk_weights,
419
+ topk_ids,
420
+ token_expert_indicies,
421
+ gating_output.float(), # TODO(woosuk): Optimize this.
422
+ )
423
+ del token_expert_indicies # Not used. Will be used in the future.
424
+
425
+ if renormalize:
426
+ topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
427
+
428
+ return topk_weights, topk_ids
429
+
430
+
431
+ # This is used by the Deepseek-V2 model
432
+ def grouped_topk(
433
+ hidden_states: torch.Tensor,
434
+ gating_output: torch.Tensor,
435
+ topk: int,
436
+ renormalize: bool,
437
+ num_expert_group: int = 0,
438
+ topk_group: int = 0,
439
+ ):
440
+
441
+ assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
442
+
443
+ scores = torch.softmax(gating_output, dim=-1)
444
+ num_token = scores.shape[0]
445
+ group_scores = (
446
+ scores.view(num_token, num_expert_group, -1).max(dim=-1).values
447
+ ) # [n, n_group]
448
+ group_idx = torch.topk(group_scores, k=topk_group, dim=-1, sorted=False)[
449
+ 1
450
+ ] # [n, top_k_group]
451
+ group_mask = torch.zeros_like(group_scores) # [n, n_group]
452
+ group_mask.scatter_(1, group_idx, 1) # [n, n_group]
453
+ score_mask = (
454
+ group_mask.unsqueeze(-1)
455
+ .expand(num_token, num_expert_group, scores.shape[-1] // num_expert_group)
456
+ .reshape(num_token, -1)
457
+ ) # [n, e]
458
+ tmp_scores = scores.masked_fill(~score_mask.bool(), 0.0) # [n, e]
459
+ topk_weights, topk_ids = torch.topk(tmp_scores, k=topk, dim=-1, sorted=False)
460
+
461
+ if renormalize:
462
+ topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
463
+
464
+ return topk_weights.to(torch.float32), topk_ids.to(torch.int32)
465
+
466
+
467
+ def get_config_dtype_str(
468
+ dtype: torch.dtype,
469
+ use_int8_w8a16: Optional[bool] = False,
470
+ use_fp8_w8a8: Optional[bool] = False,
471
+ ):
472
+ if use_fp8_w8a8:
473
+ return "fp8_w8a8"
474
+ elif use_int8_w8a16:
475
+ return "int8_w8a16"
476
+ elif dtype == torch.float:
477
+ # avoiding cases where kernel fails when float32 MoE
478
+ # use fp16/bfloat16 configs
479
+ return "float32"
480
+ return None
481
+
482
+
483
+ def inplace_fused_experts(
484
+ hidden_states: torch.Tensor,
485
+ w1: torch.Tensor,
486
+ w2: torch.Tensor,
487
+ topk_weights: torch.Tensor,
488
+ topk_ids: torch.Tensor,
489
+ use_fp8_w8a8: bool = False,
490
+ use_int8_w8a16: bool = False,
491
+ w1_scale: Optional[torch.Tensor] = None,
492
+ w2_scale: Optional[torch.Tensor] = None,
493
+ a1_scale: Optional[torch.Tensor] = None,
494
+ a2_scale: Optional[torch.Tensor] = None,
495
+ ) -> None:
496
+ fused_experts_impl(
497
+ hidden_states,
498
+ w1,
499
+ w2,
500
+ topk_weights,
501
+ topk_ids,
502
+ True,
503
+ use_fp8_w8a8,
504
+ use_int8_w8a16,
505
+ w1_scale,
506
+ w2_scale,
507
+ a1_scale,
508
+ a2_scale,
509
+ )
510
+
511
+
512
+ def inplace_fused_experts_fake(
513
+ hidden_states: torch.Tensor,
514
+ w1: torch.Tensor,
515
+ w2: torch.Tensor,
516
+ topk_weights: torch.Tensor,
517
+ topk_ids: torch.Tensor,
518
+ use_fp8_w8a8: bool = False,
519
+ use_int8_w8a16: bool = False,
520
+ w1_scale: Optional[torch.Tensor] = None,
521
+ w2_scale: Optional[torch.Tensor] = None,
522
+ a1_scale: Optional[torch.Tensor] = None,
523
+ a2_scale: Optional[torch.Tensor] = None,
524
+ ) -> None:
525
+ pass
526
+
527
+
528
+ direct_register_custom_op(
529
+ op_name="inplace_fused_experts",
530
+ op_func=inplace_fused_experts,
531
+ mutates_args=["hidden_states"],
532
+ fake_impl=inplace_fused_experts_fake,
533
+ )
534
+
535
+
536
+ def outplace_fused_experts(
537
+ hidden_states: torch.Tensor,
538
+ w1: torch.Tensor,
539
+ w2: torch.Tensor,
540
+ topk_weights: torch.Tensor,
541
+ topk_ids: torch.Tensor,
542
+ use_fp8_w8a8: bool = False,
543
+ use_int8_w8a16: bool = False,
544
+ w1_scale: Optional[torch.Tensor] = None,
545
+ w2_scale: Optional[torch.Tensor] = None,
546
+ a1_scale: Optional[torch.Tensor] = None,
547
+ a2_scale: Optional[torch.Tensor] = None,
548
+ ) -> torch.Tensor:
549
+ return fused_experts_impl(
550
+ hidden_states,
551
+ w1,
552
+ w2,
553
+ topk_weights,
554
+ topk_ids,
555
+ False,
556
+ use_fp8_w8a8,
557
+ use_int8_w8a16,
558
+ w1_scale,
559
+ w2_scale,
560
+ a1_scale,
561
+ a2_scale,
562
+ )
563
+
564
+
565
+ def outplace_fused_experts_fake(
566
+ hidden_states: torch.Tensor,
567
+ w1: torch.Tensor,
568
+ w2: torch.Tensor,
569
+ topk_weights: torch.Tensor,
570
+ topk_ids: torch.Tensor,
571
+ use_fp8_w8a8: bool = False,
572
+ use_int8_w8a16: bool = False,
573
+ w1_scale: Optional[torch.Tensor] = None,
574
+ w2_scale: Optional[torch.Tensor] = None,
575
+ a1_scale: Optional[torch.Tensor] = None,
576
+ a2_scale: Optional[torch.Tensor] = None,
577
+ ) -> torch.Tensor:
578
+ return torch.empty_like(hidden_states)
579
+
580
+
581
+ direct_register_custom_op(
582
+ op_name="outplace_fused_experts",
583
+ op_func=outplace_fused_experts,
584
+ mutates_args=[],
585
+ fake_impl=outplace_fused_experts_fake,
586
+ )
587
+
588
+
589
+ def fused_experts(
590
+ hidden_states: torch.Tensor,
591
+ w1: torch.Tensor,
592
+ w2: torch.Tensor,
593
+ topk_weights: torch.Tensor,
594
+ topk_ids: torch.Tensor,
595
+ inplace: bool = False,
596
+ use_fp8_w8a8: bool = False,
597
+ use_int8_w8a16: bool = False,
598
+ w1_scale: Optional[torch.Tensor] = None,
599
+ w2_scale: Optional[torch.Tensor] = None,
600
+ a1_scale: Optional[torch.Tensor] = None,
601
+ a2_scale: Optional[torch.Tensor] = None,
602
+ ):
603
+ if inplace:
604
+ torch.ops.sglang.inplace_fused_experts(
605
+ hidden_states,
606
+ w1,
607
+ w2,
608
+ topk_weights,
609
+ topk_ids,
610
+ use_fp8_w8a8,
611
+ use_int8_w8a16,
612
+ w1_scale,
613
+ w2_scale,
614
+ a1_scale,
615
+ a2_scale,
616
+ )
617
+ return hidden_states
618
+ else:
619
+ return torch.ops.sglang.outplace_fused_experts(
620
+ hidden_states,
621
+ w1,
622
+ w2,
623
+ topk_weights,
624
+ topk_ids,
625
+ use_fp8_w8a8,
626
+ use_int8_w8a16,
627
+ w1_scale,
628
+ w2_scale,
629
+ a1_scale,
630
+ a2_scale,
631
+ )
632
+
633
+
634
+ def fused_experts_impl(
635
+ hidden_states: torch.Tensor,
636
+ w1: torch.Tensor,
637
+ w2: torch.Tensor,
638
+ topk_weights: torch.Tensor,
639
+ topk_ids: torch.Tensor,
640
+ inplace: bool = False,
641
+ use_fp8_w8a8: bool = False,
642
+ use_int8_w8a16: bool = False,
643
+ w1_scale: Optional[torch.Tensor] = None,
644
+ w2_scale: Optional[torch.Tensor] = None,
645
+ a1_scale: Optional[torch.Tensor] = None,
646
+ a2_scale: Optional[torch.Tensor] = None,
647
+ ):
648
+ # Check constraints.
649
+ assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch"
650
+ assert topk_weights.shape == topk_ids.shape, "topk shape mismatch"
651
+ assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
652
+ assert w1.is_contiguous(), "Expert weights1 must be contiguous"
653
+ assert w2.is_contiguous(), "Expert weights2 must be contiguous"
654
+ assert hidden_states.dtype in [torch.float32, torch.float16, torch.bfloat16]
655
+
656
+ num_tokens, _ = hidden_states.shape
657
+ E, N, _ = w1.shape
658
+ # We execute the fused_moe kernel in chunks to circumvent this issue:
659
+ # https://github.com/vllm-project/vllm/issues/5938
660
+ CHUNK_SIZE = 64 * 1024
661
+ M = min(num_tokens, CHUNK_SIZE)
662
+ config_dtype = get_config_dtype_str(
663
+ use_fp8_w8a8=use_fp8_w8a8,
664
+ use_int8_w8a16=use_int8_w8a16,
665
+ dtype=hidden_states.dtype,
666
+ )
667
+
668
+ get_config_func = functools.partial(
669
+ try_get_optimal_moe_config,
670
+ w1.shape,
671
+ w2.shape,
672
+ topk_ids.shape[1],
673
+ config_dtype,
674
+ )
675
+
676
+ config = get_config_func(M)
677
+
678
+ intermediate_cache1 = torch.empty(
679
+ (M, topk_ids.shape[1], N),
680
+ device=hidden_states.device,
681
+ dtype=hidden_states.dtype,
682
+ )
683
+ intermediate_cache2 = torch.empty(
684
+ (M * topk_ids.shape[1], N // 2),
685
+ device=hidden_states.device,
686
+ dtype=hidden_states.dtype,
687
+ )
688
+ intermediate_cache3 = torch.empty(
689
+ (M, topk_ids.shape[1], w2.shape[1]),
690
+ device=hidden_states.device,
691
+ dtype=hidden_states.dtype,
692
+ )
693
+
694
+ compute_type = tl.bfloat16 if hidden_states.dtype == torch.bfloat16 else tl.float16
695
+
696
+ if inplace:
697
+ out_hidden_states = hidden_states
698
+ else:
699
+ out_hidden_states = torch.empty_like(hidden_states)
700
+
701
+ for chunk in range((num_tokens // CHUNK_SIZE) + 1):
702
+ begin_chunk_idx, end_chunk_idx = (
703
+ chunk * CHUNK_SIZE,
704
+ min((chunk + 1) * CHUNK_SIZE, num_tokens),
705
+ )
706
+ curr_hidden_states = hidden_states[begin_chunk_idx:end_chunk_idx]
707
+ tokens_in_chunk, _ = curr_hidden_states.shape
708
+
709
+ if tokens_in_chunk == 0:
710
+ break
711
+
712
+ if tokens_in_chunk < CHUNK_SIZE and chunk > 0:
713
+ # Adjust the intermediate cache size and config for the last
714
+ # chunk. Note that in most cases we only have one chunk
715
+ # so the cache size and config are already set correctly and
716
+ # do not need to be adjusted.
717
+ intermediate_cache1 = intermediate_cache1[:tokens_in_chunk]
718
+ intermediate_cache2 = intermediate_cache2[:tokens_in_chunk]
719
+ intermediate_cache3 = intermediate_cache3[:tokens_in_chunk]
720
+ config = get_config_func(tokens_in_chunk)
721
+
722
+ curr_topk_ids = topk_ids[begin_chunk_idx:end_chunk_idx]
723
+ curr_topk_weights = topk_weights[begin_chunk_idx:end_chunk_idx]
724
+
725
+ sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size(
726
+ curr_topk_ids, config["BLOCK_SIZE_M"], E
727
+ )
728
+
729
+ invoke_fused_moe_kernel(
730
+ curr_hidden_states,
731
+ w1,
732
+ intermediate_cache1,
733
+ a1_scale,
734
+ w1_scale,
735
+ curr_topk_weights,
736
+ curr_topk_ids,
737
+ sorted_token_ids,
738
+ expert_ids,
739
+ num_tokens_post_padded,
740
+ False,
741
+ topk_ids.shape[1],
742
+ config,
743
+ compute_type=compute_type,
744
+ use_fp8_w8a8=use_fp8_w8a8,
745
+ use_int8_w8a16=use_int8_w8a16,
746
+ )
747
+
748
+ ops.silu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N))
749
+
750
+ invoke_fused_moe_kernel(
751
+ intermediate_cache2,
752
+ w2,
753
+ intermediate_cache3,
754
+ a2_scale,
755
+ w2_scale,
756
+ curr_topk_weights,
757
+ curr_topk_ids,
758
+ sorted_token_ids,
759
+ expert_ids,
760
+ num_tokens_post_padded,
761
+ True,
762
+ 1,
763
+ config,
764
+ compute_type=compute_type,
765
+ use_fp8_w8a8=use_fp8_w8a8,
766
+ use_int8_w8a16=use_int8_w8a16,
767
+ )
768
+
769
+ torch.sum(
770
+ intermediate_cache3.view(*intermediate_cache3.shape),
771
+ dim=1,
772
+ out=out_hidden_states[begin_chunk_idx:end_chunk_idx],
773
+ )
774
+ return out_hidden_states
775
+
776
+
777
+ def fused_moe(
778
+ hidden_states: torch.Tensor,
779
+ w1: torch.Tensor,
780
+ w2: torch.Tensor,
781
+ gating_output: torch.Tensor,
782
+ topk: int,
783
+ renormalize: bool,
784
+ inplace: bool = False,
785
+ use_grouped_topk: bool = False,
786
+ num_expert_group: Optional[int] = None,
787
+ topk_group: Optional[int] = None,
788
+ custom_routing_function: Optional[Callable] = None,
789
+ use_fp8_w8a8: bool = False,
790
+ use_int8_w8a16: bool = False,
791
+ w1_scale: Optional[torch.Tensor] = None,
792
+ w2_scale: Optional[torch.Tensor] = None,
793
+ a1_scale: Optional[torch.Tensor] = None,
794
+ a2_scale: Optional[torch.Tensor] = None,
795
+ ) -> torch.Tensor:
796
+ """
797
+ This function computes a Mixture of Experts (MoE) layer using two sets of
798
+ weights, w1 and w2, and top-k gating mechanism.
799
+
800
+ Parameters:
801
+ - hidden_states (torch.Tensor): The input tensor to the MoE layer.
802
+ - w1 (torch.Tensor): The first set of expert weights.
803
+ - w2 (torch.Tensor): The second set of expert weights.
804
+ - gating_output (torch.Tensor): The output of the gating operation
805
+ (before softmax).
806
+ - topk (int): The number of top-k experts to select.
807
+ - renormalize (bool): If True, renormalize the top-k weights to sum to 1.
808
+ - inplace (bool): If True, perform the operation in-place.
809
+ Defaults to False.
810
+ - num_expert_group: Optional[int]: additional parameter for grouped_topk
811
+ - topk_group: Optional[int]: additional parameter for grouped_topk
812
+ - use_grouped_topk: If True, use grouped_topk instead of fused_topk
813
+ note: Deepseekv2 model uses grouped_topk
814
+ - use_fp8_w8a8 (bool): If True, use fp8 arithmetic to compute the inner
815
+ products for w1 and w2. Defaults to False.
816
+ - use_int8_w8a16 (bool): If True, use fp8 arithmetic to compute the inner
817
+ products for w1 and w2. Defaults to False.
818
+ - w1_scale (Optional[torch.Tensor]): Optional scale to be used for
819
+ w1.
820
+ - w2_scale (Optional[torch.Tensor]): Optional scale to be used for
821
+ w2.
822
+
823
+ Returns:
824
+ - torch.Tensor: The output tensor after applying the MoE layer.
825
+ """
826
+ # Check constraints.
827
+ assert gating_output.shape[1] == w1.shape[0], "Number of experts mismatch"
828
+
829
+ if use_grouped_topk:
830
+ assert num_expert_group is not None and topk_group is not None
831
+ topk_weights, topk_ids = grouped_topk(
832
+ hidden_states,
833
+ gating_output,
834
+ topk,
835
+ renormalize,
836
+ num_expert_group,
837
+ topk_group,
838
+ )
839
+ elif custom_routing_function is None:
840
+ topk_weights, topk_ids = fused_topk(
841
+ hidden_states, gating_output, topk, renormalize
842
+ )
843
+ else:
844
+ topk_weights, topk_ids = custom_routing_function(
845
+ hidden_states, gating_output, topk, renormalize
846
+ )
847
+
848
+ return fused_experts(
849
+ hidden_states,
850
+ w1,
851
+ w2,
852
+ topk_weights,
853
+ topk_ids,
854
+ inplace=inplace,
855
+ use_fp8_w8a8=use_fp8_w8a8,
856
+ use_int8_w8a16=use_int8_w8a16,
857
+ w1_scale=w1_scale,
858
+ w2_scale=w2_scale,
859
+ a1_scale=a1_scale,
860
+ a2_scale=a2_scale,
861
+ )