sglang 0.2.11__py3-none-any.whl → 0.2.12__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (63) hide show
  1. sglang/bench_latency.py +6 -4
  2. sglang/bench_serving.py +46 -22
  3. sglang/lang/compiler.py +2 -2
  4. sglang/lang/ir.py +3 -3
  5. sglang/srt/constrained/base_tool_cache.py +1 -1
  6. sglang/srt/constrained/fsm_cache.py +12 -2
  7. sglang/srt/layers/activation.py +33 -0
  8. sglang/srt/layers/{token_attention.py → decode_attention.py} +9 -5
  9. sglang/srt/layers/extend_attention.py +6 -1
  10. sglang/srt/layers/layernorm.py +65 -0
  11. sglang/srt/layers/logits_processor.py +5 -0
  12. sglang/srt/layers/pooler.py +50 -0
  13. sglang/srt/layers/{context_flashattention_nopad.py → prefill_attention.py} +5 -0
  14. sglang/srt/layers/radix_attention.py +2 -2
  15. sglang/srt/managers/detokenizer_manager.py +31 -9
  16. sglang/srt/managers/io_struct.py +63 -0
  17. sglang/srt/managers/policy_scheduler.py +173 -25
  18. sglang/srt/managers/schedule_batch.py +110 -87
  19. sglang/srt/managers/tokenizer_manager.py +193 -111
  20. sglang/srt/managers/tp_worker.py +289 -352
  21. sglang/srt/mem_cache/{base_cache.py → base_prefix_cache.py} +9 -4
  22. sglang/srt/mem_cache/chunk_cache.py +43 -20
  23. sglang/srt/mem_cache/memory_pool.py +2 -2
  24. sglang/srt/mem_cache/radix_cache.py +74 -40
  25. sglang/srt/model_executor/cuda_graph_runner.py +24 -9
  26. sglang/srt/model_executor/forward_batch_info.py +168 -105
  27. sglang/srt/model_executor/model_runner.py +24 -37
  28. sglang/srt/models/gemma2.py +0 -1
  29. sglang/srt/models/internlm2.py +2 -7
  30. sglang/srt/models/llama2.py +4 -4
  31. sglang/srt/models/llama_embedding.py +88 -0
  32. sglang/srt/models/qwen2_moe.py +0 -11
  33. sglang/srt/openai_api/adapter.py +155 -27
  34. sglang/srt/openai_api/protocol.py +37 -1
  35. sglang/srt/sampling/penaltylib/__init__.py +13 -0
  36. sglang/srt/sampling/penaltylib/orchestrator.py +357 -0
  37. sglang/srt/sampling/penaltylib/penalizers/frequency_penalty.py +80 -0
  38. sglang/srt/sampling/penaltylib/penalizers/min_new_tokens.py +105 -0
  39. sglang/srt/sampling/penaltylib/penalizers/presence_penalty.py +79 -0
  40. sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +83 -0
  41. sglang/srt/sampling_params.py +31 -4
  42. sglang/srt/server.py +69 -15
  43. sglang/srt/server_args.py +26 -19
  44. sglang/srt/utils.py +31 -13
  45. sglang/test/run_eval.py +10 -1
  46. sglang/test/runners.py +63 -63
  47. sglang/test/simple_eval_humaneval.py +2 -8
  48. sglang/test/simple_eval_mgsm.py +203 -0
  49. sglang/test/srt/sampling/penaltylib/utils.py +337 -0
  50. sglang/test/test_layernorm.py +60 -0
  51. sglang/test/test_programs.py +4 -2
  52. sglang/test/test_utils.py +20 -2
  53. sglang/utils.py +0 -1
  54. sglang/version.py +1 -1
  55. {sglang-0.2.11.dist-info → sglang-0.2.12.dist-info}/METADATA +23 -14
  56. sglang-0.2.12.dist-info/RECORD +112 -0
  57. sglang/srt/layers/linear.py +0 -884
  58. sglang/srt/layers/quantization/__init__.py +0 -64
  59. sglang/srt/layers/quantization/fp8.py +0 -677
  60. sglang-0.2.11.dist-info/RECORD +0 -102
  61. {sglang-0.2.11.dist-info → sglang-0.2.12.dist-info}/LICENSE +0 -0
  62. {sglang-0.2.11.dist-info → sglang-0.2.12.dist-info}/WHEEL +0 -0
  63. {sglang-0.2.11.dist-info → sglang-0.2.12.dist-info}/top_level.txt +0 -0
@@ -17,25 +17,30 @@ limitations under the License.
17
17
 
18
18
  import logging
19
19
  import multiprocessing
20
+ import os
20
21
  import pickle
21
22
  import time
22
23
  import warnings
23
- from typing import List, Optional
24
+ from typing import Any, List, Optional, Union
24
25
 
25
26
  import torch
27
+ import torch.distributed
26
28
  import torch.distributed as dist
27
29
 
28
30
  from sglang.global_config import global_config
29
31
  from sglang.srt.constrained.fsm_cache import FSMCache
30
32
  from sglang.srt.constrained.jump_forward import JumpForwardCache
31
33
  from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
34
+ from sglang.srt.layers.logits_processor import LogitProcessorOutput
32
35
  from sglang.srt.managers.io_struct import (
33
36
  AbortReq,
37
+ BatchEmbeddingOut,
34
38
  BatchTokenIDOut,
35
39
  FlushCacheReq,
40
+ TokenizedEmbeddingReqInput,
36
41
  TokenizedGenerateReqInput,
37
42
  )
38
- from sglang.srt.managers.policy_scheduler import PolicyScheduler
43
+ from sglang.srt.managers.policy_scheduler import PolicyScheduler, PrefillAdder
39
44
  from sglang.srt.managers.schedule_batch import (
40
45
  FINISH_ABORT,
41
46
  BaseFinishReason,
@@ -59,6 +64,9 @@ from sglang.utils import get_exception_traceback
59
64
  logger = logging.getLogger(__name__)
60
65
 
61
66
 
67
+ crash_on_warning = os.getenv("SGLANG_IS_IN_CI", "false") == "true"
68
+
69
+
62
70
  class ModelTpServer:
63
71
  def __init__(
64
72
  self,
@@ -98,26 +106,24 @@ class ModelTpServer:
98
106
  nccl_port=nccl_port,
99
107
  server_args=server_args,
100
108
  )
101
-
102
- if is_multimodal_model(server_args.model_path):
103
- self.processor = get_processor(
104
- server_args.tokenizer_path,
105
- tokenizer_mode=server_args.tokenizer_mode,
106
- trust_remote_code=server_args.trust_remote_code,
107
- )
108
- self.tokenizer = self.processor.tokenizer
109
+ if server_args.skip_tokenizer_init:
110
+ self.tokenizer = self.processor = None
109
111
  else:
110
- self.tokenizer = get_tokenizer(
111
- server_args.tokenizer_path,
112
- tokenizer_mode=server_args.tokenizer_mode,
113
- trust_remote_code=server_args.trust_remote_code,
114
- )
112
+ if is_multimodal_model(server_args.model_path):
113
+ self.processor = get_processor(
114
+ server_args.tokenizer_path,
115
+ tokenizer_mode=server_args.tokenizer_mode,
116
+ trust_remote_code=server_args.trust_remote_code,
117
+ )
118
+ self.tokenizer = self.processor.tokenizer
119
+ else:
120
+ self.tokenizer = get_tokenizer(
121
+ server_args.tokenizer_path,
122
+ tokenizer_mode=server_args.tokenizer_mode,
123
+ trust_remote_code=server_args.trust_remote_code,
124
+ )
115
125
  self.max_total_num_tokens = self.model_runner.max_total_num_tokens
116
- self.max_prefill_tokens = (
117
- 16384
118
- if server_args.max_prefill_tokens is None
119
- else server_args.max_prefill_tokens
120
- )
126
+ self.max_prefill_tokens = server_args.max_prefill_tokens
121
127
  self.max_running_requests = min(
122
128
  (
123
129
  self.max_total_num_tokens // 2
@@ -160,13 +166,7 @@ class ModelTpServer:
160
166
  disable=server_args.disable_radix_cache,
161
167
  )
162
168
  self.tree_cache_metrics = {"total": 0, "hit": 0}
163
- self.scheduler = PolicyScheduler(
164
- self.schedule_policy,
165
- self.max_running_requests,
166
- self.max_prefill_tokens,
167
- self.max_total_num_tokens,
168
- self.tree_cache,
169
- )
169
+ self.scheduler = PolicyScheduler(self.schedule_policy, self.tree_cache)
170
170
  self.req_to_token_pool = self.model_runner.req_to_token_pool
171
171
  self.token_to_kv_pool = self.model_runner.token_to_kv_pool
172
172
 
@@ -180,13 +180,15 @@ class ModelTpServer:
180
180
  self.last_stats_tic = time.time()
181
181
 
182
182
  # Init the FSM cache for constrained generation
183
- self.regex_fsm_cache = FSMCache(
184
- server_args.tokenizer_path,
185
- {
186
- "tokenizer_mode": server_args.tokenizer_mode,
187
- "trust_remote_code": server_args.trust_remote_code,
188
- },
189
- )
183
+ if not server_args.skip_tokenizer_init:
184
+ self.regex_fsm_cache = FSMCache(
185
+ server_args.tokenizer_path,
186
+ {
187
+ "tokenizer_mode": server_args.tokenizer_mode,
188
+ "trust_remote_code": server_args.trust_remote_code,
189
+ },
190
+ skip_tokenizer_init=server_args.skip_tokenizer_init,
191
+ )
190
192
  self.jump_forward_cache = JumpForwardCache()
191
193
 
192
194
  # Init new token estimation
@@ -201,11 +203,13 @@ class ModelTpServer:
201
203
  self.new_token_ratio = self.min_new_token_ratio
202
204
  self.new_token_ratio_decay = global_config.new_token_ratio_decay
203
205
 
204
- def exposed_step(self, recv_reqs):
206
+ def exposed_step(self, recv_reqs: List):
205
207
  try:
206
208
  # Recv requests
207
209
  for recv_req in recv_reqs:
208
- if isinstance(recv_req, TokenizedGenerateReqInput):
210
+ if isinstance(
211
+ recv_req, (TokenizedGenerateReqInput, TokenizedEmbeddingReqInput)
212
+ ):
209
213
  self.handle_generate_request(recv_req)
210
214
  elif isinstance(recv_req, FlushCacheReq):
211
215
  self.flush_cache()
@@ -232,8 +236,6 @@ class ModelTpServer:
232
236
  if new_batch is not None:
233
237
  # Run a new prefill batch
234
238
  self.forward_prefill_batch(new_batch)
235
- self.cache_filled_batch(new_batch)
236
- self.filter_out_inflight(new_batch)
237
239
 
238
240
  if not new_batch.is_empty():
239
241
  if self.running_batch is None:
@@ -250,7 +252,7 @@ class ModelTpServer:
250
252
 
251
253
  # Print stats
252
254
  if self.tp_rank == 0 and self.decode_forward_ct % 40 == 0:
253
- self.print_stats()
255
+ self.print_decode_stats()
254
256
 
255
257
  if self.running_batch.is_empty():
256
258
  self.running_batch = None
@@ -262,7 +264,7 @@ class ModelTpServer:
262
264
  self.check_memory()
263
265
  self.new_token_ratio = global_config.init_new_token_ratio
264
266
 
265
- def print_stats(self):
267
+ def print_decode_stats(self):
266
268
  num_used = self.max_total_num_tokens - (
267
269
  self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()
268
270
  )
@@ -288,6 +290,7 @@ class ModelTpServer:
288
290
  f"available_size={available_size}, max_total_num_tokens={self.max_total_num_tokens}\n"
289
291
  "KV cache pool leak detected!"
290
292
  )
293
+ exit(1) if crash_on_warning else None
291
294
 
292
295
  if len(self.req_to_token_pool.free_slots) != self.req_to_token_pool.size:
293
296
  warnings.warn(
@@ -296,44 +299,46 @@ class ModelTpServer:
296
299
  f"total slots={self.req_to_token_pool.size}\n"
297
300
  "Memory pool leak detected!"
298
301
  )
302
+ exit(1) if crash_on_warning else None
299
303
 
300
304
  def handle_generate_request(
301
305
  self,
302
- recv_req: TokenizedGenerateReqInput,
306
+ recv_req: Union[TokenizedGenerateReqInput, TokenizedEmbeddingReqInput],
303
307
  ):
304
308
  req = Req(recv_req.rid, recv_req.input_text, recv_req.input_ids)
305
- req.pixel_values = recv_req.pixel_values
306
- if req.pixel_values is not None:
307
- req.pad_value = [
308
- (recv_req.image_hash) % self.model_config.vocab_size,
309
- (recv_req.image_hash >> 16) % self.model_config.vocab_size,
310
- (recv_req.image_hash >> 32) % self.model_config.vocab_size,
311
- (recv_req.image_hash >> 64) % self.model_config.vocab_size,
312
- ]
313
- req.image_size = recv_req.image_size
314
- (
315
- req.origin_input_ids,
316
- req.image_offset,
317
- ) = self.model_runner.model.pad_input_ids(
318
- req.origin_input_ids_unpadded,
319
- req.pad_value,
320
- req.pixel_values.shape,
321
- req.image_size,
322
- )
323
- req.sampling_params = recv_req.sampling_params
324
- req.return_logprob = recv_req.return_logprob
325
- req.logprob_start_len = recv_req.logprob_start_len
326
- req.top_logprobs_num = recv_req.top_logprobs_num
327
- req.stream = recv_req.stream
328
309
  req.tokenizer = self.tokenizer
329
-
330
- # Init regex fsm
331
- if req.sampling_params.regex is not None:
332
- req.regex_fsm = self.regex_fsm_cache.query(req.sampling_params.regex)
333
- if not self.disable_regex_jump_forward:
334
- req.jump_forward_map = self.jump_forward_cache.query(
335
- req.sampling_params.regex
310
+ req.sampling_params = recv_req.sampling_params
311
+ if self.model_runner.is_generation:
312
+ req.pixel_values = recv_req.pixel_values
313
+ if req.pixel_values is not None:
314
+ req.pad_value = [
315
+ (recv_req.image_hash) % self.model_config.vocab_size,
316
+ (recv_req.image_hash >> 16) % self.model_config.vocab_size,
317
+ (recv_req.image_hash >> 32) % self.model_config.vocab_size,
318
+ (recv_req.image_hash >> 64) % self.model_config.vocab_size,
319
+ ]
320
+ req.image_size = recv_req.image_size
321
+ (
322
+ req.origin_input_ids,
323
+ req.image_offset,
324
+ ) = self.model_runner.model.pad_input_ids(
325
+ req.origin_input_ids_unpadded,
326
+ req.pad_value,
327
+ req.pixel_values.shape,
328
+ req.image_size,
336
329
  )
330
+ req.return_logprob = recv_req.return_logprob
331
+ req.logprob_start_len = recv_req.logprob_start_len
332
+ req.top_logprobs_num = recv_req.top_logprobs_num
333
+ req.stream = recv_req.stream
334
+
335
+ # Init regex fsm
336
+ if req.sampling_params.regex is not None:
337
+ req.regex_fsm = self.regex_fsm_cache.query(req.sampling_params.regex)
338
+ if not self.disable_regex_jump_forward:
339
+ req.jump_forward_map = self.jump_forward_cache.query(
340
+ req.sampling_params.regex
341
+ )
337
342
 
338
343
  # Truncate prompts that are too long
339
344
  if len(req.origin_input_ids) >= self.max_req_input_len:
@@ -342,186 +347,87 @@ class ModelTpServer:
342
347
  "the max context length. Truncated!!!"
343
348
  )
344
349
  req.origin_input_ids = req.origin_input_ids[: self.max_req_input_len]
345
- req.sampling_params.max_new_tokens = min(
346
- (
347
- req.sampling_params.max_new_tokens
348
- if req.sampling_params.max_new_tokens is not None
349
- else 1 << 30
350
- ),
351
- self.max_req_input_len - 1 - len(req.origin_input_ids),
352
- )
350
+
351
+ if self.model_runner.is_generation:
352
+ req.sampling_params.max_new_tokens = min(
353
+ (
354
+ req.sampling_params.max_new_tokens
355
+ if req.sampling_params.max_new_tokens is not None
356
+ else 1 << 30
357
+ ),
358
+ self.max_req_input_len - 1 - len(req.origin_input_ids),
359
+ )
360
+
353
361
  self.waiting_queue.append(req)
354
362
 
355
363
  def get_new_prefill_batch(self) -> Optional[ScheduleBatch]:
356
- # TODO(lsyin): organize this function
357
364
  running_bs = (
358
365
  len(self.running_batch.reqs) if self.running_batch is not None else 0
359
366
  )
360
367
  if running_bs >= self.max_running_requests:
361
- return
362
-
363
- # Compute matched prefix length
364
- for req in self.waiting_queue:
365
- req.input_ids = req.origin_input_ids + req.output_ids
366
- try_match_ids = req.input_ids
367
- if req.return_logprob:
368
- try_match_ids = req.input_ids[: req.logprob_start_len]
369
- # NOTE: the prefix_indices must always be aligned with last_node
370
- prefix_indices, last_node = self.tree_cache.match_prefix(
371
- rid=req.rid, key=try_match_ids
372
- )
373
- req.extend_input_len = len(req.input_ids) - len(prefix_indices)
374
- req.prefix_indices = prefix_indices
375
- req.last_node = last_node
368
+ return None
376
369
 
377
370
  # Get priority queue
378
- self.waiting_queue = self.scheduler.get_priority_queue(self.waiting_queue)
371
+ prefix_computed = self.scheduler.calc_priority(self.waiting_queue)
379
372
 
380
- # Add requests if there is available space
381
- can_run_list = []
382
- new_batch_total_tokens = 0
383
- new_batch_input_tokens = 0
384
-
385
- available_size = (
386
- self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size()
373
+ adder = PrefillAdder(
374
+ self.tree_cache,
375
+ self.token_to_kv_pool.available_size() + self.tree_cache.evictable_size(),
376
+ self.max_prefill_tokens,
377
+ self.chunked_prefill_size,
387
378
  )
388
- if self.running_batch:
389
- available_size -= sum(
390
- [
391
- (r.sampling_params.max_new_tokens - len(r.output_ids))
392
- * self.new_token_ratio
393
- for r in self.running_batch.reqs
394
- ]
395
- )
396
379
 
397
- # Handle the current inflight request
398
- take_inflight = 0
399
- if self.current_inflight_req:
400
- take_inflight = 1
401
- r = self.current_inflight_req
402
- r.input_ids = r.origin_input_ids + r.output_ids
403
- truncated = (
404
- len(r.input_ids) - len(r.prefix_indices) > self.chunked_prefill_size
380
+ if self.running_batch is not None:
381
+ adder.remove_running_tokens(self.running_batch, self.new_token_ratio)
382
+
383
+ has_inflight = self.current_inflight_req is not None
384
+ if self.current_inflight_req is not None:
385
+ self.current_inflight_req.init_next_round_input(
386
+ None if prefix_computed else self.tree_cache
405
387
  )
406
- r.extend_input_len = min(
407
- len(r.input_ids) - len(r.prefix_indices), self.chunked_prefill_size
388
+ self.current_inflight_req = adder.add_inflight_req(
389
+ self.current_inflight_req
408
390
  )
409
- r.input_ids = r.input_ids[: len(r.prefix_indices) + r.extend_input_len]
410
- can_run_list.append(r)
411
-
412
- if not truncated:
413
- # Finish inflight
414
- self.current_inflight_req = None
415
- new_batch_total_tokens += (
416
- r.extend_input_len + r.sampling_params.max_new_tokens
417
- )
418
- new_batch_input_tokens += r.extend_input_len
419
- else:
420
- new_batch_total_tokens += r.extend_input_len
421
- new_batch_input_tokens += r.extend_input_len
422
391
 
423
392
  for req in self.waiting_queue:
424
- if req.return_logprob and req.normalized_prompt_logprob is None:
425
- # Need at least two tokens to compute normalized logprob
426
- if req.extend_input_len < 2:
427
- delta = 2 - req.extend_input_len
428
- req.extend_input_len += delta
429
- req.prefix_indices = req.prefix_indices[:-delta]
430
- if req.image_offset is not None:
431
- req.image_offset += delta
432
- if req.extend_input_len == 0 and req.sampling_params.max_new_tokens > 0:
433
- # Need at least one token to compute logits
434
- req.extend_input_len = 1
435
- req.prefix_indices = req.prefix_indices[:-1]
436
- if req.image_offset is not None:
437
- req.image_offset += 1
438
-
393
+ req.init_next_round_input(None if prefix_computed else self.tree_cache)
394
+ res = adder.add_one_req(req)
439
395
  if (
440
- req.extend_input_len
441
- + req.sampling_params.max_new_tokens
442
- + new_batch_total_tokens
443
- < available_size
444
- and (
445
- req.extend_input_len + new_batch_input_tokens
446
- <= self.max_prefill_tokens
447
- or len(can_run_list) == 0
448
- )
396
+ not res
397
+ or adder.no_remaining_tokens()
398
+ or running_bs + len(adder.can_run_list) >= self.max_running_requests
449
399
  ):
450
- delta = self.tree_cache.inc_lock_ref(req.last_node)
451
- available_size += delta
452
-
453
- if not (
454
- req.extend_input_len
455
- + req.sampling_params.max_new_tokens
456
- + new_batch_total_tokens
457
- < available_size
458
- ):
459
- # Undo locking
460
- delta = self.tree_cache.dec_lock_ref(req.last_node)
461
- available_size += delta
462
- break
463
- else:
464
- # Add this request to the running batch
465
- if (
466
- self.chunked_prefill_size is None
467
- or (
468
- new_batch_input_tokens + req.extend_input_len
469
- <= self.chunked_prefill_size
470
- )
471
- or (
472
- req.return_logprob and req.normalized_prompt_logprob is None
473
- )
474
- ):
475
- can_run_list.append(req)
476
- new_batch_total_tokens += (
477
- req.extend_input_len + req.sampling_params.max_new_tokens
478
- )
479
- new_batch_input_tokens += req.extend_input_len
480
- else:
481
- trunc_len = self.chunked_prefill_size - new_batch_input_tokens
482
-
483
- if trunc_len <= 0:
484
- # Undo locking
485
- delta = self.tree_cache.dec_lock_ref(req.last_node)
486
- available_size += delta
487
- break
488
-
489
- req.extend_input_len = trunc_len
490
- req.input_ids = req.input_ids[
491
- : len(req.prefix_indices) + req.extend_input_len
492
- ]
493
- can_run_list.append(req)
494
- self.current_inflight_req = req
495
- new_batch_input_tokens += req.extend_input_len
496
- new_batch_total_tokens += req.extend_input_len
497
- break
498
- else:
499
400
  break
500
401
 
501
- if running_bs + len(can_run_list) >= self.max_running_requests:
502
- break
402
+ can_run_list = adder.can_run_list
403
+
404
+ if adder.new_inflight_req is not None:
405
+ assert self.current_inflight_req is None
406
+ self.current_inflight_req = adder.new_inflight_req
503
407
 
504
408
  if len(can_run_list) == 0:
505
409
  return None
506
410
 
507
411
  # Print stats
508
412
  if self.tp_rank == 0:
509
- hit_tokens = sum(len(x.prefix_indices) for x in can_run_list)
510
- self.tree_cache_metrics["total"] += (
511
- hit_tokens + new_batch_input_tokens
512
- ) / 10**9
513
- self.tree_cache_metrics["hit"] += hit_tokens / 10**9
514
- tree_cache_hit_rate = (
515
- self.tree_cache_metrics["hit"] / self.tree_cache_metrics["total"]
516
- )
413
+ if isinstance(self.tree_cache, RadixCache):
414
+ self.tree_cache_metrics["total"] += (
415
+ adder.log_input_tokens + adder.log_hit_tokens
416
+ ) / 10**9
417
+ self.tree_cache_metrics["hit"] += (adder.log_hit_tokens) / 10**9
418
+ tree_cache_hit_rate = (
419
+ self.tree_cache_metrics["hit"] / self.tree_cache_metrics["total"]
420
+ )
421
+ else:
422
+ tree_cache_hit_rate = 0.0
517
423
  logger.info(
518
424
  f"[gpu={self.gpu_id}] Prefill batch. "
519
425
  f"#new-seq: {len(can_run_list)}, "
520
- f"#new-token: {new_batch_input_tokens}, "
521
- f"#cached-token: {hit_tokens}, "
426
+ f"#new-token: {adder.log_input_tokens}, "
427
+ f"#cached-token: {adder.log_hit_tokens}, "
522
428
  f"cache hit rate: {100.0 * tree_cache_hit_rate:.2f}%, "
523
429
  f"#running-req: {running_bs}, "
524
- f"#queue-req: {len(self.waiting_queue) - len(can_run_list) + take_inflight}"
430
+ f"#queue-req: {len(self.waiting_queue) - len(can_run_list) + has_inflight}"
525
431
  )
526
432
 
527
433
  # Return the new batch
@@ -540,41 +446,88 @@ class ModelTpServer:
540
446
  self.model_config.vocab_size, self.int_token_logit_bias
541
447
  )
542
448
 
543
- # Forward and sample the next tokens
544
- if batch.extend_num_tokens != 0:
545
- output = self.model_runner.forward(batch, ForwardMode.EXTEND)
546
- next_token_ids = batch.sample(output.next_token_logits)
547
-
548
- # Move logprobs to cpu
549
- if output.next_token_logprobs is not None:
550
- output.next_token_logprobs = output.next_token_logprobs[
551
- torch.arange(len(next_token_ids), device=next_token_ids.device),
552
- next_token_ids,
553
- ].tolist()
554
- output.input_token_logprobs = output.input_token_logprobs.tolist()
555
- output.normalized_prompt_logprobs = (
556
- output.normalized_prompt_logprobs.tolist()
557
- )
449
+ if self.model_runner.is_generation:
450
+ # Forward and sample the next tokens
451
+ if batch.extend_num_tokens != 0:
452
+ output = self.model_runner.forward(batch, ForwardMode.EXTEND)
453
+ next_token_ids = batch.sample(output.next_token_logits)
454
+
455
+ # Move logprobs to cpu
456
+ if output.next_token_logprobs is not None:
457
+ output.next_token_logprobs = output.next_token_logprobs[
458
+ torch.arange(len(next_token_ids), device=next_token_ids.device),
459
+ next_token_ids,
460
+ ].tolist()
461
+ output.input_token_logprobs = output.input_token_logprobs.tolist()
462
+ output.normalized_prompt_logprobs = (
463
+ output.normalized_prompt_logprobs.tolist()
464
+ )
558
465
 
559
- next_token_ids = next_token_ids.tolist()
560
- else:
561
- next_token_ids = [self.tokenizer.eos_token_id] * len(batch.reqs)
466
+ next_token_ids = next_token_ids.tolist()
467
+ else:
468
+ if self.tokenizer is None:
469
+ next_token_ids = []
470
+ for req in batch.reqs:
471
+ next_token_ids.append(
472
+ next(iter(req.sampling_params.stop_token_ids))
473
+ )
474
+ else:
475
+ next_token_ids = [self.tokenizer.eos_token_id] * len(batch.reqs)
476
+
477
+ # Check finish conditions
478
+ pt = 0
479
+ for i, req in enumerate(batch.reqs):
480
+ if req is not self.current_inflight_req:
481
+ # Inflight reqs' prefill is not finished
482
+ req.completion_tokens_wo_jump_forward += 1
483
+ req.output_ids.append(next_token_ids[i])
484
+ req.check_finished()
485
+
486
+ if req.finished():
487
+ self.tree_cache.cache_finished_req(req)
488
+ else:
489
+ self.tree_cache.cache_unfinished_req(req)
562
490
 
563
- # Check finish conditions
564
- pt = 0
565
- for i, req in enumerate(batch.reqs):
566
- if req is not self.current_inflight_req:
567
- req.completion_tokens_wo_jump_forward += 1
568
- req.output_ids.append(next_token_ids[i])
569
- req.check_finished()
491
+ if req is self.current_inflight_req:
492
+ # Inflight request would get a new req idx
493
+ self.req_to_token_pool.free(req.req_pool_idx)
570
494
 
571
- if req.return_logprob:
572
- self.add_logprob_return_values(i, req, pt, next_token_ids, output)
573
- pt += req.extend_input_len
495
+ if req.return_logprob:
496
+ self.add_logprob_return_values(i, req, pt, next_token_ids, output)
497
+ pt += req.extend_input_len
498
+ else:
499
+ assert batch.extend_num_tokens != 0
500
+ output = self.model_runner.forward(batch, ForwardMode.EXTEND)
501
+ embeddings = output.embeddings.tolist()
502
+
503
+ # Check finish conditions
504
+ for i, req in enumerate(batch.reqs):
505
+ req.embedding = embeddings[i]
506
+ if req is not self.current_inflight_req:
507
+ # Inflight reqs' prefill is not finished
508
+ # dummy output token for embedding models
509
+ req.output_ids.append(0)
510
+ req.check_finished()
511
+
512
+ if req.finished():
513
+ self.tree_cache.cache_finished_req(req)
514
+ else:
515
+ self.tree_cache.cache_unfinished_req(req)
516
+
517
+ if req is self.current_inflight_req:
518
+ # Inflight request would get a new req idx
519
+ self.req_to_token_pool.free(req.req_pool_idx)
574
520
 
575
521
  self.handle_finished_requests(batch)
576
522
 
577
- def add_logprob_return_values(self, i, req, pt, next_token_ids, output):
523
+ def add_logprob_return_values(
524
+ self,
525
+ i,
526
+ req: Req,
527
+ pt: int,
528
+ next_token_ids: List[int],
529
+ output: LogitProcessorOutput,
530
+ ):
578
531
  if req.normalized_prompt_logprob is None:
579
532
  req.normalized_prompt_logprob = output.normalized_prompt_logprobs[i]
580
533
 
@@ -583,12 +536,12 @@ class ModelTpServer:
583
536
  req.input_token_logprobs = list(
584
537
  zip(
585
538
  output.input_token_logprobs[pt : pt + req.extend_input_len - 1],
586
- req.input_ids[-req.extend_input_len + 1 :],
539
+ req.fill_ids[-req.extend_input_len + 1 :],
587
540
  )
588
541
  )
589
542
  if req.logprob_start_len == 0:
590
543
  req.input_token_logprobs = [
591
- (None, req.input_ids[0])
544
+ (None, req.fill_ids[0])
592
545
  ] + req.input_token_logprobs
593
546
 
594
547
  if req.last_update_decode_tokens != 0:
@@ -602,7 +555,7 @@ class ModelTpServer:
602
555
  + req.extend_input_len
603
556
  - 1
604
557
  ],
605
- req.input_ids[-req.last_update_decode_tokens + 1 :],
558
+ req.fill_ids[-req.last_update_decode_tokens + 1 :],
606
559
  )
607
560
  )
