sglang 0.3.6.post3__py3-none-any.whl → 0.4.0.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +1 -1
- sglang/bench_one_batch.py +4 -0
- sglang/bench_serving.py +13 -0
- sglang/check_env.py +1 -1
- sglang/srt/_custom_ops.py +118 -0
- sglang/srt/configs/device_config.py +17 -0
- sglang/srt/configs/load_config.py +84 -0
- sglang/srt/configs/model_config.py +161 -4
- sglang/srt/configs/qwen2vl.py +5 -8
- sglang/srt/constrained/outlines_backend.py +11 -1
- sglang/srt/constrained/outlines_jump_forward.py +8 -1
- sglang/srt/constrained/xgrammar_backend.py +5 -5
- sglang/srt/distributed/__init__.py +3 -0
- sglang/srt/distributed/communication_op.py +34 -0
- sglang/srt/distributed/device_communicators/__init__.py +0 -0
- sglang/srt/distributed/device_communicators/cuda_wrapper.py +182 -0
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +352 -0
- sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +291 -0
- sglang/srt/distributed/device_communicators/hpu_communicator.py +48 -0
- sglang/srt/distributed/device_communicators/pynccl.py +204 -0
- sglang/srt/distributed/device_communicators/pynccl_wrapper.py +362 -0
- sglang/srt/distributed/device_communicators/shm_broadcast.py +568 -0
- sglang/srt/distributed/device_communicators/xpu_communicator.py +47 -0
- sglang/srt/distributed/parallel_state.py +1275 -0
- sglang/srt/distributed/utils.py +223 -0
- sglang/srt/hf_transformers_utils.py +37 -1
- sglang/srt/layers/attention/__init__.py +5 -2
- sglang/srt/layers/attention/double_sparsity_backend.py +22 -8
- sglang/srt/layers/attention/flashinfer_backend.py +33 -20
- sglang/srt/layers/attention/torch_native_backend.py +299 -0
- sglang/srt/layers/attention/triton_backend.py +22 -8
- sglang/srt/layers/attention/triton_ops/extend_attention.py +3 -0
- sglang/srt/layers/ep_moe/__init__.py +0 -0
- sglang/srt/layers/ep_moe/kernels.py +349 -0
- sglang/srt/layers/ep_moe/layer.py +661 -0
- sglang/srt/layers/fused_moe_patch.py +20 -11
- sglang/srt/layers/linear.py +1 -0
- sglang/srt/layers/logits_processor.py +17 -3
- sglang/srt/layers/quantization/__init__.py +36 -2
- sglang/srt/layers/quantization/fp8.py +559 -0
- sglang/srt/layers/quantization/fp8_utils.py +27 -0
- sglang/srt/layers/radix_attention.py +4 -2
- sglang/srt/layers/sampler.py +2 -0
- sglang/srt/layers/torchao_utils.py +23 -45
- sglang/srt/layers/vocab_parallel_embedding.py +1 -0
- sglang/srt/lora/lora.py +1 -1
- sglang/srt/managers/io_struct.py +48 -2
- sglang/srt/managers/schedule_batch.py +19 -14
- sglang/srt/managers/schedule_policy.py +7 -4
- sglang/srt/managers/scheduler.py +145 -85
- sglang/srt/managers/tokenizer_manager.py +166 -68
- sglang/srt/managers/tp_worker.py +36 -3
- sglang/srt/managers/tp_worker_overlap_thread.py +28 -8
- sglang/srt/mem_cache/memory_pool.py +5 -1
- sglang/srt/model_executor/cuda_graph_runner.py +30 -7
- sglang/srt/model_executor/forward_batch_info.py +9 -4
- sglang/srt/model_executor/model_runner.py +146 -153
- sglang/srt/model_loader/__init__.py +34 -0
- sglang/srt/model_loader/loader.py +1139 -0
- sglang/srt/model_loader/utils.py +41 -0
- sglang/srt/model_loader/weight_utils.py +640 -0
- sglang/srt/model_parallel.py +1 -5
- sglang/srt/models/baichuan.py +9 -10
- sglang/srt/models/chatglm.py +6 -15
- sglang/srt/models/commandr.py +4 -5
- sglang/srt/models/dbrx.py +2 -3
- sglang/srt/models/deepseek.py +4 -11
- sglang/srt/models/deepseek_v2.py +90 -18
- sglang/srt/models/exaone.py +2 -3
- sglang/srt/models/gemma.py +2 -6
- sglang/srt/models/gemma2.py +3 -14
- sglang/srt/models/gemma2_reward.py +0 -1
- sglang/srt/models/gpt2.py +5 -12
- sglang/srt/models/gpt_bigcode.py +6 -22
- sglang/srt/models/grok.py +3 -8
- sglang/srt/models/internlm2.py +2 -3
- sglang/srt/models/internlm2_reward.py +0 -1
- sglang/srt/models/llama.py +96 -31
- sglang/srt/models/llama_classification.py +1 -2
- sglang/srt/models/llama_embedding.py +1 -2
- sglang/srt/models/llama_reward.py +2 -3
- sglang/srt/models/llava.py +1 -4
- sglang/srt/models/llavavid.py +1 -2
- sglang/srt/models/minicpm.py +4 -7
- sglang/srt/models/minicpm3.py +6 -19
- sglang/srt/models/mixtral.py +24 -14
- sglang/srt/models/mixtral_quant.py +2 -3
- sglang/srt/models/mllama.py +3 -7
- sglang/srt/models/olmo.py +2 -8
- sglang/srt/models/olmo2.py +0 -1
- sglang/srt/models/olmoe.py +3 -5
- sglang/srt/models/phi3_small.py +8 -13
- sglang/srt/models/qwen.py +2 -3
- sglang/srt/models/qwen2.py +10 -9
- sglang/srt/models/qwen2_moe.py +4 -16
- sglang/srt/models/qwen2_vl.py +2 -6
- sglang/srt/models/registry.py +99 -0
- sglang/srt/models/stablelm.py +2 -3
- sglang/srt/models/torch_native_llama.py +6 -17
- sglang/srt/models/xverse.py +2 -4
- sglang/srt/models/xverse_moe.py +4 -11
- sglang/srt/models/yivl.py +2 -3
- sglang/srt/openai_api/adapter.py +9 -5
- sglang/srt/openai_api/protocol.py +1 -0
- sglang/srt/sampling/sampling_batch_info.py +9 -8
- sglang/srt/server.py +270 -173
- sglang/srt/server_args.py +102 -29
- sglang/srt/utils.py +295 -28
- sglang/test/test_utils.py +7 -0
- sglang/version.py +1 -1
- {sglang-0.3.6.post3.dist-info → sglang-0.4.0.post1.dist-info}/METADATA +5 -4
- sglang-0.4.0.post1.dist-info/RECORD +189 -0
- sglang-0.3.6.post3.dist-info/RECORD +0 -162
- {sglang-0.3.6.post3.dist-info → sglang-0.4.0.post1.dist-info}/LICENSE +0 -0
- {sglang-0.3.6.post3.dist-info → sglang-0.4.0.post1.dist-info}/WHEEL +0 -0
- {sglang-0.3.6.post3.dist-info → sglang-0.4.0.post1.dist-info}/top_level.txt +0 -0
sglang/srt/models/baichuan.py
CHANGED
@@ -34,7 +34,6 @@ from vllm.model_executor.layers.linear import (
|
|
34
34
|
RowParallelLinear,
|
35
35
|
)
|
36
36
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
37
|
-
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
38
37
|
|
39
38
|
from sglang.srt.layers.activation import SiluAndMul
|
40
39
|
from sglang.srt.layers.layernorm import RMSNorm
|
@@ -46,6 +45,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
|
|
46
45
|
VocabParallelEmbedding,
|
47
46
|
)
|
48
47
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
48
|
+
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
49
49
|
|
50
50
|
|
51
51
|
def _get_alibi_slopes(total_num_heads: int) -> torch.Tensor:
|
@@ -329,7 +329,6 @@ class BaiChuanBaseForCausalLM(nn.Module):
|
|
329
329
|
self,
|
330
330
|
config: PretrainedConfig,
|
331
331
|
position_embedding: str,
|
332
|
-
cache_config=None,
|
333
332
|
quant_config: Optional[QuantizationConfig] = None,
|
334
333
|
):
|
335
334
|
super().__init__()
|
@@ -338,11 +337,12 @@ class BaiChuanBaseForCausalLM(nn.Module):
|
|
338
337
|
|
339
338
|
self.quant_config = quant_config
|
340
339
|
self.model = BaiChuanModel(config, position_embedding, quant_config)
|
341
|
-
self.lm_head = ParallelLMHead(
|
342
|
-
config.vocab_size, config.hidden_size, quant_config=quant_config
|
343
|
-
)
|
344
340
|
if self.config.tie_word_embeddings:
|
345
|
-
self.lm_head
|
341
|
+
self.lm_head = self.model.embed_tokens
|
342
|
+
else:
|
343
|
+
self.lm_head = ParallelLMHead(
|
344
|
+
config.vocab_size, config.hidden_size, quant_config=quant_config
|
345
|
+
)
|
346
346
|
self.logits_processor = LogitsProcessor(config)
|
347
347
|
|
348
348
|
def forward(
|
@@ -353,7 +353,7 @@ class BaiChuanBaseForCausalLM(nn.Module):
|
|
353
353
|
) -> torch.Tensor:
|
354
354
|
hidden_states = self.model(input_ids, positions, forward_batch)
|
355
355
|
return self.logits_processor(
|
356
|
-
input_ids, hidden_states, self.lm_head
|
356
|
+
input_ids, hidden_states, self.lm_head, forward_batch
|
357
357
|
)
|
358
358
|
|
359
359
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
@@ -403,13 +403,12 @@ class BaichuanForCausalLM(BaiChuanBaseForCausalLM):
|
|
403
403
|
def __init__(
|
404
404
|
self,
|
405
405
|
config,
|
406
|
-
cache_config=None,
|
407
406
|
quant_config: Optional[QuantizationConfig] = None,
|
408
407
|
):
|
409
408
|
if config.hidden_size == 4096: # baichuan2 7b
|
410
|
-
super().__init__(config, "ROPE",
|
409
|
+
super().__init__(config, "ROPE", quant_config)
|
411
410
|
else: # baichuan 13b, baichuan2 13b
|
412
|
-
super().__init__(config, "ALIBI",
|
411
|
+
super().__init__(config, "ALIBI", quant_config)
|
413
412
|
|
414
413
|
|
415
414
|
EntryClass = [BaichuanForCausalLM]
|
sglang/srt/models/chatglm.py
CHANGED
@@ -23,7 +23,6 @@ from torch import nn
|
|
23
23
|
from torch.nn import LayerNorm
|
24
24
|
from vllm.distributed import get_tensor_model_parallel_world_size
|
25
25
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
26
|
-
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
27
26
|
from vllm.transformers_utils.configs import ChatGLMConfig
|
28
27
|
|
29
28
|
from sglang.srt.layers.activation import SiluAndMul
|
@@ -41,6 +40,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
|
|
41
40
|
VocabParallelEmbedding,
|
42
41
|
)
|
43
42
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
43
|
+
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
44
44
|
|
45
45
|
LoraConfig = None
|
46
46
|
|
@@ -50,7 +50,6 @@ class GLMAttention(nn.Module):
|
|
50
50
|
self,
|
51
51
|
config,
|
52
52
|
layer_id: int = 0,
|
53
|
-
cache_config=None,
|
54
53
|
quant_config: Optional[QuantizationConfig] = None,
|
55
54
|
):
|
56
55
|
super().__init__()
|
@@ -186,7 +185,6 @@ class GLMBlock(nn.Module):
|
|
186
185
|
self,
|
187
186
|
config,
|
188
187
|
layer_id: int,
|
189
|
-
cache_config=None,
|
190
188
|
quant_config: Optional[QuantizationConfig] = None,
|
191
189
|
):
|
192
190
|
super().__init__()
|
@@ -203,7 +201,7 @@ class GLMBlock(nn.Module):
|
|
203
201
|
)
|
204
202
|
|
205
203
|
# Self attention.
|
206
|
-
self.self_attention = GLMAttention(config, layer_id,
|
204
|
+
self.self_attention = GLMAttention(config, layer_id, quant_config)
|
207
205
|
self.hidden_dropout = config.hidden_dropout
|
208
206
|
|
209
207
|
# Layernorm on the attention output
|
@@ -258,7 +256,6 @@ class GLMTransformer(nn.Module):
|
|
258
256
|
def __init__(
|
259
257
|
self,
|
260
258
|
config,
|
261
|
-
cache_config=None,
|
262
259
|
quant_config: Optional[QuantizationConfig] = None,
|
263
260
|
):
|
264
261
|
super().__init__()
|
@@ -269,10 +266,7 @@ class GLMTransformer(nn.Module):
|
|
269
266
|
|
270
267
|
# Transformer layers.
|
271
268
|
self.layers = nn.ModuleList(
|
272
|
-
[
|
273
|
-
GLMBlock(config, i, cache_config, quant_config)
|
274
|
-
for i in range(self.num_layers)
|
275
|
-
]
|
269
|
+
[GLMBlock(config, i, quant_config) for i in range(self.num_layers)]
|
276
270
|
)
|
277
271
|
|
278
272
|
if self.post_layer_norm:
|
@@ -306,7 +300,6 @@ class ChatGLMM(nn.Module):
|
|
306
300
|
def __init__(
|
307
301
|
self,
|
308
302
|
config,
|
309
|
-
cache_config=None,
|
310
303
|
quant_config: Optional[QuantizationConfig] = None,
|
311
304
|
):
|
312
305
|
super().__init__()
|
@@ -318,7 +311,7 @@ class ChatGLMM(nn.Module):
|
|
318
311
|
self.num_layers = config.num_layers
|
319
312
|
self.multi_query_group_num = config.multi_query_group_num
|
320
313
|
self.kv_channels = config.kv_channels
|
321
|
-
self.encoder = GLMTransformer(config,
|
314
|
+
self.encoder = GLMTransformer(config, quant_config)
|
322
315
|
|
323
316
|
self.output_layer = ParallelLMHead(config.padded_vocab_size, config.hidden_size)
|
324
317
|
|
@@ -357,15 +350,13 @@ class ChatGLMForCausalLM(nn.Module):
|
|
357
350
|
def __init__(
|
358
351
|
self,
|
359
352
|
config: ChatGLMConfig,
|
360
|
-
cache_config=None,
|
361
353
|
quant_config: Optional[QuantizationConfig] = None,
|
362
|
-
lora_config: Optional[LoraConfig] = None,
|
363
354
|
):
|
364
355
|
super().__init__()
|
365
356
|
self.config: ChatGLMConfig = config
|
366
357
|
self.quant_config = quant_config
|
367
358
|
self.max_position_embeddings = getattr(config, "max_sequence_length", 8192)
|
368
|
-
self.transformer = ChatGLMM(config,
|
359
|
+
self.transformer = ChatGLMM(config, quant_config)
|
369
360
|
self.lm_head = self.transformer.output_layer
|
370
361
|
self.logits_processor = LogitsProcessor(config)
|
371
362
|
|
@@ -378,7 +369,7 @@ class ChatGLMForCausalLM(nn.Module):
|
|
378
369
|
) -> torch.Tensor:
|
379
370
|
hidden_states = self.transformer(input_ids, positions, forward_batch)
|
380
371
|
return self.logits_processor(
|
381
|
-
input_ids, hidden_states, self.lm_head
|
372
|
+
input_ids, hidden_states, self.lm_head, forward_batch
|
382
373
|
)
|
383
374
|
|
384
375
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
sglang/srt/models/commandr.py
CHANGED
@@ -49,7 +49,6 @@ from vllm.distributed import (
|
|
49
49
|
get_tensor_model_parallel_world_size,
|
50
50
|
)
|
51
51
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
52
|
-
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
53
52
|
|
54
53
|
from sglang.srt.layers.activation import SiluAndMul
|
55
54
|
from sglang.srt.layers.linear import (
|
@@ -62,10 +61,11 @@ from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
|
62
61
|
from sglang.srt.layers.radix_attention import RadixAttention
|
63
62
|
from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
64
63
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
65
|
-
from sglang.srt.
|
64
|
+
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
65
|
+
from sglang.srt.utils import get_compiler_backend, set_weight_attrs
|
66
66
|
|
67
67
|
|
68
|
-
@torch.compile
|
68
|
+
@torch.compile(backend=get_compiler_backend())
|
69
69
|
def layer_norm_func(hidden_states, weight, variance_epsilon):
|
70
70
|
input_dtype = hidden_states.dtype
|
71
71
|
hidden_states = hidden_states.to(torch.float32)
|
@@ -318,7 +318,6 @@ class CohereForCausalLM(nn.Module):
|
|
318
318
|
self,
|
319
319
|
config: PretrainedConfig,
|
320
320
|
quant_config: Optional[QuantizationConfig] = None,
|
321
|
-
cache_config=None,
|
322
321
|
) -> None:
|
323
322
|
super().__init__()
|
324
323
|
self.config = config
|
@@ -339,7 +338,7 @@ class CohereForCausalLM(nn.Module):
|
|
339
338
|
forward_batch,
|
340
339
|
)
|
341
340
|
return self.logits_processor(
|
342
|
-
input_ids, hidden_states, self.model.embed_tokens
|
341
|
+
input_ids, hidden_states, self.model.embed_tokens, forward_batch
|
343
342
|
)
|
344
343
|
|
345
344
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
sglang/srt/models/dbrx.py
CHANGED
@@ -25,7 +25,6 @@ from vllm.distributed import (
|
|
25
25
|
tensor_model_parallel_all_reduce,
|
26
26
|
)
|
27
27
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
28
|
-
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
29
28
|
from vllm.transformers_utils.configs.dbrx import DbrxConfig
|
30
29
|
|
31
30
|
from sglang.srt.layers.fused_moe_triton import fused_moe
|
@@ -43,6 +42,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
|
|
43
42
|
VocabParallelEmbedding,
|
44
43
|
)
|
45
44
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
45
|
+
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
46
46
|
from sglang.srt.utils import set_weight_attrs
|
47
47
|
|
48
48
|
|
@@ -366,7 +366,6 @@ class DbrxForCausalLM(nn.Module):
|
|
366
366
|
self,
|
367
367
|
config: DbrxConfig,
|
368
368
|
quant_config: Optional[QuantizationConfig] = None,
|
369
|
-
cache_config=None,
|
370
369
|
):
|
371
370
|
super().__init__()
|
372
371
|
self.config = config
|
@@ -390,7 +389,7 @@ class DbrxForCausalLM(nn.Module):
|
|
390
389
|
) -> torch.Tensor:
|
391
390
|
hidden_states = self.transformer(input_ids, positions, forward_batch)
|
392
391
|
return self.logits_processor(
|
393
|
-
input_ids, hidden_states, self.lm_head
|
392
|
+
input_ids, hidden_states, self.lm_head, forward_batch
|
394
393
|
)
|
395
394
|
|
396
395
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
sglang/srt/models/deepseek.py
CHANGED
@@ -27,7 +27,6 @@ from vllm.distributed import (
|
|
27
27
|
tensor_model_parallel_all_reduce,
|
28
28
|
)
|
29
29
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
30
|
-
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
31
30
|
|
32
31
|
from sglang.srt.layers.activation import SiluAndMul
|
33
32
|
from sglang.srt.layers.fused_moe_triton import fused_moe
|
@@ -46,6 +45,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
|
|
46
45
|
VocabParallelEmbedding,
|
47
46
|
)
|
48
47
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
48
|
+
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
49
49
|
|
50
50
|
|
51
51
|
class DeepseekMLP(nn.Module):
|
@@ -184,7 +184,6 @@ class DeepseekAttention(nn.Module):
|
|
184
184
|
rope_theta: float = 10000,
|
185
185
|
rope_scaling: Optional[Dict[str, Any]] = None,
|
186
186
|
max_position_embeddings: int = 8192,
|
187
|
-
cache_config=None,
|
188
187
|
quant_config: Optional[QuantizationConfig] = None,
|
189
188
|
) -> None:
|
190
189
|
super().__init__()
|
@@ -261,7 +260,6 @@ class DeepseekDecoderLayer(nn.Module):
|
|
261
260
|
self,
|
262
261
|
config: PretrainedConfig,
|
263
262
|
layer_id: int,
|
264
|
-
cache_config=None,
|
265
263
|
quant_config: Optional[QuantizationConfig] = None,
|
266
264
|
) -> None:
|
267
265
|
super().__init__()
|
@@ -277,7 +275,6 @@ class DeepseekDecoderLayer(nn.Module):
|
|
277
275
|
rope_theta=rope_theta,
|
278
276
|
rope_scaling=rope_scaling,
|
279
277
|
max_position_embeddings=max_position_embeddings,
|
280
|
-
cache_config=cache_config,
|
281
278
|
quant_config=quant_config,
|
282
279
|
)
|
283
280
|
if (
|
@@ -330,7 +327,6 @@ class DeepseekModel(nn.Module):
|
|
330
327
|
def __init__(
|
331
328
|
self,
|
332
329
|
config: PretrainedConfig,
|
333
|
-
cache_config=None,
|
334
330
|
quant_config: Optional[QuantizationConfig] = None,
|
335
331
|
) -> None:
|
336
332
|
super().__init__()
|
@@ -343,9 +339,7 @@ class DeepseekModel(nn.Module):
|
|
343
339
|
)
|
344
340
|
self.layers = nn.ModuleList(
|
345
341
|
[
|
346
|
-
DeepseekDecoderLayer(
|
347
|
-
config, layer_id, cache_config, quant_config=quant_config
|
348
|
-
)
|
342
|
+
DeepseekDecoderLayer(config, layer_id, quant_config=quant_config)
|
349
343
|
for layer_id in range(config.num_hidden_layers)
|
350
344
|
]
|
351
345
|
)
|
@@ -373,13 +367,12 @@ class DeepseekForCausalLM(nn.Module):
|
|
373
367
|
def __init__(
|
374
368
|
self,
|
375
369
|
config: PretrainedConfig,
|
376
|
-
cache_config=None,
|
377
370
|
quant_config: Optional[QuantizationConfig] = None,
|
378
371
|
) -> None:
|
379
372
|
super().__init__()
|
380
373
|
self.config = config
|
381
374
|
self.quant_config = quant_config
|
382
|
-
self.model = DeepseekModel(config,
|
375
|
+
self.model = DeepseekModel(config, quant_config)
|
383
376
|
self.lm_head = ParallelLMHead(
|
384
377
|
config.vocab_size, config.hidden_size, quant_config=quant_config
|
385
378
|
)
|
@@ -394,7 +387,7 @@ class DeepseekForCausalLM(nn.Module):
|
|
394
387
|
) -> torch.Tensor:
|
395
388
|
hidden_states = self.model(input_ids, positions, forward_batch)
|
396
389
|
return self.logits_processor(
|
397
|
-
input_ids, hidden_states, self.lm_head
|
390
|
+
input_ids, hidden_states, self.lm_head, forward_batch
|
398
391
|
)
|
399
392
|
|
400
393
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
sglang/srt/models/deepseek_v2.py
CHANGED
@@ -21,6 +21,7 @@ from typing import Any, Dict, Iterable, Optional, Tuple
|
|
21
21
|
import torch
|
22
22
|
from torch import nn
|
23
23
|
from transformers import PretrainedConfig
|
24
|
+
from vllm import _custom_ops as ops
|
24
25
|
from vllm.distributed import (
|
25
26
|
get_tensor_model_parallel_rank,
|
26
27
|
get_tensor_model_parallel_world_size,
|
@@ -28,9 +29,9 @@ from vllm.distributed import (
|
|
28
29
|
tensor_model_parallel_all_reduce,
|
29
30
|
)
|
30
31
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
31
|
-
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
32
32
|
|
33
33
|
from sglang.srt.layers.activation import SiluAndMul
|
34
|
+
from sglang.srt.layers.ep_moe.layer import EPMoE
|
34
35
|
from sglang.srt.layers.fused_moe_triton import FusedMoE
|
35
36
|
from sglang.srt.layers.layernorm import RMSNorm
|
36
37
|
from sglang.srt.layers.linear import (
|
@@ -48,6 +49,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
|
|
48
49
|
)
|
49
50
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
50
51
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
52
|
+
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
51
53
|
from sglang.srt.utils import is_flashinfer_available
|
52
54
|
|
53
55
|
if is_flashinfer_available():
|
@@ -112,12 +114,12 @@ class DeepseekV2MoE(nn.Module):
|
|
112
114
|
"Only silu is supported for now."
|
113
115
|
)
|
114
116
|
|
115
|
-
|
117
|
+
MoEImpl = EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE
|
118
|
+
self.experts = MoEImpl(
|
116
119
|
num_experts=config.n_routed_experts,
|
117
120
|
top_k=config.num_experts_per_tok,
|
118
121
|
hidden_size=config.hidden_size,
|
119
122
|
intermediate_size=config.moe_intermediate_size,
|
120
|
-
reduce_results=False,
|
121
123
|
renormalize=config.norm_topk_prob,
|
122
124
|
quant_config=quant_config,
|
123
125
|
use_grouped_topk=True,
|
@@ -189,7 +191,6 @@ class DeepseekV2Attention(nn.Module):
|
|
189
191
|
rope_theta: float = 10000,
|
190
192
|
rope_scaling: Optional[Dict[str, Any]] = None,
|
191
193
|
max_position_embeddings: int = 8192,
|
192
|
-
cache_config=None,
|
193
194
|
quant_config: Optional[QuantizationConfig] = None,
|
194
195
|
layer_id=None,
|
195
196
|
) -> None:
|
@@ -337,7 +338,6 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
337
338
|
rope_theta: float = 10000,
|
338
339
|
rope_scaling: Optional[Dict[str, Any]] = None,
|
339
340
|
max_position_embeddings: int = 8192,
|
340
|
-
cache_config=None,
|
341
341
|
quant_config: Optional[QuantizationConfig] = None,
|
342
342
|
layer_id=None,
|
343
343
|
use_dp=False,
|
@@ -455,7 +455,7 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
455
455
|
mscale = yarn_get_mscale(scaling_factor, float(mscale_all_dim))
|
456
456
|
self.scaling = self.scaling * mscale * mscale
|
457
457
|
|
458
|
-
self.
|
458
|
+
self.attn_mqa = RadixAttention(
|
459
459
|
self.num_local_heads,
|
460
460
|
self.kv_lora_rank + self.qk_rope_head_dim,
|
461
461
|
self.scaling,
|
@@ -464,6 +464,15 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
464
464
|
v_head_dim=self.kv_lora_rank,
|
465
465
|
)
|
466
466
|
|
467
|
+
self.attn_mha = RadixAttention(
|
468
|
+
self.num_local_heads,
|
469
|
+
self.qk_nope_head_dim + self.qk_rope_head_dim,
|
470
|
+
self.scaling,
|
471
|
+
num_kv_heads=self.num_local_heads,
|
472
|
+
layer_id=layer_id,
|
473
|
+
v_head_dim=self.v_head_dim,
|
474
|
+
)
|
475
|
+
|
467
476
|
self.w_kc = None
|
468
477
|
self.w_vc = None
|
469
478
|
self.w_scale = None
|
@@ -473,6 +482,63 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
473
482
|
positions: torch.Tensor,
|
474
483
|
hidden_states: torch.Tensor,
|
475
484
|
forward_batch: ForwardBatch,
|
485
|
+
) -> torch.Tensor:
|
486
|
+
# Use normal computation for prefill and use weight absorption for extend/decode
|
487
|
+
if (
|
488
|
+
forward_batch.forward_mode.is_extend()
|
489
|
+
and forward_batch.extend_prefix_lens.sum() == 0
|
490
|
+
):
|
491
|
+
return self.forward_normal(positions, hidden_states, forward_batch)
|
492
|
+
else:
|
493
|
+
return self.forward_absorb(positions, hidden_states, forward_batch)
|
494
|
+
|
495
|
+
def forward_normal(
|
496
|
+
self,
|
497
|
+
positions: torch.Tensor,
|
498
|
+
hidden_states: torch.Tensor,
|
499
|
+
forward_batch: ForwardBatch,
|
500
|
+
) -> torch.Tensor:
|
501
|
+
if self.q_lora_rank is not None:
|
502
|
+
q = self.q_a_proj(hidden_states)[0]
|
503
|
+
q = self.q_a_layernorm(q)
|
504
|
+
q = self.q_b_proj(q)[0].view(-1, self.num_local_heads, self.qk_head_dim)
|
505
|
+
else:
|
506
|
+
q = self.q_proj(hidden_states)[0].view(
|
507
|
+
-1, self.num_local_heads, self.qk_head_dim
|
508
|
+
)
|
509
|
+
_, q_pe = q.split([self.qk_nope_head_dim, self.qk_rope_head_dim], dim=-1)
|
510
|
+
latent_cache = self.kv_a_proj_with_mqa(hidden_states)[0]
|
511
|
+
kv_a, _ = latent_cache.split([self.kv_lora_rank, self.qk_rope_head_dim], dim=-1)
|
512
|
+
latent_cache = latent_cache.unsqueeze(1)
|
513
|
+
kv_a = self.kv_a_layernorm(kv_a.contiguous())
|
514
|
+
kv = self.kv_b_proj(kv_a)[0]
|
515
|
+
kv = kv.view(-1, self.num_local_heads, self.qk_nope_head_dim + self.v_head_dim)
|
516
|
+
k_nope = kv[..., : self.qk_nope_head_dim]
|
517
|
+
v = kv[..., self.qk_nope_head_dim :]
|
518
|
+
k_pe = latent_cache[:, :, self.kv_lora_rank :]
|
519
|
+
q_pe, k_pe = self.rotary_emb(positions, q_pe, k_pe)
|
520
|
+
q[..., self.qk_nope_head_dim :] = q_pe
|
521
|
+
k = torch.empty_like(q)
|
522
|
+
k[..., : self.qk_nope_head_dim] = k_nope
|
523
|
+
k[..., self.qk_nope_head_dim :] = k_pe
|
524
|
+
|
525
|
+
latent_cache[:, :, : self.kv_lora_rank] = kv_a.unsqueeze(1)
|
526
|
+
latent_cache[:, :, self.kv_lora_rank :] = k_pe
|
527
|
+
|
528
|
+
# Save latent cache
|
529
|
+
forward_batch.token_to_kv_pool.set_kv_buffer(
|
530
|
+
self.attn_mha, forward_batch.out_cache_loc, latent_cache, None
|
531
|
+
)
|
532
|
+
attn_output = self.attn_mha(q, k, v, forward_batch, save_kv_cache=False)
|
533
|
+
attn_output = attn_output.reshape(-1, self.num_local_heads * self.v_head_dim)
|
534
|
+
output, _ = self.o_proj(attn_output)
|
535
|
+
return output
|
536
|
+
|
537
|
+
def forward_absorb(
|
538
|
+
self,
|
539
|
+
positions: torch.Tensor,
|
540
|
+
hidden_states: torch.Tensor,
|
541
|
+
forward_batch: ForwardBatch,
|
476
542
|
) -> torch.Tensor:
|
477
543
|
q_len = hidden_states.shape[0]
|
478
544
|
q_input = hidden_states.new_empty(
|
@@ -510,7 +576,7 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
510
576
|
q_input[..., self.kv_lora_rank :] = q_pe
|
511
577
|
k_input[..., self.kv_lora_rank :] = k_pe
|
512
578
|
|
513
|
-
attn_output = self.
|
579
|
+
attn_output = self.attn_mqa(q_input, k_input, v_input, forward_batch)
|
514
580
|
attn_output = attn_output.view(-1, self.num_local_heads, self.kv_lora_rank)
|
515
581
|
|
516
582
|
if self.w_vc.dtype == torch.float8_e4m3fn:
|
@@ -568,7 +634,6 @@ class DeepseekV2DecoderLayer(nn.Module):
|
|
568
634
|
self,
|
569
635
|
config: PretrainedConfig,
|
570
636
|
layer_id: int,
|
571
|
-
cache_config=None,
|
572
637
|
quant_config: Optional[QuantizationConfig] = None,
|
573
638
|
) -> None:
|
574
639
|
super().__init__()
|
@@ -599,7 +664,6 @@ class DeepseekV2DecoderLayer(nn.Module):
|
|
599
664
|
rope_theta=rope_theta,
|
600
665
|
rope_scaling=rope_scaling,
|
601
666
|
max_position_embeddings=max_position_embeddings,
|
602
|
-
cache_config=cache_config,
|
603
667
|
quant_config=quant_config,
|
604
668
|
layer_id=layer_id,
|
605
669
|
use_dp=self.enable_dp_attention,
|
@@ -619,7 +683,6 @@ class DeepseekV2DecoderLayer(nn.Module):
|
|
619
683
|
rope_theta=rope_theta,
|
620
684
|
rope_scaling=rope_scaling,
|
621
685
|
max_position_embeddings=max_position_embeddings,
|
622
|
-
cache_config=cache_config,
|
623
686
|
quant_config=quant_config,
|
624
687
|
layer_id=layer_id,
|
625
688
|
)
|
@@ -685,7 +748,6 @@ class DeepseekV2Model(nn.Module):
|
|
685
748
|
def __init__(
|
686
749
|
self,
|
687
750
|
config: PretrainedConfig,
|
688
|
-
cache_config=None,
|
689
751
|
quant_config: Optional[QuantizationConfig] = None,
|
690
752
|
) -> None:
|
691
753
|
super().__init__()
|
@@ -702,7 +764,6 @@ class DeepseekV2Model(nn.Module):
|
|
702
764
|
DeepseekV2DecoderLayer(
|
703
765
|
config,
|
704
766
|
layer_id,
|
705
|
-
cache_config=cache_config,
|
706
767
|
quant_config=quant_config,
|
707
768
|
)
|
708
769
|
for layer_id in range(config.num_hidden_layers)
|
@@ -733,13 +794,12 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|
733
794
|
def __init__(
|
734
795
|
self,
|
735
796
|
config: PretrainedConfig,
|
736
|
-
cache_config=None,
|
737
797
|
quant_config: Optional[QuantizationConfig] = None,
|
738
798
|
) -> None:
|
739
799
|
super().__init__()
|
740
800
|
self.config = config
|
741
801
|
self.quant_config = quant_config
|
742
|
-
self.model = DeepseekV2Model(config,
|
802
|
+
self.model = DeepseekV2Model(config, quant_config)
|
743
803
|
if global_server_args_dict["enable_dp_attention"]:
|
744
804
|
self.lm_head = ReplicatedLinear(
|
745
805
|
config.hidden_size,
|
@@ -763,7 +823,7 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|
763
823
|
hidden_states = self.model(input_ids, positions, forward_batch)
|
764
824
|
if not forward_batch.forward_mode.is_idle():
|
765
825
|
return self.logits_processor(
|
766
|
-
input_ids, hidden_states, self.lm_head
|
826
|
+
input_ids, hidden_states, self.lm_head, forward_batch
|
767
827
|
)
|
768
828
|
|
769
829
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
@@ -775,7 +835,8 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|
775
835
|
|
776
836
|
# Params for weights, fp8 weight scales, fp8 activation scales
|
777
837
|
# (param_name, weight_name, expert_id, shard_id)
|
778
|
-
|
838
|
+
MoEImpl = EPMoE if global_server_args_dict["enable_ep_moe"] else FusedMoE
|
839
|
+
expert_params_mapping = MoEImpl.make_expert_params_mapping(
|
779
840
|
ckpt_gate_proj_name="gate_proj",
|
780
841
|
ckpt_down_proj_name="down_proj",
|
781
842
|
ckpt_up_proj_name="up_proj",
|
@@ -836,14 +897,25 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|
836
897
|
if not global_server_args_dict["disable_mla"]:
|
837
898
|
for layer_id in range(self.config.num_hidden_layers):
|
838
899
|
self_attn = self.model.layers[layer_id].self_attn
|
839
|
-
|
900
|
+
if hasattr(self_attn.kv_b_proj, "qweight"):
|
901
|
+
# AWQ compatible
|
902
|
+
w = ops.awq_dequantize(
|
903
|
+
self_attn.kv_b_proj.qweight,
|
904
|
+
self_attn.kv_b_proj.scales,
|
905
|
+
self_attn.kv_b_proj.qzeros,
|
906
|
+
0,
|
907
|
+
0,
|
908
|
+
0,
|
909
|
+
).T
|
910
|
+
else:
|
911
|
+
w = self_attn.kv_b_proj.weight
|
912
|
+
w_kc, w_vc = w.unflatten(
|
840
913
|
0, (-1, self_attn.qk_nope_head_dim + self_attn.v_head_dim)
|
841
914
|
).split([self_attn.qk_nope_head_dim, self_attn.v_head_dim], dim=1)
|
842
915
|
self_attn.w_kc = w_kc.transpose(1, 2).contiguous().transpose(1, 2)
|
843
916
|
self_attn.w_vc = w_vc.contiguous().transpose(1, 2)
|
844
917
|
if hasattr(self_attn.kv_b_proj, "weight_scale"):
|
845
918
|
self_attn.w_scale = self_attn.kv_b_proj.weight_scale
|
846
|
-
del self_attn.kv_b_proj
|
847
919
|
|
848
920
|
|
849
921
|
EntryClass = DeepseekV2ForCausalLM
|
sglang/srt/models/exaone.py
CHANGED
@@ -22,7 +22,6 @@ import torch
|
|
22
22
|
from torch import nn
|
23
23
|
from vllm.distributed import get_tensor_model_parallel_world_size
|
24
24
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
25
|
-
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
26
25
|
|
27
26
|
from sglang.srt.layers.activation import SiluAndMul
|
28
27
|
from sglang.srt.layers.layernorm import RMSNorm
|
@@ -39,6 +38,7 @@ from sglang.srt.layers.vocab_parallel_embedding import (
|
|
39
38
|
VocabParallelEmbedding,
|
40
39
|
)
|
41
40
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
41
|
+
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
42
42
|
|
43
43
|
|
44
44
|
class ExaoneGatedMLP(nn.Module):
|
@@ -293,7 +293,6 @@ class ExaoneForCausalLM(nn.Module):
|
|
293
293
|
self,
|
294
294
|
config,
|
295
295
|
quant_config: Optional[QuantizationConfig] = None,
|
296
|
-
cache_config=None,
|
297
296
|
) -> None:
|
298
297
|
super().__init__()
|
299
298
|
self.config = config
|
@@ -314,7 +313,7 @@ class ExaoneForCausalLM(nn.Module):
|
|
314
313
|
input_ids, positions, forward_batch, input_embeds
|
315
314
|
)
|
316
315
|
return self.logits_processor(
|
317
|
-
input_ids, hidden_states, self.lm_head
|
316
|
+
input_ids, hidden_states, self.lm_head, forward_batch
|
318
317
|
)
|
319
318
|
|
320
319
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
sglang/srt/models/gemma.py
CHANGED
@@ -21,10 +21,8 @@ from typing import Iterable, Optional, Tuple
|
|
21
21
|
import torch
|
22
22
|
from torch import nn
|
23
23
|
from transformers import PretrainedConfig
|
24
|
-
from vllm.config import LoRAConfig
|
25
24
|
from vllm.distributed import get_tensor_model_parallel_world_size
|
26
25
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
27
|
-
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
28
26
|
|
29
27
|
from sglang.srt.layers.activation import GeluAndMul
|
30
28
|
from sglang.srt.layers.layernorm import RMSNorm
|
@@ -38,6 +36,7 @@ from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
|
38
36
|
from sglang.srt.layers.radix_attention import RadixAttention
|
39
37
|
from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
40
38
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
39
|
+
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
41
40
|
|
42
41
|
|
43
42
|
class GemmaMLP(nn.Module):
|
@@ -278,10 +277,7 @@ class GemmaForCausalLM(nn.Module):
|
|
278
277
|
self,
|
279
278
|
config: PretrainedConfig,
|
280
279
|
quant_config: Optional[QuantizationConfig] = None,
|
281
|
-
lora_config: Optional[LoRAConfig] = None,
|
282
|
-
cache_config=None,
|
283
280
|
) -> None:
|
284
|
-
del lora_config # Unused.
|
285
281
|
super().__init__()
|
286
282
|
self.config = config
|
287
283
|
self.quant_config = quant_config
|
@@ -298,7 +294,7 @@ class GemmaForCausalLM(nn.Module):
|
|
298
294
|
) -> torch.Tensor:
|
299
295
|
hidden_states = self.model(input_ids, positions, forward_batch, input_embeds)
|
300
296
|
return self.logits_processor(
|
301
|
-
input_ids, hidden_states, self.model.embed_tokens
|
297
|
+
input_ids, hidden_states, self.model.embed_tokens, forward_batch
|
302
298
|
)
|
303
299
|
|
304
300
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|