sglang 0.4.5.post1__py3-none-any.whl → 0.4.5.post3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (119) hide show
  1. sglang/__init__.py +2 -4
  2. sglang/bench_one_batch.py +2 -2
  3. sglang/bench_serving.py +3 -6
  4. sglang/compile_deep_gemm.py +136 -0
  5. sglang/lang/backend/anthropic.py +0 -4
  6. sglang/lang/backend/base_backend.py +1 -1
  7. sglang/lang/backend/openai.py +6 -2
  8. sglang/lang/backend/runtime_endpoint.py +5 -1
  9. sglang/lang/backend/vertexai.py +0 -1
  10. sglang/lang/compiler.py +1 -7
  11. sglang/lang/tracer.py +3 -7
  12. sglang/srt/_custom_ops.py +0 -2
  13. sglang/srt/configs/model_config.py +4 -1
  14. sglang/srt/constrained/outlines_jump_forward.py +14 -1
  15. sglang/srt/constrained/triton_ops/bitmask_ops.py +141 -0
  16. sglang/srt/constrained/xgrammar_backend.py +27 -4
  17. sglang/srt/custom_op.py +0 -62
  18. sglang/srt/disaggregation/decode.py +105 -6
  19. sglang/srt/disaggregation/mini_lb.py +74 -9
  20. sglang/srt/disaggregation/mooncake/conn.py +33 -63
  21. sglang/srt/disaggregation/mooncake/transfer_engine.py +30 -61
  22. sglang/srt/disaggregation/nixl/__init__.py +1 -0
  23. sglang/srt/disaggregation/nixl/conn.py +622 -0
  24. sglang/srt/disaggregation/prefill.py +137 -17
  25. sglang/srt/disaggregation/utils.py +32 -0
  26. sglang/srt/entrypoints/engine.py +4 -0
  27. sglang/srt/entrypoints/http_server.py +3 -7
  28. sglang/srt/entrypoints/verl_engine.py +7 -5
  29. sglang/srt/function_call_parser.py +60 -0
  30. sglang/srt/layers/activation.py +6 -8
  31. sglang/srt/layers/attention/flashattention_backend.py +883 -209
  32. sglang/srt/layers/attention/flashinfer_backend.py +5 -2
  33. sglang/srt/layers/attention/torch_native_backend.py +6 -1
  34. sglang/srt/layers/attention/triton_backend.py +6 -0
  35. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +5 -5
  36. sglang/srt/layers/attention/triton_ops/extend_attention.py +18 -7
  37. sglang/srt/layers/attention/triton_ops/prefill_attention.py +7 -3
  38. sglang/srt/layers/dp_attention.py +1 -1
  39. sglang/srt/layers/layernorm.py +20 -5
  40. sglang/srt/layers/linear.py +17 -3
  41. sglang/srt/layers/moe/ep_moe/layer.py +17 -29
  42. sglang/srt/layers/moe/fused_moe_native.py +4 -0
  43. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +14 -19
  44. sglang/srt/layers/moe/fused_moe_triton/layer.py +7 -0
  45. sglang/srt/layers/moe/topk.py +27 -30
  46. sglang/srt/layers/parameter.py +0 -2
  47. sglang/srt/layers/quantization/__init__.py +1 -0
  48. sglang/srt/layers/quantization/blockwise_int8.py +2 -0
  49. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +9 -2
  50. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +16 -44
  51. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
  52. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +153 -0
  53. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +4 -7
  54. sglang/srt/layers/quantization/deep_gemm.py +378 -0
  55. sglang/srt/layers/quantization/fp8.py +115 -132
  56. sglang/srt/layers/quantization/fp8_kernel.py +213 -88
  57. sglang/srt/layers/quantization/fp8_utils.py +189 -264
  58. sglang/srt/layers/quantization/gptq.py +13 -7
  59. sglang/srt/layers/quantization/modelopt_quant.py +2 -2
  60. sglang/srt/layers/quantization/moe_wna16.py +2 -0
  61. sglang/srt/layers/quantization/utils.py +5 -11
  62. sglang/srt/layers/quantization/w8a8_fp8.py +2 -0
  63. sglang/srt/layers/quantization/w8a8_int8.py +7 -7
  64. sglang/srt/layers/radix_attention.py +15 -0
  65. sglang/srt/layers/rotary_embedding.py +9 -8
  66. sglang/srt/layers/sampler.py +7 -12
  67. sglang/srt/lora/backend/base_backend.py +18 -2
  68. sglang/srt/lora/backend/flashinfer_backend.py +1 -1
  69. sglang/srt/lora/backend/triton_backend.py +1 -1
  70. sglang/srt/lora/layers.py +1 -1
  71. sglang/srt/lora/lora.py +1 -1
  72. sglang/srt/lora/lora_manager.py +1 -1
  73. sglang/srt/managers/data_parallel_controller.py +7 -1
  74. sglang/srt/managers/detokenizer_manager.py +0 -1
  75. sglang/srt/managers/io_struct.py +15 -3
  76. sglang/srt/managers/mm_utils.py +4 -3
  77. sglang/srt/managers/multimodal_processor.py +0 -2
  78. sglang/srt/managers/multimodal_processors/base_processor.py +3 -2
  79. sglang/srt/managers/schedule_batch.py +15 -4
  80. sglang/srt/managers/scheduler.py +28 -77
  81. sglang/srt/managers/tokenizer_manager.py +116 -29
  82. sglang/srt/managers/tp_worker.py +1 -0
  83. sglang/srt/mem_cache/hiradix_cache.py +41 -29
  84. sglang/srt/mem_cache/memory_pool.py +38 -15
  85. sglang/srt/model_executor/cuda_graph_runner.py +15 -10
  86. sglang/srt/model_executor/model_runner.py +39 -31
  87. sglang/srt/models/bert.py +398 -0
  88. sglang/srt/models/deepseek.py +1 -1
  89. sglang/srt/models/deepseek_nextn.py +74 -70
  90. sglang/srt/models/deepseek_v2.py +292 -348
  91. sglang/srt/models/llama.py +5 -5
  92. sglang/srt/models/minicpm3.py +31 -203
  93. sglang/srt/models/minicpmo.py +17 -6
  94. sglang/srt/models/qwen2.py +4 -1
  95. sglang/srt/models/qwen2_moe.py +14 -13
  96. sglang/srt/models/qwen3.py +335 -0
  97. sglang/srt/models/qwen3_moe.py +423 -0
  98. sglang/srt/openai_api/adapter.py +71 -4
  99. sglang/srt/openai_api/protocol.py +6 -1
  100. sglang/srt/reasoning_parser.py +0 -1
  101. sglang/srt/sampling/sampling_batch_info.py +2 -3
  102. sglang/srt/server_args.py +86 -72
  103. sglang/srt/speculative/build_eagle_tree.py +2 -2
  104. sglang/srt/speculative/eagle_utils.py +2 -2
  105. sglang/srt/speculative/eagle_worker.py +6 -14
  106. sglang/srt/utils.py +62 -6
  107. sglang/test/runners.py +5 -1
  108. sglang/test/test_block_fp8.py +167 -0
  109. sglang/test/test_custom_ops.py +1 -1
  110. sglang/test/test_utils.py +3 -1
  111. sglang/version.py +1 -1
  112. {sglang-0.4.5.post1.dist-info → sglang-0.4.5.post3.dist-info}/METADATA +5 -5
  113. {sglang-0.4.5.post1.dist-info → sglang-0.4.5.post3.dist-info}/RECORD +116 -110
  114. {sglang-0.4.5.post1.dist-info → sglang-0.4.5.post3.dist-info}/WHEEL +1 -1
  115. sglang/lang/__init__.py +0 -0
  116. sglang/srt/lora/backend/__init__.py +0 -25
  117. sglang/srt/server.py +0 -18
  118. {sglang-0.4.5.post1.dist-info → sglang-0.4.5.post3.dist-info}/licenses/LICENSE +0 -0
  119. {sglang-0.4.5.post1.dist-info → sglang-0.4.5.post3.dist-info}/top_level.txt +0 -0
