sglang 0.2.10__py3-none-any.whl → 0.2.12__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (89) hide show
  1. sglang/__init__.py +8 -0
  2. sglang/api.py +10 -2
  3. sglang/bench_latency.py +151 -40
  4. sglang/bench_serving.py +46 -22
  5. sglang/check_env.py +24 -2
  6. sglang/global_config.py +0 -1
  7. sglang/lang/backend/base_backend.py +3 -1
  8. sglang/lang/backend/openai.py +8 -3
  9. sglang/lang/backend/runtime_endpoint.py +46 -29
  10. sglang/lang/choices.py +164 -0
  11. sglang/lang/compiler.py +2 -2
  12. sglang/lang/interpreter.py +6 -13
  13. sglang/lang/ir.py +14 -5
  14. sglang/srt/constrained/base_tool_cache.py +1 -1
  15. sglang/srt/constrained/fsm_cache.py +12 -2
  16. sglang/srt/layers/activation.py +33 -0
  17. sglang/srt/layers/{token_attention.py → decode_attention.py} +9 -5
  18. sglang/srt/layers/extend_attention.py +6 -1
  19. sglang/srt/layers/layernorm.py +65 -0
  20. sglang/srt/layers/logits_processor.py +6 -1
  21. sglang/srt/layers/pooler.py +50 -0
  22. sglang/srt/layers/{context_flashattention_nopad.py → prefill_attention.py} +5 -0
  23. sglang/srt/layers/radix_attention.py +4 -7
  24. sglang/srt/managers/detokenizer_manager.py +31 -9
  25. sglang/srt/managers/io_struct.py +63 -0
  26. sglang/srt/managers/policy_scheduler.py +173 -25
  27. sglang/srt/managers/schedule_batch.py +174 -380
  28. sglang/srt/managers/tokenizer_manager.py +197 -112
  29. sglang/srt/managers/tp_worker.py +299 -364
  30. sglang/srt/mem_cache/{base_cache.py → base_prefix_cache.py} +9 -4
  31. sglang/srt/mem_cache/chunk_cache.py +43 -20
  32. sglang/srt/mem_cache/memory_pool.py +10 -15
  33. sglang/srt/mem_cache/radix_cache.py +74 -40
  34. sglang/srt/model_executor/cuda_graph_runner.py +27 -12
  35. sglang/srt/model_executor/forward_batch_info.py +319 -0
  36. sglang/srt/model_executor/model_runner.py +30 -47
  37. sglang/srt/models/chatglm.py +1 -1
  38. sglang/srt/models/commandr.py +1 -1
  39. sglang/srt/models/dbrx.py +1 -1
  40. sglang/srt/models/deepseek.py +1 -1
  41. sglang/srt/models/deepseek_v2.py +1 -1
  42. sglang/srt/models/gemma.py +1 -1
  43. sglang/srt/models/gemma2.py +1 -2
  44. sglang/srt/models/gpt_bigcode.py +1 -1
  45. sglang/srt/models/grok.py +1 -1
  46. sglang/srt/models/internlm2.py +3 -8
  47. sglang/srt/models/llama2.py +5 -5
  48. sglang/srt/models/llama_classification.py +1 -1
  49. sglang/srt/models/llama_embedding.py +88 -0
  50. sglang/srt/models/llava.py +1 -2
  51. sglang/srt/models/llavavid.py +1 -2
  52. sglang/srt/models/minicpm.py +1 -1
  53. sglang/srt/models/mixtral.py +1 -1
  54. sglang/srt/models/mixtral_quant.py +1 -1
  55. sglang/srt/models/qwen.py +1 -1
  56. sglang/srt/models/qwen2.py +1 -1
  57. sglang/srt/models/qwen2_moe.py +1 -12
  58. sglang/srt/models/stablelm.py +1 -1
  59. sglang/srt/openai_api/adapter.py +189 -39
  60. sglang/srt/openai_api/protocol.py +43 -1
  61. sglang/srt/sampling/penaltylib/__init__.py +13 -0
  62. sglang/srt/sampling/penaltylib/orchestrator.py +357 -0
  63. sglang/srt/sampling/penaltylib/penalizers/frequency_penalty.py +80 -0
  64. sglang/srt/sampling/penaltylib/penalizers/min_new_tokens.py +105 -0
  65. sglang/srt/sampling/penaltylib/penalizers/presence_penalty.py +79 -0
  66. sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +83 -0
  67. sglang/srt/sampling_params.py +31 -4
  68. sglang/srt/server.py +93 -21
  69. sglang/srt/server_args.py +30 -19
  70. sglang/srt/utils.py +31 -13
  71. sglang/test/run_eval.py +10 -1
  72. sglang/test/runners.py +63 -63
  73. sglang/test/simple_eval_humaneval.py +2 -8
  74. sglang/test/simple_eval_mgsm.py +203 -0
  75. sglang/test/srt/sampling/penaltylib/utils.py +337 -0
  76. sglang/test/test_layernorm.py +60 -0
  77. sglang/test/test_programs.py +4 -2
  78. sglang/test/test_utils.py +21 -3
  79. sglang/utils.py +0 -1
  80. sglang/version.py +1 -1
  81. {sglang-0.2.10.dist-info → sglang-0.2.12.dist-info}/METADATA +50 -31
  82. sglang-0.2.12.dist-info/RECORD +112 -0
  83. sglang/srt/layers/linear.py +0 -884
  84. sglang/srt/layers/quantization/__init__.py +0 -64
  85. sglang/srt/layers/quantization/fp8.py +0 -677
  86. sglang-0.2.10.dist-info/RECORD +0 -100
  87. {sglang-0.2.10.dist-info → sglang-0.2.12.dist-info}/LICENSE +0 -0
  88. {sglang-0.2.10.dist-info → sglang-0.2.12.dist-info}/WHEEL +0 -0
  89. {sglang-0.2.10.dist-info → sglang-0.2.12.dist-info}/top_level.txt +0 -0
