sglang 0.5.2rc1__py3-none-any.whl → 0.5.3rc0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (265) hide show
  1. sglang/bench_one_batch_server.py +10 -1
  2. sglang/bench_serving.py +257 -29
  3. sglang/lang/interpreter.py +1 -1
  4. sglang/srt/configs/__init__.py +4 -0
  5. sglang/srt/configs/device_config.py +3 -1
  6. sglang/srt/configs/dots_vlm.py +139 -0
  7. sglang/srt/configs/internvl.py +6 -0
  8. sglang/srt/configs/load_config.py +1 -0
  9. sglang/srt/configs/model_config.py +50 -6
  10. sglang/srt/configs/qwen3_next.py +326 -0
  11. sglang/srt/connector/__init__.py +8 -1
  12. sglang/srt/connector/remote_instance.py +82 -0
  13. sglang/srt/constrained/base_grammar_backend.py +48 -12
  14. sglang/srt/constrained/llguidance_backend.py +0 -1
  15. sglang/srt/constrained/outlines_backend.py +0 -1
  16. sglang/srt/constrained/xgrammar_backend.py +28 -9
  17. sglang/srt/custom_op.py +11 -1
  18. sglang/srt/debug_utils/dump_comparator.py +81 -44
  19. sglang/srt/debug_utils/dump_loader.py +97 -0
  20. sglang/srt/debug_utils/dumper.py +11 -3
  21. sglang/srt/debug_utils/text_comparator.py +73 -11
  22. sglang/srt/disaggregation/base/conn.py +1 -1
  23. sglang/srt/disaggregation/common/conn.py +15 -12
  24. sglang/srt/disaggregation/decode.py +21 -10
  25. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +4 -1
  26. sglang/srt/disaggregation/fake/conn.py +1 -1
  27. sglang/srt/disaggregation/mini_lb.py +6 -445
  28. sglang/srt/disaggregation/mooncake/conn.py +18 -10
  29. sglang/srt/disaggregation/nixl/conn.py +180 -16
  30. sglang/srt/disaggregation/prefill.py +5 -3
  31. sglang/srt/disaggregation/utils.py +5 -50
  32. sglang/srt/distributed/parallel_state.py +67 -43
  33. sglang/srt/entrypoints/engine.py +38 -17
  34. sglang/srt/entrypoints/grpc_request_manager.py +580 -0
  35. sglang/srt/entrypoints/grpc_server.py +680 -0
  36. sglang/srt/entrypoints/http_server.py +88 -53
  37. sglang/srt/entrypoints/openai/protocol.py +7 -4
  38. sglang/srt/entrypoints/openai/serving_base.py +46 -3
  39. sglang/srt/entrypoints/openai/serving_chat.py +39 -19
  40. sglang/srt/entrypoints/openai/serving_completions.py +15 -4
  41. sglang/srt/entrypoints/openai/serving_embedding.py +9 -4
  42. sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
  43. sglang/srt/entrypoints/openai/serving_responses.py +7 -4
  44. sglang/srt/entrypoints/openai/serving_score.py +1 -0
  45. sglang/srt/eplb/eplb_manager.py +2 -2
  46. sglang/srt/eplb/expert_distribution.py +26 -13
  47. sglang/srt/eplb/expert_location.py +8 -3
  48. sglang/srt/eplb/expert_location_updater.py +1 -1
  49. sglang/srt/function_call/base_format_detector.py +3 -6
  50. sglang/srt/function_call/ebnf_composer.py +11 -9
  51. sglang/srt/function_call/function_call_parser.py +6 -0
  52. sglang/srt/function_call/glm4_moe_detector.py +1 -1
  53. sglang/srt/function_call/gpt_oss_detector.py +1 -1
  54. sglang/srt/function_call/qwen3_coder_detector.py +1 -1
  55. sglang/srt/grpc/__init__.py +1 -0
  56. sglang/srt/grpc/sglang_scheduler_pb2.py +106 -0
  57. sglang/srt/grpc/sglang_scheduler_pb2.pyi +427 -0
  58. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +236 -0
  59. sglang/srt/hf_transformers_utils.py +4 -0
  60. sglang/srt/layers/activation.py +142 -9
  61. sglang/srt/layers/attention/aiter_backend.py +93 -68
  62. sglang/srt/layers/attention/ascend_backend.py +11 -4
  63. sglang/srt/layers/attention/fla/chunk.py +242 -0
  64. sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
  65. sglang/srt/layers/attention/fla/chunk_o.py +178 -0
  66. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
  67. sglang/srt/layers/attention/fla/cumsum.py +300 -0
  68. sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
  69. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
  70. sglang/srt/layers/attention/fla/index.py +37 -0
  71. sglang/srt/layers/attention/fla/l2norm.py +150 -0
  72. sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
  73. sglang/srt/layers/attention/fla/op.py +66 -0
  74. sglang/srt/layers/attention/fla/solve_tril.py +465 -0
  75. sglang/srt/layers/attention/fla/utils.py +331 -0
  76. sglang/srt/layers/attention/fla/wy_fast.py +158 -0
  77. sglang/srt/layers/attention/flashinfer_backend.py +6 -4
  78. sglang/srt/layers/attention/flashinfer_mla_backend.py +16 -12
  79. sglang/srt/layers/attention/hybrid_attn_backend.py +57 -50
  80. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +602 -0
  81. sglang/srt/layers/attention/intel_amx_backend.py +3 -0
  82. sglang/srt/layers/attention/mamba/causal_conv1d.py +128 -0
  83. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +1052 -0
  84. sglang/srt/layers/attention/mamba/mamba.py +64 -0
  85. sglang/srt/layers/attention/torch_native_backend.py +12 -6
  86. sglang/srt/layers/attention/triton_backend.py +18 -1
  87. sglang/srt/layers/attention/trtllm_mla_backend.py +124 -31
  88. sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
  89. sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
  90. sglang/srt/layers/communicator.py +45 -7
  91. sglang/srt/layers/dp_attention.py +30 -1
  92. sglang/srt/layers/layernorm.py +32 -15
  93. sglang/srt/layers/linear.py +34 -3
  94. sglang/srt/layers/logits_processor.py +29 -10
  95. sglang/srt/layers/moe/__init__.py +2 -1
  96. sglang/srt/layers/moe/cutlass_w4a8_moe.py +3 -3
  97. sglang/srt/layers/moe/ep_moe/kernels.py +1 -1
  98. sglang/srt/layers/moe/ep_moe/layer.py +182 -62
  99. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +156 -0
  100. sglang/srt/layers/moe/fused_moe_native.py +5 -3
  101. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  102. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
  103. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
  104. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/{E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → E=257,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json } +29 -29
  105. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  106. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
  107. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
  108. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
  109. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
  110. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  111. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
  112. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
  113. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +1 -1
  114. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
  115. sglang/srt/layers/moe/fused_moe_triton/layer.py +61 -59
  116. sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
  117. sglang/srt/layers/moe/moe_runner/base.py +274 -1
  118. sglang/srt/layers/moe/moe_runner/runner.py +80 -0
  119. sglang/srt/layers/moe/moe_runner/triton.py +448 -0
  120. sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
  121. sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
  122. sglang/srt/layers/moe/token_dispatcher/deepep.py +43 -39
  123. sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
  124. sglang/srt/layers/moe/topk.py +30 -9
  125. sglang/srt/layers/moe/utils.py +12 -7
  126. sglang/srt/layers/quantization/awq.py +19 -7
  127. sglang/srt/layers/quantization/base_config.py +11 -6
  128. sglang/srt/layers/quantization/blockwise_int8.py +38 -27
  129. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
  130. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
  131. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
  132. sglang/srt/layers/quantization/fp8.py +76 -47
  133. sglang/srt/layers/quantization/fp8_utils.py +50 -31
  134. sglang/srt/layers/quantization/gptq.py +25 -17
  135. sglang/srt/layers/quantization/modelopt_quant.py +182 -49
  136. sglang/srt/layers/quantization/moe_wna16.py +21 -18
  137. sglang/srt/layers/quantization/mxfp4.py +68 -41
  138. sglang/srt/layers/quantization/quark/quark_moe.py +32 -27
  139. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +49 -30
  140. sglang/srt/layers/quantization/quark/utils.py +97 -0
  141. sglang/srt/layers/quantization/rocm_mxfp4_utils.py +13 -0
  142. sglang/srt/layers/quantization/unquant.py +135 -47
  143. sglang/srt/layers/quantization/w4afp8.py +30 -17
  144. sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
  145. sglang/srt/layers/quantization/w8a8_int8.py +76 -38
  146. sglang/srt/layers/rocm_linear_utils.py +44 -0
  147. sglang/srt/layers/rotary_embedding.py +0 -18
  148. sglang/srt/layers/sampler.py +162 -18
  149. sglang/srt/lora/backend/base_backend.py +50 -8
  150. sglang/srt/lora/backend/triton_backend.py +90 -2
  151. sglang/srt/lora/layers.py +32 -0
  152. sglang/srt/lora/lora.py +4 -1
  153. sglang/srt/lora/lora_manager.py +35 -112
  154. sglang/srt/lora/mem_pool.py +24 -10
  155. sglang/srt/lora/utils.py +18 -9
  156. sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
  157. sglang/srt/managers/cache_controller.py +200 -199
  158. sglang/srt/managers/data_parallel_controller.py +105 -35
  159. sglang/srt/managers/detokenizer_manager.py +8 -4
  160. sglang/srt/managers/disagg_service.py +46 -0
  161. sglang/srt/managers/io_struct.py +199 -12
  162. sglang/srt/managers/mm_utils.py +1 -0
  163. sglang/srt/managers/multi_tokenizer_mixin.py +351 -397
  164. sglang/srt/managers/schedule_batch.py +77 -56
  165. sglang/srt/managers/schedule_policy.py +4 -3
  166. sglang/srt/managers/scheduler.py +191 -139
  167. sglang/srt/managers/scheduler_metrics_mixin.py +116 -9
  168. sglang/srt/managers/scheduler_output_processor_mixin.py +55 -11
  169. sglang/srt/managers/scheduler_profiler_mixin.py +1 -1
  170. sglang/srt/managers/template_manager.py +3 -3
  171. sglang/srt/managers/tokenizer_communicator_mixin.py +569 -0
  172. sglang/srt/managers/tokenizer_manager.py +260 -519
  173. sglang/srt/managers/tp_worker.py +53 -4
  174. sglang/srt/managers/tp_worker_overlap_thread.py +42 -19
  175. sglang/srt/mem_cache/allocator.py +1 -1
  176. sglang/srt/mem_cache/hicache_storage.py +18 -33
  177. sglang/srt/mem_cache/hiradix_cache.py +108 -48
  178. sglang/srt/mem_cache/memory_pool.py +347 -48
  179. sglang/srt/mem_cache/memory_pool_host.py +121 -57
  180. sglang/srt/mem_cache/radix_cache.py +0 -2
  181. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
  182. sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
  183. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +95 -5
  184. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +280 -0
  185. sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
  186. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +81 -20
  187. sglang/srt/mem_cache/storage/mooncake_store/test_mooncake_store.py +161 -0
  188. sglang/srt/mem_cache/swa_radix_cache.py +0 -2
  189. sglang/srt/metrics/collector.py +502 -77
  190. sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
  191. sglang/srt/metrics/utils.py +48 -0
  192. sglang/srt/model_executor/cpu_graph_runner.py +640 -0
  193. sglang/srt/model_executor/cuda_graph_runner.py +13 -5
  194. sglang/srt/model_executor/forward_batch_info.py +75 -19
  195. sglang/srt/model_executor/model_runner.py +357 -30
  196. sglang/srt/model_loader/__init__.py +9 -3
  197. sglang/srt/model_loader/loader.py +128 -4
  198. sglang/srt/model_loader/weight_utils.py +2 -1
  199. sglang/srt/models/apertus.py +686 -0
  200. sglang/srt/models/bailing_moe.py +798 -218
  201. sglang/srt/models/bailing_moe_nextn.py +168 -0
  202. sglang/srt/models/deepseek_v2.py +346 -48
  203. sglang/srt/models/dots_vlm.py +174 -0
  204. sglang/srt/models/dots_vlm_vit.py +337 -0
  205. sglang/srt/models/ernie4.py +1 -1
  206. sglang/srt/models/gemma3n_mm.py +1 -1
  207. sglang/srt/models/glm4_moe.py +11 -2
  208. sglang/srt/models/glm4v.py +4 -2
  209. sglang/srt/models/glm4v_moe.py +3 -0
  210. sglang/srt/models/gpt_oss.py +1 -1
  211. sglang/srt/models/internvl.py +28 -0
  212. sglang/srt/models/llama4.py +9 -0
  213. sglang/srt/models/llama_eagle3.py +13 -0
  214. sglang/srt/models/longcat_flash.py +2 -2
  215. sglang/srt/models/minicpmv.py +165 -3
  216. sglang/srt/models/mllama4.py +25 -0
  217. sglang/srt/models/opt.py +637 -0
  218. sglang/srt/models/qwen2.py +7 -0
  219. sglang/srt/models/qwen2_5_vl.py +27 -3
  220. sglang/srt/models/qwen2_moe.py +60 -13
  221. sglang/srt/models/qwen3.py +8 -2
  222. sglang/srt/models/qwen3_moe.py +40 -9
  223. sglang/srt/models/qwen3_next.py +1042 -0
  224. sglang/srt/models/qwen3_next_mtp.py +112 -0
  225. sglang/srt/models/step3_vl.py +1 -1
  226. sglang/srt/models/torch_native_llama.py +1 -1
  227. sglang/srt/multimodal/processors/dots_vlm.py +99 -0
  228. sglang/srt/multimodal/processors/glm4v.py +9 -9
  229. sglang/srt/multimodal/processors/internvl.py +141 -129
  230. sglang/srt/multimodal/processors/qwen_vl.py +15 -5
  231. sglang/srt/offloader.py +27 -3
  232. sglang/srt/{reasoning_parser.py → parser/reasoning_parser.py} +1 -1
  233. sglang/srt/remote_instance_weight_loader_utils.py +69 -0
  234. sglang/srt/sampling/sampling_batch_info.py +18 -15
  235. sglang/srt/server_args.py +355 -37
  236. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +5 -0
  237. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +10 -1
  238. sglang/srt/speculative/eagle_utils.py +0 -2
  239. sglang/srt/speculative/eagle_worker.py +197 -112
  240. sglang/srt/speculative/spec_info.py +5 -0
  241. sglang/srt/speculative/standalone_worker.py +109 -0
  242. sglang/srt/tracing/trace.py +552 -0
  243. sglang/srt/utils.py +46 -3
  244. sglang/srt/weight_sync/utils.py +1 -1
  245. sglang/test/attention/test_trtllm_mla_backend.py +169 -5
  246. sglang/test/few_shot_gsm8k.py +1 -0
  247. sglang/test/runners.py +4 -0
  248. sglang/test/test_cutlass_moe.py +24 -6
  249. sglang/test/test_disaggregation_utils.py +66 -0
  250. sglang/test/test_fp4_moe.py +370 -1
  251. sglang/test/test_utils.py +28 -1
  252. sglang/utils.py +12 -0
  253. sglang/version.py +1 -1
  254. {sglang-0.5.2rc1.dist-info → sglang-0.5.3rc0.dist-info}/METADATA +59 -123
  255. {sglang-0.5.2rc1.dist-info → sglang-0.5.3rc0.dist-info}/RECORD +263 -200
  256. sglang/srt/disaggregation/launch_lb.py +0 -118
  257. sglang/srt/mem_cache/storage/mooncake_store/unit_test.py +0 -40
  258. /sglang/srt/{model_parallel.py → layers/model_parallel.py} +0 -0
  259. /sglang/srt/{code_completion_parser.py → parser/code_completion_parser.py} +0 -0
  260. /sglang/srt/{conversation.py → parser/conversation.py} +0 -0
  261. /sglang/srt/{harmony_parser.py → parser/harmony_parser.py} +0 -0
  262. /sglang/srt/{jinja_template_utils.py → parser/jinja_template_utils.py} +0 -0
  263. {sglang-0.5.2rc1.dist-info → sglang-0.5.3rc0.dist-info}/WHEEL +0 -0
  264. {sglang-0.5.2rc1.dist-info → sglang-0.5.3rc0.dist-info}/licenses/LICENSE +0 -0
  265. {sglang-0.5.2rc1.dist-info → sglang-0.5.3rc0.dist-info}/top_level.txt +0 -0
@@ -31,18 +31,7 @@ from contextlib import nullcontext
31
31
  from datetime import datetime
32
32
  from enum import Enum
33
33
  from http import HTTPStatus
34
- from typing import (
35
- Any,
36
- Awaitable,
37
- Deque,
38
- Dict,
39
- Generic,
40
- List,
41
- Optional,
42
- Tuple,
43
- TypeVar,
44
- Union,
45
- )
34
+ from typing import Any, Awaitable, Dict, List, Optional, Tuple, Union
46
35
 
47
36
  import fastapi
48
37
  import torch
@@ -53,18 +42,15 @@ from fastapi import BackgroundTasks
53
42
 
54
43
  from sglang.srt.aio_rwlock import RWLock
55
44
  from sglang.srt.configs.model_config import ModelConfig
56
- from sglang.srt.disaggregation.utils import (
57
- DisaggregationMode,
58
- KVClassType,
59
- TransferBackend,
60
- get_kv_class,
61
- )
45
+ from sglang.srt.disaggregation.utils import DisaggregationMode
62
46
  from sglang.srt.hf_transformers_utils import (
63
47
  get_processor,
64
48
  get_tokenizer,
65
49
  get_tokenizer_from_processor,
66
50
  )
67
51
  from sglang.srt.lora.lora_registry import LoRARef, LoRARegistry
52
+ from sglang.srt.managers.async_dynamic_batch_tokenizer import AsyncDynamicbatchTokenizer
53
+ from sglang.srt.managers.disagg_service import start_disagg_service
68
54
  from sglang.srt.managers.io_struct import (
69
55
  AbortReq,
70
56
  BatchEmbeddingOut,
@@ -73,60 +59,38 @@ from sglang.srt.managers.io_struct import (
73
59
  BatchTokenIDOut,
74
60
  BatchTokenizedEmbeddingReqInput,
75
61
  BatchTokenizedGenerateReqInput,
76
- ClearHiCacheReqInput,
77
- ClearHiCacheReqOutput,
78
62
  CloseSessionReqInput,
79
63
  ConfigureLoggingReq,
80
64
  EmbeddingReqInput,
81
- ExpertDistributionReq,
82
- ExpertDistributionReqOutput,
83
- FlushCacheReqInput,
84
- FlushCacheReqOutput,
85
65
  FreezeGCReq,
86
66
  GenerateReqInput,
87
- GetInternalStateReq,
88
- GetInternalStateReqOutput,
89
- GetWeightsByNameReqInput,
90
- GetWeightsByNameReqOutput,
67
+ GetLoadReqInput,
91
68
  HealthCheckOutput,
92
- InitWeightsUpdateGroupReqInput,
93
- InitWeightsUpdateGroupReqOutput,
94
- LoadLoRAAdapterReqInput,
95
- LoadLoRAAdapterReqOutput,
96
- LoRAUpdateResult,
97
- MultiTokenizerWarpper,
69
+ MultiTokenizerWrapper,
98
70
  OpenSessionReqInput,
99
71
  OpenSessionReqOutput,
100
- ProfileReq,
101
- ProfileReqOutput,
102
- ProfileReqType,
103
- ReleaseMemoryOccupationReqInput,
104
- ReleaseMemoryOccupationReqOutput,
105
- ResumeMemoryOccupationReqInput,
106
- ResumeMemoryOccupationReqOutput,
107
72
  SessionParams,
108
- SetInternalStateReq,
109
- SetInternalStateReqOutput,
110
- SlowDownReqInput,
111
- SlowDownReqOutput,
112
73
  TokenizedEmbeddingReqInput,
113
74
  TokenizedGenerateReqInput,
114
- UnloadLoRAAdapterReqInput,
115
- UnloadLoRAAdapterReqOutput,
116
75
  UpdateWeightFromDiskReqInput,
117
76
  UpdateWeightFromDiskReqOutput,
118
- UpdateWeightsFromDistributedReqInput,
119
- UpdateWeightsFromDistributedReqOutput,
120
- UpdateWeightsFromTensorReqInput,
121
- UpdateWeightsFromTensorReqOutput,
77
+ WatchLoadUpdateReq,
122
78
  )
123
79
  from sglang.srt.managers.mm_utils import TensorTransportMode
124
80
  from sglang.srt.managers.multimodal_processor import get_mm_processor, import_processors
125
81
  from sglang.srt.managers.scheduler import is_health_check_generate_req
126
82
  from sglang.srt.managers.scheduler_input_blocker import input_blocker_guard_region
83
+ from sglang.srt.managers.tokenizer_communicator_mixin import TokenizerCommunicatorMixin
127
84
  from sglang.srt.metrics.collector import TokenizerMetricsCollector
128
85
  from sglang.srt.sampling.sampling_params import SamplingParams
129
86
  from sglang.srt.server_args import PortArgs, ServerArgs
87
+ from sglang.srt.tracing.trace import (
88
+ trace_get_proc_propagate_context,
89
+ trace_req_finish,
90
+ trace_req_start,
91
+ trace_slice_end,
92
+ trace_slice_start,
93
+ )
130
94
  from sglang.srt.utils import (
131
95
  configure_gc_warning,
132
96
  dataclass_to_string_truncated,
@@ -180,7 +144,7 @@ class ReqState:
180
144
  output_token_ids_logprobs_idx: List = dataclasses.field(default_factory=list)
181
145
 
182
146
 
183
- class TokenizerManager:
147
+ class TokenizerManager(TokenizerCommunicatorMixin):
184
148
  """TokenizerManager is a process that tokenizes the text."""
185
149
 
186
150
  def __init__(
@@ -262,6 +226,18 @@ class TokenizerManager:
262
226
  trust_remote_code=server_args.trust_remote_code,
263
227
  revision=server_args.revision,
264
228
  )
229
+ # Initialize async dynamic batch tokenizer if enabled (common for both multimodal and non-multimodal)
230
+ if (
231
+ server_args.enable_dynamic_batch_tokenizer
232
+ and not server_args.skip_tokenizer_init
233
+ ):
234
+ self.async_dynamic_batch_tokenizer = AsyncDynamicbatchTokenizer(
235
+ self.tokenizer,
236
+ max_batch_size=server_args.dynamic_batch_tokenizer_batch_size,
237
+ batch_wait_timeout_s=server_args.dynamic_batch_tokenizer_batch_timeout,
238
+ )
239
+ else:
240
+ self.async_dynamic_batch_tokenizer = None
265
241
 
266
242
  # Init inter-process communication
267
243
  context = zmq.asyncio.Context(2)
@@ -319,8 +295,10 @@ class TokenizerManager:
319
295
  # LoRA updates and inference to overlap.
320
296
  self.lora_update_lock = asyncio.Lock()
321
297
 
322
- # For PD disaggregtion
323
- self.init_disaggregation()
298
+ self.disaggregation_mode = DisaggregationMode(
299
+ self.server_args.disaggregation_mode
300
+ )
301
+ self.bootstrap_server = start_disagg_service(self.server_args)
324
302
 
325
303
  # For load balancing
326
304
  self.current_load = 0
@@ -328,11 +306,16 @@ class TokenizerManager:
328
306
 
329
307
  # Metrics
330
308
  if self.enable_metrics:
309
+ labels = {
310
+ "model_name": self.server_args.served_model_name,
311
+ # TODO: Add lora name/path in the future,
312
+ }
313
+ if server_args.tokenizer_metrics_allowed_customer_labels:
314
+ for label in server_args.tokenizer_metrics_allowed_customer_labels:
315
+ labels[label] = ""
331
316
  self.metrics_collector = TokenizerMetricsCollector(
332
- labels={
333
- "model_name": self.server_args.served_model_name,
334
- # TODO: Add lora name/path in the future,
335
- },
317
+ server_args=server_args,
318
+ labels=labels,
336
319
  bucket_time_to_first_token=self.server_args.bucket_time_to_first_token,
337
320
  bucket_e2e_request_latency=self.server_args.bucket_e2e_request_latency,
338
321
  bucket_inter_token_latency=self.server_args.bucket_inter_token_latency,
@@ -343,50 +326,6 @@ class TokenizerManager:
343
326
  if self.server_args.gc_warning_threshold_secs > 0.0:
344
327
  configure_gc_warning(self.server_args.gc_warning_threshold_secs)
345
328
 
346
- # Communicators
347
- self.init_weights_update_group_communicator = _Communicator(
348
- self.send_to_scheduler, server_args.dp_size
349
- )
350
- self.update_weights_from_distributed_communicator = _Communicator(
351
- self.send_to_scheduler, server_args.dp_size
352
- )
353
- self.update_weights_from_tensor_communicator = _Communicator(
354
- self.send_to_scheduler, server_args.dp_size
355
- )
356
- self.get_weights_by_name_communicator = _Communicator(
357
- self.send_to_scheduler, server_args.dp_size
358
- )
359
- self.release_memory_occupation_communicator = _Communicator(
360
- self.send_to_scheduler, server_args.dp_size
361
- )
362
- self.resume_memory_occupation_communicator = _Communicator(
363
- self.send_to_scheduler, server_args.dp_size
364
- )
365
- self.slow_down_communicator = _Communicator(
366
- self.send_to_scheduler, server_args.dp_size
367
- )
368
- self.flush_cache_communicator = _Communicator(
369
- self.send_to_scheduler, server_args.dp_size
370
- )
371
- self.clear_hicache_storage_communicator = _Communicator(
372
- self.send_to_scheduler, server_args.dp_size
373
- )
374
- self.profile_communicator = _Communicator(
375
- self.send_to_scheduler, server_args.dp_size
376
- )
377
- self.get_internal_state_communicator = _Communicator(
378
- self.send_to_scheduler, server_args.dp_size
379
- )
380
- self.set_internal_state_communicator = _Communicator(
381
- self.send_to_scheduler, server_args.dp_size
382
- )
383
- self.expert_distribution_communicator = _Communicator(
384
- self.send_to_scheduler, server_args.dp_size
385
- )
386
- self.update_lora_adapter_communicator = _Communicator(
387
- self.send_to_scheduler, server_args.dp_size
388
- )
389
-
390
329
  self._result_dispatcher = TypeBasedDispatcher(
391
330
  [
392
331
  (
@@ -404,100 +343,15 @@ class TokenizerManager:
404
343
  UpdateWeightFromDiskReqOutput,
405
344
  self._handle_update_weights_from_disk_req_output,
406
345
  ),
407
- (
408
- InitWeightsUpdateGroupReqOutput,
409
- self.init_weights_update_group_communicator.handle_recv,
410
- ),
411
- (
412
- UpdateWeightsFromDistributedReqOutput,
413
- self.update_weights_from_distributed_communicator.handle_recv,
414
- ),
415
- (
416
- UpdateWeightsFromTensorReqOutput,
417
- self.update_weights_from_tensor_communicator.handle_recv,
418
- ),
419
- (
420
- GetWeightsByNameReqOutput,
421
- self.get_weights_by_name_communicator.handle_recv,
422
- ),
423
- (
424
- ReleaseMemoryOccupationReqOutput,
425
- self.release_memory_occupation_communicator.handle_recv,
426
- ),
427
- (
428
- ResumeMemoryOccupationReqOutput,
429
- self.resume_memory_occupation_communicator.handle_recv,
430
- ),
431
- (
432
- SlowDownReqOutput,
433
- self.slow_down_communicator.handle_recv,
434
- ),
435
- (
436
- ClearHiCacheReqOutput,
437
- self.clear_hicache_storage_communicator.handle_recv,
438
- ),
439
- (
440
- FlushCacheReqOutput,
441
- self.flush_cache_communicator.handle_recv,
442
- ),
443
- (
444
- ProfileReqOutput,
445
- self.profile_communicator.handle_recv,
446
- ),
447
346
  (
448
347
  FreezeGCReq,
449
348
  lambda x: None,
450
349
  ), # For handling case when scheduler skips detokenizer and forwards back to the tokenizer manager, we ignore it.
451
- (
452
- GetInternalStateReqOutput,
453
- self.get_internal_state_communicator.handle_recv,
454
- ),
455
- (
456
- SetInternalStateReqOutput,
457
- self.set_internal_state_communicator.handle_recv,
458
- ),
459
- (
460
- ExpertDistributionReqOutput,
461
- self.expert_distribution_communicator.handle_recv,
462
- ),
463
- (
464
- LoRAUpdateResult,
465
- self.update_lora_adapter_communicator.handle_recv,
466
- ),
467
350
  (HealthCheckOutput, lambda x: None),
468
351
  ]
469
352
  )
470
353
 
471
- def init_disaggregation(self):
472
- self.disaggregation_mode = DisaggregationMode(
473
- self.server_args.disaggregation_mode
474
- )
475
- self.disaggregation_transfer_backend = TransferBackend(
476
- self.server_args.disaggregation_transfer_backend
477
- )
478
- # Start kv boostrap server on prefill
479
- if self.disaggregation_mode == DisaggregationMode.PREFILL:
480
- # only start bootstrap server on prefill tm
481
- kv_bootstrap_server_class = get_kv_class(
482
- self.disaggregation_transfer_backend, KVClassType.BOOTSTRAP_SERVER
483
- )
484
- self.bootstrap_server = kv_bootstrap_server_class(
485
- self.server_args.disaggregation_bootstrap_port
486
- )
487
- is_create_store = (
488
- self.server_args.node_rank == 0
489
- and self.server_args.disaggregation_transfer_backend == "ascend"
490
- )
491
- if is_create_store:
492
- try:
493
- from mf_adapter import create_config_store
494
-
495
- ascend_url = os.getenv("ASCEND_MF_STORE_URL")
496
- create_config_store(ascend_url)
497
- except Exception as e:
498
- error_message = f"Failed create mf store, invalid ascend_url."
499
- error_message += f" With exception {e}"
500
- raise error_message
354
+ self.init_communicators(server_args)
501
355
 
502
356
  async def generate_request(
503
357
  self,
@@ -517,6 +371,24 @@ class TokenizerManager:
517
371
  # If it's a single value, add worker_id prefix
518
372
  obj.rid = f"{self.worker_id}_{obj.rid}"
519
373
 
374
+ if obj.is_single:
375
+ bootstrap_room = (
376
+ obj.bootstrap_room if hasattr(obj, "bootstrap_room") else None
377
+ )
378
+ trace_req_start(obj.rid, bootstrap_room, ts=int(created_time * 1e9))
379
+ trace_slice_start("", obj.rid, ts=int(created_time * 1e9), anonymous=True)
380
+ else:
381
+ for i in range(len(obj.rid)):
382
+ bootstrap_room = (
383
+ obj.bootstrap_room[i]
384
+ if hasattr(obj, "bootstrap_room") and obj.bootstrap_room
385
+ else None
386
+ )
387
+ trace_req_start(obj.rid[i], bootstrap_room, ts=int(created_time * 1e9))
388
+ trace_slice_start(
389
+ "", obj.rid[i], ts=int(created_time * 1e9), anonymous=True
390
+ )
391
+
520
392
  if self.log_requests:
521
393
  max_length, skip_names, _ = self.log_request_metadata
522
394
  logger.info(
@@ -542,6 +414,144 @@ class TokenizerManager:
542
414
  ):
543
415
  yield response
544
416
 
417
+ def _detect_input_format(
418
+ self, texts: Union[str, List[str]], is_cross_encoder: bool
419
+ ) -> str:
420
+ """Detect the format of input texts for proper tokenization handling.
421
+
422
+ Returns:
423
+ - "single_string": Regular single text like "Hello world"
424
+ - "batch_strings": Regular batch like ["Hello", "World"]
425
+ - "cross_encoder_pairs": Cross-encoder pairs like [["query", "document"]]
426
+ """
427
+ if isinstance(texts, str):
428
+ return "single_string"
429
+
430
+ if (
431
+ is_cross_encoder
432
+ and len(texts) > 0
433
+ and isinstance(texts[0], list)
434
+ and len(texts[0]) == 2
435
+ ):
436
+ return "cross_encoder_pairs"
437
+
438
+ return "batch_strings"
439
+
440
+ def _prepare_tokenizer_input(
441
+ self, texts: Union[str, List[str]], input_format: str
442
+ ) -> Union[List[str], List[List[str]]]:
443
+ """Prepare input for the tokenizer based on detected format."""
444
+ if input_format == "single_string":
445
+ return [texts] # Wrap single string for batch processing
446
+ elif input_format == "cross_encoder_pairs":
447
+ return texts # Already in correct format: [["query", "doc"]]
448
+ else: # batch_strings
449
+ return texts # Already in correct format: ["text1", "text2"]
450
+
451
+ def _extract_tokenizer_results(
452
+ self,
453
+ input_ids: List[List[int]],
454
+ token_type_ids: Optional[List[List[int]]],
455
+ input_format: str,
456
+ original_batch_size: int,
457
+ ) -> Union[
458
+ Tuple[List[int], Optional[List[int]]],
459
+ Tuple[List[List[int]], Optional[List[List[int]]]],
460
+ ]:
461
+ """Extract results from tokenizer output based on input format."""
462
+
463
+ # For single inputs (string or single cross-encoder pair), extract first element
464
+ if (
465
+ input_format in ["single_string", "cross_encoder_pairs"]
466
+ and original_batch_size == 1
467
+ ):
468
+ single_input_ids = input_ids[0] if input_ids else []
469
+ single_token_type_ids = token_type_ids[0] if token_type_ids else None
470
+ return single_input_ids, single_token_type_ids
471
+
472
+ # For true batches, return as-is
473
+ return input_ids, token_type_ids
474
+
475
+ async def _tokenize_texts(
476
+ self, texts: Union[str, List[str]], is_cross_encoder: bool = False
477
+ ) -> Union[
478
+ Tuple[List[int], Optional[List[int]]],
479
+ Tuple[List[List[int]], Optional[List[List[int]]]],
480
+ ]:
481
+ """
482
+ Tokenize text(s) using the appropriate tokenizer strategy.
483
+
484
+ This method handles multiple input formats and chooses between async dynamic
485
+ batch tokenizer (for single texts only) and regular tokenizer.
486
+
487
+ Args:
488
+ texts: Text input in various formats:
489
+
490
+ Regular cases:
491
+ - Single string: "How are you?"
492
+ - Batch of strings: ["Hello", "World", "How are you?"]
493
+
494
+ Cross-encoder cases (sentence pairs for similarity/ranking):
495
+ - Single pair: [["query text", "document text"]]
496
+ - Multiple pairs: [["q1", "d1"], ["q2", "d2"], ["q3", "d3"]]
497
+
498
+ is_cross_encoder: Whether to return token_type_ids for cross-encoder models.
499
+ Enables proper handling of sentence pairs with segment IDs.
500
+
501
+ Returns:
502
+ Single input cases:
503
+ Tuple[List[int], Optional[List[int]]]: (input_ids, token_type_ids)
504
+ Example: ([101, 2129, 102], [0, 0, 0]) for single text
505
+ Example: ([101, 2129, 102, 4068, 102], [0, 0, 0, 1, 1]) for cross-encoder pair
506
+
507
+ Batch input cases:
508
+ Tuple[List[List[int]], Optional[List[List[int]]]]: (batch_input_ids, batch_token_type_ids)
509
+ Example: ([[101, 2129, 102], [101, 4068, 102]], None) for regular batch
510
+
511
+ Note: token_type_ids is None unless is_cross_encoder=True.
512
+ """
513
+ if not texts or self.tokenizer is None:
514
+ raise ValueError("texts cannot be empty and tokenizer must be initialized")
515
+
516
+ # Step 1: Detect input format and prepare for tokenization
517
+ input_format = self._detect_input_format(texts, is_cross_encoder)
518
+ tokenizer_input = self._prepare_tokenizer_input(texts, input_format)
519
+ original_batch_size = len(texts) if not isinstance(texts, str) else 1
520
+
521
+ # Step 2: Set up tokenizer arguments
522
+ tokenizer_kwargs = (
523
+ {"return_token_type_ids": is_cross_encoder} if is_cross_encoder else {}
524
+ )
525
+
526
+ # Step 3: Choose tokenization strategy
527
+ use_async_tokenizer = (
528
+ self.async_dynamic_batch_tokenizer is not None
529
+ and input_format == "single_string"
530
+ )
531
+
532
+ if use_async_tokenizer:
533
+ logger.debug("Using async dynamic batch tokenizer for single text")
534
+ result = await self.async_dynamic_batch_tokenizer.encode(
535
+ tokenizer_input[0], **tokenizer_kwargs
536
+ )
537
+ # Convert to batch format for consistency
538
+ input_ids = [result["input_ids"]]
539
+ token_type_ids = (
540
+ [result["token_type_ids"]]
541
+ if is_cross_encoder and result.get("token_type_ids")
542
+ else None
543
+ )
544
+ else:
545
+ logger.debug(f"Using regular tokenizer for {len(tokenizer_input)} inputs")
546
+ encoded = self.tokenizer(tokenizer_input, **tokenizer_kwargs)
547
+ input_ids = encoded["input_ids"]
548
+ token_type_ids = encoded.get("token_type_ids") if is_cross_encoder else None
549
+
550
+ # Step 4: Extract results based on input format
551
+ return self._extract_tokenizer_results(
552
+ input_ids, token_type_ids, input_format, original_batch_size
553
+ )
554
+
545
555
  async def _tokenize_one_request(
546
556
  self,
547
557
  obj: Union[GenerateReqInput, EmbeddingReqInput],
@@ -572,14 +582,10 @@ class TokenizerManager:
572
582
  "accept text prompts. Please provide input_ids or re-initialize "
573
583
  "the engine with skip_tokenizer_init=False."
574
584
  )
575
- encoded = self.tokenizer(
576
- input_text, return_token_type_ids=is_cross_encoder_request
577
- )
578
585
 
579
- input_ids = encoded["input_ids"]
580
- if is_cross_encoder_request:
581
- input_ids = encoded["input_ids"][0]
582
- token_type_ids = encoded.get("token_type_ids", [None])[0]
586
+ input_ids, token_type_ids = await self._tokenize_texts(
587
+ input_text, is_cross_encoder_request
588
+ )
583
589
 
584
590
  if self.mm_processor and obj.contains_mm_input():
585
591
  if not isinstance(obj.image_data, list):
@@ -599,6 +605,7 @@ class TokenizerManager:
599
605
  mm_inputs = None
600
606
 
601
607
  self._validate_one_request(obj, input_ids)
608
+ trace_slice_end("tokenize", obj.rid)
602
609
  return self._create_tokenized_object(
603
610
  obj, input_text, input_ids, input_embeds, mm_inputs, token_type_ids
604
611
  )
@@ -673,7 +680,7 @@ class TokenizerManager:
673
680
  ):
674
681
  raise ValueError(
675
682
  "The server is not configured to enable custom logit processor. "
676
- "Please set `--enable-custom-logits-processor` to enable this feature."
683
+ "Please set `--enable-custom-logit-processor` to enable this feature."
677
684
  )
678
685
 
679
686
  def _validate_input_ids_in_vocab(
@@ -754,19 +761,30 @@ class TokenizerManager:
754
761
  requests = [obj[i] for i in range(batch_size)]
755
762
  texts = [req.text for req in requests]
756
763
 
757
- # Batch tokenize all texts
758
- encoded = self.tokenizer(texts)
759
- input_ids_list = encoded["input_ids"]
764
+ # Check if any request is a cross-encoder request
765
+ is_cross_encoder_request = any(
766
+ isinstance(req, EmbeddingReqInput) and req.is_cross_encoder_request
767
+ for req in requests
768
+ )
769
+
770
+ # Batch tokenize all texts using unified method
771
+ input_ids_list, token_type_ids_list = await self._tokenize_texts(
772
+ texts, is_cross_encoder_request
773
+ )
760
774
 
761
775
  # Process all requests
762
776
  tokenized_objs = []
763
777
  for i, req in enumerate(requests):
764
778
  self._validate_one_request(obj[i], input_ids_list[i])
779
+ token_type_ids = (
780
+ token_type_ids_list[i] if token_type_ids_list is not None else None
781
+ )
765
782
  tokenized_objs.append(
766
783
  self._create_tokenized_object(
767
- req, req.text, input_ids_list[i], None, None
784
+ req, req.text, input_ids_list[i], None, None, token_type_ids
768
785
  )
769
786
  )
787
+ trace_slice_end("tokenize", req.rid)
770
788
  logger.debug(f"Completed batch processing for {batch_size} requests")
771
789
  return tokenized_objs
772
790
 
@@ -794,9 +812,12 @@ class TokenizerManager:
794
812
  tokenized_obj: Union[TokenizedGenerateReqInput, TokenizedEmbeddingReqInput],
795
813
  created_time: Optional[float] = None,
796
814
  ):
815
+ trace_slice_start("dispatch", obj.rid)
816
+ tokenized_obj.trace_context = trace_get_proc_propagate_context(obj.rid)
797
817
  self.send_to_scheduler.send_pyobj(tokenized_obj)
798
818
  state = ReqState([], False, asyncio.Event(), obj, created_time=created_time)
799
819
  self.rid_to_state[obj.rid] = state
820
+ trace_slice_end("dispatch", obj.rid, thread_finish_flag=True)
800
821
  return state
801
822
 
802
823
  def _send_batch_request(
@@ -1014,74 +1035,14 @@ class TokenizerManager:
1014
1035
  except StopAsyncIteration:
1015
1036
  pass
1016
1037
 
1017
- async def flush_cache(self) -> FlushCacheReqOutput:
1018
- return (await self.flush_cache_communicator(FlushCacheReqInput()))[0]
1019
-
1020
- async def clear_hicache_storage(self) -> ClearHiCacheReqOutput:
1021
- """Clear the hierarchical cache storage."""
1022
- # Delegate to the scheduler to handle HiCacheStorage clearing
1023
- return (await self.clear_hicache_storage_communicator(ClearHiCacheReqInput()))[
1024
- 0
1025
- ]
1026
-
1027
1038
  def abort_request(self, rid: str = "", abort_all: bool = False):
1028
1039
  if not abort_all and rid not in self.rid_to_state:
1029
1040
  return
1030
1041
  req = AbortReq(rid, abort_all)
1031
1042
  self.send_to_scheduler.send_pyobj(req)
1032
-
1033
1043
  if self.enable_metrics:
1034
1044
  self.metrics_collector.observe_one_aborted_request()
1035
1045
 
1036
- async def start_profile(
1037
- self,
1038
- output_dir: Optional[str] = None,
1039
- start_step: Optional[int] = None,
1040
- num_steps: Optional[int] = None,
1041
- activities: Optional[List[str]] = None,
1042
- with_stack: Optional[bool] = None,
1043
- record_shapes: Optional[bool] = None,
1044
- profile_by_stage: bool = False,
1045
- ):
1046
- self.auto_create_handle_loop()
1047
- env_with_stack: bool = get_bool_env_var("SGLANG_PROFILE_WITH_STACK", "true")
1048
- with_stack = False if with_stack is False or env_with_stack is False else True
1049
- req = ProfileReq(
1050
- type=ProfileReqType.START_PROFILE,
1051
- output_dir=output_dir,
1052
- start_step=start_step,
1053
- num_steps=num_steps,
1054
- activities=activities,
1055
- with_stack=with_stack,
1056
- record_shapes=record_shapes,
1057
- profile_by_stage=profile_by_stage,
1058
- profile_id=str(time.time()),
1059
- )
1060
- return await self._execute_profile(req)
1061
-
1062
- async def stop_profile(self):
1063
- self.auto_create_handle_loop()
1064
- req = ProfileReq(type=ProfileReqType.STOP_PROFILE)
1065
- return await self._execute_profile(req)
1066
-
1067
- async def _execute_profile(self, req: ProfileReq):
1068
- result = (await self.profile_communicator(req))[0]
1069
- if not result.success:
1070
- raise RuntimeError(result.message)
1071
- return result
1072
-
1073
- async def start_expert_distribution_record(self):
1074
- self.auto_create_handle_loop()
1075
- await self.expert_distribution_communicator(ExpertDistributionReq.START_RECORD)
1076
-
1077
- async def stop_expert_distribution_record(self):
1078
- self.auto_create_handle_loop()
1079
- await self.expert_distribution_communicator(ExpertDistributionReq.STOP_RECORD)
1080
-
1081
- async def dump_expert_distribution_record(self):
1082
- self.auto_create_handle_loop()
1083
- await self.expert_distribution_communicator(ExpertDistributionReq.DUMP_RECORD)
1084
-
1085
1046
  async def pause_generation(self):
1086
1047
  async with self.is_pause_cond:
1087
1048
  self.is_pause = True
@@ -1117,7 +1078,7 @@ class TokenizerManager:
1117
1078
  self, obj: UpdateWeightFromDiskReqInput
1118
1079
  ) -> Tuple[bool, str]:
1119
1080
  if self.server_args.tokenizer_worker_num > 1:
1120
- obj = MultiTokenizerWarpper(self.worker_id, obj)
1081
+ obj = MultiTokenizerWrapper(self.worker_id, obj)
1121
1082
  self.send_to_scheduler.send_pyobj(obj)
1122
1083
  self.model_update_result = asyncio.Future()
1123
1084
  if self.server_args.dp_size == 1:
@@ -1142,191 +1103,6 @@ class TokenizerManager:
1142
1103
  all_paused_requests = [r.num_paused_requests for r in result]
1143
1104
  return all_success, all_message, all_paused_requests
1144
1105
 
1145
- async def init_weights_update_group(
1146
- self,
1147
- obj: InitWeightsUpdateGroupReqInput,
1148
- request: Optional[fastapi.Request] = None,
1149
- ) -> Tuple[bool, str]:
1150
- self.auto_create_handle_loop()
1151
- assert (
1152
- self.server_args.dp_size == 1
1153
- ), "dp_size must be 1 for init parameter update group"
1154
- result = (await self.init_weights_update_group_communicator(obj))[0]
1155
- return result.success, result.message
1156
-
1157
- async def update_weights_from_distributed(
1158
- self,
1159
- obj: UpdateWeightsFromDistributedReqInput,
1160
- request: Optional[fastapi.Request] = None,
1161
- ) -> Tuple[bool, str]:
1162
- self.auto_create_handle_loop()
1163
- assert (
1164
- self.server_args.dp_size == 1 or self.server_args.enable_dp_attention
1165
- ), "dp_size must be 1 or dp attention must be enabled for update weights from distributed"
1166
-
1167
- if obj.abort_all_requests:
1168
- self.abort_request(abort_all=True)
1169
-
1170
- # This means that weight sync
1171
- # cannot run while requests are in progress.
1172
- async with self.model_update_lock.writer_lock:
1173
- result = (await self.update_weights_from_distributed_communicator(obj))[0]
1174
- return result.success, result.message
1175
-
1176
- async def update_weights_from_tensor(
1177
- self,
1178
- obj: UpdateWeightsFromTensorReqInput,
1179
- request: Optional[fastapi.Request] = None,
1180
- ) -> Tuple[bool, str]:
1181
- self.auto_create_handle_loop()
1182
- assert (
1183
- self.server_args.dp_size == 1 or self.server_args.enable_dp_attention
1184
- ), "dp_size must be 1 or dp attention must be enabled for update weights from tensor"
1185
-
1186
- if obj.abort_all_requests:
1187
- self.abort_request(abort_all=True)
1188
-
1189
- # This means that weight sync
1190
- # cannot run while requests are in progress.
1191
- async with self.model_update_lock.writer_lock:
1192
- result = (await self.update_weights_from_tensor_communicator(obj))[0]
1193
- return result.success, result.message
1194
-
1195
- async def load_lora_adapter(
1196
- self,
1197
- obj: LoadLoRAAdapterReqInput,
1198
- _: Optional[fastapi.Request] = None,
1199
- ) -> LoadLoRAAdapterReqOutput:
1200
- self.auto_create_handle_loop()
1201
-
1202
- try:
1203
- if not self.server_args.enable_lora:
1204
- raise ValueError(
1205
- "LoRA is not enabled. Please set `--enable-lora` to enable LoRA."
1206
- )
1207
-
1208
- # TODO (lifuhuang): Remove this after we verify that dynamic lora loading works
1209
- # with dp_size > 1.
1210
- assert (
1211
- self.server_args.dp_size == 1
1212
- ), "dp_size must be 1 for dynamic lora loading"
1213
- logger.info(
1214
- "Start load Lora adapter. Lora name=%s, path=%s",
1215
- obj.lora_name,
1216
- obj.lora_path,
1217
- )
1218
-
1219
- async with self.lora_update_lock:
1220
- if (
1221
- self.server_args.max_loaded_loras is not None
1222
- and self.lora_registry.num_registered_loras
1223
- >= self.server_args.max_loaded_loras
1224
- ):
1225
- raise ValueError(
1226
- f"Cannot load LoRA adapter {obj.lora_name} at path {obj.lora_path}. "
1227
- f"Maximum number of loaded LoRA adapters is {self.server_args.max_loaded_loras}. "
1228
- "Please unload some LoRA adapters before loading new ones."
1229
- )
1230
-
1231
- # Generate new uniquely identifiable LoRARef object.
1232
- new_adapter = LoRARef(
1233
- lora_name=obj.lora_name,
1234
- lora_path=obj.lora_path,
1235
- pinned=obj.pinned,
1236
- )
1237
-
1238
- # Trigger the actual loading operation at the backend processes.
1239
- obj.lora_id = new_adapter.lora_id
1240
- result = (await self.update_lora_adapter_communicator(obj))[0]
1241
-
1242
- # Register the LoRA adapter only after loading is successful.
1243
- if result.success:
1244
- await self.lora_registry.register(new_adapter)
1245
-
1246
- return result
1247
- except ValueError as e:
1248
- return LoadLoRAAdapterReqOutput(
1249
- success=False,
1250
- error_message=str(e),
1251
- )
1252
-
1253
- async def unload_lora_adapter(
1254
- self,
1255
- obj: UnloadLoRAAdapterReqInput,
1256
- _: Optional[fastapi.Request] = None,
1257
- ) -> UnloadLoRAAdapterReqOutput:
1258
- self.auto_create_handle_loop()
1259
-
1260
- try:
1261
- if not self.server_args.enable_lora:
1262
- raise ValueError(
1263
- "LoRA is not enabled. Please set `--enable-lora` to enable LoRA."
1264
- )
1265
-
1266
- assert (
1267
- obj.lora_name is not None
1268
- ), "lora_name must be provided to unload LoRA adapter"
1269
-
1270
- # TODO (lifuhuang): Remove this after we verify that dynamic lora loading works
1271
- # with dp_size > 1.
1272
- assert (
1273
- self.server_args.dp_size == 1
1274
- ), "dp_size must be 1 for dynamic lora loading"
1275
- logger.info(
1276
- "Start unload Lora adapter. Lora name=%s",
1277
- obj.lora_name,
1278
- )
1279
-
1280
- async with self.lora_update_lock:
1281
- # Unregister the LoRA adapter from the registry to stop new requests for this adapter
1282
- # from being started.
1283
- lora_id = await self.lora_registry.unregister(obj.lora_name)
1284
- obj.lora_id = lora_id
1285
-
1286
- # Initiate the actual unloading operation at the backend processes only after all
1287
- # ongoing requests using this LoRA adapter are finished.
1288
- await self.lora_registry.wait_for_unload(lora_id)
1289
- result = (await self.update_lora_adapter_communicator(obj))[0]
1290
-
1291
- return result
1292
- except ValueError as e:
1293
- return UnloadLoRAAdapterReqOutput(success=False, error_message=str(e))
1294
-
1295
- async def get_weights_by_name(
1296
- self, obj: GetWeightsByNameReqInput, request: Optional[fastapi.Request] = None
1297
- ):
1298
- self.auto_create_handle_loop()
1299
- results = await self.get_weights_by_name_communicator(obj)
1300
- all_parameters = [r.parameter for r in results]
1301
- if self.server_args.dp_size == 1:
1302
- return all_parameters[0]
1303
- else:
1304
- return all_parameters
1305
-
1306
- async def release_memory_occupation(
1307
- self,
1308
- obj: ReleaseMemoryOccupationReqInput,
1309
- request: Optional[fastapi.Request] = None,
1310
- ):
1311
- self.auto_create_handle_loop()
1312
- await self.release_memory_occupation_communicator(obj)
1313
-
1314
- async def resume_memory_occupation(
1315
- self,
1316
- obj: ResumeMemoryOccupationReqInput,
1317
- request: Optional[fastapi.Request] = None,
1318
- ):
1319
- self.auto_create_handle_loop()
1320
- await self.resume_memory_occupation_communicator(obj)
1321
-
1322
- async def slow_down(
1323
- self,
1324
- obj: SlowDownReqInput,
1325
- request: Optional[fastapi.Request] = None,
1326
- ):
1327
- self.auto_create_handle_loop()
1328
- await self.slow_down_communicator(obj)
1329
-
1330
1106
  async def open_session(
1331
1107
  self, obj: OpenSessionReqInput, request: Optional[fastapi.Request] = None
1332
1108
  ):
@@ -1338,7 +1114,7 @@ class TokenizerManager:
1338
1114
  return None
1339
1115
 
1340
1116
  if self.server_args.tokenizer_worker_num > 1:
1341
- obj = MultiTokenizerWarpper(self.worker_id, obj)
1117
+ obj = MultiTokenizerWrapper(self.worker_id, obj)
1342
1118
  self.send_to_scheduler.send_pyobj(obj)
1343
1119
 
1344
1120
  self.session_futures[obj.session_id] = asyncio.Future()
@@ -1351,28 +1127,6 @@ class TokenizerManager:
1351
1127
  ):
1352
1128
  await self.send_to_scheduler.send_pyobj(obj)
1353
1129
 
1354
- async def get_internal_state(self) -> List[Dict[Any, Any]]:
1355
- req = GetInternalStateReq()
1356
- responses: List[GetInternalStateReqOutput] = (
1357
- await self.get_internal_state_communicator(req)
1358
- )
1359
- # Many DP ranks
1360
- return [res.internal_state for res in responses]
1361
-
1362
- async def set_internal_state(self, obj: SetInternalStateReq) -> List[bool]:
1363
- responses: List[SetInternalStateReqOutput] = (
1364
- await self.set_internal_state_communicator(obj)
1365
- )
1366
- return [res.updated for res in responses]
1367
-
1368
- async def get_load(self) -> dict:
1369
- # TODO(lsyin): fake load report server
1370
- if not self.current_load_lock.locked():
1371
- async with self.current_load_lock:
1372
- internal_state = await self.get_internal_state()
1373
- self.current_load = internal_state[0]["load"]
1374
- return {"load": self.current_load}
1375
-
1376
1130
  def get_log_request_metadata(self):
1377
1131
  max_length = None
1378
1132
  skip_names = None
@@ -1491,6 +1245,9 @@ class TokenizerManager:
1491
1245
  self.asyncio_tasks.add(
1492
1246
  loop.create_task(print_exception_wrapper(self.sigterm_watchdog))
1493
1247
  )
1248
+ self.asyncio_tasks.add(
1249
+ loop.create_task(print_exception_wrapper(self.watch_load_thread))
1250
+ )
1494
1251
 
1495
1252
  def dump_requests_before_crash(self):
1496
1253
  if self.crash_dump_performed:
@@ -1710,6 +1467,9 @@ class TokenizerManager:
1710
1467
  meta_info["spec_verify_ct"] = recv_obj.spec_verify_ct[i]
1711
1468
  state.finished_time = time.time()
1712
1469
  meta_info["e2e_latency"] = state.finished_time - state.created_time
1470
+
1471
+ trace_req_finish(rid, ts=int(state.finished_time * 1e9))
1472
+
1713
1473
  del self.rid_to_state[rid]
1714
1474
 
1715
1475
  # Mark ongoing LoRA request as finished.
@@ -1859,6 +1619,12 @@ class TokenizerManager:
1859
1619
  else 0
1860
1620
  )
1861
1621
 
1622
+ customer_labels = getattr(state.obj, "customer_labels", None)
1623
+ labels = (
1624
+ {**self.metrics_collector.labels, **customer_labels}
1625
+ if customer_labels
1626
+ else self.metrics_collector.labels
1627
+ )
1862
1628
  if (
1863
1629
  state.first_token_time == 0.0
1864
1630
  and self.disaggregation_mode != DisaggregationMode.PREFILL
@@ -1866,7 +1632,7 @@ class TokenizerManager:
1866
1632
  state.first_token_time = state.last_time = time.time()
1867
1633
  state.last_completion_tokens = completion_tokens
1868
1634
  self.metrics_collector.observe_time_to_first_token(
1869
- state.first_token_time - state.created_time
1635
+ labels, state.first_token_time - state.created_time
1870
1636
  )
1871
1637
  else:
1872
1638
  num_new_tokens = completion_tokens - state.last_completion_tokens
@@ -1874,6 +1640,7 @@ class TokenizerManager:
1874
1640
  new_time = time.time()
1875
1641
  interval = new_time - state.last_time
1876
1642
  self.metrics_collector.observe_inter_token_latency(
1643
+ labels,
1877
1644
  interval,
1878
1645
  num_new_tokens,
1879
1646
  )
@@ -1888,6 +1655,7 @@ class TokenizerManager:
1888
1655
  or state.obj.sampling_params.get("structural_tag", None)
1889
1656
  )
1890
1657
  self.metrics_collector.observe_one_finished_request(
1658
+ labels,
1891
1659
  recv_obj.prompt_tokens[i],
1892
1660
  completion_tokens,
1893
1661
  recv_obj.cached_tokens[i],
@@ -2059,11 +1827,15 @@ class TokenizerManager:
2059
1827
  # the next position after the last token in the prompt
2060
1828
  output_logprobs = result["meta_info"].get("output_token_ids_logprobs", [])
2061
1829
 
2062
- # Throw an error here if output_logprobs is None
2063
- if output_logprobs is None:
1830
+ # Check if output_logprobs is properly populated
1831
+ if (
1832
+ output_logprobs is None
1833
+ or not output_logprobs
1834
+ or len(output_logprobs) == 0
1835
+ ):
2064
1836
  raise RuntimeError(
2065
- f"output_logprobs is None for request {result['meta_info'].get('id', '<unknown>')}. "
2066
- "This usually indicates a problem with the scoring request or the backend output."
1837
+ f"output_logprobs is empty for request {result['meta_info'].get('id', '<unknown>')}. "
1838
+ "This indicates token_ids_logprobs were not computed properly for the scoring request."
2067
1839
  )
2068
1840
 
2069
1841
  for logprob, token_id, _ in output_logprobs[0]:
@@ -2088,6 +1860,20 @@ class TokenizerManager:
2088
1860
 
2089
1861
  return scores
2090
1862
 
1863
+ async def watch_load_thread(self):
1864
+ # Only for dp_controller when dp_size > 1
1865
+ if (
1866
+ self.server_args.dp_size == 1
1867
+ or self.server_args.load_balance_method == "round_robin"
1868
+ ):
1869
+ return
1870
+
1871
+ while True:
1872
+ await asyncio.sleep(self.server_args.load_watch_interval)
1873
+ loads = await self.get_load_communicator(GetLoadReqInput())
1874
+ load_udpate_req = WatchLoadUpdateReq(loads=loads)
1875
+ self.send_to_scheduler.send_pyobj(load_udpate_req)
1876
+
2091
1877
 
2092
1878
  class ServerStatus(Enum):
2093
1879
  Up = "Up"
@@ -2139,51 +1925,6 @@ class SignalHandler:
2139
1925
  kill_process_tree(os.getpid())
2140
1926
 
2141
1927
 
2142
- T = TypeVar("T")
2143
-
2144
-
2145
- class _Communicator(Generic[T]):
2146
- """Note: The communicator now only run up to 1 in-flight request at any time."""
2147
-
2148
- enable_multi_tokenizer = False
2149
-
2150
- def __init__(self, sender, fan_out: int):
2151
- self._sender = sender
2152
- self._fan_out = fan_out
2153
- self._result_event: Optional[asyncio.Event] = None
2154
- self._result_values: Optional[List[T]] = None
2155
- self._ready_queue: Deque[asyncio.Future] = deque()
2156
-
2157
- async def __call__(self, obj):
2158
- ready_event = asyncio.Event()
2159
- if self._result_event is not None or len(self._ready_queue) > 0:
2160
- self._ready_queue.append(ready_event)
2161
- await ready_event.wait()
2162
- assert self._result_event is None
2163
- assert self._result_values is None
2164
-
2165
- if obj:
2166
- if _Communicator.enable_multi_tokenizer:
2167
- obj = MultiTokenizerWarpper(worker_id=os.getpid(), obj=obj)
2168
- self._sender.send_pyobj(obj)
2169
-
2170
- self._result_event = asyncio.Event()
2171
- self._result_values = []
2172
- await self._result_event.wait()
2173
- result_values = self._result_values
2174
- self._result_event = self._result_values = None
2175
-
2176
- if len(self._ready_queue) > 0:
2177
- self._ready_queue.popleft().set()
2178
-
2179
- return result_values
2180
-
2181
- def handle_recv(self, recv_obj: T):
2182
- self._result_values.append(recv_obj)
2183
- if len(self._result_values) == self._fan_out:
2184
- self._result_event.set()
2185
-
2186
-
2187
1928
  # Note: request abort handling logic
2188
1929
  # We should handle all of the following cases correctly.
2189
1930
  #