sglang 0.4.3.post2__py3-none-any.whl → 0.4.3.post4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/api.py +1 -1
- sglang/bench_offline_throughput.py +19 -0
- sglang/bench_one_batch.py +2 -2
- sglang/bench_serving.py +123 -79
- sglang/global_config.py +8 -3
- sglang/lang/backend/runtime_endpoint.py +1 -1
- sglang/lang/ir.py +1 -1
- sglang/srt/_custom_ops.py +83 -91
- sglang/srt/configs/load_config.py +4 -1
- sglang/srt/configs/model_config.py +48 -2
- sglang/srt/configs/qwen2_5_vl_config.py +5 -2
- sglang/srt/constrained/base_grammar_backend.py +117 -15
- sglang/srt/constrained/llguidance_backend.py +151 -0
- sglang/srt/constrained/outlines_backend.py +24 -33
- sglang/srt/constrained/xgrammar_backend.py +69 -38
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +225 -80
- sglang/srt/distributed/parallel_state.py +48 -3
- sglang/srt/entrypoints/engine.py +67 -9
- sglang/srt/entrypoints/http_server.py +190 -41
- sglang/srt/entrypoints/verl_engine.py +147 -0
- sglang/srt/function_call_parser.py +0 -1
- sglang/srt/layers/activation.py +11 -0
- sglang/srt/layers/attention/{__init__.py → base_attn_backend.py} +14 -6
- sglang/srt/layers/attention/double_sparsity_backend.py +1 -1
- sglang/srt/layers/attention/flashinfer_backend.py +302 -414
- sglang/srt/layers/attention/flashinfer_mla_backend.py +582 -0
- sglang/srt/layers/attention/torch_native_backend.py +1 -1
- sglang/srt/layers/attention/triton_backend.py +13 -8
- sglang/srt/layers/attention/triton_ops/decode_attention.py +3 -0
- sglang/srt/layers/attention/triton_ops/extend_attention.py +20 -4
- sglang/srt/layers/attention/triton_ops/rocm_mla_decode_rope.py +439 -0
- sglang/srt/layers/attention/utils.py +39 -0
- sglang/srt/layers/attention/vision.py +60 -63
- sglang/srt/layers/dp_attention.py +142 -1
- sglang/srt/layers/layernorm.py +1 -1
- sglang/srt/layers/linear.py +3 -1
- sglang/srt/layers/logits_processor.py +281 -45
- sglang/srt/layers/moe/ep_moe/kernels.py +126 -8
- sglang/srt/layers/moe/ep_moe/layer.py +140 -28
- sglang/srt/layers/moe/fused_moe_native.py +2 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +50 -50
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Instinct_MI300X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Instinct_MI325X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Radeon_Graphics.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Instinct_MI300X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Instinct_MI325X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Radeon_Graphics.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Instinct_MI300X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Instinct_MI325X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Radeon_Graphics.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json +16 -16
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json +16 -16
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json +16 -16
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Instinct_MI300X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Instinct_MI325X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Radeon_Graphics.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json +15 -15
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json +15 -15
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json +15 -15
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +88 -20
- sglang/srt/layers/moe/fused_moe_triton/layer.py +34 -13
- sglang/srt/layers/moe/topk.py +13 -4
- sglang/srt/layers/quantization/__init__.py +111 -7
- sglang/srt/layers/quantization/blockwise_int8.py +409 -0
- sglang/srt/layers/quantization/configs/N=1536,K=1536,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=2048,K=512,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=2304,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=256,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=1024,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=1152,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/fp8.py +69 -28
- sglang/srt/layers/quantization/fp8_utils.py +17 -1
- sglang/srt/layers/quantization/gptq.py +416 -0
- sglang/srt/layers/quantization/int8_kernel.py +327 -0
- sglang/srt/layers/quantization/int8_utils.py +73 -0
- sglang/srt/layers/quantization/modelopt_quant.py +18 -1
- sglang/srt/layers/radix_attention.py +1 -0
- sglang/srt/layers/rotary_embedding.py +0 -1
- sglang/srt/layers/sampler.py +76 -31
- sglang/srt/layers/vocab_parallel_embedding.py +14 -13
- sglang/srt/lora/lora.py +17 -1
- sglang/srt/lora/lora_config.py +5 -0
- sglang/srt/lora/lora_manager.py +1 -3
- sglang/srt/managers/cache_controller.py +193 -62
- sglang/srt/managers/configure_logging.py +2 -1
- sglang/srt/managers/data_parallel_controller.py +6 -2
- sglang/srt/managers/detokenizer_manager.py +124 -102
- sglang/srt/managers/image_processor.py +2 -1
- sglang/srt/managers/io_struct.py +144 -6
- sglang/srt/managers/schedule_batch.py +237 -197
- sglang/srt/managers/schedule_policy.py +29 -29
- sglang/srt/managers/scheduler.py +773 -334
- sglang/srt/managers/session_controller.py +6 -2
- sglang/srt/managers/tokenizer_manager.py +225 -68
- sglang/srt/managers/tp_worker.py +15 -4
- sglang/srt/managers/tp_worker_overlap_thread.py +3 -4
- sglang/srt/mem_cache/chunk_cache.py +18 -11
- sglang/srt/mem_cache/hiradix_cache.py +394 -0
- sglang/srt/mem_cache/memory_pool.py +68 -37
- sglang/srt/mem_cache/radix_cache.py +58 -47
- sglang/srt/metrics/collector.py +102 -36
- sglang/srt/model_executor/cuda_graph_runner.py +56 -31
- sglang/srt/model_executor/forward_batch_info.py +49 -16
- sglang/srt/model_executor/model_runner.py +280 -81
- sglang/srt/model_loader/loader.py +3 -3
- sglang/srt/model_loader/weight_utils.py +36 -14
- sglang/srt/models/baichuan.py +31 -6
- sglang/srt/models/chatglm.py +39 -7
- sglang/srt/models/commandr.py +29 -5
- sglang/srt/models/dbrx.py +31 -5
- sglang/srt/models/deepseek.py +43 -6
- sglang/srt/models/deepseek_nextn.py +32 -19
- sglang/srt/models/deepseek_v2.py +265 -32
- sglang/srt/models/exaone.py +19 -9
- sglang/srt/models/gemma.py +22 -8
- sglang/srt/models/gemma2.py +25 -12
- sglang/srt/models/gemma2_reward.py +5 -1
- sglang/srt/models/gpt2.py +28 -13
- sglang/srt/models/gpt_bigcode.py +27 -5
- sglang/srt/models/granite.py +21 -9
- sglang/srt/models/grok.py +21 -4
- sglang/srt/models/internlm2.py +36 -6
- sglang/srt/models/internlm2_reward.py +5 -1
- sglang/srt/models/llama.py +26 -9
- sglang/srt/models/llama_classification.py +5 -1
- sglang/srt/models/llama_eagle.py +17 -4
- sglang/srt/models/llama_embedding.py +5 -1
- sglang/srt/models/llama_reward.py +7 -2
- sglang/srt/models/llava.py +19 -3
- sglang/srt/models/llavavid.py +10 -1
- sglang/srt/models/minicpm.py +26 -2
- sglang/srt/models/minicpm3.py +39 -3
- sglang/srt/models/minicpmv.py +45 -14
- sglang/srt/models/mixtral.py +20 -9
- sglang/srt/models/mixtral_quant.py +50 -8
- sglang/srt/models/mllama.py +57 -11
- sglang/srt/models/olmo.py +34 -6
- sglang/srt/models/olmo2.py +34 -13
- sglang/srt/models/olmoe.py +26 -4
- sglang/srt/models/phi3_small.py +29 -10
- sglang/srt/models/qwen.py +26 -3
- sglang/srt/models/qwen2.py +26 -4
- sglang/srt/models/qwen2_5_vl.py +46 -8
- sglang/srt/models/qwen2_eagle.py +17 -5
- sglang/srt/models/qwen2_moe.py +44 -6
- sglang/srt/models/qwen2_rm.py +78 -0
- sglang/srt/models/qwen2_vl.py +39 -8
- sglang/srt/models/stablelm.py +32 -5
- sglang/srt/models/torch_native_llama.py +5 -2
- sglang/srt/models/xverse.py +21 -9
- sglang/srt/models/xverse_moe.py +45 -7
- sglang/srt/models/yivl.py +2 -1
- sglang/srt/openai_api/adapter.py +109 -24
- sglang/srt/openai_api/protocol.py +17 -1
- sglang/srt/reasoning_parser.py +154 -0
- sglang/srt/sampling/penaltylib/__init__.py +4 -6
- sglang/srt/sampling/penaltylib/frequency_penalty.py +66 -0
- sglang/srt/sampling/penaltylib/{penalizers/min_new_tokens.py → min_new_tokens.py} +15 -23
- sglang/srt/sampling/penaltylib/orchestrator.py +39 -188
- sglang/srt/sampling/penaltylib/presence_penalty.py +66 -0
- sglang/srt/sampling/sampling_batch_info.py +79 -157
- sglang/srt/sampling/sampling_params.py +16 -13
- sglang/srt/server_args.py +135 -60
- sglang/srt/speculative/build_eagle_tree.py +8 -9
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +1 -12
- sglang/srt/speculative/eagle_utils.py +92 -57
- sglang/srt/speculative/eagle_worker.py +238 -111
- sglang/srt/speculative/spec_info.py +1 -13
- sglang/srt/utils.py +43 -17
- sglang/srt/warmup.py +47 -0
- sglang/test/few_shot_gsm8k.py +4 -1
- sglang/test/runners.py +389 -126
- sglang/test/send_one.py +88 -0
- sglang/test/test_block_fp8_ep.py +361 -0
- sglang/test/test_programs.py +1 -1
- sglang/test/test_utils.py +138 -84
- sglang/utils.py +50 -60
- sglang/version.py +1 -1
- {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post4.dist-info}/METADATA +22 -15
- {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post4.dist-info}/RECORD +200 -166
- {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post4.dist-info}/WHEEL +1 -1
- sglang/bench_latency.py +0 -1
- sglang/srt/sampling/penaltylib/penalizers/frequency_penalty.py +0 -75
- sglang/srt/sampling/penaltylib/penalizers/presence_penalty.py +0 -74
- sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +0 -85
- sglang/test/srt/sampling/penaltylib/utils.py +0 -344
- {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post4.dist-info}/LICENSE +0 -0
- {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post4.dist-info}/top_level.txt +0 -0
sglang/test/runners.py
CHANGED
@@ -15,15 +15,15 @@
|
|
15
15
|
import multiprocessing as mp
|
16
16
|
import os
|
17
17
|
from dataclasses import dataclass
|
18
|
-
from typing import List, Union
|
18
|
+
from typing import List, Optional, Tuple, Union
|
19
19
|
|
20
20
|
import torch
|
21
21
|
import torch.nn.functional as F
|
22
22
|
from transformers import AutoModelForCausalLM
|
23
23
|
|
24
|
-
from sglang.srt.entrypoints.engine import Engine
|
25
24
|
from sglang.srt.hf_transformers_utils import get_tokenizer
|
26
|
-
from sglang.
|
25
|
+
from sglang.srt.server import Engine
|
26
|
+
from sglang.test.test_utils import DEFAULT_PORT_FOR_SRT_TEST_RUNNER, calculate_rouge_l
|
27
27
|
|
28
28
|
DEFAULT_PROMPTS = [
|
29
29
|
"Apple is red. Banana is Yellow. " * 800 + "Apple is",
|
@@ -56,6 +56,13 @@ def get_top_logprobs(logits, k):
|
|
56
56
|
return logprobs
|
57
57
|
|
58
58
|
|
59
|
+
def get_token_ids_logprobs(logits, token_ids):
|
60
|
+
logprobs = F.log_softmax(logits, dim=-1, dtype=torch.float32)
|
61
|
+
del logits
|
62
|
+
logprobs = logprobs[..., token_ids]
|
63
|
+
return logprobs
|
64
|
+
|
65
|
+
|
59
66
|
def _get_sentence_transformer_embedding_model(model_path, torch_dtype):
|
60
67
|
from sentence_transformers import SentenceTransformer
|
61
68
|
from sentence_transformers.util import is_sentence_transformer_model
|
@@ -84,8 +91,13 @@ class ModelOutput:
|
|
84
91
|
output_ids: List[int] = None
|
85
92
|
top_input_logprobs: List[torch.Tensor] = None
|
86
93
|
top_output_logprobs: List[torch.Tensor] = None
|
94
|
+
top_output_logprob_idx: List[List[int]] = None
|
87
95
|
embed_logits: List[torch.Tensor] = None
|
88
96
|
scores: List[float] = None
|
97
|
+
input_token_logprobs_lst: List[List[Tuple[float, int, None]]] = None
|
98
|
+
output_token_logprobs_lst: List[List[Tuple[float, int, None]]] = None
|
99
|
+
token_ids_input_logprobs: List[torch.Tensor] = None
|
100
|
+
token_ids_output_logprobs: List[torch.Tensor] = None
|
89
101
|
|
90
102
|
|
91
103
|
class HFRunner:
|
@@ -95,9 +107,11 @@ class HFRunner:
|
|
95
107
|
torch_dtype: torch.dtype,
|
96
108
|
model_type: str = "generation",
|
97
109
|
output_str_only: bool = False,
|
110
|
+
trust_remote_code: bool = False,
|
98
111
|
):
|
99
112
|
self.model_type = model_type
|
100
113
|
self.output_str_only = output_str_only
|
114
|
+
self.trust_remote_code = trust_remote_code
|
101
115
|
|
102
116
|
self.in_queue = mp.Queue()
|
103
117
|
self.out_queue = mp.Queue()
|
@@ -130,7 +144,7 @@ class HFRunner:
|
|
130
144
|
self.base_model = AutoModelForCausalLM.from_pretrained(
|
131
145
|
model_path,
|
132
146
|
torch_dtype=torch_dtype,
|
133
|
-
trust_remote_code=
|
147
|
+
trust_remote_code=self.trust_remote_code,
|
134
148
|
low_cpu_mem_usage=True,
|
135
149
|
).cuda()
|
136
150
|
elif self.model_type == "embedding":
|
@@ -147,79 +161,32 @@ class HFRunner:
|
|
147
161
|
).cuda()
|
148
162
|
else:
|
149
163
|
raise Exception(f"Unrecognized model type {self.model_type}")
|
150
|
-
self.tokenizer = get_tokenizer(
|
164
|
+
self.tokenizer = get_tokenizer(
|
165
|
+
model_path,
|
166
|
+
torch_dtype=torch.dtype,
|
167
|
+
trust_remote_code=self.trust_remote_code,
|
168
|
+
)
|
151
169
|
|
152
170
|
# Run forward
|
153
171
|
while True:
|
154
|
-
prompts, max_new_tokens, lora_paths = in_queue.get()
|
172
|
+
prompts, max_new_tokens, lora_paths, token_ids_logprob = in_queue.get()
|
155
173
|
if lora_paths is not None:
|
156
174
|
assert len(prompts) == len(lora_paths)
|
157
175
|
|
158
176
|
if prompts is not None:
|
159
177
|
if self.model_type == "generation":
|
160
|
-
output_strs = []
|
161
|
-
top_input_logprobs = []
|
162
|
-
top_output_logprobs = []
|
163
|
-
for i, p in enumerate(prompts):
|
164
|
-
if isinstance(p, str):
|
165
|
-
input_ids = self.tokenizer.encode(
|
166
|
-
p, return_tensors="pt"
|
167
|
-
).cuda()
|
168
|
-
else:
|
169
|
-
input_ids = torch.tensor([p], device="cuda")
|
170
|
-
|
171
|
-
if lora_paths is not None and lora_paths[i] is not None:
|
172
|
-
from peft import PeftModel
|
173
|
-
|
174
|
-
self.model = PeftModel.from_pretrained(
|
175
|
-
self.base_model,
|
176
|
-
lora_paths[i],
|
177
|
-
torch_dtype=torch_dtype,
|
178
|
-
is_trainable=False,
|
179
|
-
)
|
180
|
-
else:
|
181
|
-
self.model = self.base_model
|
182
|
-
|
183
|
-
outputs = self.model.generate(
|
184
|
-
input_ids,
|
185
|
-
do_sample=False,
|
186
|
-
temperature=None,
|
187
|
-
top_p=None,
|
188
|
-
max_new_tokens=max_new_tokens,
|
189
|
-
return_dict_in_generate=True,
|
190
|
-
output_scores=(not self.output_str_only),
|
191
|
-
)
|
192
|
-
output_strs.append(
|
193
|
-
self.tokenizer.decode(outputs[0][0][len(input_ids[0]) :])
|
194
|
-
)
|
195
|
-
if not self.output_str_only:
|
196
|
-
# outputs.scores: (num_token, 1, vocab_size)
|
197
|
-
top_output_logprobs.append(
|
198
|
-
[
|
199
|
-
get_top_logprobs(
|
200
|
-
logits[0], NUM_TOP_LOGPROBS
|
201
|
-
).tolist()
|
202
|
-
for logits in outputs.scores
|
203
|
-
]
|
204
|
-
)
|
205
|
-
del outputs
|
206
|
-
|
207
|
-
input_logits = self.model.forward(input_ids).logits[0]
|
208
|
-
top_input_logprobs.append(
|
209
|
-
get_top_logprobs(
|
210
|
-
input_logits, NUM_TOP_LOGPROBS
|
211
|
-
).tolist()
|
212
|
-
)
|
213
|
-
del input_logits
|
214
|
-
|
215
178
|
out_queue.put(
|
216
|
-
|
217
|
-
|
218
|
-
|
219
|
-
|
179
|
+
self.forward_generation_raw(
|
180
|
+
base_model=self.base_model,
|
181
|
+
prompts=prompts,
|
182
|
+
max_new_tokens=max_new_tokens,
|
183
|
+
tokenizer=self.tokenizer,
|
184
|
+
lora_paths=lora_paths,
|
185
|
+
torch_dtype=torch_dtype,
|
186
|
+
output_str_only=self.output_str_only,
|
187
|
+
token_ids_logprob=token_ids_logprob,
|
220
188
|
)
|
221
189
|
)
|
222
|
-
|
223
190
|
elif self.model_type == "embedding":
|
224
191
|
assert not self.output_str_only
|
225
192
|
logits = self.model.encode(prompts).tolist()
|
@@ -244,10 +211,11 @@ class HFRunner:
|
|
244
211
|
def forward(
|
245
212
|
self,
|
246
213
|
prompts: Union[List[str], List[torch.Tensor]] = DEFAULT_PROMPTS,
|
247
|
-
max_new_tokens=8,
|
248
|
-
lora_paths=None,
|
214
|
+
max_new_tokens: int = 8,
|
215
|
+
lora_paths: Optional[List[str]] = None,
|
216
|
+
token_ids_logprob: Optional[int] = None,
|
249
217
|
):
|
250
|
-
self.in_queue.put((prompts, max_new_tokens, lora_paths))
|
218
|
+
self.in_queue.put((prompts, max_new_tokens, lora_paths, token_ids_logprob))
|
251
219
|
return self.out_queue.get()
|
252
220
|
|
253
221
|
def terminate(self):
|
@@ -261,6 +229,101 @@ class HFRunner:
|
|
261
229
|
self.model_proc.terminate()
|
262
230
|
self.in_queue = self.out_queue = None
|
263
231
|
|
232
|
+
@staticmethod
|
233
|
+
def forward_generation_raw(
|
234
|
+
base_model,
|
235
|
+
prompts: Union[List[str], List[torch.Tensor]],
|
236
|
+
max_new_tokens: int,
|
237
|
+
tokenizer,
|
238
|
+
torch_dtype: torch.dtype,
|
239
|
+
lora_paths: Optional[List[str]] = None,
|
240
|
+
output_str_only: bool = False,
|
241
|
+
token_ids_logprob: Optional[int] = None,
|
242
|
+
) -> ModelOutput:
|
243
|
+
output_strs = []
|
244
|
+
top_input_logprobs = []
|
245
|
+
top_output_logprobs = []
|
246
|
+
if token_ids_logprob is not None:
|
247
|
+
token_ids_input_logprobs = []
|
248
|
+
token_ids_output_logprobs = []
|
249
|
+
else:
|
250
|
+
token_ids_input_logprobs = token_ids_output_logprobs = None
|
251
|
+
|
252
|
+
for i, p in enumerate(prompts):
|
253
|
+
if isinstance(p, str):
|
254
|
+
input_ids = tokenizer.encode(p, return_tensors="pt").cuda()
|
255
|
+
else:
|
256
|
+
input_ids = torch.tensor([p], device="cuda")
|
257
|
+
|
258
|
+
if lora_paths is not None and lora_paths[i] is not None:
|
259
|
+
from peft import PeftModel
|
260
|
+
|
261
|
+
model = PeftModel.from_pretrained(
|
262
|
+
base_model,
|
263
|
+
lora_paths[i],
|
264
|
+
torch_dtype=torch_dtype,
|
265
|
+
is_trainable=False,
|
266
|
+
)
|
267
|
+
else:
|
268
|
+
model = base_model
|
269
|
+
|
270
|
+
outputs = model.generate(
|
271
|
+
input_ids,
|
272
|
+
do_sample=False,
|
273
|
+
temperature=None,
|
274
|
+
top_p=None,
|
275
|
+
max_new_tokens=max_new_tokens,
|
276
|
+
return_dict_in_generate=True,
|
277
|
+
output_scores=(not output_str_only),
|
278
|
+
)
|
279
|
+
|
280
|
+
text = tokenizer.decode(
|
281
|
+
outputs[0][0][len(input_ids[0]) :], skip_special_tokens=True
|
282
|
+
)
|
283
|
+
# Check if the text is empty or only whitespace.
|
284
|
+
if not text.strip():
|
285
|
+
raise ValueError(
|
286
|
+
"Received an empty text response. Please verify your input or model configuration."
|
287
|
+
)
|
288
|
+
output_strs.append(text)
|
289
|
+
|
290
|
+
if not output_str_only:
|
291
|
+
# outputs.scores: (num_token, 1, vocab_size)
|
292
|
+
top_output_logprobs.append(
|
293
|
+
[
|
294
|
+
get_top_logprobs(logits[0], NUM_TOP_LOGPROBS).tolist()
|
295
|
+
for logits in outputs.scores
|
296
|
+
]
|
297
|
+
)
|
298
|
+
if token_ids_logprob is not None:
|
299
|
+
token_ids_output_logprobs.append(
|
300
|
+
[
|
301
|
+
get_token_ids_logprobs(
|
302
|
+
logits[0], token_ids_logprob
|
303
|
+
).tolist()
|
304
|
+
for logits in outputs.scores
|
305
|
+
]
|
306
|
+
)
|
307
|
+
del outputs
|
308
|
+
|
309
|
+
input_logits = model.forward(input_ids).logits[0]
|
310
|
+
top_input_logprobs.append(
|
311
|
+
get_top_logprobs(input_logits, NUM_TOP_LOGPROBS).tolist()
|
312
|
+
)
|
313
|
+
if token_ids_logprob is not None:
|
314
|
+
token_ids_input_logprobs.append(
|
315
|
+
get_token_ids_logprobs(input_logits, token_ids_logprob).tolist()
|
316
|
+
)
|
317
|
+
del input_logits
|
318
|
+
|
319
|
+
return ModelOutput(
|
320
|
+
output_strs=output_strs,
|
321
|
+
top_input_logprobs=top_input_logprobs,
|
322
|
+
top_output_logprobs=top_output_logprobs,
|
323
|
+
token_ids_input_logprobs=token_ids_input_logprobs,
|
324
|
+
token_ids_output_logprobs=token_ids_output_logprobs,
|
325
|
+
)
|
326
|
+
|
264
327
|
|
265
328
|
class SRTRunner:
|
266
329
|
def __init__(
|
@@ -275,72 +338,79 @@ class SRTRunner:
|
|
275
338
|
lora_backend: str = "triton",
|
276
339
|
disable_cuda_graph: bool = False,
|
277
340
|
disable_radix_cache: bool = False,
|
341
|
+
chunked_prefill_size: Optional[int] = None,
|
342
|
+
dp_size: int = 1,
|
343
|
+
tokenizer_path: Optional[str] = None,
|
344
|
+
enable_ep_moe: bool = False,
|
345
|
+
mem_fraction_static: float = 0.65,
|
346
|
+
trust_remote_code: bool = False,
|
347
|
+
speculative_draft_model_path: Optional[str] = None,
|
348
|
+
speculative_algorithm: Optional[str] = None,
|
349
|
+
speculative_num_steps: Optional[int] = None,
|
350
|
+
speculative_eagle_topk: Optional[int] = None,
|
351
|
+
speculative_num_draft_tokens: Optional[int] = None,
|
352
|
+
disable_overlap_schedule: bool = False,
|
278
353
|
):
|
279
354
|
self.model_type = model_type
|
280
355
|
self.is_generation = model_type == "generation"
|
356
|
+
enable_dp_attention = dp_size > 1
|
357
|
+
|
358
|
+
spec_kwargs = {}
|
359
|
+
if speculative_draft_model_path:
|
360
|
+
spec_kwargs["speculative_draft_model_path"] = speculative_draft_model_path
|
361
|
+
spec_kwargs["speculative_algorithm"] = speculative_algorithm
|
362
|
+
spec_kwargs["speculative_num_steps"] = speculative_num_steps
|
363
|
+
spec_kwargs["speculative_eagle_topk"] = speculative_eagle_topk
|
364
|
+
spec_kwargs["speculative_num_draft_tokens"] = speculative_num_draft_tokens
|
365
|
+
|
281
366
|
self.engine = Engine(
|
282
367
|
model_path=model_path,
|
283
368
|
tp_size=tp_size,
|
284
369
|
dtype=get_dtype_str(torch_dtype),
|
285
370
|
port=port,
|
286
|
-
mem_fraction_static=
|
287
|
-
trust_remote_code=
|
371
|
+
mem_fraction_static=mem_fraction_static,
|
372
|
+
trust_remote_code=trust_remote_code,
|
288
373
|
is_embedding=not self.is_generation,
|
289
374
|
lora_paths=lora_paths,
|
290
375
|
max_loras_per_batch=max_loras_per_batch,
|
291
376
|
lora_backend=lora_backend,
|
292
377
|
disable_cuda_graph=disable_cuda_graph,
|
293
378
|
disable_radix_cache=disable_radix_cache,
|
379
|
+
chunked_prefill_size=chunked_prefill_size,
|
380
|
+
enable_dp_attention=enable_dp_attention,
|
381
|
+
dp_size=dp_size,
|
382
|
+
tokenizer_path=tokenizer_path,
|
383
|
+
enable_ep_moe=enable_ep_moe,
|
384
|
+
disable_overlap_schedule=disable_overlap_schedule,
|
385
|
+
cuda_graph_max_bs=4,
|
386
|
+
**spec_kwargs,
|
294
387
|
)
|
295
|
-
|
388
|
+
|
389
|
+
if tokenizer_path is None:
|
390
|
+
self.tokenizer = get_tokenizer(
|
391
|
+
model_path, trust_remote_code=trust_remote_code
|
392
|
+
)
|
393
|
+
else:
|
394
|
+
self.tokenizer = None
|
296
395
|
|
297
396
|
def forward(
|
298
397
|
self,
|
299
398
|
prompts: Union[List[str], List[torch.Tensor]] = DEFAULT_PROMPTS,
|
300
|
-
max_new_tokens=8,
|
301
|
-
lora_paths=None,
|
399
|
+
max_new_tokens: int = 8,
|
400
|
+
lora_paths: Optional[List[str]] = None,
|
401
|
+
logprob_start_len: int = 0,
|
402
|
+
top_k: Optional[int] = None,
|
403
|
+
token_ids_logprob: Optional[List[int]] = None,
|
302
404
|
):
|
303
405
|
if self.is_generation:
|
304
|
-
|
305
|
-
|
306
|
-
|
307
|
-
|
308
|
-
|
309
|
-
|
310
|
-
|
311
|
-
|
312
|
-
lora_path=lora_paths[i] if lora_paths else None,
|
313
|
-
sampling_params=sampling_params,
|
314
|
-
return_logprob=True,
|
315
|
-
logprob_start_len=0,
|
316
|
-
top_logprobs_num=NUM_TOP_LOGPROBS,
|
317
|
-
)
|
318
|
-
output_strs.append(response["text"])
|
319
|
-
top_input_logprobs.append(
|
320
|
-
[
|
321
|
-
[tup[0] for tup in x[:NUM_TOP_LOGPROBS]]
|
322
|
-
for x in response["meta_info"]["input_top_logprobs"][1:]
|
323
|
-
]
|
324
|
-
+ [
|
325
|
-
[
|
326
|
-
tup[0]
|
327
|
-
for tup in response["meta_info"]["output_top_logprobs"][0][
|
328
|
-
:NUM_TOP_LOGPROBS
|
329
|
-
]
|
330
|
-
]
|
331
|
-
]
|
332
|
-
)
|
333
|
-
top_output_logprobs.append(
|
334
|
-
[
|
335
|
-
[tup[0] for tup in x[:NUM_TOP_LOGPROBS]]
|
336
|
-
for x in response["meta_info"]["output_top_logprobs"]
|
337
|
-
]
|
338
|
-
)
|
339
|
-
|
340
|
-
return ModelOutput(
|
341
|
-
output_strs=output_strs,
|
342
|
-
top_input_logprobs=top_input_logprobs,
|
343
|
-
top_output_logprobs=top_output_logprobs,
|
406
|
+
return self.forward_generation_raw(
|
407
|
+
engine=self.engine,
|
408
|
+
prompts=prompts,
|
409
|
+
max_new_tokens=max_new_tokens,
|
410
|
+
lora_paths=lora_paths,
|
411
|
+
logprob_start_len=logprob_start_len,
|
412
|
+
top_k=top_k,
|
413
|
+
token_ids_logprob=token_ids_logprob,
|
344
414
|
)
|
345
415
|
else:
|
346
416
|
response = self.engine.encode(prompts)
|
@@ -362,18 +432,11 @@ class SRTRunner:
|
|
362
432
|
only return output strings and no logprobs
|
363
433
|
"""
|
364
434
|
if self.is_generation:
|
365
|
-
|
366
|
-
|
367
|
-
|
368
|
-
|
369
|
-
|
370
|
-
lora_path=lora_paths if lora_paths else None,
|
371
|
-
sampling_params=sampling_params,
|
372
|
-
)
|
373
|
-
output_strs = [r["text"] for r in response]
|
374
|
-
|
375
|
-
return ModelOutput(
|
376
|
-
output_strs=output_strs,
|
435
|
+
return self.batch_forward_generation_raw(
|
436
|
+
engine=self.engine,
|
437
|
+
prompts=prompts,
|
438
|
+
max_new_tokens=max_new_tokens,
|
439
|
+
lora_paths=lora_paths,
|
377
440
|
)
|
378
441
|
else:
|
379
442
|
response = self.engine.encode(prompts)
|
@@ -391,6 +454,157 @@ class SRTRunner:
|
|
391
454
|
self.engine.shutdown()
|
392
455
|
del self.engine
|
393
456
|
|
457
|
+
@staticmethod
|
458
|
+
def forward_generation_raw(
|
459
|
+
engine: Engine,
|
460
|
+
prompts: Union[List[str], List[torch.Tensor]],
|
461
|
+
max_new_tokens: int = 8,
|
462
|
+
lora_paths: Optional[List[str]] = None,
|
463
|
+
logprob_start_len: int = 0,
|
464
|
+
top_k: Optional[int] = None,
|
465
|
+
token_ids_logprob: Optional[List[int]] = None,
|
466
|
+
):
|
467
|
+
# the return value contains logprobs from prefill
|
468
|
+
output_strs = []
|
469
|
+
output_ids = []
|
470
|
+
# Input logprobs. Note that the last item in input logprob is equivalent to
|
471
|
+
# the first item in the output logprob.
|
472
|
+
top_input_logprobs = []
|
473
|
+
input_token_logprobs_lst = []
|
474
|
+
top_output_logprobs = []
|
475
|
+
output_token_logprobs_lst = []
|
476
|
+
top_output_logprob_idx = []
|
477
|
+
if token_ids_logprob is not None:
|
478
|
+
token_ids_input_logprobs = []
|
479
|
+
token_ids_output_logprobs = []
|
480
|
+
else:
|
481
|
+
token_ids_input_logprobs = token_ids_output_logprobs = None
|
482
|
+
|
483
|
+
sampling_params = {"max_new_tokens": max_new_tokens, "temperature": 0}
|
484
|
+
if top_k:
|
485
|
+
sampling_params["top_k"] = top_k
|
486
|
+
|
487
|
+
for i, prompt in enumerate(prompts):
|
488
|
+
response = engine.generate(
|
489
|
+
prompt,
|
490
|
+
lora_path=lora_paths[i] if lora_paths else None,
|
491
|
+
sampling_params=sampling_params,
|
492
|
+
return_logprob=True,
|
493
|
+
logprob_start_len=logprob_start_len,
|
494
|
+
top_logprobs_num=NUM_TOP_LOGPROBS,
|
495
|
+
token_ids_logprob=token_ids_logprob,
|
496
|
+
)
|
497
|
+
text = response["text"]
|
498
|
+
|
499
|
+
# Check if the text is empty or only whitespace.
|
500
|
+
if not text.strip():
|
501
|
+
raise ValueError(
|
502
|
+
"Received an empty text response. Please verify your input or model configuration."
|
503
|
+
)
|
504
|
+
output_strs.append(text)
|
505
|
+
# output_ids.append(response["output_ids"])
|
506
|
+
|
507
|
+
input_token_logprobs = response["meta_info"]["input_token_logprobs"]
|
508
|
+
output_token_logprobs = response["meta_info"]["output_token_logprobs"]
|
509
|
+
# print(i, input_token_logprobs)
|
510
|
+
# print(i, output_token_logprobs)
|
511
|
+
logprobs = response["meta_info"]["input_top_logprobs"]
|
512
|
+
if token_ids_logprob is not None:
|
513
|
+
input_token_ids_logprobs = response["meta_info"][
|
514
|
+
"input_token_ids_logprobs"
|
515
|
+
][1:]
|
516
|
+
else:
|
517
|
+
input_token_ids_logprobs = None
|
518
|
+
|
519
|
+
num_prompt_tokens = response["meta_info"]["prompt_tokens"]
|
520
|
+
assert len(input_token_logprobs) == num_prompt_tokens - logprob_start_len
|
521
|
+
assert len(logprobs) == num_prompt_tokens - logprob_start_len
|
522
|
+
|
523
|
+
# The first token logprob has no meaning in sglang.
|
524
|
+
input_token_logprobs = input_token_logprobs[1:]
|
525
|
+
logprobs = logprobs[1:]
|
526
|
+
assert len(input_token_logprobs) == len(logprobs)
|
527
|
+
|
528
|
+
input_token_logprobs_lst.append(
|
529
|
+
input_token_logprobs + [output_token_logprobs[0]]
|
530
|
+
)
|
531
|
+
output_token_logprobs_lst.append(output_token_logprobs)
|
532
|
+
|
533
|
+
top_input_logprobs.append(
|
534
|
+
[[tup[0] for tup in x[:NUM_TOP_LOGPROBS]] for x in logprobs]
|
535
|
+
+ [
|
536
|
+
[
|
537
|
+
tup[0]
|
538
|
+
for tup in response["meta_info"]["output_top_logprobs"][0][
|
539
|
+
:NUM_TOP_LOGPROBS
|
540
|
+
]
|
541
|
+
]
|
542
|
+
]
|
543
|
+
)
|
544
|
+
top_output_logprobs.append(
|
545
|
+
[
|
546
|
+
[tup[0] for tup in x[:NUM_TOP_LOGPROBS]]
|
547
|
+
for x in response["meta_info"]["output_top_logprobs"]
|
548
|
+
]
|
549
|
+
)
|
550
|
+
top_output_logprob_idx.append(
|
551
|
+
[
|
552
|
+
[tup[1] for tup in x[:NUM_TOP_LOGPROBS]]
|
553
|
+
for x in response["meta_info"]["output_top_logprobs"]
|
554
|
+
]
|
555
|
+
)
|
556
|
+
if token_ids_logprob is not None:
|
557
|
+
token_ids_input_logprobs.append(
|
558
|
+
[[tup[0] for tup in x] for x in input_token_ids_logprobs]
|
559
|
+
+ [
|
560
|
+
[
|
561
|
+
tup[0]
|
562
|
+
for tup in response["meta_info"][
|
563
|
+
"output_token_ids_logprobs"
|
564
|
+
][0]
|
565
|
+
]
|
566
|
+
]
|
567
|
+
)
|
568
|
+
token_ids_output_logprobs.append(
|
569
|
+
[
|
570
|
+
[tup[0] for tup in x]
|
571
|
+
for x in response["meta_info"]["output_token_ids_logprobs"]
|
572
|
+
]
|
573
|
+
)
|
574
|
+
|
575
|
+
return ModelOutput(
|
576
|
+
output_strs=output_strs,
|
577
|
+
output_ids=output_ids,
|
578
|
+
top_input_logprobs=top_input_logprobs,
|
579
|
+
top_output_logprobs=top_output_logprobs,
|
580
|
+
input_token_logprobs_lst=input_token_logprobs_lst,
|
581
|
+
output_token_logprobs_lst=output_token_logprobs_lst,
|
582
|
+
top_output_logprob_idx=top_output_logprob_idx,
|
583
|
+
token_ids_input_logprobs=token_ids_input_logprobs,
|
584
|
+
token_ids_output_logprobs=token_ids_output_logprobs,
|
585
|
+
)
|
586
|
+
|
587
|
+
@staticmethod
|
588
|
+
def batch_forward_generation_raw(
|
589
|
+
prompts: Union[List[str], List[torch.Tensor]],
|
590
|
+
max_new_tokens,
|
591
|
+
lora_paths,
|
592
|
+
engine,
|
593
|
+
):
|
594
|
+
# the return value contains logprobs from prefill
|
595
|
+
output_strs = []
|
596
|
+
sampling_params = {"max_new_tokens": max_new_tokens, "temperature": 0}
|
597
|
+
response = engine.generate(
|
598
|
+
prompts,
|
599
|
+
lora_path=lora_paths if lora_paths else None,
|
600
|
+
sampling_params=sampling_params,
|
601
|
+
)
|
602
|
+
output_strs = [r["text"] for r in response]
|
603
|
+
|
604
|
+
return ModelOutput(
|
605
|
+
output_strs=output_strs,
|
606
|
+
)
|
607
|
+
|
394
608
|
|
395
609
|
def monkey_patch_gemma2_sdpa():
|
396
610
|
"""
|
@@ -405,3 +619,52 @@ def monkey_patch_gemma2_sdpa():
|
|
405
619
|
return config
|
406
620
|
|
407
621
|
setattr(Gemma2PreTrainedModel, "_check_and_enable_sdpa", _check_and_enable_sdpa)
|
622
|
+
|
623
|
+
|
624
|
+
def check_close_model_outputs(
|
625
|
+
hf_outputs: ModelOutput,
|
626
|
+
srt_outputs: ModelOutput,
|
627
|
+
prefill_tolerance: float,
|
628
|
+
decode_tolerance: float,
|
629
|
+
rouge_l_tolerance: float,
|
630
|
+
debug_text: str = "",
|
631
|
+
check_logprobs: bool = True,
|
632
|
+
):
|
633
|
+
# Compare output strings
|
634
|
+
print(f"{hf_outputs.output_strs=}")
|
635
|
+
print(f"{srt_outputs.output_strs=}")
|
636
|
+
rouge_l_scores = calculate_rouge_l(hf_outputs.output_strs, srt_outputs.output_strs)
|
637
|
+
print(f"{rouge_l_scores=}")
|
638
|
+
assert all(
|
639
|
+
score >= rouge_l_tolerance for score in rouge_l_scores
|
640
|
+
), f"Not all ROUGE-L scores are greater than rouge_l_tolerance={rouge_l_tolerance}"
|
641
|
+
|
642
|
+
if check_logprobs:
|
643
|
+
for i in range(len(hf_outputs.output_strs)):
|
644
|
+
# Compare input logprobs
|
645
|
+
hf_logprobs = torch.Tensor(hf_outputs.top_input_logprobs[i])
|
646
|
+
srt_logprobs = torch.Tensor(srt_outputs.top_input_logprobs[i])
|
647
|
+
input_len = hf_logprobs.shape[0]
|
648
|
+
print(
|
649
|
+
"prefill logprobs max_diff", torch.max(abs(hf_logprobs - srt_logprobs))
|
650
|
+
)
|
651
|
+
if input_len <= 100:
|
652
|
+
assert torch.all(abs(hf_logprobs - srt_logprobs) < prefill_tolerance), (
|
653
|
+
f"prefill logprobs are not all close with {debug_text} "
|
654
|
+
f"prefill_tolerance={prefill_tolerance}."
|
655
|
+
f"{hf_logprobs=}, {srt_logprobs=}"
|
656
|
+
)
|
657
|
+
|
658
|
+
# Compare output logprobs
|
659
|
+
hf_logprobs = torch.Tensor(hf_outputs.top_output_logprobs[i])
|
660
|
+
srt_logprobs = torch.Tensor(srt_outputs.top_output_logprobs[i])
|
661
|
+
|
662
|
+
print(
|
663
|
+
"decode logprobs max_diff", torch.max(abs(hf_logprobs - srt_logprobs))
|
664
|
+
)
|
665
|
+
if input_len <= 100:
|
666
|
+
assert torch.all(abs(hf_logprobs - srt_logprobs) < decode_tolerance), (
|
667
|
+
f"decode logprobs are not all close with {debug_text} "
|
668
|
+
f"decode_tolerance={decode_tolerance}."
|
669
|
+
f"{hf_logprobs=}, {srt_logprobs=}"
|
670
|
+
)
|