sglang 0.1.17__py3-none-any.whl → 0.1.19__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (73) hide show
  1. sglang/__init__.py +2 -2
  2. sglang/api.py +30 -4
  3. sglang/backend/litellm.py +2 -2
  4. sglang/backend/openai.py +26 -15
  5. sglang/backend/runtime_endpoint.py +18 -14
  6. sglang/bench_latency.py +317 -0
  7. sglang/global_config.py +5 -1
  8. sglang/lang/chat_template.py +41 -6
  9. sglang/lang/compiler.py +2 -2
  10. sglang/lang/interpreter.py +6 -2
  11. sglang/lang/ir.py +74 -28
  12. sglang/launch_server.py +4 -1
  13. sglang/launch_server_llavavid.py +2 -1
  14. sglang/srt/constrained/__init__.py +14 -6
  15. sglang/srt/constrained/fsm_cache.py +6 -3
  16. sglang/srt/constrained/jump_forward.py +113 -25
  17. sglang/srt/conversation.py +2 -0
  18. sglang/srt/flush_cache.py +2 -0
  19. sglang/srt/hf_transformers_utils.py +68 -9
  20. sglang/srt/layers/extend_attention.py +2 -1
  21. sglang/srt/layers/fused_moe.py +280 -169
  22. sglang/srt/layers/logits_processor.py +106 -42
  23. sglang/srt/layers/radix_attention.py +53 -29
  24. sglang/srt/layers/token_attention.py +4 -1
  25. sglang/srt/managers/controller/dp_worker.py +6 -3
  26. sglang/srt/managers/controller/infer_batch.py +144 -69
  27. sglang/srt/managers/controller/manager_multi.py +5 -5
  28. sglang/srt/managers/controller/manager_single.py +9 -4
  29. sglang/srt/managers/controller/model_runner.py +167 -55
  30. sglang/srt/managers/controller/radix_cache.py +4 -0
  31. sglang/srt/managers/controller/schedule_heuristic.py +2 -0
  32. sglang/srt/managers/controller/tp_worker.py +156 -134
  33. sglang/srt/managers/detokenizer_manager.py +19 -21
  34. sglang/srt/managers/io_struct.py +11 -5
  35. sglang/srt/managers/tokenizer_manager.py +16 -14
  36. sglang/srt/model_config.py +89 -4
  37. sglang/srt/models/chatglm.py +399 -0
  38. sglang/srt/models/commandr.py +2 -2
  39. sglang/srt/models/dbrx.py +1 -1
  40. sglang/srt/models/gemma.py +5 -1
  41. sglang/srt/models/gemma2.py +436 -0
  42. sglang/srt/models/grok.py +204 -137
  43. sglang/srt/models/llama2.py +12 -5
  44. sglang/srt/models/llama_classification.py +107 -0
  45. sglang/srt/models/llava.py +11 -8
  46. sglang/srt/models/llavavid.py +1 -1
  47. sglang/srt/models/minicpm.py +373 -0
  48. sglang/srt/models/mixtral.py +164 -115
  49. sglang/srt/models/mixtral_quant.py +0 -1
  50. sglang/srt/models/qwen.py +1 -1
  51. sglang/srt/models/qwen2.py +1 -1
  52. sglang/srt/models/qwen2_moe.py +454 -0
  53. sglang/srt/models/stablelm.py +1 -1
  54. sglang/srt/models/yivl.py +2 -2
  55. sglang/srt/openai_api_adapter.py +35 -25
  56. sglang/srt/openai_protocol.py +2 -2
  57. sglang/srt/server.py +69 -19
  58. sglang/srt/server_args.py +76 -43
  59. sglang/srt/utils.py +177 -35
  60. sglang/test/test_programs.py +28 -10
  61. sglang/utils.py +4 -3
  62. {sglang-0.1.17.dist-info → sglang-0.1.19.dist-info}/METADATA +44 -31
  63. sglang-0.1.19.dist-info/RECORD +81 -0
  64. {sglang-0.1.17.dist-info → sglang-0.1.19.dist-info}/WHEEL +1 -1
  65. sglang/srt/managers/router/infer_batch.py +0 -596
  66. sglang/srt/managers/router/manager.py +0 -82
  67. sglang/srt/managers/router/model_rpc.py +0 -818
  68. sglang/srt/managers/router/model_runner.py +0 -445
  69. sglang/srt/managers/router/radix_cache.py +0 -267
  70. sglang/srt/managers/router/scheduler.py +0 -59
  71. sglang-0.1.17.dist-info/RECORD +0 -81
  72. {sglang-0.1.17.dist-info → sglang-0.1.19.dist-info}/LICENSE +0 -0
  73. {sglang-0.1.17.dist-info → sglang-0.1.19.dist-info}/top_level.txt +0 -0
@@ -9,10 +9,8 @@ from typing import Any, Dict, Optional, Tuple
9
9
  import torch
10
10
  import triton
11
11
  import triton.language as tl
12
-
13
12
  from vllm import _custom_ops as ops
14
13
  from vllm.logger import init_logger
15
- from vllm.utils import is_hip
16
14
 
17
15
  logger = init_logger(__name__)
18
16
 
@@ -109,12 +107,16 @@ def fused_moe_kernel(
109
107
 
110
108
  offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
111
109
  offs_k = tl.arange(0, BLOCK_SIZE_K)
112
- a_ptrs = a_ptr + (offs_token[:, None] // top_k * stride_am +
113
- offs_k[None, :] * stride_ak)
110
+ a_ptrs = a_ptr + (
111
+ offs_token[:, None] // top_k * stride_am + offs_k[None, :] * stride_ak
112
+ )
114
113
 
115
114
  off_experts = tl.load(expert_ids_ptr + pid_m)
116
- b_ptrs = b_ptr + off_experts * stride_be + (offs_k[:, None] * stride_bk +
117
- offs_bn[None, :] * stride_bn)
115
+ b_ptrs = (
116
+ b_ptr
117
+ + off_experts * stride_be
118
+ + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
119
+ )
118
120
 
119
121
  if use_fp8:
120
122
  a_scale = tl.load(a_scale_ptr)
@@ -130,13 +132,12 @@ def fused_moe_kernel(
130
132
  for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
131
133
  # Load the next block of A and B, generate a mask by checking the
132
134
  # K dimension.
133
- a = tl.load(a_ptrs,
134
- mask=token_mask[:, None] &
135
- (offs_k[None, :] < K - k * BLOCK_SIZE_K),
136
- other=0.0)
137
- b = tl.load(b_ptrs,
138
- mask=offs_k[:, None] < K - k * BLOCK_SIZE_K,
139
- other=0.0)
135
+ a = tl.load(
136
+ a_ptrs,
137
+ mask=token_mask[:, None] & (offs_k[None, :] < K - k * BLOCK_SIZE_K),
138
+ other=0.0,
139
+ )
140
+ b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
140
141
  # We accumulate along the K dimension.
141
142
  if use_fp8:
142
143
  accumulator = tl.dot(a, b, acc=accumulator)
@@ -147,9 +148,7 @@ def fused_moe_kernel(
147
148
  b_ptrs += BLOCK_SIZE_K * stride_bk
148
149
 
149
150
  if MUL_ROUTED_WEIGHT:
150
- moe_weight = tl.load(topk_weights_ptr + offs_token,
151
- mask=token_mask,
152
- other=0)
151
+ moe_weight = tl.load(topk_weights_ptr + offs_token, mask=token_mask, other=0)
153
152
  accumulator = accumulator * moe_weight[:, None]
154
153
 
155
154
  if use_fp8:
@@ -159,15 +158,14 @@ def fused_moe_kernel(
159
158
  # -----------------------------------------------------------
160
159
  # Write back the block of the output
161
160
  offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
162
- c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[
163
- None, :]
161
+ c_ptrs = c_ptr + stride_cm * offs_token[:, None] + stride_cn * offs_cn[None, :]
164
162
  c_mask = token_mask[:, None] & (offs_cn[None, :] < N)
165
163
  tl.store(c_ptrs, accumulator, mask=c_mask)
166
164
 
167
165
 
168
166
  def moe_align_block_size(
169
- topk_ids: torch.Tensor, block_size: int,
170
- num_experts: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
167
+ topk_ids: torch.Tensor, block_size: int, num_experts: int
168
+ ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
171
169
  """
172
170
  Aligns the token distribution across experts to be compatible with block
173
171
  size for matrix multiplication.
@@ -206,32 +204,38 @@ def moe_align_block_size(
206
204
  by block_size for proper block matrix operations.
207
205
  """
208
206
  max_num_tokens_padded = topk_ids.numel() + num_experts * (block_size - 1)
209
- sorted_ids = torch.empty((max_num_tokens_padded, ),
210
- dtype=torch.int32,
211
- device=topk_ids.device)
207
+ sorted_ids = torch.empty(
208
+ (max_num_tokens_padded,), dtype=torch.int32, device=topk_ids.device
209
+ )
212
210
  sorted_ids.fill_(topk_ids.numel())
213
211
  max_num_m_blocks = triton.cdiv(max_num_tokens_padded, block_size)
214
- expert_ids = torch.empty((max_num_m_blocks, ),
215
- dtype=torch.int32,
216
- device=topk_ids.device)
217
- num_tokens_post_pad = torch.empty((1),
218
- dtype=torch.int32,
219
- device=topk_ids.device)
220
- ops.moe_align_block_size(topk_ids, num_experts, block_size, sorted_ids,
221
- expert_ids, num_tokens_post_pad)
212
+ expert_ids = torch.empty(
213
+ (max_num_m_blocks,), dtype=torch.int32, device=topk_ids.device
214
+ )
215
+ num_tokens_post_pad = torch.empty((1), dtype=torch.int32, device=topk_ids.device)
216
+ ops.moe_align_block_size(
217
+ topk_ids, num_experts, block_size, sorted_ids, expert_ids, num_tokens_post_pad
218
+ )
222
219
  return sorted_ids, expert_ids, num_tokens_post_pad
223
220
 
224
221
 
225
- def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor,
226
- A_scale: Optional[torch.Tensor],
227
- B_scale: Optional[torch.Tensor],
228
- topk_weights: torch.Tensor, topk_ids: torch.Tensor,
229
- sorted_token_ids: torch.Tensor,
230
- expert_ids: torch.Tensor,
231
- num_tokens_post_padded: torch.Tensor,
232
- mul_routed_weight: bool, top_k: int,
233
- config: Dict[str, Any], compute_type: tl.dtype,
234
- use_fp8: bool) -> None:
222
+ def invoke_fused_moe_kernel(
223
+ A: torch.Tensor,
224
+ B: torch.Tensor,
225
+ C: torch.Tensor,
226
+ A_scale: Optional[torch.Tensor],
227
+ B_scale: Optional[torch.Tensor],
228
+ topk_weights: torch.Tensor,
229
+ topk_ids: torch.Tensor,
230
+ sorted_token_ids: torch.Tensor,
231
+ expert_ids: torch.Tensor,
232
+ num_tokens_post_padded: torch.Tensor,
233
+ mul_routed_weight: bool,
234
+ top_k: int,
235
+ config: Dict[str, Any],
236
+ compute_type: tl.dtype,
237
+ use_fp8: bool,
238
+ ) -> None:
235
239
  assert topk_weights.stride(1) == 1
236
240
  assert sorted_token_ids.stride(0) == 1
237
241
 
@@ -242,8 +246,10 @@ def invoke_fused_moe_kernel(A: torch.Tensor, B: torch.Tensor, C: torch.Tensor,
242
246
  A, A_scale = ops.scaled_fp8_quant(A, A_scale)
243
247
  assert B_scale is not None
244
248
 
245
- grid = lambda META: (triton.cdiv(sorted_token_ids.shape[0], META[
246
- 'BLOCK_SIZE_M']) * triton.cdiv(B.shape[1], META['BLOCK_SIZE_N']), )
249
+ grid = lambda META: (
250
+ triton.cdiv(sorted_token_ids.shape[0], META["BLOCK_SIZE_M"])
251
+ * triton.cdiv(B.shape[1], META["BLOCK_SIZE_N"]),
252
+ )
247
253
 
248
254
  fused_moe_kernel[grid](
249
255
  A,
@@ -281,8 +287,7 @@ def get_config_file_name(E: int, N: int, dtype: Optional[str]) -> str:
281
287
 
282
288
 
283
289
  @functools.lru_cache
284
- def get_moe_configs(E: int, N: int,
285
- dtype: Optional[str]) -> Optional[Dict[int, Any]]:
290
+ def get_moe_configs(E: int, N: int, dtype: Optional[str]) -> Optional[Dict[int, Any]]:
286
291
  """
287
292
  Return optimized configurations for the fused MoE kernel.
288
293
 
@@ -297,11 +302,11 @@ def get_moe_configs(E: int, N: int,
297
302
  json_file_name = get_config_file_name(E, N, dtype)
298
303
 
299
304
  config_file_path = os.path.join(
300
- os.path.dirname(os.path.realpath(__file__)), "configs", json_file_name)
305
+ os.path.dirname(os.path.realpath(__file__)), "configs", json_file_name
306
+ )
301
307
  if os.path.exists(config_file_path):
302
308
  with open(config_file_path) as f:
303
- logger.info("Using configuration from %s for MoE layer.",
304
- config_file_path)
309
+ logger.info("Using configuration from %s for MoE layer.", config_file_path)
305
310
  # If a configuration has been found, return it
306
311
  return {int(key): val for key, val in json.load(f).items()}
307
312
 
@@ -310,6 +315,188 @@ def get_moe_configs(E: int, N: int,
310
315
  return None
311
316
 
312
317
 
318
+ def get_default_config(
319
+ M: int,
320
+ E: int,
321
+ N: int,
322
+ K: int,
323
+ topk: int,
324
+ dtype: Optional[str],
325
+ ) -> Dict[str, int]:
326
+ if dtype == "float8":
327
+ config = {
328
+ "BLOCK_SIZE_M": 128,
329
+ "BLOCK_SIZE_N": 256,
330
+ "BLOCK_SIZE_K": 128,
331
+ "GROUP_SIZE_M": 32,
332
+ "num_warps": 8,
333
+ "num_stages": 4,
334
+ }
335
+ if M <= E:
336
+ config = {
337
+ "BLOCK_SIZE_M": 64,
338
+ "BLOCK_SIZE_N": 128,
339
+ "BLOCK_SIZE_K": 128,
340
+ "GROUP_SIZE_M": 1,
341
+ "num_warps": 4,
342
+ "num_stages": 4,
343
+ }
344
+ else:
345
+ config = {
346
+ "BLOCK_SIZE_M": 64,
347
+ "BLOCK_SIZE_N": 64,
348
+ "BLOCK_SIZE_K": 32,
349
+ "GROUP_SIZE_M": 8,
350
+ }
351
+ if M <= E:
352
+ config = {
353
+ "BLOCK_SIZE_M": 16,
354
+ "BLOCK_SIZE_N": 32,
355
+ "BLOCK_SIZE_K": 64,
356
+ "GROUP_SIZE_M": 1,
357
+ }
358
+ return config
359
+
360
+
361
+ def fused_topk(
362
+ hidden_states: torch.Tensor,
363
+ gating_output: torch.Tensor,
364
+ topk: int,
365
+ renormalize: bool,
366
+ ):
367
+ assert hidden_states.shape[0] == gating_output.shape[0], "Number of tokens mismatch"
368
+
369
+ M, _ = hidden_states.shape
370
+
371
+ topk_weights = torch.empty(
372
+ M, topk, dtype=torch.float32, device=hidden_states.device
373
+ )
374
+ topk_ids = torch.empty(M, topk, dtype=torch.int32, device=hidden_states.device)
375
+ token_expert_indicies = torch.empty(
376
+ M, topk, dtype=torch.int32, device=hidden_states.device
377
+ )
378
+ ops.topk_softmax(
379
+ topk_weights,
380
+ topk_ids,
381
+ token_expert_indicies,
382
+ gating_output.float(), # TODO(woosuk): Optimize this.
383
+ )
384
+ del token_expert_indicies # Not used. Will be used in the future.
385
+
386
+ if renormalize:
387
+ topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
388
+ return topk_weights, topk_ids
389
+
390
+
391
+ def fused_experts(
392
+ hidden_states: torch.Tensor,
393
+ w1: torch.Tensor,
394
+ w2: torch.Tensor,
395
+ topk_weights: torch.Tensor,
396
+ topk_ids: torch.Tensor,
397
+ inplace: bool = False,
398
+ override_config: Optional[Dict[str, Any]] = None,
399
+ use_fp8: bool = False,
400
+ w1_scale: Optional[torch.Tensor] = None,
401
+ w2_scale: Optional[torch.Tensor] = None,
402
+ a1_scale: Optional[torch.Tensor] = None,
403
+ a2_scale: Optional[torch.Tensor] = None,
404
+ ):
405
+ # Check constraints.
406
+ assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch"
407
+ assert topk_weights.shape == topk_ids.shape, "topk shape mismatch"
408
+ assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
409
+ assert w1.is_contiguous(), "Expert weights1 must be contiguous"
410
+ assert w2.is_contiguous(), "Expert weights2 must be contiguous"
411
+ assert hidden_states.dtype in [torch.float32, torch.float16, torch.bfloat16]
412
+
413
+ M, _ = hidden_states.shape
414
+ E, N, _ = w1.shape
415
+
416
+ if override_config:
417
+ config = override_config
418
+ else:
419
+ # First try to load optimal config from the file
420
+ configs = get_moe_configs(E, w2.shape[2], "float8" if use_fp8 else None)
421
+
422
+ if configs:
423
+ # If an optimal configuration map has been found, look up the
424
+ # optimal config
425
+ config = configs[min(configs.keys(), key=lambda x: abs(x - M))]
426
+ else:
427
+ # Else use the default config
428
+ config = get_default_config(
429
+ M, E, N, w1.shape[2], topk_ids.shape[1], "float8" if use_fp8 else None
430
+ )
431
+
432
+ intermediate_cache1 = torch.empty(
433
+ (M, topk_ids.shape[1], N),
434
+ device=hidden_states.device,
435
+ dtype=hidden_states.dtype,
436
+ )
437
+ intermediate_cache2 = torch.empty(
438
+ (M * topk_ids.shape[1], N // 2),
439
+ device=hidden_states.device,
440
+ dtype=hidden_states.dtype,
441
+ )
442
+ intermediate_cache3 = torch.empty(
443
+ (M, topk_ids.shape[1], w2.shape[1]),
444
+ device=hidden_states.device,
445
+ dtype=hidden_states.dtype,
446
+ )
447
+
448
+ sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size(
449
+ topk_ids, config["BLOCK_SIZE_M"], E
450
+ )
451
+ compute_type = tl.bfloat16 if hidden_states.dtype == torch.bfloat16 else tl.float16
452
+
453
+ invoke_fused_moe_kernel(
454
+ hidden_states,
455
+ w1,
456
+ intermediate_cache1,
457
+ a1_scale,
458
+ w1_scale,
459
+ topk_weights,
460
+ topk_ids,
461
+ sorted_token_ids,
462
+ expert_ids,
463
+ num_tokens_post_padded,
464
+ False,
465
+ topk_ids.shape[1],
466
+ config,
467
+ compute_type=compute_type,
468
+ use_fp8=use_fp8,
469
+ )
470
+
471
+ ops.gelu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N))
472
+
473
+ invoke_fused_moe_kernel(
474
+ intermediate_cache2,
475
+ w2,
476
+ intermediate_cache3,
477
+ a2_scale,
478
+ w2_scale,
479
+ topk_weights,
480
+ topk_ids,
481
+ sorted_token_ids,
482
+ expert_ids,
483
+ num_tokens_post_padded,
484
+ True,
485
+ 1,
486
+ config,
487
+ compute_type=compute_type,
488
+ use_fp8=use_fp8,
489
+ )
490
+
491
+ if inplace:
492
+ return torch.sum(
493
+ intermediate_cache3.view(*intermediate_cache3.shape),
494
+ dim=1,
495
+ out=hidden_states,
496
+ )
497
+ return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape), dim=1)
498
+
499
+
313
500
  def fused_moe(
314
501
  hidden_states: torch.Tensor,
315
502
  w1: torch.Tensor,
@@ -352,134 +539,58 @@ def fused_moe(
352
539
  - torch.Tensor: The output tensor after applying the MoE layer.
353
540
  """
354
541
  # Check constraints.
355
- assert hidden_states.shape[0] == gating_output.shape[0], (
356
- "Number of tokens mismatch")
357
- assert hidden_states.shape[1] == w1.shape[2], "Hidden size mismatch"
358
542
  assert gating_output.shape[1] == w1.shape[0], "Number of experts mismatch"
359
- assert hidden_states.is_contiguous(), "Hidden_states must be contiguous"
360
- assert w1.is_contiguous(), "Expert weights1 must be contiguous"
361
- assert w2.is_contiguous(), "Expert weights2 must be contiguous"
362
- assert hidden_states.dtype in [
363
- torch.float32, torch.float16, torch.bfloat16
364
- ]
365
- M, _ = hidden_states.shape
366
- E, N, _ = w1.shape
367
543
 
368
- if is_hip():
369
- # The MoE kernels are not yet supported on ROCm.
370
- routing_weights = torch.softmax(gating_output,
371
- dim=-1,
372
- dtype=torch.float32)
373
- topk_weights, topk_ids = torch.topk(routing_weights, topk, dim=-1)
374
- else:
375
- import vllm._moe_C as moe_kernels
376
-
377
- topk_weights = torch.empty(M,
378
- topk,
379
- dtype=torch.float32,
380
- device=hidden_states.device)
381
- topk_ids = torch.empty(M,
382
- topk,
383
- dtype=torch.int32,
384
- device=hidden_states.device)
385
- token_expert_indicies = torch.empty(M,
386
- topk,
387
- dtype=torch.int32,
388
- device=hidden_states.device)
389
- moe_kernels.topk_softmax(
390
- topk_weights,
391
- topk_ids,
392
- token_expert_indicies,
393
- gating_output.float(), # TODO(woosuk): Optimize this.
544
+ if hasattr(ops, "topk_softmax"):
545
+ topk_weights, topk_ids = fused_topk(
546
+ hidden_states, gating_output, topk, renormalize
394
547
  )
395
- del token_expert_indicies # Not used. Will be used in the future.
396
- if renormalize:
397
- topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
398
-
399
- if override_config:
400
- config = override_config
401
548
  else:
402
- # First try to load optimal config from the file
403
- configs = get_moe_configs(E, w2.shape[2],
404
- "float8" if use_fp8 else None)
549
+ topk_weights, topk_ids = fused_topk_v0_4_3(
550
+ hidden_states, gating_output, topk, renormalize
551
+ )
405
552
 
406
- if configs:
407
- # If an optimal configuration map has been found, look up the
408
- # optimal config
409
- config = configs[min(configs.keys(), key=lambda x: abs(x - M))]
410
- else:
411
- # Else use the default config
412
- config = {
413
- "BLOCK_SIZE_M": 128,
414
- "BLOCK_SIZE_N": 64,
415
- "BLOCK_SIZE_K": 128,
416
- "GROUP_SIZE_M": 1,
417
- "num_warps": 4,
418
- "num_stages": 4
419
- }
553
+ return fused_experts(
554
+ hidden_states,
555
+ w1,
556
+ w2,
557
+ topk_weights,
558
+ topk_ids,
559
+ inplace=inplace,
560
+ override_config=override_config,
561
+ use_fp8=use_fp8,
562
+ w1_scale=w1_scale,
563
+ w2_scale=w2_scale,
564
+ a1_scale=a1_scale,
565
+ a2_scale=a2_scale,
566
+ )
420
567
 
421
- if M <= E:
422
- config = {
423
- "BLOCK_SIZE_M": 128,
424
- "BLOCK_SIZE_N": 256,
425
- "BLOCK_SIZE_K": 128,
426
- "GROUP_SIZE_M": 16,
427
- "num_warps": 8,
428
- "num_stages": 4
429
- }
430
-
431
- intermediate_cache1 = torch.empty((M, topk_ids.shape[1], N),
432
- device=hidden_states.device,
433
- dtype=hidden_states.dtype)
434
- intermediate_cache2 = torch.empty((M * topk_ids.shape[1], N // 2),
435
- device=hidden_states.device,
436
- dtype=hidden_states.dtype)
437
- intermediate_cache3 = torch.empty((M, topk_ids.shape[1], w2.shape[1]),
438
- device=hidden_states.device,
439
- dtype=hidden_states.dtype)
440
568
 
441
- sorted_token_ids, expert_ids, num_tokens_post_padded = moe_align_block_size(
442
- topk_ids, config['BLOCK_SIZE_M'], E)
443
- compute_type = (tl.bfloat16
444
- if hidden_states.dtype == torch.bfloat16 else tl.float16)
445
-
446
- invoke_fused_moe_kernel(hidden_states,
447
- w1,
448
- intermediate_cache1,
449
- a1_scale,
450
- w1_scale,
451
- topk_weights,
452
- topk_ids,
453
- sorted_token_ids,
454
- expert_ids,
455
- num_tokens_post_padded,
456
- False,
457
- topk_ids.shape[1],
458
- config,
459
- compute_type=compute_type,
460
- use_fp8=use_fp8)
569
+ def fused_topk_v0_4_3(
570
+ hidden_states: torch.Tensor,
571
+ gating_output: torch.Tensor,
572
+ topk: int,
573
+ renormalize: bool,
574
+ ):
575
+ import vllm._moe_C as moe_kernels
461
576
 
462
- ops.gelu_and_mul(intermediate_cache2, intermediate_cache1.view(-1, N))
577
+ M, _ = hidden_states.shape
463
578
 
464
- invoke_fused_moe_kernel(intermediate_cache2,
465
- w2,
466
- intermediate_cache3,
467
- a2_scale,
468
- w2_scale,
469
- topk_weights,
470
- topk_ids,
471
- sorted_token_ids,
472
- expert_ids,
473
- num_tokens_post_padded,
474
- True,
475
- 1,
476
- config,
477
- compute_type=compute_type,
478
- use_fp8=use_fp8)
579
+ topk_weights = torch.empty(
580
+ M, topk, dtype=torch.float32, device=hidden_states.device
581
+ )
582
+ topk_ids = torch.empty(M, topk, dtype=torch.int32, device=hidden_states.device)
583
+ token_expert_indicies = torch.empty(
584
+ M, topk, dtype=torch.int32, device=hidden_states.device
585
+ )
586
+ moe_kernels.topk_softmax(
587
+ topk_weights,
588
+ topk_ids,
589
+ token_expert_indicies,
590
+ gating_output.float(), # TODO(woosuk): Optimize this.
591
+ )
592
+ del token_expert_indicies # Not used. Will be used in the future.
593
+ if renormalize:
594
+ topk_weights = topk_weights / topk_weights.sum(dim=-1, keepdim=True)
479
595
 
480
- if inplace:
481
- return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape),
482
- dim=1,
483
- out=hidden_states)
484
- return torch.sum(intermediate_cache3.view(*intermediate_cache3.shape),
485
- dim=1)
596
+ return topk_weights, topk_ids