sglang 0.5.1.post2__py3-none-any.whl → 0.5.2rc0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (107) hide show
  1. sglang/bench_one_batch.py +3 -0
  2. sglang/bench_one_batch_server.py +79 -53
  3. sglang/bench_serving.py +186 -14
  4. sglang/profiler.py +0 -1
  5. sglang/srt/configs/__init__.py +2 -0
  6. sglang/srt/configs/longcat_flash.py +104 -0
  7. sglang/srt/configs/model_config.py +12 -0
  8. sglang/srt/connector/__init__.py +1 -1
  9. sglang/srt/connector/base_connector.py +1 -2
  10. sglang/srt/connector/redis.py +2 -2
  11. sglang/srt/connector/serde/__init__.py +1 -1
  12. sglang/srt/connector/serde/safe_serde.py +4 -3
  13. sglang/srt/conversation.py +38 -5
  14. sglang/srt/disaggregation/ascend/conn.py +75 -0
  15. sglang/srt/disaggregation/launch_lb.py +0 -13
  16. sglang/srt/disaggregation/mini_lb.py +33 -8
  17. sglang/srt/disaggregation/prefill.py +1 -1
  18. sglang/srt/distributed/parallel_state.py +24 -14
  19. sglang/srt/entrypoints/engine.py +19 -12
  20. sglang/srt/entrypoints/http_server.py +174 -34
  21. sglang/srt/entrypoints/openai/protocol.py +87 -24
  22. sglang/srt/entrypoints/openai/serving_chat.py +50 -9
  23. sglang/srt/entrypoints/openai/serving_completions.py +15 -0
  24. sglang/srt/eplb/eplb_manager.py +26 -2
  25. sglang/srt/eplb/expert_distribution.py +29 -2
  26. sglang/srt/function_call/deepseekv31_detector.py +222 -0
  27. sglang/srt/function_call/function_call_parser.py +2 -0
  28. sglang/srt/function_call/gpt_oss_detector.py +144 -256
  29. sglang/srt/harmony_parser.py +588 -0
  30. sglang/srt/hf_transformers_utils.py +26 -7
  31. sglang/srt/layers/activation.py +12 -0
  32. sglang/srt/layers/attention/ascend_backend.py +374 -136
  33. sglang/srt/layers/attention/flashattention_backend.py +241 -7
  34. sglang/srt/layers/attention/flashinfer_backend.py +5 -2
  35. sglang/srt/layers/attention/flashinfer_mla_backend.py +5 -2
  36. sglang/srt/layers/attention/hybrid_attn_backend.py +53 -21
  37. sglang/srt/layers/attention/trtllm_mla_backend.py +25 -10
  38. sglang/srt/layers/communicator.py +1 -2
  39. sglang/srt/layers/layernorm.py +28 -3
  40. sglang/srt/layers/linear.py +3 -2
  41. sglang/srt/layers/logits_processor.py +1 -1
  42. sglang/srt/layers/moe/cutlass_moe.py +0 -8
  43. sglang/srt/layers/moe/ep_moe/kernels.py +74 -0
  44. sglang/srt/layers/moe/ep_moe/layer.py +13 -13
  45. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +146 -0
  46. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=64,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
  47. sglang/srt/layers/moe/topk.py +35 -12
  48. sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py +133 -235
  49. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +5 -10
  50. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +5 -23
  51. sglang/srt/layers/quantization/fp8.py +2 -1
  52. sglang/srt/layers/quantization/fp8_kernel.py +2 -2
  53. sglang/srt/layers/quantization/fp8_utils.py +2 -2
  54. sglang/srt/layers/quantization/modelopt_quant.py +7 -0
  55. sglang/srt/layers/quantization/mxfp4.py +25 -27
  56. sglang/srt/layers/quantization/mxfp4_tensor.py +3 -1
  57. sglang/srt/layers/quantization/utils.py +13 -0
  58. sglang/srt/layers/quantization/w8a8_int8.py +7 -3
  59. sglang/srt/layers/rotary_embedding.py +28 -1
  60. sglang/srt/layers/sampler.py +29 -5
  61. sglang/srt/layers/utils.py +0 -14
  62. sglang/srt/managers/cache_controller.py +237 -204
  63. sglang/srt/managers/detokenizer_manager.py +48 -2
  64. sglang/srt/managers/io_struct.py +57 -0
  65. sglang/srt/managers/mm_utils.py +5 -1
  66. sglang/srt/managers/multi_tokenizer_mixin.py +591 -0
  67. sglang/srt/managers/scheduler.py +94 -9
  68. sglang/srt/managers/scheduler_output_processor_mixin.py +20 -18
  69. sglang/srt/managers/scheduler_update_weights_mixin.py +8 -1
  70. sglang/srt/managers/tokenizer_manager.py +122 -42
  71. sglang/srt/mem_cache/chunk_cache.py +1 -1
  72. sglang/srt/mem_cache/hicache_storage.py +51 -23
  73. sglang/srt/mem_cache/hiradix_cache.py +87 -71
  74. sglang/srt/mem_cache/lora_radix_cache.py +1 -1
  75. sglang/srt/mem_cache/memory_pool.py +77 -14
  76. sglang/srt/mem_cache/memory_pool_host.py +4 -5
  77. sglang/srt/mem_cache/radix_cache.py +6 -4
  78. sglang/srt/mem_cache/radix_cache_cpp.py +1 -1
  79. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +38 -20
  80. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +87 -82
  81. sglang/srt/mem_cache/swa_radix_cache.py +1 -1
  82. sglang/srt/model_executor/model_runner.py +6 -5
  83. sglang/srt/model_loader/loader.py +15 -24
  84. sglang/srt/model_loader/utils.py +12 -0
  85. sglang/srt/models/deepseek_v2.py +38 -13
  86. sglang/srt/models/gpt_oss.py +2 -15
  87. sglang/srt/models/llama_eagle3.py +4 -0
  88. sglang/srt/models/longcat_flash.py +1015 -0
  89. sglang/srt/models/longcat_flash_nextn.py +691 -0
  90. sglang/srt/models/qwen2.py +26 -3
  91. sglang/srt/models/qwen2_5_vl.py +66 -41
  92. sglang/srt/models/qwen2_moe.py +22 -2
  93. sglang/srt/models/transformers.py +1 -1
  94. sglang/srt/multimodal/processors/base_processor.py +4 -2
  95. sglang/srt/reasoning_parser.py +56 -300
  96. sglang/srt/sampling/penaltylib/orchestrator.py +14 -2
  97. sglang/srt/server_args.py +122 -56
  98. sglang/srt/speculative/eagle_worker.py +28 -8
  99. sglang/srt/tokenizer/tiktoken_tokenizer.py +6 -1
  100. sglang/srt/utils.py +73 -5
  101. sglang/test/attention/test_trtllm_mla_backend.py +12 -3
  102. sglang/version.py +1 -1
  103. {sglang-0.5.1.post2.dist-info → sglang-0.5.2rc0.dist-info}/METADATA +7 -6
  104. {sglang-0.5.1.post2.dist-info → sglang-0.5.2rc0.dist-info}/RECORD +107 -99
  105. {sglang-0.5.1.post2.dist-info → sglang-0.5.2rc0.dist-info}/WHEEL +0 -0
  106. {sglang-0.5.1.post2.dist-info → sglang-0.5.2rc0.dist-info}/licenses/LICENSE +0 -0
  107. {sglang-0.5.1.post2.dist-info → sglang-0.5.2rc0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,591 @@
1
+ # Copyright 2023-2024 SGLang Team
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # ==============================================================================
14
+ """MultiTokenizerMixin is a class that provides nesscary methods for MultiTokenizerManager and DetokenizerManager."""
15
+ import asyncio
16
+ import dataclasses
17
+ import json
18
+ import logging
19
+ import multiprocessing as multiprocessing
20
+ import os
21
+ import sys
22
+ import threading
23
+ from multiprocessing import shared_memory
24
+ from typing import Dict
25
+
26
+ import zmq
27
+ import zmq.asyncio
28
+
29
+ from sglang.srt.disaggregation.utils import DisaggregationMode, TransferBackend
30
+ from sglang.srt.managers.io_struct import (
31
+ BatchEmbeddingOut,
32
+ BatchMultimodalOut,
33
+ BatchStrOut,
34
+ BatchTokenIDOut,
35
+ MultiTokenizerRegisterReq,
36
+ MultiTokenizerWarpper,
37
+ )
38
+ from sglang.srt.managers.tokenizer_manager import TokenizerManager, _Communicator
39
+ from sglang.srt.server_args import PortArgs, ServerArgs
40
+ from sglang.srt.utils import (
41
+ get_worker_ids_from_req_rids,
42
+ get_zmq_socket,
43
+ kill_process_tree,
44
+ )
45
+ from sglang.utils import get_exception_traceback
46
+
47
+ logger = logging.getLogger(__name__)
48
+
49
+
50
+ class MultiTokenizerMixin:
51
+ """Mixin class for MultiTokenizerManager and DetokenizerManager"""
52
+
53
+ def create_sockets_mapping(self):
54
+ if not hasattr(self, "tokenizer_mapping"):
55
+ self.tokenizer_mapping = {}
56
+ # Create ZMQ context if needed
57
+ if not hasattr(self, "_zmq_context"):
58
+ self._zmq_context = zmq.Context()
59
+
60
+ def init_tokenizer_mapping(
61
+ self, recv_obj: MultiTokenizerRegisterReq, worker_id: str
62
+ ):
63
+ """init tokenizer mapping from register request"""
64
+ ipc_name = recv_obj.ipc_name
65
+ worker_id_int = int(worker_id)
66
+
67
+ if worker_id_int not in self.tokenizer_mapping:
68
+ socket = get_zmq_socket(self._zmq_context, zmq.PUSH, ipc_name, False)
69
+ self.tokenizer_mapping[worker_id_int] = socket
70
+ self.tokenizer_mapping[worker_id_int].send_pyobj(recv_obj)
71
+ return True
72
+ else:
73
+ return False
74
+
75
+ def register_tokenizer_ipc(self, recv_obj, worker_id):
76
+ if worker_id not in self.tokenizer_mapping:
77
+ # register the worker if not already done
78
+ if isinstance(recv_obj, MultiTokenizerRegisterReq):
79
+ return self.init_tokenizer_mapping(recv_obj, worker_id)
80
+ else:
81
+ logger.error(
82
+ f"Worker {worker_id} not registered and not found in tokenizer mapping . "
83
+ "Please ensure the worker is registered correctly."
84
+ )
85
+ return False
86
+
87
+ def _handle_output_by_index(self, output, i):
88
+ """NOTE: A maintainable method is better here."""
89
+ if isinstance(output, BatchTokenIDOut):
90
+ new_output = BatchTokenIDOut(
91
+ rids=[output.rids[i]],
92
+ finished_reasons=(
93
+ [output.finished_reasons[i]]
94
+ if len(output.finished_reasons) > i
95
+ else None
96
+ ),
97
+ decoded_texts=(
98
+ [output.decoded_texts[i]] if len(output.decoded_texts) > i else None
99
+ ),
100
+ decode_ids=(
101
+ [output.decode_ids[i]] if len(output.decode_ids) > i else None
102
+ ),
103
+ read_offsets=(
104
+ [output.read_offsets[i]] if len(output.read_offsets) > i else None
105
+ ),
106
+ output_ids=(
107
+ [output.output_ids[i]]
108
+ if output.output_ids and len(output.output_ids) > i
109
+ else None
110
+ ),
111
+ skip_special_tokens=(
112
+ [output.skip_special_tokens[i]]
113
+ if len(output.skip_special_tokens) > i
114
+ else None
115
+ ),
116
+ spaces_between_special_tokens=(
117
+ [output.spaces_between_special_tokens[i]]
118
+ if len(output.spaces_between_special_tokens) > i
119
+ else None
120
+ ),
121
+ no_stop_trim=(
122
+ [output.no_stop_trim[i]] if len(output.no_stop_trim) > i else None
123
+ ),
124
+ prompt_tokens=(
125
+ [output.prompt_tokens[i]] if len(output.prompt_tokens) > i else None
126
+ ),
127
+ completion_tokens=(
128
+ [output.completion_tokens[i]]
129
+ if len(output.completion_tokens) > i
130
+ else None
131
+ ),
132
+ cached_tokens=(
133
+ [output.cached_tokens[i]] if len(output.cached_tokens) > i else None
134
+ ),
135
+ spec_verify_ct=(
136
+ [output.spec_verify_ct[i]]
137
+ if len(output.spec_verify_ct) > i
138
+ else None
139
+ ),
140
+ input_token_logprobs_val=(
141
+ [output.input_token_logprobs_val[i]]
142
+ if output.input_token_logprobs_val
143
+ else None
144
+ ),
145
+ input_token_logprobs_idx=(
146
+ [output.input_token_logprobs_idx[i]]
147
+ if output.input_token_logprobs_idx
148
+ else None
149
+ ),
150
+ output_token_logprobs_val=(
151
+ [output.output_token_logprobs_val[i]]
152
+ if output.output_token_logprobs_val
153
+ else None
154
+ ),
155
+ output_token_logprobs_idx=(
156
+ [output.output_token_logprobs_idx[i]]
157
+ if output.output_token_logprobs_idx
158
+ else None
159
+ ),
160
+ input_top_logprobs_val=(
161
+ [output.input_top_logprobs_val[i]]
162
+ if output.input_top_logprobs_val
163
+ else None
164
+ ),
165
+ input_top_logprobs_idx=(
166
+ [output.input_top_logprobs_idx[i]]
167
+ if output.input_top_logprobs_idx
168
+ else None
169
+ ),
170
+ output_top_logprobs_val=(
171
+ [output.output_top_logprobs_val[i]]
172
+ if output.output_top_logprobs_val
173
+ else None
174
+ ),
175
+ output_top_logprobs_idx=(
176
+ [output.output_top_logprobs_idx[i]]
177
+ if output.output_top_logprobs_idx
178
+ else None
179
+ ),
180
+ input_token_ids_logprobs_val=(
181
+ [output.input_token_ids_logprobs_val[i]]
182
+ if output.input_token_ids_logprobs_val
183
+ else None
184
+ ),
185
+ input_token_ids_logprobs_idx=(
186
+ [output.input_token_ids_logprobs_idx[i]]
187
+ if output.input_token_ids_logprobs_idx
188
+ else None
189
+ ),
190
+ output_token_ids_logprobs_val=(
191
+ [output.output_token_ids_logprobs_val[i]]
192
+ if output.output_token_ids_logprobs_val
193
+ else None
194
+ ),
195
+ output_token_ids_logprobs_idx=(
196
+ [output.output_token_ids_logprobs_idx[i]]
197
+ if output.output_token_ids_logprobs_idx
198
+ else None
199
+ ),
200
+ output_hidden_states=(
201
+ [output.output_hidden_states[i]]
202
+ if output.output_hidden_states
203
+ else None
204
+ ),
205
+ )
206
+ elif isinstance(output, BatchEmbeddingOut):
207
+ new_output = BatchEmbeddingOut(
208
+ rids=[output.rids[i]],
209
+ finished_reasons=(
210
+ [output.finished_reasons[i]]
211
+ if len(output.finished_reasons) > i
212
+ else None
213
+ ),
214
+ embeddings=(
215
+ [output.embeddings[i]] if len(output.embeddings) > i else None
216
+ ),
217
+ prompt_tokens=(
218
+ [output.prompt_tokens[i]] if len(output.prompt_tokens) > i else None
219
+ ),
220
+ cached_tokens=(
221
+ [output.cached_tokens[i]] if len(output.cached_tokens) > i else None
222
+ ),
223
+ )
224
+ elif isinstance(output, BatchStrOut):
225
+ new_output = BatchStrOut(
226
+ rids=[output.rids[i]],
227
+ finished_reasons=(
228
+ [output.finished_reasons[i]]
229
+ if len(output.finished_reasons) > i
230
+ else None
231
+ ),
232
+ output_strs=(
233
+ [output.output_strs[i]] if len(output.output_strs) > i else None
234
+ ),
235
+ output_ids=(
236
+ [output.output_ids[i]]
237
+ if output.output_ids and len(output.output_ids) > i
238
+ else None
239
+ ),
240
+ prompt_tokens=(
241
+ [output.prompt_tokens[i]] if len(output.prompt_tokens) > i else None
242
+ ),
243
+ completion_tokens=(
244
+ [output.completion_tokens[i]]
245
+ if len(output.completion_tokens) > i
246
+ else None
247
+ ),
248
+ cached_tokens=(
249
+ [output.cached_tokens[i]] if len(output.cached_tokens) > i else None
250
+ ),
251
+ spec_verify_ct=(
252
+ [output.spec_verify_ct[i]]
253
+ if len(output.spec_verify_ct) > i
254
+ else None
255
+ ),
256
+ input_token_logprobs_val=(
257
+ [output.input_token_logprobs_val[i]]
258
+ if output.input_token_logprobs_val
259
+ else None
260
+ ),
261
+ input_token_logprobs_idx=(
262
+ [output.input_token_logprobs_idx[i]]
263
+ if output.input_token_logprobs_idx
264
+ else None
265
+ ),
266
+ output_token_logprobs_val=(
267
+ [output.output_token_logprobs_val[i]]
268
+ if output.output_token_logprobs_val
269
+ else None
270
+ ),
271
+ output_token_logprobs_idx=(
272
+ [output.output_token_logprobs_idx[i]]
273
+ if output.output_token_logprobs_idx
274
+ else None
275
+ ),
276
+ input_top_logprobs_val=(
277
+ [output.input_top_logprobs_val[i]]
278
+ if output.input_top_logprobs_val
279
+ else None
280
+ ),
281
+ input_top_logprobs_idx=(
282
+ [output.input_top_logprobs_idx[i]]
283
+ if output.input_top_logprobs_idx
284
+ else None
285
+ ),
286
+ output_top_logprobs_val=(
287
+ [output.output_top_logprobs_val[i]]
288
+ if output.output_top_logprobs_val
289
+ else None
290
+ ),
291
+ output_top_logprobs_idx=(
292
+ [output.output_top_logprobs_idx[i]]
293
+ if output.output_top_logprobs_idx
294
+ else None
295
+ ),
296
+ input_token_ids_logprobs_val=(
297
+ [output.input_token_ids_logprobs_val[i]]
298
+ if output.input_token_ids_logprobs_val
299
+ else None
300
+ ),
301
+ input_token_ids_logprobs_idx=(
302
+ [output.input_token_ids_logprobs_idx[i]]
303
+ if output.input_token_ids_logprobs_idx
304
+ else None
305
+ ),
306
+ output_token_ids_logprobs_val=(
307
+ [output.output_token_ids_logprobs_val[i]]
308
+ if output.output_token_ids_logprobs_val
309
+ else None
310
+ ),
311
+ output_token_ids_logprobs_idx=(
312
+ [output.output_token_ids_logprobs_idx[i]]
313
+ if output.output_token_ids_logprobs_idx
314
+ else None
315
+ ),
316
+ output_hidden_states=(
317
+ [output.output_hidden_states[i]]
318
+ if output.output_hidden_states
319
+ else None
320
+ ),
321
+ )
322
+ elif isinstance(output, BatchMultimodalOut):
323
+ new_output = BatchMultimodalOut(
324
+ rids=[output.rids[i]],
325
+ finished_reasons=(
326
+ [output.finished_reasons[i]]
327
+ if len(output.finished_reasons) > i
328
+ else None
329
+ ),
330
+ outputs=([output.outputs[i]] if len(output.outputs) > i else None),
331
+ prompt_tokens=(
332
+ [output.prompt_tokens[i]] if len(output.prompt_tokens) > i else None
333
+ ),
334
+ completion_tokens=(
335
+ [output.completion_tokens[i]]
336
+ if len(output.completion_tokens) > i
337
+ else None
338
+ ),
339
+ cached_tokens=(
340
+ [output.cached_tokens[i]] if len(output.cached_tokens) > i else None
341
+ ),
342
+ )
343
+ else:
344
+ new_output = output
345
+ return new_output
346
+
347
+ def clear_tokenizer_mapping(self):
348
+ if hasattr(self, "tokenizer_mapping"):
349
+ for socket in self.tokenizer_mapping.values():
350
+ try:
351
+ socket.close()
352
+ except Exception as e:
353
+ logger.warning(f"Failed to close socket: {e}")
354
+ self.tokenizer_mapping.clear()
355
+
356
+
357
+ class MultiTokenizerRouter(TokenizerManager, MultiTokenizerMixin):
358
+ """A router to receive requests from MultiTokenizerManager"""
359
+
360
+ def __init__(
361
+ self,
362
+ server_args: ServerArgs,
363
+ port_args: PortArgs,
364
+ ):
365
+ self.server_args = server_args
366
+ context = zmq.asyncio.Context(3)
367
+ self.recv_from_detokenizer = get_zmq_socket(
368
+ context, zmq.PULL, port_args.tokenizer_ipc_name, True
369
+ )
370
+ self.send_to_scheduler = get_zmq_socket(
371
+ context, zmq.PUSH, port_args.scheduler_input_ipc_name, True
372
+ )
373
+ self.receive_from_worker = get_zmq_socket(
374
+ context, zmq.PULL, port_args.tokenizer_worker_ipc_name, True
375
+ )
376
+ self._loop = asyncio.new_event_loop()
377
+ self._thread = threading.Thread(target=self._run_loop, daemon=True)
378
+ self._thread.start()
379
+ self._task = asyncio.run_coroutine_threadsafe(
380
+ self.router_worker_obj(), self._loop
381
+ )
382
+ # Start handle_loop simultaneously
383
+ self._handle_task = asyncio.run_coroutine_threadsafe(
384
+ print_exception_wrapper(self.handle_loop), self._loop
385
+ )
386
+ self.init_disaggregation()
387
+
388
+ def _run_loop(self):
389
+ self._loop.run_forever()
390
+
391
+ async def router_worker_obj(self):
392
+ while True:
393
+ recv_obj = await self.receive_from_worker.recv_pyobj()
394
+ await self.send_to_scheduler.send_pyobj(recv_obj)
395
+
396
+ async def handle_loop(self):
397
+ # special reqs will recv from scheduler, need to route to right worker
398
+ self.create_sockets_mapping()
399
+ while True:
400
+ recv_obj = await self.recv_from_detokenizer.recv_pyobj()
401
+ await self._distribute_result_to_workers(recv_obj)
402
+
403
+ async def _distribute_result_to_workers(self, recv_obj):
404
+ """Distribute result to corresponding workers based on rid"""
405
+ if isinstance(recv_obj, MultiTokenizerWarpper):
406
+ worker_ids = [recv_obj.worker_id]
407
+ recv_obj = recv_obj.obj
408
+ else:
409
+ worker_ids = get_worker_ids_from_req_rids(recv_obj.rids)
410
+
411
+ if len(worker_ids) == 0:
412
+ logger.error(f"Cannot find worker_id from rids {recv_obj.rids}")
413
+ return
414
+
415
+ # Distribute result to each worker
416
+ for i, worker_id in enumerate(worker_ids):
417
+ if isinstance(recv_obj, MultiTokenizerRegisterReq):
418
+ if self.register_tokenizer_ipc(recv_obj, worker_id):
419
+ logger.info(
420
+ f"MultiTokenizerRouter Created ZMQ socket for worker {worker_id}"
421
+ )
422
+ continue
423
+ else:
424
+ if worker_id not in self.tokenizer_mapping:
425
+ logger.error(
426
+ f"Tokenizer Worker ID {worker_id} not registered. Check if the server Process {worker_id} is alive"
427
+ )
428
+ continue
429
+ new_recv_obj = self._handle_output_by_index(recv_obj, i)
430
+ self.tokenizer_mapping[worker_id].send_pyobj(new_recv_obj)
431
+
432
+
433
+ class MultiTokenizerManager(TokenizerManager, MultiTokenizerMixin):
434
+ """Multi Process Tokenizer Manager that tokenizes the text."""
435
+
436
+ def __init__(
437
+ self,
438
+ server_args: ServerArgs,
439
+ port_args: PortArgs,
440
+ ):
441
+ # prevent init prefill bootstrapserver again
442
+ disaggregation_mode = server_args.disaggregation_mode
443
+ server_args.disaggregation_mode = "null"
444
+ super().__init__(server_args, port_args)
445
+
446
+ self.worker_id = os.getpid()
447
+ self.tokenizer_ipc_name = port_args.tokenizer_ipc_name
448
+
449
+ # For PD disaggregtion
450
+ self.server_args.disaggregation_mode = disaggregation_mode
451
+ self.disaggregation_mode = DisaggregationMode(
452
+ self.server_args.disaggregation_mode
453
+ )
454
+ self.disaggregation_transfer_backend = TransferBackend(
455
+ self.server_args.disaggregation_transfer_backend
456
+ )
457
+ # Communicator
458
+ self.register_multi_tokenizer_communicator = _Communicator(
459
+ self.send_to_scheduler, 2
460
+ )
461
+ self._result_dispatcher._mapping.append(
462
+ (
463
+ MultiTokenizerRegisterReq,
464
+ self.register_multi_tokenizer_communicator.handle_recv,
465
+ )
466
+ )
467
+
468
+ async def register_to_main_tokenizer_manager(self):
469
+ """Register this worker to the main TokenizerManager"""
470
+ # create a handle loop to receive messages from the main TokenizerManager
471
+ self.auto_create_handle_loop()
472
+ req = MultiTokenizerRegisterReq(rids=[f"{self.worker_id}_register"])
473
+ req.ipc_name = self.tokenizer_ipc_name
474
+ _Communicator.enable_multi_tokenizer = True
475
+ await self.register_multi_tokenizer_communicator(req)
476
+
477
+
478
+ async def print_exception_wrapper(func):
479
+ """
480
+ Sometimes an asyncio function does not print exception.
481
+ We do another wrapper to handle the exception.
482
+ """
483
+ try:
484
+ await func()
485
+ except Exception:
486
+ traceback = get_exception_traceback()
487
+ logger.error(f"MultiTokenizerRouter hit an exception: {traceback}")
488
+ if hasattr(func, "__self__") and isinstance(
489
+ func.__self__, MultiTokenizerRouter
490
+ ):
491
+ func.__self__.dump_requests_before_crash()
492
+ kill_process_tree(os.getpid(), include_parent=True)
493
+ sys.exit(1)
494
+
495
+
496
+ def serialize_port_args(port_args: PortArgs) -> dict:
497
+ """Serialize PortArgs into a shareable dictionary"""
498
+ return {
499
+ "tokenizer_ipc_name": port_args.tokenizer_ipc_name,
500
+ "scheduler_input_ipc_name": port_args.scheduler_input_ipc_name,
501
+ "detokenizer_ipc_name": port_args.detokenizer_ipc_name,
502
+ "nccl_port": port_args.nccl_port,
503
+ "rpc_ipc_name": port_args.rpc_ipc_name,
504
+ "metrics_ipc_name": port_args.metrics_ipc_name,
505
+ "tokenizer_worker_ipc_name": port_args.tokenizer_worker_ipc_name,
506
+ }
507
+
508
+
509
+ def deserialize_data(port_args: dict, server_args: dict):
510
+ """Deserialize data from shared dictionaries"""
511
+ return PortArgs(**port_args), ServerArgs(**server_args)
512
+
513
+
514
+ def serialize_server_args(server_args: ServerArgs) -> dict:
515
+ """Serialize ServerArgs into a shareable dictionary"""
516
+ return dataclasses.asdict(server_args)
517
+
518
+
519
+ def serialize_scheduler_info(scheduler_info: Dict) -> dict:
520
+ """Serialize scheduler_info into a shareable dictionary"""
521
+ return scheduler_info
522
+
523
+
524
+ def deserialize_scheduler_info(data: dict) -> Dict:
525
+ """Deserialize scheduler_info from a shared dictionary"""
526
+ return data
527
+
528
+
529
+ def write_to_shared_memory(data: dict, name: str) -> shared_memory.SharedMemory:
530
+ """Write data to shared memory"""
531
+ serialized = json.dumps(data).encode("utf-8")
532
+ size = len(serialized)
533
+ try:
534
+ # Try to open existing shared memory
535
+ shm = shared_memory.SharedMemory(name=name)
536
+ # If size is insufficient, close and recreate
537
+ if shm.size < size:
538
+ shm.close()
539
+ shm.unlink()
540
+ shm = shared_memory.SharedMemory(create=True, size=size, name=name)
541
+ except FileNotFoundError:
542
+ # If not present, create new shared memory
543
+ shm = shared_memory.SharedMemory(create=True, size=size, name=name)
544
+
545
+ shm.buf[:size] = serialized
546
+ return shm
547
+
548
+
549
+ def read_from_shared_memory(name: str) -> dict:
550
+ """Read data from shared memory"""
551
+ try:
552
+ shm = shared_memory.SharedMemory(name=name)
553
+ data = json.loads(bytes(shm.buf).decode("utf-8"))
554
+ shm.close()
555
+ return data
556
+ except FileNotFoundError:
557
+ raise FileNotFoundError(f"Shared memory {name} not found")
558
+
559
+
560
+ def get_main_process_id() -> int:
561
+ """Get the main process ID"""
562
+ return multiprocessing.current_process()._parent_pid
563
+
564
+
565
+ def write_data_for_multi_tokenizer(
566
+ port_args: PortArgs, server_args: ServerArgs, scheduler_info: Dict
567
+ ):
568
+ """Write args information to share memory for multi-tokenizer"""
569
+ # get main process ID
570
+ main_pid = get_main_process_id()
571
+ current_pid = os.getpid()
572
+ logger.info(f"main process ID: {main_pid}, current process ID: {current_pid}")
573
+
574
+ # Write port_args to shared memory
575
+ port_args_shm = write_to_shared_memory(
576
+ serialize_port_args(port_args), f"port_args_{current_pid}"
577
+ )
578
+ # Write server_args to shared memory
579
+ server_args_shm = write_to_shared_memory(
580
+ serialize_server_args(server_args), f"server_args_{current_pid}"
581
+ )
582
+ # Write scheduler_info to shared memory
583
+ scheduler_info_shm = write_to_shared_memory(
584
+ serialize_scheduler_info(scheduler_info), f"scheduler_info_{current_pid}"
585
+ )
586
+
587
+ port_args_shm.close()
588
+ server_args_shm.close()
589
+ scheduler_info_shm.close()
590
+
591
+ return port_args_shm, server_args_shm, scheduler_info_shm