sglang 0.5.3rc0__py3-none-any.whl → 0.5.3rc2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (282) hide show
  1. sglang/bench_one_batch.py +7 -9
  2. sglang/bench_one_batch_server.py +321 -31
  3. sglang/bench_serving.py +10 -3
  4. sglang/global_config.py +2 -2
  5. sglang/lang/backend/runtime_endpoint.py +1 -1
  6. sglang/launch_server.py +14 -0
  7. sglang/profiler.py +2 -2
  8. sglang/srt/batch_invariant_ops/__init__.py +27 -0
  9. sglang/srt/batch_invariant_ops/batch_invariant_ops.py +549 -0
  10. sglang/srt/configs/__init__.py +4 -0
  11. sglang/srt/configs/dots_ocr.py +64 -0
  12. sglang/srt/configs/falcon_h1.py +360 -0
  13. sglang/srt/configs/load_config.py +8 -0
  14. sglang/srt/configs/model_config.py +160 -105
  15. sglang/srt/configs/qwen3_vl.py +586 -0
  16. sglang/srt/constrained/base_grammar_backend.py +1 -0
  17. sglang/srt/constrained/outlines_jump_forward.py +1 -1
  18. sglang/srt/constrained/xgrammar_backend.py +6 -4
  19. sglang/srt/debug_utils/dumper.py +10 -3
  20. sglang/srt/disaggregation/ascend/conn.py +2 -2
  21. sglang/srt/disaggregation/ascend/transfer_engine.py +47 -9
  22. sglang/srt/disaggregation/common/conn.py +266 -98
  23. sglang/srt/disaggregation/decode.py +50 -9
  24. sglang/srt/disaggregation/decode_kvcache_offload_manager.py +185 -0
  25. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +25 -16
  26. sglang/srt/disaggregation/mooncake/conn.py +51 -541
  27. sglang/srt/disaggregation/nixl/conn.py +148 -39
  28. sglang/srt/disaggregation/prefill.py +31 -14
  29. sglang/srt/disaggregation/utils.py +36 -5
  30. sglang/srt/distributed/device_communicators/all_reduce_utils.py +16 -0
  31. sglang/srt/distributed/device_communicators/shm_broadcast.py +4 -2
  32. sglang/srt/distributed/device_communicators/symm_mem.py +164 -0
  33. sglang/srt/distributed/parallel_state.py +135 -80
  34. sglang/srt/entrypoints/engine.py +23 -3
  35. sglang/srt/entrypoints/grpc_request_manager.py +330 -55
  36. sglang/srt/entrypoints/grpc_server.py +232 -102
  37. sglang/srt/entrypoints/http_server.py +49 -9
  38. sglang/srt/entrypoints/openai/protocol.py +110 -5
  39. sglang/srt/entrypoints/openai/serving_base.py +25 -6
  40. sglang/srt/entrypoints/openai/serving_chat.py +178 -49
  41. sglang/srt/entrypoints/openai/serving_completions.py +5 -3
  42. sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
  43. sglang/srt/entrypoints/openai/serving_responses.py +42 -0
  44. sglang/srt/environ.py +285 -0
  45. sglang/srt/eplb/expert_location.py +30 -5
  46. sglang/srt/function_call/function_call_parser.py +3 -2
  47. sglang/srt/function_call/glm4_moe_detector.py +3 -3
  48. sglang/srt/function_call/gpt_oss_detector.py +23 -0
  49. sglang/srt/function_call/json_array_parser.py +63 -0
  50. sglang/srt/function_call/kimik2_detector.py +17 -4
  51. sglang/srt/function_call/utils.py +96 -5
  52. sglang/srt/grpc/compile_proto.py +245 -0
  53. sglang/srt/grpc/sglang_scheduler_pb2.py +73 -68
  54. sglang/srt/grpc/sglang_scheduler_pb2.pyi +60 -53
  55. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +3 -0
  56. sglang/srt/layers/activation.py +7 -6
  57. sglang/srt/layers/attention/aiter_backend.py +14 -15
  58. sglang/srt/layers/attention/ascend_backend.py +108 -9
  59. sglang/srt/layers/attention/attention_registry.py +206 -0
  60. sglang/srt/layers/attention/base_attn_backend.py +12 -3
  61. sglang/srt/layers/attention/cutlass_mla_backend.py +3 -3
  62. sglang/srt/layers/attention/dual_chunk_flashattention_backend.py +1 -1
  63. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +2 -2
  64. sglang/srt/layers/attention/fla/fused_recurrent.py +4 -4
  65. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +2 -2
  66. sglang/srt/layers/attention/flashattention_backend.py +41 -8
  67. sglang/srt/layers/attention/flashinfer_backend.py +112 -194
  68. sglang/srt/layers/attention/flashinfer_mla_backend.py +11 -15
  69. sglang/srt/layers/attention/flashmla_backend.py +7 -5
  70. sglang/srt/layers/attention/hybrid_attn_backend.py +11 -3
  71. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +72 -72
  72. sglang/srt/layers/attention/mamba/causal_conv1d.py +1 -0
  73. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +15 -98
  74. sglang/srt/layers/attention/mamba/mamba.py +566 -1
  75. sglang/srt/layers/attention/mamba/mamba_utils.py +81 -0
  76. sglang/srt/layers/attention/mamba/ops/__init__.py +2 -0
  77. sglang/srt/layers/attention/mamba/ops/layernorm_gated.py +172 -0
  78. sglang/srt/layers/attention/mamba/ops/mamba_ssm.py +442 -0
  79. sglang/srt/layers/attention/mamba/ops/ssd_bmm.py +264 -0
  80. sglang/srt/layers/attention/mamba/ops/ssd_chunk_scan.py +622 -0
  81. sglang/srt/layers/attention/mamba/ops/ssd_chunk_state.py +757 -0
  82. sglang/srt/layers/attention/mamba/ops/ssd_combined.py +262 -0
  83. sglang/srt/layers/attention/mamba/ops/ssd_state_passing.py +275 -0
  84. sglang/srt/layers/attention/npu_ops/mla_preprocess.py +393 -0
  85. sglang/srt/layers/attention/nsa/dequant_k_cache.py +163 -0
  86. sglang/srt/layers/attention/nsa/index_buf_accessor.py +354 -0
  87. sglang/srt/layers/attention/nsa/nsa_indexer.py +761 -0
  88. sglang/srt/layers/attention/nsa/quant_k_cache.py +255 -0
  89. sglang/srt/layers/attention/nsa/tilelang_kernel.py +785 -0
  90. sglang/srt/layers/attention/nsa/transform_index.py +144 -0
  91. sglang/srt/layers/attention/nsa/utils.py +24 -0
  92. sglang/srt/layers/attention/nsa_backend.py +887 -0
  93. sglang/srt/layers/attention/tbo_backend.py +6 -6
  94. sglang/srt/layers/attention/torch_flex_backend.py +325 -0
  95. sglang/srt/layers/attention/triton_backend.py +42 -9
  96. sglang/srt/layers/attention/trtllm_mha_backend.py +5 -7
  97. sglang/srt/layers/attention/trtllm_mla_backend.py +178 -34
  98. sglang/srt/layers/attention/vision.py +58 -0
  99. sglang/srt/layers/attention/wave_backend.py +4 -4
  100. sglang/srt/layers/communicator.py +8 -0
  101. sglang/srt/layers/dp_attention.py +11 -1
  102. sglang/srt/layers/elementwise.py +3 -1
  103. sglang/srt/layers/layernorm.py +2 -0
  104. sglang/srt/layers/linear.py +21 -4
  105. sglang/srt/layers/logits_processor.py +15 -2
  106. sglang/srt/layers/moe/ep_moe/kernels.py +1 -1
  107. sglang/srt/layers/moe/ep_moe/layer.py +147 -74
  108. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +52 -25
  109. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=256,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  110. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H800,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  111. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_B200.json +146 -0
  112. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +6 -2
  113. sglang/srt/layers/moe/fused_moe_triton/layer.py +11 -12
  114. sglang/srt/layers/moe/token_dispatcher/deepep.py +77 -19
  115. sglang/srt/layers/moe/utils.py +10 -0
  116. sglang/srt/layers/parameter.py +23 -6
  117. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors.py +1 -0
  118. sglang/srt/layers/quantization/compressed_tensors/schemes/__init__.py +2 -0
  119. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_int8.py +173 -0
  120. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -10
  121. sglang/srt/layers/quantization/fp8.py +2 -2
  122. sglang/srt/layers/quantization/fp8_utils.py +1 -1
  123. sglang/srt/layers/quantization/modelopt_quant.py +44 -9
  124. sglang/srt/layers/quantization/mxfp4.py +12 -4
  125. sglang/srt/layers/quantization/quark/quark_moe.py +16 -3
  126. sglang/srt/layers/quantization/w4afp8.py +0 -4
  127. sglang/srt/layers/quantization/w8a8_int8.py +15 -3
  128. sglang/srt/layers/rotary_embedding.py +78 -31
  129. sglang/srt/layers/sampler.py +52 -4
  130. sglang/srt/layers/utils.py +23 -0
  131. sglang/srt/lora/backend/base_backend.py +3 -3
  132. sglang/srt/lora/backend/chunked_backend.py +348 -0
  133. sglang/srt/lora/backend/triton_backend.py +10 -4
  134. sglang/srt/lora/lora.py +7 -5
  135. sglang/srt/lora/lora_manager.py +17 -6
  136. sglang/srt/lora/mem_pool.py +1 -1
  137. sglang/srt/lora/triton_ops/__init__.py +4 -0
  138. sglang/srt/lora/triton_ops/chunked_sgmv_expand.py +214 -0
  139. sglang/srt/lora/triton_ops/chunked_sgmv_shrink.py +174 -0
  140. sglang/srt/lora/utils.py +7 -5
  141. sglang/srt/managers/cache_controller.py +42 -142
  142. sglang/srt/managers/data_parallel_controller.py +11 -46
  143. sglang/srt/managers/detokenizer_manager.py +11 -11
  144. sglang/srt/managers/io_struct.py +162 -118
  145. sglang/srt/managers/mm_utils.py +43 -6
  146. sglang/srt/managers/multi_tokenizer_mixin.py +17 -17
  147. sglang/srt/managers/multimodal_processor.py +1 -2
  148. sglang/srt/managers/overlap_utils.py +53 -0
  149. sglang/srt/managers/schedule_batch.py +167 -86
  150. sglang/srt/managers/schedule_policy.py +143 -16
  151. sglang/srt/managers/scheduler.py +359 -214
  152. sglang/srt/managers/scheduler_input_blocker.py +1 -1
  153. sglang/srt/managers/scheduler_metrics_mixin.py +98 -126
  154. sglang/srt/managers/scheduler_output_processor_mixin.py +21 -12
  155. sglang/srt/managers/scheduler_profiler_mixin.py +5 -5
  156. sglang/srt/managers/scheduler_update_weights_mixin.py +7 -0
  157. sglang/srt/managers/tokenizer_communicator_mixin.py +111 -5
  158. sglang/srt/managers/tokenizer_manager.py +84 -136
  159. sglang/srt/managers/tp_worker.py +39 -29
  160. sglang/srt/managers/tp_worker_overlap_thread.py +33 -41
  161. sglang/srt/managers/utils.py +1 -45
  162. sglang/srt/mem_cache/allocator.py +14 -20
  163. sglang/srt/mem_cache/allocator_ascend.py +41 -27
  164. sglang/srt/mem_cache/base_prefix_cache.py +1 -1
  165. sglang/srt/mem_cache/chunk_cache.py +8 -1
  166. sglang/srt/mem_cache/evict_policy.py +23 -0
  167. sglang/srt/mem_cache/hicache_storage.py +40 -1
  168. sglang/srt/mem_cache/hiradix_cache.py +119 -32
  169. sglang/srt/mem_cache/memory_pool.py +188 -10
  170. sglang/srt/mem_cache/memory_pool_host.py +134 -182
  171. sglang/srt/mem_cache/radix_cache.py +222 -71
  172. sglang/srt/mem_cache/radix_cache_cpp.py +11 -8
  173. sglang/srt/mem_cache/storage/__init__.py +10 -0
  174. sglang/srt/mem_cache/storage/aibrix_kvcache/aibrix_kvcache_storage.py +151 -0
  175. sglang/srt/mem_cache/storage/aibrix_kvcache/unit_test.py +109 -0
  176. sglang/srt/mem_cache/storage/backend_factory.py +223 -0
  177. sglang/srt/mem_cache/storage/eic/eic_storage.py +778 -0
  178. sglang/srt/mem_cache/storage/eic/test_unit.py +115 -0
  179. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +173 -58
  180. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +10 -6
  181. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +117 -10
  182. sglang/srt/mem_cache/swa_radix_cache.py +25 -34
  183. sglang/srt/metrics/collector.py +82 -120
  184. sglang/srt/metrics/func_timer.py +2 -7
  185. sglang/srt/metrics/utils.py +8 -1
  186. sglang/srt/model_executor/cpu_graph_runner.py +2 -2
  187. sglang/srt/model_executor/cuda_graph_runner.py +39 -32
  188. sglang/srt/model_executor/forward_batch_info.py +23 -38
  189. sglang/srt/model_executor/model_runner.py +131 -183
  190. sglang/srt/model_executor/npu_graph_runner.py +12 -5
  191. sglang/srt/model_loader/loader.py +14 -10
  192. sglang/srt/model_loader/weight_utils.py +156 -2
  193. sglang/srt/models/bailing_moe.py +27 -4
  194. sglang/srt/models/deepseek_nextn.py +6 -1
  195. sglang/srt/models/deepseek_v2.py +536 -153
  196. sglang/srt/models/dots_ocr.py +173 -0
  197. sglang/srt/models/falcon_h1.py +576 -0
  198. sglang/srt/models/gemma3_causal.py +0 -2
  199. sglang/srt/models/gemma3_mm.py +1 -1
  200. sglang/srt/models/gemma3n_mm.py +1 -1
  201. sglang/srt/models/glm4_moe.py +3 -3
  202. sglang/srt/models/glm4_moe_nextn.py +2 -2
  203. sglang/srt/models/glm4v.py +1 -1
  204. sglang/srt/models/glm4v_moe.py +1 -1
  205. sglang/srt/models/gpt_oss.py +7 -30
  206. sglang/srt/models/kimi_vl_moonvit.py +2 -2
  207. sglang/srt/models/llama.py +4 -0
  208. sglang/srt/models/longcat_flash.py +1 -1
  209. sglang/srt/models/longcat_flash_nextn.py +1 -1
  210. sglang/srt/models/mllama4.py +15 -4
  211. sglang/srt/models/qwen2.py +0 -7
  212. sglang/srt/models/qwen2_5_vl.py +2 -2
  213. sglang/srt/models/qwen2_audio.py +1 -1
  214. sglang/srt/models/qwen2_moe.py +64 -1
  215. sglang/srt/models/qwen2_vl.py +1 -1
  216. sglang/srt/models/qwen3.py +18 -3
  217. sglang/srt/models/qwen3_moe.py +31 -3
  218. sglang/srt/models/qwen3_next.py +36 -9
  219. sglang/srt/models/qwen3_vl.py +787 -0
  220. sglang/srt/models/qwen3_vl_moe.py +471 -0
  221. sglang/srt/models/registry.py +15 -3
  222. sglang/srt/models/sarashina2_vision.py +269 -0
  223. sglang/srt/models/solar.py +505 -0
  224. sglang/srt/models/starcoder2.py +357 -0
  225. sglang/srt/models/torch_native_llama.py +9 -2
  226. sglang/srt/models/utils.py +51 -0
  227. sglang/srt/multimodal/processors/base_processor.py +15 -7
  228. sglang/srt/multimodal/processors/dots_vlm.py +2 -3
  229. sglang/srt/multimodal/processors/internvl.py +20 -8
  230. sglang/srt/multimodal/processors/qwen_vl.py +8 -1
  231. sglang/srt/multimodal/processors/sarashina2_vision.py +81 -0
  232. sglang/srt/parser/jinja_template_utils.py +6 -0
  233. sglang/srt/sampling/sampling_batch_info.py +20 -2
  234. sglang/srt/sampling/sampling_params.py +7 -0
  235. sglang/srt/server_args.py +753 -295
  236. sglang/srt/server_args_config_parser.py +146 -0
  237. sglang/srt/single_batch_overlap.py +151 -0
  238. sglang/srt/speculative/cpp_ngram/ngram.cpp +374 -0
  239. sglang/srt/speculative/cpp_ngram/ngram.h +110 -0
  240. sglang/srt/speculative/cpp_ngram/ngram_cache.py +138 -0
  241. sglang/srt/speculative/cpp_ngram/ngram_cache_binding.cpp +43 -0
  242. sglang/srt/speculative/cpp_ngram/param.h +125 -0
  243. sglang/srt/speculative/cpp_ngram/queue.h +71 -0
  244. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +2 -1
  245. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +3 -1
  246. sglang/srt/speculative/{eagle_utils.py → eagle_info.py} +207 -755
  247. sglang/srt/speculative/eagle_worker.py +57 -25
  248. sglang/srt/speculative/ngram_utils.py +428 -0
  249. sglang/srt/speculative/ngram_worker.py +245 -0
  250. sglang/srt/speculative/spec_info.py +47 -0
  251. sglang/srt/speculative/spec_utils.py +606 -0
  252. sglang/srt/torch_memory_saver_adapter.py +5 -7
  253. sglang/srt/tracing/trace.py +32 -6
  254. sglang/srt/two_batch_overlap.py +8 -5
  255. sglang/srt/utils/__init__.py +2 -0
  256. sglang/srt/{utils.py → utils/common.py} +399 -74
  257. sglang/srt/{hf_transformers_utils.py → utils/hf_transformers_utils.py} +49 -5
  258. sglang/srt/{patch_torch.py → utils/patch_torch.py} +8 -0
  259. sglang/srt/utils/rpd_utils.py +452 -0
  260. sglang/srt/utils/slow_rank_detector.py +71 -0
  261. sglang/srt/warmup.py +8 -4
  262. sglang/srt/weight_sync/utils.py +1 -1
  263. sglang/test/get_logits_ut.py +57 -0
  264. sglang/test/run_eval.py +79 -11
  265. sglang/test/runners.py +1 -1
  266. sglang/test/simple_eval_common.py +5 -2
  267. sglang/test/simple_eval_mmmu_vlm.py +441 -0
  268. sglang/test/test_block_fp8.py +2 -2
  269. sglang/test/test_deterministic.py +297 -0
  270. sglang/test/test_disaggregation_utils.py +12 -1
  271. sglang/test/test_programs.py +1 -1
  272. sglang/test/test_utils.py +355 -4
  273. sglang/utils.py +10 -1
  274. sglang/version.py +1 -1
  275. {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/METADATA +34 -25
  276. {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/RECORD +281 -210
  277. sglang/srt/mem_cache/lora_radix_cache.py +0 -421
  278. /sglang/srt/{remote_instance_weight_loader_utils.py → model_loader/remote_instance_weight_loader_utils.py} +0 -0
  279. /sglang/srt/{poll_based_barrier.py → utils/poll_based_barrier.py} +0 -0
  280. {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/WHEEL +0 -0
  281. {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/licenses/LICENSE +0 -0
  282. {sglang-0.5.3rc0.dist-info → sglang-0.5.3rc2.dist-info}/top_level.txt +0 -0
@@ -1,37 +1,30 @@
1
1
  from __future__ import annotations
2
2
 
3
- import asyncio
4
3
  import dataclasses
5
4
  import logging
6
- import queue
7
- import socket
5
+ import os
8
6
  import struct
9
7
  import threading
8
+ import time
10
9
  import uuid
11
10
  from collections import defaultdict
12
- from functools import cache
13
- from typing import Dict, List, Optional, Set, Tuple, TypeAlias, Union
11
+ from typing import Dict, List, Optional, Set
14
12
 
15
13
  import numpy as np
16
14
  import numpy.typing as npt
17
15
  import requests
18
- import zmq
19
- from aiohttp import web
20
16
 
21
- from sglang.srt.disaggregation.base.conn import BaseKVSender, KVArgs, KVPoll
17
+ from sglang.srt.disaggregation.base.conn import KVArgs, KVPoll
22
18
  from sglang.srt.disaggregation.common.conn import (
23
19
  CommonKVBootstrapServer,
24
20
  CommonKVManager,
25
21
  CommonKVReceiver,
22
+ CommonKVSender,
26
23
  )
27
24
  from sglang.srt.disaggregation.common.utils import group_concurrent_contiguous
28
25
  from sglang.srt.disaggregation.utils import DisaggregationMode
29
26
  from sglang.srt.server_args import ServerArgs
30
- from sglang.srt.utils import (
31
- format_tcp_address,
32
- get_local_ip_auto,
33
- is_valid_ipv6_address,
34
- )
27
+ from sglang.srt.utils import get_int_env_var
35
28
 
36
29
  logger = logging.getLogger(__name__)
37
30
 
@@ -113,8 +106,14 @@ class TransferStatus:
113
106
  def is_done(self):
114
107
  if self.num_kvs_expected is None:
115
108
  return False
109
+ # Check for failure state
110
+ if self.num_kvs_expected == -1:
111
+ return True # Failed transfers are considered "done"
116
112
  return self.num_kvs_expected == len(self.received_kvs) and self.received_aux
117
113
 
114
+ def is_failed(self):
115
+ return self.num_kvs_expected == -1
116
+
118
117
 
119
118
  class NixlKVManager(CommonKVManager):
120
119
  def __init__(
@@ -134,26 +133,133 @@ class NixlKVManager(CommonKVManager):
134
133
  "to run SGLang with NixlTransferEngine."
135
134
  ) from e
136
135
  self.agent = nixl_agent(str(uuid.uuid4()))
137
- self.local_ip = get_local_ip_auto()
138
- self.server_socket = zmq.Context().socket(zmq.PULL)
139
- if is_valid_ipv6_address(self.local_ip):
140
- self.server_socket.setsockopt(zmq.IPV6, 1)
141
136
  self.register_buffer_to_engine()
142
137
 
143
138
  if self.disaggregation_mode == DisaggregationMode.PREFILL:
144
- self.request_status: Dict[int, KVPoll] = {}
145
- self.transfer_infos: Dict[int, Dict[str, TransferInfo]] = {}
146
- self.decode_kv_args_table: Dict[str, KVArgsRegisterInfo] = {}
147
139
  self._start_bootstrap_thread()
148
140
  elif self.disaggregation_mode == DisaggregationMode.DECODE:
149
141
  self.transfer_statuses: Dict[int, TransferStatus] = defaultdict(
150
142
  TransferStatus
151
143
  )
144
+ self.heartbeat_failures = {}
145
+ self.session_pool = defaultdict(requests.Session)
146
+ self.session_pool_lock = threading.Lock()
147
+ self.addr_to_rooms_tracker = defaultdict(set)
148
+ self.connection_lock = threading.Lock()
149
+
150
+ # Heartbeat interval should be at least 2 seconds
151
+ self.heartbeat_interval = max(
152
+ float(os.getenv("SGLANG_DISAGGREGATION_HEARTBEAT_INTERVAL", 5.0)), 2.0
153
+ )
154
+ # Heartbeat failure should be at least 1
155
+ self.max_failures = max(
156
+ get_int_env_var("SGLANG_DISAGGREGATION_HEARTBEAT_MAX_FAILURE", 2), 1
157
+ )
158
+ self._start_heartbeat_checker_thread()
152
159
  else:
153
160
  raise ValueError(
154
161
  f"Unsupported DisaggregationMode: {self.disaggregation_mode}"
155
162
  )
156
163
 
164
+ def _start_heartbeat_checker_thread(self):
165
+ """
166
+ Start the heartbeat checker thread for Decode worker.
167
+ TODO (smor): unite nixl heartbeat checker with mooncake's.
168
+ """
169
+
170
+ def heartbeat_checker():
171
+ while True:
172
+ time.sleep(self.heartbeat_interval)
173
+ with self.connection_lock:
174
+ addresses = list(self.prefill_dp_size_table.keys())
175
+
176
+ for bootstrap_addr in addresses:
177
+ session = None
178
+ try:
179
+ with self.session_pool_lock:
180
+ session = self.session_pool[bootstrap_addr]
181
+ response = session.get(
182
+ f"http://{bootstrap_addr}/health",
183
+ timeout=(2, 3),
184
+ headers={"Connection": "keep-alive"},
185
+ )
186
+ if response.status_code == 200:
187
+ self.heartbeat_failures[bootstrap_addr] = 0
188
+
189
+ current_rooms = self.addr_to_rooms_tracker[
190
+ bootstrap_addr
191
+ ].copy()
192
+
193
+ for bootstrap_room in current_rooms:
194
+ # Remove successful transfers from the tracker
195
+ if bootstrap_room not in self.transfer_statuses:
196
+ self.addr_to_rooms_tracker[bootstrap_addr].discard(
197
+ bootstrap_room
198
+ )
199
+ else:
200
+ logger.info(
201
+ f"Attempting to reconnect to {bootstrap_addr}..."
202
+ )
203
+ self.heartbeat_failures[bootstrap_addr] = (
204
+ self.heartbeat_failures.get(bootstrap_addr, 0) + 1
205
+ )
206
+ with self.session_pool_lock:
207
+ if bootstrap_addr in self.session_pool:
208
+ del self.session_pool[bootstrap_addr]
209
+ except Exception:
210
+ logger.info(f"Attempting to reconnect to {bootstrap_addr}...")
211
+ self.heartbeat_failures[bootstrap_addr] = (
212
+ self.heartbeat_failures.get(bootstrap_addr, 0) + 1
213
+ )
214
+
215
+ if (
216
+ self.heartbeat_failures.get(bootstrap_addr, 0)
217
+ >= self.max_failures
218
+ ):
219
+ self._handle_node_failure(bootstrap_addr)
220
+ with self.session_pool_lock:
221
+ if bootstrap_addr in self.session_pool:
222
+ del self.session_pool[bootstrap_addr]
223
+
224
+ threading.Thread(target=heartbeat_checker, daemon=True).start()
225
+
226
+ def _handle_node_failure(self, failed_bootstrap_addr):
227
+ """Handle failure of a prefill node."""
228
+ with self.connection_lock:
229
+ keys_to_remove = [
230
+ k for k in self.connection_pool if k.startswith(failed_bootstrap_addr)
231
+ ]
232
+ for k in keys_to_remove:
233
+ del self.connection_pool[k]
234
+ if failed_bootstrap_addr in self.prefill_tp_size_table:
235
+ del self.prefill_tp_size_table[failed_bootstrap_addr]
236
+ if failed_bootstrap_addr in self.prefill_dp_size_table:
237
+ del self.prefill_dp_size_table[failed_bootstrap_addr]
238
+ if failed_bootstrap_addr in self.prefill_pp_size_table:
239
+ del self.prefill_pp_size_table[failed_bootstrap_addr]
240
+
241
+ possible_affected_rooms = self.addr_to_rooms_tracker.get(
242
+ failed_bootstrap_addr, []
243
+ )
244
+ if failed_bootstrap_addr in self.addr_to_rooms_tracker:
245
+ del self.addr_to_rooms_tracker[failed_bootstrap_addr]
246
+
247
+ # Mark all pending transfers associated with the failed node as failed
248
+ affected_rooms = []
249
+ for room in possible_affected_rooms:
250
+ if (
251
+ room in self.transfer_statuses
252
+ and not self.transfer_statuses[room].is_done()
253
+ ):
254
+ # Mark the transfer as failed by setting a special state
255
+ self.transfer_statuses[room].num_kvs_expected = -1 # Indicates failure
256
+ affected_rooms.append(room)
257
+
258
+ logger.error(
259
+ f"Lost connection with prefill instance (bootstrap_addr: {failed_bootstrap_addr}), "
260
+ f"{len(affected_rooms)} transfers affected"
261
+ )
262
+
157
263
  def check_status(self, bootstrap_room: int):
158
264
  return self.request_status[bootstrap_room]
159
265
 
@@ -166,6 +272,9 @@ class NixlKVManager(CommonKVManager):
166
272
  self.request_status[bootstrap_room], status
167
273
  )
168
274
 
275
+ def record_failure(self, bootstrap_room: int, failure_reason: str):
276
+ pass
277
+
169
278
  def register_buffer_to_engine(self):
170
279
  kv_addrs = []
171
280
  for kv_data_ptr, kv_data_len in zip(
@@ -438,7 +547,7 @@ class NixlKVManager(CommonKVManager):
438
547
  notif = "_".join([str(req.room), "kv", str(chunk_id), str(int(is_last))])
439
548
  decode_tp_size = self.decode_kv_args_table[req.agent_name].decode_tp_size
440
549
 
441
- if decode_tp_size == self.tp_size:
550
+ if self.is_mla_backend or (decode_tp_size == self.attn_tp_size):
442
551
  kv_xfer_handle = self.send_kvcache(
443
552
  req.agent_name,
444
553
  kv_indices,
@@ -455,7 +564,7 @@ class NixlKVManager(CommonKVManager):
455
564
  chunked_dst_kv_indice,
456
565
  self.decode_kv_args_table[req.agent_name].gpu_id,
457
566
  notif,
458
- prefill_tp_size=self.tp_size,
567
+ prefill_tp_size=self.attn_tp_size,
459
568
  decode_tp_size=decode_tp_size,
460
569
  decode_tp_rank=self.decode_kv_args_table[
461
570
  req.agent_name
@@ -505,9 +614,6 @@ class NixlKVManager(CommonKVManager):
505
614
  return False
506
615
  return self.transfer_statuses[room].is_done()
507
616
 
508
- def _bind_server_socket(self):
509
- self.server_socket.bind(format_tcp_address(self.local_ip, self.rank_port))
510
-
511
617
  def _start_bootstrap_thread(self):
512
618
  self._bind_server_socket()
513
619
 
@@ -548,7 +654,7 @@ class NixlKVManager(CommonKVManager):
548
654
  threading.Thread(target=bootstrap_thread).start()
549
655
 
550
656
 
551
- class NixlKVSender(BaseKVSender):
657
+ class NixlKVSender(CommonKVSender):
552
658
 
553
659
  def __init__(
554
660
  self,
@@ -558,20 +664,10 @@ class NixlKVSender(BaseKVSender):
558
664
  dest_tp_ranks: List[int],
559
665
  pp_rank: int,
560
666
  ):
561
- self.kv_mgr = mgr
562
- self.bootstrap_room = bootstrap_room
563
- self.aux_index = None
564
- self.bootstrap_server_url = bootstrap_addr
667
+ super().__init__(mgr, bootstrap_addr, bootstrap_room, dest_tp_ranks, pp_rank)
565
668
  self.xfer_handles = []
566
669
  self.has_sent = False
567
670
  self.chunk_id = 0
568
- self.kv_mgr.update_status(self.bootstrap_room, KVPoll.Bootstrapping)
569
- # inner state
570
- self.curr_idx = 0
571
-
572
- def init(self, num_kv_indices: int, aux_index: Optional[int] = None):
573
- self.num_kv_indices = num_kv_indices
574
- self.aux_index = aux_index
575
671
 
576
672
  def send(
577
673
  self,
@@ -621,6 +717,12 @@ class NixlKVReceiver(CommonKVReceiver):
621
717
  self.conclude_state = None
622
718
  super().__init__(mgr, bootstrap_addr, bootstrap_room, prefill_dp_rank)
623
719
 
720
+ # Track this room with its bootstrap address for heartbeat monitoring
721
+ if hasattr(self.kv_mgr, "addr_to_rooms_tracker"):
722
+ self.kv_mgr.addr_to_rooms_tracker[self.bootstrap_addr].add(
723
+ self.bootstrap_room
724
+ )
725
+
624
726
  def init(self, kv_indices: npt.NDArray[np.int32], aux_index: Optional[int] = None):
625
727
  for bootstrap_info in self.bootstrap_infos:
626
728
  logger.debug(
@@ -655,9 +757,16 @@ class NixlKVReceiver(CommonKVReceiver):
655
757
 
656
758
  self.kv_mgr.update_transfer_status()
657
759
  if self.kv_mgr.check_transfer_done(self.bootstrap_room): # type: ignore
658
- self.conclude_state = KVPoll.Success
760
+ # Check if the transfer failed
761
+ if self.kv_mgr.transfer_statuses[self.bootstrap_room].is_failed():
762
+ self.conclude_state = KVPoll.Failed
763
+ logger.error(
764
+ f"Transfer for room {self.bootstrap_room} failed due to node failure"
765
+ )
766
+ else:
767
+ self.conclude_state = KVPoll.Success
659
768
  del self.kv_mgr.transfer_statuses[self.bootstrap_room]
660
- return KVPoll.Success # type: ignore
769
+ return self.conclude_state # type: ignore
661
770
  return KVPoll.WaitingForInput # type: ignore
662
771
 
663
772
  def _register_kv_args(self):
@@ -21,6 +21,7 @@ from __future__ import annotations
21
21
 
22
22
  import logging
23
23
  import threading
24
+ import time
24
25
  from collections import deque
25
26
  from http import HTTPStatus
26
27
  from typing import TYPE_CHECKING, List, Optional, Type
@@ -42,7 +43,12 @@ from sglang.srt.disaggregation.utils import (
42
43
  poll_and_all_reduce,
43
44
  prepare_abort,
44
45
  )
45
- from sglang.srt.managers.schedule_batch import FINISH_LENGTH, Req, ScheduleBatch
46
+ from sglang.srt.managers.schedule_batch import (
47
+ FINISH_LENGTH,
48
+ Req,
49
+ RequestStage,
50
+ ScheduleBatch,
51
+ )
46
52
  from sglang.srt.model_executor.forward_batch_info import ForwardMode, PPProxyTensors
47
53
  from sglang.srt.utils import (
48
54
  DynamicGradMode,
@@ -170,6 +176,7 @@ class PrefillBootstrapQueue:
170
176
  pp_rank=self.pp_rank,
171
177
  )
172
178
  self._process_req(req)
179
+ req.add_latency(RequestStage.PREFILL_PREPARE)
173
180
  self.queue.append(req)
174
181
 
175
182
  def extend(self, reqs: List[Req], num_kv_heads: int) -> None:
@@ -256,8 +263,11 @@ class PrefillBootstrapQueue:
256
263
 
257
264
  num_pages = kv_to_page_num(num_kv_indices, self.token_to_kv_pool.page_size)
258
265
  req.disagg_kv_sender.init(num_pages, req.metadata_buffer_index)
266
+
259
267
  bootstrapped_reqs.append(req)
260
268
  indices_to_remove.add(i)
269
+ req.time_stats.wait_queue_entry_time = time.perf_counter()
270
+ req.add_latency(RequestStage.PREFILL_BOOTSTRAP)
261
271
 
262
272
  self.queue = [
263
273
  entry for i, entry in enumerate(self.queue) if i not in indices_to_remove
@@ -399,11 +409,11 @@ class SchedulerDisaggregationPrefillMixin:
399
409
  for i, (req, next_token_id) in enumerate(
400
410
  zip(batch.reqs, next_token_ids, strict=True)
401
411
  ):
402
- req: Req
403
412
  if req.is_chunked <= 0:
404
413
  # There is no output_ids for prefill
405
414
  req.output_ids.append(next_token_id)
406
415
  self.tree_cache.cache_unfinished_req(req) # update the tree and lock
416
+ req.add_latency(RequestStage.PREFILL_FORWARD)
407
417
  self.disagg_prefill_inflight_queue.append(req)
408
418
  if (
409
419
  logits_output is not None
@@ -412,9 +422,16 @@ class SchedulerDisaggregationPrefillMixin:
412
422
  last_hidden_index = (
413
423
  hidden_state_offset + extend_input_len_per_req[i] - 1
414
424
  )
415
- req.hidden_states_tensor = (
416
- logits_output.hidden_states[last_hidden_index].cpu().clone()
417
- )
425
+ req.output_topk_p = batch.spec_info.topk_p[i]
426
+ req.output_topk_index = batch.spec_info.topk_index[i]
427
+ if self.spec_algorithm.is_eagle3():
428
+ req.hidden_states_tensor = (
429
+ batch.spec_info.hidden_states[i].cpu().clone()
430
+ )
431
+ else:
432
+ req.hidden_states_tensor = (
433
+ logits_output.hidden_states[last_hidden_index].cpu().clone()
434
+ )
418
435
  hidden_state_offset += extend_input_len_per_req[i]
419
436
  else:
420
437
  req.hidden_states_tensor = None
@@ -434,6 +451,7 @@ class SchedulerDisaggregationPrefillMixin:
434
451
  )
435
452
  logprob_pt += num_input_logprobs
436
453
  self.send_kv_chunk(req, last_chunk=True)
454
+ req.time_stats.prefill_transfer_queue_entry_time = time.perf_counter()
437
455
 
438
456
  if req.grammar is not None:
439
457
  # FIXME: this try-except block is for handling unexpected xgrammar issue.
@@ -531,6 +549,9 @@ class SchedulerDisaggregationPrefillMixin:
531
549
  else:
532
550
  assert False, f"Unexpected polling state {poll=}"
533
551
 
552
+ for req in done_reqs:
553
+ req.time_stats.completion_time = time.perf_counter()
554
+
534
555
  # Stream requests which have finished transfer
535
556
  self.stream_output(
536
557
  done_reqs,
@@ -539,6 +560,7 @@ class SchedulerDisaggregationPrefillMixin:
539
560
  )
540
561
  for req in done_reqs:
541
562
  req: Req
563
+ req.add_latency(RequestStage.PREFILL_TRANSFER_KV_CACHE)
542
564
  self.req_to_metadata_buffer_idx_allocator.free(req.metadata_buffer_index)
543
565
  req.metadata_buffer_index = -1
544
566
 
@@ -667,7 +689,6 @@ class SchedulerDisaggregationPrefillMixin:
667
689
  self.running_mbs = [
668
690
  ScheduleBatch(reqs=[], batch_is_full=False) for _ in range(self.pp_size)
669
691
  ]
670
- bids = [None] * self.pp_size
671
692
  pp_outputs: Optional[PPProxyTensors] = None
672
693
 
673
694
  # Either success or failed
@@ -739,10 +760,7 @@ class SchedulerDisaggregationPrefillMixin:
739
760
  # send the outputs to the next step
740
761
  if self.pp_group.is_last_rank:
741
762
  if self.cur_batch:
742
- next_token_ids, bids[mb_id] = (
743
- result.next_token_ids,
744
- result.bid,
745
- )
763
+ next_token_ids = result.next_token_ids
746
764
  pp_outputs = PPProxyTensors(
747
765
  {
748
766
  "next_token_ids": next_token_ids,
@@ -779,7 +797,6 @@ class SchedulerDisaggregationPrefillMixin:
779
797
  next_token_ids=next_pp_outputs["next_token_ids"],
780
798
  extend_input_len_per_req=None,
781
799
  extend_logprob_start_len_per_req=None,
782
- bid=bids[next_mb_id],
783
800
  can_run_cuda_graph=result.can_run_cuda_graph,
784
801
  )
785
802
  self.process_batch_result_disagg_prefill(
@@ -796,8 +813,6 @@ class SchedulerDisaggregationPrefillMixin:
796
813
 
797
814
  # carry the outputs to the next stage
798
815
  if not self.pp_group.is_last_rank:
799
- if self.cur_batch:
800
- bids[mb_id] = result.bid
801
816
  if pp_outputs:
802
817
  # send the outputs from the last round to let the next stage worker run post processing
803
818
  self.pp_group.send_tensor_dict(
@@ -816,8 +831,10 @@ class SchedulerDisaggregationPrefillMixin:
816
831
 
817
832
  # send out proxy tensors to the next stage
818
833
  if self.cur_batch:
834
+ # FIXME(lsyin): remove this assert
835
+ assert result.pp_hidden_states_proxy_tensors.tensors is not None
819
836
  self.pp_group.send_tensor_dict(
820
- result.pp_hidden_states_proxy_tensors,
837
+ result.pp_hidden_states_proxy_tensors.tensors,
821
838
  all_gather_group=self.attn_tp_group,
822
839
  )
823
840
 
@@ -5,7 +5,7 @@ import random
5
5
  from collections import deque
6
6
  from contextlib import nullcontext
7
7
  from enum import Enum
8
- from typing import TYPE_CHECKING, List, Optional, Type, Union
8
+ from typing import TYPE_CHECKING, Optional, Type
9
9
 
10
10
  import numpy as np
11
11
  import torch
@@ -85,7 +85,7 @@ class MetadataBuffers:
85
85
  self,
86
86
  size: int,
87
87
  hidden_size: int,
88
- dtype: torch.dtype,
88
+ hidden_states_dtype: torch.dtype,
89
89
  max_top_logprobs_num: int = 128,
90
90
  custom_mem_pool: torch.cuda.MemPool = None,
91
91
  ):
@@ -107,7 +107,9 @@ class MetadataBuffers:
107
107
  # We transfer the metadata of first output token to decode
108
108
  # The minimal size for RDMA is 64Bytes, so we pad it to > 64Bytes
109
109
  self.output_ids = torch.zeros((size, 16), dtype=torch.int32, device=device)
110
-
110
+ self.cached_tokens = torch.zeros(
111
+ (size, 16), dtype=torch.int32, device=device
112
+ )
111
113
  self.output_token_logprobs_val = torch.zeros(
112
114
  (size, 16), dtype=torch.float32, device=device
113
115
  )
@@ -120,33 +122,49 @@ class MetadataBuffers:
120
122
  self.output_top_logprobs_idx = torch.zeros(
121
123
  (size, max_top_logprobs_num), dtype=torch.int32, device=device
122
124
  )
125
+ # For PD + spec decode
126
+ self.output_topk_p = torch.zeros(
127
+ (size, 16), dtype=torch.float32, device=device
128
+ )
129
+ self.output_topk_index = torch.zeros(
130
+ (size, 16), dtype=torch.int64, device=device
131
+ )
123
132
  self.output_hidden_states = torch.zeros(
124
- (size, hidden_size), dtype=dtype, device=device
133
+ (size, hidden_size), dtype=hidden_states_dtype, device=device
125
134
  )
126
135
 
127
136
  def get_buf_infos(self):
128
137
  ptrs = [
129
138
  self.output_ids.data_ptr(),
139
+ self.cached_tokens.data_ptr(),
130
140
  self.output_token_logprobs_val.data_ptr(),
131
141
  self.output_token_logprobs_idx.data_ptr(),
132
142
  self.output_top_logprobs_val.data_ptr(),
133
143
  self.output_top_logprobs_idx.data_ptr(),
144
+ self.output_topk_p.data_ptr(),
145
+ self.output_topk_index.data_ptr(),
134
146
  self.output_hidden_states.data_ptr(),
135
147
  ]
136
148
  data_lens = [
137
149
  self.output_ids.nbytes,
150
+ self.cached_tokens.nbytes,
138
151
  self.output_token_logprobs_val.nbytes,
139
152
  self.output_token_logprobs_idx.nbytes,
140
153
  self.output_top_logprobs_val.nbytes,
141
154
  self.output_top_logprobs_idx.nbytes,
155
+ self.output_topk_p.nbytes,
156
+ self.output_topk_index.nbytes,
142
157
  self.output_hidden_states.nbytes,
143
158
  ]
144
159
  item_lens = [
145
160
  self.output_ids[0].nbytes,
161
+ self.cached_tokens[0].nbytes,
146
162
  self.output_token_logprobs_val[0].nbytes,
147
163
  self.output_token_logprobs_idx[0].nbytes,
148
164
  self.output_top_logprobs_val[0].nbytes,
149
165
  self.output_top_logprobs_idx[0].nbytes,
166
+ self.output_topk_p[0].nbytes,
167
+ self.output_topk_index[0].nbytes,
150
168
  self.output_hidden_states[0].nbytes,
151
169
  ]
152
170
  return ptrs, data_lens, item_lens
@@ -154,16 +172,20 @@ class MetadataBuffers:
154
172
  def get_buf(self, idx: int):
155
173
  return (
156
174
  self.output_ids[idx],
175
+ self.cached_tokens[idx],
157
176
  self.output_token_logprobs_val[idx],
158
177
  self.output_token_logprobs_idx[idx],
159
178
  self.output_top_logprobs_val[idx],
160
179
  self.output_top_logprobs_idx[idx],
180
+ self.output_topk_p[idx],
181
+ self.output_topk_index[idx],
161
182
  self.output_hidden_states[idx],
162
183
  )
163
184
 
164
185
  def set_buf(self, req: Req):
165
186
 
166
187
  self.output_ids[req.metadata_buffer_index][0] = req.output_ids[0]
188
+ self.cached_tokens[req.metadata_buffer_index][0] = req.cached_tokens
167
189
  if req.return_logprob:
168
190
  if req.output_token_logprobs_val: # not none or empty list
169
191
  self.output_token_logprobs_val[req.metadata_buffer_index][0] = (
@@ -186,8 +208,17 @@ class MetadataBuffers:
186
208
  ] = torch.tensor(
187
209
  req.output_top_logprobs_idx[0], dtype=torch.int32, device="cpu"
188
210
  )
189
- # for PD + spec decode
211
+ # For PD + spec decode
190
212
  if req.hidden_states_tensor is not None:
213
+ # speculative_eagle_topk should not be greater than 16 currently
214
+ topk = req.output_topk_p.size(0)
215
+
216
+ self.output_topk_p[req.metadata_buffer_index, :topk].copy_(
217
+ req.output_topk_p
218
+ )
219
+ self.output_topk_index[req.metadata_buffer_index, :topk].copy_(
220
+ req.output_topk_index
221
+ )
191
222
  self.output_hidden_states[req.metadata_buffer_index].copy_(
192
223
  req.hidden_states_tensor
193
224
  )
@@ -0,0 +1,16 @@
1
+ MiB = 1024 * 1024
2
+
3
+ SYMM_MEM_ALL_REDUCE_MAX_SIZES = {
4
+ 9: {
5
+ 2: 64 * MiB, # 64 MB
6
+ 4: 32 * MiB, # 32 MB
7
+ 6: 64 * MiB, # 64 MB
8
+ 8: 64 * MiB, # 64 MB
9
+ },
10
+ 10: {
11
+ 2: 64 * MiB, # 64 MB
12
+ 4: 32 * MiB, # 32 MB
13
+ 6: 128 * MiB, # 128 MB
14
+ 8: 128 * MiB, # 128 MB
15
+ },
16
+ }
@@ -18,7 +18,7 @@ from zmq import SUB, SUBSCRIBE, XPUB, XPUB_VERBOSE, Context # type: ignore
18
18
 
19
19
  from sglang.srt.utils import (
20
20
  format_tcp_address,
21
- get_ip,
21
+ get_local_ip_auto,
22
22
  get_open_port,
23
23
  is_valid_ipv6_address,
24
24
  )
@@ -191,7 +191,9 @@ class MessageQueue:
191
191
  self.n_remote_reader = n_remote_reader
192
192
 
193
193
  if connect_ip is None:
194
- connect_ip = get_ip() if n_remote_reader > 0 else "127.0.0.1"
194
+ connect_ip = (
195
+ get_local_ip_auto("0.0.0.0") if n_remote_reader > 0 else "127.0.0.1"
196
+ )
195
197
 
196
198
  context = Context()
197
199