sglang 0.4.6.post3__py3-none-any.whl → 0.4.6.post5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_offline_throughput.py +10 -8
- sglang/bench_one_batch.py +7 -6
- sglang/bench_one_batch_server.py +157 -21
- sglang/bench_serving.py +137 -59
- sglang/compile_deep_gemm.py +5 -5
- sglang/eval/loogle_eval.py +157 -0
- sglang/lang/chat_template.py +78 -78
- sglang/lang/tracer.py +1 -1
- sglang/srt/code_completion_parser.py +1 -1
- sglang/srt/configs/deepseekvl2.py +2 -2
- sglang/srt/configs/model_config.py +40 -28
- sglang/srt/constrained/base_grammar_backend.py +55 -72
- sglang/srt/constrained/llguidance_backend.py +25 -21
- sglang/srt/constrained/outlines_backend.py +27 -26
- sglang/srt/constrained/reasoner_grammar_backend.py +22 -33
- sglang/srt/constrained/xgrammar_backend.py +69 -43
- sglang/srt/conversation.py +49 -44
- sglang/srt/disaggregation/base/conn.py +1 -0
- sglang/srt/disaggregation/decode.py +129 -135
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +142 -0
- sglang/srt/disaggregation/fake/conn.py +3 -13
- sglang/srt/disaggregation/kv_events.py +357 -0
- sglang/srt/disaggregation/mini_lb.py +57 -24
- sglang/srt/disaggregation/mooncake/conn.py +238 -122
- sglang/srt/disaggregation/mooncake/transfer_engine.py +2 -1
- sglang/srt/disaggregation/nixl/conn.py +10 -19
- sglang/srt/disaggregation/prefill.py +132 -47
- sglang/srt/disaggregation/utils.py +123 -6
- sglang/srt/distributed/utils.py +3 -3
- sglang/srt/entrypoints/EngineBase.py +5 -0
- sglang/srt/entrypoints/engine.py +44 -9
- sglang/srt/entrypoints/http_server.py +23 -6
- sglang/srt/entrypoints/http_server_engine.py +5 -2
- sglang/srt/function_call/base_format_detector.py +250 -0
- sglang/srt/function_call/core_types.py +34 -0
- sglang/srt/function_call/deepseekv3_detector.py +157 -0
- sglang/srt/function_call/ebnf_composer.py +234 -0
- sglang/srt/function_call/function_call_parser.py +175 -0
- sglang/srt/function_call/llama32_detector.py +74 -0
- sglang/srt/function_call/mistral_detector.py +84 -0
- sglang/srt/function_call/pythonic_detector.py +163 -0
- sglang/srt/function_call/qwen25_detector.py +67 -0
- sglang/srt/function_call/utils.py +35 -0
- sglang/srt/hf_transformers_utils.py +46 -7
- sglang/srt/layers/attention/aiter_backend.py +513 -0
- sglang/srt/layers/attention/flashattention_backend.py +64 -18
- sglang/srt/layers/attention/flashinfer_mla_backend.py +8 -4
- sglang/srt/layers/attention/flashmla_backend.py +340 -78
- sglang/srt/layers/attention/triton_backend.py +3 -0
- sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +1 -1
- sglang/srt/layers/attention/utils.py +6 -4
- sglang/srt/layers/attention/vision.py +1 -1
- sglang/srt/layers/communicator.py +451 -0
- sglang/srt/layers/dp_attention.py +61 -21
- sglang/srt/layers/layernorm.py +1 -1
- sglang/srt/layers/logits_processor.py +46 -11
- sglang/srt/layers/moe/cutlass_moe.py +207 -0
- sglang/srt/layers/moe/ep_moe/kernels.py +34 -12
- sglang/srt/layers/moe/ep_moe/layer.py +105 -51
- sglang/srt/layers/moe/ep_moe/token_dispatcher.py +82 -7
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +1 -1
- sglang/srt/layers/moe/fused_moe_triton/layer.py +14 -0
- sglang/srt/layers/moe/topk.py +67 -10
- sglang/srt/layers/multimodal.py +70 -0
- sglang/srt/layers/quantization/__init__.py +8 -3
- sglang/srt/layers/quantization/blockwise_int8.py +2 -2
- sglang/srt/layers/quantization/deep_gemm.py +77 -74
- sglang/srt/layers/quantization/fp8.py +92 -2
- sglang/srt/layers/quantization/fp8_kernel.py +3 -3
- sglang/srt/layers/quantization/fp8_utils.py +6 -0
- sglang/srt/layers/quantization/gptq.py +298 -6
- sglang/srt/layers/quantization/int8_kernel.py +20 -7
- sglang/srt/layers/quantization/qoq.py +244 -0
- sglang/srt/layers/sampler.py +0 -4
- sglang/srt/layers/vocab_parallel_embedding.py +18 -7
- sglang/srt/lora/lora_manager.py +2 -4
- sglang/srt/lora/mem_pool.py +4 -4
- sglang/srt/lora/triton_ops/gate_up_lora_b.py +1 -1
- sglang/srt/lora/triton_ops/qkv_lora_b.py +1 -1
- sglang/srt/lora/triton_ops/sgemm_lora_a.py +1 -1
- sglang/srt/lora/triton_ops/sgemm_lora_b.py +1 -1
- sglang/srt/lora/utils.py +1 -1
- sglang/srt/managers/data_parallel_controller.py +3 -3
- sglang/srt/managers/deepseek_eplb.py +278 -0
- sglang/srt/managers/detokenizer_manager.py +21 -8
- sglang/srt/managers/eplb_manager.py +55 -0
- sglang/srt/managers/expert_distribution.py +704 -56
- sglang/srt/managers/expert_location.py +394 -0
- sglang/srt/managers/expert_location_dispatch.py +91 -0
- sglang/srt/managers/io_struct.py +19 -4
- sglang/srt/managers/mm_utils.py +294 -140
- sglang/srt/managers/multimodal_processors/base_processor.py +127 -42
- sglang/srt/managers/multimodal_processors/deepseek_vl_v2.py +6 -1
- sglang/srt/managers/multimodal_processors/gemma3.py +31 -6
- sglang/srt/managers/multimodal_processors/internvl.py +14 -5
- sglang/srt/managers/multimodal_processors/janus_pro.py +7 -1
- sglang/srt/managers/multimodal_processors/kimi_vl.py +7 -6
- sglang/srt/managers/multimodal_processors/llava.py +46 -0
- sglang/srt/managers/multimodal_processors/minicpm.py +25 -31
- sglang/srt/managers/multimodal_processors/mllama4.py +6 -0
- sglang/srt/managers/multimodal_processors/pixtral.py +127 -0
- sglang/srt/managers/multimodal_processors/qwen_vl.py +58 -16
- sglang/srt/managers/schedule_batch.py +122 -42
- sglang/srt/managers/schedule_policy.py +1 -5
- sglang/srt/managers/scheduler.py +205 -138
- sglang/srt/managers/scheduler_output_processor_mixin.py +124 -55
- sglang/srt/managers/session_controller.py +1 -1
- sglang/srt/managers/tokenizer_manager.py +232 -58
- sglang/srt/managers/tp_worker.py +12 -9
- sglang/srt/managers/tp_worker_overlap_thread.py +22 -11
- sglang/srt/mem_cache/base_prefix_cache.py +3 -0
- sglang/srt/mem_cache/chunk_cache.py +3 -1
- sglang/srt/mem_cache/hiradix_cache.py +4 -4
- sglang/srt/mem_cache/memory_pool.py +76 -52
- sglang/srt/mem_cache/multimodal_cache.py +45 -0
- sglang/srt/mem_cache/radix_cache.py +58 -5
- sglang/srt/metrics/collector.py +314 -39
- sglang/srt/mm_utils.py +10 -0
- sglang/srt/model_executor/cuda_graph_runner.py +29 -19
- sglang/srt/model_executor/expert_location_updater.py +422 -0
- sglang/srt/model_executor/forward_batch_info.py +5 -1
- sglang/srt/model_executor/model_runner.py +163 -68
- sglang/srt/model_loader/loader.py +10 -6
- sglang/srt/models/clip.py +5 -1
- sglang/srt/models/deepseek_janus_pro.py +2 -2
- sglang/srt/models/deepseek_v2.py +308 -351
- sglang/srt/models/exaone.py +8 -3
- sglang/srt/models/gemma3_mm.py +70 -33
- sglang/srt/models/llama.py +2 -0
- sglang/srt/models/llama4.py +15 -8
- sglang/srt/models/llava.py +258 -7
- sglang/srt/models/mimo_mtp.py +220 -0
- sglang/srt/models/minicpmo.py +5 -12
- sglang/srt/models/mistral.py +71 -1
- sglang/srt/models/mixtral.py +98 -34
- sglang/srt/models/mllama.py +3 -3
- sglang/srt/models/pixtral.py +467 -0
- sglang/srt/models/qwen2.py +95 -26
- sglang/srt/models/qwen2_5_vl.py +8 -0
- sglang/srt/models/qwen2_moe.py +330 -60
- sglang/srt/models/qwen2_vl.py +6 -0
- sglang/srt/models/qwen3.py +52 -10
- sglang/srt/models/qwen3_moe.py +411 -48
- sglang/srt/models/roberta.py +1 -1
- sglang/srt/models/siglip.py +294 -0
- sglang/srt/models/torch_native_llama.py +1 -1
- sglang/srt/openai_api/adapter.py +58 -20
- sglang/srt/openai_api/protocol.py +6 -8
- sglang/srt/operations.py +154 -0
- sglang/srt/operations_strategy.py +31 -0
- sglang/srt/reasoning_parser.py +3 -3
- sglang/srt/sampling/custom_logit_processor.py +18 -3
- sglang/srt/sampling/sampling_batch_info.py +4 -56
- sglang/srt/sampling/sampling_params.py +2 -2
- sglang/srt/server_args.py +162 -22
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +3 -3
- sglang/srt/speculative/eagle_utils.py +138 -7
- sglang/srt/speculative/eagle_worker.py +69 -21
- sglang/srt/utils.py +74 -17
- sglang/test/few_shot_gsm8k.py +2 -2
- sglang/test/few_shot_gsm8k_engine.py +2 -2
- sglang/test/run_eval.py +2 -2
- sglang/test/runners.py +8 -1
- sglang/test/send_one.py +13 -3
- sglang/test/simple_eval_common.py +1 -1
- sglang/test/simple_eval_humaneval.py +1 -1
- sglang/test/test_cutlass_moe.py +278 -0
- sglang/test/test_programs.py +5 -5
- sglang/test/test_utils.py +55 -14
- sglang/utils.py +3 -3
- sglang/version.py +1 -1
- {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/METADATA +23 -13
- {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/RECORD +178 -149
- {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/WHEEL +1 -1
- sglang/srt/function_call_parser.py +0 -858
- sglang/srt/platforms/interface.py +0 -371
- /sglang/{llama3_eval.py → eval/llama3_eval.py} +0 -0
- /sglang/srt/models/{xiaomi_mimo.py → mimo.py} +0 -0
- {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.4.6.post3.dist-info → sglang-0.4.6.post5.dist-info}/top_level.txt +0 -0
sglang/bench_serving.py
CHANGED
@@ -24,6 +24,7 @@ import warnings
|
|
24
24
|
from argparse import ArgumentParser
|
25
25
|
from dataclasses import dataclass, field
|
26
26
|
from datetime import datetime
|
27
|
+
from json import JSONDecodeError
|
27
28
|
from pathlib import Path
|
28
29
|
from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple, Union
|
29
30
|
|
@@ -73,6 +74,12 @@ class RequestFuncOutput:
|
|
73
74
|
error: str = ""
|
74
75
|
output_len: int = 0
|
75
76
|
|
77
|
+
@staticmethod
|
78
|
+
def init_new(request_func_input: RequestFuncInput):
|
79
|
+
output = RequestFuncOutput()
|
80
|
+
output.prompt_len = request_func_input.prompt_len
|
81
|
+
return output
|
82
|
+
|
76
83
|
|
77
84
|
def remove_prefix(text: str, prefix: str) -> str:
|
78
85
|
return text[len(prefix) :] if text.startswith(prefix) else text
|
@@ -114,8 +121,7 @@ async def async_request_trt_llm(
|
|
114
121
|
if args.disable_ignore_eos:
|
115
122
|
del payload["min_length"]
|
116
123
|
del payload["end_id"]
|
117
|
-
output = RequestFuncOutput()
|
118
|
-
output.prompt_len = request_func_input.prompt_len
|
124
|
+
output = RequestFuncOutput.init_new(request_func_input)
|
119
125
|
|
120
126
|
ttft = 0.0
|
121
127
|
st = time.perf_counter()
|
@@ -186,8 +192,7 @@ async def async_request_openai_completions(
|
|
186
192
|
}
|
187
193
|
headers = get_auth_headers()
|
188
194
|
|
189
|
-
output = RequestFuncOutput()
|
190
|
-
output.prompt_len = request_func_input.prompt_len
|
195
|
+
output = RequestFuncOutput.init_new(request_func_input)
|
191
196
|
|
192
197
|
generated_text = ""
|
193
198
|
output_len = request_func_input.output_len
|
@@ -269,8 +274,7 @@ async def async_request_truss(
|
|
269
274
|
}
|
270
275
|
headers = get_auth_headers()
|
271
276
|
|
272
|
-
output = RequestFuncOutput()
|
273
|
-
output.prompt_len = request_func_input.prompt_len
|
277
|
+
output = RequestFuncOutput.init_new(request_func_input)
|
274
278
|
|
275
279
|
generated_text = ""
|
276
280
|
ttft = 0.0
|
@@ -355,8 +359,7 @@ async def async_request_sglang_generate(
|
|
355
359
|
|
356
360
|
headers = get_auth_headers()
|
357
361
|
|
358
|
-
output = RequestFuncOutput()
|
359
|
-
output.prompt_len = request_func_input.prompt_len
|
362
|
+
output = RequestFuncOutput.init_new(request_func_input)
|
360
363
|
|
361
364
|
generated_text = ""
|
362
365
|
output_len = request_func_input.output_len
|
@@ -469,6 +472,10 @@ def get_model(pretrained_model_name_or_path: str) -> str:
|
|
469
472
|
def get_tokenizer(
|
470
473
|
pretrained_model_name_or_path: str,
|
471
474
|
) -> Union[PreTrainedTokenizer, PreTrainedTokenizerFast]:
|
475
|
+
assert (
|
476
|
+
pretrained_model_name_or_path is not None
|
477
|
+
and pretrained_model_name_or_path != ""
|
478
|
+
)
|
472
479
|
if pretrained_model_name_or_path.endswith(
|
473
480
|
".json"
|
474
481
|
) or pretrained_model_name_or_path.endswith(".model"):
|
@@ -582,7 +589,7 @@ def download_and_cache_file(url: str, filename: Optional[str] = None):
|
|
582
589
|
filename = os.path.join("/tmp", url.split("/")[-1])
|
583
590
|
|
584
591
|
# Check if the cache file already exists
|
585
|
-
if
|
592
|
+
if is_file_valid_json(filename):
|
586
593
|
return filename
|
587
594
|
|
588
595
|
print(f"Downloading from {url} to {filename}")
|
@@ -610,12 +617,35 @@ def download_and_cache_file(url: str, filename: Optional[str] = None):
|
|
610
617
|
return filename
|
611
618
|
|
612
619
|
|
620
|
+
def is_file_valid_json(path):
|
621
|
+
if not os.path.isfile(path):
|
622
|
+
return False
|
623
|
+
|
624
|
+
# TODO can fuse into the real file open later
|
625
|
+
try:
|
626
|
+
with open(path) as f:
|
627
|
+
json.load(f)
|
628
|
+
return True
|
629
|
+
except JSONDecodeError as e:
|
630
|
+
print(
|
631
|
+
f"{path} exists but json loading fails ({e=}), thus treat as invalid file"
|
632
|
+
)
|
633
|
+
return False
|
634
|
+
|
635
|
+
|
636
|
+
@dataclass
|
637
|
+
class DatasetRow:
|
638
|
+
prompt: str
|
639
|
+
prompt_len: int
|
640
|
+
output_len: int
|
641
|
+
|
642
|
+
|
613
643
|
def sample_mmmu_requests(
|
614
644
|
num_requests: int,
|
615
645
|
tokenizer: PreTrainedTokenizerBase,
|
616
646
|
fixed_output_len: Optional[int] = None,
|
617
647
|
random_sample: bool = True,
|
618
|
-
) -> List[
|
648
|
+
) -> List[DatasetRow]:
|
619
649
|
"""
|
620
650
|
Sample requests from the MMMU dataset using HuggingFace datasets.
|
621
651
|
|
@@ -716,7 +746,11 @@ def sample_mmmu_requests(
|
|
716
746
|
|
717
747
|
output_len = fixed_output_len if fixed_output_len is not None else 256
|
718
748
|
|
719
|
-
filtered_dataset.append(
|
749
|
+
filtered_dataset.append(
|
750
|
+
DatasetRow(
|
751
|
+
prompt=prompt, prompt_len=prompt_len, output_len=output_len
|
752
|
+
)
|
753
|
+
)
|
720
754
|
|
721
755
|
except Exception as e:
|
722
756
|
print(f"Error processing example {i}: {e}")
|
@@ -733,12 +767,12 @@ def sample_sharegpt_requests(
|
|
733
767
|
context_len: Optional[int] = None,
|
734
768
|
prompt_suffix: Optional[str] = "",
|
735
769
|
apply_chat_template=False,
|
736
|
-
) -> List[
|
770
|
+
) -> List[DatasetRow]:
|
737
771
|
if fixed_output_len is not None and fixed_output_len < 4:
|
738
772
|
raise ValueError("output_len too small")
|
739
773
|
|
740
774
|
# Download sharegpt if necessary
|
741
|
-
if not
|
775
|
+
if not is_file_valid_json(dataset_path) and dataset_path == "":
|
742
776
|
dataset_path = download_and_cache_file(SHAREGPT_URL)
|
743
777
|
|
744
778
|
# Load the dataset.
|
@@ -764,7 +798,7 @@ def sample_sharegpt_requests(
|
|
764
798
|
random.shuffle(dataset)
|
765
799
|
|
766
800
|
# Filter out sequences that are too long or too short
|
767
|
-
filtered_dataset: List[
|
801
|
+
filtered_dataset: List[DatasetRow] = []
|
768
802
|
for i in range(len(dataset)):
|
769
803
|
if len(filtered_dataset) == num_requests:
|
770
804
|
break
|
@@ -802,10 +836,12 @@ def sample_sharegpt_requests(
|
|
802
836
|
# Prune too long sequences.
|
803
837
|
continue
|
804
838
|
|
805
|
-
filtered_dataset.append(
|
839
|
+
filtered_dataset.append(
|
840
|
+
DatasetRow(prompt=prompt, prompt_len=prompt_len, output_len=output_len)
|
841
|
+
)
|
806
842
|
|
807
|
-
print(f"#Input tokens: {np.sum([x
|
808
|
-
print(f"#Output tokens: {np.sum([x
|
843
|
+
print(f"#Input tokens: {np.sum([x.prompt_len for x in filtered_dataset])}")
|
844
|
+
print(f"#Output tokens: {np.sum([x.output_len for x in filtered_dataset])}")
|
809
845
|
return filtered_dataset
|
810
846
|
|
811
847
|
|
@@ -817,7 +853,8 @@ def sample_random_requests(
|
|
817
853
|
tokenizer: PreTrainedTokenizerBase,
|
818
854
|
dataset_path: str,
|
819
855
|
random_sample: bool = True,
|
820
|
-
|
856
|
+
return_text: bool = True,
|
857
|
+
) -> List[DatasetRow]:
|
821
858
|
input_lens = np.random.randint(
|
822
859
|
max(int(input_len * range_ratio), 1),
|
823
860
|
input_len + 1,
|
@@ -833,7 +870,7 @@ def sample_random_requests(
|
|
833
870
|
# Sample token ids from ShareGPT and repeat/truncate them to satisfy the input_lens
|
834
871
|
|
835
872
|
# Download sharegpt if necessary
|
836
|
-
if not
|
873
|
+
if not is_file_valid_json(dataset_path):
|
837
874
|
dataset_path = download_and_cache_file(SHAREGPT_URL)
|
838
875
|
|
839
876
|
# Load the dataset.
|
@@ -857,7 +894,7 @@ def sample_random_requests(
|
|
857
894
|
random.shuffle(dataset)
|
858
895
|
|
859
896
|
# Filter out sequences that are too long or too short
|
860
|
-
input_requests: List[
|
897
|
+
input_requests: List[DatasetRow] = []
|
861
898
|
for data in dataset:
|
862
899
|
i = len(input_requests)
|
863
900
|
if i == num_prompts:
|
@@ -877,20 +914,34 @@ def sample_random_requests(
|
|
877
914
|
else:
|
878
915
|
ratio = (input_lens[i] + prompt_len - 1) // prompt_len
|
879
916
|
input_ids = (prompt_token_ids * ratio)[: input_lens[i]]
|
880
|
-
|
881
|
-
|
917
|
+
input_content = input_ids
|
918
|
+
if return_text:
|
919
|
+
input_content = tokenizer.decode(input_content)
|
920
|
+
input_requests.append(
|
921
|
+
DatasetRow(
|
922
|
+
prompt=input_content,
|
923
|
+
prompt_len=int(input_lens[i]),
|
924
|
+
output_len=int(output_lens[i]),
|
925
|
+
)
|
926
|
+
)
|
882
927
|
else:
|
883
928
|
# Sample token ids from random integers. This can cause some NaN issues.
|
884
929
|
offsets = np.random.randint(0, tokenizer.vocab_size, size=num_prompts)
|
885
930
|
input_requests = []
|
886
931
|
for i in range(num_prompts):
|
887
|
-
|
888
|
-
[
|
889
|
-
|
890
|
-
|
891
|
-
|
932
|
+
input_content = [
|
933
|
+
(offsets[i] + i + j) % tokenizer.vocab_size
|
934
|
+
for j in range(input_lens[i])
|
935
|
+
]
|
936
|
+
if return_text:
|
937
|
+
input_content = tokenizer.decode(input_content)
|
938
|
+
input_requests.append(
|
939
|
+
DatasetRow(
|
940
|
+
prompt=input_content,
|
941
|
+
prompt_len=int(input_lens[i]),
|
942
|
+
output_len=int(output_lens[i]),
|
943
|
+
)
|
892
944
|
)
|
893
|
-
input_requests.append((prompt, int(input_lens[i]), int(output_lens[i])))
|
894
945
|
|
895
946
|
print(f"#Input tokens: {np.sum(input_lens)}")
|
896
947
|
print(f"#Output tokens: {np.sum(output_lens)}")
|
@@ -925,7 +976,7 @@ def sample_generated_shared_prefix_requests(
|
|
925
976
|
output_len: int,
|
926
977
|
tokenizer: PreTrainedTokenizerBase,
|
927
978
|
args: argparse.Namespace,
|
928
|
-
) -> List[
|
979
|
+
) -> List[DatasetRow]:
|
929
980
|
"""Generate benchmark requests with shared system prompts using random tokens and caching."""
|
930
981
|
cache_path = get_gen_prefix_cache_path(args, tokenizer)
|
931
982
|
|
@@ -963,7 +1014,11 @@ def sample_generated_shared_prefix_requests(
|
|
963
1014
|
full_prompt = f"{system_prompt}\n\n{question}"
|
964
1015
|
prompt_len = len(tokenizer.encode(full_prompt))
|
965
1016
|
|
966
|
-
input_requests.append(
|
1017
|
+
input_requests.append(
|
1018
|
+
DatasetRow(
|
1019
|
+
prompt=full_prompt, prompt_len=prompt_len, output_len=output_len
|
1020
|
+
)
|
1021
|
+
)
|
967
1022
|
total_input_tokens += prompt_len
|
968
1023
|
total_output_tokens += output_len
|
969
1024
|
|
@@ -994,9 +1049,9 @@ def sample_generated_shared_prefix_requests(
|
|
994
1049
|
|
995
1050
|
|
996
1051
|
async def get_request(
|
997
|
-
input_requests: List[
|
1052
|
+
input_requests: List[DatasetRow],
|
998
1053
|
request_rate: float,
|
999
|
-
) -> AsyncGenerator[
|
1054
|
+
) -> AsyncGenerator[DatasetRow, None]:
|
1000
1055
|
input_requests = iter(input_requests)
|
1001
1056
|
for request in input_requests:
|
1002
1057
|
yield request
|
@@ -1012,7 +1067,7 @@ async def get_request(
|
|
1012
1067
|
|
1013
1068
|
|
1014
1069
|
def calculate_metrics(
|
1015
|
-
input_requests: List[
|
1070
|
+
input_requests: List[DatasetRow],
|
1016
1071
|
outputs: List[RequestFuncOutput],
|
1017
1072
|
dur_s: float,
|
1018
1073
|
tokenizer: PreTrainedTokenizerBase,
|
@@ -1034,7 +1089,7 @@ def calculate_metrics(
|
|
1034
1089
|
tokenizer.encode(outputs[i].generated_text, add_special_tokens=False)
|
1035
1090
|
)
|
1036
1091
|
retokenized_output_lens.append(retokenized_output_len)
|
1037
|
-
total_input += input_requests[i]
|
1092
|
+
total_input += input_requests[i].prompt_len
|
1038
1093
|
if output_len > 1:
|
1039
1094
|
tpots.append((outputs[i].latency - outputs[i].ttft) / (output_len - 1))
|
1040
1095
|
itls += outputs[i].itl
|
@@ -1096,14 +1151,14 @@ async def benchmark(
|
|
1096
1151
|
base_url: str,
|
1097
1152
|
model_id: str,
|
1098
1153
|
tokenizer: PreTrainedTokenizerBase,
|
1099
|
-
input_requests: List[
|
1154
|
+
input_requests: List[DatasetRow],
|
1100
1155
|
request_rate: float,
|
1101
1156
|
max_concurrency: Optional[int],
|
1102
1157
|
disable_tqdm: bool,
|
1103
1158
|
lora_names: List[str],
|
1104
1159
|
extra_request_body: Dict[str, Any],
|
1105
1160
|
profile: bool,
|
1106
|
-
|
1161
|
+
pd_separated: bool = False,
|
1107
1162
|
flush_cache: bool = False,
|
1108
1163
|
warmup_requests: int = 1,
|
1109
1164
|
):
|
@@ -1126,7 +1181,12 @@ async def benchmark(
|
|
1126
1181
|
print(f"Starting warmup with {warmup_requests} sequences...")
|
1127
1182
|
|
1128
1183
|
# Use the first request for all warmup iterations
|
1129
|
-
|
1184
|
+
test_request = input_requests[0]
|
1185
|
+
test_prompt, test_prompt_len, test_output_len = (
|
1186
|
+
test_request.prompt,
|
1187
|
+
test_request.prompt_len,
|
1188
|
+
test_request.output_len,
|
1189
|
+
)
|
1130
1190
|
if lora_names is not None and len(lora_names) != 0:
|
1131
1191
|
lora_name = lora_names[0]
|
1132
1192
|
else:
|
@@ -1194,7 +1254,11 @@ async def benchmark(
|
|
1194
1254
|
benchmark_start_time = time.perf_counter()
|
1195
1255
|
tasks: List[asyncio.Task] = []
|
1196
1256
|
async for request in get_request(input_requests, request_rate):
|
1197
|
-
prompt, prompt_len, output_len =
|
1257
|
+
prompt, prompt_len, output_len = (
|
1258
|
+
request.prompt,
|
1259
|
+
request.prompt_len,
|
1260
|
+
request.output_len,
|
1261
|
+
)
|
1198
1262
|
if lora_names is not None and len(lora_names) != 0:
|
1199
1263
|
idx = random.randint(0, len(lora_names) - 1)
|
1200
1264
|
lora_name = lora_names[idx]
|
@@ -1239,12 +1303,17 @@ async def benchmark(
|
|
1239
1303
|
|
1240
1304
|
if "sglang" in backend:
|
1241
1305
|
server_info = requests.get(base_url + "/get_server_info")
|
1242
|
-
if
|
1243
|
-
|
1244
|
-
"
|
1245
|
-
|
1306
|
+
if server_info.status_code == 200:
|
1307
|
+
if pd_separated:
|
1308
|
+
accept_length = server_info.json()["decode"][0]["internal_states"][
|
1309
|
+
0
|
1310
|
+
].get("avg_spec_accept_length", None)
|
1311
|
+
else:
|
1312
|
+
accept_length = server_info.json()["internal_states"][0].get(
|
1313
|
+
"avg_spec_accept_length", None
|
1314
|
+
)
|
1246
1315
|
else:
|
1247
|
-
accept_length =
|
1316
|
+
accept_length = None
|
1248
1317
|
else:
|
1249
1318
|
accept_length = None
|
1250
1319
|
|
@@ -1263,7 +1332,7 @@ async def benchmark(
|
|
1263
1332
|
print("{:<40} {:<10}".format("Traffic request rate:", request_rate))
|
1264
1333
|
print(
|
1265
1334
|
"{:<40} {:<10}".format(
|
1266
|
-
"Max
|
1335
|
+
"Max request concurrency:",
|
1267
1336
|
max_concurrency if max_concurrency else "not set",
|
1268
1337
|
)
|
1269
1338
|
)
|
@@ -1378,21 +1447,24 @@ async def benchmark(
|
|
1378
1447
|
else:
|
1379
1448
|
output_file_name = f"{args.backend}_{now}_{args.num_prompts}_sharegpt.jsonl"
|
1380
1449
|
|
1450
|
+
result_details = {
|
1451
|
+
"input_lens": [output.prompt_len for output in outputs],
|
1452
|
+
"output_lens": output_lens,
|
1453
|
+
"ttfts": [output.ttft for output in outputs],
|
1454
|
+
"itls": [output.itl for output in outputs],
|
1455
|
+
"generated_texts": [output.generated_text for output in outputs],
|
1456
|
+
"errors": [output.error for output in outputs],
|
1457
|
+
}
|
1458
|
+
|
1381
1459
|
# Append results to a JSONL file
|
1382
1460
|
with open(output_file_name, "a") as file:
|
1383
|
-
|
1384
|
-
|
1385
|
-
|
1386
|
-
|
1387
|
-
|
1388
|
-
|
1389
|
-
|
1390
|
-
"itls": [output.itl for output in outputs],
|
1391
|
-
"generated_texts": [output.generated_text for output in outputs],
|
1392
|
-
"errors": [output.error for output in outputs],
|
1393
|
-
}
|
1394
|
-
)
|
1395
|
-
return result
|
1461
|
+
if args.output_details:
|
1462
|
+
result_for_dump = result | result_details
|
1463
|
+
else:
|
1464
|
+
result_for_dump = result
|
1465
|
+
file.write(json.dumps(result_for_dump) + "\n")
|
1466
|
+
|
1467
|
+
return result | result_details
|
1396
1468
|
|
1397
1469
|
|
1398
1470
|
def check_chat_template(model_path):
|
@@ -1422,6 +1494,9 @@ def run_benchmark(args_: argparse.Namespace):
|
|
1422
1494
|
if not hasattr(args, "warmup_requests"):
|
1423
1495
|
args.warmup_requests = 1
|
1424
1496
|
|
1497
|
+
if not hasattr(args, "output_details"):
|
1498
|
+
args.output_details = False
|
1499
|
+
|
1425
1500
|
print(f"benchmark_args={args}")
|
1426
1501
|
|
1427
1502
|
# Set global environments
|
@@ -1541,7 +1616,7 @@ def run_benchmark(args_: argparse.Namespace):
|
|
1541
1616
|
lora_names=args.lora_name,
|
1542
1617
|
extra_request_body=extra_request_body,
|
1543
1618
|
profile=args.profile,
|
1544
|
-
|
1619
|
+
pd_separated=args.pd_separated,
|
1545
1620
|
flush_cache=args.flush_cache,
|
1546
1621
|
)
|
1547
1622
|
)
|
@@ -1666,6 +1741,9 @@ if __name__ == "__main__":
|
|
1666
1741
|
"if the server is not processing requests fast enough to keep up.",
|
1667
1742
|
)
|
1668
1743
|
parser.add_argument("--output-file", type=str, help="Output JSONL file name.")
|
1744
|
+
parser.add_argument(
|
1745
|
+
"--output-details", action="store_true", help="Output details of benchmarking."
|
1746
|
+
)
|
1669
1747
|
parser.add_argument(
|
1670
1748
|
"--disable-tqdm",
|
1671
1749
|
action="store_true",
|
@@ -1720,7 +1798,7 @@ if __name__ == "__main__":
|
|
1720
1798
|
help="Suffix applied to the end of all user prompts, followed by assistant prompt suffix.",
|
1721
1799
|
)
|
1722
1800
|
parser.add_argument(
|
1723
|
-
"--pd-
|
1801
|
+
"--pd-separated",
|
1724
1802
|
action="store_true",
|
1725
1803
|
help="Benchmark PD disaggregation server",
|
1726
1804
|
)
|
sglang/compile_deep_gemm.py
CHANGED
@@ -82,8 +82,8 @@ def launch_server_process_and_send_one_request(
|
|
82
82
|
base_url = f"http://{server_args.host}:{server_args.port}"
|
83
83
|
timeout = compile_args.timeout
|
84
84
|
|
85
|
-
start_time = time.
|
86
|
-
while time.
|
85
|
+
start_time = time.perf_counter()
|
86
|
+
while time.perf_counter() - start_time < timeout:
|
87
87
|
try:
|
88
88
|
headers = {
|
89
89
|
"Content-Type": "application/json; charset=utf-8",
|
@@ -112,9 +112,9 @@ def launch_server_process_and_send_one_request(
|
|
112
112
|
raise RuntimeError(f"Sync request failed: {error}")
|
113
113
|
# Other nodes should wait for the exit signal from Rank-0 node.
|
114
114
|
else:
|
115
|
-
start_time_waiting = time.
|
115
|
+
start_time_waiting = time.perf_counter()
|
116
116
|
while proc.is_alive():
|
117
|
-
if time.
|
117
|
+
if time.perf_counter() - start_time_waiting < timeout:
|
118
118
|
time.sleep(10)
|
119
119
|
else:
|
120
120
|
raise TimeoutError("Waiting for main node timeout!")
|
@@ -129,7 +129,7 @@ def launch_server_process_and_send_one_request(
|
|
129
129
|
|
130
130
|
|
131
131
|
def refine_server_args(server_args: ServerArgs, compile_args: CompileArgs):
|
132
|
-
#
|
132
|
+
# Disable cuda graph and torch compile to save time
|
133
133
|
server_args.disable_cuda_graph = True
|
134
134
|
server_args.enable_torch_compile = False
|
135
135
|
print(f"Disable CUDA Graph and Torch Compile to save time...")
|
@@ -0,0 +1,157 @@
|
|
1
|
+
import argparse
|
2
|
+
import asyncio
|
3
|
+
import os
|
4
|
+
import pickle
|
5
|
+
from pathlib import Path
|
6
|
+
from typing import List
|
7
|
+
|
8
|
+
import openai
|
9
|
+
import torch
|
10
|
+
from bert_score import BERTScorer
|
11
|
+
from datasets import load_dataset
|
12
|
+
from tqdm import tqdm
|
13
|
+
|
14
|
+
|
15
|
+
def get_client(api_url: str) -> openai.AsyncOpenAI:
|
16
|
+
if os.getenv("OPENAI_API_KEY") is None:
|
17
|
+
os.environ["OPENAI_API_KEY"] = "EMPTY"
|
18
|
+
return openai.AsyncOpenAI(base_url=api_url)
|
19
|
+
|
20
|
+
|
21
|
+
def get_dataset():
|
22
|
+
return load_dataset("bigai-nlco/LooGLE", "longdep_qa", split="test")
|
23
|
+
|
24
|
+
|
25
|
+
async def fetch_response(
|
26
|
+
client: openai.AsyncOpenAI,
|
27
|
+
context: str,
|
28
|
+
question: str,
|
29
|
+
semaphore: asyncio.Semaphore,
|
30
|
+
index: int,
|
31
|
+
model: str,
|
32
|
+
output_dir: Path,
|
33
|
+
):
|
34
|
+
output_file = output_dir / f"response_{index}.pkl"
|
35
|
+
if output_file.exists():
|
36
|
+
return
|
37
|
+
|
38
|
+
prompt = (
|
39
|
+
"Please answer the question based on the long texts below.\n"
|
40
|
+
f"{context}\n"
|
41
|
+
f"Question: {question}\n"
|
42
|
+
"Answer:"
|
43
|
+
)
|
44
|
+
messages = [
|
45
|
+
{"role": "system", "content": "You are a helpful assistant."},
|
46
|
+
{"role": "user", "content": prompt},
|
47
|
+
]
|
48
|
+
|
49
|
+
async with semaphore:
|
50
|
+
try:
|
51
|
+
response = await client.chat.completions.create(
|
52
|
+
model=model,
|
53
|
+
messages=messages,
|
54
|
+
temperature=0.0,
|
55
|
+
max_tokens=512,
|
56
|
+
)
|
57
|
+
except openai.BadRequestError as e:
|
58
|
+
with open(output_file, "wb") as f:
|
59
|
+
pickle.dump({"error": str(e)}, f)
|
60
|
+
return
|
61
|
+
|
62
|
+
with open(output_file, "wb") as f:
|
63
|
+
pickle.dump(response, f)
|
64
|
+
|
65
|
+
|
66
|
+
async def benchmark(args):
|
67
|
+
dataset = get_dataset()
|
68
|
+
output_dir = Path(args.output_dir)
|
69
|
+
output_dir.mkdir(parents=True, exist_ok=True)
|
70
|
+
|
71
|
+
client = get_client(args.api_url)
|
72
|
+
semaphore = asyncio.Semaphore(args.max_concurrency)
|
73
|
+
|
74
|
+
tasks: List[asyncio.Task] = []
|
75
|
+
for idx, ex in enumerate(dataset):
|
76
|
+
tasks.append(
|
77
|
+
asyncio.create_task(
|
78
|
+
fetch_response(
|
79
|
+
client,
|
80
|
+
ex["context"],
|
81
|
+
ex["question"],
|
82
|
+
semaphore,
|
83
|
+
idx,
|
84
|
+
args.model,
|
85
|
+
output_dir,
|
86
|
+
)
|
87
|
+
)
|
88
|
+
)
|
89
|
+
|
90
|
+
for _ in tqdm(
|
91
|
+
asyncio.as_completed(tasks), total=len(tasks), desc="Running benchmark"
|
92
|
+
):
|
93
|
+
await _
|
94
|
+
|
95
|
+
|
96
|
+
def analyse(args):
|
97
|
+
dataset = get_dataset()
|
98
|
+
output_dir = Path(args.output_dir)
|
99
|
+
|
100
|
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
101
|
+
scorer = BERTScorer(lang="en", device=device)
|
102
|
+
|
103
|
+
hyps: List[str] = []
|
104
|
+
refs: List[str] = []
|
105
|
+
for idx, ex in enumerate(tqdm(dataset, desc="Loading responses")):
|
106
|
+
pkl_file = output_dir / f"response_{idx}.pkl"
|
107
|
+
if not pkl_file.exists():
|
108
|
+
raise FileNotFoundError(pkl_file)
|
109
|
+
|
110
|
+
response = pickle.load(open(pkl_file, "rb"))
|
111
|
+
if isinstance(response, dict) and "error" in response:
|
112
|
+
continue
|
113
|
+
|
114
|
+
hyps.append(response.choices[0].message.content.strip())
|
115
|
+
refs.append(ex["answer"])
|
116
|
+
|
117
|
+
if not hyps:
|
118
|
+
print("No valid responses to score!")
|
119
|
+
return
|
120
|
+
|
121
|
+
batch_size = 64
|
122
|
+
all_f1: List[float] = []
|
123
|
+
for i in tqdm(range(0, len(hyps), batch_size), desc="Scoring batches"):
|
124
|
+
h_batch = hyps[i : i + batch_size]
|
125
|
+
r_batch = refs[i : i + batch_size]
|
126
|
+
_, _, f1_scores = scorer.score(h_batch, r_batch, verbose=False)
|
127
|
+
all_f1.extend([float(x) for x in f1_scores])
|
128
|
+
|
129
|
+
avg = sum(all_f1) / len(all_f1)
|
130
|
+
print(f"Average BERTScore (F1): {avg:.2%}")
|
131
|
+
|
132
|
+
|
133
|
+
if __name__ == "__main__":
|
134
|
+
parser = argparse.ArgumentParser(
|
135
|
+
description="Run benchmark and evaluation in one go."
|
136
|
+
)
|
137
|
+
parser.add_argument(
|
138
|
+
"--api-url",
|
139
|
+
default="http://127.0.0.1:30000/v1",
|
140
|
+
help="OpenAI‑compatible API base URL",
|
141
|
+
)
|
142
|
+
parser.add_argument(
|
143
|
+
"--model",
|
144
|
+
default="meta-llama/Llama-4-Maverick-17B-128E-Instruct",
|
145
|
+
help="Model name or ID, only used for model name",
|
146
|
+
)
|
147
|
+
parser.add_argument(
|
148
|
+
"--max-concurrency", type=int, default=144, help="Maximum concurrent requests"
|
149
|
+
)
|
150
|
+
parser.add_argument(
|
151
|
+
"--output-dir", default="tmp-output-dir", help="Directory for cached responses"
|
152
|
+
)
|
153
|
+
args = parser.parse_args()
|
154
|
+
|
155
|
+
asyncio.run(benchmark(args))
|
156
|
+
|
157
|
+
analyse(args)
|