sglang 0.4.7.post1__py3-none-any.whl → 0.4.8.post1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (123) hide show
  1. sglang/bench_one_batch.py +8 -6
  2. sglang/srt/_custom_ops.py +2 -2
  3. sglang/srt/code_completion_parser.py +2 -44
  4. sglang/srt/configs/model_config.py +1 -0
  5. sglang/srt/constants.py +3 -0
  6. sglang/srt/conversation.py +14 -3
  7. sglang/srt/custom_op.py +11 -1
  8. sglang/srt/disaggregation/base/conn.py +2 -0
  9. sglang/srt/disaggregation/decode.py +22 -28
  10. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +4 -3
  11. sglang/srt/disaggregation/mini_lb.py +34 -4
  12. sglang/srt/disaggregation/mooncake/conn.py +301 -64
  13. sglang/srt/disaggregation/mooncake/transfer_engine.py +31 -1
  14. sglang/srt/disaggregation/nixl/conn.py +94 -46
  15. sglang/srt/disaggregation/prefill.py +20 -15
  16. sglang/srt/disaggregation/utils.py +47 -18
  17. sglang/srt/distributed/parallel_state.py +12 -4
  18. sglang/srt/entrypoints/engine.py +27 -31
  19. sglang/srt/entrypoints/http_server.py +149 -79
  20. sglang/srt/entrypoints/http_server_engine.py +0 -3
  21. sglang/srt/entrypoints/openai/__init__.py +0 -0
  22. sglang/srt/{openai_api → entrypoints/openai}/protocol.py +115 -34
  23. sglang/srt/entrypoints/openai/serving_base.py +149 -0
  24. sglang/srt/entrypoints/openai/serving_chat.py +897 -0
  25. sglang/srt/entrypoints/openai/serving_completions.py +425 -0
  26. sglang/srt/entrypoints/openai/serving_embedding.py +170 -0
  27. sglang/srt/entrypoints/openai/serving_rerank.py +102 -0
  28. sglang/srt/entrypoints/openai/serving_score.py +61 -0
  29. sglang/srt/entrypoints/openai/usage_processor.py +81 -0
  30. sglang/srt/entrypoints/openai/utils.py +72 -0
  31. sglang/srt/function_call/base_format_detector.py +7 -4
  32. sglang/srt/function_call/deepseekv3_detector.py +1 -1
  33. sglang/srt/function_call/ebnf_composer.py +64 -10
  34. sglang/srt/function_call/function_call_parser.py +6 -6
  35. sglang/srt/function_call/llama32_detector.py +1 -1
  36. sglang/srt/function_call/mistral_detector.py +1 -1
  37. sglang/srt/function_call/pythonic_detector.py +1 -1
  38. sglang/srt/function_call/qwen25_detector.py +1 -1
  39. sglang/srt/{openai_api/utils.py → jinja_template_utils.py} +6 -5
  40. sglang/srt/layers/activation.py +28 -3
  41. sglang/srt/layers/attention/aiter_backend.py +5 -2
  42. sglang/srt/layers/attention/base_attn_backend.py +1 -1
  43. sglang/srt/layers/attention/cutlass_mla_backend.py +1 -0
  44. sglang/srt/layers/attention/flashattention_backend.py +43 -23
  45. sglang/srt/layers/attention/flashinfer_backend.py +9 -6
  46. sglang/srt/layers/attention/flashinfer_mla_backend.py +7 -4
  47. sglang/srt/layers/attention/flashmla_backend.py +5 -2
  48. sglang/srt/layers/attention/tbo_backend.py +3 -3
  49. sglang/srt/layers/attention/triton_backend.py +19 -11
  50. sglang/srt/layers/communicator.py +5 -5
  51. sglang/srt/layers/dp_attention.py +11 -2
  52. sglang/srt/layers/layernorm.py +44 -2
  53. sglang/srt/layers/linear.py +18 -1
  54. sglang/srt/layers/logits_processor.py +14 -5
  55. sglang/srt/layers/moe/ep_moe/kernels.py +159 -2
  56. sglang/srt/layers/moe/ep_moe/layer.py +286 -13
  57. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +19 -2
  58. sglang/srt/layers/moe/fused_moe_native.py +7 -0
  59. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=384,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  60. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +13 -2
  61. sglang/srt/layers/moe/fused_moe_triton/layer.py +148 -26
  62. sglang/srt/layers/moe/topk.py +117 -4
  63. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +6 -2
  64. sglang/srt/layers/quantization/fp8.py +25 -17
  65. sglang/srt/layers/quantization/fp8_utils.py +5 -4
  66. sglang/srt/layers/quantization/modelopt_quant.py +62 -8
  67. sglang/srt/layers/quantization/utils.py +5 -2
  68. sglang/srt/layers/rotary_embedding.py +144 -12
  69. sglang/srt/layers/sampler.py +1 -1
  70. sglang/srt/layers/vocab_parallel_embedding.py +14 -1
  71. sglang/srt/lora/lora_manager.py +173 -74
  72. sglang/srt/lora/mem_pool.py +49 -45
  73. sglang/srt/lora/utils.py +1 -1
  74. sglang/srt/managers/cache_controller.py +33 -15
  75. sglang/srt/managers/expert_distribution.py +21 -0
  76. sglang/srt/managers/io_struct.py +19 -14
  77. sglang/srt/managers/multimodal_processors/base_processor.py +44 -9
  78. sglang/srt/managers/multimodal_processors/gemma3n.py +97 -0
  79. sglang/srt/managers/schedule_batch.py +49 -32
  80. sglang/srt/managers/schedule_policy.py +70 -56
  81. sglang/srt/managers/scheduler.py +189 -68
  82. sglang/srt/managers/template_manager.py +226 -0
  83. sglang/srt/managers/tokenizer_manager.py +11 -8
  84. sglang/srt/managers/tp_worker.py +12 -2
  85. sglang/srt/managers/tp_worker_overlap_thread.py +11 -0
  86. sglang/srt/mem_cache/{paged_allocator.py → allocator.py} +125 -34
  87. sglang/srt/mem_cache/base_prefix_cache.py +52 -8
  88. sglang/srt/mem_cache/chunk_cache.py +11 -16
  89. sglang/srt/mem_cache/hiradix_cache.py +34 -23
  90. sglang/srt/mem_cache/memory_pool.py +118 -114
  91. sglang/srt/mem_cache/radix_cache.py +20 -16
  92. sglang/srt/model_executor/cuda_graph_runner.py +77 -46
  93. sglang/srt/model_executor/forward_batch_info.py +18 -5
  94. sglang/srt/model_executor/model_runner.py +27 -8
  95. sglang/srt/model_loader/loader.py +50 -8
  96. sglang/srt/model_loader/weight_utils.py +100 -2
  97. sglang/srt/models/deepseek_nextn.py +35 -30
  98. sglang/srt/models/deepseek_v2.py +255 -30
  99. sglang/srt/models/gemma3n_audio.py +949 -0
  100. sglang/srt/models/gemma3n_causal.py +1009 -0
  101. sglang/srt/models/gemma3n_mm.py +511 -0
  102. sglang/srt/models/glm4.py +312 -0
  103. sglang/srt/models/hunyuan.py +771 -0
  104. sglang/srt/models/mimo_mtp.py +2 -18
  105. sglang/srt/reasoning_parser.py +21 -11
  106. sglang/srt/server_args.py +51 -9
  107. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +131 -10
  108. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +125 -12
  109. sglang/srt/speculative/eagle_utils.py +80 -8
  110. sglang/srt/speculative/eagle_worker.py +124 -41
  111. sglang/srt/torch_memory_saver_adapter.py +19 -15
  112. sglang/srt/two_batch_overlap.py +4 -1
  113. sglang/srt/utils.py +248 -11
  114. sglang/test/test_block_fp8_ep.py +1 -0
  115. sglang/test/test_utils.py +1 -0
  116. sglang/version.py +1 -1
  117. {sglang-0.4.7.post1.dist-info → sglang-0.4.8.post1.dist-info}/METADATA +4 -10
  118. {sglang-0.4.7.post1.dist-info → sglang-0.4.8.post1.dist-info}/RECORD +121 -105
  119. sglang/srt/entrypoints/verl_engine.py +0 -179
  120. sglang/srt/openai_api/adapter.py +0 -2148
  121. {sglang-0.4.7.post1.dist-info → sglang-0.4.8.post1.dist-info}/WHEEL +0 -0
  122. {sglang-0.4.7.post1.dist-info → sglang-0.4.8.post1.dist-info}/licenses/LICENSE +0 -0
  123. {sglang-0.4.7.post1.dist-info → sglang-0.4.8.post1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,226 @@
1
+ # Copyright 2023-2024 SGLang Team
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # ==============================================================================
14
+ """
15
+ Centralized template management for chat templates and completion templates.
16
+
17
+ This module provides a unified interface for managing both chat conversation templates
18
+ and code completion templates, eliminating global state and improving modularity.
19
+ """
20
+
21
+ import json
22
+ import logging
23
+ import os
24
+ from typing import Optional
25
+
26
+ from sglang.srt.code_completion_parser import (
27
+ CompletionTemplate,
28
+ FimPosition,
29
+ completion_template_exists,
30
+ register_completion_template,
31
+ )
32
+ from sglang.srt.conversation import (
33
+ Conversation,
34
+ SeparatorStyle,
35
+ chat_template_exists,
36
+ get_conv_template_by_model_path,
37
+ register_conv_template,
38
+ )
39
+ from sglang.srt.jinja_template_utils import detect_jinja_template_content_format
40
+
41
+ logger = logging.getLogger(__name__)
42
+
43
+
44
+ class TemplateManager:
45
+ """
46
+ Centralized manager for chat and completion templates.
47
+
48
+ This class encapsulates all template-related state and operations,
49
+ eliminating the need for global variables and providing a clean
50
+ interface for template management.
51
+ """
52
+
53
+ def __init__(self):
54
+ self._chat_template_name: Optional[str] = None
55
+ self._completion_template_name: Optional[str] = None
56
+ self._jinja_template_content_format: Optional[str] = None
57
+
58
+ @property
59
+ def chat_template_name(self) -> Optional[str]:
60
+ """Get the current chat template name."""
61
+ return self._chat_template_name
62
+
63
+ @property
64
+ def completion_template_name(self) -> Optional[str]:
65
+ """Get the current completion template name."""
66
+ return self._completion_template_name
67
+
68
+ @property
69
+ def jinja_template_content_format(self) -> Optional[str]:
70
+ """Get the detected template content format ('string' or 'openai' or None)."""
71
+ return self._jinja_template_content_format
72
+
73
+ def load_chat_template(
74
+ self, tokenizer_manager, chat_template_arg: str, model_path: str
75
+ ) -> None:
76
+ """
77
+ Load a chat template from various sources.
78
+
79
+ Args:
80
+ tokenizer_manager: The tokenizer manager instance
81
+ chat_template_arg: Template name or file path
82
+ model_path: Path to the model
83
+ """
84
+ logger.info(f"Loading chat template: {chat_template_arg}")
85
+
86
+ if not chat_template_exists(chat_template_arg):
87
+ if not os.path.exists(chat_template_arg):
88
+ raise RuntimeError(
89
+ f"Chat template {chat_template_arg} is not a built-in template name "
90
+ "or a valid chat template file path."
91
+ )
92
+
93
+ if chat_template_arg.endswith(".jinja"):
94
+ self._load_jinja_template(tokenizer_manager, chat_template_arg)
95
+ else:
96
+ self._load_json_chat_template(chat_template_arg)
97
+ else:
98
+ self._chat_template_name = chat_template_arg
99
+
100
+ def guess_chat_template_from_model_path(self, model_path: str) -> None:
101
+ """
102
+ Infer chat template name from model path.
103
+
104
+ Args:
105
+ model_path: Path to the model
106
+ """
107
+ template_name = get_conv_template_by_model_path(model_path)
108
+ if template_name is not None:
109
+ logger.info(f"Inferred chat template from model path: {template_name}")
110
+ self._chat_template_name = template_name
111
+
112
+ def load_completion_template(self, completion_template_arg: str) -> None:
113
+ """
114
+ Load completion template for code completion.
115
+
116
+ Args:
117
+ completion_template_arg: Template name or file path
118
+ """
119
+ logger.info(f"Loading completion template: {completion_template_arg}")
120
+
121
+ if not completion_template_exists(completion_template_arg):
122
+ if not os.path.exists(completion_template_arg):
123
+ raise RuntimeError(
124
+ f"Completion template {completion_template_arg} is not a built-in template name "
125
+ "or a valid completion template file path."
126
+ )
127
+
128
+ self._load_json_completion_template(completion_template_arg)
129
+ else:
130
+ self._completion_template_name = completion_template_arg
131
+
132
+ def initialize_templates(
133
+ self,
134
+ tokenizer_manager,
135
+ model_path: str,
136
+ chat_template: Optional[str] = None,
137
+ completion_template: Optional[str] = None,
138
+ ) -> None:
139
+ """
140
+ Initialize all templates based on provided configuration.
141
+
142
+ Args:
143
+ tokenizer_manager: The tokenizer manager instance
144
+ model_path: Path to the model
145
+ chat_template: Optional chat template name/path
146
+ completion_template: Optional completion template name/path
147
+ """
148
+ # Load chat template
149
+ if chat_template:
150
+ self.load_chat_template(tokenizer_manager, chat_template, model_path)
151
+ else:
152
+ self.guess_chat_template_from_model_path(model_path)
153
+
154
+ # Load completion template
155
+ if completion_template:
156
+ self.load_completion_template(completion_template)
157
+
158
+ def _load_jinja_template(self, tokenizer_manager, template_path: str) -> None:
159
+ """Load a Jinja template file."""
160
+ with open(template_path, "r") as f:
161
+ chat_template = "".join(f.readlines()).strip("\n")
162
+ tokenizer_manager.tokenizer.chat_template = chat_template.replace("\\n", "\n")
163
+ self._chat_template_name = None
164
+ # Detect content format from the loaded template
165
+ self._jinja_template_content_format = detect_jinja_template_content_format(
166
+ chat_template
167
+ )
168
+ logger.info(
169
+ f"Detected chat template content format: {self._jinja_template_content_format}"
170
+ )
171
+
172
+ def _load_json_chat_template(self, template_path: str) -> None:
173
+ """Load a JSON chat template file."""
174
+ assert template_path.endswith(
175
+ ".json"
176
+ ), "unrecognized format of chat template file"
177
+
178
+ with open(template_path, "r") as filep:
179
+ template = json.load(filep)
180
+ try:
181
+ sep_style = SeparatorStyle[template["sep_style"]]
182
+ except KeyError:
183
+ raise ValueError(
184
+ f"Unknown separator style: {template['sep_style']}"
185
+ ) from None
186
+
187
+ register_conv_template(
188
+ Conversation(
189
+ name=template["name"],
190
+ system_template=template["system"] + "\n{system_message}",
191
+ system_message=template.get("system_message", ""),
192
+ roles=(template["user"], template["assistant"]),
193
+ sep_style=sep_style,
194
+ sep=template.get("sep", "\n"),
195
+ stop_str=template["stop_str"],
196
+ ),
197
+ override=True,
198
+ )
199
+ self._chat_template_name = template["name"]
200
+
201
+ def _load_json_completion_template(self, template_path: str) -> None:
202
+ """Load a JSON completion template file."""
203
+ assert template_path.endswith(
204
+ ".json"
205
+ ), "unrecognized format of completion template file"
206
+
207
+ with open(template_path, "r") as filep:
208
+ template = json.load(filep)
209
+ try:
210
+ fim_position = FimPosition[template["fim_position"]]
211
+ except KeyError:
212
+ raise ValueError(
213
+ f"Unknown fim position: {template['fim_position']}"
214
+ ) from None
215
+
216
+ register_completion_template(
217
+ CompletionTemplate(
218
+ name=template["name"],
219
+ fim_begin_token=template["fim_begin_token"],
220
+ fim_middle_token=template["fim_middle_token"],
221
+ fim_end_token=template["fim_end_token"],
222
+ fim_position=fim_position,
223
+ ),
224
+ override=True,
225
+ )
226
+ self._completion_template_name = template["name"]
@@ -1058,12 +1058,7 @@ class TokenizerManager:
1058
1058
  "lora_path",
1059
1059
  ]
