cross-domain-saliency-maps 0.0.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,208 @@
1
+ import abc
2
+ import torch
3
+ import numpy as np
4
+
5
+ class DomainBase:
6
+ """ Defines the base class for a domain transformation.
7
+
8
+ Attributes:
9
+ device (torch.device): The device used to run the model.
10
+ coefficients (tf.Tensor): The input samples expressed
11
+ in the target domain.
12
+ baseline_coefficients (tf.Tensor): The baseline samples expressed
13
+ in the target domain.
14
+ dtype: The dtype of the coefficients when they are
15
+ converted to tf.Tensor.
16
+ """
17
+ def __init__(self, device: torch.device, dtype = torch.float32):
18
+ """ Initialize the Domain.
19
+ """
20
+ self.device = device
21
+ self.dtype = dtype
22
+ self.coefficients = None
23
+ self.baseline_coefficients = None
24
+
25
+ def get_coefficients(self):
26
+ """ Return the coeffcients.
27
+ """
28
+ return self.coefficients
29
+
30
+ def get_coefficient_baseline(self):
31
+ """ Return the baseline coefficients.
32
+ """
33
+ return self.baseline_coefficients
34
+
35
+ @abc.abstractmethod
36
+ def set_coefficients(self, x, x_baseline):
37
+ """ Sets the coefficients depending on the transform
38
+ used. Should be implemented by class inheriting DomainBase.
39
+
40
+ Args:
41
+ x (np.array): The input sample.
42
+ x_baseline (np.array): The baseline sample.
43
+ """
44
+ return
45
+
46
+ @abc.abstractmethod
47
+ def forward_transform(self, x):
48
+ """ Performs the forward transform, transforming the input
49
+ sample x into the corresponding target domain.
50
+
51
+ Args:
52
+ x : The input sample
53
+ """
54
+ return
55
+
56
+ @abc.abstractmethod
57
+ def inverse_transform(self, x_input):
58
+ """ Performs the inverse transform, transforming the
59
+ sample x from the target domain back to the original one.
60
+
61
+ Args:
62
+ x_input : The input sample
63
+ """
64
+ return
65
+
66
+ class FourierDomain(DomainBase):
67
+ """ Domain implementation for the Fourier transform, mapping
68
+ time-domain samples into the frequency domain.
69
+ """
70
+ def __init__(self, device, dtype = torch.float32, time_dimension = -1):
71
+ """ Initialize the Fourier Domain.
72
+
73
+ Args:
74
+ dtype: The type of the input features.
75
+ time_dimension: The dimension in the input which
76
+ corresponds to the time-points.
77
+ """
78
+ super().__init__(device = device, dtype = dtype)
79
+ self.time_dimension = time_dimension
80
+
81
+ def forward_transform(self, x: torch.Tensor):
82
+ """ Implementation of the forward transform, transforming the input
83
+ time sample to the corresponding frequency domain sample.
84
+
85
+ Args:
86
+ x (tf.Tensor): Input time-domain sample.
87
+ """
88
+ return torch.fft.fft(x.type(torch.complex64), dim = self.time_dimension)
89
+
90
+ def set_coefficients(self, x: np.array, x_baseline: np.array):
91
+ """ Sets the frequency coefficients transforming the input and
92
+ baseline samples into the frequency domain.
93
+
94
+ Args:
95
+ x (np.array): The input sample in time-domain.
96
+ x_baseline (np.array): The baseline sample in time-domain.
97
+ """
98
+ x_tf = torch.from_numpy(x).type(self.dtype).to(self.device)
99
+ x_baseline_tf = torch.from_numpy(x_baseline).type(self.dtype).to(self.device)
100
+
101
+ self.coefficients = self.forward_transform(x_tf)
102
+ self.baseline_coefficients = self.forward_transform(x_baseline_tf)
103
+
104
+ def inverse_transform(self, x_input: torch.Tensor):
105
+ """ Inverse transform, transforming the frequency domain input
106
+ x_input points back into the time domain.
107
+
108
+ Args:
109
+ x_input (tf.Tensor): The frequency domain input.
110
+ """
111
+ return torch.fft.ifft(x_input, dim = self.time_dimension).to(torch.float32)
112
+
113
+ class ICADomain(DomainBase):
114
+ """ Implements the Domain for the Independent Component Analysis
115
+ (ICA) decomposition.
116
+
117
+ We consider the independent channels to form the basis of
118
+ the input. The mixing matrix forms the actual input coefficients.
119
+ This way, ICA IG expresses significance of each independent component.
120
+
121
+ """
122
+ def __init__(self, device, ica, dtype = torch.float32, channel_permutation = (0, 2, 1)):
123
+ super().__init__(device = device, dtype = dtype)
124
+ self.channel_permutation = channel_permutation
125
+ self.ica = ica
126
+
127
+ def forward_transform(self, x: np.array):
128
+ """ Forward transform, transforms the input channels into
129
+ independent channels.
130
+
131
+ Args:
132
+ x (np.array): Input sample
133
+ """
134
+ X = self.ica.transform(x)
135
+ return X
136
+
137
+ def set_coefficients(self, x:np.array, x_baseline: np.array):
138
+ """ Sets the coefficients in the ICA space. We consider the
139
+ independent channels to form the basis of the input. The
140
+ mixing matrix forms the actual input coefficients. This way
141
+ ICA IG expresses significance of each independent component.
142
+
143
+ Since the basis of the transform are the independent components,
144
+ we only consider zero baseline.
145
+
146
+ Args:
147
+ x (np.array): Input sample of size [1, n_time_points, n_channels]
148
+ """
149
+
150
+ self.basis = torch.from_numpy(self.forward_transform(x[0, ...].T)).type(self.dtype).to(self.device)
151
+
152
+ self.coefficients = torch.from_numpy(self.ica.mixing_.T).type(self.dtype).to(self.device)[None, ...]
153
+ self.baseline_coefficients = torch.zeros_like(self.coefficients)
154
+ self.mean = torch.from_numpy(self.ica.mean_).type(self.dtype).to(self.device)
155
+
156
+ def inverse_transform(self, x_input: torch.Tensor):
157
+ """ Inverses the ICA transform given an input matrix
158
+ x_input and the basis stored in the domain.
159
+
160
+ Args:
161
+ x_input (torch.Tensor): Input unmixing matrix of size [n_channels, n_channels].
162
+ """
163
+ return torch.transpose(torch.matmul(self.basis, x_input) + self.mean, 1, 2)
164
+
165
+ class TimeDomain(DomainBase):
166
+ """ Domain implementation for the Time transform. This is the original input domain,
167
+ no transform takes place.
168
+ """
169
+ def __init__(self, device, dtype = torch.float32, time_dimension = -1):
170
+ """ Initialize the Fourier Domain.
171
+
172
+ Args:
173
+ dtype: The type of the input features.
174
+ time_dimension: The dimension in the input which
175
+ corresponds to the time-points.
176
+ """
177
+ super().__init__(device = device, dtype = dtype)
178
+ self.time_dimension = time_dimension
179
+
180
+ def forward_transform(self, x: torch.Tensor):
181
+ """ Implementation of the forward transform. No transform takes place.
182
+
183
+ Args:
184
+ x (tf.Tensor): Input time-domain sample.
185
+ """
186
+ return x
187
+
188
+ def set_coefficients(self, x: np.array, x_baseline: np.array):
189
+ """ Sets the frequency coefficients transforming the input and
190
+ baseline samples into the frequency domain.
191
+
192
+ Args:
193
+ x (np.array): The input sample in time-domain.
194
+ x_baseline (np.array): The baseline sample in time-domain.
195
+ """
196
+ x_tf = torch.from_numpy(x).type(self.dtype).to(self.device)
197
+ x_baseline_tf = torch.from_numpy(x_baseline).type(self.dtype).to(self.device)
198
+
199
+ self.coefficients = self.forward_transform(x_tf)
200
+ self.baseline_coefficients = self.forward_transform(x_baseline_tf)
201
+
202
+ def inverse_transform(self, x_input: torch.Tensor):
203
+ """ Inverse transform.
204
+
205
+ Args:
206
+ x_input (tf.Tensor): The frequency domain input.
207
+ """
208
+ return x_input
@@ -0,0 +1,107 @@
1
+ Metadata-Version: 2.4
2
+ Name: cross_domain_saliency_maps
3
+ Version: 0.0.4
4
+ Summary: A pytorch/tensorflow library for generating Cross-Domain saliency maps.
5
+ Project-URL: Homepage, https://github.com/esl-epfl/cross-domain-saliency-maps
6
+ Author-email: Christodoulos Kechris <christodoulos.kechris@epfl.ch>
7
+ License-File: LICENSE
8
+ Classifier: License :: OSI Approved :: GNU General Public License v3 (GPLv3)
9
+ Classifier: Operating System :: OS Independent
10
+ Classifier: Programming Language :: Python :: 3
11
+ Requires-Python: >=3.10.16
12
+ Requires-Dist: tensorflow<=2.19,>=2.13.0
13
+ Requires-Dist: torch<=2.7,>=2.6.0
14
+ Requires-Dist: tqdm==4.67.1
15
+ Description-Content-Type: text/markdown
16
+
17
+ # Timeseries Saliency Maps: Explaining models across multiple domains
18
+
19
+ Official Pytorch/Tensorflow implementation of Cross-Domain Saliency Maps.
20
+ The method does not require any model model retraining or modications.
21
+
22
+ [![arXiv](https://img.shields.io/badge/arXiv-2505.13100-b31b1b.svg)](https://arxiv.org/abs/2505.13100)
23
+
24
+ <img src="./figures/cross_domain_saliency_maps_banner.svg" width="755">
25
+
26
+ # Installation
27
+ Download this repository:
28
+ ```
29
+ git clone https://github.com/esl-epfl/cross-domain-saliency-maps
30
+ ```
31
+
32
+ Install using ```pip```:
33
+ ```
34
+ pip install ./cross_domain_saliency_maps
35
+ ```
36
+
37
+ # Examples
38
+ Get started with our PyTorch/TensorFlow examples (one-click run)
39
+ 1. [Pytorch getting started](./examples/torch_demo.ipynb) [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/esl-epfl/cross-domain-saliency-maps/blob/main/examples/torch_demo.ipynb)
40
+ 2. [Tensorflow getting started](./examples/tensorflow_demo.ipynb) [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/esl-epfl/cross-domain-saliency-maps/blob/main/examples/tensorflow_demo.ipynb)
41
+ 3. [What does your model see in your EEG?](./examples/seizure_detection.ipynb) [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/esl-epfl/cross-domain-saliency-maps/blob/main/examples/seizure_detection.ipynb)
42
+
43
+ # Usage
44
+ The library supports generating saliency maps for any domain which
45
+ can be formulated as an invertible transformation with a differentiable
46
+ inverse transformation.
47
+
48
+ To generate maps expressed in a domain, a corresponding ```Domain```
49
+ object needs to be defined. This describes the operations performed
50
+ during the forward and inverse transformations.
51
+
52
+ Implementations for the [Frequency and Independent Component Analysis (ICA)](#saliency-maps-in-the-frequency-and-ica-domains)
53
+ transformations are already implemented and can be directly deployed.
54
+ Additionally, the libraryprovides the flexibility of
55
+ [defining new transformations](#saliency-maps-in-any-domain).
56
+
57
+ ## Saliency Maps in the Frequency and ICA domains
58
+ The following domains are already implemented and can be
59
+ directly used to generate saliency maps:
60
+
61
+ 1. **Time Domain.** This is the original Integrated Gradients,
62
+ expressing saliency maps in the raw input domain (time). The
63
+ corresponding ```Domain``` object is ```TimeDomain```. The map
64
+ can be directly generated:
65
+ ```timeIG = TimeIG(model, n_iterations, output_channel = 0)```
66
+
67
+ 2. **Frequency Domain.** Each point in the map corresponds to
68
+ the importance of the corresponding frquency component. The
69
+ Fourier transform is used to transform the time-domain to
70
+ the frequency domain. The corresponding ```Domain``` object
71
+ is ```FourierDomain```. The map can be directly generated:
72
+ ```fourierIG = FourierIG(model, n_iterations, output_channel = 0)```
73
+
74
+ 3. **Independent Component Domain.** Each point in the
75
+ map corresponds to an independent component (IC) of the ICA
76
+ decomposition. Any ICA implementation can be used as long as it
77
+ complies with [```sklearn.decomposition.FastICA```](https://scikit-learn.org/stable/modules/generated/sklearn.decomposition.FastICA.html). The domain is defined
78
+ by ```ICADomain```. Before generating the map a ```FastICA```
79
+ needs to be fitted to the input sample (see [example](./examples/tensorflow_demo.ipynb)). The map can be directly generated:
80
+ ``` icaIG = ICAIG(model, fastICA, n_iterations, output_channel = 0)```
81
+
82
+ ## Saliency Maps in any domain
83
+ The library supports extending the Cross-domain Integrated Gradients
84
+ for any invertible domain with a differentiable inverse transform. This
85
+ requires:
86
+ 1. Creating the propert ```Domain``` object describing the corresponding
87
+ transform. ```Domain``` objects need to inherit from ```DomainBase``` and
88
+ implement the required functions. More details can be found in the
89
+ implementation of the ```FourierDomain``` and ```ICADomain``` (
90
+ [tensorflow](/src/cross_domain_saliency_maps/tensorflow_ig/domain_transforms.py), [pytorch](/src/cross_domain_saliency_maps/torch_ig/domain_transforms.py)).
91
+
92
+ 2. Calling ```CrossDomainIG``` with the new domain as the input. This
93
+ can be done either by creating a ```CrossDomainIG```, initializing it
94
+ with the new domain, or by implementing a new dedicated class inheriting
95
+ ```CrossDomainIG```. For more details check the implementations of
96
+ ```FourierIG``` and ```ICAIG```(
97
+ [tensorflow](/src/cross_domain_saliency_maps/tensorflow_ig/cross_domain_integrated_gradients.py), [pytorch](/src/cross_domain_saliency_maps/torch_ig/cross_domain_integrated_gradients.py)).
98
+
99
+ # Reference
100
+ **BibTeX**
101
+ ```bibtex
102
+ @article{kechris2025time,
103
+ title={Time series saliency maps: Explaining models across multiple domains},
104
+ author={Kechris, Christodoulos and Dan, Jonathan and Atienza, David},
105
+ journal={arXiv preprint arXiv:2505.13100},
106
+ year={2025}
107
+ }
@@ -0,0 +1,11 @@
1
+ cross_domain_saliency_maps/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
+ cross_domain_saliency_maps/tensorflow_ig/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
3
+ cross_domain_saliency_maps/tensorflow_ig/cross_domain_integrated_gradients.py,sha256=og7t6XrLfzu3s2sbpeqi5REG4hkXdcIIvwumShT5q7Q,5991
4
+ cross_domain_saliency_maps/tensorflow_ig/domain_transforms.py,sha256=LeayOMh2EWmGU_GQ2Wa50Z5A_YH5gnEErFpi2FEz4Sw,7531
5
+ cross_domain_saliency_maps/torch_ig/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
6
+ cross_domain_saliency_maps/torch_ig/cross_domain_integrated_gradients.py,sha256=xWSA88HkUFw-FoUW9cfEmMCNIzpj4KYB5wabmx9PeBI,6128
7
+ cross_domain_saliency_maps/torch_ig/domain_transforms.py,sha256=wpAljfS9oo5fn68FHH6rPz1y64UykINmk7hRJVSHv0E,7642
8
+ cross_domain_saliency_maps-0.0.4.dist-info/METADATA,sha256=-WT7nRI7LV2jSzZKLlZT7SY3m8pbjjL9WIRn5-xYQGE,5581
9
+ cross_domain_saliency_maps-0.0.4.dist-info/WHEEL,sha256=qtCwoSJWgHk21S1Kb4ihdzI2rlJ1ZKaIurTj_ngOhyQ,87
10
+ cross_domain_saliency_maps-0.0.4.dist-info/licenses/LICENSE,sha256=OXLcl0T2SZ8Pmy2_dmlvKuetivmyPd5m1q-Gyd-zaYY,35149
11
+ cross_domain_saliency_maps-0.0.4.dist-info/RECORD,,
@@ -0,0 +1,4 @@
1
+ Wheel-Version: 1.0
2
+ Generator: hatchling 1.27.0
3
+ Root-Is-Purelib: true
4
+ Tag: py3-none-any