sglang 0.4.6.post4__py3-none-any.whl → 0.4.6.post5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (130) hide show
  1. sglang/bench_offline_throughput.py +6 -6
  2. sglang/bench_one_batch.py +5 -4
  3. sglang/bench_one_batch_server.py +23 -15
  4. sglang/bench_serving.py +133 -57
  5. sglang/compile_deep_gemm.py +4 -4
  6. sglang/srt/configs/model_config.py +39 -28
  7. sglang/srt/conversation.py +1 -1
  8. sglang/srt/disaggregation/decode.py +122 -133
  9. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +142 -0
  10. sglang/srt/disaggregation/fake/conn.py +3 -13
  11. sglang/srt/disaggregation/kv_events.py +357 -0
  12. sglang/srt/disaggregation/mini_lb.py +57 -24
  13. sglang/srt/disaggregation/mooncake/conn.py +11 -2
  14. sglang/srt/disaggregation/mooncake/transfer_engine.py +2 -1
  15. sglang/srt/disaggregation/nixl/conn.py +9 -19
  16. sglang/srt/disaggregation/prefill.py +126 -44
  17. sglang/srt/disaggregation/utils.py +116 -5
  18. sglang/srt/distributed/utils.py +3 -3
  19. sglang/srt/entrypoints/EngineBase.py +5 -0
  20. sglang/srt/entrypoints/engine.py +28 -8
  21. sglang/srt/entrypoints/http_server.py +6 -4
  22. sglang/srt/entrypoints/http_server_engine.py +5 -2
  23. sglang/srt/function_call/base_format_detector.py +250 -0
  24. sglang/srt/function_call/core_types.py +34 -0
  25. sglang/srt/function_call/deepseekv3_detector.py +157 -0
  26. sglang/srt/function_call/ebnf_composer.py +234 -0
  27. sglang/srt/function_call/function_call_parser.py +175 -0
  28. sglang/srt/function_call/llama32_detector.py +74 -0
  29. sglang/srt/function_call/mistral_detector.py +84 -0
  30. sglang/srt/function_call/pythonic_detector.py +163 -0
  31. sglang/srt/function_call/qwen25_detector.py +67 -0
  32. sglang/srt/function_call/utils.py +35 -0
  33. sglang/srt/hf_transformers_utils.py +46 -7
  34. sglang/srt/layers/attention/aiter_backend.py +513 -0
  35. sglang/srt/layers/attention/flashattention_backend.py +63 -17
  36. sglang/srt/layers/attention/flashinfer_mla_backend.py +8 -4
  37. sglang/srt/layers/attention/flashmla_backend.py +340 -78
  38. sglang/srt/layers/attention/triton_backend.py +3 -0
  39. sglang/srt/layers/attention/utils.py +2 -2
  40. sglang/srt/layers/attention/vision.py +1 -1
  41. sglang/srt/layers/communicator.py +451 -0
  42. sglang/srt/layers/dp_attention.py +0 -10
  43. sglang/srt/layers/moe/cutlass_moe.py +207 -0
  44. sglang/srt/layers/moe/ep_moe/kernels.py +33 -11
  45. sglang/srt/layers/moe/ep_moe/layer.py +104 -50
  46. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +82 -7
  47. sglang/srt/layers/moe/fused_moe_triton/layer.py +14 -0
  48. sglang/srt/layers/moe/topk.py +66 -9
  49. sglang/srt/layers/multimodal.py +70 -0
  50. sglang/srt/layers/quantization/__init__.py +7 -2
  51. sglang/srt/layers/quantization/deep_gemm.py +5 -3
  52. sglang/srt/layers/quantization/fp8.py +90 -0
  53. sglang/srt/layers/quantization/fp8_utils.py +6 -0
  54. sglang/srt/layers/quantization/gptq.py +298 -6
  55. sglang/srt/layers/quantization/int8_kernel.py +18 -5
  56. sglang/srt/layers/quantization/qoq.py +244 -0
  57. sglang/srt/lora/lora_manager.py +1 -3
  58. sglang/srt/managers/deepseek_eplb.py +278 -0
  59. sglang/srt/managers/eplb_manager.py +55 -0
  60. sglang/srt/managers/expert_distribution.py +704 -56
  61. sglang/srt/managers/expert_location.py +394 -0
  62. sglang/srt/managers/expert_location_dispatch.py +91 -0
  63. sglang/srt/managers/io_struct.py +16 -3
  64. sglang/srt/managers/mm_utils.py +293 -139
  65. sglang/srt/managers/multimodal_processors/base_processor.py +127 -42
  66. sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +6 -1
  67. sglang/srt/managers/multimodal_processors/gemma3.py +31 -6
  68. sglang/srt/managers/multimodal_processors/internvl.py +14 -5
  69. sglang/srt/managers/multimodal_processors/janus_pro.py +7 -1
  70. sglang/srt/managers/multimodal_processors/kimi_vl.py +7 -6
  71. sglang/srt/managers/multimodal_processors/llava.py +3 -3
  72. sglang/srt/managers/multimodal_processors/minicpm.py +25 -31
  73. sglang/srt/managers/multimodal_processors/mllama4.py +6 -0
  74. sglang/srt/managers/multimodal_processors/pixtral.py +9 -9
  75. sglang/srt/managers/multimodal_processors/qwen_vl.py +58 -16
  76. sglang/srt/managers/schedule_batch.py +49 -21
  77. sglang/srt/managers/schedule_policy.py +4 -5
  78. sglang/srt/managers/scheduler.py +92 -50
  79. sglang/srt/managers/session_controller.py +1 -1
  80. sglang/srt/managers/tokenizer_manager.py +99 -24
  81. sglang/srt/mem_cache/base_prefix_cache.py +3 -0
  82. sglang/srt/mem_cache/chunk_cache.py +3 -1
  83. sglang/srt/mem_cache/hiradix_cache.py +4 -4
  84. sglang/srt/mem_cache/memory_pool.py +74 -52
  85. sglang/srt/mem_cache/multimodal_cache.py +45 -0
  86. sglang/srt/mem_cache/radix_cache.py +58 -5
  87. sglang/srt/metrics/collector.py +2 -2
  88. sglang/srt/mm_utils.py +10 -0
  89. sglang/srt/model_executor/cuda_graph_runner.py +20 -9
  90. sglang/srt/model_executor/expert_location_updater.py +422 -0
  91. sglang/srt/model_executor/forward_batch_info.py +4 -0
  92. sglang/srt/model_executor/model_runner.py +144 -54
  93. sglang/srt/model_loader/loader.py +10 -6
  94. sglang/srt/models/clip.py +5 -1
  95. sglang/srt/models/deepseek_v2.py +297 -343
  96. sglang/srt/models/exaone.py +8 -3
  97. sglang/srt/models/gemma3_mm.py +70 -33
  98. sglang/srt/models/llama4.py +10 -2
  99. sglang/srt/models/llava.py +26 -18
  100. sglang/srt/models/mimo_mtp.py +220 -0
  101. sglang/srt/models/minicpmo.py +5 -12
  102. sglang/srt/models/mistral.py +71 -1
  103. sglang/srt/models/mllama.py +3 -3
  104. sglang/srt/models/qwen2.py +95 -26
  105. sglang/srt/models/qwen2_5_vl.py +8 -0
  106. sglang/srt/models/qwen2_moe.py +330 -60
  107. sglang/srt/models/qwen2_vl.py +6 -0
  108. sglang/srt/models/qwen3.py +52 -10
  109. sglang/srt/models/qwen3_moe.py +411 -48
  110. sglang/srt/models/siglip.py +294 -0
  111. sglang/srt/openai_api/adapter.py +28 -16
  112. sglang/srt/openai_api/protocol.py +6 -0
  113. sglang/srt/operations.py +154 -0
  114. sglang/srt/operations_strategy.py +31 -0
  115. sglang/srt/server_args.py +134 -24
  116. sglang/srt/speculative/eagle_utils.py +131 -0
  117. sglang/srt/speculative/eagle_worker.py +47 -2
  118. sglang/srt/utils.py +68 -12
  119. sglang/test/test_cutlass_moe.py +278 -0
  120. sglang/test/test_utils.py +2 -36
  121. sglang/utils.py +2 -2
  122. sglang/version.py +1 -1
  123. {sglang-0.4.6.post4.dist-info → sglang-0.4.6.post5.dist-info}/METADATA +20 -11
  124. {sglang-0.4.6.post4.dist-info → sglang-0.4.6.post5.dist-info}/RECORD +128 -102
  125. {sglang-0.4.6.post4.dist-info → sglang-0.4.6.post5.dist-info}/WHEEL +1 -1
  126. sglang/srt/function_call_parser.py +0 -858
  127. sglang/srt/platforms/interface.py +0 -371
  128. /sglang/srt/models/{xiaomi_mimo.py → mimo.py} +0 -0
  129. {sglang-0.4.6.post4.dist-info → sglang-0.4.6.post5.dist-info}/licenses/LICENSE +0 -0
  130. {sglang-0.4.6.post4.dist-info → sglang-0.4.6.post5.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,142 @@
1
+ from __future__ import annotations
2
+
3
+ import logging
4
+ from typing import TYPE_CHECKING
5
+
6
+ import torch
7
+
8
+ from sglang.srt.model_executor.forward_batch_info import CaptureHiddenMode, ForwardMode
9
+ from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
10
+
11
+ logger = logging.getLogger(__name__)
12
+
13
+ if TYPE_CHECKING:
14
+ from sglang.srt.configs.model_config import ModelConfig
15
+ from sglang.srt.managers.schedule_batch import ScheduleBatch
16
+ from sglang.srt.server_args import ServerArgs
17
+
18
+
19
+ class ScheduleBatchDisaggregationDecodeMixin:
20
+
21
+ def prepare_for_prebuilt_extend(self: ScheduleBatch):
22
+ """
23
+ Prepare a prebuilt extend by populate metadata
24
+ Adapted from .prepare_for_extend().
25
+ """
26
+
27
+ self.forward_mode = ForwardMode.EXTEND
28
+ reqs = self.reqs
29
+ input_ids = [r.fill_ids[len(r.prefix_indices) :] for r in reqs]
30
+ extend_num_tokens = sum(len(ids) for ids in input_ids)
31
+ seq_lens = []
32
+ pre_lens = []
33
+ req_pool_indices = []
34
+
35
+ # Pre-calculate total size
36
+ total_size = sum(req.extend_input_len for req in reqs)
37
+ out_cache_loc = torch.empty(total_size, dtype=torch.int64, device=self.device)
38
+
39
+ # Fill the tensor in one pass
40
+ offset = 0
41
+ for i, req in enumerate(reqs):
42
+ req_pool_indices.append(req.req_pool_idx)
43
+
44
+ chunk = self.req_to_token_pool.req_to_token[req.req_pool_idx][
45
+ : req.extend_input_len
46
+ ]
47
+ assert (
48
+ offset + req.extend_input_len <= total_size
49
+ ), f"Exceeds total size: offset={offset}, req.extend_input_len={req.extend_input_len}, total_size={total_size}"
50
+ out_cache_loc[offset : offset + req.extend_input_len] = chunk
51
+ offset += req.extend_input_len
52
+
53
+ pre_len = len(req.prefix_indices)
54
+ seq_len = len(req.origin_input_ids) + max(0, len(req.output_ids) - 1)
55
+ seq_lens.append(seq_len)
56
+ if len(req.output_ids) == 0:
57
+ assert (
58
+ seq_len - pre_len == req.extend_input_len
59
+ ), f"seq_len={seq_len}, pre_len={pre_len}, req.extend_input_len={req.extend_input_len}"
60
+
61
+ req.cached_tokens += pre_len - req.already_computed
62
+ req.already_computed = seq_len
63
+ req.is_retracted = False
64
+ pre_lens.append(pre_len)
65
+ req.extend_logprob_start_len = 0
66
+
67
+ extend_input_logprob_token_ids = None
68
+
69
+ # Set fields
70
+ self.input_ids = torch.tensor(
71
+ sum(input_ids, []), dtype=torch.int32, device=self.device
72
+ )
73
+ self.req_pool_indices = torch.tensor(
74
+ req_pool_indices, dtype=torch.int64, device=self.device
75
+ )
76
+ self.seq_lens = torch.tensor(seq_lens, dtype=torch.int64, device=self.device)
77
+ self.out_cache_loc = out_cache_loc
78
+ self.seq_lens_sum = sum(seq_lens)
79
+
80
+ if self.return_logprob:
81
+ self.top_logprobs_nums = [r.top_logprobs_num for r in reqs]
82
+ self.token_ids_logprobs = [r.token_ids_logprob for r in reqs]
83
+
84
+ self.extend_num_tokens = extend_num_tokens
85
+ self.prefix_lens = [len(r.prefix_indices) for r in reqs]
86
+ self.extend_lens = [r.extend_input_len for r in reqs]
87
+ self.extend_logprob_start_lens = [r.extend_logprob_start_len for r in reqs]
88
+ self.extend_input_logprob_token_ids = extend_input_logprob_token_ids
89
+
90
+ # Build sampling info
91
+ self.sampling_info = SamplingBatchInfo.from_schedule_batch(
92
+ self,
93
+ self.model_config.vocab_size,
94
+ )
95
+
96
+ def process_prebuilt_extend(
97
+ self: ScheduleBatch, server_args: ServerArgs, model_config: ModelConfig
98
+ ):
99
+ """Assign the buffered last input id to schedule batch"""
100
+ self.output_ids = []
101
+ for req in self.reqs:
102
+ self.output_ids.append(req.output_ids[-1])
103
+ self.tree_cache.cache_unfinished_req(req)
104
+ if req.grammar is not None:
105
+ req.grammar.accept_token(req.output_ids[-1])
106
+ req.grammar.finished = req.finished()
107
+ self.output_ids = torch.tensor(self.output_ids, device=self.device)
108
+
109
+ # Simulate the eagle run. We add mock data to hidden states for the
110
+ # ease of implementation now meaning the first token will have acc rate
111
+ # of 0.
112
+ if not self.spec_algorithm.is_none():
113
+
114
+ b = len(self.reqs)
115
+ topk_p = torch.arange(
116
+ b * server_args.speculative_eagle_topk,
117
+ 0,
118
+ -1,
119
+ device=self.device,
120
+ dtype=torch.float32,
121
+ )
122
+ topk_p = topk_p.reshape(b, server_args.speculative_eagle_topk)
123
+ topk_p /= b * server_args.speculative_eagle_topk
124
+ topk_index = torch.arange(
125
+ b * server_args.speculative_eagle_topk, device=self.device
126
+ )
127
+ topk_index = topk_index.reshape(b, server_args.speculative_eagle_topk)
128
+
129
+ # local import to avoid circular import
130
+ from sglang.srt.speculative.eagle_utils import EagleDraftInput
131
+
132
+ spec_info = EagleDraftInput(
133
+ topk_p=topk_p,
134
+ topk_index=topk_index,
135
+ hidden_states=torch.ones(
136
+ (b, model_config.hidden_size), device=self.device
137
+ ),
138
+ verified_id=self.output_ids,
139
+ )
140
+ spec_info.prepare_for_extend(self)
141
+ spec_info.capture_hidden_mode = CaptureHiddenMode.LAST
142
+ self.spec_info = spec_info
@@ -33,28 +33,18 @@ class FakeKVSender(BaseKVSender):
33
33
  self,
34
34
  kv_indices: list[int],
35
35
  aux_index: Optional[int] = None,
36
- dest_ranks: Optional[list[int]] = None,
37
36
  ):