1060
1060
  )
1061
- out_skip_names = set(
1062
- [
1063
- "text",
1064
- "output_ids",
1065
- ]
1066
- )
1061
+ out_skip_names = set(["text", "output_ids", "embedding"])
1067
1062
  elif self.log_requests_level == 1:
1068
1063
  max_length = 2048
1069
1064
  elif self.log_requests_level == 2:
@@ -1140,13 +1135,21 @@ class TokenizerManager:
1140
1135
  remain_num_req = len(self.rid_to_state)
1141
1136
 
1142
1137
  if self.health_check_failed:
1143
- # if health check failed, we should exit immediately
1138
+ # if health check failed, exit immediately
1144
1139
  logger.error(
1145
1140
  "Signal SIGTERM received while health check failed. Exiting... remaining number of requests: %d",
1146
1141
  remain_num_req,
1147
1142
  )
1148
1143
  break
1149
1144
 
1145
+ elif get_bool_env_var("SGL_FORCE_SHUTDOWN"):
1146
+ # if force shutdown flag set, exit immediately
1147
+ logger.error(
1148
+ "Signal SIGTERM received while force shutdown flag set. Force exiting... remaining number of requests: %d",
1149
+ remain_num_req,
1150
+ )
1151
+ break
1152
+
1150
1153
  logger.info(
1151
1154
  f"Gracefully exiting... remaining number of requests {remain_num_req}"
1152
1155
  )
