sglang 0.4.7__py3-none-any.whl → 0.4.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (152) hide show
  1. sglang/__init__.py +2 -0
  2. sglang/api.py +7 -0
  3. sglang/bench_one_batch.py +8 -6
  4. sglang/bench_serving.py +1 -1
  5. sglang/lang/interpreter.py +40 -1
  6. sglang/lang/ir.py +27 -0
  7. sglang/math_utils.py +8 -0
  8. sglang/srt/_custom_ops.py +2 -2
  9. sglang/srt/code_completion_parser.py +2 -44
  10. sglang/srt/configs/model_config.py +6 -0
  11. sglang/srt/constants.py +3 -0
  12. sglang/srt/conversation.py +19 -3
  13. sglang/srt/custom_op.py +5 -1
  14. sglang/srt/disaggregation/base/__init__.py +1 -1
  15. sglang/srt/disaggregation/base/conn.py +25 -11
  16. sglang/srt/disaggregation/common/__init__.py +5 -1
  17. sglang/srt/disaggregation/common/utils.py +42 -0
  18. sglang/srt/disaggregation/decode.py +211 -72
  19. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +4 -3
  20. sglang/srt/disaggregation/fake/__init__.py +1 -1
  21. sglang/srt/disaggregation/fake/conn.py +15 -9
  22. sglang/srt/disaggregation/mini_lb.py +34 -4
  23. sglang/srt/disaggregation/mooncake/__init__.py +1 -1
  24. sglang/srt/disaggregation/mooncake/conn.py +30 -29
  25. sglang/srt/disaggregation/nixl/__init__.py +6 -1
  26. sglang/srt/disaggregation/nixl/conn.py +17 -12
  27. sglang/srt/disaggregation/prefill.py +144 -55
  28. sglang/srt/disaggregation/utils.py +155 -123
  29. sglang/srt/distributed/parallel_state.py +12 -4
  30. sglang/srt/entrypoints/engine.py +37 -29
  31. sglang/srt/entrypoints/http_server.py +153 -72
  32. sglang/srt/entrypoints/http_server_engine.py +0 -3
  33. sglang/srt/entrypoints/openai/__init__.py +0 -0
  34. sglang/srt/{openai_api → entrypoints/openai}/protocol.py +84 -10
  35. sglang/srt/entrypoints/openai/serving_base.py +149 -0
  36. sglang/srt/entrypoints/openai/serving_chat.py +921 -0
  37. sglang/srt/entrypoints/openai/serving_completions.py +424 -0
  38. sglang/srt/entrypoints/openai/serving_embedding.py +169 -0
  39. sglang/srt/entrypoints/openai/serving_rerank.py +102 -0
  40. sglang/srt/entrypoints/openai/serving_score.py +61 -0
  41. sglang/srt/entrypoints/openai/usage_processor.py +81 -0
  42. sglang/srt/entrypoints/openai/utils.py +72 -0
  43. sglang/srt/eplb_simulator/__init__.py +1 -0
  44. sglang/srt/eplb_simulator/reader.py +51 -0
  45. sglang/srt/function_call/base_format_detector.py +7 -4
  46. sglang/srt/function_call/deepseekv3_detector.py +1 -1
  47. sglang/srt/function_call/ebnf_composer.py +64 -10
  48. sglang/srt/function_call/function_call_parser.py +6 -6
  49. sglang/srt/function_call/llama32_detector.py +1 -1
  50. sglang/srt/function_call/mistral_detector.py +1 -1
  51. sglang/srt/function_call/pythonic_detector.py +1 -1
  52. sglang/srt/function_call/qwen25_detector.py +1 -1
  53. sglang/srt/{openai_api/utils.py → jinja_template_utils.py} +6 -5
  54. sglang/srt/layers/activation.py +40 -3
  55. sglang/srt/layers/attention/aiter_backend.py +20 -4
  56. sglang/srt/layers/attention/base_attn_backend.py +1 -1
  57. sglang/srt/layers/attention/cutlass_mla_backend.py +39 -15
  58. sglang/srt/layers/attention/flashattention_backend.py +71 -72
  59. sglang/srt/layers/attention/flashinfer_backend.py +10 -8
  60. sglang/srt/layers/attention/flashinfer_mla_backend.py +29 -28
  61. sglang/srt/layers/attention/flashmla_backend.py +7 -12
  62. sglang/srt/layers/attention/tbo_backend.py +3 -3
  63. sglang/srt/layers/attention/triton_backend.py +138 -130
  64. sglang/srt/layers/attention/triton_ops/decode_attention.py +2 -7
  65. sglang/srt/layers/attention/vision.py +51 -24
  66. sglang/srt/layers/communicator.py +28 -10
  67. sglang/srt/layers/dp_attention.py +11 -2
  68. sglang/srt/layers/layernorm.py +29 -2
  69. sglang/srt/layers/linear.py +0 -4
  70. sglang/srt/layers/logits_processor.py +2 -14
  71. sglang/srt/layers/moe/ep_moe/kernels.py +165 -7
  72. sglang/srt/layers/moe/ep_moe/layer.py +249 -33
  73. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +11 -37
  74. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=384,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  75. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +7 -4
  76. sglang/srt/layers/moe/fused_moe_triton/layer.py +75 -12
  77. sglang/srt/layers/moe/topk.py +107 -12
  78. sglang/srt/layers/pooler.py +56 -0
  79. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +6 -2
  80. sglang/srt/layers/quantization/deep_gemm_wrapper/__init__.py +1 -0
  81. sglang/srt/layers/quantization/{deep_gemm.py → deep_gemm_wrapper/compile_utils.py} +23 -80
  82. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +32 -0
  83. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +110 -0
  84. sglang/srt/layers/quantization/fp8.py +25 -17
  85. sglang/srt/layers/quantization/fp8_kernel.py +44 -15
  86. sglang/srt/layers/quantization/fp8_utils.py +87 -22
  87. sglang/srt/layers/quantization/modelopt_quant.py +62 -8
  88. sglang/srt/layers/quantization/utils.py +5 -2
  89. sglang/srt/layers/radix_attention.py +2 -3
  90. sglang/srt/layers/rotary_embedding.py +42 -2
  91. sglang/srt/layers/sampler.py +1 -1
  92. sglang/srt/lora/lora_manager.py +249 -105
  93. sglang/srt/lora/mem_pool.py +53 -50
  94. sglang/srt/lora/utils.py +1 -1
  95. sglang/srt/managers/cache_controller.py +33 -14
  96. sglang/srt/managers/io_struct.py +31 -10
  97. sglang/srt/managers/multimodal_processors/base_processor.py +2 -2
  98. sglang/srt/managers/multimodal_processors/vila.py +85 -0
  99. sglang/srt/managers/schedule_batch.py +79 -37
  100. sglang/srt/managers/schedule_policy.py +70 -56
  101. sglang/srt/managers/scheduler.py +220 -79
  102. sglang/srt/managers/template_manager.py +226 -0
  103. sglang/srt/managers/tokenizer_manager.py +40 -10
  104. sglang/srt/managers/tp_worker.py +12 -2
  105. sglang/srt/managers/tp_worker_overlap_thread.py +11 -0
  106. sglang/srt/mem_cache/{paged_allocator.py → allocator.py} +125 -34
  107. sglang/srt/mem_cache/base_prefix_cache.py +52 -8
  108. sglang/srt/mem_cache/chunk_cache.py +11 -15
  109. sglang/srt/mem_cache/hiradix_cache.py +38 -25
  110. sglang/srt/mem_cache/memory_pool.py +213 -505
  111. sglang/srt/mem_cache/memory_pool_host.py +380 -0
  112. sglang/srt/mem_cache/radix_cache.py +56 -28
  113. sglang/srt/model_executor/cuda_graph_runner.py +198 -100
  114. sglang/srt/model_executor/forward_batch_info.py +32 -10
  115. sglang/srt/model_executor/model_runner.py +28 -12
  116. sglang/srt/model_loader/loader.py +16 -2
  117. sglang/srt/model_loader/weight_utils.py +11 -2
  118. sglang/srt/models/bert.py +113 -13
  119. sglang/srt/models/deepseek_nextn.py +29 -27
  120. sglang/srt/models/deepseek_v2.py +213 -173
  121. sglang/srt/models/glm4.py +312 -0
  122. sglang/srt/models/internvl.py +46 -102
  123. sglang/srt/models/mimo_mtp.py +2 -18
  124. sglang/srt/models/roberta.py +117 -9
  125. sglang/srt/models/vila.py +305 -0
  126. sglang/srt/reasoning_parser.py +21 -11
  127. sglang/srt/sampling/sampling_batch_info.py +24 -0
  128. sglang/srt/sampling/sampling_params.py +2 -0
  129. sglang/srt/server_args.py +351 -238
  130. sglang/srt/speculative/build_eagle_tree.py +1 -1
  131. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +131 -9
  132. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +130 -14
  133. sglang/srt/speculative/eagle_utils.py +468 -116
  134. sglang/srt/speculative/eagle_worker.py +258 -84
  135. sglang/srt/torch_memory_saver_adapter.py +19 -15
  136. sglang/srt/two_batch_overlap.py +4 -2
  137. sglang/srt/utils.py +235 -11
  138. sglang/test/attention/test_prefix_chunk_info.py +2 -0
  139. sglang/test/runners.py +38 -3
  140. sglang/test/test_block_fp8.py +1 -0
  141. sglang/test/test_block_fp8_deep_gemm_blackwell.py +252 -0
  142. sglang/test/test_block_fp8_ep.py +2 -0
  143. sglang/test/test_utils.py +4 -1
  144. sglang/utils.py +9 -0
  145. sglang/version.py +1 -1
  146. {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/METADATA +8 -14
  147. {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/RECORD +150 -128
  148. sglang/srt/entrypoints/verl_engine.py +0 -179
  149. sglang/srt/openai_api/adapter.py +0 -1990
  150. {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/WHEEL +0 -0
  151. {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/licenses/LICENSE +0 -0
  152. {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/top_level.txt +0 -0
@@ -38,7 +38,7 @@ import logging
38
38
  import threading
39
39
  from enum import Enum, auto
40
40
  from http import HTTPStatus
41
- from typing import TYPE_CHECKING, List, Optional, Set, Tuple, Union
41
+ from typing import TYPE_CHECKING, Any, List, Optional, Set, Tuple, Union
42
42
 
43
43
  import numpy as np
44
44
  import torch
@@ -54,9 +54,10 @@ from sglang.srt.disaggregation.decode_schedule_batch_mixin import (
54
54
  )
55
55
  from sglang.srt.distributed.parallel_state import get_tensor_model_parallel_rank
56
56
  from sglang.srt.layers.multimodal import gpu_tensor_hash
57
+ from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
57
58
  from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
58
59
  from sglang.srt.mem_cache.chunk_cache import ChunkCache
59
- from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPoolAllocator
60
+ from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
60
61
  from sglang.srt.metrics.collector import TimeStats
61
62
  from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardMode
62
63
  from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
@@ -72,32 +73,35 @@ INIT_INCREMENTAL_DETOKENIZATION_OFFSET = 5
72
73
 
73
74
  GLOBAL_SERVER_ARGS_KEYS = [
74
75
  "attention_backend",
76
+ "mm_attention_backend",
75
77
  "debug_tensor_dump_inject",
76
78
  "debug_tensor_dump_output_folder",
77
79
  "chunked_prefill_size",
78
- "deepep_mode",
79
80
  "device",
80
81
  "disable_chunked_prefix_cache",
81
82
  "disable_radix_cache",
82
- "enable_deepep_moe",
83
83
  "enable_dp_attention",
84
84
  "enable_two_batch_overlap",
85
85
  "enable_dp_lm_head",
86
+ "enable_deepep_moe",
87
+ "deepep_mode",
86
88
  "enable_ep_moe",
89
+ "enable_flashinfer_moe",
90
+ "moe_dense_tp_size",
91
+ "ep_dispatch_algorithm",
87
92
  "deepep_config",
93
+ "ep_num_redundant_experts",
88
94
  "enable_nan_detection",
89
95
  "flashinfer_mla_disable_ragged",
90
96
  "max_micro_batch_size",
91
- "moe_dense_tp_size",
92
- "ep_dispatch_algorithm",
93
97
  "disable_shared_experts_fusion",
94
98
  "sampling_backend",
95
99
  "speculative_accept_threshold_acc",
96
100
  "speculative_accept_threshold_single",
97
101
  "torchao_config",
98
102
  "triton_attention_reduce_in_fp32",
99
- "ep_num_redundant_experts",
100
- "mm_attention_backend",
103
+ "num_reserved_decode_tokens",
104
+ "weight_loader_disable_mmap",
101
105
  ]
102
106
 
103
107
  # Put some global args for easy access
@@ -435,7 +439,7 @@ class Req:
435
439
  self,
436
440
  rid: str,
437
441
  origin_input_text: str,
438
- origin_input_ids: Tuple[int],
442
+ origin_input_ids: List[int],
439
443
  sampling_params: SamplingParams,
440
444
  return_logprob: bool = False,
441
445
  top_logprobs_num: int = 0,
@@ -444,6 +448,7 @@ class Req:
444
448
  origin_input_ids_unpadded: Optional[Tuple[int]] = None,
445
449
  lora_path: Optional[str] = None,
446
450
  input_embeds: Optional[List[List[float]]] = None,
451
+ token_type_ids: List[int] = None,
447
452
  session_id: Optional[str] = None,
448
453
  custom_logit_processor: Optional[str] = None,
449
454
  return_hidden_states: bool = False,
@@ -465,10 +470,13 @@ class Req:
465
470
  # Each decode stage's output ids
466
471
  self.output_ids = []
467
472
  # fill_ids = origin_input_ids + output_ids. Updated if chunked.
468
- self.fill_ids = None
473
+ self.fill_ids = []
469
474
  self.session_id = session_id
470
475
  self.input_embeds = input_embeds
471
476
 
477
+ # for corss-endoder model
478
+ self.token_type_ids = token_type_ids
479
+
472
480
  # Sampling info
473
481
  if isinstance(sampling_params.custom_params, dict):
474
482
  sampling_params = copy.copy(sampling_params)
@@ -514,13 +522,14 @@ class Req:
514
522
 
515
523
  # Prefix info
516
524
  # The indices to kv cache for the shared prefix.
517
- self.prefix_indices = []
525
+ self.prefix_indices: torch.Tensor = []
518
526
  # Number of tokens to run prefill.
519
527
  self.extend_input_len = 0
520
528
  # The relative logprob_start_len in an extend batch
521
529
  self.extend_logprob_start_len = 0
522
- self.last_node = None
523
- self.last_node_global = None
530
+ self.last_node: Any = None
531
+ self.last_host_node: Any = None
532
+ self.host_hit_length = 0
524
533
 
525
534
  # Whether or not if it is chunked. It increments whenever
526
535
  # it is chunked, and decrement whenever chunked request is
@@ -578,6 +587,7 @@ class Req:
578
587
  self.output_token_ids_logprobs_idx
579
588
  ) = None
580
589
  self.hidden_states: List[List[float]] = []
590
+ self.hidden_states_tensor = None # Note: use tensor instead of list to transfer hidden_states when PD + MTP
581
591
 
582
592
  # Embedding (return values)
583
593
  self.embedding = None
@@ -639,29 +649,17 @@ class Req:
639
649
  def init_next_round_input(
640
650
  self,
641
651
  tree_cache: Optional[BasePrefixCache] = None,
642
- enable_hierarchical_cache=False,
643
652
  ):
644
653
  self.fill_ids = self.origin_input_ids + self.output_ids
645
654
  if tree_cache is not None:
646
- # tree cache is None if the prefix is not computed with tree cache.
647
- if enable_hierarchical_cache:
648
- self.prefix_indices, self.last_node, self.last_node_global = (
649
- tree_cache.match_prefix(
650
- key=self.adjust_max_prefix_ids(), include_evicted=True
651
- )
652
- )
653
- else:
654
- self.prefix_indices, self.last_node = tree_cache.match_prefix(
655
- rid=self.rid, key=self.adjust_max_prefix_ids()
656
- )
657
- elif enable_hierarchical_cache:
658
- # in case last_node is evicted during scheduling, we need to update the prefix_indices
659
- while self.last_node.evicted:
660
- self.prefix_indices = self.prefix_indices[
661
- : -len(self.last_node.host_value)
662
- ]
663
- self.last_node = self.last_node.parent
664
-
655
+ (
656
+ self.prefix_indices,
657
+ self.last_node,
658
+ self.last_host_node,
659
+ self.host_hit_length,
660
+ ) = tree_cache.match_prefix(
661
+ key=self.adjust_max_prefix_ids(),
662
+ )
665
663
  self.extend_input_len = len(self.fill_ids) - len(self.prefix_indices)
666
664
 
667
665
  def adjust_max_prefix_ids(self):
@@ -791,6 +789,7 @@ class Req:
791
789
  self.multimodal_inputs = None
792
790
  self.grammar = None
793
791
  self.origin_input_ids = [0] # set it to one token to skip the long prefill
792
+ self.return_logprob = False
794
793
  self.finished_reason = FINISH_ABORT(
795
794
  error_msg, HTTPStatus.BAD_REQUEST, "BadRequestError"
796
795
  )
@@ -815,7 +814,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
815
814
  # Request, memory pool, and cache
816
815
  reqs: List[Req]
817
816
  req_to_token_pool: ReqToTokenPool = None
818
- token_to_kv_pool_allocator: TokenToKVPoolAllocator = None
817
+ token_to_kv_pool_allocator: BaseTokenToKVPoolAllocator = None
819
818
  tree_cache: BasePrefixCache = None
820
819
 
821
820
  # Batch configs
@@ -840,6 +839,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
840
839
  # Batched arguments to model runner
841
840
  input_ids: torch.Tensor = None # shape: [b], int64
842
841
  input_embeds: torch.Tensor = None # shape: [b, hidden_size], float32
842
+ token_type_ids: torch.Tensor = None # shape: [b], int64
843
843
  req_pool_indices: torch.Tensor = None # shape: [b], int64
844
844
  seq_lens: torch.Tensor = None # shape: [b], int64
845
845
  # The output locations of the KV cache
@@ -856,6 +856,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
856
856
  global_num_tokens: Optional[List[int]] = None
857
857
  global_num_tokens_for_logprob: Optional[List[int]] = None
858
858
  can_run_dp_cuda_graph: bool = False
859
+ is_extend_in_batch: bool = False
859
860
  tbo_split_seq_index: Optional[int] = None
860
861
  global_forward_mode: Optional[ForwardMode] = None
861
862
 
@@ -902,12 +903,15 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
902
903
  # Whether to return hidden states
903
904
  return_hidden_states: bool = False
904
905
 
906
+ # hicache pointer for synchronizing data loading from CPU to GPU
907
+ hicache_consumer_index: int = 0
908
+
905
909
  @classmethod
906
910
  def init_new(
907
911
  cls,
908
912
  reqs: List[Req],
909
913
  req_to_token_pool: ReqToTokenPool,
910
- token_to_kv_pool_allocator: TokenToKVPoolAllocator,
914
+ token_to_kv_pool_allocator: BaseTokenToKVPoolAllocator,
911
915
  tree_cache: BasePrefixCache,
912
916
  model_config: ModelConfig,
913
917
  enable_overlap: bool,
@@ -1141,6 +1145,10 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1141
1145
  prefix_lens = [len(r.prefix_indices) for r in reqs]
1142
1146
  extend_lens = [r.extend_input_len for r in reqs]
1143
1147
 
1148
+ token_type_ids = [
1149
+ r.token_type_ids for r in reqs if r.token_type_ids is not None
1150
+ ]
1151
+
1144
1152
  req_pool_indices_tensor = torch.tensor(req_pool_indices, dtype=torch.int64).to(
1145
1153
  self.device, non_blocking=True
1146
1154
  )
@@ -1153,6 +1161,13 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1153
1161
  prefix_lens_tensor = torch.tensor(
1154
1162
  prefix_lens, dtype=torch.int64, device=self.device
1155
1163
  )
1164
+
1165
+ token_type_ids_tensor = None
1166
+ if len(token_type_ids) > 0:
1167
+ token_type_ids_tensor = torch.tensor(
1168
+ sum(token_type_ids, []), dtype=torch.int64
1169
+ ).to(self.device, non_blocking=True)
1170
+
1156
1171
  extend_lens_tensor = seq_lens_tensor - prefix_lens_tensor
1157
1172
 
1158
1173
  # Copy prefix and do some basic check
@@ -1268,6 +1283,7 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1268
1283
  self.device, non_blocking=True
1269
1284
  )
1270
1285
  self.multimodal_inputs = multimodal_inputs
1286
+ self.token_type_ids = token_type_ids_tensor
1271
1287
  self.seq_lens_sum = sum(seq_lens)
1272
1288
 
1273
1289
  if self.return_logprob:
@@ -1347,7 +1363,11 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1347
1363
  return len(self.reqs)
1348
1364
  # In the decoding phase, the length of a request's KV cache should be
1349
1365
  # the total length of the request minus 1
1350
- return sum(1 for req in self.reqs if (req.seqlen - 1) % page_size == 0)
1366
+ return (
1367
+ sum(1 for req in self.reqs if req.seqlen % page_size == 0)
1368
+ if self.enable_overlap
1369
+ else sum(1 for req in self.reqs if (req.seqlen - 1) % page_size == 0)
1370
+ )
1351
1371
 
1352
1372
  def check_decode_mem(self, buf_multiplier=1):
1353
1373
  tokens_required = (
@@ -1414,6 +1434,11 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1414
1434
  req = self.reqs[idx]
1415
1435
  retracted_reqs.append(req)
1416
1436
 
1437
+ if server_args.disaggregation_mode == "decode":
1438
+ req.offload_kv_cache(
1439
+ self.req_to_token_pool, self.token_to_kv_pool_allocator
1440
+ )
1441
+
1417
1442
  if isinstance(self.tree_cache, ChunkCache):
1418
1443
  # ChunkCache does not have eviction
1419
1444
  token_indices = self.req_to_token_pool.req_to_token[
@@ -1445,6 +1470,12 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1445
1470
 
1446
1471
  req.reset_for_retract()
1447
1472
 
1473
+ if len(retracted_reqs) == 0:
1474
+ # Corner case: only one request left
1475
+ raise ValueError(
1476
+ "Failed to retract any request. No space left for only one request."
1477
+ )
1478
+
1448
1479
  self.filter_batch(keep_indices=sorted_indices)
1449
1480
 
1450
1481
  # Reqs in batch are filtered
@@ -1702,8 +1733,10 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1702
1733
  lora_paths=[req.lora_path for req in self.reqs],
1703
1734
  sampling_info=self.sampling_info,
1704
1735
  input_embeds=self.input_embeds,
1736
+ token_type_ids=self.token_type_ids,
1705
1737
  spec_algorithm=self.spec_algorithm,
1706
1738
  spec_info=self.spec_info,
1739
+ hicache_consumer_index=self.hicache_consumer_index,
1707
1740
  capture_hidden_mode=(
1708
1741
  CaptureHiddenMode.FULL
1709
1742
  if self.return_hidden_states
@@ -1730,11 +1763,15 @@ class ScheduleBatch(ScheduleBatchDisaggregationDecodeMixin):
1730
1763
  decoding_reqs=self.decoding_reqs,
1731
1764
  spec_algorithm=self.spec_algorithm,
1732
1765
  enable_custom_logit_processor=self.enable_custom_logit_processor,
1766
+ global_num_tokens=self.global_num_tokens,
1767
+ global_num_tokens_for_logprob=self.global_num_tokens_for_logprob,
1768
+ can_run_dp_cuda_graph=self.can_run_dp_cuda_graph,
1769
+ is_extend_in_batch=self.is_extend_in_batch,
1733
1770
  )
1734
1771
 
1735
1772
  def __str__(self):
1736
1773
  return (
1737
- f"ScheduleBatch(forward_mode={self.forward_mode.name}, "
1774
+ f"ScheduleBatch(forward_mode={self.forward_mode.name if self.forward_mode else 'None'}, "
1738
1775
  f"#req={(len(self.reqs))})"
1739
1776
  )
1740
1777
 
@@ -1795,11 +1832,16 @@ class ModelWorkerBatch:
1795
1832
  # The input Embeds
1796
1833
  input_embeds: Optional[torch.tensor] = None
1797
1834
 
1835
+ # For corss-encoder model
1836
+ token_type_ids: Optional[torch.Tensor] = None
1837
+
1798
1838
  # Speculative decoding
1799
1839
  spec_algorithm: SpeculativeAlgorithm = None
1800
1840
  spec_info: Optional[Union[EagleVerifyInput, EagleDraftInput]] = None
1801
1841
  # If set, the output of the batch contains the hidden states of the run.
1802
1842
  capture_hidden_mode: CaptureHiddenMode = None
1843
+ spec_num_draft_tokens: Optional[int] = None
1844
+ hicache_consumer_index: int = 0
1803
1845
 
1804
1846
  # Overlap event
1805
1847
  launch_done: Optional[threading.Event] = None
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  # Copyright 2023-2024 SGLang Team
2
4
  # Licensed under the Apache License, Version 2.0 (the "License");
3
5
  # you may not use this file except in compliance with the License.
@@ -18,15 +20,17 @@ import random
18
20
  from collections import defaultdict
19
21
  from contextlib import contextmanager
20
22
  from enum import Enum, auto
21
- from typing import Dict, List, Optional, Set, Union
23
+ from typing import TYPE_CHECKING, Dict, List, Optional, Set, Union
22
24
 
23
25
  import torch
24
26
 
25
27
  from sglang.srt.managers.schedule_batch import Req, ScheduleBatch
26
28
  from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
27
- from sglang.srt.mem_cache.memory_pool import TokenToKVPoolAllocator
28
29
  from sglang.srt.mem_cache.radix_cache import RadixCache, TreeNode
29
30
 
31
+ if TYPE_CHECKING:
32
+ from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
33
+
30
34
  # Clip the estimation of max_new_tokens for the request whose max_new_tokens is very large.
31
35
  # This can prevent the server from being too conservative.
32
36
  # Note that this only clips the estimation in the scheduler but does not change the stop
@@ -51,6 +55,9 @@ IN_BATCH_PREFIX_CACHING_DEPRIORITIZE_THRESHOLD = int(
51
55
  )
52
56
 
53
57
 
58
+ IGNORE_EOS_RESERVE_TOKENS = 1
59
+
60
+
54
61
  class CacheAwarePolicy(Enum):
55
62
  """Scheduling policies that are aware of the tree cache."""
56
63
 
@@ -90,7 +97,7 @@ class SchedulePolicy:
90
97
  def calc_priority(self, waiting_queue: List[Req]) -> bool:
91
98
  if self.policy == CacheAgnosticPolicy.FCFS:
92
99
  # A shortcut for FCFS
93
- return
100
+ return False
94
101
 
95
102
  policy = self._determine_active_policy(waiting_queue)
96
103
 
@@ -134,7 +141,7 @@ class SchedulePolicy:
134
141
  """
135
142
  try:
136
143
  policy_enum = CacheAwarePolicy(policy)
137
- if tree_cache.disable:
144
+ if getattr(tree_cache, "disable", True):
138
145
  # If tree_cache is disabled, using CacheAgnosticPolicy policy
139
146
  return CacheAgnosticPolicy.FCFS
140
147
  return policy_enum
@@ -158,14 +165,9 @@ class SchedulePolicy:
158
165
  prefix_ids = r.adjust_max_prefix_ids()
159
166
 
160
167
  # NOTE: the prefix_indices must always be aligned with last_node
161
- if self.enable_hierarchical_cache:
162
- r.prefix_indices, r.last_node, r.last_node_global = (
163
- self.tree_cache.match_prefix(key=prefix_ids, include_evicted=True)
164
- )
165
- else:
166
- r.prefix_indices, r.last_node = self.tree_cache.match_prefix(
167
- rid=r.rid, key=prefix_ids
168
- )
168
+ r.prefix_indices, r.last_node, r.last_host_node, r.host_hit_length = (
169
+ self.tree_cache.match_prefix(rid=r.rid, key=prefix_ids)
170
+ )
169
171
 
170
172
  # NOTE(sang): This logic is for in-batch prefix caching;
171
173
  # If there are more than 1 request that have small matching prefix from
@@ -175,7 +177,7 @@ class SchedulePolicy:
175
177
  # threshold means we cannot use in-batch prefix caching for short prefixes.
176
178
  # It is kind of common when the engine is long running (e.g., imagine the prefix "the").
177
179
  if len(r.prefix_indices) <= IN_BATCH_PREFIX_CACHING_CHECK_THRESHOLD:
178
- in_batch_matching_prefixes, _ = (
180
+ in_batch_matching_prefixes, _, _, _ = (
179
181
  self.waiting_queue_radix_tree.match_prefix(
180
182
  rid=r.rid, key=prefix_ids
181
183
  )
@@ -268,14 +270,16 @@ class AddReqResult(Enum):
268
270
  class PrefillAdder:
269
271
  def __init__(
270
272
  self,
273
+ page_size: int,
271
274
  tree_cache: BasePrefixCache,
272
- token_to_kv_pool_allocator: TokenToKVPoolAllocator,
275
+ token_to_kv_pool_allocator: BaseTokenToKVPoolAllocator,
273
276
  running_batch: ScheduleBatch,
274
277
  new_token_ratio: float,
275
278
  rem_input_tokens: int,
276
279
  rem_chunk_tokens: Optional[int],
277
280
  mixed_with_decode_tokens: int = 0,
278
281
  ):
282
+ self.page_size = page_size
279
283
  self.tree_cache = tree_cache
280
284
  self.token_to_kv_pool_allocator = token_to_kv_pool_allocator
281
285
  self.running_batch = running_batch
@@ -292,6 +296,7 @@ class PrefillAdder:
292
296
  self.can_run_list = []
293
297
  self.new_chunked_req = None
294
298
  self.log_hit_tokens = 0
299
+ # TODO(lsyin): report the real input tokens excluding page alignment
295
300
  self.log_input_tokens = 0
296
301
 
297
302
  if running_batch is not None:
@@ -322,6 +327,9 @@ class PrefillAdder:
322
327
  - self.cur_rem_token_offset
323
328
  )
324
329
 
330
+ def ceil_paged_tokens(self, tokens: int) -> int:
331
+ return -(-tokens // self.page_size) * self.page_size
332
+
325
333
  def budget_state(self):
326
334
  if self.rem_total_tokens <= 0 or self.cur_rem_tokens <= 0:
327
335
  return AddReqResult.NO_TOKEN
@@ -333,9 +341,12 @@ class PrefillAdder:
333
341
 
334
342
  return AddReqResult.CONTINUE
335
343
 
336
- def _prefill_one_req(
344
+ def _update_prefill_budget(
337
345
  self, prefix_len: int, extend_input_len: int, max_new_tokens: int
338
346
  ):
347
+ # TODO(lsyin): check this workaround logic, which only ensures the prefill will not out of memory, and may be too conservative
348
+ extend_input_len = self.ceil_paged_tokens(extend_input_len)
349
+
339
350
  self.rem_total_token_offset += extend_input_len + max_new_tokens
340
351
  self.cur_rem_token_offset += extend_input_len
341
352
  self.rem_input_tokens -= extend_input_len
@@ -350,7 +361,7 @@ class PrefillAdder:
350
361
  req.extend_input_len = min(req.extend_input_len, self.rem_chunk_tokens)
351
362
  req.fill_ids = req.fill_ids[: len(req.prefix_indices) + req.extend_input_len]
352
363
  self.can_run_list.append(req)
353
- self._prefill_one_req(
364
+ self._update_prefill_budget(
354
365
  0,
355
366
  req.extend_input_len,
356
367
  (
@@ -372,6 +383,12 @@ class PrefillAdder:
372
383
  self.tree_cache.dec_lock_ref(last_node)
373
384
 
374
385
  def add_one_req_ignore_eos(self, req: Req, has_chunked_req: bool):
386
+ # Early exit if no enough tokens for the input tokens
387
+ if self.ceil_paged_tokens(req.extend_input_len) > min(
388
+ self.cur_rem_tokens, self.rem_total_tokens
389
+ ):
390
+ return AddReqResult.NO_TOKEN
391
+
375
392
  def add_req_state(r, insert_sort=False):
376
393
  new_token_ratio = (
377
394
  1.0 if r.sampling_params.ignore_eos else self.new_token_ratio
@@ -381,15 +398,17 @@ class PrefillAdder:
381
398
  )
382
399
  tokens_occupied = len(r.origin_input_ids) + len(r.output_ids)
383
400
 
384
- if tokens_left > 0:
385
- if not insert_sort:
386
- self.req_states.append((tokens_left, tokens_occupied))
387
- else:
388
- i = 0
389
- for i in range(len(self.req_states)):
390
- if tokens_left <= self.req_states[i][0]:
391
- break
392
- self.req_states.insert(i, (tokens_left, tokens_occupied))
401
+ if tokens_left <= 0:
402
+ return
403
+
404
+ if not insert_sort:
405
+ self.req_states.append((tokens_left, tokens_occupied))
406
+ else:
407
+ i = 0
408
+ for i in range(len(self.req_states)):
409
+ if tokens_left <= self.req_states[i][0]:
410
+ break
411
+ self.req_states.insert(i, (tokens_left, tokens_occupied))
393
412
 
394
413
  if self.req_states is None:
395
414
  self.req_states = []
@@ -406,13 +425,11 @@ class PrefillAdder:
406
425
  cur_rem_tokens = self.cur_rem_tokens - len(req.origin_input_ids)
407
426
  tokens_freed = 0
408
427
  for i, (tokens_left, tokens_occupied) in enumerate(self.req_states):
409
- decode_steps = (
410
- self.req_states[i + 1][0]
411
- if i + 1 < len(self.req_states)
412
- else tokens_left
413
- )
428
+ # tokens_left gives a reservative calculation as the last token is not stored
414
429
  bs = len(self.req_states) - i
415
- if cur_rem_tokens + tokens_freed - decode_steps * bs <= 0:
430
+ min_free_tokens = cur_rem_tokens + tokens_freed - tokens_left * bs
431
+ # reserve tokens for corner cases
432
+ if min_free_tokens <= IGNORE_EOS_RESERVE_TOKENS * bs:
416
433
  return AddReqResult.NO_TOKEN
417
434
  tokens_freed += tokens_occupied
418
435
 
@@ -422,7 +439,7 @@ class PrefillAdder:
422
439
  ):
423
440
  # Non-chunked prefill
424
441
  self.can_run_list.append(req)
425
- self._prefill_one_req(
442
+ self._update_prefill_budget(
426
443
  0,
427
444
  req.extend_input_len,
428
445
  min(req.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKENS_ESTIMATION),
@@ -438,55 +455,52 @@ class PrefillAdder:
438
455
  req.fill_ids = req.fill_ids[:trunc_len]
439
456
  self.can_run_list.append(req)
440
457
  self.new_chunked_req = req
441
- self._prefill_one_req(0, trunc_len, 0)
458
+ self._update_prefill_budget(0, trunc_len, 0)
442
459
 
443
460
  return self.budget_state()
444
461
 
445
- def add_one_req(
446
- self, req: Req, has_chunked_req: bool, enable_hierarchical_cache: bool = False
447
- ):
462
+ def add_one_req(self, req: Req, has_chunked_req: bool):
448
463
  if req.sampling_params.ignore_eos and getattr(self.tree_cache, "disable", True):
449
464
  return self.add_one_req_ignore_eos(req, has_chunked_req)
450
465
 
451
466
  total_tokens = req.extend_input_len + min(
452
467
  req.sampling_params.max_new_tokens, CLIP_MAX_NEW_TOKENS_ESTIMATION
453
468
  )
454
- input_tokens = (
455
- -(-req.extend_input_len // self.tree_cache.page_size)
456
- * self.tree_cache.page_size
457
- )
469
+
470
+ # adjusting the input_tokens based on host_hit_length and page_size
471
+ real_input_tokens = req.extend_input_len - req.host_hit_length
472
+ real_input_tokens = self.ceil_paged_tokens(real_input_tokens)
458
473
  prefix_len = len(req.prefix_indices)
459
474
 
460
475
  if total_tokens >= self.rem_total_tokens:
461
476
  return AddReqResult.NO_TOKEN
462
477
 
463
- if input_tokens > self.rem_input_tokens and len(self.can_run_list) != 0:
478
+ if real_input_tokens >= self.rem_input_tokens and len(self.can_run_list) != 0:
464
479
  return AddReqResult.OTHER
465
480
 
466
481
  with self._lock_node(req.last_node):
467
- if total_tokens > self.rem_total_tokens:
482
+ # self.rem_total_tokens may decrease after the lock acquisition
483
+ if total_tokens >= self.rem_total_tokens:
468
484
  return AddReqResult.NO_TOKEN
469
485
 
470
- if (
471
- enable_hierarchical_cache
472
- and req.last_node_global is not None
473
- and req.last_node_global.evicted
474
- ):
475
- req.last_node, req.prefix_indices = self.tree_cache.init_load_back(
476
- req.last_node_global, req.prefix_indices
486
+ if req.host_hit_length > 0:
487
+ new_indices, req.last_node = self.tree_cache.init_load_back(
488
+ req.last_host_node, req.host_hit_length
477
489
  )
490
+ req.prefix_indices = torch.cat([req.prefix_indices, new_indices])
478
491
  req.extend_input_len = len(req.fill_ids) - len(req.prefix_indices)
479
- input_tokens = (
480
- -(-req.extend_input_len // self.tree_cache.page_size)
481
- * self.tree_cache.page_size
482
- )
483
492
  prefix_len = len(req.prefix_indices)
484
493
 
494
+ input_tokens = self.ceil_paged_tokens(req.extend_input_len)
495
+
496
+ if input_tokens >= self.rem_input_tokens and len(self.can_run_list) != 0:
497
+ return AddReqResult.OTHER
498
+
485
499
  if self.rem_chunk_tokens is None or input_tokens <= self.rem_chunk_tokens:
486
500
  # Non-chunked prefill
487
501
  self.can_run_list.append(req)
488
502
  self.tree_cache.inc_lock_ref(req.last_node)
489
- self._prefill_one_req(
503
+ self._update_prefill_budget(
490
504
  prefix_len,
491
505
  input_tokens,
492
506
  min(
@@ -496,7 +510,7 @@ class PrefillAdder:
496
510
  )
497
511
  else:
498
512
  # Make sure at least one page is available
499
- trunc_len = self.rem_chunk_tokens - self.tree_cache.page_size + 1
513
+ trunc_len = self.rem_chunk_tokens - self.page_size + 1
500
514
  if trunc_len <= 0:
501
515
  return AddReqResult.OTHER
502
516
 
@@ -507,6 +521,6 @@ class PrefillAdder:
507
521
  self.can_run_list.append(req)
508
522
  self.new_chunked_req = req
509
523
  self.tree_cache.inc_lock_ref(req.last_node)
510
- self._prefill_one_req(prefix_len, trunc_len, 0)
524
+ self._update_prefill_budget(prefix_len, trunc_len, 0)
511
525
 
512
526
  return self.budget_state()