sglang 0.4.5.post1__py3-none-any.whl → 0.4.5.post3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (119) hide show
  1. sglang/__init__.py +2 -4
  2. sglang/bench_one_batch.py +2 -2
  3. sglang/bench_serving.py +3 -6
  4. sglang/compile_deep_gemm.py +136 -0
  5. sglang/lang/backend/anthropic.py +0 -4
  6. sglang/lang/backend/base_backend.py +1 -1
  7. sglang/lang/backend/openai.py +6 -2
  8. sglang/lang/backend/runtime_endpoint.py +5 -1
  9. sglang/lang/backend/vertexai.py +0 -1
  10. sglang/lang/compiler.py +1 -7
  11. sglang/lang/tracer.py +3 -7
  12. sglang/srt/_custom_ops.py +0 -2
  13. sglang/srt/configs/model_config.py +4 -1
  14. sglang/srt/constrained/outlines_jump_forward.py +14 -1
  15. sglang/srt/constrained/triton_ops/bitmask_ops.py +141 -0
  16. sglang/srt/constrained/xgrammar_backend.py +27 -4
  17. sglang/srt/custom_op.py +0 -62
  18. sglang/srt/disaggregation/decode.py +105 -6
  19. sglang/srt/disaggregation/mini_lb.py +74 -9
  20. sglang/srt/disaggregation/mooncake/conn.py +33 -63
  21. sglang/srt/disaggregation/mooncake/transfer_engine.py +30 -61
  22. sglang/srt/disaggregation/nixl/__init__.py +1 -0
  23. sglang/srt/disaggregation/nixl/conn.py +622 -0
  24. sglang/srt/disaggregation/prefill.py +137 -17
  25. sglang/srt/disaggregation/utils.py +32 -0
  26. sglang/srt/entrypoints/engine.py +4 -0
  27. sglang/srt/entrypoints/http_server.py +3 -7
  28. sglang/srt/entrypoints/verl_engine.py +7 -5
  29. sglang/srt/function_call_parser.py +60 -0
  30. sglang/srt/layers/activation.py +6 -8
  31. sglang/srt/layers/attention/flashattention_backend.py +883 -209
  32. sglang/srt/layers/attention/flashinfer_backend.py +5 -2
  33. sglang/srt/layers/attention/torch_native_backend.py +6 -1
  34. sglang/srt/layers/attention/triton_backend.py +6 -0
  35. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +5 -5
  36. sglang/srt/layers/attention/triton_ops/extend_attention.py +18 -7
  37. sglang/srt/layers/attention/triton_ops/prefill_attention.py +7 -3
  38. sglang/srt/layers/dp_attention.py +1 -1
  39. sglang/srt/layers/layernorm.py +20 -5
  40. sglang/srt/layers/linear.py +17 -3
  41. sglang/srt/layers/moe/ep_moe/layer.py +17 -29
  42. sglang/srt/layers/moe/fused_moe_native.py +4 -0
  43. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +14 -19
  44. sglang/srt/layers/moe/fused_moe_triton/layer.py +7 -0
  45. sglang/srt/layers/moe/topk.py +27 -30
  46. sglang/srt/layers/parameter.py +0 -2
  47. sglang/srt/layers/quantization/__init__.py +1 -0
  48. sglang/srt/layers/quantization/blockwise_int8.py +2 -0
  49. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +9 -2
  50. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +16 -44
  51. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
  52. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +153 -0
  53. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +4 -7
  54. sglang/srt/layers/quantization/deep_gemm.py +378 -0
  55. sglang/srt/layers/quantization/fp8.py +115 -132
  56. sglang/srt/layers/quantization/fp8_kernel.py +213 -88
  57. sglang/srt/layers/quantization/fp8_utils.py +189 -264
  58. sglang/srt/layers/quantization/gptq.py +13 -7
  59. sglang/srt/layers/quantization/modelopt_quant.py +2 -2
  60. sglang/srt/layers/quantization/moe_wna16.py +2 -0
  61. sglang/srt/layers/quantization/utils.py +5 -11
  62. sglang/srt/layers/quantization/w8a8_fp8.py +2 -0
  63. sglang/srt/layers/quantization/w8a8_int8.py +7 -7
  64. sglang/srt/layers/radix_attention.py +15 -0
  65. sglang/srt/layers/rotary_embedding.py +9 -8
  66. sglang/srt/layers/sampler.py +7 -12
  67. sglang/srt/lora/backend/base_backend.py +18 -2
  68. sglang/srt/lora/backend/flashinfer_backend.py +1 -1
  69. sglang/srt/lora/backend/triton_backend.py +1 -1
  70. sglang/srt/lora/layers.py +1 -1
  71. sglang/srt/lora/lora.py +1 -1
  72. sglang/srt/lora/lora_manager.py +1 -1
  73. sglang/srt/managers/data_parallel_controller.py +7 -1
  74. sglang/srt/managers/detokenizer_manager.py +0 -1
  75. sglang/srt/managers/io_struct.py +15 -3
  76. sglang/srt/managers/mm_utils.py +4 -3
  77. sglang/srt/managers/multimodal_processor.py +0 -2
  78. sglang/srt/managers/multimodal_processors/base_processor.py +3 -2
  79. sglang/srt/managers/schedule_batch.py +15 -4
  80. sglang/srt/managers/scheduler.py +28 -77
  81. sglang/srt/managers/tokenizer_manager.py +116 -29
  82. sglang/srt/managers/tp_worker.py +1 -0
  83. sglang/srt/mem_cache/hiradix_cache.py +41 -29
  84. sglang/srt/mem_cache/memory_pool.py +38 -15
  85. sglang/srt/model_executor/cuda_graph_runner.py +15 -10
  86. sglang/srt/model_executor/model_runner.py +39 -31
  87. sglang/srt/models/bert.py +398 -0
  88. sglang/srt/models/deepseek.py +1 -1
  89. sglang/srt/models/deepseek_nextn.py +74 -70
  90. sglang/srt/models/deepseek_v2.py +292 -348
  91. sglang/srt/models/llama.py +5 -5
  92. sglang/srt/models/minicpm3.py +31 -203
  93. sglang/srt/models/minicpmo.py +17 -6
  94. sglang/srt/models/qwen2.py +4 -1
  95. sglang/srt/models/qwen2_moe.py +14 -13
  96. sglang/srt/models/qwen3.py +335 -0
  97. sglang/srt/models/qwen3_moe.py +423 -0
  98. sglang/srt/openai_api/adapter.py +71 -4
  99. sglang/srt/openai_api/protocol.py +6 -1
  100. sglang/srt/reasoning_parser.py +0 -1
  101. sglang/srt/sampling/sampling_batch_info.py +2 -3
  102. sglang/srt/server_args.py +86 -72
  103. sglang/srt/speculative/build_eagle_tree.py +2 -2
  104. sglang/srt/speculative/eagle_utils.py +2 -2
  105. sglang/srt/speculative/eagle_worker.py +6 -14
  106. sglang/srt/utils.py +62 -6
  107. sglang/test/runners.py +5 -1
  108. sglang/test/test_block_fp8.py +167 -0
  109. sglang/test/test_custom_ops.py +1 -1
  110. sglang/test/test_utils.py +3 -1
  111. sglang/version.py +1 -1
  112. {sglang-0.4.5.post1.dist-info → sglang-0.4.5.post3.dist-info}/METADATA +5 -5
  113. {sglang-0.4.5.post1.dist-info → sglang-0.4.5.post3.dist-info}/RECORD +116 -110
  114. {sglang-0.4.5.post1.dist-info → sglang-0.4.5.post3.dist-info}/WHEEL +1 -1
  115. sglang/lang/__init__.py +0 -0
  116. sglang/srt/lora/backend/__init__.py +0 -25
  117. sglang/srt/server.py +0 -18
  118. {sglang-0.4.5.post1.dist-info → sglang-0.4.5.post3.dist-info}/licenses/LICENSE +0 -0
  119. {sglang-0.4.5.post1.dist-info → sglang-0.4.5.post3.dist-info}/top_level.txt +0 -0
