sglang 0.5.0rc1__py3-none-any.whl → 0.5.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +0 -7
- sglang/bench_one_batch_server.py +7 -2
- sglang/bench_serving.py +3 -3
- sglang/eval/llama3_eval.py +0 -1
- sglang/srt/configs/model_config.py +25 -9
- sglang/srt/configs/update_config.py +40 -5
- sglang/srt/constrained/xgrammar_backend.py +23 -11
- sglang/srt/conversation.py +2 -15
- sglang/srt/disaggregation/ascend/conn.py +1 -3
- sglang/srt/disaggregation/base/conn.py +1 -0
- sglang/srt/disaggregation/decode.py +1 -2
- sglang/srt/disaggregation/launch_lb.py +7 -1
- sglang/srt/disaggregation/mini_lb.py +11 -5
- sglang/srt/disaggregation/mooncake/conn.py +141 -47
- sglang/srt/disaggregation/prefill.py +261 -5
- sglang/srt/disaggregation/utils.py +2 -1
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +1 -1
- sglang/srt/distributed/device_communicators/pynccl.py +68 -18
- sglang/srt/distributed/device_communicators/pynccl_wrapper.py +52 -0
- sglang/srt/distributed/naive_distributed.py +112 -0
- sglang/srt/distributed/parallel_state.py +90 -4
- sglang/srt/entrypoints/context.py +20 -1
- sglang/srt/entrypoints/engine.py +29 -4
- sglang/srt/entrypoints/http_server.py +76 -0
- sglang/srt/entrypoints/openai/protocol.py +4 -2
- sglang/srt/entrypoints/openai/serving_chat.py +23 -6
- sglang/srt/entrypoints/openai/serving_completions.py +10 -1
- sglang/srt/entrypoints/openai/serving_responses.py +2 -2
- sglang/srt/eplb/expert_distribution.py +2 -3
- sglang/srt/function_call/deepseekv3_detector.py +1 -1
- sglang/srt/hf_transformers_utils.py +24 -0
- sglang/srt/host_shared_memory.py +83 -0
- sglang/srt/layers/attention/ascend_backend.py +132 -22
- sglang/srt/layers/attention/flashattention_backend.py +24 -17
- sglang/srt/layers/attention/flashinfer_backend.py +14 -3
- sglang/srt/layers/attention/flashinfer_mla_backend.py +227 -76
- sglang/srt/layers/attention/triton_backend.py +109 -73
- sglang/srt/layers/attention/triton_ops/decode_attention.py +33 -2
- sglang/srt/layers/attention/triton_ops/extend_attention.py +32 -2
- sglang/srt/layers/attention/trtllm_mha_backend.py +398 -36
- sglang/srt/layers/attention/trtllm_mla_backend.py +49 -19
- sglang/srt/layers/attention/utils.py +94 -15
- sglang/srt/layers/attention/vision.py +40 -13
- sglang/srt/layers/attention/vision_utils.py +65 -0
- sglang/srt/layers/communicator.py +58 -10
- sglang/srt/layers/dp_attention.py +137 -27
- sglang/srt/layers/elementwise.py +94 -0
- sglang/srt/layers/flashinfer_comm_fusion.py +29 -1
- sglang/srt/layers/layernorm.py +8 -1
- sglang/srt/layers/linear.py +24 -0
- sglang/srt/layers/logits_processor.py +16 -18
- sglang/srt/layers/moe/__init__.py +31 -0
- sglang/srt/layers/moe/ep_moe/layer.py +37 -33
- sglang/srt/layers/moe/fused_moe_native.py +14 -25
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=129,N=352,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_2_0/E=161,N=192,device_name=NVIDIA_A800-SXM4-80GB,dtype=int8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_0/E=16,N=1024,device_name=NVIDIA_B200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=384,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=160,N=640,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=128,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=128,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H20-3e,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=257,N=256,device_name=NVIDIA_H200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=384,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=704,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=161,N=384,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +69 -76
- sglang/srt/layers/moe/fused_moe_triton/layer.py +66 -123
- sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +20 -18
- sglang/srt/layers/moe/moe_runner/__init__.py +3 -0
- sglang/srt/layers/moe/moe_runner/base.py +13 -0
- sglang/srt/layers/moe/rocm_moe_utils.py +141 -0
- sglang/srt/layers/moe/router.py +15 -9
- sglang/srt/layers/moe/token_dispatcher/__init__.py +6 -0
- sglang/srt/layers/moe/token_dispatcher/base_dispatcher.py +55 -14
- sglang/srt/layers/moe/token_dispatcher/deepep.py +11 -21
- sglang/srt/layers/moe/token_dispatcher/standard.py +1 -1
- sglang/srt/layers/moe/topk.py +167 -83
- sglang/srt/layers/moe/utils.py +159 -18
- sglang/srt/layers/multimodal.py +156 -40
- sglang/srt/layers/quantization/__init__.py +18 -46
- sglang/srt/layers/quantization/awq.py +22 -23
- sglang/srt/layers/quantization/base_config.py +2 -6
- sglang/srt/layers/quantization/blockwise_int8.py +4 -12
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +72 -29
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -1
- sglang/srt/layers/quantization/fp8.py +127 -119
- sglang/srt/layers/quantization/fp8_kernel.py +195 -24
- sglang/srt/layers/quantization/fp8_utils.py +34 -9
- sglang/srt/layers/quantization/fpgemm_fp8.py +203 -0
- sglang/srt/layers/quantization/gptq.py +17 -21
- sglang/srt/layers/quantization/marlin_utils.py +26 -8
- sglang/srt/layers/quantization/marlin_utils_fp8.py +352 -0
- sglang/srt/layers/quantization/modelopt_quant.py +217 -98
- sglang/srt/layers/quantization/moe_wna16.py +10 -15
- sglang/srt/layers/quantization/mxfp4.py +222 -39
- sglang/srt/layers/quantization/quark/quark.py +390 -0
- sglang/srt/layers/quantization/quark/quark_moe.py +197 -0
- sglang/srt/layers/quantization/unquant.py +34 -70
- sglang/srt/layers/quantization/utils.py +77 -2
- sglang/srt/layers/quantization/w4afp8.py +7 -8
- sglang/srt/layers/quantization/w8a8_fp8.py +5 -13
- sglang/srt/layers/quantization/w8a8_int8.py +5 -13
- sglang/srt/layers/radix_attention.py +6 -0
- sglang/srt/layers/rotary_embedding.py +1 -0
- sglang/srt/layers/sampler.py +5 -2
- sglang/srt/lora/layers.py +6 -2
- sglang/srt/lora/lora_manager.py +21 -22
- sglang/srt/lora/lora_registry.py +3 -3
- sglang/srt/lora/mem_pool.py +26 -24
- sglang/srt/lora/utils.py +10 -12
- sglang/srt/managers/cache_controller.py +80 -19
- sglang/srt/managers/detokenizer_manager.py +10 -2
- sglang/srt/managers/io_struct.py +23 -0
- sglang/srt/managers/mm_utils.py +1 -1
- sglang/srt/managers/schedule_batch.py +22 -48
- sglang/srt/managers/scheduler.py +28 -20
- sglang/srt/managers/session_controller.py +1 -1
- sglang/srt/managers/template_manager.py +7 -5
- sglang/srt/managers/tokenizer_manager.py +88 -39
- sglang/srt/managers/tp_worker.py +1 -0
- sglang/srt/managers/utils.py +59 -1
- sglang/srt/mem_cache/allocator.py +10 -157
- sglang/srt/mem_cache/allocator_ascend.py +147 -0
- sglang/srt/mem_cache/chunk_cache.py +1 -1
- sglang/srt/mem_cache/hicache_storage.py +14 -4
- sglang/srt/mem_cache/memory_pool.py +3 -3
- sglang/srt/mem_cache/memory_pool_host.py +35 -2
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +56 -12
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +8 -4
- sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +153 -59
- sglang/srt/mem_cache/storage/nixl/nixl_utils.py +19 -53
- sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +46 -7
- sglang/srt/model_executor/cuda_graph_runner.py +33 -33
- sglang/srt/model_executor/forward_batch_info.py +11 -10
- sglang/srt/model_executor/model_runner.py +93 -78
- sglang/srt/model_executor/npu_graph_runner.py +94 -0
- sglang/srt/model_loader/loader.py +24 -6
- sglang/srt/models/dbrx.py +12 -6
- sglang/srt/models/deepseek.py +2 -1
- sglang/srt/models/deepseek_nextn.py +5 -2
- sglang/srt/models/deepseek_v2.py +226 -223
- sglang/srt/models/ernie4.py +2 -2
- sglang/srt/models/glm4_moe.py +27 -65
- sglang/srt/models/glm4_moe_nextn.py +2 -1
- sglang/srt/models/glm4v.py +52 -1
- sglang/srt/models/glm4v_moe.py +8 -11
- sglang/srt/models/gpt_oss.py +41 -76
- sglang/srt/models/granitemoe.py +0 -1
- sglang/srt/models/grok.py +376 -48
- sglang/srt/models/interns1.py +12 -47
- sglang/srt/models/internvl.py +6 -51
- sglang/srt/models/llama.py +10 -2
- sglang/srt/models/llama4.py +18 -7
- sglang/srt/models/minicpm3.py +0 -1
- sglang/srt/models/mixtral.py +0 -2
- sglang/srt/models/nemotron_nas.py +435 -0
- sglang/srt/models/olmoe.py +0 -1
- sglang/srt/models/phi4mm.py +3 -21
- sglang/srt/models/qwen2.py +2 -2
- sglang/srt/models/qwen2_5_vl.py +2 -0
- sglang/srt/models/qwen2_moe.py +23 -23
- sglang/srt/models/qwen3.py +2 -2
- sglang/srt/models/qwen3_classification.py +84 -0
- sglang/srt/models/qwen3_moe.py +27 -43
- sglang/srt/models/step3_vl.py +8 -3
- sglang/srt/models/xverse_moe.py +11 -5
- sglang/srt/multimodal/processors/base_processor.py +3 -3
- sglang/srt/multimodal/processors/internvl.py +7 -2
- sglang/srt/multimodal/processors/llava.py +11 -7
- sglang/srt/offloader.py +433 -0
- sglang/srt/operations.py +22 -2
- sglang/srt/reasoning_parser.py +4 -3
- sglang/srt/sampling/sampling_batch_info.py +7 -4
- sglang/srt/server_args.py +264 -105
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +8 -21
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +7 -21
- sglang/srt/speculative/eagle_utils.py +36 -13
- sglang/srt/speculative/eagle_worker.py +56 -3
- sglang/srt/tokenizer/tiktoken_tokenizer.py +161 -0
- sglang/srt/two_batch_overlap.py +20 -19
- sglang/srt/utils.py +68 -70
- sglang/test/runners.py +8 -5
- sglang/test/test_block_fp8.py +5 -6
- sglang/test/test_block_fp8_ep.py +13 -19
- sglang/test/test_cutlass_moe.py +4 -6
- sglang/test/test_cutlass_w4a8_moe.py +4 -3
- sglang/test/test_fp4_moe.py +4 -3
- sglang/test/test_marlin_moe.py +1 -1
- sglang/test/test_marlin_utils.py +1 -1
- sglang/test/test_utils.py +7 -0
- sglang/utils.py +0 -1
- sglang/version.py +1 -1
- {sglang-0.5.0rc1.dist-info → sglang-0.5.1.dist-info}/METADATA +11 -11
- {sglang-0.5.0rc1.dist-info → sglang-0.5.1.dist-info}/RECORD +201 -171
- sglang/srt/layers/quantization/fp4.py +0 -557
- sglang/srt/layers/quantization/scalar_type.py +0 -352
- {sglang-0.5.0rc1.dist-info → sglang-0.5.1.dist-info}/WHEEL +0 -0
- {sglang-0.5.0rc1.dist-info → sglang-0.5.1.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.0rc1.dist-info → sglang-0.5.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,435 @@
|
|
1
|
+
# Copyright 2023-2025 SGLang Team
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
#
|
6
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
7
|
+
#
|
8
|
+
# Unless required by applicable law or agreed to in writing, software
|
9
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
10
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
11
|
+
# See the License for the specific language governing permissions and
|
12
|
+
# limitations under the License.
|
13
|
+
# ==============================================================================
|
14
|
+
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/nemotron_nas.py
|
15
|
+
|
16
|
+
"""Inference-only deci model compatible with HuggingFace weights."""
|
17
|
+
from typing import Iterable, Optional, Tuple, Type, Union
|
18
|
+
|
19
|
+
import torch
|
20
|
+
from torch import nn
|
21
|
+
from transformers import LlamaConfig
|
22
|
+
|
23
|
+
from sglang.srt.distributed import get_pp_group
|
24
|
+
from sglang.srt.layers.layernorm import RMSNorm
|
25
|
+
from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput
|
26
|
+
from sglang.srt.layers.pooler import Pooler, PoolingType
|
27
|
+
from sglang.srt.layers.quantization import QuantizationConfig
|
28
|
+
from sglang.srt.layers.utils import PPMissingLayer
|
29
|
+
from sglang.srt.layers.vocab_parallel_embedding import (
|
30
|
+
DEFAULT_VOCAB_PADDING_SIZE,
|
31
|
+
ParallelLMHead,
|
32
|
+
VocabParallelEmbedding,
|
33
|
+
)
|
34
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
|
35
|
+
from sglang.srt.model_loader.weight_utils import (
|
36
|
+
default_weight_loader,
|
37
|
+
maybe_remap_kv_scale_name,
|
38
|
+
)
|
39
|
+
from sglang.srt.models.llama import LlamaAttention, LlamaMLP
|
40
|
+
from sglang.srt.utils import add_prefix, make_layers
|
41
|
+
from sglang.utils import logger
|
42
|
+
|
43
|
+
|
44
|
+
def _ffn_mult_to_intermediate_size(ffn_mult: float, n_embd: int) -> int:
|
45
|
+
# DeciLM-specific code
|
46
|
+
intermediate_size = int(2 * ffn_mult * n_embd / 3)
|
47
|
+
return _find_multiple(intermediate_size, 256)
|
48
|
+
|
49
|
+
|
50
|
+
def _find_multiple(n: int, k: int) -> int:
|
51
|
+
# DeciLM-specific code
|
52
|
+
if n % k == 0:
|
53
|
+
return n
|
54
|
+
return n + k - (n % k)
|
55
|
+
|
56
|
+
|
57
|
+
class DeciLMDecoderLayer(nn.Module):
|
58
|
+
|
59
|
+
def __init__(
|
60
|
+
self,
|
61
|
+
config: LlamaConfig,
|
62
|
+
layer_idx: int,
|
63
|
+
quant_config: Optional[QuantizationConfig] = None,
|
64
|
+
prefix: str = "",
|
65
|
+
) -> None:
|
66
|
+
super().__init__()
|
67
|
+
block_config = config.block_configs[layer_idx]
|
68
|
+
self._is_no_op_attention = block_config.attention.no_op
|
69
|
+
self._is_no_op_ffn = block_config.ffn.no_op
|
70
|
+
|
71
|
+
self.hidden_size = config.hidden_size
|
72
|
+
rope_theta = getattr(config, "rope_theta", 10000)
|
73
|
+
rope_scaling = getattr(config, "rope_scaling", None)
|
74
|
+
if rope_scaling is not None and getattr(
|
75
|
+
config, "original_max_position_embeddings", None
|
76
|
+
):
|
77
|
+
rope_scaling["original_max_position_embeddings"] = (
|
78
|
+
config.original_max_position_embeddings
|
79
|
+
)
|
80
|
+
max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
|
81
|
+
# Support abacusai/Smaug-72B-v0.1 with attention_bias
|
82
|
+
# Support internlm/internlm-7b with bias
|
83
|
+
rope_is_neox_style = getattr(config, "rope_is_neox_style", True)
|
84
|
+
attention_bias = getattr(config, "attention_bias", False) or getattr(
|
85
|
+
config, "bias", False
|
86
|
+
)
|
87
|
+
# support internlm/internlm3-8b with qkv_bias
|
88
|
+
if hasattr(config, "qkv_bias"):
|
89
|
+
attention_bias = config.qkv_bias
|
90
|
+
|
91
|
+
if not self._is_no_op_attention:
|
92
|
+
num_kv_heads = (
|
93
|
+
config.num_attention_heads // block_config.attention.n_heads_in_group
|
94
|
+
)
|
95
|
+
self.self_attn = LlamaAttention(
|
96
|
+
config=config,
|
97
|
+
hidden_size=self.hidden_size,
|
98
|
+
num_heads=config.num_attention_heads,
|
99
|
+
num_kv_heads=num_kv_heads,
|
100
|
+
layer_id=layer_idx,
|
101
|
+
rope_theta=rope_theta,
|
102
|
+
rope_scaling=rope_scaling,
|
103
|
+
rope_is_neox_style=rope_is_neox_style,
|
104
|
+
max_position_embeddings=max_position_embeddings,
|
105
|
+
quant_config=quant_config,
|
106
|
+
prefix=add_prefix("self_attn", prefix),
|
107
|
+
bias=attention_bias,
|
108
|
+
)
|
109
|
+
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
110
|
+
|
111
|
+
if not self._is_no_op_ffn:
|
112
|
+
ffn_mult = block_config.ffn.ffn_mult
|
113
|
+
intermediate_size = _ffn_mult_to_intermediate_size(
|
114
|
+
ffn_mult, config.hidden_size
|
115
|
+
)
|
116
|
+
self.mlp = LlamaMLP(
|
117
|
+
hidden_size=self.hidden_size,
|
118
|
+
intermediate_size=intermediate_size,
|
119
|
+
hidden_act=config.hidden_act,
|
120
|
+
quant_config=quant_config,
|
121
|
+
prefix=add_prefix("mlp", prefix),
|
122
|
+
)
|
123
|
+
self.post_attention_layernorm = RMSNorm(
|
124
|
+
config.hidden_size, eps=config.rms_norm_eps
|
125
|
+
)
|
126
|
+
|
127
|
+
def forward(
|
128
|
+
self,
|
129
|
+
positions: torch.Tensor,
|
130
|
+
hidden_states: torch.Tensor,
|
131
|
+
forward_batch: ForwardBatch,
|
132
|
+
residual: Optional[torch.Tensor],
|
133
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
134
|
+
# Self Attention
|
135
|
+
|
136
|
+
if self._is_no_op_attention:
|
137
|
+
pass
|
138
|
+
else:
|
139
|
+
if residual is None:
|
140
|
+
residual = hidden_states
|
141
|
+
hidden_states = self.input_layernorm(hidden_states)
|
142
|
+
else:
|
143
|
+
hidden_states, residual = self.input_layernorm(hidden_states, residual)
|
144
|
+
hidden_states = self.self_attn(
|
145
|
+
positions=positions,
|
146
|
+
hidden_states=hidden_states,
|
147
|
+
forward_batch=forward_batch,
|
148
|
+
)
|
149
|
+
|
150
|
+
# Fully Connected
|
151
|
+
if not self._is_no_op_ffn:
|
152
|
+
hidden_states, residual = self.post_attention_layernorm(
|
153
|
+
hidden_states, residual
|
154
|
+
)
|
155
|
+
hidden_states = self.mlp(hidden_states)
|
156
|
+
return hidden_states, residual
|
157
|
+
|
158
|
+
|
159
|
+
class DeciModel(nn.Module):
|
160
|
+
def __init__(
|
161
|
+
self,
|
162
|
+
*,
|
163
|
+
config: LlamaConfig,
|
164
|
+
quant_config: Optional[QuantizationConfig] = None,
|
165
|
+
prefix: str = "",
|
166
|
+
layer_type: Type[DeciLMDecoderLayer] = DeciLMDecoderLayer,
|
167
|
+
):
|
168
|
+
super().__init__()
|
169
|
+
|
170
|
+
lora_config = None
|
171
|
+
self.config = config
|
172
|
+
self.quant_config = quant_config
|
173
|
+
self.padding_idx = config.pad_token_id
|
174
|
+
lora_vocab = (
|
175
|
+
(lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1))
|
176
|
+
if lora_config
|
177
|
+
else 0
|
178
|
+
)
|
179
|
+
vocab_size = config.vocab_size + lora_vocab
|
180
|
+
if get_pp_group().is_first_rank:
|
181
|
+
self.embed_tokens = VocabParallelEmbedding(
|
182
|
+
vocab_size,
|
183
|
+
config.hidden_size,
|
184
|
+
org_num_embeddings=config.vocab_size,
|
185
|
+
quant_config=quant_config,
|
186
|
+
)
|
187
|
+
else:
|
188
|
+
self.embed_tokens = PPMissingLayer()
|
189
|
+
|
190
|
+
def get_layer(idx: int, prefix: str):
|
191
|
+
return layer_type(
|
192
|
+
config,
|
193
|
+
layer_idx=idx,
|
194
|
+
quant_config=quant_config,
|
195
|
+
prefix=prefix,
|
196
|
+
)
|
197
|
+
|
198
|
+
self.layers, self.start_layer, self.end_layer = make_layers(
|
199
|
+
config.num_hidden_layers,
|
200
|
+
get_layer,
|
201
|
+
pp_rank=get_pp_group().rank_in_group,
|
202
|
+
pp_size=get_pp_group().world_size,
|
203
|
+
prefix=add_prefix("layers", prefix),
|
204
|
+
)
|
205
|
+
if get_pp_group().is_last_rank:
|
206
|
+
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
207
|
+
else:
|
208
|
+
self.norm = PPMissingLayer(return_tuple=True)
|
209
|
+
|
210
|
+
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
211
|
+
return self.embed_tokens(input_ids)
|
212
|
+
|
213
|
+
def forward(
|
214
|
+
self,
|
215
|
+
input_ids: Optional[torch.Tensor],
|
216
|
+
positions: torch.Tensor,
|
217
|
+
forward_batch: ForwardBatch,
|
218
|
+
inputs_embeds: Optional[torch.Tensor] = None,
|
219
|
+
pp_proxy_tensors: Optional[PPProxyTensors] = None,
|
220
|
+
) -> Union[torch.Tensor, PPProxyTensors]:
|
221
|
+
if get_pp_group().is_first_rank:
|
222
|
+
if inputs_embeds is not None:
|
223
|
+
hidden_states = inputs_embeds
|
224
|
+
else:
|
225
|
+
hidden_states = self.get_input_embeddings(input_ids)
|
226
|
+
residual = None
|
227
|
+
else:
|
228
|
+
assert pp_proxy_tensors is not None
|
229
|
+
hidden_states = pp_proxy_tensors["hidden_states"]
|
230
|
+
residual = pp_proxy_tensors["residual"]
|
231
|
+
|
232
|
+
kv_cache_index = 0
|
233
|
+
for i in range(self.start_layer, self.end_layer):
|
234
|
+
layer = self.layers[i]
|
235
|
+
if not layer._is_no_op_attention:
|
236
|
+
hidden_states, residual = layer(
|
237
|
+
positions, hidden_states, forward_batch, residual
|
238
|
+
)
|
239
|
+
kv_cache_index += 1
|
240
|
+
else:
|
241
|
+
hidden_states, residual = layer(
|
242
|
+
positions, hidden_states, forward_batch, residual
|
243
|
+
)
|
244
|
+
|
245
|
+
if not get_pp_group().is_last_rank:
|
246
|
+
return PPProxyTensors(
|
247
|
+
{"hidden_states": hidden_states, "residual": residual}
|
248
|
+
)
|
249
|
+
|
250
|
+
hidden_states, _ = self.norm(hidden_states, residual)
|
251
|
+
return hidden_states
|
252
|
+
|
253
|
+
|
254
|
+
class DeciLMForCausalLM(nn.Module):
|
255
|
+
packed_modules_mapping = {
|
256
|
+
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
|
257
|
+
"gate_up_proj": ["gate_proj", "up_proj"],
|
258
|
+
}
|
259
|
+
|
260
|
+
# LoRA specific attributes
|
261
|
+
supported_lora_modules = [
|
262
|
+
"qkv_proj",
|
263
|
+
"o_proj",
|
264
|
+
"gate_up_proj",
|
265
|
+
"down_proj",
|
266
|
+
"embed_tokens",
|
267
|
+
"lm_head",
|
268
|
+
]
|
269
|
+
embedding_modules = {
|
270
|
+
"embed_tokens": "input_embeddings",
|
271
|
+
"lm_head": "output_embeddings",
|
272
|
+
}
|
273
|
+
embedding_padding_modules = ["lm_head"]
|
274
|
+
|
275
|
+
# Mistral/Llama models can also be loaded with --load-format mistral
|
276
|
+
# from consolidated.safetensors checkpoints
|
277
|
+
mistral_mapping = {
|
278
|
+
"layers": "model.layers",
|
279
|
+
"attention": "self_attn",
|
280
|
+
"wq": "q_proj",
|
281
|
+
"wk": "k_proj",
|
282
|
+
"wv": "v_proj",
|
283
|
+
"wo": "o_proj",
|
284
|
+
"attention_norm": "input_layernorm",
|
285
|
+
"feed_forward": "mlp",
|
286
|
+
"w1": "gate_proj",
|
287
|
+
"w2": "down_proj",
|
288
|
+
"w3": "up_proj",
|
289
|
+
"ffn_norm": "post_attention_layernorm",
|
290
|
+
"tok_embeddings": "model.embed_tokens",
|
291
|
+
"output": "lm_head",
|
292
|
+
"norm": "model.norm",
|
293
|
+
}
|
294
|
+
|
295
|
+
def __init__(
|
296
|
+
self,
|
297
|
+
*,
|
298
|
+
config: LlamaConfig,
|
299
|
+
quant_config: Optional[QuantizationConfig] = None,
|
300
|
+
prefix: str = "",
|
301
|
+
):
|
302
|
+
super().__init__()
|
303
|
+
lora_config = None
|
304
|
+
self.config = config
|
305
|
+
self.lora_config = lora_config
|
306
|
+
|
307
|
+
self.model = self._init_model(
|
308
|
+
config=config, quant_config=quant_config, prefix=add_prefix("model", prefix)
|
309
|
+
)
|
310
|
+
if self.config.tie_word_embeddings:
|
311
|
+
self.lm_head = self.model.embed_tokens
|
312
|
+
else:
|
313
|
+
self.unpadded_vocab_size = config.vocab_size
|
314
|
+
if lora_config:
|
315
|
+
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
|
316
|
+
self.lm_head = ParallelLMHead(
|
317
|
+
self.unpadded_vocab_size,
|
318
|
+
config.hidden_size,
|
319
|
+
org_num_embeddings=config.vocab_size,
|
320
|
+
padding_size=(
|
321
|
+
DEFAULT_VOCAB_PADDING_SIZE
|
322
|
+
# We need bigger padding if using lora for kernel
|
323
|
+
# compatibility
|
324
|
+
if not lora_config
|
325
|
+
else lora_config.lora_vocab_padding_size
|
326
|
+
),
|
327
|
+
quant_config=quant_config,
|
328
|
+
prefix=add_prefix("lm_head", prefix),
|
329
|
+
)
|
330
|
+
self.logits_processor = LogitsProcessor(config)
|
331
|
+
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
|
332
|
+
|
333
|
+
def _init_model(
|
334
|
+
self,
|
335
|
+
config: LlamaConfig,
|
336
|
+
quant_config: Optional[QuantizationConfig] = None,
|
337
|
+
prefix: str = "",
|
338
|
+
):
|
339
|
+
return DeciModel(config=config, quant_config=quant_config, prefix=prefix)
|
340
|
+
|
341
|
+
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
342
|
+
return self.model.get_input_embeddings(input_ids)
|
343
|
+
|
344
|
+
@torch.no_grad()
|
345
|
+
def forward(
|
346
|
+
self,
|
347
|
+
input_ids: torch.Tensor,
|
348
|
+
positions: torch.Tensor,
|
349
|
+
forward_batch: ForwardBatch,
|
350
|
+
inputs_embeds: Optional[torch.Tensor] = None,
|
351
|
+
get_embedding: bool = False,
|
352
|
+
pp_proxy_tensors: Optional[PPProxyTensors] = None,
|
353
|
+
) -> LogitsProcessorOutput:
|
354
|
+
hidden_states = self.model(
|
355
|
+
input_ids,
|
356
|
+
positions,
|
357
|
+
forward_batch,
|
358
|
+
inputs_embeds,
|
359
|
+
pp_proxy_tensors=pp_proxy_tensors,
|
360
|
+
)
|
361
|
+
if get_pp_group().is_last_rank:
|
362
|
+
if not get_embedding:
|
363
|
+
return self.logits_processor(
|
364
|
+
input_ids, hidden_states, self.lm_head, forward_batch
|
365
|
+
)
|
366
|
+
else:
|
367
|
+
return self.pooler(hidden_states, forward_batch)
|
368
|
+
else:
|
369
|
+
return hidden_states
|
370
|
+
|
371
|
+
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> None:
|
372
|
+
stacked_params_mapping = [
|
373
|
+
# (param_name, shard_name, shard_id)
|
374
|
+
(".qkv_proj", ".q_proj", "q"),
|
375
|
+
(".qkv_proj", ".k_proj", "k"),
|
376
|
+
(".qkv_proj", ".v_proj", "v"),
|
377
|
+
(".gate_up_proj", ".gate_proj", 0),
|
378
|
+
(".gate_up_proj", ".up_proj", 1),
|
379
|
+
]
|
380
|
+
|
381
|
+
params_dict = dict(self.named_parameters())
|
382
|
+
|
383
|
+
for name, loaded_weight in weights:
|
384
|
+
if "rotary_emb.inv_freq" in name:
|
385
|
+
continue
|
386
|
+
if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name:
|
387
|
+
# Models trained using ColossalAI may include these tensors in
|
388
|
+
# the checkpoint. Skip them.
|
389
|
+
continue
|
390
|
+
if self.config.tie_word_embeddings and "lm_head.weight" in name:
|
391
|
+
continue
|
392
|
+
if self.model.quant_config is not None and (
|
393
|
+
scale_name := self.model.quant_config.get_cache_scale(name)
|
394
|
+
):
|
395
|
+
# Loading kv cache quantization scales
|
396
|
+
param = params_dict[scale_name]
|
397
|
+
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
398
|
+
loaded_weight = (
|
399
|
+
loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0]
|
400
|
+
)
|
401
|
+
weight_loader(param, loaded_weight)
|
402
|
+
continue
|
403
|
+
if "scale" in name:
|
404
|
+
name = maybe_remap_kv_scale_name(name, params_dict)
|
405
|
+
if name is None:
|
406
|
+
continue
|
407
|
+
|
408
|
+
for param_name, weight_name, shard_id in stacked_params_mapping:
|
409
|
+
if weight_name not in name:
|
410
|
+
continue
|
411
|
+
name = name.replace(weight_name, param_name)
|
412
|
+
# Skip loading extra bias for GPTQ models.
|
413
|
+
if name.endswith(".bias") and name not in params_dict:
|
414
|
+
continue
|
415
|
+
if name not in params_dict:
|
416
|
+
continue
|
417
|
+
param = params_dict[name]
|
418
|
+
weight_loader = param.weight_loader
|
419
|
+
weight_loader(param, loaded_weight, shard_id)
|
420
|
+
break
|
421
|
+
else:
|
422
|
+
# Skip loading extra bias for GPTQ models.
|
423
|
+
if name.endswith(".bias") and name not in params_dict:
|
424
|
+
continue
|
425
|
+
if name in params_dict.keys():
|
426
|
+
param = params_dict[name]
|
427
|
+
weight_loader = getattr(
|
428
|
+
param, "weight_loader", default_weight_loader
|
429
|
+
)
|
430
|
+
weight_loader(param, loaded_weight)
|
431
|
+
else:
|
432
|
+
logger.warning(f"Parameter {name} not found in params_dict")
|
433
|
+
|
434
|
+
|
435
|
+
EntryClass = [DeciLMForCausalLM]
|
sglang/srt/models/olmoe.py
CHANGED
sglang/srt/models/phi4mm.py
CHANGED
@@ -54,25 +54,6 @@ VISION_ENCODER_TO_PROCESSING_CONFIG = {
|
|
54
54
|
}
|
55
55
|
|
56
56
|
|
57
|
-
def get_navit_vision_model():
|
58
|
-
vision_config = {
|
59
|
-
"hidden_size": 1152,
|
60
|
-
"image_size": 448,
|
61
|
-
"intermediate_size": 4304,
|
62
|
-
"model_type": "siglip_vision_model",
|
63
|
-
"num_attention_heads": 16,
|
64
|
-
"num_hidden_layers": 26, # Model is originally 27-layer, we only need the first 26 layers for feature extraction.
|
65
|
-
"patch_size": 14,
|
66
|
-
}
|
67
|
-
model_config = SiglipVisionConfig(**vision_config)
|
68
|
-
|
69
|
-
vision_model = Idefics2VisionTransformer(
|
70
|
-
config=model_config, require_post_norm=False
|
71
|
-
)
|
72
|
-
|
73
|
-
return vision_model
|
74
|
-
|
75
|
-
|
76
57
|
class Phi4MMImageEncoder(nn.Module):
|
77
58
|
"""Image embedding."""
|
78
59
|
|
@@ -88,8 +69,9 @@ class Phi4MMImageEncoder(nn.Module):
|
|
88
69
|
# n_embed or hidden_size
|
89
70
|
hidden_size = config.n_embd if hasattr(config, "n_embd") else config.hidden_size
|
90
71
|
self.type_feature = "patch"
|
91
|
-
|
92
|
-
|
72
|
+
self.img_processor = Idefics2VisionTransformer(
|
73
|
+
config=config.vision_config, require_post_norm=False
|
74
|
+
)
|
93
75
|
|
94
76
|
pe_weight = self.img_processor.embeddings.position_embedding.weight
|
95
77
|
L, D = pe_weight.size()
|
sglang/srt/models/qwen2.py
CHANGED
@@ -27,6 +27,7 @@ from sglang.srt.distributed import (
|
|
27
27
|
get_tensor_model_parallel_world_size,
|
28
28
|
)
|
29
29
|
from sglang.srt.layers.activation import SiluAndMul
|
30
|
+
from sglang.srt.layers.dp_attention import is_dp_attention_enabled
|
30
31
|
from sglang.srt.layers.layernorm import RMSNorm
|
31
32
|
from sglang.srt.layers.linear import (
|
32
33
|
MergedColumnParallelLinear,
|
@@ -43,7 +44,6 @@ from sglang.srt.layers.vocab_parallel_embedding import (
|
|
43
44
|
ParallelLMHead,
|
44
45
|
VocabParallelEmbedding,
|
45
46
|
)
|
46
|
-
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
47
47
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
|
48
48
|
from sglang.srt.model_loader.weight_utils import (
|
49
49
|
default_weight_loader,
|
@@ -273,7 +273,7 @@ class Qwen2Model(nn.Module):
|
|
273
273
|
config.vocab_size,
|
274
274
|
config.hidden_size,
|
275
275
|
quant_config=quant_config,
|
276
|
-
enable_tp=not
|
276
|
+
enable_tp=not is_dp_attention_enabled(),
|
277
277
|
prefix=add_prefix("embed_tokens", prefix),
|
278
278
|
)
|
279
279
|
else:
|
sglang/srt/models/qwen2_5_vl.py
CHANGED
@@ -117,6 +117,7 @@ class Qwen2_5_VisionBlock(nn.Module):
|
|
117
117
|
attn_implementation: Optional[str] = None,
|
118
118
|
quant_config: Optional[QuantizationConfig] = None,
|
119
119
|
prefix: str = "",
|
120
|
+
num_dummy_heads: int = 0,
|
120
121
|
) -> None:
|
121
122
|
super().__init__()
|
122
123
|
if norm_layer is None:
|
@@ -157,6 +158,7 @@ class Qwen2_5_VisionBlock(nn.Module):
|
|
157
158
|
flatten_batch=flatten_batch,
|
158
159
|
quant_config=quant_config,
|
159
160
|
prefix=add_prefix("attn", prefix),
|
161
|
+
num_dummy_heads=num_dummy_heads,
|
160
162
|
)
|
161
163
|
self.mlp = Qwen2_5_VLMLP(
|
162
164
|
dim,
|
sglang/srt/models/qwen2_moe.py
CHANGED
@@ -17,8 +17,6 @@
|
|
17
17
|
"""Inference-only Qwen2MoE model compatible with HuggingFace weights."""
|
18
18
|
|
19
19
|
import logging
|
20
|
-
from dataclasses import dataclass
|
21
|
-
from enum import Enum, auto
|
22
20
|
from typing import Any, Dict, Iterable, Optional, Tuple, Union
|
23
21
|
|
24
22
|
import torch
|
@@ -31,10 +29,7 @@ from sglang.srt.distributed import (
|
|
31
29
|
get_tensor_model_parallel_world_size,
|
32
30
|
tensor_model_parallel_all_reduce,
|
33
31
|
)
|
34
|
-
from sglang.srt.eplb.expert_distribution import
|
35
|
-
ExpertDistributionRecorder,
|
36
|
-
get_global_expert_distribution_recorder,
|
37
|
-
)
|
32
|
+
from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
|
38
33
|
from sglang.srt.eplb.expert_location import ModelConfigForExpertLocation
|
39
34
|
from sglang.srt.layers.activation import SiluAndMul
|
40
35
|
from sglang.srt.layers.communicator import (
|
@@ -45,7 +40,7 @@ from sglang.srt.layers.communicator import (
|
|
45
40
|
from sglang.srt.layers.dp_attention import (
|
46
41
|
get_attention_tp_rank,
|
47
42
|
get_attention_tp_size,
|
48
|
-
|
43
|
+
is_dp_attention_enabled,
|
49
44
|
)
|
50
45
|
from sglang.srt.layers.layernorm import RMSNorm
|
51
46
|
from sglang.srt.layers.linear import (
|
@@ -54,8 +49,8 @@ from sglang.srt.layers.linear import (
|
|
54
49
|
ReplicatedLinear,
|
55
50
|
RowParallelLinear,
|
56
51
|
)
|
57
|
-
from sglang.srt.layers.logits_processor import LogitsProcessor
|
58
|
-
from sglang.srt.layers.moe.ep_moe.layer import
|
52
|
+
from sglang.srt.layers.logits_processor import LogitsProcessor
|
53
|
+
from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class
|
59
54
|
from sglang.srt.layers.moe.fused_moe_triton import FusedMoE
|
60
55
|
from sglang.srt.layers.moe.topk import TopK
|
61
56
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
@@ -107,10 +102,14 @@ class Qwen2MoeMLP(nn.Module):
|
|
107
102
|
)
|
108
103
|
self.act_fn = SiluAndMul()
|
109
104
|
|
110
|
-
def forward(
|
105
|
+
def forward(
|
106
|
+
self,
|
107
|
+
x,
|
108
|
+
use_reduce_scatter: bool = False,
|
109
|
+
):
|
111
110
|
gate_up, _ = self.gate_up_proj(x)
|
112
111
|
x = self.act_fn(gate_up)
|
113
|
-
x, _ = self.down_proj(x)
|
112
|
+
x, _ = self.down_proj(x, skip_all_reduce=use_reduce_scatter)
|
114
113
|
return x
|
115
114
|
|
116
115
|
|
@@ -144,14 +143,6 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
|
|
144
143
|
intermediate_size=config.moe_intermediate_size,
|
145
144
|
quant_config=quant_config,
|
146
145
|
prefix=add_prefix("experts", prefix),
|
147
|
-
# Additional args for FusedMoE
|
148
|
-
**(
|
149
|
-
dict(
|
150
|
-
enable_flashinfer_cutlass_moe=True,
|
151
|
-
)
|
152
|
-
if global_server_args_dict["enable_flashinfer_cutlass_moe"]
|
153
|
-
else {}
|
154
|
-
),
|
155
146
|
)
|
156
147
|
|
157
148
|
self.gate = ReplicatedLinear(
|
@@ -175,7 +166,10 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
|
|
175
166
|
self.shared_expert_gate = torch.nn.Linear(config.hidden_size, 1, bias=False)
|
176
167
|
|
177
168
|
def forward(
|
178
|
-
self,
|
169
|
+
self,
|
170
|
+
hidden_states: torch.Tensor,
|
171
|
+
forward_batch: Optional[ForwardBatch] = None,
|
172
|
+
use_reduce_scatter: bool = False,
|
179
173
|
) -> torch.Tensor:
|
180
174
|
num_tokens, hidden_dim = hidden_states.shape
|
181
175
|
hidden_states = hidden_states.view(-1, hidden_dim)
|
@@ -193,6 +187,7 @@ class Qwen2MoeSparseMoeBlock(nn.Module):
|
|
193
187
|
final_hidden_states = self.experts(hidden_states, topk_output)
|
194
188
|
if shared_output is not None:
|
195
189
|
final_hidden_states = final_hidden_states + shared_output
|
190
|
+
if self.tp_size > 1 and not use_reduce_scatter:
|
196
191
|
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
|
197
192
|
|
198
193
|
return final_hidden_states.view(num_tokens, hidden_dim)
|
@@ -331,7 +326,6 @@ class Qwen2MoeDecoderLayer(nn.Module):
|
|
331
326
|
|
332
327
|
self.attn_tp_size = get_attention_tp_size()
|
333
328
|
self.attn_tp_rank = get_attention_tp_rank()
|
334
|
-
self.local_dp_size = get_local_attention_dp_size()
|
335
329
|
|
336
330
|
# Qwen2MoE all layers are sparse and have no nextn now
|
337
331
|
self.is_layer_sparse = True
|
@@ -367,6 +361,7 @@ class Qwen2MoeDecoderLayer(nn.Module):
|
|
367
361
|
layer_scatter_modes=self.layer_scatter_modes,
|
368
362
|
input_layernorm=self.input_layernorm,
|
369
363
|
post_attention_layernorm=self.post_attention_layernorm,
|
364
|
+
allow_reduce_scatter=True,
|
370
365
|
)
|
371
366
|
|
372
367
|
def forward(
|
@@ -392,7 +387,12 @@ class Qwen2MoeDecoderLayer(nn.Module):
|
|
392
387
|
hidden_states, residual, forward_batch
|
393
388
|
)
|
394
389
|
|
395
|
-
|
390
|
+
# For DP with padding, reduce scatter can be used instead of all-reduce.
|
391
|
+
use_reduce_scatter = self.layer_communicator.should_use_reduce_scatter(
|
392
|
+
forward_batch
|
393
|
+
)
|
394
|
+
|
395
|
+
hidden_states = self.mlp(hidden_states, forward_batch, use_reduce_scatter)
|
396
396
|
|
397
397
|
hidden_states, residual = self.layer_communicator.postprocess_layer(
|
398
398
|
hidden_states, residual, forward_batch
|
@@ -420,7 +420,7 @@ class Qwen2MoeModel(nn.Module):
|
|
420
420
|
self.embed_tokens = VocabParallelEmbedding(
|
421
421
|
config.vocab_size,
|
422
422
|
config.hidden_size,
|
423
|
-
enable_tp=not
|
423
|
+
enable_tp=not is_dp_attention_enabled(),
|
424
424
|
prefix=add_prefix("embed_tokens", prefix),
|
425
425
|
)
|
426
426
|
else:
|
sglang/srt/models/qwen3.py
CHANGED
@@ -327,8 +327,8 @@ class Qwen3ForCausalLM(nn.Module):
|
|
327
327
|
# For EAGLE3 support
|
328
328
|
self.capture_aux_hidden_states = False
|
329
329
|
|
330
|
-
def get_input_embeddings(self
|
331
|
-
return self.model.get_input_embeddings(
|
330
|
+
def get_input_embeddings(self) -> nn.Embedding:
|
331
|
+
return self.model.get_input_embeddings()
|
332
332
|
|
333
333
|
@torch.no_grad()
|
334
334
|
def forward(
|