sglang 0.5.2rc2__py3-none-any.whl → 0.5.3rc0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (238) hide show
  1. sglang/bench_one_batch_server.py +10 -1
  2. sglang/bench_serving.py +257 -29
  3. sglang/srt/configs/__init__.py +4 -0
  4. sglang/srt/configs/device_config.py +3 -1
  5. sglang/srt/configs/dots_vlm.py +139 -0
  6. sglang/srt/configs/load_config.py +1 -0
  7. sglang/srt/configs/model_config.py +50 -6
  8. sglang/srt/configs/qwen3_next.py +326 -0
  9. sglang/srt/connector/__init__.py +8 -1
  10. sglang/srt/connector/remote_instance.py +82 -0
  11. sglang/srt/constrained/base_grammar_backend.py +48 -12
  12. sglang/srt/constrained/llguidance_backend.py +0 -1
  13. sglang/srt/constrained/outlines_backend.py +0 -1
  14. sglang/srt/constrained/xgrammar_backend.py +28 -9
  15. sglang/srt/custom_op.py +11 -1
  16. sglang/srt/debug_utils/dump_comparator.py +81 -44
  17. sglang/srt/debug_utils/dump_loader.py +97 -0
  18. sglang/srt/debug_utils/dumper.py +11 -3
  19. sglang/srt/debug_utils/text_comparator.py +73 -11
  20. sglang/srt/disaggregation/base/conn.py +1 -1
  21. sglang/srt/disaggregation/common/conn.py +15 -12
  22. sglang/srt/disaggregation/decode.py +21 -10
  23. sglang/srt/disaggregation/decode_schedule_batch_mixin.py +4 -1
  24. sglang/srt/disaggregation/fake/conn.py +1 -1
  25. sglang/srt/disaggregation/mini_lb.py +6 -445
  26. sglang/srt/disaggregation/mooncake/conn.py +18 -10
  27. sglang/srt/disaggregation/nixl/conn.py +180 -16
  28. sglang/srt/disaggregation/prefill.py +5 -3
  29. sglang/srt/disaggregation/utils.py +5 -50
  30. sglang/srt/distributed/parallel_state.py +24 -3
  31. sglang/srt/entrypoints/engine.py +38 -17
  32. sglang/srt/entrypoints/grpc_request_manager.py +580 -0
  33. sglang/srt/entrypoints/grpc_server.py +680 -0
  34. sglang/srt/entrypoints/http_server.py +85 -54
  35. sglang/srt/entrypoints/openai/protocol.py +4 -1
  36. sglang/srt/entrypoints/openai/serving_base.py +46 -3
  37. sglang/srt/entrypoints/openai/serving_chat.py +36 -16
  38. sglang/srt/entrypoints/openai/serving_completions.py +12 -3
  39. sglang/srt/entrypoints/openai/serving_embedding.py +8 -3
  40. sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
  41. sglang/srt/entrypoints/openai/serving_responses.py +6 -3
  42. sglang/srt/entrypoints/openai/serving_score.py +1 -0
  43. sglang/srt/eplb/eplb_manager.py +2 -2
  44. sglang/srt/eplb/expert_distribution.py +26 -13
  45. sglang/srt/eplb/expert_location.py +8 -3
  46. sglang/srt/eplb/expert_location_updater.py +1 -1
  47. sglang/srt/function_call/base_format_detector.py +3 -6
  48. sglang/srt/function_call/ebnf_composer.py +11 -9
  49. sglang/srt/function_call/function_call_parser.py +6 -0
  50. sglang/srt/function_call/glm4_moe_detector.py +1 -1
  51. sglang/srt/function_call/qwen3_coder_detector.py +1 -1
  52. sglang/srt/grpc/__init__.py +1 -0
  53. sglang/srt/grpc/sglang_scheduler_pb2.py +106 -0
  54. sglang/srt/grpc/sglang_scheduler_pb2.pyi +427 -0
  55. sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +236 -0
  56. sglang/srt/hf_transformers_utils.py +4 -0
  57. sglang/srt/layers/activation.py +142 -9
  58. sglang/srt/layers/attention/ascend_backend.py +11 -4
  59. sglang/srt/layers/attention/fla/chunk.py +242 -0
  60. sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
  61. sglang/srt/layers/attention/fla/chunk_o.py +178 -0
  62. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
  63. sglang/srt/layers/attention/fla/cumsum.py +300 -0
  64. sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
  65. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
  66. sglang/srt/layers/attention/fla/index.py +37 -0
  67. sglang/srt/layers/attention/fla/l2norm.py +150 -0
  68. sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
  69. sglang/srt/layers/attention/fla/op.py +66 -0
  70. sglang/srt/layers/attention/fla/solve_tril.py +465 -0
  71. sglang/srt/layers/attention/fla/utils.py +331 -0
  72. sglang/srt/layers/attention/fla/wy_fast.py +158 -0
  73. sglang/srt/layers/attention/flashinfer_backend.py +6 -4
  74. sglang/srt/layers/attention/flashinfer_mla_backend.py +16 -12
  75. sglang/srt/layers/attention/hybrid_attn_backend.py +57 -50
  76. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +602 -0
  77. sglang/srt/layers/attention/intel_amx_backend.py +3 -0
  78. sglang/srt/layers/attention/mamba/causal_conv1d.py +128 -0
  79. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +1052 -0
  80. sglang/srt/layers/attention/mamba/mamba.py +64 -0
  81. sglang/srt/layers/attention/torch_native_backend.py +12 -6
  82. sglang/srt/layers/attention/triton_backend.py +18 -1
  83. sglang/srt/layers/attention/trtllm_mla_backend.py +124 -31
  84. sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
  85. sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
  86. sglang/srt/layers/dp_attention.py +30 -1
  87. sglang/srt/layers/layernorm.py +32 -15
  88. sglang/srt/layers/linear.py +34 -3
  89. sglang/srt/layers/logits_processor.py +29 -10
  90. sglang/srt/layers/moe/__init__.py +2 -1
  91. sglang/srt/layers/moe/cutlass_w4a8_moe.py +3 -3
  92. sglang/srt/layers/moe/ep_moe/kernels.py +1 -1
  93. sglang/srt/layers/moe/ep_moe/layer.py +182 -62
  94. sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +156 -0
  95. sglang/srt/layers/moe/fused_moe_native.py +5 -3
  96. sglang/srt/layers/moe/fused_moe_triton/configs/{triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +35 -35
  97. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
  98. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
  99. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  100. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
  101. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
  102. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
  103. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
  104. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  105. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
  106. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
  107. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +1 -1
  108. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
  109. sglang/srt/layers/moe/fused_moe_triton/layer.py +61 -59
  110. sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
  111. sglang/srt/layers/moe/moe_runner/base.py +274 -1
  112. sglang/srt/layers/moe/moe_runner/runner.py +80 -0
  113. sglang/srt/layers/moe/moe_runner/triton.py +448 -0
  114. sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
  115. sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
  116. sglang/srt/layers/moe/token_dispatcher/deepep.py +43 -39
  117. sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
  118. sglang/srt/layers/moe/topk.py +30 -9
  119. sglang/srt/layers/moe/utils.py +12 -6
  120. sglang/srt/layers/quantization/awq.py +19 -7
  121. sglang/srt/layers/quantization/base_config.py +11 -6
  122. sglang/srt/layers/quantization/blockwise_int8.py +38 -27
  123. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
  124. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
  125. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
  126. sglang/srt/layers/quantization/fp8.py +76 -47
  127. sglang/srt/layers/quantization/fp8_utils.py +50 -31
  128. sglang/srt/layers/quantization/gptq.py +25 -17
  129. sglang/srt/layers/quantization/modelopt_quant.py +147 -47
  130. sglang/srt/layers/quantization/moe_wna16.py +21 -18
  131. sglang/srt/layers/quantization/mxfp4.py +64 -40
  132. sglang/srt/layers/quantization/quark/quark_moe.py +32 -27
  133. sglang/srt/layers/quantization/unquant.py +135 -47
  134. sglang/srt/layers/quantization/w4afp8.py +30 -17
  135. sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
  136. sglang/srt/layers/quantization/w8a8_int8.py +76 -38
  137. sglang/srt/layers/sampler.py +162 -18
  138. sglang/srt/lora/backend/base_backend.py +50 -8
  139. sglang/srt/lora/backend/triton_backend.py +90 -2
  140. sglang/srt/lora/layers.py +32 -0
  141. sglang/srt/lora/lora.py +4 -1
  142. sglang/srt/lora/lora_manager.py +35 -112
  143. sglang/srt/lora/mem_pool.py +24 -10
  144. sglang/srt/lora/utils.py +18 -9
  145. sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
  146. sglang/srt/managers/cache_controller.py +158 -160
  147. sglang/srt/managers/data_parallel_controller.py +105 -35
  148. sglang/srt/managers/detokenizer_manager.py +8 -4
  149. sglang/srt/managers/disagg_service.py +46 -0
  150. sglang/srt/managers/io_struct.py +199 -12
  151. sglang/srt/managers/mm_utils.py +1 -0
  152. sglang/srt/managers/multi_tokenizer_mixin.py +350 -400
  153. sglang/srt/managers/schedule_batch.py +77 -56
  154. sglang/srt/managers/schedule_policy.py +1 -1
  155. sglang/srt/managers/scheduler.py +187 -39
  156. sglang/srt/managers/scheduler_metrics_mixin.py +4 -3
  157. sglang/srt/managers/scheduler_output_processor_mixin.py +55 -11
  158. sglang/srt/managers/scheduler_profiler_mixin.py +1 -1
  159. sglang/srt/managers/tokenizer_communicator_mixin.py +569 -0
  160. sglang/srt/managers/tokenizer_manager.py +259 -519
  161. sglang/srt/managers/tp_worker.py +53 -4
  162. sglang/srt/managers/tp_worker_overlap_thread.py +42 -19
  163. sglang/srt/mem_cache/hicache_storage.py +3 -23
  164. sglang/srt/mem_cache/hiradix_cache.py +103 -43
  165. sglang/srt/mem_cache/memory_pool.py +347 -48
  166. sglang/srt/mem_cache/memory_pool_host.py +105 -46
  167. sglang/srt/mem_cache/radix_cache.py +0 -2
  168. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
  169. sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
  170. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +86 -4
  171. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +280 -0
  172. sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
  173. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +49 -7
  174. sglang/srt/mem_cache/swa_radix_cache.py +0 -2
  175. sglang/srt/metrics/collector.py +493 -76
  176. sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
  177. sglang/srt/model_executor/cpu_graph_runner.py +640 -0
  178. sglang/srt/model_executor/cuda_graph_runner.py +13 -5
  179. sglang/srt/model_executor/forward_batch_info.py +59 -2
  180. sglang/srt/model_executor/model_runner.py +356 -29
  181. sglang/srt/model_loader/__init__.py +9 -3
  182. sglang/srt/model_loader/loader.py +128 -4
  183. sglang/srt/model_loader/weight_utils.py +2 -1
  184. sglang/srt/models/apertus.py +686 -0
  185. sglang/srt/models/bailing_moe.py +798 -218
  186. sglang/srt/models/bailing_moe_nextn.py +168 -0
  187. sglang/srt/models/deepseek_v2.py +109 -15
  188. sglang/srt/models/dots_vlm.py +174 -0
  189. sglang/srt/models/dots_vlm_vit.py +337 -0
  190. sglang/srt/models/ernie4.py +1 -1
  191. sglang/srt/models/gemma3n_mm.py +1 -1
  192. sglang/srt/models/glm4_moe.py +1 -1
  193. sglang/srt/models/glm4v.py +4 -2
  194. sglang/srt/models/glm4v_moe.py +3 -0
  195. sglang/srt/models/gpt_oss.py +1 -1
  196. sglang/srt/models/llama4.py +9 -0
  197. sglang/srt/models/llama_eagle3.py +13 -0
  198. sglang/srt/models/longcat_flash.py +2 -2
  199. sglang/srt/models/mllama4.py +25 -0
  200. sglang/srt/models/opt.py +637 -0
  201. sglang/srt/models/qwen2.py +7 -0
  202. sglang/srt/models/qwen2_5_vl.py +27 -3
  203. sglang/srt/models/qwen2_moe.py +56 -12
  204. sglang/srt/models/qwen3_moe.py +1 -1
  205. sglang/srt/models/qwen3_next.py +1042 -0
  206. sglang/srt/models/qwen3_next_mtp.py +112 -0
  207. sglang/srt/models/step3_vl.py +1 -1
  208. sglang/srt/multimodal/processors/dots_vlm.py +99 -0
  209. sglang/srt/multimodal/processors/glm4v.py +9 -9
  210. sglang/srt/multimodal/processors/internvl.py +141 -129
  211. sglang/srt/multimodal/processors/qwen_vl.py +15 -5
  212. sglang/srt/offloader.py +27 -3
  213. sglang/srt/remote_instance_weight_loader_utils.py +69 -0
  214. sglang/srt/sampling/sampling_batch_info.py +18 -15
  215. sglang/srt/server_args.py +276 -35
  216. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +5 -0
  217. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +10 -1
  218. sglang/srt/speculative/eagle_utils.py +0 -2
  219. sglang/srt/speculative/eagle_worker.py +43 -4
  220. sglang/srt/speculative/spec_info.py +5 -0
  221. sglang/srt/speculative/standalone_worker.py +109 -0
  222. sglang/srt/tracing/trace.py +552 -0
  223. sglang/srt/utils.py +34 -3
  224. sglang/srt/weight_sync/utils.py +1 -1
  225. sglang/test/attention/test_trtllm_mla_backend.py +169 -5
  226. sglang/test/runners.py +4 -0
  227. sglang/test/test_cutlass_moe.py +24 -6
  228. sglang/test/test_disaggregation_utils.py +66 -0
  229. sglang/test/test_fp4_moe.py +370 -1
  230. sglang/test/test_utils.py +28 -1
  231. sglang/utils.py +11 -0
  232. sglang/version.py +1 -1
  233. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/METADATA +59 -123
  234. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/RECORD +237 -178
  235. sglang/srt/disaggregation/launch_lb.py +0 -118
  236. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/WHEEL +0 -0
  237. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/licenses/LICENSE +0 -0
  238. {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/top_level.txt +0 -0
@@ -78,6 +78,9 @@ class KVArgsRegisterInfo:
78
78
  dst_kv_ptrs: list[int]
79
79
  dst_aux_ptrs: list[int]
80
80
  gpu_id: int
81
+ decode_tp_size: int
82
+ decode_tp_rank: int
83
+ dst_kv_item_len: int
81
84
 
82
85
  @classmethod
83
86
  def from_zmq(cls, msg: List[bytes]):
@@ -90,6 +93,9 @@ class KVArgsRegisterInfo:
90
93
  dst_kv_ptrs=list(struct.unpack(f"{len(msg[5])//8}Q", msg[5])),
91
94
  dst_aux_ptrs=list(struct.unpack(f"{len(msg[6])//8}Q", msg[6])),
92
95
  gpu_id=int(msg[7].decode("ascii")),
96
+ decode_tp_size=int(msg[8].decode("ascii")),
97
+ decode_tp_rank=int(msg[9].decode("ascii")),
98
+ dst_kv_item_len=int(msg[10].decode("ascii")),
93
99
  )
94
100
 
95
101
 
@@ -166,7 +172,7 @@ class NixlKVManager(CommonKVManager):
166
172
  self.kv_args.kv_data_ptrs, self.kv_args.kv_data_lens
167
173
  ):
