sglang 0.4.6.post3__py3-none-any.whl → 0.4.6.post5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (180) hide show
  1. sglang/bench_offline_throughput.py +10 -8
  2. sglang/bench_one_batch.py +7 -6
  3. sglang/bench_one_batch_server.py +157 -21
  4. sglang/bench_serving.py +137 -59
  5. sglang/compile_deep_gemm.py +5 -5
  6. sglang/eval/loogle_eval.py +157 -0
  7. sglang/lang/chat_template.py +78 -78
  8. sglang/lang/tracer.py +1 -1
  9. sglang/srt/code_completion_parser.py +1 -1
  10. sglang/srt/configs/deepseekvl2.py +2 -2
  11. sglang/srt/configs/model_config.py +40 -28
  12. sglang/srt/constrained/base_grammar_backend.py +55 -72
  13. sglang/srt/constrained/llguidance_backend.py +25 -21
  14. sglang/srt/constrained/outlines_backend.py +27 -26
  15. sglang/srt/constrained/reasoner_grammar_backend.py +22 -33
  16. sglang/srt/constrained/xgrammar_backend.py +69 -43
  17. sglang/srt/conversation.py +49 -44
  18. sglang/srt/disaggregation/base/conn.py +1 -0
  19. sglang/srt/disaggregation/decode.py +129 -135
  20. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +142 -0
  21. sglang/srt/disaggregation/fake/conn.py +3 -13
  22. sglang/srt/disaggregation/kv_events.py +357 -0
  23. sglang/srt/disaggregation/mini_lb.py +57 -24
  24. sglang/srt/disaggregation/mooncake/conn.py +238 -122
  25. sglang/srt/disaggregation/mooncake/transfer_engine.py +2 -1
  26. sglang/srt/disaggregation/nixl/conn.py +10 -19
  27. sglang/srt/disaggregation/prefill.py +132 -47
  28. sglang/srt/disaggregation/utils.py +123 -6
  29. sglang/srt/distributed/utils.py +3 -3
  30. sglang/srt/entrypoints/EngineBase.py +5 -0
  31. sglang/srt/entrypoints/engine.py +44 -9
  32. sglang/srt/entrypoints/http_server.py +23 -6
  33. sglang/srt/entrypoints/http_server_engine.py +5 -2
  34. sglang/srt/function_call/base_format_detector.py +250 -0
  35. sglang/srt/function_call/core_types.py +34 -0
  36. sglang/srt/function_call/deepseekv3_detector.py +157 -0
  37. sglang/srt/function_call/ebnf_composer.py +234 -0
  38. sglang/srt/function_call/function_call_parser.py +175 -0
  39. sglang/srt/function_call/llama32_detector.py +74 -0
  40. sglang/srt/function_call/mistral_detector.py +84 -0
  41. sglang/srt/function_call/pythonic_detector.py +163 -0
  42. sglang/srt/function_call/qwen25_detector.py +67 -0
  43. sglang/srt/function_call/utils.py +35 -0
  44. sglang/srt/hf_transformers_utils.py +46 -7
  45. sglang/srt/layers/attention/aiter_backend.py +513 -0
  46. sglang/srt/layers/attention/flashattention_backend.py +64 -18
  47. sglang/srt/layers/attention/flashinfer_mla_backend.py +8 -4
  48. sglang/srt/layers/attention/flashmla_backend.py +340 -78
  49. sglang/srt/layers/attention/triton_backend.py +3 -0
  50. sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +1 -1
  51. sglang/srt/layers/attention/utils.py +6 -4
  52. sglang/srt/layers/attention/vision.py +1 -1
  53. sglang/srt/layers/communicator.py +451 -0
  54. sglang/srt/layers/dp_attention.py +61 -21
  55. sglang/srt/layers/layernorm.py +1 -1
  56. sglang/srt/layers/logits_processor.py +46 -11
  57. sglang/srt/layers/moe/cutlass_moe.py +207 -0
  58. sglang/srt/layers/moe/ep_moe/kernels.py +34 -12
  59. sglang/srt/layers/moe/ep_moe/layer.py +105 -51
  60. sglang/srt/layers/moe/ep_moe/token_dispatcher.py +82 -7
  61. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +1 -1
  62. sglang/srt/layers/moe/fused_moe_triton/layer.py +14 -0
  63. sglang/srt/layers/moe/topk.py +67 -10
  64. sglang/srt/layers/multimodal.py +70 -0
  65. sglang/srt/layers/quantization/__init__.py +8 -3
  66. sglang/srt/layers/quantization/blockwise_int8.py +2 -2
  67. sglang/srt/layers/quantization/deep_gemm.py +77 -74
  68. sglang/srt/layers/quantization/fp8.py +92 -2
  69. sglang/srt/layers/quantization/fp8_kernel.py +3 -3
  70. sglang/srt/layers/quantization/fp8_utils.py +6 -0
  71. sglang/srt/layers/quantization/gptq.py +298 -6
  72. sglang/srt/layers/quantization/int8_kernel.py +20 -7
  73. sglang/srt/layers/quantization/qoq.py +244 -0
  74. sglang/srt/layers/sampler.py +0 -4
  75. sglang/srt/layers/vocab_parallel_embedding.py +18 -7
  76. sglang/srt/lora/lora_manager.py +2 -4
  77. sglang/srt/lora/mem_pool.py +4 -4
  78. sglang/srt/lora/triton_ops/gate_up_lora_b.py +1 -1
  79. sglang/srt/lora/triton_ops/qkv_lora_b.py +1 -1
  80. sglang/srt/lora/triton_ops/sgemm_lora_a.py +1 -1
  81. sglang/srt/lora/triton_ops/sgemm_lora_b.py +1 -1
  82. sglang/srt/lora/utils.py +1 -1
  83. sglang/srt/managers/data_parallel_controller.py +3 -3
  84. sglang/srt/managers/deepseek_eplb.py +278 -0
  85. sglang/srt/managers/detokenizer_manager.py +21 -8
  86. sglang/srt/managers/eplb_manager.py +55 -0
  87. sglang/srt/managers/expert_distribution.py +704 -56
  88. sglang/srt/managers/expert_location.py +394 -0
  89. sglang/srt/managers/expert_location_dispatch.py +91 -0
  90. sglang/srt/managers/io_struct.py +19 -4
  91. sglang/srt/managers/mm_utils.py +294 -140
  92. sglang/srt/managers/multimodal_processors/base_processor.py +127 -42
  93. sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +6 -1
  94. sglang/srt/managers/multimodal_processors/gemma3.py +31 -6
  95. sglang/srt/managers/multimodal_processors/internvl.py +14 -5
  96. sglang/srt/managers/multimodal_processors/janus_pro.py +7 -1
  97. sglang/srt/managers/multimodal_processors/kimi_vl.py +7 -6
  98. sglang/srt/managers/multimodal_processors/llava.py +46 -0
  99. sglang/srt/managers/multimodal_processors/minicpm.py +25 -31
  100. sglang/srt/managers/multimodal_processors/mllama4.py +6 -0
  101. sglang/srt/managers/multimodal_processors/pixtral.py +127 -0
  102. sglang/srt/managers/multimodal_processors/qwen_vl.py +58 -16
  103. sglang/srt/managers/schedule_batch.py +122 -42
  104. sglang/srt/managers/schedule_policy.py +1 -5
  105. sglang/srt/managers/scheduler.py +205 -138
  106. sglang/srt/managers/scheduler_output_processor_mixin.py +124 -55
  107. sglang/srt/managers/session_controller.py +1 -1
  108. sglang/srt/managers/tokenizer_manager.py +232 -58
  109. sglang/srt/managers/tp_worker.py +12 -9
  110. sglang/srt/managers/tp_worker_overlap_thread.py +22 -11
  111. sglang/srt/mem_cache/base_prefix_cache.py +3 -0
  112. sglang/srt/mem_cache/chunk_cache.py +3 -1
  113. sglang/srt/mem_cache/hiradix_cache.py +4 -4
  114. sglang/srt/mem_cache/memory_pool.py +76 -52
  115. sglang/srt/mem_cache/multimodal_cache.py +45 -0
  116. sglang/srt/mem_cache/radix_cache.py +58 -5
  117. sglang/srt/metrics/collector.py +314 -39
  118. sglang/srt/mm_utils.py +10 -0
  119. sglang/srt/model_executor/cuda_graph_runner.py +29 -19
  120. sglang/srt/model_executor/expert_location_updater.py +422 -0
  121. sglang/srt/model_executor/forward_batch_info.py +5 -1
  122. sglang/srt/model_executor/model_runner.py +163 -68
  123. sglang/srt/model_loader/loader.py +10 -6
  124. sglang/srt/models/clip.py +5 -1
  125. sglang/srt/models/deepseek_janus_pro.py +2 -2
  126. sglang/srt/models/deepseek_v2.py +308 -351
  127. sglang/srt/models/exaone.py +8 -3
  128. sglang/srt/models/gemma3_mm.py +70 -33
  129. sglang/srt/models/llama.py +2 -0
  130. sglang/srt/models/llama4.py +15 -8
  131. sglang/srt/models/llava.py +258 -7
  132. sglang/srt/models/mimo_mtp.py +220 -0
  133. sglang/srt/models/minicpmo.py +5 -12
  134. sglang/srt/models/mistral.py +71 -1
  135. sglang/srt/models/mixtral.py +98 -34
  136. sglang/srt/models/mllama.py +3 -3
  137. sglang/srt/models/pixtral.py +467 -0
  138. sglang/srt/models/qwen2.py +95 -26
  139. sglang/srt/models/qwen2_5_vl.py +8 -0
  140. sglang/srt/models/qwen2_moe.py +330 -60
  141. sglang/srt/models/qwen2_vl.py +6 -0
  142. sglang/srt/models/qwen3.py +52 -10
  143. sglang/srt/models/qwen3_moe.py +411 -48
  144. sglang/srt/models/roberta.py +1 -1
  145. sglang/srt/models/siglip.py +294 -0
  146. sglang/srt/models/torch_native_llama.py +1 -1
  147. sglang/srt/openai_api/adapter.py +58 -20
  148. sglang/srt/openai_api/protocol.py +6 -8
  149. sglang/srt/operations.py +154 -0
  150. sglang/srt/operations_strategy.py +31 -0
  151. sglang/srt/reasoning_parser.py +3 -3
  152. sglang/srt/sampling/custom_logit_processor.py +18 -3
  153. sglang/srt/sampling/sampling_batch_info.py +4 -56
  154. sglang/srt/sampling/sampling_params.py +2 -2
  155. sglang/srt/server_args.py +162 -22
  156. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +3 -3
  157. sglang/srt/speculative/eagle_utils.py +138 -7
  158. sglang/srt/speculative/eagle_worker.py +69 -21
  159. sglang/srt/utils.py +74 -17
  160. sglang/test/few_shot_gsm8k.py +2 -2
  161. sglang/test/few_shot_gsm8k_engine.py +2 -2
  162. sglang/test/run_eval.py +2 -2
  163. sglang/test/runners.py +8 -1
  164. sglang/test/send_one.py +13 -3
  165. sglang/test/simple_eval_common.py +1 -1
  166. sglang/test/simple_eval_humaneval.py +1 -1
  167. sglang/test/test_cutlass_moe.py +278 -0
  168. sglang/test/test_programs.py +5 -5
  169. sglang/test/test_utils.py +55 -14
  170. sglang/utils.py +3 -3
  171. sglang/version.py +1 -1
  172. {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/METADATA +23 -13
  173. {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/RECORD +178 -149
  174. {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/WHEEL +1 -1
  175. sglang/srt/function_call_parser.py +0 -858
  176. sglang/srt/platforms/interface.py +0 -371
  177. /sglang/{llama3_eval.py → eval/llama3_eval.py} +0 -0
  178. /sglang/srt/models/{xiaomi_mimo.py → mimo.py} +0 -0
  179. {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/licenses/LICENSE +0 -0
  180. {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,357 @@
1
+ """
2
+ Copyright 2025 SGLang Team
3
+ Licensed under the Apache License, Version 2.0 (the "License");
4
+ you may not use this file except in compliance with the License.
5
+ You may obtain a copy of the License at
6
+
7
+ http://www.apache.org/licenses/LICENSE-2.0
8
+
9
+ Unless required by applicable law or agreed to in writing, software
10
+ distributed under the License is distributed on an "AS IS" BASIS,
11
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ See the License for the specific language governing permissions and
13
+ limitations under the License.
14
+ """
15
+
16
+ """
17
+ KV caching events
18
+ """
19
+
20
+ import atexit
21
+ import logging
22
+ import queue
23
+ import threading
24
+ import time
25
+ from abc import ABC, abstractmethod
26
+ from collections import deque
27
+ from itertools import count
28
+ from queue import Queue
29
+ from typing import Any, Callable, Optional, Union
30
+
31
+ import msgspec
32
+ import zmq
33
+ from pydantic import BaseModel
34
+
35
+ logger = logging.getLogger(__name__)
36
+
37
+
38
+ class EventBatch(
39
+ msgspec.Struct,
40
+ array_like=True, # type: ignore[call-arg]
41
+ omit_defaults=True, # type: ignore[call-arg]
42
+ gc=False, # type: ignore[call-arg]
43
+ ):
44
+ ts: float
45
+ events: list[Any]
46
+
47
+
48
+ class KVCacheEvent(
49
+ msgspec.Struct,
50
+ array_like=True, # type: ignore[call-arg]
51
+ omit_defaults=True, # type: ignore[call-arg]
52
+ gc=False, # type: ignore[call-arg]
53
+ tag=True,
54
+ ):
55
+ """Base class for all KV cache-related events"""
56
+
57
+
58
+ class BlockStored(KVCacheEvent):
59
+ block_hashes: list[int]
60
+ parent_block_hash: Optional[int]
61
+ token_ids: list[int]
62
+ block_size: int
63
+ lora_id: Optional[int]
64
+
65
+
66
+ class BlockRemoved(KVCacheEvent):
67
+ block_hashes: list[int]
68
+
69
+
70
+ class AllBlocksCleared(KVCacheEvent):
71
+ pass
72
+
73
+
74
+ class KVEventBatch(EventBatch):
75
+ events: list[Union[BlockStored, BlockRemoved, AllBlocksCleared]]
76
+
77
+
78
+ class EventPublisher(ABC):
79
+ """Lightweight publisher for EventBatch batches."""
80
+
81
+ @abstractmethod
82
+ def publish(self, events: EventBatch) -> None:
83
+ """Emit events in order.
84
+
85
+ Implementations should guarantee at-least-once delivery and
86
+ monotonic ordering (e.g., via sequence numbers).
87
+ """
88
+
89
+ @abstractmethod
90
+ def shutdown(self) -> None:
91
+ """Shutdown the publisher."""
92
+
93
+
94
+ class NullEventPublisher(EventPublisher):
95
+ """No-op implementation (default when disabled)."""
96
+
97
+ def publish(self, events) -> None:
98
+ return
99
+
100
+ def shutdown(self) -> None:
101
+ return
102
+
103
+
104
+ class ZmqEventPublisher(EventPublisher):
105
+ """Reliable PUB/ROUTER publisher with an in-memory replay buffer.
106
+
107
+ Spawns a separate thread to handle publishing from a queue.
108
+
109
+ Parameters
110
+ ----------
111
+ endpoint:
112
+ PUB address. Use ``tcp://*:5557`` to bind or ``tcp://host:5557`` to
113
+ connect.
114
+ replay_endpoint:
115
+ Optional ROUTER address for replay requests. When given, subscribers can
116
+ request missed batches by sending the starting sequence number as an
117
+ 8-byte big-endian integer.
118
+ buffer_steps:
119
+ Number of past batches to keep for replay.
120
+ hwm:
121
+ ZeroMQ high-water-mark for PUB socket.
122
+ max_queue_size:
123
+ Maximum number of events to buffer in memory.
124
+ topic:
125
+ Topic to publish events to.
126
+ """
127
+
128
+ SHUTDOWN_TIMEOUT: float = 1.0
129
+ END_SEQ = (-1).to_bytes(8, "big", signed=True)
130
+
131
+ def __init__(
132
+ self,
133
+ endpoint: str = "tcp://*:5557",
134
+ replay_endpoint: Optional[str] = None,
135
+ buffer_steps: int = 10_000,
136
+ hwm: int = 100_000,
137
+ max_queue_size: int = 100_000,
138
+ topic: str = "",
139
+ ) -> None:
140
+ # Storage
141
+ self._event_queue = Queue[Optional[EventBatch]](maxsize=max_queue_size)
142
+ self._buffer = deque[tuple[int, bytes]](maxlen=buffer_steps)
143
+
144
+ # ZMQ sockets
145
+ self._ctx = zmq.Context.instance()
146
+ self._pub: Optional[zmq.Socket] = None
147
+ self._replay: Optional[zmq.Socket] = None
148
+ self._endpoint = endpoint
149
+ self._replay_endpoint = replay_endpoint
150
+ self._hwm = hwm
151
+ self._socket_setup()
152
+
153
+ # Payload
154
+ self._seq_gen = count()
155
+ self._topic_bytes = topic.encode("utf-8")
156
+
157
+ # Thread
158
+ self._running = True
159
+ logger.info("Starting ZMQ publisher thread")
160
+
161
+ self._thread = threading.Thread(
162
+ target=self._publisher_thread, daemon=True, name="zmq-publisher"
163
+ )
164
+ self._thread.start()
165
+
166
+ atexit.register(self.shutdown)
167
+
168
+ def publish(self, events: EventBatch) -> None:
169
+ if not self._running:
170
+ raise RuntimeError("Publisher is closed")
171
+ self._event_queue.put(events)
172
+
173
+ def shutdown(self) -> None:
174
+ """Stop the publisher thread and clean up resources."""
175
+ self._running = False
176
+ self._event_queue.put_nowait(None)
177
+
178
+ start = time.time()
179
+ pending_items = True
180
+ while pending_items and (time.time() - start < self.SHUTDOWN_TIMEOUT):
181
+ pending_items = not self._event_queue.empty()
182
+ if pending_items:
183
+ time.sleep(0.1)
184
+
185
+ if pending_items:
186
+ logger.warning(
187
+ "Warning: Queue still has %s items after %s seconds timeout",
188
+ self._event_queue.qsize(),
189
+ self.SHUTDOWN_TIMEOUT,
190
+ )
191
+
192
+ if self._thread.is_alive():
193
+ self._thread.join(timeout=self.SHUTDOWN_TIMEOUT)
194
+
195
+ # Clean up ZMQ resources
196
+ try:
197
+ if self._pub is not None:
198
+ self._pub.close(linger=0)
199
+ if self._replay is not None:
200
+ self._replay.close(linger=0)
201
+ finally:
202
+ pass # Do not terminate context; other sockets may use it
203
+
204
+ def _socket_setup(self) -> None:
205
+ """Initialize sockets
206
+ https://pyzmq.readthedocs.io/en/v19.0.0/morethanbindings.html#thread-safety
207
+ """
208
+ if self._pub is None:
209
+ self._pub = self._ctx.socket(zmq.PUB)
210
+ self._pub.set_hwm(self._hwm)
211
+ # Heuristic: bind if wildcard / * present, else connect.
212
+ # bind stable, connect volatile convention
213
+ if (
214
+ "*" in self._endpoint
215
+ or "::" in self._endpoint
216
+ or self._endpoint.startswith("ipc://")
217
+ or self._endpoint.startswith("inproc://")
218
+ ):
219
+ self._pub.bind(self._endpoint)
220
+ else:
221
+ self._pub.connect(self._endpoint)
222
+
223
+ # Set up replay socket: use ROUTER
224
+ # 1) handles multiple REQ clients (identities)
225
+ # 2) lets us send back one request → many replies (streamed events)
226
+ # 3) works in our non‑blocking poll loop alongside PUB
227
+ if self._replay_endpoint is not None:
228
+ self._replay = self._ctx.socket(zmq.ROUTER)
229
+ self._replay.bind(self._replay_endpoint)
230
+
231
+ def _publisher_thread(self) -> None:
232
+ """Background thread that processes the event queue."""
233
+ self._pack = msgspec.msgpack.Encoder()
234
+
235
+ assert self._pub is not None # narrows type for mypy
236
+
237
+ while self._running or self._event_queue.qsize() > 0:
238
+ # --- replay (non-critical) ---------------------------------
239
+ if self._replay is not None and self._replay.poll(0):
240
+ try:
241
+ self._service_replay()
242
+ except Exception as e:
243
+ logger.exception("Error in replay: %s", e)
244
+
245
+ # --- main queue (critical) ---------------------------------
246
+ try:
247
+ event = self._event_queue.get(timeout=0.1)
248
+ if event is None:
249
+ break # Sentinel received, exit thread
250
+ except queue.Empty:
251
+ continue
252
+
253
+ try:
254
+ seq = next(self._seq_gen)
255
+
256
+ payload = self._pack.encode(event)
257
+ seq_bytes = seq.to_bytes(8, "big")
258
+ self._pub.send_multipart((self._topic_bytes, seq_bytes, payload))
259
+
260
+ self._buffer.append((seq, payload))
261
+ self._event_queue.task_done()
262
+
263
+ except Exception as e:
264
+ # Publishing failed; back-off a bit to avoid a tight error loop
265
+ logger.exception("Error in publisher thread: %s", e)
266
+ time.sleep(0.1)
267
+
268
+ def _service_replay(self) -> None:
269
+ """If a replay request is waiting, send buffered batches."""
270
+ assert self._replay is not None # narrows type for mypy
271
+
272
+ frame = self._replay.recv_multipart()
273
+ if len(frame) != 3:
274
+ logger.warning("Invalid replay request: %s", frame)
275
+ return
276
+ client_id, _, start_seq_bytes = frame
277
+ start_seq = int.from_bytes(start_seq_bytes, "big")
278
+
279
+ for seq, buf in self._buffer:
280
+ if seq >= start_seq:
281
+ # [identity, empty_delim, seq_bytes, payload]
282
+ # (identity, empty_delim) are stripped off by the router
283
+ # receiving payload is (seq_bytes, payload)
284
+ self._replay.send_multipart(
285
+ (client_id, b"", seq.to_bytes(8, "big"), buf)
286
+ )
287
+ # Send end of sequence marker
288
+ # receiving payload is (-1, b""")
289
+ self._replay.send_multipart((client_id, b"", self.END_SEQ, b""))
290
+
291
+
292
+ class KVEventsConfig(BaseModel):
293
+ """Configuration for KV event publishing."""
294
+
295
+ publisher: str = "null"
296
+ """The publisher to use for publishing kv events. Can be "null", "zmq".
297
+ """
298
+
299
+ endpoint: str = "tcp://*:5557"
300
+ """The zmq endpoint to use for publishing kv events.
301
+ """
302
+
303
+ replay_endpoint: Optional[str] = None
304
+ """The zmq endpoint to use for replaying kv events.
305
+ """
306
+
307
+ buffer_steps: int = 10_000
308
+ """The number of steps to cache for replay endpoint. Will only save
309
+ events from the last N steps for the replay endpoint.
310
+ """
311
+
312
+ hwm: int = 100_000
313
+ """The zmq high water mark for the event publisher. After queueing N events,
314
+ events will start dropping if the consumer is not keeping up.
315
+ """
316
+
317
+ max_queue_size: int = 100_000
318
+ """The maximum number of events to queue while waiting for publishing.
319
+ """
320
+
321
+ topic: str = ""
322
+ """The topic to use for the event publisher. Consumers can subscribe to
323
+ this topic to receive events.
324
+ """
325
+
326
+ @classmethod
327
+ def from_cli(cls, cli_value: str) -> "KVEventsConfig":
328
+ """Parse the CLI value for the event publisher config."""
329
+ return KVEventsConfig.model_validate_json(cli_value)
330
+
331
+
332
+ class EventPublisherFactory:
333
+ _registry: dict[str, Callable[..., EventPublisher]] = {
334
+ "null": NullEventPublisher,
335
+ "zmq": ZmqEventPublisher,
336
+ }
337
+
338
+ @classmethod
339
+ def register_publisher(cls, name: str, ctor: Callable[..., EventPublisher]) -> None:
340
+ if name in cls._registry:
341
+ raise KeyError(f"publisher '{name}' already registered")
342
+ cls._registry[name] = ctor
343
+
344
+ @classmethod
345
+ def create(cls, config: Optional[str]) -> EventPublisher:
346
+ """Create publisher from a config mapping."""
347
+ if not config:
348
+ return NullEventPublisher()
349
+ config = KVEventsConfig.from_cli(config)
350
+ config_dict = config.model_dump()
351
+
352
+ kind = config_dict.pop("publisher", "null")
353
+ try:
354
+ constructor = cls._registry[kind]
355
+ except KeyError as exc:
356
+ raise ValueError(f"Unknown event publisher '{kind}'") from exc
357
+ return constructor(**config_dict)
@@ -73,11 +73,27 @@ class MiniLoadBalancer:
73
73
  session.post(f"{prefill_server}/{endpoint}", json=modified_request),
74
74
  session.post(f"{decode_server}/{endpoint}", json=modified_request),
75
75
  ]
76
+
76
77
  # Wait for both responses to complete. Prefill should end first.
77
- _, decode_response = await asyncio.gather(*tasks)
78
+ prefill_response, decode_response = await asyncio.gather(*tasks)
79
+
80
+ if "return_logprob" in modified_request:
81
+
82
+ prefill_json = await prefill_response.json()
83
+ ret_json = await decode_response.json()
84
+
85
+ # merge `meta_info.input_token_logprobs` from prefill to decode
86
+ if "meta_info" in ret_json:
87
+ if "input_token_logprobs" in ret_json["meta_info"]:
88
+ ret_json["meta_info"]["input_token_logprobs"] = (
89
+ prefill_json["meta_info"]["input_token_logprobs"]
90
+ + ret_json["meta_info"]["input_token_logprobs"]
91
+ )
92
+ else:
93
+ ret_json = await decode_response.json()
78
94
 
79
95
  return ORJSONResponse(
80
- content=await decode_response.json(),
96
+ content=ret_json,
81
97
  status_code=decode_response.status,
82
98
  )
83
99
 
@@ -92,30 +108,47 @@ class MiniLoadBalancer:
92
108
  total=3600
93
109
  ) # Add timeout for request reliability
94
110
  ) as session:
95
- try:
96
- # Create the tasks for both prefill and decode requests
97
- tasks = [
98
- session.post(
99
- f"{prefill_server}/{endpoint}", json=modified_request
100
- ),
101
- session.post(
102
- f"{decode_server}/{endpoint}", json=modified_request
103
- ),
104
- ]
105
- # Wait for both responses to complete. Since this is streaming, they return immediately.
106
- prefill_response, decode_response = await asyncio.gather(*tasks)
111
+ # Create the tasks for both prefill and decode requests
112
+ tasks = [
113
+ session.post(f"{prefill_server}/generate", json=modified_request),
114
+ session.post(f"{decode_server}/generate", json=modified_request),
115
+ ]
116
+ # Wait for both responses to complete. Since this is streaming, they return immediately.
117
+ prefill_response, decode_response = await asyncio.gather(*tasks)
118
+
119
+ if modified_request.get("return_logprob", False):
120
+ prefill_chunks = []
121
+ async for chunk in prefill_response.content:
122
+ prefill_chunks.append(chunk)
123
+
124
+ first_prefill_chunk = (
125
+ prefill_chunks[0].decode("utf-8")[5:].strip("\n")
126
+ )
127
+ first_prefill_chunk_json = orjson.loads(first_prefill_chunk)
128
+
129
+ async for chunk in decode_response.content:
130
+ # Note: This is inefficient
131
+ # merge prefill input_token_logprobs, output_token_logprobs to decode
132
+ decoded_chunk = chunk.decode("utf-8")
133
+ if (
134
+ decoded_chunk
135
+ and decoded_chunk.startswith("data:")
136
+ and "[DONE]" not in decoded_chunk
137
+ ):
138
+ ret_json = orjson.loads(decoded_chunk[5:].strip("\n"))
139
+ ret_json["meta_info"]["input_token_logprobs"] = (
140
+ first_prefill_chunk_json["meta_info"][
141
+ "input_token_logprobs"
142
+ ]
143
+ + ret_json["meta_info"]["input_token_logprobs"]
144
+ )
145
+
146
+ yield b"data: " + orjson.dumps(ret_json) + b"\n\n"
147
+ else:
148
+ yield chunk
149
+ else:
107
150
  async for chunk in decode_response.content:
108
151
  yield chunk
109
- except Exception as e:
110
- error_msg = {
111
- "error": {"message": f"Stream processing error: {str(e)}"}
112
- }
113
- yield b"data: " + orjson.dumps(
114
- error_msg, option=orjson.OPT_NON_STR_KEYS
115
- ) + b"\n\n"
116
- finally:
117
- if prefill_response is not None:
118
- await prefill_response.release()
119
152
 
120
153
  return StreamingResponse(
121
154
  stream_results(),