sglang 0.5.2rc2__py3-none-any.whl → 0.5.3rc0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch_server.py +10 -1
- sglang/bench_serving.py +257 -29
- sglang/srt/configs/__init__.py +4 -0
- sglang/srt/configs/device_config.py +3 -1
- sglang/srt/configs/dots_vlm.py +139 -0
- sglang/srt/configs/load_config.py +1 -0
- sglang/srt/configs/model_config.py +50 -6
- sglang/srt/configs/qwen3_next.py +326 -0
- sglang/srt/connector/__init__.py +8 -1
- sglang/srt/connector/remote_instance.py +82 -0
- sglang/srt/constrained/base_grammar_backend.py +48 -12
- sglang/srt/constrained/llguidance_backend.py +0 -1
- sglang/srt/constrained/outlines_backend.py +0 -1
- sglang/srt/constrained/xgrammar_backend.py +28 -9
- sglang/srt/custom_op.py +11 -1
- sglang/srt/debug_utils/dump_comparator.py +81 -44
- sglang/srt/debug_utils/dump_loader.py +97 -0
- sglang/srt/debug_utils/dumper.py +11 -3
- sglang/srt/debug_utils/text_comparator.py +73 -11
- sglang/srt/disaggregation/base/conn.py +1 -1
- sglang/srt/disaggregation/common/conn.py +15 -12
- sglang/srt/disaggregation/decode.py +21 -10
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +4 -1
- sglang/srt/disaggregation/fake/conn.py +1 -1
- sglang/srt/disaggregation/mini_lb.py +6 -445
- sglang/srt/disaggregation/mooncake/conn.py +18 -10
- sglang/srt/disaggregation/nixl/conn.py +180 -16
- sglang/srt/disaggregation/prefill.py +5 -3
- sglang/srt/disaggregation/utils.py +5 -50
- sglang/srt/distributed/parallel_state.py +24 -3
- sglang/srt/entrypoints/engine.py +38 -17
- sglang/srt/entrypoints/grpc_request_manager.py +580 -0
- sglang/srt/entrypoints/grpc_server.py +680 -0
- sglang/srt/entrypoints/http_server.py +85 -54
- sglang/srt/entrypoints/openai/protocol.py +4 -1
- sglang/srt/entrypoints/openai/serving_base.py +46 -3
- sglang/srt/entrypoints/openai/serving_chat.py +36 -16
- sglang/srt/entrypoints/openai/serving_completions.py +12 -3
- sglang/srt/entrypoints/openai/serving_embedding.py +8 -3
- sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
- sglang/srt/entrypoints/openai/serving_responses.py +6 -3
- sglang/srt/entrypoints/openai/serving_score.py +1 -0
- sglang/srt/eplb/eplb_manager.py +2 -2
- sglang/srt/eplb/expert_distribution.py +26 -13
- sglang/srt/eplb/expert_location.py +8 -3
- sglang/srt/eplb/expert_location_updater.py +1 -1
- sglang/srt/function_call/base_format_detector.py +3 -6
- sglang/srt/function_call/ebnf_composer.py +11 -9
- sglang/srt/function_call/function_call_parser.py +6 -0
- sglang/srt/function_call/glm4_moe_detector.py +1 -1
- sglang/srt/function_call/qwen3_coder_detector.py +1 -1
- sglang/srt/grpc/__init__.py +1 -0
- sglang/srt/grpc/sglang_scheduler_pb2.py +106 -0
- sglang/srt/grpc/sglang_scheduler_pb2.pyi +427 -0
- sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +236 -0
- sglang/srt/hf_transformers_utils.py +4 -0
- sglang/srt/layers/activation.py +142 -9
- sglang/srt/layers/attention/ascend_backend.py +11 -4
- sglang/srt/layers/attention/fla/chunk.py +242 -0
- sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
- sglang/srt/layers/attention/fla/chunk_o.py +178 -0
- sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
- sglang/srt/layers/attention/fla/cumsum.py +300 -0
- sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
- sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
- sglang/srt/layers/attention/fla/index.py +37 -0
- sglang/srt/layers/attention/fla/l2norm.py +150 -0
- sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
- sglang/srt/layers/attention/fla/op.py +66 -0
- sglang/srt/layers/attention/fla/solve_tril.py +465 -0
- sglang/srt/layers/attention/fla/utils.py +331 -0
- sglang/srt/layers/attention/fla/wy_fast.py +158 -0
- sglang/srt/layers/attention/flashinfer_backend.py +6 -4
- sglang/srt/layers/attention/flashinfer_mla_backend.py +16 -12
- sglang/srt/layers/attention/hybrid_attn_backend.py +57 -50
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +602 -0
- sglang/srt/layers/attention/intel_amx_backend.py +3 -0
- sglang/srt/layers/attention/mamba/causal_conv1d.py +128 -0
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +1052 -0
- sglang/srt/layers/attention/mamba/mamba.py +64 -0
- sglang/srt/layers/attention/torch_native_backend.py +12 -6
- sglang/srt/layers/attention/triton_backend.py +18 -1
- sglang/srt/layers/attention/trtllm_mla_backend.py +124 -31
- sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
- sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
- sglang/srt/layers/dp_attention.py +30 -1
- sglang/srt/layers/layernorm.py +32 -15
- sglang/srt/layers/linear.py +34 -3
- sglang/srt/layers/logits_processor.py +29 -10
- sglang/srt/layers/moe/__init__.py +2 -1
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +3 -3
- sglang/srt/layers/moe/ep_moe/kernels.py +1 -1
- sglang/srt/layers/moe/ep_moe/layer.py +182 -62
- sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +156 -0
- sglang/srt/layers/moe/fused_moe_native.py +5 -3
- sglang/srt/layers/moe/fused_moe_triton/configs/{triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +35 -35
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +1 -1
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
- sglang/srt/layers/moe/fused_moe_triton/layer.py +61 -59
- sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
- sglang/srt/layers/moe/moe_runner/base.py +274 -1
- sglang/srt/layers/moe/moe_runner/runner.py +80 -0
- sglang/srt/layers/moe/moe_runner/triton.py +448 -0
- sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
- sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
- sglang/srt/layers/moe/token_dispatcher/deepep.py +43 -39
- sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
- sglang/srt/layers/moe/topk.py +30 -9
- sglang/srt/layers/moe/utils.py +12 -6
- sglang/srt/layers/quantization/awq.py +19 -7
- sglang/srt/layers/quantization/base_config.py +11 -6
- sglang/srt/layers/quantization/blockwise_int8.py +38 -27
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
- sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
- sglang/srt/layers/quantization/fp8.py +76 -47
- sglang/srt/layers/quantization/fp8_utils.py +50 -31
- sglang/srt/layers/quantization/gptq.py +25 -17
- sglang/srt/layers/quantization/modelopt_quant.py +147 -47
- sglang/srt/layers/quantization/moe_wna16.py +21 -18
- sglang/srt/layers/quantization/mxfp4.py +64 -40
- sglang/srt/layers/quantization/quark/quark_moe.py +32 -27
- sglang/srt/layers/quantization/unquant.py +135 -47
- sglang/srt/layers/quantization/w4afp8.py +30 -17
- sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
- sglang/srt/layers/quantization/w8a8_int8.py +76 -38
- sglang/srt/layers/sampler.py +162 -18
- sglang/srt/lora/backend/base_backend.py +50 -8
- sglang/srt/lora/backend/triton_backend.py +90 -2
- sglang/srt/lora/layers.py +32 -0
- sglang/srt/lora/lora.py +4 -1
- sglang/srt/lora/lora_manager.py +35 -112
- sglang/srt/lora/mem_pool.py +24 -10
- sglang/srt/lora/utils.py +18 -9
- sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
- sglang/srt/managers/cache_controller.py +158 -160
- sglang/srt/managers/data_parallel_controller.py +105 -35
- sglang/srt/managers/detokenizer_manager.py +8 -4
- sglang/srt/managers/disagg_service.py +46 -0
- sglang/srt/managers/io_struct.py +199 -12
- sglang/srt/managers/mm_utils.py +1 -0
- sglang/srt/managers/multi_tokenizer_mixin.py +350 -400
- sglang/srt/managers/schedule_batch.py +77 -56
- sglang/srt/managers/schedule_policy.py +1 -1
- sglang/srt/managers/scheduler.py +187 -39
- sglang/srt/managers/scheduler_metrics_mixin.py +4 -3
- sglang/srt/managers/scheduler_output_processor_mixin.py +55 -11
- sglang/srt/managers/scheduler_profiler_mixin.py +1 -1
- sglang/srt/managers/tokenizer_communicator_mixin.py +569 -0
- sglang/srt/managers/tokenizer_manager.py +259 -519
- sglang/srt/managers/tp_worker.py +53 -4
- sglang/srt/managers/tp_worker_overlap_thread.py +42 -19
- sglang/srt/mem_cache/hicache_storage.py +3 -23
- sglang/srt/mem_cache/hiradix_cache.py +103 -43
- sglang/srt/mem_cache/memory_pool.py +347 -48
- sglang/srt/mem_cache/memory_pool_host.py +105 -46
- sglang/srt/mem_cache/radix_cache.py +0 -2
- sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
- sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +86 -4
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +280 -0
- sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +49 -7
- sglang/srt/mem_cache/swa_radix_cache.py +0 -2
- sglang/srt/metrics/collector.py +493 -76
- sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
- sglang/srt/model_executor/cpu_graph_runner.py +640 -0
- sglang/srt/model_executor/cuda_graph_runner.py +13 -5
- sglang/srt/model_executor/forward_batch_info.py +59 -2
- sglang/srt/model_executor/model_runner.py +356 -29
- sglang/srt/model_loader/__init__.py +9 -3
- sglang/srt/model_loader/loader.py +128 -4
- sglang/srt/model_loader/weight_utils.py +2 -1
- sglang/srt/models/apertus.py +686 -0
- sglang/srt/models/bailing_moe.py +798 -218
- sglang/srt/models/bailing_moe_nextn.py +168 -0
- sglang/srt/models/deepseek_v2.py +109 -15
- sglang/srt/models/dots_vlm.py +174 -0
- sglang/srt/models/dots_vlm_vit.py +337 -0
- sglang/srt/models/ernie4.py +1 -1
- sglang/srt/models/gemma3n_mm.py +1 -1
- sglang/srt/models/glm4_moe.py +1 -1
- sglang/srt/models/glm4v.py +4 -2
- sglang/srt/models/glm4v_moe.py +3 -0
- sglang/srt/models/gpt_oss.py +1 -1
- sglang/srt/models/llama4.py +9 -0
- sglang/srt/models/llama_eagle3.py +13 -0
- sglang/srt/models/longcat_flash.py +2 -2
- sglang/srt/models/mllama4.py +25 -0
- sglang/srt/models/opt.py +637 -0
- sglang/srt/models/qwen2.py +7 -0
- sglang/srt/models/qwen2_5_vl.py +27 -3
- sglang/srt/models/qwen2_moe.py +56 -12
- sglang/srt/models/qwen3_moe.py +1 -1
- sglang/srt/models/qwen3_next.py +1042 -0
- sglang/srt/models/qwen3_next_mtp.py +112 -0
- sglang/srt/models/step3_vl.py +1 -1
- sglang/srt/multimodal/processors/dots_vlm.py +99 -0
- sglang/srt/multimodal/processors/glm4v.py +9 -9
- sglang/srt/multimodal/processors/internvl.py +141 -129
- sglang/srt/multimodal/processors/qwen_vl.py +15 -5
- sglang/srt/offloader.py +27 -3
- sglang/srt/remote_instance_weight_loader_utils.py +69 -0
- sglang/srt/sampling/sampling_batch_info.py +18 -15
- sglang/srt/server_args.py +276 -35
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +5 -0
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +10 -1
- sglang/srt/speculative/eagle_utils.py +0 -2
- sglang/srt/speculative/eagle_worker.py +43 -4
- sglang/srt/speculative/spec_info.py +5 -0
- sglang/srt/speculative/standalone_worker.py +109 -0
- sglang/srt/tracing/trace.py +552 -0
- sglang/srt/utils.py +34 -3
- sglang/srt/weight_sync/utils.py +1 -1
- sglang/test/attention/test_trtllm_mla_backend.py +169 -5
- sglang/test/runners.py +4 -0
- sglang/test/test_cutlass_moe.py +24 -6
- sglang/test/test_disaggregation_utils.py +66 -0
- sglang/test/test_fp4_moe.py +370 -1
- sglang/test/test_utils.py +28 -1
- sglang/utils.py +11 -0
- sglang/version.py +1 -1
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/METADATA +59 -123
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/RECORD +237 -178
- sglang/srt/disaggregation/launch_lb.py +0 -118
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/WHEEL +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/top_level.txt +0 -0
@@ -18,7 +18,7 @@ import math
|
|
18
18
|
import threading
|
19
19
|
import time
|
20
20
|
from queue import Empty, Full, PriorityQueue, Queue
|
21
|
-
from typing import TYPE_CHECKING, List, Optional
|
21
|
+
from typing import TYPE_CHECKING, List, NamedTuple, Optional, Set, Tuple
|
22
22
|
|
23
23
|
import torch
|
24
24
|
|
@@ -33,6 +33,7 @@ from sglang.srt.distributed import (
|
|
33
33
|
get_tensor_model_parallel_world_size,
|
34
34
|
)
|
35
35
|
from sglang.srt.layers.dp_attention import (
|
36
|
+
get_attention_dp_rank,
|
36
37
|
get_attention_tp_rank,
|
37
38
|
get_attention_tp_size,
|
38
39
|
is_dp_attention_enabled,
|
@@ -42,39 +43,53 @@ from sglang.srt.mem_cache.memory_pool import MHATokenToKVPool, MLATokenToKVPool
|
|
42
43
|
logger = logging.getLogger(__name__)
|
43
44
|
|
44
45
|
|
46
|
+
class LayerLoadingEvent:
|
47
|
+
def __init__(self, num_layers: int):
|
48
|
+
self._num_layers = num_layers
|
49
|
+
self.load_events = [torch.cuda.Event() for _ in range(num_layers)]
|
50
|
+
self.start_event = torch.cuda.Event() # start event on controller stream
|
51
|
+
|
52
|
+
def complete(self, layer_index: int):
|
53
|
+
assert 0 <= layer_index < self._num_layers
|
54
|
+
self.load_events[layer_index].record()
|
55
|
+
|
56
|
+
def wait(self, layer_index: int):
|
57
|
+
torch.cuda.current_stream().wait_event(self.load_events[layer_index])
|
58
|
+
|
59
|
+
@property
|
60
|
+
def finish_event(self):
|
61
|
+
return self.load_events[-1]
|
62
|
+
|
63
|
+
|
45
64
|
class LayerDoneCounter:
|
46
|
-
def __init__(self, num_layers):
|
65
|
+
def __init__(self, num_layers: int):
|
47
66
|
self.num_layers = num_layers
|
48
67
|
# extra producer and consumer counters for overlap mode
|
49
68
|
self.num_counters = 3
|
50
|
-
self.
|
51
|
-
self.
|
52
|
-
self.
|
53
|
-
self.consumer_index = 0
|
54
|
-
|
55
|
-
def next_producer(self):
|
56
|
-
return (self.producer_index + 1) % self.num_counters
|
69
|
+
self.events = [LayerLoadingEvent(num_layers) for _ in range(self.num_counters)]
|
70
|
+
self.producer_index = -1
|
71
|
+
self.consumer_index = -1
|
57
72
|
|
58
73
|
def update_producer(self):
|
59
|
-
self.producer_index = self.
|
74
|
+
self.producer_index = (self.producer_index + 1) % self.num_counters
|
75
|
+
assert self.events[
|
76
|
+
self.producer_index
|
77
|
+
].finish_event.query(), (
|
78
|
+
"Producer finish event should be ready before being reused."
|
79
|
+
)
|
60
80
|
return self.producer_index
|
61
81
|
|
62
|
-
def set_consumer(self, index):
|
82
|
+
def set_consumer(self, index: int):
|
63
83
|
self.consumer_index = index
|
64
84
|
|
65
|
-
def
|
66
|
-
|
67
|
-
|
68
|
-
|
69
|
-
|
70
|
-
def wait_until(self, threshold):
|
71
|
-
with self.conditions[self.consumer_index]:
|
72
|
-
while self.counters[self.consumer_index] <= threshold:
|
73
|
-
self.conditions[self.consumer_index].wait()
|
85
|
+
def wait_until(self, threshold: int):
|
86
|
+
if self.consumer_index < 0:
|
87
|
+
return
|
88
|
+
self.events[self.consumer_index].wait(threshold)
|
74
89
|
|
75
90
|
def reset(self):
|
76
|
-
|
77
|
-
|
91
|
+
self.producer_index = -1
|
92
|
+
self.consumer_index = -1
|
78
93
|
|
79
94
|
|
80
95
|
class CacheOperation:
|
@@ -98,36 +113,30 @@ class CacheOperation:
|
|
98
113
|
# default priority is the order of creation
|
99
114
|
self.priority = priority if priority is not None else self.id
|
100
115
|
|
101
|
-
|
102
|
-
|
103
|
-
|
104
|
-
|
105
|
-
|
106
|
-
|
107
|
-
|
108
|
-
|
109
|
-
|
110
|
-
|
111
|
-
|
112
|
-
|
113
|
-
|
114
|
-
|
115
|
-
|
116
|
-
|
117
|
-
|
118
|
-
|
119
|
-
device_indices=self.device_indices[i : i + chunk_size],
|
120
|
-
node_id=0,
|
121
|
-
)
|
122
|
-
)
|
123
|
-
# Inherit the node_ids on the final chunk
|
124
|
-
if split_ops:
|
125
|
-
split_ops[-1].node_ids = self.node_ids
|
116
|
+
@staticmethod
|
117
|
+
def merge_ops(ops: List[CacheOperation]) -> CacheOperation:
|
118
|
+
assert len(ops) > 0
|
119
|
+
if len(ops) == 1:
|
120
|
+
return ops[0]
|
121
|
+
|
122
|
+
host_indices = torch.cat([op.host_indices for op in ops])
|
123
|
+
device_indices = torch.cat([op.device_indices for op in ops])
|
124
|
+
node_ids = []
|
125
|
+
priority = min(op.priority for op in ops)
|
126
|
+
for op in ops:
|
127
|
+
node_ids.extend(op.node_ids)
|
128
|
+
merged_op = CacheOperation(host_indices, device_indices, -1, priority)
|
129
|
+
merged_op.node_ids = node_ids
|
130
|
+
return merged_op
|
131
|
+
|
132
|
+
def __lt__(self, other: CacheOperation):
|
133
|
+
return self.priority < other.priority
|
126
134
|
|
127
|
-
return split_ops
|
128
135
|
|
129
|
-
|
130
|
-
|
136
|
+
class HiCacheAck(NamedTuple):
|
137
|
+
start_event: torch.cuda.Event
|
138
|
+
finish_event: torch.cuda.Event
|
139
|
+
node_ids: List[int]
|
131
140
|
|
132
141
|
|
133
142
|
class TransferBuffer:
|
@@ -206,26 +215,25 @@ class PrefetchOperation(StorageOperation):
|
|
206
215
|
):
|
207
216
|
self.request_id = request_id
|
208
217
|
|
209
|
-
self._done_flag = False
|
210
218
|
self._lock = threading.Lock()
|
211
|
-
|
219
|
+
self._terminated_flag = False
|
212
220
|
self.start_time = time.monotonic()
|
213
221
|
|
214
222
|
super().__init__(host_indices, token_ids, last_hash)
|
215
223
|
|
216
224
|
def increment(self, num_tokens: int):
|
217
225
|
with self._lock:
|
218
|
-
if self.
|
226
|
+
if self._terminated_flag:
|
219
227
|
return False
|
220
228
|
self.completed_tokens += num_tokens
|
221
229
|
return True
|
222
230
|
|
223
|
-
def
|
231
|
+
def mark_terminate(self):
|
224
232
|
with self._lock:
|
225
|
-
self.
|
233
|
+
self._terminated_flag = True
|
226
234
|
|
227
|
-
def
|
228
|
-
return self.
|
235
|
+
def is_terminated(self) -> bool:
|
236
|
+
return self._terminated_flag
|
229
237
|
|
230
238
|
|
231
239
|
class HiCacheController:
|
@@ -236,7 +244,7 @@ class HiCacheController:
|
|
236
244
|
mem_pool_host: HostKVCache,
|
237
245
|
page_size: int,
|
238
246
|
tp_group: torch.distributed.ProcessGroup,
|
239
|
-
load_cache_event: threading.Event
|
247
|
+
load_cache_event: threading.Event,
|
240
248
|
write_policy: str = "write_through_selective",
|
241
249
|
io_backend: str = "",
|
242
250
|
storage_backend: Optional[str] = None,
|
@@ -340,8 +348,9 @@ class HiCacheController:
|
|
340
348
|
self.page_set_func = self._3fs_zero_copy_page_set
|
341
349
|
self.batch_exists_func = self._3fs_zero_copy_batch_exists
|
342
350
|
|
343
|
-
self.
|
344
|
-
self.
|
351
|
+
self.device = self.mem_pool_device.device
|
352
|
+
self.layer_num = self.mem_pool_device.layer_num
|
353
|
+
self.layer_done_counter = LayerDoneCounter(self.layer_num)
|
345
354
|
self.mem_pool_device.register_layer_transfer_counter(self.layer_done_counter)
|
346
355
|
|
347
356
|
if write_policy not in [
|
@@ -351,11 +360,11 @@ class HiCacheController:
|
|
351
360
|
]:
|
352
361
|
raise ValueError(f"Invalid write policy: {write_policy}")
|
353
362
|
|
354
|
-
self.write_queue = PriorityQueue()
|
355
|
-
self.load_queue =
|
356
|
-
|
357
|
-
self.
|
358
|
-
self.
|
363
|
+
# self.write_queue = PriorityQueue[CacheOperation]()
|
364
|
+
self.load_queue: List[CacheOperation] = []
|
365
|
+
self.write_queue: List[CacheOperation] = []
|
366
|
+
self.ack_load_queue: List[HiCacheAck] = []
|
367
|
+
self.ack_write_queue: List[HiCacheAck] = []
|
359
368
|
|
360
369
|
self.stop_event = threading.Event()
|
361
370
|
self.write_buffer = TransferBuffer(self.stop_event)
|
@@ -366,16 +375,6 @@ class HiCacheController:
|
|
366
375
|
self.write_stream = torch.cuda.Stream()
|
367
376
|
self.load_stream = torch.cuda.Stream()
|
368
377
|
|
369
|
-
self.write_thread = threading.Thread(
|
370
|
-
target=self.write_thread_func_direct, daemon=True
|
371
|
-
)
|
372
|
-
self.load_thread = threading.Thread(
|
373
|
-
target=self.load_thread_func_layer_by_layer, daemon=True
|
374
|
-
)
|
375
|
-
|
376
|
-
self.write_thread.start()
|
377
|
-
self.load_thread.start()
|
378
|
-
|
379
378
|
if self.enable_storage:
|
380
379
|
self.prefetch_thread = threading.Thread(
|
381
380
|
target=self.prefetch_thread_func, daemon=True
|
@@ -402,9 +401,11 @@ class HiCacheController:
|
|
402
401
|
if is_dp_attention_enabled():
|
403
402
|
self.tp_rank = get_attention_tp_rank()
|
404
403
|
self.tp_size = get_attention_tp_size()
|
404
|
+
self.dp_rank = get_attention_dp_rank()
|
405
405
|
else:
|
406
406
|
self.tp_rank = get_tensor_model_parallel_rank()
|
407
407
|
self.tp_size = get_tensor_model_parallel_world_size()
|
408
|
+
self.dp_rank = 0
|
408
409
|
|
409
410
|
# Currently, AscendMLAPagedTokenToKVPool is the subclass of MLATokenToKVPool.
|
410
411
|
is_mla_backend = isinstance(self.mem_pool_device, MLATokenToKVPool)
|
@@ -430,15 +431,13 @@ class HiCacheController:
|
|
430
431
|
|
431
432
|
def reset(self):
|
432
433
|
self.stop_event.set()
|
433
|
-
self.write_thread.join()
|
434
|
-
self.load_thread.join()
|
435
434
|
|
436
|
-
self.write_queue.
|
437
|
-
self.load_queue.
|
435
|
+
self.write_queue.clear()
|
436
|
+
self.load_queue.clear()
|
438
437
|
self.write_buffer.clear()
|
439
438
|
self.load_buffer.clear()
|
440
|
-
self.ack_write_queue.
|
441
|
-
self.ack_load_queue.
|
439
|
+
self.ack_write_queue.clear()
|
440
|
+
self.ack_load_queue.clear()
|
442
441
|
if self.enable_storage:
|
443
442
|
self.prefetch_thread.join()
|
444
443
|
self.backup_thread.join()
|
@@ -447,15 +446,7 @@ class HiCacheController:
|
|
447
446
|
self.prefetch_revoke_queue.queue.clear()
|
448
447
|
self.ack_backup_queue.queue.clear()
|
449
448
|
|
450
|
-
self.write_thread = threading.Thread(
|
451
|
-
target=self.write_thread_func_direct, daemon=True
|
452
|
-
)
|
453
|
-
self.load_thread = threading.Thread(
|
454
|
-
target=self.load_thread_func_layer_by_layer, daemon=True
|
455
|
-
)
|
456
449
|
self.stop_event.clear()
|
457
|
-
self.write_thread.start()
|
458
|
-
self.load_thread.start()
|
459
450
|
|
460
451
|
if self.enable_storage:
|
461
452
|
self.prefetch_thread = threading.Thread(
|
@@ -471,7 +462,7 @@ class HiCacheController:
|
|
471
462
|
self,
|
472
463
|
device_indices: torch.Tensor,
|
473
464
|
priority: Optional[int] = None,
|
474
|
-
node_id: int =
|
465
|
+
node_id: int = -1,
|
475
466
|
) -> Optional[torch.Tensor]:
|
476
467
|
"""
|
477
468
|
Back up KV caches from device memory to host memory.
|
@@ -480,17 +471,46 @@ class HiCacheController:
|
|
480
471
|
if host_indices is None:
|
481
472
|
return None
|
482
473
|
self.mem_pool_host.protect_write(host_indices)
|
483
|
-
|
484
|
-
self.write_queue.put(
|
474
|
+
self.write_queue.append(
|
485
475
|
CacheOperation(host_indices, device_indices, node_id, priority)
|
486
476
|
)
|
477
|
+
self.start_writing()
|
487
478
|
return host_indices
|
488
479
|
|
480
|
+
def start_writing(self) -> None:
|
481
|
+
if len(self.write_queue) == 0:
|
482
|
+
return
|
483
|
+
|
484
|
+
op = CacheOperation.merge_ops(self.write_queue)
|
485
|
+
host_indices, device_indices = self.move_indices(op)
|
486
|
+
self.write_queue.clear()
|
487
|
+
|
488
|
+
start_event = torch.cuda.Event()
|
489
|
+
finish_event = torch.cuda.Event()
|
490
|
+
|
491
|
+
start_event.record()
|
492
|
+
with torch.cuda.stream(self.write_stream):
|
493
|
+
start_event.wait(self.write_stream)
|
494
|
+
self.mem_pool_host.backup_from_device_all_layer(
|
495
|
+
self.mem_pool_device, host_indices, device_indices, self.io_backend
|
496
|
+
)
|
497
|
+
self.mem_pool_host.complete_io(op.host_indices)
|
498
|
+
finish_event.record()
|
499
|
+
# NOTE: We must save the host indices and device indices here,
|
500
|
+
# this is because we need to guarantee that these tensors are
|
501
|
+
# still alive when the write stream is executing.
|
502
|
+
if host_indices.is_cuda:
|
503
|
+
host_indices.record_stream(self.write_stream)
|
504
|
+
if device_indices.is_cuda:
|
505
|
+
device_indices.record_stream(self.write_stream)
|
506
|
+
|
507
|
+
self.ack_write_queue.append(HiCacheAck(start_event, finish_event, op.node_ids))
|
508
|
+
|
489
509
|
def load(
|
490
510
|
self,
|
491
511
|
host_indices: torch.Tensor,
|
492
512
|
priority: Optional[int] = None,
|
493
|
-
node_id: int =
|
513
|
+
node_id: int = -1,
|
494
514
|
) -> Optional[torch.Tensor]:
|
495
515
|
"""
|
496
516
|
Load KV caches from host memory to device memory.
|
@@ -499,76 +519,42 @@ class HiCacheController:
|
|
499
519
|
if device_indices is None:
|
500
520
|
return None
|
501
521
|
self.mem_pool_host.protect_load(host_indices)
|
502
|
-
|
503
|
-
torch.cuda.current_stream().synchronize()
|
504
|
-
self.load_queue.put(
|
522
|
+
self.load_queue.append(
|
505
523
|
CacheOperation(host_indices, device_indices, node_id, priority)
|
506
524
|
)
|
507
525
|
return device_indices
|
508
526
|
|
509
|
-
def move_indices(self,
|
527
|
+
def move_indices(self, op: CacheOperation):
|
528
|
+
host_indices, device_indices = op.host_indices, op.device_indices
|
510
529
|
# move indices to GPU if using kernels, to host if using direct indexing
|
511
530
|
if self.io_backend == "kernel":
|
512
|
-
|
531
|
+
if not host_indices.is_cuda:
|
532
|
+
host_indices = host_indices.to(self.device, non_blocking=True)
|
533
|
+
return host_indices, device_indices
|
513
534
|
elif self.io_backend == "direct":
|
514
|
-
|
515
|
-
|
516
|
-
|
535
|
+
if self.mem_pool_host.layout == "layer_first":
|
536
|
+
device_indices = device_indices.cpu()
|
537
|
+
host_indices, idx = host_indices.sort()
|
538
|
+
return host_indices, device_indices.index_select(0, idx)
|
539
|
+
elif self.mem_pool_host.layout == "page_first_direct":
|
540
|
+
return host_indices, device_indices.cpu()
|
517
541
|
else:
|
518
542
|
raise ValueError(f"Unsupported io backend")
|
519
543
|
|
520
|
-
def
|
521
|
-
|
522
|
-
|
523
|
-
"""
|
524
|
-
torch.cuda.set_stream(self.write_stream)
|
525
|
-
while not self.stop_event.is_set():
|
526
|
-
try:
|
527
|
-
operation = self.write_queue.get(block=True, timeout=1)
|
528
|
-
host_indices, device_indices = self.move_indices(
|
529
|
-
operation.host_indices, operation.device_indices
|
530
|
-
)
|
531
|
-
self.mem_pool_host.backup_from_device_all_layer(
|
532
|
-
self.mem_pool_device, host_indices, device_indices, self.io_backend
|
533
|
-
)
|
534
|
-
self.write_stream.synchronize()
|
535
|
-
self.mem_pool_host.complete_io(operation.host_indices)
|
536
|
-
for node_id in operation.node_ids:
|
537
|
-
if node_id != 0:
|
538
|
-
self.ack_write_queue.put(node_id)
|
539
|
-
except Empty:
|
540
|
-
continue
|
541
|
-
except Exception as e:
|
542
|
-
logger.error(e)
|
544
|
+
def start_loading(self) -> int:
|
545
|
+
if len(self.load_queue) == 0:
|
546
|
+
return -1
|
543
547
|
|
544
|
-
|
545
|
-
|
546
|
-
|
547
|
-
|
548
|
-
|
549
|
-
|
550
|
-
self.load_cache_event.wait(timeout=1)
|
551
|
-
if not self.load_cache_event.is_set():
|
552
|
-
continue
|
553
|
-
self.load_cache_event.clear()
|
554
|
-
self.layer_done_counter.update_producer()
|
555
|
-
|
556
|
-
batch_operation = None
|
557
|
-
while self.load_queue.qsize() > 0:
|
558
|
-
op = self.load_queue.get(block=True)
|
559
|
-
if batch_operation is None:
|
560
|
-
batch_operation = op
|
561
|
-
else:
|
562
|
-
batch_operation.merge(op)
|
563
|
-
if batch_operation is None:
|
564
|
-
continue
|
548
|
+
producer_id = self.layer_done_counter.update_producer()
|
549
|
+
op = CacheOperation.merge_ops(self.load_queue)
|
550
|
+
host_indices, device_indices = self.move_indices(op)
|
551
|
+
self.load_queue.clear()
|
552
|
+
producer_event = self.layer_done_counter.events[producer_id]
|
553
|
+
producer_event.start_event.record()
|
565
554
|
|
566
|
-
|
567
|
-
|
568
|
-
|
569
|
-
batch_operation.host_indices, batch_operation.device_indices
|
570
|
-
)
|
571
|
-
for i in range(self.mem_pool_host.layer_num):
|
555
|
+
with torch.cuda.stream(self.load_stream):
|
556
|
+
producer_event.start_event.wait(self.load_stream)
|
557
|
+
for i in range(self.layer_num):
|
572
558
|
self.mem_pool_host.load_to_device_per_layer(
|
573
559
|
self.mem_pool_device,
|
574
560
|
host_indices,
|
@@ -576,13 +562,24 @@ class HiCacheController:
|
|
576
562
|
i,
|
577
563
|
self.io_backend,
|
578
564
|
)
|
579
|
-
|
580
|
-
|
581
|
-
|
582
|
-
|
583
|
-
|
584
|
-
|
585
|
-
|
565
|
+
producer_event.complete(i)
|
566
|
+
self.mem_pool_host.complete_io(op.host_indices)
|
567
|
+
# NOTE: We must save the host indices and device indices here,
|
568
|
+
# this is because we need to guarantee that these tensors are
|
569
|
+
# still alive when the load stream is executing.
|
570
|
+
if host_indices.is_cuda:
|
571
|
+
host_indices.record_stream(self.load_stream)
|
572
|
+
if device_indices.is_cuda:
|
573
|
+
device_indices.record_stream(self.load_stream)
|
574
|
+
|
575
|
+
self.ack_load_queue.append(
|
576
|
+
HiCacheAck(
|
577
|
+
start_event=producer_event.start_event,
|
578
|
+
finish_event=producer_event.finish_event,
|
579
|
+
node_ids=op.node_ids,
|
580
|
+
)
|
581
|
+
)
|
582
|
+
return producer_id
|
586
583
|
|
587
584
|
def evict_device(
|
588
585
|
self, device_indices: torch.Tensor, host_indices: torch.Tensor
|
@@ -625,7 +622,7 @@ class HiCacheController:
|
|
625
622
|
return operation
|
626
623
|
|
627
624
|
def terminate_prefetch(self, operation):
|
628
|
-
operation.
|
625
|
+
operation.mark_terminate()
|
629
626
|
return operation.completed_tokens, operation.hash_value
|
630
627
|
|
631
628
|
def append_host_mem_release(self, host_indices: torch.Tensor):
|
@@ -706,6 +703,7 @@ class HiCacheController:
|
|
706
703
|
operation.completed_tokens
|
707
704
|
!= prev_completed_tokens + len(batch_hashes) * self.page_size
|
708
705
|
):
|
706
|
+
operation.mark_terminate()
|
709
707
|
break # Some operations fail or operation terminated by controller
|
710
708
|
# release pre-allocated memory
|
711
709
|
self.append_host_mem_release(
|
@@ -885,7 +883,7 @@ class HiCacheController:
|
|
885
883
|
|
886
884
|
if not self.backup_skip:
|
887
885
|
self._page_backup(operation)
|
888
|
-
self.ack_backup_queue.put(operation
|
886
|
+
self.ack_backup_queue.put(operation)
|
889
887
|
|
890
888
|
except Empty:
|
891
889
|
continue
|