sglang 0.1.14__py3-none-any.whl → 0.1.21__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (81) hide show
  1. sglang/__init__.py +59 -2
  2. sglang/api.py +40 -11
  3. sglang/backend/anthropic.py +17 -3
  4. sglang/backend/litellm.py +90 -0
  5. sglang/backend/openai.py +160 -12
  6. sglang/backend/runtime_endpoint.py +62 -27
  7. sglang/backend/vertexai.py +1 -0
  8. sglang/bench_latency.py +320 -0
  9. sglang/global_config.py +24 -3
  10. sglang/lang/chat_template.py +122 -6
  11. sglang/lang/compiler.py +2 -2
  12. sglang/lang/interpreter.py +206 -98
  13. sglang/lang/ir.py +98 -34
  14. sglang/lang/tracer.py +6 -4
  15. sglang/launch_server.py +4 -1
  16. sglang/launch_server_llavavid.py +32 -0
  17. sglang/srt/constrained/__init__.py +14 -6
  18. sglang/srt/constrained/fsm_cache.py +9 -2
  19. sglang/srt/constrained/jump_forward.py +113 -24
  20. sglang/srt/conversation.py +4 -2
  21. sglang/srt/flush_cache.py +18 -0
  22. sglang/srt/hf_transformers_utils.py +144 -3
  23. sglang/srt/layers/context_flashattention_nopad.py +1 -0
  24. sglang/srt/layers/extend_attention.py +20 -1
  25. sglang/srt/layers/fused_moe.py +596 -0
  26. sglang/srt/layers/logits_processor.py +190 -61
  27. sglang/srt/layers/radix_attention.py +62 -53
  28. sglang/srt/layers/token_attention.py +21 -9
  29. sglang/srt/managers/controller/cuda_graph_runner.py +196 -0
  30. sglang/srt/managers/controller/dp_worker.py +113 -0
  31. sglang/srt/managers/controller/infer_batch.py +908 -0
  32. sglang/srt/managers/controller/manager_multi.py +195 -0
  33. sglang/srt/managers/controller/manager_single.py +177 -0
  34. sglang/srt/managers/controller/model_runner.py +359 -0
  35. sglang/srt/managers/{router → controller}/radix_cache.py +102 -53
  36. sglang/srt/managers/controller/schedule_heuristic.py +65 -0
  37. sglang/srt/managers/controller/tp_worker.py +813 -0
  38. sglang/srt/managers/detokenizer_manager.py +42 -40
  39. sglang/srt/managers/io_struct.py +44 -10
  40. sglang/srt/managers/tokenizer_manager.py +224 -82
  41. sglang/srt/memory_pool.py +52 -59
  42. sglang/srt/model_config.py +97 -2
  43. sglang/srt/models/chatglm.py +399 -0
  44. sglang/srt/models/commandr.py +369 -0
  45. sglang/srt/models/dbrx.py +406 -0
  46. sglang/srt/models/gemma.py +34 -38
  47. sglang/srt/models/gemma2.py +436 -0
  48. sglang/srt/models/grok.py +738 -0
  49. sglang/srt/models/llama2.py +47 -37
  50. sglang/srt/models/llama_classification.py +107 -0
  51. sglang/srt/models/llava.py +92 -27
  52. sglang/srt/models/llavavid.py +298 -0
  53. sglang/srt/models/minicpm.py +366 -0
  54. sglang/srt/models/mixtral.py +302 -127
  55. sglang/srt/models/mixtral_quant.py +372 -0
  56. sglang/srt/models/qwen.py +40 -35
  57. sglang/srt/models/qwen2.py +33 -36
  58. sglang/srt/models/qwen2_moe.py +473 -0
  59. sglang/srt/models/stablelm.py +33 -39
  60. sglang/srt/models/yivl.py +19 -26
  61. sglang/srt/openai_api_adapter.py +411 -0
  62. sglang/srt/{managers/openai_protocol.py → openai_protocol.py} +44 -19
  63. sglang/srt/sampling_params.py +2 -0
  64. sglang/srt/server.py +197 -481
  65. sglang/srt/server_args.py +190 -74
  66. sglang/srt/utils.py +460 -95
  67. sglang/test/test_programs.py +73 -10
  68. sglang/test/test_utils.py +226 -7
  69. sglang/utils.py +97 -27
  70. {sglang-0.1.14.dist-info → sglang-0.1.21.dist-info}/METADATA +74 -45
  71. sglang-0.1.21.dist-info/RECORD +82 -0
  72. {sglang-0.1.14.dist-info → sglang-0.1.21.dist-info}/WHEEL +1 -1
  73. sglang/srt/backend_config.py +0 -13
  74. sglang/srt/managers/router/infer_batch.py +0 -503
  75. sglang/srt/managers/router/manager.py +0 -79
  76. sglang/srt/managers/router/model_rpc.py +0 -686
  77. sglang/srt/managers/router/model_runner.py +0 -514
  78. sglang/srt/managers/router/scheduler.py +0 -70
  79. sglang-0.1.14.dist-info/RECORD +0 -64
  80. {sglang-0.1.14.dist-info → sglang-0.1.21.dist-info}/LICENSE +0 -0
  81. {sglang-0.1.14.dist-info → sglang-0.1.21.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,596 @@
1
+ # Adapted from
2
+ # https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/layers/fused_moe/fused_moe.py#L1
3
+ """Fused MoE kernel."""
4
+ import functools
5
+ import json
6
+ import os
7
+ from typing import Any, Dict, Optional, Tuple
8
+
9
+ import torch
10
+ import triton
11
+ import triton.language as tl
12
+ from vllm import _custom_ops as ops
13
+ from vllm.logger import init_logger
14
+
15
+ logger = init_logger(__name__)
16
+
17
+
18
+ @triton.jit
19
+ def fused_moe_kernel(
20
+ # Pointers to matrices
21
+ a_ptr,
22
+ b_ptr,
23
+ c_ptr,
24
+ a_scale_ptr,
25
+ b_scale_ptr,
26
+ topk_weights_ptr,
27
+ sorted_token_ids_ptr,
28
+ expert_ids_ptr,
29
+ num_tokens_post_padded_ptr,
30
+ # Matrix dimensions
31
+ N,
32
+ K,
33
+ EM,
34
+ num_valid_tokens,
35
+ # The stride variables represent how much to increase the ptr by when
36
+ # moving by 1 element in a particular dimension. E.g. `stride_am` is
37
+ # how much to increase `a_ptr` by to get the element one row down
38
+ # (A has M rows).
39
+ stride_am,
40
+ stride_ak,
41
+ stride_be,
42
+ stride_bk,
43
+ stride_bn,
44
+ stride_cm,
45
+ stride_cn,
46
+ # Meta-parameters
47
+ BLOCK_SIZE_M: tl.constexpr,
48
+ BLOCK_SIZE_N: tl.constexpr,
49
+ BLOCK_SIZE_K: tl.constexpr,
50
+ GROUP_SIZE_M: tl.constexpr,
51
+ MUL_ROUTED_WEIGHT: tl.constexpr,
52
+ top_k: tl.constexpr,
53
+ compute_type: tl.constexpr,
54
+ use_fp8: tl.constexpr,
55
+ ):
56
+ """
57
+ Implements the fused computation for a Mixture of Experts (MOE) using
58
+ token and expert matrices.
59
+
60
+ Key Parameters:
61
+ - A: The input tensor representing tokens with shape (*, K), where '*' can
62
+ be any shape representing batches and K is the feature dimension of
63
+ each token.
64
+ - B: The stacked MOE weight tensor with shape (E, N, K), where E is
65
+ the number of experts, K is the input feature dimension, and N is
66
+ the output feature dimension.
67
+ - C: The output cache tensor with shape (M, topk, N), where M is the
68
+ total number of tokens post padding, topk is the number of times
69
+ each token is repeated, and N is the output feature dimension.
70
+ - sorted_token_ids: A tensor containing the sorted indices of tokens,
71
+ repeated topk times and arranged by the expert index they are
72
+ assigned to.
73
+ - expert_ids: A tensor containing the indices of the expert for each
74
+ block. It determines which expert matrix from B should be used for
75
+ each block in A.
76
+ This kernel performs the multiplication of a token by its corresponding
77
+ expert matrix as determined by `expert_ids`. The sorting of
78
+ `sorted_token_ids` by expert index and padding ensures divisibility by
79
+ BLOCK_SIZE_M, which is necessary to maintain consistency in block matrix
80
+ multiplication across different blocks processed by the same expert.
81
+ """
82
+ # -----------------------------------------------------------
83
+ # Map program ids `pid` to the block of C it should compute.
84
+ # This is done in a grouped ordering to promote L2 data reuse.
85
+ pid = tl.program_id(axis=0)
86
+ num_pid_m = tl.cdiv(EM, BLOCK_SIZE_M)
87
+ num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
88
+ num_pid_in_group = GROUP_SIZE_M * num_pid_n
89
+ group_id = pid // num_pid_in_group
90
+ first_pid_m = group_id * GROUP_SIZE_M
91
+ group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
92
+ pid_m = first_pid_m + ((pid % num_pid_in_group) % group_size_m)
93
+ pid_n = (pid % num_pid_in_group) // group_size_m
94
+
95
+ # ----------------------------------------------------------
96
+ # Create pointers for the first blocks of A and B.
97
+ # We will advance this pointer as we move in the K direction
98
+ # and accumulate
99
+ # `a_ptrs` is a block of [BLOCK_SIZE_M, BLOCK_SIZE_K] pointers
100
+ # `b_ptrs` is a block of [BLOCK_SIZE_K, BLOCK_SIZE_N] pointers
101
+ num_tokens_post_padded = tl.load(num_tokens_post_padded_ptr)
102
+ if pid_m * BLOCK_SIZE_M >= num_tokens_post_padded:
103
+ return
104
+ offs_token_id = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
105
+ offs_token = tl.load(sorted_token_ids_ptr + offs_token_id)
106
+ token_mask = offs_token < num_valid_tokens
107
+
108
+ offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
109
+ offs_k = tl.arange(0, BLOCK_SIZE_K)
110
+ a_ptrs = a_ptr + (
111
+ offs_token[:, None] // top_k * stride_am + offs_k[None, :] * stride_ak
112
+ )
113
+
114
+ off_experts = tl.load(expert_ids_ptr + pid_m)
115
+ b_ptrs = (
116
+ b_ptr
117
+ + off_experts * stride_be
118
+ + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
119
+ )
120
+
121
+ if use_fp8:
122
+ a_scale = tl.load(a_scale_ptr)
123
+ b_scale = tl.load(b_scale_ptr + off_experts)
124
+
125
+ # -----------------------------------------------------------
126
+ # Iterate to compute a block of the C matrix.
127
+ # We accumulate into a `[BLOCK_SIZE_M, BLOCK_SIZE_N]` block
128
+ # of fp32 values for higher accuracy.
129
+ # `accumulator` will be converted back to fp16 after the loop.
130
+ accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
131
+
132
+ for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
133
+ # Load the next block of A and B, generate a mask by checking the
134
+ # K dimension.
135
+ a = tl.load(
136
+ a_ptrs,
137
+ mask=token_mask[:, None] & (offs_k[None, :] < K - k * BLOCK_SIZE_K),
138
+ other=0.0,
139
+ )
140
+ b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
141
+ # We accumulate along the K dimension.
142
+ if use_fp8:
143
+ accumulator = tl.dot(a, b, acc=accumulator)
144
+ else:
145
+ accumulator += tl.dot(a, b)
146
+ # Advance the ptrs to the next K block.
147
+ a_ptrs += BLOCK_SIZE_K * stride_ak
148
+ b_ptrs += BLOCK_SIZE_K * stride_bk
149
+
150
+ if MUL_ROUTED_WEIGHT:
151
+ moe_weight = tl.load(topk_weights_ptr + offs_token, mask=token_mask, other=0)
152
+ accumulator = accumulator * moe_weight[:, None]
153
+
154
+ if use_fp8:
155
+ accumulator = (accumulator * a_scale * b_scale).to(compute_type)
156
+ else:
157
+ accumulator = accumulator.to(compute_type)
158
+ # -----------------------------------------------------------
159
+ # Write back the block of the output
160
+ offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
161
+ c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[None, :]
162
+ c_mask = token_mask[:, None] & (offs_cn[None, :] < N)
163
+ tl.store(c_ptrs, accumulator, mask=c_mask)
164
+
165
+
166
+ def moe_align_block_size(
167
+ topk_ids: torch.Tensor, block_size: int, num_experts: int
168
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
169
+ """
170
+ Aligns the token distribution across experts to be compatible with block
171
+ size for matrix multiplication.
172
+
173
+ Parameters:
174
+ - topk_ids: A tensor of shape [total_tokens, top_k] representing the
175
+ top-k expert indices for each token.
176
+ - block_size: The block size used in block matrix multiplication.
177
+ - num_experts: The total number of experts.
178
+
179
+ Returns:
180
+ - sorted_token_ids: A tensor containing the sorted token indices according
181
+ to their allocated expert.
182
+ - expert_ids: A tensor indicating the assigned expert index for each block.
183
+ - num_tokens_post_padded: The total number of tokens after padding,
184
+ ensuring divisibility by block_size.
185
+
186
+ This function pads the number of tokens that each expert needs to process
187
+ so that it is divisible by block_size.
188
+ Padding ensures that during block matrix multiplication, the dimensions
189
+ align correctly.
190
+
191
+ Example:
192
+ Given topk_ids = [[2, 3, 4], [1, 2, 4], [1, 3, 4], [1, 2, 3]],
193
+ block_size = 4, and num_experts = 4:
194
+ - We initially have 12 tokens (after repeating 'top_k' times) and 4 experts,
195
+ with each expert needing to process 3 tokens.
196
+ - As block_size is 4, we pad 1 token for each expert.
197
+ - First, flatten topk_ids to [2, 3, 4, 1, 2, 4, 1, 3, 4, 1, 2, 3].
198
+ - Then append padding tokens [12, 12, 12, 12] for each block.
199
+ - After sorting by expert index, we obtain token_ids
200
+ [3, 6, 9, 12, 0, 4, 10, 12, 1, 7, 11, 12, 2, 5, 8, 12].
201
+ Tokens 12 are non-existent (padding) and are ignored in
202
+ the subsequent matrix multiplication.
203
+ - The padding ensures that the total number of tokens is now divisible
204
+ by block_size for proper block matrix operations.
205
+ """
206
+ max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1)
207
+ sorted_ids = torch.empty(
208
+ (max_num_tokens_padded,), dtype=torch.int32, device=topk_ids.device
209
+ )
210
+ sorted_ids.fill_(topk_ids.numel())
211
+ max_num_m_blocks = triton.cdiv(max_num_tokens_padded, block_size)
212
+ expert_ids = torch.empty(
213
+ (max_num_m_blocks,), dtype=torch.int32, device=topk_ids.device
214
+ )
215
+ num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device=topk_ids.device)
216
+ ops.moe_align_block_size(
217
+ topk_ids, num_experts, block_size, sorted_ids, expert_ids, num_tokens_post_pad
218
+ )
219
+ return sorted_ids, expert_ids, num_tokens_post_pad
220
+
221
+
222
+ def invoke_fused_moe_kernel(
223
+ A: torch.Tensor,
224
+ B: torch.Tensor,
225
+ C: torch.Tensor,
226
+ A_scale: Optional[torch.Tensor],
227
+ B_scale: Optional[torch.Tensor],
228
+ topk_weights: torch.Tensor,
229
+ topk_ids: torch.Tensor,
230
+ sorted_token_ids: torch.Tensor,
231
+ expert_ids: torch.Tensor,
232
+ num_tokens_post_padded: torch.Tensor,
233
+ mul_routed_weight: bool,
234
+ top_k: int,
235
+ config: Dict[str, Any],
236
+ compute_type: tl.dtype,
237
+ use_fp8: bool,
238
+ ) -> None:
239
+ assert topk_weights.stride(1) == 1
240
+ assert sorted_token_ids.stride(0) == 1
241
+
242
+ if not use_fp8:
243
+ assert A_scale is None
244
+ assert B_scale is None
245
+ else:
246
+ A, A_scale = ops.scaled_fp8_quant(A, A_scale)
247
+ assert B_scale is not None
248
+
249
+ grid = lambda META: (
250
+ triton.cdiv(sorted_token_ids.shape[0], META["BLOCK_SIZE_M"])
251
+ * triton.cdiv(B.shape[1], META["BLOCK_SIZE_N"]),
252
+ )
253
+
254
+ fused_moe_kernel[grid](
255
+ A,
256
+ B,
257
+ C,
258
+ A_scale,
259
+ B_scale,
260
+ topk_weights,
261
+ sorted_token_ids,
262
+ expert_ids,
263
+ num_tokens_post_padded,
264
+ B.shape[1],
265
+ B.shape[2],
266
+ sorted_token_ids.shape[0],
267
+ topk_ids.numel(),
268
+ A.stride(0),
269
+ A.stride(1),
270
+ B.stride(0),
271
+ B.stride(2),
272
+ B.stride(1),
273
+ C.stride(1),
274
+ C.stride(2),
275
+ MUL_ROUTED_WEIGHT=mul_routed_weight,
276
+ top_k=top_k,
277
+ compute_type=compute_type,
278
+ use_fp8=use_fp8,
279
+ **config,
280
+ )
281
+
282
+
283
+ def get_config_file_name(E: int, N: int, dtype: Optional[str]) -> str:
284
+ device_name = torch.cuda.get_device_name().replace(" ", "_")
285
+ dtype_selector = "" if not dtype else f",dtype={dtype}"
286
+ return f"E={E},N={N},device_name={device_name}{dtype_selector}.json"
287
+
288
+
289
+ @functools.lru_cache
290
+ def get_moe_configs(E: int, N: int, dtype: Optional[str]) -> Optional[Dict[int, Any]]:
291
+ """
292
+ Return optimized configurations for the fused MoE kernel.
293
+
294
+ The return value will be a dictionary that maps an irregular grid of
295
+ batch sizes to configurations of the fused_moe kernel. To evaluate the
296
+ kernel on a given batch size bs, the closest batch size in the grid should
297
+ be picked and the associated configuration chosen to invoke the kernel.
298
+ """
299
+
300
+ # First look up if an optimized configuration is available in the configs
301
+ # directory
302
+ json_file_name = get_config_file_name(E, N, dtype)
303
+
304
+ config_file_path = os.path.join(
305
+ os.path.dirname(os.path.realpath(__file__)), "configs", json_file_name
306
+ )
307
+ if os.path.exists(config_file_path):
308
+ with open(config_file_path) as f:
309
+ logger.info("Using configuration from %s for MoE layer.", config_file_path)
310
+ # If a configuration has been found, return it
311
+ return {int(key): val for key, val in json.load(f).items()}
312
+
313
+ # If no optimized configuration is available, we will use the default
314
+ # configuration
315
+ return None
316
+
317
+
318
+ def get_default_config(
319
+ M: int,
320
+ E: int,
321
+ N: int,
322
+ K: int,
323
+ topk: int,
324
+ dtype: Optional[str],
325
+ ) -> Dict[str, int]:
326
+ if dtype == "float8":
327
+ config = {
328
+ "BLOCK_SIZE_M": 128,
329
+ "BLOCK_SIZE_N": 256,
330
+ "BLOCK_SIZE_K": 128,
331
+ "GROUP_SIZE_M": 32,
332
+ "num_warps": 8,
333
+ "num_stages": 4,
334
+ }
335
+ if M <= E:
336
+ config = {
337
+ "BLOCK_SIZE_M": 64,
338
+ "BLOCK_SIZE_N": 128,
339
+ "BLOCK_SIZE_K": 128,
340
+ "GROUP_SIZE_M": 1,
341
+ "num_warps": 4,
342
+ "num_stages": 4,
343
+ }
344
+ else:
345
+ config = {
346
+ "BLOCK_SIZE_M": 64,
347
+ "BLOCK_SIZE_N": 64,
348
+ "BLOCK_SIZE_K": 32,
349
+ "GROUP_SIZE_M": 8,
350
+ }
351
+ if M <= E:
352
+ config = {
353
+ "BLOCK_SIZE_M": 16,
354
+ "BLOCK_SIZE_N": 32,
355
+ "BLOCK_SIZE_K": 64,
356
+ "GROUP_SIZE_M": 1,
357
+ }
358
+ return config
359
+
360
+
361
+ def fused_topk(
362
+ hidden_states: torch.Tensor,
363
+ gating_output: torch.Tensor,
364
+ topk: int,
365
+ renormalize: bool,
366
+ ):
367
+ assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
368
+
369
+ M, _ = hidden_states.shape
370
+
371
+ topk_weights = torch.empty(
372
+ M, topk, dtype=torch.float32, device=hidden_states.device
373
+ )
374
+ topk_ids = torch.empty(M, topk, dtype=torch.int32, device=hidden_states.device)
375
+ token_expert_indicies = torch.empty(
376
+ M, topk, dtype=torch.int32, device=hidden_states.device
377
+ )
378
+ ops.topk_softmax(
379
+ topk_weights,
380
+ topk_ids,
381
+ token_expert_indicies,
382
+ gating_output.float(), # TODO(woosuk): Optimize this.
383
+ )
384
+ del token_expert_indicies # Not used. Will be used in the future.
385
+
386
+ if renormalize:
387
+ topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
388
+ return topk_weights, topk_ids
389
+
390
+
391
+ def fused_experts(
392
+ hidden_states: torch.Tensor,
393
+ w1: torch.Tensor,
394
+ w2: torch.Tensor,
395
+ topk_weights: torch.Tensor,
396
+ topk_ids: torch.Tensor,
397
+ inplace: bool = False,
398
+ override_config: Optional[Dict[str, Any]] = None,
399
+ use_fp8: bool = False,
400
+ w1_scale: Optional[torch.Tensor] = None,
401
+ w2_scale: Optional[torch.Tensor] = None,
402
+ a1_scale: Optional[torch.Tensor] = None,
403
+ a2_scale: Optional[torch.Tensor] = None,
404
+ ):
405
+ # Check constraints.
406
+ assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch"
407
+ assert topk_weights.shape == topk_ids.shape, "topk shape mismatch"
408
+ assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
409
+ assert w1.is_contiguous(), "Expert weights1 must be contiguous"
410
+ assert w2.is_contiguous(), "Expert weights2 must be contiguous"
411
+ assert hidden_states.dtype in [torch.float32, torch.float16, torch.bfloat16]
412
+
413
+ M, _ = hidden_states.shape
414
+ E, N, _ = w1.shape
415
+
416
+ if override_config:
417
+ config = override_config
418
+ else:
419
+ # First try to load optimal config from the file
420
+ configs = get_moe_configs(E, w2.shape[2], "float8" if use_fp8 else None)
421
+
422
+ if configs:
423
+ # If an optimal configuration map has been found, look up the
424
+ # optimal config
425
+ config = configs[min(configs.keys(), key=lambda x: abs(x - M))]
426
+ else:
427
+ # Else use the default config
428
+ config = get_default_config(
429
+ M, E, N, w1.shape[2], topk_ids.shape[1], "float8" if use_fp8 else None
430
+ )
431
+
432
+ intermediate_cache1 = torch.empty(
433
+ (M, topk_ids.shape[1], N),
434
+ device=hidden_states.device,
435
+ dtype=hidden_states.dtype,
436
+ )
437
+ intermediate_cache2 = torch.empty(
438
+ (M * topk_ids.shape[1], N // 2),
439
+ device=hidden_states.device,
440
+ dtype=hidden_states.dtype,
441
+ )
442
+ intermediate_cache3 = torch.empty(
443
+ (M, topk_ids.shape[1], w2.shape[1]),
444
+ device=hidden_states.device,
445
+ dtype=hidden_states.dtype,
446
+ )
447
+
448
+ sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size(
449
+ topk_ids, config["BLOCK_SIZE_M"], E
450
+ )
451
+ compute_type = tl.bfloat16 if hidden_states.dtype == torch.bfloat16 else tl.float16
452
+
453
+ invoke_fused_moe_kernel(
454
+ hidden_states,
455
+ w1,
456
+ intermediate_cache1,
457
+ a1_scale,
458
+ w1_scale,
459
+ topk_weights,
460
+ topk_ids,
461
+ sorted_token_ids,
462
+ expert_ids,
463
+ num_tokens_post_padded,
464
+ False,
465
+ topk_ids.shape[1],
466
+ config,
467
+ compute_type=compute_type,
468
+ use_fp8=use_fp8,
469
+ )
470
+
471
+ ops.gelu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N))
472
+
473
+ invoke_fused_moe_kernel(
474
+ intermediate_cache2,
475
+ w2,
476
+ intermediate_cache3,
477
+ a2_scale,
478
+ w2_scale,
479
+ topk_weights,
480
+ topk_ids,
481
+ sorted_token_ids,
482
+ expert_ids,
483
+ num_tokens_post_padded,
484
+ True,
485
+ 1,
486
+ config,
487
+ compute_type=compute_type,
488
+ use_fp8=use_fp8,
489
+ )
490
+
491
+ if inplace:
492
+ return torch.sum(
493
+ intermediate_cache3.view(*intermediate_cache3.shape),
494
+ dim=1,
495
+ out=hidden_states,
496
+ )
497
+ return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape), dim=1)
498
+
499
+
500
+ def fused_moe(
501
+ hidden_states: torch.Tensor,
502
+ w1: torch.Tensor,
503
+ w2: torch.Tensor,
504
+ gating_output: torch.Tensor,
505
+ topk: int,
506
+ renormalize: bool,
507
+ inplace: bool = False,
508
+ override_config: Optional[Dict[str, Any]] = None,
509
+ use_fp8: bool = False,
510
+ w1_scale: Optional[torch.Tensor] = None,
511
+ w2_scale: Optional[torch.Tensor] = None,
512
+ a1_scale: Optional[torch.Tensor] = None,
513
+ a2_scale: Optional[torch.Tensor] = None,
514
+ ) -> torch.Tensor:
515
+ """
516
+ This function computes a Mixture of Experts (MoE) layer using two sets of
517
+ weights, w1 and w2, and top-k gating mechanism.
518
+
519
+ Parameters:
520
+ - hidden_states (torch.Tensor): The input tensor to the MoE layer.
521
+ - w1 (torch.Tensor): The first set of expert weights.
522
+ - w2 (torch.Tensor): The second set of expert weights.
523
+ - gating_output (torch.Tensor): The output of the gating operation
524
+ (before softmax).
525
+ - topk (int): The number of top-k experts to select.
526
+ - renormalize (bool): If True, renormalize the top-k weights to sum to 1.
527
+ - inplace (bool): If True, perform the operation in-place.
528
+ Defaults to False.
529
+ - override_config (Optional[Dict[str, Any]]): Optional override
530
+ for the kernel configuration.
531
+ - use_fp8 (bool): If True, use fp8 arithmetic to compute the inner
532
+ products for w1 and w2. Defaults to False.
533
+ - w1_scale (Optional[torch.Tensor]): Optional scale to be used for
534
+ w1.
535
+ - w2_scale (Optional[torch.Tensor]): Optional scale to be used for
536
+ w2.
537
+
538
+ Returns:
539
+ - torch.Tensor: The output tensor after applying the MoE layer.
540
+ """
541
+ # Check constraints.
542
+ assert gating_output.shape[1] == w1.shape[0], "Number of experts mismatch"
543
+
544
+ if hasattr(ops, "topk_softmax"):
545
+ topk_weights, topk_ids = fused_topk(
546
+ hidden_states, gating_output, topk, renormalize
547
+ )
548
+ else:
549
+ topk_weights, topk_ids = fused_topk_v0_4_3(
550
+ hidden_states, gating_output, topk, renormalize
551
+ )
552
+
553
+ return fused_experts(
554
+ hidden_states,
555
+ w1,
556
+ w2,
557
+ topk_weights,
558
+ topk_ids,
559
+ inplace=inplace,
560
+ override_config=override_config,
561
+ use_fp8=use_fp8,
562
+ w1_scale=w1_scale,
563
+ w2_scale=w2_scale,
564
+ a1_scale=a1_scale,
565
+ a2_scale=a2_scale,
566
+ )
567
+
568
+
569
+ def fused_topk_v0_4_3(
570
+ hidden_states: torch.Tensor,
571
+ gating_output: torch.Tensor,
572
+ topk: int,
573
+ renormalize: bool,
574
+ ):
575
+ import vllm._moe_C as moe_kernels
576
+
577
+ M, _ = hidden_states.shape
578
+
579
+ topk_weights = torch.empty(
580
+ M, topk, dtype=torch.float32, device=hidden_states.device
581
+ )
582
+ topk_ids = torch.empty(M, topk, dtype=torch.int32, device=hidden_states.device)
583
+ token_expert_indicies = torch.empty(
584
+ M, topk, dtype=torch.int32, device=hidden_states.device
585
+ )
586
+ moe_kernels.topk_softmax(
587
+ topk_weights,
588
+ topk_ids,
589
+ token_expert_indicies,
590
+ gating_output.float(), # TODO(woosuk): Optimize this.
591
+ )
592
+ del token_expert_indicies # Not used. Will be used in the future.
593
+ if renormalize:
594
+ topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
595
+
596
+ return topk_weights, topk_ids