sglang 0.4.5.post1__py3-none-any.whl → 0.4.5.post3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (119) hide show
  1. sglang/__init__.py +2 -4
  2. sglang/bench_one_batch.py +2 -2
  3. sglang/bench_serving.py +3 -6
  4. sglang/compile_deep_gemm.py +136 -0
  5. sglang/lang/backend/anthropic.py +0 -4
  6. sglang/lang/backend/base_backend.py +1 -1
  7. sglang/lang/backend/openai.py +6 -2
  8. sglang/lang/backend/runtime_endpoint.py +5 -1
  9. sglang/lang/backend/vertexai.py +0 -1
  10. sglang/lang/compiler.py +1 -7
  11. sglang/lang/tracer.py +3 -7
  12. sglang/srt/_custom_ops.py +0 -2
  13. sglang/srt/configs/model_config.py +4 -1
  14. sglang/srt/constrained/outlines_jump_forward.py +14 -1
  15. sglang/srt/constrained/triton_ops/bitmask_ops.py +141 -0
  16. sglang/srt/constrained/xgrammar_backend.py +27 -4
  17. sglang/srt/custom_op.py +0 -62
  18. sglang/srt/disaggregation/decode.py +105 -6
  19. sglang/srt/disaggregation/mini_lb.py +74 -9
  20. sglang/srt/disaggregation/mooncake/conn.py +33 -63
  21. sglang/srt/disaggregation/mooncake/transfer_engine.py +30 -61
  22. sglang/srt/disaggregation/nixl/__init__.py +1 -0
  23. sglang/srt/disaggregation/nixl/conn.py +622 -0
  24. sglang/srt/disaggregation/prefill.py +137 -17
  25. sglang/srt/disaggregation/utils.py +32 -0
  26. sglang/srt/entrypoints/engine.py +4 -0
  27. sglang/srt/entrypoints/http_server.py +3 -7
  28. sglang/srt/entrypoints/verl_engine.py +7 -5
  29. sglang/srt/function_call_parser.py +60 -0
  30. sglang/srt/layers/activation.py +6 -8
  31. sglang/srt/layers/attention/flashattention_backend.py +883 -209
  32. sglang/srt/layers/attention/flashinfer_backend.py +5 -2
  33. sglang/srt/layers/attention/torch_native_backend.py +6 -1
  34. sglang/srt/layers/attention/triton_backend.py +6 -0
  35. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +5 -5
  36. sglang/srt/layers/attention/triton_ops/extend_attention.py +18 -7
  37. sglang/srt/layers/attention/triton_ops/prefill_attention.py +7 -3
  38. sglang/srt/layers/dp_attention.py +1 -1
  39. sglang/srt/layers/layernorm.py +20 -5
  40. sglang/srt/layers/linear.py +17 -3
  41. sglang/srt/layers/moe/ep_moe/layer.py +17 -29
  42. sglang/srt/layers/moe/fused_moe_native.py +4 -0
  43. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +14 -19
  44. sglang/srt/layers/moe/fused_moe_triton/layer.py +7 -0
  45. sglang/srt/layers/moe/topk.py +27 -30
  46. sglang/srt/layers/parameter.py +0 -2
  47. sglang/srt/layers/quantization/__init__.py +1 -0
  48. sglang/srt/layers/quantization/blockwise_int8.py +2 -0
  49. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +9 -2
  50. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +16 -44
  51. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
  52. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +153 -0
  53. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +4 -7
  54. sglang/srt/layers/quantization/deep_gemm.py +378 -0
  55. sglang/srt/layers/quantization/fp8.py +115 -132
  56. sglang/srt/layers/quantization/fp8_kernel.py +213 -88
  57. sglang/srt/layers/quantization/fp8_utils.py +189 -264
  58. sglang/srt/layers/quantization/gptq.py +13 -7
  59. sglang/srt/layers/quantization/modelopt_quant.py +2 -2
  60. sglang/srt/layers/quantization/moe_wna16.py +2 -0
  61. sglang/srt/layers/quantization/utils.py +5 -11
  62. sglang/srt/layers/quantization/w8a8_fp8.py +2 -0
  63. sglang/srt/layers/quantization/w8a8_int8.py +7 -7
  64. sglang/srt/layers/radix_attention.py +15 -0
  65. sglang/srt/layers/rotary_embedding.py +9 -8
  66. sglang/srt/layers/sampler.py +7 -12
  67. sglang/srt/lora/backend/base_backend.py +18 -2
  68. sglang/srt/lora/backend/flashinfer_backend.py +1 -1
  69. sglang/srt/lora/backend/triton_backend.py +1 -1
  70. sglang/srt/lora/layers.py +1 -1
  71. sglang/srt/lora/lora.py +1 -1
  72. sglang/srt/lora/lora_manager.py +1 -1
  73. sglang/srt/managers/data_parallel_controller.py +7 -1
  74. sglang/srt/managers/detokenizer_manager.py +0 -1
  75. sglang/srt/managers/io_struct.py +15 -3
  76. sglang/srt/managers/mm_utils.py +4 -3
  77. sglang/srt/managers/multimodal_processor.py +0 -2
  78. sglang/srt/managers/multimodal_processors/base_processor.py +3 -2
  79. sglang/srt/managers/schedule_batch.py +15 -4
  80. sglang/srt/managers/scheduler.py +28 -77
  81. sglang/srt/managers/tokenizer_manager.py +116 -29
  82. sglang/srt/managers/tp_worker.py +1 -0
  83. sglang/srt/mem_cache/hiradix_cache.py +41 -29
  84. sglang/srt/mem_cache/memory_pool.py +38 -15
  85. sglang/srt/model_executor/cuda_graph_runner.py +15 -10
  86. sglang/srt/model_executor/model_runner.py +39 -31
  87. sglang/srt/models/bert.py +398 -0
  88. sglang/srt/models/deepseek.py +1 -1
  89. sglang/srt/models/deepseek_nextn.py +74 -70
  90. sglang/srt/models/deepseek_v2.py +292 -348
  91. sglang/srt/models/llama.py +5 -5
  92. sglang/srt/models/minicpm3.py +31 -203
  93. sglang/srt/models/minicpmo.py +17 -6
  94. sglang/srt/models/qwen2.py +4 -1
  95. sglang/srt/models/qwen2_moe.py +14 -13
  96. sglang/srt/models/qwen3.py +335 -0
  97. sglang/srt/models/qwen3_moe.py +423 -0
  98. sglang/srt/openai_api/adapter.py +71 -4
  99. sglang/srt/openai_api/protocol.py +6 -1
  100. sglang/srt/reasoning_parser.py +0 -1
  101. sglang/srt/sampling/sampling_batch_info.py +2 -3
  102. sglang/srt/server_args.py +86 -72
  103. sglang/srt/speculative/build_eagle_tree.py +2 -2
  104. sglang/srt/speculative/eagle_utils.py +2 -2
  105. sglang/srt/speculative/eagle_worker.py +6 -14
  106. sglang/srt/utils.py +62 -6
  107. sglang/test/runners.py +5 -1
  108. sglang/test/test_block_fp8.py +167 -0
  109. sglang/test/test_custom_ops.py +1 -1
  110. sglang/test/test_utils.py +3 -1
  111. sglang/version.py +1 -1
  112. {sglang-0.4.5.post1.dist-info → sglang-0.4.5.post3.dist-info}/METADATA +5 -5
  113. {sglang-0.4.5.post1.dist-info → sglang-0.4.5.post3.dist-info}/RECORD +116 -110
  114. {sglang-0.4.5.post1.dist-info → sglang-0.4.5.post3.dist-info}/WHEEL +1 -1
  115. sglang/lang/__init__.py +0 -0
  116. sglang/srt/lora/backend/__init__.py +0 -25
  117. sglang/srt/server.py +0 -18
  118. {sglang-0.4.5.post1.dist-info → sglang-0.4.5.post3.dist-info}/licenses/LICENSE +0 -0
  119. {sglang-0.4.5.post1.dist-info → sglang-0.4.5.post3.dist-info}/top_level.txt +0 -0
