sglang 0.5.4__py3-none-any.whl → 0.5.4.post2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- sglang/bench_one_batch.py +149 -34
- sglang/bench_serving.py +73 -14
- sglang/compile_deep_gemm.py +13 -7
- sglang/launch_server.py +2 -0
- sglang/srt/batch_invariant_ops/__init__.py +2 -0
- sglang/srt/batch_invariant_ops/batch_invariant_ops.py +221 -4
- sglang/srt/checkpoint_engine/__init__.py +9 -0
- sglang/srt/checkpoint_engine/update.py +317 -0
- sglang/srt/compilation/backend.py +1 -1
- sglang/srt/configs/__init__.py +2 -0
- sglang/srt/configs/deepseek_ocr.py +542 -10
- sglang/srt/configs/deepseekvl2.py +95 -194
- sglang/srt/configs/kimi_linear.py +160 -0
- sglang/srt/configs/mamba_utils.py +66 -0
- sglang/srt/configs/model_config.py +30 -7
- sglang/srt/constants.py +7 -0
- sglang/srt/debug_utils/tensor_dump_forward_hook.py +149 -0
- sglang/srt/disaggregation/decode.py +34 -6
- sglang/srt/disaggregation/nixl/conn.py +2 -2
- sglang/srt/disaggregation/prefill.py +25 -3
- sglang/srt/distributed/device_communicators/custom_all_reduce.py +3 -1
- sglang/srt/distributed/parallel_state.py +9 -12
- sglang/srt/entrypoints/engine.py +31 -20
- sglang/srt/entrypoints/grpc_server.py +0 -1
- sglang/srt/entrypoints/http_server.py +94 -94
- sglang/srt/entrypoints/openai/protocol.py +7 -1
- sglang/srt/entrypoints/openai/serving_chat.py +42 -0
- sglang/srt/entrypoints/openai/serving_completions.py +10 -0
- sglang/srt/entrypoints/openai/serving_embedding.py +1 -0
- sglang/srt/environ.py +23 -2
- sglang/srt/eplb/expert_distribution.py +64 -1
- sglang/srt/eplb/expert_location.py +106 -36
- sglang/srt/function_call/function_call_parser.py +2 -0
- sglang/srt/function_call/minimax_m2.py +367 -0
- sglang/srt/grpc/compile_proto.py +3 -0
- sglang/srt/layers/activation.py +6 -0
- sglang/srt/layers/attention/ascend_backend.py +233 -5
- sglang/srt/layers/attention/attention_registry.py +3 -0
- sglang/srt/layers/attention/fla/chunk_delta_h.py +61 -32
- sglang/srt/layers/attention/fla/fused_recurrent.py +17 -4
- sglang/srt/layers/attention/fla/kda.py +1359 -0
- sglang/srt/layers/attention/fla/layernorm_gated.py +7 -1
- sglang/srt/layers/attention/flashattention_backend.py +19 -8
- sglang/srt/layers/attention/flashinfer_backend.py +10 -1
- sglang/srt/layers/attention/flashinfer_mla_backend.py +21 -11
- sglang/srt/layers/attention/flashmla_backend.py +1 -1
- sglang/srt/layers/attention/hybrid_linear_attn_backend.py +223 -0
- sglang/srt/layers/attention/mamba/mamba.py +20 -11
- sglang/srt/layers/attention/nsa/dequant_k_cache.py +138 -6
- sglang/srt/layers/attention/nsa/nsa_indexer.py +45 -22
- sglang/srt/layers/attention/nsa/quant_k_cache.py +44 -12
- sglang/srt/layers/attention/nsa/transform_index.py +1 -1
- sglang/srt/layers/attention/nsa_backend.py +157 -23
- sglang/srt/layers/attention/triton_backend.py +4 -1
- sglang/srt/layers/attention/trtllm_mha_backend.py +10 -4
- sglang/srt/layers/attention/trtllm_mla_backend.py +11 -15
- sglang/srt/layers/attention/utils.py +78 -0
- sglang/srt/layers/communicator.py +24 -1
- sglang/srt/layers/deep_gemm_wrapper/compile_utils.py +1 -1
- sglang/srt/layers/layernorm.py +35 -6
- sglang/srt/layers/logits_processor.py +9 -20
- sglang/srt/layers/moe/cutlass_w4a8_moe.py +138 -0
- sglang/srt/layers/moe/ep_moe/kernels.py +194 -0
- sglang/srt/layers/moe/ep_moe/layer.py +78 -289
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128].json +146 -0
- sglang/srt/layers/moe/fused_moe_triton/configs/triton_3_4_0/E=257,N=256,device_name=NVIDIA_H20,dtype=fp8_w8a8,block_shape=[128, 128]_down.json +164 -0
- sglang/srt/layers/moe/fused_moe_triton/fused_moe.py +68 -22
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_config.py +43 -3
- sglang/srt/layers/moe/fused_moe_triton/fused_moe_triton_kernels.py +106 -26
- sglang/srt/layers/moe/fused_moe_triton/layer.py +3 -3
- sglang/srt/layers/moe/fused_moe_triton/triton_kernels_moe.py +7 -4
- sglang/srt/layers/moe/moe_runner/deep_gemm.py +340 -55
- sglang/srt/layers/moe/moe_runner/runner.py +3 -0
- sglang/srt/layers/moe/moe_runner/triton_kernels.py +194 -0
- sglang/srt/layers/moe/token_dispatcher/__init__.py +4 -4
- sglang/srt/layers/moe/token_dispatcher/base.py +11 -5
- sglang/srt/layers/moe/token_dispatcher/deepep.py +25 -18
- sglang/srt/layers/moe/token_dispatcher/standard.py +1 -1
- sglang/srt/layers/moe/topk.py +35 -10
- sglang/srt/layers/moe/utils.py +3 -4
- sglang/srt/layers/pooler.py +21 -2
- sglang/srt/layers/quantization/__init__.py +13 -84
- sglang/srt/layers/quantization/auto_round.py +394 -0
- sglang/srt/layers/quantization/awq.py +0 -3
- sglang/srt/layers/quantization/base_config.py +7 -0
- sglang/srt/layers/quantization/fp8.py +68 -63
- sglang/srt/layers/quantization/fp8_kernel.py +1 -1
- sglang/srt/layers/quantization/fp8_utils.py +2 -2
- sglang/srt/layers/quantization/gguf.py +566 -0
- sglang/srt/layers/quantization/modelopt_quant.py +168 -11
- sglang/srt/layers/quantization/mxfp4.py +30 -38
- sglang/srt/layers/quantization/unquant.py +23 -45
- sglang/srt/layers/quantization/w4afp8.py +38 -2
- sglang/srt/layers/radix_attention.py +5 -2
- sglang/srt/layers/rotary_embedding.py +130 -46
- sglang/srt/layers/sampler.py +12 -1
- sglang/srt/lora/lora_registry.py +9 -0
- sglang/srt/managers/async_mm_data_processor.py +122 -0
- sglang/srt/managers/data_parallel_controller.py +30 -3
- sglang/srt/managers/detokenizer_manager.py +3 -0
- sglang/srt/managers/io_struct.py +29 -4
- sglang/srt/managers/multi_tokenizer_mixin.py +22 -1
- sglang/srt/managers/schedule_batch.py +74 -15
- sglang/srt/managers/scheduler.py +185 -144
- sglang/srt/managers/scheduler_metrics_mixin.py +22 -14
- sglang/srt/managers/scheduler_output_processor_mixin.py +40 -3
- sglang/srt/managers/scheduler_pp_mixin.py +7 -2
- sglang/srt/managers/scheduler_profiler_mixin.py +3 -4
- sglang/srt/managers/scheduler_runtime_checker_mixin.py +45 -0
- sglang/srt/managers/scheduler_update_weights_mixin.py +18 -3
- sglang/srt/managers/session_controller.py +6 -5
- sglang/srt/managers/tokenizer_manager.py +165 -78
- sglang/srt/managers/tp_worker.py +24 -1
- sglang/srt/mem_cache/base_prefix_cache.py +23 -4
- sglang/srt/mem_cache/common.py +1 -0
- sglang/srt/mem_cache/hicache_storage.py +7 -1
- sglang/srt/mem_cache/memory_pool.py +253 -57
- sglang/srt/mem_cache/memory_pool_host.py +12 -5
- sglang/srt/mem_cache/radix_cache.py +4 -0
- sglang/srt/mem_cache/storage/hf3fs/mini_3fs_metadata_server.py +3 -2
- sglang/srt/mem_cache/storage/mooncake_store/mooncake_store.py +1 -1
- sglang/srt/metrics/collector.py +46 -3
- sglang/srt/model_executor/cuda_graph_runner.py +15 -3
- sglang/srt/model_executor/forward_batch_info.py +55 -14
- sglang/srt/model_executor/model_runner.py +77 -170
- sglang/srt/model_executor/npu_graph_runner.py +7 -3
- sglang/srt/model_executor/piecewise_cuda_graph_runner.py +22 -12
- sglang/srt/model_loader/weight_utils.py +1 -1
- sglang/srt/models/bailing_moe.py +9 -2
- sglang/srt/models/deepseek_nextn.py +11 -2
- sglang/srt/models/deepseek_v2.py +296 -78
- sglang/srt/models/glm4.py +391 -77
- sglang/srt/models/glm4_moe.py +322 -354
- sglang/srt/models/glm4_moe_nextn.py +4 -14
- sglang/srt/models/glm4v.py +196 -55
- sglang/srt/models/glm4v_moe.py +29 -197
- sglang/srt/models/gpt_oss.py +1 -10
- sglang/srt/models/kimi_linear.py +678 -0
- sglang/srt/models/llama4.py +1 -1
- sglang/srt/models/llama_eagle3.py +11 -1
- sglang/srt/models/longcat_flash.py +2 -2
- sglang/srt/models/minimax_m2.py +922 -0
- sglang/srt/models/nvila.py +355 -0
- sglang/srt/models/nvila_lite.py +184 -0
- sglang/srt/models/qwen2.py +23 -2
- sglang/srt/models/qwen2_moe.py +30 -15
- sglang/srt/models/qwen3.py +35 -5
- sglang/srt/models/qwen3_moe.py +18 -12
- sglang/srt/models/qwen3_next.py +7 -0
- sglang/srt/multimodal/customized_mm_processor_utils.py +35 -0
- sglang/srt/multimodal/processors/base_processor.py +1 -0
- sglang/srt/multimodal/processors/glm4v.py +1 -1
- sglang/srt/multimodal/processors/{vila.py → nvila.py} +32 -24
- sglang/srt/multimodal/processors/points_v15_chat.py +2 -2
- sglang/srt/multiplex/multiplexing_mixin.py +209 -0
- sglang/srt/multiplex/pdmux_context.py +164 -0
- sglang/srt/parser/conversation.py +7 -1
- sglang/srt/parser/reasoning_parser.py +28 -1
- sglang/srt/sampling/custom_logit_processor.py +67 -1
- sglang/srt/sampling/penaltylib/frequency_penalty.py +6 -8
- sglang/srt/sampling/penaltylib/min_new_tokens.py +7 -8
- sglang/srt/sampling/penaltylib/orchestrator.py +43 -3
- sglang/srt/sampling/penaltylib/presence_penalty.py +6 -8
- sglang/srt/server_args.py +459 -199
- sglang/srt/single_batch_overlap.py +2 -4
- sglang/srt/speculative/draft_utils.py +16 -0
- sglang/srt/speculative/eagle_info.py +42 -36
- sglang/srt/speculative/eagle_info_v2.py +68 -25
- sglang/srt/speculative/eagle_utils.py +261 -16
- sglang/srt/speculative/eagle_worker.py +11 -3
- sglang/srt/speculative/eagle_worker_v2.py +15 -9
- sglang/srt/speculative/spec_info.py +305 -31
- sglang/srt/speculative/spec_utils.py +44 -8
- sglang/srt/tracing/trace.py +121 -12
- sglang/srt/utils/common.py +142 -74
- sglang/srt/utils/hf_transformers_utils.py +38 -12
- sglang/srt/utils/torch_memory_saver_adapter.py +20 -0
- sglang/test/kits/radix_cache_server_kit.py +50 -0
- sglang/test/runners.py +31 -7
- sglang/test/simple_eval_common.py +5 -3
- sglang/test/simple_eval_humaneval.py +1 -0
- sglang/test/simple_eval_math.py +1 -0
- sglang/test/simple_eval_mmlu.py +1 -0
- sglang/test/simple_eval_mmmu_vlm.py +1 -0
- sglang/test/test_deterministic.py +235 -12
- sglang/test/test_deterministic_utils.py +2 -1
- sglang/test/test_utils.py +7 -1
- sglang/version.py +1 -1
- {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/METADATA +15 -28
- {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/RECORD +194 -175
- sglang/srt/models/vila.py +0 -306
- /sglang/test/{kit_matched_stop.py → kits/matched_stop_kit.py} +0 -0
- {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/WHEEL +0 -0
- {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/licenses/LICENSE +0 -0
- {sglang-0.5.4.dist-info → sglang-0.5.4.post2.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,355 @@
|
|
|
1
|
+
import itertools
|
|
2
|
+
import math
|
|
3
|
+
from collections.abc import Iterable
|
|
4
|
+
from typing import Any
|
|
5
|
+
|
|
6
|
+
import einops
|
|
7
|
+
import torch
|
|
8
|
+
import torch.nn as nn
|
|
9
|
+
import torch.nn.functional as F
|
|
10
|
+
from torch import Tensor
|
|
11
|
+
from transformers.configuration_utils import PretrainedConfig
|
|
12
|
+
from transformers.modeling_outputs import BaseModelOutputWithPooling
|
|
13
|
+
from transformers.models.qwen2.configuration_qwen2 import Qwen2Config
|
|
14
|
+
from transformers.models.siglip import SiglipVisionConfig, SiglipVisionModel
|
|
15
|
+
|
|
16
|
+
import sglang.srt.managers.mm_utils as mm_utils
|
|
17
|
+
import sglang.srt.model_loader.weight_utils as weight_utils
|
|
18
|
+
import sglang.srt.utils as utils
|
|
19
|
+
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
|
20
|
+
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
|
21
|
+
from sglang.srt.managers.mm_utils import MultiModalityDataPaddingPatternMultimodalTokens
|
|
22
|
+
from sglang.srt.managers.schedule_batch import (
|
|
23
|
+
Modality,
|
|
24
|
+
MultimodalDataItem,
|
|
25
|
+
MultimodalInputs,
|
|
26
|
+
)
|
|
27
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
|
28
|
+
from sglang.srt.models.qwen2 import Qwen2ForCausalLM
|
|
29
|
+
|
|
30
|
+
MM_HIDDEN_SIZE = 3456
|
|
31
|
+
|
|
32
|
+
|
|
33
|
+
class NVILAConfig(PretrainedConfig):
|
|
34
|
+
model_type = "nvila"
|
|
35
|
+
sub_configs = {
|
|
36
|
+
"text_config": Qwen2Config,
|
|
37
|
+
"vision_config": SiglipVisionConfig,
|
|
38
|
+
}
|
|
39
|
+
_auto_class = "AutoConfig"
|
|
40
|
+
|
|
41
|
+
def __init__(
|
|
42
|
+
self,
|
|
43
|
+
*,
|
|
44
|
+
text_config: dict[str, Any] | None = None,
|
|
45
|
+
vision_config: dict[str, Any] | None = None,
|
|
46
|
+
image_token_id: int | None = None,
|
|
47
|
+
video_token_id: int | None = None,
|
|
48
|
+
**kwargs,
|
|
49
|
+
):
|
|
50
|
+
self.text_config = (
|
|
51
|
+
Qwen2Config(**text_config) if text_config is not None else Qwen2Config()
|
|
52
|
+
)
|
|
53
|
+
self.vision_config = (
|
|
54
|
+
SiglipVisionConfig(**vision_config)
|
|
55
|
+
if vision_config is not None
|
|
56
|
+
else SiglipVisionConfig()
|
|
57
|
+
)
|
|
58
|
+
|
|
59
|
+
self.image_token_id = image_token_id if image_token_id is not None else -1
|
|
60
|
+
self.video_token_id = video_token_id if video_token_id is not None else -1
|
|
61
|
+
|
|
62
|
+
super().__init__(**kwargs)
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
class NVILAMultiModalProjectorDownsampleBlock(nn.Module):
|
|
66
|
+
def forward(self, x: Tensor) -> Tensor:
|
|
67
|
+
batch_size, sequence_length, hidden_size = x.shape
|
|
68
|
+
|
|
69
|
+
feat_size = math.isqrt(sequence_length)
|
|
70
|
+
|
|
71
|
+
features = x.reshape(batch_size, feat_size, feat_size, hidden_size)
|
|
72
|
+
|
|
73
|
+
pad_after = feat_size % 2
|
|
74
|
+
if pad_after > 0:
|
|
75
|
+
features = F.pad(features, (0, 0, 0, pad_after, 0, pad_after))
|
|
76
|
+
feat_size = feat_size + pad_after
|
|
77
|
+
|
|
78
|
+
features = features.reshape(
|
|
79
|
+
batch_size, feat_size // 2, 2, feat_size // 2, 2, hidden_size
|
|
80
|
+
)
|
|
81
|
+
features = features.permute(0, 1, 3, 2, 4, 5).contiguous()
|
|
82
|
+
features = features.reshape(batch_size, -1, 4 * hidden_size)
|
|
83
|
+
|
|
84
|
+
return features
|
|
85
|
+
|
|
86
|
+
|
|
87
|
+
class NVILAMultiModalProjector(nn.Module):
|
|
88
|
+
def __init__(self, config: NVILAConfig):
|
|
89
|
+
super().__init__()
|
|
90
|
+
|
|
91
|
+
self.layers = nn.Sequential(
|
|
92
|
+
NVILAMultiModalProjectorDownsampleBlock(),
|
|
93
|
+
nn.LayerNorm(MM_HIDDEN_SIZE * 4),
|
|
94
|
+
nn.Linear(MM_HIDDEN_SIZE * 4, config.text_config.hidden_size),
|
|
95
|
+
nn.GELU(),
|
|
96
|
+
nn.Linear(config.text_config.hidden_size, config.text_config.hidden_size),
|
|
97
|
+
)
|
|
98
|
+
|
|
99
|
+
def forward(self, x: Tensor) -> Tensor:
|
|
100
|
+
return self.layers(x)
|
|
101
|
+
|
|
102
|
+
|
|
103
|
+
class NVILAForConditionalGeneration(nn.Module):
|
|
104
|
+
def __init__(
|
|
105
|
+
self,
|
|
106
|
+
config: NVILAConfig,
|
|
107
|
+
quant_config: QuantizationConfig | None = None,
|
|
108
|
+
prefix: str = "",
|
|
109
|
+
) -> None:
|
|
110
|
+
super().__init__()
|
|
111
|
+
|
|
112
|
+
self.config = config
|
|
113
|
+
|
|
114
|
+
self.vision_tower = SiglipVisionModel(config.vision_config)
|
|
115
|
+
self.mm_projector = NVILAMultiModalProjector(config)
|
|
116
|
+
self.llm = Qwen2ForCausalLM(
|
|
117
|
+
config=config.text_config,
|
|
118
|
+
quant_config=quant_config,
|
|
119
|
+
prefix=utils.add_prefix("llm", prefix),
|
|
120
|
+
)
|
|
121
|
+
|
|
122
|
+
def forward(
|
|
123
|
+
self,
|
|
124
|
+
input_ids: Tensor,
|
|
125
|
+
positions: Tensor,
|
|
126
|
+
forward_batch: ForwardBatch,
|
|
127
|
+
get_embedding: bool = False,
|
|
128
|
+
) -> LogitsProcessorOutput:
|
|
129
|
+
output = mm_utils.general_mm_embed_routine(
|
|
130
|
+
input_ids=input_ids,
|
|
131
|
+
forward_batch=forward_batch,
|
|
132
|
+
language_model=self.llm,
|
|
133
|
+
data_embedding_funcs={
|
|
134
|
+
Modality.IMAGE: self.get_image_feature,
|
|
135
|
+
Modality.VIDEO: self.get_image_feature,
|
|
136
|
+
},
|
|
137
|
+
get_embedding=get_embedding,
|
|
138
|
+
positions=positions,
|
|
139
|
+
)
|
|
140
|
+
|
|
141
|
+
assert isinstance(output, LogitsProcessorOutput)
|
|
142
|
+
|
|
143
|
+
return output
|
|
144
|
+
|
|
145
|
+
def get_image_feature(self, mm_input: list[MultimodalDataItem]) -> Tensor:
|
|
146
|
+
block_sizes = (
|
|
147
|
+
list(
|
|
148
|
+
itertools.chain.from_iterable(
|
|
149
|
+
x.block_sizes for x in mm_input if hasattr(x, "block_sizes")
|
|
150
|
+
)
|
|
151
|
+
)
|
|
152
|
+
or None
|
|
153
|
+
)
|
|
154
|
+
pixel_values = torch.cat([torch.tensor(x.feature) for x in mm_input], dim=0)
|
|
155
|
+
|
|
156
|
+
vision_tower_output: BaseModelOutputWithPooling = self.vision_tower(
|
|
157
|
+
pixel_values.to(
|
|
158
|
+
device=self.vision_tower.device, dtype=self.vision_tower.dtype
|
|
159
|
+
),
|
|
160
|
+
output_hidden_states=True,
|
|
161
|
+
)
|
|
162
|
+
assert vision_tower_output.hidden_states is not None
|
|
163
|
+
|
|
164
|
+
vision_features: Tensor = vision_tower_output.hidden_states[-2]
|
|
165
|
+
|
|
166
|
+
vision_features_list, block_sizes = merge_features_for_dynamic_s2(
|
|
167
|
+
vision_features,
|
|
168
|
+
block_sizes=(
|
|
169
|
+
block_sizes
|
|
170
|
+
if block_sizes is not None
|
|
171
|
+
else [None] * vision_features.shape[0]
|
|
172
|
+
),
|
|
173
|
+
resize_output_to_scale_idx=-1,
|
|
174
|
+
scales=[448, 896, 1344],
|
|
175
|
+
)
|
|
176
|
+
|
|
177
|
+
vision_features_list = [
|
|
178
|
+
split_chessboard(x, block_size[0], block_size[1])
|
|
179
|
+
for x, block_size in zip(vision_features_list, block_sizes)
|
|
180
|
+
]
|
|
181
|
+
|
|
182
|
+
vision_features = torch.cat(
|
|
183
|
+
[einops.rearrange(x, "b c h w -> b (h w) c") for x in vision_features_list]
|
|
184
|
+
)
|
|
185
|
+
|
|
186
|
+
vision_features = self.mm_projector(vision_features)
|
|
187
|
+
|
|
188
|
+
vision_features_list = list(
|
|
189
|
+
vision_features.split(
|
|
190
|
+
[block_size[0] * block_size[1] for block_size in block_sizes], dim=0
|
|
191
|
+
)
|
|
192
|
+
)
|
|
193
|
+
vision_features_list = [
|
|
194
|
+
merge_chessboard(x, block_size[0], block_size[1])
|
|
195
|
+
for x, block_size in zip(vision_features_list, block_sizes)
|
|
196
|
+
]
|
|
197
|
+
|
|
198
|
+
vision_features = torch.stack(
|
|
199
|
+
[einops.rearrange(x, "1 c h w -> (h w) c") for x in vision_features_list]
|
|
200
|
+
)
|
|
201
|
+
|
|
202
|
+
vision_features = einops.rearrange(vision_features, "n p d -> (n p) d")
|
|
203
|
+
|
|
204
|
+
return vision_features
|
|
205
|
+
|
|
206
|
+
def load_weights(self, weights: Iterable[tuple[str, Tensor]]) -> None:
|
|
207
|
+
params_dict = dict(self.named_parameters())
|
|
208
|
+
|
|
209
|
+
for name, loaded_weight in weights:
|
|
210
|
+
if name.startswith("llm."):
|
|
211
|
+
self.llm.load_weights([(name[len("llm.") :], loaded_weight)])
|
|
212
|
+
else:
|
|
213
|
+
param = params_dict[name]
|
|
214
|
+
weight_loader = getattr(
|
|
215
|
+
param, "weight_loader", weight_utils.default_weight_loader
|
|
216
|
+
)
|
|
217
|
+
weight_loader(param, loaded_weight)
|
|
218
|
+
|
|
219
|
+
def pad_input_ids(
|
|
220
|
+
self, input_ids: list[int], mm_inputs: MultimodalInputs
|
|
221
|
+
) -> list[int]:
|
|
222
|
+
pattern = MultiModalityDataPaddingPatternMultimodalTokens()
|
|
223
|
+
return pattern.pad_input_tokens(input_ids, mm_inputs)
|
|
224
|
+
|
|
225
|
+
|
|
226
|
+
def merge_chessboard(x, num_split_h, num_split_w):
|
|
227
|
+
"""
|
|
228
|
+
x: b * n * c or b * h * w * c
|
|
229
|
+
out: b * c * h * w
|
|
230
|
+
Assuming x contains num_split**2 sub-squares concatenated along batch dimension, merge the sub-squares back to the original whole square.
|
|
231
|
+
"""
|
|
232
|
+
B = x.shape[0]
|
|
233
|
+
if x.dim() == 3:
|
|
234
|
+
N = x.shape[1]
|
|
235
|
+
x = einops.rearrange(
|
|
236
|
+
x, "b (h w) c -> b c h w", h=math.isqrt(N), w=math.isqrt(N)
|
|
237
|
+
)
|
|
238
|
+
|
|
239
|
+
assert B % (num_split_h * num_split_w) == 0
|
|
240
|
+
b = B // (num_split_h * num_split_w)
|
|
241
|
+
|
|
242
|
+
x_merge = torch.cat(
|
|
243
|
+
[
|
|
244
|
+
torch.cat(
|
|
245
|
+
[
|
|
246
|
+
x[(i * num_split_w + j) * b : (i * num_split_w + j + 1) * b]
|
|
247
|
+
for j in range(num_split_w)
|
|
248
|
+
],
|
|
249
|
+
dim=-1,
|
|
250
|
+
)
|
|
251
|
+
for i in range(num_split_h)
|
|
252
|
+
],
|
|
253
|
+
dim=-2,
|
|
254
|
+
)
|
|
255
|
+
|
|
256
|
+
return x_merge
|
|
257
|
+
|
|
258
|
+
|
|
259
|
+
def merge_features_for_dynamic_s2(
|
|
260
|
+
image_features, block_sizes, *, scales, resize_output_to_scale_idx
|
|
261
|
+
):
|
|
262
|
+
image_features_each_image = []
|
|
263
|
+
new_block_sizes = []
|
|
264
|
+
block_cnt = 0
|
|
265
|
+
for block_size_each_image in block_sizes:
|
|
266
|
+
if block_size_each_image is None:
|
|
267
|
+
cur_features = image_features[block_cnt : block_cnt + 1]
|
|
268
|
+
cur_features = einops.rearrange(
|
|
269
|
+
cur_features,
|
|
270
|
+
"1 (h w) c -> 1 c h w",
|
|
271
|
+
h=math.isqrt(cur_features.shape[1]),
|
|
272
|
+
)
|
|
273
|
+
cur_features = cur_features.repeat(1, len(scales), 1, 1)
|
|
274
|
+
image_features_each_image.append(cur_features)
|
|
275
|
+
new_block_sizes.append((1, 1))
|
|
276
|
+
block_cnt += 1
|
|
277
|
+
else:
|
|
278
|
+
cur_features_each_scale = []
|
|
279
|
+
for scale in scales[:-1]:
|
|
280
|
+
num_blocks_this_scale = (scale // scales[0]) ** 2
|
|
281
|
+
cur_features_each_scale.append(
|
|
282
|
+
merge_chessboard(
|
|
283
|
+
image_features[block_cnt : block_cnt + num_blocks_this_scale],
|
|
284
|
+
num_split_h=scale // scales[0],
|
|
285
|
+
num_split_w=scale // scales[0],
|
|
286
|
+
)
|
|
287
|
+
) # 1 * C * H * W
|
|
288
|
+
block_cnt += num_blocks_this_scale
|
|
289
|
+
num_blocks_last_scale = block_size_each_image[0] * block_size_each_image[1]
|
|
290
|
+
cur_features_each_scale.append(
|
|
291
|
+
merge_chessboard(
|
|
292
|
+
image_features[block_cnt : block_cnt + num_blocks_last_scale],
|
|
293
|
+
num_split_h=block_size_each_image[0],
|
|
294
|
+
num_split_w=block_size_each_image[1],
|
|
295
|
+
)
|
|
296
|
+
) # 1 * C * H * W
|
|
297
|
+
block_cnt += num_blocks_last_scale
|
|
298
|
+
|
|
299
|
+
# resize and concat features from different scales
|
|
300
|
+
output_size = cur_features_each_scale[resize_output_to_scale_idx].shape[-2:]
|
|
301
|
+
cur_features = torch.cat(
|
|
302
|
+
[
|
|
303
|
+
F.interpolate(
|
|
304
|
+
cur_features_each_scale[i].to(torch.float32),
|
|
305
|
+
size=output_size,
|
|
306
|
+
mode="area",
|
|
307
|
+
).to(cur_features_each_scale[i].dtype)
|
|
308
|
+
for i in range(len(cur_features_each_scale))
|
|
309
|
+
],
|
|
310
|
+
dim=1,
|
|
311
|
+
)
|
|
312
|
+
|
|
313
|
+
image_features_each_image.append(cur_features)
|
|
314
|
+
|
|
315
|
+
if (
|
|
316
|
+
resize_output_to_scale_idx == len(scales) - 1
|
|
317
|
+
or resize_output_to_scale_idx == -1
|
|
318
|
+
):
|
|
319
|
+
new_block_sizes.append(block_size_each_image)
|
|
320
|
+
else:
|
|
321
|
+
new_block_sizes.append(
|
|
322
|
+
(
|
|
323
|
+
scales[resize_output_to_scale_idx] // scales[0],
|
|
324
|
+
scales[resize_output_to_scale_idx] // scales[0],
|
|
325
|
+
)
|
|
326
|
+
)
|
|
327
|
+
|
|
328
|
+
assert block_cnt == len(
|
|
329
|
+
image_features
|
|
330
|
+
), f"The number of blocks ({block_cnt}) does not match length of image_features ({len(image_features)})!"
|
|
331
|
+
|
|
332
|
+
return image_features_each_image, new_block_sizes
|
|
333
|
+
|
|
334
|
+
|
|
335
|
+
def split_chessboard(x, num_split_h, num_split_w):
|
|
336
|
+
"""
|
|
337
|
+
x: b * c * h * w
|
|
338
|
+
out: b * c * h * w
|
|
339
|
+
Deividing x into num_split**2 sub-squares, and concatenate all the sub-squares on the batch dimension
|
|
340
|
+
"""
|
|
341
|
+
B, C, H, W = x.shape
|
|
342
|
+
assert H % num_split_h == 0 and W % num_split_w == 0
|
|
343
|
+
h, w = H // num_split_h, W // num_split_w
|
|
344
|
+
x_split = torch.cat(
|
|
345
|
+
[
|
|
346
|
+
x[:, :, i * h : (i + 1) * h, j * w : (j + 1) * w]
|
|
347
|
+
for i in range(num_split_h)
|
|
348
|
+
for j in range(num_split_w)
|
|
349
|
+
],
|
|
350
|
+
dim=0,
|
|
351
|
+
)
|
|
352
|
+
return x_split
|
|
353
|
+
|
|
354
|
+
|
|
355
|
+
EntryClass = [NVILAForConditionalGeneration]
|
|
@@ -0,0 +1,184 @@
|
|
|
1
|
+
import math
|
|
2
|
+
from collections.abc import Iterable
|
|
3
|
+
from typing import Any
|
|
4
|
+
|
|
5
|
+
import einops
|
|
6
|
+
import torch
|
|
7
|
+
import torch.nn as nn
|
|
8
|
+
import torch.nn.functional as F
|
|
9
|
+
from torch import Tensor
|
|
10
|
+
from transformers.configuration_utils import PretrainedConfig
|
|
11
|
+
from transformers.modeling_outputs import BaseModelOutputWithPooling
|
|
12
|
+
from transformers.models.qwen2.configuration_qwen2 import Qwen2Config
|
|
13
|
+
from transformers.models.siglip import SiglipVisionConfig, SiglipVisionModel
|
|
14
|
+
|
|
15
|
+
import sglang.srt.managers.mm_utils as mm_utils
|
|
16
|
+
import sglang.srt.model_loader.weight_utils as weight_utils
|
|
17
|
+
import sglang.srt.utils as utils
|
|
18
|
+
from sglang.srt.layers.logits_processor import LogitsProcessorOutput
|
|
19
|
+
from sglang.srt.layers.quantization.base_config import QuantizationConfig
|
|
20
|
+
from sglang.srt.managers.mm_utils import MultiModalityDataPaddingPatternMultimodalTokens
|
|
21
|
+
from sglang.srt.managers.schedule_batch import (
|
|
22
|
+
Modality,
|
|
23
|
+
MultimodalDataItem,
|
|
24
|
+
MultimodalInputs,
|
|
25
|
+
)
|
|
26
|
+
from sglang.srt.model_executor.forward_batch_info import ForwardBatch
|
|
27
|
+
from sglang.srt.models.qwen2 import Qwen2ForCausalLM
|
|
28
|
+
|
|
29
|
+
MM_HIDDEN_SIZE = 1152
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
class NVILALiteConfig(PretrainedConfig):
|
|
33
|
+
model_type = "nvila_lite"
|
|
34
|
+
sub_configs = {
|
|
35
|
+
"text_config": Qwen2Config,
|
|
36
|
+
"vision_config": SiglipVisionConfig,
|
|
37
|
+
}
|
|
38
|
+
_auto_class = "AutoConfig"
|
|
39
|
+
|
|
40
|
+
def __init__(
|
|
41
|
+
self,
|
|
42
|
+
*,
|
|
43
|
+
text_config: dict[str, Any] | None = None,
|
|
44
|
+
vision_config: dict[str, Any] | None = None,
|
|
45
|
+
image_token_id: int | None = None,
|
|
46
|
+
video_token_id: int | None = None,
|
|
47
|
+
**kwargs,
|
|
48
|
+
):
|
|
49
|
+
self.text_config = (
|
|
50
|
+
Qwen2Config(**text_config) if text_config is not None else Qwen2Config()
|
|
51
|
+
)
|
|
52
|
+
self.vision_config = (
|
|
53
|
+
SiglipVisionConfig(**vision_config)
|
|
54
|
+
if vision_config is not None
|
|
55
|
+
else SiglipVisionConfig()
|
|
56
|
+
)
|
|
57
|
+
|
|
58
|
+
self.image_token_id = image_token_id if image_token_id is not None else -1
|
|
59
|
+
self.video_token_id = video_token_id if video_token_id is not None else -1
|
|
60
|
+
|
|
61
|
+
super().__init__(**kwargs)
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
class NVILALiteMultiModalProjectorDownsampleBlock(nn.Module):
|
|
65
|
+
def forward(self, x: Tensor) -> Tensor:
|
|
66
|
+
batch_size, sequence_length, hidden_size = x.shape
|
|
67
|
+
|
|
68
|
+
feat_size = math.isqrt(sequence_length)
|
|
69
|
+
|
|
70
|
+
features = x.reshape(batch_size, feat_size, feat_size, hidden_size)
|
|
71
|
+
|
|
72
|
+
pad_after = (3 - feat_size % 3) % 3
|
|
73
|
+
if pad_after > 0:
|
|
74
|
+
features = F.pad(features, (0, 0, 0, pad_after, 0, pad_after))
|
|
75
|
+
feat_size = feat_size + pad_after
|
|
76
|
+
|
|
77
|
+
features = features.reshape(
|
|
78
|
+
batch_size, feat_size // 3, 3, feat_size // 3, 3, hidden_size
|
|
79
|
+
)
|
|
80
|
+
features = features.permute(0, 1, 3, 2, 4, 5).contiguous()
|
|
81
|
+
features = features.reshape(batch_size, -1, 9 * hidden_size)
|
|
82
|
+
|
|
83
|
+
return features
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
class NVILALiteMultiModalProjector(nn.Module):
|
|
87
|
+
def __init__(self, config: NVILALiteConfig):
|
|
88
|
+
super().__init__()
|
|
89
|
+
|
|
90
|
+
self.layers = nn.Sequential(
|
|
91
|
+
NVILALiteMultiModalProjectorDownsampleBlock(),
|
|
92
|
+
nn.LayerNorm(MM_HIDDEN_SIZE * 9),
|
|
93
|
+
nn.Linear(MM_HIDDEN_SIZE * 9, MM_HIDDEN_SIZE * 3),
|
|
94
|
+
nn.GELU(),
|
|
95
|
+
nn.LayerNorm(MM_HIDDEN_SIZE * 3),
|
|
96
|
+
nn.Linear(MM_HIDDEN_SIZE * 3, config.text_config.hidden_size),
|
|
97
|
+
nn.GELU(),
|
|
98
|
+
nn.Linear(config.text_config.hidden_size, config.text_config.hidden_size),
|
|
99
|
+
)
|
|
100
|
+
|
|
101
|
+
def forward(self, x: Tensor) -> Tensor:
|
|
102
|
+
return self.layers(x)
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
class NVILALiteForConditionalGeneration(nn.Module):
|
|
106
|
+
def __init__(
|
|
107
|
+
self,
|
|
108
|
+
config: NVILALiteConfig,
|
|
109
|
+
quant_config: QuantizationConfig | None = None,
|
|
110
|
+
prefix: str = "",
|
|
111
|
+
) -> None:
|
|
112
|
+
super().__init__()
|
|
113
|
+
|
|
114
|
+
self.config = config
|
|
115
|
+
|
|
116
|
+
self.vision_tower = SiglipVisionModel(config.vision_config)
|
|
117
|
+
self.mm_projector = NVILALiteMultiModalProjector(config)
|
|
118
|
+
self.llm = Qwen2ForCausalLM(
|
|
119
|
+
config=config.text_config,
|
|
120
|
+
quant_config=quant_config,
|
|
121
|
+
prefix=utils.add_prefix("llm", prefix),
|
|
122
|
+
)
|
|
123
|
+
|
|
124
|
+
def forward(
|
|
125
|
+
self,
|
|
126
|
+
input_ids: Tensor,
|
|
127
|
+
positions: Tensor,
|
|
128
|
+
forward_batch: ForwardBatch,
|
|
129
|
+
get_embedding: bool = False,
|
|
130
|
+
) -> LogitsProcessorOutput:
|
|
131
|
+
output = mm_utils.general_mm_embed_routine(
|
|
132
|
+
input_ids=input_ids,
|
|
133
|
+
forward_batch=forward_batch,
|
|
134
|
+
language_model=self.llm,
|
|
135
|
+
data_embedding_funcs={
|
|
136
|
+
Modality.IMAGE: self.get_image_feature,
|
|
137
|
+
Modality.VIDEO: self.get_image_feature,
|
|
138
|
+
},
|
|
139
|
+
get_embedding=get_embedding,
|
|
140
|
+
positions=positions,
|
|
141
|
+
)
|
|
142
|
+
|
|
143
|
+
assert isinstance(output, LogitsProcessorOutput)
|
|
144
|
+
|
|
145
|
+
return output
|
|
146
|
+
|
|
147
|
+
def get_image_feature(self, mm_input: list[MultimodalDataItem]) -> Tensor:
|
|
148
|
+
pixel_values = torch.cat([torch.tensor(x.feature) for x in mm_input], dim=0)
|
|
149
|
+
|
|
150
|
+
vision_tower_output: BaseModelOutputWithPooling = self.vision_tower(
|
|
151
|
+
pixel_values,
|
|
152
|
+
output_hidden_states=True,
|
|
153
|
+
)
|
|
154
|
+
assert vision_tower_output.hidden_states is not None
|
|
155
|
+
|
|
156
|
+
vision_features = vision_tower_output.hidden_states[-2]
|
|
157
|
+
|
|
158
|
+
vision_features = self.mm_projector(vision_features)
|
|
159
|
+
|
|
160
|
+
vision_features = einops.rearrange(vision_features, "n p d -> (n p) d")
|
|
161
|
+
|
|
162
|
+
return vision_features
|
|
163
|
+
|
|
164
|
+
def load_weights(self, weights: Iterable[tuple[str, Tensor]]) -> None:
|
|
165
|
+
params_dict = dict(self.named_parameters())
|
|
166
|
+
|
|
167
|
+
for name, loaded_weight in weights:
|
|
168
|
+
if name.startswith("llm."):
|
|
169
|
+
self.llm.load_weights([(name[len("llm.") :], loaded_weight)])
|
|
170
|
+
else:
|
|
171
|
+
param = params_dict[name]
|
|
172
|
+
weight_loader = getattr(
|
|
173
|
+
param, "weight_loader", weight_utils.default_weight_loader
|
|
174
|
+
)
|
|
175
|
+
weight_loader(param, loaded_weight)
|
|
176
|
+
|
|
177
|
+
def pad_input_ids(
|
|
178
|
+
self, input_ids: list[int], mm_inputs: MultimodalInputs
|
|
179
|
+
) -> list[int]:
|
|
180
|
+
pattern = MultiModalityDataPaddingPatternMultimodalTokens()
|
|
181
|
+
return pattern.pad_input_tokens(input_ids, mm_inputs)
|
|
182
|
+
|
|
183
|
+
|
|
184
|
+
EntryClass = [NVILALiteForConditionalGeneration]
|
sglang/srt/models/qwen2.py
CHANGED
|
@@ -49,6 +49,7 @@ from sglang.srt.model_loader.weight_utils import (
|
|
|
49
49
|
default_weight_loader,
|
|
50
50
|
kv_cache_scales_loader,
|
|
51
51
|
)
|
|
52
|
+
from sglang.srt.server_args import get_global_server_args
|
|
52
53
|
from sglang.srt.utils import add_prefix, make_layers
|
|
53
54
|
|
|
54
55
|
Qwen2Config = None
|
|
@@ -89,6 +90,9 @@ class Qwen2MLP(nn.Module):
|
|
|
89
90
|
self.act_fn = SiluAndMul()
|
|
90
91
|
|
|
91
92
|
def forward(self, x):
|
|
93
|
+
if get_global_server_args().rl_on_policy_target == "fsdp":
|
|
94
|
+
x = x.bfloat16()
|
|
95
|
+
|
|
92
96
|
gate_up, _ = self.gate_up_proj(x)
|
|
93
97
|
x = self.act_fn(gate_up)
|
|
94
98
|
x, _ = self.down_proj(x)
|
|
@@ -275,6 +279,11 @@ class Qwen2Model(nn.Module):
|
|
|
275
279
|
quant_config=quant_config,
|
|
276
280
|
enable_tp=not is_dp_attention_enabled(),
|
|
277
281
|
prefix=add_prefix("embed_tokens", prefix),
|
|
282
|
+
params_dtype=(
|
|
283
|
+
torch.float32
|
|
284
|
+
if get_global_server_args().rl_on_policy_target == "fsdp"
|
|
285
|
+
else None
|
|
286
|
+
),
|
|
278
287
|
)
|
|
279
288
|
else:
|
|
280
289
|
self.embed_tokens = PPMissingLayer()
|
|
@@ -295,7 +304,19 @@ class Qwen2Model(nn.Module):
|
|
|
295
304
|
prefix=add_prefix("layers", prefix),
|
|
296
305
|
)
|
|
297
306
|
if self.pp_group.is_last_rank:
|
|
298
|
-
|
|
307
|
+
norm_kwargs = (
|
|
308
|
+
dict(
|
|
309
|
+
weight_dtype=torch.float32,
|
|
310
|
+
cast_x_before_out_mul=True,
|
|
311
|
+
override_orig_dtype=torch.float32,
|
|
312
|
+
fp32_residual=True,
|
|
313
|
+
)
|
|
314
|
+
if get_global_server_args().rl_on_policy_target == "fsdp"
|
|
315
|
+
else {}
|
|
316
|
+
)
|
|
317
|
+
self.norm = RMSNorm(
|
|
318
|
+
config.hidden_size, eps=config.rms_norm_eps, **norm_kwargs
|
|
319
|
+
)
|
|
299
320
|
else:
|
|
300
321
|
self.norm = PPMissingLayer(return_tuple=True)
|
|
301
322
|
|
|
@@ -441,7 +462,7 @@ class Qwen2ForCausalLM(nn.Module):
|
|
|
441
462
|
self.pp_group.send(
|
|
442
463
|
self.model.embed_tokens.weight, dst=self.pp_group.last_rank
|
|
443
464
|
)
|
|
444
|
-
|
|
465
|
+
elif self.pp_group.is_last_rank:
|
|
445
466
|
emb_token_weight = self.pp_group.recv(
|
|
446
467
|
size=(config.vocab_size, config.hidden_size),
|
|
447
468
|
dtype=next(self.model.parameters()).dtype,
|
sglang/srt/models/qwen2_moe.py
CHANGED
|
@@ -473,10 +473,16 @@ class Qwen2MoeDecoderLayer(nn.Module):
|
|
|
473
473
|
hidden_states: torch.Tensor,
|
|
474
474
|
forward_batch: ForwardBatch,
|
|
475
475
|
residual: Optional[torch.Tensor],
|
|
476
|
+
captured_last_layer_outputs: Optional[List[torch.Tensor]] = None,
|
|
476
477
|
) -> Tuple[torch.Tensor, torch.Tensor]:
|
|
477
478
|
|
|
478
|
-
hidden_states, residual =
|
|
479
|
-
|
|
479
|
+
hidden_states, residual = (
|
|
480
|
+
self.layer_communicator.prepare_attn_and_capture_last_layer_outputs(
|
|
481
|
+
hidden_states,
|
|
482
|
+
residual,
|
|
483
|
+
forward_batch,
|
|
484
|
+
captured_last_layer_outputs=captured_last_layer_outputs,
|
|
485
|
+
)
|
|
480
486
|
)
|
|
481
487
|
|
|
482
488
|
if hidden_states.shape[0] != 0:
|
|
@@ -553,6 +559,11 @@ class Qwen2MoeModel(nn.Module):
|
|
|
553
559
|
# For EAGLE3 support
|
|
554
560
|
self.layers_to_capture = []
|
|
555
561
|
|
|
562
|
+
def set_eagle3_layers_to_capture(self, layers_to_capture: List[int]):
|
|
563
|
+
self.layers_to_capture = layers_to_capture
|
|
564
|
+
for layer_id in self.layers_to_capture:
|
|
565
|
+
setattr(self.layers[layer_id], "_is_layer_to_capture", True)
|
|
566
|
+
|
|
556
567
|
def forward(
|
|
557
568
|
self,
|
|
558
569
|
input_ids: torch.Tensor,
|
|
@@ -585,12 +596,6 @@ class Qwen2MoeModel(nn.Module):
|
|
|
585
596
|
)
|
|
586
597
|
else:
|
|
587
598
|
for i in range(self.start_layer, self.end_layer):
|
|
588
|
-
if i in self.layers_to_capture:
|
|
589
|
-
aux_hidden_states.append(
|
|
590
|
-
hidden_states + residual
|
|
591
|
-
if residual is not None
|
|
592
|
-
else hidden_states
|
|
593
|
-
)
|
|
594
599
|
ctx = (
|
|
595
600
|
nullcontext()
|
|
596
601
|
if get_global_server_args().enable_piecewise_cuda_graph
|
|
@@ -599,7 +604,15 @@ class Qwen2MoeModel(nn.Module):
|
|
|
599
604
|
with ctx:
|
|
600
605
|
layer = self.layers[i]
|
|
601
606
|
hidden_states, residual = layer(
|
|
602
|
-
positions,
|
|
607
|
+
positions,
|
|
608
|
+
hidden_states,
|
|
609
|
+
forward_batch,
|
|
610
|
+
residual,
|
|
611
|
+
captured_last_layer_outputs=(
|
|
612
|
+
aux_hidden_states
|
|
613
|
+
if getattr(layer, "_is_layer_to_capture", False)
|
|
614
|
+
else None
|
|
615
|
+
),
|
|
603
616
|
)
|
|
604
617
|
if not self.pp_group.is_last_rank:
|
|
605
618
|
return PPProxyTensors(
|
|
@@ -830,13 +843,15 @@ class Qwen2MoeForCausalLM(nn.Module):
|
|
|
830
843
|
self.capture_aux_hidden_states = True
|
|
831
844
|
if layer_ids is None:
|
|
832
845
|
num_layers = self.config.num_hidden_layers
|
|
833
|
-
self.model.
|
|
834
|
-
|
|
835
|
-
|
|
836
|
-
|
|
837
|
-
|
|
846
|
+
self.model.set_eagle3_layers_to_capture(
|
|
847
|
+
[
|
|
848
|
+
2,
|
|
849
|
+
num_layers // 2,
|
|
850
|
+
num_layers - 3,
|
|
851
|
+
]
|
|
852
|
+
) # Specific layers for EAGLE3 support
|
|
838
853
|
else:
|
|
839
|
-
self.model.
|
|
854
|
+
self.model.set_eagle3_layers_to_capture([val + 1 for val in layer_ids])
|
|
840
855
|
|
|
841
856
|
|
|
842
857
|
EntryClass = Qwen2MoeForCausalLM
|