sglang 0.1.16__py3-none-any.whl → 0.1.17__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (65) hide show
  1. sglang/__init__.py +3 -1
  2. sglang/api.py +3 -3
  3. sglang/backend/anthropic.py +1 -1
  4. sglang/backend/litellm.py +90 -0
  5. sglang/backend/openai.py +148 -12
  6. sglang/backend/runtime_endpoint.py +18 -10
  7. sglang/global_config.py +8 -1
  8. sglang/lang/interpreter.py +114 -67
  9. sglang/lang/ir.py +17 -2
  10. sglang/srt/constrained/fsm_cache.py +3 -0
  11. sglang/srt/flush_cache.py +1 -1
  12. sglang/srt/hf_transformers_utils.py +75 -1
  13. sglang/srt/layers/extend_attention.py +17 -0
  14. sglang/srt/layers/fused_moe.py +485 -0
  15. sglang/srt/layers/logits_processor.py +12 -7
  16. sglang/srt/layers/radix_attention.py +10 -3
  17. sglang/srt/layers/token_attention.py +16 -1
  18. sglang/srt/managers/controller/dp_worker.py +110 -0
  19. sglang/srt/managers/controller/infer_batch.py +619 -0
  20. sglang/srt/managers/controller/manager_multi.py +191 -0
  21. sglang/srt/managers/controller/manager_single.py +97 -0
  22. sglang/srt/managers/controller/model_runner.py +462 -0
  23. sglang/srt/managers/controller/radix_cache.py +267 -0
  24. sglang/srt/managers/controller/schedule_heuristic.py +59 -0
  25. sglang/srt/managers/controller/tp_worker.py +791 -0
  26. sglang/srt/managers/detokenizer_manager.py +45 -45
  27. sglang/srt/managers/io_struct.py +15 -11
  28. sglang/srt/managers/router/infer_batch.py +103 -59
  29. sglang/srt/managers/router/manager.py +1 -1
  30. sglang/srt/managers/router/model_rpc.py +175 -122
  31. sglang/srt/managers/router/model_runner.py +91 -104
  32. sglang/srt/managers/router/radix_cache.py +7 -1
  33. sglang/srt/managers/router/scheduler.py +6 -6
  34. sglang/srt/managers/tokenizer_manager.py +152 -89
  35. sglang/srt/model_config.py +4 -5
  36. sglang/srt/models/commandr.py +10 -13
  37. sglang/srt/models/dbrx.py +9 -15
  38. sglang/srt/models/gemma.py +8 -15
  39. sglang/srt/models/grok.py +671 -0
  40. sglang/srt/models/llama2.py +19 -15
  41. sglang/srt/models/llava.py +84 -20
  42. sglang/srt/models/llavavid.py +11 -20
  43. sglang/srt/models/mixtral.py +248 -118
  44. sglang/srt/models/mixtral_quant.py +373 -0
  45. sglang/srt/models/qwen.py +9 -13
  46. sglang/srt/models/qwen2.py +11 -13
  47. sglang/srt/models/stablelm.py +9 -15
  48. sglang/srt/models/yivl.py +17 -22
  49. sglang/srt/openai_api_adapter.py +140 -95
  50. sglang/srt/openai_protocol.py +10 -1
  51. sglang/srt/server.py +77 -42
  52. sglang/srt/server_args.py +51 -6
  53. sglang/srt/utils.py +124 -66
  54. sglang/test/test_programs.py +44 -0
  55. sglang/test/test_utils.py +32 -1
  56. sglang/utils.py +22 -4
  57. {sglang-0.1.16.dist-info → sglang-0.1.17.dist-info}/METADATA +15 -9
  58. sglang-0.1.17.dist-info/RECORD +81 -0
  59. sglang/srt/backend_config.py +0 -13
  60. sglang/srt/models/dbrx_config.py +0 -281
  61. sglang/srt/weight_utils.py +0 -417
  62. sglang-0.1.16.dist-info/RECORD +0 -72
  63. {sglang-0.1.16.dist-info → sglang-0.1.17.dist-info}/LICENSE +0 -0
  64. {sglang-0.1.16.dist-info → sglang-0.1.17.dist-info}/WHEEL +0 -0
  65. {sglang-0.1.16.dist-info → sglang-0.1.17.dist-info}/top_level.txt +0 -0
@@ -4,7 +4,7 @@ import multiprocessing
4
4
  import time
