sglang 0.1.17__py3-none-any.whl → 0.1.19__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (73) hide show
  1. sglang/__init__.py +2 -2
  2. sglang/api.py +30 -4
  3. sglang/backend/litellm.py +2 -2
  4. sglang/backend/openai.py +26 -15
  5. sglang/backend/runtime_endpoint.py +18 -14
  6. sglang/bench_latency.py +317 -0
  7. sglang/global_config.py +5 -1
  8. sglang/lang/chat_template.py +41 -6
  9. sglang/lang/compiler.py +2 -2
  10. sglang/lang/interpreter.py +6 -2
  11. sglang/lang/ir.py +74 -28
  12. sglang/launch_server.py +4 -1
  13. sglang/launch_server_llavavid.py +2 -1
  14. sglang/srt/constrained/__init__.py +14 -6
  15. sglang/srt/constrained/fsm_cache.py +6 -3
  16. sglang/srt/constrained/jump_forward.py +113 -25
  17. sglang/srt/conversation.py +2 -0
  18. sglang/srt/flush_cache.py +2 -0
  19. sglang/srt/hf_transformers_utils.py +68 -9
  20. sglang/srt/layers/extend_attention.py +2 -1
  21. sglang/srt/layers/fused_moe.py +280 -169
  22. sglang/srt/layers/logits_processor.py +106 -42
  23. sglang/srt/layers/radix_attention.py +53 -29
  24. sglang/srt/layers/token_attention.py +4 -1
  25. sglang/srt/managers/controller/dp_worker.py +6 -3
  26. sglang/srt/managers/controller/infer_batch.py +144 -69
  27. sglang/srt/managers/controller/manager_multi.py +5 -5
  28. sglang/srt/managers/controller/manager_single.py +9 -4
  29. sglang/srt/managers/controller/model_runner.py +167 -55
  30. sglang/srt/managers/controller/radix_cache.py +4 -0
  31. sglang/srt/managers/controller/schedule_heuristic.py +2 -0
  32. sglang/srt/managers/controller/tp_worker.py +156 -134
  33. sglang/srt/managers/detokenizer_manager.py +19 -21
  34. sglang/srt/managers/io_struct.py +11 -5
  35. sglang/srt/managers/tokenizer_manager.py +16 -14
  36. sglang/srt/model_config.py +89 -4
  37. sglang/srt/models/chatglm.py +399 -0
  38. sglang/srt/models/commandr.py +2 -2
  39. sglang/srt/models/dbrx.py +1 -1
  40. sglang/srt/models/gemma.py +5 -1
  41. sglang/srt/models/gemma2.py +436 -0
  42. sglang/srt/models/grok.py +204 -137
  43. sglang/srt/models/llama2.py +12 -5
  44. sglang/srt/models/llama_classification.py +107 -0
  45. sglang/srt/models/llava.py +11 -8
  46. sglang/srt/models/llavavid.py +1 -1
  47. sglang/srt/models/minicpm.py +373 -0
  48. sglang/srt/models/mixtral.py +164 -115
  49. sglang/srt/models/mixtral_quant.py +0 -1
  50. sglang/srt/models/qwen.py +1 -1
  51. sglang/srt/models/qwen2.py +1 -1
  52. sglang/srt/models/qwen2_moe.py +454 -0
  53. sglang/srt/models/stablelm.py +1 -1
  54. sglang/srt/models/yivl.py +2 -2
  55. sglang/srt/openai_api_adapter.py +35 -25
  56. sglang/srt/openai_protocol.py +2 -2
  57. sglang/srt/server.py +69 -19
  58. sglang/srt/server_args.py +76 -43
  59. sglang/srt/utils.py +177 -35
  60. sglang/test/test_programs.py +28 -10
  61. sglang/utils.py +4 -3
  62. {sglang-0.1.17.dist-info → sglang-0.1.19.dist-info}/METADATA +44 -31
  63. sglang-0.1.19.dist-info/RECORD +81 -0
  64. {sglang-0.1.17.dist-info → sglang-0.1.19.dist-info}/WHEEL +1 -1
  65. sglang/srt/managers/router/infer_batch.py +0 -596
  66. sglang/srt/managers/router/manager.py +0 -82
  67. sglang/srt/managers/router/model_rpc.py +0 -818
  68. sglang/srt/managers/router/model_runner.py +0 -445
  69. sglang/srt/managers/router/radix_cache.py +0 -267
  70. sglang/srt/managers/router/scheduler.py +0 -59
  71. sglang-0.1.17.dist-info/RECORD +0 -81
  72. {sglang-0.1.17.dist-info → sglang-0.1.19.dist-info}/LICENSE +0 -0
  73. {sglang-0.1.17.dist-info → sglang-0.1.19.dist-info}/top_level.txt +0 -0
