sglang 0.5.1.post3__py3-none-any.whl → 0.5.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (245) hide show
  1. sglang/bench_one_batch.py +3 -0
  2. sglang/bench_one_batch_server.py +10 -1
  3. sglang/bench_serving.py +251 -26
  4. sglang/lang/interpreter.py +1 -1
  5. sglang/srt/configs/__init__.py +4 -0
  6. sglang/srt/configs/internvl.py +6 -0
  7. sglang/srt/configs/longcat_flash.py +104 -0
  8. sglang/srt/configs/model_config.py +37 -7
  9. sglang/srt/configs/qwen3_next.py +326 -0
  10. sglang/srt/connector/__init__.py +1 -1
  11. sglang/srt/connector/base_connector.py +1 -2
  12. sglang/srt/connector/redis.py +2 -2
  13. sglang/srt/connector/serde/__init__.py +1 -1
  14. sglang/srt/connector/serde/safe_serde.py +4 -3
  15. sglang/srt/custom_op.py +11 -1
  16. sglang/srt/debug_utils/dump_comparator.py +81 -44
  17. sglang/srt/debug_utils/dump_loader.py +97 -0
  18. sglang/srt/debug_utils/dumper.py +11 -3
  19. sglang/srt/debug_utils/text_comparator.py +73 -11
  20. sglang/srt/disaggregation/ascend/conn.py +75 -0
  21. sglang/srt/disaggregation/base/conn.py +1 -1
  22. sglang/srt/disaggregation/common/conn.py +15 -12
  23. sglang/srt/disaggregation/decode.py +6 -4
  24. sglang/srt/disaggregation/fake/conn.py +1 -1
  25. sglang/srt/disaggregation/mini_lb.py +6 -420
  26. sglang/srt/disaggregation/mooncake/conn.py +18 -10
  27. sglang/srt/disaggregation/nixl/conn.py +180 -16
  28. sglang/srt/disaggregation/prefill.py +6 -4
  29. sglang/srt/disaggregation/utils.py +5 -50
  30. sglang/srt/distributed/parallel_state.py +94 -58
  31. sglang/srt/entrypoints/engine.py +34 -14
  32. sglang/srt/entrypoints/http_server.py +172 -47
  33. sglang/srt/entrypoints/openai/protocol.py +63 -3
  34. sglang/srt/entrypoints/openai/serving_base.py +6 -2
  35. sglang/srt/entrypoints/openai/serving_chat.py +34 -19
  36. sglang/srt/entrypoints/openai/serving_completions.py +10 -4
  37. sglang/srt/entrypoints/openai/serving_embedding.py +8 -4
  38. sglang/srt/entrypoints/openai/serving_responses.py +7 -4
  39. sglang/srt/eplb/eplb_manager.py +28 -4
  40. sglang/srt/eplb/expert_distribution.py +55 -15
  41. sglang/srt/eplb/expert_location.py +8 -3
  42. sglang/srt/eplb/expert_location_updater.py +1 -1
  43. sglang/srt/function_call/ebnf_composer.py +11 -9
  44. sglang/srt/function_call/glm4_moe_detector.py +1 -1
  45. sglang/srt/function_call/gpt_oss_detector.py +1 -1
  46. sglang/srt/function_call/qwen3_coder_detector.py +1 -1
  47. sglang/srt/hf_transformers_utils.py +12 -0
  48. sglang/srt/layers/activation.py +44 -9
  49. sglang/srt/layers/attention/aiter_backend.py +93 -68
  50. sglang/srt/layers/attention/ascend_backend.py +250 -112
  51. sglang/srt/layers/attention/fla/chunk.py +242 -0
  52. sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
  53. sglang/srt/layers/attention/fla/chunk_o.py +178 -0
  54. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
  55. sglang/srt/layers/attention/fla/cumsum.py +300 -0
  56. sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
  57. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
  58. sglang/srt/layers/attention/fla/index.py +37 -0
  59. sglang/srt/layers/attention/fla/l2norm.py +150 -0
  60. sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
  61. sglang/srt/layers/attention/fla/op.py +66 -0
  62. sglang/srt/layers/attention/fla/solve_tril.py +465 -0
  63. sglang/srt/layers/attention/fla/utils.py +331 -0
  64. sglang/srt/layers/attention/fla/wy_fast.py +158 -0
  65. sglang/srt/layers/attention/flashinfer_backend.py +6 -4
  66. sglang/srt/layers/attention/flashinfer_mla_backend.py +16 -12
  67. sglang/srt/layers/attention/hybrid_attn_backend.py +47 -8
  68. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +584 -0
  69. sglang/srt/layers/attention/intel_amx_backend.py +3 -0
  70. sglang/srt/layers/attention/mamba/causal_conv1d.py +128 -0
  71. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +1052 -0
  72. sglang/srt/layers/attention/mamba/mamba.py +64 -0
  73. sglang/srt/layers/attention/torch_native_backend.py +12 -6
  74. sglang/srt/layers/attention/trtllm_mla_backend.py +126 -36
  75. sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
  76. sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
  77. sglang/srt/layers/communicator.py +45 -7
  78. sglang/srt/layers/layernorm.py +54 -12
  79. sglang/srt/layers/logits_processor.py +10 -3
  80. sglang/srt/layers/moe/__init__.py +2 -1
  81. sglang/srt/layers/moe/cutlass_w4a8_moe.py +4 -12
  82. sglang/srt/layers/moe/ep_moe/kernels.py +74 -0
  83. sglang/srt/layers/moe/ep_moe/layer.py +110 -49
  84. sglang/srt/layers/moe/fused_moe_native.py +5 -3
  85. sglang/srt/layers/moe/fused_moe_triton/__init__.py +5 -3
  86. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  87. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +146 -0
  88. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/{E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → E=257,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json } +29 -29
  89. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  90. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
  91. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
  92. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  93. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +9 -1049
  94. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +212 -0
  95. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +799 -0
  96. sglang/srt/layers/moe/fused_moe_triton/layer.py +56 -45
  97. sglang/srt/layers/moe/fused_moe_triton/moe_align_block_size.py +87 -0
  98. sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
  99. sglang/srt/layers/moe/moe_runner/base.py +274 -1
  100. sglang/srt/layers/moe/moe_runner/runner.py +80 -0
  101. sglang/srt/layers/moe/moe_runner/triton.py +448 -0
  102. sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
  103. sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
  104. sglang/srt/layers/moe/token_dispatcher/deepep.py +41 -38
  105. sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
  106. sglang/srt/layers/moe/topk.py +43 -12
  107. sglang/srt/layers/moe/utils.py +6 -5
  108. sglang/srt/layers/quantization/awq.py +19 -7
  109. sglang/srt/layers/quantization/base_config.py +11 -6
  110. sglang/srt/layers/quantization/blockwise_int8.py +38 -27
  111. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
  112. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
  113. sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py +9 -1
  114. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +0 -3
  115. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
  116. sglang/srt/layers/quantization/fp8.py +76 -47
  117. sglang/srt/layers/quantization/fp8_utils.py +43 -29
  118. sglang/srt/layers/quantization/gptq.py +25 -17
  119. sglang/srt/layers/quantization/modelopt_quant.py +107 -40
  120. sglang/srt/layers/quantization/moe_wna16.py +21 -18
  121. sglang/srt/layers/quantization/mxfp4.py +77 -45
  122. sglang/srt/layers/quantization/quark/quark_moe.py +32 -27
  123. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +49 -30
  124. sglang/srt/layers/quantization/quark/utils.py +97 -0
  125. sglang/srt/layers/quantization/rocm_mxfp4_utils.py +13 -0
  126. sglang/srt/layers/quantization/unquant.py +135 -47
  127. sglang/srt/layers/quantization/utils.py +13 -0
  128. sglang/srt/layers/quantization/w4afp8.py +60 -42
  129. sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
  130. sglang/srt/layers/quantization/w8a8_int8.py +83 -41
  131. sglang/srt/layers/rocm_linear_utils.py +44 -0
  132. sglang/srt/layers/rotary_embedding.py +28 -19
  133. sglang/srt/layers/sampler.py +29 -5
  134. sglang/srt/lora/backend/base_backend.py +50 -8
  135. sglang/srt/lora/backend/triton_backend.py +90 -2
  136. sglang/srt/lora/layers.py +32 -0
  137. sglang/srt/lora/lora.py +4 -1
  138. sglang/srt/lora/lora_manager.py +35 -112
  139. sglang/srt/lora/mem_pool.py +24 -10
  140. sglang/srt/lora/utils.py +18 -9
  141. sglang/srt/managers/cache_controller.py +242 -278
  142. sglang/srt/managers/data_parallel_controller.py +30 -15
  143. sglang/srt/managers/detokenizer_manager.py +13 -2
  144. sglang/srt/managers/disagg_service.py +46 -0
  145. sglang/srt/managers/io_struct.py +160 -11
  146. sglang/srt/managers/mm_utils.py +6 -1
  147. sglang/srt/managers/multi_tokenizer_mixin.py +579 -0
  148. sglang/srt/managers/schedule_batch.py +27 -44
  149. sglang/srt/managers/schedule_policy.py +4 -3
  150. sglang/srt/managers/scheduler.py +90 -115
  151. sglang/srt/managers/scheduler_metrics_mixin.py +114 -8
  152. sglang/srt/managers/scheduler_output_processor_mixin.py +29 -19
  153. sglang/srt/managers/scheduler_profiler_mixin.py +1 -1
  154. sglang/srt/managers/scheduler_update_weights_mixin.py +8 -1
  155. sglang/srt/managers/template_manager.py +3 -3
  156. sglang/srt/managers/tokenizer_communicator_mixin.py +491 -0
  157. sglang/srt/managers/tokenizer_manager.py +41 -477
  158. sglang/srt/managers/tp_worker.py +16 -4
  159. sglang/srt/managers/tp_worker_overlap_thread.py +8 -10
  160. sglang/srt/mem_cache/allocator.py +1 -1
  161. sglang/srt/mem_cache/chunk_cache.py +1 -1
  162. sglang/srt/mem_cache/hicache_storage.py +24 -22
  163. sglang/srt/mem_cache/hiradix_cache.py +184 -101
  164. sglang/srt/mem_cache/lora_radix_cache.py +1 -1
  165. sglang/srt/mem_cache/memory_pool.py +324 -41
  166. sglang/srt/mem_cache/memory_pool_host.py +25 -18
  167. sglang/srt/mem_cache/radix_cache.py +5 -6
  168. sglang/srt/mem_cache/radix_cache_cpp.py +1 -1
  169. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
  170. sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
  171. sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +61 -34
  172. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +149 -12
  173. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +280 -0
  174. sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
  175. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +74 -19
  176. sglang/srt/mem_cache/storage/mooncake_store/test_mooncake_store.py +161 -0
  177. sglang/srt/mem_cache/swa_radix_cache.py +1 -3
  178. sglang/srt/metrics/collector.py +484 -63
  179. sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
  180. sglang/srt/metrics/utils.py +48 -0
  181. sglang/srt/model_executor/cpu_graph_runner.py +640 -0
  182. sglang/srt/model_executor/cuda_graph_runner.py +13 -5
  183. sglang/srt/model_executor/forward_batch_info.py +72 -18
  184. sglang/srt/model_executor/model_runner.py +189 -31
  185. sglang/srt/model_loader/__init__.py +9 -3
  186. sglang/srt/model_loader/loader.py +33 -28
  187. sglang/srt/model_loader/utils.py +12 -0
  188. sglang/srt/model_loader/weight_utils.py +2 -1
  189. sglang/srt/models/deepseek_v2.py +311 -50
  190. sglang/srt/models/gemma3n_mm.py +1 -1
  191. sglang/srt/models/glm4_moe.py +10 -1
  192. sglang/srt/models/glm4v.py +4 -2
  193. sglang/srt/models/gpt_oss.py +5 -18
  194. sglang/srt/models/internvl.py +28 -0
  195. sglang/srt/models/llama4.py +9 -0
  196. sglang/srt/models/llama_eagle3.py +17 -0
  197. sglang/srt/models/longcat_flash.py +1026 -0
  198. sglang/srt/models/longcat_flash_nextn.py +699 -0
  199. sglang/srt/models/minicpmv.py +165 -3
  200. sglang/srt/models/mllama4.py +25 -0
  201. sglang/srt/models/opt.py +637 -0
  202. sglang/srt/models/qwen2.py +33 -3
  203. sglang/srt/models/qwen2_5_vl.py +90 -42
  204. sglang/srt/models/qwen2_moe.py +79 -14
  205. sglang/srt/models/qwen3.py +8 -2
  206. sglang/srt/models/qwen3_moe.py +39 -8
  207. sglang/srt/models/qwen3_next.py +1039 -0
  208. sglang/srt/models/qwen3_next_mtp.py +109 -0
  209. sglang/srt/models/torch_native_llama.py +1 -1
  210. sglang/srt/models/transformers.py +1 -1
  211. sglang/srt/multimodal/processors/base_processor.py +4 -2
  212. sglang/srt/multimodal/processors/glm4v.py +9 -9
  213. sglang/srt/multimodal/processors/internvl.py +141 -129
  214. sglang/srt/{reasoning_parser.py → parser/reasoning_parser.py} +1 -1
  215. sglang/srt/sampling/penaltylib/orchestrator.py +14 -2
  216. sglang/srt/sampling/sampling_batch_info.py +18 -15
  217. sglang/srt/server_args.py +297 -79
  218. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +5 -0
  219. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +10 -1
  220. sglang/srt/speculative/eagle_worker.py +216 -120
  221. sglang/srt/speculative/spec_info.py +5 -0
  222. sglang/srt/speculative/standalone_worker.py +109 -0
  223. sglang/srt/utils.py +37 -2
  224. sglang/srt/weight_sync/utils.py +1 -1
  225. sglang/test/attention/test_trtllm_mla_backend.py +181 -8
  226. sglang/test/few_shot_gsm8k.py +1 -0
  227. sglang/test/runners.py +4 -0
  228. sglang/test/test_cutlass_moe.py +24 -6
  229. sglang/test/test_cutlass_w4a8_moe.py +24 -9
  230. sglang/test/test_disaggregation_utils.py +66 -0
  231. sglang/test/test_utils.py +25 -1
  232. sglang/utils.py +5 -0
  233. sglang/version.py +1 -1
  234. {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/METADATA +11 -9
  235. {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/RECORD +243 -194
  236. sglang/srt/disaggregation/launch_lb.py +0 -131
  237. sglang/srt/mem_cache/storage/mooncake_store/unit_test.py +0 -40
  238. /sglang/srt/{model_parallel.py → layers/model_parallel.py} +0 -0
  239. /sglang/srt/{code_completion_parser.py → parser/code_completion_parser.py} +0 -0
  240. /sglang/srt/{conversation.py → parser/conversation.py} +0 -0
  241. /sglang/srt/{harmony_parser.py → parser/harmony_parser.py} +0 -0
  242. /sglang/srt/{jinja_template_utils.py → parser/jinja_template_utils.py} +0 -0
  243. {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/WHEEL +0 -0
  244. {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/licenses/LICENSE +0 -0
  245. {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/top_level.txt +0 -0
@@ -31,18 +31,7 @@ from contextlib import nullcontext
31
31
  from datetime import datetime
32
32
  from enum import Enum
33
33
  from http import HTTPStatus
34
- from typing import (
35
- Any,
36
- Awaitable,
37
- Deque,
38
- Dict,
39
- Generic,
40
- List,
41
- Optional,
42
- Tuple,
43
- TypeVar,
44
- Union,
45
- )
34
+ from typing import Any, Awaitable, Dict, List, Optional, Tuple, Union
46
35
 
47
36
  import fastapi
48
37
  import torch
@@ -53,18 +42,14 @@ from fastapi import BackgroundTasks
53
42
 
54
43
  from sglang.srt.aio_rwlock import RWLock
55
44
  from sglang.srt.configs.model_config import ModelConfig
56
- from sglang.srt.disaggregation.utils import (
57
- DisaggregationMode,
58
- KVClassType,
59
- TransferBackend,
60
- get_kv_class,
61
- )
45
+ from sglang.srt.disaggregation.utils import DisaggregationMode
62
46
  from sglang.srt.hf_transformers_utils import (
63
47
  get_processor,
64
48
  get_tokenizer,
65
49
  get_tokenizer_from_processor,
66
50
  )
67
51
  from sglang.srt.lora.lora_registry import LoRARef, LoRARegistry
52
+ from sglang.srt.managers.disagg_service import start_disagg_service
68
53
  from sglang.srt.managers.io_struct import (
69
54
  AbortReq,
70
55
  BatchEmbeddingOut,
@@ -76,51 +61,23 @@ from sglang.srt.managers.io_struct import (
76
61
  CloseSessionReqInput,
77
62
  ConfigureLoggingReq,
78
63
  EmbeddingReqInput,
79
- ExpertDistributionReq,
80
- ExpertDistributionReqOutput,
81
- FlushCacheReqInput,
82
- FlushCacheReqOutput,
83
64
  FreezeGCReq,
84
65
  GenerateReqInput,
85
- GetInternalStateReq,
86
- GetInternalStateReqOutput,
87
- GetWeightsByNameReqInput,
88
- GetWeightsByNameReqOutput,
89
66
  HealthCheckOutput,
90
- InitWeightsUpdateGroupReqInput,
91
- InitWeightsUpdateGroupReqOutput,
92
- LoadLoRAAdapterReqInput,
93
- LoadLoRAAdapterReqOutput,
94
- LoRAUpdateResult,
67
+ MultiTokenizerWrapper,
95
68
  OpenSessionReqInput,
96
69
  OpenSessionReqOutput,
97
- ProfileReq,
98
- ProfileReqOutput,
99
- ProfileReqType,
100
- ReleaseMemoryOccupationReqInput,
101
- ReleaseMemoryOccupationReqOutput,
102
- ResumeMemoryOccupationReqInput,
103
- ResumeMemoryOccupationReqOutput,
104
70
  SessionParams,
105
- SetInternalStateReq,
106
- SetInternalStateReqOutput,
107
- SlowDownReqInput,
108
- SlowDownReqOutput,
109
71
  TokenizedEmbeddingReqInput,
110
72
  TokenizedGenerateReqInput,
111
- UnloadLoRAAdapterReqInput,
112
- UnloadLoRAAdapterReqOutput,
113
73
  UpdateWeightFromDiskReqInput,
114
74
  UpdateWeightFromDiskReqOutput,
115
- UpdateWeightsFromDistributedReqInput,
116
- UpdateWeightsFromDistributedReqOutput,
117
- UpdateWeightsFromTensorReqInput,
118
- UpdateWeightsFromTensorReqOutput,
119
75
  )
120
76
  from sglang.srt.managers.mm_utils import TensorTransportMode
121
77
  from sglang.srt.managers.multimodal_processor import get_mm_processor, import_processors
122
78
  from sglang.srt.managers.scheduler import is_health_check_generate_req
123
79
  from sglang.srt.managers.scheduler_input_blocker import input_blocker_guard_region
80
+ from sglang.srt.managers.tokenizer_communicator_mixin import TokenizerCommunicatorMixin
124
81
  from sglang.srt.metrics.collector import TokenizerMetricsCollector
125
82
  from sglang.srt.sampling.sampling_params import SamplingParams
126
83
  from sglang.srt.server_args import PortArgs, ServerArgs
@@ -129,6 +86,7 @@ from sglang.srt.utils import (
129
86
  dataclass_to_string_truncated,
130
87
  freeze_gc,
131
88
  get_bool_env_var,
89
+ get_origin_rid,
132
90
  get_zmq_socket,
133
91
  kill_process_tree,
134
92
  )
@@ -176,7 +134,7 @@ class ReqState:
176
134
  output_token_ids_logprobs_idx: List = dataclasses.field(default_factory=list)
177
135
 
178
136
 
179
- class TokenizerManager:
137
+ class TokenizerManager(TokenizerCommunicatorMixin):
180
138
  """TokenizerManager is a process that tokenizes the text."""
181
139
 
182
140
  def __init__(
@@ -264,9 +222,15 @@ class TokenizerManager:
264
222
  self.recv_from_detokenizer = get_zmq_socket(
265
223
  context, zmq.PULL, port_args.tokenizer_ipc_name, True
266
224
  )
267
- self.send_to_scheduler = get_zmq_socket(
268
- context, zmq.PUSH, port_args.scheduler_input_ipc_name, True
269
- )
225
+ if self.server_args.tokenizer_worker_num > 1:
226
+ # Use tokenizer_worker_ipc_name in multi-tokenizer mode
227
+ self.send_to_scheduler = get_zmq_socket(
228
+ context, zmq.PUSH, port_args.tokenizer_worker_ipc_name, False
229
+ )
230
+ else:
231
+ self.send_to_scheduler = get_zmq_socket(
232
+ context, zmq.PUSH, port_args.scheduler_input_ipc_name, True
233
+ )
270
234
 
271
235
  # Request states
272
236
  self.no_create_loop = False
@@ -309,36 +273,10 @@ class TokenizerManager:
309
273
  # LoRA updates and inference to overlap.
310
274
  self.lora_update_lock = asyncio.Lock()
311
275
 
312
- # For PD disaggregtion
313
276
  self.disaggregation_mode = DisaggregationMode(
314
277
  self.server_args.disaggregation_mode
315
278
  )
316
- self.disaggregation_transfer_backend = TransferBackend(
317
- self.server_args.disaggregation_transfer_backend
318
- )
319
- # Start kv boostrap server on prefill
320
- if self.disaggregation_mode == DisaggregationMode.PREFILL:
321
- # only start bootstrap server on prefill tm
322
- kv_bootstrap_server_class = get_kv_class(
323
- self.disaggregation_transfer_backend, KVClassType.BOOTSTRAP_SERVER
324
- )
325
- self.bootstrap_server = kv_bootstrap_server_class(
326
- self.server_args.disaggregation_bootstrap_port
327
- )
328
- is_create_store = (
329
- self.server_args.node_rank == 0
330
- and self.server_args.disaggregation_transfer_backend == "ascend"
331
- )
332
- if is_create_store:
333
- try:
334
- from mf_adapter import create_config_store
335
-
336
- ascend_url = os.getenv("ASCEND_MF_STORE_URL")
337
- create_config_store(ascend_url)
338
- except Exception as e:
339
- error_message = f"Failed create mf store, invalid ascend_url."
340
- error_message += f" With exception {e}"
341
- raise error_message
279
+ self.bootstrap_server = start_disagg_service(self.server_args)
342
280
 
343
281
  # For load balancing
344
282
  self.current_load = 0
@@ -347,6 +285,7 @@ class TokenizerManager:
347
285
  # Metrics
348
286
  if self.enable_metrics:
349
287
  self.metrics_collector = TokenizerMetricsCollector(
288
+ server_args=server_args,
350
289
  labels={
351
290
  "model_name": self.server_args.served_model_name,
352
291
  # TODO: Add lora name/path in the future,
@@ -361,47 +300,6 @@ class TokenizerManager:
361
300
  if self.server_args.gc_warning_threshold_secs > 0.0:
362
301
  configure_gc_warning(self.server_args.gc_warning_threshold_secs)
363
302
 
364
- # Communicators
365
- self.init_weights_update_group_communicator = _Communicator(
366
- self.send_to_scheduler, server_args.dp_size
367
- )
368
- self.update_weights_from_distributed_communicator = _Communicator(
369
- self.send_to_scheduler, server_args.dp_size
370
- )
371
- self.update_weights_from_tensor_communicator = _Communicator(
372
- self.send_to_scheduler, server_args.dp_size
373
- )
374
- self.get_weights_by_name_communicator = _Communicator(
375
- self.send_to_scheduler, server_args.dp_size
376
- )
377
- self.release_memory_occupation_communicator = _Communicator(
378
- self.send_to_scheduler, server_args.dp_size
379
- )
380
- self.resume_memory_occupation_communicator = _Communicator(
381
- self.send_to_scheduler, server_args.dp_size
382
- )
383
- self.slow_down_communicator = _Communicator(
384
- self.send_to_scheduler, server_args.dp_size
385
- )
386
- self.flush_cache_communicator = _Communicator(
387
- self.send_to_scheduler, server_args.dp_size
388
- )
389
- self.profile_communicator = _Communicator(
390
- self.send_to_scheduler, server_args.dp_size
391
- )
392
- self.get_internal_state_communicator = _Communicator(
393
- self.send_to_scheduler, server_args.dp_size
394
- )
395
- self.set_internal_state_communicator = _Communicator(
396
- self.send_to_scheduler, server_args.dp_size
397
- )
398
- self.expert_distribution_communicator = _Communicator(
399
- self.send_to_scheduler, server_args.dp_size
400
- )
401
- self.update_lora_adapter_communicator = _Communicator(
402
- self.send_to_scheduler, server_args.dp_size
403
- )
404
-
405
303
  self._result_dispatcher = TypeBasedDispatcher(
406
304
  [
407
305
  (
@@ -419,66 +317,16 @@ class TokenizerManager:
419
317
  UpdateWeightFromDiskReqOutput,
420
318
  self._handle_update_weights_from_disk_req_output,
421
319
  ),
422
- (
423
- InitWeightsUpdateGroupReqOutput,
424
- self.init_weights_update_group_communicator.handle_recv,
425
- ),
426
- (
427
- UpdateWeightsFromDistributedReqOutput,
428
- self.update_weights_from_distributed_communicator.handle_recv,
429
- ),
430
- (
431
- UpdateWeightsFromTensorReqOutput,
432
- self.update_weights_from_tensor_communicator.handle_recv,
433
- ),
434
- (
435
- GetWeightsByNameReqOutput,
436
- self.get_weights_by_name_communicator.handle_recv,
437
- ),
438
- (
439
- ReleaseMemoryOccupationReqOutput,
440
- self.release_memory_occupation_communicator.handle_recv,
441
- ),
442
- (
443
- ResumeMemoryOccupationReqOutput,
444
- self.resume_memory_occupation_communicator.handle_recv,
445
- ),
446
- (
447
- SlowDownReqOutput,
448
- self.slow_down_communicator.handle_recv,
449
- ),
450
- (
451
- FlushCacheReqOutput,
452
- self.flush_cache_communicator.handle_recv,
453
- ),
454
- (
455
- ProfileReqOutput,
456
- self.profile_communicator.handle_recv,
457
- ),
458
320
  (
459
321
  FreezeGCReq,
460
322
  lambda x: None,
461
323
  ), # For handling case when scheduler skips detokenizer and forwards back to the tokenizer manager, we ignore it.
462
- (
463
- GetInternalStateReqOutput,
464
- self.get_internal_state_communicator.handle_recv,
465
- ),
466
- (
467
- SetInternalStateReqOutput,
468
- self.set_internal_state_communicator.handle_recv,
469
- ),
470
- (
471
- ExpertDistributionReqOutput,
472
- self.expert_distribution_communicator.handle_recv,
473
- ),
474
- (
475
- LoRAUpdateResult,
476
- self.update_lora_adapter_communicator.handle_recv,
477
- ),
478
324
  (HealthCheckOutput, lambda x: None),
479
325
  ]
480
326
  )
481
327
 
328
+ self.init_communicators(server_args)
329
+
482
330
  async def generate_request(
483
331
  self,
484
332
  obj: Union[GenerateReqInput, EmbeddingReqInput],
@@ -488,6 +336,15 @@ class TokenizerManager:
488
336
  self.auto_create_handle_loop()
489
337
  obj.normalize_batch_and_arguments()
490
338
 
339
+ if self.server_args.tokenizer_worker_num > 1:
340
+ # Modify rid, add worker_id
341
+ if isinstance(obj.rid, list):
342
+ # If it's an array, add worker_id prefix to each element
343
+ obj.rid = [f"{self.worker_id}_{rid}" for rid in obj.rid]
344
+ else:
345
+ # If it's a single value, add worker_id prefix
346
+ obj.rid = f"{self.worker_id}_{obj.rid}"
347
+
491
348
  if self.log_requests:
492
349
  max_length, skip_names, _ = self.log_request_metadata
493
350
  logger.info(
@@ -985,9 +842,6 @@ class TokenizerManager:
985
842
  except StopAsyncIteration:
986
843
  pass
987
844
 
988
- async def flush_cache(self) -> FlushCacheReqOutput:
989
- return (await self.flush_cache_communicator(FlushCacheReqInput()))[0]
990
-
991
845
  def abort_request(self, rid: str = "", abort_all: bool = False):
992
846
  if not abort_all and rid not in self.rid_to_state:
993
847
  return
@@ -997,55 +851,6 @@ class TokenizerManager:
997
851
  if self.enable_metrics:
998
852
  self.metrics_collector.observe_one_aborted_request()
999
853
 
1000
- async def start_profile(
1001
- self,
1002
- output_dir: Optional[str] = None,
1003
- start_step: Optional[int] = None,
1004
- num_steps: Optional[int] = None,
1005
- activities: Optional[List[str]] = None,
1006
- with_stack: Optional[bool] = None,
1007
- record_shapes: Optional[bool] = None,
1008
- profile_by_stage: bool = False,
1009
- ):
1010
- self.auto_create_handle_loop()
1011
- env_with_stack: bool = get_bool_env_var("SGLANG_PROFILE_WITH_STACK", "true")
1012
- with_stack = False if with_stack is False or env_with_stack is False else True
1013
- req = ProfileReq(
1014
- type=ProfileReqType.START_PROFILE,
1015
- output_dir=output_dir,
1016
- start_step=start_step,
1017
- num_steps=num_steps,
1018
- activities=activities,
1019
- with_stack=with_stack,
1020
- record_shapes=record_shapes,
1021
- profile_by_stage=profile_by_stage,
1022
- profile_id=str(time.time()),
1023
- )
1024
- return await self._execute_profile(req)
1025
-
1026
- async def stop_profile(self):
1027
- self.auto_create_handle_loop()
1028
- req = ProfileReq(type=ProfileReqType.STOP_PROFILE)
1029
- return await self._execute_profile(req)
1030
-
1031
- async def _execute_profile(self, req: ProfileReq):
1032
- result = (await self.profile_communicator(req))[0]
1033
- if not result.success:
1034
- raise RuntimeError(result.message)
1035
- return result
1036
-
1037
- async def start_expert_distribution_record(self):
1038
- self.auto_create_handle_loop()
1039
- await self.expert_distribution_communicator(ExpertDistributionReq.START_RECORD)
1040
-
1041
- async def stop_expert_distribution_record(self):
1042
- self.auto_create_handle_loop()
1043
- await self.expert_distribution_communicator(ExpertDistributionReq.STOP_RECORD)
1044
-
1045
- async def dump_expert_distribution_record(self):
1046
- self.auto_create_handle_loop()
1047
- await self.expert_distribution_communicator(ExpertDistributionReq.DUMP_RECORD)
1048
-
1049
854
  async def pause_generation(self):
1050
855
  async with self.is_pause_cond:
1051
856
  self.is_pause = True
@@ -1080,6 +885,8 @@ class TokenizerManager:
1080
885
  async def _wait_for_model_update_from_disk(
1081
886
  self, obj: UpdateWeightFromDiskReqInput
1082
887
  ) -> Tuple[bool, str]:
888
+ if self.server_args.tokenizer_worker_num > 1:
889
+ obj = MultiTokenizerWrapper(self.worker_id, obj)
1083
890
  self.send_to_scheduler.send_pyobj(obj)
1084
891
  self.model_update_result = asyncio.Future()
1085
892
  if self.server_args.dp_size == 1:
@@ -1104,191 +911,6 @@ class TokenizerManager:
1104
911
  all_paused_requests = [r.num_paused_requests for r in result]
1105
912
  return all_success, all_message, all_paused_requests
1106
913
 
1107
- async def init_weights_update_group(
1108
- self,
1109
- obj: InitWeightsUpdateGroupReqInput,
1110
- request: Optional[fastapi.Request] = None,
1111
- ) -> Tuple[bool, str]:
1112
- self.auto_create_handle_loop()
1113
- assert (
1114
- self.server_args.dp_size == 1
1115
- ), "dp_size must be 1 for init parameter update group"
1116
- result = (await self.init_weights_update_group_communicator(obj))[0]
1117
- return result.success, result.message
1118
-
1119
- async def update_weights_from_distributed(
1120
- self,
1121
- obj: UpdateWeightsFromDistributedReqInput,
1122
- request: Optional[fastapi.Request] = None,
1123
- ) -> Tuple[bool, str]:
1124
- self.auto_create_handle_loop()
1125
- assert (
1126
- self.server_args.dp_size == 1 or self.server_args.enable_dp_attention
1127
- ), "dp_size must be 1 or dp attention must be enabled for update weights from distributed"
1128
-
1129
- if obj.abort_all_requests:
1130
- self.abort_request(abort_all=True)
1131
-
1132
- # This means that weight sync
1133
- # cannot run while requests are in progress.
1134
- async with self.model_update_lock.writer_lock:
1135
- result = (await self.update_weights_from_distributed_communicator(obj))[0]
1136
- return result.success, result.message
1137
-
1138
- async def update_weights_from_tensor(
1139
- self,
1140
- obj: UpdateWeightsFromTensorReqInput,
1141
- request: Optional[fastapi.Request] = None,
1142
- ) -> Tuple[bool, str]:
1143
- self.auto_create_handle_loop()
1144
- assert (
1145
- self.server_args.dp_size == 1 or self.server_args.enable_dp_attention
1146
- ), "dp_size must be 1 or dp attention must be enabled for update weights from tensor"
1147
-
1148
- if obj.abort_all_requests:
1149
- self.abort_request(abort_all=True)
1150
-
1151
- # This means that weight sync
1152
- # cannot run while requests are in progress.
1153
- async with self.model_update_lock.writer_lock:
1154
- result = (await self.update_weights_from_tensor_communicator(obj))[0]
1155
- return result.success, result.message
1156
-
1157
- async def load_lora_adapter(
1158
- self,
1159
- obj: LoadLoRAAdapterReqInput,
1160
- _: Optional[fastapi.Request] = None,
1161
- ) -> LoadLoRAAdapterReqOutput:
1162
- self.auto_create_handle_loop()
1163
-
1164
- try:
1165
- if not self.server_args.enable_lora:
1166
- raise ValueError(
1167
- "LoRA is not enabled. Please set `--enable-lora` to enable LoRA."
1168
- )
1169
-
1170
- # TODO (lifuhuang): Remove this after we verify that dynamic lora loading works
1171
- # with dp_size > 1.
1172
- assert (
1173
- self.server_args.dp_size == 1
1174
- ), "dp_size must be 1 for dynamic lora loading"
1175
- logger.info(
1176
- "Start load Lora adapter. Lora name=%s, path=%s",
1177
- obj.lora_name,
1178
- obj.lora_path,
1179
- )
1180
-
1181
- async with self.lora_update_lock:
1182
- if (
1183
- self.server_args.max_loaded_loras is not None
1184
- and self.lora_registry.num_registered_loras
1185
- >= self.server_args.max_loaded_loras
1186
- ):
1187
- raise ValueError(
1188
- f"Cannot load LoRA adapter {obj.lora_name} at path {obj.lora_path}. "
1189
- f"Maximum number of loaded LoRA adapters is {self.server_args.max_loaded_loras}. "
1190
- "Please unload some LoRA adapters before loading new ones."
1191
- )
1192
-
1193
- # Generate new uniquely identifiable LoRARef object.
1194
- new_adapter = LoRARef(
1195
- lora_name=obj.lora_name,
1196
- lora_path=obj.lora_path,
1197
- pinned=obj.pinned,
1198
- )
1199
-
1200
- # Trigger the actual loading operation at the backend processes.
1201
- obj.lora_id = new_adapter.lora_id
1202
- result = (await self.update_lora_adapter_communicator(obj))[0]
1203
-
1204
- # Register the LoRA adapter only after loading is successful.
1205
- if result.success:
1206
- await self.lora_registry.register(new_adapter)
1207
-
1208
- return result
1209
- except ValueError as e:
1210
- return LoadLoRAAdapterReqOutput(
1211
- success=False,
1212
- error_message=str(e),
1213
- )
1214
-
1215
- async def unload_lora_adapter(
1216
- self,
1217
- obj: UnloadLoRAAdapterReqInput,
1218
- _: Optional[fastapi.Request] = None,
1219
- ) -> UnloadLoRAAdapterReqOutput:
1220
- self.auto_create_handle_loop()
1221
-
1222
- try:
1223
- if not self.server_args.enable_lora:
1224
- raise ValueError(
1225
- "LoRA is not enabled. Please set `--enable-lora` to enable LoRA."
1226
- )
1227
-
1228
- assert (
1229
- obj.lora_name is not None
1230
- ), "lora_name must be provided to unload LoRA adapter"
1231
-
1232
- # TODO (lifuhuang): Remove this after we verify that dynamic lora loading works
1233
- # with dp_size > 1.
1234
- assert (
1235
- self.server_args.dp_size == 1
1236
- ), "dp_size must be 1 for dynamic lora loading"
1237
- logger.info(
1238
- "Start unload Lora adapter. Lora name=%s",
1239
- obj.lora_name,
1240
- )
1241
-
1242
- async with self.lora_update_lock:
1243
- # Unregister the LoRA adapter from the registry to stop new requests for this adapter
1244
- # from being started.
1245
- lora_id = await self.lora_registry.unregister(obj.lora_name)
1246
- obj.lora_id = lora_id
1247
-
1248
- # Initiate the actual unloading operation at the backend processes only after all
1249
- # ongoing requests using this LoRA adapter are finished.
1250
- await self.lora_registry.wait_for_unload(lora_id)
1251
- result = (await self.update_lora_adapter_communicator(obj))[0]
1252
-
1253
- return result
1254
- except ValueError as e:
1255
- return UnloadLoRAAdapterReqOutput(success=False, error_message=str(e))
1256
-
1257
- async def get_weights_by_name(
1258
- self, obj: GetWeightsByNameReqInput, request: Optional[fastapi.Request] = None
1259
- ):
1260
- self.auto_create_handle_loop()
1261
- results = await self.get_weights_by_name_communicator(obj)
1262
- all_parameters = [r.parameter for r in results]
1263
- if self.server_args.dp_size == 1:
1264
- return all_parameters[0]
1265
- else:
1266
- return all_parameters
1267
-
1268
- async def release_memory_occupation(
1269
- self,
1270
- obj: ReleaseMemoryOccupationReqInput,
1271
- request: Optional[fastapi.Request] = None,
1272
- ):
1273
- self.auto_create_handle_loop()
1274
- await self.release_memory_occupation_communicator(obj)
1275
-
1276
- async def resume_memory_occupation(
1277
- self,
1278
- obj: ResumeMemoryOccupationReqInput,
1279
- request: Optional[fastapi.Request] = None,
1280
- ):
1281
- self.auto_create_handle_loop()
1282
- await self.resume_memory_occupation_communicator(obj)
1283
-
1284
- async def slow_down(
1285
- self,
1286
- obj: SlowDownReqInput,
1287
- request: Optional[fastapi.Request] = None,
1288
- ):
1289
- self.auto_create_handle_loop()
1290
- await self.slow_down_communicator(obj)
1291
-
1292
914
  async def open_session(
1293
915
  self, obj: OpenSessionReqInput, request: Optional[fastapi.Request] = None
1294
916
  ):
@@ -1299,6 +921,8 @@ class TokenizerManager:
1299
921
  elif obj.session_id in self.session_futures:
1300
922
  return None
1301
923
 
924
+ if self.server_args.tokenizer_worker_num > 1:
925
+ obj = MultiTokenizerWrapper(self.worker_id, obj)
1302
926
  self.send_to_scheduler.send_pyobj(obj)
1303
927
 
1304
928
  self.session_futures[obj.session_id] = asyncio.Future()
@@ -1311,30 +935,6 @@ class TokenizerManager:
1311
935
  ):
1312
936
  await self.send_to_scheduler.send_pyobj(obj)
1313
937
 
1314
- async def get_internal_state(self) -> List[Dict[Any, Any]]:
1315
- req = GetInternalStateReq()
1316
- responses: List[GetInternalStateReqOutput] = (
1317
- await self.get_internal_state_communicator(req)
1318
- )
1319
- # Many DP ranks
1320
- return [res.internal_state for res in responses]
1321
-
1322
- async def set_internal_state(
1323
- self, obj: SetInternalStateReq
1324
- ) -> SetInternalStateReqOutput:
1325
- responses: List[SetInternalStateReqOutput] = (
1326
- await self.set_internal_state_communicator(obj)
1327
- )
1328
- return [res.internal_state for res in responses]
1329
-
1330
- async def get_load(self) -> dict:
1331
- # TODO(lsyin): fake load report server
1332
- if not self.current_load_lock.locked():
1333
- async with self.current_load_lock:
1334
- internal_state = await self.get_internal_state()
1335
- self.current_load = internal_state[0]["load"]
1336
- return {"load": self.current_load}
1337
-
1338
938
  def get_log_request_metadata(self):
1339
939
  max_length = None
1340
940
  skip_names = None
@@ -1576,7 +1176,6 @@ class TokenizerManager:
1576
1176
 
1577
1177
  async def handle_loop(self):
1578
1178
  """The event loop that handles requests"""
1579
-
1580
1179
  while True:
1581
1180
  recv_obj = await self.recv_from_detokenizer.recv_pyobj()
1582
1181
  self._result_dispatcher(recv_obj)
@@ -1596,9 +1195,12 @@ class TokenizerManager:
1596
1195
  )
1597
1196
  continue
1598
1197
 
1198
+ origin_rid = rid
1199
+ if self.server_args.tokenizer_worker_num > 1:
1200
+ origin_rid = get_origin_rid(rid)
1599
1201
  # Build meta_info and return value
1600
1202
  meta_info = {
1601
- "id": rid,
1203
+ "id": origin_rid,
1602
1204
  "finish_reason": recv_obj.finished_reasons[i],
1603
1205
  "prompt_tokens": recv_obj.prompt_tokens[i],
1604
1206
  "weight_version": self.server_args.weight_version,
@@ -1904,6 +1506,9 @@ class TokenizerManager:
1904
1506
  if is_health_check_generate_req(recv_obj):
1905
1507
  return
1906
1508
  state = self.rid_to_state[recv_obj.rid]
1509
+ origin_rid = recv_obj.rid
1510
+ if self.server_args.tokenizer_worker_num > 1:
1511
+ origin_rid = get_origin_rid(origin_rid)
1907
1512
  state.finished = True
1908
1513
  if recv_obj.finished_reason:
1909
1514
  out = {
@@ -1916,7 +1521,7 @@ class TokenizerManager:
1916
1521
  out = {
1917
1522
  "text": "",
1918
1523
  "meta_info": {
1919
- "id": recv_obj.rid,
1524
+ "id": origin_rid,
1920
1525
  "finish_reason": {
1921
1526
  "type": "abort",
1922
1527
  "message": "Abort before prefill",
@@ -2096,47 +1701,6 @@ class SignalHandler:
2096
1701
  kill_process_tree(os.getpid())
2097
1702
 
2098
1703
 
2099
- T = TypeVar("T")
2100
-
2101
-
2102
- class _Communicator(Generic[T]):
2103
- """Note: The communicator now only run up to 1 in-flight request at any time."""
2104
-
2105
- def __init__(self, sender, fan_out: int):
2106
- self._sender = sender
2107
- self._fan_out = fan_out
2108
- self._result_event: Optional[asyncio.Event] = None
2109
- self._result_values: Optional[List[T]] = None
2110
- self._ready_queue: Deque[asyncio.Future] = deque()
2111
-
2112
- async def __call__(self, obj):
2113
- ready_event = asyncio.Event()
2114
- if self._result_event is not None or len(self._ready_queue) > 0:
2115
- self._ready_queue.append(ready_event)
2116
- await ready_event.wait()
2117
- assert self._result_event is None
2118
- assert self._result_values is None
2119
-
2120
- if obj:
2121
- self._sender.send_pyobj(obj)
2122
-
2123
- self._result_event = asyncio.Event()
2124
- self._result_values = []
2125
- await self._result_event.wait()
2126
- result_values = self._result_values
2127
- self._result_event = self._result_values = None
2128
-
2129
- if len(self._ready_queue) > 0:
2130
- self._ready_queue.popleft().set()
2131
-
2132
- return result_values
2133
-
2134
- def handle_recv(self, recv_obj: T):
2135
- self._result_values.append(recv_obj)
2136
- if len(self._result_values) == self._fan_out:
2137
- self._result_event.set()
2138
-
2139
-
2140
1704
  # Note: request abort handling logic
2141
1705
  # We should handle all of the following cases correctly.
2142
1706
  #