5
5
  import warnings
6
6
  from concurrent.futures import ThreadPoolExecutor
7
- from typing import Any, Dict, List, Optional, Tuple, Union
7
+ from typing import List, Optional
8
8
 
9
9
  import rpyc
10
10
  import torch
@@ -16,31 +16,33 @@ try:
16
16
  except ImportError:
17
17
  from vllm.logger import logger as vllm_default_logger
18
18
 
19
+ from sglang.global_config import global_config
19
20
  from sglang.srt.constrained.fsm_cache import FSMCache
20
21
  from sglang.srt.constrained.jump_forward import JumpForwardCache
21
22
  from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
22
23
  from sglang.srt.managers.io_struct import (
24
+ AbortReq,
23
25
  BatchTokenIDOut,
24
26
  FlushCacheReq,
25
27
  TokenizedGenerateReqInput,
26
28
  )
27
- from sglang.srt.managers.router.infer_batch import Batch, ForwardMode, Req, FinishReason
29
+ from sglang.srt.managers.router.infer_batch import Batch, FinishReason, ForwardMode, Req
28
30
  from sglang.srt.managers.router.model_runner import ModelRunner
29
31
  from sglang.srt.managers.router.radix_cache import RadixCache
30
32
  from sglang.srt.managers.router.scheduler import Scheduler
31
33
  from sglang.srt.model_config import ModelConfig
32
34
  from sglang.srt.server_args import PortArgs, ServerArgs
33
35
  from sglang.srt.utils import (
34
- get_exception_traceback,
35
36
  get_int_token_logit_bias,
36
37
  is_multimodal_model,
37
38
  set_random_seed,
38
39
  )
39
-
40
+ from sglang.utils import get_exception_traceback
40
41
 
41
42
  logger = logging.getLogger("model_rpc")
42
43
  vllm_default_logger.setLevel(logging.WARN)
43
44
  logging.getLogger("vllm.utils").setLevel(logging.WARN)
45
+ logging.getLogger("vllm.selector").setLevel(logging.WARN)
44
46
 
45
47
 
46
48
  class ModelRpcServer:
@@ -68,20 +70,13 @@ class ModelRpcServer:
68
70
  )
69
71
 
70
72
  # For model end global settings
71
- server_args_dict = {
72
- "enable_flashinfer": server_args.enable_flashinfer,
73
- "attention_reduce_in_fp32": server_args.attention_reduce_in_fp32,
74
- }
75
-
76
73
  self.model_runner = ModelRunner(
77
74
  model_config=self.model_config,
78
75
  mem_fraction_static=server_args.mem_fraction_static,
79
76
  tp_rank=tp_rank,
80
77
  tp_size=server_args.tp_size,
81
78
  nccl_port=port_args.nccl_port,
82
- load_format=server_args.load_format,
83
- trust_remote_code=server_args.trust_remote_code,
84
- server_args_dict=server_args_dict,
79
+ server_args=server_args,
85
80
  )
86
81
  if is_multimodal_model(server_args.model_path):
87
82
  self.processor = get_processor(
@@ -96,24 +91,27 @@ class ModelRpcServer:
96
91
  tokenizer_mode=server_args.tokenizer_mode,
97
92
  trust_remote_code=server_args.trust_remote_code,
98
93
  )
