sglang 0.1.15__py3-none-any.whl → 0.1.17__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (69) hide show
  1. sglang/__init__.py +5 -1
  2. sglang/api.py +8 -3
  3. sglang/backend/anthropic.py +1 -1
  4. sglang/backend/litellm.py +90 -0
  5. sglang/backend/openai.py +148 -12
  6. sglang/backend/runtime_endpoint.py +18 -10
  7. sglang/global_config.py +11 -1
  8. sglang/lang/chat_template.py +9 -2
  9. sglang/lang/interpreter.py +161 -81
  10. sglang/lang/ir.py +29 -11
  11. sglang/lang/tracer.py +1 -1
  12. sglang/launch_server.py +1 -2
  13. sglang/launch_server_llavavid.py +31 -0
  14. sglang/srt/constrained/fsm_cache.py +3 -0
  15. sglang/srt/flush_cache.py +16 -0
  16. sglang/srt/hf_transformers_utils.py +83 -2
  17. sglang/srt/layers/extend_attention.py +17 -0
  18. sglang/srt/layers/fused_moe.py +485 -0
  19. sglang/srt/layers/logits_processor.py +12 -7
  20. sglang/srt/layers/radix_attention.py +10 -3
  21. sglang/srt/layers/token_attention.py +16 -1
  22. sglang/srt/managers/controller/dp_worker.py +110 -0
  23. sglang/srt/managers/controller/infer_batch.py +619 -0
  24. sglang/srt/managers/controller/manager_multi.py +191 -0
  25. sglang/srt/managers/controller/manager_single.py +97 -0
  26. sglang/srt/managers/controller/model_runner.py +462 -0
  27. sglang/srt/managers/controller/radix_cache.py +267 -0
  28. sglang/srt/managers/controller/schedule_heuristic.py +59 -0
  29. sglang/srt/managers/controller/tp_worker.py +791 -0
  30. sglang/srt/managers/detokenizer_manager.py +45 -45
  31. sglang/srt/managers/io_struct.py +26 -10
  32. sglang/srt/managers/router/infer_batch.py +130 -74
  33. sglang/srt/managers/router/manager.py +7 -9
  34. sglang/srt/managers/router/model_rpc.py +224 -135
  35. sglang/srt/managers/router/model_runner.py +94 -107
  36. sglang/srt/managers/router/radix_cache.py +54 -18
  37. sglang/srt/managers/router/scheduler.py +23 -34
  38. sglang/srt/managers/tokenizer_manager.py +183 -88
  39. sglang/srt/model_config.py +5 -2
  40. sglang/srt/models/commandr.py +15 -22
  41. sglang/srt/models/dbrx.py +22 -29
  42. sglang/srt/models/gemma.py +14 -24
  43. sglang/srt/models/grok.py +671 -0
  44. sglang/srt/models/llama2.py +24 -23
  45. sglang/srt/models/llava.py +85 -25
  46. sglang/srt/models/llavavid.py +298 -0
  47. sglang/srt/models/mixtral.py +254 -130
  48. sglang/srt/models/mixtral_quant.py +373 -0
  49. sglang/srt/models/qwen.py +28 -25
  50. sglang/srt/models/qwen2.py +17 -22
  51. sglang/srt/models/stablelm.py +21 -26
  52. sglang/srt/models/yivl.py +17 -25
  53. sglang/srt/openai_api_adapter.py +140 -95
  54. sglang/srt/openai_protocol.py +10 -1
  55. sglang/srt/server.py +101 -52
  56. sglang/srt/server_args.py +59 -11
  57. sglang/srt/utils.py +242 -75
  58. sglang/test/test_programs.py +44 -0
  59. sglang/test/test_utils.py +32 -1
  60. sglang/utils.py +95 -26
  61. {sglang-0.1.15.dist-info → sglang-0.1.17.dist-info}/METADATA +23 -13
  62. sglang-0.1.17.dist-info/RECORD +81 -0
  63. sglang/srt/backend_config.py +0 -13
  64. sglang/srt/models/dbrx_config.py +0 -281
  65. sglang/srt/weight_utils.py +0 -402
  66. sglang-0.1.15.dist-info/RECORD +0 -69
  67. {sglang-0.1.15.dist-info → sglang-0.1.17.dist-info}/LICENSE +0 -0
  68. {sglang-0.1.15.dist-info → sglang-0.1.17.dist-info}/WHEEL +0 -0
  69. {sglang-0.1.15.dist-info → sglang-0.1.17.dist-info}/top_level.txt +0 -0
