sglang 0.5.0rc2__py3-none-any.whl → 0.5.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +0 -6
- sglang/bench_one_batch_server.py +7 -2
- sglang/bench_serving.py +3 -3
- sglang/eval/llama3_eval.py +0 -1
- sglang/srt/configs/model_config.py +24 -9
- sglang/srt/configs/update_config.py +40 -5
- sglang/srt/constrained/xgrammar_backend.py +23 -11
- sglang/srt/conversation.py +2 -15
- sglang/srt/disaggregation/ascend/conn.py +1 -3
- sglang/srt/disaggregation/base/conn.py +1 -0
- sglang/srt/disaggregation/decode.py +1 -1
- sglang/srt/disaggregation/launch_lb.py +7 -1
- sglang/srt/disaggregation/mini_lb.py +11 -5
- sglang/srt/disaggregation/mooncake/conn.py +141 -47
- sglang/srt/disaggregation/prefill.py +261 -5
- sglang/srt/disaggregation/utils.py +2 -1
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +1 -1
- sglang/srt/distributed/device_communicators/pynccl.py +68 -18
- sglang/srt/distributed/device_communicators/pynccl_wrapper.py +52 -0
- sglang/srt/distributed/naive_distributed.py +112 -0
- sglang/srt/distributed/parallel_state.py +90 -4
- sglang/srt/entrypoints/context.py +20 -1
- sglang/srt/entrypoints/engine.py +27 -2
- sglang/srt/entrypoints/http_server.py +12 -0
- sglang/srt/entrypoints/openai/protocol.py +2 -2
- sglang/srt/entrypoints/openai/serving_chat.py +22 -6
- sglang/srt/entrypoints/openai/serving_completions.py +9 -1
- sglang/srt/entrypoints/openai/serving_responses.py +2 -2
- sglang/srt/eplb/expert_distribution.py +2 -3
- sglang/srt/function_call/deepseekv3_detector.py +1 -1
- sglang/srt/hf_transformers_utils.py +24 -0
- sglang/srt/host_shared_memory.py +83 -0
- sglang/srt/layers/attention/ascend_backend.py +132 -22
- sglang/srt/layers/attention/flashattention_backend.py +24 -17
- sglang/srt/layers/attention/flashinfer_backend.py +11 -3
- sglang/srt/layers/attention/flashinfer_mla_backend.py +226 -76
- sglang/srt/layers/attention/triton_backend.py +85 -46
- sglang/srt/layers/attention/triton_ops/decode_attention.py +33 -2
- sglang/srt/layers/attention/triton_ops/extend_attention.py +32 -2
- sglang/srt/layers/attention/trtllm_mha_backend.py +390 -30
- sglang/srt/layers/attention/trtllm_mla_backend.py +39 -16
- sglang/srt/layers/attention/utils.py +94 -15
- sglang/srt/layers/attention/vision.py +40 -13
- sglang/srt/layers/attention/vision_utils.py +65 -0
- sglang/srt/layers/communicator.py +51 -3
- sglang/srt/layers/dp_attention.py +23 -4
- sglang/srt/layers/elementwise.py +94 -0
- sglang/srt/layers/flashinfer_comm_fusion.py +29 -1
- sglang/srt/layers/layernorm.py +8 -1
- sglang/srt/layers/linear.py +24 -0
- sglang/srt/layers/logits_processor.py +5 -1
- sglang/srt/layers/moe/__init__.py +31 -0
- sglang/srt/layers/moe/ep_moe/layer.py +37 -33
- sglang/srt/layers/moe/fused_moe_native.py +14 -25
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=384,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=704,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=161,N=384,device_name=NVIDIA_RTX_PRO_6000_Blackwell_Max-Q_Workstation_Edition,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +69 -76
- sglang/srt/layers/moe/fused_moe_triton/layer.py +66 -123
- sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +20 -18
- sglang/srt/layers/moe/moe_runner/__init__.py +3 -0
- sglang/srt/layers/moe/moe_runner/base.py +13 -0
- sglang/srt/layers/moe/rocm_moe_utils.py +141 -0
- sglang/srt/layers/moe/router.py +15 -9
- sglang/srt/layers/moe/token_dispatcher/__init__.py +6 -0
- sglang/srt/layers/moe/token_dispatcher/base_dispatcher.py +55 -14
- sglang/srt/layers/moe/token_dispatcher/deepep.py +11 -21
- sglang/srt/layers/moe/token_dispatcher/standard.py +1 -1
- sglang/srt/layers/moe/topk.py +167 -83
- sglang/srt/layers/moe/utils.py +159 -18
- sglang/srt/layers/quantization/__init__.py +13 -14
- sglang/srt/layers/quantization/awq.py +7 -7
- sglang/srt/layers/quantization/base_config.py +2 -6
- sglang/srt/layers/quantization/blockwise_int8.py +4 -12
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +72 -28
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +2 -1
- sglang/srt/layers/quantization/fp8.py +127 -119
- sglang/srt/layers/quantization/fp8_kernel.py +195 -24
- sglang/srt/layers/quantization/fp8_utils.py +34 -9
- sglang/srt/layers/quantization/fpgemm_fp8.py +203 -0
- sglang/srt/layers/quantization/gptq.py +5 -4
- sglang/srt/layers/quantization/marlin_utils.py +11 -3
- sglang/srt/layers/quantization/marlin_utils_fp8.py +352 -0
- sglang/srt/layers/quantization/modelopt_quant.py +165 -68
- sglang/srt/layers/quantization/moe_wna16.py +10 -15
- sglang/srt/layers/quantization/mxfp4.py +206 -37
- sglang/srt/layers/quantization/quark/quark.py +390 -0
- sglang/srt/layers/quantization/quark/quark_moe.py +197 -0
- sglang/srt/layers/quantization/unquant.py +34 -70
- sglang/srt/layers/quantization/utils.py +25 -0
- sglang/srt/layers/quantization/w4afp8.py +7 -8
- sglang/srt/layers/quantization/w8a8_fp8.py +5 -13
- sglang/srt/layers/quantization/w8a8_int8.py +5 -13
- sglang/srt/layers/radix_attention.py +6 -0
- sglang/srt/layers/rotary_embedding.py +1 -0
- sglang/srt/lora/lora_manager.py +21 -22
- sglang/srt/lora/lora_registry.py +3 -3
- sglang/srt/lora/mem_pool.py +26 -24
- sglang/srt/lora/utils.py +10 -12
- sglang/srt/managers/cache_controller.py +76 -18
- sglang/srt/managers/detokenizer_manager.py +10 -2
- sglang/srt/managers/io_struct.py +9 -0
- sglang/srt/managers/mm_utils.py +1 -1
- sglang/srt/managers/schedule_batch.py +4 -9
- sglang/srt/managers/scheduler.py +25 -16
- sglang/srt/managers/session_controller.py +1 -1
- sglang/srt/managers/template_manager.py +7 -5
- sglang/srt/managers/tokenizer_manager.py +60 -21
- sglang/srt/managers/tp_worker.py +1 -0
- sglang/srt/managers/utils.py +59 -1
- sglang/srt/mem_cache/allocator.py +7 -5
- sglang/srt/mem_cache/allocator_ascend.py +0 -11
- sglang/srt/mem_cache/hicache_storage.py +14 -4
- sglang/srt/mem_cache/memory_pool.py +3 -3
- sglang/srt/mem_cache/memory_pool_host.py +35 -2
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +56 -12
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +8 -4
- sglang/srt/mem_cache/storage/nixl/hicache_nixl.py +153 -59
- sglang/srt/mem_cache/storage/nixl/nixl_utils.py +19 -53
- sglang/srt/mem_cache/storage/nixl/test_hicache_nixl_storage.py +46 -7
- sglang/srt/model_executor/cuda_graph_runner.py +25 -12
- sglang/srt/model_executor/forward_batch_info.py +4 -1
- sglang/srt/model_executor/model_runner.py +43 -32
- sglang/srt/model_executor/npu_graph_runner.py +94 -0
- sglang/srt/model_loader/loader.py +24 -6
- sglang/srt/models/dbrx.py +12 -6
- sglang/srt/models/deepseek.py +2 -1
- sglang/srt/models/deepseek_nextn.py +3 -1
- sglang/srt/models/deepseek_v2.py +224 -223
- sglang/srt/models/ernie4.py +2 -2
- sglang/srt/models/glm4_moe.py +25 -63
- sglang/srt/models/glm4v.py +52 -1
- sglang/srt/models/glm4v_moe.py +8 -11
- sglang/srt/models/gpt_oss.py +34 -74
- sglang/srt/models/granitemoe.py +0 -1
- sglang/srt/models/grok.py +376 -48
- sglang/srt/models/interns1.py +12 -47
- sglang/srt/models/internvl.py +6 -51
- sglang/srt/models/llama4.py +0 -2
- sglang/srt/models/minicpm3.py +0 -1
- sglang/srt/models/mixtral.py +0 -2
- sglang/srt/models/nemotron_nas.py +435 -0
- sglang/srt/models/olmoe.py +0 -1
- sglang/srt/models/phi4mm.py +3 -21
- sglang/srt/models/qwen2_5_vl.py +2 -0
- sglang/srt/models/qwen2_moe.py +3 -18
- sglang/srt/models/qwen3.py +2 -2
- sglang/srt/models/qwen3_classification.py +7 -1
- sglang/srt/models/qwen3_moe.py +9 -38
- sglang/srt/models/step3_vl.py +2 -1
- sglang/srt/models/xverse_moe.py +11 -5
- sglang/srt/multimodal/processors/base_processor.py +3 -3
- sglang/srt/multimodal/processors/internvl.py +7 -2
- sglang/srt/multimodal/processors/llava.py +11 -7
- sglang/srt/offloader.py +433 -0
- sglang/srt/operations.py +6 -1
- sglang/srt/reasoning_parser.py +4 -3
- sglang/srt/server_args.py +237 -104
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +1 -0
- sglang/srt/speculative/eagle_utils.py +36 -13
- sglang/srt/speculative/eagle_worker.py +56 -3
- sglang/srt/tokenizer/tiktoken_tokenizer.py +161 -0
- sglang/srt/two_batch_overlap.py +16 -11
- sglang/srt/utils.py +68 -70
- sglang/test/runners.py +8 -5
- sglang/test/test_block_fp8.py +5 -6
- sglang/test/test_block_fp8_ep.py +13 -19
- sglang/test/test_cutlass_moe.py +4 -6
- sglang/test/test_cutlass_w4a8_moe.py +4 -3
- sglang/test/test_fp4_moe.py +4 -3
- sglang/test/test_utils.py +7 -0
- sglang/utils.py +0 -1
- sglang/version.py +1 -1
- {sglang-0.5.0rc2.dist-info → sglang-0.5.1.dist-info}/METADATA +7 -7
- {sglang-0.5.0rc2.dist-info → sglang-0.5.1.dist-info}/RECORD +179 -161
- sglang/srt/layers/quantization/fp4.py +0 -557
- {sglang-0.5.0rc2.dist-info → sglang-0.5.1.dist-info}/WHEEL +0 -0
- {sglang-0.5.0rc2.dist-info → sglang-0.5.1.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.0rc2.dist-info → sglang-0.5.1.dist-info}/top_level.txt +0 -0
sglang/srt/models/interns1.py
CHANGED
@@ -4,8 +4,9 @@ import torch
|
|
4
4
|
from torch import nn
|
5
5
|
from transformers import PretrainedConfig
|
6
6
|
|
7
|
-
from sglang.srt.
|
7
|
+
from sglang.srt.layers.attention import vision_utils
|
8
8
|
from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class
|
9
|
+
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
|
9
10
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
10
11
|
from sglang.srt.managers.mm_utils import (
|
11
12
|
MultiModalityDataPaddingPatternTokenPairs,
|
@@ -20,6 +21,7 @@ from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
|
20
21
|
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
21
22
|
from sglang.srt.models.internvl import InternVisionModel
|
22
23
|
from sglang.srt.models.qwen2 import Qwen2ForCausalLM
|
24
|
+
from sglang.srt.models.qwen3 import Qwen3ForCausalLM
|
23
25
|
from sglang.srt.models.qwen3_moe import Qwen3MoeForCausalLM
|
24
26
|
from sglang.utils import logger
|
25
27
|
|
@@ -34,7 +36,7 @@ class InternS1ForConditionalGeneration(nn.Module):
|
|
34
36
|
super().__init__()
|
35
37
|
self.config = config
|
36
38
|
self.quant_config = quant_config
|
37
|
-
self.
|
39
|
+
vision_utils.update_vit_attn_dummy_heads_config(self.config)
|
38
40
|
image_size = (
|
39
41
|
getattr(config, "force_image_size", None) or config.vision_config.image_size
|
40
42
|
)
|
@@ -69,6 +71,10 @@ class InternS1ForConditionalGeneration(nn.Module):
|
|
69
71
|
self.language_model = Qwen3MoeForCausalLM(
|
70
72
|
config=config.text_config, quant_config=quant_config
|
71
73
|
)
|
74
|
+
elif config.text_config.architectures[0] == "Qwen3ForCausalLM":
|
75
|
+
self.language_model = Qwen3ForCausalLM(
|
76
|
+
config=config.text_config, quant_config=quant_config
|
77
|
+
)
|
72
78
|
else:
|
73
79
|
raise NotImplementedError(
|
74
80
|
f"{config.text_config.architectures[0]} is not implemented."
|
@@ -86,21 +92,6 @@ class InternS1ForConditionalGeneration(nn.Module):
|
|
86
92
|
nn.Linear(llm_hidden_size, llm_hidden_size),
|
87
93
|
)
|
88
94
|
|
89
|
-
def _update_hf_config(self):
|
90
|
-
"""update hf config to support tp"""
|
91
|
-
world_size = parallel_state.get_tensor_model_parallel_world_size()
|
92
|
-
num_heads = self.config.vision_config.num_attention_heads
|
93
|
-
head_dim = self.config.vision_config.hidden_size // num_heads
|
94
|
-
num_dummy_heads = 0
|
95
|
-
|
96
|
-
if num_heads % world_size != 0:
|
97
|
-
num_dummy_heads = (
|
98
|
-
(num_heads + world_size) // world_size
|
99
|
-
) * world_size - num_heads
|
100
|
-
|
101
|
-
setattr(self.config.vision_config, "head_dim", head_dim)
|
102
|
-
setattr(self.config.vision_config, "num_dummy_heads", num_dummy_heads)
|
103
|
-
|
104
95
|
def pixel_shuffle(self, x, scale_factor=0.5):
|
105
96
|
n, w, h, c = x.size()
|
106
97
|
# N, W, H, C --> N, W, H * scale, C // scale
|
@@ -183,34 +174,6 @@ class InternS1ForConditionalGeneration(nn.Module):
|
|
183
174
|
|
184
175
|
return helper.pad_input_tokens(input_ids, mm_inputs)
|
185
176
|
|
186
|
-
def _pad_vit_attn_dummy_heads(self, name: str, loaded_weight: torch.Tensor):
|
187
|
-
"""pad attn qkv weights for dummy heads"""
|
188
|
-
num_dummy_heads = self.config.vision_config.num_dummy_heads
|
189
|
-
if num_dummy_heads == 0:
|
190
|
-
return loaded_weight
|
191
|
-
head_dim = self.config.vision_config.head_dim
|
192
|
-
|
193
|
-
if any([_ in name for _ in ["attn.q_proj", "attn.k_proj", "attn.v_proj"]]):
|
194
|
-
if name.endswith(".weight"):
|
195
|
-
dummy_shape = [num_dummy_heads, head_dim, loaded_weight.shape[-1]]
|
196
|
-
elif name.endswith(".bias"):
|
197
|
-
dummy_shape = [num_dummy_heads, head_dim]
|
198
|
-
else:
|
199
|
-
raise RuntimeError(f"Unsupported weight with name={name}")
|
200
|
-
padded_weight = loaded_weight.new_zeros(dummy_shape)
|
201
|
-
loaded_weight = torch.cat(
|
202
|
-
[loaded_weight.unflatten(0, (-1, head_dim)), padded_weight], dim=0
|
203
|
-
).flatten(0, 1)
|
204
|
-
if "attn.proj.weight" in name:
|
205
|
-
padded_weight = loaded_weight.new_zeros(
|
206
|
-
loaded_weight.shape[0], head_dim * num_dummy_heads
|
207
|
-
)
|
208
|
-
loaded_weight = torch.cat([loaded_weight, padded_weight], dim=-1)
|
209
|
-
if "attn.q_norm.weight" in name or "attn.k_norm.weight" in name:
|
210
|
-
padded_weight = loaded_weight.new_zeros(head_dim * num_dummy_heads)
|
211
|
-
loaded_weight = torch.cat([loaded_weight, padded_weight], dim=0)
|
212
|
-
return loaded_weight
|
213
|
-
|
214
177
|
def _mapping_interns1_name(self, name):
|
215
178
|
names_map = {
|
216
179
|
"lm_head.weight": "language_model.lm_head.weight",
|
@@ -254,7 +217,7 @@ class InternS1ForConditionalGeneration(nn.Module):
|
|
254
217
|
]
|
255
218
|
expert_params_mapping = []
|
256
219
|
if "Qwen3MoeForCausalLM" in self.config.text_config.architectures:
|
257
|
-
expert_params_mapping =
|
220
|
+
expert_params_mapping = FusedMoE.make_expert_params_mapping(
|
258
221
|
ckpt_gate_proj_name="gate_proj",
|
259
222
|
ckpt_down_proj_name="down_proj",
|
260
223
|
ckpt_up_proj_name="up_proj",
|
@@ -269,7 +232,9 @@ class InternS1ForConditionalGeneration(nn.Module):
|
|
269
232
|
continue
|
270
233
|
name = self._mapping_interns1_name(name)
|
271
234
|
if "vision_model" in name:
|
272
|
-
loaded_weight =
|
235
|
+
loaded_weight = vision_utils.pad_vit_attn_dummy_heads(
|
236
|
+
self.config, name, loaded_weight
|
237
|
+
)
|
273
238
|
|
274
239
|
for param_name, weight_name, shard_id in stacked_params_mapping:
|
275
240
|
if weight_name not in name:
|
sglang/srt/models/internvl.py
CHANGED
@@ -10,9 +10,9 @@ from transformers import PretrainedConfig, PreTrainedModel
|
|
10
10
|
from transformers.activations import ACT2FN
|
11
11
|
from transformers.modeling_outputs import BaseModelOutput, BaseModelOutputWithPooling
|
12
12
|
|
13
|
-
from sglang.srt.
|
13
|
+
from sglang.srt.layers.attention import vision_utils
|
14
14
|
from sglang.srt.layers.attention.vision import SingletonCache, VisionAttention
|
15
|
-
from sglang.srt.layers.moe.
|
15
|
+
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
|
16
16
|
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
17
17
|
from sglang.srt.managers.mm_utils import (
|
18
18
|
MultiModalityDataPaddingPatternTokenPairs,
|
@@ -412,7 +412,7 @@ class InternVLChatModel(nn.Module):
|
|
412
412
|
super().__init__()
|
413
413
|
self.config = config
|
414
414
|
self.quant_config = quant_config
|
415
|
-
self.
|
415
|
+
vision_utils.update_vit_attn_dummy_heads_config(self.config)
|
416
416
|
image_size = config.force_image_size or config.vision_config.image_size
|
417
417
|
patch_size = config.vision_config.patch_size
|
418
418
|
self.patch_size = patch_size
|
@@ -462,21 +462,6 @@ class InternVLChatModel(nn.Module):
|
|
462
462
|
nn.Linear(llm_hidden_size, llm_hidden_size),
|
463
463
|
)
|
464
464
|
|
465
|
-
def _update_vision_config(self):
|
466
|
-
"""update vision config to support tp"""
|
467
|
-
world_size = parallel_state.get_tensor_model_parallel_world_size()
|
468
|
-
num_heads = self.config.vision_config.num_attention_heads
|
469
|
-
head_dim = self.config.vision_config.hidden_size // num_heads
|
470
|
-
num_dummy_heads = 0
|
471
|
-
|
472
|
-
if num_heads % world_size != 0:
|
473
|
-
num_dummy_heads = (
|
474
|
-
(num_heads + world_size) // world_size
|
475
|
-
) * world_size - num_heads
|
476
|
-
|
477
|
-
setattr(self.config.vision_config, "head_dim", head_dim)
|
478
|
-
setattr(self.config.vision_config, "num_dummy_heads", num_dummy_heads)
|
479
|
-
|
480
465
|
def pixel_shuffle(self, x, scale_factor=0.5):
|
481
466
|
n, w, h, c = x.size()
|
482
467
|
# N, W, H, C --> N, W, H * scale, C // scale
|
@@ -559,36 +544,6 @@ class InternVLChatModel(nn.Module):
|
|
559
544
|
|
560
545
|
return helper.pad_input_tokens(input_ids, mm_inputs)
|
561
546
|
|
562
|
-
def _pad_vit_attn_dummy_heads(self, name: str, loaded_weight: torch.Tensor):
|
563
|
-
"""pad attn qkv weights for dummy heads"""
|
564
|
-
num_dummy_heads = self.config.vision_config.num_dummy_heads
|
565
|
-
if num_dummy_heads == 0:
|
566
|
-
return loaded_weight
|
567
|
-
head_dim = self.config.vision_config.head_dim
|
568
|
-
|
569
|
-
if "attn.qkv_proj" in name:
|
570
|
-
wq, wk, wv = loaded_weight.chunk(3, dim=0)
|
571
|
-
if name.endswith(".weight"):
|
572
|
-
dummy_shape = [num_dummy_heads, head_dim, wq.shape[-1]]
|
573
|
-
elif name.endswith(".bias"):
|
574
|
-
dummy_shape = [num_dummy_heads, head_dim]
|
575
|
-
else:
|
576
|
-
raise RuntimeError(f"Unsupported weight with name={name}")
|
577
|
-
pad_func = lambda x: torch.cat(
|
578
|
-
[x.unflatten(0, (-1, head_dim)), x.new_zeros(dummy_shape)], dim=0
|
579
|
-
).flatten(0, 1)
|
580
|
-
wq, wk, wv = pad_func(wq), pad_func(wk), pad_func(wv)
|
581
|
-
loaded_weight = torch.cat([wq, wk, wv], dim=0)
|
582
|
-
if "attn.proj.weight" in name:
|
583
|
-
padded_weight = loaded_weight.new_zeros(
|
584
|
-
loaded_weight.shape[0], head_dim * num_dummy_heads
|
585
|
-
)
|
586
|
-
loaded_weight = torch.cat([loaded_weight, padded_weight], dim=-1)
|
587
|
-
if "attn.q_norm.weight" in name or "attn.k_norm.weight" in name:
|
588
|
-
padded_weight = loaded_weight.new_zeros(head_dim * num_dummy_heads)
|
589
|
-
loaded_weight = torch.cat([loaded_weight, padded_weight], dim=0)
|
590
|
-
return loaded_weight
|
591
|
-
|
592
547
|
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
593
548
|
expert_params_mapping = []
|
594
549
|
if "InternLM2ForCausalLM" in self.config.llm_config.architectures:
|
@@ -616,7 +571,7 @@ class InternVLChatModel(nn.Module):
|
|
616
571
|
("gate_up_proj", "up_proj", 1),
|
617
572
|
]
|
618
573
|
|
619
|
-
expert_params_mapping =
|
574
|
+
expert_params_mapping = FusedMoE.make_expert_params_mapping(
|
620
575
|
ckpt_gate_proj_name="gate_proj",
|
621
576
|
ckpt_down_proj_name="down_proj",
|
622
577
|
ckpt_up_proj_name="up_proj",
|
@@ -699,8 +654,8 @@ class InternVLChatModel(nn.Module):
|
|
699
654
|
param, "weight_loader", default_weight_loader
|
700
655
|
)
|
701
656
|
if "vision_model" in name:
|
702
|
-
loaded_weight =
|
703
|
-
name, loaded_weight
|
657
|
+
loaded_weight = vision_utils.pad_vit_attn_dummy_heads(
|
658
|
+
self.config, name, loaded_weight
|
704
659
|
)
|
705
660
|
weight_loader(param, loaded_weight)
|
706
661
|
|
sglang/srt/models/llama4.py
CHANGED
@@ -31,7 +31,6 @@ from sglang.srt.layers.communicator import LayerCommunicator, LayerScatterModes
|
|
31
31
|
from sglang.srt.layers.dp_attention import (
|
32
32
|
get_attention_tp_rank,
|
33
33
|
get_attention_tp_size,
|
34
|
-
get_local_attention_dp_size,
|
35
34
|
is_dp_attention_enabled,
|
36
35
|
)
|
37
36
|
from sglang.srt.layers.layernorm import RMSNorm
|
@@ -364,7 +363,6 @@ class Llama4DecoderLayer(nn.Module):
|
|
364
363
|
rope_theta = config.rope_theta
|
365
364
|
rope_scaling = config.rope_scaling
|
366
365
|
max_position_embeddings = config.max_position_embeddings
|
367
|
-
self.local_dp_size = get_local_attention_dp_size()
|
368
366
|
self.attn_tp_size = get_attention_tp_size()
|
369
367
|
self.attn_tp_rank = get_attention_tp_rank()
|
370
368
|
|
sglang/srt/models/minicpm3.py
CHANGED
@@ -37,7 +37,6 @@ from sglang.srt.layers.vocab_parallel_embedding import (
|
|
37
37
|
ParallelLMHead,
|
38
38
|
VocabParallelEmbedding,
|
39
39
|
)
|
40
|
-
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
41
40
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
42
41
|
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
43
42
|
from sglang.srt.utils import add_prefix, is_cuda
|
sglang/srt/models/mixtral.py
CHANGED
@@ -47,7 +47,6 @@ from sglang.srt.layers.vocab_parallel_embedding import (
|
|
47
47
|
ParallelLMHead,
|
48
48
|
VocabParallelEmbedding,
|
49
49
|
)
|
50
|
-
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
51
50
|
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
|
52
51
|
from sglang.srt.model_loader.weight_utils import default_weight_loader
|
53
52
|
from sglang.srt.utils import add_prefix, make_layers
|
@@ -104,7 +103,6 @@ class MixtralMoE(nn.Module):
|
|
104
103
|
intermediate_size=intermediate_size,
|
105
104
|
params_dtype=params_dtype,
|
106
105
|
quant_config=quant_config,
|
107
|
-
tp_size=tp_size,
|
108
106
|
prefix=add_prefix("experts", prefix),
|
109
107
|
)
|
110
108
|
|
@@ -0,0 +1,435 @@
|
|
1
|
+
# Copyright 2023-2025 SGLang Team
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
#
|
6
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
7
|
+
#
|
8
|
+
# Unless required by applicable law or agreed to in writing, software
|
9
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
10
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
11
|
+
# See the License for the specific language governing permissions and
|
12
|
+
# limitations under the License.
|
13
|
+
# ==============================================================================
|
14
|
+
# Adapted from https://github.com/vllm-project/vllm/blob/main/vllm/model_executor/models/nemotron_nas.py
|
15
|
+
|
16
|
+
"""Inference-only deci model compatible with HuggingFace weights."""
|
17
|
+
from typing import Iterable, Optional, Tuple, Type, Union
|
18
|
+
|
19
|
+
import torch
|
20
|
+
from torch import nn
|
21
|
+
from transformers import LlamaConfig
|
22
|
+
|
23
|
+
from sglang.srt.distributed import get_pp_group
|
24
|
+
from sglang.srt.layers.layernorm import RMSNorm
|
25
|
+
from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput
|
26
|
+
from sglang.srt.layers.pooler import Pooler, PoolingType
|
27
|
+
from sglang.srt.layers.quantization import QuantizationConfig
|
28
|
+
from sglang.srt.layers.utils import PPMissingLayer
|
29
|
+
from sglang.srt.layers.vocab_parallel_embedding import (
|
30
|
+
DEFAULT_VOCAB_PADDING_SIZE,
|
31
|
+
ParallelLMHead,
|
32
|
+
VocabParallelEmbedding,
|
33
|
+
)
|
34
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
|
35
|
+
from sglang.srt.model_loader.weight_utils import (
|
36
|
+
default_weight_loader,
|
37
|
+
maybe_remap_kv_scale_name,
|
38
|
+
)
|
39
|
+
from sglang.srt.models.llama import LlamaAttention, LlamaMLP
|
40
|
+
from sglang.srt.utils import add_prefix, make_layers
|
41
|
+
from sglang.utils import logger
|
42
|
+
|
43
|
+
|
44
|
+
def _ffn_mult_to_intermediate_size(ffn_mult: float, n_embd: int) -> int:
|
45
|
+
# DeciLM-specific code
|
46
|
+
intermediate_size = int(2 * ffn_mult * n_embd / 3)
|
47
|
+
return _find_multiple(intermediate_size, 256)
|
48
|
+
|
49
|
+
|
50
|
+
def _find_multiple(n: int, k: int) -> int:
|
51
|
+
# DeciLM-specific code
|
52
|
+
if n % k == 0:
|
53
|
+
return n
|
54
|
+
return n + k - (n % k)
|
55
|
+
|
56
|
+
|
57
|
+
class DeciLMDecoderLayer(nn.Module):
|
58
|
+
|
59
|
+
def __init__(
|
60
|
+
self,
|
61
|
+
config: LlamaConfig,
|
62
|
+
layer_idx: int,
|
63
|
+
quant_config: Optional[QuantizationConfig] = None,
|
64
|
+
prefix: str = "",
|
65
|
+
) -> None:
|
66
|
+
super().__init__()
|
67
|
+
block_config = config.block_configs[layer_idx]
|
68
|
+
self._is_no_op_attention = block_config.attention.no_op
|
69
|
+
self._is_no_op_ffn = block_config.ffn.no_op
|
70
|
+
|
71
|
+
self.hidden_size = config.hidden_size
|
72
|
+
rope_theta = getattr(config, "rope_theta", 10000)
|
73
|
+
rope_scaling = getattr(config, "rope_scaling", None)
|
74
|
+
if rope_scaling is not None and getattr(
|
75
|
+
config, "original_max_position_embeddings", None
|
76
|
+
):
|
77
|
+
rope_scaling["original_max_position_embeddings"] = (
|
78
|
+
config.original_max_position_embeddings
|
79
|
+
)
|
80
|
+
max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
|
81
|
+
# Support abacusai/Smaug-72B-v0.1 with attention_bias
|
82
|
+
# Support internlm/internlm-7b with bias
|
83
|
+
rope_is_neox_style = getattr(config, "rope_is_neox_style", True)
|
84
|
+
attention_bias = getattr(config, "attention_bias", False) or getattr(
|
85
|
+
config, "bias", False
|
86
|
+
)
|
87
|
+
# support internlm/internlm3-8b with qkv_bias
|
88
|
+
if hasattr(config, "qkv_bias"):
|
89
|
+
attention_bias = config.qkv_bias
|
90
|
+
|
91
|
+
if not self._is_no_op_attention:
|
92
|
+
num_kv_heads = (
|
93
|
+
config.num_attention_heads // block_config.attention.n_heads_in_group
|
94
|
+
)
|
95
|
+
self.self_attn = LlamaAttention(
|
96
|
+
config=config,
|
97
|
+
hidden_size=self.hidden_size,
|
98
|
+
num_heads=config.num_attention_heads,
|
99
|
+
num_kv_heads=num_kv_heads,
|
100
|
+
layer_id=layer_idx,
|
101
|
+
rope_theta=rope_theta,
|
102
|
+
rope_scaling=rope_scaling,
|
103
|
+
rope_is_neox_style=rope_is_neox_style,
|
104
|
+
max_position_embeddings=max_position_embeddings,
|
105
|
+
quant_config=quant_config,
|
106
|
+
prefix=add_prefix("self_attn", prefix),
|
107
|
+
bias=attention_bias,
|
108
|
+
)
|
109
|
+
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
110
|
+
|
111
|
+
if not self._is_no_op_ffn:
|
112
|
+
ffn_mult = block_config.ffn.ffn_mult
|
113
|
+
intermediate_size = _ffn_mult_to_intermediate_size(
|
114
|
+
ffn_mult, config.hidden_size
|
115
|
+
)
|
116
|
+
self.mlp = LlamaMLP(
|
117
|
+
hidden_size=self.hidden_size,
|
118
|
+
intermediate_size=intermediate_size,
|
119
|
+
hidden_act=config.hidden_act,
|
120
|
+
quant_config=quant_config,
|
121
|
+
prefix=add_prefix("mlp", prefix),
|
122
|
+
)
|
123
|
+
self.post_attention_layernorm = RMSNorm(
|
124
|
+
config.hidden_size, eps=config.rms_norm_eps
|
125
|
+
)
|
126
|
+
|
127
|
+
def forward(
|
128
|
+
self,
|
129
|
+
positions: torch.Tensor,
|
130
|
+
hidden_states: torch.Tensor,
|
131
|
+
forward_batch: ForwardBatch,
|
132
|
+
residual: Optional[torch.Tensor],
|
133
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
134
|
+
# Self Attention
|
135
|
+
|
136
|
+
if self._is_no_op_attention:
|
137
|
+
pass
|
138
|
+
else:
|
139
|
+
if residual is None:
|
140
|
+
residual = hidden_states
|
141
|
+
hidden_states = self.input_layernorm(hidden_states)
|
142
|
+
else:
|
143
|
+
hidden_states, residual = self.input_layernorm(hidden_states, residual)
|
144
|
+
hidden_states = self.self_attn(
|
145
|
+
positions=positions,
|
146
|
+
hidden_states=hidden_states,
|
147
|
+
forward_batch=forward_batch,
|
148
|
+
)
|
149
|
+
|
150
|
+
# Fully Connected
|
151
|
+
if not self._is_no_op_ffn:
|
152
|
+
hidden_states, residual = self.post_attention_layernorm(
|
153
|
+
hidden_states, residual
|
154
|
+
)
|
155
|
+
hidden_states = self.mlp(hidden_states)
|
156
|
+
return hidden_states, residual
|
157
|
+
|
158
|
+
|
159
|
+
class DeciModel(nn.Module):
|
160
|
+
def __init__(
|
161
|
+
self,
|
162
|
+
*,
|
163
|
+
config: LlamaConfig,
|
164
|
+
quant_config: Optional[QuantizationConfig] = None,
|
165
|
+
prefix: str = "",
|
166
|
+
layer_type: Type[DeciLMDecoderLayer] = DeciLMDecoderLayer,
|
167
|
+
):
|
168
|
+
super().__init__()
|
169
|
+
|
170
|
+
lora_config = None
|
171
|
+
self.config = config
|
172
|
+
self.quant_config = quant_config
|
173
|
+
self.padding_idx = config.pad_token_id
|
174
|
+
lora_vocab = (
|
175
|
+
(lora_config.lora_extra_vocab_size * (lora_config.max_loras or 1))
|
176
|
+
if lora_config
|
177
|
+
else 0
|
178
|
+
)
|
179
|
+
vocab_size = config.vocab_size + lora_vocab
|
180
|
+
if get_pp_group().is_first_rank:
|
181
|
+
self.embed_tokens = VocabParallelEmbedding(
|
182
|
+
vocab_size,
|
183
|
+
config.hidden_size,
|
184
|
+
org_num_embeddings=config.vocab_size,
|
185
|
+
quant_config=quant_config,
|
186
|
+
)
|
187
|
+
else:
|
188
|
+
self.embed_tokens = PPMissingLayer()
|
189
|
+
|
190
|
+
def get_layer(idx: int, prefix: str):
|
191
|
+
return layer_type(
|
192
|
+
config,
|
193
|
+
layer_idx=idx,
|
194
|
+
quant_config=quant_config,
|
195
|
+
prefix=prefix,
|
196
|
+
)
|
197
|
+
|
198
|
+
self.layers, self.start_layer, self.end_layer = make_layers(
|
199
|
+
config.num_hidden_layers,
|
200
|
+
get_layer,
|
201
|
+
pp_rank=get_pp_group().rank_in_group,
|
202
|
+
pp_size=get_pp_group().world_size,
|
203
|
+
prefix=add_prefix("layers", prefix),
|
204
|
+
)
|
205
|
+
if get_pp_group().is_last_rank:
|
206
|
+
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
207
|
+
else:
|
208
|
+
self.norm = PPMissingLayer(return_tuple=True)
|
209
|
+
|
210
|
+
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
211
|
+
return self.embed_tokens(input_ids)
|
212
|
+
|
213
|
+
def forward(
|
214
|
+
self,
|
215
|
+
input_ids: Optional[torch.Tensor],
|
216
|
+
positions: torch.Tensor,
|
217
|
+
forward_batch: ForwardBatch,
|
218
|
+
inputs_embeds: Optional[torch.Tensor] = None,
|
219
|
+
pp_proxy_tensors: Optional[PPProxyTensors] = None,
|
220
|
+
) -> Union[torch.Tensor, PPProxyTensors]:
|
221
|
+
if get_pp_group().is_first_rank:
|
222
|
+
if inputs_embeds is not None:
|
223
|
+
hidden_states = inputs_embeds
|
224
|
+
else:
|
225
|
+
hidden_states = self.get_input_embeddings(input_ids)
|
226
|
+
residual = None
|
227
|
+
else:
|
228
|
+
assert pp_proxy_tensors is not None
|
229
|
+
hidden_states = pp_proxy_tensors["hidden_states"]
|
230
|
+
residual = pp_proxy_tensors["residual"]
|
231
|
+
|
232
|
+
kv_cache_index = 0
|
233
|
+
for i in range(self.start_layer, self.end_layer):
|
234
|
+
layer = self.layers[i]
|
235
|
+
if not layer._is_no_op_attention:
|
236
|
+
hidden_states, residual = layer(
|
237
|
+
positions, hidden_states, forward_batch, residual
|
238
|
+
)
|
239
|
+
kv_cache_index += 1
|
240
|
+
else:
|
241
|
+
hidden_states, residual = layer(
|
242
|
+
positions, hidden_states, forward_batch, residual
|
243
|
+
)
|
244
|
+
|
245
|
+
if not get_pp_group().is_last_rank:
|
246
|
+
return PPProxyTensors(
|
247
|
+
{"hidden_states": hidden_states, "residual": residual}
|
248
|
+
)
|
249
|
+
|
250
|
+
hidden_states, _ = self.norm(hidden_states, residual)
|
251
|
+
return hidden_states
|
252
|
+
|
253
|
+
|
254
|
+
class DeciLMForCausalLM(nn.Module):
|
255
|
+
packed_modules_mapping = {
|
256
|
+
"qkv_proj": ["q_proj", "k_proj", "v_proj"],
|
257
|
+
"gate_up_proj": ["gate_proj", "up_proj"],
|
258
|
+
}
|
259
|
+
|
260
|
+
# LoRA specific attributes
|
261
|
+
supported_lora_modules = [
|
262
|
+
"qkv_proj",
|
263
|
+
"o_proj",
|
264
|
+
"gate_up_proj",
|
265
|
+
"down_proj",
|
266
|
+
"embed_tokens",
|
267
|
+
"lm_head",
|
268
|
+
]
|
269
|
+
embedding_modules = {
|
270
|
+
"embed_tokens": "input_embeddings",
|
271
|
+
"lm_head": "output_embeddings",
|
272
|
+
}
|
273
|
+
embedding_padding_modules = ["lm_head"]
|
274
|
+
|
275
|
+
# Mistral/Llama models can also be loaded with --load-format mistral
|
276
|
+
# from consolidated.safetensors checkpoints
|
277
|
+
mistral_mapping = {
|
278
|
+
"layers": "model.layers",
|
279
|
+
"attention": "self_attn",
|
280
|
+
"wq": "q_proj",
|
281
|
+
"wk": "k_proj",
|
282
|
+
"wv": "v_proj",
|
283
|
+
"wo": "o_proj",
|
284
|
+
"attention_norm": "input_layernorm",
|
285
|
+
"feed_forward": "mlp",
|
286
|
+
"w1": "gate_proj",
|
287
|
+
"w2": "down_proj",
|
288
|
+
"w3": "up_proj",
|
289
|
+
"ffn_norm": "post_attention_layernorm",
|
290
|
+
"tok_embeddings": "model.embed_tokens",
|
291
|
+
"output": "lm_head",
|
292
|
+
"norm": "model.norm",
|
293
|
+
}
|
294
|
+
|
295
|
+
def __init__(
|
296
|
+
self,
|
297
|
+
*,
|
298
|
+
config: LlamaConfig,
|
299
|
+
quant_config: Optional[QuantizationConfig] = None,
|
300
|
+
prefix: str = "",
|
301
|
+
):
|
302
|
+
super().__init__()
|
303
|
+
lora_config = None
|
304
|
+
self.config = config
|
305
|
+
self.lora_config = lora_config
|
306
|
+
|
307
|
+
self.model = self._init_model(
|
308
|
+
config=config, quant_config=quant_config, prefix=add_prefix("model", prefix)
|
309
|
+
)
|
310
|
+
if self.config.tie_word_embeddings:
|
311
|
+
self.lm_head = self.model.embed_tokens
|
312
|
+
else:
|
313
|
+
self.unpadded_vocab_size = config.vocab_size
|
314
|
+
if lora_config:
|
315
|
+
self.unpadded_vocab_size += lora_config.lora_extra_vocab_size
|
316
|
+
self.lm_head = ParallelLMHead(
|
317
|
+
self.unpadded_vocab_size,
|
318
|
+
config.hidden_size,
|
319
|
+
org_num_embeddings=config.vocab_size,
|
320
|
+
padding_size=(
|
321
|
+
DEFAULT_VOCAB_PADDING_SIZE
|
322
|
+
# We need bigger padding if using lora for kernel
|
323
|
+
# compatibility
|
324
|
+
if not lora_config
|
325
|
+
else lora_config.lora_vocab_padding_size
|
326
|
+
),
|
327
|
+
quant_config=quant_config,
|
328
|
+
prefix=add_prefix("lm_head", prefix),
|
329
|
+
)
|
330
|
+
self.logits_processor = LogitsProcessor(config)
|
331
|
+
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
|
332
|
+
|
333
|
+
def _init_model(
|
334
|
+
self,
|
335
|
+
config: LlamaConfig,
|
336
|
+
quant_config: Optional[QuantizationConfig] = None,
|
337
|
+
prefix: str = "",
|
338
|
+
):
|
339
|
+
return DeciModel(config=config, quant_config=quant_config, prefix=prefix)
|
340
|
+
|
341
|
+
def get_input_embeddings(self, input_ids: torch.Tensor) -> torch.Tensor:
|
342
|
+
return self.model.get_input_embeddings(input_ids)
|
343
|
+
|
344
|
+
@torch.no_grad()
|
345
|
+
def forward(
|
346
|
+
self,
|
347
|
+
input_ids: torch.Tensor,
|
348
|
+
positions: torch.Tensor,
|
349
|
+
forward_batch: ForwardBatch,
|
350
|
+
inputs_embeds: Optional[torch.Tensor] = None,
|
351
|
+
get_embedding: bool = False,
|
352
|
+
pp_proxy_tensors: Optional[PPProxyTensors] = None,
|
353
|
+
) -> LogitsProcessorOutput:
|
354
|
+
hidden_states = self.model(
|
355
|
+
input_ids,
|
356
|
+
positions,
|
357
|
+
forward_batch,
|
358
|
+
inputs_embeds,
|
359
|
+
pp_proxy_tensors=pp_proxy_tensors,
|
360
|
+
)
|
361
|
+
if get_pp_group().is_last_rank:
|
362
|
+
if not get_embedding:
|
363
|
+
return self.logits_processor(
|
364
|
+
input_ids, hidden_states, self.lm_head, forward_batch
|
365
|
+
)
|
366
|
+
else:
|
367
|
+
return self.pooler(hidden_states, forward_batch)
|
368
|
+
else:
|
369
|
+
return hidden_states
|
370
|
+
|
371
|
+
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]) -> None:
|
372
|
+
stacked_params_mapping = [
|
373
|
+
# (param_name, shard_name, shard_id)
|
374
|
+
(".qkv_proj", ".q_proj", "q"),
|
375
|
+
(".qkv_proj", ".k_proj", "k"),
|
376
|
+
(".qkv_proj", ".v_proj", "v"),
|
377
|
+
(".gate_up_proj", ".gate_proj", 0),
|
378
|
+
(".gate_up_proj", ".up_proj", 1),
|
379
|
+
]
|
380
|
+
|
381
|
+
params_dict = dict(self.named_parameters())
|
382
|
+
|
383
|
+
for name, loaded_weight in weights:
|
384
|
+
if "rotary_emb.inv_freq" in name:
|
385
|
+
continue
|
386
|
+
if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name:
|
387
|
+
# Models trained using ColossalAI may include these tensors in
|
388
|
+
# the checkpoint. Skip them.
|
389
|
+
continue
|
390
|
+
if self.config.tie_word_embeddings and "lm_head.weight" in name:
|
391
|
+
continue
|
392
|
+
if self.model.quant_config is not None and (
|
393
|
+
scale_name := self.model.quant_config.get_cache_scale(name)
|
394
|
+
):
|
395
|
+
# Loading kv cache quantization scales
|
396
|
+
param = params_dict[scale_name]
|
397
|
+
weight_loader = getattr(param, "weight_loader", default_weight_loader)
|
398
|
+
loaded_weight = (
|
399
|
+
loaded_weight if loaded_weight.dim() == 0 else loaded_weight[0]
|
400
|
+
)
|
401
|
+
weight_loader(param, loaded_weight)
|
402
|
+
continue
|
403
|
+
if "scale" in name:
|
404
|
+
name = maybe_remap_kv_scale_name(name, params_dict)
|
405
|
+
if name is None:
|
406
|
+
continue
|
407
|
+
|
408
|
+
for param_name, weight_name, shard_id in stacked_params_mapping:
|
409
|
+
if weight_name not in name:
|
410
|
+
continue
|
411
|
+
name = name.replace(weight_name, param_name)
|
412
|
+
# Skip loading extra bias for GPTQ models.
|
413
|
+
if name.endswith(".bias") and name not in params_dict:
|
414
|
+
continue
|
415
|
+
if name not in params_dict:
|
416
|
+
continue
|
417
|
+
param = params_dict[name]
|
418
|
+
weight_loader = param.weight_loader
|
419
|
+
weight_loader(param, loaded_weight, shard_id)
|
420
|
+
break
|
421
|
+
else:
|
422
|
+
# Skip loading extra bias for GPTQ models.
|
423
|
+
if name.endswith(".bias") and name not in params_dict:
|
424
|
+
continue
|
425
|
+
if name in params_dict.keys():
|
426
|
+
param = params_dict[name]
|
427
|
+
weight_loader = getattr(
|
428
|
+
param, "weight_loader", default_weight_loader
|
429
|
+
)
|
430
|
+
weight_loader(param, loaded_weight)
|
431
|
+
else:
|
432
|
+
logger.warning(f"Parameter {name} not found in params_dict")
|
433
|
+
|
434
|
+
|
435
|
+
EntryClass = [DeciLMForCausalLM]
|