sglang 0.4.6.post4__py3-none-any.whl → 0.4.6.post5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (130) hide show
  1. sglang/bench_offline_throughput.py +6 -6
  2. sglang/bench_one_batch.py +5 -4
  3. sglang/bench_one_batch_server.py +23 -15
  4. sglang/bench_serving.py +133 -57
  5. sglang/compile_deep_gemm.py +4 -4
  6. sglang/srt/configs/model_config.py +39 -28
  7. sglang/srt/conversation.py +1 -1
  8. sglang/srt/disaggregation/decode.py +122 -133
  9. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +142 -0
  10. sglang/srt/disaggregation/fake/conn.py +3 -13
  11. sglang/srt/disaggregation/kv_events.py +357 -0
  12. sglang/srt/disaggregation/mini_lb.py +57 -24
  13. sglang/srt/disaggregation/mooncake/conn.py +11 -2
  14. sglang/srt/disaggregation/mooncake/transfer_engine.py +2 -1
  15. sglang/srt/disaggregation/nixl/conn.py +9 -19
  16. sglang/srt/disaggregation/prefill.py +126 -44
  17. sglang/srt/disaggregation/utils.py +116 -5
  18. sglang/srt/distributed/utils.py +3 -3
  19. sglang/srt/entrypoints/EngineBase.py +5 -0
  20. sglang/srt/entrypoints/engine.py +28 -8
  21. sglang/srt/entrypoints/http_server.py +6 -4
  22. sglang/srt/entrypoints/http_server_engine.py +5 -2
  23. sglang/srt/function_call/base_format_detector.py +250 -0
  24. sglang/srt/function_call/core_types.py +34 -0
  25. sglang/srt/function_call/deepseekv3_detector.py +157 -0
  26. sglang/srt/function_call/ebnf_composer.py +234 -0
  27. sglang/srt/function_call/function_call_parser.py +175 -0
  28. sglang/srt/function_call/llama32_detector.py +74 -0
  29. sglang/srt/function_call/mistral_detector.py +84 -0
  30. sglang/srt/function_call/pythonic_detector.py +163 -0
  31. sglang/srt/function_call/qwen25_detector.py +67 -0
  32. sglang/srt/function_call/utils.py +35 -0
  33. sglang/srt/hf_transformers_utils.py +46 -7
  34. sglang/srt/layers/attention/aiter_backend.py +513 -0
  35. sglang/srt/layers/attention/flashattention_backend.py +63 -17
  36. sglang/srt/layers/attention/flashinfer_mla_backend.py +8 -4
  37. sglang/srt/layers/attention/flashmla_backend.py +340 -78
  38. sglang/srt/layers/attention/triton_backend.py +3 -0
  39. sglang/srt/layers/attention/utils.py +2 -2
  40. sglang/srt/layers/attention/vision.py +1 -1
  41. sglang/srt/layers/communicator.py +451 -0
  42. sglang/srt/layers/dp_attention.py +0 -10
  43. sglang/srt/layers/moe/cutlass_moe.py +207 -0
  44. sglang/srt/layers/moe/ep_moe/kernels.py +33 -11
  45. sglang/srt/layers/moe/ep_moe/layer.py +104 -50
  46. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +82 -7
  47. sglang/srt/layers/moe/fused_moe_triton/layer.py +14 -0
  48. sglang/srt/layers/moe/topk.py +66 -9
  49. sglang/srt/layers/multimodal.py +70 -0
  50. sglang/srt/layers/quantization/__init__.py +7 -2
  51. sglang/srt/layers/quantization/deep_gemm.py +5 -3
  52. sglang/srt/layers/quantization/fp8.py +90 -0
  53. sglang/srt/layers/quantization/fp8_utils.py +6 -0
  54. sglang/srt/layers/quantization/gptq.py +298 -6
  55. sglang/srt/layers/quantization/int8_kernel.py +18 -5
  56. sglang/srt/layers/quantization/qoq.py +244 -0
  57. sglang/srt/lora/lora_manager.py +1 -3
  58. sglang/srt/managers/deepseek_eplb.py +278 -0
  59. sglang/srt/managers/eplb_manager.py +55 -0
  60. sglang/srt/managers/expert_distribution.py +704 -56
  61. sglang/srt/managers/expert_location.py +394 -0
  62. sglang/srt/managers/expert_location_dispatch.py +91 -0
  63. sglang/srt/managers/io_struct.py +16 -3
  64. sglang/srt/managers/mm_utils.py +293 -139
  65. sglang/srt/managers/multimodal_processors/base_processor.py +127 -42
  66. sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +6 -1
  67. sglang/srt/managers/multimodal_processors/gemma3.py +31 -6
  68. sglang/srt/managers/multimodal_processors/internvl.py +14 -5
  69. sglang/srt/managers/multimodal_processors/janus_pro.py +7 -1
  70. sglang/srt/managers/multimodal_processors/kimi_vl.py +7 -6
  71. sglang/srt/managers/multimodal_processors/llava.py +3 -3
  72. sglang/srt/managers/multimodal_processors/minicpm.py +25 -31
  73. sglang/srt/managers/multimodal_processors/mllama4.py +6 -0
  74. sglang/srt/managers/multimodal_processors/pixtral.py +9 -9
  75. sglang/srt/managers/multimodal_processors/qwen_vl.py +58 -16
  76. sglang/srt/managers/schedule_batch.py +49 -21
  77. sglang/srt/managers/schedule_policy.py +4 -5
  78. sglang/srt/managers/scheduler.py +92 -50
  79. sglang/srt/managers/session_controller.py +1 -1
  80. sglang/srt/managers/tokenizer_manager.py +99 -24
  81. sglang/srt/mem_cache/base_prefix_cache.py +3 -0
  82. sglang/srt/mem_cache/chunk_cache.py +3 -1
  83. sglang/srt/mem_cache/hiradix_cache.py +4 -4
  84. sglang/srt/mem_cache/memory_pool.py +74 -52
  85. sglang/srt/mem_cache/multimodal_cache.py +45 -0
  86. sglang/srt/mem_cache/radix_cache.py +58 -5
  87. sglang/srt/metrics/collector.py +2 -2
  88. sglang/srt/mm_utils.py +10 -0
  89. sglang/srt/model_executor/cuda_graph_runner.py +20 -9
  90. sglang/srt/model_executor/expert_location_updater.py +422 -0
  91. sglang/srt/model_executor/forward_batch_info.py +4 -0
  92. sglang/srt/model_executor/model_runner.py +144 -54
  93. sglang/srt/model_loader/loader.py +10 -6
  94. sglang/srt/models/clip.py +5 -1
  95. sglang/srt/models/deepseek_v2.py +297 -343
  96. sglang/srt/models/exaone.py +8 -3
  97. sglang/srt/models/gemma3_mm.py +70 -33
  98. sglang/srt/models/llama4.py +10 -2
  99. sglang/srt/models/llava.py +26 -18
  100. sglang/srt/models/mimo_mtp.py +220 -0
  101. sglang/srt/models/minicpmo.py +5 -12
  102. sglang/srt/models/mistral.py +71 -1
  103. sglang/srt/models/mllama.py +3 -3
  104. sglang/srt/models/qwen2.py +95 -26
  105. sglang/srt/models/qwen2_5_vl.py +8 -0
  106. sglang/srt/models/qwen2_moe.py +330 -60
  107. sglang/srt/models/qwen2_vl.py +6 -0
  108. sglang/srt/models/qwen3.py +52 -10
  109. sglang/srt/models/qwen3_moe.py +411 -48
  110. sglang/srt/models/siglip.py +294 -0
  111. sglang/srt/openai_api/adapter.py +28 -16
  112. sglang/srt/openai_api/protocol.py +6 -0
  113. sglang/srt/operations.py +154 -0
  114. sglang/srt/operations_strategy.py +31 -0
  115. sglang/srt/server_args.py +134 -24
  116. sglang/srt/speculative/eagle_utils.py +131 -0
  117. sglang/srt/speculative/eagle_worker.py +47 -2
  118. sglang/srt/utils.py +68 -12
  119. sglang/test/test_cutlass_moe.py +278 -0
  120. sglang/test/test_utils.py +2 -36
  121. sglang/utils.py +2 -2
  122. sglang/version.py +1 -1
  123. {sglang-0.4.6.post4.dist-info → sglang-0.4.6.post5.dist-info}/METADATA +20 -11
  124. {sglang-0.4.6.post4.dist-info → sglang-0.4.6.post5.dist-info}/RECORD +128 -102
  125. {sglang-0.4.6.post4.dist-info → sglang-0.4.6.post5.dist-info}/WHEEL +1 -1
  126. sglang/srt/function_call_parser.py +0 -858
  127. sglang/srt/platforms/interface.py +0 -371
  128. /sglang/srt/models/{xiaomi_mimo.py → mimo.py} +0 -0
  129. {sglang-0.4.6.post4.dist-info → sglang-0.4.6.post5.dist-info}/licenses/LICENSE +0 -0
  130. {sglang-0.4.6.post4.dist-info → sglang-0.4.6.post5.dist-info}/top_level.txt +0 -0