99
- self.max_total_num_token = self.model_runner.max_total_num_token
100
- self.max_num_running_seq = self.max_total_num_token // 2
101
- self.max_prefill_num_token = max(
94
+ self.max_total_num_tokens = self.model_runner.max_total_num_tokens
95
+ self.max_prefill_tokens = max(
102
96
  self.model_config.context_len,
103
97
  (
104
- self.max_total_num_token // 6
105
- if server_args.max_prefill_num_token is None
106
- else server_args.max_prefill_num_token
98
+ self.max_total_num_tokens // 6
99
+ if server_args.max_prefill_tokens is None
100
+ else server_args.max_prefill_tokens
107
101
  ),
108
102
  )
103
+ self.max_running_requests = (self.max_total_num_tokens // 2
104
+ if server_args.max_running_requests is None else server_args.max_running_requests)
105
+
109
106
  self.int_token_logit_bias = torch.tensor(
110
107
  get_int_token_logit_bias(self.tokenizer, self.model_config.vocab_size)
111
108
  )
112
109
  set_random_seed(server_args.random_seed)
113
- logger.info(
114
- f"Rank {self.tp_rank}: "
115
- f"max_total_num_token={self.max_total_num_token}, "
116
- f"max_prefill_num_token={self.max_prefill_num_token}, "
110
+
111
+ # Print info
112
+ logger.info(f"[rank={self.tp_rank}] "
113
+ f"max_total_num_tokens={self.max_total_num_tokens}, "
114
+ f"max_prefill_tokens={self.max_prefill_tokens}, "
117
115
  f"context_len={self.model_config.context_len}, "
118
116
  )
119
117
  if self.tp_rank == 0:
@@ -128,9 +126,9 @@ class ModelRpcServer:
128
126
  self.tree_cache_metrics = {"total": 0, "hit": 0}
129
127
  self.scheduler = Scheduler(
130
128
  self.schedule_heuristic,
131
- self.max_num_running_seq,
132
- self.max_prefill_num_token,
133
- self.max_total_num_token,
129
+ self.max_running_requests,
130
+ self.max_prefill_tokens,
131
+ self.max_total_num_tokens,
134
132
  self.tree_cache,
135
133
  )
136
134
  self.req_to_token_pool = self.model_runner.req_to_token_pool
@@ -156,27 +154,20 @@ class ModelRpcServer:
156
154
  self.jump_forward_cache = JumpForwardCache()
157
155
 
158
156
  # Init new token estimation
159
- self.new_token_ratio = min(0.4 * server_args.schedule_conservativeness, 1.0)
160
- self.min_new_token_ratio = min(0.2 * server_args.schedule_conservativeness, 1.0)
161
- self.new_token_ratio_step = (0.0001, 0.05) # (down, up)
162
-
163
- def flush_cache(self):
164
- if len(self.forward_queue) == 0 and (
165
- self.running_batch is None or len(self.running_batch.reqs) == 0
166
- ):
167
- self.tree_cache.reset()
168
- self.tree_cache_metrics = {"total": 0, "hit": 0}
169
- self.regex_fsm_cache.reset()
170
- self.req_to_token_pool.clear()
171
- self.token_to_kv_pool.clear()
172
- torch.cuda.empty_cache()
173
- logger.info("Cache flushed successfully!")
174
- else:
175
- warnings.warn(
176
- f"Cache not flushed because there are pending requests. "
177
- f"#queue-req: {len(self.forward_queue)}, "
178
- f"#running-req: {0 if self.running_batch is None else len(self.running_batch.reqs)}"
179
- )
157
+ assert (
158
+ server_args.schedule_conservativeness >= 0
159
+ ), "Invalid schedule_conservativeness"
160
+ self.new_token_ratio = min(
161
+ global_config.base_new_token_ratio * server_args.schedule_conservativeness,
162
+ 1.0,
163
+ )
164
+ self.min_new_token_ratio = min(
165
+ global_config.base_min_new_token_ratio
166
+ * server_args.schedule_conservativeness,
167
+ 1.0,
168
+ )
169
+ self.new_token_ratio_decay = global_config.new_token_ratio_decay
170
+ self.new_token_ratio_recovery = global_config.new_token_ratio_recovery
180
171
 
181
172
  def exposed_step(self, recv_reqs):
182
173
  if self.tp_size != 1:
@@ -189,6 +180,8 @@ class ModelRpcServer:
189
180
  self.handle_generate_request(recv_req)
190
181
  elif isinstance(recv_req, FlushCacheReq):
191
182
  self.flush_cache()
183
+ elif isinstance(recv_req, AbortReq):
184
+ self.abort_request(recv_req)
192
185
  else:
193
186
  raise ValueError(f"Invalid request: {recv_req}")
194
187
 
@@ -207,9 +200,8 @@ class ModelRpcServer:
207
200
  new_batch = self.get_new_fill_batch()
208
201
 
209
202
  if new_batch is not None:
210
- # Run new fill batch
203
+ # Run a new fill batch
211
204
  self.forward_fill_batch(new_batch)
212
-
213
205
  self.cache_filled_batch(new_batch)
214
206
 
215
207
  if not new_batch.is_empty():
@@ -225,39 +217,42 @@ class ModelRpcServer:
225
217
  self.num_generated_tokens += len(self.running_batch.reqs)
226
218
  self.forward_decode_batch(self.running_batch)
227
219
 
228
- if self.running_batch.is_empty():
229
- self.running_batch = None
230
- break
231
-
232
- if self.out_pyobjs and self.running_batch.reqs[0].stream:
233
- break
234
-
235
- if self.running_batch is not None and self.tp_rank == 0:
220
+ # Print stats
221
+ if self.tp_rank == 0:
236
222
  if self.decode_forward_ct % 40 == 0:
237
- num_used = self.max_total_num_token - (
223
+ num_used = self.max_total_num_tokens - (
238
224
  self.token_to_kv_pool.available_size()
239
225
  + self.tree_cache.evictable_size()
240
226
  )
241
- throuhgput = self.num_generated_tokens / (time.time() - self.last_stats_tic)
227
+ throuhgput = self.num_generated_tokens / (
228
+ time.time() - self.last_stats_tic
229
+ )
242
230
  self.num_generated_tokens = 0
243
231
  self.last_stats_tic = time.time()
244
232
  logger.info(
245
233
  f"#running-req: {len(self.running_batch.reqs)}, "
246
234
  f"#token: {num_used}, "
247
- f"token usage: {num_used / self.max_total_num_token:.2f}, "
235
+ f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
248
236
  f"gen throughput (token/s): {throuhgput:.2f}, "
249
237
  f"#queue-req: {len(self.forward_queue)}"
250
238
  )
239
+
240
+ if self.running_batch.is_empty():
241
+ self.running_batch = None
242
+ break
243
+
244
+ if self.out_pyobjs and self.running_batch.reqs[0].stream:
245
+ break
251
246
  else:
252
- # check the available size
247
+ # Check the available size
253
248
  available_size = (
254
249
  self.token_to_kv_pool.available_size()
255
250
  + self.tree_cache.evictable_size()
256
251
  )
257
- if available_size != self.max_total_num_token:
252
+ if available_size != self.max_total_num_tokens:
258
253
  warnings.warn(
259
254
  "Warning: "
260
- f"available_size={available_size}, max_total_num_token={self.max_total_num_token}\n"
255
+ f"available_size={available_size}, max_total_num_tokens={self.max_total_num_tokens}\n"
261
256
  "KV cache pool leak detected!"
262
257
  )
263
258
 
@@ -275,8 +270,13 @@ class ModelRpcServer:
275
270
  (recv_req.image_hash >> 64) % self.model_config.vocab_size,
276
271
  ]
277
272
  req.image_size = recv_req.image_size
278
- req.input_ids, req.image_offset = self.model_runner.model.pad_input_ids(
279
- req.input_ids, req.pad_value, req.pixel_values.shape, req.image_size
273
+ req.origin_input_ids, req.image_offset = (
274
+ self.model_runner.model.pad_input_ids(
275
+ req.origin_input_ids_unpadded,
276
+ req.pad_value,
277
+ req.pixel_values.shape,
278
+ req.image_size,
279
+ )
280
280
  )
281
281
  req.sampling_params = recv_req.sampling_params
282
282
  req.return_logprob = recv_req.return_logprob
@@ -293,23 +293,28 @@ class ModelRpcServer:
293
293
  req.sampling_params.regex
294
294
  )
295
295
 
296
- # Truncate long prompts
297
- req.input_ids = req.input_ids[: self.model_config.context_len - 1]
296
+ # Truncate prompts that are too long
297
+ req.origin_input_ids = req.origin_input_ids[: self.model_config.context_len - 1]
298
298
  req.sampling_params.max_new_tokens = min(
299
299
  req.sampling_params.max_new_tokens,
300
- self.model_config.context_len - 1 - len(req.input_ids),
301
- self.max_total_num_token - 128 - len(req.input_ids),
300
+ self.model_config.context_len - 1 - len(req.origin_input_ids),
301
+ self.max_total_num_tokens - 128 - len(req.origin_input_ids),
302
302
  )
303
303
  self.forward_queue.append(req)
304
304
 
305
305
  def get_new_fill_batch(self):
306
306
  if (
307
307
  self.running_batch is not None
308
- and len(self.running_batch.reqs) > self.max_num_running_seq
308
+ and len(self.running_batch.reqs) > self.max_running_requests
309
309
  ):
310
310
  return None
311
311
 
312
+ # Compute matched prefix length
312
313
  for req in self.forward_queue:
314
+ assert (
315
+ len(req.output_ids) == 0
316
+ ), "The output ids should be empty when prefilling"
317
+ req.input_ids = req.origin_input_ids + req.prev_output_ids
313
318
  prefix_indices, last_node = self.tree_cache.match_prefix(req.input_ids)
314
319
  if req.return_logprob:
315
320
  prefix_indices = prefix_indices[: req.logprob_start_len]
@@ -337,7 +342,7 @@ class ModelRpcServer:
337
342
  )
338
343
 
339
344
  for req in self.forward_queue:
340
- if req.return_logprob:
345
+ if req.return_logprob and req.normalized_prompt_logprob is None:
341
346
  # Need at least two tokens to compute normalized logprob
342
347
  if req.extend_input_len < 2:
343
348
  delta = 2 - req.extend_input_len
@@ -356,7 +361,7 @@ class ModelRpcServer:
356
361
  req.extend_input_len + req.max_new_tokens() + new_batch_total_tokens
357
362
  < available_size
358
363
  and req.extend_input_len + new_batch_input_tokens
359
- < self.max_prefill_num_token
364
+ < self.max_prefill_tokens
360
365
  ):