@@ -21,6 +21,7 @@ Life cycle of a request in the decode server
21
21
  from __future__ import annotations
22
22
 
23
23
  import logging
24
+ from collections import deque
24
25
  from dataclasses import dataclass
25
26
  from typing import TYPE_CHECKING, List, Optional, Tuple
26
27
 
@@ -35,6 +36,7 @@ from sglang.srt.disaggregation.utils import (
35
36
  ReqToMetadataIdxAllocator,
36
37
  TransferBackend,
37
38
  get_kv_class,
39
+ kv_to_page_indices,
38
40
  poll_and_all_reduce,
39
41
  )
40
42
  from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
@@ -121,7 +123,7 @@ class DecodePreallocQueue:
121
123
  kv_args.aux_item_lens = [
122
124
  metadata_buffer[0].nbytes for metadata_buffer in self.metadata_buffers
123
125
  ]
124
- kv_args.ib_device = "mock-ib-device"
126
+ kv_args.ib_device = self.scheduler.server_args.disaggregation_ib_device
125
127
  kv_args.gpu_id = self.scheduler.gpu_id
126
128
  kv_manager_class = get_kv_class(self.transfer_backend, KVClassType.MANAGER)
127
129
  kv_manager = kv_manager_class(
@@ -205,7 +207,10 @@ class DecodePreallocQueue:
205
207
  self.req_to_metadata_buffer_idx_allocator.alloc()
206
208
  )
