sglang 0.1.16__py3-none-any.whl → 0.1.18__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (68) hide show
  1. sglang/__init__.py +3 -1
  2. sglang/api.py +7 -7
  3. sglang/backend/anthropic.py +1 -1
  4. sglang/backend/litellm.py +90 -0
  5. sglang/backend/openai.py +158 -11
  6. sglang/backend/runtime_endpoint.py +18 -10
  7. sglang/bench_latency.py +299 -0
  8. sglang/global_config.py +12 -2
  9. sglang/lang/compiler.py +2 -2
  10. sglang/lang/interpreter.py +114 -67
  11. sglang/lang/ir.py +28 -3
  12. sglang/launch_server.py +4 -1
  13. sglang/launch_server_llavavid.py +2 -1
  14. sglang/srt/constrained/__init__.py +13 -6
  15. sglang/srt/constrained/fsm_cache.py +8 -2
  16. sglang/srt/constrained/jump_forward.py +113 -25
  17. sglang/srt/conversation.py +2 -0
  18. sglang/srt/flush_cache.py +3 -1
  19. sglang/srt/hf_transformers_utils.py +130 -1
  20. sglang/srt/layers/extend_attention.py +17 -0
  21. sglang/srt/layers/fused_moe.py +582 -0
  22. sglang/srt/layers/logits_processor.py +65 -32
  23. sglang/srt/layers/radix_attention.py +41 -7
  24. sglang/srt/layers/token_attention.py +16 -1
  25. sglang/srt/managers/controller/dp_worker.py +113 -0
  26. sglang/srt/managers/{router → controller}/infer_batch.py +242 -100
  27. sglang/srt/managers/controller/manager_multi.py +191 -0
  28. sglang/srt/managers/{router/manager.py → controller/manager_single.py} +34 -14
  29. sglang/srt/managers/{router → controller}/model_runner.py +262 -158
  30. sglang/srt/managers/{router → controller}/radix_cache.py +11 -1
  31. sglang/srt/managers/{router/scheduler.py → controller/schedule_heuristic.py} +9 -7
  32. sglang/srt/managers/{router/model_rpc.py → controller/tp_worker.py} +298 -267
  33. sglang/srt/managers/detokenizer_manager.py +42 -46
  34. sglang/srt/managers/io_struct.py +22 -12
  35. sglang/srt/managers/tokenizer_manager.py +151 -87
  36. sglang/srt/model_config.py +83 -5
  37. sglang/srt/models/chatglm.py +399 -0
  38. sglang/srt/models/commandr.py +10 -13
  39. sglang/srt/models/dbrx.py +9 -15
  40. sglang/srt/models/gemma.py +12 -15
  41. sglang/srt/models/grok.py +738 -0
  42. sglang/srt/models/llama2.py +26 -15
  43. sglang/srt/models/llama_classification.py +104 -0
  44. sglang/srt/models/llava.py +86 -19
  45. sglang/srt/models/llavavid.py +11 -20
  46. sglang/srt/models/mixtral.py +282 -103
  47. sglang/srt/models/mixtral_quant.py +372 -0
  48. sglang/srt/models/qwen.py +9 -13
  49. sglang/srt/models/qwen2.py +11 -13
  50. sglang/srt/models/stablelm.py +9 -15
  51. sglang/srt/models/yivl.py +17 -22
  52. sglang/srt/openai_api_adapter.py +150 -95
  53. sglang/srt/openai_protocol.py +11 -2
  54. sglang/srt/server.py +124 -48
  55. sglang/srt/server_args.py +128 -48
  56. sglang/srt/utils.py +234 -67
  57. sglang/test/test_programs.py +65 -3
  58. sglang/test/test_utils.py +32 -1
  59. sglang/utils.py +23 -4
  60. {sglang-0.1.16.dist-info → sglang-0.1.18.dist-info}/METADATA +40 -27
  61. sglang-0.1.18.dist-info/RECORD +78 -0
  62. {sglang-0.1.16.dist-info → sglang-0.1.18.dist-info}/WHEEL +1 -1
  63. sglang/srt/backend_config.py +0 -13
  64. sglang/srt/models/dbrx_config.py +0 -281
  65. sglang/srt/weight_utils.py +0 -417
  66. sglang-0.1.16.dist-info/RECORD +0 -72
  67. {sglang-0.1.16.dist-info → sglang-0.1.18.dist-info}/LICENSE +0 -0
  68. {sglang-0.1.16.dist-info → sglang-0.1.18.dist-info}/top_level.txt +0 -0
@@ -1,61 +1,68 @@
1
+ """A tensor parallel worker."""
2
+
1
3
  import asyncio
2
4
  import logging
3
- import multiprocessing
4
5
  import time
5
6
  import warnings
6
7
  from concurrent.futures import ThreadPoolExecutor
7
- from typing import Any, Dict, List, Optional, Tuple, Union
8
+ from typing import List, Optional
8
9
 
9
10
  import rpyc
10
11
  import torch
11
12
  from rpyc.utils.classic import obtain
12
- from rpyc.utils.server import ThreadedServer
13
-
14
- try:
15
- from vllm.logger import _default_handler as vllm_default_logger
16
- except ImportError:
17
- from vllm.logger import logger as vllm_default_logger
18
13
 
14
+ from sglang.global_config import global_config
19
15
  from sglang.srt.constrained.fsm_cache import FSMCache
20
16
  from sglang.srt.constrained.jump_forward import JumpForwardCache
21
17
  from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
