sglang 0.5.4__py3-none-any.whl → 0.5.4.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (88) hide show
  1. sglang/bench_serving.py +56 -12
  2. sglang/launch_server.py +2 -0
  3. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +101 -4
  4. sglang/srt/compilation/backend.py +1 -1
  5. sglang/srt/configs/model_config.py +5 -5
  6. sglang/srt/distributed/parallel_state.py +0 -7
  7. sglang/srt/entrypoints/engine.py +18 -15
  8. sglang/srt/entrypoints/grpc_server.py +0 -1
  9. sglang/srt/entrypoints/http_server.py +75 -94
  10. sglang/srt/environ.py +16 -2
  11. sglang/srt/eplb/expert_distribution.py +30 -0
  12. sglang/srt/function_call/function_call_parser.py +2 -0
  13. sglang/srt/function_call/minimax_m2.py +367 -0
  14. sglang/srt/layers/activation.py +6 -0
  15. sglang/srt/layers/attention/flashattention_backend.py +12 -2
  16. sglang/srt/layers/attention/flashinfer_backend.py +10 -1
  17. sglang/srt/layers/attention/flashinfer_mla_backend.py +18 -10
  18. sglang/srt/layers/attention/trtllm_mla_backend.py +1 -13
  19. sglang/srt/layers/attention/utils.py +78 -0
  20. sglang/srt/layers/communicator.py +1 -0
  21. sglang/srt/layers/deep_gemm_wrapper/compile_utils.py +1 -1
  22. sglang/srt/layers/layernorm.py +19 -4
  23. sglang/srt/layers/logits_processor.py +5 -0
  24. sglang/srt/layers/moe/cutlass_w4a8_moe.py +138 -0
  25. sglang/srt/layers/moe/ep_moe/kernels.py +194 -0
  26. sglang/srt/layers/moe/ep_moe/layer.py +79 -272
  27. sglang/srt/layers/moe/fused_moe_triton/layer.py +3 -3
  28. sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +7 -4
  29. sglang/srt/layers/moe/moe_runner/deep_gemm.py +287 -22
  30. sglang/srt/layers/moe/moe_runner/runner.py +3 -0
  31. sglang/srt/layers/moe/moe_runner/triton_kernels.py +194 -0
  32. sglang/srt/layers/moe/token_dispatcher/__init__.py +4 -4
  33. sglang/srt/layers/moe/token_dispatcher/base.py +11 -5
  34. sglang/srt/layers/moe/token_dispatcher/deepep.py +18 -14
  35. sglang/srt/layers/moe/token_dispatcher/standard.py +1 -1
  36. sglang/srt/layers/moe/topk.py +4 -4
  37. sglang/srt/layers/moe/utils.py +3 -4
  38. sglang/srt/layers/quantization/__init__.py +3 -5
  39. sglang/srt/layers/quantization/awq.py +0 -3
  40. sglang/srt/layers/quantization/base_config.py +7 -0
  41. sglang/srt/layers/quantization/fp8.py +68 -63
  42. sglang/srt/layers/quantization/gguf.py +566 -0
  43. sglang/srt/layers/quantization/mxfp4.py +30 -38
  44. sglang/srt/layers/quantization/unquant.py +23 -45
  45. sglang/srt/layers/quantization/w4afp8.py +38 -2
  46. sglang/srt/layers/radix_attention.py +5 -2
  47. sglang/srt/layers/rotary_embedding.py +13 -1
  48. sglang/srt/layers/sampler.py +12 -1
  49. sglang/srt/managers/io_struct.py +3 -0
  50. sglang/srt/managers/multi_tokenizer_mixin.py +17 -1
  51. sglang/srt/managers/scheduler.py +21 -15
  52. sglang/srt/managers/scheduler_metrics_mixin.py +22 -14
  53. sglang/srt/managers/scheduler_profiler_mixin.py +3 -4
  54. sglang/srt/managers/tokenizer_manager.py +11 -19
  55. sglang/srt/mem_cache/hicache_storage.py +7 -1
  56. sglang/srt/mem_cache/memory_pool.py +82 -0
  57. sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +3 -2
  58. sglang/srt/model_executor/forward_batch_info.py +44 -3
  59. sglang/srt/model_executor/model_runner.py +1 -149
  60. sglang/srt/model_executor/piecewise_cuda_graph_runner.py +22 -12
  61. sglang/srt/models/deepseek_v2.py +147 -44
  62. sglang/srt/models/glm4_moe.py +322 -354
  63. sglang/srt/models/glm4_moe_nextn.py +4 -14
  64. sglang/srt/models/glm4v_moe.py +29 -196
  65. sglang/srt/models/minimax_m2.py +922 -0
  66. sglang/srt/models/nvila.py +355 -0
  67. sglang/srt/models/nvila_lite.py +184 -0
  68. sglang/srt/models/qwen2.py +22 -1
  69. sglang/srt/models/qwen3.py +34 -4
  70. sglang/srt/models/qwen3_moe.py +2 -4
  71. sglang/srt/multimodal/processors/base_processor.py +1 -0
  72. sglang/srt/multimodal/processors/glm4v.py +1 -1
  73. sglang/srt/multimodal/processors/{vila.py → nvila.py} +32 -24
  74. sglang/srt/multimodal/processors/points_v15_chat.py +2 -2
  75. sglang/srt/parser/reasoning_parser.py +28 -1
  76. sglang/srt/server_args.py +365 -186
  77. sglang/srt/single_batch_overlap.py +2 -7
  78. sglang/srt/utils/common.py +87 -42
  79. sglang/srt/utils/hf_transformers_utils.py +7 -3
  80. sglang/test/test_deterministic.py +235 -12
  81. sglang/test/test_deterministic_utils.py +2 -1
  82. sglang/version.py +1 -1
  83. {sglang-0.5.4.dist-info → sglang-0.5.4.post1.dist-info}/METADATA +7 -6
  84. {sglang-0.5.4.dist-info → sglang-0.5.4.post1.dist-info}/RECORD +87 -82
  85. sglang/srt/models/vila.py +0 -306
  86. {sglang-0.5.4.dist-info → sglang-0.5.4.post1.dist-info}/WHEEL +0 -0
  87. {sglang-0.5.4.dist-info → sglang-0.5.4.post1.dist-info}/licenses/LICENSE +0 -0
  88. {sglang-0.5.4.dist-info → sglang-0.5.4.post1.dist-info}/top_level.txt +0 -0
