sglang 0.2.11__py3-none-any.whl → 0.2.12__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (63) hide show
  1. sglang/bench_latency.py +6 -4
  2. sglang/bench_serving.py +46 -22
  3. sglang/lang/compiler.py +2 -2
  4. sglang/lang/ir.py +3 -3
  5. sglang/srt/constrained/base_tool_cache.py +1 -1
  6. sglang/srt/constrained/fsm_cache.py +12 -2
  7. sglang/srt/layers/activation.py +33 -0
  8. sglang/srt/layers/{token_attention.py → decode_attention.py} +9 -5
  9. sglang/srt/layers/extend_attention.py +6 -1
  10. sglang/srt/layers/layernorm.py +65 -0
  11. sglang/srt/layers/logits_processor.py +5 -0
  12. sglang/srt/layers/pooler.py +50 -0
  13. sglang/srt/layers/{context_flashattention_nopad.py → prefill_attention.py} +5 -0
  14. sglang/srt/layers/radix_attention.py +2 -2
  15. sglang/srt/managers/detokenizer_manager.py +31 -9
  16. sglang/srt/managers/io_struct.py +63 -0
  17. sglang/srt/managers/policy_scheduler.py +173 -25
  18. sglang/srt/managers/schedule_batch.py +110 -87
  19. sglang/srt/managers/tokenizer_manager.py +193 -111
  20. sglang/srt/managers/tp_worker.py +289 -352
  21. sglang/srt/mem_cache/{base_cache.py → base_prefix_cache.py} +9 -4
  22. sglang/srt/mem_cache/chunk_cache.py +43 -20
  23. sglang/srt/mem_cache/memory_pool.py +2 -2
  24. sglang/srt/mem_cache/radix_cache.py +74 -40
  25. sglang/srt/model_executor/cuda_graph_runner.py +24 -9
  26. sglang/srt/model_executor/forward_batch_info.py +168 -105
  27. sglang/srt/model_executor/model_runner.py +24 -37
  28. sglang/srt/models/gemma2.py +0 -1
  29. sglang/srt/models/internlm2.py +2 -7
  30. sglang/srt/models/llama2.py +4 -4
  31. sglang/srt/models/llama_embedding.py +88 -0
  32. sglang/srt/models/qwen2_moe.py +0 -11
  33. sglang/srt/openai_api/adapter.py +155 -27
  34. sglang/srt/openai_api/protocol.py +37 -1
  35. sglang/srt/sampling/penaltylib/__init__.py +13 -0
  36. sglang/srt/sampling/penaltylib/orchestrator.py +357 -0
  37. sglang/srt/sampling/penaltylib/penalizers/frequency_penalty.py +80 -0
  38. sglang/srt/sampling/penaltylib/penalizers/min_new_tokens.py +105 -0
  39. sglang/srt/sampling/penaltylib/penalizers/presence_penalty.py +79 -0
  40. sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +83 -0
  41. sglang/srt/sampling_params.py +31 -4
  42. sglang/srt/server.py +69 -15
  43. sglang/srt/server_args.py +26 -19
  44. sglang/srt/utils.py +31 -13
  45. sglang/test/run_eval.py +10 -1
  46. sglang/test/runners.py +63 -63
  47. sglang/test/simple_eval_humaneval.py +2 -8
  48. sglang/test/simple_eval_mgsm.py +203 -0
  49. sglang/test/srt/sampling/penaltylib/utils.py +337 -0
  50. sglang/test/test_layernorm.py +60 -0
  51. sglang/test/test_programs.py +4 -2
  52. sglang/test/test_utils.py +20 -2
  53. sglang/utils.py +0 -1
  54. sglang/version.py +1 -1
  55. {sglang-0.2.11.dist-info → sglang-0.2.12.dist-info}/METADATA +23 -14
  56. sglang-0.2.12.dist-info/RECORD +112 -0
  57. sglang/srt/layers/linear.py +0 -884
  58. sglang/srt/layers/quantization/__init__.py +0 -64
  59. sglang/srt/layers/quantization/fp8.py +0 -677
  60. sglang-0.2.11.dist-info/RECORD +0 -102
  61. {sglang-0.2.11.dist-info → sglang-0.2.12.dist-info}/LICENSE +0 -0
  62. {sglang-0.2.11.dist-info → sglang-0.2.12.dist-info}/WHEEL +0 -0
  63. {sglang-0.2.11.dist-info → sglang-0.2.12.dist-info}/top_level.txt +0 -0
