sglang 0.5.1.post2__py3-none-any.whl → 0.5.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +3 -0
- sglang/bench_one_batch_server.py +89 -54
- sglang/bench_serving.py +437 -40
- sglang/lang/interpreter.py +1 -1
- sglang/profiler.py +0 -1
- sglang/srt/configs/__init__.py +4 -0
- sglang/srt/configs/internvl.py +6 -0
- sglang/srt/configs/longcat_flash.py +104 -0
- sglang/srt/configs/model_config.py +37 -7
- sglang/srt/configs/qwen3_next.py +326 -0
- sglang/srt/connector/__init__.py +1 -1
- sglang/srt/connector/base_connector.py +1 -2
- sglang/srt/connector/redis.py +2 -2
- sglang/srt/connector/serde/__init__.py +1 -1
- sglang/srt/connector/serde/safe_serde.py +4 -3
- sglang/srt/custom_op.py +11 -1
- sglang/srt/debug_utils/dump_comparator.py +81 -44
- sglang/srt/debug_utils/dump_loader.py +97 -0
- sglang/srt/debug_utils/dumper.py +11 -3
- sglang/srt/debug_utils/text_comparator.py +73 -11
- sglang/srt/disaggregation/ascend/conn.py +75 -0
- sglang/srt/disaggregation/base/conn.py +1 -1
- sglang/srt/disaggregation/common/conn.py +15 -12
- sglang/srt/disaggregation/decode.py +6 -4
- sglang/srt/disaggregation/fake/conn.py +1 -1
- sglang/srt/disaggregation/mini_lb.py +6 -420
- sglang/srt/disaggregation/mooncake/conn.py +18 -10
- sglang/srt/disaggregation/nixl/conn.py +180 -16
- sglang/srt/disaggregation/prefill.py +6 -4
- sglang/srt/disaggregation/utils.py +5 -50
- sglang/srt/distributed/parallel_state.py +94 -58
- sglang/srt/entrypoints/engine.py +34 -14
- sglang/srt/entrypoints/http_server.py +172 -47
- sglang/srt/entrypoints/openai/protocol.py +90 -27
- sglang/srt/entrypoints/openai/serving_base.py +6 -2
- sglang/srt/entrypoints/openai/serving_chat.py +82 -26
- sglang/srt/entrypoints/openai/serving_completions.py +25 -4
- sglang/srt/entrypoints/openai/serving_embedding.py +8 -4
- sglang/srt/entrypoints/openai/serving_responses.py +7 -4
- sglang/srt/eplb/eplb_manager.py +28 -4
- sglang/srt/eplb/expert_distribution.py +55 -15
- sglang/srt/eplb/expert_location.py +8 -3
- sglang/srt/eplb/expert_location_updater.py +1 -1
- sglang/srt/function_call/deepseekv31_detector.py +222 -0
- sglang/srt/function_call/ebnf_composer.py +11 -9
- sglang/srt/function_call/function_call_parser.py +2 -0
- sglang/srt/function_call/glm4_moe_detector.py +1 -1
- sglang/srt/function_call/gpt_oss_detector.py +144 -256
- sglang/srt/function_call/qwen3_coder_detector.py +1 -1
- sglang/srt/hf_transformers_utils.py +28 -7
- sglang/srt/layers/activation.py +44 -9
- sglang/srt/layers/attention/aiter_backend.py +93 -68
- sglang/srt/layers/attention/ascend_backend.py +381 -136
- sglang/srt/layers/attention/fla/chunk.py +242 -0
- sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
- sglang/srt/layers/attention/fla/chunk_o.py +178 -0
- sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
- sglang/srt/layers/attention/fla/cumsum.py +300 -0
- sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
- sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
- sglang/srt/layers/attention/fla/index.py +37 -0
- sglang/srt/layers/attention/fla/l2norm.py +150 -0
- sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
- sglang/srt/layers/attention/fla/op.py +66 -0
- sglang/srt/layers/attention/fla/solve_tril.py +465 -0
- sglang/srt/layers/attention/fla/utils.py +331 -0
- sglang/srt/layers/attention/fla/wy_fast.py +158 -0
- sglang/srt/layers/attention/flashattention_backend.py +241 -7
- sglang/srt/layers/attention/flashinfer_backend.py +11 -6
- sglang/srt/layers/attention/flashinfer_mla_backend.py +21 -14
- sglang/srt/layers/attention/hybrid_attn_backend.py +47 -8
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +584 -0
- sglang/srt/layers/attention/intel_amx_backend.py +3 -0
- sglang/srt/layers/attention/mamba/causal_conv1d.py +128 -0
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +1052 -0
- sglang/srt/layers/attention/mamba/mamba.py +64 -0
- sglang/srt/layers/attention/torch_native_backend.py +12 -6
- sglang/srt/layers/attention/trtllm_mla_backend.py +126 -36
- sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
- sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
- sglang/srt/layers/communicator.py +45 -8
- sglang/srt/layers/layernorm.py +54 -12
- sglang/srt/layers/logits_processor.py +10 -3
- sglang/srt/layers/moe/__init__.py +2 -1
- sglang/srt/layers/moe/cutlass_moe.py +0 -8
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +4 -12
- sglang/srt/layers/moe/ep_moe/kernels.py +74 -0
- sglang/srt/layers/moe/ep_moe/layer.py +111 -56
- sglang/srt/layers/moe/fused_moe_native.py +5 -3
- sglang/srt/layers/moe/fused_moe_triton/__init__.py +5 -3
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/{E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → E=257,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json } +29 -29
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=64,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +9 -1049
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +212 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +799 -0
- sglang/srt/layers/moe/fused_moe_triton/layer.py +56 -45
- sglang/srt/layers/moe/fused_moe_triton/moe_align_block_size.py +87 -0
- sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
- sglang/srt/layers/moe/moe_runner/base.py +274 -1
- sglang/srt/layers/moe/moe_runner/runner.py +80 -0
- sglang/srt/layers/moe/moe_runner/triton.py +448 -0
- sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
- sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
- sglang/srt/layers/moe/token_dispatcher/deepep.py +41 -38
- sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
- sglang/srt/layers/moe/topk.py +43 -12
- sglang/srt/layers/moe/utils.py +6 -5
- sglang/srt/layers/quantization/awq.py +19 -7
- sglang/srt/layers/quantization/base_config.py +11 -6
- sglang/srt/layers/quantization/blockwise_int8.py +38 -27
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
- sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py +141 -235
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +5 -10
- sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +31 -22
- sglang/srt/layers/quantization/fp8.py +78 -48
- sglang/srt/layers/quantization/fp8_kernel.py +2 -2
- sglang/srt/layers/quantization/fp8_utils.py +45 -31
- sglang/srt/layers/quantization/gptq.py +25 -17
- sglang/srt/layers/quantization/modelopt_quant.py +107 -40
- sglang/srt/layers/quantization/moe_wna16.py +21 -18
- sglang/srt/layers/quantization/mxfp4.py +93 -68
- sglang/srt/layers/quantization/mxfp4_tensor.py +3 -1
- sglang/srt/layers/quantization/quark/quark_moe.py +32 -27
- sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +49 -30
- sglang/srt/layers/quantization/quark/utils.py +97 -0
- sglang/srt/layers/quantization/rocm_mxfp4_utils.py +13 -0
- sglang/srt/layers/quantization/unquant.py +135 -47
- sglang/srt/layers/quantization/utils.py +13 -0
- sglang/srt/layers/quantization/w4afp8.py +60 -42
- sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
- sglang/srt/layers/quantization/w8a8_int8.py +83 -41
- sglang/srt/layers/rocm_linear_utils.py +44 -0
- sglang/srt/layers/rotary_embedding.py +28 -19
- sglang/srt/layers/sampler.py +29 -5
- sglang/srt/layers/utils.py +0 -14
- sglang/srt/lora/backend/base_backend.py +50 -8
- sglang/srt/lora/backend/triton_backend.py +90 -2
- sglang/srt/lora/layers.py +32 -0
- sglang/srt/lora/lora.py +4 -1
- sglang/srt/lora/lora_manager.py +35 -112
- sglang/srt/lora/mem_pool.py +24 -10
- sglang/srt/lora/utils.py +18 -9
- sglang/srt/managers/cache_controller.py +396 -365
- sglang/srt/managers/data_parallel_controller.py +30 -15
- sglang/srt/managers/detokenizer_manager.py +18 -2
- sglang/srt/managers/disagg_service.py +46 -0
- sglang/srt/managers/io_struct.py +190 -11
- sglang/srt/managers/mm_utils.py +6 -1
- sglang/srt/managers/multi_tokenizer_mixin.py +579 -0
- sglang/srt/managers/schedule_batch.py +27 -44
- sglang/srt/managers/schedule_policy.py +4 -3
- sglang/srt/managers/scheduler.py +148 -122
- sglang/srt/managers/scheduler_metrics_mixin.py +114 -8
- sglang/srt/managers/scheduler_output_processor_mixin.py +29 -19
- sglang/srt/managers/scheduler_profiler_mixin.py +1 -1
- sglang/srt/managers/scheduler_update_weights_mixin.py +8 -1
- sglang/srt/managers/template_manager.py +3 -3
- sglang/srt/managers/tokenizer_communicator_mixin.py +491 -0
- sglang/srt/managers/tokenizer_manager.py +77 -480
- sglang/srt/managers/tp_worker.py +16 -4
- sglang/srt/managers/tp_worker_overlap_thread.py +8 -10
- sglang/srt/mem_cache/allocator.py +1 -1
- sglang/srt/mem_cache/chunk_cache.py +1 -1
- sglang/srt/mem_cache/hicache_storage.py +53 -40
- sglang/srt/mem_cache/hiradix_cache.py +196 -104
- sglang/srt/mem_cache/lora_radix_cache.py +1 -1
- sglang/srt/mem_cache/memory_pool.py +395 -53
- sglang/srt/mem_cache/memory_pool_host.py +27 -19
- sglang/srt/mem_cache/radix_cache.py +6 -6
- sglang/srt/mem_cache/radix_cache_cpp.py +1 -1
- sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
- sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
- sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +61 -34
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +152 -23
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +280 -0
- sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +154 -95
- sglang/srt/mem_cache/storage/mooncake_store/test_mooncake_store.py +161 -0
- sglang/srt/mem_cache/swa_radix_cache.py +1 -3
- sglang/srt/metrics/collector.py +484 -63
- sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
- sglang/srt/metrics/utils.py +48 -0
- sglang/srt/model_executor/cpu_graph_runner.py +640 -0
- sglang/srt/model_executor/cuda_graph_runner.py +13 -5
- sglang/srt/model_executor/forward_batch_info.py +72 -18
- sglang/srt/model_executor/model_runner.py +190 -32
- sglang/srt/model_loader/__init__.py +9 -3
- sglang/srt/model_loader/loader.py +33 -28
- sglang/srt/model_loader/utils.py +12 -0
- sglang/srt/model_loader/weight_utils.py +2 -1
- sglang/srt/models/deepseek_v2.py +323 -53
- sglang/srt/models/gemma3n_mm.py +1 -1
- sglang/srt/models/glm4_moe.py +10 -1
- sglang/srt/models/glm4v.py +4 -2
- sglang/srt/models/gpt_oss.py +7 -19
- sglang/srt/models/internvl.py +28 -0
- sglang/srt/models/llama4.py +9 -0
- sglang/srt/models/llama_eagle3.py +17 -0
- sglang/srt/models/longcat_flash.py +1026 -0
- sglang/srt/models/longcat_flash_nextn.py +699 -0
- sglang/srt/models/minicpmv.py +165 -3
- sglang/srt/models/mllama4.py +25 -0
- sglang/srt/models/opt.py +637 -0
- sglang/srt/models/qwen2.py +33 -3
- sglang/srt/models/qwen2_5_vl.py +91 -42
- sglang/srt/models/qwen2_moe.py +79 -14
- sglang/srt/models/qwen3.py +8 -2
- sglang/srt/models/qwen3_moe.py +39 -8
- sglang/srt/models/qwen3_next.py +1039 -0
- sglang/srt/models/qwen3_next_mtp.py +109 -0
- sglang/srt/models/torch_native_llama.py +1 -1
- sglang/srt/models/transformers.py +1 -1
- sglang/srt/multimodal/processors/base_processor.py +4 -2
- sglang/srt/multimodal/processors/glm4v.py +9 -9
- sglang/srt/multimodal/processors/internvl.py +141 -129
- sglang/srt/{conversation.py → parser/conversation.py} +38 -5
- sglang/srt/parser/harmony_parser.py +588 -0
- sglang/srt/parser/reasoning_parser.py +309 -0
- sglang/srt/sampling/penaltylib/orchestrator.py +14 -2
- sglang/srt/sampling/sampling_batch_info.py +18 -15
- sglang/srt/server_args.py +307 -80
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +5 -0
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +10 -1
- sglang/srt/speculative/eagle_worker.py +216 -120
- sglang/srt/speculative/spec_info.py +5 -0
- sglang/srt/speculative/standalone_worker.py +109 -0
- sglang/srt/tokenizer/tiktoken_tokenizer.py +6 -1
- sglang/srt/utils.py +96 -7
- sglang/srt/weight_sync/utils.py +1 -1
- sglang/test/attention/test_trtllm_mla_backend.py +181 -8
- sglang/test/few_shot_gsm8k.py +1 -0
- sglang/test/runners.py +4 -0
- sglang/test/test_cutlass_moe.py +24 -6
- sglang/test/test_cutlass_w4a8_moe.py +24 -9
- sglang/test/test_disaggregation_utils.py +66 -0
- sglang/test/test_utils.py +25 -1
- sglang/utils.py +5 -0
- sglang/version.py +1 -1
- {sglang-0.5.1.post2.dist-info → sglang-0.5.2.dist-info}/METADATA +13 -10
- {sglang-0.5.1.post2.dist-info → sglang-0.5.2.dist-info}/RECORD +253 -201
- sglang/srt/disaggregation/launch_lb.py +0 -131
- sglang/srt/mem_cache/storage/mooncake_store/unit_test.py +0 -40
- sglang/srt/reasoning_parser.py +0 -553
- /sglang/srt/{model_parallel.py → layers/model_parallel.py} +0 -0
- /sglang/srt/{code_completion_parser.py → parser/code_completion_parser.py} +0 -0
- /sglang/srt/{jinja_template_utils.py → parser/jinja_template_utils.py} +0 -0
- {sglang-0.5.1.post2.dist-info → sglang-0.5.2.dist-info}/WHEEL +0 -0
- {sglang-0.5.1.post2.dist-info → sglang-0.5.2.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.1.post2.dist-info → sglang-0.5.2.dist-info}/top_level.txt +0 -0
@@ -1,11 +1,11 @@
|
|
1
1
|
import argparse
|
2
2
|
import functools
|
3
|
-
import re
|
4
3
|
from pathlib import Path
|
5
4
|
|
6
5
|
import polars as pl
|
7
6
|
import torch
|
8
7
|
|
8
|
+
from sglang.srt.debug_utils.dump_loader import find_row, read_meta
|
9
9
|
from sglang.srt.debug_utils.dumper import get_truncated_value
|
10
10
|
|
11
11
|
|
@@ -26,66 +26,77 @@ def main(args):
|
|
26
26
|
print("df_baseline", df_baseline)
|
27
27
|
|
28
28
|
for row in df_target.iter_rows(named=True):
|
29
|
-
|
30
|
-
|
31
|
-
|
32
|
-
|
33
|
-
|
34
|
-
|
35
|
-
|
36
|
-
|
37
|
-
|
38
|
-
|
39
|
-
|
40
|
-
|
41
|
-
|
29
|
+
path_target = Path(args.target_path) / row["filename"]
|
30
|
+
|
31
|
+
row_baseline = find_row(
|
32
|
+
df_baseline,
|
33
|
+
conditions=dict(
|
34
|
+
forward_pass_id=row["forward_pass_id"]
|
35
|
+
- args.start_id
|
36
|
+
+ args.baseline_start_id,
|
37
|
+
**{
|
38
|
+
k: v
|
39
|
+
for k, v in row.items()
|
40
|
+
if k not in ["forward_pass_id", "dump_index", "filename"]
|
41
|
+
},
|
42
|
+
),
|
42
43
|
)
|
43
|
-
|
44
|
-
row_baseline
|
44
|
+
|
45
|
+
if row_baseline is None:
|
46
|
+
print(f"Skip: target={str(path_target)} since no baseline")
|
47
|
+
x_target = _load_object(path_target)
|
48
|
+
if x_target is not None:
|
49
|
+
print(f"x_target(sample)={get_truncated_value(x_target)}")
|
50
|
+
continue
|
45
51
|
|
46
52
|
path_baseline = Path(args.baseline_path) / row_baseline["filename"]
|
47
|
-
path_target = Path(args.target_path) / row["filename"]
|
48
53
|
print(f"Check: target={str(path_target)} baseline={str(path_baseline)}")
|
49
|
-
check_tensor_pair(
|
54
|
+
check_tensor_pair(
|
55
|
+
path_baseline=path_baseline, path_target=path_target, name=row["name"]
|
56
|
+
)
|
50
57
|
print()
|
51
58
|
|
52
59
|
|
53
|
-
def
|
54
|
-
|
55
|
-
|
56
|
-
|
57
|
-
rows = []
|
58
|
-
for p in directory.glob("*.pt"):
|
59
|
-
full_kwargs = {}
|
60
|
-
for kv in p.stem.split("___"):
|
61
|
-
k, v = kv.split("=")
|
62
|
-
full_kwargs[k] = v
|
63
|
-
rows.append(
|
64
|
-
{
|
65
|
-
"filename": str(p.name),
|
66
|
-
**full_kwargs,
|
67
|
-
}
|
68
|
-
)
|
60
|
+
def check_tensor_pair(path_baseline, path_target, name=""):
|
61
|
+
x_baseline = _load_object(path_baseline)
|
62
|
+
x_target = _load_object(path_target)
|
69
63
|
|
70
|
-
|
71
|
-
|
72
|
-
|
73
|
-
|
64
|
+
print(
|
65
|
+
f"Raw "
|
66
|
+
f"[shape] {x_baseline.shape} vs {x_target.shape}\t"
|
67
|
+
f"[dtype] {x_baseline.dtype} vs {x_target.dtype}"
|
74
68
|
)
|
75
|
-
return df
|
76
|
-
|
77
69
|
|
78
|
-
|
79
|
-
x_baseline =
|
80
|
-
x_target = torch.load(path_target, weights_only=True)
|
70
|
+
x_baseline, x_target = _comparison_preprocessor(x_baseline, x_target, name=name)
|
71
|
+
x_baseline = _try_unify_shape(x_baseline, target_shape=x_target.shape)
|
81
72
|
|
82
73
|
print(
|
74
|
+
f"After preprocessor "
|
83
75
|
f"[shape] {x_baseline.shape} vs {x_target.shape}\t"
|
84
76
|
f"[dtype] {x_baseline.dtype} vs {x_target.dtype}"
|
85
77
|
)
|
86
78
|
|
79
|
+
x_target = x_target.float()
|
80
|
+
x_baseline = x_baseline.float()
|
81
|
+
|
82
|
+
for name, fn in (
|
83
|
+
("mean", torch.mean),
|
84
|
+
("std", torch.std),
|
85
|
+
("min", torch.min),
|
86
|
+
("max", torch.max),
|
87
|
+
("p1", functools.partial(torch.quantile, q=0.01)),
|
88
|
+
("p5", functools.partial(torch.quantile, q=0.05)),
|
89
|
+
("p95", functools.partial(torch.quantile, q=0.95)),
|
90
|
+
("p99", functools.partial(torch.quantile, q=0.99)),
|
91
|
+
):
|
92
|
+
value_baseline = fn(x_baseline).item()
|
93
|
+
value_target = fn(x_target).item()
|
94
|
+
print(
|
95
|
+
f"[{name}] {value_baseline :.4f} vs {value_target:.4f} (diff: {value_target - value_baseline:.4f})"
|
96
|
+
)
|
97
|
+
|
87
98
|
if x_baseline.shape != x_target.shape:
|
88
|
-
print(f"
|
99
|
+
print(f"⚠️ Shape mismatch")
|
89
100
|
return
|
90
101
|
|
91
102
|
raw_abs_diff = (x_target - x_baseline).abs()
|
@@ -112,6 +123,19 @@ def check_tensor_pair(path_baseline, path_target):
|
|
112
123
|
print(f"x_target(sample)={get_truncated_value(x_target)}")
|
113
124
|
|
114
125
|
|
126
|
+
def _try_unify_shape(x: torch.Tensor, target_shape):
|
127
|
+
x_shape = x.shape
|
128
|
+
num_dim_to_remove = len(x_shape) - len(target_shape)
|
129
|
+
if (x_shape[num_dim_to_remove:] == target_shape) and all(
|
130
|
+
val == 1 for val in x_shape[:num_dim_to_remove]
|
131
|
+
):
|
132
|
+
out = functools.reduce(lambda a, _: a.squeeze(0), range(num_dim_to_remove), x)
|
133
|
+
print(f"Unify shape: {x_shape} -> {out.shape} (to match {target_shape})")
|
134
|
+
return out
|
135
|
+
|
136
|
+
return x
|
137
|
+
|
138
|
+
|
115
139
|
# Copied from DeepGEMM
|
116
140
|
def _calc_rel_diff(x: torch.Tensor, y: torch.Tensor):
|
117
141
|
x, y = x.double(), y.double()
|
@@ -120,6 +144,19 @@ def _calc_rel_diff(x: torch.Tensor, y: torch.Tensor):
|
|
120
144
|
return 1 - sim
|
121
145
|
|
122
146
|
|
147
|
+
def _comparison_preprocessor(x_baseline, x_target, name):
|
148
|
+
# can insert arbitrary adhoc postprocessing logic here
|
149
|
+
return x_baseline, x_target
|
150
|
+
|
151
|
+
|
152
|
+
def _load_object(path):
|
153
|
+
x = torch.load(path, weights_only=False)
|
154
|
+
if not isinstance(x, torch.Tensor):
|
155
|
+
print(f"Skip load {path} since {type(x)=} is not a Tensor")
|
156
|
+
return None
|
157
|
+
return x.cuda()
|
158
|
+
|
159
|
+
|
123
160
|
if __name__ == "__main__":
|
124
161
|
parser = argparse.ArgumentParser()
|
125
162
|
parser.add_argument("--baseline-path", type=str)
|
@@ -0,0 +1,97 @@
|
|
1
|
+
import functools
|
2
|
+
import os
|
3
|
+
from pathlib import Path
|
4
|
+
from typing import Any, Dict
|
5
|
+
|
6
|
+
import polars as pl
|
7
|
+
import torch
|
8
|
+
|
9
|
+
|
10
|
+
class DumpLoader:
|
11
|
+
def __init__(self):
|
12
|
+
directory = os.environ.get("SGLANG_DUMP_LOADER_DIR")
|
13
|
+
|
14
|
+
self._enable = directory is not None
|
15
|
+
if self._enable:
|
16
|
+
self._directory = Path(directory)
|
17
|
+
self._df = read_meta(directory)
|
18
|
+
|
19
|
+
@property
|
20
|
+
def enable(self):
|
21
|
+
return self._enable
|
22
|
+
|
23
|
+
def load(self, name, **kwargs):
|
24
|
+
assert self._enable, "Please call DumpLoader.load only when it is enabled"
|
25
|
+
|
26
|
+
from sglang.srt.debug_utils.dumper import dumper
|
27
|
+
|
28
|
+
forward_pass_id = dumper._forward_pass_id
|
29
|
+
conditions = dict(name=name, forward_pass_id=forward_pass_id, **kwargs)
|
30
|
+
row = find_row(self._df, conditions=conditions)
|
31
|
+
assert (
|
32
|
+
row is not None
|
33
|
+
), f"DumpLoader cannot find row given query {name=} {kwargs=} {self._directory=}"
|
34
|
+
|
35
|
+
path = self._directory / row["filename"]
|
36
|
+
output = torch.load(path, weights_only=False)
|
37
|
+
|
38
|
+
print(
|
39
|
+
f"[DumpLoader] load from {path=} (query: {name=} {kwargs=}, output: {type(output)})"
|
40
|
+
)
|
41
|
+
return output
|
42
|
+
|
43
|
+
|
44
|
+
def read_meta(directory):
|
45
|
+
directory = Path(directory)
|
46
|
+
assert directory.is_dir(), f"{directory=} should be a directory"
|
47
|
+
|
48
|
+
rows = []
|
49
|
+
for p in directory.glob("*.pt"):
|
50
|
+
full_kwargs = {}
|
51
|
+
for kv in p.stem.split("___"):
|
52
|
+
k, v = kv.split("=")
|
53
|
+
full_kwargs[k] = v
|
54
|
+
rows.append(
|
55
|
+
{
|
56
|
+
"filename": str(p.name),
|
57
|
+
**full_kwargs,
|
58
|
+
}
|
59
|
+
)
|
60
|
+
|
61
|
+
df = pl.DataFrame(rows)
|
62
|
+
df = df.with_columns(
|
63
|
+
pl.col("forward_pass_id").cast(int),
|
64
|
+
pl.col("rank").cast(int),
|
65
|
+
pl.col("dump_index").cast(int),
|
66
|
+
)
|
67
|
+
return df
|
68
|
+
|
69
|
+
|
70
|
+
def find_row(df, conditions: Dict[str, Any]):
|
71
|
+
df_sub = df.filter(
|
72
|
+
functools.reduce(
|
73
|
+
lambda a, b: a & b,
|
74
|
+
[
|
75
|
+
pl.col(col) == _cast_to_polars_dtype(conditions[col], df.schema[col])
|
76
|
+
for col in conditions.keys()
|
77
|
+
],
|
78
|
+
)
|
79
|
+
)
|
80
|
+
assert len(df_sub) <= 1
|
81
|
+
return df_sub.to_dicts()[0] if len(df_sub) > 0 else None
|
82
|
+
|
83
|
+
|
84
|
+
def _cast_to_polars_dtype(value, target_dtype):
|
85
|
+
if target_dtype in (pl.Int64, pl.Int32, pl.UInt64, pl.UInt32):
|
86
|
+
return int(value)
|
87
|
+
elif target_dtype in (pl.Float64, pl.Float32):
|
88
|
+
return float(value)
|
89
|
+
elif target_dtype == pl.Boolean:
|
90
|
+
return bool(value)
|
91
|
+
elif target_dtype == pl.String:
|
92
|
+
return str(value)
|
93
|
+
else:
|
94
|
+
return value
|
95
|
+
|
96
|
+
|
97
|
+
dump_loader = DumpLoader()
|
sglang/srt/debug_utils/dumper.py
CHANGED
@@ -53,7 +53,7 @@ class _Dumper:
|
|
53
53
|
if self._partial_name is None:
|
54
54
|
self._partial_name = _get_partial_name()
|
55
55
|
|
56
|
-
rank =
|
56
|
+
rank = _get_rank()
|
57
57
|
full_kwargs = dict(
|
58
58
|
forward_pass_id=self._forward_pass_id,
|
59
59
|
rank=rank,
|
@@ -80,12 +80,20 @@ class _Dumper:
|
|
80
80
|
|
81
81
|
|
82
82
|
def _get_partial_name():
|
83
|
-
rank =
|
83
|
+
rank = _get_rank()
|
84
84
|
object_list = [str(time.time()) if rank == 0 else None]
|
85
|
-
dist.
|
85
|
+
if dist.is_initialized():
|
86
|
+
dist.broadcast_object_list(object_list, device="cuda")
|
86
87
|
return object_list[0]
|
87
88
|
|
88
89
|
|
90
|
+
def _get_rank():
|
91
|
+
if dist.is_initialized():
|
92
|
+
return dist.get_rank()
|
93
|
+
else:
|
94
|
+
return 0
|
95
|
+
|
96
|
+
|
89
97
|
def get_truncated_value(value):
|
90
98
|
if value is None:
|
91
99
|
return None
|
@@ -1,4 +1,5 @@
|
|
1
1
|
import argparse
|
2
|
+
import hashlib
|
2
3
|
import json
|
3
4
|
from pathlib import Path
|
4
5
|
|
@@ -13,7 +14,11 @@ Supported inputs:
|
|
13
14
|
|
14
15
|
|
15
16
|
def main(args):
|
16
|
-
|
17
|
+
if args.data_type == "simple_evals":
|
18
|
+
df_input = _compute_df_input_mode_simple_evals(args)
|
19
|
+
else:
|
20
|
+
df_input = _transform_df_input(_compute_df_raw(args))
|
21
|
+
|
17
22
|
assert all(
|
18
23
|
c in df_input.columns
|
19
24
|
for c in ["category", "trial_index", "prompt_id", "prompt", "output", "correct"]
|
@@ -37,8 +42,9 @@ def main(args):
|
|
37
42
|
df_meta=df_meta.to_dicts(),
|
38
43
|
df_good_to_bad=df_good_to_bad.to_dicts(),
|
39
44
|
df_bad_to_good=df_bad_to_good.to_dicts(),
|
40
|
-
)
|
41
|
-
|
45
|
+
),
|
46
|
+
indent=4,
|
47
|
+
),
|
42
48
|
)
|
43
49
|
|
44
50
|
if not args.disable_print_details:
|
@@ -65,19 +71,70 @@ def main(args):
|
|
65
71
|
print(df)
|
66
72
|
|
67
73
|
|
74
|
+
def _compute_df_input_mode_simple_evals(args):
|
75
|
+
return pl.concat(
|
76
|
+
[
|
77
|
+
_compute_df_input_one_mode_simple_evals(**info)
|
78
|
+
for info in _get_file_infos(args=args)
|
79
|
+
]
|
80
|
+
)
|
81
|
+
|
82
|
+
|
83
|
+
def _compute_df_input_one_mode_simple_evals(path, category, trial_index):
|
84
|
+
data = json.loads(Path(path).read_text())
|
85
|
+
rows = []
|
86
|
+
|
87
|
+
for single_eval_result in data["metadata"]["single_eval_results"]:
|
88
|
+
prompt = single_eval_result["example_level_metadata"][
|
89
|
+
"actual_queried_prompt_messages"
|
90
|
+
]
|
91
|
+
score = single_eval_result["score"]
|
92
|
+
assert score in {0.0, 1.0}, f"{score=}"
|
93
|
+
|
94
|
+
row = dict(
|
95
|
+
category=category,
|
96
|
+
trial_index=trial_index,
|
97
|
+
prompt_id=_compute_id_from_object(prompt),
|
98
|
+
prompt=json.dumps(prompt),
|
99
|
+
output=single_eval_result["example_level_metadata"]["response_text"],
|
100
|
+
correct=score == 1.0,
|
101
|
+
)
|
102
|
+
rows.append(row)
|
103
|
+
|
104
|
+
return pl.DataFrame(rows)
|
105
|
+
|
106
|
+
|
107
|
+
def _compute_id_from_object(obj):
|
108
|
+
if isinstance(obj, pl.Series):
|
109
|
+
obj = obj.to_list()
|
110
|
+
json_str = json.dumps(obj, sort_keys=True, ensure_ascii=False)
|
111
|
+
return hashlib.sha256(json_str.encode("utf-8")).hexdigest()
|
112
|
+
|
113
|
+
|
68
114
|
def _compute_df_raw(args):
|
69
115
|
return pl.concat(
|
70
116
|
[
|
71
|
-
_read_df_raw(
|
72
|
-
|
73
|
-
|
74
|
-
|
75
|
-
|
76
|
-
for
|
117
|
+
_read_df_raw(
|
118
|
+
path=info["path"],
|
119
|
+
category=info["category"],
|
120
|
+
trial_index=info["trial_index"],
|
121
|
+
)
|
122
|
+
for info in _get_file_infos(args=args)
|
77
123
|
]
|
78
124
|
)
|
79
125
|
|
80
126
|
|
127
|
+
def _get_file_infos(args):
|
128
|
+
return [
|
129
|
+
dict(path=path, category=category, trial_index=trial_index)
|
130
|
+
for category, paths in [
|
131
|
+
("baseline", args.baseline_path),
|
132
|
+
("target", args.target_path),
|
133
|
+
]
|
134
|
+
for trial_index, path in enumerate(paths)
|
135
|
+
]
|
136
|
+
|
137
|
+
|
81
138
|
def _read_df_raw(path: str, category: str, trial_index: int):
|
82
139
|
return pl.read_ndjson(path).with_columns(
|
83
140
|
category=pl.lit(category), trial_index=trial_index
|
@@ -108,7 +165,9 @@ def _transform_df_input(df: pl.DataFrame):
|
|
108
165
|
print("Transform mode: SGLang bench")
|
109
166
|
return df
|
110
167
|
else:
|
111
|
-
raise Exception(
|
168
|
+
raise Exception(
|
169
|
+
f"Unknown data: {df.columns}. You may need to set `--data-type` if using e.g. simple_evals."
|
170
|
+
)
|
112
171
|
|
113
172
|
|
114
173
|
def _compute_df_meta(df_input: pl.DataFrame):
|
@@ -127,7 +186,9 @@ def _compute_df_meta(df_input: pl.DataFrame):
|
|
127
186
|
|
128
187
|
|
129
188
|
def _handle_one_prompt(df_one_prompt: pl.DataFrame):
|
130
|
-
assert
|
189
|
+
assert (
|
190
|
+
len(set(_compute_id_from_object(obj) for obj in df_one_prompt["prompt"])) == 1
|
191
|
+
)
|
131
192
|
|
132
193
|
df_baseline = df_one_prompt.filter(pl.col("category") == "baseline")
|
133
194
|
df_target = df_one_prompt.filter(pl.col("category") == "target")
|
@@ -162,6 +223,7 @@ def _compute_str_prefix_len(a: str, b: str) -> int:
|
|
162
223
|
|
163
224
|
if __name__ == "__main__":
|
164
225
|
parser = argparse.ArgumentParser(description=_DESCRIPTION)
|
226
|
+
parser.add_argument("--data-type", type=str, default="auto")
|
165
227
|
parser.add_argument("--baseline-path", type=str, nargs="+")
|
166
228
|
parser.add_argument("--target-path", type=str, nargs="+")
|
167
229
|
parser.add_argument(
|
@@ -1,6 +1,12 @@
|
|
1
|
+
import concurrent.futures
|
1
2
|
import logging
|
3
|
+
from typing import List, Tuple
|
4
|
+
|
5
|
+
import numpy as np
|
6
|
+
import numpy.typing as npt
|
2
7
|
|
3
8
|
from sglang.srt.disaggregation.ascend.transfer_engine import AscendTransferEngine
|
9
|
+
from sglang.srt.disaggregation.common.utils import group_concurrent_contiguous
|
4
10
|
from sglang.srt.disaggregation.mooncake.conn import (
|
5
11
|
MooncakeKVBootstrapServer,
|
6
12
|
MooncakeKVManager,
|
@@ -29,6 +35,75 @@ class AscendKVManager(MooncakeKVManager):
|
|
29
35
|
self.kv_args.aux_data_ptrs, self.kv_args.aux_data_lens
|
30
36
|
)
|
31
37
|
|
38
|
+
def send_kvcache(
|
39
|
+
self,
|
40
|
+
mooncake_session_id: str,
|
41
|
+
prefill_kv_indices: npt.NDArray[np.int32],
|
42
|
+
dst_kv_ptrs: list[int],
|
43
|
+
dst_kv_indices: npt.NDArray[np.int32],
|
44
|
+
executor: concurrent.futures.ThreadPoolExecutor,
|
45
|
+
):
|
46
|
+
# Group by indices
|
47
|
+
prefill_kv_blocks, dst_kv_blocks = group_concurrent_contiguous(
|
48
|
+
prefill_kv_indices, dst_kv_indices
|
49
|
+
)
|
50
|
+
|
51
|
+
num_layers = len(self.kv_args.kv_data_ptrs)
|
52
|
+
layers_params = [
|
53
|
+
(
|
54
|
+
self.kv_args.kv_data_ptrs[layer_id],
|
55
|
+
dst_kv_ptrs[layer_id],
|
56
|
+
self.kv_args.kv_item_lens[layer_id],
|
57
|
+
)
|
58
|
+
for layer_id in range(num_layers)
|
59
|
+
]
|
60
|
+
|
61
|
+
def set_transfer_blocks(
|
62
|
+
src_ptr: int, dst_ptr: int, item_len: int
|
63
|
+
) -> List[Tuple[int, int, int]]:
|
64
|
+
transfer_blocks = []
|
65
|
+
for prefill_index, decode_index in zip(prefill_kv_blocks, dst_kv_blocks):
|
66
|
+
src_addr = src_ptr + int(prefill_index[0]) * item_len
|
67
|
+
dst_addr = dst_ptr + int(decode_index[0]) * item_len
|
68
|
+
length = item_len * len(prefill_index)
|
69
|
+
transfer_blocks.append((src_addr, dst_addr, length))
|
70
|
+
return transfer_blocks
|
71
|
+
|
72
|
+
# Worker function for processing a single layer
|
73
|
+
def process_layer(src_ptr: int, dst_ptr: int, item_len: int) -> int:
|
74
|
+
transfer_blocks = set_transfer_blocks(src_ptr, dst_ptr, item_len)
|
75
|
+
return self._transfer_data(mooncake_session_id, transfer_blocks)
|
76
|
+
|
77
|
+
# Worker function for processing all layers in a batch
|
78
|
+
def process_layers(layers_params: List[Tuple[int, int, int]]) -> int:
|
79
|
+
transfer_blocks = []
|
80
|
+
for src_ptr, dst_ptr, item_len in layers_params:
|
81
|
+
transfer_blocks.extend(set_transfer_blocks(src_ptr, dst_ptr, item_len))
|
82
|
+
return self._transfer_data(mooncake_session_id, transfer_blocks)
|
83
|
+
|
84
|
+
if self.enable_custom_mem_pool:
|
85
|
+
futures = [
|
86
|
+
executor.submit(
|
87
|
+
process_layer,
|
88
|
+
src_ptr,
|
89
|
+
dst_ptr,
|
90
|
+
item_len,
|
91
|
+
)
|
92
|
+
for (src_ptr, dst_ptr, item_len) in layers_params
|
93
|
+
]
|
94
|
+
for future in concurrent.futures.as_completed(futures):
|
95
|
+
status = future.result()
|
96
|
+
if status != 0:
|
97
|
+
for f in futures:
|
98
|
+
f.cancel()
|
99
|
+
return status
|
100
|
+
else:
|
101
|
+
# Combining all layers' params in one batch transfer is more efficient
|
102
|
+
# compared to using multiple threads
|
103
|
+
return process_layers(layers_params)
|
104
|
+
|
105
|
+
return 0
|
106
|
+
|
32
107
|
|
33
108
|
class AscendKVSender(MooncakeKVSender):
|
34
109
|
pass
|
@@ -47,6 +47,7 @@ class CommonKVManager(BaseKVManager):
|
|
47
47
|
self.is_mla_backend = is_mla_backend
|
48
48
|
self.disaggregation_mode = disaggregation_mode
|
49
49
|
# for p/d multi node infer
|
50
|
+
self.bootstrap_host = server_args.host
|
50
51
|
self.bootstrap_port = server_args.disaggregation_bootstrap_port
|
51
52
|
self.dist_init_addr = server_args.dist_init_addr
|
52
53
|
self.tp_size = server_args.tp_size
|
@@ -72,6 +73,7 @@ class CommonKVManager(BaseKVManager):
|
|
72
73
|
def _register_to_bootstrap(self):
|
73
74
|
"""Register KVSender to bootstrap server via HTTP POST."""
|
74
75
|
if self.dist_init_addr:
|
76
|
+
# multi node: bootstrap server's host is dist_init_addr
|
75
77
|
if self.dist_init_addr.startswith("["): # [ipv6]:port or [ipv6]
|
76
78
|
if self.dist_init_addr.endswith("]"):
|
77
79
|
host = self.dist_init_addr
|
@@ -80,7 +82,8 @@ class CommonKVManager(BaseKVManager):
|
|
80
82
|
else:
|
81
83
|
host = socket.gethostbyname(self.dist_init_addr.rsplit(":", 1)[0])
|
82
84
|
else:
|
83
|
-
host
|
85
|
+
# single node: bootstrap server's host is same as http server's host
|
86
|
+
host = self.bootstrap_host
|
84
87
|
host = maybe_wrap_ipv6_address(host)
|
85
88
|
|
86
89
|
bootstrap_server_url = f"{host}:{self.bootstrap_port}"
|
@@ -125,12 +128,11 @@ class CommonKVReceiver(BaseKVReceiver):
|
|
125
128
|
mgr: BaseKVManager,
|
126
129
|
bootstrap_addr: str,
|
127
130
|
bootstrap_room: Optional[int] = None,
|
128
|
-
|
131
|
+
prefill_dp_rank: Optional[int] = None,
|
129
132
|
):
|
130
133
|
self.bootstrap_room = bootstrap_room
|
131
134
|
self.bootstrap_addr = bootstrap_addr
|
132
135
|
self.kv_mgr = mgr
|
133
|
-
self.data_parallel_rank = data_parallel_rank
|
134
136
|
|
135
137
|
if self.bootstrap_addr not in self.kv_mgr.prefill_dp_size_table:
|
136
138
|
self.prefill_tp_size, self.prefill_dp_size = (
|
@@ -166,9 +168,6 @@ class CommonKVReceiver(BaseKVReceiver):
|
|
166
168
|
self.required_dst_info_num = 1
|
167
169
|
self.target_tp_ranks = [self.target_tp_rank]
|
168
170
|
elif local_tp_size_per_dp_rank > prefill_tp_size_per_dp_rank:
|
169
|
-
assert (
|
170
|
-
self.kv_mgr.is_mla_backend
|
171
|
-
), "PD with different TP sizes per DP rank is not yet supported for non-MLA models"
|
172
171
|
self.target_tp_rank = (
|
173
172
|
self.kv_mgr.kv_args.engine_rank % local_tp_size_per_dp_rank
|
174
173
|
) // (local_tp_size_per_dp_rank // prefill_tp_size_per_dp_rank)
|
@@ -198,11 +197,14 @@ class CommonKVReceiver(BaseKVReceiver):
|
|
198
197
|
self.target_tp_rank = self.target_tp_ranks[0]
|
199
198
|
self.required_dst_info_num = 1
|
200
199
|
|
201
|
-
if
|
202
|
-
logger.debug(f"Targeting DP rank: {
|
203
|
-
self.
|
200
|
+
if prefill_dp_rank is not None:
|
201
|
+
logger.debug(f"Targeting DP rank: {prefill_dp_rank}")
|
202
|
+
self.prefill_dp_rank = prefill_dp_rank
|
204
203
|
else:
|
205
|
-
self.
|
204
|
+
self.prefill_dp_rank = bootstrap_room % self.prefill_dp_size
|
205
|
+
|
206
|
+
# FIXME: alias here: target_dp_group -> prefill_dp_rank
|
207
|
+
self.target_dp_group = self.prefill_dp_rank
|
206
208
|
|
207
209
|
# NOTE: key distinguished by bootstrap_addr, target_dp_group, and target_tp_rank
|
208
210
|
bootstrap_key = (
|
@@ -308,7 +310,8 @@ class CommonKVReceiver(BaseKVReceiver):
|
|
308
310
|
|
309
311
|
|
310
312
|
class CommonKVBootstrapServer(BaseKVBootstrapServer):
|
311
|
-
def __init__(self, port: int):
|
313
|
+
def __init__(self, host: str, port: int):
|
314
|
+
self.host = host
|
312
315
|
self.port = port
|
313
316
|
self.app = web.Application()
|
314
317
|
self.store = dict()
|
@@ -412,7 +415,7 @@ class CommonKVBootstrapServer(BaseKVBootstrapServer):
|
|
412
415
|
self._runner = web.AppRunner(self.app)
|
413
416
|
self._loop.run_until_complete(self._runner.setup())
|
414
417
|
|
415
|
-
site = web.TCPSite(self._runner, port=self.port)
|
418
|
+
site = web.TCPSite(self._runner, host=self.host, port=self.port)
|
416
419
|
self._loop.run_until_complete(site.start())
|
417
420
|
self._loop.run_forever()
|
418
421
|
except Exception as e:
|
@@ -24,7 +24,7 @@ import logging
|
|
24
24
|
from collections import deque
|
25
25
|
from dataclasses import dataclass
|
26
26
|
from http import HTTPStatus
|
27
|
-
from typing import TYPE_CHECKING, List, Optional, Tuple, Union
|
27
|
+
from typing import TYPE_CHECKING, List, Optional, Tuple, Type, Union
|
28
28
|
|
29
29
|
import torch
|
30
30
|
from torch.distributed import ProcessGroup
|
@@ -218,8 +218,10 @@ class DecodePreallocQueue:
|
|
218
218
|
|
219
219
|
kv_args.ib_device = self.scheduler.server_args.disaggregation_ib_device
|
220
220
|
kv_args.gpu_id = self.scheduler.gpu_id
|
221
|
-
kv_manager_class = get_kv_class(
|
222
|
-
|
221
|
+
kv_manager_class: Type[BaseKVManager] = get_kv_class(
|
222
|
+
self.transfer_backend, KVClassType.MANAGER
|
223
|
+
)
|
224
|
+
kv_manager: BaseKVManager = kv_manager_class(
|
223
225
|
kv_args,
|
224
226
|
DisaggregationMode.DECODE,
|
225
227
|
self.scheduler.server_args,
|
@@ -248,7 +250,7 @@ class DecodePreallocQueue:
|
|
248
250
|
mgr=self.kv_manager,
|
249
251
|
bootstrap_addr=f"{req.bootstrap_host}:{req.bootstrap_port}",
|
250
252
|
bootstrap_room=req.bootstrap_room,
|
251
|
-
|
253
|
+
prefill_dp_rank=req.data_parallel_rank,
|
252
254
|
)
|
253
255
|
|
254
256
|
self.queue.append(
|