@@ -1223,7 +1226,7 @@ class TokenizerManager:
1223
1226
  state.last_output_offset = len(state.output_ids)
1224
1227
  else:
1225
1228
  state.output_ids.extend(recv_obj.output_ids[i])
1226
- output_token_ids = state.output_ids
1229
+ output_token_ids = state.output_ids.copy()
1227
1230
 
1228
1231
  out_dict = {
1229
1232
  "output_ids": output_token_ids,
@@ -35,7 +35,8 @@ from sglang.srt.managers.io_struct import (
35
35
  UpdateWeightsFromTensorReqInput,
36
36
  )
37
37
  from sglang.srt.managers.schedule_batch import ModelWorkerBatch, global_server_args_dict
38
- from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPoolAllocator
38
+ from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
39
+ from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
39
40
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
40
41
  from sglang.srt.model_executor.model_runner import ModelRunner
41
42
  from sglang.srt.server_args import ServerArgs
@@ -57,7 +58,7 @@ class TpModelWorker:
57
58
  nccl_port: int,
58
59
  is_draft_worker: bool = False,
59
60
  req_to_token_pool: Optional[ReqToTokenPool] = None,
60
- token_to_kv_pool_allocator: Optional[TokenToKVPoolAllocator] = None,
61
+ token_to_kv_pool_allocator: Optional[BaseTokenToKVPoolAllocator] = None,
61
62
  ):