@@ -5,6 +5,7 @@ from typing import TYPE_CHECKING, List, Optional
5
5
 
6
6
  import torch
7
7
 
8
+ from sglang.srt.layers import deep_gemm_wrapper
8
9
  from sglang.srt.layers.moe.moe_runner.base import (
9
10
  MoeQuantInfo,
10
11
  MoeRunnerConfig,
@@ -15,14 +16,28 @@ from sglang.srt.layers.moe.moe_runner.base import (
15
16
  register_pre_permute,
16
17
  )
17
18
  from sglang.srt.layers.moe.utils import MoeRunnerBackend
18
- from sglang.srt.utils import dispose_tensor
19
+ from sglang.srt.utils import ceil_div, dispose_tensor, get_bool_env_var, is_hip, is_npu
20
+ from sglang.srt.utils.offloader import get_offloader
19
21
 
20
22
  if TYPE_CHECKING:
23
+ from sglang.srt.layers.moe.token_dispatcher.deepep import (
24
+ DeepEPLLCombineInput,
25
+ DeepEPLLDispatchOutput,
26
+ DeepEPNormalCombineInput,
27
+ DeepEPNormalDispatchOutput,
28
+ )
21
29
  from sglang.srt.layers.moe.token_dispatcher.standard import (
22
30
  StandardCombineInput,
23
31
  StandardDispatchOutput,
24
32
  )
25
33
 
34
+ _is_hip = is_hip()
35
+ _is_npu = is_npu()
36
+ _use_aiter = get_bool_env_var("SGLANG_USE_AITER") and _is_hip
37
+
38
+ if not (_is_npu or _is_hip):
39
+ from sgl_kernel import silu_and_mul
40
+
26
41
 
27
42
  # TODO(kaixih@nvidia): ideally we should merge this logic into
28
43
  # `fill_gateup_input_triton_kernel` to directly generate e8m0 scale.
@@ -40,13 +55,23 @@ def _cast_to_e8m0_with_rounding_up(x: torch.Tensor) -> torch.Tensor:
40
55
  return new_x.transpose(1, 2).contiguous().transpose(1, 2)
41
56
 
42
57
 
58
+ def copy_list_to_gpu_no_ce(arr: List[int]):
59
+ from sgl_kernel.elementwise import copy_to_gpu_no_ce
60
+
61
+ tensor_cpu = torch.tensor(arr, dtype=torch.int32, device="cpu")
62
+ tensor_gpu = torch.empty_like(tensor_cpu, device="cuda")
63
+ copy_to_gpu_no_ce(tensor_cpu, tensor_gpu)
64
+ return tensor_gpu
65
+
66
+
43
67
  @dataclass
44
68
  class DeepGemmRunnerInput(RunnerInput):
45
69
  hidden_states: torch.Tensor
46
70
  hidden_states_scale: torch.Tensor
47
- masked_m: torch.Tensor
48
- expected_m: int
49
71
  use_masked_gemm: bool
72
+ masked_m: Optional[torch.Tensor] = None
73
+ expected_m: Optional[int] = None
74
+ m_indices: Optional[torch.Tensor] = None
50
75
 
51
76
  @property
52
77
  def runner_backend(self) -> MoeRunnerBackend:
@@ -84,20 +109,100 @@ class DeepGemmRunnerCore(MoeRunnerCore):
84
109
  running_state: dict,
85
110
  ) -> DeepGemmRunnerOutput:
