sglang 0.1.14__py3-none-any.whl → 0.1.21__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (81) hide show
  1. sglang/__init__.py +59 -2
  2. sglang/api.py +40 -11
  3. sglang/backend/anthropic.py +17 -3
  4. sglang/backend/litellm.py +90 -0
  5. sglang/backend/openai.py +160 -12
  6. sglang/backend/runtime_endpoint.py +62 -27
  7. sglang/backend/vertexai.py +1 -0
  8. sglang/bench_latency.py +320 -0
  9. sglang/global_config.py +24 -3
  10. sglang/lang/chat_template.py +122 -6
  11. sglang/lang/compiler.py +2 -2
  12. sglang/lang/interpreter.py +206 -98
  13. sglang/lang/ir.py +98 -34
  14. sglang/lang/tracer.py +6 -4
  15. sglang/launch_server.py +4 -1
  16. sglang/launch_server_llavavid.py +32 -0
  17. sglang/srt/constrained/__init__.py +14 -6
  18. sglang/srt/constrained/fsm_cache.py +9 -2
  19. sglang/srt/constrained/jump_forward.py +113 -24
  20. sglang/srt/conversation.py +4 -2
  21. sglang/srt/flush_cache.py +18 -0
  22. sglang/srt/hf_transformers_utils.py +144 -3
  23. sglang/srt/layers/context_flashattention_nopad.py +1 -0
  24. sglang/srt/layers/extend_attention.py +20 -1
  25. sglang/srt/layers/fused_moe.py +596 -0
  26. sglang/srt/layers/logits_processor.py +190 -61
  27. sglang/srt/layers/radix_attention.py +62 -53
  28. sglang/srt/layers/token_attention.py +21 -9
  29. sglang/srt/managers/controller/cuda_graph_runner.py +196 -0
  30. sglang/srt/managers/controller/dp_worker.py +113 -0
  31. sglang/srt/managers/controller/infer_batch.py +908 -0
  32. sglang/srt/managers/controller/manager_multi.py +195 -0
  33. sglang/srt/managers/controller/manager_single.py +177 -0
  34. sglang/srt/managers/controller/model_runner.py +359 -0
  35. sglang/srt/managers/{router → controller}/radix_cache.py +102 -53
  36. sglang/srt/managers/controller/schedule_heuristic.py +65 -0
  37. sglang/srt/managers/controller/tp_worker.py +813 -0
  38. sglang/srt/managers/detokenizer_manager.py +42 -40
  39. sglang/srt/managers/io_struct.py +44 -10
  40. sglang/srt/managers/tokenizer_manager.py +224 -82
  41. sglang/srt/memory_pool.py +52 -59
  42. sglang/srt/model_config.py +97 -2
  43. sglang/srt/models/chatglm.py +399 -0
  44. sglang/srt/models/commandr.py +369 -0
  45. sglang/srt/models/dbrx.py +406 -0
  46. sglang/srt/models/gemma.py +34 -38
  47. sglang/srt/models/gemma2.py +436 -0
  48. sglang/srt/models/grok.py +738 -0
  49. sglang/srt/models/llama2.py +47 -37
  50. sglang/srt/models/llama_classification.py +107 -0
  51. sglang/srt/models/llava.py +92 -27
  52. sglang/srt/models/llavavid.py +298 -0
  53. sglang/srt/models/minicpm.py +366 -0
  54. sglang/srt/models/mixtral.py +302 -127
  55. sglang/srt/models/mixtral_quant.py +372 -0
  56. sglang/srt/models/qwen.py +40 -35
  57. sglang/srt/models/qwen2.py +33 -36
  58. sglang/srt/models/qwen2_moe.py +473 -0
  59. sglang/srt/models/stablelm.py +33 -39
  60. sglang/srt/models/yivl.py +19 -26
  61. sglang/srt/openai_api_adapter.py +411 -0
  62. sglang/srt/{managers/openai_protocol.py → openai_protocol.py} +44 -19
  63. sglang/srt/sampling_params.py +2 -0
  64. sglang/srt/server.py +197 -481
  65. sglang/srt/server_args.py +190 -74
  66. sglang/srt/utils.py +460 -95
  67. sglang/test/test_programs.py +73 -10
  68. sglang/test/test_utils.py +226 -7
  69. sglang/utils.py +97 -27
  70. {sglang-0.1.14.dist-info → sglang-0.1.21.dist-info}/METADATA +74 -45
  71. sglang-0.1.21.dist-info/RECORD +82 -0
  72. {sglang-0.1.14.dist-info → sglang-0.1.21.dist-info}/WHEEL +1 -1
  73. sglang/srt/backend_config.py +0 -13
  74. sglang/srt/managers/router/infer_batch.py +0 -503
  75. sglang/srt/managers/router/manager.py +0 -79
  76. sglang/srt/managers/router/model_rpc.py +0 -686
  77. sglang/srt/managers/router/model_runner.py +0 -514
  78. sglang/srt/managers/router/scheduler.py +0 -70
  79. sglang-0.1.14.dist-info/RECORD +0 -64
  80. {sglang-0.1.14.dist-info → sglang-0.1.21.dist-info}/LICENSE +0 -0
  81. {sglang-0.1.14.dist-info → sglang-0.1.21.dist-info}/top_level.txt +0 -0
