sglang 0.5.1.post3__py3-none-any.whl → 0.5.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (245) hide show
  1. sglang/bench_one_batch.py +3 -0
  2. sglang/bench_one_batch_server.py +10 -1
  3. sglang/bench_serving.py +251 -26
  4. sglang/lang/interpreter.py +1 -1
  5. sglang/srt/configs/__init__.py +4 -0
  6. sglang/srt/configs/internvl.py +6 -0
  7. sglang/srt/configs/longcat_flash.py +104 -0
  8. sglang/srt/configs/model_config.py +37 -7
  9. sglang/srt/configs/qwen3_next.py +326 -0
  10. sglang/srt/connector/__init__.py +1 -1
  11. sglang/srt/connector/base_connector.py +1 -2
  12. sglang/srt/connector/redis.py +2 -2
  13. sglang/srt/connector/serde/__init__.py +1 -1
  14. sglang/srt/connector/serde/safe_serde.py +4 -3
  15. sglang/srt/custom_op.py +11 -1
  16. sglang/srt/debug_utils/dump_comparator.py +81 -44
  17. sglang/srt/debug_utils/dump_loader.py +97 -0
  18. sglang/srt/debug_utils/dumper.py +11 -3
  19. sglang/srt/debug_utils/text_comparator.py +73 -11
  20. sglang/srt/disaggregation/ascend/conn.py +75 -0
  21. sglang/srt/disaggregation/base/conn.py +1 -1
  22. sglang/srt/disaggregation/common/conn.py +15 -12
  23. sglang/srt/disaggregation/decode.py +6 -4
  24. sglang/srt/disaggregation/fake/conn.py +1 -1
  25. sglang/srt/disaggregation/mini_lb.py +6 -420
  26. sglang/srt/disaggregation/mooncake/conn.py +18 -10
  27. sglang/srt/disaggregation/nixl/conn.py +180 -16
  28. sglang/srt/disaggregation/prefill.py +6 -4
  29. sglang/srt/disaggregation/utils.py +5 -50
  30. sglang/srt/distributed/parallel_state.py +94 -58
  31. sglang/srt/entrypoints/engine.py +34 -14
  32. sglang/srt/entrypoints/http_server.py +172 -47
  33. sglang/srt/entrypoints/openai/protocol.py +63 -3
  34. sglang/srt/entrypoints/openai/serving_base.py +6 -2
  35. sglang/srt/entrypoints/openai/serving_chat.py +34 -19
  36. sglang/srt/entrypoints/openai/serving_completions.py +10 -4
  37. sglang/srt/entrypoints/openai/serving_embedding.py +8 -4
  38. sglang/srt/entrypoints/openai/serving_responses.py +7 -4
  39. sglang/srt/eplb/eplb_manager.py +28 -4
  40. sglang/srt/eplb/expert_distribution.py +55 -15
  41. sglang/srt/eplb/expert_location.py +8 -3
  42. sglang/srt/eplb/expert_location_updater.py +1 -1
  43. sglang/srt/function_call/ebnf_composer.py +11 -9
  44. sglang/srt/function_call/glm4_moe_detector.py +1 -1
  45. sglang/srt/function_call/gpt_oss_detector.py +1 -1
  46. sglang/srt/function_call/qwen3_coder_detector.py +1 -1
  47. sglang/srt/hf_transformers_utils.py +12 -0
  48. sglang/srt/layers/activation.py +44 -9
  49. sglang/srt/layers/attention/aiter_backend.py +93 -68
  50. sglang/srt/layers/attention/ascend_backend.py +250 -112
  51. sglang/srt/layers/attention/fla/chunk.py +242 -0
  52. sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
  53. sglang/srt/layers/attention/fla/chunk_o.py +178 -0
  54. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
  55. sglang/srt/layers/attention/fla/cumsum.py +300 -0
  56. sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
  57. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
  58. sglang/srt/layers/attention/fla/index.py +37 -0
  59. sglang/srt/layers/attention/fla/l2norm.py +150 -0
  60. sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
  61. sglang/srt/layers/attention/fla/op.py +66 -0
  62. sglang/srt/layers/attention/fla/solve_tril.py +465 -0
  63. sglang/srt/layers/attention/fla/utils.py +331 -0
  64. sglang/srt/layers/attention/fla/wy_fast.py +158 -0
  65. sglang/srt/layers/attention/flashinfer_backend.py +6 -4
  66. sglang/srt/layers/attention/flashinfer_mla_backend.py +16 -12
  67. sglang/srt/layers/attention/hybrid_attn_backend.py +47 -8
  68. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +584 -0
  69. sglang/srt/layers/attention/intel_amx_backend.py +3 -0
  70. sglang/srt/layers/attention/mamba/causal_conv1d.py +128 -0
  71. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +1052 -0
  72. sglang/srt/layers/attention/mamba/mamba.py +64 -0
  73. sglang/srt/layers/attention/torch_native_backend.py +12 -6
  74. sglang/srt/layers/attention/trtllm_mla_backend.py +126 -36
  75. sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
  76. sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
  77. sglang/srt/layers/communicator.py +45 -7
  78. sglang/srt/layers/layernorm.py +54 -12
  79. sglang/srt/layers/logits_processor.py +10 -3
  80. sglang/srt/layers/moe/__init__.py +2 -1
  81. sglang/srt/layers/moe/cutlass_w4a8_moe.py +4 -12
  82. sglang/srt/layers/moe/ep_moe/kernels.py +74 -0
  83. sglang/srt/layers/moe/ep_moe/layer.py +110 -49
  84. sglang/srt/layers/moe/fused_moe_native.py +5 -3
  85. sglang/srt/layers/moe/fused_moe_triton/__init__.py +5 -3
  86. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  87. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +146 -0
  88. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/{E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → E=257,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json } +29 -29
  89. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  90. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
  91. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
  92. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  93. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +9 -1049
  94. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +212 -0
  95. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +799 -0
  96. sglang/srt/layers/moe/fused_moe_triton/layer.py +56 -45
  97. sglang/srt/layers/moe/fused_moe_triton/moe_align_block_size.py +87 -0
  98. sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
  99. sglang/srt/layers/moe/moe_runner/base.py +274 -1
  100. sglang/srt/layers/moe/moe_runner/runner.py +80 -0
  101. sglang/srt/layers/moe/moe_runner/triton.py +448 -0
  102. sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
  103. sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
  104. sglang/srt/layers/moe/token_dispatcher/deepep.py +41 -38
  105. sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
  106. sglang/srt/layers/moe/topk.py +43 -12
  107. sglang/srt/layers/moe/utils.py +6 -5
  108. sglang/srt/layers/quantization/awq.py +19 -7
  109. sglang/srt/layers/quantization/base_config.py +11 -6
  110. sglang/srt/layers/quantization/blockwise_int8.py +38 -27
  111. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
  112. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
  113. sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py +9 -1
  114. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +0 -3
  115. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
  116. sglang/srt/layers/quantization/fp8.py +76 -47
  117. sglang/srt/layers/quantization/fp8_utils.py +43 -29
  118. sglang/srt/layers/quantization/gptq.py +25 -17
  119. sglang/srt/layers/quantization/modelopt_quant.py +107 -40
  120. sglang/srt/layers/quantization/moe_wna16.py +21 -18
  121. sglang/srt/layers/quantization/mxfp4.py +77 -45
  122. sglang/srt/layers/quantization/quark/quark_moe.py +32 -27
  123. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +49 -30
  124. sglang/srt/layers/quantization/quark/utils.py +97 -0
  125. sglang/srt/layers/quantization/rocm_mxfp4_utils.py +13 -0
  126. sglang/srt/layers/quantization/unquant.py +135 -47
  127. sglang/srt/layers/quantization/utils.py +13 -0
  128. sglang/srt/layers/quantization/w4afp8.py +60 -42
  129. sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
  130. sglang/srt/layers/quantization/w8a8_int8.py +83 -41
  131. sglang/srt/layers/rocm_linear_utils.py +44 -0
  132. sglang/srt/layers/rotary_embedding.py +28 -19
  133. sglang/srt/layers/sampler.py +29 -5
  134. sglang/srt/lora/backend/base_backend.py +50 -8
  135. sglang/srt/lora/backend/triton_backend.py +90 -2
  136. sglang/srt/lora/layers.py +32 -0
  137. sglang/srt/lora/lora.py +4 -1
  138. sglang/srt/lora/lora_manager.py +35 -112
  139. sglang/srt/lora/mem_pool.py +24 -10
  140. sglang/srt/lora/utils.py +18 -9
  141. sglang/srt/managers/cache_controller.py +242 -278
  142. sglang/srt/managers/data_parallel_controller.py +30 -15
  143. sglang/srt/managers/detokenizer_manager.py +13 -2
  144. sglang/srt/managers/disagg_service.py +46 -0
  145. sglang/srt/managers/io_struct.py +160 -11
  146. sglang/srt/managers/mm_utils.py +6 -1
  147. sglang/srt/managers/multi_tokenizer_mixin.py +579 -0
  148. sglang/srt/managers/schedule_batch.py +27 -44
  149. sglang/srt/managers/schedule_policy.py +4 -3
  150. sglang/srt/managers/scheduler.py +90 -115
  151. sglang/srt/managers/scheduler_metrics_mixin.py +114 -8
  152. sglang/srt/managers/scheduler_output_processor_mixin.py +29 -19
  153. sglang/srt/managers/scheduler_profiler_mixin.py +1 -1
  154. sglang/srt/managers/scheduler_update_weights_mixin.py +8 -1
  155. sglang/srt/managers/template_manager.py +3 -3
  156. sglang/srt/managers/tokenizer_communicator_mixin.py +491 -0
  157. sglang/srt/managers/tokenizer_manager.py +41 -477
  158. sglang/srt/managers/tp_worker.py +16 -4
  159. sglang/srt/managers/tp_worker_overlap_thread.py +8 -10
  160. sglang/srt/mem_cache/allocator.py +1 -1
  161. sglang/srt/mem_cache/chunk_cache.py +1 -1
  162. sglang/srt/mem_cache/hicache_storage.py +24 -22
  163. sglang/srt/mem_cache/hiradix_cache.py +184 -101
  164. sglang/srt/mem_cache/lora_radix_cache.py +1 -1
  165. sglang/srt/mem_cache/memory_pool.py +324 -41
  166. sglang/srt/mem_cache/memory_pool_host.py +25 -18
  167. sglang/srt/mem_cache/radix_cache.py +5 -6
  168. sglang/srt/mem_cache/radix_cache_cpp.py +1 -1
  169. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
  170. sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
  171. sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +61 -34
  172. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +149 -12
  173. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +280 -0
  174. sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
  175. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +74 -19
  176. sglang/srt/mem_cache/storage/mooncake_store/test_mooncake_store.py +161 -0
  177. sglang/srt/mem_cache/swa_radix_cache.py +1 -3
  178. sglang/srt/metrics/collector.py +484 -63
  179. sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
  180. sglang/srt/metrics/utils.py +48 -0
  181. sglang/srt/model_executor/cpu_graph_runner.py +640 -0
  182. sglang/srt/model_executor/cuda_graph_runner.py +13 -5
  183. sglang/srt/model_executor/forward_batch_info.py +72 -18
  184. sglang/srt/model_executor/model_runner.py +189 -31
  185. sglang/srt/model_loader/__init__.py +9 -3
  186. sglang/srt/model_loader/loader.py +33 -28
  187. sglang/srt/model_loader/utils.py +12 -0
  188. sglang/srt/model_loader/weight_utils.py +2 -1
  189. sglang/srt/models/deepseek_v2.py +311 -50
  190. sglang/srt/models/gemma3n_mm.py +1 -1
  191. sglang/srt/models/glm4_moe.py +10 -1
  192. sglang/srt/models/glm4v.py +4 -2
  193. sglang/srt/models/gpt_oss.py +5 -18
  194. sglang/srt/models/internvl.py +28 -0
  195. sglang/srt/models/llama4.py +9 -0
  196. sglang/srt/models/llama_eagle3.py +17 -0
  197. sglang/srt/models/longcat_flash.py +1026 -0
  198. sglang/srt/models/longcat_flash_nextn.py +699 -0
  199. sglang/srt/models/minicpmv.py +165 -3
  200. sglang/srt/models/mllama4.py +25 -0
  201. sglang/srt/models/opt.py +637 -0
  202. sglang/srt/models/qwen2.py +33 -3
  203. sglang/srt/models/qwen2_5_vl.py +90 -42
  204. sglang/srt/models/qwen2_moe.py +79 -14
  205. sglang/srt/models/qwen3.py +8 -2
  206. sglang/srt/models/qwen3_moe.py +39 -8
  207. sglang/srt/models/qwen3_next.py +1039 -0
  208. sglang/srt/models/qwen3_next_mtp.py +109 -0
  209. sglang/srt/models/torch_native_llama.py +1 -1
  210. sglang/srt/models/transformers.py +1 -1
  211. sglang/srt/multimodal/processors/base_processor.py +4 -2
  212. sglang/srt/multimodal/processors/glm4v.py +9 -9
  213. sglang/srt/multimodal/processors/internvl.py +141 -129
  214. sglang/srt/{reasoning_parser.py → parser/reasoning_parser.py} +1 -1
  215. sglang/srt/sampling/penaltylib/orchestrator.py +14 -2
  216. sglang/srt/sampling/sampling_batch_info.py +18 -15
  217. sglang/srt/server_args.py +297 -79
  218. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +5 -0
  219. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +10 -1
  220. sglang/srt/speculative/eagle_worker.py +216 -120
  221. sglang/srt/speculative/spec_info.py +5 -0
  222. sglang/srt/speculative/standalone_worker.py +109 -0
  223. sglang/srt/utils.py +37 -2
  224. sglang/srt/weight_sync/utils.py +1 -1
  225. sglang/test/attention/test_trtllm_mla_backend.py +181 -8
  226. sglang/test/few_shot_gsm8k.py +1 -0
  227. sglang/test/runners.py +4 -0
  228. sglang/test/test_cutlass_moe.py +24 -6
  229. sglang/test/test_cutlass_w4a8_moe.py +24 -9
  230. sglang/test/test_disaggregation_utils.py +66 -0
  231. sglang/test/test_utils.py +25 -1
  232. sglang/utils.py +5 -0
  233. sglang/version.py +1 -1
  234. {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/METADATA +11 -9
  235. {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/RECORD +243 -194
  236. sglang/srt/disaggregation/launch_lb.py +0 -131
  237. sglang/srt/mem_cache/storage/mooncake_store/unit_test.py +0 -40
  238. /sglang/srt/{model_parallel.py → layers/model_parallel.py} +0 -0
  239. /sglang/srt/{code_completion_parser.py → parser/code_completion_parser.py} +0 -0
  240. /sglang/srt/{conversation.py → parser/conversation.py} +0 -0
  241. /sglang/srt/{harmony_parser.py → parser/harmony_parser.py} +0 -0
  242. /sglang/srt/{jinja_template_utils.py → parser/jinja_template_utils.py} +0 -0
  243. {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/WHEEL +0 -0
  244. {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/licenses/LICENSE +0 -0
  245. {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,579 @@
1
+ # Copyright 2023-2024 SGLang Team
2
+ # Licensed under the Apache License, Version 2.0 (the "License");
3
+ # you may not use this file except in compliance with the License.
4
+ # You may obtain a copy of the License at
5
+ #
6
+ # http://www.apache.org/licenses/LICENSE-2.0
7
+ #
8
+ # Unless required by applicable law or agreed to in writing, software
9
+ # distributed under the License is distributed on an "AS IS" BASIS,
10
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
11
+ # See the License for the specific language governing permissions and
12
+ # limitations under the License.
13
+ # ==============================================================================
14
+ """MultiTokenizerMixin is a class that provides nesscary methods for MultiTokenizerManager and DetokenizerManager."""
15
+ import asyncio
16
+ import logging
17
+ import multiprocessing as multiprocessing
18
+ import os
19
+ import pickle
20
+ import sys
21
+ import threading
22
+ from functools import partialmethod
23
+ from multiprocessing import shared_memory
24
+ from typing import Any, Dict
25
+
26
+ import setproctitle
27
+ import zmq
28
+ import zmq.asyncio
29
+
30
+ from sglang.srt.disaggregation.utils import DisaggregationMode, TransferBackend
31
+ from sglang.srt.managers.disagg_service import start_disagg_service
32
+ from sglang.srt.managers.io_struct import (
33
+ BatchEmbeddingOut,
34
+ BatchMultimodalOut,
35
+ BatchStrOut,
36
+ BatchTokenIDOut,
37
+ MultiTokenizerRegisterReq,
38
+ MultiTokenizerWrapper,
39
+ )
40
+ from sglang.srt.managers.tokenizer_communicator_mixin import _Communicator
41
+ from sglang.srt.managers.tokenizer_manager import TokenizerManager
42
+ from sglang.srt.server_args import PortArgs, ServerArgs
43
+ from sglang.srt.utils import get_zmq_socket, kill_process_tree
44
+ from sglang.utils import get_exception_traceback
45
+
46
+ logger = logging.getLogger(__name__)
47
+
48
+
49
+ class SocketMapping:
50
+ def __init__(self):
51
+ self._zmq_context = zmq.Context()
52
+ self._mapping: Dict[str, zmq.Socket] = {}
53
+
54
+ def clear_all_sockets(self):
55
+ for socket in self._mapping.values():
56
+ socket.close()
57
+ self._mapping.clear()
58
+
59
+ def register_ipc_mapping(
60
+ self, recv_obj: MultiTokenizerRegisterReq, worker_id: str, is_tokenizer: bool
61
+ ):
62
+ type_str = "tokenizer" if is_tokenizer else "detokenizer"
63
+ if worker_id in self._mapping:
64
+ logger.warning(
65
+ f"{type_str} already registered with worker {worker_id}, skipping..."
66
+ )
67
+ return
68
+ logger.info(
69
+ f"{type_str} not registered with worker {worker_id}, registering..."
70
+ )
71
+ socket = get_zmq_socket(self._zmq_context, zmq.PUSH, recv_obj.ipc_name, False)
72
+ self._mapping[worker_id] = socket
73
+ self._mapping[worker_id].send_pyobj(recv_obj)
74
+
75
+ def send_output(self, worker_id: str, output: Any):
76
+ if worker_id not in self._mapping:
77
+ logger.error(
78
+ f"worker ID {worker_id} not registered. Check if the server Process is alive"
79
+ )
80
+ return
81
+ self._mapping[worker_id].send_pyobj(output)
82
+
83
+
84
+ def _handle_output_by_index(output, i):
85
+ """NOTE: A maintainable method is better here."""
86
+ if isinstance(output, BatchTokenIDOut):
87
+ new_output = BatchTokenIDOut(
88
+ rids=[output.rids[i]],
89
+ finished_reasons=(
90
+ [output.finished_reasons[i]]
91
+ if len(output.finished_reasons) > i
92
+ else None
93
+ ),
94
+ decoded_texts=(
95
+ [output.decoded_texts[i]] if len(output.decoded_texts) > i else None
96
+ ),
97
+ decode_ids=([output.decode_ids[i]] if len(output.decode_ids) > i else None),
98
+ read_offsets=(
99
+ [output.read_offsets[i]] if len(output.read_offsets) > i else None
100
+ ),
101
+ output_ids=(
102
+ [output.output_ids[i]]
103
+ if output.output_ids and len(output.output_ids) > i
104
+ else None
105
+ ),
106
+ skip_special_tokens=(
107
+ [output.skip_special_tokens[i]]
108
+ if len(output.skip_special_tokens) > i
109
+ else None
110
+ ),
111
+ spaces_between_special_tokens=(
112
+ [output.spaces_between_special_tokens[i]]
113
+ if len(output.spaces_between_special_tokens) > i
114
+ else None
115
+ ),
116
+ no_stop_trim=(
117
+ [output.no_stop_trim[i]] if len(output.no_stop_trim) > i else None
118
+ ),
119
+ prompt_tokens=(
120
+ [output.prompt_tokens[i]] if len(output.prompt_tokens) > i else None
121
+ ),
122
+ completion_tokens=(
123
+ [output.completion_tokens[i]]
124
+ if len(output.completion_tokens) > i
125
+ else None
126
+ ),
127
+ cached_tokens=(
128
+ [output.cached_tokens[i]] if len(output.cached_tokens) > i else None
129
+ ),
130
+ spec_verify_ct=(
131
+ [output.spec_verify_ct[i]] if len(output.spec_verify_ct) > i else None
132
+ ),
133
+ input_token_logprobs_val=(
134
+ [output.input_token_logprobs_val[i]]
135
+ if output.input_token_logprobs_val
136
+ else None
137
+ ),
138
+ input_token_logprobs_idx=(
139
+ [output.input_token_logprobs_idx[i]]
140
+ if output.input_token_logprobs_idx
141
+ else None
142
+ ),
143
+ output_token_logprobs_val=(
144
+ [output.output_token_logprobs_val[i]]
145
+ if output.output_token_logprobs_val
146
+ else None
147
+ ),
148
+ output_token_logprobs_idx=(
149
+ [output.output_token_logprobs_idx[i]]
150
+ if output.output_token_logprobs_idx
151
+ else None
152
+ ),
153
+ input_top_logprobs_val=(
154
+ [output.input_top_logprobs_val[i]]
155
+ if output.input_top_logprobs_val
156
+ else None
157
+ ),
158
+ input_top_logprobs_idx=(
159
+ [output.input_top_logprobs_idx[i]]
160
+ if output.input_top_logprobs_idx
161
+ else None
162
+ ),
163
+ output_top_logprobs_val=(
164
+ [output.output_top_logprobs_val[i]]
165
+ if output.output_top_logprobs_val
166
+ else None
167
+ ),
168
+ output_top_logprobs_idx=(
169
+ [output.output_top_logprobs_idx[i]]
170
+ if output.output_top_logprobs_idx
171
+ else None
172
+ ),
173
+ input_token_ids_logprobs_val=(
174
+ [output.input_token_ids_logprobs_val[i]]
175
+ if output.input_token_ids_logprobs_val
176
+ else None
177
+ ),
178
+ input_token_ids_logprobs_idx=(
179
+ [output.input_token_ids_logprobs_idx[i]]
180
+ if output.input_token_ids_logprobs_idx
181
+ else None
182
+ ),
183
+ output_token_ids_logprobs_val=(
184
+ [output.output_token_ids_logprobs_val[i]]
185
+ if output.output_token_ids_logprobs_val
186
+ else None
187
+ ),
188
+ output_token_ids_logprobs_idx=(
189
+ [output.output_token_ids_logprobs_idx[i]]
190
+ if output.output_token_ids_logprobs_idx
191
+ else None
192
+ ),
193
+ output_hidden_states=(
194
+ [output.output_hidden_states[i]]
195
+ if output.output_hidden_states
196
+ else None
197
+ ),
198
+ placeholder_tokens_idx=None,
199
+ placeholder_tokens_val=None,
200
+ )
201
+ elif isinstance(output, BatchEmbeddingOut):
202
+ new_output = BatchEmbeddingOut(
203
+ rids=[output.rids[i]],
204
+ finished_reasons=(
205
+ [output.finished_reasons[i]]
206
+ if len(output.finished_reasons) > i
207
+ else None
208
+ ),
209
+ embeddings=([output.embeddings[i]] if len(output.embeddings) > i else None),
210
+ prompt_tokens=(
211
+ [output.prompt_tokens[i]] if len(output.prompt_tokens) > i else None
212
+ ),
213
+ cached_tokens=(
214
+ [output.cached_tokens[i]] if len(output.cached_tokens) > i else None
215
+ ),
216
+ placeholder_tokens_idx=None,
217
+ placeholder_tokens_val=None,
218
+ )
219
+ elif isinstance(output, BatchStrOut):
220
+ new_output = BatchStrOut(
221
+ rids=[output.rids[i]],
222
+ finished_reasons=(
223
+ [output.finished_reasons[i]]
224
+ if len(output.finished_reasons) > i
225
+ else None
226
+ ),
227
+ output_strs=(
228
+ [output.output_strs[i]] if len(output.output_strs) > i else None
229
+ ),
230
+ output_ids=(
231
+ [output.output_ids[i]]
232
+ if output.output_ids and len(output.output_ids) > i
233
+ else None
234
+ ),
235
+ prompt_tokens=(
236
+ [output.prompt_tokens[i]] if len(output.prompt_tokens) > i else None
237
+ ),
238
+ completion_tokens=(
239
+ [output.completion_tokens[i]]
240
+ if len(output.completion_tokens) > i
241
+ else None
242
+ ),
243
+ cached_tokens=(
244
+ [output.cached_tokens[i]] if len(output.cached_tokens) > i else None
245
+ ),
246
+ spec_verify_ct=(
247
+ [output.spec_verify_ct[i]] if len(output.spec_verify_ct) > i else None
248
+ ),
249
+ input_token_logprobs_val=(
250
+ [output.input_token_logprobs_val[i]]
251
+ if output.input_token_logprobs_val
252
+ else None
253
+ ),
254
+ input_token_logprobs_idx=(
255
+ [output.input_token_logprobs_idx[i]]
256
+ if output.input_token_logprobs_idx
257
+ else None
258
+ ),
259
+ output_token_logprobs_val=(
260
+ [output.output_token_logprobs_val[i]]
261
+ if output.output_token_logprobs_val
262
+ else None
263
+ ),
264
+ output_token_logprobs_idx=(
265
+ [output.output_token_logprobs_idx[i]]
266
+ if output.output_token_logprobs_idx
267
+ else None
268
+ ),
269
+ input_top_logprobs_val=(
270
+ [output.input_top_logprobs_val[i]]
271
+ if output.input_top_logprobs_val
272
+ else None
273
+ ),
274
+ input_top_logprobs_idx=(
275
+ [output.input_top_logprobs_idx[i]]
276
+ if output.input_top_logprobs_idx
277
+ else None
278
+ ),
279
+ output_top_logprobs_val=(
280
+ [output.output_top_logprobs_val[i]]
281
+ if output.output_top_logprobs_val
282
+ else None
283
+ ),
284
+ output_top_logprobs_idx=(
285
+ [output.output_top_logprobs_idx[i]]
286
+ if output.output_top_logprobs_idx
287
+ else None
288
+ ),
289
+ input_token_ids_logprobs_val=(
290
+ [output.input_token_ids_logprobs_val[i]]
291
+ if output.input_token_ids_logprobs_val
292
+ else None
293
+ ),
294
+ input_token_ids_logprobs_idx=(
295
+ [output.input_token_ids_logprobs_idx[i]]
296
+ if output.input_token_ids_logprobs_idx
297
+ else None
298
+ ),
299
+ output_token_ids_logprobs_val=(
300
+ [output.output_token_ids_logprobs_val[i]]
301
+ if output.output_token_ids_logprobs_val
302
+ else None
303
+ ),
304
+ output_token_ids_logprobs_idx=(
305
+ [output.output_token_ids_logprobs_idx[i]]
306
+ if output.output_token_ids_logprobs_idx
307
+ else None
308
+ ),
309
+ output_hidden_states=(
310
+ [output.output_hidden_states[i]]
311
+ if output.output_hidden_states
312
+ else None
313
+ ),
314
+ placeholder_tokens_idx=None,
315
+ placeholder_tokens_val=None,
316
+ )
317
+ elif isinstance(output, BatchMultimodalOut):
318
+ new_output = BatchMultimodalOut(
319
+ rids=[output.rids[i]],
320
+ finished_reasons=(
321
+ [output.finished_reasons[i]]
322
+ if len(output.finished_reasons) > i
323
+ else None
324
+ ),
325
+ outputs=([output.outputs[i]] if len(output.outputs) > i else None),
326
+ prompt_tokens=(
327
+ [output.prompt_tokens[i]] if len(output.prompt_tokens) > i else None
328
+ ),
329
+ completion_tokens=(
330
+ [output.completion_tokens[i]]
331
+ if len(output.completion_tokens) > i
332
+ else None
333
+ ),
334
+ cached_tokens=(
335
+ [output.cached_tokens[i]] if len(output.cached_tokens) > i else None
336
+ ),
337
+ placeholder_tokens_idx=None,
338
+ placeholder_tokens_val=None,
339
+ )
340
+ else:
341
+ new_output = output
342
+ return new_output
343
+
344
+
345
+ class MultiHttpWorkerDetokenizerMixin:
346
+ """Mixin class for MultiTokenizerManager and DetokenizerManager"""
347
+
348
+ def get_worker_ids_from_req_rids(self, rids):
349
+ if isinstance(rids, list):
350
+ worker_ids = [int(rid.split("_")[0]) for rid in rids]
351
+ elif isinstance(rids, str):
352
+ worker_ids = [int(rids.split("_")[0])]
353
+ else:
354
+ worker_ids = []
355
+ return worker_ids
356
+
357
+ def multi_http_worker_event_loop(self):
358
+ """The event loop that handles requests, for multi multi-http-worker mode"""
359
+ self.socket_mapping = SocketMapping()
360
+ while True:
361
+ recv_obj = self.recv_from_scheduler.recv_pyobj()
362
+ output = self._request_dispatcher(recv_obj)
363
+ if output is None:
364
+ continue
365
+ # Extract worker_id from rid
366
+ if isinstance(recv_obj.rids, list):
367
+ worker_ids = self.get_worker_ids_from_req_rids(recv_obj.rids)
368
+ else:
369
+ raise RuntimeError(
370
+ f"for tokenizer_worker_num > 1, recv_obj.rids must be a list"
371
+ )
372
+
373
+ # Send data using the corresponding socket
374
+ for i, worker_id in enumerate(worker_ids):
375
+ if isinstance(recv_obj, MultiTokenizerRegisterReq):
376
+ self.socket_mapping.register_ipc_mapping(
377
+ recv_obj, worker_id, is_tokenizer=False
378
+ )
379
+ else:
380
+ new_output = _handle_output_by_index(output, i)
381
+ self.socket_mapping.send_output(worker_id, new_output)
382
+
383
+
384
+ class MultiTokenizerRouter:
385
+ """A router to receive requests from MultiTokenizerManager"""
386
+
387
+ def __init__(
388
+ self,
389
+ server_args: ServerArgs,
390
+ port_args: PortArgs,
391
+ ):
392
+ self.server_args = server_args
393
+ context = zmq.asyncio.Context(3)
394
+ self.recv_from_detokenizer = get_zmq_socket(
395
+ context, zmq.PULL, port_args.tokenizer_ipc_name, True
396
+ )
397
+ self.send_to_scheduler = get_zmq_socket(
398
+ context, zmq.PUSH, port_args.scheduler_input_ipc_name, True
399
+ )
400
+ self.receive_from_worker = get_zmq_socket(
401
+ context, zmq.PULL, port_args.tokenizer_worker_ipc_name, True
402
+ )
403
+ self._loop = asyncio.new_event_loop()
404
+ self._thread = threading.Thread(target=self._run_loop, daemon=True)
405
+ self._thread.start()
406
+ self._task = asyncio.run_coroutine_threadsafe(
407
+ self.router_worker_obj(), self._loop
408
+ )
409
+ # Start handle_loop simultaneously
410
+ self._handle_task = asyncio.run_coroutine_threadsafe(
411
+ print_exception_wrapper(self.handle_loop), self._loop
412
+ )
413
+ self.disaggregation_bootstrap_server = start_disagg_service(self.server_args)
414
+
415
+ def _run_loop(self):
416
+ self._loop.run_forever()
417
+
418
+ async def router_worker_obj(self):
419
+ while True:
420
+ recv_obj = await self.receive_from_worker.recv_pyobj()
421
+ await self.send_to_scheduler.send_pyobj(recv_obj)
422
+
423
+ async def handle_loop(self):
424
+ # special reqs will recv from scheduler, need to route to right worker
425
+ self.socket_mapping = SocketMapping()
426
+ while True:
427
+ recv_obj = await self.recv_from_detokenizer.recv_pyobj()
428
+ await self._distribute_result_to_workers(recv_obj)
429
+
430
+ async def _distribute_result_to_workers(self, recv_obj):
431
+ """Distribute result to corresponding workers based on rid"""
432
+ if isinstance(recv_obj, MultiTokenizerWrapper):
433
+ worker_ids = [recv_obj.worker_id]
434
+ recv_obj = recv_obj.obj
435
+ else:
436
+ worker_ids = self.get_worker_ids_from_req_rids(recv_obj.rids)
437
+
438
+ if len(worker_ids) == 0:
439
+ logger.error(f"Cannot find worker_id from rids {recv_obj.rids}")
440
+ return
441
+
442
+ # Distribute result to each worker
443
+ for i, worker_id in enumerate(worker_ids):
444
+ if isinstance(recv_obj, MultiTokenizerRegisterReq):
445
+ self.socket_mapping.register_ipc_mapping(
446
+ recv_obj, worker_id, is_tokenizer=True
447
+ )
448
+ else:
449
+ new_recv_obj = _handle_output_by_index(recv_obj, i)
450
+ self.socket_mapping.send_output(worker_id, new_recv_obj)
451
+
452
+
453
+ class MultiTokenizerManager(TokenizerManager):
454
+ """Multi Process Tokenizer Manager that tokenizes the text."""
455
+
456
+ def __init__(
457
+ self,
458
+ server_args: ServerArgs,
459
+ port_args: PortArgs,
460
+ ):
461
+ setproctitle.setproctitle(f"sglang::tokenizer_worker:{os.getpid()}")
462
+ # prevent init prefill bootstrapserver again
463
+ disaggregation_mode = server_args.disaggregation_mode
464
+ server_args.disaggregation_mode = "null"
465
+ super().__init__(server_args, port_args)
466
+
467
+ self.worker_id = os.getpid()
468
+ self.tokenizer_ipc_name = port_args.tokenizer_ipc_name
469
+
470
+ # For PD disaggregtion
471
+ self.server_args.disaggregation_mode = disaggregation_mode
472
+ self.disaggregation_mode = DisaggregationMode(
473
+ self.server_args.disaggregation_mode
474
+ )
475
+ self.disaggregation_transfer_backend = TransferBackend(
476
+ self.server_args.disaggregation_transfer_backend
477
+ )
478
+ # Communicator
479
+ self.register_multi_tokenizer_communicator = _Communicator(
480
+ self.send_to_scheduler, 2
481
+ )
482
+ self._result_dispatcher._mapping.append(
483
+ (
484
+ MultiTokenizerRegisterReq,
485
+ self.register_multi_tokenizer_communicator.handle_recv,
486
+ )
487
+ )
488
+
489
+ async def register_to_main_tokenizer_manager(self):
490
+ """Register this worker to the main TokenizerManager"""
491
+ # create a handle loop to receive messages from the main TokenizerManager
492
+ self.auto_create_handle_loop()
493
+ req = MultiTokenizerRegisterReq(rids=[f"{self.worker_id}_register"])
494
+ req.ipc_name = self.tokenizer_ipc_name
495
+ _Communicator.enable_multi_tokenizer = True
496
+ await self.register_multi_tokenizer_communicator(req)
497
+
498
+
499
+ async def print_exception_wrapper(func):
500
+ """
501
+ Sometimes an asyncio function does not print exception.
502
+ We do another wrapper to handle the exception.
503
+ """
504
+ try:
505
+ await func()
506
+ except Exception:
507
+ traceback = get_exception_traceback()
508
+ logger.error(f"MultiTokenizerRouter hit an exception: {traceback}")
509
+ if hasattr(func, "__self__") and isinstance(
510
+ func.__self__, MultiTokenizerRouter
511
+ ):
512
+ func.__self__.dump_requests_before_crash()
513
+ kill_process_tree(os.getpid(), include_parent=True)
514
+ sys.exit(1)
515
+
516
+
517
+ def get_main_process_id() -> int:
518
+ """Get the main process ID"""
519
+ return multiprocessing.current_process()._parent_pid
520
+
521
+
522
+ def write_to_shared_memory(obj, name: str) -> shared_memory.SharedMemory:
523
+ """Write data to shared memory"""
524
+ serialized = pickle.dumps(obj)
525
+ size = len(serialized)
526
+ try:
527
+ # Try to open existing shared memory
528
+ shm = shared_memory.SharedMemory(name=name)
529
+ # If size is insufficient, close and recreate
530
+ if shm.size < size:
531
+ shm.close()
532
+ shm.unlink()
533
+ shm = shared_memory.SharedMemory(create=True, size=size, name=name)
534
+ except FileNotFoundError:
535
+ # If not present, create new shared memory
536
+ shm = shared_memory.SharedMemory(create=True, size=size, name=name)
537
+
538
+ shm.buf[:size] = serialized
539
+ return shm
540
+
541
+
542
+ def read_from_shared_memory(name: str) -> Any:
543
+ """Read data from shared memory"""
544
+ try:
545
+ shm = shared_memory.SharedMemory(name=name)
546
+ data = pickle.loads(bytes(shm.buf))
547
+ shm.close()
548
+ return data
549
+ except FileNotFoundError:
550
+ raise FileNotFoundError(f"Shared memory {name} not found")
551
+
552
+
553
+ def write_data_for_multi_tokenizer(
554
+ port_args: PortArgs, server_args: ServerArgs, scheduler_info: Dict
555
+ ):
556
+ """Write args information to share memory for multi-tokenizer"""
557
+ # get main process ID
558
+ main_pid = get_main_process_id()
559
+ current_pid = os.getpid()
560
+ logger.info(f"main process ID: {main_pid}, current process ID: {current_pid}")
561
+ args = (port_args, server_args, scheduler_info)
562
+ args_shm = write_to_shared_memory(args, f"multi_tokenizer_args_{current_pid}")
563
+ args_shm.close()
564
+
565
+ return args_shm
566
+
567
+
568
+ def monkey_patch_uvicorn_multiprocessing(timeout: float = 10):
569
+ """Monkey patch uvicorn multiprocessing is_alive timeout"""
570
+ # from default 5s -> 10s
571
+ try:
572
+ from uvicorn.supervisors.multiprocess import Process
573
+
574
+ Process.is_alive = partialmethod(Process.is_alive, timeout=timeout)
575
+
576
+ except ImportError:
577
+ logger.warning(
578
+ "uvicorn.supervisors.multiprocess not found, skipping monkey patch"
579
+ )