86
111
 
87
- if runner_input.use_masked_gemm:
88
- hidden_states = self._run_masked_gemm(
89
- runner_input,
90
- quant_info,
91
- running_state,
112
+ if not runner_input.use_masked_gemm:
113
+ hidden_states = self._run_contiguous_gemm(
114
+ runner_input, quant_info, running_state
92
115
  )
93
116
  else:
94
- hidden_states = self._run_contiguous_gemm(
95
- runner_input,
96
- quant_info,
97
- running_state,
117
+ hidden_states = self._run_masked_gemm(
118
+ runner_input, quant_info, running_state
98
119
  )
99
120
  return DeepGemmRunnerOutput(hidden_states=hidden_states)
100
121
 
122
+ def _run_contiguous_gemm(
123
+ self,
124
+ runner_input: DeepGemmRunnerInput,
125
+ quant_info: DeepGemmMoeQuantInfo,
126
+ running_state: dict,
127
+ ) -> torch.Tensor:
128
+
129
+ from sglang.srt.layers.moe.ep_moe.kernels import tma_align_input_scale
130
+ from sglang.srt.layers.quantization.fp8_kernel import (
131
+ sglang_per_token_group_quant_fp8,
132
+ )
133
+
134
+ hidden_states = runner_input.hidden_states
135
+ hidden_states_scale = runner_input.hidden_states_scale
136
+ all_tokens = running_state["all_tokens"]
137
+ hidden_states_device = running_state["hidden_states_device"]
138
+ hidden_states_dtype = running_state["hidden_states_dtype"]
139
+ hidden_states_shape = running_state["hidden_states_shape"]
140
+ m_indices = runner_input.m_indices
141
+
142
+ N = quant_info.w13_weight.size(1)
143
+ K = hidden_states_shape[1]
144
+ scale_block_size = 128
145
+
146
+ w13_weight_fp8 = (
147
+ quant_info.w13_weight,
148
+ quant_info.w13_scale,
149
+ )
150
+ w2_weight_fp8 = (quant_info.w2_weight, quant_info.w2_scale)
151
+
152
+ gateup_output = torch.empty(
153
+ (all_tokens, N),
154
+ device=hidden_states_device,
155
+ dtype=torch.bfloat16,
156
+ )
157
+ if not deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0:
158
+ hidden_states_scale = tma_align_input_scale(hidden_states_scale)
159
+ deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_contig(
160
+ (hidden_states, hidden_states_scale),
161
+ w13_weight_fp8,
162
+ gateup_output,
163
+ m_indices,
164
+ )
165
+
166
+ dispose_tensor(hidden_states)
167
+ dispose_tensor(hidden_states_scale)
168
+
169
+ down_input = torch.empty(
170
+ (
171
+ all_tokens,
172
+ N // 2,
173
+ ),
174
+ device=gateup_output.device,
175
+ dtype=torch.bfloat16,
176
+ )
177
+ silu_and_mul(gateup_output.view(-1, N), down_input)
178
+ del gateup_output
179
+
180
+ down_input_fp8, down_input_scale = sglang_per_token_group_quant_fp8(
181
+ down_input,
182
+ scale_block_size,
183
+ column_major_scales=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0,
184
+ scale_tma_aligned=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0,
185
+ scale_ue8m0=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0,
186
+ )
187
+ del down_input
188
+
189
+ down_output = torch.empty(
190
+ (all_tokens, K),
191
+ device=hidden_states_device,
192
+ dtype=torch.bfloat16,
193
+ )
194
+ if not deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0:
195
+ down_input_scale = tma_align_input_scale(down_input_scale)
196
+
197
+ deep_gemm_wrapper.grouped_gemm_nt_f8f8bf16_contig(
198
+ (down_input_fp8, down_input_scale),
199
+ w2_weight_fp8,
200
+ down_output,
201
+ m_indices,
202
+ )
203
+
204
+ return down_output
205
+
101
206
  def _run_masked_gemm(
102
207
  self,
103
208
  runner_input: DeepGemmRunnerInput,
@@ -149,6 +254,7 @@ class DeepGemmRunnerCore(MoeRunnerCore):
149
254
  expected_m,
150
255
  )
151
256
  dispose_tensor(hidden_states)
257
+ dispose_tensor(hidden_states_scale)
152
258
 
153
259
  # Act
154
260
  down_input = torch.empty(
@@ -198,18 +304,9 @@ class DeepGemmRunnerCore(MoeRunnerCore):
198
304
  masked_m,
199
305
  expected_m,
200
306
  )
201
- del down_input
202
307
 
203
308
  return down_output
204
309
 
205
- def _run_contiguous_gemm(
206
- self,
207
- runner_input: DeepGemmRunnerInput,
208
- quant_info: DeepGemmMoeQuantInfo,
209
- running_state: dict,
210
- ) -> torch.Tensor:
211
- pass
212
-
213
310
  @property
214
311
  def runner_backend(self) -> MoeRunnerBackend:
215
312
  return MoeRunnerBackend.DEEP_GEMM
@@ -222,6 +319,7 @@ def pre_permute_standard_to_deep_gemm(
222
319
  runner_config: MoeRunnerConfig,
223
320
  running_state: dict,
224
321
  ) -> DeepGemmRunnerInput:
322
+
225
323
  from sglang.srt.layers.moe.ep_moe.kernels import moe_ep_deepgemm_preprocess
226
324
 
227
325
  hidden_states, topk_output = dispatch_output
@@ -257,9 +355,9 @@ def pre_permute_standard_to_deep_gemm(
257
355
  return DeepGemmRunnerInput(
258
356
  hidden_states=hidden_states,
259
357
  hidden_states_scale=hidden_states_scale,
358
+ use_masked_gemm=True,
260
359
  masked_m=masked_m,
261
360
  expected_m=expected_m,
262
- use_masked_gemm=True,
263
361
  )
264
362
 
265
363
 
@@ -302,3 +400,170 @@ def post_permute_deep_gemm_to_standard(
302
400
  return StandardCombineInput(
303
401
  hidden_states=output,
304
402
  )
403
+
404
+
405
+ @register_pre_permute("deepep_ll", "deep_gemm")
406
+ def pre_permute_deepep_ll_to_deep_gemm(
407
+ dispatch_output: DeepEPLLDispatchOutput,
408
+ quant_info: DeepGemmMoeQuantInfo,
409
+ runner_config: MoeRunnerConfig,
410
+ running_state: dict,
411
+ ) -> DeepGemmRunnerInput:
412
+
413
+ hidden_states, hidden_states_scale, topk_ids, topk_weights, masked_m, expected_m = (
414
+ dispatch_output
415
+ )
416
+
417
+ running_state["topk_ids"] = topk_ids
418
+ running_state["topk_weights"] = topk_weights
419
+ running_state["hidden_states_shape"] = hidden_states.shape
420
+ running_state["hidden_states_dtype"] = hidden_states.dtype
421
+ running_state["hidden_states_device"] = hidden_states.device
422
+
423
+ return DeepGemmRunnerInput(
424
+ hidden_states=hidden_states,
425
+ hidden_states_scale=hidden_states_scale,
426
+ use_masked_gemm=True,
427
+ masked_m=masked_m,
428
+ expected_m=expected_m,
429
+ )
430
+
431
+
432
+ @register_post_permute("deep_gemm", "deepep_ll")
433
+ def post_permute_deep_gemm_to_deepep_ll(
434
+ runner_output: DeepGemmRunnerOutput,
435
+ quant_info: DeepGemmMoeQuantInfo,
436
+ runner_config: MoeRunnerConfig,
437
+ running_state: dict,
438
+ ) -> DeepEPLLCombineInput:
439
+
440
+ from sglang.srt.layers.moe.token_dispatcher.deepep import DeepEPLLCombineInput
441
+
442
+ return DeepEPLLCombineInput(
443
+ hidden_states=runner_output.hidden_states,
444
+ topk_ids=running_state["topk_ids"],
445
+ topk_weights=running_state["topk_weights"],
446
+ )
447
+
448
+
449
+ @register_pre_permute("deepep_normal", "deep_gemm")
450
+ def pre_permute_deepep_normal_to_deep_gemm(
451
+ dispatch_output: DeepEPNormalDispatchOutput,
452
+ quant_info: DeepGemmMoeQuantInfo,
453
+ runner_config: MoeRunnerConfig,
454
+ running_state: dict,
455
+ ) -> DeepGemmRunnerInput:
456
+
457
+ from sglang.srt.layers.moe.ep_moe.kernels import ep_scatter
458
+
459
+ (
460
+ hidden_states,
461
+ hidden_states_scale,
462
+ topk_ids,
463
+ topk_weights,
464
+ num_recv_tokens_per_expert,
465
+ ) = dispatch_output
466
+ assert runner_config.activation == "silu"
467
+
468
+ all_tokens = sum(num_recv_tokens_per_expert)
469
+ running_state["all_tokens"] = all_tokens
470
+
471
+ K = hidden_states.shape[1]
472
+
473
+ hidden_states_shape = hidden_states.shape
474
+ hidden_states_device = hidden_states.device
475
+ hidden_states_dtype = hidden_states.dtype
476
+
477
+ running_state["hidden_states_shape"] = hidden_states_shape
478
+ running_state["hidden_states_device"] = hidden_states_device
479
+ running_state["hidden_states_dtype"] = hidden_states_dtype
480
+ running_state["topk_ids"] = topk_ids
481
+ running_state["topk_weights"] = topk_weights
482
+
483
+ input_tensor = torch.empty(
484
+ (all_tokens, K),
485
+ device=hidden_states.device,
486
+ dtype=hidden_states.dtype,
487
+ )
488
+ if deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0:
489
+ # TODO check whether need `zeros`
490
+ input_tensor_scale = torch.zeros(
491
+ (ceil_div(K // 128, 4), all_tokens),
492
+ device=hidden_states.device,
493
+ dtype=torch.int,
494
+ ).transpose(0, 1)
495
+ else:
496
+ input_tensor_scale = torch.empty(
497
+ (all_tokens, K // 128),
498
+ device=hidden_states.device,
499
+ dtype=torch.float32,
500
+ )
501
+ m_indices = torch.empty(all_tokens, device=hidden_states.device, dtype=torch.int32)
502
+ output_index = torch.empty_like(topk_ids)
503
+
504
+ if get_offloader().forbid_copy_engine_usage:
505
+ num_recv_tokens_per_expert_gpu = copy_list_to_gpu_no_ce(
506
+ num_recv_tokens_per_expert
507
+ )
508
+ else:
509
+ num_recv_tokens_per_expert_gpu = torch.tensor(
510
+ num_recv_tokens_per_expert,
511
+ dtype=torch.int32,
512
+ pin_memory=True,
513
+ device="cpu",
514
+ ).cuda(non_blocking=True)
515
+ expert_start_loc = torch.empty_like(num_recv_tokens_per_expert_gpu)
516
+
517
+ ep_scatter(
518
+ hidden_states,
519
+ hidden_states_scale,
520
+ topk_ids,
521
+ num_recv_tokens_per_expert_gpu,
522
+ expert_start_loc,
523
+ input_tensor,
524
+ input_tensor_scale,
525
+ m_indices,
526
+ output_index,
527
+ scale_ue8m0=deep_gemm_wrapper.DEEPGEMM_SCALE_UE8M0,
528
+ )
529
+ dispose_tensor(hidden_states)
530
+ dispose_tensor(hidden_states_scale)
531
+
532
+ running_state["output_index"] = output_index
533
+
534
+ return DeepGemmRunnerInput(
535
+ hidden_states=input_tensor,
536
+ hidden_states_scale=input_tensor_scale,
537
+ use_masked_gemm=False,
538
+ m_indices=m_indices,
539
+ )
540
+
541
+
542
+ @register_post_permute("deep_gemm", "deepep_normal")
543
+ def post_permute_deep_gemm_to_deepep_normal(
544
+ runner_output: DeepGemmRunnerOutput,
545
+ quant_info: DeepGemmMoeQuantInfo,
546
+ runner_config: MoeRunnerConfig,
547
+ running_state: dict,
548
+ ) -> DeepEPNormalCombineInput:
549
+
550
+ from sglang.srt.layers.moe.ep_moe.kernels import ep_gather
551
+ from sglang.srt.layers.moe.token_dispatcher.deepep import DeepEPNormalCombineInput
552
+
553
+ hidden_states = runner_output.hidden_states
554
+ topk_ids = running_state["topk_ids"]
555
+ topk_weights = running_state["topk_weights"]
556
+ output_index = running_state["output_index"]
557
+
558
+ gather_out = torch.empty(
559
+ running_state["hidden_states_shape"],
560
+ device=running_state["hidden_states_device"],
561
+ dtype=torch.bfloat16,
562
+ )
563
+ ep_gather(hidden_states, topk_ids, topk_weights, output_index, gather_out)
564
+
565
+ return DeepEPNormalCombineInput(
566
+ hidden_states=gather_out,
567
+ topk_ids=running_state["topk_ids"],
568
+ topk_weights=running_state["topk_weights"],
569
+ )
@@ -11,6 +11,7 @@ from sglang.srt.layers.moe.moe_runner.base import (
11
11
  )
12
12
  from sglang.srt.layers.moe.moe_runner.deep_gemm import DeepGemmRunnerCore
13
13
  from sglang.srt.layers.moe.moe_runner.triton import TritonRunnerCore
14
+ from sglang.srt.layers.moe.moe_runner.triton_kernels import TritonKernelsRunnerCore
14
15
  from sglang.srt.layers.moe.utils import get_moe_a2a_backend
15
16
 
16
17
  if TYPE_CHECKING:
@@ -31,6 +32,8 @@ class MoeRunner:
31
32
 
32
33
  if runner_backend.is_triton():
33
34
  self.runner_core = TritonRunnerCore(config)
35
+ elif runner_backend.is_triton_kernels():
36
+ self.runner_core = TritonKernelsRunnerCore(config)
34
37
  elif runner_backend.is_deep_gemm():
35
38
  self.runner_core = DeepGemmRunnerCore(config)
36
39
  else:
@@ -0,0 +1,194 @@
1
+ """Triton kernels MoE runner backend skeleton."""
2
+
3
+ from __future__ import annotations
4
+
5
+ from dataclasses import dataclass
6
+ from typing import TYPE_CHECKING, Optional
7
+
8
+ import torch
9
+
10
+ from sglang.srt.layers.moe.moe_runner.base import (
11
+ MoeQuantInfo,
12
+ MoeRunnerConfig,
13
+ MoeRunnerCore,
14
+ RunnerInput,
15
+ RunnerOutput,
16
+ register_post_permute,
17
+ register_pre_permute,
18
+ )
19
+ from sglang.srt.layers.moe.utils import MoeRunnerBackend
20
+
21
+ if TYPE_CHECKING:
22
+ from triton_kernels.matmul_ogs import PrecisionConfig
23
+ from triton_kernels.routing import GatherIndx, RoutingData, ScatterIndx
24
+
25
+ from sglang.srt.layers.moe.token_dispatcher.standard import (
26
+ StandardCombineInput,
27
+ StandardDispatchOutput,
28
+ )
29
+
30
+
31
+ # ---------------------------------------------------------------------------
32
+ # Runner IO dataclasses
33
+ # ---------------------------------------------------------------------------
34
+
35
+
36
+ @dataclass
37
+ class TritonKernelsRunnerInput(RunnerInput):
38
+ """Input bundle passed to the triton-kernels runner core."""
39
+
40
+ hidden_states: torch.Tensor
41
+ routing_data: "RoutingData"
42
+ gather_indx: "GatherIndx"
43
+ scatter_indx: "ScatterIndx"
44
+
45
+ @property
46
+ def runner_backend(self) -> MoeRunnerBackend:
47
+ return MoeRunnerBackend.TRITON_KERNELS
48
+
49
+
50
+ @dataclass
51
+ class TritonKernelsRunnerOutput(RunnerOutput):
52
+ """Output bundle returned from the triton-kernels runner core."""
53
+
54
+ hidden_states: torch.Tensor
55
+
56
+ @property
57
+ def runner_backend(self) -> MoeRunnerBackend:
58
+ return MoeRunnerBackend.TRITON_KERNELS
59
+
60
+
61
+ @dataclass
62
+ class TritonKernelsQuantInfo(MoeQuantInfo):
63
+ """Quantization payload consumed by the triton-kernels backend."""
64
+
65
+ w13_weight: torch.Tensor
66
+ w2_weight: torch.Tensor
67
+ w13_bias: Optional[torch.Tensor] = None
68
+ w2_bias: Optional[torch.Tensor] = None
69
+ w13_precision_config: Optional[PrecisionConfig] = None
70
+ w2_precision_config: Optional[PrecisionConfig] = None
71
+ global_num_experts: int = -1
72
+
73
+
74
+ # ---------------------------------------------------------------------------
75
+ # Runner core
76
+ # ---------------------------------------------------------------------------
77
+
78
+
79
+ class TritonKernelsRunnerCore(MoeRunnerCore):
80
+ """Execute MoE experts via the external triton_kernels package."""
81
+
82
+ def run(
83
+ self,
84
+ runner_input: TritonKernelsRunnerInput,
85
+ quant_info: TritonKernelsQuantInfo,
86
+ running_state: dict,
87
+ ) -> TritonKernelsRunnerOutput:
88
+ from sglang.srt.layers.moe.fused_moe_triton.triton_kernels_moe import (
89
+ triton_kernel_fused_experts,
90
+ triton_kernel_fused_experts_with_bias,
91
+ )
92
+
93
+ hidden_states = runner_input.hidden_states
94
+
95
+ common_kwargs = dict(
96
+ routing_data=runner_input.routing_data,
97
+ gather_indx=runner_input.gather_indx,
98
+ scatter_indx=None if self.config.no_combine else runner_input.scatter_indx,
99
+ inplace=False,
100
+ activation=self.config.activation,
101
+ apply_router_weight_on_input=self.config.apply_router_weight_on_input,
102
+ global_num_experts=quant_info.global_num_experts,
103
+ )
104
+
105
+ has_bias = quant_info.w13_bias is not None or quant_info.w2_bias is not None
106
+
107
+ if has_bias:
108
+ assert (
109
+ quant_info.w13_bias is not None and quant_info.w2_bias is not None
110
+ ), "Bias execution requires both w13_bias and w2_bias"
111
+ output = triton_kernel_fused_experts_with_bias(
112
+ hidden_states=hidden_states,
113
+ w1=quant_info.w13_weight,
114
+ w1_pcg=quant_info.w13_precision_config,
115
+ b1=quant_info.w13_bias,
116
+ w2=quant_info.w2_weight,
117
+ w2_pcg=quant_info.w2_precision_config,
118
+ b2=quant_info.w2_bias,
119
+ gemm1_alpha=self.config.gemm1_alpha,
120
+ gemm1_clamp_limit=self.config.gemm1_clamp_limit,
121
+ **common_kwargs,
122
+ )
123
+ else:
124
+ output = triton_kernel_fused_experts(
125
+ hidden_states=hidden_states,
126
+ w1=quant_info.w13_weight,
127
+ w2=quant_info.w2_weight,
128
+ **common_kwargs,
129
+ )
130
+
131
+ if self.config.no_combine:
132
+ tokens = runner_input.hidden_states.shape[0]
133
+ hidden = runner_input.hidden_states.shape[-1]
134
+ total_rows = output.shape[0]
135
+ top_k = total_rows // tokens
136
+ output = output.view(tokens, top_k, hidden)
137
+
138
+ return TritonKernelsRunnerOutput(hidden_states=output)
139
+
140
+ @property
141
+ def runner_backend(self) -> MoeRunnerBackend:
142
+ return MoeRunnerBackend.TRITON_KERNELS
143
+
144
+
145
+ # ---------------------------------------------------------------------------
146
+ # Permute / fused hooks
147
+ # ---------------------------------------------------------------------------
148
+
149
+
150
+ @register_pre_permute("standard", "triton_kernel")
151
+ def pre_permute_standard_to_triton_kernels(
152
+ dispatch_output: "StandardDispatchOutput",
153
+ quant_info: TritonKernelsQuantInfo,
154
+ runner_config: MoeRunnerConfig,
155
+ running_state: dict,
156
+ ) -> TritonKernelsRunnerInput:
157
+ from sglang.srt.layers.moe.topk import TopKOutputChecker
158
+
159
+ hidden_states = dispatch_output.hidden_states
160
+ topk_output = dispatch_output.topk_output
161
+
162
+ assert TopKOutputChecker.format_is_triton_kernels(
163
+ topk_output
164
+ ), "Triton-kernel runner expects TritonKernelTopKOutput"
165
+
166
+ routing_data, gather_indx, scatter_indx = topk_output
167
+
168
+ return TritonKernelsRunnerInput(
169
+ hidden_states=hidden_states,
170
+ routing_data=routing_data,
171
+ gather_indx=gather_indx,
172
+ scatter_indx=scatter_indx,
173
+ )
174
+
175
+
176
+ @register_post_permute("triton_kernel", "standard")
177
+ def post_permute_triton_kernels_to_standard(
178
+ runner_output: TritonKernelsRunnerOutput,
179
+ quant_info: TritonKernelsQuantInfo,
180
+ runner_config: MoeRunnerConfig,
181
+ running_state: dict,
182
+ ) -> StandardCombineInput:
183
+ from sglang.srt.layers.moe.token_dispatcher.standard import StandardCombineInput
184
+
185
+ hidden_states = runner_output.hidden_states
186
+
187
+ if (
188
+ runner_config.routed_scaling_factor is not None
189
+ and runner_config.routed_scaling_factor != 1.0
190
+ and not runner_config.no_combine
191
+ ):
192
+ hidden_states.mul_(runner_config.routed_scaling_factor)
193
+
194
+ return StandardCombineInput(hidden_states=hidden_states)
@@ -12,9 +12,9 @@ from sglang.srt.layers.moe.token_dispatcher.deepep import (
12
12
  DeepEPConfig,
13
13
  DeepEPDispatcher,
14
14
  DeepEPLLCombineInput,
15
- DeepEPLLOutput,
15
+ DeepEPLLDispatchOutput,
16
16
  DeepEPNormalCombineInput,
17
- DeepEPNormalOutput,
17
+ DeepEPNormalDispatchOutput,
18
18
  )
19
19
  from sglang.srt.layers.moe.token_dispatcher.mooncake import (
20
20
  MooncakeCombineInput,
@@ -44,8 +44,8 @@ __all__ = [
44
44
  "StandardCombineInput",
45
45
  "DeepEPConfig",
46
46
  "DeepEPDispatcher",
47
- "DeepEPNormalOutput",
48
- "DeepEPLLOutput",
47
+ "DeepEPNormalDispatchOutput",
48
+ "DeepEPLLDispatchOutput",
49
49
  "DeepEPLLCombineInput",
50
50
  "DeepEPNormalCombineInput",
51
51
  ]
@@ -9,9 +9,9 @@ import torch
9
9
  if TYPE_CHECKING:
10
10
  from sglang.srt.layers.moe.token_dispatcher import (
11
11
  DeepEPLLCombineInput,
12
- DeepEPLLOutput,
12
+ DeepEPLLDispatchOutput,
13
13
  DeepEPNormalCombineInput,
14
- DeepEPNormalOutput,
14
+ DeepEPNormalDispatchOutput,
15
15
  StandardCombineInput,
16
16
  StandardDispatchOutput,
17
17
  )
@@ -28,22 +28,28 @@ class DispatchOutputChecker:
28
28
  ) -> TypeGuard[StandardDispatchOutput]:
29
29
  return dispatch_output.format.is_standard()
30
30
 
31
+ @staticmethod
32
+ def format_is_triton_kernels(
33
+ dispatch_output: DispatchOutput,
34
+ ) -> TypeGuard[StandardDispatchOutput]:
35
+ return dispatch_output.format.is_standard()
36
+
31
37
  @staticmethod
32
38
  def format_is_deepep_normal(
33
39
  dispatch_output: DispatchOutput,
34
- ) -> TypeGuard[DeepEPNormalOutput]:
40
+ ) -> TypeGuard[DeepEPNormalDispatchOutput]:
35
41
  return dispatch_output.format.is_deepep_normal()
36
42
 
37
43
  @staticmethod
38
44
  def format_is_deepep_ll(
39
45
  dispatch_output: DispatchOutput,
40
- ) -> TypeGuard[DeepEPLLOutput]:
46
+ ) -> TypeGuard[DeepEPLLDispatchOutput]:
41
47
  return dispatch_output.format.is_deepep_ll()
42
48
 
43
49
  @staticmethod
44
50
  def format_is_deepep(
45
51
  dispatch_output: DispatchOutput,
46
- ) -> TypeGuard[Union[DeepEPNormalOutput, DeepEPLLOutput]]:
52
+ ) -> TypeGuard[Union[DeepEPNormalDispatchOutput, DeepEPLLDispatchOutput]]:
47
53
  return dispatch_output.format.is_deepep()
48
54
 
49
55