sglang 0.4.7__py3-none-any.whl → 0.4.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +2 -0
- sglang/api.py +7 -0
- sglang/bench_one_batch.py +8 -6
- sglang/bench_serving.py +1 -1
- sglang/lang/interpreter.py +40 -1
- sglang/lang/ir.py +27 -0
- sglang/math_utils.py +8 -0
- sglang/srt/_custom_ops.py +2 -2
- sglang/srt/code_completion_parser.py +2 -44
- sglang/srt/configs/model_config.py +6 -0
- sglang/srt/constants.py +3 -0
- sglang/srt/conversation.py +19 -3
- sglang/srt/custom_op.py +5 -1
- sglang/srt/disaggregation/base/__init__.py +1 -1
- sglang/srt/disaggregation/base/conn.py +25 -11
- sglang/srt/disaggregation/common/__init__.py +5 -1
- sglang/srt/disaggregation/common/utils.py +42 -0
- sglang/srt/disaggregation/decode.py +211 -72
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +4 -3
- sglang/srt/disaggregation/fake/__init__.py +1 -1
- sglang/srt/disaggregation/fake/conn.py +15 -9
- sglang/srt/disaggregation/mini_lb.py +34 -4
- sglang/srt/disaggregation/mooncake/__init__.py +1 -1
- sglang/srt/disaggregation/mooncake/conn.py +30 -29
- sglang/srt/disaggregation/nixl/__init__.py +6 -1
- sglang/srt/disaggregation/nixl/conn.py +17 -12
- sglang/srt/disaggregation/prefill.py +144 -55
- sglang/srt/disaggregation/utils.py +155 -123
- sglang/srt/distributed/parallel_state.py +12 -4
- sglang/srt/entrypoints/engine.py +37 -29
- sglang/srt/entrypoints/http_server.py +153 -72
- sglang/srt/entrypoints/http_server_engine.py +0 -3
- sglang/srt/entrypoints/openai/__init__.py +0 -0
- sglang/srt/{openai_api → entrypoints/openai}/protocol.py +84 -10
- sglang/srt/entrypoints/openai/serving_base.py +149 -0
- sglang/srt/entrypoints/openai/serving_chat.py +921 -0
- sglang/srt/entrypoints/openai/serving_completions.py +424 -0
- sglang/srt/entrypoints/openai/serving_embedding.py +169 -0
- sglang/srt/entrypoints/openai/serving_rerank.py +102 -0
- sglang/srt/entrypoints/openai/serving_score.py +61 -0
- sglang/srt/entrypoints/openai/usage_processor.py +81 -0
- sglang/srt/entrypoints/openai/utils.py +72 -0
- sglang/srt/eplb_simulator/__init__.py +1 -0
- sglang/srt/eplb_simulator/reader.py +51 -0
- sglang/srt/function_call/base_format_detector.py +7 -4
- sglang/srt/function_call/deepseekv3_detector.py +1 -1
- sglang/srt/function_call/ebnf_composer.py +64 -10
- sglang/srt/function_call/function_call_parser.py +6 -6
- sglang/srt/function_call/llama32_detector.py +1 -1
- sglang/srt/function_call/mistral_detector.py +1 -1
- sglang/srt/function_call/pythonic_detector.py +1 -1
- sglang/srt/function_call/qwen25_detector.py +1 -1
- sglang/srt/{openai_api/utils.py → jinja_template_utils.py} +6 -5
- sglang/srt/layers/activation.py +40 -3
- sglang/srt/layers/attention/aiter_backend.py +20 -4
- sglang/srt/layers/attention/base_attn_backend.py +1 -1
- sglang/srt/layers/attention/cutlass_mla_backend.py +39 -15
- sglang/srt/layers/attention/flashattention_backend.py +71 -72
- sglang/srt/layers/attention/flashinfer_backend.py +10 -8
- sglang/srt/layers/attention/flashinfer_mla_backend.py +29 -28
- sglang/srt/layers/attention/flashmla_backend.py +7 -12
- sglang/srt/layers/attention/tbo_backend.py +3 -3
- sglang/srt/layers/attention/triton_backend.py +138 -130
- sglang/srt/layers/attention/triton_ops/decode_attention.py +2 -7
- sglang/srt/layers/attention/vision.py +51 -24
- sglang/srt/layers/communicator.py +28 -10
- sglang/srt/layers/dp_attention.py +11 -2
- sglang/srt/layers/layernorm.py +29 -2
- sglang/srt/layers/linear.py +0 -4
- sglang/srt/layers/logits_processor.py +2 -14
- sglang/srt/layers/moe/ep_moe/kernels.py +165 -7
- sglang/srt/layers/moe/ep_moe/layer.py +249 -33
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +11 -37
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=384,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +7 -4
- sglang/srt/layers/moe/fused_moe_triton/layer.py +75 -12
- sglang/srt/layers/moe/topk.py +107 -12
- sglang/srt/layers/pooler.py +56 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +6 -2
- sglang/srt/layers/quantization/deep_gemm_wrapper/__init__.py +1 -0
- sglang/srt/layers/quantization/{deep_gemm.py → deep_gemm_wrapper/compile_utils.py} +23 -80
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +32 -0
- sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +110 -0
- sglang/srt/layers/quantization/fp8.py +25 -17
- sglang/srt/layers/quantization/fp8_kernel.py +44 -15
- sglang/srt/layers/quantization/fp8_utils.py +87 -22
- sglang/srt/layers/quantization/modelopt_quant.py +62 -8
- sglang/srt/layers/quantization/utils.py +5 -2
- sglang/srt/layers/radix_attention.py +2 -3
- sglang/srt/layers/rotary_embedding.py +42 -2
- sglang/srt/layers/sampler.py +1 -1
- sglang/srt/lora/lora_manager.py +249 -105
- sglang/srt/lora/mem_pool.py +53 -50
- sglang/srt/lora/utils.py +1 -1
- sglang/srt/managers/cache_controller.py +33 -14
- sglang/srt/managers/io_struct.py +31 -10
- sglang/srt/managers/multimodal_processors/base_processor.py +2 -2
- sglang/srt/managers/multimodal_processors/vila.py +85 -0
- sglang/srt/managers/schedule_batch.py +79 -37
- sglang/srt/managers/schedule_policy.py +70 -56
- sglang/srt/managers/scheduler.py +220 -79
- sglang/srt/managers/template_manager.py +226 -0
- sglang/srt/managers/tokenizer_manager.py +40 -10
- sglang/srt/managers/tp_worker.py +12 -2
- sglang/srt/managers/tp_worker_overlap_thread.py +11 -0
- sglang/srt/mem_cache/{paged_allocator.py → allocator.py} +125 -34
- sglang/srt/mem_cache/base_prefix_cache.py +52 -8
- sglang/srt/mem_cache/chunk_cache.py +11 -15
- sglang/srt/mem_cache/hiradix_cache.py +38 -25
- sglang/srt/mem_cache/memory_pool.py +213 -505
- sglang/srt/mem_cache/memory_pool_host.py +380 -0
- sglang/srt/mem_cache/radix_cache.py +56 -28
- sglang/srt/model_executor/cuda_graph_runner.py +198 -100
- sglang/srt/model_executor/forward_batch_info.py +32 -10
- sglang/srt/model_executor/model_runner.py +28 -12
- sglang/srt/model_loader/loader.py +16 -2
- sglang/srt/model_loader/weight_utils.py +11 -2
- sglang/srt/models/bert.py +113 -13
- sglang/srt/models/deepseek_nextn.py +29 -27
- sglang/srt/models/deepseek_v2.py +213 -173
- sglang/srt/models/glm4.py +312 -0
- sglang/srt/models/internvl.py +46 -102
- sglang/srt/models/mimo_mtp.py +2 -18
- sglang/srt/models/roberta.py +117 -9
- sglang/srt/models/vila.py +305 -0
- sglang/srt/reasoning_parser.py +21 -11
- sglang/srt/sampling/sampling_batch_info.py +24 -0
- sglang/srt/sampling/sampling_params.py +2 -0
- sglang/srt/server_args.py +351 -238
- sglang/srt/speculative/build_eagle_tree.py +1 -1
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +131 -9
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +130 -14
- sglang/srt/speculative/eagle_utils.py +468 -116
- sglang/srt/speculative/eagle_worker.py +258 -84
- sglang/srt/torch_memory_saver_adapter.py +19 -15
- sglang/srt/two_batch_overlap.py +4 -2
- sglang/srt/utils.py +235 -11
- sglang/test/attention/test_prefix_chunk_info.py +2 -0
- sglang/test/runners.py +38 -3
- sglang/test/test_block_fp8.py +1 -0
- sglang/test/test_block_fp8_deep_gemm_blackwell.py +252 -0
- sglang/test/test_block_fp8_ep.py +2 -0
- sglang/test/test_utils.py +4 -1
- sglang/utils.py +9 -0
- sglang/version.py +1 -1
- {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/METADATA +8 -14
- {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/RECORD +150 -128
- sglang/srt/entrypoints/verl_engine.py +0 -179
- sglang/srt/openai_api/adapter.py +0 -1990
- {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/WHEEL +0 -0
- {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/top_level.txt +0 -0
@@ -18,6 +18,10 @@ from fastapi.responses import ORJSONResponse, Response, StreamingResponse
|
|
18
18
|
|
19
19
|
from sglang.srt.disaggregation.utils import PDRegistryRequest
|
20
20
|
|
21
|
+
AIOHTTP_STREAM_READ_CHUNK_SIZE = (
|
22
|
+
1024 * 64
|
23
|
+
) # 64KB, to prevent aiohttp's "Chunk too big" error
|
24
|
+
|
21
25
|
|
22
26
|
def setup_logger():
|
23
27
|
logger = logging.getLogger("pdlb")
|
@@ -154,7 +158,9 @@ class MiniLoadBalancer:
|
|
154
158
|
else:
|
155
159
|
yield chunk
|
156
160
|
else:
|
157
|
-
async for chunk in decode_response.content
|
161
|
+
async for chunk in decode_response.content.iter_chunked(
|
162
|
+
AIOHTTP_STREAM_READ_CHUNK_SIZE
|
163
|
+
):
|
158
164
|
yield chunk
|
159
165
|
|
160
166
|
return StreamingResponse(
|
@@ -212,15 +218,39 @@ async def get_server_info():
|
|
212
218
|
)
|
213
219
|
prefill_infos = []
|
214
220
|
decode_infos = []
|
221
|
+
all_internal_states = []
|
222
|
+
|
215
223
|
async with aiohttp.ClientSession() as session:
|
216
224
|
for server in chain(prefill_servers):
|
217
225
|
server_info = await session.get(f"{server}/get_server_info")
|
218
226
|
prefill_infos.append(await server_info.json())
|
219
227
|
for server in chain(decode_servers):
|
220
228
|
server_info = await session.get(f"{server}/get_server_info")
|
221
|
-
|
222
|
-
|
223
|
-
|
229
|
+
info_json = await server_info.json()
|
230
|
+
decode_infos.append(info_json)
|
231
|
+
# Extract internal_states from decode servers
|
232
|
+
if "internal_states" in info_json:
|
233
|
+
all_internal_states.extend(info_json["internal_states"])
|
234
|
+
|
235
|
+
# Return format expected by bench_one_batch_server.py
|
236
|
+
if all_internal_states:
|
237
|
+
return {
|
238
|
+
"internal_states": all_internal_states,
|
239
|
+
"prefill": prefill_infos,
|
240
|
+
"decode": decode_infos,
|
241
|
+
}
|
242
|
+
else:
|
243
|
+
# Fallback with dummy data if no internal states found
|
244
|
+
return {
|
245
|
+
"internal_states": [
|
246
|
+
{
|
247
|
+
"last_gen_throughput": 0.0,
|
248
|
+
"avg_spec_accept_length": None,
|
249
|
+
}
|
250
|
+
],
|
251
|
+
"prefill": prefill_infos,
|
252
|
+
"decode": decode_infos,
|
253
|
+
}
|
224
254
|
|
225
255
|
|
226
256
|
@app.get("/get_model_info")
|
@@ -28,19 +28,14 @@ from sglang.srt.disaggregation.base.conn import (
|
|
28
28
|
KVArgs,
|
29
29
|
KVPoll,
|
30
30
|
)
|
31
|
-
from sglang.srt.disaggregation.
|
32
|
-
from sglang.srt.disaggregation.utils import (
|
33
|
-
DisaggregationMode,
|
31
|
+
from sglang.srt.disaggregation.common.utils import (
|
34
32
|
FastQueue,
|
35
33
|
group_concurrent_contiguous,
|
36
34
|
)
|
35
|
+
from sglang.srt.disaggregation.mooncake.transfer_engine import MooncakeTransferEngine
|
36
|
+
from sglang.srt.disaggregation.utils import DisaggregationMode
|
37
37
|
from sglang.srt.server_args import ServerArgs
|
38
|
-
from sglang.srt.utils import
|
39
|
-
get_free_port,
|
40
|
-
get_int_env_var,
|
41
|
-
get_ip,
|
42
|
-
get_local_ip_by_remote,
|
43
|
-
)
|
38
|
+
from sglang.srt.utils import get_free_port, get_int_env_var, get_ip, get_local_ip_auto
|
44
39
|
|
45
40
|
logger = logging.getLogger(__name__)
|
46
41
|
|
@@ -59,7 +54,7 @@ class KVTransferError(Exception):
|
|
59
54
|
@dataclasses.dataclass
|
60
55
|
class TransferKVChunk:
|
61
56
|
room: int
|
62
|
-
prefill_kv_indices: npt.NDArray[np.
|
57
|
+
prefill_kv_indices: npt.NDArray[np.int32]
|
63
58
|
index_slice: slice
|
64
59
|
is_last: bool
|
65
60
|
prefill_aux_index: Optional[int]
|
@@ -72,7 +67,7 @@ class TransferInfo:
|
|
72
67
|
endpoint: str
|
73
68
|
dst_port: int
|
74
69
|
mooncake_session_id: str
|
75
|
-
dst_kv_indices: npt.NDArray[np.
|
70
|
+
dst_kv_indices: npt.NDArray[np.int32]
|
76
71
|
dst_aux_index: int
|
77
72
|
required_dst_info_num: int
|
78
73
|
is_dummy: bool
|
@@ -81,10 +76,10 @@ class TransferInfo:
|
|
81
76
|
def from_zmq(cls, msg: List[bytes]):
|
82
77
|
if msg[4] == b"" and msg[5] == b"":
|
83
78
|
is_dummy = True
|
84
|
-
dst_kv_indices = np.array([], dtype=np.
|
79
|
+
dst_kv_indices = np.array([], dtype=np.int32)
|
85
80
|
dst_aux_index = None
|
86
81
|
else:
|
87
|
-
dst_kv_indices = np.frombuffer(msg[4], dtype=np.
|
82
|
+
dst_kv_indices = np.frombuffer(msg[4], dtype=np.int32)
|
88
83
|
dst_aux_index = int(msg[5].decode("ascii"))
|
89
84
|
is_dummy = False
|
90
85
|
return cls(
|
@@ -130,8 +125,9 @@ class MooncakeKVManager(BaseKVManager):
|
|
130
125
|
is_mla_backend: Optional[bool] = False,
|
131
126
|
):
|
132
127
|
self.kv_args = args
|
128
|
+
self.local_ip = get_local_ip_auto()
|
133
129
|
self.engine = MooncakeTransferEngine(
|
134
|
-
hostname=
|
130
|
+
hostname=self.local_ip,
|
135
131
|
gpu_id=self.kv_args.gpu_id,
|
136
132
|
ib_device=self.kv_args.ib_device,
|
137
133
|
)
|
@@ -233,9 +229,9 @@ class MooncakeKVManager(BaseKVManager):
|
|
233
229
|
def send_kvcache(
|
234
230
|
self,
|
235
231
|
mooncake_session_id: str,
|
236
|
-
prefill_kv_indices: npt.NDArray[np.
|
232
|
+
prefill_kv_indices: npt.NDArray[np.int32],
|
237
233
|
dst_kv_ptrs: list[int],
|
238
|
-
dst_kv_indices: npt.NDArray[np.
|
234
|
+
dst_kv_indices: npt.NDArray[np.int32],
|
239
235
|
executor: concurrent.futures.ThreadPoolExecutor,
|
240
236
|
):
|
241
237
|
# Group by indices
|
@@ -432,7 +428,7 @@ class MooncakeKVManager(BaseKVManager):
|
|
432
428
|
|
433
429
|
def start_prefill_thread(self):
|
434
430
|
self.rank_port = get_free_port()
|
435
|
-
self.server_socket.bind(f"tcp://{
|
431
|
+
self.server_socket.bind(f"tcp://{self.local_ip}:{self.rank_port}")
|
436
432
|
|
437
433
|
def bootstrap_thread():
|
438
434
|
"""This thread recvs pre-alloc notification from the decode engine"""
|
@@ -471,7 +467,7 @@ class MooncakeKVManager(BaseKVManager):
|
|
471
467
|
|
472
468
|
def start_decode_thread(self):
|
473
469
|
self.rank_port = get_free_port()
|
474
|
-
self.server_socket.bind(f"tcp://{
|
470
|
+
self.server_socket.bind(f"tcp://{self.local_ip}:{self.rank_port}")
|
475
471
|
|
476
472
|
def decode_thread():
|
477
473
|
while True:
|
@@ -545,7 +541,7 @@ class MooncakeKVManager(BaseKVManager):
|
|
545
541
|
def add_transfer_request(
|
546
542
|
self,
|
547
543
|
bootstrap_room: int,
|
548
|
-
kv_indices: npt.NDArray[np.
|
544
|
+
kv_indices: npt.NDArray[np.int32],
|
549
545
|
index_slice: slice,
|
550
546
|
is_last: bool,
|
551
547
|
aux_index: Optional[int] = None,
|
@@ -620,7 +616,7 @@ class MooncakeKVManager(BaseKVManager):
|
|
620
616
|
"role": "Prefill",
|
621
617
|
"tp_size": self.tp_size,
|
622
618
|
"dp_size": self.dp_size,
|
623
|
-
"rank_ip":
|
619
|
+
"rank_ip": self.local_ip,
|
624
620
|
"rank_port": self.rank_port,
|
625
621
|
"engine_rank": self.kv_args.engine_rank,
|
626
622
|
}
|
@@ -677,7 +673,12 @@ class MooncakeKVManager(BaseKVManager):
|
|
677
673
|
class MooncakeKVSender(BaseKVSender):
|
678
674
|
|
679
675
|
def __init__(
|
680
|
-
self,
|
676
|
+
self,
|
677
|
+
mgr: MooncakeKVManager,
|
678
|
+
bootstrap_addr: str,
|
679
|
+
bootstrap_room: int,
|
680
|
+
dest_tp_ranks: List[int],
|
681
|
+
pp_rank: int,
|
681
682
|
):
|
682
683
|
self.kv_mgr = mgr
|
683
684
|
self.bootstrap_room = bootstrap_room
|
@@ -696,7 +697,7 @@ class MooncakeKVSender(BaseKVSender):
|
|
696
697
|
|
697
698
|
def send(
|
698
699
|
self,
|
699
|
-
kv_indices: npt.NDArray[np.
|
700
|
+
kv_indices: npt.NDArray[np.int32],
|
700
701
|
):
|
701
702
|
index_slice = slice(self.curr_idx, self.curr_idx + len(kv_indices))
|
702
703
|
self.curr_idx += len(kv_indices)
|
@@ -741,12 +742,12 @@ class MooncakeKVSender(BaseKVSender):
|
|
741
742
|
self.kv_mgr.request_status.pop(self.bootstrap_room)
|
742
743
|
|
743
744
|
def failure_exception(self):
|
744
|
-
self.clear()
|
745
|
-
|
746
745
|
# Explicitly set the status to failure since this request has failed in another rank
|
747
746
|
if self.conclude_state is None:
|
748
747
|
self.conclude_state = KVPoll.Failed
|
749
748
|
|
749
|
+
self.clear()
|
750
|
+
|
750
751
|
with self.kv_mgr.failure_lock:
|
751
752
|
failure_reason = self.kv_mgr.failure_records.pop(
|
752
753
|
self.bootstrap_room, "Failed due to an unknown reason from another rank"
|
@@ -948,7 +949,7 @@ class MooncakeKVReceiver(BaseKVReceiver):
|
|
948
949
|
sock.send_multipart(
|
949
950
|
[
|
950
951
|
"None".encode("ascii"),
|
951
|
-
|
952
|
+
self.kv_mgr.local_ip.encode("ascii"),
|
952
953
|
str(self.kv_mgr.rank_port).encode("ascii"),
|
953
954
|
self.session_id.encode("ascii"),
|
954
955
|
packed_kv_data_ptrs,
|
@@ -966,7 +967,7 @@ class MooncakeKVReceiver(BaseKVReceiver):
|
|
966
967
|
cls._socket_locks[endpoint] = threading.Lock()
|
967
968
|
return cls._socket_cache[endpoint], cls._socket_locks[endpoint]
|
968
969
|
|
969
|
-
def init(self, kv_indices: npt.NDArray[np.
|
970
|
+
def init(self, kv_indices: npt.NDArray[np.int32], aux_index: Optional[int] = None):
|
970
971
|
for bootstrap_info in self.bootstrap_infos:
|
971
972
|
self.prefill_server_url = (
|
972
973
|
f"{bootstrap_info['rank_ip']}:{bootstrap_info['rank_port']}"
|
@@ -978,7 +979,7 @@ class MooncakeKVReceiver(BaseKVReceiver):
|
|
978
979
|
sock.send_multipart(
|
979
980
|
[
|
980
981
|
str(self.bootstrap_room).encode("ascii"),
|
981
|
-
|
982
|
+
self.kv_mgr.local_ip.encode("ascii"),
|
982
983
|
str(self.kv_mgr.rank_port).encode("ascii"),
|
983
984
|
self.session_id.encode("ascii"),
|
984
985
|
kv_indices.tobytes() if not is_dummy else b"",
|
@@ -1002,12 +1003,12 @@ class MooncakeKVReceiver(BaseKVReceiver):
|
|
1002
1003
|
self.kv_mgr.request_status.pop(self.bootstrap_room)
|
1003
1004
|
|
1004
1005
|
def failure_exception(self):
|
1005
|
-
self.clear()
|
1006
|
-
|
1007
1006
|
# Explicitly set the status to failure since this request has failed in another rank
|
1008
1007
|
if self.conclude_state is None:
|
1009
1008
|
self.conclude_state = KVPoll.Failed
|
1010
1009
|
|
1010
|
+
self.clear()
|
1011
|
+
|
1011
1012
|
with self.kv_mgr.failure_lock:
|
1012
1013
|
failure_reason = self.kv_mgr.failure_records.pop(
|
1013
1014
|
self.bootstrap_room, "Failed due to an unknown reason from another rank"
|
@@ -24,10 +24,8 @@ from sglang.srt.disaggregation.common.conn import (
|
|
24
24
|
CommonKVManager,
|
25
25
|
CommonKVReceiver,
|
26
26
|
)
|
27
|
-
from sglang.srt.disaggregation.utils import
|
28
|
-
|
29
|
-
group_concurrent_contiguous,
|
30
|
-
)
|
27
|
+
from sglang.srt.disaggregation.common.utils import group_concurrent_contiguous
|
28
|
+
from sglang.srt.disaggregation.utils import DisaggregationMode
|
31
29
|
from sglang.srt.server_args import ServerArgs
|
32
30
|
from sglang.srt.utils import get_local_ip_by_remote
|
33
31
|
|
@@ -46,7 +44,7 @@ class TransferInfo:
|
|
46
44
|
agent_metadata: bytes
|
47
45
|
agent_name: str
|
48
46
|
dst_kv_ptrs: list[int]
|
49
|
-
dst_kv_indices: npt.NDArray[np.
|
47
|
+
dst_kv_indices: npt.NDArray[np.int32]
|
50
48
|
dst_aux_ptrs: list[int]
|
51
49
|
dst_aux_index: int
|
52
50
|
dst_gpu_id: int
|
@@ -64,7 +62,7 @@ class TransferInfo:
|
|
64
62
|
agent_metadata=msg[3],
|
65
63
|
agent_name=msg[4].decode("ascii"),
|
66
64
|
dst_kv_ptrs=list(struct.unpack(f"{len(msg[5])//8}Q", msg[5])),
|
67
|
-
dst_kv_indices=np.frombuffer(msg[6], dtype=np.
|
65
|
+
dst_kv_indices=np.frombuffer(msg[6], dtype=np.int32),
|
68
66
|
dst_aux_ptrs=list(struct.unpack(f"{len(msg[7])//8}Q", msg[7])),
|
69
67
|
dst_aux_index=int(msg[8].decode("ascii")),
|
70
68
|
dst_gpu_id=int(msg[9].decode("ascii")),
|
@@ -164,9 +162,9 @@ class NixlKVManager(CommonKVManager):
|
|
164
162
|
def send_kvcache(
|
165
163
|
self,
|
166
164
|
peer_name: str,
|
167
|
-
prefill_kv_indices: npt.NDArray[np.
|
165
|
+
prefill_kv_indices: npt.NDArray[np.int32],
|
168
166
|
dst_kv_ptrs: list[int],
|
169
|
-
dst_kv_indices: npt.NDArray[np.
|
167
|
+
dst_kv_indices: npt.NDArray[np.int32],
|
170
168
|
dst_gpu_id: int,
|
171
169
|
notif: str,
|
172
170
|
):
|
@@ -248,7 +246,7 @@ class NixlKVManager(CommonKVManager):
|
|
248
246
|
def add_transfer_request(
|
249
247
|
self,
|
250
248
|
bootstrap_room: int,
|
251
|
-
kv_indices: npt.NDArray[np.
|
249
|
+
kv_indices: npt.NDArray[np.int32],
|
252
250
|
index_slice: slice,
|
253
251
|
is_last: bool,
|
254
252
|
chunk_id: int,
|
@@ -350,7 +348,14 @@ class NixlKVManager(CommonKVManager):
|
|
350
348
|
|
351
349
|
class NixlKVSender(BaseKVSender):
|
352
350
|
|
353
|
-
def __init__(
|
351
|
+
def __init__(
|
352
|
+
self,
|
353
|
+
mgr: NixlKVManager,
|
354
|
+
bootstrap_addr: str,
|
355
|
+
bootstrap_room: int,
|
356
|
+
dest_tp_ranks: List[int],
|
357
|
+
pp_rank: int,
|
358
|
+
):
|
354
359
|
self.kv_mgr = mgr
|
355
360
|
self.bootstrap_room = bootstrap_room
|
356
361
|
self.aux_index = None
|
@@ -368,7 +373,7 @@ class NixlKVSender(BaseKVSender):
|
|
368
373
|
|
369
374
|
def send(
|
370
375
|
self,
|
371
|
-
kv_indices: npt.NDArray[np.
|
376
|
+
kv_indices: npt.NDArray[np.int32],
|
372
377
|
):
|
373
378
|
index_slice = slice(self.curr_idx, self.curr_idx + len(kv_indices))
|
374
379
|
self.curr_idx += len(kv_indices)
|
@@ -412,7 +417,7 @@ class NixlKVReceiver(CommonKVReceiver):
|
|
412
417
|
self.started_transfer = False
|
413
418
|
super().__init__(mgr, bootstrap_addr, bootstrap_room, data_parallel_rank)
|
414
419
|
|
415
|
-
def init(self, kv_indices: npt.NDArray[np.
|
420
|
+
def init(self, kv_indices: npt.NDArray[np.int32], aux_index: Optional[int] = None):
|
416
421
|
for bootstrap_info in self.bootstrap_infos:
|
417
422
|
self.prefill_server_url = (
|
418
423
|
f"{bootstrap_info['rank_ip']}:{bootstrap_info['rank_port']}"
|