sglang 0.4.5.post2__py3-none-any.whl → 0.4.6__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (99) hide show
  1. sglang/bench_one_batch.py +19 -3
  2. sglang/bench_serving.py +8 -8
  3. sglang/compile_deep_gemm.py +177 -0
  4. sglang/lang/backend/openai.py +5 -1
  5. sglang/lang/backend/runtime_endpoint.py +5 -1
  6. sglang/srt/code_completion_parser.py +1 -1
  7. sglang/srt/configs/deepseekvl2.py +1 -1
  8. sglang/srt/configs/model_config.py +11 -2
  9. sglang/srt/constrained/llguidance_backend.py +78 -61
  10. sglang/srt/constrained/xgrammar_backend.py +1 -0
  11. sglang/srt/conversation.py +34 -1
  12. sglang/srt/disaggregation/decode.py +96 -5
  13. sglang/srt/disaggregation/mini_lb.py +113 -15
  14. sglang/srt/disaggregation/mooncake/conn.py +199 -32
  15. sglang/srt/disaggregation/nixl/__init__.py +1 -0
  16. sglang/srt/disaggregation/nixl/conn.py +622 -0
  17. sglang/srt/disaggregation/prefill.py +119 -20
  18. sglang/srt/disaggregation/utils.py +17 -0
  19. sglang/srt/entrypoints/engine.py +4 -0
  20. sglang/srt/entrypoints/http_server.py +11 -9
  21. sglang/srt/function_call_parser.py +132 -0
  22. sglang/srt/layers/activation.py +2 -2
  23. sglang/srt/layers/attention/base_attn_backend.py +3 -0
  24. sglang/srt/layers/attention/flashattention_backend.py +809 -160
  25. sglang/srt/layers/attention/flashmla_backend.py +8 -11
  26. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +5 -5
  27. sglang/srt/layers/attention/triton_ops/extend_attention.py +5 -5
  28. sglang/srt/layers/attention/triton_ops/prefill_attention.py +7 -3
  29. sglang/srt/layers/attention/vision.py +2 -0
  30. sglang/srt/layers/dp_attention.py +1 -1
  31. sglang/srt/layers/layernorm.py +42 -5
  32. sglang/srt/layers/logits_processor.py +2 -2
  33. sglang/srt/layers/moe/ep_moe/layer.py +2 -0
  34. sglang/srt/layers/moe/fused_moe_native.py +2 -4
  35. sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +41 -41
  36. sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  37. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +18 -15
  38. sglang/srt/layers/pooler.py +6 -0
  39. sglang/srt/layers/quantization/awq.py +5 -1
  40. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
  41. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
  42. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +153 -0
  43. sglang/srt/layers/quantization/deep_gemm.py +385 -0
  44. sglang/srt/layers/quantization/fp8_kernel.py +7 -38
  45. sglang/srt/layers/quantization/fp8_utils.py +2 -2
  46. sglang/srt/layers/quantization/gptq.py +13 -7
  47. sglang/srt/layers/quantization/int8_kernel.py +32 -1
  48. sglang/srt/layers/quantization/modelopt_quant.py +2 -2
  49. sglang/srt/layers/quantization/w8a8_int8.py +3 -3
  50. sglang/srt/layers/radix_attention.py +13 -3
  51. sglang/srt/layers/rotary_embedding.py +176 -132
  52. sglang/srt/layers/sampler.py +2 -2
  53. sglang/srt/managers/data_parallel_controller.py +17 -4
  54. sglang/srt/managers/io_struct.py +21 -3
  55. sglang/srt/managers/mm_utils.py +85 -28
  56. sglang/srt/managers/multimodal_processors/base_processor.py +14 -1
  57. sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +9 -2
  58. sglang/srt/managers/multimodal_processors/gemma3.py +2 -5
  59. sglang/srt/managers/multimodal_processors/janus_pro.py +2 -2
  60. sglang/srt/managers/multimodal_processors/minicpm.py +4 -3
  61. sglang/srt/managers/multimodal_processors/qwen_vl.py +38 -13
  62. sglang/srt/managers/schedule_batch.py +42 -12
  63. sglang/srt/managers/scheduler.py +47 -26
  64. sglang/srt/managers/tokenizer_manager.py +120 -30
  65. sglang/srt/managers/tp_worker.py +1 -0
  66. sglang/srt/mem_cache/hiradix_cache.py +40 -32
  67. sglang/srt/mem_cache/memory_pool.py +118 -13
  68. sglang/srt/model_executor/cuda_graph_runner.py +16 -10
  69. sglang/srt/model_executor/forward_batch_info.py +51 -95
  70. sglang/srt/model_executor/model_runner.py +29 -27
  71. sglang/srt/models/deepseek.py +12 -2
  72. sglang/srt/models/deepseek_nextn.py +101 -6
  73. sglang/srt/models/deepseek_v2.py +153 -76
  74. sglang/srt/models/deepseek_vl2.py +9 -4
  75. sglang/srt/models/gemma3_causal.py +1 -1
  76. sglang/srt/models/llama4.py +0 -1
  77. sglang/srt/models/minicpm3.py +2 -2
  78. sglang/srt/models/minicpmo.py +22 -7
  79. sglang/srt/models/mllama4.py +2 -2
  80. sglang/srt/models/qwen2_5_vl.py +3 -6
  81. sglang/srt/models/qwen2_vl.py +3 -7
  82. sglang/srt/models/roberta.py +178 -0
  83. sglang/srt/openai_api/adapter.py +87 -10
  84. sglang/srt/openai_api/protocol.py +6 -1
  85. sglang/srt/server_args.py +65 -60
  86. sglang/srt/speculative/build_eagle_tree.py +2 -2
  87. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +3 -3
  88. sglang/srt/speculative/eagle_utils.py +2 -2
  89. sglang/srt/speculative/eagle_worker.py +2 -7
  90. sglang/srt/torch_memory_saver_adapter.py +10 -1
  91. sglang/srt/utils.py +48 -6
  92. sglang/test/runners.py +6 -13
  93. sglang/test/test_utils.py +39 -19
  94. sglang/version.py +1 -1
  95. {sglang-0.4.5.post2.dist-info → sglang-0.4.6.dist-info}/METADATA +6 -7
  96. {sglang-0.4.5.post2.dist-info → sglang-0.4.6.dist-info}/RECORD +99 -92
  97. {sglang-0.4.5.post2.dist-info → sglang-0.4.6.dist-info}/WHEEL +1 -1
  98. {sglang-0.4.5.post2.dist-info → sglang-0.4.6.dist-info}/licenses/LICENSE +0 -0
  99. {sglang-0.4.5.post2.dist-info → sglang-0.4.6.dist-info}/top_level.txt +0 -0
