cross-domain-saliency-maps 0.0.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cross_domain_saliency_maps/__init__.py +0 -0
- cross_domain_saliency_maps/tensorflow_ig/__init__.py +0 -0
- cross_domain_saliency_maps/tensorflow_ig/cross_domain_integrated_gradients.py +152 -0
- cross_domain_saliency_maps/tensorflow_ig/domain_transforms.py +207 -0
- cross_domain_saliency_maps/torch_ig/__init__.py +0 -0
- cross_domain_saliency_maps/torch_ig/cross_domain_integrated_gradients.py +162 -0
- cross_domain_saliency_maps/torch_ig/domain_transforms.py +208 -0
- cross_domain_saliency_maps-0.0.4.dist-info/METADATA +107 -0
- cross_domain_saliency_maps-0.0.4.dist-info/RECORD +11 -0
- cross_domain_saliency_maps-0.0.4.dist-info/WHEEL +4 -0
- cross_domain_saliency_maps-0.0.4.dist-info/licenses/LICENSE +674 -0
@@ -0,0 +1,208 @@
|
|
1
|
+
import abc
|
2
|
+
import torch
|
3
|
+
import numpy as np
|
4
|
+
|
5
|
+
class DomainBase:
|
6
|
+
""" Defines the base class for a domain transformation.
|
7
|
+
|
8
|
+
Attributes:
|
9
|
+
device (torch.device): The device used to run the model.
|
10
|
+
coefficients (tf.Tensor): The input samples expressed
|
11
|
+
in the target domain.
|
12
|
+
baseline_coefficients (tf.Tensor): The baseline samples expressed
|
13
|
+
in the target domain.
|
14
|
+
dtype: The dtype of the coefficients when they are
|
15
|
+
converted to tf.Tensor.
|
16
|
+
"""
|
17
|
+
def __init__(self, device: torch.device, dtype = torch.float32):
|
18
|
+
""" Initialize the Domain.
|
19
|
+
"""
|
20
|
+
self.device = device
|
21
|
+
self.dtype = dtype
|
22
|
+
self.coefficients = None
|
23
|
+
self.baseline_coefficients = None
|
24
|
+
|
25
|
+
def get_coefficients(self):
|
26
|
+
""" Return the coeffcients.
|
27
|
+
"""
|
28
|
+
return self.coefficients
|
29
|
+
|
30
|
+
def get_coefficient_baseline(self):
|
31
|
+
""" Return the baseline coefficients.
|
32
|
+
"""
|
33
|
+
return self.baseline_coefficients
|
34
|
+
|
35
|
+
@abc.abstractmethod
|
36
|
+
def set_coefficients(self, x, x_baseline):
|
37
|
+
""" Sets the coefficients depending on the transform
|
38
|
+
used. Should be implemented by class inheriting DomainBase.
|
39
|
+
|
40
|
+
Args:
|
41
|
+
x (np.array): The input sample.
|
42
|
+
x_baseline (np.array): The baseline sample.
|
43
|
+
"""
|
44
|
+
return
|
45
|
+
|
46
|
+
@abc.abstractmethod
|
47
|
+
def forward_transform(self, x):
|
48
|
+
""" Performs the forward transform, transforming the input
|
49
|
+
sample x into the corresponding target domain.
|
50
|
+
|
51
|
+
Args:
|
52
|
+
x : The input sample
|
53
|
+
"""
|
54
|
+
return
|
55
|
+
|
56
|
+
@abc.abstractmethod
|
57
|
+
def inverse_transform(self, x_input):
|
58
|
+
""" Performs the inverse transform, transforming the
|
59
|
+
sample x from the target domain back to the original one.
|
60
|
+
|
61
|
+
Args:
|
62
|
+
x_input : The input sample
|
63
|
+
"""
|
64
|
+
return
|
65
|
+
|
66
|
+
class FourierDomain(DomainBase):
|
67
|
+
""" Domain implementation for the Fourier transform, mapping
|
68
|
+
time-domain samples into the frequency domain.
|
69
|
+
"""
|
70
|
+
def __init__(self, device, dtype = torch.float32, time_dimension = -1):
|
71
|
+
""" Initialize the Fourier Domain.
|
72
|
+
|
73
|
+
Args:
|
74
|
+
dtype: The type of the input features.
|
75
|
+
time_dimension: The dimension in the input which
|
76
|
+
corresponds to the time-points.
|
77
|
+
"""
|
78
|
+
super().__init__(device = device, dtype = dtype)
|
79
|
+
self.time_dimension = time_dimension
|
80
|
+
|
81
|
+
def forward_transform(self, x: torch.Tensor):
|
82
|
+
""" Implementation of the forward transform, transforming the input
|
83
|
+
time sample to the corresponding frequency domain sample.
|
84
|
+
|
85
|
+
Args:
|
86
|
+
x (tf.Tensor): Input time-domain sample.
|
87
|
+
"""
|
88
|
+
return torch.fft.fft(x.type(torch.complex64), dim = self.time_dimension)
|
89
|
+
|
90
|
+
def set_coefficients(self, x: np.array, x_baseline: np.array):
|
91
|
+
""" Sets the frequency coefficients transforming the input and
|
92
|
+
baseline samples into the frequency domain.
|
93
|
+
|
94
|
+
Args:
|
95
|
+
x (np.array): The input sample in time-domain.
|
96
|
+
x_baseline (np.array): The baseline sample in time-domain.
|
97
|
+
"""
|
98
|
+
x_tf = torch.from_numpy(x).type(self.dtype).to(self.device)
|
99
|
+
x_baseline_tf = torch.from_numpy(x_baseline).type(self.dtype).to(self.device)
|
100
|
+
|
101
|
+
self.coefficients = self.forward_transform(x_tf)
|
102
|
+
self.baseline_coefficients = self.forward_transform(x_baseline_tf)
|
103
|
+
|
104
|
+
def inverse_transform(self, x_input: torch.Tensor):
|
105
|
+
""" Inverse transform, transforming the frequency domain input
|
106
|
+
x_input points back into the time domain.
|
107
|
+
|
108
|
+
Args:
|
109
|
+
x_input (tf.Tensor): The frequency domain input.
|
110
|
+
"""
|
111
|
+
return torch.fft.ifft(x_input, dim = self.time_dimension).to(torch.float32)
|
112
|
+
|
113
|
+
class ICADomain(DomainBase):
|
114
|
+
""" Implements the Domain for the Independent Component Analysis
|
115
|
+
(ICA) decomposition.
|
116
|
+
|
117
|
+
We consider the independent channels to form the basis of
|
118
|
+
the input. The mixing matrix forms the actual input coefficients.
|
119
|
+
This way, ICA IG expresses significance of each independent component.
|
120
|
+
|
121
|
+
"""
|
122
|
+
def __init__(self, device, ica, dtype = torch.float32, channel_permutation = (0, 2, 1)):
|
123
|
+
super().__init__(device = device, dtype = dtype)
|
124
|
+
self.channel_permutation = channel_permutation
|
125
|
+
self.ica = ica
|
126
|
+
|
127
|
+
def forward_transform(self, x: np.array):
|
128
|
+
""" Forward transform, transforms the input channels into
|
129
|
+
independent channels.
|
130
|
+
|
131
|
+
Args:
|
132
|
+
x (np.array): Input sample
|
133
|
+
"""
|
134
|
+
X = self.ica.transform(x)
|
135
|
+
return X
|
136
|
+
|
137
|
+
def set_coefficients(self, x:np.array, x_baseline: np.array):
|
138
|
+
""" Sets the coefficients in the ICA space. We consider the
|
139
|
+
independent channels to form the basis of the input. The
|
140
|
+
mixing matrix forms the actual input coefficients. This way
|
141
|
+
ICA IG expresses significance of each independent component.
|
142
|
+
|
143
|
+
Since the basis of the transform are the independent components,
|
144
|
+
we only consider zero baseline.
|
145
|
+
|
146
|
+
Args:
|
147
|
+
x (np.array): Input sample of size [1, n_time_points, n_channels]
|
148
|
+
"""
|
149
|
+
|
150
|
+
self.basis = torch.from_numpy(self.forward_transform(x[0, ...].T)).type(self.dtype).to(self.device)
|
151
|
+
|
152
|
+
self.coefficients = torch.from_numpy(self.ica.mixing_.T).type(self.dtype).to(self.device)[None, ...]
|
153
|
+
self.baseline_coefficients = torch.zeros_like(self.coefficients)
|
154
|
+
self.mean = torch.from_numpy(self.ica.mean_).type(self.dtype).to(self.device)
|
155
|
+
|
156
|
+
def inverse_transform(self, x_input: torch.Tensor):
|
157
|
+
""" Inverses the ICA transform given an input matrix
|
158
|
+
x_input and the basis stored in the domain.
|
159
|
+
|
160
|
+
Args:
|
161
|
+
x_input (torch.Tensor): Input unmixing matrix of size [n_channels, n_channels].
|
162
|
+
"""
|
163
|
+
return torch.transpose(torch.matmul(self.basis, x_input) + self.mean, 1, 2)
|
164
|
+
|
165
|
+
class TimeDomain(DomainBase):
|
166
|
+
""" Domain implementation for the Time transform. This is the original input domain,
|
167
|
+
no transform takes place.
|
168
|
+
"""
|
169
|
+
def __init__(self, device, dtype = torch.float32, time_dimension = -1):
|
170
|
+
""" Initialize the Fourier Domain.
|
171
|
+
|
172
|
+
Args:
|
173
|
+
dtype: The type of the input features.
|
174
|
+
time_dimension: The dimension in the input which
|
175
|
+
corresponds to the time-points.
|
176
|
+
"""
|
177
|
+
super().__init__(device = device, dtype = dtype)
|
178
|
+
self.time_dimension = time_dimension
|
179
|
+
|
180
|
+
def forward_transform(self, x: torch.Tensor):
|
181
|
+
""" Implementation of the forward transform. No transform takes place.
|
182
|
+
|
183
|
+
Args:
|
184
|
+
x (tf.Tensor): Input time-domain sample.
|
185
|
+
"""
|
186
|
+
return x
|
187
|
+
|
188
|
+
def set_coefficients(self, x: np.array, x_baseline: np.array):
|
189
|
+
""" Sets the frequency coefficients transforming the input and
|
190
|
+
baseline samples into the frequency domain.
|
191
|
+
|
192
|
+
Args:
|
193
|
+
x (np.array): The input sample in time-domain.
|
194
|
+
x_baseline (np.array): The baseline sample in time-domain.
|
195
|
+
"""
|
196
|
+
x_tf = torch.from_numpy(x).type(self.dtype).to(self.device)
|
197
|
+
x_baseline_tf = torch.from_numpy(x_baseline).type(self.dtype).to(self.device)
|
198
|
+
|
199
|
+
self.coefficients = self.forward_transform(x_tf)
|
200
|
+
self.baseline_coefficients = self.forward_transform(x_baseline_tf)
|
201
|
+
|
202
|
+
def inverse_transform(self, x_input: torch.Tensor):
|
203
|
+
""" Inverse transform.
|
204
|
+
|
205
|
+
Args:
|
206
|
+
x_input (tf.Tensor): The frequency domain input.
|
207
|
+
"""
|
208
|
+
return x_input
|
@@ -0,0 +1,107 @@
|
|
1
|
+
Metadata-Version: 2.4
|
2
|
+
Name: cross_domain_saliency_maps
|
3
|
+
Version: 0.0.4
|
4
|
+
Summary: A pytorch/tensorflow library for generating Cross-Domain saliency maps.
|
5
|
+
Project-URL: Homepage, https://github.com/esl-epfl/cross-domain-saliency-maps
|
6
|
+
Author-email: Christodoulos Kechris <christodoulos.kechris@epfl.ch>
|
7
|
+
License-File: LICENSE
|
8
|
+
Classifier: License :: OSI Approved :: GNU General Public License v3 (GPLv3)
|
9
|
+
Classifier: Operating System :: OS Independent
|
10
|
+
Classifier: Programming Language :: Python :: 3
|
11
|
+
Requires-Python: >=3.10.16
|
12
|
+
Requires-Dist: tensorflow<=2.19,>=2.13.0
|
13
|
+
Requires-Dist: torch<=2.7,>=2.6.0
|
14
|
+
Requires-Dist: tqdm==4.67.1
|
15
|
+
Description-Content-Type: text/markdown
|
16
|
+
|
17
|
+
# Timeseries Saliency Maps: Explaining models across multiple domains
|
18
|
+
|
19
|
+
Official Pytorch/Tensorflow implementation of Cross-Domain Saliency Maps.
|
20
|
+
The method does not require any model model retraining or modications.
|
21
|
+
|
22
|
+
[](https://arxiv.org/abs/2505.13100)
|
23
|
+
|
24
|
+
<img src="./figures/cross_domain_saliency_maps_banner.svg" width="755">
|
25
|
+
|
26
|
+
# Installation
|
27
|
+
Download this repository:
|
28
|
+
```
|
29
|
+
git clone https://github.com/esl-epfl/cross-domain-saliency-maps
|
30
|
+
```
|
31
|
+
|
32
|
+
Install using ```pip```:
|
33
|
+
```
|
34
|
+
pip install ./cross_domain_saliency_maps
|
35
|
+
```
|
36
|
+
|
37
|
+
# Examples
|
38
|
+
Get started with our PyTorch/TensorFlow examples (one-click run)
|
39
|
+
1. [Pytorch getting started](./examples/torch_demo.ipynb) [](https://colab.research.google.com/github/esl-epfl/cross-domain-saliency-maps/blob/main/examples/torch_demo.ipynb)
|
40
|
+
2. [Tensorflow getting started](./examples/tensorflow_demo.ipynb) [](https://colab.research.google.com/github/esl-epfl/cross-domain-saliency-maps/blob/main/examples/tensorflow_demo.ipynb)
|
41
|
+
3. [What does your model see in your EEG?](./examples/seizure_detection.ipynb) [](https://colab.research.google.com/github/esl-epfl/cross-domain-saliency-maps/blob/main/examples/seizure_detection.ipynb)
|
42
|
+
|
43
|
+
# Usage
|
44
|
+
The library supports generating saliency maps for any domain which
|
45
|
+
can be formulated as an invertible transformation with a differentiable
|
46
|
+
inverse transformation.
|
47
|
+
|
48
|
+
To generate maps expressed in a domain, a corresponding ```Domain```
|
49
|
+
object needs to be defined. This describes the operations performed
|
50
|
+
during the forward and inverse transformations.
|
51
|
+
|
52
|
+
Implementations for the [Frequency and Independent Component Analysis (ICA)](#saliency-maps-in-the-frequency-and-ica-domains)
|
53
|
+
transformations are already implemented and can be directly deployed.
|
54
|
+
Additionally, the libraryprovides the flexibility of
|
55
|
+
[defining new transformations](#saliency-maps-in-any-domain).
|
56
|
+
|
57
|
+
## Saliency Maps in the Frequency and ICA domains
|
58
|
+
The following domains are already implemented and can be
|
59
|
+
directly used to generate saliency maps:
|
60
|
+
|
61
|
+
1. **Time Domain.** This is the original Integrated Gradients,
|
62
|
+
expressing saliency maps in the raw input domain (time). The
|
63
|
+
corresponding ```Domain``` object is ```TimeDomain```. The map
|
64
|
+
can be directly generated:
|
65
|
+
```timeIG = TimeIG(model, n_iterations, output_channel = 0)```
|
66
|
+
|
67
|
+
2. **Frequency Domain.** Each point in the map corresponds to
|
68
|
+
the importance of the corresponding frquency component. The
|
69
|
+
Fourier transform is used to transform the time-domain to
|
70
|
+
the frequency domain. The corresponding ```Domain``` object
|
71
|
+
is ```FourierDomain```. The map can be directly generated:
|
72
|
+
```fourierIG = FourierIG(model, n_iterations, output_channel = 0)```
|
73
|
+
|
74
|
+
3. **Independent Component Domain.** Each point in the
|
75
|
+
map corresponds to an independent component (IC) of the ICA
|
76
|
+
decomposition. Any ICA implementation can be used as long as it
|
77
|
+
complies with [```sklearn.decomposition.FastICA```](https://scikit-learn.org/stable/modules/generated/sklearn.decomposition.FastICA.html). The domain is defined
|
78
|
+
by ```ICADomain```. Before generating the map a ```FastICA```
|
79
|
+
needs to be fitted to the input sample (see [example](./examples/tensorflow_demo.ipynb)). The map can be directly generated:
|
80
|
+
``` icaIG = ICAIG(model, fastICA, n_iterations, output_channel = 0)```
|
81
|
+
|
82
|
+
## Saliency Maps in any domain
|
83
|
+
The library supports extending the Cross-domain Integrated Gradients
|
84
|
+
for any invertible domain with a differentiable inverse transform. This
|
85
|
+
requires:
|
86
|
+
1. Creating the propert ```Domain``` object describing the corresponding
|
87
|
+
transform. ```Domain``` objects need to inherit from ```DomainBase``` and
|
88
|
+
implement the required functions. More details can be found in the
|
89
|
+
implementation of the ```FourierDomain``` and ```ICADomain``` (
|
90
|
+
[tensorflow](/src/cross_domain_saliency_maps/tensorflow_ig/domain_transforms.py), [pytorch](/src/cross_domain_saliency_maps/torch_ig/domain_transforms.py)).
|
91
|
+
|
92
|
+
2. Calling ```CrossDomainIG``` with the new domain as the input. This
|
93
|
+
can be done either by creating a ```CrossDomainIG```, initializing it
|
94
|
+
with the new domain, or by implementing a new dedicated class inheriting
|
95
|
+
```CrossDomainIG```. For more details check the implementations of
|
96
|
+
```FourierIG``` and ```ICAIG```(
|
97
|
+
[tensorflow](/src/cross_domain_saliency_maps/tensorflow_ig/cross_domain_integrated_gradients.py), [pytorch](/src/cross_domain_saliency_maps/torch_ig/cross_domain_integrated_gradients.py)).
|
98
|
+
|
99
|
+
# Reference
|
100
|
+
**BibTeX**
|
101
|
+
```bibtex
|
102
|
+
@article{kechris2025time,
|
103
|
+
title={Time series saliency maps: Explaining models across multiple domains},
|
104
|
+
author={Kechris, Christodoulos and Dan, Jonathan and Atienza, David},
|
105
|
+
journal={arXiv preprint arXiv:2505.13100},
|
106
|
+
year={2025}
|
107
|
+
}
|
@@ -0,0 +1,11 @@
|
|
1
|
+
cross_domain_saliency_maps/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
2
|
+
cross_domain_saliency_maps/tensorflow_ig/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
3
|
+
cross_domain_saliency_maps/tensorflow_ig/cross_domain_integrated_gradients.py,sha256=og7t6XrLfzu3s2sbpeqi5REG4hkXdcIIvwumShT5q7Q,5991
|
4
|
+
cross_domain_saliency_maps/tensorflow_ig/domain_transforms.py,sha256=LeayOMh2EWmGU_GQ2Wa50Z5A_YH5gnEErFpi2FEz4Sw,7531
|
5
|
+
cross_domain_saliency_maps/torch_ig/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
6
|
+
cross_domain_saliency_maps/torch_ig/cross_domain_integrated_gradients.py,sha256=xWSA88HkUFw-FoUW9cfEmMCNIzpj4KYB5wabmx9PeBI,6128
|
7
|
+
cross_domain_saliency_maps/torch_ig/domain_transforms.py,sha256=wpAljfS9oo5fn68FHH6rPz1y64UykINmk7hRJVSHv0E,7642
|
8
|
+
cross_domain_saliency_maps-0.0.4.dist-info/METADATA,sha256=-WT7nRI7LV2jSzZKLlZT7SY3m8pbjjL9WIRn5-xYQGE,5581
|
9
|
+
cross_domain_saliency_maps-0.0.4.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
|
10
|
+
cross_domain_saliency_maps-0.0.4.dist-info/licenses/LICENSE,sha256=OXLcl0T2SZ8Pmy2_dmlvKuetivmyPd5m1q-Gyd-zaYY,35149
|
11
|
+
cross_domain_saliency_maps-0.0.4.dist-info/RECORD,,
|