sglang 0.4.5.post1__py3-none-any.whl → 0.4.5.post2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (95) hide show
  1. sglang/__init__.py +2 -4
  2. sglang/bench_one_batch.py +2 -2
  3. sglang/bench_serving.py +0 -4
  4. sglang/lang/backend/anthropic.py +0 -4
  5. sglang/lang/backend/base_backend.py +1 -1
  6. sglang/lang/backend/openai.py +1 -1
  7. sglang/lang/backend/vertexai.py +0 -1
  8. sglang/lang/compiler.py +1 -7
  9. sglang/lang/tracer.py +3 -7
  10. sglang/srt/_custom_ops.py +0 -2
  11. sglang/srt/constrained/outlines_jump_forward.py +14 -1
  12. sglang/srt/constrained/triton_ops/bitmask_ops.py +141 -0
  13. sglang/srt/constrained/xgrammar_backend.py +26 -4
  14. sglang/srt/custom_op.py +0 -62
  15. sglang/srt/disaggregation/decode.py +62 -6
  16. sglang/srt/disaggregation/mini_lb.py +5 -1
  17. sglang/srt/disaggregation/mooncake/conn.py +32 -62
  18. sglang/srt/disaggregation/mooncake/transfer_engine.py +30 -61
  19. sglang/srt/disaggregation/prefill.py +40 -4
  20. sglang/srt/disaggregation/utils.py +15 -0
  21. sglang/srt/entrypoints/verl_engine.py +7 -5
  22. sglang/srt/layers/activation.py +6 -8
  23. sglang/srt/layers/attention/flashattention_backend.py +114 -71
  24. sglang/srt/layers/attention/flashinfer_backend.py +5 -2
  25. sglang/srt/layers/attention/torch_native_backend.py +6 -1
  26. sglang/srt/layers/attention/triton_backend.py +6 -0
  27. sglang/srt/layers/attention/triton_ops/extend_attention.py +13 -2
  28. sglang/srt/layers/layernorm.py +1 -1
  29. sglang/srt/layers/linear.py +17 -3
  30. sglang/srt/layers/moe/ep_moe/layer.py +15 -29
  31. sglang/srt/layers/moe/fused_moe_native.py +4 -0
  32. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +14 -19
  33. sglang/srt/layers/moe/fused_moe_triton/layer.py +7 -0
  34. sglang/srt/layers/moe/topk.py +27 -30
  35. sglang/srt/layers/parameter.py +0 -2
  36. sglang/srt/layers/quantization/__init__.py +1 -0
  37. sglang/srt/layers/quantization/blockwise_int8.py +2 -0
  38. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +8 -2
  39. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +16 -44
  40. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +4 -7
  41. sglang/srt/layers/quantization/fp8.py +115 -132
  42. sglang/srt/layers/quantization/fp8_kernel.py +213 -57
  43. sglang/srt/layers/quantization/fp8_utils.py +187 -262
  44. sglang/srt/layers/quantization/moe_wna16.py +2 -0
  45. sglang/srt/layers/quantization/utils.py +5 -11
  46. sglang/srt/layers/quantization/w8a8_fp8.py +2 -0
  47. sglang/srt/layers/quantization/w8a8_int8.py +7 -7
  48. sglang/srt/layers/radix_attention.py +15 -0
  49. sglang/srt/layers/rotary_embedding.py +3 -2
  50. sglang/srt/layers/sampler.py +5 -10
  51. sglang/srt/lora/backend/base_backend.py +18 -2
  52. sglang/srt/lora/backend/flashinfer_backend.py +1 -1
  53. sglang/srt/lora/backend/triton_backend.py +1 -1
  54. sglang/srt/lora/layers.py +1 -1
  55. sglang/srt/lora/lora.py +1 -1
  56. sglang/srt/lora/lora_manager.py +1 -1
  57. sglang/srt/managers/detokenizer_manager.py +0 -1
  58. sglang/srt/managers/io_struct.py +1 -0
  59. sglang/srt/managers/mm_utils.py +4 -3
  60. sglang/srt/managers/multimodal_processor.py +0 -2
  61. sglang/srt/managers/multimodal_processors/base_processor.py +3 -2
  62. sglang/srt/managers/schedule_batch.py +2 -4
  63. sglang/srt/managers/scheduler.py +12 -71
  64. sglang/srt/managers/tokenizer_manager.py +1 -0
  65. sglang/srt/mem_cache/hiradix_cache.py +5 -1
  66. sglang/srt/mem_cache/memory_pool.py +7 -2
  67. sglang/srt/model_executor/cuda_graph_runner.py +2 -2
  68. sglang/srt/model_executor/model_runner.py +20 -27
  69. sglang/srt/models/bert.py +398 -0
  70. sglang/srt/models/deepseek.py +1 -1
  71. sglang/srt/models/deepseek_nextn.py +74 -70
  72. sglang/srt/models/deepseek_v2.py +289 -348
  73. sglang/srt/models/llama.py +5 -5
  74. sglang/srt/models/minicpm3.py +29 -201
  75. sglang/srt/models/qwen2.py +4 -1
  76. sglang/srt/models/qwen2_moe.py +14 -13
  77. sglang/srt/models/qwen3.py +335 -0
  78. sglang/srt/models/qwen3_moe.py +423 -0
  79. sglang/srt/reasoning_parser.py +0 -1
  80. sglang/srt/sampling/sampling_batch_info.py +2 -3
  81. sglang/srt/server_args.py +34 -32
  82. sglang/srt/speculative/eagle_worker.py +4 -7
  83. sglang/srt/utils.py +16 -1
  84. sglang/test/runners.py +5 -1
  85. sglang/test/test_block_fp8.py +167 -0
  86. sglang/test/test_custom_ops.py +1 -1
  87. sglang/version.py +1 -1
  88. {sglang-0.4.5.post1.dist-info → sglang-0.4.5.post2.dist-info}/METADATA +3 -3
  89. {sglang-0.4.5.post1.dist-info → sglang-0.4.5.post2.dist-info}/RECORD +92 -91
  90. {sglang-0.4.5.post1.dist-info → sglang-0.4.5.post2.dist-info}/WHEEL +1 -1
  91. sglang/lang/__init__.py +0 -0
  92. sglang/srt/lora/backend/__init__.py +0 -25
  93. sglang/srt/server.py +0 -18
  94. {sglang-0.4.5.post1.dist-info → sglang-0.4.5.post2.dist-info}/licenses/LICENSE +0 -0
  95. {sglang-0.4.5.post1.dist-info → sglang-0.4.5.post2.dist-info}/top_level.txt +0 -0
