sglang 0.4.6.post3__py3-none-any.whl → 0.4.6.post5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_offline_throughput.py +10 -8
- sglang/bench_one_batch.py +7 -6
- sglang/bench_one_batch_server.py +157 -21
- sglang/bench_serving.py +137 -59
- sglang/compile_deep_gemm.py +5 -5
- sglang/eval/loogle_eval.py +157 -0
- sglang/lang/chat_template.py +78 -78
- sglang/lang/tracer.py +1 -1
- sglang/srt/code_completion_parser.py +1 -1
- sglang/srt/configs/deepseekvl2.py +2 -2
- sglang/srt/configs/model_config.py +40 -28
- sglang/srt/constrained/base_grammar_backend.py +55 -72
- sglang/srt/constrained/llguidance_backend.py +25 -21
- sglang/srt/constrained/outlines_backend.py +27 -26
- sglang/srt/constrained/reasoner_grammar_backend.py +22 -33
- sglang/srt/constrained/xgrammar_backend.py +69 -43
- sglang/srt/conversation.py +49 -44
- sglang/srt/disaggregation/base/conn.py +1 -0
- sglang/srt/disaggregation/decode.py +129 -135
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +142 -0
- sglang/srt/disaggregation/fake/conn.py +3 -13
- sglang/srt/disaggregation/kv_events.py +357 -0
- sglang/srt/disaggregation/mini_lb.py +57 -24
- sglang/srt/disaggregation/mooncake/conn.py +238 -122
- sglang/srt/disaggregation/mooncake/transfer_engine.py +2 -1
- sglang/srt/disaggregation/nixl/conn.py +10 -19
- sglang/srt/disaggregation/prefill.py +132 -47
- sglang/srt/disaggregation/utils.py +123 -6
- sglang/srt/distributed/utils.py +3 -3
- sglang/srt/entrypoints/EngineBase.py +5 -0
- sglang/srt/entrypoints/engine.py +44 -9
- sglang/srt/entrypoints/http_server.py +23 -6
- sglang/srt/entrypoints/http_server_engine.py +5 -2
- sglang/srt/function_call/base_format_detector.py +250 -0
- sglang/srt/function_call/core_types.py +34 -0
- sglang/srt/function_call/deepseekv3_detector.py +157 -0
- sglang/srt/function_call/ebnf_composer.py +234 -0
- sglang/srt/function_call/function_call_parser.py +175 -0
- sglang/srt/function_call/llama32_detector.py +74 -0
- sglang/srt/function_call/mistral_detector.py +84 -0
- sglang/srt/function_call/pythonic_detector.py +163 -0
- sglang/srt/function_call/qwen25_detector.py +67 -0
- sglang/srt/function_call/utils.py +35 -0
- sglang/srt/hf_transformers_utils.py +46 -7
- sglang/srt/layers/attention/aiter_backend.py +513 -0
- sglang/srt/layers/attention/flashattention_backend.py +64 -18
- sglang/srt/layers/attention/flashinfer_mla_backend.py +8 -4
- sglang/srt/layers/attention/flashmla_backend.py +340 -78
- sglang/srt/layers/attention/triton_backend.py +3 -0
- sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +1 -1
- sglang/srt/layers/attention/utils.py +6 -4
- sglang/srt/layers/attention/vision.py +1 -1
- sglang/srt/layers/communicator.py +451 -0
- sglang/srt/layers/dp_attention.py +61 -21
- sglang/srt/layers/layernorm.py +1 -1
- sglang/srt/layers/logits_processor.py +46 -11
- sglang/srt/layers/moe/cutlass_moe.py +207 -0
- sglang/srt/layers/moe/ep_moe/kernels.py +34 -12
- sglang/srt/layers/moe/ep_moe/layer.py +105 -51
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +82 -7
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +1 -1
- sglang/srt/layers/moe/fused_moe_triton/layer.py +14 -0
- sglang/srt/layers/moe/topk.py +67 -10
- sglang/srt/layers/multimodal.py +70 -0
- sglang/srt/layers/quantization/__init__.py +8 -3
- sglang/srt/layers/quantization/blockwise_int8.py +2 -2
- sglang/srt/layers/quantization/deep_gemm.py +77 -74
- sglang/srt/layers/quantization/fp8.py +92 -2
- sglang/srt/layers/quantization/fp8_kernel.py +3 -3
- sglang/srt/layers/quantization/fp8_utils.py +6 -0
- sglang/srt/layers/quantization/gptq.py +298 -6
- sglang/srt/layers/quantization/int8_kernel.py +20 -7
- sglang/srt/layers/quantization/qoq.py +244 -0
- sglang/srt/layers/sampler.py +0 -4
- sglang/srt/layers/vocab_parallel_embedding.py +18 -7
- sglang/srt/lora/lora_manager.py +2 -4
- sglang/srt/lora/mem_pool.py +4 -4
- sglang/srt/lora/triton_ops/gate_up_lora_b.py +1 -1
- sglang/srt/lora/triton_ops/qkv_lora_b.py +1 -1
- sglang/srt/lora/triton_ops/sgemm_lora_a.py +1 -1
- sglang/srt/lora/triton_ops/sgemm_lora_b.py +1 -1
- sglang/srt/lora/utils.py +1 -1
- sglang/srt/managers/data_parallel_controller.py +3 -3
- sglang/srt/managers/deepseek_eplb.py +278 -0
- sglang/srt/managers/detokenizer_manager.py +21 -8
- sglang/srt/managers/eplb_manager.py +55 -0
- sglang/srt/managers/expert_distribution.py +704 -56
- sglang/srt/managers/expert_location.py +394 -0
- sglang/srt/managers/expert_location_dispatch.py +91 -0
- sglang/srt/managers/io_struct.py +19 -4
- sglang/srt/managers/mm_utils.py +294 -140
- sglang/srt/managers/multimodal_processors/base_processor.py +127 -42
- sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +6 -1
- sglang/srt/managers/multimodal_processors/gemma3.py +31 -6
- sglang/srt/managers/multimodal_processors/internvl.py +14 -5
- sglang/srt/managers/multimodal_processors/janus_pro.py +7 -1
- sglang/srt/managers/multimodal_processors/kimi_vl.py +7 -6
- sglang/srt/managers/multimodal_processors/llava.py +46 -0
- sglang/srt/managers/multimodal_processors/minicpm.py +25 -31
- sglang/srt/managers/multimodal_processors/mllama4.py +6 -0
- sglang/srt/managers/multimodal_processors/pixtral.py +127 -0
- sglang/srt/managers/multimodal_processors/qwen_vl.py +58 -16
- sglang/srt/managers/schedule_batch.py +122 -42
- sglang/srt/managers/schedule_policy.py +1 -5
- sglang/srt/managers/scheduler.py +205 -138
- sglang/srt/managers/scheduler_output_processor_mixin.py +124 -55
- sglang/srt/managers/session_controller.py +1 -1
- sglang/srt/managers/tokenizer_manager.py +232 -58
- sglang/srt/managers/tp_worker.py +12 -9
- sglang/srt/managers/tp_worker_overlap_thread.py +22 -11
- sglang/srt/mem_cache/base_prefix_cache.py +3 -0
- sglang/srt/mem_cache/chunk_cache.py +3 -1
- sglang/srt/mem_cache/hiradix_cache.py +4 -4
- sglang/srt/mem_cache/memory_pool.py +76 -52
- sglang/srt/mem_cache/multimodal_cache.py +45 -0
- sglang/srt/mem_cache/radix_cache.py +58 -5
- sglang/srt/metrics/collector.py +314 -39
- sglang/srt/mm_utils.py +10 -0
- sglang/srt/model_executor/cuda_graph_runner.py +29 -19
- sglang/srt/model_executor/expert_location_updater.py +422 -0
- sglang/srt/model_executor/forward_batch_info.py +5 -1
- sglang/srt/model_executor/model_runner.py +163 -68
- sglang/srt/model_loader/loader.py +10 -6
- sglang/srt/models/clip.py +5 -1
- sglang/srt/models/deepseek_janus_pro.py +2 -2
- sglang/srt/models/deepseek_v2.py +308 -351
- sglang/srt/models/exaone.py +8 -3
- sglang/srt/models/gemma3_mm.py +70 -33
- sglang/srt/models/llama.py +2 -0
- sglang/srt/models/llama4.py +15 -8
- sglang/srt/models/llava.py +258 -7
- sglang/srt/models/mimo_mtp.py +220 -0
- sglang/srt/models/minicpmo.py +5 -12
- sglang/srt/models/mistral.py +71 -1
- sglang/srt/models/mixtral.py +98 -34
- sglang/srt/models/mllama.py +3 -3
- sglang/srt/models/pixtral.py +467 -0
- sglang/srt/models/qwen2.py +95 -26
- sglang/srt/models/qwen2_5_vl.py +8 -0
- sglang/srt/models/qwen2_moe.py +330 -60
- sglang/srt/models/qwen2_vl.py +6 -0
- sglang/srt/models/qwen3.py +52 -10
- sglang/srt/models/qwen3_moe.py +411 -48
- sglang/srt/models/roberta.py +1 -1
- sglang/srt/models/siglip.py +294 -0
- sglang/srt/models/torch_native_llama.py +1 -1
- sglang/srt/openai_api/adapter.py +58 -20
- sglang/srt/openai_api/protocol.py +6 -8
- sglang/srt/operations.py +154 -0
- sglang/srt/operations_strategy.py +31 -0
- sglang/srt/reasoning_parser.py +3 -3
- sglang/srt/sampling/custom_logit_processor.py +18 -3
- sglang/srt/sampling/sampling_batch_info.py +4 -56
- sglang/srt/sampling/sampling_params.py +2 -2
- sglang/srt/server_args.py +162 -22
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +3 -3
- sglang/srt/speculative/eagle_utils.py +138 -7
- sglang/srt/speculative/eagle_worker.py +69 -21
- sglang/srt/utils.py +74 -17
- sglang/test/few_shot_gsm8k.py +2 -2
- sglang/test/few_shot_gsm8k_engine.py +2 -2
- sglang/test/run_eval.py +2 -2
- sglang/test/runners.py +8 -1
- sglang/test/send_one.py +13 -3
- sglang/test/simple_eval_common.py +1 -1
- sglang/test/simple_eval_humaneval.py +1 -1
- sglang/test/test_cutlass_moe.py +278 -0
- sglang/test/test_programs.py +5 -5
- sglang/test/test_utils.py +55 -14
- sglang/utils.py +3 -3
- sglang/version.py +1 -1
- {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/METADATA +23 -13
- {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/RECORD +178 -149
- {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/WHEEL +1 -1
- sglang/srt/function_call_parser.py +0 -858
- sglang/srt/platforms/interface.py +0 -371
- /sglang/{llama3_eval.py → eval/llama3_eval.py} +0 -0
- /sglang/srt/models/{xiaomi_mimo.py → mimo.py} +0 -0
- {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/top_level.txt +0 -0
sglang/srt/managers/tp_worker.py
CHANGED
@@ -20,7 +20,7 @@ from typing import Optional, Tuple, Union
|
|
20
20
|
import torch
|
21
21
|
|
22
22
|
from sglang.srt.configs.model_config import ModelConfig
|
23
|
-
from sglang.srt.distributed import get_pp_group,
|
23
|
+
from sglang.srt.distributed import get_pp_group, get_world_group
|
24
24
|
from sglang.srt.hf_transformers_utils import (
|
25
25
|
get_processor,
|
26
26
|
get_tokenizer,
|
@@ -183,8 +183,11 @@ class TpModelWorker:
|
|
183
183
|
def forward_batch_generation(
|
184
184
|
self,
|
185
185
|
model_worker_batch: ModelWorkerBatch,
|
186
|
+
launch_done: Optional[threading.Event] = None,
|
186
187
|
skip_sample: bool = False,
|
187
|
-
) -> Tuple[
|
188
|
+
) -> Tuple[
|
189
|
+
Union[LogitsProcessorOutput, torch.Tensor], Optional[torch.Tensor], bool
|
190
|
+
]:
|
188
191
|
forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
|
189
192
|
|
190
193
|
pp_proxy_tensors = None
|
@@ -196,11 +199,11 @@ class TpModelWorker:
|
|
196
199
|
)
|
197
200
|
|
198
201
|
if self.pp_group.is_last_rank:
|
199
|
-
logits_output = self.model_runner.forward(
|
202
|
+
logits_output, can_run_cuda_graph = self.model_runner.forward(
|
200
203
|
forward_batch, pp_proxy_tensors=pp_proxy_tensors
|
201
204
|
)
|
202
|
-
if
|
203
|
-
|
205
|
+
if launch_done is not None:
|
206
|
+
launch_done.set()
|
204
207
|
|
205
208
|
if skip_sample:
|
206
209
|
next_token_ids = None
|
@@ -209,17 +212,17 @@ class TpModelWorker:
|
|
209
212
|
logits_output, model_worker_batch
|
210
213
|
)
|
211
214
|
|
212
|
-
return logits_output, next_token_ids
|
215
|
+
return logits_output, next_token_ids, can_run_cuda_graph
|
213
216
|
else:
|
214
|
-
pp_proxy_tensors = self.model_runner.forward(
|
217
|
+
pp_proxy_tensors, can_run_cuda_graph = self.model_runner.forward(
|
215
218
|
forward_batch,
|
216
219
|
pp_proxy_tensors=pp_proxy_tensors,
|
217
220
|
)
|
218
|
-
return pp_proxy_tensors.tensors, None
|
221
|
+
return pp_proxy_tensors.tensors, None, can_run_cuda_graph
|
219
222
|
|
220
223
|
def forward_batch_embedding(self, model_worker_batch: ModelWorkerBatch):
|
221
224
|
forward_batch = ForwardBatch.init_new(model_worker_batch, self.model_runner)
|
222
|
-
logits_output = self.model_runner.forward(forward_batch)
|
225
|
+
logits_output, _ = self.model_runner.forward(forward_batch)
|
223
226
|
embeddings = logits_output.embeddings
|
224
227
|
return embeddings
|
225
228
|
|
@@ -18,7 +18,7 @@ import logging
|
|
18
18
|
import signal
|
19
19
|
import threading
|
20
20
|
from queue import Queue
|
21
|
-
from typing import Optional
|
21
|
+
from typing import Optional, Tuple
|
22
22
|
|
23
23
|
import psutil
|
24
24
|
import torch
|
@@ -127,10 +127,12 @@ class TpModelWorkerClient:
|
|
127
127
|
batch_lists = [None] * 2
|
128
128
|
|
129
129
|
while True:
|
130
|
-
model_worker_batch, future_token_ids_ct = self.input_queue.get()
|
130
|
+
model_worker_batch, future_token_ids_ct, sync_event = self.input_queue.get()
|
131
131
|
if not model_worker_batch:
|
132
132
|
break
|
133
133
|
|
134
|
+
sync_event.wait()
|
135
|
+
|
134
136
|
# Keep a reference of model_worker_batch by storing it into a list.
|
135
137
|
# Otherwise, the tensor members of model_worker_batch will be released
|
136
138
|
# by pytorch and cause CUDA illegal memory access errors.
|
@@ -145,8 +147,10 @@ class TpModelWorkerClient:
|
|
145
147
|
resolve_future_token_ids(input_ids, self.future_token_ids_map)
|
146
148
|
|
147
149
|
# Run forward
|
148
|
-
logits_output, next_token_ids =
|
149
|
-
|
150
|
+
logits_output, next_token_ids, can_run_cuda_graph = (
|
151
|
+
self.worker.forward_batch_generation(
|
152
|
+
model_worker_batch, model_worker_batch.launch_done
|
153
|
+
)
|
150
154
|
)
|
151
155
|
|
152
156
|
# Update the future token ids map
|
@@ -171,14 +175,18 @@ class TpModelWorkerClient:
|
|
171
175
|
next_token_ids = next_token_ids.to("cpu", non_blocking=True)
|
172
176
|
copy_done.record()
|
173
177
|
|
174
|
-
self.output_queue.put(
|
178
|
+
self.output_queue.put(
|
179
|
+
(copy_done, logits_output, next_token_ids, can_run_cuda_graph)
|
180
|
+
)
|
175
181
|
|
176
182
|
def resolve_last_batch_result(self, launch_done: Optional[threading.Event] = None):
|
177
183
|
"""
|
178
184
|
This function is called to resolve the last batch result and
|
179
185
|
wait for the current batch to be launched. Used in overlap mode.
|
180
186
|
"""
|
181
|
-
copy_done, logits_output, next_token_ids =
|
187
|
+
copy_done, logits_output, next_token_ids, can_run_cuda_graph = (
|
188
|
+
self.output_queue.get()
|
189
|
+
)
|
182
190
|
|
183
191
|
if launch_done is not None:
|
184
192
|
launch_done.wait()
|
@@ -193,9 +201,11 @@ class TpModelWorkerClient:
|
|
193
201
|
logits_output.input_token_logprobs.tolist()
|
194
202
|
)
|
195
203
|
next_token_ids = next_token_ids.tolist()
|
196
|
-
return logits_output, next_token_ids
|
204
|
+
return logits_output, next_token_ids, can_run_cuda_graph
|
197
205
|
|
198
|
-
def forward_batch_generation(
|
206
|
+
def forward_batch_generation(
|
207
|
+
self, model_worker_batch: ModelWorkerBatch
|
208
|
+
) -> Tuple[None, torch.Tensor, bool]:
|
199
209
|
# Create a new copy of sampling_info because it will be updated in-place by the scheduler for the next batch.
|
200
210
|
sampling_info = model_worker_batch.sampling_info
|
201
211
|
sampling_info.update_penalties()
|
@@ -206,10 +216,11 @@ class TpModelWorkerClient:
|
|
206
216
|
)
|
207
217
|
|
208
218
|
# A cuda stream sync here to avoid the cuda illegal memory access error.
|
209
|
-
self.
|
219
|
+
sync_event = torch.get_device_module(self.device).Event()
|
220
|
+
sync_event.record(self.scheduler_stream)
|
210
221
|
|
211
222
|
# Push a new batch to the queue
|
212
|
-
self.input_queue.put((model_worker_batch, self.future_token_ids_ct))
|
223
|
+
self.input_queue.put((model_worker_batch, self.future_token_ids_ct, sync_event))
|
213
224
|
|
214
225
|
# Allocate output future objects
|
215
226
|
bs = len(model_worker_batch.seq_lens)
|
@@ -223,7 +234,7 @@ class TpModelWorkerClient:
|
|
223
234
|
self.future_token_ids_ct = (
|
224
235
|
self.future_token_ids_ct + bs
|
225
236
|
) % self.future_token_ids_limit
|
226
|
-
return None, future_next_token_ids
|
237
|
+
return None, future_next_token_ids, False
|
227
238
|
|
228
239
|
def update_weights_from_disk(self, recv_req: UpdateWeightFromDiskReqInput):
|
229
240
|
success, message = self.worker.update_weights_from_disk(recv_req)
|
@@ -38,7 +38,9 @@ class ChunkCache(BasePrefixCache):
|
|
38
38
|
|
39
39
|
def cache_finished_req(self, req: Req):
|
40
40
|
kv_indices = self.req_to_token_pool.req_to_token[
|
41
|
-
req.req_pool_idx,
|
41
|
+
req.req_pool_idx,
|
42
|
+
# For decode server: if req.output_ids is empty, we want to free all req.origin_input_ids
|
43
|
+
: len(req.origin_input_ids) + max(len(req.output_ids) - 1, 0),
|
42
44
|
]
|
43
45
|
self.req_to_token_pool.free(req.req_pool_idx)
|
44
46
|
self.token_to_kv_pool_allocator.free(kv_indices)
|
@@ -335,13 +335,13 @@ class HiRadixCache(RadixCache):
|
|
335
335
|
return value, last_node
|
336
336
|
|
337
337
|
def _match_prefix_helper(self, node: TreeNode, key: List):
|
338
|
-
node.last_access_time = time.
|
338
|
+
node.last_access_time = time.monotonic()
|
339
339
|
child_key = self.get_child_key_fn(key)
|
340
340
|
value = []
|
341
341
|
|
342
342
|
while len(key) > 0 and child_key in node.children.keys():
|
343
343
|
child = node.children[child_key]
|
344
|
-
child.last_access_time = time.
|
344
|
+
child.last_access_time = time.monotonic()
|
345
345
|
prefix_len = self.key_match_fn(child.key, key)
|
346
346
|
if prefix_len < len(child.key):
|
347
347
|
new_node = self._split_node(child.key, child, prefix_len)
|
@@ -386,7 +386,7 @@ class HiRadixCache(RadixCache):
|
|
386
386
|
return new_node
|
387
387
|
|
388
388
|
def _insert_helper(self, node: TreeNode, key: List, value):
|
389
|
-
node.last_access_time = time.
|
389
|
+
node.last_access_time = time.monotonic()
|
390
390
|
if len(key) == 0:
|
391
391
|
return 0
|
392
392
|
|
@@ -395,7 +395,7 @@ class HiRadixCache(RadixCache):
|
|
395
395
|
|
396
396
|
while len(key) > 0 and child_key in node.children.keys():
|
397
397
|
node = node.children[child_key]
|
398
|
-
node.last_access_time = time.
|
398
|
+
node.last_access_time = time.monotonic()
|
399
399
|
prefix_len = self.key_match_fn(node.key, key)
|
400
400
|
|
401
401
|
if prefix_len == len(node.key):
|
@@ -38,11 +38,17 @@ import triton
|
|
38
38
|
import triton.language as tl
|
39
39
|
|
40
40
|
from sglang.srt.layers.radix_attention import RadixAttention
|
41
|
-
from sglang.srt.utils import
|
41
|
+
from sglang.srt.utils import (
|
42
|
+
debug_timing,
|
43
|
+
get_compiler_backend,
|
44
|
+
is_cuda,
|
45
|
+
next_power_of_2,
|
46
|
+
)
|
42
47
|
|
43
48
|
logger = logging.getLogger(__name__)
|
44
49
|
|
45
50
|
GB = 1024 * 1024 * 1024
|
51
|
+
_is_cuda = is_cuda()
|
46
52
|
|
47
53
|
|
48
54
|
class ReqToTokenPool:
|
@@ -94,6 +100,33 @@ class ReqToTokenPool:
|
|
94
100
|
|
95
101
|
|
96
102
|
class KVCache(abc.ABC):
|
103
|
+
@abc.abstractmethod
|
104
|
+
def __init__(
|
105
|
+
self,
|
106
|
+
size: int,
|
107
|
+
page_size: int,
|
108
|
+
dtype: torch.dtype,
|
109
|
+
layer_num: int,
|
110
|
+
device: str,
|
111
|
+
enable_memory_saver: bool,
|
112
|
+
start_layer: Optional[int] = None,
|
113
|
+
end_layer: Optional[int] = None,
|
114
|
+
):
|
115
|
+
self.size = size
|
116
|
+
self.page_size = page_size
|
117
|
+
self.dtype = dtype
|
118
|
+
self.device = device
|
119
|
+
if dtype in (torch.float8_e5m2, torch.float8_e4m3fn):
|
120
|
+
# NOTE: Store as torch.uint8 because Tensor.index_put is not implemented for torch.float8_e5m2
|
121
|
+
self.store_dtype = torch.uint8
|
122
|
+
else:
|
123
|
+
self.store_dtype = dtype
|
124
|
+
self.layer_num = layer_num
|
125
|
+
self.start_layer = start_layer or 0
|
126
|
+
self.end_layer = end_layer or layer_num - 1
|
127
|
+
self.memory_saver_adapter = TorchMemorySaverAdapter.create(
|
128
|
+
enable=enable_memory_saver
|
129
|
+
)
|
97
130
|
|
98
131
|
@abc.abstractmethod
|
99
132
|
def get_key_buffer(self, layer_id: int) -> torch.Tensor:
|
@@ -217,30 +250,24 @@ class MHATokenToKVPool(KVCache):
|
|
217
250
|
start_layer: Optional[int] = None,
|
218
251
|
end_layer: Optional[int] = None,
|
219
252
|
):
|
220
|
-
|
221
|
-
|
222
|
-
|
223
|
-
|
224
|
-
|
225
|
-
|
226
|
-
|
227
|
-
|
228
|
-
|
229
|
-
self.memory_saver_adapter = TorchMemorySaverAdapter.create(
|
230
|
-
enable=enable_memory_saver
|
253
|
+
super().__init__(
|
254
|
+
size,
|
255
|
+
page_size,
|
256
|
+
dtype,
|
257
|
+
layer_num,
|
258
|
+
device,
|
259
|
+
enable_memory_saver,
|
260
|
+
start_layer,
|
261
|
+
end_layer,
|
231
262
|
)
|
232
263
|
|
233
264
|
self.head_num = head_num
|
234
265
|
self.head_dim = head_dim
|
235
|
-
self.layer_num = layer_num
|
236
266
|
self._create_buffers()
|
237
|
-
self.start_layer = start_layer or 0
|
238
|
-
self.end_layer = end_layer or layer_num - 1
|
239
267
|
|
240
268
|
self.layer_transfer_counter = None
|
241
|
-
self.capture_mode = False
|
242
269
|
self.device_module = torch.get_device_module(self.device)
|
243
|
-
self.alt_stream = self.device_module.Stream()
|
270
|
+
self.alt_stream = self.device_module.Stream() if is_cuda else None
|
244
271
|
|
245
272
|
k_size, v_size = self.get_kv_size_bytes()
|
246
273
|
logger.info(
|
@@ -357,6 +384,8 @@ class MHATokenToKVPool(KVCache):
|
|
357
384
|
k_scale: Optional[float] = None,
|
358
385
|
v_scale: Optional[float] = None,
|
359
386
|
):
|
387
|
+
from sglang.srt.model_executor.cuda_graph_runner import get_is_capture_mode
|
388
|
+
|
360
389
|
layer_id = layer.layer_id
|
361
390
|
if cache_k.dtype != self.dtype:
|
362
391
|
if k_scale is not None:
|
@@ -370,7 +399,7 @@ class MHATokenToKVPool(KVCache):
|
|
370
399
|
cache_k = cache_k.view(self.store_dtype)
|
371
400
|
cache_v = cache_v.view(self.store_dtype)
|
372
401
|
|
373
|
-
if
|
402
|
+
if get_is_capture_mode() and self.alt_stream is not None:
|
374
403
|
# Overlap the copy of K and V cache for small batch size
|
375
404
|
current_stream = self.device_module.current_stream()
|
376
405
|
self.alt_stream.wait_stream(current_stream)
|
@@ -493,26 +522,21 @@ class MLATokenToKVPool(KVCache):
|
|
493
522
|
start_layer: Optional[int] = None,
|
494
523
|
end_layer: Optional[int] = None,
|
495
524
|
):
|
496
|
-
|
497
|
-
|
498
|
-
|
499
|
-
|
500
|
-
|
501
|
-
|
502
|
-
|
503
|
-
|
504
|
-
|
525
|
+
super().__init__(
|
526
|
+
size,
|
527
|
+
page_size,
|
528
|
+
dtype,
|
529
|
+
layer_num,
|
530
|
+
device,
|
531
|
+
enable_memory_saver,
|
532
|
+
start_layer,
|
533
|
+
end_layer,
|
534
|
+
)
|
535
|
+
|
505
536
|
self.kv_lora_rank = kv_lora_rank
|
506
537
|
self.qk_rope_head_dim = qk_rope_head_dim
|
507
|
-
self.layer_num = layer_num
|
508
|
-
self.start_layer = start_layer or 0
|
509
|
-
self.end_layer = end_layer or layer_num - 1
|
510
|
-
|
511
|
-
memory_saver_adapter = TorchMemorySaverAdapter.create(
|
512
|
-
enable=enable_memory_saver
|
513
|
-
)
|
514
538
|
|
515
|
-
with memory_saver_adapter.region():
|
539
|
+
with self.memory_saver_adapter.region():
|
516
540
|
# The padded slot 0 is used for writing dummy outputs from padded tokens.
|
517
541
|
self.kv_buffer = [
|
518
542
|
torch.zeros(
|
@@ -524,7 +548,6 @@ class MLATokenToKVPool(KVCache):
|
|
524
548
|
]
|
525
549
|
|
526
550
|
self.layer_transfer_counter = None
|
527
|
-
self.page_size = page_size
|
528
551
|
|
529
552
|
kv_size = self.get_kv_size_bytes()
|
530
553
|
logger.info(
|
@@ -637,20 +660,18 @@ class DoubleSparseTokenToKVPool(KVCache):
|
|
637
660
|
start_layer: Optional[int] = None,
|
638
661
|
end_layer: Optional[int] = None,
|
639
662
|
):
|
640
|
-
|
641
|
-
|
642
|
-
|
643
|
-
|
644
|
-
|
645
|
-
|
646
|
-
|
647
|
-
|
648
|
-
|
649
|
-
memory_saver_adapter = TorchMemorySaverAdapter.create(
|
650
|
-
enable=enable_memory_saver
|
663
|
+
super().__init__(
|
664
|
+
size,
|
665
|
+
page_size,
|
666
|
+
dtype,
|
667
|
+
layer_num,
|
668
|
+
device,
|
669
|
+
enable_memory_saver,
|
670
|
+
start_layer,
|
671
|
+
end_layer,
|
651
672
|
)
|
652
673
|
|
653
|
-
with memory_saver_adapter.region():
|
674
|
+
with self.memory_saver_adapter.region():
|
654
675
|
# [size, head_num, head_dim] for each layer
|
655
676
|
self.k_buffer = [
|
656
677
|
torch.zeros(
|
@@ -673,9 +694,6 @@ class DoubleSparseTokenToKVPool(KVCache):
|
|
673
694
|
for _ in range(layer_num)
|
674
695
|
]
|
675
696
|
|
676
|
-
self.start_layer = start_layer or 0
|
677
|
-
self.end_layer = end_layer or layer_num - 1
|
678
|
-
|
679
697
|
def get_key_buffer(self, layer_id: int):
|
680
698
|
return self.k_buffer[layer_id - self.start_layer]
|
681
699
|
|
@@ -743,7 +761,7 @@ class HostKVCache(abc.ABC):
|
|
743
761
|
|
744
762
|
def __init__(
|
745
763
|
self,
|
746
|
-
device_pool:
|
764
|
+
device_pool: KVCache,
|
747
765
|
host_to_device_ratio: float,
|
748
766
|
host_size: int,
|
749
767
|
pin_memory: bool,
|
@@ -762,6 +780,8 @@ class HostKVCache(abc.ABC):
|
|
762
780
|
self.size = int(device_pool.size * host_to_device_ratio)
|
763
781
|
# Align the host memory pool size to the page size
|
764
782
|
self.size = self.size - (self.size % self.page_size)
|
783
|
+
self.start_layer = device_pool.start_layer
|
784
|
+
self.end_layer = device_pool.end_layer
|
765
785
|
|
766
786
|
assert (
|
767
787
|
self.size > device_pool.size
|
@@ -913,6 +933,8 @@ class HostKVCache(abc.ABC):
|
|
913
933
|
|
914
934
|
|
915
935
|
class MHATokenToKVPoolHost(HostKVCache):
|
936
|
+
device_pool: MHATokenToKVPool
|
937
|
+
|
916
938
|
def __init__(
|
917
939
|
self,
|
918
940
|
device_pool: MHATokenToKVPool,
|
@@ -996,6 +1018,8 @@ class MHATokenToKVPoolHost(HostKVCache):
|
|
996
1018
|
|
997
1019
|
|
998
1020
|
class MLATokenToKVPoolHost(HostKVCache):
|
1021
|
+
device_pool: MLATokenToKVPool
|
1022
|
+
|
999
1023
|
def __init__(
|
1000
1024
|
self,
|
1001
1025
|
device_pool: MLATokenToKVPool,
|
@@ -0,0 +1,45 @@
|
|
1
|
+
from typing import Dict
|
2
|
+
|
3
|
+
import torch
|
4
|
+
|
5
|
+
|
6
|
+
class MultiModalCache:
|
7
|
+
"""MultiModalCache is used to store vlm encoder results"""
|
8
|
+
|
9
|
+
def __init__(
|
10
|
+
self,
|
11
|
+
max_size: int,
|
12
|
+
):
|
13
|
+
self.max_size = max_size
|
14
|
+
self.mm_cache: Dict[int, torch.Tensor] = {}
|
15
|
+
self.current_size = 0
|
16
|
+
|
17
|
+
def put(self, mm_hash: int, embedding: torch.Tensor) -> bool:
|
18
|
+
if mm_hash in self.mm_cache:
|
19
|
+
return True
|
20
|
+
data_size = self._get_tensor_size(embedding)
|
21
|
+
if self.current_size + data_size > self.max_size:
|
22
|
+
return False
|
23
|
+
self.mm_cache[mm_hash] = embedding
|
24
|
+
self.current_size += data_size
|
25
|
+
return True
|
26
|
+
|
27
|
+
def get(self, mm_hash: int) -> torch.Tensor:
|
28
|
+
return self.mm_cache.get(mm_hash)
|
29
|
+
|
30
|
+
def free(self, mm_hash: int) -> bool:
|
31
|
+
if mm_hash not in self.mm_cache:
|
32
|
+
return False
|
33
|
+
old_embedding = self.mm_cache.pop(mm_hash)
|
34
|
+
self.current_size -= self._get_tensor_size(old_embedding)
|
35
|
+
return True
|
36
|
+
|
37
|
+
def clear(self):
|
38
|
+
self.mm_cache.clear()
|
39
|
+
self.current_size = 0
|
40
|
+
|
41
|
+
def _get_tensor_size(self, embedding: torch.Tensor):
|
42
|
+
return embedding.element_size() * embedding.numel()
|
43
|
+
|
44
|
+
def __len__(self):
|
45
|
+
return len(self.mm_cache)
|
@@ -27,6 +27,12 @@ from typing import TYPE_CHECKING, List, Optional, Tuple
|
|
27
27
|
|
28
28
|
import torch
|
29
29
|
|
30
|
+
from sglang.srt.disaggregation.kv_events import (
|
31
|
+
AllBlocksCleared,
|
32
|
+
BlockRemoved,
|
33
|
+
BlockStored,
|
34
|
+
KVCacheEvent,
|
35
|
+
)
|
30
36
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
31
37
|
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
|
32
38
|
from sglang.srt.mem_cache.memory_pool import ReqToTokenPool, TokenToKVPoolAllocator
|
@@ -45,7 +51,7 @@ class TreeNode:
|
|
45
51
|
self.key = None
|
46
52
|
self.value = None
|
47
53
|
self.lock_ref = 0
|
48
|
-
self.last_access_time = time.
|
54
|
+
self.last_access_time = time.monotonic()
|
49
55
|
|
50
56
|
self.hit_count = 0
|
51
57
|
# indicating the node is loading KV cache from host
|
@@ -96,11 +102,14 @@ class RadixCache(BasePrefixCache):
|
|
96
102
|
token_to_kv_pool_allocator: TokenToKVPoolAllocator,
|
97
103
|
page_size: int,
|
98
104
|
disable: bool = False,
|
105
|
+
enable_kv_cache_events: bool = False,
|
99
106
|
):
|
100
107
|
self.req_to_token_pool = req_to_token_pool
|
101
108
|
self.token_to_kv_pool_allocator = token_to_kv_pool_allocator
|
102
109
|
self.page_size = page_size
|
103
110
|
self.disable = disable
|
111
|
+
self.enable_kv_cache_events = enable_kv_cache_events
|
112
|
+
self.kv_event_queue = []
|
104
113
|
|
105
114
|
if self.token_to_kv_pool_allocator:
|
106
115
|
self.device = self.token_to_kv_pool_allocator.device
|
@@ -124,6 +133,7 @@ class RadixCache(BasePrefixCache):
|
|
124
133
|
self.root_node.lock_ref = 1
|
125
134
|
self.evictable_size_ = 0
|
126
135
|
self.protected_size_ = 0
|
136
|
+
self._record_all_cleared_event()
|
127
137
|
|
128
138
|
def match_prefix(self, key: List[int], **kwargs) -> Tuple[torch.Tensor, int]:
|
129
139
|
"""Find the matching prefix from the radix tree.
|
@@ -273,6 +283,8 @@ class RadixCache(BasePrefixCache):
|
|
273
283
|
if len(x.parent.children) == 0:
|
274
284
|
heapq.heappush(leaves, x.parent)
|
275
285
|
|
286
|
+
self._record_remove_event(x)
|
287
|
+
|
276
288
|
def inc_lock_ref(self, node: TreeNode):
|
277
289
|
if self.disable:
|
278
290
|
return 0
|
@@ -322,14 +334,14 @@ class RadixCache(BasePrefixCache):
|
|
322
334
|
##### Internal Helper Functions #####
|
323
335
|
|
324
336
|
def _match_prefix_helper(self, node: TreeNode, key: List):
|
325
|
-
node.last_access_time = time.
|
337
|
+
node.last_access_time = time.monotonic()
|
326
338
|
|
327
339
|
child_key = self.get_child_key_fn(key)
|
328
340
|
|
329
341
|
value = []
|
330
342
|
while len(key) > 0 and child_key in node.children.keys():
|
331
343
|
child = node.children[child_key]
|
332
|
-
child.last_access_time = time.
|
344
|
+
child.last_access_time = time.monotonic()
|
333
345
|
prefix_len = self.key_match_fn(child.key, key)
|
334
346
|
if prefix_len < len(child.key):
|
335
347
|
new_node = self._split_node(child.key, child, prefix_len)
|
@@ -348,6 +360,7 @@ class RadixCache(BasePrefixCache):
|
|
348
360
|
|
349
361
|
def _split_node(self, key, child: TreeNode, split_len: int):
|
350
362
|
# new_node -> child
|
363
|
+
self._record_remove_event(child)
|
351
364
|
new_node = TreeNode()
|
352
365
|
new_node.children = {self.get_child_key_fn(key[split_len:]): child}
|
353
366
|
new_node.parent = child.parent
|
@@ -358,10 +371,14 @@ class RadixCache(BasePrefixCache):
|
|
358
371
|
child.key = child.key[split_len:]
|
359
372
|
child.value = child.value[split_len:]
|
360
373
|
new_node.parent.children[self.get_child_key_fn(key)] = new_node
|
374
|
+
|
375
|
+
self._record_store_event(new_node)
|
376
|
+
self._record_store_event(child)
|
377
|
+
|
361
378
|
return new_node
|
362
379
|
|
363
380
|
def _insert_helper(self, node: TreeNode, key: List, value):
|
364
|
-
node.last_access_time = time.
|
381
|
+
node.last_access_time = time.monotonic()
|
365
382
|
if len(key) == 0:
|
366
383
|
return 0
|
367
384
|
|
@@ -370,7 +387,7 @@ class RadixCache(BasePrefixCache):
|
|
370
387
|
total_prefix_length = 0
|
371
388
|
while len(key) > 0 and child_key in node.children.keys():
|
372
389
|
node = node.children[child_key]
|
373
|
-
node.last_access_time = time.
|
390
|
+
node.last_access_time = time.monotonic()
|
374
391
|
prefix_len = self.key_match_fn(node.key, key)
|
375
392
|
total_prefix_length += prefix_len
|
376
393
|
key = key[prefix_len:]
|
@@ -390,6 +407,7 @@ class RadixCache(BasePrefixCache):
|
|
390
407
|
new_node.value = value
|
391
408
|
node.children[child_key] = new_node
|
392
409
|
self.evictable_size_ += len(value)
|
410
|
+
self._record_store_event(new_node)
|
393
411
|
return total_prefix_length
|
394
412
|
|
395
413
|
def _print_helper(self, node: TreeNode, indent: int):
|
@@ -442,6 +460,41 @@ class RadixCache(BasePrefixCache):
|
|
442
460
|
|
443
461
|
return ret_list
|
444
462
|
|
463
|
+
def _record_store_event(self, node: TreeNode):
|
464
|
+
if self.enable_kv_cache_events:
|
465
|
+
block_hash = hash(tuple(node.key))
|
466
|
+
parent_block_hash = hash(tuple(node.parent.key))
|
467
|
+
self.kv_event_queue.append(
|
468
|
+
BlockStored(
|
469
|
+
block_hashes=[block_hash],
|
470
|
+
parent_block_hash=parent_block_hash,
|
471
|
+
token_ids=node.key,
|
472
|
+
block_size=len(node.key),
|
473
|
+
lora_id=None,
|
474
|
+
)
|
475
|
+
)
|
476
|
+
|
477
|
+
def _record_remove_event(self, node: TreeNode):
|
478
|
+
if self.enable_kv_cache_events:
|
479
|
+
block_hash = hash(tuple(node.key))
|
480
|
+
self.kv_event_queue.append(BlockRemoved(block_hashes=[block_hash]))
|
481
|
+
|
482
|
+
def _record_all_cleared_event(self):
|
483
|
+
if self.enable_kv_cache_events:
|
484
|
+
self.kv_event_queue.append(AllBlocksCleared())
|
485
|
+
|
486
|
+
def take_events(self):
|
487
|
+
"""Atomically takes all events and clears the queue.
|
488
|
+
|
489
|
+
Returns:
|
490
|
+
A list of KV cache events.
|
491
|
+
"""
|
492
|
+
if not self.enable_kv_cache_events:
|
493
|
+
return []
|
494
|
+
events = self.kv_event_queue
|
495
|
+
self.kv_event_queue = []
|
496
|
+
return events
|
497
|
+
|
445
498
|
|
446
499
|
if __name__ == "__main__":
|
447
500
|
tree = RadixCache(None, None, page_size=1, disable=False)
|