@@ -17,35 +17,40 @@ limitations under the License.
17
17
 
18
18
  import logging
19
19
  import multiprocessing
20
+ import os
20
21
  import pickle
21
22
  import time
22
23
  import warnings
23
- from typing import List, Optional
24
+ from typing import Any, List, Optional, Union
24
25
 
25
26
  import torch
27
+ import torch.distributed
26
28
  import torch.distributed as dist
27
29
 
28
30
  from sglang.global_config import global_config
29
31
  from sglang.srt.constrained.fsm_cache import FSMCache
30
32
  from sglang.srt.constrained.jump_forward import JumpForwardCache
31
33
  from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
34
+ from sglang.srt.layers.logits_processor import LogitProcessorOutput
32
35
  from sglang.srt.managers.io_struct import (
33
36
  AbortReq,
37
+ BatchEmbeddingOut,
34
38
  BatchTokenIDOut,
35
39
  FlushCacheReq,
40
+ TokenizedEmbeddingReqInput,
36
41
  TokenizedGenerateReqInput,
37
42
  )
38
- from sglang.srt.managers.policy_scheduler import PolicyScheduler
43
+ from sglang.srt.managers.policy_scheduler import PolicyScheduler, PrefillAdder
39
44
  from sglang.srt.managers.schedule_batch import (
40
45
  FINISH_ABORT,
41
46
  BaseFinishReason,
42
- Batch,
43
- ForwardMode,
44
47
  Req,
48
+ ScheduleBatch,
45
49
  )
46
50
  from sglang.srt.mem_cache.chunk_cache import ChunkCache
47
51
  from sglang.srt.mem_cache.radix_cache import RadixCache
48
52
  from sglang.srt.model_config import ModelConfig
53
+ from sglang.srt.model_executor.forward_batch_info import ForwardMode
49
54
  from sglang.srt.model_executor.model_runner import ModelRunner
50
55
  from sglang.srt.server_args import ServerArgs
51
56
  from sglang.srt.utils import (
@@ -59,6 +64,9 @@ from sglang.utils import get_exception_traceback
59
64
  logger = logging.getLogger(__name__)
60
65
 
61
66
 
67
+ crash_on_warning = os.getenv("SGLANG_IS_IN_CI", "false") == "true"
68
+
69
+
62
70
  class ModelTpServer:
63
71
  def __init__(
64
72
  self,
@@ -98,26 +106,24 @@ class ModelTpServer:
98
106
  nccl_port=nccl_port,
99
107
  server_args=server_args,
100
108
  )
101
-
102
- if is_multimodal_model(server_args.model_path):
103
- self.processor = get_processor(
104
- server_args.tokenizer_path,
105
- tokenizer_mode=server_args.tokenizer_mode,
106
- trust_remote_code=server_args.trust_remote_code,
107
- )
108
- self.tokenizer = self.processor.tokenizer
109
+ if server_args.skip_tokenizer_init:
110
+ self.tokenizer = self.processor = None
109
111
  else:
110
- self.tokenizer = get_tokenizer(
111
- server_args.tokenizer_path,
112
- tokenizer_mode=server_args.tokenizer_mode,
113
- trust_remote_code=server_args.trust_remote_code,
114
- )
112
+ if is_multimodal_model(server_args.model_path):
113
+ self.processor = get_processor(
114
+ server_args.tokenizer_path,
115
+ tokenizer_mode=server_args.tokenizer_mode,
116
+ trust_remote_code=server_args.trust_remote_code,
117
+ )
118
+ self.tokenizer = self.processor.tokenizer
119
+ else:
120
+ self.tokenizer = get_tokenizer(
121
+ server_args.tokenizer_path,
122
+ tokenizer_mode=server_args.tokenizer_mode,
123
+ trust_remote_code=server_args.trust_remote_code,
124
+ )
115
125
  self.max_total_num_tokens = self.model_runner.max_total_num_tokens
116
- self.max_prefill_tokens = (
117
- 16384
118
- if server_args.max_prefill_tokens is None
119
- else server_args.max_prefill_tokens
120
- )
126
+ self.max_prefill_tokens = server_args.max_prefill_tokens
121
127
  self.max_running_requests = min(
122
128
  (
123
129
  self.max_total_num_tokens // 2
@@ -160,19 +166,13 @@ class ModelTpServer:
160
166
  disable=server_args.disable_radix_cache,
161
167
  )
162
168
  self.tree_cache_metrics = {"total": 0, "hit": 0}
163
- self.scheduler = PolicyScheduler(
164
- self.schedule_policy,
165
- self.max_running_requests,
166
- self.max_prefill_tokens,
167
- self.max_total_num_tokens,
168
- self.tree_cache,
169
- )
169
+ self.scheduler = PolicyScheduler(self.schedule_policy, self.tree_cache)
170
170
  self.req_to_token_pool = self.model_runner.req_to_token_pool
171
171
  self.token_to_kv_pool = self.model_runner.token_to_kv_pool
172
172
 
173
173
  # Init running status
174
174
  self.waiting_queue: List[Req] = []
175
- self.running_batch: Batch = None
175
+ self.running_batch: ScheduleBatch = None
176
176
  self.out_pyobjs = []
177
177
  self.decode_forward_ct = 0
178
178
  self.stream_interval = server_args.stream_interval
@@ -180,13 +180,15 @@ class ModelTpServer:
180
180
  self.last_stats_tic = time.time()
181
181
 
182
182
  # Init the FSM cache for constrained generation
183
- self.regex_fsm_cache = FSMCache(
184
- server_args.tokenizer_path,
185
- {
186
- "tokenizer_mode": server_args.tokenizer_mode,
187
- "trust_remote_code": server_args.trust_remote_code,
188
- },
189
- )
183
+ if not server_args.skip_tokenizer_init:
184
+ self.regex_fsm_cache = FSMCache(
185
+ server_args.tokenizer_path,
186
+ {
187
+ "tokenizer_mode": server_args.tokenizer_mode,
188
+ "trust_remote_code": server_args.trust_remote_code,
189
+ },
190
+ skip_tokenizer_init=server_args.skip_tokenizer_init,
191
+ )
190
192
  self.jump_forward_cache = JumpForwardCache()
191
193
 
192
194
  # Init new token estimation
@@ -200,13 +202,14 @@ class ModelTpServer:
200
202
  )
201
203
  self.new_token_ratio = self.min_new_token_ratio
202
204
  self.new_token_ratio_decay = global_config.new_token_ratio_decay
203
- self.new_token_ratio_recovery = global_config.new_token_ratio_recovery
204
205
 
205
- def exposed_step(self, recv_reqs):
206
+ def exposed_step(self, recv_reqs: List):
206
207
  try:
207
208
  # Recv requests
208
209
  for recv_req in recv_reqs:
209
- if isinstance(recv_req, TokenizedGenerateReqInput):
210
+ if isinstance(
211
+ recv_req, (TokenizedGenerateReqInput, TokenizedEmbeddingReqInput)
212
+ ):
210
213
  self.handle_generate_request(recv_req)
211
214
  elif isinstance(recv_req, FlushCacheReq):
212
215
  self.flush_cache()
@@ -233,8 +236,6 @@ class ModelTpServer:
233
236
  if new_batch is not None:
234
237
  # Run a new prefill batch
235
238
  self.forward_prefill_batch(new_batch)
236
- self.cache_filled_batch(new_batch)
237
- self.filter_out_inflight(new_batch)
238
239
 
239
240
  if not new_batch.is_empty():
240
241
  if self.running_batch is None:
@@ -251,7 +252,7 @@ class ModelTpServer:
251
252
 
252
253
  # Print stats
253
254
  if self.tp_rank == 0 and self.decode_forward_ct % 40 == 0:
254
- self.print_stats()
255
+ self.print_decode_stats()
255
256
 
256
257
  if self.running_batch.is_empty():
257
258
  self.running_batch = None
@@ -263,7 +264,7 @@ class ModelTpServer:
263
264
  self.check_memory()
264
265
  self.new_token_ratio = global_config.init_new_token_ratio
265
266
 
266
- def print_stats(self):
267
+ def print_decode_stats(self):
267
268
  num_used = self.max_total_num_tokens - (
268
269
  self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()
269
270
  )
@@ -289,52 +290,55 @@ class ModelTpServer:
289
290
  f"available_size={available_size}, max_total_num_tokens={self.max_total_num_tokens}\n"
290
291
  "KV cache pool leak detected!"
291
292
  )