@@ -1,818 +0,0 @@
1
- import asyncio
2
- import logging
3
- import multiprocessing
4
- import time
5
- import warnings
6
- from concurrent.futures import ThreadPoolExecutor
7
- from typing import List, Optional
8
-
9
- import rpyc
10
- import torch
11
- from rpyc.utils.classic import obtain
12
- from rpyc.utils.server import ThreadedServer
13
-
14
- try:
15
- from vllm.logger import _default_handler as vllm_default_logger
16
- except ImportError:
17
- from vllm.logger import logger as vllm_default_logger
18
-
19
- from sglang.global_config import global_config
20
- from sglang.srt.constrained.fsm_cache import FSMCache
21
- from sglang.srt.constrained.jump_forward import JumpForwardCache
22
- from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
23
- from sglang.srt.managers.io_struct import (
24
- AbortReq,
25
- BatchTokenIDOut,
26
- FlushCacheReq,
27
- TokenizedGenerateReqInput,
28
- )
29
- from sglang.srt.managers.router.infer_batch import Batch, FinishReason, ForwardMode, Req
30
- from sglang.srt.managers.router.model_runner import ModelRunner
31
- from sglang.srt.managers.router.radix_cache import RadixCache
32
- from sglang.srt.managers.router.scheduler import Scheduler
33
- from sglang.srt.model_config import ModelConfig
34
- from sglang.srt.server_args import PortArgs, ServerArgs
35
- from sglang.srt.utils import (
36
- get_int_token_logit_bias,
37
- is_multimodal_model,
38
- set_random_seed,
39
- )
40
- from sglang.utils import get_exception_traceback
41
-
42
- logger = logging.getLogger("model_rpc")
43
- vllm_default_logger.setLevel(logging.WARN)
44
- logging.getLogger("vllm.utils").setLevel(logging.WARN)
45
- logging.getLogger("vllm.selector").setLevel(logging.WARN)
46
-
47
-
48
- class ModelRpcServer:
49
- def __init__(
50
- self,
51
- tp_rank: int,
52
- server_args: ServerArgs,
53
- port_args: PortArgs,
54
- model_overide_args: Optional[dict] = None,
55
- ):
56
- server_args, port_args = [obtain(x) for x in [server_args, port_args]]
57
-
58
- # Copy arguments
59
- self.tp_rank = tp_rank
60
- self.tp_size = server_args.tp_size
61
- self.schedule_heuristic = server_args.schedule_heuristic
62
- self.disable_regex_jump_forward = server_args.disable_regex_jump_forward
63
-
64
- # Init model and tokenizer
65
- self.model_config = ModelConfig(
66
- server_args.model_path,
67
- server_args.trust_remote_code,
68
- context_length=server_args.context_length,
69
- model_overide_args=model_overide_args,
70
- )
71
-
72
- # For model end global settings
73
- self.model_runner = ModelRunner(
74
- model_config=self.model_config,
75
- mem_fraction_static=server_args.mem_fraction_static,
76
- tp_rank=tp_rank,
77
- tp_size=server_args.tp_size,
78
- nccl_port=port_args.nccl_port,
79
- server_args=server_args,
80
- )
81
- if is_multimodal_model(server_args.model_path):
82
- self.processor = get_processor(
83
- server_args.tokenizer_path,
84
- tokenizer_mode=server_args.tokenizer_mode,
85
- trust_remote_code=server_args.trust_remote_code,
86
- )
87
- self.tokenizer = self.processor.tokenizer
88
- else:
89
- self.tokenizer = get_tokenizer(
90
- server_args.tokenizer_path,
91
- tokenizer_mode=server_args.tokenizer_mode,
92
- trust_remote_code=server_args.trust_remote_code,
93
- )
94
- self.max_total_num_tokens = self.model_runner.max_total_num_tokens
95
- self.max_prefill_tokens = max(
96
- self.model_config.context_len,
97
- (
98
- self.max_total_num_tokens // 6
99
- if server_args.max_prefill_tokens is None
100
- else server_args.max_prefill_tokens
101
- ),
102
- )
103
- self.max_running_requests = (self.max_total_num_tokens // 2
104
- if server_args.max_running_requests is None else server_args.max_running_requests)
105
-
106
- self.int_token_logit_bias = torch.tensor(
107
- get_int_token_logit_bias(self.tokenizer, self.model_config.vocab_size)
108
- )
109
- set_random_seed(server_args.random_seed)
110
-
111
- # Print info
112
- logger.info(f"[rank={self.tp_rank}] "
113
- f"max_total_num_tokens={self.max_total_num_tokens}, "
114
- f"max_prefill_tokens={self.max_prefill_tokens}, "
115
- f"context_len={self.model_config.context_len}, "
116
- )
117
- if self.tp_rank == 0:
118
- logger.info(f"server_args: {server_args.print_mode_args()}")
119
-
120
- # Init cache
121
- self.tree_cache = RadixCache(
122
- req_to_token_pool=self.model_runner.req_to_token_pool,
123
- token_to_kv_pool=self.model_runner.token_to_kv_pool,
124
- disable=server_args.disable_radix_cache,
125
- )
126
- self.tree_cache_metrics = {"total": 0, "hit": 0}
127
- self.scheduler = Scheduler(
128
- self.schedule_heuristic,
129
- self.max_running_requests,
130
- self.max_prefill_tokens,
131
- self.max_total_num_tokens,
132
- self.tree_cache,
133
- )
134
- self.req_to_token_pool = self.model_runner.req_to_token_pool
135
- self.token_to_kv_pool = self.model_runner.token_to_kv_pool
136
-
137
- # Init running status
138
- self.forward_queue: List[Req] = []
139
- self.running_batch: Batch = None
140
- self.out_pyobjs = []
141
- self.decode_forward_ct = 0
142
- self.stream_interval = server_args.stream_interval
143
- self.num_generated_tokens = 0
144
- self.last_stats_tic = time.time()
145
-
146
- # Init the FSM cache for constrained generation
147
- self.regex_fsm_cache = FSMCache(
148
- server_args.tokenizer_path,
149
- {
150
- "tokenizer_mode": server_args.tokenizer_mode,
151
- "trust_remote_code": server_args.trust_remote_code,
152
- },
153
- )
154
- self.jump_forward_cache = JumpForwardCache()
155
-
156
- # Init new token estimation
157
- assert (
158
- server_args.schedule_conservativeness >= 0
159
- ), "Invalid schedule_conservativeness"
160
- self.new_token_ratio = min(
161
- global_config.base_new_token_ratio * server_args.schedule_conservativeness,
162
- 1.0,
163
- )
164
- self.min_new_token_ratio = min(
165
- global_config.base_min_new_token_ratio
166
- * server_args.schedule_conservativeness,
167
- 1.0,
168
- )
169
- self.new_token_ratio_decay = global_config.new_token_ratio_decay
170
- self.new_token_ratio_recovery = global_config.new_token_ratio_recovery
171
-
172
- def exposed_step(self, recv_reqs):
173
- if self.tp_size != 1:
174
- recv_reqs = obtain(recv_reqs)
175
-
176
- try:
177
- # Recv requests
178
- for recv_req in recv_reqs:
179
- if isinstance(recv_req, TokenizedGenerateReqInput):
180
- self.handle_generate_request(recv_req)
181
- elif isinstance(recv_req, FlushCacheReq):
182
- self.flush_cache()
183
- elif isinstance(recv_req, AbortReq):
184
- self.abort_request(recv_req)
185
- else:
186
- raise ValueError(f"Invalid request: {recv_req}")
187
-
188
- # Forward
189
- self.forward_step()
190
- except Exception:
191
- logger.error("Exception in ModelRpcClient:\n" + get_exception_traceback())
192
-
193
- # Return results
194
- ret = self.out_pyobjs
195
- self.out_pyobjs = []
196
- return ret
197
-
198
- @torch.inference_mode()
199
- def forward_step(self):
200
- new_batch = self.get_new_fill_batch()
201
-
202
- if new_batch is not None:
203
- # Run a new fill batch
204
- self.forward_fill_batch(new_batch)
205
- self.cache_filled_batch(new_batch)
206
-
207
- if not new_batch.is_empty():
208
- if self.running_batch is None:
209
- self.running_batch = new_batch
210
- else:
211
- self.running_batch.merge(new_batch)
212
- else:
213
- # Run decode batch
214
- if self.running_batch is not None:
215
- # Run a few decode batches continuously for reducing overhead
216
- for _ in range(10):
217
- self.num_generated_tokens += len(self.running_batch.reqs)
218
- self.forward_decode_batch(self.running_batch)
219
-
220
- # Print stats
221
- if self.tp_rank == 0:
222
- if self.decode_forward_ct % 40 == 0:
223
- num_used = self.max_total_num_tokens - (
224
- self.token_to_kv_pool.available_size()
225
- + self.tree_cache.evictable_size()
226
- )
227
- throuhgput = self.num_generated_tokens / (
228
- time.time() - self.last_stats_tic
229
- )
230
- self.num_generated_tokens = 0
231
- self.last_stats_tic = time.time()
232
- logger.info(
233
- f"#running-req: {len(self.running_batch.reqs)}, "
234
- f"#token: {num_used}, "
235
- f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
236
- f"gen throughput (token/s): {throuhgput:.2f}, "
237
- f"#queue-req: {len(self.forward_queue)}"
238
- )
239
-
240
- if self.running_batch.is_empty():
241
- self.running_batch = None
242
- break
243
-
244
- if self.out_pyobjs and self.running_batch.reqs[0].stream:
245
- break
246
- else:
247
- # Check the available size
248
- available_size = (
249
- self.token_to_kv_pool.available_size()
250
- + self.tree_cache.evictable_size()
251
- )
252
- if available_size != self.max_total_num_tokens:
253
- warnings.warn(
254
- "Warning: "
255
- f"available_size={available_size}, max_total_num_tokens={self.max_total_num_tokens}\n"
256
- "KV cache pool leak detected!"
257
- )
258
-
259
- def handle_generate_request(
260
- self,
261
- recv_req: TokenizedGenerateReqInput,
262
- ):
263
- req = Req(recv_req.rid, recv_req.input_text, recv_req.input_ids)
264
- req.pixel_values = recv_req.pixel_values
265
- if req.pixel_values is not None:
266
- req.pad_value = [
267
- (recv_req.image_hash) % self.model_config.vocab_size,
268
- (recv_req.image_hash >> 16) % self.model_config.vocab_size,
269
- (recv_req.image_hash >> 32) % self.model_config.vocab_size,
270
- (recv_req.image_hash >> 64) % self.model_config.vocab_size,
271
- ]
272
- req.image_size = recv_req.image_size
273
- req.origin_input_ids, req.image_offset = (
274
- self.model_runner.model.pad_input_ids(
275
- req.origin_input_ids_unpadded,
276
- req.pad_value,
277
- req.pixel_values.shape,
278
- req.image_size,
279
- )
280
- )
281
- req.sampling_params = recv_req.sampling_params
282
- req.return_logprob = recv_req.return_logprob
283
- req.logprob_start_len = recv_req.logprob_start_len
284
- req.top_logprobs_num = recv_req.top_logprobs_num
285
- req.stream = recv_req.stream
286
- req.tokenizer = self.tokenizer
287
-
288
- # Init regex fsm
289
- if req.sampling_params.regex is not None:
290
- req.regex_fsm = self.regex_fsm_cache.query(req.sampling_params.regex)
291
- if not self.disable_regex_jump_forward:
292
- req.jump_forward_map = self.jump_forward_cache.query(
293
- req.sampling_params.regex
294
- )
295
-
296
- # Truncate prompts that are too long
297
- req.origin_input_ids = req.origin_input_ids[: self.model_config.context_len - 1]
298
- req.sampling_params.max_new_tokens = min(
299
- req.sampling_params.max_new_tokens,
300
- self.model_config.context_len - 1 - len(req.origin_input_ids),
301
- self.max_total_num_tokens - 128 - len(req.origin_input_ids),
302
- )
303
- self.forward_queue.append(req)
304
-
305
- def get_new_fill_batch(self):
306
- if (
307
- self.running_batch is not None
308
- and len(self.running_batch.reqs) > self.max_running_requests
309
- ):
310
- return None
311
-
312
- # Compute matched prefix length
313
- for req in self.forward_queue:
314
- assert (
315
- len(req.output_ids) == 0
316
- ), "The output ids should be empty when prefilling"
317
- req.input_ids = req.origin_input_ids + req.prev_output_ids
318
- prefix_indices, last_node = self.tree_cache.match_prefix(req.input_ids)
319
- if req.return_logprob:
320
- prefix_indices = prefix_indices[: req.logprob_start_len]
321
- req.extend_input_len = len(req.input_ids) - len(prefix_indices)
322
- req.prefix_indices = prefix_indices
323
- req.last_node = last_node
324
-
325
- # Get priority queue
326
- self.forward_queue = self.scheduler.get_priority_queue(self.forward_queue)
327
-
328
- # Add requests if there is available space
329
- can_run_list = []
330
- new_batch_total_tokens = 0
331
- new_batch_input_tokens = 0
332
-
333
- available_size = (
334
- self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()
335
- )
336
- if self.running_batch:
337
- available_size -= sum(
338
- [
339
- (r.max_new_tokens() - len(r.output_ids)) * self.new_token_ratio
340
- for r in self.running_batch.reqs
341
- ]
342
- )
343
-
344
- for req in self.forward_queue:
345
- if req.return_logprob and req.normalized_prompt_logprob is None:
346
- # Need at least two tokens to compute normalized logprob
347
- if req.extend_input_len < 2:
348
- delta = 2 - req.extend_input_len
349
- req.extend_input_len += delta
350
- req.prefix_indices = req.prefix_indices[:-delta]
351
- if req.image_offset is not None:
352
- req.image_offset += delta
353
- if req.extend_input_len == 0 and req.max_new_tokens() > 0:
354
- # Need at least one token to compute logits
355
- req.extend_input_len = 1
356
- req.prefix_indices = req.prefix_indices[:-1]
357
- if req.image_offset is not None:
358
- req.image_offset += 1
359
-
360
- if (
361
- req.extend_input_len + req.max_new_tokens() + new_batch_total_tokens
362
- < available_size
363
- and req.extend_input_len + new_batch_input_tokens
364
- < self.max_prefill_tokens
365
- ):
366
- delta = self.tree_cache.inc_lock_ref(req.last_node)
367
- available_size += delta
368
-
369
- if not (
370
- req.extend_input_len + req.max_new_tokens() + new_batch_total_tokens
371
- < available_size
372
- ):
373
- # Undo locking
374
- delta = self.tree_cache.dec_lock_ref(req.last_node)
375
- available_size += delta
376
- break
377
- else:
378
- # Add this request to the running batch
379
- can_run_list.append(req)
380
- new_batch_total_tokens += (
381
- req.extend_input_len + req.max_new_tokens()
382
- )
383
- new_batch_input_tokens += req.extend_input_len
384
- else:
385
- break
386
- if len(can_run_list) == 0:
387
- return None
388
-
389
- # Print stats
390
- if self.tp_rank == 0:
391
- running_req = (
392
- 0 if self.running_batch is None else len(self.running_batch.reqs)
393
- )
394
- hit_tokens = sum(len(x.prefix_indices) for x in can_run_list)
395
- self.tree_cache_metrics["total"] += (
396
- hit_tokens + new_batch_input_tokens
397
- ) / 10**9
398
- self.tree_cache_metrics["hit"] += hit_tokens / 10**9
399
- tree_cache_hit_rate = (
400
- self.tree_cache_metrics["hit"] / self.tree_cache_metrics["total"]
401
- )
402
- logger.info(
403
- f"new fill batch. #seq: {len(can_run_list)}. "
404
- f"#cached_token: {hit_tokens}. "
405
- f"#new_token: {new_batch_input_tokens}. "
406
- f"#remaining_req: {len(self.forward_queue) - len(can_run_list)}. "
407
- f"#running_req: {running_req}. "
408
- f"tree_cache_hit_rate: {100.0 * tree_cache_hit_rate:.2f}%."
409
- )
410
- # logger.debug(
411
- # f"fsm_cache_hit_rate: {100.0 * self.regex_fsm_cache.get_cache_hit_rate():.2f}%. "
412
- # f"fsm_cache_avg_init_time: {self.regex_fsm_cache.get_avg_init_time():.2f}s. "
413
- # f"ff_cache_hit_rate: {100.0 * self.jump_forward_cache.get_cache_hit_rate():.2f}%. "
414
- # f"ff_cache_avg_init_time: {self.jump_forward_cache.get_avg_init_time():.2f}s. "
415
- # )
416
-
417
- # Return the new batch
418
- new_batch = Batch.init_new(
419
- can_run_list,
420
- self.req_to_token_pool,
421
- self.token_to_kv_pool,
422
- self.tree_cache,
423
- )
424
- self.forward_queue = [x for x in self.forward_queue if x not in can_run_list]
425
- return new_batch
426
-
427
- def forward_fill_batch(self, batch: Batch):
428
- # Build batch tensors
429
- batch.prepare_for_extend(
430
- self.model_config.vocab_size, self.int_token_logit_bias
431
- )
432
-
433
- if batch.extend_num_tokens != 0:
434
- # Forward
435
- logits, (
436
- prefill_token_logprobs,
437
- normalized_prompt_logprobs,
438
- prefill_top_logprobs,
439
- decode_top_logprobs,
440
- last_logprobs,
441
- ) = self.model_runner.forward(batch, ForwardMode.EXTEND)
442
- if prefill_token_logprobs is not None:
443
- prefill_token_logprobs = prefill_token_logprobs.tolist()
444
- normalized_prompt_logprobs = normalized_prompt_logprobs.tolist()
445
-
446
- next_token_ids, _ = batch.sample(logits)
447
-
448
- # Only transfer the selected logprobs of the next token to CPU to reduce overhead.
449
- if last_logprobs is not None:
450
- last_token_logprobs = last_logprobs[
451
- torch.arange(len(batch.reqs), device=next_token_ids.device),
452
- next_token_ids,
453
- ].tolist()
454
-
455
- next_token_ids = next_token_ids.tolist()
456
- else:
457
- next_token_ids = [self.tokenizer.eos_token_id] * len(batch.reqs)
458
-
459
- # Check finish condition
460
- pt = 0
461
- for i, req in enumerate(batch.reqs):
462
- req.completion_tokens_wo_jump_forward += 1
463
- req.output_ids = [next_token_ids[i]]
464
- req.check_finished()
465
-
466
- if req.return_logprob:
467
- if req.normalized_prompt_logprob is None:
468
- req.normalized_prompt_logprob = normalized_prompt_logprobs[i]
469
-
470
- if req.prefill_token_logprobs is None:
471
- # If logprob_start_len > 0, then first logprob_start_len prompt tokens will be ignored.
472
- req.prefill_token_logprobs = list(
473
- zip(
474
- prefill_token_logprobs[pt : pt + req.extend_input_len - 1],
475
- req.input_ids[-req.extend_input_len + 1 :],
476
- )
477
- )
478
- if req.logprob_start_len == 0:
479
- req.prefill_token_logprobs = [
480
- (None, req.input_ids[0])
481
- ] + req.prefill_token_logprobs
482
-
483
- if req.last_update_decode_tokens != 0:
484
- req.decode_token_logprobs.extend(
485
- list(
486
- zip(
487
- prefill_token_logprobs[
488
- pt
489
- + req.extend_input_len
490
- - req.last_update_decode_tokens : pt
491
- + req.extend_input_len
492
- - 1
493
- ],
494
- req.input_ids[-req.last_update_decode_tokens + 1 :],
495
- )
496
- )
497
- )
498
-
499
- req.decode_token_logprobs.append(
500
- (last_token_logprobs[i], next_token_ids[i])
501
- )
502
-
503
- if req.top_logprobs_num > 0:
504
- if req.prefill_top_logprobs is None:
505
- req.prefill_top_logprobs = prefill_top_logprobs[i]
506
- if req.logprob_start_len == 0:
507
- req.prefill_top_logprobs = [None] + req.prefill_top_logprobs
508
-
509
- if req.last_update_decode_tokens != 0:
510
- req.decode_top_logprobs.extend(
511
- prefill_top_logprobs[i][-req.last_update_decode_tokens + 1 :]
512
- )
513
- req.decode_top_logprobs.append(decode_top_logprobs[i])
514
-
515
- pt += req.extend_input_len
516
-
517
- self.handle_finished_requests(batch)
518
-
519
- def cache_filled_batch(self, batch: Batch):
520
- req_pool_indices_cpu = batch.req_pool_indices.cpu().numpy()
521
- for i, req in enumerate(batch.reqs):
522
- new_prefix_indices, new_last_node = self.tree_cache.cache_req(
523
- token_ids=tuple(req.input_ids + req.output_ids)[:-1],
524
- last_uncached_pos=len(req.prefix_indices),
525
- req_pool_idx=req_pool_indices_cpu[i],
526
- del_in_memory_pool=False,
527
- old_last_node=req.last_node,
528
- )
529
- req.prefix_indices, req.last_node = new_prefix_indices, new_last_node
530
-
531
- def forward_decode_batch(self, batch: Batch):
532
- # check if decode out of memory
533
- if not batch.check_decode_mem():
534
- old_ratio = self.new_token_ratio
535
- self.new_token_ratio = min(old_ratio + self.new_token_ratio_recovery, 1.0)
536
-
537
- retracted_reqs = batch.retract_decode()
538
- logger.info(
539
- "decode out of memory happened, "
540
- f"#retracted_reqs: {len(retracted_reqs)}, "
541
- f"#new_token_ratio: {old_ratio:.4f} -> {self.new_token_ratio:.4f}"
542
- )
543
- self.forward_queue.extend(retracted_reqs)
544
- else:
545
- self.new_token_ratio = max(
546
- self.new_token_ratio - self.new_token_ratio_decay,
547
- self.min_new_token_ratio,
548
- )
549
-
550
- if not self.disable_regex_jump_forward:
551
- # check for jump-forward
552
- jump_forward_reqs = batch.check_for_jump_forward(self.model_runner)
553
-
554
- self.forward_queue.extend(jump_forward_reqs)
555
- if batch.is_empty():
556
- return
557
-
558
- # Update batch tensors
559
- self.decode_forward_ct = (self.decode_forward_ct + 1) % (1 << 30)
560
- batch.prepare_for_decode()
561
-
562
- # Forward
563
- logits, (
564
- _,
565
- _,
566
- _,
567
- decode_top_logprobs,
568
- last_logprobs,
569
- ) = self.model_runner.forward(batch, ForwardMode.DECODE)
570
- next_token_ids, _ = batch.sample(logits)
571
- next_token_ids = next_token_ids.tolist()
572
-
573
- # Only batch transfer the selected logprobs of the next token to CPU to reduce overhead.
574
- if last_logprobs is not None:
575
- new_token_logprobs = last_logprobs[
576
- torch.arange(len(batch.reqs)), next_token_ids
577
- ].tolist()
578
-
579
- # Check finish condition
580
- for i, (req, next_token_id) in enumerate(zip(batch.reqs, next_token_ids)):
581
- req.completion_tokens_wo_jump_forward += 1
582
- req.output_ids.append(next_token_id)
583
- req.check_finished()
584
-
585
- if req.return_logprob:
586
- req.decode_token_logprobs.append((new_token_logprobs[i], next_token_id))
587
-
588
- if req.top_logprobs_num > 0:
589
- req.decode_top_logprobs.append(decode_top_logprobs[i])
590
-
591
- self.handle_finished_requests(batch)
592
-
593
- def handle_finished_requests(self, batch: Batch):
594
- output_rids = []
595
- prev_output_strs = []
596
- output_tokens = []
597
- output_hit_stop_str = []
598
- output_skip_special_tokens = []
599
- output_spaces_between_special_tokens = []
600
- output_meta_info = []
601
- output_finished = []
602
- finished_indices = []
603
- unfinished_indices = []
604
- for i, req in enumerate(batch.reqs):
605
- if req.finished:
606
- finished_indices.append(i)
607
- else:
608
- unfinished_indices.append(i)
609
-
610
- if req.finished or (
611
- (
612
- req.stream
613
- and (
614
- self.decode_forward_ct % self.stream_interval == 0
615
- or len(req.output_ids) == 1
616
- )
617
- )
618
- ):
619
- output_rids.append(req.rid)
620
- prev_output_strs.append(req.prev_output_str)
621
- output_tokens.append(req.output_ids)
622
- output_hit_stop_str.append(req.hit_stop_str)
623
- output_skip_special_tokens.append(
624
- req.sampling_params.skip_special_tokens
625
- )
626
- output_spaces_between_special_tokens.append(
627
- req.sampling_params.spaces_between_special_tokens
628
- )
629
-
630
- meta_info = {
631
- "prompt_tokens": len(req.origin_input_ids),
632
- "completion_tokens": len(req.prev_output_ids) + len(req.output_ids),
633
- "completion_tokens_wo_jump_forward": req.completion_tokens_wo_jump_forward,
634
- "finish_reason": FinishReason.to_str(req.finish_reason),
635
- "hit_stop_str": req.hit_stop_str,
636
- }
637
- if req.return_logprob:
638
- (
639
- meta_info["prefill_token_logprobs"],
640
- meta_info["decode_token_logprobs"],
641
- meta_info["prefill_top_logprobs"],
642
- meta_info["decode_top_logprobs"],
643
- meta_info["normalized_prompt_logprob"],
644
- ) = (
645
- req.prefill_token_logprobs,
646
- req.decode_token_logprobs,
647
- req.prefill_top_logprobs,
648
- req.decode_top_logprobs,
649
- req.normalized_prompt_logprob,
650
- )
651
- output_meta_info.append(meta_info)
652
- output_finished.append(req.finished)
653
-
654
- # Send to detokenizer
655
- if output_rids:
656
- self.out_pyobjs.append(
657
- BatchTokenIDOut(
658
- output_rids,
659
- prev_output_strs,
660
- output_tokens,
661
- output_hit_stop_str,
662
- output_skip_special_tokens,
663
- output_spaces_between_special_tokens,
664
- output_meta_info,
665
- output_finished,
666
- )
667
- )
668
-
669
- # Remove finished reqs
670
- if finished_indices:
671
- # Update radix cache
672
- req_pool_indices_cpu = batch.req_pool_indices.tolist()
673
- for i in finished_indices:
674
- req = batch.reqs[i]
675
- self.tree_cache.cache_req(
676
- token_ids=tuple(req.input_ids + req.output_ids)[:-1],
677
- last_uncached_pos=len(req.prefix_indices),
678
- req_pool_idx=req_pool_indices_cpu[i],
679
- )
680
-
681
- self.tree_cache.dec_lock_ref(req.last_node)
682
-
683
- # Update batch tensors
684
- if unfinished_indices:
685
- batch.filter_batch(unfinished_indices)
686
- else:
687
- batch.reqs = []
688
-
689
- def flush_cache(self):
690
- if len(self.forward_queue) == 0 and (
691
- self.running_batch is None or len(self.running_batch.reqs) == 0
692
- ):
693
- self.tree_cache.reset()
694
- self.tree_cache_metrics = {"total": 0, "hit": 0}
695
- self.regex_fsm_cache.reset()
696
- self.req_to_token_pool.clear()
697
- self.token_to_kv_pool.clear()
698
- torch.cuda.empty_cache()
699
- logger.info("Cache flushed successfully!")
700
- else:
701
- warnings.warn(
702
- f"Cache not flushed because there are pending requests. "
703
- f"#queue-req: {len(self.forward_queue)}, "
704
- f"#running-req: {0 if self.running_batch is None else len(self.running_batch.reqs)}"
705
- )
706
-
707
- def abort_request(self, recv_req):
708
- # Delete requests in the waiting queue
709
- to_del = None
710
- for i, req in enumerate(self.forward_queue):
711
- if req.rid == recv_req.rid:
712
- to_del = i
713
- break
714
-
715
- if to_del is not None:
716
- del self.forward_queue[to_del]
717
-
718
- # Delete requests in the running batch
719
- if self.running_batch:
720
- for req in self.running_batch.reqs:
721
- if req.rid == recv_req.rid:
722
- req.finished = True
723
- req.finish_reason = FinishReason.ABORT
724
- break
725
-
726
-
727
- class ModelRpcService(rpyc.Service):
728
- exposed_ModelRpcServer = ModelRpcServer
729
-
730
-
731
- class ModelRpcClient:
732
- def __init__(
733
- self, server_args: ServerArgs, port_args: PortArgs, model_overide_args
734
- ):
735
- tp_size = server_args.tp_size
736
-
737
- if tp_size == 1:
738
- # Init model
739
- self.model_server = ModelRpcService().exposed_ModelRpcServer(
740
- 0, server_args, port_args, model_overide_args
741
- )
742
-
743
- # Wrap functions
744
- def async_wrap(f):
745
- async def _func(*args, **kwargs):
746
- return f(*args, **kwargs)
747
-
748
- return _func
749
-
750
- self.step = async_wrap(self.model_server.exposed_step)
751
- else:
752
- with ThreadPoolExecutor(tp_size) as executor:
753
- # Launch model processes
754
- rets = executor.map(start_model_process, port_args.model_rpc_ports)
755
- self.remote_services = [x[0] for x in rets]
756
- self.procs = [x[1] for x in rets]
757
-
758
- # Init model
759
- def init_model(i):
760
- return self.remote_services[i].ModelRpcServer(
761
- i, server_args, port_args, model_overide_args
762
- )
763
-
764
- self.model_servers = executor.map(init_model, range(tp_size))
765
-
766
- # Wrap functions
767
- def async_wrap(func_name):
768
- fs = [rpyc.async_(getattr(m, func_name)) for m in self.model_servers]
769
-
770
- async def _func(*args, **kwargs):
771
- tasks = [f(*args, **kwargs) for f in fs]
772
- await asyncio.gather(*[asyncio.to_thread(t.wait) for t in tasks])
773
- return obtain(tasks[0].value)
774
-
775
- return _func
776
-
777
- self.step = async_wrap("step")
778
-
779
-
780
- def _init_service(port):
781
- t = ThreadedServer(
782
- ModelRpcService(),
783
- port=port,
784
- protocol_config={
785
- "allow_public_attrs": True,
786
- "allow_pickle": True,
787
- "sync_request_timeout": 3600,
788
- },
789
- )
790
- t.start()
791
-
792
-
793
- def start_model_process(port):
794
- proc = multiprocessing.Process(target=_init_service, args=(port,))
795
- proc.start()
796
- time.sleep(1)
797
-
798
- repeat_count = 0
799
- while repeat_count < 20:
800
- try:
801
- con = rpyc.connect(
802
- "localhost",
803
- port,
804
- config={
805
- "allow_public_attrs": True,
806
- "allow_pickle": True,
807
- "sync_request_timeout": 3600,
808
- },
809
- )
810
- break
811
- except ConnectionRefusedError:
812
- time.sleep(1)
813
- repeat_count += 1
814
- if repeat_count == 20:
815
- raise RuntimeError("init rpc env error!")
816
-
817
- assert proc.is_alive()
818
- return con.root, proc