@@ -100,8 +100,11 @@ class FlashInferAttnBackend(AttentionBackend):
100
100
  self.num_wrappers = 1
101
101
  self.dispatch_reason = None
102
102
 
103
- # Qwen2 models require higher flashinfer workspace size
104
- if "Qwen2ForCausalLM" in model_runner.model_config.hf_config.architectures:
103
+ # Qwen2/Qwen3 models require higher flashinfer workspace size
104
+ if (
105
+ "Qwen2ForCausalLM" in model_runner.model_config.hf_config.architectures
106
+ or "Qwen3ForCausalLM" in model_runner.model_config.hf_config.architectures
107
+ ):
105
108
  global_config.flashinfer_workspace_size = 512 * 1024 * 1024
106
109
 
107
110
  # Allocate buffers
@@ -6,6 +6,7 @@ import torch
6
6
  from torch.nn.functional import scaled_dot_product_attention
7
7
 
8
8
  from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
9
+ from sglang.srt.layers.radix_attention import AttentionType
9
10
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch
10
11
 
11
12
  if TYPE_CHECKING:
@@ -202,6 +203,10 @@ class TorchNativeAttnBackend(AttentionBackend):
202
203
  q_ = q.view(-1, layer.tp_q_head_num, layer.qk_head_dim)
203
204
  o_ = o.view(-1, layer.tp_q_head_num, layer.v_head_dim)