@@ -1,686 +0,0 @@
1
- import asyncio
2
- import logging
3
- import multiprocessing
4
- import time
5
- import warnings
6
- from concurrent.futures import ThreadPoolExecutor
7
- from typing import List
8
-
9
- import numpy as np
10
- import rpyc
11
- import torch
12
- from rpyc.utils.classic import obtain
13
- from rpyc.utils.server import ThreadedServer
14
- from sglang.srt.constrained.fsm_cache import FSMCache
15
- from sglang.srt.constrained.jump_forward import JumpForwardCache
16
- from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
17
- from sglang.srt.managers.io_struct import (
18
- BatchTokenIDOut,
19
- FlushCacheReq,
20
- TokenizedGenerateReqInput,
21
- )
22
- from sglang.srt.managers.router.infer_batch import Batch, ForwardMode, Req
23
- from sglang.srt.managers.router.model_runner import ModelRunner
24
- from sglang.srt.managers.router.radix_cache import RadixCache
25
- from sglang.srt.managers.router.scheduler import Scheduler
26
- from sglang.srt.model_config import ModelConfig
27
- from sglang.srt.server_args import PortArgs, ServerArgs
28
- from sglang.srt.utils import (
29
- get_exception_traceback,
30
- get_int_token_logit_bias,
31
- is_multimodal_model,
32
- set_random_seed,
33
- )
34
- from vllm.logger import _default_handler as vllm_default_handler
35
-
36
- logger = logging.getLogger("model_rpc")
37
-
38
-
39
- class ModelRpcServer(rpyc.Service):
40
- def exposed_init_model(
41
- self,
42
- tp_rank: int,
43
- server_args: ServerArgs,
44
- port_args: PortArgs,
45
- ):
46
- server_args, port_args = [obtain(x) for x in [server_args, port_args]]
47
-
48
- # Copy arguments
49
- self.tp_rank = tp_rank
50
- self.tp_size = server_args.tp_size
51
- self.schedule_heuristic = server_args.schedule_heuristic
52
- self.disable_regex_jump_forward = server_args.disable_regex_jump_forward
53
- vllm_default_handler.setLevel(
54
- level=getattr(logging, server_args.log_level.upper())
55
- )
56
-
57
- # Init model and tokenizer
58
- self.model_config = ModelConfig(
59
- server_args.model_path,
60
- server_args.trust_remote_code,
61
- context_length=server_args.context_length,
62
- )
63
-
64
- # for model end global settings
65
- server_args_dict = {
66
- "enable_flashinfer": server_args.enable_flashinfer,
67
- "attention_reduce_in_fp32": server_args.attention_reduce_in_fp32,
68
- }
69
-
70
- self.model_runner = ModelRunner(
71
- model_config=self.model_config,
72
- mem_fraction_static=server_args.mem_fraction_static,
73
- tp_rank=tp_rank,
74
- tp_size=server_args.tp_size,
75
- nccl_port=port_args.nccl_port,
76
- load_format=server_args.load_format,
77
- trust_remote_code=server_args.trust_remote_code,
78
- server_args_dict=server_args_dict,
79
- )
80
- if is_multimodal_model(server_args.model_path):
81
- self.processor = get_processor(
82
- server_args.tokenizer_path,
83
- tokenizer_mode=server_args.tokenizer_mode,
84
- trust_remote_code=server_args.trust_remote_code,
85
- )
86
- self.tokenizer = self.processor.tokenizer
87
- else:
88
- self.tokenizer = get_tokenizer(
89
- server_args.tokenizer_path,
90
- tokenizer_mode=server_args.tokenizer_mode,
91
- trust_remote_code=server_args.trust_remote_code,
92
- )
93
- self.eos_token_id = self.tokenizer.eos_token_id
94
- self.max_total_num_token = self.model_runner.max_total_num_token
95
- self.max_num_running_seq = self.max_total_num_token // 2
96
- self.max_prefill_num_token = max(
97
- self.model_config.context_len,
98
- (
99
- self.max_total_num_token // 6
100
- if server_args.max_prefill_num_token is None
101
- else server_args.max_prefill_num_token
102
- ),
103
- )
104
- self.int_token_logit_bias = torch.tensor(
105
- get_int_token_logit_bias(self.tokenizer, self.model_config.vocab_size)
106
- )
107
- set_random_seed(server_args.random_seed)
108
- logger.info(
109
- f"Rank {self.tp_rank}: "
110
- f"max_total_num_token={self.max_total_num_token}, "
111
- f"max_prefill_num_token={self.max_prefill_num_token}, "
112
- f"context_len={self.model_config.context_len}, "
113
- )
114
- logger.info(server_args.get_optional_modes_logging())
115
-
116
- # Init cache
117
- self.tree_cache = RadixCache(server_args.disable_radix_cache)
118
- self.tree_cache_metrics = {"total": 0, "hit": 0}
119
- self.scheduler = Scheduler(
120
- self.schedule_heuristic,
121
- self.max_num_running_seq,
122
- self.max_prefill_num_token,
123
- self.max_total_num_token,
124
- self.tree_cache,
125
- )
126
- self.req_to_token_pool = self.model_runner.req_to_token_pool
127
- self.token_to_kv_pool = self.model_runner.token_to_kv_pool
128
-
129
- # Init running status
130
- self.forward_queue: List[Req] = []
131
- self.running_batch: Batch = None
132
- self.out_pyobjs = []
133
- self.decode_forward_ct = 0
134
- self.stream_interval = server_args.stream_interval
135
-
136
- # Init the FSM cache for constrained generation
137
- self.regex_fsm_cache = FSMCache(
138
- server_args.tokenizer_path,
139
- {
140
- "tokenizer_mode": server_args.tokenizer_mode,
141
- "trust_remote_code": server_args.trust_remote_code,
142
- },
143
- )
144
- self.jump_forward_cache = JumpForwardCache()
145
-
146
- # Init new token estimation
147
- self.new_token_ratio = min(0.4 * server_args.schedule_conservativeness, 1.0)
148
- self.min_new_token_ratio = min(0.2 * server_args.schedule_conservativeness, 1.0)
149
- self.new_token_ratio_step = (0.0001, 0.05) # (down, up)
150
-
151
- def flush_cache(self):
152
- if len(self.forward_queue) == 0 and (
153
- self.running_batch is None or len(self.running_batch.reqs) == 0
154
- ):
155
- self.tree_cache.reset()
156
- self.tree_cache_metrics = {"total": 0, "hit": 0}
157
- self.regex_fsm_cache.reset()
158
- self.req_to_token_pool.clear()
159
- self.token_to_kv_pool.clear()
160
- torch.cuda.empty_cache()
161
- logger.info("Cache flushed successfully!")
162
- else:
163
- warnings.warn(
164
- "Cache not flushed because there are pending requests. "
165
- f"#queue-req: {len(self.forward_queue)}, "
166
- f"#running-req: {0 if self.running_batch is None else len(self.running_batch.reqs)}"
167
- )
168
-
169
- def exposed_step(self, recv_reqs):
170
- if self.tp_size != 1:
171
- recv_reqs = obtain(recv_reqs)
172
-
173
- try:
174
- # Recv requests
175
- for recv_req in recv_reqs:
176
- if isinstance(recv_req, TokenizedGenerateReqInput):
177
- self.handle_generate_request(recv_req)
178
- elif isinstance(recv_req, FlushCacheReq):
179
- self.flush_cache()
180
- else:
181
- raise ValueError(f"Invalid request: {recv_req}")
182
-
183
- # Forward
184
- self.forward_step()
185
- except Exception:
186
- logger.error("Exception in ModelRpcClient:\n" + get_exception_traceback())
187
-
188
- # Return results
189
- ret = self.out_pyobjs
190
- self.out_pyobjs = []
191
- return ret
192
-
193
- @torch.inference_mode()
194
- def forward_step(self):
195
- new_batch = self.get_new_fill_batch()
196
-
197
- if new_batch is not None:
198
- # Run new fill batch
199
- self.forward_fill_batch(new_batch)
200
-
201
- if not new_batch.is_empty():
202
- if self.running_batch is None:
203
- self.running_batch = new_batch
204
- else:
205
- self.running_batch.merge(new_batch)
206
- else:
207
- # Run decode batch
208
- if self.running_batch is not None:
209
- # Run a few decode batches continuously for reducing overhead
210
- for _ in range(10):
211
- self.forward_decode_batch(self.running_batch)
212
-
213
- if self.running_batch.is_empty():
214
- self.running_batch = None
215
- break
216
-
217
- if self.out_pyobjs and self.running_batch.reqs[0].stream:
218
- break
219
-
220
- if self.running_batch is not None and self.tp_rank == 0:
221
- if self.decode_forward_ct % 40 == 0:
222
- num_used = self.max_total_num_token - (
223
- self.token_to_kv_pool.available_size()
224
- + self.tree_cache.evictable_size()
225
- )
226
- logger.info(
227
- f"#running-req: {len(self.running_batch.reqs)}, "
228
- f"#token: {num_used}, "
229
- f"token usage: {num_used / self.max_total_num_token:.2f}, "
230
- f"#queue-req: {len(self.forward_queue)}"
231
- )
232
- else:
233
- # check the available size
234
- available_size = (
235
- self.token_to_kv_pool.available_size()
236
- + self.tree_cache.evictable_size()
237
- )
238
- if available_size != self.max_total_num_token:
239
- warnings.warn(
240
- "Warning: "
241
- f"available_size={available_size}, max_total_num_token={self.max_total_num_token}\n"
242
- "KV cache pool leak detected!"
243
- )
244
-
245
- def handle_generate_request(
246
- self,
247
- recv_req: TokenizedGenerateReqInput,
248
- ):
249
- req = Req(recv_req.rid, recv_req.input_text, recv_req.input_ids)
250
- req.pixel_values = recv_req.pixel_values
251
- if req.pixel_values is not None:
252
- req.pad_value = [
253
- (recv_req.image_hash) % self.model_config.vocab_size,
254
- (recv_req.image_hash >> 16) % self.model_config.vocab_size,
255
- (recv_req.image_hash >> 32) % self.model_config.vocab_size,
256
- (recv_req.image_hash >> 64) % self.model_config.vocab_size,
257
- ]
258
- req.image_size = recv_req.image_size
259
- req.input_ids, req.image_offset = self.model_runner.model.pad_input_ids(
260
- req.input_ids, req.pad_value, req.pixel_values.shape, req.image_size
261
- )
262
- req.sampling_params = recv_req.sampling_params
263
- req.return_logprob = recv_req.return_logprob
264
- req.logprob_start_len = recv_req.logprob_start_len
265
- req.stream = recv_req.stream
266
- req.tokenizer = self.tokenizer
267
-
268
- # Init regex fsm
269
- if req.sampling_params.regex is not None:
270
- req.regex_fsm = self.regex_fsm_cache.query(req.sampling_params.regex)
271
- if not self.disable_regex_jump_forward:
272
- req.jump_forward_map = self.jump_forward_cache.query(
273
- req.sampling_params.regex
274
- )
275
-
276
- # Truncate long prompts
277
- req.input_ids = req.input_ids[: self.model_config.context_len - 1]
278
- req.sampling_params.max_new_tokens = min(
279
- req.sampling_params.max_new_tokens,
280
- self.model_config.context_len - 1 - len(req.input_ids),
281
- self.max_total_num_token - 128 - len(req.input_ids),
282
- )
283
- self.forward_queue.append(req)
284
-
285
- def get_new_fill_batch(self):
286
- if (
287
- self.running_batch is not None
288
- and len(self.running_batch.reqs) > self.max_num_running_seq
289
- ):
290
- return None
291
-
292
- for req in self.forward_queue:
293
- prefix_indices, last_node = self.tree_cache.match_prefix(req.input_ids)
294
- if req.return_logprob:
295
- prefix_indices = prefix_indices[: req.logprob_start_len]
296
- req.extend_input_len = len(req.input_ids) - len(prefix_indices)
297
- req.prefix_indices = prefix_indices
298
- req.last_node = last_node
299
-
300
- # Get priority queue
301
- self.forward_queue = self.scheduler.get_priority_queue(self.forward_queue)
302
-
303
- # Add requests if there is available space
304
- can_run_list = []
305
- new_batch_total_tokens = 0
306
- new_batch_input_tokens = 0
307
-
308
- available_size = (
309
- self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()
310
- )
311
- if self.running_batch:
312
- available_size -= sum(
313
- [
314
- (r.max_new_tokens() - len(r.output_ids)) * self.new_token_ratio
315
- for r in self.running_batch.reqs
316
- ]
317
- )
318
-
319
- for req in self.forward_queue:
320
- if req.return_logprob:
321
- # Need at least two tokens to compute normalized logprob
322
- if req.extend_input_len < 2:
323
- delta = 2 - req.extend_input_len
324
- req.extend_input_len += delta
325
- req.prefix_indices = req.prefix_indices[:-delta]
326
- if req.image_offset is not None:
327
- req.image_offset += delta
328
- if req.extend_input_len == 0 and req.max_new_tokens() > 0:
329
- # Need at least one token to compute logits
330
- req.extend_input_len = 1
331
- req.prefix_indices = req.prefix_indices[:-1]
332
- if req.image_offset is not None:
333
- req.image_offset += 1
334
-
335
- if (
336
- req.extend_input_len + req.max_new_tokens() + new_batch_total_tokens
337
- < available_size
338
- and req.extend_input_len + new_batch_input_tokens
339
- < self.max_prefill_num_token
340
- ):
341
- delta = self.tree_cache.inc_ref_counter(req.last_node)
342
- available_size += delta
343
-
344
- if not (
345
- req.extend_input_len + req.max_new_tokens() + new_batch_total_tokens
346
- < available_size
347
- ):
348
- # Undo the insertion
349
- delta = self.tree_cache.dec_ref_counter(req.last_node)
350
- available_size += delta
351
- else:
352
- # Add this request to the running batch
353
- self.token_to_kv_pool.add_refs(req.prefix_indices)
354
- can_run_list.append(req)
355
- new_batch_total_tokens += (
356
- req.extend_input_len + req.max_new_tokens()
357
- )
358
- new_batch_input_tokens += req.extend_input_len
359
-
360
- if len(can_run_list) == 0:
361
- return None
362
-
363
- if self.tp_rank == 0:
364
- running_req = (
365
- 0 if self.running_batch is None else len(self.running_batch.reqs)
366
- )
367
- hit_tokens = sum(len(x.prefix_indices) for x in can_run_list)
368
- self.tree_cache_metrics["total"] += (
369
- hit_tokens + new_batch_input_tokens
370
- ) / 10**9
371
- self.tree_cache_metrics["hit"] += hit_tokens / 10**9
372
- tree_cache_hit_rate = (
373
- self.tree_cache_metrics["hit"] / self.tree_cache_metrics["total"]
374
- )
375
- logger.info(
376
- f"new fill batch. #seq: {len(can_run_list)}. "
377
- f"#cached_token: {hit_tokens}. "
378
- f"#new_token: {new_batch_input_tokens}. "
379
- f"#remaining_req: {len(self.forward_queue) - len(can_run_list)}. "
380
- f"#running_req: {running_req}. "
381
- f"tree_cache_hit_rate: {100.0 * tree_cache_hit_rate:.2f}%."
382
- )
383
- logger.debug(
384
- f"fsm_cache_hit_rate: {100.0 * self.regex_fsm_cache.get_cache_hit_rate():.2f}%. "
385
- f"fsm_cache_avg_init_time: {self.regex_fsm_cache.get_avg_init_time():.2f}s. "
386
- f"ff_cache_hit_rate: {100.0 * self.jump_forward_cache.get_cache_hit_rate():.2f}%. "
387
- f"ff_cache_avg_init_time: {self.jump_forward_cache.get_avg_init_time():.2f}s. "
388
- )
389
-
390
- new_batch = Batch.init_new(
391
- can_run_list,
392
- self.req_to_token_pool,
393
- self.token_to_kv_pool,
394
- self.tree_cache,
395
- )
396
- self.forward_queue = [x for x in self.forward_queue if x not in can_run_list]
397
- return new_batch
398
-
399
- def forward_fill_batch(self, batch: Batch):
400
- # Build batch tensors
401
- batch.prepare_for_extend(
402
- self.model_config.vocab_size, self.int_token_logit_bias
403
- )
404
-
405
- logprobs = None
406
- if batch.extend_num_tokens != 0:
407
- # Forward
408
- logits, (
409
- prefill_logprobs,
410
- normalized_logprobs,
411
- last_logprobs,
412
- ) = self.model_runner.forward(
413
- batch, ForwardMode.EXTEND, batch.return_logprob
414
- )
415
- if prefill_logprobs is not None:
416
- logprobs = prefill_logprobs.cpu().tolist()
417
- normalized_logprobs = normalized_logprobs.cpu().tolist()
418
-
419
- next_token_ids, _ = batch.sample(logits)
420
- next_token_ids = next_token_ids.cpu().tolist()
421
- else:
422
- next_token_ids = [self.tokenizer.eos_token_id] * len(batch.reqs)
423
- logits = logprobs = normalized_logprobs = last_logprobs = None
424
-
425
- # Only batch transfer the selected logprobs of the next token to CPU to reduce overhead.
426
- reqs = batch.reqs
427
- if last_logprobs is not None:
428
- last_logprobs = (
429
- last_logprobs[torch.arange(len(reqs)), next_token_ids].cpu().tolist()
430
- )
431
-
432
- # Check finish condition
433
- pt = 0
434
- for i, req in enumerate(reqs):
435
- req.completion_tokens_wo_jump_forward += 1
436
- req.output_ids = [next_token_ids[i]]
437
- req.check_finished()
438
-
439
- if logprobs is not None:
440
- req.logprob = logprobs[pt : pt + req.extend_input_len - 1]
441
- req.normalized_logprob = normalized_logprobs[i]
442
-
443
- # If logprob_start_len > 0, then first logprob_start_len prompt tokens
444
- # will be ignored.
445
- prompt_token_len = len(req.logprob)
446
- token_ids = req.input_ids[-prompt_token_len:] + [next_token_ids[i]]
447
- token_logprobs = req.logprob + [last_logprobs[i]]
448
- req.token_logprob = list(zip(token_ids, token_logprobs))
449
- if req.logprob_start_len == 0:
450
- req.token_logprob = [(req.input_ids[0], None)] + req.token_logprob
451
- pt += req.extend_input_len
452
-
453
- self.handle_finished_requests(batch)
454
-
455
- def forward_decode_batch(self, batch: Batch):
456
- # check if decode out of memory
457
- if not batch.check_decode_mem():
458
- old_ratio = self.new_token_ratio
459
- self.new_token_ratio = min(old_ratio + self.new_token_ratio_step[1], 1.0)
460
-
461
- retracted_reqs = batch.retract_decode()
462
- logger.info(
463
- "decode out of memory happened, "
464
- f"#retracted_reqs: {len(retracted_reqs)}, "
465
- f"#new_token_ratio: {old_ratio:.4f} -> {self.new_token_ratio:.4f}"
466
- )
467
- self.forward_queue.extend(retracted_reqs)
468
- else:
469
- self.new_token_ratio = max(
470
- self.new_token_ratio - self.new_token_ratio_step[0],
471
- self.min_new_token_ratio,
472
- )
473
-
474
- if not self.disable_regex_jump_forward:
475
- # check for jump-forward
476
- jump_forward_reqs = batch.check_for_jump_forward()
477
-
478
- # check for image jump-forward
479
- for req in jump_forward_reqs:
480
- if req.pixel_values is not None:
481
- (
482
- req.input_ids,
483
- req.image_offset,
484
- ) = self.model_runner.model.pad_input_ids(
485
- req.input_ids,
486
- req.pad_value,
487
- req.pixel_values.shape,
488
- req.image_size,
489
- )
490
-
491
- self.forward_queue.extend(jump_forward_reqs)
492
- if batch.is_empty():
493
- return
494
-
495
- # Update batch tensors
496
- self.decode_forward_ct = (self.decode_forward_ct + 1) % (1 << 30)
497
- batch.prepare_for_decode()
498
-
499
- # Forward
500
- logits, (_, _, last_logprobs) = self.model_runner.forward(
501
- batch,
502
- ForwardMode.DECODE,
503
- batch.return_logprob,
504
- )
505
- next_token_ids, _ = batch.sample(logits)
506
- next_token_ids = next_token_ids.cpu().tolist()
507
-
508
- # Only batch transfer the selected logprobs of the next token to CPU to reduce overhead.
509
- reqs = batch.reqs
510
- if last_logprobs is not None:
511
- last_logprobs = last_logprobs[
512
- torch.arange(len(reqs)), next_token_ids
513
- ].tolist()
514
-
515
- # Check finish condition
516
- for i, (req, next_tok_id) in enumerate(zip(reqs, next_token_ids)):
517
- req.completion_tokens_wo_jump_forward += 1
518
- req.output_ids.append(next_tok_id)
519
- req.check_finished()
520
-
521
- if last_logprobs is not None:
522
- req.token_logprob.append((next_tok_id, last_logprobs[i]))
523
-
524
- self.handle_finished_requests(batch)
525
-
526
- def handle_finished_requests(self, batch: Batch):
527
- output_rids = []
528
- output_tokens = []
529
- output_and_jump_forward_strs = []
530
- output_hit_stop_str = []
531
- output_skip_special_tokens = []
532
- output_meta_info = []
533
- output_finished = []
534
- finished_indices = []
535
- unfinished_indices = []
536
- for i, req in enumerate(batch.reqs):
537
- if req.finished:
538
- finished_indices.append(i)
539
- else:
540
- unfinished_indices.append(i)
541
-
542
- if req.finished or (
543
- (
544
- req.stream
545
- and (
546
- self.decode_forward_ct % self.stream_interval == 0
547
- or len(req.output_ids) == 1
548
- )
549
- )
550
- ):
551
- output_rids.append(req.rid)
552
- output_tokens.append(req.output_ids)
553
- output_and_jump_forward_strs.append(req.output_and_jump_forward_str)
554
- output_hit_stop_str.append(req.hit_stop_str)
555
- output_skip_special_tokens.append(
556
- req.sampling_params.skip_special_tokens
557
- )
558
-
559
- meta_info = {
560
- "prompt_tokens": req.prompt_tokens,
561
- "completion_tokens": len(req.input_ids)
562
- + len(req.output_ids)
563
- - req.prompt_tokens,
564
- "completion_tokens_wo_jump_forward": req.completion_tokens_wo_jump_forward,
565
- }
566
- if req.return_logprob:
567
- meta_info["prompt_logprob"] = req.logprob
568
- meta_info["token_logprob"] = req.token_logprob
569
- meta_info["normalized_prompt_logprob"] = req.normalized_logprob
570
- output_meta_info.append(meta_info)
571
- output_finished.append(req.finished)
572
-
573
- # Send to detokenizer
574
- if output_rids:
575
- self.out_pyobjs.append(
576
- BatchTokenIDOut(
577
- output_rids,
578
- output_tokens,
579
- output_and_jump_forward_strs,
580
- output_hit_stop_str,
581
- output_skip_special_tokens,
582
- output_meta_info,
583
- output_finished,
584
- )
585
- )
586
-
587
- # Remove finished reqs
588
- if finished_indices:
589
- # Update radix cache
590
- req_pool_indices_cpu = batch.req_pool_indices.cpu().tolist()
591
- for i in finished_indices:
592
- req = batch.reqs[i]
593
- req_pool_idx = req_pool_indices_cpu[i]
594
- token_ids = tuple(req.input_ids + req.output_ids)
595
- seq_len = len(token_ids) - 1
596
- indices = self.req_to_token_pool.req_to_token[req_pool_idx, :seq_len]
597
- prefix_len = self.tree_cache.insert(
598
- token_ids[:seq_len], indices.clone()
599
- )
600
-
601
- self.token_to_kv_pool.free(indices[:prefix_len])
602
- self.req_to_token_pool.free(req_pool_idx)
603
- self.tree_cache.dec_ref_counter(req.last_node)
604
-
605
- # Update batch tensors
606
- if unfinished_indices:
607
- batch.filter_batch(unfinished_indices)
608
- else:
609
- batch.reqs = []
610
-
611
-
612
- class ModelRpcClient:
613
- def __init__(self, server_args: ServerArgs, port_args: PortArgs):
614
- tp_size = server_args.tp_size
615
-
616
- if tp_size == 1:
617
- # Init model
618
- self.model_server = ModelRpcServer()
619
- self.model_server.exposed_init_model(0, server_args, port_args)
620
-
621
- # Wrap functions
622
- def async_wrap(f):
623
- async def _func(*args, **kwargs):
624
- return f(*args, **kwargs)
625
-
626
- return _func
627
-
628
- self.step = async_wrap(self.model_server.exposed_step)
629
- else:
630
- with ThreadPoolExecutor(tp_size) as executor:
631
- # Launch model processes
632
- rets = executor.map(start_model_process, port_args.model_rpc_ports)
633
- self.model_servers = [x[0] for x in rets]
634
- self.procs = [x[1] for x in rets]
635
-
636
- # Init model
637
- def init_model(i):
638
- return self.model_servers[i].init_model(i, server_args, port_args)
639
-
640
- rets = [obtain(x) for x in executor.map(init_model, range(tp_size))]
641
-
642
- # Wrap functions
643
- def async_wrap(func_name):
644
- fs = [rpyc.async_(getattr(m, func_name)) for m in self.model_servers]
645
-
646
- async def _func(*args, **kwargs):
647
- tasks = [f(*args, **kwargs) for f in fs]
648
- await asyncio.gather(*[asyncio.to_thread(t.wait) for t in tasks])
649
- return obtain(tasks[0].value)
650
-
651
- return _func
652
-
653
- self.step = async_wrap("step")
654
-
655
-
656
- def _init_service(port):
657
- t = ThreadedServer(
658
- ModelRpcServer(),
659
- port=port,
660
- protocol_config={"allow_pickle": True, "sync_request_timeout": 1800},
661
- )
662
- t.start()
663
-
664
-
665
- def start_model_process(port):
666
- proc = multiprocessing.Process(target=_init_service, args=(port,))
667
- proc.start()
668
- time.sleep(1)
669
-
670
- repeat_count = 0
671
- while repeat_count < 20:
672
- try:
673
- con = rpyc.connect(
674
- "localhost",
675
- port,
676
- config={"allow_pickle": True, "sync_request_timeout": 1800},
677
- )
678
- break
679
- except ConnectionRefusedError:
680
- time.sleep(1)
681
- repeat_count += 1
682
- if repeat_count == 20:
683
- raise RuntimeError("init rpc env error!")
684
-
685
- assert proc.is_alive()
686
- return con.root, proc