sglang 0.3.6.post3__py3-none-any.whl → 0.4.0.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +1 -1
- sglang/bench_one_batch.py +4 -0
- sglang/bench_serving.py +13 -0
- sglang/check_env.py +1 -1
- sglang/srt/_custom_ops.py +118 -0
- sglang/srt/configs/device_config.py +17 -0
- sglang/srt/configs/load_config.py +84 -0
- sglang/srt/configs/model_config.py +161 -4
- sglang/srt/configs/qwen2vl.py +5 -8
- sglang/srt/constrained/outlines_backend.py +11 -1
- sglang/srt/constrained/outlines_jump_forward.py +8 -1
- sglang/srt/constrained/xgrammar_backend.py +5 -5
- sglang/srt/distributed/__init__.py +3 -0
- sglang/srt/distributed/communication_op.py +34 -0
- sglang/srt/distributed/device_communicators/__init__.py +0 -0
- sglang/srt/distributed/device_communicators/cuda_wrapper.py +182 -0
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +352 -0
- sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +291 -0
- sglang/srt/distributed/device_communicators/hpu_communicator.py +48 -0
- sglang/srt/distributed/device_communicators/pynccl.py +204 -0
- sglang/srt/distributed/device_communicators/pynccl_wrapper.py +362 -0
- sglang/srt/distributed/device_communicators/shm_broadcast.py +568 -0
- sglang/srt/distributed/device_communicators/xpu_communicator.py +47 -0
- sglang/srt/distributed/parallel_state.py +1275 -0
- sglang/srt/distributed/utils.py +223 -0
- sglang/srt/hf_transformers_utils.py +37 -1
- sglang/srt/layers/attention/__init__.py +5 -2
- sglang/srt/layers/attention/double_sparsity_backend.py +22 -8
- sglang/srt/layers/attention/flashinfer_backend.py +33 -20
- sglang/srt/layers/attention/torch_native_backend.py +299 -0
- sglang/srt/layers/attention/triton_backend.py +22 -8
- sglang/srt/layers/attention/triton_ops/extend_attention.py +3 -0
- sglang/srt/layers/ep_moe/__init__.py +0 -0
- sglang/srt/layers/ep_moe/kernels.py +349 -0
- sglang/srt/layers/ep_moe/layer.py +661 -0
- sglang/srt/layers/fused_moe_patch.py +20 -11
- sglang/srt/layers/linear.py +1 -0
- sglang/srt/layers/logits_processor.py +17 -3
- sglang/srt/layers/quantization/__init__.py +36 -2
- sglang/srt/layers/quantization/fp8.py +559 -0
- sglang/srt/layers/quantization/fp8_utils.py +27 -0
- sglang/srt/layers/radix_attention.py +4 -2
- sglang/srt/layers/sampler.py +2 -0
- sglang/srt/layers/torchao_utils.py +23 -45
- sglang/srt/layers/vocab_parallel_embedding.py +1 -0
- sglang/srt/lora/lora.py +1 -1
- sglang/srt/managers/io_struct.py +48 -2
- sglang/srt/managers/schedule_batch.py +19 -14
- sglang/srt/managers/schedule_policy.py +7 -4
- sglang/srt/managers/scheduler.py +145 -85
- sglang/srt/managers/tokenizer_manager.py +166 -68
- sglang/srt/managers/tp_worker.py +36 -3
- sglang/srt/managers/tp_worker_overlap_thread.py +28 -8
- sglang/srt/mem_cache/memory_pool.py +5 -1
- sglang/srt/model_executor/cuda_graph_runner.py +30 -7
- sglang/srt/model_executor/forward_batch_info.py +9 -4
- sglang/srt/model_executor/model_runner.py +146 -153
- sglang/srt/model_loader/__init__.py +34 -0
- sglang/srt/model_loader/loader.py +1139 -0
- sglang/srt/model_loader/utils.py +41 -0
- sglang/srt/model_loader/weight_utils.py +640 -0
- sglang/srt/model_parallel.py +1 -5
- sglang/srt/models/baichuan.py +9 -10
- sglang/srt/models/chatglm.py +6 -15
- sglang/srt/models/commandr.py +4 -5
- sglang/srt/models/dbrx.py +2 -3
- sglang/srt/models/deepseek.py +4 -11
- sglang/srt/models/deepseek_v2.py +90 -18
- sglang/srt/models/exaone.py +2 -3
- sglang/srt/models/gemma.py +2 -6
- sglang/srt/models/gemma2.py +3 -14
- sglang/srt/models/gemma2_reward.py +0 -1
- sglang/srt/models/gpt2.py +5 -12
- sglang/srt/models/gpt_bigcode.py +6 -22
- sglang/srt/models/grok.py +3 -8
- sglang/srt/models/internlm2.py +2 -3
- sglang/srt/models/internlm2_reward.py +0 -1
- sglang/srt/models/llama.py +96 -31
- sglang/srt/models/llama_classification.py +1 -2
- sglang/srt/models/llama_embedding.py +1 -2
- sglang/srt/models/llama_reward.py +2 -3
- sglang/srt/models/llava.py +1 -4
- sglang/srt/models/llavavid.py +1 -2
- sglang/srt/models/minicpm.py +4 -7
- sglang/srt/models/minicpm3.py +6 -19
- sglang/srt/models/mixtral.py +24 -14
- sglang/srt/models/mixtral_quant.py +2 -3
- sglang/srt/models/mllama.py +3 -7
- sglang/srt/models/olmo.py +2 -8
- sglang/srt/models/olmo2.py +0 -1
- sglang/srt/models/olmoe.py +3 -5
- sglang/srt/models/phi3_small.py +8 -13
- sglang/srt/models/qwen.py +2 -3
- sglang/srt/models/qwen2.py +10 -9
- sglang/srt/models/qwen2_moe.py +4 -16
- sglang/srt/models/qwen2_vl.py +2 -6
- sglang/srt/models/registry.py +99 -0
- sglang/srt/models/stablelm.py +2 -3
- sglang/srt/models/torch_native_llama.py +6 -17
- sglang/srt/models/xverse.py +2 -4
- sglang/srt/models/xverse_moe.py +4 -11
- sglang/srt/models/yivl.py +2 -3
- sglang/srt/openai_api/adapter.py +9 -5
- sglang/srt/openai_api/protocol.py +1 -0
- sglang/srt/sampling/sampling_batch_info.py +9 -8
- sglang/srt/server.py +270 -173
- sglang/srt/server_args.py +102 -29
- sglang/srt/utils.py +295 -28
- sglang/test/test_utils.py +7 -0
- sglang/version.py +1 -1
- {sglang-0.3.6.post3.dist-info → sglang-0.4.0.post1.dist-info}/METADATA +5 -4
- sglang-0.4.0.post1.dist-info/RECORD +189 -0
- sglang-0.3.6.post3.dist-info/RECORD +0 -162
- {sglang-0.3.6.post3.dist-info → sglang-0.4.0.post1.dist-info}/LICENSE +0 -0
- {sglang-0.3.6.post3.dist-info → sglang-0.4.0.post1.dist-info}/WHEEL +0 -0
- {sglang-0.3.6.post3.dist-info → sglang-0.4.0.post1.dist-info}/top_level.txt +0 -0
sglang/srt/models/gemma2.py
CHANGED
@@ -20,12 +20,8 @@ from typing import Iterable, Optional, Set, Tuple, Union
|
|
20
20
|
import torch
|
21
21
|
from torch import nn
|
22
22
|
from transformers import PretrainedConfig
|
23
|
-
from vllm.config import LoRAConfig
|
24
23
|
from vllm.distributed import get_tensor_model_parallel_world_size
|
25
24
|
|
26
|
-
# from vllm.model_executor.layers.rotary_embedding import GemmaRotaryEmbedding
|
27
|
-
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
28
|
-
|
29
25
|
from sglang.srt.layers.activation import GeluAndMul
|
30
26
|
from sglang.srt.layers.layernorm import GemmaRMSNorm
|
31
27
|
from sglang.srt.layers.linear import (
|
@@ -38,6 +34,7 @@ from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
|
38
34
|
from sglang.srt.layers.radix_attention import RadixAttention
|
39
35
|
from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
40
36
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
37
|
+
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
41
38
|
from sglang.srt.utils import make_layers
|
42
39
|
|
43
40
|
|
@@ -106,7 +103,6 @@ class Gemma2Attention(nn.Module):
|
|
106
103
|
head_dim: int,
|
107
104
|
max_position_embeddings: int,
|
108
105
|
rope_theta: float,
|
109
|
-
cache_config=None,
|
110
106
|
quant_config: Optional[QuantizationConfig] = None,
|
111
107
|
) -> None:
|
112
108
|
super().__init__()
|
@@ -191,7 +187,6 @@ class Gemma2DecoderLayer(nn.Module):
|
|
191
187
|
self,
|
192
188
|
layer_id: int,
|
193
189
|
config: PretrainedConfig,
|
194
|
-
cache_config=None,
|
195
190
|
quant_config: Optional[QuantizationConfig] = None,
|
196
191
|
) -> None:
|
197
192
|
super().__init__()
|
@@ -205,7 +200,6 @@ class Gemma2DecoderLayer(nn.Module):
|
|
205
200
|
head_dim=config.head_dim,
|
206
201
|
max_position_embeddings=config.max_position_embeddings,
|
207
202
|
rope_theta=config.rope_theta,
|
208
|
-
cache_config=cache_config,
|
209
203
|
quant_config=quant_config,
|
210
204
|
)
|
211
205
|
self.hidden_size = config.hidden_size
|
@@ -258,7 +252,6 @@ class Gemma2Model(nn.Module):
|
|
258
252
|
def __init__(
|
259
253
|
self,
|
260
254
|
config: PretrainedConfig,
|
261
|
-
cache_config=None,
|
262
255
|
quant_config: Optional[QuantizationConfig] = None,
|
263
256
|
) -> None:
|
264
257
|
super().__init__()
|
@@ -273,7 +266,6 @@ class Gemma2Model(nn.Module):
|
|
273
266
|
lambda idx, prefix: Gemma2DecoderLayer(
|
274
267
|
layer_id=idx,
|
275
268
|
config=config,
|
276
|
-
cache_config=cache_config,
|
277
269
|
quant_config=quant_config,
|
278
270
|
),
|
279
271
|
prefix="",
|
@@ -342,15 +334,12 @@ class Gemma2ForCausalLM(nn.Module):
|
|
342
334
|
def __init__(
|
343
335
|
self,
|
344
336
|
config: PretrainedConfig,
|
345
|
-
cache_config=None,
|
346
337
|
quant_config: Optional[QuantizationConfig] = None,
|
347
|
-
lora_config: Optional[LoRAConfig] = None,
|
348
338
|
) -> None:
|
349
|
-
del lora_config # Unused.
|
350
339
|
super().__init__()
|
351
340
|
self.config = config
|
352
341
|
self.quant_config = quant_config
|
353
|
-
self.model = Gemma2Model(config,
|
342
|
+
self.model = Gemma2Model(config, quant_config)
|
354
343
|
self.logits_processor = LogitsProcessor(config)
|
355
344
|
|
356
345
|
@torch.no_grad()
|
@@ -363,7 +352,7 @@ class Gemma2ForCausalLM(nn.Module):
|
|
363
352
|
) -> torch.Tensor:
|
364
353
|
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
|
365
354
|
return self.logits_processor(
|
366
|
-
input_ids, hidden_states, self.model.embed_tokens
|
355
|
+
input_ids, hidden_states, self.model.embed_tokens, forward_batch
|
367
356
|
)
|
368
357
|
|
369
358
|
def get_attention_sliding_window_size(self):
|
sglang/srt/models/gpt2.py
CHANGED
@@ -22,11 +22,9 @@ from typing import Iterable, List, Optional, Tuple
|
|
22
22
|
import torch
|
23
23
|
from torch import nn
|
24
24
|
from transformers import GPT2Config
|
25
|
-
from vllm.config import CacheConfig
|
26
25
|
from vllm.distributed.parallel_state import get_tensor_model_parallel_world_size
|
27
26
|
from vllm.model_executor.layers.activation import get_act_fn
|
28
27
|
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
29
|
-
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
30
28
|
|
31
29
|
# from sglang.srt.layers.activation import get_act_fn
|
32
30
|
from sglang.srt.layers.linear import (
|
@@ -39,6 +37,7 @@ from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
|
39
37
|
from sglang.srt.layers.radix_attention import RadixAttention
|
40
38
|
from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
41
39
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
40
|
+
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
42
41
|
|
43
42
|
|
44
43
|
class GPT2Attention(nn.Module):
|
@@ -47,7 +46,6 @@ class GPT2Attention(nn.Module):
|
|
47
46
|
self,
|
48
47
|
layer_id: int,
|
49
48
|
config: GPT2Config,
|
50
|
-
cache_config=None,
|
51
49
|
quant_config: Optional[QuantizationConfig] = None,
|
52
50
|
prefix: str = "",
|
53
51
|
):
|
@@ -140,7 +138,6 @@ class GPT2Block(nn.Module):
|
|
140
138
|
self,
|
141
139
|
layer_id: int,
|
142
140
|
config: GPT2Config,
|
143
|
-
cache_config=None,
|
144
141
|
quant_config: Optional[QuantizationConfig] = None,
|
145
142
|
prefix: str = "",
|
146
143
|
):
|
@@ -150,7 +147,7 @@ class GPT2Block(nn.Module):
|
|
150
147
|
|
151
148
|
self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
152
149
|
self.attn = GPT2Attention(
|
153
|
-
layer_id, config,
|
150
|
+
layer_id, config, quant_config, prefix=f"{prefix}.attn"
|
154
151
|
)
|
155
152
|
self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
156
153
|
self.mlp = GPT2MLP(inner_dim, config, quant_config, prefix=f"{prefix}.mlp")
|
@@ -182,7 +179,6 @@ class GPT2Model(nn.Module):
|
|
182
179
|
def __init__(
|
183
180
|
self,
|
184
181
|
config: GPT2Config,
|
185
|
-
cache_config=None,
|
186
182
|
quant_config: Optional[QuantizationConfig] = None,
|
187
183
|
prefix: str = "",
|
188
184
|
):
|
@@ -196,7 +192,7 @@ class GPT2Model(nn.Module):
|
|
196
192
|
self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
|
197
193
|
self.h = nn.ModuleList(
|
198
194
|
[
|
199
|
-
GPT2Block(i, config,
|
195
|
+
GPT2Block(i, config, quant_config)
|
200
196
|
for i in range(config.num_hidden_layers)
|
201
197
|
]
|
202
198
|
)
|
@@ -226,15 +222,12 @@ class GPT2LMHeadModel(nn.Module):
|
|
226
222
|
def __init__(
|
227
223
|
self,
|
228
224
|
config: GPT2Config,
|
229
|
-
cache_config=None,
|
230
225
|
quant_config: Optional[QuantizationConfig] = None,
|
231
226
|
):
|
232
227
|
super().__init__()
|
233
228
|
self.config = config
|
234
229
|
self.quant_config = quant_config
|
235
|
-
self.transformer = GPT2Model(
|
236
|
-
config, cache_config, quant_config, prefix="transformer"
|
237
|
-
)
|
230
|
+
self.transformer = GPT2Model(config, quant_config, prefix="transformer")
|
238
231
|
self.lm_head = self.transformer.wte
|
239
232
|
|
240
233
|
self.logits_processor = LogitsProcessor(config)
|
@@ -247,7 +240,7 @@ class GPT2LMHeadModel(nn.Module):
|
|
247
240
|
) -> torch.Tensor:
|
248
241
|
hidden_states = self.transformer(input_ids, positions, forward_batch)
|
249
242
|
return self.logits_processor(
|
250
|
-
input_ids, hidden_states, self.lm_head
|
243
|
+
input_ids, hidden_states, self.lm_head, forward_batch
|
251
244
|
)
|
252
245
|
|
253
246
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
sglang/srt/models/gpt_bigcode.py
CHANGED
@@ -21,9 +21,7 @@ from typing import Iterable, Optional, Tuple
|
|
21
21
|
import torch
|
22
22
|
from torch import nn
|
23
23
|
from transformers import GPTBigCodeConfig
|
24
|
-
from vllm.config import LoRAConfig
|
25
24
|
from vllm.distributed import get_tensor_model_parallel_world_size
|
26
|
-
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
27
25
|
|
28
26
|
from sglang.srt.layers.activation import get_act_fn
|
29
27
|
from sglang.srt.layers.linear import (
|
@@ -36,6 +34,7 @@ from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
|
36
34
|
from sglang.srt.layers.radix_attention import RadixAttention
|
37
35
|
from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
38
36
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
37
|
+
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
39
38
|
|
40
39
|
|
41
40
|
class GPTBigCodeAttention(nn.Module):
|
@@ -44,7 +43,6 @@ class GPTBigCodeAttention(nn.Module):
|
|
44
43
|
self,
|
45
44
|
layer_id: int,
|
46
45
|
config: GPTBigCodeConfig,
|
47
|
-
cache_config=None,
|
48
46
|
quant_config: Optional[QuantizationConfig] = None,
|
49
47
|
):
|
50
48
|
super().__init__()
|
@@ -145,7 +143,6 @@ class GPTBigCodeBlock(nn.Module):
|
|
145
143
|
self,
|
146
144
|
layer_id: int,
|
147
145
|
config: GPTBigCodeConfig,
|
148
|
-
cache_config=None,
|
149
146
|
quant_config: Optional[QuantizationConfig] = None,
|
150
147
|
):
|
151
148
|
super().__init__()
|
@@ -153,7 +150,7 @@ class GPTBigCodeBlock(nn.Module):
|
|
153
150
|
inner_dim = config.n_inner if config.n_inner is not None else 4 * hidden_size
|
154
151
|
|
155
152
|
self.ln_1 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
156
|
-
self.attn = GPTBigCodeAttention(layer_id, config,
|
153
|
+
self.attn = GPTBigCodeAttention(layer_id, config, quant_config)
|
157
154
|
self.ln_2 = nn.LayerNorm(hidden_size, eps=config.layer_norm_epsilon)
|
158
155
|
self.mlp = GPTBigMLP(inner_dim, config, quant_config)
|
159
156
|
|
@@ -183,20 +180,14 @@ class GPTBigCodeModel(nn.Module):
|
|
183
180
|
def __init__(
|
184
181
|
self,
|
185
182
|
config: GPTBigCodeConfig,
|
186
|
-
cache_config=None,
|
187
183
|
quant_config: Optional[QuantizationConfig] = None,
|
188
|
-
lora_config: Optional[LoRAConfig] = None,
|
189
184
|
):
|
190
185
|
super().__init__()
|
191
186
|
self.config = config
|
192
187
|
assert not config.add_cross_attention
|
193
188
|
|
194
189
|
self.embed_dim = config.hidden_size
|
195
|
-
lora_vocab =
|
196
|
-
(lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1))
|
197
|
-
if lora_config
|
198
|
-
else 0
|
199
|
-
)
|
190
|
+
lora_vocab = 0
|
200
191
|
self.vocab_size = config.vocab_size + lora_vocab
|
201
192
|
self.wte = VocabParallelEmbedding(
|
202
193
|
self.vocab_size, self.embed_dim, org_num_embeddings=config.vocab_size
|
@@ -204,7 +195,7 @@ class GPTBigCodeModel(nn.Module):
|
|
204
195
|
self.wpe = nn.Embedding(config.max_position_embeddings, self.embed_dim)
|
205
196
|
self.h = nn.ModuleList(
|
206
197
|
[
|
207
|
-
GPTBigCodeBlock(i, config,
|
198
|
+
GPTBigCodeBlock(i, config, quant_config)
|
208
199
|
for i in range(config.num_hidden_layers)
|
209
200
|
]
|
210
201
|
)
|
@@ -243,23 +234,16 @@ class GPTBigCodeForCausalLM(nn.Module):
|
|
243
234
|
def __init__(
|
244
235
|
self,
|
245
236
|
config: GPTBigCodeConfig,
|
246
|
-
cache_config=None,
|
247
237
|
quant_config: Optional[QuantizationConfig] = None,
|
248
|
-
lora_config: Optional[LoRAConfig] = None,
|
249
238
|
):
|
250
239
|
super().__init__()
|
251
240
|
|
252
241
|
self.config = config
|
253
|
-
self.lora_config = lora_config
|
254
242
|
|
255
243
|
self.quant_config = quant_config
|
256
|
-
self.transformer = GPTBigCodeModel(
|
257
|
-
config, cache_config, quant_config, lora_config
|
258
|
-
)
|
244
|
+
self.transformer = GPTBigCodeModel(config, quant_config)
|
259
245
|
self.lm_head = self.transformer.wte
|
260
246
|
self.unpadded_vocab_size = config.vocab_size
|
261
|
-
if lora_config:
|
262
|
-
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
|
263
247
|
self.logits_processor = LogitsProcessor(config)
|
264
248
|
|
265
249
|
@torch.no_grad()
|
@@ -271,7 +255,7 @@ class GPTBigCodeForCausalLM(nn.Module):
|
|
271
255
|
) -> torch.Tensor:
|
272
256
|
hidden_states = self.transformer(input_ids, positions, forward_batch)
|
273
257
|
return self.logits_processor(
|
274
|
-
input_ids, hidden_states, self.lm_head
|
258
|
+
input_ids, hidden_states, self.lm_head, forward_batch
|
275
259
|
)
|
276
260
|
|
277
261
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
sglang/srt/models/grok.py
CHANGED
@@ -24,7 +24,6 @@ from torch import nn
|
|
24
24
|
from transformers import PretrainedConfig
|
25
25
|
from vllm.distributed import get_tensor_model_parallel_world_size
|
26
26
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
27
|
-
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
28
27
|
|
29
28
|
from sglang.srt.layers.fused_moe_triton import FusedMoE
|
30
29
|
from sglang.srt.layers.layernorm import RMSNorm
|
@@ -36,13 +35,13 @@ from sglang.srt.layers.linear import (
|
|
36
35
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
37
36
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
38
37
|
from sglang.srt.layers.radix_attention import RadixAttention
|
39
|
-
from sglang.srt.layers.torchao_utils import apply_torchao_config_
|
40
38
|
from sglang.srt.layers.vocab_parallel_embedding import (
|
41
39
|
ParallelLMHead,
|
42
40
|
VocabParallelEmbedding,
|
43
41
|
)
|
44
|
-
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
45
42
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
43
|
+
from sglang.srt.model_loader.loader import DefaultModelLoader
|
44
|
+
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
46
45
|
|
47
46
|
|
48
47
|
class Grok1MoE(nn.Module):
|
@@ -285,12 +284,10 @@ class Grok1ForCausalLM(nn.Module):
|
|
285
284
|
self,
|
286
285
|
config: PretrainedConfig,
|
287
286
|
quant_config: Optional[QuantizationConfig] = None,
|
288
|
-
cache_config=None,
|
289
287
|
) -> None:
|
290
288
|
super().__init__()
|
291
289
|
self.config = config
|
292
290
|
self.quant_config = quant_config
|
293
|
-
self.torchao_config = global_server_args_dict["torchao_config"]
|
294
291
|
self.model = Grok1Model(config, quant_config=quant_config)
|
295
292
|
self.lm_head = ParallelLMHead(config.vocab_size, config.hidden_size)
|
296
293
|
self.logits_processor = LogitsProcessor(config)
|
@@ -304,7 +301,7 @@ class Grok1ForCausalLM(nn.Module):
|
|
304
301
|
) -> torch.Tensor:
|
305
302
|
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
|
306
303
|
return self.logits_processor(
|
307
|
-
input_ids, hidden_states, self.lm_head
|
304
|
+
input_ids, hidden_states, self.lm_head, forward_batch
|
308
305
|
)
|
309
306
|
|
310
307
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
@@ -374,8 +371,6 @@ class Grok1ForCausalLM(nn.Module):
|
|
374
371
|
)
|
375
372
|
weight_loader(param, loaded_weight)
|
376
373
|
|
377
|
-
apply_torchao_config_(self, params_dict, set(["proj.weight"]))
|
378
|
-
|
379
374
|
|
380
375
|
class Grok1ModelForCausalLM(Grok1ForCausalLM):
|
381
376
|
"""An alias for backward-compatbility."""
|
sglang/srt/models/internlm2.py
CHANGED
@@ -21,7 +21,6 @@ from torch import nn
|
|
21
21
|
from transformers import PretrainedConfig
|
22
22
|
from vllm.distributed import get_tensor_model_parallel_world_size
|
23
23
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
24
|
-
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
25
24
|
|
26
25
|
from sglang.srt.layers.activation import SiluAndMul
|
27
26
|
from sglang.srt.layers.layernorm import RMSNorm
|
@@ -38,6 +37,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
|
|
38
37
|
VocabParallelEmbedding,
|
39
38
|
)
|
40
39
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
40
|
+
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
41
41
|
|
42
42
|
|
43
43
|
class InternLM2MLP(nn.Module):
|
@@ -251,7 +251,6 @@ class InternLM2ForCausalLM(nn.Module):
|
|
251
251
|
self,
|
252
252
|
config: PretrainedConfig,
|
253
253
|
quant_config: Optional[QuantizationConfig] = None,
|
254
|
-
cache_config=None,
|
255
254
|
) -> None:
|
256
255
|
super().__init__()
|
257
256
|
self.config = config
|
@@ -270,7 +269,7 @@ class InternLM2ForCausalLM(nn.Module):
|
|
270
269
|
) -> torch.Tensor:
|
271
270
|
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
|
272
271
|
return self.logits_processor(
|
273
|
-
input_ids, hidden_states, self.output
|
272
|
+
input_ids, hidden_states, self.output, forward_batch
|
274
273
|
)
|
275
274
|
|
276
275
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
sglang/srt/models/llama.py
CHANGED
@@ -16,6 +16,7 @@
|
|
16
16
|
# https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/llama.py#L1
|
17
17
|
"""Inference-only LLaMA model compatible with HuggingFace weights."""
|
18
18
|
|
19
|
+
import logging
|
19
20
|
from typing import Any, Dict, Iterable, Optional, Tuple
|
20
21
|
|
21
22
|
import torch
|
@@ -23,7 +24,6 @@ from torch import nn
|
|
23
24
|
from transformers import LlamaConfig
|
24
25
|
from vllm.distributed import get_tensor_model_parallel_world_size
|
25
26
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
26
|
-
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
27
27
|
|
28
28
|
from sglang.srt.layers.activation import SiluAndMul
|
29
29
|
from sglang.srt.layers.layernorm import RMSNorm
|
@@ -36,14 +36,16 @@ from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorO
|
|
36
36
|
from sglang.srt.layers.pooler import Pooler, PoolingType
|
37
37
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
38
38
|
from sglang.srt.layers.radix_attention import RadixAttention
|
39
|
-
from sglang.srt.layers.torchao_utils import apply_torchao_config_
|
40
39
|
from sglang.srt.layers.vocab_parallel_embedding import (
|
41
40
|
ParallelLMHead,
|
42
41
|
VocabParallelEmbedding,
|
43
42
|
)
|
44
|
-
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
45
43
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
44
|
+
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
46
45
|
from sglang.srt.utils import make_layers
|
46
|
+
from sglang.utils import get_exception_traceback
|
47
|
+
|
48
|
+
logger = logging.getLogger(__name__)
|
47
49
|
|
48
50
|
|
49
51
|
class LlamaMLP(nn.Module):
|
@@ -255,6 +257,7 @@ class LlamaModel(nn.Module):
|
|
255
257
|
self.embed_tokens = VocabParallelEmbedding(
|
256
258
|
config.vocab_size,
|
257
259
|
config.hidden_size,
|
260
|
+
quant_config=quant_config,
|
258
261
|
)
|
259
262
|
self.layers = make_layers(
|
260
263
|
config.num_hidden_layers,
|
@@ -295,16 +298,29 @@ class LlamaForCausalLM(nn.Module):
|
|
295
298
|
self,
|
296
299
|
config: LlamaConfig,
|
297
300
|
quant_config: Optional[QuantizationConfig] = None,
|
298
|
-
cache_config=None,
|
299
301
|
) -> None:
|
300
302
|
super().__init__()
|
301
303
|
self.config = config
|
302
304
|
self.quant_config = quant_config
|
303
|
-
self.torchao_config = global_server_args_dict["torchao_config"]
|
304
305
|
self.model = LlamaModel(config, quant_config=quant_config)
|
305
|
-
|
306
|
+
# Llama 3.2 1B Insturct set tie_word_embeddings to True
|
307
|
+
# Llama 3.1 8B Insturct set tie_word_embeddings to False
|
308
|
+
if self.config.tie_word_embeddings:
|
309
|
+
self.lm_head = self.model.embed_tokens
|
310
|
+
else:
|
311
|
+
self.lm_head = ParallelLMHead(
|
312
|
+
config.vocab_size, config.hidden_size, quant_config=quant_config
|
313
|
+
)
|
306
314
|
self.logits_processor = LogitsProcessor(config)
|
307
315
|
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
|
316
|
+
self.stacked_params_mapping = [
|
317
|
+
# (param_name, shard_name, shard_id)
|
318
|
+
(".qkv_proj", ".q_proj", "q"),
|
319
|
+
(".qkv_proj", ".k_proj", "k"),
|
320
|
+
(".qkv_proj", ".v_proj", "v"),
|
321
|
+
(".gate_up_proj", ".gate_proj", 0),
|
322
|
+
(".gate_up_proj", ".up_proj", 1),
|
323
|
+
]
|
308
324
|
|
309
325
|
@torch.no_grad()
|
310
326
|
def forward(
|
@@ -318,7 +334,7 @@ class LlamaForCausalLM(nn.Module):
|
|
318
334
|
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
|
319
335
|
if not get_embedding:
|
320
336
|
return self.logits_processor(
|
321
|
-
input_ids, hidden_states, self.lm_head
|
337
|
+
input_ids, hidden_states, self.lm_head, forward_batch
|
322
338
|
)
|
323
339
|
else:
|
324
340
|
return self.pooler(hidden_states, forward_batch)
|
@@ -349,15 +365,7 @@ class LlamaForCausalLM(nn.Module):
|
|
349
365
|
return params_mapping.get(name, name)
|
350
366
|
|
351
367
|
def get_module_name_from_weight_name(self, name):
|
352
|
-
|
353
|
-
# (param_name, shard_name, shard_id, num_shard)
|
354
|
-
("qkv_proj", "q_proj", "q", 3),
|
355
|
-
("qkv_proj", "k_proj", "k", 3),
|
356
|
-
("qkv_proj", "v_proj", "v", 3),
|
357
|
-
("gate_up_proj", "gate_proj", 0, 2),
|
358
|
-
("gate_up_proj", "up_proj", 1, 2),
|
359
|
-
]
|
360
|
-
for param_name, weight_name, shard_id, num_shard in stacked_params_mapping:
|
368
|
+
for param_name, weight_name, shard_id, num_shard in self.stacked_params_mapping:
|
361
369
|
if weight_name in name:
|
362
370
|
return (
|
363
371
|
name.replace(weight_name, param_name)[: -len(".weight")],
|
@@ -378,13 +386,8 @@ class LlamaForCausalLM(nn.Module):
|
|
378
386
|
(".gate_up_proj", ".gate_proj", 0),
|
379
387
|
(".gate_up_proj", ".up_proj", 1),
|
380
388
|
]
|
381
|
-
params_dict = dict(self.named_parameters())
|
382
389
|
|
383
|
-
|
384
|
-
hasattr(self.config, "tie_word_embeddings")
|
385
|
-
and self.config.tie_word_embeddings
|
386
|
-
and "lm_head.weight" in params_dict
|
387
|
-
)
|
390
|
+
params_dict = dict(self.named_parameters())
|
388
391
|
|
389
392
|
for name, loaded_weight in weights:
|
390
393
|
if "rotary_emb.inv_freq" in name or "projector" in name:
|
@@ -418,16 +421,78 @@ class LlamaForCausalLM(nn.Module):
|
|
418
421
|
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
419
422
|
weight_loader(param, loaded_weight)
|
420
423
|
|
421
|
-
|
422
|
-
|
423
|
-
|
424
|
-
|
425
|
-
|
426
|
-
|
427
|
-
|
428
|
-
|
424
|
+
def get_weights_by_name(
|
425
|
+
self, name: str, truncate_size: int = 100, tp_size: int = 1
|
426
|
+
) -> Optional[torch.Tensor]:
|
427
|
+
"""Get the weights of the parameter by its name. Similar to `get_parameter` in Hugging Face.
|
428
|
+
|
429
|
+
Only used for unit test with an unoptimized performance.
|
430
|
+
For optimized performance, please use torch.save and torch.load.
|
431
|
+
"""
|
432
|
+
try:
|
433
|
+
if name == "lm_head.weight" and self.config.tie_word_embeddings:
|
434
|
+
logger.info(
|
435
|
+
"word embedding is tied for this model, return embed_tokens.weight as lm_head.weight."
|
436
|
+
)
|
437
|
+
return (
|
438
|
+
self.model.embed_tokens.weight.cpu()
|
439
|
+
.to(torch.float32)
|
440
|
+
.numpy()
|
441
|
+
.tolist()[:truncate_size]
|
442
|
+
)
|
429
443
|
|
430
|
-
|
444
|
+
mapped_name = name
|
445
|
+
mapped_shard_id = None
|
446
|
+
for param_name, weight_name, shard_id in self.stacked_params_mapping:
|
447
|
+
if weight_name in name:
|
448
|
+
mapped_name = name.replace(weight_name, param_name)
|
449
|
+
mapped_shard_id = shard_id
|
450
|
+
break
|
451
|
+
params_dict = dict(self.named_parameters())
|
452
|
+
param = params_dict[mapped_name]
|
453
|
+
if mapped_shard_id is not None:
|
454
|
+
if mapped_shard_id in ["q", "k", "v"]:
|
455
|
+
num_heads = self.config.num_attention_heads // tp_size
|
456
|
+
num_kv_heads = self.config.num_key_value_heads // tp_size
|
457
|
+
head_dim = (
|
458
|
+
self.config.hidden_size // self.config.num_attention_heads
|
459
|
+
)
|
460
|
+
if mapped_shard_id == "q":
|
461
|
+
offset = 0
|
462
|
+
size = num_heads * head_dim
|
463
|
+
elif mapped_shard_id == "k":
|
464
|
+
offset = num_heads * head_dim
|
465
|
+
size = num_kv_heads * head_dim
|
466
|
+
elif mapped_shard_id == "v":
|
467
|
+
offset = (num_heads + num_kv_heads) * head_dim
|
468
|
+
size = num_kv_heads * head_dim
|
469
|
+
weight = param.data.narrow(0, offset, size)
|
470
|
+
elif mapped_shard_id in [0, 1]:
|
471
|
+
intermediate_size = self.config.intermediate_size
|
472
|
+
slice_size = intermediate_size // tp_size
|
473
|
+
if mapped_shard_id == 0: # gate_proj
|
474
|
+
offset = 0
|
475
|
+
size = slice_size
|
476
|
+
elif mapped_shard_id == 1: # up_proj
|
477
|
+
offset = slice_size
|
478
|
+
size = slice_size
|
479
|
+
|
480
|
+
weight = param.data.narrow(0, offset, size)
|
481
|
+
else:
|
482
|
+
weight = param.data
|
483
|
+
else:
|
484
|
+
weight = param.data
|
485
|
+
if tp_size > 1 and ("o_proj" in name or "down_proj" in name):
|
486
|
+
gathered_weights = [torch.zeros_like(weight) for _ in range(tp_size)]
|
487
|
+
torch.distributed.all_gather(gathered_weights, weight)
|
488
|
+
weight = torch.cat(gathered_weights, dim=1)
|
489
|
+
return weight.cpu().to(torch.float32).numpy().tolist()[:truncate_size]
|
490
|
+
|
491
|
+
except Exception:
|
492
|
+
logger.error(
|
493
|
+
f"Error getting weights by name {name} in LlamaForCausalLM: {get_exception_traceback()}"
|
494
|
+
)
|
495
|
+
return None
|
431
496
|
|
432
497
|
|
433
498
|
class Phi3ForCausalLM(LlamaForCausalLM):
|
@@ -17,11 +17,11 @@ from typing import Iterable, Optional, Tuple
|
|
17
17
|
import torch
|
18
18
|
from torch import nn
|
19
19
|
from transformers import LlamaConfig
|
20
|
-
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
21
20
|
|
22
21
|
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
23
22
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
24
23
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
24
|
+
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
25
25
|
from sglang.srt.models.llama import LlamaForCausalLM, LlamaModel
|
26
26
|
|
27
27
|
|
@@ -30,7 +30,6 @@ class LlamaForClassification(nn.Module):
|
|
30
30
|
self,
|
31
31
|
config: LlamaConfig,
|
32
32
|
quant_config: Optional[QuantizationConfig] = None,
|
33
|
-
cache_config=None,
|
34
33
|
) -> None:
|
35
34
|
super().__init__()
|
36
35
|
self.config = config
|
@@ -3,10 +3,10 @@ from typing import Iterable, Tuple
|
|
3
3
|
import torch
|
4
4
|
from torch import nn
|
5
5
|
from transformers import LlamaConfig
|
6
|
-
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
7
6
|
|
8
7
|
from sglang.srt.layers.pooler import EmbeddingPoolerOutput, Pooler, PoolingType
|
9
8
|
from sglang.srt.model_executor.model_runner import ForwardBatch
|
9
|
+
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
10
10
|
from sglang.srt.models.llama import LlamaModel
|
11
11
|
|
12
12
|
|
@@ -15,7 +15,6 @@ class LlamaEmbeddingModel(nn.Module):
|
|
15
15
|
self,
|
16
16
|
config: LlamaConfig,
|
17
17
|
quant_config=None,
|
18
|
-
cache_config=None,
|
19
18
|
) -> None:
|
20
19
|
super().__init__()
|
21
20
|
self.model = LlamaModel(config, quant_config=quant_config)
|
@@ -21,6 +21,7 @@ from transformers import LlamaConfig
|
|
21
21
|
from sglang.srt.layers.pooler import EmbeddingPoolerOutput, Pooler, PoolingType
|
22
22
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
23
23
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
24
|
+
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
24
25
|
from sglang.srt.models.llama import LlamaForCausalLM, LlamaModel
|
25
26
|
|
26
27
|
|
@@ -29,7 +30,6 @@ class LlamaForSequenceClassification(nn.Module):
|
|
29
30
|
self,
|
30
31
|
config: LlamaConfig,
|
31
32
|
quant_config: Optional[QuantizationConfig] = None,
|
32
|
-
cache_config=None,
|
33
33
|
) -> None:
|
34
34
|
super().__init__()
|
35
35
|
self.config = config
|
@@ -84,9 +84,8 @@ class LlamaForSequenceClassificationWithNormal_Weights(LlamaForSequenceClassific
|
|
84
84
|
self,
|
85
85
|
config: LlamaConfig,
|
86
86
|
quant_config: Optional[QuantizationConfig] = None,
|
87
|
-
cache_config=None,
|
88
87
|
) -> None:
|
89
|
-
super().__init__(config, quant_config
|
88
|
+
super().__init__(config, quant_config)
|
90
89
|
self.weights = self.Weights(config.hidden_size, self.num_labels)
|
91
90
|
|
92
91
|
@torch.no_grad()
|