@@ -285,6 +285,7 @@ class MultimodalInputs:
285
285
  num_image_tokens: Optional[int] = None
286
286
 
287
287
  # QWen2-VL related
288
+ mrope_positions: Optional[torch.Tensor] = None
288
289
  mrope_position_delta: Optional[torch.Tensor] = None
289
290
 
290
291
  # image
@@ -310,16 +311,12 @@ class MultimodalInputs:
310
311
  assert isinstance(ret.mm_items, list)
311
312
  ret.mm_items = [item for item in ret.mm_items if item.is_valid()]
312
313
 
313
- assert len(ret.mm_items) != 0
314
-
315
- # Use image hash as fake token_ids. We use this as the key for prefix matching in the radix cache.
316
- # Please note that if the `input_ids` is later used in the model forward,
317
- # you also need to clamp the values within the range of [0, vocab_size) to avoid out-of-bound
318
- # errors in cuda kernels. See also llava.py for example.
319
314
  for item in ret.mm_items:
320
315
  item.set_pad_value()
321
316
 
322
317
  optional_args = [
318
+ "mrope_positions",
319
+ "mrope_position_delta",
323
320
  "im_token_id",
324
321
  "im_start_id",
325
322
  "im_end_id",
@@ -350,11 +347,6 @@ class MultimodalInputs:
350
347
  merge image inputs when requests are being merged
351
348
  """
352
349
 
353
- # Use image hash as fake token_ids. We use this as the key for prefix matching in the radix cache.
354
- # Please note that if the `input_ids` is later used in the model forward,
355
- # you also need to clamp the values within the range of [0, vocab_size) to avoid out-of-bound
356
- # errors in cuda kernels. See also llava.py for example.
357
-
358
350
  # args needed to be merged
359
351
  optional_args = [
360
352
  "mm_items",
@@ -364,6 +356,30 @@ class MultimodalInputs:
364
356
  self_arg = getattr(self, arg, None)
365
357
  if self_arg is not None:
366
358
  setattr(self, arg, self_arg + getattr(other, arg))
359
+
360
+ mrope_positions = self.mrope_positions
361
+ if mrope_positions is not None:
362
+ if other.mrope_positions is None:
363
+ self.mrope_positions = mrope_positions
364
+ else:
365
+ self.mrope_positions = torch.cat(
366
+ [self.mrope_positions, other.mrope_positions], dim=1
367
+ )
368
+
369
+ mrope_position_delta = self.mrope_position_delta
370
+ if mrope_position_delta is not None:
371
+ if other.mrope_position_delta is None:
372
+ self.mrope_position_delta = mrope_position_delta
373
+ else:
374
+ self.mrope_position_delta = torch.cat(
375
+ [self.mrope_position_delta, other.mrope_position_delta], dim=0
376
+ )
377
+
378
+ for key, val in other.__dict__.items():
379
+ if "_id" in key:
380
+ # set token_ids
381
+ if getattr(self, key, None) is None:
382
+ setattr(self, key, getattr(other, key, None))
367
383
  # other args would be kept intact
368
384
 
369
385
 
@@ -388,6 +404,7 @@ class Req:
388
404
  return_hidden_states: bool = False,
389
405
  eos_token_ids: Optional[Set[int]] = None,
390
406
  bootstrap_host: Optional[str] = None,
407
+ bootstrap_port: Optional[int] = None,
391
408
  bootstrap_room: Optional[int] = None,
392
409
  ):
393
410
  # Input and output info
@@ -523,6 +540,7 @@ class Req:
523
540
 
524
541
  # For disaggregation
525
542
  self.bootstrap_host: str = bootstrap_host
543
+ self.bootstrap_port: Optional[int] = bootstrap_port
526
544
  self.bootstrap_room: Optional[int] = bootstrap_room
527
545
  self.disagg_kv_sender: Optional[BaseKVSender] = None
528
546
 
@@ -539,6 +557,11 @@ class Req:
539
557
  # The first output_id transferred from prefill instance.
540
558
  self.transferred_output_id: Optional[int] = None
541
559
 
560
+ # For overlap schedule, we delay the kv transfer until `process_batch_result_disagg_prefill` rather than `process_prefill_chunk` in non-overlap
561
+ # This is because kv is not ready in `process_prefill_chunk`.
562
+ # We use `tmp_end_idx` to store the end index of the kv cache to send.
563
+ self.tmp_end_idx: int = -1
564
+
542
565
  @property
543
566
  def seqlen(self):
544
567
  return len(self.origin_input_ids) + len(self.output_ids)
@@ -571,6 +594,14 @@ class Req:
571
594
  self.prefix_indices, self.last_node = tree_cache.match_prefix(
572
595
  rid=self.rid, key=self.adjust_max_prefix_ids()
573
596
  )
597
+ elif enable_hierarchical_cache:
598
+ # in case last_node is evicted during scheduling, we need to update the prefix_indices
599
+ while self.last_node.evicted:
600
+ self.prefix_indices = self.prefix_indices[
601
+ : -len(self.last_node.host_value)
602
+ ]
603
+ self.last_node = self.last_node.parent
604
+
574
605
  self.extend_input_len = len(self.fill_ids) - len(self.prefix_indices)
575
606
 
576
607
  def adjust_max_prefix_ids(self):
@@ -1437,7 +1468,6 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1437
1468
  if self.model_config.is_encoder_decoder:
1438
1469
  self.encoder_lens = torch.cat([self.encoder_lens, other.encoder_lens])
1439
1470
  self.encoder_lens_cpu.extend(other.encoder_lens_cpu)
1440
-
1441
1471
  self.req_pool_indices = torch.cat(
1442
1472
  [self.req_pool_indices, other.req_pool_indices]
1443
1473
  )
@@ -60,7 +60,8 @@ from sglang.srt.managers.io_struct import (
60
60
  CloseSessionReqInput,
61
61
  ExpertDistributionReq,
62
62
  ExpertDistributionReqOutput,
63
- FlushCacheReq,
63
+ FlushCacheReqInput,
64
+ FlushCacheReqOutput,
64
65
  GetInternalStateReq,
65
66
  GetInternalStateReqOutput,
66
67
  GetWeightsByNameReqInput,
@@ -402,7 +403,7 @@ class Scheduler(
402
403
  [
403
404
  (TokenizedGenerateReqInput, self.handle_generate_request),
404
405
  (TokenizedEmbeddingReqInput, self.handle_embedding_request),
405
- (FlushCacheReq, self.flush_cache_wrapped),
406
+ (FlushCacheReqInput, self.flush_cache_wrapped),
406
407
  (AbortReq, self.abort_request),
407
408
  (OpenSessionReqInput, self.open_session),
408
409
  (CloseSessionReqInput, self.close_session),
@@ -488,6 +489,8 @@ class Scheduler(
488
489
  tp_cache_group=self.tp_cpu_group,
489
490
  page_size=self.page_size,
490
491
  hicache_ratio=server_args.hicache_ratio,
492
+ hicache_size=server_args.hicache_size,
493
+ hicache_write_policy=server_args.hicache_write_policy,
491
494
  )
492
495
  else:
493
496
  self.tree_cache = RadixCache(
@@ -575,6 +578,10 @@ class Scheduler(
575
578
  bootstrap_port=self.server_args.disaggregation_bootstrap_port,
576
579
  transfer_backend=self.transfer_backend,
577
580
  )
581
+
582
+ # Metric for pre-allocation
583
+ self.num_tokens_pre_allocated = 0
584
+
578
585
  elif self.disaggregation_mode == DisaggregationMode.PREFILL:
579
586
  # *2 for the headroom.
580
587
  buffer_size = self.max_running_requests * 2
@@ -590,7 +597,7 @@ class Scheduler(
590
597
  )
591
598
  metadata_buffers = [output_id_buffer]
592
599
 
593
- self.disagg_prefill_pending_queue = PrefillBootstrapQueue(
600
+ self.disagg_prefill_bootstrap_queue = PrefillBootstrapQueue(
594
601
  token_to_kv_pool=self.token_to_kv_pool_allocator.get_kvcache(),
595
602
  req_to_metadata_buffer_idx_allocator=req_to_metadata_buffer_idx_allocator,
596
603
  metadata_buffers=metadata_buffers,
@@ -784,6 +791,7 @@ class Scheduler(
784
791
  return_hidden_states=recv_req.return_hidden_states,
785
792
  eos_token_ids=self.model_config.hf_eos_token_id,
786
793
  bootstrap_host=recv_req.bootstrap_host,
794
+ bootstrap_port=recv_req.bootstrap_port,
787
795
  bootstrap_room=recv_req.bootstrap_room,
788
796
  )
789
797
  req.tokenizer = self.tokenizer
@@ -898,7 +906,7 @@ class Scheduler(
898
906
  def _add_request_to_queue(self, req: Req):
899
907
  req.queue_time_start = time.time()
900
908
  if self.disaggregation_mode == DisaggregationMode.PREFILL:
901
- self.disagg_prefill_pending_queue.add(req)
909
+ self.disagg_prefill_bootstrap_queue.add(req)
902
910
  elif self.disaggregation_mode == DisaggregationMode.DECODE:
903
911
  self.disagg_decode_prealloc_queue.add(req)
904
912
  else:
@@ -988,8 +996,15 @@ class Scheduler(
988
996
  f"#cached-token: {adder.log_hit_tokens}, "
989
997
  f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
990
998
  f"#running-req: {running_bs}, "
991
- f"#queue-req: {len(self.waiting_queue)}, "
992
999
  )
1000
+
1001
+ if self.disaggregation_mode == DisaggregationMode.PREFILL:
1002
+ f += f"#unbootstrapped-req: {len(self.disagg_prefill_bootstrap_queue.queue)}, "
1003
+ f += f"#queue-req: {len(self.waiting_queue)}, "
1004
+ f += f"#transferring-req: {len(self.disagg_prefill_inflight_queue)} "
1005
+ else:
1006
+ f += f"#queue-req: {len(self.waiting_queue)}"
1007
+
993
1008
  logger.info(f)
994
1009
 
995
1010
  if self.enable_metrics:
@@ -1025,15 +1040,14 @@ class Scheduler(
1025
1040
  gap_latency / self.server_args.decode_log_interval
1026
1041
  )
1027
1042
 
1043
+ msg = (
1044
+ f"Decode batch. "
1045
+ f"#running-req: {num_running_reqs}, "
1046
+ f"#token: {num_used}, "
1047
+ f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
1048
+ )
1049
+
1028
1050
  if self.spec_algorithm.is_none():
1029
- msg = (
1030
- f"Decode batch. "
1031
- f"#running-req: {num_running_reqs}, "
1032
- f"#token: {num_used}, "
1033
- f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
1034
- f"gen throughput (token/s): {self.last_gen_throughput:.2f}, "
1035
- f"#queue-req: {len(self.waiting_queue)}, "
1036
- )
1037
1051
  spec_accept_length = 0
1038
1052
  else:
1039
1053
  spec_accept_length = (
@@ -1042,15 +1056,15 @@ class Scheduler(
1042
1056
  self.cum_spec_accept_length += self.spec_num_total_accepted_tokens
1043
1057
  self.cum_spec_accept_count += self.spec_num_total_forward_ct
1044
1058
  self.spec_num_total_accepted_tokens = self.spec_num_total_forward_ct = 0
1045
- msg = (
1046
- f"Decode batch. "
1047
- f"#running-req: {num_running_reqs}, "
1048
- f"#token: {num_used}, "
1049
- f"token usage: {num_used / self.max_total_num_tokens:.2f}, "
1050
- f"accept len: {spec_accept_length:.2f}, "
1051
- f"gen throughput (token/s): {self.last_gen_throughput:.2f}, "
1052
- f"#queue-req: {len(self.waiting_queue)}, "
1053
- )
1059
+ msg += f"accept len: {spec_accept_length:.2f}, "
1060
+
1061
+ if self.disaggregation_mode == DisaggregationMode.DECODE:
1062
+ msg += f"pre-allocated usage: {self.num_tokens_pre_allocated / self.max_total_num_tokens:.2f}, "
1063
+
1064
+ msg += (
1065
+ f"gen throughput (token/s): {self.last_gen_throughput:.2f}, "
1066
+ f"#queue-req: {len(self.waiting_queue)}"
1067
+ )
1054
1068
 
1055
1069
  logger.info(msg)
1056
1070
  if self.enable_metrics:
@@ -1596,8 +1610,9 @@ class Scheduler(
1596
1610
  time.sleep(5)
1597
1611
  self.parent_process.send_signal(signal.SIGQUIT)
1598
1612
 
1599
- def flush_cache_wrapped(self, recv_req: FlushCacheReq):
1600
- self.flush_cache()
1613
+ def flush_cache_wrapped(self, recv_req: FlushCacheReqInput):
1614
+ success = self.flush_cache()
1615
+ return FlushCacheReqOutput(success=success)
1601
1616
 
1602
1617
  def flush_cache(self):
1603
1618
  """Flush the memory pool and cache."""
@@ -2010,9 +2025,15 @@ def run_scheduler_process(
2010
2025
  else:
2011
2026
  scheduler.event_loop_normal()
2012
2027
  elif disaggregation_mode == DisaggregationMode.PREFILL:
2013
- scheduler.event_loop_normal_disagg_prefill()
2028
+ if scheduler.enable_overlap:
2029
+ scheduler.event_loop_overlap_disagg_prefill()
2030
+ else:
2031
+ scheduler.event_loop_normal_disagg_prefill()
2014
2032
  elif disaggregation_mode == DisaggregationMode.DECODE:
2015
- scheduler.event_loop_normal_disagg_decode()
2033
+ if scheduler.enable_overlap:
2034
+ scheduler.event_loop_overlap_disagg_decode()
2035
+ else:
2036
+ scheduler.event_loop_normal_disagg_decode()
2016
2037
 
2017
2038
  except Exception:
2018
2039
  traceback = get_exception_traceback()
@@ -66,7 +66,8 @@ from sglang.srt.managers.io_struct import (
66
66
  EmbeddingReqInput,
67
67
  ExpertDistributionReq,
68
68
  ExpertDistributionReqOutput,
69
- FlushCacheReq,
69
+ FlushCacheReqInput,
70
+ FlushCacheReqOutput,
70
71
  GenerateReqInput,
71
72
  GetInternalStateReq,
72
73
  GetInternalStateReqOutput,
@@ -264,6 +265,9 @@ class TokenizerManager:
264
265
  self.resume_memory_occupation_communicator = _Communicator(
265
266
  self.send_to_scheduler, server_args.dp_size
266
267
  )
268
+ self.flush_cache_communicator = _Communicator(
269
+ self.send_to_scheduler, server_args.dp_size
270
+ )
267
271
  self.start_profile_communicator = _Communicator(
268
272
  self.send_to_scheduler, server_args.dp_size
269
273
  )
@@ -314,6 +318,10 @@ class TokenizerManager:
314
318
  ResumeMemoryOccupationReqOutput,
315
319
  self.resume_memory_occupation_communicator.handle_recv,
316
320
  ),
321
+ (
322
+ FlushCacheReqOutput,
323
+ self.flush_cache_communicator.handle_recv,
324
+ ),
317
325
  (
318
326
  ProfileReqOutput,
319
327
  self.start_profile_communicator.handle_recv,
@@ -411,42 +419,67 @@ class TokenizerManager:
411
419
  input_ids = self.tokenizer.encode(input_text)
412
420
 
413
421
  image_inputs: Dict = await self.mm_processor.process_mm_data_async(
414
- obj.image_data, input_text or input_ids, obj, self.max_req_input_len
422
+ image_data=obj.image_data,
423
+ input_text=input_text or input_ids,
424
+ request_obj=obj,
425
+ max_req_input_len=self.max_req_input_len,
415
426
  )
416
427
  if image_inputs and "input_ids" in image_inputs:
417
428
  input_ids = image_inputs["input_ids"]
418
- if self.is_generation:
419
- return_logprob = obj.return_logprob
420
- logprob_start_len = obj.logprob_start_len
421
- top_logprobs_num = obj.top_logprobs_num
422
- token_ids_logprob = obj.token_ids_logprob
423
- session_params = (
424
- SessionParams(**obj.session_params) if obj.session_params else None
425
- )
429
+
430
+ self._validate_token_len(obj, input_ids)
431
+ return self._create_tokenized_object(
432
+ obj, input_text, input_ids, input_embeds, image_inputs
433
+ )
434
+
435
+ def _validate_token_len(
436
+ self, obj: Union[GenerateReqInput, EmbeddingReqInput], input_ids: List[int]
437
+ ) -> None:
438
+ """Validates that the input token count and the requested token count doesn't exceed the model's context length."""
426
439
 
427
440
  input_token_num = len(input_ids) if input_ids is not None else 0
441
+ # Check if input alone exceeds context length
428
442
  if input_token_num >= self.context_len:
429
443
  raise ValueError(
430
444
  f"The input ({input_token_num} tokens) is longer than the "
431
445
  f"model's context length ({self.context_len} tokens)."
432
446
  )
433
447
 
448
+ # Check total tokens (input + max_new_tokens)
449
+ max_new_tokens = obj.sampling_params.get("max_new_tokens")
434
450
  if (
435
- obj.sampling_params.get("max_new_tokens") is not None
436
- and obj.sampling_params.get("max_new_tokens") + input_token_num
437
- >= self.context_len
451
+ max_new_tokens is not None
452
+ and (max_new_tokens + input_token_num) >= self.context_len
438
453
  ):
439
- raise ValueError(
454
+ total_tokens = max_new_tokens + input_token_num
455
+ error_msg = (
440
456
  f"Requested token count exceeds the model's maximum context length "
441
- f"of {self.context_len} tokens. You requested a total of "
442
- f"{obj.sampling_params.get('max_new_tokens') + input_token_num} "
457
+ f"of {self.context_len} tokens. You requested a total of {total_tokens} "
443
458
  f"tokens: {input_token_num} tokens from the input messages and "
444
- f"{obj.sampling_params.get('max_new_tokens')} tokens for the "
445
- f"completion. Please reduce the number of tokens in the input "
446
- f"messages or the completion to fit within the limit."
459
+ f"{max_new_tokens} tokens for the completion. Please reduce the number "
460
+ f"of tokens in the input messages or the completion to fit within the limit."
461
+ )
462
+ raise ValueError(error_msg)
463
+
464
+ def _create_tokenized_object(
465
+ self,
466
+ obj: Union[GenerateReqInput, EmbeddingReqInput],
467
+ input_text: str,
468
+ input_ids: List[int],
469
+ input_embeds: Optional[Union[List[float], None]] = None,
470
+ image_inputs: Optional[Dict] = None,
471
+ ) -> Union[TokenizedGenerateReqInput, TokenizedEmbeddingReqInput]:
472
+ """Create a tokenized request object from common parameters."""
473
+
474
+ if self.is_generation:
475
+ return_logprob = obj.return_logprob
476
+ logprob_start_len = obj.logprob_start_len
477
+ top_logprobs_num = obj.top_logprobs_num
478
+ token_ids_logprob = obj.token_ids_logprob
479
+ session_params = (
480
+ SessionParams(**obj.session_params) if obj.session_params else None
447
481
  )
448
482
 
449
- # Parse sampling parameters
450
483
  sampling_params = SamplingParams(**obj.sampling_params)
451
484
  sampling_params.normalize(self.tokenizer)
452
485
  sampling_params.verify()
@@ -465,6 +498,7 @@ class TokenizerManager:
465
498
  token_ids_logprob,
466
499
  obj.stream,
467
500
  bootstrap_host=obj.bootstrap_host,
501
+ bootstrap_port=obj.bootstrap_port,
468
502
  bootstrap_room=obj.bootstrap_room,
469
503
  lora_path=obj.lora_path,
470
504
  input_embeds=input_embeds,
@@ -483,6 +517,50 @@ class TokenizerManager:
483
517
 
484
518
  return tokenized_obj
485
519
 
520
+ async def _batch_tokenize_and_process(
521
+ self, batch_size: int, obj: Union[GenerateReqInput, EmbeddingReqInput]
522
+ ) -> List[Union[TokenizedGenerateReqInput, TokenizedEmbeddingReqInput]]:
523
+ """Handle batch tokenization for text inputs only."""
524
+ logger.debug(f"Starting batch tokenization for {batch_size} text requests")
525
+
526
+ # Collect requests and texts
527
+ requests = [obj[i] for i in range(batch_size)]
528
+ texts = [req.text for req in requests]
529
+
530
+ # Batch tokenize all texts
531
+ encoded = self.tokenizer(texts)
532
+ input_ids_list = encoded["input_ids"]
533
+
534
+ # Process all requests
535
+ tokenized_objs = []
536
+ for i, req in enumerate(requests):
537
+ self._validate_token_len(obj[i], input_ids_list[i])
538
+ tokenized_objs.append(
539
+ self._create_tokenized_object(
540
+ req, req.text, input_ids_list[i], None, None
541
+ )
542
+ )
543
+ logger.debug(f"Completed batch processing for {batch_size} requests")
544
+ return tokenized_objs
545
+
546
+ def _validate_batch_tokenization_constraints(
547
+ self, batch_size: int, obj: Union[GenerateReqInput, EmbeddingReqInput]
548
+ ) -> None:
549
+ """Validate constraints for batch tokenization processing."""
550
+ for i in range(batch_size):
551
+ if self.is_generation and obj[i].image_data:
552
+ raise ValueError(
553
+ "For image input processing do not set `enable_tokenizer_batch_encode`."
554
+ )
555
+ if obj[i].input_ids is not None:
556
+ raise ValueError(
557
+ "Batch tokenization is not needed for pre-tokenized input_ids. Do not set `enable_tokenizer_batch_encode`."
558
+ )
559
+ if obj[i].input_embeds is not None:
560
+ raise ValueError(
561
+ "Batch tokenization is not needed for input_embeds. Do not set `enable_tokenizer_batch_encode`."
562
+ )
563
+
486
564
  def _send_one_request(
487
565
  self,
488
566
  obj: Union[GenerateReqInput, EmbeddingReqInput],
@@ -560,14 +638,27 @@ class TokenizerManager:
560
638
 
561
639
  generators = []
562
640
  rids = []
641
+
563
642
  if getattr(obj, "parallel_sample_num", 1) == 1:
564
- # Send all requests
565
- for i in range(batch_size):
566
- tmp_obj = obj[i]
567
- tokenized_obj = await self._tokenize_one_request(tmp_obj)
568
- self._send_one_request(tmp_obj, tokenized_obj, created_time)
569
- generators.append(self._wait_one_response(tmp_obj, request))
570
- rids.append(tmp_obj.rid)
643
+ if self.server_args.enable_tokenizer_batch_encode:
644
+ # Validate batch tokenization constraints
645
+ self._validate_batch_tokenization_constraints(batch_size, obj)
646
+
647
+ tokenized_objs = await self._batch_tokenize_and_process(batch_size, obj)
648
+
649
+ for i, tokenized_obj in enumerate(tokenized_objs):
650
+ tmp_obj = obj[i]
651
+ self._send_one_request(tmp_obj, tokenized_obj, created_time)
652
+ generators.append(self._wait_one_response(tmp_obj, request))
653
+ rids.append(tmp_obj.rid)
654
+ else:
655
+ # Sequential tokenization and processing
656
+ for i in range(batch_size):
657
+ tmp_obj = obj[i]
658
+ tokenized_obj = await self._tokenize_one_request(tmp_obj)
659
+ self._send_one_request(tmp_obj, tokenized_obj, created_time)
660
+ generators.append(self._wait_one_response(tmp_obj, request))
661
+ rids.append(tmp_obj.rid)
571
662
  else:
572
663
  # FIXME: When using batch and parallel_sample_num together, the perf is not optimal.
573
664
  if batch_size > 128:
@@ -628,9 +719,8 @@ class TokenizerManager:
628
719
  except StopAsyncIteration:
629
720
  pass
630
721
 
631
- def flush_cache(self):
632
- req = FlushCacheReq()
633
- self.send_to_scheduler.send_pyobj(req)
722
+ async def flush_cache(self) -> FlushCacheReqOutput:
723
+ return (await self.flush_cache_communicator(FlushCacheReqInput()))[0]
634
724
 
635
725
  def abort_request(self, rid: str):
636
726
  if rid not in self.rid_to_state:
@@ -116,6 +116,7 @@ class TpModelWorker:
116
116
  ),
117
117
  self.model_runner.req_to_token_pool.size,
118
118
  )
119
+ assert self.max_running_requests > 0, "max_running_request is zero"
119
120
  self.max_req_len = min(
120
121
  self.model_config.context_len - 1,
121
122
  self.max_total_num_tokens - 1,