38
37
  logger.info(
39
- f"FakeKVSender init with kv_indices: {kv_indices}, aux_index: {aux_index}, dest_ranks: {dest_ranks}"
38
+ f"FakeKVSender init with kv_indices: {kv_indices}, aux_index: {aux_index}"
40
39
  )
41
40
  pass
42
41
 
43
42
  def send(
44
43
  self,
45
44
  kv_indices: npt.NDArray[np.int64],
46
- index_slice: slice,
47
- is_last: bool,
48
45
  ):
49
- logger.info(
50
- f"FakeKVSender send with kv_indices: {kv_indices}, index_slice: {index_slice}, is_last: {is_last}"
51
- )
52
- if is_last:
53
- self.has_sent = True
54
- logger.info(f"FakeKVSender send success")
55
- else:
56
- self.has_sent = False
57
- logger.info(f"FakeKVSender send fake transferring")
46
+ self.has_sent = True
47
+ logger.info(f"FakeKVSender send with kv_indices: {kv_indices}")
58
48
 
59
49
  def failure_exception(self):
60
50
  raise Exception("Fake KVSender Exception")
@@ -0,0 +1,357 @@
1
+ """
2
+ Copyright 2025 SGLang Team
3
+ Licensed under the Apache License, Version 2.0 (the "License");
4
+ you may not use this file except in compliance with the License.
5
+ You may obtain a copy of the License at
6
+
7
+ http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ Unless required by applicable law or agreed to in writing, software
10
+ distributed under the License is distributed on an "AS IS" BASIS,
11
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ See the License for the specific language governing permissions and
13
+ limitations under the License.
14
+ """
15
+
16
+ """
17
+ KV caching events
18
+ """
19
+
20
+ import atexit
21
+ import logging
22
+ import queue
23
+ import threading
24
+ import time
25
+ from abc import ABC, abstractmethod
26
+ from collections import deque
27
+ from itertools import count
28
+ from queue import Queue
29
+ from typing import Any, Callable, Optional, Union
30
+
31
+ import msgspec
32
+ import zmq
33
+ from pydantic import BaseModel
34
+
35
+ logger = logging.getLogger(__name__)
36
+
37
+
38
+ class EventBatch(
39
+ msgspec.Struct,
40
+ array_like=True, # type: ignore[call-arg]
41
+ omit_defaults=True, # type: ignore[call-arg]
42
+ gc=False, # type: ignore[call-arg]
43
+ ):
44
+ ts: float
45
+ events: list[Any]
46
+
47
+
48
+ class KVCacheEvent(
49
+ msgspec.Struct,
50
+ array_like=True, # type: ignore[call-arg]
51
+ omit_defaults=True, # type: ignore[call-arg]
52
+ gc=False, # type: ignore[call-arg]
53
+ tag=True,
54
+ ):
55
+ """Base class for all KV cache-related events"""
56
+
57
+
58
+ class BlockStored(KVCacheEvent):
59
+ block_hashes: list[int]
60
+ parent_block_hash: Optional[int]
61
+ token_ids: list[int]
62
+ block_size: int
63
+ lora_id: Optional[int]
64
+
65
+
66
+ class BlockRemoved(KVCacheEvent):
67
+ block_hashes: list[int]
68
+
69
+
70
+ class AllBlocksCleared(KVCacheEvent):
71
+ pass
72
+
73
+
74
+ class KVEventBatch(EventBatch):
75
+ events: list[Union[BlockStored, BlockRemoved, AllBlocksCleared]]
76
+
77
+
78
+ class EventPublisher(ABC):
79
+ """Lightweight publisher for EventBatch batches."""
80
+
81
+ @abstractmethod
82
+ def publish(self, events: EventBatch) -> None:
83
+ """Emit events in order.
84
+
85
+ Implementations should guarantee at-least-once delivery and
86
+ monotonic ordering (e.g., via sequence numbers).
87
+ """
88
+
89
+ @abstractmethod
90
+ def shutdown(self) -> None:
91
+ """Shutdown the publisher."""
92
+
93
+
94
+ class NullEventPublisher(EventPublisher):
95
+ """No-op implementation (default when disabled)."""
96
+
97
+ def publish(self, events) -> None:
98
+ return
99
+
100
+ def shutdown(self) -> None:
101
+ return
102
+
103
+
104
+ class ZmqEventPublisher(EventPublisher):
105
+ """Reliable PUB/ROUTER publisher with an in-memory replay buffer.
106
+
107
+ Spawns a separate thread to handle publishing from a queue.
108
+
109
+ Parameters
110
+ ----------
111
+ endpoint:
112
+ PUB address. Use ``tcp://*:5557`` to bind or ``tcp://host:5557`` to
113
+ connect.
114
+ replay_endpoint:
115
+ Optional ROUTER address for replay requests. When given, subscribers can
116
+ request missed batches by sending the starting sequence number as an
117
+ 8-byte big-endian integer.
118
+ buffer_steps:
119
+ Number of past batches to keep for replay.
120
+ hwm:
121
+ ZeroMQ high-water-mark for PUB socket.
122
+ max_queue_size:
123
+ Maximum number of events to buffer in memory.
124
+ topic:
125
+ Topic to publish events to.
126
+ """
127
+
128
+ SHUTDOWN_TIMEOUT: float = 1.0
129
+ END_SEQ = (-1).to_bytes(8, "big", signed=True)
130
+
131
+ def __init__(
132
+ self,
133
+ endpoint: str = "tcp://*:5557",
134
+ replay_endpoint: Optional[str] = None,
135
+ buffer_steps: int = 10_000,
136
+ hwm: int = 100_000,
137
+ max_queue_size: int = 100_000,
138
+ topic: str = "",
139
+ ) -> None:
140
+ # Storage
141
+ self._event_queue = Queue[Optional[EventBatch]](maxsize=max_queue_size)
142
+ self._buffer = deque[tuple[int, bytes]](maxlen=buffer_steps)
143
+
144
+ # ZMQ sockets
145
+ self._ctx = zmq.Context.instance()
146
+ self._pub: Optional[zmq.Socket] = None
147
+ self._replay: Optional[zmq.Socket] = None
148
+ self._endpoint = endpoint
149
+ self._replay_endpoint = replay_endpoint
150
+ self._hwm = hwm
151
+ self._socket_setup()
152
+
153
+ # Payload
154
+ self._seq_gen = count()
155
+ self._topic_bytes = topic.encode("utf-8")
156
+
157
+ # Thread
158
+ self._running = True
159
+ logger.info("Starting ZMQ publisher thread")
160
+
161
+ self._thread = threading.Thread(
162
+ target=self._publisher_thread, daemon=True, name="zmq-publisher"
163
+ )
164
+ self._thread.start()
165
+
166
+ atexit.register(self.shutdown)
167
+
168
+ def publish(self, events: EventBatch) -> None:
169
+ if not self._running:
170
+ raise RuntimeError("Publisher is closed")
171
+ self._event_queue.put(events)
172
+
173
+ def shutdown(self) -> None:
174
+ """Stop the publisher thread and clean up resources."""
175
+ self._running = False
176
+ self._event_queue.put_nowait(None)
177
+
178
+ start = time.time()
179
+ pending_items = True
180
+ while pending_items and (time.time() - start < self.SHUTDOWN_TIMEOUT):
181
+ pending_items = not self._event_queue.empty()
182
+ if pending_items:
183
+ time.sleep(0.1)
184
+
185
+ if pending_items:
186
+ logger.warning(
187
+ "Warning: Queue still has %s items after %s seconds timeout",
188
+ self._event_queue.qsize(),
189
+ self.SHUTDOWN_TIMEOUT,
190
+ )
191
+
192
+ if self._thread.is_alive():
193
+ self._thread.join(timeout=self.SHUTDOWN_TIMEOUT)
194
+
195
+ # Clean up ZMQ resources
196
+ try:
197
+ if self._pub is not None:
198
+ self._pub.close(linger=0)
199
+ if self._replay is not None:
200
+ self._replay.close(linger=0)
201
+ finally:
202
+ pass # Do not terminate context; other sockets may use it
203
+
204
+ def _socket_setup(self) -> None:
205
+ """Initialize sockets
206
+ https://pyzmq.readthedocs.io/en/v19.0.0/morethanbindings.html#thread-safety
207
+ """
208
+ if self._pub is None:
209
+ self._pub = self._ctx.socket(zmq.PUB)
210
+ self._pub.set_hwm(self._hwm)
211
+ # Heuristic: bind if wildcard / * present, else connect.
212
+ # bind stable, connect volatile convention
213
+ if (
214
+ "*" in self._endpoint
215
+ or "::" in self._endpoint
216
+ or self._endpoint.startswith("ipc://")
217
+ or self._endpoint.startswith("inproc://")
218
+ ):
219
+ self._pub.bind(self._endpoint)
220
+ else:
221
+ self._pub.connect(self._endpoint)
222
+
223
+ # Set up replay socket: use ROUTER
224
+ # 1) handles multiple REQ clients (identities)
225
+ # 2) lets us send back one request → many replies (streamed events)
226
+ # 3) works in our non‑blocking poll loop alongside PUB
227
+ if self._replay_endpoint is not None:
228
+ self._replay = self._ctx.socket(zmq.ROUTER)
229
+ self._replay.bind(self._replay_endpoint)
230
+
231
+ def _publisher_thread(self) -> None:
232
+ """Background thread that processes the event queue."""
233
+ self._pack = msgspec.msgpack.Encoder()
234
+
235
+ assert self._pub is not None # narrows type for mypy
236
+
237
+ while self._running or self._event_queue.qsize() > 0:
238
+ # --- replay (non-critical) ---------------------------------
239
+ if self._replay is not None and self._replay.poll(0):
240
+ try:
241
+ self._service_replay()
242
+ except Exception as e:
243
+ logger.exception("Error in replay: %s", e)
244
+
245
+ # --- main queue (critical) ---------------------------------
246
+ try:
247
+ event = self._event_queue.get(timeout=0.1)
248
+ if event is None:
249
+ break # Sentinel received, exit thread
250
+ except queue.Empty:
251
+ continue
252
+
253
+ try:
254
+ seq = next(self._seq_gen)
255
+
256
+ payload = self._pack.encode(event)
257
+ seq_bytes = seq.to_bytes(8, "big")
258
+ self._pub.send_multipart((self._topic_bytes, seq_bytes, payload))
259
+
260
+ self._buffer.append((seq, payload))
261
+ self._event_queue.task_done()
262
+
263
+ except Exception as e:
264
+ # Publishing failed; back-off a bit to avoid a tight error loop
265
+ logger.exception("Error in publisher thread: %s", e)
266
+ time.sleep(0.1)
267
+
268
+ def _service_replay(self) -> None:
269
+ """If a replay request is waiting, send buffered batches."""
270
+ assert self._replay is not None # narrows type for mypy
271
+
272
+ frame = self._replay.recv_multipart()
273
+ if len(frame) != 3:
274
+ logger.warning("Invalid replay request: %s", frame)
275
+ return
276
+ client_id, _, start_seq_bytes = frame
277
+ start_seq = int.from_bytes(start_seq_bytes, "big")
278
+
279
+ for seq, buf in self._buffer:
280
+ if seq >= start_seq:
281
+ # [identity, empty_delim, seq_bytes, payload]
282
+ # (identity, empty_delim) are stripped off by the router
283
+ # receiving payload is (seq_bytes, payload)
284
+ self._replay.send_multipart(
285
+ (client_id, b"", seq.to_bytes(8, "big"), buf)
286
+ )
287
+ # Send end of sequence marker
288
+ # receiving payload is (-1, b""")
289
+ self._replay.send_multipart((client_id, b"", self.END_SEQ, b""))
290
+
291
+
292
+ class KVEventsConfig(BaseModel):
293
+ """Configuration for KV event publishing."""
294
+
295
+ publisher: str = "null"
296
+ """The publisher to use for publishing kv events. Can be "null", "zmq".
297
+ """
298
+
299
+ endpoint: str = "tcp://*:5557"
300
+ """The zmq endpoint to use for publishing kv events.
301
+ """
302
+
303
+ replay_endpoint: Optional[str] = None
304
+ """The zmq endpoint to use for replaying kv events.
305
+ """
306
+
307
+ buffer_steps: int = 10_000
308
+ """The number of steps to cache for replay endpoint. Will only save
309
+ events from the last N steps for the replay endpoint.
310
+ """
311
+
312
+ hwm: int = 100_000
313
+ """The zmq high water mark for the event publisher. After queueing N events,
314
+ events will start dropping if the consumer is not keeping up.
315
+ """
316
+
317
+ max_queue_size: int = 100_000
318
+ """The maximum number of events to queue while waiting for publishing.
319
+ """
320
+
321
+ topic: str = ""
322
+ """The topic to use for the event publisher. Consumers can subscribe to
323
+ this topic to receive events.
324
+ """
325
+
326
+ @classmethod
327
+ def from_cli(cls, cli_value: str) -> "KVEventsConfig":
328
+ """Parse the CLI value for the event publisher config."""
329
+ return KVEventsConfig.model_validate_json(cli_value)
330
+
331
+
332
+ class EventPublisherFactory:
333
+ _registry: dict[str, Callable[..., EventPublisher]] = {
334
+ "null": NullEventPublisher,
335
+ "zmq": ZmqEventPublisher,
336
+ }
337
+
338
+ @classmethod
339
+ def register_publisher(cls, name: str, ctor: Callable[..., EventPublisher]) -> None:
340
+ if name in cls._registry:
341
+ raise KeyError(f"publisher '{name}' already registered")
342
+ cls._registry[name] = ctor
343
+
344
+ @classmethod
345
+ def create(cls, config: Optional[str]) -> EventPublisher:
346
+ """Create publisher from a config mapping."""
347
+ if not config:
348
+ return NullEventPublisher()
349
+ config = KVEventsConfig.from_cli(config)
350
+ config_dict = config.model_dump()
351
+
352
+ kind = config_dict.pop("publisher", "null")
353
+ try:
354
+ constructor = cls._registry[kind]
355
+ except KeyError as exc:
356
+ raise ValueError(f"Unknown event publisher '{kind}'") from exc
357
+ return constructor(**config_dict)
@@ -73,11 +73,27 @@ class MiniLoadBalancer:
73
73
  session.post(f"{prefill_server}/{endpoint}", json=modified_request),
