sglang 0.4.0.post2__py3-none-any.whl → 0.4.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (45) hide show
  1. sglang/bench_offline_throughput.py +0 -12
  2. sglang/bench_one_batch.py +0 -12
  3. sglang/bench_serving.py +1 -0
  4. sglang/srt/aio_rwlock.py +100 -0
  5. sglang/srt/configs/model_config.py +8 -1
  6. sglang/srt/layers/attention/flashinfer_backend.py +49 -5
  7. sglang/srt/layers/linear.py +20 -2
  8. sglang/srt/layers/{ep_moe → moe/ep_moe}/layer.py +14 -39
  9. sglang/srt/layers/moe/fused_moe_native.py +46 -0
  10. sglang/srt/layers/{fused_moe_triton → moe/fused_moe_triton}/__init__.py +3 -7
  11. sglang/srt/layers/{fused_moe_triton → moe/fused_moe_triton}/fused_moe.py +110 -98
  12. sglang/srt/layers/{fused_moe_triton → moe/fused_moe_triton}/layer.py +16 -48
  13. sglang/srt/layers/moe/topk.py +191 -0
  14. sglang/srt/layers/quantization/__init__.py +3 -3
  15. sglang/srt/layers/quantization/fp8.py +169 -32
  16. sglang/srt/layers/quantization/fp8_kernel.py +278 -0
  17. sglang/srt/layers/quantization/fp8_utils.py +90 -1
  18. sglang/srt/layers/torchao_utils.py +11 -15
  19. sglang/srt/managers/schedule_batch.py +16 -10
  20. sglang/srt/managers/scheduler.py +2 -2
  21. sglang/srt/managers/tokenizer_manager.py +86 -76
  22. sglang/srt/mem_cache/memory_pool.py +15 -8
  23. sglang/srt/model_executor/cuda_graph_runner.py +1 -1
  24. sglang/srt/model_executor/model_runner.py +6 -0
  25. sglang/srt/models/dbrx.py +1 -1
  26. sglang/srt/models/deepseek.py +1 -1
  27. sglang/srt/models/deepseek_v2.py +67 -18
  28. sglang/srt/models/grok.py +1 -1
  29. sglang/srt/models/mixtral.py +2 -2
  30. sglang/srt/models/olmoe.py +1 -1
  31. sglang/srt/models/qwen2_moe.py +1 -1
  32. sglang/srt/models/xverse_moe.py +1 -1
  33. sglang/srt/openai_api/adapter.py +4 -0
  34. sglang/srt/server.py +1 -0
  35. sglang/srt/utils.py +33 -44
  36. sglang/test/test_block_fp8.py +341 -0
  37. sglang/version.py +1 -1
  38. {sglang-0.4.0.post2.dist-info → sglang-0.4.1.dist-info}/METADATA +3 -3
  39. {sglang-0.4.0.post2.dist-info → sglang-0.4.1.dist-info}/RECORD +44 -40
  40. sglang/srt/layers/fused_moe_patch.py +0 -133
  41. /sglang/srt/layers/{ep_moe → moe/ep_moe}/__init__.py +0 -0
  42. /sglang/srt/layers/{ep_moe → moe/ep_moe}/kernels.py +0 -0
  43. {sglang-0.4.0.post2.dist-info → sglang-0.4.1.dist-info}/LICENSE +0 -0
  44. {sglang-0.4.0.post2.dist-info → sglang-0.4.1.dist-info}/WHEEL +0 -0
  45. {sglang-0.4.0.post2.dist-info → sglang-0.4.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,278 @@
1
+ from typing import List, Tuple
2
+
3
+ import torch
4
+ import triton
5
+ import triton.language as tl
6
+
7
+
8
+ @triton.jit
9
+ def _per_token_group_quant_fp8(
10
+ # Pointers to inputs and output
11
+ y_ptr,
12
+ y_q_ptr,
13
+ y_s_ptr,
14
+ # Stride of input
15
+ y_stride,
16
+ # Collums of input
17
+ N,
18
+ # Avoid to divide zero
19
+ eps,
20
+ # Information for float8
21
+ fp8_min,
22
+ fp8_max,
23
+ # Meta-parameters
24
+ BLOCK: tl.constexpr,
25
+ ):
26
+ """A Triton-accelerated function to perform per-token-group quantization on a
27
+ tensor.
28
+
29
+ This function converts the tensor values into float8 values.
30
+ """
31
+ # Map the program id to the row of X and Y it should compute.
32
+ g_id = tl.program_id(0)
33
+ y_ptr += g_id * y_stride
34
+ y_q_ptr += g_id * y_stride
35
+ y_s_ptr += g_id
36
+
37
+ cols = tl.arange(0, BLOCK) # N <= BLOCK
38
+ mask = cols < N
39
+
40
+ y = tl.load(y_ptr + cols, mask=mask, other=0.0).to(tl.float32)
41
+ # Quant
42
+ _absmax = tl.maximum(tl.max(tl.abs(y)), eps)
43
+ y_s = _absmax / fp8_max
44
+ y_q = tl.clamp(y / y_s, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty)
45
+
46
+ tl.store(y_q_ptr + cols, y_q, mask=mask)
47
+ tl.store(y_s_ptr, y_s)
48
+
49
+
50
+ def per_token_group_quant_fp8(
51
+ x: torch.Tensor,
52
+ group_size: int,
53
+ eps: float = 1e-10,
54
+ dtype: torch.dtype = torch.float8_e4m3fn,
55
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
56
+ """Function to perform per-token-group quantization on an input tensor `x`.
57
+
58
+ It converts the tensor values into signed float8 values and returns the
59
+ quantized tensor along with the scaling factor used for quantization.
60
+
61
+ Args:
62
+ x: The input tenosr with ndim >= 2.
63
+ group_size: The group size used for quantization.
64
+ eps: The minimum to avoid dividing zero.
65
+ dtype: The dype of output tensor. Note that only `torch.float8_e4m3fn` is supported for now.
66
+
67
+ Returns:
68
+ Tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the scaling factor for quantization.
69
+ """
70
+ assert (
71
+ x.shape[-1] % group_size == 0
72
+ ), "the last dimension of `x` cannot be divisible by `group_size`"
73
+ assert x.is_contiguous(), "`x` is not contiguous"
74
+
75
+ finfo = torch.finfo(dtype)
76
+ fp8_min = finfo.min
77
+ fp8_max = finfo.max
78
+
79
+ x_q = torch.empty_like(x, device=x.device, dtype=dtype)
80
+ M = x.numel() // group_size
81
+ N = group_size
82
+ x_s = torch.empty(
83
+ x.shape[:-1] + (x.shape[-1] // group_size,),
84
+ device=x.device,
85
+ dtype=torch.float32,
86
+ )
87
+
88
+ BLOCK = triton.next_power_of_2(N)
89
+ # heuristics for number of warps
90
+ num_warps = min(max(BLOCK // 256, 1), 8)
91
+ num_stages = 1
92
+ _per_token_group_quant_fp8[(M,)](
93
+ x,
94
+ x_q,
95
+ x_s,
96
+ group_size,
97
+ N,
98
+ eps,
99
+ fp8_min=fp8_min,
100
+ fp8_max=fp8_max,
101
+ BLOCK=BLOCK,
102
+ num_warps=num_warps,
103
+ num_stages=num_stages,
104
+ )
105
+
106
+ return x_q, x_s
107
+
108
+
109
+ @triton.jit
110
+ def _w8a8_block_fp8_matmul(
111
+ # Pointers to inputs and output
112
+ A,
113
+ B,
114
+ C,
115
+ As,
116
+ Bs,
117
+ # Shape for matmul
118
+ M,
119
+ N,
120
+ K,
121
+ # Block size for block-wise quantization
122
+ group_n,
123
+ group_k,
124
+ # Stride for inputs and output
125
+ stride_am,
126
+ stride_ak,
127
+ stride_bk,
128
+ stride_bn,
129
+ stride_cm,
130
+ stride_cn,
131
+ stride_As_m,
132
+ stride_As_k,
133
+ stride_Bs_k,
134
+ stride_Bs_n,
135
+ # Meta-parameters
136
+ BLOCK_SIZE_M: tl.constexpr,
137
+ BLOCK_SIZE_N: tl.constexpr,
138
+ BLOCK_SIZE_K: tl.constexpr,
139
+ GROUP_SIZE_M: tl.constexpr,
140
+ ):
141
+ """Triton-accelerated function used to perform linear operations (dot
142
+ product) on input tensors `A` and `B` with block-wise quantization, and store the result in output
143
+ tensor `C`.
144
+ """
145
+
146
+ pid = tl.program_id(axis=0)
147
+ num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
148
+ num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
149
+ num_pid_in_group = GROUP_SIZE_M * num_pid_n
150
+ group_id = pid // num_pid_in_group
151
+ first_pid_m = group_id * GROUP_SIZE_M
152
+ group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
153
+ pid_m = first_pid_m + (pid % group_size_m)
154
+ pid_n = (pid % num_pid_in_group) // group_size_m
155
+
156
+ offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
157
+ offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
158
+ offs_k = tl.arange(0, BLOCK_SIZE_K)
159
+ a_ptrs = A + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
160
+ b_ptrs = B + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
161
+
162
+ As_ptrs = As + offs_am * stride_As_m
163
+ offs_bsn = offs_bn // group_n
164
+ Bs_ptrs = Bs + offs_bsn * stride_Bs_n
165
+
166
+ accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
167
+ for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
168
+ a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)
169
+ b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
170
+
171
+ k_start = k * BLOCK_SIZE_K
172
+ offs_ks = k_start // group_k
173
+ a_s = tl.load(As_ptrs + offs_ks * stride_As_k)
174
+ b_s = tl.load(Bs_ptrs + offs_ks * stride_Bs_k)
175
+
176
+ accumulator += tl.dot(a, b) * a_s[:, None] * b_s[None, :]
177
+ a_ptrs += BLOCK_SIZE_K * stride_ak
178
+ b_ptrs += BLOCK_SIZE_K * stride_bk
179
+
180
+ if C.dtype.element_ty == tl.bfloat16:
181
+ c = accumulator.to(tl.bfloat16)
182
+ elif C.dtype.element_ty == tl.float16:
183
+ c = accumulator.to(tl.float16)
184
+ else:
185
+ c = accumulator.to(tl.float32)
186
+
187
+ offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
188
+ offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
189
+ c_ptrs = C + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
190
+ c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
191
+ tl.store(c_ptrs, c, mask=c_mask)
192
+
193
+
194
+ def w8a8_block_fp8_matmul(
195
+ A: torch.Tensor,
196
+ B: torch.Tensor,
197
+ As: torch.Tensor,
198
+ Bs: torch.Tensor,
199
+ block_size: List[int],
200
+ output_dtype: torch.dtype = torch.float16,
201
+ ) -> torch.Tensor:
202
+ """This function performs matrix multiplication with block-wise quantization.
203
+
204
+ It takes two input tensors `A` and `B` with scales `As` and `Bs`.
205
+ The output is returned in the specified `output_dtype`.
206
+
207
+ Args:
208
+ A: The input tensor, e.g., activation.
209
+ B: The input tensor, e.g., weight.
210
+ As: The per-token-group quantization scale for `A`.
211
+ Bs: The per-block quantization scale for `B`.
212
+ block_size: The block size for per-block quantization. It should be 2-dim, e.g., [128, 128].
213
+ output_dytpe: The dtype of the returned tensor.
214
+
215
+ Returns:
216
+ torch.Tensor: The result of matmul.
217
+ """
218
+ assert len(block_size) == 2
219
+ block_n, block_k = block_size[0], block_size[1]
220
+
221
+ assert A.shape[-1] == B.shape[-1]
222
+ assert A.shape[:-1] == As.shape[:-1] and A.is_contiguous()
223
+ assert triton.cdiv(A.shape[-1], block_k) == As.shape[-1]
224
+ M = A.numel() // A.shape[-1]
225
+
226
+ assert B.ndim == 2 and B.is_contiguous() and Bs.ndim == 2
227
+ N, K = B.shape
228
+ assert triton.cdiv(N, block_n) == Bs.shape[0]
229
+ assert triton.cdiv(K, block_k) == Bs.shape[1]
230
+
231
+ C_shape = A.shape[:-1] + (N,)
232
+ C = A.new_empty(C_shape, dtype=output_dtype)
233
+
234
+ # TODO(HandH1998):
235
+ # BLOCK_SIZE_M, BLOCK_SIZE_K, BLOCK_SIZE_N can be optimized.
236
+ # BLOCK_SIZE_K must be divisable by block_k
237
+ # BLOCK_SIZE_N and BLOCK_SIZE_M has no requirements
238
+ BLOCK_SIZE_M = 128
239
+ if M < BLOCK_SIZE_M:
240
+ BLOCK_SIZE_M = triton.next_power_of_2(M)
241
+ BLOCK_SIZE_M = max(BLOCK_SIZE_M, 16)
242
+ BLOCK_SIZE_K = block_k
243
+ assert block_k % BLOCK_SIZE_K == 0
244
+ BLOCK_SIZE_N = block_n
245
+
246
+ def grid(META):
247
+ return (
248
+ triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]),
249
+ )
250
+
251
+ _w8a8_block_fp8_matmul[grid](
252
+ A,
253
+ B,
254
+ C,
255
+ As,
256
+ Bs,
257
+ M,
258
+ N,
259
+ K,
260
+ block_n,
261
+ block_k,
262
+ A.stride(-2),
263
+ A.stride(-1),
264
+ B.stride(1),
265
+ B.stride(0),
266
+ C.stride(-2),
267
+ C.stride(-1),
268
+ As.stride(-2),
269
+ As.stride(-1),
270
+ Bs.stride(1),
271
+ Bs.stride(0),
272
+ BLOCK_SIZE_M=BLOCK_SIZE_M,
273
+ BLOCK_SIZE_N=BLOCK_SIZE_N,
274
+ BLOCK_SIZE_K=BLOCK_SIZE_K,
275
+ GROUP_SIZE_M=8,
276
+ )
277
+
278
+ return C
@@ -1,6 +1,12 @@
1
- from typing import Optional, Tuple
1
+ from typing import List, Optional, Tuple
2
2
 
3
3
  import torch
4
+ from vllm.model_executor.parameter import RowvLLMParameter, _ColumnvLLMParameter
5
+
6
+ from sglang.srt.layers.quantization.fp8_kernel import (
7
+ per_token_group_quant_fp8,
8
+ w8a8_block_fp8_matmul,
9
+ )
4
10
 
5
11
 
6
12
  def normalize_e4m3fn_to_e4m3fnuz(
@@ -25,3 +31,86 @@ def normalize_e4m3fn_to_e4m3fnuz(
25
31
  if input_scale is not None:
26
32
  input_scale = input_scale * 2.0
27
33
  return weight, weight_scale, input_scale
34
+
35
+
36
+ def apply_w8a8_block_fp8_linear(
37
+ input: torch.Tensor,
38
+ weight: torch.Tensor,
39
+ block_size: List[int],
40
+ weight_scale: torch.Tensor,
41
+ input_scale: Optional[torch.Tensor] = None,
42
+ bias: Optional[torch.Tensor] = None,
43
+ ) -> torch.Tensor:
44
+ assert input_scale is None
45
+ # View input as 2D matrix for fp8 methods
46
+ input_2d = input.view(-1, input.shape[-1])
47
+ output_shape = [*input.shape[:-1], weight.shape[0]]
48
+
49
+ q_input, x_scale = per_token_group_quant_fp8(input_2d, block_size[1])
50
+ output = w8a8_block_fp8_matmul(
51
+ q_input, weight, x_scale, weight_scale, block_size, output_dtype=input.dtype
52
+ )
53
+
54
+ if bias is not None:
55
+ output = output + bias
56
+ return output.to(dtype=input.dtype).view(*output_shape)
57
+
58
+
59
+ def input_to_float8(
60
+ x: torch.Tensor, dtype: torch.dtype = torch.float8_e4m3fn
61
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
62
+ """This function quantizes input values to float8 values with tensor-wise quantization."""
63
+ finfo = torch.finfo(dtype)
64
+ min_val, max_val = x.aminmax()
65
+ amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-12)
66
+ scale = finfo.max / amax
67
+ x_scl_sat = (x * scale).clamp(min=finfo.min, max=finfo.max)
68
+ return x_scl_sat.to(dtype).contiguous(), scale.float().reciprocal()
69
+
70
+
71
+ def block_quant_to_tensor_quant(
72
+ x_q_block: torch.Tensor,
73
+ x_s: torch.Tensor,
74
+ block_size: List[int],
75
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
76
+ """This function converts block-wise quantization to tensor-wise quantization.
77
+ The inputs are block-wise quantization tensor `x_q_block`, block-wise quantization scale
78
+ and the block size.
79
+ The outputs are tensor-wise quantization tensor and tensor-wise quantization scale.
80
+ Note only float8 is supported for now.
81
+ """
82
+ block_n, block_k = block_size[0], block_size[1]
83
+ n, k = x_q_block.shape
84
+ n_tiles = (n + block_n - 1) // block_n
85
+ k_tiles = (k + block_k - 1) // block_k
86
+ assert n_tiles == x_s.shape[0]
87
+ assert k_tiles == x_s.shape[1]
88
+
89
+ x_dq_block = x_q_block.to(torch.float32)
90
+
91
+ x_dq_block_tiles = [
92
+ [
93
+ x_dq_block[
94
+ j * block_n : min((j + 1) * block_n, n),
95
+ i * block_k : min((i + 1) * block_k, k),
96
+ ]
97
+ for i in range(k_tiles)
98
+ ]
99
+ for j in range(n_tiles)
100
+ ]
101
+
102
+ for i in range(k_tiles):
103
+ for j in range(n_tiles):
104
+ x_dq_block_tiles[j][i][:, :] = x_dq_block_tiles[j][i] * x_s[j][i]
105
+
106
+ x_q_tensor, scale = input_to_float8(x_dq_block, dtype=x_q_block.dtype)
107
+ return x_q_tensor, scale
108
+
109
+
110
+ class BlockQuantScaleParameter(_ColumnvLLMParameter, RowvLLMParameter):
111
+ """
112
+ Parameter class for weight scales loaded for weights with
113
+ block-wise quantization. Uses both column and row parallelism.
114
+ """
115
+
116
+ pass
@@ -2,8 +2,14 @@
2
2
  Common utilities for torchao.
3
3
  """
4
4
 
5
+ import logging
6
+ import os
7
+ import pwd
8
+
5
9
  import torch
6
10
 
11
+ logger = logging.getLogger(__name__)
12
+
7
13
 
8
14
  def apply_torchao_config_to_model(
9
15
  model: torch.nn.Module, torchao_config: str, filter_fn=None
@@ -50,27 +56,17 @@ def apply_torchao_config_to_model(
50
56
  elif "gemlite" in torchao_config:
51
57
  # gemlite-<packing_bitwidth>-<bit_width>-<group_size> or
52
58
  # gemlite-<bit_width>-<group_size> (packing_bitwidth defaults to 32)
53
- import os
54
- import pwd
55
-
56
- import gemlite
57
- from gemlite.core import GemLiteLinearTriton, set_autotune
58
-
59
- try:
60
- from torchao.quantization import gemlite_uintx_weight_only
61
- except:
62
- print(
63
- f"import `gemlite_uintx_weight_only` failed, please use torchao nightly to use gemlite quantization"
64
- )
65
- return model
59
+ from gemlite.core import GemLiteLinearTriton
60
+ from torchao.quantization import gemlite_uintx_weight_only
66
61
 
67
62
  _quant_args = torchao_config.split("-")
68
63
  bit_width = int(_quant_args[-2])
69
64
  group_size = None if _quant_args[-1] == "None" else int(_quant_args[-1])
65
+
70
66
  try:
71
67
  packing_bitwidth = int(_quant_args[-3])
72
- except:
73
- # if only 2 inputs found, use default value
68
+ except (ValueError, IndexError):
69
+ # if only 2 inputs found or conversion fails, use default value
74
70
  packing_bitwidth = 32
75
71
 
76
72
  quantize_(
@@ -479,8 +479,22 @@ class Req:
479
479
 
480
480
  return True
481
481
 
482
+ def reset_for_retract(self):
483
+ self.prefix_indices = []
484
+ self.last_node = None
485
+ self.extend_input_len = 0
486
+ self.is_retracted = True
487
+
488
+ # For incremental logprobs
489
+ # TODO: Fix the `logprob_start_len`
490
+ self.last_update_decode_tokens = 0
491
+ self.logprob_start_len = 10**9
492
+
482
493
  def __repr__(self):
483
- return f"rid(n={self.rid}, " f"input_ids={self.origin_input_ids}, "
494
+ return (
495
+ f"rid(n={self.rid}, "
496
+ f"input_ids={self.origin_input_ids}, output_ids={self.output_ids}"
497
+ )
484
498
 
485
499
 
486
500
  bid = 0
@@ -894,15 +908,7 @@ class ScheduleBatch:
894
908
  )
895
909
  residual_size = max(0, residual_size)
896
910
  self.tree_cache.evict(residual_size, self.token_to_kv_pool.free)
897
-
898
- req.prefix_indices = []
899
- req.last_node = None
900
- req.extend_input_len = 0
901
- req.is_retracted = True
902
-
903
- # For incremental logprobs
904
- req.last_update_decode_tokens = 0
905
- req.logprob_start_len = 10**9
911
+ req.reset_for_retract()
906
912
 
907
913
  self.filter_batch(keep_indices=sorted_indices)
908
914
 
@@ -22,7 +22,7 @@ import warnings
22
22
  from collections import deque
23
23
  from concurrent import futures
24
24
  from types import SimpleNamespace
25
- from typing import List, Optional
25
+ from typing import Callable, Dict, List, Optional, Tuple
26
26
 
27
27
  import psutil
28
28
  import setproctitle
@@ -260,7 +260,7 @@ class Scheduler:
260
260
  self.current_stream = torch.get_device_module(self.device).current_stream()
261
261
 
262
262
  # Session info
263
- self.sessions = {}
263
+ self.sessions: Dict[str, Session] = {}
264
264
 
265
265
  # Init chunked prefill
266
266
  self.chunked_prefill_size = server_args.chunked_prefill_size