@@ -1,4 +1,5 @@
1
1
  from abc import ABC, abstractmethod
2
+ from typing import Callable
2
3
 
3
4
 
4
5
  class BasePrefixCache(ABC):
@@ -17,11 +18,15 @@ class BasePrefixCache(ABC):
17
18
  pass
18
19
 
19
20
  @abstractmethod
20
- def cache_req(self, **kwargs):
21
+ def cache_finished_req(self, **kwargs):
21
22
  pass
22
23
 
23
24
  @abstractmethod
24
- def evict(self, num_tokens, evict_callback):
25
+ def cache_unfinished_req(self, **kwargs):
26
+ pass
27
+
28
+ @abstractmethod
29
+ def evict(self, num_tokens: int, evict_callback: Callable):
25
30
  pass
26
31
 
27
32
  @abstractmethod
@@ -37,7 +42,7 @@ class BasePrefixCache(ABC):
37
42
  pass
38
43
 
39
44
  def total_size(self):
40
- raise NotImplementedError
45
+ raise NotImplementedError()
41
46
 
42
47
  def pretty_print(self):
43
- raise NotImplementedError
48
+ raise NotImplementedError()
@@ -1,6 +1,14 @@
1
+ from __future__ import annotations
2
+
1
3
  """Cache for chunked prefill, used when RadixCache is disabled."""
2
4
 
3
- from sglang.srt.mem_cache.base_cache import BasePrefixCache
5
+ from typing import TYPE_CHECKING, Callable, List, Optional
6
+
7
+ from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
8
+ from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
9
+
10
+ if TYPE_CHECKING:
11
+ from sglang.srt.managers.schedule_batch import Req
4
12
 
5
13
 
6
14
  class ChunkCacheEntry:
@@ -10,7 +18,9 @@ class ChunkCacheEntry:
10
18
 
11
19
 
12
20
  class ChunkCache(BasePrefixCache):
13
- def __init__(self, req_to_token_pool, token_to_kv_pool):
21
+ def __init__(
22
+ self, req_to_token_pool: ReqToTokenPool, token_to_kv_pool: BaseTokenToKVPool
23
+ ):
14
24
  self.disable = True
15
25
  self.req_to_token_pool = req_to_token_pool
16
26
  self.token_to_kv_pool = token_to_kv_pool
@@ -20,34 +30,47 @@ class ChunkCache(BasePrefixCache):
20
30
  def reset(self):
21
31
  self.entries = {}
22
32
 
23
- def match_prefix(self, rid, **kwargs):
33
+ def match_prefix(self, rid: int, key: List[int]):
24
34
  if rid not in self.entries:
25
35
  return [], None
26
36
 
27
37
  entry = self.entries[rid]
28
- return entry.value, entry
38
+ max_prefix_len = len(key)
39
+ return entry.value[:max_prefix_len], entry
29
40
 
30
- def cache_req(
31
- self, rid, token_ids, req_pool_idx, del_in_memory_pool=True, **kwargs
32
- ):
33
- indices = self.req_to_token_pool.req_to_token[req_pool_idx, : len(token_ids)]
34
- if del_in_memory_pool:
35
- assert rid in self.entries
36
- self.req_to_token_pool.free(req_pool_idx)
37
- self.token_to_kv_pool.free(indices)
38
- return
41
+ def cache_finished_req(self, req: Req, token_ids: Optional[List[int]] = None):
42
+ if token_ids is None:
43
+ token_ids = (req.origin_input_ids + req.output_ids)[:-1]
39
44
 
40
- if rid not in self.entries:
41
- self.entries[rid] = ChunkCacheEntry(rid, indices)
45
+ kv_indices = self.req_to_token_pool.req_to_token[
46
+ req.req_pool_idx, : len(token_ids)
47
+ ]
48
+ self.req_to_token_pool.free(req.req_pool_idx)
49
+ self.token_to_kv_pool.free(kv_indices)
42
50
 
43
- entry = self.entries[rid]
44
- entry.value = indices
45
- return indices, entry
51
+ if req.rid in self.entries:
52
+ del self.entries[req.rid]
53
+
54
+ def cache_unfinished_req(self, req: Req, token_ids: Optional[List[int]] = None):
55
+ if token_ids is None:
56
+ token_ids = req.fill_ids
57
+
58
+ kv_indices = self.req_to_token_pool.req_to_token[
59
+ req.req_pool_idx, : len(token_ids)
60
+ ]
61
+
62
+ if req.rid not in self.entries:
63
+ self.entries[req.rid] = ChunkCacheEntry(req.rid, kv_indices)
64
+
65
+ entry = self.entries[req.rid]
66
+ entry.value = kv_indices
67
+ req.prefix_indices = kv_indices
68
+ req.last_node = entry
46
69
 
47
70
  def insert(self):
48
- raise NotImplementedError
71
+ raise NotImplementedError()
49
72
 
50
- def evict(self, num_tokens, evict_callback):
73
+ def evict(self, num_tokens: int, evict_callback: Callable):
51
74
  pass
52
75
 
53
76
  def inc_lock_ref(self, node):
@@ -16,7 +16,7 @@ limitations under the License.
16
16
  """Memory pool."""
17
17
 
18
18
  import logging
19
- from typing import List
19
+ from typing import List, Union
20
20
 
21
21
  import torch
22
22
 
@@ -42,7 +42,7 @@ class ReqToTokenPool:
42
42
 
43
43
  return select_index
44
44
 
45
- def free(self, free_index):
45
+ def free(self, free_index: Union[int, List[int]]):
46
46
  if isinstance(free_index, (int,)):
47
47
  self.free_slots.append(free_index)
