sglang 0.4.3.post1__py3-none-any.whl → 0.4.3.post3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/api.py +1 -1
- sglang/bench_offline_throughput.py +19 -0
- sglang/bench_one_batch.py +2 -2
- sglang/bench_serving.py +123 -79
- sglang/global_config.py +8 -3
- sglang/lang/backend/runtime_endpoint.py +1 -1
- sglang/lang/ir.py +1 -1
- sglang/srt/_custom_ops.py +83 -91
- sglang/srt/configs/load_config.py +4 -1
- sglang/srt/configs/model_config.py +48 -2
- sglang/srt/configs/qwen2_5_vl_config.py +5 -2
- sglang/srt/constrained/base_grammar_backend.py +117 -15
- sglang/srt/constrained/llguidance_backend.py +151 -0
- sglang/srt/constrained/outlines_backend.py +24 -33
- sglang/srt/constrained/xgrammar_backend.py +69 -38
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +225 -80
- sglang/srt/distributed/parallel_state.py +48 -3
- sglang/srt/entrypoints/engine.py +67 -9
- sglang/srt/entrypoints/http_server.py +190 -41
- sglang/srt/entrypoints/verl_engine.py +147 -0
- sglang/srt/function_call_parser.py +0 -1
- sglang/srt/layers/activation.py +11 -0
- sglang/srt/layers/attention/{__init__.py → base_attn_backend.py} +14 -6
- sglang/srt/layers/attention/double_sparsity_backend.py +1 -1
- sglang/srt/layers/attention/flashinfer_backend.py +208 -295
- sglang/srt/layers/attention/flashinfer_mla_backend.py +582 -0
- sglang/srt/layers/attention/torch_native_backend.py +1 -1
- sglang/srt/layers/attention/triton_backend.py +9 -6
- sglang/srt/layers/attention/triton_ops/decode_attention.py +3 -0
- sglang/srt/layers/attention/triton_ops/extend_attention.py +20 -4
- sglang/srt/layers/attention/triton_ops/rocm_mla_decode_rope.py +439 -0
- sglang/srt/layers/attention/utils.py +39 -0
- sglang/srt/layers/attention/vision.py +60 -63
- sglang/srt/layers/dp_attention.py +142 -1
- sglang/srt/layers/layernorm.py +1 -1
- sglang/srt/layers/linear.py +3 -1
- sglang/srt/layers/logits_processor.py +281 -45
- sglang/srt/layers/moe/ep_moe/kernels.py +126 -8
- sglang/srt/layers/moe/ep_moe/layer.py +140 -28
- sglang/srt/layers/moe/fused_moe_native.py +2 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +50 -50
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Instinct_MI300X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Instinct_MI325X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Radeon_Graphics.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Instinct_MI300X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Instinct_MI325X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Radeon_Graphics.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Instinct_MI300X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Instinct_MI325X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Radeon_Graphics.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json +16 -16
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json +16 -16
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json +16 -16
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Instinct_MI300X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Instinct_MI325X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Radeon_Graphics.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json +15 -15
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json +15 -15
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json +15 -15
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +88 -20
- sglang/srt/layers/moe/fused_moe_triton/layer.py +34 -13
- sglang/srt/layers/moe/topk.py +13 -4
- sglang/srt/layers/quantization/__init__.py +111 -7
- sglang/srt/layers/quantization/blockwise_int8.py +409 -0
- sglang/srt/layers/quantization/configs/N=1536,K=1536,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=1536,K=1536,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
- sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
- sglang/srt/layers/quantization/configs/N=2048,K=512,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=2048,K=512,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
- sglang/srt/layers/quantization/configs/N=2304,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=2304,K=7168,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
- sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
- sglang/srt/layers/quantization/configs/N=256,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=256,K=7168,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
- sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
- sglang/srt/layers/quantization/configs/N=7168,K=1024,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=1024,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
- sglang/srt/layers/quantization/configs/N=7168,K=1152,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=1152,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
- sglang/srt/layers/quantization/configs/N=7168,K=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=128,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
- sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
- sglang/srt/layers/quantization/fp8.py +69 -28
- sglang/srt/layers/quantization/fp8_utils.py +17 -1
- sglang/srt/layers/quantization/gptq.py +416 -0
- sglang/srt/layers/quantization/int8_kernel.py +327 -0
- sglang/srt/layers/quantization/int8_utils.py +73 -0
- sglang/srt/layers/quantization/modelopt_quant.py +18 -1
- sglang/srt/layers/radix_attention.py +1 -0
- sglang/srt/layers/rotary_embedding.py +0 -1
- sglang/srt/layers/sampler.py +76 -31
- sglang/srt/layers/vocab_parallel_embedding.py +14 -13
- sglang/srt/lora/lora.py +17 -1
- sglang/srt/lora/lora_config.py +5 -0
- sglang/srt/lora/lora_manager.py +1 -3
- sglang/srt/managers/cache_controller.py +193 -62
- sglang/srt/managers/configure_logging.py +2 -1
- sglang/srt/managers/data_parallel_controller.py +6 -2
- sglang/srt/managers/detokenizer_manager.py +124 -102
- sglang/srt/managers/image_processor.py +2 -1
- sglang/srt/managers/io_struct.py +143 -6
- sglang/srt/managers/schedule_batch.py +238 -197
- sglang/srt/managers/schedule_policy.py +29 -29
- sglang/srt/managers/scheduler.py +681 -259
- sglang/srt/managers/session_controller.py +6 -2
- sglang/srt/managers/tokenizer_manager.py +224 -68
- sglang/srt/managers/tp_worker.py +15 -4
- sglang/srt/managers/tp_worker_overlap_thread.py +3 -4
- sglang/srt/mem_cache/chunk_cache.py +18 -11
- sglang/srt/mem_cache/hiradix_cache.py +394 -0
- sglang/srt/mem_cache/memory_pool.py +44 -18
- sglang/srt/mem_cache/radix_cache.py +58 -47
- sglang/srt/metrics/collector.py +94 -36
- sglang/srt/model_executor/cuda_graph_runner.py +55 -24
- sglang/srt/model_executor/forward_batch_info.py +49 -16
- sglang/srt/model_executor/model_runner.py +209 -28
- sglang/srt/model_loader/loader.py +3 -3
- sglang/srt/model_loader/weight_utils.py +36 -14
- sglang/srt/models/baichuan.py +31 -6
- sglang/srt/models/chatglm.py +39 -7
- sglang/srt/models/commandr.py +29 -5
- sglang/srt/models/dbrx.py +31 -5
- sglang/srt/models/deepseek.py +43 -6
- sglang/srt/models/deepseek_nextn.py +32 -19
- sglang/srt/models/deepseek_v2.py +265 -29
- sglang/srt/models/exaone.py +19 -9
- sglang/srt/models/gemma.py +22 -8
- sglang/srt/models/gemma2.py +25 -12
- sglang/srt/models/gemma2_reward.py +5 -1
- sglang/srt/models/gpt2.py +28 -13
- sglang/srt/models/gpt_bigcode.py +27 -5
- sglang/srt/models/granite.py +21 -9
- sglang/srt/models/grok.py +21 -4
- sglang/srt/models/internlm2.py +36 -6
- sglang/srt/models/internlm2_reward.py +5 -1
- sglang/srt/models/llama.py +26 -9
- sglang/srt/models/llama_classification.py +5 -1
- sglang/srt/models/llama_eagle.py +17 -4
- sglang/srt/models/llama_embedding.py +5 -1
- sglang/srt/models/llama_reward.py +7 -2
- sglang/srt/models/llava.py +19 -3
- sglang/srt/models/llavavid.py +10 -1
- sglang/srt/models/minicpm.py +26 -2
- sglang/srt/models/minicpm3.py +39 -3
- sglang/srt/models/minicpmv.py +45 -14
- sglang/srt/models/mixtral.py +20 -9
- sglang/srt/models/mixtral_quant.py +50 -8
- sglang/srt/models/mllama.py +57 -11
- sglang/srt/models/olmo.py +34 -6
- sglang/srt/models/olmo2.py +34 -13
- sglang/srt/models/olmoe.py +26 -4
- sglang/srt/models/phi3_small.py +29 -10
- sglang/srt/models/qwen.py +26 -3
- sglang/srt/models/qwen2.py +26 -4
- sglang/srt/models/qwen2_5_vl.py +46 -8
- sglang/srt/models/qwen2_eagle.py +17 -5
- sglang/srt/models/qwen2_moe.py +44 -6
- sglang/srt/models/qwen2_rm.py +78 -0
- sglang/srt/models/qwen2_vl.py +39 -8
- sglang/srt/models/stablelm.py +32 -5
- sglang/srt/models/torch_native_llama.py +5 -2
- sglang/srt/models/xverse.py +21 -9
- sglang/srt/models/xverse_moe.py +45 -7
- sglang/srt/models/yivl.py +2 -1
- sglang/srt/openai_api/adapter.py +109 -24
- sglang/srt/openai_api/protocol.py +17 -1
- sglang/srt/reasoning_parser.py +154 -0
- sglang/srt/sampling/penaltylib/__init__.py +4 -6
- sglang/srt/sampling/penaltylib/frequency_penalty.py +66 -0
- sglang/srt/sampling/penaltylib/{penalizers/min_new_tokens.py → min_new_tokens.py} +15 -23
- sglang/srt/sampling/penaltylib/orchestrator.py +39 -188
- sglang/srt/sampling/penaltylib/presence_penalty.py +66 -0
- sglang/srt/sampling/sampling_batch_info.py +79 -157
- sglang/srt/sampling/sampling_params.py +16 -13
- sglang/srt/server_args.py +136 -52
- sglang/srt/speculative/build_eagle_tree.py +2 -8
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +0 -1
- sglang/srt/speculative/eagle_utils.py +92 -58
- sglang/srt/speculative/eagle_worker.py +186 -94
- sglang/srt/speculative/spec_info.py +1 -13
- sglang/srt/utils.py +43 -17
- sglang/srt/warmup.py +47 -0
- sglang/test/few_shot_gsm8k.py +4 -1
- sglang/test/runners.py +389 -126
- sglang/test/send_one.py +88 -0
- sglang/test/test_block_fp8_ep.py +361 -0
- sglang/test/test_programs.py +1 -1
- sglang/test/test_utils.py +138 -84
- sglang/utils.py +50 -60
- sglang/version.py +1 -1
- {sglang-0.4.3.post1.dist-info → sglang-0.4.3.post3.dist-info}/METADATA +21 -15
- {sglang-0.4.3.post1.dist-info → sglang-0.4.3.post3.dist-info}/RECORD +214 -166
- {sglang-0.4.3.post1.dist-info → sglang-0.4.3.post3.dist-info}/WHEEL +1 -1
- sglang/bench_latency.py +0 -1
- sglang/srt/sampling/penaltylib/penalizers/frequency_penalty.py +0 -75
- sglang/srt/sampling/penaltylib/penalizers/presence_penalty.py +0 -74
- sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +0 -85
- sglang/test/srt/sampling/penaltylib/utils.py +0 -344
- {sglang-0.4.3.post1.dist-info → sglang-0.4.3.post3.dist-info}/LICENSE +0 -0
- {sglang-0.4.3.post1.dist-info → sglang-0.4.3.post3.dist-info}/top_level.txt +0 -0
@@ -1,8 +1,10 @@
|
|
1
1
|
import logging
|
2
|
+
import os
|
2
3
|
import time
|
3
|
-
from typing import List, Optional, Union
|
4
|
+
from typing import Dict, List, Optional, Tuple, Union
|
4
5
|
|
5
6
|
import torch
|
7
|
+
from huggingface_hub import snapshot_download
|
6
8
|
|
7
9
|
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
8
10
|
from sglang.srt.managers.schedule_batch import Req, ScheduleBatch
|
@@ -20,11 +22,13 @@ from sglang.srt.speculative.eagle_draft_cuda_graph_runner import (
|
|
20
22
|
from sglang.srt.speculative.eagle_utils import (
|
21
23
|
EagleDraftInput,
|
22
24
|
EagleVerifyInput,
|
25
|
+
EagleVerifyOutput,
|
23
26
|
assign_draft_cache_locs,
|
24
27
|
fast_topk,
|
25
28
|
select_top_k_tokens,
|
26
29
|
)
|
27
30
|
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
|
31
|
+
from sglang.srt.utils import get_available_gpu_memory
|
28
32
|
|
29
33
|
logger = logging.getLogger(__name__)
|
30
34
|
|
@@ -40,10 +44,31 @@ class EAGLEWorker(TpModelWorker):
|
|
40
44
|
nccl_port: int,
|
41
45
|
target_worker: TpModelWorker,
|
42
46
|
):
|
47
|
+
# Override context length with target model's context length
|
48
|
+
server_args.context_length = target_worker.model_runner.model_config.context_len
|
49
|
+
os.environ["SGLANG_ALLOW_OVERWRITE_LONGER_CONTEXT_LEN"] = "1"
|
50
|
+
|
43
51
|
# Do not capture cuda graph in `super().__init__()`
|
44
52
|
# We will capture it later
|
45
53
|
backup_disable_cuda_graph = server_args.disable_cuda_graph
|
46
54
|
server_args.disable_cuda_graph = True
|
55
|
+
|
56
|
+
# Lossy optimization by using hot tokens
|
57
|
+
if server_args.speculative_token_map is not None:
|
58
|
+
self.hot_token_id = load_token_map(server_args.speculative_token_map)
|
59
|
+
server_args.json_model_override_args = (
|
60
|
+
f'{{"hot_vocab_size": {len(self.hot_token_id)}}}'
|
61
|
+
)
|
62
|
+
else:
|
63
|
+
self.hot_token_id = None
|
64
|
+
|
65
|
+
# We share the allocator with a target worker. Draft/target worker
|
66
|
+
# owns its own KV cache.
|
67
|
+
self.req_to_token_pool, self.token_to_kv_pool_allocator = (
|
68
|
+
target_worker.get_memory_pool()
|
69
|
+
)
|
70
|
+
|
71
|
+
# Init target worker
|
47
72
|
super().__init__(
|
48
73
|
gpu_id=gpu_id,
|
49
74
|
tp_rank=tp_rank,
|
@@ -51,9 +76,10 @@ class EAGLEWorker(TpModelWorker):
|
|
51
76
|
nccl_port=nccl_port,
|
52
77
|
dp_rank=dp_rank,
|
53
78
|
is_draft_worker=True,
|
79
|
+
req_to_token_pool=self.req_to_token_pool,
|
80
|
+
token_to_kv_pool_allocator=self.token_to_kv_pool_allocator,
|
54
81
|
)
|
55
82
|
self.target_worker = target_worker
|
56
|
-
self.finish_extend_len = []
|
57
83
|
|
58
84
|
# Parse arguments
|
59
85
|
self.topk = server_args.speculative_eagle_topk
|
@@ -62,12 +88,20 @@ class EAGLEWorker(TpModelWorker):
|
|
62
88
|
server_args.speculative_algorithm
|
63
89
|
)
|
64
90
|
self.server_args = server_args
|
91
|
+
self.use_nan_detection = self.server_args.enable_nan_detection
|
92
|
+
self.device = self.model_runner.device
|
93
|
+
self.gpu_id = self.model_runner.gpu_id
|
65
94
|
|
66
95
|
# Share the embedding and lm_head
|
67
|
-
|
68
|
-
|
69
|
-
|
70
|
-
|
96
|
+
embed, head = self.target_worker.model_runner.model.get_embed_and_head()
|
97
|
+
if self.hot_token_id is not None:
|
98
|
+
head = head.clone()
|
99
|
+
self.hot_token_id = self.hot_token_id.to(head.device)
|
100
|
+
head.data = head.data[self.hot_token_id]
|
101
|
+
self.draft_model_runner.model.set_embed_and_head(embed, head)
|
102
|
+
self.draft_model_runner.server_args.disable_cuda_graph = (
|
103
|
+
backup_disable_cuda_graph
|
104
|
+
)
|
71
105
|
|
72
106
|
# Create multi-step attn backends and cuda graph runners
|
73
107
|
if server_args.attention_backend == "flashinfer":
|
@@ -95,7 +129,7 @@ class EAGLEWorker(TpModelWorker):
|
|
95
129
|
f"EAGLE is not supportted in attention backend {server_args.attention_backend}"
|
96
130
|
)
|
97
131
|
|
98
|
-
self.
|
132
|
+
self.draft_model_runner.draft_attn_backend = self.draft_attn_backend
|
99
133
|
self.init_cuda_graphs()
|
100
134
|
|
101
135
|
def init_cuda_graphs(self):
|
@@ -106,55 +140,81 @@ class EAGLEWorker(TpModelWorker):
|
|
106
140
|
return
|
107
141
|
|
108
142
|
tic = time.time()
|
109
|
-
logger.info(
|
143
|
+
logger.info(
|
144
|
+
f"Capture draft cuda graph begin. This can take up to several minutes. avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
|
145
|
+
)
|
110
146
|
self.cuda_graph_runner = EAGLEDraftCudaGraphRunner(self)
|
111
|
-
logger.info(
|
147
|
+
logger.info(
|
148
|
+
f"Capture draft cuda graph end. Time elapsed: {time.time() - tic:.2f} s. avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
|
149
|
+
)
|
112
150
|
|
113
|
-
|
151
|
+
@property
|
152
|
+
def draft_model_runner(self):
|
153
|
+
return self.model_runner
|
154
|
+
|
155
|
+
def forward_batch_speculative_generation(
|
156
|
+
self, batch: ScheduleBatch
|
157
|
+
) -> Tuple[LogitsProcessorOutput, List[int], int, int]:
|
158
|
+
"""Run speculative decoding forward.
|
159
|
+
|
160
|
+
NOTE: Many states of batch is modified as you go through. It is not guaranteed
|
161
|
+
the final output batch doesn't have the same state as the input.
|
162
|
+
|
163
|
+
Args:
|
164
|
+
batch: The batch to run forward. The state of the batch is modified as it runs.
|
165
|
+
Returns:
|
166
|
+
A tuple of the final logit output of the target model, next tokens accepeted,
|
167
|
+
the batch id (used for overlap schedule), and number of accepeted tokens.
|
168
|
+
"""
|
169
|
+
assert not batch.spec_algorithm.is_none()
|
114
170
|
if batch.forward_mode.is_decode():
|
115
|
-
|
116
|
-
|
117
|
-
|
118
|
-
|
119
|
-
(
|
120
|
-
|
121
|
-
|
122
|
-
verified_id,
|
123
|
-
self.finish_extend_len,
|
124
|
-
accept_length_cpu,
|
125
|
-
model_worker_batch,
|
126
|
-
) = self.verify(batch, spec_info)
|
127
|
-
batch.spec_info = next_draft_input
|
128
|
-
# if it is None, means all requsets are finished
|
171
|
+
spec_info, to_free_cache_loc = self.draft(batch)
|
172
|
+
logits_output, verify_output, model_worker_batch = self.verify(
|
173
|
+
batch, spec_info
|
174
|
+
)
|
175
|
+
# Free cache loc (we put it here to avoid synchronization and hide kernel launch overhead.)
|
176
|
+
self.token_to_kv_pool_allocator.free(to_free_cache_loc)
|
177
|
+
# if it is None, means all requests are finished
|
129
178
|
if batch.spec_info.verified_id is not None:
|
130
179
|
self.forward_draft_extend_after_decode(batch)
|
180
|
+
|
131
181
|
return (
|
132
182
|
logits_output,
|
133
|
-
verified_id,
|
134
|
-
model_worker_batch,
|
135
|
-
sum(
|
183
|
+
verify_output.verified_id,
|
184
|
+
model_worker_batch.bid,
|
185
|
+
sum(verify_output.accept_length_per_req_cpu),
|
136
186
|
)
|
137
187
|
|
138
188
|
else:
|
139
|
-
|
140
|
-
|
141
|
-
|
142
|
-
model_worker_batch.capture_hidden_mode = CaptureHiddenMode.FULL
|
143
|
-
logits_output, next_token_ids = self.target_worker.forward_batch_generation(
|
144
|
-
model_worker_batch
|
145
|
-
)
|
146
|
-
|
147
|
-
# Forward with the draft model.
|
148
|
-
batch.spec_info = EagleDraftInput(
|
149
|
-
hidden_states=logits_output.hidden_states,
|
150
|
-
verified_id=next_token_ids,
|
189
|
+
logits_output, next_token_ids, bid = self.forward_target_extend(batch)
|
190
|
+
self.forward_draft_extend(
|
191
|
+
batch, logits_output.hidden_states, next_token_ids
|
151
192
|
)
|
152
|
-
|
153
|
-
|
193
|
+
return logits_output, next_token_ids, bid, 0
|
194
|
+
|
195
|
+
def forward_target_extend(
|
196
|
+
self, batch: ScheduleBatch
|
197
|
+
) -> Tuple[LogitsProcessorOutput, List[int], int]:
|
198
|
+
"""Run the target extend.
|
199
|
+
|
200
|
+
Args:
|
201
|
+
batch: The batch to run. States could be modified.
|
202
|
+
|
203
|
+
Returns:
|
204
|
+
logits_output: The output of logits. It will contain the full hidden states.
|
205
|
+
next_token_ids: Next token ids generated.
|
206
|
+
bid: The model batch ID. Used for overlap schedule.
|
207
|
+
"""
|
208
|
+
# Forward with the target model and get hidden states.
|
209
|
+
# We need the full hidden states to prefill the KV cache of the draft model.
|
210
|
+
model_worker_batch = batch.get_model_worker_batch()
|
211
|
+
model_worker_batch.capture_hidden_mode = CaptureHiddenMode.FULL
|
212
|
+
logits_output, next_token_ids = self.target_worker.forward_batch_generation(
|
213
|
+
model_worker_batch
|
214
|
+
)
|
215
|
+
return logits_output, next_token_ids, model_worker_batch.bid
|
154
216
|
|
155
217
|
def draft(self, batch: ScheduleBatch):
|
156
|
-
self._set_mem_pool(batch, self.model_runner)
|
157
|
-
|
158
218
|
# Parse args
|
159
219
|
num_seqs = batch.batch_size()
|
160
220
|
spec_info = batch.spec_info
|
@@ -172,7 +232,6 @@ class EAGLEWorker(TpModelWorker):
|
|
172
232
|
self.topk,
|
173
233
|
self.speculative_num_steps,
|
174
234
|
)
|
175
|
-
|
176
235
|
batch.out_cache_loc = out_cache_loc
|
177
236
|
batch.seq_lens_sum = torch.sum(batch.seq_lens).item()
|
178
237
|
spec_info.positions = batch.seq_lens.repeat_interleave(self.topk, dim=0)
|
@@ -180,11 +239,12 @@ class EAGLEWorker(TpModelWorker):
|
|
180
239
|
# Get forward batch
|
181
240
|
spec_info.capture_hidden_mode = CaptureHiddenMode.LAST
|
182
241
|
model_worker_batch = batch.get_model_worker_batch()
|
183
|
-
forward_batch = ForwardBatch.init_new(
|
242
|
+
forward_batch = ForwardBatch.init_new(
|
243
|
+
model_worker_batch, self.draft_model_runner
|
244
|
+
)
|
184
245
|
can_cuda_graph = self.cuda_graph_runner and self.cuda_graph_runner.can_run(
|
185
246
|
forward_batch
|
186
247
|
)
|
187
|
-
|
188
248
|
if can_cuda_graph:
|
189
249
|
score_list, token_list, parents_list = self.cuda_graph_runner.replay(
|
190
250
|
forward_batch
|
@@ -192,7 +252,9 @@ class EAGLEWorker(TpModelWorker):
|
|
192
252
|
else:
|
193
253
|
# Initialize attention backend
|
194
254
|
self.draft_attn_backend.init_forward_metadata(forward_batch)
|
195
|
-
|
255
|
+
forward_batch = ForwardBatch.init_new(
|
256
|
+
model_worker_batch, self.draft_model_runner
|
257
|
+
)
|
196
258
|
# Run forward steps
|
197
259
|
score_list, token_list, parents_list = self.draft_forward(forward_batch)
|
198
260
|
|
@@ -209,10 +271,7 @@ class EAGLEWorker(TpModelWorker):
|
|
209
271
|
batch.sampling_info.is_all_greedy,
|
210
272
|
)
|
211
273
|
|
212
|
-
|
213
|
-
batch.token_to_kv_pool.free(out_cache_loc)
|
214
|
-
self._set_mem_pool(batch, self.target_worker.model_runner)
|
215
|
-
return ret
|
274
|
+
return ret, out_cache_loc
|
216
275
|
|
217
276
|
def draft_forward(self, forward_batch: ForwardBatch):
|
218
277
|
# Parse args
|
@@ -223,6 +282,8 @@ class EAGLEWorker(TpModelWorker):
|
|
223
282
|
spec_info.topk_index,
|
224
283
|
spec_info.hidden_states,
|
225
284
|
)
|
285
|
+
if self.hot_token_id is not None:
|
286
|
+
topk_index = self.hot_token_id[topk_index]
|
226
287
|
|
227
288
|
# Return values
|
228
289
|
score_list: List[torch.Tensor] = []
|
@@ -260,8 +321,11 @@ class EAGLEWorker(TpModelWorker):
|
|
260
321
|
logits_output = self.model_runner.model.forward(
|
261
322
|
forward_batch.input_ids, forward_batch.positions, forward_batch
|
262
323
|
)
|
324
|
+
self._detect_nan_if_needed(logits_output)
|
263
325
|
probs = torch.softmax(logits_output.next_token_logits, dim=-1)
|
264
326
|
topk_p, topk_index = fast_topk(probs, self.topk, dim=-1)
|
327
|
+
if self.hot_token_id is not None:
|
328
|
+
topk_index = self.hot_token_id[topk_index]
|
265
329
|
hidden_states = logits_output.hidden_states
|
266
330
|
|
267
331
|
return score_list, token_list, parents_list
|
@@ -274,68 +338,96 @@ class EAGLEWorker(TpModelWorker):
|
|
274
338
|
logits_output, _ = self.target_worker.forward_batch_generation(
|
275
339
|
model_worker_batch, skip_sample=True
|
276
340
|
)
|
341
|
+
self._detect_nan_if_needed(logits_output)
|
277
342
|
spec_info.hidden_states = logits_output.hidden_states
|
278
|
-
res = spec_info.verify(
|
343
|
+
res: EagleVerifyOutput = spec_info.verify(
|
344
|
+
batch, logits_output, self.token_to_kv_pool_allocator
|
345
|
+
)
|
346
|
+
|
347
|
+
# Post process based on verified outputs.
|
348
|
+
# Pick indices that we care (accepeted)
|
349
|
+
logits_output.next_token_logits = logits_output.next_token_logits[
|
350
|
+
res.accepeted_indices_cpu
|
351
|
+
]
|
352
|
+
logits_output.hidden_states = logits_output.hidden_states[
|
353
|
+
res.accepeted_indices_cpu
|
354
|
+
]
|
355
|
+
# Prepare the batch for the next draft forwards.
|
279
356
|
batch.forward_mode = ForwardMode.DECODE
|
280
|
-
|
357
|
+
batch.spec_info = res.draft_input
|
358
|
+
|
359
|
+
return logits_output, res, model_worker_batch
|
281
360
|
|
282
|
-
def forward_draft_extend(
|
283
|
-
self
|
361
|
+
def forward_draft_extend(
|
362
|
+
self,
|
363
|
+
batch: ScheduleBatch,
|
364
|
+
hidden_states: torch.Tensor,
|
365
|
+
next_token_ids: List[int],
|
366
|
+
):
|
367
|
+
"""Run draft model extend. This API modifies the states of the batch.
|
368
|
+
|
369
|
+
Args:
|
370
|
+
batch: The batch to run.
|
371
|
+
hidden_states: Hidden states from the target model forward
|
372
|
+
next_token_ids: Next token ids generated from the target forward.
|
373
|
+
"""
|
374
|
+
batch.spec_info = EagleDraftInput(
|
375
|
+
hidden_states=hidden_states,
|
376
|
+
verified_id=next_token_ids,
|
377
|
+
)
|
284
378
|
batch.spec_info.prepare_for_extend(batch)
|
285
379
|
batch.spec_info.capture_hidden_mode = CaptureHiddenMode.LAST
|
286
380
|
model_worker_batch = batch.get_model_worker_batch()
|
287
|
-
forward_batch = ForwardBatch.init_new(
|
288
|
-
|
289
|
-
|
290
|
-
|
291
|
-
|
292
|
-
|
293
|
-
|
294
|
-
|
381
|
+
forward_batch = ForwardBatch.init_new(
|
382
|
+
model_worker_batch, self.draft_model_runner
|
383
|
+
)
|
384
|
+
logits_output = self.draft_model_runner.forward(forward_batch)
|
385
|
+
self._detect_nan_if_needed(logits_output)
|
386
|
+
assert isinstance(forward_batch.spec_info, EagleDraftInput)
|
387
|
+
assert forward_batch.spec_info is batch.spec_info
|
388
|
+
self.capture_for_decode(logits_output, forward_batch.spec_info)
|
295
389
|
|
296
390
|
def forward_draft_extend_after_decode(self, batch: ScheduleBatch):
|
297
391
|
seq_lens_backup = batch.seq_lens
|
298
|
-
req_pool_indices_backup = batch.req_pool_indices
|
299
|
-
|
300
|
-
self._set_mem_pool(batch, self.model_runner)
|
301
392
|
batch.forward_mode = ForwardMode.DRAFT_EXTEND
|
302
393
|
batch.spec_info.prepare_extend_after_decode(batch, self.speculative_num_steps)
|
303
394
|
batch.spec_info.capture_hidden_mode = CaptureHiddenMode.LAST
|
395
|
+
# We don't need logprob for this extend.
|
304
396
|
model_worker_batch = batch.get_model_worker_batch()
|
305
|
-
forward_batch = ForwardBatch.init_new(
|
306
|
-
|
307
|
-
|
308
|
-
|
397
|
+
forward_batch = ForwardBatch.init_new(
|
398
|
+
model_worker_batch, self.draft_model_runner
|
399
|
+
)
|
400
|
+
logits_output = self.draft_model_runner.forward(forward_batch)
|
401
|
+
self._detect_nan_if_needed(logits_output)
|
402
|
+
assert forward_batch.spec_info is batch.spec_info
|
403
|
+
self.capture_for_decode(logits_output, forward_batch.spec_info)
|
309
404
|
|
310
405
|
# Restore backup.
|
311
406
|
# This is because `seq_lens` can be modified in `prepare_extend_after_decode`
|
312
407
|
batch.forward_mode = ForwardMode.DECODE
|
313
408
|
batch.seq_lens = seq_lens_backup
|
314
|
-
batch.req_pool_indices = req_pool_indices_backup
|
315
409
|
|
316
410
|
def capture_for_decode(
|
317
|
-
self, logits_output: LogitsProcessorOutput,
|
411
|
+
self, logits_output: LogitsProcessorOutput, draft_input: EagleDraftInput
|
318
412
|
):
|
319
413
|
probs = torch.softmax(logits_output.next_token_logits, dim=-1)
|
320
|
-
|
321
|
-
|
322
|
-
|
323
|
-
|
324
|
-
|
325
|
-
|
326
|
-
|
327
|
-
|
328
|
-
|
329
|
-
|
330
|
-
|
331
|
-
|
332
|
-
|
333
|
-
|
334
|
-
|
335
|
-
|
336
|
-
|
337
|
-
|
338
|
-
|
339
|
-
|
340
|
-
self.model_runner.token_to_kv_pool.free(kv_indices)
|
341
|
-
self.model_runner.req_to_token_pool.free(req.req_pool_idx)
|
414
|
+
draft_input.topk_p, draft_input.topk_index = fast_topk(probs, self.topk, dim=-1)
|
415
|
+
draft_input.hidden_states = logits_output.hidden_states
|
416
|
+
|
417
|
+
def _detect_nan_if_needed(self, logits_output: LogitsProcessorOutput):
|
418
|
+
if self.use_nan_detection:
|
419
|
+
logits = logits_output.next_token_logits
|
420
|
+
if torch.any(torch.isnan(logits)):
|
421
|
+
logger.warning("Detected errors during sampling! NaN in the logits.")
|
422
|
+
raise ValueError("Detected errors during sampling! NaN in the logits.")
|
423
|
+
|
424
|
+
|
425
|
+
def load_token_map(token_map_path: str) -> List[int]:
|
426
|
+
if not os.path.exists(token_map_path):
|
427
|
+
cache_dir = snapshot_download(
|
428
|
+
os.path.dirname(token_map_path),
|
429
|
+
ignore_patterns=["*.bin", "*.safetensors"],
|
430
|
+
)
|
431
|
+
token_map_path = os.path.join(cache_dir, os.path.basename(token_map_path))
|
432
|
+
hot_token_id = torch.load(token_map_path)
|
433
|
+
return torch.tensor(hot_token_id, dtype=torch.int32)
|
@@ -5,30 +5,18 @@ class SpeculativeAlgorithm(IntEnum):
|
|
5
5
|
NONE = auto()
|
6
6
|
EAGLE = auto()
|
7
7
|
|
8
|
-
# NEXTN spec decoding is for DeepSeek V3/R1
|
9
|
-
# currently it's implemented based on EAGLE
|
10
|
-
NEXTN = auto()
|
11
|
-
|
12
8
|
def is_none(self):
|
13
9
|
return self == SpeculativeAlgorithm.NONE
|
14
10
|
|
15
11
|
def is_eagle(self):
|
16
|
-
return self == SpeculativeAlgorithm.EAGLE
|
17
|
-
|
18
|
-
def is_nextn(self):
|
19
|
-
return self == SpeculativeAlgorithm.NEXTN
|
12
|
+
return self == SpeculativeAlgorithm.EAGLE
|
20
13
|
|
21
14
|
@staticmethod
|
22
15
|
def from_string(name: str):
|
23
16
|
name_map = {
|
24
17
|
"EAGLE": SpeculativeAlgorithm.EAGLE,
|
25
|
-
"NEXTN": SpeculativeAlgorithm.NEXTN,
|
26
18
|
None: SpeculativeAlgorithm.NONE,
|
27
19
|
}
|
28
20
|
if name is not None:
|
29
21
|
name = name.upper()
|
30
22
|
return name_map[name]
|
31
|
-
|
32
|
-
|
33
|
-
class SpecInfo:
|
34
|
-
pass
|
sglang/srt/utils.py
CHANGED
@@ -32,13 +32,15 @@ import socket
|
|
32
32
|
import subprocess
|
33
33
|
import sys
|
34
34
|
import tempfile
|
35
|
+
import threading
|
35
36
|
import time
|
36
37
|
import warnings
|
37
38
|
from functools import lru_cache
|
38
39
|
from importlib.metadata import PackageNotFoundError, version
|
39
40
|
from io import BytesIO
|
41
|
+
from multiprocessing import Pool
|
40
42
|
from multiprocessing.reduction import ForkingPickler
|
41
|
-
from typing import Any, Callable, Dict, List, Optional, Protocol, Tuple, Union
|
43
|
+
from typing import Any, Callable, Dict, List, Optional, Protocol, Set, Tuple, Union
|
42
44
|
|
43
45
|
import numpy as np
|
44
46
|
import psutil
|
@@ -311,7 +313,7 @@ def make_layers(
|
|
311
313
|
"""Make a list of layers with the given layer function"""
|
312
314
|
modules = torch.nn.ModuleList(
|
313
315
|
[
|
314
|
-
maybe_offload_to_cpu(layer_fn(idx=idx, prefix=
|
316
|
+
maybe_offload_to_cpu(layer_fn(idx=idx, prefix=add_prefix(idx, prefix)))
|
315
317
|
for idx in range(num_hidden_layers)
|
316
318
|
]
|
317
319
|
)
|
@@ -480,6 +482,10 @@ def assert_pkg_version(pkg: str, min_version: str, message: str):
|
|
480
482
|
|
481
483
|
def kill_process_tree(parent_pid, include_parent: bool = True, skip_pid: int = None):
|
482
484
|
"""Kill the process and all its child processes."""
|
485
|
+
# Remove sigchld handler to avoid spammy logs.
|
486
|
+
if threading.current_thread() is threading.main_thread():
|
487
|
+
signal.signal(signal.SIGCHLD, signal.SIG_DFL)
|
488
|
+
|
483
489
|
if parent_pid is None:
|
484
490
|
parent_pid = os.getpid()
|
485
491
|
include_parent = False
|
@@ -735,13 +741,6 @@ def pytorch_profile(name, func, *args, data_size=-1):
|
|
735
741
|
return result
|
736
742
|
|
737
743
|
|
738
|
-
def first_rank_print(*args, **kwargs):
|
739
|
-
if torch.cuda.current_device() == 0:
|
740
|
-
print(*args, **kwargs)
|
741
|
-
else:
|
742
|
-
pass
|
743
|
-
|
744
|
-
|
745
744
|
def get_zmq_socket(
|
746
745
|
context: zmq.Context, socket_type: zmq.SocketType, endpoint: str, bind: bool
|
747
746
|
):
|
@@ -1154,9 +1153,9 @@ def set_gpu_proc_affinity(
|
|
1154
1153
|
|
1155
1154
|
if psutil.cpu_count() != psutil.cpu_count(logical=False):
|
1156
1155
|
# HT on
|
1157
|
-
|
1158
|
-
|
1159
|
-
bind_cpu_ids = list(itertools.chain(
|
1156
|
+
lower_cpu_ids = [id for id in range(start_cpu_id, end_cpu_id)]
|
1157
|
+
upper_cpu_ids = [id + total_pcores for id in range(start_cpu_id, end_cpu_id)]
|
1158
|
+
bind_cpu_ids = list(itertools.chain(lower_cpu_ids, upper_cpu_ids))
|
1160
1159
|
else:
|
1161
1160
|
# HT off
|
1162
1161
|
bind_cpu_ids = [id for id in range(start_cpu_id, end_cpu_id)]
|
@@ -1171,6 +1170,11 @@ def get_bool_env_var(name: str, default: str = "false") -> bool:
|
|
1171
1170
|
return value.lower() in ("true", "1")
|
1172
1171
|
|
1173
1172
|
|
1173
|
+
@lru_cache(maxsize=2)
|
1174
|
+
def disable_request_logging() -> bool:
|
1175
|
+
return get_bool_env_var("SGLANG_DISABLE_REQUEST_LOGGING")
|
1176
|
+
|
1177
|
+
|
1174
1178
|
@lru_cache(maxsize=8)
|
1175
1179
|
def _cuda_device_count_stateless(cuda_visible_devices: Optional[str] = None) -> int:
|
1176
1180
|
# Note: cuda_visible_devices is not used, but we keep it as an argument for
|
@@ -1212,7 +1216,11 @@ def cuda_device_count_stateless() -> int:
|
|
1212
1216
|
return _cuda_device_count_stateless(os.environ.get("CUDA_VISIBLE_DEVICES", None))
|
1213
1217
|
|
1214
1218
|
|
1215
|
-
def dataclass_to_string_truncated(
|
1219
|
+
def dataclass_to_string_truncated(
|
1220
|
+
data, max_length=2048, skip_names: Optional[Set[str]] = None
|
1221
|
+
):
|
1222
|
+
if skip_names is None:
|
1223
|
+
skip_names = set()
|
1216
1224
|
if isinstance(data, str):
|
1217
1225
|
if len(data) > max_length:
|
1218
1226
|
half_length = max_length // 2
|
@@ -1231,6 +1239,7 @@ def dataclass_to_string_truncated(data, max_length=2048):
|
|
1231
1239
|
+ ", ".join(
|
1232
1240
|
f"'{k}': {dataclass_to_string_truncated(v, max_length)}"
|
1233
1241
|
for k, v in data.items()
|
1242
|
+
if k not in skip_names
|
1234
1243
|
)
|
1235
1244
|
+ "}"
|
1236
1245
|
)
|
@@ -1241,6 +1250,7 @@ def dataclass_to_string_truncated(data, max_length=2048):
|
|
1241
1250
|
+ ", ".join(
|
1242
1251
|
f"{f.name}={dataclass_to_string_truncated(getattr(data, f.name), max_length)}"
|
1243
1252
|
for f in fields
|
1253
|
+
if f.name not in skip_names
|
1244
1254
|
)
|
1245
1255
|
+ ")"
|
1246
1256
|
)
|
@@ -1289,7 +1299,7 @@ def debug_timing(func):
|
|
1289
1299
|
tic.record()
|
1290
1300
|
result = func(*args, **kwargs)
|
1291
1301
|
toc.record()
|
1292
|
-
|
1302
|
+
toc.synchronize() # Wait for the function to complete without synchronizing all ops on the GPU
|
1293
1303
|
elapsed = tic.elapsed_time(toc)
|
1294
1304
|
indices = kwargs.get("indices", args[1] if len(args) > 1 else None)
|
1295
1305
|
num_tokens = len(indices) if indices is not None else 0
|
@@ -1319,9 +1329,9 @@ def pyspy_dump_schedulers():
|
|
1319
1329
|
result = subprocess.run(
|
1320
1330
|
cmd, shell=True, capture_output=True, text=True, check=True
|
1321
1331
|
)
|
1322
|
-
logger.
|
1332
|
+
logger.error(f"Pyspy dump for PID {pid}:\n{result.stdout}")
|
1323
1333
|
except subprocess.CalledProcessError as e:
|
1324
|
-
logger.
|
1334
|
+
logger.error(f"Pyspy failed to dump PID {pid}. Error: {e.stderr}")
|
1325
1335
|
|
1326
1336
|
|
1327
1337
|
def kill_itself_when_parent_died():
|
@@ -1383,7 +1393,6 @@ def get_ip() -> str:
|
|
1383
1393
|
|
1384
1394
|
|
1385
1395
|
def get_open_port() -> int:
|
1386
|
-
|
1387
1396
|
port = os.getenv("SGLANG_PORT")
|
1388
1397
|
if port is not None:
|
1389
1398
|
while True:
|
@@ -1446,8 +1455,25 @@ def launch_dummy_health_check_server(host, port):
|
|
1446
1455
|
)
|
1447
1456
|
|
1448
1457
|
|
1458
|
+
def create_checksum(directory: str):
|
1459
|
+
raise NotImplementedError()
|
1460
|
+
|
1461
|
+
|
1449
1462
|
def set_cuda_arch():
|
1450
1463
|
if is_flashinfer_available():
|
1451
1464
|
capability = torch.cuda.get_device_capability()
|
1452
1465
|
arch = f"{capability[0]}.{capability[1]}"
|
1453
1466
|
os.environ["TORCH_CUDA_ARCH_LIST"] = f"{arch}{'+PTX' if arch == '9.0' else ''}"
|
1467
|
+
|
1468
|
+
|
1469
|
+
def add_prefix(name: str, prefix: str) -> str:
|
1470
|
+
"""Add a weight path prefix to a module name.
|
1471
|
+
|
1472
|
+
Args:
|
1473
|
+
name: base module name.
|
1474
|
+
prefix: weight prefix str to added to the front of `name` concatenated with `.`.
|
1475
|
+
|
1476
|
+
Returns:
|
1477
|
+
The string `prefix.name` if prefix is non-empty, otherwise just `name`.
|
1478
|
+
"""
|
1479
|
+
return name if not prefix else f"{prefix}.{name}"
|
sglang/srt/warmup.py
ADDED
@@ -0,0 +1,47 @@
|
|
1
|
+
import logging
|
2
|
+
from typing import List
|
3
|
+
|
4
|
+
import numpy as np
|
5
|
+
import tqdm
|
6
|
+
|
7
|
+
from sglang.srt.managers.io_struct import GenerateReqInput
|
8
|
+
from sglang.srt.managers.tokenizer_manager import TokenizerManager
|
9
|
+
|
10
|
+
logger = logging.getLogger(__file__)
|
11
|
+
|
12
|
+
_warmup_registry = {}
|
13
|
+
|
14
|
+
|
15
|
+
def warmup(name: str) -> callable:
|
16
|
+
def decorator(fn: callable):
|
17
|
+
_warmup_registry[name] = fn
|
18
|
+
return fn
|
19
|
+
|
20
|
+
return decorator
|
21
|
+
|
22
|
+
|
23
|
+
async def execute_warmups(warmup_names: List[str], tokenizer_manager: TokenizerManager):
|
24
|
+
for warmup_name in warmup_names:
|
25
|
+
if warmup_name not in _warmup_registry:
|
26
|
+
logger.warning(f"Could not find custom warmup {warmup_name}")
|
27
|
+
continue
|
28
|
+
logger.info(f"Running warmup {warmup_name}")
|
29
|
+
await _warmup_registry[warmup_name](tokenizer_manager)
|
30
|
+
|
31
|
+
|
32
|
+
@warmup("voice_chat")
|
33
|
+
async def voice_chat(tokenizer_manager: TokenizerManager):
|
34
|
+
# this warms up the fused_moe triton kernels and caches them
|
35
|
+
# if we don't do this we break real time inference for voice chat
|
36
|
+
for i in tqdm.trange(1, 512):
|
37
|
+
size = i * 4
|
38
|
+
generate_req_input = GenerateReqInput(
|
39
|
+
input_ids=(np.random.randint(2**16, size=[size])).tolist(),
|
40
|
+
sampling_params={
|
41
|
+
"max_new_tokens": 30,
|
42
|
+
"temperature": 0.8,
|
43
|
+
"stop_token_ids": [1],
|
44
|
+
"min_p": 0.0,
|
45
|
+
},
|
46
|
+
)
|
47
|
+
await tokenizer_manager.generate_request(generate_req_input, None).__anext__()
|
sglang/test/few_shot_gsm8k.py
CHANGED
@@ -93,9 +93,11 @@ def run_eval(args):
|
|
93
93
|
tic = time.time()
|
94
94
|
states = few_shot_gsm8k.run_batch(
|
95
95
|
arguments,
|
96
|
-
temperature=0,
|
96
|
+
temperature=args.temperature if hasattr(args, "temperature") else 0,
|
97
97
|
num_threads=args.parallel,
|
98
98
|
progress_bar=True,
|
99
|
+
return_logprob=getattr(args, "return_logprob", None),
|
100
|
+
logprob_start_len=getattr(args, "logprob_start_len", None),
|
99
101
|
)
|
100
102
|
latency = time.time() - tic
|
101
103
|
|
@@ -141,5 +143,6 @@ if __name__ == "__main__":
|
|
141
143
|
parser.add_argument("--parallel", type=int, default=128)
|
142
144
|
parser.add_argument("--host", type=str, default="http://127.0.0.1")
|
143
145
|
parser.add_argument("--port", type=int, default=30000)
|
146
|
+
parser.add_argument("--temperature", type=float, default=0.0)
|
144
147
|
args = parser.parse_args()
|
145
148
|
run_eval(args)
|