sglang 0.4.5.post1__py3-none-any.whl → 0.4.5.post2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (95) hide show
  1. sglang/__init__.py +2 -4
  2. sglang/bench_one_batch.py +2 -2
  3. sglang/bench_serving.py +0 -4
  4. sglang/lang/backend/anthropic.py +0 -4
  5. sglang/lang/backend/base_backend.py +1 -1
  6. sglang/lang/backend/openai.py +1 -1
  7. sglang/lang/backend/vertexai.py +0 -1
  8. sglang/lang/compiler.py +1 -7
  9. sglang/lang/tracer.py +3 -7
  10. sglang/srt/_custom_ops.py +0 -2
  11. sglang/srt/constrained/outlines_jump_forward.py +14 -1
  12. sglang/srt/constrained/triton_ops/bitmask_ops.py +141 -0
  13. sglang/srt/constrained/xgrammar_backend.py +26 -4
  14. sglang/srt/custom_op.py +0 -62
  15. sglang/srt/disaggregation/decode.py +62 -6
  16. sglang/srt/disaggregation/mini_lb.py +5 -1
  17. sglang/srt/disaggregation/mooncake/conn.py +32 -62
  18. sglang/srt/disaggregation/mooncake/transfer_engine.py +30 -61
  19. sglang/srt/disaggregation/prefill.py +40 -4
  20. sglang/srt/disaggregation/utils.py +15 -0
  21. sglang/srt/entrypoints/verl_engine.py +7 -5
  22. sglang/srt/layers/activation.py +6 -8
  23. sglang/srt/layers/attention/flashattention_backend.py +114 -71
  24. sglang/srt/layers/attention/flashinfer_backend.py +5 -2
  25. sglang/srt/layers/attention/torch_native_backend.py +6 -1
  26. sglang/srt/layers/attention/triton_backend.py +6 -0
  27. sglang/srt/layers/attention/triton_ops/extend_attention.py +13 -2
  28. sglang/srt/layers/layernorm.py +1 -1
  29. sglang/srt/layers/linear.py +17 -3
  30. sglang/srt/layers/moe/ep_moe/layer.py +15 -29
  31. sglang/srt/layers/moe/fused_moe_native.py +4 -0
  32. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +14 -19
  33. sglang/srt/layers/moe/fused_moe_triton/layer.py +7 -0
  34. sglang/srt/layers/moe/topk.py +27 -30
  35. sglang/srt/layers/parameter.py +0 -2
  36. sglang/srt/layers/quantization/__init__.py +1 -0
  37. sglang/srt/layers/quantization/blockwise_int8.py +2 -0
  38. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +8 -2
  39. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +16 -44
  40. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +4 -7
  41. sglang/srt/layers/quantization/fp8.py +115 -132
  42. sglang/srt/layers/quantization/fp8_kernel.py +213 -57
  43. sglang/srt/layers/quantization/fp8_utils.py +187 -262
  44. sglang/srt/layers/quantization/moe_wna16.py +2 -0
  45. sglang/srt/layers/quantization/utils.py +5 -11
  46. sglang/srt/layers/quantization/w8a8_fp8.py +2 -0
  47. sglang/srt/layers/quantization/w8a8_int8.py +7 -7
  48. sglang/srt/layers/radix_attention.py +15 -0
  49. sglang/srt/layers/rotary_embedding.py +3 -2
  50. sglang/srt/layers/sampler.py +5 -10
  51. sglang/srt/lora/backend/base_backend.py +18 -2
  52. sglang/srt/lora/backend/flashinfer_backend.py +1 -1
  53. sglang/srt/lora/backend/triton_backend.py +1 -1
  54. sglang/srt/lora/layers.py +1 -1
  55. sglang/srt/lora/lora.py +1 -1
  56. sglang/srt/lora/lora_manager.py +1 -1
  57. sglang/srt/managers/detokenizer_manager.py +0 -1
  58. sglang/srt/managers/io_struct.py +1 -0
  59. sglang/srt/managers/mm_utils.py +4 -3
  60. sglang/srt/managers/multimodal_processor.py +0 -2
  61. sglang/srt/managers/multimodal_processors/base_processor.py +3 -2
  62. sglang/srt/managers/schedule_batch.py +2 -4
  63. sglang/srt/managers/scheduler.py +12 -71
  64. sglang/srt/managers/tokenizer_manager.py +1 -0
  65. sglang/srt/mem_cache/hiradix_cache.py +5 -1
  66. sglang/srt/mem_cache/memory_pool.py +7 -2
  67. sglang/srt/model_executor/cuda_graph_runner.py +2 -2
  68. sglang/srt/model_executor/model_runner.py +20 -27
  69. sglang/srt/models/bert.py +398 -0
  70. sglang/srt/models/deepseek.py +1 -1
  71. sglang/srt/models/deepseek_nextn.py +74 -70
  72. sglang/srt/models/deepseek_v2.py +289 -348
  73. sglang/srt/models/llama.py +5 -5
  74. sglang/srt/models/minicpm3.py +29 -201
  75. sglang/srt/models/qwen2.py +4 -1
  76. sglang/srt/models/qwen2_moe.py +14 -13
  77. sglang/srt/models/qwen3.py +335 -0
  78. sglang/srt/models/qwen3_moe.py +423 -0
  79. sglang/srt/reasoning_parser.py +0 -1
  80. sglang/srt/sampling/sampling_batch_info.py +2 -3
  81. sglang/srt/server_args.py +34 -32
  82. sglang/srt/speculative/eagle_worker.py +4 -7
  83. sglang/srt/utils.py +16 -1
  84. sglang/test/runners.py +5 -1
  85. sglang/test/test_block_fp8.py +167 -0
  86. sglang/test/test_custom_ops.py +1 -1
  87. sglang/version.py +1 -1
  88. {sglang-0.4.5.post1.dist-info → sglang-0.4.5.post2.dist-info}/METADATA +3 -3
  89. {sglang-0.4.5.post1.dist-info → sglang-0.4.5.post2.dist-info}/RECORD +92 -91
  90. {sglang-0.4.5.post1.dist-info → sglang-0.4.5.post2.dist-info}/WHEEL +1 -1
  91. sglang/lang/__init__.py +0 -0
  92. sglang/srt/lora/backend/__init__.py +0 -25
  93. sglang/srt/server.py +0 -18
  94. {sglang-0.4.5.post1.dist-info → sglang-0.4.5.post2.dist-info}/licenses/LICENSE +0 -0
  95. {sglang-0.4.5.post1.dist-info → sglang-0.4.5.post2.dist-info}/top_level.txt +0 -0
