sglang 0.5.1.post3__py3-none-any.whl → 0.5.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +3 -0
- sglang/bench_one_batch_server.py +10 -1
- sglang/bench_serving.py +251 -26
- sglang/lang/interpreter.py +1 -1
- sglang/srt/configs/__init__.py +4 -0
- sglang/srt/configs/internvl.py +6 -0
- sglang/srt/configs/longcat_flash.py +104 -0
- sglang/srt/configs/model_config.py +37 -7
- sglang/srt/configs/qwen3_next.py +326 -0
- sglang/srt/connector/__init__.py +1 -1
- sglang/srt/connector/base_connector.py +1 -2
- sglang/srt/connector/redis.py +2 -2
- sglang/srt/connector/serde/__init__.py +1 -1
- sglang/srt/connector/serde/safe_serde.py +4 -3
- sglang/srt/custom_op.py +11 -1
- sglang/srt/debug_utils/dump_comparator.py +81 -44
- sglang/srt/debug_utils/dump_loader.py +97 -0
- sglang/srt/debug_utils/dumper.py +11 -3
- sglang/srt/debug_utils/text_comparator.py +73 -11
- sglang/srt/disaggregation/ascend/conn.py +75 -0
- sglang/srt/disaggregation/base/conn.py +1 -1
- sglang/srt/disaggregation/common/conn.py +15 -12
- sglang/srt/disaggregation/decode.py +6 -4
- sglang/srt/disaggregation/fake/conn.py +1 -1
- sglang/srt/disaggregation/mini_lb.py +6 -420
- sglang/srt/disaggregation/mooncake/conn.py +18 -10
- sglang/srt/disaggregation/nixl/conn.py +180 -16
- sglang/srt/disaggregation/prefill.py +6 -4
- sglang/srt/disaggregation/utils.py +5 -50
- sglang/srt/distributed/parallel_state.py +94 -58
- sglang/srt/entrypoints/engine.py +34 -14
- sglang/srt/entrypoints/http_server.py +172 -47
- sglang/srt/entrypoints/openai/protocol.py +63 -3
- sglang/srt/entrypoints/openai/serving_base.py +6 -2
- sglang/srt/entrypoints/openai/serving_chat.py +34 -19
- sglang/srt/entrypoints/openai/serving_completions.py +10 -4
- sglang/srt/entrypoints/openai/serving_embedding.py +8 -4
- sglang/srt/entrypoints/openai/serving_responses.py +7 -4
- sglang/srt/eplb/eplb_manager.py +28 -4
- sglang/srt/eplb/expert_distribution.py +55 -15
- sglang/srt/eplb/expert_location.py +8 -3
- sglang/srt/eplb/expert_location_updater.py +1 -1
- sglang/srt/function_call/ebnf_composer.py +11 -9
- sglang/srt/function_call/glm4_moe_detector.py +1 -1
- sglang/srt/function_call/gpt_oss_detector.py +1 -1
- sglang/srt/function_call/qwen3_coder_detector.py +1 -1
- sglang/srt/hf_transformers_utils.py +12 -0
- sglang/srt/layers/activation.py +44 -9
- sglang/srt/layers/attention/aiter_backend.py +93 -68
- sglang/srt/layers/attention/ascend_backend.py +250 -112
- sglang/srt/layers/attention/fla/chunk.py +242 -0
- sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
- sglang/srt/layers/attention/fla/chunk_o.py +178 -0
- sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
- sglang/srt/layers/attention/fla/cumsum.py +300 -0
- sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
- sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
- sglang/srt/layers/attention/fla/index.py +37 -0
- sglang/srt/layers/attention/fla/l2norm.py +150 -0
- sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
- sglang/srt/layers/attention/fla/op.py +66 -0
- sglang/srt/layers/attention/fla/solve_tril.py +465 -0
- sglang/srt/layers/attention/fla/utils.py +331 -0
- sglang/srt/layers/attention/fla/wy_fast.py +158 -0
- sglang/srt/layers/attention/flashinfer_backend.py +6 -4
- sglang/srt/layers/attention/flashinfer_mla_backend.py +16 -12
- sglang/srt/layers/attention/hybrid_attn_backend.py +47 -8
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +584 -0
- sglang/srt/layers/attention/intel_amx_backend.py +3 -0
- sglang/srt/layers/attention/mamba/causal_conv1d.py +128 -0
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +1052 -0
- sglang/srt/layers/attention/mamba/mamba.py +64 -0
- sglang/srt/layers/attention/torch_native_backend.py +12 -6
- sglang/srt/layers/attention/trtllm_mla_backend.py +126 -36
- sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
- sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
- sglang/srt/layers/communicator.py +45 -7
- sglang/srt/layers/layernorm.py +54 -12
- sglang/srt/layers/logits_processor.py +10 -3
- sglang/srt/layers/moe/__init__.py +2 -1
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +4 -12
- sglang/srt/layers/moe/ep_moe/kernels.py +74 -0
- sglang/srt/layers/moe/ep_moe/layer.py +110 -49
- sglang/srt/layers/moe/fused_moe_native.py +5 -3
- sglang/srt/layers/moe/fused_moe_triton/__init__.py +5 -3
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/{E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → E=257,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json } +29 -29
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +9 -1049
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +212 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +799 -0
- sglang/srt/layers/moe/fused_moe_triton/layer.py +56 -45
- sglang/srt/layers/moe/fused_moe_triton/moe_align_block_size.py +87 -0
- sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
- sglang/srt/layers/moe/moe_runner/base.py +274 -1
- sglang/srt/layers/moe/moe_runner/runner.py +80 -0
- sglang/srt/layers/moe/moe_runner/triton.py +448 -0
- sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
- sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
- sglang/srt/layers/moe/token_dispatcher/deepep.py +41 -38
- sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
- sglang/srt/layers/moe/topk.py +43 -12
- sglang/srt/layers/moe/utils.py +6 -5
- sglang/srt/layers/quantization/awq.py +19 -7
- sglang/srt/layers/quantization/base_config.py +11 -6
- sglang/srt/layers/quantization/blockwise_int8.py +38 -27
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
- sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py +9 -1
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +0 -3
- sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
- sglang/srt/layers/quantization/fp8.py +76 -47
- sglang/srt/layers/quantization/fp8_utils.py +43 -29
- sglang/srt/layers/quantization/gptq.py +25 -17
- sglang/srt/layers/quantization/modelopt_quant.py +107 -40
- sglang/srt/layers/quantization/moe_wna16.py +21 -18
- sglang/srt/layers/quantization/mxfp4.py +77 -45
- sglang/srt/layers/quantization/quark/quark_moe.py +32 -27
- sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +49 -30
- sglang/srt/layers/quantization/quark/utils.py +97 -0
- sglang/srt/layers/quantization/rocm_mxfp4_utils.py +13 -0
- sglang/srt/layers/quantization/unquant.py +135 -47
- sglang/srt/layers/quantization/utils.py +13 -0
- sglang/srt/layers/quantization/w4afp8.py +60 -42
- sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
- sglang/srt/layers/quantization/w8a8_int8.py +83 -41
- sglang/srt/layers/rocm_linear_utils.py +44 -0
- sglang/srt/layers/rotary_embedding.py +28 -19
- sglang/srt/layers/sampler.py +29 -5
- sglang/srt/lora/backend/base_backend.py +50 -8
- sglang/srt/lora/backend/triton_backend.py +90 -2
- sglang/srt/lora/layers.py +32 -0
- sglang/srt/lora/lora.py +4 -1
- sglang/srt/lora/lora_manager.py +35 -112
- sglang/srt/lora/mem_pool.py +24 -10
- sglang/srt/lora/utils.py +18 -9
- sglang/srt/managers/cache_controller.py +242 -278
- sglang/srt/managers/data_parallel_controller.py +30 -15
- sglang/srt/managers/detokenizer_manager.py +13 -2
- sglang/srt/managers/disagg_service.py +46 -0
- sglang/srt/managers/io_struct.py +160 -11
- sglang/srt/managers/mm_utils.py +6 -1
- sglang/srt/managers/multi_tokenizer_mixin.py +579 -0
- sglang/srt/managers/schedule_batch.py +27 -44
- sglang/srt/managers/schedule_policy.py +4 -3
- sglang/srt/managers/scheduler.py +90 -115
- sglang/srt/managers/scheduler_metrics_mixin.py +114 -8
- sglang/srt/managers/scheduler_output_processor_mixin.py +29 -19
- sglang/srt/managers/scheduler_profiler_mixin.py +1 -1
- sglang/srt/managers/scheduler_update_weights_mixin.py +8 -1
- sglang/srt/managers/template_manager.py +3 -3
- sglang/srt/managers/tokenizer_communicator_mixin.py +491 -0
- sglang/srt/managers/tokenizer_manager.py +41 -477
- sglang/srt/managers/tp_worker.py +16 -4
- sglang/srt/managers/tp_worker_overlap_thread.py +8 -10
- sglang/srt/mem_cache/allocator.py +1 -1
- sglang/srt/mem_cache/chunk_cache.py +1 -1
- sglang/srt/mem_cache/hicache_storage.py +24 -22
- sglang/srt/mem_cache/hiradix_cache.py +184 -101
- sglang/srt/mem_cache/lora_radix_cache.py +1 -1
- sglang/srt/mem_cache/memory_pool.py +324 -41
- sglang/srt/mem_cache/memory_pool_host.py +25 -18
- sglang/srt/mem_cache/radix_cache.py +5 -6
- sglang/srt/mem_cache/radix_cache_cpp.py +1 -1
- sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
- sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
- sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +61 -34
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +149 -12
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +280 -0
- sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +74 -19
- sglang/srt/mem_cache/storage/mooncake_store/test_mooncake_store.py +161 -0
- sglang/srt/mem_cache/swa_radix_cache.py +1 -3
- sglang/srt/metrics/collector.py +484 -63
- sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
- sglang/srt/metrics/utils.py +48 -0
- sglang/srt/model_executor/cpu_graph_runner.py +640 -0
- sglang/srt/model_executor/cuda_graph_runner.py +13 -5
- sglang/srt/model_executor/forward_batch_info.py +72 -18
- sglang/srt/model_executor/model_runner.py +189 -31
- sglang/srt/model_loader/__init__.py +9 -3
- sglang/srt/model_loader/loader.py +33 -28
- sglang/srt/model_loader/utils.py +12 -0
- sglang/srt/model_loader/weight_utils.py +2 -1
- sglang/srt/models/deepseek_v2.py +311 -50
- sglang/srt/models/gemma3n_mm.py +1 -1
- sglang/srt/models/glm4_moe.py +10 -1
- sglang/srt/models/glm4v.py +4 -2
- sglang/srt/models/gpt_oss.py +5 -18
- sglang/srt/models/internvl.py +28 -0
- sglang/srt/models/llama4.py +9 -0
- sglang/srt/models/llama_eagle3.py +17 -0
- sglang/srt/models/longcat_flash.py +1026 -0
- sglang/srt/models/longcat_flash_nextn.py +699 -0
- sglang/srt/models/minicpmv.py +165 -3
- sglang/srt/models/mllama4.py +25 -0
- sglang/srt/models/opt.py +637 -0
- sglang/srt/models/qwen2.py +33 -3
- sglang/srt/models/qwen2_5_vl.py +90 -42
- sglang/srt/models/qwen2_moe.py +79 -14
- sglang/srt/models/qwen3.py +8 -2
- sglang/srt/models/qwen3_moe.py +39 -8
- sglang/srt/models/qwen3_next.py +1039 -0
- sglang/srt/models/qwen3_next_mtp.py +109 -0
- sglang/srt/models/torch_native_llama.py +1 -1
- sglang/srt/models/transformers.py +1 -1
- sglang/srt/multimodal/processors/base_processor.py +4 -2
- sglang/srt/multimodal/processors/glm4v.py +9 -9
- sglang/srt/multimodal/processors/internvl.py +141 -129
- sglang/srt/{reasoning_parser.py → parser/reasoning_parser.py} +1 -1
- sglang/srt/sampling/penaltylib/orchestrator.py +14 -2
- sglang/srt/sampling/sampling_batch_info.py +18 -15
- sglang/srt/server_args.py +297 -79
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +5 -0
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +10 -1
- sglang/srt/speculative/eagle_worker.py +216 -120
- sglang/srt/speculative/spec_info.py +5 -0
- sglang/srt/speculative/standalone_worker.py +109 -0
- sglang/srt/utils.py +37 -2
- sglang/srt/weight_sync/utils.py +1 -1
- sglang/test/attention/test_trtllm_mla_backend.py +181 -8
- sglang/test/few_shot_gsm8k.py +1 -0
- sglang/test/runners.py +4 -0
- sglang/test/test_cutlass_moe.py +24 -6
- sglang/test/test_cutlass_w4a8_moe.py +24 -9
- sglang/test/test_disaggregation_utils.py +66 -0
- sglang/test/test_utils.py +25 -1
- sglang/utils.py +5 -0
- sglang/version.py +1 -1
- {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/METADATA +11 -9
- {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/RECORD +243 -194
- sglang/srt/disaggregation/launch_lb.py +0 -131
- sglang/srt/mem_cache/storage/mooncake_store/unit_test.py +0 -40
- /sglang/srt/{model_parallel.py → layers/model_parallel.py} +0 -0
- /sglang/srt/{code_completion_parser.py → parser/code_completion_parser.py} +0 -0
- /sglang/srt/{conversation.py → parser/conversation.py} +0 -0
- /sglang/srt/{harmony_parser.py → parser/harmony_parser.py} +0 -0
- /sglang/srt/{jinja_template_utils.py → parser/jinja_template_utils.py} +0 -0
- {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/WHEEL +0 -0
- {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.1.post3.dist-info → sglang-0.5.2.dist-info}/top_level.txt +0 -0
@@ -78,6 +78,9 @@ class KVArgsRegisterInfo:
|
|
78
78
|
dst_kv_ptrs: list[int]
|
79
79
|
dst_aux_ptrs: list[int]
|
80
80
|
gpu_id: int
|
81
|
+
decode_tp_size: int
|
82
|
+
decode_tp_rank: int
|
83
|
+
dst_kv_item_len: int
|
81
84
|
|
82
85
|
@classmethod
|
83
86
|
def from_zmq(cls, msg: List[bytes]):
|
@@ -90,6 +93,9 @@ class KVArgsRegisterInfo:
|
|
90
93
|
dst_kv_ptrs=list(struct.unpack(f"{len(msg[5])//8}Q", msg[5])),
|
91
94
|
dst_aux_ptrs=list(struct.unpack(f"{len(msg[6])//8}Q", msg[6])),
|
92
95
|
gpu_id=int(msg[7].decode("ascii")),
|
96
|
+
decode_tp_size=int(msg[8].decode("ascii")),
|
97
|
+
decode_tp_rank=int(msg[9].decode("ascii")),
|
98
|
+
dst_kv_item_len=int(msg[10].decode("ascii")),
|
93
99
|
)
|
94
100
|
|
95
101
|
|
@@ -166,7 +172,7 @@ class NixlKVManager(CommonKVManager):
|
|
166
172
|
self.kv_args.kv_data_ptrs, self.kv_args.kv_data_lens
|
167
173
|
):
|
168
174
|
kv_addrs.append((kv_data_ptr, kv_data_len, self.kv_args.gpu_id, ""))
|
169
|
-
self.kv_descs = self.agent.register_memory(kv_addrs, "VRAM"
|
175
|
+
self.kv_descs = self.agent.register_memory(kv_addrs, "VRAM")
|
170
176
|
logger.debug(f"Register kv tensors, len(kv_addr)= {len(kv_addrs)}")
|
171
177
|
if not self.kv_descs:
|
172
178
|
raise Exception("NIXL memory registration failed for kv tensors")
|
@@ -175,7 +181,7 @@ class NixlKVManager(CommonKVManager):
|
|
175
181
|
self.kv_args.aux_data_ptrs, self.kv_args.aux_data_lens
|
176
182
|
):
|
177
183
|
aux_addrs.append((aux_data_ptr, aux_data_len, 0, ""))
|
178
|
-
self.aux_descs = self.agent.register_memory(aux_addrs, "DRAM"
|
184
|
+
self.aux_descs = self.agent.register_memory(aux_addrs, "DRAM")
|
179
185
|
logger.debug(f"Register aux tensors, len(aux_addrs)= {len(aux_addrs)}")
|
180
186
|
if not self.aux_descs:
|
181
187
|
raise Exception("NIXL memory registration failed for aux tensors")
|
@@ -222,8 +228,8 @@ class NixlKVManager(CommonKVManager):
|
|
222
228
|
logger.debug(
|
223
229
|
f"len(src_addrs): before group: {len(prefill_kv_indices)}, after group: {len(src_addrs)}"
|
224
230
|
)
|
225
|
-
src_descs = self.agent.get_xfer_descs(src_addrs, "VRAM"
|
226
|
-
dst_descs = self.agent.get_xfer_descs(dst_addrs, "VRAM"
|
231
|
+
src_descs = self.agent.get_xfer_descs(src_addrs, "VRAM")
|
232
|
+
dst_descs = self.agent.get_xfer_descs(dst_addrs, "VRAM")
|
227
233
|
# Transfer data
|
228
234
|
xfer_handle = self.agent.initialize_xfer(
|
229
235
|
"WRITE",
|
@@ -239,6 +245,140 @@ class NixlKVManager(CommonKVManager):
|
|
239
245
|
raise Exception("KVSender failed to post transfer")
|
240
246
|
return xfer_handle
|
241
247
|
|
248
|
+
def send_kvcache_slice(
|
249
|
+
self,
|
250
|
+
peer_name: str,
|
251
|
+
prefill_kv_indices: npt.NDArray[np.int32],
|
252
|
+
dst_kv_ptrs: list[int],
|
253
|
+
dst_kv_indices: npt.NDArray[np.int32],
|
254
|
+
dst_gpu_id: int,
|
255
|
+
notif: str,
|
256
|
+
prefill_tp_size: int,
|
257
|
+
decode_tp_size: int,
|
258
|
+
decode_tp_rank: int,
|
259
|
+
dst_kv_item_len: int,
|
260
|
+
):
|
261
|
+
# Get configuration from kv_args
|
262
|
+
local_tp_rank_in_group = self.kv_args.engine_rank % prefill_tp_size
|
263
|
+
dst_tp_rank_in_group = decode_tp_rank % decode_tp_size
|
264
|
+
num_kv_heads = self.kv_args.kv_head_num
|
265
|
+
|
266
|
+
# Calculate head distribution
|
267
|
+
src_heads_per_rank = num_kv_heads
|
268
|
+
dst_heads_per_rank = num_kv_heads * prefill_tp_size // decode_tp_size
|
269
|
+
|
270
|
+
src_kv_item_len = self.kv_args.kv_item_lens[0]
|
271
|
+
page_size = self.kv_args.page_size
|
272
|
+
|
273
|
+
bytes_per_head_slice_to_send = (
|
274
|
+
dst_kv_item_len // page_size // dst_heads_per_rank
|
275
|
+
)
|
276
|
+
|
277
|
+
# Determine which heads to send
|
278
|
+
if prefill_tp_size > decode_tp_size:
|
279
|
+
# Multiple prefill ranks to one decode rank
|
280
|
+
src_head_start_offset = 0
|
281
|
+
num_heads_to_send = src_heads_per_rank
|
282
|
+
dst_head_start_offset = local_tp_rank_in_group * src_heads_per_rank
|
283
|
+
else:
|
284
|
+
# Send KVCache from 1 prefill instance to multiple decode instances
|
285
|
+
src_head_start_offset = (
|
286
|
+
dst_tp_rank_in_group * dst_heads_per_rank
|
287
|
+
) % src_heads_per_rank
|
288
|
+
num_heads_to_send = dst_heads_per_rank
|
289
|
+
dst_head_start_offset = 0
|
290
|
+
|
291
|
+
# Create transfer descriptors
|
292
|
+
src_addrs = []
|
293
|
+
dst_addrs = []
|
294
|
+
|
295
|
+
bytes_per_token_on_prefill = src_kv_item_len // page_size
|
296
|
+
bytes_per_token_on_decode = dst_kv_item_len // page_size
|
297
|
+
|
298
|
+
num_kv_layers = len(self.kv_args.kv_data_ptrs) // 2
|
299
|
+
src_k_ptrs = self.kv_args.kv_data_ptrs[:num_kv_layers]
|
300
|
+
src_v_ptrs = self.kv_args.kv_data_ptrs[num_kv_layers:]
|
301
|
+
dst_k_ptrs = dst_kv_ptrs[0 : len(src_k_ptrs)]
|
302
|
+
dst_v_ptrs = dst_kv_ptrs[num_kv_layers : num_kv_layers + len(src_v_ptrs)]
|
303
|
+
|
304
|
+
# Calculate precise byte offset and length for the sub-slice within the token
|
305
|
+
src_head_slice_offset = src_head_start_offset * bytes_per_head_slice_to_send
|
306
|
+
dst_head_slice_offset = dst_head_start_offset * bytes_per_head_slice_to_send
|
307
|
+
heads_bytes_per_token_to_send = num_heads_to_send * bytes_per_head_slice_to_send
|
308
|
+
|
309
|
+
src_dst_ptr_pairs = [
|
310
|
+
(
|
311
|
+
src_k_ptrs[layer_id],
|
312
|
+
dst_k_ptrs[layer_id],
|
313
|
+
)
|
314
|
+
for layer_id in range(len(src_k_ptrs))
|
315
|
+
] + [
|
316
|
+
(
|
317
|
+
src_v_ptrs[layer_id],
|
318
|
+
dst_v_ptrs[layer_id],
|
319
|
+
)
|
320
|
+
for layer_id in range(len(src_v_ptrs))
|
321
|
+
]
|
322
|
+
|
323
|
+
src_addrs = []
|
324
|
+
dst_addrs = []
|
325
|
+
|
326
|
+
# Calculate strides for a single token slot
|
327
|
+
bytes_per_token_on_prefill = src_kv_item_len // page_size
|
328
|
+
bytes_per_token_on_decode = dst_kv_item_len // page_size
|
329
|
+
|
330
|
+
for src_ptr, dst_ptr in src_dst_ptr_pairs:
|
331
|
+
for i in range(len(prefill_kv_indices)):
|
332
|
+
prefill_page_idx = int(prefill_kv_indices[i])
|
333
|
+
decode_page_idx = int(dst_kv_indices[i])
|
334
|
+
|
335
|
+
# Get the starting addresses for the current src and dst pages
|
336
|
+
src_page_start_addr = src_ptr + prefill_page_idx * src_kv_item_len
|
337
|
+
dst_page_start_addr = dst_ptr + decode_page_idx * dst_kv_item_len
|
338
|
+
|
339
|
+
# Iterate through each valid token slot within the current page
|
340
|
+
for token_slot_in_page in range(page_size):
|
341
|
+
# Calculate the start address of the current token slot
|
342
|
+
src_token_slot_start_addr = (
|
343
|
+
src_page_start_addr
|
344
|
+
+ token_slot_in_page * bytes_per_token_on_prefill
|
345
|
+
)
|
346
|
+
dst_token_slot_start_addr = (
|
347
|
+
dst_page_start_addr
|
348
|
+
+ token_slot_in_page * bytes_per_token_on_decode
|
349
|
+
)
|
350
|
+
|
351
|
+
# Calculate final src and dst addresses by applying head-slice offsets
|
352
|
+
src_slice_addr = src_token_slot_start_addr + src_head_slice_offset
|
353
|
+
dst_slice_addr = dst_token_slot_start_addr + dst_head_slice_offset
|
354
|
+
|
355
|
+
src_addrs.append(
|
356
|
+
(
|
357
|
+
src_slice_addr,
|
358
|
+
heads_bytes_per_token_to_send,
|
359
|
+
self.kv_args.gpu_id,
|
360
|
+
)
|
361
|
+
)
|
362
|
+
dst_addrs.append(
|
363
|
+
(dst_slice_addr, heads_bytes_per_token_to_send, dst_gpu_id)
|
364
|
+
)
|
365
|
+
|
366
|
+
# Use NIXL agent for transfer
|
367
|
+
src_descs = self.agent.get_xfer_descs(src_addrs, "VRAM")
|
368
|
+
dst_descs = self.agent.get_xfer_descs(dst_addrs, "VRAM")
|
369
|
+
|
370
|
+
xfer_handle = self.agent.initialize_xfer(
|
371
|
+
"WRITE", src_descs, dst_descs, peer_name, notif.encode("ascii")
|
372
|
+
)
|
373
|
+
if not xfer_handle:
|
374
|
+
raise Exception("Failed to create sliced KV transfer")
|
375
|
+
|
376
|
+
state = self.agent.transfer(xfer_handle)
|
377
|
+
if state == "ERR":
|
378
|
+
raise Exception("Failed to post sliced KV transfer")
|
379
|
+
|
380
|
+
return xfer_handle
|
381
|
+
|
242
382
|
def send_aux(
|
243
383
|
self,
|
244
384
|
peer_name: str,
|
@@ -255,8 +395,8 @@ class NixlKVManager(CommonKVManager):
|
|
255
395
|
decode_aux_addr = dst_aux_ptrs[0] + dst_aux_index * aux_item_len
|
256
396
|
src_addrs = [(prefill_aux_addr, aux_item_len, 0)]
|
257
397
|
dst_addrs = [(decode_aux_addr, aux_item_len, 0)]
|
258
|
-
src_descs = self.agent.get_xfer_descs(src_addrs, "DRAM"
|
259
|
-
dst_descs = self.agent.get_xfer_descs(dst_addrs, "DRAM"
|
398
|
+
src_descs = self.agent.get_xfer_descs(src_addrs, "DRAM")
|
399
|
+
dst_descs = self.agent.get_xfer_descs(dst_addrs, "DRAM")
|
260
400
|
# Transfer data
|
261
401
|
xfer_handle = self.agent.initialize_xfer(
|
262
402
|
"WRITE",
|
@@ -296,14 +436,35 @@ class NixlKVManager(CommonKVManager):
|
|
296
436
|
assert req.agent_name in self.decode_kv_args_table
|
297
437
|
|
298
438
|
notif = "_".join([str(req.room), "kv", str(chunk_id), str(int(is_last))])
|
299
|
-
|
300
|
-
|
301
|
-
|
302
|
-
self.
|
303
|
-
|
304
|
-
|
305
|
-
|
306
|
-
|
439
|
+
decode_tp_size = self.decode_kv_args_table[req.agent_name].decode_tp_size
|
440
|
+
|
441
|
+
if decode_tp_size == self.tp_size:
|
442
|
+
kv_xfer_handle = self.send_kvcache(
|
443
|
+
req.agent_name,
|
444
|
+
kv_indices,
|
445
|
+
self.decode_kv_args_table[req.agent_name].dst_kv_ptrs,
|
446
|
+
chunked_dst_kv_indice,
|
447
|
+
self.decode_kv_args_table[req.agent_name].gpu_id,
|
448
|
+
notif,
|
449
|
+
)
|
450
|
+
else:
|
451
|
+
kv_xfer_handle = self.send_kvcache_slice(
|
452
|
+
req.agent_name,
|
453
|
+
kv_indices,
|
454
|
+
self.decode_kv_args_table[req.agent_name].dst_kv_ptrs,
|
455
|
+
chunked_dst_kv_indice,
|
456
|
+
self.decode_kv_args_table[req.agent_name].gpu_id,
|
457
|
+
notif,
|
458
|
+
prefill_tp_size=self.tp_size,
|
459
|
+
decode_tp_size=decode_tp_size,
|
460
|
+
decode_tp_rank=self.decode_kv_args_table[
|
461
|
+
req.agent_name
|
462
|
+
].decode_tp_rank,
|
463
|
+
dst_kv_item_len=self.decode_kv_args_table[
|
464
|
+
req.agent_name
|
465
|
+
].dst_kv_item_len,
|
466
|
+
)
|
467
|
+
|
307
468
|
handles.append(kv_xfer_handle)
|
308
469
|
# Only the last chunk we need to send the aux data.
|
309
470
|
if is_last:
|
@@ -454,11 +615,11 @@ class NixlKVReceiver(CommonKVReceiver):
|
|
454
615
|
mgr: NixlKVManager,
|
455
616
|
bootstrap_addr: str,
|
456
617
|
bootstrap_room: Optional[int] = None,
|
457
|
-
|
618
|
+
prefill_dp_rank: Optional[int] = None,
|
458
619
|
):
|
459
620
|
self.started_transfer = False
|
460
621
|
self.conclude_state = None
|
461
|
-
super().__init__(mgr, bootstrap_addr, bootstrap_room,
|
622
|
+
super().__init__(mgr, bootstrap_addr, bootstrap_room, prefill_dp_rank)
|
462
623
|
|
463
624
|
def init(self, kv_indices: npt.NDArray[np.int32], aux_index: Optional[int] = None):
|
464
625
|
for bootstrap_info in self.bootstrap_infos:
|
@@ -521,6 +682,9 @@ class NixlKVReceiver(CommonKVReceiver):
|
|
521
682
|
packed_kv_data_ptrs,
|
522
683
|
packed_aux_data_ptrs,
|
523
684
|
str(self.kv_mgr.kv_args.gpu_id).encode("ascii"),
|
685
|
+
str(self.kv_mgr.kv_args.decode_tp_size).encode("ascii"),
|
686
|
+
str(self.kv_mgr.kv_args.engine_rank).encode("ascii"),
|
687
|
+
str(self.kv_mgr.kv_args.kv_item_lens[0]).encode("ascii"),
|
524
688
|
]
|
525
689
|
)
|
526
690
|
|
@@ -23,7 +23,7 @@ import logging
|
|
23
23
|
import threading
|
24
24
|
from collections import deque
|
25
25
|
from http import HTTPStatus
|
26
|
-
from typing import TYPE_CHECKING, List, Optional
|
26
|
+
from typing import TYPE_CHECKING, List, Optional, Type
|
27
27
|
|
28
28
|
import torch
|
29
29
|
|
@@ -140,8 +140,10 @@ class PrefillBootstrapQueue:
|
|
140
140
|
kv_args.ib_device = self.scheduler.server_args.disaggregation_ib_device
|
141
141
|
kv_args.gpu_id = self.scheduler.gpu_id
|
142
142
|
|
143
|
-
kv_manager_class = get_kv_class(
|
144
|
-
|
143
|
+
kv_manager_class: Type[BaseKVManager] = get_kv_class(
|
144
|
+
self.transfer_backend, KVClassType.MANAGER
|
145
|
+
)
|
146
|
+
kv_manager: BaseKVManager = kv_manager_class(
|
145
147
|
kv_args,
|
146
148
|
DisaggregationMode.PREFILL,
|
147
149
|
self.scheduler.server_args,
|
@@ -567,7 +569,7 @@ class SchedulerDisaggregationPrefillMixin:
|
|
567
569
|
# Move the chunked request out of the batch so that we can merge
|
568
570
|
# only finished requests to running_batch.
|
569
571
|
self.last_batch.filter_batch(chunked_req_to_exclude=self.chunked_req)
|
570
|
-
self.tree_cache.cache_unfinished_req(self.chunked_req)
|
572
|
+
self.tree_cache.cache_unfinished_req(self.chunked_req, chunked=True)
|
571
573
|
if self.enable_overlap:
|
572
574
|
# Delay KV transfer to process_batch_result_disagg_prefill when overlap is enabled to ensure results are resolved
|
573
575
|
self.chunked_req.tmp_end_idx = min(
|
@@ -1,21 +1,17 @@
|
|
1
1
|
from __future__ import annotations
|
2
2
|
|
3
|
-
import dataclasses
|
4
3
|
import os
|
5
4
|
import random
|
6
|
-
import threading
|
7
|
-
import warnings
|
8
5
|
from collections import deque
|
9
6
|
from contextlib import nullcontext
|
10
7
|
from enum import Enum
|
11
|
-
from typing import TYPE_CHECKING, List, Optional
|
8
|
+
from typing import TYPE_CHECKING, List, Optional, Type, Union
|
12
9
|
|
13
10
|
import numpy as np
|
14
|
-
import requests
|
15
11
|
import torch
|
16
12
|
import torch.distributed as dist
|
17
13
|
|
18
|
-
from sglang.srt.utils import
|
14
|
+
from sglang.srt.utils import is_npu
|
19
15
|
|
20
16
|
if TYPE_CHECKING:
|
21
17
|
from sglang.srt.managers.schedule_batch import Req
|
@@ -217,7 +213,9 @@ class KVClassType(Enum):
|
|
217
213
|
BOOTSTRAP_SERVER = "bootstrap_server"
|
218
214
|
|
219
215
|
|
220
|
-
def get_kv_class(
|
216
|
+
def get_kv_class(
|
217
|
+
transfer_backend: TransferBackend, class_type: KVClassType
|
218
|
+
) -> Optional[Type]:
|
221
219
|
from sglang.srt.disaggregation.fake import FakeKVReceiver, FakeKVSender
|
222
220
|
|
223
221
|
if transfer_backend == TransferBackend.MOONCAKE:
|
@@ -305,49 +303,6 @@ def kv_to_page_num(num_kv_indices: int, page_size: int):
|
|
305
303
|
return (num_kv_indices + page_size - 1) // page_size
|
306
304
|
|
307
305
|
|
308
|
-
#########################
|
309
|
-
# PDLB Registry
|
310
|
-
#########################
|
311
|
-
|
312
|
-
|
313
|
-
@dataclasses.dataclass
|
314
|
-
class PDRegistryRequest:
|
315
|
-
"""A request to register a machine itself to the LB."""
|
316
|
-
|
317
|
-
mode: str
|
318
|
-
registry_url: str
|
319
|
-
bootstrap_port: Optional[int] = None
|
320
|
-
|
321
|
-
def __post_init__(self):
|
322
|
-
if self.mode == "prefill" and self.bootstrap_port is None:
|
323
|
-
raise ValueError("Bootstrap port must be set in PREFILL mode.")
|
324
|
-
elif self.mode == "decode" and self.bootstrap_port is not None:
|
325
|
-
raise ValueError("Bootstrap port must not be set in DECODE mode.")
|
326
|
-
elif self.mode not in ["prefill", "decode"]:
|
327
|
-
raise ValueError(
|
328
|
-
f"Invalid mode: {self.mode}. Must be 'prefill' or 'decode'."
|
329
|
-
)
|
330
|
-
|
331
|
-
|
332
|
-
def register_disaggregation_server(
|
333
|
-
mode: str, server_port: int, bootstrap_port: int, pdlb_url: str
|
334
|
-
):
|
335
|
-
boostrap_port = bootstrap_port if mode == "prefill" else None
|
336
|
-
registry_request = PDRegistryRequest(
|
337
|
-
mode=mode,
|
338
|
-
registry_url=f"http://{get_ip()}:{server_port}",
|
339
|
-
bootstrap_port=boostrap_port,
|
340
|
-
)
|
341
|
-
res = requests.post(
|
342
|
-
f"{pdlb_url}/register",
|
343
|
-
json=dataclasses.asdict(registry_request),
|
344
|
-
)
|
345
|
-
if res.status_code != 200:
|
346
|
-
warnings.warn(
|
347
|
-
f"Failed to register disaggregation server: {res.status_code} {res.text}"
|
348
|
-
)
|
349
|
-
|
350
|
-
|
351
306
|
#########################
|
352
307
|
# Misc
|
353
308
|
#########################
|
@@ -43,6 +43,7 @@ from sglang.srt.utils import (
|
|
43
43
|
direct_register_custom_op,
|
44
44
|
get_bool_env_var,
|
45
45
|
get_int_env_var,
|
46
|
+
is_cpu,
|
46
47
|
is_cuda_alike,
|
47
48
|
is_hip,
|
48
49
|
is_npu,
|
@@ -51,6 +52,9 @@ from sglang.srt.utils import (
|
|
51
52
|
)
|
52
53
|
|
53
54
|
_is_npu = is_npu()
|
55
|
+
_is_cpu = is_cpu()
|
56
|
+
|
57
|
+
IS_ONE_DEVICE_PER_PROCESS = get_bool_env_var("SGLANG_ONE_DEVICE_PER_PROCESS")
|
54
58
|
|
55
59
|
|
56
60
|
@dataclass
|
@@ -60,6 +64,9 @@ class GraphCaptureContext:
|
|
60
64
|
|
61
65
|
TensorMetadata = namedtuple("TensorMetadata", ["device", "dtype", "size"])
|
62
66
|
|
67
|
+
# use int value instead of ReduceOp.SUM to support torch compile
|
68
|
+
REDUCE_OP_SUM = int(torch.distributed.ReduceOp.SUM)
|
69
|
+
|
63
70
|
|
64
71
|
def _split_tensor_dict(
|
65
72
|
tensor_dict: Dict[str, Union[torch.Tensor, Any]]
|
@@ -223,10 +230,12 @@ class GroupCoordinator:
|
|
223
230
|
use_message_queue_broadcaster: bool = False,
|
224
231
|
group_name: Optional[str] = None,
|
225
232
|
):
|
233
|
+
# Set group info
|
226
234
|
group_name = group_name or "anonymous"
|
227
235
|
self.unique_name = _get_unique_name(group_name)
|
228
236
|
_register_group(self)
|
229
237
|
|
238
|
+
# Set rank info
|
230
239
|
self.rank = torch.distributed.get_rank()
|
231
240
|
self.local_rank = local_rank
|
232
241
|
self.device_group = None
|
@@ -250,14 +259,16 @@ class GroupCoordinator:
|
|
250
259
|
assert self.cpu_group is not None
|
251
260
|
assert self.device_group is not None
|
252
261
|
|
262
|
+
device_id = 0 if IS_ONE_DEVICE_PER_PROCESS else local_rank
|
253
263
|
if is_cuda_alike():
|
254
|
-
self.device = torch.device(f"cuda:{
|
264
|
+
self.device = torch.device(f"cuda:{device_id}")
|
255
265
|
elif _is_npu:
|
256
|
-
self.device = torch.device(f"npu:{
|
266
|
+
self.device = torch.device(f"npu:{device_id}")
|
257
267
|
else:
|
258
268
|
self.device = torch.device("cpu")
|
259
269
|
self.device_module = torch.get_device_module(self.device)
|
260
270
|
|
271
|
+
# Import communicators
|
261
272
|
self.use_pynccl = use_pynccl
|
262
273
|
self.use_pymscclpp = use_pymscclpp
|
263
274
|
self.use_custom_allreduce = use_custom_allreduce
|
@@ -270,6 +281,9 @@ class GroupCoordinator:
|
|
270
281
|
from sglang.srt.distributed.device_communicators.custom_all_reduce import (
|
271
282
|
CustomAllreduce,
|
272
283
|
)
|
284
|
+
from sglang.srt.distributed.device_communicators.pymscclpp import (
|
285
|
+
PyMscclppCommunicator,
|
286
|
+
)
|
273
287
|
from sglang.srt.distributed.device_communicators.pynccl import (
|
274
288
|
PyNcclCommunicator,
|
275
289
|
)
|
@@ -287,10 +301,6 @@ class GroupCoordinator:
|
|
287
301
|
device=self.device,
|
288
302
|
)
|
289
303
|
|
290
|
-
from sglang.srt.distributed.device_communicators.pymscclpp import (
|
291
|
-
PyMscclppCommunicator,
|
292
|
-
)
|
293
|
-
|
294
304
|
self.pymscclpp_comm: Optional[PyMscclppCommunicator] = None
|
295
305
|
if use_pymscclpp and self.world_size > 1:
|
296
306
|
self.pymscclpp_comm = PyMscclppCommunicator(
|
@@ -325,30 +335,30 @@ class GroupCoordinator:
|
|
325
335
|
except Exception as e:
|
326
336
|
logger.warning(f"Failed to initialize QuickAllReduce: {e}")
|
327
337
|
|
338
|
+
# Create communicator for other hardware backends
|
328
339
|
from sglang.srt.distributed.device_communicators.hpu_communicator import (
|
329
340
|
HpuCommunicator,
|
330
341
|
)
|
342
|
+
from sglang.srt.distributed.device_communicators.npu_communicator import (
|
343
|
+
NpuCommunicator,
|
344
|
+
)
|
345
|
+
from sglang.srt.distributed.device_communicators.xpu_communicator import (
|
346
|
+
XpuCommunicator,
|
347
|
+
)
|
331
348
|
|
332
349
|
self.hpu_communicator: Optional[HpuCommunicator] = None
|
333
350
|
if use_hpu_communicator and self.world_size > 1:
|
334
351
|
self.hpu_communicator = HpuCommunicator(group=self.device_group)
|
335
352
|
|
336
|
-
from sglang.srt.distributed.device_communicators.xpu_communicator import (
|
337
|
-
XpuCommunicator,
|
338
|
-
)
|
339
|
-
|
340
353
|
self.xpu_communicator: Optional[XpuCommunicator] = None
|
341
354
|
if use_xpu_communicator and self.world_size > 1:
|
342
355
|
self.xpu_communicator = XpuCommunicator(group=self.device_group)
|
343
356
|
|
344
|
-
from sglang.srt.distributed.device_communicators.npu_communicator import (
|
345
|
-
NpuCommunicator,
|
346
|
-
)
|
347
|
-
|
348
357
|
self.npu_communicator: Optional[NpuCommunicator] = None
|
349
358
|
if use_npu_communicator and self.world_size > 1:
|
350
359
|
self.npu_communicator = NpuCommunicator(group=self.device_group)
|
351
360
|
|
361
|
+
# Create message queue
|
352
362
|
from sglang.srt.distributed.device_communicators.shm_broadcast import (
|
353
363
|
MessageQueue,
|
354
364
|
)
|
@@ -482,9 +492,7 @@ class GroupCoordinator:
|
|
482
492
|
|
483
493
|
if input_.is_cpu:
|
484
494
|
if is_shm_available(input_.dtype, self.world_size, self.local_size):
|
485
|
-
torch.ops.sgl_kernel.shm_allreduce(
|
486
|
-
input_, torch.distributed.ReduceOp.SUM
|
487
|
-
)
|
495
|
+
torch.ops.sgl_kernel.shm_allreduce(input_, REDUCE_OP_SUM)
|
488
496
|
else:
|
489
497
|
torch.distributed.all_reduce(input_, group=self.device_group)
|
490
498
|
return input_
|
@@ -848,6 +856,11 @@ class GroupCoordinator:
|
|
848
856
|
)
|
849
857
|
return obj_list
|
850
858
|
|
859
|
+
def all_gather_object(self, obj: Any) -> List[Any]:
|
860
|
+
objs = [None] * self.world_size
|
861
|
+
torch.distributed.all_gather_object(objs, obj, group=self.cpu_group)
|
862
|
+
return objs
|
863
|
+
|
851
864
|
def send_object(self, obj: Any, dst: int) -> None:
|
852
865
|
"""Send the input object list to the destination rank."""
|
853
866
|
"""NOTE: `dst` is the local rank of the destination rank."""
|
@@ -867,17 +880,16 @@ class GroupCoordinator:
|
|
867
880
|
size_tensor = torch.tensor(
|
868
881
|
[object_tensor.numel()],
|
869
882
|
dtype=torch.long,
|
870
|
-
device=
|
883
|
+
device="cpu",
|
871
884
|
)
|
872
|
-
|
873
885
|
# Send object size
|
874
|
-
torch.distributed.send(
|
875
|
-
size_tensor, dst=self.ranks[dst], group=self.device_group
|
876
|
-
)
|
886
|
+
torch.distributed.send(size_tensor, dst=self.ranks[dst], group=self.cpu_group)
|
877
887
|
|
878
888
|
# Send object
|
879
889
|
torch.distributed.send(
|
880
|
-
object_tensor,
|
890
|
+
object_tensor,
|
891
|
+
dst=self.ranks[dst],
|
892
|
+
group=self.device_group,
|
881
893
|
)
|
882
894
|
|
883
895
|
return None
|
@@ -892,13 +904,11 @@ class GroupCoordinator:
|
|
892
904
|
src != self.rank_in_group
|
893
905
|
), "Invalid source rank. Source rank is the same as the current rank."
|
894
906
|
|
895
|
-
size_tensor = torch.empty(
|
896
|
-
1, dtype=torch.long, device=torch.cuda.current_device()
|
897
|
-
)
|
907
|
+
size_tensor = torch.empty(1, dtype=torch.long, device="cpu")
|
898
908
|
|
899
909
|
# Receive object size
|
900
910
|
rank_size = torch.distributed.recv(
|
901
|
-
size_tensor, src=self.ranks[src], group=self.
|
911
|
+
size_tensor, src=self.ranks[src], group=self.cpu_group
|
902
912
|
)
|
903
913
|
|
904
914
|
# Tensor to receive serialized objects into.
|
@@ -916,7 +926,7 @@ class GroupCoordinator:
|
|
916
926
|
rank_object == rank_size
|
917
927
|
), "Received object sender rank does not match the size sender rank."
|
918
928
|
|
919
|
-
obj = pickle.loads(object_tensor.cpu().numpy()
|
929
|
+
obj = pickle.loads(object_tensor.cpu().numpy())
|
920
930
|
|
921
931
|
return obj
|
922
932
|
|
@@ -1449,43 +1459,49 @@ def initialize_model_parallel(
|
|
1449
1459
|
_PDMUX_PREFILL_TP_GROUP.pynccl_comm.disabled = False
|
1450
1460
|
|
1451
1461
|
moe_ep_size = expert_model_parallel_size
|
1452
|
-
|
1453
1462
|
moe_tp_size = tensor_model_parallel_size // moe_ep_size
|
1463
|
+
|
1454
1464
|
global _MOE_EP
|
1455
1465
|
assert _MOE_EP is None, "expert model parallel group is already initialized"
|
1456
|
-
group_ranks = []
|
1457
|
-
for i in range(num_tensor_model_parallel_groups):
|
1458
|
-
for j in range(moe_tp_size):
|
1459
|
-
st = i * tensor_model_parallel_size + j
|
1460
|
-
en = (i + 1) * tensor_model_parallel_size + j
|
1461
|
-
ranks = list(range(st, en, moe_tp_size))
|
1462
|
-
group_ranks.append(ranks)
|
1463
1466
|
|
1464
|
-
|
1465
|
-
|
1466
|
-
|
1467
|
-
|
1468
|
-
|
1469
|
-
|
1470
|
-
|
1467
|
+
if moe_ep_size == tensor_model_parallel_size:
|
1468
|
+
_MOE_EP = _TP
|
1469
|
+
else:
|
1470
|
+
# TODO(ch-wan): use split_group to save memory
|
1471
|
+
group_ranks = []
|
1472
|
+
for i in range(num_tensor_model_parallel_groups):
|
1473
|
+
for j in range(moe_tp_size):
|
1474
|
+
st = i * tensor_model_parallel_size + j
|
1475
|
+
en = (i + 1) * tensor_model_parallel_size + j
|
1476
|
+
ranks = list(range(st, en, moe_tp_size))
|
1477
|
+
group_ranks.append(ranks)
|
1478
|
+
_MOE_EP = init_model_parallel_group(
|
1479
|
+
group_ranks,
|
1480
|
+
get_world_group().local_rank,
|
1481
|
+
backend,
|
1482
|
+
group_name="moe_ep",
|
1483
|
+
)
|
1471
1484
|
|
1472
1485
|
global _MOE_TP
|
1473
1486
|
assert _MOE_TP is None, "expert model parallel group is already initialized"
|
1474
|
-
group_ranks = []
|
1475
|
-
for i in range(num_tensor_model_parallel_groups):
|
1476
|
-
for j in range(moe_ep_size):
|
1477
|
-
st = i * tensor_model_parallel_size + j * moe_tp_size
|
1478
|
-
en = i * tensor_model_parallel_size + (j + 1) * moe_tp_size
|
1479
|
-
ranks = list(range(st, en))
|
1480
|
-
group_ranks.append(ranks)
|
1481
1487
|
|
1482
|
-
|
1483
|
-
|
1484
|
-
|
1485
|
-
|
1486
|
-
|
1487
|
-
|
1488
|
-
|
1488
|
+
if moe_tp_size == tensor_model_parallel_size:
|
1489
|
+
_MOE_TP = _TP
|
1490
|
+
else:
|
1491
|
+
# TODO(ch-wan): use split_group to save memory
|
1492
|
+
group_ranks = []
|
1493
|
+
for i in range(num_tensor_model_parallel_groups):
|
1494
|
+
for j in range(moe_ep_size):
|
1495
|
+
st = i * tensor_model_parallel_size + j * moe_tp_size
|
1496
|
+
en = i * tensor_model_parallel_size + (j + 1) * moe_tp_size
|
1497
|
+
ranks = list(range(st, en))
|
1498
|
+
group_ranks.append(ranks)
|
1499
|
+
_MOE_TP = init_model_parallel_group(
|
1500
|
+
group_ranks,
|
1501
|
+
get_world_group().local_rank,
|
1502
|
+
backend,
|
1503
|
+
group_name="moe_tp",
|
1504
|
+
)
|
1489
1505
|
|
1490
1506
|
# Build the pipeline model-parallel groups.
|
1491
1507
|
num_pipeline_model_parallel_groups: int = world_size // pipeline_model_parallel_size
|
@@ -1571,6 +1587,16 @@ def patch_tensor_parallel_group(tp_group: GroupCoordinator):
|
|
1571
1587
|
_TP = old_tp_group
|
1572
1588
|
|
1573
1589
|
|
1590
|
+
def get_world_size():
|
1591
|
+
"""Return world size for the world group."""
|
1592
|
+
return get_world_group().world_size
|
1593
|
+
|
1594
|
+
|
1595
|
+
def get_world_rank():
|
1596
|
+
"""Return my rank for the world group."""
|
1597
|
+
return get_world_group().rank_in_group
|
1598
|
+
|
1599
|
+
|
1574
1600
|
def get_tensor_model_parallel_world_size():
|
1575
1601
|
"""Return world size for the tensor model parallel group."""
|
1576
1602
|
return get_tp_group().world_size
|
@@ -1581,6 +1607,16 @@ def get_tensor_model_parallel_rank():
|
|
1581
1607
|
return get_tp_group().rank_in_group
|
1582
1608
|
|
1583
1609
|
|
1610
|
+
def get_pipeline_model_parallel_world_size():
|
1611
|
+
"""Return world size for the pipeline model parallel group."""
|
1612
|
+
return get_pp_group().world_size
|
1613
|
+
|
1614
|
+
|
1615
|
+
def get_pipeline_model_parallel_rank():
|
1616
|
+
"""Return my rank for the pipeline model parallel group."""
|
1617
|
+
return get_pp_group().rank_in_group
|
1618
|
+
|
1619
|
+
|
1584
1620
|
def get_moe_expert_parallel_world_size():
|
1585
1621
|
"""Return world size for the moe expert parallel group."""
|
1586
1622
|
return get_moe_ep_group().world_size
|
@@ -1633,7 +1669,7 @@ def cleanup_dist_env_and_memory(shutdown_ray: bool = False):
|
|
1633
1669
|
|
1634
1670
|
ray.shutdown()
|
1635
1671
|
gc.collect()
|
1636
|
-
if not
|
1672
|
+
if not _is_cpu:
|
1637
1673
|
if hasattr(torch, "cuda") and torch.cuda.is_available():
|
1638
1674
|
torch.cuda.empty_cache()
|
1639
1675
|
if hasattr(torch._C, "_host_emptyCache"):
|