sglang 0.3.3__py3-none-any.whl → 0.3.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_latency.py +31 -13
- sglang/bench_server_latency.py +21 -10
- sglang/bench_serving.py +101 -7
- sglang/global_config.py +0 -1
- sglang/srt/conversation.py +11 -2
- sglang/srt/layers/attention/__init__.py +27 -5
- sglang/srt/layers/attention/double_sparsity_backend.py +281 -0
- sglang/srt/layers/attention/flashinfer_backend.py +352 -83
- sglang/srt/layers/attention/triton_backend.py +6 -4
- sglang/srt/layers/attention/triton_ops/double_sparsity_attention.py +772 -0
- sglang/srt/layers/attention/triton_ops/extend_attention.py +5 -3
- sglang/srt/layers/attention/triton_ops/prefill_attention.py +4 -2
- sglang/srt/layers/sampler.py +6 -2
- sglang/srt/managers/data_parallel_controller.py +177 -0
- sglang/srt/managers/detokenizer_manager.py +31 -10
- sglang/srt/managers/io_struct.py +11 -2
- sglang/srt/managers/schedule_batch.py +126 -43
- sglang/srt/managers/schedule_policy.py +2 -1
- sglang/srt/managers/scheduler.py +245 -142
- sglang/srt/managers/tokenizer_manager.py +14 -1
- sglang/srt/managers/tp_worker.py +111 -1
- sglang/srt/mem_cache/chunk_cache.py +8 -4
- sglang/srt/mem_cache/memory_pool.py +77 -4
- sglang/srt/mem_cache/radix_cache.py +15 -7
- sglang/srt/model_executor/cuda_graph_runner.py +4 -4
- sglang/srt/model_executor/forward_batch_info.py +16 -21
- sglang/srt/model_executor/model_runner.py +100 -36
- sglang/srt/models/baichuan.py +2 -3
- sglang/srt/models/chatglm.py +5 -6
- sglang/srt/models/commandr.py +1 -2
- sglang/srt/models/dbrx.py +1 -2
- sglang/srt/models/deepseek.py +4 -5
- sglang/srt/models/deepseek_v2.py +5 -6
- sglang/srt/models/exaone.py +1 -2
- sglang/srt/models/gemma.py +2 -2
- sglang/srt/models/gemma2.py +5 -5
- sglang/srt/models/gpt_bigcode.py +5 -5
- sglang/srt/models/grok.py +1 -2
- sglang/srt/models/internlm2.py +1 -2
- sglang/srt/models/llama.py +1 -2
- sglang/srt/models/llama_classification.py +1 -2
- sglang/srt/models/llama_reward.py +2 -3
- sglang/srt/models/llava.py +4 -8
- sglang/srt/models/llavavid.py +1 -2
- sglang/srt/models/minicpm.py +1 -2
- sglang/srt/models/minicpm3.py +5 -6
- sglang/srt/models/mixtral.py +1 -2
- sglang/srt/models/mixtral_quant.py +1 -2
- sglang/srt/models/olmo.py +352 -0
- sglang/srt/models/olmoe.py +1 -2
- sglang/srt/models/qwen.py +1 -2
- sglang/srt/models/qwen2.py +1 -2
- sglang/srt/models/qwen2_moe.py +4 -5
- sglang/srt/models/stablelm.py +1 -2
- sglang/srt/models/torch_native_llama.py +1 -2
- sglang/srt/models/xverse.py +1 -2
- sglang/srt/models/xverse_moe.py +4 -5
- sglang/srt/models/yivl.py +1 -2
- sglang/srt/openai_api/adapter.py +97 -52
- sglang/srt/openai_api/protocol.py +10 -2
- sglang/srt/sampling/penaltylib/orchestrator.py +28 -9
- sglang/srt/sampling/sampling_batch_info.py +105 -59
- sglang/srt/sampling/sampling_params.py +2 -0
- sglang/srt/server.py +171 -37
- sglang/srt/server_args.py +127 -48
- sglang/srt/utils.py +37 -14
- sglang/test/few_shot_gsm8k.py +4 -1
- sglang/test/few_shot_gsm8k_engine.py +144 -0
- sglang/test/srt/sampling/penaltylib/utils.py +16 -12
- sglang/version.py +1 -1
- {sglang-0.3.3.dist-info → sglang-0.3.4.dist-info}/METADATA +82 -32
- sglang-0.3.4.dist-info/RECORD +143 -0
- {sglang-0.3.3.dist-info → sglang-0.3.4.dist-info}/WHEEL +1 -1
- sglang/srt/layers/attention/flashinfer_utils.py +0 -237
- sglang-0.3.3.dist-info/RECORD +0 -139
- {sglang-0.3.3.dist-info → sglang-0.3.4.dist-info}/LICENSE +0 -0
- {sglang-0.3.3.dist-info → sglang-0.3.4.dist-info}/top_level.txt +0 -0
sglang/srt/models/llama.py
CHANGED
@@ -22,7 +22,6 @@ from typing import Any, Dict, Iterable, Optional, Tuple
|
|
22
22
|
import torch
|
23
23
|
from torch import nn
|
24
24
|
from transformers import LlamaConfig
|
25
|
-
from vllm.config import CacheConfig
|
26
25
|
from vllm.distributed import get_tensor_model_parallel_world_size
|
27
26
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
28
27
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
@@ -295,7 +294,7 @@ class LlamaForCausalLM(nn.Module):
|
|
295
294
|
self,
|
296
295
|
config: LlamaConfig,
|
297
296
|
quant_config: Optional[QuantizationConfig] = None,
|
298
|
-
cache_config
|
297
|
+
cache_config=None,
|
299
298
|
) -> None:
|
300
299
|
super().__init__()
|
301
300
|
self.config = config
|
@@ -18,7 +18,6 @@ from typing import Iterable, Optional, Tuple
|
|
18
18
|
import torch
|
19
19
|
from torch import nn
|
20
20
|
from transformers import LlamaConfig
|
21
|
-
from vllm.config import CacheConfig
|
22
21
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
23
22
|
|
24
23
|
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
@@ -32,7 +31,7 @@ class LlamaForClassification(nn.Module):
|
|
32
31
|
self,
|
33
32
|
config: LlamaConfig,
|
34
33
|
quant_config: Optional[QuantizationConfig] = None,
|
35
|
-
cache_config
|
34
|
+
cache_config=None,
|
36
35
|
) -> None:
|
37
36
|
super().__init__()
|
38
37
|
self.config = config
|
@@ -18,7 +18,6 @@ from typing import Iterable, Optional, Tuple
|
|
18
18
|
import torch
|
19
19
|
from torch import nn
|
20
20
|
from transformers import LlamaConfig
|
21
|
-
from vllm.config import CacheConfig
|
22
21
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
23
22
|
|
24
23
|
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
@@ -33,7 +32,7 @@ class LlamaForSequenceClassification(nn.Module):
|
|
33
32
|
self,
|
34
33
|
config: LlamaConfig,
|
35
34
|
quant_config: Optional[QuantizationConfig] = None,
|
36
|
-
cache_config
|
35
|
+
cache_config=None,
|
37
36
|
) -> None:
|
38
37
|
super().__init__()
|
39
38
|
self.config = config
|
@@ -92,7 +91,7 @@ class LlamaForSequenceClassificationWithNormal_Weights(LlamaForSequenceClassific
|
|
92
91
|
self,
|
93
92
|
config: LlamaConfig,
|
94
93
|
quant_config: Optional[QuantizationConfig] = None,
|
95
|
-
cache_config
|
94
|
+
cache_config=None,
|
96
95
|
) -> None:
|
97
96
|
super().__init__(config, quant_config, cache_config)
|
98
97
|
self.weights = self.Weights(config.hidden_size, self.num_labels)
|
sglang/srt/models/llava.py
CHANGED
@@ -31,7 +31,6 @@ from transformers import (
|
|
31
31
|
SiglipVisionModel,
|
32
32
|
)
|
33
33
|
from transformers.models.llava.modeling_llava import LlavaMultiModalProjector
|
34
|
-
from vllm.config import CacheConfig
|
35
34
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
36
35
|
|
37
36
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
@@ -161,9 +160,6 @@ class LlavaBaseForCausalLM(nn.Module):
|
|
161
160
|
image_sizes = [
|
162
161
|
image_inputs[i].image_sizes for i in range(bs) if need_vision[i]
|
163
162
|
]
|
164
|
-
image_offsets = [
|
165
|
-
image_inputs[i].image_offsets for i in range(bs) if need_vision[i]
|
166
|
-
]
|
167
163
|
|
168
164
|
########## Encode Image ########
|
169
165
|
|
@@ -359,7 +355,7 @@ class LlavaBaseForCausalLM(nn.Module):
|
|
359
355
|
prefix_len = prefix_lens_cpu[i]
|
360
356
|
|
361
357
|
# Multiple images
|
362
|
-
for j, image_offset in enumerate(
|
358
|
+
for j, image_offset in enumerate(image_inputs[i].image_offsets):
|
363
359
|
if image_offset < prefix_len:
|
364
360
|
continue
|
365
361
|
|
@@ -450,7 +446,7 @@ class LlavaLlamaForCausalLM(LlavaBaseForCausalLM):
|
|
450
446
|
self,
|
451
447
|
config: LlavaConfig,
|
452
448
|
quant_config: Optional[QuantizationConfig] = None,
|
453
|
-
cache_config
|
449
|
+
cache_config=None,
|
454
450
|
) -> None:
|
455
451
|
super().__init__()
|
456
452
|
|
@@ -472,7 +468,7 @@ class LlavaQwenForCausalLM(LlavaBaseForCausalLM):
|
|
472
468
|
self,
|
473
469
|
config: LlavaConfig,
|
474
470
|
quant_config: Optional[QuantizationConfig] = None,
|
475
|
-
cache_config
|
471
|
+
cache_config=None,
|
476
472
|
) -> None:
|
477
473
|
super().__init__()
|
478
474
|
|
@@ -505,7 +501,7 @@ class LlavaMistralForCausalLM(LlavaBaseForCausalLM):
|
|
505
501
|
self,
|
506
502
|
config: LlavaConfig,
|
507
503
|
quant_config: Optional[QuantizationConfig] = None,
|
508
|
-
cache_config
|
504
|
+
cache_config=None,
|
509
505
|
) -> None:
|
510
506
|
super().__init__()
|
511
507
|
|
sglang/srt/models/llavavid.py
CHANGED
@@ -22,7 +22,6 @@ import torch
|
|
22
22
|
from torch import nn
|
23
23
|
from transformers import CLIPVisionModel, LlavaConfig
|
24
24
|
from transformers.models.llava.modeling_llava import LlavaMultiModalProjector
|
25
|
-
from vllm.config import CacheConfig
|
26
25
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
27
26
|
|
28
27
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
@@ -36,7 +35,7 @@ class LlavaVidForCausalLM(nn.Module):
|
|
36
35
|
self,
|
37
36
|
config: LlavaConfig,
|
38
37
|
quant_config: Optional[QuantizationConfig] = None,
|
39
|
-
cache_config
|
38
|
+
cache_config=None,
|
40
39
|
) -> None:
|
41
40
|
super().__init__()
|
42
41
|
self.config = config
|
sglang/srt/models/minicpm.py
CHANGED
@@ -20,7 +20,6 @@ from typing import Any, Dict, Iterable, Optional, Tuple
|
|
20
20
|
|
21
21
|
import torch
|
22
22
|
from torch import nn
|
23
|
-
from vllm.config import CacheConfig
|
24
23
|
from vllm.distributed import get_tensor_model_parallel_world_size
|
25
24
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
26
25
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
@@ -278,7 +277,7 @@ class MiniCPMForCausalLM(nn.Module):
|
|
278
277
|
self,
|
279
278
|
config,
|
280
279
|
quant_config: Optional[QuantizationConfig] = None,
|
281
|
-
cache_config
|
280
|
+
cache_config=None,
|
282
281
|
) -> None:
|
283
282
|
super().__init__()
|
284
283
|
self.config = config
|
sglang/srt/models/minicpm3.py
CHANGED
@@ -21,7 +21,6 @@ from typing import Any, Dict, Iterable, Optional, Tuple
|
|
21
21
|
import torch
|
22
22
|
from torch import nn
|
23
23
|
from transformers import PretrainedConfig
|
24
|
-
from vllm.config import CacheConfig
|
25
24
|
from vllm.distributed import get_tensor_model_parallel_world_size
|
26
25
|
from vllm.model_executor.layers.linear import (
|
27
26
|
ColumnParallelLinear,
|
@@ -108,7 +107,7 @@ class MiniCPM3Attention(nn.Module):
|
|
108
107
|
rope_theta: float = 10000,
|
109
108
|
rope_scaling: Optional[Dict[str, Any]] = None,
|
110
109
|
max_position_embeddings: int = 8192,
|
111
|
-
cache_config
|
110
|
+
cache_config=None,
|
112
111
|
quant_config: Optional[QuantizationConfig] = None,
|
113
112
|
layer_id=None,
|
114
113
|
) -> None:
|
@@ -252,7 +251,7 @@ class MiniCPM3AttentionMLA(nn.Module):
|
|
252
251
|
rope_theta: float = 10000,
|
253
252
|
rope_scaling: Optional[Dict[str, Any]] = None,
|
254
253
|
max_position_embeddings: int = 8192,
|
255
|
-
cache_config
|
254
|
+
cache_config=None,
|
256
255
|
quant_config: Optional[QuantizationConfig] = None,
|
257
256
|
layer_id=None,
|
258
257
|
) -> None:
|
@@ -409,7 +408,7 @@ class MiniCPM3DecoderLayer(nn.Module):
|
|
409
408
|
self,
|
410
409
|
config: PretrainedConfig,
|
411
410
|
layer_id: int,
|
412
|
-
cache_config
|
411
|
+
cache_config=None,
|
413
412
|
quant_config: Optional[QuantizationConfig] = None,
|
414
413
|
) -> None:
|
415
414
|
super().__init__()
|
@@ -501,7 +500,7 @@ class MiniCPM3Model(nn.Module):
|
|
501
500
|
def __init__(
|
502
501
|
self,
|
503
502
|
config: PretrainedConfig,
|
504
|
-
cache_config
|
503
|
+
cache_config=None,
|
505
504
|
quant_config: Optional[QuantizationConfig] = None,
|
506
505
|
) -> None:
|
507
506
|
super().__init__()
|
@@ -552,7 +551,7 @@ class MiniCPM3ForCausalLM(nn.Module):
|
|
552
551
|
def __init__(
|
553
552
|
self,
|
554
553
|
config: PretrainedConfig,
|
555
|
-
cache_config
|
554
|
+
cache_config=None,
|
556
555
|
quant_config: Optional[QuantizationConfig] = None,
|
557
556
|
) -> None:
|
558
557
|
super().__init__()
|
sglang/srt/models/mixtral.py
CHANGED
@@ -21,7 +21,6 @@ from typing import Iterable, Optional, Tuple
|
|
21
21
|
import torch
|
22
22
|
from torch import nn
|
23
23
|
from transformers import MixtralConfig
|
24
|
-
from vllm.config import CacheConfig
|
25
24
|
from vllm.distributed import get_tensor_model_parallel_world_size
|
26
25
|
from vllm.model_executor.layers.fused_moe import FusedMoE
|
27
26
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
@@ -293,7 +292,7 @@ class MixtralForCausalLM(nn.Module):
|
|
293
292
|
self,
|
294
293
|
config: MixtralConfig,
|
295
294
|
quant_config: Optional[QuantizationConfig] = None,
|
296
|
-
cache_config
|
295
|
+
cache_config=None,
|
297
296
|
) -> None:
|
298
297
|
super().__init__()
|
299
298
|
self.config = config
|
@@ -23,7 +23,6 @@ import torch
|
|
23
23
|
import torch.nn.functional as F
|
24
24
|
from torch import nn
|
25
25
|
from transformers import MixtralConfig
|
26
|
-
from vllm.config import CacheConfig
|
27
26
|
from vllm.distributed import (
|
28
27
|
get_tensor_model_parallel_rank,
|
29
28
|
get_tensor_model_parallel_world_size,
|
@@ -325,7 +324,7 @@ class QuantMixtralForCausalLM(nn.Module):
|
|
325
324
|
self,
|
326
325
|
config: MixtralConfig,
|
327
326
|
quant_config: Optional[QuantizationConfig] = None,
|
328
|
-
cache_config
|
327
|
+
cache_config=None,
|
329
328
|
) -> None:
|
330
329
|
super().__init__()
|
331
330
|
self.config = config
|
@@ -0,0 +1,352 @@
|
|
1
|
+
"""
|
2
|
+
Copyright 2023-2024 SGLang Team
|
3
|
+
Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
you may not use this file except in compliance with the License.
|
5
|
+
You may obtain a copy of the License at
|
6
|
+
|
7
|
+
http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
|
9
|
+
Unless required by applicable law or agreed to in writing, software
|
10
|
+
distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
See the License for the specific language governing permissions and
|
13
|
+
limitations under the License.
|
14
|
+
"""
|
15
|
+
|
16
|
+
# Adapted from
|
17
|
+
# https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/olmo.py#L1
|
18
|
+
"""Inference-only OLMo model compatible with HuggingFace weights."""
|
19
|
+
from typing import Iterable, List, Optional, Tuple
|
20
|
+
|
21
|
+
import torch
|
22
|
+
from torch import nn
|
23
|
+
from transformers import OlmoConfig
|
24
|
+
from vllm.distributed import get_tensor_model_parallel_world_size
|
25
|
+
from vllm.model_executor.layers.rotary_embedding import get_rope
|
26
|
+
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
27
|
+
ParallelLMHead,
|
28
|
+
VocabParallelEmbedding,
|
29
|
+
)
|
30
|
+
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
31
|
+
|
32
|
+
from sglang.srt.layers.activation import SiluAndMul
|
33
|
+
from sglang.srt.layers.linear import (
|
34
|
+
MergedColumnParallelLinear,
|
35
|
+
QKVParallelLinear,
|
36
|
+
RowParallelLinear,
|
37
|
+
)
|
38
|
+
from sglang.srt.layers.logits_processor import LogitsProcessor
|
39
|
+
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
40
|
+
from sglang.srt.layers.radix_attention import RadixAttention
|
41
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
42
|
+
|
43
|
+
|
44
|
+
class OlmoAttention(nn.Module):
|
45
|
+
"""
|
46
|
+
This is the attention block where the output is computed as
|
47
|
+
``Attention(LN(x))`` in ``MLP(LN(x + Attention(LN(x))))``
|
48
|
+
(plus another skip connection).
|
49
|
+
"""
|
50
|
+
|
51
|
+
def __init__(
|
52
|
+
self,
|
53
|
+
config: OlmoConfig,
|
54
|
+
layer_id: int = 0,
|
55
|
+
quant_config: Optional[QuantizationConfig] = None,
|
56
|
+
):
|
57
|
+
super().__init__()
|
58
|
+
self.config = config
|
59
|
+
self.hidden_size = config.hidden_size
|
60
|
+
tensor_model_parallel_world_size = get_tensor_model_parallel_world_size()
|
61
|
+
self.total_num_heads = config.num_attention_heads
|
62
|
+
|
63
|
+
assert self.hidden_size % self.total_num_heads == 0
|
64
|
+
assert self.total_num_heads % tensor_model_parallel_world_size == 0
|
65
|
+
|
66
|
+
self.num_heads = self.total_num_heads // tensor_model_parallel_world_size
|
67
|
+
self.head_dim = self.hidden_size // self.total_num_heads
|
68
|
+
self.max_position_embeddings = config.max_position_embeddings
|
69
|
+
self.rope_theta = config.rope_theta
|
70
|
+
self.clip_qkv = config.clip_qkv
|
71
|
+
|
72
|
+
# Attention input projection. Projects x -> (q, k, v)
|
73
|
+
self.qkv_proj = QKVParallelLinear(
|
74
|
+
self.hidden_size,
|
75
|
+
self.head_dim,
|
76
|
+
self.total_num_heads,
|
77
|
+
bias=config.attention_bias,
|
78
|
+
)
|
79
|
+
|
80
|
+
# Rotary embeddings.
|
81
|
+
self.rotary_emb = get_rope(
|
82
|
+
self.head_dim,
|
83
|
+
rotary_dim=self.head_dim,
|
84
|
+
max_position=self.max_position_embeddings,
|
85
|
+
base=self.rope_theta,
|
86
|
+
)
|
87
|
+
self.scaling = self.head_dim**-0.5
|
88
|
+
self.attn = RadixAttention(
|
89
|
+
self.num_heads,
|
90
|
+
self.head_dim,
|
91
|
+
self.scaling,
|
92
|
+
num_kv_heads=self.num_heads,
|
93
|
+
layer_id=layer_id,
|
94
|
+
)
|
95
|
+
|
96
|
+
# Attention output projection.
|
97
|
+
self.o_proj = RowParallelLinear(
|
98
|
+
self.hidden_size,
|
99
|
+
self.hidden_size,
|
100
|
+
bias=config.attention_bias,
|
101
|
+
)
|
102
|
+
|
103
|
+
def forward(
|
104
|
+
self,
|
105
|
+
positions: torch.Tensor,
|
106
|
+
hidden_states: torch.Tensor,
|
107
|
+
forward_batch: ForwardBatch,
|
108
|
+
) -> torch.Tensor:
|
109
|
+
qkv, _ = self.qkv_proj(hidden_states)
|
110
|
+
if self.clip_qkv is not None:
|
111
|
+
qkv.clamp_(min=-self.clip_qkv, max=self.clip_qkv)
|
112
|
+
q, k, v = qkv.chunk(chunks=3, dim=-1)
|
113
|
+
q, k = self.rotary_emb(positions, q, k)
|
114
|
+
attn_output = self.attn(q, k, v, forward_batch)
|
115
|
+
output, _ = self.o_proj(attn_output)
|
116
|
+
return output
|
117
|
+
|
118
|
+
|
119
|
+
class OlmoMLP(nn.Module):
|
120
|
+
"""
|
121
|
+
This is the MLP block where the output is computed as
|
122
|
+
``MLP(LN(x))`` in ``MLP(LN(x + Attention(LN(x))))``
|
123
|
+
(plus another skip connection).
|
124
|
+
"""
|
125
|
+
|
126
|
+
def __init__(
|
127
|
+
self,
|
128
|
+
config: OlmoConfig,
|
129
|
+
quant_config: Optional[QuantizationConfig] = None,
|
130
|
+
):
|
131
|
+
super().__init__()
|
132
|
+
self.config = config
|
133
|
+
self.hidden_size = config.hidden_size
|
134
|
+
self.intermediate_size = config.intermediate_size
|
135
|
+
|
136
|
+
# Feed-forward input projection.
|
137
|
+
self.gate_up_proj = MergedColumnParallelLinear(
|
138
|
+
self.hidden_size,
|
139
|
+
[self.intermediate_size] * 2,
|
140
|
+
bias=False,
|
141
|
+
quant_config=quant_config,
|
142
|
+
)
|
143
|
+
|
144
|
+
# Activation function.
|
145
|
+
self.act_fn = SiluAndMul()
|
146
|
+
|
147
|
+
# Feed-forward output projection.
|
148
|
+
self.down_proj = RowParallelLinear(
|
149
|
+
self.intermediate_size,
|
150
|
+
self.hidden_size,
|
151
|
+
bias=False,
|
152
|
+
quant_config=quant_config,
|
153
|
+
)
|
154
|
+
|
155
|
+
def forward(
|
156
|
+
self,
|
157
|
+
x: torch.Tensor,
|
158
|
+
) -> torch.Tensor:
|
159
|
+
gate_up, _ = self.gate_up_proj(x)
|
160
|
+
x = self.act_fn(gate_up)
|
161
|
+
x, _ = self.down_proj(x)
|
162
|
+
return x
|
163
|
+
|
164
|
+
|
165
|
+
class OlmoDecoderLayer(nn.Module):
|
166
|
+
"""
|
167
|
+
This is a typical transformer block where the output is
|
168
|
+
computed as ``MLP(LN(x + Attention(LN(x))))``
|
169
|
+
(plus another skip connection).
|
170
|
+
"""
|
171
|
+
|
172
|
+
def __init__(
|
173
|
+
self,
|
174
|
+
config: OlmoConfig,
|
175
|
+
layer_id: int = 0,
|
176
|
+
quant_config: Optional[QuantizationConfig] = None,
|
177
|
+
):
|
178
|
+
super().__init__()
|
179
|
+
# Attention block.
|
180
|
+
self.self_attn = OlmoAttention(config, layer_id, quant_config)
|
181
|
+
|
182
|
+
# MLP block.
|
183
|
+
self.mlp = OlmoMLP(config, quant_config)
|
184
|
+
|
185
|
+
# LayerNorm
|
186
|
+
self.input_layernorm = nn.LayerNorm(
|
187
|
+
config.hidden_size, elementwise_affine=False, bias=False
|
188
|
+
)
|
189
|
+
self.post_attention_layernorm = nn.LayerNorm(
|
190
|
+
config.hidden_size, elementwise_affine=False, bias=False
|
191
|
+
)
|
192
|
+
|
193
|
+
def forward(
|
194
|
+
self,
|
195
|
+
positions: torch.Tensor,
|
196
|
+
hidden_states: torch.Tensor,
|
197
|
+
forward_batch: ForwardBatch,
|
198
|
+
) -> Tuple[torch.Tensor, Optional[Tuple[torch.Tensor, torch.Tensor]]]:
|
199
|
+
# Attention block.
|
200
|
+
residual = hidden_states
|
201
|
+
hidden_states = self.input_layernorm(hidden_states)
|
202
|
+
hidden_states = self.self_attn(positions, hidden_states, forward_batch)
|
203
|
+
hidden_states = hidden_states + residual
|
204
|
+
|
205
|
+
# MLP block.
|
206
|
+
residual = hidden_states
|
207
|
+
hidden_states = self.post_attention_layernorm(hidden_states)
|
208
|
+
hidden_states = self.mlp(hidden_states)
|
209
|
+
hidden_states = residual + hidden_states
|
210
|
+
return hidden_states
|
211
|
+
|
212
|
+
|
213
|
+
class OlmoModel(nn.Module):
|
214
|
+
|
215
|
+
def __init__(
|
216
|
+
self, config: OlmoConfig, quant_config: Optional[QuantizationConfig] = None
|
217
|
+
):
|
218
|
+
super().__init__()
|
219
|
+
self.config = config
|
220
|
+
|
221
|
+
self.embed_tokens = VocabParallelEmbedding(
|
222
|
+
config.vocab_size, config.hidden_size
|
223
|
+
)
|
224
|
+
self.layers = nn.ModuleList(
|
225
|
+
[
|
226
|
+
OlmoDecoderLayer(config, layer_idx, quant_config)
|
227
|
+
for layer_idx in range(config.num_hidden_layers)
|
228
|
+
]
|
229
|
+
)
|
230
|
+
self.norm = nn.LayerNorm(
|
231
|
+
config.hidden_size, elementwise_affine=False, bias=False
|
232
|
+
)
|
233
|
+
|
234
|
+
def forward(
|
235
|
+
self,
|
236
|
+
input_ids: torch.Tensor,
|
237
|
+
positions: torch.Tensor,
|
238
|
+
forward_batch: ForwardBatch,
|
239
|
+
input_embeds: torch.Tensor = None,
|
240
|
+
) -> torch.Tensor:
|
241
|
+
"""
|
242
|
+
:param input_ids: A tensor of shape `(batch_size, seq_len)`.
|
243
|
+
"""
|
244
|
+
# Get embeddings of input.
|
245
|
+
# shape: (batch_size, seq_len, d_model)
|
246
|
+
|
247
|
+
if input_embeds is None:
|
248
|
+
hidden_states = self.embed_tokens(input_ids)
|
249
|
+
else:
|
250
|
+
hidden_states = input_embeds
|
251
|
+
|
252
|
+
# Apply blocks one-by-one.
|
253
|
+
for layer_idx, decoder_layer in enumerate(self.layers):
|
254
|
+
# shape: (batch_size, seq_len, d_model)
|
255
|
+
hidden_states = decoder_layer(
|
256
|
+
positions,
|
257
|
+
hidden_states,
|
258
|
+
forward_batch,
|
259
|
+
)
|
260
|
+
|
261
|
+
# Apply final layer norm.
|
262
|
+
# shape: (batch_size, seq_len or 1, d_model)
|
263
|
+
hidden_states = self.norm(hidden_states)
|
264
|
+
return hidden_states
|
265
|
+
|
266
|
+
|
267
|
+
class OlmoForCausalLM(nn.Module):
|
268
|
+
"""
|
269
|
+
Extremely barebones HF model wrapper.
|
270
|
+
"""
|
271
|
+
|
272
|
+
def __init__(
|
273
|
+
self,
|
274
|
+
config: OlmoConfig,
|
275
|
+
cache_config=None,
|
276
|
+
quant_config: Optional[QuantizationConfig] = None,
|
277
|
+
):
|
278
|
+
super().__init__()
|
279
|
+
self.config = config
|
280
|
+
self.model = OlmoModel(config, quant_config)
|
281
|
+
if config.tie_word_embeddings:
|
282
|
+
self.lm_head = self.model.embed_tokens
|
283
|
+
else:
|
284
|
+
self.unpadded_vocab_size = config.vocab_size
|
285
|
+
self.lm_head = ParallelLMHead(
|
286
|
+
self.unpadded_vocab_size,
|
287
|
+
config.hidden_size,
|
288
|
+
org_num_embeddings=config.vocab_size,
|
289
|
+
quant_config=quant_config,
|
290
|
+
)
|
291
|
+
self.logits_processor = LogitsProcessor(config)
|
292
|
+
|
293
|
+
def forward(
|
294
|
+
self,
|
295
|
+
input_ids: torch.Tensor,
|
296
|
+
positions: torch.Tensor,
|
297
|
+
forward_batch: ForwardBatch,
|
298
|
+
input_embeds: torch.Tensor = None,
|
299
|
+
) -> torch.Tensor:
|
300
|
+
hidden_states = self.model(
|
301
|
+
input_ids=input_ids,
|
302
|
+
positions=positions,
|
303
|
+
forward_batch=forward_batch,
|
304
|
+
input_embeds=input_embeds,
|
305
|
+
)
|
306
|
+
return self.logits_processor(
|
307
|
+
input_ids, hidden_states, self.lm_head.weight, forward_batch
|
308
|
+
)
|
309
|
+
|
310
|
+
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
311
|
+
stacked_params_mapping = [
|
312
|
+
# (param_name, shard_name, shard_id)
|
313
|
+
("qkv_proj", "q_proj", "q"),
|
314
|
+
("qkv_proj", "k_proj", "k"),
|
315
|
+
("qkv_proj", "v_proj", "v"),
|
316
|
+
("gate_up_proj", "gate_proj", 0),
|
317
|
+
("gate_up_proj", "up_proj", 1),
|
318
|
+
]
|
319
|
+
params_dict = dict(self.named_parameters(remove_duplicate=False))
|
320
|
+
for name, loaded_weight in weights:
|
321
|
+
if "rotary_emb.inv_freq" in name:
|
322
|
+
continue
|
323
|
+
if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name:
|
324
|
+
# Models trained using ColossalAI may include these tensors in
|
325
|
+
# the checkpoint. Skip them.
|
326
|
+
continue
|
327
|
+
# With tie_word_embeddings, we can skip lm_head.weight
|
328
|
+
# The weight might appear unnecessarily in the files if the model is
|
329
|
+
# processed with quantization, LoRA, fine-tuning, etc.
|
330
|
+
if self.config.tie_word_embeddings and "lm_head.weight" in name:
|
331
|
+
continue
|
332
|
+
for param_name, weight_name, shard_id in stacked_params_mapping:
|
333
|
+
if weight_name not in name:
|
334
|
+
continue
|
335
|
+
name = name.replace(weight_name, param_name)
|
336
|
+
# Skip loading extra bias for GPTQ models.
|
337
|
+
if name.endswith(".bias") and name not in params_dict:
|
338
|
+
continue
|
339
|
+
param = params_dict[name]
|
340
|
+
weight_loader = param.weight_loader
|
341
|
+
weight_loader(param, loaded_weight, shard_id)
|
342
|
+
break
|
343
|
+
else:
|
344
|
+
# Skip loading extra bias for GPTQ models.
|
345
|
+
if name.endswith(".bias") and name not in params_dict:
|
346
|
+
continue
|
347
|
+
param = params_dict[name]
|
348
|
+
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
349
|
+
weight_loader(param, loaded_weight)
|
350
|
+
|
351
|
+
|
352
|
+
EntryClass = OlmoForCausalLM
|
sglang/srt/models/olmoe.py
CHANGED
@@ -23,7 +23,6 @@ import torch
|
|
23
23
|
import torch.nn.functional as F
|
24
24
|
from torch import nn
|
25
25
|
from transformers import PretrainedConfig
|
26
|
-
from vllm.config import CacheConfig
|
27
26
|
from vllm.distributed import (
|
28
27
|
get_tensor_model_parallel_world_size,
|
29
28
|
tensor_model_parallel_all_reduce,
|
@@ -298,7 +297,7 @@ class OlmoeForCausalLM(nn.Module):
|
|
298
297
|
def __init__(
|
299
298
|
self,
|
300
299
|
config: PretrainedConfig,
|
301
|
-
cache_config
|
300
|
+
cache_config=None,
|
302
301
|
quant_config: Optional[QuantizationConfig] = None,
|
303
302
|
) -> None:
|
304
303
|
super().__init__()
|
sglang/srt/models/qwen.py
CHANGED
@@ -20,7 +20,6 @@ from typing import Any, Dict, Iterable, Optional, Tuple
|
|
20
20
|
import torch
|
21
21
|
from torch import nn
|
22
22
|
from transformers import PretrainedConfig
|
23
|
-
from vllm.config import CacheConfig
|
24
23
|
from vllm.distributed import get_tensor_model_parallel_world_size
|
25
24
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
26
25
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
@@ -243,7 +242,7 @@ class QWenLMHeadModel(nn.Module):
|
|
243
242
|
self,
|
244
243
|
config: PretrainedConfig,
|
245
244
|
quant_config: Optional[QuantizationConfig] = None,
|
246
|
-
cache_config
|
245
|
+
cache_config=None,
|
247
246
|
):
|
248
247
|
super().__init__()
|
249
248
|
self.config = config
|
sglang/srt/models/qwen2.py
CHANGED
@@ -20,7 +20,6 @@ from typing import Any, Dict, Iterable, Optional, Tuple
|
|
20
20
|
|
21
21
|
import torch
|
22
22
|
from torch import nn
|
23
|
-
from vllm.config import CacheConfig
|
24
23
|
from vllm.distributed import get_tensor_model_parallel_world_size
|
25
24
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
26
25
|
from vllm.model_executor.layers.vocab_parallel_embedding import (
|
@@ -268,7 +267,7 @@ class Qwen2ForCausalLM(nn.Module):
|
|
268
267
|
self,
|
269
268
|
config: Qwen2Config,
|
270
269
|
quant_config: Optional[QuantizationConfig] = None,
|
271
|
-
cache_config
|
270
|
+
cache_config=None,
|
272
271
|
) -> None:
|
273
272
|
super().__init__()
|
274
273
|
self.config = config
|
sglang/srt/models/qwen2_moe.py
CHANGED
@@ -23,7 +23,6 @@ import torch
|
|
23
23
|
import torch.nn.functional as F
|
24
24
|
from torch import nn
|
25
25
|
from transformers import PretrainedConfig
|
26
|
-
from vllm.config import CacheConfig
|
27
26
|
from vllm.distributed import (
|
28
27
|
get_tensor_model_parallel_world_size,
|
29
28
|
tensor_model_parallel_all_reduce,
|
@@ -160,7 +159,7 @@ class Qwen2MoeAttention(nn.Module):
|
|
160
159
|
rope_theta: float = 10000,
|
161
160
|
rope_scaling: Optional[Dict[str, Any]] = None,
|
162
161
|
max_position_embeddings: int = 8192,
|
163
|
-
cache_config
|
162
|
+
cache_config=None,
|
164
163
|
quant_config: Optional[QuantizationConfig] = None,
|
165
164
|
) -> None:
|
166
165
|
super().__init__()
|
@@ -236,7 +235,7 @@ class Qwen2MoeDecoderLayer(nn.Module):
|
|
236
235
|
self,
|
237
236
|
config: PretrainedConfig,
|
238
237
|
layer_id: int,
|
239
|
-
cache_config
|
238
|
+
cache_config=None,
|
240
239
|
quant_config: Optional[QuantizationConfig] = None,
|
241
240
|
) -> None:
|
242
241
|
super().__init__()
|
@@ -306,7 +305,7 @@ class Qwen2MoeModel(nn.Module):
|
|
306
305
|
def __init__(
|
307
306
|
self,
|
308
307
|
config: PretrainedConfig,
|
309
|
-
cache_config
|
308
|
+
cache_config=None,
|
310
309
|
quant_config: Optional[QuantizationConfig] = None,
|
311
310
|
) -> None:
|
312
311
|
super().__init__()
|
@@ -355,7 +354,7 @@ class Qwen2MoeForCausalLM(nn.Module):
|
|
355
354
|
def __init__(
|
356
355
|
self,
|
357
356
|
config: PretrainedConfig,
|
358
|
-
cache_config
|
357
|
+
cache_config=None,
|
359
358
|
quant_config: Optional[QuantizationConfig] = None,
|
360
359
|
) -> None:
|
361
360
|
super().__init__()
|