sglang 0.4.3.post2__py3-none-any.whl → 0.4.3.post4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/api.py +1 -1
- sglang/bench_offline_throughput.py +19 -0
- sglang/bench_one_batch.py +2 -2
- sglang/bench_serving.py +123 -79
- sglang/global_config.py +8 -3
- sglang/lang/backend/runtime_endpoint.py +1 -1
- sglang/lang/ir.py +1 -1
- sglang/srt/_custom_ops.py +83 -91
- sglang/srt/configs/load_config.py +4 -1
- sglang/srt/configs/model_config.py +48 -2
- sglang/srt/configs/qwen2_5_vl_config.py +5 -2
- sglang/srt/constrained/base_grammar_backend.py +117 -15
- sglang/srt/constrained/llguidance_backend.py +151 -0
- sglang/srt/constrained/outlines_backend.py +24 -33
- sglang/srt/constrained/xgrammar_backend.py +69 -38
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +225 -80
- sglang/srt/distributed/parallel_state.py +48 -3
- sglang/srt/entrypoints/engine.py +67 -9
- sglang/srt/entrypoints/http_server.py +190 -41
- sglang/srt/entrypoints/verl_engine.py +147 -0
- sglang/srt/function_call_parser.py +0 -1
- sglang/srt/layers/activation.py +11 -0
- sglang/srt/layers/attention/{__init__.py → base_attn_backend.py} +14 -6
- sglang/srt/layers/attention/double_sparsity_backend.py +1 -1
- sglang/srt/layers/attention/flashinfer_backend.py +302 -414
- sglang/srt/layers/attention/flashinfer_mla_backend.py +582 -0
- sglang/srt/layers/attention/torch_native_backend.py +1 -1
- sglang/srt/layers/attention/triton_backend.py +13 -8
- sglang/srt/layers/attention/triton_ops/decode_attention.py +3 -0
- sglang/srt/layers/attention/triton_ops/extend_attention.py +20 -4
- sglang/srt/layers/attention/triton_ops/rocm_mla_decode_rope.py +439 -0
- sglang/srt/layers/attention/utils.py +39 -0
- sglang/srt/layers/attention/vision.py +60 -63
- sglang/srt/layers/dp_attention.py +142 -1
- sglang/srt/layers/layernorm.py +1 -1
- sglang/srt/layers/linear.py +3 -1
- sglang/srt/layers/logits_processor.py +281 -45
- sglang/srt/layers/moe/ep_moe/kernels.py +126 -8
- sglang/srt/layers/moe/ep_moe/layer.py +140 -28
- sglang/srt/layers/moe/fused_moe_native.py +2 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +50 -50
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Instinct_MI300X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Instinct_MI325X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=14336,device_name=AMD_Radeon_Graphics.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Instinct_MI300X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Instinct_MI325X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=1792,device_name=AMD_Radeon_Graphics.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Instinct_MI300X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Instinct_MI325X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=3584,device_name=AMD_Radeon_Graphics.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json +16 -16
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json +16 -16
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=4096,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json +16 -16
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Instinct_MI300X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Instinct_MI325X.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=7168,device_name=AMD_Radeon_Graphics.json +18 -18
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8.json +15 -15
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8.json +15 -15
- sglang/srt/layers/moe/fused_moe_triton/configs/E=8,N=8192,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8.json +15 -15
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +88 -20
- sglang/srt/layers/moe/fused_moe_triton/layer.py +34 -13
- sglang/srt/layers/moe/topk.py +13 -4
- sglang/srt/layers/quantization/__init__.py +111 -7
- sglang/srt/layers/quantization/blockwise_int8.py +409 -0
- sglang/srt/layers/quantization/configs/N=1536,K=1536,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=2048,K=512,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=2304,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=24576,K=1536,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=24576,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=256,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=32768,K=512,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=1024,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=1152,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=128,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Instinct_MI325X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=AMD_Radeon_Graphics,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=7168,K=16384,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/configs/N=7168,K=18432,device_name=NVIDIA_A100-SXM4-80GB,dtype=int8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/quantization/fp8.py +69 -28
- sglang/srt/layers/quantization/fp8_utils.py +17 -1
- sglang/srt/layers/quantization/gptq.py +416 -0
- sglang/srt/layers/quantization/int8_kernel.py +327 -0
- sglang/srt/layers/quantization/int8_utils.py +73 -0
- sglang/srt/layers/quantization/modelopt_quant.py +18 -1
- sglang/srt/layers/radix_attention.py +1 -0
- sglang/srt/layers/rotary_embedding.py +0 -1
- sglang/srt/layers/sampler.py +76 -31
- sglang/srt/layers/vocab_parallel_embedding.py +14 -13
- sglang/srt/lora/lora.py +17 -1
- sglang/srt/lora/lora_config.py +5 -0
- sglang/srt/lora/lora_manager.py +1 -3
- sglang/srt/managers/cache_controller.py +193 -62
- sglang/srt/managers/configure_logging.py +2 -1
- sglang/srt/managers/data_parallel_controller.py +6 -2
- sglang/srt/managers/detokenizer_manager.py +124 -102
- sglang/srt/managers/image_processor.py +2 -1
- sglang/srt/managers/io_struct.py +144 -6
- sglang/srt/managers/schedule_batch.py +237 -197
- sglang/srt/managers/schedule_policy.py +29 -29
- sglang/srt/managers/scheduler.py +773 -334
- sglang/srt/managers/session_controller.py +6 -2
- sglang/srt/managers/tokenizer_manager.py +225 -68
- sglang/srt/managers/tp_worker.py +15 -4
- sglang/srt/managers/tp_worker_overlap_thread.py +3 -4
- sglang/srt/mem_cache/chunk_cache.py +18 -11
- sglang/srt/mem_cache/hiradix_cache.py +394 -0
- sglang/srt/mem_cache/memory_pool.py +68 -37
- sglang/srt/mem_cache/radix_cache.py +58 -47
- sglang/srt/metrics/collector.py +102 -36
- sglang/srt/model_executor/cuda_graph_runner.py +56 -31
- sglang/srt/model_executor/forward_batch_info.py +49 -16
- sglang/srt/model_executor/model_runner.py +280 -81
- sglang/srt/model_loader/loader.py +3 -3
- sglang/srt/model_loader/weight_utils.py +36 -14
- sglang/srt/models/baichuan.py +31 -6
- sglang/srt/models/chatglm.py +39 -7
- sglang/srt/models/commandr.py +29 -5
- sglang/srt/models/dbrx.py +31 -5
- sglang/srt/models/deepseek.py +43 -6
- sglang/srt/models/deepseek_nextn.py +32 -19
- sglang/srt/models/deepseek_v2.py +265 -32
- sglang/srt/models/exaone.py +19 -9
- sglang/srt/models/gemma.py +22 -8
- sglang/srt/models/gemma2.py +25 -12
- sglang/srt/models/gemma2_reward.py +5 -1
- sglang/srt/models/gpt2.py +28 -13
- sglang/srt/models/gpt_bigcode.py +27 -5
- sglang/srt/models/granite.py +21 -9
- sglang/srt/models/grok.py +21 -4
- sglang/srt/models/internlm2.py +36 -6
- sglang/srt/models/internlm2_reward.py +5 -1
- sglang/srt/models/llama.py +26 -9
- sglang/srt/models/llama_classification.py +5 -1
- sglang/srt/models/llama_eagle.py +17 -4
- sglang/srt/models/llama_embedding.py +5 -1
- sglang/srt/models/llama_reward.py +7 -2
- sglang/srt/models/llava.py +19 -3
- sglang/srt/models/llavavid.py +10 -1
- sglang/srt/models/minicpm.py +26 -2
- sglang/srt/models/minicpm3.py +39 -3
- sglang/srt/models/minicpmv.py +45 -14
- sglang/srt/models/mixtral.py +20 -9
- sglang/srt/models/mixtral_quant.py +50 -8
- sglang/srt/models/mllama.py +57 -11
- sglang/srt/models/olmo.py +34 -6
- sglang/srt/models/olmo2.py +34 -13
- sglang/srt/models/olmoe.py +26 -4
- sglang/srt/models/phi3_small.py +29 -10
- sglang/srt/models/qwen.py +26 -3
- sglang/srt/models/qwen2.py +26 -4
- sglang/srt/models/qwen2_5_vl.py +46 -8
- sglang/srt/models/qwen2_eagle.py +17 -5
- sglang/srt/models/qwen2_moe.py +44 -6
- sglang/srt/models/qwen2_rm.py +78 -0
- sglang/srt/models/qwen2_vl.py +39 -8
- sglang/srt/models/stablelm.py +32 -5
- sglang/srt/models/torch_native_llama.py +5 -2
- sglang/srt/models/xverse.py +21 -9
- sglang/srt/models/xverse_moe.py +45 -7
- sglang/srt/models/yivl.py +2 -1
- sglang/srt/openai_api/adapter.py +109 -24
- sglang/srt/openai_api/protocol.py +17 -1
- sglang/srt/reasoning_parser.py +154 -0
- sglang/srt/sampling/penaltylib/__init__.py +4 -6
- sglang/srt/sampling/penaltylib/frequency_penalty.py +66 -0
- sglang/srt/sampling/penaltylib/{penalizers/min_new_tokens.py → min_new_tokens.py} +15 -23
- sglang/srt/sampling/penaltylib/orchestrator.py +39 -188
- sglang/srt/sampling/penaltylib/presence_penalty.py +66 -0
- sglang/srt/sampling/sampling_batch_info.py +79 -157
- sglang/srt/sampling/sampling_params.py +16 -13
- sglang/srt/server_args.py +135 -60
- sglang/srt/speculative/build_eagle_tree.py +8 -9
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +1 -12
- sglang/srt/speculative/eagle_utils.py +92 -57
- sglang/srt/speculative/eagle_worker.py +238 -111
- sglang/srt/speculative/spec_info.py +1 -13
- sglang/srt/utils.py +43 -17
- sglang/srt/warmup.py +47 -0
- sglang/test/few_shot_gsm8k.py +4 -1
- sglang/test/runners.py +389 -126
- sglang/test/send_one.py +88 -0
- sglang/test/test_block_fp8_ep.py +361 -0
- sglang/test/test_programs.py +1 -1
- sglang/test/test_utils.py +138 -84
- sglang/utils.py +50 -60
- sglang/version.py +1 -1
- {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post4.dist-info}/METADATA +22 -15
- {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post4.dist-info}/RECORD +200 -166
- {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post4.dist-info}/WHEEL +1 -1
- sglang/bench_latency.py +0 -1
- sglang/srt/sampling/penaltylib/penalizers/frequency_penalty.py +0 -75
- sglang/srt/sampling/penaltylib/penalizers/presence_penalty.py +0 -74
- sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +0 -85
- sglang/test/srt/sampling/penaltylib/utils.py +0 -344
- {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post4.dist-info}/LICENSE +0 -0
- {sglang-0.4.3.post2.dist-info → sglang-0.4.3.post4.dist-info}/top_level.txt +0 -0
@@ -13,11 +13,14 @@
|
|
13
13
|
# ==============================================================================
|
14
14
|
"""ModelRunner runs the forward passes of the models."""
|
15
15
|
|
16
|
+
import datetime
|
16
17
|
import gc
|
17
18
|
import json
|
18
19
|
import logging
|
20
|
+
import os
|
19
21
|
import time
|
20
|
-
from
|
22
|
+
from dataclasses import dataclass
|
23
|
+
from typing import List, Optional, Tuple, Union
|
21
24
|
|
22
25
|
import torch
|
23
26
|
import torch.distributed as dist
|
@@ -34,6 +37,7 @@ from sglang.srt.distributed import (
|
|
34
37
|
from sglang.srt.distributed.parallel_state import monkey_patch_vllm_parallel_state
|
35
38
|
from sglang.srt.layers.attention.double_sparsity_backend import DoubleSparseAttnBackend
|
36
39
|
from sglang.srt.layers.attention.flashinfer_backend import FlashInferAttnBackend
|
40
|
+
from sglang.srt.layers.attention.flashinfer_mla_backend import FlashInferMLAAttnBackend
|
37
41
|
from sglang.srt.layers.attention.torch_native_backend import TorchNativeAttnBackend
|
38
42
|
from sglang.srt.layers.attention.triton_backend import TritonAttnBackend
|
39
43
|
from sglang.srt.layers.dp_attention import (
|
@@ -51,14 +55,18 @@ from sglang.srt.mem_cache.memory_pool import (
|
|
51
55
|
MHATokenToKVPool,
|
52
56
|
MLATokenToKVPool,
|
53
57
|
ReqToTokenPool,
|
58
|
+
TokenToKVPoolAllocator,
|
54
59
|
)
|
55
60
|
from sglang.srt.model_executor.cuda_graph_runner import CudaGraphRunner
|
56
61
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
57
62
|
from sglang.srt.model_loader import get_model
|
63
|
+
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
64
|
+
from sglang.srt.sampling.sampling_batch_info import SamplingBatchInfo
|
58
65
|
from sglang.srt.server_args import ServerArgs
|
59
66
|
from sglang.srt.speculative.spec_info import SpeculativeAlgorithm
|
60
67
|
from sglang.srt.torch_memory_saver_adapter import TorchMemorySaverAdapter
|
61
68
|
from sglang.srt.utils import (
|
69
|
+
MultiprocessingSerializer,
|
62
70
|
enable_show_time_cost,
|
63
71
|
get_available_gpu_memory,
|
64
72
|
init_custom_process_group,
|
@@ -69,10 +77,15 @@ from sglang.srt.utils import (
|
|
69
77
|
set_cpu_offload_max_bytes,
|
70
78
|
set_cuda_arch,
|
71
79
|
)
|
80
|
+
from sglang.utils import get_exception_traceback
|
72
81
|
|
73
82
|
logger = logging.getLogger(__name__)
|
74
83
|
|
75
84
|
|
85
|
+
SGLANG_CI_SMALL_KV_SIZE = os.getenv("SGLANG_CI_SMALL_KV_SIZE", None)
|
86
|
+
UNBALANCED_MODEL_LOADING_TIMEOUT_S = 300
|
87
|
+
|
88
|
+
|
76
89
|
class ModelRunner:
|
77
90
|
"""ModelRunner runs the forward passes of the models."""
|
78
91
|
|
@@ -86,6 +99,8 @@ class ModelRunner:
|
|
86
99
|
nccl_port: int,
|
87
100
|
server_args: ServerArgs,
|
88
101
|
is_draft_worker: bool = False,
|
102
|
+
req_to_token_pool: Optional[ReqToTokenPool] = None,
|
103
|
+
token_to_kv_pool_allocator: Optional[TokenToKVPoolAllocator] = None,
|
89
104
|
):
|
90
105
|
# Parse args
|
91
106
|
self.model_config = model_config
|
@@ -103,68 +118,21 @@ class ModelRunner:
|
|
103
118
|
self.spec_algorithm = SpeculativeAlgorithm.from_string(
|
104
119
|
server_args.speculative_algorithm
|
105
120
|
)
|
121
|
+
self.req_to_token_pool = req_to_token_pool
|
122
|
+
self.token_to_kv_pool_allocator = token_to_kv_pool_allocator
|
106
123
|
|
107
124
|
# Model-specific adjustment
|
108
|
-
|
109
|
-
self.model_config.attention_arch == AttentionArch.MLA
|
110
|
-
and not self.server_args.disable_mla
|
111
|
-
):
|
112
|
-
# TODO: add MLA optimization on CPU
|
113
|
-
if self.server_args.device != "cpu":
|
114
|
-
if server_args.enable_flashinfer_mla:
|
115
|
-
logger.info(
|
116
|
-
"FlashInfer MLA optimization is turned on. Use flashinfer backend for DeepseekV3ForCausalLM."
|
117
|
-
)
|
118
|
-
self.server_args.attention_backend = "flashinfer"
|
119
|
-
else:
|
120
|
-
logger.info("MLA optimization is turned on. Use triton backend.")
|
121
|
-
self.server_args.attention_backend = "triton"
|
122
|
-
|
123
|
-
if self.server_args.enable_double_sparsity:
|
124
|
-
logger.info(
|
125
|
-
"Double sparsity optimization is turned on. Use triton backend without CUDA graph."
|
126
|
-
)
|
127
|
-
self.server_args.attention_backend = "triton"
|
128
|
-
self.server_args.disable_cuda_graph = True
|
129
|
-
if self.server_args.ds_heavy_channel_type is None:
|
130
|
-
raise ValueError(
|
131
|
-
"Please specify the heavy channel type for double sparsity optimization."
|
132
|
-
)
|
133
|
-
self.init_double_sparsity_channel_config(
|
134
|
-
self.server_args.ds_heavy_channel_type
|
135
|
-
)
|
125
|
+
self.model_specific_adjustment()
|
136
126
|
|
137
|
-
if self.is_multimodal:
|
138
|
-
self.mem_fraction_static *= 0.95
|
139
|
-
logger.info(
|
140
|
-
f"Automatically reduce --mem-fraction-static to {self.mem_fraction_static:.3f} "
|
141
|
-
f"because this is a multimodal model."
|
142
|
-
)
|
143
|
-
|
144
|
-
if self.model_config.hf_config.architectures == [
|
145
|
-
"MllamaForConditionalGeneration"
|
146
|
-
]:
|
147
|
-
logger.info("Automatically turn off --chunked-prefill-size for mllama.")
|
148
|
-
server_args.chunked_prefill_size = -1
|
149
|
-
|
150
|
-
if self.model_config.hf_config.architectures == [
|
151
|
-
"Qwen2VLForConditionalGeneration"
|
152
|
-
]:
|
153
|
-
# TODO: qwen2-vl does not support radix cache now, set disable_radix_cache=True automatically
|
154
|
-
logger.info(
|
155
|
-
"Automatically turn off --chunked-prefill-size and disable radix cache for qwen2-vl."
|
156
|
-
)
|
157
|
-
server_args.chunked_prefill_size = -1
|
158
|
-
server_args.disable_radix_cache = True
|
159
|
-
|
160
|
-
# Global vars
|
161
127
|
if server_args.show_time_cost:
|
162
128
|
enable_show_time_cost()
|
129
|
+
|
163
130
|
if server_args.disable_outlines_disk_cache:
|
164
131
|
from outlines.caching import disable_cache
|
165
132
|
|
166
133
|
disable_cache()
|
167
134
|
|
135
|
+
# Global vars
|
168
136
|
global_server_args_dict.update(
|
169
137
|
{
|
170
138
|
"attention_backend": server_args.attention_backend,
|
@@ -176,11 +144,17 @@ class ModelRunner:
|
|
176
144
|
"enable_dp_attention": server_args.enable_dp_attention,
|
177
145
|
"enable_ep_moe": server_args.enable_ep_moe,
|
178
146
|
"device": server_args.device,
|
147
|
+
"speculative_accept_threshold_single": server_args.speculative_accept_threshold_single,
|
148
|
+
"speculative_accept_threshold_acc": server_args.speculative_accept_threshold_acc,
|
179
149
|
"enable_flashinfer_mla": server_args.enable_flashinfer_mla,
|
180
150
|
"disable_radix_cache": server_args.disable_radix_cache,
|
151
|
+
"flashinfer_mla_disable_ragged": server_args.flashinfer_mla_disable_ragged,
|
152
|
+
"debug_tensor_dump_output_folder": server_args.debug_tensor_dump_output_folder,
|
153
|
+
"debug_tensor_dump_inject": server_args.debug_tensor_dump_inject,
|
181
154
|
}
|
182
155
|
)
|
183
156
|
|
157
|
+
# CPU offload
|
184
158
|
set_cpu_offload_max_bytes(int(server_args.cpu_offload_gb * 1024**3))
|
185
159
|
|
186
160
|
# Get memory before model loading
|
@@ -210,9 +184,11 @@ class ModelRunner:
|
|
210
184
|
else:
|
211
185
|
self.torch_tp_applied = False
|
212
186
|
|
213
|
-
# Init
|
187
|
+
# Init lora
|
214
188
|
if server_args.lora_paths is not None:
|
215
189
|
self.init_lora_manager()
|
190
|
+
|
191
|
+
# Init memory pool and attention backends
|
216
192
|
self.init_memory_pool(
|
217
193
|
min_per_gpu_memory,
|
218
194
|
server_args.max_running_requests,
|
@@ -226,6 +202,59 @@ class ModelRunner:
|
|
226
202
|
self.cuda_graph_runner = None
|
227
203
|
self.init_attention_backend()
|
228
204
|
|
205
|
+
def model_specific_adjustment(self):
|
206
|
+
server_args = self.server_args
|
207
|
+
|
208
|
+
if (
|
209
|
+
self.model_config.attention_arch == AttentionArch.MLA
|
210
|
+
and not server_args.disable_mla
|
211
|
+
):
|
212
|
+
# TODO: add MLA optimization on CPU
|
213
|
+
if server_args.device != "cpu":
|
214
|
+
if server_args.enable_flashinfer_mla:
|
215
|
+
logger.info(
|
216
|
+
"MLA optimization is turned on. Use flashinfer mla backend."
|
217
|
+
)
|
218
|
+
server_args.attention_backend = "flashinfer_mla"
|
219
|
+
else:
|
220
|
+
logger.info("MLA optimization is turned on. Use triton backend.")
|
221
|
+
server_args.attention_backend = "triton"
|
222
|
+
|
223
|
+
if server_args.enable_double_sparsity:
|
224
|
+
logger.info(
|
225
|
+
"Double sparsity optimization is turned on. Use triton backend without CUDA graph."
|
226
|
+
)
|
227
|
+
server_args.attention_backend = "triton"
|
228
|
+
server_args.disable_cuda_graph = True
|
229
|
+
if server_args.ds_heavy_channel_type is None:
|
230
|
+
raise ValueError(
|
231
|
+
"Please specify the heavy channel type for double sparsity optimization."
|
232
|
+
)
|
233
|
+
self.init_double_sparsity_channel_config(server_args.ds_heavy_channel_type)
|
234
|
+
|
235
|
+
if self.is_multimodal:
|
236
|
+
self.mem_fraction_static *= 0.95
|
237
|
+
logger.info(
|
238
|
+
f"Automatically reduce --mem-fraction-static to {self.mem_fraction_static:.3f} "
|
239
|
+
f"because this is a multimodal model."
|
240
|
+
)
|
241
|
+
|
242
|
+
if self.model_config.hf_config.architectures == [
|
243
|
+
"MllamaForConditionalGeneration"
|
244
|
+
]:
|
245
|
+
logger.info("Automatically turn off --chunked-prefill-size for mllama.")
|
246
|
+
server_args.chunked_prefill_size = -1
|
247
|
+
|
248
|
+
if self.model_config.hf_config.architectures == [
|
249
|
+
"Qwen2VLForConditionalGeneration"
|
250
|
+
]:
|
251
|
+
# TODO: qwen2-vl does not support radix cache now, set disable_radix_cache=True automatically
|
252
|
+
logger.info(
|
253
|
+
"Automatically turn off --chunked-prefill-size and disable radix cache for qwen2-vl."
|
254
|
+
)
|
255
|
+
server_args.chunked_prefill_size = -1
|
256
|
+
server_args.disable_radix_cache = True
|
257
|
+
|
229
258
|
def init_torch_distributed(self):
|
230
259
|
logger.info("Init torch distributed begin.")
|
231
260
|
|
@@ -233,14 +262,13 @@ class ModelRunner:
|
|
233
262
|
if self.device == "cuda":
|
234
263
|
backend = "nccl"
|
235
264
|
elif self.device == "xpu":
|
236
|
-
|
237
|
-
# Need to use xccl for xpu backend in the future
|
238
|
-
backend = "gloo"
|
265
|
+
backend = "xccl"
|
239
266
|
elif self.device == "hpu":
|
240
267
|
backend = "hccl"
|
241
268
|
elif self.device == "cpu":
|
242
269
|
backend = "gloo"
|
243
270
|
|
271
|
+
before_avail_memory = get_available_gpu_memory(self.device, self.gpu_id)
|
244
272
|
if not self.server_args.enable_p2p_check:
|
245
273
|
monkey_patch_p2p_access_check()
|
246
274
|
|
@@ -258,6 +286,7 @@ class ModelRunner:
|
|
258
286
|
rank=self.tp_rank,
|
259
287
|
local_rank=self.gpu_id,
|
260
288
|
distributed_init_method=dist_init_method,
|
289
|
+
timeout=self.server_args.dist_timeout,
|
261
290
|
)
|
262
291
|
initialize_model_parallel(tensor_model_parallel_size=self.tp_size)
|
263
292
|
initialize_dp_attention(
|
@@ -270,20 +299,24 @@ class ModelRunner:
|
|
270
299
|
min_per_gpu_memory = get_available_gpu_memory(
|
271
300
|
self.device, self.gpu_id, distributed=self.tp_size > 1
|
272
301
|
)
|
302
|
+
local_gpu_memory = get_available_gpu_memory(self.device, self.gpu_id)
|
273
303
|
self.tp_group = get_tp_group()
|
274
304
|
self.attention_tp_group = get_attention_tp_group()
|
275
305
|
|
276
306
|
# Check memory for tensor parallelism
|
277
307
|
if self.tp_size > 1:
|
278
|
-
local_gpu_memory = get_available_gpu_memory(self.device, self.gpu_id)
|
279
308
|
if min_per_gpu_memory < local_gpu_memory * 0.9:
|
280
309
|
raise ValueError(
|
281
310
|
"The memory capacity is unbalanced. Some GPUs may be occupied by other processes."
|
282
311
|
)
|
283
312
|
|
313
|
+
logger.info(
|
314
|
+
f"Init torch distributed ends. mem usage={(before_avail_memory - local_gpu_memory):.2f} GB"
|
315
|
+
)
|
284
316
|
return min_per_gpu_memory
|
285
317
|
|
286
318
|
def load_model(self):
|
319
|
+
before_avail_memory = get_available_gpu_memory(self.device, self.gpu_id)
|
287
320
|
logger.info(
|
288
321
|
f"Load weight begin. avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
|
289
322
|
)
|
@@ -353,13 +386,27 @@ class ModelRunner:
|
|
353
386
|
)
|
354
387
|
self.dtype = self.model_config.dtype
|
355
388
|
|
389
|
+
after_avail_memory = get_available_gpu_memory(self.device, self.gpu_id)
|
356
390
|
logger.info(
|
357
391
|
f"Load weight end. "
|
358
392
|
f"type={type(self.model).__name__}, "
|
359
393
|
f"dtype={self.dtype}, "
|
360
|
-
f"avail mem={
|
394
|
+
f"avail mem={after_avail_memory:.2f} GB, "
|
395
|
+
f"mem usage={(before_avail_memory - after_avail_memory):.2f} GB."
|
361
396
|
)
|
362
397
|
|
398
|
+
# Handle the case where some ranks do not finish loading.
|
399
|
+
try:
|
400
|
+
dist.monitored_barrier(
|
401
|
+
group=get_tp_group().cpu_group,
|
402
|
+
timeout=datetime.timedelta(seconds=UNBALANCED_MODEL_LOADING_TIMEOUT_S),
|
403
|
+
wait_all_ranks=True,
|
404
|
+
)
|
405
|
+
except RuntimeError:
|
406
|
+
raise ValueError(
|
407
|
+
f"TP rank {self.tp_rank} could finish the model loading, but there are other ranks that didn't finish loading. It is likely due to unexpected failures (e.g., OOM) or a slow node."
|
408
|
+
) from None
|
409
|
+
|
363
410
|
def update_weights_from_disk(
|
364
411
|
self, model_path: str, load_format: str
|
365
412
|
) -> tuple[bool, str]:
|
@@ -512,8 +559,21 @@ class ModelRunner:
|
|
512
559
|
logger.error(error_msg)
|
513
560
|
return False, error_msg
|
514
561
|
|
515
|
-
def update_weights_from_tensor(
|
516
|
-
self
|
562
|
+
def update_weights_from_tensor(
|
563
|
+
self,
|
564
|
+
named_tensors: List[Tuple[str, Union[torch.Tensor, "LocalSerializedTensor"]]],
|
565
|
+
load_format: Optional[str] = None,
|
566
|
+
):
|
567
|
+
named_tensors = [
|
568
|
+
(name, _unwrap_tensor(tensor, tp_rank=self.tp_rank))
|
569
|
+
for name, tensor in named_tensors
|
570
|
+
]
|
571
|
+
if load_format == "direct":
|
572
|
+
_model_load_weights_direct(self.model, named_tensors)
|
573
|
+
elif load_format is None:
|
574
|
+
self.model.load_weights(named_tensors)
|
575
|
+
else:
|
576
|
+
raise NotImplementedError(f"Unknown load_format={load_format}")
|
517
577
|
return True, "Success"
|
518
578
|
|
519
579
|
def get_weights_by_name(
|
@@ -606,15 +666,31 @@ class ModelRunner:
|
|
606
666
|
4096,
|
607
667
|
)
|
608
668
|
|
669
|
+
if SGLANG_CI_SMALL_KV_SIZE:
|
670
|
+
self.max_total_num_tokens = int(SGLANG_CI_SMALL_KV_SIZE)
|
671
|
+
|
609
672
|
if not self.spec_algorithm.is_none():
|
610
673
|
if self.is_draft_worker:
|
611
674
|
self.max_total_num_tokens = self.server_args.draft_runner_cache_size
|
675
|
+
max_num_reqs = self.server_args.max_num_reqs
|
612
676
|
else:
|
677
|
+
# We are sharing the `token_to_kv_pool`, and both verify and draft tokens
|
678
|
+
# can be concurrently allocated, so we should give a headroom for it.
|
613
679
|
self.server_args.draft_runner_cache_size = (
|
614
680
|
self.max_total_num_tokens
|
615
|
-
|
681
|
+
# draft
|
682
|
+
+ max_num_reqs
|
683
|
+
* self.server_args.speculative_num_steps
|
684
|
+
* self.server_args.speculative_eagle_topk
|
685
|
+
# verify
|
686
|
+
+ max_num_reqs * self.server_args.speculative_num_draft_tokens
|
687
|
+
# buffer
|
616
688
|
+ 100
|
617
689
|
)
|
690
|
+
# Target worker and draft worker shares the same indices for the
|
691
|
+
# token_to_kv_pool, so we should make sure to match max_total_num_tokens.
|
692
|
+
self.max_total_num_tokens = self.server_args.draft_runner_cache_size
|
693
|
+
self.server_args.max_num_reqs = max_num_reqs
|
618
694
|
|
619
695
|
if max_total_tokens is not None:
|
620
696
|
if max_total_tokens > self.max_total_num_tokens:
|
@@ -630,12 +706,17 @@ class ModelRunner:
|
|
630
706
|
"Not enough memory. Please try to increase --mem-fraction-static."
|
631
707
|
)
|
632
708
|
|
633
|
-
self.req_to_token_pool
|
634
|
-
|
635
|
-
|
636
|
-
|
637
|
-
|
638
|
-
|
709
|
+
if self.req_to_token_pool is None:
|
710
|
+
self.req_to_token_pool = ReqToTokenPool(
|
711
|
+
size=max_num_reqs + 1,
|
712
|
+
max_context_len=self.model_config.context_len + 4,
|
713
|
+
device=self.device,
|
714
|
+
enable_memory_saver=self.server_args.enable_memory_saver,
|
715
|
+
)
|
716
|
+
else:
|
717
|
+
# Draft worker shares req_to_token_pool with the target worker.
|
718
|
+
assert self.is_draft_worker
|
719
|
+
|
639
720
|
if (
|
640
721
|
self.model_config.attention_arch == AttentionArch.MLA
|
641
722
|
and not self.server_args.disable_mla
|
@@ -670,6 +751,17 @@ class ModelRunner:
|
|
670
751
|
device=self.device,
|
671
752
|
enable_memory_saver=self.server_args.enable_memory_saver,
|
672
753
|
)
|
754
|
+
|
755
|
+
if self.token_to_kv_pool_allocator is None:
|
756
|
+
self.token_to_kv_pool_allocator = TokenToKVPoolAllocator(
|
757
|
+
self.max_total_num_tokens,
|
758
|
+
dtype=self.kv_cache_dtype,
|
759
|
+
device=self.device,
|
760
|
+
kvcache=self.token_to_kv_pool,
|
761
|
+
)
|
762
|
+
else:
|
763
|
+
assert self.is_draft_worker
|
764
|
+
|
673
765
|
logger.info(
|
674
766
|
f"Memory pool end. "
|
675
767
|
f"avail mem={get_available_gpu_memory(self.device, self.gpu_id):.2f} GB"
|
@@ -687,6 +779,10 @@ class ModelRunner:
|
|
687
779
|
def init_attention_backend(self):
|
688
780
|
"""Init attention kernel backend."""
|
689
781
|
if self.server_args.attention_backend == "flashinfer":
|
782
|
+
# Init streams
|
783
|
+
if self.server_args.speculative_algorithm == "EAGLE":
|
784
|
+
self.plan_stream_for_flashinfer = torch.cuda.Stream()
|
785
|
+
|
690
786
|
self.attn_backend = FlashInferAttnBackend(self)
|
691
787
|
elif self.server_args.attention_backend == "triton":
|
692
788
|
assert self.sliding_window_size is None, (
|
@@ -703,6 +799,8 @@ class ModelRunner:
|
|
703
799
|
self.attn_backend = TritonAttnBackend(self)
|
704
800
|
elif self.server_args.attention_backend == "torch_native":
|
705
801
|
self.attn_backend = TorchNativeAttnBackend(self)
|
802
|
+
elif self.server_args.attention_backend == "flashinfer_mla":
|
803
|
+
self.attn_backend = FlashInferMLAAttnBackend(self)
|
706
804
|
else:
|
707
805
|
raise ValueError(
|
708
806
|
f"Invalid attention backend: {self.server_args.attention_backend}"
|
@@ -737,9 +835,16 @@ class ModelRunner:
|
|
737
835
|
return
|
738
836
|
|
739
837
|
tic = time.time()
|
740
|
-
|
838
|
+
before_mem = get_available_gpu_memory(self.device, self.gpu_id)
|
839
|
+
logger.info(
|
840
|
+
f"Capture cuda graph begin. This can take up to several minutes. avail mem={before_mem:.2f} GB"
|
841
|
+
)
|
741
842
|
self.cuda_graph_runner = CudaGraphRunner(self)
|
742
|
-
|
843
|
+
after_mem = get_available_gpu_memory(self.device, self.gpu_id)
|
844
|
+
logger.info(
|
845
|
+
f"Capture cuda graph end. Time elapsed: {time.time() - tic:.2f} s. "
|
846
|
+
f"avail mem={after_mem:.2f} GB. mem usage={(before_mem - after_mem):.2f} GB."
|
847
|
+
)
|
743
848
|
|
744
849
|
def apply_torch_tp(self):
|
745
850
|
logger.info(f"Enabling torch tensor parallelism on {self.tp_size} devices.")
|
@@ -754,8 +859,12 @@ class ModelRunner:
|
|
754
859
|
forward_batch.input_ids, forward_batch.positions, forward_batch
|
755
860
|
)
|
756
861
|
|
757
|
-
def forward_extend(
|
758
|
-
self
|
862
|
+
def forward_extend(
|
863
|
+
self, forward_batch: ForwardBatch, skip_attn_backend_init: bool = False
|
864
|
+
):
|
865
|
+
if not skip_attn_backend_init:
|
866
|
+
self.attn_backend.init_forward_metadata(forward_batch)
|
867
|
+
|
759
868
|
if self.is_generation:
|
760
869
|
if forward_batch.input_embeds is None:
|
761
870
|
return self.model.forward(
|
@@ -782,28 +891,33 @@ class ModelRunner:
|
|
782
891
|
forward_batch.input_ids, forward_batch.positions, forward_batch
|
783
892
|
)
|
784
893
|
|
785
|
-
def forward(
|
894
|
+
def forward(
|
895
|
+
self, forward_batch: ForwardBatch, skip_attn_backend_init: bool = False
|
896
|
+
) -> LogitsProcessorOutput:
|
786
897
|
if (
|
787
898
|
forward_batch.forward_mode.is_cuda_graph()
|
788
899
|
and self.cuda_graph_runner
|
789
900
|
and self.cuda_graph_runner.can_run(forward_batch)
|
790
901
|
):
|
791
|
-
return self.cuda_graph_runner.replay(
|
902
|
+
return self.cuda_graph_runner.replay(
|
903
|
+
forward_batch, skip_attn_backend_init=skip_attn_backend_init
|
904
|
+
)
|
792
905
|
|
793
906
|
if forward_batch.forward_mode.is_decode():
|
794
907
|
return self.forward_decode(forward_batch)
|
795
908
|
elif forward_batch.forward_mode.is_extend():
|
796
|
-
return self.forward_extend(
|
909
|
+
return self.forward_extend(
|
910
|
+
forward_batch, skip_attn_backend_init=skip_attn_backend_init
|
911
|
+
)
|
797
912
|
elif forward_batch.forward_mode.is_idle():
|
798
913
|
return self.forward_idle(forward_batch)
|
799
914
|
else:
|
800
915
|
raise ValueError(f"Invalid forward mode: {forward_batch.forward_mode}")
|
801
916
|
|
802
|
-
def
|
803
|
-
self, logits_output: LogitsProcessorOutput,
|
804
|
-
)
|
917
|
+
def _preprocess_logits(
|
918
|
+
self, logits_output: LogitsProcessorOutput, sampling_info: SamplingBatchInfo
|
919
|
+
):
|
805
920
|
# Apply logit bias
|
806
|
-
sampling_info = forward_batch.sampling_info
|
807
921
|
if sampling_info.sampling_info_done:
|
808
922
|
# Overlap mode: the function update_regex_vocab_mask was executed
|
809
923
|
# in process_batch_result of the last batch.
|
@@ -812,15 +926,77 @@ class ModelRunner:
|
|
812
926
|
else:
|
813
927
|
# Normal mode: Put CPU-heavy tasks here. They will be overlapped with the forward pass.
|
814
928
|
sampling_info.update_regex_vocab_mask()
|
815
|
-
sampling_info.update_penalties()
|
816
929
|
sampling_info.apply_logits_bias(logits_output.next_token_logits)
|
817
930
|
|
931
|
+
def update_output_logprobs(
|
932
|
+
self,
|
933
|
+
logits_output: LogitsProcessorOutput,
|
934
|
+
sampling_info: SamplingBatchInfo,
|
935
|
+
top_logprobs_nums: List[int],
|
936
|
+
token_ids_logprobs: List[int],
|
937
|
+
next_token_ids: torch.Tensor,
|
938
|
+
*,
|
939
|
+
num_tokens_per_req: List[int],
|
940
|
+
):
|
941
|
+
"""Update the logits_output's output logprob based on next_token_ids
|
942
|
+
|
943
|
+
Args:
|
944
|
+
logits_output: The logits output from the model forward
|
945
|
+
sampling_info: Sampling info for logprob calculation
|
946
|
+
top_logprobs_nums: Number of logprobs per request.
|
947
|
+
next_token_ids: Next token ids.
|
948
|
+
num_tokens_per_req: The number of tokens per request.
|
949
|
+
|
950
|
+
Returns:
|
951
|
+
A list of next_token_ids
|
952
|
+
"""
|
953
|
+
self._preprocess_logits(logits_output, sampling_info)
|
954
|
+
# We should repeat top_logprobs_nums to match num_tokens_per_req.
|
955
|
+
top_logprobs_nums_repeat_interleaved = []
|
956
|
+
token_ids_logprobs_repeat_interleaved = []
|
957
|
+
for num, num_tokens in zip(top_logprobs_nums, num_tokens_per_req):
|
958
|
+
top_logprobs_nums_repeat_interleaved.extend([num] * num_tokens)
|
959
|
+
for token_ids, num_tokens in zip(token_ids_logprobs, num_tokens_per_req):
|
960
|
+
token_ids_logprobs_repeat_interleaved.extend([token_ids] * num_tokens)
|
961
|
+
self.sampler(
|
962
|
+
logits_output,
|
963
|
+
sampling_info,
|
964
|
+
True,
|
965
|
+
top_logprobs_nums_repeat_interleaved,
|
966
|
+
token_ids_logprobs_repeat_interleaved,
|
967
|
+
batch_next_token_ids=next_token_ids,
|
968
|
+
)
|
969
|
+
|
970
|
+
def sample(
|
971
|
+
self,
|
972
|
+
logits_output: LogitsProcessorOutput,
|
973
|
+
forward_batch: ForwardBatch,
|
974
|
+
) -> torch.Tensor:
|
975
|
+
"""Sample and compute logprobs and update logits_output.
|
976
|
+
|
977
|
+
Args:
|
978
|
+
logits_output: The logits output from the model forward
|
979
|
+
forward_batch: The forward batch that generates logits_output
|
980
|
+
|
981
|
+
Returns:
|
982
|
+
A list of next_token_ids
|
983
|
+
"""
|
984
|
+
# For duplex models with multiple output streams.
|
985
|
+
if isinstance(logits_output, tuple):
|
986
|
+
return torch.stack(
|
987
|
+
[self.sample(values, forward_batch) for values in logits_output],
|
988
|
+
axis=-1,
|
989
|
+
)
|
990
|
+
|
991
|
+
self._preprocess_logits(logits_output, forward_batch.sampling_info)
|
992
|
+
|
818
993
|
# Sample the next tokens
|
819
994
|
next_token_ids = self.sampler(
|
820
995
|
logits_output,
|
821
|
-
sampling_info,
|
996
|
+
forward_batch.sampling_info,
|
822
997
|
forward_batch.return_logprob,
|
823
998
|
forward_batch.top_logprobs_nums,
|
999
|
+
forward_batch.token_ids_logprobs,
|
824
1000
|
)
|
825
1001
|
return next_token_ids
|
826
1002
|
|
@@ -832,3 +1008,26 @@ class ModelRunner:
|
|
832
1008
|
if rope_scaling is None:
|
833
1009
|
return False
|
834
1010
|
return rope_scaling.get("type", None) == "mrope"
|
1011
|
+
|
1012
|
+
|
1013
|
+
def _model_load_weights_direct(model, named_tensors: List[Tuple[str, torch.Tensor]]):
|
1014
|
+
params_dict = dict(model.named_parameters())
|
1015
|
+
for name, tensor in named_tensors:
|
1016
|
+
default_weight_loader(params_dict[name], tensor)
|
1017
|
+
|
1018
|
+
|
1019
|
+
def _unwrap_tensor(tensor, tp_rank):
|
1020
|
+
if isinstance(tensor, LocalSerializedTensor):
|
1021
|
+
return tensor.get(tp_rank)
|
1022
|
+
return tensor
|
1023
|
+
|
1024
|
+
|
1025
|
+
@dataclass
|
1026
|
+
class LocalSerializedTensor:
|
1027
|
+
"""torch.Tensor that gets serialized by MultiprocessingSerializer (which only serializes a pointer and not the data).
|
1028
|
+
The i-th element in the list corresponds to i-th rank's GPU."""
|
1029
|
+
|
1030
|
+
values: List[bytes]
|
1031
|
+
|
1032
|
+
def get(self, rank: int):
|
1033
|
+
return MultiprocessingSerializer.deserialize(self.values[rank])
|
@@ -11,7 +11,7 @@ import math
|
|
11
11
|
import os
|
12
12
|
from abc import ABC, abstractmethod
|
13
13
|
from contextlib import contextmanager
|
14
|
-
from typing import Any, Dict, Generator, Iterable, List, Optional, Tuple,
|
14
|
+
from typing import Any, Dict, Generator, Iterable, List, Optional, Tuple, cast
|
15
15
|
|
16
16
|
import gguf
|
17
17
|
import huggingface_hub
|
@@ -19,7 +19,7 @@ import numpy as np
|
|
19
19
|
import torch
|
20
20
|
from huggingface_hub import HfApi, hf_hub_download
|
21
21
|
from torch import nn
|
22
|
-
from transformers import AutoModelForCausalLM
|
22
|
+
from transformers import AutoModelForCausalLM
|
23
23
|
from transformers.utils import SAFE_WEIGHTS_INDEX_NAME
|
24
24
|
|
25
25
|
from sglang.srt.configs.device_config import DeviceConfig
|
@@ -197,7 +197,7 @@ class DefaultModelLoader(BaseModelLoader):
|
|
197
197
|
|
198
198
|
Returns the path to the downloaded model, or None if the model is not
|
199
199
|
downloaded from ModelScope."""
|
200
|
-
if "SGLANG_USE_MODELSCOPE"
|
200
|
+
if os.environ.get("SGLANG_USE_MODELSCOPE", None) == "True":
|
201
201
|
# download model from ModelScope hub,
|
202
202
|
# lazy import so that modelscope is not required for normal use.
|
203
203
|
# pylint: disable=C.
|