sglang 0.5.1.post3__py3-none-any.whl → 0.5.2rc0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +3 -0
- sglang/srt/configs/__init__.py +2 -0
- sglang/srt/configs/longcat_flash.py +104 -0
- sglang/srt/configs/model_config.py +12 -0
- sglang/srt/connector/__init__.py +1 -1
- sglang/srt/connector/base_connector.py +1 -2
- sglang/srt/connector/redis.py +2 -2
- sglang/srt/connector/serde/__init__.py +1 -1
- sglang/srt/connector/serde/safe_serde.py +4 -3
- sglang/srt/disaggregation/ascend/conn.py +75 -0
- sglang/srt/disaggregation/launch_lb.py +0 -13
- sglang/srt/disaggregation/mini_lb.py +33 -8
- sglang/srt/disaggregation/prefill.py +1 -1
- sglang/srt/distributed/parallel_state.py +24 -14
- sglang/srt/entrypoints/engine.py +19 -12
- sglang/srt/entrypoints/http_server.py +174 -34
- sglang/srt/entrypoints/openai/protocol.py +60 -0
- sglang/srt/eplb/eplb_manager.py +26 -2
- sglang/srt/eplb/expert_distribution.py +29 -2
- sglang/srt/hf_transformers_utils.py +10 -0
- sglang/srt/layers/activation.py +12 -0
- sglang/srt/layers/attention/ascend_backend.py +240 -109
- sglang/srt/layers/attention/hybrid_attn_backend.py +53 -21
- sglang/srt/layers/attention/trtllm_mla_backend.py +25 -10
- sglang/srt/layers/layernorm.py +28 -3
- sglang/srt/layers/linear.py +3 -2
- sglang/srt/layers/logits_processor.py +1 -1
- sglang/srt/layers/moe/ep_moe/kernels.py +74 -0
- sglang/srt/layers/moe/ep_moe/layer.py +12 -6
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/topk.py +35 -12
- sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py +1 -1
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +0 -3
- sglang/srt/layers/quantization/modelopt_quant.py +7 -0
- sglang/srt/layers/quantization/mxfp4.py +9 -4
- sglang/srt/layers/quantization/utils.py +13 -0
- sglang/srt/layers/quantization/w8a8_int8.py +7 -3
- sglang/srt/layers/rotary_embedding.py +28 -1
- sglang/srt/layers/sampler.py +29 -5
- sglang/srt/managers/cache_controller.py +62 -96
- sglang/srt/managers/detokenizer_manager.py +43 -2
- sglang/srt/managers/io_struct.py +27 -0
- sglang/srt/managers/mm_utils.py +5 -1
- sglang/srt/managers/multi_tokenizer_mixin.py +591 -0
- sglang/srt/managers/scheduler.py +36 -2
- sglang/srt/managers/scheduler_output_processor_mixin.py +20 -18
- sglang/srt/managers/scheduler_update_weights_mixin.py +8 -1
- sglang/srt/managers/tokenizer_manager.py +86 -39
- sglang/srt/mem_cache/chunk_cache.py +1 -1
- sglang/srt/mem_cache/hicache_storage.py +20 -3
- sglang/srt/mem_cache/hiradix_cache.py +75 -68
- sglang/srt/mem_cache/lora_radix_cache.py +1 -1
- sglang/srt/mem_cache/memory_pool.py +4 -0
- sglang/srt/mem_cache/memory_pool_host.py +2 -4
- sglang/srt/mem_cache/radix_cache.py +5 -4
- sglang/srt/mem_cache/radix_cache_cpp.py +1 -1
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +33 -7
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +2 -1
- sglang/srt/mem_cache/swa_radix_cache.py +1 -1
- sglang/srt/model_executor/model_runner.py +5 -4
- sglang/srt/model_loader/loader.py +15 -24
- sglang/srt/model_loader/utils.py +12 -0
- sglang/srt/models/deepseek_v2.py +26 -10
- sglang/srt/models/gpt_oss.py +0 -14
- sglang/srt/models/llama_eagle3.py +4 -0
- sglang/srt/models/longcat_flash.py +1015 -0
- sglang/srt/models/longcat_flash_nextn.py +691 -0
- sglang/srt/models/qwen2.py +26 -3
- sglang/srt/models/qwen2_5_vl.py +65 -41
- sglang/srt/models/qwen2_moe.py +22 -2
- sglang/srt/models/transformers.py +1 -1
- sglang/srt/multimodal/processors/base_processor.py +4 -2
- sglang/srt/sampling/penaltylib/orchestrator.py +14 -2
- sglang/srt/server_args.py +112 -55
- sglang/srt/speculative/eagle_worker.py +28 -8
- sglang/srt/utils.py +14 -0
- sglang/test/attention/test_trtllm_mla_backend.py +12 -3
- sglang/version.py +1 -1
- {sglang-0.5.1.post3.dist-info → sglang-0.5.2rc0.dist-info}/METADATA +5 -5
- {sglang-0.5.1.post3.dist-info → sglang-0.5.2rc0.dist-info}/RECORD +83 -78
- {sglang-0.5.1.post3.dist-info → sglang-0.5.2rc0.dist-info}/WHEEL +0 -0
- {sglang-0.5.1.post3.dist-info → sglang-0.5.2rc0.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.1.post3.dist-info → sglang-0.5.2rc0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,591 @@
|
|
1
|
+
# Copyright 2023-2024 SGLang Team
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
#
|
6
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
7
|
+
#
|
8
|
+
# Unless required by applicable law or agreed to in writing, software
|
9
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
10
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
11
|
+
# See the License for the specific language governing permissions and
|
12
|
+
# limitations under the License.
|
13
|
+
# ==============================================================================
|
14
|
+
"""MultiTokenizerMixin is a class that provides nesscary methods for MultiTokenizerManager and DetokenizerManager."""
|
15
|
+
import asyncio
|
16
|
+
import dataclasses
|
17
|
+
import json
|
18
|
+
import logging
|
19
|
+
import multiprocessing as multiprocessing
|
20
|
+
import os
|
21
|
+
import sys
|
22
|
+
import threading
|
23
|
+
from multiprocessing import shared_memory
|
24
|
+
from typing import Dict
|
25
|
+
|
26
|
+
import zmq
|
27
|
+
import zmq.asyncio
|
28
|
+
|
29
|
+
from sglang.srt.disaggregation.utils import DisaggregationMode, TransferBackend
|
30
|
+
from sglang.srt.managers.io_struct import (
|
31
|
+
BatchEmbeddingOut,
|
32
|
+
BatchMultimodalOut,
|
33
|
+
BatchStrOut,
|
34
|
+
BatchTokenIDOut,
|
35
|
+
MultiTokenizerRegisterReq,
|
36
|
+
MultiTokenizerWarpper,
|
37
|
+
)
|
38
|
+
from sglang.srt.managers.tokenizer_manager import TokenizerManager, _Communicator
|
39
|
+
from sglang.srt.server_args import PortArgs, ServerArgs
|
40
|
+
from sglang.srt.utils import (
|
41
|
+
get_worker_ids_from_req_rids,
|
42
|
+
get_zmq_socket,
|
43
|
+
kill_process_tree,
|
44
|
+
)
|
45
|
+
from sglang.utils import get_exception_traceback
|
46
|
+
|
47
|
+
logger = logging.getLogger(__name__)
|
48
|
+
|
49
|
+
|
50
|
+
class MultiTokenizerMixin:
|
51
|
+
"""Mixin class for MultiTokenizerManager and DetokenizerManager"""
|
52
|
+
|
53
|
+
def create_sockets_mapping(self):
|
54
|
+
if not hasattr(self, "tokenizer_mapping"):
|
55
|
+
self.tokenizer_mapping = {}
|
56
|
+
# Create ZMQ context if needed
|
57
|
+
if not hasattr(self, "_zmq_context"):
|
58
|
+
self._zmq_context = zmq.Context()
|
59
|
+
|
60
|
+
def init_tokenizer_mapping(
|
61
|
+
self, recv_obj: MultiTokenizerRegisterReq, worker_id: str
|
62
|
+
):
|
63
|
+
"""init tokenizer mapping from register request"""
|
64
|
+
ipc_name = recv_obj.ipc_name
|
65
|
+
worker_id_int = int(worker_id)
|
66
|
+
|
67
|
+
if worker_id_int not in self.tokenizer_mapping:
|
68
|
+
socket = get_zmq_socket(self._zmq_context, zmq.PUSH, ipc_name, False)
|
69
|
+
self.tokenizer_mapping[worker_id_int] = socket
|
70
|
+
self.tokenizer_mapping[worker_id_int].send_pyobj(recv_obj)
|
71
|
+
return True
|
72
|
+
else:
|
73
|
+
return False
|
74
|
+
|
75
|
+
def register_tokenizer_ipc(self, recv_obj, worker_id):
|
76
|
+
if worker_id not in self.tokenizer_mapping:
|
77
|
+
# register the worker if not already done
|
78
|
+
if isinstance(recv_obj, MultiTokenizerRegisterReq):
|
79
|
+
return self.init_tokenizer_mapping(recv_obj, worker_id)
|
80
|
+
else:
|
81
|
+
logger.error(
|
82
|
+
f"Worker {worker_id} not registered and not found in tokenizer mapping . "
|
83
|
+
"Please ensure the worker is registered correctly."
|
84
|
+
)
|
85
|
+
return False
|
86
|
+
|
87
|
+
def _handle_output_by_index(self, output, i):
|
88
|
+
"""NOTE: A maintainable method is better here."""
|
89
|
+
if isinstance(output, BatchTokenIDOut):
|
90
|
+
new_output = BatchTokenIDOut(
|
91
|
+
rids=[output.rids[i]],
|
92
|
+
finished_reasons=(
|
93
|
+
[output.finished_reasons[i]]
|
94
|
+
if len(output.finished_reasons) > i
|
95
|
+
else None
|
96
|
+
),
|
97
|
+
decoded_texts=(
|
98
|
+
[output.decoded_texts[i]] if len(output.decoded_texts) > i else None
|
99
|
+
),
|
100
|
+
decode_ids=(
|
101
|
+
[output.decode_ids[i]] if len(output.decode_ids) > i else None
|
102
|
+
),
|
103
|
+
read_offsets=(
|
104
|
+
[output.read_offsets[i]] if len(output.read_offsets) > i else None
|
105
|
+
),
|
106
|
+
output_ids=(
|
107
|
+
[output.output_ids[i]]
|
108
|
+
if output.output_ids and len(output.output_ids) > i
|
109
|
+
else None
|
110
|
+
),
|
111
|
+
skip_special_tokens=(
|
112
|
+
[output.skip_special_tokens[i]]
|
113
|
+
if len(output.skip_special_tokens) > i
|
114
|
+
else None
|
115
|
+
),
|
116
|
+
spaces_between_special_tokens=(
|
117
|
+
[output.spaces_between_special_tokens[i]]
|
118
|
+
if len(output.spaces_between_special_tokens) > i
|
119
|
+
else None
|
120
|
+
),
|
121
|
+
no_stop_trim=(
|
122
|
+
[output.no_stop_trim[i]] if len(output.no_stop_trim) > i else None
|
123
|
+
),
|
124
|
+
prompt_tokens=(
|
125
|
+
[output.prompt_tokens[i]] if len(output.prompt_tokens) > i else None
|
126
|
+
),
|
127
|
+
completion_tokens=(
|
128
|
+
[output.completion_tokens[i]]
|
129
|
+
if len(output.completion_tokens) > i
|
130
|
+
else None
|
131
|
+
),
|
132
|
+
cached_tokens=(
|
133
|
+
[output.cached_tokens[i]] if len(output.cached_tokens) > i else None
|
134
|
+
),
|
135
|
+
spec_verify_ct=(
|
136
|
+
[output.spec_verify_ct[i]]
|
137
|
+
if len(output.spec_verify_ct) > i
|
138
|
+
else None
|
139
|
+
),
|
140
|
+
input_token_logprobs_val=(
|
141
|
+
[output.input_token_logprobs_val[i]]
|
142
|
+
if output.input_token_logprobs_val
|
143
|
+
else None
|
144
|
+
),
|
145
|
+
input_token_logprobs_idx=(
|
146
|
+
[output.input_token_logprobs_idx[i]]
|
147
|
+
if output.input_token_logprobs_idx
|
148
|
+
else None
|
149
|
+
),
|
150
|
+
output_token_logprobs_val=(
|
151
|
+
[output.output_token_logprobs_val[i]]
|
152
|
+
if output.output_token_logprobs_val
|
153
|
+
else None
|
154
|
+
),
|
155
|
+
output_token_logprobs_idx=(
|
156
|
+
[output.output_token_logprobs_idx[i]]
|
157
|
+
if output.output_token_logprobs_idx
|
158
|
+
else None
|
159
|
+
),
|
160
|
+
input_top_logprobs_val=(
|
161
|
+
[output.input_top_logprobs_val[i]]
|
162
|
+
if output.input_top_logprobs_val
|
163
|
+
else None
|
164
|
+
),
|
165
|
+
input_top_logprobs_idx=(
|
166
|
+
[output.input_top_logprobs_idx[i]]
|
167
|
+
if output.input_top_logprobs_idx
|
168
|
+
else None
|
169
|
+
),
|
170
|
+
output_top_logprobs_val=(
|
171
|
+
[output.output_top_logprobs_val[i]]
|
172
|
+
if output.output_top_logprobs_val
|
173
|
+
else None
|
174
|
+
),
|
175
|
+
output_top_logprobs_idx=(
|
176
|
+
[output.output_top_logprobs_idx[i]]
|
177
|
+
if output.output_top_logprobs_idx
|
178
|
+
else None
|
179
|
+
),
|
180
|
+
input_token_ids_logprobs_val=(
|
181
|
+
[output.input_token_ids_logprobs_val[i]]
|
182
|
+
if output.input_token_ids_logprobs_val
|
183
|
+
else None
|
184
|
+
),
|
185
|
+
input_token_ids_logprobs_idx=(
|
186
|
+
[output.input_token_ids_logprobs_idx[i]]
|
187
|
+
if output.input_token_ids_logprobs_idx
|
188
|
+
else None
|
189
|
+
),
|
190
|
+
output_token_ids_logprobs_val=(
|
191
|
+
[output.output_token_ids_logprobs_val[i]]
|
192
|
+
if output.output_token_ids_logprobs_val
|
193
|
+
else None
|
194
|
+
),
|
195
|
+
output_token_ids_logprobs_idx=(
|
196
|
+
[output.output_token_ids_logprobs_idx[i]]
|
197
|
+
if output.output_token_ids_logprobs_idx
|
198
|
+
else None
|
199
|
+
),
|
200
|
+
output_hidden_states=(
|
201
|
+
[output.output_hidden_states[i]]
|
202
|
+
if output.output_hidden_states
|
203
|
+
else None
|
204
|
+
),
|
205
|
+
)
|
206
|
+
elif isinstance(output, BatchEmbeddingOut):
|
207
|
+
new_output = BatchEmbeddingOut(
|
208
|
+
rids=[output.rids[i]],
|
209
|
+
finished_reasons=(
|
210
|
+
[output.finished_reasons[i]]
|
211
|
+
if len(output.finished_reasons) > i
|
212
|
+
else None
|
213
|
+
),
|
214
|
+
embeddings=(
|
215
|
+
[output.embeddings[i]] if len(output.embeddings) > i else None
|
216
|
+
),
|
217
|
+
prompt_tokens=(
|
218
|
+
[output.prompt_tokens[i]] if len(output.prompt_tokens) > i else None
|
219
|
+
),
|
220
|
+
cached_tokens=(
|
221
|
+
[output.cached_tokens[i]] if len(output.cached_tokens) > i else None
|
222
|
+
),
|
223
|
+
)
|
224
|
+
elif isinstance(output, BatchStrOut):
|
225
|
+
new_output = BatchStrOut(
|
226
|
+
rids=[output.rids[i]],
|
227
|
+
finished_reasons=(
|
228
|
+
[output.finished_reasons[i]]
|
229
|
+
if len(output.finished_reasons) > i
|
230
|
+
else None
|
231
|
+
),
|
232
|
+
output_strs=(
|
233
|
+
[output.output_strs[i]] if len(output.output_strs) > i else None
|
234
|
+
),
|
235
|
+
output_ids=(
|
236
|
+
[output.output_ids[i]]
|
237
|
+
if output.output_ids and len(output.output_ids) > i
|
238
|
+
else None
|
239
|
+
),
|
240
|
+
prompt_tokens=(
|
241
|
+
[output.prompt_tokens[i]] if len(output.prompt_tokens) > i else None
|
242
|
+
),
|
243
|
+
completion_tokens=(
|
244
|
+
[output.completion_tokens[i]]
|
245
|
+
if len(output.completion_tokens) > i
|
246
|
+
else None
|
247
|
+
),
|
248
|
+
cached_tokens=(
|
249
|
+
[output.cached_tokens[i]] if len(output.cached_tokens) > i else None
|
250
|
+
),
|
251
|
+
spec_verify_ct=(
|
252
|
+
[output.spec_verify_ct[i]]
|
253
|
+
if len(output.spec_verify_ct) > i
|
254
|
+
else None
|
255
|
+
),
|
256
|
+
input_token_logprobs_val=(
|
257
|
+
[output.input_token_logprobs_val[i]]
|
258
|
+
if output.input_token_logprobs_val
|
259
|
+
else None
|
260
|
+
),
|
261
|
+
input_token_logprobs_idx=(
|
262
|
+
[output.input_token_logprobs_idx[i]]
|
263
|
+
if output.input_token_logprobs_idx
|
264
|
+
else None
|
265
|
+
),
|
266
|
+
output_token_logprobs_val=(
|
267
|
+
[output.output_token_logprobs_val[i]]
|
268
|
+
if output.output_token_logprobs_val
|
269
|
+
else None
|
270
|
+
),
|
271
|
+
output_token_logprobs_idx=(
|
272
|
+
[output.output_token_logprobs_idx[i]]
|
273
|
+
if output.output_token_logprobs_idx
|
274
|
+
else None
|
275
|
+
),
|
276
|
+
input_top_logprobs_val=(
|
277
|
+
[output.input_top_logprobs_val[i]]
|
278
|
+
if output.input_top_logprobs_val
|
279
|
+
else None
|
280
|
+
),
|
281
|
+
input_top_logprobs_idx=(
|
282
|
+
[output.input_top_logprobs_idx[i]]
|
283
|
+
if output.input_top_logprobs_idx
|
284
|
+
else None
|
285
|
+
),
|
286
|
+
output_top_logprobs_val=(
|
287
|
+
[output.output_top_logprobs_val[i]]
|
288
|
+
if output.output_top_logprobs_val
|
289
|
+
else None
|
290
|
+
),
|
291
|
+
output_top_logprobs_idx=(
|
292
|
+
[output.output_top_logprobs_idx[i]]
|
293
|
+
if output.output_top_logprobs_idx
|
294
|
+
else None
|
295
|
+
),
|
296
|
+
input_token_ids_logprobs_val=(
|
297
|
+
[output.input_token_ids_logprobs_val[i]]
|
298
|
+
if output.input_token_ids_logprobs_val
|
299
|
+
else None
|
300
|
+
),
|
301
|
+
input_token_ids_logprobs_idx=(
|
302
|
+
[output.input_token_ids_logprobs_idx[i]]
|
303
|
+
if output.input_token_ids_logprobs_idx
|
304
|
+
else None
|
305
|
+
),
|
306
|
+
output_token_ids_logprobs_val=(
|
307
|
+
[output.output_token_ids_logprobs_val[i]]
|
308
|
+
if output.output_token_ids_logprobs_val
|
309
|
+
else None
|
310
|
+
),
|
311
|
+
output_token_ids_logprobs_idx=(
|
312
|
+
[output.output_token_ids_logprobs_idx[i]]
|
313
|
+
if output.output_token_ids_logprobs_idx
|
314
|
+
else None
|
315
|
+
),
|
316
|
+
output_hidden_states=(
|
317
|
+
[output.output_hidden_states[i]]
|
318
|
+
if output.output_hidden_states
|
319
|
+
else None
|
320
|
+
),
|
321
|
+
)
|
322
|
+
elif isinstance(output, BatchMultimodalOut):
|
323
|
+
new_output = BatchMultimodalOut(
|
324
|
+
rids=[output.rids[i]],
|
325
|
+
finished_reasons=(
|
326
|
+
[output.finished_reasons[i]]
|
327
|
+
if len(output.finished_reasons) > i
|
328
|
+
else None
|
329
|
+
),
|
330
|
+
outputs=([output.outputs[i]] if len(output.outputs) > i else None),
|
331
|
+
prompt_tokens=(
|
332
|
+
[output.prompt_tokens[i]] if len(output.prompt_tokens) > i else None
|
333
|
+
),
|
334
|
+
completion_tokens=(
|
335
|
+
[output.completion_tokens[i]]
|
336
|
+
if len(output.completion_tokens) > i
|
337
|
+
else None
|
338
|
+
),
|
339
|
+
cached_tokens=(
|
340
|
+
[output.cached_tokens[i]] if len(output.cached_tokens) > i else None
|
341
|
+
),
|
342
|
+
)
|
343
|
+
else:
|
344
|
+
new_output = output
|
345
|
+
return new_output
|
346
|
+
|
347
|
+
def clear_tokenizer_mapping(self):
|
348
|
+
if hasattr(self, "tokenizer_mapping"):
|
349
|
+
for socket in self.tokenizer_mapping.values():
|
350
|
+
try:
|
351
|
+
socket.close()
|
352
|
+
except Exception as e:
|
353
|
+
logger.warning(f"Failed to close socket: {e}")
|
354
|
+
self.tokenizer_mapping.clear()
|
355
|
+
|
356
|
+
|
357
|
+
class MultiTokenizerRouter(TokenizerManager, MultiTokenizerMixin):
|
358
|
+
"""A router to receive requests from MultiTokenizerManager"""
|
359
|
+
|
360
|
+
def __init__(
|
361
|
+
self,
|
362
|
+
server_args: ServerArgs,
|
363
|
+
port_args: PortArgs,
|
364
|
+
):
|
365
|
+
self.server_args = server_args
|
366
|
+
context = zmq.asyncio.Context(3)
|
367
|
+
self.recv_from_detokenizer = get_zmq_socket(
|
368
|
+
context, zmq.PULL, port_args.tokenizer_ipc_name, True
|
369
|
+
)
|
370
|
+
self.send_to_scheduler = get_zmq_socket(
|
371
|
+
context, zmq.PUSH, port_args.scheduler_input_ipc_name, True
|
372
|
+
)
|
373
|
+
self.receive_from_worker = get_zmq_socket(
|
374
|
+
context, zmq.PULL, port_args.tokenizer_worker_ipc_name, True
|
375
|
+
)
|
376
|
+
self._loop = asyncio.new_event_loop()
|
377
|
+
self._thread = threading.Thread(target=self._run_loop, daemon=True)
|
378
|
+
self._thread.start()
|
379
|
+
self._task = asyncio.run_coroutine_threadsafe(
|
380
|
+
self.router_worker_obj(), self._loop
|
381
|
+
)
|
382
|
+
# Start handle_loop simultaneously
|
383
|
+
self._handle_task = asyncio.run_coroutine_threadsafe(
|
384
|
+
print_exception_wrapper(self.handle_loop), self._loop
|
385
|
+
)
|
386
|
+
self.init_disaggregation()
|
387
|
+
|
388
|
+
def _run_loop(self):
|
389
|
+
self._loop.run_forever()
|
390
|
+
|
391
|
+
async def router_worker_obj(self):
|
392
|
+
while True:
|
393
|
+
recv_obj = await self.receive_from_worker.recv_pyobj()
|
394
|
+
await self.send_to_scheduler.send_pyobj(recv_obj)
|
395
|
+
|
396
|
+
async def handle_loop(self):
|
397
|
+
# special reqs will recv from scheduler, need to route to right worker
|
398
|
+
self.create_sockets_mapping()
|
399
|
+
while True:
|
400
|
+
recv_obj = await self.recv_from_detokenizer.recv_pyobj()
|
401
|
+
await self._distribute_result_to_workers(recv_obj)
|
402
|
+
|
403
|
+
async def _distribute_result_to_workers(self, recv_obj):
|
404
|
+
"""Distribute result to corresponding workers based on rid"""
|
405
|
+
if isinstance(recv_obj, MultiTokenizerWarpper):
|
406
|
+
worker_ids = [recv_obj.worker_id]
|
407
|
+
recv_obj = recv_obj.obj
|
408
|
+
else:
|
409
|
+
worker_ids = get_worker_ids_from_req_rids(recv_obj.rids)
|
410
|
+
|
411
|
+
if len(worker_ids) == 0:
|
412
|
+
logger.error(f"Cannot find worker_id from rids {recv_obj.rids}")
|
413
|
+
return
|
414
|
+
|
415
|
+
# Distribute result to each worker
|
416
|
+
for i, worker_id in enumerate(worker_ids):
|
417
|
+
if isinstance(recv_obj, MultiTokenizerRegisterReq):
|
418
|
+
if self.register_tokenizer_ipc(recv_obj, worker_id):
|
419
|
+
logger.info(
|
420
|
+
f"MultiTokenizerRouter Created ZMQ socket for worker {worker_id}"
|
421
|
+
)
|
422
|
+
continue
|
423
|
+
else:
|
424
|
+
if worker_id not in self.tokenizer_mapping:
|
425
|
+
logger.error(
|
426
|
+
f"Tokenizer Worker ID {worker_id} not registered. Check if the server Process {worker_id} is alive"
|
427
|
+
)
|
428
|
+
continue
|
429
|
+
new_recv_obj = self._handle_output_by_index(recv_obj, i)
|
430
|
+
self.tokenizer_mapping[worker_id].send_pyobj(new_recv_obj)
|
431
|
+
|
432
|
+
|
433
|
+
class MultiTokenizerManager(TokenizerManager, MultiTokenizerMixin):
|
434
|
+
"""Multi Process Tokenizer Manager that tokenizes the text."""
|
435
|
+
|
436
|
+
def __init__(
|
437
|
+
self,
|
438
|
+
server_args: ServerArgs,
|
439
|
+
port_args: PortArgs,
|
440
|
+
):
|
441
|
+
# prevent init prefill bootstrapserver again
|
442
|
+
disaggregation_mode = server_args.disaggregation_mode
|
443
|
+
server_args.disaggregation_mode = "null"
|
444
|
+
super().__init__(server_args, port_args)
|
445
|
+
|
446
|
+
self.worker_id = os.getpid()
|
447
|
+
self.tokenizer_ipc_name = port_args.tokenizer_ipc_name
|
448
|
+
|
449
|
+
# For PD disaggregtion
|
450
|
+
self.server_args.disaggregation_mode = disaggregation_mode
|
451
|
+
self.disaggregation_mode = DisaggregationMode(
|
452
|
+
self.server_args.disaggregation_mode
|
453
|
+
)
|
454
|
+
self.disaggregation_transfer_backend = TransferBackend(
|
455
|
+
self.server_args.disaggregation_transfer_backend
|
456
|
+
)
|
457
|
+
# Communicator
|
458
|
+
self.register_multi_tokenizer_communicator = _Communicator(
|
459
|
+
self.send_to_scheduler, 2
|
460
|
+
)
|
461
|
+
self._result_dispatcher._mapping.append(
|
462
|
+
(
|
463
|
+
MultiTokenizerRegisterReq,
|
464
|
+
self.register_multi_tokenizer_communicator.handle_recv,
|
465
|
+
)
|
466
|
+
)
|
467
|
+
|
468
|
+
async def register_to_main_tokenizer_manager(self):
|
469
|
+
"""Register this worker to the main TokenizerManager"""
|
470
|
+
# create a handle loop to receive messages from the main TokenizerManager
|
471
|
+
self.auto_create_handle_loop()
|
472
|
+
req = MultiTokenizerRegisterReq(rids=[f"{self.worker_id}_register"])
|
473
|
+
req.ipc_name = self.tokenizer_ipc_name
|
474
|
+
_Communicator.enable_multi_tokenizer = True
|
475
|
+
await self.register_multi_tokenizer_communicator(req)
|
476
|
+
|
477
|
+
|
478
|
+
async def print_exception_wrapper(func):
|
479
|
+
"""
|
480
|
+
Sometimes an asyncio function does not print exception.
|
481
|
+
We do another wrapper to handle the exception.
|
482
|
+
"""
|
483
|
+
try:
|
484
|
+
await func()
|
485
|
+
except Exception:
|
486
|
+
traceback = get_exception_traceback()
|
487
|
+
logger.error(f"MultiTokenizerRouter hit an exception: {traceback}")
|
488
|
+
if hasattr(func, "__self__") and isinstance(
|
489
|
+
func.__self__, MultiTokenizerRouter
|
490
|
+
):
|
491
|
+
func.__self__.dump_requests_before_crash()
|
492
|
+
kill_process_tree(os.getpid(), include_parent=True)
|
493
|
+
sys.exit(1)
|
494
|
+
|
495
|
+
|
496
|
+
def serialize_port_args(port_args: PortArgs) -> dict:
|
497
|
+
"""Serialize PortArgs into a shareable dictionary"""
|
498
|
+
return {
|
499
|
+
"tokenizer_ipc_name": port_args.tokenizer_ipc_name,
|
500
|
+
"scheduler_input_ipc_name": port_args.scheduler_input_ipc_name,
|
501
|
+
"detokenizer_ipc_name": port_args.detokenizer_ipc_name,
|
502
|
+
"nccl_port": port_args.nccl_port,
|
503
|
+
"rpc_ipc_name": port_args.rpc_ipc_name,
|
504
|
+
"metrics_ipc_name": port_args.metrics_ipc_name,
|
505
|
+
"tokenizer_worker_ipc_name": port_args.tokenizer_worker_ipc_name,
|
506
|
+
}
|
507
|
+
|
508
|
+
|
509
|
+
def deserialize_data(port_args: dict, server_args: dict):
|
510
|
+
"""Deserialize data from shared dictionaries"""
|
511
|
+
return PortArgs(**port_args), ServerArgs(**server_args)
|
512
|
+
|
513
|
+
|
514
|
+
def serialize_server_args(server_args: ServerArgs) -> dict:
|
515
|
+
"""Serialize ServerArgs into a shareable dictionary"""
|
516
|
+
return dataclasses.asdict(server_args)
|
517
|
+
|
518
|
+
|
519
|
+
def serialize_scheduler_info(scheduler_info: Dict) -> dict:
|
520
|
+
"""Serialize scheduler_info into a shareable dictionary"""
|
521
|
+
return scheduler_info
|
522
|
+
|
523
|
+
|
524
|
+
def deserialize_scheduler_info(data: dict) -> Dict:
|
525
|
+
"""Deserialize scheduler_info from a shared dictionary"""
|
526
|
+
return data
|
527
|
+
|
528
|
+
|
529
|
+
def write_to_shared_memory(data: dict, name: str) -> shared_memory.SharedMemory:
|
530
|
+
"""Write data to shared memory"""
|
531
|
+
serialized = json.dumps(data).encode("utf-8")
|
532
|
+
size = len(serialized)
|
533
|
+
try:
|
534
|
+
# Try to open existing shared memory
|
535
|
+
shm = shared_memory.SharedMemory(name=name)
|
536
|
+
# If size is insufficient, close and recreate
|
537
|
+
if shm.size < size:
|
538
|
+
shm.close()
|
539
|
+
shm.unlink()
|
540
|
+
shm = shared_memory.SharedMemory(create=True, size=size, name=name)
|
541
|
+
except FileNotFoundError:
|
542
|
+
# If not present, create new shared memory
|
543
|
+
shm = shared_memory.SharedMemory(create=True, size=size, name=name)
|
544
|
+
|
545
|
+
shm.buf[:size] = serialized
|
546
|
+
return shm
|
547
|
+
|
548
|
+
|
549
|
+
def read_from_shared_memory(name: str) -> dict:
|
550
|
+
"""Read data from shared memory"""
|
551
|
+
try:
|
552
|
+
shm = shared_memory.SharedMemory(name=name)
|
553
|
+
data = json.loads(bytes(shm.buf).decode("utf-8"))
|
554
|
+
shm.close()
|
555
|
+
return data
|
556
|
+
except FileNotFoundError:
|
557
|
+
raise FileNotFoundError(f"Shared memory {name} not found")
|
558
|
+
|
559
|
+
|
560
|
+
def get_main_process_id() -> int:
|
561
|
+
"""Get the main process ID"""
|
562
|
+
return multiprocessing.current_process()._parent_pid
|
563
|
+
|
564
|
+
|
565
|
+
def write_data_for_multi_tokenizer(
|
566
|
+
port_args: PortArgs, server_args: ServerArgs, scheduler_info: Dict
|
567
|
+
):
|
568
|
+
"""Write args information to share memory for multi-tokenizer"""
|
569
|
+
# get main process ID
|
570
|
+
main_pid = get_main_process_id()
|
571
|
+
current_pid = os.getpid()
|
572
|
+
logger.info(f"main process ID: {main_pid}, current process ID: {current_pid}")
|
573
|
+
|
574
|
+
# Write port_args to shared memory
|
575
|
+
port_args_shm = write_to_shared_memory(
|
576
|
+
serialize_port_args(port_args), f"port_args_{current_pid}"
|
577
|
+
)
|
578
|
+
# Write server_args to shared memory
|
579
|
+
server_args_shm = write_to_shared_memory(
|
580
|
+
serialize_server_args(server_args), f"server_args_{current_pid}"
|
581
|
+
)
|
582
|
+
# Write scheduler_info to shared memory
|
583
|
+
scheduler_info_shm = write_to_shared_memory(
|
584
|
+
serialize_scheduler_info(scheduler_info), f"scheduler_info_{current_pid}"
|
585
|
+
)
|
586
|
+
|
587
|
+
port_args_shm.close()
|
588
|
+
server_args_shm.close()
|
589
|
+
scheduler_info_shm.close()
|
590
|
+
|
591
|
+
return port_args_shm, server_args_shm, scheduler_info_shm
|
sglang/srt/managers/scheduler.py
CHANGED
@@ -69,6 +69,8 @@ from sglang.srt.managers.io_struct import (
|
|
69
69
|
AbortReq,
|
70
70
|
BatchTokenizedEmbeddingReqInput,
|
71
71
|
BatchTokenizedGenerateReqInput,
|
72
|
+
ClearHiCacheReqInput,
|
73
|
+
ClearHiCacheReqOutput,
|
72
74
|
CloseSessionReqInput,
|
73
75
|
ExpertDistributionReq,
|
74
76
|
ExpertDistributionReqOutput,
|
@@ -82,6 +84,8 @@ from sglang.srt.managers.io_struct import (
|
|
82
84
|
InitWeightsUpdateGroupReqInput,
|
83
85
|
LoadLoRAAdapterReqInput,
|
84
86
|
LoadLoRAAdapterReqOutput,
|
87
|
+
MultiTokenizerRegisterReq,
|
88
|
+
MultiTokenizerWarpper,
|
85
89
|
OpenSessionReqInput,
|
86
90
|
OpenSessionReqOutput,
|
87
91
|
ProfileReq,
|
@@ -255,7 +259,6 @@ class Scheduler(
|
|
255
259
|
# Init inter-process communication
|
256
260
|
context = zmq.Context(2)
|
257
261
|
self.idle_sleeper = None
|
258
|
-
|
259
262
|
if self.pp_rank == 0 and self.attn_tp_rank == 0:
|
260
263
|
self.recv_from_tokenizer = get_zmq_socket(
|
261
264
|
context, zmq.PULL, port_args.scheduler_input_ipc_name, False
|
@@ -515,6 +518,7 @@ class Scheduler(
|
|
515
518
|
(BatchTokenizedGenerateReqInput, self.handle_batch_generate_request),
|
516
519
|
(BatchTokenizedEmbeddingReqInput, self.handle_batch_embedding_request),
|
517
520
|
(FlushCacheReqInput, self.flush_cache_wrapped),
|
521
|
+
(ClearHiCacheReqInput, self.clear_hicache_storage_wrapped),
|
518
522
|
(AbortReq, self.abort_request),
|
519
523
|
(OpenSessionReqInput, self.open_session),
|
520
524
|
(CloseSessionReqInput, self.close_session),
|
@@ -537,6 +541,7 @@ class Scheduler(
|
|
537
541
|
(ExpertDistributionReq, self.expert_distribution_handle),
|
538
542
|
(LoadLoRAAdapterReqInput, self.load_lora_adapter),
|
539
543
|
(UnloadLoRAAdapterReqInput, self.unload_lora_adapter),
|
544
|
+
(MultiTokenizerRegisterReq, self.register_multi_tokenizer),
|
540
545
|
]
|
541
546
|
)
|
542
547
|
|
@@ -1098,6 +1103,17 @@ class Scheduler(
|
|
1098
1103
|
)
|
1099
1104
|
self.send_to_tokenizer.send_pyobj(abort_req)
|
1100
1105
|
continue
|
1106
|
+
|
1107
|
+
# If it is a MultiTokenizerWarpper, unwrap it and handle the inner request.
|
1108
|
+
if isinstance(recv_req, MultiTokenizerWarpper):
|
1109
|
+
worker_id = recv_req.worker_id
|
1110
|
+
recv_req = recv_req.obj
|
1111
|
+
output = self._request_dispatcher(recv_req)
|
1112
|
+
if output is not None:
|
1113
|
+
output = MultiTokenizerWarpper(worker_id, output)
|
1114
|
+
self.send_to_tokenizer.send_pyobj(output)
|
1115
|
+
continue
|
1116
|
+
|
1101
1117
|
output = self._request_dispatcher(recv_req)
|
1102
1118
|
if output is not None:
|
1103
1119
|
if isinstance(output, RpcReqOutput):
|
@@ -1503,7 +1519,7 @@ class Scheduler(
|
|
1503
1519
|
# Move the chunked request out of the batch so that we can merge
|
1504
1520
|
# only finished requests to running_batch.
|
1505
1521
|
chunked_req_to_exclude.add(self.chunked_req)
|
1506
|
-
self.tree_cache.cache_unfinished_req(self.chunked_req)
|
1522
|
+
self.tree_cache.cache_unfinished_req(self.chunked_req, chunked=True)
|
1507
1523
|
# chunked request keeps its rid but will get a new req_pool_idx
|
1508
1524
|
self.req_to_token_pool.free(self.chunked_req.req_pool_idx)
|
1509
1525
|
if self.last_batch and self.last_batch.forward_mode.is_extend():
|
@@ -2207,6 +2223,16 @@ class Scheduler(
|
|
2207
2223
|
success = self.flush_cache()
|
2208
2224
|
return FlushCacheReqOutput(success=success)
|
2209
2225
|
|
2226
|
+
def clear_hicache_storage_wrapped(self, recv_req: ClearHiCacheReqInput):
|
2227
|
+
if self.enable_hierarchical_cache:
|
2228
|
+
self.tree_cache.clear_storage_backend()
|
2229
|
+
logger.info("Hierarchical cache cleared successfully!")
|
2230
|
+
if_success = True
|
2231
|
+
else:
|
2232
|
+
logging.warning("Hierarchical cache is not enabled.")
|
2233
|
+
if_success = False
|
2234
|
+
return ClearHiCacheReqOutput(success=if_success)
|
2235
|
+
|
2210
2236
|
def flush_cache(self):
|
2211
2237
|
"""Flush the memory pool and cache."""
|
2212
2238
|
if (
|
@@ -2378,6 +2404,10 @@ class Scheduler(
|
|
2378
2404
|
# We still need to send something back to TokenizerManager to clean up the state.
|
2379
2405
|
req = self.waiting_queue.pop(i)
|
2380
2406
|
self.send_to_tokenizer.send_pyobj(AbortReq(req.rid))
|
2407
|
+
# For disaggregation decode mode, the request in the waiting queue has KV cache allocated.
|
2408
|
+
if self.disaggregation_mode == DisaggregationMode.DECODE:
|
2409
|
+
self.tree_cache.cache_finished_req(req)
|
2410
|
+
|
2381
2411
|
logger.debug(f"Abort queued request. {req.rid=}")
|
2382
2412
|
|
2383
2413
|
# Delete the requests in the grammar queue
|
@@ -2457,6 +2487,10 @@ class Scheduler(
|
|
2457
2487
|
result = self.tp_worker.unload_lora_adapter(recv_req)
|
2458
2488
|
return result
|
2459
2489
|
|
2490
|
+
def register_multi_tokenizer(self, recv_req: MultiTokenizerRegisterReq):
|
2491
|
+
self.send_to_detokenizer.send_pyobj(recv_req)
|
2492
|
+
return recv_req
|
2493
|
+
|
2460
2494
|
def slow_down(self, recv_req: SlowDownReqInput):
|
2461
2495
|
t = recv_req.forward_sleep_time
|
2462
2496
|
if t is not None and t <= 0:
|