608
561
  )
@@ -623,22 +576,6 @@ class ModelTpServer:
623
576
  )
624
577
  req.output_top_logprobs.append(output.output_top_logprobs[i])
625
578
 
626
- def cache_filled_batch(self, batch: ScheduleBatch):
627
- for i, req in enumerate(batch.reqs):
628
- new_prefix_indices, new_last_node = self.tree_cache.cache_req(
629
- rid=req.rid,
630
- token_ids=tuple(req.input_ids),
631
- last_uncached_pos=len(req.prefix_indices),
632
- req_pool_idx=req.req_pool_idx,
633
- del_in_memory_pool=False,
634
- old_last_node=req.last_node,
635
- )
636
- req.prefix_indices, req.last_node = new_prefix_indices, new_last_node
637
-
638
- if req is self.current_inflight_req:
639
- # inflight request would get a new req idx
640
- self.req_to_token_pool.free(req.req_pool_idx)
641
-
642
579
  def forward_decode_batch(self, batch: ScheduleBatch):
643
580
  # Check if decode out of memory
644
581
  if not batch.check_decode_mem():
@@ -689,6 +626,9 @@ class ModelTpServer:
689
626
  req.output_ids.append(next_token_id)
690
627
  req.check_finished()
691
628
 
629
+ if req.finished():
630
+ self.tree_cache.cache_finished_req(req)
631
+
692
632
  if req.return_logprob:
