attention-visualiser 0.1.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,38 @@
1
+ # This workflow will install Python dependencies, run tests and lint with a single version of Python
2
+ # For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python
3
+
4
+ name: tests
5
+
6
+ on:
7
+ push:
8
+ branches: [ "main" ]
9
+ pull_request:
10
+ branches: [ "main" ]
11
+
12
+ permissions:
13
+ contents: read
14
+
15
+ jobs:
16
+ build:
17
+ runs-on: ubuntu-latest
18
+
19
+ steps:
20
+ - uses: actions/checkout@v4
21
+ - name: Set up Python 3.12
22
+ uses: actions/setup-python@v3
23
+ with:
24
+ python-version: "3.12"
25
+ - name: Install uv
26
+ uses: astral-sh/setup-uv@v5
27
+ with:
28
+ version: "0.6.16"
29
+
30
+ - name: Install dependencies
31
+ run: |
32
+ uv sync
33
+ - name: Test with pytest
34
+ run: |
35
+ source .venv/bin/activate
36
+ uv run pytest --cov --cov-report=xml
37
+ - name: Upload coverage reports to Codecov
38
+ uses: codecov/codecov-action@v5
@@ -0,0 +1,23 @@
1
+ # Python-generated files
2
+ __pycache__/
3
+ *.py[oc]
4
+ build/
5
+ dist/
6
+ wheels/
7
+ *.egg-info
8
+ .mypy_cache/
9
+
10
+ # Virtual environments
11
+ .venv
12
+
13
+ # IDE and Editor specific files
14
+ .idea/
15
+ .vscode/
16
+ .code/
17
+
18
+ # pytest
19
+ .coverage
20
+ coverage.xml
21
+
22
+ # macos
23
+ .DS_Store
@@ -0,0 +1,10 @@
1
+ repos:
2
+ - repo: https://github.com/astral-sh/ruff-pre-commit
3
+ # Ruff version.
4
+ rev: v0.11.6
5
+ hooks:
6
+ # Run the linter.
7
+ - id: ruff
8
+ args: [ --fix ]
9
+ # Run the formatter.
10
+ - id: ruff-format
@@ -0,0 +1 @@
1
+ 3.12
@@ -0,0 +1,21 @@
1
+ MIT License
2
+
3
+ Copyright (c) 2025 Shawon Ashraf
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
@@ -0,0 +1,60 @@
1
+ Metadata-Version: 2.4
2
+ Name: attention-visualiser
3
+ Version: 0.1.0
4
+ Summary: a module to visualise attention layer activations from transformer based models from huggingface
5
+ License-File: LICENSE
6
+ Requires-Python: >=3.11
7
+ Requires-Dist: einops>=0.8.1
8
+ Requires-Dist: loguru>=0.7.3
9
+ Requires-Dist: seaborn>=0.13.2
10
+ Requires-Dist: transformers>=4.51.3
11
+ Description-Content-Type: text/markdown
12
+
13
+ # attention-visualiser
14
+
15
+ a module to visualise attention layer activations from transformer based models from huggingface
16
+
17
+ ## installation
18
+
19
+ ```bash
20
+ pip install git+https://codeberg.org/rashomon/attention-visualiser
21
+ ```
22
+
23
+ ## usage
24
+
25
+ ```python
26
+ from attention_visualiser import AttentionVisualiser
27
+ from transformers import AutoModel, AutoTokenizer
28
+
29
+ # visualising activations from gpt
30
+ model_name = "openai-community/openai-gpt"
31
+
32
+ model = AutoModel.from_pretrained(model_name)
33
+ model.eval()
34
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
35
+
36
+ text = "Look on my Works, ye Mighty, and despair!"
37
+ encoded_inputs = tokenizer.encode_plus(text, truncation=True, return_tensors="pt")
38
+
39
+ visualiser = AttentionVisualiser(model, tokenizer)
40
+
41
+ # visualise from the first attn layer
42
+ visualiser.visualise_attn_layer(0, encoded_inputs)
43
+
44
+ ```
45
+
46
+
47
+ ## local dev
48
+
49
+ ```bash
50
+ # env setup
51
+
52
+ uv sync
53
+ source .venv/bin/activate
54
+
55
+ # tests
56
+ uv run pytest
57
+
58
+ # tests with coverage
59
+ uv run pytest --cov --cov-report=xml
60
+ ```
@@ -0,0 +1,48 @@
1
+ # attention-visualiser
2
+
3
+ a module to visualise attention layer activations from transformer based models from huggingface
4
+
5
+ ## installation
6
+
7
+ ```bash
8
+ pip install git+https://codeberg.org/rashomon/attention-visualiser
9
+ ```
10
+
11
+ ## usage
12
+
13
+ ```python
14
+ from attention_visualiser import AttentionVisualiser
15
+ from transformers import AutoModel, AutoTokenizer
16
+
17
+ # visualising activations from gpt
18
+ model_name = "openai-community/openai-gpt"
19
+
20
+ model = AutoModel.from_pretrained(model_name)
21
+ model.eval()
22
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
23
+
24
+ text = "Look on my Works, ye Mighty, and despair!"
25
+ encoded_inputs = tokenizer.encode_plus(text, truncation=True, return_tensors="pt")
26
+
27
+ visualiser = AttentionVisualiser(model, tokenizer)
28
+
29
+ # visualise from the first attn layer
30
+ visualiser.visualise_attn_layer(0, encoded_inputs)
31
+
32
+ ```
33
+
34
+
35
+ ## local dev
36
+
37
+ ```bash
38
+ # env setup
39
+
40
+ uv sync
41
+ source .venv/bin/activate
42
+
43
+ # tests
44
+ uv run pytest
45
+
46
+ # tests with coverage
47
+ uv run pytest --cov --cov-report=xml
48
+ ```
@@ -0,0 +1,3 @@
1
+ from .pt import AttentionVisualiserPytorch as AttentionVisualiser
2
+
3
+ __all__ = ["AttentionVisualiser"]
@@ -0,0 +1,160 @@
1
+ import torch
2
+ import matplotlib.pyplot as plt
3
+ import seaborn as sns
4
+ from transformers import AutoTokenizer, AutoModel, FlaxAutoModel
5
+ from transformers import BatchEncoding
6
+ from loguru import logger
7
+ from einops import rearrange
8
+ from abc import ABC, abstractmethod
9
+ import numpy as np
10
+ from typing import Optional
11
+
12
+
13
+ class BaseAttentionVisualiser(ABC):
14
+ """Base abstract class for visualizing attention weights in transformer models.
15
+
16
+ This class provides the foundation for visualizing attention weights from
17
+ different transformer model implementations. Concrete subclasses must implement
18
+ methods for computing attention values and processing attention vectors.
19
+
20
+ Attributes:
21
+ model: A transformer model from the Hugging Face library
22
+ tokenizer: A tokenizer matching the model
23
+ config: Dictionary containing visualization configuration parameters
24
+ """
25
+
26
+ def __init__(
27
+ self,
28
+ model: AutoModel | FlaxAutoModel,
29
+ tokenizer: AutoTokenizer,
30
+ config: Optional[dict] = None,
31
+ ) -> None:
32
+ """Initialize the attention visualizer with a model and tokenizer.
33
+
34
+ Args:
35
+ model: A transformer model from Hugging Face (PyTorch or Flax)
36
+ tokenizer: A tokenizer matching the model
37
+ config: Optional dictionary with visualization parameters
38
+ Default parameters include:
39
+ - figsize: Tuple specifying figure dimensions
40
+ - cmap: Colormap for the heatmap
41
+ - annot: Whether to annotate heatmap cells with values
42
+ - xlabel: Label for x-axis
43
+ - ylabel: Label for y-axis
44
+ """
45
+ self.model = model
46
+ self.tokenizer = tokenizer
47
+
48
+ logger.info(f"Model config: {self.model.config}") # type: ignore
49
+
50
+ if not config:
51
+ self.config = {
52
+ "figsize": (15, 15),
53
+ "cmap": "viridis",
54
+ "annot": True,
55
+ "xlabel": "",
56
+ "ylabel": "",
57
+ }
58
+ logger.info(f"Setting default visualiser config: {self.config}")
59
+ else:
60
+ logger.info(f"Visualiser config: {config}")
61
+ self.config = config
62
+
63
+ # a cache for storing already computed attention vectors
64
+ # these need to be updated by the `compute_attentions`
65
+ # method
66
+ self.current_input = None
67
+ self.cache = None
68
+
69
+ def id_to_tokens(self, encoded_input: BatchEncoding) -> list[str]:
70
+ """Convert token IDs to readable token strings.
71
+
72
+ Args:
73
+ encoded_input: The encoded input from the tokenizer
74
+
75
+ Returns:
76
+ List of token strings corresponding to the input IDs
77
+ """
78
+ tokens = self.tokenizer.convert_ids_to_tokens(encoded_input["input_ids"][0]) # type: ignore
79
+ return tokens
80
+
81
+ @abstractmethod
82
+ def compute_attentions(self, encoded_input: BatchEncoding) -> tuple:
83
+ """Compute attention weights for the given input.
84
+
85
+ This method must be implemented by concrete subclasses to compute
86
+ attention weights specific to the model implementation.
87
+
88
+ Args:
89
+ encoded_input: The encoded input from the tokenizer
90
+
91
+ Returns:
92
+ A tuple containing attention weights
93
+ """
94
+ pass
95
+
96
+ @abstractmethod
97
+ def get_attention_vector_mean(
98
+ self, attention: torch.Tensor, axis: int = 0
99
+ ) -> np.ndarray:
100
+ """Calculate mean of attention vectors along specified axis.
101
+
102
+ This method must be implemented by concrete subclasses to handle
103
+ either PyTorch or JAX tensors appropriately.
104
+
105
+ Args:
106
+ attention: Attention tensor from the model
107
+ axis: Axis along which to compute the mean (default: 0)
108
+
109
+ Returns:
110
+ NumPy array of mean attention values
111
+ """
112
+ pass
113
+
114
+ def visualise_attn_layer(self, idx: int, encoded_input: BatchEncoding) -> None:
115
+ """Visualize attention weights for a specific layer.
116
+
117
+ Creates a heatmap visualization of the attention weights for the specified
118
+ layer index.
119
+
120
+ Args:
121
+ idx: Index of the attention layer to visualize.
122
+ Negative indices count from the end (-1 is the last layer).
123
+ encoded_input: The encoded input from the tokenizer
124
+
125
+ Raises:
126
+ AssertionError: If idx is outside the range of available attention layers
127
+ """
128
+ tokens = self.id_to_tokens(encoded_input)
129
+
130
+ attentions = self.compute_attentions(encoded_input)
131
+ n_attns = len(attentions)
132
+
133
+ # idx must no exceed attn_heads
134
+ assert idx < n_attns, (
135
+ f"index must be less than the number of attention outputs in the model, which is: {n_attns}"
136
+ )
137
+
138
+ # setting idx = -1 will get the last attention layer activations but
139
+ # the plot title will also show -1
140
+ if idx < 0:
141
+ idx = n_attns + idx
142
+
143
+ # get rid of the additional dimension since single input
144
+ attention = rearrange(attentions[idx], "1 a b c -> a b c")
145
+ # take mean over dim 0
146
+ attention = self.get_attention_vector_mean(attention)
147
+
148
+ plt.figure(figsize=self.config.get("figsize"))
149
+ sns.heatmap(
150
+ attention,
151
+ cmap=self.config.get("cmap"),
152
+ annot=self.config.get("annot"),
153
+ xticklabels=tokens,
154
+ yticklabels=tokens,
155
+ )
156
+
157
+ plt.title(f"Attention Weights for Layer idx: {idx}")
158
+ plt.xlabel(self.config.get("xlabel")) # type: ignore
159
+ plt.ylabel(self.config.get("ylabel")) # type: ignore
160
+ plt.show()
@@ -0,0 +1,76 @@
1
+ import torch
2
+ from transformers import AutoTokenizer, AutoModel
3
+ from transformers import BatchEncoding
4
+ from attention_visualiser.base import BaseAttentionVisualiser
5
+ import numpy as np
6
+ from typing import Optional
7
+
8
+
9
+ class AttentionVisualiserPytorch(BaseAttentionVisualiser):
10
+ """Attention visualizer for PyTorch-based transformer models.
11
+
12
+ This class implements the abstract methods from BaseAttentionVisualiser
13
+ specifically for models implemented in PyTorch. It handles the extraction
14
+ and processing of attention weights from PyTorch transformer models.
15
+
16
+ Attributes:
17
+ model: A PyTorch-based transformer model from Hugging Face
18
+ tokenizer: A tokenizer matching the model
19
+ config: Dictionary containing visualization configuration parameters
20
+ """
21
+
22
+ def __init__(
23
+ self, model: AutoModel, tokenizer: AutoTokenizer, config: Optional[dict] = None
24
+ ) -> None:
25
+ """Initialize the PyTorch-specific attention visualizer.
26
+
27
+ Args:
28
+ model: A PyTorch-based transformer model from Hugging Face
29
+ tokenizer: A tokenizer matching the model
30
+ config: Optional dictionary with visualization parameters
31
+ """
32
+ super().__init__(model, tokenizer, config)
33
+
34
+ def compute_attentions(self, encoded_input: BatchEncoding) -> tuple:
35
+ """Compute attention weights for the given input using a PyTorch model.
36
+
37
+ Runs the PyTorch model in inference mode with output_attentions flag set to True
38
+ and extracts the attention weights from the model output.
39
+
40
+ Args:
41
+ encoded_input: The encoded input from the tokenizer
42
+
43
+ Returns:
44
+ A tuple containing attention weights from all layers of the model
45
+ """
46
+ if encoded_input == self.current_input:
47
+ # return from cache
48
+ return self.cache
49
+
50
+ # else recompute
51
+ with torch.no_grad():
52
+ output = self.model(**encoded_input, output_attentions=True) # type: ignore
53
+
54
+ attentions = output.attentions
55
+
56
+ # update cache and current input
57
+ self.current_input = encoded_input
58
+ self.cache = attentions
59
+
60
+ return attentions
61
+
62
+ def get_attention_vector_mean(
63
+ self, attention: torch.Tensor, axis: int = 0
64
+ ) -> np.ndarray:
65
+ """Calculate mean of PyTorch attention vectors along specified axis.
66
+
67
+ Computes the mean of the attention tensor and converts it to a NumPy array.
68
+
69
+ Args:
70
+ attention: PyTorch tensor containing attention weights
71
+ axis: Axis along which to compute the mean (default: 0)
72
+
73
+ Returns:
74
+ NumPy array of mean attention values
75
+ """
76
+ return torch.mean(attention, dim=axis).detach().cpu().numpy()
@@ -0,0 +1,5 @@
1
+ coverage:
2
+ status:
3
+ project:
4
+ default:
5
+ target: 80%
@@ -0,0 +1,18 @@
1
+ from attention_visualiser import AttentionVisualiser
2
+ from transformers import AutoModel, AutoTokenizer
3
+
4
+ if __name__ == "__main__":
5
+ # visualising activations from gpt
6
+ model_name = "openai-community/openai-gpt"
7
+
8
+ model = AutoModel.from_pretrained(model_name)
9
+ model.eval()
10
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
11
+
12
+ text = "Look on my Works, ye Mighty, and despair!"
13
+ encoded_inputs = tokenizer.encode_plus(text, truncation=True, return_tensors="pt")
14
+
15
+ visualiser = AttentionVisualiser(model, tokenizer)
16
+
17
+ # visualise from the first attn layer
18
+ visualiser.visualise_attn_layer(0, encoded_inputs)
@@ -0,0 +1,28 @@
1
+ [project]
2
+ name = "attention-visualiser"
3
+ version = "0.1.0"
4
+ description = "a module to visualise attention layer activations from transformer based models from huggingface"
5
+ readme = "README.md"
6
+ requires-python = ">=3.11"
7
+ dependencies = [
8
+ "einops>=0.8.1",
9
+ "loguru>=0.7.3",
10
+ "seaborn>=0.13.2",
11
+ "transformers>=4.51.3",
12
+ ]
13
+
14
+ [dependency-groups]
15
+ dev = [
16
+ "pre-commit>=4.2.0",
17
+ "pytest>=8.3.5",
18
+ "pytest-cov>=6.1.1",
19
+ "ruff>=0.11.6",
20
+ "torch>=2.7.0",
21
+ ]
22
+
23
+ [build-system]
24
+ requires = ["hatchling"]
25
+ build-backend = "hatchling.build"
26
+
27
+ [tool.hatch.build.targets.wheel]
28
+ packages = ["attention-visualiser"]
File without changes
@@ -0,0 +1,167 @@
1
+ import pytest
2
+ import numpy as np
3
+ import torch
4
+ from attention_visualiser import AttentionVisualiser
5
+
6
+
7
+ class TestAttentionVisualiser:
8
+ @pytest.fixture
9
+ def mock_model(self):
10
+ class MockConfig:
11
+ num_attention_heads = 12
12
+ num_hidden_layers = 12
13
+
14
+ class MockModel:
15
+ def __init__(self):
16
+ self.config = MockConfig()
17
+ self.call_count = 0
18
+
19
+ def __call__(self, **kwargs):
20
+ self.call_count += 1
21
+ self.last_kwargs = kwargs
22
+ return self.output # type: ignore
23
+
24
+ model = MockModel()
25
+ return model
26
+
27
+ @pytest.fixture
28
+ def mock_tokenizer(self):
29
+ class MockTokenizer:
30
+ def __init__(self):
31
+ self.call_count = 0
32
+
33
+ def convert_ids_to_tokens(self, ids):
34
+ self.call_count += 1
35
+ return ["[CLS]", "Hello", "world", "[SEP]"]
36
+
37
+ return MockTokenizer()
38
+
39
+ @pytest.fixture
40
+ def visualiser(self, mock_model, mock_tokenizer):
41
+ return AttentionVisualiser(
42
+ model=mock_model,
43
+ tokenizer=mock_tokenizer,
44
+ )
45
+
46
+ @pytest.fixture
47
+ def mock_encoded_input(self):
48
+ class MockEncodedInput:
49
+ def __init__(self):
50
+ self.data = {"input_ids": torch.tensor([[101, 7592, 2088, 102]])}
51
+
52
+ def __getitem__(self, key):
53
+ return self.data.get(key)
54
+
55
+ def __eq__(self, other):
56
+ return id(self) == id(other)
57
+
58
+ # Add these methods to support ** unpacking
59
+ def keys(self):
60
+ return self.data.keys()
61
+
62
+ def __iter__(self):
63
+ return iter(self.data)
64
+
65
+ def get(self, key, default=None):
66
+ return self.data.get(key, default)
67
+
68
+ return MockEncodedInput()
69
+
70
+ @pytest.fixture
71
+ def mock_attention_data(self):
72
+ # Create mock attention data: (batch_size, num_heads, seq_len, seq_len)
73
+ # Shape: (1, 12, 4, 4) - 1 batch, 12 attention heads, sequence length 4
74
+ return torch.ones((1, 12, 4, 4)) * 0.25
75
+
76
+ def test_init(self, mock_model, mock_tokenizer):
77
+ """Test initialization of AttentionVisualiser."""
78
+ visualiser = AttentionVisualiser(mock_model, mock_tokenizer)
79
+
80
+ assert visualiser.model == mock_model
81
+ assert visualiser.tokenizer == mock_tokenizer
82
+ assert visualiser.config is not None
83
+ assert visualiser.current_input is None
84
+ assert visualiser.cache is None
85
+
86
+ def test_id_to_tokens(self, visualiser, mock_encoded_input):
87
+ """Test the id_to_tokens method."""
88
+ tokens = visualiser.id_to_tokens(mock_encoded_input)
89
+
90
+ assert visualiser.tokenizer.call_count == 1
91
+ assert tokens == ["[CLS]", "Hello", "world", "[SEP]"]
92
+
93
+ def test_compute_attentions_new_input(
94
+ self, visualiser, mock_encoded_input, mock_attention_data
95
+ ):
96
+ """Test compute_attentions with new input."""
97
+
98
+ # Setup mock return value for model call
99
+ class MockOutput:
100
+ def __init__(self, attn_data):
101
+ self.attentions = attn_data
102
+
103
+ # Add the output attribute to the model
104
+ output = MockOutput(mock_attention_data)
105
+ visualiser.model.output = output
106
+
107
+ # Initialize call_count if not present
108
+ if not hasattr(visualiser.model, "call_count"):
109
+ visualiser.model.call_count = 0
110
+ if not hasattr(visualiser.model, "last_kwargs"):
111
+ visualiser.model.last_kwargs = {}
112
+
113
+ # Call the method
114
+ attentions = visualiser.compute_attentions(mock_encoded_input)
115
+
116
+ # Verify the model was called
117
+ assert visualiser.model.call_count > 0
118
+ assert visualiser.model.last_kwargs.get("output_attentions")
119
+
120
+ # Verify the return value and cache update
121
+ assert attentions is mock_attention_data
122
+ assert visualiser.current_input == mock_encoded_input
123
+ assert visualiser.cache is mock_attention_data
124
+
125
+ def test_compute_attentions_cached_input(
126
+ self, visualiser, mock_encoded_input, mock_attention_data
127
+ ):
128
+ """Test compute_attentions with cached input."""
129
+ # Setup cache
130
+ visualiser.current_input = mock_encoded_input
131
+ visualiser.cache = mock_attention_data
132
+ visualiser.model.call_count = 0
133
+
134
+ # Call the method
135
+ attentions = visualiser.compute_attentions(mock_encoded_input)
136
+
137
+ # Verify model wasn't called and cached data was returned
138
+ assert visualiser.model.call_count == 0
139
+ assert attentions is mock_attention_data
140
+
141
+ def test_get_attention_vector_mean(self, visualiser):
142
+ """Test get_attention_vector_mean method."""
143
+ # Create test data
144
+ test_data = torch.tensor([[[0.1, 0.2], [0.3, 0.4]], [[0.5, 0.6], [0.7, 0.8]]])
145
+
146
+ # Expected result: mean along axis 0
147
+ expected = np.array([[0.3, 0.4], [0.5, 0.6]])
148
+
149
+ # Call the method
150
+ result = visualiser.get_attention_vector_mean(test_data)
151
+
152
+ # Verify
153
+ np.testing.assert_allclose(result, expected)
154
+
155
+ def test_get_attention_vector_mean_different_axis(self, visualiser):
156
+ """Test get_attention_vector_mean with a different axis."""
157
+ # Create test data
158
+ test_data = torch.tensor([[[0.1, 0.2], [0.3, 0.4]], [[0.5, 0.6], [0.7, 0.8]]])
159
+
160
+ # Expected result: mean along axis 1
161
+ expected = np.array([[0.2, 0.3], [0.6, 0.7]])
162
+
163
+ # Call the method
164
+ result = visualiser.get_attention_vector_mean(test_data, axis=1)
165
+
166
+ # Verify
167
+ np.testing.assert_allclose(result, expected)