@@ -60,7 +60,8 @@ from sglang.srt.managers.io_struct import (
60
60
  CloseSessionReqInput,
61
61
  ExpertDistributionReq,
62
62
  ExpertDistributionReqOutput,
63
- FlushCacheReq,
63
+ FlushCacheReqInput,
64
+ FlushCacheReqOutput,
64
65
  GetInternalStateReq,
65
66
  GetInternalStateReqOutput,
66
67
  GetWeightsByNameReqInput,
@@ -391,6 +392,7 @@ class Scheduler(
391
392
  self.torch_profiler = None
392
393
  self.torch_profiler_output_dir: Optional[str] = None
393
394
  self.profiler_activities: Optional[List[str]] = None
395
+ self.profiler_id: Optional[str] = None
394
396
  self.profiler_target_forward_ct: Optional[int] = None
395
397
 
396
398
  # Init metrics stats
@@ -401,7 +403,7 @@ class Scheduler(
401
403
  [
402
404
  (TokenizedGenerateReqInput, self.handle_generate_request),
403
405
  (TokenizedEmbeddingReqInput, self.handle_embedding_request),
404
- (FlushCacheReq, self.flush_cache_wrapped),
406
+ (FlushCacheReqInput, self.flush_cache_wrapped),
405
407
  (AbortReq, self.abort_request),
406
408
  (OpenSessionReqInput, self.open_session),
407
409
  (CloseSessionReqInput, self.close_session),
@@ -484,9 +486,11 @@ class Scheduler(
484
486
  self.tree_cache = HiRadixCache(
485
487
  req_to_token_pool=self.req_to_token_pool,
486
488
  token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
487
- tp_cache_group=self.tp_worker.get_tp_cpu_group(),
489
+ tp_cache_group=self.tp_cpu_group,
488
490
  page_size=self.page_size,
489
491
  hicache_ratio=server_args.hicache_ratio,
492
+ hicache_size=server_args.hicache_size,
493
+ hicache_write_policy=server_args.hicache_write_policy,
490
494
  )
491
495
  else:
492
496
  self.tree_cache = RadixCache(
@@ -553,7 +557,7 @@ class Scheduler(
553
557
 
554
558
  # The decode requests polling kv cache
555
559
  self.disagg_decode_transfer_queue = DecodeTransferQueue(
556
- gloo_group=self.tp_worker.get_attention_tp_cpu_group(),
560
+ gloo_group=self.attn_tp_cpu_group,
557
561
  req_to_metadata_buffer_idx_allocator=req_to_metadata_buffer_idx_allocator,
558
562
  metadata_buffers=metadata_buffers,
559
563
  )
@@ -568,7 +572,7 @@ class Scheduler(
568
572
  scheduler=self,
569
573
  transfer_queue=self.disagg_decode_transfer_queue,
570
574
  tree_cache=self.tree_cache,
571
- gloo_group=self.tp_worker.get_attention_tp_cpu_group(),
575
+ gloo_group=self.attn_tp_cpu_group,
572
576
  tp_rank=self.tp_rank,
573
577
  tp_size=self.tp_size,
574
578
  bootstrap_port=self.server_args.disaggregation_bootstrap_port,
@@ -597,7 +601,7 @@ class Scheduler(
597
601
  tp_rank=self.tp_rank,
598
602
  tp_size=self.tp_size,
599
603
  bootstrap_port=self.server_args.disaggregation_bootstrap_port,
600
- gloo_group=self.tp_worker.get_attention_tp_cpu_group(),
604
+ gloo_group=self.attn_tp_cpu_group,
601
605
  transfer_backend=self.transfer_backend,
602
606
  scheduler=self,
603
607
  )
@@ -664,70 +668,6 @@ class Scheduler(
664
668
 
665
669
  self.last_batch = batch
666
670
 
667
- @torch.no_grad()
668
- def event_loop_normal_disagg_prefill(self):
669
- """A normal scheduler loop for prefill worker in disaggregation mode."""
670
-
671
- while True:
672
- recv_reqs = self.recv_requests()
673
- self.process_input_requests(recv_reqs)
674
- self.waiting_queue.extend(
675
- self.disagg_prefill_pending_queue.pop_bootstrapped()
676
- )
677
- self.process_prefill_chunk()
678
- batch = self.get_new_batch_prefill()
679
- self.cur_batch = batch
680
-
681
- if batch:
682
- result = self.run_batch(batch)
683
- self.process_batch_result_disagg_prefill(batch, result)
684
-
685
- if len(self.disagg_prefill_inflight_queue) > 0:
686
- self.process_disagg_prefill_inflight_queue()
687
-
688
- if batch is None and len(self.disagg_prefill_inflight_queue) == 0:
689
- self.check_memory()
690
- self.new_token_ratio = self.init_new_token_ratio
691
-
692
- self.last_batch = batch
693
- # HACK (byronhsu): reset the batch_is_full flag because we never enter update_running_batch which resets it
694
- # Otherwise, it hangs under high concurrency
695
- self.running_batch.batch_is_full = False
696
-
697
- @torch.no_grad()
698
- def event_loop_normal_disagg_decode(self):
699
- """A normal scheduler loop for decode worker in disaggregation mode."""
700
-
701
- while True:
702
- recv_reqs = self.recv_requests()
703
- self.process_input_requests(recv_reqs)
704
- # polling and allocating kv cache
705
- self.process_decode_queue()
706
- batch = self.get_next_disagg_decode_batch_to_run()
707
- self.cur_batch = batch
708
-
709
- if batch:
710
- # Generate fake extend output.
711
- if batch.forward_mode.is_extend():
712
- # Note: Logprobs should be handled on the prefill engine.
713
- self.stream_output(
714
- batch.reqs, [False for _ in range(len(batch.reqs))]
715
- )
716
- else:
717
- result = self.run_batch(batch)
718
- self.process_batch_result(batch, result)
719
-
720
- if batch is None and (
721
- len(self.disagg_decode_transfer_queue.queue)
722
- + len(self.disagg_decode_prealloc_queue.queue)
723
- == 0
724
- ):
725
- # When the server is idle, do self-check and re-init some states
726
- self.check_memory()
727
- self.new_token_ratio = self.init_new_token_ratio
728
-
729
- self.last_batch = batch
730
-
731
671
  def recv_requests(self) -> List[Req]:
732
672
  """Receive results at tp_rank = 0 and broadcast it to all other TP ranks."""
733
673
  if self.attn_tp_rank == 0:
@@ -1659,8 +1599,9 @@ class Scheduler(
1659
1599
  time.sleep(5)
1660
1600
  self.parent_process.send_signal(signal.SIGQUIT)
1661
1601
 
1662
- def flush_cache_wrapped(self, recv_req: FlushCacheReq):
1663
- self.flush_cache()
1602
+ def flush_cache_wrapped(self, recv_req: FlushCacheReqInput):
1603
+ success = self.flush_cache()
1604
+ return FlushCacheReqOutput(success=success)
1664
1605
 
1665
1606
  def flush_cache(self):
1666
1607
  """Flush the memory pool and cache."""
@@ -1869,6 +1810,7 @@ class Scheduler(
1869
1810
  recv_req.activities,
1870
1811
  recv_req.with_stack,
1871
1812
  recv_req.record_shapes,
1813
+ recv_req.profile_id,
1872
1814
  )
1873
1815
  else:
1874
1816
  return self.stop_profile()
@@ -1880,6 +1822,7 @@ class Scheduler(
1880
1822
  activities: Optional[List[str]],
1881
1823
  with_stack: Optional[bool],
1882
1824
  record_shapes: Optional[bool],
1825
+ profile_id: Optional[str],
1883
1826
  ) -> None:
1884
1827
  if self.profiler_activities:
1885
1828
  return ProfileReqOutput(
@@ -1894,9 +1837,11 @@ class Scheduler(
1894
1837
 
1895
1838
  self.torch_profiler_output_dir = output_dir
1896
1839
  self.profiler_activities = activities
1840
+ self.profiler_id = profile_id
1897
1841
  logger.info(
1898
- "Profiling starts. Traces will be saved to: %s",
1842
+ "Profiling starts. Traces will be saved to: %s (with id %s)",
1899
1843
  self.torch_profiler_output_dir,
1844
+ self.profiler_id,
1900
1845
  )
1901
1846
 
1902
1847
  activity_map = {
@@ -1938,14 +1883,14 @@ class Scheduler(
1938
1883
  self.torch_profiler.export_chrome_trace(
1939
1884
  os.path.join(
1940
1885
  self.torch_profiler_output_dir,
1941
- str(time.time()) + f"-TP-{self.tp_rank}" + ".trace.json.gz",
1886
+ self.profiler_id + f"-TP-{self.tp_rank}" + ".trace.json.gz",
1942
1887
  )
1943
1888
  )
1944
1889
 
1945
1890
  if "MEM" in self.profiler_activities:
1946
1891
  memory_profile_path = os.path.join(
1947
1892
  self.torch_profiler_output_dir,
1948
- str(time.time()) + f"-TP-{self.tp_rank}-memory" + ".pickle",
1893
+ self.profiler_id + f"-TP-{self.tp_rank}-memory" + ".pickle",
1949
1894
  )
1950
1895
  torch.cuda.memory._dump_snapshot(memory_profile_path)
1951
1896
  torch.cuda.memory._record_memory_history(enabled=None)
@@ -2069,9 +2014,15 @@ def run_scheduler_process(
2069
2014
  else:
2070
2015
  scheduler.event_loop_normal()
2071
2016
  elif disaggregation_mode == DisaggregationMode.PREFILL:
2072
- scheduler.event_loop_normal_disagg_prefill()
2017
+ if scheduler.enable_overlap:
2018
+ scheduler.event_loop_overlap_disagg_prefill()
2019
+ else:
2020
+ scheduler.event_loop_normal_disagg_prefill()
2073
2021
  elif disaggregation_mode == DisaggregationMode.DECODE:
2074
- scheduler.event_loop_normal_disagg_decode()
2022
+ if scheduler.enable_overlap:
2023
+ scheduler.event_loop_overlap_disagg_decode()
2024
+ else:
2025
+ scheduler.event_loop_normal_disagg_decode()
2075
2026
 
2076
2027
  except Exception:
2077
2028
  traceback = get_exception_traceback()
@@ -66,7 +66,8 @@ from sglang.srt.managers.io_struct import (
66
66
  EmbeddingReqInput,
67
67
  ExpertDistributionReq,
68
68
  ExpertDistributionReqOutput,
69
- FlushCacheReq,
69
+ FlushCacheReqInput,
70
+ FlushCacheReqOutput,
70
71
  GenerateReqInput,
71
72
  GetInternalStateReq,
72
73
  GetInternalStateReqOutput,
@@ -264,6 +265,9 @@ class TokenizerManager:
264
265
  self.resume_memory_occupation_communicator = _Communicator(
265
266
  self.send_to_scheduler, server_args.dp_size
266
267
  )
268
+ self.flush_cache_communicator = _Communicator(
269
+ self.send_to_scheduler, server_args.dp_size
270
+ )
267
271
  self.start_profile_communicator = _Communicator(
268
272
  self.send_to_scheduler, server_args.dp_size
269
273
  )
@@ -314,6 +318,10 @@ class TokenizerManager:
314
318
  ResumeMemoryOccupationReqOutput,
315
319
  self.resume_memory_occupation_communicator.handle_recv,
316
320
  ),
321
+ (
322
+ FlushCacheReqOutput,
323
+ self.flush_cache_communicator.handle_recv,
324
+ ),
317
325
  (
318
326
  ProfileReqOutput,
319
327
  self.start_profile_communicator.handle_recv,
@@ -415,38 +423,60 @@ class TokenizerManager:
415
423
  )
416
424
  if image_inputs and "input_ids" in image_inputs:
417
425
  input_ids = image_inputs["input_ids"]
418
- if self.is_generation:
419
- return_logprob = obj.return_logprob
420
- logprob_start_len = obj.logprob_start_len
421
- top_logprobs_num = obj.top_logprobs_num
422
- token_ids_logprob = obj.token_ids_logprob
423
- session_params = (
424
- SessionParams(**obj.session_params) if obj.session_params else None
425
- )
426
+
427
+ self._validate_token_len(obj, input_ids)
428
+ return self._create_tokenized_object(
429
+ obj, input_text, input_ids, input_embeds, image_inputs
430
+ )
431
+
432
+ def _validate_token_len(
433
+ self, obj: Union[GenerateReqInput, EmbeddingReqInput], input_ids: List[int]
434
+ ) -> None:
435
+ """Validates that the input token count and the requested token count doesn't exceed the model's context length."""
426
436
 
427
437
  input_token_num = len(input_ids) if input_ids is not None else 0
438
+ # Check if input alone exceeds context length
428
439
  if input_token_num >= self.context_len:
429
440
  raise ValueError(
430
441
  f"The input ({input_token_num} tokens) is longer than the "
431
442
  f"model's context length ({self.context_len} tokens)."
432
443
  )
433
444
 
445
+ # Check total tokens (input + max_new_tokens)
446
+ max_new_tokens = obj.sampling_params.get("max_new_tokens")
434
447
  if (
435
- obj.sampling_params.get("max_new_tokens") is not None
436
- and obj.sampling_params.get("max_new_tokens") + input_token_num
437
- >= self.context_len
448
+ max_new_tokens is not None
449
+ and (max_new_tokens + input_token_num) >= self.context_len
438
450
  ):
439
- raise ValueError(
451
+ total_tokens = max_new_tokens + input_token_num
452
+ error_msg = (
440
453
  f"Requested token count exceeds the model's maximum context length "
441
- f"of {self.context_len} tokens. You requested a total of "
442
- f"{obj.sampling_params.get('max_new_tokens') + input_token_num} "
454
+ f"of {self.context_len} tokens. You requested a total of {total_tokens} "
443
455
  f"tokens: {input_token_num} tokens from the input messages and "
444
- f"{obj.sampling_params.get('max_new_tokens')} tokens for the "
445
- f"completion. Please reduce the number of tokens in the input "
446
- f"messages or the completion to fit within the limit."
456
+ f"{max_new_tokens} tokens for the completion. Please reduce the number "
457
+ f"of tokens in the input messages or the completion to fit within the limit."
458
+ )
459
+ raise ValueError(error_msg)
460
+
461
+ def _create_tokenized_object(
462
+ self,
463
+ obj: Union[GenerateReqInput, EmbeddingReqInput],
464
+ input_text: str,
465
+ input_ids: List[int],
466
+ input_embeds: Optional[Union[List[float], None]] = None,
467
+ image_inputs: Optional[Dict] = None,
468
+ ) -> Union[TokenizedGenerateReqInput, TokenizedEmbeddingReqInput]:
469
+ """Create a tokenized request object from common parameters."""
470
+
471
+ if self.is_generation:
472
+ return_logprob = obj.return_logprob
473
+ logprob_start_len = obj.logprob_start_len
474
+ top_logprobs_num = obj.top_logprobs_num
475
+ token_ids_logprob = obj.token_ids_logprob
476
+ session_params = (
477
+ SessionParams(**obj.session_params) if obj.session_params else None
447
478
  )
448
479
 
449
- # Parse sampling parameters
450
480
  sampling_params = SamplingParams(**obj.sampling_params)
451
481
  sampling_params.normalize(self.tokenizer)
452
482
  sampling_params.verify()
@@ -483,6 +513,50 @@ class TokenizerManager:
483
513
 
484
514
  return tokenized_obj
485
515
 
516
+ async def _batch_tokenize_and_process(
517
+ self, batch_size: int, obj: Union[GenerateReqInput, EmbeddingReqInput]
518
+ ) -> List[Union[TokenizedGenerateReqInput, TokenizedEmbeddingReqInput]]:
519
+ """Handle batch tokenization for text inputs only."""
520
+ logger.debug(f"Starting batch tokenization for {batch_size} text requests")
521
+
522
+ # Collect requests and texts
523
+ requests = [obj[i] for i in range(batch_size)]
524
+ texts = [req.text for req in requests]
525
+
526
+ # Batch tokenize all texts
527
+ encoded = self.tokenizer(texts)
528
+ input_ids_list = encoded["input_ids"]
529
+
530
+ # Process all requests
531
+ tokenized_objs = []
532
+ for i, req in enumerate(requests):
533
+ self._validate_token_len(obj[i], input_ids_list[i])
534
+ tokenized_objs.append(
535
+ self._create_tokenized_object(
536
+ req, req.text, input_ids_list[i], None, None
537
+ )
538
+ )
539
+ logger.debug(f"Completed batch processing for {batch_size} requests")
540
+ return tokenized_objs
541
+
542
+ def _validate_batch_tokenization_constraints(
543
+ self, batch_size: int, obj: Union[GenerateReqInput, EmbeddingReqInput]
544
+ ) -> None:
545
+ """Validate constraints for batch tokenization processing."""
546
+ for i in range(batch_size):
547
+ if self.is_generation and obj[i].image_data:
548
+ raise ValueError(
549
+ "For image input processing do not set `enable_tokenizer_batch_encode`."
550
+ )
551
+ if obj[i].input_ids is not None:
552
+ raise ValueError(
553
+ "Batch tokenization is not needed for pre-tokenized input_ids. Do not set `enable_tokenizer_batch_encode`."
554
+ )
555
+ if obj[i].input_embeds is not None:
556
+ raise ValueError(
557
+ "Batch tokenization is not needed for input_embeds. Do not set `enable_tokenizer_batch_encode`."
558
+ )
559
+
486
560
  def _send_one_request(
487
561
  self,
488
562
  obj: Union[GenerateReqInput, EmbeddingReqInput],
@@ -560,14 +634,27 @@ class TokenizerManager:
560
634
 
561
635
  generators = []
562
636
  rids = []
637
+
563
638
  if getattr(obj, "parallel_sample_num", 1) == 1:
564
- # Send all requests
565
- for i in range(batch_size):
566
- tmp_obj = obj[i]
567
- tokenized_obj = await self._tokenize_one_request(tmp_obj)
568
- self._send_one_request(tmp_obj, tokenized_obj, created_time)
569
- generators.append(self._wait_one_response(tmp_obj, request))
570
- rids.append(tmp_obj.rid)
639
+ if self.server_args.enable_tokenizer_batch_encode:
640
+ # Validate batch tokenization constraints
641
+ self._validate_batch_tokenization_constraints(batch_size, obj)
642
+
643
+ tokenized_objs = await self._batch_tokenize_and_process(batch_size, obj)
644
+
645
+ for i, tokenized_obj in enumerate(tokenized_objs):
646
+ tmp_obj = obj[i]
647
+ self._send_one_request(tmp_obj, tokenized_obj, created_time)
648
+ generators.append(self._wait_one_response(tmp_obj, request))
649
+ rids.append(tmp_obj.rid)
650
+ else:
651
+ # Sequential tokenization and processing
652
+ for i in range(batch_size):
653
+ tmp_obj = obj[i]
654
+ tokenized_obj = await self._tokenize_one_request(tmp_obj)
655
+ self._send_one_request(tmp_obj, tokenized_obj, created_time)
656
+ generators.append(self._wait_one_response(tmp_obj, request))
657
+ rids.append(tmp_obj.rid)
571
658
  else:
572
659
  # FIXME: When using batch and parallel_sample_num together, the perf is not optimal.
573
660
  if batch_size > 128:
@@ -628,9 +715,8 @@ class TokenizerManager:
628
715
  except StopAsyncIteration:
629
716
  pass
630
717
 
631
- def flush_cache(self):
632
- req = FlushCacheReq()
633
- self.send_to_scheduler.send_pyobj(req)
718
+ async def flush_cache(self) -> FlushCacheReqOutput:
719
+ return (await self.flush_cache_communicator(FlushCacheReqInput()))[0]
634
720
 
635
721
  def abort_request(self, rid: str):
636
722
  if rid not in self.rid_to_state:
@@ -650,6 +736,7 @@ class TokenizerManager:
650
736
  output_dir=output_dir,
651
737
  num_steps=num_steps,
652
738
  activities=activities,
739
+ profile_id=str(time.time()),
653
740
  )
654
741
  result = (await self.start_profile_communicator(req))[0]
655
742
  if not result.success:
@@ -116,6 +116,7 @@ class TpModelWorker:
116
116
  ),
117
117
  self.model_runner.req_to_token_pool.size,
118
118
  )
119
+ assert self.max_running_requests > 0, "max_running_request is zero"
119
120
  self.max_req_len = min(
120
121
  self.model_config.context_len - 1,
121
122
  self.max_total_num_tokens - 1,
@@ -29,15 +29,17 @@ class HiRadixCache(RadixCache):
29
29
  tp_cache_group: torch.distributed.ProcessGroup,
30
30
  page_size: int,
31
31
  hicache_ratio: float,
32
+ hicache_size: int,
33
+ hicache_write_policy: str,
32
34
  ):
33
35
  self.kv_cache = token_to_kv_pool_allocator.get_kvcache()
34
36
  if isinstance(self.kv_cache, MHATokenToKVPool):
35
37
  self.token_to_kv_pool_host = MHATokenToKVPoolHost(
36
- self.kv_cache, hicache_ratio, page_size
38
+ self.kv_cache, hicache_ratio, hicache_size, page_size
37
39
  )
38
40
  elif isinstance(self.kv_cache, MLATokenToKVPool):
39
41
  self.token_to_kv_pool_host = MLATokenToKVPoolHost(
40
- self.kv_cache, hicache_ratio, page_size
42
+ self.kv_cache, hicache_ratio, hicache_size, page_size
41
43
  )
42
44
  else:
43
45
  raise ValueError(f"HiRadixCache only supports MHA and MLA yet")
@@ -50,6 +52,7 @@ class HiRadixCache(RadixCache):
50
52
  self.token_to_kv_pool_host,
51
53
  page_size,
52
54
  load_cache_event=self.load_cache_event,
55
+ write_policy=hicache_write_policy,
53
56
  )
54
57
 
55
58
  # record the nodes with ongoing write through
@@ -57,7 +60,9 @@ class HiRadixCache(RadixCache):
57
60
  # record the node segments with ongoing load back
58
61
  self.ongoing_load_back = {}
59
62
  # todo: dynamically adjust the threshold
60
- self.write_through_threshold = 1
63
+ self.write_through_threshold = (
64
+ 1 if hicache_write_policy == "write_through" else 3
65
+ )
61
66
  self.load_back_threshold = 10
62
67
  super().__init__(
63
68
  req_to_token_pool, token_to_kv_pool_allocator, page_size, disable=False
@@ -76,7 +81,7 @@ class HiRadixCache(RadixCache):
76
81
  height += 1
77
82
  return height
78
83
 
79
- def write_backup(self, node: TreeNode):
84
+ def write_backup(self, node: TreeNode, write_back=False):
80
85
  host_indices = self.cache_controller.write(
81
86
  device_indices=node.value,
82
87
  node_id=node.id,
@@ -90,21 +95,29 @@ class HiRadixCache(RadixCache):
90
95
  if host_indices is not None:
91
96
  node.host_value = host_indices
92
97
  self.ongoing_write_through[node.id] = node
93
- self.inc_lock_ref(node)
98
+ if not write_back:
99
+ # no need to lock nodes if write back
100
+ self.inc_lock_ref(node)
94
101
  else:
95
- return None
102
+ return 0
96
103
 
97
104
  return len(host_indices)
98
105
 
99
106
  def inc_hit_count(self, node: TreeNode):
100
- if self.cache_controller.write_policy != "write_through_selective":
107
+ if node.backuped or self.cache_controller.write_policy == "write_back":
101
108
  return
102
109
  node.hit_count += 1
103
- if node.host_value is None and node.hit_count > self.write_through_threshold:
110
+ if node.hit_count >= self.write_through_threshold:
104
111
  self.write_backup(node)
105
112
  node.hit_count = 0
106
113
 
107
- def writing_check(self):
114
+ def writing_check(self, write_back=False):
115
+ if write_back:
116
+ # blocking till all write back complete
117
+ while len(self.ongoing_write_through) > 0:
118
+ ack_id = self.cache_controller.ack_write_queue.get()
119
+ del self.ongoing_write_through[ack_id]
120
+ return
108
121
  queue_size = torch.tensor(
109
122
  self.cache_controller.ack_write_queue.qsize(), dtype=torch.int
110
123
  )
@@ -143,28 +156,25 @@ class HiRadixCache(RadixCache):
143
156
  heapq.heapify(leaves)
144
157
 
145
158
  num_evicted = 0
146
- pending_nodes = []
159
+ write_back_nodes = []
147
160
  while num_evicted < num_tokens and len(leaves):
148
161
  x = heapq.heappop(leaves)
149
162
 
150
163
  if x.lock_ref > 0:
151
164
  continue
152
165
 
153
- if x.host_value is None:
166
+ if not x.backuped:
154
167
  if self.cache_controller.write_policy == "write_back":
155
- num_evicted += self.write_backup(x)
156
- elif self.cache_controller.write_policy == "write_through_selective":
157
- num_evicted += self._evict_write_through_selective(x)
168
+ # write to host if the node is not backuped
169
+ num_evicted += self.write_backup(x, write_back=True)
170
+ write_back_nodes.append(x)
158
171
  else:
159
- assert (
160
- self.cache_controller.write_policy != "write_through"
161
- ), "write_through should be inclusive"
162
- raise NotImplementedError
172
+ num_evicted += self._evict_regular(x)
163
173
  else:
164
- num_evicted += self._evict_write_through(x)
174
+ num_evicted += self._evict_backuped(x)
165
175
 
166
176
  for child in x.parent.children.values():
167
- if child in pending_nodes:
177
+ if child in write_back_nodes:
168
178
  continue
169
179
  if not child.evicted:
170
180
  break
@@ -173,12 +183,12 @@ class HiRadixCache(RadixCache):
173
183
  heapq.heappush(leaves, x.parent)
174
184
 
175
185
  if self.cache_controller.write_policy == "write_back":
176
- # blocking till all write back complete
177
- while len(self.ongoing_write_through) > 0:
178
- self.writing_check()
179
- time.sleep(0.1)
186
+ self.writing_check(write_back=True)
187
+ for node in write_back_nodes:
188
+ assert node.backuped
189
+ self._evict_backuped(node)
180
190
 
181
- def _evict_write_through(self, node: TreeNode):
191
+ def _evict_backuped(self, node: TreeNode):
182
192
  # evict a node already written to host
183
193
  num_evicted = self.cache_controller.evict_device(node.value, node.host_value)
184
194
  assert num_evicted > 0
@@ -186,7 +196,7 @@ class HiRadixCache(RadixCache):
186
196
  node.value = None
187
197
  return num_evicted
188
198
 
189
- def _evict_write_through_selective(self, node: TreeNode):
199
+ def _evict_regular(self, node: TreeNode):
190
200
  # evict a node not initiated write to host
191
201
  self.cache_controller.mem_pool_device_allocator.free(node.value)
192
202
  num_evicted = len(node.value)
@@ -335,11 +345,13 @@ class HiRadixCache(RadixCache):
335
345
  prefix_len = self.key_match_fn(child.key, key)
336
346
  if prefix_len < len(child.key):
337
347
  new_node = self._split_node(child.key, child, prefix_len)
348
+ self.inc_hit_count(new_node)
338
349
  if not new_node.evicted:
339
350
  value.append(new_node.value)
340
351
  node = new_node
341
352
  break
342
353
  else:
354
+ self.inc_hit_count(child)
343
355
  if not child.evicted:
344
356
  value.append(child.value)
345
357
  node = child
@@ -365,7 +377,7 @@ class HiRadixCache(RadixCache):
365
377
  else:
366
378
  new_node.value = child.value[:split_len]
367
379
  child.value = child.value[split_len:]
368
- if child.host_value is not None:
380
+ if child.backuped:
369
381
  new_node.host_value = child.host_value[:split_len]
370
382
  child.host_value = child.host_value[split_len:]
371
383
  child.parent = new_node
@@ -422,8 +434,8 @@ class HiRadixCache(RadixCache):
422
434
  node.children[child_key] = new_node
423
435
  self.evictable_size_ += len(value)
424
436
 
425
- if self.cache_controller.write_policy == "write_through":
426
- self.write_backup(new_node)
437
+ if self.cache_controller.write_policy != "write_back":
438
+ self.inc_hit_count(new_node)
427
439
  return total_prefix_length
428
440
 
429
441
  def _collect_leaves_device(self):