sglang 0.4.0.post2__py3-none-any.whl → 0.4.1.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (53) hide show
  1. sglang/bench_offline_throughput.py +0 -12
  2. sglang/bench_one_batch.py +0 -12
  3. sglang/bench_serving.py +11 -2
  4. sglang/lang/backend/openai.py +10 -0
  5. sglang/srt/aio_rwlock.py +100 -0
  6. sglang/srt/configs/model_config.py +8 -1
  7. sglang/srt/constrained/xgrammar_backend.py +6 -0
  8. sglang/srt/layers/attention/flashinfer_backend.py +49 -5
  9. sglang/srt/layers/attention/triton_ops/extend_attention.py +20 -14
  10. sglang/srt/layers/linear.py +20 -2
  11. sglang/srt/layers/{ep_moe → moe/ep_moe}/layer.py +14 -39
  12. sglang/srt/layers/moe/fused_moe_native.py +46 -0
  13. sglang/srt/layers/{fused_moe_triton → moe/fused_moe_triton}/__init__.py +3 -7
  14. sglang/srt/layers/{fused_moe_triton → moe/fused_moe_triton}/fused_moe.py +124 -99
  15. sglang/srt/layers/{fused_moe_triton → moe/fused_moe_triton}/layer.py +16 -48
  16. sglang/srt/layers/moe/topk.py +205 -0
  17. sglang/srt/layers/quantization/__init__.py +3 -3
  18. sglang/srt/layers/quantization/fp8.py +169 -32
  19. sglang/srt/layers/quantization/fp8_kernel.py +292 -0
  20. sglang/srt/layers/quantization/fp8_utils.py +90 -1
  21. sglang/srt/layers/torchao_utils.py +11 -15
  22. sglang/srt/managers/schedule_batch.py +16 -10
  23. sglang/srt/managers/schedule_policy.py +1 -1
  24. sglang/srt/managers/scheduler.py +13 -16
  25. sglang/srt/managers/tokenizer_manager.py +130 -111
  26. sglang/srt/mem_cache/memory_pool.py +15 -8
  27. sglang/srt/model_executor/cuda_graph_runner.py +1 -1
  28. sglang/srt/model_loader/loader.py +22 -11
  29. sglang/srt/models/dbrx.py +1 -1
  30. sglang/srt/models/deepseek.py +1 -1
  31. sglang/srt/models/deepseek_v2.py +67 -18
  32. sglang/srt/models/gemma2.py +19 -0
  33. sglang/srt/models/grok.py +1 -1
  34. sglang/srt/models/llama.py +2 -2
  35. sglang/srt/models/mixtral.py +2 -2
  36. sglang/srt/models/olmoe.py +1 -1
  37. sglang/srt/models/qwen2_moe.py +1 -1
  38. sglang/srt/models/xverse_moe.py +1 -1
  39. sglang/srt/openai_api/adapter.py +23 -0
  40. sglang/srt/openai_api/protocol.py +2 -0
  41. sglang/srt/sampling/sampling_params.py +9 -2
  42. sglang/srt/server.py +21 -37
  43. sglang/srt/utils.py +33 -44
  44. sglang/test/test_block_fp8.py +341 -0
  45. sglang/version.py +1 -1
  46. {sglang-0.4.0.post2.dist-info → sglang-0.4.1.post1.dist-info}/METADATA +4 -4
  47. {sglang-0.4.0.post2.dist-info → sglang-0.4.1.post1.dist-info}/RECORD +52 -48
  48. sglang/srt/layers/fused_moe_patch.py +0 -133
  49. /sglang/srt/layers/{ep_moe → moe/ep_moe}/__init__.py +0 -0
  50. /sglang/srt/layers/{ep_moe → moe/ep_moe}/kernels.py +0 -0
  51. {sglang-0.4.0.post2.dist-info → sglang-0.4.1.post1.dist-info}/LICENSE +0 -0
  52. {sglang-0.4.0.post2.dist-info → sglang-0.4.1.post1.dist-info}/WHEEL +0 -0
  53. {sglang-0.4.0.post2.dist-info → sglang-0.4.1.post1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,292 @@
1
+ # Copyright 2024 SGLang Team
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # ==============================================================================
14
+
15
+ from typing import List, Tuple
16
+
17
+ import torch
18
+ import triton
19
+ import triton.language as tl
20
+
21
+
22
+ @triton.jit
23
+ def _per_token_group_quant_fp8(
24
+ # Pointers to inputs and output
25
+ y_ptr,
26
+ y_q_ptr,
27
+ y_s_ptr,
28
+ # Stride of input
29
+ y_stride,
30
+ # Collums of input
31
+ N,
32
+ # Avoid to divide zero
33
+ eps,
34
+ # Information for float8
35
+ fp8_min,
36
+ fp8_max,
37
+ # Meta-parameters
38
+ BLOCK: tl.constexpr,
39
+ ):
40
+ """A Triton-accelerated function to perform per-token-group quantization on a
41
+ tensor.
42
+
43
+ This function converts the tensor values into float8 values.
44
+ """
45
+ # Map the program id to the row of X and Y it should compute.
46
+ g_id = tl.program_id(0)
47
+ y_ptr += g_id * y_stride
48
+ y_q_ptr += g_id * y_stride
49
+ y_s_ptr += g_id
50
+
51
+ cols = tl.arange(0, BLOCK) # N <= BLOCK
52
+ mask = cols < N
53
+
54
+ y = tl.load(y_ptr + cols, mask=mask, other=0.0).to(tl.float32)
55
+ # Quant
56
+ _absmax = tl.maximum(tl.max(tl.abs(y)), eps)
57
+ y_s = _absmax / fp8_max
58
+ y_q = tl.clamp(y / y_s, fp8_min, fp8_max).to(y_q_ptr.dtype.element_ty)
59
+
60
+ tl.store(y_q_ptr + cols, y_q, mask=mask)
61
+ tl.store(y_s_ptr, y_s)
62
+
63
+
64
+ def per_token_group_quant_fp8(
65
+ x: torch.Tensor,
66
+ group_size: int,
67
+ eps: float = 1e-10,
68
+ dtype: torch.dtype = torch.float8_e4m3fn,
69
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
70
+ """Function to perform per-token-group quantization on an input tensor `x`.
71
+
72
+ It converts the tensor values into signed float8 values and returns the
73
+ quantized tensor along with the scaling factor used for quantization.
74
+
75
+ Args:
76
+ x: The input tenosr with ndim >= 2.
77
+ group_size: The group size used for quantization.
78
+ eps: The minimum to avoid dividing zero.
79
+ dtype: The dype of output tensor. Note that only `torch.float8_e4m3fn` is supported for now.
80
+
81
+ Returns:
82
+ Tuple[torch.Tensor, torch.Tensor]: The quantized tensor and the scaling factor for quantization.
83
+ """
84
+ assert (
85
+ x.shape[-1] % group_size == 0
86
+ ), "the last dimension of `x` cannot be divisible by `group_size`"
87
+ assert x.is_contiguous(), "`x` is not contiguous"
88
+
89
+ finfo = torch.finfo(dtype)
90
+ fp8_min = finfo.min
91
+ fp8_max = finfo.max
92
+
93
+ x_q = torch.empty_like(x, device=x.device, dtype=dtype)
94
+ M = x.numel() // group_size
95
+ N = group_size
96
+ x_s = torch.empty(
97
+ x.shape[:-1] + (x.shape[-1] // group_size,),
98
+ device=x.device,
99
+ dtype=torch.float32,
100
+ )
101
+
102
+ BLOCK = triton.next_power_of_2(N)
103
+ # heuristics for number of warps
104
+ num_warps = min(max(BLOCK // 256, 1), 8)
105
+ num_stages = 1
106
+ _per_token_group_quant_fp8[(M,)](
107
+ x,
108
+ x_q,
109
+ x_s,
110
+ group_size,
111
+ N,
112
+ eps,
113
+ fp8_min=fp8_min,
114
+ fp8_max=fp8_max,
115
+ BLOCK=BLOCK,
116
+ num_warps=num_warps,
117
+ num_stages=num_stages,
118
+ )
119
+
120
+ return x_q, x_s
121
+
122
+
123
+ @triton.jit
124
+ def _w8a8_block_fp8_matmul(
125
+ # Pointers to inputs and output
126
+ A,
127
+ B,
128
+ C,
129
+ As,
130
+ Bs,
131
+ # Shape for matmul
132
+ M,
133
+ N,
134
+ K,
135
+ # Block size for block-wise quantization
136
+ group_n,
137
+ group_k,
138
+ # Stride for inputs and output
139
+ stride_am,
140
+ stride_ak,
141
+ stride_bk,
142
+ stride_bn,
143
+ stride_cm,
144
+ stride_cn,
145
+ stride_As_m,
146
+ stride_As_k,
147
+ stride_Bs_k,
148
+ stride_Bs_n,
149
+ # Meta-parameters
150
+ BLOCK_SIZE_M: tl.constexpr,
151
+ BLOCK_SIZE_N: tl.constexpr,
152
+ BLOCK_SIZE_K: tl.constexpr,
153
+ GROUP_SIZE_M: tl.constexpr,
154
+ ):
155
+ """Triton-accelerated function used to perform linear operations (dot
156
+ product) on input tensors `A` and `B` with block-wise quantization, and store the result in output
157
+ tensor `C`.
158
+ """
159
+
160
+ pid = tl.program_id(axis=0)
161
+ num_pid_m = tl.cdiv(M, BLOCK_SIZE_M)
162
+ num_pid_n = tl.cdiv(N, BLOCK_SIZE_N)
163
+ num_pid_in_group = GROUP_SIZE_M * num_pid_n
164
+ group_id = pid // num_pid_in_group
165
+ first_pid_m = group_id * GROUP_SIZE_M
166
+ group_size_m = min(num_pid_m - first_pid_m, GROUP_SIZE_M)
167
+ pid_m = first_pid_m + (pid % group_size_m)
168
+ pid_n = (pid % num_pid_in_group) // group_size_m
169
+
170
+ offs_am = (pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)) % M
171
+ offs_bn = (pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)) % N
172
+ offs_k = tl.arange(0, BLOCK_SIZE_K)
173
+ a_ptrs = A + (offs_am[:, None] * stride_am + offs_k[None, :] * stride_ak)
174
+ b_ptrs = B + (offs_k[:, None] * stride_bk + offs_bn[None, :] * stride_bn)
175
+
176
+ As_ptrs = As + offs_am * stride_As_m
177
+ offs_bsn = offs_bn // group_n
178
+ Bs_ptrs = Bs + offs_bsn * stride_Bs_n
179
+
180
+ accumulator = tl.zeros((BLOCK_SIZE_M, BLOCK_SIZE_N), dtype=tl.float32)
181
+ for k in range(0, tl.cdiv(K, BLOCK_SIZE_K)):
182
+ a = tl.load(a_ptrs, mask=offs_k[None, :] < K - k * BLOCK_SIZE_K, other=0.0)
183
+ b = tl.load(b_ptrs, mask=offs_k[:, None] < K - k * BLOCK_SIZE_K, other=0.0)
184
+
185
+ k_start = k * BLOCK_SIZE_K
186
+ offs_ks = k_start // group_k
187
+ a_s = tl.load(As_ptrs + offs_ks * stride_As_k)
188
+ b_s = tl.load(Bs_ptrs + offs_ks * stride_Bs_k)
189
+
190
+ accumulator += tl.dot(a, b) * a_s[:, None] * b_s[None, :]
191
+ a_ptrs += BLOCK_SIZE_K * stride_ak
192
+ b_ptrs += BLOCK_SIZE_K * stride_bk
193
+
194
+ if C.dtype.element_ty == tl.bfloat16:
195
+ c = accumulator.to(tl.bfloat16)
196
+ elif C.dtype.element_ty == tl.float16:
197
+ c = accumulator.to(tl.float16)
198
+ else:
199
+ c = accumulator.to(tl.float32)
200
+
201
+ offs_cm = pid_m * BLOCK_SIZE_M + tl.arange(0, BLOCK_SIZE_M)
202
+ offs_cn = pid_n * BLOCK_SIZE_N + tl.arange(0, BLOCK_SIZE_N)
203
+ c_ptrs = C + stride_cm * offs_cm[:, None] + stride_cn * offs_cn[None, :]
204
+ c_mask = (offs_cm[:, None] < M) & (offs_cn[None, :] < N)
205
+ tl.store(c_ptrs, c, mask=c_mask)
206
+
207
+
208
+ def w8a8_block_fp8_matmul(
209
+ A: torch.Tensor,
210
+ B: torch.Tensor,
211
+ As: torch.Tensor,
212
+ Bs: torch.Tensor,
213
+ block_size: List[int],
214
+ output_dtype: torch.dtype = torch.float16,
215
+ ) -> torch.Tensor:
216
+ """This function performs matrix multiplication with block-wise quantization.
217
+
218
+ It takes two input tensors `A` and `B` with scales `As` and `Bs`.
219
+ The output is returned in the specified `output_dtype`.
220
+
221
+ Args:
222
+ A: The input tensor, e.g., activation.
223
+ B: The input tensor, e.g., weight.
224
+ As: The per-token-group quantization scale for `A`.
225
+ Bs: The per-block quantization scale for `B`.
226
+ block_size: The block size for per-block quantization. It should be 2-dim, e.g., [128, 128].
227
+ output_dytpe: The dtype of the returned tensor.
228
+
229
+ Returns:
230
+ torch.Tensor: The result of matmul.
231
+ """
232
+ assert len(block_size) == 2
233
+ block_n, block_k = block_size[0], block_size[1]
234
+
235
+ assert A.shape[-1] == B.shape[-1]
236
+ assert A.shape[:-1] == As.shape[:-1] and A.is_contiguous()
237
+ assert triton.cdiv(A.shape[-1], block_k) == As.shape[-1]
238
+ M = A.numel() // A.shape[-1]
239
+
240
+ assert B.ndim == 2 and B.is_contiguous() and Bs.ndim == 2
241
+ N, K = B.shape
242
+ assert triton.cdiv(N, block_n) == Bs.shape[0]
243
+ assert triton.cdiv(K, block_k) == Bs.shape[1]
244
+
245
+ C_shape = A.shape[:-1] + (N,)
246
+ C = A.new_empty(C_shape, dtype=output_dtype)
247
+
248
+ # TODO(HandH1998):
249
+ # BLOCK_SIZE_M, BLOCK_SIZE_K, BLOCK_SIZE_N can be optimized.
250
+ # BLOCK_SIZE_K must be divisable by block_k
251
+ # BLOCK_SIZE_N and BLOCK_SIZE_M has no requirements
252
+ BLOCK_SIZE_M = 128
253
+ if M < BLOCK_SIZE_M:
254
+ BLOCK_SIZE_M = triton.next_power_of_2(M)
255
+ BLOCK_SIZE_M = max(BLOCK_SIZE_M, 16)
256
+ BLOCK_SIZE_K = block_k
257
+ assert block_k % BLOCK_SIZE_K == 0
258
+ BLOCK_SIZE_N = block_n
259
+
260
+ def grid(META):
261
+ return (
262
+ triton.cdiv(M, META["BLOCK_SIZE_M"]) * triton.cdiv(N, META["BLOCK_SIZE_N"]),
263
+ )
264
+
265
+ _w8a8_block_fp8_matmul[grid](
266
+ A,
267
+ B,
268
+ C,
269
+ As,
270
+ Bs,
271
+ M,
272
+ N,
273
+ K,
274
+ block_n,
275
+ block_k,
276
+ A.stride(-2),
277
+ A.stride(-1),
278
+ B.stride(1),
279
+ B.stride(0),
280
+ C.stride(-2),
281
+ C.stride(-1),
282
+ As.stride(-2),
283
+ As.stride(-1),
284
+ Bs.stride(1),
285
+ Bs.stride(0),
286
+ BLOCK_SIZE_M=BLOCK_SIZE_M,
287
+ BLOCK_SIZE_N=BLOCK_SIZE_N,
288
+ BLOCK_SIZE_K=BLOCK_SIZE_K,
289
+ GROUP_SIZE_M=8,
290
+ )
291
+
292
+ return C
@@ -1,6 +1,12 @@
1
- from typing import Optional, Tuple
1
+ from typing import List, Optional, Tuple
2
2
 
3
3
  import torch
4
+ from vllm.model_executor.parameter import RowvLLMParameter, _ColumnvLLMParameter
5
+
6
+ from sglang.srt.layers.quantization.fp8_kernel import (
7
+ per_token_group_quant_fp8,
8
+ w8a8_block_fp8_matmul,
9
+ )
4
10
 
5
11
 
6
12
  def normalize_e4m3fn_to_e4m3fnuz(
@@ -25,3 +31,86 @@ def normalize_e4m3fn_to_e4m3fnuz(
25
31
  if input_scale is not None:
26
32
  input_scale = input_scale * 2.0
27
33
  return weight, weight_scale, input_scale
34
+
35
+
36
+ def apply_w8a8_block_fp8_linear(
37
+ input: torch.Tensor,
38
+ weight: torch.Tensor,
39
+ block_size: List[int],
40
+ weight_scale: torch.Tensor,
41
+ input_scale: Optional[torch.Tensor] = None,
42
+ bias: Optional[torch.Tensor] = None,
43
+ ) -> torch.Tensor:
44
+ assert input_scale is None
45
+ # View input as 2D matrix for fp8 methods
46
+ input_2d = input.view(-1, input.shape[-1])
47
+ output_shape = [*input.shape[:-1], weight.shape[0]]
48
+
49
+ q_input, x_scale = per_token_group_quant_fp8(input_2d, block_size[1])
50
+ output = w8a8_block_fp8_matmul(
51
+ q_input, weight, x_scale, weight_scale, block_size, output_dtype=input.dtype
52
+ )
53
+
54
+ if bias is not None:
55
+ output = output + bias
56
+ return output.to(dtype=input.dtype).view(*output_shape)
57
+
58
+
59
+ def input_to_float8(
60
+ x: torch.Tensor, dtype: torch.dtype = torch.float8_e4m3fn
61
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
62
+ """This function quantizes input values to float8 values with tensor-wise quantization."""
63
+ finfo = torch.finfo(dtype)
64
+ min_val, max_val = x.aminmax()
65
+ amax = torch.maximum(min_val.abs(), max_val.abs()).clamp(min=1e-12)
66
+ scale = finfo.max / amax
67
+ x_scl_sat = (x * scale).clamp(min=finfo.min, max=finfo.max)
68
+ return x_scl_sat.to(dtype).contiguous(), scale.float().reciprocal()
69
+
70
+
71
+ def block_quant_to_tensor_quant(
72
+ x_q_block: torch.Tensor,
73
+ x_s: torch.Tensor,
74
+ block_size: List[int],
75
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
76
+ """This function converts block-wise quantization to tensor-wise quantization.
77
+ The inputs are block-wise quantization tensor `x_q_block`, block-wise quantization scale
78
+ and the block size.
79
+ The outputs are tensor-wise quantization tensor and tensor-wise quantization scale.
80
+ Note only float8 is supported for now.
81
+ """
82
+ block_n, block_k = block_size[0], block_size[1]
83
+ n, k = x_q_block.shape
84
+ n_tiles = (n + block_n - 1) // block_n
85
+ k_tiles = (k + block_k - 1) // block_k
86
+ assert n_tiles == x_s.shape[0]
87
+ assert k_tiles == x_s.shape[1]
88
+
89
+ x_dq_block = x_q_block.to(torch.float32)
90
+
91
+ x_dq_block_tiles = [
92
+ [
93
+ x_dq_block[
94
+ j * block_n : min((j + 1) * block_n, n),
95
+ i * block_k : min((i + 1) * block_k, k),
96
+ ]
97
+ for i in range(k_tiles)
98
+ ]
99
+ for j in range(n_tiles)
100
+ ]
101
+
102
+ for i in range(k_tiles):
103
+ for j in range(n_tiles):
104
+ x_dq_block_tiles[j][i][:, :] = x_dq_block_tiles[j][i] * x_s[j][i]
105
+
106
+ x_q_tensor, scale = input_to_float8(x_dq_block, dtype=x_q_block.dtype)
107
+ return x_q_tensor, scale
108
+
109
+
110
+ class BlockQuantScaleParameter(_ColumnvLLMParameter, RowvLLMParameter):
111
+ """
112
+ Parameter class for weight scales loaded for weights with
113
+ block-wise quantization. Uses both column and row parallelism.
114
+ """
115
+
116
+ pass
@@ -2,8 +2,14 @@
2
2
  Common utilities for torchao.
3
3
  """
4
4
 
5
+ import logging
6
+ import os
7
+ import pwd
8
+
5
9
  import torch
6
10
 
11
+ logger = logging.getLogger(__name__)
12
+
7
13
 
8
14
  def apply_torchao_config_to_model(
9
15
  model: torch.nn.Module, torchao_config: str, filter_fn=None
@@ -50,27 +56,17 @@ def apply_torchao_config_to_model(
50
56
  elif "gemlite" in torchao_config:
51
57
  # gemlite-<packing_bitwidth>-<bit_width>-<group_size> or
52
58
  # gemlite-<bit_width>-<group_size> (packing_bitwidth defaults to 32)
53
- import os
54
- import pwd
55
-
56
- import gemlite
57
- from gemlite.core import GemLiteLinearTriton, set_autotune
58
-
59
- try:
60
- from torchao.quantization import gemlite_uintx_weight_only
61
- except:
62
- print(
63
- f"import `gemlite_uintx_weight_only` failed, please use torchao nightly to use gemlite quantization"
64
- )
65
- return model
59
+ from gemlite.core import GemLiteLinearTriton
60
+ from torchao.quantization import gemlite_uintx_weight_only
66
61
 
67
62
  _quant_args = torchao_config.split("-")
68
63
  bit_width = int(_quant_args[-2])
69
64
  group_size = None if _quant_args[-1] == "None" else int(_quant_args[-1])
65
+
70
66
  try:
71
67
  packing_bitwidth = int(_quant_args[-3])
72
- except:
73
- # if only 2 inputs found, use default value
68
+ except (ValueError, IndexError):
69
+ # if only 2 inputs found or conversion fails, use default value
74
70
  packing_bitwidth = 32
75
71
 
76
72
  quantize_(
@@ -479,8 +479,22 @@ class Req:
479
479
 
480
480
  return True
481
481
 
482
+ def reset_for_retract(self):
483
+ self.prefix_indices = []
484
+ self.last_node = None
485
+ self.extend_input_len = 0
486
+ self.is_retracted = True
487
+
488
+ # For incremental logprobs
489
+ # TODO: Fix the `logprob_start_len`
490
+ self.last_update_decode_tokens = 0
491
+ self.logprob_start_len = 10**9
492
+
482
493
  def __repr__(self):
483
- return f"rid(n={self.rid}, " f"input_ids={self.origin_input_ids}, "
494
+ return (
495
+ f"rid(n={self.rid}, "
496
+ f"input_ids={self.origin_input_ids}, output_ids={self.output_ids}"
497
+ )
484
498
 
485
499
 
486
500
  bid = 0
@@ -894,15 +908,7 @@ class ScheduleBatch:
894
908
  )
895
909
  residual_size = max(0, residual_size)
896
910
  self.tree_cache.evict(residual_size, self.token_to_kv_pool.free)
897
-
898
- req.prefix_indices = []
899
- req.last_node = None
900
- req.extend_input_len = 0
901
- req.is_retracted = True
902
-
903
- # For incremental logprobs
904
- req.last_update_decode_tokens = 0
905
- req.logprob_start_len = 10**9
911
+ req.reset_for_retract()
906
912
 
907
913
  self.filter_batch(keep_indices=sorted_indices)
908
914
 
@@ -248,7 +248,7 @@ class PrefillAdder:
248
248
  self.can_run_list.append(req)
249
249
 
250
250
  self._prefill_one_req(
251
- len(req.prefix_indices),
251
+ 0,
252
252
  req.extend_input_len,
253
253
  (
254
254
  min(req.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKENS_ESTIMATION)
@@ -22,7 +22,7 @@ import warnings
22
22
  from collections import deque
23
23
  from concurrent import futures
24
24
  from types import SimpleNamespace
25
- from typing import List, Optional
25
+ from typing import Callable, Dict, List, Optional, Tuple
26
26
 
27
27
  import psutil
28
28
  import setproctitle
@@ -260,7 +260,7 @@ class Scheduler:
260
260
  self.current_stream = torch.get_device_module(self.device).current_stream()
261
261
 
262
262
  # Session info
263
- self.sessions = {}
263
+ self.sessions: Dict[str, Session] = {}
264
264
 
265
265
  # Init chunked prefill
266
266
  self.chunked_prefill_size = server_args.chunked_prefill_size
@@ -468,9 +468,6 @@ class Scheduler:
468
468
  self.send_to_tokenizer.send_pyobj(
469
469
  UpdateWeightFromDiskReqOutput(success, message)
470
470
  )
471
- elif isinstance(recv_req, GetWeightsByNameReqInput):
472
- parameter = self.get_weights_by_name(recv_req)
473
- self.send_to_tokenizer.send_pyobj(GetWeightsByNameReqOutput(parameter))
474
471
  elif isinstance(recv_req, InitWeightsUpdateGroupReqInput):
475
472
  success, message = self.init_weights_update_group(recv_req)
476
473
  self.send_to_tokenizer.send_pyobj(
@@ -565,7 +562,7 @@ class Scheduler:
565
562
 
566
563
  if req.logprob_start_len == -1:
567
564
  # By default, only return the logprobs for output tokens
568
- req.logprob_start_len = len(recv_req.input_ids) - 1
565
+ req.logprob_start_len = len(req.origin_input_ids) - 1
569
566
 
570
567
  # Truncate prompts that are too long
571
568
  if len(req.origin_input_ids) > self.max_req_input_len:
@@ -589,12 +586,15 @@ class Scheduler:
589
586
  if (
590
587
  req.sampling_params.json_schema is not None
591
588
  or req.sampling_params.regex is not None
589
+ or req.sampling_params.ebnf is not None
592
590
  ):
593
591
  assert self.grammar_backend is not None
594
592
  if req.sampling_params.json_schema is not None:
595
593
  key = ("json", req.sampling_params.json_schema)
596
594
  elif req.sampling_params.regex is not None:
597
595
  key = ("regex", req.sampling_params.regex)
596
+ elif req.sampling_params.ebnf is not None:
597
+ key = ("ebnf", req.sampling_params.ebnf)
598
598
 
599
599
  req.grammar = self.grammar_backend.get_cached_value(key)
600
600
  if not req.grammar:
@@ -629,16 +629,13 @@ class Scheduler:
629
629
  self.waiting_queue.append(req)
630
630
 
631
631
  def log_prefill_stats(self, adder, can_run_list, running_bs, has_being_chunked):
632
- if isinstance(self.tree_cache, RadixCache):
633
- self.tree_cache_metrics["total"] += (
634
- adder.log_input_tokens + adder.log_hit_tokens
635
- ) / 10**9
636
- self.tree_cache_metrics["hit"] += (adder.log_hit_tokens) / 10**9
637
- tree_cache_hit_rate = (
638
- self.tree_cache_metrics["hit"] / self.tree_cache_metrics["total"]
639
- )
640
- else:
641
- tree_cache_hit_rate = 0.0
632
+ self.tree_cache_metrics["total"] += (
633
+ adder.log_input_tokens + adder.log_hit_tokens
634
+ ) / 10**9
635
+ self.tree_cache_metrics["hit"] += (adder.log_hit_tokens) / 10**9
636
+ tree_cache_hit_rate = (
637
+ self.tree_cache_metrics["hit"] / self.tree_cache_metrics["total"]
638
+ )
642
639
 
643
640
  num_used = self.max_total_num_tokens - (
644
641
  self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()