sglang 0.4.3.post1__py3-none-any.whl → 0.4.3.post3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/api.py +1 -1
- sglang/bench_offline_throughput.py +19 -0
- sglang/bench_one_batch.py +2 -2
- sglang/bench_serving.py +123 -79
- sglang/global_config.py +8 -3
- sglang/lang/backend/runtime_endpoint.py +1 -1
- sglang/lang/ir.py +1 -1
- sglang/srt/_custom_ops.py +83 -91
- sglang/srt/configs/load_config.py +4 -1
- sglang/srt/configs/model_config.py +48 -2
- sglang/srt/configs/qwen2_5_vl_config.py +5 -2
- sglang/srt/constrained/base_grammar_backend.py +117 -15
- sglang/srt/constrained/llguidance_backend.py +151 -0
- sglang/srt/constrained/outlines_backend.py +24 -33
- sglang/srt/constrained/xgrammar_backend.py +69 -38
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +225 -80
- sglang/srt/distributed/parallel_state.py +48 -3
- sglang/srt/entrypoints/engine.py +67 -9
- sglang/srt/entrypoints/http_server.py +190 -41
- sglang/srt/entrypoints/verl_engine.py +147 -0
- sglang/srt/function_call_parser.py +0 -1
- sglang/srt/layers/activation.py +11 -0
- sglang/srt/layers/attention/{__init__.py → base_attn_backend.py} +14 -6
- sglang/srt/layers/attention/double_sparsity_backend.py +1 -1
- sglang/srt/layers/attention/flashinfer_backend.py +208 -295
- sglang/srt/layers/attention/flashinfer_mla_backend.py +582 -0
- sglang/srt/layers/attention/torch_native_backend.py +1 -1
- sglang/srt/layers/attention/triton_backend.py +9 -6
- sglang/srt/layers/attention/triton_ops/decode_attention.py +3 -0
- sglang/srt/layers/attention/triton_ops/extend_attention.py +20 -4
- sglang/srt/layers/attention/triton_ops/rocm_mla_decode_rope.py +439 -0
- sglang/srt/layers/attention/utils.py +39 -0
- sglang/srt/layers/attention/vision.py +60 -63
- sglang/srt/layers/dp_attention.py +142 -1
- sglang/srt/layers/layernorm.py +1 -1
- sglang/srt/layers/linear.py +3 -1
- sglang/srt/layers/logits_processor.py +281 -45
- sglang/srt/layers/moe/ep_moe/kernels.py +126 -8
- sglang/srt/layers/moe/ep_moe/layer.py +140 -28
- sglang/srt/layers/moe/fused_moe_native.py +2 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +50 -50
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Instinct_MI300X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Instinct_MI325X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Radeon_Graphics.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Instinct_MI300X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Instinct_MI325X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Radeon_Graphics.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Instinct_MI300X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Instinct_MI325X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Radeon_Graphics.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json +16 -16
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json +16 -16
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json +16 -16
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Instinct_MI300X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Instinct_MI325X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Radeon_Graphics.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json +15 -15
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json +15 -15
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json +15 -15
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +88 -20
- sglang/srt/layers/moe/fused_moe_triton/layer.py +34 -13
- sglang/srt/layers/moe/topk.py +13 -4
- sglang/srt/layers/quantization/__init__.py +111 -7
- sglang/srt/layers/quantization/blockwise_int8.py +409 -0
- sglang/srt/layers/quantization/configs/N=1536,K=1536,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=1536,K=1536,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
- sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
- sglang/srt/layers/quantization/configs/N=2048,K=512,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=2048,K=512,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
- sglang/srt/layers/quantization/configs/N=2304,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=2304,K=7168,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
- sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
- sglang/srt/layers/quantization/configs/N=256,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=256,K=7168,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
- sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
- sglang/srt/layers/quantization/configs/N=7168,K=1024,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=1024,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
- sglang/srt/layers/quantization/configs/N=7168,K=1152,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=1152,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
- sglang/srt/layers/quantization/configs/N=7168,K=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=128,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
- sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_L20Y,dtype=fp8_w8a8,block_shape=[128, 128].json +26 -0
- sglang/srt/layers/quantization/fp8.py +69 -28
- sglang/srt/layers/quantization/fp8_utils.py +17 -1
- sglang/srt/layers/quantization/gptq.py +416 -0
- sglang/srt/layers/quantization/int8_kernel.py +327 -0
- sglang/srt/layers/quantization/int8_utils.py +73 -0
- sglang/srt/layers/quantization/modelopt_quant.py +18 -1
- sglang/srt/layers/radix_attention.py +1 -0
- sglang/srt/layers/rotary_embedding.py +0 -1
- sglang/srt/layers/sampler.py +76 -31
- sglang/srt/layers/vocab_parallel_embedding.py +14 -13
- sglang/srt/lora/lora.py +17 -1
- sglang/srt/lora/lora_config.py +5 -0
- sglang/srt/lora/lora_manager.py +1 -3
- sglang/srt/managers/cache_controller.py +193 -62
- sglang/srt/managers/configure_logging.py +2 -1
- sglang/srt/managers/data_parallel_controller.py +6 -2
- sglang/srt/managers/detokenizer_manager.py +124 -102
- sglang/srt/managers/image_processor.py +2 -1
- sglang/srt/managers/io_struct.py +143 -6
- sglang/srt/managers/schedule_batch.py +238 -197
- sglang/srt/managers/schedule_policy.py +29 -29
- sglang/srt/managers/scheduler.py +681 -259
- sglang/srt/managers/session_controller.py +6 -2
- sglang/srt/managers/tokenizer_manager.py +224 -68
- sglang/srt/managers/tp_worker.py +15 -4
- sglang/srt/managers/tp_worker_overlap_thread.py +3 -4
- sglang/srt/mem_cache/chunk_cache.py +18 -11
- sglang/srt/mem_cache/hiradix_cache.py +394 -0
- sglang/srt/mem_cache/memory_pool.py +44 -18
- sglang/srt/mem_cache/radix_cache.py +58 -47
- sglang/srt/metrics/collector.py +94 -36
- sglang/srt/model_executor/cuda_graph_runner.py +55 -24
- sglang/srt/model_executor/forward_batch_info.py +49 -16
- sglang/srt/model_executor/model_runner.py +209 -28
- sglang/srt/model_loader/loader.py +3 -3
- sglang/srt/model_loader/weight_utils.py +36 -14
- sglang/srt/models/baichuan.py +31 -6
- sglang/srt/models/chatglm.py +39 -7
- sglang/srt/models/commandr.py +29 -5
- sglang/srt/models/dbrx.py +31 -5
- sglang/srt/models/deepseek.py +43 -6
- sglang/srt/models/deepseek_nextn.py +32 -19
- sglang/srt/models/deepseek_v2.py +265 -29
- sglang/srt/models/exaone.py +19 -9
- sglang/srt/models/gemma.py +22 -8
- sglang/srt/models/gemma2.py +25 -12
- sglang/srt/models/gemma2_reward.py +5 -1
- sglang/srt/models/gpt2.py +28 -13
- sglang/srt/models/gpt_bigcode.py +27 -5
- sglang/srt/models/granite.py +21 -9
- sglang/srt/models/grok.py +21 -4
- sglang/srt/models/internlm2.py +36 -6
- sglang/srt/models/internlm2_reward.py +5 -1
- sglang/srt/models/llama.py +26 -9
- sglang/srt/models/llama_classification.py +5 -1
- sglang/srt/models/llama_eagle.py +17 -4
- sglang/srt/models/llama_embedding.py +5 -1
- sglang/srt/models/llama_reward.py +7 -2
- sglang/srt/models/llava.py +19 -3
- sglang/srt/models/llavavid.py +10 -1
- sglang/srt/models/minicpm.py +26 -2
- sglang/srt/models/minicpm3.py +39 -3
- sglang/srt/models/minicpmv.py +45 -14
- sglang/srt/models/mixtral.py +20 -9
- sglang/srt/models/mixtral_quant.py +50 -8
- sglang/srt/models/mllama.py +57 -11
- sglang/srt/models/olmo.py +34 -6
- sglang/srt/models/olmo2.py +34 -13
- sglang/srt/models/olmoe.py +26 -4
- sglang/srt/models/phi3_small.py +29 -10
- sglang/srt/models/qwen.py +26 -3
- sglang/srt/models/qwen2.py +26 -4
- sglang/srt/models/qwen2_5_vl.py +46 -8
- sglang/srt/models/qwen2_eagle.py +17 -5
- sglang/srt/models/qwen2_moe.py +44 -6
- sglang/srt/models/qwen2_rm.py +78 -0
- sglang/srt/models/qwen2_vl.py +39 -8
- sglang/srt/models/stablelm.py +32 -5
- sglang/srt/models/torch_native_llama.py +5 -2
- sglang/srt/models/xverse.py +21 -9
- sglang/srt/models/xverse_moe.py +45 -7
- sglang/srt/models/yivl.py +2 -1
- sglang/srt/openai_api/adapter.py +109 -24
- sglang/srt/openai_api/protocol.py +17 -1
- sglang/srt/reasoning_parser.py +154 -0
- sglang/srt/sampling/penaltylib/__init__.py +4 -6
- sglang/srt/sampling/penaltylib/frequency_penalty.py +66 -0
- sglang/srt/sampling/penaltylib/{penalizers/min_new_tokens.py → min_new_tokens.py} +15 -23
- sglang/srt/sampling/penaltylib/orchestrator.py +39 -188
- sglang/srt/sampling/penaltylib/presence_penalty.py +66 -0
- sglang/srt/sampling/sampling_batch_info.py +79 -157
- sglang/srt/sampling/sampling_params.py +16 -13
- sglang/srt/server_args.py +136 -52
- sglang/srt/speculative/build_eagle_tree.py +2 -8
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +0 -1
- sglang/srt/speculative/eagle_utils.py +92 -58
- sglang/srt/speculative/eagle_worker.py +186 -94
- sglang/srt/speculative/spec_info.py +1 -13
- sglang/srt/utils.py +43 -17
- sglang/srt/warmup.py +47 -0
- sglang/test/few_shot_gsm8k.py +4 -1
- sglang/test/runners.py +389 -126
- sglang/test/send_one.py +88 -0
- sglang/test/test_block_fp8_ep.py +361 -0
- sglang/test/test_programs.py +1 -1
- sglang/test/test_utils.py +138 -84
- sglang/utils.py +50 -60
- sglang/version.py +1 -1
- {sglang-0.4.3.post1.dist-info → sglang-0.4.3.post3.dist-info}/METADATA +21 -15
- {sglang-0.4.3.post1.dist-info → sglang-0.4.3.post3.dist-info}/RECORD +214 -166
- {sglang-0.4.3.post1.dist-info → sglang-0.4.3.post3.dist-info}/WHEEL +1 -1
- sglang/bench_latency.py +0 -1
- sglang/srt/sampling/penaltylib/penalizers/frequency_penalty.py +0 -75
- sglang/srt/sampling/penaltylib/penalizers/presence_penalty.py +0 -74
- sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +0 -85
- sglang/test/srt/sampling/penaltylib/utils.py +0 -344
- {sglang-0.4.3.post1.dist-info → sglang-0.4.3.post3.dist-info}/LICENSE +0 -0
- {sglang-0.4.3.post1.dist-info → sglang-0.4.3.post3.dist-info}/top_level.txt +0 -0
@@ -31,7 +31,7 @@ from __future__ import annotations
|
|
31
31
|
|
32
32
|
from dataclasses import dataclass
|
33
33
|
from enum import IntEnum, auto
|
34
|
-
from typing import TYPE_CHECKING, List, Optional
|
34
|
+
from typing import TYPE_CHECKING, List, Optional, Union
|
35
35
|
|
36
36
|
import torch
|
37
37
|
import triton
|
@@ -41,12 +41,13 @@ from sglang.srt.layers.rotary_embedding import MRotaryEmbedding
|
|
41
41
|
from sglang.srt.utils import get_compiler_backend
|
42
42
|
|
43
43
|
if TYPE_CHECKING:
|
44
|
-
from sglang.srt.layers.attention import AttentionBackend
|
44
|
+
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
45
45
|
from sglang.srt.managers.schedule_batch import ImageInputs, ModelWorkerBatch
|
46
46
|
from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
|
47
47
|
from sglang.srt.model_executor.model_runner import ModelRunner
|
48
48
|
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
49
|
-
from sglang.srt.speculative.
|
49
|
+
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
|
50
|
+
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
|
50
51
|
|
51
52
|
|
52
53
|
class ForwardMode(IntEnum):
|
@@ -112,7 +113,9 @@ class ForwardMode(IntEnum):
|
|
112
113
|
|
113
114
|
class CaptureHiddenMode(IntEnum):
|
114
115
|
NULL = auto()
|
116
|
+
# Capture hidden states of all tokens.
|
115
117
|
FULL = auto()
|
118
|
+
# Capture a hidden state of the last token.
|
116
119
|
LAST = auto()
|
117
120
|
|
118
121
|
def need_capture(self):
|
@@ -148,10 +151,14 @@ class ForwardBatch:
|
|
148
151
|
# For logprob
|
149
152
|
return_logprob: bool = False
|
150
153
|
top_logprobs_nums: Optional[List[int]] = None
|
154
|
+
token_ids_logprobs: Optional[List[List[int]]] = None
|
151
155
|
|
152
156
|
# Position information
|
153
157
|
positions: torch.Tensor = None
|
154
158
|
|
159
|
+
# For decode
|
160
|
+
decode_seq_lens_cpu: Optional[torch.Tensor] = None
|
161
|
+
|
155
162
|
# For extend
|
156
163
|
extend_num_tokens: Optional[int] = None
|
157
164
|
extend_seq_lens: Optional[torch.Tensor] = None
|
@@ -160,6 +167,7 @@ class ForwardBatch:
|
|
160
167
|
extend_prefix_lens_cpu: Optional[List[int]] = None
|
161
168
|
extend_seq_lens_cpu: Optional[List[int]] = None
|
162
169
|
extend_logprob_start_lens_cpu: Optional[List[int]] = None
|
170
|
+
extend_input_logprob_token_ids_gpu: Optional[torch.Tensor] = None
|
163
171
|
|
164
172
|
# For multimodal
|
165
173
|
image_inputs: Optional[List[ImageInputs]] = None
|
@@ -185,15 +193,27 @@ class ForwardBatch:
|
|
185
193
|
attn_backend: AttentionBackend = None
|
186
194
|
|
187
195
|
# For DP attention
|
188
|
-
|
196
|
+
global_num_tokens_cpu: Optional[List[int]] = None
|
197
|
+
global_num_tokens_gpu: Optional[torch.Tensor] = None
|
198
|
+
# Has to be None when cuda graph is captured.
|
199
|
+
global_num_tokens_for_logprob_cpu: Optional[List[int]] = None
|
200
|
+
global_num_tokens_for_logprob_gpu: Optional[torch.Tensor] = None
|
201
|
+
# for extend, local start pos and num tokens is different in logits processor
|
202
|
+
# this will be computed in get_dp_local_info
|
203
|
+
# this will be recomputed in LogitsMetadata.from_forward_batch
|
204
|
+
dp_local_start_pos: Optional[torch.Tensor] = None # cached info at runtime
|
205
|
+
dp_local_num_tokens: Optional[torch.Tensor] = None # cached info at runtime
|
189
206
|
gathered_buffer: Optional[torch.Tensor] = None
|
190
207
|
can_run_dp_cuda_graph: bool = False
|
191
208
|
|
192
209
|
# Speculative decoding
|
193
|
-
spec_info:
|
210
|
+
spec_info: Optional[Union[EagleVerifyInput, EagleDraftInput]] = None
|
194
211
|
spec_algorithm: SpeculativeAlgorithm = None
|
195
212
|
capture_hidden_mode: CaptureHiddenMode = None
|
196
213
|
|
214
|
+
# For padding
|
215
|
+
padded_static_len: int = -1 # -1 if not padded
|
216
|
+
|
197
217
|
# For Qwen2-VL
|
198
218
|
mrope_positions: torch.Tensor = None
|
199
219
|
|
@@ -203,8 +223,13 @@ class ForwardBatch:
|
|
203
223
|
batch: ModelWorkerBatch,
|
204
224
|
model_runner: ModelRunner,
|
205
225
|
):
|
206
|
-
|
207
226
|
device = model_runner.device
|
227
|
+
extend_input_logprob_token_ids_gpu = None
|
228
|
+
if batch.extend_input_logprob_token_ids is not None:
|
229
|
+
extend_input_logprob_token_ids_gpu = (
|
230
|
+
batch.extend_input_logprob_token_ids.to(device, non_blocking=True)
|
231
|
+
)
|
232
|
+
|
208
233
|
ret = cls(
|
209
234
|
forward_mode=batch.forward_mode,
|
210
235
|
batch_size=len(batch.seq_lens),
|
@@ -220,7 +245,7 @@ class ForwardBatch:
|
|
220
245
|
seq_lens_sum=batch.seq_lens_sum,
|
221
246
|
return_logprob=batch.return_logprob,
|
222
247
|
top_logprobs_nums=batch.top_logprobs_nums,
|
223
|
-
|
248
|
+
token_ids_logprobs=batch.token_ids_logprobs,
|
224
249
|
can_run_dp_cuda_graph=batch.can_run_dp_cuda_graph,
|
225
250
|
lora_paths=batch.lora_paths,
|
226
251
|
sampling_info=batch.sampling_info,
|
@@ -231,10 +256,12 @@ class ForwardBatch:
|
|
231
256
|
spec_info=batch.spec_info,
|
232
257
|
capture_hidden_mode=batch.capture_hidden_mode,
|
233
258
|
input_embeds=batch.input_embeds,
|
259
|
+
extend_input_logprob_token_ids_gpu=extend_input_logprob_token_ids_gpu,
|
234
260
|
)
|
235
261
|
|
236
|
-
if
|
237
|
-
|
262
|
+
if batch.global_num_tokens is not None:
|
263
|
+
ret.global_num_tokens_cpu = batch.global_num_tokens
|
264
|
+
max_len = max(ret.global_num_tokens_cpu)
|
238
265
|
ret.gathered_buffer = torch.zeros(
|
239
266
|
(max_len * model_runner.tp_size, model_runner.model_config.hidden_size),
|
240
267
|
dtype=model_runner.dtype,
|
@@ -256,6 +283,8 @@ class ForwardBatch:
|
|
256
283
|
if ret.forward_mode.is_decode():
|
257
284
|
if ret.positions is None:
|
258
285
|
ret.positions = clamp_position(batch.seq_lens)
|
286
|
+
if ret.decode_seq_lens_cpu is None:
|
287
|
+
ret.decode_seq_lens_cpu = batch.decode_seq_lens
|
259
288
|
else:
|
260
289
|
ret.extend_seq_lens = torch.tensor(
|
261
290
|
batch.extend_seq_lens, dtype=torch.int32
|
@@ -263,13 +292,12 @@ class ForwardBatch:
|
|
263
292
|
ret.extend_prefix_lens = torch.tensor(
|
264
293
|
batch.extend_prefix_lens, dtype=torch.int32
|
265
294
|
).to(device, non_blocking=True)
|
266
|
-
if
|
267
|
-
model_runner.server_args.attention_backend != "torch_native"
|
268
|
-
and model_runner.server_args.speculative_algorithm != "NEXTN"
|
269
|
-
):
|
295
|
+
if model_runner.server_args.attention_backend != "torch_native":
|
270
296
|
ret.extend_num_tokens = batch.extend_num_tokens
|
271
297
|
positions, ret.extend_start_loc = compute_position_triton(
|
272
|
-
ret.extend_prefix_lens,
|
298
|
+
ret.extend_prefix_lens,
|
299
|
+
ret.extend_seq_lens,
|
300
|
+
ret.extend_num_tokens,
|
273
301
|
)
|
274
302
|
else:
|
275
303
|
positions, ret.extend_start_loc = compute_position_torch(
|
@@ -341,6 +369,7 @@ class ForwardBatch:
|
|
341
369
|
)
|
342
370
|
batch.image_inputs[i].mrope_position_delta = mrope_position_delta
|
343
371
|
mrope_positions_list[i] = mrope_positions
|
372
|
+
|
344
373
|
self.mrope_positions = torch.concat(
|
345
374
|
[torch.tensor(pos, device=device) for pos in mrope_positions_list],
|
346
375
|
axis=1,
|
@@ -353,6 +382,8 @@ def compute_position_triton(
|
|
353
382
|
):
|
354
383
|
"""Compute positions. It is a fused version of `compute_position_torch`."""
|
355
384
|
batch_size = extend_seq_lens.shape[0]
|
385
|
+
has_prefix = extend_prefix_lens.shape[0] == batch_size
|
386
|
+
|
356
387
|
positions = torch.empty(
|
357
388
|
extend_seq_lens_sum, dtype=torch.int64, device=extend_seq_lens.device
|
358
389
|
)
|
@@ -366,6 +397,7 @@ def compute_position_triton(
|
|
366
397
|
extend_start_loc,
|
367
398
|
extend_prefix_lens,
|
368
399
|
extend_seq_lens,
|
400
|
+
has_prefix,
|
369
401
|
)
|
370
402
|
|
371
403
|
return positions, extend_start_loc
|
@@ -377,11 +409,12 @@ def compute_position_kernel(
|
|
377
409
|
extend_start_loc,
|
378
410
|
extend_prefix_lens,
|
379
411
|
extend_seq_lens,
|
412
|
+
has_prefix: tl.constexpr,
|
380
413
|
):
|
381
414
|
BLOCK_SIZE: tl.constexpr = 512
|
382
|
-
pid = tl.program_id(0)
|
415
|
+
pid = tl.program_id(0).to(tl.int64)
|
383
416
|
|
384
|
-
prefix_len = tl.load(extend_prefix_lens + pid)
|
417
|
+
prefix_len = tl.load(extend_prefix_lens + pid) if has_prefix else 0
|
385
418
|
seq_len = tl.load(extend_seq_lens + pid)
|
386
419
|
|
387
420
|
# TODO: optimize this?
|
@@ -13,11 +13,14 @@
|
|
13
13
|
# ==============================================================================
|
14
14
|
"""ModelRunner runs the forward passes of the models."""
|
15
15
|
|
16
|
+
import datetime
|
16
17
|
import gc
|
17
18
|
import json
|
18
19
|
import logging
|
20
|
+
import os
|
19
21
|
import time
|
20
|
-
from
|
22
|
+
from dataclasses import dataclass
|
23
|
+
from typing import List, Optional, Tuple, Union
|
21
24
|
|
22
25
|
import torch
|
23
26
|
import torch.distributed as dist
|
@@ -34,6 +37,7 @@ from sglang.srt.distributed import (
|
|
34
37
|
from sglang.srt.distributed.parallel_state import monkey_patch_vllm_parallel_state
|
35
38
|
from sglang.srt.layers.attention.double_sparsity_backend import DoubleSparseAttnBackend
|
36
39
|
from sglang.srt.layers.attention.flashinfer_backend import FlashInferAttnBackend
|
40
|
+
from sglang.srt.layers.attention.flashinfer_mla_backend import FlashInferMLAAttnBackend
|
37
41
|
from sglang.srt.layers.attention.torch_native_backend import TorchNativeAttnBackend
|
38
42
|
from sglang.srt.layers.attention.triton_backend import TritonAttnBackend
|
39
43
|
from sglang.srt.layers.dp_attention import (
|
@@ -51,14 +55,18 @@ from sglang.srt.mem_cache.memory_pool import (
|
|
51
55
|
MHATokenToKVPool,
|
52
56
|
MLATokenToKVPool,
|
53
57
|
ReqToTokenPool,
|
58
|
+
TokenToKVPoolAllocator,
|
54
59
|
)
|
55
60
|
from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner
|
56
61
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
57
62
|
from sglang.srt.model_loader import get_model
|
63
|
+
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
64
|
+
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
58
65
|
from sglang.srt.server_args import ServerArgs
|
59
66
|
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
|
60
67
|
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
|
61
68
|
from sglang.srt.utils import (
|
69
|
+
MultiprocessingSerializer,
|
62
70
|
enable_show_time_cost,
|
63
71
|
get_available_gpu_memory,
|
64
72
|
init_custom_process_group,
|
@@ -69,10 +77,15 @@ from sglang.srt.utils import (
|
|
69
77
|
set_cpu_offload_max_bytes,
|
70
78
|
set_cuda_arch,
|
71
79
|
)
|
80
|
+
from sglang.utils import get_exception_traceback
|
72
81
|
|
73
82
|
logger = logging.getLogger(__name__)
|
74
83
|
|
75
84
|
|
85
|
+
SGLANG_CI_SMALL_KV_SIZE = os.getenv("SGLANG_CI_SMALL_KV_SIZE", None)
|
86
|
+
UNBALANCED_MODEL_LOADING_TIMEOUT_S = 300
|
87
|
+
|
88
|
+
|
76
89
|
class ModelRunner:
|
77
90
|
"""ModelRunner runs the forward passes of the models."""
|
78
91
|
|
@@ -86,6 +99,8 @@ class ModelRunner:
|
|
86
99
|
nccl_port: int,
|
87
100
|
server_args: ServerArgs,
|
88
101
|
is_draft_worker: bool = False,
|
102
|
+
req_to_token_pool: Optional[ReqToTokenPool] = None,
|
103
|
+
token_to_kv_pool_allocator: Optional[TokenToKVPoolAllocator] = None,
|
89
104
|
):
|
90
105
|
# Parse args
|
91
106
|
self.model_config = model_config
|
@@ -103,6 +118,8 @@ class ModelRunner:
|
|
103
118
|
self.spec_algorithm = SpeculativeAlgorithm.from_string(
|
104
119
|
server_args.speculative_algorithm
|
105
120
|
)
|
121
|
+
self.req_to_token_pool = req_to_token_pool
|
122
|
+
self.token_to_kv_pool_allocator = token_to_kv_pool_allocator
|
106
123
|
|
107
124
|
# Model-specific adjustment
|
108
125
|
if (
|
@@ -113,9 +130,9 @@ class ModelRunner:
|
|
113
130
|
if self.server_args.device != "cpu":
|
114
131
|
if server_args.enable_flashinfer_mla:
|
115
132
|
logger.info(
|
116
|
-
"
|
133
|
+
"MLA optimization is turned on. Use flashinfer mla backend."
|
117
134
|
)
|
118
|
-
self.server_args.attention_backend = "
|
135
|
+
self.server_args.attention_backend = "flashinfer_mla"
|
119
136
|
else:
|
120
137
|
logger.info("MLA optimization is turned on. Use triton backend.")
|
121
138
|
self.server_args.attention_backend = "triton"
|
@@ -176,7 +193,13 @@ class ModelRunner:
|
|
176
193
|
"enable_dp_attention": server_args.enable_dp_attention,
|
177
194
|
"enable_ep_moe": server_args.enable_ep_moe,
|
178
195
|
"device": server_args.device,
|
196
|
+
"speculative_accept_threshold_single": server_args.speculative_accept_threshold_single,
|
197
|
+
"speculative_accept_threshold_acc": server_args.speculative_accept_threshold_acc,
|
179
198
|
"enable_flashinfer_mla": server_args.enable_flashinfer_mla,
|
199
|
+
"disable_radix_cache": server_args.disable_radix_cache,
|
200
|
+
"flashinfer_mla_disable_ragged": server_args.flashinfer_mla_disable_ragged,
|
201
|
+
"debug_tensor_dump_output_folder": server_args.debug_tensor_dump_output_folder,
|
202
|
+
"debug_tensor_dump_inject": server_args.debug_tensor_dump_inject,
|
180
203
|
}
|
181
204
|
)
|
182
205
|
|
@@ -193,6 +216,18 @@ class ModelRunner:
|
|
193
216
|
self.sampler = Sampler()
|
194
217
|
self.load_model()
|
195
218
|
|
219
|
+
# Handle the case where some of models don't finish loading.
|
220
|
+
try:
|
221
|
+
dist.monitored_barrier(
|
222
|
+
group=get_tp_group().cpu_group,
|
223
|
+
timeout=datetime.timedelta(seconds=UNBALANCED_MODEL_LOADING_TIMEOUT_S),
|
224
|
+
wait_all_ranks=True,
|
225
|
+
)
|
226
|
+
except RuntimeError:
|
227
|
+
raise ValueError(
|
228
|
+
f"TP rank {self.tp_rank} could finish the model loading, but there are other ranks that didn't finish loading. It is likely due to unexpected failures (e.g., OOM) or a slow node."
|
229
|
+
) from None
|
230
|
+
|
196
231
|
# Apply torchao quantization
|
197
232
|
torchao_applied = getattr(self.model, "torchao_applied", False)
|
198
233
|
# In layered loading, torchao may have been applied
|
@@ -227,19 +262,18 @@ class ModelRunner:
|
|
227
262
|
|
228
263
|
def init_torch_distributed(self):
|
229
264
|
logger.info("Init torch distributed begin.")
|
230
|
-
|
231
265
|
torch.get_device_module(self.device).set_device(self.gpu_id)
|
266
|
+
|
232
267
|
if self.device == "cuda":
|
233
268
|
backend = "nccl"
|
234
269
|
elif self.device == "xpu":
|
235
|
-
|
236
|
-
# Need to use xccl for xpu backend in the future
|
237
|
-
backend = "gloo"
|
270
|
+
backend = "xccl"
|
238
271
|
elif self.device == "hpu":
|
239
272
|
backend = "hccl"
|
240
273
|
elif self.device == "cpu":
|
241
274
|
backend = "gloo"
|
242
275
|
|
276
|
+
before_avail_memory = get_available_gpu_memory(self.device, self.gpu_id)
|
243
277
|
if not self.server_args.enable_p2p_check:
|
244
278
|
monkey_patch_p2p_access_check()
|
245
279
|
|
@@ -257,6 +291,7 @@ class ModelRunner:
|
|
257
291
|
rank=self.tp_rank,
|
258
292
|
local_rank=self.gpu_id,
|
259
293
|
distributed_init_method=dist_init_method,
|
294
|
+
timeout=self.server_args.dist_timeout,
|
260
295
|
)
|
261
296
|
initialize_model_parallel(tensor_model_parallel_size=self.tp_size)
|
262
297
|
initialize_dp_attention(
|
@@ -269,20 +304,24 @@ class ModelRunner:
|
|
269
304
|
min_per_gpu_memory = get_available_gpu_memory(
|
270
305
|
self.device, self.gpu_id, distributed=self.tp_size > 1
|
271
306
|
)
|
307
|
+
local_gpu_memory = get_available_gpu_memory(self.device, self.gpu_id)
|
272
308
|
self.tp_group = get_tp_group()
|
273
309
|
self.attention_tp_group = get_attention_tp_group()
|
274
310
|
|
275
311
|
# Check memory for tensor parallelism
|
276
312
|
if self.tp_size > 1:
|
277
|
-
local_gpu_memory = get_available_gpu_memory(self.device, self.gpu_id)
|
278
313
|
if min_per_gpu_memory < local_gpu_memory * 0.9:
|
279
314
|
raise ValueError(
|
280
315
|
"The memory capacity is unbalanced. Some GPUs may be occupied by other processes."
|
281
316
|
)
|
282
317
|
|
318
|
+
logger.info(
|
319
|
+
f"Init torch distributed ends. mem usage={(before_avail_memory - local_gpu_memory):.2f} GB"
|
320
|
+
)
|
283
321
|
return min_per_gpu_memory
|
284
322
|
|
285
323
|
def load_model(self):
|
324
|
+
before_avail_memory = get_available_gpu_memory(self.device, self.gpu_id)
|
286
325
|
logger.info(
|
287
326
|
f"Load weight begin. avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
|
288
327
|
)
|
@@ -352,11 +391,13 @@ class ModelRunner:
|
|
352
391
|
)
|
353
392
|
self.dtype = self.model_config.dtype
|
354
393
|
|
394
|
+
after_avail_memory = get_available_gpu_memory(self.device, self.gpu_id)
|
355
395
|
logger.info(
|
356
396
|
f"Load weight end. "
|
357
397
|
f"type={type(self.model).__name__}, "
|
358
398
|
f"dtype={self.dtype}, "
|
359
|
-
f"avail mem={
|
399
|
+
f"avail mem={after_avail_memory:.2f} GB, "
|
400
|
+
f"mem usage={(before_avail_memory - after_avail_memory):.2f} GB."
|
360
401
|
)
|
361
402
|
|
362
403
|
def update_weights_from_disk(
|
@@ -511,8 +552,21 @@ class ModelRunner:
|
|
511
552
|
logger.error(error_msg)
|
512
553
|
return False, error_msg
|
513
554
|
|
514
|
-
def update_weights_from_tensor(
|
515
|
-
self
|
555
|
+
def update_weights_from_tensor(
|
556
|
+
self,
|
557
|
+
named_tensors: List[Tuple[str, Union[torch.Tensor, "LocalSerializedTensor"]]],
|
558
|
+
load_format: Optional[str] = None,
|
559
|
+
):
|
560
|
+
named_tensors = [
|
561
|
+
(name, _unwrap_tensor(tensor, tp_rank=self.tp_rank))
|
562
|
+
for name, tensor in named_tensors
|
563
|
+
]
|
564
|
+
if load_format == "direct":
|
565
|
+
_model_load_weights_direct(self.model, named_tensors)
|
566
|
+
elif load_format is None:
|
567
|
+
self.model.load_weights(named_tensors)
|
568
|
+
else:
|
569
|
+
raise NotImplementedError(f"Unknown load_format={load_format}")
|
516
570
|
return True, "Success"
|
517
571
|
|
518
572
|
def get_weights_by_name(
|
@@ -605,15 +659,31 @@ class ModelRunner:
|
|
605
659
|
4096,
|
606
660
|
)
|
607
661
|
|
662
|
+
if SGLANG_CI_SMALL_KV_SIZE:
|
663
|
+
self.max_total_num_tokens = int(SGLANG_CI_SMALL_KV_SIZE)
|
664
|
+
|
608
665
|
if not self.spec_algorithm.is_none():
|
609
666
|
if self.is_draft_worker:
|
610
667
|
self.max_total_num_tokens = self.server_args.draft_runner_cache_size
|
668
|
+
max_num_reqs = self.server_args.max_num_reqs
|
611
669
|
else:
|
670
|
+
# We are sharing the `token_to_kv_pool`, and both verify and draft tokens
|
671
|
+
# can be concurrently allocated, so we should give a headroom for it.
|
612
672
|
self.server_args.draft_runner_cache_size = (
|
613
673
|
self.max_total_num_tokens
|
614
|
-
|
674
|
+
# draft
|
675
|
+
+ max_num_reqs
|
676
|
+
* self.server_args.speculative_num_steps
|
677
|
+
* self.server_args.speculative_eagle_topk
|
678
|
+
# verify
|
679
|
+
+ max_num_reqs * self.server_args.speculative_num_draft_tokens
|
680
|
+
# buffer
|
615
681
|
+ 100
|
616
682
|
)
|
683
|
+
# Target worker and draft worker shares the same indices for the
|
684
|
+
# token_to_kv_pool, so we should make sure to match max_total_num_tokens.
|
685
|
+
self.max_total_num_tokens = self.server_args.draft_runner_cache_size
|
686
|
+
self.server_args.max_num_reqs = max_num_reqs
|
617
687
|
|
618
688
|
if max_total_tokens is not None:
|
619
689
|
if max_total_tokens > self.max_total_num_tokens:
|
@@ -629,12 +699,26 @@ class ModelRunner:
|
|
629
699
|
"Not enough memory. Please try to increase --mem-fraction-static."
|
630
700
|
)
|
631
701
|
|
632
|
-
self.req_to_token_pool
|
633
|
-
|
634
|
-
|
635
|
-
|
636
|
-
|
637
|
-
|
702
|
+
if self.req_to_token_pool is None:
|
703
|
+
self.req_to_token_pool = ReqToTokenPool(
|
704
|
+
size=max_num_reqs + 1,
|
705
|
+
max_context_len=self.model_config.context_len + 4,
|
706
|
+
device=self.device,
|
707
|
+
enable_memory_saver=self.server_args.enable_memory_saver,
|
708
|
+
)
|
709
|
+
else:
|
710
|
+
# Draft worker shares req_to_token_pool with the target worker.
|
711
|
+
assert self.is_draft_worker
|
712
|
+
|
713
|
+
if self.token_to_kv_pool_allocator is None:
|
714
|
+
self.token_to_kv_pool_allocator = TokenToKVPoolAllocator(
|
715
|
+
self.max_total_num_tokens,
|
716
|
+
dtype=self.kv_cache_dtype,
|
717
|
+
device=self.device,
|
718
|
+
)
|
719
|
+
else:
|
720
|
+
assert self.is_draft_worker
|
721
|
+
|
638
722
|
if (
|
639
723
|
self.model_config.attention_arch == AttentionArch.MLA
|
640
724
|
and not self.server_args.disable_mla
|
@@ -702,6 +786,8 @@ class ModelRunner:
|
|
702
786
|
self.attn_backend = TritonAttnBackend(self)
|
703
787
|
elif self.server_args.attention_backend == "torch_native":
|
704
788
|
self.attn_backend = TorchNativeAttnBackend(self)
|
789
|
+
elif self.server_args.attention_backend == "flashinfer_mla":
|
790
|
+
self.attn_backend = FlashInferMLAAttnBackend(self)
|
705
791
|
else:
|
706
792
|
raise ValueError(
|
707
793
|
f"Invalid attention backend: {self.server_args.attention_backend}"
|
@@ -736,9 +822,16 @@ class ModelRunner:
|
|
736
822
|
return
|
737
823
|
|
738
824
|
tic = time.time()
|
739
|
-
|
825
|
+
before_mem = get_available_gpu_memory(self.device, self.gpu_id)
|
826
|
+
logger.info(
|
827
|
+
f"Capture cuda graph begin. This can take up to several minutes. avail mem={before_mem:.2f} GB"
|
828
|
+
)
|
740
829
|
self.cuda_graph_runner = CudaGraphRunner(self)
|
741
|
-
|
830
|
+
after_mem = get_available_gpu_memory(self.device, self.gpu_id)
|
831
|
+
logger.info(
|
832
|
+
f"Capture cuda graph end. Time elapsed: {time.time() - tic:.2f} s. "
|
833
|
+
f"avail mem={after_mem:.2f} GB. mem usage={(before_mem - after_mem):.2f} GB."
|
834
|
+
)
|
742
835
|
|
743
836
|
def apply_torch_tp(self):
|
744
837
|
logger.info(f"Enabling torch tensor parallelism on {self.tp_size} devices.")
|
@@ -753,8 +846,12 @@ class ModelRunner:
|
|
753
846
|
forward_batch.input_ids, forward_batch.positions, forward_batch
|
754
847
|
)
|
755
848
|
|
756
|
-
def forward_extend(
|
757
|
-
self
|
849
|
+
def forward_extend(
|
850
|
+
self, forward_batch: ForwardBatch, skip_attn_backend_init: bool = False
|
851
|
+
):
|
852
|
+
if not skip_attn_backend_init:
|
853
|
+
self.attn_backend.init_forward_metadata(forward_batch)
|
854
|
+
|
758
855
|
if self.is_generation:
|
759
856
|
if forward_batch.input_embeds is None:
|
760
857
|
return self.model.forward(
|
@@ -798,11 +895,10 @@ class ModelRunner:
|
|
798
895
|
else:
|
799
896
|
raise ValueError(f"Invalid forward mode: {forward_batch.forward_mode}")
|
800
897
|
|
801
|
-
def
|
802
|
-
self, logits_output: LogitsProcessorOutput,
|
803
|
-
)
|
898
|
+
def _preprocess_logits(
|
899
|
+
self, logits_output: LogitsProcessorOutput, sampling_info: SamplingBatchInfo
|
900
|
+
):
|
804
901
|
# Apply logit bias
|
805
|
-
sampling_info = forward_batch.sampling_info
|
806
902
|
if sampling_info.sampling_info_done:
|
807
903
|
# Overlap mode: the function update_regex_vocab_mask was executed
|
808
904
|
# in process_batch_result of the last batch.
|
@@ -811,15 +907,77 @@ class ModelRunner:
|
|
811
907
|
else:
|
812
908
|
# Normal mode: Put CPU-heavy tasks here. They will be overlapped with the forward pass.
|
813
909
|
sampling_info.update_regex_vocab_mask()
|
814
|
-
sampling_info.update_penalties()
|
815
910
|
sampling_info.apply_logits_bias(logits_output.next_token_logits)
|
816
911
|
|
912
|
+
def update_output_logprobs(
|
913
|
+
self,
|
914
|
+
logits_output: LogitsProcessorOutput,
|
915
|
+
sampling_info: SamplingBatchInfo,
|
916
|
+
top_logprobs_nums: List[int],
|
917
|
+
token_ids_logprobs: List[int],
|
918
|
+
next_token_ids: torch.Tensor,
|
919
|
+
*,
|
920
|
+
num_tokens_per_req: List[int],
|
921
|
+
):
|
922
|
+
"""Update the logits_output's output logprob based on next_token_ids
|
923
|
+
|
924
|
+
Args:
|
925
|
+
logits_output: The logits output from the model forward
|
926
|
+
sampling_info: Sampling info for logprob calculation
|
927
|
+
top_logprobs_nums: Number of logprobs per request.
|
928
|
+
next_token_ids: Next token ids.
|
929
|
+
num_tokens_per_req: The number of tokens per request.
|
930
|
+
|
931
|
+
Returns:
|
932
|
+
A list of next_token_ids
|
933
|
+
"""
|
934
|
+
self._preprocess_logits(logits_output, sampling_info)
|
935
|
+
# We should repeat top_logprobs_nums to match num_tokens_per_req.
|
936
|
+
top_logprobs_nums_repeat_interleaved = []
|
937
|
+
token_ids_logprobs_repeat_interleaved = []
|
938
|
+
for num, num_tokens in zip(top_logprobs_nums, num_tokens_per_req):
|
939
|
+
top_logprobs_nums_repeat_interleaved.extend([num] * num_tokens)
|
940
|
+
for token_ids, num_tokens in zip(token_ids_logprobs, num_tokens_per_req):
|
941
|
+
token_ids_logprobs_repeat_interleaved.extend([token_ids] * num_tokens)
|
942
|
+
self.sampler(
|
943
|
+
logits_output,
|
944
|
+
sampling_info,
|
945
|
+
True,
|
946
|
+
top_logprobs_nums_repeat_interleaved,
|
947
|
+
token_ids_logprobs_repeat_interleaved,
|
948
|
+
batch_next_token_ids=next_token_ids,
|
949
|
+
)
|
950
|
+
|
951
|
+
def sample(
|
952
|
+
self,
|
953
|
+
logits_output: LogitsProcessorOutput,
|
954
|
+
forward_batch: ForwardBatch,
|
955
|
+
) -> torch.Tensor:
|
956
|
+
"""Sample and compute logprobs and update logits_output.
|
957
|
+
|
958
|
+
Args:
|
959
|
+
logits_output: The logits output from the model forward
|
960
|
+
forward_batch: The forward batch that generates logits_output
|
961
|
+
|
962
|
+
Returns:
|
963
|
+
A list of next_token_ids
|
964
|
+
"""
|
965
|
+
# For duplex models with multiple output streams.
|
966
|
+
if isinstance(logits_output, tuple):
|
967
|
+
return torch.stack(
|
968
|
+
[self.sample(values, forward_batch) for values in logits_output],
|
969
|
+
axis=-1,
|
970
|
+
)
|
971
|
+
|
972
|
+
self._preprocess_logits(logits_output, forward_batch.sampling_info)
|
973
|
+
|
817
974
|
# Sample the next tokens
|
818
975
|
next_token_ids = self.sampler(
|
819
976
|
logits_output,
|
820
|
-
sampling_info,
|
977
|
+
forward_batch.sampling_info,
|
821
978
|
forward_batch.return_logprob,
|
822
979
|
forward_batch.top_logprobs_nums,
|
980
|
+
forward_batch.token_ids_logprobs,
|
823
981
|
)
|
824
982
|
return next_token_ids
|
825
983
|
|
@@ -831,3 +989,26 @@ class ModelRunner:
|
|
831
989
|
if rope_scaling is None:
|
832
990
|
return False
|
833
991
|
return rope_scaling.get("type", None) == "mrope"
|
992
|
+
|
993
|
+
|
994
|
+
def _model_load_weights_direct(model, named_tensors: List[Tuple[str, torch.Tensor]]):
|
995
|
+
params_dict = dict(model.named_parameters())
|
996
|
+
for name, tensor in named_tensors:
|
997
|
+
default_weight_loader(params_dict[name], tensor)
|
998
|
+
|
999
|
+
|
1000
|
+
def _unwrap_tensor(tensor, tp_rank):
|
1001
|
+
if isinstance(tensor, LocalSerializedTensor):
|
1002
|
+
return tensor.get(tp_rank)
|
1003
|
+
return tensor
|
1004
|
+
|
1005
|
+
|
1006
|
+
@dataclass
|
1007
|
+
class LocalSerializedTensor:
|
1008
|
+
"""torch.Tensor that gets serialized by MultiprocessingSerializer (which only serializes a pointer and not the data).
|
1009
|
+
The i-th element in the list corresponds to i-th rank's GPU."""
|
1010
|
+
|
1011
|
+
values: List[bytes]
|
1012
|
+
|
1013
|
+
def get(self, rank: int):
|
1014
|
+
return MultiprocessingSerializer.deserialize(self.values[rank])
|
@@ -11,7 +11,7 @@ import math
|
|
11
11
|
import os
|
12
12
|
from abc import ABC, abstractmethod
|
13
13
|
from contextlib import contextmanager
|
14
|
-
from typing import Any, Dict, Generator, Iterable, List, Optional, Tuple,
|
14
|
+
from typing import Any, Dict, Generator, Iterable, List, Optional, Tuple, cast
|
15
15
|
|
16
16
|
import gguf
|
17
17
|
import huggingface_hub
|
@@ -19,7 +19,7 @@ import numpy as np
|
|
19
19
|
import torch
|
20
20
|
from huggingface_hub import HfApi, hf_hub_download
|
21
21
|
from torch import nn
|
22
|
-
from transformers import AutoModelForCausalLM
|
22
|
+
from transformers import AutoModelForCausalLM
|
23
23
|
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME
|
24
24
|
|
25
25
|
from sglang.srt.configs.device_config import DeviceConfig
|
@@ -197,7 +197,7 @@ class DefaultModelLoader(BaseModelLoader):
|
|
197
197
|
|
198
198
|
Returns the path to the downloaded model, or None if the model is not
|
199
199
|
downloaded from ModelScope."""
|
200
|
-
if "SGLANG_USE_MODELSCOPE"
|
200
|
+
if os.environ.get("SGLANG_USE_MODELSCOPE", None) == "True":
|
201
201
|
# download model from ModelScope hub,
|
202
202
|
# lazy import so that modelscope is not required for normal use.
|
203
203
|
# pylint: disable=C.
|