sglang 0.1.13__py3-none-any.whl → 0.1.15__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (57) hide show
  1. sglang/__init__.py +55 -2
  2. sglang/api.py +3 -5
  3. sglang/backend/anthropic.py +33 -13
  4. sglang/backend/openai.py +2 -1
  5. sglang/backend/runtime_endpoint.py +18 -5
  6. sglang/backend/vertexai.py +1 -0
  7. sglang/global_config.py +1 -0
  8. sglang/lang/chat_template.py +74 -0
  9. sglang/lang/interpreter.py +40 -16
  10. sglang/lang/ir.py +1 -1
  11. sglang/lang/tracer.py +6 -4
  12. sglang/launch_server.py +2 -1
  13. sglang/srt/constrained/fsm_cache.py +15 -3
  14. sglang/srt/constrained/jump_forward.py +1 -0
  15. sglang/srt/conversation.py +2 -2
  16. sglang/srt/hf_transformers_utils.py +2 -1
  17. sglang/srt/layers/context_flashattention_nopad.py +1 -0
  18. sglang/srt/layers/extend_attention.py +1 -0
  19. sglang/srt/layers/logits_processor.py +114 -54
  20. sglang/srt/layers/radix_attention.py +2 -1
  21. sglang/srt/layers/token_attention.py +1 -0
  22. sglang/srt/managers/detokenizer_manager.py +5 -1
  23. sglang/srt/managers/io_struct.py +12 -0
  24. sglang/srt/managers/router/infer_batch.py +70 -33
  25. sglang/srt/managers/router/manager.py +7 -2
  26. sglang/srt/managers/router/model_rpc.py +116 -73
  27. sglang/srt/managers/router/model_runner.py +121 -155
  28. sglang/srt/managers/router/radix_cache.py +46 -38
  29. sglang/srt/managers/tokenizer_manager.py +56 -11
  30. sglang/srt/memory_pool.py +5 -14
  31. sglang/srt/model_config.py +7 -0
  32. sglang/srt/models/commandr.py +376 -0
  33. sglang/srt/models/dbrx.py +413 -0
  34. sglang/srt/models/dbrx_config.py +281 -0
  35. sglang/srt/models/gemma.py +22 -20
  36. sglang/srt/models/llama2.py +23 -21
  37. sglang/srt/models/llava.py +12 -10
  38. sglang/srt/models/mixtral.py +27 -25
  39. sglang/srt/models/qwen.py +23 -21
  40. sglang/srt/models/qwen2.py +23 -21
  41. sglang/srt/models/stablelm.py +292 -0
  42. sglang/srt/models/yivl.py +6 -5
  43. sglang/srt/openai_api_adapter.py +356 -0
  44. sglang/srt/{managers/openai_protocol.py → openai_protocol.py} +36 -20
  45. sglang/srt/sampling_params.py +2 -0
  46. sglang/srt/server.py +68 -439
  47. sglang/srt/server_args.py +76 -49
  48. sglang/srt/utils.py +88 -32
  49. sglang/srt/weight_utils.py +402 -0
  50. sglang/test/test_programs.py +8 -7
  51. sglang/test/test_utils.py +196 -8
  52. {sglang-0.1.13.dist-info → sglang-0.1.15.dist-info}/METADATA +13 -15
  53. sglang-0.1.15.dist-info/RECORD +69 -0
  54. {sglang-0.1.13.dist-info → sglang-0.1.15.dist-info}/WHEEL +1 -1
  55. sglang-0.1.13.dist-info/RECORD +0 -63
  56. {sglang-0.1.13.dist-info → sglang-0.1.15.dist-info}/LICENSE +0 -0
  57. {sglang-0.1.13.dist-info → sglang-0.1.15.dist-info}/top_level.txt +0 -0
@@ -1,11 +1,12 @@
1
1
  import torch
2
- from sglang.srt.managers.router.model_runner import ForwardMode, InputMetadata
3
2
  from torch import nn