@@ -34,22 +34,31 @@ from sglang.srt.utils import (
34
34
  supports_custom_op,
35
35
  )
36
36
 
37
- _enable_jit_deepgemm = False
38
-
39
37
  _is_hip = is_hip()
40
- fp8_type_ = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
41
-
42
38
  _is_cuda = is_cuda()
39
+ _fp8_type = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
40
+ if _is_hip:
41
+ fp8_max = 224.0
42
+ else:
43
+ fp8_max = torch.finfo(_fp8_type).max
44
+ fp8_min = -fp8_max
45
+
46
+ _enable_jit_deepgemm = False
47
+ _enable_jit_deepgemm_bmm = False
43
48
  if _is_cuda:
44
49
  import deep_gemm
45
- from sgl_kernel import sgl_per_token_group_quant_fp8, sgl_per_token_quant_fp8
50
+ from sgl_kernel import (
51
+ sgl_per_tensor_quant_fp8,
52
+ sgl_per_token_group_quant_fp8,
53
+ sgl_per_token_quant_fp8,
54
+ )
46
55
 
47
56
  sm_version = get_device_sm()
48
- if sm_version == 90 and get_bool_env_var(
49
- "SGL_ENABLE_JIT_DEEPGEMM", default="false"
50
- ):
51
- _enable_jit_deepgemm = True
52
-
57
+ if sm_version == 90:
58
+ if get_bool_env_var("SGL_ENABLE_JIT_DEEPGEMM", default="false"):
59
+ _enable_jit_deepgemm = True
60
+ if get_bool_env_var("SGL_ENABLE_JIT_DEEPGEMM_BMM", default="false"):
61
+ _enable_jit_deepgemm_bmm = True
53
62
 