207
209
  assert decode_req.metadata_buffer_index is not None
208
- decode_req.kv_receiver.init(kv_indices, decode_req.metadata_buffer_index)
210
+ page_indices = kv_to_page_indices(
211
+ kv_indices, self.token_to_kv_pool_allocator.page_size
212
+ )
213
+ decode_req.kv_receiver.init(page_indices, decode_req.metadata_buffer_index)
209
214
  preallocated_reqs.append(decode_req)
210
215
  indices_to_remove.add(i)
211
216
 
@@ -245,10 +250,30 @@ class DecodePreallocQueue:
245
250
  assert req_pool_indices is not None
246
251
 
247
252
  req.req_pool_idx = req_pool_indices[0]
248
- kv_loc = self.token_to_kv_pool_allocator.alloc(
249
- len(req.origin_input_ids) + max(len(req.output_ids) - 1, 0)
250
- )
251
-
253
+ if self.token_to_kv_pool_allocator.page_size == 1:
254
+ kv_loc = self.token_to_kv_pool_allocator.alloc(
255
+ len(req.origin_input_ids) + max(len(req.output_ids) - 1, 0)
256
+ )
257
+ else:
258
+ num_tokens = len(req.origin_input_ids) + max(len(req.output_ids) - 1, 0)
259
+ kv_loc = self.token_to_kv_pool_allocator.alloc_extend(
260
+ prefix_lens=torch.tensor(
261
+ [0],
262
+ dtype=torch.int64,
263
+ device=self.token_to_kv_pool_allocator.device,
264
+ ),
265
+ seq_lens=torch.tensor(
266
+ [num_tokens],
267
+ dtype=torch.int64,
268
+ device=self.token_to_kv_pool_allocator.device,
269
+ ),
270
+ last_loc=torch.tensor(
271
+ [-1],
272
+ dtype=torch.int64,
273
+ device=self.token_to_kv_pool_allocator.device,
274
+ ),
275
+ extend_num_tokens=num_tokens,
276
+ )
252
277
  assert kv_loc is not None
253
278
 
254
279
  self.req_to_token_pool.write((req.req_pool_idx, slice(0, len(kv_loc))), kv_loc)
@@ -419,6 +444,80 @@ class ScheduleBatchDisaggregationDecodeMixin:
419
444
 
420
445
  class SchedulerDisaggregationDecodeMixin:
421
446
 
