sglang 0.5.2rc1__py3-none-any.whl → 0.5.3rc0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (265) hide show
  1. sglang/bench_one_batch_server.py +10 -1
  2. sglang/bench_serving.py +257 -29
  3. sglang/lang/interpreter.py +1 -1
  4. sglang/srt/configs/__init__.py +4 -0
  5. sglang/srt/configs/device_config.py +3 -1
  6. sglang/srt/configs/dots_vlm.py +139 -0
  7. sglang/srt/configs/internvl.py +6 -0
  8. sglang/srt/configs/load_config.py +1 -0
  9. sglang/srt/configs/model_config.py +50 -6
  10. sglang/srt/configs/qwen3_next.py +326 -0
  11. sglang/srt/connector/__init__.py +8 -1
  12. sglang/srt/connector/remote_instance.py +82 -0
  13. sglang/srt/constrained/base_grammar_backend.py +48 -12
  14. sglang/srt/constrained/llguidance_backend.py +0 -1
  15. sglang/srt/constrained/outlines_backend.py +0 -1
  16. sglang/srt/constrained/xgrammar_backend.py +28 -9
  17. sglang/srt/custom_op.py +11 -1
  18. sglang/srt/debug_utils/dump_comparator.py +81 -44
  19. sglang/srt/debug_utils/dump_loader.py +97 -0
  20. sglang/srt/debug_utils/dumper.py +11 -3
  21. sglang/srt/debug_utils/text_comparator.py +73 -11
  22. sglang/srt/disaggregation/base/conn.py +1 -1
  23. sglang/srt/disaggregation/common/conn.py +15 -12
  24. sglang/srt/disaggregation/decode.py +21 -10
  25. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +4 -1
  26. sglang/srt/disaggregation/fake/conn.py +1 -1
  27. sglang/srt/disaggregation/mini_lb.py +6 -445
  28. sglang/srt/disaggregation/mooncake/conn.py +18 -10
  29. sglang/srt/disaggregation/nixl/conn.py +180 -16
  30. sglang/srt/disaggregation/prefill.py +5 -3
  31. sglang/srt/disaggregation/utils.py +5 -50
  32. sglang/srt/distributed/parallel_state.py +67 -43
  33. sglang/srt/entrypoints/engine.py +38 -17
  34. sglang/srt/entrypoints/grpc_request_manager.py +580 -0
  35. sglang/srt/entrypoints/grpc_server.py +680 -0
  36. sglang/srt/entrypoints/http_server.py +88 -53
  37. sglang/srt/entrypoints/openai/protocol.py +7 -4
  38. sglang/srt/entrypoints/openai/serving_base.py +46 -3
  39. sglang/srt/entrypoints/openai/serving_chat.py +39 -19
  40. sglang/srt/entrypoints/openai/serving_completions.py +15 -4
  41. sglang/srt/entrypoints/openai/serving_embedding.py +9 -4
  42. sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
  43. sglang/srt/entrypoints/openai/serving_responses.py +7 -4
  44. sglang/srt/entrypoints/openai/serving_score.py +1 -0
  45. sglang/srt/eplb/eplb_manager.py +2 -2
  46. sglang/srt/eplb/expert_distribution.py +26 -13
  47. sglang/srt/eplb/expert_location.py +8 -3
  48. sglang/srt/eplb/expert_location_updater.py +1 -1
  49. sglang/srt/function_call/base_format_detector.py +3 -6
  50. sglang/srt/function_call/ebnf_composer.py +11 -9
  51. sglang/srt/function_call/function_call_parser.py +6 -0
  52. sglang/srt/function_call/glm4_moe_detector.py +1 -1
  53. sglang/srt/function_call/gpt_oss_detector.py +1 -1
  54. sglang/srt/function_call/qwen3_coder_detector.py +1 -1
  55. sglang/srt/grpc/__init__.py +1 -0
  56. sglang/srt/grpc/sglang_scheduler_pb2.py +106 -0
  57. sglang/srt/grpc/sglang_scheduler_pb2.pyi +427 -0
  58. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +236 -0
  59. sglang/srt/hf_transformers_utils.py +4 -0
  60. sglang/srt/layers/activation.py +142 -9
  61. sglang/srt/layers/attention/aiter_backend.py +93 -68
  62. sglang/srt/layers/attention/ascend_backend.py +11 -4
  63. sglang/srt/layers/attention/fla/chunk.py +242 -0
  64. sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
  65. sglang/srt/layers/attention/fla/chunk_o.py +178 -0
  66. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
  67. sglang/srt/layers/attention/fla/cumsum.py +300 -0
  68. sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
  69. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
  70. sglang/srt/layers/attention/fla/index.py +37 -0
  71. sglang/srt/layers/attention/fla/l2norm.py +150 -0
  72. sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
  73. sglang/srt/layers/attention/fla/op.py +66 -0
  74. sglang/srt/layers/attention/fla/solve_tril.py +465 -0
  75. sglang/srt/layers/attention/fla/utils.py +331 -0
  76. sglang/srt/layers/attention/fla/wy_fast.py +158 -0
  77. sglang/srt/layers/attention/flashinfer_backend.py +6 -4
  78. sglang/srt/layers/attention/flashinfer_mla_backend.py +16 -12
  79. sglang/srt/layers/attention/hybrid_attn_backend.py +57 -50
  80. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +602 -0
  81. sglang/srt/layers/attention/intel_amx_backend.py +3 -0
  82. sglang/srt/layers/attention/mamba/causal_conv1d.py +128 -0
  83. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +1052 -0
  84. sglang/srt/layers/attention/mamba/mamba.py +64 -0
  85. sglang/srt/layers/attention/torch_native_backend.py +12 -6
  86. sglang/srt/layers/attention/triton_backend.py +18 -1
  87. sglang/srt/layers/attention/trtllm_mla_backend.py +124 -31
  88. sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
  89. sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
  90. sglang/srt/layers/communicator.py +45 -7
  91. sglang/srt/layers/dp_attention.py +30 -1
  92. sglang/srt/layers/layernorm.py +32 -15
  93. sglang/srt/layers/linear.py +34 -3
  94. sglang/srt/layers/logits_processor.py +29 -10
  95. sglang/srt/layers/moe/__init__.py +2 -1
  96. sglang/srt/layers/moe/cutlass_w4a8_moe.py +3 -3
  97. sglang/srt/layers/moe/ep_moe/kernels.py +1 -1
  98. sglang/srt/layers/moe/ep_moe/layer.py +182 -62
  99. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +156 -0
  100. sglang/srt/layers/moe/fused_moe_native.py +5 -3
  101. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  102. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
  103. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
  104. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/{E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → E=257,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json } +29 -29
  105. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  106. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
  107. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
  108. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
  109. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
  110. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  111. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
  112. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
  113. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +1 -1
  114. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
  115. sglang/srt/layers/moe/fused_moe_triton/layer.py +61 -59
  116. sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
  117. sglang/srt/layers/moe/moe_runner/base.py +274 -1
  118. sglang/srt/layers/moe/moe_runner/runner.py +80 -0
  119. sglang/srt/layers/moe/moe_runner/triton.py +448 -0
  120. sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
  121. sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
  122. sglang/srt/layers/moe/token_dispatcher/deepep.py +43 -39
  123. sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
  124. sglang/srt/layers/moe/topk.py +30 -9
  125. sglang/srt/layers/moe/utils.py +12 -7
  126. sglang/srt/layers/quantization/awq.py +19 -7
  127. sglang/srt/layers/quantization/base_config.py +11 -6
  128. sglang/srt/layers/quantization/blockwise_int8.py +38 -27
  129. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
  130. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
  131. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
  132. sglang/srt/layers/quantization/fp8.py +76 -47
  133. sglang/srt/layers/quantization/fp8_utils.py +50 -31
  134. sglang/srt/layers/quantization/gptq.py +25 -17
  135. sglang/srt/layers/quantization/modelopt_quant.py +182 -49
  136. sglang/srt/layers/quantization/moe_wna16.py +21 -18
  137. sglang/srt/layers/quantization/mxfp4.py +68 -41
  138. sglang/srt/layers/quantization/quark/quark_moe.py +32 -27
  139. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +49 -30
  140. sglang/srt/layers/quantization/quark/utils.py +97 -0
  141. sglang/srt/layers/quantization/rocm_mxfp4_utils.py +13 -0
  142. sglang/srt/layers/quantization/unquant.py +135 -47
  143. sglang/srt/layers/quantization/w4afp8.py +30 -17
  144. sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
  145. sglang/srt/layers/quantization/w8a8_int8.py +76 -38
  146. sglang/srt/layers/rocm_linear_utils.py +44 -0
  147. sglang/srt/layers/rotary_embedding.py +0 -18
  148. sglang/srt/layers/sampler.py +162 -18
  149. sglang/srt/lora/backend/base_backend.py +50 -8
  150. sglang/srt/lora/backend/triton_backend.py +90 -2
  151. sglang/srt/lora/layers.py +32 -0
  152. sglang/srt/lora/lora.py +4 -1
  153. sglang/srt/lora/lora_manager.py +35 -112
  154. sglang/srt/lora/mem_pool.py +24 -10
  155. sglang/srt/lora/utils.py +18 -9
  156. sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
  157. sglang/srt/managers/cache_controller.py +200 -199
  158. sglang/srt/managers/data_parallel_controller.py +105 -35
  159. sglang/srt/managers/detokenizer_manager.py +8 -4
  160. sglang/srt/managers/disagg_service.py +46 -0
  161. sglang/srt/managers/io_struct.py +199 -12
  162. sglang/srt/managers/mm_utils.py +1 -0
  163. sglang/srt/managers/multi_tokenizer_mixin.py +351 -397
  164. sglang/srt/managers/schedule_batch.py +77 -56
  165. sglang/srt/managers/schedule_policy.py +4 -3
  166. sglang/srt/managers/scheduler.py +191 -139
  167. sglang/srt/managers/scheduler_metrics_mixin.py +116 -9
  168. sglang/srt/managers/scheduler_output_processor_mixin.py +55 -11
  169. sglang/srt/managers/scheduler_profiler_mixin.py +1 -1
  170. sglang/srt/managers/template_manager.py +3 -3
  171. sglang/srt/managers/tokenizer_communicator_mixin.py +569 -0
  172. sglang/srt/managers/tokenizer_manager.py +260 -519
  173. sglang/srt/managers/tp_worker.py +53 -4
  174. sglang/srt/managers/tp_worker_overlap_thread.py +42 -19
  175. sglang/srt/mem_cache/allocator.py +1 -1
  176. sglang/srt/mem_cache/hicache_storage.py +18 -33
  177. sglang/srt/mem_cache/hiradix_cache.py +108 -48
  178. sglang/srt/mem_cache/memory_pool.py +347 -48
  179. sglang/srt/mem_cache/memory_pool_host.py +121 -57
  180. sglang/srt/mem_cache/radix_cache.py +0 -2
  181. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
  182. sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
  183. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +95 -5
  184. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +280 -0
  185. sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
  186. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +81 -20
  187. sglang/srt/mem_cache/storage/mooncake_store/test_mooncake_store.py +161 -0
  188. sglang/srt/mem_cache/swa_radix_cache.py +0 -2
  189. sglang/srt/metrics/collector.py +502 -77
  190. sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
  191. sglang/srt/metrics/utils.py +48 -0
  192. sglang/srt/model_executor/cpu_graph_runner.py +640 -0
  193. sglang/srt/model_executor/cuda_graph_runner.py +13 -5
  194. sglang/srt/model_executor/forward_batch_info.py +75 -19
  195. sglang/srt/model_executor/model_runner.py +357 -30
  196. sglang/srt/model_loader/__init__.py +9 -3
  197. sglang/srt/model_loader/loader.py +128 -4
  198. sglang/srt/model_loader/weight_utils.py +2 -1
  199. sglang/srt/models/apertus.py +686 -0
  200. sglang/srt/models/bailing_moe.py +798 -218
  201. sglang/srt/models/bailing_moe_nextn.py +168 -0
  202. sglang/srt/models/deepseek_v2.py +346 -48
  203. sglang/srt/models/dots_vlm.py +174 -0
  204. sglang/srt/models/dots_vlm_vit.py +337 -0
  205. sglang/srt/models/ernie4.py +1 -1
  206. sglang/srt/models/gemma3n_mm.py +1 -1
  207. sglang/srt/models/glm4_moe.py +11 -2
  208. sglang/srt/models/glm4v.py +4 -2
  209. sglang/srt/models/glm4v_moe.py +3 -0
  210. sglang/srt/models/gpt_oss.py +1 -1
  211. sglang/srt/models/internvl.py +28 -0
  212. sglang/srt/models/llama4.py +9 -0
  213. sglang/srt/models/llama_eagle3.py +13 -0
  214. sglang/srt/models/longcat_flash.py +2 -2
  215. sglang/srt/models/minicpmv.py +165 -3
  216. sglang/srt/models/mllama4.py +25 -0
  217. sglang/srt/models/opt.py +637 -0
  218. sglang/srt/models/qwen2.py +7 -0
  219. sglang/srt/models/qwen2_5_vl.py +27 -3
  220. sglang/srt/models/qwen2_moe.py +60 -13
  221. sglang/srt/models/qwen3.py +8 -2
  222. sglang/srt/models/qwen3_moe.py +40 -9
  223. sglang/srt/models/qwen3_next.py +1042 -0
  224. sglang/srt/models/qwen3_next_mtp.py +112 -0
  225. sglang/srt/models/step3_vl.py +1 -1
  226. sglang/srt/models/torch_native_llama.py +1 -1
  227. sglang/srt/multimodal/processors/dots_vlm.py +99 -0
  228. sglang/srt/multimodal/processors/glm4v.py +9 -9
  229. sglang/srt/multimodal/processors/internvl.py +141 -129
  230. sglang/srt/multimodal/processors/qwen_vl.py +15 -5
  231. sglang/srt/offloader.py +27 -3
  232. sglang/srt/{reasoning_parser.py → parser/reasoning_parser.py} +1 -1
  233. sglang/srt/remote_instance_weight_loader_utils.py +69 -0
  234. sglang/srt/sampling/sampling_batch_info.py +18 -15
  235. sglang/srt/server_args.py +355 -37
  236. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +5 -0
  237. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +10 -1
  238. sglang/srt/speculative/eagle_utils.py +0 -2
  239. sglang/srt/speculative/eagle_worker.py +197 -112
  240. sglang/srt/speculative/spec_info.py +5 -0
  241. sglang/srt/speculative/standalone_worker.py +109 -0
  242. sglang/srt/tracing/trace.py +552 -0
  243. sglang/srt/utils.py +46 -3
  244. sglang/srt/weight_sync/utils.py +1 -1
  245. sglang/test/attention/test_trtllm_mla_backend.py +169 -5
  246. sglang/test/few_shot_gsm8k.py +1 -0
  247. sglang/test/runners.py +4 -0
  248. sglang/test/test_cutlass_moe.py +24 -6
  249. sglang/test/test_disaggregation_utils.py +66 -0
  250. sglang/test/test_fp4_moe.py +370 -1
  251. sglang/test/test_utils.py +28 -1
  252. sglang/utils.py +12 -0
  253. sglang/version.py +1 -1
  254. {sglang-0.5.2rc1.dist-info → sglang-0.5.3rc0.dist-info}/METADATA +59 -123
  255. {sglang-0.5.2rc1.dist-info → sglang-0.5.3rc0.dist-info}/RECORD +263 -200
  256. sglang/srt/disaggregation/launch_lb.py +0 -118
  257. sglang/srt/mem_cache/storage/mooncake_store/unit_test.py +0 -40
  258. /sglang/srt/{model_parallel.py → layers/model_parallel.py} +0 -0
  259. /sglang/srt/{code_completion_parser.py → parser/code_completion_parser.py} +0 -0
  260. /sglang/srt/{conversation.py → parser/conversation.py} +0 -0
  261. /sglang/srt/{harmony_parser.py → parser/harmony_parser.py} +0 -0
  262. /sglang/srt/{jinja_template_utils.py → parser/jinja_template_utils.py} +0 -0
  263. {sglang-0.5.2rc1.dist-info → sglang-0.5.3rc0.dist-info}/WHEEL +0 -0
  264. {sglang-0.5.2rc1.dist-info → sglang-0.5.3rc0.dist-info}/licenses/LICENSE +0 -0
  265. {sglang-0.5.2rc1.dist-info → sglang-0.5.3rc0.dist-info}/top_level.txt +0 -0
@@ -13,29 +13,32 @@
13
13
  # ==============================================================================
14
14
  """MultiTokenizerMixin is a class that provides nesscary methods for MultiTokenizerManager and DetokenizerManager."""
15
15
  import asyncio
16
- import dataclasses
17
- import json
18
16
  import logging
19
17
  import multiprocessing as multiprocessing
20
18
  import os
19
+ import pickle
21
20
  import sys
22
21
  import threading
22
+ from functools import partialmethod
23
23
  from multiprocessing import shared_memory
24
- from typing import Dict
24
+ from typing import Any, Dict
25
25
 
26
+ import setproctitle
26
27
  import zmq
27
28
  import zmq.asyncio
28
29
 
29
30
  from sglang.srt.disaggregation.utils import DisaggregationMode, TransferBackend
31
+ from sglang.srt.managers.disagg_service import start_disagg_service
30
32
  from sglang.srt.managers.io_struct import (
31
33
  BatchEmbeddingOut,
32
34
  BatchMultimodalOut,
33
35
  BatchStrOut,
34
36
  BatchTokenIDOut,
35
37
  MultiTokenizerRegisterReq,
36
- MultiTokenizerWarpper,
38
+ MultiTokenizerWrapper,
37
39
  )
38
- from sglang.srt.managers.tokenizer_manager import TokenizerManager, _Communicator
40
+ from sglang.srt.managers.tokenizer_communicator_mixin import _Communicator
41
+ from sglang.srt.managers.tokenizer_manager import TokenizerManager
39
42
  from sglang.srt.server_args import PortArgs, ServerArgs
