sglang 0.4.6.post4__py3-none-any.whl → 0.4.6.post5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_offline_throughput.py +6 -6
- sglang/bench_one_batch.py +5 -4
- sglang/bench_one_batch_server.py +23 -15
- sglang/bench_serving.py +133 -57
- sglang/compile_deep_gemm.py +4 -4
- sglang/srt/configs/model_config.py +39 -28
- sglang/srt/conversation.py +1 -1
- sglang/srt/disaggregation/decode.py +122 -133
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +142 -0
- sglang/srt/disaggregation/fake/conn.py +3 -13
- sglang/srt/disaggregation/kv_events.py +357 -0
- sglang/srt/disaggregation/mini_lb.py +57 -24
- sglang/srt/disaggregation/mooncake/conn.py +11 -2
- sglang/srt/disaggregation/mooncake/transfer_engine.py +2 -1
- sglang/srt/disaggregation/nixl/conn.py +9 -19
- sglang/srt/disaggregation/prefill.py +126 -44
- sglang/srt/disaggregation/utils.py +116 -5
- sglang/srt/distributed/utils.py +3 -3
- sglang/srt/entrypoints/EngineBase.py +5 -0
- sglang/srt/entrypoints/engine.py +28 -8
- sglang/srt/entrypoints/http_server.py +6 -4
- sglang/srt/entrypoints/http_server_engine.py +5 -2
- sglang/srt/function_call/base_format_detector.py +250 -0
- sglang/srt/function_call/core_types.py +34 -0
- sglang/srt/function_call/deepseekv3_detector.py +157 -0
- sglang/srt/function_call/ebnf_composer.py +234 -0
- sglang/srt/function_call/function_call_parser.py +175 -0
- sglang/srt/function_call/llama32_detector.py +74 -0
- sglang/srt/function_call/mistral_detector.py +84 -0
- sglang/srt/function_call/pythonic_detector.py +163 -0
- sglang/srt/function_call/qwen25_detector.py +67 -0
- sglang/srt/function_call/utils.py +35 -0
- sglang/srt/hf_transformers_utils.py +46 -7
- sglang/srt/layers/attention/aiter_backend.py +513 -0
- sglang/srt/layers/attention/flashattention_backend.py +63 -17
- sglang/srt/layers/attention/flashinfer_mla_backend.py +8 -4
- sglang/srt/layers/attention/flashmla_backend.py +340 -78
- sglang/srt/layers/attention/triton_backend.py +3 -0
- sglang/srt/layers/attention/utils.py +2 -2
- sglang/srt/layers/attention/vision.py +1 -1
- sglang/srt/layers/communicator.py +451 -0
- sglang/srt/layers/dp_attention.py +0 -10
- sglang/srt/layers/moe/cutlass_moe.py +207 -0
- sglang/srt/layers/moe/ep_moe/kernels.py +33 -11
- sglang/srt/layers/moe/ep_moe/layer.py +104 -50
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +82 -7
- sglang/srt/layers/moe/fused_moe_triton/layer.py +14 -0
- sglang/srt/layers/moe/topk.py +66 -9
- sglang/srt/layers/multimodal.py +70 -0
- sglang/srt/layers/quantization/__init__.py +7 -2
- sglang/srt/layers/quantization/deep_gemm.py +5 -3
- sglang/srt/layers/quantization/fp8.py +90 -0
- sglang/srt/layers/quantization/fp8_utils.py +6 -0
- sglang/srt/layers/quantization/gptq.py +298 -6
- sglang/srt/layers/quantization/int8_kernel.py +18 -5
- sglang/srt/layers/quantization/qoq.py +244 -0
- sglang/srt/lora/lora_manager.py +1 -3
- sglang/srt/managers/deepseek_eplb.py +278 -0
- sglang/srt/managers/eplb_manager.py +55 -0
- sglang/srt/managers/expert_distribution.py +704 -56
- sglang/srt/managers/expert_location.py +394 -0
- sglang/srt/managers/expert_location_dispatch.py +91 -0
- sglang/srt/managers/io_struct.py +16 -3
- sglang/srt/managers/mm_utils.py +293 -139
- sglang/srt/managers/multimodal_processors/base_processor.py +127 -42
- sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +6 -1
- sglang/srt/managers/multimodal_processors/gemma3.py +31 -6
- sglang/srt/managers/multimodal_processors/internvl.py +14 -5
- sglang/srt/managers/multimodal_processors/janus_pro.py +7 -1
- sglang/srt/managers/multimodal_processors/kimi_vl.py +7 -6
- sglang/srt/managers/multimodal_processors/llava.py +3 -3
- sglang/srt/managers/multimodal_processors/minicpm.py +25 -31
- sglang/srt/managers/multimodal_processors/mllama4.py +6 -0
- sglang/srt/managers/multimodal_processors/pixtral.py +9 -9
- sglang/srt/managers/multimodal_processors/qwen_vl.py +58 -16
- sglang/srt/managers/schedule_batch.py +49 -21
- sglang/srt/managers/schedule_policy.py +4 -5
- sglang/srt/managers/scheduler.py +92 -50
- sglang/srt/managers/session_controller.py +1 -1
- sglang/srt/managers/tokenizer_manager.py +99 -24
- sglang/srt/mem_cache/base_prefix_cache.py +3 -0
- sglang/srt/mem_cache/chunk_cache.py +3 -1
- sglang/srt/mem_cache/hiradix_cache.py +4 -4
- sglang/srt/mem_cache/memory_pool.py +74 -52
- sglang/srt/mem_cache/multimodal_cache.py +45 -0
- sglang/srt/mem_cache/radix_cache.py +58 -5
- sglang/srt/metrics/collector.py +2 -2
- sglang/srt/mm_utils.py +10 -0
- sglang/srt/model_executor/cuda_graph_runner.py +20 -9
- sglang/srt/model_executor/expert_location_updater.py +422 -0
- sglang/srt/model_executor/forward_batch_info.py +4 -0
- sglang/srt/model_executor/model_runner.py +144 -54
- sglang/srt/model_loader/loader.py +10 -6
- sglang/srt/models/clip.py +5 -1
- sglang/srt/models/deepseek_v2.py +297 -343
- sglang/srt/models/exaone.py +8 -3
- sglang/srt/models/gemma3_mm.py +70 -33
- sglang/srt/models/llama4.py +10 -2
- sglang/srt/models/llava.py +26 -18
- sglang/srt/models/mimo_mtp.py +220 -0
- sglang/srt/models/minicpmo.py +5 -12
- sglang/srt/models/mistral.py +71 -1
- sglang/srt/models/mllama.py +3 -3
- sglang/srt/models/qwen2.py +95 -26
- sglang/srt/models/qwen2_5_vl.py +8 -0
- sglang/srt/models/qwen2_moe.py +330 -60
- sglang/srt/models/qwen2_vl.py +6 -0
- sglang/srt/models/qwen3.py +52 -10
- sglang/srt/models/qwen3_moe.py +411 -48
- sglang/srt/models/siglip.py +294 -0
- sglang/srt/openai_api/adapter.py +28 -16
- sglang/srt/openai_api/protocol.py +6 -0
- sglang/srt/operations.py +154 -0
- sglang/srt/operations_strategy.py +31 -0
- sglang/srt/server_args.py +134 -24
- sglang/srt/speculative/eagle_utils.py +131 -0
- sglang/srt/speculative/eagle_worker.py +47 -2
- sglang/srt/utils.py +68 -12
- sglang/test/test_cutlass_moe.py +278 -0
- sglang/test/test_utils.py +2 -36
- sglang/utils.py +2 -2
- sglang/version.py +1 -1
- {sglang-0.4.6.post4.dist-info → sglang-0.4.6.post5.dist-info}/METADATA +20 -11
- {sglang-0.4.6.post4.dist-info → sglang-0.4.6.post5.dist-info}/RECORD +128 -102
- {sglang-0.4.6.post4.dist-info → sglang-0.4.6.post5.dist-info}/WHEEL +1 -1
- sglang/srt/function_call_parser.py +0 -858
- sglang/srt/platforms/interface.py +0 -371
- /sglang/srt/models/{xiaomi_mimo.py → mimo.py} +0 -0
- {sglang-0.4.6.post4.dist-info → sglang-0.4.6.post5.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.6.post4.dist-info → sglang-0.4.6.post5.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,142 @@
|
|
1
|
+
from __future__ import annotations
|
2
|
+
|
3
|
+
import logging
|
4
|
+
from typing import TYPE_CHECKING
|
5
|
+
|
6
|
+
import torch
|
7
|
+
|
8
|
+
from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardMode
|
9
|
+
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
10
|
+
|
11
|
+
logger = logging.getLogger(__name__)
|
12
|
+
|
13
|
+
if TYPE_CHECKING:
|
14
|
+
from sglang.srt.configs.model_config import ModelConfig
|
15
|
+
from sglang.srt.managers.schedule_batch import ScheduleBatch
|
16
|
+
from sglang.srt.server_args import ServerArgs
|
17
|
+
|
18
|
+
|
19
|
+
class ScheduleBatchDisaggregationDecodeMixin:
|
20
|
+
|
21
|
+
def prepare_for_prebuilt_extend(self: ScheduleBatch):
|
22
|
+
"""
|
23
|
+
Prepare a prebuilt extend by populate metadata
|
24
|
+
Adapted from .prepare_for_extend().
|
25
|
+
"""
|
26
|
+
|
27
|
+
self.forward_mode = ForwardMode.EXTEND
|
28
|
+
reqs = self.reqs
|
29
|
+
input_ids = [r.fill_ids[len(r.prefix_indices) :] for r in reqs]
|
30
|
+
extend_num_tokens = sum(len(ids) for ids in input_ids)
|
31
|
+
seq_lens = []
|
32
|
+
pre_lens = []
|
33
|
+
req_pool_indices = []
|
34
|
+
|
35
|
+
# Pre-calculate total size
|
36
|
+
total_size = sum(req.extend_input_len for req in reqs)
|
37
|
+
out_cache_loc = torch.empty(total_size, dtype=torch.int64, device=self.device)
|
38
|
+
|
39
|
+
# Fill the tensor in one pass
|
40
|
+
offset = 0
|
41
|
+
for i, req in enumerate(reqs):
|
42
|
+
req_pool_indices.append(req.req_pool_idx)
|
43
|
+
|
44
|
+
chunk = self.req_to_token_pool.req_to_token[req.req_pool_idx][
|
45
|
+
: req.extend_input_len
|
46
|
+
]
|
47
|
+
assert (
|
48
|
+
offset + req.extend_input_len <= total_size
|
49
|
+
), f"Exceeds total size: offset={offset}, req.extend_input_len={req.extend_input_len}, total_size={total_size}"
|
50
|
+
out_cache_loc[offset : offset + req.extend_input_len] = chunk
|
51
|
+
offset += req.extend_input_len
|
52
|
+
|
53
|
+
pre_len = len(req.prefix_indices)
|
54
|
+
seq_len = len(req.origin_input_ids) + max(0, len(req.output_ids) - 1)
|
55
|
+
seq_lens.append(seq_len)
|
56
|
+
if len(req.output_ids) == 0:
|
57
|
+
assert (
|
58
|
+
seq_len - pre_len == req.extend_input_len
|
59
|
+
), f"seq_len={seq_len}, pre_len={pre_len}, req.extend_input_len={req.extend_input_len}"
|
60
|
+
|
61
|
+
req.cached_tokens += pre_len - req.already_computed
|
62
|
+
req.already_computed = seq_len
|
63
|
+
req.is_retracted = False
|
64
|
+
pre_lens.append(pre_len)
|
65
|
+
req.extend_logprob_start_len = 0
|
66
|
+
|
67
|
+
extend_input_logprob_token_ids = None
|
68
|
+
|
69
|
+
# Set fields
|
70
|
+
self.input_ids = torch.tensor(
|
71
|
+
sum(input_ids, []), dtype=torch.int32, device=self.device
|
72
|
+
)
|
73
|
+
self.req_pool_indices = torch.tensor(
|
74
|
+
req_pool_indices, dtype=torch.int64, device=self.device
|
75
|
+
)
|
76
|
+
self.seq_lens = torch.tensor(seq_lens, dtype=torch.int64, device=self.device)
|
77
|
+
self.out_cache_loc = out_cache_loc
|
78
|
+
self.seq_lens_sum = sum(seq_lens)
|
79
|
+
|
80
|
+
if self.return_logprob:
|
81
|
+
self.top_logprobs_nums = [r.top_logprobs_num for r in reqs]
|
82
|
+
self.token_ids_logprobs = [r.token_ids_logprob for r in reqs]
|
83
|
+
|
84
|
+
self.extend_num_tokens = extend_num_tokens
|
85
|
+
self.prefix_lens = [len(r.prefix_indices) for r in reqs]
|
86
|
+
self.extend_lens = [r.extend_input_len for r in reqs]
|
87
|
+
self.extend_logprob_start_lens = [r.extend_logprob_start_len for r in reqs]
|
88
|
+
self.extend_input_logprob_token_ids = extend_input_logprob_token_ids
|
89
|
+
|
90
|
+
# Build sampling info
|
91
|
+
self.sampling_info = SamplingBatchInfo.from_schedule_batch(
|
92
|
+
self,
|
93
|
+
self.model_config.vocab_size,
|
94
|
+
)
|
95
|
+
|
96
|
+
def process_prebuilt_extend(
|
97
|
+
self: ScheduleBatch, server_args: ServerArgs, model_config: ModelConfig
|
98
|
+
):
|
99
|
+
"""Assign the buffered last input id to schedule batch"""
|
100
|
+
self.output_ids = []
|
101
|
+
for req in self.reqs:
|
102
|
+
self.output_ids.append(req.output_ids[-1])
|
103
|
+
self.tree_cache.cache_unfinished_req(req)
|
104
|
+
if req.grammar is not None:
|
105
|
+
req.grammar.accept_token(req.output_ids[-1])
|
106
|
+
req.grammar.finished = req.finished()
|
107
|
+
self.output_ids = torch.tensor(self.output_ids, device=self.device)
|
108
|
+
|
109
|
+
# Simulate the eagle run. We add mock data to hidden states for the
|
110
|
+
# ease of implementation now meaning the first token will have acc rate
|
111
|
+
# of 0.
|
112
|
+
if not self.spec_algorithm.is_none():
|
113
|
+
|
114
|
+
b = len(self.reqs)
|
115
|
+
topk_p = torch.arange(
|
116
|
+
b * server_args.speculative_eagle_topk,
|
117
|
+
0,
|
118
|
+
-1,
|
119
|
+
device=self.device,
|
120
|
+
dtype=torch.float32,
|
121
|
+
)
|
122
|
+
topk_p = topk_p.reshape(b, server_args.speculative_eagle_topk)
|
123
|
+
topk_p /= b * server_args.speculative_eagle_topk
|
124
|
+
topk_index = torch.arange(
|
125
|
+
b * server_args.speculative_eagle_topk, device=self.device
|
126
|
+
)
|
127
|
+
topk_index = topk_index.reshape(b, server_args.speculative_eagle_topk)
|
128
|
+
|
129
|
+
# local import to avoid circular import
|
130
|
+
from sglang.srt.speculative.eagle_utils import EagleDraftInput
|
131
|
+
|
132
|
+
spec_info = EagleDraftInput(
|
133
|
+
topk_p=topk_p,
|
134
|
+
topk_index=topk_index,
|
135
|
+
hidden_states=torch.ones(
|
136
|
+
(b, model_config.hidden_size), device=self.device
|
137
|
+
),
|
138
|
+
verified_id=self.output_ids,
|
139
|
+
)
|
140
|
+
spec_info.prepare_for_extend(self)
|
141
|
+
spec_info.capture_hidden_mode = CaptureHiddenMode.LAST
|
142
|
+
self.spec_info = spec_info
|
@@ -33,28 +33,18 @@ class FakeKVSender(BaseKVSender):
|
|
33
33
|
self,
|
34
34
|
kv_indices: list[int],
|
35
35
|
aux_index: Optional[int] = None,
|
36
|
-
dest_ranks: Optional[list[int]] = None,
|
37
36
|
):
|
38
37
|
logger.info(
|
39
|
-
f"FakeKVSender init with kv_indices: {kv_indices}, aux_index: {aux_index}
|
38
|
+
f"FakeKVSender init with kv_indices: {kv_indices}, aux_index: {aux_index}"
|
40
39
|
)
|
41
40
|
pass
|
42
41
|
|
43
42
|
def send(
|
44
43
|
self,
|
45
44
|
kv_indices: npt.NDArray[np.int64],
|
46
|
-
index_slice: slice,
|
47
|
-
is_last: bool,
|
48
45
|
):
|
49
|
-
|
50
|
-
|
51
|
-
)
|
52
|
-
if is_last:
|
53
|
-
self.has_sent = True
|
54
|
-
logger.info(f"FakeKVSender send success")
|
55
|
-
else:
|
56
|
-
self.has_sent = False
|
57
|
-
logger.info(f"FakeKVSender send fake transferring")
|
46
|
+
self.has_sent = True
|
47
|
+
logger.info(f"FakeKVSender send with kv_indices: {kv_indices}")
|
58
48
|
|
59
49
|
def failure_exception(self):
|
60
50
|
raise Exception("Fake KVSender Exception")
|
@@ -0,0 +1,357 @@
|
|
1
|
+
"""
|
2
|
+
Copyright 2025 SGLang Team
|
3
|
+
Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
you may not use this file except in compliance with the License.
|
5
|
+
You may obtain a copy of the License at
|
6
|
+
|
7
|
+
http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
Unless required by applicable law or agreed to in writing, software
|
10
|
+
distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
See the License for the specific language governing permissions and
|
13
|
+
limitations under the License.
|
14
|
+
"""
|
15
|
+
|
16
|
+
"""
|
17
|
+
KV caching events
|
18
|
+
"""
|
19
|
+
|
20
|
+
import atexit
|
21
|
+
import logging
|
22
|
+
import queue
|
23
|
+
import threading
|
24
|
+
import time
|
25
|
+
from abc import ABC, abstractmethod
|
26
|
+
from collections import deque
|
27
|
+
from itertools import count
|
28
|
+
from queue import Queue
|
29
|
+
from typing import Any, Callable, Optional, Union
|
30
|
+
|
31
|
+
import msgspec
|
32
|
+
import zmq
|
33
|
+
from pydantic import BaseModel
|
34
|
+
|
35
|
+
logger = logging.getLogger(__name__)
|
36
|
+
|
37
|
+
|
38
|
+
class EventBatch(
|
39
|
+
msgspec.Struct,
|
40
|
+
array_like=True, # type: ignore[call-arg]
|
41
|
+
omit_defaults=True, # type: ignore[call-arg]
|
42
|
+
gc=False, # type: ignore[call-arg]
|
43
|
+
):
|
44
|
+
ts: float
|
45
|
+
events: list[Any]
|
46
|
+
|
47
|
+
|
48
|
+
class KVCacheEvent(
|
49
|
+
msgspec.Struct,
|
50
|
+
array_like=True, # type: ignore[call-arg]
|
51
|
+
omit_defaults=True, # type: ignore[call-arg]
|
52
|
+
gc=False, # type: ignore[call-arg]
|
53
|
+
tag=True,
|
54
|
+
):
|
55
|
+
"""Base class for all KV cache-related events"""
|
56
|
+
|
57
|
+
|
58
|
+
class BlockStored(KVCacheEvent):
|
59
|
+
block_hashes: list[int]
|
60
|
+
parent_block_hash: Optional[int]
|
61
|
+
token_ids: list[int]
|
62
|
+
block_size: int
|
63
|
+
lora_id: Optional[int]
|
64
|
+
|
65
|
+
|
66
|
+
class BlockRemoved(KVCacheEvent):
|
67
|
+
block_hashes: list[int]
|
68
|
+
|
69
|
+
|
70
|
+
class AllBlocksCleared(KVCacheEvent):
|
71
|
+
pass
|
72
|
+
|
73
|
+
|
74
|
+
class KVEventBatch(EventBatch):
|
75
|
+
events: list[Union[BlockStored, BlockRemoved, AllBlocksCleared]]
|
76
|
+
|
77
|
+
|
78
|
+
class EventPublisher(ABC):
|
79
|
+
"""Lightweight publisher for EventBatch batches."""
|
80
|
+
|
81
|
+
@abstractmethod
|
82
|
+
def publish(self, events: EventBatch) -> None:
|
83
|
+
"""Emit events in order.
|
84
|
+
|
85
|
+
Implementations should guarantee at-least-once delivery and
|
86
|
+
monotonic ordering (e.g., via sequence numbers).
|
87
|
+
"""
|
88
|
+
|
89
|
+
@abstractmethod
|
90
|
+
def shutdown(self) -> None:
|
91
|
+
"""Shutdown the publisher."""
|
92
|
+
|
93
|
+
|
94
|
+
class NullEventPublisher(EventPublisher):
|
95
|
+
"""No-op implementation (default when disabled)."""
|
96
|
+
|
97
|
+
def publish(self, events) -> None:
|
98
|
+
return
|
99
|
+
|
100
|
+
def shutdown(self) -> None:
|
101
|
+
return
|
102
|
+
|
103
|
+
|
104
|
+
class ZmqEventPublisher(EventPublisher):
|
105
|
+
"""Reliable PUB/ROUTER publisher with an in-memory replay buffer.
|
106
|
+
|
107
|
+
Spawns a separate thread to handle publishing from a queue.
|
108
|
+
|
109
|
+
Parameters
|
110
|
+
----------
|
111
|
+
endpoint:
|
112
|
+
PUB address. Use ``tcp://*:5557`` to bind or ``tcp://host:5557`` to
|
113
|
+
connect.
|
114
|
+
replay_endpoint:
|
115
|
+
Optional ROUTER address for replay requests. When given, subscribers can
|
116
|
+
request missed batches by sending the starting sequence number as an
|
117
|
+
8-byte big-endian integer.
|
118
|
+
buffer_steps:
|
119
|
+
Number of past batches to keep for replay.
|
120
|
+
hwm:
|
121
|
+
ZeroMQ high-water-mark for PUB socket.
|
122
|
+
max_queue_size:
|
123
|
+
Maximum number of events to buffer in memory.
|
124
|
+
topic:
|
125
|
+
Topic to publish events to.
|
126
|
+
"""
|
127
|
+
|
128
|
+
SHUTDOWN_TIMEOUT: float = 1.0
|
129
|
+
END_SEQ = (-1).to_bytes(8, "big", signed=True)
|
130
|
+
|
131
|
+
def __init__(
|
132
|
+
self,
|
133
|
+
endpoint: str = "tcp://*:5557",
|
134
|
+
replay_endpoint: Optional[str] = None,
|
135
|
+
buffer_steps: int = 10_000,
|
136
|
+
hwm: int = 100_000,
|
137
|
+
max_queue_size: int = 100_000,
|
138
|
+
topic: str = "",
|
139
|
+
) -> None:
|
140
|
+
# Storage
|
141
|
+
self._event_queue = Queue[Optional[EventBatch]](maxsize=max_queue_size)
|
142
|
+
self._buffer = deque[tuple[int, bytes]](maxlen=buffer_steps)
|
143
|
+
|
144
|
+
# ZMQ sockets
|
145
|
+
self._ctx = zmq.Context.instance()
|
146
|
+
self._pub: Optional[zmq.Socket] = None
|
147
|
+
self._replay: Optional[zmq.Socket] = None
|
148
|
+
self._endpoint = endpoint
|
149
|
+
self._replay_endpoint = replay_endpoint
|
150
|
+
self._hwm = hwm
|
151
|
+
self._socket_setup()
|
152
|
+
|
153
|
+
# Payload
|
154
|
+
self._seq_gen = count()
|
155
|
+
self._topic_bytes = topic.encode("utf-8")
|
156
|
+
|
157
|
+
# Thread
|
158
|
+
self._running = True
|
159
|
+
logger.info("Starting ZMQ publisher thread")
|
160
|
+
|
161
|
+
self._thread = threading.Thread(
|
162
|
+
target=self._publisher_thread, daemon=True, name="zmq-publisher"
|
163
|
+
)
|
164
|
+
self._thread.start()
|
165
|
+
|
166
|
+
atexit.register(self.shutdown)
|
167
|
+
|
168
|
+
def publish(self, events: EventBatch) -> None:
|
169
|
+
if not self._running:
|
170
|
+
raise RuntimeError("Publisher is closed")
|
171
|
+
self._event_queue.put(events)
|
172
|
+
|
173
|
+
def shutdown(self) -> None:
|
174
|
+
"""Stop the publisher thread and clean up resources."""
|
175
|
+
self._running = False
|
176
|
+
self._event_queue.put_nowait(None)
|
177
|
+
|
178
|
+
start = time.time()
|
179
|
+
pending_items = True
|
180
|
+
while pending_items and (time.time() - start < self.SHUTDOWN_TIMEOUT):
|
181
|
+
pending_items = not self._event_queue.empty()
|
182
|
+
if pending_items:
|
183
|
+
time.sleep(0.1)
|
184
|
+
|
185
|
+
if pending_items:
|
186
|
+
logger.warning(
|
187
|
+
"Warning: Queue still has %s items after %s seconds timeout",
|
188
|
+
self._event_queue.qsize(),
|
189
|
+
self.SHUTDOWN_TIMEOUT,
|
190
|
+
)
|
191
|
+
|
192
|
+
if self._thread.is_alive():
|
193
|
+
self._thread.join(timeout=self.SHUTDOWN_TIMEOUT)
|
194
|
+
|
195
|
+
# Clean up ZMQ resources
|
196
|
+
try:
|
197
|
+
if self._pub is not None:
|
198
|
+
self._pub.close(linger=0)
|
199
|
+
if self._replay is not None:
|
200
|
+
self._replay.close(linger=0)
|
201
|
+
finally:
|
202
|
+
pass # Do not terminate context; other sockets may use it
|
203
|
+
|
204
|
+
def _socket_setup(self) -> None:
|
205
|
+
"""Initialize sockets
|
206
|
+
https://pyzmq.readthedocs.io/en/v19.0.0/morethanbindings.html#thread-safety
|
207
|
+
"""
|
208
|
+
if self._pub is None:
|
209
|
+
self._pub = self._ctx.socket(zmq.PUB)
|
210
|
+
self._pub.set_hwm(self._hwm)
|
211
|
+
# Heuristic: bind if wildcard / * present, else connect.
|
212
|
+
# bind stable, connect volatile convention
|
213
|
+
if (
|
214
|
+
"*" in self._endpoint
|
215
|
+
or "::" in self._endpoint
|
216
|
+
or self._endpoint.startswith("ipc://")
|
217
|
+
or self._endpoint.startswith("inproc://")
|
218
|
+
):
|
219
|
+
self._pub.bind(self._endpoint)
|
220
|
+
else:
|
221
|
+
self._pub.connect(self._endpoint)
|
222
|
+
|
223
|
+
# Set up replay socket: use ROUTER
|
224
|
+
# 1) handles multiple REQ clients (identities)
|
225
|
+
# 2) lets us send back one request → many replies (streamed events)
|
226
|
+
# 3) works in our non‑blocking poll loop alongside PUB
|
227
|
+
if self._replay_endpoint is not None:
|
228
|
+
self._replay = self._ctx.socket(zmq.ROUTER)
|
229
|
+
self._replay.bind(self._replay_endpoint)
|
230
|
+
|
231
|
+
def _publisher_thread(self) -> None:
|
232
|
+
"""Background thread that processes the event queue."""
|
233
|
+
self._pack = msgspec.msgpack.Encoder()
|
234
|
+
|
235
|
+
assert self._pub is not None # narrows type for mypy
|
236
|
+
|
237
|
+
while self._running or self._event_queue.qsize() > 0:
|
238
|
+
# --- replay (non-critical) ---------------------------------
|
239
|
+
if self._replay is not None and self._replay.poll(0):
|
240
|
+
try:
|
241
|
+
self._service_replay()
|
242
|
+
except Exception as e:
|
243
|
+
logger.exception("Error in replay: %s", e)
|
244
|
+
|
245
|
+
# --- main queue (critical) ---------------------------------
|
246
|
+
try:
|
247
|
+
event = self._event_queue.get(timeout=0.1)
|
248
|
+
if event is None:
|
249
|
+
break # Sentinel received, exit thread
|
250
|
+
except queue.Empty:
|
251
|
+
continue
|
252
|
+
|
253
|
+
try:
|
254
|
+
seq = next(self._seq_gen)
|
255
|
+
|
256
|
+
payload = self._pack.encode(event)
|
257
|
+
seq_bytes = seq.to_bytes(8, "big")
|
258
|
+
self._pub.send_multipart((self._topic_bytes, seq_bytes, payload))
|
259
|
+
|
260
|
+
self._buffer.append((seq, payload))
|
261
|
+
self._event_queue.task_done()
|
262
|
+
|
263
|
+
except Exception as e:
|
264
|
+
# Publishing failed; back-off a bit to avoid a tight error loop
|
265
|
+
logger.exception("Error in publisher thread: %s", e)
|
266
|
+
time.sleep(0.1)
|
267
|
+
|
268
|
+
def _service_replay(self) -> None:
|
269
|
+
"""If a replay request is waiting, send buffered batches."""
|
270
|
+
assert self._replay is not None # narrows type for mypy
|
271
|
+
|
272
|
+
frame = self._replay.recv_multipart()
|
273
|
+
if len(frame) != 3:
|
274
|
+
logger.warning("Invalid replay request: %s", frame)
|
275
|
+
return
|
276
|
+
client_id, _, start_seq_bytes = frame
|
277
|
+
start_seq = int.from_bytes(start_seq_bytes, "big")
|
278
|
+
|
279
|
+
for seq, buf in self._buffer:
|
280
|
+
if seq >= start_seq:
|
281
|
+
# [identity, empty_delim, seq_bytes, payload]
|
282
|
+
# (identity, empty_delim) are stripped off by the router
|
283
|
+
# receiving payload is (seq_bytes, payload)
|
284
|
+
self._replay.send_multipart(
|
285
|
+
(client_id, b"", seq.to_bytes(8, "big"), buf)
|
286
|
+
)
|
287
|
+
# Send end of sequence marker
|
288
|
+
# receiving payload is (-1, b""")
|
289
|
+
self._replay.send_multipart((client_id, b"", self.END_SEQ, b""))
|
290
|
+
|
291
|
+
|
292
|
+
class KVEventsConfig(BaseModel):
|
293
|
+
"""Configuration for KV event publishing."""
|
294
|
+
|
295
|
+
publisher: str = "null"
|
296
|
+
"""The publisher to use for publishing kv events. Can be "null", "zmq".
|
297
|
+
"""
|
298
|
+
|
299
|
+
endpoint: str = "tcp://*:5557"
|
300
|
+
"""The zmq endpoint to use for publishing kv events.
|
301
|
+
"""
|
302
|
+
|
303
|
+
replay_endpoint: Optional[str] = None
|
304
|
+
"""The zmq endpoint to use for replaying kv events.
|
305
|
+
"""
|
306
|
+
|
307
|
+
buffer_steps: int = 10_000
|
308
|
+
"""The number of steps to cache for replay endpoint. Will only save
|
309
|
+
events from the last N steps for the replay endpoint.
|
310
|
+
"""
|
311
|
+
|
312
|
+
hwm: int = 100_000
|
313
|
+
"""The zmq high water mark for the event publisher. After queueing N events,
|
314
|
+
events will start dropping if the consumer is not keeping up.
|
315
|
+
"""
|
316
|
+
|
317
|
+
max_queue_size: int = 100_000
|
318
|
+
"""The maximum number of events to queue while waiting for publishing.
|
319
|
+
"""
|
320
|
+
|
321
|
+
topic: str = ""
|
322
|
+
"""The topic to use for the event publisher. Consumers can subscribe to
|
323
|
+
this topic to receive events.
|
324
|
+
"""
|
325
|
+
|
326
|
+
@classmethod
|
327
|
+
def from_cli(cls, cli_value: str) -> "KVEventsConfig":
|
328
|
+
"""Parse the CLI value for the event publisher config."""
|
329
|
+
return KVEventsConfig.model_validate_json(cli_value)
|
330
|
+
|
331
|
+
|
332
|
+
class EventPublisherFactory:
|
333
|
+
_registry: dict[str, Callable[..., EventPublisher]] = {
|
334
|
+
"null": NullEventPublisher,
|
335
|
+
"zmq": ZmqEventPublisher,
|
336
|
+
}
|
337
|
+
|
338
|
+
@classmethod
|
339
|
+
def register_publisher(cls, name: str, ctor: Callable[..., EventPublisher]) -> None:
|
340
|
+
if name in cls._registry:
|
341
|
+
raise KeyError(f"publisher '{name}' already registered")
|
342
|
+
cls._registry[name] = ctor
|
343
|
+
|
344
|
+
@classmethod
|
345
|
+
def create(cls, config: Optional[str]) -> EventPublisher:
|
346
|
+
"""Create publisher from a config mapping."""
|
347
|
+
if not config:
|
348
|
+
return NullEventPublisher()
|
349
|
+
config = KVEventsConfig.from_cli(config)
|
350
|
+
config_dict = config.model_dump()
|
351
|
+
|
352
|
+
kind = config_dict.pop("publisher", "null")
|
353
|
+
try:
|
354
|
+
constructor = cls._registry[kind]
|
355
|
+
except KeyError as exc:
|
356
|
+
raise ValueError(f"Unknown event publisher '{kind}'") from exc
|
357
|
+
return constructor(**config_dict)
|
@@ -73,11 +73,27 @@ class MiniLoadBalancer:
|
|
73
73
|
session.post(f"{prefill_server}/{endpoint}", json=modified_request),
|
74
74
|
session.post(f"{decode_server}/{endpoint}", json=modified_request),
|
75
75
|
]
|
76
|
+
|
76
77
|
# Wait for both responses to complete. Prefill should end first.
|
77
|
-
|
78
|
+
prefill_response, decode_response = await asyncio.gather(*tasks)
|
79
|
+
|
80
|
+
if "return_logprob" in modified_request:
|
81
|
+
|
82
|
+
prefill_json = await prefill_response.json()
|
83
|
+
ret_json = await decode_response.json()
|
84
|
+
|
85
|
+
# merge `meta_info.input_token_logprobs` from prefill to decode
|
86
|
+
if "meta_info" in ret_json:
|
87
|
+
if "input_token_logprobs" in ret_json["meta_info"]:
|
88
|
+
ret_json["meta_info"]["input_token_logprobs"] = (
|
89
|
+
prefill_json["meta_info"]["input_token_logprobs"]
|
90
|
+
+ ret_json["meta_info"]["input_token_logprobs"]
|
91
|
+
)
|
92
|
+
else:
|
93
|
+
ret_json = await decode_response.json()
|
78
94
|
|
79
95
|
return ORJSONResponse(
|
80
|
-
content=
|
96
|
+
content=ret_json,
|
81
97
|
status_code=decode_response.status,
|
82
98
|
)
|
83
99
|
|
@@ -92,30 +108,47 @@ class MiniLoadBalancer:
|
|
92
108
|
total=3600
|
93
109
|
) # Add timeout for request reliability
|
94
110
|
) as session:
|
95
|
-
|
96
|
-
|
97
|
-
|
98
|
-
|
99
|
-
|
100
|
-
|
101
|
-
|
102
|
-
|
103
|
-
|
104
|
-
]
|
105
|
-
|
106
|
-
|
111
|
+
# Create the tasks for both prefill and decode requests
|
112
|
+
tasks = [
|
113
|
+
session.post(f"{prefill_server}/generate", json=modified_request),
|
114
|
+
session.post(f"{decode_server}/generate", json=modified_request),
|
115
|
+
]
|
116
|
+
# Wait for both responses to complete. Since this is streaming, they return immediately.
|
117
|
+
prefill_response, decode_response = await asyncio.gather(*tasks)
|
118
|
+
|
119
|
+
if modified_request.get("return_logprob", False):
|
120
|
+
prefill_chunks = []
|
121
|
+
async for chunk in prefill_response.content:
|
122
|
+
prefill_chunks.append(chunk)
|
123
|
+
|
124
|
+
first_prefill_chunk = (
|
125
|
+
prefill_chunks[0].decode("utf-8")[5:].strip("\n")
|
126
|
+
)
|
127
|
+
first_prefill_chunk_json = orjson.loads(first_prefill_chunk)
|
128
|
+
|
129
|
+
async for chunk in decode_response.content:
|
130
|
+
# Note: This is inefficient
|
131
|
+
# merge prefill input_token_logprobs, output_token_logprobs to decode
|
132
|
+
decoded_chunk = chunk.decode("utf-8")
|
133
|
+
if (
|
134
|
+
decoded_chunk
|
135
|
+
and decoded_chunk.startswith("data:")
|
136
|
+
and "[DONE]" not in decoded_chunk
|
137
|
+
):
|
138
|
+
ret_json = orjson.loads(decoded_chunk[5:].strip("\n"))
|
139
|
+
ret_json["meta_info"]["input_token_logprobs"] = (
|
140
|
+
first_prefill_chunk_json["meta_info"][
|
141
|
+
"input_token_logprobs"
|
142
|
+
]
|
143
|
+
+ ret_json["meta_info"]["input_token_logprobs"]
|
144
|
+
)
|
145
|
+
|
146
|
+
yield b"data: " + orjson.dumps(ret_json) + b"\n\n"
|
147
|
+
else:
|
148
|
+
yield chunk
|
149
|
+
else:
|
107
150
|
async for chunk in decode_response.content:
|
108
151
|
yield chunk
|
109
|
-
except Exception as e:
|
110
|
-
error_msg = {
|
111
|
-
"error": {"message": f"Stream processing error: {str(e)}"}
|
112
|
-
}
|
113
|
-
yield b"data: " + orjson.dumps(
|
114
|
-
error_msg, option=orjson.OPT_NON_STR_KEYS
|
115
|
-
) + b"\n\n"
|
116
|
-
finally:
|
117
|
-
if prefill_response is not None:
|
118
|
-
await prefill_response.release()
|
119
152
|
|
120
153
|
return StreamingResponse(
|
121
154
|
stream_results(),
|
@@ -51,6 +51,7 @@ def group_concurrent_contiguous(
|
|
51
51
|
return src_groups, dst_groups
|
52
52
|
|
53
53
|
|
54
|
+
# prefill
|
54
55
|
@dataclasses.dataclass
|
55
56
|
class TransferKVChunk:
|
56
57
|
room: int
|
@@ -60,6 +61,7 @@ class TransferKVChunk:
|
|
60
61
|
prefill_aux_index: Optional[int]
|
61
62
|
|
62
63
|
|
64
|
+
# decode
|
63
65
|
@dataclasses.dataclass
|
64
66
|
class TransferInfo:
|
65
67
|
room: int
|
@@ -93,6 +95,7 @@ class TransferInfo:
|
|
93
95
|
)
|
94
96
|
|
95
97
|
|
98
|
+
# decode
|
96
99
|
@dataclasses.dataclass
|
97
100
|
class KVArgsRegisterInfo:
|
98
101
|
room: str
|
@@ -464,6 +467,8 @@ class MooncakeKVSender(BaseKVSender):
|
|
464
467
|
self.aux_index = None
|
465
468
|
self.bootstrap_server_url = bootstrap_addr
|
466
469
|
self.session_id = self.kv_mgr.get_session_id()
|
470
|
+
# inner state
|
471
|
+
self.curr_idx = 0
|
467
472
|
|
468
473
|
def init(self, num_kv_indices: int, aux_index: Optional[int] = None):
|
469
474
|
self.num_kv_indices = num_kv_indices
|
@@ -472,9 +477,11 @@ class MooncakeKVSender(BaseKVSender):
|
|
472
477
|
def send(
|
473
478
|
self,
|
474
479
|
kv_indices: npt.NDArray[np.int64],
|
475
|
-
index_slice: slice,
|
476
|
-
is_last: bool,
|
477
480
|
):
|
481
|
+
index_slice = slice(self.curr_idx, self.curr_idx + len(kv_indices))
|
482
|
+
self.curr_idx += len(kv_indices)
|
483
|
+
is_last = self.curr_idx == self.num_kv_indices
|
484
|
+
|
478
485
|
if not is_last:
|
479
486
|
self.kv_mgr.add_transfer_request(
|
480
487
|
self.bootstrap_room, kv_indices, index_slice, False
|
@@ -492,6 +499,7 @@ class MooncakeKVSender(BaseKVSender):
|
|
492
499
|
return self.kv_mgr.check_status(self.bootstrap_room)
|
493
500
|
|
494
501
|
def failure_exception(self):
|
502
|
+
# TODO: raise a real exception
|
495
503
|
raise Exception("Fake KVSender Exception")
|
496
504
|
|
497
505
|
|
@@ -719,6 +727,7 @@ class MooncakeKVReceiver(BaseKVReceiver):
|
|
719
727
|
return self.kv_mgr.check_status(self.bootstrap_room)
|
720
728
|
|
721
729
|
def failure_exception(self):
|
730
|
+
# TODO: raise a real exception
|
722
731
|
raise Exception("Fake KVReceiver Exception")
|
723
732
|
|
724
733
|
|