18
+ from sglang.srt.managers.controller.infer_batch import (
19
+ FINISH_ABORT,
20
+ BaseFinishReason,
21
+ Batch,
22
+ ForwardMode,
23
+ Req,
24
+ )
25
+ from sglang.srt.managers.controller.model_runner import ModelRunner
26
+ from sglang.srt.managers.controller.radix_cache import RadixCache
27
+ from sglang.srt.managers.controller.schedule_heuristic import ScheduleHeuristic
22
28
  from sglang.srt.managers.io_struct import (
29
+ AbortReq,
23
30
  BatchTokenIDOut,
24
31
  FlushCacheReq,
25
32
  TokenizedGenerateReqInput,
26
33
  )
27
- from sglang.srt.managers.router.infer_batch import Batch, ForwardMode, Req, FinishReason
28
- from sglang.srt.managers.router.model_runner import ModelRunner
29
- from sglang.srt.managers.router.radix_cache import RadixCache
30
- from sglang.srt.managers.router.scheduler import Scheduler
31
34
  from sglang.srt.model_config import ModelConfig
32
- from sglang.srt.server_args import PortArgs, ServerArgs
35
+ from sglang.srt.server_args import ModelPortArgs, ServerArgs
33
36
  from sglang.srt.utils import (
34
- get_exception_traceback,
35
37
  get_int_token_logit_bias,
36
38
  is_multimodal_model,
37
39
  set_random_seed,
40
+ start_rpyc_service_process,
41
+ connect_rpyc_service,
42
+ suppress_other_loggers,
38
43
  )
44
+ from sglang.utils import get_exception_traceback
39
45
 
40
-
41
- logger = logging.getLogger("model_rpc")
42
- vllm_default_logger.setLevel(logging.WARN)
43
- logging.getLogger("vllm.utils").setLevel(logging.WARN)
46
+ logger = logging.getLogger("srt.tp_worker")
44
47
 
45
48
 
46
- class ModelRpcServer:
49
+ class ModelTpServer:
47
50
  def __init__(
48
51
  self,
52
+ gpu_id: int,
49
53
  tp_rank: int,
50
54
  server_args: ServerArgs,
51
- port_args: PortArgs,
52
- model_overide_args: Optional[dict] = None,
55
+ model_port_args: ModelPortArgs,
56
+ model_overide_args,
53
57
  ):
54
- server_args, port_args = [obtain(x) for x in [server_args, port_args]]
58
+ server_args, model_port_args = obtain(server_args), obtain(model_port_args)
59
+ suppress_other_loggers()
55
60
 
56
61
  # Copy arguments
62
+ self.gpu_id = gpu_id
57
63
  self.tp_rank = tp_rank
58
64
  self.tp_size = server_args.tp_size
65
+ self.dp_size = server_args.dp_size
59
66
  self.schedule_heuristic = server_args.schedule_heuristic
60
67
  self.disable_regex_jump_forward = server_args.disable_regex_jump_forward
61
68
 
@@ -66,23 +73,16 @@ class ModelRpcServer:
66
73
  context_length=server_args.context_length,
67
74
  model_overide_args=model_overide_args,
68
75
  )
69
-
70
- # For model end global settings
71
- server_args_dict = {
72
- "enable_flashinfer": server_args.enable_flashinfer,
73
- "attention_reduce_in_fp32": server_args.attention_reduce_in_fp32,
74
- }
75
-
76
76
  self.model_runner = ModelRunner(
77
77
  model_config=self.model_config,
78
78
  mem_fraction_static=server_args.mem_fraction_static,
79
+ gpu_id=gpu_id,
79
80
  tp_rank=tp_rank,
80
81
  tp_size=server_args.tp_size,
81
- nccl_port=port_args.nccl_port,
82
- load_format=server_args.load_format,
83
- trust_remote_code=server_args.trust_remote_code,
84
- server_args_dict=server_args_dict,
82
+ nccl_port=model_port_args.nccl_port,
83
+ server_args=server_args,
85
84
  )
85
+
86
86
  if is_multimodal_model(server_args.model_path):
87
87
  self.processor = get_processor(
88
88
  server_args.tokenizer_path,
@@ -96,28 +96,34 @@ class ModelRpcServer:
96
96
  tokenizer_mode=server_args.tokenizer_mode,
97
97
  trust_remote_code=server_args.trust_remote_code,
98
98
  )