40
43
  from sglang.srt.utils import get_zmq_socket, kill_process_tree
41
44
  from sglang.utils import get_exception_traceback
@@ -43,302 +46,304 @@ from sglang.utils import get_exception_traceback
43
46
  logger = logging.getLogger(__name__)
44
47
 
45
48
 
46
- class MultiTokenizerMixin:
47
- """Mixin class for MultiTokenizerManager and DetokenizerManager"""
49
+ class SocketMapping:
50
+ def __init__(self):
51
+ self._zmq_context = zmq.Context()
52
+ self._mapping: Dict[str, zmq.Socket] = {}
48
53
 
49
- def create_sockets_mapping(self):
50
- if not hasattr(self, "tokenizer_mapping"):
51
- self.tokenizer_mapping = {}
52
- # Create ZMQ context if needed
53
- if not hasattr(self, "_zmq_context"):
54
- self._zmq_context = zmq.Context()
54
+ def clear_all_sockets(self):
55
+ for socket in self._mapping.values():
56
+ socket.close()
57
+ self._mapping.clear()
55
58
 
56
- def init_tokenizer_mapping(
57
- self, recv_obj: MultiTokenizerRegisterReq, worker_id: str
59
+ def register_ipc_mapping(
60
+ self, recv_obj: MultiTokenizerRegisterReq, worker_id: str, is_tokenizer: bool
58
61
  ):
59
- """init tokenizer mapping from register request"""
60
- ipc_name = recv_obj.ipc_name
61
- worker_id_int = int(worker_id)
62
-
63
- if worker_id_int not in self.tokenizer_mapping:
64
- socket = get_zmq_socket(self._zmq_context, zmq.PUSH, ipc_name, False)
65
- self.tokenizer_mapping[worker_id_int] = socket
66
- self.tokenizer_mapping[worker_id_int].send_pyobj(recv_obj)
67
- return True
68
- else:
69
- return False
70
-
71
- def register_tokenizer_ipc(self, recv_obj, worker_id):
72
- if worker_id not in self.tokenizer_mapping:
73
- # register the worker if not already done
74
- if isinstance(recv_obj, MultiTokenizerRegisterReq):
75
- return self.init_tokenizer_mapping(recv_obj, worker_id)
76
- else:
77
- logger.error(
78
- f"Worker {worker_id} not registered and not found in tokenizer mapping . "
79
- "Please ensure the worker is registered correctly."
80
- )
81
- return False
82
-
83
- def _handle_output_by_index(self, output, i):
84
- """NOTE: A maintainable method is better here."""
85
- if isinstance(output, BatchTokenIDOut):
86
- new_output = BatchTokenIDOut(
87
- rids=[output.rids[i]],
88
- finished_reasons=(
89
- [output.finished_reasons[i]]
90
- if len(output.finished_reasons) > i
91
- else None
92
- ),
93
- decoded_texts=(
94
- [output.decoded_texts[i]] if len(output.decoded_texts) > i else None
95
- ),
96
- decode_ids=(
97
- [output.decode_ids[i]] if len(output.decode_ids) > i else None
98
- ),
99
- read_offsets=(
100
- [output.read_offsets[i]] if len(output.read_offsets) > i else None
101
- ),
102
- output_ids=(
103
- [output.output_ids[i]]
104
- if output.output_ids and len(output.output_ids) > i
105
- else None
106
- ),
107
- skip_special_tokens=(
108
- [output.skip_special_tokens[i]]
109
- if len(output.skip_special_tokens) > i
110
- else None
111
- ),
112
- spaces_between_special_tokens=(
113
- [output.spaces_between_special_tokens[i]]
114
- if len(output.spaces_between_special_tokens) > i
115
- else None
116
- ),
117
- no_stop_trim=(
118
- [output.no_stop_trim[i]] if len(output.no_stop_trim) > i else None
119
- ),
120
- prompt_tokens=(
121
- [output.prompt_tokens[i]] if len(output.prompt_tokens) > i else None
122
- ),
123
- completion_tokens=(
124
- [output.completion_tokens[i]]
125
- if len(output.completion_tokens) > i
126
- else None
127
- ),
128
- cached_tokens=(
129
- [output.cached_tokens[i]] if len(output.cached_tokens) > i else None
130
- ),
131
- spec_verify_ct=(
132
- [output.spec_verify_ct[i]]
133
- if len(output.spec_verify_ct) > i
134
- else None
135
- ),
136
- input_token_logprobs_val=(
137
- [output.input_token_logprobs_val[i]]
138
- if output.input_token_logprobs_val
139
- else None
140
- ),
141
- input_token_logprobs_idx=(
142
- [output.input_token_logprobs_idx[i]]
143
- if output.input_token_logprobs_idx
144
- else None
145
- ),
146
- output_token_logprobs_val=(
147
- [output.output_token_logprobs_val[i]]
148
- if output.output_token_logprobs_val
149
- else None
150
- ),
151
- output_token_logprobs_idx=(
152
- [output.output_token_logprobs_idx[i]]
153
- if output.output_token_logprobs_idx
154
- else None
155
- ),
156
- input_top_logprobs_val=(
157
- [output.input_top_logprobs_val[i]]
158
- if output.input_top_logprobs_val
159
- else None
160
- ),
161
- input_top_logprobs_idx=(
162
- [output.input_top_logprobs_idx[i]]
163
- if output.input_top_logprobs_idx
164
- else None
165
- ),
166
- output_top_logprobs_val=(
167
- [output.output_top_logprobs_val[i]]
168
- if output.output_top_logprobs_val
169
- else None
170
- ),
171
- output_top_logprobs_idx=(
172
- [output.output_top_logprobs_idx[i]]
173
- if output.output_top_logprobs_idx
174
- else None
175
- ),
176
- input_token_ids_logprobs_val=(
177
- [output.input_token_ids_logprobs_val[i]]
178
- if output.input_token_ids_logprobs_val
179
- else None
180
- ),
181
- input_token_ids_logprobs_idx=(
182
- [output.input_token_ids_logprobs_idx[i]]
183
- if output.input_token_ids_logprobs_idx
184
- else None
185
- ),
186
- output_token_ids_logprobs_val=(
187
- [output.output_token_ids_logprobs_val[i]]
188
- if output.output_token_ids_logprobs_val
189
- else None
190
- ),
191
- output_token_ids_logprobs_idx=(
192
- [output.output_token_ids_logprobs_idx[i]]
193
- if output.output_token_ids_logprobs_idx
194
- else None
195
- ),
196
- output_hidden_states=(
197
- [output.output_hidden_states[i]]
198
- if output.output_hidden_states
199
- else None
200
- ),
201
- )
202
- elif isinstance(output, BatchEmbeddingOut):
203
- new_output = BatchEmbeddingOut(
204
- rids=[output.rids[i]],
205
- finished_reasons=(
206
- [output.finished_reasons[i]]
207
- if len(output.finished_reasons) > i
208
- else None
209
- ),
210
- embeddings=(
211
- [output.embeddings[i]] if len(output.embeddings) > i else None
212
- ),
213
- prompt_tokens=(
214
- [output.prompt_tokens[i]] if len(output.prompt_tokens) > i else None
215
- ),
216
- cached_tokens=(
217
- [output.cached_tokens[i]] if len(output.cached_tokens) > i else None
218
- ),
62
+ type_str = "tokenizer" if is_tokenizer else "detokenizer"
63
+ if worker_id in self._mapping:
64
+ logger.warning(
65
+ f"{type_str} already registered with worker {worker_id}, skipping..."
219
66
  )
220
- elif isinstance(output, BatchStrOut):
221
- new_output = BatchStrOut(
222
- rids=[output.rids[i]],
223
- finished_reasons=(
224
- [output.finished_reasons[i]]
225
- if len(output.finished_reasons) > i
226
- else None
227
- ),
228
- output_strs=(
229
- [output.output_strs[i]] if len(output.output_strs) > i else None
230
- ),
231
- output_ids=(
232
- [output.output_ids[i]]
233
- if output.output_ids and len(output.output_ids) > i
234
- else None
235
- ),
236
- prompt_tokens=(
237
- [output.prompt_tokens[i]] if len(output.prompt_tokens) > i else None
238
- ),
239
- completion_tokens=(
240
- [output.completion_tokens[i]]
241
- if len(output.completion_tokens) > i
242
- else None
243
- ),
244
- cached_tokens=(
245
- [output.cached_tokens[i]] if len(output.cached_tokens) > i else None
246
- ),
247
- spec_verify_ct=(
248
- [output.spec_verify_ct[i]]
249
- if len(output.spec_verify_ct) > i
250
- else None
251
- ),
252
- input_token_logprobs_val=(
253
- [output.input_token_logprobs_val[i]]
254
- if output.input_token_logprobs_val
255
- else None
256
- ),
257
- input_token_logprobs_idx=(
258
- [output.input_token_logprobs_idx[i]]
259
- if output.input_token_logprobs_idx
260
- else None
261
- ),
262
- output_token_logprobs_val=(
263
- [output.output_token_logprobs_val[i]]
264
- if output.output_token_logprobs_val
265
- else None
266
- ),
267
- output_token_logprobs_idx=(
268
- [output.output_token_logprobs_idx[i]]
269
- if output.output_token_logprobs_idx
270
- else None
271
- ),
272
- input_top_logprobs_val=(
273
- [output.input_top_logprobs_val[i]]
274
- if output.input_top_logprobs_val
275
- else None
276
- ),
277
- input_top_logprobs_idx=(
278
- [output.input_top_logprobs_idx[i]]
279
- if output.input_top_logprobs_idx
280
- else None
281
- ),
282
- output_top_logprobs_val=(
283
- [output.output_top_logprobs_val[i]]
284
- if output.output_top_logprobs_val
285
- else None
286
- ),
287
- output_top_logprobs_idx=(
288
- [output.output_top_logprobs_idx[i]]
289
- if output.output_top_logprobs_idx
290
- else None
291
- ),
292
- input_token_ids_logprobs_val=(
293
- [output.input_token_ids_logprobs_val[i]]
294
- if output.input_token_ids_logprobs_val
295
- else None
296
- ),
297
- input_token_ids_logprobs_idx=(
298
- [output.input_token_ids_logprobs_idx[i]]
299
- if output.input_token_ids_logprobs_idx
300
- else None
301
- ),
302
- output_token_ids_logprobs_val=(
303
- [output.output_token_ids_logprobs_val[i]]
304
- if output.output_token_ids_logprobs_val
305
- else None
306
- ),
307
- output_token_ids_logprobs_idx=(
308
- [output.output_token_ids_logprobs_idx[i]]
309
- if output.output_token_ids_logprobs_idx
310
- else None
311
- ),
312
- output_hidden_states=(
313
- [output.output_hidden_states[i]]
314
- if output.output_hidden_states
315
- else None
316
- ),
317
- )
318
- elif isinstance(output, BatchMultimodalOut):
319
- new_output = BatchMultimodalOut(
320
- rids=[output.rids[i]],
321
- finished_reasons=(
322
- [output.finished_reasons[i]]
323
- if len(output.finished_reasons) > i
324
- else None
325
- ),
326
- outputs=([output.outputs[i]] if len(output.outputs) > i else None),
327
- prompt_tokens=(
328
- [output.prompt_tokens[i]] if len(output.prompt_tokens) > i else None
329
- ),
330
- completion_tokens=(
331
- [output.completion_tokens[i]]
332
- if len(output.completion_tokens) > i
333
- else None
334
- ),
335
- cached_tokens=(
336
- [output.cached_tokens[i]] if len(output.cached_tokens) > i else None
337
- ),
67
+ return
68
+ logger.info(
69
+ f"{type_str} not registered with worker {worker_id}, registering..."
70
+ )
71
+ socket = get_zmq_socket(self._zmq_context, zmq.PUSH, recv_obj.ipc_name, False)
72
+ self._mapping[worker_id] = socket
73
+ self._mapping[worker_id].send_pyobj(recv_obj)
74
+
75
+ def send_output(self, worker_id: str, output: Any):
76
+ if worker_id not in self._mapping:
77
+ logger.error(
78
+ f"worker ID {worker_id} not registered. Check if the server Process is alive"
338
79
  )
339
- else:
340
- new_output = output
341
- return new_output
80
+ return
81
+ self._mapping[worker_id].send_pyobj(output)
82
+
83
+
84
+ def _handle_output_by_index(output, i):
85
+ """NOTE: A maintainable method is better here."""
86
+ if isinstance(output, BatchTokenIDOut):
87
+ new_output = BatchTokenIDOut(
88
+ rids=[output.rids[i]],
89
+ finished_reasons=(
90
+ [output.finished_reasons[i]]
91
+ if len(output.finished_reasons) > i
92
+ else None
93
+ ),
94
+ decoded_texts=(
95
+ [output.decoded_texts[i]] if len(output.decoded_texts) > i else None
96
+ ),
97
+ decode_ids=([output.decode_ids[i]] if len(output.decode_ids) > i else None),
98
+ read_offsets=(
99
+ [output.read_offsets[i]] if len(output.read_offsets) > i else None
100
+ ),
101
+ output_ids=(
102
+ [output.output_ids[i]]
103
+ if output.output_ids and len(output.output_ids) > i
104
+ else None
105
+ ),
106
+ skip_special_tokens=(
107
+ [output.skip_special_tokens[i]]
108
+ if len(output.skip_special_tokens) > i
109
+ else None
110
+ ),
111
+ spaces_between_special_tokens=(
112
+ [output.spaces_between_special_tokens[i]]
113
+ if len(output.spaces_between_special_tokens) > i
114
+ else None
115
+ ),
116
+ no_stop_trim=(
117
+ [output.no_stop_trim[i]] if len(output.no_stop_trim) > i else None
118
+ ),
119
+ prompt_tokens=(
120
+ [output.prompt_tokens[i]] if len(output.prompt_tokens) > i else None
121
+ ),
122
+ completion_tokens=(
123
+ [output.completion_tokens[i]]
124
+ if len(output.completion_tokens) > i
125
+ else None
126
+ ),
127
+ cached_tokens=(
128
+ [output.cached_tokens[i]] if len(output.cached_tokens) > i else None
129
+ ),
130
+ spec_verify_ct=(
131
+ [output.spec_verify_ct[i]] if len(output.spec_verify_ct) > i else None
132
+ ),
133
+ input_token_logprobs_val=(
134
+ [output.input_token_logprobs_val[i]]
135
+ if output.input_token_logprobs_val
136
+ else None
137
+ ),
138
+ input_token_logprobs_idx=(
139
+ [output.input_token_logprobs_idx[i]]
140
+ if output.input_token_logprobs_idx
141
+ else None
142
+ ),
143
+ output_token_logprobs_val=(
144
+ [output.output_token_logprobs_val[i]]
145
+ if output.output_token_logprobs_val
146
+ else None
147
+ ),
148
+ output_token_logprobs_idx=(
149
+ [output.output_token_logprobs_idx[i]]
150
+ if output.output_token_logprobs_idx
151
+ else None
152
+ ),
153
+ input_top_logprobs_val=(
154
+ [output.input_top_logprobs_val[i]]
155
+ if output.input_top_logprobs_val
156
+ else None
157
+ ),
158
+ input_top_logprobs_idx=(
159
+ [output.input_top_logprobs_idx[i]]
160
+ if output.input_top_logprobs_idx
161
+ else None
162
+ ),
163
+ output_top_logprobs_val=(
164
+ [output.output_top_logprobs_val[i]]
165
+ if output.output_top_logprobs_val
166
+ else None
167
+ ),
168
+ output_top_logprobs_idx=(
169
+ [output.output_top_logprobs_idx[i]]
170
+ if output.output_top_logprobs_idx
171
+ else None
172
+ ),
173
+ input_token_ids_logprobs_val=(
174
+ [output.input_token_ids_logprobs_val[i]]
175
+ if output.input_token_ids_logprobs_val
176
+ else None
177
+ ),
178
+ input_token_ids_logprobs_idx=(
179
+ [output.input_token_ids_logprobs_idx[i]]
180
+ if output.input_token_ids_logprobs_idx
181
+ else None
182
+ ),
183
+ output_token_ids_logprobs_val=(
184
+ [output.output_token_ids_logprobs_val[i]]
185
+ if output.output_token_ids_logprobs_val
186
+ else None
187
+ ),
188
+ output_token_ids_logprobs_idx=(
189
+ [output.output_token_ids_logprobs_idx[i]]
190
+ if output.output_token_ids_logprobs_idx
191
+ else None
192
+ ),
193
+ output_hidden_states=(
194
+ [output.output_hidden_states[i]]
195
+ if output.output_hidden_states
196
+ else None
197
+ ),
198
+ placeholder_tokens_idx=None,
199
+ placeholder_tokens_val=None,
200
+ )
201
+ elif isinstance(output, BatchEmbeddingOut):
202
+ new_output = BatchEmbeddingOut(
203
+ rids=[output.rids[i]],
204
+ finished_reasons=(
205
+ [output.finished_reasons[i]]
206
+ if len(output.finished_reasons) > i
207
+ else None
208
+ ),
209
+ embeddings=([output.embeddings[i]] if len(output.embeddings) > i else None),
210
+ prompt_tokens=(
211
+ [output.prompt_tokens[i]] if len(output.prompt_tokens) > i else None
212
+ ),
213
+ cached_tokens=(
214
+ [output.cached_tokens[i]] if len(output.cached_tokens) > i else None
215
+ ),
216
+ placeholder_tokens_idx=None,
217
+ placeholder_tokens_val=None,
218
+ )
219
+ elif isinstance(output, BatchStrOut):
220
+ new_output = BatchStrOut(
221
+ rids=[output.rids[i]],
222
+ finished_reasons=(
223
+ [output.finished_reasons[i]]
224
+ if len(output.finished_reasons) > i
225
+ else None
226
+ ),
227
+ output_strs=(
228
+ [output.output_strs[i]] if len(output.output_strs) > i else None
229
+ ),
230
+ output_ids=(
231
+ [output.output_ids[i]]
232
+ if output.output_ids and len(output.output_ids) > i
233
+ else None
234
+ ),
235
+ prompt_tokens=(
236
+ [output.prompt_tokens[i]] if len(output.prompt_tokens) > i else None
237
+ ),
238
+ completion_tokens=(
239
+ [output.completion_tokens[i]]
240
+ if len(output.completion_tokens) > i
241
+ else None
242
+ ),
243
+ cached_tokens=(
244
+ [output.cached_tokens[i]] if len(output.cached_tokens) > i else None
245
+ ),
246
+ spec_verify_ct=(
247
+ [output.spec_verify_ct[i]] if len(output.spec_verify_ct) > i else None
248
+ ),
249
+ input_token_logprobs_val=(
250
+ [output.input_token_logprobs_val[i]]
251
+ if output.input_token_logprobs_val
252
+ else None
253
+ ),
254
+ input_token_logprobs_idx=(
255
+ [output.input_token_logprobs_idx[i]]
256
+ if output.input_token_logprobs_idx
257
+ else None
258
+ ),
259
+ output_token_logprobs_val=(
260
+ [output.output_token_logprobs_val[i]]
261
+ if output.output_token_logprobs_val
262
+ else None
263
+ ),
264
+ output_token_logprobs_idx=(
265
+ [output.output_token_logprobs_idx[i]]
266
+ if output.output_token_logprobs_idx
267
+ else None
268
+ ),
269
+ input_top_logprobs_val=(
270
+ [output.input_top_logprobs_val[i]]
271
+ if output.input_top_logprobs_val
272
+ else None
273
+ ),
274
+ input_top_logprobs_idx=(
275
+ [output.input_top_logprobs_idx[i]]
276
+ if output.input_top_logprobs_idx
277
+ else None
278
+ ),
279
+ output_top_logprobs_val=(
280
+ [output.output_top_logprobs_val[i]]
281
+ if output.output_top_logprobs_val
282
+ else None
283
+ ),
284
+ output_top_logprobs_idx=(
285
+ [output.output_top_logprobs_idx[i]]
286
+ if output.output_top_logprobs_idx
287
+ else None
288
+ ),
289
+ input_token_ids_logprobs_val=(
290
+ [output.input_token_ids_logprobs_val[i]]
291
+ if output.input_token_ids_logprobs_val
292
+ else None
293
+ ),
294
+ input_token_ids_logprobs_idx=(
295
+ [output.input_token_ids_logprobs_idx[i]]
296
+ if output.input_token_ids_logprobs_idx
297
+ else None
298
+ ),
299
+ output_token_ids_logprobs_val=(
300
+ [output.output_token_ids_logprobs_val[i]]
301
+ if output.output_token_ids_logprobs_val
302
+ else None
303
+ ),
304
+ output_token_ids_logprobs_idx=(
305
+ [output.output_token_ids_logprobs_idx[i]]
306
+ if output.output_token_ids_logprobs_idx
307
+ else None
308
+ ),
309
+ output_hidden_states=(
310
+ [output.output_hidden_states[i]]
311
+ if output.output_hidden_states
312
+ else None
313
+ ),
314
+ placeholder_tokens_idx=None,
315
+ placeholder_tokens_val=None,
316
+ )
317
+ elif isinstance(output, BatchMultimodalOut):
318
+ new_output = BatchMultimodalOut(
319
+ rids=[output.rids[i]],
320
+ finished_reasons=(
321
+ [output.finished_reasons[i]]
322
+ if len(output.finished_reasons) > i
323
+ else None
324
+ ),
325
+ outputs=([output.outputs[i]] if len(output.outputs) > i else None),
326
+ prompt_tokens=(
327
+ [output.prompt_tokens[i]] if len(output.prompt_tokens) > i else None
328
+ ),
329
+ completion_tokens=(
330
+ [output.completion_tokens[i]]
331
+ if len(output.completion_tokens) > i
332
+ else None
333
+ ),
334
+ cached_tokens=(
335
+ [output.cached_tokens[i]] if len(output.cached_tokens) > i else None
336
+ ),
337
+ placeholder_tokens_idx=None,
338
+ placeholder_tokens_val=None,
339
+ )
340
+ else:
341
+ new_output = output
342
+ return new_output
343
+
344
+
345
+ class MultiHttpWorkerDetokenizerMixin:
346
+ """Mixin class for MultiTokenizerManager and DetokenizerManager"""
342
347
 
343
348
  def get_worker_ids_from_req_rids(self, rids):
344
349
  if isinstance(rids, list):
@@ -349,9 +354,13 @@ class MultiTokenizerMixin:
349
354
  worker_ids = []
350
355
  return worker_ids
351
356
 
352
- def multi_tokenizer_manager_event_loop(self):
353
- """The event loop that handles requests, for multi tokenizer manager mode only"""
354
- self.create_sockets_mapping()
357
+ def maybe_clear_socket_mapping(self):
358
+ if hasattr(self, "socket_mapping"):
359
+ self.socket_mapping.clear_all_sockets()
360
+
361
+ def multi_http_worker_event_loop(self):
362
+ """The event loop that handles requests, for multi multi-http-worker mode"""
363
+ self.socket_mapping = SocketMapping()
355
364
  while True:
356
365
  recv_obj = self.recv_from_scheduler.recv_pyobj()
357
366
  output = self._request_dispatcher(recv_obj)
@@ -368,31 +377,15 @@ class MultiTokenizerMixin:
368
377
  # Send data using the corresponding socket
369
378
  for i, worker_id in enumerate(worker_ids):
370
379
  if isinstance(recv_obj, MultiTokenizerRegisterReq):
371
- if self.register_tokenizer_ipc(recv_obj, worker_id):
372
- logger.info(
373
- f"DetokenizerManager Created ZMQ socket for worker {worker_id}"
374
- )
375
- continue
380
+ self.socket_mapping.register_ipc_mapping(
381
+ recv_obj, worker_id, is_tokenizer=False
382
+ )
376
383
  else:
377
- if worker_id not in self.tokenizer_mapping:
378
- logger.error(
379
- f"Tokenizer Worker ID {worker_id} not registered. Check if the server Process {worker_id} is alive"
380
- )
381
- continue
382
- new_output = self._handle_output_by_index(output, i)
383
- self.tokenizer_mapping[worker_id].send_pyobj(new_output)
384
-
385
- def clear_tokenizer_mapping(self):
386
- if hasattr(self, "tokenizer_mapping"):
387
- for socket in self.tokenizer_mapping.values():
388
- try:
389
- socket.close()
390
- except Exception as e:
391
- logger.warning(f"Failed to close socket: {e}")
392
- self.tokenizer_mapping.clear()
393
-
394
-
395
- class MultiTokenizerRouter(TokenizerManager, MultiTokenizerMixin):
384
+ new_output = _handle_output_by_index(output, i)
385
+ self.socket_mapping.send_output(worker_id, new_output)
386
+
387
+
388
+ class MultiTokenizerRouter:
396
389
  """A router to receive requests from MultiTokenizerManager"""
397
390
 
398
391
  def __init__(
@@ -421,7 +414,7 @@ class MultiTokenizerRouter(TokenizerManager, MultiTokenizerMixin):
421
414
  self._handle_task = asyncio.run_coroutine_threadsafe(
422
415
  print_exception_wrapper(self.handle_loop), self._loop
423
416
  )
424
- self.init_disaggregation()
417
+ self.disaggregation_bootstrap_server = start_disagg_service(self.server_args)
425
418
 
426
419
  def _run_loop(self):
427
420
  self._loop.run_forever()
@@ -433,14 +426,14 @@ class MultiTokenizerRouter(TokenizerManager, MultiTokenizerMixin):
433
426
 
434
427
  async def handle_loop(self):
435
428
  # special reqs will recv from scheduler, need to route to right worker
436
- self.create_sockets_mapping()
429
+ self.socket_mapping = SocketMapping()
437
430
  while True:
438
431
  recv_obj = await self.recv_from_detokenizer.recv_pyobj()
439
432
  await self._distribute_result_to_workers(recv_obj)
440
433
 
441
434
  async def _distribute_result_to_workers(self, recv_obj):
442
435
  """Distribute result to corresponding workers based on rid"""
443
- if isinstance(recv_obj, MultiTokenizerWarpper):
436
+ if isinstance(recv_obj, MultiTokenizerWrapper):
444
437
  worker_ids = [recv_obj.worker_id]
445
438
  recv_obj = recv_obj.obj
446
439
  else:
@@ -453,22 +446,15 @@ class MultiTokenizerRouter(TokenizerManager, MultiTokenizerMixin):
453
446
  # Distribute result to each worker
454
447
  for i, worker_id in enumerate(worker_ids):
455
448
  if isinstance(recv_obj, MultiTokenizerRegisterReq):
456
- if self.register_tokenizer_ipc(recv_obj, worker_id):
457
- logger.info(
458
- f"MultiTokenizerRouter Created ZMQ socket for worker {worker_id}"
459
- )
460
- continue
449
+ self.socket_mapping.register_ipc_mapping(
450
+ recv_obj, worker_id, is_tokenizer=True
451
+ )
461
452
  else:
462
- if worker_id not in self.tokenizer_mapping:
463
- logger.error(
464
- f"Tokenizer Worker ID {worker_id} not registered. Check if the server Process {worker_id} is alive"
465
- )
466
- continue
467
- new_recv_obj = self._handle_output_by_index(recv_obj, i)
468
- self.tokenizer_mapping[worker_id].send_pyobj(new_recv_obj)
453
+ new_recv_obj = _handle_output_by_index(recv_obj, i)
454
+ self.socket_mapping.send_output(worker_id, new_recv_obj)
469
455
 
470
456
 
471
- class MultiTokenizerManager(TokenizerManager, MultiTokenizerMixin):
457
+ class MultiTokenizerManager(TokenizerManager):
472
458
  """Multi Process Tokenizer Manager that tokenizes the text."""
473
459
 
474
460
  def __init__(
@@ -476,6 +462,7 @@ class MultiTokenizerManager(TokenizerManager, MultiTokenizerMixin):
476
462
  server_args: ServerArgs,
477
463
  port_args: PortArgs,
478
464
  ):
465
+ setproctitle.setproctitle(f"sglang::tokenizer_worker:{os.getpid()}")
479
466
  # prevent init prefill bootstrapserver again
480
467
  disaggregation_mode = server_args.disaggregation_mode
481
468
  server_args.disaggregation_mode = "null"
@@ -531,42 +518,14 @@ async def print_exception_wrapper(func):
531
518
  sys.exit(1)
532
519
 
533
520
 
534
- def serialize_port_args(port_args: PortArgs) -> dict:
535
- """Serialize PortArgs into a shareable dictionary"""
536
- return {
537
- "tokenizer_ipc_name": port_args.tokenizer_ipc_name,
538
- "scheduler_input_ipc_name": port_args.scheduler_input_ipc_name,
539
- "detokenizer_ipc_name": port_args.detokenizer_ipc_name,
540
- "nccl_port": port_args.nccl_port,
541
- "rpc_ipc_name": port_args.rpc_ipc_name,
542
- "metrics_ipc_name": port_args.metrics_ipc_name,
543
- "tokenizer_worker_ipc_name": port_args.tokenizer_worker_ipc_name,
544
- }
545
-
546
-
547
- def deserialize_data(port_args: dict, server_args: dict):
548
- """Deserialize data from shared dictionaries"""
549
- return PortArgs(**port_args), ServerArgs(**server_args)
550
-
551
-
552
- def serialize_server_args(server_args: ServerArgs) -> dict:
553
- """Serialize ServerArgs into a shareable dictionary"""
554
- return dataclasses.asdict(server_args)
555
-
556
-
557
- def serialize_scheduler_info(scheduler_info: Dict) -> dict:
558
- """Serialize scheduler_info into a shareable dictionary"""
559
- return scheduler_info
560
-
561
-
562
- def deserialize_scheduler_info(data: dict) -> Dict:
563
- """Deserialize scheduler_info from a shared dictionary"""
564
- return data
521
+ def get_main_process_id() -> int:
522
+ """Get the main process ID"""
523
+ return multiprocessing.current_process()._parent_pid
565
524
 
566
525
 
567
- def write_to_shared_memory(data: dict, name: str) -> shared_memory.SharedMemory:
526
+ def write_to_shared_memory(obj, name: str) -> shared_memory.SharedMemory:
568
527
  """Write data to shared memory"""
569
- serialized = json.dumps(data).encode("utf-8")
528
+ serialized = pickle.dumps(obj)
570
529
  size = len(serialized)
571
530
  try:
572
531
  # Try to open existing shared memory
@@ -584,22 +543,17 @@ def write_to_shared_memory(data: dict, name: str) -> shared_memory.SharedMemory:
584
543
  return shm
585
544
 
586
545
 
587
- def read_from_shared_memory(name: str) -> dict:
546
+ def read_from_shared_memory(name: str) -> Any:
588
547
  """Read data from shared memory"""
589
548
  try:
590
549
  shm = shared_memory.SharedMemory(name=name)
591
- data = json.loads(bytes(shm.buf).decode("utf-8"))
550
+ data = pickle.loads(bytes(shm.buf))
592
551
  shm.close()
593
552
  return data
594
553
  except FileNotFoundError:
595
554
  raise FileNotFoundError(f"Shared memory {name} not found")
596
555
 
597
556
 
598
- def get_main_process_id() -> int:
599
- """Get the main process ID"""
600
- return multiprocessing.current_process()._parent_pid
601
-
602
-
603
557
  def write_data_for_multi_tokenizer(
604
558
  port_args: PortArgs, server_args: ServerArgs, scheduler_info: Dict
605
559
  ):
@@ -608,22 +562,22 @@ def write_data_for_multi_tokenizer(
608
562
  main_pid = get_main_process_id()
609
563
  current_pid = os.getpid()
610
564
  logger.info(f"main process ID: {main_pid}, current process ID: {current_pid}")
565
+ args = (port_args, server_args, scheduler_info)
566
+ args_shm = write_to_shared_memory(args, f"multi_tokenizer_args_{current_pid}")
567
+ args_shm.close()
611
568
 
612
- # Write port_args to shared memory
613
- port_args_shm = write_to_shared_memory(
614
- serialize_port_args(port_args), f"port_args_{current_pid}"
615
- )
616
- # Write server_args to shared memory
617
- server_args_shm = write_to_shared_memory(
618
- serialize_server_args(server_args), f"server_args_{current_pid}"
619
- )
620
- # Write scheduler_info to shared memory
621
- scheduler_info_shm = write_to_shared_memory(
622
- serialize_scheduler_info(scheduler_info), f"scheduler_info_{current_pid}"
623
- )
624
-
625
- port_args_shm.close()
626
- server_args_shm.close()
627
- scheduler_info_shm.close()
628
-
629
- return port_args_shm, server_args_shm, scheduler_info_shm
569
+ return args_shm
570
+
571
+
572
+ def monkey_patch_uvicorn_multiprocessing(timeout: float = 10):
573
+ """Monkey patch uvicorn multiprocessing is_alive timeout"""
574
+ # from default 5s -> 10s
575
+ try:
576
+ from uvicorn.supervisors.multiprocess import Process
577
+
578
+ Process.is_alive = partialmethod(Process.is_alive, timeout=timeout)
579
+
580
+ except ImportError:
581
+ logger.warning(
582
+ "uvicorn.supervisors.multiprocess not found, skipping monkey patch"
583
+ )