sglang 0.4.4__py3-none-any.whl → 0.4.4.post2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +2 -0
- sglang/api.py +6 -0
- sglang/bench_one_batch.py +1 -1
- sglang/bench_one_batch_server.py +1 -1
- sglang/bench_serving.py +3 -1
- sglang/check_env.py +3 -4
- sglang/lang/backend/openai.py +18 -5
- sglang/lang/chat_template.py +28 -7
- sglang/lang/interpreter.py +7 -3
- sglang/lang/ir.py +10 -0
- sglang/srt/_custom_ops.py +1 -1
- sglang/srt/code_completion_parser.py +174 -0
- sglang/srt/configs/__init__.py +2 -6
- sglang/srt/configs/deepseekvl2.py +667 -0
- sglang/srt/configs/janus_pro.py +3 -4
- sglang/srt/configs/load_config.py +1 -0
- sglang/srt/configs/model_config.py +63 -11
- sglang/srt/configs/utils.py +25 -0
- sglang/srt/connector/__init__.py +51 -0
- sglang/srt/connector/base_connector.py +112 -0
- sglang/srt/connector/redis.py +85 -0
- sglang/srt/connector/s3.py +122 -0
- sglang/srt/connector/serde/__init__.py +31 -0
- sglang/srt/connector/serde/safe_serde.py +29 -0
- sglang/srt/connector/serde/serde.py +43 -0
- sglang/srt/connector/utils.py +35 -0
- sglang/srt/conversation.py +88 -0
- sglang/srt/disaggregation/conn.py +81 -0
- sglang/srt/disaggregation/decode.py +495 -0
- sglang/srt/disaggregation/mini_lb.py +285 -0
- sglang/srt/disaggregation/prefill.py +249 -0
- sglang/srt/disaggregation/utils.py +44 -0
- sglang/srt/distributed/parallel_state.py +10 -3
- sglang/srt/entrypoints/engine.py +55 -5
- sglang/srt/entrypoints/http_server.py +71 -12
- sglang/srt/function_call_parser.py +164 -54
- sglang/srt/hf_transformers_utils.py +28 -3
- sglang/srt/layers/activation.py +4 -2
- sglang/srt/layers/attention/base_attn_backend.py +1 -1
- sglang/srt/layers/attention/flashattention_backend.py +295 -0
- sglang/srt/layers/attention/flashinfer_backend.py +1 -1
- sglang/srt/layers/attention/flashmla_backend.py +284 -0
- sglang/srt/layers/attention/triton_backend.py +171 -38
- sglang/srt/layers/attention/triton_ops/decode_attention.py +94 -31
- sglang/srt/layers/attention/triton_ops/extend_attention.py +14 -5
- sglang/srt/layers/attention/utils.py +53 -0
- sglang/srt/layers/attention/vision.py +9 -28
- sglang/srt/layers/dp_attention.py +62 -23
- sglang/srt/layers/elementwise.py +411 -0
- sglang/srt/layers/layernorm.py +24 -2
- sglang/srt/layers/linear.py +17 -5
- sglang/srt/layers/logits_processor.py +26 -7
- sglang/srt/layers/moe/ep_moe/kernels.py +110 -11
- sglang/srt/layers/moe/ep_moe/layer.py +273 -1
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +416 -0
- sglang/srt/layers/moe/fused_moe_native.py +2 -1
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_L20,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_L40S,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=1024,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=512,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +23 -32
- sglang/srt/layers/moe/fused_moe_triton/layer.py +1 -2
- sglang/srt/layers/moe/router.py +342 -0
- sglang/srt/layers/moe/topk.py +31 -18
- sglang/srt/layers/parameter.py +1 -1
- sglang/srt/layers/quantization/__init__.py +184 -126
- sglang/srt/layers/quantization/base_config.py +5 -0
- sglang/srt/layers/quantization/blockwise_int8.py +1 -1
- sglang/srt/layers/quantization/compressed_tensors/__init__.py +0 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +652 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +658 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +9 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_scheme.py +56 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +162 -0
- sglang/srt/layers/quantization/compressed_tensors/utils.py +218 -0
- sglang/srt/layers/quantization/fp8.py +76 -34
- sglang/srt/layers/quantization/fp8_kernel.py +24 -8
- sglang/srt/layers/quantization/fp8_utils.py +284 -28
- sglang/srt/layers/quantization/gptq.py +36 -9
- sglang/srt/layers/quantization/kv_cache.py +98 -0
- sglang/srt/layers/quantization/modelopt_quant.py +9 -7
- sglang/srt/layers/quantization/utils.py +153 -0
- sglang/srt/layers/quantization/w8a8_fp8.py +70 -19
- sglang/srt/layers/rotary_embedding.py +66 -87
- sglang/srt/layers/sampler.py +1 -1
- sglang/srt/lora/layers.py +68 -0
- sglang/srt/lora/lora.py +2 -22
- sglang/srt/lora/lora_manager.py +47 -23
- sglang/srt/lora/mem_pool.py +110 -51
- sglang/srt/lora/utils.py +12 -1
- sglang/srt/managers/cache_controller.py +4 -5
- sglang/srt/managers/data_parallel_controller.py +31 -9
- sglang/srt/managers/expert_distribution.py +81 -0
- sglang/srt/managers/io_struct.py +39 -3
- sglang/srt/managers/mm_utils.py +373 -0
- sglang/srt/managers/multimodal_processor.py +68 -0
- sglang/srt/managers/multimodal_processors/base_processor.py +275 -0
- sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +119 -0
- sglang/srt/managers/multimodal_processors/gemma3.py +83 -0
- sglang/srt/managers/{image_processors → multimodal_processors}/janus_pro.py +20 -15
- sglang/srt/managers/{image_processors → multimodal_processors}/llava.py +10 -15
- sglang/srt/managers/multimodal_processors/minicpm.py +167 -0
- sglang/srt/managers/{image_processors → multimodal_processors}/mlama.py +7 -8
- sglang/srt/managers/{image_processors → multimodal_processors}/qwen_vl.py +28 -22
- sglang/srt/managers/schedule_batch.py +134 -31
- sglang/srt/managers/scheduler.py +325 -38
- sglang/srt/managers/scheduler_output_processor_mixin.py +4 -1
- sglang/srt/managers/session_controller.py +1 -1
- sglang/srt/managers/tokenizer_manager.py +59 -23
- sglang/srt/managers/tp_worker.py +1 -1
- sglang/srt/managers/tp_worker_overlap_thread.py +3 -3
- sglang/srt/managers/utils.py +6 -1
- sglang/srt/mem_cache/hiradix_cache.py +27 -8
- sglang/srt/mem_cache/memory_pool.py +258 -98
- sglang/srt/mem_cache/paged_allocator.py +2 -2
- sglang/srt/mem_cache/radix_cache.py +4 -4
- sglang/srt/model_executor/cuda_graph_runner.py +85 -28
- sglang/srt/model_executor/forward_batch_info.py +81 -15
- sglang/srt/model_executor/model_runner.py +70 -6
- sglang/srt/model_loader/loader.py +160 -2
- sglang/srt/model_loader/weight_utils.py +45 -0
- sglang/srt/models/deepseek_janus_pro.py +29 -86
- sglang/srt/models/deepseek_nextn.py +22 -10
- sglang/srt/models/deepseek_v2.py +326 -192
- sglang/srt/models/deepseek_vl2.py +358 -0
- sglang/srt/models/gemma3_causal.py +684 -0
- sglang/srt/models/gemma3_mm.py +462 -0
- sglang/srt/models/grok.py +374 -119
- sglang/srt/models/llama.py +47 -7
- sglang/srt/models/llama_eagle.py +1 -0
- sglang/srt/models/llama_eagle3.py +196 -0
- sglang/srt/models/llava.py +3 -3
- sglang/srt/models/llavavid.py +3 -3
- sglang/srt/models/minicpmo.py +1995 -0
- sglang/srt/models/minicpmv.py +62 -137
- sglang/srt/models/mllama.py +4 -4
- sglang/srt/models/phi3_small.py +1 -1
- sglang/srt/models/qwen2.py +3 -0
- sglang/srt/models/qwen2_5_vl.py +68 -146
- sglang/srt/models/qwen2_classification.py +75 -0
- sglang/srt/models/qwen2_moe.py +9 -1
- sglang/srt/models/qwen2_vl.py +25 -63
- sglang/srt/openai_api/adapter.py +145 -47
- sglang/srt/openai_api/protocol.py +23 -2
- sglang/srt/sampling/sampling_batch_info.py +1 -1
- sglang/srt/sampling/sampling_params.py +6 -6
- sglang/srt/server_args.py +104 -14
- sglang/srt/speculative/build_eagle_tree.py +7 -347
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +41 -5
- sglang/srt/speculative/eagle_utils.py +208 -252
- sglang/srt/speculative/eagle_worker.py +139 -53
- sglang/srt/speculative/spec_info.py +6 -1
- sglang/srt/torch_memory_saver_adapter.py +22 -0
- sglang/srt/utils.py +182 -21
- sglang/test/__init__.py +0 -0
- sglang/test/attention/__init__.py +0 -0
- sglang/test/attention/test_flashattn_backend.py +312 -0
- sglang/test/runners.py +2 -0
- sglang/test/test_activation.py +2 -1
- sglang/test/test_block_fp8.py +5 -4
- sglang/test/test_block_fp8_ep.py +2 -1
- sglang/test/test_dynamic_grad_mode.py +58 -0
- sglang/test/test_layernorm.py +3 -2
- sglang/test/test_utils.py +55 -4
- sglang/utils.py +31 -0
- sglang/version.py +1 -1
- {sglang-0.4.4.dist-info → sglang-0.4.4.post2.dist-info}/METADATA +12 -8
- {sglang-0.4.4.dist-info → sglang-0.4.4.post2.dist-info}/RECORD +171 -125
- {sglang-0.4.4.dist-info → sglang-0.4.4.post2.dist-info}/WHEEL +1 -1
- sglang/srt/configs/qwen2_5_vl_config.py +0 -1006
- sglang/srt/managers/image_processor.py +0 -55
- sglang/srt/managers/image_processors/base_image_processor.py +0 -219
- sglang/srt/managers/image_processors/minicpmv.py +0 -86
- sglang/srt/managers/multi_modality_padding.py +0 -134
- {sglang-0.4.4.dist-info → sglang-0.4.4.post2.dist-info/licenses}/LICENSE +0 -0
- {sglang-0.4.4.dist-info → sglang-0.4.4.post2.dist-info}/top_level.txt +0 -0
@@ -16,7 +16,6 @@
|
|
16
16
|
import asyncio
|
17
17
|
import copy
|
18
18
|
import dataclasses
|
19
|
-
import json
|
20
19
|
import logging
|
21
20
|
import os
|
22
21
|
import pickle
|
@@ -49,11 +48,9 @@ from fastapi import BackgroundTasks
|
|
49
48
|
|
50
49
|
from sglang.srt.aio_rwlock import RWLock
|
51
50
|
from sglang.srt.configs.model_config import ModelConfig
|
51
|
+
from sglang.srt.disaggregation.conn import KVBootstrapServer
|
52
|
+
from sglang.srt.disaggregation.utils import DisaggregationMode
|
52
53
|
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
|
53
|
-
from sglang.srt.managers.image_processor import (
|
54
|
-
get_dummy_image_processor,
|
55
|
-
get_image_processor,
|
56
|
-
)
|
57
54
|
from sglang.srt.managers.io_struct import (
|
58
55
|
AbortReq,
|
59
56
|
BatchEmbeddingOut,
|
@@ -63,6 +60,8 @@ from sglang.srt.managers.io_struct import (
|
|
63
60
|
CloseSessionReqInput,
|
64
61
|
ConfigureLoggingReq,
|
65
62
|
EmbeddingReqInput,
|
63
|
+
ExpertDistributionReq,
|
64
|
+
ExpertDistributionReqOutput,
|
66
65
|
FlushCacheReq,
|
67
66
|
GenerateReqInput,
|
68
67
|
GetInternalStateReq,
|
@@ -91,6 +90,11 @@ from sglang.srt.managers.io_struct import (
|
|
91
90
|
UpdateWeightsFromTensorReqInput,
|
92
91
|
UpdateWeightsFromTensorReqOutput,
|
93
92
|
)
|
93
|
+
from sglang.srt.managers.multimodal_processor import (
|
94
|
+
get_dummy_processor,
|
95
|
+
get_mm_processor,
|
96
|
+
import_processors,
|
97
|
+
)
|
94
98
|
from sglang.srt.metrics.collector import TokenizerMetricsCollector
|
95
99
|
from sglang.srt.sampling.sampling_params import SamplingParams
|
96
100
|
from sglang.srt.server_args import PortArgs, ServerArgs
|
@@ -168,27 +172,33 @@ class TokenizerManager:
|
|
168
172
|
self.context_len = self.model_config.context_len
|
169
173
|
self.image_token_id = self.model_config.image_token_id
|
170
174
|
|
171
|
-
|
172
|
-
|
175
|
+
if self.model_config.is_multimodal:
|
176
|
+
import_processors()
|
177
|
+
_processor = get_processor(
|
178
|
+
server_args.tokenizer_path,
|
179
|
+
tokenizer_mode=server_args.tokenizer_mode,
|
180
|
+
trust_remote_code=server_args.trust_remote_code,
|
181
|
+
revision=server_args.revision,
|
182
|
+
)
|
173
183
|
|
174
|
-
|
175
|
-
|
176
|
-
|
177
|
-
|
178
|
-
|
179
|
-
|
180
|
-
|
181
|
-
|
182
|
-
|
183
|
-
|
184
|
-
|
184
|
+
# We want to parallelize the image pre-processing so we create an executor for it
|
185
|
+
# We create mm_processor for any skip_tokenizer_init to make sure we still encode
|
186
|
+
# images even with skip_tokenizer_init=False.
|
187
|
+
self.mm_processor = get_mm_processor(
|
188
|
+
self.model_config.hf_config, server_args, _processor
|
189
|
+
)
|
190
|
+
|
191
|
+
if server_args.skip_tokenizer_init:
|
192
|
+
self.tokenizer = self.processor = None
|
193
|
+
else:
|
194
|
+
self.processor = _processor
|
185
195
|
self.tokenizer = self.processor.tokenizer
|
186
196
|
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
197
|
+
else:
|
198
|
+
self.mm_processor = get_dummy_processor()
|
187
199
|
|
188
|
-
|
189
|
-
self.
|
190
|
-
self.model_config.hf_config, server_args, self.processor
|
191
|
-
)
|
200
|
+
if server_args.skip_tokenizer_init:
|
201
|
+
self.tokenizer = self.processor = None
|
192
202
|
else:
|
193
203
|
self.tokenizer = get_tokenizer(
|
194
204
|
server_args.tokenizer_path,
|
@@ -255,6 +265,9 @@ class TokenizerManager:
|
|
255
265
|
self.get_internal_state_communicator = _Communicator(
|
256
266
|
self.send_to_scheduler, server_args.dp_size
|
257
267
|
)
|
268
|
+
self.expert_distribution_communicator = _Communicator(
|
269
|
+
self.send_to_scheduler, server_args.dp_size
|
270
|
+
)
|
258
271
|
|
259
272
|
self._result_dispatcher = TypeBasedDispatcher(
|
260
273
|
[
|
@@ -304,10 +317,24 @@ class TokenizerManager:
|
|
304
317
|
GetInternalStateReqOutput,
|
305
318
|
self.get_internal_state_communicator.handle_recv,
|
306
319
|
),
|
320
|
+
(
|
321
|
+
ExpertDistributionReqOutput,
|
322
|
+
self.expert_distribution_communicator.handle_recv,
|
323
|
+
),
|
307
324
|
(HealthCheckOutput, lambda x: None),
|
308
325
|
]
|
309
326
|
)
|
310
327
|
|
328
|
+
self.disaggregation_mode = DisaggregationMode(
|
329
|
+
self.server_args.disaggregation_mode
|
330
|
+
)
|
331
|
+
# for disaggregtion, start kv boostrap server on prefill
|
332
|
+
if self.disaggregation_mode == DisaggregationMode.PREFILL:
|
333
|
+
# only start bootstrap server on prefill tm
|
334
|
+
self.bootstrap_server = KVBootstrapServer(
|
335
|
+
self.server_args.disaggregation_bootstrap_port
|
336
|
+
)
|
337
|
+
|
311
338
|
async def generate_request(
|
312
339
|
self,
|
313
340
|
obj: Union[GenerateReqInput, EmbeddingReqInput],
|
@@ -372,7 +399,7 @@ class TokenizerManager:
|
|
372
399
|
)
|
373
400
|
input_ids = self.tokenizer.encode(input_text)
|
374
401
|
|
375
|
-
image_inputs: Dict = await self.
|
402
|
+
image_inputs: Dict = await self.mm_processor.process_mm_data_async(
|
376
403
|
obj.image_data, input_text or input_ids, obj, self.max_req_input_len
|
377
404
|
)
|
378
405
|
if image_inputs and "input_ids" in image_inputs:
|
@@ -620,6 +647,15 @@ class TokenizerManager:
|
|
620
647
|
req = ProfileReq(type=ProfileReqType.STOP_PROFILE)
|
621
648
|
self.send_to_scheduler.send_pyobj(req)
|
622
649
|
|
650
|
+
async def start_expert_distribution_record(self):
|
651
|
+
await self.expert_distribution_communicator(ExpertDistributionReq.START_RECORD)
|
652
|
+
|
653
|
+
async def stop_expert_distribution_record(self):
|
654
|
+
await self.expert_distribution_communicator(ExpertDistributionReq.STOP_RECORD)
|
655
|
+
|
656
|
+
async def dump_expert_distribution_record(self):
|
657
|
+
await self.expert_distribution_communicator(ExpertDistributionReq.DUMP_RECORD)
|
658
|
+
|
623
659
|
async def update_weights_from_disk(
|
624
660
|
self,
|
625
661
|
obj: UpdateWeightFromDiskReqInput,
|
sglang/srt/managers/tp_worker.py
CHANGED
@@ -214,7 +214,7 @@ class TpModelWorker:
|
|
214
214
|
def update_weights_from_tensor(self, recv_req: UpdateWeightsFromTensorReqInput):
|
215
215
|
success, message = self.model_runner.update_weights_from_tensor(
|
216
216
|
named_tensors=MultiprocessingSerializer.deserialize(
|
217
|
-
recv_req.serialized_named_tensors
|
217
|
+
recv_req.serialized_named_tensors[self.tp_rank]
|
218
218
|
),
|
219
219
|
load_format=recv_req.load_format,
|
220
220
|
)
|
@@ -33,7 +33,7 @@ from sglang.srt.managers.io_struct import (
|
|
33
33
|
from sglang.srt.managers.schedule_batch import ModelWorkerBatch
|
34
34
|
from sglang.srt.managers.tp_worker import TpModelWorker
|
35
35
|
from sglang.srt.server_args import ServerArgs
|
36
|
-
from sglang.srt.utils import get_compiler_backend
|
36
|
+
from sglang.srt.utils import DynamicGradMode, get_compiler_backend
|
37
37
|
from sglang.utils import get_exception_traceback
|
38
38
|
|
39
39
|
logger = logging.getLogger(__name__)
|
@@ -69,7 +69,7 @@ class TpModelWorkerClient:
|
|
69
69
|
self.future_token_ids_ct = 0
|
70
70
|
self.future_token_ids_limit = self.max_running_requests * 3
|
71
71
|
self.future_token_ids_map = torch.empty(
|
72
|
-
(self.max_running_requests * 5,), dtype=torch.
|
72
|
+
(self.max_running_requests * 5,), dtype=torch.int64, device=self.device
|
73
73
|
)
|
74
74
|
|
75
75
|
# Launch threads
|
@@ -115,7 +115,7 @@ class TpModelWorkerClient:
|
|
115
115
|
logger.error(f"TpModelWorkerClient hit an exception: {traceback}")
|
116
116
|
self.parent_process.send_signal(signal.SIGQUIT)
|
117
117
|
|
118
|
-
@
|
118
|
+
@DynamicGradMode()
|
119
119
|
def forward_thread_func_(self):
|
120
120
|
batch_pt = 0
|
121
121
|
batch_lists = [None] * 2
|
sglang/srt/managers/utils.py
CHANGED
@@ -1,6 +1,11 @@
|
|
1
|
+
import json
|
1
2
|
import logging
|
3
|
+
import time
|
4
|
+
from collections import defaultdict
|
2
5
|
from http import HTTPStatus
|
3
|
-
from typing import Optional
|
6
|
+
from typing import Dict, List, Optional, Tuple
|
7
|
+
|
8
|
+
import torch
|
4
9
|
|
5
10
|
from sglang.srt.managers.schedule_batch import FINISH_ABORT, Req
|
6
11
|
|
@@ -8,7 +8,10 @@ import torch
|
|
8
8
|
|
9
9
|
from sglang.srt.managers.cache_controller import HiCacheController
|
10
10
|
from sglang.srt.mem_cache.memory_pool import (
|
11
|
+
MHATokenToKVPool,
|
11
12
|
MHATokenToKVPoolHost,
|
13
|
+
MLATokenToKVPool,
|
14
|
+
MLATokenToKVPoolHost,
|
12
15
|
ReqToTokenPool,
|
13
16
|
TokenToKVPoolAllocator,
|
14
17
|
)
|
@@ -25,11 +28,27 @@ class HiRadixCache(RadixCache):
|
|
25
28
|
req_to_token_pool: ReqToTokenPool,
|
26
29
|
token_to_kv_pool_allocator: TokenToKVPoolAllocator,
|
27
30
|
tp_cache_group: torch.distributed.ProcessGroup,
|
31
|
+
page_size: int,
|
32
|
+
hicache_ratio: float,
|
28
33
|
):
|
29
|
-
|
30
|
-
|
31
|
-
|
34
|
+
if page_size != 1:
|
35
|
+
raise ValueError(
|
36
|
+
"Page size larger than 1 is not yet supported in HiRadixCache."
|
37
|
+
)
|
38
|
+
self.kv_cache = token_to_kv_pool_allocator.get_kvcache()
|
39
|
+
if isinstance(self.kv_cache, MHATokenToKVPool):
|
40
|
+
self.token_to_kv_pool_host = MHATokenToKVPoolHost(
|
41
|
+
self.kv_cache, hicache_ratio
|
42
|
+
)
|
43
|
+
elif isinstance(self.kv_cache, MLATokenToKVPool):
|
44
|
+
self.token_to_kv_pool_host = MLATokenToKVPoolHost(
|
45
|
+
self.kv_cache, hicache_ratio
|
46
|
+
)
|
47
|
+
else:
|
48
|
+
raise ValueError(f"Only MHA and MLA supports swap kv_cache to host.")
|
49
|
+
|
32
50
|
self.tp_group = tp_cache_group
|
51
|
+
self.page_size = page_size
|
33
52
|
|
34
53
|
self.load_cache_event = threading.Event()
|
35
54
|
self.cache_controller = HiCacheController(
|
@@ -45,7 +64,9 @@ class HiRadixCache(RadixCache):
|
|
45
64
|
# todo: dynamically adjust the threshold
|
46
65
|
self.write_through_threshold = 1
|
47
66
|
self.load_back_threshold = 10
|
48
|
-
super().__init__(
|
67
|
+
super().__init__(
|
68
|
+
req_to_token_pool, token_to_kv_pool_allocator, self.page_size, disable=False
|
69
|
+
)
|
49
70
|
|
50
71
|
def reset(self):
|
51
72
|
TreeNode.counter = 0
|
@@ -287,9 +308,9 @@ class HiRadixCache(RadixCache):
|
|
287
308
|
|
288
309
|
value, last_node = self._match_prefix_helper(self.root_node, key)
|
289
310
|
if value:
|
290
|
-
value = torch.
|
311
|
+
value = torch.cat(value)
|
291
312
|
else:
|
292
|
-
value = torch.tensor([], dtype=torch.
|
313
|
+
value = torch.tensor([], dtype=torch.int64)
|
293
314
|
|
294
315
|
last_node_global = last_node
|
295
316
|
while last_node.evicted:
|
@@ -309,13 +330,11 @@ class HiRadixCache(RadixCache):
|
|
309
330
|
prefix_len = _key_match(child.key, key)
|
310
331
|
if prefix_len < len(child.key):
|
311
332
|
new_node = self._split_node(child.key, child, prefix_len)
|
312
|
-
self.inc_hit_count(new_node)
|
313
333
|
if not new_node.evicted:
|
314
334
|
value.append(new_node.value)
|
315
335
|
node = new_node
|
316
336
|
break
|
317
337
|
else:
|
318
|
-
self.inc_hit_count(child)
|
319
338
|
if not child.evicted:
|
320
339
|
value.append(child.value)
|
321
340
|
node = child
|