sglang 0.2.11__py3-none-any.whl → 0.2.12__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (63) hide show
  1. sglang/bench_latency.py +6 -4
  2. sglang/bench_serving.py +46 -22
  3. sglang/lang/compiler.py +2 -2
  4. sglang/lang/ir.py +3 -3
  5. sglang/srt/constrained/base_tool_cache.py +1 -1
  6. sglang/srt/constrained/fsm_cache.py +12 -2
  7. sglang/srt/layers/activation.py +33 -0
  8. sglang/srt/layers/{token_attention.py → decode_attention.py} +9 -5
  9. sglang/srt/layers/extend_attention.py +6 -1
  10. sglang/srt/layers/layernorm.py +65 -0
  11. sglang/srt/layers/logits_processor.py +5 -0
  12. sglang/srt/layers/pooler.py +50 -0
  13. sglang/srt/layers/{context_flashattention_nopad.py → prefill_attention.py} +5 -0
  14. sglang/srt/layers/radix_attention.py +2 -2
  15. sglang/srt/managers/detokenizer_manager.py +31 -9
  16. sglang/srt/managers/io_struct.py +63 -0
  17. sglang/srt/managers/policy_scheduler.py +173 -25
  18. sglang/srt/managers/schedule_batch.py +110 -87
  19. sglang/srt/managers/tokenizer_manager.py +193 -111
  20. sglang/srt/managers/tp_worker.py +289 -352
  21. sglang/srt/mem_cache/{base_cache.py → base_prefix_cache.py} +9 -4
  22. sglang/srt/mem_cache/chunk_cache.py +43 -20
  23. sglang/srt/mem_cache/memory_pool.py +2 -2
  24. sglang/srt/mem_cache/radix_cache.py +74 -40
  25. sglang/srt/model_executor/cuda_graph_runner.py +24 -9
  26. sglang/srt/model_executor/forward_batch_info.py +168 -105
  27. sglang/srt/model_executor/model_runner.py +24 -37
  28. sglang/srt/models/gemma2.py +0 -1
  29. sglang/srt/models/internlm2.py +2 -7
  30. sglang/srt/models/llama2.py +4 -4
  31. sglang/srt/models/llama_embedding.py +88 -0
  32. sglang/srt/models/qwen2_moe.py +0 -11
  33. sglang/srt/openai_api/adapter.py +155 -27
  34. sglang/srt/openai_api/protocol.py +37 -1
  35. sglang/srt/sampling/penaltylib/__init__.py +13 -0
  36. sglang/srt/sampling/penaltylib/orchestrator.py +357 -0
  37. sglang/srt/sampling/penaltylib/penalizers/frequency_penalty.py +80 -0
  38. sglang/srt/sampling/penaltylib/penalizers/min_new_tokens.py +105 -0
  39. sglang/srt/sampling/penaltylib/penalizers/presence_penalty.py +79 -0
  40. sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +83 -0
  41. sglang/srt/sampling_params.py +31 -4
  42. sglang/srt/server.py +69 -15
  43. sglang/srt/server_args.py +26 -19
  44. sglang/srt/utils.py +31 -13
  45. sglang/test/run_eval.py +10 -1
  46. sglang/test/runners.py +63 -63
  47. sglang/test/simple_eval_humaneval.py +2 -8
  48. sglang/test/simple_eval_mgsm.py +203 -0
  49. sglang/test/srt/sampling/penaltylib/utils.py +337 -0
  50. sglang/test/test_layernorm.py +60 -0
  51. sglang/test/test_programs.py +4 -2
  52. sglang/test/test_utils.py +20 -2
  53. sglang/utils.py +0 -1
  54. sglang/version.py +1 -1
  55. {sglang-0.2.11.dist-info → sglang-0.2.12.dist-info}/METADATA +23 -14
  56. sglang-0.2.12.dist-info/RECORD +112 -0
  57. sglang/srt/layers/linear.py +0 -884
  58. sglang/srt/layers/quantization/__init__.py +0 -64
  59. sglang/srt/layers/quantization/fp8.py +0 -677
  60. sglang-0.2.11.dist-info/RECORD +0 -102
  61. {sglang-0.2.11.dist-info → sglang-0.2.12.dist-info}/LICENSE +0 -0
  62. {sglang-0.2.11.dist-info → sglang-0.2.12.dist-info}/WHEEL +0 -0
  63. {sglang-0.2.11.dist-info → sglang-0.2.12.dist-info}/top_level.txt +0 -0
