sglang 0.4.3.post3__py3-none-any.whl → 0.4.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_serving.py +2 -2
- sglang/lang/chat_template.py +29 -0
- sglang/srt/_custom_ops.py +19 -17
- sglang/srt/configs/__init__.py +2 -0
- sglang/srt/configs/janus_pro.py +629 -0
- sglang/srt/configs/model_config.py +24 -14
- sglang/srt/conversation.py +80 -2
- sglang/srt/custom_op.py +64 -3
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +18 -17
- sglang/srt/distributed/parallel_state.py +10 -1
- sglang/srt/entrypoints/engine.py +5 -3
- sglang/srt/entrypoints/http_server.py +1 -1
- sglang/srt/hf_transformers_utils.py +16 -1
- sglang/srt/layers/attention/flashinfer_backend.py +95 -49
- sglang/srt/layers/attention/flashinfer_mla_backend.py +317 -57
- sglang/srt/layers/attention/triton_backend.py +5 -5
- sglang/srt/layers/attention/triton_ops/decode_attention.py +6 -6
- sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +3 -3
- sglang/srt/layers/attention/triton_ops/extend_attention.py +4 -4
- sglang/srt/layers/attention/triton_ops/rocm_mla_decode_rope.py +3 -3
- sglang/srt/layers/attention/vision.py +43 -62
- sglang/srt/layers/linear.py +1 -1
- sglang/srt/layers/moe/ep_moe/kernels.py +2 -1
- sglang/srt/layers/moe/ep_moe/layer.py +25 -9
- sglang/srt/layers/moe/fused_moe_triton/configs/E=160,N=192,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=64,device_name=NVIDIA_A800-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +63 -23
- sglang/srt/layers/moe/fused_moe_triton/layer.py +16 -4
- sglang/srt/layers/parameter.py +10 -0
- sglang/srt/layers/quantization/__init__.py +90 -68
- sglang/srt/layers/quantization/blockwise_int8.py +1 -2
- sglang/srt/layers/quantization/configs/N=1536,K=1536,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=2048,K=512,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=2304,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=256,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=3072,K=1536,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=3072,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=512,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=1024,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=1152,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=128,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=256,device_name=NVIDIA_H20,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/fp8.py +174 -106
- sglang/srt/layers/quantization/fp8_kernel.py +210 -38
- sglang/srt/layers/quantization/fp8_utils.py +156 -15
- sglang/srt/layers/quantization/modelopt_quant.py +5 -1
- sglang/srt/layers/quantization/w8a8_fp8.py +128 -0
- sglang/srt/layers/quantization/w8a8_int8.py +152 -3
- sglang/srt/layers/rotary_embedding.py +5 -3
- sglang/srt/layers/sampler.py +29 -35
- sglang/srt/layers/vocab_parallel_embedding.py +0 -1
- sglang/srt/lora/backend/__init__.py +9 -12
- sglang/srt/managers/cache_controller.py +72 -8
- sglang/srt/managers/image_processor.py +37 -631
- sglang/srt/managers/image_processors/base_image_processor.py +219 -0
- sglang/srt/managers/image_processors/janus_pro.py +79 -0
- sglang/srt/managers/image_processors/llava.py +152 -0
- sglang/srt/managers/image_processors/minicpmv.py +86 -0
- sglang/srt/managers/image_processors/mlama.py +60 -0
- sglang/srt/managers/image_processors/qwen_vl.py +161 -0
- sglang/srt/managers/io_struct.py +33 -15
- sglang/srt/managers/multi_modality_padding.py +134 -0
- sglang/srt/managers/schedule_batch.py +212 -117
- sglang/srt/managers/schedule_policy.py +40 -8
- sglang/srt/managers/scheduler.py +258 -782
- sglang/srt/managers/scheduler_output_processor_mixin.py +611 -0
- sglang/srt/managers/tokenizer_manager.py +7 -6
- sglang/srt/managers/tp_worker_overlap_thread.py +4 -1
- sglang/srt/mem_cache/base_prefix_cache.py +6 -8
- sglang/srt/mem_cache/chunk_cache.py +12 -44
- sglang/srt/mem_cache/hiradix_cache.py +63 -34
- sglang/srt/mem_cache/memory_pool.py +112 -46
- sglang/srt/mem_cache/paged_allocator.py +283 -0
- sglang/srt/mem_cache/radix_cache.py +117 -36
- sglang/srt/metrics/collector.py +8 -0
- sglang/srt/model_executor/cuda_graph_runner.py +10 -11
- sglang/srt/model_executor/forward_batch_info.py +12 -8
- sglang/srt/model_executor/model_runner.py +153 -134
- sglang/srt/model_loader/loader.py +2 -1
- sglang/srt/model_loader/weight_utils.py +1 -1
- sglang/srt/models/deepseek_janus_pro.py +2127 -0
- sglang/srt/models/deepseek_nextn.py +23 -3
- sglang/srt/models/deepseek_v2.py +25 -19
- sglang/srt/models/minicpmv.py +28 -89
- sglang/srt/models/mllama.py +1 -1
- sglang/srt/models/qwen2.py +0 -1
- sglang/srt/models/qwen2_5_vl.py +25 -50
- sglang/srt/models/qwen2_vl.py +33 -49
- sglang/srt/openai_api/adapter.py +37 -15
- sglang/srt/openai_api/protocol.py +8 -1
- sglang/srt/sampling/penaltylib/frequency_penalty.py +0 -1
- sglang/srt/sampling/penaltylib/presence_penalty.py +0 -1
- sglang/srt/server_args.py +19 -20
- sglang/srt/speculative/build_eagle_tree.py +6 -1
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +1 -11
- sglang/srt/speculative/eagle_utils.py +2 -1
- sglang/srt/speculative/eagle_worker.py +109 -38
- sglang/srt/utils.py +104 -9
- sglang/test/runners.py +104 -10
- sglang/test/test_block_fp8.py +106 -16
- sglang/test/test_custom_ops.py +88 -0
- sglang/test/test_utils.py +20 -4
- sglang/utils.py +0 -4
- sglang/version.py +1 -1
- {sglang-0.4.3.post3.dist-info → sglang-0.4.4.dist-info}/METADATA +9 -9
- {sglang-0.4.3.post3.dist-info → sglang-0.4.4.dist-info}/RECORD +128 -83
- {sglang-0.4.3.post3.dist-info → sglang-0.4.4.dist-info}/WHEEL +1 -1
- {sglang-0.4.3.post3.dist-info → sglang-0.4.4.dist-info}/LICENSE +0 -0
- {sglang-0.4.3.post3.dist-info → sglang-0.4.4.dist-info}/top_level.txt +0 -0
@@ -38,12 +38,12 @@ import triton
|
|
38
38
|
import triton.language as tl
|
39
39
|
|
40
40
|
from sglang.srt.layers.rotary_embedding import MRotaryEmbedding
|
41
|
-
from sglang.srt.utils import get_compiler_backend
|
41
|
+
from sglang.srt.utils import get_compiler_backend, next_power_of_2
|
42
42
|
|
43
43
|
if TYPE_CHECKING:
|
44
44
|
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
45
45
|
from sglang.srt.managers.schedule_batch import ImageInputs, ModelWorkerBatch
|
46
|
-
from sglang.srt.mem_cache.memory_pool import
|
46
|
+
from sglang.srt.mem_cache.memory_pool import KVCache, ReqToTokenPool
|
47
47
|
from sglang.srt.model_executor.model_runner import ModelRunner
|
48
48
|
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
49
49
|
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
|
@@ -51,9 +51,8 @@ if TYPE_CHECKING:
|
|
51
51
|
|
52
52
|
|
53
53
|
class ForwardMode(IntEnum):
|
54
|
-
# Prefill a new sequence. This is deprecated now. "EXTEND" covers this case.
|
55
|
-
PREFILL = auto()
|
56
54
|
# Extend a sequence. The KV cache of the beginning part of the sequence is already computed (e.g., system prompt).
|
55
|
+
# It is also called "prefill" in common terminology.
|
57
56
|
EXTEND = auto()
|
58
57
|
# Decode one token.
|
59
58
|
DECODE = auto()
|
@@ -153,6 +152,12 @@ class ForwardBatch:
|
|
153
152
|
top_logprobs_nums: Optional[List[int]] = None
|
154
153
|
token_ids_logprobs: Optional[List[List[int]]] = None
|
155
154
|
|
155
|
+
# For logits and logprobs post processing
|
156
|
+
temp_scaled_logprobs: bool = False
|
157
|
+
temperature: torch.Tensor = None
|
158
|
+
top_p_normalized_logprobs: bool = False
|
159
|
+
top_p: torch.Tensor = None
|
160
|
+
|
156
161
|
# Position information
|
157
162
|
positions: torch.Tensor = None
|
158
163
|
|
@@ -189,7 +194,7 @@ class ForwardBatch:
|
|
189
194
|
|
190
195
|
# Attention backend
|
191
196
|
req_to_token_pool: ReqToTokenPool = None
|
192
|
-
token_to_kv_pool:
|
197
|
+
token_to_kv_pool: KVCache = None
|
193
198
|
attn_backend: AttentionBackend = None
|
194
199
|
|
195
200
|
# For DP attention
|
@@ -229,7 +234,6 @@ class ForwardBatch:
|
|
229
234
|
extend_input_logprob_token_ids_gpu = (
|
230
235
|
batch.extend_input_logprob_token_ids.to(device, non_blocking=True)
|
231
236
|
)
|
232
|
-
|
233
237
|
ret = cls(
|
234
238
|
forward_mode=batch.forward_mode,
|
235
239
|
batch_size=len(batch.seq_lens),
|
@@ -417,8 +421,8 @@ def compute_position_kernel(
|
|
417
421
|
prefix_len = tl.load(extend_prefix_lens + pid) if has_prefix else 0
|
418
422
|
seq_len = tl.load(extend_seq_lens + pid)
|
419
423
|
|
420
|
-
#
|
421
|
-
cumsum_start = 0
|
424
|
+
# NOTE: This can be slow for large bs
|
425
|
+
cumsum_start = tl.cast(0, tl.int64)
|
422
426
|
for i in range(pid):
|
423
427
|
cumsum_start += tl.load(extend_seq_lens + i)
|
424
428
|
|
@@ -35,17 +35,13 @@ from sglang.srt.distributed import (
|
|
35
35
|
set_custom_all_reduce,
|
36
36
|
)
|
37
37
|
from sglang.srt.distributed.parallel_state import monkey_patch_vllm_parallel_state
|
38
|
-
from sglang.srt.layers.attention.double_sparsity_backend import DoubleSparseAttnBackend
|
39
|
-
from sglang.srt.layers.attention.flashinfer_backend import FlashInferAttnBackend
|
40
|
-
from sglang.srt.layers.attention.flashinfer_mla_backend import FlashInferMLAAttnBackend
|
41
|
-
from sglang.srt.layers.attention.torch_native_backend import TorchNativeAttnBackend
|
42
|
-
from sglang.srt.layers.attention.triton_backend import TritonAttnBackend
|
43
38
|
from sglang.srt.layers.dp_attention import (
|
44
39
|
get_attention_tp_group,
|
45
40
|
get_attention_tp_size,
|
46
41
|
initialize_dp_attention,
|
47
42
|
)
|
48
43
|
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
44
|
+
from sglang.srt.layers.quantization import monkey_patch_isinstance_for_vllm_base_layer
|
49
45
|
from sglang.srt.layers.sampler import Sampler
|
50
46
|
from sglang.srt.layers.torchao_utils import apply_torchao_config_to_model
|
51
47
|
from sglang.srt.lora.lora_manager import LoRAManager
|
@@ -57,9 +53,16 @@ from sglang.srt.mem_cache.memory_pool import (
|
|
57
53
|
ReqToTokenPool,
|
58
54
|
TokenToKVPoolAllocator,
|
59
55
|
)
|
56
|
+
from sglang.srt.mem_cache.paged_allocator import PagedTokenToKVPoolAllocator
|
60
57
|
from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner
|
61
58
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
62
59
|
from sglang.srt.model_loader import get_model
|
60
|
+
from sglang.srt.model_loader.loader import (
|
61
|
+
DefaultModelLoader,
|
62
|
+
device_loading_context,
|
63
|
+
get_model_loader,
|
64
|
+
)
|
65
|
+
from sglang.srt.model_loader.utils import set_default_torch_dtype
|
63
66
|
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
64
67
|
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
65
68
|
from sglang.srt.server_args import ServerArgs
|
@@ -77,11 +80,9 @@ from sglang.srt.utils import (
|
|
77
80
|
set_cpu_offload_max_bytes,
|
78
81
|
set_cuda_arch,
|
79
82
|
)
|
80
|
-
from sglang.utils import get_exception_traceback
|
81
83
|
|
82
84
|
logger = logging.getLogger(__name__)
|
83
85
|
|
84
|
-
|
85
86
|
SGLANG_CI_SMALL_KV_SIZE = os.getenv("SGLANG_CI_SMALL_KV_SIZE", None)
|
86
87
|
UNBALANCED_MODEL_LOADING_TIMEOUT_S = 300
|
87
88
|
|
@@ -118,70 +119,22 @@ class ModelRunner:
|
|
118
119
|
self.spec_algorithm = SpeculativeAlgorithm.from_string(
|
119
120
|
server_args.speculative_algorithm
|
120
121
|
)
|
122
|
+
self.page_size = server_args.page_size
|
121
123
|
self.req_to_token_pool = req_to_token_pool
|
122
124
|
self.token_to_kv_pool_allocator = token_to_kv_pool_allocator
|
123
125
|
|
124
126
|
# Model-specific adjustment
|
125
|
-
|
126
|
-
self.model_config.attention_arch == AttentionArch.MLA
|
127
|
-
and not self.server_args.disable_mla
|
128
|
-
):
|
129
|
-
# TODO: add MLA optimization on CPU
|
130
|
-
if self.server_args.device != "cpu":
|
131
|
-
if server_args.enable_flashinfer_mla:
|
132
|
-
logger.info(
|
133
|
-
"MLA optimization is turned on. Use flashinfer mla backend."
|
134
|
-
)
|
135
|
-
self.server_args.attention_backend = "flashinfer_mla"
|
136
|
-
else:
|
137
|
-
logger.info("MLA optimization is turned on. Use triton backend.")
|
138
|
-
self.server_args.attention_backend = "triton"
|
139
|
-
|
140
|
-
if self.server_args.enable_double_sparsity:
|
141
|
-
logger.info(
|
142
|
-
"Double sparsity optimization is turned on. Use triton backend without CUDA graph."
|
143
|
-
)
|
144
|
-
self.server_args.attention_backend = "triton"
|
145
|
-
self.server_args.disable_cuda_graph = True
|
146
|
-
if self.server_args.ds_heavy_channel_type is None:
|
147
|
-
raise ValueError(
|
148
|
-
"Please specify the heavy channel type for double sparsity optimization."
|
149
|
-
)
|
150
|
-
self.init_double_sparsity_channel_config(
|
151
|
-
self.server_args.ds_heavy_channel_type
|
152
|
-
)
|
127
|
+
self.model_specific_adjustment()
|
153
128
|
|
154
|
-
if self.is_multimodal:
|
155
|
-
self.mem_fraction_static *= 0.95
|
156
|
-
logger.info(
|
157
|
-
f"Automatically reduce --mem-fraction-static to {self.mem_fraction_static:.3f} "
|
158
|
-
f"because this is a multimodal model."
|
159
|
-
)
|
160
|
-
|
161
|
-
if self.model_config.hf_config.architectures == [
|
162
|
-
"MllamaForConditionalGeneration"
|
163
|
-
]:
|
164
|
-
logger.info("Automatically turn off --chunked-prefill-size for mllama.")
|
165
|
-
server_args.chunked_prefill_size = -1
|
166
|
-
|
167
|
-
if self.model_config.hf_config.architectures == [
|
168
|
-
"Qwen2VLForConditionalGeneration"
|
169
|
-
]:
|
170
|
-
# TODO: qwen2-vl does not support radix cache now, set disable_radix_cache=True automatically
|
171
|
-
logger.info(
|
172
|
-
"Automatically turn off --chunked-prefill-size and disable radix cache for qwen2-vl."
|
173
|
-
)
|
174
|
-
server_args.chunked_prefill_size = -1
|
175
|
-
server_args.disable_radix_cache = True
|
176
|
-
|
177
|
-
# Global vars
|
178
129
|
if server_args.show_time_cost:
|
179
130
|
enable_show_time_cost()
|
131
|
+
|
180
132
|
if server_args.disable_outlines_disk_cache:
|
181
133
|
from outlines.caching import disable_cache
|
182
134
|
|
183
135
|
disable_cache()
|
184
136
|
|
137
|
+
# Global vars
|
185
138
|
global_server_args_dict.update(
|
186
139
|
{
|
187
140
|
"attention_backend": server_args.attention_backend,
|
@@ -203,11 +156,17 @@ class ModelRunner:
|
|
203
156
|
}
|
204
157
|
)
|
205
158
|
|
159
|
+
# CPU offload
|
206
160
|
set_cpu_offload_max_bytes(int(server_args.cpu_offload_gb * 1024**3))
|
207
161
|
|
208
162
|
# Get memory before model loading
|
209
163
|
min_per_gpu_memory = self.init_torch_distributed()
|
210
164
|
|
165
|
+
# If it is a draft model tp_group can be different.
|
166
|
+
self.initialize(min_per_gpu_memory)
|
167
|
+
|
168
|
+
def initialize(self, min_per_gpu_memory: float):
|
169
|
+
server_args = self.server_args
|
211
170
|
self.memory_saver_adapter = TorchMemorySaverAdapter.create(
|
212
171
|
enable=self.server_args.enable_memory_saver
|
213
172
|
)
|
@@ -216,18 +175,6 @@ class ModelRunner:
|
|
216
175
|
self.sampler = Sampler()
|
217
176
|
self.load_model()
|
218
177
|
|
219
|
-
# Handle the case where some of models don't finish loading.
|
220
|
-
try:
|
221
|
-
dist.monitored_barrier(
|
222
|
-
group=get_tp_group().cpu_group,
|
223
|
-
timeout=datetime.timedelta(seconds=UNBALANCED_MODEL_LOADING_TIMEOUT_S),
|
224
|
-
wait_all_ranks=True,
|
225
|
-
)
|
226
|
-
except RuntimeError:
|
227
|
-
raise ValueError(
|
228
|
-
f"TP rank {self.tp_rank} could finish the model loading, but there are other ranks that didn't finish loading. It is likely due to unexpected failures (e.g., OOM) or a slow node."
|
229
|
-
) from None
|
230
|
-
|
231
178
|
# Apply torchao quantization
|
232
179
|
torchao_applied = getattr(self.model, "torchao_applied", False)
|
233
180
|
# In layered loading, torchao may have been applied
|
@@ -244,9 +191,11 @@ class ModelRunner:
|
|
244
191
|
else:
|
245
192
|
self.torch_tp_applied = False
|
246
193
|
|
247
|
-
# Init
|
194
|
+
# Init lora
|
248
195
|
if server_args.lora_paths is not None:
|
249
196
|
self.init_lora_manager()
|
197
|
+
|
198
|
+
# Init memory pool and attention backends
|
250
199
|
self.init_memory_pool(
|
251
200
|
min_per_gpu_memory,
|
252
201
|
server_args.max_running_requests,
|
@@ -260,10 +209,63 @@ class ModelRunner:
|
|
260
209
|
self.cuda_graph_runner = None
|
261
210
|
self.init_attention_backend()
|
262
211
|
|
212
|
+
def model_specific_adjustment(self):
|
213
|
+
server_args = self.server_args
|
214
|
+
|
215
|
+
if (
|
216
|
+
self.model_config.attention_arch == AttentionArch.MLA
|
217
|
+
and not server_args.disable_mla
|
218
|
+
):
|
219
|
+
# TODO: add MLA optimization on CPU
|
220
|
+
if server_args.device != "cpu":
|
221
|
+
if server_args.enable_flashinfer_mla:
|
222
|
+
logger.info(
|
223
|
+
"MLA optimization is turned on. Use flashinfer mla backend."
|
224
|
+
)
|
225
|
+
server_args.attention_backend = "flashinfer_mla"
|
226
|
+
else:
|
227
|
+
logger.info("MLA optimization is turned on. Use triton backend.")
|
228
|
+
server_args.attention_backend = "triton"
|
229
|
+
|
230
|
+
if server_args.enable_double_sparsity:
|
231
|
+
logger.info(
|
232
|
+
"Double sparsity optimization is turned on. Use triton backend without CUDA graph."
|
233
|
+
)
|
234
|
+
server_args.attention_backend = "triton"
|
235
|
+
server_args.disable_cuda_graph = True
|
236
|
+
if server_args.ds_heavy_channel_type is None:
|
237
|
+
raise ValueError(
|
238
|
+
"Please specify the heavy channel type for double sparsity optimization."
|
239
|
+
)
|
240
|
+
self.init_double_sparsity_channel_config(server_args.ds_heavy_channel_type)
|
241
|
+
|
242
|
+
if self.is_multimodal:
|
243
|
+
self.mem_fraction_static *= 0.95
|
244
|
+
logger.info(
|
245
|
+
f"Automatically reduce --mem-fraction-static to {self.mem_fraction_static:.3f} "
|
246
|
+
f"because this is a multimodal model."
|
247
|
+
)
|
248
|
+
|
249
|
+
if self.model_config.hf_config.architectures == [
|
250
|
+
"MllamaForConditionalGeneration"
|
251
|
+
]:
|
252
|
+
logger.info("Automatically turn off --chunked-prefill-size for mllama.")
|
253
|
+
server_args.chunked_prefill_size = -1
|
254
|
+
|
255
|
+
if self.model_config.hf_config.architectures == [
|
256
|
+
"Qwen2VLForConditionalGeneration"
|
257
|
+
]:
|
258
|
+
# TODO: qwen2-vl does not support radix cache now, set disable_radix_cache=True automatically
|
259
|
+
logger.info(
|
260
|
+
"Automatically turn off --chunked-prefill-size and disable radix cache for qwen2-vl."
|
261
|
+
)
|
262
|
+
server_args.chunked_prefill_size = -1
|
263
|
+
server_args.disable_radix_cache = True
|
264
|
+
|
263
265
|
def init_torch_distributed(self):
|
264
266
|
logger.info("Init torch distributed begin.")
|
265
|
-
torch.get_device_module(self.device).set_device(self.gpu_id)
|
266
267
|
|
268
|
+
torch.get_device_module(self.device).set_device(self.gpu_id)
|
267
269
|
if self.device == "cuda":
|
268
270
|
backend = "nccl"
|
269
271
|
elif self.device == "xpu":
|
@@ -304,15 +306,16 @@ class ModelRunner:
|
|
304
306
|
min_per_gpu_memory = get_available_gpu_memory(
|
305
307
|
self.device, self.gpu_id, distributed=self.tp_size > 1
|
306
308
|
)
|
307
|
-
local_gpu_memory = get_available_gpu_memory(self.device, self.gpu_id)
|
308
309
|
self.tp_group = get_tp_group()
|
309
310
|
self.attention_tp_group = get_attention_tp_group()
|
310
311
|
|
311
312
|
# Check memory for tensor parallelism
|
313
|
+
local_gpu_memory = get_available_gpu_memory(self.device, self.gpu_id)
|
312
314
|
if self.tp_size > 1:
|
313
315
|
if min_per_gpu_memory < local_gpu_memory * 0.9:
|
314
316
|
raise ValueError(
|
315
|
-
"The memory capacity is unbalanced. Some GPUs may be occupied by other processes."
|
317
|
+
"The memory capacity is unbalanced. Some GPUs may be occupied by other processes. "
|
318
|
+
f"{min_per_gpu_memory=}, {local_gpu_memory=}, {local_gpu_memory * 0.9=}"
|
316
319
|
)
|
317
320
|
|
318
321
|
logger.info(
|
@@ -352,6 +355,8 @@ class ModelRunner:
|
|
352
355
|
# Load the model
|
353
356
|
# Remove monkey_patch when linear.py quant remove dependencies with vllm
|
354
357
|
monkey_patch_vllm_parallel_state()
|
358
|
+
monkey_patch_isinstance_for_vllm_base_layer()
|
359
|
+
|
355
360
|
with self.memory_saver_adapter.region():
|
356
361
|
self.model = get_model(
|
357
362
|
model_config=self.model_config,
|
@@ -359,6 +364,7 @@ class ModelRunner:
|
|
359
364
|
device_config=DeviceConfig(self.device),
|
360
365
|
)
|
361
366
|
monkey_patch_vllm_parallel_state(reverse=True)
|
367
|
+
monkey_patch_isinstance_for_vllm_base_layer(reverse=True)
|
362
368
|
|
363
369
|
if self.server_args.kv_cache_dtype == "fp8_e4m3":
|
364
370
|
if self.server_args.quantization_param_path is not None:
|
@@ -400,17 +406,22 @@ class ModelRunner:
|
|
400
406
|
f"mem usage={(before_avail_memory - after_avail_memory):.2f} GB."
|
401
407
|
)
|
402
408
|
|
409
|
+
# Handle the case where some ranks do not finish loading.
|
410
|
+
try:
|
411
|
+
dist.monitored_barrier(
|
412
|
+
group=get_tp_group().cpu_group,
|
413
|
+
timeout=datetime.timedelta(seconds=UNBALANCED_MODEL_LOADING_TIMEOUT_S),
|
414
|
+
wait_all_ranks=True,
|
415
|
+
)
|
416
|
+
except RuntimeError:
|
417
|
+
raise ValueError(
|
418
|
+
f"TP rank {self.tp_rank} could finish the model loading, but there are other ranks that didn't finish loading. It is likely due to unexpected failures (e.g., OOM) or a slow node."
|
419
|
+
) from None
|
420
|
+
|
403
421
|
def update_weights_from_disk(
|
404
422
|
self, model_path: str, load_format: str
|
405
423
|
) -> tuple[bool, str]:
|
406
424
|
"""Update engine weights in-place from the disk."""
|
407
|
-
from sglang.srt.model_loader.loader import (
|
408
|
-
DefaultModelLoader,
|
409
|
-
device_loading_context,
|
410
|
-
get_model_loader,
|
411
|
-
)
|
412
|
-
from sglang.srt.model_loader.utils import set_default_torch_dtype
|
413
|
-
|
414
425
|
logger.info(
|
415
426
|
f"Update engine weights online from disk begin. "
|
416
427
|
f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
|
@@ -420,7 +431,7 @@ class ModelRunner:
|
|
420
431
|
self.model_config.model_path = model_path
|
421
432
|
load_config = LoadConfig(load_format=load_format)
|
422
433
|
|
423
|
-
# Only support
|
434
|
+
# Only support DefaultModelLoader for now
|
424
435
|
loader = get_model_loader(load_config)
|
425
436
|
if not isinstance(loader, DefaultModelLoader):
|
426
437
|
message = f"Failed to get model loader: {loader}."
|
@@ -694,6 +705,12 @@ class ModelRunner:
|
|
694
705
|
)
|
695
706
|
self.max_total_num_tokens = min(self.max_total_num_tokens, max_total_tokens)
|
696
707
|
|
708
|
+
self.max_total_num_tokens = (
|
709
|
+
self.max_total_num_tokens
|
710
|
+
// self.server_args.page_size
|
711
|
+
* self.server_args.page_size
|
712
|
+
)
|
713
|
+
|
697
714
|
if self.max_total_num_tokens <= 0:
|
698
715
|
raise RuntimeError(
|
699
716
|
"Not enough memory. Please try to increase --mem-fraction-static."
|
@@ -710,21 +727,13 @@ class ModelRunner:
|
|
710
727
|
# Draft worker shares req_to_token_pool with the target worker.
|
711
728
|
assert self.is_draft_worker
|
712
729
|
|
713
|
-
if self.token_to_kv_pool_allocator is None:
|
714
|
-
self.token_to_kv_pool_allocator = TokenToKVPoolAllocator(
|
715
|
-
self.max_total_num_tokens,
|
716
|
-
dtype=self.kv_cache_dtype,
|
717
|
-
device=self.device,
|
718
|
-
)
|
719
|
-
else:
|
720
|
-
assert self.is_draft_worker
|
721
|
-
|
722
730
|
if (
|
723
731
|
self.model_config.attention_arch == AttentionArch.MLA
|
724
732
|
and not self.server_args.disable_mla
|
725
733
|
):
|
726
734
|
self.token_to_kv_pool = MLATokenToKVPool(
|
727
735
|
self.max_total_num_tokens,
|
736
|
+
page_size=self.page_size,
|
728
737
|
dtype=self.kv_cache_dtype,
|
729
738
|
kv_lora_rank=self.model_config.kv_lora_rank,
|
730
739
|
qk_rope_head_dim=self.model_config.qk_rope_head_dim,
|
@@ -735,6 +744,7 @@ class ModelRunner:
|
|
735
744
|
elif self.server_args.enable_double_sparsity:
|
736
745
|
self.token_to_kv_pool = DoubleSparseTokenToKVPool(
|
737
746
|
self.max_total_num_tokens,
|
747
|
+
page_size=self.page_size,
|
738
748
|
dtype=self.kv_cache_dtype,
|
739
749
|
head_num=self.model_config.get_num_kv_heads(get_attention_tp_size()),
|
740
750
|
head_dim=self.model_config.head_dim,
|
@@ -746,6 +756,7 @@ class ModelRunner:
|
|
746
756
|
else:
|
747
757
|
self.token_to_kv_pool = MHATokenToKVPool(
|
748
758
|
self.max_total_num_tokens,
|
759
|
+
page_size=self.page_size,
|
749
760
|
dtype=self.kv_cache_dtype,
|
750
761
|
head_num=self.model_config.get_num_kv_heads(get_attention_tp_size()),
|
751
762
|
head_dim=self.model_config.head_dim,
|
@@ -753,6 +764,26 @@ class ModelRunner:
|
|
753
764
|
device=self.device,
|
754
765
|
enable_memory_saver=self.server_args.enable_memory_saver,
|
755
766
|
)
|
767
|
+
|
768
|
+
if self.token_to_kv_pool_allocator is None:
|
769
|
+
if self.page_size == 1:
|
770
|
+
self.token_to_kv_pool_allocator = TokenToKVPoolAllocator(
|
771
|
+
self.max_total_num_tokens,
|
772
|
+
dtype=self.kv_cache_dtype,
|
773
|
+
device=self.device,
|
774
|
+
kvcache=self.token_to_kv_pool,
|
775
|
+
)
|
776
|
+
else:
|
777
|
+
self.token_to_kv_pool_allocator = PagedTokenToKVPoolAllocator(
|
778
|
+
self.max_total_num_tokens,
|
779
|
+
page_size=self.page_size,
|
780
|
+
dtype=self.kv_cache_dtype,
|
781
|
+
device=self.device,
|
782
|
+
kvcache=self.token_to_kv_pool,
|
783
|
+
)
|
784
|
+
else:
|
785
|
+
assert self.is_draft_worker
|
786
|
+
|
756
787
|
logger.info(
|
757
788
|
f"Memory pool end. "
|
758
789
|
f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
|
@@ -770,6 +801,13 @@ class ModelRunner:
|
|
770
801
|
def init_attention_backend(self):
|
771
802
|
"""Init attention kernel backend."""
|
772
803
|
if self.server_args.attention_backend == "flashinfer":
|
804
|
+
from sglang.srt.layers.attention.flashinfer_backend import (
|
805
|
+
FlashInferAttnBackend,
|
806
|
+
)
|
807
|
+
|
808
|
+
# Init streams
|
809
|
+
if self.server_args.speculative_algorithm == "EAGLE":
|
810
|
+
self.plan_stream_for_flashinfer = torch.cuda.Stream()
|
773
811
|
self.attn_backend = FlashInferAttnBackend(self)
|
774
812
|
elif self.server_args.attention_backend == "triton":
|
775
813
|
assert self.sliding_window_size is None, (
|
@@ -781,12 +819,26 @@ class ModelRunner:
|
|
781
819
|
"Please use `--attention-backend flashinfer`."
|
782
820
|
)
|
783
821
|
if self.server_args.enable_double_sparsity:
|
822
|
+
from sglang.srt.layers.attention.double_sparsity_backend import (
|
823
|
+
DoubleSparseAttnBackend,
|
824
|
+
)
|
825
|
+
|
784
826
|
self.attn_backend = DoubleSparseAttnBackend(self)
|
785
827
|
else:
|
828
|
+
from sglang.srt.layers.attention.triton_backend import TritonAttnBackend
|
829
|
+
|
786
830
|
self.attn_backend = TritonAttnBackend(self)
|
787
831
|
elif self.server_args.attention_backend == "torch_native":
|
832
|
+
from sglang.srt.layers.attention.torch_native_backend import (
|
833
|
+
TorchNativeAttnBackend,
|
834
|
+
)
|
835
|
+
|
788
836
|
self.attn_backend = TorchNativeAttnBackend(self)
|
789
837
|
elif self.server_args.attention_backend == "flashinfer_mla":
|
838
|
+
from sglang.srt.layers.attention.flashinfer_mla_backend import (
|
839
|
+
FlashInferMLAAttnBackend,
|
840
|
+
)
|
841
|
+
|
790
842
|
self.attn_backend = FlashInferMLAAttnBackend(self)
|
791
843
|
else:
|
792
844
|
raise ValueError(
|
@@ -878,18 +930,24 @@ class ModelRunner:
|
|
878
930
|
forward_batch.input_ids, forward_batch.positions, forward_batch
|
879
931
|
)
|
880
932
|
|
881
|
-
def forward(
|
933
|
+
def forward(
|
934
|
+
self, forward_batch: ForwardBatch, skip_attn_backend_init: bool = False
|
935
|
+
) -> LogitsProcessorOutput:
|
882
936
|
if (
|
883
937
|
forward_batch.forward_mode.is_cuda_graph()
|
884
938
|
and self.cuda_graph_runner
|
885
939
|
and self.cuda_graph_runner.can_run(forward_batch)
|
886
940
|
):
|
887
|
-
return self.cuda_graph_runner.replay(
|
941
|
+
return self.cuda_graph_runner.replay(
|
942
|
+
forward_batch, skip_attn_backend_init=skip_attn_backend_init
|
943
|
+
)
|
888
944
|
|
889
945
|
if forward_batch.forward_mode.is_decode():
|
890
946
|
return self.forward_decode(forward_batch)
|
891
947
|
elif forward_batch.forward_mode.is_extend():
|
892
|
-
return self.forward_extend(
|
948
|
+
return self.forward_extend(
|
949
|
+
forward_batch, skip_attn_backend_init=skip_attn_backend_init
|
950
|
+
)
|
893
951
|
elif forward_batch.forward_mode.is_idle():
|
894
952
|
return self.forward_idle(forward_batch)
|
895
953
|
else:
|
@@ -909,45 +967,6 @@ class ModelRunner:
|
|
909
967
|
sampling_info.update_regex_vocab_mask()
|
910
968
|
sampling_info.apply_logits_bias(logits_output.next_token_logits)
|
911
969
|
|
912
|
-
def update_output_logprobs(
|
913
|
-
self,
|
914
|
-
logits_output: LogitsProcessorOutput,
|
915
|
-
sampling_info: SamplingBatchInfo,
|
916
|
-
top_logprobs_nums: List[int],
|
917
|
-
token_ids_logprobs: List[int],
|
918
|
-
next_token_ids: torch.Tensor,
|
919
|
-
*,
|
920
|
-
num_tokens_per_req: List[int],
|
921
|
-
):
|
922
|
-
"""Update the logits_output's output logprob based on next_token_ids
|
923
|
-
|
924
|
-
Args:
|
925
|
-
logits_output: The logits output from the model forward
|
926
|
-
sampling_info: Sampling info for logprob calculation
|
927
|
-
top_logprobs_nums: Number of logprobs per request.
|
928
|
-
next_token_ids: Next token ids.
|
929
|
-
num_tokens_per_req: The number of tokens per request.
|
930
|
-
|
931
|
-
Returns:
|
932
|
-
A list of next_token_ids
|
933
|
-
"""
|
934
|
-
self._preprocess_logits(logits_output, sampling_info)
|
935
|
-
# We should repeat top_logprobs_nums to match num_tokens_per_req.
|
936
|
-
top_logprobs_nums_repeat_interleaved = []
|
937
|
-
token_ids_logprobs_repeat_interleaved = []
|
938
|
-
for num, num_tokens in zip(top_logprobs_nums, num_tokens_per_req):
|
939
|
-
top_logprobs_nums_repeat_interleaved.extend([num] * num_tokens)
|
940
|
-
for token_ids, num_tokens in zip(token_ids_logprobs, num_tokens_per_req):
|
941
|
-
token_ids_logprobs_repeat_interleaved.extend([token_ids] * num_tokens)
|
942
|
-
self.sampler(
|
943
|
-
logits_output,
|
944
|
-
sampling_info,
|
945
|
-
True,
|
946
|
-
top_logprobs_nums_repeat_interleaved,
|
947
|
-
token_ids_logprobs_repeat_interleaved,
|
948
|
-
batch_next_token_ids=next_token_ids,
|
949
|
-
)
|
950
|
-
|
951
970
|
def sample(
|
952
971
|
self,
|
953
972
|
logits_output: LogitsProcessorOutput,
|
@@ -48,6 +48,7 @@ from sglang.srt.model_loader.weight_utils import (
|
|
48
48
|
safetensors_weights_iterator,
|
49
49
|
)
|
50
50
|
from sglang.srt.utils import (
|
51
|
+
get_bool_env_var,
|
51
52
|
get_device_capability,
|
52
53
|
is_pin_memory_available,
|
53
54
|
set_weight_attrs,
|
@@ -197,7 +198,7 @@ class DefaultModelLoader(BaseModelLoader):
|
|
197
198
|
|
198
199
|
Returns the path to the downloaded model, or None if the model is not
|
199
200
|
downloaded from ModelScope."""
|
200
|
-
if
|
201
|
+
if get_bool_env_var("SGLANG_USE_MODELSCOPE"):
|
201
202
|
# download model from ModelScope hub,
|
202
203
|
# lazy import so that modelscope is not required for normal use.
|
203
204
|
# pylint: disable=C.
|
@@ -455,7 +455,7 @@ def pt_weights_iterator(
|
|
455
455
|
disable=not enable_tqdm,
|
456
456
|
bar_format=_BAR_FORMAT,
|
457
457
|
):
|
458
|
-
state = torch.load(bin_file, map_location="cpu")
|
458
|
+
state = torch.load(bin_file, map_location="cpu", weights_only=True)
|
459
459
|
yield from state.items()
|
460
460
|
del state
|
461
461
|
torch.cuda.empty_cache()
|