sglang 0.4.0.post1__py3-none-any.whl → 0.4.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (74) hide show
  1. sglang/bench_offline_throughput.py +6 -6
  2. sglang/bench_one_batch.py +1 -0
  3. sglang/bench_serving.py +9 -1
  4. sglang/check_env.py +140 -48
  5. sglang/lang/backend/runtime_endpoint.py +1 -0
  6. sglang/lang/chat_template.py +32 -0
  7. sglang/llama3_eval.py +316 -0
  8. sglang/srt/aio_rwlock.py +100 -0
  9. sglang/srt/configs/model_config.py +8 -1
  10. sglang/srt/constrained/xgrammar_backend.py +4 -1
  11. sglang/srt/layers/attention/flashinfer_backend.py +51 -5
  12. sglang/srt/layers/attention/triton_backend.py +16 -25
  13. sglang/srt/layers/attention/triton_ops/decode_attention.py +305 -350
  14. sglang/srt/layers/linear.py +20 -2
  15. sglang/srt/layers/logits_processor.py +133 -95
  16. sglang/srt/layers/{ep_moe → moe/ep_moe}/layer.py +18 -39
  17. sglang/srt/layers/moe/fused_moe_native.py +46 -0
  18. sglang/srt/layers/{fused_moe_triton → moe/fused_moe_triton}/__init__.py +3 -7
  19. sglang/srt/layers/{fused_moe_triton → moe/fused_moe_triton}/fused_moe.py +174 -119
  20. sglang/srt/layers/{fused_moe_triton → moe/fused_moe_triton}/layer.py +17 -49
  21. sglang/srt/layers/moe/topk.py +191 -0
  22. sglang/srt/layers/quantization/__init__.py +5 -50
  23. sglang/srt/layers/quantization/fp8.py +221 -36
  24. sglang/srt/layers/quantization/fp8_kernel.py +278 -0
  25. sglang/srt/layers/quantization/fp8_utils.py +90 -1
  26. sglang/srt/layers/radix_attention.py +8 -1
  27. sglang/srt/layers/sampler.py +27 -5
  28. sglang/srt/layers/torchao_utils.py +31 -0
  29. sglang/srt/managers/detokenizer_manager.py +37 -17
  30. sglang/srt/managers/io_struct.py +39 -10
  31. sglang/srt/managers/schedule_batch.py +54 -34
  32. sglang/srt/managers/schedule_policy.py +64 -5
  33. sglang/srt/managers/scheduler.py +171 -136
  34. sglang/srt/managers/tokenizer_manager.py +184 -133
  35. sglang/srt/mem_cache/base_prefix_cache.py +2 -2
  36. sglang/srt/mem_cache/chunk_cache.py +2 -2
  37. sglang/srt/mem_cache/memory_pool.py +15 -8
  38. sglang/srt/mem_cache/radix_cache.py +12 -2
  39. sglang/srt/model_executor/cuda_graph_runner.py +25 -11
  40. sglang/srt/model_executor/model_runner.py +28 -14
  41. sglang/srt/model_parallel.py +66 -5
  42. sglang/srt/models/dbrx.py +1 -1
  43. sglang/srt/models/deepseek.py +1 -1
  44. sglang/srt/models/deepseek_v2.py +67 -18
  45. sglang/srt/models/gemma2.py +34 -0
  46. sglang/srt/models/gemma2_reward.py +0 -1
  47. sglang/srt/models/granite.py +517 -0
  48. sglang/srt/models/grok.py +73 -9
  49. sglang/srt/models/llama.py +22 -0
  50. sglang/srt/models/llama_classification.py +11 -23
  51. sglang/srt/models/llama_reward.py +0 -2
  52. sglang/srt/models/llava.py +37 -14
  53. sglang/srt/models/mixtral.py +2 -2
  54. sglang/srt/models/olmoe.py +1 -1
  55. sglang/srt/models/qwen2.py +20 -0
  56. sglang/srt/models/qwen2_moe.py +1 -1
  57. sglang/srt/models/xverse_moe.py +1 -1
  58. sglang/srt/openai_api/adapter.py +8 -0
  59. sglang/srt/openai_api/protocol.py +9 -4
  60. sglang/srt/server.py +2 -1
  61. sglang/srt/server_args.py +19 -9
  62. sglang/srt/utils.py +40 -54
  63. sglang/test/test_block_fp8.py +341 -0
  64. sglang/test/test_utils.py +3 -2
  65. sglang/utils.py +10 -3
  66. sglang/version.py +1 -1
  67. {sglang-0.4.0.post1.dist-info → sglang-0.4.1.dist-info}/METADATA +12 -7
  68. {sglang-0.4.0.post1.dist-info → sglang-0.4.1.dist-info}/RECORD +73 -67
  69. sglang/srt/layers/fused_moe_patch.py +0 -133
  70. /sglang/srt/layers/{ep_moe → moe/ep_moe}/__init__.py +0 -0
  71. /sglang/srt/layers/{ep_moe → moe/ep_moe}/kernels.py +0 -0
  72. {sglang-0.4.0.post1.dist-info → sglang-0.4.1.dist-info}/LICENSE +0 -0
  73. {sglang-0.4.0.post1.dist-info → sglang-0.4.1.dist-info}/WHEEL +0 -0
  74. {sglang-0.4.0.post1.dist-info → sglang-0.4.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,278 @@
1
+ from typing import List, Tuple
2
+
3
+ import torch
4
+ import triton
5
+ import triton.language as tl
6
+
7
+
8
+ @triton.jit
9
+ def _per_token_group_quant_fp8(
10
+ # Pointers to inputs and output
11
+ y_ptr,
12
+ y_q_ptr,
13
+ y_s_ptr,
14
+ # Stride of input
15
+ y_stride,
16
+ # Collums of input
17
+ N,
18
+ # Avoid to divide zero
19
+ eps,
20
+ # Information for float8
21
+ fp8_min,
22
+ fp8_max,
23
+ # Meta-parameters
24
+ BLOCK: tl.constexpr,
25
+ ):
26
+ """A Triton-accelerated function to perform per-token-group quantization on a
27
+ tensor.
28
+
29
+ This function converts the tensor values into float8 values.
30
+ """
31
+ # Map the program id to the row of X and Y it should compute.
32
+ g_id = tl.program_id(0)
33
+ y_ptr += g_id * y_stride
34
+ y_q_ptr += g_id * y_stride
35
+ y_s_ptr += g_id
36
+
37
+ cols = tl.arange(0, BLOCK) # N <= BLOCK
38
+ mask = cols < N
39
+
40
+ y = tl.load(y_ptr + cols, mask=mask, other=0.0).to(tl.float32)
41
+ # Quant
42
+ _absmax = tl.maximum(tl.max(tl.abs(y)), eps)
43
+ y_s = _absmax / fp8_max
44
+ y_q = tl.clamp(y / y_s, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty)
45
+
46
+ tl.store(y_q_ptr + cols, y_q, mask=mask)
47
+ tl.store(y_s_ptr, y_s)
48
+
49
+
50
+ def per_token_group_quant_fp8(
51
+ x: torch.Tensor,
52
+ group_size: int,
53
+ eps: float = 1e-10,
54
+ dtype: torch.dtype = torch.float8_e4m3fn,
55
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
56
+ """Function to perform per-token-group quantization on an input tensor `x`.
57
+
58
+ It converts the tensor values into signed float8 values and returns the
59
+ quantized tensor along with the scaling factor used for quantization.
60
+
61
+ Args:
62
+ x: The input tenosr with ndim >= 2.
63
+ group_size: The group size used for quantization.
64
+ eps: The minimum to avoid dividing zero.
65
+ dtype: The dype of output tensor. Note that only `torch.float8_e4m3fn` is supported for now.
66
+
67
+ Returns:
68
+ Tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the scaling factor for quantization.
69
+ """
70
+ assert (
71
+ x.shape[-1] % group_size == 0
72
+ ), "the last dimension of `x` cannot be divisible by `group_size`"
73
+ assert x.is_contiguous(), "`x` is not contiguous"
74
+
75
+ finfo = torch.finfo(dtype)
76
+ fp8_min = finfo.min
77
+ fp8_max = finfo.max
78
+
79
+ x_q = torch.empty_like(x, device=x.device, dtype=dtype)
80
+ M = x.numel() // group_size
81
+ N = group_size
82
+ x_s = torch.empty(
83
+ x.shape[:-1] + (x.shape[-1] // group_size,),
84
+ device=x.device,
85
+ dtype=torch.float32,
86
+ )
87
+
88
+ BLOCK = triton.next_power_of_2(N)
89
+ # heuristics for number of warps
90
+ num_warps = min(max(BLOCK // 256, 1), 8)
91
+ num_stages = 1
92
+ _per_token_group_quant_fp8[(M,)](
93
+ x,
94
+ x_q,
95
+ x_s,
96
+ group_size,
97
+ N,
98
+ eps,
99
+ fp8_min=fp8_min,
100
+ fp8_max=fp8_max,
101
+ BLOCK=BLOCK,
102
+ num_warps=num_warps,
103
+ num_stages=num_stages,
104
+ )
105
+
106
+ return x_q, x_s
107
+
108
+
109
+ @triton.jit
110
+ def _w8a8_block_fp8_matmul(
111
+ # Pointers to inputs and output
112
+ A,
113
+ B,
114
+ C,
115
+ As,
116
+ Bs,
117
+ # Shape for matmul
118
+ M,
119
+ N,
120
+ K,
121
+ # Block size for block-wise quantization
122
+ group_n,
123
+ group_k,
124
+ # Stride for inputs and output
125
+ stride_am,
126
+ stride_ak,
127
+ stride_bk,
128
+ stride_bn,
129
+ stride_cm,
130
+ stride_cn,
131
+ stride_As_m,
132
+ stride_As_k,
133
+ stride_Bs_k,
134
+ stride_Bs_n,
135
+ # Meta-parameters
136
+ BLOCK_SIZE_M: tl.constexpr,
137
+ BLOCK_SIZE_N: tl.constexpr,
138
+ BLOCK_SIZE_K: tl.constexpr,
139
+ GROUP_SIZE_M: tl.constexpr,
140
+ ):
141
+ """Triton-accelerated function used to perform linear operations (dot
142
+ product) on input tensors `A` and `B` with block-wise quantization, and store the result in output
143
+ tensor `C`.
144
+ """
145
+
146
+ pid = tl.program_id(axis=0)
147
+ num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
148
+ num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
149
+ num_pid_in_group = GROUP_SIZE_M * num_pid_n
150
+ group_id = pid // num_pid_in_group
151
+ first_pid_m = group_id * GROUP_SIZE_M
152
+ group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
153
+ pid_m = first_pid_m + (pid % group_size_m)
154
+ pid_n = (pid % num_pid_in_group) // group_size_m
155
+
156
+ offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
157
+ offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
158
+ offs_k = tl.arange(0, BLOCK_SIZE_K)
159
+ a_ptrs = A + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
160
+ b_ptrs = B + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
161
+
162
+ As_ptrs = As + offs_am * stride_As_m
163
+ offs_bsn = offs_bn // group_n
164
+ Bs_ptrs = Bs + offs_bsn * stride_Bs_n
165
+
166
+ accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
167
+ for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
168
+ a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)
169
+ b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
170
+
171
+ k_start = k * BLOCK_SIZE_K
172
+ offs_ks = k_start // group_k
173
+ a_s = tl.load(As_ptrs + offs_ks * stride_As_k)
174
+ b_s = tl.load(Bs_ptrs + offs_ks * stride_Bs_k)
175
+
176
+ accumulator += tl.dot(a, b) * a_s[:, None] * b_s[None, :]
177
+ a_ptrs += BLOCK_SIZE_K * stride_ak
178
+ b_ptrs += BLOCK_SIZE_K * stride_bk
179
+
180
+ if C.dtype.element_ty == tl.bfloat16:
181
+ c = accumulator.to(tl.bfloat16)
182
+ elif C.dtype.element_ty == tl.float16:
183
+ c = accumulator.to(tl.float16)
184
+ else:
185
+ c = accumulator.to(tl.float32)
186
+
187
+ offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
188
+ offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
189
+ c_ptrs = C + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
190
+ c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
191
+ tl.store(c_ptrs, c, mask=c_mask)
192
+
193
+
194
+ def w8a8_block_fp8_matmul(
195
+ A: torch.Tensor,
196
+ B: torch.Tensor,
197
+ As: torch.Tensor,
198
+ Bs: torch.Tensor,
199
+ block_size: List[int],
200
+ output_dtype: torch.dtype = torch.float16,
201
+ ) -> torch.Tensor:
202
+ """This function performs matrix multiplication with block-wise quantization.
203
+
204
+ It takes two input tensors `A` and `B` with scales `As` and `Bs`.
205
+ The output is returned in the specified `output_dtype`.
206
+
207
+ Args:
208
+ A: The input tensor, e.g., activation.
209
+ B: The input tensor, e.g., weight.
210
+ As: The per-token-group quantization scale for `A`.
211
+ Bs: The per-block quantization scale for `B`.
212
+ block_size: The block size for per-block quantization. It should be 2-dim, e.g., [128, 128].
213
+ output_dytpe: The dtype of the returned tensor.
214
+
215
+ Returns:
216
+ torch.Tensor: The result of matmul.
217
+ """
218
+ assert len(block_size) == 2
219
+ block_n, block_k = block_size[0], block_size[1]
220
+
221
+ assert A.shape[-1] == B.shape[-1]
222
+ assert A.shape[:-1] == As.shape[:-1] and A.is_contiguous()
223
+ assert triton.cdiv(A.shape[-1], block_k) == As.shape[-1]
224
+ M = A.numel() // A.shape[-1]
225
+
226
+ assert B.ndim == 2 and B.is_contiguous() and Bs.ndim == 2
227
+ N, K = B.shape
228
+ assert triton.cdiv(N, block_n) == Bs.shape[0]
229
+ assert triton.cdiv(K, block_k) == Bs.shape[1]
230
+
231
+ C_shape = A.shape[:-1] + (N,)
232
+ C = A.new_empty(C_shape, dtype=output_dtype)
233
+
234
+ # TODO(HandH1998):
235
+ # BLOCK_SIZE_M, BLOCK_SIZE_K, BLOCK_SIZE_N can be optimized.
236
+ # BLOCK_SIZE_K must be divisable by block_k
237
+ # BLOCK_SIZE_N and BLOCK_SIZE_M has no requirements
238
+ BLOCK_SIZE_M = 128
239
+ if M < BLOCK_SIZE_M:
240
+ BLOCK_SIZE_M = triton.next_power_of_2(M)
241
+ BLOCK_SIZE_M = max(BLOCK_SIZE_M, 16)
242
+ BLOCK_SIZE_K = block_k
243
+ assert block_k % BLOCK_SIZE_K == 0
244
+ BLOCK_SIZE_N = block_n
245
+
246
+ def grid(META):
247
+ return (
248
+ triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]),
249
+ )
250
+
251
+ _w8a8_block_fp8_matmul[grid](
252
+ A,
253
+ B,
254
+ C,
255
+ As,
256
+ Bs,
257
+ M,
258
+ N,
259
+ K,
260
+ block_n,
261
+ block_k,
262
+ A.stride(-2),
263
+ A.stride(-1),
264
+ B.stride(1),
265
+ B.stride(0),
266
+ C.stride(-2),
267
+ C.stride(-1),
268
+ As.stride(-2),
269
+ As.stride(-1),
270
+ Bs.stride(1),
271
+ Bs.stride(0),
272
+ BLOCK_SIZE_M=BLOCK_SIZE_M,
273
+ BLOCK_SIZE_N=BLOCK_SIZE_N,
274
+ BLOCK_SIZE_K=BLOCK_SIZE_K,
275
+ GROUP_SIZE_M=8,
276
+ )
277
+
278
+ return C
@@ -1,6 +1,12 @@
1
- from typing import Optional, Tuple
1
+ from typing import List, Optional, Tuple
2
2
 
3
3
  import torch
4
+ from vllm.model_executor.parameter import RowvLLMParameter, _ColumnvLLMParameter
5
+
6
+ from sglang.srt.layers.quantization.fp8_kernel import (
7
+ per_token_group_quant_fp8,
8
+ w8a8_block_fp8_matmul,
9
+ )
4
10
 
5
11
 
6
12
  def normalize_e4m3fn_to_e4m3fnuz(
@@ -25,3 +31,86 @@ def normalize_e4m3fn_to_e4m3fnuz(
25
31
  if input_scale is not None:
26
32
  input_scale = input_scale * 2.0
27
33
  return weight, weight_scale, input_scale
34
+
35
+
36
+ def apply_w8a8_block_fp8_linear(
37
+ input: torch.Tensor,
38
+ weight: torch.Tensor,
39
+ block_size: List[int],
40
+ weight_scale: torch.Tensor,
41
+ input_scale: Optional[torch.Tensor] = None,
42
+ bias: Optional[torch.Tensor] = None,
43
+ ) -> torch.Tensor:
44
+ assert input_scale is None
45
+ # View input as 2D matrix for fp8 methods
46
+ input_2d = input.view(-1, input.shape[-1])
47
+ output_shape = [*input.shape[:-1], weight.shape[0]]
48
+
49
+ q_input, x_scale = per_token_group_quant_fp8(input_2d, block_size[1])
50
+ output = w8a8_block_fp8_matmul(
51
+ q_input, weight, x_scale, weight_scale, block_size, output_dtype=input.dtype
52
+ )
53
+
54
+ if bias is not None:
55
+ output = output + bias
56
+ return output.to(dtype=input.dtype).view(*output_shape)
57
+
58
+
59
+ def input_to_float8(
60
+ x: torch.Tensor, dtype: torch.dtype = torch.float8_e4m3fn
61
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
62
+ """This function quantizes input values to float8 values with tensor-wise quantization."""
63
+ finfo = torch.finfo(dtype)
64
+ min_val, max_val = x.aminmax()
65
+ amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-12)
66
+ scale = finfo.max / amax
67
+ x_scl_sat = (x * scale).clamp(min=finfo.min, max=finfo.max)
68
+ return x_scl_sat.to(dtype).contiguous(), scale.float().reciprocal()
69
+
70
+
71
+ def block_quant_to_tensor_quant(
72
+ x_q_block: torch.Tensor,
73
+ x_s: torch.Tensor,
74
+ block_size: List[int],
75
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
76
+ """This function converts block-wise quantization to tensor-wise quantization.
77
+ The inputs are block-wise quantization tensor `x_q_block`, block-wise quantization scale
78
+ and the block size.
79
+ The outputs are tensor-wise quantization tensor and tensor-wise quantization scale.
80
+ Note only float8 is supported for now.
81
+ """
82
+ block_n, block_k = block_size[0], block_size[1]
83
+ n, k = x_q_block.shape
84
+ n_tiles = (n + block_n - 1) // block_n
85
+ k_tiles = (k + block_k - 1) // block_k
86
+ assert n_tiles == x_s.shape[0]
87
+ assert k_tiles == x_s.shape[1]
88
+
89
+ x_dq_block = x_q_block.to(torch.float32)
90
+
91
+ x_dq_block_tiles = [
92
+ [
93
+ x_dq_block[
94
+ j * block_n : min((j + 1) * block_n, n),
95
+ i * block_k : min((i + 1) * block_k, k),
96
+ ]
97
+ for i in range(k_tiles)
98
+ ]
99
+ for j in range(n_tiles)
100
+ ]
101
+
102
+ for i in range(k_tiles):
103
+ for j in range(n_tiles):
104
+ x_dq_block_tiles[j][i][:, :] = x_dq_block_tiles[j][i] * x_s[j][i]
105
+
106
+ x_q_tensor, scale = input_to_float8(x_dq_block, dtype=x_q_block.dtype)
107
+ return x_q_tensor, scale
108
+
109
+
110
+ class BlockQuantScaleParameter(_ColumnvLLMParameter, RowvLLMParameter):
111
+ """
112
+ Parameter class for weight scales loaded for weights with
113
+ block-wise quantization. Uses both column and row parallelism.
114
+ """
115
+
116
+ pass
@@ -48,7 +48,14 @@ class RadixAttention(nn.Module):
48
48
  self.sliding_window_size = sliding_window_size or -1
49
49
  self.is_cross_attention = is_cross_attention
50
50
 
51
- def forward(self, q, k, v, forward_batch: ForwardBatch, save_kv_cache=True):
51
+ def forward(
52
+ self,
53
+ q,
54
+ k,
55
+ v,
56
+ forward_batch: ForwardBatch,
57
+ save_kv_cache: bool = True,
58
+ ):
52
59
  if k is not None:
53
60
  # For cross-layer sharing, kv can be None
54
61
  assert v is not None
@@ -51,7 +51,6 @@ class Sampler(nn.Module):
51
51
  # Post process logits
52
52
  logits.div_(sampling_info.temperatures)
53
53
  probs = torch.softmax(logits, dim=-1)
54
- logits = None
55
54
  del logits
56
55
 
57
56
  if global_server_args_dict["sampling_backend"] == "flashinfer":
@@ -84,6 +83,7 @@ class Sampler(nn.Module):
84
83
  sampling_info.top_ks,
85
84
  sampling_info.top_ps,
86
85
  sampling_info.min_ps,
86
+ sampling_info.need_min_p_sampling,
87
87
  )
88
88
  else:
89
89
  raise ValueError(
@@ -98,20 +98,42 @@ def top_k_top_p_min_p_sampling_from_probs_torch(
98
98
  top_ks: torch.Tensor,
99
99
  top_ps: torch.Tensor,
100
100
  min_ps: torch.Tensor,
101
+ need_min_p_sampling: bool,
101
102
  ):
102
103
  """A top-k, top-p and min-p sampling implementation with native pytorch operations."""
103
104
  probs_sort, probs_idx = probs.sort(dim=-1, descending=True)
104
105
  probs_sum = torch.cumsum(probs_sort, dim=-1)
105
- min_p_thresholds = probs_sort[:, 0] * min_ps
106
- probs_sort[(probs_sum - probs_sort) > top_ps.view(-1, 1)] = 0.0
107
106
  probs_sort[
108
107
  torch.arange(0, probs.shape[-1], device=probs.device).view(1, -1)
109
108
  >= top_ks.view(-1, 1)
110
109
  ] = 0.0
111
- probs_sort[probs_sort < min_p_thresholds.view(-1, 1)] = 0.0
112
- probs_sort.div_(probs_sort.max(dim=-1, keepdim=True)[0])
110
+ probs_sort[(probs_sum - probs_sort) > top_ps.view(-1, 1)] = 0.0
111
+
112
+ if need_min_p_sampling:
113
+ min_p_thresholds = probs_sort[:, 0] * min_ps
114
+ probs_sort[probs_sort < min_p_thresholds.view(-1, 1)] = 0.0
115
+
113
116
  sampled_index = torch.multinomial(probs_sort, num_samples=1)
114
117
  # int32 range is enough to represent the token ids
115
118
  probs_idx = probs_idx.to(torch.int32)
116
119
  batch_next_token_ids = torch.gather(probs_idx, dim=1, index=sampled_index).view(-1)
117
120
  return batch_next_token_ids
121
+
122
+
123
+ def top_p_normalize_probs(
124
+ probs: torch.Tensor,
125
+ top_ps: torch.Tensor,
126
+ ):
127
+ if global_server_args_dict["sampling_backend"] == "flashinfer":
128
+ return top_p_renorm_prob(probs, top_ps)
129
+ elif global_server_args_dict["sampling_backend"] == "pytorch":
130
+ # See also top_k_top_p_min_p_sampling_from_probs_torch
131
+ probs_sort, probs_idx = probs.sort(dim=-1, descending=True)
132
+ probs_sum = torch.cumsum(probs_sort, dim=-1)
133
+ probs_sort[(probs_sum - probs_sort) > top_ps.view(-1, 1)] = 0.0
134
+ probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
135
+ return torch.zeros_like(probs_sort).scatter_(-1, probs_idx, probs_sort)
136
+ else:
137
+ raise ValueError(
138
+ f"Invalid sampling backend: {global_server_args_dict['sampling_backend']}"
139
+ )
@@ -2,8 +2,14 @@
2
2
  Common utilities for torchao.
3
3
  """
4
4
 
5
+ import logging
6
+ import os
7
+ import pwd
8
+
5
9
  import torch
6
10
 
11
+ logger = logging.getLogger(__name__)
12
+
7
13
 
8
14
  def apply_torchao_config_to_model(
9
15
  model: torch.nn.Module, torchao_config: str, filter_fn=None
@@ -47,6 +53,31 @@ def apply_torchao_config_to_model(
47
53
  256,
48
54
  ], f"int4wo groupsize needs to be one of [32, 64, 128, 256] but got {group_size}"
49
55
  quantize_(model, int4_weight_only(group_size=group_size), filter_fn=filter_fn)
56
+ elif "gemlite" in torchao_config:
57
+ # gemlite-<packing_bitwidth>-<bit_width>-<group_size> or
58
+ # gemlite-<bit_width>-<group_size> (packing_bitwidth defaults to 32)
59
+ from gemlite.core import GemLiteLinearTriton
60
+ from torchao.quantization import gemlite_uintx_weight_only
61
+
62
+ _quant_args = torchao_config.split("-")
63
+ bit_width = int(_quant_args[-2])
64
+ group_size = None if _quant_args[-1] == "None" else int(_quant_args[-1])
65
+
66
+ try:
67
+ packing_bitwidth = int(_quant_args[-3])
68
+ except (ValueError, IndexError):
69
+ # if only 2 inputs found or conversion fails, use default value
70
+ packing_bitwidth = 32
71
+
72
+ quantize_(
73
+ model, gemlite_uintx_weight_only(group_size, bit_width, packing_bitwidth)
74
+ )
75
+
76
+ # try to load gemlite kernel config
77
+ GemLiteLinearTriton.load_config(
78
+ f"/tmp/{pwd.getpwuid(os.getuid()).pw_gecos}_gemlite.json"
79
+ )
80
+
50
81
  elif "fp8wo" in torchao_config:
51
82
  # this requires newer hardware
52
83
  # [rank0]: AssertionError: fp8e4nv data type is not supported on CUDA arch < 89
@@ -17,9 +17,10 @@ import dataclasses
17
17
  import logging
18
18
  import signal
19
19
  from collections import OrderedDict
20
- from typing import List, Union
20
+ from typing import Dict, List, Union
21
21
 
22
22
  import psutil
23
+ import setproctitle
23
24
  import zmq
24
25
 
25
26
  from sglang.srt.hf_transformers_utils import get_tokenizer
@@ -28,7 +29,6 @@ from sglang.srt.managers.io_struct import (
28
29
  BatchStrOut,
29
30
  BatchTokenIDOut,
30
31
  )
31
- from sglang.srt.managers.schedule_batch import FINISH_MATCHED_STR, FINISH_MATCHED_TOKEN
32
32
  from sglang.srt.server_args import PortArgs, ServerArgs
33
33
  from sglang.srt.utils import configure_logger, get_zmq_socket
34
34
  from sglang.utils import find_printable_text, get_exception_traceback
@@ -75,17 +75,25 @@ class DetokenizerManager:
75
75
 
76
76
  self.decode_status = LimitedCapacityDict()
77
77
 
78
- def trim_eos(self, output: Union[str, List[int]], finished_reason, no_stop_trim):
79
- if no_stop_trim:
78
+ def trim_matched_stop(
79
+ self, output: Union[str, List[int]], finished_reason: Dict, no_stop_trim: bool
80
+ ):
81
+ if no_stop_trim or not finished_reason:
82
+ return output
83
+
84
+ matched = finished_reason.get("matched", None)
85
+ if not matched:
80
86
  return output
81
87
 
82
- # Trim stop str. TODO(lmzheng): handle the case where multiple stop strs are hit
83
- if isinstance(finished_reason, FINISH_MATCHED_STR) and isinstance(output, str):
84
- pos = output.find(finished_reason.matched)
88
+ # TODO(lmzheng): handle the case where multiple stop strs are hit
89
+
90
+ # Trim stop str.
91
+ if isinstance(matched, str) and isinstance(output, str):
92
+ pos = output.find(matched)
85
93
  return output[:pos] if pos != -1 else output
86
- if isinstance(finished_reason, FINISH_MATCHED_TOKEN) and isinstance(
87
- output, list
88
- ):
94
+
95
+ # Trim stop token.
96
+ if isinstance(matched, int) and isinstance(output, list):
89
97
  assert len(output) > 0
90
98
  return output[:-1]
91
99
  return output
@@ -124,9 +132,9 @@ class DetokenizerManager:
124
132
  s.decode_ids = recv_obj.decode_ids[i]
125
133
 
126
134
  read_ids.append(
127
- self.trim_eos(
135
+ self.trim_matched_stop(
128
136
  s.decode_ids[s.surr_offset :],
129
- recv_obj.finished_reason[i],
137
+ recv_obj.finished_reasons[i],
130
138
  recv_obj.no_stop_trim[i],
131
139
  )
132
140
  )
@@ -149,7 +157,7 @@ class DetokenizerManager:
149
157
  for i in range(bs):
150
158
  s = self.decode_status[recv_obj.rids[i]]
151
159
  new_text = read_texts[i][len(surr_texts[i]) :]
152
- if recv_obj.finished_reason[i] is None:
160
+ if recv_obj.finished_reasons[i] is None:
153
161
  # Streaming chunk: update the decode status
154
162
  if len(new_text) > 0 and not new_text.endswith("�"):
155
163
  s.decoded_text = s.decoded_text + new_text
@@ -160,9 +168,9 @@ class DetokenizerManager:
160
168
  new_text = find_printable_text(new_text)
161
169
 
162
170
  output_strs.append(
163
- self.trim_eos(
171
+ self.trim_matched_stop(
164
172
  s.decoded_text + new_text,
165
- recv_obj.finished_reason[i],
173
+ recv_obj.finished_reasons[i],
166
174
  recv_obj.no_stop_trim[i],
167
175
  )
168
176
  )
@@ -170,9 +178,20 @@ class DetokenizerManager:
170
178
  self.send_to_tokenizer.send_pyobj(
171
179
  BatchStrOut(
172
180
  rids=recv_obj.rids,
181
+ finished_reasons=recv_obj.finished_reasons,
173
182
  output_strs=output_strs,
174
- meta_info=recv_obj.meta_info,
175
- finished_reason=recv_obj.finished_reason,
183
+ prompt_tokens=recv_obj.prompt_tokens,
184
+ completion_tokens=recv_obj.completion_tokens,
185
+ cached_tokens=recv_obj.cached_tokens,
186
+ input_token_logprobs_val=recv_obj.input_token_logprobs_val,
187
+ input_token_logprobs_idx=recv_obj.input_token_logprobs_idx,
188
+ output_token_logprobs_val=recv_obj.output_token_logprobs_val,
189
+ output_token_logprobs_idx=recv_obj.output_token_logprobs_idx,
190
+ input_top_logprobs_val=recv_obj.input_top_logprobs_val,
191
+ input_top_logprobs_idx=recv_obj.input_top_logprobs_idx,
192
+ output_top_logprobs_val=recv_obj.output_top_logprobs_val,
193
+ output_top_logprobs_idx=recv_obj.output_top_logprobs_idx,
194
+ normalized_prompt_logprob=recv_obj.normalized_prompt_logprob,
176
195
  )
177
196
  )
178
197
 
@@ -194,6 +213,7 @@ def run_detokenizer_process(
194
213
  server_args: ServerArgs,
195
214
  port_args: PortArgs,
196
215
  ):
216
+ setproctitle.setproctitle("sglang::detokenizer")
197
217
  configure_logger(server_args)
198
218
  parent_process = psutil.Process().parent()
199
219
 
@@ -308,6 +308,9 @@ class TokenizedEmbeddingReqInput:
308
308
  class BatchTokenIDOut:
309
309
  # The request id
310
310
  rids: List[str]
311
+ # The finish reason
312
+ finished_reasons: List[BaseFinishReason]
313
+ # For incremental decoding
311
314
  # The version id to sync decode status with in detokenizer_manager
312
315
  vids: List[int]
313
316
  decoded_texts: List[str]
@@ -315,35 +318,61 @@ class BatchTokenIDOut:
315
318
  read_offsets: List[int]
316
319
  # Only used when `--skip-tokenizer-init`
317
320
  output_ids: Optional[List[int]]
321
+ # Detokenization configs
318
322
  skip_special_tokens: List[bool]
319
323
  spaces_between_special_tokens: List[bool]
320
- meta_info: List[Dict]
321
- finished_reason: List[BaseFinishReason]
322
324
  no_stop_trim: List[bool]
325
+ # Token counts
326
+ prompt_tokens: List[int]
327
+ completion_tokens: List[int]
328
+ cached_tokens: List[int]
329
+ # Logprobs
330
+ input_token_logprobs_val: List[float]
331
+ input_token_logprobs_idx: List[int]
332
+ output_token_logprobs_val: List[float]
333
+ output_token_logprobs_idx: List[int]
334
+ input_top_logprobs_val: List[List]
335
+ input_top_logprobs_idx: List[List]
336
+ output_top_logprobs_val: List[List]
337
+ output_top_logprobs_idx: List[List]
338
+ normalized_prompt_logprob: List[float]
323
339
 
324
340
 
325
341
  @dataclass
326
342
  class BatchStrOut:
327
343
  # The request id
328
344
  rids: List[str]
345
+ # The finish reason
346
+ finished_reasons: List[dict]
329
347
  # The output decoded strings
330
348
  output_strs: List[str]
331
- # The meta info
332
- meta_info: List[Dict]
333
- # The finish reason
334
- finished_reason: List[BaseFinishReason]
349
+
350
+ # Token counts
351
+ prompt_tokens: List[int]
352
+ completion_tokens: List[int]
353
+ cached_tokens: List[int]
354
+ # Logprobs
355
+ input_token_logprobs_val: List[float]
356
+ input_token_logprobs_idx: List[int]
357
+ output_token_logprobs_val: List[float]
358
+ output_token_logprobs_idx: List[int]
359
+ input_top_logprobs_val: List[List]
360
+ input_top_logprobs_idx: List[List]
361
+ output_top_logprobs_val: List[List]
362
+ output_top_logprobs_idx: List[List]
363
+ normalized_prompt_logprob: List[float]
335
364
 
336
365
 
337
366
  @dataclass
338
367
  class BatchEmbeddingOut:
339
368
  # The request id
340
369
  rids: List[str]
370
+ # The finish reason
371
+ finished_reasons: List[BaseFinishReason]
341
372
  # The output embedding
342
373
  embeddings: List[List[float]]
343
- # The meta info
344
- meta_info: List[Dict]
345
- # The finish reason
346
- finished_reason: List[BaseFinishReason]
374
+ # Token counts
375
+ prompt_tokens: List[int]
347
376
 
348
377
 
349
378
  @dataclass