54
63
  logger = logging.getLogger(__name__)
55
64
 
@@ -179,7 +188,6 @@ def per_token_group_quant_fp8(
179
188
  x: torch.Tensor,
180
189
  group_size: int,
181
190
  eps: float = 1e-10,
182
- dtype: torch.dtype = fp8_type_,
183
191
  column_major_scales: bool = False,
184
192
  scale_tma_aligned: bool = False,
185
193
  ) -> Tuple[torch.Tensor, torch.Tensor]:
@@ -192,7 +200,6 @@ def per_token_group_quant_fp8(
192
200
  x: The input tenosr with ndim >= 2.
193
201
  group_size: The group size used for quantization.
194
202
  eps: The minimum to avoid dividing zero.
195
- dtype: The dype of output tensor.
196
203
 
197
204
  Returns:
198
205
  Tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the scaling factor for quantization.
@@ -202,15 +209,7 @@ def per_token_group_quant_fp8(
202
209
  ), "the last dimension of `x` cannot be divisible by `group_size`"
203
210
  assert x.is_contiguous(), "`x` is not contiguous"
204
211
 
205
- finfo = torch.finfo(dtype)
206
- fp8_max = finfo.max
207
-
208
- if _is_hip:
209
- fp8_max = 224.0
210
-
211
- fp8_min = -fp8_max
212
-
213
- x_q = torch.empty_like(x, device=x.device, dtype=dtype)
212
+ x_q = torch.empty_like(x, device=x.device, dtype=_fp8_type)
214
213
  M = x.numel() // group_size
215
214
  N = group_size
216
215
  if column_major_scales:
@@ -276,26 +275,36 @@ def sglang_per_token_group_quant_fp8(
276
275
  x: torch.Tensor,
277
276
  group_size: int,
278
277
  eps: float = 1e-10,
279
- dtype: torch.dtype = fp8_type_,
278
+ column_major_scales: bool = False,
279
+ scale_tma_aligned: bool = False,
280
280
  ):
281
281
  assert (
282
282
  x.shape[-1] % group_size == 0
283
283
  ), "the last dimension of `x` cannot be divisible by `group_size`"
284
284
  assert x.is_contiguous(), "`x` is not contiguous"
285
285
 
286
- finfo = torch.finfo(dtype)
287
- fp8_max = finfo.max
288
-
289
- fp8_min = -fp8_max
290
-
291
- x_q = torch.empty_like(x, device=x.device, dtype=dtype)
292
- M = x.numel() // group_size
293
- N = group_size
294
- x_s = torch.empty(
295
- x.shape[:-1] + (x.shape[-1] // group_size,),
296
- device=x.device,
297
- dtype=torch.float32,
298
- )
286
+ x_q = torch.empty_like(x, device=x.device, dtype=_fp8_type)
287
+ if column_major_scales:
288
+ if scale_tma_aligned:
289
+ # aligned to 4 * sizeof(float)
290
+ aligned_size = (x.shape[-2] + 3) // 4 * 4
291
+ x_s = torch.empty(
292
+ x.shape[:-2] + (x.shape[-1] // group_size, aligned_size),
293
+ device=x.device,
294
+ dtype=torch.float32,
295
+ ).permute(-1, -2)[: x.shape[-2], :]
296
+ else:
297
+ x_s = torch.empty(
298
+ (x.shape[-1] // group_size,) + x.shape[:-1],
299
+ device=x.device,
300
+ dtype=torch.float32,
301
+ ).permute(-1, -2)
302
+ else:
303
+ x_s = torch.empty(
304
+ x.shape[:-1] + (x.shape[-1] // group_size,),
305
+ device=x.device,
306
+ dtype=torch.float32,
307
+ )
299
308
 
300
309
  sgl_per_token_group_quant_fp8(x, x_q, x_s, group_size, eps, fp8_min, fp8_max)
301
310
 
@@ -304,7 +313,7 @@ def sglang_per_token_group_quant_fp8(
304
313
 
305
314
  def sglang_per_token_quant_fp8(
306
315
  x: torch.Tensor,
307
- dtype: torch.dtype = fp8_type_,
316
+ dtype: torch.dtype = _fp8_type,
308
317
  ):
309
318
  assert x.is_contiguous(), "`x` is not contiguous"
310
319
 
@@ -368,7 +377,6 @@ def static_quant_fp8(
368
377
  x: torch.Tensor,
369
378
  x_s: torch.Tensor,
370
379
  repeat_scale: bool = False,
371
- dtype: torch.dtype = fp8_type_,
372
380
  ) -> Tuple[torch.Tensor, torch.Tensor]:
373
381
  """Function to perform static quantization using the given scale on an input tensor `x`.
374
382
 
@@ -386,15 +394,8 @@ def static_quant_fp8(
386
394
  """
387
395
  assert x.is_contiguous(), "`x` is not contiguous"
388
396
  assert x_s.numel() == 1, "only supports per-tensor scale"
389
- finfo = torch.finfo(dtype)
390
- fp8_max = finfo.max
391
-
392
- if _is_hip:
393
- fp8_max = 224.0
394
397
 
395
- fp8_min = -fp8_max
396
-
397
- x_q = torch.empty_like(x, device=x.device, dtype=dtype)
398
+ x_q = torch.empty_like(x, device=x.device, dtype=_fp8_type)
398
399
  M = x.numel() // x.shape[-1]
399
400
  N = x.shape[-1]
400
401
  if repeat_scale:
@@ -896,22 +897,20 @@ def _per_tensor_quant_mla_fp8_stage2(
896
897
 
897
898
 
898
899
  def per_tensor_quant_mla_fp8(
899
- x: torch.Tensor, eps: float = 1e-12, dtype: torch.dtype = torch.float8_e4m3fn
900
+ x: torch.Tensor, x_s_out: torch.Tensor, eps: float = 1e-12
900
901
  ) -> Tuple[torch.Tensor, torch.Tensor]:
901
902
  """
902
903
  This function quantizes input values to float8 values with tensor-wise quantization
903
904
  and specialized for mla absorbed case.
904
905
  """
905
906
  assert x.dim() == 3, "`x` is not a 3d-tensor"
907
+ assert (
908
+ x_s_out.shape == (1,)
909
+ and x_s_out.dtype == torch.float32
910
+ and x_s_out.device == x.device
911
+ )
906
912
 
907
- finfo = torch.finfo(dtype)
908
- fp8_max = finfo.max
909
- if _is_hip:
910
- dtype = torch.float8_e4m3fnuz
911
- fp8_max = 224.0
912
-
913
- x_q = x.new_empty(x.size(), dtype=dtype)
914
- x_s = torch.zeros((1,), dtype=torch.float32, device=x.device)
913
+ x_q = x.new_empty(x.size(), dtype=_fp8_type)
915
914
 
916
915
  num_head, num_seq, head_size = x.shape
917
916
  BLOCK_SIZE = triton.next_power_of_2(head_size)
@@ -919,7 +918,7 @@ def per_tensor_quant_mla_fp8(
919
918
 
920
919
  _per_tensor_quant_mla_fp8_stage1[grid](
921
920
  x,
922
- x_s,
921
+ x_s_out,
923
922
  head_size,
924
923
  x.stride(0),
925
924
  x.stride(1),
@@ -929,15 +928,172 @@ def per_tensor_quant_mla_fp8(
929
928
  )
930
929
  _per_tensor_quant_mla_fp8_stage2[grid](
931
930
  x,
932
- x_s,
931
+ x_s_out,
933
932
  x_q,
934
933
  num_seq,
935
934
  head_size,
936
935
  x.stride(0),
937
936
  x.stride(1),
937
+ fp8_min,
938
+ fp8_max,
939
+ BLOCK_SIZE,
940
+ )
941
+
942
+ return x_q, x_s_out
943
+
944
+
945
+ @triton.jit
946
+ def _per_token_group_quant_mla_deep_gemm_masked_fp8(
947
+ y_ptr,
948
+ y_q_ptr,
949
+ y_s_ptr,
950
+ masked_m_ptr,
951
+ group_size,
952
+ y_stride_b,
953
+ y_stride_t,
954
+ y_q_stride_b,
955
+ y_q_stride_t,
956
+ y_s_stride_b,
957
+ y_s_stride_g,
958
+ eps,
959
+ fp8_min,
960
+ fp8_max,
961
+ NUM_GROUP: tl.constexpr,
962
+ BLOCK: tl.constexpr,
963
+ ):
964
+ """A Triton-accelerated function to perform per-token-group
965
+ quantization on a tensor for deep_gemm grouped_gemm_masked.
966
+ This function converts the tensor values into float8 values.
967
+ y and y_q: (b, t, k)
968
+ y_s: (b, k//group_size, t)
969
+ """
970
+ t_id = tl.program_id(0)
971
+ b_id = tl.program_id(1)
972
+
973
+ y_ptr += b_id * y_stride_b + t_id * y_stride_t
974
+ y_q_ptr += b_id * y_q_stride_b + t_id * y_q_stride_t
975
+ y_s_ptr += b_id * y_s_stride_b + t_id
976
+
977
+ if t_id == 0:
978
+ tl.store(masked_m_ptr + b_id, tl.num_programs(0))
979
+
980
+ cols = tl.arange(0, BLOCK) # group_size <= BLOCK
981
+ mask = cols < group_size
982
+
983
+ for gid in range(NUM_GROUP):
984
+ y = tl.load(y_ptr + gid * group_size + cols, mask=mask, other=0.0).to(
985
+ tl.float32
986
+ )
987
+ _absmax = tl.maximum(tl.max(tl.abs(y)), eps)
988
+ y_s = _absmax / fp8_max
989
+ y_q = tl.clamp(y / y_s, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty)
990
+
991
+ tl.store(y_q_ptr + gid * group_size + cols, y_q, mask=mask)
992
+ tl.store(y_s_ptr + gid * y_s_stride_g, y_s)
993
+
994
+
995
+ def per_tensor_quant_mla_deep_gemm_masked_fp8(
996
+ x: torch.Tensor,
997
+ group_size: int = 128,
998
+ eps: float = 1e-12,
999
+ dtype: torch.dtype = torch.float8_e4m3fn,
1000
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
1001
+ """
1002
+ This function quantizes input values to float8 values with per-token-group-quantization
1003
+ for deep_gemm grouped_gemm_masked and specialized for mla absorbed case.
1004
+ """
1005
+ assert x.dim() == 3, "`x` is not a 3d-tensor"
1006
+
1007
+ finfo = torch.finfo(dtype)
1008
+ fp8_max = finfo.max
1009
+ if _is_hip:
1010
+ dtype = torch.float8_e4m3fnuz
1011
+ fp8_max = 224.0
1012
+
1013
+ b, m, k = x.shape
1014
+ aligned_m = (m + 255) // 256 * 256 # 256 is the max block_m of the gemm kernel
1015
+ num_tiles_k = k // group_size
1016
+ assert num_tiles_k * group_size == k, f"k % {group_size} must be zero"
1017
+
1018
+ x_q = x.new_empty((b, aligned_m, k), dtype=dtype)
1019
+ x_s = x.new_empty((b, num_tiles_k, aligned_m), dtype=torch.float32)
1020
+ masked_m = x.new_empty((b,), dtype=torch.int32)
1021
+
1022
+ BLOCK_SIZE = triton.next_power_of_2(group_size)
1023
+ grid = (m, b)
1024
+
1025
+ _per_token_group_quant_mla_deep_gemm_masked_fp8[grid](
1026
+ x,
1027
+ x_q,
1028
+ x_s,
1029
+ masked_m,
1030
+ group_size,
1031
+ x.stride(0),
1032
+ x.stride(1),
1033
+ x_q.stride(0),
1034
+ x_q.stride(1),
1035
+ x_s.stride(0),
1036
+ x_s.stride(1),
1037
+ eps,
938
1038
  -fp8_max,
939
1039
  fp8_max,
1040
+ num_tiles_k,
940
1041
  BLOCK_SIZE,
941
1042
  )
942
1043
 
943
- return x_q, x_s
1044
+ return x_q, x_s.transpose(1, 2), masked_m, m, aligned_m
1045
+
1046
+
1047
+ def scaled_fp8_quant(
1048
+ input: torch.Tensor,
1049
+ scale: Optional[torch.Tensor] = None,
1050
+ num_token_padding: Optional[int] = None,
1051
+ use_per_token_if_dynamic: bool = False,
1052
+ ) -> tuple[torch.Tensor, torch.Tensor]:
1053
+ """
1054
+ Quantize input tensor to FP8 (8-bit floating point) format.
1055
+
1056
+ Args:
1057
+ input (torch.Tensor): Input tensor to be quantized
1058
+ scale (Optional[torch.Tensor]): Pre-computed scaling factor for static quantization.
1059
+ If None, scales will be computed dynamically.
1060
+ num_token_padding (Optional[int]): If specified, pad the first dimension
1061
+ of the output to at least this value.
1062
+ use_per_token_if_dynamic (bool): When using dynamic scaling (scale=None),
1063
+ determines the quantization granularity:
1064
+ - True: compute scale per token
1065
+ - False: compute single scale per tensor
1066
+
1067
+ Returns:
1068
+ Tuple[torch.Tensor, torch.Tensor]: A tuple containing:
1069
+ - quantized_tensor: The FP8 quantized version of input
1070
+ - scale_tensor: The scaling factors used for quantization
1071
+
1072
+ Raises:
1073
+ AssertionError: If input is not 2D or if static scale's numel != 1
1074
+ """
1075
+ assert input.ndim == 2, f"Expected 2D input tensor, got {input.ndim}D"
1076
+ shape = input.shape
1077
+ out_dtype = torch.float8_e4m3fnuz if _is_hip else torch.float8_e4m3fn
1078
+ if num_token_padding:
1079
+ shape = (max(num_token_padding, input.shape[0]), shape[1])
1080
+ output = torch.empty(shape, device=input.device, dtype=out_dtype)
1081
+
1082
+ if scale is None:
1083
+ # Dynamic scaling
1084
+ if use_per_token_if_dynamic:
1085
+ scale = torch.empty((shape[0], 1), device=input.device, dtype=torch.float32)
1086
+ sgl_per_token_quant_fp8(input, output, scale)
1087
+ else:
1088
+ scale = torch.zeros(1, device=input.device, dtype=torch.float32)
1089
+ sgl_per_tensor_quant_fp8(
1090
+ input, output, scale, is_static=False
1091
+ ) # False for dynamic
1092
+ else:
1093
+ # Static scaling
1094
+ assert scale.numel() == 1, f"Expected scalar scale, got numel={scale.numel()}"
1095
+ sgl_per_tensor_quant_fp8(
1096
+ input, output, scale, is_static=True
1097
+ ) # True for static
1098
+
1099
+ return output, scale