62
63
  # Parse args
63
64
  self.tp_size = server_args.tp_size
@@ -147,6 +148,15 @@ class TpModelWorker:
147
148
  # A reference make this class has the same member as TpModelWorkerClient
148
149
  self.worker = self
149
150
 
151
+ self.hicache_layer_transfer_counter = None
152
+
153
+ def register_hicache_layer_transfer_counter(self, counter):
154
+ self.hicache_layer_transfer_counter = counter
155
+
156
+ def set_hicache_consumer(self, consumer_index):
157
+ if self.hicache_layer_transfer_counter is not None:
158
+ self.hicache_layer_transfer_counter.set_consumer(consumer_index)
159
+
150
160
  def get_worker_info(self):
151
161
  return (
152
162
  self.max_total_num_tokens,
@@ -88,6 +88,15 @@ class TpModelWorkerClient:
88
88
  if self.device == "cpu":
89
89
  self.scheduler_stream.synchronize = lambda: None # No-op for CPU
90
90
 
91
+ self.hicache_layer_transfer_counter = None
92
+
93
+ def register_hicache_layer_transfer_counter(self, counter):
94
+ self.hicache_layer_transfer_counter = counter
95
+
96
+ def set_hicache_consumer(self, consumer_index):
97
+ if self.hicache_layer_transfer_counter is not None:
98
+ self.hicache_layer_transfer_counter.set_consumer(consumer_index)
99
+
91
100
  def get_worker_info(self):
92
101
  return self.worker.get_worker_info()
93
102
 
@@ -146,6 +155,8 @@ class TpModelWorkerClient:
146
155
  input_ids = model_worker_batch.input_ids
147
156
  resolve_future_token_ids(input_ids, self.future_token_ids_map)
148
157
 
158
+ # update the consumer index of hicache to the running batch
159
+ self.set_hicache_consumer(model_worker_batch.hicache_consumer_index)
149
160
  # Run forward
150
161
  logits_output, next_token_ids, can_run_cuda_graph = (
151
162
  self.worker.forward_batch_generation(
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  """
2
4
  Copyright 2025 SGLang Team
3
5
  Licensed under the Apache License, Version 2.0 (the "License");
@@ -17,13 +19,132 @@ limitations under the License.
17
19
  Page-aligned memory pool.
18
20
  """
19
21
 
22
+ import abc
23
+ from typing import TYPE_CHECKING
24
+
20
25
  import torch
21
26
  import triton
22
27
  import triton.language as tl
23
28
 
24
- from sglang.srt.mem_cache.memory_pool import KVCache
25
29
  from sglang.srt.utils import get_bool_env_var, next_power_of_2
26
30
 
31
+ if TYPE_CHECKING:
32
+ from sglang.srt.mem_cache.memory_pool import KVCache
33
+
34
+
35
+ class BaseTokenToKVPoolAllocator(abc.ABC):
36
+ @abc.abstractmethod
37
+ def __init__(
38
+ self,
39
+ size: int,
40
+ page_size: int,
41
+ dtype: torch.dtype,
42
+ device: str,
43
+ kvcache: KVCache,
44
+ ):
45
+ self.size = size
46
+ self.page_size = page_size
47
+ self.dtype = dtype
48
+ self.device = device
49
+ self._kvcache = kvcache
50
+
51
+ self.free_pages = None
52
+ self.is_not_in_free_group = True
53
+ self.free_group = []
54
+
55
+ def debug_print(self) -> str:
56
+ return ""
57
+
58
+ def available_size(self):
59
+ return len(self.free_pages) * self.page_size
60
+
61
+ def get_kvcache(self):
62
+ return self._kvcache
63
+
64
+ def restore_state(self, free_pages):
65
+ self.free_pages = free_pages
66
+
67
+ def backup_state(self):
68
+ return self.free_pages
69
+
70
+ def free_group_begin(self):
71
+ self.is_not_in_free_group = False
72
+ self.free_group = []
73
+
74
+ def free_group_end(self):
75
+ self.is_not_in_free_group = True
76
+ if self.free_group:
77
+ self.free(torch.cat(self.free_group))
78
+
79
+ def get_cpu_copy(self, *args, **kwargs):
80
+ # FIXME: reuse the get_cpu_copy after paged allocator is implemented
81
+ raise NotImplementedError()
82
+
83
+ def load_cpu_copy(self, *args, **kwargs):
84
+ # FIXME: reuse the load_cpu_copy after paged allocator is implemented
85
+ raise NotImplementedError()
86
+
87
+ def alloc_extend(self, *args, **kwargs):
88
+ raise NotImplementedError("alloc_extend is only for paged allocator")
89
+
90
+ def alloc_decode(self, *args, **kwargs):
91
+ raise NotImplementedError("alloc_decode is only for paged allocator")
92
+
93
+ @abc.abstractmethod
94
+ def clear(self):
95
+ raise NotImplementedError()
96
+
97
+ @abc.abstractmethod
98
+ def alloc(self, need_size: int):
99
+ raise NotImplementedError()
100
+
101
+ @abc.abstractmethod
102
+ def free(self, free_index: torch.Tensor):
103
+ raise NotImplementedError()
104
+
105
+
106
+ class TokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
107
+ """An allocator managing the indices to kv cache data."""
108
+
109
+ def __init__(self, size: int, dtype: torch.dtype, device: str, kvcache: KVCache):
110
+ super().__init__(size, 1, dtype, device, kvcache)
111
+ self.clear()
112
+
113
+ def clear(self):
114
+ # The padded slot 0 is used for writing dummy outputs from padded tokens.
115
+ self.free_pages = torch.arange(
116
+ 1, self.size + 1, dtype=torch.int64, device=self.device
117
+ )
118
+ self.is_not_in_free_group = True
119
+ self.free_group = []
120
+
121
+ def available_size(self):
122
+ # To avoid minor "len(free_pages) * 1" overhead
123
+ return len(self.free_pages)
124
+
125
+ def alloc(self, need_size: int):
126
+ if need_size > len(self.free_pages):
127
+ return None
128
+
129
+ select_index = self.free_pages[:need_size]
130
+ self.free_pages = self.free_pages[need_size:]
131
+ return select_index
132
+
133
+ def free(self, free_index: torch.Tensor):
134
+ if free_index.numel() == 0:
135
+ return
136
+
137
+ if self.is_not_in_free_group:
138
+ self.free_pages = torch.cat((self.free_pages, free_index))
139
+ else:
140
+ self.free_group.append(free_index)
141
+
142
+ def get_cpu_copy(self, indices):
143
+ return self._kvcache.get_cpu_copy(indices)
144
+
145
+ def load_cpu_copy(self, kv_cache_cpu, indices):
146
+ return self._kvcache.load_cpu_copy(kv_cache_cpu, indices)
147
+
27
148
 
28
149
  @triton.jit
29
150
  def alloc_extend_kernel(
@@ -154,7 +275,7 @@ def alloc_decode_kernel(
154
275
  tl.store(out_indices + pid, page * page_size)
155
276
 
156
277
 
157
- class PagedTokenToKVPoolAllocator:
278
+ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
158
279
  """
159
280
  An allocator managing the indices to kv cache data.
160
281
 
@@ -172,26 +293,11 @@ class PagedTokenToKVPoolAllocator:
172
293
  device: str,
173
294
  kvcache: KVCache,
174
295
  ):
175
- self.size = size
176
- self.dtype = dtype
177
- self.device = device
178
- self.page_size = page_size
296
+ super().__init__(size, page_size, dtype, device, kvcache)
179
297
  self.num_pages = size // page_size
180
-
181
- self.free_pages = None
182
- self.is_not_in_free_group = True
183
- self.free_group = []
184
- self.clear()
185
298
  self.debug_mode = get_bool_env_var("SGLANG_DEBUG_MEMORY_POOL")
186
-
187
- self._kvcache = kvcache
188
299
  self.ret_values = torch.empty((), dtype=torch.int64, device=self.device)
189
-
190
- def available_size(self):
191
- return len(self.free_pages) * self.page_size
192
-
193
- def get_kvcache(self):
194
- return self._kvcache
300
+ self.clear()
195
301
 
196
302
  def alloc(self, need_size: int):
197
303
  # page-aligned allocation, returning contiguous indices of pages
@@ -298,21 +404,6 @@ class PagedTokenToKVPoolAllocator:
298
404
  if self.debug_mode:
299
405
  assert len(torch.unique(self.free_pages)) == len(self.free_pages)
300
406
 
301
- def free_group_begin(self):
302
- self.is_not_in_free_group = False
303
- self.free_group = []
304
-
305
- def free_group_end(self):
306
- self.is_not_in_free_group = True
307
- if self.free_group:
308
- self.free(torch.cat(self.free_group))
309
-
310
- def backup_state(self):
311
- return self.free_pages
312
-
313
- def restore_state(self, free_pages):
314
- self.free_pages = free_pages
315
-
316
407
  def clear(self):
317
408
  # The padded slot 0 is used for writing dummy outputs from padded tokens.
318
409
  self.free_pages = torch.arange(
@@ -1,5 +1,31 @@
1
1
  from abc import ABC, abstractmethod
2
- from typing import Any, List, Tuple
2
+ from typing import TYPE_CHECKING, Any, List, NamedTuple, Tuple
3
+
4
+ import torch
5
+
6
+ if TYPE_CHECKING:
7
+ from sglang.srt.managers.schedule_batch import Req
8
+ else:
9
+ Req = Any # Placeholder for Req type when not type checking
10
+
11
+
12
+ class MatchResult(NamedTuple):
13
+ """Result of a prefix match operation.
14
+
15
+ Attributes:
16
+ device_indices : Indices of the KV cache on the device matched by common prefix.
17
+ last_device_node: The last TreeNode on the device that was matched.
18
+ last_host_node : The last TreeNode on the host that was matched.
19
+ Note that if HiCache is not enabled,
20
+ this **must** be the same as `last_device_node`.
21
+ host_hit_length : Length of the KV cache hit on the host, if applicable.
22
+ 0 if HiCache is not enabled.
23
+ """
24
+
25
+ device_indices: torch.Tensor
26
+ last_device_node: Any
27
+ last_host_node: Any
28
+ host_hit_length: int = 0
3
29
 
4
30
 
5
31
  class BasePrefixCache(ABC):
@@ -10,19 +36,15 @@ class BasePrefixCache(ABC):
10
36
  pass
11
37
 
12
38
  @abstractmethod
13
- def match_prefix(self, **kwargs) -> Tuple[List[int], int]:
39
+ def match_prefix(self, key: List[int], **kwargs) -> MatchResult:
14
40
  pass
15
41
 
16
42
  @abstractmethod
17
- def insert(self, **kwargs):
43
+ def cache_finished_req(self, req: Req, **kwargs):
18
44
  pass
19
45
 
20
46
  @abstractmethod
21
- def cache_finished_req(self, **kwargs):
22
- pass
23
-
24
- @abstractmethod
25
- def cache_unfinished_req(self, **kwargs):
47
+ def cache_unfinished_req(self, req: Req, **kwargs):
26
48
  pass
27
49
 
28
50
  @abstractmethod
@@ -49,5 +71,27 @@ class BasePrefixCache(ABC):
49
71
  def pretty_print(self):
50
72
  raise NotImplementedError()
51
73
 
74
+ def init_load_back(
75
+ self,
76
+ last_host_node: Any,
77
+ host_hit_length: int,
78
+ ) -> Tuple[torch.Tensor, Any]:
79
+ """
80
+ Preparing KV cache loading from host to device.
81
+ """
82
+ raise NotImplementedError()
83
+
84
+ def ready_to_load_host_cache(self) -> Any:
85
+ """
86
+ Notify the cache controller to start the KV cache loading
87
+ """
88
+ raise NotImplementedError()
89
+
90
+ def check_hicache_events(self) -> Any:
91
+ """
92
+ Check HiCache related activities to update radix tree and synchronize across TP workers if needed
93
+ """
94
+ raise NotImplementedError()
95
+
52
96
  def take_events(self):
53
97
  return []
@@ -2,40 +2,38 @@ from __future__ import annotations
2
2
 
3
3
  """Cache for chunked prefill, used when RadixCache is disabled."""
4
4
 
5
- from typing import TYPE_CHECKING, Any, Callable, List, Tuple
5
+ from typing import TYPE_CHECKING, Any
6
6
 
7
7
  import torch
8
8
 
9
- from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
10
- from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPoolAllocator
9
+ from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
10
+ from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache, MatchResult
11
+ from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
11
12
 
12
13
  if TYPE_CHECKING:
13
14
  from sglang.srt.managers.schedule_batch import Req
14
15
 
15
16
 
16
- class ChunkCacheEntry:
17
- def __init__(self, rid: str, value: torch.Tensor):
18
- self.rid = rid
19
- self.value = value
20
-
21
-
22
17
  class ChunkCache(BasePrefixCache):
23
18
  def __init__(
24
19
  self,
25
20
  req_to_token_pool: ReqToTokenPool,
26
- token_to_kv_pool_allocator: TokenToKVPoolAllocator,
21
+ token_to_kv_pool_allocator: BaseTokenToKVPoolAllocator,
27
22
  page_size: int,
28
23
  ):
29
24
  self.req_to_token_pool = req_to_token_pool
30
25
  self.token_to_kv_pool_allocator = token_to_kv_pool_allocator
31
26
  self.page_size = page_size
32
- self.disable = True
33
27
 
34
28
  def reset(self):
35
29
  pass
36
30
 
37
- def match_prefix(self, **unused_kwargs) -> Tuple[List[int], int]:
38
- return [], None
31
+ def match_prefix(self, **unused_kwargs) -> MatchResult:
32
+ return MatchResult(
33
+ device_indices=torch.empty((0,), dtype=torch.int64),
34
+ last_device_node=None,
35
+ last_host_node=None,
36
+ )
39
37
 
40
38
  def cache_finished_req(self, req: Req):
41
39
  kv_indices = self.req_to_token_pool.req_to_token[
@@ -54,9 +52,6 @@ class ChunkCache(BasePrefixCache):
54
52
  # `req.prefix_indices` will be used in `PrefillAdder::add_chunked_req` later
55
53
  req.prefix_indices = kv_indices
56
54
 
57
- def insert(self):
58
- raise NotImplementedError()
59
-
60
55
  def evict(self, num_tokens: int):
61
56
  pass
62
57