sglang 0.5.2rc2__py3-none-any.whl → 0.5.3rc0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch_server.py +10 -1
- sglang/bench_serving.py +257 -29
- sglang/srt/configs/__init__.py +4 -0
- sglang/srt/configs/device_config.py +3 -1
- sglang/srt/configs/dots_vlm.py +139 -0
- sglang/srt/configs/load_config.py +1 -0
- sglang/srt/configs/model_config.py +50 -6
- sglang/srt/configs/qwen3_next.py +326 -0
- sglang/srt/connector/__init__.py +8 -1
- sglang/srt/connector/remote_instance.py +82 -0
- sglang/srt/constrained/base_grammar_backend.py +48 -12
- sglang/srt/constrained/llguidance_backend.py +0 -1
- sglang/srt/constrained/outlines_backend.py +0 -1
- sglang/srt/constrained/xgrammar_backend.py +28 -9
- sglang/srt/custom_op.py +11 -1
- sglang/srt/debug_utils/dump_comparator.py +81 -44
- sglang/srt/debug_utils/dump_loader.py +97 -0
- sglang/srt/debug_utils/dumper.py +11 -3
- sglang/srt/debug_utils/text_comparator.py +73 -11
- sglang/srt/disaggregation/base/conn.py +1 -1
- sglang/srt/disaggregation/common/conn.py +15 -12
- sglang/srt/disaggregation/decode.py +21 -10
- sglang/srt/disaggregation/decode_schedule_batch_mixin.py +4 -1
- sglang/srt/disaggregation/fake/conn.py +1 -1
- sglang/srt/disaggregation/mini_lb.py +6 -445
- sglang/srt/disaggregation/mooncake/conn.py +18 -10
- sglang/srt/disaggregation/nixl/conn.py +180 -16
- sglang/srt/disaggregation/prefill.py +5 -3
- sglang/srt/disaggregation/utils.py +5 -50
- sglang/srt/distributed/parallel_state.py +24 -3
- sglang/srt/entrypoints/engine.py +38 -17
- sglang/srt/entrypoints/grpc_request_manager.py +580 -0
- sglang/srt/entrypoints/grpc_server.py +680 -0
- sglang/srt/entrypoints/http_server.py +85 -54
- sglang/srt/entrypoints/openai/protocol.py +4 -1
- sglang/srt/entrypoints/openai/serving_base.py +46 -3
- sglang/srt/entrypoints/openai/serving_chat.py +36 -16
- sglang/srt/entrypoints/openai/serving_completions.py +12 -3
- sglang/srt/entrypoints/openai/serving_embedding.py +8 -3
- sglang/srt/entrypoints/openai/serving_rerank.py +3 -1
- sglang/srt/entrypoints/openai/serving_responses.py +6 -3
- sglang/srt/entrypoints/openai/serving_score.py +1 -0
- sglang/srt/eplb/eplb_manager.py +2 -2
- sglang/srt/eplb/expert_distribution.py +26 -13
- sglang/srt/eplb/expert_location.py +8 -3
- sglang/srt/eplb/expert_location_updater.py +1 -1
- sglang/srt/function_call/base_format_detector.py +3 -6
- sglang/srt/function_call/ebnf_composer.py +11 -9
- sglang/srt/function_call/function_call_parser.py +6 -0
- sglang/srt/function_call/glm4_moe_detector.py +1 -1
- sglang/srt/function_call/qwen3_coder_detector.py +1 -1
- sglang/srt/grpc/__init__.py +1 -0
- sglang/srt/grpc/sglang_scheduler_pb2.py +106 -0
- sglang/srt/grpc/sglang_scheduler_pb2.pyi +427 -0
- sglang/srt/grpc/sglang_scheduler_pb2_grpc.py +236 -0
- sglang/srt/hf_transformers_utils.py +4 -0
- sglang/srt/layers/activation.py +142 -9
- sglang/srt/layers/attention/ascend_backend.py +11 -4
- sglang/srt/layers/attention/fla/chunk.py +242 -0
- sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
- sglang/srt/layers/attention/fla/chunk_o.py +178 -0
- sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
- sglang/srt/layers/attention/fla/cumsum.py +300 -0
- sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
- sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
- sglang/srt/layers/attention/fla/index.py +37 -0
- sglang/srt/layers/attention/fla/l2norm.py +150 -0
- sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
- sglang/srt/layers/attention/fla/op.py +66 -0
- sglang/srt/layers/attention/fla/solve_tril.py +465 -0
- sglang/srt/layers/attention/fla/utils.py +331 -0
- sglang/srt/layers/attention/fla/wy_fast.py +158 -0
- sglang/srt/layers/attention/flashinfer_backend.py +6 -4
- sglang/srt/layers/attention/flashinfer_mla_backend.py +16 -12
- sglang/srt/layers/attention/hybrid_attn_backend.py +57 -50
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +602 -0
- sglang/srt/layers/attention/intel_amx_backend.py +3 -0
- sglang/srt/layers/attention/mamba/causal_conv1d.py +128 -0
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +1052 -0
- sglang/srt/layers/attention/mamba/mamba.py +64 -0
- sglang/srt/layers/attention/torch_native_backend.py +12 -6
- sglang/srt/layers/attention/triton_backend.py +18 -1
- sglang/srt/layers/attention/trtllm_mla_backend.py +124 -31
- sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
- sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
- sglang/srt/layers/dp_attention.py +30 -1
- sglang/srt/layers/layernorm.py +32 -15
- sglang/srt/layers/linear.py +34 -3
- sglang/srt/layers/logits_processor.py +29 -10
- sglang/srt/layers/moe/__init__.py +2 -1
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +3 -3
- sglang/srt/layers/moe/ep_moe/kernels.py +1 -1
- sglang/srt/layers/moe/ep_moe/layer.py +182 -62
- sglang/srt/layers/moe/flashinfer_cutedsl_moe.py +156 -0
- sglang/srt/layers/moe/fused_moe_native.py +5 -3
- sglang/srt/layers/moe/fused_moe_triton/configs/{triton_3_4_0/E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json } +35 -35
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=128,N=352,device_name=NVIDIA_RTX_5880_Ada_Generation,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=256,N=512,device_name=NVIDIA_H20.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H200.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +5 -2
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +1 -1
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +23 -20
- sglang/srt/layers/moe/fused_moe_triton/layer.py +61 -59
- sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
- sglang/srt/layers/moe/moe_runner/base.py +274 -1
- sglang/srt/layers/moe/moe_runner/runner.py +80 -0
- sglang/srt/layers/moe/moe_runner/triton.py +448 -0
- sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
- sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
- sglang/srt/layers/moe/token_dispatcher/deepep.py +43 -39
- sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
- sglang/srt/layers/moe/topk.py +30 -9
- sglang/srt/layers/moe/utils.py +12 -6
- sglang/srt/layers/quantization/awq.py +19 -7
- sglang/srt/layers/quantization/base_config.py +11 -6
- sglang/srt/layers/quantization/blockwise_int8.py +38 -27
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
- sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +27 -0
- sglang/srt/layers/quantization/fp8.py +76 -47
- sglang/srt/layers/quantization/fp8_utils.py +50 -31
- sglang/srt/layers/quantization/gptq.py +25 -17
- sglang/srt/layers/quantization/modelopt_quant.py +147 -47
- sglang/srt/layers/quantization/moe_wna16.py +21 -18
- sglang/srt/layers/quantization/mxfp4.py +64 -40
- sglang/srt/layers/quantization/quark/quark_moe.py +32 -27
- sglang/srt/layers/quantization/unquant.py +135 -47
- sglang/srt/layers/quantization/w4afp8.py +30 -17
- sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
- sglang/srt/layers/quantization/w8a8_int8.py +76 -38
- sglang/srt/layers/sampler.py +162 -18
- sglang/srt/lora/backend/base_backend.py +50 -8
- sglang/srt/lora/backend/triton_backend.py +90 -2
- sglang/srt/lora/layers.py +32 -0
- sglang/srt/lora/lora.py +4 -1
- sglang/srt/lora/lora_manager.py +35 -112
- sglang/srt/lora/mem_pool.py +24 -10
- sglang/srt/lora/utils.py +18 -9
- sglang/srt/managers/async_dynamic_batch_tokenizer.py +170 -0
- sglang/srt/managers/cache_controller.py +158 -160
- sglang/srt/managers/data_parallel_controller.py +105 -35
- sglang/srt/managers/detokenizer_manager.py +8 -4
- sglang/srt/managers/disagg_service.py +46 -0
- sglang/srt/managers/io_struct.py +199 -12
- sglang/srt/managers/mm_utils.py +1 -0
- sglang/srt/managers/multi_tokenizer_mixin.py +350 -400
- sglang/srt/managers/schedule_batch.py +77 -56
- sglang/srt/managers/schedule_policy.py +1 -1
- sglang/srt/managers/scheduler.py +187 -39
- sglang/srt/managers/scheduler_metrics_mixin.py +4 -3
- sglang/srt/managers/scheduler_output_processor_mixin.py +55 -11
- sglang/srt/managers/scheduler_profiler_mixin.py +1 -1
- sglang/srt/managers/tokenizer_communicator_mixin.py +569 -0
- sglang/srt/managers/tokenizer_manager.py +259 -519
- sglang/srt/managers/tp_worker.py +53 -4
- sglang/srt/managers/tp_worker_overlap_thread.py +42 -19
- sglang/srt/mem_cache/hicache_storage.py +3 -23
- sglang/srt/mem_cache/hiradix_cache.py +103 -43
- sglang/srt/mem_cache/memory_pool.py +347 -48
- sglang/srt/mem_cache/memory_pool_host.py +105 -46
- sglang/srt/mem_cache/radix_cache.py +0 -2
- sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
- sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +86 -4
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +280 -0
- sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +49 -7
- sglang/srt/mem_cache/swa_radix_cache.py +0 -2
- sglang/srt/metrics/collector.py +493 -76
- sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
- sglang/srt/model_executor/cpu_graph_runner.py +640 -0
- sglang/srt/model_executor/cuda_graph_runner.py +13 -5
- sglang/srt/model_executor/forward_batch_info.py +59 -2
- sglang/srt/model_executor/model_runner.py +356 -29
- sglang/srt/model_loader/__init__.py +9 -3
- sglang/srt/model_loader/loader.py +128 -4
- sglang/srt/model_loader/weight_utils.py +2 -1
- sglang/srt/models/apertus.py +686 -0
- sglang/srt/models/bailing_moe.py +798 -218
- sglang/srt/models/bailing_moe_nextn.py +168 -0
- sglang/srt/models/deepseek_v2.py +109 -15
- sglang/srt/models/dots_vlm.py +174 -0
- sglang/srt/models/dots_vlm_vit.py +337 -0
- sglang/srt/models/ernie4.py +1 -1
- sglang/srt/models/gemma3n_mm.py +1 -1
- sglang/srt/models/glm4_moe.py +1 -1
- sglang/srt/models/glm4v.py +4 -2
- sglang/srt/models/glm4v_moe.py +3 -0
- sglang/srt/models/gpt_oss.py +1 -1
- sglang/srt/models/llama4.py +9 -0
- sglang/srt/models/llama_eagle3.py +13 -0
- sglang/srt/models/longcat_flash.py +2 -2
- sglang/srt/models/mllama4.py +25 -0
- sglang/srt/models/opt.py +637 -0
- sglang/srt/models/qwen2.py +7 -0
- sglang/srt/models/qwen2_5_vl.py +27 -3
- sglang/srt/models/qwen2_moe.py +56 -12
- sglang/srt/models/qwen3_moe.py +1 -1
- sglang/srt/models/qwen3_next.py +1042 -0
- sglang/srt/models/qwen3_next_mtp.py +112 -0
- sglang/srt/models/step3_vl.py +1 -1
- sglang/srt/multimodal/processors/dots_vlm.py +99 -0
- sglang/srt/multimodal/processors/glm4v.py +9 -9
- sglang/srt/multimodal/processors/internvl.py +141 -129
- sglang/srt/multimodal/processors/qwen_vl.py +15 -5
- sglang/srt/offloader.py +27 -3
- sglang/srt/remote_instance_weight_loader_utils.py +69 -0
- sglang/srt/sampling/sampling_batch_info.py +18 -15
- sglang/srt/server_args.py +276 -35
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +5 -0
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +10 -1
- sglang/srt/speculative/eagle_utils.py +0 -2
- sglang/srt/speculative/eagle_worker.py +43 -4
- sglang/srt/speculative/spec_info.py +5 -0
- sglang/srt/speculative/standalone_worker.py +109 -0
- sglang/srt/tracing/trace.py +552 -0
- sglang/srt/utils.py +34 -3
- sglang/srt/weight_sync/utils.py +1 -1
- sglang/test/attention/test_trtllm_mla_backend.py +169 -5
- sglang/test/runners.py +4 -0
- sglang/test/test_cutlass_moe.py +24 -6
- sglang/test/test_disaggregation_utils.py +66 -0
- sglang/test/test_fp4_moe.py +370 -1
- sglang/test/test_utils.py +28 -1
- sglang/utils.py +11 -0
- sglang/version.py +1 -1
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/METADATA +59 -123
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/RECORD +237 -178
- sglang/srt/disaggregation/launch_lb.py +0 -118
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/WHEEL +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.2rc2.dist-info → sglang-0.5.3rc0.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,686 @@
|
|
1
|
+
# Copyright 2025 The SwissAI Initiative
|
2
|
+
# Copyright 2023-2024 SGLang Team
|
3
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
4
|
+
# you may not use this file except in compliance with the License.
|
5
|
+
# You may obtain a copy of the License at
|
6
|
+
#
|
7
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
8
|
+
#
|
9
|
+
# Unless required by applicable law or agreed to in writing, software
|
10
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
11
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
12
|
+
# See the License for the specific language governing permissions and
|
13
|
+
# limitations under the License.
|
14
|
+
# ==============================================================================
|
15
|
+
|
16
|
+
# Adapted from
|
17
|
+
# https://github.com/vllm-project/vllm/blob/c7f2cf2b7f67bce5842fedfdba508440fe257375/vllm/model_executor/models/llama.py#L1
|
18
|
+
"""Inference-only Apertus model compatible with HuggingFace weights."""
|
19
|
+
|
20
|
+
import logging
|
21
|
+
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
|
22
|
+
|
23
|
+
import torch
|
24
|
+
from torch import nn
|
25
|
+
from transformers import ApertusConfig
|
26
|
+
|
27
|
+
from sglang.srt.distributed import (
|
28
|
+
get_pp_group,
|
29
|
+
get_tensor_model_parallel_rank,
|
30
|
+
get_tensor_model_parallel_world_size,
|
31
|
+
)
|
32
|
+
from sglang.srt.layers.activation import XIELU
|
33
|
+
from sglang.srt.layers.layernorm import RMSNorm
|
34
|
+
from sglang.srt.layers.linear import (
|
35
|
+
ColumnParallelLinear,
|
36
|
+
QKVParallelLinear,
|
37
|
+
RowParallelLinear,
|
38
|
+
)
|
39
|
+
from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput
|
40
|
+
from sglang.srt.layers.pooler import Pooler, PoolingType
|
41
|
+
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
42
|
+
from sglang.srt.layers.radix_attention import RadixAttention
|
43
|
+
from sglang.srt.layers.rotary_embedding import get_rope
|
44
|
+
from sglang.srt.layers.utils import PPMissingLayer, get_layer_id
|
45
|
+
from sglang.srt.layers.vocab_parallel_embedding import (
|
46
|
+
ParallelLMHead,
|
47
|
+
VocabParallelEmbedding,
|
48
|
+
)
|
49
|
+
from sglang.srt.managers.schedule_batch import global_server_args_dict
|
50
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
|
51
|
+
from sglang.srt.model_loader.weight_utils import (
|
52
|
+
default_weight_loader,
|
53
|
+
kv_cache_scales_loader,
|
54
|
+
maybe_remap_kv_scale_name,
|
55
|
+
)
|
56
|
+
from sglang.srt.utils import add_prefix, make_layers
|
57
|
+
from sglang.utils import get_exception_traceback
|
58
|
+
|
59
|
+
logger = logging.getLogger(__name__)
|
60
|
+
|
61
|
+
|
62
|
+
class ApertusMLP(nn.Module):
|
63
|
+
def __init__(
|
64
|
+
self,
|
65
|
+
hidden_size: int,
|
66
|
+
intermediate_size: int,
|
67
|
+
hidden_act: str,
|
68
|
+
quant_config: Optional[QuantizationConfig] = None,
|
69
|
+
bias: bool = False,
|
70
|
+
prefix: str = "",
|
71
|
+
reduce_results: bool = True,
|
72
|
+
) -> None:
|
73
|
+
super().__init__()
|
74
|
+
self.up_proj = ColumnParallelLinear(
|
75
|
+
hidden_size,
|
76
|
+
intermediate_size,
|
77
|
+
bias=bias,
|
78
|
+
quant_config=quant_config,
|
79
|
+
prefix=add_prefix("up_proj", prefix),
|
80
|
+
)
|
81
|
+
self.down_proj = RowParallelLinear(
|
82
|
+
intermediate_size,
|
83
|
+
hidden_size,
|
84
|
+
bias=bias,
|
85
|
+
quant_config=quant_config,
|
86
|
+
prefix=add_prefix("down_proj", prefix),
|
87
|
+
reduce_results=reduce_results,
|
88
|
+
)
|
89
|
+
if hidden_act != "xielu":
|
90
|
+
raise ValueError(
|
91
|
+
f"Unsupported activation: {hidden_act}. "
|
92
|
+
"Only xIELU is supported for now."
|
93
|
+
)
|
94
|
+
self.act_fn = XIELU()
|
95
|
+
|
96
|
+
def forward(
|
97
|
+
self,
|
98
|
+
x,
|
99
|
+
forward_batch=None,
|
100
|
+
use_reduce_scatter: bool = False,
|
101
|
+
):
|
102
|
+
# note: with xielu, there's no gate_proj
|
103
|
+
x, _ = self.up_proj(x)
|
104
|
+
x = self.act_fn(x)
|
105
|
+
x, _ = self.down_proj(
|
106
|
+
x,
|
107
|
+
skip_all_reduce=use_reduce_scatter,
|
108
|
+
)
|
109
|
+
return x
|
110
|
+
|
111
|
+
|
112
|
+
class ApertusAttention(nn.Module):
|
113
|
+
def __init__(
|
114
|
+
self,
|
115
|
+
config: ApertusConfig,
|
116
|
+
hidden_size: int,
|
117
|
+
num_heads: int,
|
118
|
+
num_kv_heads: int,
|
119
|
+
layer_id: int = 0,
|
120
|
+
rope_theta: float = 10000,
|
121
|
+
rope_scaling: Optional[Dict[str, Any]] = None,
|
122
|
+
rope_is_neox_style: bool = True,
|
123
|
+
max_position_embeddings: int = 8192,
|
124
|
+
quant_config: Optional[QuantizationConfig] = None,
|
125
|
+
prefix: str = "",
|
126
|
+
bias: bool = False,
|
127
|
+
bias_o_proj: bool = False,
|
128
|
+
) -> None:
|
129
|
+
super().__init__()
|
130
|
+
self.layer_id = layer_id
|
131
|
+
self.hidden_size = hidden_size
|
132
|
+
tp_size = get_tensor_model_parallel_world_size()
|
133
|
+
self.total_num_heads = num_heads
|
134
|
+
assert self.total_num_heads % tp_size == 0
|
135
|
+
self.num_heads = self.total_num_heads // tp_size
|
136
|
+
self.total_num_kv_heads = num_kv_heads
|
137
|
+
if self.total_num_kv_heads >= tp_size:
|
138
|
+
# Number of KV heads is greater than TP size, so we partition
|
139
|
+
# the KV heads across multiple tensor parallel GPUs.
|
140
|
+
assert self.total_num_kv_heads % tp_size == 0
|
141
|
+
else:
|
142
|
+
# Number of KV heads is less than TP size, so we replicate
|
143
|
+
# the KV heads across multiple tensor parallel GPUs.
|
144
|
+
assert tp_size % self.total_num_kv_heads == 0
|
145
|
+
self.num_kv_heads = max(1, self.total_num_kv_heads // tp_size)
|
146
|
+
# MistralConfig has an optional head_dim introduced by Mistral-Nemo
|
147
|
+
self.head_dim = getattr(
|
148
|
+
config, "head_dim", self.hidden_size // self.total_num_heads
|
149
|
+
)
|
150
|
+
partial_rotary_factor = getattr(config, "partial_rotary_factor", 1)
|
151
|
+
self.rotary_dim = int(partial_rotary_factor * self.head_dim)
|
152
|
+
self.q_size = self.num_heads * self.head_dim
|
153
|
+
self.kv_size = self.num_kv_heads * self.head_dim
|
154
|
+
self.scaling = self.head_dim**-0.5
|
155
|
+
self.rope_theta = rope_theta
|
156
|
+
self.max_position_embeddings = max_position_embeddings
|
157
|
+
|
158
|
+
self.qkv_proj = QKVParallelLinear(
|
159
|
+
hidden_size,
|
160
|
+
self.head_dim,
|
161
|
+
self.total_num_heads,
|
162
|
+
self.total_num_kv_heads,
|
163
|
+
bias=bias,
|
164
|
+
quant_config=quant_config,
|
165
|
+
prefix=add_prefix("qkv_proj", prefix),
|
166
|
+
)
|
167
|
+
self.o_proj = RowParallelLinear(
|
168
|
+
self.total_num_heads * self.head_dim,
|
169
|
+
hidden_size,
|
170
|
+
bias=bias_o_proj,
|
171
|
+
quant_config=quant_config,
|
172
|
+
prefix=add_prefix("o_proj", prefix),
|
173
|
+
)
|
174
|
+
|
175
|
+
self.rotary_emb = get_rope(
|
176
|
+
self.head_dim,
|
177
|
+
rotary_dim=self.rotary_dim,
|
178
|
+
max_position=max_position_embeddings,
|
179
|
+
base=rope_theta,
|
180
|
+
rope_scaling=rope_scaling,
|
181
|
+
is_neox_style=rope_is_neox_style,
|
182
|
+
)
|
183
|
+
self.attn = RadixAttention(
|
184
|
+
self.num_heads,
|
185
|
+
self.head_dim,
|
186
|
+
self.scaling,
|
187
|
+
num_kv_heads=self.num_kv_heads,
|
188
|
+
layer_id=layer_id,
|
189
|
+
quant_config=quant_config,
|
190
|
+
prefix=add_prefix("attn", prefix),
|
191
|
+
)
|
192
|
+
self.q_norm = RMSNorm(self.head_dim, eps=config.rms_norm_eps)
|
193
|
+
self.k_norm = RMSNorm(self.head_dim, eps=config.rms_norm_eps)
|
194
|
+
|
195
|
+
def forward(
|
196
|
+
self,
|
197
|
+
positions: torch.Tensor,
|
198
|
+
hidden_states: torch.Tensor,
|
199
|
+
forward_batch: ForwardBatch,
|
200
|
+
) -> torch.Tensor:
|
201
|
+
qkv, _ = self.qkv_proj(hidden_states)
|
202
|
+
q, k, v = qkv.split([self.q_size, self.kv_size, self.kv_size], dim=-1)
|
203
|
+
q = self.q_norm(q.contiguous().view(-1, self.head_dim)).view_as(q)
|
204
|
+
k = self.k_norm(k.contiguous().view(-1, self.head_dim)).view_as(k)
|
205
|
+
q, k = self.rotary_emb(positions, q, k)
|
206
|
+
attn_output = self.attn(q, k, v, forward_batch)
|
207
|
+
output, _ = self.o_proj(attn_output)
|
208
|
+
return output
|
209
|
+
|
210
|
+
|
211
|
+
class ApertusDecoderLayer(nn.Module):
|
212
|
+
def __init__(
|
213
|
+
self,
|
214
|
+
config: ApertusConfig,
|
215
|
+
layer_id: int = 0,
|
216
|
+
quant_config: Optional[QuantizationConfig] = None,
|
217
|
+
prefix: str = "",
|
218
|
+
) -> None:
|
219
|
+
super().__init__()
|
220
|
+
self.hidden_size = config.hidden_size
|
221
|
+
rope_theta = getattr(config, "rope_theta", 10000)
|
222
|
+
rope_scaling = getattr(config, "rope_scaling", None)
|
223
|
+
if rope_scaling is not None and getattr(
|
224
|
+
config, "original_max_position_embeddings", None
|
225
|
+
):
|
226
|
+
rope_scaling["original_max_position_embeddings"] = (
|
227
|
+
config.original_max_position_embeddings
|
228
|
+
)
|
229
|
+
rope_is_neox_style = getattr(config, "rope_is_neox_style", True)
|
230
|
+
max_position_embeddings = getattr(config, "max_position_embeddings", 8192)
|
231
|
+
# Support llamafy/Qwen-Qwen2.5-7B-Instruct-llamafied with attention_bias
|
232
|
+
# Support internlm/internlm-7b with bias
|
233
|
+
attention_bias = getattr(config, "attention_bias", False) or getattr(
|
234
|
+
config, "bias", False
|
235
|
+
)
|
236
|
+
bias_o_proj = attention_bias
|
237
|
+
# support internlm/internlm3-8b with qkv_bias
|
238
|
+
if hasattr(config, "qkv_bias"):
|
239
|
+
attention_bias = config.qkv_bias
|
240
|
+
self.self_attn = ApertusAttention(
|
241
|
+
config=config,
|
242
|
+
hidden_size=self.hidden_size,
|
243
|
+
num_heads=config.num_attention_heads,
|
244
|
+
num_kv_heads=config.num_key_value_heads,
|
245
|
+
layer_id=layer_id,
|
246
|
+
rope_theta=rope_theta,
|
247
|
+
rope_scaling=rope_scaling,
|
248
|
+
rope_is_neox_style=rope_is_neox_style,
|
249
|
+
max_position_embeddings=max_position_embeddings,
|
250
|
+
quant_config=quant_config,
|
251
|
+
prefix=add_prefix("self_attn", prefix),
|
252
|
+
bias=attention_bias,
|
253
|
+
bias_o_proj=bias_o_proj,
|
254
|
+
)
|
255
|
+
self.mlp = ApertusMLP(
|
256
|
+
hidden_size=self.hidden_size,
|
257
|
+
intermediate_size=config.intermediate_size,
|
258
|
+
hidden_act=config.hidden_act,
|
259
|
+
quant_config=quant_config,
|
260
|
+
bias=getattr(config, "mlp_bias", False),
|
261
|
+
prefix=add_prefix("mlp", prefix),
|
262
|
+
)
|
263
|
+
self.attention_layernorm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
264
|
+
self.feedforward_layernorm = RMSNorm(
|
265
|
+
config.hidden_size, eps=config.rms_norm_eps
|
266
|
+
)
|
267
|
+
|
268
|
+
def forward(
|
269
|
+
self,
|
270
|
+
positions: torch.Tensor,
|
271
|
+
hidden_states: torch.Tensor,
|
272
|
+
forward_batch: ForwardBatch,
|
273
|
+
residual: Optional[torch.Tensor],
|
274
|
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
275
|
+
# Self Attention
|
276
|
+
if residual is None:
|
277
|
+
residual = hidden_states
|
278
|
+
hidden_states = self.attention_layernorm(hidden_states)
|
279
|
+
else:
|
280
|
+
hidden_states, residual = self.attention_layernorm(hidden_states, residual)
|
281
|
+
hidden_states = self.self_attn(
|
282
|
+
positions=positions,
|
283
|
+
hidden_states=hidden_states,
|
284
|
+
forward_batch=forward_batch,
|
285
|
+
)
|
286
|
+
|
287
|
+
# Fully Connected
|
288
|
+
hidden_states, residual = self.feedforward_layernorm(hidden_states, residual)
|
289
|
+
hidden_states = self.mlp(hidden_states)
|
290
|
+
return hidden_states, residual
|
291
|
+
|
292
|
+
|
293
|
+
class ApertusModel(nn.Module):
|
294
|
+
def __init__(
|
295
|
+
self,
|
296
|
+
config: ApertusConfig,
|
297
|
+
quant_config: Optional[QuantizationConfig] = None,
|
298
|
+
prefix: str = "",
|
299
|
+
) -> None:
|
300
|
+
super().__init__()
|
301
|
+
self.quant_config = quant_config
|
302
|
+
self.config = config
|
303
|
+
self.padding_idx = config.pad_token_id
|
304
|
+
self.vocab_size = config.vocab_size
|
305
|
+
self.org_vocab_size = config.vocab_size
|
306
|
+
self.pp_group = get_pp_group()
|
307
|
+
if self.pp_group.is_first_rank:
|
308
|
+
self.embed_tokens = VocabParallelEmbedding(
|
309
|
+
config.vocab_size,
|
310
|
+
config.hidden_size,
|
311
|
+
quant_config=quant_config,
|
312
|
+
prefix=add_prefix("embed_tokens", prefix),
|
313
|
+
)
|
314
|
+
else:
|
315
|
+
self.embed_tokens = PPMissingLayer()
|
316
|
+
|
317
|
+
self.layers, self.start_layer, self.end_layer = make_layers(
|
318
|
+
config.num_hidden_layers,
|
319
|
+
lambda idx, prefix: ApertusDecoderLayer(
|
320
|
+
config=config, quant_config=quant_config, layer_id=idx, prefix=prefix
|
321
|
+
),
|
322
|
+
pp_rank=self.pp_group.rank_in_group,
|
323
|
+
pp_size=self.pp_group.world_size,
|
324
|
+
prefix="model.layers",
|
325
|
+
)
|
326
|
+
|
327
|
+
if self.pp_group.is_last_rank:
|
328
|
+
self.norm = RMSNorm(config.hidden_size, eps=config.rms_norm_eps)
|
329
|
+
else:
|
330
|
+
self.norm = PPMissingLayer(return_tuple=True)
|
331
|
+
self.layers_to_capture = []
|
332
|
+
|
333
|
+
def forward(
|
334
|
+
self,
|
335
|
+
input_ids: torch.Tensor,
|
336
|
+
positions: torch.Tensor,
|
337
|
+
forward_batch: ForwardBatch,
|
338
|
+
input_embeds: torch.Tensor = None,
|
339
|
+
pp_proxy_tensors: Optional[PPProxyTensors] = None,
|
340
|
+
) -> Union[torch.Tensor, Tuple[torch.Tensor, List[torch.Tensor]], PPProxyTensors]:
|
341
|
+
if self.pp_group.is_first_rank:
|
342
|
+
if input_embeds is None:
|
343
|
+
hidden_states = self.embed_tokens(input_ids)
|
344
|
+
else:
|
345
|
+
hidden_states = input_embeds
|
346
|
+
residual = None
|
347
|
+
else:
|
348
|
+
assert pp_proxy_tensors is not None
|
349
|
+
# FIXME(@ying): reduce the number of proxy tensors by not fusing layer norms
|
350
|
+
hidden_states = pp_proxy_tensors["hidden_states"]
|
351
|
+
residual = pp_proxy_tensors["residual"]
|
352
|
+
deferred_norm = None
|
353
|
+
|
354
|
+
aux_hidden_states = []
|
355
|
+
for i in range(self.start_layer, self.end_layer):
|
356
|
+
if i in self.layers_to_capture:
|
357
|
+
aux_hidden_states.append(hidden_states + residual)
|
358
|
+
layer = self.layers[i]
|
359
|
+
hidden_states, residual = layer(
|
360
|
+
positions,
|
361
|
+
hidden_states,
|
362
|
+
forward_batch,
|
363
|
+
residual,
|
364
|
+
)
|
365
|
+
|
366
|
+
if not self.pp_group.is_last_rank:
|
367
|
+
return PPProxyTensors(
|
368
|
+
{
|
369
|
+
"hidden_states": hidden_states,
|
370
|
+
"residual": residual,
|
371
|
+
}
|
372
|
+
)
|
373
|
+
else:
|
374
|
+
hidden_states, _ = self.norm(hidden_states, residual)
|
375
|
+
|
376
|
+
if len(aux_hidden_states) == 0:
|
377
|
+
return hidden_states
|
378
|
+
|
379
|
+
return hidden_states, aux_hidden_states
|
380
|
+
|
381
|
+
# If this function is called, it should always initialize KV cache scale
|
382
|
+
# factors (or else raise an exception). Thus, handled exceptions should
|
383
|
+
# make sure to leave KV cache scale factors in a known good (dummy) state
|
384
|
+
def load_kv_cache_scales(self, quantization_param_path: str) -> None:
|
385
|
+
tp_size = get_tensor_model_parallel_world_size()
|
386
|
+
tp_rank = get_tensor_model_parallel_rank()
|
387
|
+
for layer_idx, scaling_factor in kv_cache_scales_loader(
|
388
|
+
quantization_param_path,
|
389
|
+
tp_rank,
|
390
|
+
tp_size,
|
391
|
+
self.config.num_hidden_layers,
|
392
|
+
self.config.__class__.model_type,
|
393
|
+
):
|
394
|
+
if not isinstance(self.layers[layer_idx], nn.Identity):
|
395
|
+
layer_self_attn = self.layers[layer_idx].self_attn
|
396
|
+
|
397
|
+
if hasattr(layer_self_attn.attn, "k_scale"):
|
398
|
+
layer_self_attn.attn.k_scale = scaling_factor
|
399
|
+
layer_self_attn.attn.v_scale = scaling_factor
|
400
|
+
else:
|
401
|
+
raise RuntimeError(
|
402
|
+
"Self attention has no KV cache scaling " "factor attribute!"
|
403
|
+
)
|
404
|
+
|
405
|
+
|
406
|
+
class ApertusForCausalLM(nn.Module):
|
407
|
+
# LoRA specific attributes
|
408
|
+
embedding_modules = {
|
409
|
+
"embed_tokens": "input_embeddings",
|
410
|
+
"lm_head": "output_embeddings",
|
411
|
+
}
|
412
|
+
embedding_padding_modules = ["lm_head"]
|
413
|
+
# BitandBytes specific attributes
|
414
|
+
default_bitsandbytes_target_modules = [
|
415
|
+
".down_proj.",
|
416
|
+
".up_proj.",
|
417
|
+
".q_proj.",
|
418
|
+
".k_proj.",
|
419
|
+
".v_proj.",
|
420
|
+
".o_proj.",
|
421
|
+
]
|
422
|
+
# in TP, these weights are partitioned along the column dimension (dim=-1)
|
423
|
+
column_parallel_weights_modules = [".down_proj.", ".o_proj."]
|
424
|
+
bitsandbytes_stacked_params_mapping = {
|
425
|
+
# shard_name, weight_name, index
|
426
|
+
".q_proj": (".qkv_proj", 0),
|
427
|
+
".k_proj": (".qkv_proj", 1),
|
428
|
+
".v_proj": (".qkv_proj", 2),
|
429
|
+
}
|
430
|
+
|
431
|
+
def __init__(
|
432
|
+
self,
|
433
|
+
config: ApertusConfig,
|
434
|
+
quant_config: Optional[QuantizationConfig] = None,
|
435
|
+
prefix: str = "",
|
436
|
+
) -> None:
|
437
|
+
super().__init__()
|
438
|
+
self.pp_group = get_pp_group()
|
439
|
+
self.config = config
|
440
|
+
self.quant_config = quant_config
|
441
|
+
self.model = self._init_model(config, quant_config, add_prefix("model", prefix))
|
442
|
+
if self.config.tie_word_embeddings:
|
443
|
+
self.lm_head = self.model.embed_tokens
|
444
|
+
else:
|
445
|
+
self.lm_head = ParallelLMHead(
|
446
|
+
config.vocab_size,
|
447
|
+
config.hidden_size,
|
448
|
+
quant_config=quant_config,
|
449
|
+
prefix=add_prefix("lm_head", prefix),
|
450
|
+
use_attn_tp_group=global_server_args_dict["enable_dp_lm_head"],
|
451
|
+
)
|
452
|
+
self.logits_processor = LogitsProcessor(config)
|
453
|
+
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
|
454
|
+
self.stacked_params_mapping = [
|
455
|
+
# (param_name, shard_name, shard_id)
|
456
|
+
(".qkv_proj", ".q_proj", "q"),
|
457
|
+
(".qkv_proj", ".k_proj", "k"),
|
458
|
+
(".qkv_proj", ".v_proj", "v"),
|
459
|
+
]
|
460
|
+
|
461
|
+
self.capture_aux_hidden_states = False
|
462
|
+
|
463
|
+
def _init_model(
|
464
|
+
self,
|
465
|
+
config: ApertusConfig,
|
466
|
+
quant_config: Optional[QuantizationConfig] = None,
|
467
|
+
prefix: str = "",
|
468
|
+
):
|
469
|
+
return ApertusModel(config, quant_config=quant_config, prefix=prefix)
|
470
|
+
|
471
|
+
@torch.no_grad()
|
472
|
+
def forward(
|
473
|
+
self,
|
474
|
+
input_ids: torch.Tensor,
|
475
|
+
positions: torch.Tensor,
|
476
|
+
forward_batch: ForwardBatch,
|
477
|
+
input_embeds: torch.Tensor = None,
|
478
|
+
get_embedding: bool = False,
|
479
|
+
pp_proxy_tensors: Optional[PPProxyTensors] = None,
|
480
|
+
) -> LogitsProcessorOutput:
|
481
|
+
hidden_states = self.model(
|
482
|
+
input_ids,
|
483
|
+
positions,
|
484
|
+
forward_batch,
|
485
|
+
input_embeds,
|
486
|
+
pp_proxy_tensors=pp_proxy_tensors,
|
487
|
+
)
|
488
|
+
|
489
|
+
aux_hidden_states = None
|
490
|
+
if self.capture_aux_hidden_states:
|
491
|
+
hidden_states, aux_hidden_states = hidden_states
|
492
|
+
|
493
|
+
if self.pp_group.is_last_rank:
|
494
|
+
if not get_embedding:
|
495
|
+
return self.logits_processor(
|
496
|
+
input_ids,
|
497
|
+
hidden_states,
|
498
|
+
self.lm_head,
|
499
|
+
forward_batch,
|
500
|
+
aux_hidden_states,
|
501
|
+
)
|
502
|
+
else:
|
503
|
+
return self.pooler(hidden_states, forward_batch)
|
504
|
+
else:
|
505
|
+
return hidden_states
|
506
|
+
|
507
|
+
@torch.no_grad()
|
508
|
+
def forward_split_prefill(
|
509
|
+
self,
|
510
|
+
input_ids: torch.Tensor,
|
511
|
+
positions: torch.Tensor,
|
512
|
+
forward_batch: ForwardBatch,
|
513
|
+
split_interval: Tuple[int, int], # [start, end) 0-based
|
514
|
+
input_embeds: torch.Tensor = None,
|
515
|
+
) -> Optional[LogitsProcessorOutput]:
|
516
|
+
start, end = split_interval
|
517
|
+
# embed
|
518
|
+
if start == 0:
|
519
|
+
if input_embeds is None:
|
520
|
+
forward_batch.hidden_states = self.model.embed_tokens(input_ids)
|
521
|
+
else:
|
522
|
+
forward_batch.hidden_states = input_embeds
|
523
|
+
# decoder layer
|
524
|
+
for i in range(start, end):
|
525
|
+
layer = self.model.layers[i]
|
526
|
+
forward_batch.hidden_states, forward_batch.residual = layer(
|
527
|
+
positions,
|
528
|
+
forward_batch.hidden_states,
|
529
|
+
forward_batch,
|
530
|
+
forward_batch.residual,
|
531
|
+
)
|
532
|
+
|
533
|
+
if end == self.model.config.num_hidden_layers:
|
534
|
+
# norm
|
535
|
+
hidden_states, _ = self.model.norm(
|
536
|
+
forward_batch.hidden_states, forward_batch.residual
|
537
|
+
)
|
538
|
+
forward_batch.hidden_states = hidden_states
|
539
|
+
# logits process
|
540
|
+
result = self.logits_processor(
|
541
|
+
input_ids, forward_batch.hidden_states, self.lm_head, forward_batch
|
542
|
+
)
|
543
|
+
else:
|
544
|
+
result = None
|
545
|
+
|
546
|
+
return result
|
547
|
+
|
548
|
+
@property
|
549
|
+
def start_layer(self):
|
550
|
+
return self.model.start_layer
|
551
|
+
|
552
|
+
@property
|
553
|
+
def end_layer(self):
|
554
|
+
return self.model.end_layer
|
555
|
+
|
556
|
+
def get_input_embeddings(self) -> nn.Embedding:
|
557
|
+
return self.model.embed_tokens
|
558
|
+
|
559
|
+
def get_module_name_from_weight_name(self, name):
|
560
|
+
for param_name, weight_name, shard_id, num_shard in self.stacked_params_mapping:
|
561
|
+
if weight_name in name:
|
562
|
+
return (
|
563
|
+
name.replace(weight_name, param_name)[: -len(".weight")],
|
564
|
+
num_shard,
|
565
|
+
)
|
566
|
+
return name[: -len(".weight")], 1
|
567
|
+
|
568
|
+
def get_num_params(self):
|
569
|
+
params_dict = dict(self.named_parameters())
|
570
|
+
return len(params_dict)
|
571
|
+
|
572
|
+
def load_weights(self, weights: Iterable[Tuple[str, torch.Tensor]]):
|
573
|
+
stacked_params_mapping = [
|
574
|
+
# (param_name, shard_name, shard_id)
|
575
|
+
(".qkv_proj", ".q_proj", "q"),
|
576
|
+
(".qkv_proj", ".k_proj", "k"),
|
577
|
+
(".qkv_proj", ".v_proj", "v"),
|
578
|
+
]
|
579
|
+
|
580
|
+
params_dict = dict(self.named_parameters())
|
581
|
+
|
582
|
+
for name, buffer in self.named_buffers():
|
583
|
+
if name.endswith(".beta") or name.endswith(".eps"):
|
584
|
+
params_dict[name] = buffer
|
585
|
+
|
586
|
+
for name, loaded_weight in weights:
|
587
|
+
layer_id = get_layer_id(name)
|
588
|
+
if (
|
589
|
+
layer_id is not None
|
590
|
+
and hasattr(self.model, "start_layer")
|
591
|
+
and (
|
592
|
+
layer_id < self.model.start_layer
|
593
|
+
or layer_id >= self.model.end_layer
|
594
|
+
)
|
595
|
+
):
|
596
|
+
continue
|
597
|
+
if "rotary_emb.inv_freq" in name or "projector" in name:
|
598
|
+
continue
|
599
|
+
if "rotary_emb.cos_cached" in name or "rotary_emb.sin_cached" in name:
|
600
|
+
# Models trained using ColossalAI may include these tensors in
|
601
|
+
# the checkpoint. Skip them.
|
602
|
+
continue
|
603
|
+
if name.startswith("model.vision_tower") and name not in params_dict:
|
604
|
+
continue
|
605
|
+
if self.config.tie_word_embeddings and "lm_head.weight" in name:
|
606
|
+
continue
|
607
|
+
# Handle FP8 kv-scale remapping
|
608
|
+
if "scale" in name:
|
609
|
+
name = maybe_remap_kv_scale_name(name, params_dict)
|
610
|
+
if name is None:
|
611
|
+
continue
|
612
|
+
|
613
|
+
for param_name, weight_name, shard_id in stacked_params_mapping:
|
614
|
+
if weight_name not in name:
|
615
|
+
continue
|
616
|
+
name = name.replace(weight_name, param_name)
|
617
|
+
# Skip loading extra bias for GPTQ models.
|
618
|
+
if name.endswith(".bias") and name not in params_dict:
|
619
|
+
continue
|
620
|
+
if name not in params_dict:
|
621
|
+
continue
|
622
|
+
param = params_dict[name]
|
623
|
+
weight_loader = param.weight_loader
|
624
|
+
weight_loader(param, loaded_weight, shard_id)
|
625
|
+
break
|
626
|
+
else:
|
627
|
+
# Skip loading extra bias for GPTQ models.
|
628
|
+
if name.endswith(".bias") and name not in params_dict:
|
629
|
+
continue
|
630
|
+
# Skip loading kv_scale from ckpts towards new design.
|
631
|
+
if name.endswith(".kv_scale") and name not in params_dict:
|
632
|
+
continue
|
633
|
+
if name in params_dict.keys():
|
634
|
+
param = params_dict[name]
|
635
|
+
weight_loader = getattr(
|
636
|
+
param, "weight_loader", default_weight_loader
|
637
|
+
)
|
638
|
+
weight_loader(param, loaded_weight)
|
639
|
+
else:
|
640
|
+
logger.warning(f"Parameter {name} not found in params_dict")
|
641
|
+
|
642
|
+
def get_embed_and_head(self):
|
643
|
+
return self.model.embed_tokens.weight, self.lm_head.weight
|
644
|
+
|
645
|
+
def set_embed_and_head(self, embed, head):
|
646
|
+
del self.model.embed_tokens.weight
|
647
|
+
del self.lm_head.weight
|
648
|
+
self.model.embed_tokens.weight = embed
|
649
|
+
self.lm_head.weight = head
|
650
|
+
torch.cuda.empty_cache()
|
651
|
+
torch.cuda.synchronize()
|
652
|
+
|
653
|
+
def get_embed(self):
|
654
|
+
return self.model.embed_tokens.weight
|
655
|
+
|
656
|
+
def set_embed(self, embed):
|
657
|
+
# NOTE: If draft hidden size != target hidden size, the embed weight cannot be shared for EAGLE3
|
658
|
+
if (
|
659
|
+
hasattr(self.config, "target_hidden_size")
|
660
|
+
and self.config.target_hidden_size != self.config.hidden_size
|
661
|
+
):
|
662
|
+
return
|
663
|
+
del self.model.embed_tokens.weight
|
664
|
+
self.model.embed_tokens.weight = embed
|
665
|
+
torch.cuda.empty_cache()
|
666
|
+
torch.cuda.synchronize()
|
667
|
+
|
668
|
+
def load_kv_cache_scales(self, quantization_param_path: str) -> None:
|
669
|
+
self.model.load_kv_cache_scales(quantization_param_path)
|
670
|
+
|
671
|
+
def set_eagle3_layers_to_capture(self, layer_ids: Optional[List[int]] = None):
|
672
|
+
if not self.pp_group.is_last_rank:
|
673
|
+
return
|
674
|
+
|
675
|
+
if layer_ids is None:
|
676
|
+
self.capture_aux_hidden_states = True
|
677
|
+
num_layers = self.config.num_hidden_layers
|
678
|
+
self.model.layers_to_capture = [2, num_layers // 2, num_layers - 3]
|
679
|
+
else:
|
680
|
+
self.capture_aux_hidden_states = True
|
681
|
+
# we plus 1 here because in sglang, for the ith layer, it takes the output
|
682
|
+
# of the (i-1)th layer as aux hidden state
|
683
|
+
self.model.layers_to_capture = [val + 1 for val in layer_ids]
|
684
|
+
|
685
|
+
|
686
|
+
EntryClass = [ApertusForCausalLM]
|