293
+ exit(1) if crash_on_warning else None
292
294
 
293
- if self.req_to_token_pool.can_use_mem_size != self.req_to_token_pool.size:
295
+ if len(self.req_to_token_pool.free_slots) != self.req_to_token_pool.size:
294
296
  warnings.warn(
295
297
  "Warning: "
296
- f"available req slots={self.req_to_token_pool.can_use_mem_size}, "
298
+ f"available req slots={len(self.req_to_token_pool.free_slots)}, "
297
299
  f"total slots={self.req_to_token_pool.size}\n"
298
300
  "Memory pool leak detected!"
299
301
  )
302
+ exit(1) if crash_on_warning else None
300
303
 
301
304
  def handle_generate_request(
302
305
  self,
303
- recv_req: TokenizedGenerateReqInput,
306
+ recv_req: Union[TokenizedGenerateReqInput, TokenizedEmbeddingReqInput],
304
307
  ):
305
308
  req = Req(recv_req.rid, recv_req.input_text, recv_req.input_ids)
306
- req.pixel_values = recv_req.pixel_values
307
- if req.pixel_values is not None:
308
- req.pad_value = [
309
- (recv_req.image_hash) % self.model_config.vocab_size,
310
- (recv_req.image_hash >> 16) % self.model_config.vocab_size,
311
- (recv_req.image_hash >> 32) % self.model_config.vocab_size,
312
- (recv_req.image_hash >> 64) % self.model_config.vocab_size,
313
- ]
314
- req.image_size = recv_req.image_size
315
- (
316
- req.origin_input_ids,
317
- req.image_offset,
318
- ) = self.model_runner.model.pad_input_ids(
319
- req.origin_input_ids_unpadded,
320
- req.pad_value,
321
- req.pixel_values.shape,
322
- req.image_size,
323
- )
324
- req.sampling_params = recv_req.sampling_params
325
- req.return_logprob = recv_req.return_logprob
326
- req.logprob_start_len = recv_req.logprob_start_len
327
- req.top_logprobs_num = recv_req.top_logprobs_num
328
- req.stream = recv_req.stream
329
309
  req.tokenizer = self.tokenizer
