sglang 0.5.2rc2__py3-none-any.whl → 0.5.3rc0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (238) hide show
  1. sglang/bench_one_batch_server.py +10 -1
  2. sglang/bench_serving.py +257 -29
  3. sglang/srt/configs/__init__.py +4 -0
  4. sglang/srt/configs/device_config.py +3 -1
  5. sglang/srt/configs/dots_vlm.py +139 -0
  6. sglang/srt/configs/load_config.py +1 -0
  7. sglang/srt/configs/model_config.py +50 -6
  8. sglang/srt/configs/qwen3_next.py +326 -0
  9. sglang/srt/connector/__init__.py +8 -1
  10. sglang/srt/connector/remote_instance.py +82 -0
  11. sglang/srt/constrained/base_grammar_backend.py +48 -12
  12. sglang/srt/constrained/llguidance_backend.py +0 -1
  13. sglang/srt/constrained/outlines_backend.py +0 -1
  14. sglang/srt/constrained/xgrammar_backend.py +28 -9
  15. sglang/srt/custom_op.py +11 -1
  16. sglang/srt/debug_utils/dump_comparator.py +81 -44
  17. sglang/srt/debug_utils/dump_loader.py +97 -0
  18. sglang/srt/debug_utils/dumper.py +11 -3
  19. sglang/srt/debug_utils/text_comparator.py +73 -11
  20. sglang/srt/disaggregation/base/conn.py +1 -1
  21. sglang/srt/disaggregation/common/conn.py +15 -12
  22. sglang/srt/disaggregation/decode.py +21 -10
  23. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +4 -1
  24. sglang/srt/disaggregation/fake/conn.py +1 -1
  25. sglang/srt/disaggregation/mini_lb.py +6 -445
  26. sglang/srt/disaggregation/mooncake/conn.py +18 -10
  27. sglang/srt/disaggregation/nixl/conn.py +180 -16
  28. sglang/srt/disaggregation/prefill.py +5 -3
  29. sglang/srt/disaggregation/utils.py +5 -50
  30. sglang/srt/distributed/parallel_state.py +24 -3
  31. sglang/srt/entrypoints/engine.py +38 -17
  32. sglang/srt/entrypoints/grpc_request_manager.py +580 -0
  33. sglang/srt/entrypoints/grpc_server.py +680 -0
  34. sglang/srt/entrypoints/http_server.py +85 -54
  35. sglang/srt/entrypoints/openai/protocol.py +4 -1
  36. sglang/srt/entrypoints/openai/serving_base.py +46 -3
  37. sglang/srt/entrypoints/openai/serving_chat.py +36 -16
  38. sglang/srt/entrypoints/openai/serving_completions.py +12 -3
  39. sglang/srt/entrypoints/openai/serving_embedding.py +8 -3
  40. sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
  41. sglang/srt/entrypoints/openai/serving_responses.py +6 -3
  42. sglang/srt/entrypoints/openai/serving_score.py +1 -0
  43. sglang/srt/eplb/eplb_manager.py +2 -2
  44. sglang/srt/eplb/expert_distribution.py +26 -13
  45. sglang/srt/eplb/expert_location.py +8 -3
  46. sglang/srt/eplb/expert_location_updater.py +1 -1
  47. sglang/srt/function_call/base_format_detector.py +3 -6
  48. sglang/srt/function_call/ebnf_composer.py +11 -9
  49. sglang/srt/function_call/function_call_parser.py +6 -0
  50. sglang/srt/function_call/glm4_moe_detector.py +1 -1
  51. sglang/srt/function_call/qwen3_coder_detector.py +1 -1
  52. sglang/srt/grpc/__init__.py +1 -0
  53. sglang/srt/grpc/sglang_scheduler_pb2.py +106 -0
  54. sglang/srt/grpc/sglang_scheduler_pb2.pyi +427 -0
  55. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +236 -0
  56. sglang/srt/hf_transformers_utils.py +4 -0
  57. sglang/srt/layers/activation.py +142 -9
  58. sglang/srt/layers/attention/ascend_backend.py +11 -4
  59. sglang/srt/layers/attention/fla/chunk.py +242 -0
  60. sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
  61. sglang/srt/layers/attention/fla/chunk_o.py +178 -0
  62. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
  63. sglang/srt/layers/attention/fla/cumsum.py +300 -0
  64. sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
  65. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
  66. sglang/srt/layers/attention/fla/index.py +37 -0
  67. sglang/srt/layers/attention/fla/l2norm.py +150 -0
  68. sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
  69. sglang/srt/layers/attention/fla/op.py +66 -0
  70. sglang/srt/layers/attention/fla/solve_tril.py +465 -0
  71. sglang/srt/layers/attention/fla/utils.py +331 -0
  72. sglang/srt/layers/attention/fla/wy_fast.py +158 -0
  73. sglang/srt/layers/attention/flashinfer_backend.py +6 -4
  74. sglang/srt/layers/attention/flashinfer_mla_backend.py +16 -12
  75. sglang/srt/layers/attention/hybrid_attn_backend.py +57 -50
  76. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +602 -0
  77. sglang/srt/layers/attention/intel_amx_backend.py +3 -0
  78. sglang/srt/layers/attention/mamba/causal_conv1d.py +128 -0
  79. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +1052 -0
  80. sglang/srt/layers/attention/mamba/mamba.py +64 -0
  81. sglang/srt/layers/attention/torch_native_backend.py +12 -6
  82. sglang/srt/layers/attention/triton_backend.py +18 -1
  83. sglang/srt/layers/attention/trtllm_mla_backend.py +124 -31
  84. sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
  85. sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
  86. sglang/srt/layers/dp_attention.py +30 -1
  87. sglang/srt/layers/layernorm.py +32 -15
  88. sglang/srt/layers/linear.py +34 -3
  89. sglang/srt/layers/logits_processor.py +29 -10
  90. sglang/srt/layers/moe/__init__.py +2 -1
  91. sglang/srt/layers/moe/cutlass_w4a8_moe.py +3 -3
  92. sglang/srt/layers/moe/ep_moe/kernels.py +1 -1
  93. sglang/srt/layers/moe/ep_moe/layer.py +182 -62
  94. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +156 -0
  95. sglang/srt/layers/moe/fused_moe_native.py +5 -3
  96. sglang/srt/layers/moe/fused_moe_triton/configs/{triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +35 -35
  97. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
  98. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
  99. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  100. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
  101. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
  102. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
  103. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
  104. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  105. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
  106. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
  107. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +1 -1
  108. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
  109. sglang/srt/layers/moe/fused_moe_triton/layer.py +61 -59
  110. sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
  111. sglang/srt/layers/moe/moe_runner/base.py +274 -1
  112. sglang/srt/layers/moe/moe_runner/runner.py +80 -0
  113. sglang/srt/layers/moe/moe_runner/triton.py +448 -0
  114. sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
  115. sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
  116. sglang/srt/layers/moe/token_dispatcher/deepep.py +43 -39
  117. sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
  118. sglang/srt/layers/moe/topk.py +30 -9
  119. sglang/srt/layers/moe/utils.py +12 -6
  120. sglang/srt/layers/quantization/awq.py +19 -7
  121. sglang/srt/layers/quantization/base_config.py +11 -6
  122. sglang/srt/layers/quantization/blockwise_int8.py +38 -27
  123. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
  124. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
  125. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
  126. sglang/srt/layers/quantization/fp8.py +76 -47
  127. sglang/srt/layers/quantization/fp8_utils.py +50 -31
  128. sglang/srt/layers/quantization/gptq.py +25 -17
  129. sglang/srt/layers/quantization/modelopt_quant.py +147 -47
  130. sglang/srt/layers/quantization/moe_wna16.py +21 -18
  131. sglang/srt/layers/quantization/mxfp4.py +64 -40
  132. sglang/srt/layers/quantization/quark/quark_moe.py +32 -27
  133. sglang/srt/layers/quantization/unquant.py +135 -47
  134. sglang/srt/layers/quantization/w4afp8.py +30 -17
  135. sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
  136. sglang/srt/layers/quantization/w8a8_int8.py +76 -38
  137. sglang/srt/layers/sampler.py +162 -18
  138. sglang/srt/lora/backend/base_backend.py +50 -8
  139. sglang/srt/lora/backend/triton_backend.py +90 -2
  140. sglang/srt/lora/layers.py +32 -0
  141. sglang/srt/lora/lora.py +4 -1
  142. sglang/srt/lora/lora_manager.py +35 -112
  143. sglang/srt/lora/mem_pool.py +24 -10
  144. sglang/srt/lora/utils.py +18 -9
  145. sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
  146. sglang/srt/managers/cache_controller.py +158 -160
  147. sglang/srt/managers/data_parallel_controller.py +105 -35
  148. sglang/srt/managers/detokenizer_manager.py +8 -4
  149. sglang/srt/managers/disagg_service.py +46 -0
  150. sglang/srt/managers/io_struct.py +199 -12
  151. sglang/srt/managers/mm_utils.py +1 -0
  152. sglang/srt/managers/multi_tokenizer_mixin.py +350 -400
  153. sglang/srt/managers/schedule_batch.py +77 -56
  154. sglang/srt/managers/schedule_policy.py +1 -1
  155. sglang/srt/managers/scheduler.py +187 -39
  156. sglang/srt/managers/scheduler_metrics_mixin.py +4 -3
  157. sglang/srt/managers/scheduler_output_processor_mixin.py +55 -11
  158. sglang/srt/managers/scheduler_profiler_mixin.py +1 -1
  159. sglang/srt/managers/tokenizer_communicator_mixin.py +569 -0
  160. sglang/srt/managers/tokenizer_manager.py +259 -519
  161. sglang/srt/managers/tp_worker.py +53 -4
  162. sglang/srt/managers/tp_worker_overlap_thread.py +42 -19
  163. sglang/srt/mem_cache/hicache_storage.py +3 -23
  164. sglang/srt/mem_cache/hiradix_cache.py +103 -43
  165. sglang/srt/mem_cache/memory_pool.py +347 -48
  166. sglang/srt/mem_cache/memory_pool_host.py +105 -46
  167. sglang/srt/mem_cache/radix_cache.py +0 -2
  168. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
  169. sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
  170. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +86 -4
  171. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +280 -0
  172. sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
  173. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +49 -7
  174. sglang/srt/mem_cache/swa_radix_cache.py +0 -2
  175. sglang/srt/metrics/collector.py +493 -76
  176. sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
  177. sglang/srt/model_executor/cpu_graph_runner.py +640 -0
  178. sglang/srt/model_executor/cuda_graph_runner.py +13 -5
  179. sglang/srt/model_executor/forward_batch_info.py +59 -2
  180. sglang/srt/model_executor/model_runner.py +356 -29
  181. sglang/srt/model_loader/__init__.py +9 -3
  182. sglang/srt/model_loader/loader.py +128 -4
  183. sglang/srt/model_loader/weight_utils.py +2 -1
  184. sglang/srt/models/apertus.py +686 -0
  185. sglang/srt/models/bailing_moe.py +798 -218
  186. sglang/srt/models/bailing_moe_nextn.py +168 -0
  187. sglang/srt/models/deepseek_v2.py +109 -15
  188. sglang/srt/models/dots_vlm.py +174 -0
  189. sglang/srt/models/dots_vlm_vit.py +337 -0
  190. sglang/srt/models/ernie4.py +1 -1
  191. sglang/srt/models/gemma3n_mm.py +1 -1
  192. sglang/srt/models/glm4_moe.py +1 -1
  193. sglang/srt/models/glm4v.py +4 -2
  194. sglang/srt/models/glm4v_moe.py +3 -0
  195. sglang/srt/models/gpt_oss.py +1 -1
  196. sglang/srt/models/llama4.py +9 -0
  197. sglang/srt/models/llama_eagle3.py +13 -0
  198. sglang/srt/models/longcat_flash.py +2 -2
  199. sglang/srt/models/mllama4.py +25 -0
  200. sglang/srt/models/opt.py +637 -0
  201. sglang/srt/models/qwen2.py +7 -0
  202. sglang/srt/models/qwen2_5_vl.py +27 -3
  203. sglang/srt/models/qwen2_moe.py +56 -12
  204. sglang/srt/models/qwen3_moe.py +1 -1
  205. sglang/srt/models/qwen3_next.py +1042 -0
  206. sglang/srt/models/qwen3_next_mtp.py +112 -0
  207. sglang/srt/models/step3_vl.py +1 -1
  208. sglang/srt/multimodal/processors/dots_vlm.py +99 -0
  209. sglang/srt/multimodal/processors/glm4v.py +9 -9
  210. sglang/srt/multimodal/processors/internvl.py +141 -129
  211. sglang/srt/multimodal/processors/qwen_vl.py +15 -5
  212. sglang/srt/offloader.py +27 -3
  213. sglang/srt/remote_instance_weight_loader_utils.py +69 -0
  214. sglang/srt/sampling/sampling_batch_info.py +18 -15
  215. sglang/srt/server_args.py +276 -35
  216. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +5 -0
  217. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +10 -1
  218. sglang/srt/speculative/eagle_utils.py +0 -2
  219. sglang/srt/speculative/eagle_worker.py +43 -4
  220. sglang/srt/speculative/spec_info.py +5 -0
  221. sglang/srt/speculative/standalone_worker.py +109 -0
  222. sglang/srt/tracing/trace.py +552 -0
  223. sglang/srt/utils.py +34 -3
  224. sglang/srt/weight_sync/utils.py +1 -1
  225. sglang/test/attention/test_trtllm_mla_backend.py +169 -5
  226. sglang/test/runners.py +4 -0
  227. sglang/test/test_cutlass_moe.py +24 -6
  228. sglang/test/test_disaggregation_utils.py +66 -0
  229. sglang/test/test_fp4_moe.py +370 -1
  230. sglang/test/test_utils.py +28 -1
  231. sglang/utils.py +11 -0
  232. sglang/version.py +1 -1
  233. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/METADATA +59 -123
  234. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/RECORD +237 -178
  235. sglang/srt/disaggregation/launch_lb.py +0 -118
  236. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/WHEEL +0 -0
  237. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/licenses/LICENSE +0 -0
  238. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/top_level.txt +0 -0
@@ -31,18 +31,7 @@ from contextlib import nullcontext
31
31
  from datetime import datetime
32
32
  from enum import Enum
33
33
  from http import HTTPStatus
34
- from typing import (
35
- Any,
36
- Awaitable,
37
- Deque,
38
- Dict,
39
- Generic,
40
- List,
41
- Optional,
42
- Tuple,
43
- TypeVar,
44
- Union,
45
- )
34
+ from typing import Any, Awaitable, Dict, List, Optional, Tuple, Union
46
35
 
47
36
  import fastapi
48
37
  import torch
@@ -53,18 +42,15 @@ from fastapi import BackgroundTasks
53
42
 
54
43
  from sglang.srt.aio_rwlock import RWLock
55
44
  from sglang.srt.configs.model_config import ModelConfig
56
- from sglang.srt.disaggregation.utils import (
57
- DisaggregationMode,
58
- KVClassType,
59
- TransferBackend,
60
- get_kv_class,
61
- )
45
+ from sglang.srt.disaggregation.utils import DisaggregationMode
62
46
  from sglang.srt.hf_transformers_utils import (
63
47
  get_processor,
64
48
  get_tokenizer,
65
49
  get_tokenizer_from_processor,
66
50
  )
67
51
  from sglang.srt.lora.lora_registry import LoRARef, LoRARegistry
52
+ from sglang.srt.managers.async_dynamic_batch_tokenizer import AsyncDynamicbatchTokenizer
53
+ from sglang.srt.managers.disagg_service import start_disagg_service
68
54
  from sglang.srt.managers.io_struct import (
69
55
  AbortReq,
70
56
  BatchEmbeddingOut,
@@ -73,60 +59,38 @@ from sglang.srt.managers.io_struct import (
73
59
  BatchTokenIDOut,
74
60
  BatchTokenizedEmbeddingReqInput,
75
61
  BatchTokenizedGenerateReqInput,
76
- ClearHiCacheReqInput,
77
- ClearHiCacheReqOutput,
78
62
  CloseSessionReqInput,
79
63
  ConfigureLoggingReq,
80
64
  EmbeddingReqInput,
81
- ExpertDistributionReq,
82
- ExpertDistributionReqOutput,
83
- FlushCacheReqInput,
84
- FlushCacheReqOutput,
85
65
  FreezeGCReq,
86
66
  GenerateReqInput,
87
- GetInternalStateReq,
88
- GetInternalStateReqOutput,
89
- GetWeightsByNameReqInput,
90
- GetWeightsByNameReqOutput,
67
+ GetLoadReqInput,
91
68
  HealthCheckOutput,
92
- InitWeightsUpdateGroupReqInput,
93
- InitWeightsUpdateGroupReqOutput,
94
- LoadLoRAAdapterReqInput,
95
- LoadLoRAAdapterReqOutput,
96
- LoRAUpdateResult,
97
- MultiTokenizerWarpper,
69
+ MultiTokenizerWrapper,
98
70
  OpenSessionReqInput,
99
71
  OpenSessionReqOutput,
100
- ProfileReq,
101
- ProfileReqOutput,
102
- ProfileReqType,
103
- ReleaseMemoryOccupationReqInput,
104
- ReleaseMemoryOccupationReqOutput,
105
- ResumeMemoryOccupationReqInput,
106
- ResumeMemoryOccupationReqOutput,
107
72
  SessionParams,
108
- SetInternalStateReq,
109
- SetInternalStateReqOutput,
110
- SlowDownReqInput,
111
- SlowDownReqOutput,
112
73
  TokenizedEmbeddingReqInput,
113
74
  TokenizedGenerateReqInput,
114
- UnloadLoRAAdapterReqInput,
115
- UnloadLoRAAdapterReqOutput,
116
75
  UpdateWeightFromDiskReqInput,
117
76
  UpdateWeightFromDiskReqOutput,
118
- UpdateWeightsFromDistributedReqInput,
119
- UpdateWeightsFromDistributedReqOutput,
120
- UpdateWeightsFromTensorReqInput,
121
- UpdateWeightsFromTensorReqOutput,
77
+ WatchLoadUpdateReq,
122
78
  )
123
79
  from sglang.srt.managers.mm_utils import TensorTransportMode
124
80
  from sglang.srt.managers.multimodal_processor import get_mm_processor, import_processors
125
81
  from sglang.srt.managers.scheduler import is_health_check_generate_req
126
82
  from sglang.srt.managers.scheduler_input_blocker import input_blocker_guard_region
83
+ from sglang.srt.managers.tokenizer_communicator_mixin import TokenizerCommunicatorMixin
127
84
  from sglang.srt.metrics.collector import TokenizerMetricsCollector
128
85
  from sglang.srt.sampling.sampling_params import SamplingParams
129
86
  from sglang.srt.server_args import PortArgs, ServerArgs
87
+ from sglang.srt.tracing.trace import (
88
+ trace_get_proc_propagate_context,
89
+ trace_req_finish,
90
+ trace_req_start,
91
+ trace_slice_end,
92
+ trace_slice_start,
93
+ )
130
94
  from sglang.srt.utils import (
131
95
  configure_gc_warning,
132
96
  dataclass_to_string_truncated,
@@ -180,7 +144,7 @@ class ReqState:
180
144
  output_token_ids_logprobs_idx: List = dataclasses.field(default_factory=list)
181
145
 
182
146
 
183
- class TokenizerManager:
147
+ class TokenizerManager(TokenizerCommunicatorMixin):
184
148
  """TokenizerManager is a process that tokenizes the text."""
185
149
 
186
150
  def __init__(
@@ -262,6 +226,18 @@ class TokenizerManager:
262
226
  trust_remote_code=server_args.trust_remote_code,
263
227
  revision=server_args.revision,
264
228
  )
229
+ # Initialize async dynamic batch tokenizer if enabled (common for both multimodal and non-multimodal)
230
+ if (
231
+ server_args.enable_dynamic_batch_tokenizer
232
+ and not server_args.skip_tokenizer_init
233
+ ):
234
+ self.async_dynamic_batch_tokenizer = AsyncDynamicbatchTokenizer(
235
+ self.tokenizer,
236
+ max_batch_size=server_args.dynamic_batch_tokenizer_batch_size,
237
+ batch_wait_timeout_s=server_args.dynamic_batch_tokenizer_batch_timeout,
238
+ )
239
+ else:
240
+ self.async_dynamic_batch_tokenizer = None
265
241
 
266
242
  # Init inter-process communication
267
243
  context = zmq.asyncio.Context(2)
@@ -319,8 +295,10 @@ class TokenizerManager:
319
295
  # LoRA updates and inference to overlap.
320
296
  self.lora_update_lock = asyncio.Lock()
321
297
 
322
- # For PD disaggregtion
323
- self.init_disaggregation()
298
+ self.disaggregation_mode = DisaggregationMode(
299
+ self.server_args.disaggregation_mode
300
+ )
301
+ self.bootstrap_server = start_disagg_service(self.server_args)
324
302
 
325
303
  # For load balancing
326
304
  self.current_load = 0
@@ -328,12 +306,16 @@ class TokenizerManager:
328
306
 
329
307
  # Metrics
330
308
  if self.enable_metrics:
309
+ labels = {
310
+ "model_name": self.server_args.served_model_name,
311
+ # TODO: Add lora name/path in the future,
312
+ }
313
+ if server_args.tokenizer_metrics_allowed_customer_labels:
314
+ for label in server_args.tokenizer_metrics_allowed_customer_labels:
315
+ labels[label] = ""
331
316
  self.metrics_collector = TokenizerMetricsCollector(
332
317
  server_args=server_args,
333
- labels={
334
- "model_name": self.server_args.served_model_name,
335
- # TODO: Add lora name/path in the future,
336
- },
318
+ labels=labels,
337
319
  bucket_time_to_first_token=self.server_args.bucket_time_to_first_token,
338
320
  bucket_e2e_request_latency=self.server_args.bucket_e2e_request_latency,
339
321
  bucket_inter_token_latency=self.server_args.bucket_inter_token_latency,
@@ -344,50 +326,6 @@ class TokenizerManager:
344
326
  if self.server_args.gc_warning_threshold_secs > 0.0:
345
327
  configure_gc_warning(self.server_args.gc_warning_threshold_secs)
346
328
 
347
- # Communicators
348
- self.init_weights_update_group_communicator = _Communicator(
349
- self.send_to_scheduler, server_args.dp_size
350
- )
351
- self.update_weights_from_distributed_communicator = _Communicator(
352
- self.send_to_scheduler, server_args.dp_size
353
- )
354
- self.update_weights_from_tensor_communicator = _Communicator(
355
- self.send_to_scheduler, server_args.dp_size
356
- )
357
- self.get_weights_by_name_communicator = _Communicator(
358
- self.send_to_scheduler, server_args.dp_size
359
- )
360
- self.release_memory_occupation_communicator = _Communicator(
361
- self.send_to_scheduler, server_args.dp_size
362
- )
363
- self.resume_memory_occupation_communicator = _Communicator(
364
- self.send_to_scheduler, server_args.dp_size
365
- )
366
- self.slow_down_communicator = _Communicator(
367
- self.send_to_scheduler, server_args.dp_size
368
- )
369
- self.flush_cache_communicator = _Communicator(
370
- self.send_to_scheduler, server_args.dp_size
371
- )
372
- self.clear_hicache_storage_communicator = _Communicator(
373
- self.send_to_scheduler, server_args.dp_size
374
- )
375
- self.profile_communicator = _Communicator(
376
- self.send_to_scheduler, server_args.dp_size
377
- )
378
- self.get_internal_state_communicator = _Communicator(
379
- self.send_to_scheduler, server_args.dp_size
380
- )
381
- self.set_internal_state_communicator = _Communicator(
382
- self.send_to_scheduler, server_args.dp_size
383
- )
384
- self.expert_distribution_communicator = _Communicator(
385
- self.send_to_scheduler, server_args.dp_size
386
- )
387
- self.update_lora_adapter_communicator = _Communicator(
388
- self.send_to_scheduler, server_args.dp_size
389
- )
390
-
391
329
  self._result_dispatcher = TypeBasedDispatcher(
392
330
  [
393
331
  (
@@ -405,100 +343,15 @@ class TokenizerManager:
405
343
  UpdateWeightFromDiskReqOutput,
406
344
  self._handle_update_weights_from_disk_req_output,
407
345
  ),
408
- (
409
- InitWeightsUpdateGroupReqOutput,
410
- self.init_weights_update_group_communicator.handle_recv,
411
- ),
412
- (
413
- UpdateWeightsFromDistributedReqOutput,
414
- self.update_weights_from_distributed_communicator.handle_recv,
415
- ),
416
- (
417
- UpdateWeightsFromTensorReqOutput,
418
- self.update_weights_from_tensor_communicator.handle_recv,
419
- ),
420
- (
421
- GetWeightsByNameReqOutput,
422
- self.get_weights_by_name_communicator.handle_recv,
423
- ),
424
- (
425
- ReleaseMemoryOccupationReqOutput,
426
- self.release_memory_occupation_communicator.handle_recv,
427
- ),
428
- (
429
- ResumeMemoryOccupationReqOutput,
430
- self.resume_memory_occupation_communicator.handle_recv,
431
- ),
432
- (
433
- SlowDownReqOutput,
434
- self.slow_down_communicator.handle_recv,
435
- ),
436
- (
437
- ClearHiCacheReqOutput,
438
- self.clear_hicache_storage_communicator.handle_recv,
439
- ),
440
- (
441
- FlushCacheReqOutput,
442
- self.flush_cache_communicator.handle_recv,
443
- ),
444
- (
445
- ProfileReqOutput,
446
- self.profile_communicator.handle_recv,
447
- ),
448
346
  (
449
347
  FreezeGCReq,
450
348
  lambda x: None,
451
349
  ), # For handling case when scheduler skips detokenizer and forwards back to the tokenizer manager, we ignore it.
452
- (
453
- GetInternalStateReqOutput,
454
- self.get_internal_state_communicator.handle_recv,
455
- ),
456
- (
457
- SetInternalStateReqOutput,
458
- self.set_internal_state_communicator.handle_recv,
459
- ),
460
- (
461
- ExpertDistributionReqOutput,
462
- self.expert_distribution_communicator.handle_recv,
463
- ),
464
- (
465
- LoRAUpdateResult,
466
- self.update_lora_adapter_communicator.handle_recv,
467
- ),
468
350
  (HealthCheckOutput, lambda x: None),
469
351
  ]
470
352
  )
471
353
 
472
- def init_disaggregation(self):
473
- self.disaggregation_mode = DisaggregationMode(
474
- self.server_args.disaggregation_mode
475
- )
476
- self.disaggregation_transfer_backend = TransferBackend(
477
- self.server_args.disaggregation_transfer_backend
478
- )
479
- # Start kv boostrap server on prefill
480
- if self.disaggregation_mode == DisaggregationMode.PREFILL:
481
- # only start bootstrap server on prefill tm
482
- kv_bootstrap_server_class = get_kv_class(
483
- self.disaggregation_transfer_backend, KVClassType.BOOTSTRAP_SERVER
484
- )
485
- self.bootstrap_server = kv_bootstrap_server_class(
486
- self.server_args.disaggregation_bootstrap_port
487
- )
488
- is_create_store = (
489
- self.server_args.node_rank == 0
490
- and self.server_args.disaggregation_transfer_backend == "ascend"
491
- )
492
- if is_create_store:
493
- try:
494
- from mf_adapter import create_config_store
495
-
496
- ascend_url = os.getenv("ASCEND_MF_STORE_URL")
497
- create_config_store(ascend_url)
498
- except Exception as e:
499
- error_message = f"Failed create mf store, invalid ascend_url."
500
- error_message += f" With exception {e}"
501
- raise error_message
354
+ self.init_communicators(server_args)
502
355
 
503
356
  async def generate_request(
504
357
  self,
@@ -518,6 +371,24 @@ class TokenizerManager:
518
371
  # If it's a single value, add worker_id prefix
519
372
  obj.rid = f"{self.worker_id}_{obj.rid}"
520
373
 
374
+ if obj.is_single:
375
+ bootstrap_room = (
376
+ obj.bootstrap_room if hasattr(obj, "bootstrap_room") else None
377
+ )
378
+ trace_req_start(obj.rid, bootstrap_room, ts=int(created_time * 1e9))
379
+ trace_slice_start("", obj.rid, ts=int(created_time * 1e9), anonymous=True)
380
+ else:
381
+ for i in range(len(obj.rid)):
382
+ bootstrap_room = (
383
+ obj.bootstrap_room[i]
384
+ if hasattr(obj, "bootstrap_room") and obj.bootstrap_room
385
+ else None
386
+ )
387
+ trace_req_start(obj.rid[i], bootstrap_room, ts=int(created_time * 1e9))
388
+ trace_slice_start(
389
+ "", obj.rid[i], ts=int(created_time * 1e9), anonymous=True
390
+ )
391
+
521
392
  if self.log_requests:
522
393
  max_length, skip_names, _ = self.log_request_metadata
523
394
  logger.info(
@@ -543,6 +414,144 @@ class TokenizerManager:
543
414
  ):
544
415
  yield response
545
416
 
417
+ def _detect_input_format(
418
+ self, texts: Union[str, List[str]], is_cross_encoder: bool
419
+ ) -> str:
420
+ """Detect the format of input texts for proper tokenization handling.
421
+
422
+ Returns:
423
+ - "single_string": Regular single text like "Hello world"
424
+ - "batch_strings": Regular batch like ["Hello", "World"]
425
+ - "cross_encoder_pairs": Cross-encoder pairs like [["query", "document"]]
426
+ """
427
+ if isinstance(texts, str):
428
+ return "single_string"
429
+
430
+ if (
431
+ is_cross_encoder
432
+ and len(texts) > 0
433
+ and isinstance(texts[0], list)
434
+ and len(texts[0]) == 2
435
+ ):
436
+ return "cross_encoder_pairs"
437
+
438
+ return "batch_strings"
439
+
440
+ def _prepare_tokenizer_input(
441
+ self, texts: Union[str, List[str]], input_format: str
442
+ ) -> Union[List[str], List[List[str]]]:
443
+ """Prepare input for the tokenizer based on detected format."""
444
+ if input_format == "single_string":
445
+ return [texts] # Wrap single string for batch processing
446
+ elif input_format == "cross_encoder_pairs":
447
+ return texts # Already in correct format: [["query", "doc"]]
448
+ else: # batch_strings
449
+ return texts # Already in correct format: ["text1", "text2"]
450
+
451
+ def _extract_tokenizer_results(
452
+ self,
453
+ input_ids: List[List[int]],
454
+ token_type_ids: Optional[List[List[int]]],
455
+ input_format: str,
456
+ original_batch_size: int,
457
+ ) -> Union[
458
+ Tuple[List[int], Optional[List[int]]],
459
+ Tuple[List[List[int]], Optional[List[List[int]]]],
460
+ ]:
461
+ """Extract results from tokenizer output based on input format."""
462
+
463
+ # For single inputs (string or single cross-encoder pair), extract first element
464
+ if (
465
+ input_format in ["single_string", "cross_encoder_pairs"]
466
+ and original_batch_size == 1
467
+ ):
468
+ single_input_ids = input_ids[0] if input_ids else []
469
+ single_token_type_ids = token_type_ids[0] if token_type_ids else None
470
+ return single_input_ids, single_token_type_ids
471
+
472
+ # For true batches, return as-is
473
+ return input_ids, token_type_ids
474
+
475
+ async def _tokenize_texts(
476
+ self, texts: Union[str, List[str]], is_cross_encoder: bool = False
477
+ ) -> Union[
478
+ Tuple[List[int], Optional[List[int]]],
479
+ Tuple[List[List[int]], Optional[List[List[int]]]],
480
+ ]:
481
+ """
482
+ Tokenize text(s) using the appropriate tokenizer strategy.
483
+
484
+ This method handles multiple input formats and chooses between async dynamic
485
+ batch tokenizer (for single texts only) and regular tokenizer.
486
+
487
+ Args:
488
+ texts: Text input in various formats:
489
+
490
+ Regular cases:
491
+ - Single string: "How are you?"
492
+ - Batch of strings: ["Hello", "World", "How are you?"]
493
+
494
+ Cross-encoder cases (sentence pairs for similarity/ranking):
495
+ - Single pair: [["query text", "document text"]]
496
+ - Multiple pairs: [["q1", "d1"], ["q2", "d2"], ["q3", "d3"]]
497
+
498
+ is_cross_encoder: Whether to return token_type_ids for cross-encoder models.
499
+ Enables proper handling of sentence pairs with segment IDs.
500
+
501
+ Returns:
502
+ Single input cases:
503
+ Tuple[List[int], Optional[List[int]]]: (input_ids, token_type_ids)
504
+ Example: ([101, 2129, 102], [0, 0, 0]) for single text
505
+ Example: ([101, 2129, 102, 4068, 102], [0, 0, 0, 1, 1]) for cross-encoder pair
506
+
507
+ Batch input cases:
508
+ Tuple[List[List[int]], Optional[List[List[int]]]]: (batch_input_ids, batch_token_type_ids)
509
+ Example: ([[101, 2129, 102], [101, 4068, 102]], None) for regular batch
510
+
511
+ Note: token_type_ids is None unless is_cross_encoder=True.
512
+ """
513
+ if not texts or self.tokenizer is None:
514
+ raise ValueError("texts cannot be empty and tokenizer must be initialized")
515
+
516
+ # Step 1: Detect input format and prepare for tokenization
517
+ input_format = self._detect_input_format(texts, is_cross_encoder)
518
+ tokenizer_input = self._prepare_tokenizer_input(texts, input_format)
519
+ original_batch_size = len(texts) if not isinstance(texts, str) else 1
520
+
521
+ # Step 2: Set up tokenizer arguments
522
+ tokenizer_kwargs = (
523
+ {"return_token_type_ids": is_cross_encoder} if is_cross_encoder else {}
524
+ )
525
+
526
+ # Step 3: Choose tokenization strategy
527
+ use_async_tokenizer = (
528
+ self.async_dynamic_batch_tokenizer is not None
529
+ and input_format == "single_string"
530
+ )
531
+
532
+ if use_async_tokenizer:
533
+ logger.debug("Using async dynamic batch tokenizer for single text")
534
+ result = await self.async_dynamic_batch_tokenizer.encode(
535
+ tokenizer_input[0], **tokenizer_kwargs
536
+ )
537
+ # Convert to batch format for consistency
538
+ input_ids = [result["input_ids"]]
539
+ token_type_ids = (
540
+ [result["token_type_ids"]]
541
+ if is_cross_encoder and result.get("token_type_ids")
542
+ else None
543
+ )
544
+ else:
545
+ logger.debug(f"Using regular tokenizer for {len(tokenizer_input)} inputs")
546
+ encoded = self.tokenizer(tokenizer_input, **tokenizer_kwargs)
547
+ input_ids = encoded["input_ids"]
548
+ token_type_ids = encoded.get("token_type_ids") if is_cross_encoder else None
549
+
550
+ # Step 4: Extract results based on input format
551
+ return self._extract_tokenizer_results(
552
+ input_ids, token_type_ids, input_format, original_batch_size
553
+ )
554
+
546
555
  async def _tokenize_one_request(
547
556
  self,
548
557
  obj: Union[GenerateReqInput, EmbeddingReqInput],
@@ -573,14 +582,10 @@ class TokenizerManager:
573
582
  "accept text prompts. Please provide input_ids or re-initialize "
574
583
  "the engine with skip_tokenizer_init=False."
575
584
  )
576
- encoded = self.tokenizer(
577
- input_text, return_token_type_ids=is_cross_encoder_request
578
- )
579
585
 
580
- input_ids = encoded["input_ids"]
581
- if is_cross_encoder_request:
582
- input_ids = encoded["input_ids"][0]
583
- token_type_ids = encoded.get("token_type_ids", [None])[0]
586
+ input_ids, token_type_ids = await self._tokenize_texts(
587
+ input_text, is_cross_encoder_request
588
+ )
584
589
 
585
590
  if self.mm_processor and obj.contains_mm_input():
586
591
  if not isinstance(obj.image_data, list):
@@ -600,6 +605,7 @@ class TokenizerManager:
600
605
  mm_inputs = None
601
606
 
602
607
  self._validate_one_request(obj, input_ids)
608
+ trace_slice_end("tokenize", obj.rid)
603
609
  return self._create_tokenized_object(
604
610
  obj, input_text, input_ids, input_embeds, mm_inputs, token_type_ids
605
611
  )
@@ -674,7 +680,7 @@ class TokenizerManager:
674
680
  ):
675
681
  raise ValueError(
676
682
  "The server is not configured to enable custom logit processor. "
677
- "Please set `--enable-custom-logits-processor` to enable this feature."
683
+ "Please set `--enable-custom-logit-processor` to enable this feature."
678
684
  )
679
685
 
680
686
  def _validate_input_ids_in_vocab(
@@ -755,19 +761,30 @@ class TokenizerManager:
755
761
  requests = [obj[i] for i in range(batch_size)]
756
762
  texts = [req.text for req in requests]
757
763
 
758
- # Batch tokenize all texts
759
- encoded = self.tokenizer(texts)
760
- input_ids_list = encoded["input_ids"]
764
+ # Check if any request is a cross-encoder request
765
+ is_cross_encoder_request = any(
766
+ isinstance(req, EmbeddingReqInput) and req.is_cross_encoder_request
767
+ for req in requests
768
+ )
769
+
770
+ # Batch tokenize all texts using unified method
771
+ input_ids_list, token_type_ids_list = await self._tokenize_texts(
772
+ texts, is_cross_encoder_request
773
+ )
761
774
 
762
775
  # Process all requests
763
776
  tokenized_objs = []
764
777
  for i, req in enumerate(requests):
765
778
  self._validate_one_request(obj[i], input_ids_list[i])
779
+ token_type_ids = (
780
+ token_type_ids_list[i] if token_type_ids_list is not None else None
781
+ )
766
782
  tokenized_objs.append(
767
783
  self._create_tokenized_object(
768
- req, req.text, input_ids_list[i], None, None
784
+ req, req.text, input_ids_list[i], None, None, token_type_ids
769
785
  )
770
786
  )
787
+ trace_slice_end("tokenize", req.rid)
771
788
  logger.debug(f"Completed batch processing for {batch_size} requests")
772
789
  return tokenized_objs
773
790
 
@@ -795,9 +812,12 @@ class TokenizerManager:
795
812
  tokenized_obj: Union[TokenizedGenerateReqInput, TokenizedEmbeddingReqInput],
796
813
  created_time: Optional[float] = None,
797
814
  ):
815
+ trace_slice_start("dispatch", obj.rid)
816
+ tokenized_obj.trace_context = trace_get_proc_propagate_context(obj.rid)
798
817
  self.send_to_scheduler.send_pyobj(tokenized_obj)
799
818
  state = ReqState([], False, asyncio.Event(), obj, created_time=created_time)
800
819
  self.rid_to_state[obj.rid] = state
820
+ trace_slice_end("dispatch", obj.rid, thread_finish_flag=True)
801
821
  return state
802
822
 
803
823
  def _send_batch_request(
@@ -1015,74 +1035,14 @@ class TokenizerManager:
1015
1035
  except StopAsyncIteration:
1016
1036
  pass
1017
1037
 
1018
- async def flush_cache(self) -> FlushCacheReqOutput:
1019
- return (await self.flush_cache_communicator(FlushCacheReqInput()))[0]
1020
-
1021
- async def clear_hicache_storage(self) -> ClearHiCacheReqOutput:
1022
- """Clear the hierarchical cache storage."""
1023
- # Delegate to the scheduler to handle HiCacheStorage clearing
1024
- return (await self.clear_hicache_storage_communicator(ClearHiCacheReqInput()))[
1025
- 0
1026
- ]
1027
-
1028
1038
  def abort_request(self, rid: str = "", abort_all: bool = False):
1029
1039
  if not abort_all and rid not in self.rid_to_state:
1030
1040
  return
1031
1041
  req = AbortReq(rid, abort_all)
1032
1042
  self.send_to_scheduler.send_pyobj(req)
1033
-
1034
1043
  if self.enable_metrics:
1035
1044
  self.metrics_collector.observe_one_aborted_request()
1036
1045
 
1037
- async def start_profile(
1038
- self,
1039
- output_dir: Optional[str] = None,
1040
- start_step: Optional[int] = None,
1041
- num_steps: Optional[int] = None,
1042
- activities: Optional[List[str]] = None,
1043
- with_stack: Optional[bool] = None,
1044
- record_shapes: Optional[bool] = None,
1045
- profile_by_stage: bool = False,
1046
- ):
1047
- self.auto_create_handle_loop()
1048
- env_with_stack: bool = get_bool_env_var("SGLANG_PROFILE_WITH_STACK", "true")
1049
- with_stack = False if with_stack is False or env_with_stack is False else True
1050
- req = ProfileReq(
1051
- type=ProfileReqType.START_PROFILE,
1052
- output_dir=output_dir,
1053
- start_step=start_step,
1054
- num_steps=num_steps,
1055
- activities=activities,
1056
- with_stack=with_stack,
1057
- record_shapes=record_shapes,
1058
- profile_by_stage=profile_by_stage,
1059
- profile_id=str(time.time()),
1060
- )
1061
- return await self._execute_profile(req)
1062
-
1063
- async def stop_profile(self):
1064
- self.auto_create_handle_loop()
1065
- req = ProfileReq(type=ProfileReqType.STOP_PROFILE)
1066
- return await self._execute_profile(req)
1067
-
1068
- async def _execute_profile(self, req: ProfileReq):
1069
- result = (await self.profile_communicator(req))[0]
1070
- if not result.success:
1071
- raise RuntimeError(result.message)
1072
- return result
1073
-
1074
- async def start_expert_distribution_record(self):
1075
- self.auto_create_handle_loop()
1076
- await self.expert_distribution_communicator(ExpertDistributionReq.START_RECORD)
1077
-
1078
- async def stop_expert_distribution_record(self):
1079
- self.auto_create_handle_loop()
1080
- await self.expert_distribution_communicator(ExpertDistributionReq.STOP_RECORD)
1081
-
1082
- async def dump_expert_distribution_record(self):
1083
- self.auto_create_handle_loop()
1084
- await self.expert_distribution_communicator(ExpertDistributionReq.DUMP_RECORD)
1085
-
1086
1046
  async def pause_generation(self):
1087
1047
  async with self.is_pause_cond:
1088
1048
  self.is_pause = True
@@ -1118,7 +1078,7 @@ class TokenizerManager:
1118
1078
  self, obj: UpdateWeightFromDiskReqInput
1119
1079
  ) -> Tuple[bool, str]:
1120
1080
  if self.server_args.tokenizer_worker_num > 1:
1121
- obj = MultiTokenizerWarpper(self.worker_id, obj)
1081
+ obj = MultiTokenizerWrapper(self.worker_id, obj)
1122
1082
  self.send_to_scheduler.send_pyobj(obj)
1123
1083
  self.model_update_result = asyncio.Future()
1124
1084
  if self.server_args.dp_size == 1:
@@ -1143,191 +1103,6 @@ class TokenizerManager:
1143
1103
  all_paused_requests = [r.num_paused_requests for r in result]
1144
1104
  return all_success, all_message, all_paused_requests
1145
1105
 
1146
- async def init_weights_update_group(
1147
- self,
1148
- obj: InitWeightsUpdateGroupReqInput,
1149
- request: Optional[fastapi.Request] = None,
1150
- ) -> Tuple[bool, str]:
1151
- self.auto_create_handle_loop()
1152
- assert (
1153
- self.server_args.dp_size == 1
1154
- ), "dp_size must be 1 for init parameter update group"
1155
- result = (await self.init_weights_update_group_communicator(obj))[0]
1156
- return result.success, result.message
1157
-
1158
- async def update_weights_from_distributed(
1159
- self,
1160
- obj: UpdateWeightsFromDistributedReqInput,
1161
- request: Optional[fastapi.Request] = None,
1162
- ) -> Tuple[bool, str]:
1163
- self.auto_create_handle_loop()
1164
- assert (
1165
- self.server_args.dp_size == 1 or self.server_args.enable_dp_attention
1166
- ), "dp_size must be 1 or dp attention must be enabled for update weights from distributed"
1167
-
1168
- if obj.abort_all_requests:
1169
- self.abort_request(abort_all=True)
1170
-
1171
- # This means that weight sync
1172
- # cannot run while requests are in progress.
1173
- async with self.model_update_lock.writer_lock:
1174
- result = (await self.update_weights_from_distributed_communicator(obj))[0]
1175
- return result.success, result.message
1176
-
1177
- async def update_weights_from_tensor(
1178
- self,
1179
- obj: UpdateWeightsFromTensorReqInput,
1180
- request: Optional[fastapi.Request] = None,
1181
- ) -> Tuple[bool, str]:
1182
- self.auto_create_handle_loop()
1183
- assert (
1184
- self.server_args.dp_size == 1 or self.server_args.enable_dp_attention
1185
- ), "dp_size must be 1 or dp attention must be enabled for update weights from tensor"
1186
-
1187
- if obj.abort_all_requests:
1188
- self.abort_request(abort_all=True)
1189
-
1190
- # This means that weight sync
1191
- # cannot run while requests are in progress.
1192
- async with self.model_update_lock.writer_lock:
1193
- result = (await self.update_weights_from_tensor_communicator(obj))[0]
1194
- return result.success, result.message
1195
-
1196
- async def load_lora_adapter(
1197
- self,
1198
- obj: LoadLoRAAdapterReqInput,
1199
- _: Optional[fastapi.Request] = None,
1200
- ) -> LoadLoRAAdapterReqOutput:
1201
- self.auto_create_handle_loop()
1202
-
1203
- try:
1204
- if not self.server_args.enable_lora:
1205
- raise ValueError(
1206
- "LoRA is not enabled. Please set `--enable-lora` to enable LoRA."
1207
- )
1208
-
1209
- # TODO (lifuhuang): Remove this after we verify that dynamic lora loading works
1210
- # with dp_size > 1.
1211
- assert (
1212
- self.server_args.dp_size == 1
1213
- ), "dp_size must be 1 for dynamic lora loading"
1214
- logger.info(
1215
- "Start load Lora adapter. Lora name=%s, path=%s",
1216
- obj.lora_name,
1217
- obj.lora_path,
1218
- )
1219
-
1220
- async with self.lora_update_lock:
1221
- if (
1222
- self.server_args.max_loaded_loras is not None
1223
- and self.lora_registry.num_registered_loras
1224
- >= self.server_args.max_loaded_loras
1225
- ):
1226
- raise ValueError(
1227
- f"Cannot load LoRA adapter {obj.lora_name} at path {obj.lora_path}. "
1228
- f"Maximum number of loaded LoRA adapters is {self.server_args.max_loaded_loras}. "
1229
- "Please unload some LoRA adapters before loading new ones."
1230
- )
1231
-
1232
- # Generate new uniquely identifiable LoRARef object.
1233
- new_adapter = LoRARef(
1234
- lora_name=obj.lora_name,
1235
- lora_path=obj.lora_path,
1236
- pinned=obj.pinned,
1237
- )
1238
-
1239
- # Trigger the actual loading operation at the backend processes.
1240
- obj.lora_id = new_adapter.lora_id
1241
- result = (await self.update_lora_adapter_communicator(obj))[0]
1242
-
1243
- # Register the LoRA adapter only after loading is successful.
1244
- if result.success:
1245
- await self.lora_registry.register(new_adapter)
1246
-
1247
- return result
1248
- except ValueError as e:
1249
- return LoadLoRAAdapterReqOutput(
1250
- success=False,
1251
- error_message=str(e),
1252
- )
1253
-
1254
- async def unload_lora_adapter(
1255
- self,
1256
- obj: UnloadLoRAAdapterReqInput,
1257
- _: Optional[fastapi.Request] = None,
1258
- ) -> UnloadLoRAAdapterReqOutput:
1259
- self.auto_create_handle_loop()
1260
-
1261
- try:
1262
- if not self.server_args.enable_lora:
1263
- raise ValueError(
1264
- "LoRA is not enabled. Please set `--enable-lora` to enable LoRA."
1265
- )
1266
-
1267
- assert (
1268
- obj.lora_name is not None
1269
- ), "lora_name must be provided to unload LoRA adapter"
1270
-
1271
- # TODO (lifuhuang): Remove this after we verify that dynamic lora loading works
1272
- # with dp_size > 1.
1273
- assert (
1274
- self.server_args.dp_size == 1
1275
- ), "dp_size must be 1 for dynamic lora loading"
1276
- logger.info(
1277
- "Start unload Lora adapter. Lora name=%s",
1278
- obj.lora_name,
1279
- )
1280
-
1281
- async with self.lora_update_lock:
1282
- # Unregister the LoRA adapter from the registry to stop new requests for this adapter
1283
- # from being started.
1284
- lora_id = await self.lora_registry.unregister(obj.lora_name)
1285
- obj.lora_id = lora_id
1286
-
1287
- # Initiate the actual unloading operation at the backend processes only after all
1288
- # ongoing requests using this LoRA adapter are finished.
1289
- await self.lora_registry.wait_for_unload(lora_id)
1290
- result = (await self.update_lora_adapter_communicator(obj))[0]
1291
-
1292
- return result
1293
- except ValueError as e:
1294
- return UnloadLoRAAdapterReqOutput(success=False, error_message=str(e))
1295
-
1296
- async def get_weights_by_name(
1297
- self, obj: GetWeightsByNameReqInput, request: Optional[fastapi.Request] = None
1298
- ):
1299
- self.auto_create_handle_loop()
1300
- results = await self.get_weights_by_name_communicator(obj)
1301
- all_parameters = [r.parameter for r in results]
1302
- if self.server_args.dp_size == 1:
1303
- return all_parameters[0]
1304
- else:
1305
- return all_parameters
1306
-
1307
- async def release_memory_occupation(
1308
- self,
1309
- obj: ReleaseMemoryOccupationReqInput,
1310
- request: Optional[fastapi.Request] = None,
1311
- ):
1312
- self.auto_create_handle_loop()
1313
- await self.release_memory_occupation_communicator(obj)
1314
-
1315
- async def resume_memory_occupation(
1316
- self,
1317
- obj: ResumeMemoryOccupationReqInput,
1318
- request: Optional[fastapi.Request] = None,
1319
- ):
1320
- self.auto_create_handle_loop()
1321
- await self.resume_memory_occupation_communicator(obj)
1322
-
1323
- async def slow_down(
1324
- self,
1325
- obj: SlowDownReqInput,
1326
- request: Optional[fastapi.Request] = None,
1327
- ):
1328
- self.auto_create_handle_loop()
1329
- await self.slow_down_communicator(obj)
1330
-
1331
1106
  async def open_session(
1332
1107
  self, obj: OpenSessionReqInput, request: Optional[fastapi.Request] = None
1333
1108
  ):
@@ -1339,7 +1114,7 @@ class TokenizerManager:
1339
1114
  return None
1340
1115
 
1341
1116
  if self.server_args.tokenizer_worker_num > 1:
1342
- obj = MultiTokenizerWarpper(self.worker_id, obj)
1117
+ obj = MultiTokenizerWrapper(self.worker_id, obj)
1343
1118
  self.send_to_scheduler.send_pyobj(obj)
1344
1119
 
1345
1120
  self.session_futures[obj.session_id] = asyncio.Future()
@@ -1352,28 +1127,6 @@ class TokenizerManager:
1352
1127
  ):
1353
1128
  await self.send_to_scheduler.send_pyobj(obj)
1354
1129
 
1355
- async def get_internal_state(self) -> List[Dict[Any, Any]]:
1356
- req = GetInternalStateReq()
1357
- responses: List[GetInternalStateReqOutput] = (
1358
- await self.get_internal_state_communicator(req)
1359
- )
1360
- # Many DP ranks
1361
- return [res.internal_state for res in responses]
1362
-
1363
- async def set_internal_state(self, obj: SetInternalStateReq) -> List[bool]:
1364
- responses: List[SetInternalStateReqOutput] = (
1365
- await self.set_internal_state_communicator(obj)
1366
- )
1367
- return [res.updated for res in responses]
1368
-
1369
- async def get_load(self) -> dict:
1370
- # TODO(lsyin): fake load report server
1371
- if not self.current_load_lock.locked():
1372
- async with self.current_load_lock:
1373
- internal_state = await self.get_internal_state()
1374
- self.current_load = internal_state[0]["load"]
1375
- return {"load": self.current_load}
1376
-
1377
1130
  def get_log_request_metadata(self):
1378
1131
  max_length = None
1379
1132
  skip_names = None
@@ -1492,6 +1245,9 @@ class TokenizerManager:
1492
1245
  self.asyncio_tasks.add(
1493
1246
  loop.create_task(print_exception_wrapper(self.sigterm_watchdog))
1494
1247
  )
1248
+ self.asyncio_tasks.add(
1249
+ loop.create_task(print_exception_wrapper(self.watch_load_thread))
1250
+ )
1495
1251
 
1496
1252
  def dump_requests_before_crash(self):
1497
1253
  if self.crash_dump_performed:
@@ -1711,6 +1467,9 @@ class TokenizerManager:
1711
1467
  meta_info["spec_verify_ct"] = recv_obj.spec_verify_ct[i]
1712
1468
  state.finished_time = time.time()
1713
1469
  meta_info["e2e_latency"] = state.finished_time - state.created_time
1470
+
1471
+ trace_req_finish(rid, ts=int(state.finished_time * 1e9))
1472
+
1714
1473
  del self.rid_to_state[rid]
1715
1474
 
1716
1475
  # Mark ongoing LoRA request as finished.
@@ -1860,6 +1619,12 @@ class TokenizerManager:
1860
1619
  else 0
1861
1620
  )
1862
1621
 
1622
+ customer_labels = getattr(state.obj, "customer_labels", None)
1623
+ labels = (
1624
+ {**self.metrics_collector.labels, **customer_labels}
1625
+ if customer_labels
1626
+ else self.metrics_collector.labels
1627
+ )
1863
1628
  if (
1864
1629
  state.first_token_time == 0.0
1865
1630
  and self.disaggregation_mode != DisaggregationMode.PREFILL
@@ -1867,7 +1632,7 @@ class TokenizerManager:
1867
1632
  state.first_token_time = state.last_time = time.time()
1868
1633
  state.last_completion_tokens = completion_tokens
1869
1634
  self.metrics_collector.observe_time_to_first_token(
1870
- state.first_token_time - state.created_time
1635
+ labels, state.first_token_time - state.created_time
1871
1636
  )
1872
1637
  else:
1873
1638
  num_new_tokens = completion_tokens - state.last_completion_tokens
@@ -1875,6 +1640,7 @@ class TokenizerManager:
1875
1640
  new_time = time.time()
1876
1641
  interval = new_time - state.last_time
1877
1642
  self.metrics_collector.observe_inter_token_latency(
1643
+ labels,
1878
1644
  interval,
1879
1645
  num_new_tokens,
1880
1646
  )
@@ -1889,6 +1655,7 @@ class TokenizerManager:
1889
1655
  or state.obj.sampling_params.get("structural_tag", None)
1890
1656
  )
1891
1657
  self.metrics_collector.observe_one_finished_request(
1658
+ labels,
1892
1659
  recv_obj.prompt_tokens[i],
1893
1660
  completion_tokens,
1894
1661
  recv_obj.cached_tokens[i],
@@ -2060,11 +1827,15 @@ class TokenizerManager:
2060
1827
  # the next position after the last token in the prompt
2061
1828
  output_logprobs = result["meta_info"].get("output_token_ids_logprobs", [])
2062
1829
 
2063
- # Throw an error here if output_logprobs is None
2064
- if output_logprobs is None:
1830
+ # Check if output_logprobs is properly populated
1831
+ if (
1832
+ output_logprobs is None
1833
+ or not output_logprobs
1834
+ or len(output_logprobs) == 0
1835
+ ):
2065
1836
  raise RuntimeError(
2066
- f"output_logprobs is None for request {result['meta_info'].get('id', '<unknown>')}. "
2067
- "This usually indicates a problem with the scoring request or the backend output."
1837
+ f"output_logprobs is empty for request {result['meta_info'].get('id', '<unknown>')}. "
1838
+ "This indicates token_ids_logprobs were not computed properly for the scoring request."
2068
1839
  )
2069
1840
 
2070
1841
  for logprob, token_id, _ in output_logprobs[0]:
@@ -2089,6 +1860,20 @@ class TokenizerManager:
2089
1860
 
2090
1861
  return scores
2091
1862
 
1863
+ async def watch_load_thread(self):
1864
+ # Only for dp_controller when dp_size > 1
1865
+ if (
1866
+ self.server_args.dp_size == 1
1867
+ or self.server_args.load_balance_method == "round_robin"
1868
+ ):
1869
+ return
1870
+
1871
+ while True:
1872
+ await asyncio.sleep(self.server_args.load_watch_interval)
1873
+ loads = await self.get_load_communicator(GetLoadReqInput())
1874
+ load_udpate_req = WatchLoadUpdateReq(loads=loads)
1875
+ self.send_to_scheduler.send_pyobj(load_udpate_req)
1876
+
2092
1877
 
2093
1878
  class ServerStatus(Enum):
2094
1879
  Up = "Up"
@@ -2140,51 +1925,6 @@ class SignalHandler:
2140
1925
  kill_process_tree(os.getpid())
2141
1926
 
2142
1927
 
2143
- T = TypeVar("T")
2144
-
2145
-
2146
- class _Communicator(Generic[T]):
2147
- """Note: The communicator now only run up to 1 in-flight request at any time."""
2148
-
2149
- enable_multi_tokenizer = False
2150
-
2151
- def __init__(self, sender, fan_out: int):
2152
- self._sender = sender
2153
- self._fan_out = fan_out
2154
- self._result_event: Optional[asyncio.Event] = None
2155
- self._result_values: Optional[List[T]] = None
2156
- self._ready_queue: Deque[asyncio.Future] = deque()
2157
-
2158
- async def __call__(self, obj):
2159
- ready_event = asyncio.Event()
2160
- if self._result_event is not None or len(self._ready_queue) > 0:
2161
- self._ready_queue.append(ready_event)
2162
- await ready_event.wait()
2163
- assert self._result_event is None
2164
- assert self._result_values is None
2165
-
2166
- if obj:
2167
- if _Communicator.enable_multi_tokenizer:
2168
- obj = MultiTokenizerWarpper(worker_id=os.getpid(), obj=obj)
2169
- self._sender.send_pyobj(obj)
2170
-
2171
- self._result_event = asyncio.Event()
2172
- self._result_values = []
2173
- await self._result_event.wait()
2174
- result_values = self._result_values
2175
- self._result_event = self._result_values = None
2176
-
2177
- if len(self._ready_queue) > 0:
2178
- self._ready_queue.popleft().set()
2179
-
2180
- return result_values
2181
-
2182
- def handle_recv(self, recv_obj: T):
2183
- self._result_values.append(recv_obj)
2184
- if len(self._result_values) == self._fan_out:
2185
- self._result_event.set()
2186
-
2187
-
2188
1928
  # Note: request abort handling logic
2189
1929
  # We should handle all of the following cases correctly.
2190
1930
  #