693
633
  req.output_token_logprobs.append(
694
634
  (next_token_logprobs[i], next_token_id)
@@ -700,20 +640,21 @@ class ModelTpServer:
700
640
 
701
641
  def handle_finished_requests(self, batch: ScheduleBatch):
702
642
  output_rids = []
703
- output_vids = []
704
- decoded_texts = []
705
- output_read_ids = []
706
- output_read_offsets = []
707
- output_skip_special_tokens = []
708
- output_spaces_between_special_tokens = []
709
643
  output_meta_info = []
710
644
  output_finished_reason: List[BaseFinishReason] = []
711
- finished_indices = []
645
+ if self.model_runner.is_generation:
646
+ output_vids = []
647
+ decoded_texts = []
648
+ output_read_ids = []
649
+ output_read_offsets = []
650
+ output_skip_special_tokens = []
651
+ output_spaces_between_special_tokens = []
652
+ else: # for embedding model
653
+ output_embeddings = []
712
654
  unfinished_indices = []
655
+
713
656
  for i, req in enumerate(batch.reqs):
714
- if req.finished():
715
- finished_indices.append(i)
716
- else:
657
+ if not req.finished() and req is not self.current_inflight_req:
717
658
  unfinished_indices.append(i)
718
659
 
719
660
  if req.finished() or (
@@ -726,85 +667,75 @@ class ModelTpServer:
726
667
  )
727
668
  ):
728
669
  output_rids.append(req.rid)
729
- output_vids.append(req.vid)
730
- decoded_texts.append(req.decoded_text)
731
- read_ids, read_offset = req.init_incremental_detokenize()
732
- output_read_ids.append(read_ids)
733
- output_read_offsets.append(read_offset)
734
- output_skip_special_tokens.append(
735
- req.sampling_params.skip_special_tokens
736
- )
737
- output_spaces_between_special_tokens.append(
738
- req.sampling_params.spaces_between_special_tokens
739
- )
740
-
741
- meta_info = {
742
- "prompt_tokens": len(req.origin_input_ids),
743
- "completion_tokens": len(req.output_ids),
744
- "completion_tokens_wo_jump_forward": req.completion_tokens_wo_jump_forward,
745
- "finish_reason": str(req.finished_reason),
746
- }
747
- if req.return_logprob:
748
- (
749
- meta_info["input_token_logprobs"],
750
- meta_info["output_token_logprobs"],
751
- meta_info["input_top_logprobs"],
752
- meta_info["output_top_logprobs"],
753
- meta_info["normalized_prompt_logprob"],
754
- ) = (
755
- req.input_token_logprobs,
756
- req.output_token_logprobs,
757
- req.input_top_logprobs,
758
- req.output_top_logprobs,
759
- req.normalized_prompt_logprob,
760
- )
761
- output_meta_info.append(meta_info)
762
670
  output_finished_reason.append(req.finished_reason)
671
+ if self.model_runner.is_generation:
672
+ output_vids.append(req.vid)
673
+ decoded_texts.append(req.decoded_text)
674
+ read_ids, read_offset = req.init_incremental_detokenize()
675
+ output_read_ids.append(read_ids)
676
+ output_read_offsets.append(read_offset)
677
+ output_skip_special_tokens.append(
678
+ req.sampling_params.skip_special_tokens
679
+ )
680
+ output_spaces_between_special_tokens.append(
681
+ req.sampling_params.spaces_between_special_tokens
682
+ )
683
+
684
+ meta_info = {
685
+ "prompt_tokens": len(req.origin_input_ids),
686
+ "completion_tokens": len(req.output_ids),
687
+ "completion_tokens_wo_jump_forward": req.completion_tokens_wo_jump_forward,
688
+ "finish_reason": str(req.finished_reason),
689
+ }
690
+ if req.return_logprob:
691
+ (
692
+ meta_info["input_token_logprobs"],
693
+ meta_info["output_token_logprobs"],
694
+ meta_info["input_top_logprobs"],
695
+ meta_info["output_top_logprobs"],
696
+ meta_info["normalized_prompt_logprob"],
697
+ ) = (
698
+ req.input_token_logprobs,
699
+ req.output_token_logprobs,
700
+ req.input_top_logprobs,
701
+ req.output_top_logprobs,
702
+ req.normalized_prompt_logprob,
703
+ )
704
+ output_meta_info.append(meta_info)
705
+ else: # for embedding model
706
+ output_embeddings.append(req.embedding)
707
+ meta_info = {
708
+ "prompt_tokens": len(req.origin_input_ids),
709
+ }
710
+ output_meta_info.append(meta_info)
763
711
 
764
712
  # Send to detokenizer
765
713
  if output_rids:
766
- self.out_pyobjs.append(
767
- BatchTokenIDOut(
768
- output_rids,
769
- output_vids,
770
- decoded_texts,
771
- output_read_ids,
772
- output_read_offsets,
773
- output_skip_special_tokens,
774
- output_spaces_between_special_tokens,
775
- output_meta_info,
776
- output_finished_reason,
714
+ if self.model_runner.is_generation:
715
+ self.out_pyobjs.append(
716
+ BatchTokenIDOut(
717
+ output_rids,
718
+ output_vids,
719
+ decoded_texts,
720
+ output_read_ids,
721
+ output_read_offsets,
722
+ output_skip_special_tokens,
723
+ output_spaces_between_special_tokens,
724
+ output_meta_info,
725
+ output_finished_reason,
726
+ )
777
727
  )
778
- )
779
-
780
- # Remove finished reqs
781
- if finished_indices:
782
- # Update radix cache
783
- for i in finished_indices:
784
- req = batch.reqs[i]
785
- self.tree_cache.cache_req(
786
- rid=req.rid,
787
- token_ids=tuple(req.origin_input_ids + req.output_ids)[:-1],
788
- last_uncached_pos=len(req.prefix_indices),
789
- req_pool_idx=req.req_pool_idx,
728
+ else: # for embedding model
729
+ self.out_pyobjs.append(
730
+ BatchEmbeddingOut(
731
+ output_rids,
732
+ output_embeddings,
733
+ output_meta_info,
734
+ output_finished_reason,
735
+ )
790
736
  )
791
737
 
792
- self.tree_cache.dec_lock_ref(req.last_node)
793
-
794
- # Update batch tensors
795
- if unfinished_indices:
796
- batch.filter_batch(unfinished_indices)
797
- else:
798
- batch.reqs = []
799
-
800
- def filter_out_inflight(self, batch: ScheduleBatch):
801
- # TODO(lsyin): reduce the overhead, make a special version for this
802
- if self.current_inflight_req is None:
803
- return
804
-
805
- to_remove = batch.reqs.index(self.current_inflight_req)
806
- unfinished_indices = [i for i in range(len(batch.reqs)) if i != to_remove]
807
-
738
+ # Remove finished reqs: update batch tensors
808
739
  batch.filter_batch(unfinished_indices)
809
740
 
810
741
  def flush_cache(self):
@@ -871,7 +802,11 @@ def run_tp_server(
871
802
 
872
803
 
873
804
  def launch_tp_servers(
874
- gpu_ids, tp_rank_range, server_args, nccl_port, model_overide_args
805
+ gpu_ids: List[int],
806
+ tp_rank_range: List[int],
807
+ server_args: ServerArgs,
808
+ nccl_port: int,
809
+ model_overide_args: dict,
875
810
  ):
876
811
  """Launch multiple tensor parallel servers."""
877
812
  procs = []
@@ -886,7 +821,9 @@ def launch_tp_servers(
886
821
  return procs
887
822
 
888
823
 
889
- def broadcast_recv_input(data, rank, dist_group):
824
+ def broadcast_recv_input(
825
+ data: Any, rank: int, dist_group: torch.distributed.ProcessGroup
826
+ ):
890
827
  """Broadcast inputs from rank=0 to all other ranks with torch.dist backend."""
891
828
 
892
829
  if rank == 0: