sglang 0.5.2rc2__py3-none-any.whl → 0.5.3rc0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch_server.py +10 -1
- sglang/bench_serving.py +257 -29
- sglang/srt/configs/__init__.py +4 -0
- sglang/srt/configs/device_config.py +3 -1
- sglang/srt/configs/dots_vlm.py +139 -0
- sglang/srt/configs/load_config.py +1 -0
- sglang/srt/configs/model_config.py +50 -6
- sglang/srt/configs/qwen3_next.py +326 -0
- sglang/srt/connector/__init__.py +8 -1
- sglang/srt/connector/remote_instance.py +82 -0
- sglang/srt/constrained/base_grammar_backend.py +48 -12
- sglang/srt/constrained/llguidance_backend.py +0 -1
- sglang/srt/constrained/outlines_backend.py +0 -1
- sglang/srt/constrained/xgrammar_backend.py +28 -9
- sglang/srt/custom_op.py +11 -1
- sglang/srt/debug_utils/dump_comparator.py +81 -44
- sglang/srt/debug_utils/dump_loader.py +97 -0
- sglang/srt/debug_utils/dumper.py +11 -3
- sglang/srt/debug_utils/text_comparator.py +73 -11
- sglang/srt/disaggregation/base/conn.py +1 -1
- sglang/srt/disaggregation/common/conn.py +15 -12
- sglang/srt/disaggregation/decode.py +21 -10
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +4 -1
- sglang/srt/disaggregation/fake/conn.py +1 -1
- sglang/srt/disaggregation/mini_lb.py +6 -445
- sglang/srt/disaggregation/mooncake/conn.py +18 -10
- sglang/srt/disaggregation/nixl/conn.py +180 -16
- sglang/srt/disaggregation/prefill.py +5 -3
- sglang/srt/disaggregation/utils.py +5 -50
- sglang/srt/distributed/parallel_state.py +24 -3
- sglang/srt/entrypoints/engine.py +38 -17
- sglang/srt/entrypoints/grpc_request_manager.py +580 -0
- sglang/srt/entrypoints/grpc_server.py +680 -0
- sglang/srt/entrypoints/http_server.py +85 -54
- sglang/srt/entrypoints/openai/protocol.py +4 -1
- sglang/srt/entrypoints/openai/serving_base.py +46 -3
- sglang/srt/entrypoints/openai/serving_chat.py +36 -16
- sglang/srt/entrypoints/openai/serving_completions.py +12 -3
- sglang/srt/entrypoints/openai/serving_embedding.py +8 -3
- sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
- sglang/srt/entrypoints/openai/serving_responses.py +6 -3
- sglang/srt/entrypoints/openai/serving_score.py +1 -0
- sglang/srt/eplb/eplb_manager.py +2 -2
- sglang/srt/eplb/expert_distribution.py +26 -13
- sglang/srt/eplb/expert_location.py +8 -3
- sglang/srt/eplb/expert_location_updater.py +1 -1
- sglang/srt/function_call/base_format_detector.py +3 -6
- sglang/srt/function_call/ebnf_composer.py +11 -9
- sglang/srt/function_call/function_call_parser.py +6 -0
- sglang/srt/function_call/glm4_moe_detector.py +1 -1
- sglang/srt/function_call/qwen3_coder_detector.py +1 -1
- sglang/srt/grpc/__init__.py +1 -0
- sglang/srt/grpc/sglang_scheduler_pb2.py +106 -0
- sglang/srt/grpc/sglang_scheduler_pb2.pyi +427 -0
- sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +236 -0
- sglang/srt/hf_transformers_utils.py +4 -0
- sglang/srt/layers/activation.py +142 -9
- sglang/srt/layers/attention/ascend_backend.py +11 -4
- sglang/srt/layers/attention/fla/chunk.py +242 -0
- sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
- sglang/srt/layers/attention/fla/chunk_o.py +178 -0
- sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
- sglang/srt/layers/attention/fla/cumsum.py +300 -0
- sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
- sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
- sglang/srt/layers/attention/fla/index.py +37 -0
- sglang/srt/layers/attention/fla/l2norm.py +150 -0
- sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
- sglang/srt/layers/attention/fla/op.py +66 -0
- sglang/srt/layers/attention/fla/solve_tril.py +465 -0
- sglang/srt/layers/attention/fla/utils.py +331 -0
- sglang/srt/layers/attention/fla/wy_fast.py +158 -0
- sglang/srt/layers/attention/flashinfer_backend.py +6 -4
- sglang/srt/layers/attention/flashinfer_mla_backend.py +16 -12
- sglang/srt/layers/attention/hybrid_attn_backend.py +57 -50
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +602 -0
- sglang/srt/layers/attention/intel_amx_backend.py +3 -0
- sglang/srt/layers/attention/mamba/causal_conv1d.py +128 -0
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +1052 -0
- sglang/srt/layers/attention/mamba/mamba.py +64 -0
- sglang/srt/layers/attention/torch_native_backend.py +12 -6
- sglang/srt/layers/attention/triton_backend.py +18 -1
- sglang/srt/layers/attention/trtllm_mla_backend.py +124 -31
- sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
- sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
- sglang/srt/layers/dp_attention.py +30 -1
- sglang/srt/layers/layernorm.py +32 -15
- sglang/srt/layers/linear.py +34 -3
- sglang/srt/layers/logits_processor.py +29 -10
- sglang/srt/layers/moe/__init__.py +2 -1
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +3 -3
- sglang/srt/layers/moe/ep_moe/kernels.py +1 -1
- sglang/srt/layers/moe/ep_moe/layer.py +182 -62
- sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +156 -0
- sglang/srt/layers/moe/fused_moe_native.py +5 -3
- sglang/srt/layers/moe/fused_moe_triton/configs/{triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +35 -35
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +1 -1
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
- sglang/srt/layers/moe/fused_moe_triton/layer.py +61 -59
- sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
- sglang/srt/layers/moe/moe_runner/base.py +274 -1
- sglang/srt/layers/moe/moe_runner/runner.py +80 -0
- sglang/srt/layers/moe/moe_runner/triton.py +448 -0
- sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
- sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
- sglang/srt/layers/moe/token_dispatcher/deepep.py +43 -39
- sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
- sglang/srt/layers/moe/topk.py +30 -9
- sglang/srt/layers/moe/utils.py +12 -6
- sglang/srt/layers/quantization/awq.py +19 -7
- sglang/srt/layers/quantization/base_config.py +11 -6
- sglang/srt/layers/quantization/blockwise_int8.py +38 -27
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
- sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
- sglang/srt/layers/quantization/fp8.py +76 -47
- sglang/srt/layers/quantization/fp8_utils.py +50 -31
- sglang/srt/layers/quantization/gptq.py +25 -17
- sglang/srt/layers/quantization/modelopt_quant.py +147 -47
- sglang/srt/layers/quantization/moe_wna16.py +21 -18
- sglang/srt/layers/quantization/mxfp4.py +64 -40
- sglang/srt/layers/quantization/quark/quark_moe.py +32 -27
- sglang/srt/layers/quantization/unquant.py +135 -47
- sglang/srt/layers/quantization/w4afp8.py +30 -17
- sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
- sglang/srt/layers/quantization/w8a8_int8.py +76 -38
- sglang/srt/layers/sampler.py +162 -18
- sglang/srt/lora/backend/base_backend.py +50 -8
- sglang/srt/lora/backend/triton_backend.py +90 -2
- sglang/srt/lora/layers.py +32 -0
- sglang/srt/lora/lora.py +4 -1
- sglang/srt/lora/lora_manager.py +35 -112
- sglang/srt/lora/mem_pool.py +24 -10
- sglang/srt/lora/utils.py +18 -9
- sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
- sglang/srt/managers/cache_controller.py +158 -160
- sglang/srt/managers/data_parallel_controller.py +105 -35
- sglang/srt/managers/detokenizer_manager.py +8 -4
- sglang/srt/managers/disagg_service.py +46 -0
- sglang/srt/managers/io_struct.py +199 -12
- sglang/srt/managers/mm_utils.py +1 -0
- sglang/srt/managers/multi_tokenizer_mixin.py +350 -400
- sglang/srt/managers/schedule_batch.py +77 -56
- sglang/srt/managers/schedule_policy.py +1 -1
- sglang/srt/managers/scheduler.py +187 -39
- sglang/srt/managers/scheduler_metrics_mixin.py +4 -3
- sglang/srt/managers/scheduler_output_processor_mixin.py +55 -11
- sglang/srt/managers/scheduler_profiler_mixin.py +1 -1
- sglang/srt/managers/tokenizer_communicator_mixin.py +569 -0
- sglang/srt/managers/tokenizer_manager.py +259 -519
- sglang/srt/managers/tp_worker.py +53 -4
- sglang/srt/managers/tp_worker_overlap_thread.py +42 -19
- sglang/srt/mem_cache/hicache_storage.py +3 -23
- sglang/srt/mem_cache/hiradix_cache.py +103 -43
- sglang/srt/mem_cache/memory_pool.py +347 -48
- sglang/srt/mem_cache/memory_pool_host.py +105 -46
- sglang/srt/mem_cache/radix_cache.py +0 -2
- sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
- sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +86 -4
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +280 -0
- sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +49 -7
- sglang/srt/mem_cache/swa_radix_cache.py +0 -2
- sglang/srt/metrics/collector.py +493 -76
- sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
- sglang/srt/model_executor/cpu_graph_runner.py +640 -0
- sglang/srt/model_executor/cuda_graph_runner.py +13 -5
- sglang/srt/model_executor/forward_batch_info.py +59 -2
- sglang/srt/model_executor/model_runner.py +356 -29
- sglang/srt/model_loader/__init__.py +9 -3
- sglang/srt/model_loader/loader.py +128 -4
- sglang/srt/model_loader/weight_utils.py +2 -1
- sglang/srt/models/apertus.py +686 -0
- sglang/srt/models/bailing_moe.py +798 -218
- sglang/srt/models/bailing_moe_nextn.py +168 -0
- sglang/srt/models/deepseek_v2.py +109 -15
- sglang/srt/models/dots_vlm.py +174 -0
- sglang/srt/models/dots_vlm_vit.py +337 -0
- sglang/srt/models/ernie4.py +1 -1
- sglang/srt/models/gemma3n_mm.py +1 -1
- sglang/srt/models/glm4_moe.py +1 -1
- sglang/srt/models/glm4v.py +4 -2
- sglang/srt/models/glm4v_moe.py +3 -0
- sglang/srt/models/gpt_oss.py +1 -1
- sglang/srt/models/llama4.py +9 -0
- sglang/srt/models/llama_eagle3.py +13 -0
- sglang/srt/models/longcat_flash.py +2 -2
- sglang/srt/models/mllama4.py +25 -0
- sglang/srt/models/opt.py +637 -0
- sglang/srt/models/qwen2.py +7 -0
- sglang/srt/models/qwen2_5_vl.py +27 -3
- sglang/srt/models/qwen2_moe.py +56 -12
- sglang/srt/models/qwen3_moe.py +1 -1
- sglang/srt/models/qwen3_next.py +1042 -0
- sglang/srt/models/qwen3_next_mtp.py +112 -0
- sglang/srt/models/step3_vl.py +1 -1
- sglang/srt/multimodal/processors/dots_vlm.py +99 -0
- sglang/srt/multimodal/processors/glm4v.py +9 -9
- sglang/srt/multimodal/processors/internvl.py +141 -129
- sglang/srt/multimodal/processors/qwen_vl.py +15 -5
- sglang/srt/offloader.py +27 -3
- sglang/srt/remote_instance_weight_loader_utils.py +69 -0
- sglang/srt/sampling/sampling_batch_info.py +18 -15
- sglang/srt/server_args.py +276 -35
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +5 -0
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +10 -1
- sglang/srt/speculative/eagle_utils.py +0 -2
- sglang/srt/speculative/eagle_worker.py +43 -4
- sglang/srt/speculative/spec_info.py +5 -0
- sglang/srt/speculative/standalone_worker.py +109 -0
- sglang/srt/tracing/trace.py +552 -0
- sglang/srt/utils.py +34 -3
- sglang/srt/weight_sync/utils.py +1 -1
- sglang/test/attention/test_trtllm_mla_backend.py +169 -5
- sglang/test/runners.py +4 -0
- sglang/test/test_cutlass_moe.py +24 -6
- sglang/test/test_disaggregation_utils.py +66 -0
- sglang/test/test_fp4_moe.py +370 -1
- sglang/test/test_utils.py +28 -1
- sglang/utils.py +11 -0
- sglang/version.py +1 -1
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/METADATA +59 -123
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/RECORD +237 -178
- sglang/srt/disaggregation/launch_lb.py +0 -118
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/WHEEL +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/top_level.txt +0 -0
@@ -1,4 +1,5 @@
|
|
1
1
|
import argparse
|
2
|
+
import hashlib
|
2
3
|
import json
|
3
4
|
from pathlib import Path
|
4
5
|
|
@@ -13,7 +14,11 @@ Supported inputs:
|
|
13
14
|
|
14
15
|
|
15
16
|
def main(args):
|
16
|
-
|
17
|
+
if args.data_type == "simple_evals":
|
18
|
+
df_input = _compute_df_input_mode_simple_evals(args)
|
19
|
+
else:
|
20
|
+
df_input = _transform_df_input(_compute_df_raw(args))
|
21
|
+
|
17
22
|
assert all(
|
18
23
|
c in df_input.columns
|
19
24
|
for c in ["category", "trial_index", "prompt_id", "prompt", "output", "correct"]
|
@@ -37,8 +42,9 @@ def main(args):
|
|
37
42
|
df_meta=df_meta.to_dicts(),
|
38
43
|
df_good_to_bad=df_good_to_bad.to_dicts(),
|
39
44
|
df_bad_to_good=df_bad_to_good.to_dicts(),
|
40
|
-
)
|
41
|
-
|
45
|
+
),
|
46
|
+
indent=4,
|
47
|
+
),
|
42
48
|
)
|
43
49
|
|
44
50
|
if not args.disable_print_details:
|
@@ -65,19 +71,70 @@ def main(args):
|
|
65
71
|
print(df)
|
66
72
|
|
67
73
|
|
74
|
+
def _compute_df_input_mode_simple_evals(args):
|
75
|
+
return pl.concat(
|
76
|
+
[
|
77
|
+
_compute_df_input_one_mode_simple_evals(**info)
|
78
|
+
for info in _get_file_infos(args=args)
|
79
|
+
]
|
80
|
+
)
|
81
|
+
|
82
|
+
|
83
|
+
def _compute_df_input_one_mode_simple_evals(path, category, trial_index):
|
84
|
+
data = json.loads(Path(path).read_text())
|
85
|
+
rows = []
|
86
|
+
|
87
|
+
for single_eval_result in data["metadata"]["single_eval_results"]:
|
88
|
+
prompt = single_eval_result["example_level_metadata"][
|
89
|
+
"actual_queried_prompt_messages"
|
90
|
+
]
|
91
|
+
score = single_eval_result["score"]
|
92
|
+
assert score in {0.0, 1.0}, f"{score=}"
|
93
|
+
|
94
|
+
row = dict(
|
95
|
+
category=category,
|
96
|
+
trial_index=trial_index,
|
97
|
+
prompt_id=_compute_id_from_object(prompt),
|
98
|
+
prompt=json.dumps(prompt),
|
99
|
+
output=single_eval_result["example_level_metadata"]["response_text"],
|
100
|
+
correct=score == 1.0,
|
101
|
+
)
|
102
|
+
rows.append(row)
|
103
|
+
|
104
|
+
return pl.DataFrame(rows)
|
105
|
+
|
106
|
+
|
107
|
+
def _compute_id_from_object(obj):
|
108
|
+
if isinstance(obj, pl.Series):
|
109
|
+
obj = obj.to_list()
|
110
|
+
json_str = json.dumps(obj, sort_keys=True, ensure_ascii=False)
|
111
|
+
return hashlib.sha256(json_str.encode("utf-8")).hexdigest()
|
112
|
+
|
113
|
+
|
68
114
|
def _compute_df_raw(args):
|
69
115
|
return pl.concat(
|
70
116
|
[
|
71
|
-
_read_df_raw(
|
72
|
-
|
73
|
-
|
74
|
-
|
75
|
-
|
76
|
-
for
|
117
|
+
_read_df_raw(
|
118
|
+
path=info["path"],
|
119
|
+
category=info["category"],
|
120
|
+
trial_index=info["trial_index"],
|
121
|
+
)
|
122
|
+
for info in _get_file_infos(args=args)
|
77
123
|
]
|
78
124
|
)
|
79
125
|
|
80
126
|
|
127
|
+
def _get_file_infos(args):
|
128
|
+
return [
|
129
|
+
dict(path=path, category=category, trial_index=trial_index)
|
130
|
+
for category, paths in [
|
131
|
+
("baseline", args.baseline_path),
|
132
|
+
("target", args.target_path),
|
133
|
+
]
|
134
|
+
for trial_index, path in enumerate(paths)
|
135
|
+
]
|
136
|
+
|
137
|
+
|
81
138
|
def _read_df_raw(path: str, category: str, trial_index: int):
|
82
139
|
return pl.read_ndjson(path).with_columns(
|
83
140
|
category=pl.lit(category), trial_index=trial_index
|
@@ -108,7 +165,9 @@ def _transform_df_input(df: pl.DataFrame):
|
|
108
165
|
print("Transform mode: SGLang bench")
|
109
166
|
return df
|
110
167
|
else:
|
111
|
-
raise Exception(
|
168
|
+
raise Exception(
|
169
|
+
f"Unknown data: {df.columns}. You may need to set `--data-type` if using e.g. simple_evals."
|
170
|
+
)
|
112
171
|
|
113
172
|
|
114
173
|
def _compute_df_meta(df_input: pl.DataFrame):
|
@@ -127,7 +186,9 @@ def _compute_df_meta(df_input: pl.DataFrame):
|
|
127
186
|
|
128
187
|
|
129
188
|
def _handle_one_prompt(df_one_prompt: pl.DataFrame):
|
130
|
-
assert
|
189
|
+
assert (
|
190
|
+
len(set(_compute_id_from_object(obj) for obj in df_one_prompt["prompt"])) == 1
|
191
|
+
)
|
131
192
|
|
132
193
|
df_baseline = df_one_prompt.filter(pl.col("category") == "baseline")
|
133
194
|
df_target = df_one_prompt.filter(pl.col("category") == "target")
|
@@ -162,6 +223,7 @@ def _compute_str_prefix_len(a: str, b: str) -> int:
|
|
162
223
|
|
163
224
|
if __name__ == "__main__":
|
164
225
|
parser = argparse.ArgumentParser(description=_DESCRIPTION)
|
226
|
+
parser.add_argument("--data-type", type=str, default="auto")
|
165
227
|
parser.add_argument("--baseline-path", type=str, nargs="+")
|
166
228
|
parser.add_argument("--target-path", type=str, nargs="+")
|
167
229
|
parser.add_argument(
|
@@ -47,6 +47,7 @@ class CommonKVManager(BaseKVManager):
|
|
47
47
|
self.is_mla_backend = is_mla_backend
|
48
48
|
self.disaggregation_mode = disaggregation_mode
|
49
49
|
# for p/d multi node infer
|
50
|
+
self.bootstrap_host = server_args.host
|
50
51
|
self.bootstrap_port = server_args.disaggregation_bootstrap_port
|
51
52
|
self.dist_init_addr = server_args.dist_init_addr
|
52
53
|
self.tp_size = server_args.tp_size
|
@@ -72,6 +73,7 @@ class CommonKVManager(BaseKVManager):
|
|
72
73
|
def _register_to_bootstrap(self):
|
73
74
|
"""Register KVSender to bootstrap server via HTTP POST."""
|
74
75
|
if self.dist_init_addr:
|
76
|
+
# multi node: bootstrap server's host is dist_init_addr
|
75
77
|
if self.dist_init_addr.startswith("["): # [ipv6]:port or [ipv6]
|
76
78
|
if self.dist_init_addr.endswith("]"):
|
77
79
|
host = self.dist_init_addr
|
@@ -80,7 +82,8 @@ class CommonKVManager(BaseKVManager):
|
|
80
82
|
else:
|
81
83
|
host = socket.gethostbyname(self.dist_init_addr.rsplit(":", 1)[0])
|
82
84
|
else:
|
83
|
-
host
|
85
|
+
# single node: bootstrap server's host is same as http server's host
|
86
|
+
host = self.bootstrap_host
|
84
87
|
host = maybe_wrap_ipv6_address(host)
|
85
88
|
|
86
89
|
bootstrap_server_url = f"{host}:{self.bootstrap_port}"
|
@@ -125,12 +128,11 @@ class CommonKVReceiver(BaseKVReceiver):
|
|
125
128
|
mgr: BaseKVManager,
|
126
129
|
bootstrap_addr: str,
|
127
130
|
bootstrap_room: Optional[int] = None,
|
128
|
-
|
131
|
+
prefill_dp_rank: Optional[int] = None,
|
129
132
|
):
|
130
133
|
self.bootstrap_room = bootstrap_room
|
131
134
|
self.bootstrap_addr = bootstrap_addr
|
132
135
|
self.kv_mgr = mgr
|
133
|
-
self.data_parallel_rank = data_parallel_rank
|
134
136
|
|
135
137
|
if self.bootstrap_addr not in self.kv_mgr.prefill_dp_size_table:
|
136
138
|
self.prefill_tp_size, self.prefill_dp_size = (
|
@@ -166,9 +168,6 @@ class CommonKVReceiver(BaseKVReceiver):
|
|
166
168
|
self.required_dst_info_num = 1
|
167
169
|
self.target_tp_ranks = [self.target_tp_rank]
|
168
170
|
elif local_tp_size_per_dp_rank > prefill_tp_size_per_dp_rank:
|
169
|
-
assert (
|
170
|
-
self.kv_mgr.is_mla_backend
|
171
|
-
), "PD with different TP sizes per DP rank is not yet supported for non-MLA models"
|
172
171
|
self.target_tp_rank = (
|
173
172
|
self.kv_mgr.kv_args.engine_rank % local_tp_size_per_dp_rank
|
174
173
|
) // (local_tp_size_per_dp_rank // prefill_tp_size_per_dp_rank)
|
@@ -198,11 +197,14 @@ class CommonKVReceiver(BaseKVReceiver):
|
|
198
197
|
self.target_tp_rank = self.target_tp_ranks[0]
|
199
198
|
self.required_dst_info_num = 1
|
200
199
|
|
201
|
-
if
|
202
|
-
logger.debug(f"Targeting DP rank: {
|
203
|
-
self.
|
200
|
+
if prefill_dp_rank is not None:
|
201
|
+
logger.debug(f"Targeting DP rank: {prefill_dp_rank}")
|
202
|
+
self.prefill_dp_rank = prefill_dp_rank
|
204
203
|
else:
|
205
|
-
self.
|
204
|
+
self.prefill_dp_rank = bootstrap_room % self.prefill_dp_size
|
205
|
+
|
206
|
+
# FIXME: alias here: target_dp_group -> prefill_dp_rank
|
207
|
+
self.target_dp_group = self.prefill_dp_rank
|
206
208
|
|
207
209
|
# NOTE: key distinguished by bootstrap_addr, target_dp_group, and target_tp_rank
|
208
210
|
bootstrap_key = (
|
@@ -308,7 +310,8 @@ class CommonKVReceiver(BaseKVReceiver):
|
|
308
310
|
|
309
311
|
|
310
312
|
class CommonKVBootstrapServer(BaseKVBootstrapServer):
|
311
|
-
def __init__(self, port: int):
|
313
|
+
def __init__(self, host: str, port: int):
|
314
|
+
self.host = host
|
312
315
|
self.port = port
|
313
316
|
self.app = web.Application()
|
314
317
|
self.store = dict()
|
@@ -412,7 +415,7 @@ class CommonKVBootstrapServer(BaseKVBootstrapServer):
|
|
412
415
|
self._runner = web.AppRunner(self.app)
|
413
416
|
self._loop.run_until_complete(self._runner.setup())
|
414
417
|
|
415
|
-
site = web.TCPSite(self._runner, port=self.port)
|
418
|
+
site = web.TCPSite(self._runner, host=self.host, port=self.port)
|
416
419
|
self._loop.run_until_complete(site.start())
|
417
420
|
self._loop.run_forever()
|
418
421
|
except Exception as e:
|
@@ -24,7 +24,7 @@ import logging
|
|
24
24
|
from collections import deque
|
25
25
|
from dataclasses import dataclass
|
26
26
|
from http import HTTPStatus
|
27
|
-
from typing import TYPE_CHECKING, List, Optional, Tuple, Union
|
27
|
+
from typing import TYPE_CHECKING, List, Optional, Tuple, Type, Union
|
28
28
|
|
29
29
|
import torch
|
30
30
|
from torch.distributed import ProcessGroup
|
@@ -218,8 +218,10 @@ class DecodePreallocQueue:
|
|
218
218
|
|
219
219
|
kv_args.ib_device = self.scheduler.server_args.disaggregation_ib_device
|
220
220
|
kv_args.gpu_id = self.scheduler.gpu_id
|
221
|
-
kv_manager_class = get_kv_class(
|
222
|
-
|
221
|
+
kv_manager_class: Type[BaseKVManager] = get_kv_class(
|
222
|
+
self.transfer_backend, KVClassType.MANAGER
|
223
|
+
)
|
224
|
+
kv_manager: BaseKVManager = kv_manager_class(
|
223
225
|
kv_args,
|
224
226
|
DisaggregationMode.DECODE,
|
225
227
|
self.scheduler.server_args,
|
@@ -248,7 +250,7 @@ class DecodePreallocQueue:
|
|
248
250
|
mgr=self.kv_manager,
|
249
251
|
bootstrap_addr=f"{req.bootstrap_host}:{req.bootstrap_port}",
|
250
252
|
bootstrap_room=req.bootstrap_room,
|
251
|
-
|
253
|
+
prefill_dp_rank=req.data_parallel_rank,
|
252
254
|
)
|
253
255
|
|
254
256
|
self.queue.append(
|
@@ -884,9 +886,18 @@ class SchedulerDisaggregationDecodeMixin:
|
|
884
886
|
# if there are still retracted requests, we do not allocate new requests
|
885
887
|
return
|
886
888
|
|
887
|
-
|
888
|
-
|
889
|
-
|
890
|
-
|
891
|
-
|
892
|
-
|
889
|
+
if not hasattr(self, "polling_count"):
|
890
|
+
self.polling_count = 0
|
891
|
+
self.polling_interval = (
|
892
|
+
self.server_args.disaggregation_decode_polling_interval
|
893
|
+
)
|
894
|
+
|
895
|
+
self.polling_count = (self.polling_count + 1) % self.polling_interval
|
896
|
+
|
897
|
+
if self.polling_count % self.polling_interval == 0:
|
898
|
+
req_conns = self.disagg_decode_prealloc_queue.pop_preallocated()
|
899
|
+
self.disagg_decode_transfer_queue.extend(req_conns)
|
900
|
+
alloc_reqs = (
|
901
|
+
self.disagg_decode_transfer_queue.pop_transferred()
|
902
|
+
) # the requests which kv has arrived
|
903
|
+
self.waiting_queue.extend(alloc_reqs)
|
@@ -110,7 +110,10 @@ class ScheduleBatchDisaggregationDecodeMixin:
|
|
110
110
|
if req.grammar is not None:
|
111
111
|
# FIXME: this try-except block is for handling unexpected xgrammar issue.
|
112
112
|
try:
|
113
|
-
|
113
|
+
# if it is not None, then the grammar is from a retracted request, and we should not
|
114
|
+
# accept the token as it's already accepted
|
115
|
+
if req.grammar.current_token is None:
|
116
|
+
req.grammar.accept_token(req.output_ids[-1])
|
114
117
|
except ValueError as e:
|
115
118
|
# Grammar accept_token can raise ValueError if the token is not in the grammar.
|
116
119
|
# This can happen if the grammar is not set correctly or the token is invalid.
|