sglang 0.4.1.post6__py3-none-any.whl → 0.4.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/__init__.py +21 -23
- sglang/api.py +2 -7
- sglang/bench_offline_throughput.py +41 -27
- sglang/bench_one_batch.py +60 -4
- sglang/bench_one_batch_server.py +1 -1
- sglang/bench_serving.py +83 -71
- sglang/lang/backend/runtime_endpoint.py +183 -4
- sglang/lang/chat_template.py +46 -4
- sglang/launch_server.py +1 -1
- sglang/srt/_custom_ops.py +80 -42
- sglang/srt/configs/device_config.py +1 -1
- sglang/srt/configs/load_config.py +1 -0
- sglang/srt/configs/model_config.py +1 -0
- sglang/srt/constrained/base_grammar_backend.py +21 -0
- sglang/srt/constrained/xgrammar_backend.py +8 -4
- sglang/srt/conversation.py +14 -1
- sglang/srt/distributed/__init__.py +3 -3
- sglang/srt/distributed/communication_op.py +2 -1
- sglang/srt/distributed/device_communicators/cuda_wrapper.py +2 -1
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +112 -42
- sglang/srt/distributed/device_communicators/custom_all_reduce_utils.py +2 -2
- sglang/srt/distributed/device_communicators/hpu_communicator.py +2 -1
- sglang/srt/distributed/device_communicators/pynccl.py +80 -1
- sglang/srt/distributed/device_communicators/pynccl_wrapper.py +112 -2
- sglang/srt/distributed/device_communicators/shm_broadcast.py +5 -72
- sglang/srt/distributed/device_communicators/xpu_communicator.py +2 -1
- sglang/srt/distributed/parallel_state.py +1 -1
- sglang/srt/distributed/utils.py +2 -1
- sglang/srt/entrypoints/engine.py +452 -0
- sglang/srt/entrypoints/http_server.py +603 -0
- sglang/srt/function_call_parser.py +494 -0
- sglang/srt/layers/activation.py +8 -8
- sglang/srt/layers/attention/flashinfer_backend.py +10 -9
- sglang/srt/layers/attention/triton_backend.py +4 -6
- sglang/srt/layers/attention/vision.py +204 -0
- sglang/srt/layers/dp_attention.py +71 -0
- sglang/srt/layers/layernorm.py +5 -5
- sglang/srt/layers/linear.py +65 -14
- sglang/srt/layers/logits_processor.py +49 -64
- sglang/srt/layers/moe/ep_moe/layer.py +24 -16
- sglang/srt/layers/moe/fused_moe_native.py +84 -1
- sglang/srt/layers/moe/fused_moe_triton/configs/E=256,N=256,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +27 -7
- sglang/srt/layers/moe/fused_moe_triton/layer.py +38 -5
- sglang/srt/layers/parameter.py +18 -8
- sglang/srt/layers/quantization/__init__.py +20 -23
- sglang/srt/layers/quantization/configs/N=1536,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=3072,K=1536,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=4096,K=512,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=4608,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=512,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=576,K=7168,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=7168,K=2048,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=7168,K=2304,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/configs/N=7168,K=256,device_name=AMD_Instinct_MI300X,dtype=fp8_w8a8,block_shape=[128, 128].json +164 -0
- sglang/srt/layers/quantization/fp8.py +10 -4
- sglang/srt/layers/quantization/modelopt_quant.py +1 -2
- sglang/srt/layers/quantization/w8a8_int8.py +1 -1
- sglang/srt/layers/radix_attention.py +2 -2
- sglang/srt/layers/rotary_embedding.py +1184 -31
- sglang/srt/layers/sampler.py +64 -6
- sglang/srt/layers/torchao_utils.py +12 -6
- sglang/srt/layers/vocab_parallel_embedding.py +2 -2
- sglang/srt/lora/lora.py +1 -9
- sglang/srt/managers/configure_logging.py +3 -0
- sglang/srt/managers/data_parallel_controller.py +79 -72
- sglang/srt/managers/detokenizer_manager.py +24 -6
- sglang/srt/managers/image_processor.py +158 -2
- sglang/srt/managers/io_struct.py +57 -3
- sglang/srt/managers/schedule_batch.py +78 -45
- sglang/srt/managers/schedule_policy.py +26 -12
- sglang/srt/managers/scheduler.py +326 -201
- sglang/srt/managers/session_controller.py +1 -0
- sglang/srt/managers/tokenizer_manager.py +210 -121
- sglang/srt/managers/tp_worker.py +6 -4
- sglang/srt/managers/tp_worker_overlap_thread.py +5 -8
- sglang/srt/managers/utils.py +44 -0
- sglang/srt/mem_cache/memory_pool.py +10 -32
- sglang/srt/metrics/collector.py +15 -6
- sglang/srt/model_executor/cuda_graph_runner.py +26 -30
- sglang/srt/model_executor/forward_batch_info.py +5 -7
- sglang/srt/model_executor/model_runner.py +44 -19
- sglang/srt/model_loader/loader.py +83 -6
- sglang/srt/model_loader/weight_utils.py +145 -6
- sglang/srt/models/baichuan.py +6 -6
- sglang/srt/models/chatglm.py +2 -2
- sglang/srt/models/commandr.py +17 -5
- sglang/srt/models/dbrx.py +13 -5
- sglang/srt/models/deepseek.py +3 -3
- sglang/srt/models/deepseek_v2.py +11 -11
- sglang/srt/models/exaone.py +2 -2
- sglang/srt/models/gemma.py +2 -2
- sglang/srt/models/gemma2.py +15 -25
- sglang/srt/models/gpt2.py +3 -5
- sglang/srt/models/gpt_bigcode.py +1 -1
- sglang/srt/models/granite.py +2 -2
- sglang/srt/models/grok.py +4 -3
- sglang/srt/models/internlm2.py +2 -2
- sglang/srt/models/llama.py +7 -5
- sglang/srt/models/minicpm.py +2 -2
- sglang/srt/models/minicpm3.py +9 -9
- sglang/srt/models/minicpmv.py +1238 -0
- sglang/srt/models/mixtral.py +3 -3
- sglang/srt/models/mixtral_quant.py +3 -3
- sglang/srt/models/mllama.py +2 -2
- sglang/srt/models/olmo.py +3 -3
- sglang/srt/models/olmo2.py +4 -4
- sglang/srt/models/olmoe.py +7 -13
- sglang/srt/models/phi3_small.py +2 -2
- sglang/srt/models/qwen.py +2 -2
- sglang/srt/models/qwen2.py +41 -4
- sglang/srt/models/qwen2_moe.py +3 -3
- sglang/srt/models/qwen2_vl.py +22 -122
- sglang/srt/models/stablelm.py +2 -2
- sglang/srt/models/torch_native_llama.py +20 -7
- sglang/srt/models/xverse.py +6 -6
- sglang/srt/models/xverse_moe.py +6 -6
- sglang/srt/openai_api/adapter.py +139 -37
- sglang/srt/openai_api/protocol.py +7 -4
- sglang/srt/sampling/custom_logit_processor.py +38 -0
- sglang/srt/sampling/penaltylib/penalizers/repetition_penalty.py +11 -14
- sglang/srt/sampling/sampling_batch_info.py +143 -18
- sglang/srt/sampling/sampling_params.py +3 -1
- sglang/srt/server.py +4 -1090
- sglang/srt/server_args.py +77 -15
- sglang/srt/speculative/eagle_utils.py +37 -15
- sglang/srt/speculative/eagle_worker.py +11 -13
- sglang/srt/utils.py +164 -129
- sglang/test/runners.py +8 -13
- sglang/test/test_programs.py +2 -1
- sglang/test/test_utils.py +83 -22
- sglang/utils.py +12 -2
- sglang/version.py +1 -1
- {sglang-0.4.1.post6.dist-info → sglang-0.4.2.dist-info}/METADATA +21 -10
- {sglang-0.4.1.post6.dist-info → sglang-0.4.2.dist-info}/RECORD +138 -123
- sglang/launch_server_llavavid.py +0 -25
- sglang/srt/constrained/__init__.py +0 -16
- sglang/srt/distributed/device_communicators/__init__.py +0 -0
- {sglang-0.4.1.post6.dist-info → sglang-0.4.2.dist-info}/LICENSE +0 -0
- {sglang-0.4.1.post6.dist-info → sglang-0.4.2.dist-info}/WHEEL +0 -0
- {sglang-0.4.1.post6.dist-info → sglang-0.4.2.dist-info}/top_level.txt +0 -0
@@ -9,7 +9,17 @@ import logging
|
|
9
9
|
import os
|
10
10
|
import tempfile
|
11
11
|
from collections import defaultdict
|
12
|
-
from typing import
|
12
|
+
from typing import (
|
13
|
+
Any,
|
14
|
+
Callable,
|
15
|
+
Dict,
|
16
|
+
Generator,
|
17
|
+
Iterable,
|
18
|
+
List,
|
19
|
+
Optional,
|
20
|
+
Tuple,
|
21
|
+
Union,
|
22
|
+
)
|
13
23
|
|
14
24
|
import filelock
|
15
25
|
import gguf
|
@@ -17,12 +27,13 @@ import huggingface_hub.constants
|
|
17
27
|
import numpy as np
|
18
28
|
import torch
|
19
29
|
from huggingface_hub import HfFileSystem, hf_hub_download, snapshot_download
|
30
|
+
from pydantic import BaseModel, ConfigDict, ValidationInfo, model_validator
|
20
31
|
from safetensors.torch import load_file, safe_open, save_file
|
21
32
|
from tqdm.auto import tqdm
|
22
|
-
from vllm.distributed import get_tensor_model_parallel_rank
|
23
33
|
|
24
34
|
from sglang.srt.configs.load_config import LoadConfig
|
25
35
|
from sglang.srt.configs.model_config import ModelConfig
|
36
|
+
from sglang.srt.distributed import get_tensor_model_parallel_rank
|
26
37
|
from sglang.srt.layers.quantization import QuantizationConfig, get_quantization_config
|
27
38
|
from sglang.srt.utils import print_warning_once
|
28
39
|
|
@@ -393,8 +404,13 @@ def np_cache_weights_iterator(
|
|
393
404
|
|
394
405
|
def safetensors_weights_iterator(
|
395
406
|
hf_weights_files: List[str],
|
407
|
+
is_all_weights_sharded: bool = False,
|
396
408
|
) -> Generator[Tuple[str, torch.Tensor], None, None]:
|
397
|
-
"""Iterate over the weights in the model safetensor files.
|
409
|
+
"""Iterate over the weights in the model safetensor files.
|
410
|
+
|
411
|
+
If is_all_weights_sharded is True, it uses more optimize read by reading an
|
412
|
+
entire file instead of reading each tensor one by one.
|
413
|
+
"""
|
398
414
|
enable_tqdm = (
|
399
415
|
not torch.distributed.is_initialized() or torch.distributed.get_rank() == 0
|
400
416
|
)
|
@@ -404,9 +420,14 @@ def safetensors_weights_iterator(
|
|
404
420
|
disable=not enable_tqdm,
|
405
421
|
bar_format=_BAR_FORMAT,
|
406
422
|
):
|
407
|
-
|
408
|
-
|
409
|
-
|
423
|
+
if not is_all_weights_sharded:
|
424
|
+
with safe_open(st_file, framework="pt") as f:
|
425
|
+
for name in f.keys(): # noqa: SIM118
|
426
|
+
param = f.get_tensor(name)
|
427
|
+
yield name, param
|
428
|
+
else:
|
429
|
+
result = load_file(st_file, device="cpu")
|
430
|
+
for name, param in result.items():
|
410
431
|
yield name, param
|
411
432
|
|
412
433
|
|
@@ -638,3 +659,121 @@ def maybe_remap_kv_scale_name(name: str, params_dict: dict) -> Optional[str]:
|
|
638
659
|
|
639
660
|
# If there were no matches, return the untouched param name
|
640
661
|
return name
|
662
|
+
|
663
|
+
|
664
|
+
# Adapted from https://github.com/vllm-project/vllm/blob/68ad4e3a8d8a66fb2a43be57471ee13a8bec4ec0/vllm/model_executor/layers/quantization/schema.py
|
665
|
+
class KVCacheQuantSchema(BaseModel):
|
666
|
+
dtype: str
|
667
|
+
# Each key is a TP rank. Each value is a dictionary mapping a TP rank's
|
668
|
+
# layer indices to their per-tensor KV cache scaling factor.
|
669
|
+
# TODO: Consider pulling this and its validation methods out into its
|
670
|
+
# own schema class (tricky as its members are variable)
|
671
|
+
scaling_factor: Dict[int, Dict[int, float]]
|
672
|
+
|
673
|
+
@model_validator(mode="after")
|
674
|
+
def check_is_fp8(self) -> "KVCacheQuantSchema":
|
675
|
+
assert self.dtype == "float8_e4m3fn", (
|
676
|
+
"Loaded scaling factors intended for KV cache dtype = "
|
677
|
+
f"{self.dtype} rather than float8_e4m3fn!"
|
678
|
+
)
|
679
|
+
return self
|
680
|
+
|
681
|
+
@model_validator(mode="after")
|
682
|
+
def check_tp_ranks(self, info: ValidationInfo) -> "KVCacheQuantSchema":
|
683
|
+
context = info.context
|
684
|
+
if context:
|
685
|
+
tp_size = context["tp_size"]
|
686
|
+
num_hidden_layers = context["num_hidden_layers"]
|
687
|
+
assert len(self.scaling_factor) == tp_size, (
|
688
|
+
f"Loaded dictionary has TP size {len(self.scaling_factor)} "
|
689
|
+
f"but LLM engine is currently running with TP size {tp_size}."
|
690
|
+
)
|
691
|
+
for tp_rank, layer_maps in self.scaling_factor.items():
|
692
|
+
assert len(layer_maps) == num_hidden_layers, (
|
693
|
+
f"KV cache scales map for TP rank {tp_rank} is malformed. "
|
694
|
+
f"Expected {num_hidden_layers} layers, got "
|
695
|
+
f"{len(layer_maps)}."
|
696
|
+
)
|
697
|
+
for i in range(tp_size):
|
698
|
+
assert (
|
699
|
+
i in self.scaling_factor
|
700
|
+
), f"KV cache scales map for TP rank {i} not found."
|
701
|
+
return self
|
702
|
+
|
703
|
+
@model_validator(mode="after")
|
704
|
+
def check_current_rank(self, info: ValidationInfo) -> "KVCacheQuantSchema":
|
705
|
+
context = info.context
|
706
|
+
if context:
|
707
|
+
tp_rank = context["tp_rank"]
|
708
|
+
num_hidden_layers = context["num_hidden_layers"]
|
709
|
+
layer_scales_map = self.scaling_factor[tp_rank]
|
710
|
+
for i in range(num_hidden_layers):
|
711
|
+
assert i in layer_scales_map, (
|
712
|
+
f"Could not find KV cache scales for layer {i} in "
|
713
|
+
f"TP rank {tp_rank}."
|
714
|
+
)
|
715
|
+
return self
|
716
|
+
|
717
|
+
|
718
|
+
class QuantParamSchema(BaseModel):
|
719
|
+
# TODO: Generalize and extend with more fields
|
720
|
+
# (e.g. weights/activations params) once functionality is enabled
|
721
|
+
model_config = ConfigDict(protected_namespaces=())
|
722
|
+
model_type: Optional[str]
|
723
|
+
kv_cache: KVCacheQuantSchema
|
724
|
+
|
725
|
+
@model_validator(mode="after")
|
726
|
+
def check_model_type(self, info: ValidationInfo) -> "QuantParamSchema":
|
727
|
+
context = info.context
|
728
|
+
if context:
|
729
|
+
model_type = context.get("model_type", None)
|
730
|
+
if model_type is not None:
|
731
|
+
assert model_type == self.model_type, (
|
732
|
+
f"Model type is {model_type} but loaded "
|
733
|
+
f"scaling factors belonging to different "
|
734
|
+
f"model type {self.model_type}!"
|
735
|
+
)
|
736
|
+
return self
|
737
|
+
|
738
|
+
|
739
|
+
def kv_cache_scales_loader(
|
740
|
+
filename: str,
|
741
|
+
tp_rank: int,
|
742
|
+
tp_size: int,
|
743
|
+
num_hidden_layers: int,
|
744
|
+
model_type: Optional[str],
|
745
|
+
) -> Iterable[Tuple[int, float]]:
|
746
|
+
"""
|
747
|
+
A simple utility to read in KV cache scaling factors that have been
|
748
|
+
previously serialized to disk. Used by the model to populate the appropriate
|
749
|
+
KV cache scaling factors. The serialization should represent a dictionary
|
750
|
+
whose keys are the TP ranks and values are another dictionary mapping layers
|
751
|
+
to their KV cache scaling factors.
|
752
|
+
"""
|
753
|
+
try:
|
754
|
+
with open(filename) as f:
|
755
|
+
context = {
|
756
|
+
"model_type": model_type,
|
757
|
+
"num_hidden_layers": num_hidden_layers,
|
758
|
+
"tp_rank": tp_rank,
|
759
|
+
"tp_size": tp_size,
|
760
|
+
}
|
761
|
+
schema_dct = json.load(f)
|
762
|
+
schema = QuantParamSchema.model_validate(schema_dct, context=context)
|
763
|
+
layer_scales_map = schema.kv_cache.scaling_factor[tp_rank]
|
764
|
+
return layer_scales_map.items()
|
765
|
+
except FileNotFoundError:
|
766
|
+
logger.error("File or directory '%s' not found.", filename)
|
767
|
+
except json.JSONDecodeError:
|
768
|
+
logger.error("Error decoding JSON in file '%s'.", filename)
|
769
|
+
except Exception:
|
770
|
+
logger.error("An error occurred while reading '%s'.", filename)
|
771
|
+
# This section is reached if and only if any of the excepts are hit
|
772
|
+
# Return an empty iterable (list) => no KV cache scales are loaded
|
773
|
+
# which ultimately defaults to 1.0 scales
|
774
|
+
logger.warning(
|
775
|
+
"Defaulting to KV cache scaling factors = 1.0 for all "
|
776
|
+
"layers in TP rank %d as an error occurred during loading.",
|
777
|
+
tp_rank,
|
778
|
+
)
|
779
|
+
return []
|
sglang/srt/models/baichuan.py
CHANGED
@@ -24,22 +24,22 @@ from typing import Iterable, Optional, Tuple
|
|
24
24
|
import torch
|
25
25
|
from torch import nn
|
26
26
|
from transformers import PretrainedConfig
|
27
|
-
|
27
|
+
|
28
|
+
from sglang.srt.distributed import (
|
28
29
|
get_tensor_model_parallel_rank,
|
29
30
|
get_tensor_model_parallel_world_size,
|
30
31
|
)
|
31
|
-
from
|
32
|
+
from sglang.srt.layers.activation import SiluAndMul
|
33
|
+
from sglang.srt.layers.layernorm import RMSNorm
|
34
|
+
from sglang.srt.layers.linear import (
|
32
35
|
MergedColumnParallelLinear,
|
33
36
|
QKVParallelLinear,
|
34
37
|
RowParallelLinear,
|
35
38
|
)
|
36
|
-
from vllm.model_executor.layers.rotary_embedding import get_rope
|
37
|
-
|
38
|
-
from sglang.srt.layers.activation import SiluAndMul
|
39
|
-
from sglang.srt.layers.layernorm import RMSNorm
|
40
39
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
41
40
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
42
41
|
from sglang.srt.layers.radix_attention import RadixAttention
|
42
|
+
from sglang.srt.layers.rotary_embedding import get_rope
|
43
43
|
from sglang.srt.layers.vocab_parallel_embedding import (
|
44
44
|
ParallelLMHead,
|
45
45
|
VocabParallelEmbedding,
|
sglang/srt/models/chatglm.py
CHANGED
@@ -21,10 +21,9 @@ from typing import Iterable, Optional, Tuple
|
|
21
21
|
import torch
|
22
22
|
from torch import nn
|
23
23
|
from torch.nn import LayerNorm
|
24
|
-
from vllm.distributed import get_tensor_model_parallel_world_size
|
25
|
-
from vllm.model_executor.layers.rotary_embedding import get_rope
|
26
24
|
|
27
25
|
from sglang.srt.configs import ChatGLMConfig
|
26
|
+
from sglang.srt.distributed import get_tensor_model_parallel_world_size
|
28
27
|
from sglang.srt.layers.activation import SiluAndMul
|
29
28
|
from sglang.srt.layers.layernorm import RMSNorm
|
30
29
|
from sglang.srt.layers.linear import (
|
@@ -35,6 +34,7 @@ from sglang.srt.layers.linear import (
|
|
35
34
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
36
35
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
37
36
|
from sglang.srt.layers.radix_attention import RadixAttention
|
37
|
+
from sglang.srt.layers.rotary_embedding import get_rope
|
38
38
|
from sglang.srt.layers.vocab_parallel_embedding import (
|
39
39
|
ParallelLMHead,
|
40
40
|
VocabParallelEmbedding,
|
sglang/srt/models/commandr.py
CHANGED
@@ -44,12 +44,11 @@ import torch.utils.checkpoint
|
|
44
44
|
from torch import nn
|
45
45
|
from torch.nn.parameter import Parameter
|
46
46
|
from transformers import PretrainedConfig
|
47
|
-
|
47
|
+
|
48
|
+
from sglang.srt.distributed import (
|
48
49
|
get_tensor_model_parallel_rank,
|
49
50
|
get_tensor_model_parallel_world_size,
|
50
51
|
)
|
51
|
-
from vllm.model_executor.layers.rotary_embedding import get_rope
|
52
|
-
|
53
52
|
from sglang.srt.layers.activation import SiluAndMul
|
54
53
|
from sglang.srt.layers.linear import (
|
55
54
|
MergedColumnParallelLinear,
|
@@ -59,9 +58,13 @@ from sglang.srt.layers.linear import (
|
|
59
58
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
60
59
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
61
60
|
from sglang.srt.layers.radix_attention import RadixAttention
|
61
|
+
from sglang.srt.layers.rotary_embedding import get_rope
|
62
62
|
from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
63
63
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
64
|
-
from sglang.srt.model_loader.weight_utils import
|
64
|
+
from sglang.srt.model_loader.weight_utils import (
|
65
|
+
default_weight_loader,
|
66
|
+
maybe_remap_kv_scale_name,
|
67
|
+
)
|
65
68
|
from sglang.srt.utils import get_compiler_backend, set_weight_attrs
|
66
69
|
|
67
70
|
|
@@ -372,10 +375,19 @@ class CohereForCausalLM(nn.Module):
|
|
372
375
|
# Skip loading extra bias for GPTQ models.
|
373
376
|
if name.endswith(".bias") and name not in params_dict:
|
374
377
|
continue
|
378
|
+
# Remapping the name of FP8 kv-scale.
|
379
|
+
name = maybe_remap_kv_scale_name(name, params_dict)
|
380
|
+
if name is None:
|
381
|
+
continue
|
382
|
+
|
375
383
|
param = params_dict[name]
|
376
384
|
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
377
385
|
weight_loader(param, loaded_weight)
|
378
386
|
loaded_params.add(name)
|
379
387
|
|
380
388
|
|
381
|
-
|
389
|
+
class Cohere2ForCausalLM(CohereForCausalLM):
|
390
|
+
pass
|
391
|
+
|
392
|
+
|
393
|
+
EntryClass = [CohereForCausalLM, Cohere2ForCausalLM]
|
sglang/srt/models/dbrx.py
CHANGED
@@ -19,14 +19,13 @@ from typing import Iterable, Optional, Tuple
|
|
19
19
|
|
20
20
|
import torch
|
21
21
|
import torch.nn as nn
|
22
|
-
|
22
|
+
|
23
|
+
from sglang.srt.configs import DbrxConfig
|
24
|
+
from sglang.srt.distributed import (
|
23
25
|
get_tensor_model_parallel_rank,
|
24
26
|
get_tensor_model_parallel_world_size,
|
25
27
|
tensor_model_parallel_all_reduce,
|
26
28
|
)
|
27
|
-
from vllm.model_executor.layers.rotary_embedding import get_rope
|
28
|
-
|
29
|
-
from sglang.srt.configs import DbrxConfig
|
30
29
|
from sglang.srt.layers.linear import (
|
31
30
|
QKVParallelLinear,
|
32
31
|
ReplicatedLinear,
|
@@ -36,13 +35,17 @@ from sglang.srt.layers.logits_processor import LogitsProcessor
|
|
36
35
|
from sglang.srt.layers.moe.fused_moe_triton import fused_moe
|
37
36
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
38
37
|
from sglang.srt.layers.radix_attention import RadixAttention
|
38
|
+
from sglang.srt.layers.rotary_embedding import get_rope
|
39
39
|
from sglang.srt.layers.vocab_parallel_embedding import (
|
40
40
|
DEFAULT_VOCAB_PADDING_SIZE,
|
41
41
|
ParallelLMHead,
|
42
42
|
VocabParallelEmbedding,
|
43
43
|
)
|
44
44
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
45
|
-
from sglang.srt.model_loader.weight_utils import
|
45
|
+
from sglang.srt.model_loader.weight_utils import (
|
46
|
+
default_weight_loader,
|
47
|
+
maybe_remap_kv_scale_name,
|
48
|
+
)
|
46
49
|
from sglang.srt.utils import set_weight_attrs
|
47
50
|
|
48
51
|
|
@@ -411,6 +414,11 @@ class DbrxForCausalLM(nn.Module):
|
|
411
414
|
weight_loader(param, loaded_weight, weight_name)
|
412
415
|
break
|
413
416
|
else:
|
417
|
+
# Remapping the name of FP8 kv-scale.
|
418
|
+
name = maybe_remap_kv_scale_name(name, params_dict)
|
419
|
+
if name is None:
|
420
|
+
continue
|
421
|
+
|
414
422
|
param = params_dict[name]
|
415
423
|
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
416
424
|
weight_loader(param, loaded_weight)
|
sglang/srt/models/deepseek.py
CHANGED
@@ -21,13 +21,12 @@ from typing import Any, Dict, Iterable, Optional, Tuple
|
|
21
21
|
import torch
|
22
22
|
from torch import nn
|
23
23
|
from transformers import PretrainedConfig
|
24
|
-
|
24
|
+
|
25
|
+
from sglang.srt.distributed import (
|
25
26
|
get_tensor_model_parallel_rank,
|
26
27
|
get_tensor_model_parallel_world_size,
|
27
28
|
tensor_model_parallel_all_reduce,
|
28
29
|
)
|
29
|
-
from vllm.model_executor.layers.rotary_embedding import get_rope
|
30
|
-
|
31
30
|
from sglang.srt.layers.activation import SiluAndMul
|
32
31
|
from sglang.srt.layers.layernorm import RMSNorm
|
33
32
|
from sglang.srt.layers.linear import (
|
@@ -40,6 +39,7 @@ from sglang.srt.layers.logits_processor import LogitsProcessor
|
|
40
39
|
from sglang.srt.layers.moe.fused_moe_triton import fused_moe
|
41
40
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
42
41
|
from sglang.srt.layers.radix_attention import RadixAttention
|
42
|
+
from sglang.srt.layers.rotary_embedding import get_rope
|
43
43
|
from sglang.srt.layers.vocab_parallel_embedding import (
|
44
44
|
ParallelLMHead,
|
45
45
|
VocabParallelEmbedding,
|
sglang/srt/models/deepseek_v2.py
CHANGED
@@ -23,14 +23,13 @@ import torch.nn.functional as F
|
|
23
23
|
from torch import nn
|
24
24
|
from transformers import PretrainedConfig
|
25
25
|
from vllm import _custom_ops as ops
|
26
|
-
|
26
|
+
|
27
|
+
from sglang.srt.distributed import (
|
27
28
|
get_tensor_model_parallel_rank,
|
28
29
|
get_tensor_model_parallel_world_size,
|
29
30
|
get_tp_group,
|
30
31
|
tensor_model_parallel_all_reduce,
|
31
32
|
)
|
32
|
-
from vllm.model_executor.layers.rotary_embedding import get_rope
|
33
|
-
|
34
33
|
from sglang.srt.layers.activation import SiluAndMul
|
35
34
|
from sglang.srt.layers.layernorm import RMSNorm
|
36
35
|
from sglang.srt.layers.linear import (
|
@@ -49,6 +48,7 @@ from sglang.srt.layers.quantization.fp8_utils import (
|
|
49
48
|
normalize_e4m3fn_to_e4m3fnuz,
|
50
49
|
)
|
51
50
|
from sglang.srt.layers.radix_attention import RadixAttention
|
51
|
+
from sglang.srt.layers.rotary_embedding import get_rope, get_rope_wrapper
|
52
52
|
from sglang.srt.layers.vocab_parallel_embedding import (
|
53
53
|
ParallelLMHead,
|
54
54
|
VocabParallelEmbedding,
|
@@ -56,12 +56,12 @@ from sglang.srt.layers.vocab_parallel_embedding import (
|
|
56
56
|
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
57
57
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
58
58
|
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
59
|
-
from sglang.srt.utils import
|
59
|
+
from sglang.srt.utils import is_cuda_available, is_hip
|
60
60
|
|
61
61
|
is_hip_ = is_hip()
|
62
62
|
|
63
|
-
if
|
64
|
-
from
|
63
|
+
if is_cuda_available():
|
64
|
+
from sgl_kernel import bmm_fp8
|
65
65
|
|
66
66
|
|
67
67
|
class DeepseekV2MLP(nn.Module):
|
@@ -271,13 +271,14 @@ class DeepseekV2Attention(nn.Module):
|
|
271
271
|
quant_config=quant_config,
|
272
272
|
)
|
273
273
|
rope_scaling["rope_type"] = "deepseek_yarn"
|
274
|
-
self.rotary_emb =
|
274
|
+
self.rotary_emb = get_rope_wrapper(
|
275
275
|
qk_rope_head_dim,
|
276
276
|
rotary_dim=qk_rope_head_dim,
|
277
277
|
max_position=max_position_embeddings,
|
278
278
|
base=rope_theta,
|
279
279
|
rope_scaling=rope_scaling,
|
280
280
|
is_neox_style=False,
|
281
|
+
device=global_server_args_dict["device"],
|
281
282
|
)
|
282
283
|
|
283
284
|
if rope_scaling:
|
@@ -855,10 +856,9 @@ class DeepseekV2ForCausalLM(nn.Module):
|
|
855
856
|
forward_batch: ForwardBatch,
|
856
857
|
) -> torch.Tensor:
|
857
858
|
hidden_states = self.model(input_ids, positions, forward_batch)
|
858
|
-
|
859
|
-
|
860
|
-
|
861
|
-
)
|
859
|
+
return self.logits_processor(
|
860
|
+
input_ids, hidden_states, self.lm_head, forward_batch
|
861
|
+
)
|
862
862
|
|
863
863
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
864
864
|
stacked_params_mapping = [
|
sglang/srt/models/exaone.py
CHANGED
@@ -20,9 +20,8 @@ from typing import Any, Dict, Iterable, Optional, Tuple
|
|
20
20
|
|
21
21
|
import torch
|
22
22
|
from torch import nn
|
23
|
-
from vllm.distributed import get_tensor_model_parallel_world_size
|
24
|
-
from vllm.model_executor.layers.rotary_embedding import get_rope
|
25
23
|
|
24
|
+
from sglang.srt.distributed import get_tensor_model_parallel_world_size
|
26
25
|
from sglang.srt.layers.activation import SiluAndMul
|
27
26
|
from sglang.srt.layers.layernorm import RMSNorm
|
28
27
|
from sglang.srt.layers.linear import (
|
@@ -33,6 +32,7 @@ from sglang.srt.layers.linear import (
|
|
33
32
|
from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput
|
34
33
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
35
34
|
from sglang.srt.layers.radix_attention import RadixAttention
|
35
|
+
from sglang.srt.layers.rotary_embedding import get_rope
|
36
36
|
from sglang.srt.layers.vocab_parallel_embedding import (
|
37
37
|
ParallelLMHead,
|
38
38
|
VocabParallelEmbedding,
|
sglang/srt/models/gemma.py
CHANGED
@@ -21,9 +21,8 @@ from typing import Iterable, Optional, Tuple
|
|
21
21
|
import torch
|
22
22
|
from torch import nn
|
23
23
|
from transformers import PretrainedConfig
|
24
|
-
from vllm.distributed import get_tensor_model_parallel_world_size
|
25
|
-
from vllm.model_executor.layers.rotary_embedding import get_rope
|
26
24
|
|
25
|
+
from sglang.srt.distributed import get_tensor_model_parallel_world_size
|
27
26
|
from sglang.srt.layers.activation import GeluAndMul
|
28
27
|
from sglang.srt.layers.layernorm import RMSNorm
|
29
28
|
from sglang.srt.layers.linear import (
|
@@ -34,6 +33,7 @@ from sglang.srt.layers.linear import (
|
|
34
33
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
35
34
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
36
35
|
from sglang.srt.layers.radix_attention import RadixAttention
|
36
|
+
from sglang.srt.layers.rotary_embedding import get_rope
|
37
37
|
from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
38
38
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
39
39
|
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
sglang/srt/models/gemma2.py
CHANGED
@@ -15,13 +15,13 @@
|
|
15
15
|
# Adapted from:
|
16
16
|
# https://github.com/vllm-project/vllm/blob/56b325e977435af744f8b3dca7af0ca209663558/vllm/model_executor/models/gemma2.py
|
17
17
|
|
18
|
-
from typing import Iterable, Optional, Set, Tuple
|
18
|
+
from typing import Iterable, Optional, Set, Tuple
|
19
19
|
|
20
20
|
import torch
|
21
21
|
from torch import nn
|
22
22
|
from transformers import PretrainedConfig
|
23
|
-
from vllm.distributed import get_tensor_model_parallel_world_size
|
24
23
|
|
24
|
+
from sglang.srt.distributed import get_tensor_model_parallel_world_size
|
25
25
|
from sglang.srt.layers.activation import GeluAndMul
|
26
26
|
from sglang.srt.layers.layernorm import GemmaRMSNorm
|
27
27
|
from sglang.srt.layers.linear import (
|
@@ -32,9 +32,13 @@ from sglang.srt.layers.linear import (
|
|
32
32
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
33
33
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
34
34
|
from sglang.srt.layers.radix_attention import RadixAttention
|
35
|
+
from sglang.srt.layers.rotary_embedding import get_rope
|
35
36
|
from sglang.srt.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
36
37
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
37
|
-
from sglang.srt.model_loader.weight_utils import
|
38
|
+
from sglang.srt.model_loader.weight_utils import (
|
39
|
+
default_weight_loader,
|
40
|
+
maybe_remap_kv_scale_name,
|
41
|
+
)
|
38
42
|
from sglang.srt.utils import make_layers
|
39
43
|
|
40
44
|
|
@@ -44,23 +48,6 @@ def get_attention_sliding_window_size(config):
|
|
44
48
|
return config.sliding_window - 1
|
45
49
|
|
46
50
|
|
47
|
-
# FIXME: temporary solution, remove after next vllm release
|
48
|
-
from vllm.model_executor.layers.rotary_embedding import RotaryEmbedding
|
49
|
-
|
50
|
-
|
51
|
-
class GemmaRotaryEmbedding(RotaryEmbedding):
|
52
|
-
def _compute_inv_freq(self, base: Union[int, float]) -> torch.Tensor:
|
53
|
-
# https://github.com/huggingface/transformers/blob/v4.41.2/src/transformers/models/gemma/modeling_gemma.py#L107
|
54
|
-
inv_freq = 1.0 / (
|
55
|
-
base
|
56
|
-
** (
|
57
|
-
torch.arange(0, self.rotary_dim, 2, dtype=torch.int64).float()
|
58
|
-
/ self.rotary_dim
|
59
|
-
)
|
60
|
-
)
|
61
|
-
return inv_freq
|
62
|
-
|
63
|
-
|
64
51
|
class Gemma2MLP(nn.Module):
|
65
52
|
def __init__(
|
66
53
|
self,
|
@@ -143,14 +130,12 @@ class Gemma2Attention(nn.Module):
|
|
143
130
|
bias=config.attention_bias,
|
144
131
|
quant_config=quant_config,
|
145
132
|
)
|
146
|
-
|
147
|
-
self.rotary_emb = GemmaRotaryEmbedding(
|
148
|
-
self.head_dim,
|
133
|
+
self.rotary_emb = get_rope(
|
149
134
|
self.head_dim,
|
150
|
-
|
135
|
+
rotary_dim=self.head_dim,
|
136
|
+
max_position=max_position_embeddings,
|
151
137
|
base=self.rope_theta,
|
152
138
|
is_neox_style=True,
|
153
|
-
dtype=torch.get_default_dtype(),
|
154
139
|
)
|
155
140
|
|
156
141
|
use_sliding_window = layer_id % 2 == 0 and hasattr(config, "sliding_window")
|
@@ -442,6 +427,11 @@ class Gemma2ForCausalLM(nn.Module):
|
|
442
427
|
# Skip loading extra bias for GPTQ models.
|
443
428
|
if name.endswith(".bias") and name not in params_dict:
|
444
429
|
continue
|
430
|
+
# Remapping the name of FP8 kv-scale.
|
431
|
+
name = maybe_remap_kv_scale_name(name, params_dict)
|
432
|
+
if name is None:
|
433
|
+
continue
|
434
|
+
|
445
435
|
param = params_dict[name]
|
446
436
|
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
447
437
|
weight_loader(param, loaded_weight)
|
sglang/srt/models/gpt2.py
CHANGED
@@ -17,16 +17,14 @@
|
|
17
17
|
# See the License for the specific language governing permissions and
|
18
18
|
# limitations under the License.
|
19
19
|
"""Inference-only GPT-2 model compatible with HuggingFace weights."""
|
20
|
-
from typing import Iterable,
|
20
|
+
from typing import Iterable, Optional, Tuple
|
21
21
|
|
22
22
|
import torch
|
23
23
|
from torch import nn
|
24
24
|
from transformers import GPT2Config
|
25
|
-
from vllm.distributed.parallel_state import get_tensor_model_parallel_world_size
|
26
|
-
from vllm.model_executor.layers.activation import get_act_fn
|
27
|
-
from vllm.model_executor.layers.vocab_parallel_embedding import VocabParallelEmbedding
|
28
25
|
|
29
|
-
|
26
|
+
from sglang.srt.distributed.parallel_state import get_tensor_model_parallel_world_size
|
27
|
+
from sglang.srt.layers.activation import get_act_fn
|
30
28
|
from sglang.srt.layers.linear import (
|
31
29
|
ColumnParallelLinear,
|
32
30
|
QKVParallelLinear,
|
sglang/srt/models/gpt_bigcode.py
CHANGED
@@ -21,8 +21,8 @@ from typing import Iterable, Optional, Tuple
|
|
21
21
|
import torch
|
22
22
|
from torch import nn
|
23
23
|
from transformers import GPTBigCodeConfig
|
24
|
-
from vllm.distributed import get_tensor_model_parallel_world_size
|
25
24
|
|
25
|
+
from sglang.srt.distributed import get_tensor_model_parallel_world_size
|
26
26
|
from sglang.srt.layers.activation import get_act_fn
|
27
27
|
from sglang.srt.layers.linear import (
|
28
28
|
ColumnParallelLinear,
|
sglang/srt/models/granite.py
CHANGED
@@ -22,9 +22,8 @@ from typing import Any, Dict, Iterable, Optional, Tuple
|
|
22
22
|
import torch
|
23
23
|
from torch import nn
|
24
24
|
from transformers import GraniteConfig
|
25
|
-
from vllm.distributed import get_tensor_model_parallel_world_size
|
26
|
-
from vllm.model_executor.layers.rotary_embedding import get_rope
|
27
25
|
|
26
|
+
from sglang.srt.distributed import get_tensor_model_parallel_world_size
|
28
27
|
from sglang.srt.layers.activation import SiluAndMul
|
29
28
|
from sglang.srt.layers.layernorm import RMSNorm
|
30
29
|
from sglang.srt.layers.linear import (
|
@@ -36,6 +35,7 @@ from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorO
|
|
36
35
|
from sglang.srt.layers.pooler import Pooler, PoolingType
|
37
36
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
38
37
|
from sglang.srt.layers.radix_attention import RadixAttention
|
38
|
+
from sglang.srt.layers.rotary_embedding import get_rope
|
39
39
|
from sglang.srt.layers.vocab_parallel_embedding import (
|
40
40
|
ParallelLMHead,
|
41
41
|
VocabParallelEmbedding,
|
sglang/srt/models/grok.py
CHANGED
@@ -22,12 +22,11 @@ import torch
|
|
22
22
|
import torch.nn.functional as F
|
23
23
|
from torch import nn
|
24
24
|
from transformers import PretrainedConfig
|
25
|
-
|
25
|
+
|
26
|
+
from sglang.srt.distributed import (
|
26
27
|
get_tensor_model_parallel_rank,
|
27
28
|
get_tensor_model_parallel_world_size,
|
28
29
|
)
|
29
|
-
from vllm.model_executor.layers.rotary_embedding import get_rope
|
30
|
-
|
31
30
|
from sglang.srt.layers.activation import GeluAndMul
|
32
31
|
from sglang.srt.layers.layernorm import RMSNorm
|
33
32
|
from sglang.srt.layers.linear import (
|
@@ -40,6 +39,7 @@ from sglang.srt.layers.logits_processor import LogitsProcessor
|
|
40
39
|
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
41
40
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
42
41
|
from sglang.srt.layers.radix_attention import RadixAttention
|
42
|
+
from sglang.srt.layers.rotary_embedding import get_rope
|
43
43
|
from sglang.srt.layers.vocab_parallel_embedding import (
|
44
44
|
ParallelLMHead,
|
45
45
|
VocabParallelEmbedding,
|
@@ -133,6 +133,7 @@ class Grok1MoE(nn.Module):
|
|
133
133
|
renormalize=False,
|
134
134
|
quant_config=quant_config,
|
135
135
|
tp_size=tp_size,
|
136
|
+
activation="gelu",
|
136
137
|
use_presharded_weights=use_presharded_weights,
|
137
138
|
)
|
138
139
|
|
sglang/srt/models/internlm2.py
CHANGED
@@ -19,9 +19,8 @@ from typing import Any, Dict, Iterable, Optional, Tuple
|
|
19
19
|
import torch
|
20
20
|
from torch import nn
|
21
21
|
from transformers import PretrainedConfig
|
22
|
-
from vllm.distributed import get_tensor_model_parallel_world_size
|
23
|
-
from vllm.model_executor.layers.rotary_embedding import get_rope
|
24
22
|
|
23
|
+
from sglang.srt.distributed import get_tensor_model_parallel_world_size
|
25
24
|
from sglang.srt.layers.activation import SiluAndMul
|
26
25
|
from sglang.srt.layers.layernorm import RMSNorm
|
27
26
|
from sglang.srt.layers.linear import (
|
@@ -32,6 +31,7 @@ from sglang.srt.layers.linear import (
|
|
32
31
|
from sglang.srt.layers.logits_processor import LogitsProcessor
|
33
32
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
34
33
|
from sglang.srt.layers.radix_attention import RadixAttention
|
34
|
+
from sglang.srt.layers.rotary_embedding import get_rope
|
35
35
|
from sglang.srt.layers.vocab_parallel_embedding import (
|
36
36
|
ParallelLMHead,
|
37
37
|
VocabParallelEmbedding,
|