sglang 0.5.1.post2__py3-none-any.whl → 0.5.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +3 -0
- sglang/bench_one_batch_server.py +89 -54
- sglang/bench_serving.py +437 -40
- sglang/lang/interpreter.py +1 -1
- sglang/profiler.py +0 -1
- sglang/srt/configs/__init__.py +4 -0
- sglang/srt/configs/internvl.py +6 -0
- sglang/srt/configs/longcat_flash.py +104 -0
- sglang/srt/configs/model_config.py +37 -7
- sglang/srt/configs/qwen3_next.py +326 -0
- sglang/srt/connector/__init__.py +1 -1
- sglang/srt/connector/base_connector.py +1 -2
- sglang/srt/connector/redis.py +2 -2
- sglang/srt/connector/serde/__init__.py +1 -1
- sglang/srt/connector/serde/safe_serde.py +4 -3
- sglang/srt/custom_op.py +11 -1
- sglang/srt/debug_utils/dump_comparator.py +81 -44
- sglang/srt/debug_utils/dump_loader.py +97 -0
- sglang/srt/debug_utils/dumper.py +11 -3
- sglang/srt/debug_utils/text_comparator.py +73 -11
- sglang/srt/disaggregation/ascend/conn.py +75 -0
- sglang/srt/disaggregation/base/conn.py +1 -1
- sglang/srt/disaggregation/common/conn.py +15 -12
- sglang/srt/disaggregation/decode.py +6 -4
- sglang/srt/disaggregation/fake/conn.py +1 -1
- sglang/srt/disaggregation/mini_lb.py +6 -420
- sglang/srt/disaggregation/mooncake/conn.py +18 -10
- sglang/srt/disaggregation/nixl/conn.py +180 -16
- sglang/srt/disaggregation/prefill.py +6 -4
- sglang/srt/disaggregation/utils.py +5 -50
- sglang/srt/distributed/parallel_state.py +94 -58
- sglang/srt/entrypoints/engine.py +34 -14
- sglang/srt/entrypoints/http_server.py +172 -47
- sglang/srt/entrypoints/openai/protocol.py +90 -27
- sglang/srt/entrypoints/openai/serving_base.py +6 -2
- sglang/srt/entrypoints/openai/serving_chat.py +82 -26
- sglang/srt/entrypoints/openai/serving_completions.py +25 -4
- sglang/srt/entrypoints/openai/serving_embedding.py +8 -4
- sglang/srt/entrypoints/openai/serving_responses.py +7 -4
- sglang/srt/eplb/eplb_manager.py +28 -4
- sglang/srt/eplb/expert_distribution.py +55 -15
- sglang/srt/eplb/expert_location.py +8 -3
- sglang/srt/eplb/expert_location_updater.py +1 -1
- sglang/srt/function_call/deepseekv31_detector.py +222 -0
- sglang/srt/function_call/ebnf_composer.py +11 -9
- sglang/srt/function_call/function_call_parser.py +2 -0
- sglang/srt/function_call/glm4_moe_detector.py +1 -1
- sglang/srt/function_call/gpt_oss_detector.py +144 -256
- sglang/srt/function_call/qwen3_coder_detector.py +1 -1
- sglang/srt/hf_transformers_utils.py +28 -7
- sglang/srt/layers/activation.py +44 -9
- sglang/srt/layers/attention/aiter_backend.py +93 -68
- sglang/srt/layers/attention/ascend_backend.py +381 -136
- sglang/srt/layers/attention/fla/chunk.py +242 -0
- sglang/srt/layers/attention/fla/chunk_delta_h.py +314 -0
- sglang/srt/layers/attention/fla/chunk_o.py +178 -0
- sglang/srt/layers/attention/fla/chunk_scaled_dot_kkt.py +151 -0
- sglang/srt/layers/attention/fla/cumsum.py +300 -0
- sglang/srt/layers/attention/fla/fused_recurrent.py +640 -0
- sglang/srt/layers/attention/fla/fused_sigmoid_gating_recurrent.py +232 -0
- sglang/srt/layers/attention/fla/index.py +37 -0
- sglang/srt/layers/attention/fla/l2norm.py +150 -0
- sglang/srt/layers/attention/fla/layernorm_gated.py +326 -0
- sglang/srt/layers/attention/fla/op.py +66 -0
- sglang/srt/layers/attention/fla/solve_tril.py +465 -0
- sglang/srt/layers/attention/fla/utils.py +331 -0
- sglang/srt/layers/attention/fla/wy_fast.py +158 -0
- sglang/srt/layers/attention/flashattention_backend.py +241 -7
- sglang/srt/layers/attention/flashinfer_backend.py +11 -6
- sglang/srt/layers/attention/flashinfer_mla_backend.py +21 -14
- sglang/srt/layers/attention/hybrid_attn_backend.py +47 -8
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +584 -0
- sglang/srt/layers/attention/intel_amx_backend.py +3 -0
- sglang/srt/layers/attention/mamba/causal_conv1d.py +128 -0
- sglang/srt/layers/attention/mamba/causal_conv1d_triton.py +1052 -0
- sglang/srt/layers/attention/mamba/mamba.py +64 -0
- sglang/srt/layers/attention/torch_native_backend.py +12 -6
- sglang/srt/layers/attention/trtllm_mla_backend.py +126 -36
- sglang/srt/layers/attention/wave_ops/decode_attention.py +2 -4
- sglang/srt/layers/attention/wave_ops/extend_attention.py +1 -3
- sglang/srt/layers/communicator.py +45 -8
- sglang/srt/layers/layernorm.py +54 -12
- sglang/srt/layers/logits_processor.py +10 -3
- sglang/srt/layers/moe/__init__.py +2 -1
- sglang/srt/layers/moe/cutlass_moe.py +0 -8
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +4 -12
- sglang/srt/layers/moe/ep_moe/kernels.py +74 -0
- sglang/srt/layers/moe/ep_moe/layer.py +111 -56
- sglang/srt/layers/moe/fused_moe_native.py +5 -3
- sglang/srt/layers/moe/fused_moe_triton/__init__.py +5 -3
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_3_1/E=128,N=768,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=129,N=352,device_name=NVIDIA_B200,dtype=fp8_w8a8.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/{E=128,N=768,device_name=NVIDIA_B200,dtype=fp8_w8a8,block_shape=[128, 128].json → E=257,N=128,device_name=NVIDIA_H100_80GB_HBM3,dtype=fp8_w8a8,block_shape=[128, 128].json } +29 -29
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=64,device_name=NVIDIA_A100-SXM4-80GB.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=128,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=256,device_name=NVIDIA_H20-3e.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=512,N=64,device_name=NVIDIA_H100_80GB_HBM3.json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +9 -1049
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +212 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +799 -0
- sglang/srt/layers/moe/fused_moe_triton/layer.py +56 -45
- sglang/srt/layers/moe/fused_moe_triton/moe_align_block_size.py +87 -0
- sglang/srt/layers/moe/moe_runner/__init__.py +2 -1
- sglang/srt/layers/moe/moe_runner/base.py +274 -1
- sglang/srt/layers/moe/moe_runner/runner.py +80 -0
- sglang/srt/layers/moe/moe_runner/triton.py +448 -0
- sglang/srt/layers/moe/token_dispatcher/__init__.py +16 -4
- sglang/srt/layers/moe/token_dispatcher/{base_dispatcher.py → base.py} +67 -17
- sglang/srt/layers/moe/token_dispatcher/deepep.py +41 -38
- sglang/srt/layers/moe/token_dispatcher/standard.py +44 -2
- sglang/srt/layers/moe/topk.py +43 -12
- sglang/srt/layers/moe/utils.py +6 -5
- sglang/srt/layers/quantization/awq.py +19 -7
- sglang/srt/layers/quantization/base_config.py +11 -6
- sglang/srt/layers/quantization/blockwise_int8.py +38 -27
- sglang/srt/layers/quantization/compressed_tensors/compressed_tensors_moe.py +50 -30
- sglang/srt/layers/quantization/compressed_tensors/schemes/compressed_tensors_w8a8_fp8.py +13 -1
- sglang/srt/layers/quantization/deep_gemm_wrapper/compile_utils.py +141 -235
- sglang/srt/layers/quantization/deep_gemm_wrapper/configurer.py +5 -10
- sglang/srt/layers/quantization/deep_gemm_wrapper/entrypoint.py +31 -22
- sglang/srt/layers/quantization/fp8.py +78 -48
- sglang/srt/layers/quantization/fp8_kernel.py +2 -2
- sglang/srt/layers/quantization/fp8_utils.py +45 -31
- sglang/srt/layers/quantization/gptq.py +25 -17
- sglang/srt/layers/quantization/modelopt_quant.py +107 -40
- sglang/srt/layers/quantization/moe_wna16.py +21 -18
- sglang/srt/layers/quantization/mxfp4.py +93 -68
- sglang/srt/layers/quantization/mxfp4_tensor.py +3 -1
- sglang/srt/layers/quantization/quark/quark_moe.py +32 -27
- sglang/srt/layers/quantization/quark/schemes/quark_w4a4_mxfp4.py +49 -30
- sglang/srt/layers/quantization/quark/utils.py +97 -0
- sglang/srt/layers/quantization/rocm_mxfp4_utils.py +13 -0
- sglang/srt/layers/quantization/unquant.py +135 -47
- sglang/srt/layers/quantization/utils.py +13 -0
- sglang/srt/layers/quantization/w4afp8.py +60 -42
- sglang/srt/layers/quantization/w8a8_fp8.py +35 -20
- sglang/srt/layers/quantization/w8a8_int8.py +83 -41
- sglang/srt/layers/rocm_linear_utils.py +44 -0
- sglang/srt/layers/rotary_embedding.py +28 -19
- sglang/srt/layers/sampler.py +29 -5
- sglang/srt/layers/utils.py +0 -14
- sglang/srt/lora/backend/base_backend.py +50 -8
- sglang/srt/lora/backend/triton_backend.py +90 -2
- sglang/srt/lora/layers.py +32 -0
- sglang/srt/lora/lora.py +4 -1
- sglang/srt/lora/lora_manager.py +35 -112
- sglang/srt/lora/mem_pool.py +24 -10
- sglang/srt/lora/utils.py +18 -9
- sglang/srt/managers/cache_controller.py +396 -365
- sglang/srt/managers/data_parallel_controller.py +30 -15
- sglang/srt/managers/detokenizer_manager.py +18 -2
- sglang/srt/managers/disagg_service.py +46 -0
- sglang/srt/managers/io_struct.py +190 -11
- sglang/srt/managers/mm_utils.py +6 -1
- sglang/srt/managers/multi_tokenizer_mixin.py +579 -0
- sglang/srt/managers/schedule_batch.py +27 -44
- sglang/srt/managers/schedule_policy.py +4 -3
- sglang/srt/managers/scheduler.py +148 -122
- sglang/srt/managers/scheduler_metrics_mixin.py +114 -8
- sglang/srt/managers/scheduler_output_processor_mixin.py +29 -19
- sglang/srt/managers/scheduler_profiler_mixin.py +1 -1
- sglang/srt/managers/scheduler_update_weights_mixin.py +8 -1
- sglang/srt/managers/template_manager.py +3 -3
- sglang/srt/managers/tokenizer_communicator_mixin.py +491 -0
- sglang/srt/managers/tokenizer_manager.py +77 -480
- sglang/srt/managers/tp_worker.py +16 -4
- sglang/srt/managers/tp_worker_overlap_thread.py +8 -10
- sglang/srt/mem_cache/allocator.py +1 -1
- sglang/srt/mem_cache/chunk_cache.py +1 -1
- sglang/srt/mem_cache/hicache_storage.py +53 -40
- sglang/srt/mem_cache/hiradix_cache.py +196 -104
- sglang/srt/mem_cache/lora_radix_cache.py +1 -1
- sglang/srt/mem_cache/memory_pool.py +395 -53
- sglang/srt/mem_cache/memory_pool_host.py +27 -19
- sglang/srt/mem_cache/radix_cache.py +6 -6
- sglang/srt/mem_cache/radix_cache_cpp.py +1 -1
- sglang/srt/mem_cache/storage/hf3fs/hf3fs_client.py +164 -0
- sglang/srt/mem_cache/storage/hf3fs/{client_hf3fs.py → hf3fs_usrbio_client.py} +5 -1
- sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +61 -34
- sglang/srt/mem_cache/storage/hf3fs/storage_hf3fs.py +152 -23
- sglang/srt/mem_cache/storage/lmcache/lmc_radix_cache.py +280 -0
- sglang/srt/mem_cache/storage/lmcache/unit_test.py +121 -0
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +154 -95
- sglang/srt/mem_cache/storage/mooncake_store/test_mooncake_store.py +161 -0
- sglang/srt/mem_cache/swa_radix_cache.py +1 -3
- sglang/srt/metrics/collector.py +484 -63
- sglang/srt/metrics/startup_func_log_and_timer.py +150 -0
- sglang/srt/metrics/utils.py +48 -0
- sglang/srt/model_executor/cpu_graph_runner.py +640 -0
- sglang/srt/model_executor/cuda_graph_runner.py +13 -5
- sglang/srt/model_executor/forward_batch_info.py +72 -18
- sglang/srt/model_executor/model_runner.py +190 -32
- sglang/srt/model_loader/__init__.py +9 -3
- sglang/srt/model_loader/loader.py +33 -28
- sglang/srt/model_loader/utils.py +12 -0
- sglang/srt/model_loader/weight_utils.py +2 -1
- sglang/srt/models/deepseek_v2.py +323 -53
- sglang/srt/models/gemma3n_mm.py +1 -1
- sglang/srt/models/glm4_moe.py +10 -1
- sglang/srt/models/glm4v.py +4 -2
- sglang/srt/models/gpt_oss.py +7 -19
- sglang/srt/models/internvl.py +28 -0
- sglang/srt/models/llama4.py +9 -0
- sglang/srt/models/llama_eagle3.py +17 -0
- sglang/srt/models/longcat_flash.py +1026 -0
- sglang/srt/models/longcat_flash_nextn.py +699 -0
- sglang/srt/models/minicpmv.py +165 -3
- sglang/srt/models/mllama4.py +25 -0
- sglang/srt/models/opt.py +637 -0
- sglang/srt/models/qwen2.py +33 -3
- sglang/srt/models/qwen2_5_vl.py +91 -42
- sglang/srt/models/qwen2_moe.py +79 -14
- sglang/srt/models/qwen3.py +8 -2
- sglang/srt/models/qwen3_moe.py +39 -8
- sglang/srt/models/qwen3_next.py +1039 -0
- sglang/srt/models/qwen3_next_mtp.py +109 -0
- sglang/srt/models/torch_native_llama.py +1 -1
- sglang/srt/models/transformers.py +1 -1
- sglang/srt/multimodal/processors/base_processor.py +4 -2
- sglang/srt/multimodal/processors/glm4v.py +9 -9
- sglang/srt/multimodal/processors/internvl.py +141 -129
- sglang/srt/{conversation.py → parser/conversation.py} +38 -5
- sglang/srt/parser/harmony_parser.py +588 -0
- sglang/srt/parser/reasoning_parser.py +309 -0
- sglang/srt/sampling/penaltylib/orchestrator.py +14 -2
- sglang/srt/sampling/sampling_batch_info.py +18 -15
- sglang/srt/server_args.py +307 -80
- sglang/srt/speculative/eagle_draft_cuda_graph_runner.py +5 -0
- sglang/srt/speculative/eagle_draft_extend_cuda_graph_runner.py +10 -1
- sglang/srt/speculative/eagle_worker.py +216 -120
- sglang/srt/speculative/spec_info.py +5 -0
- sglang/srt/speculative/standalone_worker.py +109 -0
- sglang/srt/tokenizer/tiktoken_tokenizer.py +6 -1
- sglang/srt/utils.py +96 -7
- sglang/srt/weight_sync/utils.py +1 -1
- sglang/test/attention/test_trtllm_mla_backend.py +181 -8
- sglang/test/few_shot_gsm8k.py +1 -0
- sglang/test/runners.py +4 -0
- sglang/test/test_cutlass_moe.py +24 -6
- sglang/test/test_cutlass_w4a8_moe.py +24 -9
- sglang/test/test_disaggregation_utils.py +66 -0
- sglang/test/test_utils.py +25 -1
- sglang/utils.py +5 -0
- sglang/version.py +1 -1
- {sglang-0.5.1.post2.dist-info → sglang-0.5.2.dist-info}/METADATA +13 -10
- {sglang-0.5.1.post2.dist-info → sglang-0.5.2.dist-info}/RECORD +253 -201
- sglang/srt/disaggregation/launch_lb.py +0 -131
- sglang/srt/mem_cache/storage/mooncake_store/unit_test.py +0 -40
- sglang/srt/reasoning_parser.py +0 -553
- /sglang/srt/{model_parallel.py → layers/model_parallel.py} +0 -0
- /sglang/srt/{code_completion_parser.py → parser/code_completion_parser.py} +0 -0
- /sglang/srt/{jinja_template_utils.py → parser/jinja_template_utils.py} +0 -0
- {sglang-0.5.1.post2.dist-info → sglang-0.5.2.dist-info}/WHEEL +0 -0
- {sglang-0.5.1.post2.dist-info → sglang-0.5.2.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.1.post2.dist-info → sglang-0.5.2.dist-info}/top_level.txt +0 -0
sglang/srt/models/opt.py
ADDED
@@ -0,0 +1,637 @@
|
|
1
|
+
# Copyright 2023-2024 SGLang Team
|
2
|
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
3
|
+
# you may not use this file except in compliance with the License.
|
4
|
+
# You may obtain a copy of the License at
|
5
|
+
#
|
6
|
+
# http://www.apache.org/licenses/LICENSE-2.0
|
7
|
+
#
|
8
|
+
# Unless required by applicable law or agreed to in writing, software
|
9
|
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
10
|
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
11
|
+
# See the License for the specific language governing permissions and
|
12
|
+
# limitations under the License.
|
13
|
+
# ==============================================================================
|
14
|
+
|
15
|
+
"""Inference-only OPT model compatible with HuggingFace weights."""
|
16
|
+
from collections.abc import Iterable
|
17
|
+
from typing import Optional, Union
|
18
|
+
|
19
|
+
import torch
|
20
|
+
import torch.nn.functional as F
|
21
|
+
from torch import nn
|
22
|
+
from transformers import OPTConfig
|
23
|
+
|
24
|
+
from sglang.srt.distributed import (
|
25
|
+
get_pp_group,
|
26
|
+
get_tensor_model_parallel_rank,
|
27
|
+
get_tensor_model_parallel_world_size,
|
28
|
+
)
|
29
|
+
from sglang.srt.layers.activation import get_act_fn
|
30
|
+
from sglang.srt.layers.linear import (
|
31
|
+
ColumnParallelLinear,
|
32
|
+
MergedColumnParallelLinear,
|
33
|
+
QKVParallelLinear,
|
34
|
+
ReplicatedLinear,
|
35
|
+
RowParallelLinear,
|
36
|
+
)
|
37
|
+
from sglang.srt.layers.logits_processor import LogitsProcessor, LogitsProcessorOutput
|
38
|
+
from sglang.srt.layers.pooler import Pooler, PoolingType
|
39
|
+
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
40
|
+
from sglang.srt.layers.radix_attention import RadixAttention
|
41
|
+
from sglang.srt.layers.utils import PPMissingLayer, get_layer_id
|
42
|
+
from sglang.srt.layers.vocab_parallel_embedding import (
|
43
|
+
ParallelLMHead,
|
44
|
+
VocabParallelEmbedding,
|
45
|
+
)
|
46
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardBatch, PPProxyTensors
|
47
|
+
from sglang.srt.model_loader.weight_utils import (
|
48
|
+
default_weight_loader,
|
49
|
+
kv_cache_scales_loader,
|
50
|
+
maybe_remap_kv_scale_name,
|
51
|
+
)
|
52
|
+
from sglang.srt.utils import add_prefix, make_layers
|
53
|
+
|
54
|
+
|
55
|
+
def get_activation(name="relu"):
|
56
|
+
"""Select an activation function by name
|
57
|
+
|
58
|
+
Args:
|
59
|
+
name: str
|
60
|
+
activation function name,
|
61
|
+
one of ["relu", "gelu", "swish", "sigmoid"],
|
62
|
+
default "relu".
|
63
|
+
"""
|
64
|
+
name = name.lower()
|
65
|
+
if name == "relu":
|
66
|
+
return nn.ReLU()
|
67
|
+
if name == "gelu":
|
68
|
+
return nn.GELU()
|
69
|
+
if name == "sigmoid":
|
70
|
+
return torch.nn.Sigmoid()
|
71
|
+
return nn.Identity()
|
72
|
+
|
73
|
+
|
74
|
+
class OPTLearnedPositionalEmbedding(nn.Embedding):
|
75
|
+
|
76
|
+
def __init__(self, num_embeddings: int, embedding_dim: int):
|
77
|
+
# OPT is set up so that if padding_idx is specified then offset the
|
78
|
+
# embedding ids by 2 and adjust num_embeddings appropriately. Other
|
79
|
+
# models don't have this hack
|
80
|
+
self.offset = 2
|
81
|
+
super().__init__(num_embeddings + self.offset, embedding_dim)
|
82
|
+
|
83
|
+
def forward(self, positions: torch.Tensor):
|
84
|
+
return super().forward(positions + self.offset)
|
85
|
+
|
86
|
+
|
87
|
+
class OPTAttention(nn.Module):
|
88
|
+
|
89
|
+
def __init__(
|
90
|
+
self,
|
91
|
+
embed_dim: int,
|
92
|
+
num_heads: int,
|
93
|
+
layer_id: int = 0,
|
94
|
+
bias: bool = True,
|
95
|
+
quant_config: Optional[QuantizationConfig] = None,
|
96
|
+
prefix: str = "",
|
97
|
+
) -> None:
|
98
|
+
super().__init__()
|
99
|
+
self.embed_dim = embed_dim
|
100
|
+
tensor_model_parallel_world_size = get_tensor_model_parallel_world_size()
|
101
|
+
total_num_heads = num_heads
|
102
|
+
assert num_heads % tensor_model_parallel_world_size == 0
|
103
|
+
self.num_heads = total_num_heads // tensor_model_parallel_world_size
|
104
|
+
self.head_dim = embed_dim // total_num_heads
|
105
|
+
self.scaling = self.head_dim**-0.5
|
106
|
+
|
107
|
+
self.qkv_proj = QKVParallelLinear(
|
108
|
+
embed_dim,
|
109
|
+
self.head_dim,
|
110
|
+
total_num_heads,
|
111
|
+
bias=bias,
|
112
|
+
quant_config=quant_config,
|
113
|
+
prefix=add_prefix("qkv_proj", prefix),
|
114
|
+
)
|
115
|
+
self.out_proj = RowParallelLinear(
|
116
|
+
embed_dim,
|
117
|
+
embed_dim,
|
118
|
+
bias=bias,
|
119
|
+
quant_config=quant_config,
|
120
|
+
prefix=add_prefix("o_proj", prefix),
|
121
|
+
)
|
122
|
+
|
123
|
+
self.attn = RadixAttention(
|
124
|
+
self.num_heads,
|
125
|
+
self.head_dim,
|
126
|
+
self.scaling,
|
127
|
+
num_kv_heads=self.num_heads,
|
128
|
+
layer_id=layer_id,
|
129
|
+
quant_config=quant_config,
|
130
|
+
prefix=add_prefix("attn", prefix),
|
131
|
+
)
|
132
|
+
|
133
|
+
def forward(
|
134
|
+
self,
|
135
|
+
hidden_states: torch.Tensor,
|
136
|
+
forward_batch: ForwardBatch,
|
137
|
+
) -> torch.Tensor:
|
138
|
+
qkv, _ = self.qkv_proj(hidden_states)
|
139
|
+
q, k, v = qkv.chunk(chunks=3, dim=-1)
|
140
|
+
attn_output = self.attn(q, k, v, forward_batch)
|
141
|
+
output, _ = self.out_proj(attn_output)
|
142
|
+
return output
|
143
|
+
|
144
|
+
|
145
|
+
class OPTDecoderLayer(nn.Module):
|
146
|
+
|
147
|
+
def __init__(
|
148
|
+
self,
|
149
|
+
config: OPTConfig,
|
150
|
+
layer_id: int = 0,
|
151
|
+
quant_config: Optional[QuantizationConfig] = None,
|
152
|
+
prefix: str = "",
|
153
|
+
):
|
154
|
+
super().__init__()
|
155
|
+
self.config = config
|
156
|
+
self.embed_dim = config.hidden_size
|
157
|
+
self.self_attn = OPTAttention(
|
158
|
+
embed_dim=self.embed_dim,
|
159
|
+
num_heads=config.num_attention_heads,
|
160
|
+
layer_id=layer_id,
|
161
|
+
bias=config.enable_bias,
|
162
|
+
quant_config=quant_config,
|
163
|
+
prefix=add_prefix("self_attn", prefix),
|
164
|
+
)
|
165
|
+
self.do_layer_norm_before = config.do_layer_norm_before
|
166
|
+
|
167
|
+
self.self_attn_layer_norm = nn.LayerNorm(
|
168
|
+
self.embed_dim, elementwise_affine=config.layer_norm_elementwise_affine
|
169
|
+
)
|
170
|
+
self.fc1 = ColumnParallelLinear(
|
171
|
+
self.embed_dim,
|
172
|
+
config.ffn_dim,
|
173
|
+
bias=config.enable_bias,
|
174
|
+
quant_config=quant_config,
|
175
|
+
prefix=add_prefix("fc1", prefix),
|
176
|
+
)
|
177
|
+
self.activation_fn = get_activation(config.activation_function)
|
178
|
+
self.fc2 = RowParallelLinear(
|
179
|
+
config.ffn_dim,
|
180
|
+
self.embed_dim,
|
181
|
+
bias=config.enable_bias,
|
182
|
+
quant_config=quant_config,
|
183
|
+
prefix=add_prefix("fc2", prefix),
|
184
|
+
)
|
185
|
+
self.final_layer_norm = nn.LayerNorm(
|
186
|
+
self.embed_dim, elementwise_affine=config.layer_norm_elementwise_affine
|
187
|
+
)
|
188
|
+
|
189
|
+
def forward(
|
190
|
+
self,
|
191
|
+
hidden_states: torch.Tensor,
|
192
|
+
forward_batch: ForwardBatch,
|
193
|
+
) -> torch.Tensor:
|
194
|
+
# Self Attention
|
195
|
+
residual = hidden_states
|
196
|
+
# 125m, 1.7B, ..., 175B applies layer norm BEFORE attention
|
197
|
+
if self.do_layer_norm_before:
|
198
|
+
hidden_states = self.self_attn_layer_norm(hidden_states)
|
199
|
+
hidden_states = self.self_attn(
|
200
|
+
hidden_states=hidden_states, forward_batch=forward_batch
|
201
|
+
)
|
202
|
+
hidden_states = residual + hidden_states
|
203
|
+
# 350m applies layer norm AFTER attention
|
204
|
+
if not self.do_layer_norm_before:
|
205
|
+
hidden_states = self.self_attn_layer_norm(hidden_states)
|
206
|
+
|
207
|
+
# Fully Connected
|
208
|
+
residual = hidden_states
|
209
|
+
# 125m, 1.7B, ..., 175B applies layer norm BEFORE attention
|
210
|
+
if self.do_layer_norm_before:
|
211
|
+
hidden_states = self.final_layer_norm(hidden_states)
|
212
|
+
hidden_states, _ = self.fc1(hidden_states)
|
213
|
+
hidden_states = self.activation_fn(hidden_states)
|
214
|
+
hidden_states, _ = self.fc2(hidden_states)
|
215
|
+
hidden_states = residual + hidden_states
|
216
|
+
# 350m applies layer norm AFTER attention
|
217
|
+
if not self.do_layer_norm_before:
|
218
|
+
hidden_states = self.final_layer_norm(hidden_states)
|
219
|
+
return hidden_states
|
220
|
+
|
221
|
+
|
222
|
+
class OPTDecoder(nn.Module):
|
223
|
+
|
224
|
+
def __init__(
|
225
|
+
self,
|
226
|
+
config: OPTConfig,
|
227
|
+
layer_id: int = 0,
|
228
|
+
quant_config: Optional[QuantizationConfig] = None,
|
229
|
+
prefix: str = "",
|
230
|
+
):
|
231
|
+
super().__init__()
|
232
|
+
self.config = config
|
233
|
+
self.max_target_positions = config.max_position_embeddings
|
234
|
+
self.vocab_size = config.vocab_size
|
235
|
+
|
236
|
+
self.pp_group = get_pp_group()
|
237
|
+
|
238
|
+
self.embed_tokens = VocabParallelEmbedding(
|
239
|
+
config.vocab_size,
|
240
|
+
config.word_embed_proj_dim,
|
241
|
+
prefix=add_prefix("embed_tokens", prefix),
|
242
|
+
)
|
243
|
+
# Positional embeddings are replicated (not sharded).
|
244
|
+
self.embed_positions = OPTLearnedPositionalEmbedding(
|
245
|
+
config.max_position_embeddings, config.hidden_size
|
246
|
+
)
|
247
|
+
|
248
|
+
# Project out & in will be replicated if they exist.
|
249
|
+
if config.word_embed_proj_dim != config.hidden_size:
|
250
|
+
self.project_out = ReplicatedLinear(
|
251
|
+
config.hidden_size,
|
252
|
+
config.word_embed_proj_dim,
|
253
|
+
bias=False,
|
254
|
+
quant_config=quant_config,
|
255
|
+
prefix=add_prefix("project_out", prefix),
|
256
|
+
)
|
257
|
+
else:
|
258
|
+
self.project_out = None
|
259
|
+
|
260
|
+
if config.word_embed_proj_dim != config.hidden_size:
|
261
|
+
self.project_in = ReplicatedLinear(
|
262
|
+
config.word_embed_proj_dim,
|
263
|
+
config.hidden_size,
|
264
|
+
bias=False,
|
265
|
+
quant_config=quant_config,
|
266
|
+
prefix=add_prefix("project_in", prefix),
|
267
|
+
)
|
268
|
+
else:
|
269
|
+
self.project_in = None
|
270
|
+
|
271
|
+
# Note that the only purpose of `config._remove_final_layer_norm` is to
|
272
|
+
# keep backward compatibility with checkpoints that have been fine-tuned
|
273
|
+
# before transformers v4.20.1
|
274
|
+
# see https://github.com/facebookresearch/metaseq/pull/164
|
275
|
+
if config.do_layer_norm_before and not config._remove_final_layer_norm:
|
276
|
+
self.final_layer_norm = nn.LayerNorm(
|
277
|
+
config.hidden_size,
|
278
|
+
elementwise_affine=config.layer_norm_elementwise_affine,
|
279
|
+
)
|
280
|
+
else:
|
281
|
+
self.final_layer_norm = None
|
282
|
+
|
283
|
+
self.layers, self.start_layer, self.end_layer = make_layers(
|
284
|
+
config.num_hidden_layers,
|
285
|
+
lambda idx, prefix: OPTDecoderLayer(
|
286
|
+
config=config, layer_id=idx, quant_config=quant_config, prefix=prefix
|
287
|
+
),
|
288
|
+
pp_rank=self.pp_group.rank_in_group,
|
289
|
+
pp_size=self.pp_group.world_size,
|
290
|
+
prefix="model.layers",
|
291
|
+
)
|
292
|
+
|
293
|
+
def forward(
|
294
|
+
self,
|
295
|
+
input_ids: torch.Tensor,
|
296
|
+
positions: torch.Tensor,
|
297
|
+
forward_batch: ForwardBatch,
|
298
|
+
pp_proxy_tensors: Optional[PPProxyTensors] = None,
|
299
|
+
input_embeds: Optional[torch.Tensor] = None,
|
300
|
+
) -> Union[torch.Tensor, PPProxyTensors]:
|
301
|
+
if self.pp_group.is_first_rank:
|
302
|
+
if input_embeds is None:
|
303
|
+
input_embeds = self.embed_tokens(input_ids)
|
304
|
+
pos_embeds = self.embed_positions(positions)
|
305
|
+
if self.project_in is not None:
|
306
|
+
input_embeds, _ = self.project_in(input_embeds)
|
307
|
+
hidden_states = input_embeds + pos_embeds
|
308
|
+
else:
|
309
|
+
assert pp_proxy_tensors is not None
|
310
|
+
hidden_states = pp_proxy_tensors["hidden_states"]
|
311
|
+
|
312
|
+
for layer in self.layers[self.start_layer : self.end_layer]:
|
313
|
+
hidden_states = layer(
|
314
|
+
hidden_states=hidden_states, forward_batch=forward_batch
|
315
|
+
)
|
316
|
+
if not self.pp_group.is_last_rank:
|
317
|
+
return PPProxyTensors({"hidden_states": hidden_states})
|
318
|
+
if self.final_layer_norm is not None:
|
319
|
+
hidden_states = self.final_layer_norm(hidden_states)
|
320
|
+
# 没有经过这里
|
321
|
+
if self.project_out is not None:
|
322
|
+
hidden_states, _ = self.project_out(hidden_states)
|
323
|
+
return hidden_states
|
324
|
+
|
325
|
+
|
326
|
+
class OPTModel(nn.Module):
|
327
|
+
|
328
|
+
def __init__(
|
329
|
+
self,
|
330
|
+
config: OPTConfig,
|
331
|
+
quant_config: Optional[QuantizationConfig] = None,
|
332
|
+
prefix: str = "",
|
333
|
+
) -> None:
|
334
|
+
super().__init__()
|
335
|
+
|
336
|
+
# config = vllm_config.model_config.hf_config
|
337
|
+
# quant_config = vllm_config.quant_config
|
338
|
+
self.config = config
|
339
|
+
self.padding_idx = config.pad_token_id
|
340
|
+
self.vocab_size = config.vocab_size
|
341
|
+
self.pp_group = get_pp_group()
|
342
|
+
|
343
|
+
self.decoder = OPTDecoder(
|
344
|
+
config=config,
|
345
|
+
quant_config=quant_config,
|
346
|
+
prefix=add_prefix("decoder", prefix),
|
347
|
+
)
|
348
|
+
|
349
|
+
def forward(
|
350
|
+
self,
|
351
|
+
input_ids: torch.Tensor,
|
352
|
+
positions: torch.Tensor,
|
353
|
+
forward_batch: ForwardBatch,
|
354
|
+
pp_proxy_tensors: Optional[PPProxyTensors],
|
355
|
+
input_embeds: Optional[torch.Tensor] = None,
|
356
|
+
) -> Union[torch.Tensor, PPProxyTensors]:
|
357
|
+
return self.decoder(
|
358
|
+
input_ids,
|
359
|
+
positions,
|
360
|
+
pp_proxy_tensors=pp_proxy_tensors,
|
361
|
+
input_embeds=input_embeds,
|
362
|
+
forward_batch=forward_batch,
|
363
|
+
)
|
364
|
+
|
365
|
+
def load_kv_cache_scales(self, quantization_param_path: str) -> None:
|
366
|
+
tp_size = get_tensor_model_parallel_world_size()
|
367
|
+
tp_rank = get_tensor_model_parallel_rank()
|
368
|
+
for layer_idx, scaling_factor in kv_cache_scales_loader(
|
369
|
+
quantization_param_path,
|
370
|
+
tp_rank,
|
371
|
+
tp_size,
|
372
|
+
self.config.num_hidden_layers,
|
373
|
+
self.config.__class__.model_type,
|
374
|
+
):
|
375
|
+
if not isinstance(self.decoder.layers[layer_idx], nn.Identity):
|
376
|
+
layer_self_attn = self.decoder.layers[layer_idx].self_attn
|
377
|
+
|
378
|
+
if hasattr(layer_self_attn.attn, "k_scale"):
|
379
|
+
layer_self_attn.attn.k_scale = scaling_factor
|
380
|
+
layer_self_attn.attn.v_scale = scaling_factor
|
381
|
+
else:
|
382
|
+
raise RuntimeError(
|
383
|
+
"Self attention has no KV cache scaling " "factor attribute!"
|
384
|
+
)
|
385
|
+
|
386
|
+
|
387
|
+
class OPTForCausalLM(nn.Module):
|
388
|
+
# BitandBytes specific attributes
|
389
|
+
# in TP, these weights are partitioned along the column dimension (dim=-1)
|
390
|
+
column_parallel_weights_modules = [".down_proj.", ".o_proj."]
|
391
|
+
|
392
|
+
def __init__(
|
393
|
+
self,
|
394
|
+
config: OPTConfig,
|
395
|
+
quant_config: Optional[QuantizationConfig] = None,
|
396
|
+
prefix: str = "",
|
397
|
+
):
|
398
|
+
super().__init__()
|
399
|
+
self.config = config
|
400
|
+
self.quant_config = quant_config
|
401
|
+
|
402
|
+
self.model = OPTModel(
|
403
|
+
config=config, quant_config=quant_config, prefix=add_prefix("model", prefix)
|
404
|
+
)
|
405
|
+
if self.config.tie_word_embeddings:
|
406
|
+
self.lm_head = self.model.decoder.embed_tokens
|
407
|
+
else:
|
408
|
+
self.lm_head = ParallelLMHead(
|
409
|
+
config.vocab_size,
|
410
|
+
config.word_embed_proj_dim,
|
411
|
+
prefix=add_prefix("lm_head", prefix),
|
412
|
+
)
|
413
|
+
self.logits_processor = LogitsProcessor(config)
|
414
|
+
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
|
415
|
+
self.capture_aux_hidden_states = False
|
416
|
+
self.pp_group = get_pp_group()
|
417
|
+
self.stacked_params_mapping = [
|
418
|
+
# (param_name, shard_name, shard_id)
|
419
|
+
(".qkv_proj", ".q_proj", "q"),
|
420
|
+
(".qkv_proj", ".k_proj", "k"),
|
421
|
+
(".qkv_proj", ".v_proj", "v"),
|
422
|
+
]
|
423
|
+
|
424
|
+
def forward(
|
425
|
+
self,
|
426
|
+
input_ids: torch.Tensor,
|
427
|
+
positions: torch.Tensor,
|
428
|
+
forward_batch: ForwardBatch,
|
429
|
+
pp_proxy_tensors: Optional[PPProxyTensors] = None,
|
430
|
+
input_embeds: Optional[torch.Tensor] = None,
|
431
|
+
get_embedding: bool = False,
|
432
|
+
) -> LogitsProcessorOutput:
|
433
|
+
hidden_states = self.model(
|
434
|
+
input_ids=input_ids,
|
435
|
+
positions=positions,
|
436
|
+
forward_batch=forward_batch,
|
437
|
+
input_embeds=input_embeds,
|
438
|
+
pp_proxy_tensors=pp_proxy_tensors,
|
439
|
+
)
|
440
|
+
aux_hidden_states = None
|
441
|
+
if self.capture_aux_hidden_states:
|
442
|
+
hidden_states, aux_hidden_states = hidden_states
|
443
|
+
|
444
|
+
if self.pp_group.is_last_rank:
|
445
|
+
if not get_embedding:
|
446
|
+
return self.logits_processor(
|
447
|
+
input_ids,
|
448
|
+
hidden_states,
|
449
|
+
self.lm_head,
|
450
|
+
forward_batch,
|
451
|
+
aux_hidden_states=aux_hidden_states,
|
452
|
+
)
|
453
|
+
else:
|
454
|
+
return self.pooler(hidden_states, forward_batch)
|
455
|
+
else:
|
456
|
+
return hidden_states
|
457
|
+
|
458
|
+
def load_weights(self, weights: Iterable[tuple[str, torch.Tensor]]) -> None:
|
459
|
+
stacked_params_mapping = [
|
460
|
+
# (param_name, shard_name, shard_id)
|
461
|
+
("qkv_proj", "q_proj", "q"),
|
462
|
+
("qkv_proj", "k_proj", "k"),
|
463
|
+
("qkv_proj", "v_proj", "v"),
|
464
|
+
]
|
465
|
+
params_dict = dict(self.named_parameters(remove_duplicate=False))
|
466
|
+
|
467
|
+
for name, loaded_weight in weights:
|
468
|
+
if name.startswith("decoder"):
|
469
|
+
name = name.replace("decoder.", "model.decoder.")
|
470
|
+
layer_id = get_layer_id(name)
|
471
|
+
if (
|
472
|
+
layer_id is not None
|
473
|
+
and hasattr(self.model, "start_layer")
|
474
|
+
and (
|
475
|
+
layer_id < self.model.start_layer
|
476
|
+
or layer_id >= self.model.end_layer
|
477
|
+
)
|
478
|
+
):
|
479
|
+
continue
|
480
|
+
for param_name, weight_name, shard_id in stacked_params_mapping:
|
481
|
+
if weight_name not in name:
|
482
|
+
continue
|
483
|
+
name = name.replace(weight_name, param_name)
|
484
|
+
# Skip loading extra bias for GPTQ models.
|
485
|
+
if name.endswith(".bias") and name not in params_dict:
|
486
|
+
continue
|
487
|
+
# if is_pp_missing_parameter(name, self):
|
488
|
+
# continue
|
489
|
+
param = params_dict[name]
|
490
|
+
weight_loader = param.weight_loader
|
491
|
+
weight_loader(param, loaded_weight, shard_id)
|
492
|
+
break
|
493
|
+
else:
|
494
|
+
# Skip loading extra bias for GPTQ models.
|
495
|
+
if name.endswith(".bias") and name not in params_dict:
|
496
|
+
continue
|
497
|
+
# if is_pp_missing_parameter(name, self):
|
498
|
+
# continue
|
499
|
+
if name not in params_dict:
|
500
|
+
continue
|
501
|
+
if name in params_dict.keys():
|
502
|
+
param = params_dict[name]
|
503
|
+
weight_loader = getattr(
|
504
|
+
param, "weight_loader", default_weight_loader
|
505
|
+
)
|
506
|
+
weight_loader(param, loaded_weight)
|
507
|
+
else:
|
508
|
+
logger.warning(f"Parameter {name} not found in params_dict")
|
509
|
+
|
510
|
+
@property
|
511
|
+
def start_layer(self):
|
512
|
+
return self.model.start_layer
|
513
|
+
|
514
|
+
@property
|
515
|
+
def end_layer(self):
|
516
|
+
return self.model.end_layer
|
517
|
+
|
518
|
+
def get_input_embeddings(self) -> nn.Embedding:
|
519
|
+
return self.model.embed_tokens
|
520
|
+
|
521
|
+
def get_module_name_from_weight_name(self, name):
|
522
|
+
for param_name, weight_name, shard_id, num_shard in self.stacked_params_mapping:
|
523
|
+
if weight_name in name:
|
524
|
+
return (
|
525
|
+
name.replace(weight_name, param_name)[: -len(".weight")],
|
526
|
+
num_shard,
|
527
|
+
)
|
528
|
+
return name[: -len(".weight")], 1
|
529
|
+
|
530
|
+
def get_num_params(self):
|
531
|
+
params_dict = dict(self.named_parameters())
|
532
|
+
return len(params_dict)
|
533
|
+
|
534
|
+
def get_weights_by_name(
|
535
|
+
self, name: str, truncate_size: int = 100, tp_size: int = 1
|
536
|
+
) -> Optional[torch.Tensor]:
|
537
|
+
"""Get the weights of the parameter by its name. Similar to `get_parameter` in Hugging Face.
|
538
|
+
|
539
|
+
Only used for unit test with an unoptimized performance.
|
540
|
+
For optimized performance, please use torch.save and torch.load.
|
541
|
+
"""
|
542
|
+
try:
|
543
|
+
if name == "lm_head.weight" and self.config.tie_word_embeddings:
|
544
|
+
logger.info(
|
545
|
+
"word embedding is tied for this model, return embed_tokens.weight as lm_head.weight."
|
546
|
+
)
|
547
|
+
return (
|
548
|
+
self.model.embed_tokens.weight.cpu()
|
549
|
+
.to(torch.float32)
|
550
|
+
.numpy()
|
551
|
+
.tolist()[:truncate_size]
|
552
|
+
)
|
553
|
+
|
554
|
+
mapped_name = name
|
555
|
+
mapped_shard_id = None
|
556
|
+
for param_name, weight_name, shard_id in self.stacked_params_mapping:
|
557
|
+
if weight_name in name:
|
558
|
+
mapped_name = name.replace(weight_name, param_name)
|
559
|
+
mapped_shard_id = shard_id
|
560
|
+
break
|
561
|
+
params_dict = dict(self.named_parameters())
|
562
|
+
param = params_dict[mapped_name]
|
563
|
+
if mapped_shard_id is not None:
|
564
|
+
if mapped_shard_id in ["q", "k", "v"]:
|
565
|
+
num_heads = self.config.num_attention_heads // tp_size
|
566
|
+
num_kv_heads = self.config.num_attention_heads // tp_size
|
567
|
+
head_dim = (
|
568
|
+
self.config.hidden_size // self.config.num_attention_heads
|
569
|
+
)
|
570
|
+
if mapped_shard_id == "q":
|
571
|
+
offset = 0
|
572
|
+
size = num_heads * head_dim
|
573
|
+
elif mapped_shard_id == "k":
|
574
|
+
offset = num_heads * head_dim
|
575
|
+
size = num_kv_heads * head_dim
|
576
|
+
elif mapped_shard_id == "v":
|
577
|
+
offset = (num_heads + num_kv_heads) * head_dim
|
578
|
+
size = num_kv_heads * head_dim
|
579
|
+
weight = param.data.narrow(0, offset, size)
|
580
|
+
elif mapped_shard_id in [0, 1]:
|
581
|
+
intermediate_size = self.config.ffn_dim
|
582
|
+
slice_size = intermediate_size // tp_size
|
583
|
+
if mapped_shard_id == 0: # gate_proj
|
584
|
+
offset = 0
|
585
|
+
size = slice_size
|
586
|
+
elif mapped_shard_id == 1: # up_proj
|
587
|
+
offset = slice_size
|
588
|
+
size = slice_size
|
589
|
+
|
590
|
+
weight = param.data.narrow(0, offset, size)
|
591
|
+
else:
|
592
|
+
weight = param.data
|
593
|
+
else:
|
594
|
+
weight = param.data
|
595
|
+
if tp_size > 1 and ("o_proj" in name or "down_proj" in name):
|
596
|
+
gathered_weights = [torch.zeros_like(weight) for _ in range(tp_size)]
|
597
|
+
torch.distributed.all_gather(gathered_weights, weight)
|
598
|
+
weight = torch.cat(gathered_weights, dim=1)
|
599
|
+
return weight.cpu().to(torch.float32).numpy().tolist()[:truncate_size]
|
600
|
+
|
601
|
+
except Exception:
|
602
|
+
logger.error(
|
603
|
+
f"Error getting weights by name {name} in OPTForCausalLM: {get_exception_traceback()}"
|
604
|
+
)
|
605
|
+
return None
|
606
|
+
|
607
|
+
def get_embed_and_head(self):
|
608
|
+
return self.model.embed_tokens.weight, self.lm_head.weight
|
609
|
+
|
610
|
+
def set_embed_and_head(self, embed, head):
|
611
|
+
del self.model.embed_tokens.weight
|
612
|
+
del self.lm_head.weight
|
613
|
+
self.model.embed_tokens.weight = embed
|
614
|
+
self.lm_head.weight = head
|
615
|
+
torch.cuda.empty_cache()
|
616
|
+
torch.cuda.synchronize()
|
617
|
+
|
618
|
+
def get_embed(self):
|
619
|
+
return self.model.embed_tokens.weight
|
620
|
+
|
621
|
+
def set_embed(self, embed):
|
622
|
+
# NOTE: If draft hidden size != target hidden size, the embed weight cannot be shared for EAGLE3
|
623
|
+
if (
|
624
|
+
hasattr(self.config, "target_hidden_size")
|
625
|
+
and self.config.target_hidden_size != self.config.hidden_size
|
626
|
+
):
|
627
|
+
return
|
628
|
+
del self.model.embed_tokens.weight
|
629
|
+
self.model.embed_tokens.weight = embed
|
630
|
+
torch.cuda.empty_cache()
|
631
|
+
torch.cuda.synchronize()
|
632
|
+
|
633
|
+
def load_kv_cache_scales(self, quantization_param_path: str) -> None:
|
634
|
+
self.model.load_kv_cache_scales(quantization_param_path)
|
635
|
+
|
636
|
+
|
637
|
+
EntryClass = [OPTForCausalLM]
|
sglang/srt/models/qwen2.py
CHANGED
@@ -16,7 +16,7 @@
|
|
16
16
|
# Modify details for the adaptation of Qwen2 model.
|
17
17
|
"""Inference-only Qwen2 model compatible with HuggingFace weights."""
|
18
18
|
import logging
|
19
|
-
from typing import Any, Dict, Iterable, Optional, Tuple, Union
|
19
|
+
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
|
20
20
|
|
21
21
|
import torch
|
22
22
|
from torch import nn
|
@@ -431,7 +431,6 @@ class Qwen2ForCausalLM(nn.Module):
|
|
431
431
|
quant_config=quant_config,
|
432
432
|
prefix=add_prefix("lm_head", prefix),
|
433
433
|
)
|
434
|
-
|
435
434
|
else:
|
436
435
|
# ranks other than the last rank will have a placeholder layer
|
437
436
|
self.lm_head = PPMissingLayer()
|
@@ -452,6 +451,11 @@ class Qwen2ForCausalLM(nn.Module):
|
|
452
451
|
|
453
452
|
self.logits_processor = LogitsProcessor(config)
|
454
453
|
self.pooler = Pooler(pooling_type=PoolingType.LAST, normalize=True)
|
454
|
+
# For EAGLE3 support
|
455
|
+
self.capture_aux_hidden_states = False
|
456
|
+
|
457
|
+
# For EAGLE3 support
|
458
|
+
self.capture_aux_hidden_states = False
|
455
459
|
|
456
460
|
def get_input_embedding(self, input_ids: torch.Tensor) -> torch.Tensor:
|
457
461
|
return self.model.get_input_embedding(input_ids)
|
@@ -476,11 +480,22 @@ class Qwen2ForCausalLM(nn.Module):
|
|
476
480
|
input_embeds,
|
477
481
|
pp_proxy_tensors=pp_proxy_tensors,
|
478
482
|
)
|
483
|
+
aux_hidden_states = None
|
484
|
+
if self.capture_aux_hidden_states:
|
485
|
+
hidden_states, aux_hidden_states = hidden_states
|
486
|
+
|
487
|
+
aux_hidden_states = None
|
488
|
+
if self.capture_aux_hidden_states:
|
489
|
+
hidden_states, aux_hidden_states = hidden_states
|
479
490
|
|
480
491
|
if self.pp_group.is_last_rank:
|
481
492
|
if not get_embedding:
|
482
493
|
return self.logits_processor(
|
483
|
-
input_ids,
|
494
|
+
input_ids,
|
495
|
+
hidden_states,
|
496
|
+
self.lm_head,
|
497
|
+
forward_batch,
|
498
|
+
aux_hidden_states,
|
484
499
|
)
|
485
500
|
else:
|
486
501
|
return self.pooler(hidden_states, forward_batch)
|
@@ -619,5 +634,20 @@ class Qwen2ForCausalLM(nn.Module):
|
|
619
634
|
def load_kv_cache_scales(self, quantization_param_path: str) -> None:
|
620
635
|
self.model.load_kv_cache_scales(quantization_param_path)
|
621
636
|
|
637
|
+
def set_eagle3_layers_to_capture(self, layer_ids: Optional[List[int]] = None):
|
638
|
+
if not self.pp_group.is_last_rank:
|
639
|
+
return
|
640
|
+
|
641
|
+
self.capture_aux_hidden_states = True
|
642
|
+
if layer_ids is None:
|
643
|
+
num_layers = self.config.num_hidden_layers
|
644
|
+
self.model.layers_to_capture = [
|
645
|
+
2,
|
646
|
+
num_layers // 2,
|
647
|
+
num_layers - 3,
|
648
|
+
] # Specific layers for EAGLE3 support
|
649
|
+
else:
|
650
|
+
self.model.layers_to_capture = [val + 1 for val in layer_ids]
|
651
|
+
|
622
652
|
|
623
653
|
EntryClass = Qwen2ForCausalLM
|