sglang 0.5.0rc2__py3-none-any.whl → 0.5.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +0 -6
- sglang/bench_one_batch_server.py +7 -2
- sglang/bench_serving.py +3 -3
- sglang/eval/llama3_eval.py +0 -1
- sglang/srt/configs/model_config.py +24 -9
- sglang/srt/configs/update_config.py +40 -5
- sglang/srt/constrained/xgrammar_backend.py +23 -11
- sglang/srt/conversation.py +2 -15
- sglang/srt/disaggregation/ascend/conn.py +1 -3
- sglang/srt/disaggregation/base/conn.py +1 -0
- sglang/srt/disaggregation/decode.py +1 -1
- sglang/srt/disaggregation/launch_lb.py +7 -1
- sglang/srt/disaggregation/mini_lb.py +11 -5
- sglang/srt/disaggregation/mooncake/conn.py +141 -47
- sglang/srt/disaggregation/prefill.py +261 -5
- sglang/srt/disaggregation/utils.py +2 -1
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +1 -1
- sglang/srt/distributed/device_communicators/pynccl.py +68 -18
- sglang/srt/distributed/device_communicators/pynccl_wrapper.py +52 -0
- sglang/srt/distributed/naive_distributed.py +112 -0
- sglang/srt/distributed/parallel_state.py +90 -4
- sglang/srt/entrypoints/context.py +20 -1
- sglang/srt/entrypoints/engine.py +27 -2
- sglang/srt/entrypoints/http_server.py +12 -0
- sglang/srt/entrypoints/openai/protocol.py +2 -2
- sglang/srt/entrypoints/openai/serving_chat.py +22 -6
- sglang/srt/entrypoints/openai/serving_completions.py +9 -1
- sglang/srt/entrypoints/openai/serving_responses.py +2 -2
- sglang/srt/eplb/expert_distribution.py +2 -3
- sglang/srt/function_call/deepseekv3_detector.py +1 -1
- sglang/srt/hf_transformers_utils.py +24 -0
- sglang/srt/host_shared_memory.py +83 -0
- sglang/srt/layers/attention/ascend_backend.py +132 -22
- sglang/srt/layers/attention/flashattention_backend.py +24 -17
- sglang/srt/layers/attention/flashinfer_backend.py +11 -3
- sglang/srt/layers/attention/flashinfer_mla_backend.py +226 -76
- sglang/srt/layers/attention/triton_backend.py +85 -46
- sglang/srt/layers/attention/triton_ops/decode_attention.py +33 -2
- sglang/srt/layers/attention/triton_ops/extend_attention.py +32 -2
- sglang/srt/layers/attention/trtllm_mha_backend.py +390 -30
- sglang/srt/layers/attention/trtllm_mla_backend.py +39 -16
- sglang/srt/layers/attention/utils.py +94 -15
- sglang/srt/layers/attention/vision.py +40 -13
- sglang/srt/layers/attention/vision_utils.py +65 -0
- sglang/srt/layers/communicator.py +51 -3
- sglang/srt/layers/dp_attention.py +23 -4
- sglang/srt/layers/elementwise.py +94 -0
- sglang/srt/layers/flashinfer_comm_fusion.py +29 -1
- sglang/srt/layers/layernorm.py +8 -1
- sglang/srt/layers/linear.py +24 -0
- sglang/srt/layers/logits_processor.py +5 -1
- sglang/srt/layers/moe/__init__.py +31 -0
- sglang/srt/layers/moe/ep_moe/layer.py +37 -33
- sglang/srt/layers/moe/fused_moe_native.py +14 -25
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=384,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=704,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=161,N=384,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +69 -76
- sglang/srt/layers/moe/fused_moe_triton/layer.py +66 -123
- sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +20 -18
- sglang/srt/layers/moe/moe_runner/__init__.py +3 -0
- sglang/srt/layers/moe/moe_runner/base.py +13 -0
- sglang/srt/layers/moe/rocm_moe_utils.py +141 -0
- sglang/srt/layers/moe/router.py +15 -9
- sglang/srt/layers/moe/token_dispatcher/__init__.py +6 -0
- sglang/srt/layers/moe/token_dispatcher/base_dispatcher.py +55 -14
- sglang/srt/layers/moe/token_dispatcher/deepep.py +11 -21
- sglang/srt/layers/moe/token_dispatcher/standard.py +1 -1
- sglang/srt/layers/moe/topk.py +167 -83
- sglang/srt/layers/moe/utils.py +159 -18
- sglang/srt/layers/quantization/__init__.py +13 -14
- sglang/srt/layers/quantization/awq.py +7 -7
- sglang/srt/layers/quantization/base_config.py +2 -6
- sglang/srt/layers/quantization/blockwise_int8.py +4 -12
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +72 -28
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -1
- sglang/srt/layers/quantization/fp8.py +127 -119
- sglang/srt/layers/quantization/fp8_kernel.py +195 -24
- sglang/srt/layers/quantization/fp8_utils.py +34 -9
- sglang/srt/layers/quantization/fpgemm_fp8.py +203 -0
- sglang/srt/layers/quantization/gptq.py +5 -4
- sglang/srt/layers/quantization/marlin_utils.py +11 -3
- sglang/srt/layers/quantization/marlin_utils_fp8.py +352 -0
- sglang/srt/layers/quantization/modelopt_quant.py +165 -68
- sglang/srt/layers/quantization/moe_wna16.py +10 -15
- sglang/srt/layers/quantization/mxfp4.py +206 -37
- sglang/srt/layers/quantization/quark/quark.py +390 -0
- sglang/srt/layers/quantization/quark/quark_moe.py +197 -0
- sglang/srt/layers/quantization/unquant.py +34 -70
- sglang/srt/layers/quantization/utils.py +25 -0
- sglang/srt/layers/quantization/w4afp8.py +7 -8
- sglang/srt/layers/quantization/w8a8_fp8.py +5 -13
- sglang/srt/layers/quantization/w8a8_int8.py +5 -13
- sglang/srt/layers/radix_attention.py +6 -0
- sglang/srt/layers/rotary_embedding.py +1 -0
- sglang/srt/lora/lora_manager.py +21 -22
- sglang/srt/lora/lora_registry.py +3 -3
- sglang/srt/lora/mem_pool.py +26 -24
- sglang/srt/lora/utils.py +10 -12
- sglang/srt/managers/cache_controller.py +76 -18
- sglang/srt/managers/detokenizer_manager.py +10 -2
- sglang/srt/managers/io_struct.py +9 -0
- sglang/srt/managers/mm_utils.py +1 -1
- sglang/srt/managers/schedule_batch.py +4 -9
- sglang/srt/managers/scheduler.py +25 -16
- sglang/srt/managers/session_controller.py +1 -1
- sglang/srt/managers/template_manager.py +7 -5
- sglang/srt/managers/tokenizer_manager.py +60 -21
- sglang/srt/managers/tp_worker.py +1 -0
- sglang/srt/managers/utils.py +59 -1
- sglang/srt/mem_cache/allocator.py +7 -5
- sglang/srt/mem_cache/allocator_ascend.py +0 -11
- sglang/srt/mem_cache/hicache_storage.py +14 -4
- sglang/srt/mem_cache/memory_pool.py +3 -3
- sglang/srt/mem_cache/memory_pool_host.py +35 -2
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +56 -12
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +8 -4
- sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +153 -59
- sglang/srt/mem_cache/storage/nixl/nixl_utils.py +19 -53
- sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +46 -7
- sglang/srt/model_executor/cuda_graph_runner.py +25 -12
- sglang/srt/model_executor/forward_batch_info.py +4 -1
- sglang/srt/model_executor/model_runner.py +43 -32
- sglang/srt/model_executor/npu_graph_runner.py +94 -0
- sglang/srt/model_loader/loader.py +24 -6
- sglang/srt/models/dbrx.py +12 -6
- sglang/srt/models/deepseek.py +2 -1
- sglang/srt/models/deepseek_nextn.py +3 -1
- sglang/srt/models/deepseek_v2.py +224 -223
- sglang/srt/models/ernie4.py +2 -2
- sglang/srt/models/glm4_moe.py +25 -63
- sglang/srt/models/glm4v.py +52 -1
- sglang/srt/models/glm4v_moe.py +8 -11
- sglang/srt/models/gpt_oss.py +34 -74
- sglang/srt/models/granitemoe.py +0 -1
- sglang/srt/models/grok.py +376 -48
- sglang/srt/models/interns1.py +12 -47
- sglang/srt/models/internvl.py +6 -51
- sglang/srt/models/llama4.py +0 -2
- sglang/srt/models/minicpm3.py +0 -1
- sglang/srt/models/mixtral.py +0 -2
- sglang/srt/models/nemotron_nas.py +435 -0
- sglang/srt/models/olmoe.py +0 -1
- sglang/srt/models/phi4mm.py +3 -21
- sglang/srt/models/qwen2_5_vl.py +2 -0
- sglang/srt/models/qwen2_moe.py +3 -18
- sglang/srt/models/qwen3.py +2 -2
- sglang/srt/models/qwen3_classification.py +7 -1
- sglang/srt/models/qwen3_moe.py +9 -38
- sglang/srt/models/step3_vl.py +2 -1
- sglang/srt/models/xverse_moe.py +11 -5
- sglang/srt/multimodal/processors/base_processor.py +3 -3
- sglang/srt/multimodal/processors/internvl.py +7 -2
- sglang/srt/multimodal/processors/llava.py +11 -7
- sglang/srt/offloader.py +433 -0
- sglang/srt/operations.py +6 -1
- sglang/srt/reasoning_parser.py +4 -3
- sglang/srt/server_args.py +237 -104
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +1 -0
- sglang/srt/speculative/eagle_utils.py +36 -13
- sglang/srt/speculative/eagle_worker.py +56 -3
- sglang/srt/tokenizer/tiktoken_tokenizer.py +161 -0
- sglang/srt/two_batch_overlap.py +16 -11
- sglang/srt/utils.py +68 -70
- sglang/test/runners.py +8 -5
- sglang/test/test_block_fp8.py +5 -6
- sglang/test/test_block_fp8_ep.py +13 -19
- sglang/test/test_cutlass_moe.py +4 -6
- sglang/test/test_cutlass_w4a8_moe.py +4 -3
- sglang/test/test_fp4_moe.py +4 -3
- sglang/test/test_utils.py +7 -0
- sglang/utils.py +0 -1
- sglang/version.py +1 -1
- {sglang-0.5.0rc2.dist-info → sglang-0.5.1.dist-info}/METADATA +7 -7
- {sglang-0.5.0rc2.dist-info → sglang-0.5.1.dist-info}/RECORD +179 -161
- sglang/srt/layers/quantization/fp4.py +0 -557
- {sglang-0.5.0rc2.dist-info → sglang-0.5.1.dist-info}/WHEEL +0 -0
- {sglang-0.5.0rc2.dist-info → sglang-0.5.1.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.0rc2.dist-info → sglang-0.5.1.dist-info}/top_level.txt +0 -0
@@ -3,7 +3,7 @@ import logging
|
|
3
3
|
import os
|
4
4
|
import time
|
5
5
|
import uuid
|
6
|
-
from typing import Dict, List, Optional, Tuple, Union
|
6
|
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
7
7
|
|
8
8
|
import torch
|
9
9
|
|
@@ -28,6 +28,8 @@ class HiCacheNixl(HiCacheStorage):
|
|
28
28
|
|
29
29
|
def __init__(self, file_path: str = "/tmp/hicache_storage", plugin: str = "auto"):
|
30
30
|
"""Initialize NIXL storage connector."""
|
31
|
+
# Might be better to be unified across HiCache backends and moved to HiCacheController
|
32
|
+
file_path = os.getenv("SGLANG_HICACHE_NIXL_BACKEND_STORAGE_DIR", file_path)
|
31
33
|
self.file_manager = (
|
32
34
|
NixlFileManager(file_path)
|
33
35
|
if plugin not in NixlBackendSelection.OBJ_PLUGINS
|
@@ -44,59 +46,109 @@ class HiCacheNixl(HiCacheStorage):
|
|
44
46
|
|
45
47
|
self.registration = NixlRegistration(self.agent)
|
46
48
|
|
49
|
+
def register_buffers(
|
50
|
+
self, buffers: Union[torch.Tensor, List[torch.Tensor], List[tuple]]
|
51
|
+
) -> Optional[Any]:
|
52
|
+
"""Register tensor(s) or target locations in host memory (list of addr,len tuples) with NIXL."""
|
53
|
+
if isinstance(buffers[0], tuple):
|
54
|
+
tuples = [(x[0], x[1], 0, "") for x in buffers]
|
55
|
+
return self.registration._register_memory(tuples, "DRAM")
|
56
|
+
else:
|
57
|
+
return self.registration._register_memory(buffers)
|
58
|
+
|
59
|
+
def register_files(
|
60
|
+
self, file_paths: List[str], open_file: Optional[bool] = True
|
61
|
+
) -> Optional[Any]:
|
62
|
+
"""Register files with NIXL."""
|
63
|
+
tuples = self.file_manager.files_to_nixl_tuples(file_paths)
|
64
|
+
return self.registration._register_memory(tuples, "FILE")
|
65
|
+
|
66
|
+
def register_objects(
|
67
|
+
self, keys: List[str], sizes: Optional[List[int]] = None
|
68
|
+
) -> Optional[Any]:
|
69
|
+
"""Register objects with NIXL."""
|
70
|
+
if not keys:
|
71
|
+
return None
|
72
|
+
tuples = [(0, 0, key, "") for key in keys]
|
73
|
+
return self.registration._register_memory(tuples, "OBJ")
|
74
|
+
|
47
75
|
def _execute_transfer(
|
48
|
-
self,
|
76
|
+
self,
|
77
|
+
buffers: Optional[List[torch.Tensor | tuple]],
|
78
|
+
keys: List[str],
|
79
|
+
direction: str,
|
49
80
|
) -> bool:
|
50
|
-
if len(
|
51
|
-
logger.error("Mismatch between number of tensors and files/objects")
|
81
|
+
if len(buffers) != len(keys):
|
82
|
+
logger.error("Mismatch between number of tensors/buffers and files/objects")
|
52
83
|
return False
|
53
84
|
|
54
|
-
|
55
|
-
|
56
|
-
return False
|
57
|
-
|
58
|
-
# Get transfer tuples based on backend type
|
59
|
-
tensor_sizes = [tensor.element_size() * tensor.numel() for tensor in tensors]
|
85
|
+
# Registering file and object keys per transfer, to be updated when
|
86
|
+
# pre-registration for file and object is added to HiCache.
|
60
87
|
if self.backend_selector.mem_type == "FILE":
|
61
|
-
|
62
|
-
if not
|
88
|
+
tuples = self.file_manager.files_to_nixl_tuples(keys)
|
89
|
+
if not tuples or not self.registration._register_memory(tuples, "FILE"):
|
63
90
|
logger.error("Failed to prepare files for transfer")
|
64
91
|
return False
|
65
|
-
|
66
|
-
|
67
|
-
|
68
|
-
else:
|
69
|
-
if not self.registration.register_objects(keys, tensors):
|
92
|
+
else: # mem_type == "OBJ"
|
93
|
+
tuples = [(0, 0, key, "") for key in keys]
|
94
|
+
if not tuples or not self.registration._register_memory(tuples, "OBJ"):
|
70
95
|
logger.error("Failed to register objects")
|
71
96
|
return False
|
72
|
-
transfer_tuples = [(0, s, key) for s, key in zip(tensor_sizes, keys)]
|
73
97
|
|
98
|
+
# Prepare transfer descriptors
|
99
|
+
if isinstance(buffers[0], torch.Tensor):
|
100
|
+
tensor_sizes = [
|
101
|
+
tensor.element_size() * tensor.numel() for tensor in buffers
|
102
|
+
]
|
103
|
+
storage_tuples = [(x[0], s, x[2]) for x, s in zip(tuples, tensor_sizes)]
|
104
|
+
host_descs = self.agent.get_xfer_descs(buffers)
|
105
|
+
elif isinstance(buffers[0], tuple):
|
106
|
+
storage_tuples = [(x[0], y[1], x[2]) for x, y in zip(tuples, buffers)]
|
107
|
+
host_descs = self.agent.get_xfer_descs(
|
108
|
+
[(x[0], x[1], 0) for x in buffers], "DRAM"
|
109
|
+
)
|
110
|
+
else:
|
111
|
+
return False
|
112
|
+
|
113
|
+
storage_descs = self.agent.get_xfer_descs(
|
114
|
+
storage_tuples, self.backend_selector.mem_type
|
115
|
+
)
|
116
|
+
|
117
|
+
if (host_descs is None) or (storage_descs is None):
|
118
|
+
logger.error("Failed to get transfer descriptors")
|
119
|
+
return False
|
120
|
+
|
121
|
+
# Initialize transfer, default assumption that tensor was registered
|
74
122
|
try:
|
75
|
-
|
76
|
-
|
77
|
-
|
78
|
-
|
79
|
-
|
80
|
-
|
81
|
-
logger.error("Failed to
|
123
|
+
xfer_req = self.agent.initialize_xfer(
|
124
|
+
direction, host_descs, storage_descs, self.agent_name
|
125
|
+
)
|
126
|
+
except Exception:
|
127
|
+
# Check if it was due to missing pre-registration
|
128
|
+
if not self.register_buffers(buffers):
|
129
|
+
logger.error("Failed to register tensors/buffers")
|
82
130
|
return False
|
83
131
|
|
84
|
-
|
85
|
-
|
86
|
-
|
87
|
-
direction, tensor_descs, file_descs, self.agent_name
|
132
|
+
try:
|
133
|
+
xfer_req = self.agent.initialize_xfer(
|
134
|
+
direction, host_descs, storage_descs, self.agent_name
|
88
135
|
)
|
89
|
-
|
90
|
-
logger.error("Failed to create transfer request")
|
136
|
+
except Exception as e:
|
137
|
+
logger.error(f"Failed to create transfer request: {e}")
|
91
138
|
return False
|
92
139
|
|
140
|
+
# Execute transfer and wait for its completion
|
141
|
+
try:
|
93
142
|
state = self.agent.transfer(xfer_req)
|
94
143
|
while state != "DONE":
|
95
144
|
state = self.agent.check_xfer_state(xfer_req)
|
96
145
|
if state == "ERR":
|
146
|
+
self.agent.release_xfer_handle(xfer_req)
|
97
147
|
logger.error("Transfer failed")
|
98
148
|
return False
|
99
|
-
|
149
|
+
time.sleep(0.0001) # Can be changed to os.sched_yield() or parametrized
|
150
|
+
|
151
|
+
self.agent.release_xfer_handle(xfer_req)
|
100
152
|
return True
|
101
153
|
|
102
154
|
except Exception as e:
|
@@ -106,45 +158,87 @@ class HiCacheNixl(HiCacheStorage):
|
|
106
158
|
logger.error(f"Traceback: {traceback.format_exc()}")
|
107
159
|
return False
|
108
160
|
|
109
|
-
def batch_set(self, keys: List[str], values: List[torch.Tensor]) -> bool:
|
110
|
-
if not keys:
|
111
|
-
return True
|
112
|
-
|
113
|
-
if self.backend_selector.mem_type == "FILE":
|
114
|
-
file_paths = []
|
115
|
-
for key in keys:
|
116
|
-
tensor_path = self.file_manager.get_file_path(key)
|
117
|
-
if not self.file_manager.create_file(tensor_path):
|
118
|
-
logger.error(f"Failed to create file {tensor_path}")
|
119
|
-
return False
|
120
|
-
file_paths.append(tensor_path)
|
121
|
-
return self._execute_transfer(values, file_paths, "WRITE")
|
122
|
-
else:
|
123
|
-
return self._execute_transfer(values, keys, "WRITE")
|
124
|
-
|
125
|
-
def set(self, key: str, value: torch.Tensor) -> bool:
|
126
|
-
return self.batch_set([key], [value])
|
127
|
-
|
128
161
|
def get(
|
129
|
-
self,
|
162
|
+
self,
|
163
|
+
key: str,
|
164
|
+
target_location: Optional[torch.Tensor | int] = None,
|
165
|
+
target_sizes: Optional[int] = None,
|
130
166
|
) -> torch.Tensor | None:
|
131
|
-
|
167
|
+
# To be removed, being compatible with the current API
|
168
|
+
if target_location is None:
|
132
169
|
return None
|
133
|
-
|
170
|
+
if target_sizes:
|
171
|
+
result = self.batch_get([key], [target_location], [target_sizes])
|
172
|
+
else:
|
173
|
+
result = self.batch_get([key], [target_location])
|
134
174
|
return result[0] if result else None
|
135
175
|
|
136
176
|
def batch_get(
|
137
|
-
self,
|
138
|
-
|
177
|
+
self,
|
178
|
+
keys: List[str],
|
179
|
+
target_locations: Optional[List[torch.Tensor | int]] = None,
|
180
|
+
target_sizes: Optional[List[int]] = None,
|
181
|
+
) -> List[torch.Tensor | None]:
|
139
182
|
if not keys:
|
140
183
|
return []
|
141
184
|
|
185
|
+
# To be removed, being compatible with the current API
|
186
|
+
if not target_locations:
|
187
|
+
return [None] * len(keys)
|
188
|
+
|
189
|
+
if target_sizes and (len(target_sizes) != len(target_locations)):
|
190
|
+
logger.error("Mismatch between number of target_locations and target_sizes")
|
191
|
+
return [None] * len(keys)
|
192
|
+
if target_sizes:
|
193
|
+
dest = list(zip(target_locations, target_sizes))
|
194
|
+
else:
|
195
|
+
dest = target_locations
|
196
|
+
|
142
197
|
if self.backend_selector.mem_type == "FILE":
|
143
198
|
file_paths = [self.file_manager.get_file_path(key) for key in keys]
|
144
|
-
success = self._execute_transfer(
|
199
|
+
success = self._execute_transfer(dest, file_paths, "READ")
|
145
200
|
else:
|
146
|
-
success = self._execute_transfer(
|
147
|
-
return
|
201
|
+
success = self._execute_transfer(dest, keys, "READ")
|
202
|
+
return target_locations if success and not target_sizes else [None] * len(keys)
|
203
|
+
|
204
|
+
def set(
|
205
|
+
self,
|
206
|
+
key: str,
|
207
|
+
value: Optional[torch.Tensor] = None,
|
208
|
+
target_location: Optional[int] = None,
|
209
|
+
target_sizes: Optional[int] = None,
|
210
|
+
) -> bool:
|
211
|
+
if target_location and target_sizes:
|
212
|
+
return self.batch_set([key], None, [target_location], [target_sizes])
|
213
|
+
else:
|
214
|
+
return self.batch_set([key], [value])
|
215
|
+
|
216
|
+
def batch_set(
|
217
|
+
self,
|
218
|
+
keys: List[str],
|
219
|
+
values: Optional[List[torch.Tensor]] = None,
|
220
|
+
target_locations: Optional[List[int]] = None,
|
221
|
+
target_sizes: Optional[List[int]] = None,
|
222
|
+
) -> bool:
|
223
|
+
if not keys or (not values and (not target_locations or not target_sizes)):
|
224
|
+
logger.error("Keys or values were not passed")
|
225
|
+
return False
|
226
|
+
|
227
|
+
if not values:
|
228
|
+
values = list(zip(target_locations, target_sizes))
|
229
|
+
|
230
|
+
if self.backend_selector.mem_type == "FILE":
|
231
|
+
file_paths = []
|
232
|
+
for key in keys:
|
233
|
+
file_path = self.file_manager.get_file_path(key)
|
234
|
+
# New file per set, to be updated when partial writes is added to HiCache
|
235
|
+
if not self.file_manager.create_file(file_path):
|
236
|
+
logger.error(f"Failed to create file {file_path}")
|
237
|
+
return False
|
238
|
+
file_paths.append(file_path)
|
239
|
+
return self._execute_transfer(values, file_paths, "WRITE")
|
240
|
+
else: # mem_type == "OBJ"
|
241
|
+
return self._execute_transfer(values, keys, "WRITE")
|
148
242
|
|
149
243
|
def exists(self, key: str) -> bool:
|
150
244
|
tuples = self.registration.create_query_tuples(
|
@@ -109,66 +109,35 @@ class NixlRegistration:
|
|
109
109
|
return [(0, 0, key)]
|
110
110
|
|
111
111
|
def _register_memory(
|
112
|
-
self,
|
112
|
+
self,
|
113
|
+
items: Union[List[tuple], torch.Tensor, List[torch.Tensor]],
|
114
|
+
mem_type: Optional[str] = None,
|
113
115
|
) -> Optional[Any]:
|
114
116
|
"""Common registration logic for files, objects, and buffers.
|
115
117
|
Args:
|
116
118
|
items: List of tuples or tensors to register
|
117
|
-
mem_type: Memory type ("FILE", "OBJ"
|
118
|
-
desc: Description for logging
|
119
|
+
mem_type: Memory type ("FILE", "OBJ") or None for tensor or list of tensors
|
119
120
|
"""
|
120
|
-
|
121
|
-
if not items:
|
122
|
-
return None
|
123
|
-
|
124
|
-
reg_descs = self.agent.get_reg_descs(items, mem_type)
|
125
|
-
if reg_descs is None:
|
126
|
-
logger.error("Failed to create registration descriptors")
|
127
|
-
return None
|
128
|
-
|
129
|
-
registered_memory = self.agent.register_memory(reg_descs)
|
130
|
-
if registered_memory:
|
131
|
-
return registered_memory
|
132
|
-
else:
|
133
|
-
logger.error("Failed to register with NIXL")
|
134
|
-
return None
|
135
|
-
|
136
|
-
except Exception as e:
|
137
|
-
logger.error(f"Failed to register {desc}: {e}")
|
121
|
+
if isinstance(items, list) and not items:
|
138
122
|
return None
|
139
123
|
|
140
|
-
|
141
|
-
|
142
|
-
|
143
|
-
"""Register tensors/buffers with NIXL."""
|
144
|
-
if isinstance(buffers, torch.Tensor):
|
145
|
-
buffers = [buffers]
|
146
|
-
|
147
|
-
if not buffers:
|
124
|
+
reg_descs = self.agent.get_reg_descs(items, mem_type)
|
125
|
+
if reg_descs is None:
|
126
|
+
logger.error("Failed to create registration descriptors")
|
148
127
|
return None
|
149
128
|
|
150
|
-
|
151
|
-
|
152
|
-
|
153
|
-
|
154
|
-
|
155
|
-
|
156
|
-
|
157
|
-
|
158
|
-
|
159
|
-
|
160
|
-
) -> Optional[Any]:
|
161
|
-
"""Register objects with NIXL."""
|
162
|
-
if not keys:
|
129
|
+
try:
|
130
|
+
registered_memory = self.agent.register_memory(reg_descs)
|
131
|
+
return registered_memory # Could be None in case of error
|
132
|
+
except Exception as e:
|
133
|
+
if not mem_type:
|
134
|
+
logger.error(f"Failed to register Tensors with NIXL: {e}")
|
135
|
+
else:
|
136
|
+
logger.error(
|
137
|
+
f"Failed to register memory of type {mem_type} with NIXL: {e}"
|
138
|
+
)
|
163
139
|
return None
|
164
140
|
|
165
|
-
# Create object tuples with proper sizes
|
166
|
-
tuples = [
|
167
|
-
(0, tensor.element_size() * tensor.numel() if tensor else 0, key)
|
168
|
-
for key, tensor in zip(keys, tensors or [None] * len(keys))
|
169
|
-
]
|
170
|
-
return self._register_memory(tuples, "OBJ", "objects")
|
171
|
-
|
172
141
|
|
173
142
|
class NixlFileManager:
|
174
143
|
"""Handles file system operations for NIXL."""
|
@@ -221,12 +190,9 @@ class NixlFileManager:
|
|
221
190
|
return False
|
222
191
|
|
223
192
|
def files_to_nixl_tuples(
|
224
|
-
self, file_paths: List[str]
|
193
|
+
self, file_paths: List[str]
|
225
194
|
) -> List[Tuple[int, int, int, str]]:
|
226
195
|
"""Create NIXL tuples (offset, length, fd, file_path) for given files."""
|
227
|
-
if not open_file:
|
228
|
-
return [(0, 0, 0, path) for path in file_paths]
|
229
|
-
|
230
196
|
tuples = []
|
231
197
|
for path in file_paths:
|
232
198
|
if (fd := self.open_file(path)) is None:
|
@@ -7,8 +7,11 @@ from unittest.mock import MagicMock
|
|
7
7
|
|
8
8
|
import torch
|
9
9
|
|
10
|
-
from sglang.srt.mem_cache.nixl.hicache_nixl import HiCacheNixl
|
11
|
-
from sglang.srt.mem_cache.nixl.nixl_utils import
|
10
|
+
from sglang.srt.mem_cache.storage.nixl.hicache_nixl import HiCacheNixl
|
11
|
+
from sglang.srt.mem_cache.storage.nixl.nixl_utils import (
|
12
|
+
NixlFileManager,
|
13
|
+
NixlRegistration,
|
14
|
+
)
|
12
15
|
|
13
16
|
|
14
17
|
class TestNixlUnified(unittest.TestCase):
|
@@ -88,8 +91,27 @@ class TestNixlUnified(unittest.TestCase):
|
|
88
91
|
|
89
92
|
# Test get
|
90
93
|
retrieved = self.hicache.get(key, dst_tensor)
|
94
|
+
self.verify_tensors_equal(value, dst_tensor)
|
91
95
|
self.verify_tensors_equal(value, retrieved)
|
92
96
|
|
97
|
+
# Same test in addr,len mode with another key and dst_tensor
|
98
|
+
key2 = "test_key2"
|
99
|
+
dst_tensor2 = torch.zeros_like(value, device="cpu")
|
100
|
+
src_addr, src_len = value.data_ptr(), value.numel() * value.element_size()
|
101
|
+
dst_addr, dst_len = (
|
102
|
+
dst_tensor2.data_ptr(),
|
103
|
+
dst_tensor2.numel() * dst_tensor2.element_size(),
|
104
|
+
)
|
105
|
+
|
106
|
+
# Test set
|
107
|
+
self.assertTrue(self.hicache.set(key, None, src_addr, src_len))
|
108
|
+
self.assertTrue(self.hicache.exists(key))
|
109
|
+
|
110
|
+
# Test get
|
111
|
+
retrieved2 = self.hicache.get(key, dst_addr, dst_len)
|
112
|
+
self.assertTrue(retrieved2 == None)
|
113
|
+
self.verify_tensors_equal(value, dst_tensor2)
|
114
|
+
|
93
115
|
def test_batch_set_get(self):
|
94
116
|
"""Test batch tensor set/get operations."""
|
95
117
|
keys = ["key1", "key2", "key3"]
|
@@ -108,6 +130,23 @@ class TestNixlUnified(unittest.TestCase):
|
|
108
130
|
retrieved = self.hicache.batch_get(keys, dst_tensors)
|
109
131
|
self.verify_tensor_lists_equal(values, retrieved)
|
110
132
|
|
133
|
+
# Same test in addr,len mode with another key and dst_tensor
|
134
|
+
keys2 = ["key4", "key5", "key6"]
|
135
|
+
dst_tensors2 = [torch.zeros_like(v, device="cpu") for v in values]
|
136
|
+
src_addrs = [v.data_ptr() for v in values]
|
137
|
+
src_lens = [v.numel() * v.element_size() for v in values]
|
138
|
+
dst_addrs = [dt.data_ptr() for dt in dst_tensors2]
|
139
|
+
dst_lens = [dt.numel() * dt.element_size() for dt in dst_tensors2]
|
140
|
+
|
141
|
+
# Test batch set
|
142
|
+
self.assertTrue(self.hicache.batch_set(keys2, None, src_addrs, src_lens))
|
143
|
+
self.assertTrue(all(self.hicache.exists(key) for key in keys2))
|
144
|
+
|
145
|
+
# Test batch get
|
146
|
+
retrieved2 = self.hicache.batch_get(keys, dst_addrs, dst_lens)
|
147
|
+
self.assertTrue(all(ret == None for ret in retrieved2))
|
148
|
+
self.verify_tensor_lists_equal(values, dst_tensors2)
|
149
|
+
|
111
150
|
def test_mixed_operations(self):
|
112
151
|
"""Test mixing single and batch operations."""
|
113
152
|
# Test interleaved set/get operations
|
@@ -170,7 +209,7 @@ class TestNixlUnified(unittest.TestCase):
|
|
170
209
|
self.file_manager.create_file(test_file)
|
171
210
|
|
172
211
|
# Test tuple creation
|
173
|
-
tuples = self.file_manager.files_to_nixl_tuples([test_file]
|
212
|
+
tuples = self.file_manager.files_to_nixl_tuples([test_file])
|
174
213
|
self.assertIsNotNone(tuples)
|
175
214
|
self.assertTrue(len(tuples) > 0)
|
176
215
|
|
@@ -190,11 +229,11 @@ class TestNixlUnified(unittest.TestCase):
|
|
190
229
|
tensor = torch.randn(10, 10)
|
191
230
|
|
192
231
|
# Test buffer registration
|
193
|
-
self.assertIsNotNone(self.
|
232
|
+
self.assertIsNotNone(self.hicache.register_buffers(tensor))
|
194
233
|
|
195
234
|
# Test batch registration
|
196
235
|
tensors = [torch.randn(5, 5) for _ in range(3)]
|
197
|
-
self.assertIsNotNone(self.
|
236
|
+
self.assertIsNotNone(self.hicache.register_buffers(tensors))
|
198
237
|
|
199
238
|
def test_register_files_with_tuples(self):
|
200
239
|
"""Test registration of files using NIXL tuples."""
|
@@ -203,8 +242,8 @@ class TestNixlUnified(unittest.TestCase):
|
|
203
242
|
self.file_manager.create_file(file)
|
204
243
|
|
205
244
|
# Create tuples and register
|
206
|
-
tuples = self.file_manager.files_to_nixl_tuples(files
|
207
|
-
self.
|
245
|
+
tuples = self.file_manager.files_to_nixl_tuples(files)
|
246
|
+
self.hicache.register_files(tuples)
|
208
247
|
|
209
248
|
# Verify tuples
|
210
249
|
self.assertEqual(len(tuples), len(files))
|
@@ -240,6 +240,8 @@ class CudaGraphRunner:
|
|
240
240
|
def __init__(self, model_runner: ModelRunner):
|
241
241
|
# Parse args
|
242
242
|
self.model_runner = model_runner
|
243
|
+
self.device = model_runner.device
|
244
|
+
self.device_module = torch.get_device_module(self.device)
|
243
245
|
self.graphs = {}
|
244
246
|
self.output_buffers = {}
|
245
247
|
self.enable_torch_compile = model_runner.server_args.enable_torch_compile
|
@@ -305,13 +307,15 @@ class CudaGraphRunner:
|
|
305
307
|
self.model_runner.lora_manager.init_cuda_graph_batch_info(self.max_bs)
|
306
308
|
|
307
309
|
# Graph inputs
|
308
|
-
with torch.device(
|
310
|
+
with torch.device(self.device):
|
309
311
|
self.input_ids = torch.zeros((self.max_num_token,), dtype=torch.int64)
|
310
312
|
self.req_pool_indices = torch.zeros((self.max_bs,), dtype=torch.int32)
|
311
313
|
self.seq_lens = torch.full(
|
312
314
|
(self.max_bs,), self.seq_len_fill_value, dtype=torch.int32
|
313
315
|
)
|
314
|
-
self.out_cache_loc = torch.zeros(
|
316
|
+
self.out_cache_loc = torch.zeros(
|
317
|
+
(self.max_num_token,), dtype=self._cache_loc_dtype()
|
318
|
+
)
|
315
319
|
self.positions = torch.zeros((self.max_num_token,), dtype=torch.int64)
|
316
320
|
self.mrope_positions = torch.zeros((3, self.max_bs), dtype=torch.int64)
|
317
321
|
self.num_token_non_padded = torch.zeros((1,), dtype=torch.int32)
|
@@ -366,12 +370,12 @@ class CudaGraphRunner:
|
|
366
370
|
* self.num_tokens_per_bs
|
367
371
|
),
|
368
372
|
dtype=torch.bool,
|
369
|
-
device=
|
373
|
+
device=self.device,
|
370
374
|
)
|
371
375
|
self.next_token_logits_buffer = torch.zeros(
|
372
376
|
(self.max_num_token, self.model_runner.model_config.vocab_size),
|
373
377
|
dtype=torch.float,
|
374
|
-
device=
|
378
|
+
device=self.device,
|
375
379
|
)
|
376
380
|
|
377
381
|
# Capture
|
@@ -383,6 +387,9 @@ class CudaGraphRunner:
|
|
383
387
|
f"Capture cuda graph failed: {e}\n{CUDA_GRAPH_CAPTURE_FAILED_MSG}"
|
384
388
|
)
|
385
389
|
|
390
|
+
def _cache_loc_dtype(self):
|
391
|
+
return torch.int64
|
392
|
+
|
386
393
|
def can_run(self, forward_batch: ForwardBatch):
|
387
394
|
if self.require_mlp_tp_gather:
|
388
395
|
cuda_graph_bs = (
|
@@ -502,8 +509,16 @@ class CudaGraphRunner:
|
|
502
509
|
)
|
503
510
|
logger.info(log_message)
|
504
511
|
|
512
|
+
def _capture_graph(self, graph, pool, stream, run_once_fn):
|
513
|
+
with self.device_module.graph(graph, pool=pool, stream=stream):
|
514
|
+
out = run_once_fn()
|
515
|
+
return out
|
516
|
+
|
517
|
+
def _create_device_graph(self):
|
518
|
+
return torch.cuda.CUDAGraph()
|
519
|
+
|
505
520
|
def capture_one_batch_size(self, bs: int, forward: Callable):
|
506
|
-
graph =
|
521
|
+
graph = self._create_device_graph()
|
507
522
|
stream = self.stream
|
508
523
|
num_tokens = bs * self.num_tokens_per_bs
|
509
524
|
|
@@ -643,19 +658,17 @@ class CudaGraphRunner:
|
|
643
658
|
return logits_output_or_pp_proxy_tensors
|
644
659
|
|
645
660
|
for _ in range(2):
|
646
|
-
|
661
|
+
self.device_module.synchronize()
|
647
662
|
self.model_runner.tp_group.barrier()
|
648
|
-
|
649
663
|
run_once()
|
650
664
|
|
651
665
|
if get_global_graph_memory_pool() is None:
|
652
|
-
set_global_graph_memory_pool(
|
666
|
+
set_global_graph_memory_pool(self.device_module.graph_pool_handle())
|
653
667
|
# Set graph pool id globally to be able to use symmetric memory
|
654
668
|
set_graph_pool_id(get_global_graph_memory_pool())
|
655
|
-
|
656
|
-
graph,
|
657
|
-
)
|
658
|
-
out = run_once()
|
669
|
+
out = self._capture_graph(
|
670
|
+
graph, get_global_graph_memory_pool(), stream, run_once
|
671
|
+
)
|
659
672
|
|
660
673
|
return graph, out
|
661
674
|
|
@@ -241,6 +241,9 @@ class ForwardBatch:
|
|
241
241
|
prefix_chunk_num_tokens: Optional[List[int]] = None
|
242
242
|
# KV Indices for each chunk
|
243
243
|
prefix_chunk_kv_indices: Optional[List[torch.Tensor]] = None
|
244
|
+
# For MLA chunked prefix cache used in chunked prefill
|
245
|
+
# Tell attention backend whether lse needs to be returned
|
246
|
+
mha_return_lse: Optional[bool] = None
|
244
247
|
|
245
248
|
# For multimodal
|
246
249
|
mm_inputs: Optional[List[MultimodalInputs]] = None
|
@@ -649,7 +652,7 @@ class ForwardBatch:
|
|
649
652
|
num_tokens = global_num_tokens[0]
|
650
653
|
|
651
654
|
self.global_dp_buffer_len = buffer_len
|
652
|
-
set_dp_buffer_len(buffer_len, num_tokens)
|
655
|
+
set_dp_buffer_len(buffer_len, num_tokens, global_num_tokens)
|
653
656
|
|
654
657
|
bs = self.batch_size
|
655
658
|
|