sglang 0.2.10__py3-none-any.whl → 0.2.12__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (89) hide show
  1. sglang/__init__.py +8 -0
  2. sglang/api.py +10 -2
  3. sglang/bench_latency.py +151 -40
  4. sglang/bench_serving.py +46 -22
  5. sglang/check_env.py +24 -2
  6. sglang/global_config.py +0 -1
  7. sglang/lang/backend/base_backend.py +3 -1
  8. sglang/lang/backend/openai.py +8 -3
  9. sglang/lang/backend/runtime_endpoint.py +46 -29
  10. sglang/lang/choices.py +164 -0
  11. sglang/lang/compiler.py +2 -2
  12. sglang/lang/interpreter.py +6 -13
  13. sglang/lang/ir.py +14 -5
  14. sglang/srt/constrained/base_tool_cache.py +1 -1
  15. sglang/srt/constrained/fsm_cache.py +12 -2
  16. sglang/srt/layers/activation.py +33 -0
  17. sglang/srt/layers/{token_attention.py → decode_attention.py} +9 -5
  18. sglang/srt/layers/extend_attention.py +6 -1
  19. sglang/srt/layers/layernorm.py +65 -0
  20. sglang/srt/layers/logits_processor.py +6 -1
  21. sglang/srt/layers/pooler.py +50 -0
  22. sglang/srt/layers/{context_flashattention_nopad.py → prefill_attention.py} +5 -0
  23. sglang/srt/layers/radix_attention.py +4 -7
  24. sglang/srt/managers/detokenizer_manager.py +31 -9
  25. sglang/srt/managers/io_struct.py +63 -0
  26. sglang/srt/managers/policy_scheduler.py +173 -25
  27. sglang/srt/managers/schedule_batch.py +174 -380
  28. sglang/srt/managers/tokenizer_manager.py +197 -112
  29. sglang/srt/managers/tp_worker.py +299 -364
  30. sglang/srt/mem_cache/{base_cache.py → base_prefix_cache.py} +9 -4
  31. sglang/srt/mem_cache/chunk_cache.py +43 -20
  32. sglang/srt/mem_cache/memory_pool.py +10 -15
  33. sglang/srt/mem_cache/radix_cache.py +74 -40
  34. sglang/srt/model_executor/cuda_graph_runner.py +27 -12
  35. sglang/srt/model_executor/forward_batch_info.py +319 -0
  36. sglang/srt/model_executor/model_runner.py +30 -47
  37. sglang/srt/models/chatglm.py +1 -1
  38. sglang/srt/models/commandr.py +1 -1
  39. sglang/srt/models/dbrx.py +1 -1
  40. sglang/srt/models/deepseek.py +1 -1
  41. sglang/srt/models/deepseek_v2.py +1 -1
  42. sglang/srt/models/gemma.py +1 -1
  43. sglang/srt/models/gemma2.py +1 -2
  44. sglang/srt/models/gpt_bigcode.py +1 -1
  45. sglang/srt/models/grok.py +1 -1
  46. sglang/srt/models/internlm2.py +3 -8
  47. sglang/srt/models/llama2.py +5 -5
  48. sglang/srt/models/llama_classification.py +1 -1
  49. sglang/srt/models/llama_embedding.py +88 -0
  50. sglang/srt/models/llava.py +1 -2
  51. sglang/srt/models/llavavid.py +1 -2
  52. sglang/srt/models/minicpm.py +1 -1
  53. sglang/srt/models/mixtral.py +1 -1
  54. sglang/srt/models/mixtral_quant.py +1 -1
  55. sglang/srt/models/qwen.py +1 -1
  56. sglang/srt/models/qwen2.py +1 -1
  57. sglang/srt/models/qwen2_moe.py +1 -12
  58. sglang/srt/models/stablelm.py +1 -1
  59. sglang/srt/openai_api/adapter.py +189 -39
  60. sglang/srt/openai_api/protocol.py +43 -1
  61. sglang/srt/sampling/penaltylib/__init__.py +13 -0
  62. sglang/srt/sampling/penaltylib/orchestrator.py +357 -0
  63. sglang/srt/sampling/penaltylib/penalizers/frequency_penalty.py +80 -0
  64. sglang/srt/sampling/penaltylib/penalizers/min_new_tokens.py +105 -0
  65. sglang/srt/sampling/penaltylib/penalizers/presence_penalty.py +79 -0
  66. sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +83 -0
  67. sglang/srt/sampling_params.py +31 -4
  68. sglang/srt/server.py +93 -21
  69. sglang/srt/server_args.py +30 -19
  70. sglang/srt/utils.py +31 -13
  71. sglang/test/run_eval.py +10 -1
  72. sglang/test/runners.py +63 -63
  73. sglang/test/simple_eval_humaneval.py +2 -8
  74. sglang/test/simple_eval_mgsm.py +203 -0
  75. sglang/test/srt/sampling/penaltylib/utils.py +337 -0
  76. sglang/test/test_layernorm.py +60 -0
  77. sglang/test/test_programs.py +4 -2
  78. sglang/test/test_utils.py +21 -3
  79. sglang/utils.py +0 -1
  80. sglang/version.py +1 -1
  81. {sglang-0.2.10.dist-info → sglang-0.2.12.dist-info}/METADATA +50 -31
  82. sglang-0.2.12.dist-info/RECORD +112 -0
  83. sglang/srt/layers/linear.py +0 -884
  84. sglang/srt/layers/quantization/__init__.py +0 -64
  85. sglang/srt/layers/quantization/fp8.py +0 -677
  86. sglang-0.2.10.dist-info/RECORD +0 -100
  87. {sglang-0.2.10.dist-info → sglang-0.2.12.dist-info}/LICENSE +0 -0
  88. {sglang-0.2.10.dist-info → sglang-0.2.12.dist-info}/WHEEL +0 -0
  89. {sglang-0.2.10.dist-info → sglang-0.2.12.dist-info}/top_level.txt +0 -0
