sglang 0.4.5__py3-none-any.whl → 0.4.5.post2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (166) hide show
  1. sglang/__init__.py +2 -4
  2. sglang/bench_one_batch.py +23 -2
  3. sglang/bench_serving.py +6 -4
  4. sglang/lang/backend/anthropic.py +0 -4
  5. sglang/lang/backend/base_backend.py +1 -1
  6. sglang/lang/backend/openai.py +1 -1
  7. sglang/lang/backend/vertexai.py +0 -1
  8. sglang/lang/compiler.py +1 -7
  9. sglang/lang/tracer.py +3 -7
  10. sglang/srt/_custom_ops.py +0 -2
  11. sglang/srt/configs/model_config.py +37 -5
  12. sglang/srt/constrained/base_grammar_backend.py +26 -5
  13. sglang/srt/constrained/llguidance_backend.py +1 -0
  14. sglang/srt/constrained/outlines_backend.py +1 -0
  15. sglang/srt/constrained/outlines_jump_forward.py +14 -1
  16. sglang/srt/constrained/reasoner_grammar_backend.py +101 -0
  17. sglang/srt/constrained/triton_ops/bitmask_ops.py +141 -0
  18. sglang/srt/constrained/xgrammar_backend.py +27 -4
  19. sglang/srt/custom_op.py +0 -62
  20. sglang/srt/disaggregation/base/__init__.py +8 -0
  21. sglang/srt/disaggregation/base/conn.py +113 -0
  22. sglang/srt/disaggregation/decode.py +80 -11
  23. sglang/srt/disaggregation/mini_lb.py +58 -123
  24. sglang/srt/disaggregation/mooncake/__init__.py +6 -0
  25. sglang/srt/disaggregation/mooncake/conn.py +585 -0
  26. sglang/srt/disaggregation/mooncake/transfer_engine.py +77 -0
  27. sglang/srt/disaggregation/prefill.py +82 -22
  28. sglang/srt/disaggregation/utils.py +46 -0
  29. sglang/srt/entrypoints/EngineBase.py +53 -0
  30. sglang/srt/entrypoints/engine.py +36 -8
  31. sglang/srt/entrypoints/http_server.py +37 -8
  32. sglang/srt/entrypoints/http_server_engine.py +142 -0
  33. sglang/srt/entrypoints/verl_engine.py +42 -13
  34. sglang/srt/hf_transformers_utils.py +4 -0
  35. sglang/srt/layers/activation.py +6 -8
  36. sglang/srt/layers/attention/flashattention_backend.py +430 -257
  37. sglang/srt/layers/attention/flashinfer_backend.py +18 -9
  38. sglang/srt/layers/attention/torch_native_backend.py +6 -1
  39. sglang/srt/layers/attention/triton_backend.py +6 -0
  40. sglang/srt/layers/attention/triton_ops/extend_attention.py +13 -2
  41. sglang/srt/layers/attention/vision.py +1 -1
  42. sglang/srt/layers/dp_attention.py +2 -4
  43. sglang/srt/layers/elementwise.py +15 -2
  44. sglang/srt/layers/layernorm.py +1 -1
  45. sglang/srt/layers/linear.py +18 -3
  46. sglang/srt/layers/moe/ep_moe/layer.py +15 -29
  47. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +145 -118
  48. sglang/srt/layers/moe/fused_moe_native.py +4 -0
  49. sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  50. sglang/srt/layers/moe/fused_moe_triton/configs/E=264,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  51. sglang/srt/layers/moe/fused_moe_triton/configs/{E=257,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json → E=264,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +34 -34
  52. sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
  53. sglang/srt/layers/moe/fused_moe_triton/configs/E=272,N=64,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
  54. sglang/srt/layers/moe/fused_moe_triton/configs/E=288,N=64,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
  55. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +46 -34
  56. sglang/srt/layers/moe/fused_moe_triton/layer.py +7 -0
  57. sglang/srt/layers/moe/router.py +7 -1
  58. sglang/srt/layers/moe/topk.py +63 -45
  59. sglang/srt/layers/parameter.py +0 -2
  60. sglang/srt/layers/quantization/__init__.py +13 -5
  61. sglang/srt/layers/quantization/blockwise_int8.py +2 -0
  62. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +12 -2
  63. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +72 -77
  64. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +4 -7
  65. sglang/srt/layers/quantization/fp8.py +131 -136
  66. sglang/srt/layers/quantization/fp8_kernel.py +328 -46
  67. sglang/srt/layers/quantization/fp8_utils.py +206 -253
  68. sglang/srt/layers/quantization/kv_cache.py +43 -52
  69. sglang/srt/layers/quantization/modelopt_quant.py +271 -4
  70. sglang/srt/layers/quantization/moe_wna16.py +2 -0
  71. sglang/srt/layers/quantization/utils.py +5 -11
  72. sglang/srt/layers/quantization/w8a8_fp8.py +156 -4
  73. sglang/srt/layers/quantization/w8a8_int8.py +8 -7
  74. sglang/srt/layers/radix_attention.py +28 -1
  75. sglang/srt/layers/rotary_embedding.py +15 -3
  76. sglang/srt/layers/sampler.py +5 -10
  77. sglang/srt/lora/backend/base_backend.py +18 -2
  78. sglang/srt/lora/backend/flashinfer_backend.py +1 -1
  79. sglang/srt/lora/backend/triton_backend.py +1 -1
  80. sglang/srt/lora/layers.py +1 -1
  81. sglang/srt/lora/lora.py +1 -1
  82. sglang/srt/lora/lora_manager.py +1 -1
  83. sglang/srt/managers/detokenizer_manager.py +0 -1
  84. sglang/srt/managers/io_struct.py +255 -97
  85. sglang/srt/managers/mm_utils.py +7 -5
  86. sglang/srt/managers/multimodal_processor.py +0 -2
  87. sglang/srt/managers/multimodal_processors/base_processor.py +117 -79
  88. sglang/srt/managers/multimodal_processors/janus_pro.py +3 -1
  89. sglang/srt/managers/multimodal_processors/mllama4.py +21 -36
  90. sglang/srt/managers/schedule_batch.py +64 -25
  91. sglang/srt/managers/scheduler.py +80 -82
  92. sglang/srt/managers/tokenizer_manager.py +18 -3
  93. sglang/srt/managers/tp_worker.py +1 -0
  94. sglang/srt/mem_cache/hiradix_cache.py +5 -1
  95. sglang/srt/mem_cache/memory_pool.py +21 -3
  96. sglang/srt/metrics/collector.py +9 -0
  97. sglang/srt/model_executor/cuda_graph_runner.py +9 -6
  98. sglang/srt/model_executor/forward_batch_info.py +234 -15
  99. sglang/srt/model_executor/model_runner.py +67 -35
  100. sglang/srt/model_loader/loader.py +31 -4
  101. sglang/srt/model_loader/weight_utils.py +4 -2
  102. sglang/srt/models/baichuan.py +2 -0
  103. sglang/srt/models/bert.py +398 -0
  104. sglang/srt/models/chatglm.py +1 -0
  105. sglang/srt/models/commandr.py +1 -0
  106. sglang/srt/models/dbrx.py +1 -0
  107. sglang/srt/models/deepseek.py +2 -1
  108. sglang/srt/models/deepseek_nextn.py +74 -70
  109. sglang/srt/models/deepseek_v2.py +494 -366
  110. sglang/srt/models/exaone.py +1 -0
  111. sglang/srt/models/gemma.py +1 -0
  112. sglang/srt/models/gemma2.py +1 -0
  113. sglang/srt/models/gemma3_causal.py +1 -0
  114. sglang/srt/models/gpt2.py +1 -0
  115. sglang/srt/models/gpt_bigcode.py +1 -0
  116. sglang/srt/models/granite.py +1 -0
  117. sglang/srt/models/grok.py +1 -0
  118. sglang/srt/models/internlm2.py +1 -0
  119. sglang/srt/models/llama.py +6 -5
  120. sglang/srt/models/llama4.py +101 -34
  121. sglang/srt/models/minicpm.py +1 -0
  122. sglang/srt/models/minicpm3.py +30 -200
  123. sglang/srt/models/mixtral.py +1 -0
  124. sglang/srt/models/mixtral_quant.py +1 -0
  125. sglang/srt/models/mllama.py +51 -8
  126. sglang/srt/models/mllama4.py +102 -29
  127. sglang/srt/models/olmo.py +1 -0
  128. sglang/srt/models/olmo2.py +1 -0
  129. sglang/srt/models/olmoe.py +1 -0
  130. sglang/srt/models/phi3_small.py +1 -0
  131. sglang/srt/models/qwen.py +1 -0
  132. sglang/srt/models/qwen2.py +5 -1
  133. sglang/srt/models/qwen2_5_vl.py +35 -70
  134. sglang/srt/models/qwen2_moe.py +15 -13
  135. sglang/srt/models/qwen2_vl.py +27 -25
  136. sglang/srt/models/qwen3.py +335 -0
  137. sglang/srt/models/qwen3_moe.py +423 -0
  138. sglang/srt/models/stablelm.py +1 -0
  139. sglang/srt/models/xverse.py +1 -0
  140. sglang/srt/models/xverse_moe.py +1 -0
  141. sglang/srt/openai_api/adapter.py +4 -1
  142. sglang/srt/patch_torch.py +11 -0
  143. sglang/srt/reasoning_parser.py +0 -1
  144. sglang/srt/sampling/sampling_batch_info.py +2 -3
  145. sglang/srt/server_args.py +55 -19
  146. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +4 -4
  147. sglang/srt/speculative/eagle_utils.py +1 -11
  148. sglang/srt/speculative/eagle_worker.py +10 -9
  149. sglang/srt/utils.py +136 -10
  150. sglang/test/attention/test_flashattn_backend.py +259 -221
  151. sglang/test/attention/test_flashattn_mla_backend.py +285 -0
  152. sglang/test/attention/test_prefix_chunk_info.py +224 -0
  153. sglang/test/runners.py +5 -1
  154. sglang/test/test_block_fp8.py +224 -0
  155. sglang/test/test_custom_ops.py +1 -1
  156. sglang/test/test_utils.py +19 -8
  157. sglang/version.py +1 -1
  158. {sglang-0.4.5.dist-info → sglang-0.4.5.post2.dist-info}/METADATA +15 -5
  159. {sglang-0.4.5.dist-info → sglang-0.4.5.post2.dist-info}/RECORD +162 -147
  160. {sglang-0.4.5.dist-info → sglang-0.4.5.post2.dist-info}/WHEEL +1 -1
  161. sglang/lang/__init__.py +0 -0
  162. sglang/srt/disaggregation/conn.py +0 -81
  163. sglang/srt/lora/backend/__init__.py +0 -25
  164. sglang/srt/server.py +0 -18
  165. {sglang-0.4.5.dist-info → sglang-0.4.5.post2.dist-info}/licenses/LICENSE +0 -0
  166. {sglang-0.4.5.dist-info → sglang-0.4.5.post2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,585 @@
1
+ from __future__ import annotations
2
+
3
+ import asyncio
4
+ import dataclasses
5
+ import logging
6
+ import queue
7
+ import socket
8
+ import struct
9
+ import threading
10
+ from functools import cache
11
+ from typing import Dict, List, Optional, Tuple, Union
12
+
13
+ import numpy as np
14
+ import numpy.typing as npt
15
+ import requests
16
+ import zmq
17
+ from aiohttp import web
18
+
19
+ from sglang.srt.disaggregation.base.conn import (
20
+ BaseKVBootstrapServer,
21
+ BaseKVManager,
22
+ BaseKVReceiver,
23
+ BaseKVSender,
24
+ KVArgs,
25
+ KVPoll,
26
+ )
27
+ from sglang.srt.disaggregation.mooncake.transfer_engine import MooncakeTransferEngine
28
+ from sglang.srt.disaggregation.utils import DisaggregationMode
29
+ from sglang.srt.server_args import ServerArgs
30
+ from sglang.srt.utils import get_free_port, get_ip, get_local_ip_by_remote
31
+
32
+ logger = logging.getLogger(__name__)
33
+
34
+
35
+ def group_concurrent_contiguous(
36
+ src_indices: npt.NDArray[np.int64], dst_indices: npt.NDArray[np.int64]
37
+ ) -> Tuple[List[npt.NDArray[np.int64]], List[npt.NDArray[np.int64]]]:
38
+ src_groups = []
39
+ dst_groups = []
40
+ current_src = [src_indices[0]]
41
+ current_dst = [dst_indices[0]]
42
+
43
+ for i in range(1, len(src_indices)):
44
+ src_contiguous = src_indices[i] == src_indices[i - 1] + 1
45
+ dst_contiguous = dst_indices[i] == dst_indices[i - 1] + 1
46
+ if src_contiguous and dst_contiguous:
47
+ current_src.append(src_indices[i])
48
+ current_dst.append(dst_indices[i])
49
+ else:
50
+ src_groups.append(current_src)
51
+ dst_groups.append(current_dst)
52
+ current_src = [src_indices[i]]
53
+ current_dst = [dst_indices[i]]
54
+
55
+ src_groups.append(current_src)
56
+ dst_groups.append(current_dst)
57
+
58
+ return src_groups, dst_groups
59
+
60
+
61
+ @dataclasses.dataclass
62
+ class TransferKVChunk:
63
+ room: int
64
+ prefill_kv_indices: npt.NDArray[np.int64]
65
+ index_slice: slice
66
+ is_last: bool
67
+ prefill_aux_index: Optional[int]
68
+
69
+
70
+ @dataclasses.dataclass
71
+ class TransferInfo:
72
+ room: int
73
+ endpoint: str
74
+ dst_port: int
75
+ mooncake_session_id: str
76
+ dst_kv_ptrs: list[int]
77
+ dst_kv_indices: npt.NDArray[np.int64]
78
+ dst_aux_ptrs: list[int]
79
+ dst_aux_index: int
80
+
81
+ @classmethod
82
+ def from_zmq(cls, msg: List[bytes]):
83
+ return cls(
84
+ room=int(msg[0].decode("ascii")),
85
+ endpoint=msg[1].decode("ascii"),
86
+ dst_port=int(msg[2].decode("ascii")),
87
+ mooncake_session_id=msg[3].decode("ascii"),
88
+ dst_kv_ptrs=list(struct.unpack(f"{len(msg[4])//8}Q", msg[4])),
89
+ dst_kv_indices=np.frombuffer(msg[5], dtype=np.int64),
90
+ dst_aux_ptrs=list(struct.unpack(f"{len(msg[6])//8}Q", msg[6])),
91
+ dst_aux_index=int(msg[7].decode("ascii")),
92
+ )
93
+
94
+
95
+ class MooncakeKVManager(BaseKVManager):
96
+ def __init__(
97
+ self,
98
+ args: KVArgs,
99
+ disaggregation_mode: DisaggregationMode,
100
+ server_args: ServerArgs,
101
+ ):
102
+ self.kv_args = args
103
+ self.engine = MooncakeTransferEngine(
104
+ hostname=get_local_ip_by_remote(),
105
+ gpu_id=self.kv_args.gpu_id,
106
+ ib_device=self.kv_args.ib_device,
107
+ )
108
+ self.disaggregation_mode = disaggregation_mode
109
+ # for p/d multi node infer
110
+ self.bootstrap_port = server_args.disaggregation_bootstrap_port
111
+ self.dist_init_addr = server_args.dist_init_addr
112
+ self.request_status: Dict[int, KVPoll] = {}
113
+ self.rank_port = None
114
+ self.server_socket = zmq.Context().socket(zmq.PULL)
115
+ self.register_buffer_to_engine()
116
+ if self.disaggregation_mode == DisaggregationMode.PREFILL:
117
+ self.transfer_queue = queue.Queue()
118
+ self.transfer_infos: Dict[int, TransferInfo] = {}
119
+ self.start_prefill_thread()
120
+ self._register_to_bootstrap()
121
+ elif self.disaggregation_mode == DisaggregationMode.DECODE:
122
+ self.start_decode_thread()
123
+ self.connection_pool: Dict[str, Dict[str, Union[str, int]]] = {}
124
+ else:
125
+ raise ValueError(
126
+ f"Unsupported DisaggregationMode: {self.disaggregation_mode}"
127
+ )
128
+
129
+ def register_buffer_to_engine(self):
130
+ for kv_data_ptr, kv_data_len in zip(
131
+ self.kv_args.kv_data_ptrs, self.kv_args.kv_data_lens
132
+ ):
133
+ self.engine.register(kv_data_ptr, kv_data_len)
134
+
135
+ for aux_data_ptr, aux_data_len in zip(
136
+ self.kv_args.aux_data_ptrs, self.kv_args.aux_data_lens
137
+ ):
138
+ self.engine.register(aux_data_ptr, aux_data_len)
139
+
140
+ @cache
141
+ def _connect(self, endpoint: str):
142
+ socket = zmq.Context().socket(zmq.PUSH)
143
+ socket.connect(endpoint)
144
+ return socket
145
+
146
+ def send_kvcache(
147
+ self,
148
+ mooncake_session_id: str,
149
+ prefill_kv_indices: npt.NDArray[np.int64],
150
+ dst_kv_ptrs: list[int],
151
+ dst_kv_indices: npt.NDArray[np.int64],
152
+ ):
153
+ # group by indices
154
+ prefill_kv_blocks, dst_kv_blocks = group_concurrent_contiguous(
155
+ prefill_kv_indices, dst_kv_indices
156
+ )
157
+
158
+ num_layers = len(self.kv_args.kv_data_ptrs)
159
+ for layer_id in range(num_layers):
160
+ src_ptr = self.kv_args.kv_data_ptrs[layer_id]
161
+ dst_ptr = dst_kv_ptrs[layer_id]
162
+ item_len = self.kv_args.kv_item_lens[layer_id]
163
+
164
+ for prefill_index, decode_index in zip(prefill_kv_blocks, dst_kv_blocks):
165
+ src_addr = src_ptr + int(prefill_index[0]) * item_len
166
+ dst_addr = dst_ptr + int(decode_index[0]) * item_len
167
+ length = item_len * len(prefill_index)
168
+
169
+ # TODO: make async later
170
+ status = self.engine.transfer_sync(
171
+ mooncake_session_id, src_addr, dst_addr, length
172
+ )
173
+ if status != 0:
174
+ return status
175
+
176
+ return 0
177
+
178
+ def send_aux(
179
+ self,
180
+ mooncake_session_id: str,
181
+ prefill_aux_index: int,
182
+ dst_aux_ptrs: list[int],
183
+ dst_aux_index: int,
184
+ ):
185
+ aux_item_len = self.kv_args.aux_item_lens[0]
186
+ prefill_aux_addr = (
187
+ self.kv_args.aux_data_ptrs[0] + prefill_aux_index * aux_item_len
188
+ )
189
+ decode_aux_addr = dst_aux_ptrs[0] + dst_aux_index * aux_item_len
190
+ # TODO: mooncake transfer engine can do async transfer. Do async later
191
+ # Not sure about the amount of aux data, maybe transfer it by zmq is more effective
192
+ status = self.engine.transfer_sync(
193
+ mooncake_session_id, prefill_aux_addr, decode_aux_addr, aux_item_len
194
+ )
195
+ return status
196
+
197
+ def sync_status_to_decode_endpoint(self, remote: str, dst_port: int, room: int):
198
+ if ":" in remote:
199
+ remote = remote.split(":")[0]
200
+ self._connect("tcp://" + remote + ":" + str(dst_port)).send_multipart(
201
+ [
202
+ str(room).encode("ascii"),
203
+ str(self.request_status[room]).encode("ascii"),
204
+ ]
205
+ )
206
+
207
+ def start_prefill_thread(self):
208
+ self.rank_port = get_free_port()
209
+ self.server_socket.bind(f"tcp://{get_local_ip_by_remote()}:{self.rank_port}")
210
+
211
+ def bootstrap_thread():
212
+ """This thread recvs pre-alloc notification from the decode engine"""
213
+ # KVPoll.Bootstrapping -> KVPoll.WaitingForInput
214
+ while True:
215
+ waiting_req_bytes = self.server_socket.recv_multipart()
216
+ room = waiting_req_bytes[0].decode("ascii")
217
+ if room == "None":
218
+ continue
219
+ room = int(room)
220
+ self.transfer_infos[room] = TransferInfo.from_zmq(waiting_req_bytes)
221
+
222
+ # NOTE: after bootstrapping we can mark the req as waiting for input
223
+ self.request_status[room] = KVPoll.WaitingForInput
224
+
225
+ def transfer_thread():
226
+ # TODO: Shall we use KVPoll.Transferring state?
227
+ while True:
228
+ try:
229
+ kv_chunk: TransferKVChunk = self.transfer_queue.get(timeout=0.01)
230
+ req = self.transfer_infos[kv_chunk.room]
231
+ chunked_dst_kv_indice = req.dst_kv_indices[kv_chunk.index_slice]
232
+ assert len(chunked_dst_kv_indice) == len(
233
+ kv_chunk.prefill_kv_indices
234
+ )
235
+
236
+ ret = self.send_kvcache(
237
+ req.mooncake_session_id,
238
+ kv_chunk.prefill_kv_indices,
239
+ req.dst_kv_ptrs,
240
+ chunked_dst_kv_indice,
241
+ )
242
+ if ret != 0:
243
+ self.request_status[kv_chunk.room] = KVPoll.Failed
244
+ self.sync_status_to_decode_endpoint(
245
+ req.endpoint, req.dst_port, req.room
246
+ )
247
+ continue
248
+
249
+ if kv_chunk.is_last:
250
+ # Only the last chunk we need to send the aux data
251
+ ret = self.send_aux(
252
+ req.mooncake_session_id,
253
+ kv_chunk.prefill_aux_index,
254
+ req.dst_aux_ptrs,
255
+ req.dst_aux_index,
256
+ )
257
+ self.request_status[req.room] = (
258
+ KVPoll.Success if ret == 0 else KVPoll.Failed
259
+ )
260
+ self.sync_status_to_decode_endpoint(
261
+ req.endpoint, req.dst_port, req.room
262
+ )
263
+ self.transfer_infos.pop(req.room)
264
+
265
+ except queue.Empty:
266
+ continue
267
+
268
+ threading.Thread(target=bootstrap_thread).start()
269
+ threading.Thread(target=transfer_thread).start()
270
+
271
+ def start_decode_thread(self):
272
+ self.rank_port = get_free_port()
273
+ self.server_socket.bind(f"tcp://{get_local_ip_by_remote()}:{self.rank_port}")
274
+
275
+ def decode_thread():
276
+ while True:
277
+ (bootstrap_room, status) = self.server_socket.recv_multipart()
278
+ status = int(status.decode("ascii"))
279
+ bootstrap_room = int(bootstrap_room.decode("ascii"))
280
+ self.request_status[bootstrap_room] = status
281
+
282
+ threading.Thread(target=decode_thread).start()
283
+
284
+ def add_transfer_request(
285
+ self,
286
+ bootstrap_room: int,
287
+ kv_indices: npt.NDArray[np.int64],
288
+ index_slice: slice,
289
+ is_last: bool,
290
+ aux_index: Optional[int] = None,
291
+ ):
292
+ assert self.disaggregation_mode == DisaggregationMode.PREFILL
293
+ assert not is_last or (is_last and aux_index is not None)
294
+
295
+ self.transfer_queue.put(
296
+ TransferKVChunk(
297
+ room=bootstrap_room,
298
+ prefill_kv_indices=kv_indices,
299
+ index_slice=index_slice,
300
+ is_last=is_last,
301
+ prefill_aux_index=aux_index,
302
+ )
303
+ )
304
+ self.request_status[bootstrap_room] = KVPoll.WaitingForInput
305
+
306
+ def check_status(self, bootstrap_room: int):
307
+ # TOOD: do we really need the poll()?
308
+
309
+ return self.request_status[bootstrap_room]
310
+
311
+ def update_status(self, bootstrap_room: int, status: KVPoll):
312
+ if bootstrap_room not in self.request_status:
313
+ self.request_status[bootstrap_room] = status
314
+ else:
315
+ # NOTE: The prefill engine could recv bootstrapping first
316
+ self.request_status[bootstrap_room] = max(
317
+ self.request_status[bootstrap_room], status
318
+ )
319
+
320
+ def get_session_id(self):
321
+ return self.engine.get_session_id()
322
+
323
+ def _register_to_bootstrap(self):
324
+ """Register KVSender to bootstrap server via HTTP POST."""
325
+ if self.dist_init_addr:
326
+ ip_address = socket.gethostbyname(self.dist_init_addr.split(":")[0])
327
+ else:
328
+ ip_address = get_ip()
329
+
330
+ bootstrap_server_url = f"{ip_address}:{self.bootstrap_port}"
331
+ url = f"http://{bootstrap_server_url}/route"
332
+ payload = {
333
+ "role": "Prefill",
334
+ "rank_ip": get_local_ip_by_remote(),
335
+ "rank_port": self.rank_port,
336
+ "engine_rank": self.kv_args.engine_rank,
337
+ }
338
+
339
+ try:
340
+ response = requests.put(url, json=payload)
341
+ if response.status_code == 200:
342
+ logger.debug("Prefill successfully registered to bootstrap server.")
343
+ else:
344
+ logger.error(
345
+ f"Prefill Failed to connect to bootstrap server: {response.status_code}, {response.text}"
346
+ )
347
+ except Exception as e:
348
+ logger.error(f"Prefill Failed to register to bootstrap server: {e}")
349
+
350
+
351
+ class MooncakeKVSender(BaseKVSender):
352
+
353
+ def __init__(
354
+ self, mgr: MooncakeKVManager, bootstrap_addr: str, bootstrap_room: int
355
+ ):
356
+ self.kv_mgr = mgr
357
+ self.bootstrap_room = bootstrap_room
358
+ self.kv_mgr.update_status(bootstrap_room, KVPoll.Bootstrapping)
359
+ self.aux_index = None
360
+ self.bootstrap_server_url = bootstrap_addr
361
+ self.session_id = self.kv_mgr.get_session_id()
362
+
363
+ def init(self, num_kv_indices: int, aux_index: Optional[int] = None):
364
+ self.num_kv_indices = num_kv_indices
365
+ self.aux_index = aux_index
366
+
367
+ def send(
368
+ self,
369
+ kv_indices: npt.NDArray[np.int64],
370
+ index_slice: slice,
371
+ is_last: bool,
372
+ ):
373
+ if not is_last:
374
+ self.kv_mgr.add_transfer_request(
375
+ self.bootstrap_room, kv_indices, index_slice, False
376
+ )
377
+ else:
378
+ self.kv_mgr.add_transfer_request(
379
+ self.bootstrap_room,
380
+ kv_indices,
381
+ index_slice,
382
+ True,
383
+ aux_index=self.aux_index,
384
+ )
385
+
386
+ def poll(self) -> KVPoll:
387
+ return self.kv_mgr.check_status(self.bootstrap_room)
388
+
389
+ def failure_exception(self):
390
+ raise Exception("Fake KVSender Exception")
391
+
392
+
393
+ class MooncakeKVReceiver(BaseKVReceiver):
394
+ _ctx = zmq.Context()
395
+ _socket_cache = {}
396
+ _socket_locks = {}
397
+ _global_lock = threading.Lock()
398
+
399
+ def __init__(
400
+ self,
401
+ mgr: MooncakeKVManager,
402
+ bootstrap_addr: str,
403
+ bootstrap_room: Optional[int] = None,
404
+ ):
405
+ self.bootstrap_room = bootstrap_room
406
+ self.bootstrap_addr = bootstrap_addr
407
+ self.kv_mgr = mgr
408
+ self.session_id = self.kv_mgr.get_session_id()
409
+ self.kv_mgr.update_status(bootstrap_room, KVPoll.Bootstrapping)
410
+
411
+ # NOTE: key distinguished by bootstrap_addr and engine_rank
412
+ bootstrap_key = f"{self.bootstrap_addr}_{self.kv_mgr.kv_args.engine_rank}"
413
+
414
+ if bootstrap_key not in self.kv_mgr.connection_pool:
415
+ self.bootstrap_info = self._get_bootstrap_info_from_server(
416
+ self.kv_mgr.kv_args.engine_rank
417
+ )
418
+ if self.bootstrap_info is None:
419
+ logger.error(
420
+ f"Could not fetch bootstrap info for engine rank: {self.kv_mgr.kv_args.engine_rank}"
421
+ )
422
+ else:
423
+ self.kv_mgr.connection_pool[bootstrap_key] = self.bootstrap_info
424
+ else:
425
+ self.bootstrap_info = self.kv_mgr.connection_pool[bootstrap_key]
426
+
427
+ assert self.bootstrap_info is not None
428
+ self.kv_mgr.update_status(bootstrap_room, KVPoll.WaitingForInput)
429
+
430
+ def _get_bootstrap_info_from_server(self, engine_rank):
431
+ """Fetch the bootstrap info from the bootstrap server."""
432
+ try:
433
+ url = f"http://{self.bootstrap_addr}/route?engine_rank={engine_rank}"
434
+ response = requests.get(url)
435
+ if response.status_code == 200:
436
+ bootstrap_info = response.json()
437
+ return bootstrap_info
438
+ else:
439
+ logger.error(
440
+ f"Failed to get prefill server info: {response.status_code}, {response.text}"
441
+ )
442
+ return None
443
+ except Exception as e:
444
+ logger.error(f"Error fetching prefill info from bootstrap: {e}")
445
+ return None
446
+
447
+ @classmethod
448
+ def _connect(cls, endpoint: str):
449
+ with cls._global_lock:
450
+ if endpoint not in cls._socket_cache:
451
+ sock = cls._ctx.socket(zmq.PUSH)
452
+ sock.connect(endpoint)
453
+ cls._socket_cache[endpoint] = sock
454
+ cls._socket_locks[endpoint] = threading.Lock()
455
+ return cls._socket_cache[endpoint], cls._socket_locks[endpoint]
456
+
457
+ def init(self, kv_indices: npt.NDArray[np.int64], aux_index: Optional[int] = None):
458
+ self.prefill_server_url = (
459
+ f"{self.bootstrap_info['rank_ip']}:{self.bootstrap_info['rank_port']}"
460
+ )
461
+ logger.debug(
462
+ f"Fetched bootstrap info: {self.bootstrap_info} for engine rank: {self.kv_mgr.kv_args.engine_rank}"
463
+ )
464
+
465
+ packed_kv_data_ptrs = b"".join(
466
+ struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.kv_data_ptrs
467
+ )
468
+ packed_aux_data_ptrs = b"".join(
469
+ struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.aux_data_ptrs
470
+ )
471
+ sock, lock = self._connect("tcp://" + self.prefill_server_url)
472
+ with lock:
473
+ sock.send_multipart(
474
+ [
475
+ str(self.bootstrap_room).encode("ascii"),
476
+ get_local_ip_by_remote().encode("ascii"),
477
+ str(self.kv_mgr.rank_port).encode("ascii"),
478
+ self.session_id.encode("ascii"),
479
+ packed_kv_data_ptrs,
480
+ kv_indices.tobytes(),
481
+ packed_aux_data_ptrs,
482
+ str(aux_index).encode("ascii"),
483
+ ]
484
+ )
485
+
486
+ def poll(self) -> KVPoll:
487
+ return self.kv_mgr.check_status(self.bootstrap_room)
488
+
489
+ def failure_exception(self):
490
+ raise Exception("Fake KVReceiver Exception")
491
+
492
+
493
+ class MooncakeKVBootstrapServer(BaseKVBootstrapServer):
494
+ def __init__(self, port: int):
495
+ self.port = port
496
+ self.app = web.Application()
497
+ self.store = dict()
498
+ self.lock = asyncio.Lock()
499
+ self._setup_routes()
500
+ self.prefill_port_table: Dict[int, Dict[str, Union[str, int]]] = {}
501
+
502
+ # Start bootstrap server
503
+ self.thread = threading.Thread(target=self._run_server, daemon=True)
504
+ self.run()
505
+
506
+ def run(self):
507
+ self.thread.start()
508
+
509
+ def _setup_routes(self):
510
+ self.app.router.add_route("*", "/route", self._handle_route)
511
+
512
+ async def _handle_route(self, request: web.Request):
513
+ method = request.method
514
+ if method == "PUT":
515
+ return await self._handle_route_put(request)
516
+ elif method == "GET":
517
+ return await self._handle_route_get(request)
518
+ else:
519
+ return web.Response(
520
+ text="Method not allowed", status=405, content_type="application/json"
521
+ )
522
+
523
+ async def _handle_route_put(self, request: web.Request):
524
+ data = await request.json()
525
+ role = data["role"]
526
+ rank_ip = data["rank_ip"]
527
+ rank_port = int(data["rank_port"])
528
+ engine_rank = int(data["engine_rank"])
529
+
530
+ # Add lock to make sure thread-safe
531
+ if role == "Prefill":
532
+ self.prefill_port_table[engine_rank] = {
533
+ "rank_ip": rank_ip,
534
+ "rank_port": rank_port,
535
+ }
536
+ logger.debug(
537
+ f"Registered Prefill boostrap: {engine_rank} with rank_ip: {rank_ip} and rank_port: {rank_port}"
538
+ )
539
+
540
+ return web.Response(text="OK", status=200)
541
+
542
+ async def _handle_route_get(self, request: web.Request):
543
+ engine_rank = request.query.get("engine_rank")
544
+ if not engine_rank:
545
+ return web.Response(text="Missing rank", status=400)
546
+
547
+ # Find corresponding prefill info
548
+ async with self.lock:
549
+ bootstrap_info = self.prefill_port_table.get(int(engine_rank))
550
+
551
+ if bootstrap_info is not None:
552
+ return web.json_response(bootstrap_info, status=200)
553
+ else:
554
+ return web.Response(text="Not Found", status=404)
555
+
556
+ def _run_server(self):
557
+ try:
558
+ # Event Loop
559
+ self._loop = asyncio.new_event_loop()
560
+ asyncio.set_event_loop(self._loop)
561
+
562
+ self._runner = web.AppRunner(self.app)
563
+ self._loop.run_until_complete(self._runner.setup())
564
+
565
+ site = web.TCPSite(self._runner, port=self.port)
566
+ self._loop.run_until_complete(site.start())
567
+ self._loop.run_forever()
568
+ except Exception as e:
569
+ logger.error(f"Server error: {str(e)}")
570
+ finally:
571
+ # Cleanup
572
+ self._loop.run_until_complete(self._runner.cleanup())
573
+ self._loop.close()
574
+
575
+ def close(self):
576
+ """Shutdown"""
577
+ if self._loop is not None and self._loop.is_running():
578
+ self._loop.call_soon_threadsafe(self._loop.stop)
579
+ logger.info("Stopping server loop...")
580
+
581
+ if self.thread.is_alive():
582
+ self.thread.join(timeout=2)
583
+ logger.info("Server thread stopped")
584
+
585
+ def poll(self) -> KVPoll: ...
@@ -0,0 +1,77 @@
1
+ import json
2
+ import logging
3
+ from dataclasses import dataclass
4
+ from typing import Optional
5
+
6
+ logger = logging.getLogger(__name__)
7
+
8
+
9
+ class MooncakeTransferEngine:
10
+
11
+ def __init__(self, hostname: str, gpu_id: int, ib_device: Optional[str] = None):
12
+ try:
13
+ from mooncake.engine import TransferEngine
14
+ except ImportError as e:
15
+ raise ImportError(
16
+ "Please install mooncake by following the instructions at "
17
+ "https://github.com/kvcache-ai/Mooncake/blob/main/doc/en/build.md " # noqa: E501
18
+ "to run SGLang with MooncakeTransferEngine."
19
+ ) from e
20
+
21
+ self.engine = TransferEngine()
22
+ self.hostname = hostname
23
+ self.gpu_id = gpu_id
24
+ self.ib_device = ib_device
25
+
26
+ self.initialize(
27
+ hostname=self.hostname,
28
+ device_name=self.ib_device,
29
+ )
30
+ self.session_id = f"{self.hostname}:{self.engine.get_rpc_port()}"
31
+
32
+ def register(self, ptr, length):
33
+ ret_value = self.engine.register_memory(ptr, length)
34
+ if ret_value != 0:
35
+ logger.error("Mooncake memory registration failed.")
36
+ raise RuntimeError("Mooncake memory registration failed.")
37
+
38
+ def deregister(self, ptr):
39
+ ret_value = self.engine.unregister_memory(ptr)
40
+ if ret_value != 0:
41
+ logger.error("Mooncake memory deregistration failed.")
42
+ raise RuntimeError("Mooncake memory deregistration failed.")
43
+
44
+ def initialize(
45
+ self,
46
+ hostname: str,
47
+ device_name: Optional[str],
48
+ ) -> None:
49
+ """Initialize the mooncake instance."""
50
+ ret_value = self.engine.initialize(
51
+ hostname,
52
+ "P2PHANDSHAKE",
53
+ "rdma",
54
+ device_name if device_name is not None else "",
55
+ )
56
+ if ret_value != 0:
57
+ logger.error("Mooncake Transfer Engine initialization failed.")
58
+ raise RuntimeError("Mooncake Transfer Engine initialization failed.")
59
+
60
+ def transfer_sync(
61
+ self, session_id: str, buffer: int, peer_buffer_address: int, length: int
62
+ ) -> int:
63
+ """Synchronously transfer data to the specified address."""
64
+
65
+ ret = self.engine.transfer_sync_write(
66
+ session_id, buffer, peer_buffer_address, length
67
+ )
68
+ if ret < 0:
69
+ logger.error("Mooncake Transfer Engine Return Error.")
70
+ raise RuntimeError("Mooncake Transfer Engine Return Error.")
71
+ return ret
72
+
73
+ def get_localhost(self):
74
+ return self.hostname
75
+
76
+ def get_session_id(self):
77
+ return self.session_id