sglang 0.4.7__py3-none-any.whl → 0.4.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (152) hide show
  1. sglang/__init__.py +2 -0
  2. sglang/api.py +7 -0
  3. sglang/bench_one_batch.py +8 -6
  4. sglang/bench_serving.py +1 -1
  5. sglang/lang/interpreter.py +40 -1
  6. sglang/lang/ir.py +27 -0
  7. sglang/math_utils.py +8 -0
  8. sglang/srt/_custom_ops.py +2 -2
  9. sglang/srt/code_completion_parser.py +2 -44
  10. sglang/srt/configs/model_config.py +6 -0
  11. sglang/srt/constants.py +3 -0
  12. sglang/srt/conversation.py +19 -3
  13. sglang/srt/custom_op.py +5 -1
  14. sglang/srt/disaggregation/base/__init__.py +1 -1
  15. sglang/srt/disaggregation/base/conn.py +25 -11
  16. sglang/srt/disaggregation/common/__init__.py +5 -1
  17. sglang/srt/disaggregation/common/utils.py +42 -0
  18. sglang/srt/disaggregation/decode.py +211 -72
  19. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +4 -3
  20. sglang/srt/disaggregation/fake/__init__.py +1 -1
  21. sglang/srt/disaggregation/fake/conn.py +15 -9
  22. sglang/srt/disaggregation/mini_lb.py +34 -4
  23. sglang/srt/disaggregation/mooncake/__init__.py +1 -1
  24. sglang/srt/disaggregation/mooncake/conn.py +30 -29
  25. sglang/srt/disaggregation/nixl/__init__.py +6 -1
  26. sglang/srt/disaggregation/nixl/conn.py +17 -12
  27. sglang/srt/disaggregation/prefill.py +144 -55
  28. sglang/srt/disaggregation/utils.py +155 -123
  29. sglang/srt/distributed/parallel_state.py +12 -4
  30. sglang/srt/entrypoints/engine.py +37 -29
  31. sglang/srt/entrypoints/http_server.py +153 -72
  32. sglang/srt/entrypoints/http_server_engine.py +0 -3
  33. sglang/srt/entrypoints/openai/__init__.py +0 -0
  34. sglang/srt/{openai_api → entrypoints/openai}/protocol.py +84 -10
  35. sglang/srt/entrypoints/openai/serving_base.py +149 -0
  36. sglang/srt/entrypoints/openai/serving_chat.py +921 -0
  37. sglang/srt/entrypoints/openai/serving_completions.py +424 -0
  38. sglang/srt/entrypoints/openai/serving_embedding.py +169 -0
  39. sglang/srt/entrypoints/openai/serving_rerank.py +102 -0
  40. sglang/srt/entrypoints/openai/serving_score.py +61 -0
  41. sglang/srt/entrypoints/openai/usage_processor.py +81 -0
  42. sglang/srt/entrypoints/openai/utils.py +72 -0
  43. sglang/srt/eplb_simulator/__init__.py +1 -0
  44. sglang/srt/eplb_simulator/reader.py +51 -0
  45. sglang/srt/function_call/base_format_detector.py +7 -4
  46. sglang/srt/function_call/deepseekv3_detector.py +1 -1
  47. sglang/srt/function_call/ebnf_composer.py +64 -10
  48. sglang/srt/function_call/function_call_parser.py +6 -6
  49. sglang/srt/function_call/llama32_detector.py +1 -1
  50. sglang/srt/function_call/mistral_detector.py +1 -1
  51. sglang/srt/function_call/pythonic_detector.py +1 -1
  52. sglang/srt/function_call/qwen25_detector.py +1 -1
  53. sglang/srt/{openai_api/utils.py → jinja_template_utils.py} +6 -5
  54. sglang/srt/layers/activation.py +40 -3
  55. sglang/srt/layers/attention/aiter_backend.py +20 -4
  56. sglang/srt/layers/attention/base_attn_backend.py +1 -1
  57. sglang/srt/layers/attention/cutlass_mla_backend.py +39 -15
  58. sglang/srt/layers/attention/flashattention_backend.py +71 -72
  59. sglang/srt/layers/attention/flashinfer_backend.py +10 -8
  60. sglang/srt/layers/attention/flashinfer_mla_backend.py +29 -28
  61. sglang/srt/layers/attention/flashmla_backend.py +7 -12
  62. sglang/srt/layers/attention/tbo_backend.py +3 -3
  63. sglang/srt/layers/attention/triton_backend.py +138 -130
  64. sglang/srt/layers/attention/triton_ops/decode_attention.py +2 -7
  65. sglang/srt/layers/attention/vision.py +51 -24
  66. sglang/srt/layers/communicator.py +28 -10
  67. sglang/srt/layers/dp_attention.py +11 -2
  68. sglang/srt/layers/layernorm.py +29 -2
  69. sglang/srt/layers/linear.py +0 -4
  70. sglang/srt/layers/logits_processor.py +2 -14
  71. sglang/srt/layers/moe/ep_moe/kernels.py +165 -7
  72. sglang/srt/layers/moe/ep_moe/layer.py +249 -33
  73. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +11 -37
  74. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=384,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  75. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +7 -4
  76. sglang/srt/layers/moe/fused_moe_triton/layer.py +75 -12
  77. sglang/srt/layers/moe/topk.py +107 -12
  78. sglang/srt/layers/pooler.py +56 -0
  79. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +6 -2
  80. sglang/srt/layers/quantization/deep_gemm_wrapper/__init__.py +1 -0
  81. sglang/srt/layers/quantization/{deep_gemm.py → deep_gemm_wrapper/compile_utils.py} +23 -80
  82. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +32 -0
  83. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +110 -0
  84. sglang/srt/layers/quantization/fp8.py +25 -17
  85. sglang/srt/layers/quantization/fp8_kernel.py +44 -15
  86. sglang/srt/layers/quantization/fp8_utils.py +87 -22
  87. sglang/srt/layers/quantization/modelopt_quant.py +62 -8
  88. sglang/srt/layers/quantization/utils.py +5 -2
  89. sglang/srt/layers/radix_attention.py +2 -3
  90. sglang/srt/layers/rotary_embedding.py +42 -2
  91. sglang/srt/layers/sampler.py +1 -1
  92. sglang/srt/lora/lora_manager.py +249 -105
  93. sglang/srt/lora/mem_pool.py +53 -50
  94. sglang/srt/lora/utils.py +1 -1
  95. sglang/srt/managers/cache_controller.py +33 -14
  96. sglang/srt/managers/io_struct.py +31 -10
  97. sglang/srt/managers/multimodal_processors/base_processor.py +2 -2
  98. sglang/srt/managers/multimodal_processors/vila.py +85 -0
  99. sglang/srt/managers/schedule_batch.py +79 -37
  100. sglang/srt/managers/schedule_policy.py +70 -56
  101. sglang/srt/managers/scheduler.py +220 -79
  102. sglang/srt/managers/template_manager.py +226 -0
  103. sglang/srt/managers/tokenizer_manager.py +40 -10
  104. sglang/srt/managers/tp_worker.py +12 -2
  105. sglang/srt/managers/tp_worker_overlap_thread.py +11 -0
  106. sglang/srt/mem_cache/{paged_allocator.py → allocator.py} +125 -34
  107. sglang/srt/mem_cache/base_prefix_cache.py +52 -8
  108. sglang/srt/mem_cache/chunk_cache.py +11 -15
  109. sglang/srt/mem_cache/hiradix_cache.py +38 -25
  110. sglang/srt/mem_cache/memory_pool.py +213 -505
  111. sglang/srt/mem_cache/memory_pool_host.py +380 -0
  112. sglang/srt/mem_cache/radix_cache.py +56 -28
  113. sglang/srt/model_executor/cuda_graph_runner.py +198 -100
  114. sglang/srt/model_executor/forward_batch_info.py +32 -10
  115. sglang/srt/model_executor/model_runner.py +28 -12
  116. sglang/srt/model_loader/loader.py +16 -2
  117. sglang/srt/model_loader/weight_utils.py +11 -2
  118. sglang/srt/models/bert.py +113 -13
  119. sglang/srt/models/deepseek_nextn.py +29 -27
  120. sglang/srt/models/deepseek_v2.py +213 -173
  121. sglang/srt/models/glm4.py +312 -0
  122. sglang/srt/models/internvl.py +46 -102
  123. sglang/srt/models/mimo_mtp.py +2 -18
  124. sglang/srt/models/roberta.py +117 -9
  125. sglang/srt/models/vila.py +305 -0
  126. sglang/srt/reasoning_parser.py +21 -11
  127. sglang/srt/sampling/sampling_batch_info.py +24 -0
  128. sglang/srt/sampling/sampling_params.py +2 -0
  129. sglang/srt/server_args.py +351 -238
  130. sglang/srt/speculative/build_eagle_tree.py +1 -1
  131. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +131 -9
  132. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +130 -14
  133. sglang/srt/speculative/eagle_utils.py +468 -116
  134. sglang/srt/speculative/eagle_worker.py +258 -84
  135. sglang/srt/torch_memory_saver_adapter.py +19 -15
  136. sglang/srt/two_batch_overlap.py +4 -2
  137. sglang/srt/utils.py +235 -11
  138. sglang/test/attention/test_prefix_chunk_info.py +2 -0
  139. sglang/test/runners.py +38 -3
  140. sglang/test/test_block_fp8.py +1 -0
  141. sglang/test/test_block_fp8_deep_gemm_blackwell.py +252 -0
  142. sglang/test/test_block_fp8_ep.py +2 -0
  143. sglang/test/test_utils.py +4 -1
  144. sglang/utils.py +9 -0
  145. sglang/version.py +1 -1
  146. {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/METADATA +8 -14
  147. {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/RECORD +150 -128
  148. sglang/srt/entrypoints/verl_engine.py +0 -179
  149. sglang/srt/openai_api/adapter.py +0 -1990
  150. {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/WHEEL +0 -0
  151. {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/licenses/LICENSE +0 -0
  152. {sglang-0.4.7.dist-info → sglang-0.4.8.dist-info}/top_level.txt +0 -0
@@ -18,6 +18,10 @@ from fastapi.responses import ORJSONResponse, Response, StreamingResponse
18
18
 
19
19
  from sglang.srt.disaggregation.utils import PDRegistryRequest
20
20
 
21
+ AIOHTTP_STREAM_READ_CHUNK_SIZE = (
22
+ 1024 * 64
23
+ ) # 64KB, to prevent aiohttp's "Chunk too big" error
24
+
21
25
 
22
26
  def setup_logger():
23
27
  logger = logging.getLogger("pdlb")
@@ -154,7 +158,9 @@ class MiniLoadBalancer:
154
158
  else:
155
159
  yield chunk
156
160
  else:
157
- async for chunk in decode_response.content:
161
+ async for chunk in decode_response.content.iter_chunked(
162
+ AIOHTTP_STREAM_READ_CHUNK_SIZE
163
+ ):
158
164
  yield chunk
159
165
 
160
166
  return StreamingResponse(
@@ -212,15 +218,39 @@ async def get_server_info():
212
218
  )
213
219
  prefill_infos = []
214
220
  decode_infos = []
221
+ all_internal_states = []
222
+
215
223
  async with aiohttp.ClientSession() as session:
216
224
  for server in chain(prefill_servers):
217
225
  server_info = await session.get(f"{server}/get_server_info")
218
226
  prefill_infos.append(await server_info.json())
219
227
  for server in chain(decode_servers):
220
228
  server_info = await session.get(f"{server}/get_server_info")
221
- decode_infos.append(await server_info.json())
222
-
223
- return {"prefill": prefill_infos, "decode": decode_infos}
229
+ info_json = await server_info.json()
230
+ decode_infos.append(info_json)
231
+ # Extract internal_states from decode servers
232
+ if "internal_states" in info_json:
233
+ all_internal_states.extend(info_json["internal_states"])
234
+
235
+ # Return format expected by bench_one_batch_server.py
236
+ if all_internal_states:
237
+ return {
238
+ "internal_states": all_internal_states,
239
+ "prefill": prefill_infos,
240
+ "decode": decode_infos,
241
+ }
242
+ else:
243
+ # Fallback with dummy data if no internal states found
244
+ return {
245
+ "internal_states": [
246
+ {
247
+ "last_gen_throughput": 0.0,
248
+ "avg_spec_accept_length": None,
249
+ }
250
+ ],
251
+ "prefill": prefill_infos,
252
+ "decode": decode_infos,
253
+ }
224
254
 
225
255
 
226
256
  @app.get("/get_model_info")
@@ -1,4 +1,4 @@
1
- from .conn import (
1
+ from sglang.srt.disaggregation.mooncake.conn import (
2
2
  MooncakeKVBootstrapServer,
3
3
  MooncakeKVManager,
4
4
  MooncakeKVReceiver,
@@ -28,19 +28,14 @@ from sglang.srt.disaggregation.base.conn import (
28
28
  KVArgs,
29
29
  KVPoll,
30
30
  )
31
- from sglang.srt.disaggregation.mooncake.transfer_engine import MooncakeTransferEngine
32
- from sglang.srt.disaggregation.utils import (
33
- DisaggregationMode,
31
+ from sglang.srt.disaggregation.common.utils import (
34
32
  FastQueue,
35
33
  group_concurrent_contiguous,
36
34
  )
35
+ from sglang.srt.disaggregation.mooncake.transfer_engine import MooncakeTransferEngine
36
+ from sglang.srt.disaggregation.utils import DisaggregationMode
37
37
  from sglang.srt.server_args import ServerArgs
38
- from sglang.srt.utils import (
39
- get_free_port,
40
- get_int_env_var,
41
- get_ip,
42
- get_local_ip_by_remote,
43
- )
38
+ from sglang.srt.utils import get_free_port, get_int_env_var, get_ip, get_local_ip_auto
44
39
 
45
40
  logger = logging.getLogger(__name__)
46
41
 
@@ -59,7 +54,7 @@ class KVTransferError(Exception):
59
54
  @dataclasses.dataclass
60
55
  class TransferKVChunk:
61
56
  room: int
62
- prefill_kv_indices: npt.NDArray[np.int64]
57
+ prefill_kv_indices: npt.NDArray[np.int32]
63
58
  index_slice: slice
64
59
  is_last: bool
65
60
  prefill_aux_index: Optional[int]
@@ -72,7 +67,7 @@ class TransferInfo:
72
67
  endpoint: str
73
68
  dst_port: int
74
69
  mooncake_session_id: str
75
- dst_kv_indices: npt.NDArray[np.int64]
70
+ dst_kv_indices: npt.NDArray[np.int32]
76
71
  dst_aux_index: int
77
72
  required_dst_info_num: int
78
73
  is_dummy: bool
@@ -81,10 +76,10 @@ class TransferInfo:
81
76
  def from_zmq(cls, msg: List[bytes]):
82
77
  if msg[4] == b"" and msg[5] == b"":
83
78
  is_dummy = True
84
- dst_kv_indices = np.array([], dtype=np.int64)
79
+ dst_kv_indices = np.array([], dtype=np.int32)
85
80
  dst_aux_index = None
86
81
  else:
87
- dst_kv_indices = np.frombuffer(msg[4], dtype=np.int64)
82
+ dst_kv_indices = np.frombuffer(msg[4], dtype=np.int32)
88
83
  dst_aux_index = int(msg[5].decode("ascii"))
89
84
  is_dummy = False
90
85
  return cls(
@@ -130,8 +125,9 @@ class MooncakeKVManager(BaseKVManager):
130
125
  is_mla_backend: Optional[bool] = False,
131
126
  ):
132
127
  self.kv_args = args
128
+ self.local_ip = get_local_ip_auto()
133
129
  self.engine = MooncakeTransferEngine(
134
- hostname=get_local_ip_by_remote(),
130
+ hostname=self.local_ip,
135
131
  gpu_id=self.kv_args.gpu_id,
136
132
  ib_device=self.kv_args.ib_device,
137
133
  )
@@ -233,9 +229,9 @@ class MooncakeKVManager(BaseKVManager):
233
229
  def send_kvcache(
234
230
  self,
235
231
  mooncake_session_id: str,
236
- prefill_kv_indices: npt.NDArray[np.int64],
232
+ prefill_kv_indices: npt.NDArray[np.int32],
237
233
  dst_kv_ptrs: list[int],
238
- dst_kv_indices: npt.NDArray[np.int64],
234
+ dst_kv_indices: npt.NDArray[np.int32],
239
235
  executor: concurrent.futures.ThreadPoolExecutor,
240
236
  ):
241
237
  # Group by indices
@@ -432,7 +428,7 @@ class MooncakeKVManager(BaseKVManager):
432
428
 
433
429
  def start_prefill_thread(self):
434
430
  self.rank_port = get_free_port()
435
- self.server_socket.bind(f"tcp://{get_local_ip_by_remote()}:{self.rank_port}")
431
+ self.server_socket.bind(f"tcp://{self.local_ip}:{self.rank_port}")
436
432
 
437
433
  def bootstrap_thread():
438
434
  """This thread recvs pre-alloc notification from the decode engine"""
@@ -471,7 +467,7 @@ class MooncakeKVManager(BaseKVManager):
471
467
 
472
468
  def start_decode_thread(self):
473
469
  self.rank_port = get_free_port()
474
- self.server_socket.bind(f"tcp://{get_local_ip_by_remote()}:{self.rank_port}")
470
+ self.server_socket.bind(f"tcp://{self.local_ip}:{self.rank_port}")
475
471
 
476
472
  def decode_thread():
477
473
  while True:
@@ -545,7 +541,7 @@ class MooncakeKVManager(BaseKVManager):
545
541
  def add_transfer_request(
546
542
  self,
547
543
  bootstrap_room: int,
548
- kv_indices: npt.NDArray[np.int64],
544
+ kv_indices: npt.NDArray[np.int32],
549
545
  index_slice: slice,
550
546
  is_last: bool,
551
547
  aux_index: Optional[int] = None,
@@ -620,7 +616,7 @@ class MooncakeKVManager(BaseKVManager):
620
616
  "role": "Prefill",
621
617
  "tp_size": self.tp_size,
622
618
  "dp_size": self.dp_size,
623
- "rank_ip": get_local_ip_by_remote(),
619
+ "rank_ip": self.local_ip,
624
620
  "rank_port": self.rank_port,
625
621
  "engine_rank": self.kv_args.engine_rank,
626
622
  }
@@ -677,7 +673,12 @@ class MooncakeKVManager(BaseKVManager):
677
673
  class MooncakeKVSender(BaseKVSender):
678
674
 
679
675
  def __init__(
680
- self, mgr: MooncakeKVManager, bootstrap_addr: str, bootstrap_room: int
676
+ self,
677
+ mgr: MooncakeKVManager,
678
+ bootstrap_addr: str,
679
+ bootstrap_room: int,
680
+ dest_tp_ranks: List[int],
681
+ pp_rank: int,
681
682
  ):
682
683
  self.kv_mgr = mgr
683
684
  self.bootstrap_room = bootstrap_room
@@ -696,7 +697,7 @@ class MooncakeKVSender(BaseKVSender):
696
697
 
697
698
  def send(
698
699
  self,
699
- kv_indices: npt.NDArray[np.int64],
700
+ kv_indices: npt.NDArray[np.int32],
700
701
  ):
701
702
  index_slice = slice(self.curr_idx, self.curr_idx + len(kv_indices))
702
703
  self.curr_idx += len(kv_indices)
@@ -741,12 +742,12 @@ class MooncakeKVSender(BaseKVSender):
741
742
  self.kv_mgr.request_status.pop(self.bootstrap_room)
742
743
 
743
744
  def failure_exception(self):
744
- self.clear()
745
-
746
745
  # Explicitly set the status to failure since this request has failed in another rank
747
746
  if self.conclude_state is None:
748
747
  self.conclude_state = KVPoll.Failed
749
748
 
749
+ self.clear()
750
+
750
751
  with self.kv_mgr.failure_lock:
751
752
  failure_reason = self.kv_mgr.failure_records.pop(
752
753
  self.bootstrap_room, "Failed due to an unknown reason from another rank"
@@ -948,7 +949,7 @@ class MooncakeKVReceiver(BaseKVReceiver):
948
949
  sock.send_multipart(
949
950
  [
950
951
  "None".encode("ascii"),
951
- get_local_ip_by_remote().encode("ascii"),
952
+ self.kv_mgr.local_ip.encode("ascii"),
952
953
  str(self.kv_mgr.rank_port).encode("ascii"),
953
954
  self.session_id.encode("ascii"),
954
955
  packed_kv_data_ptrs,
@@ -966,7 +967,7 @@ class MooncakeKVReceiver(BaseKVReceiver):
966
967
  cls._socket_locks[endpoint] = threading.Lock()
967
968
  return cls._socket_cache[endpoint], cls._socket_locks[endpoint]
968
969
 
969
- def init(self, kv_indices: npt.NDArray[np.int64], aux_index: Optional[int] = None):
970
+ def init(self, kv_indices: npt.NDArray[np.int32], aux_index: Optional[int] = None):
970
971
  for bootstrap_info in self.bootstrap_infos:
971
972
  self.prefill_server_url = (
972
973
  f"{bootstrap_info['rank_ip']}:{bootstrap_info['rank_port']}"
@@ -978,7 +979,7 @@ class MooncakeKVReceiver(BaseKVReceiver):
978
979
  sock.send_multipart(
979
980
  [
980
981
  str(self.bootstrap_room).encode("ascii"),
981
- get_local_ip_by_remote().encode("ascii"),
982
+ self.kv_mgr.local_ip.encode("ascii"),
982
983
  str(self.kv_mgr.rank_port).encode("ascii"),
983
984
  self.session_id.encode("ascii"),
984
985
  kv_indices.tobytes() if not is_dummy else b"",
@@ -1002,12 +1003,12 @@ class MooncakeKVReceiver(BaseKVReceiver):
1002
1003
  self.kv_mgr.request_status.pop(self.bootstrap_room)
1003
1004
 
1004
1005
  def failure_exception(self):
1005
- self.clear()
1006
-
1007
1006
  # Explicitly set the status to failure since this request has failed in another rank
1008
1007
  if self.conclude_state is None:
1009
1008
  self.conclude_state = KVPoll.Failed
1010
1009
 
1010
+ self.clear()
1011
+
1011
1012
  with self.kv_mgr.failure_lock:
1012
1013
  failure_reason = self.kv_mgr.failure_records.pop(
1013
1014
  self.bootstrap_room, "Failed due to an unknown reason from another rank"
@@ -1 +1,6 @@
1
- from .conn import NixlKVBootstrapServer, NixlKVManager, NixlKVReceiver, NixlKVSender
1
+ from sglang.srt.disaggregation.nixl.conn import (
2
+ NixlKVBootstrapServer,
3
+ NixlKVManager,
4
+ NixlKVReceiver,
5
+ NixlKVSender,
6
+ )
@@ -24,10 +24,8 @@ from sglang.srt.disaggregation.common.conn import (
24
24
  CommonKVManager,
25
25
  CommonKVReceiver,
26
26
  )
27
- from sglang.srt.disaggregation.utils import (
28
- DisaggregationMode,
29
- group_concurrent_contiguous,
30
- )
27
+ from sglang.srt.disaggregation.common.utils import group_concurrent_contiguous
28
+ from sglang.srt.disaggregation.utils import DisaggregationMode
31
29
  from sglang.srt.server_args import ServerArgs
32
30
  from sglang.srt.utils import get_local_ip_by_remote
33
31
 
@@ -46,7 +44,7 @@ class TransferInfo:
46
44
  agent_metadata: bytes
47
45
  agent_name: str
48
46
  dst_kv_ptrs: list[int]
49
- dst_kv_indices: npt.NDArray[np.int64]
47
+ dst_kv_indices: npt.NDArray[np.int32]
50
48
  dst_aux_ptrs: list[int]
51
49
  dst_aux_index: int
52
50
  dst_gpu_id: int
@@ -64,7 +62,7 @@ class TransferInfo:
64
62
  agent_metadata=msg[3],
65
63
  agent_name=msg[4].decode("ascii"),
66
64
  dst_kv_ptrs=list(struct.unpack(f"{len(msg[5])//8}Q", msg[5])),
67
- dst_kv_indices=np.frombuffer(msg[6], dtype=np.int64),
65
+ dst_kv_indices=np.frombuffer(msg[6], dtype=np.int32),
68
66
  dst_aux_ptrs=list(struct.unpack(f"{len(msg[7])//8}Q", msg[7])),
69
67
  dst_aux_index=int(msg[8].decode("ascii")),
70
68
  dst_gpu_id=int(msg[9].decode("ascii")),
@@ -164,9 +162,9 @@ class NixlKVManager(CommonKVManager):
164
162
  def send_kvcache(
165
163
  self,
166
164
  peer_name: str,
167
- prefill_kv_indices: npt.NDArray[np.int64],
165
+ prefill_kv_indices: npt.NDArray[np.int32],
168
166
  dst_kv_ptrs: list[int],
169
- dst_kv_indices: npt.NDArray[np.int64],
167
+ dst_kv_indices: npt.NDArray[np.int32],
170
168
  dst_gpu_id: int,
171
169
  notif: str,
172
170
  ):
@@ -248,7 +246,7 @@ class NixlKVManager(CommonKVManager):
248
246
  def add_transfer_request(
249
247
  self,
250
248
  bootstrap_room: int,
251
- kv_indices: npt.NDArray[np.int64],
249
+ kv_indices: npt.NDArray[np.int32],
252
250
  index_slice: slice,
253
251
  is_last: bool,
254
252
  chunk_id: int,
@@ -350,7 +348,14 @@ class NixlKVManager(CommonKVManager):
350
348
 
351
349
  class NixlKVSender(BaseKVSender):
352
350
 
353
- def __init__(self, mgr: NixlKVManager, bootstrap_addr: str, bootstrap_room: int):
351
+ def __init__(
352
+ self,
353
+ mgr: NixlKVManager,
354
+ bootstrap_addr: str,
355
+ bootstrap_room: int,
356
+ dest_tp_ranks: List[int],
357
+ pp_rank: int,
358
+ ):
354
359
  self.kv_mgr = mgr
355
360
  self.bootstrap_room = bootstrap_room
356
361
  self.aux_index = None
@@ -368,7 +373,7 @@ class NixlKVSender(BaseKVSender):
368
373
 
369
374
  def send(
370
375
  self,
371
- kv_indices: npt.NDArray[np.int64],
376
+ kv_indices: npt.NDArray[np.int32],
372
377
  ):
373
378
  index_slice = slice(self.curr_idx, self.curr_idx + len(kv_indices))
374
379
  self.curr_idx += len(kv_indices)
@@ -412,7 +417,7 @@ class NixlKVReceiver(CommonKVReceiver):
412
417
  self.started_transfer = False
413
418
  super().__init__(mgr, bootstrap_addr, bootstrap_room, data_parallel_rank)
414
419
 
415
- def init(self, kv_indices: npt.NDArray[np.int64], aux_index: Optional[int] = None):
420
+ def init(self, kv_indices: npt.NDArray[np.int32], aux_index: Optional[int] = None):
416
421
  for bootstrap_info in self.bootstrap_infos:
417
422
  self.prefill_server_url = (
418
423
  f"{bootstrap_info['rank_ip']}:{bootstrap_info['rank_port']}"