@@ -4,42 +4,45 @@ import multiprocessing
4
4
  import time
5
5
  import warnings
6
6
  from concurrent.futures import ThreadPoolExecutor
7
- from typing import List
7
+ from typing import List, Optional
8
8
 
9
9
  import rpyc
10
10
  import torch
11
11
  from rpyc.utils.classic import obtain
12
12
  from rpyc.utils.server import ThreadedServer
13
+
13
14
  try:
14
15
  from vllm.logger import _default_handler as vllm_default_logger
15
16
  except ImportError:
16
17
  from vllm.logger import logger as vllm_default_logger
17
18
 
19
+ from sglang.global_config import global_config
18
20
  from sglang.srt.constrained.fsm_cache import FSMCache
19
21
  from sglang.srt.constrained.jump_forward import JumpForwardCache
20
22
  from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
21
23
  from sglang.srt.managers.io_struct import (
24
+ AbortReq,
22
25
  BatchTokenIDOut,
23
26
  FlushCacheReq,
24
27
  TokenizedGenerateReqInput,
25
28
  )
26
- from sglang.srt.managers.router.infer_batch import Batch, ForwardMode, Req
29
+ from sglang.srt.managers.router.infer_batch import Batch, FinishReason, ForwardMode, Req
27
30
  from sglang.srt.managers.router.model_runner import ModelRunner
28
31
  from sglang.srt.managers.router.radix_cache import RadixCache
29
32
  from sglang.srt.managers.router.scheduler import Scheduler
30
33
  from sglang.srt.model_config import ModelConfig
31
34
  from sglang.srt.server_args import PortArgs, ServerArgs
32
35
  from sglang.srt.utils import (
33
- get_exception_traceback,
34
36
  get_int_token_logit_bias,
35
37
  is_multimodal_model,
36
38
  set_random_seed,
37
39
  )
38
-
40
+ from sglang.utils import get_exception_traceback
39
41
 
40
42
  logger = logging.getLogger("model_rpc")
41
43
  vllm_default_logger.setLevel(logging.WARN)
42
44
  logging.getLogger("vllm.utils").setLevel(logging.WARN)
45
+ logging.getLogger("vllm.selector").setLevel(logging.WARN)
43
46
 
44
47
 
45
48
  class ModelRpcServer:
@@ -48,6 +51,7 @@ class ModelRpcServer:
48
51
  tp_rank: int,
49
52
  server_args: ServerArgs,
50
53
  port_args: PortArgs,
54
+ model_overide_args: Optional[dict] = None,
51
55
  ):
52
56
  server_args, port_args = [obtain(x) for x in [server_args, port_args]]
53
57
 
@@ -62,23 +66,17 @@ class ModelRpcServer:
62
66
  server_args.model_path,
63
67
  server_args.trust_remote_code,
64
68
  context_length=server_args.context_length,
69
+ model_overide_args=model_overide_args,
65
70
  )
66
71
 
67
72
  # For model end global settings
68
- server_args_dict = {
69
- "enable_flashinfer": server_args.enable_flashinfer,
70
- "attention_reduce_in_fp32": server_args.attention_reduce_in_fp32,
71
- }
72
-
73
73
  self.model_runner = ModelRunner(
74
74
  model_config=self.model_config,
75
75
  mem_fraction_static=server_args.mem_fraction_static,
76
76
  tp_rank=tp_rank,
77
77
  tp_size=server_args.tp_size,
78
78
  nccl_port=port_args.nccl_port,
79
- load_format=server_args.load_format,
80
- trust_remote_code=server_args.trust_remote_code,
81
- server_args_dict=server_args_dict,
79
+ server_args=server_args,
82
80
  )
83
81
  if is_multimodal_model(server_args.model_path):
84
82
  self.processor = get_processor(
@@ -93,37 +91,44 @@ class ModelRpcServer:
93
91
  tokenizer_mode=server_args.tokenizer_mode,
94
92
  trust_remote_code=server_args.trust_remote_code,
95
93
  )
96
- self.max_total_num_token = self.model_runner.max_total_num_token
97
- self.max_num_running_seq = self.max_total_num_token // 2
98
- self.max_prefill_num_token = max(
94
+ self.max_total_num_tokens = self.model_runner.max_total_num_tokens
95
+ self.max_prefill_tokens = max(
99
96
  self.model_config.context_len,
100
97
  (
101
- self.max_total_num_token // 6
102
- if server_args.max_prefill_num_token is None
103
- else server_args.max_prefill_num_token
98
+ self.max_total_num_tokens // 6
99
+ if server_args.max_prefill_tokens is None
100
+ else server_args.max_prefill_tokens
104
101
  ),
105
102
  )
103
+ self.max_running_requests = (self.max_total_num_tokens // 2
104
+ if server_args.max_running_requests is None else server_args.max_running_requests)
105
+
106
106
  self.int_token_logit_bias = torch.tensor(
107
107
  get_int_token_logit_bias(self.tokenizer, self.model_config.vocab_size)
108
108
  )
109
109
  set_random_seed(server_args.random_seed)
110
- logger.info(
111
- f"Rank {self.tp_rank}: "
112
- f"max_total_num_token={self.max_total_num_token}, "
113
- f"max_prefill_num_token={self.max_prefill_num_token}, "
110
+
111
+ # Print info
112
+ logger.info(f"[rank={self.tp_rank}] "
113
+ f"max_total_num_tokens={self.max_total_num_tokens}, "
114
+ f"max_prefill_tokens={self.max_prefill_tokens}, "
114
115
  f"context_len={self.model_config.context_len}, "
115
116
  )
116
117
  if self.tp_rank == 0:
117
118
  logger.info(f"server_args: {server_args.print_mode_args()}")
118
119
 
119
120
  # Init cache
120
- self.tree_cache = RadixCache(disable=server_args.disable_radix_cache)
121
+ self.tree_cache = RadixCache(
122
+ req_to_token_pool=self.model_runner.req_to_token_pool,
123
+ token_to_kv_pool=self.model_runner.token_to_kv_pool,
124
+ disable=server_args.disable_radix_cache,
125
+ )
121
126
  self.tree_cache_metrics = {"total": 0, "hit": 0}
122
127
  self.scheduler = Scheduler(
123
128
  self.schedule_heuristic,
124
- self.max_num_running_seq,
125
- self.max_prefill_num_token,
126
- self.max_total_num_token,
129
+ self.max_running_requests,
130
+ self.max_prefill_tokens,
131
+ self.max_total_num_tokens,
127
132
  self.tree_cache,
128
133
  )
129
134
  self.req_to_token_pool = self.model_runner.req_to_token_pool
@@ -135,6 +140,8 @@ class ModelRpcServer:
135
140
  self.out_pyobjs = []
136
141
  self.decode_forward_ct = 0
137
142
  self.stream_interval = server_args.stream_interval
143
+ self.num_generated_tokens = 0
144
+ self.last_stats_tic = time.time()
138
145
 
139
146
  # Init the FSM cache for constrained generation
140
147
  self.regex_fsm_cache = FSMCache(
@@ -147,27 +154,20 @@ class ModelRpcServer:
147
154
  self.jump_forward_cache = JumpForwardCache()
148
155
 
149
156
  # Init new token estimation
150
- self.new_token_ratio = min(0.4 * server_args.schedule_conservativeness, 1.0)
151
- self.min_new_token_ratio = min(0.2 * server_args.schedule_conservativeness, 1.0)
152
- self.new_token_ratio_step = (0.0001, 0.05) # (down, up)
153
-
154
- def flush_cache(self):
155
- if len(self.forward_queue) == 0 and (
156
- self.running_batch is None or len(self.running_batch.reqs) == 0
157
- ):
158
- self.tree_cache.reset()
159
- self.tree_cache_metrics = {"total": 0, "hit": 0}
160
- self.regex_fsm_cache.reset()
161
- self.req_to_token_pool.clear()
162
- self.token_to_kv_pool.clear()
163
- torch.cuda.empty_cache()
164
- logger.info("Cache flushed successfully!")
165
- else:
166
- warnings.warn(
167
- f"Cache not flushed because there are pending requests. "
168
- f"#queue-req: {len(self.forward_queue)}, "
169
- f"#running-req: {0 if self.running_batch is None else len(self.running_batch.reqs)}"
170
- )
157
+ assert (
158
+ server_args.schedule_conservativeness >= 0
159
+ ), "Invalid schedule_conservativeness"
160
+ self.new_token_ratio = min(
161
+ global_config.base_new_token_ratio * server_args.schedule_conservativeness,
162
+ 1.0,
163
+ )
164
+ self.min_new_token_ratio = min(
165
+ global_config.base_min_new_token_ratio
166
+ * server_args.schedule_conservativeness,
167
+ 1.0,
168
+ )
169
+ self.new_token_ratio_decay = global_config.new_token_ratio_decay
170
+ self.new_token_ratio_recovery = global_config.new_token_ratio_recovery
171
171
 
172
172
  def exposed_step(self, recv_reqs):
173
173
  if self.tp_size != 1:
@@ -180,6 +180,8 @@ class ModelRpcServer:
180
180
  self.handle_generate_request(recv_req)
181
181
  elif isinstance(recv_req, FlushCacheReq):
182
182
  self.flush_cache()
183
+ elif isinstance(recv_req, AbortReq):
184
+ self.abort_request(recv_req)
183
185
  else:
184
186
  raise ValueError(f"Invalid request: {recv_req}")
185
187
 
@@ -198,8 +200,9 @@ class ModelRpcServer:
198
200
  new_batch = self.get_new_fill_batch()
199
201
 
200
202
  if new_batch is not None:
201
- # Run new fill batch
203
+ # Run a new fill batch
202
204
  self.forward_fill_batch(new_batch)
205
+ self.cache_filled_batch(new_batch)
203
206
 
204
207
  if not new_batch.is_empty():
205
208
  if self.running_batch is None:
@@ -211,37 +214,45 @@ class ModelRpcServer:
211
214
  if self.running_batch is not None:
212
215
  # Run a few decode batches continuously for reducing overhead
213
216
  for _ in range(10):
217
+ self.num_generated_tokens += len(self.running_batch.reqs)
214
218
  self.forward_decode_batch(self.running_batch)
215
219
 
216
- if self.running_batch.is_empty():
217
- self.running_batch = None
218
- break
219
-
220
- if self.out_pyobjs and self.running_batch.reqs[0].stream:
221
- break
222
-
223
- if self.running_batch is not None and self.tp_rank == 0:
220
+ # Print stats
221
+ if self.tp_rank == 0:
224
222
  if self.decode_forward_ct % 40 == 0:
225
- num_used = self.max_total_num_token - (
223
+ num_used = self.max_total_num_tokens - (
226
224
  self.token_to_kv_pool.available_size()
227
225
  + self.tree_cache.evictable_size()
228
226
  )
227
+ throuhgput = self.num_generated_tokens / (
228
+ time.time() - self.last_stats_tic
229
+ )
230
+ self.num_generated_tokens = 0
231
+ self.last_stats_tic = time.time()
229
232
  logger.info(
230
233
  f"#running-req: {len(self.running_batch.reqs)}, "
231
234
  f"#token: {num_used}, "
232
- f"token usage: {num_used / self.max_total_num_token:.2f}, "
235
+ f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
236
+ f"gen throughput (token/s): {throuhgput:.2f}, "
233
237
  f"#queue-req: {len(self.forward_queue)}"
234
238
  )
239
+
240
+ if self.running_batch.is_empty():
241
+ self.running_batch = None
242
+ break
243
+
244
+ if self.out_pyobjs and self.running_batch.reqs[0].stream:
245
+ break
235
246
  else:
236
- # check the available size
247
+ # Check the available size
237
248
  available_size = (
238
249
  self.token_to_kv_pool.available_size()
239
250
  + self.tree_cache.evictable_size()
240
251
  )
241
- if available_size != self.max_total_num_token:
252
+ if available_size != self.max_total_num_tokens:
242
253
  warnings.warn(
243
254
  "Warning: "
244
- f"available_size={available_size}, max_total_num_token={self.max_total_num_token}\n"
255
+ f"available_size={available_size}, max_total_num_tokens={self.max_total_num_tokens}\n"
245
256
  "KV cache pool leak detected!"
246
257
  )
247
258
 
@@ -259,8 +270,13 @@ class ModelRpcServer:
259
270
  (recv_req.image_hash >> 64) % self.model_config.vocab_size,
260
271
  ]
261
272
  req.image_size = recv_req.image_size
262
- req.input_ids, req.image_offset = self.model_runner.model.pad_input_ids(
263
- req.input_ids, req.pad_value, req.pixel_values.shape, req.image_size
273
+ req.origin_input_ids, req.image_offset = (
274
+ self.model_runner.model.pad_input_ids(
275
+ req.origin_input_ids_unpadded,
276
+ req.pad_value,
277
+ req.pixel_values.shape,
278
+ req.image_size,
279
+ )
264
280
  )
265
281
  req.sampling_params = recv_req.sampling_params
266
282
  req.return_logprob = recv_req.return_logprob
@@ -277,23 +293,28 @@ class ModelRpcServer:
277
293
  req.sampling_params.regex
278
294
  )
279
295
 
280
- # Truncate long prompts
281
- req.input_ids = req.input_ids[: self.model_config.context_len - 1]
296
+ # Truncate prompts that are too long
297
+ req.origin_input_ids = req.origin_input_ids[: self.model_config.context_len - 1]
282
298
  req.sampling_params.max_new_tokens = min(
283
299
  req.sampling_params.max_new_tokens,
284
- self.model_config.context_len - 1 - len(req.input_ids),
285
- self.max_total_num_token - 128 - len(req.input_ids),
300
+ self.model_config.context_len - 1 - len(req.origin_input_ids),
301
+ self.max_total_num_tokens - 128 - len(req.origin_input_ids),
286
302
  )
287
303
  self.forward_queue.append(req)
288
304
 
289
305
  def get_new_fill_batch(self):
290
306
  if (
291
307
  self.running_batch is not None
292
- and len(self.running_batch.reqs) > self.max_num_running_seq
308
+ and len(self.running_batch.reqs) > self.max_running_requests
293
309
  ):
294
310
  return None
295
311
 
312
+ # Compute matched prefix length
296
313
  for req in self.forward_queue:
314
+ assert (
315
+ len(req.output_ids) == 0
316
+ ), "The output ids should be empty when prefilling"
317
+ req.input_ids = req.origin_input_ids + req.prev_output_ids
297
318
  prefix_indices, last_node = self.tree_cache.match_prefix(req.input_ids)
298
319
  if req.return_logprob:
299
320
  prefix_indices = prefix_indices[: req.logprob_start_len]
@@ -321,7 +342,7 @@ class ModelRpcServer:
321
342
  )
322
343
 
323
344
  for req in self.forward_queue:
324
- if req.return_logprob:
345
+ if req.return_logprob and req.normalized_prompt_logprob is None:
325
346
  # Need at least two tokens to compute normalized logprob
326
347
  if req.extend_input_len < 2:
327
348
  delta = 2 - req.extend_input_len
@@ -340,22 +361,21 @@ class ModelRpcServer:
340
361
  req.extend_input_len + req.max_new_tokens() + new_batch_total_tokens
341
362
  < available_size
342
363
  and req.extend_input_len + new_batch_input_tokens
343
- < self.max_prefill_num_token
364
+ < self.max_prefill_tokens
344
365
  ):
345
- delta = self.tree_cache.inc_ref_counter(req.last_node)
366
+ delta = self.tree_cache.inc_lock_ref(req.last_node)
346
367
  available_size += delta
347
368
 
348
369
  if not (
349
370
  req.extend_input_len + req.max_new_tokens() + new_batch_total_tokens
350
371
  < available_size
351
372
  ):
352
- # Undo the insertion
353
- delta = self.tree_cache.dec_ref_counter(req.last_node)
373
+ # Undo locking
374
+ delta = self.tree_cache.dec_lock_ref(req.last_node)
354
375
  available_size += delta
355
376
  break
356
377
  else:
357
378
  # Add this request to the running batch
358
- self.token_to_kv_pool.add_refs(req.prefix_indices)
359
379
  can_run_list.append(req)
360
380
  new_batch_total_tokens += (
361
381
  req.extend_input_len + req.max_new_tokens()
@@ -366,6 +386,7 @@ class ModelRpcServer:
366
386
  if len(can_run_list) == 0:
367
387
  return None
368
388
 
389
+ # Print stats
369
390
  if self.tp_rank == 0:
370
391
  running_req = (
371
392
  0 if self.running_batch is None else len(self.running_batch.reqs)
@@ -386,13 +407,14 @@ class ModelRpcServer:
386
407
  f"#running_req: {running_req}. "
387
408
  f"tree_cache_hit_rate: {100.0 * tree_cache_hit_rate:.2f}%."
388
409
  )
389
- #logger.debug(
410
+ # logger.debug(
390
411
  # f"fsm_cache_hit_rate: {100.0 * self.regex_fsm_cache.get_cache_hit_rate():.2f}%. "
391
412
  # f"fsm_cache_avg_init_time: {self.regex_fsm_cache.get_avg_init_time():.2f}s. "
392
413
  # f"ff_cache_hit_rate: {100.0 * self.jump_forward_cache.get_cache_hit_rate():.2f}%. "
393
414
  # f"ff_cache_avg_init_time: {self.jump_forward_cache.get_avg_init_time():.2f}s. "
394
- #)
415
+ # )
395
416
 
417
+ # Return the new batch
396
418
  new_batch = Batch.init_new(
397
419
  can_run_list,
398
420
  self.req_to_token_pool,
@@ -425,9 +447,10 @@ class ModelRpcServer:
425
447
 
426
448
  # Only transfer the selected logprobs of the next token to CPU to reduce overhead.
427
449
  if last_logprobs is not None:
428
- last_token_logprobs = (
429
- last_logprobs[torch.arange(len(batch.reqs)), next_token_ids].tolist()
430
- )
450
+ last_token_logprobs = last_logprobs[
451
+ torch.arange(len(batch.reqs), device=next_token_ids.device),
452
+ next_token_ids,
453
+ ].tolist()
431
454
 
432
455
  next_token_ids = next_token_ids.tolist()
433
456
  else:
@@ -441,38 +464,75 @@ class ModelRpcServer:
441
464
  req.check_finished()
442
465
 
443
466
  if req.return_logprob:
444
- req.normalized_prompt_logprob = normalized_prompt_logprobs[i]
445
-
446
- # If logprob_start_len > 0, then first logprob_start_len prompt tokens will be ignored.
447
- req.prefill_token_logprobs = list(
448
- zip(
449
- prefill_token_logprobs[pt : pt + req.extend_input_len - 1],
450
- req.input_ids[-req.extend_input_len + 1 :],
467
+ if req.normalized_prompt_logprob is None:
468
+ req.normalized_prompt_logprob = normalized_prompt_logprobs[i]
469
+
470
+ if req.prefill_token_logprobs is None:
471
+ # If logprob_start_len > 0, then first logprob_start_len prompt tokens will be ignored.
472
+ req.prefill_token_logprobs = list(
473
+ zip(
474
+ prefill_token_logprobs[pt : pt + req.extend_input_len - 1],
475
+ req.input_ids[-req.extend_input_len + 1 :],
476
+ )
451
477
  )
452
- )
453
- if req.logprob_start_len == 0:
454
- req.prefill_token_logprobs = [
455
- (None, req.input_ids[0])
456
- ] + req.prefill_token_logprobs
457
- req.decode_token_logprobs = [
478
+ if req.logprob_start_len == 0:
479
+ req.prefill_token_logprobs = [
480
+ (None, req.input_ids[0])
481
+ ] + req.prefill_token_logprobs
482
+
483
+ if req.last_update_decode_tokens != 0:
484
+ req.decode_token_logprobs.extend(
485
+ list(
486
+ zip(
487
+ prefill_token_logprobs[
488
+ pt
489
+ + req.extend_input_len
490
+ - req.last_update_decode_tokens : pt
491
+ + req.extend_input_len
492
+ - 1
493
+ ],
494
+ req.input_ids[-req.last_update_decode_tokens + 1 :],
495
+ )
496
+ )
497
+ )
498
+
499
+ req.decode_token_logprobs.append(
458
500
  (last_token_logprobs[i], next_token_ids[i])
459
- ]
501
+ )
460
502
 
461
503
  if req.top_logprobs_num > 0:
462
- req.prefill_top_logprobs = prefill_top_logprobs[i]
463
- if req.logprob_start_len == 0:
464
- req.prefill_top_logprobs = [None] + req.prefill_top_logprobs
465
- req.decode_top_logprobs = [decode_top_logprobs[i]]
504
+ if req.prefill_top_logprobs is None:
505
+ req.prefill_top_logprobs = prefill_top_logprobs[i]
506
+ if req.logprob_start_len == 0:
507
+ req.prefill_top_logprobs = [None] + req.prefill_top_logprobs
508
+
509
+ if req.last_update_decode_tokens != 0:
510
+ req.decode_top_logprobs.extend(
511
+ prefill_top_logprobs[i][-req.last_update_decode_tokens + 1 :]
512
+ )
513
+ req.decode_top_logprobs.append(decode_top_logprobs[i])
466
514
 
467
515
  pt += req.extend_input_len
468
516
 
469
517
  self.handle_finished_requests(batch)
470
518
 
519
+ def cache_filled_batch(self, batch: Batch):
520
+ req_pool_indices_cpu = batch.req_pool_indices.cpu().numpy()
521
+ for i, req in enumerate(batch.reqs):
522
+ new_prefix_indices, new_last_node = self.tree_cache.cache_req(
523
+ token_ids=tuple(req.input_ids + req.output_ids)[:-1],
524
+ last_uncached_pos=len(req.prefix_indices),
525
+ req_pool_idx=req_pool_indices_cpu[i],
526
+ del_in_memory_pool=False,
527
+ old_last_node=req.last_node,
528
+ )
529
+ req.prefix_indices, req.last_node = new_prefix_indices, new_last_node
530
+
471
531
  def forward_decode_batch(self, batch: Batch):
472
532
  # check if decode out of memory
473
533
  if not batch.check_decode_mem():
474
534
  old_ratio = self.new_token_ratio
475
- self.new_token_ratio = min(old_ratio + self.new_token_ratio_step[1], 1.0)
535
+ self.new_token_ratio = min(old_ratio + self.new_token_ratio_recovery, 1.0)
476
536
 
477
537
  retracted_reqs = batch.retract_decode()
478
538
  logger.info(
@@ -483,26 +543,13 @@ class ModelRpcServer:
483
543
  self.forward_queue.extend(retracted_reqs)
484
544
  else:
485
545
  self.new_token_ratio = max(
486
- self.new_token_ratio - self.new_token_ratio_step[0],
546
+ self.new_token_ratio - self.new_token_ratio_decay,
487
547
  self.min_new_token_ratio,
488
548
  )
489
549
 
490
550
  if not self.disable_regex_jump_forward:
491
551
  # check for jump-forward
492
- jump_forward_reqs = batch.check_for_jump_forward()
493
-
494
- # check for image jump-forward
495
- for req in jump_forward_reqs:
496
- if req.pixel_values is not None:
497
- (
498
- req.input_ids,
499
- req.image_offset,
500
- ) = self.model_runner.model.pad_input_ids(
501
- req.input_ids,
502
- req.pad_value,
503
- req.pixel_values.shape,
504
- req.image_size,
505
- )
552
+ jump_forward_reqs = batch.check_for_jump_forward(self.model_runner)
506
553
 
507
554
  self.forward_queue.extend(jump_forward_reqs)
508
555
  if batch.is_empty():
@@ -545,8 +592,8 @@ class ModelRpcServer:
545
592
 
546
593
  def handle_finished_requests(self, batch: Batch):
547
594
  output_rids = []
595
+ prev_output_strs = []
548
596
  output_tokens = []
549
- output_and_jump_forward_strs = []
550
597
  output_hit_stop_str = []
551
598
  output_skip_special_tokens = []
552
599
  output_spaces_between_special_tokens = []
@@ -570,8 +617,8 @@ class ModelRpcServer:
570
617
  )
571
618
  ):
572
619
  output_rids.append(req.rid)
620
+ prev_output_strs.append(req.prev_output_str)
573
621
  output_tokens.append(req.output_ids)
574
- output_and_jump_forward_strs.append(req.output_and_jump_forward_str)
575
622
  output_hit_stop_str.append(req.hit_stop_str)
576
623
  output_skip_special_tokens.append(
577
624
  req.sampling_params.skip_special_tokens
@@ -581,12 +628,11 @@ class ModelRpcServer:
581
628
  )
582
629
 
583
630
  meta_info = {
584
- "prompt_tokens": req.prompt_tokens,
585
- "completion_tokens": len(req.input_ids)
586
- + len(req.output_ids)
587
- - req.prompt_tokens,
631
+ "prompt_tokens": len(req.origin_input_ids),
632
+ "completion_tokens": len(req.prev_output_ids) + len(req.output_ids),
588
633
  "completion_tokens_wo_jump_forward": req.completion_tokens_wo_jump_forward,
589
- "finish_reason": str(req.finish_reason), # FIXME: convert to the correct string
634
+ "finish_reason": FinishReason.to_str(req.finish_reason),
635
+ "hit_stop_str": req.hit_stop_str,
590
636
  }
591
637
  if req.return_logprob:
592
638
  (
@@ -610,8 +656,8 @@ class ModelRpcServer:
610
656
  self.out_pyobjs.append(
611
657
  BatchTokenIDOut(
612
658
  output_rids,
659
+ prev_output_strs,
613
660
  output_tokens,
614
- output_and_jump_forward_strs,
615
661
  output_hit_stop_str,
616
662
  output_skip_special_tokens,
617
663
  output_spaces_between_special_tokens,
@@ -626,17 +672,13 @@ class ModelRpcServer:
626
672
  req_pool_indices_cpu = batch.req_pool_indices.tolist()
627
673
  for i in finished_indices:
628
674
  req = batch.reqs[i]
629
- req_pool_idx = req_pool_indices_cpu[i]
630
- token_ids = tuple(req.input_ids + req.output_ids)
631
- seq_len = len(token_ids) - 1
632
- indices = self.req_to_token_pool.req_to_token[req_pool_idx, :seq_len]
633
- prefix_len = self.tree_cache.insert(
634
- token_ids[:seq_len], indices.clone()
675
+ self.tree_cache.cache_req(
676
+ token_ids=tuple(req.input_ids + req.output_ids)[:-1],
677
+ last_uncached_pos=len(req.prefix_indices),
678
+ req_pool_idx=req_pool_indices_cpu[i],
635
679
  )
636
680
 
637
- self.token_to_kv_pool.dec_refs(indices[:prefix_len])
638
- self.req_to_token_pool.free(req_pool_idx)
639
- self.tree_cache.dec_ref_counter(req.last_node)
681
+ self.tree_cache.dec_lock_ref(req.last_node)
640
682
 
641
683
  # Update batch tensors
642
684
  if unfinished_indices:
@@ -644,19 +686,58 @@ class ModelRpcServer:
644
686
  else:
645
687
  batch.reqs = []
646
688
 
689
+ def flush_cache(self):
690
+ if len(self.forward_queue) == 0 and (
691
+ self.running_batch is None or len(self.running_batch.reqs) == 0
692
+ ):
693
+ self.tree_cache.reset()
694
+ self.tree_cache_metrics = {"total": 0, "hit": 0}
695
+ self.regex_fsm_cache.reset()
696
+ self.req_to_token_pool.clear()
697
+ self.token_to_kv_pool.clear()
698
+ torch.cuda.empty_cache()
699
+ logger.info("Cache flushed successfully!")
700
+ else:
701
+ warnings.warn(
702
+ f"Cache not flushed because there are pending requests. "
703
+ f"#queue-req: {len(self.forward_queue)}, "
704
+ f"#running-req: {0 if self.running_batch is None else len(self.running_batch.reqs)}"
705
+ )
706
+
707
+ def abort_request(self, recv_req):
708
+ # Delete requests in the waiting queue
709
+ to_del = None
710
+ for i, req in enumerate(self.forward_queue):
711
+ if req.rid == recv_req.rid:
712
+ to_del = i
713
+ break
714
+
715
+ if to_del is not None:
716
+ del self.forward_queue[to_del]
717
+
718
+ # Delete requests in the running batch
719
+ if self.running_batch:
720
+ for req in self.running_batch.reqs:
721
+ if req.rid == recv_req.rid:
722
+ req.finished = True
723
+ req.finish_reason = FinishReason.ABORT
724
+ break
725
+
647
726
 
648
727
  class ModelRpcService(rpyc.Service):
649
728
  exposed_ModelRpcServer = ModelRpcServer
650
729
 
651
730
 
652
731
  class ModelRpcClient:
653
- def __init__(self, server_args: ServerArgs, port_args: PortArgs):
732
+ def __init__(
733
+ self, server_args: ServerArgs, port_args: PortArgs, model_overide_args
734
+ ):
654
735
  tp_size = server_args.tp_size
655
736
 
656
737
  if tp_size == 1:
657
738
  # Init model
658
739
  self.model_server = ModelRpcService().exposed_ModelRpcServer(
659
- 0, server_args, port_args
740
+ 0, server_args, port_args, model_overide_args
660
741
  )
661
742
 
662
743
  # Wrap functions
@@ -677,7 +758,7 @@ class ModelRpcClient:
677
758
  # Init model
678
759
  def init_model(i):
679
760
  return self.remote_services[i].ModelRpcServer(
680
- i, server_args, port_args
761
+ i, server_args, port_args, model_overide_args
681
762
  )
682
763
 
683
764
  self.model_servers = executor.map(init_model, range(tp_size))
@@ -700,7 +781,11 @@ def _init_service(port):
700
781
  t = ThreadedServer(
701
782
  ModelRpcService(),
702
783
  port=port,
703
- protocol_config={"allow_pickle": True, "sync_request_timeout": 1800},
784
+ protocol_config={
785
+ "allow_public_attrs": True,
786
+ "allow_pickle": True,
787
+ "sync_request_timeout": 3600,
788
+ },
704
789
  )
705
790
  t.start()
706
791
 
@@ -716,7 +801,11 @@ def start_model_process(port):
716
801
  con = rpyc.connect(
717
802
  "localhost",
718
803
  port,
719
- config={"allow_pickle": True, "sync_request_timeout": 1800},
804
+ config={
805
+ "allow_public_attrs": True,
806
+ "allow_pickle": True,
807
+ "sync_request_timeout": 3600,
808
+ },
720
809
  )
721
810
  break
722
811
  except ConnectionRefusedError: