sglang 0.1.15__py3-none-any.whl → 0.1.17__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (69) hide show
  1. sglang/__init__.py +5 -1
  2. sglang/api.py +8 -3
  3. sglang/backend/anthropic.py +1 -1
  4. sglang/backend/litellm.py +90 -0
  5. sglang/backend/openai.py +148 -12
  6. sglang/backend/runtime_endpoint.py +18 -10
  7. sglang/global_config.py +11 -1
  8. sglang/lang/chat_template.py +9 -2
  9. sglang/lang/interpreter.py +161 -81
  10. sglang/lang/ir.py +29 -11
  11. sglang/lang/tracer.py +1 -1
  12. sglang/launch_server.py +1 -2
  13. sglang/launch_server_llavavid.py +31 -0
  14. sglang/srt/constrained/fsm_cache.py +3 -0
  15. sglang/srt/flush_cache.py +16 -0
  16. sglang/srt/hf_transformers_utils.py +83 -2
  17. sglang/srt/layers/extend_attention.py +17 -0
  18. sglang/srt/layers/fused_moe.py +485 -0
  19. sglang/srt/layers/logits_processor.py +12 -7
  20. sglang/srt/layers/radix_attention.py +10 -3
  21. sglang/srt/layers/token_attention.py +16 -1
  22. sglang/srt/managers/controller/dp_worker.py +110 -0
  23. sglang/srt/managers/controller/infer_batch.py +619 -0
  24. sglang/srt/managers/controller/manager_multi.py +191 -0
  25. sglang/srt/managers/controller/manager_single.py +97 -0
  26. sglang/srt/managers/controller/model_runner.py +462 -0
  27. sglang/srt/managers/controller/radix_cache.py +267 -0
  28. sglang/srt/managers/controller/schedule_heuristic.py +59 -0
  29. sglang/srt/managers/controller/tp_worker.py +791 -0
  30. sglang/srt/managers/detokenizer_manager.py +45 -45
  31. sglang/srt/managers/io_struct.py +26 -10
  32. sglang/srt/managers/router/infer_batch.py +130 -74
  33. sglang/srt/managers/router/manager.py +7 -9
  34. sglang/srt/managers/router/model_rpc.py +224 -135
  35. sglang/srt/managers/router/model_runner.py +94 -107
  36. sglang/srt/managers/router/radix_cache.py +54 -18
  37. sglang/srt/managers/router/scheduler.py +23 -34
  38. sglang/srt/managers/tokenizer_manager.py +183 -88
  39. sglang/srt/model_config.py +5 -2
  40. sglang/srt/models/commandr.py +15 -22
  41. sglang/srt/models/dbrx.py +22 -29
  42. sglang/srt/models/gemma.py +14 -24
  43. sglang/srt/models/grok.py +671 -0
  44. sglang/srt/models/llama2.py +24 -23
  45. sglang/srt/models/llava.py +85 -25
  46. sglang/srt/models/llavavid.py +298 -0
  47. sglang/srt/models/mixtral.py +254 -130
  48. sglang/srt/models/mixtral_quant.py +373 -0
  49. sglang/srt/models/qwen.py +28 -25
  50. sglang/srt/models/qwen2.py +17 -22
  51. sglang/srt/models/stablelm.py +21 -26
  52. sglang/srt/models/yivl.py +17 -25
  53. sglang/srt/openai_api_adapter.py +140 -95
  54. sglang/srt/openai_protocol.py +10 -1
  55. sglang/srt/server.py +101 -52
  56. sglang/srt/server_args.py +59 -11
  57. sglang/srt/utils.py +242 -75
  58. sglang/test/test_programs.py +44 -0
  59. sglang/test/test_utils.py +32 -1
  60. sglang/utils.py +95 -26
  61. {sglang-0.1.15.dist-info → sglang-0.1.17.dist-info}/METADATA +23 -13
  62. sglang-0.1.17.dist-info/RECORD +81 -0
  63. sglang/srt/backend_config.py +0 -13
  64. sglang/srt/models/dbrx_config.py +0 -281
  65. sglang/srt/weight_utils.py +0 -402
  66. sglang-0.1.15.dist-info/RECORD +0 -69
  67. {sglang-0.1.15.dist-info → sglang-0.1.17.dist-info}/LICENSE +0 -0
  68. {sglang-0.1.15.dist-info → sglang-0.1.17.dist-info}/WHEEL +0 -0
  69. {sglang-0.1.15.dist-info → sglang-0.1.17.dist-info}/top_level.txt +0 -0
@@ -1,4 +1,5 @@
1
1
  import asyncio
2
+ import inspect
2
3
 
3
4
  import uvloop
4
5
  import zmq
@@ -7,7 +8,8 @@ import zmq.asyncio
7
8
  from sglang.srt.hf_transformers_utils import get_tokenizer
8
9
  from sglang.srt.managers.io_struct import BatchStrOut, BatchTokenIDOut
9
10
  from sglang.srt.server_args import PortArgs, ServerArgs
10
- from sglang.srt.utils import get_exception_traceback
11
+ from sglang.utils import get_exception_traceback, graceful_registry
12
+ from sglang.srt.managers.controller.infer_batch import FINISH_MATCHED_STR
11
13
 
12
14
  asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
13
15
 
@@ -33,51 +35,47 @@ class DetokenizerManager:
33
35
 
34
36
  async def handle_loop(self):
35
37
  while True:
36
- recv_obj = await self.recv_from_router.recv_pyobj()
37
-
38
- if isinstance(recv_obj, BatchTokenIDOut):
39
- output_tokens = recv_obj.output_tokens
40
-
41
- # TODO(lmzheng): handle skip_special_tokens/spaces_between_special_tokens per request
42
- output_strs = self.tokenizer.batch_decode(
43
- output_tokens,
44
- skip_special_tokens=recv_obj.skip_special_tokens[0],
45
- spaces_between_special_tokens=recv_obj.spaces_between_special_tokens[
46
- 0
47
- ],
48
- )
49
-
50
- # Trim stop str
51
- # TODO(lmzheng): handle the case where multiple stop strs are hit
52
- for i in range(len(output_strs)):
53
- if recv_obj.hit_stop_str[i] is not None:
54
- pos = output_strs[i].find(recv_obj.hit_stop_str[i])
55
- if pos != -1:
56
- output_strs[i] = output_strs[i][:pos]
57
-
58
- if len(output_tokens[i]) > 0:
59
- first_token = self.tokenizer.convert_ids_to_tokens(
60
- int(output_tokens[i][0])
61
- )
62
- if not isinstance(first_token, str):
63
- first_token = first_token.decode("utf-8", errors="ignore")
64
- if first_token.startswith("▁"):
65
- output_strs[i] = " " + output_strs[i]
66
-
67
- output_strs[i] = (
68
- recv_obj.output_and_jump_forward_strs[i] + output_strs[i]
69
- )
70
-
71
- self.send_to_tokenizer.send_pyobj(
72
- BatchStrOut(
73
- recv_obj.rids,
74
- output_strs,
75
- recv_obj.meta_info,
76
- recv_obj.finished,
38
+ recv_obj: BatchTokenIDOut = await self.recv_from_router.recv_pyobj()
39
+ assert isinstance(recv_obj, BatchTokenIDOut)
40
+
41
+ output_tokens = recv_obj.output_tokens
42
+
43
+ # TODO(lmzheng): handle skip_special_tokens/spaces_between_special_tokens per request
44
+ output_strs = self.tokenizer.batch_decode(
45
+ output_tokens,
46
+ skip_special_tokens=recv_obj.skip_special_tokens[0],
47
+ spaces_between_special_tokens=recv_obj.spaces_between_special_tokens[
48
+ 0
49
+ ],
50
+ )
51
+
52
+ # Trim stop str
53
+ # TODO(lmzheng): handle the case where multiple stop strs are hit
54
+ for i in range(len(output_strs)):
55
+ if len(output_tokens[i]) > 0:
56
+ first_token = self.tokenizer.convert_ids_to_tokens(
57
+ int(output_tokens[i][0])
77
58
  )
59
+ if not isinstance(first_token, str):
60
+ first_token = first_token.decode("utf-8", errors="ignore")
61
+ if first_token.startswith("▁"):
62
+ output_strs[i] = " " + output_strs[i]
63
+
64
+ output_strs[i] = recv_obj.prev_output_strs[i] + output_strs[i]
65
+
66
+ if isinstance(recv_obj.finished_reason[i], FINISH_MATCHED_STR):
67
+ pos = output_strs[i].find(recv_obj.finished_reason[i].matched)
68
+ if pos != -1:
69
+ output_strs[i] = output_strs[i][:pos]
70
+
71
+ self.send_to_tokenizer.send_pyobj(
72
+ BatchStrOut(
73
+ rids=recv_obj.rids,
74
+ output_str=output_strs,
75
+ meta_info=recv_obj.meta_info,
76
+ finished_reason=recv_obj.finished_reason,
78
77
  )
79
- else:
80
- raise ValueError(f"Invalid object: {recv_obj}")
78
+ )
81
79
 
82
80
 
83
81
  def start_detokenizer_process(
@@ -85,9 +83,11 @@ def start_detokenizer_process(
85
83
  port_args: PortArgs,
86
84
  pipe_writer,
87
85
  ):
86
+ graceful_registry(inspect.currentframe().f_code.co_name)
87
+
88
88
  try:
89
89
  manager = DetokenizerManager(server_args, port_args)
90
- except Exception as e:
90
+ except Exception:
91
91
  pipe_writer.send(get_exception_traceback())
92
92
  raise
93
93
  pipe_writer.send("init ok")
@@ -3,12 +3,15 @@ from dataclasses import dataclass
3
3
  from typing import Dict, List, Optional, Union
4
4
 
5
5
  from sglang.srt.sampling_params import SamplingParams
6
+ from sglang.srt.managers.controller.infer_batch import BaseFinishReason
6
7
 
7
8
 
8
9
  @dataclass
9
10
  class GenerateReqInput:
10
11
  # The input prompt
11
- text: Union[List[str], str]
12
+ text: Optional[Union[List[str], str]] = None
13
+ # The token ids for text; one can either specify text or input_ids
14
+ input_ids: Optional[Union[List[List[int]], List[int]]] = None
12
15
  # The image input
13
16
  image_data: Optional[Union[List[str], str]] = None
14
17
  # The sampling_params
@@ -25,10 +28,19 @@ class GenerateReqInput:
25
28
  return_text_in_logprobs: bool = False
26
29
  # Whether to stream output
27
30
  stream: bool = False
28
- # TODO: make all parameters a Union[List[T], T] to allow for batched requests
29
31
 
30
32
  def post_init(self):
31
- is_single = isinstance(self.text, str)
33
+
34
+ if (self.text is None and self.input_ids is None) or (
35
+ self.text is not None and self.input_ids is not None
36
+ ):
37
+ raise ValueError("Either text or input_ids should be provided.")
38
+
39
+ if self.text is not None:
40
+ is_single = isinstance(self.text, str)
41
+ else:
42
+ is_single = isinstance(self.input_ids[0], int)
43
+ self.is_single = is_single
32
44
 
33
45
  if is_single:
34
46
  if self.sampling_params is None:
@@ -42,7 +54,7 @@ class GenerateReqInput:
42
54
  if self.top_logprobs_num is None:
43
55
  self.top_logprobs_num = 0
44
56
  else:
45
- num = len(self.text)
57
+ num = len(self.text) if self.text is not None else len(self.input_ids)
46
58
 
47
59
  if self.image_data is None:
48
60
  self.image_data = [None] * num
@@ -57,7 +69,8 @@ class GenerateReqInput:
57
69
  if self.rid is None:
58
70
  self.rid = [uuid.uuid4().hex for _ in range(num)]
59
71
  else:
60
- assert isinstance(self.rid, list)
72
+ if not isinstance(self.rid, list):
73
+ raise ValueError("The rid should be a list.")
61
74
 
62
75
  if self.return_logprob is None:
63
76
  self.return_logprob = [False] * num
@@ -93,21 +106,19 @@ class TokenizedGenerateReqInput:
93
106
  @dataclass
94
107
  class BatchTokenIDOut:
95
108
  rids: List[str]
109
+ prev_output_strs: List[str]
96
110
  output_tokens: List[List[int]]
97
- output_and_jump_forward_strs: List[str]
98
- hit_stop_str: List[Optional[str]]
99
111
  skip_special_tokens: List[bool]
100
112
  spaces_between_special_tokens: List[bool]
101
113
  meta_info: List[Dict]
102
- finished: List[bool]
103
-
114
+ finished_reason: List[BaseFinishReason]
104
115
 
105
116
  @dataclass
106
117
  class BatchStrOut:
107
118
  rids: List[str]
108
119
  output_str: List[str]
109
120
  meta_info: List[Dict]
110
- finished: List[bool]
121
+ finished_reason: List[BaseFinishReason]
111
122
 
112
123
 
113
124
  @dataclass
@@ -115,6 +126,11 @@ class FlushCacheReq:
115
126
  pass
116
127
 
117
128
 
129
+ @dataclass
130
+ class AbortReq:
131
+ rid: str
132
+
133
+
118
134
  @dataclass
119
135
  class DetokenizeReqInput:
120
136
  input_ids: List[int]
@@ -19,18 +19,32 @@ class FinishReason(IntEnum):
19
19
  EOS_TOKEN = auto()
20
20
  LENGTH = auto()
21
21
  STOP_STR = auto()
22
+ ABORT = auto()
23
+
24
+ @staticmethod
25
+ def to_str(reason):
26
+ if reason == FinishReason.EOS_TOKEN:
27
+ return None
28
+ elif reason == FinishReason.LENGTH:
29
+ return "length"
30
+ elif reason == FinishReason.STOP_STR:
31
+ return "stop"
32
+ elif reason == FinishReason.ABORT:
33
+ return "abort"
34
+ else:
35
+ return None
22
36
 
23
37
 
24
38
  class Req:
25
- def __init__(self, rid, input_text, input_ids):
39
+ def __init__(self, rid, origin_input_text, origin_input_ids):
26
40
  self.rid = rid
27
- self.input_text = input_text
28
- self.input_ids = input_ids
41
+ self.origin_input_text = origin_input_text
42
+ self.origin_input_ids = origin_input_ids
43
+ self.origin_input_ids_unpadded = origin_input_ids # before image padding
44
+ self.prev_output_str = ""
45
+ self.prev_output_ids = []
29
46
  self.output_ids = []
30
-
31
- # Since jump forward may retokenize the prompt with partial outputs,
32
- # we maintain the original prompt length to report the correct usage.
33
- self.prompt_tokens = len(input_ids)
47
+ self.input_ids = None # input_ids = origin_input_ids + prev_output_ids
34
48
 
35
49
  # The number of decoded tokens for token usage report. Note that
36
50
  # this does not include the jump forward tokens.
@@ -52,6 +66,7 @@ class Req:
52
66
  self.finish_reason = None
53
67
  self.hit_stop_str = None
54
68
 
69
+ # Prefix info
55
70
  self.extend_input_len = 0
56
71
  self.prefix_indices = []
57
72
  self.last_node = None
@@ -62,67 +77,36 @@ class Req:
62
77
  self.top_logprobs_num = 0
63
78
  self.normalized_prompt_logprob = None
64
79
  self.prefill_token_logprobs = None
65
- self.decode_token_logprobs = None
66
80
  self.prefill_top_logprobs = None
67
- self.decode_top_logprobs = None
81
+ self.decode_token_logprobs = []
82
+ self.decode_top_logprobs = []
83
+ # The tokens is prefilled but need to be considered as decode tokens
84
+ # and should be updated for the decode logprobs
85
+ self.last_update_decode_tokens = 0
68
86
 
69
87
  # Constrained decoding
70
88
  self.regex_fsm = None
71
89
  self.regex_fsm_state = 0
72
90
  self.jump_forward_map = None
73
- self.output_and_jump_forward_str = ""
74
-
75
- def max_new_tokens(self):
76
- return self.sampling_params.max_new_tokens
77
91
 
78
- def jump_forward_and_retokenize(self, jump_forward_str, next_state):
79
- old_output_str = self.tokenizer.decode(self.output_ids)
80
- # FIXME: This logic does not really solve the problem of determining whether
81
- # there should be a leading space.
82
- first_token = self.tokenizer.convert_ids_to_tokens(self.output_ids[0])
92
+ def partial_decode(self, ids):
93
+ first_token = self.tokenizer.convert_ids_to_tokens(ids[0])
83
94
  first_token = (
84
95
  first_token.decode() if isinstance(first_token, bytes) else first_token
85
96
  )
86
- if first_token.startswith("▁"):
87
- old_output_str = " " + old_output_str
88
- new_input_string = (
89
- self.input_text
90
- + self.output_and_jump_forward_str
91
- + old_output_str
92
- + jump_forward_str
93
- )
94
- new_input_ids = self.tokenizer.encode(new_input_string)
95
- if self.pixel_values is not None:
96
- # NOTE: This is a hack because the old input_ids contains the image padding
97
- jump_forward_tokens_len = len(self.tokenizer.encode(jump_forward_str))
98
- else:
99
- jump_forward_tokens_len = (
100
- len(new_input_ids) - len(self.input_ids) - len(self.output_ids)
101
- )
102
-
103
- # print("=" * 100)
104
- # print(f"Catch jump forward:\n{jump_forward_str}")
105
- # print(self.tokenizer.convert_ids_to_tokens(self.input_ids))
106
- # print(self.tokenizer.convert_ids_to_tokens(new_input_ids))
97
+ return (" " if first_token.startswith("▁") else "") + self.tokenizer.decode(ids)
107
98
 
108
- self.input_ids = new_input_ids
109
- self.output_ids = []
110
- self.sampling_params.max_new_tokens = max(
111
- self.sampling_params.max_new_tokens - jump_forward_tokens_len, 0
112
- )
113
- self.regex_fsm_state = next_state
114
- self.output_and_jump_forward_str = (
115
- self.output_and_jump_forward_str + old_output_str + jump_forward_str
116
- )
117
-
118
- # print(f"Output and jump forward str:\n{self.output_and_jump_forward_str}")
119
- # print("*" * 100)
99
+ def max_new_tokens(self):
100
+ return self.sampling_params.max_new_tokens
120
101
 
121
102
  def check_finished(self):
122
103
  if self.finished:
123
104
  return
124
105
 
125
- if len(self.output_ids) >= self.sampling_params.max_new_tokens:
106
+ if (
107
+ len(self.prev_output_ids) + len(self.output_ids)
108
+ >= self.sampling_params.max_new_tokens
109
+ ):
126
110
  self.finished = True
127
111
  self.finish_reason = FinishReason.LENGTH
128
112
  return
@@ -141,14 +125,66 @@ class Req:
141
125
  )
142
126
 
143
127
  for stop_str in self.sampling_params.stop_strs:
144
- if stop_str in tail_str:
128
+ # FIXME: (minor) try incremental match in prev_output_str
129
+ if stop_str in tail_str or stop_str in self.prev_output_str:
145
130
  self.finished = True
146
131
  self.finish_reason = FinishReason.STOP_STR
147
132
  self.hit_stop_str = stop_str
148
133
  return
149
134
 
135
+ def jump_forward_and_retokenize(self, jump_forward_str, next_state):
136
+ # FIXME: This logic does not really solve the problem of determining whether
137
+ # there should be a leading space.
138
+ cur_output_str = self.partial_decode(self.output_ids)
139
+
140
+ # TODO(lsyin): apply re-tokenize only for decode tokens so that we do not need origin_input_text anymore
141
+ if self.origin_input_text is None:
142
+ # Recovering text can only use unpadded ids
143
+ self.origin_input_text = self.tokenizer.decode(
144
+ self.origin_input_ids_unpadded
145
+ )
146
+
147
+ all_text = (
148
+ self.origin_input_text
149
+ + self.prev_output_str
150
+ + cur_output_str
151
+ + jump_forward_str
152
+ )
153
+ all_ids = self.tokenizer.encode(all_text)
154
+ prompt_tokens = len(self.origin_input_ids_unpadded)
155
+ self.origin_input_ids = all_ids[:prompt_tokens]
156
+ self.origin_input_ids_unpadded = self.origin_input_ids
157
+ # NOTE: the output ids may not strictly correspond to the output text
158
+ old_prev_output_ids = self.prev_output_ids
159
+ self.prev_output_ids = all_ids[prompt_tokens:]
160
+ self.prev_output_str = self.prev_output_str + cur_output_str + jump_forward_str
161
+ self.output_ids = []
162
+
163
+ self.regex_fsm_state = next_state
164
+
165
+ if self.return_logprob:
166
+ # For fast-forward part's logprobs
167
+ k = 0
168
+ for i, old_id in enumerate(old_prev_output_ids):
169
+ if old_id == self.prev_output_ids[i]:
170
+ k = k + 1
171
+ else:
172
+ break
173
+ self.decode_token_logprobs = self.decode_token_logprobs[:k]
174
+ self.decode_top_logprobs = self.decode_top_logprobs[:k]
175
+ self.logprob_start_len = prompt_tokens + k
176
+ self.last_update_decode_tokens = len(self.prev_output_ids) - k
177
+
178
+ # print("=" * 100)
179
+ # print(f"Catch jump forward:\n{jump_forward_str}")
180
+ # print(self.tokenizer.convert_ids_to_tokens(self.input_ids))
181
+ # print(self.tokenizer.convert_ids_to_tokens(new_input_ids))
182
+
183
+ # print(f"Output and jump forward str:\n{self.output_and_jump_forward_str}")
184
+ # print("*" * 100)
185
+
150
186
  def __repr__(self):
151
- return f"rid(n={self.rid}, " f"input_ids={self.input_ids}, "
187
+ return f"rid(n={self.rid}, " f"input_ids={self.origin_input_ids}, "
152
188
 
153
189
 
154
190
  @dataclass
@@ -319,6 +355,7 @@ class Batch:
319
355
 
320
356
  def retract_decode(self):
321
357
  sorted_indices = [i for i in range(len(self.reqs))]
358
+ # TODO(lsyin): improve the priority of retraction
322
359
  sorted_indices.sort(
323
360
  key=lambda i: (len(self.reqs[i].output_ids), -len(self.reqs[i].input_ids)),
324
361
  reverse=True,
@@ -332,25 +369,34 @@ class Batch:
332
369
  req = self.reqs[idx]
333
370
  retracted_reqs.append(req)
334
371
 
335
- self.tree_cache.dec_ref_counter(req.last_node)
372
+ # TODO: apply more fine-grained retraction
373
+ last_uncached_pos = len(req.prefix_indices)
374
+ token_indices = self.req_to_token_pool.req_to_token[
375
+ req_pool_indices_cpu[idx]
376
+ ][last_uncached_pos : seq_lens_cpu[idx]]
377
+ self.token_to_kv_pool.dec_refs(token_indices)
378
+
379
+ # release the last node
380
+ self.tree_cache.dec_lock_ref(req.last_node)
381
+
382
+ cur_output_str = req.partial_decode(req.output_ids)
383
+ req.prev_output_str = req.prev_output_str + cur_output_str
384
+ req.prev_output_ids.extend(req.output_ids)
385
+
336
386
  req.prefix_indices = None
337
387
  req.last_node = None
338
388
  req.extend_input_len = 0
339
389
  req.output_ids = []
340
- req.regex_fsm_state = 0
341
-
342
- # TODO: apply more fine-grained retraction
343
390
 
344
- token_indices = self.req_to_token_pool.req_to_token[
345
- req_pool_indices_cpu[idx]
346
- ][: seq_lens_cpu[idx]]
347
- self.token_to_kv_pool.dec_refs(token_indices)
391
+ # For incremental logprobs
392
+ req.last_update_decode_tokens = 0
393
+ req.logprob_start_len = 10**9
348
394
 
349
395
  self.filter_batch(sorted_indices)
350
396
 
351
397
  return retracted_reqs
352
398
 
353
- def check_for_jump_forward(self):
399
+ def check_for_jump_forward(self, model_runner):
354
400
  jump_forward_reqs = []
355
401
  filter_indices = [i for i in range(len(self.reqs))]
356
402
 
@@ -364,24 +410,34 @@ class Batch:
364
410
  if len(jump_forward_str) <= 1:
365
411
  continue
366
412
 
367
- # insert the old request into tree_cache
368
- token_ids_in_memory = tuple(req.input_ids + req.output_ids)[:-1]
369
413
  if req_pool_indices_cpu is None:
370
414
  req_pool_indices_cpu = self.req_pool_indices.tolist()
371
- req_pool_idx = req_pool_indices_cpu[i]
372
- indices = self.req_to_token_pool.req_to_token[
373
- req_pool_idx, : len(token_ids_in_memory)
374
- ]
375
- prefix_len = self.tree_cache.insert(
376
- token_ids_in_memory, indices.clone()
415
+
416
+ # insert the old request into tree_cache
417
+ self.tree_cache.cache_req(
418
+ token_ids=tuple(req.input_ids + req.output_ids)[:-1],
419
+ last_uncached_pos=len(req.prefix_indices),
420
+ req_pool_idx=req_pool_indices_cpu[i],
377
421
  )
378
- self.token_to_kv_pool.dec_refs(indices[:prefix_len])
379
- self.req_to_token_pool.free(req_pool_idx)
380
- self.tree_cache.dec_ref_counter(req.last_node)
422
+
423
+ # unlock the last node
424
+ self.tree_cache.dec_lock_ref(req.last_node)
381
425
 
382
426
  # jump-forward
383
427
  req.jump_forward_and_retokenize(jump_forward_str, next_state)
384
428
 
429
+ # re-applying image padding
430
+ if req.pixel_values is not None:
431
+ (
432
+ req.origin_input_ids,
433
+ req.image_offset,
434
+ ) = model_runner.model.pad_input_ids(
435
+ req.origin_input_ids_unpadded,
436
+ req.pad_value,
437
+ req.pixel_values.shape,
438
+ req.image_size,
439
+ )
440
+
385
441
  jump_forward_reqs.append(req)
386
442
  filter_indices.remove(i)
387
443
 
@@ -5,10 +5,10 @@ import uvloop
5
5
  import zmq
6
6
  import zmq.asyncio
7
7
 
8
- from sglang.srt.backend_config import GLOBAL_BACKEND_CONFIG
8
+ from sglang.global_config import global_config
9
9
  from sglang.srt.managers.router.model_rpc import ModelRpcClient
10
10
  from sglang.srt.server_args import PortArgs, ServerArgs
11
- from sglang.srt.utils import get_exception_traceback
11
+ from sglang.utils import get_exception_traceback
12
12
 
13
13
  asyncio.set_event_loop_policy(uvloop.EventLoopPolicy())
14
14
 
@@ -30,7 +30,7 @@ class RouterManager:
30
30
  self.recv_reqs = []
31
31
 
32
32
  # Init some configs
33
- self.extend_dependency_time = GLOBAL_BACKEND_CONFIG.extend_dependency_time
33
+ self.request_dependency_time = global_config.request_dependency_time
34
34
 
35
35
  async def loop_for_forward(self):
36
36
  while True:
@@ -46,9 +46,9 @@ class RouterManager:
46
46
  if len(out_pyobjs) != 0:
47
47
  has_finished = any([obj.finished for obj in out_pyobjs])
48
48
  if has_finished:
49
- if self.extend_dependency_time > 0:
49
+ if self.request_dependency_time > 0:
50
50
  slept = True
51
- await asyncio.sleep(self.extend_dependency_time)
51
+ await asyncio.sleep(self.request_dependency_time)
52
52
 
53
53
  if not slept:
54
54
  await asyncio.sleep(0.0006)
@@ -60,9 +60,7 @@ class RouterManager:
60
60
 
61
61
 
62
62
  def start_router_process(
63
- server_args: ServerArgs,
64
- port_args: PortArgs,
65
- pipe_writer,
63
+ server_args: ServerArgs, port_args: PortArgs, pipe_writer, model_overide_args
66
64
  ):
67
65
  logging.basicConfig(
68
66
  level=getattr(logging, server_args.log_level.upper()),
@@ -70,7 +68,7 @@ def start_router_process(
70
68
  )
71
69
 
72
70
  try:
73
- model_client = ModelRpcClient(server_args, port_args)
71
+ model_client = ModelRpcClient(server_args, port_args, model_overide_args)
74
72
  router = RouterManager(model_client, port_args)
75
73
  except Exception:
76
74
  pipe_writer.send(get_exception_traceback())