sglang 0.5.1.post2__py3-none-any.whl → 0.5.2rc0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +3 -0
- sglang/bench_one_batch_server.py +79 -53
- sglang/bench_serving.py +186 -14
- sglang/profiler.py +0 -1
- sglang/srt/configs/__init__.py +2 -0
- sglang/srt/configs/longcat_flash.py +104 -0
- sglang/srt/configs/model_config.py +12 -0
- sglang/srt/connector/__init__.py +1 -1
- sglang/srt/connector/base_connector.py +1 -2
- sglang/srt/connector/redis.py +2 -2
- sglang/srt/connector/serde/__init__.py +1 -1
- sglang/srt/connector/serde/safe_serde.py +4 -3
- sglang/srt/conversation.py +38 -5
- sglang/srt/disaggregation/ascend/conn.py +75 -0
- sglang/srt/disaggregation/launch_lb.py +0 -13
- sglang/srt/disaggregation/mini_lb.py +33 -8
- sglang/srt/disaggregation/prefill.py +1 -1
- sglang/srt/distributed/parallel_state.py +24 -14
- sglang/srt/entrypoints/engine.py +19 -12
- sglang/srt/entrypoints/http_server.py +174 -34
- sglang/srt/entrypoints/openai/protocol.py +87 -24
- sglang/srt/entrypoints/openai/serving_chat.py +50 -9
- sglang/srt/entrypoints/openai/serving_completions.py +15 -0
- sglang/srt/eplb/eplb_manager.py +26 -2
- sglang/srt/eplb/expert_distribution.py +29 -2
- sglang/srt/function_call/deepseekv31_detector.py +222 -0
- sglang/srt/function_call/function_call_parser.py +2 -0
- sglang/srt/function_call/gpt_oss_detector.py +144 -256
- sglang/srt/harmony_parser.py +588 -0
- sglang/srt/hf_transformers_utils.py +26 -7
- sglang/srt/layers/activation.py +12 -0
- sglang/srt/layers/attention/ascend_backend.py +374 -136
- sglang/srt/layers/attention/flashattention_backend.py +241 -7
- sglang/srt/layers/attention/flashinfer_backend.py +5 -2
- sglang/srt/layers/attention/flashinfer_mla_backend.py +5 -2
- sglang/srt/layers/attention/hybrid_attn_backend.py +53 -21
- sglang/srt/layers/attention/trtllm_mla_backend.py +25 -10
- sglang/srt/layers/communicator.py +1 -2
- sglang/srt/layers/layernorm.py +28 -3
- sglang/srt/layers/linear.py +3 -2
- sglang/srt/layers/logits_processor.py +1 -1
- sglang/srt/layers/moe/cutlass_moe.py +0 -8
- sglang/srt/layers/moe/ep_moe/kernels.py +74 -0
- sglang/srt/layers/moe/ep_moe/layer.py +13 -13
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=64,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/topk.py +35 -12
- sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py +133 -235
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +5 -10
- sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +5 -23
- sglang/srt/layers/quantization/fp8.py +2 -1
- sglang/srt/layers/quantization/fp8_kernel.py +2 -2
- sglang/srt/layers/quantization/fp8_utils.py +2 -2
- sglang/srt/layers/quantization/modelopt_quant.py +7 -0
- sglang/srt/layers/quantization/mxfp4.py +25 -27
- sglang/srt/layers/quantization/mxfp4_tensor.py +3 -1
- sglang/srt/layers/quantization/utils.py +13 -0
- sglang/srt/layers/quantization/w8a8_int8.py +7 -3
- sglang/srt/layers/rotary_embedding.py +28 -1
- sglang/srt/layers/sampler.py +29 -5
- sglang/srt/layers/utils.py +0 -14
- sglang/srt/managers/cache_controller.py +237 -204
- sglang/srt/managers/detokenizer_manager.py +48 -2
- sglang/srt/managers/io_struct.py +57 -0
- sglang/srt/managers/mm_utils.py +5 -1
- sglang/srt/managers/multi_tokenizer_mixin.py +591 -0
- sglang/srt/managers/scheduler.py +94 -9
- sglang/srt/managers/scheduler_output_processor_mixin.py +20 -18
- sglang/srt/managers/scheduler_update_weights_mixin.py +8 -1
- sglang/srt/managers/tokenizer_manager.py +122 -42
- sglang/srt/mem_cache/chunk_cache.py +1 -1
- sglang/srt/mem_cache/hicache_storage.py +51 -23
- sglang/srt/mem_cache/hiradix_cache.py +87 -71
- sglang/srt/mem_cache/lora_radix_cache.py +1 -1
- sglang/srt/mem_cache/memory_pool.py +77 -14
- sglang/srt/mem_cache/memory_pool_host.py +4 -5
- sglang/srt/mem_cache/radix_cache.py +6 -4
- sglang/srt/mem_cache/radix_cache_cpp.py +1 -1
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +38 -20
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +87 -82
- sglang/srt/mem_cache/swa_radix_cache.py +1 -1
- sglang/srt/model_executor/model_runner.py +6 -5
- sglang/srt/model_loader/loader.py +15 -24
- sglang/srt/model_loader/utils.py +12 -0
- sglang/srt/models/deepseek_v2.py +38 -13
- sglang/srt/models/gpt_oss.py +2 -15
- sglang/srt/models/llama_eagle3.py +4 -0
- sglang/srt/models/longcat_flash.py +1015 -0
- sglang/srt/models/longcat_flash_nextn.py +691 -0
- sglang/srt/models/qwen2.py +26 -3
- sglang/srt/models/qwen2_5_vl.py +66 -41
- sglang/srt/models/qwen2_moe.py +22 -2
- sglang/srt/models/transformers.py +1 -1
- sglang/srt/multimodal/processors/base_processor.py +4 -2
- sglang/srt/reasoning_parser.py +56 -300
- sglang/srt/sampling/penaltylib/orchestrator.py +14 -2
- sglang/srt/server_args.py +122 -56
- sglang/srt/speculative/eagle_worker.py +28 -8
- sglang/srt/tokenizer/tiktoken_tokenizer.py +6 -1
- sglang/srt/utils.py +73 -5
- sglang/test/attention/test_trtllm_mla_backend.py +12 -3
- sglang/version.py +1 -1
- {sglang-0.5.1.post2.dist-info → sglang-0.5.2rc0.dist-info}/METADATA +7 -6
- {sglang-0.5.1.post2.dist-info → sglang-0.5.2rc0.dist-info}/RECORD +107 -99
- {sglang-0.5.1.post2.dist-info → sglang-0.5.2rc0.dist-info}/WHEEL +0 -0
- {sglang-0.5.1.post2.dist-info → sglang-0.5.2rc0.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.1.post2.dist-info → sglang-0.5.2rc0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,591 @@
|
|
1
|
+
# Copyright 2023-2024 SGLang Team
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
#
|
6
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
7
|
+
#
|
8
|
+
# Unless required by applicable law or agreed to in writing, software
|
9
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
10
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
11
|
+
# See the License for the specific language governing permissions and
|
12
|
+
# limitations under the License.
|
13
|
+
# ==============================================================================
|
14
|
+
"""MultiTokenizerMixin is a class that provides nesscary methods for MultiTokenizerManager and DetokenizerManager."""
|
15
|
+
import asyncio
|
16
|
+
import dataclasses
|
17
|
+
import json
|
18
|
+
import logging
|
19
|
+
import multiprocessing as multiprocessing
|
20
|
+
import os
|
21
|
+
import sys
|
22
|
+
import threading
|
23
|
+
from multiprocessing import shared_memory
|
24
|
+
from typing import Dict
|
25
|
+
|
26
|
+
import zmq
|
27
|
+
import zmq.asyncio
|
28
|
+
|
29
|
+
from sglang.srt.disaggregation.utils import DisaggregationMode, TransferBackend
|
30
|
+
from sglang.srt.managers.io_struct import (
|
31
|
+
BatchEmbeddingOut,
|
32
|
+
BatchMultimodalOut,
|
33
|
+
BatchStrOut,
|
34
|
+
BatchTokenIDOut,
|
35
|
+
MultiTokenizerRegisterReq,
|
36
|
+
MultiTokenizerWarpper,
|
37
|
+
)
|
38
|
+
from sglang.srt.managers.tokenizer_manager import TokenizerManager, _Communicator
|
39
|
+
from sglang.srt.server_args import PortArgs, ServerArgs
|
40
|
+
from sglang.srt.utils import (
|
41
|
+
get_worker_ids_from_req_rids,
|
42
|
+
get_zmq_socket,
|
43
|
+
kill_process_tree,
|
44
|
+
)
|
45
|
+
from sglang.utils import get_exception_traceback
|
46
|
+
|
47
|
+
logger = logging.getLogger(__name__)
|
48
|
+
|
49
|
+
|
50
|
+
class MultiTokenizerMixin:
|
51
|
+
"""Mixin class for MultiTokenizerManager and DetokenizerManager"""
|
52
|
+
|
53
|
+
def create_sockets_mapping(self):
|
54
|
+
if not hasattr(self, "tokenizer_mapping"):
|
55
|
+
self.tokenizer_mapping = {}
|
56
|
+
# Create ZMQ context if needed
|
57
|
+
if not hasattr(self, "_zmq_context"):
|
58
|
+
self._zmq_context = zmq.Context()
|
59
|
+
|
60
|
+
def init_tokenizer_mapping(
|
61
|
+
self, recv_obj: MultiTokenizerRegisterReq, worker_id: str
|
62
|
+
):
|
63
|
+
"""init tokenizer mapping from register request"""
|
64
|
+
ipc_name = recv_obj.ipc_name
|
65
|
+
worker_id_int = int(worker_id)
|
66
|
+
|
67
|
+
if worker_id_int not in self.tokenizer_mapping:
|
68
|
+
socket = get_zmq_socket(self._zmq_context, zmq.PUSH, ipc_name, False)
|
69
|
+
self.tokenizer_mapping[worker_id_int] = socket
|
70
|
+
self.tokenizer_mapping[worker_id_int].send_pyobj(recv_obj)
|
71
|
+
return True
|
72
|
+
else:
|
73
|
+
return False
|
74
|
+
|
75
|
+
def register_tokenizer_ipc(self, recv_obj, worker_id):
|
76
|
+
if worker_id not in self.tokenizer_mapping:
|
77
|
+
# register the worker if not already done
|
78
|
+
if isinstance(recv_obj, MultiTokenizerRegisterReq):
|
79
|
+
return self.init_tokenizer_mapping(recv_obj, worker_id)
|
80
|
+
else:
|
81
|
+
logger.error(
|
82
|
+
f"Worker {worker_id} not registered and not found in tokenizer mapping . "
|
83
|
+
"Please ensure the worker is registered correctly."
|
84
|
+
)
|
85
|
+
return False
|
86
|
+
|
87
|
+
def _handle_output_by_index(self, output, i):
|
88
|
+
"""NOTE: A maintainable method is better here."""
|
89
|
+
if isinstance(output, BatchTokenIDOut):
|
90
|
+
new_output = BatchTokenIDOut(
|
91
|
+
rids=[output.rids[i]],
|
92
|
+
finished_reasons=(
|
93
|
+
[output.finished_reasons[i]]
|
94
|
+
if len(output.finished_reasons) > i
|
95
|
+
else None
|
96
|
+
),
|
97
|
+
decoded_texts=(
|
98
|
+
[output.decoded_texts[i]] if len(output.decoded_texts) > i else None
|
99
|
+
),
|
100
|
+
decode_ids=(
|
101
|
+
[output.decode_ids[i]] if len(output.decode_ids) > i else None
|
102
|
+
),
|
103
|
+
read_offsets=(
|
104
|
+
[output.read_offsets[i]] if len(output.read_offsets) > i else None
|
105
|
+
),
|
106
|
+
output_ids=(
|
107
|
+
[output.output_ids[i]]
|
108
|
+
if output.output_ids and len(output.output_ids) > i
|
109
|
+
else None
|
110
|
+
),
|
111
|
+
skip_special_tokens=(
|
112
|
+
[output.skip_special_tokens[i]]
|
113
|
+
if len(output.skip_special_tokens) > i
|
114
|
+
else None
|
115
|
+
),
|
116
|
+
spaces_between_special_tokens=(
|
117
|
+
[output.spaces_between_special_tokens[i]]
|
118
|
+
if len(output.spaces_between_special_tokens) > i
|
119
|
+
else None
|
120
|
+
),
|
121
|
+
no_stop_trim=(
|
122
|
+
[output.no_stop_trim[i]] if len(output.no_stop_trim) > i else None
|
123
|
+
),
|
124
|
+
prompt_tokens=(
|
125
|
+
[output.prompt_tokens[i]] if len(output.prompt_tokens) > i else None
|
126
|
+
),
|
127
|
+
completion_tokens=(
|
128
|
+
[output.completion_tokens[i]]
|
129
|
+
if len(output.completion_tokens) > i
|
130
|
+
else None
|
131
|
+
),
|
132
|
+
cached_tokens=(
|
133
|
+
[output.cached_tokens[i]] if len(output.cached_tokens) > i else None
|
134
|
+
),
|
135
|
+
spec_verify_ct=(
|
136
|
+
[output.spec_verify_ct[i]]
|
137
|
+
if len(output.spec_verify_ct) > i
|
138
|
+
else None
|
139
|
+
),
|
140
|
+
input_token_logprobs_val=(
|
141
|
+
[output.input_token_logprobs_val[i]]
|
142
|
+
if output.input_token_logprobs_val
|
143
|
+
else None
|
144
|
+
),
|
145
|
+
input_token_logprobs_idx=(
|
146
|
+
[output.input_token_logprobs_idx[i]]
|
147
|
+
if output.input_token_logprobs_idx
|
148
|
+
else None
|
149
|
+
),
|
150
|
+
output_token_logprobs_val=(
|
151
|
+
[output.output_token_logprobs_val[i]]
|
152
|
+
if output.output_token_logprobs_val
|
153
|
+
else None
|
154
|
+
),
|
155
|
+
output_token_logprobs_idx=(
|
156
|
+
[output.output_token_logprobs_idx[i]]
|
157
|
+
if output.output_token_logprobs_idx
|
158
|
+
else None
|
159
|
+
),
|
160
|
+
input_top_logprobs_val=(
|
161
|
+
[output.input_top_logprobs_val[i]]
|
162
|
+
if output.input_top_logprobs_val
|
163
|
+
else None
|
164
|
+
),
|
165
|
+
input_top_logprobs_idx=(
|
166
|
+
[output.input_top_logprobs_idx[i]]
|
167
|
+
if output.input_top_logprobs_idx
|
168
|
+
else None
|
169
|
+
),
|
170
|
+
output_top_logprobs_val=(
|
171
|
+
[output.output_top_logprobs_val[i]]
|
172
|
+
if output.output_top_logprobs_val
|
173
|
+
else None
|
174
|
+
),
|
175
|
+
output_top_logprobs_idx=(
|
176
|
+
[output.output_top_logprobs_idx[i]]
|
177
|
+
if output.output_top_logprobs_idx
|
178
|
+
else None
|
179
|
+
),
|
180
|
+
input_token_ids_logprobs_val=(
|
181
|
+
[output.input_token_ids_logprobs_val[i]]
|
182
|
+
if output.input_token_ids_logprobs_val
|
183
|
+
else None
|
184
|
+
),
|
185
|
+
input_token_ids_logprobs_idx=(
|
186
|
+
[output.input_token_ids_logprobs_idx[i]]
|
187
|
+
if output.input_token_ids_logprobs_idx
|
188
|
+
else None
|
189
|
+
),
|
190
|
+
output_token_ids_logprobs_val=(
|
191
|
+
[output.output_token_ids_logprobs_val[i]]
|
192
|
+
if output.output_token_ids_logprobs_val
|
193
|
+
else None
|
194
|
+
),
|
195
|
+
output_token_ids_logprobs_idx=(
|
196
|
+
[output.output_token_ids_logprobs_idx[i]]
|
197
|
+
if output.output_token_ids_logprobs_idx
|
198
|
+
else None
|
199
|
+
),
|
200
|
+
output_hidden_states=(
|
201
|
+
[output.output_hidden_states[i]]
|
202
|
+
if output.output_hidden_states
|
203
|
+
else None
|
204
|
+
),
|
205
|
+
)
|
206
|
+
elif isinstance(output, BatchEmbeddingOut):
|
207
|
+
new_output = BatchEmbeddingOut(
|
208
|
+
rids=[output.rids[i]],
|
209
|
+
finished_reasons=(
|
210
|
+
[output.finished_reasons[i]]
|
211
|
+
if len(output.finished_reasons) > i
|
212
|
+
else None
|
213
|
+
),
|
214
|
+
embeddings=(
|
215
|
+
[output.embeddings[i]] if len(output.embeddings) > i else None
|
216
|
+
),
|
217
|
+
prompt_tokens=(
|
218
|
+
[output.prompt_tokens[i]] if len(output.prompt_tokens) > i else None
|
219
|
+
),
|
220
|
+
cached_tokens=(
|
221
|
+
[output.cached_tokens[i]] if len(output.cached_tokens) > i else None
|
222
|
+
),
|
223
|
+
)
|
224
|
+
elif isinstance(output, BatchStrOut):
|
225
|
+
new_output = BatchStrOut(
|
226
|
+
rids=[output.rids[i]],
|
227
|
+
finished_reasons=(
|
228
|
+
[output.finished_reasons[i]]
|
229
|
+
if len(output.finished_reasons) > i
|
230
|
+
else None
|
231
|
+
),
|
232
|
+
output_strs=(
|
233
|
+
[output.output_strs[i]] if len(output.output_strs) > i else None
|
234
|
+
),
|
235
|
+
output_ids=(
|
236
|
+
[output.output_ids[i]]
|
237
|
+
if output.output_ids and len(output.output_ids) > i
|
238
|
+
else None
|
239
|
+
),
|
240
|
+
prompt_tokens=(
|
241
|
+
[output.prompt_tokens[i]] if len(output.prompt_tokens) > i else None
|
242
|
+
),
|
243
|
+
completion_tokens=(
|
244
|
+
[output.completion_tokens[i]]
|
245
|
+
if len(output.completion_tokens) > i
|
246
|
+
else None
|
247
|
+
),
|
248
|
+
cached_tokens=(
|
249
|
+
[output.cached_tokens[i]] if len(output.cached_tokens) > i else None
|
250
|
+
),
|
251
|
+
spec_verify_ct=(
|
252
|
+
[output.spec_verify_ct[i]]
|
253
|
+
if len(output.spec_verify_ct) > i
|
254
|
+
else None
|
255
|
+
),
|
256
|
+
input_token_logprobs_val=(
|
257
|
+
[output.input_token_logprobs_val[i]]
|
258
|
+
if output.input_token_logprobs_val
|
259
|
+
else None
|
260
|
+
),
|
261
|
+
input_token_logprobs_idx=(
|
262
|
+
[output.input_token_logprobs_idx[i]]
|
263
|
+
if output.input_token_logprobs_idx
|
264
|
+
else None
|
265
|
+
),
|
266
|
+
output_token_logprobs_val=(
|
267
|
+
[output.output_token_logprobs_val[i]]
|
268
|
+
if output.output_token_logprobs_val
|
269
|
+
else None
|
270
|
+
),
|
271
|
+
output_token_logprobs_idx=(
|
272
|
+
[output.output_token_logprobs_idx[i]]
|
273
|
+
if output.output_token_logprobs_idx
|
274
|
+
else None
|
275
|
+
),
|
276
|
+
input_top_logprobs_val=(
|
277
|
+
[output.input_top_logprobs_val[i]]
|
278
|
+
if output.input_top_logprobs_val
|
279
|
+
else None
|
280
|
+
),
|
281
|
+
input_top_logprobs_idx=(
|
282
|
+
[output.input_top_logprobs_idx[i]]
|
283
|
+
if output.input_top_logprobs_idx
|
284
|
+
else None
|
285
|
+
),
|
286
|
+
output_top_logprobs_val=(
|
287
|
+
[output.output_top_logprobs_val[i]]
|
288
|
+
if output.output_top_logprobs_val
|
289
|
+
else None
|
290
|
+
),
|
291
|
+
output_top_logprobs_idx=(
|
292
|
+
[output.output_top_logprobs_idx[i]]
|
293
|
+
if output.output_top_logprobs_idx
|
294
|
+
else None
|
295
|
+
),
|
296
|
+
input_token_ids_logprobs_val=(
|
297
|
+
[output.input_token_ids_logprobs_val[i]]
|
298
|
+
if output.input_token_ids_logprobs_val
|
299
|
+
else None
|
300
|
+
),
|
301
|
+
input_token_ids_logprobs_idx=(
|
302
|
+
[output.input_token_ids_logprobs_idx[i]]
|
303
|
+
if output.input_token_ids_logprobs_idx
|
304
|
+
else None
|
305
|
+
),
|
306
|
+
output_token_ids_logprobs_val=(
|
307
|
+
[output.output_token_ids_logprobs_val[i]]
|
308
|
+
if output.output_token_ids_logprobs_val
|
309
|
+
else None
|
310
|
+
),
|
311
|
+
output_token_ids_logprobs_idx=(
|
312
|
+
[output.output_token_ids_logprobs_idx[i]]
|
313
|
+
if output.output_token_ids_logprobs_idx
|
314
|
+
else None
|
315
|
+
),
|
316
|
+
output_hidden_states=(
|
317
|
+
[output.output_hidden_states[i]]
|
318
|
+
if output.output_hidden_states
|
319
|
+
else None
|
320
|
+
),
|
321
|
+
)
|
322
|
+
elif isinstance(output, BatchMultimodalOut):
|
323
|
+
new_output = BatchMultimodalOut(
|
324
|
+
rids=[output.rids[i]],
|
325
|
+
finished_reasons=(
|
326
|
+
[output.finished_reasons[i]]
|
327
|
+
if len(output.finished_reasons) > i
|
328
|
+
else None
|
329
|
+
),
|
330
|
+
outputs=([output.outputs[i]] if len(output.outputs) > i else None),
|
331
|
+
prompt_tokens=(
|
332
|
+
[output.prompt_tokens[i]] if len(output.prompt_tokens) > i else None
|
333
|
+
),
|
334
|
+
completion_tokens=(
|
335
|
+
[output.completion_tokens[i]]
|
336
|
+
if len(output.completion_tokens) > i
|
337
|
+
else None
|
338
|
+
),
|
339
|
+
cached_tokens=(
|
340
|
+
[output.cached_tokens[i]] if len(output.cached_tokens) > i else None
|
341
|
+
),
|
342
|
+
)
|
343
|
+
else:
|
344
|
+
new_output = output
|
345
|
+
return new_output
|
346
|
+
|
347
|
+
def clear_tokenizer_mapping(self):
|
348
|
+
if hasattr(self, "tokenizer_mapping"):
|
349
|
+
for socket in self.tokenizer_mapping.values():
|
350
|
+
try:
|
351
|
+
socket.close()
|
352
|
+
except Exception as e:
|
353
|
+
logger.warning(f"Failed to close socket: {e}")
|
354
|
+
self.tokenizer_mapping.clear()
|
355
|
+
|
356
|
+
|
357
|
+
class MultiTokenizerRouter(TokenizerManager, MultiTokenizerMixin):
|
358
|
+
"""A router to receive requests from MultiTokenizerManager"""
|
359
|
+
|
360
|
+
def __init__(
|
361
|
+
self,
|
362
|
+
server_args: ServerArgs,
|
363
|
+
port_args: PortArgs,
|
364
|
+
):
|
365
|
+
self.server_args = server_args
|
366
|
+
context = zmq.asyncio.Context(3)
|
367
|
+
self.recv_from_detokenizer = get_zmq_socket(
|
368
|
+
context, zmq.PULL, port_args.tokenizer_ipc_name, True
|
369
|
+
)
|
370
|
+
self.send_to_scheduler = get_zmq_socket(
|
371
|
+
context, zmq.PUSH, port_args.scheduler_input_ipc_name, True
|
372
|
+
)
|
373
|
+
self.receive_from_worker = get_zmq_socket(
|
374
|
+
context, zmq.PULL, port_args.tokenizer_worker_ipc_name, True
|
375
|
+
)
|
376
|
+
self._loop = asyncio.new_event_loop()
|
377
|
+
self._thread = threading.Thread(target=self._run_loop, daemon=True)
|
378
|
+
self._thread.start()
|
379
|
+
self._task = asyncio.run_coroutine_threadsafe(
|
380
|
+
self.router_worker_obj(), self._loop
|
381
|
+
)
|
382
|
+
# Start handle_loop simultaneously
|
383
|
+
self._handle_task = asyncio.run_coroutine_threadsafe(
|
384
|
+
print_exception_wrapper(self.handle_loop), self._loop
|
385
|
+
)
|
386
|
+
self.init_disaggregation()
|
387
|
+
|
388
|
+
def _run_loop(self):
|
389
|
+
self._loop.run_forever()
|
390
|
+
|
391
|
+
async def router_worker_obj(self):
|
392
|
+
while True:
|
393
|
+
recv_obj = await self.receive_from_worker.recv_pyobj()
|
394
|
+
await self.send_to_scheduler.send_pyobj(recv_obj)
|
395
|
+
|
396
|
+
async def handle_loop(self):
|
397
|
+
# special reqs will recv from scheduler, need to route to right worker
|
398
|
+
self.create_sockets_mapping()
|
399
|
+
while True:
|
400
|
+
recv_obj = await self.recv_from_detokenizer.recv_pyobj()
|
401
|
+
await self._distribute_result_to_workers(recv_obj)
|
402
|
+
|
403
|
+
async def _distribute_result_to_workers(self, recv_obj):
|
404
|
+
"""Distribute result to corresponding workers based on rid"""
|
405
|
+
if isinstance(recv_obj, MultiTokenizerWarpper):
|
406
|
+
worker_ids = [recv_obj.worker_id]
|
407
|
+
recv_obj = recv_obj.obj
|
408
|
+
else:
|
409
|
+
worker_ids = get_worker_ids_from_req_rids(recv_obj.rids)
|
410
|
+
|
411
|
+
if len(worker_ids) == 0:
|
412
|
+
logger.error(f"Cannot find worker_id from rids {recv_obj.rids}")
|
413
|
+
return
|
414
|
+
|
415
|
+
# Distribute result to each worker
|
416
|
+
for i, worker_id in enumerate(worker_ids):
|
417
|
+
if isinstance(recv_obj, MultiTokenizerRegisterReq):
|
418
|
+
if self.register_tokenizer_ipc(recv_obj, worker_id):
|
419
|
+
logger.info(
|
420
|
+
f"MultiTokenizerRouter Created ZMQ socket for worker {worker_id}"
|
421
|
+
)
|
422
|
+
continue
|
423
|
+
else:
|
424
|
+
if worker_id not in self.tokenizer_mapping:
|
425
|
+
logger.error(
|
426
|
+
f"Tokenizer Worker ID {worker_id} not registered. Check if the server Process {worker_id} is alive"
|
427
|
+
)
|
428
|
+
continue
|
429
|
+
new_recv_obj = self._handle_output_by_index(recv_obj, i)
|
430
|
+
self.tokenizer_mapping[worker_id].send_pyobj(new_recv_obj)
|
431
|
+
|
432
|
+
|
433
|
+
class MultiTokenizerManager(TokenizerManager, MultiTokenizerMixin):
|
434
|
+
"""Multi Process Tokenizer Manager that tokenizes the text."""
|
435
|
+
|
436
|
+
def __init__(
|
437
|
+
self,
|
438
|
+
server_args: ServerArgs,
|
439
|
+
port_args: PortArgs,
|
440
|
+
):
|
441
|
+
# prevent init prefill bootstrapserver again
|
442
|
+
disaggregation_mode = server_args.disaggregation_mode
|
443
|
+
server_args.disaggregation_mode = "null"
|
444
|
+
super().__init__(server_args, port_args)
|
445
|
+
|
446
|
+
self.worker_id = os.getpid()
|
447
|
+
self.tokenizer_ipc_name = port_args.tokenizer_ipc_name
|
448
|
+
|
449
|
+
# For PD disaggregtion
|
450
|
+
self.server_args.disaggregation_mode = disaggregation_mode
|
451
|
+
self.disaggregation_mode = DisaggregationMode(
|
452
|
+
self.server_args.disaggregation_mode
|
453
|
+
)
|
454
|
+
self.disaggregation_transfer_backend = TransferBackend(
|
455
|
+
self.server_args.disaggregation_transfer_backend
|
456
|
+
)
|
457
|
+
# Communicator
|
458
|
+
self.register_multi_tokenizer_communicator = _Communicator(
|
459
|
+
self.send_to_scheduler, 2
|
460
|
+
)
|
461
|
+
self._result_dispatcher._mapping.append(
|
462
|
+
(
|
463
|
+
MultiTokenizerRegisterReq,
|
464
|
+
self.register_multi_tokenizer_communicator.handle_recv,
|
465
|
+
)
|
466
|
+
)
|
467
|
+
|
468
|
+
async def register_to_main_tokenizer_manager(self):
|
469
|
+
"""Register this worker to the main TokenizerManager"""
|
470
|
+
# create a handle loop to receive messages from the main TokenizerManager
|
471
|
+
self.auto_create_handle_loop()
|
472
|
+
req = MultiTokenizerRegisterReq(rids=[f"{self.worker_id}_register"])
|
473
|
+
req.ipc_name = self.tokenizer_ipc_name
|
474
|
+
_Communicator.enable_multi_tokenizer = True
|
475
|
+
await self.register_multi_tokenizer_communicator(req)
|
476
|
+
|
477
|
+
|
478
|
+
async def print_exception_wrapper(func):
|
479
|
+
"""
|
480
|
+
Sometimes an asyncio function does not print exception.
|
481
|
+
We do another wrapper to handle the exception.
|
482
|
+
"""
|
483
|
+
try:
|
484
|
+
await func()
|
485
|
+
except Exception:
|
486
|
+
traceback = get_exception_traceback()
|
487
|
+
logger.error(f"MultiTokenizerRouter hit an exception: {traceback}")
|
488
|
+
if hasattr(func, "__self__") and isinstance(
|
489
|
+
func.__self__, MultiTokenizerRouter
|
490
|
+
):
|
491
|
+
func.__self__.dump_requests_before_crash()
|
492
|
+
kill_process_tree(os.getpid(), include_parent=True)
|
493
|
+
sys.exit(1)
|
494
|
+
|
495
|
+
|
496
|
+
def serialize_port_args(port_args: PortArgs) -> dict:
|
497
|
+
"""Serialize PortArgs into a shareable dictionary"""
|
498
|
+
return {
|
499
|
+
"tokenizer_ipc_name": port_args.tokenizer_ipc_name,
|
500
|
+
"scheduler_input_ipc_name": port_args.scheduler_input_ipc_name,
|
501
|
+
"detokenizer_ipc_name": port_args.detokenizer_ipc_name,
|
502
|
+
"nccl_port": port_args.nccl_port,
|
503
|
+
"rpc_ipc_name": port_args.rpc_ipc_name,
|
504
|
+
"metrics_ipc_name": port_args.metrics_ipc_name,
|
505
|
+
"tokenizer_worker_ipc_name": port_args.tokenizer_worker_ipc_name,
|
506
|
+
}
|
507
|
+
|
508
|
+
|
509
|
+
def deserialize_data(port_args: dict, server_args: dict):
|
510
|
+
"""Deserialize data from shared dictionaries"""
|
511
|
+
return PortArgs(**port_args), ServerArgs(**server_args)
|
512
|
+
|
513
|
+
|
514
|
+
def serialize_server_args(server_args: ServerArgs) -> dict:
|
515
|
+
"""Serialize ServerArgs into a shareable dictionary"""
|
516
|
+
return dataclasses.asdict(server_args)
|
517
|
+
|
518
|
+
|
519
|
+
def serialize_scheduler_info(scheduler_info: Dict) -> dict:
|
520
|
+
"""Serialize scheduler_info into a shareable dictionary"""
|
521
|
+
return scheduler_info
|
522
|
+
|
523
|
+
|
524
|
+
def deserialize_scheduler_info(data: dict) -> Dict:
|
525
|
+
"""Deserialize scheduler_info from a shared dictionary"""
|
526
|
+
return data
|
527
|
+
|
528
|
+
|
529
|
+
def write_to_shared_memory(data: dict, name: str) -> shared_memory.SharedMemory:
|
530
|
+
"""Write data to shared memory"""
|
531
|
+
serialized = json.dumps(data).encode("utf-8")
|
532
|
+
size = len(serialized)
|
533
|
+
try:
|
534
|
+
# Try to open existing shared memory
|
535
|
+
shm = shared_memory.SharedMemory(name=name)
|
536
|
+
# If size is insufficient, close and recreate
|
537
|
+
if shm.size < size:
|
538
|
+
shm.close()
|
539
|
+
shm.unlink()
|
540
|
+
shm = shared_memory.SharedMemory(create=True, size=size, name=name)
|
541
|
+
except FileNotFoundError:
|
542
|
+
# If not present, create new shared memory
|
543
|
+
shm = shared_memory.SharedMemory(create=True, size=size, name=name)
|
544
|
+
|
545
|
+
shm.buf[:size] = serialized
|
546
|
+
return shm
|
547
|
+
|
548
|
+
|
549
|
+
def read_from_shared_memory(name: str) -> dict:
|
550
|
+
"""Read data from shared memory"""
|
551
|
+
try:
|
552
|
+
shm = shared_memory.SharedMemory(name=name)
|
553
|
+
data = json.loads(bytes(shm.buf).decode("utf-8"))
|
554
|
+
shm.close()
|
555
|
+
return data
|
556
|
+
except FileNotFoundError:
|
557
|
+
raise FileNotFoundError(f"Shared memory {name} not found")
|
558
|
+
|
559
|
+
|
560
|
+
def get_main_process_id() -> int:
|
561
|
+
"""Get the main process ID"""
|
562
|
+
return multiprocessing.current_process()._parent_pid
|
563
|
+
|
564
|
+
|
565
|
+
def write_data_for_multi_tokenizer(
|
566
|
+
port_args: PortArgs, server_args: ServerArgs, scheduler_info: Dict
|
567
|
+
):
|
568
|
+
"""Write args information to share memory for multi-tokenizer"""
|
569
|
+
# get main process ID
|
570
|
+
main_pid = get_main_process_id()
|
571
|
+
current_pid = os.getpid()
|
572
|
+
logger.info(f"main process ID: {main_pid}, current process ID: {current_pid}")
|
573
|
+
|
574
|
+
# Write port_args to shared memory
|
575
|
+
port_args_shm = write_to_shared_memory(
|
576
|
+
serialize_port_args(port_args), f"port_args_{current_pid}"
|
577
|
+
)
|
578
|
+
# Write server_args to shared memory
|
579
|
+
server_args_shm = write_to_shared_memory(
|
580
|
+
serialize_server_args(server_args), f"server_args_{current_pid}"
|
581
|
+
)
|
582
|
+
# Write scheduler_info to shared memory
|
583
|
+
scheduler_info_shm = write_to_shared_memory(
|
584
|
+
serialize_scheduler_info(scheduler_info), f"scheduler_info_{current_pid}"
|
585
|
+
)
|
586
|
+
|
587
|
+
port_args_shm.close()
|
588
|
+
server_args_shm.close()
|
589
|
+
scheduler_info_shm.close()
|
590
|
+
|
591
|
+
return port_args_shm, server_args_shm, scheduler_info_shm
|