sglang 0.2.7__py3-none-any.whl → 0.2.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
sglang/bench_serving.py CHANGED
@@ -21,7 +21,7 @@ import sys
21
21
  import time
22
22
  import traceback
23
23
  import warnings
24
- from argparse import ArgumentParser as FlexibleArgumentParser
24
+ from argparse import ArgumentParser
25
25
  from dataclasses import dataclass, field
26
26
  from datetime import datetime
27
27
  from typing import AsyncGenerator, List, Optional, Tuple, Union
@@ -868,14 +868,12 @@ def set_ulimit(target_soft_limit=65535):
868
868
 
869
869
 
870
870
  if __name__ == "__main__":
871
- parser = FlexibleArgumentParser(
872
- description="Benchmark the online serving throughput."
873
- )
871
+ parser = ArgumentParser(description="Benchmark the online serving throughput.")
874
872
  parser.add_argument(
875
873
  "--backend",
876
874
  type=str,
877
- required=True,
878
875
  choices=list(ASYNC_REQUEST_FUNCS.keys()),
876
+ default="sglang",
879
877
  help="Must specify a backend, depending on the LLM Inference Engine.",
880
878
  )
881
879
  parser.add_argument(
@@ -553,7 +553,8 @@ class StreamExecutor:
553
553
  "output_token_logprobs": output_token_logprobs,
554
554
  }
555
555
  self.variable_event[name].set()
556
- self.stream_var_event[name].set()
556
+ if self.stream_var_event:
557
+ self.stream_var_event[name].set()
557
558
  self.text_ += decision
558
559
 
559
560
  def _execute_variable(self, expr: SglVariable):
sglang/lang/ir.py CHANGED
@@ -99,7 +99,6 @@ class SglSamplingParams:
99
99
  "stop": self.stop or None,
100
100
  "temperature": self.temperature,
101
101
  "top_p": self.top_p,
102
- "top_k": self.top_k,
103
102
  "frequency_penalty": self.frequency_penalty,
104
103
  "presence_penalty": self.presence_penalty,
105
104
  }
@@ -13,12 +13,12 @@ See the License for the specific language governing permissions and
13
13
  limitations under the License.