@@ -22,6 +22,8 @@ import uuid
22
22
  from dataclasses import dataclass
23
23
  from typing import Dict, List, Optional, Union
24
24
 
25
+ import torch
26
+
25
27
  from sglang.srt.managers.schedule_batch import BaseFinishReason
26
28
  from sglang.srt.sampling_params import SamplingParams
27
29
 
@@ -166,6 +168,59 @@ class TokenizedGenerateReqInput:
166
168
  stream: bool
167
169
 
168
170
 
171
+ @dataclass
172
+ class EmbeddingReqInput:
173
+ # The input prompt. It can be a single prompt or a batch of prompts.
174
+ text: Optional[Union[List[str], str]] = None
175
+ # The token ids for text; one can either specify text or input_ids.
176
+ input_ids: Optional[Union[List[List[int]], List[int]]] = None
177
+ # The request id.
178
+ rid: Optional[Union[List[str], str]] = None
179
+ # Dummy sampling params for compatibility
180
+ sampling_params: Union[List[Dict], Dict] = None
181
+
182
+ def post_init(self):
183
+ if (self.text is None and self.input_ids is None) or (
184
+ self.text is not None and self.input_ids is not None
185
+ ):
186
+ raise ValueError("Either text or input_ids should be provided.")
187
+
188
+ if self.text is not None:
189
+ is_single = isinstance(self.text, str)
190
+ else:
191
+ is_single = isinstance(self.input_ids[0], int)
192
+ self.is_single = is_single
193
+
194
+ if is_single:
195
+ if self.rid is None:
196
+ self.rid = uuid.uuid4().hex
197
+ if self.sampling_params is None:
198
+ self.sampling_params = {}
199
+ self.sampling_params["max_new_tokens"] = 1
200
+ else:
201
+ # support select operation
202
+ self.batch_size = (
203
+ len(self.text) if self.text is not None else len(self.input_ids)
204
+ )
205
+ if self.rid is None:
206
+ self.rid = [uuid.uuid4().hex for _ in range(self.batch_size)]
207
+ else:
208
+ if not isinstance(self.rid, list):
209
+ raise ValueError("The rid should be a list.")
210
+ if self.sampling_params is None:
211
+ self.sampling_params = [{}] * self.batch_size
212
+ for i in range(self.batch_size):
213
+ self.sampling_params[i]["max_new_tokens"] = 1
214
+
215
+
216
+ @dataclass
217
+ class TokenizedEmbeddingReqInput:
218
+ rid: str
219
+ input_text: str
220
+ input_ids: List[int]
221
+ sampling_params: SamplingParams
222
+
223
+
169
224
  @dataclass
170
225
  class BatchTokenIDOut:
171
226
  rids: List[str]
@@ -187,6 +242,14 @@ class BatchStrOut:
187
242
  finished_reason: List[BaseFinishReason]
188
243
 
189
244
 
245
+ @dataclass
246
+ class BatchEmbeddingOut:
247
+ rids: List[str]
248
+ embeddings: List[List[float]]
249
+ meta_info: List[Dict]
250
+ finished_reason: List[BaseFinishReason]
251
+
252
+
190
253
  @dataclass
191
254
  class FlushCacheReq:
192
255
  pass
@@ -15,44 +15,54 @@ limitations under the License.
15
15
 
16
16
  """Request policy scheduler"""
17
17
 
18
+ import os
18
19
  import random
19
20
  from collections import defaultdict
21
+ from contextlib import contextmanager
22
+ from typing import Dict, List, Optional
23
+
24
+ from sglang.srt.managers.schedule_batch import Req, ScheduleBatch
25
+ from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
26
+ from sglang.srt.mem_cache.radix_cache import TreeNode
27
+
28
+ # Clip the estimation of max_new_tokens for the request whose max_new_tokens is very large.
29
+ # This can prevent the server from being too conservative.
30
+ # Note that this only clips the estimation in the scheduler but does not change the stop
31
+ # condition. The request can still generate tokens until it hits the unclipped max_new_tokens.
32
+ CLIP_MAX_NEW_TOKENS = int(os.environ.get("SGLANG_CLIP_MAX_NEW_TOKENS", "4096"))
20
33
 
21
34
 
22
35
  class PolicyScheduler:
23
- def __init__(
24
- self,
25
- policy,
26
- max_running_seqs,
27
- max_prefill_num_tokens,
28
- max_total_num_tokens,
29
- tree_cache,
30
- ):
31
- if tree_cache.disable and policy == "lpm":
32
- # LMP is meaningless when the tree cache is disabled.
36
+ def __init__(self, policy: str, tree_cache: BasePrefixCache):
37
+ if tree_cache.disable and policy in ["lpm", "dfs-weight"]:
38
+ # LPM and DFS-weight is meaningless when the tree cache is disabled.
33
39
  policy = "fcfs"
34
40
 
35
41
  self.policy = policy
36
- self.max_running_seqs = max_running_seqs
37
- self.max_prefill_num_tokens = max_prefill_num_tokens
38
- self.max_total_num_tokens = max_total_num_tokens
39
42
  self.tree_cache = tree_cache
40
43
 
41
- def get_priority_queue(self, waiting_queue):
44
+ def calc_priority(self, waiting_queue: List[Req]):
45
+ # Compute matched prefix length
46
+ prefix_computed = False
47
+ if self.policy in ["lpm", "dfs-weight"]:
48
+ for r in waiting_queue:
49
+ # NOTE: the prefix_indices must always be aligned with last_node
50
+ r.prefix_indices, r.last_node = self.tree_cache.match_prefix(
51
+ rid=r.rid, key=r.adjust_max_prefix_ids()
52
+ )
53
+ prefix_computed = True
54
+
42
55
  if self.policy == "lpm":
43
- # longest prefix match
56
+ # Longest Prefix Match
44
57
  waiting_queue.sort(key=lambda x: -len(x.prefix_indices))
45
- return waiting_queue
46
58
  elif self.policy == "fcfs":
47
59
  # first come first serve
48
- return waiting_queue
60
+ pass
49
61
  elif self.policy == "lof":
50
62
  # longest output first
51
63
  waiting_queue.sort(key=lambda x: -x.sampling_params.max_new_tokens)
52
- return waiting_queue
53
64
  elif self.policy == "random":
54
65
  random.shuffle(waiting_queue)
55
- return waiting_queue
56
66
  elif self.policy == "dfs-weight":
57
67
  last_node_to_reqs = defaultdict(list)
58
68
  for req in waiting_queue:
@@ -63,23 +73,161 @@ class PolicyScheduler:
63
73
  node_to_weight[node] = len(last_node_to_reqs[node])
64
74
  self.calc_weight(self.tree_cache.root_node, node_to_weight)
65
75
 
66
- q = []
76
+ waiting_queue.clear()
67
77
  self.get_dfs_priority(
68
- self.tree_cache.root_node, node_to_weight, last_node_to_reqs, q
78
+ self.tree_cache.root_node,
79
+ node_to_weight,
80
+ last_node_to_reqs,
81
+ waiting_queue,
69
82
  )
70
- assert len(q) == len(waiting_queue)
71
- return q
72
83
  else:
73
84
  raise ValueError(f"Unknown schedule_policy: {self.policy}")
74
85
 
75
- def calc_weight(self, cur_node, node_to_weight):
86
+ return prefix_computed
87
+
88
+ def calc_weight(self, cur_node: TreeNode, node_to_weight: Dict):
76
89
  for child in cur_node.children.values():
77
90
  self.calc_weight(child, node_to_weight)
78
91
  node_to_weight[cur_node] += node_to_weight[child]
79
92
 
80
- def get_dfs_priority(self, cur_node, node_to_priority, last_node_to_reqs, q):
93
+ def get_dfs_priority(
94
+ self,
95
+ cur_node: TreeNode,
96
+ node_to_priority: Dict,
97
+ last_node_to_reqs: Dict,
98
+ q: List,
99
+ ):
81
100
  childs = [child for child in cur_node.children.values()]
82
101
  childs.sort(key=lambda x: -node_to_priority[x])
83
102
  for child in childs:
84
103
  self.get_dfs_priority(child, node_to_priority, last_node_to_reqs, q)
85
104
  q.extend(last_node_to_reqs[cur_node])
105
+
106
+
107
+ class PrefillAdder:
108
+ def __init__(
109
+ self,
110
+ tree_cache: BasePrefixCache,
111
+ rem_total_tokens: int,
112
+ rem_input_tokens: int,
113
+ rem_chunk_tokens: Optional[int],
114
+ ):
115
+ self.tree_cache = tree_cache
116
+ self.rem_total_tokens = rem_total_tokens
117
+ self.rem_input_tokens = rem_input_tokens
118
+ self.rem_chunk_tokens = rem_chunk_tokens
119
+
120
+ self.can_run_list = []
121
+ self.new_inflight_req = None
122
+ self.log_hit_tokens = 0
123
+ self.log_input_tokens = 0
124
+
125
+ def no_remaining_tokens(self):
126
+ return (
127
+ self.rem_total_tokens <= 0
128
+ or self.rem_input_tokens <= 0
129
+ or (
130
+ self.rem_chunk_tokens <= 0
131
+ if self.rem_chunk_tokens is not None
132
+ else False
133
+ )
134
+ )
135
+
136
+ def remove_running_tokens(
137
+ self, running_batch: ScheduleBatch, new_token_ratio: float
138
+ ):
139
+ self.rem_total_tokens -= sum(
140
+ [
141
+ min(
142
+ (r.sampling_params.max_new_tokens - len(r.output_ids)),
143
+ CLIP_MAX_NEW_TOKENS,
144
+ )
145
+ * new_token_ratio
146
+ for r in running_batch.reqs
147
+ ]
148
+ )
149
+
150
+ def _prefill_one_req(
151
+ self, prefix_len: int, extend_input_len: int, max_new_tokens: int
152
+ ):
153
+ self.rem_total_tokens -= extend_input_len + max_new_tokens
154
+ self.rem_input_tokens -= extend_input_len
155
+ if self.rem_chunk_tokens is not None:
156
+ self.rem_chunk_tokens -= extend_input_len
157
+
158
+ self.log_hit_tokens += prefix_len
159
+ self.log_input_tokens += extend_input_len
160
+
161
+ def add_inflight_req(self, req: Req):
162
+ truncated = req.extend_input_len > self.rem_chunk_tokens
163
+ req.extend_input_len = min(req.extend_input_len, self.rem_chunk_tokens)
164
+ req.fill_ids = req.fill_ids[: len(req.prefix_indices) + req.extend_input_len]
165
+ self.can_run_list.append(req)
166
+
167
+ self._prefill_one_req(
168
+ len(req.prefix_indices),
169
+ req.extend_input_len,
170
+ (
171
+ min(req.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKENS)
172
+ if not truncated
173
+ else 0
174
+ ),
175
+ )
176
+
177
+ # Return if chunked prefill not finished
178
+ return req if truncated else None
179
+
180
+ @contextmanager
181
+ def _lock_node(self, last_node: TreeNode):
182
+ try:
183
+ delta = self.tree_cache.inc_lock_ref(last_node)
184
+ self.rem_total_tokens += delta
185
+ yield None
186
+ finally:
187
+ delta = self.tree_cache.dec_lock_ref(last_node)
188
+ self.rem_total_tokens += delta
189
+
190
+ def add_one_req(self, req: Req):
191
+ total_tokens = req.extend_input_len + min(
192
+ req.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKENS
193
+ )
194
+ input_tokens = req.extend_input_len
195
+ prefix_len = len(req.prefix_indices)
196
+
197
+ if total_tokens >= self.rem_total_tokens:
198
+ return False
199
+
200
+ if input_tokens > self.rem_input_tokens and len(self.can_run_list) != 0:
201
+ return False
202
+
203
+ with self._lock_node(req.last_node):
204
+ if total_tokens > self.rem_total_tokens:
205
+ return False
206
+
207
+ if (
208
+ self.rem_chunk_tokens is None
209
+ or input_tokens <= self.rem_chunk_tokens
210
+ or (req.return_logprob and req.normalized_prompt_logprob is None)
211
+ ):
212
+ # Non-chunked prefill
213
+ self.can_run_list.append(req)
214
+ self.tree_cache.inc_lock_ref(req.last_node)
215
+ self._prefill_one_req(
216
+ prefix_len,
217
+ input_tokens,
218
+ min(req.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKENS),
219
+ )
220
+ else:
221
+ # Chunked prefill
222
+ trunc_len = self.rem_chunk_tokens
223
+ if trunc_len == 0:
224
+ return False
225
+
226
+ req.extend_input_len = trunc_len
227
+ req.fill_ids = req.fill_ids[: len(req.prefix_indices) + trunc_len]
228
+ self.can_run_list.append(req)
229
+ self.new_inflight_req = req
230
+ self.tree_cache.inc_lock_ref(req.last_node)
231
+ self._prefill_one_req(prefix_len, trunc_len, 0)
232
+
233
+ return True