@@ -8,15 +8,6 @@ import torch.nn.functional as F
8
8
  from torch.nn import Module
9
9
  from torch.nn.parameter import Parameter
10
10
 
11
- from sglang.srt.layers.quantization.kv_cache import BaseKVCacheMethod
12
- from sglang.srt.layers.quantization.utils import (
13
- all_close_1d,
14
- convert_to_channelwise,
15
- is_layer_skipped,
16
- per_tensor_dequantize,
17
- requantize_with_max_scale,
18
- )
19
-
20
11
  try:
21
12
  from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
22
13
  apply_fp8_marlin_linear,
@@ -27,11 +18,12 @@ try:
27
18
  except ImportError:
28
19
  MARLIN_FP8_AVAILABLE = False
29
20
 
30
- def apply_fp8_marlin_linear(*args, **kwargs):
31
- raise ImportError("vllm is not installed")
21
+ def dummy_func(*args, **kwargs):
22
+ raise ImportError(
23
+ "marlin FP8 requires some operators from vllm. Please install vllm."
24
+ )
32
25
 
33
- def prepare_fp8_layer_for_marlin(*args, **kwargs):
34
- raise ImportError("vllm is not installed")
26
+ apply_fp8_marlin_linear = prepare_fp8_layer_for_marlin = dummy_func
35
27
 
36
28
 
37
29
  from sglang.srt.distributed import get_tensor_model_parallel_world_size
@@ -49,7 +41,10 @@ from sglang.srt.layers.quantization.base_config import (
49
41
  QuantizationConfig,
50
42
  QuantizeMethodBase,
51
43
  )
52
- from sglang.srt.layers.quantization.fp8_kernel import per_token_group_quant_fp8
44
+ from sglang.srt.layers.quantization.fp8_kernel import (
45
+ per_token_group_quant_fp8,
46
+ scaled_fp8_quant,
47
+ )
53
48
  from sglang.srt.layers.quantization.fp8_utils import (
54
49
  apply_fp8_linear,
55
50
  apply_w8a8_block_fp8_linear,
@@ -57,30 +52,35 @@ from sglang.srt.layers.quantization.fp8_utils import (
57
52
  input_to_float8,
58
53
  normalize_e4m3fn_to_e4m3fnuz,
59
54
  )
55
+ from sglang.srt.layers.quantization.kv_cache import BaseKVCacheMethod
56
+ from sglang.srt.layers.quantization.utils import (
57
+ all_close_1d,
58
+ convert_to_channelwise,
59
+ is_layer_skipped,
60
+ per_tensor_dequantize,
61
+ requantize_with_max_scale,
62
+ )
60
63
  from sglang.srt.utils import (
61
64
  get_bool_env_var,
62
65
  is_cuda,
63
66
  is_hip,
64
- permute_weight,
65
67
  print_warning_once,
66
68
  set_weight_attrs,
67
69
  )
68
70
 
69
- ACTIVATION_SCHEMES = ["static", "dynamic"]
70
-
71
71
  _is_hip = is_hip()
72
+ _is_cuda = is_cuda()
72
73
 
73
74
  if _is_hip:
74
75
  from aiter import ActivationType
75
76
  from aiter.fused_moe_bf16_asm import asm_moe, ck_moe_2stages, ck_moe_2stages_win4
76
77
  from aiter.ops.shuffle import shuffle_weight
77
78
 
78
- _is_cuda = is_cuda()
79
+ if not _is_cuda:
80
+ from vllm._custom_ops import scaled_fp8_quant
79
81
 
80
- if _is_cuda:
81
- from sglang.srt.custom_op import scaled_fp8_quant as sgl_scaled_fp8_quant
82
- else:
83
- from vllm import _custom_ops as vllm_ops
82
+
83
+ ACTIVATION_SCHEMES = ["static", "dynamic"]
84
84
 
85
85
  logger = logging.getLogger(__name__)
86
86
 
@@ -243,7 +243,6 @@ class Fp8LinearMethod(LinearMethodBase):
243
243
  )
244
244
 
245
245
  layer.logical_widths = output_partition_sizes
246
-
247
246
  layer.input_size_per_partition = input_size_per_partition
248
247
  layer.output_size_per_partition = output_size_per_partition
249
248
  layer.orig_dtype = params_dtype
@@ -327,7 +326,9 @@ class Fp8LinearMethod(LinearMethodBase):
327
326
  layer.weight_scale_inv.data, requires_grad=False
328
327
  )
329
328
  return
329
+
330
330
  layer.weight = torch.nn.Parameter(layer.weight.data, requires_grad=False)
331
+
331
332
  # If checkpoint not serialized fp8, quantize the weights.
332
333
  if not self.quant_config.is_checkpoint_fp8_serialized:
333
334
  if self.cutlass_fp8_supported or self.use_marlin:
@@ -391,12 +392,9 @@ class Fp8LinearMethod(LinearMethodBase):
391
392
  )
392
393
 
393
394
  if self.use_marlin:
394
- try:
395
- prepare_fp8_layer_for_marlin(layer)
396
- # Activations not quantized for marlin.
397
- del layer.input_scale
398
- except ImportError:
399
- self.use_marlin = False
395
+ prepare_fp8_layer_for_marlin(layer)
396
+ # Activations not quantized for marlin.
397
+ del layer.input_scale
400
398
 
401
399
  def apply(
402
400
  self,
@@ -406,18 +404,15 @@ class Fp8LinearMethod(LinearMethodBase):
406
404
  ) -> torch.Tensor:
407
405
 
408
406
  if self.use_marlin:
409
- try:
410
- return apply_fp8_marlin_linear(
411
- input=x,
412
- weight=layer.weight,
413
- weight_scale=layer.weight_scale,
414
- workspace=layer.workspace,
415
- size_n=layer.output_size_per_partition,
416
- size_k=layer.input_size_per_partition,
417
- bias=bias,
418
- )
419
- except ImportError:
420
- self.use_marlin = False
407
+ return apply_fp8_marlin_linear(
408
+ input=x,
409
+ weight=layer.weight,
410
+ weight_scale=layer.weight_scale,
411
+ workspace=layer.workspace,
412
+ size_n=layer.output_size_per_partition,
413
+ size_k=layer.input_size_per_partition,
414
+ bias=bias,
415
+ )
421
416
 
422
417
  if self.block_quant:
423
418
  return apply_w8a8_block_fp8_linear(
@@ -516,7 +511,7 @@ class Fp8MoEMethod:
516
511
  )
517
512
 
518
513
  # WEIGHTS
519
- if get_bool_env_var("USE_INT4_WEIGHT"):
514
+ if _is_hip and get_bool_env_var("USE_INT4_WEIGHT"):
520
515
  # INT4 MoE weight - INT32 packed
521
516
  w13_weight = torch.nn.Parameter(
522
517
  torch.empty(
@@ -617,7 +612,7 @@ class Fp8MoEMethod:
617
612
  set_weight_attrs(w13_weight_scale, extra_weight_attrs)
618
613
  set_weight_attrs(w2_weight_scale, extra_weight_attrs)
619
614
 
620
- if get_bool_env_var("USE_INT4_WEIGHT"):
615
+ if _is_hip and get_bool_env_var("USE_INT4_WEIGHT"):
621
616
  extra_weight_attrs.update(
622
617
  {"quant_method": FusedMoeWeightScaleSupported.CHANNEL.value}
623
618
  )
@@ -649,7 +644,7 @@ class Fp8MoEMethod:
649
644
  layer.w2_input_scale = None
650
645
 
651
646
  def process_weights_after_loading(self, layer: Module) -> None:
652
- if get_bool_env_var("USE_INT4_WEIGHT"):
647
+ if _is_hip and get_bool_env_var("USE_INT4_WEIGHT"):
653
648
  self.process_weights_hip_int4(layer)
654
649
  return
655
650
 
@@ -706,20 +701,12 @@ class Fp8MoEMethod:
706
701
  requires_grad=False,
707
702
  )
708
703
  for expert in range(layer.num_experts):
709
- if _is_cuda:
710
- w13_weight[expert, :, :], layer.w13_weight_scale[expert] = (
711
- sgl_scaled_fp8_quant(layer.w13_weight.data[expert, :, :])
712
- )
713
- w2_weight[expert, :, :], layer.w2_weight_scale[expert] = (
714
- sgl_scaled_fp8_quant(layer.w2_weight.data[expert, :, :])
715
- )
716
- else:
717
- w13_weight[expert, :, :], layer.w13_weight_scale[expert] = (
718
- vllm_ops.scaled_fp8_quant(layer.w13_weight.data[expert, :, :])
719
- )
720
- w2_weight[expert, :, :], layer.w2_weight_scale[expert] = (
721
- vllm_ops.scaled_fp8_quant(layer.w2_weight.data[expert, :, :])
722
- )
704
+ w13_weight[expert, :, :], layer.w13_weight_scale[expert] = (
705
+ scaled_fp8_quant(layer.w13_weight.data[expert, :, :])
706
+ )
707
+ w2_weight[expert, :, :], layer.w2_weight_scale[expert] = (
708
+ scaled_fp8_quant(layer.w2_weight.data[expert, :, :])
709
+ )
723
710
  layer.w13_weight = torch.nn.Parameter(w13_weight, requires_grad=False)
724
711
  layer.w2_weight = torch.nn.Parameter(w2_weight, requires_grad=False)
725
712
 
@@ -796,18 +783,10 @@ class Fp8MoEMethod:
796
783
  layer.w13_weight[expert_id][start : start + shard_size, :],
797
784
  layer.w13_weight_scale[expert_id][shard_id],
798
785
  )
799
- if _is_cuda:
800
- (
801
- layer.w13_weight[expert_id][start : start + shard_size, :],
802
- _,
803
- ) = sgl_scaled_fp8_quant(dq_weight, max_w13_scales[expert_id])
804
- else:
805
- (
806
- layer.w13_weight[expert_id][start : start + shard_size, :],
807
- _,
808
- ) = vllm_ops.scaled_fp8_quant(
809
- dq_weight, max_w13_scales[expert_id]
810
- )
786
+ (
787
+ layer.w13_weight[expert_id][start : start + shard_size, :],
788
+ _,
789
+ ) = scaled_fp8_quant(dq_weight, max_w13_scales[expert_id])
811
790
  start += shard_size
812
791
 
813
792
  layer.w13_weight_scale = torch.nn.Parameter(
@@ -913,6 +892,7 @@ class Fp8MoEMethod:
913
892
  apply_router_weight_on_input: bool = False,
914
893
  inplace: bool = True,
915
894
  no_combine: bool = False,
895
+ routed_scaling_factor: Optional[float] = None,
916
896
  ) -> torch.Tensor:
917
897
  from sglang.srt.layers.moe.fused_moe_triton.fused_moe import fused_experts
918
898
  from sglang.srt.layers.moe.topk import select_experts
@@ -928,43 +908,14 @@ class Fp8MoEMethod:
928
908
  num_expert_group=num_expert_group,
929
909
  custom_routing_function=custom_routing_function,
930
910
  correction_bias=correction_bias,
911
+ routed_scaling_factor=routed_scaling_factor,
931
912
  )
932
913
 
933
- if _is_hip and get_bool_env_var("USE_INT4_WEIGHT"):
934
- # TODO: add triton kernel and add check get_bool_env_var("CK_MOE")
935
- assert not no_combine, f"{no_combine=} is not supported."
936
- return ck_moe_2stages_win4(
937
- x,
938
- layer.w13_weight,
939
- layer.w2_weight,
940
- topk_weights,
941
- topk_ids,
942
- layer.w13_weight_scale1,
943
- layer.w2_weight_scale1,
944
- activation=(
945
- ActivationType.Silu if activation == "silu" else ActivationType.Gelu
946
- ),
947
- )
948
- if _is_hip and get_bool_env_var("CK_MOE"):
949
- assert not no_combine, f"{no_combine=} is not supported."
950
- if self.block_quant:
951
- # TODO(CK_MOE): FP8 block_quant only supports 'silu' for the time-being.
952
- assert (
953
- activation == "silu"
954
- ), f"CK_MOE: FP8 bloack_quant {activation=} will be supported later, unset CK_MOE"
955
- return asm_moe(
956
- x,
957
- layer.w13_weight,
958
- layer.w2_weight,
959
- topk_weights,
960
- topk_ids,
961
- layer.w13_weight_scale_inv,
962
- layer.w2_weight_scale_inv,
963
- block_shape=tuple(self.quant_config.weight_block_size),
964
- expert_mask=None,
965
- )
966
- else:
967
- return ck_moe_2stages(
914
+ if _is_hip:
915
+ if get_bool_env_var("USE_INT4_WEIGHT"):
916
+ # TODO: add triton kernel and add check get_bool_env_var("CK_MOE")
917
+ assert not no_combine, f"{no_combine=} is not supported."
918
+ return ck_moe_2stages_win4(
968
919
  x,
969
920
  layer.w13_weight,
970
921
  layer.w2_weight,
@@ -978,33 +929,65 @@ class Fp8MoEMethod:
978
929
  else ActivationType.Gelu
979
930
  ),
980
931
  )