99
- self.max_total_num_token = self.model_runner.max_total_num_token
100
- self.max_num_running_seq = self.max_total_num_token // 2
101
- self.max_prefill_num_token = max(
102
- self.model_config.context_len,
103
- (
104
- self.max_total_num_token // 6
105
- if server_args.max_prefill_num_token is None
106
- else server_args.max_prefill_num_token
107
- ),
99
+ self.max_total_num_tokens = self.model_runner.max_total_num_tokens
100
+ self.max_prefill_tokens = (
101
+ 4096
102
+ if server_args.max_prefill_tokens is None
103
+ else server_args.max_prefill_tokens
104
+ )
105
+ self.max_running_requests = (
106
+ self.max_total_num_tokens // 2
107
+ if server_args.max_running_requests is None
108
+ else server_args.max_running_requests
108
109
  )
109
110
  self.int_token_logit_bias = torch.tensor(
110
111
  get_int_token_logit_bias(self.tokenizer, self.model_config.vocab_size)
111
112
  )
112
113
  set_random_seed(server_args.random_seed)
114
+
115
+ # Print info
113
116
  logger.info(
114
- f"Rank {self.tp_rank}: "
115
- f"max_total_num_token={self.max_total_num_token}, "
116
- f"max_prefill_num_token={self.max_prefill_num_token}, "
117
- f"context_len={self.model_config.context_len}, "
117
+ f"[gpu_id={self.gpu_id}] "
118
+ f"max_total_num_tokens={self.max_total_num_tokens}, "
119
+ f"max_prefill_tokens={self.max_prefill_tokens}, "
120
+ f"context_len={self.model_config.context_len}"
118
121
  )
119
122
  if self.tp_rank == 0:
120
- logger.info(f"server_args: {server_args.print_mode_args()}")
123
+ logger.info(
124
+ f"[gpu_id={self.gpu_id}] "
125
+ f"server_args: {server_args.print_mode_args()}"
126
+ )
121
127
 
122
128
  # Init cache
123
129
  self.tree_cache = RadixCache(
@@ -126,11 +132,11 @@ class ModelRpcServer:
126
132
  disable=server_args.disable_radix_cache,
127
133
  )
128
134
  self.tree_cache_metrics = {"total": 0, "hit": 0}
129
- self.scheduler = Scheduler(
135
+ self.scheduler = ScheduleHeuristic(
130
136
  self.schedule_heuristic,
131
- self.max_num_running_seq,
132
- self.max_prefill_num_token,
133
- self.max_total_num_token,
137
+ self.max_running_requests,
138
+ self.max_prefill_tokens,
139
+ self.max_total_num_tokens,
134
140
  self.tree_cache,
135
141
  )
136
142
  self.req_to_token_pool = self.model_runner.req_to_token_pool
@@ -156,30 +162,23 @@ class ModelRpcServer:
156
162
  self.jump_forward_cache = JumpForwardCache()
157
163
 
158
164
  # Init new token estimation
159
- self.new_token_ratio = min(0.4 * server_args.schedule_conservativeness, 1.0)
160
- self.min_new_token_ratio = min(0.2 * server_args.schedule_conservativeness, 1.0)
161
- self.new_token_ratio_step = (0.0001, 0.05) # (down, up)
162
-
163
- def flush_cache(self):
164
- if len(self.forward_queue) == 0 and (
165
- self.running_batch is None or len(self.running_batch.reqs) == 0
166
- ):
167
- self.tree_cache.reset()
168
- self.tree_cache_metrics = {"total": 0, "hit": 0}
169
- self.regex_fsm_cache.reset()
170
- self.req_to_token_pool.clear()
171
- self.token_to_kv_pool.clear()
172
- torch.cuda.empty_cache()
173
- logger.info("Cache flushed successfully!")
174
- else:
175
- warnings.warn(
176
- f"Cache not flushed because there are pending requests. "
177
- f"#queue-req: {len(self.forward_queue)}, "
178
- f"#running-req: {0 if self.running_batch is None else len(self.running_batch.reqs)}"
179
- )
165
+ assert (
166
+ server_args.schedule_conservativeness >= 0
167
+ ), "Invalid schedule_conservativeness"
168
+ self.new_token_ratio = min(
169
+ global_config.base_new_token_ratio * server_args.schedule_conservativeness,
170
+ 1.0,
171
+ )
172
+ self.min_new_token_ratio = min(
173
+ global_config.base_min_new_token_ratio
174
+ * server_args.schedule_conservativeness,
175
+ 1.0,
176
+ )
177
+ self.new_token_ratio_decay = global_config.new_token_ratio_decay
178
+ self.new_token_ratio_recovery = global_config.new_token_ratio_recovery
180
179
 
181
180
  def exposed_step(self, recv_reqs):
182
- if self.tp_size != 1:
181
+ if self.tp_size * self.dp_size != 1:
183
182
  recv_reqs = obtain(recv_reqs)
184
183
 
185
184
  try:
@@ -189,13 +188,16 @@ class ModelRpcServer:
189
188
  self.handle_generate_request(recv_req)
190
189
  elif isinstance(recv_req, FlushCacheReq):
191
190
  self.flush_cache()
191
+ elif isinstance(recv_req, AbortReq):
192
+ self.abort_request(recv_req)
192
193
  else:
193
194
  raise ValueError(f"Invalid request: {recv_req}")
194
195
 
195
196
  # Forward
196
197
  self.forward_step()
197
198
  except Exception:
198
- logger.error("Exception in ModelRpcClient:\n" + get_exception_traceback())
199
+ logger.error("Exception in ModelTpServer:\n" + get_exception_traceback())
200
+ raise
199
201
 
200
202
  # Return results
201
203
  ret = self.out_pyobjs
@@ -207,9 +209,8 @@ class ModelRpcServer:
207
209
  new_batch = self.get_new_fill_batch()
208
210
 
209
211
  if new_batch is not None:
210
- # Run new fill batch
212
+ # Run a new fill batch
211
213
  self.forward_fill_batch(new_batch)
212
-
213
214
  self.cache_filled_batch(new_batch)
214
215
 
215
216
  if not new_batch.is_empty():
@@ -225,39 +226,43 @@ class ModelRpcServer:
225
226
  self.num_generated_tokens += len(self.running_batch.reqs)
226
227
  self.forward_decode_batch(self.running_batch)
227
228
 
228
- if self.running_batch.is_empty():
229
- self.running_batch = None
230
- break
231
-
232
- if self.out_pyobjs and self.running_batch.reqs[0].stream:
233
- break
234
-
235
- if self.running_batch is not None and self.tp_rank == 0:
229
+ # Print stats
230
+ if self.tp_rank == 0:
236
231
  if self.decode_forward_ct % 40 == 0:
237
- num_used = self.max_total_num_token - (
232
+ num_used = self.max_total_num_tokens - (
238
233
  self.token_to_kv_pool.available_size()
239
234
  + self.tree_cache.evictable_size()
240
235
  )
241
- throuhgput = self.num_generated_tokens / (time.time() - self.last_stats_tic)
236
+ throughput = self.num_generated_tokens / (
237
+ time.time() - self.last_stats_tic
238
+ )
242
239
  self.num_generated_tokens = 0
243
240
  self.last_stats_tic = time.time()
244
241
  logger.info(
242
+ f"[gpu_id={self.gpu_id}] Decode batch. "
245
243
  f"#running-req: {len(self.running_batch.reqs)}, "
246
244
  f"#token: {num_used}, "
247
- f"token usage: {num_used / self.max_total_num_token:.2f}, "
248
- f"gen throughput (token/s): {throuhgput:.2f}, "
245
+ f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
246
+ f"gen throughput (token/s): {throughput:.2f}, "
249
247
  f"#queue-req: {len(self.forward_queue)}"
250
248
  )
249
+
250
+ if self.running_batch.is_empty():
251
+ self.running_batch = None
252
+ break
253
+
254
+ if self.out_pyobjs and self.running_batch.has_stream():
255
+ break
251
256
  else:
252
- # check the available size
257
+ # Check the available size
253
258
  available_size = (
254
259
  self.token_to_kv_pool.available_size()
255
260
  + self.tree_cache.evictable_size()
256
261
  )
257
- if available_size != self.max_total_num_token:
262
+ if available_size != self.max_total_num_tokens:
258
263
  warnings.warn(
259
264
  "Warning: "
260
- f"available_size={available_size}, max_total_num_token={self.max_total_num_token}\n"
265
+ f"available_size={available_size}, max_total_num_tokens={self.max_total_num_tokens}\n"
261
266
  "KV cache pool leak detected!"
262
267
  )
263
268
 
@@ -275,8 +280,14 @@ class ModelRpcServer:
275
280
  (recv_req.image_hash >> 64) % self.model_config.vocab_size,
276
281
  ]
277
282
  req.image_size = recv_req.image_size
278
- req.input_ids, req.image_offset = self.model_runner.model.pad_input_ids(
279
- req.input_ids, req.pad_value, req.pixel_values.shape, req.image_size
283
+ (
284
+ req.origin_input_ids,
285
+ req.image_offset,
286
+ ) = self.model_runner.model.pad_input_ids(
287
+ req.origin_input_ids_unpadded,
288
+ req.pad_value,
289
+ req.pixel_values.shape,
290
+ req.image_size,
280
291
  )
281
292
  req.sampling_params = recv_req.sampling_params
282
293
  req.return_logprob = recv_req.return_logprob
@@ -293,23 +304,25 @@ class ModelRpcServer:
293
304
  req.sampling_params.regex
294
305
  )
295
306
 
296
- # Truncate long prompts
297
- req.input_ids = req.input_ids[: self.model_config.context_len - 1]
307
+ # Truncate prompts that are too long
308
+ req.origin_input_ids = req.origin_input_ids[: self.model_config.context_len - 1]
298
309
  req.sampling_params.max_new_tokens = min(
299
310
  req.sampling_params.max_new_tokens,
300
- self.model_config.context_len - 1 - len(req.input_ids),
301
- self.max_total_num_token - 128 - len(req.input_ids),
311
+ self.model_config.context_len - 1 - len(req.origin_input_ids),
312
+ self.max_total_num_tokens - 128 - len(req.origin_input_ids),
302
313
  )
303
314
  self.forward_queue.append(req)
304
315
 
305
- def get_new_fill_batch(self):
316
+ def get_new_fill_batch(self) -> Optional[Batch]:
306
317
  if (
307
318
  self.running_batch is not None
308
- and len(self.running_batch.reqs) > self.max_num_running_seq
319
+ and len(self.running_batch.reqs) > self.max_running_requests
309
320
  ):
310
321
  return None
311
322
 
323
+ # Compute matched prefix length
312
324
  for req in self.forward_queue:
325
+ req.input_ids = req.origin_input_ids + req.output_ids
313
326
  prefix_indices, last_node = self.tree_cache.match_prefix(req.input_ids)
314
327
  if req.return_logprob:
315
328
  prefix_indices = prefix_indices[: req.logprob_start_len]
@@ -337,7 +350,7 @@ class ModelRpcServer:
337
350
  )
338
351
 
339
352
  for req in self.forward_queue:
340
- if req.return_logprob:
353
+ if req.return_logprob and req.normalized_prompt_logprob is None:
341
354
  # Need at least two tokens to compute normalized logprob
342
355
  if req.extend_input_len < 2:
343
356
  delta = 2 - req.extend_input_len
@@ -355,8 +368,9 @@ class ModelRpcServer:
355
368
  if (
356
369
  req.extend_input_len + req.max_new_tokens() + new_batch_total_tokens
357
370
  < available_size
358
- and req.extend_input_len + new_batch_input_tokens
359
- < self.max_prefill_num_token
371
+ and (req.extend_input_len + new_batch_input_tokens
372
+ <= self.max_prefill_tokens
373
+ or len(can_run_list) == 0)
360
374
  ):
361
375
  delta = self.tree_cache.inc_lock_ref(req.last_node)
362
376
  available_size += delta
@@ -381,6 +395,7 @@ class ModelRpcServer:
381
395
  if len(can_run_list) == 0:
382
396
  return None
383
397
 
398
+ # Print stats
384
399
  if self.tp_rank == 0:
385
400
  running_req = (
386
401
  0 if self.running_batch is None else len(self.running_batch.reqs)
@@ -394,20 +409,22 @@ class ModelRpcServer:
394
409
  self.tree_cache_metrics["hit"] / self.tree_cache_metrics["total"]
395
410
  )
396
411
  logger.info(
397
- f"new fill batch. #seq: {len(can_run_list)}. "
398
- f"#cached_token: {hit_tokens}. "
399
- f"#new_token: {new_batch_input_tokens}. "
400
- f"#remaining_req: {len(self.forward_queue) - len(can_run_list)}. "
401
- f"#running_req: {running_req}. "
402
- f"tree_cache_hit_rate: {100.0 * tree_cache_hit_rate:.2f}%."
412
+ f"[gpu_id={self.gpu_id}] Prefill batch. "
413
+ f"#new-seq: {len(can_run_list)}, "
414
+ f"#new-token: {new_batch_input_tokens}, "
415
+ f"#cached-token: {hit_tokens}, "
416
+ f"cache hit rate: {100.0 * tree_cache_hit_rate:.2f}%, "
417
+ f"#running-req: {running_req}, "
418
+ f"#queue-req: {len(self.forward_queue) - len(can_run_list)}"
403
419
  )
404
- #logger.debug(
420
+ # logger.debug(
405
421
  # f"fsm_cache_hit_rate: {100.0 * self.regex_fsm_cache.get_cache_hit_rate():.2f}%. "
406
422
  # f"fsm_cache_avg_init_time: {self.regex_fsm_cache.get_avg_init_time():.2f}s. "
407
423
  # f"ff_cache_hit_rate: {100.0 * self.jump_forward_cache.get_cache_hit_rate():.2f}%. "
408
424
  # f"ff_cache_avg_init_time: {self.jump_forward_cache.get_avg_init_time():.2f}s. "
409
- #)
425
+ # )
410
426
 
427
+ # Return the new batch
411
428
  new_batch = Batch.init_new(
412
429
  can_run_list,
413
430
  self.req_to_token_pool,
@@ -423,73 +440,91 @@ class ModelRpcServer:
423
440
  self.model_config.vocab_size, self.int_token_logit_bias
424
441
  )
425
442
 
443
+ # Forward and sample the next tokens
426
444
  if batch.extend_num_tokens != 0:
427
- # Forward
428
- logits, (
429
- prefill_token_logprobs,
430
- normalized_prompt_logprobs,
431
- prefill_top_logprobs,
432
- decode_top_logprobs,
433
- last_logprobs,
434
- ) = self.model_runner.forward(batch, ForwardMode.EXTEND)
435
- if prefill_token_logprobs is not None:
436
- prefill_token_logprobs = prefill_token_logprobs.tolist()
437
- normalized_prompt_logprobs = normalized_prompt_logprobs.tolist()
438
-
439
- next_token_ids, _ = batch.sample(logits)
440
-
441
- # Only transfer the selected logprobs of the next token to CPU to reduce overhead.
442
- if last_logprobs is not None:
443
- last_token_logprobs = (
444
- last_logprobs[
445
- torch.arange(len(batch.reqs), device=next_token_ids.device),
446
- next_token_ids].tolist()
447
- )
445
+ output = self.model_runner.forward(batch, ForwardMode.EXTEND)
446
+ next_token_ids, _ = batch.sample(output.next_token_logits)
447
+
448
+ # Move logprobs to cpu
449
+ if output.next_token_logprobs is not None:
450
+ output.next_token_logprobs = output.next_token_logprobs[
451
+ torch.arange(len(next_token_ids), device=next_token_ids.device),
452
+ next_token_ids,
453
+ ].tolist()
454
+ output.prefill_token_logprobs = output.prefill_token_logprobs.tolist()
455
+ output.normalized_prompt_logprobs = output.normalized_prompt_logprobs.tolist()
448
456
 
449
457
  next_token_ids = next_token_ids.tolist()
450
458
  else:
451
459
  next_token_ids = [self.tokenizer.eos_token_id] * len(batch.reqs)
452
460
 
453
- # Check finish condition
461
+ # Check finish conditions
454
462
  pt = 0
455
463
  for i, req in enumerate(batch.reqs):
456
464
  req.completion_tokens_wo_jump_forward += 1
457
- req.output_ids = [next_token_ids[i]]
465
+ req.output_ids.append(next_token_ids[i])
458
466
  req.check_finished()
459
467
 
460
468
  if req.return_logprob:
461
- req.normalized_prompt_logprob = normalized_prompt_logprobs[i]
469
+ self.add_logprob_return_values(i, req, pt, next_token_ids, output)
470
+ pt += req.extend_input_len
471
+
472
+ self.handle_finished_requests(batch)
473
+
474
+ def add_logprob_return_values(self, i, req, pt, next_token_ids, output):
475
+ if req.normalized_prompt_logprob is None:
476
+ req.normalized_prompt_logprob = output.normalized_prompt_logprobs[i]
462
477
 
463
- # If logprob_start_len > 0, then first logprob_start_len prompt tokens will be ignored.
464
- req.prefill_token_logprobs = list(
478
+ if req.prefill_token_logprobs is None:
479
+ # If logprob_start_len > 0, then first logprob_start_len prompt tokens will be ignored.
480
+ req.prefill_token_logprobs = list(
481
+ zip(
482
+ output.prefill_token_logprobs[pt : pt + req.extend_input_len - 1],
483
+ req.input_ids[-req.extend_input_len + 1 :],
484
+ )
485
+ )
486
+ if req.logprob_start_len == 0:
487
+ req.prefill_token_logprobs = [
488
+ (None, req.input_ids[0])
489
+ ] + req.prefill_token_logprobs
490
+
491
+ if req.last_update_decode_tokens != 0:
492
+ req.decode_token_logprobs.extend(
493
+ list(
465
494
  zip(
466
- prefill_token_logprobs[pt : pt + req.extend_input_len - 1],
467
- req.input_ids[-req.extend_input_len + 1 :],
495
+ output.prefill_token_logprobs[
496
+ pt
497
+ + req.extend_input_len
498
+ - req.last_update_decode_tokens : pt
499
+ + req.extend_input_len
500
+ - 1
501
+ ],
502
+ req.input_ids[-req.last_update_decode_tokens + 1 :],
468
503
  )
469
504
  )
470
- if req.logprob_start_len == 0:
471
- req.prefill_token_logprobs = [
472
- (None, req.input_ids[0])
473
- ] + req.prefill_token_logprobs
474
- req.decode_token_logprobs = [
475
- (last_token_logprobs[i], next_token_ids[i])
476
- ]
505
+ )
506
+
507
+ req.decode_token_logprobs.append(
508
+ (output.next_token_logprobs[i], next_token_ids[i])
509
+ )
477
510
 
478
- if req.top_logprobs_num > 0:
479
- req.prefill_top_logprobs = prefill_top_logprobs[i]
511
+ if req.top_logprobs_num > 0:
512
+ if req.prefill_top_logprobs is None:
513
+ req.prefill_top_logprobs = output.prefill_top_logprobs[i]
480
514
  if req.logprob_start_len == 0:
481
515
  req.prefill_top_logprobs = [None] + req.prefill_top_logprobs
482
- req.decode_top_logprobs = [decode_top_logprobs[i]]
483
516
 
484
- pt += req.extend_input_len
485
-
486
- self.handle_finished_requests(batch)
517
+ if req.last_update_decode_tokens != 0:
518
+ req.decode_top_logprobs.extend(
519
+ output.prefill_top_logprobs[i][-req.last_update_decode_tokens + 1 :]
520
+ )
521
+ req.decode_top_logprobs.append(output.decode_top_logprobs[i])
487
522
 
488
523
  def cache_filled_batch(self, batch: Batch):
489
- req_pool_indices_cpu = batch.req_pool_indices.cpu().tolist()
524
+ req_pool_indices_cpu = batch.req_pool_indices.cpu().numpy()
490
525
  for i, req in enumerate(batch.reqs):
491
526
  new_prefix_indices, new_last_node = self.tree_cache.cache_req(
492
- token_ids=tuple(req.input_ids + req.output_ids)[:-1],
527
+ token_ids=tuple(req.origin_input_ids + req.output_ids)[:-1],
493
528
  last_uncached_pos=len(req.prefix_indices),
494
529
  req_pool_idx=req_pool_indices_cpu[i],
495
530
  del_in_memory_pool=False,
@@ -498,10 +533,10 @@ class ModelRpcServer:
498
533
  req.prefix_indices, req.last_node = new_prefix_indices, new_last_node
499
534
 
500
535
  def forward_decode_batch(self, batch: Batch):
501
- # check if decode out of memory
536
+ # Check if decode out of memory
502
537
  if not batch.check_decode_mem():
503
538
  old_ratio = self.new_token_ratio
504
- self.new_token_ratio = min(old_ratio + self.new_token_ratio_step[1], 1.0)
539
+ self.new_token_ratio = min(old_ratio + self.new_token_ratio_recovery, 1.0)
505
540
 
506
541
  retracted_reqs = batch.retract_decode()
507
542
  logger.info(
@@ -512,27 +547,13 @@ class ModelRpcServer:
512
547
  self.forward_queue.extend(retracted_reqs)
513
548
  else:
514
549
  self.new_token_ratio = max(
515
- self.new_token_ratio - self.new_token_ratio_step[0],
550
+ self.new_token_ratio - self.new_token_ratio_decay,
516
551
  self.min_new_token_ratio,
517
552
  )
518
553
 
519
554
  if not self.disable_regex_jump_forward:
520
- # check for jump-forward
521
- jump_forward_reqs = batch.check_for_jump_forward()
522
-
523
- # check for image jump-forward
524
- for req in jump_forward_reqs:
525
- if req.pixel_values is not None:
526
- (
527
- req.input_ids,
528
- req.image_offset,
529
- ) = self.model_runner.model.pad_input_ids(
530
- req.input_ids,
531
- req.pad_value,
532
- req.pixel_values.shape,
533
- req.image_size,
534
- )
535
-
555
+ # Check for jump-forward
556
+ jump_forward_reqs = batch.check_for_jump_forward(self.model_runner)
536
557
  self.forward_queue.extend(jump_forward_reqs)
537
558
  if batch.is_empty():
538
559
  return
@@ -541,23 +562,19 @@ class ModelRpcServer:
541
562
  self.decode_forward_ct = (self.decode_forward_ct + 1) % (1 << 30)
542
563
  batch.prepare_for_decode()
543
564
 
544
- # Forward
545
- logits, (
546
- _,
547
- _,
548
- _,
549
- decode_top_logprobs,
550
- last_logprobs,
551
- ) = self.model_runner.forward(batch, ForwardMode.DECODE)
552
- next_token_ids, _ = batch.sample(logits)
553
- next_token_ids = next_token_ids.tolist()
565
+ # Forward and sample the next tokens
566
+ output = self.model_runner.forward(batch, ForwardMode.DECODE)
567
+ next_token_ids, _ = batch.sample(output.next_token_logits)
554
568
 
555
- # Only batch transfer the selected logprobs of the next token to CPU to reduce overhead.
556
- if last_logprobs is not None:
557
- new_token_logprobs = last_logprobs[
558
- torch.arange(len(batch.reqs)), next_token_ids
569
+ # Move logprobs to cpu
570
+ if output.next_token_logprobs is not None:
571
+ next_token_logprobs = output.next_token_logprobs[
572
+ torch.arange(len(next_token_ids), device=next_token_ids.device),
573
+ next_token_ids,
559
574
  ].tolist()
560
575
 
576
+ next_token_ids = next_token_ids.tolist()
577
+
561
578
  # Check finish condition
562
579
  for i, (req, next_token_id) in enumerate(zip(batch.reqs, next_token_ids)):
563
580
  req.completion_tokens_wo_jump_forward += 1
@@ -565,31 +582,30 @@ class ModelRpcServer:
565
582
  req.check_finished()
566
583
 
567
584
  if req.return_logprob:
568
- req.decode_token_logprobs.append((new_token_logprobs[i], next_token_id))
569
-
570
- if req.top_logprobs_num > 0:
571
- req.decode_top_logprobs.append(decode_top_logprobs[i])
585
+ req.decode_token_logprobs.append((next_token_logprobs[i], next_token_id))
586
+ if req.top_logprobs_num > 0:
587
+ req.decode_top_logprobs.append(output.decode_top_logprobs[i])
572
588
 
573
589
  self.handle_finished_requests(batch)
574
590
 
575
591
  def handle_finished_requests(self, batch: Batch):
576
592
  output_rids = []
577
- output_tokens = []
578
- output_and_jump_forward_strs = []
579
- output_hit_stop_str = []
593
+ decoded_texts = []
594
+ surr_output_ids = []
595
+ read_output_ids = []
580
596
  output_skip_special_tokens = []
581
597
  output_spaces_between_special_tokens = []
582
598
  output_meta_info = []
583
- output_finished = []
599
+ output_finished_reason: List[BaseFinishReason] = []
584
600
  finished_indices = []
585
601
  unfinished_indices = []
586
602
  for i, req in enumerate(batch.reqs):
587
- if req.finished:
603
+ if req.finished():
588
604
  finished_indices.append(i)
589
605
  else:
590
606
  unfinished_indices.append(i)
591
607
 
592
- if req.finished or (
608
+ if req.finished() or (
593
609
  (
594
610
  req.stream
595
611
  and (
@@ -599,9 +615,10 @@ class ModelRpcServer:
599
615
  )
600
616
  ):
601
617
  output_rids.append(req.rid)
602
- output_tokens.append(req.output_ids)
603
- output_and_jump_forward_strs.append(req.output_and_jump_forward_str)
604
- output_hit_stop_str.append(req.hit_stop_str)
618
+ decoded_texts.append(req.decoded_text)
619
+ surr_ids, read_ids, _ = req.init_detokenize_incrementally()
620
+ surr_output_ids.append(surr_ids)
621
+ read_output_ids.append(read_ids)
605
622
  output_skip_special_tokens.append(
606
623
  req.sampling_params.skip_special_tokens
607
624
  )
@@ -610,13 +627,10 @@ class ModelRpcServer:
610
627
  )
611
628
 
612
629
  meta_info = {
613
- "prompt_tokens": req.prompt_tokens,
614
- "completion_tokens": len(req.input_ids)
615
- + len(req.output_ids)
616
- - req.prompt_tokens,
630
+ "prompt_tokens": len(req.origin_input_ids),
631
+ "completion_tokens": len(req.output_ids),
617
632
  "completion_tokens_wo_jump_forward": req.completion_tokens_wo_jump_forward,
618
- "finish_reason": FinishReason.to_str(req.finish_reason),
619
- "hit_stop_str": req.hit_stop_str,
633
+ "finish_reason": str(req.finished_reason),
620
634
  }
621
635
  if req.return_logprob:
622
636
  (
@@ -633,20 +647,20 @@ class ModelRpcServer:
633
647
  req.normalized_prompt_logprob,
634
648
  )
635
649
  output_meta_info.append(meta_info)
636
- output_finished.append(req.finished)
650
+ output_finished_reason.append(req.finished_reason)
637
651
 
638
652
  # Send to detokenizer
639
653
  if output_rids:
640
654
  self.out_pyobjs.append(
641
655
  BatchTokenIDOut(
642
656
  output_rids,
643
- output_tokens,
644
- output_and_jump_forward_strs,
645
- output_hit_stop_str,
657
+ decoded_texts,
658
+ surr_output_ids,
659
+ read_output_ids,
646
660
  output_skip_special_tokens,
647
661
  output_spaces_between_special_tokens,
648
662
  output_meta_info,
649
- output_finished,
663
+ output_finished_reason,
650
664
  )
651
665
  )
652
666
 
@@ -657,7 +671,7 @@ class ModelRpcServer:
657
671
  for i in finished_indices:
658
672
  req = batch.reqs[i]
659
673
  self.tree_cache.cache_req(
660
- token_ids=tuple(req.input_ids + req.output_ids)[:-1],
674
+ token_ids=tuple(req.origin_input_ids + req.output_ids)[:-1],
661
675
  last_uncached_pos=len(req.prefix_indices),
662
676
  req_pool_idx=req_pool_indices_cpu[i],
663
677
  )
@@ -670,21 +684,67 @@ class ModelRpcServer:
670
684
  else:
671
685
  batch.reqs = []
672
686
 
687
+ def flush_cache(self):
688
+ if len(self.forward_queue) == 0 and (
689
+ self.running_batch is None or len(self.running_batch.reqs) == 0
690
+ ):
691
+ self.tree_cache.reset()
692
+ self.tree_cache_metrics = {"total": 0, "hit": 0}
693
+ self.regex_fsm_cache.reset()
694
+ self.req_to_token_pool.clear()
695
+ self.token_to_kv_pool.clear()
696
+ torch.cuda.empty_cache()
697
+ logger.info("Cache flushed successfully!")
698
+ else:
699
+ warnings.warn(
700
+ f"Cache not flushed because there are pending requests. "
701
+ f"#queue-req: {len(self.forward_queue)}, "
702
+ f"#running-req: {0 if self.running_batch is None else len(self.running_batch.reqs)}"
703
+ )
704
+
705
+ def abort_request(self, recv_req):
706
+ # Delete requests in the waiting queue
707
+ to_del = None
708
+ for i, req in enumerate(self.forward_queue):
709
+ if req.rid == recv_req.rid:
710
+ to_del = i
711
+ break
712
+
713
+ if to_del is not None:
714
+ del self.forward_queue[to_del]
715
+
716
+ # Delete requests in the running batch
717
+ if self.running_batch:
718
+ for req in self.running_batch.reqs:
719
+ if req.rid == recv_req.rid:
720
+ req.finished_reason = FINISH_ABORT()
721
+ break
673
722
 
674
- class ModelRpcService(rpyc.Service):
675
- exposed_ModelRpcServer = ModelRpcServer
676
723
 
724
+ class ModelTpService(rpyc.Service):
725
+ exposed_ModelTpServer = ModelTpServer
677
726
 
678
- class ModelRpcClient:
727
+
728
+ class ModelTpClient:
679
729
  def __init__(
680
- self, server_args: ServerArgs, port_args: PortArgs, model_overide_args
730
+ self,
731
+ gpu_ids: List[int],
732
+ server_args: ServerArgs,
733
+ model_port_args: ModelPortArgs,
734
+ model_overide_args,
681
735
  ):
682
- tp_size = server_args.tp_size
736
+ server_args, model_port_args = obtain(server_args), obtain(model_port_args)
737
+ self.tp_size = server_args.tp_size
683
738
 
684
- if tp_size == 1:
739
+ if self.tp_size * server_args.dp_size == 1:
685
740
  # Init model
686
- self.model_server = ModelRpcService().exposed_ModelRpcServer(
687
- 0, server_args, port_args, model_overide_args
741
+ assert len(gpu_ids) == 1
742
+ self.model_server = ModelTpService().exposed_ModelTpServer(
743
+ 0,
744
+ gpu_ids[0],
745
+ server_args,
746
+ model_port_args,
747
+ model_overide_args,
688
748
  )
689
749
 
690
750
  # Wrap functions
@@ -696,19 +756,31 @@ class ModelRpcClient:
696
756
 
697
757
  self.step = async_wrap(self.model_server.exposed_step)
698
758
  else:
699
- with ThreadPoolExecutor(tp_size) as executor:
759
+ with ThreadPoolExecutor(self.tp_size) as executor:
700
760
  # Launch model processes
701
- rets = executor.map(start_model_process, port_args.model_rpc_ports)
702
- self.remote_services = [x[0] for x in rets]
703
- self.procs = [x[1] for x in rets]
761
+ if server_args.nnodes == 1:
762
+ self.procs = list(executor.map(
763
+ lambda args: start_rpyc_service_process(*args),
764
+ [(ModelTpService, p) for p in model_port_args.model_tp_ports],
765
+ ))
766
+ addrs = [("localhost", p) for p in model_port_args.model_tp_ports]
767
+ else:
768
+ addrs = [(ip, port) for ip, port in zip(model_port_args.model_tp_ips, model_port_args.model_tp_ports)]
769
+
770
+ self.model_services = list(executor.map(
771
+ lambda args: connect_rpyc_service(*args), addrs))
704
772
 
705
773
  # Init model
706
774
  def init_model(i):
707
- return self.remote_services[i].ModelRpcServer(
708
- i, server_args, port_args, model_overide_args
775
+ return self.model_services[i].ModelTpServer(
776
+ gpu_ids[i],
777
+ i,
778
+ server_args,
779
+ model_port_args,
780
+ model_overide_args,
709
781
  )
710
782
 
711
- self.model_servers = executor.map(init_model, range(tp_size))
783
+ self.model_servers = list(executor.map(init_model, range(self.tp_size)))
712
784
 
713
785
  # Wrap functions
714
786
  def async_wrap(func_name):
@@ -722,44 +794,3 @@ class ModelRpcClient:
722
794
  return _func
723
795
 
724
796
  self.step = async_wrap("step")
725
-
726
-
727
- def _init_service(port):
728
- t = ThreadedServer(
729
- ModelRpcService(),
730
- port=port,
731
- protocol_config={
732
- "allow_public_attrs": True,
733
- "allow_pickle": True,
734
- "sync_request_timeout": 1800,
735
- },
736
- )
737
- t.start()
738
-
739
-
740
- def start_model_process(port):
741
- proc = multiprocessing.Process(target=_init_service, args=(port,))
742
- proc.start()
743
- time.sleep(1)
744
-
745
- repeat_count = 0
746
- while repeat_count < 20:
747
- try:
748
- con = rpyc.connect(
749
- "localhost",
750
- port,
751
- config={
752
- "allow_public_attrs": True,
753
- "allow_pickle": True,
754
- "sync_request_timeout": 1800,
755
- },
756
- )
757
- break
758
- except ConnectionRefusedError:
759
- time.sleep(1)
760
- repeat_count += 1
761
- if repeat_count == 20:
762
- raise RuntimeError("init rpc env error!")
763
-
764
- assert proc.is_alive()
765
- return con.root, proc