@@ -13,6 +13,11 @@ See the License for the specific language governing permissions and
13
13
  limitations under the License.
14
14
  """
15
15
 
16
+ """
17
+ Memory-efficient attention for prefill.
18
+ It supporst page size = 1.
19
+ """
20
+
16
21
  # Adapted from
17
22
  # https://github.com/ModelTC/lightllm/blob/f2a54f0912293f683bf1d1695fd12c4098a5bf82/lightllm/models/llama/triton_kernel/context_flashattention_nopad.py#L1
18
23
  import torch
@@ -20,13 +20,10 @@ from flashinfer.cascade import merge_state
20
20
  from torch import nn
21
21
 
22
22
  from sglang.global_config import global_config
23
+ from sglang.srt.layers.decode_attention import decode_attention_fwd
23
24
  from sglang.srt.layers.extend_attention import extend_attention_fwd
24
- from sglang.srt.layers.token_attention import token_attention_fwd
25
- from sglang.srt.model_executor.model_runner import (
26
- ForwardMode,
27
- InputMetadata,
28
- global_server_args_dict,
29
- )
25
+ from sglang.srt.model_executor.forward_batch_info import ForwardMode, InputMetadata
26
+ from sglang.srt.model_executor.model_runner import global_server_args_dict
30
27
 
31
28
 
32
29
  class RadixAttention(nn.Module):
@@ -98,7 +95,7 @@ class RadixAttention(nn.Module):
98
95
  o = torch.empty_like(q)
99
96
  self.store_kv_cache(k, v, input_metadata)
100
97
 
101
- token_attention_fwd(
98
+ decode_attention_fwd(
102
99
  q.view(-1, self.tp_q_head_num, self.qk_head_dim),
103
100
  input_metadata.token_to_kv_pool.get_key_buffer(self.layer_id),
104
101
  input_metadata.token_to_kv_pool.get_value_buffer(self.layer_id),
@@ -25,10 +25,14 @@ import zmq
25
25
  import zmq.asyncio
26
26
 
27
27
  from sglang.srt.hf_transformers_utils import get_tokenizer
28
- from sglang.srt.managers.io_struct import BatchStrOut, BatchTokenIDOut
28
+ from sglang.srt.managers.io_struct import (
29
+ BatchEmbeddingOut,
30
+ BatchStrOut,
31
+ BatchTokenIDOut,
32
+ )
29
33
  from sglang.srt.managers.schedule_batch import FINISH_MATCHED_STR
30
34
  from sglang.srt.server_args import PortArgs, ServerArgs
31
- from sglang.utils import find_printable_text, get_exception_traceback, graceful_registry
35
+ from sglang.utils import find_printable_text, get_exception_traceback
32
36
 
33
37
  asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
34
38
 
@@ -55,20 +59,40 @@ class DetokenizerManager:
55
59
  self.send_to_tokenizer = context.socket(zmq.PUSH)
56
60
  self.send_to_tokenizer.connect(f"tcp://127.0.0.1:{port_args.tokenizer_port}")
57
61
 
58
- self.tokenizer = get_tokenizer(
59
- server_args.tokenizer_path,
60
- tokenizer_mode=server_args.tokenizer_mode,
61
- trust_remote_code=server_args.trust_remote_code,
62
- )
62
+ if server_args.skip_tokenizer_init:
63
+ self.tokenizer = None
64
+ else:
65
+ self.tokenizer = get_tokenizer(
66
+ server_args.tokenizer_path,
67
+ tokenizer_mode=server_args.tokenizer_mode,
68
+ trust_remote_code=server_args.trust_remote_code,
69
+ )
63
70
 
64
71
  self.decode_status = {}
65
72
 
66
73
  async def handle_loop(self):
67
74
  while True:
68
75
  recv_obj: BatchTokenIDOut = await self.recv_from_router.recv_pyobj()
76
+
77
+ if isinstance(recv_obj, BatchEmbeddingOut):
78
+ self.send_to_tokenizer.send_pyobj(
79
+ BatchEmbeddingOut(
80
+ rids=recv_obj.rids,
81
+ embeddings=recv_obj.embeddings,
82
+ meta_info=recv_obj.meta_info,
83
+ finished_reason=recv_obj.finished_reason,
84
+ )
85
+ )
86
+ continue
87
+
69
88
  assert isinstance(recv_obj, BatchTokenIDOut)
70
89
  bs = len(recv_obj.rids)
71
90
 
91
+ if self.tokenizer is None:
92
+ # Send BatchTokenIDOut if no tokenizer init'ed.
93
+ self.send_to_tokenizer.send_pyobj(recv_obj)
94
+ continue
95
+
72
96
  # Initialize decode status
73
97
  read_ids, surr_ids = [], []
74
98
  for i in range(bs):
@@ -140,8 +164,6 @@ def start_detokenizer_process(
140
164
  port_args: PortArgs,
141
165
  pipe_writer,
142
166
  ):
143
- graceful_registry(inspect.currentframe().f_code.co_name)
144
-
145
167
  try:
146
168
  manager = DetokenizerManager(server_args, port_args)
147
169
  except Exception:
@@ -22,6 +22,8 @@ import uuid
22
22
  from dataclasses import dataclass
23
23
  from typing import Dict, List, Optional, Union
24
24
 
25
+ import torch
26
+
25
27
  from sglang.srt.managers.schedule_batch import BaseFinishReason
26
28
  from sglang.srt.sampling_params import SamplingParams
27
29
 
@@ -166,6 +168,59 @@ class TokenizedGenerateReqInput:
166
168
  stream: bool
167
169
 
168
170
 
171
+ @dataclass
172
+ class EmbeddingReqInput:
173
+ # The input prompt. It can be a single prompt or a batch of prompts.
174
+ text: Optional[Union[List[str], str]] = None
175
+ # The token ids for text; one can either specify text or input_ids.
176
+ input_ids: Optional[Union[List[List[int]], List[int]]] = None
177
+ # The request id.
178
+ rid: Optional[Union[List[str], str]] = None
179
+ # Dummy sampling params for compatibility
180
+ sampling_params: Union[List[Dict], Dict] = None
181
+
182
+ def post_init(self):
183
+ if (self.text is None and self.input_ids is None) or (
184
+ self.text is not None and self.input_ids is not None
185
+ ):
186
+ raise ValueError("Either text or input_ids should be provided.")
187
+
188
+ if self.text is not None:
189
+ is_single = isinstance(self.text, str)
190
+ else:
191
+ is_single = isinstance(self.input_ids[0], int)
192
+ self.is_single = is_single
193
+
194
+ if is_single:
195
+ if self.rid is None:
196
+ self.rid = uuid.uuid4().hex
197
+ if self.sampling_params is None:
198
+ self.sampling_params = {}
199
+ self.sampling_params["max_new_tokens"] = 1
200
+ else:
201
+ # support select operation
202
+ self.batch_size = (
203
+ len(self.text) if self.text is not None else len(self.input_ids)
204
+ )
205
+ if self.rid is None:
206
+ self.rid = [uuid.uuid4().hex for _ in range(self.batch_size)]
207
+ else:
208
+ if not isinstance(self.rid, list):
209
+ raise ValueError("The rid should be a list.")
210
+ if self.sampling_params is None:
211
+ self.sampling_params = [{}] * self.batch_size
212
+ for i in range(self.batch_size):
213
+ self.sampling_params[i]["max_new_tokens"] = 1
214
+
215
+
216
+ @dataclass
217
+ class TokenizedEmbeddingReqInput:
218
+ rid: str
219
+ input_text: str
220
+ input_ids: List[int]
221
+ sampling_params: SamplingParams
222
+
223
+
169
224
  @dataclass
170
225
  class BatchTokenIDOut:
171
226
  rids: List[str]
@@ -187,6 +242,14 @@ class BatchStrOut:
187
242
  finished_reason: List[BaseFinishReason]
188
243
 
189
244
 
245
+ @dataclass
246
+ class BatchEmbeddingOut:
247
+ rids: List[str]
248
+ embeddings: List[List[float]]
249
+ meta_info: List[Dict]
250
+ finished_reason: List[BaseFinishReason]
251
+
252
+
190
253
  @dataclass
191
254
  class FlushCacheReq:
192
255
  pass
@@ -15,44 +15,54 @@ limitations under the License.
15
15
 
16
16
  """Request policy scheduler"""
17
17
 
18
+ import os
18
19
  import random
19
20
  from collections import defaultdict
21
+ from contextlib import contextmanager
22
+ from typing import Dict, List, Optional
23
+
24
+ from sglang.srt.managers.schedule_batch import Req, ScheduleBatch
25
+ from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
26
+ from sglang.srt.mem_cache.radix_cache import TreeNode
27
+
28
+ # Clip the estimation of max_new_tokens for the request whose max_new_tokens is very large.
29
+ # This can prevent the server from being too conservative.
30
+ # Note that this only clips the estimation in the scheduler but does not change the stop
31
+ # condition. The request can still generate tokens until it hits the unclipped max_new_tokens.
32
+ CLIP_MAX_NEW_TOKENS = int(os.environ.get("SGLANG_CLIP_MAX_NEW_TOKENS", "4096"))
20
33
 
21
34
 
22
35
  class PolicyScheduler:
23
- def __init__(
24
- self,
25
- policy,
26
- max_running_seqs,
27
- max_prefill_num_tokens,
28
- max_total_num_tokens,
29
- tree_cache,
30
- ):
31
- if tree_cache.disable and policy == "lpm":
32
- # LMP is meaningless when the tree cache is disabled.
36
+ def __init__(self, policy: str, tree_cache: BasePrefixCache):
37
+ if tree_cache.disable and policy in ["lpm", "dfs-weight"]:
38
+ # LPM and DFS-weight is meaningless when the tree cache is disabled.
33
39
  policy = "fcfs"
34
40
 
35
41
  self.policy = policy
36
- self.max_running_seqs = max_running_seqs
37
- self.max_prefill_num_tokens = max_prefill_num_tokens
38
- self.max_total_num_tokens = max_total_num_tokens
39
42
  self.tree_cache = tree_cache
40
43
 
41
- def get_priority_queue(self, waiting_queue):
44
+ def calc_priority(self, waiting_queue: List[Req]):
45
+ # Compute matched prefix length
46
+ prefix_computed = False
47
+ if self.policy in ["lpm", "dfs-weight"]:
48
+ for r in waiting_queue:
49
+ # NOTE: the prefix_indices must always be aligned with last_node
50
+ r.prefix_indices, r.last_node = self.tree_cache.match_prefix(
51
+ rid=r.rid, key=r.adjust_max_prefix_ids()
52
+ )
53
+ prefix_computed = True
54
+
42
55
  if self.policy == "lpm":
43
- # longest prefix match
56
+ # Longest Prefix Match
44
57
  waiting_queue.sort(key=lambda x: -len(x.prefix_indices))
45
- return waiting_queue
46
58
  elif self.policy == "fcfs":
47
59
  # first come first serve
48
- return waiting_queue
60
+ pass
49
61
  elif self.policy == "lof":
50
62
  # longest output first
51
63
  waiting_queue.sort(key=lambda x: -x.sampling_params.max_new_tokens)
52
- return waiting_queue
53
64
  elif self.policy == "random":
54
65
  random.shuffle(waiting_queue)
55
- return waiting_queue
56
66
  elif self.policy == "dfs-weight":
57
67
  last_node_to_reqs = defaultdict(list)
58
68
  for req in waiting_queue:
@@ -63,23 +73,161 @@ class PolicyScheduler:
63
73
  node_to_weight[node] = len(last_node_to_reqs[node])
64
74
  self.calc_weight(self.tree_cache.root_node, node_to_weight)
65
75
 
66
- q = []
76
+ waiting_queue.clear()
67
77
  self.get_dfs_priority(
68
- self.tree_cache.root_node, node_to_weight, last_node_to_reqs, q
78
+ self.tree_cache.root_node,
79
+ node_to_weight,
80
+ last_node_to_reqs,
81
+ waiting_queue,
69
82
  )
70
- assert len(q) == len(waiting_queue)
71
- return q
72
83
  else:
73
84
  raise ValueError(f"Unknown schedule_policy: {self.policy}")
74
85
 
75
- def calc_weight(self, cur_node, node_to_weight):
86
+ return prefix_computed
87
+
88
+ def calc_weight(self, cur_node: TreeNode, node_to_weight: Dict):
76
89
  for child in cur_node.children.values():
77
90
  self.calc_weight(child, node_to_weight)
78
91
  node_to_weight[cur_node] += node_to_weight[child]
79
92
 
80
- def get_dfs_priority(self, cur_node, node_to_priority, last_node_to_reqs, q):
93
+ def get_dfs_priority(
94
+ self,
95
+ cur_node: TreeNode,
96
+ node_to_priority: Dict,
97
+ last_node_to_reqs: Dict,
98
+ q: List,
99
+ ):
81
100
  childs = [child for child in cur_node.children.values()]
82
101
  childs.sort(key=lambda x: -node_to_priority[x])
83
102
  for child in childs:
84
103
  self.get_dfs_priority(child, node_to_priority, last_node_to_reqs, q)
85
104
  q.extend(last_node_to_reqs[cur_node])
105
+
106
+
107
+ class PrefillAdder:
108
+ def __init__(
109
+ self,
110
+ tree_cache: BasePrefixCache,
111
+ rem_total_tokens: int,
112
+ rem_input_tokens: int,
113
+ rem_chunk_tokens: Optional[int],
114
+ ):
115
+ self.tree_cache = tree_cache
116
+ self.rem_total_tokens = rem_total_tokens
117
+ self.rem_input_tokens = rem_input_tokens
118
+ self.rem_chunk_tokens = rem_chunk_tokens
119
+
120
+ self.can_run_list = []
121
+ self.new_inflight_req = None
122
+ self.log_hit_tokens = 0
123
+ self.log_input_tokens = 0
124
+
125
+ def no_remaining_tokens(self):
126
+ return (
127
+ self.rem_total_tokens <= 0
128
+ or self.rem_input_tokens <= 0
129
+ or (
130
+ self.rem_chunk_tokens <= 0
131
+ if self.rem_chunk_tokens is not None
132
+ else False
133
+ )
134
+ )
135
+
136
+ def remove_running_tokens(
137
+ self, running_batch: ScheduleBatch, new_token_ratio: float
138
+ ):
139
+ self.rem_total_tokens -= sum(
140
+ [
141
+ min(
142
+ (r.sampling_params.max_new_tokens - len(r.output_ids)),
143
+ CLIP_MAX_NEW_TOKENS,
144
+ )
145
+ * new_token_ratio
146
+ for r in running_batch.reqs
147
+ ]
148
+ )
149
+
150
+ def _prefill_one_req(
151
+ self, prefix_len: int, extend_input_len: int, max_new_tokens: int
152
+ ):
153
+ self.rem_total_tokens -= extend_input_len + max_new_tokens
154
+ self.rem_input_tokens -= extend_input_len
155
+ if self.rem_chunk_tokens is not None:
156
+ self.rem_chunk_tokens -= extend_input_len
157
+
158
+ self.log_hit_tokens += prefix_len
159
+ self.log_input_tokens += extend_input_len
160
+
161
+ def add_inflight_req(self, req: Req):
162
+ truncated = req.extend_input_len > self.rem_chunk_tokens
163
+ req.extend_input_len = min(req.extend_input_len, self.rem_chunk_tokens)
164
+ req.fill_ids = req.fill_ids[: len(req.prefix_indices) + req.extend_input_len]
165
+ self.can_run_list.append(req)
166
+
167
+ self._prefill_one_req(
168
+ len(req.prefix_indices),
169
+ req.extend_input_len,
170
+ (
171
+ min(req.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKENS)
172
+ if not truncated
173
+ else 0
174
+ ),
175
+ )
176
+
177
+ # Return if chunked prefill not finished
178
+ return req if truncated else None
179
+
180
+ @contextmanager
181
+ def _lock_node(self, last_node: TreeNode):
182
+ try:
183
+ delta = self.tree_cache.inc_lock_ref(last_node)
184
+ self.rem_total_tokens += delta
185
+ yield None
186
+ finally:
187
+ delta = self.tree_cache.dec_lock_ref(last_node)
188
+ self.rem_total_tokens += delta
189
+
190
+ def add_one_req(self, req: Req):
191
+ total_tokens = req.extend_input_len + min(
192
+ req.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKENS
193
+ )
194
+ input_tokens = req.extend_input_len
195
+ prefix_len = len(req.prefix_indices)
196
+
197
+ if total_tokens >= self.rem_total_tokens:
198
+ return False
199
+
200
+ if input_tokens > self.rem_input_tokens and len(self.can_run_list) != 0:
201
+ return False
202
+
203
+ with self._lock_node(req.last_node):
204
+ if total_tokens > self.rem_total_tokens:
205
+ return False
206
+
207
+ if (
208
+ self.rem_chunk_tokens is None
209
+ or input_tokens <= self.rem_chunk_tokens
210
+ or (req.return_logprob and req.normalized_prompt_logprob is None)
211
+ ):
212
+ # Non-chunked prefill
213
+ self.can_run_list.append(req)
214
+ self.tree_cache.inc_lock_ref(req.last_node)
215
+ self._prefill_one_req(
216
+ prefix_len,
217
+ input_tokens,
218
+ min(req.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKENS),
219
+ )
220
+ else:
221
+ # Chunked prefill
222
+ trunc_len = self.rem_chunk_tokens
223
+ if trunc_len == 0:
224
+ return False
225
+
226
+ req.extend_input_len = trunc_len
227
+ req.fill_ids = req.fill_ids[: len(req.prefix_indices) + trunc_len]
228
+ self.can_run_list.append(req)
229
+ self.new_inflight_req = req
230
+ self.tree_cache.inc_lock_ref(req.last_node)
231
+ self._prefill_one_req(prefix_len, trunc_len, 0)
232
+
233
+ return True