sglang 0.4.6.post3__py3-none-any.whl → 0.4.6.post5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_offline_throughput.py +10 -8
- sglang/bench_one_batch.py +7 -6
- sglang/bench_one_batch_server.py +157 -21
- sglang/bench_serving.py +137 -59
- sglang/compile_deep_gemm.py +5 -5
- sglang/eval/loogle_eval.py +157 -0
- sglang/lang/chat_template.py +78 -78
- sglang/lang/tracer.py +1 -1
- sglang/srt/code_completion_parser.py +1 -1
- sglang/srt/configs/deepseekvl2.py +2 -2
- sglang/srt/configs/model_config.py +40 -28
- sglang/srt/constrained/base_grammar_backend.py +55 -72
- sglang/srt/constrained/llguidance_backend.py +25 -21
- sglang/srt/constrained/outlines_backend.py +27 -26
- sglang/srt/constrained/reasoner_grammar_backend.py +22 -33
- sglang/srt/constrained/xgrammar_backend.py +69 -43
- sglang/srt/conversation.py +49 -44
- sglang/srt/disaggregation/base/conn.py +1 -0
- sglang/srt/disaggregation/decode.py +129 -135
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +142 -0
- sglang/srt/disaggregation/fake/conn.py +3 -13
- sglang/srt/disaggregation/kv_events.py +357 -0
- sglang/srt/disaggregation/mini_lb.py +57 -24
- sglang/srt/disaggregation/mooncake/conn.py +238 -122
- sglang/srt/disaggregation/mooncake/transfer_engine.py +2 -1
- sglang/srt/disaggregation/nixl/conn.py +10 -19
- sglang/srt/disaggregation/prefill.py +132 -47
- sglang/srt/disaggregation/utils.py +123 -6
- sglang/srt/distributed/utils.py +3 -3
- sglang/srt/entrypoints/EngineBase.py +5 -0
- sglang/srt/entrypoints/engine.py +44 -9
- sglang/srt/entrypoints/http_server.py +23 -6
- sglang/srt/entrypoints/http_server_engine.py +5 -2
- sglang/srt/function_call/base_format_detector.py +250 -0
- sglang/srt/function_call/core_types.py +34 -0
- sglang/srt/function_call/deepseekv3_detector.py +157 -0
- sglang/srt/function_call/ebnf_composer.py +234 -0
- sglang/srt/function_call/function_call_parser.py +175 -0
- sglang/srt/function_call/llama32_detector.py +74 -0
- sglang/srt/function_call/mistral_detector.py +84 -0
- sglang/srt/function_call/pythonic_detector.py +163 -0
- sglang/srt/function_call/qwen25_detector.py +67 -0
- sglang/srt/function_call/utils.py +35 -0
- sglang/srt/hf_transformers_utils.py +46 -7
- sglang/srt/layers/attention/aiter_backend.py +513 -0
- sglang/srt/layers/attention/flashattention_backend.py +64 -18
- sglang/srt/layers/attention/flashinfer_mla_backend.py +8 -4
- sglang/srt/layers/attention/flashmla_backend.py +340 -78
- sglang/srt/layers/attention/triton_backend.py +3 -0
- sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +1 -1
- sglang/srt/layers/attention/utils.py +6 -4
- sglang/srt/layers/attention/vision.py +1 -1
- sglang/srt/layers/communicator.py +451 -0
- sglang/srt/layers/dp_attention.py +61 -21
- sglang/srt/layers/layernorm.py +1 -1
- sglang/srt/layers/logits_processor.py +46 -11
- sglang/srt/layers/moe/cutlass_moe.py +207 -0
- sglang/srt/layers/moe/ep_moe/kernels.py +34 -12
- sglang/srt/layers/moe/ep_moe/layer.py +105 -51
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +82 -7
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +1 -1
- sglang/srt/layers/moe/fused_moe_triton/layer.py +14 -0
- sglang/srt/layers/moe/topk.py +67 -10
- sglang/srt/layers/multimodal.py +70 -0
- sglang/srt/layers/quantization/__init__.py +8 -3
- sglang/srt/layers/quantization/blockwise_int8.py +2 -2
- sglang/srt/layers/quantization/deep_gemm.py +77 -74
- sglang/srt/layers/quantization/fp8.py +92 -2
- sglang/srt/layers/quantization/fp8_kernel.py +3 -3
- sglang/srt/layers/quantization/fp8_utils.py +6 -0
- sglang/srt/layers/quantization/gptq.py +298 -6
- sglang/srt/layers/quantization/int8_kernel.py +20 -7
- sglang/srt/layers/quantization/qoq.py +244 -0
- sglang/srt/layers/sampler.py +0 -4
- sglang/srt/layers/vocab_parallel_embedding.py +18 -7
- sglang/srt/lora/lora_manager.py +2 -4
- sglang/srt/lora/mem_pool.py +4 -4
- sglang/srt/lora/triton_ops/gate_up_lora_b.py +1 -1
- sglang/srt/lora/triton_ops/qkv_lora_b.py +1 -1
- sglang/srt/lora/triton_ops/sgemm_lora_a.py +1 -1
- sglang/srt/lora/triton_ops/sgemm_lora_b.py +1 -1
- sglang/srt/lora/utils.py +1 -1
- sglang/srt/managers/data_parallel_controller.py +3 -3
- sglang/srt/managers/deepseek_eplb.py +278 -0
- sglang/srt/managers/detokenizer_manager.py +21 -8
- sglang/srt/managers/eplb_manager.py +55 -0
- sglang/srt/managers/expert_distribution.py +704 -56
- sglang/srt/managers/expert_location.py +394 -0
- sglang/srt/managers/expert_location_dispatch.py +91 -0
- sglang/srt/managers/io_struct.py +19 -4
- sglang/srt/managers/mm_utils.py +294 -140
- sglang/srt/managers/multimodal_processors/base_processor.py +127 -42
- sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +6 -1
- sglang/srt/managers/multimodal_processors/gemma3.py +31 -6
- sglang/srt/managers/multimodal_processors/internvl.py +14 -5
- sglang/srt/managers/multimodal_processors/janus_pro.py +7 -1
- sglang/srt/managers/multimodal_processors/kimi_vl.py +7 -6
- sglang/srt/managers/multimodal_processors/llava.py +46 -0
- sglang/srt/managers/multimodal_processors/minicpm.py +25 -31
- sglang/srt/managers/multimodal_processors/mllama4.py +6 -0
- sglang/srt/managers/multimodal_processors/pixtral.py +127 -0
- sglang/srt/managers/multimodal_processors/qwen_vl.py +58 -16
- sglang/srt/managers/schedule_batch.py +122 -42
- sglang/srt/managers/schedule_policy.py +1 -5
- sglang/srt/managers/scheduler.py +205 -138
- sglang/srt/managers/scheduler_output_processor_mixin.py +124 -55
- sglang/srt/managers/session_controller.py +1 -1
- sglang/srt/managers/tokenizer_manager.py +232 -58
- sglang/srt/managers/tp_worker.py +12 -9
- sglang/srt/managers/tp_worker_overlap_thread.py +22 -11
- sglang/srt/mem_cache/base_prefix_cache.py +3 -0
- sglang/srt/mem_cache/chunk_cache.py +3 -1
- sglang/srt/mem_cache/hiradix_cache.py +4 -4
- sglang/srt/mem_cache/memory_pool.py +76 -52
- sglang/srt/mem_cache/multimodal_cache.py +45 -0
- sglang/srt/mem_cache/radix_cache.py +58 -5
- sglang/srt/metrics/collector.py +314 -39
- sglang/srt/mm_utils.py +10 -0
- sglang/srt/model_executor/cuda_graph_runner.py +29 -19
- sglang/srt/model_executor/expert_location_updater.py +422 -0
- sglang/srt/model_executor/forward_batch_info.py +5 -1
- sglang/srt/model_executor/model_runner.py +163 -68
- sglang/srt/model_loader/loader.py +10 -6
- sglang/srt/models/clip.py +5 -1
- sglang/srt/models/deepseek_janus_pro.py +2 -2
- sglang/srt/models/deepseek_v2.py +308 -351
- sglang/srt/models/exaone.py +8 -3
- sglang/srt/models/gemma3_mm.py +70 -33
- sglang/srt/models/llama.py +2 -0
- sglang/srt/models/llama4.py +15 -8
- sglang/srt/models/llava.py +258 -7
- sglang/srt/models/mimo_mtp.py +220 -0
- sglang/srt/models/minicpmo.py +5 -12
- sglang/srt/models/mistral.py +71 -1
- sglang/srt/models/mixtral.py +98 -34
- sglang/srt/models/mllama.py +3 -3
- sglang/srt/models/pixtral.py +467 -0
- sglang/srt/models/qwen2.py +95 -26
- sglang/srt/models/qwen2_5_vl.py +8 -0
- sglang/srt/models/qwen2_moe.py +330 -60
- sglang/srt/models/qwen2_vl.py +6 -0
- sglang/srt/models/qwen3.py +52 -10
- sglang/srt/models/qwen3_moe.py +411 -48
- sglang/srt/models/roberta.py +1 -1
- sglang/srt/models/siglip.py +294 -0
- sglang/srt/models/torch_native_llama.py +1 -1
- sglang/srt/openai_api/adapter.py +58 -20
- sglang/srt/openai_api/protocol.py +6 -8
- sglang/srt/operations.py +154 -0
- sglang/srt/operations_strategy.py +31 -0
- sglang/srt/reasoning_parser.py +3 -3
- sglang/srt/sampling/custom_logit_processor.py +18 -3
- sglang/srt/sampling/sampling_batch_info.py +4 -56
- sglang/srt/sampling/sampling_params.py +2 -2
- sglang/srt/server_args.py +162 -22
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +3 -3
- sglang/srt/speculative/eagle_utils.py +138 -7
- sglang/srt/speculative/eagle_worker.py +69 -21
- sglang/srt/utils.py +74 -17
- sglang/test/few_shot_gsm8k.py +2 -2
- sglang/test/few_shot_gsm8k_engine.py +2 -2
- sglang/test/run_eval.py +2 -2
- sglang/test/runners.py +8 -1
- sglang/test/send_one.py +13 -3
- sglang/test/simple_eval_common.py +1 -1
- sglang/test/simple_eval_humaneval.py +1 -1
- sglang/test/test_cutlass_moe.py +278 -0
- sglang/test/test_programs.py +5 -5
- sglang/test/test_utils.py +55 -14
- sglang/utils.py +3 -3
- sglang/version.py +1 -1
- {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/METADATA +23 -13
- {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/RECORD +178 -149
- {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/WHEEL +1 -1
- sglang/srt/function_call_parser.py +0 -858
- sglang/srt/platforms/interface.py +0 -371
- /sglang/{llama3_eval.py → eval/llama3_eval.py} +0 -0
- /sglang/srt/models/{xiaomi_mimo.py → mimo.py} +0 -0
- {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,357 @@
|
|
1
|
+
"""
|
2
|
+
Copyright 2025 SGLang Team
|
3
|
+
Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
you may not use this file except in compliance with the License.
|
5
|
+
You may obtain a copy of the License at
|
6
|
+
|
7
|
+
http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
Unless required by applicable law or agreed to in writing, software
|
10
|
+
distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
See the License for the specific language governing permissions and
|
13
|
+
limitations under the License.
|
14
|
+
"""
|
15
|
+
|
16
|
+
"""
|
17
|
+
KV caching events
|
18
|
+
"""
|
19
|
+
|
20
|
+
import atexit
|
21
|
+
import logging
|
22
|
+
import queue
|
23
|
+
import threading
|
24
|
+
import time
|
25
|
+
from abc import ABC, abstractmethod
|
26
|
+
from collections import deque
|
27
|
+
from itertools import count
|
28
|
+
from queue import Queue
|
29
|
+
from typing import Any, Callable, Optional, Union
|
30
|
+
|
31
|
+
import msgspec
|
32
|
+
import zmq
|
33
|
+
from pydantic import BaseModel
|
34
|
+
|
35
|
+
logger = logging.getLogger(__name__)
|
36
|
+
|
37
|
+
|
38
|
+
class EventBatch(
|
39
|
+
msgspec.Struct,
|
40
|
+
array_like=True, # type: ignore[call-arg]
|
41
|
+
omit_defaults=True, # type: ignore[call-arg]
|
42
|
+
gc=False, # type: ignore[call-arg]
|
43
|
+
):
|
44
|
+
ts: float
|
45
|
+
events: list[Any]
|
46
|
+
|
47
|
+
|
48
|
+
class KVCacheEvent(
|
49
|
+
msgspec.Struct,
|
50
|
+
array_like=True, # type: ignore[call-arg]
|
51
|
+
omit_defaults=True, # type: ignore[call-arg]
|
52
|
+
gc=False, # type: ignore[call-arg]
|
53
|
+
tag=True,
|
54
|
+
):
|
55
|
+
"""Base class for all KV cache-related events"""
|
56
|
+
|
57
|
+
|
58
|
+
class BlockStored(KVCacheEvent):
|
59
|
+
block_hashes: list[int]
|
60
|
+
parent_block_hash: Optional[int]
|
61
|
+
token_ids: list[int]
|
62
|
+
block_size: int
|
63
|
+
lora_id: Optional[int]
|
64
|
+
|
65
|
+
|
66
|
+
class BlockRemoved(KVCacheEvent):
|
67
|
+
block_hashes: list[int]
|
68
|
+
|
69
|
+
|
70
|
+
class AllBlocksCleared(KVCacheEvent):
|
71
|
+
pass
|
72
|
+
|
73
|
+
|
74
|
+
class KVEventBatch(EventBatch):
|
75
|
+
events: list[Union[BlockStored, BlockRemoved, AllBlocksCleared]]
|
76
|
+
|
77
|
+
|
78
|
+
class EventPublisher(ABC):
|
79
|
+
"""Lightweight publisher for EventBatch batches."""
|
80
|
+
|
81
|
+
@abstractmethod
|
82
|
+
def publish(self, events: EventBatch) -> None:
|
83
|
+
"""Emit events in order.
|
84
|
+
|
85
|
+
Implementations should guarantee at-least-once delivery and
|
86
|
+
monotonic ordering (e.g., via sequence numbers).
|
87
|
+
"""
|
88
|
+
|
89
|
+
@abstractmethod
|
90
|
+
def shutdown(self) -> None:
|
91
|
+
"""Shutdown the publisher."""
|
92
|
+
|
93
|
+
|
94
|
+
class NullEventPublisher(EventPublisher):
|
95
|
+
"""No-op implementation (default when disabled)."""
|
96
|
+
|
97
|
+
def publish(self, events) -> None:
|
98
|
+
return
|
99
|
+
|
100
|
+
def shutdown(self) -> None:
|
101
|
+
return
|
102
|
+
|
103
|
+
|
104
|
+
class ZmqEventPublisher(EventPublisher):
|
105
|
+
"""Reliable PUB/ROUTER publisher with an in-memory replay buffer.
|
106
|
+
|
107
|
+
Spawns a separate thread to handle publishing from a queue.
|
108
|
+
|
109
|
+
Parameters
|
110
|
+
----------
|
111
|
+
endpoint:
|
112
|
+
PUB address. Use ``tcp://*:5557`` to bind or ``tcp://host:5557`` to
|
113
|
+
connect.
|
114
|
+
replay_endpoint:
|
115
|
+
Optional ROUTER address for replay requests. When given, subscribers can
|
116
|
+
request missed batches by sending the starting sequence number as an
|
117
|
+
8-byte big-endian integer.
|
118
|
+
buffer_steps:
|
119
|
+
Number of past batches to keep for replay.
|
120
|
+
hwm:
|
121
|
+
ZeroMQ high-water-mark for PUB socket.
|
122
|
+
max_queue_size:
|
123
|
+
Maximum number of events to buffer in memory.
|
124
|
+
topic:
|
125
|
+
Topic to publish events to.
|
126
|
+
"""
|
127
|
+
|
128
|
+
SHUTDOWN_TIMEOUT: float = 1.0
|
129
|
+
END_SEQ = (-1).to_bytes(8, "big", signed=True)
|
130
|
+
|
131
|
+
def __init__(
|
132
|
+
self,
|
133
|
+
endpoint: str = "tcp://*:5557",
|
134
|
+
replay_endpoint: Optional[str] = None,
|
135
|
+
buffer_steps: int = 10_000,
|
136
|
+
hwm: int = 100_000,
|
137
|
+
max_queue_size: int = 100_000,
|
138
|
+
topic: str = "",
|
139
|
+
) -> None:
|
140
|
+
# Storage
|
141
|
+
self._event_queue = Queue[Optional[EventBatch]](maxsize=max_queue_size)
|
142
|
+
self._buffer = deque[tuple[int, bytes]](maxlen=buffer_steps)
|
143
|
+
|
144
|
+
# ZMQ sockets
|
145
|
+
self._ctx = zmq.Context.instance()
|
146
|
+
self._pub: Optional[zmq.Socket] = None
|
147
|
+
self._replay: Optional[zmq.Socket] = None
|
148
|
+
self._endpoint = endpoint
|
149
|
+
self._replay_endpoint = replay_endpoint
|
150
|
+
self._hwm = hwm
|
151
|
+
self._socket_setup()
|
152
|
+
|
153
|
+
# Payload
|
154
|
+
self._seq_gen = count()
|
155
|
+
self._topic_bytes = topic.encode("utf-8")
|
156
|
+
|
157
|
+
# Thread
|
158
|
+
self._running = True
|
159
|
+
logger.info("Starting ZMQ publisher thread")
|
160
|
+
|
161
|
+
self._thread = threading.Thread(
|
162
|
+
target=self._publisher_thread, daemon=True, name="zmq-publisher"
|
163
|
+
)
|
164
|
+
self._thread.start()
|
165
|
+
|
166
|
+
atexit.register(self.shutdown)
|
167
|
+
|
168
|
+
def publish(self, events: EventBatch) -> None:
|
169
|
+
if not self._running:
|
170
|
+
raise RuntimeError("Publisher is closed")
|
171
|
+
self._event_queue.put(events)
|
172
|
+
|
173
|
+
def shutdown(self) -> None:
|
174
|
+
"""Stop the publisher thread and clean up resources."""
|
175
|
+
self._running = False
|
176
|
+
self._event_queue.put_nowait(None)
|
177
|
+
|
178
|
+
start = time.time()
|
179
|
+
pending_items = True
|
180
|
+
while pending_items and (time.time() - start < self.SHUTDOWN_TIMEOUT):
|
181
|
+
pending_items = not self._event_queue.empty()
|
182
|
+
if pending_items:
|
183
|
+
time.sleep(0.1)
|
184
|
+
|
185
|
+
if pending_items:
|
186
|
+
logger.warning(
|
187
|
+
"Warning: Queue still has %s items after %s seconds timeout",
|
188
|
+
self._event_queue.qsize(),
|
189
|
+
self.SHUTDOWN_TIMEOUT,
|
190
|
+
)
|
191
|
+
|
192
|
+
if self._thread.is_alive():
|
193
|
+
self._thread.join(timeout=self.SHUTDOWN_TIMEOUT)
|
194
|
+
|
195
|
+
# Clean up ZMQ resources
|
196
|
+
try:
|
197
|
+
if self._pub is not None:
|
198
|
+
self._pub.close(linger=0)
|
199
|
+
if self._replay is not None:
|
200
|
+
self._replay.close(linger=0)
|
201
|
+
finally:
|
202
|
+
pass # Do not terminate context; other sockets may use it
|
203
|
+
|
204
|
+
def _socket_setup(self) -> None:
|
205
|
+
"""Initialize sockets
|
206
|
+
https://pyzmq.readthedocs.io/en/v19.0.0/morethanbindings.html#thread-safety
|
207
|
+
"""
|
208
|
+
if self._pub is None:
|
209
|
+
self._pub = self._ctx.socket(zmq.PUB)
|
210
|
+
self._pub.set_hwm(self._hwm)
|
211
|
+
# Heuristic: bind if wildcard / * present, else connect.
|
212
|
+
# bind stable, connect volatile convention
|
213
|
+
if (
|
214
|
+
"*" in self._endpoint
|
215
|
+
or "::" in self._endpoint
|
216
|
+
or self._endpoint.startswith("ipc://")
|
217
|
+
or self._endpoint.startswith("inproc://")
|
218
|
+
):
|
219
|
+
self._pub.bind(self._endpoint)
|
220
|
+
else:
|
221
|
+
self._pub.connect(self._endpoint)
|
222
|
+
|
223
|
+
# Set up replay socket: use ROUTER
|
224
|
+
# 1) handles multiple REQ clients (identities)
|
225
|
+
# 2) lets us send back one request → many replies (streamed events)
|
226
|
+
# 3) works in our non‑blocking poll loop alongside PUB
|
227
|
+
if self._replay_endpoint is not None:
|
228
|
+
self._replay = self._ctx.socket(zmq.ROUTER)
|
229
|
+
self._replay.bind(self._replay_endpoint)
|
230
|
+
|
231
|
+
def _publisher_thread(self) -> None:
|
232
|
+
"""Background thread that processes the event queue."""
|
233
|
+
self._pack = msgspec.msgpack.Encoder()
|
234
|
+
|
235
|
+
assert self._pub is not None # narrows type for mypy
|
236
|
+
|
237
|
+
while self._running or self._event_queue.qsize() > 0:
|
238
|
+
# --- replay (non-critical) ---------------------------------
|
239
|
+
if self._replay is not None and self._replay.poll(0):
|
240
|
+
try:
|
241
|
+
self._service_replay()
|
242
|
+
except Exception as e:
|
243
|
+
logger.exception("Error in replay: %s", e)
|
244
|
+
|
245
|
+
# --- main queue (critical) ---------------------------------
|
246
|
+
try:
|
247
|
+
event = self._event_queue.get(timeout=0.1)
|
248
|
+
if event is None:
|
249
|
+
break # Sentinel received, exit thread
|
250
|
+
except queue.Empty:
|
251
|
+
continue
|
252
|
+
|
253
|
+
try:
|
254
|
+
seq = next(self._seq_gen)
|
255
|
+
|
256
|
+
payload = self._pack.encode(event)
|
257
|
+
seq_bytes = seq.to_bytes(8, "big")
|
258
|
+
self._pub.send_multipart((self._topic_bytes, seq_bytes, payload))
|
259
|
+
|
260
|
+
self._buffer.append((seq, payload))
|
261
|
+
self._event_queue.task_done()
|
262
|
+
|
263
|
+
except Exception as e:
|
264
|
+
# Publishing failed; back-off a bit to avoid a tight error loop
|
265
|
+
logger.exception("Error in publisher thread: %s", e)
|
266
|
+
time.sleep(0.1)
|
267
|
+
|
268
|
+
def _service_replay(self) -> None:
|
269
|
+
"""If a replay request is waiting, send buffered batches."""
|
270
|
+
assert self._replay is not None # narrows type for mypy
|
271
|
+
|
272
|
+
frame = self._replay.recv_multipart()
|
273
|
+
if len(frame) != 3:
|
274
|
+
logger.warning("Invalid replay request: %s", frame)
|
275
|
+
return
|
276
|
+
client_id, _, start_seq_bytes = frame
|
277
|
+
start_seq = int.from_bytes(start_seq_bytes, "big")
|
278
|
+
|
279
|
+
for seq, buf in self._buffer:
|
280
|
+
if seq >= start_seq:
|
281
|
+
# [identity, empty_delim, seq_bytes, payload]
|
282
|
+
# (identity, empty_delim) are stripped off by the router
|
283
|
+
# receiving payload is (seq_bytes, payload)
|
284
|
+
self._replay.send_multipart(
|
285
|
+
(client_id, b"", seq.to_bytes(8, "big"), buf)
|
286
|
+
)
|
287
|
+
# Send end of sequence marker
|
288
|
+
# receiving payload is (-1, b""")
|
289
|
+
self._replay.send_multipart((client_id, b"", self.END_SEQ, b""))
|
290
|
+
|
291
|
+
|
292
|
+
class KVEventsConfig(BaseModel):
|
293
|
+
"""Configuration for KV event publishing."""
|
294
|
+
|
295
|
+
publisher: str = "null"
|
296
|
+
"""The publisher to use for publishing kv events. Can be "null", "zmq".
|
297
|
+
"""
|
298
|
+
|
299
|
+
endpoint: str = "tcp://*:5557"
|
300
|
+
"""The zmq endpoint to use for publishing kv events.
|
301
|
+
"""
|
302
|
+
|
303
|
+
replay_endpoint: Optional[str] = None
|
304
|
+
"""The zmq endpoint to use for replaying kv events.
|
305
|
+
"""
|
306
|
+
|
307
|
+
buffer_steps: int = 10_000
|
308
|
+
"""The number of steps to cache for replay endpoint. Will only save
|
309
|
+
events from the last N steps for the replay endpoint.
|
310
|
+
"""
|
311
|
+
|
312
|
+
hwm: int = 100_000
|
313
|
+
"""The zmq high water mark for the event publisher. After queueing N events,
|
314
|
+
events will start dropping if the consumer is not keeping up.
|
315
|
+
"""
|
316
|
+
|
317
|
+
max_queue_size: int = 100_000
|
318
|
+
"""The maximum number of events to queue while waiting for publishing.
|
319
|
+
"""
|
320
|
+
|
321
|
+
topic: str = ""
|
322
|
+
"""The topic to use for the event publisher. Consumers can subscribe to
|
323
|
+
this topic to receive events.
|
324
|
+
"""
|
325
|
+
|
326
|
+
@classmethod
|
327
|
+
def from_cli(cls, cli_value: str) -> "KVEventsConfig":
|
328
|
+
"""Parse the CLI value for the event publisher config."""
|
329
|
+
return KVEventsConfig.model_validate_json(cli_value)
|
330
|
+
|
331
|
+
|
332
|
+
class EventPublisherFactory:
|
333
|
+
_registry: dict[str, Callable[..., EventPublisher]] = {
|
334
|
+
"null": NullEventPublisher,
|
335
|
+
"zmq": ZmqEventPublisher,
|
336
|
+
}
|
337
|
+
|
338
|
+
@classmethod
|
339
|
+
def register_publisher(cls, name: str, ctor: Callable[..., EventPublisher]) -> None:
|
340
|
+
if name in cls._registry:
|
341
|
+
raise KeyError(f"publisher '{name}' already registered")
|
342
|
+
cls._registry[name] = ctor
|
343
|
+
|
344
|
+
@classmethod
|
345
|
+
def create(cls, config: Optional[str]) -> EventPublisher:
|
346
|
+
"""Create publisher from a config mapping."""
|
347
|
+
if not config:
|
348
|
+
return NullEventPublisher()
|
349
|
+
config = KVEventsConfig.from_cli(config)
|
350
|
+
config_dict = config.model_dump()
|
351
|
+
|
352
|
+
kind = config_dict.pop("publisher", "null")
|
353
|
+
try:
|
354
|
+
constructor = cls._registry[kind]
|
355
|
+
except KeyError as exc:
|
356
|
+
raise ValueError(f"Unknown event publisher '{kind}'") from exc
|
357
|
+
return constructor(**config_dict)
|
@@ -73,11 +73,27 @@ class MiniLoadBalancer:
|
|
73
73
|
session.post(f"{prefill_server}/{endpoint}", json=modified_request),
|
74
74
|
session.post(f"{decode_server}/{endpoint}", json=modified_request),
|
75
75
|
]
|
76
|
+
|
76
77
|
# Wait for both responses to complete. Prefill should end first.
|
77
|
-
|
78
|
+
prefill_response, decode_response = await asyncio.gather(*tasks)
|
79
|
+
|
80
|
+
if "return_logprob" in modified_request:
|
81
|
+
|
82
|
+
prefill_json = await prefill_response.json()
|
83
|
+
ret_json = await decode_response.json()
|
84
|
+
|
85
|
+
# merge `meta_info.input_token_logprobs` from prefill to decode
|
86
|
+
if "meta_info" in ret_json:
|
87
|
+
if "input_token_logprobs" in ret_json["meta_info"]:
|
88
|
+
ret_json["meta_info"]["input_token_logprobs"] = (
|
89
|
+
prefill_json["meta_info"]["input_token_logprobs"]
|
90
|
+
+ ret_json["meta_info"]["input_token_logprobs"]
|
91
|
+
)
|
92
|
+
else:
|
93
|
+
ret_json = await decode_response.json()
|
78
94
|
|
79
95
|
return ORJSONResponse(
|
80
|
-
content=
|
96
|
+
content=ret_json,
|
81
97
|
status_code=decode_response.status,
|
82
98
|
)
|
83
99
|
|
@@ -92,30 +108,47 @@ class MiniLoadBalancer:
|
|
92
108
|
total=3600
|
93
109
|
) # Add timeout for request reliability
|
94
110
|
) as session:
|
95
|
-
|
96
|
-
|
97
|
-
|
98
|
-
|
99
|
-
|
100
|
-
|
101
|
-
|
102
|
-
|
103
|
-
|
104
|
-
]
|
105
|
-
|
106
|
-
|
111
|
+
# Create the tasks for both prefill and decode requests
|
112
|
+
tasks = [
|
113
|
+
session.post(f"{prefill_server}/generate", json=modified_request),
|
114
|
+
session.post(f"{decode_server}/generate", json=modified_request),
|
115
|
+
]
|
116
|
+
# Wait for both responses to complete. Since this is streaming, they return immediately.
|
117
|
+
prefill_response, decode_response = await asyncio.gather(*tasks)
|
118
|
+
|
119
|
+
if modified_request.get("return_logprob", False):
|
120
|
+
prefill_chunks = []
|
121
|
+
async for chunk in prefill_response.content:
|
122
|
+
prefill_chunks.append(chunk)
|
123
|
+
|
124
|
+
first_prefill_chunk = (
|
125
|
+
prefill_chunks[0].decode("utf-8")[5:].strip("\n")
|
126
|
+
)
|
127
|
+
first_prefill_chunk_json = orjson.loads(first_prefill_chunk)
|
128
|
+
|
129
|
+
async for chunk in decode_response.content:
|
130
|
+
# Note: This is inefficient
|
131
|
+
# merge prefill input_token_logprobs, output_token_logprobs to decode
|
132
|
+
decoded_chunk = chunk.decode("utf-8")
|
133
|
+
if (
|
134
|
+
decoded_chunk
|
135
|
+
and decoded_chunk.startswith("data:")
|
136
|
+
and "[DONE]" not in decoded_chunk
|
137
|
+
):
|
138
|
+
ret_json = orjson.loads(decoded_chunk[5:].strip("\n"))
|
139
|
+
ret_json["meta_info"]["input_token_logprobs"] = (
|
140
|
+
first_prefill_chunk_json["meta_info"][
|
141
|
+
"input_token_logprobs"
|
142
|
+
]
|
143
|
+
+ ret_json["meta_info"]["input_token_logprobs"]
|
144
|
+
)
|
145
|
+
|
146
|
+
yield b"data: " + orjson.dumps(ret_json) + b"\n\n"
|
147
|
+
else:
|
148
|
+
yield chunk
|
149
|
+
else:
|
107
150
|
async for chunk in decode_response.content:
|
108
151
|
yield chunk
|
109
|
-
except Exception as e:
|
110
|
-
error_msg = {
|
111
|
-
"error": {"message": f"Stream processing error: {str(e)}"}
|
112
|
-
}
|
113
|
-
yield b"data: " + orjson.dumps(
|
114
|
-
error_msg, option=orjson.OPT_NON_STR_KEYS
|
115
|
-
) + b"\n\n"
|
116
|
-
finally:
|
117
|
-
if prefill_response is not None:
|
118
|
-
await prefill_response.release()
|
119
152
|
|
120
153
|
return StreamingResponse(
|
121
154
|
stream_results(),
|