sglang 0.4.4.post1__py3-none-any.whl → 0.4.4.post3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +2 -0
- sglang/api.py +6 -0
- sglang/bench_one_batch.py +1 -1
- sglang/bench_one_batch_server.py +1 -1
- sglang/bench_serving.py +26 -4
- sglang/check_env.py +3 -4
- sglang/lang/backend/openai.py +18 -5
- sglang/lang/chat_template.py +28 -7
- sglang/lang/interpreter.py +7 -3
- sglang/lang/ir.py +10 -0
- sglang/srt/_custom_ops.py +1 -1
- sglang/srt/code_completion_parser.py +174 -0
- sglang/srt/configs/__init__.py +2 -6
- sglang/srt/configs/deepseekvl2.py +676 -0
- sglang/srt/configs/janus_pro.py +3 -4
- sglang/srt/configs/load_config.py +1 -0
- sglang/srt/configs/model_config.py +49 -8
- sglang/srt/configs/utils.py +25 -0
- sglang/srt/connector/__init__.py +51 -0
- sglang/srt/connector/base_connector.py +112 -0
- sglang/srt/connector/redis.py +85 -0
- sglang/srt/connector/s3.py +122 -0
- sglang/srt/connector/serde/__init__.py +31 -0
- sglang/srt/connector/serde/safe_serde.py +29 -0
- sglang/srt/connector/serde/serde.py +43 -0
- sglang/srt/connector/utils.py +35 -0
- sglang/srt/conversation.py +88 -0
- sglang/srt/disaggregation/conn.py +81 -0
- sglang/srt/disaggregation/decode.py +495 -0
- sglang/srt/disaggregation/mini_lb.py +285 -0
- sglang/srt/disaggregation/prefill.py +249 -0
- sglang/srt/disaggregation/utils.py +44 -0
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +1 -1
- sglang/srt/distributed/parallel_state.py +42 -8
- sglang/srt/entrypoints/engine.py +55 -5
- sglang/srt/entrypoints/http_server.py +78 -13
- sglang/srt/entrypoints/verl_engine.py +2 -0
- sglang/srt/function_call_parser.py +133 -55
- sglang/srt/hf_transformers_utils.py +28 -3
- sglang/srt/layers/activation.py +4 -2
- sglang/srt/layers/attention/base_attn_backend.py +1 -1
- sglang/srt/layers/attention/flashattention_backend.py +434 -0
- sglang/srt/layers/attention/flashinfer_backend.py +1 -1
- sglang/srt/layers/attention/flashmla_backend.py +284 -0
- sglang/srt/layers/attention/triton_backend.py +171 -38
- sglang/srt/layers/attention/triton_ops/decode_attention.py +94 -31
- sglang/srt/layers/attention/triton_ops/extend_attention.py +14 -5
- sglang/srt/layers/attention/utils.py +53 -0
- sglang/srt/layers/attention/vision.py +9 -28
- sglang/srt/layers/dp_attention.py +41 -19
- sglang/srt/layers/layernorm.py +24 -2
- sglang/srt/layers/linear.py +17 -5
- sglang/srt/layers/logits_processor.py +25 -7
- sglang/srt/layers/moe/ep_moe/kernels.py +110 -11
- sglang/srt/layers/moe/ep_moe/layer.py +273 -1
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +416 -0
- sglang/srt/layers/moe/fused_moe_native.py +2 -1
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_L20,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_L40S,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=1024,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=64,N=512,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +23 -32
- sglang/srt/layers/moe/fused_moe_triton/layer.py +1 -2
- sglang/srt/layers/moe/topk.py +60 -20
- sglang/srt/layers/parameter.py +1 -1
- sglang/srt/layers/quantization/__init__.py +80 -53
- sglang/srt/layers/quantization/awq.py +200 -0
- sglang/srt/layers/quantization/base_config.py +5 -0
- sglang/srt/layers/quantization/blockwise_int8.py +1 -1
- sglang/srt/layers/quantization/compressed_tensors/__init__.py +0 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +652 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +658 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +9 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_scheme.py +56 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +162 -0
- sglang/srt/layers/quantization/compressed_tensors/utils.py +218 -0
- sglang/srt/layers/quantization/fp8.py +76 -34
- sglang/srt/layers/quantization/fp8_kernel.py +25 -8
- sglang/srt/layers/quantization/fp8_utils.py +284 -28
- sglang/srt/layers/quantization/gptq.py +36 -19
- sglang/srt/layers/quantization/kv_cache.py +98 -0
- sglang/srt/layers/quantization/modelopt_quant.py +9 -7
- sglang/srt/layers/quantization/utils.py +153 -0
- sglang/srt/layers/quantization/w8a8_fp8.py +70 -19
- sglang/srt/layers/rotary_embedding.py +78 -87
- sglang/srt/layers/sampler.py +1 -1
- sglang/srt/lora/backend/base_backend.py +4 -4
- sglang/srt/lora/backend/flashinfer_backend.py +12 -9
- sglang/srt/lora/backend/triton_backend.py +5 -8
- sglang/srt/lora/layers.py +87 -33
- sglang/srt/lora/lora.py +2 -22
- sglang/srt/lora/lora_manager.py +67 -30
- sglang/srt/lora/mem_pool.py +117 -52
- sglang/srt/lora/triton_ops/gate_up_lora_b.py +10 -4
- sglang/srt/lora/triton_ops/qkv_lora_b.py +8 -3
- sglang/srt/lora/triton_ops/sgemm_lora_a.py +16 -5
- sglang/srt/lora/triton_ops/sgemm_lora_b.py +11 -6
- sglang/srt/lora/utils.py +18 -1
- sglang/srt/managers/cache_controller.py +2 -5
- sglang/srt/managers/data_parallel_controller.py +30 -8
- sglang/srt/managers/expert_distribution.py +81 -0
- sglang/srt/managers/io_struct.py +43 -5
- sglang/srt/managers/mm_utils.py +373 -0
- sglang/srt/managers/multimodal_processor.py +68 -0
- sglang/srt/managers/multimodal_processors/base_processor.py +275 -0
- sglang/srt/managers/multimodal_processors/clip.py +63 -0
- sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +119 -0
- sglang/srt/managers/multimodal_processors/gemma3.py +83 -0
- sglang/srt/managers/{image_processors → multimodal_processors}/janus_pro.py +20 -15
- sglang/srt/managers/{image_processors → multimodal_processors}/llava.py +10 -15
- sglang/srt/managers/multimodal_processors/minicpm.py +167 -0
- sglang/srt/managers/{image_processors → multimodal_processors}/mlama.py +7 -8
- sglang/srt/managers/{image_processors → multimodal_processors}/qwen_vl.py +28 -22
- sglang/srt/managers/schedule_batch.py +134 -30
- sglang/srt/managers/scheduler.py +290 -31
- sglang/srt/managers/session_controller.py +1 -1
- sglang/srt/managers/tokenizer_manager.py +59 -24
- sglang/srt/managers/tp_worker.py +4 -1
- sglang/srt/managers/tp_worker_overlap_thread.py +3 -3
- sglang/srt/managers/utils.py +6 -1
- sglang/srt/mem_cache/hiradix_cache.py +18 -7
- sglang/srt/mem_cache/memory_pool.py +255 -98
- sglang/srt/mem_cache/paged_allocator.py +2 -2
- sglang/srt/mem_cache/radix_cache.py +4 -4
- sglang/srt/model_executor/cuda_graph_runner.py +36 -21
- sglang/srt/model_executor/forward_batch_info.py +68 -11
- sglang/srt/model_executor/model_runner.py +75 -8
- sglang/srt/model_loader/loader.py +171 -3
- sglang/srt/model_loader/weight_utils.py +51 -3
- sglang/srt/models/clip.py +563 -0
- sglang/srt/models/deepseek_janus_pro.py +31 -88
- sglang/srt/models/deepseek_nextn.py +22 -10
- sglang/srt/models/deepseek_v2.py +329 -73
- sglang/srt/models/deepseek_vl2.py +358 -0
- sglang/srt/models/gemma3_causal.py +694 -0
- sglang/srt/models/gemma3_mm.py +468 -0
- sglang/srt/models/llama.py +47 -7
- sglang/srt/models/llama_eagle.py +1 -0
- sglang/srt/models/llama_eagle3.py +196 -0
- sglang/srt/models/llava.py +3 -3
- sglang/srt/models/llavavid.py +3 -3
- sglang/srt/models/minicpmo.py +1995 -0
- sglang/srt/models/minicpmv.py +62 -137
- sglang/srt/models/mllama.py +4 -4
- sglang/srt/models/phi3_small.py +1 -1
- sglang/srt/models/qwen2.py +3 -0
- sglang/srt/models/qwen2_5_vl.py +68 -146
- sglang/srt/models/qwen2_classification.py +75 -0
- sglang/srt/models/qwen2_moe.py +9 -1
- sglang/srt/models/qwen2_vl.py +25 -63
- sglang/srt/openai_api/adapter.py +201 -104
- sglang/srt/openai_api/protocol.py +33 -7
- sglang/srt/patch_torch.py +71 -0
- sglang/srt/sampling/sampling_batch_info.py +1 -1
- sglang/srt/sampling/sampling_params.py +6 -6
- sglang/srt/server_args.py +114 -14
- sglang/srt/speculative/build_eagle_tree.py +7 -347
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +41 -5
- sglang/srt/speculative/eagle_utils.py +208 -252
- sglang/srt/speculative/eagle_worker.py +140 -54
- sglang/srt/speculative/spec_info.py +6 -1
- sglang/srt/torch_memory_saver_adapter.py +22 -0
- sglang/srt/utils.py +215 -21
- sglang/test/__init__.py +0 -0
- sglang/test/attention/__init__.py +0 -0
- sglang/test/attention/test_flashattn_backend.py +312 -0
- sglang/test/runners.py +29 -2
- sglang/test/test_activation.py +2 -1
- sglang/test/test_block_fp8.py +5 -4
- sglang/test/test_block_fp8_ep.py +2 -1
- sglang/test/test_dynamic_grad_mode.py +58 -0
- sglang/test/test_layernorm.py +3 -2
- sglang/test/test_utils.py +56 -5
- sglang/utils.py +31 -0
- sglang/version.py +1 -1
- {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post3.dist-info}/METADATA +16 -8
- {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post3.dist-info}/RECORD +180 -132
- {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post3.dist-info}/WHEEL +1 -1
- sglang/srt/configs/qwen2_5_vl_config.py +0 -1006
- sglang/srt/managers/image_processor.py +0 -55
- sglang/srt/managers/image_processors/base_image_processor.py +0 -219
- sglang/srt/managers/image_processors/minicpmv.py +0 -86
- sglang/srt/managers/multi_modality_padding.py +0 -134
- {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post3.dist-info/licenses}/LICENSE +0 -0
- {sglang-0.4.4.post1.dist-info → sglang-0.4.4.post3.dist-info}/top_level.txt +0 -0
@@ -16,7 +16,6 @@
|
|
16
16
|
import asyncio
|
17
17
|
import copy
|
18
18
|
import dataclasses
|
19
|
-
import json
|
20
19
|
import logging
|
21
20
|
import os
|
22
21
|
import pickle
|
@@ -49,11 +48,9 @@ from fastapi import BackgroundTasks
|
|
49
48
|
|
50
49
|
from sglang.srt.aio_rwlock import RWLock
|
51
50
|
from sglang.srt.configs.model_config import ModelConfig
|
51
|
+
from sglang.srt.disaggregation.conn import KVBootstrapServer
|
52
|
+
from sglang.srt.disaggregation.utils import DisaggregationMode
|
52
53
|
from sglang.srt.hf_transformers_utils import get_processor, get_tokenizer
|
53
|
-
from sglang.srt.managers.image_processor import (
|
54
|
-
get_dummy_image_processor,
|
55
|
-
get_image_processor,
|
56
|
-
)
|
57
54
|
from sglang.srt.managers.io_struct import (
|
58
55
|
AbortReq,
|
59
56
|
BatchEmbeddingOut,
|
@@ -63,6 +60,8 @@ from sglang.srt.managers.io_struct import (
|
|
63
60
|
CloseSessionReqInput,
|
64
61
|
ConfigureLoggingReq,
|
65
62
|
EmbeddingReqInput,
|
63
|
+
ExpertDistributionReq,
|
64
|
+
ExpertDistributionReqOutput,
|
66
65
|
FlushCacheReq,
|
67
66
|
GenerateReqInput,
|
68
67
|
GetInternalStateReq,
|
@@ -91,6 +90,11 @@ from sglang.srt.managers.io_struct import (
|
|
91
90
|
UpdateWeightsFromTensorReqInput,
|
92
91
|
UpdateWeightsFromTensorReqOutput,
|
93
92
|
)
|
93
|
+
from sglang.srt.managers.multimodal_processor import (
|
94
|
+
get_dummy_processor,
|
95
|
+
get_mm_processor,
|
96
|
+
import_processors,
|
97
|
+
)
|
94
98
|
from sglang.srt.metrics.collector import TokenizerMetricsCollector
|
95
99
|
from sglang.srt.sampling.sampling_params import SamplingParams
|
96
100
|
from sglang.srt.server_args import PortArgs, ServerArgs
|
@@ -168,27 +172,33 @@ class TokenizerManager:
|
|
168
172
|
self.context_len = self.model_config.context_len
|
169
173
|
self.image_token_id = self.model_config.image_token_id
|
170
174
|
|
171
|
-
|
172
|
-
|
175
|
+
if self.model_config.is_multimodal:
|
176
|
+
import_processors()
|
177
|
+
_processor = get_processor(
|
178
|
+
server_args.tokenizer_path,
|
179
|
+
tokenizer_mode=server_args.tokenizer_mode,
|
180
|
+
trust_remote_code=server_args.trust_remote_code,
|
181
|
+
revision=server_args.revision,
|
182
|
+
)
|
173
183
|
|
174
|
-
|
175
|
-
|
176
|
-
|
177
|
-
|
178
|
-
|
179
|
-
|
180
|
-
|
181
|
-
|
182
|
-
|
183
|
-
|
184
|
-
|
184
|
+
# We want to parallelize the image pre-processing so we create an executor for it
|
185
|
+
# We create mm_processor for any skip_tokenizer_init to make sure we still encode
|
186
|
+
# images even with skip_tokenizer_init=False.
|
187
|
+
self.mm_processor = get_mm_processor(
|
188
|
+
self.model_config.hf_config, server_args, _processor
|
189
|
+
)
|
190
|
+
|
191
|
+
if server_args.skip_tokenizer_init:
|
192
|
+
self.tokenizer = self.processor = None
|
193
|
+
else:
|
194
|
+
self.processor = _processor
|
185
195
|
self.tokenizer = self.processor.tokenizer
|
186
196
|
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
197
|
+
else:
|
198
|
+
self.mm_processor = get_dummy_processor()
|
187
199
|
|
188
|
-
|
189
|
-
self.
|
190
|
-
self.model_config.hf_config, server_args, self.processor
|
191
|
-
)
|
200
|
+
if server_args.skip_tokenizer_init:
|
201
|
+
self.tokenizer = self.processor = None
|
192
202
|
else:
|
193
203
|
self.tokenizer = get_tokenizer(
|
194
204
|
server_args.tokenizer_path,
|
@@ -251,10 +261,12 @@ class TokenizerManager:
|
|
251
261
|
self.start_profile_communicator = _Communicator(
|
252
262
|
self.send_to_scheduler, server_args.dp_size
|
253
263
|
)
|
254
|
-
self.health_check_communitcator = _Communicator(self.send_to_scheduler, 1)
|
255
264
|
self.get_internal_state_communicator = _Communicator(
|
256
265
|
self.send_to_scheduler, server_args.dp_size
|
257
266
|
)
|
267
|
+
self.expert_distribution_communicator = _Communicator(
|
268
|
+
self.send_to_scheduler, server_args.dp_size
|
269
|
+
)
|
258
270
|
|
259
271
|
self._result_dispatcher = TypeBasedDispatcher(
|
260
272
|
[
|
@@ -304,10 +316,24 @@ class TokenizerManager:
|
|
304
316
|
GetInternalStateReqOutput,
|
305
317
|
self.get_internal_state_communicator.handle_recv,
|
306
318
|
),
|
319
|
+
(
|
320
|
+
ExpertDistributionReqOutput,
|
321
|
+
self.expert_distribution_communicator.handle_recv,
|
322
|
+
),
|
307
323
|
(HealthCheckOutput, lambda x: None),
|
308
324
|
]
|
309
325
|
)
|
310
326
|
|
327
|
+
self.disaggregation_mode = DisaggregationMode(
|
328
|
+
self.server_args.disaggregation_mode
|
329
|
+
)
|
330
|
+
# for disaggregtion, start kv boostrap server on prefill
|
331
|
+
if self.disaggregation_mode == DisaggregationMode.PREFILL:
|
332
|
+
# only start bootstrap server on prefill tm
|
333
|
+
self.bootstrap_server = KVBootstrapServer(
|
334
|
+
self.server_args.disaggregation_bootstrap_port
|
335
|
+
)
|
336
|
+
|
311
337
|
async def generate_request(
|
312
338
|
self,
|
313
339
|
obj: Union[GenerateReqInput, EmbeddingReqInput],
|
@@ -372,7 +398,7 @@ class TokenizerManager:
|
|
372
398
|
)
|
373
399
|
input_ids = self.tokenizer.encode(input_text)
|
374
400
|
|
375
|
-
image_inputs: Dict = await self.
|
401
|
+
image_inputs: Dict = await self.mm_processor.process_mm_data_async(
|
376
402
|
obj.image_data, input_text or input_ids, obj, self.max_req_input_len
|
377
403
|
)
|
378
404
|
if image_inputs and "input_ids" in image_inputs:
|
@@ -620,6 +646,15 @@ class TokenizerManager:
|
|
620
646
|
req = ProfileReq(type=ProfileReqType.STOP_PROFILE)
|
621
647
|
self.send_to_scheduler.send_pyobj(req)
|
622
648
|
|
649
|
+
async def start_expert_distribution_record(self):
|
650
|
+
await self.expert_distribution_communicator(ExpertDistributionReq.START_RECORD)
|
651
|
+
|
652
|
+
async def stop_expert_distribution_record(self):
|
653
|
+
await self.expert_distribution_communicator(ExpertDistributionReq.STOP_RECORD)
|
654
|
+
|
655
|
+
async def dump_expert_distribution_record(self):
|
656
|
+
await self.expert_distribution_communicator(ExpertDistributionReq.DUMP_RECORD)
|
657
|
+
|
623
658
|
async def update_weights_from_disk(
|
624
659
|
self,
|
625
660
|
obj: UpdateWeightFromDiskReqInput,
|
sglang/srt/managers/tp_worker.py
CHANGED
@@ -132,6 +132,9 @@ class TpModelWorker:
|
|
132
132
|
)[0]
|
133
133
|
set_random_seed(self.random_seed)
|
134
134
|
|
135
|
+
# A reference make this class has the same member as TpModelWorkerClient
|
136
|
+
self.worker = self
|
137
|
+
|
135
138
|
def get_worker_info(self):
|
136
139
|
return (
|
137
140
|
self.max_total_num_tokens,
|
@@ -214,7 +217,7 @@ class TpModelWorker:
|
|
214
217
|
def update_weights_from_tensor(self, recv_req: UpdateWeightsFromTensorReqInput):
|
215
218
|
success, message = self.model_runner.update_weights_from_tensor(
|
216
219
|
named_tensors=MultiprocessingSerializer.deserialize(
|
217
|
-
recv_req.serialized_named_tensors
|
220
|
+
recv_req.serialized_named_tensors[self.tp_rank]
|
218
221
|
),
|
219
222
|
load_format=recv_req.load_format,
|
220
223
|
)
|
@@ -33,7 +33,7 @@ from sglang.srt.managers.io_struct import (
|
|
33
33
|
from sglang.srt.managers.schedule_batch import ModelWorkerBatch
|
34
34
|
from sglang.srt.managers.tp_worker import TpModelWorker
|
35
35
|
from sglang.srt.server_args import ServerArgs
|
36
|
-
from sglang.srt.utils import get_compiler_backend
|
36
|
+
from sglang.srt.utils import DynamicGradMode, get_compiler_backend
|
37
37
|
from sglang.utils import get_exception_traceback
|
38
38
|
|
39
39
|
logger = logging.getLogger(__name__)
|
@@ -69,7 +69,7 @@ class TpModelWorkerClient:
|
|
69
69
|
self.future_token_ids_ct = 0
|
70
70
|
self.future_token_ids_limit = self.max_running_requests * 3
|
71
71
|
self.future_token_ids_map = torch.empty(
|
72
|
-
(self.max_running_requests * 5,), dtype=torch.
|
72
|
+
(self.max_running_requests * 5,), dtype=torch.int64, device=self.device
|
73
73
|
)
|
74
74
|
|
75
75
|
# Launch threads
|
@@ -115,7 +115,7 @@ class TpModelWorkerClient:
|
|
115
115
|
logger.error(f"TpModelWorkerClient hit an exception: {traceback}")
|
116
116
|
self.parent_process.send_signal(signal.SIGQUIT)
|
117
117
|
|
118
|
-
@
|
118
|
+
@DynamicGradMode()
|
119
119
|
def forward_thread_func_(self):
|
120
120
|
batch_pt = 0
|
121
121
|
batch_lists = [None] * 2
|
sglang/srt/managers/utils.py
CHANGED
@@ -1,6 +1,11 @@
|
|
1
|
+
import json
|
1
2
|
import logging
|
3
|
+
import time
|
4
|
+
from collections import defaultdict
|
2
5
|
from http import HTTPStatus
|
3
|
-
from typing import Optional
|
6
|
+
from typing import Dict, List, Optional, Tuple
|
7
|
+
|
8
|
+
import torch
|
4
9
|
|
5
10
|
from sglang.srt.managers.schedule_batch import FINISH_ABORT, Req
|
6
11
|
|
@@ -8,7 +8,10 @@ import torch
|
|
8
8
|
|
9
9
|
from sglang.srt.managers.cache_controller import HiCacheController
|
10
10
|
from sglang.srt.mem_cache.memory_pool import (
|
11
|
+
MHATokenToKVPool,
|
11
12
|
MHATokenToKVPoolHost,
|
13
|
+
MLATokenToKVPool,
|
14
|
+
MLATokenToKVPoolHost,
|
12
15
|
ReqToTokenPool,
|
13
16
|
TokenToKVPoolAllocator,
|
14
17
|
)
|
@@ -26,14 +29,24 @@ class HiRadixCache(RadixCache):
|
|
26
29
|
token_to_kv_pool_allocator: TokenToKVPoolAllocator,
|
27
30
|
tp_cache_group: torch.distributed.ProcessGroup,
|
28
31
|
page_size: int,
|
32
|
+
hicache_ratio: float,
|
29
33
|
):
|
30
34
|
if page_size != 1:
|
31
35
|
raise ValueError(
|
32
36
|
"Page size larger than 1 is not yet supported in HiRadixCache."
|
33
37
|
)
|
34
|
-
self.
|
35
|
-
|
36
|
-
|
38
|
+
self.kv_cache = token_to_kv_pool_allocator.get_kvcache()
|
39
|
+
if isinstance(self.kv_cache, MHATokenToKVPool):
|
40
|
+
self.token_to_kv_pool_host = MHATokenToKVPoolHost(
|
41
|
+
self.kv_cache, hicache_ratio
|
42
|
+
)
|
43
|
+
elif isinstance(self.kv_cache, MLATokenToKVPool):
|
44
|
+
self.token_to_kv_pool_host = MLATokenToKVPoolHost(
|
45
|
+
self.kv_cache, hicache_ratio
|
46
|
+
)
|
47
|
+
else:
|
48
|
+
raise ValueError(f"Only MHA and MLA supports swap kv_cache to host.")
|
49
|
+
|
37
50
|
self.tp_group = tp_cache_group
|
38
51
|
self.page_size = page_size
|
39
52
|
|
@@ -295,9 +308,9 @@ class HiRadixCache(RadixCache):
|
|
295
308
|
|
296
309
|
value, last_node = self._match_prefix_helper(self.root_node, key)
|
297
310
|
if value:
|
298
|
-
value = torch.
|
311
|
+
value = torch.cat(value)
|
299
312
|
else:
|
300
|
-
value = torch.tensor([], dtype=torch.
|
313
|
+
value = torch.tensor([], dtype=torch.int64)
|
301
314
|
|
302
315
|
last_node_global = last_node
|
303
316
|
while last_node.evicted:
|
@@ -317,13 +330,11 @@ class HiRadixCache(RadixCache):
|
|
317
330
|
prefix_len = _key_match(child.key, key)
|
318
331
|
if prefix_len < len(child.key):
|
319
332
|
new_node = self._split_node(child.key, child, prefix_len)
|
320
|
-
self.inc_hit_count(new_node)
|
321
333
|
if not new_node.evicted:
|
322
334
|
value.append(new_node.value)
|
323
335
|
node = new_node
|
324
336
|
break
|
325
337
|
else:
|
326
|
-
self.inc_hit_count(child)
|
327
338
|
if not child.evicted:
|
328
339
|
value.append(child.value)
|
329
340
|
node = child
|