sglang 0.4.5__py3-none-any.whl → 0.4.5.post2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +2 -4
- sglang/bench_one_batch.py +23 -2
- sglang/bench_serving.py +6 -4
- sglang/lang/backend/anthropic.py +0 -4
- sglang/lang/backend/base_backend.py +1 -1
- sglang/lang/backend/openai.py +1 -1
- sglang/lang/backend/vertexai.py +0 -1
- sglang/lang/compiler.py +1 -7
- sglang/lang/tracer.py +3 -7
- sglang/srt/_custom_ops.py +0 -2
- sglang/srt/configs/model_config.py +37 -5
- sglang/srt/constrained/base_grammar_backend.py +26 -5
- sglang/srt/constrained/llguidance_backend.py +1 -0
- sglang/srt/constrained/outlines_backend.py +1 -0
- sglang/srt/constrained/outlines_jump_forward.py +14 -1
- sglang/srt/constrained/reasoner_grammar_backend.py +101 -0
- sglang/srt/constrained/triton_ops/bitmask_ops.py +141 -0
- sglang/srt/constrained/xgrammar_backend.py +27 -4
- sglang/srt/custom_op.py +0 -62
- sglang/srt/disaggregation/base/__init__.py +8 -0
- sglang/srt/disaggregation/base/conn.py +113 -0
- sglang/srt/disaggregation/decode.py +80 -11
- sglang/srt/disaggregation/mini_lb.py +58 -123
- sglang/srt/disaggregation/mooncake/__init__.py +6 -0
- sglang/srt/disaggregation/mooncake/conn.py +585 -0
- sglang/srt/disaggregation/mooncake/transfer_engine.py +77 -0
- sglang/srt/disaggregation/prefill.py +82 -22
- sglang/srt/disaggregation/utils.py +46 -0
- sglang/srt/entrypoints/EngineBase.py +53 -0
- sglang/srt/entrypoints/engine.py +36 -8
- sglang/srt/entrypoints/http_server.py +37 -8
- sglang/srt/entrypoints/http_server_engine.py +142 -0
- sglang/srt/entrypoints/verl_engine.py +42 -13
- sglang/srt/hf_transformers_utils.py +4 -0
- sglang/srt/layers/activation.py +6 -8
- sglang/srt/layers/attention/flashattention_backend.py +430 -257
- sglang/srt/layers/attention/flashinfer_backend.py +18 -9
- sglang/srt/layers/attention/torch_native_backend.py +6 -1
- sglang/srt/layers/attention/triton_backend.py +6 -0
- sglang/srt/layers/attention/triton_ops/extend_attention.py +13 -2
- sglang/srt/layers/attention/vision.py +1 -1
- sglang/srt/layers/dp_attention.py +2 -4
- sglang/srt/layers/elementwise.py +15 -2
- sglang/srt/layers/layernorm.py +1 -1
- sglang/srt/layers/linear.py +18 -3
- sglang/srt/layers/moe/ep_moe/layer.py +15 -29
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +145 -118
- sglang/srt/layers/moe/fused_moe_native.py +4 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/{E=257,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json → E=264,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +34 -34
- sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=64,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=288,N=64,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +46 -34
- sglang/srt/layers/moe/fused_moe_triton/layer.py +7 -0
- sglang/srt/layers/moe/router.py +7 -1
- sglang/srt/layers/moe/topk.py +63 -45
- sglang/srt/layers/parameter.py +0 -2
- sglang/srt/layers/quantization/__init__.py +13 -5
- sglang/srt/layers/quantization/blockwise_int8.py +2 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +12 -2
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +72 -77
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +4 -7
- sglang/srt/layers/quantization/fp8.py +131 -136
- sglang/srt/layers/quantization/fp8_kernel.py +328 -46
- sglang/srt/layers/quantization/fp8_utils.py +206 -253
- sglang/srt/layers/quantization/kv_cache.py +43 -52
- sglang/srt/layers/quantization/modelopt_quant.py +271 -4
- sglang/srt/layers/quantization/moe_wna16.py +2 -0
- sglang/srt/layers/quantization/utils.py +5 -11
- sglang/srt/layers/quantization/w8a8_fp8.py +156 -4
- sglang/srt/layers/quantization/w8a8_int8.py +8 -7
- sglang/srt/layers/radix_attention.py +28 -1
- sglang/srt/layers/rotary_embedding.py +15 -3
- sglang/srt/layers/sampler.py +5 -10
- sglang/srt/lora/backend/base_backend.py +18 -2
- sglang/srt/lora/backend/flashinfer_backend.py +1 -1
- sglang/srt/lora/backend/triton_backend.py +1 -1
- sglang/srt/lora/layers.py +1 -1
- sglang/srt/lora/lora.py +1 -1
- sglang/srt/lora/lora_manager.py +1 -1
- sglang/srt/managers/detokenizer_manager.py +0 -1
- sglang/srt/managers/io_struct.py +255 -97
- sglang/srt/managers/mm_utils.py +7 -5
- sglang/srt/managers/multimodal_processor.py +0 -2
- sglang/srt/managers/multimodal_processors/base_processor.py +117 -79
- sglang/srt/managers/multimodal_processors/janus_pro.py +3 -1
- sglang/srt/managers/multimodal_processors/mllama4.py +21 -36
- sglang/srt/managers/schedule_batch.py +64 -25
- sglang/srt/managers/scheduler.py +80 -82
- sglang/srt/managers/tokenizer_manager.py +18 -3
- sglang/srt/managers/tp_worker.py +1 -0
- sglang/srt/mem_cache/hiradix_cache.py +5 -1
- sglang/srt/mem_cache/memory_pool.py +21 -3
- sglang/srt/metrics/collector.py +9 -0
- sglang/srt/model_executor/cuda_graph_runner.py +9 -6
- sglang/srt/model_executor/forward_batch_info.py +234 -15
- sglang/srt/model_executor/model_runner.py +67 -35
- sglang/srt/model_loader/loader.py +31 -4
- sglang/srt/model_loader/weight_utils.py +4 -2
- sglang/srt/models/baichuan.py +2 -0
- sglang/srt/models/bert.py +398 -0
- sglang/srt/models/chatglm.py +1 -0
- sglang/srt/models/commandr.py +1 -0
- sglang/srt/models/dbrx.py +1 -0
- sglang/srt/models/deepseek.py +2 -1
- sglang/srt/models/deepseek_nextn.py +74 -70
- sglang/srt/models/deepseek_v2.py +494 -366
- sglang/srt/models/exaone.py +1 -0
- sglang/srt/models/gemma.py +1 -0
- sglang/srt/models/gemma2.py +1 -0
- sglang/srt/models/gemma3_causal.py +1 -0
- sglang/srt/models/gpt2.py +1 -0
- sglang/srt/models/gpt_bigcode.py +1 -0
- sglang/srt/models/granite.py +1 -0
- sglang/srt/models/grok.py +1 -0
- sglang/srt/models/internlm2.py +1 -0
- sglang/srt/models/llama.py +6 -5
- sglang/srt/models/llama4.py +101 -34
- sglang/srt/models/minicpm.py +1 -0
- sglang/srt/models/minicpm3.py +30 -200
- sglang/srt/models/mixtral.py +1 -0
- sglang/srt/models/mixtral_quant.py +1 -0
- sglang/srt/models/mllama.py +51 -8
- sglang/srt/models/mllama4.py +102 -29
- sglang/srt/models/olmo.py +1 -0
- sglang/srt/models/olmo2.py +1 -0
- sglang/srt/models/olmoe.py +1 -0
- sglang/srt/models/phi3_small.py +1 -0
- sglang/srt/models/qwen.py +1 -0
- sglang/srt/models/qwen2.py +5 -1
- sglang/srt/models/qwen2_5_vl.py +35 -70
- sglang/srt/models/qwen2_moe.py +15 -13
- sglang/srt/models/qwen2_vl.py +27 -25
- sglang/srt/models/qwen3.py +335 -0
- sglang/srt/models/qwen3_moe.py +423 -0
- sglang/srt/models/stablelm.py +1 -0
- sglang/srt/models/xverse.py +1 -0
- sglang/srt/models/xverse_moe.py +1 -0
- sglang/srt/openai_api/adapter.py +4 -1
- sglang/srt/patch_torch.py +11 -0
- sglang/srt/reasoning_parser.py +0 -1
- sglang/srt/sampling/sampling_batch_info.py +2 -3
- sglang/srt/server_args.py +55 -19
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +4 -4
- sglang/srt/speculative/eagle_utils.py +1 -11
- sglang/srt/speculative/eagle_worker.py +10 -9
- sglang/srt/utils.py +136 -10
- sglang/test/attention/test_flashattn_backend.py +259 -221
- sglang/test/attention/test_flashattn_mla_backend.py +285 -0
- sglang/test/attention/test_prefix_chunk_info.py +224 -0
- sglang/test/runners.py +5 -1
- sglang/test/test_block_fp8.py +224 -0
- sglang/test/test_custom_ops.py +1 -1
- sglang/test/test_utils.py +19 -8
- sglang/version.py +1 -1
- {sglang-0.4.5.dist-info → sglang-0.4.5.post2.dist-info}/METADATA +15 -5
- {sglang-0.4.5.dist-info → sglang-0.4.5.post2.dist-info}/RECORD +162 -147
- {sglang-0.4.5.dist-info → sglang-0.4.5.post2.dist-info}/WHEEL +1 -1
- sglang/lang/__init__.py +0 -0
- sglang/srt/disaggregation/conn.py +0 -81
- sglang/srt/lora/backend/__init__.py +0 -25
- sglang/srt/server.py +0 -18
- {sglang-0.4.5.dist-info → sglang-0.4.5.post2.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.5.dist-info → sglang-0.4.5.post2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,585 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
import asyncio
|
4
|
+
import dataclasses
|
5
|
+
import logging
|
6
|
+
import queue
|
7
|
+
import socket
|
8
|
+
import struct
|
9
|
+
import threading
|
10
|
+
from functools import cache
|
11
|
+
from typing import Dict, List, Optional, Tuple, Union
|
12
|
+
|
13
|
+
import numpy as np
|
14
|
+
import numpy.typing as npt
|
15
|
+
import requests
|
16
|
+
import zmq
|
17
|
+
from aiohttp import web
|
18
|
+
|
19
|
+
from sglang.srt.disaggregation.base.conn import (
|
20
|
+
BaseKVBootstrapServer,
|
21
|
+
BaseKVManager,
|
22
|
+
BaseKVReceiver,
|
23
|
+
BaseKVSender,
|
24
|
+
KVArgs,
|
25
|
+
KVPoll,
|
26
|
+
)
|
27
|
+
from sglang.srt.disaggregation.mooncake.transfer_engine import MooncakeTransferEngine
|
28
|
+
from sglang.srt.disaggregation.utils import DisaggregationMode
|
29
|
+
from sglang.srt.server_args import ServerArgs
|
30
|
+
from sglang.srt.utils import get_free_port, get_ip, get_local_ip_by_remote
|
31
|
+
|
32
|
+
logger = logging.getLogger(__name__)
|
33
|
+
|
34
|
+
|
35
|
+
def group_concurrent_contiguous(
|
36
|
+
src_indices: npt.NDArray[np.int64], dst_indices: npt.NDArray[np.int64]
|
37
|
+
) -> Tuple[List[npt.NDArray[np.int64]], List[npt.NDArray[np.int64]]]:
|
38
|
+
src_groups = []
|
39
|
+
dst_groups = []
|
40
|
+
current_src = [src_indices[0]]
|
41
|
+
current_dst = [dst_indices[0]]
|
42
|
+
|
43
|
+
for i in range(1, len(src_indices)):
|
44
|
+
src_contiguous = src_indices[i] == src_indices[i - 1] + 1
|
45
|
+
dst_contiguous = dst_indices[i] == dst_indices[i - 1] + 1
|
46
|
+
if src_contiguous and dst_contiguous:
|
47
|
+
current_src.append(src_indices[i])
|
48
|
+
current_dst.append(dst_indices[i])
|
49
|
+
else:
|
50
|
+
src_groups.append(current_src)
|
51
|
+
dst_groups.append(current_dst)
|
52
|
+
current_src = [src_indices[i]]
|
53
|
+
current_dst = [dst_indices[i]]
|
54
|
+
|
55
|
+
src_groups.append(current_src)
|
56
|
+
dst_groups.append(current_dst)
|
57
|
+
|
58
|
+
return src_groups, dst_groups
|
59
|
+
|
60
|
+
|
61
|
+
@dataclasses.dataclass
|
62
|
+
class TransferKVChunk:
|
63
|
+
room: int
|
64
|
+
prefill_kv_indices: npt.NDArray[np.int64]
|
65
|
+
index_slice: slice
|
66
|
+
is_last: bool
|
67
|
+
prefill_aux_index: Optional[int]
|
68
|
+
|
69
|
+
|
70
|
+
@dataclasses.dataclass
|
71
|
+
class TransferInfo:
|
72
|
+
room: int
|
73
|
+
endpoint: str
|
74
|
+
dst_port: int
|
75
|
+
mooncake_session_id: str
|
76
|
+
dst_kv_ptrs: list[int]
|
77
|
+
dst_kv_indices: npt.NDArray[np.int64]
|
78
|
+
dst_aux_ptrs: list[int]
|
79
|
+
dst_aux_index: int
|
80
|
+
|
81
|
+
@classmethod
|
82
|
+
def from_zmq(cls, msg: List[bytes]):
|
83
|
+
return cls(
|
84
|
+
room=int(msg[0].decode("ascii")),
|
85
|
+
endpoint=msg[1].decode("ascii"),
|
86
|
+
dst_port=int(msg[2].decode("ascii")),
|
87
|
+
mooncake_session_id=msg[3].decode("ascii"),
|
88
|
+
dst_kv_ptrs=list(struct.unpack(f"{len(msg[4])//8}Q", msg[4])),
|
89
|
+
dst_kv_indices=np.frombuffer(msg[5], dtype=np.int64),
|
90
|
+
dst_aux_ptrs=list(struct.unpack(f"{len(msg[6])//8}Q", msg[6])),
|
91
|
+
dst_aux_index=int(msg[7].decode("ascii")),
|
92
|
+
)
|
93
|
+
|
94
|
+
|
95
|
+
class MooncakeKVManager(BaseKVManager):
|
96
|
+
def __init__(
|
97
|
+
self,
|
98
|
+
args: KVArgs,
|
99
|
+
disaggregation_mode: DisaggregationMode,
|
100
|
+
server_args: ServerArgs,
|
101
|
+
):
|
102
|
+
self.kv_args = args
|
103
|
+
self.engine = MooncakeTransferEngine(
|
104
|
+
hostname=get_local_ip_by_remote(),
|
105
|
+
gpu_id=self.kv_args.gpu_id,
|
106
|
+
ib_device=self.kv_args.ib_device,
|
107
|
+
)
|
108
|
+
self.disaggregation_mode = disaggregation_mode
|
109
|
+
# for p/d multi node infer
|
110
|
+
self.bootstrap_port = server_args.disaggregation_bootstrap_port
|
111
|
+
self.dist_init_addr = server_args.dist_init_addr
|
112
|
+
self.request_status: Dict[int, KVPoll] = {}
|
113
|
+
self.rank_port = None
|
114
|
+
self.server_socket = zmq.Context().socket(zmq.PULL)
|
115
|
+
self.register_buffer_to_engine()
|
116
|
+
if self.disaggregation_mode == DisaggregationMode.PREFILL:
|
117
|
+
self.transfer_queue = queue.Queue()
|
118
|
+
self.transfer_infos: Dict[int, TransferInfo] = {}
|
119
|
+
self.start_prefill_thread()
|
120
|
+
self._register_to_bootstrap()
|
121
|
+
elif self.disaggregation_mode == DisaggregationMode.DECODE:
|
122
|
+
self.start_decode_thread()
|
123
|
+
self.connection_pool: Dict[str, Dict[str, Union[str, int]]] = {}
|
124
|
+
else:
|
125
|
+
raise ValueError(
|
126
|
+
f"Unsupported DisaggregationMode: {self.disaggregation_mode}"
|
127
|
+
)
|
128
|
+
|
129
|
+
def register_buffer_to_engine(self):
|
130
|
+
for kv_data_ptr, kv_data_len in zip(
|
131
|
+
self.kv_args.kv_data_ptrs, self.kv_args.kv_data_lens
|
132
|
+
):
|
133
|
+
self.engine.register(kv_data_ptr, kv_data_len)
|
134
|
+
|
135
|
+
for aux_data_ptr, aux_data_len in zip(
|
136
|
+
self.kv_args.aux_data_ptrs, self.kv_args.aux_data_lens
|
137
|
+
):
|
138
|
+
self.engine.register(aux_data_ptr, aux_data_len)
|
139
|
+
|
140
|
+
@cache
|
141
|
+
def _connect(self, endpoint: str):
|
142
|
+
socket = zmq.Context().socket(zmq.PUSH)
|
143
|
+
socket.connect(endpoint)
|
144
|
+
return socket
|
145
|
+
|
146
|
+
def send_kvcache(
|
147
|
+
self,
|
148
|
+
mooncake_session_id: str,
|
149
|
+
prefill_kv_indices: npt.NDArray[np.int64],
|
150
|
+
dst_kv_ptrs: list[int],
|
151
|
+
dst_kv_indices: npt.NDArray[np.int64],
|
152
|
+
):
|
153
|
+
# group by indices
|
154
|
+
prefill_kv_blocks, dst_kv_blocks = group_concurrent_contiguous(
|
155
|
+
prefill_kv_indices, dst_kv_indices
|
156
|
+
)
|
157
|
+
|
158
|
+
num_layers = len(self.kv_args.kv_data_ptrs)
|
159
|
+
for layer_id in range(num_layers):
|
160
|
+
src_ptr = self.kv_args.kv_data_ptrs[layer_id]
|
161
|
+
dst_ptr = dst_kv_ptrs[layer_id]
|
162
|
+
item_len = self.kv_args.kv_item_lens[layer_id]
|
163
|
+
|
164
|
+
for prefill_index, decode_index in zip(prefill_kv_blocks, dst_kv_blocks):
|
165
|
+
src_addr = src_ptr + int(prefill_index[0]) * item_len
|
166
|
+
dst_addr = dst_ptr + int(decode_index[0]) * item_len
|
167
|
+
length = item_len * len(prefill_index)
|
168
|
+
|
169
|
+
# TODO: make async later
|
170
|
+
status = self.engine.transfer_sync(
|
171
|
+
mooncake_session_id, src_addr, dst_addr, length
|
172
|
+
)
|
173
|
+
if status != 0:
|
174
|
+
return status
|
175
|
+
|
176
|
+
return 0
|
177
|
+
|
178
|
+
def send_aux(
|
179
|
+
self,
|
180
|
+
mooncake_session_id: str,
|
181
|
+
prefill_aux_index: int,
|
182
|
+
dst_aux_ptrs: list[int],
|
183
|
+
dst_aux_index: int,
|
184
|
+
):
|
185
|
+
aux_item_len = self.kv_args.aux_item_lens[0]
|
186
|
+
prefill_aux_addr = (
|
187
|
+
self.kv_args.aux_data_ptrs[0] + prefill_aux_index * aux_item_len
|
188
|
+
)
|
189
|
+
decode_aux_addr = dst_aux_ptrs[0] + dst_aux_index * aux_item_len
|
190
|
+
# TODO: mooncake transfer engine can do async transfer. Do async later
|
191
|
+
# Not sure about the amount of aux data, maybe transfer it by zmq is more effective
|
192
|
+
status = self.engine.transfer_sync(
|
193
|
+
mooncake_session_id, prefill_aux_addr, decode_aux_addr, aux_item_len
|
194
|
+
)
|
195
|
+
return status
|
196
|
+
|
197
|
+
def sync_status_to_decode_endpoint(self, remote: str, dst_port: int, room: int):
|
198
|
+
if ":" in remote:
|
199
|
+
remote = remote.split(":")[0]
|
200
|
+
self._connect("tcp://" + remote + ":" + str(dst_port)).send_multipart(
|
201
|
+
[
|
202
|
+
str(room).encode("ascii"),
|
203
|
+
str(self.request_status[room]).encode("ascii"),
|
204
|
+
]
|
205
|
+
)
|
206
|
+
|
207
|
+
def start_prefill_thread(self):
|
208
|
+
self.rank_port = get_free_port()
|
209
|
+
self.server_socket.bind(f"tcp://{get_local_ip_by_remote()}:{self.rank_port}")
|
210
|
+
|
211
|
+
def bootstrap_thread():
|
212
|
+
"""This thread recvs pre-alloc notification from the decode engine"""
|
213
|
+
# KVPoll.Bootstrapping -> KVPoll.WaitingForInput
|
214
|
+
while True:
|
215
|
+
waiting_req_bytes = self.server_socket.recv_multipart()
|
216
|
+
room = waiting_req_bytes[0].decode("ascii")
|
217
|
+
if room == "None":
|
218
|
+
continue
|
219
|
+
room = int(room)
|
220
|
+
self.transfer_infos[room] = TransferInfo.from_zmq(waiting_req_bytes)
|
221
|
+
|
222
|
+
# NOTE: after bootstrapping we can mark the req as waiting for input
|
223
|
+
self.request_status[room] = KVPoll.WaitingForInput
|
224
|
+
|
225
|
+
def transfer_thread():
|
226
|
+
# TODO: Shall we use KVPoll.Transferring state?
|
227
|
+
while True:
|
228
|
+
try:
|
229
|
+
kv_chunk: TransferKVChunk = self.transfer_queue.get(timeout=0.01)
|
230
|
+
req = self.transfer_infos[kv_chunk.room]
|
231
|
+
chunked_dst_kv_indice = req.dst_kv_indices[kv_chunk.index_slice]
|
232
|
+
assert len(chunked_dst_kv_indice) == len(
|
233
|
+
kv_chunk.prefill_kv_indices
|
234
|
+
)
|
235
|
+
|
236
|
+
ret = self.send_kvcache(
|
237
|
+
req.mooncake_session_id,
|
238
|
+
kv_chunk.prefill_kv_indices,
|
239
|
+
req.dst_kv_ptrs,
|
240
|
+
chunked_dst_kv_indice,
|
241
|
+
)
|
242
|
+
if ret != 0:
|
243
|
+
self.request_status[kv_chunk.room] = KVPoll.Failed
|
244
|
+
self.sync_status_to_decode_endpoint(
|
245
|
+
req.endpoint, req.dst_port, req.room
|
246
|
+
)
|
247
|
+
continue
|
248
|
+
|
249
|
+
if kv_chunk.is_last:
|
250
|
+
# Only the last chunk we need to send the aux data
|
251
|
+
ret = self.send_aux(
|
252
|
+
req.mooncake_session_id,
|
253
|
+
kv_chunk.prefill_aux_index,
|
254
|
+
req.dst_aux_ptrs,
|
255
|
+
req.dst_aux_index,
|
256
|
+
)
|
257
|
+
self.request_status[req.room] = (
|
258
|
+
KVPoll.Success if ret == 0 else KVPoll.Failed
|
259
|
+
)
|
260
|
+
self.sync_status_to_decode_endpoint(
|
261
|
+
req.endpoint, req.dst_port, req.room
|
262
|
+
)
|
263
|
+
self.transfer_infos.pop(req.room)
|
264
|
+
|
265
|
+
except queue.Empty:
|
266
|
+
continue
|
267
|
+
|
268
|
+
threading.Thread(target=bootstrap_thread).start()
|
269
|
+
threading.Thread(target=transfer_thread).start()
|
270
|
+
|
271
|
+
def start_decode_thread(self):
|
272
|
+
self.rank_port = get_free_port()
|
273
|
+
self.server_socket.bind(f"tcp://{get_local_ip_by_remote()}:{self.rank_port}")
|
274
|
+
|
275
|
+
def decode_thread():
|
276
|
+
while True:
|
277
|
+
(bootstrap_room, status) = self.server_socket.recv_multipart()
|
278
|
+
status = int(status.decode("ascii"))
|
279
|
+
bootstrap_room = int(bootstrap_room.decode("ascii"))
|
280
|
+
self.request_status[bootstrap_room] = status
|
281
|
+
|
282
|
+
threading.Thread(target=decode_thread).start()
|
283
|
+
|
284
|
+
def add_transfer_request(
|
285
|
+
self,
|
286
|
+
bootstrap_room: int,
|
287
|
+
kv_indices: npt.NDArray[np.int64],
|
288
|
+
index_slice: slice,
|
289
|
+
is_last: bool,
|
290
|
+
aux_index: Optional[int] = None,
|
291
|
+
):
|
292
|
+
assert self.disaggregation_mode == DisaggregationMode.PREFILL
|
293
|
+
assert not is_last or (is_last and aux_index is not None)
|
294
|
+
|
295
|
+
self.transfer_queue.put(
|
296
|
+
TransferKVChunk(
|
297
|
+
room=bootstrap_room,
|
298
|
+
prefill_kv_indices=kv_indices,
|
299
|
+
index_slice=index_slice,
|
300
|
+
is_last=is_last,
|
301
|
+
prefill_aux_index=aux_index,
|
302
|
+
)
|
303
|
+
)
|
304
|
+
self.request_status[bootstrap_room] = KVPoll.WaitingForInput
|
305
|
+
|
306
|
+
def check_status(self, bootstrap_room: int):
|
307
|
+
# TOOD: do we really need the poll()?
|
308
|
+
|
309
|
+
return self.request_status[bootstrap_room]
|
310
|
+
|
311
|
+
def update_status(self, bootstrap_room: int, status: KVPoll):
|
312
|
+
if bootstrap_room not in self.request_status:
|
313
|
+
self.request_status[bootstrap_room] = status
|
314
|
+
else:
|
315
|
+
# NOTE: The prefill engine could recv bootstrapping first
|
316
|
+
self.request_status[bootstrap_room] = max(
|
317
|
+
self.request_status[bootstrap_room], status
|
318
|
+
)
|
319
|
+
|
320
|
+
def get_session_id(self):
|
321
|
+
return self.engine.get_session_id()
|
322
|
+
|
323
|
+
def _register_to_bootstrap(self):
|
324
|
+
"""Register KVSender to bootstrap server via HTTP POST."""
|
325
|
+
if self.dist_init_addr:
|
326
|
+
ip_address = socket.gethostbyname(self.dist_init_addr.split(":")[0])
|
327
|
+
else:
|
328
|
+
ip_address = get_ip()
|
329
|
+
|
330
|
+
bootstrap_server_url = f"{ip_address}:{self.bootstrap_port}"
|
331
|
+
url = f"http://{bootstrap_server_url}/route"
|
332
|
+
payload = {
|
333
|
+
"role": "Prefill",
|
334
|
+
"rank_ip": get_local_ip_by_remote(),
|
335
|
+
"rank_port": self.rank_port,
|
336
|
+
"engine_rank": self.kv_args.engine_rank,
|
337
|
+
}
|
338
|
+
|
339
|
+
try:
|
340
|
+
response = requests.put(url, json=payload)
|
341
|
+
if response.status_code == 200:
|
342
|
+
logger.debug("Prefill successfully registered to bootstrap server.")
|
343
|
+
else:
|
344
|
+
logger.error(
|
345
|
+
f"Prefill Failed to connect to bootstrap server: {response.status_code}, {response.text}"
|
346
|
+
)
|
347
|
+
except Exception as e:
|
348
|
+
logger.error(f"Prefill Failed to register to bootstrap server: {e}")
|
349
|
+
|
350
|
+
|
351
|
+
class MooncakeKVSender(BaseKVSender):
|
352
|
+
|
353
|
+
def __init__(
|
354
|
+
self, mgr: MooncakeKVManager, bootstrap_addr: str, bootstrap_room: int
|
355
|
+
):
|
356
|
+
self.kv_mgr = mgr
|
357
|
+
self.bootstrap_room = bootstrap_room
|
358
|
+
self.kv_mgr.update_status(bootstrap_room, KVPoll.Bootstrapping)
|
359
|
+
self.aux_index = None
|
360
|
+
self.bootstrap_server_url = bootstrap_addr
|
361
|
+
self.session_id = self.kv_mgr.get_session_id()
|
362
|
+
|
363
|
+
def init(self, num_kv_indices: int, aux_index: Optional[int] = None):
|
364
|
+
self.num_kv_indices = num_kv_indices
|
365
|
+
self.aux_index = aux_index
|
366
|
+
|
367
|
+
def send(
|
368
|
+
self,
|
369
|
+
kv_indices: npt.NDArray[np.int64],
|
370
|
+
index_slice: slice,
|
371
|
+
is_last: bool,
|
372
|
+
):
|
373
|
+
if not is_last:
|
374
|
+
self.kv_mgr.add_transfer_request(
|
375
|
+
self.bootstrap_room, kv_indices, index_slice, False
|
376
|
+
)
|
377
|
+
else:
|
378
|
+
self.kv_mgr.add_transfer_request(
|
379
|
+
self.bootstrap_room,
|
380
|
+
kv_indices,
|
381
|
+
index_slice,
|
382
|
+
True,
|
383
|
+
aux_index=self.aux_index,
|
384
|
+
)
|
385
|
+
|
386
|
+
def poll(self) -> KVPoll:
|
387
|
+
return self.kv_mgr.check_status(self.bootstrap_room)
|
388
|
+
|
389
|
+
def failure_exception(self):
|
390
|
+
raise Exception("Fake KVSender Exception")
|
391
|
+
|
392
|
+
|
393
|
+
class MooncakeKVReceiver(BaseKVReceiver):
|
394
|
+
_ctx = zmq.Context()
|
395
|
+
_socket_cache = {}
|
396
|
+
_socket_locks = {}
|
397
|
+
_global_lock = threading.Lock()
|
398
|
+
|
399
|
+
def __init__(
|
400
|
+
self,
|
401
|
+
mgr: MooncakeKVManager,
|
402
|
+
bootstrap_addr: str,
|
403
|
+
bootstrap_room: Optional[int] = None,
|
404
|
+
):
|
405
|
+
self.bootstrap_room = bootstrap_room
|
406
|
+
self.bootstrap_addr = bootstrap_addr
|
407
|
+
self.kv_mgr = mgr
|
408
|
+
self.session_id = self.kv_mgr.get_session_id()
|
409
|
+
self.kv_mgr.update_status(bootstrap_room, KVPoll.Bootstrapping)
|
410
|
+
|
411
|
+
# NOTE: key distinguished by bootstrap_addr and engine_rank
|
412
|
+
bootstrap_key = f"{self.bootstrap_addr}_{self.kv_mgr.kv_args.engine_rank}"
|
413
|
+
|
414
|
+
if bootstrap_key not in self.kv_mgr.connection_pool:
|
415
|
+
self.bootstrap_info = self._get_bootstrap_info_from_server(
|
416
|
+
self.kv_mgr.kv_args.engine_rank
|
417
|
+
)
|
418
|
+
if self.bootstrap_info is None:
|
419
|
+
logger.error(
|
420
|
+
f"Could not fetch bootstrap info for engine rank: {self.kv_mgr.kv_args.engine_rank}"
|
421
|
+
)
|
422
|
+
else:
|
423
|
+
self.kv_mgr.connection_pool[bootstrap_key] = self.bootstrap_info
|
424
|
+
else:
|
425
|
+
self.bootstrap_info = self.kv_mgr.connection_pool[bootstrap_key]
|
426
|
+
|
427
|
+
assert self.bootstrap_info is not None
|
428
|
+
self.kv_mgr.update_status(bootstrap_room, KVPoll.WaitingForInput)
|
429
|
+
|
430
|
+
def _get_bootstrap_info_from_server(self, engine_rank):
|
431
|
+
"""Fetch the bootstrap info from the bootstrap server."""
|
432
|
+
try:
|
433
|
+
url = f"http://{self.bootstrap_addr}/route?engine_rank={engine_rank}"
|
434
|
+
response = requests.get(url)
|
435
|
+
if response.status_code == 200:
|
436
|
+
bootstrap_info = response.json()
|
437
|
+
return bootstrap_info
|
438
|
+
else:
|
439
|
+
logger.error(
|
440
|
+
f"Failed to get prefill server info: {response.status_code}, {response.text}"
|
441
|
+
)
|
442
|
+
return None
|
443
|
+
except Exception as e:
|
444
|
+
logger.error(f"Error fetching prefill info from bootstrap: {e}")
|
445
|
+
return None
|
446
|
+
|
447
|
+
@classmethod
|
448
|
+
def _connect(cls, endpoint: str):
|
449
|
+
with cls._global_lock:
|
450
|
+
if endpoint not in cls._socket_cache:
|
451
|
+
sock = cls._ctx.socket(zmq.PUSH)
|
452
|
+
sock.connect(endpoint)
|
453
|
+
cls._socket_cache[endpoint] = sock
|
454
|
+
cls._socket_locks[endpoint] = threading.Lock()
|
455
|
+
return cls._socket_cache[endpoint], cls._socket_locks[endpoint]
|
456
|
+
|
457
|
+
def init(self, kv_indices: npt.NDArray[np.int64], aux_index: Optional[int] = None):
|
458
|
+
self.prefill_server_url = (
|
459
|
+
f"{self.bootstrap_info['rank_ip']}:{self.bootstrap_info['rank_port']}"
|
460
|
+
)
|
461
|
+
logger.debug(
|
462
|
+
f"Fetched bootstrap info: {self.bootstrap_info} for engine rank: {self.kv_mgr.kv_args.engine_rank}"
|
463
|
+
)
|
464
|
+
|
465
|
+
packed_kv_data_ptrs = b"".join(
|
466
|
+
struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.kv_data_ptrs
|
467
|
+
)
|
468
|
+
packed_aux_data_ptrs = b"".join(
|
469
|
+
struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.aux_data_ptrs
|
470
|
+
)
|
471
|
+
sock, lock = self._connect("tcp://" + self.prefill_server_url)
|
472
|
+
with lock:
|
473
|
+
sock.send_multipart(
|
474
|
+
[
|
475
|
+
str(self.bootstrap_room).encode("ascii"),
|
476
|
+
get_local_ip_by_remote().encode("ascii"),
|
477
|
+
str(self.kv_mgr.rank_port).encode("ascii"),
|
478
|
+
self.session_id.encode("ascii"),
|
479
|
+
packed_kv_data_ptrs,
|
480
|
+
kv_indices.tobytes(),
|
481
|
+
packed_aux_data_ptrs,
|
482
|
+
str(aux_index).encode("ascii"),
|
483
|
+
]
|
484
|
+
)
|
485
|
+
|
486
|
+
def poll(self) -> KVPoll:
|
487
|
+
return self.kv_mgr.check_status(self.bootstrap_room)
|
488
|
+
|
489
|
+
def failure_exception(self):
|
490
|
+
raise Exception("Fake KVReceiver Exception")
|
491
|
+
|
492
|
+
|
493
|
+
class MooncakeKVBootstrapServer(BaseKVBootstrapServer):
|
494
|
+
def __init__(self, port: int):
|
495
|
+
self.port = port
|
496
|
+
self.app = web.Application()
|
497
|
+
self.store = dict()
|
498
|
+
self.lock = asyncio.Lock()
|
499
|
+
self._setup_routes()
|
500
|
+
self.prefill_port_table: Dict[int, Dict[str, Union[str, int]]] = {}
|
501
|
+
|
502
|
+
# Start bootstrap server
|
503
|
+
self.thread = threading.Thread(target=self._run_server, daemon=True)
|
504
|
+
self.run()
|
505
|
+
|
506
|
+
def run(self):
|
507
|
+
self.thread.start()
|
508
|
+
|
509
|
+
def _setup_routes(self):
|
510
|
+
self.app.router.add_route("*", "/route", self._handle_route)
|
511
|
+
|
512
|
+
async def _handle_route(self, request: web.Request):
|
513
|
+
method = request.method
|
514
|
+
if method == "PUT":
|
515
|
+
return await self._handle_route_put(request)
|
516
|
+
elif method == "GET":
|
517
|
+
return await self._handle_route_get(request)
|
518
|
+
else:
|
519
|
+
return web.Response(
|
520
|
+
text="Method not allowed", status=405, content_type="application/json"
|
521
|
+
)
|
522
|
+
|
523
|
+
async def _handle_route_put(self, request: web.Request):
|
524
|
+
data = await request.json()
|
525
|
+
role = data["role"]
|
526
|
+
rank_ip = data["rank_ip"]
|
527
|
+
rank_port = int(data["rank_port"])
|
528
|
+
engine_rank = int(data["engine_rank"])
|
529
|
+
|
530
|
+
# Add lock to make sure thread-safe
|
531
|
+
if role == "Prefill":
|
532
|
+
self.prefill_port_table[engine_rank] = {
|
533
|
+
"rank_ip": rank_ip,
|
534
|
+
"rank_port": rank_port,
|
535
|
+
}
|
536
|
+
logger.debug(
|
537
|
+
f"Registered Prefill boostrap: {engine_rank} with rank_ip: {rank_ip} and rank_port: {rank_port}"
|
538
|
+
)
|
539
|
+
|
540
|
+
return web.Response(text="OK", status=200)
|
541
|
+
|
542
|
+
async def _handle_route_get(self, request: web.Request):
|
543
|
+
engine_rank = request.query.get("engine_rank")
|
544
|
+
if not engine_rank:
|
545
|
+
return web.Response(text="Missing rank", status=400)
|
546
|
+
|
547
|
+
# Find corresponding prefill info
|
548
|
+
async with self.lock:
|
549
|
+
bootstrap_info = self.prefill_port_table.get(int(engine_rank))
|
550
|
+
|
551
|
+
if bootstrap_info is not None:
|
552
|
+
return web.json_response(bootstrap_info, status=200)
|
553
|
+
else:
|
554
|
+
return web.Response(text="Not Found", status=404)
|
555
|
+
|
556
|
+
def _run_server(self):
|
557
|
+
try:
|
558
|
+
# Event Loop
|
559
|
+
self._loop = asyncio.new_event_loop()
|
560
|
+
asyncio.set_event_loop(self._loop)
|
561
|
+
|
562
|
+
self._runner = web.AppRunner(self.app)
|
563
|
+
self._loop.run_until_complete(self._runner.setup())
|
564
|
+
|
565
|
+
site = web.TCPSite(self._runner, port=self.port)
|
566
|
+
self._loop.run_until_complete(site.start())
|
567
|
+
self._loop.run_forever()
|
568
|
+
except Exception as e:
|
569
|
+
logger.error(f"Server error: {str(e)}")
|
570
|
+
finally:
|
571
|
+
# Cleanup
|
572
|
+
self._loop.run_until_complete(self._runner.cleanup())
|
573
|
+
self._loop.close()
|
574
|
+
|
575
|
+
def close(self):
|
576
|
+
"""Shutdown"""
|
577
|
+
if self._loop is not None and self._loop.is_running():
|
578
|
+
self._loop.call_soon_threadsafe(self._loop.stop)
|
579
|
+
logger.info("Stopping server loop...")
|
580
|
+
|
581
|
+
if self.thread.is_alive():
|
582
|
+
self.thread.join(timeout=2)
|
583
|
+
logger.info("Server thread stopped")
|
584
|
+
|
585
|
+
def poll(self) -> KVPoll: ...
|
@@ -0,0 +1,77 @@
|
|
1
|
+
import json
|
2
|
+
import logging
|
3
|
+
from dataclasses import dataclass
|
4
|
+
from typing import Optional
|
5
|
+
|
6
|
+
logger = logging.getLogger(__name__)
|
7
|
+
|
8
|
+
|
9
|
+
class MooncakeTransferEngine:
|
10
|
+
|
11
|
+
def __init__(self, hostname: str, gpu_id: int, ib_device: Optional[str] = None):
|
12
|
+
try:
|
13
|
+
from mooncake.engine import TransferEngine
|
14
|
+
except ImportError as e:
|
15
|
+
raise ImportError(
|
16
|
+
"Please install mooncake by following the instructions at "
|
17
|
+
"https://github.com/kvcache-ai/Mooncake/blob/main/doc/en/build.md " # noqa: E501
|
18
|
+
"to run SGLang with MooncakeTransferEngine."
|
19
|
+
) from e
|
20
|
+
|
21
|
+
self.engine = TransferEngine()
|
22
|
+
self.hostname = hostname
|
23
|
+
self.gpu_id = gpu_id
|
24
|
+
self.ib_device = ib_device
|
25
|
+
|
26
|
+
self.initialize(
|
27
|
+
hostname=self.hostname,
|
28
|
+
device_name=self.ib_device,
|
29
|
+
)
|
30
|
+
self.session_id = f"{self.hostname}:{self.engine.get_rpc_port()}"
|
31
|
+
|
32
|
+
def register(self, ptr, length):
|
33
|
+
ret_value = self.engine.register_memory(ptr, length)
|
34
|
+
if ret_value != 0:
|
35
|
+
logger.error("Mooncake memory registration failed.")
|
36
|
+
raise RuntimeError("Mooncake memory registration failed.")
|
37
|
+
|
38
|
+
def deregister(self, ptr):
|
39
|
+
ret_value = self.engine.unregister_memory(ptr)
|
40
|
+
if ret_value != 0:
|
41
|
+
logger.error("Mooncake memory deregistration failed.")
|
42
|
+
raise RuntimeError("Mooncake memory deregistration failed.")
|
43
|
+
|
44
|
+
def initialize(
|
45
|
+
self,
|
46
|
+
hostname: str,
|
47
|
+
device_name: Optional[str],
|
48
|
+
) -> None:
|
49
|
+
"""Initialize the mooncake instance."""
|
50
|
+
ret_value = self.engine.initialize(
|
51
|
+
hostname,
|
52
|
+
"P2PHANDSHAKE",
|
53
|
+
"rdma",
|
54
|
+
device_name if device_name is not None else "",
|
55
|
+
)
|
56
|
+
if ret_value != 0:
|
57
|
+
logger.error("Mooncake Transfer Engine initialization failed.")
|
58
|
+
raise RuntimeError("Mooncake Transfer Engine initialization failed.")
|
59
|
+
|
60
|
+
def transfer_sync(
|
61
|
+
self, session_id: str, buffer: int, peer_buffer_address: int, length: int
|
62
|
+
) -> int:
|
63
|
+
"""Synchronously transfer data to the specified address."""
|
64
|
+
|
65
|
+
ret = self.engine.transfer_sync_write(
|
66
|
+
session_id, buffer, peer_buffer_address, length
|
67
|
+
)
|
68
|
+
if ret < 0:
|
69
|
+
logger.error("Mooncake Transfer Engine Return Error.")
|
70
|
+
raise RuntimeError("Mooncake Transfer Engine Return Error.")
|
71
|
+
return ret
|
72
|
+
|
73
|
+
def get_localhost(self):
|
74
|
+
return self.hostname
|
75
|
+
|
76
|
+
def get_session_id(self):
|
77
|
+
return self.session_id
|