sglang 0.4.7__py3-none-any.whl → 0.4.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (152) hide show
  1. sglang/__init__.py +2 -0
  2. sglang/api.py +7 -0
  3. sglang/bench_one_batch.py +8 -6
  4. sglang/bench_serving.py +1 -1
  5. sglang/lang/interpreter.py +40 -1
  6. sglang/lang/ir.py +27 -0
  7. sglang/math_utils.py +8 -0
  8. sglang/srt/_custom_ops.py +2 -2
  9. sglang/srt/code_completion_parser.py +2 -44
  10. sglang/srt/configs/model_config.py +6 -0
  11. sglang/srt/constants.py +3 -0
  12. sglang/srt/conversation.py +19 -3
  13. sglang/srt/custom_op.py +5 -1
  14. sglang/srt/disaggregation/base/__init__.py +1 -1
  15. sglang/srt/disaggregation/base/conn.py +25 -11
  16. sglang/srt/disaggregation/common/__init__.py +5 -1
  17. sglang/srt/disaggregation/common/utils.py +42 -0
  18. sglang/srt/disaggregation/decode.py +211 -72
  19. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +4 -3
  20. sglang/srt/disaggregation/fake/__init__.py +1 -1
  21. sglang/srt/disaggregation/fake/conn.py +15 -9
  22. sglang/srt/disaggregation/mini_lb.py +34 -4
  23. sglang/srt/disaggregation/mooncake/__init__.py +1 -1
  24. sglang/srt/disaggregation/mooncake/conn.py +30 -29
  25. sglang/srt/disaggregation/nixl/__init__.py +6 -1
  26. sglang/srt/disaggregation/nixl/conn.py +17 -12
  27. sglang/srt/disaggregation/prefill.py +144 -55
  28. sglang/srt/disaggregation/utils.py +155 -123
  29. sglang/srt/distributed/parallel_state.py +12 -4
  30. sglang/srt/entrypoints/engine.py +37 -29
  31. sglang/srt/entrypoints/http_server.py +153 -72
  32. sglang/srt/entrypoints/http_server_engine.py +0 -3
  33. sglang/srt/entrypoints/openai/__init__.py +0 -0
  34. sglang/srt/{openai_api → entrypoints/openai}/protocol.py +84 -10
  35. sglang/srt/entrypoints/openai/serving_base.py +149 -0
  36. sglang/srt/entrypoints/openai/serving_chat.py +921 -0
  37. sglang/srt/entrypoints/openai/serving_completions.py +424 -0
  38. sglang/srt/entrypoints/openai/serving_embedding.py +169 -0
  39. sglang/srt/entrypoints/openai/serving_rerank.py +102 -0
  40. sglang/srt/entrypoints/openai/serving_score.py +61 -0
  41. sglang/srt/entrypoints/openai/usage_processor.py +81 -0
  42. sglang/srt/entrypoints/openai/utils.py +72 -0
  43. sglang/srt/eplb_simulator/__init__.py +1 -0
  44. sglang/srt/eplb_simulator/reader.py +51 -0
  45. sglang/srt/function_call/base_format_detector.py +7 -4
  46. sglang/srt/function_call/deepseekv3_detector.py +1 -1
  47. sglang/srt/function_call/ebnf_composer.py +64 -10
  48. sglang/srt/function_call/function_call_parser.py +6 -6
  49. sglang/srt/function_call/llama32_detector.py +1 -1
  50. sglang/srt/function_call/mistral_detector.py +1 -1
  51. sglang/srt/function_call/pythonic_detector.py +1 -1
  52. sglang/srt/function_call/qwen25_detector.py +1 -1
  53. sglang/srt/{openai_api/utils.py → jinja_template_utils.py} +6 -5
  54. sglang/srt/layers/activation.py +40 -3
  55. sglang/srt/layers/attention/aiter_backend.py +20 -4
  56. sglang/srt/layers/attention/base_attn_backend.py +1 -1
  57. sglang/srt/layers/attention/cutlass_mla_backend.py +39 -15
  58. sglang/srt/layers/attention/flashattention_backend.py +71 -72
  59. sglang/srt/layers/attention/flashinfer_backend.py +10 -8
  60. sglang/srt/layers/attention/flashinfer_mla_backend.py +29 -28
  61. sglang/srt/layers/attention/flashmla_backend.py +7 -12
  62. sglang/srt/layers/attention/tbo_backend.py +3 -3
  63. sglang/srt/layers/attention/triton_backend.py +138 -130
  64. sglang/srt/layers/attention/triton_ops/decode_attention.py +2 -7
  65. sglang/srt/layers/attention/vision.py +51 -24
  66. sglang/srt/layers/communicator.py +28 -10
  67. sglang/srt/layers/dp_attention.py +11 -2
  68. sglang/srt/layers/layernorm.py +29 -2
  69. sglang/srt/layers/linear.py +0 -4
  70. sglang/srt/layers/logits_processor.py +2 -14
  71. sglang/srt/layers/moe/ep_moe/kernels.py +165 -7
  72. sglang/srt/layers/moe/ep_moe/layer.py +249 -33
  73. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +11 -37
  74. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=384,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  75. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +7 -4
  76. sglang/srt/layers/moe/fused_moe_triton/layer.py +75 -12
  77. sglang/srt/layers/moe/topk.py +107 -12
  78. sglang/srt/layers/pooler.py +56 -0
  79. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +6 -2
  80. sglang/srt/layers/quantization/deep_gemm_wrapper/__init__.py +1 -0
  81. sglang/srt/layers/quantization/{deep_gemm.py → deep_gemm_wrapper/compile_utils.py} +23 -80
  82. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +32 -0
  83. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +110 -0
  84. sglang/srt/layers/quantization/fp8.py +25 -17
  85. sglang/srt/layers/quantization/fp8_kernel.py +44 -15
  86. sglang/srt/layers/quantization/fp8_utils.py +87 -22
  87. sglang/srt/layers/quantization/modelopt_quant.py +62 -8
  88. sglang/srt/layers/quantization/utils.py +5 -2
  89. sglang/srt/layers/radix_attention.py +2 -3
  90. sglang/srt/layers/rotary_embedding.py +42 -2
  91. sglang/srt/layers/sampler.py +1 -1
  92. sglang/srt/lora/lora_manager.py +249 -105
  93. sglang/srt/lora/mem_pool.py +53 -50
  94. sglang/srt/lora/utils.py +1 -1
  95. sglang/srt/managers/cache_controller.py +33 -14
  96. sglang/srt/managers/io_struct.py +31 -10
  97. sglang/srt/managers/multimodal_processors/base_processor.py +2 -2
  98. sglang/srt/managers/multimodal_processors/vila.py +85 -0
  99. sglang/srt/managers/schedule_batch.py +79 -37
  100. sglang/srt/managers/schedule_policy.py +70 -56
  101. sglang/srt/managers/scheduler.py +220 -79
  102. sglang/srt/managers/template_manager.py +226 -0
  103. sglang/srt/managers/tokenizer_manager.py +40 -10
  104. sglang/srt/managers/tp_worker.py +12 -2
  105. sglang/srt/managers/tp_worker_overlap_thread.py +11 -0
  106. sglang/srt/mem_cache/{paged_allocator.py → allocator.py} +125 -34
  107. sglang/srt/mem_cache/base_prefix_cache.py +52 -8
  108. sglang/srt/mem_cache/chunk_cache.py +11 -15
  109. sglang/srt/mem_cache/hiradix_cache.py +38 -25
  110. sglang/srt/mem_cache/memory_pool.py +213 -505
  111. sglang/srt/mem_cache/memory_pool_host.py +380 -0
  112. sglang/srt/mem_cache/radix_cache.py +56 -28
  113. sglang/srt/model_executor/cuda_graph_runner.py +198 -100
  114. sglang/srt/model_executor/forward_batch_info.py +32 -10
  115. sglang/srt/model_executor/model_runner.py +28 -12
  116. sglang/srt/model_loader/loader.py +16 -2
  117. sglang/srt/model_loader/weight_utils.py +11 -2
  118. sglang/srt/models/bert.py +113 -13
  119. sglang/srt/models/deepseek_nextn.py +29 -27
  120. sglang/srt/models/deepseek_v2.py +213 -173
  121. sglang/srt/models/glm4.py +312 -0
  122. sglang/srt/models/internvl.py +46 -102
  123. sglang/srt/models/mimo_mtp.py +2 -18
  124. sglang/srt/models/roberta.py +117 -9
  125. sglang/srt/models/vila.py +305 -0
  126. sglang/srt/reasoning_parser.py +21 -11
  127. sglang/srt/sampling/sampling_batch_info.py +24 -0
  128. sglang/srt/sampling/sampling_params.py +2 -0
  129. sglang/srt/server_args.py +351 -238
  130. sglang/srt/speculative/build_eagle_tree.py +1 -1
  131. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +131 -9
  132. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +130 -14
  133. sglang/srt/speculative/eagle_utils.py +468 -116
  134. sglang/srt/speculative/eagle_worker.py +258 -84
  135. sglang/srt/torch_memory_saver_adapter.py +19 -15
  136. sglang/srt/two_batch_overlap.py +4 -2
  137. sglang/srt/utils.py +235 -11
  138. sglang/test/attention/test_prefix_chunk_info.py +2 -0
  139. sglang/test/runners.py +38 -3
  140. sglang/test/test_block_fp8.py +1 -0
  141. sglang/test/test_block_fp8_deep_gemm_blackwell.py +252 -0
  142. sglang/test/test_block_fp8_ep.py +2 -0
  143. sglang/test/test_utils.py +4 -1
  144. sglang/utils.py +9 -0
  145. sglang/version.py +1 -1
  146. {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/METADATA +8 -14
  147. {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/RECORD +150 -128
  148. sglang/srt/entrypoints/verl_engine.py +0 -179
  149. sglang/srt/openai_api/adapter.py +0 -1990
  150. {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/WHEEL +0 -0
  151. {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/licenses/LICENSE +0 -0
  152. {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,226 @@
1
+ # Copyright 2023-2024 SGLang Team
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # ==============================================================================
14
+ """
15
+ Centralized template management for chat templates and completion templates.
16
+
17
+ This module provides a unified interface for managing both chat conversation templates
18
+ and code completion templates, eliminating global state and improving modularity.
19
+ """
20
+
21
+ import json
22
+ import logging
23
+ import os
24
+ from typing import Optional
25
+
26
+ from sglang.srt.code_completion_parser import (
27
+ CompletionTemplate,
28
+ FimPosition,
29
+ completion_template_exists,
30
+ register_completion_template,
31
+ )
32
+ from sglang.srt.conversation import (
33
+ Conversation,
34
+ SeparatorStyle,
35
+ chat_template_exists,
36
+ get_conv_template_by_model_path,
37
+ register_conv_template,
38
+ )
39
+ from sglang.srt.jinja_template_utils import detect_jinja_template_content_format
40
+
41
+ logger = logging.getLogger(__name__)
42
+
43
+
44
+ class TemplateManager:
45
+ """
46
+ Centralized manager for chat and completion templates.
47
+
48
+ This class encapsulates all template-related state and operations,
49
+ eliminating the need for global variables and providing a clean
50
+ interface for template management.
51
+ """
52
+
53
+ def __init__(self):
54
+ self._chat_template_name: Optional[str] = None
55
+ self._completion_template_name: Optional[str] = None
56
+ self._jinja_template_content_format: Optional[str] = None
57
+
58
+ @property
59
+ def chat_template_name(self) -> Optional[str]:
60
+ """Get the current chat template name."""
61
+ return self._chat_template_name
62
+
63
+ @property
64
+ def completion_template_name(self) -> Optional[str]:
65
+ """Get the current completion template name."""
66
+ return self._completion_template_name
67
+
68
+ @property
69
+ def jinja_template_content_format(self) -> Optional[str]:
70
+ """Get the detected template content format ('string' or 'openai' or None)."""
71
+ return self._jinja_template_content_format
72
+
73
+ def load_chat_template(
74
+ self, tokenizer_manager, chat_template_arg: str, model_path: str
75
+ ) -> None:
76
+ """
77
+ Load a chat template from various sources.
78
+
79
+ Args:
80
+ tokenizer_manager: The tokenizer manager instance
81
+ chat_template_arg: Template name or file path
82
+ model_path: Path to the model
83
+ """
84
+ logger.info(f"Loading chat template: {chat_template_arg}")
85
+
86
+ if not chat_template_exists(chat_template_arg):
87
+ if not os.path.exists(chat_template_arg):
88
+ raise RuntimeError(
89
+ f"Chat template {chat_template_arg} is not a built-in template name "
90
+ "or a valid chat template file path."
91
+ )
92
+
93
+ if chat_template_arg.endswith(".jinja"):
94
+ self._load_jinja_template(tokenizer_manager, chat_template_arg)
95
+ else:
96
+ self._load_json_chat_template(chat_template_arg)
97
+ else:
98
+ self._chat_template_name = chat_template_arg
99
+
100
+ def guess_chat_template_from_model_path(self, model_path: str) -> None:
101
+ """
102
+ Infer chat template name from model path.
103
+
104
+ Args:
105
+ model_path: Path to the model
106
+ """
107
+ template_name = get_conv_template_by_model_path(model_path)
108
+ if template_name is not None:
109
+ logger.info(f"Inferred chat template from model path: {template_name}")
110
+ self._chat_template_name = template_name
111
+
112
+ def load_completion_template(self, completion_template_arg: str) -> None:
113
+ """
114
+ Load completion template for code completion.
115
+
116
+ Args:
117
+ completion_template_arg: Template name or file path
118
+ """
119
+ logger.info(f"Loading completion template: {completion_template_arg}")
120
+
121
+ if not completion_template_exists(completion_template_arg):
122
+ if not os.path.exists(completion_template_arg):
123
+ raise RuntimeError(
124
+ f"Completion template {completion_template_arg} is not a built-in template name "
125
+ "or a valid completion template file path."
126
+ )
127
+
128
+ self._load_json_completion_template(completion_template_arg)
129
+ else:
130
+ self._completion_template_name = completion_template_arg
131
+
132
+ def initialize_templates(
133
+ self,
134
+ tokenizer_manager,
135
+ model_path: str,
136
+ chat_template: Optional[str] = None,
137
+ completion_template: Optional[str] = None,
138
+ ) -> None:
139
+ """
140
+ Initialize all templates based on provided configuration.
141
+
142
+ Args:
143
+ tokenizer_manager: The tokenizer manager instance
144
+ model_path: Path to the model
145
+ chat_template: Optional chat template name/path
146
+ completion_template: Optional completion template name/path
147
+ """
148
+ # Load chat template
149
+ if chat_template:
150
+ self.load_chat_template(tokenizer_manager, chat_template, model_path)
151
+ else:
152
+ self.guess_chat_template_from_model_path(model_path)
153
+
154
+ # Load completion template
155
+ if completion_template:
156
+ self.load_completion_template(completion_template)
157
+
158
+ def _load_jinja_template(self, tokenizer_manager, template_path: str) -> None:
159
+ """Load a Jinja template file."""
160
+ with open(template_path, "r") as f:
161
+ chat_template = "".join(f.readlines()).strip("\n")
162
+ tokenizer_manager.tokenizer.chat_template = chat_template.replace("\\n", "\n")
163
+ self._chat_template_name = None
164
+ # Detect content format from the loaded template
165
+ self._jinja_template_content_format = detect_jinja_template_content_format(
166
+ chat_template
167
+ )
168
+ logger.info(
169
+ f"Detected chat template content format: {self._jinja_template_content_format}"
170
+ )
171
+
172
+ def _load_json_chat_template(self, template_path: str) -> None:
173
+ """Load a JSON chat template file."""
174
+ assert template_path.endswith(
175
+ ".json"
176
+ ), "unrecognized format of chat template file"
177
+
178
+ with open(template_path, "r") as filep:
179
+ template = json.load(filep)
180
+ try:
181
+ sep_style = SeparatorStyle[template["sep_style"]]
182
+ except KeyError:
183
+ raise ValueError(
184
+ f"Unknown separator style: {template['sep_style']}"
185
+ ) from None
186
+
187
+ register_conv_template(
188
+ Conversation(
189
+ name=template["name"],
190
+ system_template=template["system"] + "\n{system_message}",
191
+ system_message=template.get("system_message", ""),
192
+ roles=(template["user"], template["assistant"]),
193
+ sep_style=sep_style,
194
+ sep=template.get("sep", "\n"),
195
+ stop_str=template["stop_str"],
196
+ ),
197
+ override=True,
198
+ )
199
+ self._chat_template_name = template["name"]
200
+
201
+ def _load_json_completion_template(self, template_path: str) -> None:
202
+ """Load a JSON completion template file."""
203
+ assert template_path.endswith(
204
+ ".json"
205
+ ), "unrecognized format of completion template file"
206
+
207
+ with open(template_path, "r") as filep:
208
+ template = json.load(filep)
209
+ try:
210
+ fim_position = FimPosition[template["fim_position"]]
211
+ except KeyError:
212
+ raise ValueError(
213
+ f"Unknown fim position: {template['fim_position']}"
214
+ ) from None
215
+
216
+ register_completion_template(
217
+ CompletionTemplate(
218
+ name=template["name"],
219
+ fim_begin_token=template["fim_begin_token"],
220
+ fim_middle_token=template["fim_middle_token"],
221
+ fim_end_token=template["fim_end_token"],
222
+ fim_position=fim_position,
223
+ ),
224
+ override=True,
225
+ )
226
+ self._completion_template_name = template["name"]
@@ -418,6 +418,20 @@ class TokenizerManager:
418
418
 
419
419
  obj.normalize_batch_and_arguments()
420
420
 
421
+ if isinstance(obj, GenerateReqInput):
422
+ return_hidden_states = obj.return_hidden_states
423
+ has_return_hidden_states = return_hidden_states == True or (
424
+ isinstance(return_hidden_states, list) and any(return_hidden_states)
425
+ )
426
+ if (
427
+ not self.server_args.enable_return_hidden_states
428
+ and has_return_hidden_states
429
+ ):
430
+ raise ValueError(
431
+ "return_hidden_states=True requires the server to be started "
432
+ "with --enable-return-hidden-states (ServerArgs.enable_return_hidden_states)."
433
+ )
434
+
421
435
  if self.log_requests:
422
436
  max_length, skip_names, _ = self.log_request_metadata
423
437
  logger.info(
@@ -445,6 +459,10 @@ class TokenizerManager:
445
459
  # Tokenize
446
460
  input_embeds = None
447
461
  input_text = obj.text
462
+ token_type_ids = None
463
+ is_cross_encoder_request = (
464
+ isinstance(obj, EmbeddingReqInput) and obj.is_cross_encoder_request
465
+ )
448
466
  if obj.input_embeds is not None:
449
467
  if not self.server_args.disable_radix_cache:
450
468
  raise ValueError(
@@ -463,7 +481,14 @@ class TokenizerManager:
463
481
  "accept text prompts. Please provide input_ids or re-initialize "
464
482
  "the engine with skip_tokenizer_init=False."
465
483
  )
466
- input_ids = self.tokenizer.encode(input_text)
484
+ encoded = self.tokenizer(
485
+ input_text, return_token_type_ids=is_cross_encoder_request
486
+ )
487
+
488
+ input_ids = encoded["input_ids"]
489
+ if is_cross_encoder_request:
490
+ input_ids = encoded["input_ids"][0]
491
+ token_type_ids = encoded.get("token_type_ids", [None])[0]
467
492
 
468
493
  if self.mm_processor and obj.contains_mm_input():
469
494
  image_inputs = await self.mm_processor.process_mm_data_async(
@@ -479,7 +504,7 @@ class TokenizerManager:
479
504
 
480
505
  self._validate_token_len(obj, input_ids)
481
506
  return self._create_tokenized_object(
482
- obj, input_text, input_ids, input_embeds, image_inputs
507
+ obj, input_text, input_ids, input_embeds, image_inputs, token_type_ids
483
508
  )
484
509
 
485
510
  def _validate_token_len(
@@ -518,6 +543,7 @@ class TokenizerManager:
518
543
  input_ids: List[int],
519
544
  input_embeds: Optional[Union[List[float], None]] = None,
520
545
  image_inputs: Optional[Dict] = None,
546
+ token_type_ids: Optional[List[int]] = None,
521
547
  ) -> Union[TokenizedGenerateReqInput, TokenizedEmbeddingReqInput]:
522
548
  """Create a tokenized request object from common parameters."""
523
549
 
@@ -578,6 +604,7 @@ class TokenizerManager:
578
604
  input_text,
579
605
  input_ids,
580
606
  image_inputs,
607
+ token_type_ids,
581
608
  sampling_params,
582
609
  )
583
610
 
@@ -1031,12 +1058,7 @@ class TokenizerManager:
1031
1058
  "lora_path",
1032
1059
  ]
1033
1060
  )
1034
- out_skip_names = set(
1035
- [
1036
- "text",
1037
- "output_ids",
1038
- ]
1039
- )
1061
+ out_skip_names = set(["text", "output_ids", "embedding"])
1040
1062
  elif self.log_requests_level == 1:
1041
1063
  max_length = 2048
1042
1064
  elif self.log_requests_level == 2:
@@ -1113,13 +1135,21 @@ class TokenizerManager:
1113
1135
  remain_num_req = len(self.rid_to_state)
1114
1136
 
1115
1137
  if self.health_check_failed:
1116
- # if health check failed, we should exit immediately
1138
+ # if health check failed, exit immediately
1117
1139
  logger.error(
1118
1140
  "Signal SIGTERM received while health check failed. Exiting... remaining number of requests: %d",
1119
1141
  remain_num_req,
1120
1142
  )
1121
1143
  break
1122
1144
 
1145
+ elif get_bool_env_var("SGL_FORCE_SHUTDOWN"):
1146
+ # if force shutdown flag set, exit immediately
1147
+ logger.error(
1148
+ "Signal SIGTERM received while force shutdown flag set. Force exiting... remaining number of requests: %d",
1149
+ remain_num_req,
1150
+ )
1151
+ break
1152
+
1123
1153
  logger.info(
1124
1154
  f"Gracefully exiting... remaining number of requests {remain_num_req}"
1125
1155
  )
@@ -1196,7 +1226,7 @@ class TokenizerManager:
1196
1226
  state.last_output_offset = len(state.output_ids)
1197
1227
  else:
1198
1228
  state.output_ids.extend(recv_obj.output_ids[i])
1199
- output_token_ids = state.output_ids
1229
+ output_token_ids = state.output_ids.copy()
1200
1230
 
1201
1231
  out_dict = {
1202
1232
  "output_ids": output_token_ids,
@@ -35,7 +35,8 @@ from sglang.srt.managers.io_struct import (
35
35
  UpdateWeightsFromTensorReqInput,
36
36
  )
37
37
  from sglang.srt.managers.schedule_batch import ModelWorkerBatch, global_server_args_dict
38
- from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPoolAllocator
38
+ from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
39
+ from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
39
40
  from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
40
41
  from sglang.srt.model_executor.model_runner import ModelRunner
41
42
  from sglang.srt.server_args import ServerArgs
@@ -57,7 +58,7 @@ class TpModelWorker:
57
58
  nccl_port: int,
58
59
  is_draft_worker: bool = False,
59
60
  req_to_token_pool: Optional[ReqToTokenPool] = None,
60
- token_to_kv_pool_allocator: Optional[TokenToKVPoolAllocator] = None,
61
+ token_to_kv_pool_allocator: Optional[BaseTokenToKVPoolAllocator] = None,
61
62
  ):
62
63
  # Parse args
63
64
  self.tp_size = server_args.tp_size
@@ -147,6 +148,15 @@ class TpModelWorker:
147
148
  # A reference make this class has the same member as TpModelWorkerClient
148
149
  self.worker = self
149
150
 
151
+ self.hicache_layer_transfer_counter = None
152
+
153
+ def register_hicache_layer_transfer_counter(self, counter):
154
+ self.hicache_layer_transfer_counter = counter
155
+
156
+ def set_hicache_consumer(self, consumer_index):
157
+ if self.hicache_layer_transfer_counter is not None:
158
+ self.hicache_layer_transfer_counter.set_consumer(consumer_index)
159
+
150
160
  def get_worker_info(self):
151
161
  return (
152
162
  self.max_total_num_tokens,
@@ -88,6 +88,15 @@ class TpModelWorkerClient:
88
88
  if self.device == "cpu":
89
89
  self.scheduler_stream.synchronize = lambda: None # No-op for CPU
90
90
 
91
+ self.hicache_layer_transfer_counter = None
92
+
93
+ def register_hicache_layer_transfer_counter(self, counter):
94
+ self.hicache_layer_transfer_counter = counter
95
+
96
+ def set_hicache_consumer(self, consumer_index):
97
+ if self.hicache_layer_transfer_counter is not None:
98
+ self.hicache_layer_transfer_counter.set_consumer(consumer_index)
99
+
91
100
  def get_worker_info(self):
92
101
  return self.worker.get_worker_info()
93
102
 
@@ -146,6 +155,8 @@ class TpModelWorkerClient:
146
155
  input_ids = model_worker_batch.input_ids
147
156
  resolve_future_token_ids(input_ids, self.future_token_ids_map)
148
157
 
158
+ # update the consumer index of hicache to the running batch
159
+ self.set_hicache_consumer(model_worker_batch.hicache_consumer_index)
149
160
  # Run forward
150
161
  logits_output, next_token_ids, can_run_cuda_graph = (
151
162
  self.worker.forward_batch_generation(
@@ -1,3 +1,5 @@
1
+ from __future__ import annotations
2
+
1
3
  """
2
4
  Copyright 2025 SGLang Team
3
5
  Licensed under the Apache License, Version 2.0 (the "License");
@@ -17,13 +19,132 @@ limitations under the License.
17
19
  Page-aligned memory pool.
18
20
  """
19
21
 
22
+ import abc
23
+ from typing import TYPE_CHECKING
24
+
20
25
  import torch
21
26
  import triton
22
27
  import triton.language as tl
23
28
 
24
- from sglang.srt.mem_cache.memory_pool import KVCache
25
29
  from sglang.srt.utils import get_bool_env_var, next_power_of_2
26
30
 
31
+ if TYPE_CHECKING:
32
+ from sglang.srt.mem_cache.memory_pool import KVCache
33
+
34
+
35
+ class BaseTokenToKVPoolAllocator(abc.ABC):
36
+ @abc.abstractmethod
37
+ def __init__(
38
+ self,
39
+ size: int,
40
+ page_size: int,
41
+ dtype: torch.dtype,
42
+ device: str,
43
+ kvcache: KVCache,
44
+ ):
45
+ self.size = size
46
+ self.page_size = page_size
47
+ self.dtype = dtype
48
+ self.device = device
49
+ self._kvcache = kvcache
50
+
51
+ self.free_pages = None
52
+ self.is_not_in_free_group = True
53
+ self.free_group = []
54
+
55
+ def debug_print(self) -> str:
56
+ return ""
57
+
58
+ def available_size(self):
59
+ return len(self.free_pages) * self.page_size
60
+
61
+ def get_kvcache(self):
62
+ return self._kvcache
63
+
64
+ def restore_state(self, free_pages):
65
+ self.free_pages = free_pages
66
+
67
+ def backup_state(self):
68
+ return self.free_pages
69
+
70
+ def free_group_begin(self):
71
+ self.is_not_in_free_group = False
72
+ self.free_group = []
73
+
74
+ def free_group_end(self):
75
+ self.is_not_in_free_group = True
76
+ if self.free_group:
77
+ self.free(torch.cat(self.free_group))
78
+
79
+ def get_cpu_copy(self, *args, **kwargs):
80
+ # FIXME: reuse the get_cpu_copy after paged allocator is implemented
81
+ raise NotImplementedError()
82
+
83
+ def load_cpu_copy(self, *args, **kwargs):
84
+ # FIXME: reuse the load_cpu_copy after paged allocator is implemented
85
+ raise NotImplementedError()
86
+
87
+ def alloc_extend(self, *args, **kwargs):
88
+ raise NotImplementedError("alloc_extend is only for paged allocator")
89
+
90
+ def alloc_decode(self, *args, **kwargs):
91
+ raise NotImplementedError("alloc_decode is only for paged allocator")
92
+
93
+ @abc.abstractmethod
94
+ def clear(self):
95
+ raise NotImplementedError()
96
+
97
+ @abc.abstractmethod
98
+ def alloc(self, need_size: int):
99
+ raise NotImplementedError()
100
+
101
+ @abc.abstractmethod
102
+ def free(self, free_index: torch.Tensor):
103
+ raise NotImplementedError()
104
+
105
+
106
+ class TokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
107
+ """An allocator managing the indices to kv cache data."""
108
+
109
+ def __init__(self, size: int, dtype: torch.dtype, device: str, kvcache: KVCache):
110
+ super().__init__(size, 1, dtype, device, kvcache)
111
+ self.clear()
112
+
113
+ def clear(self):
114
+ # The padded slot 0 is used for writing dummy outputs from padded tokens.
115
+ self.free_pages = torch.arange(
116
+ 1, self.size + 1, dtype=torch.int64, device=self.device
117
+ )
118
+ self.is_not_in_free_group = True
119
+ self.free_group = []
120
+
121
+ def available_size(self):
122
+ # To avoid minor "len(free_pages) * 1" overhead
123
+ return len(self.free_pages)
124
+
125
+ def alloc(self, need_size: int):
126
+ if need_size > len(self.free_pages):
127
+ return None
128
+
129
+ select_index = self.free_pages[:need_size]
130
+ self.free_pages = self.free_pages[need_size:]
131
+ return select_index
132
+
133
+ def free(self, free_index: torch.Tensor):
134
+ if free_index.numel() == 0:
135
+ return
136
+
137
+ if self.is_not_in_free_group:
138
+ self.free_pages = torch.cat((self.free_pages, free_index))
139
+ else:
140
+ self.free_group.append(free_index)
141
+
142
+ def get_cpu_copy(self, indices):
143
+ return self._kvcache.get_cpu_copy(indices)
144
+
145
+ def load_cpu_copy(self, kv_cache_cpu, indices):
146
+ return self._kvcache.load_cpu_copy(kv_cache_cpu, indices)
147
+
27
148
 
28
149
  @triton.jit
29
150
  def alloc_extend_kernel(
@@ -154,7 +275,7 @@ def alloc_decode_kernel(
154
275
  tl.store(out_indices + pid, page * page_size)
155
276
 
156
277
 
157
- class PagedTokenToKVPoolAllocator:
278
+ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
158
279
  """
159
280
  An allocator managing the indices to kv cache data.
160
281
 
@@ -172,26 +293,11 @@ class PagedTokenToKVPoolAllocator:
172
293
  device: str,
173
294
  kvcache: KVCache,
174
295
  ):
175
- self.size = size
176
- self.dtype = dtype
177
- self.device = device
178
- self.page_size = page_size
296
+ super().__init__(size, page_size, dtype, device, kvcache)
179
297
  self.num_pages = size // page_size
180
-
181
- self.free_pages = None
182
- self.is_not_in_free_group = True
183
- self.free_group = []
184
- self.clear()
185
298
  self.debug_mode = get_bool_env_var("SGLANG_DEBUG_MEMORY_POOL")
186
-
187
- self._kvcache = kvcache
188
299
  self.ret_values = torch.empty((), dtype=torch.int64, device=self.device)
189
-
190
- def available_size(self):
191
- return len(self.free_pages) * self.page_size
192
-
193
- def get_kvcache(self):
194
- return self._kvcache
300
+ self.clear()
195
301
 
196
302
  def alloc(self, need_size: int):
197
303
  # page-aligned allocation, returning contiguous indices of pages
@@ -298,21 +404,6 @@ class PagedTokenToKVPoolAllocator:
298
404
  if self.debug_mode:
299
405
  assert len(torch.unique(self.free_pages)) == len(self.free_pages)
300
406
 
301
- def free_group_begin(self):
302
- self.is_not_in_free_group = False
303
- self.free_group = []
304
-
305
- def free_group_end(self):
306
- self.is_not_in_free_group = True
307
- if self.free_group:
308
- self.free(torch.cat(self.free_group))
309
-
310
- def backup_state(self):
311
- return self.free_pages
312
-
313
- def restore_state(self, free_pages):
314
- self.free_pages = free_pages
315
-
316
407
  def clear(self):
317
408
  # The padded slot 0 is used for writing dummy outputs from padded tokens.
318
409
  self.free_pages = torch.arange(
@@ -1,5 +1,31 @@
1
1
  from abc import ABC, abstractmethod
2
- from typing import Any, List, Tuple
2
+ from typing import TYPE_CHECKING, Any, List, NamedTuple, Tuple
3
+
4
+ import torch
5
+
6
+ if TYPE_CHECKING:
7
+ from sglang.srt.managers.schedule_batch import Req
8
+ else:
9
+ Req = Any # Placeholder for Req type when not type checking
10
+
11
+
12
+ class MatchResult(NamedTuple):
13
+ """Result of a prefix match operation.
14
+
15
+ Attributes:
16
+ device_indices : Indices of the KV cache on the device matched by common prefix.
17
+ last_device_node: The last TreeNode on the device that was matched.
18
+ last_host_node : The last TreeNode on the host that was matched.
19
+ Note that if HiCache is not enabled,
20
+ this **must** be the same as `last_device_node`.
21
+ host_hit_length : Length of the KV cache hit on the host, if applicable.
22
+ 0 if HiCache is not enabled.
23
+ """
24
+
25
+ device_indices: torch.Tensor
26
+ last_device_node: Any
27
+ last_host_node: Any
28
+ host_hit_length: int = 0
3
29
 
4
30
 
5
31
  class BasePrefixCache(ABC):
@@ -10,19 +36,15 @@ class BasePrefixCache(ABC):
10
36
  pass
11
37
 
12
38
  @abstractmethod
13
- def match_prefix(self, **kwargs) -> Tuple[List[int], int]:
39
+ def match_prefix(self, key: List[int], **kwargs) -> MatchResult:
14
40
  pass
15
41
 
16
42
  @abstractmethod
17
- def insert(self, **kwargs):
43
+ def cache_finished_req(self, req: Req, **kwargs):
18
44
  pass
19
45
 
20
46
  @abstractmethod
21
- def cache_finished_req(self, **kwargs):
22
- pass
23
-
24
- @abstractmethod
25
- def cache_unfinished_req(self, **kwargs):
47
+ def cache_unfinished_req(self, req: Req, **kwargs):
26
48
  pass
27
49
 
28
50
  @abstractmethod
@@ -49,5 +71,27 @@ class BasePrefixCache(ABC):
49
71
  def pretty_print(self):
50
72
  raise NotImplementedError()
51
73
 
74
+ def init_load_back(
75
+ self,
76
+ last_host_node: Any,
77
+ host_hit_length: int,
78
+ ) -> Tuple[torch.Tensor, Any]:
79
+ """
80
+ Preparing KV cache loading from host to device.
81
+ """
82
+ raise NotImplementedError()
83
+
84
+ def ready_to_load_host_cache(self) -> Any:
85
+ """
86
+ Notify the cache controller to start the KV cache loading
87
+ """
88
+ raise NotImplementedError()
89
+
90
+ def check_hicache_events(self) -> Any:
91
+ """
92
+ Check HiCache related activities to update radix tree and synchronize across TP workers if needed
93
+ """
94
+ raise NotImplementedError()
95
+
52
96
  def take_events(self):
53
97
  return []