361
366
  delta = self.tree_cache.inc_lock_ref(req.last_node)
362
367
  available_size += delta
@@ -381,6 +386,7 @@ class ModelRpcServer:
381
386
  if len(can_run_list) == 0:
382
387
  return None
383
388
 
389
+ # Print stats
384
390
  if self.tp_rank == 0:
385
391
  running_req = (
386
392
  0 if self.running_batch is None else len(self.running_batch.reqs)
@@ -401,13 +407,14 @@ class ModelRpcServer:
401
407
  f"#running_req: {running_req}. "
402
408
  f"tree_cache_hit_rate: {100.0 * tree_cache_hit_rate:.2f}%."
403
409
  )
404
- #logger.debug(
410
+ # logger.debug(
405
411
  # f"fsm_cache_hit_rate: {100.0 * self.regex_fsm_cache.get_cache_hit_rate():.2f}%. "
406
412
  # f"fsm_cache_avg_init_time: {self.regex_fsm_cache.get_avg_init_time():.2f}s. "
407
413
  # f"ff_cache_hit_rate: {100.0 * self.jump_forward_cache.get_cache_hit_rate():.2f}%. "
408
414
  # f"ff_cache_avg_init_time: {self.jump_forward_cache.get_avg_init_time():.2f}s. "
409
- #)
415
+ # )
410
416
 