204
205
 
206
+ causal = True
207
+ if layer.is_cross_attention or layer.attn_type == AttentionType.ENCODER_ONLY:
208
+ causal = False
209
+
205
210
  self._run_sdpa_forward_extend(
206
211
  q_,
207
212
  o_,
@@ -214,7 +219,7 @@ class TorchNativeAttnBackend(AttentionBackend):
214
219
  forward_batch.extend_seq_lens,
215
220
  scaling=layer.scaling,
216
221
  enable_gqa=use_gqa,
217
- causal=not layer.is_cross_attention,
222
+ causal=causal,
218
223
  )
219
224
  return o
220
225
 
@@ -10,6 +10,7 @@ import triton.language as tl
10
10
  from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
11
11
  from sglang.srt.layers.attention.utils import create_flashinfer_kv_indices_triton
12
12
  from sglang.srt.layers.dp_attention import get_attention_tp_size
13
+ from sglang.srt.layers.radix_attention import AttentionType
13
14
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch, ForwardMode
14
15
  from sglang.srt.utils import get_bool_env_var, get_device_core_count
15
16
 
@@ -528,6 +529,10 @@ class TritonAttnBackend(AttentionBackend):
528
529
  layer, forward_batch.out_cache_loc, k, v
529
530
  )
530
531
 
532
+ causal = True
533
+ if layer.attn_type == AttentionType.ENCODER_ONLY:
534
+ causal = False
535
+
531
536
  self.extend_attention_fwd(
532
537
  q.view(-1, layer.tp_q_head_num, layer.qk_head_dim),
533
538
  k.contiguous(),
@@ -539,6 +544,7 @@ class TritonAttnBackend(AttentionBackend):
539
544
  self.forward_metadata.kv_indptr,
540
545
  self.forward_metadata.kv_indices,
541
546
  self.forward_metadata.custom_mask,
547
+ causal,
542
548
  self.forward_metadata.mask_indptr,
543
549
  self.forward_metadata.max_extend_len,
544
550
  layer.scaling,
@@ -3,10 +3,10 @@ import triton
3
3
  import triton.language as tl
4
4
 
5
5
  from sglang.srt.managers.schedule_batch import global_server_args_dict
6
- from sglang.srt.utils import is_hip
6
+ from sglang.srt.utils import is_cuda, is_hip
7
7
 
8
- is_cuda_available = torch.cuda.is_available()
9
- if is_cuda_available:
8
+ _is_cuda = is_cuda()
9
+ if _is_cuda:
10
10
  CUDA_CAPABILITY = torch.cuda.get_device_capability()
11
11
 
12
12
  _is_hip = is_hip()
@@ -1037,12 +1037,12 @@ def extend_attention_fwd(
1037
1037
  num_warps = 4
1038
1038
 
1039
1039
  else:
1040
- if is_cuda_available and CUDA_CAPABILITY[0] >= 9:
1040
+ if _is_cuda and CUDA_CAPABILITY[0] >= 9:
1041
1041
  if Lq <= 256:
1042
1042
  BLOCK_M, BLOCK_N = (128, 64)
1043
1043
  else:
1044
1044
  BLOCK_M, BLOCK_N = (32, 64)
1045
- elif is_cuda_available and CUDA_CAPABILITY[0] >= 8:
1045
+ elif _is_cuda and CUDA_CAPABILITY[0] >= 8:
1046
1046
  if Lq <= 128:
1047
1047
  BLOCK_M, BLOCK_N = (128, 128)
1048
1048
  elif Lq <= 256:
@@ -23,10 +23,10 @@ import triton.language as tl
23
23
  from sglang.srt.layers.attention.triton_ops.prefill_attention import (
24
24
  context_attention_fwd,
25
25
  )
26
- from sglang.srt.utils import is_hip
26
+ from sglang.srt.utils import is_cuda, is_hip
27
27
 
28
- is_cuda_available = torch.cuda.is_available()
29
- if is_cuda_available:
28
+ _is_cuda = is_cuda()
29
+ if _is_cuda:
30
30
  CUDA_CAPABILITY = torch.cuda.get_device_capability()
31
31
 
32
32
  _is_hip = is_hip()
@@ -74,6 +74,7 @@ def _fwd_kernel(
74
74
  BLOCK_M: tl.constexpr,
75
75
  BLOCK_N: tl.constexpr,
76
76
  USE_CUSTOM_MASK: tl.constexpr,
77
+ IS_CAUSAL: tl.constexpr,
77
78
  SKIP_PREFIX_CUSTOM_MASK: tl.constexpr,
78
79
  STORE_TRANSPOSE: tl.constexpr,
79
80
  ):
@@ -129,6 +130,7 @@ def _fwd_kernel(
129
130
  for start_n in range(0, cur_seq_len_prefix, BLOCK_N):
130
131
  start_n = tl.multiple_of(start_n, BLOCK_N)
131
132
  mask_n = (start_n + offs_n) < cur_seq_len_prefix
133
+
132
134
  offs_kv_loc = tl.load(
133
135
  kv_indices + cur_seq_kv_start_idx + start_n + offs_n, mask=mask_n, other=0
134
136
  )
@@ -196,7 +198,11 @@ def _fwd_kernel(
196
198
 
197
199
  # stage 2: compute the triangle part
198
200
 
199
- cur_block_m_end = tl.minimum(cur_seq_len_extend, (cur_block_m + 1) * BLOCK_M)
201
+ cur_block_m_end = (
202
+ cur_seq_len_extend
203
+ if not IS_CAUSAL
204
+ else tl.minimum(cur_seq_len_extend, (cur_block_m + 1) * BLOCK_M)
205
+ )
200
206
  for start_n in range(0, cur_block_m_end, BLOCK_N):
201
207
  start_n = tl.multiple_of(start_n, BLOCK_N)
202
208
  mask_n = (start_n + offs_n) < cur_block_m_end
@@ -243,12 +249,15 @@ def _fwd_kernel(
243
249
  )
244
250
  custom_mask &= mask_m[:, None] & mask_n[None, :]
245
251
  qk = tl.where(custom_mask, qk, float("-inf"))
246
- else:
252
+ elif IS_CAUSAL:
247
253
  mask_causual = (cur_block_m * BLOCK_M + offs_m[:, None]) >= (
248
254
  start_n + offs_n[None, :]
249
255
  )
250
256
  mask_causual &= mask_m[:, None] & mask_n[None, :]
251
257
  qk = tl.where(mask_causual, qk, float("-inf"))
258
+ else:
259
+ mask_non_causal = mask_m[:, None] & mask_n[None, :]
260
+ qk = tl.where(mask_non_causal, qk, float("-inf"))
252
261
 
253
262
  n_e_max = tl.maximum(tl.max(qk, 1), e_max)
254
263
  re_scale = tl.exp(e_max - n_e_max)
@@ -299,6 +308,7 @@ def extend_attention_fwd(
299
308
  kv_indptr,
300
309
  kv_indices,
301
310
  custom_mask,
311
+ is_causal,
302
312
  mask_indptr,
303
313
  max_len_extend,
304
314
  sm_scale=None,
@@ -335,12 +345,12 @@ def extend_attention_fwd(
335
345
  num_warps = 4
336
346
 
337
347
  else:
338
- if is_cuda_available and CUDA_CAPABILITY[0] >= 9:
348
+ if _is_cuda and CUDA_CAPABILITY[0] >= 9:
339
349
  if Lq <= 256:
340
350
  BLOCK_M, BLOCK_N = (128, 64)
341
351
  else:
342
352
  BLOCK_M, BLOCK_N = (32, 64)
343
- elif is_cuda_available and CUDA_CAPABILITY[0] >= 8:
353
+ elif _is_cuda and CUDA_CAPABILITY[0] >= 8:
344
354
  # sm86/sm89 has a much smaller shared memory size (100K) than sm80 (160K)
345
355
  if CUDA_CAPABILITY[1] == 9 or CUDA_CAPABILITY[1] == 6:
346
356
  if Lq <= 128:
@@ -411,6 +421,7 @@ def extend_attention_fwd(
411
421
  Lq=Lq,
412
422
  Lv=Lv,
413
423
  USE_CUSTOM_MASK=USE_CUSTOM_MASK,
424
+ IS_CAUSAL=is_causal,
414
425
  SKIP_PREFIX_CUSTOM_MASK=SKIP_PREFIX_CUSTOM_MASK,
415
426
  STORE_TRANSPOSE=_is_hip,
416
427
  num_warps=num_warps,
@@ -22,8 +22,12 @@ import torch
22
22
  import triton
23
23
  import triton.language as tl
24
24
 
25
- is_cuda_available = torch.cuda.is_available()
26
- if is_cuda_available:
25
+ from sglang.srt.utils import is_cuda, is_hip
26
+
27
+ _is_cuda = is_cuda()
28
+ _is_hip = is_hip()
29
+
30
+ if _is_cuda or _is_hip:
27
31
  CUDA_CAPABILITY = torch.cuda.get_device_capability()
28
32
 
29
33
 
@@ -172,7 +176,7 @@ def context_attention_fwd(
172
176
  b_seq_len: [b]
173
177
  out: [b * s, head, head_dim]
174
178
  """
175
- if is_cuda_available and CUDA_CAPABILITY[0] > 8:
179
+ if (_is_cuda or _is_hip) and CUDA_CAPABILITY[0] > 8:
176
180
  BLOCK = 128
177
181
  else:
178
182
  BLOCK = 64
@@ -143,7 +143,7 @@ def memcpy_triton_kernel(
143
143
  src_ptr,
144
144
  offset_ptr,
145
145
  sz_ptr,
146
- offset_src,
146
+ offset_src: tl.constexpr,
147
147
  chunk_size, # multiplied for offset and sz
148
148
  BLOCK_SIZE: tl.constexpr,
149
149
  ):
@@ -19,9 +19,13 @@ from typing import Optional, Tuple, Union
19
19
  import torch
20
20
  import torch.nn as nn
21
21
 
22
- from sglang.srt.utils import is_cuda_available
22
+ from sglang.srt.custom_op import CustomOp
23
+ from sglang.srt.utils import is_cuda, is_hip
24
+
25
+ logger = logging.getLogger(__name__)
23
26
 
24
- _is_cuda = is_cuda_available()
27
+ _is_cuda = is_cuda()
28
+ _is_hip = is_hip()
25
29
 
26
30
  if _is_cuda:
27
31
  from sgl_kernel import (
@@ -31,9 +35,20 @@ if _is_cuda:
31
35
  rmsnorm,
32
36
  )
33
37
 
34
- from sglang.srt.custom_op import CustomOp
38
+ if _is_hip:
35
39
 
36
- logger = logging.getLogger(__name__)
40
+ from aiter.ops.rmsnorm import rms_norm, rmsnorm2d_fwd_with_add
41
+
42
+ rmsnorm = rms_norm
43
+
44
+ def fused_add_rmsnorm(
45
+ x: torch.Tensor,
46
+ residual: torch.Tensor,
47
+ w: torch.Tensor,
48
+ eps: float,
49
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
50
+ rmsnorm2d_fwd_with_add(x, x, residual, residual, w, eps)
51
+ return x, residual
37
52
 
38
53
 
39
54
  class RMSNorm(CustomOp):
@@ -139,7 +154,7 @@ class Gemma3RMSNorm(nn.Module):
139
154
  return f"{tuple(self.weight.shape)}, eps={self.eps}"
140
155
 
141
156
 
142
- if not _is_cuda:
157
+ if not (_is_cuda or _is_hip):
143
158
  logger.info(
144
159
  "sgl-kernel is not available on Non-NV platforms. Fallback to other kernel libraries."
145
160
  )
@@ -1,5 +1,6 @@
1
1
  """Adapted from https://github.com/vllm-project/vllm/blob/v0.6.4.post1/vllm/model_executor/layers/linear.py"""
2
2
 
3
+ import itertools
3
4
  import logging
4
5
  from abc import abstractmethod
5
6
  from typing import Dict, List, Optional, Tuple
@@ -61,12 +62,12 @@ def adjust_marlin_shard(param, shard_size, shard_offset):
61
62
 
62
63
 
63
64
  def adjust_bitsandbytes_4bit_shard(
64
- param: Parameter, qkv_offsets: Dict[str, Tuple[int, int]], loaded_shard_id: str
65
+ param: Parameter, shard_offsets: Dict[str, Tuple[int, int]], loaded_shard_id: str
65
66
  ) -> Tuple[int, int]:
66
67
  """Adjust the quantization offsets and sizes for BitsAndBytes sharding."""
67
68
 
68
- total, _ = qkv_offsets["total"]
69
- orig_offset, orig_size = qkv_offsets[loaded_shard_id]
69
+ total, _ = shard_offsets["total"]
70
+ orig_offset, orig_size = shard_offsets[loaded_shard_id]
70
71
 
71
72
  quantized_total = param.data.shape[0]
72
73
  quantized_offset = orig_offset * quantized_total // total
@@ -573,6 +574,8 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
573
574
  shard_offsets.append((i, current_shard_offset, output_size))
574
575
  current_shard_offset += output_size
575
576
  packed_dim = getattr(param, "packed_dim", None)
577
+
578
+ use_bitsandbytes_4bit = getattr(param, "use_bitsandbytes_4bit", False)
576
579
  for shard_id, shard_offset, shard_size in shard_offsets:
577
580
  # Special case for Quantization.
578
581
  # If quantized, we need to adjust the offset and size to account
@@ -585,6 +588,17 @@ class MergedColumnParallelLinear(ColumnParallelLinear):
585
588
  param, shard_size, shard_offset
586
589
  )
587
590
 
591
+ if use_bitsandbytes_4bit:
592
+ index = list(itertools.accumulate([0] + self.output_sizes))
593
+ orig_offsets = {
594
+ str(i): (index[i], size)
595
+ for i, size in enumerate(self.output_sizes)
596
+ }
597
+ orig_offsets["total"] = (self.output_size, 0)
598
+ shard_size, shard_offset = adjust_bitsandbytes_4bit_shard(
599
+ param, orig_offsets, str(shard_id)
600
+ )
601
+
588
602
  loaded_weight_shard = loaded_weight.narrow(
589
603
  output_dim, shard_offset, shard_size
590
604
  )
@@ -2,6 +2,7 @@ import logging
2
2
  from typing import Callable, List, Optional, Tuple
3
3
 
4
4
  import torch
5
+ from torch.nn import Module
5
6
 
6
7
  try:
7
8
  from deep_gemm import (
@@ -13,8 +14,6 @@ try:
13
14
  except ImportError:
14
15
  use_deep_gemm = False
15
16
 
16
- from torch.nn import Module
17
-
18
17
  from sglang.srt.custom_op import CustomOp
19
18
  from sglang.srt.distributed import (
20
19
  get_tensor_model_parallel_rank,
@@ -37,22 +36,17 @@ from sglang.srt.layers.quantization.base_config import (
37
36
  QuantizeMethodBase,
38
37
  )
39
38
  from sglang.srt.layers.quantization.fp8 import Fp8Config, Fp8MoEMethod
39
+ from sglang.srt.layers.quantization.fp8_kernel import scaled_fp8_quant
40
40
  from sglang.srt.model_executor.forward_batch_info import ForwardMode
41
- from sglang.srt.utils import DeepEPMode, is_cuda, is_hip, set_weight_attrs
41
+ from sglang.srt.utils import DeepEPMode, is_hip, set_weight_attrs
42
42
 
43
- _is_cuda = is_cuda()
43
+ _is_hip = is_hip()
44
44
 
45
- if _is_cuda:
46
- from sglang.srt.custom_op import scaled_fp8_quant as sgl_scaled_fp8_quant
47
- else:
48
- from vllm import _custom_ops as vllm_ops
45
+ if _is_hip:
46
+ from vllm._custom_ops import scaled_fp8_quant
49
47
 
50
48
  logger = logging.getLogger(__name__)
51
49
 
52
- _is_hip = is_hip()
53
-
54
- _buffer = None
55
-
56
50
 
57
51
  class GroupedGemmRunner(torch.nn.Module):
58
52
  flashinfer_gemm_warpper = None
@@ -142,6 +136,7 @@ class EPMoE(torch.nn.Module):
142
136
  correction_bias: Optional[torch.Tensor] = None,
143
137
  custom_routing_function: Optional[Callable] = None,
144
138
  activation: str = "silu",
139
+ routed_scaling_factor: Optional[float] = None,
145
140
  ):
146
141
  super().__init__()
147
142
 
@@ -170,6 +165,7 @@ class EPMoE(torch.nn.Module):
170
165
  self.correction_bias = correction_bias
171
166
  self.custom_routing_function = custom_routing_function
172
167
  self.activation = activation
168
+ self.routed_scaling_factor = routed_scaling_factor
173
169
 
174
170
  if quant_config is None:
175
171
  self.quant_method: Optional[QuantizeMethodBase] = UnquantizedEPMoEMethod()
@@ -221,6 +217,7 @@ class EPMoE(torch.nn.Module):
221
217
  num_expert_group=self.num_expert_group,
222
218
  correction_bias=self.correction_bias,
223
219
  custom_routing_function=self.custom_routing_function,
220
+ routed_scaling_factor=self.routed_scaling_factor,
224
221
  )
225
222
 
226
223
  reorder_topk_ids, src2dst, seg_indptr = run_moe_ep_preproess(
@@ -740,20 +737,12 @@ class Fp8EPMoEMethod(Fp8MoEMethod):
740
737
  )
741
738
 
742
739
  for expert in range(layer.num_experts_per_partition):
743
- if _is_cuda:
744
- w13_weight[expert, :, :], layer.w13_weight_scale[expert] = (
745
- sgl_scaled_fp8_quant(layer.w13_weight.data[expert, :, :])
746
- )
747
- w2_weight[expert, :, :], layer.w2_weight_scale[expert] = (
748
- sgl_scaled_fp8_quant(layer.w2_weight.data[expert, :, :])
749
- )
750
- else:
751
- w13_weight[expert, :, :], layer.w13_weight_scale[expert] = (
752
- vllm_ops.scaled_fp8_quant(layer.w13_weight.data[expert, :, :])
753
- )
754
- w2_weight[expert, :, :], layer.w2_weight_scale[expert] = (
755
- vllm_ops.scaled_fp8_quant(layer.w2_weight.data[expert, :, :])
756
- )
740
+ w13_weight[expert, :, :], layer.w13_weight_scale[expert] = (
741
+ scaled_fp8_quant(layer.w13_weight.data[expert, :, :])
742
+ )
743
+ w2_weight[expert, :, :], layer.w2_weight_scale[expert] = (
744
+ scaled_fp8_quant(layer.w2_weight.data[expert, :, :])
745
+ )
757
746
  layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False)
758
747
  layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False)
759
748
  return
@@ -813,6 +802,7 @@ class DeepEPMoE(EPMoE):
813
802
  correction_bias: Optional[torch.Tensor] = None,
814
803
  custom_routing_function: Optional[Callable] = None,
815
804
  activation: str = "silu",
805
+ routed_scaling_factor: Optional[float] = None,
816
806
  deepep_mode: DeepEPMode = DeepEPMode.auto,
817
807
  ):
818
808
  super().__init__(
@@ -831,6 +821,7 @@ class DeepEPMoE(EPMoE):
831
821
  correction_bias,
832
822
  custom_routing_function,
833
823
  activation,
824
+ routed_scaling_factor,
834
825
  )
835
826
  self.deepep_mode = deepep_mode
836
827
  if self.deepep_mode.enable_low_latency():
@@ -986,9 +977,6 @@ class DeepEPMoE(EPMoE):
986
977
  ):
987
978
  assert self.quant_method is not None
988
979
  assert self.activation == "silu"
989
- assert (
990
- hidden_states_fp8[0].size(0) % 4 == 0
991
- ), f"TMA alignment error: {hidden_states_fp8[0].size(0)}"
992
980
 
993
981
  # GroupGemm-0
994
982
  num_groups, m, k = hidden_states_fp8[0].size()
@@ -26,6 +26,7 @@ def fused_moe_forward_native(
26
26
  apply_router_weight_on_input: bool = False,
27
27
  inplace: bool = True,
28
28
  no_combine: bool = False,
29
+ routed_scaling_factor: Optional[float] = None,
29
30
  ) -> torch.Tensor:
30
31
 
31
32
  if apply_router_weight_on_input:
@@ -41,6 +42,7 @@ def fused_moe_forward_native(
41
42
  num_expert_group=num_expert_group,
42
43
  custom_routing_function=custom_routing_function,
43
44
  correction_bias=correction_bias,
45
+ routed_scaling_factor=routed_scaling_factor,
44
46
  torch_native=True,
45
47
  )
46
48
 
@@ -71,6 +73,7 @@ def moe_forward_native(
71
73
  custom_routing_function: Optional[Callable] = None,
72
74
  correction_bias: Optional[torch.Tensor] = None,
73
75
  activation: str = "silu",
76
+ routed_scaling_factor: Optional[float] = None,
74
77
  ) -> torch.Tensor:
75
78
 
76
79
  from sglang.srt.layers.activation import GeluAndMul, SiluAndMul
@@ -86,6 +89,7 @@ def moe_forward_native(
86
89
  custom_routing_function=custom_routing_function,
87
90
  correction_bias=correction_bias,
88
91
  torch_native=True,
92
+ routed_scaling_factor=routed_scaling_factor,
89
93
  )
90
94
 
91
95
  # Ref code from https://huggingface.co/deepseek-ai/DeepSeek-V2/blob/e0828e3cc0a03408724b80c3cc92c8e072db8d01/modeling_deepseek.py#L589
@@ -13,6 +13,7 @@ import triton
13
13
  import triton.language as tl
14
14
 
15
15
  from sglang.srt.layers.moe.topk import select_experts
16
+ from sglang.srt.layers.quantization.fp8_kernel import scaled_fp8_quant
16
17
  from sglang.srt.utils import (
17
18
  direct_register_custom_op,
18
19
  get_bool_env_var,
@@ -22,28 +23,25 @@ from sglang.srt.utils import (
22
23
  )
23
24
 
24
25
  _is_hip = is_hip()
25
-
26
-
27
- logger = logging.getLogger(__name__)
28
- padding_size = 128 if bool(int(os.getenv("MOE_PADDING", "0"))) else 0
29
-
30
- enable_moe_align_block_size_triton = bool(
31
- int(os.getenv("ENABLE_MOE_ALIGN_BLOCK_SIZE_TRITON", "0"))
32
- )
33
-
34
26
  _is_cuda = is_cuda()
35
27
 
36
28
  if _is_cuda:
37
29
  from sgl_kernel import gelu_and_mul, silu_and_mul
38
-
39
- from sglang.srt.custom_op import scaled_fp8_quant as sgl_scaled_fp8_quant
40
30
  else:
41
31
  from vllm import _custom_ops as vllm_ops
32
+ from vllm._custom_ops import scaled_fp8_quant
42
33
 
43
34
  if _is_cuda or _is_hip:
44
35
  from sgl_kernel import moe_align_block_size as sgl_moe_align_block_size
45
36
 
46
37
 
38
+ logger = logging.getLogger(__name__)
39
+ padding_size = 128 if bool(int(os.getenv("MOE_PADDING", "0"))) else 0
40
+ enable_moe_align_block_size_triton = bool(
41
+ int(os.getenv("ENABLE_MOE_ALIGN_BLOCK_SIZE_TRITON", "0"))
42
+ )
43
+
44
+
47
45
  @triton.jit
48
46
  def write_zeros_to_output(
49
47
  c_ptr,
@@ -770,14 +768,9 @@ def invoke_fused_moe_kernel(
770
768
  # activation tensor-wise fp8 quantization, dynamic or static
771
769
  padded_size = padding_size
772
770
  # activations apply per-token quantization when weights apply per-channel quantization by default
773
- if _is_cuda:
774
- A, A_scale = sgl_scaled_fp8_quant(
775
- A, A_scale, use_per_token_if_dynamic=per_channel_quant
776
- )
777
- else:
778
- A, A_scale = vllm_ops.scaled_fp8_quant(
779
- A, A_scale, use_per_token_if_dynamic=per_channel_quant
780
- )
771
+ A, A_scale = scaled_fp8_quant(
772
+ A, A_scale, use_per_token_if_dynamic=per_channel_quant
773
+ )
781
774
  else:
782
775
  # activation block-wise fp8 quantization
783
776
  assert len(block_shape) == 2
@@ -1554,6 +1547,7 @@ def fused_moe(
1554
1547
  a2_scale: Optional[torch.Tensor] = None,
1555
1548
  block_shape: Optional[List[int]] = None,
1556
1549
  no_combine: bool = False,
1550
+ routed_scaling_factor: Optional[float] = None,
1557
1551
  ) -> torch.Tensor:
1558
1552
  """
1559
1553
  This function computes a Mixture of Experts (MoE) layer using two sets of
@@ -1608,6 +1602,7 @@ def fused_moe(
1608
1602
  topk_group=topk_group,
1609
1603
  num_expert_group=num_expert_group,
1610
1604
  custom_routing_function=custom_routing_function,
1605
+ routed_scaling_factor=routed_scaling_factor,
1611
1606
  )
1612
1607
 
1613
1608
  return fused_experts(
@@ -131,6 +131,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
131
131
  apply_router_weight_on_input: bool = False,
132
132
  inplace: bool = True,
133
133
  no_combine: bool = False,
134
+ routed_scaling_factor: Optional[float] = None,
134
135
  ) -> torch.Tensor:
135
136
  return self.forward(
136
137
  x=x,
@@ -147,6 +148,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
147
148
  apply_router_weight_on_input=apply_router_weight_on_input,
148
149
  inplace=inplace,
149
150
  no_combine=no_combine,
151
+ routed_scaling_factor=routed_scaling_factor,
150
152
  )
151
153
 
152
154
  def forward_cuda(
@@ -165,6 +167,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
165
167
  apply_router_weight_on_input: bool = False,
166
168
  inplace: bool = True,
167
169
  no_combine: bool = False,
170
+ routed_scaling_factor: Optional[float] = None,
168
171
  ) -> torch.Tensor:
169
172
  topk_weights, topk_ids = select_experts(
170
173
  hidden_states=x,
@@ -176,6 +179,7 @@ class UnquantizedFusedMoEMethod(FusedMoEMethodBase, CustomOp):
176
179
  num_expert_group=num_expert_group,
177
180
  custom_routing_function=custom_routing_function,
178
181
  correction_bias=correction_bias,
182
+ routed_scaling_factor=routed_scaling_factor,
179
183
  )
180
184
 
181
185
  if _is_hip and get_bool_env_var("CK_MOE"):
@@ -284,6 +288,7 @@ class FusedMoE(torch.nn.Module):
284
288
  use_presharded_weights: bool = False,
285
289
  inplace: bool = True,
286
290
  no_combine: bool = False,
291
+ routed_scaling_factor: Optional[float] = None,
287
292
  ):
288
293
  super().__init__()
289
294
 
@@ -293,6 +298,7 @@ class FusedMoE(torch.nn.Module):
293
298
  self.tp_size = (
294
299
  tp_size if tp_size is not None else get_tensor_model_parallel_world_size()
295
300
  )
301
+ self.routed_scaling_factor = routed_scaling_factor
296
302
  self.top_k = top_k
297
303
  self.num_experts = num_experts
298
304
  assert intermediate_size % self.tp_size == 0
@@ -637,6 +643,7 @@ class FusedMoE(torch.nn.Module):
637
643
  correction_bias=self.correction_bias,
638
644
  activation=self.activation,
639
645
  apply_router_weight_on_input=self.apply_router_weight_on_input,
646
+ routed_scaling_factor=self.routed_scaling_factor,
640
647
  )
641
648
 
642
649
  if self.reduce_results and self.tp_size > 1: