sglang 0.5.1.post3__py3-none-any.whl → 0.5.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (245) hide show
  1. sglang/bench_one_batch.py +3 -0
  2. sglang/bench_one_batch_server.py +10 -1
  3. sglang/bench_serving.py +251 -26
  4. sglang/lang/interpreter.py +1 -1
  5. sglang/srt/configs/__init__.py +4 -0
  6. sglang/srt/configs/internvl.py +6 -0
  7. sglang/srt/configs/longcat_flash.py +104 -0
  8. sglang/srt/configs/model_config.py +37 -7
  9. sglang/srt/configs/qwen3_next.py +326 -0
  10. sglang/srt/connector/__init__.py +1 -1
  11. sglang/srt/connector/base_connector.py +1 -2
  12. sglang/srt/connector/redis.py +2 -2
  13. sglang/srt/connector/serde/__init__.py +1 -1
  14. sglang/srt/connector/serde/safe_serde.py +4 -3
  15. sglang/srt/custom_op.py +11 -1
  16. sglang/srt/debug_utils/dump_comparator.py +81 -44
  17. sglang/srt/debug_utils/dump_loader.py +97 -0
  18. sglang/srt/debug_utils/dumper.py +11 -3
  19. sglang/srt/debug_utils/text_comparator.py +73 -11
  20. sglang/srt/disaggregation/ascend/conn.py +75 -0
  21. sglang/srt/disaggregation/base/conn.py +1 -1
  22. sglang/srt/disaggregation/common/conn.py +15 -12
  23. sglang/srt/disaggregation/decode.py +6 -4
  24. sglang/srt/disaggregation/fake/conn.py +1 -1
  25. sglang/srt/disaggregation/mini_lb.py +6 -420
  26. sglang/srt/disaggregation/mooncake/conn.py +18 -10
  27. sglang/srt/disaggregation/nixl/conn.py +180 -16
  28. sglang/srt/disaggregation/prefill.py +6 -4
  29. sglang/srt/disaggregation/utils.py +5 -50
  30. sglang/srt/distributed/parallel_state.py +94 -58
  31. sglang/srt/entrypoints/engine.py +34 -14
  32. sglang/srt/entrypoints/http_server.py +172 -47
  33. sglang/srt/entrypoints/openai/protocol.py +63 -3
  34. sglang/srt/entrypoints/openai/serving_base.py +6 -2
  35. sglang/srt/entrypoints/openai/serving_chat.py +34 -19
  36. sglang/srt/entrypoints/openai/serving_completions.py +10 -4
  37. sglang/srt/entrypoints/openai/serving_embedding.py +8 -4
  38. sglang/srt/entrypoints/openai/serving_responses.py +7 -4
  39. sglang/srt/eplb/eplb_manager.py +28 -4
  40. sglang/srt/eplb/expert_distribution.py +55 -15
  41. sglang/srt/eplb/expert_location.py +8 -3
  42. sglang/srt/eplb/expert_location_updater.py +1 -1
  43. sglang/srt/function_call/ebnf_composer.py +11 -9
  44. sglang/srt/function_call/glm4_moe_detector.py +1 -1
  45. sglang/srt/function_call/gpt_oss_detector.py +1 -1
  46. sglang/srt/function_call/qwen3_coder_detector.py +1 -1
  47. sglang/srt/hf_transformers_utils.py +12 -0
  48. sglang/srt/layers/activation.py +44 -9
  49. sglang/srt/layers/attention/aiter_backend.py +93 -68
  50. sglang/srt/layers/attention/ascend_backend.py +250 -112
  51. sglang/srt/layers/attention/fla/chunk.py +242 -0
  52. sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
  53. sglang/srt/layers/attention/fla/chunk_o.py +178 -0
  54. sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
  55. sglang/srt/layers/attention/fla/cumsum.py +300 -0
  56. sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
  57. sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
  58. sglang/srt/layers/attention/fla/index.py +37 -0
  59. sglang/srt/layers/attention/fla/l2norm.py +150 -0
  60. sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
  61. sglang/srt/layers/attention/fla/op.py +66 -0
  62. sglang/srt/layers/attention/fla/solve_tril.py +465 -0
  63. sglang/srt/layers/attention/fla/utils.py +331 -0
  64. sglang/srt/layers/attention/fla/wy_fast.py +158 -0
  65. sglang/srt/layers/attention/flashinfer_backend.py +6 -4
  66. sglang/srt/layers/attention/flashinfer_mla_backend.py +16 -12
  67. sglang/srt/layers/attention/hybrid_attn_backend.py +47 -8
  68. sglang/srt/layers/attention/hybrid_linear_attn_backend.py +584 -0
  69. sglang/srt/layers/attention/intel_amx_backend.py +3 -0
  70. sglang/srt/layers/attention/mamba/causal_conv1d.py +128 -0
  71. sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +1052 -0
  72. sglang/srt/layers/attention/mamba/mamba.py +64 -0
  73. sglang/srt/layers/attention/torch_native_backend.py +12 -6
  74. sglang/srt/layers/attention/trtllm_mla_backend.py +126 -36
  75. sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
  76. sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
  77. sglang/srt/layers/communicator.py +45 -7
  78. sglang/srt/layers/layernorm.py +54 -12
  79. sglang/srt/layers/logits_processor.py +10 -3
  80. sglang/srt/layers/moe/__init__.py +2 -1
  81. sglang/srt/layers/moe/cutlass_w4a8_moe.py +4 -12
  82. sglang/srt/layers/moe/ep_moe/kernels.py +74 -0
  83. sglang/srt/layers/moe/ep_moe/layer.py +110 -49
  84. sglang/srt/layers/moe/fused_moe_native.py +5 -3
  85. sglang/srt/layers/moe/fused_moe_triton/__init__.py +5 -3
  86. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
  87. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +146 -0
  88. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/{E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → E=257,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json } +29 -29
  89. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  90. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
  91. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
  92. sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
  93. sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +9 -1049
  94. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +212 -0
  95. sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +799 -0
  96. sglang/srt/layers/moe/fused_moe_triton/layer.py +56 -45
  97. sglang/srt/layers/moe/fused_moe_triton/moe_align_block_size.py +87 -0
  98. sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
  99. sglang/srt/layers/moe/moe_runner/base.py +274 -1
  100. sglang/srt/layers/moe/moe_runner/runner.py +80 -0
  101. sglang/srt/layers/moe/moe_runner/triton.py +448 -0
  102. sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
  103. sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
  104. sglang/srt/layers/moe/token_dispatcher/deepep.py +41 -38
  105. sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
  106. sglang/srt/layers/moe/topk.py +43 -12
  107. sglang/srt/layers/moe/utils.py +6 -5
  108. sglang/srt/layers/quantization/awq.py +19 -7
  109. sglang/srt/layers/quantization/base_config.py +11 -6
  110. sglang/srt/layers/quantization/blockwise_int8.py +38 -27
  111. sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
  112. sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
  113. sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py +9 -1
  114. sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +0 -3
  115. sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
  116. sglang/srt/layers/quantization/fp8.py +76 -47
  117. sglang/srt/layers/quantization/fp8_utils.py +43 -29
  118. sglang/srt/layers/quantization/gptq.py +25 -17
  119. sglang/srt/layers/quantization/modelopt_quant.py +107 -40
  120. sglang/srt/layers/quantization/moe_wna16.py +21 -18
  121. sglang/srt/layers/quantization/mxfp4.py +77 -45
  122. sglang/srt/layers/quantization/quark/quark_moe.py +32 -27
  123. sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +49 -30
  124. sglang/srt/layers/quantization/quark/utils.py +97 -0
  125. sglang/srt/layers/quantization/rocm_mxfp4_utils.py +13 -0
  126. sglang/srt/layers/quantization/unquant.py +135 -47
  127. sglang/srt/layers/quantization/utils.py +13 -0
  128. sglang/srt/layers/quantization/w4afp8.py +60 -42
  129. sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
  130. sglang/srt/layers/quantization/w8a8_int8.py +83 -41
  131. sglang/srt/layers/rocm_linear_utils.py +44 -0
  132. sglang/srt/layers/rotary_embedding.py +28 -19
  133. sglang/srt/layers/sampler.py +29 -5
  134. sglang/srt/lora/backend/base_backend.py +50 -8
  135. sglang/srt/lora/backend/triton_backend.py +90 -2
  136. sglang/srt/lora/layers.py +32 -0
  137. sglang/srt/lora/lora.py +4 -1
  138. sglang/srt/lora/lora_manager.py +35 -112
  139. sglang/srt/lora/mem_pool.py +24 -10
  140. sglang/srt/lora/utils.py +18 -9
  141. sglang/srt/managers/cache_controller.py +242 -278
  142. sglang/srt/managers/data_parallel_controller.py +30 -15
  143. sglang/srt/managers/detokenizer_manager.py +13 -2
  144. sglang/srt/managers/disagg_service.py +46 -0
  145. sglang/srt/managers/io_struct.py +160 -11
  146. sglang/srt/managers/mm_utils.py +6 -1
  147. sglang/srt/managers/multi_tokenizer_mixin.py +579 -0
  148. sglang/srt/managers/schedule_batch.py +27 -44
  149. sglang/srt/managers/schedule_policy.py +4 -3
  150. sglang/srt/managers/scheduler.py +90 -115
  151. sglang/srt/managers/scheduler_metrics_mixin.py +114 -8
  152. sglang/srt/managers/scheduler_output_processor_mixin.py +29 -19
  153. sglang/srt/managers/scheduler_profiler_mixin.py +1 -1
  154. sglang/srt/managers/scheduler_update_weights_mixin.py +8 -1
  155. sglang/srt/managers/template_manager.py +3 -3
  156. sglang/srt/managers/tokenizer_communicator_mixin.py +491 -0
  157. sglang/srt/managers/tokenizer_manager.py +41 -477
  158. sglang/srt/managers/tp_worker.py +16 -4
  159. sglang/srt/managers/tp_worker_overlap_thread.py +8 -10
  160. sglang/srt/mem_cache/allocator.py +1 -1
  161. sglang/srt/mem_cache/chunk_cache.py +1 -1
  162. sglang/srt/mem_cache/hicache_storage.py +24 -22
  163. sglang/srt/mem_cache/hiradix_cache.py +184 -101
  164. sglang/srt/mem_cache/lora_radix_cache.py +1 -1
  165. sglang/srt/mem_cache/memory_pool.py +324 -41
  166. sglang/srt/mem_cache/memory_pool_host.py +25 -18
  167. sglang/srt/mem_cache/radix_cache.py +5 -6
  168. sglang/srt/mem_cache/radix_cache_cpp.py +1 -1
  169. sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
  170. sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
  171. sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +61 -34
  172. sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +149 -12
  173. sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +280 -0
  174. sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
  175. sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +74 -19
  176. sglang/srt/mem_cache/storage/mooncake_store/test_mooncake_store.py +161 -0
  177. sglang/srt/mem_cache/swa_radix_cache.py +1 -3
  178. sglang/srt/metrics/collector.py +484 -63
  179. sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
  180. sglang/srt/metrics/utils.py +48 -0
  181. sglang/srt/model_executor/cpu_graph_runner.py +640 -0
  182. sglang/srt/model_executor/cuda_graph_runner.py +13 -5
  183. sglang/srt/model_executor/forward_batch_info.py +72 -18
  184. sglang/srt/model_executor/model_runner.py +189 -31
  185. sglang/srt/model_loader/__init__.py +9 -3
  186. sglang/srt/model_loader/loader.py +33 -28
  187. sglang/srt/model_loader/utils.py +12 -0
  188. sglang/srt/model_loader/weight_utils.py +2 -1
  189. sglang/srt/models/deepseek_v2.py +311 -50
  190. sglang/srt/models/gemma3n_mm.py +1 -1
  191. sglang/srt/models/glm4_moe.py +10 -1
  192. sglang/srt/models/glm4v.py +4 -2
  193. sglang/srt/models/gpt_oss.py +5 -18
  194. sglang/srt/models/internvl.py +28 -0
  195. sglang/srt/models/llama4.py +9 -0
  196. sglang/srt/models/llama_eagle3.py +17 -0
  197. sglang/srt/models/longcat_flash.py +1026 -0
  198. sglang/srt/models/longcat_flash_nextn.py +699 -0
  199. sglang/srt/models/minicpmv.py +165 -3
  200. sglang/srt/models/mllama4.py +25 -0
  201. sglang/srt/models/opt.py +637 -0
  202. sglang/srt/models/qwen2.py +33 -3
  203. sglang/srt/models/qwen2_5_vl.py +90 -42
  204. sglang/srt/models/qwen2_moe.py +79 -14
  205. sglang/srt/models/qwen3.py +8 -2
  206. sglang/srt/models/qwen3_moe.py +39 -8
  207. sglang/srt/models/qwen3_next.py +1039 -0
  208. sglang/srt/models/qwen3_next_mtp.py +109 -0
  209. sglang/srt/models/torch_native_llama.py +1 -1
  210. sglang/srt/models/transformers.py +1 -1
  211. sglang/srt/multimodal/processors/base_processor.py +4 -2
  212. sglang/srt/multimodal/processors/glm4v.py +9 -9
  213. sglang/srt/multimodal/processors/internvl.py +141 -129
  214. sglang/srt/{reasoning_parser.py → parser/reasoning_parser.py} +1 -1
  215. sglang/srt/sampling/penaltylib/orchestrator.py +14 -2
  216. sglang/srt/sampling/sampling_batch_info.py +18 -15
  217. sglang/srt/server_args.py +297 -79
  218. sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +5 -0
  219. sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +10 -1
  220. sglang/srt/speculative/eagle_worker.py +216 -120
  221. sglang/srt/speculative/spec_info.py +5 -0
  222. sglang/srt/speculative/standalone_worker.py +109 -0
  223. sglang/srt/utils.py +37 -2
  224. sglang/srt/weight_sync/utils.py +1 -1
  225. sglang/test/attention/test_trtllm_mla_backend.py +181 -8
  226. sglang/test/few_shot_gsm8k.py +1 -0
  227. sglang/test/runners.py +4 -0
  228. sglang/test/test_cutlass_moe.py +24 -6
  229. sglang/test/test_cutlass_w4a8_moe.py +24 -9
  230. sglang/test/test_disaggregation_utils.py +66 -0
  231. sglang/test/test_utils.py +25 -1
  232. sglang/utils.py +5 -0
  233. sglang/version.py +1 -1
  234. {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/METADATA +11 -9
  235. {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/RECORD +243 -194
  236. sglang/srt/disaggregation/launch_lb.py +0 -131
  237. sglang/srt/mem_cache/storage/mooncake_store/unit_test.py +0 -40
  238. /sglang/srt/{model_parallel.py → layers/model_parallel.py} +0 -0
  239. /sglang/srt/{code_completion_parser.py → parser/code_completion_parser.py} +0 -0
  240. /sglang/srt/{conversation.py → parser/conversation.py} +0 -0
  241. /sglang/srt/{harmony_parser.py → parser/harmony_parser.py} +0 -0
  242. /sglang/srt/{jinja_template_utils.py → parser/jinja_template_utils.py} +0 -0
  243. {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/WHEEL +0 -0
  244. {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/licenses/LICENSE +0 -0
  245. {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/top_level.txt +0 -0
@@ -78,6 +78,9 @@ class KVArgsRegisterInfo:
78
78
  dst_kv_ptrs: list[int]
79
79
  dst_aux_ptrs: list[int]
80
80
  gpu_id: int
81
+ decode_tp_size: int
82
+ decode_tp_rank: int
83
+ dst_kv_item_len: int
81
84
 
82
85
  @classmethod
83
86
  def from_zmq(cls, msg: List[bytes]):
@@ -90,6 +93,9 @@ class KVArgsRegisterInfo:
90
93
  dst_kv_ptrs=list(struct.unpack(f"{len(msg[5])//8}Q", msg[5])),
91
94
  dst_aux_ptrs=list(struct.unpack(f"{len(msg[6])//8}Q", msg[6])),
92
95
  gpu_id=int(msg[7].decode("ascii")),
96
+ decode_tp_size=int(msg[8].decode("ascii")),
97
+ decode_tp_rank=int(msg[9].decode("ascii")),
98
+ dst_kv_item_len=int(msg[10].decode("ascii")),
93
99
  )
94
100
 
95
101
 
@@ -166,7 +172,7 @@ class NixlKVManager(CommonKVManager):
166
172
  self.kv_args.kv_data_ptrs, self.kv_args.kv_data_lens
167
173
  ):
168
174
  kv_addrs.append((kv_data_ptr, kv_data_len, self.kv_args.gpu_id, ""))
169
- self.kv_descs = self.agent.register_memory(kv_addrs, "VRAM", is_sorted=False)
175
+ self.kv_descs = self.agent.register_memory(kv_addrs, "VRAM")
170
176
  logger.debug(f"Register kv tensors, len(kv_addr)= {len(kv_addrs)}")
171
177
  if not self.kv_descs:
172
178
  raise Exception("NIXL memory registration failed for kv tensors")
@@ -175,7 +181,7 @@ class NixlKVManager(CommonKVManager):
175
181
  self.kv_args.aux_data_ptrs, self.kv_args.aux_data_lens
176
182
  ):
177
183
  aux_addrs.append((aux_data_ptr, aux_data_len, 0, ""))
178
- self.aux_descs = self.agent.register_memory(aux_addrs, "DRAM", is_sorted=False)
184
+ self.aux_descs = self.agent.register_memory(aux_addrs, "DRAM")
179
185
  logger.debug(f"Register aux tensors, len(aux_addrs)= {len(aux_addrs)}")
180
186
  if not self.aux_descs:
181
187
  raise Exception("NIXL memory registration failed for aux tensors")
@@ -222,8 +228,8 @@ class NixlKVManager(CommonKVManager):
222
228
  logger.debug(
223
229
  f"len(src_addrs): before group: {len(prefill_kv_indices)}, after group: {len(src_addrs)}"
224
230
  )
225
- src_descs = self.agent.get_xfer_descs(src_addrs, "VRAM", is_sorted=False)
226
- dst_descs = self.agent.get_xfer_descs(dst_addrs, "VRAM", is_sorted=False)
231
+ src_descs = self.agent.get_xfer_descs(src_addrs, "VRAM")
232
+ dst_descs = self.agent.get_xfer_descs(dst_addrs, "VRAM")
227
233
  # Transfer data
228
234
  xfer_handle = self.agent.initialize_xfer(
229
235
  "WRITE",
@@ -239,6 +245,140 @@ class NixlKVManager(CommonKVManager):
239
245
  raise Exception("KVSender failed to post transfer")
240
246
  return xfer_handle
241
247
 
248
+ def send_kvcache_slice(
249
+ self,
250
+ peer_name: str,
251
+ prefill_kv_indices: npt.NDArray[np.int32],
252
+ dst_kv_ptrs: list[int],
253
+ dst_kv_indices: npt.NDArray[np.int32],
254
+ dst_gpu_id: int,
255
+ notif: str,
256
+ prefill_tp_size: int,
257
+ decode_tp_size: int,
258
+ decode_tp_rank: int,
259
+ dst_kv_item_len: int,
260
+ ):
261
+ # Get configuration from kv_args
262
+ local_tp_rank_in_group = self.kv_args.engine_rank % prefill_tp_size
263
+ dst_tp_rank_in_group = decode_tp_rank % decode_tp_size
264
+ num_kv_heads = self.kv_args.kv_head_num
265
+
266
+ # Calculate head distribution
267
+ src_heads_per_rank = num_kv_heads
268
+ dst_heads_per_rank = num_kv_heads * prefill_tp_size // decode_tp_size
269
+
270
+ src_kv_item_len = self.kv_args.kv_item_lens[0]
271
+ page_size = self.kv_args.page_size
272
+
273
+ bytes_per_head_slice_to_send = (
274
+ dst_kv_item_len // page_size // dst_heads_per_rank
275
+ )
276
+
277
+ # Determine which heads to send
278
+ if prefill_tp_size > decode_tp_size:
279
+ # Multiple prefill ranks to one decode rank
280
+ src_head_start_offset = 0
281
+ num_heads_to_send = src_heads_per_rank
282
+ dst_head_start_offset = local_tp_rank_in_group * src_heads_per_rank
283
+ else:
284
+ # Send KVCache from 1 prefill instance to multiple decode instances
285
+ src_head_start_offset = (
286
+ dst_tp_rank_in_group * dst_heads_per_rank
287
+ ) % src_heads_per_rank
288
+ num_heads_to_send = dst_heads_per_rank
289
+ dst_head_start_offset = 0
290
+
291
+ # Create transfer descriptors
292
+ src_addrs = []
293
+ dst_addrs = []
294
+
295
+ bytes_per_token_on_prefill = src_kv_item_len // page_size
296
+ bytes_per_token_on_decode = dst_kv_item_len // page_size
297
+
298
+ num_kv_layers = len(self.kv_args.kv_data_ptrs) // 2
299
+ src_k_ptrs = self.kv_args.kv_data_ptrs[:num_kv_layers]
300
+ src_v_ptrs = self.kv_args.kv_data_ptrs[num_kv_layers:]
301
+ dst_k_ptrs = dst_kv_ptrs[0 : len(src_k_ptrs)]
302
+ dst_v_ptrs = dst_kv_ptrs[num_kv_layers : num_kv_layers + len(src_v_ptrs)]
303
+
304
+ # Calculate precise byte offset and length for the sub-slice within the token
305
+ src_head_slice_offset = src_head_start_offset * bytes_per_head_slice_to_send
306
+ dst_head_slice_offset = dst_head_start_offset * bytes_per_head_slice_to_send
307
+ heads_bytes_per_token_to_send = num_heads_to_send * bytes_per_head_slice_to_send
308
+
309
+ src_dst_ptr_pairs = [
310
+ (
311
+ src_k_ptrs[layer_id],
312
+ dst_k_ptrs[layer_id],
313
+ )
314
+ for layer_id in range(len(src_k_ptrs))
315
+ ] + [
316
+ (
317
+ src_v_ptrs[layer_id],
318
+ dst_v_ptrs[layer_id],
319
+ )
320
+ for layer_id in range(len(src_v_ptrs))
321
+ ]
322
+
323
+ src_addrs = []
324
+ dst_addrs = []
325
+
326
+ # Calculate strides for a single token slot
327
+ bytes_per_token_on_prefill = src_kv_item_len // page_size
328
+ bytes_per_token_on_decode = dst_kv_item_len // page_size
329
+
330
+ for src_ptr, dst_ptr in src_dst_ptr_pairs:
331
+ for i in range(len(prefill_kv_indices)):
332
+ prefill_page_idx = int(prefill_kv_indices[i])
333
+ decode_page_idx = int(dst_kv_indices[i])
334
+
335
+ # Get the starting addresses for the current src and dst pages
336
+ src_page_start_addr = src_ptr + prefill_page_idx * src_kv_item_len
337
+ dst_page_start_addr = dst_ptr + decode_page_idx * dst_kv_item_len
338
+
339
+ # Iterate through each valid token slot within the current page
340
+ for token_slot_in_page in range(page_size):
341
+ # Calculate the start address of the current token slot
342
+ src_token_slot_start_addr = (
343
+ src_page_start_addr
344
+ + token_slot_in_page * bytes_per_token_on_prefill
345
+ )
346
+ dst_token_slot_start_addr = (
347
+ dst_page_start_addr
348
+ + token_slot_in_page * bytes_per_token_on_decode
349
+ )
350
+
351
+ # Calculate final src and dst addresses by applying head-slice offsets
352
+ src_slice_addr = src_token_slot_start_addr + src_head_slice_offset
353
+ dst_slice_addr = dst_token_slot_start_addr + dst_head_slice_offset
354
+
355
+ src_addrs.append(
356
+ (
357
+ src_slice_addr,
358
+ heads_bytes_per_token_to_send,
359
+ self.kv_args.gpu_id,
360
+ )
361
+ )
362
+ dst_addrs.append(
363
+ (dst_slice_addr, heads_bytes_per_token_to_send, dst_gpu_id)
364
+ )
365
+
366
+ # Use NIXL agent for transfer
367
+ src_descs = self.agent.get_xfer_descs(src_addrs, "VRAM")
368
+ dst_descs = self.agent.get_xfer_descs(dst_addrs, "VRAM")
369
+
370
+ xfer_handle = self.agent.initialize_xfer(
371
+ "WRITE", src_descs, dst_descs, peer_name, notif.encode("ascii")
372
+ )
373
+ if not xfer_handle:
374
+ raise Exception("Failed to create sliced KV transfer")
375
+
376
+ state = self.agent.transfer(xfer_handle)
377
+ if state == "ERR":
378
+ raise Exception("Failed to post sliced KV transfer")
379
+
380
+ return xfer_handle
381
+
242
382
  def send_aux(
243
383
  self,
244
384
  peer_name: str,
@@ -255,8 +395,8 @@ class NixlKVManager(CommonKVManager):
255
395
  decode_aux_addr = dst_aux_ptrs[0] + dst_aux_index * aux_item_len
256
396
  src_addrs = [(prefill_aux_addr, aux_item_len, 0)]
257
397
  dst_addrs = [(decode_aux_addr, aux_item_len, 0)]
258
- src_descs = self.agent.get_xfer_descs(src_addrs, "DRAM", is_sorted=False)
259
- dst_descs = self.agent.get_xfer_descs(dst_addrs, "DRAM", is_sorted=False)
398
+ src_descs = self.agent.get_xfer_descs(src_addrs, "DRAM")
399
+ dst_descs = self.agent.get_xfer_descs(dst_addrs, "DRAM")
260
400
  # Transfer data
261
401
  xfer_handle = self.agent.initialize_xfer(
262
402
  "WRITE",
@@ -296,14 +436,35 @@ class NixlKVManager(CommonKVManager):
296
436
  assert req.agent_name in self.decode_kv_args_table
297
437
 
298
438
  notif = "_".join([str(req.room), "kv", str(chunk_id), str(int(is_last))])
299
- kv_xfer_handle = self.send_kvcache(
300
- req.agent_name,
301
- kv_indices,
302
- self.decode_kv_args_table[req.agent_name].dst_kv_ptrs,
303
- chunked_dst_kv_indice,
304
- self.decode_kv_args_table[req.agent_name].gpu_id,
305
- notif,
306
- )
439
+ decode_tp_size = self.decode_kv_args_table[req.agent_name].decode_tp_size
440
+
441
+ if decode_tp_size == self.tp_size:
442
+ kv_xfer_handle = self.send_kvcache(
443
+ req.agent_name,
444
+ kv_indices,
445
+ self.decode_kv_args_table[req.agent_name].dst_kv_ptrs,
446
+ chunked_dst_kv_indice,
447
+ self.decode_kv_args_table[req.agent_name].gpu_id,
448
+ notif,
449
+ )
450
+ else:
451
+ kv_xfer_handle = self.send_kvcache_slice(
452
+ req.agent_name,
453
+ kv_indices,
454
+ self.decode_kv_args_table[req.agent_name].dst_kv_ptrs,
455
+ chunked_dst_kv_indice,
456
+ self.decode_kv_args_table[req.agent_name].gpu_id,
457
+ notif,
458
+ prefill_tp_size=self.tp_size,
459
+ decode_tp_size=decode_tp_size,
460
+ decode_tp_rank=self.decode_kv_args_table[
461
+ req.agent_name
462
+ ].decode_tp_rank,
463
+ dst_kv_item_len=self.decode_kv_args_table[
464
+ req.agent_name
465
+ ].dst_kv_item_len,
466
+ )
467
+
307
468
  handles.append(kv_xfer_handle)
308
469
  # Only the last chunk we need to send the aux data.
309
470
  if is_last:
@@ -454,11 +615,11 @@ class NixlKVReceiver(CommonKVReceiver):
454
615
  mgr: NixlKVManager,
455
616
  bootstrap_addr: str,
456
617
  bootstrap_room: Optional[int] = None,
457
- data_parallel_rank: Optional[int] = None,
618
+ prefill_dp_rank: Optional[int] = None,
458
619
  ):
459
620
  self.started_transfer = False
460
621
  self.conclude_state = None
461
- super().__init__(mgr, bootstrap_addr, bootstrap_room, data_parallel_rank)
622
+ super().__init__(mgr, bootstrap_addr, bootstrap_room, prefill_dp_rank)
462
623
 
463
624
  def init(self, kv_indices: npt.NDArray[np.int32], aux_index: Optional[int] = None):
464
625
  for bootstrap_info in self.bootstrap_infos:
@@ -521,6 +682,9 @@ class NixlKVReceiver(CommonKVReceiver):
521
682
  packed_kv_data_ptrs,
522
683
  packed_aux_data_ptrs,
523
684
  str(self.kv_mgr.kv_args.gpu_id).encode("ascii"),
685
+ str(self.kv_mgr.kv_args.decode_tp_size).encode("ascii"),
686
+ str(self.kv_mgr.kv_args.engine_rank).encode("ascii"),
687
+ str(self.kv_mgr.kv_args.kv_item_lens[0]).encode("ascii"),
524
688
  ]
525
689
  )
526
690
 
@@ -23,7 +23,7 @@ import logging
23
23
  import threading
24
24
  from collections import deque
25
25
  from http import HTTPStatus
26
- from typing import TYPE_CHECKING, List, Optional
26
+ from typing import TYPE_CHECKING, List, Optional, Type
27
27
 
28
28
  import torch
29
29
 
@@ -140,8 +140,10 @@ class PrefillBootstrapQueue:
140
140
  kv_args.ib_device = self.scheduler.server_args.disaggregation_ib_device
141
141
  kv_args.gpu_id = self.scheduler.gpu_id
142
142
 
143
- kv_manager_class = get_kv_class(self.transfer_backend, KVClassType.MANAGER)
144
- kv_manager = kv_manager_class(
143
+ kv_manager_class: Type[BaseKVManager] = get_kv_class(
144
+ self.transfer_backend, KVClassType.MANAGER
145
+ )
146
+ kv_manager: BaseKVManager = kv_manager_class(
145
147
  kv_args,
146
148
  DisaggregationMode.PREFILL,
147
149
  self.scheduler.server_args,
@@ -567,7 +569,7 @@ class SchedulerDisaggregationPrefillMixin:
567
569
  # Move the chunked request out of the batch so that we can merge
568
570
  # only finished requests to running_batch.
569
571
  self.last_batch.filter_batch(chunked_req_to_exclude=self.chunked_req)
570
- self.tree_cache.cache_unfinished_req(self.chunked_req)
572
+ self.tree_cache.cache_unfinished_req(self.chunked_req, chunked=True)
571
573
  if self.enable_overlap:
572
574
  # Delay KV transfer to process_batch_result_disagg_prefill when overlap is enabled to ensure results are resolved
573
575
  self.chunked_req.tmp_end_idx = min(
@@ -1,21 +1,17 @@
1
1
  from __future__ import annotations
2
2
 
3
- import dataclasses
4
3
  import os
5
4
  import random
6
- import threading
7
- import warnings
8
5
  from collections import deque
9
6
  from contextlib import nullcontext
10
7
  from enum import Enum
11
- from typing import TYPE_CHECKING, List, Optional
8
+ from typing import TYPE_CHECKING, List, Optional, Type, Union
12
9
 
13
10
  import numpy as np
14
- import requests
15
11
  import torch
16
12
  import torch.distributed as dist
17
13
 
18
- from sglang.srt.utils import get_ip, is_npu
14
+ from sglang.srt.utils import is_npu
19
15
 
20
16
  if TYPE_CHECKING:
21
17
  from sglang.srt.managers.schedule_batch import Req
@@ -217,7 +213,9 @@ class KVClassType(Enum):
217
213
  BOOTSTRAP_SERVER = "bootstrap_server"
218
214
 
219
215
 
220
- def get_kv_class(transfer_backend: TransferBackend, class_type: KVClassType):
216
+ def get_kv_class(
217
+ transfer_backend: TransferBackend, class_type: KVClassType
218
+ ) -> Optional[Type]:
221
219
  from sglang.srt.disaggregation.fake import FakeKVReceiver, FakeKVSender
222
220
 
223
221
  if transfer_backend == TransferBackend.MOONCAKE:
@@ -305,49 +303,6 @@ def kv_to_page_num(num_kv_indices: int, page_size: int):
305
303
  return (num_kv_indices + page_size - 1) // page_size
306
304
 
307
305
 
308
- #########################
309
- # PDLB Registry
310
- #########################
311
-
312
-
313
- @dataclasses.dataclass
314
- class PDRegistryRequest:
315
- """A request to register a machine itself to the LB."""
316
-
317
- mode: str
318
- registry_url: str
319
- bootstrap_port: Optional[int] = None
320
-
321
- def __post_init__(self):
322
- if self.mode == "prefill" and self.bootstrap_port is None:
323
- raise ValueError("Bootstrap port must be set in PREFILL mode.")
324
- elif self.mode == "decode" and self.bootstrap_port is not None:
325
- raise ValueError("Bootstrap port must not be set in DECODE mode.")
326
- elif self.mode not in ["prefill", "decode"]:
327
- raise ValueError(
328
- f"Invalid mode: {self.mode}. Must be 'prefill' or 'decode'."
329
- )
330
-
331
-
332
- def register_disaggregation_server(
333
- mode: str, server_port: int, bootstrap_port: int, pdlb_url: str
334
- ):
335
- boostrap_port = bootstrap_port if mode == "prefill" else None
336
- registry_request = PDRegistryRequest(
337
- mode=mode,
338
- registry_url=f"http://{get_ip()}:{server_port}",
339
- bootstrap_port=boostrap_port,
340
- )
341
- res = requests.post(
342
- f"{pdlb_url}/register",
343
- json=dataclasses.asdict(registry_request),
344
- )
345
- if res.status_code != 200:
346
- warnings.warn(
347
- f"Failed to register disaggregation server: {res.status_code} {res.text}"
348
- )
349
-
350
-
351
306
  #########################
352
307
  # Misc
353
308
  #########################
@@ -43,6 +43,7 @@ from sglang.srt.utils import (
43
43
  direct_register_custom_op,
44
44
  get_bool_env_var,
45
45
  get_int_env_var,
46
+ is_cpu,
46
47
  is_cuda_alike,
47
48
  is_hip,
48
49
  is_npu,
@@ -51,6 +52,9 @@ from sglang.srt.utils import (
51
52
  )
52
53
 
53
54
  _is_npu = is_npu()
55
+ _is_cpu = is_cpu()
56
+
57
+ IS_ONE_DEVICE_PER_PROCESS = get_bool_env_var("SGLANG_ONE_DEVICE_PER_PROCESS")
54
58
 
55
59
 
56
60
  @dataclass
@@ -60,6 +64,9 @@ class GraphCaptureContext:
60
64
 
61
65
  TensorMetadata = namedtuple("TensorMetadata", ["device", "dtype", "size"])
62
66
 
67
+ # use int value instead of ReduceOp.SUM to support torch compile
68
+ REDUCE_OP_SUM = int(torch.distributed.ReduceOp.SUM)
69
+
63
70
 
64
71
  def _split_tensor_dict(
65
72
  tensor_dict: Dict[str, Union[torch.Tensor, Any]]
@@ -223,10 +230,12 @@ class GroupCoordinator:
223
230
  use_message_queue_broadcaster: bool = False,
224
231
  group_name: Optional[str] = None,
225
232
  ):
233
+ # Set group info
226
234
  group_name = group_name or "anonymous"
227
235
  self.unique_name = _get_unique_name(group_name)
228
236
  _register_group(self)
229
237
 
238
+ # Set rank info
230
239
  self.rank = torch.distributed.get_rank()
231
240
  self.local_rank = local_rank
232
241
  self.device_group = None
@@ -250,14 +259,16 @@ class GroupCoordinator:
250
259
  assert self.cpu_group is not None
251
260
  assert self.device_group is not None
252
261
 
262
+ device_id = 0 if IS_ONE_DEVICE_PER_PROCESS else local_rank
253
263
  if is_cuda_alike():
254
- self.device = torch.device(f"cuda:{local_rank}")
264
+ self.device = torch.device(f"cuda:{device_id}")
255
265
  elif _is_npu:
256
- self.device = torch.device(f"npu:{local_rank}")
266
+ self.device = torch.device(f"npu:{device_id}")
257
267
  else:
258
268
  self.device = torch.device("cpu")
259
269
  self.device_module = torch.get_device_module(self.device)
260
270
 
271
+ # Import communicators
261
272
  self.use_pynccl = use_pynccl
262
273
  self.use_pymscclpp = use_pymscclpp
263
274
  self.use_custom_allreduce = use_custom_allreduce
@@ -270,6 +281,9 @@ class GroupCoordinator:
270
281
  from sglang.srt.distributed.device_communicators.custom_all_reduce import (
271
282
  CustomAllreduce,
272
283
  )
284
+ from sglang.srt.distributed.device_communicators.pymscclpp import (
285
+ PyMscclppCommunicator,
286
+ )
273
287
  from sglang.srt.distributed.device_communicators.pynccl import (
274
288
  PyNcclCommunicator,
275
289
  )
@@ -287,10 +301,6 @@ class GroupCoordinator:
287
301
  device=self.device,
288
302
  )
289
303
 
290
- from sglang.srt.distributed.device_communicators.pymscclpp import (
291
- PyMscclppCommunicator,
292
- )
293
-
294
304
  self.pymscclpp_comm: Optional[PyMscclppCommunicator] = None
295
305
  if use_pymscclpp and self.world_size > 1:
296
306
  self.pymscclpp_comm = PyMscclppCommunicator(
@@ -325,30 +335,30 @@ class GroupCoordinator:
325
335
  except Exception as e:
326
336
  logger.warning(f"Failed to initialize QuickAllReduce: {e}")
327
337
 
338
+ # Create communicator for other hardware backends
328
339
  from sglang.srt.distributed.device_communicators.hpu_communicator import (
329
340
  HpuCommunicator,
330
341
  )
342
+ from sglang.srt.distributed.device_communicators.npu_communicator import (
343
+ NpuCommunicator,
344
+ )
345
+ from sglang.srt.distributed.device_communicators.xpu_communicator import (
346
+ XpuCommunicator,
347
+ )
331
348
 
332
349
  self.hpu_communicator: Optional[HpuCommunicator] = None
333
350
  if use_hpu_communicator and self.world_size > 1:
334
351
  self.hpu_communicator = HpuCommunicator(group=self.device_group)
335
352
 
336
- from sglang.srt.distributed.device_communicators.xpu_communicator import (
337
- XpuCommunicator,
338
- )
339
-
340
353
  self.xpu_communicator: Optional[XpuCommunicator] = None
341
354
  if use_xpu_communicator and self.world_size > 1:
342
355
  self.xpu_communicator = XpuCommunicator(group=self.device_group)
343
356
 
344
- from sglang.srt.distributed.device_communicators.npu_communicator import (
345
- NpuCommunicator,
346
- )
347
-
348
357
  self.npu_communicator: Optional[NpuCommunicator] = None
349
358
  if use_npu_communicator and self.world_size > 1:
350
359
  self.npu_communicator = NpuCommunicator(group=self.device_group)
351
360
 
361
+ # Create message queue
352
362
  from sglang.srt.distributed.device_communicators.shm_broadcast import (
353
363
  MessageQueue,
354
364
  )
@@ -482,9 +492,7 @@ class GroupCoordinator:
482
492
 
483
493
  if input_.is_cpu:
484
494
  if is_shm_available(input_.dtype, self.world_size, self.local_size):
485
- torch.ops.sgl_kernel.shm_allreduce(
486
- input_, torch.distributed.ReduceOp.SUM
487
- )
495
+ torch.ops.sgl_kernel.shm_allreduce(input_, REDUCE_OP_SUM)
488
496
  else:
489
497
  torch.distributed.all_reduce(input_, group=self.device_group)
490
498
  return input_
@@ -848,6 +856,11 @@ class GroupCoordinator:
848
856
  )
849
857
  return obj_list
850
858
 
859
+ def all_gather_object(self, obj: Any) -> List[Any]:
860
+ objs = [None] * self.world_size
861
+ torch.distributed.all_gather_object(objs, obj, group=self.cpu_group)
862
+ return objs
863
+
851
864
  def send_object(self, obj: Any, dst: int) -> None:
852
865
  """Send the input object list to the destination rank."""
853
866
  """NOTE: `dst` is the local rank of the destination rank."""
@@ -867,17 +880,16 @@ class GroupCoordinator:
867
880
  size_tensor = torch.tensor(
868
881
  [object_tensor.numel()],
869
882
  dtype=torch.long,
870
- device=torch.cuda.current_device(),
883
+ device="cpu",
871
884
  )
872
-
873
885
  # Send object size
874
- torch.distributed.send(
875
- size_tensor, dst=self.ranks[dst], group=self.device_group
876
- )
886
+ torch.distributed.send(size_tensor, dst=self.ranks[dst], group=self.cpu_group)
877
887
 
878
888
  # Send object
879
889
  torch.distributed.send(
880
- object_tensor, dst=self.ranks[dst], group=self.device_group
890
+ object_tensor,
891
+ dst=self.ranks[dst],
892
+ group=self.device_group,
881
893
  )
882
894
 
883
895
  return None
@@ -892,13 +904,11 @@ class GroupCoordinator:
892
904
  src != self.rank_in_group
893
905
  ), "Invalid source rank. Source rank is the same as the current rank."
894
906
 
895
- size_tensor = torch.empty(
896
- 1, dtype=torch.long, device=torch.cuda.current_device()
897
- )
907
+ size_tensor = torch.empty(1, dtype=torch.long, device="cpu")
898
908
 
899
909
  # Receive object size
900
910
  rank_size = torch.distributed.recv(
901
- size_tensor, src=self.ranks[src], group=self.device_group
911
+ size_tensor, src=self.ranks[src], group=self.cpu_group
902
912
  )
903
913
 
904
914
  # Tensor to receive serialized objects into.
@@ -916,7 +926,7 @@ class GroupCoordinator:
916
926
  rank_object == rank_size
917
927
  ), "Received object sender rank does not match the size sender rank."
918
928
 
919
- obj = pickle.loads(object_tensor.cpu().numpy().tobytes())
929
+ obj = pickle.loads(object_tensor.cpu().numpy())
920
930
 
921
931
  return obj
922
932
 
@@ -1449,43 +1459,49 @@ def initialize_model_parallel(
1449
1459
  _PDMUX_PREFILL_TP_GROUP.pynccl_comm.disabled = False
1450
1460
 
1451
1461
  moe_ep_size = expert_model_parallel_size
1452
-
1453
1462
  moe_tp_size = tensor_model_parallel_size // moe_ep_size
1463
+
1454
1464
  global _MOE_EP
1455
1465
  assert _MOE_EP is None, "expert model parallel group is already initialized"
1456
- group_ranks = []
1457
- for i in range(num_tensor_model_parallel_groups):
1458
- for j in range(moe_tp_size):
1459
- st = i * tensor_model_parallel_size + j
1460
- en = (i + 1) * tensor_model_parallel_size + j
1461
- ranks = list(range(st, en, moe_tp_size))
1462
- group_ranks.append(ranks)
1463
1466
 
1464
- _MOE_EP = init_model_parallel_group(
1465
- group_ranks,
1466
- get_world_group().local_rank,
1467
- backend,
1468
- use_custom_allreduce=False,
1469
- group_name="moe_ep",
1470
- )
1467
+ if moe_ep_size == tensor_model_parallel_size:
1468
+ _MOE_EP = _TP
1469
+ else:
1470
+ # TODO(ch-wan): use split_group to save memory
1471
+ group_ranks = []
1472
+ for i in range(num_tensor_model_parallel_groups):
1473
+ for j in range(moe_tp_size):
1474
+ st = i * tensor_model_parallel_size + j
1475
+ en = (i + 1) * tensor_model_parallel_size + j
1476
+ ranks = list(range(st, en, moe_tp_size))
1477
+ group_ranks.append(ranks)
1478
+ _MOE_EP = init_model_parallel_group(
1479
+ group_ranks,
1480
+ get_world_group().local_rank,
1481
+ backend,
1482
+ group_name="moe_ep",
1483
+ )
1471
1484
 
1472
1485
  global _MOE_TP
1473
1486
  assert _MOE_TP is None, "expert model parallel group is already initialized"
1474
- group_ranks = []
1475
- for i in range(num_tensor_model_parallel_groups):
1476
- for j in range(moe_ep_size):
1477
- st = i * tensor_model_parallel_size + j * moe_tp_size
1478
- en = i * tensor_model_parallel_size + (j + 1) * moe_tp_size
1479
- ranks = list(range(st, en))
1480
- group_ranks.append(ranks)
1481
1487
 
1482
- _MOE_TP = init_model_parallel_group(
1483
- group_ranks,
1484
- get_world_group().local_rank,
1485
- backend,
1486
- use_custom_allreduce=False,
1487
- group_name="moe_tp",
1488
- )
1488
+ if moe_tp_size == tensor_model_parallel_size:
1489
+ _MOE_TP = _TP
1490
+ else:
1491
+ # TODO(ch-wan): use split_group to save memory
1492
+ group_ranks = []
1493
+ for i in range(num_tensor_model_parallel_groups):
1494
+ for j in range(moe_ep_size):
1495
+ st = i * tensor_model_parallel_size + j * moe_tp_size
1496
+ en = i * tensor_model_parallel_size + (j + 1) * moe_tp_size
1497
+ ranks = list(range(st, en))
1498
+ group_ranks.append(ranks)
1499
+ _MOE_TP = init_model_parallel_group(
1500
+ group_ranks,
1501
+ get_world_group().local_rank,
1502
+ backend,
1503
+ group_name="moe_tp",
1504
+ )
1489
1505
 
1490
1506
  # Build the pipeline model-parallel groups.
1491
1507
  num_pipeline_model_parallel_groups: int = world_size // pipeline_model_parallel_size
@@ -1571,6 +1587,16 @@ def patch_tensor_parallel_group(tp_group: GroupCoordinator):
1571
1587
  _TP = old_tp_group
1572
1588
 
1573
1589
 
1590
+ def get_world_size():
1591
+ """Return world size for the world group."""
1592
+ return get_world_group().world_size
1593
+
1594
+
1595
+ def get_world_rank():
1596
+ """Return my rank for the world group."""
1597
+ return get_world_group().rank_in_group
1598
+
1599
+
1574
1600
  def get_tensor_model_parallel_world_size():
1575
1601
  """Return world size for the tensor model parallel group."""
1576
1602
  return get_tp_group().world_size
@@ -1581,6 +1607,16 @@ def get_tensor_model_parallel_rank():
1581
1607
  return get_tp_group().rank_in_group
1582
1608
 
1583
1609
 
1610
+ def get_pipeline_model_parallel_world_size():
1611
+ """Return world size for the pipeline model parallel group."""
1612
+ return get_pp_group().world_size
1613
+
1614
+
1615
+ def get_pipeline_model_parallel_rank():
1616
+ """Return my rank for the pipeline model parallel group."""
1617
+ return get_pp_group().rank_in_group
1618
+
1619
+
1584
1620
  def get_moe_expert_parallel_world_size():
1585
1621
  """Return world size for the moe expert parallel group."""
1586
1622
  return get_moe_ep_group().world_size
@@ -1633,7 +1669,7 @@ def cleanup_dist_env_and_memory(shutdown_ray: bool = False):
1633
1669
 
1634
1670
  ray.shutdown()
1635
1671
  gc.collect()
1636
- if not current_platform.is_cpu():
1672
+ if not _is_cpu:
1637
1673
  if hasattr(torch, "cuda") and torch.cuda.is_available():
1638
1674
  torch.cuda.empty_cache()
1639
1675
  if hasattr(torch._C, "_host_emptyCache"):