4
- from vllm.model_executor.parallel_utils.communication_op import (
3
+ from vllm.distributed import (
5
4
  get_tensor_model_parallel_world_size,
6
5
  tensor_model_parallel_all_gather,
7
6
  )
8
7
 
8
+ from sglang.srt.managers.router.model_runner import ForwardMode, InputMetadata
9
+
9
10
 
10
11
  class LogitsProcessor(nn.Module):
11
12
  def __init__(self, config):
@@ -13,76 +14,136 @@ class LogitsProcessor(nn.Module):
13
14
  self.config = config
14
15
  self.tp_size = get_tensor_model_parallel_world_size()
15
16
 
16
- def forward(self, input_ids, hidden_states, weight, input_metadata):
17
- last_index = None
17
+ def _get_normalized_prompt_logprobs(
18
+ self, prefill_token_logprobs, input_metadata: InputMetadata
19
+ ):
20
+ logprobs_cumsum = torch.cumsum(
21
+ prefill_token_logprobs, dim=0, dtype=torch.float32
22
+ )
18
23
 
19
- # Compute the last index (the first decode token) of each requeast
20
- # if we are in prefill or extend mode.
24
+ start = input_metadata.extend_start_loc.clone()
25
+ end = start + input_metadata.extend_seq_lens - 2
26
+ start.clamp_(min=0, max=prefill_token_logprobs.shape[0] - 1)
27
+ end.clamp_(min=0, max=prefill_token_logprobs.shape[0] - 1)
28
+ sum_logp = (
29
+ logprobs_cumsum[end]
30
+ - logprobs_cumsum[start]
31
+ + prefill_token_logprobs[start]
32
+ )
33
+ normalized_prompt_logprobs = sum_logp / (
34
+ (input_metadata.extend_seq_lens - 1).clamp(min=1)
35
+ )
36
+
37
+ return normalized_prompt_logprobs
38
+
39
+ def _get_top_logprobs(self, all_logprobs, input_metadata: InputMetadata):
40
+ if input_metadata.forward_mode == ForwardMode.DECODE:
41
+ decode_top_logprobs = []
42
+ for i in range(all_logprobs.shape[0]):
43
+ k = input_metadata.top_logprobs_nums[i]
44
+ t = all_logprobs[i].topk(k)
45
+ v_cpu = t.values.tolist()
46
+ p_cpu = t.indices.tolist()
47
+ decode_top_logprobs.append(list(zip(v_cpu, p_cpu)))
48
+ return None, decode_top_logprobs
49
+ else:
50
+ prefill_top_logprobs, decode_top_logprobs = [], []
51
+ pt = 0
52
+ # NOTE: the GPU-CPU overhead can be reduced
53
+ extend_seq_lens_cpu = input_metadata.extend_seq_lens.cpu().numpy()
54
+ for i in range(len(extend_seq_lens_cpu)):
55
+ if extend_seq_lens_cpu[i] == 0:
56
+ prefill_top_logprobs.append([])
57
+ decode_top_logprobs.append([])
58
+ continue
59
+ k = input_metadata.top_logprobs_nums[i]
60
+ t = all_logprobs[pt : pt + extend_seq_lens_cpu[i]].topk(k)
61
+ vs_cpu = t.values.tolist()
62
+ ps_cpu = t.indices.tolist()
63
+ prefill_top_logprobs.append(
64
+ [list(zip(vs_cpu[j], ps_cpu[j])) for j in range(len(vs_cpu) - 1)]
65
+ )
66
+ decode_top_logprobs.append(list(zip(vs_cpu[-1], ps_cpu[-1])))
67
+ pt += extend_seq_lens_cpu[i]
68
+ return prefill_top_logprobs, decode_top_logprobs
69
+
70
+ def forward(self, input_ids, hidden_states, weight, input_metadata: InputMetadata):
71
+ # Get last index for next token prediction, except for DECODE mode.
72
+ last_index = None
21
73
  if input_metadata.forward_mode != ForwardMode.DECODE:
22
74
  last_index = (
23
- torch.cumsum(
24
- input_metadata.seq_lens - input_metadata.prefix_lens,
25
- dim=0,
26
- dtype=torch.long,
27
- )
75
+ torch.cumsum(input_metadata.extend_seq_lens, dim=0, dtype=torch.long)
28
76
  - 1
29
77
  )
30
78
 
79
+ # Get the last hidden states and last logits
80
+ if input_metadata.forward_mode == ForwardMode.DECODE:
81
+ last_hidden = hidden_states
82
+ else:
83
+ last_hidden = hidden_states[last_index]
84
+
85
+ last_logits = torch.matmul(last_hidden, weight.T)
86
+ if self.tp_size > 1:
87
+ last_logits = tensor_model_parallel_all_gather(last_logits)
88
+ last_logits = last_logits[:, : self.config.vocab_size]
89
+
90
+ # Return only last_logits if logprob is not requested
31
91
  if not input_metadata.return_logprob:
32
- # When logprob is not requested, only compute the last logits.
33
- if input_metadata.forward_mode == ForwardMode.DECODE:
34
- last_hidden = hidden_states
35
- else:
36
- last_hidden = hidden_states[last_index]
37
- hidden_states = None
38
-
39
- last_logits = torch.matmul(last_hidden, weight.T)
40
- if self.tp_size > 1:
41
- last_logits = tensor_model_parallel_all_gather(last_logits)
42
- last_logits = last_logits[:, : self.config.vocab_size]
43
- return last_logits, (None, None, None)
92
+ hidden_states = None
93
+ return last_logits, (None, None, None, None, None)
44
94
  else:
45
95
  # When logprob is requested, compute the logits for all tokens.
46
- logits = torch.matmul(hidden_states, weight.T)
47
- if self.tp_size > 1:
48
- logits = tensor_model_parallel_all_gather(logits)
49
- logits = logits[:, : self.config.vocab_size]
50
- all_logprobs = torch.log(torch.softmax(logits.float(), dim=-1) + 1e-6)
96
+ if input_metadata.forward_mode == ForwardMode.DECODE:
97
+ all_logits = last_logits
98
+ else:
99
+ all_logits = torch.matmul(hidden_states, weight.T)
100
+ if self.tp_size > 1:
101
+ all_logits = tensor_model_parallel_all_gather(all_logits)
102
+ all_logits = all_logits[:, : self.config.vocab_size]
103
+
104
+ all_logprobs = all_logits.float()
105
+ del all_logits
106
+ all_logprobs[:] = torch.nn.functional.log_softmax(all_logprobs, dim=-1)
107
+
108
+ return_top_logprob = any(x > 0 for x in input_metadata.top_logprobs_nums)
109
+ if return_top_logprob:
110
+ prefill_top_logprobs, decode_top_logprobs = self._get_top_logprobs(
111
+ all_logprobs, input_metadata
112
+ )
113
+ else:
114
+ prefill_top_logprobs = decode_top_logprobs = None
51
115
 
52
116
  if input_metadata.forward_mode == ForwardMode.DECODE:
53
- last_logits = logits
54
117
  last_logprobs = all_logprobs
55
- prefill_logprobs = normalized_logprobs = None
118
+ return last_logits, (
119
+ None,
120
+ None,
121
+ None,
122
+ decode_top_logprobs,
123
+ last_logprobs,
124
+ )
56
125
  else:
57
126
  # Compute the logprobs for the last token of each request.
58
- last_logits = logits[last_index]
59
127
  last_logprobs = all_logprobs[last_index]
60
128
 
61
129
  # Compute the logprobs and normalized logprobs for the prefill tokens.
62
130
  # Note that we pad a zero at the end of each sequence for easy computation.
63
- prefill_logprobs = all_logprobs[
131
+ prefill_token_logprobs = all_logprobs[
64
132
  torch.arange(all_logprobs.shape[0], device="cuda"),
65
133
  torch.cat([input_ids[1:], torch.tensor([0], device="cuda")]),
66
134
  ]
67
- logprobs_cumsum = torch.cumsum(
68
- prefill_logprobs, dim=0, dtype=torch.float32
69
- )
70
135
 
71
- start = input_metadata.extend_start_loc.clone()
72
- end = start + input_metadata.extend_seq_lens - 2
73
- start.clamp_(min=0, max=prefill_logprobs.shape[0] - 1)
74
- end.clamp_(min=0, max=prefill_logprobs.shape[0] - 1)
75
- sum_logp = (
76
- logprobs_cumsum[end]
77
- - logprobs_cumsum[start]
78
- + prefill_logprobs[start]
136
+ normalized_prompt_logprobs = self._get_normalized_prompt_logprobs(
137
+ prefill_token_logprobs, input_metadata
79
138
  )
80
- normalized_logprobs = sum_logp / (
81
- (input_metadata.extend_seq_lens - 1).clamp(min=1)
139
+ return last_logits, (
140
+ prefill_token_logprobs,
141
+ normalized_prompt_logprobs,
142
+ prefill_top_logprobs,
143
+ decode_top_logprobs,
144
+ last_logprobs,
82
145
  )
83
146
 
84
- return last_logits, (prefill_logprobs, normalized_logprobs, last_logprobs)
85
-
86
147
 
87
148
  if __name__ == "__main__":
88
149
  all_logprobs = torch.tensor(
@@ -93,23 +154,22 @@ if __name__ == "__main__":
93
154
  )
94
155
  seq_lens = torch.tensor([2, 0, 3, 0], dtype=torch.int32, device="cuda")
95
156
  input_ids = torch.tensor([1, 2, 3, 0, 1], dtype=torch.int32, device="cuda")
96
- logprobs = torch.zeros(5, dtype=torch.float32, device="cuda")
97
157
 
98
- logprobs = all_logprobs[
158
+ token_logprobs = all_logprobs[
99
159
  torch.arange(all_logprobs.shape[0], device="cuda"),
100
160
  torch.cat([input_ids[1:], torch.tensor([0], device="cuda")]),
101
161
  ]
102
- logprobs_cumsum = torch.cumsum(logprobs, dim=0, dtype=torch.float32)
162
+ logprobs_cumsum = torch.cumsum(token_logprobs, dim=0, dtype=torch.float32)
103
163
 
104
164
  len_cumsum = torch.cumsum(seq_lens, dim=0)
105
165
  start = torch.cat((torch.tensor([0], device="cuda"), len_cumsum[:-1]), 0)
106
166
  end = start + seq_lens - 2
107
- start.clamp_(min=0, max=logprobs.shape[0] - 1)
108
- end.clamp_(min=0, max=logprobs.shape[0] - 1)
109
- sum_logp = logprobs_cumsum[end] - logprobs_cumsum[start] + logprobs[start]
167
+ start.clamp_(min=0, max=token_logprobs.shape[0] - 1)
168
+ end.clamp_(min=0, max=token_logprobs.shape[0] - 1)
169
+ sum_logp = logprobs_cumsum[end] - logprobs_cumsum[start] + token_logprobs[start]
110
170
 
111
171
  # assert logprobs == [2, _, 2, 4, _]
112
- print("logprobs", logprobs)
172
+ print("token logprobs", token_logprobs)
113
173
  print("start", start)
114
174
  print("end", end)
115
175
  print("sum_logp", sum_logp)
@@ -1,9 +1,10 @@
1
1
  import torch
2
+ from torch import nn
3
+
2
4
  from sglang.srt.layers.context_flashattention_nopad import context_attention_fwd
3
5
  from sglang.srt.layers.extend_attention import extend_attention_fwd
4
6
  from sglang.srt.layers.token_attention import token_attention_fwd
5
7
  from sglang.srt.managers.router.model_runner import ForwardMode, InputMetadata
6
- from torch import nn
7
8
 
8
9
 
9
10
  class RadixAttention(nn.Module):
@@ -4,6 +4,7 @@
4
4
  import torch
5
5
  import triton
6
6
  import triton.language as tl
7
+
7
8
  from sglang.srt.managers.router.model_runner import global_server_args_dict
8
9
  from sglang.srt.utils import wrap_kernel_launcher
9
10
 
@@ -3,6 +3,7 @@ import asyncio
3
3
  import uvloop
4
4
  import zmq
5
5
  import zmq.asyncio
6
+
6
7
  from sglang.srt.hf_transformers_utils import get_tokenizer
7
8
  from sglang.srt.managers.io_struct import BatchStrOut, BatchTokenIDOut
8
9
  from sglang.srt.server_args import PortArgs, ServerArgs
@@ -37,10 +38,13 @@ class DetokenizerManager:
37
38
  if isinstance(recv_obj, BatchTokenIDOut):
38
39
  output_tokens = recv_obj.output_tokens
39
40
 
40
- # TODO(lmzheng): handle skip_special_tokens per request
41
+ # TODO(lmzheng): handle skip_special_tokens/spaces_between_special_tokens per request
41
42
  output_strs = self.tokenizer.batch_decode(
42
43
  output_tokens,
43
44
  skip_special_tokens=recv_obj.skip_special_tokens[0],
45
+ spaces_between_special_tokens=recv_obj.spaces_between_special_tokens[
46
+ 0
47
+ ],
44
48
  )
45
49
 
46
50
  # Trim stop str
@@ -19,10 +19,13 @@ class GenerateReqInput:
19
19
  return_logprob: Optional[Union[List[bool], bool]] = None
20
20
  # The start location of the prompt for return_logprob
21
21
  logprob_start_len: Optional[Union[List[int], int]] = None
22
+ # The number of top logprobs to return
23
+ top_logprobs_num: Optional[Union[List[int], int]] = None
22
24
  # Whether to detokenize tokens in logprobs
23
25
  return_text_in_logprobs: bool = False
24
26
  # Whether to stream output
25
27
  stream: bool = False
28
+ # TODO: make all parameters a Union[List[T], T] to allow for batched requests
26
29
 
27
30
  def post_init(self):
28
31
  is_single = isinstance(self.text, str)
@@ -36,6 +39,8 @@ class GenerateReqInput:
36
39
  self.return_logprob = False
37
40
  if self.logprob_start_len is None:
38
41
  self.logprob_start_len = 0
42
+ if self.top_logprobs_num is None:
43
+ self.top_logprobs_num = 0
39
44
  else:
40
45
  num = len(self.text)
41
46
 
@@ -64,6 +69,11 @@ class GenerateReqInput:
64
69
  elif not isinstance(self.logprob_start_len, list):
65
70
  self.logprob_start_len = [self.logprob_start_len] * num
66
71
 
72
+ if self.top_logprobs_num is None:
73
+ self.top_logprobs_num = [0] * num
74
+ elif not isinstance(self.top_logprobs_num, list):
75
+ self.top_logprobs_num = [self.top_logprobs_num] * num
76
+
67
77
 
68
78
  @dataclass
69
79
  class TokenizedGenerateReqInput:
@@ -76,6 +86,7 @@ class TokenizedGenerateReqInput:
76
86
  sampling_params: SamplingParams
77
87
  return_logprob: bool
78
88
  logprob_start_len: int
89
+ top_logprobs_num: int
79
90
  stream: bool
80
91
 
81
92
 
@@ -86,6 +97,7 @@ class BatchTokenIDOut:
86
97
  output_and_jump_forward_strs: List[str]
87
98
  hit_stop_str: List[Optional[str]]
88
99
  skip_special_tokens: List[bool]
100
+ spaces_between_special_tokens: List[bool]
89
101
  meta_info: List[Dict]
90
102
  finished: List[bool]
91
103
 
@@ -1,22 +1,23 @@
1
1
  from dataclasses import dataclass
2
- from enum import Enum, auto
2
+ from enum import IntEnum, auto
3
3
  from typing import List
4
4
 
5
5
  import numpy as np
6
6
  import torch
7
+
7
8
  from sglang.srt.managers.router.radix_cache import RadixCache
8
9
  from sglang.srt.memory_pool import ReqToTokenPool, TokenToKVPool
9
10
 
10
11
 
11
- class ForwardMode(Enum):
12
+ class ForwardMode(IntEnum):
12
13
  PREFILL = auto()
13
14
  EXTEND = auto()
14
15
  DECODE = auto()
15
16
 
16
17
 
17
- class FinishReason(Enum):
18
- LENGTH = auto()
18
+ class FinishReason(IntEnum):
19
19
  EOS_TOKEN = auto()
20
+ LENGTH = auto()
20
21
  STOP_STR = auto()
21
22
 
22
23
 
@@ -30,6 +31,7 @@ class Req:
30
31
  # Since jump forward may retokenize the prompt with partial outputs,
31
32
  # we maintain the original prompt length to report the correct usage.
32
33
  self.prompt_tokens = len(input_ids)
34
+
33
35
  # The number of decoded tokens for token usage report. Note that
34
36
  # this does not include the jump forward tokens.
35
37
  self.completion_tokens_wo_jump_forward = 0
@@ -40,11 +42,11 @@ class Req:
40
42
  self.image_offset = 0
41
43
  self.pad_value = None
42
44
 
45
+ # Sampling parameters
43
46
  self.sampling_params = None
44
- self.return_logprob = False
45
- self.logprob_start_len = 0
46
47
  self.stream = False
47
48
 
49
+ # Check finish
48
50
  self.tokenizer = None
49
51
  self.finished = False
50
52
  self.finish_reason = None
@@ -54,11 +56,17 @@ class Req:
54
56
  self.prefix_indices = []
55
57
  self.last_node = None
56
58
 
57
- self.logprob = None
58
- self.token_logprob = None
59
- self.normalized_logprob = None
60
-
61
- # For constrained decoding
59
+ # Logprobs
60
+ self.return_logprob = False
61
+ self.logprob_start_len = 0
62
+ self.top_logprobs_num = 0
63
+ self.normalized_prompt_logprob = None
64
+ self.prefill_token_logprobs = None
65
+ self.decode_token_logprobs = None
66
+ self.prefill_top_logprobs = None
67
+ self.decode_top_logprobs = None
68
+
69
+ # Constrained decoding
62
70
  self.regex_fsm = None
63
71
  self.regex_fsm_state = 0
64
72
  self.jump_forward_map = None
@@ -159,7 +167,10 @@ class Batch:
159
167
  out_cache_loc: torch.Tensor = None
160
168
  out_cache_cont_start: torch.Tensor = None
161
169
  out_cache_cont_end: torch.Tensor = None
170
+
171
+ # for processing logprobs
162
172
  return_logprob: bool = False
173
+ top_logprobs_nums: List[int] = None
163
174
 
164
175
  # for multimodal
165
176
  pixel_values: List[torch.Tensor] = None
@@ -229,12 +240,11 @@ class Batch:
229
240
  extend_num_tokens = seq_lens.sum() - prefix_lens.sum()
230
241
  out_cache_loc = self.token_to_kv_pool.alloc(extend_num_tokens)
231
242
  if out_cache_loc is None:
232
- if not self.tree_cache.disable:
233
- self.tree_cache.evict(extend_num_tokens, self.token_to_kv_pool.free)
234
- out_cache_loc = self.token_to_kv_pool.alloc(extend_num_tokens)
243
+ self.tree_cache.evict(extend_num_tokens, self.token_to_kv_pool.dec_refs)
244
+ out_cache_loc = self.token_to_kv_pool.alloc(extend_num_tokens)
235
245
 
236
246
  if out_cache_loc is None:
237
- print("Prefill out of memory. This should nerver happen.")
247
+ print("Prefill out of memory. This should never happen.")
238
248
  self.tree_cache.pretty_print()
239
249
  exit()
240
250
 
@@ -245,10 +255,14 @@ class Batch:
245
255
  ] = out_cache_loc[pt : pt + extend_lens[i]]
246
256
  pt += extend_lens[i]
247
257
 
248
- # Handle logit bias
249
- logit_bias = torch.zeros((bs, vocab_size), dtype=torch.float32, device=device)
258
+ # Handle logit bias but only allocate when needed
259
+ logit_bias = None
250
260
  for i in range(bs):
251
261
  if reqs[i].sampling_params.dtype == "int":
262
+ if logit_bias is None:
263
+ logit_bias = torch.zeros(
264
+ (bs, vocab_size), dtype=torch.float32, device=device
265
+ )
252
266
  logit_bias[i] = int_token_logit_bias
253
267
 
254
268
  # Set fields
@@ -266,6 +280,7 @@ class Batch:
266
280
  self.position_ids_offsets = position_ids_offsets
267
281
  self.extend_num_tokens = extend_num_tokens
268
282
  self.out_cache_loc = out_cache_loc
283
+ self.top_logprobs_nums = [r.top_logprobs_num for r in reqs]
269
284
 
270
285
  self.temperatures = torch.tensor(
271
286
  [r.sampling_params.temperature for r in reqs],
@@ -295,8 +310,8 @@ class Batch:
295
310
  if self.token_to_kv_pool.available_size() >= bs:
296
311
  return True
297
312
 
298
- if not self.tree_cache.disable:
299
- self.tree_cache.evict(bs, self.token_to_kv_pool.free)
313
+ self.tree_cache.evict(bs, self.token_to_kv_pool.dec_refs)
314
+
300
315
  if self.token_to_kv_pool.available_size() >= bs:
301
316
  return True
302
317
 
@@ -310,8 +325,8 @@ class Batch:
310
325
  )
311
326
 
312
327
  retracted_reqs = []
313
- seq_lens_np = self.seq_lens.cpu().numpy()
314
- req_pool_indices_np = self.req_pool_indices.cpu().numpy()
328
+ seq_lens_cpu = self.seq_lens.cpu().numpy()
329
+ req_pool_indices_cpu = self.req_pool_indices.cpu().numpy()
315
330
  while self.token_to_kv_pool.available_size() < len(self.reqs):
316
331
  idx = sorted_indices.pop()
317
332
  req = self.reqs[idx]
@@ -327,9 +342,9 @@ class Batch:
327
342
  # TODO: apply more fine-grained retraction
328
343
 
329
344
  token_indices = self.req_to_token_pool.req_to_token[
330
- req_pool_indices_np[idx]
331
- ][: seq_lens_np[idx]]
332
- self.token_to_kv_pool.free(token_indices)
345
+ req_pool_indices_cpu[idx]
346
+ ][: seq_lens_cpu[idx]]
347
+ self.token_to_kv_pool.dec_refs(token_indices)
333
348
 
334
349
  self.filter_batch(sorted_indices)
335
350
 
@@ -352,7 +367,7 @@ class Batch:
352
367
  # insert the old request into tree_cache
353
368
  token_ids_in_memory = tuple(req.input_ids + req.output_ids)[:-1]
354
369
  if req_pool_indices_cpu is None:
355
- req_pool_indices_cpu = self.req_pool_indices.cpu().tolist()
370
+ req_pool_indices_cpu = self.req_pool_indices.tolist()
356
371
  req_pool_idx = req_pool_indices_cpu[i]
357
372
  indices = self.req_to_token_pool.req_to_token[
358
373
  req_pool_idx, : len(token_ids_in_memory)
@@ -360,7 +375,7 @@ class Batch:
360
375
  prefix_len = self.tree_cache.insert(
361
376
  token_ids_in_memory, indices.clone()
362
377
  )
363
- self.token_to_kv_pool.free(indices[:prefix_len])
378
+ self.token_to_kv_pool.dec_refs(indices[:prefix_len])
364
379
  self.req_to_token_pool.free(req_pool_idx)
365
380
  self.tree_cache.dec_ref_counter(req.last_node)
366
381
 
@@ -391,7 +406,7 @@ class Batch:
391
406
  self.out_cache_loc = self.token_to_kv_pool.alloc(bs)
392
407
 
393
408
  if self.out_cache_loc is None:
394
- print("Decode out of memory. This should nerver happen.")
409
+ print("Decode out of memory. This should never happen.")
395
410
  self.tree_cache.pretty_print()
396
411
  exit()
397
412
 
@@ -415,6 +430,7 @@ class Batch:
415
430
  self.prefix_lens = None
416
431
  self.position_ids_offsets = self.position_ids_offsets[new_indices]
417
432
  self.out_cache_loc = self.out_cache_cont_start = self.out_cache_cont_end = None
433
+ self.top_logprobs_nums = [self.top_logprobs_nums[i] for i in unfinished_indices]
418
434
  self.return_logprob = any(req.return_logprob for req in self.reqs)
419
435
 
420
436
  for item in [
@@ -425,9 +441,12 @@ class Batch:
425
441
  "presence_penalties",
426
442
  "logit_bias",
427
443
  ]:
428
- setattr(self, item, getattr(self, item)[new_indices])
444
+ self_val = getattr(self, item, None)
445
+ # logit_bias can be None
446
+ if self_val is not None:
447
+ setattr(self, item, self_val[new_indices])
429
448
 
430
- def merge(self, other):
449
+ def merge(self, other: "Batch"):
431
450
  self.reqs.extend(other.reqs)
432
451
 
433
452
  self.req_pool_indices = torch.concat(
@@ -439,6 +458,7 @@ class Batch:
439
458
  [self.position_ids_offsets, other.position_ids_offsets]
440
459
  )
441
460
  self.out_cache_loc = self.out_cache_cont_start = self.out_cache_cont_end = None
461
+ self.top_logprobs_nums.extend(other.top_logprobs_nums)
442
462
  self.return_logprob = any(req.return_logprob for req in self.reqs)
443
463
 
444
464
  for item in [
@@ -447,17 +467,34 @@ class Batch:
447
467
  "top_ks",
448
468
  "frequency_penalties",
449
469
  "presence_penalties",
450
- "logit_bias",
451
470
  ]:
452
- setattr(
453
- self, item, torch.concat([getattr(self, item), getattr(other, item)])
471
+ self_val = getattr(self, item, None)
472
+ other_val = getattr(other, item, None)
473
+ setattr(self, item, torch.concat([self_val, other_val]))
474
+
475
+ # logit_bias can be None
476
+ if self.logit_bias is not None or other.logit_bias is not None:
477
+ vocab_size = (
478
+ self.logit_bias.shape[1]
479
+ if self.logit_bias is not None
480
+ else other.logit_bias.shape[1]
454
481
  )
482
+ if self.logit_bias is None:
483
+ self.logit_bias = torch.zeros(
484
+ (len(self.reqs), vocab_size), dtype=torch.float32, device="cuda"
485
+ )
486
+ if other.logit_bias is None:
487
+ other.logit_bias = torch.zeros(
488
+ (len(other.reqs), vocab_size), dtype=torch.float32, device="cuda"
489
+ )
490
+ self.logit_bias = torch.concat([self.logit_bias, other.logit_bias])
455
491
 
456
492
  def sample(self, logits: torch.Tensor):
457
493
  # Post process logits
458
494
  logits = logits.contiguous()
459
495
  logits.div_(self.temperatures)
460
- logits.add_(self.logit_bias)
496
+ if self.logit_bias is not None:
497
+ logits.add_(self.logit_bias)
461
498
 
462
499
  has_regex = any(req.regex_fsm is not None for req in self.reqs)
463
500
  if has_regex:
@@ -4,6 +4,7 @@ import logging
4
4
  import uvloop
5
5
  import zmq
6
6
  import zmq.asyncio
7
+
7
8
  from sglang.srt.backend_config import GLOBAL_BACKEND_CONFIG
8
9
  from sglang.srt.managers.router.model_rpc import ModelRpcClient
9
10
  from sglang.srt.server_args import PortArgs, ServerArgs
@@ -41,12 +42,16 @@ class RouterManager:
41
42
  self.send_to_detokenizer.send_pyobj(obj)
42
43
 
43
44
  # async sleep for receiving the subsequent request and avoiding cache miss
45
+ slept = False
44
46
  if len(out_pyobjs) != 0:
45
47
  has_finished = any([obj.finished for obj in out_pyobjs])
46
48
  if has_finished:
47
- await asyncio.sleep(self.extend_dependency_time)
49
+ if self.extend_dependency_time > 0:
50
+ slept = True
51
+ await asyncio.sleep(self.extend_dependency_time)
48
52
 
49
- await asyncio.sleep(0.0006)
53
+ if not slept:
54
+ await asyncio.sleep(0.0006)
50
55
 
51
56
  async def loop_for_recv_requests(self):
52
57
  while True: