attention-visualiser 0.1.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- attention_visualiser-0.1.0/.github/workflows/tests.yml +38 -0
- attention_visualiser-0.1.0/.gitignore +23 -0
- attention_visualiser-0.1.0/.pre-commit-config.yaml +10 -0
- attention_visualiser-0.1.0/.python-version +1 -0
- attention_visualiser-0.1.0/LICENSE +21 -0
- attention_visualiser-0.1.0/PKG-INFO +60 -0
- attention_visualiser-0.1.0/README.md +48 -0
- attention_visualiser-0.1.0/attention_visualiser/__init__.py +3 -0
- attention_visualiser-0.1.0/attention_visualiser/base.py +160 -0
- attention_visualiser-0.1.0/attention_visualiser/pt.py +76 -0
- attention_visualiser-0.1.0/codecov.yml +5 -0
- attention_visualiser-0.1.0/main.py +18 -0
- attention_visualiser-0.1.0/pyproject.toml +28 -0
- attention_visualiser-0.1.0/tests/__init__.py +0 -0
- attention_visualiser-0.1.0/tests/test_pt.py +167 -0
- attention_visualiser-0.1.0/uv.lock +1714 -0
|
@@ -0,0 +1,38 @@
|
|
|
1
|
+
# This workflow will install Python dependencies, run tests and lint with a single version of Python
|
|
2
|
+
# For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python
|
|
3
|
+
|
|
4
|
+
name: tests
|
|
5
|
+
|
|
6
|
+
on:
|
|
7
|
+
push:
|
|
8
|
+
branches: [ "main" ]
|
|
9
|
+
pull_request:
|
|
10
|
+
branches: [ "main" ]
|
|
11
|
+
|
|
12
|
+
permissions:
|
|
13
|
+
contents: read
|
|
14
|
+
|
|
15
|
+
jobs:
|
|
16
|
+
build:
|
|
17
|
+
runs-on: ubuntu-latest
|
|
18
|
+
|
|
19
|
+
steps:
|
|
20
|
+
- uses: actions/checkout@v4
|
|
21
|
+
- name: Set up Python 3.12
|
|
22
|
+
uses: actions/setup-python@v3
|
|
23
|
+
with:
|
|
24
|
+
python-version: "3.12"
|
|
25
|
+
- name: Install uv
|
|
26
|
+
uses: astral-sh/setup-uv@v5
|
|
27
|
+
with:
|
|
28
|
+
version: "0.6.16"
|
|
29
|
+
|
|
30
|
+
- name: Install dependencies
|
|
31
|
+
run: |
|
|
32
|
+
uv sync
|
|
33
|
+
- name: Test with pytest
|
|
34
|
+
run: |
|
|
35
|
+
source .venv/bin/activate
|
|
36
|
+
uv run pytest --cov --cov-report=xml
|
|
37
|
+
- name: Upload coverage reports to Codecov
|
|
38
|
+
uses: codecov/codecov-action@v5
|
|
@@ -0,0 +1,23 @@
|
|
|
1
|
+
# Python-generated files
|
|
2
|
+
__pycache__/
|
|
3
|
+
*.py[oc]
|
|
4
|
+
build/
|
|
5
|
+
dist/
|
|
6
|
+
wheels/
|
|
7
|
+
*.egg-info
|
|
8
|
+
.mypy_cache/
|
|
9
|
+
|
|
10
|
+
# Virtual environments
|
|
11
|
+
.venv
|
|
12
|
+
|
|
13
|
+
# IDE and Editor specific files
|
|
14
|
+
.idea/
|
|
15
|
+
.vscode/
|
|
16
|
+
.code/
|
|
17
|
+
|
|
18
|
+
# pytest
|
|
19
|
+
.coverage
|
|
20
|
+
coverage.xml
|
|
21
|
+
|
|
22
|
+
# macos
|
|
23
|
+
.DS_Store
|
|
@@ -0,0 +1 @@
|
|
|
1
|
+
3.12
|
|
@@ -0,0 +1,21 @@
|
|
|
1
|
+
MIT License
|
|
2
|
+
|
|
3
|
+
Copyright (c) 2025 Shawon Ashraf
|
|
4
|
+
|
|
5
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
|
6
|
+
of this software and associated documentation files (the "Software"), to deal
|
|
7
|
+
in the Software without restriction, including without limitation the rights
|
|
8
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
|
9
|
+
copies of the Software, and to permit persons to whom the Software is
|
|
10
|
+
furnished to do so, subject to the following conditions:
|
|
11
|
+
|
|
12
|
+
The above copyright notice and this permission notice shall be included in all
|
|
13
|
+
copies or substantial portions of the Software.
|
|
14
|
+
|
|
15
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
|
16
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
|
17
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
|
18
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
|
19
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
|
20
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
|
21
|
+
SOFTWARE.
|
|
@@ -0,0 +1,60 @@
|
|
|
1
|
+
Metadata-Version: 2.4
|
|
2
|
+
Name: attention-visualiser
|
|
3
|
+
Version: 0.1.0
|
|
4
|
+
Summary: a module to visualise attention layer activations from transformer based models from huggingface
|
|
5
|
+
License-File: LICENSE
|
|
6
|
+
Requires-Python: >=3.11
|
|
7
|
+
Requires-Dist: einops>=0.8.1
|
|
8
|
+
Requires-Dist: loguru>=0.7.3
|
|
9
|
+
Requires-Dist: seaborn>=0.13.2
|
|
10
|
+
Requires-Dist: transformers>=4.51.3
|
|
11
|
+
Description-Content-Type: text/markdown
|
|
12
|
+
|
|
13
|
+
# attention-visualiser
|
|
14
|
+
|
|
15
|
+
a module to visualise attention layer activations from transformer based models from huggingface
|
|
16
|
+
|
|
17
|
+
## installation
|
|
18
|
+
|
|
19
|
+
```bash
|
|
20
|
+
pip install git+https://codeberg.org/rashomon/attention-visualiser
|
|
21
|
+
```
|
|
22
|
+
|
|
23
|
+
## usage
|
|
24
|
+
|
|
25
|
+
```python
|
|
26
|
+
from attention_visualiser import AttentionVisualiser
|
|
27
|
+
from transformers import AutoModel, AutoTokenizer
|
|
28
|
+
|
|
29
|
+
# visualising activations from gpt
|
|
30
|
+
model_name = "openai-community/openai-gpt"
|
|
31
|
+
|
|
32
|
+
model = AutoModel.from_pretrained(model_name)
|
|
33
|
+
model.eval()
|
|
34
|
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
|
35
|
+
|
|
36
|
+
text = "Look on my Works, ye Mighty, and despair!"
|
|
37
|
+
encoded_inputs = tokenizer.encode_plus(text, truncation=True, return_tensors="pt")
|
|
38
|
+
|
|
39
|
+
visualiser = AttentionVisualiser(model, tokenizer)
|
|
40
|
+
|
|
41
|
+
# visualise from the first attn layer
|
|
42
|
+
visualiser.visualise_attn_layer(0, encoded_inputs)
|
|
43
|
+
|
|
44
|
+
```
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
## local dev
|
|
48
|
+
|
|
49
|
+
```bash
|
|
50
|
+
# env setup
|
|
51
|
+
|
|
52
|
+
uv sync
|
|
53
|
+
source .venv/bin/activate
|
|
54
|
+
|
|
55
|
+
# tests
|
|
56
|
+
uv run pytest
|
|
57
|
+
|
|
58
|
+
# tests with coverage
|
|
59
|
+
uv run pytest --cov --cov-report=xml
|
|
60
|
+
```
|
|
@@ -0,0 +1,48 @@
|
|
|
1
|
+
# attention-visualiser
|
|
2
|
+
|
|
3
|
+
a module to visualise attention layer activations from transformer based models from huggingface
|
|
4
|
+
|
|
5
|
+
## installation
|
|
6
|
+
|
|
7
|
+
```bash
|
|
8
|
+
pip install git+https://codeberg.org/rashomon/attention-visualiser
|
|
9
|
+
```
|
|
10
|
+
|
|
11
|
+
## usage
|
|
12
|
+
|
|
13
|
+
```python
|
|
14
|
+
from attention_visualiser import AttentionVisualiser
|
|
15
|
+
from transformers import AutoModel, AutoTokenizer
|
|
16
|
+
|
|
17
|
+
# visualising activations from gpt
|
|
18
|
+
model_name = "openai-community/openai-gpt"
|
|
19
|
+
|
|
20
|
+
model = AutoModel.from_pretrained(model_name)
|
|
21
|
+
model.eval()
|
|
22
|
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
|
23
|
+
|
|
24
|
+
text = "Look on my Works, ye Mighty, and despair!"
|
|
25
|
+
encoded_inputs = tokenizer.encode_plus(text, truncation=True, return_tensors="pt")
|
|
26
|
+
|
|
27
|
+
visualiser = AttentionVisualiser(model, tokenizer)
|
|
28
|
+
|
|
29
|
+
# visualise from the first attn layer
|
|
30
|
+
visualiser.visualise_attn_layer(0, encoded_inputs)
|
|
31
|
+
|
|
32
|
+
```
|
|
33
|
+
|
|
34
|
+
|
|
35
|
+
## local dev
|
|
36
|
+
|
|
37
|
+
```bash
|
|
38
|
+
# env setup
|
|
39
|
+
|
|
40
|
+
uv sync
|
|
41
|
+
source .venv/bin/activate
|
|
42
|
+
|
|
43
|
+
# tests
|
|
44
|
+
uv run pytest
|
|
45
|
+
|
|
46
|
+
# tests with coverage
|
|
47
|
+
uv run pytest --cov --cov-report=xml
|
|
48
|
+
```
|
|
@@ -0,0 +1,160 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import matplotlib.pyplot as plt
|
|
3
|
+
import seaborn as sns
|
|
4
|
+
from transformers import AutoTokenizer, AutoModel, FlaxAutoModel
|
|
5
|
+
from transformers import BatchEncoding
|
|
6
|
+
from loguru import logger
|
|
7
|
+
from einops import rearrange
|
|
8
|
+
from abc import ABC, abstractmethod
|
|
9
|
+
import numpy as np
|
|
10
|
+
from typing import Optional
|
|
11
|
+
|
|
12
|
+
|
|
13
|
+
class BaseAttentionVisualiser(ABC):
|
|
14
|
+
"""Base abstract class for visualizing attention weights in transformer models.
|
|
15
|
+
|
|
16
|
+
This class provides the foundation for visualizing attention weights from
|
|
17
|
+
different transformer model implementations. Concrete subclasses must implement
|
|
18
|
+
methods for computing attention values and processing attention vectors.
|
|
19
|
+
|
|
20
|
+
Attributes:
|
|
21
|
+
model: A transformer model from the Hugging Face library
|
|
22
|
+
tokenizer: A tokenizer matching the model
|
|
23
|
+
config: Dictionary containing visualization configuration parameters
|
|
24
|
+
"""
|
|
25
|
+
|
|
26
|
+
def __init__(
|
|
27
|
+
self,
|
|
28
|
+
model: AutoModel | FlaxAutoModel,
|
|
29
|
+
tokenizer: AutoTokenizer,
|
|
30
|
+
config: Optional[dict] = None,
|
|
31
|
+
) -> None:
|
|
32
|
+
"""Initialize the attention visualizer with a model and tokenizer.
|
|
33
|
+
|
|
34
|
+
Args:
|
|
35
|
+
model: A transformer model from Hugging Face (PyTorch or Flax)
|
|
36
|
+
tokenizer: A tokenizer matching the model
|
|
37
|
+
config: Optional dictionary with visualization parameters
|
|
38
|
+
Default parameters include:
|
|
39
|
+
- figsize: Tuple specifying figure dimensions
|
|
40
|
+
- cmap: Colormap for the heatmap
|
|
41
|
+
- annot: Whether to annotate heatmap cells with values
|
|
42
|
+
- xlabel: Label for x-axis
|
|
43
|
+
- ylabel: Label for y-axis
|
|
44
|
+
"""
|
|
45
|
+
self.model = model
|
|
46
|
+
self.tokenizer = tokenizer
|
|
47
|
+
|
|
48
|
+
logger.info(f"Model config: {self.model.config}") # type: ignore
|
|
49
|
+
|
|
50
|
+
if not config:
|
|
51
|
+
self.config = {
|
|
52
|
+
"figsize": (15, 15),
|
|
53
|
+
"cmap": "viridis",
|
|
54
|
+
"annot": True,
|
|
55
|
+
"xlabel": "",
|
|
56
|
+
"ylabel": "",
|
|
57
|
+
}
|
|
58
|
+
logger.info(f"Setting default visualiser config: {self.config}")
|
|
59
|
+
else:
|
|
60
|
+
logger.info(f"Visualiser config: {config}")
|
|
61
|
+
self.config = config
|
|
62
|
+
|
|
63
|
+
# a cache for storing already computed attention vectors
|
|
64
|
+
# these need to be updated by the `compute_attentions`
|
|
65
|
+
# method
|
|
66
|
+
self.current_input = None
|
|
67
|
+
self.cache = None
|
|
68
|
+
|
|
69
|
+
def id_to_tokens(self, encoded_input: BatchEncoding) -> list[str]:
|
|
70
|
+
"""Convert token IDs to readable token strings.
|
|
71
|
+
|
|
72
|
+
Args:
|
|
73
|
+
encoded_input: The encoded input from the tokenizer
|
|
74
|
+
|
|
75
|
+
Returns:
|
|
76
|
+
List of token strings corresponding to the input IDs
|
|
77
|
+
"""
|
|
78
|
+
tokens = self.tokenizer.convert_ids_to_tokens(encoded_input["input_ids"][0]) # type: ignore
|
|
79
|
+
return tokens
|
|
80
|
+
|
|
81
|
+
@abstractmethod
|
|
82
|
+
def compute_attentions(self, encoded_input: BatchEncoding) -> tuple:
|
|
83
|
+
"""Compute attention weights for the given input.
|
|
84
|
+
|
|
85
|
+
This method must be implemented by concrete subclasses to compute
|
|
86
|
+
attention weights specific to the model implementation.
|
|
87
|
+
|
|
88
|
+
Args:
|
|
89
|
+
encoded_input: The encoded input from the tokenizer
|
|
90
|
+
|
|
91
|
+
Returns:
|
|
92
|
+
A tuple containing attention weights
|
|
93
|
+
"""
|
|
94
|
+
pass
|
|
95
|
+
|
|
96
|
+
@abstractmethod
|
|
97
|
+
def get_attention_vector_mean(
|
|
98
|
+
self, attention: torch.Tensor, axis: int = 0
|
|
99
|
+
) -> np.ndarray:
|
|
100
|
+
"""Calculate mean of attention vectors along specified axis.
|
|
101
|
+
|
|
102
|
+
This method must be implemented by concrete subclasses to handle
|
|
103
|
+
either PyTorch or JAX tensors appropriately.
|
|
104
|
+
|
|
105
|
+
Args:
|
|
106
|
+
attention: Attention tensor from the model
|
|
107
|
+
axis: Axis along which to compute the mean (default: 0)
|
|
108
|
+
|
|
109
|
+
Returns:
|
|
110
|
+
NumPy array of mean attention values
|
|
111
|
+
"""
|
|
112
|
+
pass
|
|
113
|
+
|
|
114
|
+
def visualise_attn_layer(self, idx: int, encoded_input: BatchEncoding) -> None:
|
|
115
|
+
"""Visualize attention weights for a specific layer.
|
|
116
|
+
|
|
117
|
+
Creates a heatmap visualization of the attention weights for the specified
|
|
118
|
+
layer index.
|
|
119
|
+
|
|
120
|
+
Args:
|
|
121
|
+
idx: Index of the attention layer to visualize.
|
|
122
|
+
Negative indices count from the end (-1 is the last layer).
|
|
123
|
+
encoded_input: The encoded input from the tokenizer
|
|
124
|
+
|
|
125
|
+
Raises:
|
|
126
|
+
AssertionError: If idx is outside the range of available attention layers
|
|
127
|
+
"""
|
|
128
|
+
tokens = self.id_to_tokens(encoded_input)
|
|
129
|
+
|
|
130
|
+
attentions = self.compute_attentions(encoded_input)
|
|
131
|
+
n_attns = len(attentions)
|
|
132
|
+
|
|
133
|
+
# idx must no exceed attn_heads
|
|
134
|
+
assert idx < n_attns, (
|
|
135
|
+
f"index must be less than the number of attention outputs in the model, which is: {n_attns}"
|
|
136
|
+
)
|
|
137
|
+
|
|
138
|
+
# setting idx = -1 will get the last attention layer activations but
|
|
139
|
+
# the plot title will also show -1
|
|
140
|
+
if idx < 0:
|
|
141
|
+
idx = n_attns + idx
|
|
142
|
+
|
|
143
|
+
# get rid of the additional dimension since single input
|
|
144
|
+
attention = rearrange(attentions[idx], "1 a b c -> a b c")
|
|
145
|
+
# take mean over dim 0
|
|
146
|
+
attention = self.get_attention_vector_mean(attention)
|
|
147
|
+
|
|
148
|
+
plt.figure(figsize=self.config.get("figsize"))
|
|
149
|
+
sns.heatmap(
|
|
150
|
+
attention,
|
|
151
|
+
cmap=self.config.get("cmap"),
|
|
152
|
+
annot=self.config.get("annot"),
|
|
153
|
+
xticklabels=tokens,
|
|
154
|
+
yticklabels=tokens,
|
|
155
|
+
)
|
|
156
|
+
|
|
157
|
+
plt.title(f"Attention Weights for Layer idx: {idx}")
|
|
158
|
+
plt.xlabel(self.config.get("xlabel")) # type: ignore
|
|
159
|
+
plt.ylabel(self.config.get("ylabel")) # type: ignore
|
|
160
|
+
plt.show()
|
|
@@ -0,0 +1,76 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from transformers import AutoTokenizer, AutoModel
|
|
3
|
+
from transformers import BatchEncoding
|
|
4
|
+
from attention_visualiser.base import BaseAttentionVisualiser
|
|
5
|
+
import numpy as np
|
|
6
|
+
from typing import Optional
|
|
7
|
+
|
|
8
|
+
|
|
9
|
+
class AttentionVisualiserPytorch(BaseAttentionVisualiser):
|
|
10
|
+
"""Attention visualizer for PyTorch-based transformer models.
|
|
11
|
+
|
|
12
|
+
This class implements the abstract methods from BaseAttentionVisualiser
|
|
13
|
+
specifically for models implemented in PyTorch. It handles the extraction
|
|
14
|
+
and processing of attention weights from PyTorch transformer models.
|
|
15
|
+
|
|
16
|
+
Attributes:
|
|
17
|
+
model: A PyTorch-based transformer model from Hugging Face
|
|
18
|
+
tokenizer: A tokenizer matching the model
|
|
19
|
+
config: Dictionary containing visualization configuration parameters
|
|
20
|
+
"""
|
|
21
|
+
|
|
22
|
+
def __init__(
|
|
23
|
+
self, model: AutoModel, tokenizer: AutoTokenizer, config: Optional[dict] = None
|
|
24
|
+
) -> None:
|
|
25
|
+
"""Initialize the PyTorch-specific attention visualizer.
|
|
26
|
+
|
|
27
|
+
Args:
|
|
28
|
+
model: A PyTorch-based transformer model from Hugging Face
|
|
29
|
+
tokenizer: A tokenizer matching the model
|
|
30
|
+
config: Optional dictionary with visualization parameters
|
|
31
|
+
"""
|
|
32
|
+
super().__init__(model, tokenizer, config)
|
|
33
|
+
|
|
34
|
+
def compute_attentions(self, encoded_input: BatchEncoding) -> tuple:
|
|
35
|
+
"""Compute attention weights for the given input using a PyTorch model.
|
|
36
|
+
|
|
37
|
+
Runs the PyTorch model in inference mode with output_attentions flag set to True
|
|
38
|
+
and extracts the attention weights from the model output.
|
|
39
|
+
|
|
40
|
+
Args:
|
|
41
|
+
encoded_input: The encoded input from the tokenizer
|
|
42
|
+
|
|
43
|
+
Returns:
|
|
44
|
+
A tuple containing attention weights from all layers of the model
|
|
45
|
+
"""
|
|
46
|
+
if encoded_input == self.current_input:
|
|
47
|
+
# return from cache
|
|
48
|
+
return self.cache
|
|
49
|
+
|
|
50
|
+
# else recompute
|
|
51
|
+
with torch.no_grad():
|
|
52
|
+
output = self.model(**encoded_input, output_attentions=True) # type: ignore
|
|
53
|
+
|
|
54
|
+
attentions = output.attentions
|
|
55
|
+
|
|
56
|
+
# update cache and current input
|
|
57
|
+
self.current_input = encoded_input
|
|
58
|
+
self.cache = attentions
|
|
59
|
+
|
|
60
|
+
return attentions
|
|
61
|
+
|
|
62
|
+
def get_attention_vector_mean(
|
|
63
|
+
self, attention: torch.Tensor, axis: int = 0
|
|
64
|
+
) -> np.ndarray:
|
|
65
|
+
"""Calculate mean of PyTorch attention vectors along specified axis.
|
|
66
|
+
|
|
67
|
+
Computes the mean of the attention tensor and converts it to a NumPy array.
|
|
68
|
+
|
|
69
|
+
Args:
|
|
70
|
+
attention: PyTorch tensor containing attention weights
|
|
71
|
+
axis: Axis along which to compute the mean (default: 0)
|
|
72
|
+
|
|
73
|
+
Returns:
|
|
74
|
+
NumPy array of mean attention values
|
|
75
|
+
"""
|
|
76
|
+
return torch.mean(attention, dim=axis).detach().cpu().numpy()
|
|
@@ -0,0 +1,18 @@
|
|
|
1
|
+
from attention_visualiser import AttentionVisualiser
|
|
2
|
+
from transformers import AutoModel, AutoTokenizer
|
|
3
|
+
|
|
4
|
+
if __name__ == "__main__":
|
|
5
|
+
# visualising activations from gpt
|
|
6
|
+
model_name = "openai-community/openai-gpt"
|
|
7
|
+
|
|
8
|
+
model = AutoModel.from_pretrained(model_name)
|
|
9
|
+
model.eval()
|
|
10
|
+
tokenizer = AutoTokenizer.from_pretrained(model_name)
|
|
11
|
+
|
|
12
|
+
text = "Look on my Works, ye Mighty, and despair!"
|
|
13
|
+
encoded_inputs = tokenizer.encode_plus(text, truncation=True, return_tensors="pt")
|
|
14
|
+
|
|
15
|
+
visualiser = AttentionVisualiser(model, tokenizer)
|
|
16
|
+
|
|
17
|
+
# visualise from the first attn layer
|
|
18
|
+
visualiser.visualise_attn_layer(0, encoded_inputs)
|
|
@@ -0,0 +1,28 @@
|
|
|
1
|
+
[project]
|
|
2
|
+
name = "attention-visualiser"
|
|
3
|
+
version = "0.1.0"
|
|
4
|
+
description = "a module to visualise attention layer activations from transformer based models from huggingface"
|
|
5
|
+
readme = "README.md"
|
|
6
|
+
requires-python = ">=3.11"
|
|
7
|
+
dependencies = [
|
|
8
|
+
"einops>=0.8.1",
|
|
9
|
+
"loguru>=0.7.3",
|
|
10
|
+
"seaborn>=0.13.2",
|
|
11
|
+
"transformers>=4.51.3",
|
|
12
|
+
]
|
|
13
|
+
|
|
14
|
+
[dependency-groups]
|
|
15
|
+
dev = [
|
|
16
|
+
"pre-commit>=4.2.0",
|
|
17
|
+
"pytest>=8.3.5",
|
|
18
|
+
"pytest-cov>=6.1.1",
|
|
19
|
+
"ruff>=0.11.6",
|
|
20
|
+
"torch>=2.7.0",
|
|
21
|
+
]
|
|
22
|
+
|
|
23
|
+
[build-system]
|
|
24
|
+
requires = ["hatchling"]
|
|
25
|
+
build-backend = "hatchling.build"
|
|
26
|
+
|
|
27
|
+
[tool.hatch.build.targets.wheel]
|
|
28
|
+
packages = ["attention-visualiser"]
|
|
File without changes
|
|
@@ -0,0 +1,167 @@
|
|
|
1
|
+
import pytest
|
|
2
|
+
import numpy as np
|
|
3
|
+
import torch
|
|
4
|
+
from attention_visualiser import AttentionVisualiser
|
|
5
|
+
|
|
6
|
+
|
|
7
|
+
class TestAttentionVisualiser:
|
|
8
|
+
@pytest.fixture
|
|
9
|
+
def mock_model(self):
|
|
10
|
+
class MockConfig:
|
|
11
|
+
num_attention_heads = 12
|
|
12
|
+
num_hidden_layers = 12
|
|
13
|
+
|
|
14
|
+
class MockModel:
|
|
15
|
+
def __init__(self):
|
|
16
|
+
self.config = MockConfig()
|
|
17
|
+
self.call_count = 0
|
|
18
|
+
|
|
19
|
+
def __call__(self, **kwargs):
|
|
20
|
+
self.call_count += 1
|
|
21
|
+
self.last_kwargs = kwargs
|
|
22
|
+
return self.output # type: ignore
|
|
23
|
+
|
|
24
|
+
model = MockModel()
|
|
25
|
+
return model
|
|
26
|
+
|
|
27
|
+
@pytest.fixture
|
|
28
|
+
def mock_tokenizer(self):
|
|
29
|
+
class MockTokenizer:
|
|
30
|
+
def __init__(self):
|
|
31
|
+
self.call_count = 0
|
|
32
|
+
|
|
33
|
+
def convert_ids_to_tokens(self, ids):
|
|
34
|
+
self.call_count += 1
|
|
35
|
+
return ["[CLS]", "Hello", "world", "[SEP]"]
|
|
36
|
+
|
|
37
|
+
return MockTokenizer()
|
|
38
|
+
|
|
39
|
+
@pytest.fixture
|
|
40
|
+
def visualiser(self, mock_model, mock_tokenizer):
|
|
41
|
+
return AttentionVisualiser(
|
|
42
|
+
model=mock_model,
|
|
43
|
+
tokenizer=mock_tokenizer,
|
|
44
|
+
)
|
|
45
|
+
|
|
46
|
+
@pytest.fixture
|
|
47
|
+
def mock_encoded_input(self):
|
|
48
|
+
class MockEncodedInput:
|
|
49
|
+
def __init__(self):
|
|
50
|
+
self.data = {"input_ids": torch.tensor([[101, 7592, 2088, 102]])}
|
|
51
|
+
|
|
52
|
+
def __getitem__(self, key):
|
|
53
|
+
return self.data.get(key)
|
|
54
|
+
|
|
55
|
+
def __eq__(self, other):
|
|
56
|
+
return id(self) == id(other)
|
|
57
|
+
|
|
58
|
+
# Add these methods to support ** unpacking
|
|
59
|
+
def keys(self):
|
|
60
|
+
return self.data.keys()
|
|
61
|
+
|
|
62
|
+
def __iter__(self):
|
|
63
|
+
return iter(self.data)
|
|
64
|
+
|
|
65
|
+
def get(self, key, default=None):
|
|
66
|
+
return self.data.get(key, default)
|
|
67
|
+
|
|
68
|
+
return MockEncodedInput()
|
|
69
|
+
|
|
70
|
+
@pytest.fixture
|
|
71
|
+
def mock_attention_data(self):
|
|
72
|
+
# Create mock attention data: (batch_size, num_heads, seq_len, seq_len)
|
|
73
|
+
# Shape: (1, 12, 4, 4) - 1 batch, 12 attention heads, sequence length 4
|
|
74
|
+
return torch.ones((1, 12, 4, 4)) * 0.25
|
|
75
|
+
|
|
76
|
+
def test_init(self, mock_model, mock_tokenizer):
|
|
77
|
+
"""Test initialization of AttentionVisualiser."""
|
|
78
|
+
visualiser = AttentionVisualiser(mock_model, mock_tokenizer)
|
|
79
|
+
|
|
80
|
+
assert visualiser.model == mock_model
|
|
81
|
+
assert visualiser.tokenizer == mock_tokenizer
|
|
82
|
+
assert visualiser.config is not None
|
|
83
|
+
assert visualiser.current_input is None
|
|
84
|
+
assert visualiser.cache is None
|
|
85
|
+
|
|
86
|
+
def test_id_to_tokens(self, visualiser, mock_encoded_input):
|
|
87
|
+
"""Test the id_to_tokens method."""
|
|
88
|
+
tokens = visualiser.id_to_tokens(mock_encoded_input)
|
|
89
|
+
|
|
90
|
+
assert visualiser.tokenizer.call_count == 1
|
|
91
|
+
assert tokens == ["[CLS]", "Hello", "world", "[SEP]"]
|
|
92
|
+
|
|
93
|
+
def test_compute_attentions_new_input(
|
|
94
|
+
self, visualiser, mock_encoded_input, mock_attention_data
|
|
95
|
+
):
|
|
96
|
+
"""Test compute_attentions with new input."""
|
|
97
|
+
|
|
98
|
+
# Setup mock return value for model call
|
|
99
|
+
class MockOutput:
|
|
100
|
+
def __init__(self, attn_data):
|
|
101
|
+
self.attentions = attn_data
|
|
102
|
+
|
|
103
|
+
# Add the output attribute to the model
|
|
104
|
+
output = MockOutput(mock_attention_data)
|
|
105
|
+
visualiser.model.output = output
|
|
106
|
+
|
|
107
|
+
# Initialize call_count if not present
|
|
108
|
+
if not hasattr(visualiser.model, "call_count"):
|
|
109
|
+
visualiser.model.call_count = 0
|
|
110
|
+
if not hasattr(visualiser.model, "last_kwargs"):
|
|
111
|
+
visualiser.model.last_kwargs = {}
|
|
112
|
+
|
|
113
|
+
# Call the method
|
|
114
|
+
attentions = visualiser.compute_attentions(mock_encoded_input)
|
|
115
|
+
|
|
116
|
+
# Verify the model was called
|
|
117
|
+
assert visualiser.model.call_count > 0
|
|
118
|
+
assert visualiser.model.last_kwargs.get("output_attentions")
|
|
119
|
+
|
|
120
|
+
# Verify the return value and cache update
|
|
121
|
+
assert attentions is mock_attention_data
|
|
122
|
+
assert visualiser.current_input == mock_encoded_input
|
|
123
|
+
assert visualiser.cache is mock_attention_data
|
|
124
|
+
|
|
125
|
+
def test_compute_attentions_cached_input(
|
|
126
|
+
self, visualiser, mock_encoded_input, mock_attention_data
|
|
127
|
+
):
|
|
128
|
+
"""Test compute_attentions with cached input."""
|
|
129
|
+
# Setup cache
|
|
130
|
+
visualiser.current_input = mock_encoded_input
|
|
131
|
+
visualiser.cache = mock_attention_data
|
|
132
|
+
visualiser.model.call_count = 0
|
|
133
|
+
|
|
134
|
+
# Call the method
|
|
135
|
+
attentions = visualiser.compute_attentions(mock_encoded_input)
|
|
136
|
+
|
|
137
|
+
# Verify model wasn't called and cached data was returned
|
|
138
|
+
assert visualiser.model.call_count == 0
|
|
139
|
+
assert attentions is mock_attention_data
|
|
140
|
+
|
|
141
|
+
def test_get_attention_vector_mean(self, visualiser):
|
|
142
|
+
"""Test get_attention_vector_mean method."""
|
|
143
|
+
# Create test data
|
|
144
|
+
test_data = torch.tensor([[[0.1, 0.2], [0.3, 0.4]], [[0.5, 0.6], [0.7, 0.8]]])
|
|
145
|
+
|
|
146
|
+
# Expected result: mean along axis 0
|
|
147
|
+
expected = np.array([[0.3, 0.4], [0.5, 0.6]])
|
|
148
|
+
|
|
149
|
+
# Call the method
|
|
150
|
+
result = visualiser.get_attention_vector_mean(test_data)
|
|
151
|
+
|
|
152
|
+
# Verify
|
|
153
|
+
np.testing.assert_allclose(result, expected)
|
|
154
|
+
|
|
155
|
+
def test_get_attention_vector_mean_different_axis(self, visualiser):
|
|
156
|
+
"""Test get_attention_vector_mean with a different axis."""
|
|
157
|
+
# Create test data
|
|
158
|
+
test_data = torch.tensor([[[0.1, 0.2], [0.3, 0.4]], [[0.5, 0.6], [0.7, 0.8]]])
|
|
159
|
+
|
|
160
|
+
# Expected result: mean along axis 1
|
|
161
|
+
expected = np.array([[0.2, 0.3], [0.6, 0.7]])
|
|
162
|
+
|
|
163
|
+
# Call the method
|
|
164
|
+
result = visualiser.get_attention_vector_mean(test_data, axis=1)
|
|
165
|
+
|
|
166
|
+
# Verify
|
|
167
|
+
np.testing.assert_allclose(result, expected)
|