74
74
  session.post(f"{decode_server}/{endpoint}", json=modified_request),
75
75
  ]
76
+
76
77
  # Wait for both responses to complete. Prefill should end first.
77
- _, decode_response = await asyncio.gather(*tasks)
78
+ prefill_response, decode_response = await asyncio.gather(*tasks)
79
+
80
+ if "return_logprob" in modified_request:
81
+
82
+ prefill_json = await prefill_response.json()
83
+ ret_json = await decode_response.json()
84
+
85
+ # merge `meta_info.input_token_logprobs` from prefill to decode
86
+ if "meta_info" in ret_json:
87
+ if "input_token_logprobs" in ret_json["meta_info"]:
88
+ ret_json["meta_info"]["input_token_logprobs"] = (
89
+ prefill_json["meta_info"]["input_token_logprobs"]
90
+ + ret_json["meta_info"]["input_token_logprobs"]
91
+ )
92
+ else:
93
+ ret_json = await decode_response.json()
78
94
 
79
95
  return ORJSONResponse(
80
- content=await decode_response.json(),
96
+ content=ret_json,
81
97
  status_code=decode_response.status,
82
98
  )
83
99
 
@@ -92,30 +108,47 @@ class MiniLoadBalancer:
92
108
  total=3600
93
109
  ) # Add timeout for request reliability
94
110
  ) as session:
95
- try:
96
- # Create the tasks for both prefill and decode requests
97
- tasks = [
98
- session.post(
99
- f"{prefill_server}/{endpoint}", json=modified_request
100
- ),
101
- session.post(
102
- f"{decode_server}/{endpoint}", json=modified_request
103
- ),
104
- ]
105
- # Wait for both responses to complete. Since this is streaming, they return immediately.
106
- prefill_response, decode_response = await asyncio.gather(*tasks)
111
+ # Create the tasks for both prefill and decode requests
112
+ tasks = [
113
+ session.post(f"{prefill_server}/generate", json=modified_request),
114
+ session.post(f"{decode_server}/generate", json=modified_request),
115
+ ]
116
+ # Wait for both responses to complete. Since this is streaming, they return immediately.
117
+ prefill_response, decode_response = await asyncio.gather(*tasks)
118
+
119
+ if modified_request.get("return_logprob", False):
120
+ prefill_chunks = []
121
+ async for chunk in prefill_response.content:
122
+ prefill_chunks.append(chunk)
123
+
124
+ first_prefill_chunk = (
125
+ prefill_chunks[0].decode("utf-8")[5:].strip("\n")
126
+ )
127
+ first_prefill_chunk_json = orjson.loads(first_prefill_chunk)
128
+
129
+ async for chunk in decode_response.content:
130
+ # Note: This is inefficient
131
+ # merge prefill input_token_logprobs, output_token_logprobs to decode
132
+ decoded_chunk = chunk.decode("utf-8")
133
+ if (
134
+ decoded_chunk
135
+ and decoded_chunk.startswith("data:")
136
+ and "[DONE]" not in decoded_chunk
137
+ ):
138
+ ret_json = orjson.loads(decoded_chunk[5:].strip("\n"))
139
+ ret_json["meta_info"]["input_token_logprobs"] = (
140
+ first_prefill_chunk_json["meta_info"][
141
+ "input_token_logprobs"
142
+ ]
143
+ + ret_json["meta_info"]["input_token_logprobs"]
144
+ )
145
+
146
+ yield b"data: " + orjson.dumps(ret_json) + b"\n\n"
147
+ else:
148
+ yield chunk
149
+ else:
107
150
  async for chunk in decode_response.content:
108
151
  yield chunk
109
- except Exception as e:
110
- error_msg = {
111
- "error": {"message": f"Stream processing error: {str(e)}"}
112
- }
113
- yield b"data: " + orjson.dumps(
114
- error_msg, option=orjson.OPT_NON_STR_KEYS
115
- ) + b"\n\n"
116
- finally:
117
- if prefill_response is not None:
118
- await prefill_response.release()
119
152
 
120
153
  return StreamingResponse(
121
154
  stream_results(),
@@ -51,6 +51,7 @@ def group_concurrent_contiguous(
51
51
  return src_groups, dst_groups
52
52
 
53
53
 
54
+ # prefill
54
55
  @dataclasses.dataclass
55
56
  class TransferKVChunk:
56
57
  room: int
@@ -60,6 +61,7 @@ class TransferKVChunk:
60
61
  prefill_aux_index: Optional[int]
61
62
 
62
63
 
64
+ # decode
63
65
  @dataclasses.dataclass
64
66
  class TransferInfo:
65
67
  room: int
@@ -93,6 +95,7 @@ class TransferInfo:
93
95
  )
94
96
 
95
97
 
98
+ # decode
96
99
  @dataclasses.dataclass
97
100
  class KVArgsRegisterInfo:
98
101
  room: str
@@ -464,6 +467,8 @@ class MooncakeKVSender(BaseKVSender):
464
467
  self.aux_index = None
465
468
  self.bootstrap_server_url = bootstrap_addr
466
469
  self.session_id = self.kv_mgr.get_session_id()
470
+ # inner state
471
+ self.curr_idx = 0
467
472
 
468
473
  def init(self, num_kv_indices: int, aux_index: Optional[int] = None):
469
474
  self.num_kv_indices = num_kv_indices
@@ -472,9 +477,11 @@ class MooncakeKVSender(BaseKVSender):
472
477
  def send(
473
478
  self,
474
479
  kv_indices: npt.NDArray[np.int64],
475
- index_slice: slice,
476
- is_last: bool,
477
480
  ):
481
+ index_slice = slice(self.curr_idx, self.curr_idx + len(kv_indices))
482
+ self.curr_idx += len(kv_indices)
483
+ is_last = self.curr_idx == self.num_kv_indices
484
+
478
485
  if not is_last:
479
486
  self.kv_mgr.add_transfer_request(
480
487
  self.bootstrap_room, kv_indices, index_slice, False
@@ -492,6 +499,7 @@ class MooncakeKVSender(BaseKVSender):
492
499
  return self.kv_mgr.check_status(self.bootstrap_room)
493
500
 
494
501
  def failure_exception(self):
502
+ # TODO: raise a real exception
495
503
  raise Exception("Fake KVSender Exception")
496
504
 
497
505
 
@@ -719,6 +727,7 @@ class MooncakeKVReceiver(BaseKVReceiver):
719
727
  return self.kv_mgr.check_status(self.bootstrap_room)
720
728
 
721
729
  def failure_exception(self):
730
+ # TODO: raise a real exception
722
731
  raise Exception("Fake KVReceiver Exception")
723
732
 
724
733