447
+ @torch.no_grad()
448
+ def event_loop_normal_disagg_decode(self):
449
+ """A normal scheduler loop for decode worker in disaggregation mode."""
450
+
451
+ while True:
452
+ recv_reqs = self.recv_requests()
453
+ self.process_input_requests(recv_reqs)
454
+ # polling and allocating kv cache
455
+ self.process_decode_queue()
456
+ batch = self.get_next_disagg_decode_batch_to_run()
457
+ self.cur_batch = batch
458
+
459
+ if batch:
460
+ # Generate fake extend output.
461
+ if batch.forward_mode.is_extend():
462
+ # Note: Logprobs should be handled on the prefill engine.
463
+ self.stream_output(batch.reqs, False)
464
+ else:
465
+ result = self.run_batch(batch)
466
+ self.process_batch_result(batch, result)
467
+
468
+ if batch is None and (
469
+ len(self.disagg_decode_transfer_queue.queue)
470
+ + len(self.disagg_decode_prealloc_queue.queue)
471
+ == 0
472
+ ):
473
+ # When the server is idle, do self-check and re-init some states
474
+ self.check_memory()
475
+ self.new_token_ratio = self.init_new_token_ratio
476
+
477
+ self.last_batch = batch
478
+
479
+ @torch.no_grad()
480
+ def event_loop_overlap_disagg_decode(self):
481
+ result_queue = deque()
482
+ self.last_batch: Optional[ScheduleBatch] = None
483
+ self.last_batch_is_extend = False # last batch is modifed in-place, so we need another variable to track if it's extend
484
+
485
+ while True:
486
+ recv_reqs = self.recv_requests()
487
+ self.process_input_requests(recv_reqs)
488
+ # polling and allocating kv cache
489
+ self.process_decode_queue()
490
+ batch = self.get_next_disagg_decode_batch_to_run()
491
+ self.cur_batch = batch
492
+ last_batch_is_extend = False
493
+
494
+ if batch:
495
+ # Generate fake extend output.
496
+ if batch.forward_mode.is_extend():
497
+ # Note: Logprobs should be handled on the prefill engine.
498
+ self.stream_output(batch.reqs, False)
499
+ last_batch_is_extend = True
500
+ else:
501
+ result = self.run_batch(batch)
502
+ result_queue.append((batch.copy(), result))
503
+
504
+ # Process the results of the previous batch but skip if the last batch is extend
505
+ if self.last_batch and not self.last_batch_is_extend:
506
+ tmp_batch, tmp_result = result_queue.popleft()
507
+ self.process_batch_result(tmp_batch, tmp_result)
508
+
509
+ if batch is None and (
510
+ len(self.disagg_decode_transfer_queue.queue)
511
+ + len(self.disagg_decode_prealloc_queue.queue)
512
+ == 0
513
+ ):
514
+ # When the server is idle, do self-check and re-init some states
515
+ self.check_memory()
516
+ self.new_token_ratio = self.init_new_token_ratio
517
+
518
+ self.last_batch = batch
519
+ self.last_batch_is_extend = last_batch_is_extend
520
+
422
521
  def get_next_disagg_decode_batch_to_run(
423
522
  self: Scheduler,
424
523
  ) -> Optional[Tuple[ScheduleBatch, bool]]:
@@ -23,13 +23,18 @@ class MiniLoadBalancer:
23
23
  return random.choice(self.prefill_servers), random.choice(self.decode_servers)
24
24
 
25
25
  async def generate(
26
- self, modified_request, prefill_server, decode_server
26
+ self, modified_request, prefill_server, decode_server, endpoint
27
27
  ) -> ORJSONResponse:
28
+ assert endpoint[0] != "/", f"Endpoint should not start with '/': {endpoint}"
28
29
 
29
- async with aiohttp.ClientSession() as session:
30
+ async with aiohttp.ClientSession(
31
+ timeout=aiohttp.ClientTimeout(
32
+ total=3600
33
+ ) # Add timeout for request reliability
34
+ ) as session:
30
35
  tasks = [
31
- session.post(f"{prefill_server}/generate", json=modified_request),
32
- session.post(f"{decode_server}/generate", json=modified_request),
36
+ session.post(f"{prefill_server}/{endpoint}", json=modified_request),
37
+ session.post(f"{decode_server}/{endpoint}", json=modified_request),
33
38
  ]
34
39
  # Wait for both responses to complete. Prefill should end first.
35
40
  prefill_response, decode_response = await asyncio.gather(*tasks)
@@ -39,7 +44,11 @@ class MiniLoadBalancer:
39
44
  status_code=decode_response.status,
40
45
  )
41
46
 
42
- async def generate_stream(self, modified_request, prefill_server, decode_server):
47
+ async def generate_stream(
48
+ self, modified_request, prefill_server, decode_server, endpoint="generate"
49
+ ):
50
+ assert endpoint[0] != "/", f"Endpoint should not start with '/': {endpoint}"
51
+
43
52
  async def stream_results():
44
53
  async with aiohttp.ClientSession(
45
54
  timeout=aiohttp.ClientTimeout(
@@ -50,10 +59,10 @@ class MiniLoadBalancer:
50
59
  # Create the tasks for both prefill and decode requests
51
60
  tasks = [
52
61
  session.post(
53
- f"{prefill_server}/generate", json=modified_request
62
+ f"{prefill_server}/{endpoint}", json=modified_request
54
63
  ),
55
64
  session.post(
56
- f"{decode_server}/generate", json=modified_request
65
+ f"{decode_server}/{endpoint}", json=modified_request
57
66
  ),
58
67
  ]
59
68
  # Wait for both responses to complete. Since this is streaming, they return immediately.
@@ -153,6 +162,43 @@ async def get_model_info():
153
162
  async def handle_generate_request(request_data: dict):
154
163
  prefill_server, decode_server = load_balancer.select_pair()
155
164
 
165
+ # Parse and transform prefill_server for bootstrap data
166
+ parsed_url = urllib.parse.urlparse(prefill_server)
167
+ hostname = parsed_url.hostname
168
+ modified_request = request_data.copy()
169
+
170
+ batch_size = _get_request_batch_size(modified_request)
171
+ if batch_size is not None:
172
+ modified_request.update(
173
+ {
174
+ "bootstrap_host": [hostname] * batch_size,
175
+ "bootstrap_room": [
176
+ _generate_bootstrap_room() for _ in range(batch_size)
177
+ ],
178
+ }
179
+ )
180
+ else:
181
+ modified_request.update(
182
+ {
183
+ "bootstrap_host": hostname,
184
+ "bootstrap_room": _generate_bootstrap_room(),
185
+ }
186
+ )
187
+
188
+ if request_data.get("stream", False):
189
+ return await load_balancer.generate_stream(
190
+ modified_request, prefill_server, decode_server, "generate"
191
+ )
192
+ else:
193
+ return await load_balancer.generate(
194
+ modified_request, prefill_server, decode_server, "generate"
195
+ )
196
+
197
+
198
+ @app.post("/v1/chat/completions")
199
+ async def handle_completion_request(request_data: dict):
200
+ prefill_server, decode_server = load_balancer.select_pair()
201
+
156
202
  # Parse and transform prefill_server for bootstrap data
157
203
  parsed_url = urllib.parse.urlparse(prefill_server)
158
204
  hostname = parsed_url.hostname
@@ -166,14 +212,33 @@ async def handle_generate_request(request_data: dict):
166
212
 
167
213
  if request_data.get("stream", False):
168
214
  return await load_balancer.generate_stream(
169
- modified_request, prefill_server, decode_server
215
+ modified_request,
216
+ prefill_server,
217
+ decode_server,
218
+ endpoint="v1/chat/completions",
170
219
  )
171
220
  else:
172
221
  return await load_balancer.generate(
173
- modified_request, prefill_server, decode_server
222
+ modified_request,
223
+ prefill_server,
224
+ decode_server,
225
+ endpoint="v1/chat/completions",
174
226
  )
175
227
 
176
228
 
229
+ def _generate_bootstrap_room():
230
+ return random.randint(0, 2**63 - 1)
231
+
232
+
233
+ # We may utilize `GenerateReqInput`'s logic later
234
+ def _get_request_batch_size(request):
235
+ if (text := request.get("text")) is not None:
236
+ return None if isinstance(text, str) else len(text)
237
+ if (input_ids := request.get("input_ids")) is not None:
238
+ return None if isinstance(input_ids[0], int) else len(input_ids)
239
+ return None
240
+
241
+
177
242
  @app.get("/v1/models")
178
243
  async def get_models():
179
244
  prefill_server = load_balancer.prefill_servers[0] # Get the first prefill server
@@ -99,8 +99,12 @@ class MooncakeKVManager(BaseKVManager):
99
99
  disaggregation_mode: DisaggregationMode,
100
100
  server_args: ServerArgs,
101
101
  ):
102
- self.engine = MooncakeTransferEngine()
103
102
  self.kv_args = args
103
+ self.engine = MooncakeTransferEngine(
104
+ hostname=get_local_ip_by_remote(),
105
+ gpu_id=self.kv_args.gpu_id,
106
+ ib_device=self.kv_args.ib_device,
107
+ )
104
108
  self.disaggregation_mode = disaggregation_mode
105
109
  # for p/d multi node infer
106
110
  self.bootstrap_port = server_args.disaggregation_bootstrap_port
@@ -227,7 +231,7 @@ class MooncakeKVManager(BaseKVManager):
227
231
  chunked_dst_kv_indice = req.dst_kv_indices[kv_chunk.index_slice]
228
232
  assert len(chunked_dst_kv_indice) == len(
229
233
  kv_chunk.prefill_kv_indices
230
- )
234
+ ), f"len(chunked_dst_kv_indice) = {len(chunked_dst_kv_indice)}, len(kv_chunk.prefill_kv_indices) = {len(kv_chunk.prefill_kv_indices)}"
231
235
 
