sglang 0.5.0rc1__py3-none-any.whl → 0.5.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +0 -7
- sglang/bench_one_batch_server.py +7 -2
- sglang/bench_serving.py +3 -3
- sglang/eval/llama3_eval.py +0 -1
- sglang/srt/configs/model_config.py +25 -9
- sglang/srt/configs/update_config.py +40 -5
- sglang/srt/constrained/xgrammar_backend.py +23 -11
- sglang/srt/conversation.py +2 -15
- sglang/srt/disaggregation/ascend/conn.py +1 -3
- sglang/srt/disaggregation/base/conn.py +1 -0
- sglang/srt/disaggregation/decode.py +1 -2
- sglang/srt/disaggregation/launch_lb.py +7 -1
- sglang/srt/disaggregation/mini_lb.py +11 -5
- sglang/srt/disaggregation/mooncake/conn.py +141 -47
- sglang/srt/disaggregation/prefill.py +261 -5
- sglang/srt/disaggregation/utils.py +2 -1
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +1 -1
- sglang/srt/distributed/device_communicators/pynccl.py +68 -18
- sglang/srt/distributed/device_communicators/pynccl_wrapper.py +52 -0
- sglang/srt/distributed/naive_distributed.py +112 -0
- sglang/srt/distributed/parallel_state.py +90 -4
- sglang/srt/entrypoints/context.py +20 -1
- sglang/srt/entrypoints/engine.py +29 -4
- sglang/srt/entrypoints/http_server.py +76 -0
- sglang/srt/entrypoints/openai/protocol.py +4 -2
- sglang/srt/entrypoints/openai/serving_chat.py +23 -6
- sglang/srt/entrypoints/openai/serving_completions.py +10 -1
- sglang/srt/entrypoints/openai/serving_responses.py +2 -2
- sglang/srt/eplb/expert_distribution.py +2 -3
- sglang/srt/function_call/deepseekv3_detector.py +1 -1
- sglang/srt/hf_transformers_utils.py +24 -0
- sglang/srt/host_shared_memory.py +83 -0
- sglang/srt/layers/attention/ascend_backend.py +132 -22
- sglang/srt/layers/attention/flashattention_backend.py +24 -17
- sglang/srt/layers/attention/flashinfer_backend.py +14 -3
- sglang/srt/layers/attention/flashinfer_mla_backend.py +227 -76
- sglang/srt/layers/attention/triton_backend.py +109 -73
- sglang/srt/layers/attention/triton_ops/decode_attention.py +33 -2
- sglang/srt/layers/attention/triton_ops/extend_attention.py +32 -2
- sglang/srt/layers/attention/trtllm_mha_backend.py +398 -36
- sglang/srt/layers/attention/trtllm_mla_backend.py +49 -19
- sglang/srt/layers/attention/utils.py +94 -15
- sglang/srt/layers/attention/vision.py +40 -13
- sglang/srt/layers/attention/vision_utils.py +65 -0
- sglang/srt/layers/communicator.py +58 -10
- sglang/srt/layers/dp_attention.py +137 -27
- sglang/srt/layers/elementwise.py +94 -0
- sglang/srt/layers/flashinfer_comm_fusion.py +29 -1
- sglang/srt/layers/layernorm.py +8 -1
- sglang/srt/layers/linear.py +24 -0
- sglang/srt/layers/logits_processor.py +16 -18
- sglang/srt/layers/moe/__init__.py +31 -0
- sglang/srt/layers/moe/ep_moe/layer.py +37 -33
- sglang/srt/layers/moe/fused_moe_native.py +14 -25
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=129,N=352,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=161,N=192,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_0/E=16,N=1024,device_name=NVIDIA_B200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=384,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=640,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H20-3e,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=384,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=704,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=161,N=384,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +69 -76
- sglang/srt/layers/moe/fused_moe_triton/layer.py +66 -123
- sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +20 -18
- sglang/srt/layers/moe/moe_runner/__init__.py +3 -0
- sglang/srt/layers/moe/moe_runner/base.py +13 -0
- sglang/srt/layers/moe/rocm_moe_utils.py +141 -0
- sglang/srt/layers/moe/router.py +15 -9
- sglang/srt/layers/moe/token_dispatcher/__init__.py +6 -0
- sglang/srt/layers/moe/token_dispatcher/base_dispatcher.py +55 -14
- sglang/srt/layers/moe/token_dispatcher/deepep.py +11 -21
- sglang/srt/layers/moe/token_dispatcher/standard.py +1 -1
- sglang/srt/layers/moe/topk.py +167 -83
- sglang/srt/layers/moe/utils.py +159 -18
- sglang/srt/layers/multimodal.py +156 -40
- sglang/srt/layers/quantization/__init__.py +18 -46
- sglang/srt/layers/quantization/awq.py +22 -23
- sglang/srt/layers/quantization/base_config.py +2 -6
- sglang/srt/layers/quantization/blockwise_int8.py +4 -12
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +72 -29
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -1
- sglang/srt/layers/quantization/fp8.py +127 -119
- sglang/srt/layers/quantization/fp8_kernel.py +195 -24
- sglang/srt/layers/quantization/fp8_utils.py +34 -9
- sglang/srt/layers/quantization/fpgemm_fp8.py +203 -0
- sglang/srt/layers/quantization/gptq.py +17 -21
- sglang/srt/layers/quantization/marlin_utils.py +26 -8
- sglang/srt/layers/quantization/marlin_utils_fp8.py +352 -0
- sglang/srt/layers/quantization/modelopt_quant.py +217 -98
- sglang/srt/layers/quantization/moe_wna16.py +10 -15
- sglang/srt/layers/quantization/mxfp4.py +222 -39
- sglang/srt/layers/quantization/quark/quark.py +390 -0
- sglang/srt/layers/quantization/quark/quark_moe.py +197 -0
- sglang/srt/layers/quantization/unquant.py +34 -70
- sglang/srt/layers/quantization/utils.py +77 -2
- sglang/srt/layers/quantization/w4afp8.py +7 -8
- sglang/srt/layers/quantization/w8a8_fp8.py +5 -13
- sglang/srt/layers/quantization/w8a8_int8.py +5 -13
- sglang/srt/layers/radix_attention.py +6 -0
- sglang/srt/layers/rotary_embedding.py +1 -0
- sglang/srt/layers/sampler.py +5 -2
- sglang/srt/lora/layers.py +6 -2
- sglang/srt/lora/lora_manager.py +21 -22
- sglang/srt/lora/lora_registry.py +3 -3
- sglang/srt/lora/mem_pool.py +26 -24
- sglang/srt/lora/utils.py +10 -12
- sglang/srt/managers/cache_controller.py +80 -19
- sglang/srt/managers/detokenizer_manager.py +10 -2
- sglang/srt/managers/io_struct.py +23 -0
- sglang/srt/managers/mm_utils.py +1 -1
- sglang/srt/managers/schedule_batch.py +22 -48
- sglang/srt/managers/scheduler.py +28 -20
- sglang/srt/managers/session_controller.py +1 -1
- sglang/srt/managers/template_manager.py +7 -5
- sglang/srt/managers/tokenizer_manager.py +88 -39
- sglang/srt/managers/tp_worker.py +1 -0
- sglang/srt/managers/utils.py +59 -1
- sglang/srt/mem_cache/allocator.py +10 -157
- sglang/srt/mem_cache/allocator_ascend.py +147 -0
- sglang/srt/mem_cache/chunk_cache.py +1 -1
- sglang/srt/mem_cache/hicache_storage.py +14 -4
- sglang/srt/mem_cache/memory_pool.py +3 -3
- sglang/srt/mem_cache/memory_pool_host.py +35 -2
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +56 -12
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +8 -4
- sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +153 -59
- sglang/srt/mem_cache/storage/nixl/nixl_utils.py +19 -53
- sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +46 -7
- sglang/srt/model_executor/cuda_graph_runner.py +33 -33
- sglang/srt/model_executor/forward_batch_info.py +11 -10
- sglang/srt/model_executor/model_runner.py +93 -78
- sglang/srt/model_executor/npu_graph_runner.py +94 -0
- sglang/srt/model_loader/loader.py +24 -6
- sglang/srt/models/dbrx.py +12 -6
- sglang/srt/models/deepseek.py +2 -1
- sglang/srt/models/deepseek_nextn.py +5 -2
- sglang/srt/models/deepseek_v2.py +226 -223
- sglang/srt/models/ernie4.py +2 -2
- sglang/srt/models/glm4_moe.py +27 -65
- sglang/srt/models/glm4_moe_nextn.py +2 -1
- sglang/srt/models/glm4v.py +52 -1
- sglang/srt/models/glm4v_moe.py +8 -11
- sglang/srt/models/gpt_oss.py +41 -76
- sglang/srt/models/granitemoe.py +0 -1
- sglang/srt/models/grok.py +376 -48
- sglang/srt/models/interns1.py +12 -47
- sglang/srt/models/internvl.py +6 -51
- sglang/srt/models/llama.py +10 -2
- sglang/srt/models/llama4.py +18 -7
- sglang/srt/models/minicpm3.py +0 -1
- sglang/srt/models/mixtral.py +0 -2
- sglang/srt/models/nemotron_nas.py +435 -0
- sglang/srt/models/olmoe.py +0 -1
- sglang/srt/models/phi4mm.py +3 -21
- sglang/srt/models/qwen2.py +2 -2
- sglang/srt/models/qwen2_5_vl.py +2 -0
- sglang/srt/models/qwen2_moe.py +23 -23
- sglang/srt/models/qwen3.py +2 -2
- sglang/srt/models/qwen3_classification.py +84 -0
- sglang/srt/models/qwen3_moe.py +27 -43
- sglang/srt/models/step3_vl.py +8 -3
- sglang/srt/models/xverse_moe.py +11 -5
- sglang/srt/multimodal/processors/base_processor.py +3 -3
- sglang/srt/multimodal/processors/internvl.py +7 -2
- sglang/srt/multimodal/processors/llava.py +11 -7
- sglang/srt/offloader.py +433 -0
- sglang/srt/operations.py +22 -2
- sglang/srt/reasoning_parser.py +4 -3
- sglang/srt/sampling/sampling_batch_info.py +7 -4
- sglang/srt/server_args.py +264 -105
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +8 -21
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +7 -21
- sglang/srt/speculative/eagle_utils.py +36 -13
- sglang/srt/speculative/eagle_worker.py +56 -3
- sglang/srt/tokenizer/tiktoken_tokenizer.py +161 -0
- sglang/srt/two_batch_overlap.py +20 -19
- sglang/srt/utils.py +68 -70
- sglang/test/runners.py +8 -5
- sglang/test/test_block_fp8.py +5 -6
- sglang/test/test_block_fp8_ep.py +13 -19
- sglang/test/test_cutlass_moe.py +4 -6
- sglang/test/test_cutlass_w4a8_moe.py +4 -3
- sglang/test/test_fp4_moe.py +4 -3
- sglang/test/test_marlin_moe.py +1 -1
- sglang/test/test_marlin_utils.py +1 -1
- sglang/test/test_utils.py +7 -0
- sglang/utils.py +0 -1
- sglang/version.py +1 -1
- {sglang-0.5.0rc1.dist-info → sglang-0.5.1.dist-info}/METADATA +11 -11
- {sglang-0.5.0rc1.dist-info → sglang-0.5.1.dist-info}/RECORD +201 -171
- sglang/srt/layers/quantization/fp4.py +0 -557
- sglang/srt/layers/quantization/scalar_type.py +0 -352
- {sglang-0.5.0rc1.dist-info → sglang-0.5.1.dist-info}/WHEEL +0 -0
- {sglang-0.5.0rc1.dist-info → sglang-0.5.1.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.0rc1.dist-info → sglang-0.5.1.dist-info}/top_level.txt +0 -0
@@ -3,7 +3,7 @@ import logging
|
|
3
3
|
import os
|
4
4
|
import time
|
5
5
|
import uuid
|
6
|
-
from typing import Dict, List, Optional, Tuple, Union
|
6
|
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
7
7
|
|
8
8
|
import torch
|
9
9
|
|
@@ -28,6 +28,8 @@ class HiCacheNixl(HiCacheStorage):
|
|
28
28
|
|
29
29
|
def __init__(self, file_path: str = "/tmp/hicache_storage", plugin: str = "auto"):
|
30
30
|
"""Initialize NIXL storage connector."""
|
31
|
+
# Might be better to be unified across HiCache backends and moved to HiCacheController
|
32
|
+
file_path = os.getenv("SGLANG_HICACHE_NIXL_BACKEND_STORAGE_DIR", file_path)
|
31
33
|
self.file_manager = (
|
32
34
|
NixlFileManager(file_path)
|
33
35
|
if plugin not in NixlBackendSelection.OBJ_PLUGINS
|
@@ -44,59 +46,109 @@ class HiCacheNixl(HiCacheStorage):
|
|
44
46
|
|
45
47
|
self.registration = NixlRegistration(self.agent)
|
46
48
|
|
49
|
+
def register_buffers(
|
50
|
+
self, buffers: Union[torch.Tensor, List[torch.Tensor], List[tuple]]
|
51
|
+
) -> Optional[Any]:
|
52
|
+
"""Register tensor(s) or target locations in host memory (list of addr,len tuples) with NIXL."""
|
53
|
+
if isinstance(buffers[0], tuple):
|
54
|
+
tuples = [(x[0], x[1], 0, "") for x in buffers]
|
55
|
+
return self.registration._register_memory(tuples, "DRAM")
|
56
|
+
else:
|
57
|
+
return self.registration._register_memory(buffers)
|
58
|
+
|
59
|
+
def register_files(
|
60
|
+
self, file_paths: List[str], open_file: Optional[bool] = True
|
61
|
+
) -> Optional[Any]:
|
62
|
+
"""Register files with NIXL."""
|
63
|
+
tuples = self.file_manager.files_to_nixl_tuples(file_paths)
|
64
|
+
return self.registration._register_memory(tuples, "FILE")
|
65
|
+
|
66
|
+
def register_objects(
|
67
|
+
self, keys: List[str], sizes: Optional[List[int]] = None
|
68
|
+
) -> Optional[Any]:
|
69
|
+
"""Register objects with NIXL."""
|
70
|
+
if not keys:
|
71
|
+
return None
|
72
|
+
tuples = [(0, 0, key, "") for key in keys]
|
73
|
+
return self.registration._register_memory(tuples, "OBJ")
|
74
|
+
|
47
75
|
def _execute_transfer(
|
48
|
-
self,
|
76
|
+
self,
|
77
|
+
buffers: Optional[List[torch.Tensor | tuple]],
|
78
|
+
keys: List[str],
|
79
|
+
direction: str,
|
49
80
|
) -> bool:
|
50
|
-
if len(
|
51
|
-
logger.error("Mismatch between number of tensors and files/objects")
|
81
|
+
if len(buffers) != len(keys):
|
82
|
+
logger.error("Mismatch between number of tensors/buffers and files/objects")
|
52
83
|
return False
|
53
84
|
|
54
|
-
|
55
|
-
|
56
|
-
return False
|
57
|
-
|
58
|
-
# Get transfer tuples based on backend type
|
59
|
-
tensor_sizes = [tensor.element_size() * tensor.numel() for tensor in tensors]
|
85
|
+
# Registering file and object keys per transfer, to be updated when
|
86
|
+
# pre-registration for file and object is added to HiCache.
|
60
87
|
if self.backend_selector.mem_type == "FILE":
|
61
|
-
|
62
|
-
if not
|
88
|
+
tuples = self.file_manager.files_to_nixl_tuples(keys)
|
89
|
+
if not tuples or not self.registration._register_memory(tuples, "FILE"):
|
63
90
|
logger.error("Failed to prepare files for transfer")
|
64
91
|
return False
|
65
|
-
|
66
|
-
|
67
|
-
|
68
|
-
else:
|
69
|
-
if not self.registration.register_objects(keys, tensors):
|
92
|
+
else: # mem_type == "OBJ"
|
93
|
+
tuples = [(0, 0, key, "") for key in keys]
|
94
|
+
if not tuples or not self.registration._register_memory(tuples, "OBJ"):
|
70
95
|
logger.error("Failed to register objects")
|
71
96
|
return False
|
72
|
-
transfer_tuples = [(0, s, key) for s, key in zip(tensor_sizes, keys)]
|
73
97
|
|
98
|
+
# Prepare transfer descriptors
|
99
|
+
if isinstance(buffers[0], torch.Tensor):
|
100
|
+
tensor_sizes = [
|
101
|
+
tensor.element_size() * tensor.numel() for tensor in buffers
|
102
|
+
]
|
103
|
+
storage_tuples = [(x[0], s, x[2]) for x, s in zip(tuples, tensor_sizes)]
|
104
|
+
host_descs = self.agent.get_xfer_descs(buffers)
|
105
|
+
elif isinstance(buffers[0], tuple):
|
106
|
+
storage_tuples = [(x[0], y[1], x[2]) for x, y in zip(tuples, buffers)]
|
107
|
+
host_descs = self.agent.get_xfer_descs(
|
108
|
+
[(x[0], x[1], 0) for x in buffers], "DRAM"
|
109
|
+
)
|
110
|
+
else:
|
111
|
+
return False
|
112
|
+
|
113
|
+
storage_descs = self.agent.get_xfer_descs(
|
114
|
+
storage_tuples, self.backend_selector.mem_type
|
115
|
+
)
|
116
|
+
|
117
|
+
if (host_descs is None) or (storage_descs is None):
|
118
|
+
logger.error("Failed to get transfer descriptors")
|
119
|
+
return False
|
120
|
+
|
121
|
+
# Initialize transfer, default assumption that tensor was registered
|
74
122
|
try:
|
75
|
-
|
76
|
-
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
|
81
|
-
logger.error("Failed to
|
123
|
+
xfer_req = self.agent.initialize_xfer(
|
124
|
+
direction, host_descs, storage_descs, self.agent_name
|
125
|
+
)
|
126
|
+
except Exception:
|
127
|
+
# Check if it was due to missing pre-registration
|
128
|
+
if not self.register_buffers(buffers):
|
129
|
+
logger.error("Failed to register tensors/buffers")
|
82
130
|
return False
|
83
131
|
|
84
|
-
|
85
|
-
|
86
|
-
|
87
|
-
direction, tensor_descs, file_descs, self.agent_name
|
132
|
+
try:
|
133
|
+
xfer_req = self.agent.initialize_xfer(
|
134
|
+
direction, host_descs, storage_descs, self.agent_name
|
88
135
|
)
|
89
|
-
|
90
|
-
logger.error("Failed to create transfer request")
|
136
|
+
except Exception as e:
|
137
|
+
logger.error(f"Failed to create transfer request: {e}")
|
91
138
|
return False
|
92
139
|
|
140
|
+
# Execute transfer and wait for its completion
|
141
|
+
try:
|
93
142
|
state = self.agent.transfer(xfer_req)
|
94
143
|
while state != "DONE":
|
95
144
|
state = self.agent.check_xfer_state(xfer_req)
|
96
145
|
if state == "ERR":
|
146
|
+
self.agent.release_xfer_handle(xfer_req)
|
97
147
|
logger.error("Transfer failed")
|
98
148
|
return False
|
99
|
-
|
149
|
+
time.sleep(0.0001) # Can be changed to os.sched_yield() or parametrized
|
150
|
+
|
151
|
+
self.agent.release_xfer_handle(xfer_req)
|
100
152
|
return True
|
101
153
|
|
102
154
|
except Exception as e:
|
@@ -106,45 +158,87 @@ class HiCacheNixl(HiCacheStorage):
|
|
106
158
|
logger.error(f"Traceback: {traceback.format_exc()}")
|
107
159
|
return False
|
108
160
|
|
109
|
-
def batch_set(self, keys: List[str], values: List[torch.Tensor]) -> bool:
|
110
|
-
if not keys:
|
111
|
-
return True
|
112
|
-
|
113
|
-
if self.backend_selector.mem_type == "FILE":
|
114
|
-
file_paths = []
|
115
|
-
for key in keys:
|
116
|
-
tensor_path = self.file_manager.get_file_path(key)
|
117
|
-
if not self.file_manager.create_file(tensor_path):
|
118
|
-
logger.error(f"Failed to create file {tensor_path}")
|
119
|
-
return False
|
120
|
-
file_paths.append(tensor_path)
|
121
|
-
return self._execute_transfer(values, file_paths, "WRITE")
|
122
|
-
else:
|
123
|
-
return self._execute_transfer(values, keys, "WRITE")
|
124
|
-
|
125
|
-
def set(self, key: str, value: torch.Tensor) -> bool:
|
126
|
-
return self.batch_set([key], [value])
|
127
|
-
|
128
161
|
def get(
|
129
|
-
self,
|
162
|
+
self,
|
163
|
+
key: str,
|
164
|
+
target_location: Optional[torch.Tensor | int] = None,
|
165
|
+
target_sizes: Optional[int] = None,
|
130
166
|
) -> torch.Tensor | None:
|
131
|
-
|
167
|
+
# To be removed, being compatible with the current API
|
168
|
+
if target_location is None:
|
132
169
|
return None
|
133
|
-
|
170
|
+
if target_sizes:
|
171
|
+
result = self.batch_get([key], [target_location], [target_sizes])
|
172
|
+
else:
|
173
|
+
result = self.batch_get([key], [target_location])
|
134
174
|
return result[0] if result else None
|
135
175
|
|
136
176
|
def batch_get(
|
137
|
-
self,
|
138
|
-
|
177
|
+
self,
|
178
|
+
keys: List[str],
|
179
|
+
target_locations: Optional[List[torch.Tensor | int]] = None,
|
180
|
+
target_sizes: Optional[List[int]] = None,
|
181
|
+
) -> List[torch.Tensor | None]:
|
139
182
|
if not keys:
|
140
183
|
return []
|
141
184
|
|
185
|
+
# To be removed, being compatible with the current API
|
186
|
+
if not target_locations:
|
187
|
+
return [None] * len(keys)
|
188
|
+
|
189
|
+
if target_sizes and (len(target_sizes) != len(target_locations)):
|
190
|
+
logger.error("Mismatch between number of target_locations and target_sizes")
|
191
|
+
return [None] * len(keys)
|
192
|
+
if target_sizes:
|
193
|
+
dest = list(zip(target_locations, target_sizes))
|
194
|
+
else:
|
195
|
+
dest = target_locations
|
196
|
+
|
142
197
|
if self.backend_selector.mem_type == "FILE":
|
143
198
|
file_paths = [self.file_manager.get_file_path(key) for key in keys]
|
144
|
-
success = self._execute_transfer(
|
199
|
+
success = self._execute_transfer(dest, file_paths, "READ")
|
145
200
|
else:
|
146
|
-
success = self._execute_transfer(
|
147
|
-
return
|
201
|
+
success = self._execute_transfer(dest, keys, "READ")
|
202
|
+
return target_locations if success and not target_sizes else [None] * len(keys)
|
203
|
+
|
204
|
+
def set(
|
205
|
+
self,
|
206
|
+
key: str,
|
207
|
+
value: Optional[torch.Tensor] = None,
|
208
|
+
target_location: Optional[int] = None,
|
209
|
+
target_sizes: Optional[int] = None,
|
210
|
+
) -> bool:
|
211
|
+
if target_location and target_sizes:
|
212
|
+
return self.batch_set([key], None, [target_location], [target_sizes])
|
213
|
+
else:
|
214
|
+
return self.batch_set([key], [value])
|
215
|
+
|
216
|
+
def batch_set(
|
217
|
+
self,
|
218
|
+
keys: List[str],
|
219
|
+
values: Optional[List[torch.Tensor]] = None,
|
220
|
+
target_locations: Optional[List[int]] = None,
|
221
|
+
target_sizes: Optional[List[int]] = None,
|
222
|
+
) -> bool:
|
223
|
+
if not keys or (not values and (not target_locations or not target_sizes)):
|
224
|
+
logger.error("Keys or values were not passed")
|
225
|
+
return False
|
226
|
+
|
227
|
+
if not values:
|
228
|
+
values = list(zip(target_locations, target_sizes))
|
229
|
+
|
230
|
+
if self.backend_selector.mem_type == "FILE":
|
231
|
+
file_paths = []
|
232
|
+
for key in keys:
|
233
|
+
file_path = self.file_manager.get_file_path(key)
|
234
|
+
# New file per set, to be updated when partial writes is added to HiCache
|
235
|
+
if not self.file_manager.create_file(file_path):
|
236
|
+
logger.error(f"Failed to create file {file_path}")
|
237
|
+
return False
|
238
|
+
file_paths.append(file_path)
|
239
|
+
return self._execute_transfer(values, file_paths, "WRITE")
|
240
|
+
else: # mem_type == "OBJ"
|
241
|
+
return self._execute_transfer(values, keys, "WRITE")
|
148
242
|
|
149
243
|
def exists(self, key: str) -> bool:
|
150
244
|
tuples = self.registration.create_query_tuples(
|
@@ -109,66 +109,35 @@ class NixlRegistration:
|
|
109
109
|
return [(0, 0, key)]
|
110
110
|
|
111
111
|
def _register_memory(
|
112
|
-
self,
|
112
|
+
self,
|
113
|
+
items: Union[List[tuple], torch.Tensor, List[torch.Tensor]],
|
114
|
+
mem_type: Optional[str] = None,
|
113
115
|
) -> Optional[Any]:
|
114
116
|
"""Common registration logic for files, objects, and buffers.
|
115
117
|
Args:
|
116
118
|
items: List of tuples or tensors to register
|
117
|
-
mem_type: Memory type ("FILE", "OBJ"
|
118
|
-
desc: Description for logging
|
119
|
+
mem_type: Memory type ("FILE", "OBJ") or None for tensor or list of tensors
|
119
120
|
"""
|
120
|
-
|
121
|
-
if not items:
|
122
|
-
return None
|
123
|
-
|
124
|
-
reg_descs = self.agent.get_reg_descs(items, mem_type)
|
125
|
-
if reg_descs is None:
|
126
|
-
logger.error("Failed to create registration descriptors")
|
127
|
-
return None
|
128
|
-
|
129
|
-
registered_memory = self.agent.register_memory(reg_descs)
|
130
|
-
if registered_memory:
|
131
|
-
return registered_memory
|
132
|
-
else:
|
133
|
-
logger.error("Failed to register with NIXL")
|
134
|
-
return None
|
135
|
-
|
136
|
-
except Exception as e:
|
137
|
-
logger.error(f"Failed to register {desc}: {e}")
|
121
|
+
if isinstance(items, list) and not items:
|
138
122
|
return None
|
139
123
|
|
140
|
-
|
141
|
-
|
142
|
-
|
143
|
-
"""Register tensors/buffers with NIXL."""
|
144
|
-
if isinstance(buffers, torch.Tensor):
|
145
|
-
buffers = [buffers]
|
146
|
-
|
147
|
-
if not buffers:
|
124
|
+
reg_descs = self.agent.get_reg_descs(items, mem_type)
|
125
|
+
if reg_descs is None:
|
126
|
+
logger.error("Failed to create registration descriptors")
|
148
127
|
return None
|
149
128
|
|
150
|
-
|
151
|
-
|
152
|
-
|
153
|
-
|
154
|
-
|
155
|
-
|
156
|
-
|
157
|
-
|
158
|
-
|
159
|
-
|
160
|
-
) -> Optional[Any]:
|
161
|
-
"""Register objects with NIXL."""
|
162
|
-
if not keys:
|
129
|
+
try:
|
130
|
+
registered_memory = self.agent.register_memory(reg_descs)
|
131
|
+
return registered_memory # Could be None in case of error
|
132
|
+
except Exception as e:
|
133
|
+
if not mem_type:
|
134
|
+
logger.error(f"Failed to register Tensors with NIXL: {e}")
|
135
|
+
else:
|
136
|
+
logger.error(
|
137
|
+
f"Failed to register memory of type {mem_type} with NIXL: {e}"
|
138
|
+
)
|
163
139
|
return None
|
164
140
|
|
165
|
-
# Create object tuples with proper sizes
|
166
|
-
tuples = [
|
167
|
-
(0, tensor.element_size() * tensor.numel() if tensor else 0, key)
|
168
|
-
for key, tensor in zip(keys, tensors or [None] * len(keys))
|
169
|
-
]
|
170
|
-
return self._register_memory(tuples, "OBJ", "objects")
|
171
|
-
|
172
141
|
|
173
142
|
class NixlFileManager:
|
174
143
|
"""Handles file system operations for NIXL."""
|
@@ -221,12 +190,9 @@ class NixlFileManager:
|
|
221
190
|
return False
|
222
191
|
|
223
192
|
def files_to_nixl_tuples(
|
224
|
-
self, file_paths: List[str]
|
193
|
+
self, file_paths: List[str]
|
225
194
|
) -> List[Tuple[int, int, int, str]]:
|
226
195
|
"""Create NIXL tuples (offset, length, fd, file_path) for given files."""
|
227
|
-
if not open_file:
|
228
|
-
return [(0, 0, 0, path) for path in file_paths]
|
229
|
-
|
230
196
|
tuples = []
|
231
197
|
for path in file_paths:
|
232
198
|
if (fd := self.open_file(path)) is None:
|
@@ -7,8 +7,11 @@ from unittest.mock import MagicMock
|
|
7
7
|
|
8
8
|
import torch
|
9
9
|
|
10
|
-
from sglang.srt.mem_cache.nixl.hicache_nixl import HiCacheNixl
|
11
|
-
from sglang.srt.mem_cache.nixl.nixl_utils import
|
10
|
+
from sglang.srt.mem_cache.storage.nixl.hicache_nixl import HiCacheNixl
|
11
|
+
from sglang.srt.mem_cache.storage.nixl.nixl_utils import (
|
12
|
+
NixlFileManager,
|
13
|
+
NixlRegistration,
|
14
|
+
)
|
12
15
|
|
13
16
|
|
14
17
|
class TestNixlUnified(unittest.TestCase):
|
@@ -88,8 +91,27 @@ class TestNixlUnified(unittest.TestCase):
|
|
88
91
|
|
89
92
|
# Test get
|
90
93
|
retrieved = self.hicache.get(key, dst_tensor)
|
94
|
+
self.verify_tensors_equal(value, dst_tensor)
|
91
95
|
self.verify_tensors_equal(value, retrieved)
|
92
96
|
|
97
|
+
# Same test in addr,len mode with another key and dst_tensor
|
98
|
+
key2 = "test_key2"
|
99
|
+
dst_tensor2 = torch.zeros_like(value, device="cpu")
|
100
|
+
src_addr, src_len = value.data_ptr(), value.numel() * value.element_size()
|
101
|
+
dst_addr, dst_len = (
|
102
|
+
dst_tensor2.data_ptr(),
|
103
|
+
dst_tensor2.numel() * dst_tensor2.element_size(),
|
104
|
+
)
|
105
|
+
|
106
|
+
# Test set
|
107
|
+
self.assertTrue(self.hicache.set(key, None, src_addr, src_len))
|
108
|
+
self.assertTrue(self.hicache.exists(key))
|
109
|
+
|
110
|
+
# Test get
|
111
|
+
retrieved2 = self.hicache.get(key, dst_addr, dst_len)
|
112
|
+
self.assertTrue(retrieved2 == None)
|
113
|
+
self.verify_tensors_equal(value, dst_tensor2)
|
114
|
+
|
93
115
|
def test_batch_set_get(self):
|
94
116
|
"""Test batch tensor set/get operations."""
|
95
117
|
keys = ["key1", "key2", "key3"]
|
@@ -108,6 +130,23 @@ class TestNixlUnified(unittest.TestCase):
|
|
108
130
|
retrieved = self.hicache.batch_get(keys, dst_tensors)
|
109
131
|
self.verify_tensor_lists_equal(values, retrieved)
|
110
132
|
|
133
|
+
# Same test in addr,len mode with another key and dst_tensor
|
134
|
+
keys2 = ["key4", "key5", "key6"]
|
135
|
+
dst_tensors2 = [torch.zeros_like(v, device="cpu") for v in values]
|
136
|
+
src_addrs = [v.data_ptr() for v in values]
|
137
|
+
src_lens = [v.numel() * v.element_size() for v in values]
|
138
|
+
dst_addrs = [dt.data_ptr() for dt in dst_tensors2]
|
139
|
+
dst_lens = [dt.numel() * dt.element_size() for dt in dst_tensors2]
|
140
|
+
|
141
|
+
# Test batch set
|
142
|
+
self.assertTrue(self.hicache.batch_set(keys2, None, src_addrs, src_lens))
|
143
|
+
self.assertTrue(all(self.hicache.exists(key) for key in keys2))
|
144
|
+
|
145
|
+
# Test batch get
|
146
|
+
retrieved2 = self.hicache.batch_get(keys, dst_addrs, dst_lens)
|
147
|
+
self.assertTrue(all(ret == None for ret in retrieved2))
|
148
|
+
self.verify_tensor_lists_equal(values, dst_tensors2)
|
149
|
+
|
111
150
|
def test_mixed_operations(self):
|
112
151
|
"""Test mixing single and batch operations."""
|
113
152
|
# Test interleaved set/get operations
|
@@ -170,7 +209,7 @@ class TestNixlUnified(unittest.TestCase):
|
|
170
209
|
self.file_manager.create_file(test_file)
|
171
210
|
|
172
211
|
# Test tuple creation
|
173
|
-
tuples = self.file_manager.files_to_nixl_tuples([test_file]
|
212
|
+
tuples = self.file_manager.files_to_nixl_tuples([test_file])
|
174
213
|
self.assertIsNotNone(tuples)
|
175
214
|
self.assertTrue(len(tuples) > 0)
|
176
215
|
|
@@ -190,11 +229,11 @@ class TestNixlUnified(unittest.TestCase):
|
|
190
229
|
tensor = torch.randn(10, 10)
|
191
230
|
|
192
231
|
# Test buffer registration
|
193
|
-
self.assertIsNotNone(self.
|
232
|
+
self.assertIsNotNone(self.hicache.register_buffers(tensor))
|
194
233
|
|
195
234
|
# Test batch registration
|
196
235
|
tensors = [torch.randn(5, 5) for _ in range(3)]
|
197
|
-
self.assertIsNotNone(self.
|
236
|
+
self.assertIsNotNone(self.hicache.register_buffers(tensors))
|
198
237
|
|
199
238
|
def test_register_files_with_tuples(self):
|
200
239
|
"""Test registration of files using NIXL tuples."""
|
@@ -203,8 +242,8 @@ class TestNixlUnified(unittest.TestCase):
|
|
203
242
|
self.file_manager.create_file(file)
|
204
243
|
|
205
244
|
# Create tuples and register
|
206
|
-
tuples = self.file_manager.files_to_nixl_tuples(files
|
207
|
-
self.
|
245
|
+
tuples = self.file_manager.files_to_nixl_tuples(files)
|
246
|
+
self.hicache.register_files(tuples)
|
208
247
|
|
209
248
|
# Verify tuples
|
210
249
|
self.assertEqual(len(tuples), len(files))
|
@@ -34,9 +34,10 @@ from sglang.srt.distributed.device_communicators.pynccl_allocator import (
|
|
34
34
|
)
|
35
35
|
from sglang.srt.distributed.parallel_state import GroupCoordinator, graph_capture
|
36
36
|
from sglang.srt.layers.dp_attention import (
|
37
|
-
|
37
|
+
DpPaddingMode,
|
38
38
|
get_attention_tp_rank,
|
39
39
|
get_attention_tp_size,
|
40
|
+
set_dp_buffer_len,
|
40
41
|
)
|
41
42
|
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
42
43
|
from sglang.srt.layers.torchao_utils import save_gemlite_cache
|
@@ -239,6 +240,8 @@ class CudaGraphRunner:
|
|
239
240
|
def __init__(self, model_runner: ModelRunner):
|
240
241
|
# Parse args
|
241
242
|
self.model_runner = model_runner
|
243
|
+
self.device = model_runner.device
|
244
|
+
self.device_module = torch.get_device_module(self.device)
|
242
245
|
self.graphs = {}
|
243
246
|
self.output_buffers = {}
|
244
247
|
self.enable_torch_compile = model_runner.server_args.enable_torch_compile
|
@@ -304,13 +307,15 @@ class CudaGraphRunner:
|
|
304
307
|
self.model_runner.lora_manager.init_cuda_graph_batch_info(self.max_bs)
|
305
308
|
|
306
309
|
# Graph inputs
|
307
|
-
with torch.device(
|
310
|
+
with torch.device(self.device):
|
308
311
|
self.input_ids = torch.zeros((self.max_num_token,), dtype=torch.int64)
|
309
312
|
self.req_pool_indices = torch.zeros((self.max_bs,), dtype=torch.int32)
|
310
313
|
self.seq_lens = torch.full(
|
311
314
|
(self.max_bs,), self.seq_len_fill_value, dtype=torch.int32
|
312
315
|
)
|
313
|
-
self.out_cache_loc = torch.zeros(
|
316
|
+
self.out_cache_loc = torch.zeros(
|
317
|
+
(self.max_num_token,), dtype=self._cache_loc_dtype()
|
318
|
+
)
|
314
319
|
self.positions = torch.zeros((self.max_num_token,), dtype=torch.int64)
|
315
320
|
self.mrope_positions = torch.zeros((3, self.max_bs), dtype=torch.int64)
|
316
321
|
self.num_token_non_padded = torch.zeros((1,), dtype=torch.int32)
|
@@ -349,30 +354,15 @@ class CudaGraphRunner:
|
|
349
354
|
self.global_num_tokens_for_logprob_gpu = torch.zeros(
|
350
355
|
(self.dp_size,), dtype=torch.int32
|
351
356
|
)
|
352
|
-
self.gathered_buffer = torch.zeros(
|
353
|
-
(
|
354
|
-
self.max_num_token * self.dp_size,
|
355
|
-
self.model_runner.model_config.hidden_size,
|
356
|
-
),
|
357
|
-
dtype=self.model_runner.dtype,
|
358
|
-
)
|
359
357
|
else:
|
360
358
|
assert self.require_attn_tp_gather
|
361
359
|
self.global_num_tokens_gpu = torch.zeros((1,), dtype=torch.int32)
|
362
360
|
self.global_num_tokens_for_logprob_gpu = torch.zeros(
|
363
361
|
(1,), dtype=torch.int32
|
364
362
|
)
|
365
|
-
self.gathered_buffer = torch.zeros(
|
366
|
-
(
|
367
|
-
self.max_num_token,
|
368
|
-
self.model_runner.model_config.hidden_size,
|
369
|
-
),
|
370
|
-
dtype=self.model_runner.dtype,
|
371
|
-
)
|
372
363
|
else:
|
373
364
|
self.global_num_tokens_gpu = None
|
374
365
|
self.global_num_tokens_for_logprob_gpu = None
|
375
|
-
self.gathered_buffer = None
|
376
366
|
|
377
367
|
self.custom_mask = torch.ones(
|
378
368
|
(
|
@@ -380,12 +370,12 @@ class CudaGraphRunner:
|
|
380
370
|
* self.num_tokens_per_bs
|
381
371
|
),
|
382
372
|
dtype=torch.bool,
|
383
|
-
device=
|
373
|
+
device=self.device,
|
384
374
|
)
|
385
375
|
self.next_token_logits_buffer = torch.zeros(
|
386
376
|
(self.max_num_token, self.model_runner.model_config.vocab_size),
|
387
377
|
dtype=torch.float,
|
388
|
-
device=
|
378
|
+
device=self.device,
|
389
379
|
)
|
390
380
|
|
391
381
|
# Capture
|
@@ -397,6 +387,9 @@ class CudaGraphRunner:
|
|
397
387
|
f"Capture cuda graph failed: {e}\n{CUDA_GRAPH_CAPTURE_FAILED_MSG}"
|
398
388
|
)
|
399
389
|
|
390
|
+
def _cache_loc_dtype(self):
|
391
|
+
return torch.int64
|
392
|
+
|
400
393
|
def can_run(self, forward_batch: ForwardBatch):
|
401
394
|
if self.require_mlp_tp_gather:
|
402
395
|
cuda_graph_bs = (
|
@@ -516,8 +509,16 @@ class CudaGraphRunner:
|
|
516
509
|
)
|
517
510
|
logger.info(log_message)
|
518
511
|
|
512
|
+
def _capture_graph(self, graph, pool, stream, run_once_fn):
|
513
|
+
with self.device_module.graph(graph, pool=pool, stream=stream):
|
514
|
+
out = run_once_fn()
|
515
|
+
return out
|
516
|
+
|
517
|
+
def _create_device_graph(self):
|
518
|
+
return torch.cuda.CUDAGraph()
|
519
|
+
|
519
520
|
def capture_one_batch_size(self, bs: int, forward: Callable):
|
520
|
-
graph =
|
521
|
+
graph = self._create_device_graph()
|
521
522
|
stream = self.stream
|
522
523
|
num_tokens = bs * self.num_tokens_per_bs
|
523
524
|
|
@@ -556,7 +557,7 @@ class CudaGraphRunner:
|
|
556
557
|
device=input_ids.device,
|
557
558
|
)
|
558
559
|
)
|
559
|
-
|
560
|
+
global_dp_buffer_len = num_tokens * self.dp_size
|
560
561
|
elif self.require_attn_tp_gather:
|
561
562
|
self.global_num_tokens_gpu.copy_(
|
562
563
|
torch.tensor(
|
@@ -572,9 +573,9 @@ class CudaGraphRunner:
|
|
572
573
|
device=input_ids.device,
|
573
574
|
)
|
574
575
|
)
|
575
|
-
|
576
|
+
global_dp_buffer_len = num_tokens
|
576
577
|
else:
|
577
|
-
|
578
|
+
global_dp_buffer_len = None
|
578
579
|
|
579
580
|
spec_info = self.get_spec_info(num_tokens)
|
580
581
|
if self.capture_hidden_mode != CaptureHiddenMode.FULL:
|
@@ -607,8 +608,8 @@ class CudaGraphRunner:
|
|
607
608
|
positions=positions,
|
608
609
|
global_num_tokens_gpu=self.global_num_tokens_gpu,
|
609
610
|
global_num_tokens_for_logprob_gpu=self.global_num_tokens_for_logprob_gpu,
|
610
|
-
dp_padding_mode=
|
611
|
-
|
611
|
+
dp_padding_mode=DpPaddingMode.get_default_mode_in_cuda_graph(),
|
612
|
+
global_dp_buffer_len=global_dp_buffer_len,
|
612
613
|
mrope_positions=mrope_positions,
|
613
614
|
spec_algorithm=self.model_runner.spec_algorithm,
|
614
615
|
spec_info=spec_info,
|
@@ -637,6 +638,7 @@ class CudaGraphRunner:
|
|
637
638
|
def run_once():
|
638
639
|
# Clean intermediate result cache for DP attention
|
639
640
|
forward_batch.dp_local_start_pos = forward_batch.dp_local_num_tokens = None
|
641
|
+
set_dp_buffer_len(global_dp_buffer_len, num_tokens)
|
640
642
|
|
641
643
|
kwargs = {}
|
642
644
|
if (
|
@@ -656,19 +658,17 @@ class CudaGraphRunner:
|
|
656
658
|
return logits_output_or_pp_proxy_tensors
|
657
659
|
|
658
660
|
for _ in range(2):
|
659
|
-
|
661
|
+
self.device_module.synchronize()
|
660
662
|
self.model_runner.tp_group.barrier()
|
661
|
-
|
662
663
|
run_once()
|
663
664
|
|
664
665
|
if get_global_graph_memory_pool() is None:
|
665
|
-
set_global_graph_memory_pool(
|
666
|
+
set_global_graph_memory_pool(self.device_module.graph_pool_handle())
|
666
667
|
# Set graph pool id globally to be able to use symmetric memory
|
667
668
|
set_graph_pool_id(get_global_graph_memory_pool())
|
668
|
-
|
669
|
-
graph,
|
670
|
-
)
|
671
|
-
out = run_once()
|
669
|
+
out = self._capture_graph(
|
670
|
+
graph, get_global_graph_memory_pool(), stream, run_once
|
671
|
+
)
|
672
672
|
|
673
673
|
return graph, out
|
674
674
|
|