sglang 0.1.16__py3-none-any.whl → 0.1.17__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (65) hide show
  1. sglang/__init__.py +3 -1
  2. sglang/api.py +3 -3
  3. sglang/backend/anthropic.py +1 -1
  4. sglang/backend/litellm.py +90 -0
  5. sglang/backend/openai.py +148 -12
  6. sglang/backend/runtime_endpoint.py +18 -10
  7. sglang/global_config.py +8 -1
  8. sglang/lang/interpreter.py +114 -67
  9. sglang/lang/ir.py +17 -2
  10. sglang/srt/constrained/fsm_cache.py +3 -0
  11. sglang/srt/flush_cache.py +1 -1
  12. sglang/srt/hf_transformers_utils.py +75 -1
  13. sglang/srt/layers/extend_attention.py +17 -0
  14. sglang/srt/layers/fused_moe.py +485 -0
  15. sglang/srt/layers/logits_processor.py +12 -7
  16. sglang/srt/layers/radix_attention.py +10 -3
  17. sglang/srt/layers/token_attention.py +16 -1
  18. sglang/srt/managers/controller/dp_worker.py +110 -0
  19. sglang/srt/managers/controller/infer_batch.py +619 -0
  20. sglang/srt/managers/controller/manager_multi.py +191 -0
  21. sglang/srt/managers/controller/manager_single.py +97 -0
  22. sglang/srt/managers/controller/model_runner.py +462 -0
  23. sglang/srt/managers/controller/radix_cache.py +267 -0
  24. sglang/srt/managers/controller/schedule_heuristic.py +59 -0
  25. sglang/srt/managers/controller/tp_worker.py +791 -0
  26. sglang/srt/managers/detokenizer_manager.py +45 -45
  27. sglang/srt/managers/io_struct.py +15 -11
  28. sglang/srt/managers/router/infer_batch.py +103 -59
  29. sglang/srt/managers/router/manager.py +1 -1
  30. sglang/srt/managers/router/model_rpc.py +175 -122
  31. sglang/srt/managers/router/model_runner.py +91 -104
  32. sglang/srt/managers/router/radix_cache.py +7 -1
  33. sglang/srt/managers/router/scheduler.py +6 -6
  34. sglang/srt/managers/tokenizer_manager.py +152 -89
  35. sglang/srt/model_config.py +4 -5
  36. sglang/srt/models/commandr.py +10 -13
  37. sglang/srt/models/dbrx.py +9 -15
  38. sglang/srt/models/gemma.py +8 -15
  39. sglang/srt/models/grok.py +671 -0
  40. sglang/srt/models/llama2.py +19 -15
  41. sglang/srt/models/llava.py +84 -20
  42. sglang/srt/models/llavavid.py +11 -20
  43. sglang/srt/models/mixtral.py +248 -118
  44. sglang/srt/models/mixtral_quant.py +373 -0
  45. sglang/srt/models/qwen.py +9 -13
  46. sglang/srt/models/qwen2.py +11 -13
  47. sglang/srt/models/stablelm.py +9 -15
  48. sglang/srt/models/yivl.py +17 -22
  49. sglang/srt/openai_api_adapter.py +140 -95
  50. sglang/srt/openai_protocol.py +10 -1
  51. sglang/srt/server.py +77 -42
  52. sglang/srt/server_args.py +51 -6
  53. sglang/srt/utils.py +124 -66
  54. sglang/test/test_programs.py +44 -0
  55. sglang/test/test_utils.py +32 -1
  56. sglang/utils.py +22 -4
  57. {sglang-0.1.16.dist-info → sglang-0.1.17.dist-info}/METADATA +15 -9
  58. sglang-0.1.17.dist-info/RECORD +81 -0
  59. sglang/srt/backend_config.py +0 -13
  60. sglang/srt/models/dbrx_config.py +0 -281
  61. sglang/srt/weight_utils.py +0 -417
  62. sglang-0.1.16.dist-info/RECORD +0 -72
  63. {sglang-0.1.16.dist-info → sglang-0.1.17.dist-info}/LICENSE +0 -0
  64. {sglang-0.1.16.dist-info → sglang-0.1.17.dist-info}/WHEEL +0 -0
  65. {sglang-0.1.16.dist-info → sglang-0.1.17.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,791 @@
1
+ import asyncio
2
+ import logging
3
+ import time
4
+ import warnings
5
+ from concurrent.futures import ThreadPoolExecutor
6
+ from typing import List
7
+
8
+ import rpyc
9
+ import torch
10
+ from rpyc.utils.classic import obtain
11
+
12
+ from sglang.global_config import global_config
13
+ from sglang.srt.constrained.fsm_cache import FSMCache
14
+ from sglang.srt.constrained.jump_forward import JumpForwardCache
15
+ from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
16
+ from sglang.srt.managers.io_struct import (
17
+ AbortReq,
18
+ BatchTokenIDOut,
19
+ FlushCacheReq,
20
+ TokenizedGenerateReqInput,
21
+ )
22
+ from sglang.srt.managers.controller.infer_batch import BaseFinishReason, Batch, FINISH_ABORT, ForwardMode, Req
23
+ from sglang.srt.managers.controller.model_runner import ModelRunner
24
+ from sglang.srt.managers.controller.radix_cache import RadixCache
25
+ from sglang.srt.managers.controller.schedule_heuristic import ScheduleHeuristic
26
+ from sglang.srt.model_config import ModelConfig
27
+ from sglang.srt.server_args import ModelPortArgs, ServerArgs
28
+ from sglang.srt.utils import (
29
+ get_int_token_logit_bias,
30
+ is_multimodal_model,
31
+ set_random_seed,
32
+ start_rpyc_process,
33
+ suppress_other_loggers,
34
+ )
35
+ from sglang.utils import get_exception_traceback
36
+
37
+ logger = logging.getLogger("srt.tp_worker")
38
+
39
+
40
+ class ModelTpServer:
41
+ def __init__(
42
+ self,
43
+ gpu_id: int,
44
+ tp_rank: int,
45
+ server_args: ServerArgs,
46
+ model_port_args: ModelPortArgs,
47
+ model_overide_args,
48
+ ):
49
+ server_args, model_port_args = obtain(server_args), obtain(model_port_args)
50
+ suppress_other_loggers()
51
+
52
+ # Copy arguments
53
+ self.gpu_id = gpu_id
54
+ self.tp_rank = tp_rank
55
+ self.tp_size = server_args.tp_size
56
+ self.dp_size = server_args.dp_size
57
+ self.schedule_heuristic = server_args.schedule_heuristic
58
+ self.disable_regex_jump_forward = server_args.disable_regex_jump_forward
59
+
60
+ # Init model and tokenizer
61
+ self.model_config = ModelConfig(
62
+ server_args.model_path,
63
+ server_args.trust_remote_code,
64
+ context_length=server_args.context_length,
65
+ model_overide_args=model_overide_args,
66
+ )
67
+ self.model_runner = ModelRunner(
68
+ model_config=self.model_config,
69
+ mem_fraction_static=server_args.mem_fraction_static,
70
+ gpu_id=gpu_id,
71
+ tp_rank=tp_rank,
72
+ tp_size=server_args.tp_size,
73
+ nccl_port=model_port_args.nccl_port,
74
+ server_args=server_args,
75
+ )
76
+
77
+ if is_multimodal_model(server_args.model_path):
78
+ self.processor = get_processor(
79
+ server_args.tokenizer_path,
80
+ tokenizer_mode=server_args.tokenizer_mode,
81
+ trust_remote_code=server_args.trust_remote_code,
82
+ )
83
+ self.tokenizer = self.processor.tokenizer
84
+ else:
85
+ self.tokenizer = get_tokenizer(
86
+ server_args.tokenizer_path,
87
+ tokenizer_mode=server_args.tokenizer_mode,
88
+ trust_remote_code=server_args.trust_remote_code,
89
+ )
90
+ self.max_total_num_tokens = self.model_runner.max_total_num_tokens
91
+ self.max_prefill_tokens = max(
92
+ self.model_config.context_len,
93
+ (
94
+ min(self.max_total_num_tokens // 6, 65536)
95
+ if server_args.max_prefill_tokens is None
96
+ else server_args.max_prefill_tokens
97
+ ),
98
+ )
99
+ self.max_running_requests = (self.max_total_num_tokens // 2
100
+ if server_args.max_running_requests is None else server_args.max_running_requests)
101
+ self.int_token_logit_bias = torch.tensor(
102
+ get_int_token_logit_bias(self.tokenizer, self.model_config.vocab_size)
103
+ )
104
+ set_random_seed(server_args.random_seed)
105
+
106
+ # Print info
107
+ logger.info(
108
+ f"[gpu_id={self.gpu_id}] "
109
+ f"max_total_num_tokens={self.max_total_num_tokens}, "
110
+ f"max_prefill_tokens={self.max_prefill_tokens}, "
111
+ f"context_len={self.model_config.context_len}, "
112
+ )
113
+ if self.tp_rank == 0:
114
+ logger.info(
115
+ f"[gpu_id={self.gpu_id}] "
116
+ f"server_args: {server_args.print_mode_args()}"
117
+ )
118
+
119
+ # Init cache
120
+ self.tree_cache = RadixCache(
121
+ req_to_token_pool=self.model_runner.req_to_token_pool,
122
+ token_to_kv_pool=self.model_runner.token_to_kv_pool,
123
+ disable=server_args.disable_radix_cache,
124
+ )
125
+ self.tree_cache_metrics = {"total": 0, "hit": 0}
126
+ self.scheduler = ScheduleHeuristic(
127
+ self.schedule_heuristic,
128
+ self.max_running_requests,
129
+ self.max_prefill_tokens,
130
+ self.max_total_num_tokens,
131
+ self.tree_cache,
132
+ )
133
+ self.req_to_token_pool = self.model_runner.req_to_token_pool
134
+ self.token_to_kv_pool = self.model_runner.token_to_kv_pool
135
+
136
+ # Init running status
137
+ self.forward_queue: List[Req] = []
138
+ self.running_batch: Batch = None
139
+ self.out_pyobjs = []
140
+ self.decode_forward_ct = 0
141
+ self.stream_interval = server_args.stream_interval
142
+ self.num_generated_tokens = 0
143
+ self.last_stats_tic = time.time()
144
+
145
+ # Init the FSM cache for constrained generation
146
+ self.regex_fsm_cache = FSMCache(
147
+ server_args.tokenizer_path,
148
+ {
149
+ "tokenizer_mode": server_args.tokenizer_mode,
150
+ "trust_remote_code": server_args.trust_remote_code,
151
+ },
152
+ )
153
+ self.jump_forward_cache = JumpForwardCache()
154
+
155
+ # Init new token estimation
156
+ assert (
157
+ server_args.schedule_conservativeness >= 0
158
+ ), "Invalid schedule_conservativeness"
159
+ self.new_token_ratio = min(
160
+ global_config.base_new_token_ratio * server_args.schedule_conservativeness,
161
+ 1.0,
162
+ )
163
+ self.min_new_token_ratio = min(
164
+ global_config.base_min_new_token_ratio
165
+ * server_args.schedule_conservativeness,
166
+ 1.0,
167
+ )
168
+ self.new_token_ratio_decay = global_config.new_token_ratio_decay
169
+ self.new_token_ratio_recovery = global_config.new_token_ratio_recovery
170
+
171
+ def exposed_step(self, recv_reqs):
172
+ if self.tp_size * self.dp_size != 1:
173
+ recv_reqs = obtain(recv_reqs)
174
+
175
+ try:
176
+ # Recv requests
177
+ for recv_req in recv_reqs:
178
+ if isinstance(recv_req, TokenizedGenerateReqInput):
179
+ self.handle_generate_request(recv_req)
180
+ elif isinstance(recv_req, FlushCacheReq):
181
+ self.flush_cache()
182
+ elif isinstance(recv_req, AbortReq):
183
+ self.abort_request(recv_req)
184
+ else:
185
+ raise ValueError(f"Invalid request: {recv_req}")
186
+
187
+ # Forward
188
+ self.forward_step()
189
+ except Exception:
190
+ logger.error("Exception in ModelTpServer:\n" + get_exception_traceback())
191
+ raise
192
+
193
+ # Return results
194
+ ret = self.out_pyobjs
195
+ self.out_pyobjs = []
196
+ return ret
197
+
198
+ @torch.inference_mode()
199
+ def forward_step(self):
200
+ new_batch = self.get_new_fill_batch()
201
+
202
+ if new_batch is not None:
203
+ # Run a new fill batch
204
+ self.forward_fill_batch(new_batch)
205
+ self.cache_filled_batch(new_batch)
206
+
207
+ if not new_batch.is_empty():
208
+ if self.running_batch is None:
209
+ self.running_batch = new_batch
210
+ else:
211
+ self.running_batch.merge(new_batch)
212
+ else:
213
+ # Run decode batch
214
+ if self.running_batch is not None:
215
+ # Run a few decode batches continuously for reducing overhead
216
+ for _ in range(10):
217
+ self.num_generated_tokens += len(self.running_batch.reqs)
218
+ self.forward_decode_batch(self.running_batch)
219
+
220
+ # Print stats
221
+ if self.tp_rank == 0:
222
+ if self.decode_forward_ct % 40 == 0:
223
+ num_used = self.max_total_num_tokens - (
224
+ self.token_to_kv_pool.available_size()
225
+ + self.tree_cache.evictable_size()
226
+ )
227
+ throughput = self.num_generated_tokens / (
228
+ time.time() - self.last_stats_tic
229
+ )
230
+ self.num_generated_tokens = 0
231
+ self.last_stats_tic = time.time()
232
+ logger.info(
233
+ f"[gpu_id={self.gpu_id}] Decode batch. "
234
+ f"#running-req: {len(self.running_batch.reqs)}, "
235
+ f"#token: {num_used}, "
236
+ f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
237
+ f"gen throughput (token/s): {throughput:.2f}, "
238
+ f"#queue-req: {len(self.forward_queue)}"
239
+ )
240
+
241
+ if self.running_batch.is_empty():
242
+ self.running_batch = None
243
+ break
244
+
245
+ if self.out_pyobjs and self.running_batch.reqs[0].stream:
246
+ break
247
+ else:
248
+ # Check the available size
249
+ available_size = (
250
+ self.token_to_kv_pool.available_size()
251
+ + self.tree_cache.evictable_size()
252
+ )
253
+ if available_size != self.max_total_num_tokens:
254
+ warnings.warn(
255
+ "Warning: "
256
+ f"available_size={available_size}, max_total_num_tokens={self.max_total_num_tokens}\n"
257
+ "KV cache pool leak detected!"
258
+ )
259
+
260
+ def handle_generate_request(
261
+ self,
262
+ recv_req: TokenizedGenerateReqInput,
263
+ ):
264
+ req = Req(recv_req.rid, recv_req.input_text, recv_req.input_ids)
265
+ req.pixel_values = recv_req.pixel_values
266
+ if req.pixel_values is not None:
267
+ req.pad_value = [
268
+ (recv_req.image_hash) % self.model_config.vocab_size,
269
+ (recv_req.image_hash >> 16) % self.model_config.vocab_size,
270
+ (recv_req.image_hash >> 32) % self.model_config.vocab_size,
271
+ (recv_req.image_hash >> 64) % self.model_config.vocab_size,
272
+ ]
273
+ req.image_size = recv_req.image_size
274
+ req.origin_input_ids, req.image_offset = (
275
+ self.model_runner.model.pad_input_ids(
276
+ req.origin_input_ids_unpadded,
277
+ req.pad_value,
278
+ req.pixel_values.shape,
279
+ req.image_size,
280
+ )
281
+ )
282
+ req.sampling_params = recv_req.sampling_params
283
+ req.return_logprob = recv_req.return_logprob
284
+ req.logprob_start_len = recv_req.logprob_start_len
285
+ req.top_logprobs_num = recv_req.top_logprobs_num
286
+ req.stream = recv_req.stream
287
+ req.tokenizer = self.tokenizer
288
+
289
+ # Init regex fsm
290
+ if req.sampling_params.regex is not None:
291
+ req.regex_fsm = self.regex_fsm_cache.query(req.sampling_params.regex)
292
+ if not self.disable_regex_jump_forward:
293
+ req.jump_forward_map = self.jump_forward_cache.query(
294
+ req.sampling_params.regex
295
+ )
296
+
297
+ # Truncate prompts that are too long
298
+ req.origin_input_ids = req.origin_input_ids[: self.model_config.context_len - 1]
299
+ req.sampling_params.max_new_tokens = min(
300
+ req.sampling_params.max_new_tokens,
301
+ self.model_config.context_len - 1 - len(req.origin_input_ids),
302
+ self.max_total_num_tokens - 128 - len(req.origin_input_ids),
303
+ )
304
+ self.forward_queue.append(req)
305
+
306
+ def get_new_fill_batch(self):
307
+ if (
308
+ self.running_batch is not None
309
+ and len(self.running_batch.reqs) > self.max_running_requests
310
+ ):
311
+ return None
312
+
313
+ # Compute matched prefix length
314
+ for req in self.forward_queue:
315
+ assert (
316
+ len(req.output_ids) == 0
317
+ ), "The output ids should be empty when prefilling"
318
+ req.input_ids = req.origin_input_ids + req.prev_output_ids
319
+ prefix_indices, last_node = self.tree_cache.match_prefix(req.input_ids)
320
+ if req.return_logprob:
321
+ prefix_indices = prefix_indices[: req.logprob_start_len]
322
+ req.extend_input_len = len(req.input_ids) - len(prefix_indices)
323
+ req.prefix_indices = prefix_indices
324
+ req.last_node = last_node
325
+
326
+ # Get priority queue
327
+ self.forward_queue = self.scheduler.get_priority_queue(self.forward_queue)
328
+
329
+ # Add requests if there is available space
330
+ can_run_list = []
331
+ new_batch_total_tokens = 0
332
+ new_batch_input_tokens = 0
333
+
334
+ available_size = (
335
+ self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()
336
+ )
337
+ if self.running_batch:
338
+ available_size -= sum(
339
+ [
340
+ (r.max_new_tokens() - len(r.output_ids)) * self.new_token_ratio
341
+ for r in self.running_batch.reqs
342
+ ]
343
+ )
344
+
345
+ for req in self.forward_queue:
346
+ if req.return_logprob and req.normalized_prompt_logprob is None:
347
+ # Need at least two tokens to compute normalized logprob
348
+ if req.extend_input_len < 2:
349
+ delta = 2 - req.extend_input_len
350
+ req.extend_input_len += delta
351
+ req.prefix_indices = req.prefix_indices[:-delta]
352
+ if req.image_offset is not None:
353
+ req.image_offset += delta
354
+ if req.extend_input_len == 0 and req.max_new_tokens() > 0:
355
+ # Need at least one token to compute logits
356
+ req.extend_input_len = 1
357
+ req.prefix_indices = req.prefix_indices[:-1]
358
+ if req.image_offset is not None:
359
+ req.image_offset += 1
360
+
361
+ if (
362
+ req.extend_input_len + req.max_new_tokens() + new_batch_total_tokens
363
+ < available_size
364
+ and req.extend_input_len + new_batch_input_tokens
365
+ < self.max_prefill_tokens
366
+ ):
367
+ delta = self.tree_cache.inc_lock_ref(req.last_node)
368
+ available_size += delta
369
+
370
+ if not (
371
+ req.extend_input_len + req.max_new_tokens() + new_batch_total_tokens
372
+ < available_size
373
+ ):
374
+ # Undo locking
375
+ delta = self.tree_cache.dec_lock_ref(req.last_node)
376
+ available_size += delta
377
+ break
378
+ else:
379
+ # Add this request to the running batch
380
+ can_run_list.append(req)
381
+ new_batch_total_tokens += (
382
+ req.extend_input_len + req.max_new_tokens()
383
+ )
384
+ new_batch_input_tokens += req.extend_input_len
385
+ else:
386
+ break
387
+ if len(can_run_list) == 0:
388
+ return None
389
+
390
+ # Print stats
391
+ if self.tp_rank == 0:
392
+ running_req = (
393
+ 0 if self.running_batch is None else len(self.running_batch.reqs)
394
+ )
395
+ hit_tokens = sum(len(x.prefix_indices) for x in can_run_list)
396
+ self.tree_cache_metrics["total"] += (
397
+ hit_tokens + new_batch_input_tokens
398
+ ) / 10**9
399
+ self.tree_cache_metrics["hit"] += hit_tokens / 10**9
400
+ tree_cache_hit_rate = (
401
+ self.tree_cache_metrics["hit"] / self.tree_cache_metrics["total"]
402
+ )
403
+ logger.info(
404
+ f"[gpu_id={self.gpu_id}] Prefil batch. "
405
+ f"#new-seq: {len(can_run_list)}, "
406
+ f"#new-token: {new_batch_input_tokens}, "
407
+ f"#cached-token: {hit_tokens}, "
408
+ f"cache hit rate: {100.0 * tree_cache_hit_rate:.2f}%, "
409
+ f"#running-req: {running_req}, "
410
+ f"#queue-req: {len(self.forward_queue) - len(can_run_list)}"
411
+ )
412
+ # logger.debug(
413
+ # f"fsm_cache_hit_rate: {100.0 * self.regex_fsm_cache.get_cache_hit_rate():.2f}%. "
414
+ # f"fsm_cache_avg_init_time: {self.regex_fsm_cache.get_avg_init_time():.2f}s. "
415
+ # f"ff_cache_hit_rate: {100.0 * self.jump_forward_cache.get_cache_hit_rate():.2f}%. "
416
+ # f"ff_cache_avg_init_time: {self.jump_forward_cache.get_avg_init_time():.2f}s. "
417
+ # )
418
+
419
+ # Return the new batch
420
+ new_batch = Batch.init_new(
421
+ can_run_list,
422
+ self.req_to_token_pool,
423
+ self.token_to_kv_pool,
424
+ self.tree_cache,
425
+ )
426
+ self.forward_queue = [x for x in self.forward_queue if x not in can_run_list]
427
+ return new_batch
428
+
429
+ def forward_fill_batch(self, batch: Batch):
430
+ # Build batch tensors
431
+ batch.prepare_for_extend(
432
+ self.model_config.vocab_size, self.int_token_logit_bias
433
+ )
434
+
435
+ if batch.extend_num_tokens != 0:
436
+ # Forward
437
+ logits, (
438
+ prefill_token_logprobs,
439
+ normalized_prompt_logprobs,
440
+ prefill_top_logprobs,
441
+ decode_top_logprobs,
442
+ last_logprobs,
443
+ ) = self.model_runner.forward(batch, ForwardMode.EXTEND)
444
+ if prefill_token_logprobs is not None:
445
+ prefill_token_logprobs = prefill_token_logprobs.tolist()
446
+ normalized_prompt_logprobs = normalized_prompt_logprobs.tolist()
447
+
448
+ next_token_ids, _ = batch.sample(logits)
449
+
450
+ # Only transfer the selected logprobs of the next token to CPU to reduce overhead.
451
+ if last_logprobs is not None:
452
+ last_token_logprobs = last_logprobs[
453
+ torch.arange(len(batch.reqs), device=next_token_ids.device),
454
+ next_token_ids,
455
+ ].tolist()
456
+
457
+ next_token_ids = next_token_ids.tolist()
458
+ else:
459
+ next_token_ids = [self.tokenizer.eos_token_id] * len(batch.reqs)
460
+
461
+ # Check finish condition
462
+ pt = 0
463
+ for i, req in enumerate(batch.reqs):
464
+ req.completion_tokens_wo_jump_forward += 1
465
+ req.output_ids = [next_token_ids[i]]
466
+ req.check_finished()
467
+
468
+ if req.return_logprob:
469
+ if req.normalized_prompt_logprob is None:
470
+ req.normalized_prompt_logprob = normalized_prompt_logprobs[i]
471
+
472
+ if req.prefill_token_logprobs is None:
473
+ # If logprob_start_len > 0, then first logprob_start_len prompt tokens will be ignored.
474
+ req.prefill_token_logprobs = list(
475
+ zip(
476
+ prefill_token_logprobs[pt : pt + req.extend_input_len - 1],
477
+ req.input_ids[-req.extend_input_len + 1 :],
478
+ )
479
+ )
480
+ if req.logprob_start_len == 0:
481
+ req.prefill_token_logprobs = [
482
+ (None, req.input_ids[0])
483
+ ] + req.prefill_token_logprobs
484
+
485
+ if req.last_update_decode_tokens != 0:
486
+ req.decode_token_logprobs.extend(
487
+ list(
488
+ zip(
489
+ prefill_token_logprobs[
490
+ pt
491
+ + req.extend_input_len
492
+ - req.last_update_decode_tokens : pt
493
+ + req.extend_input_len
494
+ - 1
495
+ ],
496
+ req.input_ids[-req.last_update_decode_tokens + 1 :],
497
+ )
498
+ )
499
+ )
500
+
501
+ req.decode_token_logprobs.append(
502
+ (last_token_logprobs[i], next_token_ids[i])
503
+ )
504
+
505
+ if req.top_logprobs_num > 0:
506
+ if req.prefill_top_logprobs is None:
507
+ req.prefill_top_logprobs = prefill_top_logprobs[i]
508
+ if req.logprob_start_len == 0:
509
+ req.prefill_top_logprobs = [None] + req.prefill_top_logprobs
510
+
511
+ if req.last_update_decode_tokens != 0:
512
+ req.decode_top_logprobs.extend(
513
+ prefill_top_logprobs[i][-req.last_update_decode_tokens + 1 :]
514
+ )
515
+ req.decode_top_logprobs.append(decode_top_logprobs[i])
516
+
517
+ pt += req.extend_input_len
518
+
519
+ self.handle_finished_requests(batch)
520
+
521
+ def cache_filled_batch(self, batch: Batch):
522
+ req_pool_indices_cpu = batch.req_pool_indices.cpu().numpy()
523
+ for i, req in enumerate(batch.reqs):
524
+ new_prefix_indices, new_last_node = self.tree_cache.cache_req(
525
+ token_ids=tuple(req.input_ids + req.output_ids)[:-1],
526
+ last_uncached_pos=len(req.prefix_indices),
527
+ req_pool_idx=req_pool_indices_cpu[i],
528
+ del_in_memory_pool=False,
529
+ old_last_node=req.last_node,
530
+ )
531
+ req.prefix_indices, req.last_node = new_prefix_indices, new_last_node
532
+
533
+ def forward_decode_batch(self, batch: Batch):
534
+ # check if decode out of memory
535
+ if not batch.check_decode_mem():
536
+ old_ratio = self.new_token_ratio
537
+ self.new_token_ratio = min(old_ratio + self.new_token_ratio_recovery, 1.0)
538
+
539
+ retracted_reqs = batch.retract_decode()
540
+ logger.info(
541
+ "decode out of memory happened, "
542
+ f"#retracted_reqs: {len(retracted_reqs)}, "
543
+ f"#new_token_ratio: {old_ratio:.4f} -> {self.new_token_ratio:.4f}"
544
+ )
545
+ self.forward_queue.extend(retracted_reqs)
546
+ else:
547
+ self.new_token_ratio = max(
548
+ self.new_token_ratio - self.new_token_ratio_decay,
549
+ self.min_new_token_ratio,
550
+ )
551
+
552
+ if not self.disable_regex_jump_forward:
553
+ # check for jump-forward
554
+ jump_forward_reqs = batch.check_for_jump_forward(self.model_runner)
555
+
556
+ self.forward_queue.extend(jump_forward_reqs)
557
+ if batch.is_empty():
558
+ return
559
+
560
+ # Update batch tensors
561
+ self.decode_forward_ct = (self.decode_forward_ct + 1) % (1 << 30)
562
+ batch.prepare_for_decode()
563
+
564
+ # Forward
565
+ logits, (
566
+ _,
567
+ _,
568
+ _,
569
+ decode_top_logprobs,
570
+ last_logprobs,
571
+ ) = self.model_runner.forward(batch, ForwardMode.DECODE)
572
+ next_token_ids, _ = batch.sample(logits)
573
+ next_token_ids = next_token_ids.tolist()
574
+
575
+ # Only batch transfer the selected logprobs of the next token to CPU to reduce overhead.
576
+ if last_logprobs is not None:
577
+ new_token_logprobs = last_logprobs[
578
+ torch.arange(len(batch.reqs)), next_token_ids
579
+ ].tolist()
580
+
581
+ # Check finish condition
582
+ for i, (req, next_token_id) in enumerate(zip(batch.reqs, next_token_ids)):
583
+ req.completion_tokens_wo_jump_forward += 1
584
+ req.output_ids.append(next_token_id)
585
+ req.check_finished()
586
+
587
+ if req.return_logprob:
588
+ req.decode_token_logprobs.append((new_token_logprobs[i], next_token_id))
589
+
590
+ if req.top_logprobs_num > 0:
591
+ req.decode_top_logprobs.append(decode_top_logprobs[i])
592
+
593
+ self.handle_finished_requests(batch)
594
+
595
+ def handle_finished_requests(self, batch: Batch):
596
+ output_rids = []
597
+ prev_output_strs = []
598
+ output_tokens = []
599
+ output_skip_special_tokens = []
600
+ output_spaces_between_special_tokens = []
601
+ output_meta_info = []
602
+ output_finished_reason: List[BaseFinishReason] = []
603
+ finished_indices = []
604
+ unfinished_indices = []
605
+ for i, req in enumerate(batch.reqs):
606
+ if req.finished():
607
+ finished_indices.append(i)
608
+ else:
609
+ unfinished_indices.append(i)
610
+
611
+ if req.finished() or (
612
+ (
613
+ req.stream
614
+ and (
615
+ self.decode_forward_ct % self.stream_interval == 0
616
+ or len(req.output_ids) == 1
617
+ )
618
+ )
619
+ ):
620
+ output_rids.append(req.rid)
621
+ prev_output_strs.append(req.prev_output_str)
622
+ output_tokens.append(req.output_ids)
623
+ output_skip_special_tokens.append(
624
+ req.sampling_params.skip_special_tokens
625
+ )
626
+ output_spaces_between_special_tokens.append(
627
+ req.sampling_params.spaces_between_special_tokens
628
+ )
629
+
630
+ meta_info = {
631
+ "prompt_tokens": len(req.origin_input_ids),
632
+ "completion_tokens": len(req.prev_output_ids) + len(req.output_ids),
633
+ "completion_tokens_wo_jump_forward": req.completion_tokens_wo_jump_forward,
634
+ "finish_reason": str(req.finished_reason),
635
+ }
636
+ if req.return_logprob:
637
+ (
638
+ meta_info["prefill_token_logprobs"],
639
+ meta_info["decode_token_logprobs"],
640
+ meta_info["prefill_top_logprobs"],
641
+ meta_info["decode_top_logprobs"],
642
+ meta_info["normalized_prompt_logprob"],
643
+ ) = (
644
+ req.prefill_token_logprobs,
645
+ req.decode_token_logprobs,
646
+ req.prefill_top_logprobs,
647
+ req.decode_top_logprobs,
648
+ req.normalized_prompt_logprob,
649
+ )
650
+ output_meta_info.append(meta_info)
651
+ output_finished_reason.append(req.finished_reason)
652
+
653
+ # Send to detokenizer
654
+ if output_rids:
655
+ self.out_pyobjs.append(
656
+ BatchTokenIDOut(
657
+ output_rids,
658
+ prev_output_strs,
659
+ output_tokens,
660
+ output_skip_special_tokens,
661
+ output_spaces_between_special_tokens,
662
+ output_meta_info,
663
+ output_finished_reason,
664
+ )
665
+ )
666
+
667
+ # Remove finished reqs
668
+ if finished_indices:
669
+ # Update radix cache
670
+ req_pool_indices_cpu = batch.req_pool_indices.tolist()
671
+ for i in finished_indices:
672
+ req = batch.reqs[i]
673
+ self.tree_cache.cache_req(
674
+ token_ids=tuple(req.input_ids + req.output_ids)[:-1],
675
+ last_uncached_pos=len(req.prefix_indices),
676
+ req_pool_idx=req_pool_indices_cpu[i],
677
+ )
678
+
679
+ self.tree_cache.dec_lock_ref(req.last_node)
680
+
681
+ # Update batch tensors
682
+ if unfinished_indices:
683
+ batch.filter_batch(unfinished_indices)
684
+ else:
685
+ batch.reqs = []
686
+
687
+ def flush_cache(self):
688
+ if len(self.forward_queue) == 0 and (
689
+ self.running_batch is None or len(self.running_batch.reqs) == 0
690
+ ):
691
+ self.tree_cache.reset()
692
+ self.tree_cache_metrics = {"total": 0, "hit": 0}
693
+ self.regex_fsm_cache.reset()
694
+ self.req_to_token_pool.clear()
695
+ self.token_to_kv_pool.clear()
696
+ torch.cuda.empty_cache()
697
+ logger.info("Cache flushed successfully!")
698
+ else:
699
+ warnings.warn(
700
+ f"Cache not flushed because there are pending requests. "
701
+ f"#queue-req: {len(self.forward_queue)}, "
702
+ f"#running-req: {0 if self.running_batch is None else len(self.running_batch.reqs)}"
703
+ )
704
+
705
+ def abort_request(self, recv_req):
706
+ # Delete requests in the waiting queue
707
+ to_del = None
708
+ for i, req in enumerate(self.forward_queue):
709
+ if req.rid == recv_req.rid:
710
+ to_del = i
711
+ break
712
+
713
+ if to_del is not None:
714
+ del self.forward_queue[to_del]
715
+
716
+ # Delete requests in the running batch
717
+ if self.running_batch:
718
+ for req in self.running_batch.reqs:
719
+ if req.rid == recv_req.rid:
720
+ req.finished_reason = FINISH_ABORT()
721
+ break
722
+
723
+
724
+ class ModelTpService(rpyc.Service):
725
+ exposed_ModelTpServer = ModelTpServer
726
+
727
+
728
+ class ModelTpClient:
729
+ def __init__(
730
+ self,
731
+ gpu_ids: List[int],
732
+ server_args: ServerArgs,
733
+ model_port_args: ModelPortArgs,
734
+ model_overide_args,
735
+ ):
736
+ server_args, model_port_args = obtain(server_args), obtain(model_port_args)
737
+ self.tp_size = server_args.tp_size
738
+
739
+ if self.tp_size * server_args.dp_size == 1:
740
+ # Init model
741
+ assert len(gpu_ids) == 1
742
+ self.model_server = ModelTpService().exposed_ModelTpServer(
743
+ 0,
744
+ gpu_ids[0],
745
+ server_args,
746
+ model_port_args,
747
+ model_overide_args,
748
+ )
749
+
750
+ # Wrap functions
751
+ def async_wrap(f):
752
+ async def _func(*args, **kwargs):
753
+ return f(*args, **kwargs)
754
+
755
+ return _func
756
+
757
+ self.step = async_wrap(self.model_server.exposed_step)
758
+ else:
759
+ with ThreadPoolExecutor(self.tp_size) as executor:
760
+ # Launch model processes
761
+ rets = executor.map(
762
+ lambda args: start_rpyc_process(*args),
763
+ [(ModelTpService, p) for p in model_port_args.model_tp_ports],
764
+ )
765
+ self.model_services = [x[0] for x in rets]
766
+ self.procs = [x[1] for x in rets]
767
+
768
+ # Init model
769
+ def init_model(i):
770
+ return self.model_services[i].ModelTpServer(
771
+ gpu_ids[i],
772
+ i,
773
+ server_args,
774
+ model_port_args,
775
+ model_overide_args,
776
+ )
777
+
778
+ self.model_servers = executor.map(init_model, range(self.tp_size))
779
+
780
+ # Wrap functions
781
+ def async_wrap(func_name):
782
+ fs = [rpyc.async_(getattr(m, func_name)) for m in self.model_servers]
783
+
784
+ async def _func(*args, **kwargs):
785
+ tasks = [f(*args, **kwargs) for f in fs]
786
+ await asyncio.gather(*[asyncio.to_thread(t.wait) for t in tasks])
787
+ return obtain(tasks[0].value)
788
+
789
+ return _func
790
+
791
+ self.step = async_wrap("step")