417
+ # Return the new batch
411
418
  new_batch = Batch.init_new(
412
419
  can_run_list,
413
420
  self.req_to_token_pool,
@@ -440,11 +447,10 @@ class ModelRpcServer:
440
447
 
441
448
  # Only transfer the selected logprobs of the next token to CPU to reduce overhead.
442
449
  if last_logprobs is not None:
443
- last_token_logprobs = (
444
- last_logprobs[
445
- torch.arange(len(batch.reqs), device=next_token_ids.device),
446
- next_token_ids].tolist()
447
- )
450
+ last_token_logprobs = last_logprobs[
451
+ torch.arange(len(batch.reqs), device=next_token_ids.device),
452
+ next_token_ids,
453
+ ].tolist()
448
454
 
449
455
  next_token_ids = next_token_ids.tolist()
450
456
  else:
@@ -458,35 +464,60 @@ class ModelRpcServer:
458
464
  req.check_finished()
459
465
 
460
466
  if req.return_logprob:
461
- req.normalized_prompt_logprob = normalized_prompt_logprobs[i]
462
-
463
- # If logprob_start_len > 0, then first logprob_start_len prompt tokens will be ignored.
464
- req.prefill_token_logprobs = list(
465
- zip(
466
- prefill_token_logprobs[pt : pt + req.extend_input_len - 1],
467
- req.input_ids[-req.extend_input_len + 1 :],
467
+ if req.normalized_prompt_logprob is None:
468
+ req.normalized_prompt_logprob = normalized_prompt_logprobs[i]
469
+
470
+ if req.prefill_token_logprobs is None:
471
+ # If logprob_start_len > 0, then first logprob_start_len prompt tokens will be ignored.
472
+ req.prefill_token_logprobs = list(
473
+ zip(
474
+ prefill_token_logprobs[pt : pt + req.extend_input_len - 1],
475
+ req.input_ids[-req.extend_input_len + 1 :],
476
+ )
468
477
  )
469
- )
470
- if req.logprob_start_len == 0:
471
- req.prefill_token_logprobs = [
472
- (None, req.input_ids[0])
473
- ] + req.prefill_token_logprobs
474
- req.decode_token_logprobs = [
478
+ if req.logprob_start_len == 0:
479
+ req.prefill_token_logprobs = [
480
+ (None, req.input_ids[0])
481
+ ] + req.prefill_token_logprobs
482
+
483
+ if req.last_update_decode_tokens != 0:
484
+ req.decode_token_logprobs.extend(
485
+ list(
486
+ zip(
487
+ prefill_token_logprobs[
488
+ pt
489
+ + req.extend_input_len
490
+ - req.last_update_decode_tokens : pt
491
+ + req.extend_input_len
492
+ - 1
493
+ ],
494
+ req.input_ids[-req.last_update_decode_tokens + 1 :],
495
+ )
496
+ )
497
+ )
498
+
499
+ req.decode_token_logprobs.append(
475
500
  (last_token_logprobs[i], next_token_ids[i])
476
- ]
501
+ )
477
502
 
478
503
  if req.top_logprobs_num > 0:
479
- req.prefill_top_logprobs = prefill_top_logprobs[i]
480
- if req.logprob_start_len == 0:
481
- req.prefill_top_logprobs = [None] + req.prefill_top_logprobs
482
- req.decode_top_logprobs = [decode_top_logprobs[i]]
504
+ if req.prefill_top_logprobs is None:
505
+ req.prefill_top_logprobs = prefill_top_logprobs[i]
506
+ if req.logprob_start_len == 0:
507
+ req.prefill_top_logprobs = [None] + req.prefill_top_logprobs
508
+
509
+ if req.last_update_decode_tokens != 0:
510
+ req.decode_top_logprobs.extend(
511
+ prefill_top_logprobs[i][-req.last_update_decode_tokens + 1 :]
512
+ )
513
+ req.decode_top_logprobs.append(decode_top_logprobs[i])
483
514
 
484
515
  pt += req.extend_input_len
485
516
 
486
517
  self.handle_finished_requests(batch)
487
518
 
488
519
  def cache_filled_batch(self, batch: Batch):
489
- req_pool_indices_cpu = batch.req_pool_indices.cpu().tolist()
520
+ req_pool_indices_cpu = batch.req_pool_indices.cpu().numpy()
490
521
  for i, req in enumerate(batch.reqs):
491
522
  new_prefix_indices, new_last_node = self.tree_cache.cache_req(
492
523
  token_ids=tuple(req.input_ids + req.output_ids)[:-1],
@@ -501,7 +532,7 @@ class ModelRpcServer:
501
532
  # check if decode out of memory
502
533
  if not batch.check_decode_mem():
503
534
  old_ratio = self.new_token_ratio
504
- self.new_token_ratio = min(old_ratio + self.new_token_ratio_step[1], 1.0)
535
+ self.new_token_ratio = min(old_ratio + self.new_token_ratio_recovery, 1.0)
505
536
 
506
537
  retracted_reqs = batch.retract_decode()
507
538
  logger.info(
@@ -512,26 +543,13 @@ class ModelRpcServer:
512
543
  self.forward_queue.extend(retracted_reqs)
513
544
  else:
514
545
  self.new_token_ratio = max(
515
- self.new_token_ratio - self.new_token_ratio_step[0],
546
+ self.new_token_ratio - self.new_token_ratio_decay,
516
547
  self.min_new_token_ratio,
517
548
  )
518
549
 
519
550
  if not self.disable_regex_jump_forward:
520
551
  # check for jump-forward
521
- jump_forward_reqs = batch.check_for_jump_forward()
522
-
523
- # check for image jump-forward
524
- for req in jump_forward_reqs:
525
- if req.pixel_values is not None:
526
- (
527
- req.input_ids,
528
- req.image_offset,
529
- ) = self.model_runner.model.pad_input_ids(
530
- req.input_ids,
531
- req.pad_value,
532
- req.pixel_values.shape,
533
- req.image_size,
534
- )
552
+ jump_forward_reqs = batch.check_for_jump_forward(self.model_runner)
535
553
 
536
554
  self.forward_queue.extend(jump_forward_reqs)
537
555
  if batch.is_empty():
@@ -574,8 +592,8 @@ class ModelRpcServer:
574
592
 
575
593
  def handle_finished_requests(self, batch: Batch):
576
594
  output_rids = []
595
+ prev_output_strs = []
577
596
  output_tokens = []
578
- output_and_jump_forward_strs = []
579
597
  output_hit_stop_str = []
580
598
  output_skip_special_tokens = []
581
599
  output_spaces_between_special_tokens = []
@@ -599,8 +617,8 @@ class ModelRpcServer:
599
617
  )
600
618
  ):
601
619
  output_rids.append(req.rid)
620
+ prev_output_strs.append(req.prev_output_str)
602
621
  output_tokens.append(req.output_ids)
603
- output_and_jump_forward_strs.append(req.output_and_jump_forward_str)
604
622
  output_hit_stop_str.append(req.hit_stop_str)
605
623
  output_skip_special_tokens.append(
606
624
  req.sampling_params.skip_special_tokens
@@ -610,10 +628,8 @@ class ModelRpcServer:
610
628
  )
611
629
 
612
630
  meta_info = {
613
- "prompt_tokens": req.prompt_tokens,
614
- "completion_tokens": len(req.input_ids)
615
- + len(req.output_ids)
616
- - req.prompt_tokens,
631
+ "prompt_tokens": len(req.origin_input_ids),
632
+ "completion_tokens": len(req.prev_output_ids) + len(req.output_ids),
617
633
  "completion_tokens_wo_jump_forward": req.completion_tokens_wo_jump_forward,
618
634
  "finish_reason": FinishReason.to_str(req.finish_reason),
619
635
  "hit_stop_str": req.hit_stop_str,
@@ -640,8 +656,8 @@ class ModelRpcServer:
640
656
  self.out_pyobjs.append(
641
657
  BatchTokenIDOut(
642
658
  output_rids,
659
+ prev_output_strs,
643
660
  output_tokens,
644
- output_and_jump_forward_strs,
645
661
  output_hit_stop_str,
646
662
  output_skip_special_tokens,
647
663
  output_spaces_between_special_tokens,
@@ -670,6 +686,43 @@ class ModelRpcServer:
670
686
  else:
671
687
  batch.reqs = []
672
688
 
689
+ def flush_cache(self):
690
+ if len(self.forward_queue) == 0 and (
691
+ self.running_batch is None or len(self.running_batch.reqs) == 0
692
+ ):
693
+ self.tree_cache.reset()
694
+ self.tree_cache_metrics = {"total": 0, "hit": 0}
695
+ self.regex_fsm_cache.reset()
696
+ self.req_to_token_pool.clear()
697
+ self.token_to_kv_pool.clear()
698
+ torch.cuda.empty_cache()
699
+ logger.info("Cache flushed successfully!")
700
+ else:
701
+ warnings.warn(
702
+ f"Cache not flushed because there are pending requests. "
703
+ f"#queue-req: {len(self.forward_queue)}, "
704
+ f"#running-req: {0 if self.running_batch is None else len(self.running_batch.reqs)}"
705
+ )
706
+
707
+ def abort_request(self, recv_req):
708
+ # Delete requests in the waiting queue
709
+ to_del = None
710
+ for i, req in enumerate(self.forward_queue):
711
+ if req.rid == recv_req.rid:
712
+ to_del = i
713
+ break
714
+
715
+ if to_del is not None:
716
+ del self.forward_queue[to_del]
717
+
718
+ # Delete requests in the running batch
719
+ if self.running_batch:
720
+ for req in self.running_batch.reqs:
721
+ if req.rid == recv_req.rid:
722
+ req.finished = True
723
+ req.finish_reason = FinishReason.ABORT
724
+ break
725
+
673
726
 
674
727
  class ModelRpcService(rpyc.Service):
675
728
  exposed_ModelRpcServer = ModelRpcServer
@@ -731,7 +784,7 @@ def _init_service(port):
731
784
  protocol_config={
732
785
  "allow_public_attrs": True,
733
786
  "allow_pickle": True,
734
- "sync_request_timeout": 1800,
787
+ "sync_request_timeout": 3600,
735
788
  },
736
789
  )
737
790
  t.start()
@@ -751,7 +804,7 @@ def start_model_process(port):
751
804
  config={
752
805
  "allow_public_attrs": True,
753
806
  "allow_pickle": True,
754
- "sync_request_timeout": 1800,
807
+ "sync_request_timeout": 3600,
755
808
  },
756
809
  )
757
810
  break