168
174
  kv_addrs.append((kv_data_ptr, kv_data_len, self.kv_args.gpu_id, ""))
169
- self.kv_descs = self.agent.register_memory(kv_addrs, "VRAM", is_sorted=False)
175
+ self.kv_descs = self.agent.register_memory(kv_addrs, "VRAM")
170
176
  logger.debug(f"Register kv tensors, len(kv_addr)= {len(kv_addrs)}")
171
177
  if not self.kv_descs:
172
178
  raise Exception("NIXL memory registration failed for kv tensors")
@@ -175,7 +181,7 @@ class NixlKVManager(CommonKVManager):
175
181
  self.kv_args.aux_data_ptrs, self.kv_args.aux_data_lens
176
182
  ):
177
183
  aux_addrs.append((aux_data_ptr, aux_data_len, 0, ""))
178
- self.aux_descs = self.agent.register_memory(aux_addrs, "DRAM", is_sorted=False)
184
+ self.aux_descs = self.agent.register_memory(aux_addrs, "DRAM")
179
185
  logger.debug(f"Register aux tensors, len(aux_addrs)= {len(aux_addrs)}")
180
186
  if not self.aux_descs:
181
187
  raise Exception("NIXL memory registration failed for aux tensors")
@@ -222,8 +228,8 @@ class NixlKVManager(CommonKVManager):
222
228
  logger.debug(
223
229
  f"len(src_addrs): before group: {len(prefill_kv_indices)}, after group: {len(src_addrs)}"
224
230
  )
225
- src_descs = self.agent.get_xfer_descs(src_addrs, "VRAM", is_sorted=False)
226
- dst_descs = self.agent.get_xfer_descs(dst_addrs, "VRAM", is_sorted=False)
231
+ src_descs = self.agent.get_xfer_descs(src_addrs, "VRAM")
232
+ dst_descs = self.agent.get_xfer_descs(dst_addrs, "VRAM")
227
233
  # Transfer data
228
234
  xfer_handle = self.agent.initialize_xfer(
229
235
  "WRITE",
@@ -239,6 +245,140 @@ class NixlKVManager(CommonKVManager):
239
245
  raise Exception("KVSender failed to post transfer")
240
246
  return xfer_handle
241
247
 
248
+ def send_kvcache_slice(
249
+ self,
250
+ peer_name: str,
251
+ prefill_kv_indices: npt.NDArray[np.int32],
252
+ dst_kv_ptrs: list[int],
253
+ dst_kv_indices: npt.NDArray[np.int32],
254
+ dst_gpu_id: int,
255
+ notif: str,
256
+ prefill_tp_size: int,
257
+ decode_tp_size: int,
258
+ decode_tp_rank: int,
259
+ dst_kv_item_len: int,
260
+ ):
261
+ # Get configuration from kv_args
262
+ local_tp_rank_in_group = self.kv_args.engine_rank % prefill_tp_size
263
+ dst_tp_rank_in_group = decode_tp_rank % decode_tp_size
264
+ num_kv_heads = self.kv_args.kv_head_num
265
+
266
+ # Calculate head distribution
267
+ src_heads_per_rank = num_kv_heads
268
+ dst_heads_per_rank = num_kv_heads * prefill_tp_size // decode_tp_size
269
+
270
+ src_kv_item_len = self.kv_args.kv_item_lens[0]
271
+ page_size = self.kv_args.page_size
272
+
273
+ bytes_per_head_slice_to_send = (
274
+ dst_kv_item_len // page_size // dst_heads_per_rank
275
+ )
276
+
277
+ # Determine which heads to send
278
+ if prefill_tp_size > decode_tp_size:
279
+ # Multiple prefill ranks to one decode rank
280
+ src_head_start_offset = 0
281
+ num_heads_to_send = src_heads_per_rank
282
+ dst_head_start_offset = local_tp_rank_in_group * src_heads_per_rank
283
+ else:
284
+ # Send KVCache from 1 prefill instance to multiple decode instances
285
+ src_head_start_offset = (
286
+ dst_tp_rank_in_group * dst_heads_per_rank
287
+ ) % src_heads_per_rank
288
+ num_heads_to_send = dst_heads_per_rank
289
+ dst_head_start_offset = 0
290
+
291
+ # Create transfer descriptors
292
+ src_addrs = []
293
+ dst_addrs = []
294
+
295
+ bytes_per_token_on_prefill = src_kv_item_len // page_size
296
+ bytes_per_token_on_decode = dst_kv_item_len // page_size
297
+
298
+ num_kv_layers = len(self.kv_args.kv_data_ptrs) // 2
299
+ src_k_ptrs = self.kv_args.kv_data_ptrs[:num_kv_layers]
300
+ src_v_ptrs = self.kv_args.kv_data_ptrs[num_kv_layers:]
301
+ dst_k_ptrs = dst_kv_ptrs[0 : len(src_k_ptrs)]
302
+ dst_v_ptrs = dst_kv_ptrs[num_kv_layers : num_kv_layers + len(src_v_ptrs)]
303
+
304
+ # Calculate precise byte offset and length for the sub-slice within the token
305
+ src_head_slice_offset = src_head_start_offset * bytes_per_head_slice_to_send
306
+ dst_head_slice_offset = dst_head_start_offset * bytes_per_head_slice_to_send
307
+ heads_bytes_per_token_to_send = num_heads_to_send * bytes_per_head_slice_to_send
308
+
309
+ src_dst_ptr_pairs = [
310
+ (
311
+ src_k_ptrs[layer_id],
312
+ dst_k_ptrs[layer_id],
313
+ )
314
+ for layer_id in range(len(src_k_ptrs))
315
+ ] + [
316
+ (
317
+ src_v_ptrs[layer_id],
318
+ dst_v_ptrs[layer_id],
319
+ )
320
+ for layer_id in range(len(src_v_ptrs))
321
+ ]
322
+
323
+ src_addrs = []
324
+ dst_addrs = []
325
+
326
+ # Calculate strides for a single token slot
327
+ bytes_per_token_on_prefill = src_kv_item_len // page_size
328
+ bytes_per_token_on_decode = dst_kv_item_len // page_size
329
+
330
+ for src_ptr, dst_ptr in src_dst_ptr_pairs:
331
+ for i in range(len(prefill_kv_indices)):
332
+ prefill_page_idx = int(prefill_kv_indices[i])
333
+ decode_page_idx = int(dst_kv_indices[i])
334
+
335
+ # Get the starting addresses for the current src and dst pages
336
+ src_page_start_addr = src_ptr + prefill_page_idx * src_kv_item_len
337
+ dst_page_start_addr = dst_ptr + decode_page_idx * dst_kv_item_len
338
+
339
+ # Iterate through each valid token slot within the current page
340
+ for token_slot_in_page in range(page_size):
341
+ # Calculate the start address of the current token slot
342
+ src_token_slot_start_addr = (
343
+ src_page_start_addr
344
+ + token_slot_in_page * bytes_per_token_on_prefill
345
+ )
346
+ dst_token_slot_start_addr = (
347
+ dst_page_start_addr
348
+ + token_slot_in_page * bytes_per_token_on_decode
349
+ )
350
+
351
+ # Calculate final src and dst addresses by applying head-slice offsets
352
+ src_slice_addr = src_token_slot_start_addr + src_head_slice_offset
353
+ dst_slice_addr = dst_token_slot_start_addr + dst_head_slice_offset
354
+
355
+ src_addrs.append(
356
+ (
357
+ src_slice_addr,
358
+ heads_bytes_per_token_to_send,
359
+ self.kv_args.gpu_id,
360
+ )
361
+ )
362
+ dst_addrs.append(
363
+ (dst_slice_addr, heads_bytes_per_token_to_send, dst_gpu_id)
364
+ )
365
+
366
+ # Use NIXL agent for transfer
367
+ src_descs = self.agent.get_xfer_descs(src_addrs, "VRAM")
368
+ dst_descs = self.agent.get_xfer_descs(dst_addrs, "VRAM")
369
+
370
+ xfer_handle = self.agent.initialize_xfer(
371
+ "WRITE", src_descs, dst_descs, peer_name, notif.encode("ascii")
372
+ )
373
+ if not xfer_handle:
374
+ raise Exception("Failed to create sliced KV transfer")
375
+
376
+ state = self.agent.transfer(xfer_handle)
377
+ if state == "ERR":
378
+ raise Exception("Failed to post sliced KV transfer")
379
+
380
+ return xfer_handle
381
+
242
382
  def send_aux(
243
383
  self,
244
384
  peer_name: str,
@@ -255,8 +395,8 @@ class NixlKVManager(CommonKVManager):
255
395
  decode_aux_addr = dst_aux_ptrs[0] + dst_aux_index * aux_item_len
256
396
  src_addrs = [(prefill_aux_addr, aux_item_len, 0)]
257
397
  dst_addrs = [(decode_aux_addr, aux_item_len, 0)]
258
- src_descs = self.agent.get_xfer_descs(src_addrs, "DRAM", is_sorted=False)
259
- dst_descs = self.agent.get_xfer_descs(dst_addrs, "DRAM", is_sorted=False)
398
+ src_descs = self.agent.get_xfer_descs(src_addrs, "DRAM")
399
+ dst_descs = self.agent.get_xfer_descs(dst_addrs, "DRAM")
260
400
  # Transfer data
261
401
  xfer_handle = self.agent.initialize_xfer(
262
402
  "WRITE",
@@ -296,14 +436,35 @@ class NixlKVManager(CommonKVManager):
296
436
  assert req.agent_name in self.decode_kv_args_table
297
437
 
298
438
  notif = "_".join([str(req.room), "kv", str(chunk_id), str(int(is_last))])
299
- kv_xfer_handle = self.send_kvcache(
300
- req.agent_name,
301
- kv_indices,
302
- self.decode_kv_args_table[req.agent_name].dst_kv_ptrs,
303
- chunked_dst_kv_indice,
304
- self.decode_kv_args_table[req.agent_name].gpu_id,
305
- notif,
306
- )
439
+ decode_tp_size = self.decode_kv_args_table[req.agent_name].decode_tp_size
440
+
441
+ if decode_tp_size == self.tp_size:
442
+ kv_xfer_handle = self.send_kvcache(
443
+ req.agent_name,
444
+ kv_indices,
445
+ self.decode_kv_args_table[req.agent_name].dst_kv_ptrs,
446
+ chunked_dst_kv_indice,
447
+ self.decode_kv_args_table[req.agent_name].gpu_id,
448
+ notif,
449
+ )
450
+ else:
451
+ kv_xfer_handle = self.send_kvcache_slice(
452
+ req.agent_name,
453
+ kv_indices,
454
+ self.decode_kv_args_table[req.agent_name].dst_kv_ptrs,
455
+ chunked_dst_kv_indice,
456
+ self.decode_kv_args_table[req.agent_name].gpu_id,
457
+ notif,
458
+ prefill_tp_size=self.tp_size,
459
+ decode_tp_size=decode_tp_size,
460
+ decode_tp_rank=self.decode_kv_args_table[
461
+ req.agent_name
462
+ ].decode_tp_rank,
463
+ dst_kv_item_len=self.decode_kv_args_table[
464
+ req.agent_name
465
+ ].dst_kv_item_len,
466
+ )
467
+
307
468
  handles.append(kv_xfer_handle)
308
469
  # Only the last chunk we need to send the aux data.
309
470
  if is_last:
@@ -454,11 +615,11 @@ class NixlKVReceiver(CommonKVReceiver):
454
615
  mgr: NixlKVManager,
455
616
  bootstrap_addr: str,
456
617
  bootstrap_room: Optional[int] = None,
457
- data_parallel_rank: Optional[int] = None,
618
+ prefill_dp_rank: Optional[int] = None,
458
619
  ):
459
620
  self.started_transfer = False
460
621
  self.conclude_state = None
461
- super().__init__(mgr, bootstrap_addr, bootstrap_room, data_parallel_rank)
622
+ super().__init__(mgr, bootstrap_addr, bootstrap_room, prefill_dp_rank)
462
623
 
463
624
  def init(self, kv_indices: npt.NDArray[np.int32], aux_index: Optional[int] = None):
464
625
  for bootstrap_info in self.bootstrap_infos:
@@ -521,6 +682,9 @@ class NixlKVReceiver(CommonKVReceiver):
521
682
  packed_kv_data_ptrs,
522
683
  packed_aux_data_ptrs,
523
684
  str(self.kv_mgr.kv_args.gpu_id).encode("ascii"),
685
+ str(self.kv_mgr.kv_args.decode_tp_size).encode("ascii"),
686
+ str(self.kv_mgr.kv_args.engine_rank).encode("ascii"),
687
+ str(self.kv_mgr.kv_args.kv_item_lens[0]).encode("ascii"),
524
688
  ]
525
689
  )