@@ -3,10 +3,9 @@ from typing import List, Optional
3
3
 
4
4
  import torch
5
5
  import triton
6
- import triton.language as tl
7
6
 
8
7
  from sglang.srt.layers.quantization.fp8_kernel import per_token_group_quant_fp8
9
- from sglang.srt.utils import is_cuda
8
+ from sglang.srt.utils import dispose_tensor, is_cuda
10
9
 
11
10
  logger = logging.getLogger(__name__)
12
11
 
@@ -653,12 +652,15 @@ def grouped_gemm_triton(
653
652
  scale_a: torch.Tensor = None,
654
653
  scale_b: torch.Tensor = None,
655
654
  block_shape: Optional[List[int]] = None,
655
+ c_dtype=None,
656
656
  ):
657
657
  assert weight_column_major == True # TODO: more
658
658
  if use_fp8_w8a8 and block_shape is None:
659
659
  assert scale_a is not None and scale_b is not None
660
660
 
661
661
  if block_shape is not None:
662
+ a_original = a
663
+
662
664
  assert len(block_shape) == 2
663
665
  block_n, block_k = block_shape[0], block_shape[1]
664
666
  a, scale_a = per_token_group_quant_fp8(a, block_k)
@@ -667,6 +669,8 @@ def grouped_gemm_triton(
667
669
  assert triton.cdiv(b.shape[-2], block_n) == scale_b.shape[-2]
668
670
  assert triton.cdiv(b.shape[-1], block_k) == scale_b.shape[-1]
669
671
 
672
+ dispose_tensor(a_original)
673
+
670
674
  # TODO: adjust config or tune kernel
671
675
  # Reduce block size to prevent L40 shared memory overflow.
672
676
  config = {
@@ -680,6 +684,10 @@ def grouped_gemm_triton(
680
684
  m_num_tiles_indptr, seg_indptr, batch_size, config["BLOCK_SIZE_M"]
681
685
  )
682
686
 
687
+ if c is None:
688
+ assert c_dtype is not None
689
+ c = torch.empty(a.shape[0], b.shape[1], device=a.device, dtype=c_dtype)
690
+
683
691
  grid = lambda META: (
684
692
  triton.cdiv(a.size(0), META["BLOCK_SIZE_M"]) + batch_size,
685
693
  triton.cdiv(b.size(1), META["BLOCK_SIZE_N"]),
@@ -783,19 +791,23 @@ def _fwd_kernel_ep_scatter_2(
783
791
  offset_in_s = tl.arange(0, SCALE_HIDDEN_SIZE_PAD)
784
792
  mask_s = offset_in_s < SCALE_HIDDEN_SIZE
785
793
 
786
- for token_id in range(start_token_id, total_token_num, grid_num):
794
+ for token_id_int32 in range(start_token_id, total_token_num, grid_num):
795
+ token_id = token_id_int32.to(tl.int64)
787
796
  to_copy = tl.load(recv_x + token_id * recv_x_stride0 + offset_in, mask=mask)
788
797
  to_copy_s = tl.load(
789
798
  recv_x_scale + token_id * recv_x_scale_stride0 + offset_in_s, mask=mask_s
790
799
  )
791
800
 
792
- for topk_index in tl.range(0, topk_num, 1, num_stages=4):
801
+ for topk_idx_int32 in tl.range(0, topk_num, 1, num_stages=4):
802
+ topk_index = topk_idx_int32.to(tl.int64)
793
803
  expert_id = tl.load(recv_topk + token_id * recv_topk_stride0 + topk_index)
794
804
  if expert_id >= 0:
795
- dest_token_index = tl.atomic_add(expert_start_loc + expert_id, 1)
805
+ dest_token_index_int32 = tl.atomic_add(expert_start_loc + expert_id, 1)
806
+ dest_token_index = dest_token_index_int32.to(tl.int64)
807
+
796
808
  tl.store(
797
809
  output_index + token_id * output_index_stride0 + topk_index,
798
- dest_token_index,
810
+ dest_token_index_int32,
799
811
  )
800
812
  output_tensor_ptr = (
801
813
  output_tensor + dest_token_index * output_tensor_stride0
@@ -894,21 +906,31 @@ def _fwd_kernel_ep_gather(
894
906
  topk_num: tl.constexpr,
895
907
  BLOCK_D: tl.constexpr,
896
908
  ):
897
- cur_block = tl.program_id(0)
898
- start_cur_token = tl.program_id(1)
909
+ cur_block_int32 = tl.program_id(0)
910
+ cur_block = cur_block_int32.to(tl.int64)
911
+
912
+ start_cur_token_int32 = tl.program_id(1)
913
+
899
914
  grid_num = tl.num_programs(1)
900
915
 
901
- for cur_token in range(start_cur_token, total_token_num, grid_num):
916
+ for cur_token_int32 in range(start_cur_token_int32, total_token_num, grid_num):
917
+ cur_token = cur_token_int32.to(tl.int64)
918
+
902
919
  off_d = tl.arange(0, BLOCK_D)
903
920
  accumulator = tl.zeros([BLOCK_D], dtype=tl.float32)
904
- for topk_index in range(0, topk_num):
921
+
922
+ for topk_index_int32 in range(0, topk_num):
923
+ topk_index = topk_index_int32.to(tl.int64)
924
+
905
925
  expert_id = tl.load(
906
926
  recv_topk_ids + cur_token * recv_topk_ids_stride0 + topk_index
907
927
  )
908
928
  if expert_id >= 0:
909
- source_token_index = tl.load(
929
+ source_token_index_int32 = tl.load(
910
930
  input_index + cur_token * input_index_stride0 + topk_index
911
931
  )
932
+ source_token_index = source_token_index_int32.to(tl.int64)
933
+
912
934
  acc_weight = tl.load(
913
935
  recv_topk_weight + cur_token * recv_topk_weight_stride0 + topk_index
914
936
  )
@@ -5,6 +5,9 @@ import torch
5
5
  from torch.nn import Module
6
6
 
7
7
  from sglang.srt.layers.quantization.deep_gemm import _ENABLE_JIT_DEEPGEMM
8
+ from sglang.srt.managers.expert_location import get_global_expert_location_metadata
9
+ from sglang.srt.managers.expert_location_dispatch import ExpertLocationDispatchInfo
10
+ from sglang.srt.managers.schedule_batch import global_server_args_dict
8
11
 
9
12
  try:
10
13
  from deep_gemm import (
@@ -40,7 +43,7 @@ from sglang.srt.layers.moe.ep_moe.kernels import (
40
43
  tma_align_input_scale,
41
44
  )
42
45
  from sglang.srt.layers.moe.fused_moe_triton import FusedMoeWeightScaleSupported
43
- from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoEMethodBase
46
+ from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE, FusedMoEMethodBase
44
47
  from sglang.srt.layers.moe.topk import select_experts
45
48
  from sglang.srt.layers.quantization.base_config import (
46
49
  QuantizationConfig,
@@ -49,7 +52,7 @@ from sglang.srt.layers.quantization.base_config import (
49
52
  from sglang.srt.layers.quantization.fp8 import Fp8Config, Fp8MoEMethod
50
53
  from sglang.srt.layers.quantization.fp8_kernel import scaled_fp8_quant
51
54
  from sglang.srt.model_executor.forward_batch_info import ForwardMode
52
- from sglang.srt.utils import DeepEPMode, is_hip, set_weight_attrs
55
+ from sglang.srt.utils import DeepEPMode, dispose_tensor, is_hip, set_weight_attrs
53
56
 
54
57
  _is_hip = is_hip()
55
58
 
@@ -92,6 +95,7 @@ class GroupedGemmRunner(torch.nn.Module):
92
95
  scale_a: torch.Tensor = None,
93
96
  scale_b: torch.Tensor = None,
94
97
  block_shape: Optional[List[int]] = None,
98
+ c_dtype=None,
95
99
  ):
96
100
  if self.use_flashinfer:
97
101
  # TODO: flashinfer
@@ -119,6 +123,7 @@ class GroupedGemmRunner(torch.nn.Module):
119
123
  scale_a,
120
124
  scale_b,
121
125
  block_shape=block_shape,
126
+ c_dtype=c_dtype,
122
127
  )
123
128
  return c
124
129
 
@@ -136,6 +141,7 @@ class EPMoE(torch.nn.Module):
136
141
  top_k: int,
137
142
  hidden_size: int,
138
143
  intermediate_size: int,
144
+ layer_id: int,
139
145
  params_dtype: Optional[torch.dtype] = None,
140
146
  renormalize: bool = True,
141
147
  use_grouped_topk: bool = False,
@@ -159,6 +165,7 @@ class EPMoE(torch.nn.Module):
159
165
  )
160
166
  self.tp_rank = get_tensor_model_parallel_rank()
161
167
 
168
+ self.layer_id = layer_id
162
169
  self.num_experts = num_experts
163
170
  assert self.num_experts % self.tp_size == 0
164
171
  self.num_experts_per_partition = self.num_experts // self.tp_size
@@ -210,6 +217,10 @@ class EPMoE(torch.nn.Module):
210
217
  self.grouped_gemm_runner = None
211
218
 
212
219
  def forward(self, hidden_states: torch.Tensor, router_logits: torch.Tensor):
220
+ hidden_states_shape = hidden_states.shape
221
+ hidden_states_dtype = hidden_states.dtype
222
+ hidden_states_device = hidden_states.device
223
+
213
224
  assert self.quant_method is not None
214
225
 
215
226
  if self.grouped_gemm_runner is None:
@@ -229,6 +240,9 @@ class EPMoE(torch.nn.Module):
229
240
  correction_bias=self.correction_bias,
230
241
  custom_routing_function=self.custom_routing_function,
231
242
  routed_scaling_factor=self.routed_scaling_factor,
243
+ expert_location_dispatch_info=ExpertLocationDispatchInfo.init_new(
244
+ layer_id=self.layer_id,
245
+ ),
232
246
  )
233
247
 
234
248
  reorder_topk_ids, src2dst, seg_indptr = run_moe_ep_preproess(
@@ -265,25 +279,21 @@ class EPMoE(torch.nn.Module):
265
279
  hidden_states.shape[1],
266
280
  BLOCK_SIZE=512,
267
281
  )
282
+ dispose_tensor(hidden_states)
268
283
 
269
284
  seg_indptr_cur_rank = seg_indptr[self.start_expert_id : self.end_expert_id + 2]
270
285
  weight_indices_cur_rank = torch.arange(
271
286
  0,
272
287
  self.num_experts_per_partition,
273
- device=hidden_states.device,
288
+ device=hidden_states_device,
274
289
  dtype=torch.int64,
275
290
  )
276
291
  # GroupGemm-0
277
- gateup_output = torch.empty(
278
- gateup_input.shape[0],
279
- self.w13_weight.shape[1],
280
- device=hidden_states.device,
281
- dtype=hidden_states.dtype,
282
- )
283
292
  gateup_output = self.grouped_gemm_runner(
284
293
  a=gateup_input,
285
294
  b=self.w13_weight,
286
- c=gateup_output,
295
+ c=None,
296
+ c_dtype=hidden_states_dtype,
287
297
  batch_size=self.num_experts_per_partition,
288
298
  weight_column_major=True,
289
299
  seg_indptr=seg_indptr_cur_rank,
@@ -297,6 +307,7 @@ class EPMoE(torch.nn.Module):
297
307
  ),
298
308
  block_shape=self.block_shape,
299
309
  )
310
+ del gateup_input
300
311
 
301
312
  # Act
302
313
  down_input = torch.empty(
@@ -306,14 +317,14 @@ class EPMoE(torch.nn.Module):
306
317
  dtype=(
307
318
  self.fp8_dtype
308
319
  if (self.use_fp8_w8a8 and not self.use_block_quant)
309
- else hidden_states.dtype
320
+ else hidden_states_dtype
310
321
  ),
311
322
  )
312
323
  if self.w2_input_scale is None and not self.use_block_quant:
313
324
  self.w2_input_scale = torch.ones(
314
325
  self.num_experts_per_partition,
315
326
  dtype=torch.float32,
316
- device=hidden_states.device,
327
+ device=hidden_states_device,
317
328
  )
318
329
 
319
330
  if self.activation == "silu":
@@ -340,13 +351,14 @@ class EPMoE(torch.nn.Module):
340
351
  )
341
352
  else:
342
353
  raise ValueError(f"Unsupported activation: {self.activation=}")
354
+ del gateup_output
343
355
 
344
356
  # GroupGemm-1
345
357
  down_output = torch.empty(
346
358
  down_input.shape[0],
347
359
  self.w2_weight.shape[1],
348
- device=hidden_states.device,
349
- dtype=hidden_states.dtype,
360
+ device=hidden_states_device,
361
+ dtype=hidden_states_dtype,
350
362
  )
351
363
  down_output = self.grouped_gemm_runner(
352
364
  a=down_input,
@@ -365,10 +377,13 @@ class EPMoE(torch.nn.Module):
365
377
  ),
366
378
  block_shape=self.block_shape,
367
379
  )
380
+ del down_input
368
381
 
369
382
  # PostReorder
370
- output = torch.empty_like(hidden_states)
371
- post_reorder_triton_kernel[(hidden_states.size(0),)](
383
+ output = torch.empty(
384
+ hidden_states_shape, dtype=hidden_states_dtype, device=hidden_states_device
385
+ )
386
+ post_reorder_triton_kernel[(hidden_states_shape[0],)](
372
387
  down_output,
373
388
  output,
374
389
  src2dst,
@@ -377,7 +392,7 @@ class EPMoE(torch.nn.Module):
377
392
  self.start_expert_id,
378
393
  self.end_expert_id,
379
394
  self.top_k,
380
- hidden_states.size(1),
395
+ hidden_states_shape[1],
381
396
  BLOCK_SIZE=512,
382
397
  )
383
398
  return output
@@ -417,6 +432,28 @@ class EPMoE(torch.nn.Module):
417
432
  weight_name: str,
418
433
  shard_id: str,
419
434
  expert_id: int,
435
+ ) -> None:
436
+ physical_expert_ids = (
437
+ get_global_expert_location_metadata().logical_to_all_physical(
438
+ self.layer_id, expert_id
439
+ )
440
+ )
441
+ for physical_expert_id in physical_expert_ids:
442
+ self._weight_loader_physical(
443
+ param=param,
444
+ loaded_weight=loaded_weight,
445
+ weight_name=weight_name,
446
+ shard_id=shard_id,
447
+ expert_id=physical_expert_id,
448
+ )
449
+
450
+ def _weight_loader_physical(
451
+ self,
452
+ param: torch.nn.Parameter,
453
+ loaded_weight: torch.Tensor,
454
+ weight_name: str,
455
+ shard_id: str,
456
+ expert_id: int,
420
457
  ) -> None:
421
458
  if expert_id < self.start_expert_id or expert_id > self.end_expert_id:
422
459
  return
@@ -460,7 +497,8 @@ class EPMoE(torch.nn.Module):
460
497
  # Input scales can be loaded directly and should be equal.
461
498
  if "input_scale" in weight_name:
462
499
  if (
463
- param_data[expert_id] != 1
500
+ (shard_id == "w1" or shard_id == "w3")
501
+ and param_data[expert_id] != 1
464
502
  and (param_data[expert_id] - loaded_weight).abs() > 1e-5
465
503
  ):
466
504
  raise ValueError(
@@ -534,13 +572,10 @@ class UnquantizedEPMoEMethod(FusedMoEMethodBase, CustomOp):
534
572
  set_weight_attrs(w2_weight, extra_weight_attrs)
535
573
 
536
574
  # scale
575
+ layer.register_parameter("w13_input_scale", None)
576
+ layer.register_parameter("w13_weight_scale", None)
577
+
537
578
  ones_tensor = torch.ones(num_experts_per_partition, dtype=torch.float32)
538
- w13_input_scale = torch.nn.Parameter(
539
- ones_tensor,
540
- requires_grad=False,
541
- )
542
- layer.register_parameter("w13_input_scale", w13_input_scale)
543
- set_weight_attrs(w13_input_scale, extra_weight_attrs)
544
579
 
545
580
  w2_input_scale = torch.nn.Parameter(
546
581
  ones_tensor,
@@ -549,13 +584,6 @@ class UnquantizedEPMoEMethod(FusedMoEMethodBase, CustomOp):
549
584
  layer.register_parameter("w2_input_scale", w2_input_scale)
550
585
  set_weight_attrs(w2_input_scale, extra_weight_attrs)
551
586
 
552
- w13_weight_scale = torch.nn.Parameter(
553
- ones_tensor,
554
- requires_grad=False,
555
- )
556
- layer.register_parameter("w13_weight_scale", w13_weight_scale)
557
- set_weight_attrs(w13_weight_scale, extra_weight_attrs)
558
-
559
587
  w2_weight_scale = torch.nn.Parameter(
560
588
  ones_tensor,
561
589
  requires_grad=False,
@@ -802,6 +830,7 @@ class DeepEPMoE(EPMoE):
802
830
  top_k: int,
803
831
  hidden_size: int,
804
832
  intermediate_size: int,
833
+ layer_id: int,
805
834
  params_dtype: Optional[torch.dtype] = None,
806
835
  renormalize: bool = True,
807
836
  use_grouped_topk: bool = False,
@@ -821,6 +850,7 @@ class DeepEPMoE(EPMoE):
821
850
  top_k,
822
851
  hidden_size,
823
852
  intermediate_size,
853
+ layer_id,
824
854
  params_dtype,
825
855
  renormalize,
826
856
  use_grouped_topk,
@@ -881,6 +911,9 @@ class DeepEPMoE(EPMoE):
881
911
  reorder_topk_ids: torch.Tensor,
882
912
  seg_indptr: torch.Tensor,
883
913
  ):
914
+ hidden_states_dtype = hidden_states.dtype
915
+ hidden_states_device = hidden_states.device
916
+
884
917
  assert self.quant_method is not None
885
918
  assert self.activation == "silu"
886
919
  if self.grouped_gemm_runner is None:
@@ -903,18 +936,12 @@ class DeepEPMoE(EPMoE):
903
936
  )
904
937
 
905
938
  # GroupGemm-0
906
- gateup_output = torch.empty(
907
- hidden_states.shape[0],
908
- self.w13_weight.shape[1],
909
- device=hidden_states.device,
910
- dtype=hidden_states.dtype,
911
- )
912
-
913
939
  if hidden_states.shape[0] > 0:
914
940
  gateup_output = self.grouped_gemm_runner(
915
941
  a=hidden_states,
916
942
  b=self.w13_weight,
917
- c=gateup_output,
943
+ c=None,
944
+ c_dtype=hidden_states.dtype,
918
945
  batch_size=self.num_experts_per_partition,
919
946
  weight_column_major=True,
920
947
  seg_indptr=seg_indptr,
@@ -928,6 +955,13 @@ class DeepEPMoE(EPMoE):
928
955
  ),
929
956
  block_shape=self.block_shape,
930
957
  )
958
+ else:
959
+ gateup_output = torch.empty(
960
+ hidden_states.shape[0],
961
+ self.w13_weight.shape[1],
962
+ device=hidden_states.device,
963
+ dtype=hidden_states.dtype,
964
+ )
931
965
 
932
966
  # Act
933
967
  down_input = torch.empty(
@@ -937,14 +971,14 @@ class DeepEPMoE(EPMoE):
937
971
  dtype=(
938
972
  self.fp8_dtype
939
973
  if (self.use_fp8_w8a8 and not self.use_block_quant)
940
- else hidden_states.dtype
974
+ else hidden_states_dtype
941
975
  ),
942
976
  )
943
977
  if self.w2_input_scale is None and not self.use_block_quant:
944
978
  self.w2_input_scale = torch.ones(
945
979
  self.num_experts_per_partition,
946
980
  dtype=torch.float32,
947
- device=hidden_states.device,
981
+ device=hidden_states_device,
948
982
  )
949
983
 
950
984
  if self.activation == "silu":
@@ -961,12 +995,14 @@ class DeepEPMoE(EPMoE):
961
995
  else:
962
996
  raise ValueError(f"Unsupported activation: {self.activation=}")
963
997
 
998
+ del gateup_output
999
+
964
1000
  # GroupGemm-1
965
1001
  down_output = torch.empty(
966
1002
  down_input.shape[0],
967
1003
  self.w2_weight.shape[1],
968
- device=hidden_states.device,
969
- dtype=hidden_states.dtype,
1004
+ device=hidden_states_device,
1005
+ dtype=hidden_states_dtype,
970
1006
  )
971
1007
  if down_input.shape[0] > 0:
972
1008
  down_output = self.grouped_gemm_runner(
@@ -1007,11 +1043,9 @@ class DeepEPMoE(EPMoE):
1007
1043
  N = self.w13_weight.size(1)
1008
1044
  scale_block_size = 128
1009
1045
 
1010
- gather_out = torch.empty_like(
1011
- hidden_states_fp8,
1012
- device=hidden_states_fp8.device,
1013
- dtype=torch.bfloat16,
1014
- )
1046
+ hidden_states_fp8_shape = hidden_states_fp8.shape
1047
+ hidden_states_fp8_device = hidden_states_fp8.device
1048
+ hidden_states_fp8_dtype = hidden_states_fp8.dtype
1015
1049
 
1016
1050
  input_tensor = [
1017
1051
  torch.empty(
@@ -1049,16 +1083,18 @@ class DeepEPMoE(EPMoE):
1049
1083
  m_indices,
1050
1084
  output_index,
1051
1085
  )
1086
+ dispose_tensor(hidden_states_fp8)
1052
1087
 
1053
1088
  gateup_output = torch.empty(
1054
1089
  (all_tokens, N),
1055
- device=hidden_states_fp8.device,
1090
+ device=hidden_states_fp8_device,
1056
1091
  dtype=torch.bfloat16,
1057
1092
  )
1058
1093
  input_tensor[1] = tma_align_input_scale(input_tensor[1])
1059
1094
  m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(
1060
1095
  input_tensor, self.w13_weight_fp8, gateup_output, m_indices
1061
1096
  )
1097
+ del input_tensor
1062
1098
  down_input = torch.empty(
1063
1099
  (
1064
1100
  all_tokens,
@@ -1068,14 +1104,16 @@ class DeepEPMoE(EPMoE):
1068
1104
  dtype=torch.bfloat16,
1069
1105
  )
1070
1106
  silu_and_mul(gateup_output.view(-1, N), down_input)
1107
+ del gateup_output
1071
1108
  down_output = torch.empty(
1072
1109
  (all_tokens, K),
1073
- device=hidden_states_fp8.device,
1110
+ device=hidden_states_fp8_device,
1074
1111
  dtype=torch.bfloat16,
1075
1112
  )
1076
1113
  down_input_fp8, down_input_scale = sglang_per_token_group_quant_fp8(
1077
1114
  down_input, scale_block_size
1078
1115
  )
1116
+ del down_input
1079
1117
  down_input_scale = tma_align_input_scale(down_input_scale)
1080
1118
  m_grouped_gemm_fp8_fp8_bf16_nt_contiguous(
1081
1119
  (down_input_fp8, down_input_scale),
@@ -1083,7 +1121,13 @@ class DeepEPMoE(EPMoE):
1083
1121
  down_output,
1084
1122
  m_indices,
1085
1123
  )
1124
+ del down_input_fp8, down_input_scale
1086
1125
 
1126
+ gather_out = torch.empty(
1127
+ hidden_states_fp8_shape,
1128
+ device=hidden_states_fp8_device,
1129
+ dtype=torch.bfloat16,
1130
+ )
1087
1131
  ep_gather(down_output, topk_idx, topk_weights, output_index, gather_out)
1088
1132
 
1089
1133
  return gather_out
@@ -1107,6 +1151,7 @@ class DeepEPMoE(EPMoE):
1107
1151
  m_grouped_gemm_fp8_fp8_bf16_nt_masked(
1108
1152
  hidden_states_fp8, self.w13_weight_fp8, gateup_output, masked_m, expected_m
1109
1153
  )
1154
+ dispose_tensor(hidden_states_fp8[0])
1110
1155
 
1111
1156
  # Act
1112
1157
  down_input = torch.empty(
@@ -1135,6 +1180,7 @@ class DeepEPMoE(EPMoE):
1135
1180
  scale_block_size,
1136
1181
  masked_m,
1137
1182
  )
1183
+ del gateup_output
1138
1184
 
1139
1185
  # GroupGemm-1
1140
1186
  n = self.w2_weight.size(1)
@@ -1150,3 +1196,11 @@ class DeepEPMoE(EPMoE):
1150
1196
  )
1151
1197
 
1152
1198
  return down_output
1199
+
1200
+
1201
+ def get_moe_impl_class():
1202
+ if global_server_args_dict["enable_deepep_moe"]:
1203
+ return DeepEPMoE
1204
+ if global_server_args_dict["enable_ep_moe"]:
1205
+ return EPMoE
1206
+ return FusedMoE