981
- else:
982
- # Expert fusion with FP8 quantization
983
- return fused_experts(
984
- x,
985
- layer.w13_weight,
986
- layer.w2_weight,
987
- topk_weights=topk_weights,
988
- topk_ids=topk_ids,
989
- inplace=inplace and not no_combine,
990
- activation=activation,
991
- apply_router_weight_on_input=apply_router_weight_on_input,
992
- use_fp8_w8a8=True,
993
- w1_scale=(
994
- layer.w13_weight_scale_inv
995
- if self.block_quant
996
- else layer.w13_weight_scale
997
- ),
998
- w2_scale=(
999
- layer.w2_weight_scale_inv
1000
- if self.block_quant
1001
- else layer.w2_weight_scale
1002
- ),
1003
- a1_scale=layer.w13_input_scale,
1004
- a2_scale=layer.w2_input_scale,
1005
- block_shape=self.quant_config.weight_block_size,
1006
- no_combine=no_combine,
1007
- )
932
+
933
+ if get_bool_env_var("CK_MOE"):
934
+ assert not no_combine, f"{no_combine=} is not supported."
935
+ if self.block_quant:
936
+ # TODO(CK_MOE): FP8 block_quant only supports 'silu' for the time-being.
937
+ assert (
938
+ activation == "silu"
939
+ ), f"CK_MOE: FP8 bloack_quant {activation=} will be supported later, unset CK_MOE"
940
+ return asm_moe(
941
+ x,
942
+ layer.w13_weight,
943
+ layer.w2_weight,
944
+ topk_weights,
945
+ topk_ids,
946
+ layer.w13_weight_scale_inv,
947
+ layer.w2_weight_scale_inv,
948
+ block_shape=tuple(self.quant_config.weight_block_size),
949
+ expert_mask=None,
950
+ )
951
+ else:
952
+ return ck_moe_2stages(
953
+ x,
954
+ layer.w13_weight,
955
+ layer.w2_weight,
956
+ topk_weights,
957
+ topk_ids,
958
+ layer.w13_weight_scale1,
959
+ layer.w2_weight_scale1,
960
+ activation=(
961
+ ActivationType.Silu
962
+ if activation == "silu"
963
+ else ActivationType.Gelu
964
+ ),
965
+ )
966
+
967
+ # Expert fusion with FP8 quantization
968
+ return fused_experts(
969
+ x,
970
+ layer.w13_weight,
971
+ layer.w2_weight,
972
+ topk_weights=topk_weights,
973
+ topk_ids=topk_ids,
974
+ inplace=inplace and not no_combine,
975
+ activation=activation,
976
+ apply_router_weight_on_input=apply_router_weight_on_input,
977
+ use_fp8_w8a8=True,
978
+ w1_scale=(
979
+ layer.w13_weight_scale_inv
980
+ if self.block_quant
981
+ else layer.w13_weight_scale
982
+ ),
983
+ w2_scale=(
984
+ layer.w2_weight_scale_inv if self.block_quant else layer.w2_weight_scale
985
+ ),
986
+ a1_scale=layer.w13_input_scale,
987
+ a2_scale=layer.w2_input_scale,
988
+ block_shape=self.quant_config.weight_block_size,
989
+ no_combine=no_combine,
990
+ )
1008
991
 
1009
992
 
1010
993
  class Fp8KVCacheMethod(BaseKVCacheMethod):