sglang 0.1.14__py3-none-any.whl → 0.1.21__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (81) hide show
  1. sglang/__init__.py +59 -2
  2. sglang/api.py +40 -11
  3. sglang/backend/anthropic.py +17 -3
  4. sglang/backend/litellm.py +90 -0
  5. sglang/backend/openai.py +160 -12
  6. sglang/backend/runtime_endpoint.py +62 -27
  7. sglang/backend/vertexai.py +1 -0
  8. sglang/bench_latency.py +320 -0
  9. sglang/global_config.py +24 -3
  10. sglang/lang/chat_template.py +122 -6
  11. sglang/lang/compiler.py +2 -2
  12. sglang/lang/interpreter.py +206 -98
  13. sglang/lang/ir.py +98 -34
  14. sglang/lang/tracer.py +6 -4
  15. sglang/launch_server.py +4 -1
  16. sglang/launch_server_llavavid.py +32 -0
  17. sglang/srt/constrained/__init__.py +14 -6
  18. sglang/srt/constrained/fsm_cache.py +9 -2
  19. sglang/srt/constrained/jump_forward.py +113 -24
  20. sglang/srt/conversation.py +4 -2
  21. sglang/srt/flush_cache.py +18 -0
  22. sglang/srt/hf_transformers_utils.py +144 -3
  23. sglang/srt/layers/context_flashattention_nopad.py +1 -0
  24. sglang/srt/layers/extend_attention.py +20 -1
  25. sglang/srt/layers/fused_moe.py +596 -0
  26. sglang/srt/layers/logits_processor.py +190 -61
  27. sglang/srt/layers/radix_attention.py +62 -53
  28. sglang/srt/layers/token_attention.py +21 -9
  29. sglang/srt/managers/controller/cuda_graph_runner.py +196 -0
  30. sglang/srt/managers/controller/dp_worker.py +113 -0
  31. sglang/srt/managers/controller/infer_batch.py +908 -0
  32. sglang/srt/managers/controller/manager_multi.py +195 -0
  33. sglang/srt/managers/controller/manager_single.py +177 -0
  34. sglang/srt/managers/controller/model_runner.py +359 -0
  35. sglang/srt/managers/{router → controller}/radix_cache.py +102 -53
  36. sglang/srt/managers/controller/schedule_heuristic.py +65 -0
  37. sglang/srt/managers/controller/tp_worker.py +813 -0
  38. sglang/srt/managers/detokenizer_manager.py +42 -40
  39. sglang/srt/managers/io_struct.py +44 -10
  40. sglang/srt/managers/tokenizer_manager.py +224 -82
  41. sglang/srt/memory_pool.py +52 -59
  42. sglang/srt/model_config.py +97 -2
  43. sglang/srt/models/chatglm.py +399 -0
  44. sglang/srt/models/commandr.py +369 -0
  45. sglang/srt/models/dbrx.py +406 -0
  46. sglang/srt/models/gemma.py +34 -38
  47. sglang/srt/models/gemma2.py +436 -0
  48. sglang/srt/models/grok.py +738 -0
  49. sglang/srt/models/llama2.py +47 -37
  50. sglang/srt/models/llama_classification.py +107 -0
  51. sglang/srt/models/llava.py +92 -27
  52. sglang/srt/models/llavavid.py +298 -0
  53. sglang/srt/models/minicpm.py +366 -0
  54. sglang/srt/models/mixtral.py +302 -127
  55. sglang/srt/models/mixtral_quant.py +372 -0
  56. sglang/srt/models/qwen.py +40 -35
  57. sglang/srt/models/qwen2.py +33 -36
  58. sglang/srt/models/qwen2_moe.py +473 -0
  59. sglang/srt/models/stablelm.py +33 -39
  60. sglang/srt/models/yivl.py +19 -26
  61. sglang/srt/openai_api_adapter.py +411 -0
  62. sglang/srt/{managers/openai_protocol.py → openai_protocol.py} +44 -19
  63. sglang/srt/sampling_params.py +2 -0
  64. sglang/srt/server.py +197 -481
  65. sglang/srt/server_args.py +190 -74
  66. sglang/srt/utils.py +460 -95
  67. sglang/test/test_programs.py +73 -10
  68. sglang/test/test_utils.py +226 -7
  69. sglang/utils.py +97 -27
  70. {sglang-0.1.14.dist-info → sglang-0.1.21.dist-info}/METADATA +74 -45
  71. sglang-0.1.21.dist-info/RECORD +82 -0
  72. {sglang-0.1.14.dist-info → sglang-0.1.21.dist-info}/WHEEL +1 -1
  73. sglang/srt/backend_config.py +0 -13
  74. sglang/srt/managers/router/infer_batch.py +0 -503
  75. sglang/srt/managers/router/manager.py +0 -79
  76. sglang/srt/managers/router/model_rpc.py +0 -686
  77. sglang/srt/managers/router/model_runner.py +0 -514
  78. sglang/srt/managers/router/scheduler.py +0 -70
  79. sglang-0.1.14.dist-info/RECORD +0 -64
  80. {sglang-0.1.14.dist-info → sglang-0.1.21.dist-info}/LICENSE +0 -0
  81. {sglang-0.1.14.dist-info → sglang-0.1.21.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,813 @@
1
+ """A tensor parallel worker."""
2
+
3
+ import asyncio
4
+ import logging
5
+ import time
6
+ import warnings
7
+ from concurrent.futures import ThreadPoolExecutor
8
+ from typing import List, Optional
9
+
10
+ import rpyc
11
+ import torch
12
+ from rpyc.utils.classic import obtain
13
+
14
+ from sglang.global_config import global_config
15
+ from sglang.srt.constrained.fsm_cache import FSMCache
16
+ from sglang.srt.constrained.jump_forward import JumpForwardCache
17
+ from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
18
+ from sglang.srt.managers.controller.infer_batch import (
19
+ FINISH_ABORT,
20
+ BaseFinishReason,
21
+ Batch,
22
+ ForwardMode,
23
+ Req,
24
+ )
25
+ from sglang.srt.managers.controller.model_runner import ModelRunner
26
+ from sglang.srt.managers.controller.radix_cache import RadixCache
27
+ from sglang.srt.managers.controller.schedule_heuristic import ScheduleHeuristic
28
+ from sglang.srt.managers.io_struct import (
29
+ AbortReq,
30
+ BatchTokenIDOut,
31
+ FlushCacheReq,
32
+ TokenizedGenerateReqInput,
33
+ )
34
+ from sglang.srt.model_config import ModelConfig
35
+ from sglang.srt.server_args import ModelPortArgs, ServerArgs
36
+ from sglang.srt.utils import (
37
+ connect_rpyc_service,
38
+ get_int_token_logit_bias,
39
+ is_multimodal_model,
40
+ set_random_seed,
41
+ start_rpyc_service_process,
42
+ suppress_other_loggers,
43
+ )
44
+ from sglang.utils import get_exception_traceback
45
+
46
+ logger = logging.getLogger("srt.tp_worker")
47
+
48
+
49
+ class ModelTpServer:
50
+ def __init__(
51
+ self,
52
+ gpu_id: int,
53
+ tp_rank: int,
54
+ server_args: ServerArgs,
55
+ model_port_args: ModelPortArgs,
56
+ model_overide_args: dict,
57
+ ):
58
+ server_args, model_port_args = obtain(server_args), obtain(model_port_args)
59
+ suppress_other_loggers()
60
+
61
+ # Copy arguments
62
+ self.gpu_id = gpu_id
63
+ self.tp_rank = tp_rank
64
+ self.tp_size = server_args.tp_size
65
+ self.dp_size = server_args.dp_size
66
+ self.schedule_heuristic = server_args.schedule_heuristic
67
+ self.disable_regex_jump_forward = server_args.disable_regex_jump_forward
68
+
69
+ # Init model and tokenizer
70
+ self.model_config = ModelConfig(
71
+ server_args.model_path,
72
+ server_args.trust_remote_code,
73
+ context_length=server_args.context_length,
74
+ model_overide_args=model_overide_args,
75
+ )
76
+ self.model_runner = ModelRunner(
77
+ model_config=self.model_config,
78
+ mem_fraction_static=server_args.mem_fraction_static,
79
+ gpu_id=gpu_id,
80
+ tp_rank=tp_rank,
81
+ tp_size=server_args.tp_size,
82
+ nccl_port=model_port_args.nccl_port,
83
+ server_args=server_args,
84
+ )
85
+
86
+ if is_multimodal_model(server_args.model_path):
87
+ self.processor = get_processor(
88
+ server_args.tokenizer_path,
89
+ tokenizer_mode=server_args.tokenizer_mode,
90
+ trust_remote_code=server_args.trust_remote_code,
91
+ )
92
+ self.tokenizer = self.processor.tokenizer
93
+ else:
94
+ self.tokenizer = get_tokenizer(
95
+ server_args.tokenizer_path,
96
+ tokenizer_mode=server_args.tokenizer_mode,
97
+ trust_remote_code=server_args.trust_remote_code,
98
+ )
99
+ self.max_total_num_tokens = self.model_runner.max_total_num_tokens
100
+ self.max_prefill_tokens = (
101
+ 16384
102
+ if server_args.max_prefill_tokens is None
103
+ else server_args.max_prefill_tokens
104
+ )
105
+ self.max_running_requests = (
106
+ self.max_total_num_tokens // 2
107
+ if server_args.max_running_requests is None
108
+ else server_args.max_running_requests
109
+ )
110
+ self.int_token_logit_bias = torch.tensor(
111
+ get_int_token_logit_bias(self.tokenizer, self.model_config.vocab_size)
112
+ )
113
+ set_random_seed(server_args.random_seed)
114
+
115
+ # Print info
116
+ logger.info(
117
+ f"[gpu_id={self.gpu_id}] "
118
+ f"max_total_num_tokens={self.max_total_num_tokens}, "
119
+ f"max_prefill_tokens={self.max_prefill_tokens}, "
120
+ f"context_len={self.model_config.context_len}"
121
+ )
122
+ if self.tp_rank == 0:
123
+ logger.info(
124
+ f"[gpu_id={self.gpu_id}] "
125
+ f"server_args: {server_args.print_mode_args()}"
126
+ )
127
+
128
+ # Init cache
129
+ self.tree_cache = RadixCache(
130
+ req_to_token_pool=self.model_runner.req_to_token_pool,
131
+ token_to_kv_pool=self.model_runner.token_to_kv_pool,
132
+ disable=server_args.disable_radix_cache,
133
+ )
134
+ self.tree_cache_metrics = {"total": 0, "hit": 0}
135
+ self.scheduler = ScheduleHeuristic(
136
+ self.schedule_heuristic,
137
+ self.max_running_requests,
138
+ self.max_prefill_tokens,
139
+ self.max_total_num_tokens,
140
+ self.tree_cache,
141
+ )
142
+ self.req_to_token_pool = self.model_runner.req_to_token_pool
143
+ self.token_to_kv_pool = self.model_runner.token_to_kv_pool
144
+
145
+ # Init running status
146
+ self.forward_queue: List[Req] = []
147
+ self.running_batch: Batch = None
148
+ self.out_pyobjs = []
149
+ self.decode_forward_ct = 0
150
+ self.stream_interval = server_args.stream_interval
151
+ self.num_generated_tokens = 0
152
+ self.last_stats_tic = time.time()
153
+
154
+ # Init the FSM cache for constrained generation
155
+ self.regex_fsm_cache = FSMCache(
156
+ server_args.tokenizer_path,
157
+ {
158
+ "tokenizer_mode": server_args.tokenizer_mode,
159
+ "trust_remote_code": server_args.trust_remote_code,
160
+ },
161
+ )
162
+ self.jump_forward_cache = JumpForwardCache()
163
+
164
+ # Init new token estimation
165
+ assert (
166
+ server_args.schedule_conservativeness >= 0
167
+ ), "Invalid schedule_conservativeness"
168
+ self.new_token_ratio = min(
169
+ global_config.base_new_token_ratio * server_args.schedule_conservativeness,
170
+ 1.0,
171
+ )
172
+ self.min_new_token_ratio = min(
173
+ global_config.base_min_new_token_ratio
174
+ * server_args.schedule_conservativeness,
175
+ 1.0,
176
+ )
177
+ self.new_token_ratio_decay = global_config.new_token_ratio_decay
178
+ self.new_token_ratio_recovery = global_config.new_token_ratio_recovery
179
+
180
+ def exposed_step(self, recv_reqs):
181
+ if not isinstance(recv_reqs, list):
182
+ recv_reqs = obtain(recv_reqs)
183
+
184
+ try:
185
+ # Recv requests
186
+ for recv_req in recv_reqs:
187
+ if isinstance(recv_req, TokenizedGenerateReqInput):
188
+ self.handle_generate_request(recv_req)
189
+ elif isinstance(recv_req, FlushCacheReq):
190
+ self.flush_cache()
191
+ elif isinstance(recv_req, AbortReq):
192
+ self.abort_request(recv_req)
193
+ else:
194
+ raise ValueError(f"Invalid request: {recv_req}")
195
+
196
+ # Forward
197
+ self.forward_step()
198
+ except Exception:
199
+ logger.error("Exception in ModelTpServer:\n" + get_exception_traceback())
200
+ raise
201
+
202
+ # Return results
203
+ ret = self.out_pyobjs
204
+ self.out_pyobjs = []
205
+ return ret
206
+
207
+ @torch.inference_mode()
208
+ def forward_step(self):
209
+ new_batch = self.get_new_prefill_batch()
210
+
211
+ if new_batch is not None:
212
+ # Run a new prefill batch
213
+ self.forward_prefill_batch(new_batch)
214
+ self.cache_filled_batch(new_batch)
215
+
216
+ if not new_batch.is_empty():
217
+ if self.running_batch is None:
218
+ self.running_batch = new_batch
219
+ else:
220
+ self.running_batch.merge(new_batch)
221
+ else:
222
+ # Run a decode batch
223
+ if self.running_batch is not None:
224
+ # Run a few decode batches continuously for reducing overhead
225
+ for _ in range(global_config.num_continue_decode_steps):
226
+ self.num_generated_tokens += len(self.running_batch.reqs)
227
+ self.forward_decode_batch(self.running_batch)
228
+
229
+ # Print stats
230
+ if self.tp_rank == 0 and self.decode_forward_ct % 40 == 0:
231
+ num_used = self.max_total_num_tokens - (
232
+ self.token_to_kv_pool.available_size()
233
+ + self.tree_cache.evictable_size()
234
+ )
235
+ throughput = self.num_generated_tokens / (
236
+ time.time() - self.last_stats_tic
237
+ )
238
+ self.num_generated_tokens = 0
239
+ self.last_stats_tic = time.time()
240
+ logger.info(
241
+ f"[gpu_id={self.gpu_id}] Decode batch. "
242
+ f"#running-req: {len(self.running_batch.reqs)}, "
243
+ f"#token: {num_used}, "
244
+ f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
245
+ f"gen throughput (token/s): {throughput:.2f}, "
246
+ f"#queue-req: {len(self.forward_queue)}"
247
+ )
248
+
249
+ if self.running_batch.is_empty():
250
+ self.running_batch = None
251
+ break
252
+
253
+ if self.out_pyobjs and self.running_batch.has_stream():
254
+ break
255
+ else:
256
+ # Check the available size
257
+ available_size = (
258
+ self.token_to_kv_pool.available_size()
259
+ + self.tree_cache.evictable_size()
260
+ )
261
+ if available_size != self.max_total_num_tokens:
262
+ warnings.warn(
263
+ "Warning: "
264
+ f"available_size={available_size}, max_total_num_tokens={self.max_total_num_tokens}\n"
265
+ "KV cache pool leak detected!"
266
+ )
267
+
268
+ def handle_generate_request(
269
+ self,
270
+ recv_req: TokenizedGenerateReqInput,
271
+ ):
272
+ req = Req(recv_req.rid, recv_req.input_text, recv_req.input_ids)
273
+ req.pixel_values = recv_req.pixel_values
274
+ if req.pixel_values is not None:
275
+ req.pad_value = [
276
+ (recv_req.image_hash) % self.model_config.vocab_size,
277
+ (recv_req.image_hash >> 16) % self.model_config.vocab_size,
278
+ (recv_req.image_hash >> 32) % self.model_config.vocab_size,
279
+ (recv_req.image_hash >> 64) % self.model_config.vocab_size,
280
+ ]
281
+ req.image_size = recv_req.image_size
282
+ (
283
+ req.origin_input_ids,
284
+ req.image_offset,
285
+ ) = self.model_runner.model.pad_input_ids(
286
+ req.origin_input_ids_unpadded,
287
+ req.pad_value,
288
+ req.pixel_values.shape,
289
+ req.image_size,
290
+ )
291
+ req.sampling_params = recv_req.sampling_params
292
+ req.return_logprob = recv_req.return_logprob
293
+ req.logprob_start_len = recv_req.logprob_start_len
294
+ req.top_logprobs_num = recv_req.top_logprobs_num
295
+ req.stream = recv_req.stream
296
+ req.tokenizer = self.tokenizer
297
+
298
+ # Init regex fsm
299
+ if req.sampling_params.regex is not None:
300
+ req.regex_fsm = self.regex_fsm_cache.query(req.sampling_params.regex)
301
+ if not self.disable_regex_jump_forward:
302
+ req.jump_forward_map = self.jump_forward_cache.query(
303
+ req.sampling_params.regex
304
+ )
305
+
306
+ # Truncate prompts that are too long
307
+ req.origin_input_ids = req.origin_input_ids[: self.model_config.context_len - 1]
308
+ req.sampling_params.max_new_tokens = min(
309
+ req.sampling_params.max_new_tokens,
310
+ self.model_config.context_len - 1 - len(req.origin_input_ids),
311
+ self.max_total_num_tokens - 128 - len(req.origin_input_ids),
312
+ )
313
+ self.forward_queue.append(req)
314
+
315
+ def get_new_prefill_batch(self) -> Optional[Batch]:
316
+ running_bs = (
317
+ len(self.running_batch.reqs) if self.running_batch is not None else 0
318
+ )
319
+ if running_bs >= self.max_running_requests:
320
+ return
321
+
322
+ # Compute matched prefix length
323
+ for req in self.forward_queue:
324
+ req.input_ids = req.origin_input_ids + req.output_ids
325
+ prefix_indices, last_node = self.tree_cache.match_prefix(req.input_ids)
326
+ if req.return_logprob:
327
+ prefix_indices = prefix_indices[: req.logprob_start_len]
328
+ req.extend_input_len = len(req.input_ids) - len(prefix_indices)
329
+ req.prefix_indices = prefix_indices
330
+ req.last_node = last_node
331
+
332
+ # Get priority queue
333
+ self.forward_queue = self.scheduler.get_priority_queue(self.forward_queue)
334
+
335
+ # Add requests if there is available space
336
+ can_run_list = []
337
+ new_batch_total_tokens = 0
338
+ new_batch_input_tokens = 0
339
+
340
+ available_size = (
341
+ self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()
342
+ )
343
+ if self.running_batch:
344
+ available_size -= sum(
345
+ [
346
+ (r.sampling_params.max_new_tokens - len(r.output_ids)) * self.new_token_ratio
347
+ for r in self.running_batch.reqs
348
+ ]
349
+ )
350
+
351
+ for req in self.forward_queue:
352
+ if req.return_logprob and req.normalized_prompt_logprob is None:
353
+ # Need at least two tokens to compute normalized logprob
354
+ if req.extend_input_len < 2:
355
+ delta = 2 - req.extend_input_len
356
+ req.extend_input_len += delta
357
+ req.prefix_indices = req.prefix_indices[:-delta]
358
+ if req.image_offset is not None:
359
+ req.image_offset += delta
360
+ if req.extend_input_len == 0 and req.sampling_params.max_new_tokens > 0:
361
+ # Need at least one token to compute logits
362
+ req.extend_input_len = 1
363
+ req.prefix_indices = req.prefix_indices[:-1]
364
+ if req.image_offset is not None:
365
+ req.image_offset += 1
366
+
367
+ if (
368
+ req.extend_input_len + req.sampling_params.max_new_tokens + new_batch_total_tokens
369
+ < available_size
370
+ and (
371
+ req.extend_input_len + new_batch_input_tokens
372
+ <= self.max_prefill_tokens
373
+ or len(can_run_list) == 0
374
+ )
375
+ ):
376
+ delta = self.tree_cache.inc_lock_ref(req.last_node)
377
+ available_size += delta
378
+
379
+ if not (
380
+ req.extend_input_len + req.sampling_params.max_new_tokens + new_batch_total_tokens
381
+ < available_size
382
+ ):
383
+ # Undo locking
384
+ delta = self.tree_cache.dec_lock_ref(req.last_node)
385
+ available_size += delta
386
+ break
387
+ else:
388
+ # Add this request to the running batch
389
+ can_run_list.append(req)
390
+ new_batch_total_tokens += (
391
+ req.extend_input_len + req.sampling_params.max_new_tokens
392
+ )
393
+ new_batch_input_tokens += req.extend_input_len
394
+ else:
395
+ break
396
+
397
+ if running_bs + len(can_run_list) >= self.max_running_requests:
398
+ break
399
+
400
+ if len(can_run_list) == 0:
401
+ return None
402
+
403
+ # Print stats
404
+ if self.tp_rank == 0:
405
+ hit_tokens = sum(len(x.prefix_indices) for x in can_run_list)
406
+ self.tree_cache_metrics["total"] += (
407
+ hit_tokens + new_batch_input_tokens
408
+ ) / 10**9
409
+ self.tree_cache_metrics["hit"] += hit_tokens / 10**9
410
+ tree_cache_hit_rate = (
411
+ self.tree_cache_metrics["hit"] / self.tree_cache_metrics["total"]
412
+ )
413
+ logger.info(
414
+ f"[gpu_id={self.gpu_id}] Prefill batch. "
415
+ f"#new-seq: {len(can_run_list)}, "
416
+ f"#new-token: {new_batch_input_tokens}, "
417
+ f"#cached-token: {hit_tokens}, "
418
+ f"cache hit rate: {100.0 * tree_cache_hit_rate:.2f}%, "
419
+ f"#running-req: {running_bs}, "
420
+ f"#queue-req: {len(self.forward_queue) - len(can_run_list)}"
421
+ )
422
+ # logger.debug(
423
+ # f"fsm_cache_hit_rate: {100.0 * self.regex_fsm_cache.get_cache_hit_rate():.2f}%. "
424
+ # f"fsm_cache_avg_init_time: {self.regex_fsm_cache.get_avg_init_time():.2f}s. "
425
+ # f"ff_cache_hit_rate: {100.0 * self.jump_forward_cache.get_cache_hit_rate():.2f}%. "
426
+ # f"ff_cache_avg_init_time: {self.jump_forward_cache.get_avg_init_time():.2f}s. "
427
+ # )
428
+
429
+ # Return the new batch
430
+ new_batch = Batch.init_new(
431
+ can_run_list,
432
+ self.req_to_token_pool,
433
+ self.token_to_kv_pool,
434
+ self.tree_cache,
435
+ )
436
+ self.forward_queue = [x for x in self.forward_queue if x not in can_run_list]
437
+ return new_batch
438
+
439
+ def forward_prefill_batch(self, batch: Batch):
440
+ # Build batch tensors
441
+ batch.prepare_for_extend(
442
+ self.model_config.vocab_size, self.int_token_logit_bias
443
+ )
444
+
445
+ # Forward and sample the next tokens
446
+ if batch.extend_num_tokens != 0:
447
+ output = self.model_runner.forward(batch, ForwardMode.EXTEND)
448
+ next_token_ids, _ = batch.sample(output.next_token_logits)
449
+
450
+ # Move logprobs to cpu
451
+ if output.next_token_logprobs is not None:
452
+ output.next_token_logprobs = output.next_token_logprobs[
453
+ torch.arange(len(next_token_ids), device=next_token_ids.device),
454
+ next_token_ids,
455
+ ].tolist()
456
+ output.prefill_token_logprobs = output.prefill_token_logprobs.tolist()
457
+ output.normalized_prompt_logprobs = (
458
+ output.normalized_prompt_logprobs.tolist()
459
+ )
460
+
461
+ next_token_ids = next_token_ids.tolist()
462
+ else:
463
+ next_token_ids = [self.tokenizer.eos_token_id] * len(batch.reqs)
464
+
465
+ # Check finish conditions
466
+ pt = 0
467
+ for i, req in enumerate(batch.reqs):
468
+ req.completion_tokens_wo_jump_forward += 1
469
+ req.output_ids.append(next_token_ids[i])
470
+ req.check_finished()
471
+
472
+ if req.return_logprob:
473
+ self.add_logprob_return_values(i, req, pt, next_token_ids, output)
474
+ pt += req.extend_input_len
475
+
476
+ self.handle_finished_requests(batch)
477
+
478
+ def add_logprob_return_values(self, i, req, pt, next_token_ids, output):
479
+ if req.normalized_prompt_logprob is None:
480
+ req.normalized_prompt_logprob = output.normalized_prompt_logprobs[i]
481
+
482
+ if req.prefill_token_logprobs is None:
483
+ # If logprob_start_len > 0, then first logprob_start_len prompt tokens will be ignored.
484
+ req.prefill_token_logprobs = list(
485
+ zip(
486
+ output.prefill_token_logprobs[pt : pt + req.extend_input_len - 1],
487
+ req.input_ids[-req.extend_input_len + 1 :],
488
+ )
489
+ )
490
+ if req.logprob_start_len == 0:
491
+ req.prefill_token_logprobs = [
492
+ (None, req.input_ids[0])
493
+ ] + req.prefill_token_logprobs
494
+
495
+ if req.last_update_decode_tokens != 0:
496
+ req.decode_token_logprobs.extend(
497
+ list(
498
+ zip(
499
+ output.prefill_token_logprobs[
500
+ pt
501
+ + req.extend_input_len
502
+ - req.last_update_decode_tokens : pt
503
+ + req.extend_input_len
504
+ - 1
505
+ ],
506
+ req.input_ids[-req.last_update_decode_tokens + 1 :],
507
+ )
508
+ )
509
+ )
510
+
511
+ req.decode_token_logprobs.append(
512
+ (output.next_token_logprobs[i], next_token_ids[i])
513
+ )
514
+
515
+ if req.top_logprobs_num > 0:
516
+ if req.prefill_top_logprobs is None:
517
+ req.prefill_top_logprobs = output.prefill_top_logprobs[i]
518
+ if req.logprob_start_len == 0:
519
+ req.prefill_top_logprobs = [None] + req.prefill_top_logprobs
520
+
521
+ if req.last_update_decode_tokens != 0:
522
+ req.decode_top_logprobs.extend(
523
+ output.prefill_top_logprobs[i][-req.last_update_decode_tokens + 1 :]
524
+ )
525
+ req.decode_top_logprobs.append(output.decode_top_logprobs[i])
526
+
527
+ def cache_filled_batch(self, batch: Batch):
528
+ req_pool_indices_cpu = batch.req_pool_indices.cpu().numpy()
529
+ for i, req in enumerate(batch.reqs):
530
+ new_prefix_indices, new_last_node = self.tree_cache.cache_req(
531
+ token_ids=tuple(req.origin_input_ids + req.output_ids)[:-1],
532
+ last_uncached_pos=len(req.prefix_indices),
533
+ req_pool_idx=req_pool_indices_cpu[i],
534
+ del_in_memory_pool=False,
535
+ old_last_node=req.last_node,
536
+ )
537
+ req.prefix_indices, req.last_node = new_prefix_indices, new_last_node
538
+
539
+ def forward_decode_batch(self, batch: Batch):
540
+ # Check if decode out of memory
541
+ if not batch.check_decode_mem():
542
+ old_ratio = self.new_token_ratio
543
+ self.new_token_ratio = min(old_ratio + self.new_token_ratio_recovery, 1.0)
544
+
545
+ retracted_reqs = batch.retract_decode()
546
+ logger.info(
547
+ "decode out of memory happened, "
548
+ f"#retracted_reqs: {len(retracted_reqs)}, "
549
+ f"#new_token_ratio: {old_ratio:.4f} -> {self.new_token_ratio:.4f}"
550
+ )
551
+ self.forward_queue.extend(retracted_reqs)
552
+ else:
553
+ self.new_token_ratio = max(
554
+ self.new_token_ratio - self.new_token_ratio_decay,
555
+ self.min_new_token_ratio,
556
+ )
557
+
558
+ if not self.disable_regex_jump_forward:
559
+ # Check for jump-forward
560
+ jump_forward_reqs = batch.check_for_jump_forward(self.model_runner)
561
+ self.forward_queue.extend(jump_forward_reqs)
562
+ if batch.is_empty():
563
+ return
564
+
565
+ # Update batch tensors
566
+ self.decode_forward_ct = (self.decode_forward_ct + 1) % (1 << 30)
567
+ batch.prepare_for_decode()
568
+
569
+ # Forward and sample the next tokens
570
+ output = self.model_runner.forward(batch, ForwardMode.DECODE)
571
+ next_token_ids, _ = batch.sample(output.next_token_logits)
572
+
573
+ # Move logprobs to cpu
574
+ if output.next_token_logprobs is not None:
575
+ next_token_logprobs = output.next_token_logprobs[
576
+ torch.arange(len(next_token_ids), device=next_token_ids.device),
577
+ next_token_ids,
578
+ ].tolist()
579
+
580
+ next_token_ids = next_token_ids.tolist()
581
+
582
+ # Check finish condition
583
+ for i, (req, next_token_id) in enumerate(zip(batch.reqs, next_token_ids)):
584
+ req.completion_tokens_wo_jump_forward += 1
585
+ req.output_ids.append(next_token_id)
586
+ req.check_finished()
587
+
588
+ if req.return_logprob:
589
+ req.decode_token_logprobs.append(
590
+ (next_token_logprobs[i], next_token_id)
591
+ )
592
+ if req.top_logprobs_num > 0:
593
+ req.decode_top_logprobs.append(output.decode_top_logprobs[i])
594
+
595
+ self.handle_finished_requests(batch)
596
+
597
+ def handle_finished_requests(self, batch: Batch):
598
+ output_rids = []
599
+ decoded_texts = []
600
+ surr_output_ids = []
601
+ read_output_ids = []
602
+ output_skip_special_tokens = []
603
+ output_spaces_between_special_tokens = []
604
+ output_meta_info = []
605
+ output_finished_reason: List[BaseFinishReason] = []
606
+ finished_indices = []
607
+ unfinished_indices = []
608
+ for i, req in enumerate(batch.reqs):
609
+ if req.finished():
610
+ finished_indices.append(i)
611
+ else:
612
+ unfinished_indices.append(i)
613
+
614
+ if req.finished() or (
615
+ (
616
+ req.stream
617
+ and (
618
+ self.decode_forward_ct % self.stream_interval == 0
619
+ or len(req.output_ids) == 1
620
+ )
621
+ )
622
+ ):
623
+ output_rids.append(req.rid)
624
+ decoded_texts.append(req.decoded_text)
625
+ surr_ids, read_ids, _ = req.init_detokenize_incrementally()
626
+ surr_output_ids.append(surr_ids)
627
+ read_output_ids.append(read_ids)
628
+ output_skip_special_tokens.append(
629
+ req.sampling_params.skip_special_tokens
630
+ )
631
+ output_spaces_between_special_tokens.append(
632
+ req.sampling_params.spaces_between_special_tokens
633
+ )
634
+
635
+ meta_info = {
636
+ "prompt_tokens": len(req.origin_input_ids),
637
+ "completion_tokens": len(req.output_ids),
638
+ "completion_tokens_wo_jump_forward": req.completion_tokens_wo_jump_forward,
639
+ "finish_reason": str(req.finished_reason),
640
+ }
641
+ if req.return_logprob:
642
+ (
643
+ meta_info["prefill_token_logprobs"],
644
+ meta_info["decode_token_logprobs"],
645
+ meta_info["prefill_top_logprobs"],
646
+ meta_info["decode_top_logprobs"],
647
+ meta_info["normalized_prompt_logprob"],
648
+ ) = (
649
+ req.prefill_token_logprobs,
650
+ req.decode_token_logprobs,
651
+ req.prefill_top_logprobs,
652
+ req.decode_top_logprobs,
653
+ req.normalized_prompt_logprob,
654
+ )
655
+ output_meta_info.append(meta_info)
656
+ output_finished_reason.append(req.finished_reason)
657
+
658
+ # Send to detokenizer
659
+ if output_rids:
660
+ self.out_pyobjs.append(
661
+ BatchTokenIDOut(
662
+ output_rids,
663
+ decoded_texts,
664
+ surr_output_ids,
665
+ read_output_ids,
666
+ output_skip_special_tokens,
667
+ output_spaces_between_special_tokens,
668
+ output_meta_info,
669
+ output_finished_reason,
670
+ )
671
+ )
672
+
673
+ # Remove finished reqs
674
+ if finished_indices:
675
+ # Update radix cache
676
+ req_pool_indices_cpu = batch.req_pool_indices.tolist()
677
+ for i in finished_indices:
678
+ req = batch.reqs[i]
679
+ self.tree_cache.cache_req(
680
+ token_ids=tuple(req.origin_input_ids + req.output_ids)[:-1],
681
+ last_uncached_pos=len(req.prefix_indices),
682
+ req_pool_idx=req_pool_indices_cpu[i],
683
+ )
684
+
685
+ self.tree_cache.dec_lock_ref(req.last_node)
686
+
687
+ # Update batch tensors
688
+ if unfinished_indices:
689
+ batch.filter_batch(unfinished_indices)
690
+ else:
691
+ batch.reqs = []
692
+
693
+ def flush_cache(self):
694
+ if len(self.forward_queue) == 0 and (
695
+ self.running_batch is None or len(self.running_batch.reqs) == 0
696
+ ):
697
+ self.tree_cache.reset()
698
+ self.tree_cache_metrics = {"total": 0, "hit": 0}
699
+ self.regex_fsm_cache.reset()
700
+ self.req_to_token_pool.clear()
701
+ self.token_to_kv_pool.clear()
702
+ torch.cuda.empty_cache()
703
+ logger.info("Cache flushed successfully!")
704
+ else:
705
+ warnings.warn(
706
+ f"Cache not flushed because there are pending requests. "
707
+ f"#queue-req: {len(self.forward_queue)}, "
708
+ f"#running-req: {0 if self.running_batch is None else len(self.running_batch.reqs)}"
709
+ )
710
+
711
+ def abort_request(self, recv_req):
712
+ # Delete requests in the waiting queue
713
+ to_del = None
714
+ for i, req in enumerate(self.forward_queue):
715
+ if req.rid == recv_req.rid:
716
+ to_del = i
717
+ break
718
+
719
+ if to_del is not None:
720
+ del self.forward_queue[to_del]
721
+
722
+ # Delete requests in the running batch
723
+ if self.running_batch:
724
+ for req in self.running_batch.reqs:
725
+ if req.rid == recv_req.rid:
726
+ req.finished_reason = FINISH_ABORT()
727
+ break
728
+
729
+
730
+ class ModelTpService(rpyc.Service):
731
+ exposed_ModelTpServer = ModelTpServer
732
+
733
+
734
+ class ModelTpClient:
735
+ def __init__(
736
+ self,
737
+ gpu_ids: List[int],
738
+ server_args: ServerArgs,
739
+ model_port_args: ModelPortArgs,
740
+ model_overide_args,
741
+ ):
742
+ server_args, model_port_args = obtain(server_args), obtain(model_port_args)
743
+ self.tp_size = server_args.tp_size
744
+
745
+ if self.tp_size * server_args.dp_size == 1:
746
+ # Init model
747
+ assert len(gpu_ids) == 1
748
+ self.model_server = ModelTpService().exposed_ModelTpServer(
749
+ gpu_ids[0],
750
+ 0,
751
+ server_args,
752
+ model_port_args,
753
+ model_overide_args,
754
+ )
755
+
756
+ # Wrap functions
757
+ def async_wrap(f):
758
+ async def _func(*args, **kwargs):
759
+ return f(*args, **kwargs)
760
+
761
+ return _func
762
+
763
+ self.step = async_wrap(self.model_server.exposed_step)
764
+ else:
765
+ with ThreadPoolExecutor(self.tp_size) as executor:
766
+ # Launch model processes
767
+ if server_args.nnodes == 1:
768
+ self.procs = list(
769
+ executor.map(
770
+ lambda args: start_rpyc_service_process(*args),
771
+ [
772
+ (ModelTpService, p)
773
+ for p in model_port_args.model_tp_ports
774
+ ],
775
+ )
776
+ )
777
+ addrs = [("localhost", p) for p in model_port_args.model_tp_ports]
778
+ else:
779
+ addrs = [
780
+ (ip, port)
781
+ for ip, port in zip(
782
+ model_port_args.model_tp_ips, model_port_args.model_tp_ports
783
+ )
784
+ ]
785
+
786
+ self.model_services = list(
787
+ executor.map(lambda args: connect_rpyc_service(*args), addrs)
788
+ )
789
+
790
+ # Init model
791
+ def init_model(i):
792
+ return self.model_services[i].ModelTpServer(
793
+ gpu_ids[i],
794
+ i,
795
+ server_args,
796
+ model_port_args,
797
+ model_overide_args,
798
+ )
799
+
800
+ self.model_servers = list(executor.map(init_model, range(self.tp_size)))
801
+
802
+ # Wrap functions
803
+ def async_wrap(func_name):
804
+ fs = [rpyc.async_(getattr(m, func_name)) for m in self.model_servers]
805
+
806
+ async def _func(*args, **kwargs):
807
+ tasks = [f(*args, **kwargs) for f in fs]
808
+ await asyncio.gather(*[asyncio.to_thread(t.wait) for t in tasks])
809
+ return obtain(tasks[0].value)
810
+
811
+ return _func
812
+
813
+ self.step = async_wrap("step")