sglang 0.5.4__py3-none-any.whl → 0.5.4.post2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +149 -34
- sglang/bench_serving.py +73 -14
- sglang/compile_deep_gemm.py +13 -7
- sglang/launch_server.py +2 -0
- sglang/srt/batch_invariant_ops/__init__.py +2 -0
- sglang/srt/batch_invariant_ops/batch_invariant_ops.py +221 -4
- sglang/srt/checkpoint_engine/__init__.py +9 -0
- sglang/srt/checkpoint_engine/update.py +317 -0
- sglang/srt/compilation/backend.py +1 -1
- sglang/srt/configs/__init__.py +2 -0
- sglang/srt/configs/deepseek_ocr.py +542 -10
- sglang/srt/configs/deepseekvl2.py +95 -194
- sglang/srt/configs/kimi_linear.py +160 -0
- sglang/srt/configs/mamba_utils.py +66 -0
- sglang/srt/configs/model_config.py +30 -7
- sglang/srt/constants.py +7 -0
- sglang/srt/debug_utils/tensor_dump_forward_hook.py +149 -0
- sglang/srt/disaggregation/decode.py +34 -6
- sglang/srt/disaggregation/nixl/conn.py +2 -2
- sglang/srt/disaggregation/prefill.py +25 -3
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +3 -1
- sglang/srt/distributed/parallel_state.py +9 -12
- sglang/srt/entrypoints/engine.py +31 -20
- sglang/srt/entrypoints/grpc_server.py +0 -1
- sglang/srt/entrypoints/http_server.py +94 -94
- sglang/srt/entrypoints/openai/protocol.py +7 -1
- sglang/srt/entrypoints/openai/serving_chat.py +42 -0
- sglang/srt/entrypoints/openai/serving_completions.py +10 -0
- sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
- sglang/srt/environ.py +23 -2
- sglang/srt/eplb/expert_distribution.py +64 -1
- sglang/srt/eplb/expert_location.py +106 -36
- sglang/srt/function_call/function_call_parser.py +2 -0
- sglang/srt/function_call/minimax_m2.py +367 -0
- sglang/srt/grpc/compile_proto.py +3 -0
- sglang/srt/layers/activation.py +6 -0
- sglang/srt/layers/attention/ascend_backend.py +233 -5
- sglang/srt/layers/attention/attention_registry.py +3 -0
- sglang/srt/layers/attention/fla/chunk_delta_h.py +61 -32
- sglang/srt/layers/attention/fla/fused_recurrent.py +17 -4
- sglang/srt/layers/attention/fla/kda.py +1359 -0
- sglang/srt/layers/attention/fla/layernorm_gated.py +7 -1
- sglang/srt/layers/attention/flashattention_backend.py +19 -8
- sglang/srt/layers/attention/flashinfer_backend.py +10 -1
- sglang/srt/layers/attention/flashinfer_mla_backend.py +21 -11
- sglang/srt/layers/attention/flashmla_backend.py +1 -1
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +223 -0
- sglang/srt/layers/attention/mamba/mamba.py +20 -11
- sglang/srt/layers/attention/nsa/dequant_k_cache.py +138 -6
- sglang/srt/layers/attention/nsa/nsa_indexer.py +45 -22
- sglang/srt/layers/attention/nsa/quant_k_cache.py +44 -12
- sglang/srt/layers/attention/nsa/transform_index.py +1 -1
- sglang/srt/layers/attention/nsa_backend.py +157 -23
- sglang/srt/layers/attention/triton_backend.py +4 -1
- sglang/srt/layers/attention/trtllm_mha_backend.py +10 -4
- sglang/srt/layers/attention/trtllm_mla_backend.py +11 -15
- sglang/srt/layers/attention/utils.py +78 -0
- sglang/srt/layers/communicator.py +24 -1
- sglang/srt/layers/deep_gemm_wrapper/compile_utils.py +1 -1
- sglang/srt/layers/layernorm.py +35 -6
- sglang/srt/layers/logits_processor.py +9 -20
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +138 -0
- sglang/srt/layers/moe/ep_moe/kernels.py +194 -0
- sglang/srt/layers/moe/ep_moe/layer.py +78 -289
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128]_down.json +164 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +68 -22
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +43 -3
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +106 -26
- sglang/srt/layers/moe/fused_moe_triton/layer.py +3 -3
- sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +7 -4
- sglang/srt/layers/moe/moe_runner/deep_gemm.py +340 -55
- sglang/srt/layers/moe/moe_runner/runner.py +3 -0
- sglang/srt/layers/moe/moe_runner/triton_kernels.py +194 -0
- sglang/srt/layers/moe/token_dispatcher/__init__.py +4 -4
- sglang/srt/layers/moe/token_dispatcher/base.py +11 -5
- sglang/srt/layers/moe/token_dispatcher/deepep.py +25 -18
- sglang/srt/layers/moe/token_dispatcher/standard.py +1 -1
- sglang/srt/layers/moe/topk.py +35 -10
- sglang/srt/layers/moe/utils.py +3 -4
- sglang/srt/layers/pooler.py +21 -2
- sglang/srt/layers/quantization/__init__.py +13 -84
- sglang/srt/layers/quantization/auto_round.py +394 -0
- sglang/srt/layers/quantization/awq.py +0 -3
- sglang/srt/layers/quantization/base_config.py +7 -0
- sglang/srt/layers/quantization/fp8.py +68 -63
- sglang/srt/layers/quantization/fp8_kernel.py +1 -1
- sglang/srt/layers/quantization/fp8_utils.py +2 -2
- sglang/srt/layers/quantization/gguf.py +566 -0
- sglang/srt/layers/quantization/modelopt_quant.py +168 -11
- sglang/srt/layers/quantization/mxfp4.py +30 -38
- sglang/srt/layers/quantization/unquant.py +23 -45
- sglang/srt/layers/quantization/w4afp8.py +38 -2
- sglang/srt/layers/radix_attention.py +5 -2
- sglang/srt/layers/rotary_embedding.py +130 -46
- sglang/srt/layers/sampler.py +12 -1
- sglang/srt/lora/lora_registry.py +9 -0
- sglang/srt/managers/async_mm_data_processor.py +122 -0
- sglang/srt/managers/data_parallel_controller.py +30 -3
- sglang/srt/managers/detokenizer_manager.py +3 -0
- sglang/srt/managers/io_struct.py +29 -4
- sglang/srt/managers/multi_tokenizer_mixin.py +22 -1
- sglang/srt/managers/schedule_batch.py +74 -15
- sglang/srt/managers/scheduler.py +185 -144
- sglang/srt/managers/scheduler_metrics_mixin.py +22 -14
- sglang/srt/managers/scheduler_output_processor_mixin.py +40 -3
- sglang/srt/managers/scheduler_pp_mixin.py +7 -2
- sglang/srt/managers/scheduler_profiler_mixin.py +3 -4
- sglang/srt/managers/scheduler_runtime_checker_mixin.py +45 -0
- sglang/srt/managers/scheduler_update_weights_mixin.py +18 -3
- sglang/srt/managers/session_controller.py +6 -5
- sglang/srt/managers/tokenizer_manager.py +165 -78
- sglang/srt/managers/tp_worker.py +24 -1
- sglang/srt/mem_cache/base_prefix_cache.py +23 -4
- sglang/srt/mem_cache/common.py +1 -0
- sglang/srt/mem_cache/hicache_storage.py +7 -1
- sglang/srt/mem_cache/memory_pool.py +253 -57
- sglang/srt/mem_cache/memory_pool_host.py +12 -5
- sglang/srt/mem_cache/radix_cache.py +4 -0
- sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +3 -2
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +1 -1
- sglang/srt/metrics/collector.py +46 -3
- sglang/srt/model_executor/cuda_graph_runner.py +15 -3
- sglang/srt/model_executor/forward_batch_info.py +55 -14
- sglang/srt/model_executor/model_runner.py +77 -170
- sglang/srt/model_executor/npu_graph_runner.py +7 -3
- sglang/srt/model_executor/piecewise_cuda_graph_runner.py +22 -12
- sglang/srt/model_loader/weight_utils.py +1 -1
- sglang/srt/models/bailing_moe.py +9 -2
- sglang/srt/models/deepseek_nextn.py +11 -2
- sglang/srt/models/deepseek_v2.py +296 -78
- sglang/srt/models/glm4.py +391 -77
- sglang/srt/models/glm4_moe.py +322 -354
- sglang/srt/models/glm4_moe_nextn.py +4 -14
- sglang/srt/models/glm4v.py +196 -55
- sglang/srt/models/glm4v_moe.py +29 -197
- sglang/srt/models/gpt_oss.py +1 -10
- sglang/srt/models/kimi_linear.py +678 -0
- sglang/srt/models/llama4.py +1 -1
- sglang/srt/models/llama_eagle3.py +11 -1
- sglang/srt/models/longcat_flash.py +2 -2
- sglang/srt/models/minimax_m2.py +922 -0
- sglang/srt/models/nvila.py +355 -0
- sglang/srt/models/nvila_lite.py +184 -0
- sglang/srt/models/qwen2.py +23 -2
- sglang/srt/models/qwen2_moe.py +30 -15
- sglang/srt/models/qwen3.py +35 -5
- sglang/srt/models/qwen3_moe.py +18 -12
- sglang/srt/models/qwen3_next.py +7 -0
- sglang/srt/multimodal/customized_mm_processor_utils.py +35 -0
- sglang/srt/multimodal/processors/base_processor.py +1 -0
- sglang/srt/multimodal/processors/glm4v.py +1 -1
- sglang/srt/multimodal/processors/{vila.py → nvila.py} +32 -24
- sglang/srt/multimodal/processors/points_v15_chat.py +2 -2
- sglang/srt/multiplex/multiplexing_mixin.py +209 -0
- sglang/srt/multiplex/pdmux_context.py +164 -0
- sglang/srt/parser/conversation.py +7 -1
- sglang/srt/parser/reasoning_parser.py +28 -1
- sglang/srt/sampling/custom_logit_processor.py +67 -1
- sglang/srt/sampling/penaltylib/frequency_penalty.py +6 -8
- sglang/srt/sampling/penaltylib/min_new_tokens.py +7 -8
- sglang/srt/sampling/penaltylib/orchestrator.py +43 -3
- sglang/srt/sampling/penaltylib/presence_penalty.py +6 -8
- sglang/srt/server_args.py +459 -199
- sglang/srt/single_batch_overlap.py +2 -4
- sglang/srt/speculative/draft_utils.py +16 -0
- sglang/srt/speculative/eagle_info.py +42 -36
- sglang/srt/speculative/eagle_info_v2.py +68 -25
- sglang/srt/speculative/eagle_utils.py +261 -16
- sglang/srt/speculative/eagle_worker.py +11 -3
- sglang/srt/speculative/eagle_worker_v2.py +15 -9
- sglang/srt/speculative/spec_info.py +305 -31
- sglang/srt/speculative/spec_utils.py +44 -8
- sglang/srt/tracing/trace.py +121 -12
- sglang/srt/utils/common.py +142 -74
- sglang/srt/utils/hf_transformers_utils.py +38 -12
- sglang/srt/utils/torch_memory_saver_adapter.py +20 -0
- sglang/test/kits/radix_cache_server_kit.py +50 -0
- sglang/test/runners.py +31 -7
- sglang/test/simple_eval_common.py +5 -3
- sglang/test/simple_eval_humaneval.py +1 -0
- sglang/test/simple_eval_math.py +1 -0
- sglang/test/simple_eval_mmlu.py +1 -0
- sglang/test/simple_eval_mmmu_vlm.py +1 -0
- sglang/test/test_deterministic.py +235 -12
- sglang/test/test_deterministic_utils.py +2 -1
- sglang/test/test_utils.py +7 -1
- sglang/version.py +1 -1
- {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/METADATA +15 -28
- {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/RECORD +194 -175
- sglang/srt/models/vila.py +0 -306
- /sglang/test/{kit_matched_stop.py → kits/matched_stop_kit.py} +0 -0
- {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/WHEEL +0 -0
- {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,678 @@
|
|
|
1
|
+
# Adapted from: https://github.com/vllm-project/vllm/blob/0384aa7150c4c9778efca041ffd1beb3ad2bd694/vllm/model_executor/models/kimi_linear.py
|
|
2
|
+
|
|
3
|
+
from collections.abc import Iterable
|
|
4
|
+
from typing import Optional
|
|
5
|
+
|
|
6
|
+
import torch
|
|
7
|
+
from einops import rearrange
|
|
8
|
+
from torch import nn
|
|
9
|
+
|
|
10
|
+
from sglang.srt.configs.kimi_linear import KimiLinearConfig
|
|
11
|
+
from sglang.srt.distributed import (
|
|
12
|
+
divide,
|
|
13
|
+
get_pp_group,
|
|
14
|
+
get_tensor_model_parallel_world_size,
|
|
15
|
+
tensor_model_parallel_all_reduce,
|
|
16
|
+
)
|
|
17
|
+
from sglang.srt.eplb.expert_distribution import get_global_expert_distribution_recorder
|
|
18
|
+
from sglang.srt.layers.attention.fla.kda import FusedRMSNormGated
|
|
19
|
+
from sglang.srt.layers.layernorm import RMSNorm
|
|
20
|
+
from sglang.srt.layers.linear import (
|
|
21
|
+
ColumnParallelLinear,
|
|
22
|
+
ReplicatedLinear,
|
|
23
|
+
RowParallelLinear,
|
|
24
|
+
)
|
|
25
|
+
from sglang.srt.layers.logits_processor import LogitsProcessor
|
|
26
|
+
from sglang.srt.layers.moe.ep_moe.layer import get_moe_impl_class
|
|
27
|
+
from sglang.srt.layers.moe.fused_moe_triton.layer import FusedMoE
|
|
28
|
+
from sglang.srt.layers.moe.topk import TopK, TopKOutputFormat
|
|
29
|
+
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
|
30
|
+
from sglang.srt.layers.utils import PPMissingLayer
|
|
31
|
+
from sglang.srt.layers.vocab_parallel_embedding import (
|
|
32
|
+
ParallelLMHead,
|
|
33
|
+
VocabParallelEmbedding,
|
|
34
|
+
)
|
|
35
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
|
|
36
|
+
from sglang.srt.model_loader.weight_utils import (
|
|
37
|
+
default_weight_loader,
|
|
38
|
+
maybe_remap_kv_scale_name,
|
|
39
|
+
sharded_weight_loader,
|
|
40
|
+
)
|
|
41
|
+
from sglang.srt.models.deepseek_v2 import DeepseekV2AttentionMLA as KimiMLAAttention
|
|
42
|
+
from sglang.srt.models.llama import LlamaMLP as KimiMLP
|
|
43
|
+
from sglang.srt.models.transformers import maybe_prefix
|
|
44
|
+
from sglang.srt.utils import make_layers
|
|
45
|
+
from sglang.srt.utils.common import BumpAllocator, add_prefix, set_weight_attrs
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
class KimiMoE(nn.Module):
|
|
49
|
+
def __init__(
|
|
50
|
+
self,
|
|
51
|
+
config: KimiLinearConfig,
|
|
52
|
+
quant_config: Optional[QuantizationConfig] = None,
|
|
53
|
+
prefix: str = "",
|
|
54
|
+
layer_idx: int = 0,
|
|
55
|
+
):
|
|
56
|
+
super().__init__()
|
|
57
|
+
hidden_size = config.hidden_size
|
|
58
|
+
intermediate_size = config.intermediate_size
|
|
59
|
+
moe_intermediate_size = config.moe_intermediate_size
|
|
60
|
+
num_experts = config.num_experts
|
|
61
|
+
moe_renormalize = config.moe_renormalize
|
|
62
|
+
self.tp_size = get_tensor_model_parallel_world_size()
|
|
63
|
+
self.routed_scaling_factor = config.routed_scaling_factor
|
|
64
|
+
self.num_shared_experts = config.num_shared_experts
|
|
65
|
+
self.layer_idx = layer_idx
|
|
66
|
+
|
|
67
|
+
if config.hidden_act != "silu":
|
|
68
|
+
raise ValueError(
|
|
69
|
+
f"Unsupported activation: {config.hidden_act}. "
|
|
70
|
+
"Only silu is supported for now."
|
|
71
|
+
)
|
|
72
|
+
|
|
73
|
+
# Gate always runs at half / full precision for now.
|
|
74
|
+
self.gate = ReplicatedLinear(
|
|
75
|
+
hidden_size,
|
|
76
|
+
num_experts,
|
|
77
|
+
bias=False,
|
|
78
|
+
quant_config=None,
|
|
79
|
+
prefix=f"{prefix}.gate",
|
|
80
|
+
)
|
|
81
|
+
|
|
82
|
+
self.gate.e_score_correction_bias = nn.Parameter(torch.empty(num_experts))
|
|
83
|
+
|
|
84
|
+
self.experts = get_moe_impl_class(quant_config)(
|
|
85
|
+
num_experts=config.n_routed_experts,
|
|
86
|
+
top_k=config.num_experts_per_token,
|
|
87
|
+
hidden_size=config.hidden_size,
|
|
88
|
+
intermediate_size=config.moe_intermediate_size,
|
|
89
|
+
layer_id=self.layer_idx,
|
|
90
|
+
quant_config=quant_config,
|
|
91
|
+
routed_scaling_factor=self.routed_scaling_factor,
|
|
92
|
+
prefix=add_prefix("experts", prefix),
|
|
93
|
+
)
|
|
94
|
+
|
|
95
|
+
self.topk = TopK(
|
|
96
|
+
top_k=config.num_experts_per_token,
|
|
97
|
+
renormalize=moe_renormalize,
|
|
98
|
+
use_grouped_topk=True,
|
|
99
|
+
num_expert_group=config.num_expert_group,
|
|
100
|
+
topk_group=config.topk_group,
|
|
101
|
+
correction_bias=self.gate.e_score_correction_bias,
|
|
102
|
+
quant_config=quant_config,
|
|
103
|
+
routed_scaling_factor=self.routed_scaling_factor,
|
|
104
|
+
apply_routed_scaling_factor_on_output=self.experts.should_fuse_routed_scaling_factor_in_topk,
|
|
105
|
+
# Some Fp4 MoE backends require the output format to be bypassed but the MTP layers are unquantized
|
|
106
|
+
# and requires the output format to be standard. We use quant_config to determine the output format.
|
|
107
|
+
output_format=TopKOutputFormat.STANDARD if quant_config is None else None,
|
|
108
|
+
)
|
|
109
|
+
|
|
110
|
+
if self.num_shared_experts is not None:
|
|
111
|
+
intermediate_size = moe_intermediate_size * self.num_shared_experts
|
|
112
|
+
self.shared_experts = KimiMLP(
|
|
113
|
+
hidden_size=config.hidden_size,
|
|
114
|
+
intermediate_size=intermediate_size,
|
|
115
|
+
hidden_act=config.hidden_act,
|
|
116
|
+
quant_config=quant_config,
|
|
117
|
+
reduce_results=False,
|
|
118
|
+
)
|
|
119
|
+
|
|
120
|
+
def forward(self, hidden_states: torch.Tensor) -> torch.Tensor:
|
|
121
|
+
num_tokens, hidden_size = hidden_states.shape
|
|
122
|
+
hidden_states = hidden_states.view(-1, hidden_size)
|
|
123
|
+
if self.num_shared_experts is not None:
|
|
124
|
+
shared_output = self.shared_experts(hidden_states)
|
|
125
|
+
router_logits, _ = self.gate(hidden_states)
|
|
126
|
+
topk_output = self.topk(hidden_states, router_logits)
|
|
127
|
+
final_hidden_states = self.experts(hidden_states, topk_output)
|
|
128
|
+
|
|
129
|
+
if shared_output is not None:
|
|
130
|
+
final_hidden_states = final_hidden_states + shared_output
|
|
131
|
+
|
|
132
|
+
if self.tp_size > 1:
|
|
133
|
+
final_hidden_states = tensor_model_parallel_all_reduce(final_hidden_states)
|
|
134
|
+
return final_hidden_states.view(num_tokens, hidden_size)
|
|
135
|
+
|
|
136
|
+
|
|
137
|
+
class KimiDeltaAttention(nn.Module):
|
|
138
|
+
def __init__(
|
|
139
|
+
self,
|
|
140
|
+
layer_idx: int,
|
|
141
|
+
hidden_size: int,
|
|
142
|
+
config: KimiLinearConfig,
|
|
143
|
+
quant_config: Optional[QuantizationConfig] = None,
|
|
144
|
+
rms_norm_eps: float = 1e-5,
|
|
145
|
+
prefix: str = "",
|
|
146
|
+
**kwargs,
|
|
147
|
+
) -> None:
|
|
148
|
+
super().__init__()
|
|
149
|
+
self.tp_size = get_tensor_model_parallel_world_size()
|
|
150
|
+
self.hidden_size = hidden_size
|
|
151
|
+
self.config = config
|
|
152
|
+
self.head_dim = config.linear_attn_config["head_dim"]
|
|
153
|
+
self.num_heads = config.linear_attn_config["num_heads"]
|
|
154
|
+
self.layer_idx = layer_idx
|
|
155
|
+
self.prefix = prefix
|
|
156
|
+
assert self.num_heads % self.tp_size == 0
|
|
157
|
+
self.local_num_heads = divide(self.num_heads, self.tp_size)
|
|
158
|
+
|
|
159
|
+
projection_size = self.head_dim * self.num_heads
|
|
160
|
+
self.conv_size = config.linear_attn_config["short_conv_kernel_size"]
|
|
161
|
+
|
|
162
|
+
self.q_proj = ColumnParallelLinear(
|
|
163
|
+
self.hidden_size,
|
|
164
|
+
projection_size,
|
|
165
|
+
bias=False,
|
|
166
|
+
quant_config=quant_config,
|
|
167
|
+
prefix=f"{prefix}.q_proj",
|
|
168
|
+
)
|
|
169
|
+
self.k_proj = ColumnParallelLinear(
|
|
170
|
+
self.hidden_size,
|
|
171
|
+
projection_size,
|
|
172
|
+
bias=False,
|
|
173
|
+
quant_config=quant_config,
|
|
174
|
+
prefix=f"{prefix}.k_proj",
|
|
175
|
+
)
|
|
176
|
+
self.v_proj = ColumnParallelLinear(
|
|
177
|
+
self.hidden_size,
|
|
178
|
+
projection_size,
|
|
179
|
+
bias=False,
|
|
180
|
+
quant_config=quant_config,
|
|
181
|
+
prefix=f"{prefix}.v_proj",
|
|
182
|
+
)
|
|
183
|
+
|
|
184
|
+
self.f_a_proj = ReplicatedLinear(
|
|
185
|
+
self.hidden_size,
|
|
186
|
+
self.head_dim,
|
|
187
|
+
bias=False,
|
|
188
|
+
quant_config=quant_config,
|
|
189
|
+
prefix=f"{prefix}.f_a_proj",
|
|
190
|
+
)
|
|
191
|
+
|
|
192
|
+
self.f_b_proj = ColumnParallelLinear(
|
|
193
|
+
self.head_dim,
|
|
194
|
+
projection_size,
|
|
195
|
+
bias=False,
|
|
196
|
+
quant_config=quant_config,
|
|
197
|
+
prefix=f"{prefix}.f_b_proj",
|
|
198
|
+
)
|
|
199
|
+
self.dt_bias = nn.Parameter(
|
|
200
|
+
torch.empty(divide(projection_size, self.tp_size), dtype=torch.float32)
|
|
201
|
+
)
|
|
202
|
+
|
|
203
|
+
set_weight_attrs(self.dt_bias, {"weight_loader": sharded_weight_loader(0)})
|
|
204
|
+
|
|
205
|
+
self.b_proj = ColumnParallelLinear(
|
|
206
|
+
self.hidden_size,
|
|
207
|
+
self.num_heads,
|
|
208
|
+
bias=False,
|
|
209
|
+
quant_config=quant_config,
|
|
210
|
+
prefix=f"{prefix}.b_proj",
|
|
211
|
+
)
|
|
212
|
+
|
|
213
|
+
self.q_conv1d = ColumnParallelLinear(
|
|
214
|
+
input_size=self.conv_size,
|
|
215
|
+
output_size=projection_size,
|
|
216
|
+
bias=False,
|
|
217
|
+
params_dtype=torch.float32,
|
|
218
|
+
prefix=f"{prefix}.q_conv1d",
|
|
219
|
+
)
|
|
220
|
+
self.k_conv1d = ColumnParallelLinear(
|
|
221
|
+
input_size=self.conv_size,
|
|
222
|
+
output_size=projection_size,
|
|
223
|
+
bias=False,
|
|
224
|
+
params_dtype=torch.float32,
|
|
225
|
+
prefix=f"{prefix}.k_conv1d",
|
|
226
|
+
)
|
|
227
|
+
self.v_conv1d = ColumnParallelLinear(
|
|
228
|
+
input_size=self.conv_size,
|
|
229
|
+
output_size=projection_size,
|
|
230
|
+
bias=False,
|
|
231
|
+
params_dtype=torch.float32,
|
|
232
|
+
prefix=f"{prefix}.v_conv1d",
|
|
233
|
+
)
|
|
234
|
+
# unsqueeze to fit conv1d weights shape into the linear weights shape.
|
|
235
|
+
# Can't do this in `weight_loader` since it already exists in
|
|
236
|
+
# `ColumnParallelLinear` and `set_weight_attrs`
|
|
237
|
+
# doesn't allow to override it
|
|
238
|
+
self.q_conv1d.weight.data = self.q_conv1d.weight.data.unsqueeze(1)
|
|
239
|
+
self.k_conv1d.weight.data = self.k_conv1d.weight.data.unsqueeze(1)
|
|
240
|
+
self.v_conv1d.weight.data = self.v_conv1d.weight.data.unsqueeze(1)
|
|
241
|
+
|
|
242
|
+
self.A_log = nn.Parameter(
|
|
243
|
+
torch.empty(1, 1, self.local_num_heads, 1, dtype=torch.float32)
|
|
244
|
+
)
|
|
245
|
+
set_weight_attrs(self.A_log, {"weight_loader": sharded_weight_loader(2)})
|
|
246
|
+
|
|
247
|
+
self.g_a_proj = ReplicatedLinear(
|
|
248
|
+
self.hidden_size,
|
|
249
|
+
self.head_dim,
|
|
250
|
+
bias=False,
|
|
251
|
+
quant_config=quant_config,
|
|
252
|
+
prefix=f"{prefix}.g_a_proj",
|
|
253
|
+
)
|
|
254
|
+
self.g_b_proj = ColumnParallelLinear(
|
|
255
|
+
self.head_dim,
|
|
256
|
+
projection_size,
|
|
257
|
+
bias=False,
|
|
258
|
+
quant_config=quant_config,
|
|
259
|
+
prefix=f"{prefix}.g_b_proj",
|
|
260
|
+
)
|
|
261
|
+
self.o_norm = FusedRMSNormGated(
|
|
262
|
+
self.head_dim, eps=rms_norm_eps, activation="sigmoid"
|
|
263
|
+
)
|
|
264
|
+
self.o_proj = RowParallelLinear(
|
|
265
|
+
projection_size,
|
|
266
|
+
self.hidden_size,
|
|
267
|
+
bias=False,
|
|
268
|
+
quant_config=quant_config,
|
|
269
|
+
prefix=f"{prefix}.o_proj",
|
|
270
|
+
)
|
|
271
|
+
|
|
272
|
+
def forward(
|
|
273
|
+
self,
|
|
274
|
+
hidden_states: torch.Tensor,
|
|
275
|
+
positions: torch.Tensor,
|
|
276
|
+
forward_batch: ForwardBatch,
|
|
277
|
+
zero_allocator: BumpAllocator,
|
|
278
|
+
) -> None:
|
|
279
|
+
q_proj_states = self.q_proj(hidden_states)[0]
|
|
280
|
+
k_proj_states = self.k_proj(hidden_states)[0]
|
|
281
|
+
v_proj_states = self.v_proj(hidden_states)[0]
|
|
282
|
+
|
|
283
|
+
q_conv_weights = self.q_conv1d.weight.view(
|
|
284
|
+
self.q_conv1d.weight.size(0), self.q_conv1d.weight.size(2)
|
|
285
|
+
)
|
|
286
|
+
k_conv_weights = self.k_conv1d.weight.view(
|
|
287
|
+
self.k_conv1d.weight.size(0), self.k_conv1d.weight.size(2)
|
|
288
|
+
)
|
|
289
|
+
v_conv_weights = self.v_conv1d.weight.view(
|
|
290
|
+
self.v_conv1d.weight.size(0), self.v_conv1d.weight.size(2)
|
|
291
|
+
)
|
|
292
|
+
|
|
293
|
+
kwargs = {
|
|
294
|
+
"q_proj_states": q_proj_states,
|
|
295
|
+
"k_proj_states": k_proj_states,
|
|
296
|
+
"v_proj_states": v_proj_states,
|
|
297
|
+
"q_conv_weights": q_conv_weights,
|
|
298
|
+
"k_conv_weights": k_conv_weights,
|
|
299
|
+
"v_conv_weights": v_conv_weights,
|
|
300
|
+
"q_conv_bias": self.q_conv1d.bias,
|
|
301
|
+
"k_conv_bias": self.k_conv1d.bias,
|
|
302
|
+
"v_conv_bias": self.v_conv1d.bias,
|
|
303
|
+
"dt_bias": self.dt_bias,
|
|
304
|
+
"b_proj": self.b_proj,
|
|
305
|
+
"f_a_proj": self.f_a_proj,
|
|
306
|
+
"f_b_proj": self.f_b_proj,
|
|
307
|
+
"A_log": self.A_log,
|
|
308
|
+
"head_dim": self.head_dim,
|
|
309
|
+
"hidden_states": hidden_states,
|
|
310
|
+
"layer_id": self.layer_idx,
|
|
311
|
+
}
|
|
312
|
+
|
|
313
|
+
core_attn_out = forward_batch.attn_backend.forward(
|
|
314
|
+
q=None,
|
|
315
|
+
k=None,
|
|
316
|
+
v=None,
|
|
317
|
+
layer=None,
|
|
318
|
+
forward_batch=forward_batch,
|
|
319
|
+
**kwargs,
|
|
320
|
+
)
|
|
321
|
+
|
|
322
|
+
g_proj_states = self.g_b_proj(self.g_a_proj(hidden_states)[0])[0]
|
|
323
|
+
g = rearrange(g_proj_states, "... (h d) -> ... h d", d=self.head_dim)
|
|
324
|
+
core_attn_out = self.o_norm(core_attn_out, g)
|
|
325
|
+
core_attn_out = rearrange(core_attn_out, "1 n h d -> n (h d)")
|
|
326
|
+
|
|
327
|
+
return self.o_proj(core_attn_out)[0]
|
|
328
|
+
|
|
329
|
+
|
|
330
|
+
class KimiDecoderLayer(nn.Module):
|
|
331
|
+
def __init__(
|
|
332
|
+
self,
|
|
333
|
+
config: KimiLinearConfig,
|
|
334
|
+
layer_idx: int,
|
|
335
|
+
quant_config: Optional[QuantizationConfig] = None,
|
|
336
|
+
prefix: str = "",
|
|
337
|
+
) -> None:
|
|
338
|
+
super().__init__()
|
|
339
|
+
self.hidden_size = config.hidden_size
|
|
340
|
+
|
|
341
|
+
self.is_moe = config.is_moe
|
|
342
|
+
|
|
343
|
+
if config.is_kda_layer(layer_idx):
|
|
344
|
+
self.self_attn = KimiDeltaAttention(
|
|
345
|
+
layer_idx=layer_idx,
|
|
346
|
+
hidden_size=config.hidden_size,
|
|
347
|
+
config=config,
|
|
348
|
+
quant_config=quant_config,
|
|
349
|
+
prefix=f"{prefix}.self_attn",
|
|
350
|
+
)
|
|
351
|
+
else:
|
|
352
|
+
self.self_attn = KimiMLAAttention(
|
|
353
|
+
layer_id=layer_idx,
|
|
354
|
+
hidden_size=self.hidden_size,
|
|
355
|
+
num_heads=config.num_attention_heads,
|
|
356
|
+
quant_config=quant_config,
|
|
357
|
+
prefix=f"{prefix}.self_attn",
|
|
358
|
+
config=config,
|
|
359
|
+
qk_nope_head_dim=config.qk_nope_head_dim,
|
|
360
|
+
qk_rope_head_dim=config.qk_rope_head_dim,
|
|
361
|
+
v_head_dim=config.v_head_dim,
|
|
362
|
+
q_lora_rank=config.q_lora_rank,
|
|
363
|
+
kv_lora_rank=config.kv_lora_rank,
|
|
364
|
+
skip_rope=True,
|
|
365
|
+
)
|
|
366
|
+
|
|
367
|
+
if (
|
|
368
|
+
self.is_moe
|
|
369
|
+
and config.num_experts is not None
|
|
370
|
+
and layer_idx >= config.first_k_dense_replace
|
|
371
|
+
and layer_idx % config.moe_layer_freq == 0
|
|
372
|
+
):
|
|
373
|
+
self.block_sparse_moe = KimiMoE(
|
|
374
|
+
config=config,
|
|
375
|
+
quant_config=quant_config,
|
|
376
|
+
layer_idx=layer_idx,
|
|
377
|
+
prefix=f"{prefix}.mlp",
|
|
378
|
+
)
|
|
379
|
+
self.mlp = self.block_sparse_moe
|
|
380
|
+
else:
|
|
381
|
+
self.mlp = KimiMLP(
|
|
382
|
+
hidden_size=self.hidden_size,
|
|
383
|
+
intermediate_size=config.intermediate_size,
|
|
384
|
+
hidden_act=config.hidden_act,
|
|
385
|
+
quant_config=quant_config,
|
|
386
|
+
prefix=f"{prefix}.mlp",
|
|
387
|
+
)
|
|
388
|
+
self.input_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
|
389
|
+
self.post_attention_layernorm = RMSNorm(
|
|
390
|
+
config.hidden_size, eps=config.rms_norm_eps
|
|
391
|
+
)
|
|
392
|
+
|
|
393
|
+
def forward(
|
|
394
|
+
self,
|
|
395
|
+
positions: torch.Tensor,
|
|
396
|
+
hidden_states: torch.Tensor,
|
|
397
|
+
forward_batch: ForwardBatch,
|
|
398
|
+
residual: Optional[torch.Tensor],
|
|
399
|
+
zero_allocator: BumpAllocator,
|
|
400
|
+
) -> tuple[torch.Tensor, torch.Tensor]:
|
|
401
|
+
# Self Attention
|
|
402
|
+
if residual is None:
|
|
403
|
+
residual = hidden_states
|
|
404
|
+
hidden_states = self.input_layernorm(hidden_states)
|
|
405
|
+
else:
|
|
406
|
+
hidden_states, residual = self.input_layernorm(hidden_states, residual)
|
|
407
|
+
|
|
408
|
+
hidden_states = self.self_attn(
|
|
409
|
+
hidden_states=hidden_states,
|
|
410
|
+
positions=positions,
|
|
411
|
+
forward_batch=forward_batch,
|
|
412
|
+
zero_allocator=zero_allocator,
|
|
413
|
+
)
|
|
414
|
+
|
|
415
|
+
# Fully Connected
|
|
416
|
+
hidden_states, residual = self.post_attention_layernorm(hidden_states, residual)
|
|
417
|
+
hidden_states = self.mlp(hidden_states)
|
|
418
|
+
return hidden_states, residual
|
|
419
|
+
|
|
420
|
+
|
|
421
|
+
class KimiLinearModel(nn.Module):
|
|
422
|
+
def __init__(
|
|
423
|
+
self,
|
|
424
|
+
config: KimiLinearConfig,
|
|
425
|
+
quant_config: Optional[QuantizationConfig] = None,
|
|
426
|
+
prefix: str = "",
|
|
427
|
+
):
|
|
428
|
+
super().__init__()
|
|
429
|
+
|
|
430
|
+
self.config = config
|
|
431
|
+
|
|
432
|
+
self.padding_idx = config.pad_token_id
|
|
433
|
+
self.vocab_size = config.vocab_size
|
|
434
|
+
self.pp_group = get_pp_group()
|
|
435
|
+
|
|
436
|
+
if self.pp_group.is_first_rank:
|
|
437
|
+
self.embed_tokens = VocabParallelEmbedding(
|
|
438
|
+
config.vocab_size,
|
|
439
|
+
config.hidden_size,
|
|
440
|
+
prefix=f"{prefix}.embed_tokens",
|
|
441
|
+
)
|
|
442
|
+
else:
|
|
443
|
+
self.embed_tokens = PPMissingLayer()
|
|
444
|
+
|
|
445
|
+
self.layers, self.start_layer, self.end_layer = make_layers(
|
|
446
|
+
config.num_hidden_layers,
|
|
447
|
+
lambda idx, prefix: KimiDecoderLayer(
|
|
448
|
+
layer_idx=idx,
|
|
449
|
+
config=config,
|
|
450
|
+
quant_config=quant_config,
|
|
451
|
+
prefix=prefix,
|
|
452
|
+
),
|
|
453
|
+
pp_rank=self.pp_group.rank_in_group,
|
|
454
|
+
pp_size=self.pp_group.world_size,
|
|
455
|
+
prefix=f"{prefix}.layers",
|
|
456
|
+
)
|
|
457
|
+
|
|
458
|
+
if self.pp_group.is_last_rank:
|
|
459
|
+
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
|
460
|
+
else:
|
|
461
|
+
self.norm = PPMissingLayer()
|
|
462
|
+
|
|
463
|
+
world_size = get_tensor_model_parallel_world_size()
|
|
464
|
+
assert (
|
|
465
|
+
config.num_attention_heads % world_size == 0
|
|
466
|
+
), "num_attention_heads must be divisible by world_size"
|
|
467
|
+
|
|
468
|
+
def forward(
|
|
469
|
+
self,
|
|
470
|
+
input_ids: torch.Tensor | None,
|
|
471
|
+
positions: torch.Tensor,
|
|
472
|
+
forward_batch: ForwardBatch,
|
|
473
|
+
inputs_embeds: torch.Tensor | None = None,
|
|
474
|
+
pp_proxy_tensors: Optional[PPProxyTensors] = None,
|
|
475
|
+
) -> torch.Tensor:
|
|
476
|
+
if get_pp_group().is_first_rank:
|
|
477
|
+
if inputs_embeds is not None:
|
|
478
|
+
hidden_states = inputs_embeds
|
|
479
|
+
else:
|
|
480
|
+
hidden_states = self.embed_tokens(input_ids)
|
|
481
|
+
residual = None
|
|
482
|
+
else:
|
|
483
|
+
assert pp_proxy_tensors is not None
|
|
484
|
+
hidden_states = pp_proxy_tensors["hidden_states"]
|
|
485
|
+
residual = pp_proxy_tensors["residual"]
|
|
486
|
+
|
|
487
|
+
total_num_layers = self.end_layer - self.start_layer
|
|
488
|
+
device = hidden_states.device
|
|
489
|
+
zero_allocator = BumpAllocator(
|
|
490
|
+
buffer_size=total_num_layers * 2,
|
|
491
|
+
dtype=torch.float32,
|
|
492
|
+
device=device,
|
|
493
|
+
)
|
|
494
|
+
# TODO: capture aux hidden states
|
|
495
|
+
aux_hidden_states = []
|
|
496
|
+
for i in range(self.start_layer, self.end_layer):
|
|
497
|
+
ctx = get_global_expert_distribution_recorder().with_current_layer(i)
|
|
498
|
+
with ctx:
|
|
499
|
+
layer = self.layers[i]
|
|
500
|
+
hidden_states, residual = layer(
|
|
501
|
+
positions=positions,
|
|
502
|
+
hidden_states=hidden_states,
|
|
503
|
+
forward_batch=forward_batch,
|
|
504
|
+
residual=residual,
|
|
505
|
+
zero_allocator=zero_allocator,
|
|
506
|
+
)
|
|
507
|
+
|
|
508
|
+
if not self.pp_group.is_last_rank:
|
|
509
|
+
return PPProxyTensors(
|
|
510
|
+
{
|
|
511
|
+
"hidden_states": hidden_states,
|
|
512
|
+
"residual": residual,
|
|
513
|
+
}
|
|
514
|
+
)
|
|
515
|
+
else:
|
|
516
|
+
if hidden_states.shape[0] != 0:
|
|
517
|
+
if residual is None:
|
|
518
|
+
hidden_states = self.norm(hidden_states)
|
|
519
|
+
else:
|
|
520
|
+
hidden_states, _ = self.norm(hidden_states, residual)
|
|
521
|
+
|
|
522
|
+
if len(aux_hidden_states) == 0:
|
|
523
|
+
return hidden_states
|
|
524
|
+
|
|
525
|
+
return hidden_states, aux_hidden_states
|
|
526
|
+
|
|
527
|
+
|
|
528
|
+
class KimiLinearForCausalLM(nn.Module):
|
|
529
|
+
def __init__(
|
|
530
|
+
self,
|
|
531
|
+
config: KimiLinearConfig,
|
|
532
|
+
quant_config: Optional[QuantizationConfig] = None,
|
|
533
|
+
prefix: str = "",
|
|
534
|
+
) -> None:
|
|
535
|
+
super().__init__()
|
|
536
|
+
self.config = config
|
|
537
|
+
self.quant_config = quant_config
|
|
538
|
+
self.model = KimiLinearModel(
|
|
539
|
+
config, quant_config, prefix=maybe_prefix(prefix, "model")
|
|
540
|
+
)
|
|
541
|
+
self.pp_group = get_pp_group()
|
|
542
|
+
if self.pp_group.is_last_rank:
|
|
543
|
+
self.lm_head = ParallelLMHead(
|
|
544
|
+
self.config.vocab_size,
|
|
545
|
+
self.config.hidden_size,
|
|
546
|
+
quant_config=quant_config,
|
|
547
|
+
prefix=maybe_prefix(prefix, "lm_head"),
|
|
548
|
+
)
|
|
549
|
+
else:
|
|
550
|
+
self.lm_head = PPMissingLayer()
|
|
551
|
+
logit_scale = getattr(self.config, "logit_scale", 1.0)
|
|
552
|
+
self.logits_processor = LogitsProcessor(config=config, logit_scale=logit_scale)
|
|
553
|
+
|
|
554
|
+
def forward(
|
|
555
|
+
self,
|
|
556
|
+
input_ids: torch.Tensor,
|
|
557
|
+
positions: torch.Tensor,
|
|
558
|
+
forward_batch: ForwardBatch,
|
|
559
|
+
inputs_embeds: Optional[torch.Tensor] = None,
|
|
560
|
+
pp_proxy_tensors: Optional[PPProxyTensors] = None,
|
|
561
|
+
) -> torch.Tensor:
|
|
562
|
+
hidden_states = self.model(
|
|
563
|
+
input_ids,
|
|
564
|
+
positions,
|
|
565
|
+
forward_batch,
|
|
566
|
+
inputs_embeds,
|
|
567
|
+
pp_proxy_tensors,
|
|
568
|
+
)
|
|
569
|
+
if self.pp_group.is_last_rank:
|
|
570
|
+
return self.logits_processor(
|
|
571
|
+
input_ids, hidden_states, self.lm_head, forward_batch
|
|
572
|
+
)
|
|
573
|
+
else:
|
|
574
|
+
return hidden_states
|
|
575
|
+
|
|
576
|
+
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]):
|
|
577
|
+
stacked_params_mapping = [
|
|
578
|
+
# (param_name, shard_name, shard_id)
|
|
579
|
+
(".gate_up_proj", ".gate_proj", 0),
|
|
580
|
+
(".gate_up_proj", ".up_proj", 1),
|
|
581
|
+
]
|
|
582
|
+
if self.config.is_moe:
|
|
583
|
+
# Params for weights, fp8 weight scales, fp8 activation scales
|
|
584
|
+
# (param_name, weight_name, expert_id, shard_id)
|
|
585
|
+
expert_params_mapping = FusedMoE.make_expert_params_mapping(
|
|
586
|
+
ckpt_gate_proj_name="w1",
|
|
587
|
+
ckpt_down_proj_name="w2",
|
|
588
|
+
ckpt_up_proj_name="w3",
|
|
589
|
+
num_experts=self.config.num_experts,
|
|
590
|
+
)
|
|
591
|
+
else:
|
|
592
|
+
expert_params_mapping = []
|
|
593
|
+
params_dict = dict(self.named_parameters())
|
|
594
|
+
loaded_params: set[str] = set()
|
|
595
|
+
for args in weights:
|
|
596
|
+
name, loaded_weight = args[:2]
|
|
597
|
+
kwargs = args[2] if len(args) > 2 else {}
|
|
598
|
+
if "rotary_emb.inv_freq" in name:
|
|
599
|
+
continue
|
|
600
|
+
|
|
601
|
+
if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name:
|
|
602
|
+
# Models trained using ColossalAI may include these tensors in
|
|
603
|
+
# the checkpoint. Skip them.
|
|
604
|
+
continue
|
|
605
|
+
for param_name, weight_name, shard_id in stacked_params_mapping:
|
|
606
|
+
if weight_name not in name:
|
|
607
|
+
continue
|
|
608
|
+
# We have mlp.experts[0].gate_proj in the checkpoint.
|
|
609
|
+
# Since we handle the experts below in expert_params_mapping,
|
|
610
|
+
# we need to skip here BEFORE we update the name, otherwise
|
|
611
|
+
# name will be updated to mlp.experts[0].gate_up_proj, which
|
|
612
|
+
# will then be updated below in expert_params_mapping
|
|
613
|
+
# for mlp.experts[0].gate_gate_up_proj, which breaks load.
|
|
614
|
+
if ("mlp.experts." in name) and name not in params_dict:
|
|
615
|
+
continue
|
|
616
|
+
name = name.replace(weight_name, param_name)
|
|
617
|
+
# Skip loading extra bias for GPTQ models.
|
|
618
|
+
if name.endswith(".bias") and name not in params_dict:
|
|
619
|
+
continue
|
|
620
|
+
# if is_pp_missing_parameter(name, self):
|
|
621
|
+
# continue
|
|
622
|
+
param = params_dict[name]
|
|
623
|
+
weight_loader = param.weight_loader
|
|
624
|
+
weight_loader(param, loaded_weight, shard_id)
|
|
625
|
+
break
|
|
626
|
+
else:
|
|
627
|
+
for idx, (param_name, weight_name, expert_id, shard_id) in enumerate(
|
|
628
|
+
expert_params_mapping
|
|
629
|
+
):
|
|
630
|
+
if weight_name not in name:
|
|
631
|
+
continue
|
|
632
|
+
name = name.replace(weight_name, param_name)
|
|
633
|
+
# if is_pp_missing_parameter(name, self):
|
|
634
|
+
# continue
|
|
635
|
+
param = params_dict[name]
|
|
636
|
+
weight_loader = param.weight_loader
|
|
637
|
+
weight_loader(
|
|
638
|
+
param,
|
|
639
|
+
loaded_weight,
|
|
640
|
+
name,
|
|
641
|
+
expert_id=expert_id,
|
|
642
|
+
shard_id=shard_id,
|
|
643
|
+
)
|
|
644
|
+
break
|
|
645
|
+
else:
|
|
646
|
+
# Skip loading extra bias for GPTQ models.
|
|
647
|
+
if (
|
|
648
|
+
name.endswith(".bias")
|
|
649
|
+
and name not in params_dict
|
|
650
|
+
and not self.config.is_linear_attn
|
|
651
|
+
): # noqa: E501
|
|
652
|
+
continue
|
|
653
|
+
# Remapping the name of FP8 kv-scale.
|
|
654
|
+
name = maybe_remap_kv_scale_name(name, params_dict)
|
|
655
|
+
if name is None:
|
|
656
|
+
continue
|
|
657
|
+
# if is_pp_missing_parameter(name, self):
|
|
658
|
+
# continue
|
|
659
|
+
|
|
660
|
+
param = params_dict[name]
|
|
661
|
+
weight_loader = getattr(
|
|
662
|
+
param, "weight_loader", default_weight_loader
|
|
663
|
+
)
|
|
664
|
+
weight_loader(param, loaded_weight, **kwargs)
|
|
665
|
+
loaded_params.add(name)
|
|
666
|
+
|
|
667
|
+
for layer_id in self.config.full_attention_layer_ids:
|
|
668
|
+
self_attn = self.model.layers[layer_id].self_attn
|
|
669
|
+
w_kc, w_vc = self_attn.kv_b_proj.weight.unflatten(
|
|
670
|
+
0, (-1, self_attn.qk_nope_head_dim + self_attn.v_head_dim)
|
|
671
|
+
).split([self_attn.qk_nope_head_dim, self_attn.v_head_dim], dim=1)
|
|
672
|
+
self_attn.w_kc = w_kc.transpose(1, 2).contiguous().transpose(1, 2)
|
|
673
|
+
self_attn.w_vc = w_vc.contiguous().transpose(1, 2)
|
|
674
|
+
if hasattr(self_attn.kv_b_proj, "weight_scale"):
|
|
675
|
+
self_attn.w_scale = self_attn.kv_b_proj.weight_scale
|
|
676
|
+
|
|
677
|
+
|
|
678
|
+
EntryClass = KimiLinearForCausalLM
|
sglang/srt/models/llama4.py
CHANGED
|
@@ -148,7 +148,7 @@ class Llama4MoE(nn.Module):
|
|
|
148
148
|
return out_aD
|
|
149
149
|
|
|
150
150
|
def _forward_core(self, hidden_states, forward_mode: ForwardMode):
|
|
151
|
-
if
|
|
151
|
+
if _is_cuda:
|
|
152
152
|
return self._forward_core_shared_routed_overlap(hidden_states)
|
|
153
153
|
else:
|
|
154
154
|
return self._forward_core_normal(hidden_states)
|