sglang 0.5.2rc1__py3-none-any.whl → 0.5.3rc0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (265) hide show
  1. sglang/bench_one_batch_server.py +10 -1
  2. sglang/bench_serving.py +257 -29
  3. sglang/lang/interpreter.py +1 -1
  4. sglang/srt/configs/__init__.py +4 -0
  5. sglang/srt/configs/device_config.py +3 -1
  6. sglang/srt/configs/dots_vlm.py +139 -0
  7. sglang/srt/configs/internvl.py +6 -0
  8. sglang/srt/configs/load_config.py +1 -0
  9. sglang/srt/configs/model_config.py +50 -6
  10. sglang/srt/configs/qwen3_next.py +326 -0
  11. sglang/srt/connector/__init__.py +8 -1
  12. sglang/srt/connector/remote_instance.py +82 -0
  13. sglang/srt/constrained/base_grammar_backend.py +48 -12
  14. sglang/srt/constrained/llguidance_backend.py +0 -1
  15. sglang/srt/constrained/outlines_backend.py +0 -1
  16. sglang/srt/constrained/xgrammar_backend.py +28 -9
  17. sglang/srt/custom_op.py +11 -1
  18. sglang/srt/debug_utils/dump_comparator.py +81 -44
  19. sglang/srt/debug_utils/dump_loader.py +97 -0
  20. sglang/srt/debug_utils/dumper.py +11 -3
  21. sglang/srt/debug_utils/text_comparator.py +73 -11
  22. sglang/srt/disaggregation/base/conn.py +1 -1
  23. sglang/srt/disaggregation/common/conn.py +15 -12
  24. sglang/srt/disaggregation/decode.py +21 -10
  25. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +4 -1
  26. sglang/srt/disaggregation/fake/conn.py +1 -1
  27. sglang/srt/disaggregation/mini_lb.py +6 -445
  28. sglang/srt/disaggregation/mooncake/conn.py +18 -10
  29. sglang/srt/disaggregation/nixl/conn.py +180 -16
  30. sglang/srt/disaggregation/prefill.py +5 -3
  31. sglang/srt/disaggregation/utils.py +5 -50
  32. sglang/srt/distributed/parallel_state.py +67 -43
  33. sglang/srt/entrypoints/engine.py +38 -17
  34. sglang/srt/entrypoints/grpc_request_manager.py +580 -0
  35. sglang/srt/entrypoints/grpc_server.py +680 -0
  36. sglang/srt/entrypoints/http_server.py +88 -53
  37. sglang/srt/entrypoints/openai/protocol.py +7 -4
  38. sglang/srt/entrypoints/openai/serving_base.py +46 -3
  39. sglang/srt/entrypoints/openai/serving_chat.py +39 -19
  40. sglang/srt/entrypoints/openai/serving_completions.py +15 -4
  41. sglang/srt/entrypoints/openai/serving_embedding.py +9 -4
  42. sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
  43. sglang/srt/entrypoints/openai/serving_responses.py +7 -4
  44. sglang/srt/entrypoints/openai/serving_score.py +1 -0
  45. sglang/srt/eplb/eplb_manager.py +2 -2
  46. sglang/srt/eplb/expert_distribution.py +26 -13
  47. sglang/srt/eplb/expert_location.py +8 -3
  48. sglang/srt/eplb/expert_location_updater.py +1 -1
  49. sglang/srt/function_call/base_format_detector.py +3 -6
  50. sglang/srt/function_call/ebnf_composer.py +11 -9
  51. sglang/srt/function_call/function_call_parser.py +6 -0
  52. sglang/srt/function_call/glm4_moe_detector.py +1 -1
  53. sglang/srt/function_call/gpt_oss_detector.py +1 -1
  54. sglang/srt/function_call/qwen3_coder_detector.py +1 -1
  55. sglang/srt/grpc/__init__.py +1 -0
  56. sglang/srt/grpc/sglang_scheduler_pb2.py +106 -0
  57. sglang/srt/grpc/sglang_scheduler_pb2.pyi +427 -0
  58. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +236 -0
  59. sglang/srt/hf_transformers_utils.py +4 -0
  60. sglang/srt/layers/activation.py +142 -9
  61. sglang/srt/layers/attention/aiter_backend.py +93 -68
  62. sglang/srt/layers/attention/ascend_backend.py +11 -4
  63. sglang/srt/layers/attention/fla/chunk.py +242 -0
  64. sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
  65. sglang/srt/layers/attention/fla/chunk_o.py +178 -0
  66. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
  67. sglang/srt/layers/attention/fla/cumsum.py +300 -0
  68. sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
  69. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
  70. sglang/srt/layers/attention/fla/index.py +37 -0
  71. sglang/srt/layers/attention/fla/l2norm.py +150 -0
  72. sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
  73. sglang/srt/layers/attention/fla/op.py +66 -0
  74. sglang/srt/layers/attention/fla/solve_tril.py +465 -0
  75. sglang/srt/layers/attention/fla/utils.py +331 -0
  76. sglang/srt/layers/attention/fla/wy_fast.py +158 -0
  77. sglang/srt/layers/attention/flashinfer_backend.py +6 -4
  78. sglang/srt/layers/attention/flashinfer_mla_backend.py +16 -12
  79. sglang/srt/layers/attention/hybrid_attn_backend.py +57 -50
  80. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +602 -0
  81. sglang/srt/layers/attention/intel_amx_backend.py +3 -0
  82. sglang/srt/layers/attention/mamba/causal_conv1d.py +128 -0
  83. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +1052 -0
  84. sglang/srt/layers/attention/mamba/mamba.py +64 -0
  85. sglang/srt/layers/attention/torch_native_backend.py +12 -6
  86. sglang/srt/layers/attention/triton_backend.py +18 -1
  87. sglang/srt/layers/attention/trtllm_mla_backend.py +124 -31
  88. sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
  89. sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
  90. sglang/srt/layers/communicator.py +45 -7
  91. sglang/srt/layers/dp_attention.py +30 -1
  92. sglang/srt/layers/layernorm.py +32 -15
  93. sglang/srt/layers/linear.py +34 -3
  94. sglang/srt/layers/logits_processor.py +29 -10
  95. sglang/srt/layers/moe/__init__.py +2 -1
  96. sglang/srt/layers/moe/cutlass_w4a8_moe.py +3 -3
  97. sglang/srt/layers/moe/ep_moe/kernels.py +1 -1
  98. sglang/srt/layers/moe/ep_moe/layer.py +182 -62
  99. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +156 -0
  100. sglang/srt/layers/moe/fused_moe_native.py +5 -3
  101. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  102. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
  103. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
  104. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/{E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → E=257,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json } +29 -29
  105. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  106. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
  107. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
  108. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
  109. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
  110. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  111. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
  112. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
  113. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +1 -1
  114. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
  115. sglang/srt/layers/moe/fused_moe_triton/layer.py +61 -59
  116. sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
  117. sglang/srt/layers/moe/moe_runner/base.py +274 -1
  118. sglang/srt/layers/moe/moe_runner/runner.py +80 -0
  119. sglang/srt/layers/moe/moe_runner/triton.py +448 -0
  120. sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
  121. sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
  122. sglang/srt/layers/moe/token_dispatcher/deepep.py +43 -39
  123. sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
  124. sglang/srt/layers/moe/topk.py +30 -9
  125. sglang/srt/layers/moe/utils.py +12 -7
  126. sglang/srt/layers/quantization/awq.py +19 -7
  127. sglang/srt/layers/quantization/base_config.py +11 -6
  128. sglang/srt/layers/quantization/blockwise_int8.py +38 -27
  129. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
  130. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
  131. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
  132. sglang/srt/layers/quantization/fp8.py +76 -47
  133. sglang/srt/layers/quantization/fp8_utils.py +50 -31
  134. sglang/srt/layers/quantization/gptq.py +25 -17
  135. sglang/srt/layers/quantization/modelopt_quant.py +182 -49
  136. sglang/srt/layers/quantization/moe_wna16.py +21 -18
  137. sglang/srt/layers/quantization/mxfp4.py +68 -41
  138. sglang/srt/layers/quantization/quark/quark_moe.py +32 -27
  139. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +49 -30
  140. sglang/srt/layers/quantization/quark/utils.py +97 -0
  141. sglang/srt/layers/quantization/rocm_mxfp4_utils.py +13 -0
  142. sglang/srt/layers/quantization/unquant.py +135 -47
  143. sglang/srt/layers/quantization/w4afp8.py +30 -17
  144. sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
  145. sglang/srt/layers/quantization/w8a8_int8.py +76 -38
  146. sglang/srt/layers/rocm_linear_utils.py +44 -0
  147. sglang/srt/layers/rotary_embedding.py +0 -18
  148. sglang/srt/layers/sampler.py +162 -18
  149. sglang/srt/lora/backend/base_backend.py +50 -8
  150. sglang/srt/lora/backend/triton_backend.py +90 -2
  151. sglang/srt/lora/layers.py +32 -0
  152. sglang/srt/lora/lora.py +4 -1
  153. sglang/srt/lora/lora_manager.py +35 -112
  154. sglang/srt/lora/mem_pool.py +24 -10
  155. sglang/srt/lora/utils.py +18 -9
  156. sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
  157. sglang/srt/managers/cache_controller.py +200 -199
  158. sglang/srt/managers/data_parallel_controller.py +105 -35
  159. sglang/srt/managers/detokenizer_manager.py +8 -4
  160. sglang/srt/managers/disagg_service.py +46 -0
  161. sglang/srt/managers/io_struct.py +199 -12
  162. sglang/srt/managers/mm_utils.py +1 -0
  163. sglang/srt/managers/multi_tokenizer_mixin.py +351 -397
  164. sglang/srt/managers/schedule_batch.py +77 -56
  165. sglang/srt/managers/schedule_policy.py +4 -3
  166. sglang/srt/managers/scheduler.py +191 -139
  167. sglang/srt/managers/scheduler_metrics_mixin.py +116 -9
  168. sglang/srt/managers/scheduler_output_processor_mixin.py +55 -11
  169. sglang/srt/managers/scheduler_profiler_mixin.py +1 -1
  170. sglang/srt/managers/template_manager.py +3 -3
  171. sglang/srt/managers/tokenizer_communicator_mixin.py +569 -0
  172. sglang/srt/managers/tokenizer_manager.py +260 -519
  173. sglang/srt/managers/tp_worker.py +53 -4
  174. sglang/srt/managers/tp_worker_overlap_thread.py +42 -19
  175. sglang/srt/mem_cache/allocator.py +1 -1
  176. sglang/srt/mem_cache/hicache_storage.py +18 -33
  177. sglang/srt/mem_cache/hiradix_cache.py +108 -48
  178. sglang/srt/mem_cache/memory_pool.py +347 -48
  179. sglang/srt/mem_cache/memory_pool_host.py +121 -57
  180. sglang/srt/mem_cache/radix_cache.py +0 -2
  181. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
  182. sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
  183. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +95 -5
  184. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +280 -0
  185. sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
  186. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +81 -20
  187. sglang/srt/mem_cache/storage/mooncake_store/test_mooncake_store.py +161 -0
  188. sglang/srt/mem_cache/swa_radix_cache.py +0 -2
  189. sglang/srt/metrics/collector.py +502 -77
  190. sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
  191. sglang/srt/metrics/utils.py +48 -0
  192. sglang/srt/model_executor/cpu_graph_runner.py +640 -0
  193. sglang/srt/model_executor/cuda_graph_runner.py +13 -5
  194. sglang/srt/model_executor/forward_batch_info.py +75 -19
  195. sglang/srt/model_executor/model_runner.py +357 -30
  196. sglang/srt/model_loader/__init__.py +9 -3
  197. sglang/srt/model_loader/loader.py +128 -4
  198. sglang/srt/model_loader/weight_utils.py +2 -1
  199. sglang/srt/models/apertus.py +686 -0
  200. sglang/srt/models/bailing_moe.py +798 -218
  201. sglang/srt/models/bailing_moe_nextn.py +168 -0
  202. sglang/srt/models/deepseek_v2.py +346 -48
  203. sglang/srt/models/dots_vlm.py +174 -0
  204. sglang/srt/models/dots_vlm_vit.py +337 -0
  205. sglang/srt/models/ernie4.py +1 -1
  206. sglang/srt/models/gemma3n_mm.py +1 -1
  207. sglang/srt/models/glm4_moe.py +11 -2
  208. sglang/srt/models/glm4v.py +4 -2
  209. sglang/srt/models/glm4v_moe.py +3 -0
  210. sglang/srt/models/gpt_oss.py +1 -1
  211. sglang/srt/models/internvl.py +28 -0
  212. sglang/srt/models/llama4.py +9 -0
  213. sglang/srt/models/llama_eagle3.py +13 -0
  214. sglang/srt/models/longcat_flash.py +2 -2
  215. sglang/srt/models/minicpmv.py +165 -3
  216. sglang/srt/models/mllama4.py +25 -0
  217. sglang/srt/models/opt.py +637 -0
  218. sglang/srt/models/qwen2.py +7 -0
  219. sglang/srt/models/qwen2_5_vl.py +27 -3
  220. sglang/srt/models/qwen2_moe.py +60 -13
  221. sglang/srt/models/qwen3.py +8 -2
  222. sglang/srt/models/qwen3_moe.py +40 -9
  223. sglang/srt/models/qwen3_next.py +1042 -0
  224. sglang/srt/models/qwen3_next_mtp.py +112 -0
  225. sglang/srt/models/step3_vl.py +1 -1
  226. sglang/srt/models/torch_native_llama.py +1 -1
  227. sglang/srt/multimodal/processors/dots_vlm.py +99 -0
  228. sglang/srt/multimodal/processors/glm4v.py +9 -9
  229. sglang/srt/multimodal/processors/internvl.py +141 -129
  230. sglang/srt/multimodal/processors/qwen_vl.py +15 -5
  231. sglang/srt/offloader.py +27 -3
  232. sglang/srt/{reasoning_parser.py → parser/reasoning_parser.py} +1 -1
  233. sglang/srt/remote_instance_weight_loader_utils.py +69 -0
  234. sglang/srt/sampling/sampling_batch_info.py +18 -15
  235. sglang/srt/server_args.py +355 -37
  236. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +5 -0
  237. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +10 -1
  238. sglang/srt/speculative/eagle_utils.py +0 -2
  239. sglang/srt/speculative/eagle_worker.py +197 -112
  240. sglang/srt/speculative/spec_info.py +5 -0
  241. sglang/srt/speculative/standalone_worker.py +109 -0
  242. sglang/srt/tracing/trace.py +552 -0
  243. sglang/srt/utils.py +46 -3
  244. sglang/srt/weight_sync/utils.py +1 -1
  245. sglang/test/attention/test_trtllm_mla_backend.py +169 -5
  246. sglang/test/few_shot_gsm8k.py +1 -0
  247. sglang/test/runners.py +4 -0
  248. sglang/test/test_cutlass_moe.py +24 -6
  249. sglang/test/test_disaggregation_utils.py +66 -0
  250. sglang/test/test_fp4_moe.py +370 -1
  251. sglang/test/test_utils.py +28 -1
  252. sglang/utils.py +12 -0
  253. sglang/version.py +1 -1
  254. {sglang-0.5.2rc1.dist-info → sglang-0.5.3rc0.dist-info}/METADATA +59 -123
  255. {sglang-0.5.2rc1.dist-info → sglang-0.5.3rc0.dist-info}/RECORD +263 -200
  256. sglang/srt/disaggregation/launch_lb.py +0 -118
  257. sglang/srt/mem_cache/storage/mooncake_store/unit_test.py +0 -40
  258. /sglang/srt/{model_parallel.py → layers/model_parallel.py} +0 -0
  259. /sglang/srt/{code_completion_parser.py → parser/code_completion_parser.py} +0 -0
  260. /sglang/srt/{conversation.py → parser/conversation.py} +0 -0
  261. /sglang/srt/{harmony_parser.py → parser/harmony_parser.py} +0 -0
  262. /sglang/srt/{jinja_template_utils.py → parser/jinja_template_utils.py} +0 -0
  263. {sglang-0.5.2rc1.dist-info → sglang-0.5.3rc0.dist-info}/WHEEL +0 -0
  264. {sglang-0.5.2rc1.dist-info → sglang-0.5.3rc0.dist-info}/licenses/LICENSE +0 -0
  265. {sglang-0.5.2rc1.dist-info → sglang-0.5.3rc0.dist-info}/top_level.txt +0 -0
@@ -79,13 +79,17 @@ from sglang.srt.managers.io_struct import (
79
79
  FreezeGCReq,
80
80
  GetInternalStateReq,
81
81
  GetInternalStateReqOutput,
82
+ GetLoadReqInput,
83
+ GetLoadReqOutput,
82
84
  GetWeightsByNameReqInput,
83
85
  HealthCheckOutput,
86
+ InitWeightsSendGroupForRemoteInstanceReqInput,
87
+ InitWeightsSendGroupForRemoteInstanceReqOutput,
84
88
  InitWeightsUpdateGroupReqInput,
85
89
  LoadLoRAAdapterReqInput,
86
90
  LoadLoRAAdapterReqOutput,
87
91
  MultiTokenizerRegisterReq,
88
- MultiTokenizerWarpper,
92
+ MultiTokenizerWrapper,
89
93
  OpenSessionReqInput,
90
94
  OpenSessionReqOutput,
91
95
  ProfileReq,
@@ -93,6 +97,8 @@ from sglang.srt.managers.io_struct import (
93
97
  ResumeMemoryOccupationReqInput,
94
98
  RpcReqInput,
95
99
  RpcReqOutput,
100
+ SendWeightsToRemoteInstanceReqInput,
101
+ SendWeightsToRemoteInstanceReqOutput,
96
102
  SetInternalStateReq,
97
103
  SetInternalStateReqOutput,
98
104
  SlowDownReqInput,
@@ -141,10 +147,19 @@ from sglang.srt.mem_cache.lora_radix_cache import LoRARadixCache
141
147
  from sglang.srt.mem_cache.radix_cache import RadixCache
142
148
  from sglang.srt.mem_cache.swa_radix_cache import SWARadixCache
143
149
  from sglang.srt.model_executor.forward_batch_info import ForwardMode, PPProxyTensors
144
- from sglang.srt.reasoning_parser import ReasoningParser
150
+ from sglang.srt.parser.reasoning_parser import ReasoningParser
145
151
  from sglang.srt.server_args import PortArgs, ServerArgs
146
152
  from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
147
153
  from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
154
+ from sglang.srt.tracing.trace import (
155
+ process_tracing_init,
156
+ trace_event,
157
+ trace_set_proc_propagate_context,
158
+ trace_set_thread_info,
159
+ trace_slice,
160
+ trace_slice_end,
161
+ trace_slice_start,
162
+ )
148
163
  from sglang.srt.two_batch_overlap import TboDPAttentionPreparer
149
164
  from sglang.srt.utils import (
150
165
  DynamicGradMode,
@@ -158,6 +173,7 @@ from sglang.srt.utils import (
158
173
  get_zmq_socket,
159
174
  is_cpu,
160
175
  kill_itself_when_parent_died,
176
+ numa_bind_to_node,
161
177
  point_to_point_pyobj,
162
178
  pyspy_dump_schedulers,
163
179
  require_mlp_sync,
@@ -348,6 +364,18 @@ class Scheduler(
348
364
  target_worker=self.tp_worker,
349
365
  dp_rank=dp_rank,
350
366
  )
367
+ elif self.spec_algorithm.is_standalone():
368
+ from sglang.srt.speculative.standalone_worker import StandaloneWorker
369
+
370
+ self.draft_worker = StandaloneWorker(
371
+ gpu_id=gpu_id,
372
+ tp_rank=tp_rank,
373
+ moe_ep_rank=moe_ep_rank,
374
+ server_args=server_args,
375
+ nccl_port=port_args.nccl_port,
376
+ target_worker=self.tp_worker,
377
+ dp_rank=dp_rank,
378
+ )
351
379
  else:
352
380
  self.draft_worker = None
353
381
 
@@ -401,7 +429,7 @@ class Scheduler(
401
429
  f"max_prefill_tokens={self.max_prefill_tokens}, "
402
430
  f"max_running_requests={self.max_running_requests}, "
403
431
  f"context_len={self.model_config.context_len}, "
404
- f"available_gpu_mem={avail_mem:.2f} GB"
432
+ f"{'available_cpu_mem' if self.device == 'cpu' else 'available_gpu_mem'}={avail_mem:.2f} GB"
405
433
  )
406
434
 
407
435
  # Init memory pool and cache
@@ -488,7 +516,7 @@ class Scheduler(
488
516
  enable=server_args.enable_memory_saver
489
517
  )
490
518
  self.offload_tags = set()
491
- self.init_profier()
519
+ self.init_profiler()
492
520
 
493
521
  self.recv_skipper = SchedulerRecvSkipper.maybe_create(server_args)
494
522
  self.input_blocker = (
@@ -500,6 +528,7 @@ class Scheduler(
500
528
  # Init metrics stats
501
529
  self.init_metrics(tp_rank, pp_rank, dp_rank)
502
530
  self.init_kv_events(server_args.kv_events_config)
531
+ self.init_dp_balance(dp_balance_meta)
503
532
 
504
533
  # Init disaggregation
505
534
  self.disaggregation_mode = DisaggregationMode(
@@ -524,6 +553,14 @@ class Scheduler(
524
553
  (CloseSessionReqInput, self.close_session),
525
554
  (UpdateWeightFromDiskReqInput, self.update_weights_from_disk),
526
555
  (InitWeightsUpdateGroupReqInput, self.init_weights_update_group),
556
+ (
557
+ InitWeightsSendGroupForRemoteInstanceReqInput,
558
+ self.init_weights_send_group_for_remote_instance,
559
+ ),
560
+ (
561
+ SendWeightsToRemoteInstanceReqInput,
562
+ self.send_weights_to_remote_instance,
563
+ ),
527
564
  (
528
565
  UpdateWeightsFromDistributedReqInput,
529
566
  self.update_weights_from_distributed,
@@ -542,18 +579,10 @@ class Scheduler(
542
579
  (LoadLoRAAdapterReqInput, self.load_lora_adapter),
543
580
  (UnloadLoRAAdapterReqInput, self.unload_lora_adapter),
544
581
  (MultiTokenizerRegisterReq, self.register_multi_tokenizer),
582
+ (GetLoadReqInput, self.get_load),
545
583
  ]
546
584
  )
547
585
 
548
- self.balance_meta = dp_balance_meta
549
- if (
550
- server_args.enable_dp_attention
551
- and server_args.load_balance_method == "minimum_tokens"
552
- ):
553
- assert dp_balance_meta is not None
554
-
555
- self.recv_dp_balance_id_this_term = []
556
-
557
586
  def init_tokenizer(self):
558
587
  server_args = self.server_args
559
588
  self.is_generation = self.model_config.is_generation
@@ -630,6 +659,7 @@ class Scheduler(
630
659
  hicache_write_policy=server_args.hicache_write_policy,
631
660
  hicache_io_backend=server_args.hicache_io_backend,
632
661
  hicache_mem_layout=server_args.hicache_mem_layout,
662
+ enable_metrics=self.enable_metrics,
633
663
  hicache_storage_backend=server_args.hicache_storage_backend,
634
664
  hicache_storage_prefetch_policy=server_args.hicache_storage_prefetch_policy,
635
665
  model_name=server_args.served_model_name,
@@ -662,6 +692,21 @@ class Scheduler(
662
692
  page_size=self.page_size,
663
693
  disable=server_args.disable_radix_cache,
664
694
  )
695
+ elif server_args.enable_lmcache:
696
+ from sglang.srt.mem_cache.storage.lmcache.lmc_radix_cache import (
697
+ LMCRadixCache,
698
+ )
699
+
700
+ self.tree_cache = LMCRadixCache(
701
+ req_to_token_pool=self.req_to_token_pool,
702
+ token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
703
+ page_size=self.page_size,
704
+ disable=server_args.disable_radix_cache,
705
+ model_config=self.model_config,
706
+ tp_size=self.tp_size,
707
+ rank=self.tp_rank,
708
+ tp_group=self.tp_group,
709
+ )
665
710
  else:
666
711
  self.tree_cache = RadixCache(
667
712
  req_to_token_pool=self.req_to_token_pool,
@@ -793,6 +838,10 @@ class Scheduler(
793
838
  batch = self.get_next_batch_to_run()
794
839
  self.cur_batch = batch
795
840
 
841
+ if batch:
842
+ for req in batch.reqs:
843
+ trace_event("schedule", req.rid)
844
+
796
845
  if batch:
797
846
  result = self.run_batch(batch)
798
847
  self.process_batch_result(batch, result)
@@ -814,6 +863,10 @@ class Scheduler(
814
863
  batch = self.get_next_batch_to_run()
815
864
  self.cur_batch = batch
816
865
 
866
+ if batch:
867
+ for req in batch.reqs:
868
+ trace_event("schedule", req.rid)
869
+
817
870
  if batch:
818
871
  batch.launch_done = threading.Event()
819
872
  result = self.run_batch(batch)
@@ -1077,6 +1130,12 @@ class Scheduler(
1077
1130
  self.tp_cpu_group,
1078
1131
  src=self.tp_group.ranks[0],
1079
1132
  )
1133
+
1134
+ for req in recv_reqs:
1135
+ if isinstance(req, (TokenizedGenerateReqInput, TokenizedEmbeddingReqInput)):
1136
+ trace_set_proc_propagate_context(req.rid, req.trace_context)
1137
+ trace_slice_start("", req.rid, anonymous=True)
1138
+
1080
1139
  return recv_reqs
1081
1140
 
1082
1141
  def process_input_requests(self, recv_reqs: List):
@@ -1104,13 +1163,13 @@ class Scheduler(
1104
1163
  self.send_to_tokenizer.send_pyobj(abort_req)
1105
1164
  continue
1106
1165
 
1107
- # If it is a MultiTokenizerWarpper, unwrap it and handle the inner request.
1108
- if isinstance(recv_req, MultiTokenizerWarpper):
1166
+ # If it is a MultiTokenizerWrapper, unwrap it and handle the inner request.
1167
+ if isinstance(recv_req, MultiTokenizerWrapper):
1109
1168
  worker_id = recv_req.worker_id
1110
1169
  recv_req = recv_req.obj
1111
1170
  output = self._request_dispatcher(recv_req)
1112
1171
  if output is not None:
1113
- output = MultiTokenizerWarpper(worker_id, output)
1172
+ output = MultiTokenizerWrapper(worker_id, output)
1114
1173
  self.send_to_tokenizer.send_pyobj(output)
1115
1174
  continue
1116
1175
 
@@ -1122,15 +1181,21 @@ class Scheduler(
1122
1181
  else:
1123
1182
  self.send_to_tokenizer.send_pyobj(output)
1124
1183
 
1184
+ def init_req_max_new_tokens(self, req):
1185
+ req.sampling_params.max_new_tokens = min(
1186
+ (
1187
+ req.sampling_params.max_new_tokens
1188
+ if req.sampling_params.max_new_tokens is not None
1189
+ else 1 << 30
1190
+ ),
1191
+ self.max_req_len - len(req.origin_input_ids) - 1,
1192
+ )
1193
+
1125
1194
  def handle_generate_request(
1126
1195
  self,
1127
1196
  recv_req: TokenizedGenerateReqInput,
1128
1197
  ):
1129
- if (
1130
- self.server_args.enable_dp_attention
1131
- and self.server_args.load_balance_method == "minimum_tokens"
1132
- ):
1133
- self.recv_dp_balance_id_this_term.append(recv_req.dp_balance_id)
1198
+ self.maybe_update_dp_balance_data(recv_req)
1134
1199
 
1135
1200
  # Create a new request
1136
1201
  if (
@@ -1189,6 +1254,7 @@ class Scheduler(
1189
1254
  req.set_finish_with_abort(
1190
1255
  f"Invalid request: session id {recv_req.session_params.id} does not exist"
1191
1256
  )
1257
+ self.init_req_max_new_tokens(req)
1192
1258
  self._add_request_to_queue(req)
1193
1259
  return
1194
1260
  else:
@@ -1196,6 +1262,7 @@ class Scheduler(
1196
1262
  session = self.sessions[recv_req.session_params.id]
1197
1263
  req = session.create_req(recv_req, self.tokenizer)
1198
1264
  if isinstance(req.finished_reason, FINISH_ABORT):
1265
+ self.init_req_max_new_tokens(req)
1199
1266
  self._add_request_to_queue(req)
1200
1267
  return
1201
1268
 
@@ -1215,9 +1282,13 @@ class Scheduler(
1215
1282
  f"After expanding {len(req.origin_input_ids_unpadded)=} => {len(req.origin_input_ids)} >= {self.max_req_input_len}."
1216
1283
  )
1217
1284
  )
1285
+ self.init_req_max_new_tokens(req)
1218
1286
  self._add_request_to_queue(req)
1219
1287
  return
1220
1288
 
1289
+ # initialize before returning
1290
+ self.init_req_max_new_tokens(req)
1291
+
1221
1292
  # Validate prompt length
1222
1293
  error_msg = validate_input_length(
1223
1294
  req,
@@ -1232,26 +1303,25 @@ class Scheduler(
1232
1303
  # Copy more attributes
1233
1304
  if recv_req.logprob_start_len == -1 or not recv_req.return_logprob:
1234
1305
  # By default, only return the logprobs for output tokens
1235
- req.logprob_start_len = len(req.origin_input_ids) - 1
1306
+ # For prefill-only requests with logprob_start_len == -1, set logprob_start_len beyond input sequence
1307
+ # to skip input logprob computation entirely
1308
+ if req.is_prefill_only:
1309
+ req.logprob_start_len = len(req.origin_input_ids)
1310
+ else:
1311
+ # TODO: For text generation, evaluate setting logprob_start_len to len(req.origin_input_ids) as well
1312
+ req.logprob_start_len = len(req.origin_input_ids) - 1
1236
1313
  else:
1237
1314
  req.logprob_start_len = recv_req.logprob_start_len
1238
1315
 
1239
- if req.logprob_start_len >= len(req.origin_input_ids):
1316
+ if not req.is_prefill_only and req.logprob_start_len >= len(
1317
+ req.origin_input_ids
1318
+ ):
1240
1319
  error_msg = f"{req.logprob_start_len=} is higher than the number of input tokens {len(req.origin_input_ids)=}. Please use a smaller logprob_start_len."
1241
1320
  req.logprob_start_len = len(req.origin_input_ids) - 1
1242
1321
  req.set_finish_with_abort(error_msg)
1243
1322
  self._add_request_to_queue(req)
1244
1323
  return
1245
1324
 
1246
- req.sampling_params.max_new_tokens = min(
1247
- (
1248
- req.sampling_params.max_new_tokens
1249
- if req.sampling_params.max_new_tokens is not None
1250
- else 1 << 30
1251
- ),
1252
- self.max_req_len - len(req.origin_input_ids) - 1,
1253
- )
1254
-
1255
1325
  # Init grammar cache for this request
1256
1326
  add_to_grammar_queue = False
1257
1327
  if (
@@ -1310,6 +1380,7 @@ class Scheduler(
1310
1380
  else:
1311
1381
  self._prefetch_kvcache(req)
1312
1382
  self.waiting_queue.append(req)
1383
+ trace_slice_end("process req", req.rid, auto_next_anon=True)
1313
1384
 
1314
1385
  def _prefetch_kvcache(self, req: Req):
1315
1386
  if self.enable_hicache_storage:
@@ -1421,9 +1492,11 @@ class Scheduler(
1421
1492
  _, _, available_size, evictable_size = self._get_token_info()
1422
1493
  protected_size = self.tree_cache.protected_size()
1423
1494
  memory_leak = (available_size + evictable_size) != (
1495
+ # self.max_total_num_tokens
1496
+ # if not self.enable_hierarchical_cache
1497
+ # else self.max_total_num_tokens - protected_size
1424
1498
  self.max_total_num_tokens
1425
- if not self.enable_hierarchical_cache
1426
- else self.max_total_num_tokens - protected_size
1499
+ - protected_size
1427
1500
  )
1428
1501
  token_msg = f"{self.max_total_num_tokens=}, {available_size=}, {evictable_size=}, {protected_size=}\n"
1429
1502
 
@@ -1474,6 +1547,20 @@ class Scheduler(
1474
1547
  self.stats.gen_throughput = 0
1475
1548
  self.stats.num_queue_reqs = len(self.waiting_queue)
1476
1549
  self.stats.num_grammar_queue_reqs = len(self.grammar_queue)
1550
+ if self.disaggregation_mode == DisaggregationMode.PREFILL:
1551
+ self.stats.num_prefill_prealloc_queue_reqs = len(
1552
+ self.disagg_prefill_bootstrap_queue.queue
1553
+ )
1554
+ self.stats.num_prefill_inflight_queue_reqs = len(
1555
+ self.disagg_prefill_inflight_queue
1556
+ )
1557
+ if self.disaggregation_mode == DisaggregationMode.DECODE:
1558
+ self.stats.num_decode_prealloc_queue_reqs = len(
1559
+ self.disagg_decode_prealloc_queue.queue
1560
+ )
1561
+ self.stats.num_decode_transfer_queue_reqs = len(
1562
+ self.disagg_decode_transfer_queue.queue
1563
+ )
1477
1564
  self.metrics_collector.log_stats(self.stats)
1478
1565
  self._publish_kv_events()
1479
1566
 
@@ -1521,7 +1608,12 @@ class Scheduler(
1521
1608
  chunked_req_to_exclude.add(self.chunked_req)
1522
1609
  self.tree_cache.cache_unfinished_req(self.chunked_req, chunked=True)
1523
1610
  # chunked request keeps its rid but will get a new req_pool_idx
1524
- self.req_to_token_pool.free(self.chunked_req.req_pool_idx)
1611
+ if self.tp_worker.worker.model_runner.is_hybrid_gdn:
1612
+ self.req_to_token_pool.free(
1613
+ self.chunked_req.req_pool_idx, free_mamba_cache=False
1614
+ )
1615
+ else:
1616
+ self.req_to_token_pool.free(self.chunked_req.req_pool_idx)
1525
1617
  if self.last_batch and self.last_batch.forward_mode.is_extend():
1526
1618
  if self.last_batch.chunked_req is not None:
1527
1619
  # In the context pipeline parallelism, after the last chunk, the current microbatch still track outdated chunked_req.
@@ -1568,11 +1660,7 @@ class Scheduler(
1568
1660
 
1569
1661
  # Handle DP attention
1570
1662
  if need_dp_attn_preparation:
1571
- if (
1572
- self.server_args.load_balance_method == "minimum_tokens"
1573
- and self.forward_ct % 40 == 0
1574
- ):
1575
- self.handle_dp_balance_data(ret)
1663
+ self.maybe_handle_dp_balance_data()
1576
1664
  ret = self.prepare_mlp_sync_batch(ret)
1577
1665
 
1578
1666
  return ret
@@ -1792,10 +1880,6 @@ class Scheduler(
1792
1880
  if self.spec_algorithm.is_none():
1793
1881
  model_worker_batch = batch.get_model_worker_batch()
1794
1882
 
1795
- # update the consumer index of hicache to the running batch
1796
- self.tp_worker.set_hicache_consumer(
1797
- model_worker_batch.hicache_consumer_index
1798
- )
1799
1883
  if self.pp_group.is_last_rank:
1800
1884
  logits_output, next_token_ids, can_run_cuda_graph = (
1801
1885
  self.tp_worker.forward_batch_generation(model_worker_batch)
@@ -1864,8 +1948,23 @@ class Scheduler(
1864
1948
  ):
1865
1949
  if batch.forward_mode.is_decode():
1866
1950
  self.process_batch_result_decode(batch, result, launch_done)
1951
+ for req in batch.reqs:
1952
+ trace_slice(
1953
+ "decode loop",
1954
+ req.rid,
1955
+ auto_next_anon=not req.finished(),
1956
+ thread_finish_flag=req.finished(),
1957
+ )
1958
+
1867
1959
  elif batch.forward_mode.is_extend():
1868
1960
  self.process_batch_result_prefill(batch, result, launch_done)
1961
+ for req in batch.reqs:
1962
+ trace_slice(
1963
+ "prefill",
1964
+ req.rid,
1965
+ auto_next_anon=not req.finished(),
1966
+ thread_finish_flag=req.finished(),
1967
+ )
1869
1968
  elif batch.forward_mode.is_idle():
1870
1969
  if self.enable_overlap:
1871
1970
  self.tp_worker.resolve_last_batch_result(launch_done)
@@ -1897,86 +1996,6 @@ class Scheduler(
1897
1996
  disable_overlap_schedule=self.server_args.disable_overlap_schedule,
1898
1997
  )
1899
1998
 
1900
- def handle_dp_balance_data(self, local_batch: ScheduleBatch):
1901
- def gather_dp_balance_info(holding_tokens_list) -> Union[None, List[List[int]]]:
1902
- """gather recv_dp_balance_id_this_term and holding tokens per worker for dp balance"""
1903
- recv_list = self.recv_dp_balance_id_this_term
1904
- assert len(recv_list) <= 511, (
1905
- "The number of requests received this round is too large. "
1906
- "Please increase gather_tensor_size and onfly_info_size."
1907
- )
1908
- # The maximum size of the tensor used for gathering data from all workers.
1909
- gather_tensor_size = 512
1910
-
1911
- # recv_tensor: | holding_tokens | len(recv_dp_balance_id) | recv_dp_balance_ids
1912
- recv_tensor = torch.zeros(gather_tensor_size, dtype=torch.int32)
1913
- recv_tensor[0] = holding_tokens_list
1914
- recv_tensor[1] = len(
1915
- recv_list
1916
- ) # The first element is the length of the list.
1917
- recv_tensor[2 : len(recv_list) + 2] = torch.tensor(
1918
- recv_list, dtype=torch.int32
1919
- )
1920
-
1921
- if self.tp_rank == 0:
1922
- gathered_list = [
1923
- torch.zeros(gather_tensor_size, dtype=torch.int32)
1924
- for _ in range(self.balance_meta.num_workers)
1925
- ]
1926
- else:
1927
- gathered_list = None
1928
-
1929
- torch.distributed.gather(
1930
- recv_tensor, gathered_list, group=self.tp_cpu_group
1931
- )
1932
-
1933
- gathered_id_list_per_worker = None
1934
- if self.tp_rank == 0:
1935
- gathered_id_list_per_worker = []
1936
- holding_tokens_list = []
1937
- for tensor in gathered_list:
1938
- holding_tokens_list.append(tensor[0].item())
1939
- list_length = tensor[1].item()
1940
- gathered_id_list_per_worker.append(
1941
- tensor[2 : list_length + 2].tolist()
1942
- )
1943
-
1944
- return gathered_id_list_per_worker, holding_tokens_list
1945
-
1946
- def write_shared_dp_balance_info(new_recv_rid_lists, local_tokens):
1947
- meta = self.balance_meta
1948
-
1949
- with meta.mutex:
1950
- onfly_list: List[Dict[int, int]] = meta.get_shared_onfly()
1951
- assert len(new_recv_rid_lists) == len(
1952
- onfly_list
1953
- ), "num_worker not equal"
1954
- # 1.Check if the rid received by each worker this round is present in onfly.
1955
- # If it is, remove the corresponding onfly item.
1956
- worker_id = 0
1957
- for new_recv_rids, on_fly_reqs in zip(new_recv_rid_lists, onfly_list):
1958
- for new_recv_rid in new_recv_rids:
1959
- assert (
1960
- new_recv_rid in on_fly_reqs
1961
- ), f"{new_recv_rid=} not in {worker_id=} {on_fly_reqs=}, data consistency is wrong"
1962
- del on_fly_reqs[new_recv_rid]
1963
- worker_id += 1
1964
- # 2. Atomically write local_tokens and onfly into shm under the mutex
1965
- meta.set_shared_onfly_info(onfly_list)
1966
- meta.set_shared_local_tokens(local_tokens)
1967
-
1968
- holding_tokens = self.get_load()
1969
-
1970
- new_recv_dp_balance_id_list, holding_token_list = gather_dp_balance_info(
1971
- holding_tokens
1972
- )
1973
-
1974
- self.recv_dp_balance_id_this_term.clear()
1975
- if self.tp_rank == 0: # only first worker write info
1976
- write_shared_dp_balance_info(
1977
- new_recv_dp_balance_id_list, holding_token_list
1978
- )
1979
-
1980
1999
  @staticmethod
1981
2000
  def prepare_mlp_sync_batch_raw(
1982
2001
  local_batch: ScheduleBatch,
@@ -2270,39 +2289,50 @@ class Scheduler(
2270
2289
  if_success = False
2271
2290
  return if_success
2272
2291
 
2273
- def get_load(self):
2292
+ def get_load(self, recv_req: GetLoadReqInput = None) -> GetLoadReqOutput:
2274
2293
  # TODO(lsyin): use dynamically maintained num_waiting_tokens
2294
+
2275
2295
  if self.is_hybrid:
2276
- load_full = (
2296
+ num_tokens_full = (
2277
2297
  self.full_tokens_per_layer
2278
2298
  - self.token_to_kv_pool_allocator.full_available_size()
2279
2299
  - self.tree_cache.full_evictable_size()
2280
2300
  )
2281
- load_swa = (
2301
+ num_tokens_swa = (
2282
2302
  self.swa_tokens_per_layer
2283
2303
  - self.token_to_kv_pool_allocator.swa_available_size()
2284
2304
  - self.tree_cache.swa_evictable_size()
2285
2305
  )
2286
- load = max(load_full, load_swa)
2306
+ num_tokens = max(num_tokens_full, num_tokens_swa)
2287
2307
  else:
2288
- load = (
2308
+ num_tokens = (
2289
2309
  self.max_total_num_tokens
2290
2310
  - self.token_to_kv_pool_allocator.available_size()
2291
2311
  - self.tree_cache.evictable_size()
2292
2312
  )
2293
- load += sum(len(req.origin_input_ids) for req in self.waiting_queue)
2313
+
2314
+ # Tokens in waiting queue, bootstrap queue, prealloc queue
2315
+ num_tokens += sum(len(req.origin_input_ids) for req in self.waiting_queue)
2316
+ num_waiting_reqs = len(self.waiting_queue)
2294
2317
  if self.disaggregation_mode == DisaggregationMode.PREFILL:
2295
- load += sum(
2318
+ num_tokens += sum(
2296
2319
  len(req.origin_input_ids)
2297
2320
  for req in self.disagg_prefill_bootstrap_queue.queue
2298
2321
  )
2322
+ num_waiting_reqs += len(self.disagg_prefill_bootstrap_queue.queue)
2299
2323
  elif self.disaggregation_mode == DisaggregationMode.DECODE:
2300
- load += sum(
2324
+ num_tokens += sum(
2301
2325
  len(req.req.origin_input_ids)
2302
2326
  for req in self.disagg_decode_prealloc_queue.queue
2303
2327
  )
2328
+ num_waiting_reqs += len(self.disagg_decode_prealloc_queue.queue)
2304
2329
 
2305
- return load
2330
+ return GetLoadReqOutput(
2331
+ dp_rank=self.dp_rank,
2332
+ num_reqs=len(self.running_batch.reqs) + num_waiting_reqs,
2333
+ num_waiting_reqs=num_waiting_reqs,
2334
+ num_tokens=num_tokens,
2335
+ )
2306
2336
 
2307
2337
  def get_internal_state(self, recv_req: GetInternalStateReq):
2308
2338
  ret = dict(global_server_args_dict)
@@ -2317,10 +2347,9 @@ class Scheduler(
2317
2347
  "token_capacity": int(self.max_total_num_tokens),
2318
2348
  }
2319
2349
 
2320
- if not _is_cpu:
2321
- ret["memory_usage"]["cuda_graph"] = round(
2322
- self.tp_worker.worker.model_runner.cuda_graph_mem_usage, 2
2323
- )
2350
+ ret["memory_usage"]["graph"] = round(
2351
+ self.tp_worker.worker.model_runner.graph_mem_usage, 2
2352
+ )
2324
2353
 
2325
2354
  if not self.spec_algorithm.is_none() and self.cum_spec_accept_count > 0:
2326
2355
  ret["avg_spec_accept_length"] = (
@@ -2329,8 +2358,6 @@ class Scheduler(
2329
2358
  if RECORD_STEP_TIME:
2330
2359
  ret["step_time_dict"] = self.step_time_dict
2331
2360
 
2332
- ret["load"] = self.get_load()
2333
-
2334
2361
  return GetInternalStateReqOutput(internal_state=ret)
2335
2362
 
2336
2363
  def set_internal_state(self, recv_req: SetInternalStateReq):
@@ -2494,6 +2521,22 @@ class Scheduler(
2494
2521
  self.send_to_detokenizer.send_pyobj(recv_req)
2495
2522
  return recv_req
2496
2523
 
2524
+ def init_weights_send_group_for_remote_instance(
2525
+ self, recv_req: InitWeightsSendGroupForRemoteInstanceReqInput
2526
+ ):
2527
+ """Init the seed and client instance communication group."""
2528
+ success, message = self.tp_worker.init_weights_send_group_for_remote_instance(
2529
+ recv_req
2530
+ )
2531
+ return InitWeightsSendGroupForRemoteInstanceReqOutput(success, message)
2532
+
2533
+ def send_weights_to_remote_instance(
2534
+ self, recv_req: SendWeightsToRemoteInstanceReqInput
2535
+ ):
2536
+ """Send the seed instance weights to the destination instance."""
2537
+ success, message = self.tp_worker.send_weights_to_remote_instance(recv_req)
2538
+ return SendWeightsToRemoteInstanceReqOutput(success, message)
2539
+
2497
2540
  def slow_down(self, recv_req: SlowDownReqInput):
2498
2541
  t = recv_req.forward_sleep_time
2499
2542
  if t is not None and t <= 0:
@@ -2615,6 +2658,15 @@ def run_scheduler_process(
2615
2658
  pipe_writer,
2616
2659
  balance_meta: Optional[DPBalanceMeta] = None,
2617
2660
  ):
2661
+ if server_args.enable_trace:
2662
+ process_tracing_init(server_args.oltp_traces_endpoint, "sglang")
2663
+ if server_args.disaggregation_mode == "null":
2664
+ thread_label = "Scheduler"
2665
+ trace_set_thread_info(thread_label, tp_rank, dp_rank)
2666
+
2667
+ if (numa_node := server_args.numa_node) is not None:
2668
+ numa_bind_to_node(numa_node[gpu_id])
2669
+
2618
2670
  # Generate the prefix
2619
2671
  prefix = ""
2620
2672
  if dp_rank is not None: