sglang 0.5.0rc1__py3-none-any.whl → 0.5.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +0 -7
- sglang/bench_one_batch_server.py +7 -2
- sglang/bench_serving.py +3 -3
- sglang/eval/llama3_eval.py +0 -1
- sglang/srt/configs/model_config.py +25 -9
- sglang/srt/configs/update_config.py +40 -5
- sglang/srt/constrained/xgrammar_backend.py +23 -11
- sglang/srt/conversation.py +2 -15
- sglang/srt/disaggregation/ascend/conn.py +1 -3
- sglang/srt/disaggregation/base/conn.py +1 -0
- sglang/srt/disaggregation/decode.py +1 -2
- sglang/srt/disaggregation/launch_lb.py +7 -1
- sglang/srt/disaggregation/mini_lb.py +11 -5
- sglang/srt/disaggregation/mooncake/conn.py +141 -47
- sglang/srt/disaggregation/prefill.py +261 -5
- sglang/srt/disaggregation/utils.py +2 -1
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +1 -1
- sglang/srt/distributed/device_communicators/pynccl.py +68 -18
- sglang/srt/distributed/device_communicators/pynccl_wrapper.py +52 -0
- sglang/srt/distributed/naive_distributed.py +112 -0
- sglang/srt/distributed/parallel_state.py +90 -4
- sglang/srt/entrypoints/context.py +20 -1
- sglang/srt/entrypoints/engine.py +29 -4
- sglang/srt/entrypoints/http_server.py +76 -0
- sglang/srt/entrypoints/openai/protocol.py +4 -2
- sglang/srt/entrypoints/openai/serving_chat.py +23 -6
- sglang/srt/entrypoints/openai/serving_completions.py +10 -1
- sglang/srt/entrypoints/openai/serving_responses.py +2 -2
- sglang/srt/eplb/expert_distribution.py +2 -3
- sglang/srt/function_call/deepseekv3_detector.py +1 -1
- sglang/srt/hf_transformers_utils.py +24 -0
- sglang/srt/host_shared_memory.py +83 -0
- sglang/srt/layers/attention/ascend_backend.py +132 -22
- sglang/srt/layers/attention/flashattention_backend.py +24 -17
- sglang/srt/layers/attention/flashinfer_backend.py +14 -3
- sglang/srt/layers/attention/flashinfer_mla_backend.py +227 -76
- sglang/srt/layers/attention/triton_backend.py +109 -73
- sglang/srt/layers/attention/triton_ops/decode_attention.py +33 -2
- sglang/srt/layers/attention/triton_ops/extend_attention.py +32 -2
- sglang/srt/layers/attention/trtllm_mha_backend.py +398 -36
- sglang/srt/layers/attention/trtllm_mla_backend.py +49 -19
- sglang/srt/layers/attention/utils.py +94 -15
- sglang/srt/layers/attention/vision.py +40 -13
- sglang/srt/layers/attention/vision_utils.py +65 -0
- sglang/srt/layers/communicator.py +58 -10
- sglang/srt/layers/dp_attention.py +137 -27
- sglang/srt/layers/elementwise.py +94 -0
- sglang/srt/layers/flashinfer_comm_fusion.py +29 -1
- sglang/srt/layers/layernorm.py +8 -1
- sglang/srt/layers/linear.py +24 -0
- sglang/srt/layers/logits_processor.py +16 -18
- sglang/srt/layers/moe/__init__.py +31 -0
- sglang/srt/layers/moe/ep_moe/layer.py +37 -33
- sglang/srt/layers/moe/fused_moe_native.py +14 -25
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=129,N=352,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=161,N=192,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_0/E=16,N=1024,device_name=NVIDIA_B200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=384,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=640,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H20-3e,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=384,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=704,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=161,N=384,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +69 -76
- sglang/srt/layers/moe/fused_moe_triton/layer.py +66 -123
- sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +20 -18
- sglang/srt/layers/moe/moe_runner/__init__.py +3 -0
- sglang/srt/layers/moe/moe_runner/base.py +13 -0
- sglang/srt/layers/moe/rocm_moe_utils.py +141 -0
- sglang/srt/layers/moe/router.py +15 -9
- sglang/srt/layers/moe/token_dispatcher/__init__.py +6 -0
- sglang/srt/layers/moe/token_dispatcher/base_dispatcher.py +55 -14
- sglang/srt/layers/moe/token_dispatcher/deepep.py +11 -21
- sglang/srt/layers/moe/token_dispatcher/standard.py +1 -1
- sglang/srt/layers/moe/topk.py +167 -83
- sglang/srt/layers/moe/utils.py +159 -18
- sglang/srt/layers/multimodal.py +156 -40
- sglang/srt/layers/quantization/__init__.py +18 -46
- sglang/srt/layers/quantization/awq.py +22 -23
- sglang/srt/layers/quantization/base_config.py +2 -6
- sglang/srt/layers/quantization/blockwise_int8.py +4 -12
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +72 -29
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -1
- sglang/srt/layers/quantization/fp8.py +127 -119
- sglang/srt/layers/quantization/fp8_kernel.py +195 -24
- sglang/srt/layers/quantization/fp8_utils.py +34 -9
- sglang/srt/layers/quantization/fpgemm_fp8.py +203 -0
- sglang/srt/layers/quantization/gptq.py +17 -21
- sglang/srt/layers/quantization/marlin_utils.py +26 -8
- sglang/srt/layers/quantization/marlin_utils_fp8.py +352 -0
- sglang/srt/layers/quantization/modelopt_quant.py +217 -98
- sglang/srt/layers/quantization/moe_wna16.py +10 -15
- sglang/srt/layers/quantization/mxfp4.py +222 -39
- sglang/srt/layers/quantization/quark/quark.py +390 -0
- sglang/srt/layers/quantization/quark/quark_moe.py +197 -0
- sglang/srt/layers/quantization/unquant.py +34 -70
- sglang/srt/layers/quantization/utils.py +77 -2
- sglang/srt/layers/quantization/w4afp8.py +7 -8
- sglang/srt/layers/quantization/w8a8_fp8.py +5 -13
- sglang/srt/layers/quantization/w8a8_int8.py +5 -13
- sglang/srt/layers/radix_attention.py +6 -0
- sglang/srt/layers/rotary_embedding.py +1 -0
- sglang/srt/layers/sampler.py +5 -2
- sglang/srt/lora/layers.py +6 -2
- sglang/srt/lora/lora_manager.py +21 -22
- sglang/srt/lora/lora_registry.py +3 -3
- sglang/srt/lora/mem_pool.py +26 -24
- sglang/srt/lora/utils.py +10 -12
- sglang/srt/managers/cache_controller.py +80 -19
- sglang/srt/managers/detokenizer_manager.py +10 -2
- sglang/srt/managers/io_struct.py +23 -0
- sglang/srt/managers/mm_utils.py +1 -1
- sglang/srt/managers/schedule_batch.py +22 -48
- sglang/srt/managers/scheduler.py +28 -20
- sglang/srt/managers/session_controller.py +1 -1
- sglang/srt/managers/template_manager.py +7 -5
- sglang/srt/managers/tokenizer_manager.py +88 -39
- sglang/srt/managers/tp_worker.py +1 -0
- sglang/srt/managers/utils.py +59 -1
- sglang/srt/mem_cache/allocator.py +10 -157
- sglang/srt/mem_cache/allocator_ascend.py +147 -0
- sglang/srt/mem_cache/chunk_cache.py +1 -1
- sglang/srt/mem_cache/hicache_storage.py +14 -4
- sglang/srt/mem_cache/memory_pool.py +3 -3
- sglang/srt/mem_cache/memory_pool_host.py +35 -2
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +56 -12
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +8 -4
- sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +153 -59
- sglang/srt/mem_cache/storage/nixl/nixl_utils.py +19 -53
- sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +46 -7
- sglang/srt/model_executor/cuda_graph_runner.py +33 -33
- sglang/srt/model_executor/forward_batch_info.py +11 -10
- sglang/srt/model_executor/model_runner.py +93 -78
- sglang/srt/model_executor/npu_graph_runner.py +94 -0
- sglang/srt/model_loader/loader.py +24 -6
- sglang/srt/models/dbrx.py +12 -6
- sglang/srt/models/deepseek.py +2 -1
- sglang/srt/models/deepseek_nextn.py +5 -2
- sglang/srt/models/deepseek_v2.py +226 -223
- sglang/srt/models/ernie4.py +2 -2
- sglang/srt/models/glm4_moe.py +27 -65
- sglang/srt/models/glm4_moe_nextn.py +2 -1
- sglang/srt/models/glm4v.py +52 -1
- sglang/srt/models/glm4v_moe.py +8 -11
- sglang/srt/models/gpt_oss.py +41 -76
- sglang/srt/models/granitemoe.py +0 -1
- sglang/srt/models/grok.py +376 -48
- sglang/srt/models/interns1.py +12 -47
- sglang/srt/models/internvl.py +6 -51
- sglang/srt/models/llama.py +10 -2
- sglang/srt/models/llama4.py +18 -7
- sglang/srt/models/minicpm3.py +0 -1
- sglang/srt/models/mixtral.py +0 -2
- sglang/srt/models/nemotron_nas.py +435 -0
- sglang/srt/models/olmoe.py +0 -1
- sglang/srt/models/phi4mm.py +3 -21
- sglang/srt/models/qwen2.py +2 -2
- sglang/srt/models/qwen2_5_vl.py +2 -0
- sglang/srt/models/qwen2_moe.py +23 -23
- sglang/srt/models/qwen3.py +2 -2
- sglang/srt/models/qwen3_classification.py +84 -0
- sglang/srt/models/qwen3_moe.py +27 -43
- sglang/srt/models/step3_vl.py +8 -3
- sglang/srt/models/xverse_moe.py +11 -5
- sglang/srt/multimodal/processors/base_processor.py +3 -3
- sglang/srt/multimodal/processors/internvl.py +7 -2
- sglang/srt/multimodal/processors/llava.py +11 -7
- sglang/srt/offloader.py +433 -0
- sglang/srt/operations.py +22 -2
- sglang/srt/reasoning_parser.py +4 -3
- sglang/srt/sampling/sampling_batch_info.py +7 -4
- sglang/srt/server_args.py +264 -105
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +8 -21
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +7 -21
- sglang/srt/speculative/eagle_utils.py +36 -13
- sglang/srt/speculative/eagle_worker.py +56 -3
- sglang/srt/tokenizer/tiktoken_tokenizer.py +161 -0
- sglang/srt/two_batch_overlap.py +20 -19
- sglang/srt/utils.py +68 -70
- sglang/test/runners.py +8 -5
- sglang/test/test_block_fp8.py +5 -6
- sglang/test/test_block_fp8_ep.py +13 -19
- sglang/test/test_cutlass_moe.py +4 -6
- sglang/test/test_cutlass_w4a8_moe.py +4 -3
- sglang/test/test_fp4_moe.py +4 -3
- sglang/test/test_marlin_moe.py +1 -1
- sglang/test/test_marlin_utils.py +1 -1
- sglang/test/test_utils.py +7 -0
- sglang/utils.py +0 -1
- sglang/version.py +1 -1
- {sglang-0.5.0rc1.dist-info → sglang-0.5.1.dist-info}/METADATA +11 -11
- {sglang-0.5.0rc1.dist-info → sglang-0.5.1.dist-info}/RECORD +201 -171
- sglang/srt/layers/quantization/fp4.py +0 -557
- sglang/srt/layers/quantization/scalar_type.py +0 -352
- {sglang-0.5.0rc1.dist-info → sglang-0.5.1.dist-info}/WHEEL +0 -0
- {sglang-0.5.0rc1.dist-info → sglang-0.5.1.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.0rc1.dist-info → sglang-0.5.1.dist-info}/top_level.txt +0 -0
@@ -2,6 +2,7 @@ from __future__ import annotations
|
|
2
2
|
|
3
3
|
import asyncio
|
4
4
|
import concurrent.futures
|
5
|
+
import ctypes
|
5
6
|
import dataclasses
|
6
7
|
import logging
|
7
8
|
import os
|
@@ -34,6 +35,7 @@ from sglang.srt.disaggregation.common.utils import (
|
|
34
35
|
)
|
35
36
|
from sglang.srt.disaggregation.mooncake.transfer_engine import MooncakeTransferEngine
|
36
37
|
from sglang.srt.disaggregation.utils import DisaggregationMode
|
38
|
+
from sglang.srt.distributed import get_pp_group
|
37
39
|
from sglang.srt.layers.dp_attention import (
|
38
40
|
get_attention_dp_rank,
|
39
41
|
get_attention_dp_size,
|
@@ -137,7 +139,29 @@ class KVArgsRegisterInfo:
|
|
137
139
|
)
|
138
140
|
|
139
141
|
|
142
|
+
class AuxDataCodec:
|
143
|
+
"""Handles serialization and deserialization of auxiliary data buffers"""
|
144
|
+
|
145
|
+
@staticmethod
|
146
|
+
def serialize_data_from_buffer(src_addr, data_length):
|
147
|
+
"""Serialize data from memory buffer to bytes"""
|
148
|
+
buffer = (ctypes.c_byte * data_length).from_address(src_addr)
|
149
|
+
return bytes(buffer)
|
150
|
+
|
151
|
+
@staticmethod
|
152
|
+
def deserialize_data_to_buffer(kv_args, buffer_index, aux_index, data):
|
153
|
+
"""Deserialize bytes into target memory buffer"""
|
154
|
+
dst_aux_ptr = kv_args.aux_data_ptrs[buffer_index]
|
155
|
+
item_len = kv_args.aux_item_lens[buffer_index]
|
156
|
+
dst_addr = dst_aux_ptr + item_len * aux_index
|
157
|
+
buffer = (ctypes.c_byte * len(data)).from_address(dst_addr)
|
158
|
+
buffer[:] = data
|
159
|
+
return
|
160
|
+
|
161
|
+
|
140
162
|
class MooncakeKVManager(BaseKVManager):
|
163
|
+
AUX_DATA_HEADER = b"AUX_DATA"
|
164
|
+
|
141
165
|
def __init__(
|
142
166
|
self,
|
143
167
|
args: KVArgs,
|
@@ -180,6 +204,7 @@ class MooncakeKVManager(BaseKVManager):
|
|
180
204
|
self.session_failures = defaultdict(int)
|
181
205
|
self.failed_sessions = set()
|
182
206
|
self.session_lock = threading.Lock()
|
207
|
+
self.pp_group = get_pp_group()
|
183
208
|
# Determine the number of threads to use for kv sender
|
184
209
|
cpu_count = os.cpu_count()
|
185
210
|
transfer_thread_pool_size = get_int_env_var(
|
@@ -281,21 +306,10 @@ class MooncakeKVManager(BaseKVManager):
|
|
281
306
|
if not transfer_blocks:
|
282
307
|
return 0
|
283
308
|
|
284
|
-
|
285
|
-
|
286
|
-
|
287
|
-
|
288
|
-
status = self.engine.transfer_sync(
|
289
|
-
mooncake_session_id, src_addr, dst_addr, length
|
290
|
-
)
|
291
|
-
if status != 0:
|
292
|
-
return status
|
293
|
-
return 0
|
294
|
-
else:
|
295
|
-
src_addrs, dst_addrs, lengths = zip(*transfer_blocks)
|
296
|
-
return self.engine.batch_transfer_sync(
|
297
|
-
mooncake_session_id, list(src_addrs), list(dst_addrs), list(lengths)
|
298
|
-
)
|
309
|
+
src_addrs, dst_addrs, lengths = zip(*transfer_blocks)
|
310
|
+
return self.engine.batch_transfer_sync(
|
311
|
+
mooncake_session_id, list(src_addrs), list(dst_addrs), list(lengths)
|
312
|
+
)
|
299
313
|
|
300
314
|
def send_kvcache(
|
301
315
|
self,
|
@@ -313,11 +327,11 @@ class MooncakeKVManager(BaseKVManager):
|
|
313
327
|
layers_params = None
|
314
328
|
|
315
329
|
# pp is not supported on the decode side yet
|
330
|
+
start_layer = self.kv_args.prefill_start_layer
|
331
|
+
end_layer = start_layer + len(self.kv_args.kv_data_ptrs)
|
316
332
|
if self.is_mla_backend:
|
317
333
|
src_kv_ptrs = self.kv_args.kv_data_ptrs
|
318
334
|
layers_per_pp_stage = len(src_kv_ptrs)
|
319
|
-
start_layer = self.pp_rank * layers_per_pp_stage
|
320
|
-
end_layer = start_layer + layers_per_pp_stage
|
321
335
|
dst_kv_ptrs = dst_kv_ptrs[start_layer:end_layer]
|
322
336
|
kv_item_len = self.kv_args.kv_item_lens[0]
|
323
337
|
layers_params = [
|
@@ -330,17 +344,15 @@ class MooncakeKVManager(BaseKVManager):
|
|
330
344
|
]
|
331
345
|
else:
|
332
346
|
num_kv_layers = len(self.kv_args.kv_data_ptrs) // 2
|
347
|
+
dst_num_total_layers = num_kv_layers * self.pp_size
|
333
348
|
src_k_ptrs = self.kv_args.kv_data_ptrs[:num_kv_layers]
|
334
349
|
src_v_ptrs = self.kv_args.kv_data_ptrs[num_kv_layers:]
|
335
350
|
layers_per_pp_stage = len(src_k_ptrs)
|
336
|
-
start_layer = self.pp_rank * layers_per_pp_stage
|
337
|
-
end_layer = start_layer + layers_per_pp_stage
|
338
351
|
dst_k_ptrs = dst_kv_ptrs[start_layer:end_layer]
|
339
352
|
dst_v_ptrs = dst_kv_ptrs[
|
340
|
-
|
353
|
+
dst_num_total_layers + start_layer : dst_num_total_layers + end_layer
|
341
354
|
]
|
342
355
|
kv_item_len = self.kv_args.kv_item_lens[0]
|
343
|
-
|
344
356
|
layers_params = [
|
345
357
|
(
|
346
358
|
src_k_ptrs[layer_id],
|
@@ -452,6 +464,7 @@ class MooncakeKVManager(BaseKVManager):
|
|
452
464
|
|
453
465
|
# pp is not supported on the decode side yet
|
454
466
|
num_kv_layers = len(self.kv_args.kv_data_ptrs) // 2
|
467
|
+
dst_num_total_layers = num_kv_layers * self.pp_size
|
455
468
|
src_k_ptrs = self.kv_args.kv_data_ptrs[:num_kv_layers]
|
456
469
|
src_v_ptrs = self.kv_args.kv_data_ptrs[num_kv_layers:]
|
457
470
|
layers_per_pp_stage = len(src_k_ptrs)
|
@@ -459,7 +472,7 @@ class MooncakeKVManager(BaseKVManager):
|
|
459
472
|
end_layer = start_layer + layers_per_pp_stage
|
460
473
|
dst_k_ptrs = dst_kv_ptrs[start_layer:end_layer]
|
461
474
|
dst_v_ptrs = dst_kv_ptrs[
|
462
|
-
|
475
|
+
dst_num_total_layers + start_layer : dst_num_total_layers + end_layer
|
463
476
|
]
|
464
477
|
|
465
478
|
# Calculate precise byte offset and length for the sub-slice within the token
|
@@ -569,11 +582,14 @@ class MooncakeKVManager(BaseKVManager):
|
|
569
582
|
|
570
583
|
def send_aux(
|
571
584
|
self,
|
572
|
-
|
585
|
+
req: TransferInfo,
|
573
586
|
prefill_aux_index: int,
|
574
587
|
dst_aux_ptrs: list[int],
|
575
|
-
dst_aux_index: int,
|
576
588
|
):
|
589
|
+
# TODO(shangming): Fix me when nvlink_transport of Mooncake is bug-free
|
590
|
+
if self.enable_custom_mem_pool:
|
591
|
+
return self.send_aux_tcp(req, prefill_aux_index, dst_aux_ptrs)
|
592
|
+
|
577
593
|
transfer_blocks = []
|
578
594
|
prefill_aux_ptrs = self.kv_args.aux_data_ptrs
|
579
595
|
prefill_aux_item_lens = self.kv_args.aux_item_lens
|
@@ -581,10 +597,59 @@ class MooncakeKVManager(BaseKVManager):
|
|
581
597
|
for i, dst_aux_ptr in enumerate(dst_aux_ptrs):
|
582
598
|
length = prefill_aux_item_lens[i]
|
583
599
|
src_addr = prefill_aux_ptrs[i] + length * prefill_aux_index
|
584
|
-
dst_addr = dst_aux_ptrs[i] + length * dst_aux_index
|
600
|
+
dst_addr = dst_aux_ptrs[i] + length * req.dst_aux_index
|
585
601
|
transfer_blocks.append((src_addr, dst_addr, length))
|
586
602
|
|
587
|
-
return self._transfer_data(mooncake_session_id, transfer_blocks)
|
603
|
+
return self._transfer_data(req.mooncake_session_id, transfer_blocks)
|
604
|
+
|
605
|
+
def send_aux_tcp(
|
606
|
+
self,
|
607
|
+
req: TransferInfo,
|
608
|
+
prefill_aux_index: int,
|
609
|
+
dst_aux_ptrs: list[int],
|
610
|
+
):
|
611
|
+
prefill_aux_ptrs = self.kv_args.aux_data_ptrs
|
612
|
+
prefill_aux_item_lens = self.kv_args.aux_item_lens
|
613
|
+
|
614
|
+
for i in range(len(prefill_aux_ptrs)):
|
615
|
+
length = prefill_aux_item_lens[i]
|
616
|
+
src_addr = prefill_aux_ptrs[i] + length * prefill_aux_index
|
617
|
+
data = AuxDataCodec.serialize_data_from_buffer(src_addr, length)
|
618
|
+
|
619
|
+
self.send_aux_data_to_endpoint(
|
620
|
+
remote=req.endpoint,
|
621
|
+
dst_port=req.dst_port,
|
622
|
+
room=req.room,
|
623
|
+
buffer_index=i,
|
624
|
+
aux_index=req.dst_aux_index,
|
625
|
+
data=data,
|
626
|
+
)
|
627
|
+
|
628
|
+
return 0
|
629
|
+
|
630
|
+
def send_aux_data_to_endpoint(
|
631
|
+
self,
|
632
|
+
remote: str,
|
633
|
+
dst_port: int,
|
634
|
+
room: int,
|
635
|
+
buffer_index: int,
|
636
|
+
aux_index: int,
|
637
|
+
data: bytes,
|
638
|
+
):
|
639
|
+
socket = self._connect(
|
640
|
+
format_tcp_address(remote, dst_port), is_ipv6=is_valid_ipv6_address(remote)
|
641
|
+
)
|
642
|
+
|
643
|
+
socket.send_multipart(
|
644
|
+
[
|
645
|
+
MooncakeKVManager.AUX_DATA_HEADER,
|
646
|
+
str(room).encode("ascii"),
|
647
|
+
str(buffer_index).encode("ascii"),
|
648
|
+
str(aux_index).encode("ascii"),
|
649
|
+
struct.pack(">I", len(data)),
|
650
|
+
data,
|
651
|
+
]
|
652
|
+
)
|
588
653
|
|
589
654
|
def sync_status_to_decode_endpoint(
|
590
655
|
self, remote: str, dst_port: int, room: int, status: int, prefill_rank: int
|
@@ -612,7 +677,7 @@ class MooncakeKVManager(BaseKVManager):
|
|
612
677
|
)
|
613
678
|
polls = []
|
614
679
|
dst_ranks_infos = []
|
615
|
-
local_rank = self.
|
680
|
+
local_rank = self.attn_tp_rank * self.pp_size + self.pp_rank
|
616
681
|
for req in reqs_to_be_processed:
|
617
682
|
if not req.is_dummy:
|
618
683
|
# Early exit if the request has failed
|
@@ -695,13 +760,13 @@ class MooncakeKVManager(BaseKVManager):
|
|
695
760
|
break
|
696
761
|
|
697
762
|
if kv_chunk.is_last:
|
698
|
-
|
699
|
-
|
700
|
-
|
701
|
-
|
702
|
-
|
703
|
-
|
704
|
-
|
763
|
+
if self.pp_group.is_last_rank:
|
764
|
+
# Only the last chunk we need to send the aux data
|
765
|
+
ret = self.send_aux(
|
766
|
+
req,
|
767
|
+
kv_chunk.prefill_aux_index,
|
768
|
+
target_rank_registration_info.dst_aux_ptrs,
|
769
|
+
)
|
705
770
|
polls.append(True if ret == 0 else False)
|
706
771
|
dst_ranks_infos.append(
|
707
772
|
(req.endpoint, req.dst_port, req.room)
|
@@ -776,15 +841,38 @@ class MooncakeKVManager(BaseKVManager):
|
|
776
841
|
|
777
842
|
threading.Thread(target=bootstrap_thread).start()
|
778
843
|
|
844
|
+
def _handle_aux_data(self, msg: List[bytes]):
|
845
|
+
"""Handle AUX_DATA messages received by the decode thread."""
|
846
|
+
room = int(msg[1].decode("ascii"))
|
847
|
+
buffer_index = int(msg[2].decode("ascii"))
|
848
|
+
aux_index = int(msg[3].decode("ascii"))
|
849
|
+
data_length = struct.unpack(">I", msg[4])[0]
|
850
|
+
data = msg[5]
|
851
|
+
|
852
|
+
if len(data) != data_length:
|
853
|
+
logger.error(f"AUX_DATA length mismatch for bootstrap_room {room}")
|
854
|
+
return
|
855
|
+
|
856
|
+
AuxDataCodec.deserialize_data_to_buffer(
|
857
|
+
self.kv_args, buffer_index, aux_index, data
|
858
|
+
)
|
859
|
+
|
860
|
+
logger.debug(
|
861
|
+
f"Received AUX_DATA for bootstrap_room {room} with length:{len(data)}"
|
862
|
+
)
|
863
|
+
|
779
864
|
def start_decode_thread(self):
|
780
865
|
self.rank_port = get_free_port()
|
781
866
|
self._bind_server_socket()
|
782
867
|
|
783
868
|
def decode_thread():
|
784
869
|
while True:
|
785
|
-
|
786
|
-
|
787
|
-
|
870
|
+
msg = self.server_socket.recv_multipart()
|
871
|
+
if msg[0] == MooncakeKVManager.AUX_DATA_HEADER:
|
872
|
+
self._handle_aux_data(msg)
|
873
|
+
continue
|
874
|
+
|
875
|
+
(bootstrap_room, status, prefill_rank) = msg
|
788
876
|
status = int(status.decode("ascii"))
|
789
877
|
bootstrap_room = int(bootstrap_room.decode("ascii"))
|
790
878
|
prefill_rank = int(prefill_rank.decode("ascii"))
|
@@ -798,10 +886,7 @@ class MooncakeKVManager(BaseKVManager):
|
|
798
886
|
arrived_response_num = len(
|
799
887
|
self.prefill_response_tracker[bootstrap_room]
|
800
888
|
)
|
801
|
-
if
|
802
|
-
self.is_mla_backend
|
803
|
-
or arrived_response_num == expected_response_num
|
804
|
-
):
|
889
|
+
if arrived_response_num == expected_response_num:
|
805
890
|
self.update_status(bootstrap_room, KVPoll.Success)
|
806
891
|
elif status == KVPoll.Failed:
|
807
892
|
self.record_failure(
|
@@ -1183,7 +1268,9 @@ class MooncakeKVReceiver(BaseKVReceiver):
|
|
1183
1268
|
self.kv_mgr.kv_args.engine_rank % self.kv_mgr.attn_tp_size
|
1184
1269
|
)
|
1185
1270
|
self.required_dst_info_num = 1
|
1186
|
-
self.required_prefill_response_num = 1
|
1271
|
+
self.required_prefill_response_num = 1 * (
|
1272
|
+
self.prefill_pp_size // self.kv_mgr.pp_size
|
1273
|
+
)
|
1187
1274
|
self.target_tp_ranks = [self.target_tp_rank]
|
1188
1275
|
elif self.kv_mgr.attn_tp_size > self.prefill_attn_tp_size:
|
1189
1276
|
if not self.kv_mgr.is_mla_backend:
|
@@ -1196,7 +1283,9 @@ class MooncakeKVReceiver(BaseKVReceiver):
|
|
1196
1283
|
self.required_dst_info_num = (
|
1197
1284
|
self.kv_mgr.attn_tp_size // self.prefill_attn_tp_size
|
1198
1285
|
)
|
1199
|
-
self.required_prefill_response_num = 1
|
1286
|
+
self.required_prefill_response_num = 1 * (
|
1287
|
+
self.prefill_pp_size // self.kv_mgr.pp_size
|
1288
|
+
)
|
1200
1289
|
self.target_tp_ranks = [self.target_tp_rank]
|
1201
1290
|
else:
|
1202
1291
|
if not self.kv_mgr.is_mla_backend:
|
@@ -1219,9 +1308,14 @@ class MooncakeKVReceiver(BaseKVReceiver):
|
|
1219
1308
|
# or the KVPoll will never be set correctly
|
1220
1309
|
self.target_tp_rank = self.target_tp_ranks[0]
|
1221
1310
|
self.required_dst_info_num = 1
|
1222
|
-
self.
|
1223
|
-
self.
|
1224
|
-
|
1311
|
+
if self.kv_mgr.is_mla_backend:
|
1312
|
+
self.required_prefill_response_num = (
|
1313
|
+
self.prefill_pp_size // self.kv_mgr.pp_size
|
1314
|
+
)
|
1315
|
+
else:
|
1316
|
+
self.required_prefill_response_num = (
|
1317
|
+
self.prefill_attn_tp_size // self.kv_mgr.attn_tp_size
|
1318
|
+
) * (self.prefill_pp_size // self.kv_mgr.pp_size)
|
1225
1319
|
|
1226
1320
|
if self.data_parallel_rank is not None:
|
1227
1321
|
logger.debug(f"Targeting DP rank: {self.data_parallel_rank}")
|
@@ -1530,7 +1624,7 @@ class MooncakeKVBootstrapServer(BaseKVBootstrapServer):
|
|
1530
1624
|
"rank_port": rank_port,
|
1531
1625
|
}
|
1532
1626
|
logger.debug(
|
1533
|
-
f"Register prefill bootstrap: DP
|
1627
|
+
f"Register prefill bootstrap: DP{dp_group} TP{attn_tp_rank} PP{pp_rank} with rank_ip: {rank_ip} and rank_port: {rank_port}"
|
1534
1628
|
)
|
1535
1629
|
|
1536
1630
|
return web.Response(text="OK", status=200)
|
@@ -43,8 +43,13 @@ from sglang.srt.disaggregation.utils import (
|
|
43
43
|
prepare_abort,
|
44
44
|
)
|
45
45
|
from sglang.srt.managers.schedule_batch import FINISH_LENGTH, Req, ScheduleBatch
|
46
|
-
from sglang.srt.model_executor.forward_batch_info import ForwardMode
|
47
|
-
from sglang.srt.utils import
|
46
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardMode, PPProxyTensors
|
47
|
+
from sglang.srt.utils import (
|
48
|
+
DynamicGradMode,
|
49
|
+
broadcast_pyobj,
|
50
|
+
point_to_point_pyobj,
|
51
|
+
require_mlp_sync,
|
52
|
+
)
|
48
53
|
|
49
54
|
if TYPE_CHECKING:
|
50
55
|
from torch.distributed import ProcessGroup
|
@@ -107,6 +112,7 @@ class PrefillBootstrapQueue:
|
|
107
112
|
kv_args.system_dp_rank = self.scheduler.dp_rank
|
108
113
|
kv_args.decode_tp_size = self.decode_tp_size // self.decode_dp_size
|
109
114
|
kv_args.prefill_pp_size = self.pp_size
|
115
|
+
kv_args.prefill_start_layer = self.token_to_kv_pool.start_layer
|
110
116
|
kv_data_ptrs, kv_data_lens, kv_item_lens = (
|
111
117
|
self.token_to_kv_pool.get_contiguous_buf_infos()
|
112
118
|
)
|
@@ -172,7 +178,7 @@ class PrefillBootstrapQueue:
|
|
172
178
|
if len(req.origin_input_ids) > self.max_total_num_tokens:
|
173
179
|
message = f"Request {req.rid} exceeds the maximum number of tokens: {len(req.origin_input_ids)} > {self.max_total_num_tokens}"
|
174
180
|
logger.error(message)
|
175
|
-
prepare_abort(req, message)
|
181
|
+
prepare_abort(req, message, status_code=HTTPStatus.BAD_REQUEST)
|
176
182
|
self.scheduler.stream_output([req], req.return_logprob)
|
177
183
|
return True
|
178
184
|
return False
|
@@ -208,8 +214,8 @@ class PrefillBootstrapQueue:
|
|
208
214
|
polls = poll_and_all_reduce(
|
209
215
|
[req.disagg_kv_sender for req in self.queue], self.gloo_group
|
210
216
|
)
|
211
|
-
for i, (req, poll) in enumerate(zip(self.queue, polls)):
|
212
217
|
|
218
|
+
for i, (req, poll) in enumerate(zip(self.queue, polls)):
|
213
219
|
if rids_to_check is not None:
|
214
220
|
# if req not in reqs_info_to_check, skip
|
215
221
|
if req.rid not in rids_to_check:
|
@@ -395,7 +401,10 @@ class SchedulerDisaggregationPrefillMixin:
|
|
395
401
|
req.output_ids.append(next_token_id)
|
396
402
|
self.tree_cache.cache_unfinished_req(req) # update the tree and lock
|
397
403
|
self.disagg_prefill_inflight_queue.append(req)
|
398
|
-
if
|
404
|
+
if (
|
405
|
+
logits_output is not None
|
406
|
+
and logits_output.hidden_states is not None
|
407
|
+
):
|
399
408
|
last_hidden_index = (
|
400
409
|
hidden_state_offset + extend_input_len_per_req[i] - 1
|
401
410
|
)
|
@@ -603,3 +612,250 @@ class SchedulerDisaggregationPrefillMixin:
|
|
603
612
|
)
|
604
613
|
return
|
605
614
|
req.disagg_kv_sender.send(page_indices)
|
615
|
+
|
616
|
+
# PP
|
617
|
+
@DynamicGradMode()
|
618
|
+
def event_loop_pp_disagg_prefill(self: Scheduler):
|
619
|
+
"""
|
620
|
+
An event loop for the prefill server in pipeline parallelism.
|
621
|
+
|
622
|
+
Rules:
|
623
|
+
1. Each stage runs in the same order and is notified by the previous stage.
|
624
|
+
2. Each send/recv operation is blocking and matched by the neighboring stage.
|
625
|
+
|
626
|
+
Regular Schedule:
|
627
|
+
====================================================================
|
628
|
+
Stage i | Stage i+1
|
629
|
+
send ith req | recv ith req
|
630
|
+
send ith proxy | recv ith proxy
|
631
|
+
send prev (i+1)th carry | recv prev (i+1)th carry
|
632
|
+
====================================================================
|
633
|
+
|
634
|
+
Prefill Server Schedule:
|
635
|
+
====================================================================
|
636
|
+
Stage i | Stage i+1
|
637
|
+
send ith req | recv ith req
|
638
|
+
send ith bootstrap req | recv ith bootstrap req
|
639
|
+
send ith transferred req | recv ith transferred req
|
640
|
+
send ith proxy | recv ith proxy
|
641
|
+
send prev (i+1)th carry | recv prev (i+1)th carry
|
642
|
+
send prev (i+1)th release req | recv prev (i+1)th release req
|
643
|
+
====================================================================
|
644
|
+
|
645
|
+
There are two additional elements compared to the regular schedule:
|
646
|
+
|
647
|
+
1. Bootstrap Requests:
|
648
|
+
a. Instead of polling the status on the current workers, we should wait for the previous stage to notify to avoid desynchronization.
|
649
|
+
b. The first stage polls the status and propagates the bootstrapped requests down to all other stages.
|
650
|
+
c. If the first stage polls successfully, by nature, other ranks are also successful because they performed a handshake together.
|
651
|
+
|
652
|
+
2. Transferred Requests + Release Requests:
|
653
|
+
a. The first stage polls the transfer finished requests, performs an intersection with the next stage's finished requests, and propagates down to the last stage.
|
654
|
+
b. The last stage receives the requests that have finished transfer on all stages (consensus), then sends them to the first stage to release the memory.
|
655
|
+
c. The first stage receives the release requests, releases the memory, and then propagates the release requests down to the last stage.
|
656
|
+
"""
|
657
|
+
from sglang.srt.managers.scheduler import GenerationBatchResult
|
658
|
+
|
659
|
+
mbs = [None] * self.pp_size
|
660
|
+
last_mbs = [None] * self.pp_size
|
661
|
+
self.running_mbs = [
|
662
|
+
ScheduleBatch(reqs=[], batch_is_full=False) for _ in range(self.pp_size)
|
663
|
+
]
|
664
|
+
bids = [None] * self.pp_size
|
665
|
+
pp_outputs: Optional[PPProxyTensors] = None
|
666
|
+
|
667
|
+
# Either success or failed
|
668
|
+
bootstrapped_rids: List[str] = []
|
669
|
+
transferred_rids: List[str] = []
|
670
|
+
release_rids: Optional[List[str]] = None
|
671
|
+
|
672
|
+
# transferred microbatch
|
673
|
+
tmbs = [None] * self.pp_size
|
674
|
+
|
675
|
+
ENABLE_RELEASE = True # For debug
|
676
|
+
|
677
|
+
while True:
|
678
|
+
server_is_idle = True
|
679
|
+
|
680
|
+
for mb_id in range(self.pp_size):
|
681
|
+
self.running_batch = self.running_mbs[mb_id]
|
682
|
+
self.last_batch = last_mbs[mb_id]
|
683
|
+
|
684
|
+
recv_reqs = self.recv_requests()
|
685
|
+
|
686
|
+
self.process_input_requests(recv_reqs)
|
687
|
+
|
688
|
+
if self.pp_group.is_first_rank:
|
689
|
+
# First rank, pop the bootstrap reqs from the bootstrap queue
|
690
|
+
bootstrapped_reqs, failed_reqs = (
|
691
|
+
self.disagg_prefill_bootstrap_queue.pop_bootstrapped(
|
692
|
+
return_failed_reqs=True
|
693
|
+
)
|
694
|
+
)
|
695
|
+
bootstrapped_rids = [req.rid for req in bootstrapped_reqs] + [
|
696
|
+
req.rid for req in failed_reqs
|
697
|
+
]
|
698
|
+
self.waiting_queue.extend(bootstrapped_reqs)
|
699
|
+
else:
|
700
|
+
# Other ranks, receive the bootstrap reqs info from the previous rank and ensure the consensus
|
701
|
+
bootstrapped_rids = self.recv_pyobj_from_prev_stage()
|
702
|
+
bootstrapped_reqs = (
|
703
|
+
self.disagg_prefill_bootstrap_queue.pop_bootstrapped(
|
704
|
+
rids_to_check=bootstrapped_rids
|
705
|
+
)
|
706
|
+
)
|
707
|
+
self.waiting_queue.extend(bootstrapped_reqs)
|
708
|
+
|
709
|
+
if self.pp_group.is_first_rank:
|
710
|
+
transferred_rids = self.get_transferred_rids()
|
711
|
+
# if other ranks,
|
712
|
+
else:
|
713
|
+
# 1. recv previous stage's transferred reqs info
|
714
|
+
prev_transferred_rids = self.recv_pyobj_from_prev_stage()
|
715
|
+
# 2. get the current stage's transferred reqs info
|
716
|
+
curr_transferred_rids = self.get_transferred_rids()
|
717
|
+
# 3. new consensus rids = intersection(previous consensus rids, transfer finished rids)
|
718
|
+
transferred_rids = list(
|
719
|
+
set(prev_transferred_rids) & set(curr_transferred_rids)
|
720
|
+
)
|
721
|
+
|
722
|
+
tmbs[mb_id] = transferred_rids
|
723
|
+
|
724
|
+
self.process_prefill_chunk()
|
725
|
+
mbs[mb_id] = self.get_new_batch_prefill()
|
726
|
+
self.running_mbs[mb_id] = self.running_batch
|
727
|
+
|
728
|
+
self.cur_batch = mbs[mb_id]
|
729
|
+
if self.cur_batch:
|
730
|
+
server_is_idle = False
|
731
|
+
result = self.run_batch(self.cur_batch)
|
732
|
+
|
733
|
+
# send the outputs to the next step
|
734
|
+
if self.pp_group.is_last_rank:
|
735
|
+
if self.cur_batch:
|
736
|
+
next_token_ids, bids[mb_id] = (
|
737
|
+
result.next_token_ids,
|
738
|
+
result.bid,
|
739
|
+
)
|
740
|
+
pp_outputs = PPProxyTensors(
|
741
|
+
{
|
742
|
+
"next_token_ids": next_token_ids,
|
743
|
+
}
|
744
|
+
)
|
745
|
+
# send the output from the last round to let the next stage worker run post processing
|
746
|
+
self.pp_group.send_tensor_dict(
|
747
|
+
pp_outputs.tensors,
|
748
|
+
all_gather_group=self.attn_tp_group,
|
749
|
+
)
|
750
|
+
|
751
|
+
if ENABLE_RELEASE:
|
752
|
+
if self.pp_group.is_last_rank:
|
753
|
+
# At the last stage, all stages has reached the consensus to release memory for transferred_rids
|
754
|
+
release_rids = transferred_rids
|
755
|
+
# send to the first rank
|
756
|
+
self.send_pyobj_to_next_stage(release_rids)
|
757
|
+
|
758
|
+
# receive outputs and post-process (filter finished reqs) the coming microbatch
|
759
|
+
next_mb_id = (mb_id + 1) % self.pp_size
|
760
|
+
next_pp_outputs = None
|
761
|
+
next_release_rids = None
|
762
|
+
|
763
|
+
if mbs[next_mb_id] is not None:
|
764
|
+
next_pp_outputs: Optional[PPProxyTensors] = PPProxyTensors(
|
765
|
+
self.pp_group.recv_tensor_dict(
|
766
|
+
all_gather_group=self.attn_tp_group
|
767
|
+
)
|
768
|
+
)
|
769
|
+
mbs[next_mb_id].output_ids = next_pp_outputs["next_token_ids"]
|
770
|
+
output_result = GenerationBatchResult(
|
771
|
+
logits_output=None,
|
772
|
+
pp_hidden_states_proxy_tensors=None,
|
773
|
+
next_token_ids=next_pp_outputs["next_token_ids"],
|
774
|
+
extend_input_len_per_req=None,
|
775
|
+
extend_logprob_start_len_per_req=None,
|
776
|
+
bid=bids[next_mb_id],
|
777
|
+
can_run_cuda_graph=result.can_run_cuda_graph,
|
778
|
+
)
|
779
|
+
self.process_batch_result_disagg_prefill(
|
780
|
+
mbs[next_mb_id], output_result
|
781
|
+
)
|
782
|
+
|
783
|
+
last_mbs[next_mb_id] = mbs[next_mb_id]
|
784
|
+
|
785
|
+
if ENABLE_RELEASE:
|
786
|
+
if tmbs[next_mb_id] is not None:
|
787
|
+
# recv consensus rids from the previous rank
|
788
|
+
next_release_rids = self.recv_pyobj_from_prev_stage()
|
789
|
+
self.process_disagg_prefill_inflight_queue(next_release_rids)
|
790
|
+
|
791
|
+
# carry the outputs to the next stage
|
792
|
+
if not self.pp_group.is_last_rank:
|
793
|
+
if self.cur_batch:
|
794
|
+
bids[mb_id] = result.bid
|
795
|
+
if pp_outputs:
|
796
|
+
# send the outputs from the last round to let the next stage worker run post processing
|
797
|
+
self.pp_group.send_tensor_dict(
|
798
|
+
pp_outputs.tensors,
|
799
|
+
all_gather_group=self.attn_tp_group,
|
800
|
+
)
|
801
|
+
if ENABLE_RELEASE:
|
802
|
+
if release_rids is not None:
|
803
|
+
self.send_pyobj_to_next_stage(release_rids)
|
804
|
+
|
805
|
+
if not self.pp_group.is_last_rank:
|
806
|
+
# send out reqs to the next stage
|
807
|
+
self.send_pyobj_to_next_stage(recv_reqs)
|
808
|
+
self.send_pyobj_to_next_stage(bootstrapped_rids)
|
809
|
+
self.send_pyobj_to_next_stage(transferred_rids)
|
810
|
+
|
811
|
+
# send out proxy tensors to the next stage
|
812
|
+
if self.cur_batch:
|
813
|
+
self.pp_group.send_tensor_dict(
|
814
|
+
result.pp_hidden_states_proxy_tensors,
|
815
|
+
all_gather_group=self.attn_tp_group,
|
816
|
+
)
|
817
|
+
|
818
|
+
pp_outputs = next_pp_outputs
|
819
|
+
release_rids = next_release_rids
|
820
|
+
|
821
|
+
self.running_batch.batch_is_full = False
|
822
|
+
|
823
|
+
if not ENABLE_RELEASE:
|
824
|
+
if len(self.disagg_prefill_inflight_queue) > 0:
|
825
|
+
self.process_disagg_prefill_inflight_queue()
|
826
|
+
|
827
|
+
# When the server is idle, self-check and re-init some states
|
828
|
+
if server_is_idle and len(self.disagg_prefill_inflight_queue) == 0:
|
829
|
+
self.check_memory()
|
830
|
+
self.check_tree_cache()
|
831
|
+
self.new_token_ratio = self.init_new_token_ratio
|
832
|
+
|
833
|
+
def send_pyobj_to_next_stage(self, data):
|
834
|
+
if self.attn_tp_rank == 0:
|
835
|
+
dp_offset = self.attn_dp_rank * self.attn_tp_size
|
836
|
+
point_to_point_pyobj(
|
837
|
+
data,
|
838
|
+
self.pp_rank * self.tp_size + dp_offset,
|
839
|
+
self.world_group.device_group,
|
840
|
+
self.pp_rank * self.tp_size + dp_offset,
|
841
|
+
((self.pp_rank + 1) % self.pp_size) * self.tp_size + dp_offset,
|
842
|
+
)
|
843
|
+
|
844
|
+
def recv_pyobj_from_prev_stage(self):
|
845
|
+
if self.attn_tp_rank == 0:
|
846
|
+
dp_offset = self.attn_dp_rank * self.attn_tp_size
|
847
|
+
data = point_to_point_pyobj(
|
848
|
+
[],
|
849
|
+
self.pp_rank * self.tp_size + dp_offset,
|
850
|
+
self.world_group.device_group,
|
851
|
+
((self.pp_rank - 1) % self.pp_size) * self.tp_size + dp_offset,
|
852
|
+
self.pp_rank * self.tp_size + dp_offset,
|
853
|
+
)
|
854
|
+
else:
|
855
|
+
data = None
|
856
|
+
|
857
|
+
if self.tp_size != 1:
|
858
|
+
data = broadcast_pyobj(
|
859
|
+
data, self.tp_group.rank, self.tp_cpu_group, src=self.tp_group.ranks[0]
|
860
|
+
)
|
861
|
+
return data
|
@@ -99,7 +99,8 @@ class MetadataBuffers:
|
|
99
99
|
# For ascend backend, output tokens are placed in the NPU and will be transferred by D2D channel.
|
100
100
|
device = "npu"
|
101
101
|
elif self.custom_mem_pool:
|
102
|
-
|
102
|
+
# TODO(shangming): Fix me (use 'cuda') when nvlink_transport of Mooncake is bug-free
|
103
|
+
device = "cpu"
|
103
104
|
with (
|
104
105
|
torch.cuda.use_mem_pool(self.custom_mem_pool)
|
105
106
|
if self.custom_mem_pool
|
@@ -398,7 +398,7 @@ class CustomAllreduce:
|
|
398
398
|
else:
|
399
399
|
# If warm up, mimic the allocation pattern since custom
|
400
400
|
# allreduce is out-of-place.
|
401
|
-
return torch.
|
401
|
+
return torch.zeros_like(input)
|
402
402
|
else:
|
403
403
|
if _is_hip:
|
404
404
|
# note: outside of cuda graph context,
|