330
-
331
- # Init regex fsm
332
- if req.sampling_params.regex is not None:
333
- req.regex_fsm = self.regex_fsm_cache.query(req.sampling_params.regex)
334
- if not self.disable_regex_jump_forward:
335
- req.jump_forward_map = self.jump_forward_cache.query(
336
- req.sampling_params.regex
310
+ req.sampling_params = recv_req.sampling_params
311
+ if self.model_runner.is_generation:
312
+ req.pixel_values = recv_req.pixel_values
313
+ if req.pixel_values is not None:
314
+ req.pad_value = [
315
+ (recv_req.image_hash) % self.model_config.vocab_size,
316
+ (recv_req.image_hash >> 16) % self.model_config.vocab_size,
317
+ (recv_req.image_hash >> 32) % self.model_config.vocab_size,
318
+ (recv_req.image_hash >> 64) % self.model_config.vocab_size,
319
+ ]
320
+ req.image_size = recv_req.image_size
321
+ (
322
+ req.origin_input_ids,
323
+ req.image_offset,
324
+ ) = self.model_runner.model.pad_input_ids(
325
+ req.origin_input_ids_unpadded,
326
+ req.pad_value,
327
+ req.pixel_values.shape,
328
+ req.image_size,
337
329
  )
330
+ req.return_logprob = recv_req.return_logprob
331
+ req.logprob_start_len = recv_req.logprob_start_len
332
+ req.top_logprobs_num = recv_req.top_logprobs_num
333
+ req.stream = recv_req.stream
334
+
335
+ # Init regex fsm
336
+ if req.sampling_params.regex is not None:
337
+ req.regex_fsm = self.regex_fsm_cache.query(req.sampling_params.regex)
338
+ if not self.disable_regex_jump_forward:
339
+ req.jump_forward_map = self.jump_forward_cache.query(
340
+ req.sampling_params.regex
341
+ )
338
342
 
339
343
  # Truncate prompts that are too long
340
344
  if len(req.origin_input_ids) >= self.max_req_input_len:
@@ -343,189 +347,91 @@ class ModelTpServer:
343
347
  "the max context length. Truncated!!!"
344
348
  )
345
349
  req.origin_input_ids = req.origin_input_ids[: self.max_req_input_len]
346
- req.sampling_params.max_new_tokens = min(
347
- (
348
- req.sampling_params.max_new_tokens
349
- if req.sampling_params.max_new_tokens is not None
350
- else 1 << 30
351
- ),
352
- self.max_req_input_len - 1 - len(req.origin_input_ids),
353
- )
350
+
351
+ if self.model_runner.is_generation:
352
+ req.sampling_params.max_new_tokens = min(
353
+ (
354
+ req.sampling_params.max_new_tokens
355
+ if req.sampling_params.max_new_tokens is not None
356
+ else 1 << 30
357
+ ),
358
+ self.max_req_input_len - 1 - len(req.origin_input_ids),
359
+ )
360
+
354
361
  self.waiting_queue.append(req)
355
362
 
356
- def get_new_prefill_batch(self) -> Optional[Batch]:
357
- # TODO(lsyin): organize this function
363
+ def get_new_prefill_batch(self) -> Optional[ScheduleBatch]:
358
364
  running_bs = (
359
365
  len(self.running_batch.reqs) if self.running_batch is not None else 0
360
366
  )
361
367
  if running_bs >= self.max_running_requests:
362
- return
363
-
364
- # Compute matched prefix length
365
- for req in self.waiting_queue:
366
- req.input_ids = req.origin_input_ids + req.output_ids
367
- prefix_indices, last_node = self.tree_cache.match_prefix(
368
- rid=req.rid,
369
- key=req.input_ids,
370
- )
371
- if req.return_logprob:
372
- prefix_indices = prefix_indices[: req.logprob_start_len]
373
- req.extend_input_len = len(req.input_ids) - len(prefix_indices)
374
- req.prefix_indices = prefix_indices
375
- req.last_node = last_node
368
+ return None
376
369
 
377
370
  # Get priority queue
378
- self.waiting_queue = self.scheduler.get_priority_queue(self.waiting_queue)
371
+ prefix_computed = self.scheduler.calc_priority(self.waiting_queue)
379
372
 
