sglang 0.5.1.post3__py3-none-any.whl → 0.5.2rc1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (93) hide show
  1. sglang/bench_one_batch.py +3 -0
  2. sglang/srt/configs/__init__.py +2 -0
  3. sglang/srt/configs/longcat_flash.py +104 -0
  4. sglang/srt/configs/model_config.py +14 -1
  5. sglang/srt/connector/__init__.py +1 -1
  6. sglang/srt/connector/base_connector.py +1 -2
  7. sglang/srt/connector/redis.py +2 -2
  8. sglang/srt/connector/serde/__init__.py +1 -1
  9. sglang/srt/connector/serde/safe_serde.py +4 -3
  10. sglang/srt/disaggregation/ascend/conn.py +75 -0
  11. sglang/srt/disaggregation/launch_lb.py +0 -13
  12. sglang/srt/disaggregation/mini_lb.py +33 -8
  13. sglang/srt/disaggregation/prefill.py +1 -1
  14. sglang/srt/distributed/parallel_state.py +27 -15
  15. sglang/srt/entrypoints/engine.py +19 -12
  16. sglang/srt/entrypoints/http_server.py +174 -34
  17. sglang/srt/entrypoints/openai/protocol.py +60 -0
  18. sglang/srt/eplb/eplb_manager.py +26 -2
  19. sglang/srt/eplb/expert_distribution.py +29 -2
  20. sglang/srt/hf_transformers_utils.py +10 -0
  21. sglang/srt/layers/activation.py +12 -0
  22. sglang/srt/layers/attention/ascend_backend.py +240 -109
  23. sglang/srt/layers/attention/hybrid_attn_backend.py +53 -21
  24. sglang/srt/layers/attention/trtllm_mla_backend.py +25 -10
  25. sglang/srt/layers/layernorm.py +28 -3
  26. sglang/srt/layers/linear.py +3 -2
  27. sglang/srt/layers/logits_processor.py +1 -1
  28. sglang/srt/layers/moe/cutlass_w4a8_moe.py +1 -9
  29. sglang/srt/layers/moe/ep_moe/kernels.py +74 -0
  30. sglang/srt/layers/moe/ep_moe/layer.py +14 -13
  31. sglang/srt/layers/moe/fused_moe_triton/__init__.py +5 -3
  32. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +146 -0
  33. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -1048
  34. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +212 -0
  35. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +796 -0
  36. sglang/srt/layers/moe/fused_moe_triton/layer.py +5 -2
  37. sglang/srt/layers/moe/fused_moe_triton/moe_align_block_size.py +87 -0
  38. sglang/srt/layers/moe/topk.py +35 -12
  39. sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py +9 -1
  40. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +0 -3
  41. sglang/srt/layers/quantization/modelopt_quant.py +7 -0
  42. sglang/srt/layers/quantization/mxfp4.py +9 -4
  43. sglang/srt/layers/quantization/utils.py +13 -0
  44. sglang/srt/layers/quantization/w4afp8.py +30 -25
  45. sglang/srt/layers/quantization/w8a8_int8.py +7 -3
  46. sglang/srt/layers/rotary_embedding.py +28 -1
  47. sglang/srt/layers/sampler.py +29 -5
  48. sglang/srt/managers/cache_controller.py +62 -96
  49. sglang/srt/managers/detokenizer_manager.py +9 -2
  50. sglang/srt/managers/io_struct.py +27 -0
  51. sglang/srt/managers/mm_utils.py +5 -1
  52. sglang/srt/managers/multi_tokenizer_mixin.py +629 -0
  53. sglang/srt/managers/scheduler.py +39 -2
  54. sglang/srt/managers/scheduler_output_processor_mixin.py +20 -18
  55. sglang/srt/managers/scheduler_update_weights_mixin.py +8 -1
  56. sglang/srt/managers/tokenizer_manager.py +86 -39
  57. sglang/srt/mem_cache/chunk_cache.py +1 -1
  58. sglang/srt/mem_cache/hicache_storage.py +20 -3
  59. sglang/srt/mem_cache/hiradix_cache.py +94 -71
  60. sglang/srt/mem_cache/lora_radix_cache.py +1 -1
  61. sglang/srt/mem_cache/memory_pool.py +4 -0
  62. sglang/srt/mem_cache/memory_pool_host.py +4 -4
  63. sglang/srt/mem_cache/radix_cache.py +5 -4
  64. sglang/srt/mem_cache/radix_cache_cpp.py +1 -1
  65. sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +61 -34
  66. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +56 -9
  67. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +2 -1
  68. sglang/srt/mem_cache/swa_radix_cache.py +1 -1
  69. sglang/srt/model_executor/model_runner.py +5 -4
  70. sglang/srt/model_loader/loader.py +15 -24
  71. sglang/srt/model_loader/utils.py +12 -0
  72. sglang/srt/models/deepseek_v2.py +31 -10
  73. sglang/srt/models/gpt_oss.py +5 -18
  74. sglang/srt/models/llama_eagle3.py +4 -0
  75. sglang/srt/models/longcat_flash.py +1026 -0
  76. sglang/srt/models/longcat_flash_nextn.py +699 -0
  77. sglang/srt/models/qwen2.py +26 -3
  78. sglang/srt/models/qwen2_5_vl.py +65 -41
  79. sglang/srt/models/qwen2_moe.py +22 -2
  80. sglang/srt/models/transformers.py +1 -1
  81. sglang/srt/multimodal/processors/base_processor.py +4 -2
  82. sglang/srt/sampling/penaltylib/orchestrator.py +14 -2
  83. sglang/srt/server_args.py +112 -55
  84. sglang/srt/speculative/eagle_worker.py +28 -8
  85. sglang/srt/utils.py +4 -0
  86. sglang/test/attention/test_trtllm_mla_backend.py +12 -3
  87. sglang/test/test_cutlass_w4a8_moe.py +24 -9
  88. sglang/version.py +1 -1
  89. {sglang-0.5.1.post3.dist-info → sglang-0.5.2rc1.dist-info}/METADATA +5 -5
  90. {sglang-0.5.1.post3.dist-info → sglang-0.5.2rc1.dist-info}/RECORD +93 -85
  91. {sglang-0.5.1.post3.dist-info → sglang-0.5.2rc1.dist-info}/WHEEL +0 -0
  92. {sglang-0.5.1.post3.dist-info → sglang-0.5.2rc1.dist-info}/licenses/LICENSE +0 -0
  93. {sglang-0.5.1.post3.dist-info → sglang-0.5.2rc1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,629 @@
1
+ # Copyright 2023-2024 SGLang Team
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # ==============================================================================
14
+ """MultiTokenizerMixin is a class that provides nesscary methods for MultiTokenizerManager and DetokenizerManager."""
15
+ import asyncio
16
+ import dataclasses
17
+ import json
18
+ import logging
19
+ import multiprocessing as multiprocessing
20
+ import os
21
+ import sys
22
+ import threading
23
+ from multiprocessing import shared_memory
24
+ from typing import Dict
25
+
26
+ import zmq
27
+ import zmq.asyncio
28
+
29
+ from sglang.srt.disaggregation.utils import DisaggregationMode, TransferBackend
30
+ from sglang.srt.managers.io_struct import (
31
+ BatchEmbeddingOut,
32
+ BatchMultimodalOut,
33
+ BatchStrOut,
34
+ BatchTokenIDOut,
35
+ MultiTokenizerRegisterReq,
36
+ MultiTokenizerWarpper,
37
+ )
38
+ from sglang.srt.managers.tokenizer_manager import TokenizerManager, _Communicator
39
+ from sglang.srt.server_args import PortArgs, ServerArgs
40
+ from sglang.srt.utils import get_zmq_socket, kill_process_tree
41
+ from sglang.utils import get_exception_traceback
42
+
43
+ logger = logging.getLogger(__name__)
44
+
45
+
46
+ class MultiTokenizerMixin:
47
+ """Mixin class for MultiTokenizerManager and DetokenizerManager"""
48
+
49
+ def create_sockets_mapping(self):
50
+ if not hasattr(self, "tokenizer_mapping"):
51
+ self.tokenizer_mapping = {}
52
+ # Create ZMQ context if needed
53
+ if not hasattr(self, "_zmq_context"):
54
+ self._zmq_context = zmq.Context()
55
+
56
+ def init_tokenizer_mapping(
57
+ self, recv_obj: MultiTokenizerRegisterReq, worker_id: str
58
+ ):
59
+ """init tokenizer mapping from register request"""
60
+ ipc_name = recv_obj.ipc_name
61
+ worker_id_int = int(worker_id)
62
+
63
+ if worker_id_int not in self.tokenizer_mapping:
64
+ socket = get_zmq_socket(self._zmq_context, zmq.PUSH, ipc_name, False)
65
+ self.tokenizer_mapping[worker_id_int] = socket
66
+ self.tokenizer_mapping[worker_id_int].send_pyobj(recv_obj)
67
+ return True
68
+ else:
69
+ return False
70
+
71
+ def register_tokenizer_ipc(self, recv_obj, worker_id):
72
+ if worker_id not in self.tokenizer_mapping:
73
+ # register the worker if not already done
74
+ if isinstance(recv_obj, MultiTokenizerRegisterReq):
75
+ return self.init_tokenizer_mapping(recv_obj, worker_id)
76
+ else:
77
+ logger.error(
78
+ f"Worker {worker_id} not registered and not found in tokenizer mapping . "
79
+ "Please ensure the worker is registered correctly."
80
+ )
81
+ return False
82
+
83
+ def _handle_output_by_index(self, output, i):
84
+ """NOTE: A maintainable method is better here."""
85
+ if isinstance(output, BatchTokenIDOut):
86
+ new_output = BatchTokenIDOut(
87
+ rids=[output.rids[i]],
88
+ finished_reasons=(
89
+ [output.finished_reasons[i]]
90
+ if len(output.finished_reasons) > i
91
+ else None
92
+ ),
93
+ decoded_texts=(
94
+ [output.decoded_texts[i]] if len(output.decoded_texts) > i else None
95
+ ),
96
+ decode_ids=(
97
+ [output.decode_ids[i]] if len(output.decode_ids) > i else None
98
+ ),
99
+ read_offsets=(
100
+ [output.read_offsets[i]] if len(output.read_offsets) > i else None
101
+ ),
102
+ output_ids=(
103
+ [output.output_ids[i]]
104
+ if output.output_ids and len(output.output_ids) > i
105
+ else None
106
+ ),
107
+ skip_special_tokens=(
108
+ [output.skip_special_tokens[i]]
109
+ if len(output.skip_special_tokens) > i
110
+ else None
111
+ ),
112
+ spaces_between_special_tokens=(
113
+ [output.spaces_between_special_tokens[i]]
114
+ if len(output.spaces_between_special_tokens) > i
115
+ else None
116
+ ),
117
+ no_stop_trim=(
118
+ [output.no_stop_trim[i]] if len(output.no_stop_trim) > i else None
119
+ ),
120
+ prompt_tokens=(
121
+ [output.prompt_tokens[i]] if len(output.prompt_tokens) > i else None
122
+ ),
123
+ completion_tokens=(
124
+ [output.completion_tokens[i]]
125
+ if len(output.completion_tokens) > i
126
+ else None
127
+ ),
128
+ cached_tokens=(
129
+ [output.cached_tokens[i]] if len(output.cached_tokens) > i else None
130
+ ),
131
+ spec_verify_ct=(
132
+ [output.spec_verify_ct[i]]
133
+ if len(output.spec_verify_ct) > i
134
+ else None
135
+ ),
136
+ input_token_logprobs_val=(
137
+ [output.input_token_logprobs_val[i]]
138
+ if output.input_token_logprobs_val
139
+ else None
140
+ ),
141
+ input_token_logprobs_idx=(
142
+ [output.input_token_logprobs_idx[i]]
143
+ if output.input_token_logprobs_idx
144
+ else None
145
+ ),
146
+ output_token_logprobs_val=(
147
+ [output.output_token_logprobs_val[i]]
148
+ if output.output_token_logprobs_val
149
+ else None
150
+ ),
151
+ output_token_logprobs_idx=(
152
+ [output.output_token_logprobs_idx[i]]
153
+ if output.output_token_logprobs_idx
154
+ else None
155
+ ),
156
+ input_top_logprobs_val=(
157
+ [output.input_top_logprobs_val[i]]
158
+ if output.input_top_logprobs_val
159
+ else None
160
+ ),
161
+ input_top_logprobs_idx=(
162
+ [output.input_top_logprobs_idx[i]]
163
+ if output.input_top_logprobs_idx
164
+ else None
165
+ ),
166
+ output_top_logprobs_val=(
167
+ [output.output_top_logprobs_val[i]]
168
+ if output.output_top_logprobs_val
169
+ else None
170
+ ),
171
+ output_top_logprobs_idx=(
172
+ [output.output_top_logprobs_idx[i]]
173
+ if output.output_top_logprobs_idx
174
+ else None
175
+ ),
176
+ input_token_ids_logprobs_val=(
177
+ [output.input_token_ids_logprobs_val[i]]
178
+ if output.input_token_ids_logprobs_val
179
+ else None
180
+ ),
181
+ input_token_ids_logprobs_idx=(
182
+ [output.input_token_ids_logprobs_idx[i]]
183
+ if output.input_token_ids_logprobs_idx
184
+ else None
185
+ ),
186
+ output_token_ids_logprobs_val=(
187
+ [output.output_token_ids_logprobs_val[i]]
188
+ if output.output_token_ids_logprobs_val
189
+ else None
190
+ ),
191
+ output_token_ids_logprobs_idx=(
192
+ [output.output_token_ids_logprobs_idx[i]]
193
+ if output.output_token_ids_logprobs_idx
194
+ else None
195
+ ),
196
+ output_hidden_states=(
197
+ [output.output_hidden_states[i]]
198
+ if output.output_hidden_states
199
+ else None
200
+ ),
201
+ )
202
+ elif isinstance(output, BatchEmbeddingOut):
203
+ new_output = BatchEmbeddingOut(
204
+ rids=[output.rids[i]],
205
+ finished_reasons=(
206
+ [output.finished_reasons[i]]
207
+ if len(output.finished_reasons) > i
208
+ else None
209
+ ),
210
+ embeddings=(
211
+ [output.embeddings[i]] if len(output.embeddings) > i else None
212
+ ),
213
+ prompt_tokens=(
214
+ [output.prompt_tokens[i]] if len(output.prompt_tokens) > i else None
215
+ ),
216
+ cached_tokens=(
217
+ [output.cached_tokens[i]] if len(output.cached_tokens) > i else None
218
+ ),
219
+ )
220
+ elif isinstance(output, BatchStrOut):
221
+ new_output = BatchStrOut(
222
+ rids=[output.rids[i]],
223
+ finished_reasons=(
224
+ [output.finished_reasons[i]]
225
+ if len(output.finished_reasons) > i
226
+ else None
227
+ ),
228
+ output_strs=(
229
+ [output.output_strs[i]] if len(output.output_strs) > i else None
230
+ ),
231
+ output_ids=(
232
+ [output.output_ids[i]]
233
+ if output.output_ids and len(output.output_ids) > i
234
+ else None
235
+ ),
236
+ prompt_tokens=(
237
+ [output.prompt_tokens[i]] if len(output.prompt_tokens) > i else None
238
+ ),
239
+ completion_tokens=(
240
+ [output.completion_tokens[i]]
241
+ if len(output.completion_tokens) > i
242
+ else None
243
+ ),
244
+ cached_tokens=(
245
+ [output.cached_tokens[i]] if len(output.cached_tokens) > i else None
246
+ ),
247
+ spec_verify_ct=(
248
+ [output.spec_verify_ct[i]]
249
+ if len(output.spec_verify_ct) > i
250
+ else None
251
+ ),
252
+ input_token_logprobs_val=(
253
+ [output.input_token_logprobs_val[i]]
254
+ if output.input_token_logprobs_val
255
+ else None
256
+ ),
257
+ input_token_logprobs_idx=(
258
+ [output.input_token_logprobs_idx[i]]
259
+ if output.input_token_logprobs_idx
260
+ else None
261
+ ),
262
+ output_token_logprobs_val=(
263
+ [output.output_token_logprobs_val[i]]
264
+ if output.output_token_logprobs_val
265
+ else None
266
+ ),
267
+ output_token_logprobs_idx=(
268
+ [output.output_token_logprobs_idx[i]]
269
+ if output.output_token_logprobs_idx
270
+ else None
271
+ ),
272
+ input_top_logprobs_val=(
273
+ [output.input_top_logprobs_val[i]]
274
+ if output.input_top_logprobs_val
275
+ else None
276
+ ),
277
+ input_top_logprobs_idx=(
278
+ [output.input_top_logprobs_idx[i]]
279
+ if output.input_top_logprobs_idx
280
+ else None
281
+ ),
282
+ output_top_logprobs_val=(
283
+ [output.output_top_logprobs_val[i]]
284
+ if output.output_top_logprobs_val
285
+ else None
286
+ ),
287
+ output_top_logprobs_idx=(
288
+ [output.output_top_logprobs_idx[i]]
289
+ if output.output_top_logprobs_idx
290
+ else None
291
+ ),
292
+ input_token_ids_logprobs_val=(
293
+ [output.input_token_ids_logprobs_val[i]]
294
+ if output.input_token_ids_logprobs_val
295
+ else None
296
+ ),
297
+ input_token_ids_logprobs_idx=(
298
+ [output.input_token_ids_logprobs_idx[i]]
299
+ if output.input_token_ids_logprobs_idx
300
+ else None
301
+ ),
302
+ output_token_ids_logprobs_val=(
303
+ [output.output_token_ids_logprobs_val[i]]
304
+ if output.output_token_ids_logprobs_val
305
+ else None
306
+ ),
307
+ output_token_ids_logprobs_idx=(
308
+ [output.output_token_ids_logprobs_idx[i]]
309
+ if output.output_token_ids_logprobs_idx
310
+ else None
311
+ ),
312
+ output_hidden_states=(
313
+ [output.output_hidden_states[i]]
314
+ if output.output_hidden_states
315
+ else None
316
+ ),
317
+ )
318
+ elif isinstance(output, BatchMultimodalOut):
319
+ new_output = BatchMultimodalOut(
320
+ rids=[output.rids[i]],
321
+ finished_reasons=(
322
+ [output.finished_reasons[i]]
323
+ if len(output.finished_reasons) > i
324
+ else None
325
+ ),
326
+ outputs=([output.outputs[i]] if len(output.outputs) > i else None),
327
+ prompt_tokens=(
328
+ [output.prompt_tokens[i]] if len(output.prompt_tokens) > i else None
329
+ ),
330
+ completion_tokens=(
331
+ [output.completion_tokens[i]]
332
+ if len(output.completion_tokens) > i
333
+ else None
334
+ ),
335
+ cached_tokens=(
336
+ [output.cached_tokens[i]] if len(output.cached_tokens) > i else None
337
+ ),
338
+ )
339
+ else:
340
+ new_output = output
341
+ return new_output
342
+
343
+ def get_worker_ids_from_req_rids(self, rids):
344
+ if isinstance(rids, list):
345
+ worker_ids = [int(rid.split("_")[0]) for rid in rids]
346
+ elif isinstance(rids, str):
347
+ worker_ids = [int(rids.split("_")[0])]
348
+ else:
349
+ worker_ids = []
350
+ return worker_ids
351
+
352
+ def multi_tokenizer_manager_event_loop(self):
353
+ """The event loop that handles requests, for multi tokenizer manager mode only"""
354
+ self.create_sockets_mapping()
355
+ while True:
356
+ recv_obj = self.recv_from_scheduler.recv_pyobj()
357
+ output = self._request_dispatcher(recv_obj)
358
+ if output is None:
359
+ continue
360
+ # Extract worker_id from rid
361
+ if isinstance(recv_obj.rids, list):
362
+ worker_ids = self.get_worker_ids_from_req_rids(recv_obj.rids)
363
+ else:
364
+ raise RuntimeError(
365
+ f"for tokenizer_worker_num > 1, recv_obj.rids must be a list"
366
+ )
367
+
368
+ # Send data using the corresponding socket
369
+ for i, worker_id in enumerate(worker_ids):
370
+ if isinstance(recv_obj, MultiTokenizerRegisterReq):
371
+ if self.register_tokenizer_ipc(recv_obj, worker_id):
372
+ logger.info(
373
+ f"DetokenizerManager Created ZMQ socket for worker {worker_id}"
374
+ )
375
+ continue
376
+ else:
377
+ if worker_id not in self.tokenizer_mapping:
378
+ logger.error(
379
+ f"Tokenizer Worker ID {worker_id} not registered. Check if the server Process {worker_id} is alive"
380
+ )
381
+ continue
382
+ new_output = self._handle_output_by_index(output, i)
383
+ self.tokenizer_mapping[worker_id].send_pyobj(new_output)
384
+
385
+ def clear_tokenizer_mapping(self):
386
+ if hasattr(self, "tokenizer_mapping"):
387
+ for socket in self.tokenizer_mapping.values():
388
+ try:
389
+ socket.close()
390
+ except Exception as e:
391
+ logger.warning(f"Failed to close socket: {e}")
392
+ self.tokenizer_mapping.clear()
393
+
394
+
395
+ class MultiTokenizerRouter(TokenizerManager, MultiTokenizerMixin):
396
+ """A router to receive requests from MultiTokenizerManager"""
397
+
398
+ def __init__(
399
+ self,
400
+ server_args: ServerArgs,
401
+ port_args: PortArgs,
402
+ ):
403
+ self.server_args = server_args
404
+ context = zmq.asyncio.Context(3)
405
+ self.recv_from_detokenizer = get_zmq_socket(
406
+ context, zmq.PULL, port_args.tokenizer_ipc_name, True
407
+ )
408
+ self.send_to_scheduler = get_zmq_socket(
409
+ context, zmq.PUSH, port_args.scheduler_input_ipc_name, True
410
+ )
411
+ self.receive_from_worker = get_zmq_socket(
412
+ context, zmq.PULL, port_args.tokenizer_worker_ipc_name, True
413
+ )
414
+ self._loop = asyncio.new_event_loop()
415
+ self._thread = threading.Thread(target=self._run_loop, daemon=True)
416
+ self._thread.start()
417
+ self._task = asyncio.run_coroutine_threadsafe(
418
+ self.router_worker_obj(), self._loop
419
+ )
420
+ # Start handle_loop simultaneously
421
+ self._handle_task = asyncio.run_coroutine_threadsafe(
422
+ print_exception_wrapper(self.handle_loop), self._loop
423
+ )
424
+ self.init_disaggregation()
425
+
426
+ def _run_loop(self):
427
+ self._loop.run_forever()
428
+
429
+ async def router_worker_obj(self):
430
+ while True:
431
+ recv_obj = await self.receive_from_worker.recv_pyobj()
432
+ await self.send_to_scheduler.send_pyobj(recv_obj)
433
+
434
+ async def handle_loop(self):
435
+ # special reqs will recv from scheduler, need to route to right worker
436
+ self.create_sockets_mapping()
437
+ while True:
438
+ recv_obj = await self.recv_from_detokenizer.recv_pyobj()
439
+ await self._distribute_result_to_workers(recv_obj)
440
+
441
+ async def _distribute_result_to_workers(self, recv_obj):
442
+ """Distribute result to corresponding workers based on rid"""
443
+ if isinstance(recv_obj, MultiTokenizerWarpper):
444
+ worker_ids = [recv_obj.worker_id]
445
+ recv_obj = recv_obj.obj
446
+ else:
447
+ worker_ids = self.get_worker_ids_from_req_rids(recv_obj.rids)
448
+
449
+ if len(worker_ids) == 0:
450
+ logger.error(f"Cannot find worker_id from rids {recv_obj.rids}")
451
+ return
452
+
453
+ # Distribute result to each worker
454
+ for i, worker_id in enumerate(worker_ids):
455
+ if isinstance(recv_obj, MultiTokenizerRegisterReq):
456
+ if self.register_tokenizer_ipc(recv_obj, worker_id):
457
+ logger.info(
458
+ f"MultiTokenizerRouter Created ZMQ socket for worker {worker_id}"
459
+ )
460
+ continue
461
+ else:
462
+ if worker_id not in self.tokenizer_mapping:
463
+ logger.error(
464
+ f"Tokenizer Worker ID {worker_id} not registered. Check if the server Process {worker_id} is alive"
465
+ )
466
+ continue
467
+ new_recv_obj = self._handle_output_by_index(recv_obj, i)
468
+ self.tokenizer_mapping[worker_id].send_pyobj(new_recv_obj)
469
+
470
+
471
+ class MultiTokenizerManager(TokenizerManager, MultiTokenizerMixin):
472
+ """Multi Process Tokenizer Manager that tokenizes the text."""
473
+
474
+ def __init__(
475
+ self,
476
+ server_args: ServerArgs,
477
+ port_args: PortArgs,
478
+ ):
479
+ # prevent init prefill bootstrapserver again
480
+ disaggregation_mode = server_args.disaggregation_mode
481
+ server_args.disaggregation_mode = "null"
482
+ super().__init__(server_args, port_args)
483
+
484
+ self.worker_id = os.getpid()
485
+ self.tokenizer_ipc_name = port_args.tokenizer_ipc_name
486
+
487
+ # For PD disaggregtion
488
+ self.server_args.disaggregation_mode = disaggregation_mode
489
+ self.disaggregation_mode = DisaggregationMode(
490
+ self.server_args.disaggregation_mode
491
+ )
492
+ self.disaggregation_transfer_backend = TransferBackend(
493
+ self.server_args.disaggregation_transfer_backend
494
+ )
495
+ # Communicator
496
+ self.register_multi_tokenizer_communicator = _Communicator(
497
+ self.send_to_scheduler, 2
498
+ )
499
+ self._result_dispatcher._mapping.append(
500
+ (
501
+ MultiTokenizerRegisterReq,
502
+ self.register_multi_tokenizer_communicator.handle_recv,
503
+ )
504
+ )
505
+
506
+ async def register_to_main_tokenizer_manager(self):
507
+ """Register this worker to the main TokenizerManager"""
508
+ # create a handle loop to receive messages from the main TokenizerManager
509
+ self.auto_create_handle_loop()
510
+ req = MultiTokenizerRegisterReq(rids=[f"{self.worker_id}_register"])
511
+ req.ipc_name = self.tokenizer_ipc_name
512
+ _Communicator.enable_multi_tokenizer = True
513
+ await self.register_multi_tokenizer_communicator(req)
514
+
515
+
516
+ async def print_exception_wrapper(func):
517
+ """
518
+ Sometimes an asyncio function does not print exception.
519
+ We do another wrapper to handle the exception.
520
+ """
521
+ try:
522
+ await func()
523
+ except Exception:
524
+ traceback = get_exception_traceback()
525
+ logger.error(f"MultiTokenizerRouter hit an exception: {traceback}")
526
+ if hasattr(func, "__self__") and isinstance(
527
+ func.__self__, MultiTokenizerRouter
528
+ ):
529
+ func.__self__.dump_requests_before_crash()
530
+ kill_process_tree(os.getpid(), include_parent=True)
531
+ sys.exit(1)
532
+
533
+
534
+ def serialize_port_args(port_args: PortArgs) -> dict:
535
+ """Serialize PortArgs into a shareable dictionary"""
536
+ return {
537
+ "tokenizer_ipc_name": port_args.tokenizer_ipc_name,
538
+ "scheduler_input_ipc_name": port_args.scheduler_input_ipc_name,
539
+ "detokenizer_ipc_name": port_args.detokenizer_ipc_name,
540
+ "nccl_port": port_args.nccl_port,
541
+ "rpc_ipc_name": port_args.rpc_ipc_name,
542
+ "metrics_ipc_name": port_args.metrics_ipc_name,
543
+ "tokenizer_worker_ipc_name": port_args.tokenizer_worker_ipc_name,
544
+ }
545
+
546
+
547
+ def deserialize_data(port_args: dict, server_args: dict):
548
+ """Deserialize data from shared dictionaries"""
549
+ return PortArgs(**port_args), ServerArgs(**server_args)
550
+
551
+
552
+ def serialize_server_args(server_args: ServerArgs) -> dict:
553
+ """Serialize ServerArgs into a shareable dictionary"""
554
+ return dataclasses.asdict(server_args)
555
+
556
+
557
+ def serialize_scheduler_info(scheduler_info: Dict) -> dict:
558
+ """Serialize scheduler_info into a shareable dictionary"""
559
+ return scheduler_info
560
+
561
+
562
+ def deserialize_scheduler_info(data: dict) -> Dict:
563
+ """Deserialize scheduler_info from a shared dictionary"""
564
+ return data
565
+
566
+
567
+ def write_to_shared_memory(data: dict, name: str) -> shared_memory.SharedMemory:
568
+ """Write data to shared memory"""
569
+ serialized = json.dumps(data).encode("utf-8")
570
+ size = len(serialized)
571
+ try:
572
+ # Try to open existing shared memory
573
+ shm = shared_memory.SharedMemory(name=name)
574
+ # If size is insufficient, close and recreate
575
+ if shm.size < size:
576
+ shm.close()
577
+ shm.unlink()
578
+ shm = shared_memory.SharedMemory(create=True, size=size, name=name)
579
+ except FileNotFoundError:
580
+ # If not present, create new shared memory
581
+ shm = shared_memory.SharedMemory(create=True, size=size, name=name)
582
+
583
+ shm.buf[:size] = serialized
584
+ return shm
585
+
586
+
587
+ def read_from_shared_memory(name: str) -> dict:
588
+ """Read data from shared memory"""
589
+ try:
590
+ shm = shared_memory.SharedMemory(name=name)
591
+ data = json.loads(bytes(shm.buf).decode("utf-8"))
592
+ shm.close()
593
+ return data
594
+ except FileNotFoundError:
595
+ raise FileNotFoundError(f"Shared memory {name} not found")
596
+
597
+
598
+ def get_main_process_id() -> int:
599
+ """Get the main process ID"""
600
+ return multiprocessing.current_process()._parent_pid
601
+
602
+
603
+ def write_data_for_multi_tokenizer(
604
+ port_args: PortArgs, server_args: ServerArgs, scheduler_info: Dict
605
+ ):
606
+ """Write args information to share memory for multi-tokenizer"""
607
+ # get main process ID
608
+ main_pid = get_main_process_id()
609
+ current_pid = os.getpid()
610
+ logger.info(f"main process ID: {main_pid}, current process ID: {current_pid}")
611
+
612
+ # Write port_args to shared memory
613
+ port_args_shm = write_to_shared_memory(
614
+ serialize_port_args(port_args), f"port_args_{current_pid}"
615
+ )
616
+ # Write server_args to shared memory
617
+ server_args_shm = write_to_shared_memory(
618
+ serialize_server_args(server_args), f"server_args_{current_pid}"
619
+ )
620
+ # Write scheduler_info to shared memory
621
+ scheduler_info_shm = write_to_shared_memory(
622
+ serialize_scheduler_info(scheduler_info), f"scheduler_info_{current_pid}"
623
+ )
624
+
625
+ port_args_shm.close()
626
+ server_args_shm.close()
627
+ scheduler_info_shm.close()
628
+
629
+ return port_args_shm, server_args_shm, scheduler_info_shm