14
14
  """
15
15
 
16
- """Base cache class."""
16
+ """Base tool cache for constrained decoding tools."""
17
17
 
18
18
  import time
19
19
 
20
20
 
21
- class BaseCache:
21
+ class BaseToolCache:
22
22
  def __init__(self, enable=True):
23
23
  self.enable = enable
24
24
  self.reset()
@@ -16,10 +16,10 @@ limitations under the License.
16
16
  """Cache for the compressed finite state machine."""
17
17
 
18
18
  from sglang.srt.constrained import RegexGuide, TransformerTokenizer
19
- from sglang.srt.constrained.base_cache import BaseCache
19
+ from sglang.srt.constrained.base_tool_cache import BaseToolCache
20
20
 
21
21
 
22
- class FSMCache(BaseCache):
22
+ class FSMCache(BaseToolCache):
23
23
  def __init__(self, tokenizer_path, tokenizer_args_dict, enable=True):
24
24
  super().__init__(enable=enable)
25
25
 
@@ -30,7 +30,7 @@ from sglang.srt.constrained import (
30
30
  make_byte_level_fsm,
31
31
  make_deterministic_fsm,
32
32
  )
33
- from sglang.srt.constrained.base_cache import BaseCache
33
+ from sglang.srt.constrained.base_tool_cache import BaseToolCache
34
34
 
35
35
  IP_REGEX = r"((25[0-5]|2[0-4]\d|[01]?\d\d?)\.){3}(25[0-5]|2[0-4]\d|[01]?\d\d?)"
36
36
 
@@ -151,7 +151,7 @@ class JumpForwardMap:
151
151
  )
152
152
 
153
153
 
154
- class JumpForwardCache(BaseCache):
154
+ class JumpForwardCache(BaseToolCache):
155
155
  def __init__(self):
156
156
  super().__init__()
157
157
 
@@ -209,7 +209,7 @@ class LogitsProcessor(nn.Module):
209
209
  all_logits = all_logits[:, : self.config.vocab_size].float()
210
210
 
211
211
  all_logprobs = all_logits
212
- del all_logits
212
+ del all_logits, hidden_states
213
213
  all_logprobs[:] = torch.nn.functional.log_softmax(all_logprobs, dim=-1)
214
214
 
215
215
  # Get the logprob of top-k tokens
@@ -28,6 +28,7 @@ from flashinfer.sampling import top_k_top_p_sampling_from_probs
28
28
  from sglang.global_config import global_config
29
29
  from sglang.srt.constrained import RegexGuide
30
30
  from sglang.srt.constrained.jump_forward import JumpForwardMap
31
+ from sglang.srt.mem_cache.chunk_cache import ChunkCache
31
32
  from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPool
32
33
  from sglang.srt.mem_cache.radix_cache import RadixCache
33
34
 
@@ -486,15 +487,33 @@ class Batch:
486
487
  req = self.reqs[idx]
487
488
  retracted_reqs.append(req)
488
489
 
489
- # TODO: apply more fine-grained retraction
490
- last_uncached_pos = len(req.prefix_indices)
491
- token_indices = self.req_to_token_pool.req_to_token[
492
- req_pool_indices_cpu[idx]
493
- ][last_uncached_pos : seq_lens_cpu[idx]]
494
- self.token_to_kv_pool.free(token_indices)
495
-
496
- # release the last node
497
- self.tree_cache.dec_lock_ref(req.last_node)
490
+ if isinstance(self.tree_cache, ChunkCache):
491
+ # ChunkCache does not have eviction
492
+ token_indices = self.req_to_token_pool.req_to_token[
493
+ req_pool_indices_cpu[idx]
494
+ ][: seq_lens_cpu[idx]]
495
+ self.token_to_kv_pool.free(token_indices)
496
+ self.req_to_token_pool.free(int(req_pool_indices_cpu[idx]))
497
+ del self.tree_cache.entries[req.rid]
498
+ else:
499
+ # TODO: apply more fine-grained retraction
500
+ last_uncached_pos = len(req.prefix_indices)
501
+ token_indices = self.req_to_token_pool.req_to_token[
502
+ req_pool_indices_cpu[idx]
503
+ ][last_uncached_pos : seq_lens_cpu[idx]]
504
+ self.token_to_kv_pool.free(token_indices)
505
+ self.req_to_token_pool.free(int(req_pool_indices_cpu[idx]))
506
+
507
+ # release the last node
508
+ self.tree_cache.dec_lock_ref(req.last_node)
509
+
510
+ # NOTE(lsyin): we should use the newly evictable memory instantly.
511
+ residual_size = (
512
+ len(sorted_indices) * global_config.retract_decode_steps
513
+ - self.token_to_kv_pool.available_size()
514
+ )
515
+ residual_size = max(0, residual_size)
516
+ self.tree_cache.evict(residual_size, self.token_to_kv_pool.free)
498
517
 
499
518
  req.prefix_indices = None
500
519
  req.last_node = None
@@ -575,6 +594,7 @@ class Batch:
575
594
  if req_pool_indices_cpu is None:
576
595
  req_pool_indices_cpu = self.req_pool_indices.tolist()
577
596
  self.tree_cache.cache_req(
597
+ rid=req.rid,
578
598
  token_ids=cur_all_ids,
579
599
  last_uncached_pos=len(req.prefix_indices),
580
600
  req_pool_idx=req_pool_indices_cpu[i],
@@ -79,6 +79,7 @@ class TokenizerManager:
79
79
  self.send_to_router.connect(f"tcp://127.0.0.1:{port_args.controller_port}")
80
80
 
81
81
  self.model_path = server_args.model_path
82
+ self.served_model_name = server_args.served_model_name
82
83
  self.hf_config = get_config(
83
84
  self.model_path,
84
85
  trust_remote_code=server_args.trust_remote_code,
@@ -43,6 +43,7 @@ from sglang.srt.managers.schedule_batch import (
43
43
  ForwardMode,
44
44
  Req,
45
45
  )
46
+ from sglang.srt.mem_cache.chunk_cache import ChunkCache
46
47
  from sglang.srt.mem_cache.radix_cache import RadixCache
47
48
  from sglang.srt.model_config import ModelConfig
48
49
  from sglang.srt.model_executor.model_runner import ModelRunner
@@ -144,11 +145,20 @@ class ModelTpServer:
144
145
  )
145
146
 
146
147
  # Init cache
147
- self.tree_cache = RadixCache(
148
- req_to_token_pool=self.model_runner.req_to_token_pool,
149
- token_to_kv_pool=self.model_runner.token_to_kv_pool,
150
- disable=server_args.disable_radix_cache,
151
- )
148
+ if (
149
+ server_args.chunked_prefill_size is not None
150
+ and server_args.disable_radix_cache
151
+ ):
152
+ self.tree_cache = ChunkCache(
153
+ req_to_token_pool=self.model_runner.req_to_token_pool,
154
+ token_to_kv_pool=self.model_runner.token_to_kv_pool,
155
+ )
156
+ else:
157
+ self.tree_cache = RadixCache(
158
+ req_to_token_pool=self.model_runner.req_to_token_pool,
159
+ token_to_kv_pool=self.model_runner.token_to_kv_pool,
160
+ disable=server_args.disable_radix_cache,
161
+ )
152
162
  self.tree_cache_metrics = {"total": 0, "hit": 0}
153
163
  self.scheduler = PolicyScheduler(
154
164
  self.schedule_policy,
@@ -280,6 +290,14 @@ class ModelTpServer:
280
290
  "KV cache pool leak detected!"
281
291
  )
282
292
 
293
+ if self.req_to_token_pool.can_use_mem_size != self.req_to_token_pool.size:
294
+ warnings.warn(
295
+ "Warning: "
296
+ f"available req slots={self.req_to_token_pool.can_use_mem_size}, "
297
+ f"total slots={self.req_to_token_pool.size}\n"
298
+ "Memory pool leak detected!"
299
+ )
300
+
283
301
  def handle_generate_request(
284
302
  self,
285
303
  recv_req: TokenizedGenerateReqInput,
@@ -346,7 +364,10 @@ class ModelTpServer:
346
364
  # Compute matched prefix length
347
365
  for req in self.waiting_queue:
348
366
  req.input_ids = req.origin_input_ids + req.output_ids
349
- prefix_indices, last_node = self.tree_cache.match_prefix(req.input_ids)
367
+ prefix_indices, last_node = self.tree_cache.match_prefix(
368
+ rid=req.rid,
369
+ key=req.input_ids,
370
+ )
350
371
  if req.return_logprob:
351
372
  prefix_indices = prefix_indices[: req.logprob_start_len]
352
373
  req.extend_input_len = len(req.input_ids) - len(prefix_indices)
@@ -606,6 +627,7 @@ class ModelTpServer:
606
627
  req_pool_indices_cpu = batch.req_pool_indices.cpu().numpy()
607
628
  for i, req in enumerate(batch.reqs):
608
629
  new_prefix_indices, new_last_node = self.tree_cache.cache_req(
630
+ rid=req.rid,
609
631
  token_ids=tuple(req.input_ids),
610
632
  last_uncached_pos=len(req.prefix_indices),
611
633
  req_pool_idx=req_pool_indices_cpu[i],
@@ -763,6 +785,7 @@ class ModelTpServer:
763
785
  for i in finished_indices:
764
786
  req = batch.reqs[i]
765
787
  self.tree_cache.cache_req(
788
+ rid=req.rid,
766
789
  token_ids=tuple(req.origin_input_ids + req.output_ids)[:-1],
767
790
  last_uncached_pos=len(req.prefix_indices),
768
791
  req_pool_idx=req_pool_indices_cpu[i],
@@ -0,0 +1,43 @@
1
+ from abc import ABC, abstractmethod
2
+
3
+
4
+ class BasePrefixCache(ABC):
5
+ """Cache can be indexed by either rid or key."""
6
+
7
+ @abstractmethod
8
+ def reset(self):
9
+ pass
10
+
11
+ @abstractmethod
12
+ def match_prefix(self, **kwargs):
13
+ pass
14
+
15
+ @abstractmethod
16
+ def insert(self, **kwargs):
17
+ pass
18
+
19
+ @abstractmethod
20
+ def cache_req(self, **kwargs):
21
+ pass
22
+
23
+ @abstractmethod
24
+ def evict(self, num_tokens, evict_callback):
25
+ pass
26
+
27
+ @abstractmethod
28
+ def inc_lock_ref(self, node):
29
+ pass
30
+
31
+ @abstractmethod
32
+ def dec_lock_ref(self, node):
33
+ pass
34
+
35
+ @abstractmethod
36
+ def evictable_size(self):
37
+ pass
38
+
39
+ def total_size(self):
40
+ raise NotImplementedError
41
+
42
+ def pretty_print(self):
43
+ raise NotImplementedError
@@ -0,0 +1,60 @@
1
+ """Cache for chunked prefill, used when RadixCache is disabled."""
2
+
3
+ from sglang.srt.mem_cache.base_cache import BasePrefixCache
4
+
5
+
6
+ class ChunkCacheEntry:
7
+ def __init__(self, rid, value):
8
+ self.rid = rid
9
+ self.value = value
10
+
11
+
12
+ class ChunkCache(BasePrefixCache):
13
+ def __init__(self, req_to_token_pool, token_to_kv_pool):
14
+ self.disable = True
15
+ self.req_to_token_pool = req_to_token_pool
16
+ self.token_to_kv_pool = token_to_kv_pool
17
+
18
+ self.reset()
19
+
20
+ def reset(self):
21
+ self.entries = {}
22
+
23
+ def match_prefix(self, rid, **kwargs):
24
+ if rid not in self.entries:
25
+ return [], None
26
+
27
+ entry = self.entries[rid]
28
+ return entry.value, entry
29
+
30
+ def cache_req(
31
+ self, rid, token_ids, req_pool_idx, del_in_memory_pool=True, **kwargs
32
+ ):
33
+ indices = self.req_to_token_pool.req_to_token[req_pool_idx, : len(token_ids)]
34
+ if del_in_memory_pool:
35
+ assert rid in self.entries
36
+ self.req_to_token_pool.free(req_pool_idx)
37
+ self.token_to_kv_pool.free(indices)
38
+ return
39
+
40
+ if rid not in self.entries:
41
+ self.entries[rid] = ChunkCacheEntry(rid, indices)
42
+
43
+ entry = self.entries[rid]
44
+ entry.value = indices
45
+ return indices, entry
46
+
47
+ def insert(self):
48
+ raise NotImplementedError
49
+
50
+ def evict(self, num_tokens, evict_callback):
51
+ pass
52
+
53
+ def inc_lock_ref(self, node):
54
+ return 0
55
+
56
+ def dec_lock_ref(self, node):
57
+ return 0
58
+
59
+ def evictable_size(self):
60
+ return 0
@@ -23,6 +23,8 @@ from collections import defaultdict
23
23
 
24
24
  import torch
25
25
 
26
+ from sglang.srt.mem_cache.base_cache import BasePrefixCache
27
+
26
28
 
27
29
  class TreeNode:
28
30
  def __init__(self):
@@ -46,7 +48,7 @@ def _key_match(key0, key1):
46
48
  return i
47
49
 
48
50
 
49
- class RadixCache:
51
+ class RadixCache(BasePrefixCache):
50
52
  def __init__(self, req_to_token_pool, token_to_kv_pool, disable: bool = False):
51
53
  self.req_to_token_pool = req_to_token_pool
52
54
  self.token_to_kv_pool = token_to_kv_pool
@@ -62,7 +64,7 @@ class RadixCache:
62
64
  self.root_node.lock_ref = 1
63
65
  self.evictable_size_ = 0
64
66
 
65
- def match_prefix(self, key):
67
+ def match_prefix(self, key, **kwargs):
66
68
  if self.disable:
67
69
  return [], self.root_node
68
70
 
@@ -90,6 +92,7 @@ class RadixCache:
90
92
  req_pool_idx,
91
93
  del_in_memory_pool=True,
92
94
  old_last_node=None,
95
+ **kwargs,
93
96
  ):
94
97
  # Insert the request into radix cache
95
98
  indices = self.req_to_token_pool.req_to_token[req_pool_idx, : len(token_ids)]
@@ -19,6 +19,7 @@ import importlib
19
19
  import importlib.resources
20
20
  import logging
21
21
  import pkgutil
22
+ import warnings
22
23
  from functools import lru_cache
23
24
  from typing import Optional, Type
24
25
 
@@ -121,7 +122,11 @@ class ModelRunner:
121
122
 
122
123
  # Load the model and create memory pool
123
124
  self.load_model()
124
- self.init_memory_pool(total_gpu_memory, server_args.max_num_reqs)
125
+ self.init_memory_pool(
126
+ total_gpu_memory,
127
+ server_args.max_num_reqs,
128
+ server_args.max_total_tokens,
129
+ )
125
130
  self.init_cublas()
126
131
  self.init_flash_infer()
127
132
 
@@ -203,8 +208,18 @@ class ModelRunner:
203
208
  max_num_token = int(rest_memory * (1 << 30) // cell_size)
204
209
  return max_num_token
205
210
 
206
- def init_memory_pool(self, total_gpu_memory, max_num_reqs=None):
211
+ def init_memory_pool(
212
+ self, total_gpu_memory, max_num_reqs=None, max_total_tokens=None
213
+ ):
207
214
  self.max_total_num_tokens = self.profile_max_num_token(total_gpu_memory)
215
+ if max_total_tokens is not None:
216
+ if max_total_tokens > self.max_total_num_tokens:
217
+ warnings.warn(
218
+ f"max_total_tokens={max_total_tokens} is larger than the profiled value "
219
+ f"{self.max_total_num_tokens}. "
220
+ f"Use the profiled value instead."
221
+ )
222
+ self.max_total_num_tokens = min(self.max_total_num_tokens, max_total_tokens)
208
223
 
209
224
  if self.max_total_num_tokens <= 0:
210
225
  raise RuntimeError(
@@ -26,6 +26,11 @@ from vllm.config import CacheConfig
26
26
  from vllm.distributed import get_tensor_model_parallel_world_size
27
27
  from vllm.model_executor.layers.activation import SiluAndMul
28
28
  from vllm.model_executor.layers.layernorm import RMSNorm
29
+ from vllm.model_executor.layers.linear import (
30
+ MergedColumnParallelLinear,
31
+ QKVParallelLinear,
32
+ RowParallelLinear,
33
+ )
29
34
  from vllm.model_executor.layers.quantization.base_config import QuantizationConfig
30
35
  from vllm.model_executor.layers.rotary_embedding import get_rope
31
36
  from vllm.model_executor.layers.vocab_parallel_embedding import (
@@ -38,10 +43,6 @@ from sglang.srt.layers.logits_processor import LogitsProcessor
38
43
  from sglang.srt.layers.radix_attention import RadixAttention
39
44
  from sglang.srt.model_executor.model_runner import InputMetadata
40
45
 
41
- MergedColumnParallelLinear = None
42
- QKVParallelLinear = None
43
- RowParallelLinear = None
44
-
45
46
 
46
47
  class LlamaMLP(nn.Module):
47
48
  def __init__(
@@ -295,23 +296,6 @@ class LlamaForCausalLM(nn.Module):
295
296
  cache_config: Optional[CacheConfig] = None,
296
297
  efficient_weight_load=False,
297
298
  ) -> None:
298
- global MergedColumnParallelLinear
299
- global QKVParallelLinear
300
- global RowParallelLinear
301
-
302
- if efficient_weight_load:
303
- from sglang.srt.layers.linear import (
304
- MergedColumnParallelLinear,
305
- QKVParallelLinear,
306
- RowParallelLinear,
307
- )
308
- else:
309
- from vllm.model_executor.layers.linear import (
310
- MergedColumnParallelLinear,
311
- QKVParallelLinear,
312
- RowParallelLinear,
313
- )
314
-
315
299
  super().__init__()
316
300
  self.config = config
317
301
  self.quant_config = quant_config