sglang 0.4.7__py3-none-any.whl → 0.4.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +2 -0
- sglang/api.py +7 -0
- sglang/bench_one_batch.py +8 -6
- sglang/bench_serving.py +1 -1
- sglang/lang/interpreter.py +40 -1
- sglang/lang/ir.py +27 -0
- sglang/math_utils.py +8 -0
- sglang/srt/_custom_ops.py +2 -2
- sglang/srt/code_completion_parser.py +2 -44
- sglang/srt/configs/model_config.py +6 -0
- sglang/srt/constants.py +3 -0
- sglang/srt/conversation.py +19 -3
- sglang/srt/custom_op.py +5 -1
- sglang/srt/disaggregation/base/__init__.py +1 -1
- sglang/srt/disaggregation/base/conn.py +25 -11
- sglang/srt/disaggregation/common/__init__.py +5 -1
- sglang/srt/disaggregation/common/utils.py +42 -0
- sglang/srt/disaggregation/decode.py +211 -72
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +4 -3
- sglang/srt/disaggregation/fake/__init__.py +1 -1
- sglang/srt/disaggregation/fake/conn.py +15 -9
- sglang/srt/disaggregation/mini_lb.py +34 -4
- sglang/srt/disaggregation/mooncake/__init__.py +1 -1
- sglang/srt/disaggregation/mooncake/conn.py +30 -29
- sglang/srt/disaggregation/nixl/__init__.py +6 -1
- sglang/srt/disaggregation/nixl/conn.py +17 -12
- sglang/srt/disaggregation/prefill.py +144 -55
- sglang/srt/disaggregation/utils.py +155 -123
- sglang/srt/distributed/parallel_state.py +12 -4
- sglang/srt/entrypoints/engine.py +37 -29
- sglang/srt/entrypoints/http_server.py +153 -72
- sglang/srt/entrypoints/http_server_engine.py +0 -3
- sglang/srt/entrypoints/openai/__init__.py +0 -0
- sglang/srt/{openai_api → entrypoints/openai}/protocol.py +84 -10
- sglang/srt/entrypoints/openai/serving_base.py +149 -0
- sglang/srt/entrypoints/openai/serving_chat.py +921 -0
- sglang/srt/entrypoints/openai/serving_completions.py +424 -0
- sglang/srt/entrypoints/openai/serving_embedding.py +169 -0
- sglang/srt/entrypoints/openai/serving_rerank.py +102 -0
- sglang/srt/entrypoints/openai/serving_score.py +61 -0
- sglang/srt/entrypoints/openai/usage_processor.py +81 -0
- sglang/srt/entrypoints/openai/utils.py +72 -0
- sglang/srt/eplb_simulator/__init__.py +1 -0
- sglang/srt/eplb_simulator/reader.py +51 -0
- sglang/srt/function_call/base_format_detector.py +7 -4
- sglang/srt/function_call/deepseekv3_detector.py +1 -1
- sglang/srt/function_call/ebnf_composer.py +64 -10
- sglang/srt/function_call/function_call_parser.py +6 -6
- sglang/srt/function_call/llama32_detector.py +1 -1
- sglang/srt/function_call/mistral_detector.py +1 -1
- sglang/srt/function_call/pythonic_detector.py +1 -1
- sglang/srt/function_call/qwen25_detector.py +1 -1
- sglang/srt/{openai_api/utils.py → jinja_template_utils.py} +6 -5
- sglang/srt/layers/activation.py +40 -3
- sglang/srt/layers/attention/aiter_backend.py +20 -4
- sglang/srt/layers/attention/base_attn_backend.py +1 -1
- sglang/srt/layers/attention/cutlass_mla_backend.py +39 -15
- sglang/srt/layers/attention/flashattention_backend.py +71 -72
- sglang/srt/layers/attention/flashinfer_backend.py +10 -8
- sglang/srt/layers/attention/flashinfer_mla_backend.py +29 -28
- sglang/srt/layers/attention/flashmla_backend.py +7 -12
- sglang/srt/layers/attention/tbo_backend.py +3 -3
- sglang/srt/layers/attention/triton_backend.py +138 -130
- sglang/srt/layers/attention/triton_ops/decode_attention.py +2 -7
- sglang/srt/layers/attention/vision.py +51 -24
- sglang/srt/layers/communicator.py +28 -10
- sglang/srt/layers/dp_attention.py +11 -2
- sglang/srt/layers/layernorm.py +29 -2
- sglang/srt/layers/linear.py +0 -4
- sglang/srt/layers/logits_processor.py +2 -14
- sglang/srt/layers/moe/ep_moe/kernels.py +165 -7
- sglang/srt/layers/moe/ep_moe/layer.py +249 -33
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +11 -37
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=384,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +7 -4
- sglang/srt/layers/moe/fused_moe_triton/layer.py +75 -12
- sglang/srt/layers/moe/topk.py +107 -12
- sglang/srt/layers/pooler.py +56 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +6 -2
- sglang/srt/layers/quantization/deep_gemm_wrapper/__init__.py +1 -0
- sglang/srt/layers/quantization/{deep_gemm.py → deep_gemm_wrapper/compile_utils.py} +23 -80
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +32 -0
- sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +110 -0
- sglang/srt/layers/quantization/fp8.py +25 -17
- sglang/srt/layers/quantization/fp8_kernel.py +44 -15
- sglang/srt/layers/quantization/fp8_utils.py +87 -22
- sglang/srt/layers/quantization/modelopt_quant.py +62 -8
- sglang/srt/layers/quantization/utils.py +5 -2
- sglang/srt/layers/radix_attention.py +2 -3
- sglang/srt/layers/rotary_embedding.py +42 -2
- sglang/srt/layers/sampler.py +1 -1
- sglang/srt/lora/lora_manager.py +249 -105
- sglang/srt/lora/mem_pool.py +53 -50
- sglang/srt/lora/utils.py +1 -1
- sglang/srt/managers/cache_controller.py +33 -14
- sglang/srt/managers/io_struct.py +31 -10
- sglang/srt/managers/multimodal_processors/base_processor.py +2 -2
- sglang/srt/managers/multimodal_processors/vila.py +85 -0
- sglang/srt/managers/schedule_batch.py +79 -37
- sglang/srt/managers/schedule_policy.py +70 -56
- sglang/srt/managers/scheduler.py +220 -79
- sglang/srt/managers/template_manager.py +226 -0
- sglang/srt/managers/tokenizer_manager.py +40 -10
- sglang/srt/managers/tp_worker.py +12 -2
- sglang/srt/managers/tp_worker_overlap_thread.py +11 -0
- sglang/srt/mem_cache/{paged_allocator.py → allocator.py} +125 -34
- sglang/srt/mem_cache/base_prefix_cache.py +52 -8
- sglang/srt/mem_cache/chunk_cache.py +11 -15
- sglang/srt/mem_cache/hiradix_cache.py +38 -25
- sglang/srt/mem_cache/memory_pool.py +213 -505
- sglang/srt/mem_cache/memory_pool_host.py +380 -0
- sglang/srt/mem_cache/radix_cache.py +56 -28
- sglang/srt/model_executor/cuda_graph_runner.py +198 -100
- sglang/srt/model_executor/forward_batch_info.py +32 -10
- sglang/srt/model_executor/model_runner.py +28 -12
- sglang/srt/model_loader/loader.py +16 -2
- sglang/srt/model_loader/weight_utils.py +11 -2
- sglang/srt/models/bert.py +113 -13
- sglang/srt/models/deepseek_nextn.py +29 -27
- sglang/srt/models/deepseek_v2.py +213 -173
- sglang/srt/models/glm4.py +312 -0
- sglang/srt/models/internvl.py +46 -102
- sglang/srt/models/mimo_mtp.py +2 -18
- sglang/srt/models/roberta.py +117 -9
- sglang/srt/models/vila.py +305 -0
- sglang/srt/reasoning_parser.py +21 -11
- sglang/srt/sampling/sampling_batch_info.py +24 -0
- sglang/srt/sampling/sampling_params.py +2 -0
- sglang/srt/server_args.py +351 -238
- sglang/srt/speculative/build_eagle_tree.py +1 -1
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +131 -9
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +130 -14
- sglang/srt/speculative/eagle_utils.py +468 -116
- sglang/srt/speculative/eagle_worker.py +258 -84
- sglang/srt/torch_memory_saver_adapter.py +19 -15
- sglang/srt/two_batch_overlap.py +4 -2
- sglang/srt/utils.py +235 -11
- sglang/test/attention/test_prefix_chunk_info.py +2 -0
- sglang/test/runners.py +38 -3
- sglang/test/test_block_fp8.py +1 -0
- sglang/test/test_block_fp8_deep_gemm_blackwell.py +252 -0
- sglang/test/test_block_fp8_ep.py +2 -0
- sglang/test/test_utils.py +4 -1
- sglang/utils.py +9 -0
- sglang/version.py +1 -1
- {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/METADATA +8 -14
- {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/RECORD +150 -128
- sglang/srt/entrypoints/verl_engine.py +0 -179
- sglang/srt/openai_api/adapter.py +0 -1990
- {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/WHEEL +0 -0
- {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,226 @@
|
|
1
|
+
# Copyright 2023-2024 SGLang Team
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
#
|
6
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
7
|
+
#
|
8
|
+
# Unless required by applicable law or agreed to in writing, software
|
9
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
10
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
11
|
+
# See the License for the specific language governing permissions and
|
12
|
+
# limitations under the License.
|
13
|
+
# ==============================================================================
|
14
|
+
"""
|
15
|
+
Centralized template management for chat templates and completion templates.
|
16
|
+
|
17
|
+
This module provides a unified interface for managing both chat conversation templates
|
18
|
+
and code completion templates, eliminating global state and improving modularity.
|
19
|
+
"""
|
20
|
+
|
21
|
+
import json
|
22
|
+
import logging
|
23
|
+
import os
|
24
|
+
from typing import Optional
|
25
|
+
|
26
|
+
from sglang.srt.code_completion_parser import (
|
27
|
+
CompletionTemplate,
|
28
|
+
FimPosition,
|
29
|
+
completion_template_exists,
|
30
|
+
register_completion_template,
|
31
|
+
)
|
32
|
+
from sglang.srt.conversation import (
|
33
|
+
Conversation,
|
34
|
+
SeparatorStyle,
|
35
|
+
chat_template_exists,
|
36
|
+
get_conv_template_by_model_path,
|
37
|
+
register_conv_template,
|
38
|
+
)
|
39
|
+
from sglang.srt.jinja_template_utils import detect_jinja_template_content_format
|
40
|
+
|
41
|
+
logger = logging.getLogger(__name__)
|
42
|
+
|
43
|
+
|
44
|
+
class TemplateManager:
|
45
|
+
"""
|
46
|
+
Centralized manager for chat and completion templates.
|
47
|
+
|
48
|
+
This class encapsulates all template-related state and operations,
|
49
|
+
eliminating the need for global variables and providing a clean
|
50
|
+
interface for template management.
|
51
|
+
"""
|
52
|
+
|
53
|
+
def __init__(self):
|
54
|
+
self._chat_template_name: Optional[str] = None
|
55
|
+
self._completion_template_name: Optional[str] = None
|
56
|
+
self._jinja_template_content_format: Optional[str] = None
|
57
|
+
|
58
|
+
@property
|
59
|
+
def chat_template_name(self) -> Optional[str]:
|
60
|
+
"""Get the current chat template name."""
|
61
|
+
return self._chat_template_name
|
62
|
+
|
63
|
+
@property
|
64
|
+
def completion_template_name(self) -> Optional[str]:
|
65
|
+
"""Get the current completion template name."""
|
66
|
+
return self._completion_template_name
|
67
|
+
|
68
|
+
@property
|
69
|
+
def jinja_template_content_format(self) -> Optional[str]:
|
70
|
+
"""Get the detected template content format ('string' or 'openai' or None)."""
|
71
|
+
return self._jinja_template_content_format
|
72
|
+
|
73
|
+
def load_chat_template(
|
74
|
+
self, tokenizer_manager, chat_template_arg: str, model_path: str
|
75
|
+
) -> None:
|
76
|
+
"""
|
77
|
+
Load a chat template from various sources.
|
78
|
+
|
79
|
+
Args:
|
80
|
+
tokenizer_manager: The tokenizer manager instance
|
81
|
+
chat_template_arg: Template name or file path
|
82
|
+
model_path: Path to the model
|
83
|
+
"""
|
84
|
+
logger.info(f"Loading chat template: {chat_template_arg}")
|
85
|
+
|
86
|
+
if not chat_template_exists(chat_template_arg):
|
87
|
+
if not os.path.exists(chat_template_arg):
|
88
|
+
raise RuntimeError(
|
89
|
+
f"Chat template {chat_template_arg} is not a built-in template name "
|
90
|
+
"or a valid chat template file path."
|
91
|
+
)
|
92
|
+
|
93
|
+
if chat_template_arg.endswith(".jinja"):
|
94
|
+
self._load_jinja_template(tokenizer_manager, chat_template_arg)
|
95
|
+
else:
|
96
|
+
self._load_json_chat_template(chat_template_arg)
|
97
|
+
else:
|
98
|
+
self._chat_template_name = chat_template_arg
|
99
|
+
|
100
|
+
def guess_chat_template_from_model_path(self, model_path: str) -> None:
|
101
|
+
"""
|
102
|
+
Infer chat template name from model path.
|
103
|
+
|
104
|
+
Args:
|
105
|
+
model_path: Path to the model
|
106
|
+
"""
|
107
|
+
template_name = get_conv_template_by_model_path(model_path)
|
108
|
+
if template_name is not None:
|
109
|
+
logger.info(f"Inferred chat template from model path: {template_name}")
|
110
|
+
self._chat_template_name = template_name
|
111
|
+
|
112
|
+
def load_completion_template(self, completion_template_arg: str) -> None:
|
113
|
+
"""
|
114
|
+
Load completion template for code completion.
|
115
|
+
|
116
|
+
Args:
|
117
|
+
completion_template_arg: Template name or file path
|
118
|
+
"""
|
119
|
+
logger.info(f"Loading completion template: {completion_template_arg}")
|
120
|
+
|
121
|
+
if not completion_template_exists(completion_template_arg):
|
122
|
+
if not os.path.exists(completion_template_arg):
|
123
|
+
raise RuntimeError(
|
124
|
+
f"Completion template {completion_template_arg} is not a built-in template name "
|
125
|
+
"or a valid completion template file path."
|
126
|
+
)
|
127
|
+
|
128
|
+
self._load_json_completion_template(completion_template_arg)
|
129
|
+
else:
|
130
|
+
self._completion_template_name = completion_template_arg
|
131
|
+
|
132
|
+
def initialize_templates(
|
133
|
+
self,
|
134
|
+
tokenizer_manager,
|
135
|
+
model_path: str,
|
136
|
+
chat_template: Optional[str] = None,
|
137
|
+
completion_template: Optional[str] = None,
|
138
|
+
) -> None:
|
139
|
+
"""
|
140
|
+
Initialize all templates based on provided configuration.
|
141
|
+
|
142
|
+
Args:
|
143
|
+
tokenizer_manager: The tokenizer manager instance
|
144
|
+
model_path: Path to the model
|
145
|
+
chat_template: Optional chat template name/path
|
146
|
+
completion_template: Optional completion template name/path
|
147
|
+
"""
|
148
|
+
# Load chat template
|
149
|
+
if chat_template:
|
150
|
+
self.load_chat_template(tokenizer_manager, chat_template, model_path)
|
151
|
+
else:
|
152
|
+
self.guess_chat_template_from_model_path(model_path)
|
153
|
+
|
154
|
+
# Load completion template
|
155
|
+
if completion_template:
|
156
|
+
self.load_completion_template(completion_template)
|
157
|
+
|
158
|
+
def _load_jinja_template(self, tokenizer_manager, template_path: str) -> None:
|
159
|
+
"""Load a Jinja template file."""
|
160
|
+
with open(template_path, "r") as f:
|
161
|
+
chat_template = "".join(f.readlines()).strip("\n")
|
162
|
+
tokenizer_manager.tokenizer.chat_template = chat_template.replace("\\n", "\n")
|
163
|
+
self._chat_template_name = None
|
164
|
+
# Detect content format from the loaded template
|
165
|
+
self._jinja_template_content_format = detect_jinja_template_content_format(
|
166
|
+
chat_template
|
167
|
+
)
|
168
|
+
logger.info(
|
169
|
+
f"Detected chat template content format: {self._jinja_template_content_format}"
|
170
|
+
)
|
171
|
+
|
172
|
+
def _load_json_chat_template(self, template_path: str) -> None:
|
173
|
+
"""Load a JSON chat template file."""
|
174
|
+
assert template_path.endswith(
|
175
|
+
".json"
|
176
|
+
), "unrecognized format of chat template file"
|
177
|
+
|
178
|
+
with open(template_path, "r") as filep:
|
179
|
+
template = json.load(filep)
|
180
|
+
try:
|
181
|
+
sep_style = SeparatorStyle[template["sep_style"]]
|
182
|
+
except KeyError:
|
183
|
+
raise ValueError(
|
184
|
+
f"Unknown separator style: {template['sep_style']}"
|
185
|
+
) from None
|
186
|
+
|
187
|
+
register_conv_template(
|
188
|
+
Conversation(
|
189
|
+
name=template["name"],
|
190
|
+
system_template=template["system"] + "\n{system_message}",
|
191
|
+
system_message=template.get("system_message", ""),
|
192
|
+
roles=(template["user"], template["assistant"]),
|
193
|
+
sep_style=sep_style,
|
194
|
+
sep=template.get("sep", "\n"),
|
195
|
+
stop_str=template["stop_str"],
|
196
|
+
),
|
197
|
+
override=True,
|
198
|
+
)
|
199
|
+
self._chat_template_name = template["name"]
|
200
|
+
|
201
|
+
def _load_json_completion_template(self, template_path: str) -> None:
|
202
|
+
"""Load a JSON completion template file."""
|
203
|
+
assert template_path.endswith(
|
204
|
+
".json"
|
205
|
+
), "unrecognized format of completion template file"
|
206
|
+
|
207
|
+
with open(template_path, "r") as filep:
|
208
|
+
template = json.load(filep)
|
209
|
+
try:
|
210
|
+
fim_position = FimPosition[template["fim_position"]]
|
211
|
+
except KeyError:
|
212
|
+
raise ValueError(
|
213
|
+
f"Unknown fim position: {template['fim_position']}"
|
214
|
+
) from None
|
215
|
+
|
216
|
+
register_completion_template(
|
217
|
+
CompletionTemplate(
|
218
|
+
name=template["name"],
|
219
|
+
fim_begin_token=template["fim_begin_token"],
|
220
|
+
fim_middle_token=template["fim_middle_token"],
|
221
|
+
fim_end_token=template["fim_end_token"],
|
222
|
+
fim_position=fim_position,
|
223
|
+
),
|
224
|
+
override=True,
|
225
|
+
)
|
226
|
+
self._completion_template_name = template["name"]
|
@@ -418,6 +418,20 @@ class TokenizerManager:
|
|
418
418
|
|
419
419
|
obj.normalize_batch_and_arguments()
|
420
420
|
|
421
|
+
if isinstance(obj, GenerateReqInput):
|
422
|
+
return_hidden_states = obj.return_hidden_states
|
423
|
+
has_return_hidden_states = return_hidden_states == True or (
|
424
|
+
isinstance(return_hidden_states, list) and any(return_hidden_states)
|
425
|
+
)
|
426
|
+
if (
|
427
|
+
not self.server_args.enable_return_hidden_states
|
428
|
+
and has_return_hidden_states
|
429
|
+
):
|
430
|
+
raise ValueError(
|
431
|
+
"return_hidden_states=True requires the server to be started "
|
432
|
+
"with --enable-return-hidden-states (ServerArgs.enable_return_hidden_states)."
|
433
|
+
)
|
434
|
+
|
421
435
|
if self.log_requests:
|
422
436
|
max_length, skip_names, _ = self.log_request_metadata
|
423
437
|
logger.info(
|
@@ -445,6 +459,10 @@ class TokenizerManager:
|
|
445
459
|
# Tokenize
|
446
460
|
input_embeds = None
|
447
461
|
input_text = obj.text
|
462
|
+
token_type_ids = None
|
463
|
+
is_cross_encoder_request = (
|
464
|
+
isinstance(obj, EmbeddingReqInput) and obj.is_cross_encoder_request
|
465
|
+
)
|
448
466
|
if obj.input_embeds is not None:
|
449
467
|
if not self.server_args.disable_radix_cache:
|
450
468
|
raise ValueError(
|
@@ -463,7 +481,14 @@ class TokenizerManager:
|
|
463
481
|
"accept text prompts. Please provide input_ids or re-initialize "
|
464
482
|
"the engine with skip_tokenizer_init=False."
|
465
483
|
)
|
466
|
-
|
484
|
+
encoded = self.tokenizer(
|
485
|
+
input_text, return_token_type_ids=is_cross_encoder_request
|
486
|
+
)
|
487
|
+
|
488
|
+
input_ids = encoded["input_ids"]
|
489
|
+
if is_cross_encoder_request:
|
490
|
+
input_ids = encoded["input_ids"][0]
|
491
|
+
token_type_ids = encoded.get("token_type_ids", [None])[0]
|
467
492
|
|
468
493
|
if self.mm_processor and obj.contains_mm_input():
|
469
494
|
image_inputs = await self.mm_processor.process_mm_data_async(
|
@@ -479,7 +504,7 @@ class TokenizerManager:
|
|
479
504
|
|
480
505
|
self._validate_token_len(obj, input_ids)
|
481
506
|
return self._create_tokenized_object(
|
482
|
-
obj, input_text, input_ids, input_embeds, image_inputs
|
507
|
+
obj, input_text, input_ids, input_embeds, image_inputs, token_type_ids
|
483
508
|
)
|
484
509
|
|
485
510
|
def _validate_token_len(
|
@@ -518,6 +543,7 @@ class TokenizerManager:
|
|
518
543
|
input_ids: List[int],
|
519
544
|
input_embeds: Optional[Union[List[float], None]] = None,
|
520
545
|
image_inputs: Optional[Dict] = None,
|
546
|
+
token_type_ids: Optional[List[int]] = None,
|
521
547
|
) -> Union[TokenizedGenerateReqInput, TokenizedEmbeddingReqInput]:
|
522
548
|
"""Create a tokenized request object from common parameters."""
|
523
549
|
|
@@ -578,6 +604,7 @@ class TokenizerManager:
|
|
578
604
|
input_text,
|
579
605
|
input_ids,
|
580
606
|
image_inputs,
|
607
|
+
token_type_ids,
|
581
608
|
sampling_params,
|
582
609
|
)
|
583
610
|
|
@@ -1031,12 +1058,7 @@ class TokenizerManager:
|
|
1031
1058
|
"lora_path",
|
1032
1059
|
]
|
1033
1060
|
)
|
1034
|
-
out_skip_names = set(
|
1035
|
-
[
|
1036
|
-
"text",
|
1037
|
-
"output_ids",
|
1038
|
-
]
|
1039
|
-
)
|
1061
|
+
out_skip_names = set(["text", "output_ids", "embedding"])
|
1040
1062
|
elif self.log_requests_level == 1:
|
1041
1063
|
max_length = 2048
|
1042
1064
|
elif self.log_requests_level == 2:
|
@@ -1113,13 +1135,21 @@ class TokenizerManager:
|
|
1113
1135
|
remain_num_req = len(self.rid_to_state)
|
1114
1136
|
|
1115
1137
|
if self.health_check_failed:
|
1116
|
-
# if health check failed,
|
1138
|
+
# if health check failed, exit immediately
|
1117
1139
|
logger.error(
|
1118
1140
|
"Signal SIGTERM received while health check failed. Exiting... remaining number of requests: %d",
|
1119
1141
|
remain_num_req,
|
1120
1142
|
)
|
1121
1143
|
break
|
1122
1144
|
|
1145
|
+
elif get_bool_env_var("SGL_FORCE_SHUTDOWN"):
|
1146
|
+
# if force shutdown flag set, exit immediately
|
1147
|
+
logger.error(
|
1148
|
+
"Signal SIGTERM received while force shutdown flag set. Force exiting... remaining number of requests: %d",
|
1149
|
+
remain_num_req,
|
1150
|
+
)
|
1151
|
+
break
|
1152
|
+
|
1123
1153
|
logger.info(
|
1124
1154
|
f"Gracefully exiting... remaining number of requests {remain_num_req}"
|
1125
1155
|
)
|
@@ -1196,7 +1226,7 @@ class TokenizerManager:
|
|
1196
1226
|
state.last_output_offset = len(state.output_ids)
|
1197
1227
|
else:
|
1198
1228
|
state.output_ids.extend(recv_obj.output_ids[i])
|
1199
|
-
output_token_ids = state.output_ids
|
1229
|
+
output_token_ids = state.output_ids.copy()
|
1200
1230
|
|
1201
1231
|
out_dict = {
|
1202
1232
|
"output_ids": output_token_ids,
|
sglang/srt/managers/tp_worker.py
CHANGED
@@ -35,7 +35,8 @@ from sglang.srt.managers.io_struct import (
|
|
35
35
|
UpdateWeightsFromTensorReqInput,
|
36
36
|
)
|
37
37
|
from sglang.srt.managers.schedule_batch import ModelWorkerBatch, global_server_args_dict
|
38
|
-
from sglang.srt.mem_cache.
|
38
|
+
from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
|
39
|
+
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool
|
39
40
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
|
40
41
|
from sglang.srt.model_executor.model_runner import ModelRunner
|
41
42
|
from sglang.srt.server_args import ServerArgs
|
@@ -57,7 +58,7 @@ class TpModelWorker:
|
|
57
58
|
nccl_port: int,
|
58
59
|
is_draft_worker: bool = False,
|
59
60
|
req_to_token_pool: Optional[ReqToTokenPool] = None,
|
60
|
-
token_to_kv_pool_allocator: Optional[
|
61
|
+
token_to_kv_pool_allocator: Optional[BaseTokenToKVPoolAllocator] = None,
|
61
62
|
):
|
62
63
|
# Parse args
|
63
64
|
self.tp_size = server_args.tp_size
|
@@ -147,6 +148,15 @@ class TpModelWorker:
|
|
147
148
|
# A reference make this class has the same member as TpModelWorkerClient
|
148
149
|
self.worker = self
|
149
150
|
|
151
|
+
self.hicache_layer_transfer_counter = None
|
152
|
+
|
153
|
+
def register_hicache_layer_transfer_counter(self, counter):
|
154
|
+
self.hicache_layer_transfer_counter = counter
|
155
|
+
|
156
|
+
def set_hicache_consumer(self, consumer_index):
|
157
|
+
if self.hicache_layer_transfer_counter is not None:
|
158
|
+
self.hicache_layer_transfer_counter.set_consumer(consumer_index)
|
159
|
+
|
150
160
|
def get_worker_info(self):
|
151
161
|
return (
|
152
162
|
self.max_total_num_tokens,
|
@@ -88,6 +88,15 @@ class TpModelWorkerClient:
|
|
88
88
|
if self.device == "cpu":
|
89
89
|
self.scheduler_stream.synchronize = lambda: None # No-op for CPU
|
90
90
|
|
91
|
+
self.hicache_layer_transfer_counter = None
|
92
|
+
|
93
|
+
def register_hicache_layer_transfer_counter(self, counter):
|
94
|
+
self.hicache_layer_transfer_counter = counter
|
95
|
+
|
96
|
+
def set_hicache_consumer(self, consumer_index):
|
97
|
+
if self.hicache_layer_transfer_counter is not None:
|
98
|
+
self.hicache_layer_transfer_counter.set_consumer(consumer_index)
|
99
|
+
|
91
100
|
def get_worker_info(self):
|
92
101
|
return self.worker.get_worker_info()
|
93
102
|
|
@@ -146,6 +155,8 @@ class TpModelWorkerClient:
|
|
146
155
|
input_ids = model_worker_batch.input_ids
|
147
156
|
resolve_future_token_ids(input_ids, self.future_token_ids_map)
|
148
157
|
|
158
|
+
# update the consumer index of hicache to the running batch
|
159
|
+
self.set_hicache_consumer(model_worker_batch.hicache_consumer_index)
|
149
160
|
# Run forward
|
150
161
|
logits_output, next_token_ids, can_run_cuda_graph = (
|
151
162
|
self.worker.forward_batch_generation(
|
@@ -1,3 +1,5 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
1
3
|
"""
|
2
4
|
Copyright 2025 SGLang Team
|
3
5
|
Licensed under the Apache License, Version 2.0 (the "License");
|
@@ -17,13 +19,132 @@ limitations under the License.
|
|
17
19
|
Page-aligned memory pool.
|
18
20
|
"""
|
19
21
|
|
22
|
+
import abc
|
23
|
+
from typing import TYPE_CHECKING
|
24
|
+
|
20
25
|
import torch
|
21
26
|
import triton
|
22
27
|
import triton.language as tl
|
23
28
|
|
24
|
-
from sglang.srt.mem_cache.memory_pool import KVCache
|
25
29
|
from sglang.srt.utils import get_bool_env_var, next_power_of_2
|
26
30
|
|
31
|
+
if TYPE_CHECKING:
|
32
|
+
from sglang.srt.mem_cache.memory_pool import KVCache
|
33
|
+
|
34
|
+
|
35
|
+
class BaseTokenToKVPoolAllocator(abc.ABC):
|
36
|
+
@abc.abstractmethod
|
37
|
+
def __init__(
|
38
|
+
self,
|
39
|
+
size: int,
|
40
|
+
page_size: int,
|
41
|
+
dtype: torch.dtype,
|
42
|
+
device: str,
|
43
|
+
kvcache: KVCache,
|
44
|
+
):
|
45
|
+
self.size = size
|
46
|
+
self.page_size = page_size
|
47
|
+
self.dtype = dtype
|
48
|
+
self.device = device
|
49
|
+
self._kvcache = kvcache
|
50
|
+
|
51
|
+
self.free_pages = None
|
52
|
+
self.is_not_in_free_group = True
|
53
|
+
self.free_group = []
|
54
|
+
|
55
|
+
def debug_print(self) -> str:
|
56
|
+
return ""
|
57
|
+
|
58
|
+
def available_size(self):
|
59
|
+
return len(self.free_pages) * self.page_size
|
60
|
+
|
61
|
+
def get_kvcache(self):
|
62
|
+
return self._kvcache
|
63
|
+
|
64
|
+
def restore_state(self, free_pages):
|
65
|
+
self.free_pages = free_pages
|
66
|
+
|
67
|
+
def backup_state(self):
|
68
|
+
return self.free_pages
|
69
|
+
|
70
|
+
def free_group_begin(self):
|
71
|
+
self.is_not_in_free_group = False
|
72
|
+
self.free_group = []
|
73
|
+
|
74
|
+
def free_group_end(self):
|
75
|
+
self.is_not_in_free_group = True
|
76
|
+
if self.free_group:
|
77
|
+
self.free(torch.cat(self.free_group))
|
78
|
+
|
79
|
+
def get_cpu_copy(self, *args, **kwargs):
|
80
|
+
# FIXME: reuse the get_cpu_copy after paged allocator is implemented
|
81
|
+
raise NotImplementedError()
|
82
|
+
|
83
|
+
def load_cpu_copy(self, *args, **kwargs):
|
84
|
+
# FIXME: reuse the load_cpu_copy after paged allocator is implemented
|
85
|
+
raise NotImplementedError()
|
86
|
+
|
87
|
+
def alloc_extend(self, *args, **kwargs):
|
88
|
+
raise NotImplementedError("alloc_extend is only for paged allocator")
|
89
|
+
|
90
|
+
def alloc_decode(self, *args, **kwargs):
|
91
|
+
raise NotImplementedError("alloc_decode is only for paged allocator")
|
92
|
+
|
93
|
+
@abc.abstractmethod
|
94
|
+
def clear(self):
|
95
|
+
raise NotImplementedError()
|
96
|
+
|
97
|
+
@abc.abstractmethod
|
98
|
+
def alloc(self, need_size: int):
|
99
|
+
raise NotImplementedError()
|
100
|
+
|
101
|
+
@abc.abstractmethod
|
102
|
+
def free(self, free_index: torch.Tensor):
|
103
|
+
raise NotImplementedError()
|
104
|
+
|
105
|
+
|
106
|
+
class TokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
|
107
|
+
"""An allocator managing the indices to kv cache data."""
|
108
|
+
|
109
|
+
def __init__(self, size: int, dtype: torch.dtype, device: str, kvcache: KVCache):
|
110
|
+
super().__init__(size, 1, dtype, device, kvcache)
|
111
|
+
self.clear()
|
112
|
+
|
113
|
+
def clear(self):
|
114
|
+
# The padded slot 0 is used for writing dummy outputs from padded tokens.
|
115
|
+
self.free_pages = torch.arange(
|
116
|
+
1, self.size + 1, dtype=torch.int64, device=self.device
|
117
|
+
)
|
118
|
+
self.is_not_in_free_group = True
|
119
|
+
self.free_group = []
|
120
|
+
|
121
|
+
def available_size(self):
|
122
|
+
# To avoid minor "len(free_pages) * 1" overhead
|
123
|
+
return len(self.free_pages)
|
124
|
+
|
125
|
+
def alloc(self, need_size: int):
|
126
|
+
if need_size > len(self.free_pages):
|
127
|
+
return None
|
128
|
+
|
129
|
+
select_index = self.free_pages[:need_size]
|
130
|
+
self.free_pages = self.free_pages[need_size:]
|
131
|
+
return select_index
|
132
|
+
|
133
|
+
def free(self, free_index: torch.Tensor):
|
134
|
+
if free_index.numel() == 0:
|
135
|
+
return
|
136
|
+
|
137
|
+
if self.is_not_in_free_group:
|
138
|
+
self.free_pages = torch.cat((self.free_pages, free_index))
|
139
|
+
else:
|
140
|
+
self.free_group.append(free_index)
|
141
|
+
|
142
|
+
def get_cpu_copy(self, indices):
|
143
|
+
return self._kvcache.get_cpu_copy(indices)
|
144
|
+
|
145
|
+
def load_cpu_copy(self, kv_cache_cpu, indices):
|
146
|
+
return self._kvcache.load_cpu_copy(kv_cache_cpu, indices)
|
147
|
+
|
27
148
|
|
28
149
|
@triton.jit
|
29
150
|
def alloc_extend_kernel(
|
@@ -154,7 +275,7 @@ def alloc_decode_kernel(
|
|
154
275
|
tl.store(out_indices + pid, page * page_size)
|
155
276
|
|
156
277
|
|
157
|
-
class PagedTokenToKVPoolAllocator:
|
278
|
+
class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
|
158
279
|
"""
|
159
280
|
An allocator managing the indices to kv cache data.
|
160
281
|
|
@@ -172,26 +293,11 @@ class PagedTokenToKVPoolAllocator:
|
|
172
293
|
device: str,
|
173
294
|
kvcache: KVCache,
|
174
295
|
):
|
175
|
-
|
176
|
-
self.dtype = dtype
|
177
|
-
self.device = device
|
178
|
-
self.page_size = page_size
|
296
|
+
super().__init__(size, page_size, dtype, device, kvcache)
|
179
297
|
self.num_pages = size // page_size
|
180
|
-
|
181
|
-
self.free_pages = None
|
182
|
-
self.is_not_in_free_group = True
|
183
|
-
self.free_group = []
|
184
|
-
self.clear()
|
185
298
|
self.debug_mode = get_bool_env_var("SGLANG_DEBUG_MEMORY_POOL")
|
186
|
-
|
187
|
-
self._kvcache = kvcache
|
188
299
|
self.ret_values = torch.empty((), dtype=torch.int64, device=self.device)
|
189
|
-
|
190
|
-
def available_size(self):
|
191
|
-
return len(self.free_pages) * self.page_size
|
192
|
-
|
193
|
-
def get_kvcache(self):
|
194
|
-
return self._kvcache
|
300
|
+
self.clear()
|
195
301
|
|
196
302
|
def alloc(self, need_size: int):
|
197
303
|
# page-aligned allocation, returning contiguous indices of pages
|
@@ -298,21 +404,6 @@ class PagedTokenToKVPoolAllocator:
|
|
298
404
|
if self.debug_mode:
|
299
405
|
assert len(torch.unique(self.free_pages)) == len(self.free_pages)
|
300
406
|
|
301
|
-
def free_group_begin(self):
|
302
|
-
self.is_not_in_free_group = False
|
303
|
-
self.free_group = []
|
304
|
-
|
305
|
-
def free_group_end(self):
|
306
|
-
self.is_not_in_free_group = True
|
307
|
-
if self.free_group:
|
308
|
-
self.free(torch.cat(self.free_group))
|
309
|
-
|
310
|
-
def backup_state(self):
|
311
|
-
return self.free_pages
|
312
|
-
|
313
|
-
def restore_state(self, free_pages):
|
314
|
-
self.free_pages = free_pages
|
315
|
-
|
316
407
|
def clear(self):
|
317
408
|
# The padded slot 0 is used for writing dummy outputs from padded tokens.
|
318
409
|
self.free_pages = torch.arange(
|
@@ -1,5 +1,31 @@
|
|
1
1
|
from abc import ABC, abstractmethod
|
2
|
-
from typing import Any, List, Tuple
|
2
|
+
from typing import TYPE_CHECKING, Any, List, NamedTuple, Tuple
|
3
|
+
|
4
|
+
import torch
|
5
|
+
|
6
|
+
if TYPE_CHECKING:
|
7
|
+
from sglang.srt.managers.schedule_batch import Req
|
8
|
+
else:
|
9
|
+
Req = Any # Placeholder for Req type when not type checking
|
10
|
+
|
11
|
+
|
12
|
+
class MatchResult(NamedTuple):
|
13
|
+
"""Result of a prefix match operation.
|
14
|
+
|
15
|
+
Attributes:
|
16
|
+
device_indices : Indices of the KV cache on the device matched by common prefix.
|
17
|
+
last_device_node: The last TreeNode on the device that was matched.
|
18
|
+
last_host_node : The last TreeNode on the host that was matched.
|
19
|
+
Note that if HiCache is not enabled,
|
20
|
+
this **must** be the same as `last_device_node`.
|
21
|
+
host_hit_length : Length of the KV cache hit on the host, if applicable.
|
22
|
+
0 if HiCache is not enabled.
|
23
|
+
"""
|
24
|
+
|
25
|
+
device_indices: torch.Tensor
|
26
|
+
last_device_node: Any
|
27
|
+
last_host_node: Any
|
28
|
+
host_hit_length: int = 0
|
3
29
|
|
4
30
|
|
5
31
|
class BasePrefixCache(ABC):
|
@@ -10,19 +36,15 @@ class BasePrefixCache(ABC):
|
|
10
36
|
pass
|
11
37
|
|
12
38
|
@abstractmethod
|
13
|
-
def match_prefix(self,
|
39
|
+
def match_prefix(self, key: List[int], **kwargs) -> MatchResult:
|
14
40
|
pass
|
15
41
|
|
16
42
|
@abstractmethod
|
17
|
-
def
|
43
|
+
def cache_finished_req(self, req: Req, **kwargs):
|
18
44
|
pass
|
19
45
|
|
20
46
|
@abstractmethod
|
21
|
-
def
|
22
|
-
pass
|
23
|
-
|
24
|
-
@abstractmethod
|
25
|
-
def cache_unfinished_req(self, **kwargs):
|
47
|
+
def cache_unfinished_req(self, req: Req, **kwargs):
|
26
48
|
pass
|
27
49
|
|
28
50
|
@abstractmethod
|
@@ -49,5 +71,27 @@ class BasePrefixCache(ABC):
|
|
49
71
|
def pretty_print(self):
|
50
72
|
raise NotImplementedError()
|
51
73
|
|
74
|
+
def init_load_back(
|
75
|
+
self,
|
76
|
+
last_host_node: Any,
|
77
|
+
host_hit_length: int,
|
78
|
+
) -> Tuple[torch.Tensor, Any]:
|
79
|
+
"""
|
80
|
+
Preparing KV cache loading from host to device.
|
81
|
+
"""
|
82
|
+
raise NotImplementedError()
|
83
|
+
|
84
|
+
def ready_to_load_host_cache(self) -> Any:
|
85
|
+
"""
|
86
|
+
Notify the cache controller to start the KV cache loading
|
87
|
+
"""
|
88
|
+
raise NotImplementedError()
|
89
|
+
|
90
|
+
def check_hicache_events(self) -> Any:
|
91
|
+
"""
|
92
|
+
Check HiCache related activities to update radix tree and synchronize across TP workers if needed
|
93
|
+
"""
|
94
|
+
raise NotImplementedError()
|
95
|
+
|
52
96
|
def take_events(self):
|
53
97
|
return []
|