sglang 0.4.5.post2__py3-none-any.whl → 0.4.5.post3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (62) hide show
  1. sglang/bench_serving.py +3 -2
  2. sglang/compile_deep_gemm.py +136 -0
  3. sglang/lang/backend/openai.py +5 -1
  4. sglang/lang/backend/runtime_endpoint.py +5 -1
  5. sglang/srt/configs/model_config.py +4 -1
  6. sglang/srt/constrained/xgrammar_backend.py +1 -0
  7. sglang/srt/disaggregation/decode.py +43 -0
  8. sglang/srt/disaggregation/mini_lb.py +69 -8
  9. sglang/srt/disaggregation/mooncake/conn.py +1 -1
  10. sglang/srt/disaggregation/nixl/__init__.py +1 -0
  11. sglang/srt/disaggregation/nixl/conn.py +622 -0
  12. sglang/srt/disaggregation/prefill.py +100 -16
  13. sglang/srt/disaggregation/utils.py +17 -0
  14. sglang/srt/entrypoints/engine.py +4 -0
  15. sglang/srt/entrypoints/http_server.py +3 -7
  16. sglang/srt/function_call_parser.py +60 -0
  17. sglang/srt/layers/activation.py +2 -2
  18. sglang/srt/layers/attention/flashattention_backend.py +781 -150
  19. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +5 -5
  20. sglang/srt/layers/attention/triton_ops/extend_attention.py +5 -5
  21. sglang/srt/layers/attention/triton_ops/prefill_attention.py +7 -3
  22. sglang/srt/layers/dp_attention.py +1 -1
  23. sglang/srt/layers/layernorm.py +19 -4
  24. sglang/srt/layers/moe/ep_moe/layer.py +2 -0
  25. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
  26. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
  27. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +153 -0
  28. sglang/srt/layers/quantization/deep_gemm.py +378 -0
  29. sglang/srt/layers/quantization/fp8_kernel.py +7 -38
  30. sglang/srt/layers/quantization/fp8_utils.py +2 -2
  31. sglang/srt/layers/quantization/gptq.py +13 -7
  32. sglang/srt/layers/quantization/modelopt_quant.py +2 -2
  33. sglang/srt/layers/quantization/w8a8_int8.py +3 -3
  34. sglang/srt/layers/rotary_embedding.py +6 -6
  35. sglang/srt/layers/sampler.py +2 -2
  36. sglang/srt/managers/data_parallel_controller.py +7 -1
  37. sglang/srt/managers/io_struct.py +14 -3
  38. sglang/srt/managers/schedule_batch.py +13 -0
  39. sglang/srt/managers/scheduler.py +16 -6
  40. sglang/srt/managers/tokenizer_manager.py +115 -29
  41. sglang/srt/managers/tp_worker.py +1 -0
  42. sglang/srt/mem_cache/hiradix_cache.py +40 -32
  43. sglang/srt/mem_cache/memory_pool.py +31 -13
  44. sglang/srt/model_executor/cuda_graph_runner.py +13 -8
  45. sglang/srt/model_executor/model_runner.py +19 -4
  46. sglang/srt/models/deepseek_v2.py +9 -6
  47. sglang/srt/models/minicpm3.py +2 -2
  48. sglang/srt/models/minicpmo.py +17 -6
  49. sglang/srt/openai_api/adapter.py +71 -4
  50. sglang/srt/openai_api/protocol.py +6 -1
  51. sglang/srt/server_args.py +52 -40
  52. sglang/srt/speculative/build_eagle_tree.py +2 -2
  53. sglang/srt/speculative/eagle_utils.py +2 -2
  54. sglang/srt/speculative/eagle_worker.py +2 -7
  55. sglang/srt/utils.py +46 -5
  56. sglang/test/test_utils.py +3 -1
  57. sglang/version.py +1 -1
  58. {sglang-0.4.5.post2.dist-info → sglang-0.4.5.post3.dist-info}/METADATA +3 -3
  59. {sglang-0.4.5.post2.dist-info → sglang-0.4.5.post3.dist-info}/RECORD +62 -57
  60. {sglang-0.4.5.post2.dist-info → sglang-0.4.5.post3.dist-info}/WHEEL +0 -0
  61. {sglang-0.4.5.post2.dist-info → sglang-0.4.5.post3.dist-info}/licenses/LICENSE +0 -0
  62. {sglang-0.4.5.post2.dist-info → sglang-0.4.5.post3.dist-info}/top_level.txt +0 -0
@@ -3,10 +3,10 @@ import triton
3
3
  import triton.language as tl
4
4
 
5
5
  from sglang.srt.managers.schedule_batch import global_server_args_dict
6
- from sglang.srt.utils import is_hip
6
+ from sglang.srt.utils import is_cuda, is_hip
7
7
 
8
- is_cuda_available = torch.cuda.is_available()
9
- if is_cuda_available:
8
+ _is_cuda = is_cuda()
9
+ if _is_cuda:
10
10
  CUDA_CAPABILITY = torch.cuda.get_device_capability()
11
11
 
12
12
  _is_hip = is_hip()
@@ -1037,12 +1037,12 @@ def extend_attention_fwd(
1037
1037
  num_warps = 4
1038
1038
 
1039
1039
  else:
1040
- if is_cuda_available and CUDA_CAPABILITY[0] >= 9:
1040
+ if _is_cuda and CUDA_CAPABILITY[0] >= 9:
1041
1041
  if Lq <= 256:
1042
1042
  BLOCK_M, BLOCK_N = (128, 64)
1043
1043
  else:
1044
1044
  BLOCK_M, BLOCK_N = (32, 64)
1045
- elif is_cuda_available and CUDA_CAPABILITY[0] >= 8:
1045
+ elif _is_cuda and CUDA_CAPABILITY[0] >= 8:
1046
1046
  if Lq <= 128:
1047
1047
  BLOCK_M, BLOCK_N = (128, 128)
1048
1048
  elif Lq <= 256:
@@ -23,10 +23,10 @@ import triton.language as tl
23
23
  from sglang.srt.layers.attention.triton_ops.prefill_attention import (
24
24
  context_attention_fwd,
25
25
  )
26
- from sglang.srt.utils import is_hip
26
+ from sglang.srt.utils import is_cuda, is_hip
27
27
 
28
- is_cuda_available = torch.cuda.is_available()
29
- if is_cuda_available:
28
+ _is_cuda = is_cuda()
29
+ if _is_cuda:
30
30
  CUDA_CAPABILITY = torch.cuda.get_device_capability()
31
31
 
32
32
  _is_hip = is_hip()
@@ -345,12 +345,12 @@ def extend_attention_fwd(
345
345
  num_warps = 4
346
346
 
347
347
  else:
348
- if is_cuda_available and CUDA_CAPABILITY[0] >= 9:
348
+ if _is_cuda and CUDA_CAPABILITY[0] >= 9:
349
349
  if Lq <= 256:
350
350
  BLOCK_M, BLOCK_N = (128, 64)
351
351
  else:
352
352
  BLOCK_M, BLOCK_N = (32, 64)
353
- elif is_cuda_available and CUDA_CAPABILITY[0] >= 8:
353
+ elif _is_cuda and CUDA_CAPABILITY[0] >= 8:
354
354
  # sm86/sm89 has a much smaller shared memory size (100K) than sm80 (160K)
355
355
  if CUDA_CAPABILITY[1] == 9 or CUDA_CAPABILITY[1] == 6:
356
356
  if Lq <= 128:
@@ -22,8 +22,12 @@ import torch
22
22
  import triton
23
23
  import triton.language as tl
24
24
 
25
- is_cuda_available = torch.cuda.is_available()
26
- if is_cuda_available:
25
+ from sglang.srt.utils import is_cuda, is_hip
26
+
27
+ _is_cuda = is_cuda()
28
+ _is_hip = is_hip()
29
+
30
+ if _is_cuda or _is_hip:
27
31
  CUDA_CAPABILITY = torch.cuda.get_device_capability()
28
32
 
29
33
 
@@ -172,7 +176,7 @@ def context_attention_fwd(
172
176
  b_seq_len: [b]
173
177
  out: [b * s, head, head_dim]
174
178
  """
175
- if is_cuda_available and CUDA_CAPABILITY[0] > 8:
179
+ if (_is_cuda or _is_hip) and CUDA_CAPABILITY[0] > 8:
176
180
  BLOCK = 128
177
181
  else:
178
182
  BLOCK = 64
@@ -143,7 +143,7 @@ def memcpy_triton_kernel(
143
143
  src_ptr,
144
144
  offset_ptr,
145
145
  sz_ptr,
146
- offset_src,
146
+ offset_src: tl.constexpr,
147
147
  chunk_size, # multiplied for offset and sz
148
148
  BLOCK_SIZE: tl.constexpr,
149
149
  ):
@@ -20,9 +20,12 @@ import torch
20
20
  import torch.nn as nn
21
21
 
22
22
  from sglang.srt.custom_op import CustomOp
23
- from sglang.srt.utils import is_cuda_available
23
+ from sglang.srt.utils import is_cuda, is_hip
24
24
 
25
- _is_cuda = is_cuda_available()
25
+ logger = logging.getLogger(__name__)
26
+
27
+ _is_cuda = is_cuda()
28
+ _is_hip = is_hip()
26
29
 
27
30
  if _is_cuda:
28
31
  from sgl_kernel import (
@@ -32,8 +35,20 @@ if _is_cuda:
32
35
  rmsnorm,
33
36
  )
34
37
 
38
+ if _is_hip:
35
39
 
36
- logger = logging.getLogger(__name__)
40
+ from aiter.ops.rmsnorm import rms_norm, rmsnorm2d_fwd_with_add
41
+
42
+ rmsnorm = rms_norm
43
+
44
+ def fused_add_rmsnorm(
45
+ x: torch.Tensor,
46
+ residual: torch.Tensor,
47
+ w: torch.Tensor,
48
+ eps: float,
49
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
50
+ rmsnorm2d_fwd_with_add(x, x, residual, residual, w, eps)
51
+ return x, residual
37
52
 
38
53
 
39
54
  class RMSNorm(CustomOp):
@@ -139,7 +154,7 @@ class Gemma3RMSNorm(nn.Module):
139
154
  return f"{tuple(self.weight.shape)}, eps={self.eps}"
140
155
 
141
156
 
142
- if not _is_cuda:
157
+ if not (_is_cuda or _is_hip):
143
158
  logger.info(
144
159
  "sgl-kernel is not available on Non-NV platforms. Fallback to other kernel libraries."
145
160
  )
@@ -802,6 +802,7 @@ class DeepEPMoE(EPMoE):
802
802
  correction_bias: Optional[torch.Tensor] = None,
803
803
  custom_routing_function: Optional[Callable] = None,
804
804
  activation: str = "silu",
805
+ routed_scaling_factor: Optional[float] = None,
805
806
  deepep_mode: DeepEPMode = DeepEPMode.auto,
806
807
  ):
807
808
  super().__init__(
@@ -820,6 +821,7 @@ class DeepEPMoE(EPMoE):
820
821
  correction_bias,
821
822
  custom_routing_function,
822
823
  activation,
824
+ routed_scaling_factor,
823
825
  )
824
826
  self.deepep_mode = deepep_mode
825
827
  if self.deepep_mode.enable_low_latency():
@@ -33,6 +33,7 @@ from sglang.srt.layers.quantization.compressed_tensors.compressed_tensors_moe im
33
33
  from sglang.srt.layers.quantization.compressed_tensors.schemes import (
34
34
  CompressedTensorsScheme,
35
35
  CompressedTensorsW8A8Fp8,
36
+ CompressedTensorsW8A16Fp8,
36
37
  )
37
38
  from sglang.srt.layers.quantization.compressed_tensors.utils import (
38
39
  find_matched_target,
@@ -2,8 +2,10 @@
2
2
 
3
3
  from .compressed_tensors_scheme import CompressedTensorsScheme
4
4
  from .compressed_tensors_w8a8_fp8 import CompressedTensorsW8A8Fp8
5
+ from .compressed_tensors_w8a16_fp8 import CompressedTensorsW8A16Fp8
5
6
 
6
7
  __all__ = [
7
8
  "CompressedTensorsScheme",
8
9
  "CompressedTensorsW8A8Fp8",
10
+ "CompressedTensorsW8A16Fp8",
9
11
  ]
@@ -0,0 +1,153 @@
1
+ # Adapted from https://github.com/vllm-project/vllm/tree/main/vllm/model_executor/layers/quantization/compressed_tensors
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ from typing import Callable, List, Optional
5
+
6
+ import torch
7
+ from compressed_tensors.quantization import QuantizationStrategy
8
+
9
+ from sglang.srt.layers.parameter import (
10
+ ChannelQuantScaleParameter,
11
+ ModelWeightParameter,
12
+ PerTensorScaleParameter,
13
+ )
14
+ from sglang.srt.layers.quantization.compressed_tensors.schemes import (
15
+ CompressedTensorsScheme,
16
+ )
17
+ from sglang.srt.layers.quantization.utils import convert_to_channelwise
18
+
19
+ try:
20
+ from vllm.model_executor.layers.quantization.utils.marlin_utils_fp8 import (
21
+ apply_fp8_marlin_linear,
22
+ prepare_fp8_layer_for_marlin,
23
+ )
24
+
25
+ MARLIN_FP8_AVAILABLE = True
26
+ except ImportError:
27
+ MARLIN_FP8_AVAILABLE = False
28
+
29
+ def apply_fp8_marlin_linear(*args, **kwargs):
30
+ raise ImportError("vllm is not installed")
31
+
32
+ def prepare_fp8_layer_for_marlin(*args, **kwargs):
33
+ raise ImportError("vllm is not installed")
34
+
35
+
36
+ __all__ = ["CompressedTensorsW8A16Fp8"]
37
+
38
+ SUPPORTED_STRATEGIES = [QuantizationStrategy.CHANNEL, QuantizationStrategy.TENSOR]
39
+
40
+
41
+ class CompressedTensorsW8A16Fp8(CompressedTensorsScheme):
42
+ def __init__(self, strategy: str, is_static_input_scheme: bool):
43
+ self.strategy = strategy
44
+ self.is_static_input_scheme = is_static_input_scheme
45
+
46
+ if not MARLIN_FP8_AVAILABLE:
47
+ raise ImportError(
48
+ "vllm is not installed. To use CompressedTensorsW8A16Fp8, please install vllm"
49
+ )
50
+
51
+ @classmethod
52
+ def get_min_capability(cls) -> int:
53
+ # ampere and up
54
+ return 80
55
+
56
+ # W8A8-Fp8 kernels support only per-tensor and per-channel cases.
57
+ # So if we have a fused module (QKV, MLP) with per tensor scales,
58
+ # we expand each scale to its shard's channels.
59
+ def process_weights_after_loading(self, layer) -> None:
60
+ if self.strategy == QuantizationStrategy.TENSOR:
61
+ ws_channelwise = convert_to_channelwise(
62
+ layer.weight_scale, layer.logical_widths
63
+ )
64
+ layer.weight_scale = torch.nn.Parameter(ws_channelwise, requires_grad=False)
65
+ else:
66
+ # required by torch.compile to be torch.nn.Parameter
67
+ layer.weight_scale = torch.nn.Parameter(
68
+ layer.weight_scale.data, requires_grad=False
69
+ )
70
+
71
+ # Weights must be transposed for marlin
72
+ layer.weight = torch.nn.Parameter(layer.weight.t(), requires_grad=False)
73
+
74
+ if self.is_static_input_scheme:
75
+ # required by torch.compile to be torch.nn.Parameter
76
+ layer.input_scale = torch.nn.Parameter(
77
+ layer.input_scale.data, requires_grad=False
78
+ )
79
+ prepare_fp8_layer_for_marlin(layer, strategy="channel")
80
+
81
+ def create_weights(
82
+ self,
83
+ layer: torch.nn.Module,
84
+ input_size: int,
85
+ output_partition_sizes: List[int],
86
+ input_size_per_partition: int,
87
+ params_dtype: torch.dtype,
88
+ weight_loader: Callable,
89
+ **kwargs,
90
+ ):
91
+ output_size_per_partition = sum(output_partition_sizes)
92
+ layer.logical_widths = output_partition_sizes
93
+ layer.input_size_per_partition = input_size_per_partition
94
+ layer.output_size_per_partition = output_size_per_partition
95
+ layer.orig_dtype = params_dtype
96
+
97
+ # WEIGHT
98
+ weight = ModelWeightParameter(
99
+ data=torch.empty(
100
+ output_size_per_partition,
101
+ input_size_per_partition,
102
+ dtype=torch.float8_e4m3fn,
103
+ ),
104
+ input_dim=1,
105
+ output_dim=0,
106
+ weight_loader=weight_loader,
107
+ )
108
+ layer.register_parameter("weight", weight)
109
+
110
+ # WEIGHT SCALE
111
+ if self.strategy == QuantizationStrategy.CHANNEL:
112
+ weight_scale = ChannelQuantScaleParameter(
113
+ data=torch.empty((sum(output_partition_sizes), 1), dtype=torch.float32),
114
+ output_dim=0,
115
+ weight_loader=weight_loader,
116
+ )
117
+ elif self.strategy == QuantizationStrategy.TENSOR:
118
+ weight_scale = PerTensorScaleParameter(
119
+ data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
120
+ weight_loader=weight_loader,
121
+ )
122
+ else:
123
+ raise ValueError(
124
+ f"Unsupported weight strategy={self.strategy}, "
125
+ f"supported strategies are {SUPPORTED_STRATEGIES}"
126
+ )
127
+
128
+ weight_scale[:] = torch.finfo(torch.float32).min
129
+ layer.register_parameter("weight_scale", weight_scale)
130
+
131
+ # INPUT SCALE (to deal with converted checkpoints)
132
+ if self.is_static_input_scheme:
133
+ input_scale = PerTensorScaleParameter(
134
+ data=torch.empty(len(output_partition_sizes), dtype=torch.float32),
135
+ weight_loader=weight_loader,
136
+ )
137
+ layer.register_parameter("input_scale", input_scale)
138
+
139
+ def apply_weights(
140
+ self,
141
+ layer: torch.nn.Module,
142
+ x: torch.Tensor,
143
+ bias: Optional[torch.Tensor] = None,
144
+ ) -> torch.Tensor:
145
+ return apply_fp8_marlin_linear(
146
+ input=x,
147
+ weight=layer.weight,
148
+ weight_scale=layer.weight_scale,
149
+ workspace=layer.workspace,
150
+ size_n=layer.output_size_per_partition,
151
+ size_k=layer.input_size_per_partition,
152
+ bias=bias,
153
+ )