48
48
  else:
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  """
2
4
  Copyright 2023-2024 SGLang Team
3
5
  Licensed under the Apache License, Version 2.0 (the "License");
@@ -20,10 +22,15 @@ The radix tree data structure for managing the KV cache.
20
22
  import heapq
21
23
  import time
22
24
  from collections import defaultdict
25
+ from typing import TYPE_CHECKING, Callable, List, Optional
23
26
 
24
27
  import torch
25
28
 
26
- from sglang.srt.mem_cache.base_cache import BasePrefixCache
29
+ from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
30
+ from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
31
+
32
+ if TYPE_CHECKING:
33
+ from sglang.srt.managers.schedule_batch import Req
27
34
 
28
35
 
29
36
  class TreeNode:
@@ -39,7 +46,7 @@ class TreeNode:
39
46
  return self.last_access_time < other.last_access_time
40
47
 
41
48
 
42
- def _key_match(key0, key1):
49
+ def _key_match(key0: List, key1: List):
43
50
  i = 0
44
51
  for k0, k1 in zip(key0, key1):
45
52
  if k0 != k1:
@@ -49,7 +56,12 @@ def _key_match(key0, key1):
49
56
 
50
57
 
51
58
  class RadixCache(BasePrefixCache):
52
- def __init__(self, req_to_token_pool, token_to_kv_pool, disable: bool = False):
59
+ def __init__(
60
+ self,
61
+ req_to_token_pool: ReqToTokenPool,
62
+ token_to_kv_pool: BaseTokenToKVPool,
63
+ disable: bool = False,
64
+ ):
53
65
  self.req_to_token_pool = req_to_token_pool
54
66
  self.token_to_kv_pool = token_to_kv_pool
55
67
  self.disable = disable
@@ -64,7 +76,7 @@ class RadixCache(BasePrefixCache):
64
76
  self.root_node.lock_ref = 1
65
77
  self.evictable_size_ = 0
66
78
 
67
- def match_prefix(self, key, **kwargs):
79
+ def match_prefix(self, key: List, **kwargs):
68
80
  if self.disable:
69
81
  return [], self.root_node
70
82
 
@@ -74,10 +86,10 @@ class RadixCache(BasePrefixCache):
74
86
  if value:
75
87
  value = torch.concat(value)
76
88
  else:
77
- value = torch.tensor([], dtype=torch.int64)
89
+ value = torch.tensor([], dtype=torch.int32)
78
90
  return value, last_node[0]
79
91
 
80
- def insert(self, key, value=None):
92
+ def insert(self, key: List, value=None):
81
93
  if self.disable:
82
94
  return 0
83
95
 
@@ -85,40 +97,54 @@ class RadixCache(BasePrefixCache):
85
97
  value = [x for x in key]
86
98
  return self._insert_helper(self.root_node, key, value)
87
99
 
88
- def cache_req(
89
- self,
90
- token_ids,
91
- last_uncached_pos,
92
- req_pool_idx,
93
- del_in_memory_pool=True,
94
- old_last_node=None,
95
- **kwargs,
96
- ):
97
- # Insert the request into radix cache
98
- indices = self.req_to_token_pool.req_to_token[req_pool_idx, : len(token_ids)]
99
- new_prefix_len = self.insert(token_ids, indices.clone())
100
+ def cache_finished_req(self, req: Req, token_ids: Optional[List[int]] = None):
101
+ """Cache request when it finishes."""
102
+ if token_ids is None:
103
+ token_ids = (req.origin_input_ids + req.output_ids)[:-1]
104
+ kv_indices = self.req_to_token_pool.req_to_token[
105
+ req.req_pool_idx, : len(token_ids)
106
+ ]
100
107
 
101
108
  if self.disable:
102
- if del_in_memory_pool:
103
- self.token_to_kv_pool.free(indices)
104
- else:
105
- return torch.tensor([], dtype=torch.int64), self.root_node
109
+ self.token_to_kv_pool.free(kv_indices)
110
+ self.req_to_token_pool.free(req.req_pool_idx)
111
+ return
106
112
 
107
113
  # Radix Cache takes one ref in memory pool
108
- self.token_to_kv_pool.free(indices[last_uncached_pos:new_prefix_len])
114
+ new_prefix_len = self.insert(token_ids, kv_indices.clone())
115
+ self.token_to_kv_pool.free(kv_indices[len(req.prefix_indices) : new_prefix_len])
109
116
 
110
- if del_in_memory_pool:
111
- self.req_to_token_pool.free(req_pool_idx)
112
- else:
113
- cached_indices, new_last_node = self.match_prefix(token_ids)
114
- assert len(cached_indices) == len(token_ids)
117
+ # Remove req slot release the cache lock
118
+ self.req_to_token_pool.free(req.req_pool_idx)
119
+ self.dec_lock_ref(req.last_node)
120
+
121
+ def cache_unfinished_req(self, req: Req, token_ids: Optional[List[int]] = None):
122
+ """Cache request when it is unfinished."""
123
+ if self.disable:
124
+ return
115
125
 
116
- self.req_to_token_pool.req_to_token[
117
- req_pool_idx, last_uncached_pos : len(cached_indices)
118
- ] = cached_indices[last_uncached_pos:]
119
- self.dec_lock_ref(old_last_node)
120
- self.inc_lock_ref(new_last_node)
121
- return cached_indices, new_last_node
126
+ if token_ids is None:
127
+ token_ids = req.fill_ids
128
+
129
+ kv_indices = self.req_to_token_pool.req_to_token[
130
+ req.req_pool_idx, : len(token_ids)
131
+ ]
132
+
133
+ # Radix Cache takes one ref in memory pool
134
+ new_prefix_len = self.insert(token_ids, kv_indices.clone())
135
+ self.token_to_kv_pool.free(kv_indices[len(req.prefix_indices) : new_prefix_len])
136
+
137
+ # The prefix indices could be updated, reuse it
138
+ new_indices, new_last_node = self.match_prefix(token_ids)
139
+ assert len(new_indices) == len(token_ids)
140
+ self.req_to_token_pool.req_to_token[
141
+ req.req_pool_idx, len(req.prefix_indices) : len(new_indices)
142
+ ] = new_indices[len(req.prefix_indices) :]
143
+
144
+ self.dec_lock_ref(req.last_node)
145
+ self.inc_lock_ref(new_last_node)
146
+ req.prefix_indices = new_indices
147
+ req.last_node = new_last_node
122
148
 
123
149
  def pretty_print(self):
124
150
  self._print_helper(self.root_node, 0)
@@ -127,7 +153,7 @@ class RadixCache(BasePrefixCache):
127
153
  def total_size(self):
128
154
  return self._total_size_helper(self.root_node)
129
155
 
130
- def evict(self, num_tokens, evict_callback):
156
+ def evict(self, num_tokens: int, evict_callback: Callable):
131
157
  if self.disable:
132
158
  return
133
159
 
@@ -151,6 +177,9 @@ class RadixCache(BasePrefixCache):
151
177
  heapq.heappush(leaves, x.parent)
152
178
 
153
179
  def inc_lock_ref(self, node: TreeNode):
180
+ if self.disable:
181
+ return 0
182
+
154
183
  delta = 0
155
184
  while node != self.root_node:
156
185
  if node.lock_ref == 0:
@@ -161,6 +190,9 @@ class RadixCache(BasePrefixCache):
161
190
  return delta
162
191
 
163
192
  def dec_lock_ref(self, node: TreeNode):
193
+ if self.disable:
194
+ return 0
195
+
164
196
  delta = 0
165
197
  while node != self.root_node:
166
198
  if node.lock_ref == 1:
@@ -175,7 +207,9 @@ class RadixCache(BasePrefixCache):
175
207
 
176
208
  ##### Internal Helper Functions #####
177
209
 
178
- def _match_prefix_helper(self, node, key, value, last_node):
210
+ def _match_prefix_helper(
211
+ self, node: TreeNode, key: List, value, last_node: TreeNode
212
+ ):
179
213
  node.last_access_time = time.time()
180
214
  if len(key) == 0:
181
215
  return
@@ -192,7 +226,7 @@ class RadixCache(BasePrefixCache):
192
226
  last_node[0] = child
193
227
  self._match_prefix_helper(child, key[prefix_len:], value, last_node)
194
228
 
195
- def _split_node(self, key, child: TreeNode, split_len):
229
+ def _split_node(self, key, child: TreeNode, split_len: int):
196
230
  # new_node -> child
197
231
  new_node = TreeNode()
198
232
  new_node.children = {key[split_len:][0]: child}
@@ -206,7 +240,7 @@ class RadixCache(BasePrefixCache):
206
240
  new_node.parent.children[key[:split_len][0]] = new_node
207
241
  return new_node
208
242
 
209
- def _insert_helper(self, node, key, value):
243
+ def _insert_helper(self, node: TreeNode, key: List, value):
210
244
  node.last_access_time = time.time()
211
245
  if len(key) == 0:
212
246
  return 0
@@ -237,7 +271,7 @@ class RadixCache(BasePrefixCache):
237
271
  self.evictable_size_ += len(value)
238
272
  return 0
239
273
 
240
- def _print_helper(self, node: TreeNode, indent):
274
+ def _print_helper(self, node: TreeNode, indent: int):
241
275
  for _, child in node.children.items():
242
276
  print(" " * indent, len(child.key), child.key[:10], f"r={child.lock_ref}")
243
277
  self._print_helper(child, indent=indent + 2)
@@ -249,7 +283,7 @@ class RadixCache(BasePrefixCache):
249
283
  del node.parent.children[k]
250
284
  self.evictable_size_ -= len(node.key)
251
285
 
252
- def _total_size_helper(self, node):
286
+ def _total_size_helper(self, node: TreeNode):
253
287
  x = len(node.value)
254
288
  for child in node.children.values():
255
289
  x += self._total_size_helper(child)
@@ -33,7 +33,7 @@ from sglang.srt.managers.schedule_batch import ScheduleBatch
33
33
  from sglang.srt.model_executor.forward_batch_info import (
34
34
  ForwardMode,
35
35
  InputMetadata,
36
- init_flashinfer_args,
36
+ update_flashinfer_indices,
37
37
  )
38
38
  from sglang.srt.utils import monkey_patch_vllm_all_gather
39
39
 
@@ -71,6 +71,18 @@ def patch_model(
71
71
  tp_group.ca_comm = backup_ca_comm
72
72
 
73
73
 
74
+ def set_torch_compile_config():
75
+ import torch._dynamo.config
76
+ import torch._inductor.config
77
+
78
+ torch._inductor.config.coordinate_descent_tuning = True
79
+ torch._inductor.config.triton.unique_kernel_names = True
80
+ torch._inductor.config.fx_graph_cache = True # Experimental feature to reduce compilation times, will be on by default in future
81
+
82
+ # FIXME: tmp workaround
83
+ torch._dynamo.config.accumulated_cache_size_limit = 1024
84
+
85
+
74
86
  class CudaGraphRunner:
75
87
  def __init__(self, model_runner, max_batch_size_to_capture, use_torch_compile):
76
88
  self.model_runner = model_runner
@@ -112,6 +124,9 @@ class CudaGraphRunner:
112
124
 
113
125
  self.compile_bs = [1, 2, 4, 8, 16, 24, 32] if use_torch_compile else []
114
126
 
127
+ if use_torch_compile:
128
+ set_torch_compile_config()
129
+
115
130
  def can_run(self, batch_size):
116
131
  return batch_size < self.max_bs
117
132
 
@@ -165,7 +180,7 @@ class CudaGraphRunner:
165
180
  paged_kv_indices_buffer=self.flashinfer_kv_indices,
166
181
  paged_kv_last_page_len_buffer=self.flashinfer_kv_last_page_len[:bs],
167
182
  )
168
- init_flashinfer_args(
183
+ update_flashinfer_indices(
169
184
  ForwardMode.DECODE,
170
185
  self.model_runner,
171
186
  req_pool_indices,
@@ -176,19 +191,19 @@ class CudaGraphRunner:
176
191
 
177
192
  # Run and capture
178
193
  def run_once():
179
- input_metadata = InputMetadata.create(
180
- self.model_runner,
194
+ input_metadata = InputMetadata(
181
195
  forward_mode=ForwardMode.DECODE,
196
+ batch_size=bs,
182
197
  req_pool_indices=req_pool_indices,
183
198
  seq_lens=seq_lens,
184
- prefix_lens=None,
185
- position_ids_offsets=position_ids_offsets,
199
+ req_to_token_pool=self.model_runner.req_to_token_pool,
200
+ token_to_kv_pool=self.model_runner.token_to_kv_pool,
186
201
  out_cache_loc=out_cache_loc,
187
202
  return_logprob=False,
188
203
  top_logprobs_nums=0,
189
- skip_flashinfer_init=True,
204
+ positions=(seq_lens - 1).to(torch.int64),
205
+ flashinfer_decode_wrapper=flashinfer_decode_wrapper,
190
206
  )
191
- input_metadata.flashinfer_decode_wrapper = flashinfer_decode_wrapper
192
207
 
193
208
  return forward(input_ids, input_metadata.positions, input_metadata)
194
209
 
@@ -222,7 +237,7 @@ class CudaGraphRunner:
222
237
  self.out_cache_loc[:raw_bs] = batch.out_cache_loc
223
238
 
224
239
  # FlashInfer inputs
225
- init_flashinfer_args(
240
+ update_flashinfer_indices(
226
241
  ForwardMode.DECODE,
227
242
  self.model_runner,
228
243
  self.req_pool_indices[:bs],