380
- # Add requests if there is available space
381
- can_run_list = []
382
- new_batch_total_tokens = 0
383
- new_batch_input_tokens = 0
384
-
385
- available_size = (
386
- self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()
373
+ adder = PrefillAdder(
374
+ self.tree_cache,
375
+ self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size(),
376
+ self.max_prefill_tokens,
377
+ self.chunked_prefill_size,
387
378
  )
388
- if self.running_batch:
389
- available_size -= sum(
390
- [
391
- (r.sampling_params.max_new_tokens - len(r.output_ids))
392
- * self.new_token_ratio
393
- for r in self.running_batch.reqs
394
- ]
395
- )
396
379
 
397
- # Handle the current inflight request
398
- take_inflight = 0
399
- if self.current_inflight_req:
400
- take_inflight = 1
401
- r = self.current_inflight_req
402
- r.input_ids = r.origin_input_ids + r.output_ids
403
- truncated = (
404
- len(r.input_ids) - len(r.prefix_indices) > self.chunked_prefill_size
380
+ if self.running_batch is not None:
381
+ adder.remove_running_tokens(self.running_batch, self.new_token_ratio)
382
+
383
+ has_inflight = self.current_inflight_req is not None
384
+ if self.current_inflight_req is not None:
385
+ self.current_inflight_req.init_next_round_input(
386
+ None if prefix_computed else self.tree_cache
405
387
  )
406
- r.extend_input_len = min(
407
- len(r.input_ids) - len(r.prefix_indices), self.chunked_prefill_size
388
+ self.current_inflight_req = adder.add_inflight_req(
389
+ self.current_inflight_req
408
390
  )
409
- r.input_ids = r.input_ids[: len(r.prefix_indices) + r.extend_input_len]
410
- can_run_list.append(r)
411
-
412
- if not truncated:
413
- # Finish inflight
414
- self.current_inflight_req = None
415
- new_batch_total_tokens += (
416
- r.extend_input_len + r.sampling_params.max_new_tokens
417
- )
418
- new_batch_input_tokens += r.extend_input_len
419
- else:
420
- new_batch_total_tokens += r.extend_input_len
421
- new_batch_input_tokens += r.extend_input_len
422
391
 
423
392
  for req in self.waiting_queue:
424
- if req.return_logprob and req.normalized_prompt_logprob is None:
425
- # Need at least two tokens to compute normalized logprob
426
- if req.extend_input_len < 2:
427
- delta = 2 - req.extend_input_len
428
- req.extend_input_len += delta
429
- req.prefix_indices = req.prefix_indices[:-delta]
430
- if req.image_offset is not None:
431
- req.image_offset += delta
432
- if req.extend_input_len == 0 and req.sampling_params.max_new_tokens > 0:
433
- # Need at least one token to compute logits
434
- req.extend_input_len = 1
435
- req.prefix_indices = req.prefix_indices[:-1]
436
- if req.image_offset is not None:
437
- req.image_offset += 1
438
-
393
+ req.init_next_round_input(None if prefix_computed else self.tree_cache)
394
+ res = adder.add_one_req(req)
439
395
  if (
440
- req.extend_input_len
441
- + req.sampling_params.max_new_tokens
442
- + new_batch_total_tokens
443
- < available_size
444
- and (
445
- req.extend_input_len + new_batch_input_tokens
446
- <= self.max_prefill_tokens
447
- or len(can_run_list) == 0
448
- )
396
+ not res
397
+ or adder.no_remaining_tokens()
398
+ or running_bs + len(adder.can_run_list) >= self.max_running_requests
449
399
  ):
450
- delta = self.tree_cache.inc_lock_ref(req.last_node)
451
- available_size += delta
452
-
453
- if not (
454
- req.extend_input_len
455
- + req.sampling_params.max_new_tokens
456
- + new_batch_total_tokens
457
- < available_size
458
- ):
459
- # Undo locking
460
- delta = self.tree_cache.dec_lock_ref(req.last_node)
461
- available_size += delta
462
- break
463
- else:
464
- # Add this request to the running batch
465
- if (
466
- self.chunked_prefill_size is None
467
- or (
468
- new_batch_input_tokens + req.extend_input_len
469
- <= self.chunked_prefill_size
470
- )
471
- or (
472
- req.return_logprob and req.normalized_prompt_logprob is None
473
- )
474
- ):
475
- can_run_list.append(req)
476
- new_batch_total_tokens += (
477
- req.extend_input_len + req.sampling_params.max_new_tokens
478
- )
479
- new_batch_input_tokens += req.extend_input_len
480
- else:
481
- trunc_len = self.chunked_prefill_size - new_batch_input_tokens
482
-
483
- if trunc_len <= 0:
484
- # Undo locking
485
- delta = self.tree_cache.dec_lock_ref(req.last_node)
486
- available_size += delta
487
- break
488
-
489
- req.extend_input_len = trunc_len
490
- req.input_ids = req.input_ids[
491
- : len(req.prefix_indices) + req.extend_input_len
492
- ]
493
- can_run_list.append(req)
494
- self.current_inflight_req = req
495
- new_batch_input_tokens += req.extend_input_len
496
- new_batch_total_tokens += req.extend_input_len
497
- break
498
- else:
499
400
  break
500
401
 
501
- if running_bs + len(can_run_list) >= self.max_running_requests:
502
- break
402
+ can_run_list = adder.can_run_list
403
+
404
+ if adder.new_inflight_req is not None:
405
+ assert self.current_inflight_req is None
406
+ self.current_inflight_req = adder.new_inflight_req
503
407
 
504
408
  if len(can_run_list) == 0:
505
409
  return None
506
410
 
507
411
  # Print stats
508
412
  if self.tp_rank == 0:
509
- hit_tokens = sum(len(x.prefix_indices) for x in can_run_list)
510
- self.tree_cache_metrics["total"] += (
511
- hit_tokens + new_batch_input_tokens
512
- ) / 10**9
513
- self.tree_cache_metrics["hit"] += hit_tokens / 10**9
514
- tree_cache_hit_rate = (
515
- self.tree_cache_metrics["hit"] / self.tree_cache_metrics["total"]
516
- )
413
+ if isinstance(self.tree_cache, RadixCache):
414
+ self.tree_cache_metrics["total"] += (
415
+ adder.log_input_tokens + adder.log_hit_tokens
416
+ ) / 10**9
417
+ self.tree_cache_metrics["hit"] += (adder.log_hit_tokens) / 10**9
418
+ tree_cache_hit_rate = (
419
+ self.tree_cache_metrics["hit"] / self.tree_cache_metrics["total"]
420
+ )
421
+ else:
422
+ tree_cache_hit_rate = 0.0
517
423
  logger.info(
518
424
  f"[gpu={self.gpu_id}] Prefill batch. "
519
425
  f"#new-seq: {len(can_run_list)}, "
520
- f"#new-token: {new_batch_input_tokens}, "
521
- f"#cached-token: {hit_tokens}, "
426
+ f"#new-token: {adder.log_input_tokens}, "
427
+ f"#cached-token: {adder.log_hit_tokens}, "
522
428
  f"cache hit rate: {100.0 * tree_cache_hit_rate:.2f}%, "
523
429
  f"#running-req: {running_bs}, "
524
- f"#queue-req: {len(self.waiting_queue) - len(can_run_list) + take_inflight}"
430
+ f"#queue-req: {len(self.waiting_queue) - len(can_run_list) + has_inflight}"
525
431
  )
526
432
 
527
433
  # Return the new batch
528
- new_batch = Batch.init_new(
434
+ new_batch = ScheduleBatch.init_new(
529
435
  can_run_list,
530
436
  self.req_to_token_pool,
531
437
  self.token_to_kv_pool,
@@ -534,47 +440,94 @@ class ModelTpServer:
534
440
  self.waiting_queue = [x for x in self.waiting_queue if x not in can_run_list]
535
441
  return new_batch
536
442
 
537
- def forward_prefill_batch(self, batch: Batch):
443
+ def forward_prefill_batch(self, batch: ScheduleBatch):
538
444
  # Build batch tensors
539
445
  batch.prepare_for_extend(
540
446
  self.model_config.vocab_size, self.int_token_logit_bias
541
447
  )
542
448
 
543
- # Forward and sample the next tokens
544
- if batch.extend_num_tokens != 0:
545
- output = self.model_runner.forward(batch, ForwardMode.EXTEND)
546
- next_token_ids = batch.sample(output.next_token_logits)
547
-
548
- # Move logprobs to cpu
549
- if output.next_token_logprobs is not None:
550
- output.next_token_logprobs = output.next_token_logprobs[
551
- torch.arange(len(next_token_ids), device=next_token_ids.device),
552
- next_token_ids,
553
- ].tolist()
554
- output.input_token_logprobs = output.input_token_logprobs.tolist()
555
- output.normalized_prompt_logprobs = (
556
- output.normalized_prompt_logprobs.tolist()
557
- )
449
+ if self.model_runner.is_generation:
450
+ # Forward and sample the next tokens
451
+ if batch.extend_num_tokens != 0:
452
+ output = self.model_runner.forward(batch, ForwardMode.EXTEND)
453
+ next_token_ids = batch.sample(output.next_token_logits)
454
+
455
+ # Move logprobs to cpu
456
+ if output.next_token_logprobs is not None:
457
+ output.next_token_logprobs = output.next_token_logprobs[
458
+ torch.arange(len(next_token_ids), device=next_token_ids.device),
459
+ next_token_ids,
460
+ ].tolist()
461
+ output.input_token_logprobs = output.input_token_logprobs.tolist()
462
+ output.normalized_prompt_logprobs = (
463
+ output.normalized_prompt_logprobs.tolist()
464
+ )
558
465
 
559
- next_token_ids = next_token_ids.tolist()
560
- else:
561
- next_token_ids = [self.tokenizer.eos_token_id] * len(batch.reqs)
466
+ next_token_ids = next_token_ids.tolist()
467
+ else:
468
+ if self.tokenizer is None:
469
+ next_token_ids = []
470
+ for req in batch.reqs:
471
+ next_token_ids.append(
472
+ next(iter(req.sampling_params.stop_token_ids))
473
+ )
474
+ else:
475
+ next_token_ids = [self.tokenizer.eos_token_id] * len(batch.reqs)
476
+
477
+ # Check finish conditions
478
+ pt = 0
479
+ for i, req in enumerate(batch.reqs):
480
+ if req is not self.current_inflight_req:
481
+ # Inflight reqs' prefill is not finished
482
+ req.completion_tokens_wo_jump_forward += 1
483
+ req.output_ids.append(next_token_ids[i])
484
+ req.check_finished()
485
+
486
+ if req.finished():
487
+ self.tree_cache.cache_finished_req(req)
488
+ else:
489
+ self.tree_cache.cache_unfinished_req(req)
562
490
 
563
- # Check finish conditions
564
- pt = 0
565
- for i, req in enumerate(batch.reqs):
566
- if req is not self.current_inflight_req:
567
- req.completion_tokens_wo_jump_forward += 1
568
- req.output_ids.append(next_token_ids[i])
569
- req.check_finished()
491
+ if req is self.current_inflight_req:
492
+ # Inflight request would get a new req idx
493
+ self.req_to_token_pool.free(req.req_pool_idx)
570
494
 
571
- if req.return_logprob:
572
- self.add_logprob_return_values(i, req, pt, next_token_ids, output)
573
- pt += req.extend_input_len
495
+ if req.return_logprob:
496
+ self.add_logprob_return_values(i, req, pt, next_token_ids, output)
497
+ pt += req.extend_input_len
498
+ else:
499
+ assert batch.extend_num_tokens != 0
500
+ output = self.model_runner.forward(batch, ForwardMode.EXTEND)
501
+ embeddings = output.embeddings.tolist()
502
+
503
+ # Check finish conditions
504
+ for i, req in enumerate(batch.reqs):
505
+ req.embedding = embeddings[i]
506
+ if req is not self.current_inflight_req:
507
+ # Inflight reqs' prefill is not finished
508
+ # dummy output token for embedding models
509
+ req.output_ids.append(0)
510
+ req.check_finished()
511
+
512
+ if req.finished():
513
+ self.tree_cache.cache_finished_req(req)
514
+ else:
515
+ self.tree_cache.cache_unfinished_req(req)
516
+
517
+ if req is self.current_inflight_req:
518
+ # Inflight request would get a new req idx
519
+ self.req_to_token_pool.free(req.req_pool_idx)
574
520
 
575
521
  self.handle_finished_requests(batch)
576
522
 
577
- def add_logprob_return_values(self, i, req, pt, next_token_ids, output):
523
+ def add_logprob_return_values(
524
+ self,
525
+ i,
526
+ req: Req,
527
+ pt: int,
528
+ next_token_ids: List[int],
529
+ output: LogitProcessorOutput,
530
+ ):
578
531
  if req.normalized_prompt_logprob is None:
579
532
  req.normalized_prompt_logprob = output.normalized_prompt_logprobs[i]
580
533
 
@@ -583,12 +536,12 @@ class ModelTpServer:
583
536
  req.input_token_logprobs = list(
584
537
  zip(
585
538
  output.input_token_logprobs[pt : pt + req.extend_input_len - 1],
586
- req.input_ids[-req.extend_input_len + 1 :],
539
+ req.fill_ids[-req.extend_input_len + 1 :],
587
540
  )
588
541
  )
589
542
  if req.logprob_start_len == 0:
590
543
  req.input_token_logprobs = [
591
- (None, req.input_ids[0])
544
+ (None, req.fill_ids[0])
592
545
  ] + req.input_token_logprobs
593
546
 
594
547
  if req.last_update_decode_tokens != 0:
@@ -602,7 +555,7 @@ class ModelTpServer:
602
555
  + req.extend_input_len
603
556
  - 1
604
557
  ],
605
- req.input_ids[-req.last_update_decode_tokens + 1 :],
558
+ req.fill_ids[-req.last_update_decode_tokens + 1 :],
606
559
  )
607
560
  )
608
561
  )
@@ -623,24 +576,7 @@ class ModelTpServer:
623
576
  )
624
577
  req.output_top_logprobs.append(output.output_top_logprobs[i])
625
578
 
626
- def cache_filled_batch(self, batch: Batch):
627
- req_pool_indices_cpu = batch.req_pool_indices.cpu().numpy()
628
- for i, req in enumerate(batch.reqs):
629
- new_prefix_indices, new_last_node = self.tree_cache.cache_req(
630
- rid=req.rid,
631
- token_ids=tuple(req.input_ids),
632
- last_uncached_pos=len(req.prefix_indices),
633
- req_pool_idx=req_pool_indices_cpu[i],
634
- del_in_memory_pool=False,
635
- old_last_node=req.last_node,
636
- )
637
- req.prefix_indices, req.last_node = new_prefix_indices, new_last_node
638
-
639
- if req is self.current_inflight_req:
640
- # inflight request would get a new req idx
641
- self.req_to_token_pool.free(int(req_pool_indices_cpu[i]))
642
-
643
- def forward_decode_batch(self, batch: Batch):
579
+ def forward_decode_batch(self, batch: ScheduleBatch):
644
580
  # Check if decode out of memory
645
581
  if not batch.check_decode_mem():
646
582
  old_ratio = self.new_token_ratio
@@ -690,6 +626,9 @@ class ModelTpServer:
690
626
  req.output_ids.append(next_token_id)
691
627
  req.check_finished()
692
628
 
629
+ if req.finished():
630
+ self.tree_cache.cache_finished_req(req)
631
+
693
632
  if req.return_logprob:
694
633
  req.output_token_logprobs.append(
695
634
  (next_token_logprobs[i], next_token_id)
@@ -699,22 +638,23 @@ class ModelTpServer:
699
638
 
700
639
  self.handle_finished_requests(batch)
701
640
 
702
- def handle_finished_requests(self, batch: Batch):
641
+ def handle_finished_requests(self, batch: ScheduleBatch):
703
642
  output_rids = []
704
- output_vids = []
705
- decoded_texts = []
706
- output_read_ids = []
707
- output_read_offsets = []
708
- output_skip_special_tokens = []
709
- output_spaces_between_special_tokens = []
710
643
  output_meta_info = []
711
644
  output_finished_reason: List[BaseFinishReason] = []
712
- finished_indices = []
645
+ if self.model_runner.is_generation:
646
+ output_vids = []
647
+ decoded_texts = []
648
+ output_read_ids = []
649
+ output_read_offsets = []
650
+ output_skip_special_tokens = []
651
+ output_spaces_between_special_tokens = []
652
+ else: # for embedding model
653
+ output_embeddings = []
713
654
  unfinished_indices = []
655
+
714
656
  for i, req in enumerate(batch.reqs):
715
- if req.finished():
716
- finished_indices.append(i)
717
- else:
657
+ if not req.finished() and req is not self.current_inflight_req:
718
658
  unfinished_indices.append(i)
719
659
 
720
660
  if req.finished() or (
@@ -727,86 +667,75 @@ class ModelTpServer:
727
667
  )
728
668
  ):
729
669
  output_rids.append(req.rid)
730
- output_vids.append(req.vid)
731
- decoded_texts.append(req.decoded_text)
732
- read_ids, read_offset = req.init_incremental_detokenize()
733
- output_read_ids.append(read_ids)
734
- output_read_offsets.append(read_offset)
735
- output_skip_special_tokens.append(
736
- req.sampling_params.skip_special_tokens
737
- )
738
- output_spaces_between_special_tokens.append(
739
- req.sampling_params.spaces_between_special_tokens
740
- )
741
-
742
- meta_info = {
743
- "prompt_tokens": len(req.origin_input_ids),
744
- "completion_tokens": len(req.output_ids),
745
- "completion_tokens_wo_jump_forward": req.completion_tokens_wo_jump_forward,
746
- "finish_reason": str(req.finished_reason),
747
- }
748
- if req.return_logprob:
749
- (
750
- meta_info["input_token_logprobs"],
751
- meta_info["output_token_logprobs"],
752
- meta_info["input_top_logprobs"],
753
- meta_info["output_top_logprobs"],
754
- meta_info["normalized_prompt_logprob"],
755
- ) = (
756
- req.input_token_logprobs,
757
- req.output_token_logprobs,
758
- req.input_top_logprobs,
759
- req.output_top_logprobs,
760
- req.normalized_prompt_logprob,
761
- )
762
- output_meta_info.append(meta_info)
763
670
  output_finished_reason.append(req.finished_reason)
671
+ if self.model_runner.is_generation:
672
+ output_vids.append(req.vid)
673
+ decoded_texts.append(req.decoded_text)
674
+ read_ids, read_offset = req.init_incremental_detokenize()
675
+ output_read_ids.append(read_ids)
676
+ output_read_offsets.append(read_offset)
677
+ output_skip_special_tokens.append(
678
+ req.sampling_params.skip_special_tokens
679
+ )
680
+ output_spaces_between_special_tokens.append(
681
+ req.sampling_params.spaces_between_special_tokens
682
+ )
683
+
684
+ meta_info = {
685
+ "prompt_tokens": len(req.origin_input_ids),
686
+ "completion_tokens": len(req.output_ids),
687
+ "completion_tokens_wo_jump_forward": req.completion_tokens_wo_jump_forward,
688
+ "finish_reason": str(req.finished_reason),
689
+ }
690
+ if req.return_logprob:
691
+ (
692
+ meta_info["input_token_logprobs"],
693
+ meta_info["output_token_logprobs"],
694
+ meta_info["input_top_logprobs"],
695
+ meta_info["output_top_logprobs"],
696
+ meta_info["normalized_prompt_logprob"],
697
+ ) = (
698
+ req.input_token_logprobs,
699
+ req.output_token_logprobs,
700
+ req.input_top_logprobs,
701
+ req.output_top_logprobs,
702
+ req.normalized_prompt_logprob,
703
+ )
704
+ output_meta_info.append(meta_info)
705
+ else: # for embedding model
706
+ output_embeddings.append(req.embedding)
707
+ meta_info = {
708
+ "prompt_tokens": len(req.origin_input_ids),
709
+ }
710
+ output_meta_info.append(meta_info)
764
711
 
765
712
  # Send to detokenizer
766
713
  if output_rids:
767
- self.out_pyobjs.append(
768
- BatchTokenIDOut(
769
- output_rids,
770
- output_vids,
771
- decoded_texts,
772
- output_read_ids,
773
- output_read_offsets,
774
- output_skip_special_tokens,
775
- output_spaces_between_special_tokens,
776
- output_meta_info,
777
- output_finished_reason,
714
+ if self.model_runner.is_generation:
715
+ self.out_pyobjs.append(
716
+ BatchTokenIDOut(
717
+ output_rids,
718
+ output_vids,
719
+ decoded_texts,
720
+ output_read_ids,
721
+ output_read_offsets,
722
+ output_skip_special_tokens,
723
+ output_spaces_between_special_tokens,
724
+ output_meta_info,
725
+ output_finished_reason,
726
+ )
778
727
  )
779
- )
780
-
781
- # Remove finished reqs
782
- if finished_indices:
783
- # Update radix cache
784
- req_pool_indices_cpu = batch.req_pool_indices.tolist()
785
- for i in finished_indices:
786
- req = batch.reqs[i]
787
- self.tree_cache.cache_req(
788
- rid=req.rid,
789
- token_ids=tuple(req.origin_input_ids + req.output_ids)[:-1],
790
- last_uncached_pos=len(req.prefix_indices),
791
- req_pool_idx=req_pool_indices_cpu[i],
728
+ else: # for embedding model
729
+ self.out_pyobjs.append(
730
+ BatchEmbeddingOut(
731
+ output_rids,
732
+ output_embeddings,
733
+ output_meta_info,
734
+ output_finished_reason,
735
+ )
792
736
  )
793
737
 
794
- self.tree_cache.dec_lock_ref(req.last_node)
795
-
796
- # Update batch tensors
797
- if unfinished_indices:
798
- batch.filter_batch(unfinished_indices)
799
- else:
800
- batch.reqs = []
801
-
802
- def filter_out_inflight(self, batch: Batch):
803
- # TODO(lsyin): reduce the overhead, make a special version for this
804
- if self.current_inflight_req is None:
805
- return
806
-
807
- to_remove = batch.reqs.index(self.current_inflight_req)
808
- unfinished_indices = [i for i in range(len(batch.reqs)) if i != to_remove]
809
-
738
+ # Remove finished reqs: update batch tensors
810
739
  batch.filter_batch(unfinished_indices)
811
740
 
812
741
  def flush_cache(self):
@@ -873,7 +802,11 @@ def run_tp_server(
873
802
 
874
803
 
875
804
  def launch_tp_servers(
876
- gpu_ids, tp_rank_range, server_args, nccl_port, model_overide_args
805
+ gpu_ids: List[int],
806
+ tp_rank_range: List[int],
807
+ server_args: ServerArgs,
808
+ nccl_port: int,
809
+ model_overide_args: dict,
877
810
  ):
878
811
  """Launch multiple tensor parallel servers."""
879
812
  procs = []
@@ -888,7 +821,9 @@ def launch_tp_servers(
888
821
  return procs
889
822
 
890
823
 
891
- def broadcast_recv_input(data, rank, dist_group):
824
+ def broadcast_recv_input(
825
+ data: Any, rank: int, dist_group: torch.distributed.ProcessGroup
826
+ ):
892
827
  """Broadcast inputs from rank=0 to all other ranks with torch.dist backend."""
893
828
 
894
829
  if rank == 0: