sglang 0.4.1.post6__py3-none-any.whl → 0.4.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +21 -23
- sglang/api.py +2 -7
- sglang/bench_offline_throughput.py +41 -27
- sglang/bench_one_batch.py +60 -4
- sglang/bench_one_batch_server.py +1 -1
- sglang/bench_serving.py +83 -71
- sglang/lang/backend/runtime_endpoint.py +183 -4
- sglang/lang/chat_template.py +46 -4
- sglang/launch_server.py +1 -1
- sglang/srt/_custom_ops.py +80 -42
- sglang/srt/configs/device_config.py +1 -1
- sglang/srt/configs/load_config.py +1 -0
- sglang/srt/configs/model_config.py +1 -0
- sglang/srt/constrained/base_grammar_backend.py +21 -0
- sglang/srt/constrained/xgrammar_backend.py +8 -4
- sglang/srt/conversation.py +14 -1
- sglang/srt/distributed/__init__.py +3 -3
- sglang/srt/distributed/communication_op.py +2 -1
- sglang/srt/distributed/device_communicators/cuda_wrapper.py +2 -1
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +112 -42
- sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +2 -2
- sglang/srt/distributed/device_communicators/hpu_communicator.py +2 -1
- sglang/srt/distributed/device_communicators/pynccl.py +80 -1
- sglang/srt/distributed/device_communicators/pynccl_wrapper.py +112 -2
- sglang/srt/distributed/device_communicators/shm_broadcast.py +5 -72
- sglang/srt/distributed/device_communicators/xpu_communicator.py +2 -1
- sglang/srt/distributed/parallel_state.py +1 -1
- sglang/srt/distributed/utils.py +2 -1
- sglang/srt/entrypoints/engine.py +452 -0
- sglang/srt/entrypoints/http_server.py +603 -0
- sglang/srt/function_call_parser.py +494 -0
- sglang/srt/layers/activation.py +8 -8
- sglang/srt/layers/attention/flashinfer_backend.py +10 -9
- sglang/srt/layers/attention/triton_backend.py +4 -6
- sglang/srt/layers/attention/vision.py +204 -0
- sglang/srt/layers/dp_attention.py +71 -0
- sglang/srt/layers/layernorm.py +5 -5
- sglang/srt/layers/linear.py +65 -14
- sglang/srt/layers/logits_processor.py +49 -64
- sglang/srt/layers/moe/ep_moe/layer.py +24 -16
- sglang/srt/layers/moe/fused_moe_native.py +84 -1
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +27 -7
- sglang/srt/layers/moe/fused_moe_triton/layer.py +38 -5
- sglang/srt/layers/parameter.py +18 -8
- sglang/srt/layers/quantization/__init__.py +20 -23
- sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=3072,K=1536,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=512,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=7168,K=256,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/fp8.py +10 -4
- sglang/srt/layers/quantization/modelopt_quant.py +1 -2
- sglang/srt/layers/quantization/w8a8_int8.py +1 -1
- sglang/srt/layers/radix_attention.py +2 -2
- sglang/srt/layers/rotary_embedding.py +1184 -31
- sglang/srt/layers/sampler.py +64 -6
- sglang/srt/layers/torchao_utils.py +12 -6
- sglang/srt/layers/vocab_parallel_embedding.py +2 -2
- sglang/srt/lora/lora.py +1 -9
- sglang/srt/managers/configure_logging.py +3 -0
- sglang/srt/managers/data_parallel_controller.py +79 -72
- sglang/srt/managers/detokenizer_manager.py +24 -6
- sglang/srt/managers/image_processor.py +158 -2
- sglang/srt/managers/io_struct.py +57 -3
- sglang/srt/managers/schedule_batch.py +78 -45
- sglang/srt/managers/schedule_policy.py +26 -12
- sglang/srt/managers/scheduler.py +326 -201
- sglang/srt/managers/session_controller.py +1 -0
- sglang/srt/managers/tokenizer_manager.py +210 -121
- sglang/srt/managers/tp_worker.py +6 -4
- sglang/srt/managers/tp_worker_overlap_thread.py +5 -8
- sglang/srt/managers/utils.py +44 -0
- sglang/srt/mem_cache/memory_pool.py +10 -32
- sglang/srt/metrics/collector.py +15 -6
- sglang/srt/model_executor/cuda_graph_runner.py +26 -30
- sglang/srt/model_executor/forward_batch_info.py +5 -7
- sglang/srt/model_executor/model_runner.py +44 -19
- sglang/srt/model_loader/loader.py +83 -6
- sglang/srt/model_loader/weight_utils.py +145 -6
- sglang/srt/models/baichuan.py +6 -6
- sglang/srt/models/chatglm.py +2 -2
- sglang/srt/models/commandr.py +17 -5
- sglang/srt/models/dbrx.py +13 -5
- sglang/srt/models/deepseek.py +3 -3
- sglang/srt/models/deepseek_v2.py +11 -11
- sglang/srt/models/exaone.py +2 -2
- sglang/srt/models/gemma.py +2 -2
- sglang/srt/models/gemma2.py +15 -25
- sglang/srt/models/gpt2.py +3 -5
- sglang/srt/models/gpt_bigcode.py +1 -1
- sglang/srt/models/granite.py +2 -2
- sglang/srt/models/grok.py +4 -3
- sglang/srt/models/internlm2.py +2 -2
- sglang/srt/models/llama.py +7 -5
- sglang/srt/models/minicpm.py +2 -2
- sglang/srt/models/minicpm3.py +9 -9
- sglang/srt/models/minicpmv.py +1238 -0
- sglang/srt/models/mixtral.py +3 -3
- sglang/srt/models/mixtral_quant.py +3 -3
- sglang/srt/models/mllama.py +2 -2
- sglang/srt/models/olmo.py +3 -3
- sglang/srt/models/olmo2.py +4 -4
- sglang/srt/models/olmoe.py +7 -13
- sglang/srt/models/phi3_small.py +2 -2
- sglang/srt/models/qwen.py +2 -2
- sglang/srt/models/qwen2.py +41 -4
- sglang/srt/models/qwen2_moe.py +3 -3
- sglang/srt/models/qwen2_vl.py +22 -122
- sglang/srt/models/stablelm.py +2 -2
- sglang/srt/models/torch_native_llama.py +20 -7
- sglang/srt/models/xverse.py +6 -6
- sglang/srt/models/xverse_moe.py +6 -6
- sglang/srt/openai_api/adapter.py +139 -37
- sglang/srt/openai_api/protocol.py +7 -4
- sglang/srt/sampling/custom_logit_processor.py +38 -0
- sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +11 -14
- sglang/srt/sampling/sampling_batch_info.py +143 -18
- sglang/srt/sampling/sampling_params.py +3 -1
- sglang/srt/server.py +4 -1090
- sglang/srt/server_args.py +77 -15
- sglang/srt/speculative/eagle_utils.py +37 -15
- sglang/srt/speculative/eagle_worker.py +11 -13
- sglang/srt/utils.py +164 -129
- sglang/test/runners.py +8 -13
- sglang/test/test_programs.py +2 -1
- sglang/test/test_utils.py +83 -22
- sglang/utils.py +12 -2
- sglang/version.py +1 -1
- {sglang-0.4.1.post6.dist-info → sglang-0.4.2.dist-info}/METADATA +21 -10
- {sglang-0.4.1.post6.dist-info → sglang-0.4.2.dist-info}/RECORD +138 -123
- sglang/launch_server_llavavid.py +0 -25
- sglang/srt/constrained/__init__.py +0 -16
- sglang/srt/distributed/device_communicators/__init__.py +0 -0
- {sglang-0.4.1.post6.dist-info → sglang-0.4.2.dist-info}/LICENSE +0 -0
- {sglang-0.4.1.post6.dist-info → sglang-0.4.2.dist-info}/WHEEL +0 -0
- {sglang-0.4.1.post6.dist-info → sglang-0.4.2.dist-info}/top_level.txt +0 -0
sglang/srt/models/mixtral.py
CHANGED
@@ -21,12 +21,11 @@ from typing import Iterable, Optional, Tuple
|
|
21
21
|
import torch
|
22
22
|
from torch import nn
|
23
23
|
from transformers import MixtralConfig
|
24
|
-
|
24
|
+
|
25
|
+
from sglang.srt.distributed import (
|
25
26
|
get_tensor_model_parallel_world_size,
|
26
27
|
tensor_model_parallel_all_reduce,
|
27
28
|
)
|
28
|
-
from vllm.model_executor.layers.rotary_embedding import get_rope
|
29
|
-
|
30
29
|
from sglang.srt.layers.layernorm import RMSNorm
|
31
30
|
from sglang.srt.layers.linear import (
|
32
31
|
QKVParallelLinear,
|
@@ -38,6 +37,7 @@ from sglang.srt.layers.moe.ep_moe.layer import EPMoE
|
|
38
37
|
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
39
38
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
40
39
|
from sglang.srt.layers.radix_attention import RadixAttention
|
40
|
+
from sglang.srt.layers.rotary_embedding import get_rope
|
41
41
|
from sglang.srt.layers.vocab_parallel_embedding import (
|
42
42
|
ParallelLMHead,
|
43
43
|
VocabParallelEmbedding,
|
@@ -23,13 +23,12 @@ import torch
|
|
23
23
|
import torch.nn.functional as F
|
24
24
|
from torch import nn
|
25
25
|
from transformers import MixtralConfig
|
26
|
-
|
26
|
+
|
27
|
+
from sglang.srt.distributed import (
|
27
28
|
get_tensor_model_parallel_rank,
|
28
29
|
get_tensor_model_parallel_world_size,
|
29
30
|
tensor_model_parallel_all_reduce,
|
30
31
|
)
|
31
|
-
from vllm.model_executor.layers.rotary_embedding import get_rope
|
32
|
-
|
33
32
|
from sglang.srt.layers.layernorm import RMSNorm
|
34
33
|
from sglang.srt.layers.linear import (
|
35
34
|
QKVParallelLinear,
|
@@ -39,6 +38,7 @@ from sglang.srt.layers.linear import (
|
|
39
38
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
40
39
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
41
40
|
from sglang.srt.layers.radix_attention import RadixAttention
|
41
|
+
from sglang.srt.layers.rotary_embedding import get_rope
|
42
42
|
from sglang.srt.layers.vocab_parallel_embedding import (
|
43
43
|
ParallelLMHead,
|
44
44
|
VocabParallelEmbedding,
|
sglang/srt/models/mllama.py
CHANGED
@@ -8,14 +8,14 @@ import torch
|
|
8
8
|
import torch.nn.functional as F
|
9
9
|
import torch.utils.checkpoint
|
10
10
|
import transformers.models.mllama.configuration_mllama as config_mllama
|
11
|
-
import vllm.distributed.parallel_state as ps
|
12
11
|
from torch import nn
|
13
12
|
from transformers.modeling_outputs import BaseModelOutput, CausalLMOutputWithPast
|
14
13
|
from transformers.models.mllama.modeling_mllama import (
|
15
14
|
_prepare_aspect_ratio_attention_mask,
|
16
15
|
)
|
17
|
-
from vllm.distributed import get_tensor_model_parallel_world_size
|
18
16
|
|
17
|
+
import sglang.srt.distributed.parallel_state as ps
|
18
|
+
from sglang.srt.distributed import get_tensor_model_parallel_world_size
|
19
19
|
from sglang.srt.layers.activation import get_act_fn
|
20
20
|
from sglang.srt.layers.layernorm import RMSNorm
|
21
21
|
from sglang.srt.layers.linear import (
|
sglang/srt/models/olmo.py
CHANGED
@@ -15,14 +15,13 @@
|
|
15
15
|
# Adapted from
|
16
16
|
# https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/olmo.py#L1
|
17
17
|
"""Inference-only OLMo model compatible with HuggingFace weights."""
|
18
|
-
from typing import Iterable,
|
18
|
+
from typing import Iterable, Optional, Tuple
|
19
19
|
|
20
20
|
import torch
|
21
21
|
from torch import nn
|
22
22
|
from transformers import OlmoConfig
|
23
|
-
from vllm.distributed import get_tensor_model_parallel_world_size
|
24
|
-
from vllm.model_executor.layers.rotary_embedding import get_rope
|
25
23
|
|
24
|
+
from sglang.srt.distributed import get_tensor_model_parallel_world_size
|
26
25
|
from sglang.srt.layers.activation import SiluAndMul
|
27
26
|
from sglang.srt.layers.linear import (
|
28
27
|
MergedColumnParallelLinear,
|
@@ -32,6 +31,7 @@ from sglang.srt.layers.linear import (
|
|
32
31
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
33
32
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
34
33
|
from sglang.srt.layers.radix_attention import RadixAttention
|
34
|
+
from sglang.srt.layers.rotary_embedding import get_rope
|
35
35
|
from sglang.srt.layers.vocab_parallel_embedding import (
|
36
36
|
ParallelLMHead,
|
37
37
|
VocabParallelEmbedding,
|
sglang/srt/models/olmo2.py
CHANGED
@@ -21,15 +21,13 @@ from typing import Iterable, Optional, Tuple
|
|
21
21
|
import torch
|
22
22
|
from torch import nn
|
23
23
|
from transformers import PretrainedConfig
|
24
|
-
|
24
|
+
|
25
|
+
from sglang.srt.distributed import (
|
25
26
|
get_tensor_model_parallel_rank,
|
26
27
|
get_tensor_model_parallel_world_size,
|
27
28
|
split_tensor_along_last_dim,
|
28
29
|
tensor_model_parallel_all_gather,
|
29
30
|
)
|
30
|
-
from vllm.model_executor.layers.rotary_embedding import get_rope
|
31
|
-
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
32
|
-
|
33
31
|
from sglang.srt.layers.activation import SiluAndMul
|
34
32
|
from sglang.srt.layers.layernorm import RMSNorm
|
35
33
|
from sglang.srt.layers.linear import (
|
@@ -40,11 +38,13 @@ from sglang.srt.layers.linear import (
|
|
40
38
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
41
39
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
42
40
|
from sglang.srt.layers.radix_attention import RadixAttention
|
41
|
+
from sglang.srt.layers.rotary_embedding import get_rope
|
43
42
|
from sglang.srt.layers.vocab_parallel_embedding import (
|
44
43
|
ParallelLMHead,
|
45
44
|
VocabParallelEmbedding,
|
46
45
|
)
|
47
46
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
47
|
+
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
48
48
|
from sglang.srt.utils import make_layers
|
49
49
|
|
50
50
|
|
sglang/srt/models/olmoe.py
CHANGED
@@ -17,30 +17,24 @@
|
|
17
17
|
|
18
18
|
"""Inference-only OLMoE model compatible with HuggingFace weights."""
|
19
19
|
|
20
|
-
from typing import Any, Dict, Iterable,
|
20
|
+
from typing import Any, Dict, Iterable, Optional, Tuple
|
21
21
|
|
22
22
|
import torch
|
23
|
-
import torch.nn.functional as F
|
24
23
|
from torch import nn
|
25
24
|
from transformers import PretrainedConfig
|
26
|
-
|
27
|
-
|
28
|
-
|
29
|
-
|
30
|
-
from vllm.model_executor.layers.linear import (
|
31
|
-
MergedColumnParallelLinear,
|
25
|
+
|
26
|
+
from sglang.srt.distributed import get_tensor_model_parallel_world_size
|
27
|
+
from sglang.srt.layers.layernorm import RMSNorm
|
28
|
+
from sglang.srt.layers.linear import (
|
32
29
|
QKVParallelLinear,
|
33
30
|
ReplicatedLinear,
|
34
31
|
RowParallelLinear,
|
35
32
|
)
|
36
|
-
from
|
37
|
-
|
38
|
-
from sglang.srt.layers.activation import SiluAndMul
|
39
|
-
from sglang.srt.layers.layernorm import RMSNorm
|
40
|
-
from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput
|
33
|
+
from sglang.srt.layers.logits_processor import LogitsProcessor
|
41
34
|
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
42
35
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
43
36
|
from sglang.srt.layers.radix_attention import RadixAttention
|
37
|
+
from sglang.srt.layers.rotary_embedding import get_rope
|
44
38
|
from sglang.srt.layers.vocab_parallel_embedding import (
|
45
39
|
ParallelLMHead,
|
46
40
|
VocabParallelEmbedding,
|
sglang/srt/models/phi3_small.py
CHANGED
@@ -5,9 +5,8 @@ import torch
|
|
5
5
|
from torch import nn
|
6
6
|
from transformers import Phi3Config
|
7
7
|
from transformers.configuration_utils import PretrainedConfig
|
8
|
-
from vllm.distributed import get_tensor_model_parallel_world_size
|
9
|
-
from vllm.model_executor.layers.rotary_embedding import get_rope
|
10
8
|
|
9
|
+
from sglang.srt.distributed import get_tensor_model_parallel_world_size
|
11
10
|
from sglang.srt.layers.linear import (
|
12
11
|
MergedColumnParallelLinear,
|
13
12
|
QKVParallelLinear,
|
@@ -17,6 +16,7 @@ from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorO
|
|
17
16
|
from sglang.srt.layers.pooler import Pooler, PoolingType
|
18
17
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
19
18
|
from sglang.srt.layers.radix_attention import RadixAttention
|
19
|
+
from sglang.srt.layers.rotary_embedding import get_rope
|
20
20
|
from sglang.srt.layers.vocab_parallel_embedding import (
|
21
21
|
DEFAULT_VOCAB_PADDING_SIZE,
|
22
22
|
ParallelLMHead,
|
sglang/srt/models/qwen.py
CHANGED
@@ -20,9 +20,8 @@ from typing import Any, Dict, Iterable, Optional, Tuple
|
|
20
20
|
import torch
|
21
21
|
from torch import nn
|
22
22
|
from transformers import PretrainedConfig
|
23
|
-
from vllm.distributed import get_tensor_model_parallel_world_size
|
24
|
-
from vllm.model_executor.layers.rotary_embedding import get_rope
|
25
23
|
|
24
|
+
from sglang.srt.distributed import get_tensor_model_parallel_world_size
|
26
25
|
from sglang.srt.layers.activation import SiluAndMul
|
27
26
|
from sglang.srt.layers.layernorm import RMSNorm
|
28
27
|
from sglang.srt.layers.linear import (
|
@@ -33,6 +32,7 @@ from sglang.srt.layers.linear import (
|
|
33
32
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
34
33
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
35
34
|
from sglang.srt.layers.radix_attention import RadixAttention
|
35
|
+
from sglang.srt.layers.rotary_embedding import get_rope
|
36
36
|
from sglang.srt.layers.vocab_parallel_embedding import (
|
37
37
|
ParallelLMHead,
|
38
38
|
VocabParallelEmbedding,
|
sglang/srt/models/qwen2.py
CHANGED
@@ -20,9 +20,11 @@ from typing import Any, Dict, Iterable, Optional, Tuple
|
|
20
20
|
|
21
21
|
import torch
|
22
22
|
from torch import nn
|
23
|
-
from vllm.distributed import get_tensor_model_parallel_world_size
|
24
|
-
from vllm.model_executor.layers.rotary_embedding import get_rope
|
25
23
|
|
24
|
+
from sglang.srt.distributed import (
|
25
|
+
get_tensor_model_parallel_rank,
|
26
|
+
get_tensor_model_parallel_world_size,
|
27
|
+
)
|
26
28
|
from sglang.srt.layers.activation import SiluAndMul
|
27
29
|
from sglang.srt.layers.layernorm import RMSNorm
|
28
30
|
from sglang.srt.layers.linear import (
|
@@ -34,12 +36,16 @@ from sglang.srt.layers.logits_processor import LogitsProcessor
|
|
34
36
|
from sglang.srt.layers.pooler import Pooler, PoolingType
|
35
37
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
36
38
|
from sglang.srt.layers.radix_attention import RadixAttention
|
39
|
+
from sglang.srt.layers.rotary_embedding import get_rope
|
37
40
|
from sglang.srt.layers.vocab_parallel_embedding import (
|
38
41
|
ParallelLMHead,
|
39
42
|
VocabParallelEmbedding,
|
40
43
|
)
|
41
44
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
42
|
-
from sglang.srt.model_loader.weight_utils import
|
45
|
+
from sglang.srt.model_loader.weight_utils import (
|
46
|
+
default_weight_loader,
|
47
|
+
kv_cache_scales_loader,
|
48
|
+
)
|
43
49
|
from sglang.srt.utils import make_layers
|
44
50
|
|
45
51
|
Qwen2Config = None
|
@@ -242,6 +248,9 @@ class Qwen2Model(nn.Module):
|
|
242
248
|
)
|
243
249
|
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
244
250
|
|
251
|
+
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
252
|
+
return self.embed_tokens(input_ids)
|
253
|
+
|
245
254
|
def forward(
|
246
255
|
self,
|
247
256
|
input_ids: torch.Tensor,
|
@@ -265,9 +274,31 @@ class Qwen2Model(nn.Module):
|
|
265
274
|
hidden_states, _ = self.norm(hidden_states, residual)
|
266
275
|
return hidden_states
|
267
276
|
|
277
|
+
# If this function is called, it should always initialize KV cache scale
|
278
|
+
# factors (or else raise an exception). Thus, handled exceptions should
|
279
|
+
# make sure to leave KV cache scale factors in a known good (dummy) state
|
280
|
+
def load_kv_cache_scales(self, quantization_param_path: str) -> None:
|
281
|
+
tp_size = get_tensor_model_parallel_world_size()
|
282
|
+
tp_rank = get_tensor_model_parallel_rank()
|
283
|
+
for layer_idx, scaling_factor in kv_cache_scales_loader(
|
284
|
+
quantization_param_path,
|
285
|
+
tp_rank,
|
286
|
+
tp_size,
|
287
|
+
self.config.num_hidden_layers,
|
288
|
+
self.config.__class__.model_type,
|
289
|
+
):
|
290
|
+
if not isinstance(self.layers[layer_idx], nn.Identity):
|
291
|
+
layer_self_attn = self.layers[layer_idx].self_attn
|
292
|
+
if hasattr(layer_self_attn.attn, "k_scale"):
|
293
|
+
layer_self_attn.attn.k_scale = scaling_factor
|
294
|
+
layer_self_attn.attn.v_scale = scaling_factor
|
295
|
+
else:
|
296
|
+
raise RuntimeError(
|
297
|
+
"Self attention has no KV cache scaling " "factor attribute!"
|
298
|
+
)
|
268
299
|
|
269
|
-
class Qwen2ForCausalLM(nn.Module):
|
270
300
|
|
301
|
+
class Qwen2ForCausalLM(nn.Module):
|
271
302
|
# BitandBytes specific attributes
|
272
303
|
default_bitsandbytes_target_modules = [
|
273
304
|
".gate_proj.",
|
@@ -305,6 +336,9 @@ class Qwen2ForCausalLM(nn.Module):
|
|
305
336
|
self.logits_processor = LogitsProcessor(config)
|
306
337
|
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
|
307
338
|
|
339
|
+
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
340
|
+
return self.model.get_input_embeddings(input_ids)
|
341
|
+
|
308
342
|
@torch.no_grad()
|
309
343
|
def forward(
|
310
344
|
self,
|
@@ -373,5 +407,8 @@ class Qwen2ForCausalLM(nn.Module):
|
|
373
407
|
torch.cuda.empty_cache()
|
374
408
|
torch.cuda.synchronize()
|
375
409
|
|
410
|
+
def load_kv_cache_scales(self, quantization_param_path: str) -> None:
|
411
|
+
self.model.load_kv_cache_scales(quantization_param_path)
|
412
|
+
|
376
413
|
|
377
414
|
EntryClass = Qwen2ForCausalLM
|
sglang/srt/models/qwen2_moe.py
CHANGED
@@ -22,12 +22,11 @@ import torch
|
|
22
22
|
import torch.nn.functional as F
|
23
23
|
from torch import nn
|
24
24
|
from transformers import PretrainedConfig
|
25
|
-
|
25
|
+
|
26
|
+
from sglang.srt.distributed import (
|
26
27
|
get_tensor_model_parallel_world_size,
|
27
28
|
tensor_model_parallel_all_reduce,
|
28
29
|
)
|
29
|
-
from vllm.model_executor.layers.rotary_embedding import get_rope
|
30
|
-
|
31
30
|
from sglang.srt.layers.activation import SiluAndMul
|
32
31
|
from sglang.srt.layers.layernorm import RMSNorm
|
33
32
|
from sglang.srt.layers.linear import (
|
@@ -40,6 +39,7 @@ from sglang.srt.layers.logits_processor import LogitsProcessor
|
|
40
39
|
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
41
40
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
42
41
|
from sglang.srt.layers.radix_attention import RadixAttention
|
42
|
+
from sglang.srt.layers.rotary_embedding import get_rope
|
43
43
|
from sglang.srt.layers.vocab_parallel_embedding import (
|
44
44
|
ParallelLMHead,
|
45
45
|
VocabParallelEmbedding,
|
sglang/srt/models/qwen2_vl.py
CHANGED
@@ -22,6 +22,7 @@
|
|
22
22
|
# See the License for the specific language governing permissions and
|
23
23
|
# limitations under the License.
|
24
24
|
"""Inference-only Qwen2-VL model compatible with HuggingFace weights."""
|
25
|
+
import logging
|
25
26
|
from functools import lru_cache, partial
|
26
27
|
from typing import Iterable, List, Optional, Tuple, Type, TypedDict
|
27
28
|
|
@@ -30,16 +31,13 @@ import torch
|
|
30
31
|
import torch.nn as nn
|
31
32
|
import torch.nn.functional as F
|
32
33
|
from einops import rearrange, repeat
|
33
|
-
from vllm.distributed import parallel_state
|
34
|
-
from vllm.distributed import utils as dist_utils
|
35
|
-
from vllm.logger import init_logger
|
36
34
|
from vllm.model_executor.layers.activation import QuickGELU
|
37
35
|
|
38
36
|
from sglang.srt.configs import Qwen2VLConfig, Qwen2VLVisionConfig
|
37
|
+
from sglang.srt.distributed import parallel_state
|
38
|
+
from sglang.srt.distributed import utils as dist_utils
|
39
39
|
from sglang.srt.hf_transformers_utils import get_processor
|
40
|
-
from sglang.srt.layers.attention.
|
41
|
-
context_attention_fwd,
|
42
|
-
)
|
40
|
+
from sglang.srt.layers.attention.vision import VisionAttention
|
43
41
|
from sglang.srt.layers.linear import ColumnParallelLinear, RowParallelLinear
|
44
42
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
45
43
|
from sglang.srt.layers.pooler import Pooler, PoolingType
|
@@ -50,7 +48,8 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
|
50
48
|
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
51
49
|
from sglang.srt.models.qwen2 import Qwen2Model
|
52
50
|
|
53
|
-
logger =
|
51
|
+
logger = logging.getLogger(__name__)
|
52
|
+
|
54
53
|
|
55
54
|
# === Vision Inputs === #
|
56
55
|
|
@@ -110,118 +109,6 @@ class Qwen2VisionMLP(nn.Module):
|
|
110
109
|
return x
|
111
110
|
|
112
111
|
|
113
|
-
def rotate_half(x: torch.Tensor, interleaved: bool = False) -> torch.Tensor:
|
114
|
-
if not interleaved:
|
115
|
-
x1, x2 = x.chunk(2, dim=-1)
|
116
|
-
return torch.cat((-x2, x1), dim=-1)
|
117
|
-
else:
|
118
|
-
x1, x2 = x[..., ::2], x[..., 1::2]
|
119
|
-
return rearrange(
|
120
|
-
torch.stack((-x2, x1), dim=-1), "... d two -> ... (d two)", two=2
|
121
|
-
)
|
122
|
-
|
123
|
-
|
124
|
-
def apply_rotary_emb_torch(
|
125
|
-
x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor, interleaved: bool = False
|
126
|
-
) -> torch.Tensor:
|
127
|
-
"""
|
128
|
-
x: (batch_size, seqlen, nheads, headdim)
|
129
|
-
cos, sin: (seqlen, rotary_dim / 2) or (batch_size, seqlen, rotary_dim / 2)
|
130
|
-
"""
|
131
|
-
ro_dim = cos.shape[-1] * 2
|
132
|
-
assert ro_dim <= x.shape[-1]
|
133
|
-
cos = repeat(
|
134
|
-
cos, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)"
|
135
|
-
)
|
136
|
-
sin = repeat(
|
137
|
-
sin, "... d -> ... 1 (2 d)" if not interleaved else "... d -> ... 1 (d 2)"
|
138
|
-
)
|
139
|
-
return torch.cat(
|
140
|
-
[
|
141
|
-
x[..., :ro_dim] * cos + rotate_half(x[..., :ro_dim], interleaved) * sin,
|
142
|
-
x[..., ro_dim:],
|
143
|
-
],
|
144
|
-
dim=-1,
|
145
|
-
)
|
146
|
-
|
147
|
-
|
148
|
-
def apply_rotary_pos_emb_vision(t: torch.Tensor, freqs: torch.Tensor) -> torch.Tensor:
|
149
|
-
t_ = t.float()
|
150
|
-
cos = freqs.cos()
|
151
|
-
sin = freqs.sin()
|
152
|
-
output = apply_rotary_emb_torch(t_, cos, sin).type_as(t)
|
153
|
-
return output
|
154
|
-
|
155
|
-
|
156
|
-
class Qwen2VisionAttention(nn.Module):
|
157
|
-
|
158
|
-
def __init__(
|
159
|
-
self,
|
160
|
-
embed_dim: Optional[int] = None,
|
161
|
-
num_heads: Optional[int] = None,
|
162
|
-
projection_size: Optional[int] = None,
|
163
|
-
quant_config: Optional[QuantizationConfig] = None,
|
164
|
-
) -> None:
|
165
|
-
super().__init__()
|
166
|
-
# Per attention head and per partition values.
|
167
|
-
world_size = parallel_state.get_tensor_model_parallel_world_size()
|
168
|
-
self.hidden_size_per_attention_head = dist_utils.divide(
|
169
|
-
projection_size, num_heads
|
170
|
-
)
|
171
|
-
self.num_attention_heads_per_partition = dist_utils.divide(
|
172
|
-
num_heads, world_size
|
173
|
-
)
|
174
|
-
|
175
|
-
self.qkv = ColumnParallelLinear(
|
176
|
-
input_size=embed_dim,
|
177
|
-
output_size=3 * projection_size,
|
178
|
-
quant_config=quant_config,
|
179
|
-
)
|
180
|
-
self.proj = RowParallelLinear(
|
181
|
-
input_size=projection_size, output_size=embed_dim, quant_config=quant_config
|
182
|
-
)
|
183
|
-
|
184
|
-
def forward(
|
185
|
-
self,
|
186
|
-
x: torch.Tensor,
|
187
|
-
cu_seqlens: torch.Tensor,
|
188
|
-
rotary_pos_emb: torch.Tensor = None,
|
189
|
-
) -> torch.Tensor:
|
190
|
-
# [s, b, c] --> [s, b, head * 3 * head_dim]
|
191
|
-
x, _ = self.qkv(x)
|
192
|
-
|
193
|
-
# [s, b, head * 3 * head_dim] --> [s, b, head, 3 * head_dim]
|
194
|
-
new_x_shape = x.size()[:-1] + (
|
195
|
-
self.num_attention_heads_per_partition,
|
196
|
-
3 * self.hidden_size_per_attention_head,
|
197
|
-
)
|
198
|
-
x = x.view(*new_x_shape)
|
199
|
-
|
200
|
-
# [s, b, head, 3 * head_dim] --> 3 [s, b, head, head_dim]
|
201
|
-
q, k, v = dist_utils.split_tensor_along_last_dim(x, 3)
|
202
|
-
batch_size = q.shape[1]
|
203
|
-
|
204
|
-
q, k, v = [rearrange(x, "s b ... -> b s ...").contiguous() for x in (q, k, v)]
|
205
|
-
if rotary_pos_emb is not None:
|
206
|
-
q = apply_rotary_pos_emb_vision(q, rotary_pos_emb)
|
207
|
-
k = apply_rotary_pos_emb_vision(k, rotary_pos_emb)
|
208
|
-
|
209
|
-
seq_lens = cu_seqlens[1:] - cu_seqlens[:-1]
|
210
|
-
max_seqlen = (seq_lens).max().item()
|
211
|
-
q, k, v = [rearrange(x, "b s ... -> (b s) ...") for x in [q, k, v]]
|
212
|
-
|
213
|
-
output = torch.empty_like(q)
|
214
|
-
context_attention_fwd(
|
215
|
-
q, k, v, output, cu_seqlens, seq_lens, max_seqlen, is_causal=False
|
216
|
-
)
|
217
|
-
|
218
|
-
context_layer = rearrange(output, "(b s) ... -> b s ...", b=batch_size)
|
219
|
-
context_layer = rearrange(context_layer, "b s h d -> s b (h d)").contiguous()
|
220
|
-
|
221
|
-
output, _ = self.proj(context_layer)
|
222
|
-
return output
|
223
|
-
|
224
|
-
|
225
112
|
class Qwen2VisionBlock(nn.Module):
|
226
113
|
|
227
114
|
def __init__(
|
@@ -240,10 +127,11 @@ class Qwen2VisionBlock(nn.Module):
|
|
240
127
|
self.norm2 = norm_layer(dim)
|
241
128
|
mlp_hidden_dim = int(dim * mlp_ratio)
|
242
129
|
|
243
|
-
self.attn =
|
130
|
+
self.attn = VisionAttention(
|
244
131
|
embed_dim=dim,
|
245
132
|
num_heads=num_heads,
|
246
133
|
projection_size=dim,
|
134
|
+
use_qkv_parallel=False,
|
247
135
|
quant_config=quant_config,
|
248
136
|
)
|
249
137
|
self.mlp = Qwen2VisionMLP(
|
@@ -253,9 +141,13 @@ class Qwen2VisionBlock(nn.Module):
|
|
253
141
|
def forward(
|
254
142
|
self, x: torch.Tensor, cu_seqlens: torch.Tensor, rotary_pos_emb: torch.Tensor
|
255
143
|
) -> torch.Tensor:
|
256
|
-
|
257
|
-
|
144
|
+
hidden_states = self.norm1(x)
|
145
|
+
hidden_states = rearrange(hidden_states, "s b ... -> b s ...")
|
146
|
+
attn = self.attn(
|
147
|
+
hidden_states, cu_seqlens=cu_seqlens, rotary_pos_emb=rotary_pos_emb
|
258
148
|
)
|
149
|
+
attn = rearrange(attn, "b s ... -> s b ...")
|
150
|
+
x = x + attn
|
259
151
|
x = x + self.mlp(self.norm2(x))
|
260
152
|
return x
|
261
153
|
|
@@ -684,10 +576,12 @@ class Qwen2VLForConditionalGeneration(nn.Module):
|
|
684
576
|
for name, loaded_weight in weights:
|
685
577
|
if "rotary_emb.inv_freq" in name:
|
686
578
|
continue
|
579
|
+
|
687
580
|
for param_name, weight_name, shard_id in stacked_params_mapping:
|
688
581
|
if weight_name not in name:
|
689
582
|
continue
|
690
583
|
name = name.replace(weight_name, param_name)
|
584
|
+
|
691
585
|
# Skip loading extra bias for GPTQ models.
|
692
586
|
if name.endswith(".bias") and name not in params_dict:
|
693
587
|
continue
|
@@ -696,6 +590,7 @@ class Qwen2VLForConditionalGeneration(nn.Module):
|
|
696
590
|
weight_loader(param, loaded_weight, shard_id)
|
697
591
|
break
|
698
592
|
else:
|
593
|
+
|
699
594
|
if "visual" in name and "qkv.weight" in name:
|
700
595
|
visual_num_heads = self.config.vision_config.num_heads
|
701
596
|
visual_embed_dim = self.config.vision_config.embed_dim
|
@@ -712,6 +607,11 @@ class Qwen2VLForConditionalGeneration(nn.Module):
|
|
712
607
|
loaded_weight = loaded_weight.view(3, visual_num_heads, head_size)
|
713
608
|
loaded_weight = loaded_weight.transpose(0, 1)
|
714
609
|
loaded_weight = loaded_weight.reshape(-1)
|
610
|
+
|
611
|
+
if "visual" in name:
|
612
|
+
# adapt to VisionAttention
|
613
|
+
name = name.replace(r"attn.qkv.", r"attn.qkv_proj.")
|
614
|
+
|
715
615
|
try:
|
716
616
|
# Skip loading extra bias for GPTQ models.
|
717
617
|
if name.endswith(".bias") and name not in params_dict:
|
sglang/srt/models/stablelm.py
CHANGED
@@ -24,9 +24,8 @@ from typing import Iterable, Optional, Tuple
|
|
24
24
|
import torch
|
25
25
|
from torch import nn
|
26
26
|
from transformers import PretrainedConfig
|
27
|
-
from vllm.distributed import get_tensor_model_parallel_world_size
|
28
|
-
from vllm.model_executor.layers.rotary_embedding import get_rope
|
29
27
|
|
28
|
+
from sglang.srt.distributed import get_tensor_model_parallel_world_size
|
30
29
|
from sglang.srt.layers.activation import SiluAndMul
|
31
30
|
from sglang.srt.layers.linear import (
|
32
31
|
MergedColumnParallelLinear,
|
@@ -36,6 +35,7 @@ from sglang.srt.layers.linear import (
|
|
36
35
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
37
36
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
38
37
|
from sglang.srt.layers.radix_attention import RadixAttention
|
38
|
+
from sglang.srt.layers.rotary_embedding import get_rope
|
39
39
|
from sglang.srt.layers.vocab_parallel_embedding import (
|
40
40
|
ParallelLMHead,
|
41
41
|
VocabParallelEmbedding,
|
@@ -47,17 +47,17 @@ import torch
|
|
47
47
|
from torch import nn
|
48
48
|
from torch.nn.parameter import Parameter
|
49
49
|
from transformers import LlamaConfig
|
50
|
-
|
50
|
+
|
51
|
+
from sglang.srt.distributed import (
|
51
52
|
get_tensor_model_parallel_rank,
|
52
53
|
get_tensor_model_parallel_world_size,
|
53
54
|
)
|
54
|
-
from vllm.model_executor.layers.rotary_embedding import get_rope
|
55
|
-
|
56
55
|
from sglang.srt.layers.activation import SiluAndMul
|
57
56
|
from sglang.srt.layers.layernorm import RMSNorm
|
58
57
|
from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput
|
59
58
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
60
59
|
from sglang.srt.layers.radix_attention import RadixAttention
|
60
|
+
from sglang.srt.layers.rotary_embedding import get_rope
|
61
61
|
from sglang.srt.layers.vocab_parallel_embedding import (
|
62
62
|
ParallelLMHead,
|
63
63
|
VocabParallelEmbedding,
|
@@ -460,7 +460,12 @@ class TorchNativeLlamaForCausalLM(nn.Module):
|
|
460
460
|
params_dict = dict(self.named_parameters())
|
461
461
|
return len(params_dict)
|
462
462
|
|
463
|
-
def
|
463
|
+
def load_weights_to_module(
|
464
|
+
self,
|
465
|
+
fqn: str,
|
466
|
+
weights: Iterable[Tuple[str, torch.Tensor]],
|
467
|
+
):
|
468
|
+
"""Load weights onto submodule pointed by path `fqn`."""
|
464
469
|
stacked_params_mapping = [
|
465
470
|
# (param_name, shard_name, shard_id)
|
466
471
|
(".qkv_proj", ".q_proj", "q"),
|
@@ -469,7 +474,8 @@ class TorchNativeLlamaForCausalLM(nn.Module):
|
|
469
474
|
(".gate_up_proj", ".gate_proj", 0),
|
470
475
|
(".gate_up_proj", ".up_proj", 1),
|
471
476
|
]
|
472
|
-
|
477
|
+
module = self.get_submodule(fqn)
|
478
|
+
params_dict = dict(module.named_parameters(prefix=fqn, recurse=False))
|
473
479
|
|
474
480
|
for name, loaded_weight in weights:
|
475
481
|
if "rotary_emb.inv_freq" in name or "projector" in name:
|
@@ -486,7 +492,7 @@ class TorchNativeLlamaForCausalLM(nn.Module):
|
|
486
492
|
continue
|
487
493
|
name = name.replace(weight_name, param_name)
|
488
494
|
# Skip loading extra bias for GPTQ models.
|
489
|
-
if name.endswith(".bias")
|
495
|
+
if name.endswith(".bias") or name not in params_dict:
|
490
496
|
continue
|
491
497
|
param = params_dict[name]
|
492
498
|
weight_loader = param.weight_loader
|
@@ -494,12 +500,19 @@ class TorchNativeLlamaForCausalLM(nn.Module):
|
|
494
500
|
break
|
495
501
|
else:
|
496
502
|
# Skip loading extra bias for GPTQ models.
|
497
|
-
if name.endswith(".bias")
|
503
|
+
if name.endswith(".bias") or name not in params_dict:
|
498
504
|
continue
|
499
505
|
param = params_dict[name]
|
500
506
|
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
501
507
|
weight_loader(param, loaded_weight)
|
502
508
|
|
509
|
+
def load_weights(
|
510
|
+
self,
|
511
|
+
weights: Iterable[Tuple[str, torch.Tensor]],
|
512
|
+
):
|
513
|
+
"""Load weights onto the full model."""
|
514
|
+
self.load_weights_to_module("", weights)
|
515
|
+
|
503
516
|
|
504
517
|
class TorchNativePhi3ForCausalLM(TorchNativeLlamaForCausalLM):
|
505
518
|
pass
|
sglang/srt/models/xverse.py
CHANGED
@@ -21,19 +21,19 @@ from typing import Any, Dict, Iterable, Optional, Tuple
|
|
21
21
|
import torch
|
22
22
|
from torch import nn
|
23
23
|
from transformers import LlamaConfig
|
24
|
-
|
25
|
-
from
|
26
|
-
from
|
27
|
-
from
|
24
|
+
|
25
|
+
from sglang.srt.distributed import get_tensor_model_parallel_world_size
|
26
|
+
from sglang.srt.layers.activation import SiluAndMul
|
27
|
+
from sglang.srt.layers.layernorm import RMSNorm
|
28
|
+
from sglang.srt.layers.linear import (
|
28
29
|
MergedColumnParallelLinear,
|
29
30
|
QKVParallelLinear,
|
30
31
|
RowParallelLinear,
|
31
32
|
)
|
32
|
-
from vllm.model_executor.layers.rotary_embedding import get_rope
|
33
|
-
|
34
33
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
35
34
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
36
35
|
from sglang.srt.layers.radix_attention import RadixAttention
|
36
|
+
from sglang.srt.layers.rotary_embedding import get_rope
|
37
37
|
from sglang.srt.layers.vocab_parallel_embedding import (
|
38
38
|
ParallelLMHead,
|
39
39
|
VocabParallelEmbedding,
|