sglang 0.4.3.post2__py3-none-any.whl → 0.4.3.post3__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/api.py +1 -1
- sglang/bench_offline_throughput.py +19 -0
- sglang/bench_one_batch.py +2 -2
- sglang/bench_serving.py +123 -79
- sglang/global_config.py +8 -3
- sglang/lang/backend/runtime_endpoint.py +1 -1
- sglang/lang/ir.py +1 -1
- sglang/srt/_custom_ops.py +83 -91
- sglang/srt/configs/load_config.py +4 -1
- sglang/srt/configs/model_config.py +48 -2
- sglang/srt/configs/qwen2_5_vl_config.py +5 -2
- sglang/srt/constrained/base_grammar_backend.py +117 -15
- sglang/srt/constrained/llguidance_backend.py +151 -0
- sglang/srt/constrained/outlines_backend.py +24 -33
- sglang/srt/constrained/xgrammar_backend.py +69 -38
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +225 -80
- sglang/srt/distributed/parallel_state.py +48 -3
- sglang/srt/entrypoints/engine.py +67 -9
- sglang/srt/entrypoints/http_server.py +190 -41
- sglang/srt/entrypoints/verl_engine.py +147 -0
- sglang/srt/function_call_parser.py +0 -1
- sglang/srt/layers/activation.py +11 -0
- sglang/srt/layers/attention/{__init__.py → base_attn_backend.py} +14 -6
- sglang/srt/layers/attention/double_sparsity_backend.py +1 -1
- sglang/srt/layers/attention/flashinfer_backend.py +220 -378
- sglang/srt/layers/attention/flashinfer_mla_backend.py +582 -0
- sglang/srt/layers/attention/torch_native_backend.py +1 -1
- sglang/srt/layers/attention/triton_backend.py +9 -6
- sglang/srt/layers/attention/triton_ops/decode_attention.py +3 -0
- sglang/srt/layers/attention/triton_ops/extend_attention.py +20 -4
- sglang/srt/layers/attention/triton_ops/rocm_mla_decode_rope.py +439 -0
- sglang/srt/layers/attention/utils.py +39 -0
- sglang/srt/layers/attention/vision.py +60 -63
- sglang/srt/layers/dp_attention.py +142 -1
- sglang/srt/layers/layernorm.py +1 -1
- sglang/srt/layers/linear.py +3 -1
- sglang/srt/layers/logits_processor.py +281 -45
- sglang/srt/layers/moe/ep_moe/kernels.py +126 -8
- sglang/srt/layers/moe/ep_moe/layer.py +140 -28
- sglang/srt/layers/moe/fused_moe_native.py +2 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +50 -50
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Instinct_MI300X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Instinct_MI325X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Radeon_Graphics.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Instinct_MI300X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Instinct_MI325X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Radeon_Graphics.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Instinct_MI300X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Instinct_MI325X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Radeon_Graphics.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json +16 -16
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json +16 -16
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json +16 -16
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Instinct_MI300X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Instinct_MI325X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Radeon_Graphics.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json +15 -15
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json +15 -15
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json +15 -15
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +88 -20
- sglang/srt/layers/moe/fused_moe_triton/layer.py +34 -13
- sglang/srt/layers/moe/topk.py +13 -4
- sglang/srt/layers/quantization/__init__.py +111 -7
- sglang/srt/layers/quantization/blockwise_int8.py +409 -0
- sglang/srt/layers/quantization/configs/N=1536,K=1536,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=2048,K=512,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=2304,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=256,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=1024,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=1152,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/fp8.py +69 -28
- sglang/srt/layers/quantization/fp8_utils.py +17 -1
- sglang/srt/layers/quantization/gptq.py +416 -0
- sglang/srt/layers/quantization/int8_kernel.py +327 -0
- sglang/srt/layers/quantization/int8_utils.py +73 -0
- sglang/srt/layers/quantization/modelopt_quant.py +18 -1
- sglang/srt/layers/radix_attention.py +1 -0
- sglang/srt/layers/rotary_embedding.py +0 -1
- sglang/srt/layers/sampler.py +76 -31
- sglang/srt/layers/vocab_parallel_embedding.py +14 -13
- sglang/srt/lora/lora.py +17 -1
- sglang/srt/lora/lora_config.py +5 -0
- sglang/srt/lora/lora_manager.py +1 -3
- sglang/srt/managers/cache_controller.py +193 -62
- sglang/srt/managers/configure_logging.py +2 -1
- sglang/srt/managers/data_parallel_controller.py +6 -2
- sglang/srt/managers/detokenizer_manager.py +124 -102
- sglang/srt/managers/image_processor.py +2 -1
- sglang/srt/managers/io_struct.py +143 -6
- sglang/srt/managers/schedule_batch.py +237 -197
- sglang/srt/managers/schedule_policy.py +29 -29
- sglang/srt/managers/scheduler.py +681 -259
- sglang/srt/managers/session_controller.py +6 -2
- sglang/srt/managers/tokenizer_manager.py +224 -68
- sglang/srt/managers/tp_worker.py +15 -4
- sglang/srt/managers/tp_worker_overlap_thread.py +3 -4
- sglang/srt/mem_cache/chunk_cache.py +18 -11
- sglang/srt/mem_cache/hiradix_cache.py +394 -0
- sglang/srt/mem_cache/memory_pool.py +44 -18
- sglang/srt/mem_cache/radix_cache.py +58 -47
- sglang/srt/metrics/collector.py +94 -36
- sglang/srt/model_executor/cuda_graph_runner.py +55 -24
- sglang/srt/model_executor/forward_batch_info.py +49 -16
- sglang/srt/model_executor/model_runner.py +208 -28
- sglang/srt/model_loader/loader.py +3 -3
- sglang/srt/model_loader/weight_utils.py +36 -14
- sglang/srt/models/baichuan.py +31 -6
- sglang/srt/models/chatglm.py +39 -7
- sglang/srt/models/commandr.py +29 -5
- sglang/srt/models/dbrx.py +31 -5
- sglang/srt/models/deepseek.py +43 -6
- sglang/srt/models/deepseek_nextn.py +32 -19
- sglang/srt/models/deepseek_v2.py +265 -32
- sglang/srt/models/exaone.py +19 -9
- sglang/srt/models/gemma.py +22 -8
- sglang/srt/models/gemma2.py +25 -12
- sglang/srt/models/gemma2_reward.py +5 -1
- sglang/srt/models/gpt2.py +28 -13
- sglang/srt/models/gpt_bigcode.py +27 -5
- sglang/srt/models/granite.py +21 -9
- sglang/srt/models/grok.py +21 -4
- sglang/srt/models/internlm2.py +36 -6
- sglang/srt/models/internlm2_reward.py +5 -1
- sglang/srt/models/llama.py +26 -9
- sglang/srt/models/llama_classification.py +5 -1
- sglang/srt/models/llama_eagle.py +17 -4
- sglang/srt/models/llama_embedding.py +5 -1
- sglang/srt/models/llama_reward.py +7 -2
- sglang/srt/models/llava.py +19 -3
- sglang/srt/models/llavavid.py +10 -1
- sglang/srt/models/minicpm.py +26 -2
- sglang/srt/models/minicpm3.py +39 -3
- sglang/srt/models/minicpmv.py +45 -14
- sglang/srt/models/mixtral.py +20 -9
- sglang/srt/models/mixtral_quant.py +50 -8
- sglang/srt/models/mllama.py +57 -11
- sglang/srt/models/olmo.py +34 -6
- sglang/srt/models/olmo2.py +34 -13
- sglang/srt/models/olmoe.py +26 -4
- sglang/srt/models/phi3_small.py +29 -10
- sglang/srt/models/qwen.py +26 -3
- sglang/srt/models/qwen2.py +26 -4
- sglang/srt/models/qwen2_5_vl.py +46 -8
- sglang/srt/models/qwen2_eagle.py +17 -5
- sglang/srt/models/qwen2_moe.py +44 -6
- sglang/srt/models/qwen2_rm.py +78 -0
- sglang/srt/models/qwen2_vl.py +39 -8
- sglang/srt/models/stablelm.py +32 -5
- sglang/srt/models/torch_native_llama.py +5 -2
- sglang/srt/models/xverse.py +21 -9
- sglang/srt/models/xverse_moe.py +45 -7
- sglang/srt/models/yivl.py +2 -1
- sglang/srt/openai_api/adapter.py +109 -24
- sglang/srt/openai_api/protocol.py +17 -1
- sglang/srt/reasoning_parser.py +154 -0
- sglang/srt/sampling/penaltylib/__init__.py +4 -6
- sglang/srt/sampling/penaltylib/frequency_penalty.py +66 -0
- sglang/srt/sampling/penaltylib/{penalizers/min_new_tokens.py → min_new_tokens.py} +15 -23
- sglang/srt/sampling/penaltylib/orchestrator.py +39 -188
- sglang/srt/sampling/penaltylib/presence_penalty.py +66 -0
- sglang/srt/sampling/sampling_batch_info.py +79 -157
- sglang/srt/sampling/sampling_params.py +16 -13
- sglang/srt/server_args.py +136 -52
- sglang/srt/speculative/build_eagle_tree.py +2 -8
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +0 -1
- sglang/srt/speculative/eagle_utils.py +92 -58
- sglang/srt/speculative/eagle_worker.py +186 -94
- sglang/srt/speculative/spec_info.py +1 -13
- sglang/srt/utils.py +43 -17
- sglang/srt/warmup.py +47 -0
- sglang/test/few_shot_gsm8k.py +4 -1
- sglang/test/runners.py +389 -126
- sglang/test/send_one.py +88 -0
- sglang/test/test_block_fp8_ep.py +361 -0
- sglang/test/test_programs.py +1 -1
- sglang/test/test_utils.py +138 -84
- sglang/utils.py +50 -60
- sglang/version.py +1 -1
- {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post3.dist-info}/METADATA +21 -15
- {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post3.dist-info}/RECORD +200 -166
- {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post3.dist-info}/WHEEL +1 -1
- sglang/bench_latency.py +0 -1
- sglang/srt/sampling/penaltylib/penalizers/frequency_penalty.py +0 -75
- sglang/srt/sampling/penaltylib/penalizers/presence_penalty.py +0 -74
- sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +0 -85
- sglang/test/srt/sampling/penaltylib/utils.py +0 -344
- {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post3.dist-info}/LICENSE +0 -0
- {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post3.dist-info}/top_level.txt +0 -0
@@ -31,7 +31,7 @@ from __future__ import annotations
|
|
31
31
|
|
32
32
|
from dataclasses import dataclass
|
33
33
|
from enum import IntEnum, auto
|
34
|
-
from typing import TYPE_CHECKING, List, Optional
|
34
|
+
from typing import TYPE_CHECKING, List, Optional, Union
|
35
35
|
|
36
36
|
import torch
|
37
37
|
import triton
|
@@ -41,12 +41,13 @@ from sglang.srt.layers.rotary_embedding import MRotaryEmbedding
|
|
41
41
|
from sglang.srt.utils import get_compiler_backend
|
42
42
|
|
43
43
|
if TYPE_CHECKING:
|
44
|
-
from sglang.srt.layers.attention import AttentionBackend
|
44
|
+
from sglang.srt.layers.attention.base_attn_backend import AttentionBackend
|
45
45
|
from sglang.srt.managers.schedule_batch import ImageInputs, ModelWorkerBatch
|
46
46
|
from sglang.srt.mem_cache.memory_pool import BaseTokenToKVPool, ReqToTokenPool
|
47
47
|
from sglang.srt.model_executor.model_runner import ModelRunner
|
48
48
|
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
49
|
-
from sglang.srt.speculative.
|
49
|
+
from sglang.srt.speculative.eagle_utils import EagleDraftInput, EagleVerifyInput
|
50
|
+
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
|
50
51
|
|
51
52
|
|
52
53
|
class ForwardMode(IntEnum):
|
@@ -112,7 +113,9 @@ class ForwardMode(IntEnum):
|
|
112
113
|
|
113
114
|
class CaptureHiddenMode(IntEnum):
|
114
115
|
NULL = auto()
|
116
|
+
# Capture hidden states of all tokens.
|
115
117
|
FULL = auto()
|
118
|
+
# Capture a hidden state of the last token.
|
116
119
|
LAST = auto()
|
117
120
|
|
118
121
|
def need_capture(self):
|
@@ -148,10 +151,14 @@ class ForwardBatch:
|
|
148
151
|
# For logprob
|
149
152
|
return_logprob: bool = False
|
150
153
|
top_logprobs_nums: Optional[List[int]] = None
|
154
|
+
token_ids_logprobs: Optional[List[List[int]]] = None
|
151
155
|
|
152
156
|
# Position information
|
153
157
|
positions: torch.Tensor = None
|
154
158
|
|
159
|
+
# For decode
|
160
|
+
decode_seq_lens_cpu: Optional[torch.Tensor] = None
|
161
|
+
|
155
162
|
# For extend
|
156
163
|
extend_num_tokens: Optional[int] = None
|
157
164
|
extend_seq_lens: Optional[torch.Tensor] = None
|
@@ -160,6 +167,7 @@ class ForwardBatch:
|
|
160
167
|
extend_prefix_lens_cpu: Optional[List[int]] = None
|
161
168
|
extend_seq_lens_cpu: Optional[List[int]] = None
|
162
169
|
extend_logprob_start_lens_cpu: Optional[List[int]] = None
|
170
|
+
extend_input_logprob_token_ids_gpu: Optional[torch.Tensor] = None
|
163
171
|
|
164
172
|
# For multimodal
|
165
173
|
image_inputs: Optional[List[ImageInputs]] = None
|
@@ -185,15 +193,27 @@ class ForwardBatch:
|
|
185
193
|
attn_backend: AttentionBackend = None
|
186
194
|
|
187
195
|
# For DP attention
|
188
|
-
|
196
|
+
global_num_tokens_cpu: Optional[List[int]] = None
|
197
|
+
global_num_tokens_gpu: Optional[torch.Tensor] = None
|
198
|
+
# Has to be None when cuda graph is captured.
|
199
|
+
global_num_tokens_for_logprob_cpu: Optional[List[int]] = None
|
200
|
+
global_num_tokens_for_logprob_gpu: Optional[torch.Tensor] = None
|
201
|
+
# for extend, local start pos and num tokens is different in logits processor
|
202
|
+
# this will be computed in get_dp_local_info
|
203
|
+
# this will be recomputed in LogitsMetadata.from_forward_batch
|
204
|
+
dp_local_start_pos: Optional[torch.Tensor] = None # cached info at runtime
|
205
|
+
dp_local_num_tokens: Optional[torch.Tensor] = None # cached info at runtime
|
189
206
|
gathered_buffer: Optional[torch.Tensor] = None
|
190
207
|
can_run_dp_cuda_graph: bool = False
|
191
208
|
|
192
209
|
# Speculative decoding
|
193
|
-
spec_info:
|
210
|
+
spec_info: Optional[Union[EagleVerifyInput, EagleDraftInput]] = None
|
194
211
|
spec_algorithm: SpeculativeAlgorithm = None
|
195
212
|
capture_hidden_mode: CaptureHiddenMode = None
|
196
213
|
|
214
|
+
# For padding
|
215
|
+
padded_static_len: int = -1 # -1 if not padded
|
216
|
+
|
197
217
|
# For Qwen2-VL
|
198
218
|
mrope_positions: torch.Tensor = None
|
199
219
|
|
@@ -203,8 +223,13 @@ class ForwardBatch:
|
|
203
223
|
batch: ModelWorkerBatch,
|
204
224
|
model_runner: ModelRunner,
|
205
225
|
):
|
206
|
-
|
207
226
|
device = model_runner.device
|
227
|
+
extend_input_logprob_token_ids_gpu = None
|
228
|
+
if batch.extend_input_logprob_token_ids is not None:
|
229
|
+
extend_input_logprob_token_ids_gpu = (
|
230
|
+
batch.extend_input_logprob_token_ids.to(device, non_blocking=True)
|
231
|
+
)
|
232
|
+
|
208
233
|
ret = cls(
|
209
234
|
forward_mode=batch.forward_mode,
|
210
235
|
batch_size=len(batch.seq_lens),
|
@@ -220,7 +245,7 @@ class ForwardBatch:
|
|
220
245
|
seq_lens_sum=batch.seq_lens_sum,
|
221
246
|
return_logprob=batch.return_logprob,
|
222
247
|
top_logprobs_nums=batch.top_logprobs_nums,
|
223
|
-
|
248
|
+
token_ids_logprobs=batch.token_ids_logprobs,
|
224
249
|
can_run_dp_cuda_graph=batch.can_run_dp_cuda_graph,
|
225
250
|
lora_paths=batch.lora_paths,
|
226
251
|
sampling_info=batch.sampling_info,
|
@@ -231,10 +256,12 @@ class ForwardBatch:
|
|
231
256
|
spec_info=batch.spec_info,
|
232
257
|
capture_hidden_mode=batch.capture_hidden_mode,
|
233
258
|
input_embeds=batch.input_embeds,
|
259
|
+
extend_input_logprob_token_ids_gpu=extend_input_logprob_token_ids_gpu,
|
234
260
|
)
|
235
261
|
|
236
|
-
if
|
237
|
-
|
262
|
+
if batch.global_num_tokens is not None:
|
263
|
+
ret.global_num_tokens_cpu = batch.global_num_tokens
|
264
|
+
max_len = max(ret.global_num_tokens_cpu)
|
238
265
|
ret.gathered_buffer = torch.zeros(
|
239
266
|
(max_len * model_runner.tp_size, model_runner.model_config.hidden_size),
|
240
267
|
dtype=model_runner.dtype,
|
@@ -256,6 +283,8 @@ class ForwardBatch:
|
|
256
283
|
if ret.forward_mode.is_decode():
|
257
284
|
if ret.positions is None:
|
258
285
|
ret.positions = clamp_position(batch.seq_lens)
|
286
|
+
if ret.decode_seq_lens_cpu is None:
|
287
|
+
ret.decode_seq_lens_cpu = batch.decode_seq_lens
|
259
288
|
else:
|
260
289
|
ret.extend_seq_lens = torch.tensor(
|
261
290
|
batch.extend_seq_lens, dtype=torch.int32
|
@@ -263,13 +292,12 @@ class ForwardBatch:
|
|
263
292
|
ret.extend_prefix_lens = torch.tensor(
|
264
293
|
batch.extend_prefix_lens, dtype=torch.int32
|
265
294
|
).to(device, non_blocking=True)
|
266
|
-
if
|
267
|
-
model_runner.server_args.attention_backend != "torch_native"
|
268
|
-
and model_runner.server_args.speculative_algorithm != "NEXTN"
|
269
|
-
):
|
295
|
+
if model_runner.server_args.attention_backend != "torch_native":
|
270
296
|
ret.extend_num_tokens = batch.extend_num_tokens
|
271
297
|
positions, ret.extend_start_loc = compute_position_triton(
|
272
|
-
ret.extend_prefix_lens,
|
298
|
+
ret.extend_prefix_lens,
|
299
|
+
ret.extend_seq_lens,
|
300
|
+
ret.extend_num_tokens,
|
273
301
|
)
|
274
302
|
else:
|
275
303
|
positions, ret.extend_start_loc = compute_position_torch(
|
@@ -341,6 +369,7 @@ class ForwardBatch:
|
|
341
369
|
)
|
342
370
|
batch.image_inputs[i].mrope_position_delta = mrope_position_delta
|
343
371
|
mrope_positions_list[i] = mrope_positions
|
372
|
+
|
344
373
|
self.mrope_positions = torch.concat(
|
345
374
|
[torch.tensor(pos, device=device) for pos in mrope_positions_list],
|
346
375
|
axis=1,
|
@@ -353,6 +382,8 @@ def compute_position_triton(
|
|
353
382
|
):
|
354
383
|
"""Compute positions. It is a fused version of `compute_position_torch`."""
|
355
384
|
batch_size = extend_seq_lens.shape[0]
|
385
|
+
has_prefix = extend_prefix_lens.shape[0] == batch_size
|
386
|
+
|
356
387
|
positions = torch.empty(
|
357
388
|
extend_seq_lens_sum, dtype=torch.int64, device=extend_seq_lens.device
|
358
389
|
)
|
@@ -366,6 +397,7 @@ def compute_position_triton(
|
|
366
397
|
extend_start_loc,
|
367
398
|
extend_prefix_lens,
|
368
399
|
extend_seq_lens,
|
400
|
+
has_prefix,
|
369
401
|
)
|
370
402
|
|
371
403
|
return positions, extend_start_loc
|
@@ -377,11 +409,12 @@ def compute_position_kernel(
|
|
377
409
|
extend_start_loc,
|
378
410
|
extend_prefix_lens,
|
379
411
|
extend_seq_lens,
|
412
|
+
has_prefix: tl.constexpr,
|
380
413
|
):
|
381
414
|
BLOCK_SIZE: tl.constexpr = 512
|
382
|
-
pid = tl.program_id(0)
|
415
|
+
pid = tl.program_id(0).to(tl.int64)
|
383
416
|
|
384
|
-
prefix_len = tl.load(extend_prefix_lens + pid)
|
417
|
+
prefix_len = tl.load(extend_prefix_lens + pid) if has_prefix else 0
|
385
418
|
seq_len = tl.load(extend_seq_lens + pid)
|
386
419
|
|
387
420
|
# TODO: optimize this?
|
@@ -13,11 +13,14 @@
|
|
13
13
|
# ==============================================================================
|
14
14
|
"""ModelRunner runs the forward passes of the models."""
|
15
15
|
|
16
|
+
import datetime
|
16
17
|
import gc
|
17
18
|
import json
|
18
19
|
import logging
|
20
|
+
import os
|
19
21
|
import time
|
20
|
-
from
|
22
|
+
from dataclasses import dataclass
|
23
|
+
from typing import List, Optional, Tuple, Union
|
21
24
|
|
22
25
|
import torch
|
23
26
|
import torch.distributed as dist
|
@@ -34,6 +37,7 @@ from sglang.srt.distributed import (
|
|
34
37
|
from sglang.srt.distributed.parallel_state import monkey_patch_vllm_parallel_state
|
35
38
|
from sglang.srt.layers.attention.double_sparsity_backend import DoubleSparseAttnBackend
|
36
39
|
from sglang.srt.layers.attention.flashinfer_backend import FlashInferAttnBackend
|
40
|
+
from sglang.srt.layers.attention.flashinfer_mla_backend import FlashInferMLAAttnBackend
|
37
41
|
from sglang.srt.layers.attention.torch_native_backend import TorchNativeAttnBackend
|
38
42
|
from sglang.srt.layers.attention.triton_backend import TritonAttnBackend
|
39
43
|
from sglang.srt.layers.dp_attention import (
|
@@ -51,14 +55,18 @@ from sglang.srt.mem_cache.memory_pool import (
|
|
51
55
|
MHATokenToKVPool,
|
52
56
|
MLATokenToKVPool,
|
53
57
|
ReqToTokenPool,
|
58
|
+
TokenToKVPoolAllocator,
|
54
59
|
)
|
55
60
|
from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner
|
56
61
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
57
62
|
from sglang.srt.model_loader import get_model
|
63
|
+
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
64
|
+
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
58
65
|
from sglang.srt.server_args import ServerArgs
|
59
66
|
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
|
60
67
|
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
|
61
68
|
from sglang.srt.utils import (
|
69
|
+
MultiprocessingSerializer,
|
62
70
|
enable_show_time_cost,
|
63
71
|
get_available_gpu_memory,
|
64
72
|
init_custom_process_group,
|
@@ -69,10 +77,15 @@ from sglang.srt.utils import (
|
|
69
77
|
set_cpu_offload_max_bytes,
|
70
78
|
set_cuda_arch,
|
71
79
|
)
|
80
|
+
from sglang.utils import get_exception_traceback
|
72
81
|
|
73
82
|
logger = logging.getLogger(__name__)
|
74
83
|
|
75
84
|
|
85
|
+
SGLANG_CI_SMALL_KV_SIZE = os.getenv("SGLANG_CI_SMALL_KV_SIZE", None)
|
86
|
+
UNBALANCED_MODEL_LOADING_TIMEOUT_S = 300
|
87
|
+
|
88
|
+
|
76
89
|
class ModelRunner:
|
77
90
|
"""ModelRunner runs the forward passes of the models."""
|
78
91
|
|
@@ -86,6 +99,8 @@ class ModelRunner:
|
|
86
99
|
nccl_port: int,
|
87
100
|
server_args: ServerArgs,
|
88
101
|
is_draft_worker: bool = False,
|
102
|
+
req_to_token_pool: Optional[ReqToTokenPool] = None,
|
103
|
+
token_to_kv_pool_allocator: Optional[TokenToKVPoolAllocator] = None,
|
89
104
|
):
|
90
105
|
# Parse args
|
91
106
|
self.model_config = model_config
|
@@ -103,6 +118,8 @@ class ModelRunner:
|
|
103
118
|
self.spec_algorithm = SpeculativeAlgorithm.from_string(
|
104
119
|
server_args.speculative_algorithm
|
105
120
|
)
|
121
|
+
self.req_to_token_pool = req_to_token_pool
|
122
|
+
self.token_to_kv_pool_allocator = token_to_kv_pool_allocator
|
106
123
|
|
107
124
|
# Model-specific adjustment
|
108
125
|
if (
|
@@ -113,9 +130,9 @@ class ModelRunner:
|
|
113
130
|
if self.server_args.device != "cpu":
|
114
131
|
if server_args.enable_flashinfer_mla:
|
115
132
|
logger.info(
|
116
|
-
"
|
133
|
+
"MLA optimization is turned on. Use flashinfer mla backend."
|
117
134
|
)
|
118
|
-
self.server_args.attention_backend = "
|
135
|
+
self.server_args.attention_backend = "flashinfer_mla"
|
119
136
|
else:
|
120
137
|
logger.info("MLA optimization is turned on. Use triton backend.")
|
121
138
|
self.server_args.attention_backend = "triton"
|
@@ -176,8 +193,13 @@ class ModelRunner:
|
|
176
193
|
"enable_dp_attention": server_args.enable_dp_attention,
|
177
194
|
"enable_ep_moe": server_args.enable_ep_moe,
|
178
195
|
"device": server_args.device,
|
196
|
+
"speculative_accept_threshold_single": server_args.speculative_accept_threshold_single,
|
197
|
+
"speculative_accept_threshold_acc": server_args.speculative_accept_threshold_acc,
|
179
198
|
"enable_flashinfer_mla": server_args.enable_flashinfer_mla,
|
180
199
|
"disable_radix_cache": server_args.disable_radix_cache,
|
200
|
+
"flashinfer_mla_disable_ragged": server_args.flashinfer_mla_disable_ragged,
|
201
|
+
"debug_tensor_dump_output_folder": server_args.debug_tensor_dump_output_folder,
|
202
|
+
"debug_tensor_dump_inject": server_args.debug_tensor_dump_inject,
|
181
203
|
}
|
182
204
|
)
|
183
205
|
|
@@ -194,6 +216,18 @@ class ModelRunner:
|
|
194
216
|
self.sampler = Sampler()
|
195
217
|
self.load_model()
|
196
218
|
|
219
|
+
# Handle the case where some of models don't finish loading.
|
220
|
+
try:
|
221
|
+
dist.monitored_barrier(
|
222
|
+
group=get_tp_group().cpu_group,
|
223
|
+
timeout=datetime.timedelta(seconds=UNBALANCED_MODEL_LOADING_TIMEOUT_S),
|
224
|
+
wait_all_ranks=True,
|
225
|
+
)
|
226
|
+
except RuntimeError:
|
227
|
+
raise ValueError(
|
228
|
+
f"TP rank {self.tp_rank} could finish the model loading, but there are other ranks that didn't finish loading. It is likely due to unexpected failures (e.g., OOM) or a slow node."
|
229
|
+
) from None
|
230
|
+
|
197
231
|
# Apply torchao quantization
|
198
232
|
torchao_applied = getattr(self.model, "torchao_applied", False)
|
199
233
|
# In layered loading, torchao may have been applied
|
@@ -228,19 +262,18 @@ class ModelRunner:
|
|
228
262
|
|
229
263
|
def init_torch_distributed(self):
|
230
264
|
logger.info("Init torch distributed begin.")
|
231
|
-
|
232
265
|
torch.get_device_module(self.device).set_device(self.gpu_id)
|
266
|
+
|
233
267
|
if self.device == "cuda":
|
234
268
|
backend = "nccl"
|
235
269
|
elif self.device == "xpu":
|
236
|
-
|
237
|
-
# Need to use xccl for xpu backend in the future
|
238
|
-
backend = "gloo"
|
270
|
+
backend = "xccl"
|
239
271
|
elif self.device == "hpu":
|
240
272
|
backend = "hccl"
|
241
273
|
elif self.device == "cpu":
|
242
274
|
backend = "gloo"
|
243
275
|
|
276
|
+
before_avail_memory = get_available_gpu_memory(self.device, self.gpu_id)
|
244
277
|
if not self.server_args.enable_p2p_check:
|
245
278
|
monkey_patch_p2p_access_check()
|
246
279
|
|
@@ -258,6 +291,7 @@ class ModelRunner:
|
|
258
291
|
rank=self.tp_rank,
|
259
292
|
local_rank=self.gpu_id,
|
260
293
|
distributed_init_method=dist_init_method,
|
294
|
+
timeout=self.server_args.dist_timeout,
|
261
295
|
)
|
262
296
|
initialize_model_parallel(tensor_model_parallel_size=self.tp_size)
|
263
297
|
initialize_dp_attention(
|
@@ -270,20 +304,24 @@ class ModelRunner:
|
|
270
304
|
min_per_gpu_memory = get_available_gpu_memory(
|
271
305
|
self.device, self.gpu_id, distributed=self.tp_size > 1
|
272
306
|
)
|
307
|
+
local_gpu_memory = get_available_gpu_memory(self.device, self.gpu_id)
|
273
308
|
self.tp_group = get_tp_group()
|
274
309
|
self.attention_tp_group = get_attention_tp_group()
|
275
310
|
|
276
311
|
# Check memory for tensor parallelism
|
277
312
|
if self.tp_size > 1:
|
278
|
-
local_gpu_memory = get_available_gpu_memory(self.device, self.gpu_id)
|
279
313
|
if min_per_gpu_memory < local_gpu_memory * 0.9:
|
280
314
|
raise ValueError(
|
281
315
|
"The memory capacity is unbalanced. Some GPUs may be occupied by other processes."
|
282
316
|
)
|
283
317
|
|
318
|
+
logger.info(
|
319
|
+
f"Init torch distributed ends. mem usage={(before_avail_memory - local_gpu_memory):.2f} GB"
|
320
|
+
)
|
284
321
|
return min_per_gpu_memory
|
285
322
|
|
286
323
|
def load_model(self):
|
324
|
+
before_avail_memory = get_available_gpu_memory(self.device, self.gpu_id)
|
287
325
|
logger.info(
|
288
326
|
f"Load weight begin. avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
|
289
327
|
)
|
@@ -353,11 +391,13 @@ class ModelRunner:
|
|
353
391
|
)
|
354
392
|
self.dtype = self.model_config.dtype
|
355
393
|
|
394
|
+
after_avail_memory = get_available_gpu_memory(self.device, self.gpu_id)
|
356
395
|
logger.info(
|
357
396
|
f"Load weight end. "
|
358
397
|
f"type={type(self.model).__name__}, "
|
359
398
|
f"dtype={self.dtype}, "
|
360
|
-
f"avail mem={
|
399
|
+
f"avail mem={after_avail_memory:.2f} GB, "
|
400
|
+
f"mem usage={(before_avail_memory - after_avail_memory):.2f} GB."
|
361
401
|
)
|
362
402
|
|
363
403
|
def update_weights_from_disk(
|
@@ -512,8 +552,21 @@ class ModelRunner:
|
|
512
552
|
logger.error(error_msg)
|
513
553
|
return False, error_msg
|
514
554
|
|
515
|
-
def update_weights_from_tensor(
|
516
|
-
self
|
555
|
+
def update_weights_from_tensor(
|
556
|
+
self,
|
557
|
+
named_tensors: List[Tuple[str, Union[torch.Tensor, "LocalSerializedTensor"]]],
|
558
|
+
load_format: Optional[str] = None,
|
559
|
+
):
|
560
|
+
named_tensors = [
|
561
|
+
(name, _unwrap_tensor(tensor, tp_rank=self.tp_rank))
|
562
|
+
for name, tensor in named_tensors
|
563
|
+
]
|
564
|
+
if load_format == "direct":
|
565
|
+
_model_load_weights_direct(self.model, named_tensors)
|
566
|
+
elif load_format is None:
|
567
|
+
self.model.load_weights(named_tensors)
|
568
|
+
else:
|
569
|
+
raise NotImplementedError(f"Unknown load_format={load_format}")
|
517
570
|
return True, "Success"
|
518
571
|
|
519
572
|
def get_weights_by_name(
|
@@ -606,15 +659,31 @@ class ModelRunner:
|
|
606
659
|
4096,
|
607
660
|
)
|
608
661
|
|
662
|
+
if SGLANG_CI_SMALL_KV_SIZE:
|
663
|
+
self.max_total_num_tokens = int(SGLANG_CI_SMALL_KV_SIZE)
|
664
|
+
|
609
665
|
if not self.spec_algorithm.is_none():
|
610
666
|
if self.is_draft_worker:
|
611
667
|
self.max_total_num_tokens = self.server_args.draft_runner_cache_size
|
668
|
+
max_num_reqs = self.server_args.max_num_reqs
|
612
669
|
else:
|
670
|
+
# We are sharing the `token_to_kv_pool`, and both verify and draft tokens
|
671
|
+
# can be concurrently allocated, so we should give a headroom for it.
|
613
672
|
self.server_args.draft_runner_cache_size = (
|
614
673
|
self.max_total_num_tokens
|
615
|
-
|
674
|
+
# draft
|
675
|
+
+ max_num_reqs
|
676
|
+
* self.server_args.speculative_num_steps
|
677
|
+
* self.server_args.speculative_eagle_topk
|
678
|
+
# verify
|
679
|
+
+ max_num_reqs * self.server_args.speculative_num_draft_tokens
|
680
|
+
# buffer
|
616
681
|
+ 100
|
617
682
|
)
|
683
|
+
# Target worker and draft worker shares the same indices for the
|
684
|
+
# token_to_kv_pool, so we should make sure to match max_total_num_tokens.
|
685
|
+
self.max_total_num_tokens = self.server_args.draft_runner_cache_size
|
686
|
+
self.server_args.max_num_reqs = max_num_reqs
|
618
687
|
|
619
688
|
if max_total_tokens is not None:
|
620
689
|
if max_total_tokens > self.max_total_num_tokens:
|
@@ -630,12 +699,26 @@ class ModelRunner:
|
|
630
699
|
"Not enough memory. Please try to increase --mem-fraction-static."
|
631
700
|
)
|
632
701
|
|
633
|
-
self.req_to_token_pool
|
634
|
-
|
635
|
-
|
636
|
-
|
637
|
-
|
638
|
-
|
702
|
+
if self.req_to_token_pool is None:
|
703
|
+
self.req_to_token_pool = ReqToTokenPool(
|
704
|
+
size=max_num_reqs + 1,
|
705
|
+
max_context_len=self.model_config.context_len + 4,
|
706
|
+
device=self.device,
|
707
|
+
enable_memory_saver=self.server_args.enable_memory_saver,
|
708
|
+
)
|
709
|
+
else:
|
710
|
+
# Draft worker shares req_to_token_pool with the target worker.
|
711
|
+
assert self.is_draft_worker
|
712
|
+
|
713
|
+
if self.token_to_kv_pool_allocator is None:
|
714
|
+
self.token_to_kv_pool_allocator = TokenToKVPoolAllocator(
|
715
|
+
self.max_total_num_tokens,
|
716
|
+
dtype=self.kv_cache_dtype,
|
717
|
+
device=self.device,
|
718
|
+
)
|
719
|
+
else:
|
720
|
+
assert self.is_draft_worker
|
721
|
+
|
639
722
|
if (
|
640
723
|
self.model_config.attention_arch == AttentionArch.MLA
|
641
724
|
and not self.server_args.disable_mla
|
@@ -703,6 +786,8 @@ class ModelRunner:
|
|
703
786
|
self.attn_backend = TritonAttnBackend(self)
|
704
787
|
elif self.server_args.attention_backend == "torch_native":
|
705
788
|
self.attn_backend = TorchNativeAttnBackend(self)
|
789
|
+
elif self.server_args.attention_backend == "flashinfer_mla":
|
790
|
+
self.attn_backend = FlashInferMLAAttnBackend(self)
|
706
791
|
else:
|
707
792
|
raise ValueError(
|
708
793
|
f"Invalid attention backend: {self.server_args.attention_backend}"
|
@@ -737,9 +822,16 @@ class ModelRunner:
|
|
737
822
|
return
|
738
823
|
|
739
824
|
tic = time.time()
|
740
|
-
|
825
|
+
before_mem = get_available_gpu_memory(self.device, self.gpu_id)
|
826
|
+
logger.info(
|
827
|
+
f"Capture cuda graph begin. This can take up to several minutes. avail mem={before_mem:.2f} GB"
|
828
|
+
)
|
741
829
|
self.cuda_graph_runner = CudaGraphRunner(self)
|
742
|
-
|
830
|
+
after_mem = get_available_gpu_memory(self.device, self.gpu_id)
|
831
|
+
logger.info(
|
832
|
+
f"Capture cuda graph end. Time elapsed: {time.time() - tic:.2f} s. "
|
833
|
+
f"avail mem={after_mem:.2f} GB. mem usage={(before_mem - after_mem):.2f} GB."
|
834
|
+
)
|
743
835
|
|
744
836
|
def apply_torch_tp(self):
|
745
837
|
logger.info(f"Enabling torch tensor parallelism on {self.tp_size} devices.")
|
@@ -754,8 +846,12 @@ class ModelRunner:
|
|
754
846
|
forward_batch.input_ids, forward_batch.positions, forward_batch
|
755
847
|
)
|
756
848
|
|
757
|
-
def forward_extend(
|
758
|
-
self
|
849
|
+
def forward_extend(
|
850
|
+
self, forward_batch: ForwardBatch, skip_attn_backend_init: bool = False
|
851
|
+
):
|
852
|
+
if not skip_attn_backend_init:
|
853
|
+
self.attn_backend.init_forward_metadata(forward_batch)
|
854
|
+
|
759
855
|
if self.is_generation:
|
760
856
|
if forward_batch.input_embeds is None:
|
761
857
|
return self.model.forward(
|
@@ -799,11 +895,10 @@ class ModelRunner:
|
|
799
895
|
else:
|
800
896
|
raise ValueError(f"Invalid forward mode: {forward_batch.forward_mode}")
|
801
897
|
|
802
|
-
def
|
803
|
-
self, logits_output: LogitsProcessorOutput,
|
804
|
-
)
|
898
|
+
def _preprocess_logits(
|
899
|
+
self, logits_output: LogitsProcessorOutput, sampling_info: SamplingBatchInfo
|
900
|
+
):
|
805
901
|
# Apply logit bias
|
806
|
-
sampling_info = forward_batch.sampling_info
|
807
902
|
if sampling_info.sampling_info_done:
|
808
903
|
# Overlap mode: the function update_regex_vocab_mask was executed
|
809
904
|
# in process_batch_result of the last batch.
|
@@ -812,15 +907,77 @@ class ModelRunner:
|
|
812
907
|
else:
|
813
908
|
# Normal mode: Put CPU-heavy tasks here. They will be overlapped with the forward pass.
|
814
909
|
sampling_info.update_regex_vocab_mask()
|
815
|
-
sampling_info.update_penalties()
|
816
910
|
sampling_info.apply_logits_bias(logits_output.next_token_logits)
|
817
911
|
|
912
|
+
def update_output_logprobs(
|
913
|
+
self,
|
914
|
+
logits_output: LogitsProcessorOutput,
|
915
|
+
sampling_info: SamplingBatchInfo,
|
916
|
+
top_logprobs_nums: List[int],
|
917
|
+
token_ids_logprobs: List[int],
|
918
|
+
next_token_ids: torch.Tensor,
|
919
|
+
*,
|
920
|
+
num_tokens_per_req: List[int],
|
921
|
+
):
|
922
|
+
"""Update the logits_output's output logprob based on next_token_ids
|
923
|
+
|
924
|
+
Args:
|
925
|
+
logits_output: The logits output from the model forward
|
926
|
+
sampling_info: Sampling info for logprob calculation
|
927
|
+
top_logprobs_nums: Number of logprobs per request.
|
928
|
+
next_token_ids: Next token ids.
|
929
|
+
num_tokens_per_req: The number of tokens per request.
|
930
|
+
|
931
|
+
Returns:
|
932
|
+
A list of next_token_ids
|
933
|
+
"""
|
934
|
+
self._preprocess_logits(logits_output, sampling_info)
|
935
|
+
# We should repeat top_logprobs_nums to match num_tokens_per_req.
|
936
|
+
top_logprobs_nums_repeat_interleaved = []
|
937
|
+
token_ids_logprobs_repeat_interleaved = []
|
938
|
+
for num, num_tokens in zip(top_logprobs_nums, num_tokens_per_req):
|
939
|
+
top_logprobs_nums_repeat_interleaved.extend([num] * num_tokens)
|
940
|
+
for token_ids, num_tokens in zip(token_ids_logprobs, num_tokens_per_req):
|
941
|
+
token_ids_logprobs_repeat_interleaved.extend([token_ids] * num_tokens)
|
942
|
+
self.sampler(
|
943
|
+
logits_output,
|
944
|
+
sampling_info,
|
945
|
+
True,
|
946
|
+
top_logprobs_nums_repeat_interleaved,
|
947
|
+
token_ids_logprobs_repeat_interleaved,
|
948
|
+
batch_next_token_ids=next_token_ids,
|
949
|
+
)
|
950
|
+
|
951
|
+
def sample(
|
952
|
+
self,
|
953
|
+
logits_output: LogitsProcessorOutput,
|
954
|
+
forward_batch: ForwardBatch,
|
955
|
+
) -> torch.Tensor:
|
956
|
+
"""Sample and compute logprobs and update logits_output.
|
957
|
+
|
958
|
+
Args:
|
959
|
+
logits_output: The logits output from the model forward
|
960
|
+
forward_batch: The forward batch that generates logits_output
|
961
|
+
|
962
|
+
Returns:
|
963
|
+
A list of next_token_ids
|
964
|
+
"""
|
965
|
+
# For duplex models with multiple output streams.
|
966
|
+
if isinstance(logits_output, tuple):
|
967
|
+
return torch.stack(
|
968
|
+
[self.sample(values, forward_batch) for values in logits_output],
|
969
|
+
axis=-1,
|
970
|
+
)
|
971
|
+
|
972
|
+
self._preprocess_logits(logits_output, forward_batch.sampling_info)
|
973
|
+
|
818
974
|
# Sample the next tokens
|
819
975
|
next_token_ids = self.sampler(
|
820
976
|
logits_output,
|
821
|
-
sampling_info,
|
977
|
+
forward_batch.sampling_info,
|
822
978
|
forward_batch.return_logprob,
|
823
979
|
forward_batch.top_logprobs_nums,
|
980
|
+
forward_batch.token_ids_logprobs,
|
824
981
|
)
|
825
982
|
return next_token_ids
|
826
983
|
|
@@ -832,3 +989,26 @@ class ModelRunner:
|
|
832
989
|
if rope_scaling is None:
|
833
990
|
return False
|
834
991
|
return rope_scaling.get("type", None) == "mrope"
|
992
|
+
|
993
|
+
|
994
|
+
def _model_load_weights_direct(model, named_tensors: List[Tuple[str, torch.Tensor]]):
|
995
|
+
params_dict = dict(model.named_parameters())
|
996
|
+
for name, tensor in named_tensors:
|
997
|
+
default_weight_loader(params_dict[name], tensor)
|
998
|
+
|
999
|
+
|
1000
|
+
def _unwrap_tensor(tensor, tp_rank):
|
1001
|
+
if isinstance(tensor, LocalSerializedTensor):
|
1002
|
+
return tensor.get(tp_rank)
|
1003
|
+
return tensor
|
1004
|
+
|
1005
|
+
|
1006
|
+
@dataclass
|
1007
|
+
class LocalSerializedTensor:
|
1008
|
+
"""torch.Tensor that gets serialized by MultiprocessingSerializer (which only serializes a pointer and not the data).
|
1009
|
+
The i-th element in the list corresponds to i-th rank's GPU."""
|
1010
|
+
|
1011
|
+
values: List[bytes]
|
1012
|
+
|
1013
|
+
def get(self, rank: int):
|
1014
|
+
return MultiprocessingSerializer.deserialize(self.values[rank])
|
@@ -11,7 +11,7 @@ import math
|
|
11
11
|
import os
|
12
12
|
from abc import ABC, abstractmethod
|
13
13
|
from contextlib import contextmanager
|
14
|
-
from typing import Any, Dict, Generator, Iterable, List, Optional, Tuple,
|
14
|
+
from typing import Any, Dict, Generator, Iterable, List, Optional, Tuple, cast
|
15
15
|
|
16
16
|
import gguf
|
17
17
|
import huggingface_hub
|
@@ -19,7 +19,7 @@ import numpy as np
|
|
19
19
|
import torch
|
20
20
|
from huggingface_hub import HfApi, hf_hub_download
|
21
21
|
from torch import nn
|
22
|
-
from transformers import AutoModelForCausalLM
|
22
|
+
from transformers import AutoModelForCausalLM
|
23
23
|
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME
|
24
24
|
|
25
25
|
from sglang.srt.configs.device_config import DeviceConfig
|
@@ -197,7 +197,7 @@ class DefaultModelLoader(BaseModelLoader):
|
|
197
197
|
|
198
198
|
Returns the path to the downloaded model, or None if the model is not
|
199
199
|
downloaded from ModelScope."""
|
200
|
-
if "SGLANG_USE_MODELSCOPE"
|
200
|
+
if os.environ.get("SGLANG_USE_MODELSCOPE", None) == "True":
|
201
201
|
# download model from ModelScope hub,
|
202
202
|
# lazy import so that modelscope is not required for normal use.
|
203
203
|
# pylint: disable=C.
|