526
690
 
@@ -23,7 +23,7 @@ import logging
23
23
  import threading
24
24
  from collections import deque
25
25
  from http import HTTPStatus
26
- from typing import TYPE_CHECKING, List, Optional
26
+ from typing import TYPE_CHECKING, List, Optional, Type
27
27
 
28
28
  import torch
29
29
 
@@ -140,8 +140,10 @@ class PrefillBootstrapQueue:
140
140
  kv_args.ib_device = self.scheduler.server_args.disaggregation_ib_device
141
141
  kv_args.gpu_id = self.scheduler.gpu_id
142
142
 
143
- kv_manager_class = get_kv_class(self.transfer_backend, KVClassType.MANAGER)
144
- kv_manager = kv_manager_class(
143
+ kv_manager_class: Type[BaseKVManager] = get_kv_class(
144
+ self.transfer_backend, KVClassType.MANAGER
145
+ )
146
+ kv_manager: BaseKVManager = kv_manager_class(
145
147
  kv_args,
146
148
  DisaggregationMode.PREFILL,
147
149
  self.scheduler.server_args,
@@ -1,21 +1,17 @@
1
1
  from __future__ import annotations
2
2
 
3
- import dataclasses
4
3
  import os
5
4
  import random
6
- import threading
7
- import warnings
8
5
  from collections import deque
9
6
  from contextlib import nullcontext
10
7
  from enum import Enum
11
- from typing import TYPE_CHECKING, List, Optional
8
+ from typing import TYPE_CHECKING, List, Optional, Type, Union
12
9
 
13
10
  import numpy as np
14
- import requests
15
11
  import torch
16
12
  import torch.distributed as dist
17
13
 
18
- from sglang.srt.utils import get_ip, is_npu
14
+ from sglang.srt.utils import is_npu
19
15
 
20
16
  if TYPE_CHECKING:
21
17
  from sglang.srt.managers.schedule_batch import Req
@@ -217,7 +213,9 @@ class KVClassType(Enum):
217
213
  BOOTSTRAP_SERVER = "bootstrap_server"
218
214
 
219
215
 
220
- def get_kv_class(transfer_backend: TransferBackend, class_type: KVClassType):
216
+ def get_kv_class(
217
+ transfer_backend: TransferBackend, class_type: KVClassType
218
+ ) -> Optional[Type]:
221
219
  from sglang.srt.disaggregation.fake import FakeKVReceiver, FakeKVSender
222
220
 
223
221
  if transfer_backend == TransferBackend.MOONCAKE:
@@ -305,49 +303,6 @@ def kv_to_page_num(num_kv_indices: int, page_size: int):
305
303
  return (num_kv_indices + page_size - 1) // page_size
306
304
 
307
305
 
308
- #########################
309
- # PDLB Registry
310
- #########################
311
-
312
-
313
- @dataclasses.dataclass
314
- class PDRegistryRequest:
315
- """A request to register a machine itself to the LB."""
316
-
317
- mode: str
318
- registry_url: str
319
- bootstrap_port: Optional[int] = None
320
-
321
- def __post_init__(self):
322
- if self.mode == "prefill" and self.bootstrap_port is None:
323
- raise ValueError("Bootstrap port must be set in PREFILL mode.")
324
- elif self.mode == "decode" and self.bootstrap_port is not None:
325
- raise ValueError("Bootstrap port must not be set in DECODE mode.")
326
- elif self.mode not in ["prefill", "decode"]:
327
- raise ValueError(
328
- f"Invalid mode: {self.mode}. Must be 'prefill' or 'decode'."
329
- )
330
-
331
-
332
- def register_disaggregation_server(
333
- mode: str, server_port: int, bootstrap_port: int, pdlb_url: str
334
- ):
335
- boostrap_port = bootstrap_port if mode == "prefill" else None
336
- registry_request = PDRegistryRequest(
337
- mode=mode,
338
- registry_url=f"http://{get_ip()}:{server_port}",
339
- bootstrap_port=boostrap_port,
340
- )
341
- res = requests.post(
342
- f"{pdlb_url}/register",
343
- json=dataclasses.asdict(registry_request),
344
- )
345
- if res.status_code != 200:
346
- warnings.warn(
347
- f"Failed to register disaggregation server: {res.status_code} {res.text}"
348
- )
349
-
350
-
351
306
  #########################
352
307
  # Misc
353
308
  #########################
@@ -64,6 +64,9 @@ class GraphCaptureContext:
64
64
 
65
65
  TensorMetadata = namedtuple("TensorMetadata", ["device", "dtype", "size"])
66
66
 
67
+ # use int value instead of ReduceOp.SUM to support torch compile
68
+ REDUCE_OP_SUM = int(torch.distributed.ReduceOp.SUM)
69
+
67
70
 
68
71
  def _split_tensor_dict(
69
72
  tensor_dict: Dict[str, Union[torch.Tensor, Any]]
@@ -489,9 +492,7 @@ class GroupCoordinator:
489
492
 
490
493
  if input_.is_cpu:
491
494
  if is_shm_available(input_.dtype, self.world_size, self.local_size):
492
- torch.ops.sgl_kernel.shm_allreduce(
493
- input_, torch.distributed.ReduceOp.SUM
494
- )
495
+ torch.ops.sgl_kernel.shm_allreduce(input_, REDUCE_OP_SUM)
495
496
  else:
496
497
  torch.distributed.all_reduce(input_, group=self.device_group)
497
498
  return input_
@@ -1586,6 +1587,16 @@ def patch_tensor_parallel_group(tp_group: GroupCoordinator):
1586
1587
  _TP = old_tp_group
1587
1588
 
1588
1589
 
1590
+ def get_world_size():
1591
+ """Return world size for the world group."""
1592
+ return get_world_group().world_size
1593
+
1594
+
1595
+ def get_world_rank():
1596
+ """Return my rank for the world group."""
1597
+ return get_world_group().rank_in_group
1598
+
1599
+
1589
1600
  def get_tensor_model_parallel_world_size():
1590
1601
  """Return world size for the tensor model parallel group."""
1591
1602
  return get_tp_group().world_size
@@ -1596,6 +1607,16 @@ def get_tensor_model_parallel_rank():
1596
1607
  return get_tp_group().rank_in_group
1597
1608
 
1598
1609
 
1610
+ def get_pipeline_model_parallel_world_size():
1611
+ """Return world size for the pipeline model parallel group."""
1612
+ return get_pp_group().world_size
1613
+
1614
+
1615
+ def get_pipeline_model_parallel_rank():
1616
+ """Return my rank for the pipeline model parallel group."""
1617
+ return get_pp_group().rank_in_group
1618
+
1619
+
1599
1620
  def get_moe_expert_parallel_world_size():
1600
1621
  """Return world size for the moe expert parallel group."""
1601
1622
  return get_moe_ep_group().world_size
@@ -33,6 +33,8 @@ import zmq
33
33
  import zmq.asyncio
34
34
  from PIL.Image import Image
35
35
 
36
+ from sglang.srt.tracing.trace import process_tracing_init, trace_set_thread_info
37
+
36
38
  # Fix a bug of Python threading
37
39
  setattr(threading, "_register_atexit", lambda *args, **kwargs: None)
38
40
 
@@ -138,6 +140,12 @@ class Engine(EngineBase):
138
140
  context, zmq.DEALER, self.port_args.rpc_ipc_name, True
139
141
  )
140
142
 
143
+ if server_args.enable_trace:
144
+ process_tracing_init(server_args.oltp_traces_endpoint, "sglang")
145
+ if server_args.disaggregation_mode == "null":
146
+ thread_label = "Tokenizer"
147
+ trace_set_thread_info(thread_label)
148
+
141
149
  def generate(
142
150
  self,
143
151
  # The input prompt. It can be a single prompt or a batch of prompts.
@@ -364,9 +372,9 @@ class Engine(EngineBase):
364
372
  loop = asyncio.get_event_loop()
365
373
  return loop.run_until_complete(self.tokenizer_manager.flush_cache())
366
374
 
367
- def start_profile(self):
375
+ def start_profile(self, **kwargs):
368
376
  loop = asyncio.get_event_loop()
369
- loop.run_until_complete(self.tokenizer_manager.start_profile())
377
+ loop.run_until_complete(self.tokenizer_manager.start_profile(**kwargs))
370
378
 
371
379
  def stop_profile(self):
372
380
  loop = asyncio.get_event_loop()
@@ -655,7 +663,8 @@ def _set_envs_and_config(server_args: ServerArgs):
655
663
  os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "4"
656
664
  os.environ["CUDA_MODULE_LOADING"] = "AUTO"
657
665
  # flashinfer uses this environment variable for various kernels from MoE to quant kernels
658
- os.environ["TRTLLM_ENABLE_PDL"] = "1"
666
+ if os.environ.get("TRTLLM_ENABLE_PDL", "1") != "0":
667
+ os.environ["TRTLLM_ENABLE_PDL"] = "1"
659
668
 
660
669
  # Can also be passed as argument
661
670
  os.environ["SGLANG_RUN_ID"] = (
@@ -673,7 +682,7 @@ def _set_envs_and_config(server_args: ServerArgs):
673
682
  if server_args.attention_backend == "flashinfer":
674
683
  assert_pkg_version(
675
684
  "flashinfer_python",
676
- "0.3.0",
685
+ "0.3.1",
677
686
  "Please uninstall the old version and "
678
687
  "reinstall the latest version by following the instructions "
679
688
  "at https://docs.flashinfer.ai/installation.html.",
@@ -681,7 +690,7 @@ def _set_envs_and_config(server_args: ServerArgs):
681
690
  if _is_cuda and not get_bool_env_var("SGLANG_SKIP_SGL_KERNEL_VERSION_CHECK"):
682
691
  assert_pkg_version(
683
692
  "sgl-kernel",
684
- "0.3.8",
693
+ "0.3.9.post2",
685
694
  "Please reinstall the latest version with `pip install sgl-kernel --force-reinstall`",
686
695
  )
687
696
 
@@ -703,6 +712,24 @@ def _set_envs_and_config(server_args: ServerArgs):
703
712
  mp.set_start_method("spawn", force=True)
704
713
 
705
714
 
715
+ def _init_tokenizer_manager(
716
+ server_args: ServerArgs, port_args: PortArgs
717
+ ) -> TokenizerManager:
718
+ # Launch tokenizer process
719
+ tokenizer_manager = TokenizerManager(server_args, port_args)
720
+
721
+ # Initialize templates
722
+ template_manager = TemplateManager()
723
+ template_manager.initialize_templates(
724
+ tokenizer_manager=tokenizer_manager,
725
+ model_path=server_args.model_path,
726
+ chat_template=server_args.chat_template,
727
+ completion_template=server_args.completion_template,
728
+ )
729
+
730
+ return tokenizer_manager, template_manager
731
+
732
+
706
733
  def _launch_subprocesses(
707
734
  server_args: ServerArgs, port_args: Optional[PortArgs] = None
708
735
  ) -> Tuple[TokenizerManager, TemplateManager, Dict]:
@@ -815,23 +842,15 @@ def _launch_subprocesses(
815
842
  ),
816
843
  )
817
844
  detoken_proc.start()
845
+
846
+ # Init tokenizer manager first, as the bootstrap server is initialized here
818
847
  if server_args.tokenizer_worker_num > 1:
819
848
  # Launch multi-tokenizer router
820
849
  tokenizer_manager = MultiTokenizerRouter(server_args, port_args)
821
-
822
- # Initialize templates
823
850
  template_manager = None
824
851
  else:
825
- # Launch tokenizer process
826
- tokenizer_manager = TokenizerManager(server_args, port_args)
827
-
828
- # Initialize templates
829
- template_manager = TemplateManager()
830
- template_manager.initialize_templates(
831
- tokenizer_manager=tokenizer_manager,
832
- model_path=server_args.model_path,
833
- chat_template=server_args.chat_template,
834
- completion_template=server_args.completion_template,
852
+ tokenizer_manager, template_manager = _init_tokenizer_manager(
853
+ server_args, port_args
835
854
  )
836
855
 
837
856
  # Wait for the model to finish loading
@@ -855,5 +874,7 @@ def _launch_subprocesses(
855
874
 
856
875
  # Assume all schedulers have the same scheduler_info
857
876
  scheduler_info = scheduler_infos[0]
877
+
858
878
  tokenizer_manager.max_req_input_len = scheduler_info["max_req_input_len"]
879
+
859
880
  return tokenizer_manager, template_manager, scheduler_info