sglang 0.3.5.post2__py3-none-any.whl → 0.3.6.post1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +2 -2
- sglang/api.py +2 -2
- sglang/bench_latency.py +1 -553
- sglang/bench_offline_throughput.py +48 -20
- sglang/bench_one_batch.py +472 -0
- sglang/{bench_server_latency.py → bench_one_batch_server.py} +3 -3
- sglang/bench_serving.py +125 -6
- sglang/check_env.py +3 -6
- sglang/lang/backend/base_backend.py +1 -1
- sglang/lang/backend/runtime_endpoint.py +2 -2
- sglang/srt/configs/model_config.py +13 -14
- sglang/srt/constrained/__init__.py +13 -14
- sglang/srt/constrained/base_grammar_backend.py +13 -15
- sglang/srt/constrained/outlines_backend.py +28 -17
- sglang/srt/constrained/outlines_jump_forward.py +13 -15
- sglang/srt/constrained/xgrammar_backend.py +47 -58
- sglang/srt/conversation.py +13 -15
- sglang/srt/hf_transformers_utils.py +13 -15
- sglang/srt/layers/activation.py +16 -13
- sglang/srt/layers/attention/flashinfer_backend.py +106 -54
- sglang/srt/layers/attention/triton_backend.py +9 -7
- sglang/srt/layers/attention/triton_ops/decode_attention.py +51 -55
- sglang/srt/layers/attention/triton_ops/extend_attention.py +16 -16
- sglang/srt/layers/attention/triton_ops/prefill_attention.py +13 -15
- sglang/srt/layers/custom_op_util.py +25 -0
- sglang/srt/layers/fused_moe_grok/__init__.py +1 -0
- sglang/srt/layers/{fused_moe → fused_moe_grok}/fused_moe.py +11 -4
- sglang/srt/layers/{fused_moe → fused_moe_grok}/layer.py +4 -9
- sglang/srt/layers/{fused_moe/patch.py → fused_moe_patch.py} +5 -0
- sglang/srt/layers/fused_moe_triton/__init__.py +44 -0
- sglang/srt/layers/fused_moe_triton/fused_moe.py +861 -0
- sglang/srt/layers/fused_moe_triton/layer.py +633 -0
- sglang/srt/layers/layernorm.py +17 -15
- sglang/srt/layers/logits_processor.py +23 -25
- sglang/srt/layers/quantization/__init__.py +77 -17
- sglang/srt/layers/radix_attention.py +13 -15
- sglang/srt/layers/rotary_embedding.py +13 -13
- sglang/srt/layers/sampler.py +4 -8
- sglang/srt/layers/torchao_utils.py +2 -0
- sglang/srt/lora/lora.py +13 -14
- sglang/srt/lora/lora_config.py +13 -14
- sglang/srt/lora/lora_manager.py +22 -24
- sglang/srt/managers/data_parallel_controller.py +98 -27
- sglang/srt/managers/detokenizer_manager.py +13 -15
- sglang/srt/managers/io_struct.py +63 -21
- sglang/srt/managers/schedule_batch.py +154 -59
- sglang/srt/managers/schedule_policy.py +18 -16
- sglang/srt/managers/scheduler.py +278 -109
- sglang/srt/managers/session_controller.py +61 -0
- sglang/srt/managers/tokenizer_manager.py +63 -18
- sglang/srt/managers/tp_worker.py +25 -16
- sglang/srt/managers/tp_worker_overlap_thread.py +62 -67
- sglang/srt/metrics/collector.py +13 -15
- sglang/srt/metrics/func_timer.py +13 -15
- sglang/srt/mm_utils.py +13 -14
- sglang/srt/model_executor/cuda_graph_runner.py +63 -25
- sglang/srt/model_executor/forward_batch_info.py +128 -32
- sglang/srt/model_executor/model_runner.py +132 -64
- sglang/srt/model_parallel.py +98 -0
- sglang/srt/models/chatglm.py +15 -16
- sglang/srt/models/commandr.py +15 -16
- sglang/srt/models/dbrx.py +15 -16
- sglang/srt/models/deepseek.py +15 -15
- sglang/srt/models/deepseek_v2.py +162 -59
- sglang/srt/models/exaone.py +14 -15
- sglang/srt/models/gemma.py +14 -14
- sglang/srt/models/gemma2.py +31 -25
- sglang/srt/models/gemma2_reward.py +13 -14
- sglang/srt/models/gpt_bigcode.py +14 -14
- sglang/srt/models/grok.py +15 -15
- sglang/srt/models/internlm2.py +13 -15
- sglang/srt/models/internlm2_reward.py +13 -14
- sglang/srt/models/llama.py +21 -21
- sglang/srt/models/llama_classification.py +13 -14
- sglang/srt/models/llama_reward.py +13 -14
- sglang/srt/models/llava.py +14 -16
- sglang/srt/models/llavavid.py +14 -16
- sglang/srt/models/minicpm.py +13 -15
- sglang/srt/models/minicpm3.py +13 -15
- sglang/srt/models/mistral.py +13 -15
- sglang/srt/models/mixtral.py +15 -15
- sglang/srt/models/mixtral_quant.py +14 -14
- sglang/srt/models/olmo.py +22 -20
- sglang/srt/models/olmoe.py +23 -20
- sglang/srt/models/phi3_small.py +447 -0
- sglang/srt/models/qwen.py +14 -14
- sglang/srt/models/qwen2.py +22 -19
- sglang/srt/models/qwen2_moe.py +17 -18
- sglang/srt/models/qwen2_vl.py +13 -6
- sglang/srt/models/stablelm.py +18 -16
- sglang/srt/models/torch_native_llama.py +107 -93
- sglang/srt/models/xverse.py +13 -14
- sglang/srt/models/xverse_moe.py +15 -16
- sglang/srt/models/yivl.py +13 -15
- sglang/srt/openai_api/adapter.py +19 -17
- sglang/srt/openai_api/protocol.py +14 -16
- sglang/srt/sampling/penaltylib/orchestrator.py +49 -79
- sglang/srt/sampling/penaltylib/penalizers/frequency_penalty.py +3 -8
- sglang/srt/sampling/penaltylib/penalizers/min_new_tokens.py +3 -9
- sglang/srt/sampling/penaltylib/penalizers/presence_penalty.py +3 -8
- sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +3 -8
- sglang/srt/sampling/sampling_batch_info.py +61 -57
- sglang/srt/sampling/sampling_params.py +14 -16
- sglang/srt/server.py +86 -35
- sglang/srt/server_args.py +96 -80
- sglang/srt/utils.py +266 -68
- sglang/test/few_shot_gsm8k.py +8 -4
- sglang/test/runners.py +38 -20
- sglang/test/srt/sampling/penaltylib/utils.py +23 -21
- sglang/test/test_utils.py +31 -20
- sglang/version.py +1 -1
- {sglang-0.3.5.post2.dist-info → sglang-0.3.6.post1.dist-info}/LICENSE +1 -1
- {sglang-0.3.5.post2.dist-info → sglang-0.3.6.post1.dist-info}/METADATA +66 -57
- sglang-0.3.6.post1.dist-info/RECORD +164 -0
- {sglang-0.3.5.post2.dist-info → sglang-0.3.6.post1.dist-info}/WHEEL +1 -1
- sglang/srt/layers/fused_moe/__init__.py +0 -1
- sglang-0.3.5.post2.dist-info/RECORD +0 -156
- {sglang-0.3.5.post2.dist-info → sglang-0.3.6.post1.dist-info}/top_level.txt +0 -0
sglang/srt/models/deepseek_v2.py
CHANGED
@@ -1,35 +1,37 @@
|
|
1
|
-
|
2
|
-
|
3
|
-
|
4
|
-
|
5
|
-
|
6
|
-
|
7
|
-
|
8
|
-
|
9
|
-
|
10
|
-
|
11
|
-
|
12
|
-
|
13
|
-
|
14
|
-
"""
|
1
|
+
# Copyright 2023-2024 SGLang Team
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
#
|
6
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
7
|
+
#
|
8
|
+
# Unless required by applicable law or agreed to in writing, software
|
9
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
10
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
11
|
+
# See the License for the specific language governing permissions and
|
12
|
+
# limitations under the License.
|
13
|
+
# ==============================================================================
|
15
14
|
|
16
15
|
# Adapted from:
|
17
16
|
# https://github.com/vllm-project/vllm/blob/fb6af8bc086328ca6659e72d11ffd4309ce4de22/vllm/model_executor/models/deepseek_v2.py
|
18
17
|
"""Inference-only DeepseekV2 model."""
|
18
|
+
|
19
19
|
from typing import Any, Dict, Iterable, Optional, Tuple
|
20
20
|
|
21
21
|
import torch
|
22
22
|
from torch import nn
|
23
23
|
from transformers import PretrainedConfig
|
24
24
|
from vllm.distributed import (
|
25
|
+
get_tensor_model_parallel_rank,
|
25
26
|
get_tensor_model_parallel_world_size,
|
27
|
+
get_tp_group,
|
26
28
|
tensor_model_parallel_all_reduce,
|
27
29
|
)
|
28
|
-
from vllm.model_executor.layers.fused_moe import FusedMoE
|
29
30
|
from vllm.model_executor.layers.rotary_embedding import get_rope
|
30
31
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
31
32
|
|
32
33
|
from sglang.srt.layers.activation import SiluAndMul
|
34
|
+
from sglang.srt.layers.fused_moe_triton import FusedMoE
|
33
35
|
from sglang.srt.layers.layernorm import RMSNorm
|
34
36
|
from sglang.srt.layers.linear import (
|
35
37
|
ColumnParallelLinear,
|
@@ -338,6 +340,7 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
338
340
|
cache_config=None,
|
339
341
|
quant_config: Optional[QuantizationConfig] = None,
|
340
342
|
layer_id=None,
|
343
|
+
use_dp=False,
|
341
344
|
) -> None:
|
342
345
|
super().__init__()
|
343
346
|
self.layer_id = layer_id
|
@@ -351,29 +354,80 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
351
354
|
self.num_heads = num_heads
|
352
355
|
tp_size = get_tensor_model_parallel_world_size()
|
353
356
|
assert num_heads % tp_size == 0
|
354
|
-
self.num_local_heads = num_heads // tp_size
|
357
|
+
self.num_local_heads = num_heads if use_dp else num_heads // tp_size
|
355
358
|
self.scaling = self.qk_head_dim**-0.5
|
356
359
|
self.rope_theta = rope_theta
|
357
360
|
self.max_position_embeddings = max_position_embeddings
|
358
361
|
|
359
|
-
if
|
360
|
-
|
361
|
-
|
362
|
-
self.
|
362
|
+
if use_dp:
|
363
|
+
# For data parallel attention
|
364
|
+
if self.q_lora_rank is not None:
|
365
|
+
self.q_a_proj = ReplicatedLinear(
|
366
|
+
self.hidden_size,
|
367
|
+
self.q_lora_rank,
|
368
|
+
bias=False,
|
369
|
+
quant_config=quant_config,
|
370
|
+
)
|
371
|
+
self.q_a_layernorm = RMSNorm(self.q_lora_rank, eps=config.rms_norm_eps)
|
372
|
+
self.q_b_proj = ReplicatedLinear(
|
373
|
+
q_lora_rank,
|
374
|
+
self.num_heads * self.qk_head_dim,
|
375
|
+
bias=False,
|
376
|
+
quant_config=quant_config,
|
377
|
+
)
|
378
|
+
else:
|
379
|
+
self.q_proj = ReplicatedLinear(
|
380
|
+
self.hidden_size,
|
381
|
+
self.num_heads * self.qk_head_dim,
|
382
|
+
bias=False,
|
383
|
+
quant_config=quant_config,
|
384
|
+
)
|
385
|
+
self.kv_b_proj = ReplicatedLinear(
|
386
|
+
self.kv_lora_rank,
|
387
|
+
self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
|
363
388
|
bias=False,
|
364
389
|
quant_config=quant_config,
|
365
390
|
)
|
366
|
-
|
367
|
-
self.
|
368
|
-
|
369
|
-
self.
|
391
|
+
# O projection.
|
392
|
+
self.o_proj = ReplicatedLinear(
|
393
|
+
self.num_heads * self.v_head_dim,
|
394
|
+
self.hidden_size,
|
370
395
|
bias=False,
|
371
396
|
quant_config=quant_config,
|
372
397
|
)
|
373
398
|
else:
|
374
|
-
|
399
|
+
# For tensor parallel attention
|
400
|
+
if self.q_lora_rank is not None:
|
401
|
+
self.q_a_proj = ReplicatedLinear(
|
402
|
+
self.hidden_size,
|
403
|
+
self.q_lora_rank,
|
404
|
+
bias=False,
|
405
|
+
quant_config=quant_config,
|
406
|
+
)
|
407
|
+
self.q_a_layernorm = RMSNorm(self.q_lora_rank, eps=config.rms_norm_eps)
|
408
|
+
self.q_b_proj = ColumnParallelLinear(
|
409
|
+
q_lora_rank,
|
410
|
+
self.num_heads * self.qk_head_dim,
|
411
|
+
bias=False,
|
412
|
+
quant_config=quant_config,
|
413
|
+
)
|
414
|
+
else:
|
415
|
+
self.q_proj = ColumnParallelLinear(
|
416
|
+
self.hidden_size,
|
417
|
+
self.num_heads * self.qk_head_dim,
|
418
|
+
bias=False,
|
419
|
+
quant_config=quant_config,
|
420
|
+
)
|
421
|
+
self.kv_b_proj = ColumnParallelLinear(
|
422
|
+
self.kv_lora_rank,
|
423
|
+
self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
|
424
|
+
bias=False,
|
425
|
+
quant_config=quant_config,
|
426
|
+
)
|
427
|
+
# O projection.
|
428
|
+
self.o_proj = RowParallelLinear(
|
429
|
+
self.num_heads * self.v_head_dim,
|
375
430
|
self.hidden_size,
|
376
|
-
self.num_heads * self.qk_head_dim,
|
377
431
|
bias=False,
|
378
432
|
quant_config=quant_config,
|
379
433
|
)
|
@@ -385,19 +439,6 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
385
439
|
quant_config=quant_config,
|
386
440
|
)
|
387
441
|
self.kv_a_layernorm = RMSNorm(self.kv_lora_rank, eps=config.rms_norm_eps)
|
388
|
-
self.kv_b_proj = ColumnParallelLinear(
|
389
|
-
self.kv_lora_rank,
|
390
|
-
self.num_heads * (self.qk_nope_head_dim + self.v_head_dim),
|
391
|
-
bias=False,
|
392
|
-
quant_config=quant_config,
|
393
|
-
)
|
394
|
-
# O projection.
|
395
|
-
self.o_proj = RowParallelLinear(
|
396
|
-
self.num_heads * self.v_head_dim,
|
397
|
-
self.hidden_size,
|
398
|
-
bias=False,
|
399
|
-
quant_config=quant_config,
|
400
|
-
)
|
401
442
|
rope_scaling["rope_type"] = "deepseek_yarn"
|
402
443
|
self.rotary_emb = get_rope(
|
403
444
|
qk_rope_head_dim,
|
@@ -491,6 +532,36 @@ class DeepseekV2AttentionMLA(nn.Module):
|
|
491
532
|
return output
|
492
533
|
|
493
534
|
|
535
|
+
def all_gather(
|
536
|
+
input_tensor: torch.Tensor, forward_batch: ForwardBatch, rank, world_size, group
|
537
|
+
):
|
538
|
+
if world_size == 1:
|
539
|
+
return input_tensor
|
540
|
+
|
541
|
+
all_lens = forward_batch.global_num_tokens
|
542
|
+
max_len = max(forward_batch.global_num_tokens)
|
543
|
+
|
544
|
+
padded_tensor = torch.nn.functional.pad(
|
545
|
+
input_tensor, (0, 0, 0, max_len - input_tensor.shape[0])
|
546
|
+
)
|
547
|
+
|
548
|
+
torch.distributed.all_gather_into_tensor(
|
549
|
+
forward_batch.gathered_buffer, padded_tensor, group=group
|
550
|
+
)
|
551
|
+
|
552
|
+
gathered_tensors = torch.concat(
|
553
|
+
[
|
554
|
+
forward_batch.gathered_buffer[i * max_len : i * max_len + all_lens[i]]
|
555
|
+
for i in range(world_size)
|
556
|
+
]
|
557
|
+
)
|
558
|
+
|
559
|
+
start_index = 0 if rank == 0 else sum(all_lens[:rank])
|
560
|
+
end_index = start_index + all_lens[rank]
|
561
|
+
|
562
|
+
return gathered_tensors, start_index, end_index
|
563
|
+
|
564
|
+
|
494
565
|
class DeepseekV2DecoderLayer(nn.Module):
|
495
566
|
|
496
567
|
def __init__(
|
@@ -505,6 +576,14 @@ class DeepseekV2DecoderLayer(nn.Module):
|
|
505
576
|
rope_theta = getattr(config, "rope_theta", 10000)
|
506
577
|
rope_scaling = getattr(config, "rope_scaling", None)
|
507
578
|
max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
|
579
|
+
self.enable_dp_attention = (
|
580
|
+
not global_server_args_dict["disable_mla"]
|
581
|
+
and global_server_args_dict["enable_dp_attention"]
|
582
|
+
)
|
583
|
+
if self.enable_dp_attention:
|
584
|
+
self.tp_rank = get_tensor_model_parallel_rank()
|
585
|
+
self.tp_size = get_tensor_model_parallel_world_size()
|
586
|
+
self.tp_group = get_tp_group().device_group
|
508
587
|
if not global_server_args_dict["disable_mla"]:
|
509
588
|
self.self_attn = DeepseekV2AttentionMLA(
|
510
589
|
config=config,
|
@@ -523,6 +602,7 @@ class DeepseekV2DecoderLayer(nn.Module):
|
|
523
602
|
cache_config=cache_config,
|
524
603
|
quant_config=quant_config,
|
525
604
|
layer_id=layer_id,
|
605
|
+
use_dp=self.enable_dp_attention,
|
526
606
|
)
|
527
607
|
else:
|
528
608
|
self.self_attn = DeepseekV2Attention(
|
@@ -569,20 +649,32 @@ class DeepseekV2DecoderLayer(nn.Module):
|
|
569
649
|
residual: Optional[torch.Tensor],
|
570
650
|
) -> torch.Tensor:
|
571
651
|
# Self Attention
|
572
|
-
if
|
573
|
-
residual
|
574
|
-
|
575
|
-
|
576
|
-
|
577
|
-
|
578
|
-
|
579
|
-
hidden_states=
|
580
|
-
|
581
|
-
|
652
|
+
if not forward_batch.forward_mode.is_idle():
|
653
|
+
if residual is None:
|
654
|
+
residual = hidden_states
|
655
|
+
hidden_states = self.input_layernorm(hidden_states)
|
656
|
+
else:
|
657
|
+
hidden_states, residual = self.input_layernorm(hidden_states, residual)
|
658
|
+
|
659
|
+
hidden_states = self.self_attn(
|
660
|
+
positions=positions,
|
661
|
+
hidden_states=hidden_states,
|
662
|
+
forward_batch=forward_batch,
|
663
|
+
)
|
664
|
+
hidden_states, residual = self.post_attention_layernorm(
|
665
|
+
hidden_states, residual
|
666
|
+
)
|
582
667
|
|
583
668
|
# Fully Connected
|
584
|
-
|
585
|
-
|
669
|
+
if self.enable_dp_attention:
|
670
|
+
hidden_states, start_idx, end_idx = all_gather(
|
671
|
+
hidden_states, forward_batch, self.tp_rank, self.tp_size, self.tp_group
|
672
|
+
)
|
673
|
+
hidden_states = self.mlp(hidden_states)
|
674
|
+
hidden_states = hidden_states[start_idx:end_idx]
|
675
|
+
else:
|
676
|
+
hidden_states = self.mlp(hidden_states)
|
677
|
+
|
586
678
|
return hidden_states, residual
|
587
679
|
|
588
680
|
|
@@ -603,6 +695,7 @@ class DeepseekV2Model(nn.Module):
|
|
603
695
|
self.embed_tokens = VocabParallelEmbedding(
|
604
696
|
config.vocab_size,
|
605
697
|
config.hidden_size,
|
698
|
+
enable_tp=not global_server_args_dict["enable_dp_attention"],
|
606
699
|
)
|
607
700
|
self.layers = nn.ModuleList(
|
608
701
|
[
|
@@ -630,7 +723,8 @@ class DeepseekV2Model(nn.Module):
|
|
630
723
|
hidden_states, residual = layer(
|
631
724
|
positions, hidden_states, forward_batch, residual
|
632
725
|
)
|
633
|
-
|
726
|
+
if not forward_batch.forward_mode.is_idle():
|
727
|
+
hidden_states, _ = self.norm(hidden_states, residual)
|
634
728
|
return hidden_states
|
635
729
|
|
636
730
|
|
@@ -646,10 +740,18 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|
646
740
|
self.config = config
|
647
741
|
self.quant_config = quant_config
|
648
742
|
self.model = DeepseekV2Model(config, cache_config, quant_config)
|
649
|
-
|
650
|
-
|
651
|
-
|
652
|
-
|
743
|
+
if global_server_args_dict["enable_dp_attention"]:
|
744
|
+
self.lm_head = ReplicatedLinear(
|
745
|
+
config.hidden_size,
|
746
|
+
config.vocab_size,
|
747
|
+
bias=False,
|
748
|
+
)
|
749
|
+
self.logits_processor = LogitsProcessor(config, skip_all_gather=True)
|
750
|
+
else:
|
751
|
+
self.lm_head = ParallelLMHead(
|
752
|
+
config.vocab_size, config.hidden_size, quant_config=quant_config
|
753
|
+
)
|
754
|
+
self.logits_processor = LogitsProcessor(config)
|
653
755
|
|
654
756
|
@torch.no_grad()
|
655
757
|
def forward(
|
@@ -659,9 +761,10 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|
659
761
|
forward_batch: ForwardBatch,
|
660
762
|
) -> torch.Tensor:
|
661
763
|
hidden_states = self.model(input_ids, positions, forward_batch)
|
662
|
-
|
663
|
-
|
664
|
-
|
764
|
+
if not forward_batch.forward_mode.is_idle():
|
765
|
+
return self.logits_processor(
|
766
|
+
input_ids, hidden_states, self.lm_head.weight, forward_batch
|
767
|
+
)
|
665
768
|
|
666
769
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
667
770
|
stacked_params_mapping = [
|
sglang/srt/models/exaone.py
CHANGED
@@ -1,18 +1,17 @@
|
|
1
|
-
|
2
|
-
Copyright 2024
|
3
|
-
|
4
|
-
|
5
|
-
|
6
|
-
|
7
|
-
|
8
|
-
|
9
|
-
|
10
|
-
|
11
|
-
|
12
|
-
|
13
|
-
|
14
|
-
|
15
|
-
"""
|
1
|
+
# Copyright 2024 The LGcns AI Engineering Team
|
2
|
+
# Copyright 2023-2024 SGLang Team
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
16
15
|
|
17
16
|
# Adapted from llama2.py
|
18
17
|
"""Inference-only Exaone model compatible with HuggingFace weights."""
|
sglang/srt/models/gemma.py
CHANGED
@@ -1,21 +1,21 @@
|
|
1
|
-
|
2
|
-
|
3
|
-
|
4
|
-
|
5
|
-
|
6
|
-
|
7
|
-
|
8
|
-
|
9
|
-
|
10
|
-
|
11
|
-
|
12
|
-
|
13
|
-
|
14
|
-
"""
|
1
|
+
# Copyright 2023-2024 SGLang Team
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
#
|
6
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
7
|
+
#
|
8
|
+
# Unless required by applicable law or agreed to in writing, software
|
9
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
10
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
11
|
+
# See the License for the specific language governing permissions and
|
12
|
+
# limitations under the License.
|
13
|
+
# ==============================================================================
|
15
14
|
|
16
15
|
# Adapted from:
|
17
16
|
# https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/gemma.py#L1
|
18
17
|
"""Inference-only Gemma model compatible with HuggingFace weights."""
|
18
|
+
|
19
19
|
from typing import Iterable, Optional, Tuple
|
20
20
|
|
21
21
|
import torch
|
sglang/srt/models/gemma2.py
CHANGED
@@ -1,20 +1,20 @@
|
|
1
|
-
|
2
|
-
|
3
|
-
|
4
|
-
|
5
|
-
|
6
|
-
|
7
|
-
|
8
|
-
|
9
|
-
|
10
|
-
|
11
|
-
|
12
|
-
|
13
|
-
|
14
|
-
"""
|
1
|
+
# Copyright 2023-2024 SGLang Team
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
#
|
6
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
7
|
+
#
|
8
|
+
# Unless required by applicable law or agreed to in writing, software
|
9
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
10
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
11
|
+
# See the License for the specific language governing permissions and
|
12
|
+
# limitations under the License.
|
13
|
+
# ==============================================================================
|
15
14
|
|
16
15
|
# Adapted from:
|
17
16
|
# https://github.com/vllm-project/vllm/blob/56b325e977435af744f8b3dca7af0ca209663558/vllm/model_executor/models/gemma2.py
|
17
|
+
|
18
18
|
from typing import Iterable, Optional, Set, Tuple, Union
|
19
19
|
|
20
20
|
import torch
|
@@ -38,6 +38,7 @@ from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
|
38
38
|
from sglang.srt.layers.radix_attention import RadixAttention
|
39
39
|
from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
40
40
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
41
|
+
from sglang.srt.utils import make_layers
|
41
42
|
|
42
43
|
|
43
44
|
# Aligned with HF's implementation, using sliding window inclusive with the last token
|
@@ -97,7 +98,7 @@ class Gemma2MLP(nn.Module):
|
|
97
98
|
class Gemma2Attention(nn.Module):
|
98
99
|
def __init__(
|
99
100
|
self,
|
100
|
-
|
101
|
+
layer_id: int,
|
101
102
|
config: PretrainedConfig,
|
102
103
|
hidden_size: int,
|
103
104
|
num_heads: int,
|
@@ -109,7 +110,7 @@ class Gemma2Attention(nn.Module):
|
|
109
110
|
quant_config: Optional[QuantizationConfig] = None,
|
110
111
|
) -> None:
|
111
112
|
super().__init__()
|
112
|
-
self.
|
113
|
+
self.layer_id = layer_id
|
113
114
|
self.config = config
|
114
115
|
self.hidden_size = hidden_size
|
115
116
|
tp_size = get_tensor_model_parallel_world_size()
|
@@ -156,13 +157,13 @@ class Gemma2Attention(nn.Module):
|
|
156
157
|
dtype=torch.get_default_dtype(),
|
157
158
|
)
|
158
159
|
|
159
|
-
use_sliding_window =
|
160
|
+
use_sliding_window = layer_id % 2 == 0 and hasattr(config, "sliding_window")
|
160
161
|
self.attn = RadixAttention(
|
161
162
|
self.num_heads,
|
162
163
|
self.head_dim,
|
163
164
|
self.scaling,
|
164
165
|
num_kv_heads=self.num_kv_heads,
|
165
|
-
layer_id=
|
166
|
+
layer_id=layer_id,
|
166
167
|
logit_cap=self.config.attn_logit_softcapping,
|
167
168
|
sliding_window_size=(
|
168
169
|
get_attention_sliding_window_size(config)
|
@@ -188,7 +189,7 @@ class Gemma2Attention(nn.Module):
|
|
188
189
|
class Gemma2DecoderLayer(nn.Module):
|
189
190
|
def __init__(
|
190
191
|
self,
|
191
|
-
|
192
|
+
layer_id: int,
|
192
193
|
config: PretrainedConfig,
|
193
194
|
cache_config=None,
|
194
195
|
quant_config: Optional[QuantizationConfig] = None,
|
@@ -196,7 +197,7 @@ class Gemma2DecoderLayer(nn.Module):
|
|
196
197
|
super().__init__()
|
197
198
|
self.hidden_size = config.hidden_size
|
198
199
|
self.self_attn = Gemma2Attention(
|
199
|
-
|
200
|
+
layer_id=layer_id,
|
200
201
|
config=config,
|
201
202
|
hidden_size=self.hidden_size,
|
202
203
|
num_heads=config.num_attention_heads,
|
@@ -267,11 +268,15 @@ class Gemma2Model(nn.Module):
|
|
267
268
|
config.vocab_size,
|
268
269
|
config.hidden_size,
|
269
270
|
)
|
270
|
-
self.layers =
|
271
|
-
|
272
|
-
|
273
|
-
|
274
|
-
|
271
|
+
self.layers = make_layers(
|
272
|
+
config.num_hidden_layers,
|
273
|
+
lambda idx, prefix: Gemma2DecoderLayer(
|
274
|
+
layer_id=idx,
|
275
|
+
config=config,
|
276
|
+
cache_config=cache_config,
|
277
|
+
quant_config=quant_config,
|
278
|
+
),
|
279
|
+
prefix="",
|
275
280
|
)
|
276
281
|
self.norm = GemmaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
277
282
|
|
@@ -332,6 +337,7 @@ class Gemma2ForCausalLM(nn.Module):
|
|
332
337
|
# Gemma does not apply LoRA to the embedding layer.
|
333
338
|
embedding_modules = {}
|
334
339
|
embedding_padding_modules = []
|
340
|
+
supports_lora = True
|
335
341
|
|
336
342
|
def __init__(
|
337
343
|
self,
|
@@ -1,17 +1,16 @@
|
|
1
|
-
|
2
|
-
|
3
|
-
|
4
|
-
|
5
|
-
|
6
|
-
|
7
|
-
|
8
|
-
|
9
|
-
|
10
|
-
|
11
|
-
|
12
|
-
|
13
|
-
|
14
|
-
"""
|
1
|
+
# Copyright 2023-2024 SGLang Team
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
#
|
6
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
7
|
+
#
|
8
|
+
# Unless required by applicable law or agreed to in writing, software
|
9
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
10
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
11
|
+
# See the License for the specific language governing permissions and
|
12
|
+
# limitations under the License.
|
13
|
+
# ==============================================================================
|
15
14
|
|
16
15
|
from typing import Iterable, Optional, Tuple
|
17
16
|
|
sglang/srt/models/gpt_bigcode.py
CHANGED
@@ -1,21 +1,21 @@
|
|
1
|
-
|
2
|
-
|
3
|
-
|
4
|
-
|
5
|
-
|
6
|
-
|
7
|
-
|
8
|
-
|
9
|
-
|
10
|
-
|
11
|
-
|
12
|
-
|
13
|
-
|
14
|
-
"""
|
1
|
+
# Copyright 2023-2024 SGLang Team
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
#
|
6
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
7
|
+
#
|
8
|
+
# Unless required by applicable law or agreed to in writing, software
|
9
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
10
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
11
|
+
# See the License for the specific language governing permissions and
|
12
|
+
# limitations under the License.
|
13
|
+
# ==============================================================================
|
15
14
|
|
16
15
|
# Adapted from:
|
17
16
|
# https://github.com/vllm-project/vllm/blob/07eb6f19f3b0ee9f7adf6eb689607028aa40bfd5/vllm/model_executor/models/gpt_bigcode.py
|
18
17
|
"""Inference-only GPTBigCode model compatible with HuggingFace weights."""
|
18
|
+
|
19
19
|
from typing import Iterable, Optional, Tuple
|
20
20
|
|
21
21
|
import torch
|
sglang/srt/models/grok.py
CHANGED
@@ -1,21 +1,21 @@
|
|
1
|
-
|
2
|
-
|
3
|
-
|
4
|
-
|
5
|
-
|
6
|
-
|
7
|
-
|
8
|
-
|
9
|
-
|
10
|
-
|
11
|
-
|
12
|
-
|
13
|
-
|
14
|
-
"""
|
1
|
+
# Copyright 2023-2024 SGLang Team
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
#
|
6
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
7
|
+
#
|
8
|
+
# Unless required by applicable law or agreed to in writing, software
|
9
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
10
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
11
|
+
# See the License for the specific language governing permissions and
|
12
|
+
# limitations under the License.
|
13
|
+
# ==============================================================================
|
15
14
|
|
16
15
|
# Adapted from
|
17
16
|
# https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/mixtral.py#L1
|
18
17
|
"""Inference-only Grok1 model."""
|
18
|
+
|
19
19
|
import warnings
|
20
20
|
from typing import Iterable, List, Optional, Tuple
|
21
21
|
|
@@ -31,7 +31,7 @@ from vllm.model_executor.layers.rotary_embedding import get_rope
|
|
31
31
|
from vllm.model_executor.model_loader.loader import DefaultModelLoader
|
32
32
|
from vllm.model_executor.model_loader.weight_utils import default_weight_loader
|
33
33
|
|
34
|
-
from sglang.srt.layers.
|
34
|
+
from sglang.srt.layers.fused_moe_grok import FusedMoE
|
35
35
|
from sglang.srt.layers.layernorm import RMSNorm
|
36
36
|
from sglang.srt.layers.linear import (
|
37
37
|
QKVParallelLinear,
|