232
236
  ret = self.send_kvcache(
233
237
  req.mooncake_session_id,
@@ -387,6 +391,10 @@ class MooncakeKVSender(BaseKVSender):
387
391
 
388
392
 
389
393
  class MooncakeKVReceiver(BaseKVReceiver):
394
+ _ctx = zmq.Context()
395
+ _socket_cache = {}
396
+ _socket_locks = {}
397
+ _global_lock = threading.Lock()
390
398
 
391
399
  def __init__(
392
400
  self,
@@ -436,11 +444,15 @@ class MooncakeKVReceiver(BaseKVReceiver):
436
444
  logger.error(f"Error fetching prefill info from bootstrap: {e}")
437
445
  return None
438
446
 
439
- @cache
440
- def _connect(self, endpoint: str):
441
- socket = zmq.Context().socket(zmq.PUSH)
442
- socket.connect(endpoint)
443
- return socket
447
+ @classmethod
448
+ def _connect(cls, endpoint: str):
449
+ with cls._global_lock:
450
+ if endpoint not in cls._socket_cache:
451
+ sock = cls._ctx.socket(zmq.PUSH)
452
+ sock.connect(endpoint)
453
+ cls._socket_cache[endpoint] = sock
454
+ cls._socket_locks[endpoint] = threading.Lock()
455
+ return cls._socket_cache[endpoint], cls._socket_locks[endpoint]
444
456
 
445
457
  def init(self, kv_indices: npt.NDArray[np.int64], aux_index: Optional[int] = None):
446
458
  self.prefill_server_url = (
@@ -456,18 +468,20 @@ class MooncakeKVReceiver(BaseKVReceiver):
456
468
  packed_aux_data_ptrs = b"".join(
457
469
  struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.aux_data_ptrs
458
470
  )
459
- self._connect("tcp://" + self.prefill_server_url).send_multipart(
460
- [
461
- str(self.bootstrap_room).encode("ascii"),
462
- get_local_ip_by_remote().encode("ascii"),
463
- str(self.kv_mgr.rank_port).encode("ascii"),
464
- self.session_id.encode("ascii"),
465
- packed_kv_data_ptrs,
466
- kv_indices.tobytes(),
467
- packed_aux_data_ptrs,
468
- str(aux_index).encode("ascii"),
469
- ]
470
- )
471
+ sock, lock = self._connect("tcp://" + self.prefill_server_url)
472
+ with lock:
473
+ sock.send_multipart(
474
+ [
475
+ str(self.bootstrap_room).encode("ascii"),
476
+ get_local_ip_by_remote().encode("ascii"),
477
+ str(self.kv_mgr.rank_port).encode("ascii"),
478
+ self.session_id.encode("ascii"),
479
+ packed_kv_data_ptrs,
480
+ kv_indices.tobytes(),
481
+ packed_aux_data_ptrs,
482
+ str(aux_index).encode("ascii"),
483
+ ]
484
+ )
471
485
 
472
486
  def poll(self) -> KVPoll:
473
487
  return self.kv_mgr.check_status(self.bootstrap_room)
@@ -493,52 +507,8 @@ class MooncakeKVBootstrapServer(BaseKVBootstrapServer):
493
507
  self.thread.start()
494
508
 
495
509
  def _setup_routes(self):
496
- self.app.router.add_route("*", "/metadata", self._handle_metadata)
497
510
  self.app.router.add_route("*", "/route", self._handle_route)
498
511
 
499
- async def _handle_metadata(self, request: web.Request):
500
- key = request.query.get("key", "")
501
-
502
- if request.method == "GET":
503
- return await self._handle_metadata_get(key)
504
- elif request.method == "PUT":
505
- return await self._handle_metadata_put(key, request)
506
- elif request.method == "DELETE":
507
- return await self._handle_metadata_delete(key)
508
- return web.Response(
509
- text="Method not allowed", status=405, content_type="application/json"
510
- )
511
-
512
- async def _handle_metadata_get(self, key):
513
- async with self.lock:
514
- value = self.store.get(key)
515
- if value is None:
516
- return web.Response(
517
- text="metadata not found", status=404, content_type="application/json"
518
- )
519
- return web.Response(body=value, status=200, content_type="application/json")
520
-
521
- async def _handle_metadata_put(self, key, request):
522
- data = await request.read()
523
- async with self.lock:
524
- self.store[key] = data
525
- return web.Response(
526
- text="metadata updated", status=200, content_type="application/json"
527
- )
528
-
529
- async def _handle_metadata_delete(self, key):
530
- async with self.lock:
531
- if key not in self.store:
532
- return web.Response(
533
- text="metadata not found",
534
- status=404,
535
- content_type="application/json",
536
- )
537
- del self.store[key]
538
- return web.Response(
539
- text="metadata deleted", status=200, content_type="application/json"
540
- )
541
-
542
512
  async def _handle_route(self, request: web.Request):
543
513
  method = request.method
544
514
  if method == "PUT":
@@ -1,45 +1,14 @@
1
1
  import json
2
2
  import logging
3
- import os
4
- import uuid
5
3
  from dataclasses import dataclass
4
+ from typing import Optional
6
5
 
7
6
  logger = logging.getLogger(__name__)
8
7
 
9
8
 
10
- @dataclass
11
- class MooncakeTransferEngineConfig:
12
- local_hostname: str
13
- metadata_server: str
14
- protocol: str
15
- device_name: str
16
-
17
- @staticmethod
18
- def from_file(file_path: str) -> "MooncakeTransferEngineConfig":
19
- """Load the config from a JSON file."""
20
- with open(file_path) as fin:
21
- config = json.load(fin)
22
- return MooncakeTransferEngineConfig(
23
- local_hostname=config.get("local_hostname", None),
24
- metadata_server=config.get("metadata_server"),
25
- protocol=config.get("protocol", "rdma"),
26
- device_name=config.get("device_name", ""),
27
- )
28
-
29
- @staticmethod
30
- def load_from_env() -> "MooncakeTransferEngineConfig":
31
- """Load config from a file specified in the environment variable."""
32
- config_file_path = os.getenv("MOONCAKE_CONFIG_PATH")
33
- if config_file_path is None:
34
- raise ValueError(
35
- "The environment variable 'MOONCAKE_CONFIG_PATH' is not set."
36
- )
37
- return MooncakeTransferEngineConfig.from_file(config_file_path)
38
-
39
-
40
9
  class MooncakeTransferEngine:
41
10
 
42
- def __init__(self):
11
+ def __init__(self, hostname: str, gpu_id: int, ib_device: Optional[str] = None):
43
12
  try:
44
13
  from mooncake.engine import TransferEngine
45
14
  except ImportError as e:
@@ -50,43 +19,43 @@ class MooncakeTransferEngine:
50
19
  ) from e
51
20
 
52
21
  self.engine = TransferEngine()
22
+ self.hostname = hostname
23
+ self.gpu_id = gpu_id
24
+ self.ib_device = ib_device
53
25
 
54
- try:
55
- self.config = MooncakeTransferEngineConfig.load_from_env()
56
- logger.info("Mooncake Configuration loaded successfully.")
57
- except ValueError as e:
58
- logger.error(e)
59
- raise
60
- except Exception as exc:
61
- logger.error("An error occurred while loading the configuration: %s", exc)
62
- raise
63
-
64
- self.config = MooncakeTransferEngineConfig.load_from_env()
65
-
66
- session_suffix = "_" + str(uuid.uuid4())
67
- self.session_id = self.config.local_hostname + session_suffix
68
26
  self.initialize(
69
- self.session_id,
70
- self.config.metadata_server,
71
- self.config.protocol,
72
- self.config.device_name,
27
+ hostname=self.hostname,
28
+ device_name=self.ib_device,
73
29
  )
30
+ self.session_id = f"{self.hostname}:{self.engine.get_rpc_port()}"
74
31
 
75
32
  def register(self, ptr, length):
76
- self.engine.register_memory(ptr, length)
33
+ ret_value = self.engine.register_memory(ptr, length)
34
+ if ret_value != 0:
35
+ logger.error("Mooncake memory registration failed.")
36
+ raise RuntimeError("Mooncake memory registration failed.")
77
37
 
78
38
  def deregister(self, ptr):
79
- self.engine.unregister_memory(ptr)
39
+ ret_value = self.engine.unregister_memory(ptr)
40
+ if ret_value != 0:
41
+ logger.error("Mooncake memory deregistration failed.")
42
+ raise RuntimeError("Mooncake memory deregistration failed.")
80
43
 
81
44
  def initialize(
82
45
  self,
83
- local_hostname: str,
84
- metadata_server: str,
85
- protocol: str,
86
- device_name: str,
46
+ hostname: str,
47
+ device_name: Optional[str],
87
48
  ) -> None:
88
49
  """Initialize the mooncake instance."""
89
- self.engine.initialize(local_hostname, metadata_server, protocol, device_name)
50
+ ret_value = self.engine.initialize(
51
+ hostname,
52
+ "P2PHANDSHAKE",
53
+ "rdma",
54
+ device_name if device_name is not None else "",
55
+ )
56
+ if ret_value != 0:
57
+ logger.error("Mooncake Transfer Engine initialization failed.")
58
+ raise RuntimeError("Mooncake Transfer Engine initialization failed.")
90
59
 
91
60
  def transfer_sync(
92
61
  self, session_id: str, buffer: int, peer_buffer_address: int, length: int
@@ -97,12 +66,12 @@ class MooncakeTransferEngine:
97
66
  session_id, buffer, peer_buffer_address, length
98
67
  )
99
68
  if ret < 0:
100
- logger.error("Transfer Return Error")
101
- raise Exception("Transfer Return Error")
69
+ logger.error("Mooncake Transfer Engine Return Error.")
70
+ raise RuntimeError("Mooncake Transfer Engine Return Error.")
102
71
  return ret
103
72
 
104
73
  def get_localhost(self):
105
- return self.config.local_hostname
74
+ return self.hostname
106
75
 
107
76
  def get_session_id(self):
108
77
  return self.session_id
@@ -0,0 +1 @@
1
+ from .conn import NixlKVBootstrapServer, NixlKVManager, NixlKVReceiver, NixlKVSender