sglang 0.4.5.post1__py3-none-any.whl → 0.4.5.post3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +2 -4
- sglang/bench_one_batch.py +2 -2
- sglang/bench_serving.py +3 -6
- sglang/compile_deep_gemm.py +136 -0
- sglang/lang/backend/anthropic.py +0 -4
- sglang/lang/backend/base_backend.py +1 -1
- sglang/lang/backend/openai.py +6 -2
- sglang/lang/backend/runtime_endpoint.py +5 -1
- sglang/lang/backend/vertexai.py +0 -1
- sglang/lang/compiler.py +1 -7
- sglang/lang/tracer.py +3 -7
- sglang/srt/_custom_ops.py +0 -2
- sglang/srt/configs/model_config.py +4 -1
- sglang/srt/constrained/outlines_jump_forward.py +14 -1
- sglang/srt/constrained/triton_ops/bitmask_ops.py +141 -0
- sglang/srt/constrained/xgrammar_backend.py +27 -4
- sglang/srt/custom_op.py +0 -62
- sglang/srt/disaggregation/decode.py +105 -6
- sglang/srt/disaggregation/mini_lb.py +74 -9
- sglang/srt/disaggregation/mooncake/conn.py +33 -63
- sglang/srt/disaggregation/mooncake/transfer_engine.py +30 -61
- sglang/srt/disaggregation/nixl/__init__.py +1 -0
- sglang/srt/disaggregation/nixl/conn.py +622 -0
- sglang/srt/disaggregation/prefill.py +137 -17
- sglang/srt/disaggregation/utils.py +32 -0
- sglang/srt/entrypoints/engine.py +4 -0
- sglang/srt/entrypoints/http_server.py +3 -7
- sglang/srt/entrypoints/verl_engine.py +7 -5
- sglang/srt/function_call_parser.py +60 -0
- sglang/srt/layers/activation.py +6 -8
- sglang/srt/layers/attention/flashattention_backend.py +883 -209
- sglang/srt/layers/attention/flashinfer_backend.py +5 -2
- sglang/srt/layers/attention/torch_native_backend.py +6 -1
- sglang/srt/layers/attention/triton_backend.py +6 -0
- sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +5 -5
- sglang/srt/layers/attention/triton_ops/extend_attention.py +18 -7
- sglang/srt/layers/attention/triton_ops/prefill_attention.py +7 -3
- sglang/srt/layers/dp_attention.py +1 -1
- sglang/srt/layers/layernorm.py +20 -5
- sglang/srt/layers/linear.py +17 -3
- sglang/srt/layers/moe/ep_moe/layer.py +17 -29
- sglang/srt/layers/moe/fused_moe_native.py +4 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +14 -19
- sglang/srt/layers/moe/fused_moe_triton/layer.py +7 -0
- sglang/srt/layers/moe/topk.py +27 -30
- sglang/srt/layers/parameter.py +0 -2
- sglang/srt/layers/quantization/__init__.py +1 -0
- sglang/srt/layers/quantization/blockwise_int8.py +2 -0
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +9 -2
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +16 -44
- sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a16_fp8.py +153 -0
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +4 -7
- sglang/srt/layers/quantization/deep_gemm.py +378 -0
- sglang/srt/layers/quantization/fp8.py +115 -132
- sglang/srt/layers/quantization/fp8_kernel.py +213 -88
- sglang/srt/layers/quantization/fp8_utils.py +189 -264
- sglang/srt/layers/quantization/gptq.py +13 -7
- sglang/srt/layers/quantization/modelopt_quant.py +2 -2
- sglang/srt/layers/quantization/moe_wna16.py +2 -0
- sglang/srt/layers/quantization/utils.py +5 -11
- sglang/srt/layers/quantization/w8a8_fp8.py +2 -0
- sglang/srt/layers/quantization/w8a8_int8.py +7 -7
- sglang/srt/layers/radix_attention.py +15 -0
- sglang/srt/layers/rotary_embedding.py +9 -8
- sglang/srt/layers/sampler.py +7 -12
- sglang/srt/lora/backend/base_backend.py +18 -2
- sglang/srt/lora/backend/flashinfer_backend.py +1 -1
- sglang/srt/lora/backend/triton_backend.py +1 -1
- sglang/srt/lora/layers.py +1 -1
- sglang/srt/lora/lora.py +1 -1
- sglang/srt/lora/lora_manager.py +1 -1
- sglang/srt/managers/data_parallel_controller.py +7 -1
- sglang/srt/managers/detokenizer_manager.py +0 -1
- sglang/srt/managers/io_struct.py +15 -3
- sglang/srt/managers/mm_utils.py +4 -3
- sglang/srt/managers/multimodal_processor.py +0 -2
- sglang/srt/managers/multimodal_processors/base_processor.py +3 -2
- sglang/srt/managers/schedule_batch.py +15 -4
- sglang/srt/managers/scheduler.py +28 -77
- sglang/srt/managers/tokenizer_manager.py +116 -29
- sglang/srt/managers/tp_worker.py +1 -0
- sglang/srt/mem_cache/hiradix_cache.py +41 -29
- sglang/srt/mem_cache/memory_pool.py +38 -15
- sglang/srt/model_executor/cuda_graph_runner.py +15 -10
- sglang/srt/model_executor/model_runner.py +39 -31
- sglang/srt/models/bert.py +398 -0
- sglang/srt/models/deepseek.py +1 -1
- sglang/srt/models/deepseek_nextn.py +74 -70
- sglang/srt/models/deepseek_v2.py +292 -348
- sglang/srt/models/llama.py +5 -5
- sglang/srt/models/minicpm3.py +31 -203
- sglang/srt/models/minicpmo.py +17 -6
- sglang/srt/models/qwen2.py +4 -1
- sglang/srt/models/qwen2_moe.py +14 -13
- sglang/srt/models/qwen3.py +335 -0
- sglang/srt/models/qwen3_moe.py +423 -0
- sglang/srt/openai_api/adapter.py +71 -4
- sglang/srt/openai_api/protocol.py +6 -1
- sglang/srt/reasoning_parser.py +0 -1
- sglang/srt/sampling/sampling_batch_info.py +2 -3
- sglang/srt/server_args.py +86 -72
- sglang/srt/speculative/build_eagle_tree.py +2 -2
- sglang/srt/speculative/eagle_utils.py +2 -2
- sglang/srt/speculative/eagle_worker.py +6 -14
- sglang/srt/utils.py +62 -6
- sglang/test/runners.py +5 -1
- sglang/test/test_block_fp8.py +167 -0
- sglang/test/test_custom_ops.py +1 -1
- sglang/test/test_utils.py +3 -1
- sglang/version.py +1 -1
- {sglang-0.4.5.post1.dist-info → sglang-0.4.5.post3.dist-info}/METADATA +5 -5
- {sglang-0.4.5.post1.dist-info → sglang-0.4.5.post3.dist-info}/RECORD +116 -110
- {sglang-0.4.5.post1.dist-info → sglang-0.4.5.post3.dist-info}/WHEEL +1 -1
- sglang/lang/__init__.py +0 -0
- sglang/srt/lora/backend/__init__.py +0 -25
- sglang/srt/server.py +0 -18
- {sglang-0.4.5.post1.dist-info → sglang-0.4.5.post3.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.5.post1.dist-info → sglang-0.4.5.post3.dist-info}/top_level.txt +0 -0
@@ -21,6 +21,7 @@ Life cycle of a request in the decode server
|
|
21
21
|
from __future__ import annotations
|
22
22
|
|
23
23
|
import logging
|
24
|
+
from collections import deque
|
24
25
|
from dataclasses import dataclass
|
25
26
|
from typing import TYPE_CHECKING, List, Optional, Tuple
|
26
27
|
|
@@ -35,6 +36,7 @@ from sglang.srt.disaggregation.utils import (
|
|
35
36
|
ReqToMetadataIdxAllocator,
|
36
37
|
TransferBackend,
|
37
38
|
get_kv_class,
|
39
|
+
kv_to_page_indices,
|
38
40
|
poll_and_all_reduce,
|
39
41
|
)
|
40
42
|
from sglang.srt.mem_cache.base_prefix_cache import BasePrefixCache
|
@@ -121,7 +123,7 @@ class DecodePreallocQueue:
|
|
121
123
|
kv_args.aux_item_lens = [
|
122
124
|
metadata_buffer[0].nbytes for metadata_buffer in self.metadata_buffers
|
123
125
|
]
|
124
|
-
kv_args.ib_device =
|
126
|
+
kv_args.ib_device = self.scheduler.server_args.disaggregation_ib_device
|
125
127
|
kv_args.gpu_id = self.scheduler.gpu_id
|
126
128
|
kv_manager_class = get_kv_class(self.transfer_backend, KVClassType.MANAGER)
|
127
129
|
kv_manager = kv_manager_class(
|
@@ -205,7 +207,10 @@ class DecodePreallocQueue:
|
|
205
207
|
self.req_to_metadata_buffer_idx_allocator.alloc()
|
206
208
|
)
|
207
209
|
assert decode_req.metadata_buffer_index is not None
|
208
|
-
|
210
|
+
page_indices = kv_to_page_indices(
|
211
|
+
kv_indices, self.token_to_kv_pool_allocator.page_size
|
212
|
+
)
|
213
|
+
decode_req.kv_receiver.init(page_indices, decode_req.metadata_buffer_index)
|
209
214
|
preallocated_reqs.append(decode_req)
|
210
215
|
indices_to_remove.add(i)
|
211
216
|
|
@@ -245,10 +250,30 @@ class DecodePreallocQueue:
|
|
245
250
|
assert req_pool_indices is not None
|
246
251
|
|
247
252
|
req.req_pool_idx = req_pool_indices[0]
|
248
|
-
|
249
|
-
|
250
|
-
|
251
|
-
|
253
|
+
if self.token_to_kv_pool_allocator.page_size == 1:
|
254
|
+
kv_loc = self.token_to_kv_pool_allocator.alloc(
|
255
|
+
len(req.origin_input_ids) + max(len(req.output_ids) - 1, 0)
|
256
|
+
)
|
257
|
+
else:
|
258
|
+
num_tokens = len(req.origin_input_ids) + max(len(req.output_ids) - 1, 0)
|
259
|
+
kv_loc = self.token_to_kv_pool_allocator.alloc_extend(
|
260
|
+
prefix_lens=torch.tensor(
|
261
|
+
[0],
|
262
|
+
dtype=torch.int64,
|
263
|
+
device=self.token_to_kv_pool_allocator.device,
|
264
|
+
),
|
265
|
+
seq_lens=torch.tensor(
|
266
|
+
[num_tokens],
|
267
|
+
dtype=torch.int64,
|
268
|
+
device=self.token_to_kv_pool_allocator.device,
|
269
|
+
),
|
270
|
+
last_loc=torch.tensor(
|
271
|
+
[-1],
|
272
|
+
dtype=torch.int64,
|
273
|
+
device=self.token_to_kv_pool_allocator.device,
|
274
|
+
),
|
275
|
+
extend_num_tokens=num_tokens,
|
276
|
+
)
|
252
277
|
assert kv_loc is not None
|
253
278
|
|
254
279
|
self.req_to_token_pool.write((req.req_pool_idx, slice(0, len(kv_loc))), kv_loc)
|
@@ -419,6 +444,80 @@ class ScheduleBatchDisaggregationDecodeMixin:
|
|
419
444
|
|
420
445
|
class SchedulerDisaggregationDecodeMixin:
|
421
446
|
|
447
|
+
@torch.no_grad()
|
448
|
+
def event_loop_normal_disagg_decode(self):
|
449
|
+
"""A normal scheduler loop for decode worker in disaggregation mode."""
|
450
|
+
|
451
|
+
while True:
|
452
|
+
recv_reqs = self.recv_requests()
|
453
|
+
self.process_input_requests(recv_reqs)
|
454
|
+
# polling and allocating kv cache
|
455
|
+
self.process_decode_queue()
|
456
|
+
batch = self.get_next_disagg_decode_batch_to_run()
|
457
|
+
self.cur_batch = batch
|
458
|
+
|
459
|
+
if batch:
|
460
|
+
# Generate fake extend output.
|
461
|
+
if batch.forward_mode.is_extend():
|
462
|
+
# Note: Logprobs should be handled on the prefill engine.
|
463
|
+
self.stream_output(batch.reqs, False)
|
464
|
+
else:
|
465
|
+
result = self.run_batch(batch)
|
466
|
+
self.process_batch_result(batch, result)
|
467
|
+
|
468
|
+
if batch is None and (
|
469
|
+
len(self.disagg_decode_transfer_queue.queue)
|
470
|
+
+ len(self.disagg_decode_prealloc_queue.queue)
|
471
|
+
== 0
|
472
|
+
):
|
473
|
+
# When the server is idle, do self-check and re-init some states
|
474
|
+
self.check_memory()
|
475
|
+
self.new_token_ratio = self.init_new_token_ratio
|
476
|
+
|
477
|
+
self.last_batch = batch
|
478
|
+
|
479
|
+
@torch.no_grad()
|
480
|
+
def event_loop_overlap_disagg_decode(self):
|
481
|
+
result_queue = deque()
|
482
|
+
self.last_batch: Optional[ScheduleBatch] = None
|
483
|
+
self.last_batch_is_extend = False # last batch is modifed in-place, so we need another variable to track if it's extend
|
484
|
+
|
485
|
+
while True:
|
486
|
+
recv_reqs = self.recv_requests()
|
487
|
+
self.process_input_requests(recv_reqs)
|
488
|
+
# polling and allocating kv cache
|
489
|
+
self.process_decode_queue()
|
490
|
+
batch = self.get_next_disagg_decode_batch_to_run()
|
491
|
+
self.cur_batch = batch
|
492
|
+
last_batch_is_extend = False
|
493
|
+
|
494
|
+
if batch:
|
495
|
+
# Generate fake extend output.
|
496
|
+
if batch.forward_mode.is_extend():
|
497
|
+
# Note: Logprobs should be handled on the prefill engine.
|
498
|
+
self.stream_output(batch.reqs, False)
|
499
|
+
last_batch_is_extend = True
|
500
|
+
else:
|
501
|
+
result = self.run_batch(batch)
|
502
|
+
result_queue.append((batch.copy(), result))
|
503
|
+
|
504
|
+
# Process the results of the previous batch but skip if the last batch is extend
|
505
|
+
if self.last_batch and not self.last_batch_is_extend:
|
506
|
+
tmp_batch, tmp_result = result_queue.popleft()
|
507
|
+
self.process_batch_result(tmp_batch, tmp_result)
|
508
|
+
|
509
|
+
if batch is None and (
|
510
|
+
len(self.disagg_decode_transfer_queue.queue)
|
511
|
+
+ len(self.disagg_decode_prealloc_queue.queue)
|
512
|
+
== 0
|
513
|
+
):
|
514
|
+
# When the server is idle, do self-check and re-init some states
|
515
|
+
self.check_memory()
|
516
|
+
self.new_token_ratio = self.init_new_token_ratio
|
517
|
+
|
518
|
+
self.last_batch = batch
|
519
|
+
self.last_batch_is_extend = last_batch_is_extend
|
520
|
+
|
422
521
|
def get_next_disagg_decode_batch_to_run(
|
423
522
|
self: Scheduler,
|
424
523
|
) -> Optional[Tuple[ScheduleBatch, bool]]:
|
@@ -23,13 +23,18 @@ class MiniLoadBalancer:
|
|
23
23
|
return random.choice(self.prefill_servers), random.choice(self.decode_servers)
|
24
24
|
|
25
25
|
async def generate(
|
26
|
-
self, modified_request, prefill_server, decode_server
|
26
|
+
self, modified_request, prefill_server, decode_server, endpoint
|
27
27
|
) -> ORJSONResponse:
|
28
|
+
assert endpoint[0] != "/", f"Endpoint should not start with '/': {endpoint}"
|
28
29
|
|
29
|
-
async with aiohttp.ClientSession(
|
30
|
+
async with aiohttp.ClientSession(
|
31
|
+
timeout=aiohttp.ClientTimeout(
|
32
|
+
total=3600
|
33
|
+
) # Add timeout for request reliability
|
34
|
+
) as session:
|
30
35
|
tasks = [
|
31
|
-
session.post(f"{prefill_server}/
|
32
|
-
session.post(f"{decode_server}/
|
36
|
+
session.post(f"{prefill_server}/{endpoint}", json=modified_request),
|
37
|
+
session.post(f"{decode_server}/{endpoint}", json=modified_request),
|
33
38
|
]
|
34
39
|
# Wait for both responses to complete. Prefill should end first.
|
35
40
|
prefill_response, decode_response = await asyncio.gather(*tasks)
|
@@ -39,7 +44,11 @@ class MiniLoadBalancer:
|
|
39
44
|
status_code=decode_response.status,
|
40
45
|
)
|
41
46
|
|
42
|
-
async def generate_stream(
|
47
|
+
async def generate_stream(
|
48
|
+
self, modified_request, prefill_server, decode_server, endpoint="generate"
|
49
|
+
):
|
50
|
+
assert endpoint[0] != "/", f"Endpoint should not start with '/': {endpoint}"
|
51
|
+
|
43
52
|
async def stream_results():
|
44
53
|
async with aiohttp.ClientSession(
|
45
54
|
timeout=aiohttp.ClientTimeout(
|
@@ -50,10 +59,10 @@ class MiniLoadBalancer:
|
|
50
59
|
# Create the tasks for both prefill and decode requests
|
51
60
|
tasks = [
|
52
61
|
session.post(
|
53
|
-
f"{prefill_server}/
|
62
|
+
f"{prefill_server}/{endpoint}", json=modified_request
|
54
63
|
),
|
55
64
|
session.post(
|
56
|
-
f"{decode_server}/
|
65
|
+
f"{decode_server}/{endpoint}", json=modified_request
|
57
66
|
),
|
58
67
|
]
|
59
68
|
# Wait for both responses to complete. Since this is streaming, they return immediately.
|
@@ -153,6 +162,43 @@ async def get_model_info():
|
|
153
162
|
async def handle_generate_request(request_data: dict):
|
154
163
|
prefill_server, decode_server = load_balancer.select_pair()
|
155
164
|
|
165
|
+
# Parse and transform prefill_server for bootstrap data
|
166
|
+
parsed_url = urllib.parse.urlparse(prefill_server)
|
167
|
+
hostname = parsed_url.hostname
|
168
|
+
modified_request = request_data.copy()
|
169
|
+
|
170
|
+
batch_size = _get_request_batch_size(modified_request)
|
171
|
+
if batch_size is not None:
|
172
|
+
modified_request.update(
|
173
|
+
{
|
174
|
+
"bootstrap_host": [hostname] * batch_size,
|
175
|
+
"bootstrap_room": [
|
176
|
+
_generate_bootstrap_room() for _ in range(batch_size)
|
177
|
+
],
|
178
|
+
}
|
179
|
+
)
|
180
|
+
else:
|
181
|
+
modified_request.update(
|
182
|
+
{
|
183
|
+
"bootstrap_host": hostname,
|
184
|
+
"bootstrap_room": _generate_bootstrap_room(),
|
185
|
+
}
|
186
|
+
)
|
187
|
+
|
188
|
+
if request_data.get("stream", False):
|
189
|
+
return await load_balancer.generate_stream(
|
190
|
+
modified_request, prefill_server, decode_server, "generate"
|
191
|
+
)
|
192
|
+
else:
|
193
|
+
return await load_balancer.generate(
|
194
|
+
modified_request, prefill_server, decode_server, "generate"
|
195
|
+
)
|
196
|
+
|
197
|
+
|
198
|
+
@app.post("/v1/chat/completions")
|
199
|
+
async def handle_completion_request(request_data: dict):
|
200
|
+
prefill_server, decode_server = load_balancer.select_pair()
|
201
|
+
|
156
202
|
# Parse and transform prefill_server for bootstrap data
|
157
203
|
parsed_url = urllib.parse.urlparse(prefill_server)
|
158
204
|
hostname = parsed_url.hostname
|
@@ -166,14 +212,33 @@ async def handle_generate_request(request_data: dict):
|
|
166
212
|
|
167
213
|
if request_data.get("stream", False):
|
168
214
|
return await load_balancer.generate_stream(
|
169
|
-
modified_request,
|
215
|
+
modified_request,
|
216
|
+
prefill_server,
|
217
|
+
decode_server,
|
218
|
+
endpoint="v1/chat/completions",
|
170
219
|
)
|
171
220
|
else:
|
172
221
|
return await load_balancer.generate(
|
173
|
-
modified_request,
|
222
|
+
modified_request,
|
223
|
+
prefill_server,
|
224
|
+
decode_server,
|
225
|
+
endpoint="v1/chat/completions",
|
174
226
|
)
|
175
227
|
|
176
228
|
|
229
|
+
def _generate_bootstrap_room():
|
230
|
+
return random.randint(0, 2**63 - 1)
|
231
|
+
|
232
|
+
|
233
|
+
# We may utilize `GenerateReqInput`'s logic later
|
234
|
+
def _get_request_batch_size(request):
|
235
|
+
if (text := request.get("text")) is not None:
|
236
|
+
return None if isinstance(text, str) else len(text)
|
237
|
+
if (input_ids := request.get("input_ids")) is not None:
|
238
|
+
return None if isinstance(input_ids[0], int) else len(input_ids)
|
239
|
+
return None
|
240
|
+
|
241
|
+
|
177
242
|
@app.get("/v1/models")
|
178
243
|
async def get_models():
|
179
244
|
prefill_server = load_balancer.prefill_servers[0] # Get the first prefill server
|
@@ -99,8 +99,12 @@ class MooncakeKVManager(BaseKVManager):
|
|
99
99
|
disaggregation_mode: DisaggregationMode,
|
100
100
|
server_args: ServerArgs,
|
101
101
|
):
|
102
|
-
self.engine = MooncakeTransferEngine()
|
103
102
|
self.kv_args = args
|
103
|
+
self.engine = MooncakeTransferEngine(
|
104
|
+
hostname=get_local_ip_by_remote(),
|
105
|
+
gpu_id=self.kv_args.gpu_id,
|
106
|
+
ib_device=self.kv_args.ib_device,
|
107
|
+
)
|
104
108
|
self.disaggregation_mode = disaggregation_mode
|
105
109
|
# for p/d multi node infer
|
106
110
|
self.bootstrap_port = server_args.disaggregation_bootstrap_port
|
@@ -227,7 +231,7 @@ class MooncakeKVManager(BaseKVManager):
|
|
227
231
|
chunked_dst_kv_indice = req.dst_kv_indices[kv_chunk.index_slice]
|
228
232
|
assert len(chunked_dst_kv_indice) == len(
|
229
233
|
kv_chunk.prefill_kv_indices
|
230
|
-
)
|
234
|
+
), f"len(chunked_dst_kv_indice) = {len(chunked_dst_kv_indice)}, len(kv_chunk.prefill_kv_indices) = {len(kv_chunk.prefill_kv_indices)}"
|
231
235
|
|
232
236
|
ret = self.send_kvcache(
|
233
237
|
req.mooncake_session_id,
|
@@ -387,6 +391,10 @@ class MooncakeKVSender(BaseKVSender):
|
|
387
391
|
|
388
392
|
|
389
393
|
class MooncakeKVReceiver(BaseKVReceiver):
|
394
|
+
_ctx = zmq.Context()
|
395
|
+
_socket_cache = {}
|
396
|
+
_socket_locks = {}
|
397
|
+
_global_lock = threading.Lock()
|
390
398
|
|
391
399
|
def __init__(
|
392
400
|
self,
|
@@ -436,11 +444,15 @@ class MooncakeKVReceiver(BaseKVReceiver):
|
|
436
444
|
logger.error(f"Error fetching prefill info from bootstrap: {e}")
|
437
445
|
return None
|
438
446
|
|
439
|
-
@
|
440
|
-
def _connect(
|
441
|
-
|
442
|
-
|
443
|
-
|
447
|
+
@classmethod
|
448
|
+
def _connect(cls, endpoint: str):
|
449
|
+
with cls._global_lock:
|
450
|
+
if endpoint not in cls._socket_cache:
|
451
|
+
sock = cls._ctx.socket(zmq.PUSH)
|
452
|
+
sock.connect(endpoint)
|
453
|
+
cls._socket_cache[endpoint] = sock
|
454
|
+
cls._socket_locks[endpoint] = threading.Lock()
|
455
|
+
return cls._socket_cache[endpoint], cls._socket_locks[endpoint]
|
444
456
|
|
445
457
|
def init(self, kv_indices: npt.NDArray[np.int64], aux_index: Optional[int] = None):
|
446
458
|
self.prefill_server_url = (
|
@@ -456,18 +468,20 @@ class MooncakeKVReceiver(BaseKVReceiver):
|
|
456
468
|
packed_aux_data_ptrs = b"".join(
|
457
469
|
struct.pack("Q", ptr) for ptr in self.kv_mgr.kv_args.aux_data_ptrs
|
458
470
|
)
|
459
|
-
self._connect("tcp://" + self.prefill_server_url)
|
460
|
-
|
461
|
-
|
462
|
-
|
463
|
-
|
464
|
-
|
465
|
-
|
466
|
-
|
467
|
-
|
468
|
-
|
469
|
-
|
470
|
-
|
471
|
+
sock, lock = self._connect("tcp://" + self.prefill_server_url)
|
472
|
+
with lock:
|
473
|
+
sock.send_multipart(
|
474
|
+
[
|
475
|
+
str(self.bootstrap_room).encode("ascii"),
|
476
|
+
get_local_ip_by_remote().encode("ascii"),
|
477
|
+
str(self.kv_mgr.rank_port).encode("ascii"),
|
478
|
+
self.session_id.encode("ascii"),
|
479
|
+
packed_kv_data_ptrs,
|
480
|
+
kv_indices.tobytes(),
|
481
|
+
packed_aux_data_ptrs,
|
482
|
+
str(aux_index).encode("ascii"),
|
483
|
+
]
|
484
|
+
)
|
471
485
|
|
472
486
|
def poll(self) -> KVPoll:
|
473
487
|
return self.kv_mgr.check_status(self.bootstrap_room)
|
@@ -493,52 +507,8 @@ class MooncakeKVBootstrapServer(BaseKVBootstrapServer):
|
|
493
507
|
self.thread.start()
|
494
508
|
|
495
509
|
def _setup_routes(self):
|
496
|
-
self.app.router.add_route("*", "/metadata", self._handle_metadata)
|
497
510
|
self.app.router.add_route("*", "/route", self._handle_route)
|
498
511
|
|
499
|
-
async def _handle_metadata(self, request: web.Request):
|
500
|
-
key = request.query.get("key", "")
|
501
|
-
|
502
|
-
if request.method == "GET":
|
503
|
-
return await self._handle_metadata_get(key)
|
504
|
-
elif request.method == "PUT":
|
505
|
-
return await self._handle_metadata_put(key, request)
|
506
|
-
elif request.method == "DELETE":
|
507
|
-
return await self._handle_metadata_delete(key)
|
508
|
-
return web.Response(
|
509
|
-
text="Method not allowed", status=405, content_type="application/json"
|
510
|
-
)
|
511
|
-
|
512
|
-
async def _handle_metadata_get(self, key):
|
513
|
-
async with self.lock:
|
514
|
-
value = self.store.get(key)
|
515
|
-
if value is None:
|
516
|
-
return web.Response(
|
517
|
-
text="metadata not found", status=404, content_type="application/json"
|
518
|
-
)
|
519
|
-
return web.Response(body=value, status=200, content_type="application/json")
|
520
|
-
|
521
|
-
async def _handle_metadata_put(self, key, request):
|
522
|
-
data = await request.read()
|
523
|
-
async with self.lock:
|
524
|
-
self.store[key] = data
|
525
|
-
return web.Response(
|
526
|
-
text="metadata updated", status=200, content_type="application/json"
|
527
|
-
)
|
528
|
-
|
529
|
-
async def _handle_metadata_delete(self, key):
|
530
|
-
async with self.lock:
|
531
|
-
if key not in self.store:
|
532
|
-
return web.Response(
|
533
|
-
text="metadata not found",
|
534
|
-
status=404,
|
535
|
-
content_type="application/json",
|
536
|
-
)
|
537
|
-
del self.store[key]
|
538
|
-
return web.Response(
|
539
|
-
text="metadata deleted", status=200, content_type="application/json"
|
540
|
-
)
|
541
|
-
|
542
512
|
async def _handle_route(self, request: web.Request):
|
543
513
|
method = request.method
|
544
514
|
if method == "PUT":
|
@@ -1,45 +1,14 @@
|
|
1
1
|
import json
|
2
2
|
import logging
|
3
|
-
import os
|
4
|
-
import uuid
|
5
3
|
from dataclasses import dataclass
|
4
|
+
from typing import Optional
|
6
5
|
|
7
6
|
logger = logging.getLogger(__name__)
|
8
7
|
|
9
8
|
|
10
|
-
@dataclass
|
11
|
-
class MooncakeTransferEngineConfig:
|
12
|
-
local_hostname: str
|
13
|
-
metadata_server: str
|
14
|
-
protocol: str
|
15
|
-
device_name: str
|
16
|
-
|
17
|
-
@staticmethod
|
18
|
-
def from_file(file_path: str) -> "MooncakeTransferEngineConfig":
|
19
|
-
"""Load the config from a JSON file."""
|
20
|
-
with open(file_path) as fin:
|
21
|
-
config = json.load(fin)
|
22
|
-
return MooncakeTransferEngineConfig(
|
23
|
-
local_hostname=config.get("local_hostname", None),
|
24
|
-
metadata_server=config.get("metadata_server"),
|
25
|
-
protocol=config.get("protocol", "rdma"),
|
26
|
-
device_name=config.get("device_name", ""),
|
27
|
-
)
|
28
|
-
|
29
|
-
@staticmethod
|
30
|
-
def load_from_env() -> "MooncakeTransferEngineConfig":
|
31
|
-
"""Load config from a file specified in the environment variable."""
|
32
|
-
config_file_path = os.getenv("MOONCAKE_CONFIG_PATH")
|
33
|
-
if config_file_path is None:
|
34
|
-
raise ValueError(
|
35
|
-
"The environment variable 'MOONCAKE_CONFIG_PATH' is not set."
|
36
|
-
)
|
37
|
-
return MooncakeTransferEngineConfig.from_file(config_file_path)
|
38
|
-
|
39
|
-
|
40
9
|
class MooncakeTransferEngine:
|
41
10
|
|
42
|
-
def __init__(self):
|
11
|
+
def __init__(self, hostname: str, gpu_id: int, ib_device: Optional[str] = None):
|
43
12
|
try:
|
44
13
|
from mooncake.engine import TransferEngine
|
45
14
|
except ImportError as e:
|
@@ -50,43 +19,43 @@ class MooncakeTransferEngine:
|
|
50
19
|
) from e
|
51
20
|
|
52
21
|
self.engine = TransferEngine()
|
22
|
+
self.hostname = hostname
|
23
|
+
self.gpu_id = gpu_id
|
24
|
+
self.ib_device = ib_device
|
53
25
|
|
54
|
-
try:
|
55
|
-
self.config = MooncakeTransferEngineConfig.load_from_env()
|
56
|
-
logger.info("Mooncake Configuration loaded successfully.")
|
57
|
-
except ValueError as e:
|
58
|
-
logger.error(e)
|
59
|
-
raise
|
60
|
-
except Exception as exc:
|
61
|
-
logger.error("An error occurred while loading the configuration: %s", exc)
|
62
|
-
raise
|
63
|
-
|
64
|
-
self.config = MooncakeTransferEngineConfig.load_from_env()
|
65
|
-
|
66
|
-
session_suffix = "_" + str(uuid.uuid4())
|
67
|
-
self.session_id = self.config.local_hostname + session_suffix
|
68
26
|
self.initialize(
|
69
|
-
self.
|
70
|
-
self.
|
71
|
-
self.config.protocol,
|
72
|
-
self.config.device_name,
|
27
|
+
hostname=self.hostname,
|
28
|
+
device_name=self.ib_device,
|
73
29
|
)
|
30
|
+
self.session_id = f"{self.hostname}:{self.engine.get_rpc_port()}"
|
74
31
|
|
75
32
|
def register(self, ptr, length):
|
76
|
-
self.engine.register_memory(ptr, length)
|
33
|
+
ret_value = self.engine.register_memory(ptr, length)
|
34
|
+
if ret_value != 0:
|
35
|
+
logger.error("Mooncake memory registration failed.")
|
36
|
+
raise RuntimeError("Mooncake memory registration failed.")
|
77
37
|
|
78
38
|
def deregister(self, ptr):
|
79
|
-
self.engine.unregister_memory(ptr)
|
39
|
+
ret_value = self.engine.unregister_memory(ptr)
|
40
|
+
if ret_value != 0:
|
41
|
+
logger.error("Mooncake memory deregistration failed.")
|
42
|
+
raise RuntimeError("Mooncake memory deregistration failed.")
|
80
43
|
|
81
44
|
def initialize(
|
82
45
|
self,
|
83
|
-
|
84
|
-
|
85
|
-
protocol: str,
|
86
|
-
device_name: str,
|
46
|
+
hostname: str,
|
47
|
+
device_name: Optional[str],
|
87
48
|
) -> None:
|
88
49
|
"""Initialize the mooncake instance."""
|
89
|
-
self.engine.initialize(
|
50
|
+
ret_value = self.engine.initialize(
|
51
|
+
hostname,
|
52
|
+
"P2PHANDSHAKE",
|
53
|
+
"rdma",
|
54
|
+
device_name if device_name is not None else "",
|
55
|
+
)
|
56
|
+
if ret_value != 0:
|
57
|
+
logger.error("Mooncake Transfer Engine initialization failed.")
|
58
|
+
raise RuntimeError("Mooncake Transfer Engine initialization failed.")
|
90
59
|
|
91
60
|
def transfer_sync(
|
92
61
|
self, session_id: str, buffer: int, peer_buffer_address: int, length: int
|
@@ -97,12 +66,12 @@ class MooncakeTransferEngine:
|
|
97
66
|
session_id, buffer, peer_buffer_address, length
|
98
67
|
)
|
99
68
|
if ret < 0:
|
100
|
-
logger.error("Transfer Return Error")
|
101
|
-
raise
|
69
|
+
logger.error("Mooncake Transfer Engine Return Error.")
|
70
|
+
raise RuntimeError("Mooncake Transfer Engine Return Error.")
|
102
71
|
return ret
|
103
72
|
|
104
73
|
def get_localhost(self):
|
105
|
-
return self.
|
74
|
+
return self.hostname
|
106
75
|
|
107
76
|
def get_session_id(self):
|
108
77
|
return self.session_id
|
@@ -0,0 +1 @@
|
|
1
|
+
from .conn import NixlKVBootstrapServer, NixlKVManager, NixlKVReceiver, NixlKVSender
|