layer-scan 0.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- layer_scan/__init__.py +3 -0
- layer_scan/backends/__init__.py +9 -0
- layer_scan/backends/base.py +51 -0
- layer_scan/backends/exllamav2.py +224 -0
- layer_scan/backends/transformers_backend.py +254 -0
- layer_scan/cli.py +234 -0
- layer_scan/config.py +112 -0
- layer_scan/export.py +63 -0
- layer_scan/heatmap.py +234 -0
- layer_scan/probes/__init__.py +9 -0
- layer_scan/probes/base.py +93 -0
- layer_scan/probes/custom.py +77 -0
- layer_scan/probes/eq_probe.py +162 -0
- layer_scan/probes/json_probe.py +179 -0
- layer_scan/probes/math_probe.py +190 -0
- layer_scan/scanner.py +249 -0
- layer_scan/scoring.py +150 -0
- layer_scan-0.1.0.dist-info/METADATA +347 -0
- layer_scan-0.1.0.dist-info/RECORD +23 -0
- layer_scan-0.1.0.dist-info/WHEEL +4 -0
- layer_scan-0.1.0.dist-info/entry_points.txt +2 -0
- layer_scan-0.1.0.dist-info/licenses/LICENSE +21 -0
- layer_scan-0.1.0.dist-info/licenses/NOTICE +14 -0
layer_scan/__init__.py
ADDED
|
@@ -0,0 +1,51 @@
|
|
|
1
|
+
"""Abstract backend interface for layer-scan."""
|
|
2
|
+
|
|
3
|
+
from __future__ import annotations
|
|
4
|
+
|
|
5
|
+
from abc import ABC, abstractmethod
|
|
6
|
+
|
|
7
|
+
import torch
|
|
8
|
+
|
|
9
|
+
from layer_scan.config import DuplicationConfig
|
|
10
|
+
|
|
11
|
+
|
|
12
|
+
class Backend(ABC):
|
|
13
|
+
"""Abstract base for inference backends.
|
|
14
|
+
|
|
15
|
+
A backend must provide:
|
|
16
|
+
1. Model loading and tokenization
|
|
17
|
+
2. Forward pass with optional layer duplication
|
|
18
|
+
3. Layer count information
|
|
19
|
+
"""
|
|
20
|
+
|
|
21
|
+
@abstractmethod
|
|
22
|
+
def load(self, model_path: str, **kwargs) -> None:
|
|
23
|
+
"""Load a model from the given path or HuggingFace ID."""
|
|
24
|
+
|
|
25
|
+
@abstractmethod
|
|
26
|
+
def get_total_layers(self) -> int:
|
|
27
|
+
"""Return the total number of transformer decoder layers."""
|
|
28
|
+
|
|
29
|
+
@abstractmethod
|
|
30
|
+
def get_tokenizer(self):
|
|
31
|
+
"""Return the model's tokenizer."""
|
|
32
|
+
|
|
33
|
+
@abstractmethod
|
|
34
|
+
def forward_with_duplication(
|
|
35
|
+
self,
|
|
36
|
+
text: str,
|
|
37
|
+
duplication_config: DuplicationConfig | None = None,
|
|
38
|
+
) -> torch.Tensor:
|
|
39
|
+
"""Run forward pass with optional layer duplication.
|
|
40
|
+
|
|
41
|
+
Args:
|
|
42
|
+
text: Input text to process.
|
|
43
|
+
duplication_config: If provided, duplicate layers [i..j-1].
|
|
44
|
+
If None, run the standard forward pass (baseline).
|
|
45
|
+
|
|
46
|
+
Returns:
|
|
47
|
+
Logits tensor of shape (vocab_size,) for the last token position.
|
|
48
|
+
"""
|
|
49
|
+
|
|
50
|
+
def cleanup(self) -> None:
|
|
51
|
+
"""Release resources (GPU memory, etc.). Optional."""
|
|
@@ -0,0 +1,224 @@
|
|
|
1
|
+
"""ExLlamaV2 backend — optimized for consumer GPUs with quantized models.
|
|
2
|
+
|
|
3
|
+
This is the recommended backend for scanning large models (70B+) on
|
|
4
|
+
consumer hardware (e.g., 2×RTX 4090). ExLlamaV2 supports GPTQ and EXL2
|
|
5
|
+
quantization with excellent memory efficiency.
|
|
6
|
+
|
|
7
|
+
Requires: pip install layer-scan[exllamav2]
|
|
8
|
+
"""
|
|
9
|
+
|
|
10
|
+
from __future__ import annotations
|
|
11
|
+
|
|
12
|
+
import logging
|
|
13
|
+
|
|
14
|
+
import torch
|
|
15
|
+
|
|
16
|
+
from layer_scan.backends.base import Backend
|
|
17
|
+
from layer_scan.config import DuplicationConfig
|
|
18
|
+
|
|
19
|
+
logger = logging.getLogger(__name__)
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class ExLlamaV2Backend(Backend):
|
|
23
|
+
"""ExLlamaV2 backend with runtime layer duplication.
|
|
24
|
+
|
|
25
|
+
Uses ExLlamaV2's modular architecture to execute individual layers
|
|
26
|
+
in a custom order, enabling layer duplication without model copies.
|
|
27
|
+
"""
|
|
28
|
+
|
|
29
|
+
def __init__(self) -> None:
|
|
30
|
+
self._model = None
|
|
31
|
+
self._tokenizer = None
|
|
32
|
+
self._cache = None
|
|
33
|
+
self._config = None
|
|
34
|
+
self._total_layers: int = 0
|
|
35
|
+
|
|
36
|
+
def load(self, model_path: str, **kwargs) -> None:
|
|
37
|
+
"""Load an EXL2/GPTQ quantized model.
|
|
38
|
+
|
|
39
|
+
Kwargs:
|
|
40
|
+
gpu_split: List of GPU memory limits in MB (e.g., [22000, 22000]).
|
|
41
|
+
max_seq_len: Maximum sequence length (default: 4096).
|
|
42
|
+
rope_scale: RoPE scaling factor (default: 1.0).
|
|
43
|
+
"""
|
|
44
|
+
try:
|
|
45
|
+
from exllamav2 import (
|
|
46
|
+
ExLlamaV2,
|
|
47
|
+
ExLlamaV2Cache,
|
|
48
|
+
ExLlamaV2Config,
|
|
49
|
+
ExLlamaV2Tokenizer,
|
|
50
|
+
)
|
|
51
|
+
except ImportError as e:
|
|
52
|
+
raise ImportError(
|
|
53
|
+
"ExLlamaV2 not installed. Install with: "
|
|
54
|
+
"pip install layer-scan[exllamav2]"
|
|
55
|
+
) from e
|
|
56
|
+
|
|
57
|
+
logger.info("Loading ExLlamaV2 model from %s", model_path)
|
|
58
|
+
|
|
59
|
+
self._config = ExLlamaV2Config(model_path)
|
|
60
|
+
self._config.max_seq_len = kwargs.get("max_seq_len", 4096)
|
|
61
|
+
|
|
62
|
+
if "rope_scale" in kwargs:
|
|
63
|
+
self._config.scale_pos_emb = kwargs["rope_scale"]
|
|
64
|
+
|
|
65
|
+
self._model = ExLlamaV2(self._config)
|
|
66
|
+
|
|
67
|
+
gpu_split = kwargs.get("gpu_split")
|
|
68
|
+
if gpu_split:
|
|
69
|
+
self._model.load(gpu_split)
|
|
70
|
+
else:
|
|
71
|
+
self._model.load_autosplit()
|
|
72
|
+
|
|
73
|
+
self._tokenizer = ExLlamaV2Tokenizer(self._config)
|
|
74
|
+
self._cache = ExLlamaV2Cache(self._model, max_seq_len=self._config.max_seq_len)
|
|
75
|
+
|
|
76
|
+
# Count transformer layers (excluding embedding, head, norms)
|
|
77
|
+
self._total_layers = self._count_decoder_layers()
|
|
78
|
+
logger.info("ExLlamaV2 model loaded: %d decoder layers", self._total_layers)
|
|
79
|
+
|
|
80
|
+
def _count_decoder_layers(self) -> int:
|
|
81
|
+
"""Count the number of decoder layers in the ExLlamaV2 model."""
|
|
82
|
+
count = 0
|
|
83
|
+
for module in self._model.modules:
|
|
84
|
+
module_name = type(module).__name__
|
|
85
|
+
if "Attention" in module_name or "MLP" in module_name:
|
|
86
|
+
# Each transformer block has attention + MLP
|
|
87
|
+
# We count by attention modules to get block count
|
|
88
|
+
if "Attention" in module_name:
|
|
89
|
+
count += 1
|
|
90
|
+
return count
|
|
91
|
+
|
|
92
|
+
def get_total_layers(self) -> int:
|
|
93
|
+
if self._total_layers == 0:
|
|
94
|
+
raise RuntimeError("Model not loaded. Call load() first.")
|
|
95
|
+
return self._total_layers
|
|
96
|
+
|
|
97
|
+
def get_tokenizer(self):
|
|
98
|
+
"""Return a tokenizer adapter compatible with HuggingFace interface."""
|
|
99
|
+
if self._tokenizer is None:
|
|
100
|
+
raise RuntimeError("Model not loaded. Call load() first.")
|
|
101
|
+
return _ExLlamaV2TokenizerAdapter(self._tokenizer)
|
|
102
|
+
|
|
103
|
+
@torch.no_grad()
|
|
104
|
+
def forward_with_duplication(
|
|
105
|
+
self,
|
|
106
|
+
text: str,
|
|
107
|
+
duplication_config: DuplicationConfig | None = None,
|
|
108
|
+
) -> torch.Tensor:
|
|
109
|
+
"""Run forward pass with optional layer duplication.
|
|
110
|
+
|
|
111
|
+
For ExLlamaV2, we use the module-level forward to control
|
|
112
|
+
execution order of individual transformer blocks.
|
|
113
|
+
"""
|
|
114
|
+
input_ids = self._tokenizer.encode(text)
|
|
115
|
+
input_ids = input_ids.to(self._model.modules[0].device())
|
|
116
|
+
|
|
117
|
+
self._cache.current_seq_len = 0
|
|
118
|
+
|
|
119
|
+
if duplication_config is None:
|
|
120
|
+
# Standard forward pass
|
|
121
|
+
logits = self._model.forward(input_ids, self._cache)
|
|
122
|
+
return logits[0, -1, :]
|
|
123
|
+
|
|
124
|
+
# Duplicated forward: execute modules in custom order
|
|
125
|
+
return self._forward_duplicated(input_ids, duplication_config)
|
|
126
|
+
|
|
127
|
+
def _forward_duplicated(
|
|
128
|
+
self,
|
|
129
|
+
input_ids: torch.Tensor,
|
|
130
|
+
config: DuplicationConfig,
|
|
131
|
+
) -> torch.Tensor:
|
|
132
|
+
"""Execute ExLlamaV2 modules with layer duplication.
|
|
133
|
+
|
|
134
|
+
ExLlamaV2 exposes modules as a flat list. We need to map
|
|
135
|
+
decoder layer indices to module indices and build the
|
|
136
|
+
custom execution order.
|
|
137
|
+
"""
|
|
138
|
+
# Build module execution order based on layer duplication config
|
|
139
|
+
# ExLlamaV2 modules: [embedding, layer0_attn, layer0_mlp, ..., norm, head]
|
|
140
|
+
exec_order = config.execution_order()
|
|
141
|
+
|
|
142
|
+
# Map layer indices to module ranges
|
|
143
|
+
layer_modules = self._get_layer_module_map()
|
|
144
|
+
|
|
145
|
+
# Execute embedding
|
|
146
|
+
hidden = self._model.modules[0].forward(input_ids, self._cache)
|
|
147
|
+
|
|
148
|
+
# Execute decoder layers in duplicated order
|
|
149
|
+
for layer_idx in exec_order:
|
|
150
|
+
for mod_idx in layer_modules[layer_idx]:
|
|
151
|
+
hidden = self._model.modules[mod_idx].forward(hidden, self._cache)
|
|
152
|
+
|
|
153
|
+
# Execute final norm and head
|
|
154
|
+
for mod_idx in self._get_post_layer_modules():
|
|
155
|
+
hidden = self._model.modules[mod_idx].forward(hidden, self._cache)
|
|
156
|
+
|
|
157
|
+
return hidden[0, -1, :]
|
|
158
|
+
|
|
159
|
+
def _get_layer_module_map(self) -> dict[int, list[int]]:
|
|
160
|
+
"""Map decoder layer indices to ExLlamaV2 module indices."""
|
|
161
|
+
layer_map: dict[int, list[int]] = {}
|
|
162
|
+
current_layer = -1
|
|
163
|
+
|
|
164
|
+
for mod_idx, module in enumerate(self._model.modules):
|
|
165
|
+
name = type(module).__name__
|
|
166
|
+
if "Attention" in name:
|
|
167
|
+
current_layer += 1
|
|
168
|
+
layer_map[current_layer] = [mod_idx]
|
|
169
|
+
elif "MLP" in name and current_layer >= 0:
|
|
170
|
+
layer_map[current_layer].append(mod_idx)
|
|
171
|
+
|
|
172
|
+
return layer_map
|
|
173
|
+
|
|
174
|
+
def _get_post_layer_modules(self) -> list[int]:
|
|
175
|
+
"""Get module indices for norm and lm_head (after all decoder layers)."""
|
|
176
|
+
post_modules = []
|
|
177
|
+
found_last_mlp = False
|
|
178
|
+
|
|
179
|
+
for mod_idx in range(len(self._model.modules) - 1, -1, -1):
|
|
180
|
+
name = type(self._model.modules[mod_idx]).__name__
|
|
181
|
+
if "MLP" in name or "Attention" in name:
|
|
182
|
+
if not found_last_mlp:
|
|
183
|
+
found_last_mlp = True
|
|
184
|
+
# Everything after this index is post-layer
|
|
185
|
+
post_modules = list(range(mod_idx + 1, len(self._model.modules)))
|
|
186
|
+
break
|
|
187
|
+
|
|
188
|
+
return post_modules
|
|
189
|
+
|
|
190
|
+
def cleanup(self) -> None:
|
|
191
|
+
if self._model is not None:
|
|
192
|
+
del self._model
|
|
193
|
+
self._model = None
|
|
194
|
+
if self._cache is not None:
|
|
195
|
+
del self._cache
|
|
196
|
+
self._cache = None
|
|
197
|
+
self._tokenizer = None
|
|
198
|
+
self._total_layers = 0
|
|
199
|
+
|
|
200
|
+
import gc
|
|
201
|
+
|
|
202
|
+
gc.collect()
|
|
203
|
+
if torch.cuda.is_available():
|
|
204
|
+
torch.cuda.empty_cache()
|
|
205
|
+
|
|
206
|
+
|
|
207
|
+
class _ExLlamaV2TokenizerAdapter:
|
|
208
|
+
"""Adapter to make ExLlamaV2 tokenizer compatible with HuggingFace interface.
|
|
209
|
+
|
|
210
|
+
layer-scan's scoring module expects tokenizer.encode(str) -> list[int].
|
|
211
|
+
"""
|
|
212
|
+
|
|
213
|
+
def __init__(self, exl2_tokenizer) -> None:
|
|
214
|
+
self._tokenizer = exl2_tokenizer
|
|
215
|
+
|
|
216
|
+
def encode(self, text: str, add_special_tokens: bool = True) -> list[int]:
|
|
217
|
+
ids = self._tokenizer.encode(text)
|
|
218
|
+
return ids[0].tolist() if hasattr(ids, "tolist") else list(ids[0])
|
|
219
|
+
|
|
220
|
+
def decode(self, token_ids: list[int]) -> str:
|
|
221
|
+
import torch as _torch
|
|
222
|
+
|
|
223
|
+
ids = _torch.tensor([token_ids])
|
|
224
|
+
return self._tokenizer.decode(ids)[0]
|
|
@@ -0,0 +1,254 @@
|
|
|
1
|
+
"""HuggingFace Transformers backend — the default reference implementation.
|
|
2
|
+
|
|
3
|
+
This backend hooks into the model's forward pass at the layer level,
|
|
4
|
+
implementing layer duplication by re-executing specified layers without
|
|
5
|
+
modifying model weights. It works with any CausalLM model.
|
|
6
|
+
|
|
7
|
+
Note: This is the reference backend for correctness. For large models
|
|
8
|
+
(70B+), use the ExLlamaV2 backend for quantized inference on consumer GPUs.
|
|
9
|
+
"""
|
|
10
|
+
|
|
11
|
+
from __future__ import annotations
|
|
12
|
+
|
|
13
|
+
import logging
|
|
14
|
+
from typing import Any
|
|
15
|
+
|
|
16
|
+
import torch
|
|
17
|
+
|
|
18
|
+
from layer_scan.backends.base import Backend
|
|
19
|
+
from layer_scan.config import DuplicationConfig
|
|
20
|
+
|
|
21
|
+
logger = logging.getLogger(__name__)
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
class TransformersBackend(Backend):
|
|
25
|
+
"""HuggingFace Transformers backend with runtime layer duplication."""
|
|
26
|
+
|
|
27
|
+
def __init__(self) -> None:
|
|
28
|
+
self._model: Any = None
|
|
29
|
+
self._tokenizer: Any = None
|
|
30
|
+
self._layers: list[Any] = []
|
|
31
|
+
self._total_layers: int = 0
|
|
32
|
+
|
|
33
|
+
def load(self, model_path: str, **kwargs) -> None:
|
|
34
|
+
"""Load a CausalLM model from path or HuggingFace ID.
|
|
35
|
+
|
|
36
|
+
Kwargs:
|
|
37
|
+
dtype: Torch dtype string (default: "float16").
|
|
38
|
+
device_map: Device map for model parallelism (default: "auto").
|
|
39
|
+
trust_remote_code: Whether to trust remote code (default: False).
|
|
40
|
+
"""
|
|
41
|
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
42
|
+
|
|
43
|
+
dtype_str = kwargs.get("dtype", "float16")
|
|
44
|
+
dtype_map = {
|
|
45
|
+
"float16": torch.float16,
|
|
46
|
+
"bfloat16": torch.bfloat16,
|
|
47
|
+
"float32": torch.float32,
|
|
48
|
+
}
|
|
49
|
+
dtype = dtype_map.get(dtype_str, torch.float16)
|
|
50
|
+
|
|
51
|
+
logger.info("Loading tokenizer from %s", model_path)
|
|
52
|
+
self._tokenizer = AutoTokenizer.from_pretrained(
|
|
53
|
+
model_path,
|
|
54
|
+
trust_remote_code=kwargs.get("trust_remote_code", False),
|
|
55
|
+
)
|
|
56
|
+
|
|
57
|
+
logger.info("Loading model from %s (dtype=%s)", model_path, dtype_str)
|
|
58
|
+
self._model = AutoModelForCausalLM.from_pretrained(
|
|
59
|
+
model_path,
|
|
60
|
+
torch_dtype=dtype,
|
|
61
|
+
device_map=kwargs.get("device_map", "auto"),
|
|
62
|
+
trust_remote_code=kwargs.get("trust_remote_code", False),
|
|
63
|
+
)
|
|
64
|
+
self._model.eval()
|
|
65
|
+
|
|
66
|
+
# Discover layer structure
|
|
67
|
+
self._layers = self._find_layers()
|
|
68
|
+
self._total_layers = len(self._layers)
|
|
69
|
+
logger.info("Model loaded: %d layers discovered", self._total_layers)
|
|
70
|
+
|
|
71
|
+
def _find_layers(self) -> list[Any]:
|
|
72
|
+
"""Discover the decoder layers in the model.
|
|
73
|
+
|
|
74
|
+
Supports common architectures: LLaMA, Mistral, Qwen, GPT-NeoX, etc.
|
|
75
|
+
"""
|
|
76
|
+
model = self._model
|
|
77
|
+
|
|
78
|
+
# Try common layer container names
|
|
79
|
+
for attr_path in [
|
|
80
|
+
"model.layers", # LLaMA, Mistral, Qwen2
|
|
81
|
+
"transformer.h", # GPT-2, GPT-Neo
|
|
82
|
+
"gpt_neox.layers", # GPT-NeoX, Pythia
|
|
83
|
+
"transformer.blocks", # MPT
|
|
84
|
+
"model.decoder.layers", # OPT
|
|
85
|
+
]:
|
|
86
|
+
obj = model
|
|
87
|
+
found = True
|
|
88
|
+
for part in attr_path.split("."):
|
|
89
|
+
if hasattr(obj, part):
|
|
90
|
+
obj = getattr(obj, part)
|
|
91
|
+
else:
|
|
92
|
+
found = False
|
|
93
|
+
break
|
|
94
|
+
if found and hasattr(obj, "__len__"):
|
|
95
|
+
return list(obj)
|
|
96
|
+
|
|
97
|
+
raise RuntimeError(
|
|
98
|
+
"Could not find decoder layers. Supported architectures: "
|
|
99
|
+
"LLaMA, Mistral, Qwen2, GPT-2, GPT-NeoX, MPT, OPT. "
|
|
100
|
+
"For other architectures, use a custom backend."
|
|
101
|
+
)
|
|
102
|
+
|
|
103
|
+
def get_total_layers(self) -> int:
|
|
104
|
+
if self._total_layers == 0:
|
|
105
|
+
raise RuntimeError("Model not loaded. Call load() first.")
|
|
106
|
+
return self._total_layers
|
|
107
|
+
|
|
108
|
+
def get_tokenizer(self):
|
|
109
|
+
if self._tokenizer is None:
|
|
110
|
+
raise RuntimeError("Model not loaded. Call load() first.")
|
|
111
|
+
return self._tokenizer
|
|
112
|
+
|
|
113
|
+
@torch.no_grad()
|
|
114
|
+
def forward_with_duplication(
|
|
115
|
+
self,
|
|
116
|
+
text: str,
|
|
117
|
+
duplication_config: DuplicationConfig | None = None,
|
|
118
|
+
) -> torch.Tensor:
|
|
119
|
+
"""Run forward pass, optionally duplicating layers [i..j-1].
|
|
120
|
+
|
|
121
|
+
For the baseline (duplication_config=None), uses the standard
|
|
122
|
+
model forward pass. For duplicated configs, hooks into the
|
|
123
|
+
layer execution to replay specified layers.
|
|
124
|
+
|
|
125
|
+
Returns:
|
|
126
|
+
Logits tensor of shape (vocab_size,) for the last token.
|
|
127
|
+
"""
|
|
128
|
+
inputs = self._tokenizer(text, return_tensors="pt")
|
|
129
|
+
input_ids = inputs["input_ids"].to(self._model.device)
|
|
130
|
+
attention_mask = inputs.get("attention_mask")
|
|
131
|
+
if attention_mask is not None:
|
|
132
|
+
attention_mask = attention_mask.to(self._model.device)
|
|
133
|
+
|
|
134
|
+
if duplication_config is None:
|
|
135
|
+
# Standard forward pass (baseline)
|
|
136
|
+
outputs = self._model(
|
|
137
|
+
input_ids=input_ids,
|
|
138
|
+
attention_mask=attention_mask,
|
|
139
|
+
)
|
|
140
|
+
return outputs.logits[0, -1, :]
|
|
141
|
+
|
|
142
|
+
# Layer-duplicated forward pass
|
|
143
|
+
return self._forward_duplicated(
|
|
144
|
+
input_ids=input_ids,
|
|
145
|
+
attention_mask=attention_mask,
|
|
146
|
+
config=duplication_config,
|
|
147
|
+
)
|
|
148
|
+
|
|
149
|
+
def _forward_duplicated(
|
|
150
|
+
self,
|
|
151
|
+
input_ids: torch.Tensor,
|
|
152
|
+
attention_mask: torch.Tensor | None,
|
|
153
|
+
config: DuplicationConfig,
|
|
154
|
+
) -> torch.Tensor:
|
|
155
|
+
"""Execute the forward pass with layer duplication.
|
|
156
|
+
|
|
157
|
+
Execution order: [0...j-1, i...N-1]
|
|
158
|
+
Layers [i...j-1] execute twice — this gives the model
|
|
159
|
+
additional "thinking time" on the reasoning circuits.
|
|
160
|
+
"""
|
|
161
|
+
model = self._model
|
|
162
|
+
|
|
163
|
+
# Get the embedding output (before decoder layers)
|
|
164
|
+
if hasattr(model, "model"):
|
|
165
|
+
# LLaMA-style: model.model.embed_tokens
|
|
166
|
+
base = model.model
|
|
167
|
+
elif hasattr(model, "transformer"):
|
|
168
|
+
# GPT-style: model.transformer.wte
|
|
169
|
+
base = model.transformer
|
|
170
|
+
else:
|
|
171
|
+
raise RuntimeError("Cannot determine embedding layer location")
|
|
172
|
+
|
|
173
|
+
# Compute embeddings
|
|
174
|
+
if hasattr(base, "embed_tokens"):
|
|
175
|
+
hidden_states = base.embed_tokens(input_ids)
|
|
176
|
+
elif hasattr(base, "wte"):
|
|
177
|
+
hidden_states = base.wte(input_ids)
|
|
178
|
+
else:
|
|
179
|
+
raise RuntimeError("Cannot find embedding layer")
|
|
180
|
+
|
|
181
|
+
# Build execution order
|
|
182
|
+
exec_order = config.execution_order()
|
|
183
|
+
|
|
184
|
+
# Execute layers in the duplicated order
|
|
185
|
+
# Note: For models with RoPE, position IDs remain based on
|
|
186
|
+
# sequence length — we don't change them for duplicated layers.
|
|
187
|
+
# This matches the RYS methodology.
|
|
188
|
+
position_ids = torch.arange(
|
|
189
|
+
input_ids.shape[1], device=input_ids.device
|
|
190
|
+
).unsqueeze(0)
|
|
191
|
+
|
|
192
|
+
for layer_idx in exec_order:
|
|
193
|
+
layer = self._layers[layer_idx]
|
|
194
|
+
# Most decoder layers accept (hidden_states, attention_mask, position_ids)
|
|
195
|
+
# but the exact signature varies by architecture
|
|
196
|
+
layer_output = layer(
|
|
197
|
+
hidden_states,
|
|
198
|
+
attention_mask=self._prepare_causal_mask(
|
|
199
|
+
attention_mask, hidden_states
|
|
200
|
+
),
|
|
201
|
+
position_ids=position_ids,
|
|
202
|
+
)
|
|
203
|
+
# Layer output is typically a tuple; first element is hidden_states
|
|
204
|
+
if isinstance(layer_output, tuple):
|
|
205
|
+
hidden_states = layer_output[0]
|
|
206
|
+
else:
|
|
207
|
+
hidden_states = layer_output
|
|
208
|
+
|
|
209
|
+
# Apply final norm
|
|
210
|
+
if hasattr(base, "norm"):
|
|
211
|
+
hidden_states = base.norm(hidden_states)
|
|
212
|
+
elif hasattr(base, "ln_f"):
|
|
213
|
+
hidden_states = base.ln_f(hidden_states)
|
|
214
|
+
|
|
215
|
+
# Project to logits
|
|
216
|
+
if hasattr(model, "lm_head"):
|
|
217
|
+
logits = model.lm_head(hidden_states)
|
|
218
|
+
elif hasattr(base, "lm_head"):
|
|
219
|
+
logits = base.lm_head(hidden_states)
|
|
220
|
+
else:
|
|
221
|
+
raise RuntimeError("Cannot find LM head for logit projection")
|
|
222
|
+
|
|
223
|
+
return logits[0, -1, :]
|
|
224
|
+
|
|
225
|
+
@staticmethod
|
|
226
|
+
def _prepare_causal_mask(
|
|
227
|
+
attention_mask: torch.Tensor | None,
|
|
228
|
+
hidden_states: torch.Tensor,
|
|
229
|
+
) -> torch.Tensor | None:
|
|
230
|
+
"""Prepare 4D causal attention mask if needed.
|
|
231
|
+
|
|
232
|
+
Some models expect a pre-expanded mask. We return None and let
|
|
233
|
+
the layer handle masking internally when possible.
|
|
234
|
+
"""
|
|
235
|
+
# Most modern architectures handle masking internally
|
|
236
|
+
# when attention_mask is None or 2D
|
|
237
|
+
return attention_mask
|
|
238
|
+
|
|
239
|
+
def cleanup(self) -> None:
|
|
240
|
+
"""Free GPU memory."""
|
|
241
|
+
if self._model is not None:
|
|
242
|
+
del self._model
|
|
243
|
+
self._model = None
|
|
244
|
+
if self._tokenizer is not None:
|
|
245
|
+
del self._tokenizer
|
|
246
|
+
self._tokenizer = None
|
|
247
|
+
self._layers = []
|
|
248
|
+
self._total_layers = 0
|
|
249
|
+
|
|
250
|
+
import gc
|
|
251
|
+
|
|
252
|
+
gc.collect()
|
|
253
|
+
if torch.cuda.is_available():
|
|
254
|
+
torch.cuda.empty_cache()
|