cross-domain-saliency-maps 0.0.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
File without changes
File without changes
@@ -0,0 +1,152 @@
1
+ import tensorflow as tf
2
+ import numpy as np
3
+
4
+ from cross_domain_saliency_maps.tensorflow_ig.domain_transforms import DomainBase
5
+ from cross_domain_saliency_maps.tensorflow_ig.domain_transforms import FourierDomain
6
+ from cross_domain_saliency_maps.tensorflow_ig.domain_transforms import TimeDomain
7
+ from cross_domain_saliency_maps.tensorflow_ig.domain_transforms import ICADomain
8
+
9
+ class CrossDomainIG:
10
+ """ Cross Domain IG base class. Defines the basic functionality
11
+ for cross-domain ig.
12
+
13
+ Attributes:
14
+ model (tf.keras.models.Model): A tensorflow model.
15
+ n_iterations (int): Number of iterations for approximating IG.
16
+ output_channel(int): The output channel of the model for which
17
+ we generate the saliency map.
18
+ dtype (dtype): The type of the target domain.
19
+ """
20
+ def __init__(self, model: tf.keras.models.Model,
21
+ n_iterations: int,
22
+ output_channel: int,
23
+ dtype = tf.float32):
24
+ """ Initializes CrossDomainIG.
25
+
26
+ Args:
27
+ model (tf.keras.models.Model): A tensorflow model for which
28
+ saliency maps will be generated.
29
+ n_iterations (int): Number of steps used in approximating
30
+ the integral in the Integrated Gradients computation.
31
+ output_channel (int): The channle of the model's output used
32
+ for the saliency map (e.g. the class channel).
33
+ """
34
+ self.model = model
35
+ self.n_iterations = n_iterations
36
+ self.output_channel = output_channel
37
+
38
+ def initialize_domain(self, Domain: DomainBase, **kwargs):
39
+ """ Initializes the target domain in which the IG is
40
+ expressed.
41
+
42
+ Args:
43
+ Domain (DomainBase): The target domain.
44
+ **kwargs: Parameters needed to initialize the Domain.
45
+ """
46
+ self.domain = Domain(**kwargs)
47
+
48
+ def run(self, x: np.array, x_baseline: np.array):
49
+ """ Runs the saliency map generation for input sample x
50
+ using the x_baseline for baseline.
51
+
52
+ Args:
53
+ x (np.array): a single sample with shape [1, n_timesteps, n_channels]
54
+ x_baseline (np.array): the baseline sample with shape [1, n_timesteps, n_channels]
55
+ """
56
+ self.domain.set_coefficients(x, x_baseline)
57
+
58
+ with tf.GradientTape() as tape:
59
+ X_in = self.domain.get_coefficients()
60
+ X_baseline = self.domain.get_coefficient_baseline()
61
+
62
+ a = tf.constant(np.linspace(0, 1, self.n_iterations), dtype = X_in.dtype)
63
+
64
+ X_samples = X_baseline + (X_in - X_baseline) * a[:, tf.newaxis, tf.newaxis]
65
+ tape.watch(X_samples)
66
+ x_ = self.domain.inverse_transform(X_samples)
67
+ y_ = self.model(x_)
68
+ grads = tape.gradient(y_[:, self.output_channel], X_samples)
69
+
70
+ S = tf.math.reduce_mean(tf.math.conj(grads), axis = 0)
71
+ self.multiIG = tf.math.real((X_in[0, :] - X_baseline[0, :]) * S)
72
+ return self.multiIG
73
+
74
+ def getMultiIG(self):
75
+ """ Get the generated saliency map.
76
+ """
77
+ return self.multiIG
78
+
79
+ class TimeIG(CrossDomainIG):
80
+ """ Implementation of the CrossDomainIG specifically for the
81
+ time target domain (Original IG).
82
+ """
83
+ def __init__(self,
84
+ model: tf.keras.models.Model,
85
+ n_iterations: int,
86
+ output_channel: int,
87
+ dtype = tf.float32):
88
+ super().__init__(model, n_iterations,
89
+ output_channel, dtype)
90
+
91
+ self.initialize_domain(TimeDomain, dtype = dtype)
92
+
93
+ class FourierIG(CrossDomainIG):
94
+ """ Implementation of the CrossDomainIG specifically for the
95
+ frequency target domain.
96
+ """
97
+ def __init__(self,
98
+ model: tf.keras.models.Model,
99
+ n_iterations: int,
100
+ output_channel: int,
101
+ dtype = tf.float32):
102
+ super().__init__(model, n_iterations,
103
+ output_channel, dtype)
104
+
105
+ self.initialize_domain(FourierDomain, dtype = dtype)
106
+
107
+ class ICAIG(CrossDomainIG):
108
+ """ Implementation of the CrossDomainIG specifically for the
109
+ target domain defined by the Independent Component Analysis
110
+ (ICA) decomposition.
111
+
112
+ Here the decomposed signal channels act as the signal basis,
113
+ while the mixing matrix is the features on which the saliency
114
+ is expressed.
115
+
116
+ Args:
117
+ model (tf.keras.models.Model): The tensorflow model.
118
+ ica: A ICA object compatible with sklearn's FastICA object.
119
+ n_iterations (int): Number of steps to approximate the integral
120
+ in the IG.
121
+ output_channel (int): The model's output channel for which the
122
+ saliency map will be generated.
123
+ dtype: The type of the features.
124
+ """
125
+ def __init__(self,
126
+ model: tf.keras.models.Model,
127
+ ica,
128
+ n_iterations: int,
129
+ output_channel: int,
130
+ dtype = tf.float32):
131
+ super().__init__(model, n_iterations,
132
+ output_channel, dtype)
133
+
134
+ self.initialize_domain(ICADomain, ica = ica, dtype = dtype)
135
+
136
+ def run(self, x: np.array, x_baseline: np.array):
137
+ """ Runs the IG generation.
138
+
139
+ The IG components are expressed over the elements of
140
+ the entire unmixing matrix, since this is considered
141
+ as the input features (projected over the basis defined
142
+ by the independent channels). We sume along the first
143
+ dimension of the matrix to express the IG in terms of
144
+ independent channels.
145
+
146
+ Args:
147
+ x (np.array): Input sample for which to generate saliency.
148
+ x_basleine (np.array): Baseline sample.
149
+ """
150
+ super().run(x, x_baseline)
151
+ self.multiIG = tf.reduce_sum(self.multiIG, axis = 1)
152
+ return self.multiIG
@@ -0,0 +1,207 @@
1
+ import abc
2
+ import tensorflow as tf
3
+ import numpy as np
4
+
5
+ class DomainBase:
6
+ """ Defines the base class for a domain transformation.
7
+
8
+ Attributes:
9
+ coefficients (tf.Tensor): The input samples expressed
10
+ in the target domain.
11
+ baseline_coefficients (tf.Tensor): The baseline samples expressed
12
+ in the target domain.
13
+ dtype: The dtype of the coefficients when they are
14
+ converted to tf.Tensor.
15
+ """
16
+ def __init__(self, dtype = tf.float32):
17
+ """ Initialize the Domain.
18
+ """
19
+ self.dtype = dtype
20
+ self.coefficients = None
21
+ self.baseline_coefficients = None
22
+
23
+ def get_coefficients(self):
24
+ """ Return the coeffcients.
25
+ """
26
+ return self.coefficients
27
+
28
+ def get_coefficient_baseline(self):
29
+ """ Return the baseline coefficients.
30
+ """
31
+ return self.baseline_coefficients
32
+
33
+ @abc.abstractmethod
34
+ def set_coefficients(self, x, x_baseline):
35
+ """ Sets the coefficients depending on the transform
36
+ used. Should be implemented by class inheriting DomainBase.
37
+
38
+ Args:
39
+ x (np.array): The input sample.
40
+ x_baseline (np.array): The baseline sample.
41
+ """
42
+ return
43
+
44
+ @abc.abstractmethod
45
+ def forward_transform(self, x):
46
+ """ Performs the forward transform, transforming the input
47
+ sample x into the corresponding target domain.
48
+
49
+ Args:
50
+ x : The input sample
51
+ """
52
+ return
53
+
54
+ @abc.abstractmethod
55
+ def inverse_transform(self, x_input):
56
+ """ Performs the inverse transform, transforming the
57
+ sample x from the target domain back to the original one.
58
+
59
+ Args:
60
+ x_input : The input sample
61
+ """
62
+ return
63
+
64
+ class FourierDomain(DomainBase):
65
+ """ Domain implementation for the Fourier transform, mapping
66
+ time-domain samples into the frequency domain.
67
+ """
68
+ def __init__(self, dtype = tf.float32, channel_permutation = (0, 2, 1)):
69
+ """ Initialize the Fourier Domain.
70
+
71
+ Args:
72
+ dtype: The type of the input features.
73
+ channel_permutation: Permutation of the input samples such that
74
+ the time-samples are placed in the last channel. This is used
75
+ because tensorflow's fft transformation requires the last dimension
76
+ to correspond to the time domain.
77
+ """
78
+ super().__init__(dtype)
79
+ self.channel_permutation = channel_permutation
80
+
81
+ def forward_transform(self, x: tf.Tensor):
82
+ """ Implementation of the forward transform, transforming the input
83
+ time sample to the corresponding frequency domain sample.
84
+
85
+ Args:
86
+ x (tf.Tensor): Input time-domain sample.
87
+ """
88
+ return tf.signal.fft(tf.cast(tf.transpose(x, perm = self.channel_permutation),
89
+ dtype = tf.complex64))
90
+
91
+ def set_coefficients(self, x: np.array, x_baseline: np.array):
92
+ """ Sets the frequency coefficients transforming the input and
93
+ baseline samples into the frequency domain.
94
+
95
+ Args:
96
+ x (np.array): The input sample in time-domain.
97
+ x_baseline (np.array): The baseline sample in time-domain.
98
+ """
99
+ x_tf = tf.constant(x, dtype = self.dtype)
100
+ x_baseline_tf = tf.constant(x_baseline, dtype = self.dtype)
101
+
102
+ self.coefficients = self.forward_transform(x_tf)
103
+ self.baseline_coefficients = self.forward_transform(x_baseline_tf)
104
+
105
+ def inverse_transform(self, x_input: tf.Tensor):
106
+ """ Inverse transform, transforming the frequency domain input
107
+ x_input points back into the time domain.
108
+
109
+ Args:
110
+ x_input (tf.Tensor): The frequency domain input.
111
+ """
112
+ return tf.transpose(tf.cast(tf.signal.ifft(x_input),
113
+ dtype = tf.float32),
114
+ perm = self.channel_permutation)
115
+
116
+ class ICADomain(DomainBase):
117
+ """ Implements the Domain for the Independent Component Analysis
118
+ (ICA) decomposition.
119
+
120
+ We consider the independent channels to form the basis of
121
+ the input. The mixing matrix forms the actual input coefficients.
122
+ This way, ICA IG expresses significance of each independent component.
123
+
124
+ """
125
+ def __init__(self, ica, dtype = tf.float32, channel_permutation = (0, 2, 1)):
126
+ super().__init__(dtype)
127
+ self.channel_permutation = channel_permutation
128
+ self.ica = ica
129
+
130
+ def forward_transform(self, x: np.array):
131
+ """ Forward transform, transforms the input channels into
132
+ independent channels.
133
+
134
+ Args:
135
+ x (np.array): Input sample
136
+ """
137
+ X = self.ica.transform(x)
138
+ return X
139
+
140
+ def set_coefficients(self, x:np.array, x_baseline: np.array):
141
+ """ Sets the coefficients in the ICA space. We consider the
142
+ independent channels to form the basis of the input. The
143
+ mixing matrix forms the actual input coefficients. This way
144
+ ICA IG expresses significance of each independent component.
145
+
146
+ Since the basis of the transform are the independent components,
147
+ we only consider zero baseline.
148
+
149
+ Args:
150
+ x (np.array): Input sample of size [1, n_time_points, n_channels]
151
+ """
152
+
153
+ self.basis = tf.constant(self.forward_transform(x[0, ...]), dtype = self.dtype)
154
+
155
+ self.coefficients = tf.constant(self.ica.mixing_.T, dtype = self.dtype)[tf.newaxis, ...]
156
+ self.baseline_coefficients = tf.zeros_like(self.coefficients)
157
+ self.mean = tf.constant(self.ica.mean_, dtype = self.dtype)
158
+
159
+ def inverse_transform(self, x_input: tf.Tensor):
160
+ """ Inverses the ICA transform given an input matrix
161
+ x_input and the basis stored in the domain.
162
+
163
+ Args:
164
+ x_input (tf.Tensor): Input unmixing matrix of size [n_channels, n_channels].
165
+ """
166
+ return tf.matmul(self.basis, x_input) + self.mean
167
+
168
+ class TimeDomain(DomainBase):
169
+ """ Time domain implementation (Original input domain).
170
+ """
171
+ def __init__(self, dtype = tf.float32, channel_permutation = (0, 2, 1)):
172
+ """ Initialize the Time Domain.
173
+ """
174
+ super().__init__(dtype)
175
+
176
+ self.channel_permutation = channel_permutation
177
+
178
+ def forward_transform(self, x: tf.Tensor):
179
+ """ Implementation of the forward transform. No transformation
180
+ take place.
181
+
182
+ Args:
183
+ x (tf.Tensor): Input time-domain sample.
184
+ """
185
+ return tf.transpose(x, perm = self.channel_permutation)
186
+
187
+ def set_coefficients(self, x: np.array, x_baseline: np.array):
188
+ """ Sets the time coefficients transforming.
189
+
190
+ Args:
191
+ x (np.array): The input sample in time-domain.
192
+ x_baseline (np.array): The baseline sample in time-domain.
193
+ """
194
+ x_tf = tf.constant(x, dtype = self.dtype)
195
+ x_baseline_tf = tf.constant(x_baseline, dtype = self.dtype)
196
+
197
+ self.coefficients = self.forward_transform(x_tf)
198
+ self.baseline_coefficients = self.forward_transform(x_baseline_tf)
199
+
200
+ def inverse_transform(self, x_input: tf.Tensor):
201
+ """ Inverse transform, transforming the frequency domain input
202
+ x_input points back into the time domain.
203
+
204
+ Args:
205
+ x_input (tf.Tensor): The frequency domain input.
206
+ """
207
+ return tf.transpose(x_input, perm = self.channel_permutation)
File without changes
@@ -0,0 +1,162 @@
1
+ import torch
2
+ import numpy as np
3
+
4
+ from cross_domain_saliency_maps.torch_ig.domain_transforms import DomainBase
5
+ from cross_domain_saliency_maps.torch_ig.domain_transforms import FourierDomain
6
+ from cross_domain_saliency_maps.torch_ig.domain_transforms import ICADomain
7
+ from cross_domain_saliency_maps.torch_ig.domain_transforms import TimeDomain
8
+
9
+ from tqdm import tqdm
10
+
11
+ class CrossDomainIG:
12
+ """ Cross Domain IG base class. Defines the basic functionality
13
+ for cross-domain ig.
14
+
15
+ Attributes:
16
+ model (torch.nn.Module): A pytorch model.
17
+ n_iterations (int): Number of iterations for approximating IG.
18
+ output_channel(int): The output channel of the model for which
19
+ we generate the saliency map.
20
+ dtype (dtype): The type of the target domain.
21
+ """
22
+ def __init__(self, model: torch.nn.Module,
23
+ n_iterations: int,
24
+ output_channel: int,
25
+ dtype = torch.float32):
26
+ """ Initializes CrossDomainIG.
27
+
28
+ Args:
29
+ model (torch.nn.Module): A pytorch model for which
30
+ saliency maps will be generated.
31
+ n_iterations (int): Number of steps used in approximating
32
+ the integral in the Integrated Gradients computation.
33
+ output_channel (int): The channle of the model's output used
34
+ for the saliency map (e.g. the class channel).
35
+ """
36
+ self.model = model
37
+ self.n_iterations = n_iterations
38
+ self.output_channel = output_channel
39
+
40
+ def initialize_domain(self, Domain: DomainBase, **kwargs):
41
+ """ Initializes the target domain in which the IG is
42
+ expressed.
43
+
44
+ Args:
45
+ Domain (DomainBase): The target domain.
46
+ **kwargs: Parameters needed to initialize the Domain.
47
+ """
48
+ self.domain = Domain(**kwargs)
49
+
50
+ def run(self, x: np.array, x_baseline: np.array):
51
+ """ Runs the saliency map generation for input sample x
52
+ using the x_baseline for baseline.
53
+
54
+ Args:
55
+ x (np.array): a single sample with shape [1, n_timesteps, n_channels]
56
+ x_baseline (np.array): the baseline sample with shape [1, n_timesteps, n_channels]
57
+ """
58
+ self.domain.set_coefficients(x, x_baseline)
59
+
60
+ grad_sum = 0
61
+
62
+ X_in = self.domain.get_coefficients()
63
+ X_baseline = self.domain.get_coefficient_baseline()
64
+
65
+ X_samples = [ X_baseline + (float(i) / self.n_iterations) * (X_in - X_baseline) for i in range(1, self.n_iterations + 1)]
66
+
67
+ for X_sample in tqdm(X_samples):
68
+ X_sample.requires_grad = True
69
+ x_ = self.domain.inverse_transform(X_sample)
70
+ prediction = self.model(x_)
71
+ prediction[0, self.output_channel].backward()
72
+ grad_sum += torch.conj(X_sample.grad)
73
+
74
+ grad_sum /= self.n_iterations
75
+ self.multiIG = torch.real((X_in - X_baseline) * grad_sum)
76
+
77
+ return self.multiIG
78
+
79
+ def getMultiIG(self):
80
+ """ Get the generated saliency map.
81
+ """
82
+ return self.multiIG
83
+
84
+
85
+ class TimeIG(CrossDomainIG):
86
+ """ Implementation of the CrossDomainIG specifically for the
87
+ time target domain. This is the original Integrated Gradients
88
+ method.
89
+ """
90
+ def __init__(self,
91
+ model: torch.nn.Module,
92
+ n_iterations: int,
93
+ output_channel: int,
94
+ device: torch.device,
95
+ dtype = torch.float32):
96
+ super().__init__(model, n_iterations,
97
+ output_channel, dtype)
98
+
99
+ self.initialize_domain(TimeDomain, device = device, dtype = dtype)
100
+
101
+ class FourierIG(CrossDomainIG):
102
+ """ Implementation of the CrossDomainIG specifically for the
103
+ frequency target domain.
104
+ """
105
+ def __init__(self,
106
+ model: torch.nn.Module,
107
+ n_iterations: int,
108
+ output_channel: int,
109
+ device: torch.device,
110
+ dtype = torch.float32):
111
+ super().__init__(model, n_iterations,
112
+ output_channel, dtype)
113
+
114
+ self.initialize_domain(FourierDomain, device = device, dtype = dtype)
115
+
116
+ class ICAIG(CrossDomainIG):
117
+ """ Implementation of the CrossDomainIG specifically for the
118
+ target domain defined by the Independent Component Analysis
119
+ (ICA) decomposition.
120
+
121
+ Here the decomposed signal channels act as the signal basis,
122
+ while the mixing matrix is the features on which the saliency
123
+ is expressed.
124
+
125
+ Args:
126
+ model (torch.nn.Module): The pytorch model.
127
+ ica: A ICA object compatible with sklearn's FastICA object.
128
+ n_iterations (int): Number of steps to approximate the integral
129
+ in the IG.
130
+ output_channel (int): The model's output channel for which the
131
+ saliency map will be generated.
132
+ dtype: The type of the features.
133
+ """
134
+ def __init__(self,
135
+ model: torch.nn.Module,
136
+ ica,
137
+ n_iterations: int,
138
+ output_channel: int,
139
+ device: torch.device,
140
+ dtype = torch.float32):
141
+ super().__init__(model, n_iterations,
142
+ output_channel, dtype)
143
+
144
+ self.initialize_domain(ICADomain, ica = ica, dtype = dtype, device = device)
145
+
146
+ def run(self, x: np.array, x_baseline: np.array):
147
+ """ Runs the IG generation.
148
+
149
+ The IG components are expressed over the elements of
150
+ the entire unmixing matrix, since this is considered
151
+ as the input features (projected over the basis defined
152
+ by the independent channels). We sume along the first
153
+ dimension of the matrix to express the IG in terms of
154
+ independent channels.
155
+
156
+ Args:
157
+ x (np.array): Input sample for which to generate saliency.
158
+ x_basleine (np.array): Baseline sample.
159
+ """
160
+ super().run(x, x_baseline)
161
+ self.multiIG = torch.sum(self.multiIG[0], dim = 1)
162
+ return self.multiIG