sglang 0.2.11__py3-none-any.whl → 0.2.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (85) hide show
  1. sglang/api.py +7 -1
  2. sglang/bench_latency.py +9 -6
  3. sglang/bench_serving.py +46 -22
  4. sglang/global_config.py +1 -1
  5. sglang/lang/backend/runtime_endpoint.py +60 -49
  6. sglang/lang/compiler.py +2 -2
  7. sglang/lang/interpreter.py +4 -2
  8. sglang/lang/ir.py +16 -7
  9. sglang/srt/constrained/base_tool_cache.py +1 -1
  10. sglang/srt/constrained/fsm_cache.py +12 -2
  11. sglang/srt/constrained/jump_forward.py +13 -2
  12. sglang/srt/layers/activation.py +32 -0
  13. sglang/srt/layers/{token_attention.py → decode_attention.py} +9 -5
  14. sglang/srt/layers/extend_attention.py +9 -2
  15. sglang/srt/layers/fused_moe/__init__.py +1 -0
  16. sglang/srt/layers/{fused_moe.py → fused_moe/fused_moe.py} +165 -108
  17. sglang/srt/layers/fused_moe/layer.py +587 -0
  18. sglang/srt/layers/layernorm.py +65 -0
  19. sglang/srt/layers/logits_processor.py +7 -2
  20. sglang/srt/layers/pooler.py +50 -0
  21. sglang/srt/layers/{context_flashattention_nopad.py → prefill_attention.py} +5 -0
  22. sglang/srt/layers/radix_attention.py +40 -16
  23. sglang/srt/managers/detokenizer_manager.py +31 -9
  24. sglang/srt/managers/io_struct.py +63 -0
  25. sglang/srt/managers/policy_scheduler.py +173 -25
  26. sglang/srt/managers/schedule_batch.py +115 -97
  27. sglang/srt/managers/tokenizer_manager.py +194 -112
  28. sglang/srt/managers/tp_worker.py +290 -359
  29. sglang/srt/mem_cache/{base_cache.py → base_prefix_cache.py} +9 -4
  30. sglang/srt/mem_cache/chunk_cache.py +43 -20
  31. sglang/srt/mem_cache/memory_pool.py +2 -2
  32. sglang/srt/mem_cache/radix_cache.py +74 -40
  33. sglang/srt/model_executor/cuda_graph_runner.py +71 -25
  34. sglang/srt/model_executor/forward_batch_info.py +293 -156
  35. sglang/srt/model_executor/model_runner.py +77 -57
  36. sglang/srt/models/chatglm.py +2 -2
  37. sglang/srt/models/commandr.py +1 -1
  38. sglang/srt/models/deepseek.py +2 -2
  39. sglang/srt/models/deepseek_v2.py +7 -6
  40. sglang/srt/models/gemma.py +1 -1
  41. sglang/srt/models/gemma2.py +11 -6
  42. sglang/srt/models/grok.py +50 -396
  43. sglang/srt/models/internlm2.py +2 -7
  44. sglang/srt/models/llama2.py +4 -4
  45. sglang/srt/models/llama_embedding.py +88 -0
  46. sglang/srt/models/minicpm.py +2 -2
  47. sglang/srt/models/mixtral.py +56 -254
  48. sglang/srt/models/mixtral_quant.py +1 -4
  49. sglang/srt/models/qwen.py +2 -2
  50. sglang/srt/models/qwen2.py +2 -2
  51. sglang/srt/models/qwen2_moe.py +2 -13
  52. sglang/srt/models/stablelm.py +1 -1
  53. sglang/srt/openai_api/adapter.py +187 -48
  54. sglang/srt/openai_api/protocol.py +37 -1
  55. sglang/srt/sampling/penaltylib/__init__.py +13 -0
  56. sglang/srt/sampling/penaltylib/orchestrator.py +357 -0
  57. sglang/srt/sampling/penaltylib/penalizers/frequency_penalty.py +80 -0
  58. sglang/srt/sampling/penaltylib/penalizers/min_new_tokens.py +105 -0
  59. sglang/srt/sampling/penaltylib/penalizers/presence_penalty.py +79 -0
  60. sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +83 -0
  61. sglang/srt/sampling_params.py +31 -8
  62. sglang/srt/server.py +91 -29
  63. sglang/srt/server_args.py +32 -19
  64. sglang/srt/utils.py +32 -15
  65. sglang/test/run_eval.py +10 -1
  66. sglang/test/runners.py +81 -73
  67. sglang/test/simple_eval_humaneval.py +2 -8
  68. sglang/test/simple_eval_mgsm.py +203 -0
  69. sglang/test/srt/sampling/penaltylib/utils.py +337 -0
  70. sglang/test/test_layernorm.py +60 -0
  71. sglang/test/test_programs.py +36 -7
  72. sglang/test/test_utils.py +24 -2
  73. sglang/utils.py +0 -1
  74. sglang/version.py +1 -1
  75. {sglang-0.2.11.dist-info → sglang-0.2.13.dist-info}/METADATA +33 -16
  76. sglang-0.2.13.dist-info/RECORD +112 -0
  77. {sglang-0.2.11.dist-info → sglang-0.2.13.dist-info}/WHEEL +1 -1
  78. sglang/srt/layers/linear.py +0 -884
  79. sglang/srt/layers/quantization/__init__.py +0 -64
  80. sglang/srt/layers/quantization/fp8.py +0 -677
  81. sglang/srt/model_loader/model_loader.py +0 -292
  82. sglang/srt/model_loader/utils.py +0 -275
  83. sglang-0.2.11.dist-info/RECORD +0 -102
  84. {sglang-0.2.11.dist-info → sglang-0.2.13.dist-info}/LICENSE +0 -0
  85. {sglang-0.2.11.dist-info → sglang-0.2.13.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,50 @@
1
+ # adapted from
2
+ # https://github.com/vllm-project/vllm/blob/82a1b1a82b1fbb454c82a9ef95730b929c9b270c/vllm/model_executor/layers/pooler.py
3
+
4
+ from dataclasses import dataclass
5
+ from enum import IntEnum
6
+
7
+ import torch
8
+ import torch.nn as nn
9
+
10
+ from sglang.srt.model_executor.model_runner import InputMetadata
11
+
12
+
13
+ class PoolingType(IntEnum):
14
+ LAST = 0
15
+
16
+
17
+ @dataclass
18
+ class EmbeddingPoolerOutput:
19
+ embeddings: torch.Tensor
20
+
21
+
22
+ class Pooler(nn.Module):
23
+ """A layer that pools specific information from hidden states.
24
+ This layer does the following:
25
+ 1. Extracts specific tokens or aggregates data based on pooling method.
26
+ 2. Normalizes output if specified.
27
+ 3. Returns structured results as `PoolerOutput`.
28
+ Attributes:
29
+ pooling_type: The type of pooling to use (LAST, AVERAGE, MAX).
30
+ normalize: Whether to normalize the pooled data.
31
+ """
32
+
33
+ def __init__(self, pooling_type: PoolingType, normalize: bool):
34
+ super().__init__()
35
+ self.pooling_type = pooling_type
36
+ self.normalize = normalize
37
+
38
+ def forward(
39
+ self, hidden_states: torch.Tensor, input_metadata: InputMetadata
40
+ ) -> EmbeddingPoolerOutput:
41
+ if self.pooling_type == PoolingType.LAST:
42
+ last_token_indices = torch.cumsum(input_metadata.extend_seq_lens, dim=0) - 1
43
+ pooled_data = hidden_states[last_token_indices]
44
+ else:
45
+ raise ValueError(f"Invalid pooling type: {self.pooling_type}")
46
+
47
+ if self.normalize:
48
+ pooled_data = nn.functional.normalize(pooled_data, p=2, dim=1)
49
+
50
+ return EmbeddingPoolerOutput(embeddings=pooled_data)
@@ -13,6 +13,11 @@ See the License for the specific language governing permissions and
13
13
  limitations under the License.
14
14
  """
15
15
 
16
+ """
17
+ Memory-efficient attention for prefill.
18
+ It supporst page size = 1.
19
+ """
20
+
16
21
  # Adapted from
17
22
  # https://github.com/ModelTC/lightllm/blob/f2a54f0912293f683bf1d1695fd12c4098a5bf82/lightllm/models/llama/triton_kernel/context_flashattention_nopad.py#L1
18
23
  import torch
@@ -15,13 +15,15 @@ limitations under the License.
15
15
 
16
16
  """Radix attention."""
17
17
 
18
+ from typing import Optional
19
+
18
20
  import torch
19
21
  from flashinfer.cascade import merge_state
20
22
  from torch import nn
21
23
 
22
24
  from sglang.global_config import global_config
25
+ from sglang.srt.layers.decode_attention import decode_attention_fwd
23
26
  from sglang.srt.layers.extend_attention import extend_attention_fwd
24
- from sglang.srt.layers.token_attention import token_attention_fwd
25
27
  from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata
26
28
  from sglang.srt.model_executor.model_runner import global_server_args_dict
27
29
 
@@ -34,6 +36,7 @@ class RadixAttention(nn.Module):
34
36
  scaling: float,
35
37
  num_kv_heads: int,
36
38
  layer_id: int,
39
+ sliding_window_size: Optional[int] = None,
37
40
  logit_cap: int = -1,
38
41
  v_head_dim: int = -1,
39
42
  ):
@@ -46,6 +49,7 @@ class RadixAttention(nn.Module):
46
49
  self.v_head_dim = v_head_dim if v_head_dim != -1 else head_dim
47
50
  self.scaling = scaling
48
51
  self.layer_id = layer_id
52
+ self.sliding_window_size = sliding_window_size if sliding_window_size else -1
49
53
 
50
54
  if (
51
55
  not global_server_args_dict.get("disable_flashinfer", False)
@@ -95,7 +99,7 @@ class RadixAttention(nn.Module):
95
99
  o = torch.empty_like(q)
96
100
  self.store_kv_cache(k, v, input_metadata)
97
101
 
98
- token_attention_fwd(
102
+ decode_attention_fwd(
99
103
  q.view(-1, self.tp_q_head_num, self.qk_head_dim),
100
104
  input_metadata.token_to_kv_pool.get_key_buffer(self.layer_id),
101
105
  input_metadata.token_to_kv_pool.get_value_buffer(self.layer_id),
@@ -113,14 +117,25 @@ class RadixAttention(nn.Module):
113
117
  return o
114
118
 
115
119
  def extend_forward_flashinfer(self, q, k, v, input_metadata: InputMetadata):
120
+ # using two wrappers is unnecessary in the current PR, but are prepared for future PRs
121
+ prefill_wrapper_paged = input_metadata.flashinfer_prefill_wrapper_paged
122
+ if self.sliding_window_size != -1:
123
+ prefill_wrapper_paged = prefill_wrapper_paged[0]
124
+ else:
125
+ if isinstance(prefill_wrapper_paged, list):
126
+ prefill_wrapper_paged = prefill_wrapper_paged[1]
127
+
116
128
  if not input_metadata.flashinfer_use_ragged:
117
- self.store_kv_cache(k, v, input_metadata)
129
+ if k is not None:
130
+ assert v is not None
131
+ self.store_kv_cache(k, v, input_metadata)
118
132
 
119
- o = input_metadata.flashinfer_prefill_wrapper_paged.forward(
133
+ o = prefill_wrapper_paged.forward(
120
134
  q.contiguous().view(-1, self.tp_q_head_num, self.head_dim),
121
135
  input_metadata.token_to_kv_pool.get_kv_buffer(self.layer_id),
122
136
  causal=True,
123
137
  sm_scale=self.scaling,
138
+ window_left=self.sliding_window_size,
124
139
  logits_soft_cap=self.logit_cap,
125
140
  )
126
141
  else:
@@ -138,14 +153,12 @@ class RadixAttention(nn.Module):
138
153
  if input_metadata.extend_no_prefix:
139
154
  o = o1
140
155
  else:
141
- o2, s2 = (
142
- input_metadata.flashinfer_prefill_wrapper_paged.forward_return_lse(
143
- q.contiguous().view(-1, self.tp_q_head_num, self.head_dim),
144
- input_metadata.token_to_kv_pool.get_kv_buffer(self.layer_id),
145
- causal=False,
146
- sm_scale=self.scaling,
147
- logits_soft_cap=self.logit_cap,
148
- )
156
+ o2, s2 = prefill_wrapper_paged.forward_return_lse(
157
+ q.contiguous().view(-1, self.tp_q_head_num, self.head_dim),
158
+ input_metadata.token_to_kv_pool.get_kv_buffer(self.layer_id),
159
+ causal=False,
160
+ sm_scale=self.scaling,
161
+ logits_soft_cap=self.logit_cap,
149
162
  )
150
163
 
151
164
  o, _ = merge_state(o1, s1, o2, s2)
@@ -158,9 +171,18 @@ class RadixAttention(nn.Module):
158
171
  return o.view(-1, self.tp_q_head_num * self.head_dim)
159
172
 
160
173
  def decode_forward_flashinfer(self, q, k, v, input_metadata: InputMetadata):
161
- self.store_kv_cache(k, v, input_metadata)
174
+ decode_wrapper = input_metadata.flashinfer_decode_wrapper
175
+ if self.sliding_window_size != -1:
176
+ decode_wrapper = decode_wrapper[0]
177
+ else:
178
+ if isinstance(decode_wrapper, list):
179
+ decode_wrapper = decode_wrapper[1]
180
+
181
+ if k is not None:
182
+ assert v is not None
183
+ self.store_kv_cache(k, v, input_metadata)
162
184
 
163
- o = input_metadata.flashinfer_decode_wrapper.forward(
185
+ o = decode_wrapper.forward(
164
186
  q.contiguous().view(-1, self.tp_q_head_num, self.head_dim),
165
187
  input_metadata.token_to_kv_pool.get_kv_buffer(self.layer_id),
166
188
  sm_scale=self.scaling,
@@ -170,8 +192,10 @@ class RadixAttention(nn.Module):
170
192
  return o.view(-1, self.tp_q_head_num * self.head_dim)
171
193
 
172
194
  def forward(self, q, k, v, input_metadata: InputMetadata):
173
- k = k.view(-1, self.tp_k_head_num, self.qk_head_dim)
174
- v = v.view(-1, self.tp_v_head_num, self.v_head_dim)
195
+ if k is not None:
196
+ assert v is not None
197
+ k = k.view(-1, self.tp_k_head_num, self.qk_head_dim)
198
+ v = v.view(-1, self.tp_v_head_num, self.v_head_dim)
175
199
 
176
200
  if input_metadata.forward_mode == ForwardMode.EXTEND:
177
201
  return self.extend_forward(q, k, v, input_metadata)
@@ -25,10 +25,14 @@ import zmq
25
25
  import zmq.asyncio
26
26
 
27
27
  from sglang.srt.hf_transformers_utils import get_tokenizer
28
- from sglang.srt.managers.io_struct import BatchStrOut, BatchTokenIDOut
28
+ from sglang.srt.managers.io_struct import (
29
+ BatchEmbeddingOut,
30
+ BatchStrOut,
31
+ BatchTokenIDOut,
32
+ )
29
33
  from sglang.srt.managers.schedule_batch import FINISH_MATCHED_STR
30
34
  from sglang.srt.server_args import PortArgs, ServerArgs
31
- from sglang.utils import find_printable_text, get_exception_traceback, graceful_registry
35
+ from sglang.utils import find_printable_text, get_exception_traceback
32
36
 
33
37
  asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
34
38
 
@@ -55,20 +59,40 @@ class DetokenizerManager:
55
59
  self.send_to_tokenizer = context.socket(zmq.PUSH)
56
60
  self.send_to_tokenizer.connect(f"tcp://127.0.0.1:{port_args.tokenizer_port}")
57
61
 
58
- self.tokenizer = get_tokenizer(
59
- server_args.tokenizer_path,
60
- tokenizer_mode=server_args.tokenizer_mode,
61
- trust_remote_code=server_args.trust_remote_code,
62
- )
62
+ if server_args.skip_tokenizer_init:
63
+ self.tokenizer = None
64
+ else:
65
+ self.tokenizer = get_tokenizer(
66
+ server_args.tokenizer_path,
67
+ tokenizer_mode=server_args.tokenizer_mode,
68
+ trust_remote_code=server_args.trust_remote_code,
69
+ )
63
70
 
64
71
  self.decode_status = {}
65
72
 
66
73
  async def handle_loop(self):
67
74
  while True:
68
75
  recv_obj: BatchTokenIDOut = await self.recv_from_router.recv_pyobj()
76
+
77
+ if isinstance(recv_obj, BatchEmbeddingOut):
78
+ self.send_to_tokenizer.send_pyobj(
79
+ BatchEmbeddingOut(
80
+ rids=recv_obj.rids,
81
+ embeddings=recv_obj.embeddings,
82
+ meta_info=recv_obj.meta_info,
83
+ finished_reason=recv_obj.finished_reason,
84
+ )
85
+ )
86
+ continue
87
+
69
88
  assert isinstance(recv_obj, BatchTokenIDOut)
70
89
  bs = len(recv_obj.rids)
71
90
 
91
+ if self.tokenizer is None:
92
+ # Send BatchTokenIDOut if no tokenizer init'ed.
93
+ self.send_to_tokenizer.send_pyobj(recv_obj)
94
+ continue
95
+
72
96
  # Initialize decode status
73
97
  read_ids, surr_ids = [], []
74
98
  for i in range(bs):
@@ -140,8 +164,6 @@ def start_detokenizer_process(
140
164
  port_args: PortArgs,
141
165
  pipe_writer,
142
166
  ):
143
- graceful_registry(inspect.currentframe().f_code.co_name)
144
-
145
167
  try:
146
168
  manager = DetokenizerManager(server_args, port_args)
147
169
  except Exception:
@@ -22,6 +22,8 @@ import uuid
22
22
  from dataclasses import dataclass
23
23
  from typing import Dict, List, Optional, Union
24
24
 
25
+ import torch
26
+
25
27
  from sglang.srt.managers.schedule_batch import BaseFinishReason
26
28
  from sglang.srt.sampling_params import SamplingParams
27
29
 
@@ -166,6 +168,59 @@ class TokenizedGenerateReqInput:
166
168
  stream: bool
167
169
 
168
170
 
171
+ @dataclass
172
+ class EmbeddingReqInput:
173
+ # The input prompt. It can be a single prompt or a batch of prompts.
174
+ text: Optional[Union[List[str], str]] = None
175
+ # The token ids for text; one can either specify text or input_ids.
176
+ input_ids: Optional[Union[List[List[int]], List[int]]] = None
177
+ # The request id.
178
+ rid: Optional[Union[List[str], str]] = None
179
+ # Dummy sampling params for compatibility
180
+ sampling_params: Union[List[Dict], Dict] = None
181
+
182
+ def post_init(self):
183
+ if (self.text is None and self.input_ids is None) or (
184
+ self.text is not None and self.input_ids is not None
185
+ ):
186
+ raise ValueError("Either text or input_ids should be provided.")
187
+
188
+ if self.text is not None:
189
+ is_single = isinstance(self.text, str)
190
+ else:
191
+ is_single = isinstance(self.input_ids[0], int)
192
+ self.is_single = is_single
193
+
194
+ if is_single:
195
+ if self.rid is None:
196
+ self.rid = uuid.uuid4().hex
197
+ if self.sampling_params is None:
198
+ self.sampling_params = {}
199
+ self.sampling_params["max_new_tokens"] = 1
200
+ else:
201
+ # support select operation
202
+ self.batch_size = (
203
+ len(self.text) if self.text is not None else len(self.input_ids)
204
+ )
205
+ if self.rid is None:
206
+ self.rid = [uuid.uuid4().hex for _ in range(self.batch_size)]
207
+ else:
208
+ if not isinstance(self.rid, list):
209
+ raise ValueError("The rid should be a list.")
210
+ if self.sampling_params is None:
211
+ self.sampling_params = [{}] * self.batch_size
212
+ for i in range(self.batch_size):
213
+ self.sampling_params[i]["max_new_tokens"] = 1
214
+
215
+
216
+ @dataclass
217
+ class TokenizedEmbeddingReqInput:
218
+ rid: str
219
+ input_text: str
220
+ input_ids: List[int]
221
+ sampling_params: SamplingParams
222
+
223
+
169
224
  @dataclass
170
225
  class BatchTokenIDOut:
171
226
  rids: List[str]
@@ -187,6 +242,14 @@ class BatchStrOut:
187
242
  finished_reason: List[BaseFinishReason]
188
243
 
189
244
 
245
+ @dataclass
246
+ class BatchEmbeddingOut:
247
+ rids: List[str]
248
+ embeddings: List[List[float]]
249
+ meta_info: List[Dict]
250
+ finished_reason: List[BaseFinishReason]
251
+
252
+
190
253
  @dataclass
191
254
  class FlushCacheReq:
192
255
  pass
@@ -15,44 +15,54 @@ limitations under the License.
15
15
 
16
16
  """Request policy scheduler"""
17
17
 
18
+ import os
18
19
  import random
19
20
  from collections import defaultdict
21
+ from contextlib import contextmanager
22
+ from typing import Dict, List, Optional
23
+
24
+ from sglang.srt.managers.schedule_batch import Req, ScheduleBatch
25
+ from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
26
+ from sglang.srt.mem_cache.radix_cache import TreeNode
27
+
28
+ # Clip the estimation of max_new_tokens for the request whose max_new_tokens is very large.
29
+ # This can prevent the server from being too conservative.
30
+ # Note that this only clips the estimation in the scheduler but does not change the stop
31
+ # condition. The request can still generate tokens until it hits the unclipped max_new_tokens.
32
+ CLIP_MAX_NEW_TOKENS = int(os.environ.get("SGLANG_CLIP_MAX_NEW_TOKENS", "4096"))
20
33
 
21
34
 
22
35
  class PolicyScheduler:
23
- def __init__(
24
- self,
25
- policy,
26
- max_running_seqs,
27
- max_prefill_num_tokens,
28
- max_total_num_tokens,
29
- tree_cache,
30
- ):
31
- if tree_cache.disable and policy == "lpm":
32
- # LMP is meaningless when the tree cache is disabled.
36
+ def __init__(self, policy: str, tree_cache: BasePrefixCache):
37
+ if tree_cache.disable and policy in ["lpm", "dfs-weight"]:
38
+ # LPM and DFS-weight is meaningless when the tree cache is disabled.
33
39
  policy = "fcfs"
34
40
 
35
41
  self.policy = policy
36
- self.max_running_seqs = max_running_seqs
37
- self.max_prefill_num_tokens = max_prefill_num_tokens
38
- self.max_total_num_tokens = max_total_num_tokens
39
42
  self.tree_cache = tree_cache
40
43
 
41
- def get_priority_queue(self, waiting_queue):
44
+ def calc_priority(self, waiting_queue: List[Req]):
45
+ # Compute matched prefix length
46
+ prefix_computed = False
47
+ if self.policy in ["lpm", "dfs-weight"]:
48
+ for r in waiting_queue:
49
+ # NOTE: the prefix_indices must always be aligned with last_node
50
+ r.prefix_indices, r.last_node = self.tree_cache.match_prefix(
51
+ rid=r.rid, key=r.adjust_max_prefix_ids()
52
+ )
53
+ prefix_computed = True
54
+
42
55
  if self.policy == "lpm":
43
- # longest prefix match
56
+ # Longest Prefix Match
44
57
  waiting_queue.sort(key=lambda x: -len(x.prefix_indices))
45
- return waiting_queue
46
58
  elif self.policy == "fcfs":
47
59
  # first come first serve
48
- return waiting_queue
60
+ pass
49
61
  elif self.policy == "lof":
50
62
  # longest output first
51
63
  waiting_queue.sort(key=lambda x: -x.sampling_params.max_new_tokens)
52
- return waiting_queue
53
64
  elif self.policy == "random":
54
65
  random.shuffle(waiting_queue)
55
- return waiting_queue
56
66
  elif self.policy == "dfs-weight":
57
67
  last_node_to_reqs = defaultdict(list)
58
68
  for req in waiting_queue:
@@ -63,23 +73,161 @@ class PolicyScheduler:
63
73
  node_to_weight[node] = len(last_node_to_reqs[node])
64
74
  self.calc_weight(self.tree_cache.root_node, node_to_weight)
65
75
 
66
- q = []
76
+ waiting_queue.clear()
67
77
  self.get_dfs_priority(
68
- self.tree_cache.root_node, node_to_weight, last_node_to_reqs, q
78
+ self.tree_cache.root_node,
79
+ node_to_weight,
80
+ last_node_to_reqs,
81
+ waiting_queue,
69
82
  )
70
- assert len(q) == len(waiting_queue)
71
- return q
72
83
  else:
73
84
  raise ValueError(f"Unknown schedule_policy: {self.policy}")
74
85
 
75
- def calc_weight(self, cur_node, node_to_weight):
86
+ return prefix_computed
87
+
88
+ def calc_weight(self, cur_node: TreeNode, node_to_weight: Dict):
76
89
  for child in cur_node.children.values():
77
90
  self.calc_weight(child, node_to_weight)
78
91
  node_to_weight[cur_node] += node_to_weight[child]
79
92
 
80
- def get_dfs_priority(self, cur_node, node_to_priority, last_node_to_reqs, q):
93
+ def get_dfs_priority(
94
+ self,
95
+ cur_node: TreeNode,
96
+ node_to_priority: Dict,
97
+ last_node_to_reqs: Dict,
98
+ q: List,
99
+ ):
81
100
  childs = [child for child in cur_node.children.values()]
82
101
  childs.sort(key=lambda x: -node_to_priority[x])
83
102
  for child in childs:
84
103
  self.get_dfs_priority(child, node_to_priority, last_node_to_reqs, q)
85
104
  q.extend(last_node_to_reqs[cur_node])
105
+
106
+
107
+ class PrefillAdder:
108
+ def __init__(
109
+ self,
110
+ tree_cache: BasePrefixCache,
111
+ rem_total_tokens: int,
112
+ rem_input_tokens: int,
113
+ rem_chunk_tokens: Optional[int],
114
+ ):
115
+ self.tree_cache = tree_cache
116
+ self.rem_total_tokens = rem_total_tokens
117
+ self.rem_input_tokens = rem_input_tokens
118
+ self.rem_chunk_tokens = rem_chunk_tokens
119
+
120
+ self.can_run_list = []
121
+ self.new_inflight_req = None
122
+ self.log_hit_tokens = 0
123
+ self.log_input_tokens = 0
124
+
125
+ def no_remaining_tokens(self):
126
+ return (
127
+ self.rem_total_tokens <= 0
128
+ or self.rem_input_tokens <= 0
129
+ or (
130
+ self.rem_chunk_tokens <= 0
131
+ if self.rem_chunk_tokens is not None
132
+ else False
133
+ )
134
+ )
135
+
136
+ def remove_running_tokens(
137
+ self, running_batch: ScheduleBatch, new_token_ratio: float
138
+ ):
139
+ self.rem_total_tokens -= sum(
140
+ [
141
+ min(
142
+ (r.sampling_params.max_new_tokens - len(r.output_ids)),
143
+ CLIP_MAX_NEW_TOKENS,
144
+ )
145
+ * new_token_ratio
146
+ for r in running_batch.reqs
147
+ ]
148
+ )
149
+
150
+ def _prefill_one_req(
151
+ self, prefix_len: int, extend_input_len: int, max_new_tokens: int
152
+ ):
153
+ self.rem_total_tokens -= extend_input_len + max_new_tokens
154
+ self.rem_input_tokens -= extend_input_len
155
+ if self.rem_chunk_tokens is not None:
156
+ self.rem_chunk_tokens -= extend_input_len
157
+
158
+ self.log_hit_tokens += prefix_len
159
+ self.log_input_tokens += extend_input_len
160
+
161
+ def add_inflight_req(self, req: Req):
162
+ truncated = req.extend_input_len > self.rem_chunk_tokens
163
+ req.extend_input_len = min(req.extend_input_len, self.rem_chunk_tokens)
164
+ req.fill_ids = req.fill_ids[: len(req.prefix_indices) + req.extend_input_len]
165
+ self.can_run_list.append(req)
166
+
167
+ self._prefill_one_req(
168
+ len(req.prefix_indices),
169
+ req.extend_input_len,
170
+ (
171
+ min(req.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKENS)
172
+ if not truncated
173
+ else 0
174
+ ),
175
+ )
176
+
177
+ # Return if chunked prefill not finished
178
+ return req if truncated else None
179
+
180
+ @contextmanager
181
+ def _lock_node(self, last_node: TreeNode):
182
+ try:
183
+ delta = self.tree_cache.inc_lock_ref(last_node)
184
+ self.rem_total_tokens += delta
185
+ yield None
186
+ finally:
187
+ delta = self.tree_cache.dec_lock_ref(last_node)
188
+ self.rem_total_tokens += delta
189
+
190
+ def add_one_req(self, req: Req):
191
+ total_tokens = req.extend_input_len + min(
192
+ req.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKENS
193
+ )
194
+ input_tokens = req.extend_input_len
195
+ prefix_len = len(req.prefix_indices)
196
+
197
+ if total_tokens >= self.rem_total_tokens:
198
+ return False
199
+
200
+ if input_tokens > self.rem_input_tokens and len(self.can_run_list) != 0:
201
+ return False
202
+
203
+ with self._lock_node(req.last_node):
204
+ if total_tokens > self.rem_total_tokens:
205
+ return False
206
+
207
+ if (
208
+ self.rem_chunk_tokens is None
209
+ or input_tokens <= self.rem_chunk_tokens
210
+ or (req.return_logprob and req.normalized_prompt_logprob is None)
211
+ ):
212
+ # Non-chunked prefill
213
+ self.can_run_list.append(req)
214
+ self.tree_cache.inc_lock_ref(req.last_node)
215
+ self._prefill_one_req(
216
+ prefix_len,
217
+ input_tokens,
218
+ min(req.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKENS),
219
+ )
220
+ else:
221
+ # Chunked prefill
222
+ trunc_len = self.rem_chunk_tokens
223
+ if trunc_len == 0:
224
+ return False
225
+
226
+ req.extend_input_len = trunc_len
227
+ req.fill_ids = req.fill_ids[: len(req.prefix_indices) + trunc_len]
228
+ self.can_run_list.append(req)
229
+ self.new_inflight_req = req
230
+ self.tree_cache.inc_lock_ref(req.last_node)
231
+ self._prefill_one_req(prefix_len, trunc_len, 0)
232
+
233
+ return True