sglang 0.4.10.post2__py3-none-any.whl → 0.5.0rc1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +8 -3
- sglang/bench_one_batch.py +119 -17
- sglang/lang/chat_template.py +18 -0
- sglang/srt/bench_utils.py +137 -0
- sglang/srt/configs/model_config.py +42 -7
- sglang/srt/conversation.py +9 -5
- sglang/srt/disaggregation/base/conn.py +5 -2
- sglang/srt/disaggregation/decode.py +14 -4
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +3 -0
- sglang/srt/disaggregation/mooncake/conn.py +286 -160
- sglang/srt/disaggregation/mooncake/transfer_engine.py +29 -0
- sglang/srt/disaggregation/prefill.py +2 -0
- sglang/srt/distributed/parallel_state.py +15 -11
- sglang/srt/entrypoints/context.py +227 -0
- sglang/srt/entrypoints/engine.py +15 -9
- sglang/srt/entrypoints/harmony_utils.py +372 -0
- sglang/srt/entrypoints/http_server.py +74 -4
- sglang/srt/entrypoints/openai/protocol.py +218 -1
- sglang/srt/entrypoints/openai/serving_chat.py +41 -11
- sglang/srt/entrypoints/openai/serving_responses.py +1273 -0
- sglang/srt/entrypoints/openai/tool_server.py +175 -0
- sglang/srt/entrypoints/tool.py +87 -0
- sglang/srt/eplb/expert_location.py +5 -1
- sglang/srt/function_call/ebnf_composer.py +1 -0
- sglang/srt/function_call/function_call_parser.py +2 -0
- sglang/srt/function_call/glm4_moe_detector.py +1 -1
- sglang/srt/function_call/gpt_oss_detector.py +331 -0
- sglang/srt/function_call/kimik2_detector.py +3 -3
- sglang/srt/function_call/qwen3_coder_detector.py +219 -9
- sglang/srt/hf_transformers_utils.py +30 -3
- sglang/srt/jinja_template_utils.py +14 -1
- sglang/srt/layers/attention/aiter_backend.py +375 -115
- sglang/srt/layers/attention/ascend_backend.py +3 -0
- sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1700 -0
- sglang/srt/layers/attention/flashattention_backend.py +18 -0
- sglang/srt/layers/attention/flashinfer_backend.py +52 -13
- sglang/srt/layers/attention/hybrid_attn_backend.py +1 -1
- sglang/srt/layers/attention/triton_backend.py +85 -14
- sglang/srt/layers/attention/triton_ops/decode_attention.py +17 -0
- sglang/srt/layers/attention/triton_ops/extend_attention.py +143 -98
- sglang/srt/layers/attention/trtllm_mha_backend.py +332 -0
- sglang/srt/layers/attention/trtllm_mla_backend.py +119 -22
- sglang/srt/layers/attention/vision.py +22 -6
- sglang/srt/layers/attention/wave_backend.py +627 -0
- sglang/srt/layers/attention/wave_ops/decode_attention.py +186 -0
- sglang/srt/layers/attention/wave_ops/extend_attention.py +149 -0
- sglang/srt/layers/attention/wave_ops/prefill_attention.py +79 -0
- sglang/srt/layers/communicator.py +29 -14
- sglang/srt/layers/dp_attention.py +12 -0
- sglang/srt/layers/flashinfer_comm_fusion.py +4 -4
- sglang/srt/layers/linear.py +3 -7
- sglang/srt/layers/moe/cutlass_moe.py +12 -3
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +4 -5
- sglang/srt/layers/moe/ep_moe/kernels.py +43 -0
- sglang/srt/layers/moe/ep_moe/layer.py +135 -73
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=384,N=256,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +101 -12
- sglang/srt/layers/moe/fused_moe_triton/layer.py +412 -33
- sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +188 -3
- sglang/srt/layers/moe/token_dispatcher/deepep.py +61 -24
- sglang/srt/layers/moe/topk.py +16 -4
- sglang/srt/layers/moe/utils.py +16 -0
- sglang/srt/layers/quantization/__init__.py +27 -3
- sglang/srt/layers/quantization/fp4.py +557 -0
- sglang/srt/layers/quantization/fp8.py +3 -6
- sglang/srt/layers/quantization/fp8_kernel.py +277 -0
- sglang/srt/layers/quantization/fp8_utils.py +51 -10
- sglang/srt/layers/quantization/modelopt_quant.py +258 -68
- sglang/srt/layers/quantization/mxfp4.py +654 -0
- sglang/srt/layers/quantization/mxfp4_tensor.py +133 -0
- sglang/srt/layers/quantization/quark/schemes/__init__.py +6 -0
- sglang/srt/layers/quantization/quark/schemes/quark_scheme.py +55 -0
- sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +118 -0
- sglang/srt/layers/quantization/quark/utils.py +107 -0
- sglang/srt/layers/quantization/unquant.py +60 -6
- sglang/srt/layers/quantization/w4afp8.py +21 -12
- sglang/srt/layers/quantization/w8a8_int8.py +48 -34
- sglang/srt/layers/rotary_embedding.py +506 -3
- sglang/srt/layers/utils.py +9 -0
- sglang/srt/layers/vocab_parallel_embedding.py +8 -3
- sglang/srt/lora/backend/base_backend.py +3 -23
- sglang/srt/lora/layers.py +60 -114
- sglang/srt/lora/lora.py +17 -62
- sglang/srt/lora/lora_manager.py +82 -62
- sglang/srt/lora/lora_registry.py +23 -11
- sglang/srt/lora/mem_pool.py +63 -68
- sglang/srt/lora/triton_ops/qkv_lora_b.py +1 -1
- sglang/srt/lora/utils.py +25 -58
- sglang/srt/managers/cache_controller.py +75 -58
- sglang/srt/managers/detokenizer_manager.py +1 -1
- sglang/srt/managers/io_struct.py +20 -8
- sglang/srt/managers/mm_utils.py +6 -13
- sglang/srt/managers/multimodal_processor.py +1 -1
- sglang/srt/managers/schedule_batch.py +61 -25
- sglang/srt/managers/schedule_policy.py +6 -6
- sglang/srt/managers/scheduler.py +41 -19
- sglang/srt/managers/scheduler_output_processor_mixin.py +1 -2
- sglang/srt/managers/scheduler_profiler_mixin.py +28 -8
- sglang/srt/managers/scheduler_recv_skipper.py +37 -0
- sglang/srt/managers/scheduler_update_weights_mixin.py +6 -0
- sglang/srt/managers/template_manager.py +35 -1
- sglang/srt/managers/tokenizer_manager.py +47 -30
- sglang/srt/managers/tp_worker.py +3 -0
- sglang/srt/managers/tp_worker_overlap_thread.py +3 -0
- sglang/srt/mem_cache/allocator.py +61 -87
- sglang/srt/mem_cache/hicache_storage.py +1 -1
- sglang/srt/mem_cache/hiradix_cache.py +80 -22
- sglang/srt/mem_cache/lora_radix_cache.py +421 -0
- sglang/srt/mem_cache/memory_pool_host.py +34 -36
- sglang/srt/mem_cache/multimodal_cache.py +33 -13
- sglang/srt/mem_cache/radix_cache.py +2 -5
- sglang/srt/mem_cache/storage/hf3fs/client_hf3fs.py +2 -2
- sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +443 -0
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +139 -67
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +6 -9
- sglang/srt/model_executor/cuda_graph_runner.py +29 -9
- sglang/srt/model_executor/forward_batch_info.py +61 -19
- sglang/srt/model_executor/model_runner.py +148 -37
- sglang/srt/model_loader/loader.py +18 -6
- sglang/srt/model_loader/weight_utils.py +10 -0
- sglang/srt/models/bailing_moe.py +425 -0
- sglang/srt/models/deepseek_v2.py +137 -59
- sglang/srt/models/ernie4.py +426 -0
- sglang/srt/models/ernie4_eagle.py +203 -0
- sglang/srt/models/gemma2.py +0 -34
- sglang/srt/models/gemma3n_mm.py +38 -0
- sglang/srt/models/glm4.py +6 -0
- sglang/srt/models/glm4_moe.py +28 -16
- sglang/srt/models/glm4v.py +589 -0
- sglang/srt/models/glm4v_moe.py +400 -0
- sglang/srt/models/gpt_oss.py +1251 -0
- sglang/srt/models/granite.py +0 -25
- sglang/srt/models/llama.py +0 -25
- sglang/srt/models/llama4.py +1 -1
- sglang/srt/models/qwen2.py +6 -0
- sglang/srt/models/qwen2_5_vl.py +7 -3
- sglang/srt/models/qwen2_audio.py +10 -9
- sglang/srt/models/qwen2_moe.py +6 -0
- sglang/srt/models/qwen3.py +0 -24
- sglang/srt/models/qwen3_moe.py +32 -6
- sglang/srt/models/registry.py +1 -1
- sglang/srt/models/step3_vl.py +9 -0
- sglang/srt/models/torch_native_llama.py +0 -24
- sglang/srt/models/transformers.py +2 -5
- sglang/srt/multimodal/processors/base_processor.py +23 -13
- sglang/srt/multimodal/processors/glm4v.py +132 -0
- sglang/srt/multimodal/processors/qwen_audio.py +4 -2
- sglang/srt/multimodal/processors/step3_vl.py +3 -1
- sglang/srt/reasoning_parser.py +332 -37
- sglang/srt/server_args.py +186 -75
- sglang/srt/speculative/eagle_worker.py +16 -0
- sglang/srt/two_batch_overlap.py +169 -9
- sglang/srt/utils.py +41 -5
- sglang/srt/weight_sync/tensor_bucket.py +106 -0
- sglang/test/attention/test_trtllm_mla_backend.py +186 -36
- sglang/test/doc_patch.py +59 -0
- sglang/test/few_shot_gsm8k.py +1 -1
- sglang/test/few_shot_gsm8k_engine.py +1 -1
- sglang/test/run_eval.py +4 -1
- sglang/test/runners.py +2 -2
- sglang/test/simple_eval_common.py +6 -0
- sglang/test/simple_eval_gpqa.py +2 -0
- sglang/test/test_fp4_moe.py +118 -36
- sglang/test/test_utils.py +1 -1
- sglang/utils.py +1 -1
- sglang/version.py +1 -1
- {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/METADATA +36 -38
- {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/RECORD +174 -141
- sglang/srt/lora/backend/flashinfer_backend.py +0 -131
- /sglang/{api.py → lang/api.py} +0 -0
- /sglang/{lang/backend → srt/layers/quantization/quark}/__init__.py +0 -0
- {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/WHEEL +0 -0
- {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.10.post2.dist-info → sglang-0.5.0rc1.dist-info}/top_level.txt +0 -0
@@ -43,12 +43,14 @@ class BaseTokenToKVPoolAllocator(abc.ABC):
|
|
43
43
|
dtype: torch.dtype,
|
44
44
|
device: str,
|
45
45
|
kvcache: KVCache,
|
46
|
+
need_sort: bool,
|
46
47
|
):
|
47
48
|
self.size = size
|
48
49
|
self.page_size = page_size
|
49
50
|
self.dtype = dtype
|
50
51
|
self.device = device
|
51
52
|
self._kvcache = kvcache
|
53
|
+
self.need_sort = need_sort
|
52
54
|
|
53
55
|
self.free_pages = None
|
54
56
|
self.release_pages = None
|
@@ -79,6 +81,9 @@ class BaseTokenToKVPoolAllocator(abc.ABC):
|
|
79
81
|
if self.free_group:
|
80
82
|
self.free(torch.cat(self.free_group))
|
81
83
|
|
84
|
+
def estimated_num_new_pages(self, bs, extend_num_tokens):
|
85
|
+
return bs * ((extend_num_tokens + self.page_size - 1) // self.page_size)
|
86
|
+
|
82
87
|
def merge_and_sort_free(self):
|
83
88
|
if len(self.release_pages) > 0:
|
84
89
|
self.free_pages = torch.cat((self.free_pages, self.release_pages))
|
@@ -117,8 +122,15 @@ class BaseTokenToKVPoolAllocator(abc.ABC):
|
|
117
122
|
class TokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
|
118
123
|
"""An allocator managing the indices to kv cache data."""
|
119
124
|
|
120
|
-
def __init__(
|
121
|
-
|
125
|
+
def __init__(
|
126
|
+
self,
|
127
|
+
size: int,
|
128
|
+
dtype: torch.dtype,
|
129
|
+
device: str,
|
130
|
+
kvcache: KVCache,
|
131
|
+
need_sort: bool,
|
132
|
+
):
|
133
|
+
super().__init__(size, 1, dtype, device, kvcache, need_sort)
|
122
134
|
self.clear()
|
123
135
|
|
124
136
|
def clear(self):
|
@@ -135,7 +147,7 @@ class TokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
|
|
135
147
|
return len(self.free_pages) + len(self.release_pages)
|
136
148
|
|
137
149
|
def alloc(self, need_size: int):
|
138
|
-
if need_size > len(self.free_pages):
|
150
|
+
if self.need_sort and need_size > len(self.free_pages):
|
139
151
|
self.merge_and_sort_free()
|
140
152
|
if need_size > len(self.free_pages):
|
141
153
|
return None
|
@@ -149,7 +161,10 @@ class TokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
|
|
149
161
|
return
|
150
162
|
|
151
163
|
if self.is_not_in_free_group:
|
152
|
-
|
164
|
+
if self.need_sort:
|
165
|
+
self.release_pages = torch.cat((self.release_pages, free_index))
|
166
|
+
else:
|
167
|
+
self.free_pages = torch.cat((self.free_pages, free_index))
|
153
168
|
else:
|
154
169
|
self.free_group.append(free_index)
|
155
170
|
|
@@ -170,8 +185,9 @@ class SWATokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
|
|
170
185
|
dtype: torch.dtype,
|
171
186
|
device: str,
|
172
187
|
kvcache: SWAKVPool,
|
188
|
+
need_sort: bool,
|
173
189
|
):
|
174
|
-
super().__init__(size, 1, dtype, device, kvcache)
|
190
|
+
super().__init__(size, 1, dtype, device, kvcache, need_sort)
|
175
191
|
assert isinstance(kvcache, SWAKVPool)
|
176
192
|
self._size_full = size
|
177
193
|
self._size_swa = size_swa
|
@@ -180,12 +196,14 @@ class SWATokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
|
|
180
196
|
dtype,
|
181
197
|
device,
|
182
198
|
kvcache.full_kv_pool,
|
199
|
+
need_sort,
|
183
200
|
)
|
184
201
|
self.swa_attn_allocator = TokenToKVPoolAllocator(
|
185
202
|
size_swa,
|
186
203
|
dtype,
|
187
204
|
device,
|
188
205
|
kvcache.swa_kv_pool,
|
206
|
+
need_sort,
|
189
207
|
)
|
190
208
|
self.full_to_swa_index_mapping = torch.empty(
|
191
209
|
size + size_swa + 1,
|
@@ -418,8 +436,9 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
|
|
418
436
|
dtype: torch.dtype,
|
419
437
|
device: str,
|
420
438
|
kvcache: KVCache,
|
439
|
+
need_sort: bool,
|
421
440
|
):
|
422
|
-
super().__init__(size, page_size, dtype, device, kvcache)
|
441
|
+
super().__init__(size, page_size, dtype, device, kvcache, need_sort)
|
423
442
|
self.num_pages = size // page_size
|
424
443
|
self.debug_mode = get_bool_env_var("SGLANG_DEBUG_MEMORY_POOL")
|
425
444
|
self.ret_values = torch.empty((), dtype=torch.int64, device=self.device)
|
@@ -433,7 +452,7 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
|
|
433
452
|
), "The allocation size should be page-aligned"
|
434
453
|
|
435
454
|
num_pages = need_size // self.page_size
|
436
|
-
if num_pages > len(self.free_pages):
|
455
|
+
if self.need_sort and num_pages > len(self.free_pages):
|
437
456
|
self.merge_and_sort_free()
|
438
457
|
if num_pages > len(self.free_pages):
|
439
458
|
return None
|
@@ -460,18 +479,12 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
|
|
460
479
|
(last_loc + 1) % self.page_size == prefix_lens % self.page_size
|
461
480
|
)
|
462
481
|
|
463
|
-
|
464
|
-
|
465
|
-
|
466
|
-
|
467
|
-
)
|
468
|
-
.sum()
|
469
|
-
.item()
|
470
|
-
)
|
471
|
-
if estimated_num_new_pages > len(self.free_pages):
|
482
|
+
bs = len(prefix_lens)
|
483
|
+
if self.need_sort and self.estimated_num_new_pages(bs, extend_num_tokens) > len(
|
484
|
+
self.free_pages
|
485
|
+
):
|
472
486
|
self.merge_and_sort_free()
|
473
487
|
|
474
|
-
bs = len(prefix_lens)
|
475
488
|
out_indices = torch.empty(
|
476
489
|
(extend_num_tokens,), dtype=torch.int64, device=self.device
|
477
490
|
)
|
@@ -508,18 +521,12 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
|
|
508
521
|
(last_loc + 2) % self.page_size == seq_lens % self.page_size
|
509
522
|
)
|
510
523
|
|
511
|
-
|
512
|
-
|
513
|
-
|
514
|
-
|
515
|
-
)
|
516
|
-
.sum()
|
517
|
-
.item()
|
518
|
-
)
|
519
|
-
if estimated_num_new_pages > len(self.free_pages):
|
524
|
+
bs = len(seq_lens)
|
525
|
+
if self.need_sort and self.estimated_num_new_pages(bs, 1) > len(
|
526
|
+
self.free_pages
|
527
|
+
):
|
520
528
|
self.merge_and_sort_free()
|
521
529
|
|
522
|
-
bs = len(seq_lens)
|
523
530
|
out_indices = torch.empty((bs,), dtype=torch.int64, device=self.device)
|
524
531
|
alloc_decode_kernel[(bs,)](
|
525
532
|
seq_lens,
|
@@ -547,7 +554,10 @@ class PagedTokenToKVPoolAllocator(BaseTokenToKVPoolAllocator):
|
|
547
554
|
|
548
555
|
if self.is_not_in_free_group:
|
549
556
|
free_page_indices = torch.unique(free_index // self.page_size)
|
550
|
-
|
557
|
+
if self.need_sort:
|
558
|
+
self.release_pages = torch.cat((free_page_indices, self.release_pages))
|
559
|
+
else:
|
560
|
+
self.free_pages = torch.cat((free_page_indices, self.free_pages))
|
551
561
|
else:
|
552
562
|
self.free_group.append(free_index)
|
553
563
|
|
@@ -622,27 +632,6 @@ def alloc_extend_kernel_ascend(
|
|
622
632
|
out_indices[end_pos[i] - num3 : end_pos[i]] = (
|
623
633
|
free_pages[end_new_pages[i] - 1] * page_size + pos_in_page[:num3]
|
624
634
|
).view(-1)
|
625
|
-
return num_new_pages
|
626
|
-
|
627
|
-
|
628
|
-
def alloc_decode_kernel_ascend(
|
629
|
-
seq_lens,
|
630
|
-
last_loc,
|
631
|
-
free_pages,
|
632
|
-
out_indices,
|
633
|
-
page_size,
|
634
|
-
):
|
635
|
-
num_new_pages = (seq_lens + page_size - 1) // page_size - (
|
636
|
-
seq_lens - 1 + page_size - 1
|
637
|
-
) // page_size
|
638
|
-
end_new_pages = torch.cumsum(num_new_pages, 0)
|
639
|
-
start_new_pages = end_new_pages - num_new_pages
|
640
|
-
for i in range(len(seq_lens)):
|
641
|
-
if num_new_pages[i]:
|
642
|
-
out_indices[i] = free_pages[start_new_pages[i]] * page_size
|
643
|
-
else:
|
644
|
-
out_indices[i] = last_loc[i] + 1
|
645
|
-
return num_new_pages
|
646
635
|
|
647
636
|
|
648
637
|
class AscendPagedTokenToKVPoolAllocator(PagedTokenToKVPoolAllocator):
|
@@ -654,9 +643,9 @@ class AscendPagedTokenToKVPoolAllocator(PagedTokenToKVPoolAllocator):
|
|
654
643
|
dtype: torch.dtype,
|
655
644
|
device: str,
|
656
645
|
kvcache: KVCache,
|
646
|
+
need_sort: bool,
|
657
647
|
):
|
658
|
-
super().__init__(size, page_size, dtype, device, kvcache)
|
659
|
-
self.ret_values = torch.empty((), dtype=torch.int32, device=self.device)
|
648
|
+
super().__init__(size, page_size, dtype, device, kvcache, need_sort)
|
660
649
|
|
661
650
|
def alloc_extend(
|
662
651
|
self,
|
@@ -678,15 +667,17 @@ class AscendPagedTokenToKVPoolAllocator(PagedTokenToKVPoolAllocator):
|
|
678
667
|
.sum()
|
679
668
|
.item()
|
680
669
|
)
|
681
|
-
if estimated_num_new_pages > len(self.free_pages):
|
670
|
+
if self.need_sort and estimated_num_new_pages > len(self.free_pages):
|
682
671
|
self.merge_and_sort_free()
|
683
672
|
|
684
|
-
|
673
|
+
if estimated_num_new_pages > len(self.free_pages):
|
674
|
+
return None
|
675
|
+
|
685
676
|
out_indices = torch.empty(
|
686
677
|
(extend_num_tokens,), dtype=torch.int32, device=self.device
|
687
678
|
)
|
688
679
|
|
689
|
-
|
680
|
+
alloc_extend_kernel_ascend(
|
690
681
|
prefix_lens,
|
691
682
|
seq_lens,
|
692
683
|
last_loc,
|
@@ -699,11 +690,7 @@ class AscendPagedTokenToKVPoolAllocator(PagedTokenToKVPoolAllocator):
|
|
699
690
|
if self.debug_mode:
|
700
691
|
assert len(torch.unique(out_indices)) == len(out_indices)
|
701
692
|
|
702
|
-
|
703
|
-
if num_new_pages > len(self.free_pages):
|
704
|
-
return None
|
705
|
-
|
706
|
-
self.free_pages = self.free_pages[num_new_pages:]
|
693
|
+
self.free_pages = self.free_pages[estimated_num_new_pages:]
|
707
694
|
return out_indices
|
708
695
|
|
709
696
|
def alloc_decode(
|
@@ -716,39 +703,26 @@ class AscendPagedTokenToKVPoolAllocator(PagedTokenToKVPoolAllocator):
|
|
716
703
|
(last_loc + 2) % self.page_size == seq_lens % self.page_size
|
717
704
|
)
|
718
705
|
|
719
|
-
|
720
|
-
|
721
|
-
|
722
|
-
|
723
|
-
)
|
724
|
-
.sum()
|
725
|
-
.item()
|
726
|
-
)
|
727
|
-
if estimated_num_new_pages > len(self.free_pages):
|
706
|
+
need_new_pages = (seq_lens % self.page_size == 1).int()
|
707
|
+
num_new_pages = need_new_pages.sum().item()
|
708
|
+
|
709
|
+
if num_new_pages > len(self.free_pages):
|
728
710
|
self.merge_and_sort_free()
|
729
711
|
|
730
|
-
|
731
|
-
|
712
|
+
if num_new_pages > len(self.free_pages):
|
713
|
+
return None
|
732
714
|
|
733
|
-
|
734
|
-
|
735
|
-
|
736
|
-
|
737
|
-
|
738
|
-
self.
|
739
|
-
|
715
|
+
end_new_pages = torch.cumsum(need_new_pages, 0)
|
716
|
+
start_new_pages = end_new_pages - need_new_pages
|
717
|
+
if num_new_pages == 0:
|
718
|
+
out_indices = last_loc + 1
|
719
|
+
else:
|
720
|
+
out_indices = (last_loc + 1) * (1 - need_new_pages) + self.free_pages[
|
721
|
+
start_new_pages
|
722
|
+
] * self.page_size * need_new_pages
|
740
723
|
|
741
724
|
if self.debug_mode:
|
742
725
|
assert len(torch.unique(out_indices)) == len(out_indices)
|
743
726
|
|
744
|
-
num_new_pages = self.ret_values.sum()
|
745
|
-
if num_new_pages > len(self.free_pages):
|
746
|
-
return None
|
747
|
-
|
748
727
|
self.free_pages = self.free_pages[num_new_pages:]
|
749
|
-
return out_indices
|
750
|
-
|
751
|
-
def clear(self):
|
752
|
-
super().clear()
|
753
|
-
self.free_pages = self.free_pages.to(torch.int32)
|
754
|
-
self.release_pages = self.release_pages.to(torch.int32)
|
728
|
+
return out_indices.int()
|
@@ -2,11 +2,12 @@ import heapq
|
|
2
2
|
import logging
|
3
3
|
import threading
|
4
4
|
import time
|
5
|
+
from queue import Queue
|
5
6
|
from typing import List, Optional
|
6
7
|
|
7
8
|
import torch
|
8
9
|
|
9
|
-
from sglang.srt.managers.cache_controller import HiCacheController
|
10
|
+
from sglang.srt.managers.cache_controller import HiCacheController, PrefetchOperation
|
10
11
|
from sglang.srt.mem_cache.allocator import BaseTokenToKVPoolAllocator
|
11
12
|
from sglang.srt.mem_cache.base_prefix_cache import MatchResult
|
12
13
|
from sglang.srt.mem_cache.memory_pool import (
|
@@ -37,6 +38,7 @@ class HiRadixCache(RadixCache):
|
|
37
38
|
hicache_io_backend: str,
|
38
39
|
hicache_mem_layout: str,
|
39
40
|
hicache_storage_backend: Optional[str] = None,
|
41
|
+
hicache_storage_prefetch_policy: Optional[str] = "best_effort",
|
40
42
|
):
|
41
43
|
|
42
44
|
if hicache_io_backend == "direct":
|
@@ -69,8 +71,10 @@ class HiRadixCache(RadixCache):
|
|
69
71
|
self.tp_group = tp_cache_group
|
70
72
|
self.tp_world_size = torch.distributed.get_world_size(group=self.tp_group)
|
71
73
|
self.enable_storage = hicache_storage_backend is not None
|
72
|
-
# todo: customizable storage prefetch threshold
|
74
|
+
# todo: customizable storage prefetch threshold and timeout
|
73
75
|
self.prefetch_threshold = 256
|
76
|
+
self.prefetch_timeout = 3 # seconds
|
77
|
+
self.prefetch_stop_policy = hicache_storage_prefetch_policy
|
74
78
|
|
75
79
|
self.load_cache_event = threading.Event()
|
76
80
|
self.cache_controller = HiCacheController(
|
@@ -142,7 +146,7 @@ class HiRadixCache(RadixCache):
|
|
142
146
|
|
143
147
|
def write_backup_storage(self, node: TreeNode):
|
144
148
|
operation_id = self.cache_controller.write_storage(
|
145
|
-
node.host_value, node.key, node.
|
149
|
+
node.host_value, node.key, node.hash_value
|
146
150
|
)
|
147
151
|
self.ongoing_backup[operation_id] = node
|
148
152
|
node.protect_host()
|
@@ -385,9 +389,10 @@ class HiRadixCache(RadixCache):
|
|
385
389
|
for _ in range(queue_size.item()):
|
386
390
|
req_id = self.cache_controller.prefetch_revoke_queue.get()
|
387
391
|
if req_id in self.ongoing_prefetch:
|
388
|
-
last_host_node,
|
392
|
+
last_host_node, token_ids, _, _ = self.ongoing_prefetch[req_id]
|
389
393
|
last_host_node.release_host()
|
390
394
|
del self.ongoing_prefetch[req_id]
|
395
|
+
self.cache_controller.prefetch_tokens_occupied -= len(token_ids)
|
391
396
|
else:
|
392
397
|
# the revoked operation already got terminated
|
393
398
|
pass
|
@@ -404,25 +409,56 @@ class HiRadixCache(RadixCache):
|
|
404
409
|
group=self.tp_group,
|
405
410
|
)
|
406
411
|
for _ in range(queue_size.item()):
|
407
|
-
ack_id,
|
408
|
-
self.cache_controller.ack_backup_queue.get()
|
409
|
-
)
|
412
|
+
ack_id, completed_tokens = self.cache_controller.ack_backup_queue.get()
|
410
413
|
host_node = self.ongoing_backup[ack_id]
|
411
|
-
|
412
|
-
|
413
|
-
|
414
|
-
|
415
|
-
|
416
|
-
|
417
|
-
|
418
|
-
|
414
|
+
|
415
|
+
if completed_tokens > 0:
|
416
|
+
if completed_tokens < len(host_node.key):
|
417
|
+
# backup is only partially successful, split the node
|
418
|
+
new_node = self._split_node(
|
419
|
+
host_node.key, host_node, completed_tokens
|
420
|
+
)
|
421
|
+
new_node.backuped_storage = True
|
422
|
+
else:
|
423
|
+
host_node.backuped_storage = True
|
419
424
|
host_node.release_host()
|
420
425
|
del self.ongoing_backup[ack_id]
|
421
426
|
|
422
|
-
def
|
427
|
+
def can_terminate_prefetch(self, operation: PrefetchOperation):
|
428
|
+
can_terminate = True
|
429
|
+
|
430
|
+
if self.prefetch_stop_policy == "best_effort":
|
431
|
+
return can_terminate
|
432
|
+
|
433
|
+
completed = (
|
434
|
+
operation.completed_tokens == len(operation.hash_value) * self.page_size
|
435
|
+
)
|
436
|
+
|
437
|
+
if self.prefetch_stop_policy == "wait_complete":
|
438
|
+
can_terminate = completed
|
439
|
+
elif self.prefetch_stop_policy == "timeout":
|
440
|
+
can_terminate = completed or (
|
441
|
+
time.monotonic() - operation.start_time > self.prefetch_timeout
|
442
|
+
)
|
443
|
+
else:
|
444
|
+
# unknown prefetch stop policy, just return True
|
445
|
+
return True
|
446
|
+
|
447
|
+
if self.tp_world_size > 1:
|
448
|
+
can_terminate = torch.tensor(can_terminate, dtype=torch.int)
|
449
|
+
torch.distributed.all_reduce(
|
450
|
+
can_terminate,
|
451
|
+
op=torch.distributed.ReduceOp.MIN,
|
452
|
+
group=self.tp_group,
|
453
|
+
)
|
454
|
+
can_terminate = bool(can_terminate.item())
|
455
|
+
|
456
|
+
return can_terminate
|
457
|
+
|
458
|
+
def check_prefetch_progress(self, req_id: str) -> bool:
|
423
459
|
if req_id not in self.ongoing_prefetch:
|
424
460
|
# there is no ongoing prefetch for this request or it has been revoked
|
425
|
-
return
|
461
|
+
return True
|
426
462
|
|
427
463
|
# todo: more policies for prefetch progress such as timeout
|
428
464
|
# the current policy is to prefetch with best effort and terminate when queuing is over
|
@@ -430,13 +466,20 @@ class HiRadixCache(RadixCache):
|
|
430
466
|
req_id
|
431
467
|
]
|
432
468
|
|
469
|
+
if operation.host_indices is None:
|
470
|
+
# prefetch has not been issued due to insufficient host memory
|
471
|
+
return True
|
472
|
+
|
473
|
+
if not self.can_terminate_prefetch(operation):
|
474
|
+
return False
|
475
|
+
|
433
476
|
completed_tokens, hash_value = self.cache_controller.terminate_prefetch(
|
434
477
|
operation
|
435
478
|
)
|
436
479
|
logger.debug(f"Prefetch {req_id} completed with {completed_tokens} tokens")
|
437
480
|
|
438
481
|
min_completed_tokens = completed_tokens
|
439
|
-
if self.tp_world_size > 1:
|
482
|
+
if self.tp_world_size > 1 and self.prefetch_stop_policy != "wait_complete":
|
440
483
|
# synchrnoize TP workers to make the same update to hiradix cache
|
441
484
|
completed_tokens_tensor = torch.tensor(
|
442
485
|
min_completed_tokens, dtype=torch.int
|
@@ -464,6 +507,9 @@ class HiRadixCache(RadixCache):
|
|
464
507
|
)
|
465
508
|
last_host_node.release_host()
|
466
509
|
del self.ongoing_prefetch[req_id]
|
510
|
+
self.cache_controller.prefetch_tokens_occupied -= len(token_ids)
|
511
|
+
|
512
|
+
return True
|
467
513
|
|
468
514
|
def match_prefix(self, key: List[int], **kwargs):
|
469
515
|
empty_value = torch.empty((0,), dtype=torch.int64, device=self.device)
|
@@ -518,10 +564,6 @@ class HiRadixCache(RadixCache):
|
|
518
564
|
if host_indices is None:
|
519
565
|
self.evict_host(prefetch_length)
|
520
566
|
host_indices = self.cache_controller.mem_pool_host.alloc(prefetch_length)
|
521
|
-
if host_indices is None:
|
522
|
-
last_host_node.release_host()
|
523
|
-
# no sufficient host memory to prefetch
|
524
|
-
return
|
525
567
|
operation = self.cache_controller.prefetch(
|
526
568
|
req_id, host_indices, new_input_tokens, last_hash
|
527
569
|
)
|
@@ -531,6 +573,7 @@ class HiRadixCache(RadixCache):
|
|
531
573
|
host_indices,
|
532
574
|
operation,
|
533
575
|
)
|
576
|
+
self.cache_controller.prefetch_tokens_occupied += len(new_input_tokens)
|
534
577
|
|
535
578
|
def _insert_helper_host(self, node: TreeNode, key: List, host_value, hash_value):
|
536
579
|
node.last_access_time = time.monotonic()
|
@@ -669,6 +712,21 @@ class HiRadixCache(RadixCache):
|
|
669
712
|
node.children[child_key] = new_node
|
670
713
|
self.evictable_size_ += len(value)
|
671
714
|
|
715
|
+
if self.enable_storage:
|
716
|
+
last_hash = node.get_last_hash_value()
|
717
|
+
assert (node == self.root_node) or (
|
718
|
+
last_hash is not None
|
719
|
+
), "Parent node must have a hash value with storage enabled"
|
720
|
+
new_node.hash_value = []
|
721
|
+
for idx in range(0, len(key), self.page_size):
|
722
|
+
new_node.hash_value.append(
|
723
|
+
self.cache_controller.get_hash_str(
|
724
|
+
key[idx : idx + self.page_size],
|
725
|
+
prior_hash=last_hash,
|
726
|
+
)
|
727
|
+
)
|
728
|
+
last_hash = new_node.hash_value[-1]
|
729
|
+
|
672
730
|
if self.cache_controller.write_policy != "write_back":
|
673
731
|
self.inc_hit_count(new_node)
|
674
732
|
return total_prefix_length
|