sglang 0.5.4__py3-none-any.whl → 0.5.4.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (88) hide show
  1. sglang/bench_serving.py +56 -12
  2. sglang/launch_server.py +2 -0
  3. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +101 -4
  4. sglang/srt/compilation/backend.py +1 -1
  5. sglang/srt/configs/model_config.py +5 -5
  6. sglang/srt/distributed/parallel_state.py +0 -7
  7. sglang/srt/entrypoints/engine.py +18 -15
  8. sglang/srt/entrypoints/grpc_server.py +0 -1
  9. sglang/srt/entrypoints/http_server.py +75 -94
  10. sglang/srt/environ.py +16 -2
  11. sglang/srt/eplb/expert_distribution.py +30 -0
  12. sglang/srt/function_call/function_call_parser.py +2 -0
  13. sglang/srt/function_call/minimax_m2.py +367 -0
  14. sglang/srt/layers/activation.py +6 -0
  15. sglang/srt/layers/attention/flashattention_backend.py +12 -2
  16. sglang/srt/layers/attention/flashinfer_backend.py +10 -1
  17. sglang/srt/layers/attention/flashinfer_mla_backend.py +18 -10
  18. sglang/srt/layers/attention/trtllm_mla_backend.py +1 -13
  19. sglang/srt/layers/attention/utils.py +78 -0
  20. sglang/srt/layers/communicator.py +1 -0
  21. sglang/srt/layers/deep_gemm_wrapper/compile_utils.py +1 -1
  22. sglang/srt/layers/layernorm.py +19 -4
  23. sglang/srt/layers/logits_processor.py +5 -0
  24. sglang/srt/layers/moe/cutlass_w4a8_moe.py +138 -0
  25. sglang/srt/layers/moe/ep_moe/kernels.py +194 -0
  26. sglang/srt/layers/moe/ep_moe/layer.py +79 -272
  27. sglang/srt/layers/moe/fused_moe_triton/layer.py +3 -3
  28. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +7 -4
  29. sglang/srt/layers/moe/moe_runner/deep_gemm.py +287 -22
  30. sglang/srt/layers/moe/moe_runner/runner.py +3 -0
  31. sglang/srt/layers/moe/moe_runner/triton_kernels.py +194 -0
  32. sglang/srt/layers/moe/token_dispatcher/__init__.py +4 -4
  33. sglang/srt/layers/moe/token_dispatcher/base.py +11 -5
  34. sglang/srt/layers/moe/token_dispatcher/deepep.py +18 -14
  35. sglang/srt/layers/moe/token_dispatcher/standard.py +1 -1
  36. sglang/srt/layers/moe/topk.py +4 -4
  37. sglang/srt/layers/moe/utils.py +3 -4
  38. sglang/srt/layers/quantization/__init__.py +3 -5
  39. sglang/srt/layers/quantization/awq.py +0 -3
  40. sglang/srt/layers/quantization/base_config.py +7 -0
  41. sglang/srt/layers/quantization/fp8.py +68 -63
  42. sglang/srt/layers/quantization/gguf.py +566 -0
  43. sglang/srt/layers/quantization/mxfp4.py +30 -38
  44. sglang/srt/layers/quantization/unquant.py +23 -45
  45. sglang/srt/layers/quantization/w4afp8.py +38 -2
  46. sglang/srt/layers/radix_attention.py +5 -2
  47. sglang/srt/layers/rotary_embedding.py +13 -1
  48. sglang/srt/layers/sampler.py +12 -1
  49. sglang/srt/managers/io_struct.py +3 -0
  50. sglang/srt/managers/multi_tokenizer_mixin.py +17 -1
  51. sglang/srt/managers/scheduler.py +21 -15
  52. sglang/srt/managers/scheduler_metrics_mixin.py +22 -14
  53. sglang/srt/managers/scheduler_profiler_mixin.py +3 -4
  54. sglang/srt/managers/tokenizer_manager.py +11 -19
  55. sglang/srt/mem_cache/hicache_storage.py +7 -1
  56. sglang/srt/mem_cache/memory_pool.py +82 -0
  57. sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +3 -2
  58. sglang/srt/model_executor/forward_batch_info.py +44 -3
  59. sglang/srt/model_executor/model_runner.py +1 -149
  60. sglang/srt/model_executor/piecewise_cuda_graph_runner.py +22 -12
  61. sglang/srt/models/deepseek_v2.py +147 -44
  62. sglang/srt/models/glm4_moe.py +322 -354
  63. sglang/srt/models/glm4_moe_nextn.py +4 -14
  64. sglang/srt/models/glm4v_moe.py +29 -196
  65. sglang/srt/models/minimax_m2.py +922 -0
  66. sglang/srt/models/nvila.py +355 -0
  67. sglang/srt/models/nvila_lite.py +184 -0
  68. sglang/srt/models/qwen2.py +22 -1
  69. sglang/srt/models/qwen3.py +34 -4
  70. sglang/srt/models/qwen3_moe.py +2 -4
  71. sglang/srt/multimodal/processors/base_processor.py +1 -0
  72. sglang/srt/multimodal/processors/glm4v.py +1 -1
  73. sglang/srt/multimodal/processors/{vila.py → nvila.py} +32 -24
  74. sglang/srt/multimodal/processors/points_v15_chat.py +2 -2
  75. sglang/srt/parser/reasoning_parser.py +28 -1
  76. sglang/srt/server_args.py +365 -186
  77. sglang/srt/single_batch_overlap.py +2 -7
  78. sglang/srt/utils/common.py +87 -42
  79. sglang/srt/utils/hf_transformers_utils.py +7 -3
  80. sglang/test/test_deterministic.py +235 -12
  81. sglang/test/test_deterministic_utils.py +2 -1
  82. sglang/version.py +1 -1
  83. {sglang-0.5.4.dist-info → sglang-0.5.4.post1.dist-info}/METADATA +7 -6
  84. {sglang-0.5.4.dist-info → sglang-0.5.4.post1.dist-info}/RECORD +87 -82
  85. sglang/srt/models/vila.py +0 -306
  86. {sglang-0.5.4.dist-info → sglang-0.5.4.post1.dist-info}/WHEEL +0 -0
  87. {sglang-0.5.4.dist-info → sglang-0.5.4.post1.dist-info}/licenses/LICENSE +0 -0
  88. {sglang-0.5.4.dist-info → sglang-0.5.4.post1.dist-info}/top_level.txt +0 -0
@@ -73,9 +73,16 @@ class RMSNorm(CustomOp):
73
73
  hidden_size: int,
74
74
  eps: float = 1e-6,
75
75
  var_hidden_size: Optional[int] = None,
76
+ cast_x_before_out_mul: bool = False,
77
+ fp32_residual: bool = False,
78
+ weight_dtype: Optional = None,
79
+ override_orig_dtype: Optional = None,
76
80
  ) -> None:
77
81
  super().__init__()
78
- self.weight = nn.Parameter(torch.ones(hidden_size))
82
+ self.cast_x_before_out_mul = cast_x_before_out_mul
83
+ self.fp32_residual = fp32_residual
84
+ self.override_orig_dtype = override_orig_dtype
85
+ self.weight = nn.Parameter(torch.ones(hidden_size, dtype=weight_dtype))
79
86
  self.variance_epsilon = eps
80
87
  self.hidden_size = hidden_size
81
88
  self.variance_size_override = (
@@ -165,11 +172,14 @@ class RMSNorm(CustomOp):
165
172
  ) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]:
166
173
  if not x.is_contiguous():
167
174
  x = x.contiguous()
168
- orig_dtype = x.dtype
175
+ orig_dtype = self.override_orig_dtype or x.dtype
169
176
  x = x.to(torch.float32)
170
177
  if residual is not None:
171
178
  x = x + residual.to(torch.float32)
172
- residual = x.to(orig_dtype)
179
+ if self.fp32_residual:
180
+ residual = x.clone()
181
+ else:
182
+ residual = x.to(orig_dtype)
173
183
 
174
184
  hidden_size = x.shape[-1]
175
185
  if hidden_size != self.hidden_size:
@@ -191,7 +201,12 @@ class RMSNorm(CustomOp):
191
201
 
192
202
  variance = x_var.pow(2).mean(dim=-1, keepdim=True)
193
203
  x = x * torch.rsqrt(variance + self.variance_epsilon)
194
- x = (x * self.weight).to(orig_dtype)
204
+
205
+ if self.cast_x_before_out_mul:
206
+ x = self.weight * x.to(orig_dtype)
207
+ else:
208
+ x = (x * self.weight).to(orig_dtype)
209
+
195
210
  if residual is None:
196
211
  return x
197
212
  else:
@@ -593,6 +593,11 @@ class LogitsProcessor(nn.Module):
593
593
  None, # bias
594
594
  True, # is_vnni
595
595
  )
596
+ elif get_global_server_args().rl_on_policy_target == "fsdp":
597
+ # Due to tie-weight, we may not be able to change lm_head's weight dtype
598
+ logits = torch.matmul(
599
+ hidden_states.bfloat16(), lm_head.weight.T.bfloat16()
600
+ )
596
601
  else:
597
602
  logits = torch.matmul(
598
603
  hidden_states.to(lm_head.weight.dtype), lm_head.weight.T
@@ -11,12 +11,14 @@ from sgl_kernel import (
11
11
  )
12
12
 
13
13
  from sglang.srt.layers.moe.ep_moe.kernels import (
14
+ deepep_ll_get_cutlass_w4a8_moe_mm_data,
14
15
  deepep_permute_triton_kernel,
15
16
  deepep_post_reorder_triton_kernel,
16
17
  deepep_run_moe_deep_preprocess,
17
18
  post_reorder_triton_kernel_for_cutlass_moe,
18
19
  pre_reorder_triton_kernel_for_cutlass_moe,
19
20
  run_moe_ep_preproess,
21
+ silu_and_mul_masked_post_per_tensor_quant_fwd,
20
22
  )
21
23
 
22
24
 
@@ -396,3 +398,139 @@ def cutlass_w4a8_moe_deepep_normal(
396
398
  )
397
399
 
398
400
  return output
401
+
402
+
403
+ def cutlass_w4a8_moe_deepep_ll(
404
+ a: torch.Tensor,
405
+ w1_q: torch.Tensor,
406
+ w2_q: torch.Tensor,
407
+ w1_scale: torch.Tensor,
408
+ w2_scale: torch.Tensor,
409
+ topk_ids_: torch.Tensor,
410
+ masked_m: torch.Tensor,
411
+ a_strides1: torch.Tensor,
412
+ b_strides1: torch.Tensor,
413
+ c_strides1: torch.Tensor,
414
+ a_strides2: torch.Tensor,
415
+ b_strides2: torch.Tensor,
416
+ c_strides2: torch.Tensor,
417
+ s_strides13: torch.Tensor,
418
+ s_strides2: torch.Tensor,
419
+ expert_offsets: torch.Tensor,
420
+ problem_sizes1: torch.Tensor,
421
+ problem_sizes2: torch.Tensor,
422
+ a1_scale: Optional[torch.Tensor] = None,
423
+ a2_scale: Optional[torch.Tensor] = None,
424
+ ) -> torch.Tensor:
425
+ """
426
+ This function computes a w4a8-quantized Mixture of Experts (MoE) layer
427
+ using two sets of quantized weights, w1_q and w2_q, and top-k gating
428
+ mechanism. The matrix multiplications are implemented with CUTLASS
429
+ grouped gemm.
430
+
431
+ Parameters:
432
+ - a (torch.Tensor): The input tensor to the MoE layer.
433
+ Shape: [num_local_experts, num_max_dispatch_tokens_per_rank * num_ranks, K]
434
+ - w1_q (torch.Tensor): The first set of int4-quantized expert weights.
435
+ Shape: [num_experts, N * 2, K // 2]
436
+ (the weights are passed transposed and int4-packed)
437
+ - w2_q (torch.Tensor): The second set of int4-quantized expert weights.
438
+ Shape: [num_experts, K, N // 2]
439
+ (the weights are passed transposed and int4-packed)
440
+ - w1_scale (torch.Tensor): The fp32 scale to dequantize w1_q.
441
+ Shape: [num_experts, K // 512, N * 8]
442
+ - w2_scale (torch.Tensor): The fp32 scale to dequantize w2_q.
443
+ Shape: [num_experts, N // 512, K * 4]
444
+ - topk_weights (torch.Tensor): The weights of each token->expert mapping.
445
+ - a_strides1 (torch.Tensor): The input strides of the first grouped gemm.
446
+ - b_strides1 (torch.Tensor): The weights strides of the first grouped gemm.
447
+ - c_strides1 (torch.Tensor): The output strides of the first grouped gemm.
448
+ - a_strides2 (torch.Tensor): The input strides of the second grouped gemm.
449
+ - b_strides2 (torch.Tensor): The weights strides of the second grouped gemm.
450
+ - c_strides2 (torch.Tensor): The output strides of the second grouped gemm.
451
+ - s_strides13 (torch.Tensor): The input and scale strides of the first grouped gemm.
452
+ - s_strides2 (torch.Tensor): The scale strides of the second grouped gemm.
453
+ - a1_scale (Optional[torch.Tensor]): The optional fp32 scale to quantize a.
454
+ Shape: scalar or [1, K]
455
+ - a2_scale (Optional[torch.Tensor]): The optional fp32 scale to
456
+ quantize the intermediate result between the gemms.
457
+ Shape: scalar or [1, N]
458
+ - apply_router_weight_on_input (bool): When true, the topk weights are
459
+ applied directly on the inputs. This is only applicable when topk is 1.
460
+
461
+ Returns:
462
+ - torch.Tensor: The fp8 output tensor after applying the MoE layer.
463
+ """
464
+ assert w1_q.dtype == torch.int8
465
+ assert w2_q.dtype == torch.int8
466
+ assert a.shape[2] // 2 == w1_q.shape[2], "Hidden size mismatch w1"
467
+ assert w1_q.shape[2] * 2 == w2_q.shape[1], "Hidden size mismatch w2"
468
+ assert w1_q.shape[0] == w2_q.shape[0], "Expert number mismatch"
469
+ assert w1_q.shape[0] == w1_scale.shape[0], "w1 scales expert number mismatch"
470
+ assert w1_q.shape[0] == w2_scale.shape[0], "w2 scales expert number mismatch"
471
+
472
+ assert a_strides1.shape[0] == w1_q.shape[0], "A Strides 1 expert number mismatch"
473
+ assert b_strides1.shape[0] == w1_q.shape[0], "B Strides 1 expert number mismatch"
474
+ assert a_strides2.shape[0] == w2_q.shape[0], "A Strides 2 expert number mismatch"
475
+ assert b_strides2.shape[0] == w2_q.shape[0], "B Strides 2 expert number mismatch"
476
+ num_experts = w1_q.size(0)
477
+ m = a.size(1)
478
+ k = w1_q.size(2) * 2 # w1_q is transposed and packed
479
+ n = w2_q.size(2) * 2 # w2_q is transposed and packed
480
+ topk = topk_ids_.size(1)
481
+
482
+ device = a.device
483
+
484
+ problem_sizes1, problem_sizes2 = deepep_ll_get_cutlass_w4a8_moe_mm_data(
485
+ masked_m,
486
+ problem_sizes1,
487
+ problem_sizes2,
488
+ num_experts,
489
+ n,
490
+ k,
491
+ )
492
+
493
+ gateup_input = torch.empty(a.shape, dtype=torch.float8_e4m3fn, device=device)
494
+ sgl_per_tensor_quant_fp8(a, gateup_input, a1_scale.float(), True)
495
+ c1 = torch.empty((num_experts, m, n * 2), device=device, dtype=torch.bfloat16)
496
+ c2 = torch.empty((num_experts, m, k), device=device, dtype=torch.bfloat16)
497
+
498
+ cutlass_w4a8_moe_mm(
499
+ c1,
500
+ gateup_input,
501
+ w1_q,
502
+ a1_scale.float(),
503
+ w1_scale,
504
+ expert_offsets[:-1],
505
+ problem_sizes1,
506
+ a_strides1,
507
+ b_strides1,
508
+ c_strides1,
509
+ s_strides13,
510
+ 128,
511
+ topk,
512
+ )
513
+
514
+ intermediate_q = torch.empty(
515
+ (num_experts, m, n), device=a.device, dtype=torch.float8_e4m3fn
516
+ )
517
+ silu_and_mul_masked_post_per_tensor_quant_fwd(
518
+ c1, intermediate_q, masked_m, a2_scale
519
+ )
520
+ cutlass_w4a8_moe_mm(
521
+ c2,
522
+ intermediate_q,
523
+ w2_q,
524
+ a2_scale.float(),
525
+ w2_scale,
526
+ expert_offsets[:-1],
527
+ problem_sizes2,
528
+ a_strides2,
529
+ b_strides2,
530
+ c_strides2,
531
+ s_strides2,
532
+ 128,
533
+ topk,
534
+ )
535
+
536
+ return c2
@@ -1014,3 +1014,197 @@ def zero_experts_compute_triton(
1014
1014
  )
1015
1015
 
1016
1016
  return output
1017
+
1018
+
1019
+ @triton.jit
1020
+ def compute_problem_sizes_w4a8_kernel(
1021
+ masked_m_ptr,
1022
+ problem_sizes1_ptr,
1023
+ problem_sizes2_ptr,
1024
+ n,
1025
+ k,
1026
+ num_experts,
1027
+ BLOCK_SIZE: tl.constexpr,
1028
+ ):
1029
+ pid = tl.program_id(axis=0) * BLOCK_SIZE + tl.arange(0, BLOCK_SIZE)
1030
+ mask = pid < num_experts
1031
+ final_occurrences = tl.load(masked_m_ptr + pid, mask=mask, other=0)
1032
+
1033
+ ps1_idx_0 = pid * 3
1034
+ ps1_idx_1 = ps1_idx_0 + 1
1035
+ ps1_idx_2 = ps1_idx_0 + 2
1036
+
1037
+ ps2_idx_0 = pid * 3
1038
+ ps2_idx_1 = ps2_idx_0 + 1
1039
+ ps2_idx_2 = ps2_idx_0 + 2
1040
+
1041
+ ps1_mask_0 = ps1_idx_0 < num_experts * 3
1042
+ ps1_mask_1 = ps1_idx_1 < num_experts * 3
1043
+ ps1_mask_2 = ps1_idx_2 < num_experts * 3
1044
+ ps2_mask_0 = ps2_idx_0 < num_experts * 3
1045
+ ps2_mask_1 = ps2_idx_1 < num_experts * 3
1046
+ ps2_mask_2 = ps2_idx_2 < num_experts * 3
1047
+
1048
+ tl.store(problem_sizes1_ptr + ps1_idx_0, 2 * n, mask=ps1_mask_0)
1049
+ tl.store(problem_sizes1_ptr + ps1_idx_1, final_occurrences, mask=ps1_mask_1)
1050
+ tl.store(problem_sizes1_ptr + ps1_idx_2, k, mask=ps1_mask_2)
1051
+
1052
+ tl.store(problem_sizes2_ptr + ps2_idx_0, k, mask=ps2_mask_0)
1053
+ tl.store(problem_sizes2_ptr + ps2_idx_1, final_occurrences, mask=ps2_mask_1)
1054
+ tl.store(problem_sizes2_ptr + ps2_idx_2, n, mask=ps2_mask_2)
1055
+
1056
+
1057
+ def compute_problem_sizes_w4a8(
1058
+ masked_m, problem_sizes1, problem_sizes2, n, k, num_experts
1059
+ ):
1060
+ BLOCK_SIZE = 256
1061
+ grid = lambda meta: (triton.cdiv(num_experts, meta["BLOCK_SIZE"]),)
1062
+ compute_problem_sizes_w4a8_kernel[grid](
1063
+ masked_m,
1064
+ problem_sizes1,
1065
+ problem_sizes2,
1066
+ n,
1067
+ k,
1068
+ num_experts,
1069
+ BLOCK_SIZE=BLOCK_SIZE,
1070
+ )
1071
+ return problem_sizes1, problem_sizes2
1072
+
1073
+
1074
+ def deepep_ll_get_cutlass_w4a8_moe_mm_data(
1075
+ masked_m,
1076
+ problem_sizes1,
1077
+ problem_sizes2,
1078
+ num_experts,
1079
+ n,
1080
+ k,
1081
+ ):
1082
+ problem_sizes1, problem_sizes2 = compute_problem_sizes_w4a8(
1083
+ masked_m, problem_sizes1, problem_sizes2, n, k, num_experts
1084
+ )
1085
+ return (
1086
+ problem_sizes1.to(torch.int32),
1087
+ problem_sizes2.to(torch.int32),
1088
+ )
1089
+
1090
+
1091
+ @triton.jit
1092
+ def _silu_and_mul_post_per_tensor_quant_kernel(
1093
+ input_ptr,
1094
+ stride_input_expert,
1095
+ stride_input_token,
1096
+ stride_input_dim,
1097
+ output_ptr,
1098
+ stride_output_expert,
1099
+ stride_output_token,
1100
+ stride_output_dim,
1101
+ scale_ptr,
1102
+ masked_m_ptr,
1103
+ inner_dim,
1104
+ fp8_max,
1105
+ fp8_min,
1106
+ BLOCK_N: tl.constexpr,
1107
+ NUM_STAGE: tl.constexpr,
1108
+ ):
1109
+ """
1110
+ Triton kernel: fused SiLU(gate) * up + per-tensor FP8 quantization.
1111
+
1112
+ Shape:
1113
+ input: [E, T_padded, 2*D] -> gate: [:,:,D], up: [:,:,D]
1114
+ output: [E, T_padded, D], dtype=float8_e4m3fn
1115
+ """
1116
+ expert_id = tl.program_id(2)
1117
+ block_id_token = tl.program_id(1)
1118
+ block_id_dim = tl.program_id(0)
1119
+
1120
+ num_token_blocks = tl.num_programs(1)
1121
+
1122
+ token_num_cur_expert = tl.load(masked_m_ptr + expert_id)
1123
+
1124
+ scale = 1.0 / tl.load(scale_ptr).to(tl.float32)
1125
+
1126
+ stride_input_expert = tl.cast(stride_input_expert, tl.int32)
1127
+ stride_output_expert = tl.cast(stride_output_expert, tl.int32)
1128
+ stride_input_token = tl.cast(stride_input_token, tl.int32)
1129
+ stride_output_token = tl.cast(stride_output_token, tl.int32)
1130
+
1131
+ offset_d = block_id_dim * BLOCK_N + tl.arange(0, BLOCK_N)
1132
+ mask_d = offset_d < inner_dim
1133
+
1134
+ # base pointers for current expert and dim block
1135
+ input_base_offs = input_ptr + expert_id * stride_input_expert + offset_d
1136
+ output_base_offs = output_ptr + expert_id * stride_output_expert + offset_d
1137
+
1138
+ for token_idx in tl.range(
1139
+ block_id_token, token_num_cur_expert, num_token_blocks, num_stages=NUM_STAGE
1140
+ ):
1141
+ gate_ptr = input_base_offs + token_idx * stride_input_token
1142
+ up_ptr = gate_ptr + inner_dim
1143
+ gate = tl.load(gate_ptr, mask=mask_d, other=0.0).to(tl.float32)
1144
+ up = tl.load(up_ptr, mask=mask_d, other=0.0).to(tl.float32)
1145
+
1146
+ # SiLU: x * sigmoid(x)
1147
+ gate = gate / (1 + tl.exp(-gate))
1148
+ gate = gate.to(input_ptr.dtype.element_ty)
1149
+ gate_up = up * gate
1150
+
1151
+ scaled = gate_up * scale
1152
+ output_q = tl.clamp(scaled, fp8_min, fp8_max).to(output_ptr.dtype.element_ty)
1153
+ out_ptr = output_base_offs + token_idx * stride_output_token
1154
+ tl.store(out_ptr, output_q, mask=mask_d)
1155
+
1156
+
1157
+ def silu_and_mul_masked_post_per_tensor_quant_fwd(
1158
+ input: torch.Tensor,
1159
+ output: torch.Tensor,
1160
+ masked_m: torch.Tensor,
1161
+ scale: torch.Tensor,
1162
+ ) -> torch.Tensor:
1163
+ """
1164
+ Fused SiLU + Mul + Per-Tensor Quantization to FP8.
1165
+
1166
+ Args:
1167
+ input: [expert_num, token_num_padded, 2 * inner_dim]
1168
+ output: [expert_num, token_num_padded, inner_dim], dtype=torch.float8_e4m3fn
1169
+ masked_m: [expert_num], actual token count for each expert
1170
+ scale: [1] or [expert_num], quantization scale (per-tensor or per-expert)
1171
+
1172
+ Returns:
1173
+ output tensor
1174
+ """
1175
+ assert input.is_contiguous()
1176
+ assert output.is_contiguous()
1177
+ assert output.dtype == torch.float8_e4m3fn
1178
+ assert input.ndim == 3
1179
+ assert input.shape[0] == masked_m.shape[0]
1180
+ assert input.shape[-1] % 2 == 0
1181
+ assert scale.numel() == 1 or scale.shape[0] == input.shape[0]
1182
+
1183
+ expert_num = input.shape[0]
1184
+ # 3584
1185
+ inner_dim = input.shape[-1] // 2
1186
+
1187
+ BLOCK_N = 256
1188
+ BLOCK_M = 64 if expert_num < 4 else 32
1189
+ NUM_STAGES = 3
1190
+ hidden_dim_split_block_num = triton.cdiv(inner_dim, BLOCK_N)
1191
+
1192
+ grid = (hidden_dim_split_block_num, BLOCK_M, expert_num)
1193
+ finfo = torch.finfo(torch.float8_e4m3fn)
1194
+ fp8_max = finfo.max
1195
+ fp8_min = -fp8_max
1196
+
1197
+ _silu_and_mul_post_per_tensor_quant_kernel[grid](
1198
+ input,
1199
+ *input.stride(),
1200
+ output,
1201
+ *output.stride(),
1202
+ scale,
1203
+ masked_m,
1204
+ inner_dim,
1205
+ fp8_max,
1206
+ fp8_min,
1207
+ BLOCK_N=BLOCK_N,
1208
+ NUM_STAGE=NUM_STAGES,
1209
+ )
1210
+ return output