dl-backtrace 0.0.14__py3-none-any.whl → 0.0.16.dev4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of dl-backtrace might be problematic. Click here for more details.

Files changed (27) hide show
  1. dl_backtrace/pytorch_backtrace/backtrace/backtrace.py +173 -44
  2. dl_backtrace/pytorch_backtrace/backtrace/utils/__init__.py +3 -0
  3. dl_backtrace/pytorch_backtrace/backtrace/utils/encoder.py +183 -0
  4. dl_backtrace/pytorch_backtrace/backtrace/utils/encoder_decoder.py +489 -0
  5. dl_backtrace/pytorch_backtrace/backtrace/utils/helper.py +95 -0
  6. dl_backtrace/pytorch_backtrace/backtrace/utils/prop.py +481 -0
  7. dl_backtrace/tf_backtrace/backtrace/__init__.py +1 -2
  8. dl_backtrace/tf_backtrace/backtrace/activation_info.py +33 -0
  9. dl_backtrace/tf_backtrace/backtrace/backtrace.py +506 -279
  10. dl_backtrace/tf_backtrace/backtrace/models.py +25 -0
  11. dl_backtrace/tf_backtrace/backtrace/server.py +27 -0
  12. dl_backtrace/tf_backtrace/backtrace/utils/__init__.py +5 -2
  13. dl_backtrace/tf_backtrace/backtrace/utils/encoder.py +206 -0
  14. dl_backtrace/tf_backtrace/backtrace/utils/encoder_decoder.py +501 -0
  15. dl_backtrace/tf_backtrace/backtrace/utils/helper.py +99 -0
  16. dl_backtrace/tf_backtrace/backtrace/utils/utils_contrast.py +1132 -0
  17. dl_backtrace/tf_backtrace/backtrace/utils/utils_prop.py +1582 -0
  18. dl_backtrace/version.py +2 -2
  19. {dl_backtrace-0.0.14.dist-info → dl_backtrace-0.0.16.dev4.dist-info}/METADATA +2 -2
  20. dl_backtrace-0.0.16.dev4.dist-info/RECORD +29 -0
  21. {dl_backtrace-0.0.14.dist-info → dl_backtrace-0.0.16.dev4.dist-info}/WHEEL +1 -1
  22. dl_backtrace/tf_backtrace/backtrace/config.py +0 -41
  23. dl_backtrace/tf_backtrace/backtrace/utils/contrast.py +0 -834
  24. dl_backtrace/tf_backtrace/backtrace/utils/prop.py +0 -725
  25. dl_backtrace-0.0.14.dist-info/RECORD +0 -21
  26. {dl_backtrace-0.0.14.dist-info → dl_backtrace-0.0.16.dev4.dist-info}/LICENSE +0 -0
  27. {dl_backtrace-0.0.14.dist-info → dl_backtrace-0.0.16.dev4.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,1582 @@
1
+ import gc
2
+ import numpy as np # type: ignore
3
+ import tensorflow as tf # type: ignore
4
+ from tensorflow import keras
5
+ from tensorflow.keras import backend as K # type: ignore
6
+ from tensorflow.keras.backend import sigmoid # type: ignore
7
+ from numpy.lib.stride_tricks import as_strided # type: ignore
8
+
9
+ def np_swish(x, beta=0.75):
10
+ z = 1 / (1 + np.exp(-(beta * x)))
11
+ return x * z
12
+
13
+ def np_wave(x, alpha=1.0):
14
+ return (alpha * x * np.exp(1.0)) / (np.exp(-x) + np.exp(x))
15
+
16
+ def np_pulse(x, alpha=1.0):
17
+ return alpha * (1 - np.tanh(x) * np.tanh(x))
18
+
19
+ def np_absolute(x, alpha=1.0):
20
+ return alpha * x * np.tanh(x)
21
+
22
+ def np_hard_sigmoid(x):
23
+ return np.clip(0.2 * x + 0.5, 0, 1)
24
+
25
+ def np_sigmoid(x):
26
+ z = 1 / (1 + np.exp(-x))
27
+ return z
28
+
29
+ def np_tanh(x):
30
+ z = np.tanh(x)
31
+ return z.astype(np.float32)
32
+
33
+ class LSTM_forward(object):
34
+ def __init__(
35
+ self, num_cells, units, weights, return_sequence=False, go_backwards=False
36
+ ):
37
+ self.num_cells = num_cells
38
+ self.units = units
39
+ self.kernel = weights[0]
40
+ self.recurrent_kernel = weights[1]
41
+ self.bias = weights[2]
42
+ self.return_sequence = return_sequence
43
+ self.go_backwards = go_backwards
44
+ self.recurrent_activation = tf.math.sigmoid
45
+ self.activation = tf.math.tanh
46
+
47
+ self.compute_log = {}
48
+ for i in range(self.num_cells):
49
+ self.compute_log[i] = {}
50
+ self.compute_log[i]["inp"] = None
51
+ self.compute_log[i]["x"] = None
52
+ self.compute_log[i]["hstate"] = [None, None]
53
+ self.compute_log[i]["cstate"] = [None, None]
54
+ self.compute_log[i]["int_arrays"] = {}
55
+
56
+ def compute_carry_and_output(self, x, h_tm1, c_tm1, cell_num):
57
+ """Computes carry and output using split kernels."""
58
+ x_i, x_f, x_c, x_o = x
59
+ h_tm1_i, h_tm1_f, h_tm1_c, h_tm1_o = h_tm1
60
+ i = self.recurrent_activation(
61
+ x_i + K.dot(h_tm1_i, self.recurrent_kernel[:, : self.units])
62
+ )
63
+ f = self.recurrent_activation(
64
+ x_f + K.dot(h_tm1_f, self.recurrent_kernel[:, self.units : self.units * 2])
65
+ )
66
+ c = f * c_tm1 + i * self.activation(
67
+ x_c
68
+ + K.dot(h_tm1_c, self.recurrent_kernel[:, self.units * 2 : self.units * 3])
69
+ )
70
+ o = self.recurrent_activation(
71
+ x_o + K.dot(h_tm1_o, self.recurrent_kernel[:, self.units * 3 :])
72
+ )
73
+ self.compute_log[cell_num]["int_arrays"]["i"] = i
74
+ self.compute_log[cell_num]["int_arrays"]["f"] = f
75
+ self.compute_log[cell_num]["int_arrays"]["c"] = c
76
+ self.compute_log[cell_num]["int_arrays"]["o"] = o
77
+ return c, o
78
+
79
+ def calculate_lstm_cell_wt(self, inputs, states, cell_num, training=None):
80
+ h_tm1 = states[0] # previous memory state
81
+ c_tm1 = states[1] # previous carry state
82
+ self.compute_log[cell_num]["inp"] = inputs
83
+ self.compute_log[cell_num]["hstate"][0] = h_tm1
84
+ self.compute_log[cell_num]["cstate"][0] = c_tm1
85
+ inputs_i = inputs
86
+ inputs_f = inputs
87
+ inputs_c = inputs
88
+ inputs_o = inputs
89
+ k_i, k_f, k_c, k_o = tf.split(self.kernel, num_or_size_splits=4, axis=1)
90
+ x_i = K.dot(inputs_i, k_i)
91
+ x_f = K.dot(inputs_f, k_f)
92
+ x_c = K.dot(inputs_c, k_c)
93
+ x_o = K.dot(inputs_o, k_o)
94
+ b_i, b_f, b_c, b_o = tf.split(self.bias, num_or_size_splits=4, axis=0)
95
+ x_i = tf.add(x_i, b_i)
96
+ x_f = tf.add(x_f, b_f)
97
+ x_c = tf.add(x_c, b_c)
98
+ x_o = tf.add(x_o, b_o)
99
+
100
+ h_tm1_i = h_tm1
101
+ h_tm1_f = h_tm1
102
+ h_tm1_c = h_tm1
103
+ h_tm1_o = h_tm1
104
+ x = (x_i, x_f, x_c, x_o)
105
+ h_tm1 = (h_tm1_i, h_tm1_f, h_tm1_c, h_tm1_o)
106
+ c, o = self.compute_carry_and_output(x, h_tm1, c_tm1, cell_num)
107
+ h = o * self.activation(c)
108
+ self.compute_log[cell_num]["x"] = x
109
+ self.compute_log[cell_num]["hstate"][1] = h
110
+ self.compute_log[cell_num]["cstate"][1] = c
111
+ return h, [h, c]
112
+
113
+ def calculate_lstm_wt(self, input_data):
114
+ hstate = tf.convert_to_tensor(np.zeros((1, self.units)), dtype=tf.float32)
115
+ cstate = tf.convert_to_tensor(np.zeros((1, self.units)), dtype=tf.float32)
116
+ output = []
117
+ for ind in range(input_data.shape[0]):
118
+ inp = tf.convert_to_tensor(
119
+ input_data[ind, :].reshape((1, input_data.shape[1])), dtype=tf.float32
120
+ )
121
+ h, s = self.calculate_lstm_cell_wt(inp, [hstate, cstate], ind)
122
+ hstate = s[0]
123
+ cstate = s[1]
124
+ output.append(h)
125
+ return output
126
+
127
+ class LSTM_backtrace(object):
128
+ def __init__(
129
+ self, num_cells, units, weights, return_sequence=False, go_backwards=False
130
+ ):
131
+ self.num_cells = num_cells
132
+ self.units = units
133
+ self.kernel = weights[0]
134
+ self.recurrent_kernel = weights[1]
135
+ self.bias = weights[2]
136
+ self.return_sequence = return_sequence
137
+ self.go_backwards = go_backwards
138
+ self.recurrent_activation = np_sigmoid
139
+ self.activation = np_tanh
140
+
141
+ self.compute_log = {}
142
+
143
+ def calculate_wt_fc(self, wts, inp, w, b, act):
144
+ mul_mat = np.einsum("ij,i->ij", w, inp).T
145
+ wt_mat = np.zeros(mul_mat.shape)
146
+ for i in range(mul_mat.shape[0]):
147
+ l1_ind1 = mul_mat[i]
148
+ wt_ind1 = wt_mat[i]
149
+ wt = wts[i]
150
+ p_ind = l1_ind1 > 0
151
+ n_ind = l1_ind1 < 0
152
+ p_sum = np.sum(l1_ind1[p_ind])
153
+ n_sum = np.sum(l1_ind1[n_ind]) * -1
154
+ if len(b) > 0:
155
+ if b[i] > 0:
156
+ pbias = b[i]
157
+ nbias = 0
158
+ else:
159
+ pbias = 0
160
+ nbias = b[i] * -1
161
+ else:
162
+ pbias = 0
163
+ nbias = 0
164
+ t_sum = p_sum + pbias - n_sum - nbias
165
+ if act["type"] == "mono":
166
+ if act["range"]["l"]:
167
+ if t_sum < act["range"]["l"]:
168
+ p_sum = 0
169
+ if act["range"]["u"]:
170
+ if t_sum > act["range"]["u"]:
171
+ n_sum = 0
172
+ elif act["type"] == "non_mono":
173
+ t_act = act["func"](t_sum)
174
+ p_act = act["func"](p_sum + pbias)
175
+ n_act = act["func"](-1 * (n_sum + nbias))
176
+ if act["range"]["l"]:
177
+ if t_sum < act["range"]["l"]:
178
+ p_sum = 0
179
+ if act["range"]["u"]:
180
+ if t_sum > act["range"]["u"]:
181
+ n_sum = 0
182
+ if p_sum > 0 and n_sum > 0:
183
+ if t_act == p_act:
184
+ n_sum = 0
185
+ elif t_act == n_act:
186
+ p_sum = 0
187
+ if p_sum > 0:
188
+ p_agg_wt = (p_sum + pbias) / (p_sum + n_sum + pbias + nbias)
189
+ p_agg_wt = p_agg_wt * (p_sum / (p_sum + pbias))
190
+ else:
191
+ p_agg_wt = 0
192
+ if n_sum > 0:
193
+ n_agg_wt = (n_sum + nbias) / (p_sum + n_sum + pbias + nbias)
194
+ n_agg_wt = n_agg_wt * (n_sum / (n_sum + nbias))
195
+ else:
196
+ n_agg_wt = 0
197
+ if p_sum == 0:
198
+ p_sum = 1
199
+ if n_sum == 0:
200
+ n_sum = 1
201
+ wt_ind1[p_ind] = (l1_ind1[p_ind] / p_sum) * wt * p_agg_wt
202
+ wt_ind1[n_ind] = (l1_ind1[n_ind] / n_sum) * wt * n_agg_wt * -1.0
203
+ wt_mat = wt_mat.sum(axis=0)
204
+ return wt_mat
205
+
206
+ def calculate_wt_add(self, wts, inp=None):
207
+ wt_mat = []
208
+ inp_list = []
209
+ for x in inp:
210
+ wt_mat.append(np.zeros_like(x))
211
+ wt_mat = np.array(wt_mat)
212
+ inp_list = np.array(inp)
213
+ for i in range(wt_mat.shape[1]):
214
+ wt_ind1 = wt_mat[:, i]
215
+ wt = wts[i]
216
+ l1_ind1 = inp_list[:, i]
217
+ p_ind = l1_ind1 > 0
218
+ n_ind = l1_ind1 < 0
219
+ p_sum = np.sum(l1_ind1[p_ind])
220
+ n_sum = np.sum(l1_ind1[n_ind]) * -1
221
+ t_sum = p_sum - n_sum
222
+ p_agg_wt = 0
223
+ n_agg_wt = 0
224
+ if p_sum + n_sum > 0:
225
+ p_agg_wt = p_sum / (p_sum + n_sum)
226
+ n_agg_wt = n_sum / (p_sum + n_sum)
227
+ if p_sum == 0:
228
+ p_sum = 1
229
+ if n_sum == 0:
230
+ n_sum = 1
231
+ wt_ind1[p_ind] = (l1_ind1[p_ind] / p_sum) * wt * p_agg_wt
232
+ wt_ind1[n_ind] = (l1_ind1[n_ind] / n_sum) * wt * n_agg_wt * -1.0
233
+ wt_mat[:, i] = wt_ind1
234
+ wt_mat = [i.reshape(wts.shape) for i in list(wt_mat)]
235
+ return wt_mat
236
+
237
+ def calculate_wt_multiply(self, wts, inp=None):
238
+ wt_mat = []
239
+ inp_list = []
240
+ for x in inp:
241
+ wt_mat.append(np.zeros_like(x))
242
+ wt_mat = np.array(wt_mat)
243
+ inp_list = np.array(inp)
244
+ inp_prod = inp[0] * inp[1]
245
+ inp_diff1 = np.abs(inp_prod - inp[0])
246
+ inp_diff2 = np.abs(inp_prod - inp[1])
247
+ inp_diff_sum = inp_diff1 + inp_diff2
248
+ inp_wt1 = (inp_diff1 / inp_diff_sum) * wts
249
+ inp_wt2 = (inp_diff2 / inp_diff_sum) * wts
250
+ return [inp_wt1, inp_wt2]
251
+
252
+ def compute_carry_and_output(self, wt_o, wt_c, h_tm1, c_tm1, x, cell_num):
253
+ """Computes carry and output using split kernels."""
254
+ h_tm1_i, h_tm1_f, h_tm1_c, h_tm1_o = (h_tm1, h_tm1, h_tm1, h_tm1)
255
+ x_i, x_f, x_c, x_o = x
256
+ f = self.compute_log[cell_num]["int_arrays"]["f"].numpy()[0]
257
+ i = self.compute_log[cell_num]["int_arrays"]["i"].numpy()[0]
258
+ # o = self.recurrent_activation(
259
+ # x_o + np.dot(h_tm1_o, self.recurrent_kernel[:, self.units * 3:])).astype(np.float32)
260
+ temp1 = np.dot(h_tm1_o, self.recurrent_kernel[:, self.units * 3 :]).astype(
261
+ np.float32
262
+ )
263
+ wt_x_o, wt_temp1 = self.calculate_wt_add(wt_o, [x_o, temp1])
264
+ wt_h_tm1_o = self.calculate_wt_fc(
265
+ wt_temp1,
266
+ h_tm1_o,
267
+ self.recurrent_kernel[:, self.units * 3 :],
268
+ [],
269
+ {"type": None},
270
+ )
271
+
272
+ # c = f * c_tm1 + i * self.activation(x_c + np.dot(
273
+ # h_tm1_c, self.recurrent_kernel[:, self.units * 2:self.units * 3])).astype(np.float32)
274
+ temp2 = f * c_tm1
275
+ temp3_1 = np.dot(
276
+ h_tm1_c, self.recurrent_kernel[:, self.units * 2 : self.units * 3]
277
+ )
278
+ temp3_2 = self.activation(x_c + temp3_1)
279
+ temp3_3 = i * temp3_2
280
+ wt_temp2, wt_temp3_3 = self.calculate_wt_add(wt_c, [temp2, temp3_3])
281
+ wt_f, wt_c_tm1 = self.calculate_wt_multiply(wt_temp2, [f, c_tm1])
282
+ wt_i, wt_temp3_2 = self.calculate_wt_multiply(wt_temp3_3, [i, temp3_2])
283
+ wt_x_c, wt_temp3_1 = self.calculate_wt_add(wt_temp3_2, [x_c, temp3_1])
284
+ wt_h_tm1_c = self.calculate_wt_fc(
285
+ wt_temp3_1,
286
+ h_tm1_c,
287
+ self.recurrent_kernel[:, self.units * 2 : self.units * 3],
288
+ [],
289
+ {"type": None},
290
+ )
291
+
292
+ # f = self.recurrent_activation(x_f + np.dot(
293
+ # h_tm1_f, self.recurrent_kernel[:, self.units:self.units * 2])).astype(np.float32)
294
+ temp4 = np.dot(h_tm1_f, self.recurrent_kernel[:, self.units : self.units * 2])
295
+ wt_x_f, wt_temp4 = self.calculate_wt_add(wt_f, [x_f, temp4])
296
+ wt_h_tm1_f = self.calculate_wt_fc(
297
+ wt_temp4,
298
+ h_tm1_f,
299
+ self.recurrent_kernel[:, self.units : self.units * 2],
300
+ [],
301
+ {"type": None},
302
+ )
303
+
304
+ # i = self.recurrent_activation(
305
+ # x_i + np.dot(h_tm1_i, self.recurrent_kernel[:, :self.units])).astype(np.float32)
306
+ temp5 = np.dot(h_tm1_i, self.recurrent_kernel[:, : self.units])
307
+ wt_x_i, wt_temp5 = self.calculate_wt_add(wt_i, [x_i, temp5])
308
+ wt_h_tm1_i = self.calculate_wt_fc(
309
+ wt_temp5,
310
+ h_tm1_i,
311
+ self.recurrent_kernel[:, : self.units],
312
+ [],
313
+ {"type": None},
314
+ )
315
+
316
+ return (
317
+ wt_x_i,
318
+ wt_x_f,
319
+ wt_x_c,
320
+ wt_x_o,
321
+ wt_h_tm1_i,
322
+ wt_h_tm1_f,
323
+ wt_h_tm1_c,
324
+ wt_h_tm1_o,
325
+ wt_c_tm1,
326
+ )
327
+
328
+ def calculate_lstm_cell_wt(self, cell_num, wts_hstate, wts_cstate):
329
+ o = self.compute_log[cell_num]["int_arrays"]["o"].numpy()[0]
330
+ c = self.compute_log[cell_num]["cstate"][1].numpy()[0]
331
+ h_tm1 = self.compute_log[cell_num]["hstate"][0].numpy()[0]
332
+ c_tm1 = self.compute_log[cell_num]["cstate"][0].numpy()[0]
333
+ x = [i.numpy()[0] for i in self.compute_log[cell_num]["x"]]
334
+ wt_o, wt_c = self.calculate_wt_multiply(
335
+ wts_hstate, [o, self.activation(c)]
336
+ ) # h = o * self.activation(c)
337
+ wt_c = wt_c + wts_cstate
338
+ (
339
+ wt_x_i,
340
+ wt_x_f,
341
+ wt_x_c,
342
+ wt_x_o,
343
+ wt_h_tm1_i,
344
+ wt_h_tm1_f,
345
+ wt_h_tm1_c,
346
+ wt_h_tm1_o,
347
+ wt_c_tm1,
348
+ ) = self.compute_carry_and_output(wt_o, wt_c, h_tm1, c_tm1, x, cell_num)
349
+ wt_h_tm1 = wt_h_tm1_i + wt_h_tm1_f + wt_h_tm1_c + wt_h_tm1_o
350
+ inputs = self.compute_log[cell_num]["inp"].numpy()[0]
351
+ k_i, k_f, k_c, k_o = np.split(self.kernel, indices_or_sections=4, axis=1)
352
+ b_i, b_f, b_c, b_o = np.split(self.bias, indices_or_sections=4, axis=0)
353
+
354
+ wt_inputs_i = self.calculate_wt_fc(wt_x_i, inputs, k_i, b_i, {"type": None})
355
+ wt_inputs_f = self.calculate_wt_fc(wt_x_f, inputs, k_f, b_f, {"type": None})
356
+ wt_inputs_c = self.calculate_wt_fc(wt_x_c, inputs, k_c, b_c, {"type": None})
357
+ wt_inputs_o = self.calculate_wt_fc(wt_x_o, inputs, k_o, b_o, {"type": None})
358
+
359
+ wt_inputs = wt_inputs_i + wt_inputs_f + wt_inputs_c + wt_inputs_o
360
+
361
+ return wt_inputs, wt_h_tm1, wt_c_tm1
362
+
363
+ def calculate_lstm_wt(self, wts, compute_log):
364
+ self.compute_log = compute_log
365
+ output = []
366
+ if self.return_sequence:
367
+ temp_wts_hstate = wts[-1, :]
368
+ else:
369
+ temp_wts_hstate = wts
370
+ temp_wts_cstate = np.zeros_like(self.compute_log[0]["cstate"][1].numpy()[0])
371
+ for ind in range(len(self.compute_log) - 1, -1, -1):
372
+ temp_wt_inp, temp_wts_hstate, temp_wts_cstate = self.calculate_lstm_cell_wt(
373
+ ind, temp_wts_hstate, temp_wts_cstate
374
+ )
375
+ output.append(temp_wt_inp)
376
+ if self.return_sequence and ind > 0:
377
+ temp_wts_hstate = temp_wts_hstate + wts[ind - 1, :]
378
+ output.reverse()
379
+ return np.array(output)
380
+
381
+ def dummy_wt(wts, inp, *args):
382
+ test_wt = np.zeros_like(inp)
383
+ return test_wt
384
+
385
+ def calculate_wt_fc(wts, inp, w, b, act):
386
+ mul_mat = np.einsum("ij,i->ij", w.numpy(), inp).T
387
+ wt_mat = np.zeros(mul_mat.shape)
388
+ for i in range(mul_mat.shape[0]):
389
+ l1_ind1 = mul_mat[i]
390
+ wt_ind1 = wt_mat[i]
391
+ wt = wts[i]
392
+ p_ind = l1_ind1 > 0
393
+ n_ind = l1_ind1 < 0
394
+ p_sum = np.sum(l1_ind1[p_ind])
395
+ n_sum = np.sum(l1_ind1[n_ind]) * -1
396
+ if b.numpy()[i] > 0:
397
+ pbias = b.numpy()[i]
398
+ nbias = 0
399
+ else:
400
+ pbias = 0
401
+ nbias = b.numpy()[i] * -1
402
+ t_sum = p_sum + pbias - n_sum - nbias
403
+ if act["type"] == "mono":
404
+ if act["range"]["l"]:
405
+ if t_sum < act["range"]["l"]:
406
+ p_sum = 0
407
+ if act["range"]["u"]:
408
+ if t_sum > act["range"]["u"]:
409
+ n_sum = 0
410
+ elif act["type"] == "non_mono":
411
+ t_act = act["func"](t_sum)
412
+ p_act = act["func"](p_sum + pbias)
413
+ n_act = act["func"](-1 * (n_sum + nbias))
414
+ if act["range"]["l"]:
415
+ if t_sum < act["range"]["l"]:
416
+ p_sum = 0
417
+ if act["range"]["u"]:
418
+ if t_sum > act["range"]["u"]:
419
+ n_sum = 0
420
+ if p_sum > 0 and n_sum > 0:
421
+ if t_act == p_act:
422
+ n_sum = 0
423
+ elif t_act == n_act:
424
+ p_sum = 0
425
+ if p_sum > 0:
426
+ p_agg_wt = (p_sum + pbias) / (p_sum + n_sum + pbias + nbias)
427
+ p_agg_wt = p_agg_wt * (p_sum / (p_sum + pbias))
428
+ else:
429
+ p_agg_wt = 0
430
+ if n_sum > 0:
431
+ n_agg_wt = (n_sum + nbias) / (p_sum + n_sum + pbias + nbias)
432
+ n_agg_wt = n_agg_wt * (n_sum / (n_sum + nbias))
433
+ else:
434
+ n_agg_wt = 0
435
+ if p_sum == 0:
436
+ p_sum = 1
437
+ if n_sum == 0:
438
+ n_sum = 1
439
+ wt_ind1[p_ind] = (l1_ind1[p_ind] / p_sum) * wt * p_agg_wt
440
+ wt_ind1[n_ind] = (l1_ind1[n_ind] / n_sum) * wt * n_agg_wt * -1.0
441
+
442
+ wt_mat = wt_mat.sum(axis=0)
443
+ return wt_mat
444
+
445
+ def calculate_wt_rshp(wts, inp=None):
446
+ x = np.reshape(wts, inp.shape)
447
+ return x
448
+
449
+ def calculate_wt_concat(wts, inp=None, axis=-1):
450
+ splits = [i.shape[axis] for i in inp]
451
+ splits = np.cumsum(splits)
452
+ if axis > 0:
453
+ axis = axis - 1
454
+ x = np.split(wts, indices_or_sections=splits, axis=axis)
455
+ return x
456
+
457
+ def calculate_wt_add(wts, inp=None):
458
+ wt_mat = []
459
+ inp_list = []
460
+ expanded_wts = as_strided(
461
+ wts,
462
+ shape=(np.prod(wts.shape),),
463
+ strides=(wts.strides[-1],),
464
+ writeable=False, # totally use this to avoid writing to memory in weird places
465
+ )
466
+
467
+ for x in inp:
468
+ expanded_input = as_strided(
469
+ x,
470
+ shape=(np.prod(x.shape),),
471
+ strides=(x.strides[-1],),
472
+ writeable=False, # totally use this to avoid writing to memory in weird places
473
+ )
474
+ inp_list.append(expanded_input)
475
+ wt_mat.append(np.zeros_like(expanded_input))
476
+ wt_mat = np.array(wt_mat)
477
+ inp_list = np.array(inp_list)
478
+ for i in range(wt_mat.shape[1]):
479
+ wt_ind1 = wt_mat[:, i]
480
+ wt = expanded_wts[i]
481
+ l1_ind1 = inp_list[:, i]
482
+ p_ind = l1_ind1 > 0
483
+ n_ind = l1_ind1 < 0
484
+ p_sum = np.sum(l1_ind1[p_ind])
485
+ n_sum = np.sum(l1_ind1[n_ind]) * -1
486
+ t_sum = p_sum - n_sum
487
+ p_agg_wt = 0
488
+ n_agg_wt = 0
489
+ if p_sum + n_sum > 0:
490
+ p_agg_wt = p_sum / (p_sum + n_sum)
491
+ n_agg_wt = n_sum / (p_sum + n_sum)
492
+ if p_sum == 0:
493
+ p_sum = 1
494
+ if n_sum == 0:
495
+ n_sum = 1
496
+ wt_ind1[p_ind] = (l1_ind1[p_ind] / p_sum) * wt * p_agg_wt
497
+ wt_ind1[n_ind] = (l1_ind1[n_ind] / n_sum) * wt * n_agg_wt * -1.0
498
+ wt_mat[:, i] = wt_ind1
499
+ wt_mat = [i.reshape(wts.shape) for i in list(wt_mat)]
500
+ return wt_mat
501
+
502
+ def calculate_start_wt(arg, scaler=None,thresholding=0.5,task="binary-classification"):
503
+ if arg.ndim == 2:
504
+ if task == "binary-classification" or task == "multi-class classification":
505
+ x = np.argmax(arg[0])
506
+ m = np.max(arg[0])
507
+ y = np.zeros(arg.shape)
508
+ if scaler:
509
+ y[0][x] = scaler
510
+ else:
511
+ y[0][x] = m
512
+ elif task == "bbox-regression":
513
+ y = np.zeros(arg.shape)
514
+ if scaler:
515
+ y[0] = scaler
516
+ num_non_zero_elements = np.count_nonzero(y)
517
+ if num_non_zero_elements > 0:
518
+ y = y / num_non_zero_elements
519
+ else:
520
+ m = np.max(arg[0])
521
+ x = np.argmax(arg[0])
522
+ y[0][x] = m
523
+ else:
524
+ x = np.argmax(arg[0])
525
+ m = np.max(arg[0])
526
+ y = np.zeros(arg.shape)
527
+ if scaler:
528
+ y[0][x] = scaler
529
+ else:
530
+ y[0][x] = m
531
+
532
+ elif arg.ndim == 4 and task == "binary-segmentation":
533
+ indices = np.where(arg > thresholding)
534
+ y = np.zeros(arg.shape)
535
+ if scaler:
536
+ y[indices] = scaler
537
+ num_non_zero_elements = np.count_nonzero(y)
538
+ if num_non_zero_elements > 0:
539
+ y = y / num_non_zero_elements
540
+ else:
541
+ y[indices] = arg[indices]
542
+
543
+ else:
544
+ x = np.argmax(arg[0])
545
+ m = np.max(arg[0])
546
+ y = np.zeros(arg.shape)
547
+ if scaler:
548
+ y[0][x] = scaler
549
+ else:
550
+ y[0][x] = m
551
+ return y[0]
552
+
553
+ def calculate_wt_passthru(wts):
554
+ return wts
555
+
556
+ def calculate_wt_zero_pad(wts,inp,padding):
557
+ wt_mat = wts[padding[0][0]:inp.shape[0]+padding[0][0],padding[1][0]:inp.shape[1]+padding[1][0],:]
558
+ return wt_mat
559
+
560
+ def calculate_padding(kernel_size, inp, padding, strides, const_val=0.0):
561
+ if padding=='valid':
562
+ return (inp, [[0,0],[0,0],[0,0]])
563
+ else:
564
+ h = inp.shape[0]%strides[0]
565
+ if h==0:
566
+ pad_h = np.max([0,kernel_size[0]-strides[0]])
567
+ else:
568
+ pad_h = np.max([0,kernel_size[0]-h])
569
+
570
+ v = inp.shape[1]%strides[1]
571
+ if v==0:
572
+ pad_v = np.max([0,kernel_size[1]-strides[1]])
573
+ else:
574
+ pad_v = np.max([0,kernel_size[1]-v])
575
+
576
+ paddings = [np.floor([pad_h/2.0,(pad_h+1)/2.0]).astype("int32"),
577
+ np.floor([pad_v/2.0,(pad_v+1)/2.0]).astype("int32"),
578
+ np.zeros((2)).astype("int32")]
579
+ inp_pad = np.pad(inp, paddings, 'constant', constant_values=const_val)
580
+ return (inp_pad,paddings)
581
+
582
+ def calculate_wt_conv_unit(patch, wts, w, b, act):
583
+ k = w.numpy()
584
+ bias = b.numpy()
585
+ b_ind = bias>0
586
+ bias_pos = bias*b_ind
587
+ b_ind = bias<0
588
+ bias_neg = bias*b_ind*-1.0
589
+ conv_out = np.einsum("ijkl,ijk->ijkl",k,patch)
590
+ p_ind = conv_out>0
591
+ p_ind = conv_out*p_ind
592
+ p_sum = np.einsum("ijkl->l",p_ind)
593
+ n_ind = conv_out<0
594
+ n_ind = conv_out*n_ind
595
+ n_sum = np.einsum("ijkl->l",n_ind)*-1.0
596
+ t_sum = p_sum+n_sum
597
+ wt_mat = np.zeros_like(k)
598
+ p_saturate = p_sum>0
599
+ n_saturate = n_sum>0
600
+ if act["type"]=='mono':
601
+ if act["range"]["l"]:
602
+ temp_ind = t_sum > act["range"]["l"]
603
+ p_saturate = temp_ind
604
+ if act["range"]["u"]:
605
+ temp_ind = t_sum < act["range"]["u"]
606
+ n_saturate = temp_ind
607
+ elif act["type"]=='non_mono':
608
+ t_act = act["func"](t_sum)
609
+ p_act = act["func"](p_sum + bias_pos)
610
+ n_act = act["func"](-1*(n_sum + bias_neg))
611
+ if act["range"]["l"]:
612
+ temp_ind = t_sum > act["range"]["l"]
613
+ p_saturate = p_saturate*temp_ind
614
+ if act["range"]["u"]:
615
+ temp_ind = t_sum < act["range"]["u"]
616
+ n_saturate = n_saturate*temp_ind
617
+ temp_ind = np.abs(t_act - p_act)>1e-5
618
+ n_saturate = n_saturate*temp_ind
619
+ temp_ind = np.abs(t_act - n_act)>1e-5
620
+ p_saturate = p_saturate*temp_ind
621
+ p_agg_wt = (1.0/(p_sum+n_sum+bias_pos+bias_neg))*wts*p_saturate
622
+ n_agg_wt = (1.0/(p_sum+n_sum+bias_pos+bias_neg))*wts*n_saturate
623
+
624
+ wt_mat = wt_mat+(p_ind*p_agg_wt)
625
+ wt_mat = wt_mat+(n_ind*n_agg_wt*-1.0)
626
+ wt_mat = np.sum(wt_mat,axis=-1)
627
+ return wt_mat
628
+
629
+ def calculate_wt_conv(wts, inp, w, b, padding, strides, act):
630
+ input_padded, paddings = calculate_padding(w.shape, inp, padding, strides)
631
+ out_ds = np.zeros_like(input_padded)
632
+ for ind1 in range(wts.shape[0]):
633
+ for ind2 in range(wts.shape[1]):
634
+ indexes = [np.arange(ind1*strides[0], ind1*(strides[0])+w.shape[0]),
635
+ np.arange(ind2*strides[1], ind2*(strides[1])+w.shape[1])]
636
+ # Take slice
637
+ tmp_patch = input_padded[np.ix_(indexes[0],indexes[1])]
638
+ updates = calculate_wt_conv_unit(tmp_patch, wts[ind1,ind2,:], w, b, act)
639
+ # Build tensor with "filtered" gradient
640
+ out_ds[np.ix_(indexes[0],indexes[1])]+=updates
641
+ out_ds = out_ds[paddings[0][0]:(paddings[0][0]+inp.shape[0]),
642
+ paddings[1][0]:(paddings[1][0]+inp.shape[1]),:]
643
+ return out_ds
644
+
645
+ def calculate_wt_max_unit(patch, wts, pool_size):
646
+ pmax = np.einsum("ijk,k->ijk",np.ones_like(patch),np.max(np.max(patch,axis=0),axis=0))
647
+ indexes = (patch-pmax)==0
648
+ indexes = indexes.astype(np.float32)
649
+ indexes_norm = 1.0/np.einsum("mnc->c",indexes)
650
+ indexes = np.einsum("ijk,k->ijk",indexes,indexes_norm)
651
+ out = np.einsum("ijk,k->ijk",indexes,wts)
652
+ return out
653
+
654
+ def calculate_wt_maxpool(wts, inp, pool_size, padding, strides):
655
+ input_padded, paddings = calculate_padding(pool_size, inp, padding, strides, -np.inf)
656
+ out_ds = np.zeros_like(input_padded)
657
+ for ind1 in range(wts.shape[0]):
658
+ for ind2 in range(wts.shape[1]):
659
+ indexes = [np.arange(ind1*strides[0], ind1*(strides[0])+pool_size[0]),
660
+ np.arange(ind2*strides[1], ind2*(strides[1])+pool_size[1])]
661
+ # Take slice
662
+ tmp_patch = input_padded[np.ix_(indexes[0],indexes[1])]
663
+ updates = calculate_wt_max_unit(tmp_patch, wts[ind1,ind2,:], pool_size)
664
+ # Build tensor with "filtered" gradient
665
+ out_ds[np.ix_(indexes[0],indexes[1])]+=updates
666
+ out_ds = out_ds[paddings[0][0]:(paddings[0][0]+inp.shape[0]),
667
+ paddings[1][0]:(paddings[1][0]+inp.shape[1]),:]
668
+ return out_ds
669
+
670
+ def calculate_wt_avg_unit(patch, wts, pool_size):
671
+ p_ind = patch>0
672
+ p_ind = patch*p_ind
673
+ p_sum = np.einsum("ijk->k",p_ind)
674
+ n_ind = patch<0
675
+ n_ind = patch*n_ind
676
+ n_sum = np.einsum("ijk->k",n_ind)*-1.0
677
+ t_sum = p_sum+n_sum
678
+ wt_mat = np.zeros_like(patch)
679
+ p_saturate = p_sum>0
680
+ n_saturate = n_sum>0
681
+ t_sum[t_sum==0] = 1.0
682
+ p_agg_wt = (1.0/(t_sum))*wts*p_saturate
683
+ n_agg_wt = (1.0/(t_sum))*wts*n_saturate
684
+ wt_mat = wt_mat+(p_ind*p_agg_wt)
685
+ wt_mat = wt_mat+(n_ind*n_agg_wt*-1.0)
686
+ return wt_mat
687
+
688
+ def calculate_wt_avgpool(wts, inp, pool_size, padding, strides):
689
+ input_padded, paddings = calculate_padding(pool_size, inp, padding, strides, -np.inf)
690
+ out_ds = np.zeros_like(input_padded)
691
+ for ind1 in range(wts.shape[0]):
692
+ for ind2 in range(wts.shape[1]):
693
+ indexes = [np.arange(ind1*strides[0], ind1*(strides[0])+pool_size[0]),
694
+ np.arange(ind2*strides[1], ind2*(strides[1])+pool_size[1])]
695
+ # Take slice
696
+ tmp_patch = input_padded[np.ix_(indexes[0],indexes[1])]
697
+ updates = calculate_wt_avg_unit(tmp_patch, wts[ind1,ind2,:], pool_size)
698
+ # Build tensor with "filtered" gradient
699
+ out_ds[np.ix_(indexes[0],indexes[1])]+=updates
700
+ out_ds = out_ds[paddings[0][0]:(paddings[0][0]+inp.shape[0]),
701
+ paddings[1][0]:(paddings[1][0]+inp.shape[1]),:]
702
+ return out_ds
703
+
704
+ def calculate_wt_gavgpool(wts,inp):
705
+ channels = wts.shape[0]
706
+ wt_mat = np.zeros_like(inp)
707
+ for c in range(channels):
708
+ wt = wts[c]
709
+ temp_wt = wt_mat[...,c]
710
+ x = inp[...,c]
711
+ p_mat = np.copy(x)
712
+ n_mat = np.copy(x)
713
+ p_mat[p_mat<0] = 0
714
+ n_mat[n_mat>0] = 0
715
+ p_sum = np.sum(p_mat)
716
+ n_sum = np.sum(n_mat)*-1
717
+ p_agg_wt = 0.0
718
+ n_agg_wt = 0.0
719
+ if p_sum+n_sum > 0.0:
720
+ p_agg_wt = p_sum/(p_sum+n_sum)
721
+ n_agg_wt = n_sum/(p_sum+n_sum)
722
+ if p_sum == 0.0:
723
+ p_sum = 1.0
724
+ if n_sum == 0.0:
725
+ n_sum = 1.0
726
+ temp_wt = temp_wt+((p_mat/p_sum)*wt*p_agg_wt)
727
+ temp_wt = temp_wt+((n_mat/n_sum)*wt*n_agg_wt*-1.0)
728
+ wt_mat[...,c] = temp_wt
729
+ return wt_mat
730
+
731
+ def calculate_wt_gmaxpool_2d(wts, inp):
732
+ channels = wts.shape[0]
733
+ wt_mat = np.zeros_like(inp)
734
+ for c in range(channels):
735
+ wt = wts[c]
736
+ x = inp[..., c]
737
+ max_val = np.max(x)
738
+ max_indexes = (x == max_val).astype(np.float32)
739
+ max_indexes_norm = 1.0 / np.sum(max_indexes)
740
+ max_indexes = max_indexes * max_indexes_norm
741
+ wt_mat[..., c] = max_indexes * wt
742
+ return wt_mat
743
+
744
+ def calculate_padding_1d(kernel_size, inp, padding, strides, const_val=0.0):
745
+ if padding == 'valid':
746
+ return inp, [0, 0]
747
+ else:
748
+ remainder = inp.shape[0] % strides
749
+ if remainder == 0:
750
+ pad_total = max(0, kernel_size - strides)
751
+ else:
752
+ pad_total = max(0, kernel_size - remainder)
753
+
754
+ pad_left = int(np.floor(pad_total / 2.0))
755
+ pad_right = int(np.ceil(pad_total / 2.0))
756
+
757
+ inp_pad = np.pad(inp, (pad_left, pad_right), 'constant', constant_values=const_val)
758
+ return inp_pad, [pad_left, pad_right]
759
+
760
+ def calculate_wt_conv_unit_1d(patch, wts, w, b, act):
761
+ k = w.numpy()
762
+ bias = b.numpy()
763
+ b_ind = bias > 0
764
+ bias_pos = bias * b_ind
765
+ b_ind = bias < 0
766
+ bias_neg = bias * b_ind * -1.0
767
+ conv_out = np.einsum("ijk,ij->ijk", k, patch)
768
+ p_ind = conv_out > 0
769
+ p_ind = conv_out * p_ind
770
+ p_sum = np.einsum("ijk->k",p_ind)
771
+ n_ind = conv_out < 0
772
+ n_ind = conv_out * n_ind
773
+ n_sum = np.einsum("ijk->k",n_ind) * -1.0
774
+ t_sum = p_sum + n_sum
775
+ wt_mat = np.zeros_like(k)
776
+ p_saturate = p_sum > 0
777
+ n_saturate = n_sum > 0
778
+ if act["type"] == 'mono':
779
+ if act["range"]["l"]:
780
+ temp_ind = t_sum > act["range"]["l"]
781
+ p_saturate = temp_ind
782
+ if act["range"]["u"]:
783
+ temp_ind = t_sum < act["range"]["u"]
784
+ n_saturate = temp_ind
785
+ elif act["type"] == 'non_mono':
786
+ t_act = act["func"](t_sum)
787
+ p_act = act["func"](p_sum + bias_pos)
788
+ n_act = act["func"](-1 * (n_sum + bias_neg))
789
+ if act["range"]["l"]:
790
+ temp_ind = t_sum > act["range"]["l"]
791
+ p_saturate = p_saturate * temp_ind
792
+ if act["range"]["u"]:
793
+ temp_ind = t_sum < act["range"]["u"]
794
+ n_saturate = n_saturate * temp_ind
795
+ temp_ind = np.abs(t_act - p_act) > 1e-5
796
+ n_saturate = n_saturate * temp_ind
797
+ temp_ind = np.abs(t_act - n_act) > 1e-5
798
+ p_saturate = p_saturate * temp_ind
799
+ p_agg_wt = (1.0 / (p_sum + n_sum + bias_pos + bias_neg)) * wts * p_saturate
800
+ n_agg_wt = (1.0 / (p_sum + n_sum + bias_pos + bias_neg)) * wts * n_saturate
801
+
802
+ wt_mat = wt_mat + (p_ind * p_agg_wt)
803
+ wt_mat = wt_mat + (n_ind * n_agg_wt * -1.0)
804
+ wt_mat = np.sum(wt_mat, axis=-1)
805
+ return wt_mat
806
+
807
+ def calculate_wt_conv_1d(wts, inp, w, b, padding, stride, act):
808
+ input_padded, paddings = calculate_padding_1d(w.shape[0], inp, padding, stride)
809
+ out_ds = np.zeros_like(input_padded)
810
+ for ind in range(wts.shape[0]):
811
+ indexes = np.arange(ind * stride, ind * stride + w.shape[0])
812
+ tmp_patch = input_padded[indexes]
813
+ updates = calculate_wt_conv_unit_1d(tmp_patch, wts[ind, :], w, b, act)
814
+ out_ds[indexes] += updates
815
+ out_ds = out_ds[paddings[0]:(paddings[0] + inp.shape[0])]
816
+ return out_ds
817
+
818
+ def calculate_wt_max_unit_1d(patch, wts):
819
+ pmax = np.max(patch, axis=0)
820
+ indexes = (patch - pmax) == 0
821
+ indexes = indexes.astype(np.float32)
822
+ indexes_norm = 1.0 / np.sum(indexes, axis=0)
823
+ indexes = np.einsum("ij,j->ij", indexes, indexes_norm)
824
+ out = np.einsum("ij,j->ij", indexes, wts)
825
+ return out
826
+
827
+ def calculate_wt_maxpool_1d(wts, inp, pool_size, padding, stride):
828
+ input_padded, paddings = calculate_padding_1d(pool_size, inp, padding, stride, -np.inf)
829
+ out_ds = np.zeros_like(input_padded)
830
+ stride=stride[0]
831
+ pool_size=pool_size[0]
832
+ for ind in range(wts.shape[0]):
833
+ indexes = np.arange(ind * stride, ind * stride + pool_size)
834
+ tmp_patch = input_padded[indexes]
835
+ updates = calculate_wt_max_unit_1d(tmp_patch, wts[ind, :])
836
+ out_ds[indexes] += updates
837
+ out_ds = out_ds[paddings[0]:(paddings[0] + inp.shape[0])]
838
+ return out_ds
839
+
840
+ def calculate_wt_avg_unit_1d(patch, wts):
841
+ p_ind = patch > 0
842
+ p_ind = patch * p_ind
843
+ p_sum = np.sum(p_ind, axis=0)
844
+ n_ind = patch < 0
845
+ n_ind = patch * n_ind
846
+ n_sum = np.sum(n_ind, axis=0) * -1.0
847
+ t_sum = p_sum + n_sum
848
+ wt_mat = np.zeros_like(patch)
849
+ p_saturate = p_sum > 0
850
+ n_saturate = n_sum > 0
851
+ t_sum[t_sum == 0] = 1.0
852
+ p_agg_wt = (1.0 / t_sum) * wts * p_saturate
853
+ n_agg_wt = (1.0 / t_sum) * wts * n_saturate
854
+ wt_mat = wt_mat + (p_ind * p_agg_wt)
855
+ wt_mat = wt_mat + (n_ind * n_agg_wt * -1.0)
856
+ return wt_mat
857
+
858
+ def calculate_wt_avgpool_1d(wts, inp, pool_size, padding, stride):
859
+ input_padded, paddings = calculate_padding_1d(pool_size, inp, padding, stride, 0)
860
+ out_ds = np.zeros_like(input_padded)
861
+ stride=stride[0]
862
+ pool_size=pool_size[0]
863
+ for ind in range(wts.shape[0]):
864
+ indexes = np.arange(ind * stride, ind * stride + pool_size)
865
+ tmp_patch = input_padded[indexes]
866
+ updates = calculate_wt_avg_unit_1d(tmp_patch, wts[ind, :])
867
+ out_ds[indexes] += updates
868
+ out_ds = out_ds[paddings[0]:(paddings[0] + inp.shape[0])]
869
+ return out_ds
870
+
871
+ def calculate_wt_gavgpool_1d(wts, inp):
872
+ channels = wts.shape[0]
873
+ wt_mat = np.zeros_like(inp)
874
+ for c in range(channels):
875
+ wt = wts[c]
876
+ temp_wt = wt_mat[:, c]
877
+ x = inp[:, c]
878
+ p_mat = np.copy(x)
879
+ n_mat = np.copy(x)
880
+ p_mat[p_mat < 0] = 0
881
+ n_mat[n_mat > 0] = 0
882
+ p_sum = np.sum(p_mat)
883
+ n_sum = np.sum(n_mat) * -1
884
+ p_agg_wt = 0.0
885
+ n_agg_wt = 0.0
886
+ if p_sum + n_sum > 0.0:
887
+ p_agg_wt = p_sum / (p_sum + n_sum)
888
+ n_agg_wt = n_sum / (p_sum + n_sum)
889
+ if p_sum == 0.0:
890
+ p_sum = 1.0
891
+ if n_sum == 0.0:
892
+ n_sum = 1.0
893
+ temp_wt = temp_wt + ((p_mat / p_sum) * wt * p_agg_wt)
894
+ temp_wt = temp_wt + ((n_mat / n_sum) * wt * n_agg_wt * -1.0)
895
+ wt_mat[:, c] = temp_wt
896
+ return wt_mat
897
+
898
+ def calculate_wt_gmaxpool_1d(wts, inp):
899
+ channels = wts.shape[0]
900
+ wt_mat = np.zeros_like(inp)
901
+ for c in range(channels):
902
+ wt = wts[c]
903
+ x = inp[:, c]
904
+ max_val = np.max(x)
905
+ max_indexes = (x == max_val).astype(np.float32)
906
+ max_indexes_norm = 1.0 / np.sum(max_indexes)
907
+ max_indexes = max_indexes * max_indexes_norm
908
+ wt_mat[:, c] = max_indexes * wt
909
+ return wt_mat
910
+
911
+
912
+ def calculate_output_padding_conv2d_transpose(input_shape, kernel_size, padding, strides):
913
+ if padding == 'valid':
914
+ out_shape = [(input_shape[0] - 1) * strides[0] + kernel_size[0],
915
+ (input_shape[1] - 1) * strides[1] + kernel_size[1]]
916
+ paddings = [[0, 0], [0, 0], [0, 0]]
917
+ else: # 'same' padding
918
+ out_shape = [input_shape[0] * strides[0], input_shape[1] * strides[1]]
919
+ pad_h = max(0, (input_shape[0] - 1) * strides[0] + kernel_size[0] - out_shape[0])
920
+ pad_v = max(0, (input_shape[1] - 1) * strides[1] + kernel_size[1] - out_shape[1])
921
+ paddings = [[pad_h // 2, pad_h - pad_h // 2],
922
+ [pad_v // 2, pad_v - pad_v // 2],
923
+ [0, 0]]
924
+
925
+ return out_shape, paddings
926
+
927
+ def calculate_wt_conv2d_transpose_unit(patch, wts, w, b, act):
928
+ if patch.ndim == 1:
929
+ patch = patch.reshape(1, 1, -1)
930
+ elif patch.ndim == 2:
931
+ patch = patch.reshape(1, *patch.shape)
932
+ elif patch.ndim != 3:
933
+ raise ValueError(f"Unexpected patch shape: {patch.shape}")
934
+
935
+ k = tf.transpose(w, perm=[0, 1, 3, 2]).numpy()
936
+ bias = b.numpy()
937
+ b_ind = bias > 0
938
+ bias_pos = bias * b_ind
939
+ b_ind = bias < 0
940
+ bias_neg = bias * b_ind * -1.0
941
+
942
+ conv_out = np.einsum('ijkl,mnk->ijkl', k, patch)
943
+ p_ind = conv_out > 0
944
+ p_ind = conv_out * p_ind
945
+ n_ind = conv_out < 0
946
+ n_ind = conv_out * n_ind
947
+
948
+ p_sum = np.einsum("ijkl->l", p_ind)
949
+ n_sum = np.einsum("ijkl->l", n_ind) * -1.0
950
+ t_sum = p_sum + n_sum
951
+
952
+ wt_mat = np.zeros_like(k)
953
+ p_saturate = p_sum > 0
954
+ n_saturate = n_sum > 0
955
+
956
+ if act["type"] == 'mono':
957
+ if act["range"]["l"]:
958
+ p_saturate = t_sum > act["range"]["l"]
959
+ if act["range"]["u"]:
960
+ n_saturate = t_sum < act["range"]["u"]
961
+ elif act["type"] == 'non_mono':
962
+ t_act = act["func"](t_sum)
963
+ p_act = act["func"](p_sum + bias_pos)
964
+ n_act = act["func"](-1 * (n_sum + bias_neg))
965
+ if act["range"]["l"]:
966
+ temp_ind = t_sum > act["range"]["l"]
967
+ p_saturate = p_saturate * temp_ind
968
+ if act["range"]["u"]:
969
+ temp_ind = t_sum < act["range"]["u"]
970
+ n_saturate = n_saturate * temp_ind
971
+ temp_ind = np.abs(t_act - p_act) > 1e-5
972
+ n_saturate = n_saturate * temp_ind
973
+ temp_ind = np.abs(t_act - n_act) > 1e-5
974
+ p_saturate = p_saturate * temp_ind
975
+
976
+ p_agg_wt = (1.0 / (p_sum + n_sum + bias_pos + bias_neg)) * wts * p_saturate
977
+ n_agg_wt = (1.0 / (p_sum + n_sum + bias_pos + bias_neg)) * wts * n_saturate
978
+
979
+ wt_mat = wt_mat + (p_ind * p_agg_wt)
980
+ wt_mat = wt_mat + (n_ind * n_agg_wt * -1.0)
981
+ wt_mat = np.sum(wt_mat, axis=-1)
982
+ return wt_mat
983
+
984
+ def calculate_wt_conv2d_transpose(wts, inp, w, b, padding, strides, act):
985
+ out_shape, paddings = calculate_output_padding_conv2d_transpose(inp.shape, w.shape, padding, strides)
986
+ out_ds = np.zeros(out_shape + [w.shape[3]])
987
+
988
+ for ind1 in range(inp.shape[0]):
989
+ for ind2 in range(inp.shape[1]):
990
+ out_ind1 = ind1 * strides[0]
991
+ out_ind2 = ind2 * strides[1]
992
+ tmp_patch = inp[ind1, ind2, :]
993
+ updates = calculate_wt_conv2d_transpose_unit(tmp_patch, wts[ind1, ind2, :], w, b, act)
994
+ end_ind1 = min(out_ind1 + w.shape[0], out_shape[0])
995
+ end_ind2 = min(out_ind2 + w.shape[1], out_shape[1])
996
+ valid_updates = updates[:end_ind1 - out_ind1, :end_ind2 - out_ind2, :]
997
+ out_ds[out_ind1:end_ind1, out_ind2:end_ind2, :] += valid_updates
998
+
999
+ if padding == 'same':
1000
+ adjusted_out_ds = np.zeros(inp.shape)
1001
+ for i in range(inp.shape[0]):
1002
+ for j in range(inp.shape[1]):
1003
+ start_i = max(0, i * strides[0])
1004
+ start_j = max(0, j * strides[1])
1005
+ end_i = min(out_ds.shape[0], (i+1) * strides[0])
1006
+ end_j = min(out_ds.shape[1], (j+1) * strides[1])
1007
+ relevant_area = out_ds[start_i:end_i, start_j:end_j, :]
1008
+ adjusted_out_ds[i, j, :] = np.sum(relevant_area, axis=(0, 1))
1009
+ out_ds = adjusted_out_ds
1010
+ else:
1011
+ out_ds = out_ds[paddings[0][0]:(paddings[0][0] + inp.shape[0]),
1012
+ paddings[1][0]:(paddings[1][0] + inp.shape[1]), :]
1013
+
1014
+ return out_ds
1015
+
1016
+
1017
+ def calculate_output_padding_conv1d_transpose(input_shape, kernel_size, padding, strides):
1018
+ if padding == 'valid':
1019
+ out_shape = [(input_shape[0] - 1) * strides + kernel_size[0]]
1020
+ paddings = [[0, 0], [0, 0]]
1021
+ else: # 'same' padding
1022
+ out_shape = [input_shape[0] * strides]
1023
+ pad_h = max(0, (input_shape[0] - 1) * strides + kernel_size[0] - out_shape[0])
1024
+ paddings = [[pad_h // 2, pad_h // 2],
1025
+ [0, 0]]
1026
+
1027
+ return out_shape, paddings
1028
+
1029
+ def calculate_wt_conv1d_transpose_unit(patch, wts, w, b, act):
1030
+ if patch.ndim == 1:
1031
+ patch = patch.reshape(1, -1)
1032
+ elif patch.ndim != 2:
1033
+ raise ValueError(f"Unexpected patch shape: {patch.shape}")
1034
+ k = tf.transpose(w, perm=[0, 2, 1]).numpy()
1035
+ bias = b.numpy()
1036
+ b_ind = bias > 0
1037
+ bias_pos = bias * b_ind
1038
+ b_ind = bias < 0
1039
+ bias_neg = bias * b_ind * -1.0
1040
+ conv_out = np.einsum('ijk,mj->ijk', k, patch)
1041
+ p_ind = conv_out > 0
1042
+ p_ind = conv_out * p_ind
1043
+ n_ind = conv_out < 0
1044
+ n_ind = conv_out * n_ind
1045
+
1046
+ p_sum = np.einsum("ijl->l", p_ind)
1047
+ n_sum = np.einsum("ijl->l", n_ind) * -1.0
1048
+ t_sum = p_sum + n_sum
1049
+
1050
+ wt_mat = np.zeros_like(k)
1051
+ p_saturate = p_sum > 0
1052
+ n_saturate = n_sum > 0
1053
+
1054
+ if act["type"] == 'mono':
1055
+ if act["range"]["l"]:
1056
+ p_saturate = t_sum > act["range"]["l"]
1057
+ if act["range"]["u"]:
1058
+ n_saturate = t_sum < act["range"]["u"]
1059
+ elif act["type"] == 'non_mono':
1060
+ t_act = act["func"](t_sum)
1061
+ p_act = act["func"](p_sum + bias_pos)
1062
+ n_act = act["func"](-1 * (n_sum + bias_neg))
1063
+ if act["range"]["l"]:
1064
+ temp_ind = t_sum > act["range"]["l"]
1065
+ p_saturate = p_saturate * temp_ind
1066
+ if act["range"]["u"]:
1067
+ temp_ind = t_sum < act["range"]["u"]
1068
+ n_saturate = n_saturate * temp_ind
1069
+ temp_ind = np.abs(t_act - p_act) > 1e-5
1070
+ n_saturate = n_saturate * temp_ind
1071
+ temp_ind = np.abs(t_act - n_act) > 1e-5
1072
+ p_saturate = p_saturate * temp_ind
1073
+
1074
+ p_agg_wt = (1.0 / (p_sum + n_sum + bias_pos + bias_neg)) * wts * p_saturate
1075
+ n_agg_wt = (1.0 / (p_sum + n_sum + bias_pos + bias_neg)) * wts * n_saturate
1076
+ wt_mat = wt_mat + (p_ind * p_agg_wt)
1077
+ wt_mat = wt_mat + (n_ind * n_agg_wt * -1.0)
1078
+ wt_mat = np.sum(wt_mat, axis=-1)
1079
+ return wt_mat
1080
+
1081
+ def calculate_wt_conv1d_transpose(wts, inp, w, b, padding, strides, act):
1082
+ out_shape, paddings = calculate_output_padding_conv1d_transpose(inp.shape, w.shape, padding, strides)
1083
+ out_ds = np.zeros(out_shape + [w.shape[2]])
1084
+ for ind in range(inp.shape[0]):
1085
+ out_ind = ind * strides
1086
+ tmp_patch = inp[ind, :]
1087
+ updates = calculate_wt_conv1d_transpose_unit(tmp_patch, wts[ind, :], w, b, act)
1088
+ end_ind = min(out_ind + w.shape[0], out_shape[0])
1089
+ valid_updates = updates[:end_ind - out_ind, :]
1090
+ out_ds[out_ind:end_ind, :] += valid_updates
1091
+
1092
+ if padding == 'same':
1093
+ adjusted_out_ds = np.zeros(inp.shape)
1094
+ for i in range(inp.shape[0]):
1095
+ start_i = max(0, i * strides)
1096
+ end_i = min(out_ds.shape[0], (i + 1) * strides)
1097
+ relevant_area = out_ds[start_i:end_i, :]
1098
+ adjusted_out_ds[i, :] = np.sum(relevant_area, axis=0)
1099
+ out_ds = adjusted_out_ds
1100
+ else:
1101
+ out_ds = out_ds[paddings[0][0]:(paddings[0][0] + inp.shape[0]), :]
1102
+ return out_ds
1103
+
1104
+ ####################################################################
1105
+ ################### Encoder Model ####################
1106
+ ####################################################################
1107
+ def stabilize(matrix, epsilon=1e-6):
1108
+ return matrix + epsilon * np.sign(matrix)
1109
+
1110
+
1111
+ def calculate_relevance_V(wts, value_output):
1112
+ # Initialize wt_mat with zeros
1113
+ wt_mat_V = np.zeros((wts.shape[0], wts.shape[1], *value_output.shape))
1114
+
1115
+ for i in range(wts.shape[0]):
1116
+ for j in range(wts.shape[1]):
1117
+ l1_ind1 = value_output
1118
+ wt_ind1 = wt_mat_V[i, j]
1119
+ wt = wts[i, j]
1120
+
1121
+ p_ind = l1_ind1 > 0
1122
+ n_ind = l1_ind1 < 0
1123
+ p_sum = np.sum(l1_ind1[p_ind])
1124
+ n_sum = np.sum(l1_ind1[n_ind]) * -1
1125
+
1126
+ if p_sum > 0:
1127
+ p_agg_wt = p_sum / (p_sum + n_sum)
1128
+ else:
1129
+ p_agg_wt = 0
1130
+ if n_sum > 0:
1131
+ n_agg_wt = n_sum / (p_sum + n_sum)
1132
+ else:
1133
+ n_agg_wt = 0
1134
+
1135
+ if p_sum == 0:
1136
+ p_sum = 1
1137
+ if n_sum == 0:
1138
+ n_sum = 1
1139
+
1140
+ wt_ind1[p_ind] = (l1_ind1[p_ind] / p_sum) * wt * p_agg_wt
1141
+ wt_ind1[n_ind] = (l1_ind1[n_ind] / n_sum) * wt * n_agg_wt * -1.0
1142
+
1143
+ wt_mat_V = np.sum(wt_mat_V, axis=(0,1))
1144
+ return wt_mat_V
1145
+
1146
+
1147
+ def calculate_relevance_QK(wts, QK_output):
1148
+ # Initialize wt_mat with zeros
1149
+ wt_mat_QK = np.zeros((wts.shape[0], wts.shape[1], *QK_output.shape))
1150
+
1151
+ for i in range(wts.shape[0]):
1152
+ for j in range(wts.shape[1]):
1153
+ l1_ind1 = QK_output
1154
+ wt_ind1 = wt_mat_QK[i, j]
1155
+ wt = wts[i, j]
1156
+
1157
+ p_ind = l1_ind1 > 0
1158
+ n_ind = l1_ind1 < 0
1159
+ p_sum = np.sum(l1_ind1[p_ind])
1160
+ n_sum = np.sum(l1_ind1[n_ind]) * -1
1161
+
1162
+ t_sum = p_sum - n_sum
1163
+
1164
+ # This layer has a softmax activation function
1165
+ act = {
1166
+ "name": "softmax",
1167
+ "range": {"l": -1, "u": 2},
1168
+ "type": "mono",
1169
+ "func": None,
1170
+ }
1171
+
1172
+ if act["type"] == "mono":
1173
+ if act["range"]["l"]:
1174
+ if t_sum < act["range"]["l"]:
1175
+ p_sum = 0
1176
+ if act["range"]["u"]:
1177
+ if t_sum > act["range"]["u"]:
1178
+ n_sum = 0
1179
+
1180
+ if p_sum > 0:
1181
+ p_agg_wt = p_sum / (p_sum + n_sum)
1182
+ else:
1183
+ p_agg_wt = 0
1184
+
1185
+ if n_sum > 0:
1186
+ n_agg_wt = n_sum / (p_sum + n_sum)
1187
+ else:
1188
+ n_agg_wt = 0
1189
+
1190
+ if p_sum == 0:
1191
+ p_sum = 1
1192
+ if n_sum == 0:
1193
+ n_sum = 1
1194
+
1195
+ wt_ind1[p_ind] = (l1_ind1[p_ind] / p_sum) * wt * p_agg_wt
1196
+ wt_ind1[n_ind] = (l1_ind1[n_ind] / n_sum) * wt * n_agg_wt * -1.0
1197
+
1198
+ wt_mat_QK = np.sum(wt_mat_QK, axis=(0, 1))
1199
+ return wt_mat_QK
1200
+
1201
+
1202
+ def calculate_wt_self_attention(wts, inp, w):
1203
+ '''
1204
+ Input:
1205
+ wts: relevance score of the layer
1206
+ inp: input to the layer
1207
+ w: weights of the layer- ['W_q', 'W_k', 'W_v', 'W_o']
1208
+
1209
+ Outputs:
1210
+ Step-1: outputs = torch.matmul(input_a, input_b)
1211
+ Step-2: outputs = F.softmax(inputs, dim=dim, dtype=dtype)
1212
+ Step-3: outputs = input_a * input_b
1213
+ '''
1214
+ query_output = np.einsum('ij,kj->ik', inp, w['W_q'].T)
1215
+ key_output = np.einsum('ij,kj->ik', inp, w['W_k'].T)
1216
+ value_output = np.einsum('ij,kj->ik', inp, w['W_v'].T)
1217
+
1218
+ # --------------- Relevance Calculation for Step-3 -----------------------
1219
+ relevance_V = wts / 2
1220
+ relevance_QK = wts / 2
1221
+
1222
+ # --------------- Relevance Calculation for V --------------------------------
1223
+ wt_mat_V = calculate_relevance_V(relevance_V, value_output)
1224
+
1225
+ # --------------- Transformed Relevance QK ----------------------------------
1226
+ QK_output = np.einsum('ij,kj->ik', query_output, key_output)
1227
+ wt_mat_QK = calculate_relevance_QK(relevance_QK, QK_output)
1228
+
1229
+ # --------------- Relevance Calculation for K and Q --------------------------------
1230
+ stabilized_QK_output = stabilize(QK_output * 2)
1231
+ norm_wt_mat_QK = wt_mat_QK / stabilized_QK_output
1232
+ wt_mat_Q = np.einsum('ij,jk->ik', norm_wt_mat_QK, key_output) * query_output
1233
+ wt_mat_K = np.einsum('ij,ik->kj', query_output, norm_wt_mat_QK) * key_output
1234
+
1235
+ wt_mat = wt_mat_V + wt_mat_K + wt_mat_Q
1236
+ return wt_mat
1237
+
1238
+
1239
+ def calculate_wt_feed_forward(wts, inp, w):
1240
+ intermediate_output = np.einsum('ij,jk->ik', inp, w['W_int'])
1241
+ feed_forward_output = np.einsum('ij,jk->ik', intermediate_output, w['W_out'])
1242
+
1243
+ relevance_input = np.zeros(inp.shape)
1244
+ relevance_out = np.zeros(intermediate_output.shape)
1245
+
1246
+ # Relevance propagation for 2nd layer
1247
+ for i in range(wts.shape[0]):
1248
+ R2 = wts[i]
1249
+ contribution_matrix2 = np.einsum('ij,j->ij', w['W_out'].T, intermediate_output[i])
1250
+ wt_mat2 = np.zeros(contribution_matrix2.shape)
1251
+
1252
+ for j in range(contribution_matrix2.shape[0]):
1253
+ l1_ind1 = contribution_matrix2[j]
1254
+ wt_ind1 = wt_mat2[j]
1255
+ wt = R2[j]
1256
+
1257
+ p_ind = l1_ind1 > 0
1258
+ n_ind = l1_ind1 < 0
1259
+ p_sum = np.sum(l1_ind1[p_ind])
1260
+ n_sum = np.sum(l1_ind1[n_ind]) * -1
1261
+
1262
+ if p_sum > 0:
1263
+ p_agg_wt = p_sum / (p_sum + n_sum)
1264
+ else:
1265
+ p_agg_wt = 0
1266
+
1267
+ if n_sum > 0:
1268
+ n_agg_wt = n_sum / (p_sum + n_sum)
1269
+ else:
1270
+ n_agg_wt = 0
1271
+
1272
+ if p_sum == 0:
1273
+ p_sum = 1
1274
+ if n_sum == 0:
1275
+ n_sum = 1
1276
+
1277
+ wt_ind1[p_ind] = (l1_ind1[p_ind] / p_sum) * wt * p_agg_wt
1278
+ wt_ind1[n_ind] = (l1_ind1[n_ind] / n_sum) * wt * n_agg_wt * -1.0
1279
+
1280
+ relevance_out[i] = wt_mat2.sum(axis=0)
1281
+
1282
+ # Relevance propagation for 1st layer
1283
+ for i in range(relevance_out.shape[0]):
1284
+ R1 = relevance_out[i]
1285
+ contribution_matrix1 = np.einsum('ij,j->ij', w['W_int'].T, inp[i])
1286
+ wt_mat1 = np.zeros(contribution_matrix1.shape)
1287
+
1288
+ for j in range(contribution_matrix1.shape[0]):
1289
+ l1_ind1 = contribution_matrix1[j]
1290
+ wt_ind1 = wt_mat1[j]
1291
+ wt = R1[j]
1292
+
1293
+ p_ind = l1_ind1 > 0
1294
+ n_ind = l1_ind1 < 0
1295
+ p_sum = np.sum(l1_ind1[p_ind])
1296
+ n_sum = np.sum(l1_ind1[n_ind]) * -1
1297
+
1298
+ t_sum = p_sum - n_sum
1299
+
1300
+ # This layer has a ReLU activation function
1301
+ act = {
1302
+ "name": "relu",
1303
+ "range": {"l": 0, "u": None},
1304
+ "type": "mono",
1305
+ "func": None,
1306
+ }
1307
+
1308
+ if act["type"] == "mono":
1309
+ if act["range"]["l"]:
1310
+ if t_sum < act["range"]["l"]:
1311
+ p_sum = 0
1312
+ if act["range"]["u"]:
1313
+ if t_sum > act["range"]["u"]:
1314
+ n_sum = 0
1315
+
1316
+ if p_sum > 0:
1317
+ p_agg_wt = p_sum / (p_sum + n_sum)
1318
+ else:
1319
+ p_agg_wt = 0
1320
+
1321
+ if n_sum > 0:
1322
+ n_agg_wt = n_sum / (p_sum + n_sum)
1323
+ else:
1324
+ n_agg_wt = 0
1325
+
1326
+ if p_sum == 0:
1327
+ p_sum = 1
1328
+ if n_sum == 0:
1329
+ n_sum = 1
1330
+
1331
+ wt_ind1[p_ind] = (l1_ind1[p_ind] / p_sum) * wt * p_agg_wt
1332
+ wt_ind1[n_ind] = (l1_ind1[n_ind] / n_sum) * wt * n_agg_wt * -1.0
1333
+
1334
+ relevance_input[i] = wt_mat1.sum(axis=0)
1335
+
1336
+ return relevance_input
1337
+
1338
+
1339
+ def calculate_wt_pooler(wts, inp, w):
1340
+ '''
1341
+ Input:
1342
+ wts: relevance score of the layer
1343
+ inp: input to the layer
1344
+ w: weights of the layer- ['W_p', 'b_p']
1345
+ '''
1346
+ relevance_inp = np.zeros(inp.shape)
1347
+
1348
+ for i in range(inp.shape[0]):
1349
+ # Compute contribution matrix
1350
+ contribution_matrix = np.einsum('ij,j->ij', w['W_p'], inp[i])
1351
+ wt_mat = np.zeros(contribution_matrix.shape)
1352
+
1353
+ # Iterate over each unit
1354
+ for j in range(contribution_matrix.shape[0]):
1355
+ l1_ind1 = contribution_matrix[j]
1356
+ wt_ind1 = wt_mat[j]
1357
+ wt = wts[j]
1358
+
1359
+ p_ind = l1_ind1 > 0
1360
+ n_ind = l1_ind1 < 0
1361
+ p_sum = np.sum(l1_ind1[p_ind])
1362
+ n_sum = np.sum(l1_ind1[n_ind]) * -1
1363
+
1364
+ # Calculate biases
1365
+ pbias = max(w['b_p'][j], 0)
1366
+ nbias = min(w['b_p'][j], 0) * -1
1367
+
1368
+ t_sum = p_sum + pbias - n_sum - nbias
1369
+
1370
+ # This layer has a tanh activation function
1371
+ act = {
1372
+ "name": "tanh",
1373
+ "range": {"l": -2, "u": 2},
1374
+ "type": "mono",
1375
+ "func": None
1376
+ }
1377
+
1378
+ # Apply activation function constraints
1379
+ if act["type"] == "mono":
1380
+ if act["range"]["l"]:
1381
+ if t_sum < act["range"]["l"]:
1382
+ p_sum = 0
1383
+ if act["range"]["u"]:
1384
+ if t_sum > act["range"]["u"]:
1385
+ n_sum = 0
1386
+
1387
+ # Aggregate weights based on positive and negative contributions
1388
+ p_agg_wt = 0
1389
+ n_agg_wt = 0
1390
+ if p_sum > 0:
1391
+ p_agg_wt = (p_sum + pbias) / (p_sum + n_sum + pbias + nbias)
1392
+ p_agg_wt *= (p_sum / (p_sum + pbias))
1393
+
1394
+ if n_sum > 0:
1395
+ n_agg_wt = (n_sum + nbias) / (p_sum + n_sum + pbias + nbias)
1396
+ n_agg_wt *= (n_sum / (n_sum + nbias))
1397
+
1398
+ # Prevent division by zero
1399
+ if p_sum == 0:
1400
+ p_sum = 1
1401
+ if n_sum == 0:
1402
+ n_sum = 1
1403
+
1404
+ # Update weight matrix
1405
+ wt_ind1[p_ind] = (l1_ind1[p_ind] / p_sum) * wt * p_agg_wt
1406
+ wt_ind1[n_ind] = (l1_ind1[n_ind] / n_sum) * wt * n_agg_wt * -1.0
1407
+
1408
+ # Calculate relevance for each token
1409
+ relevance_inp[i] = wt_mat.sum(axis=0)
1410
+
1411
+ relevance_inp *= (100 / np.sum(relevance_inp))
1412
+ return relevance_inp
1413
+
1414
+
1415
+ def calculate_wt_classifier(wts, inp, w):
1416
+ '''
1417
+ Input:
1418
+ wts: relevance score of the layer
1419
+ inp: input to the layer
1420
+ w: weights of the layer- ['W_cls', 'b_cls']
1421
+ '''
1422
+ mul_mat = np.einsum("ij, i->ij", w['W_cls'], inp).T
1423
+ wt_mat = np.zeros(mul_mat.shape)
1424
+
1425
+ for i in range(mul_mat.shape[0]):
1426
+ l1_ind1 = mul_mat[i]
1427
+ wt_ind1 = wt_mat[i]
1428
+ wt = wts[i]
1429
+
1430
+ p_ind = l1_ind1 > 0
1431
+ n_ind = l1_ind1 < 0
1432
+ p_sum = np.sum(l1_ind1[p_ind])
1433
+ n_sum = np.sum(l1_ind1[n_ind]) * -1
1434
+
1435
+ if w['b_cls'][i] > 0:
1436
+ pbias = w['b_cls'][i]
1437
+ nbias = 0
1438
+ else:
1439
+ pbias = 0
1440
+ nbias = w['b_cls'][i]
1441
+
1442
+ t_sum = p_sum + pbias - n_sum - nbias
1443
+
1444
+ # This layer has a softmax activation function
1445
+ act = {
1446
+ "name": "softmax",
1447
+ "range": {"l": -1, "u": 2},
1448
+ "type": "mono",
1449
+ "func": None,
1450
+ }
1451
+
1452
+ if act["type"] == "mono":
1453
+ if act["range"]["l"]:
1454
+ if t_sum < act["range"]["l"]:
1455
+ p_sum = 0
1456
+ if act["range"]["u"]:
1457
+ if t_sum > act["range"]["u"]:
1458
+ n_sum = 0
1459
+
1460
+ if p_sum > 0:
1461
+ p_agg_wt = (p_sum + pbias) / (p_sum + n_sum + pbias + nbias)
1462
+ p_agg_wt = p_agg_wt * (p_sum / (p_sum + pbias))
1463
+ else:
1464
+ p_agg_wt = 0
1465
+ if n_sum > 0:
1466
+ n_agg_wt = (n_sum + nbias) / (p_sum + n_sum + pbias + nbias)
1467
+ n_agg_wt = n_agg_wt * (n_sum / (n_sum + nbias))
1468
+ else:
1469
+ n_agg_wt = 0
1470
+
1471
+ if p_sum == 0:
1472
+ p_sum = 1
1473
+ if n_sum == 0:
1474
+ n_sum = 1
1475
+
1476
+ wt_ind1[p_ind] = (l1_ind1[p_ind] / p_sum) * wt * p_agg_wt
1477
+ wt_ind1[n_ind] = (l1_ind1[n_ind] / n_sum) * wt * n_agg_wt * -1.0
1478
+
1479
+ wt_mat = wt_mat.sum(axis=0)
1480
+ return wt_mat
1481
+
1482
+
1483
+ ####################################################################
1484
+ ################### Encoder-Decoder Model ####################
1485
+ ####################################################################
1486
+
1487
+ def calculate_enc_dec_start_wt(arg, indices):
1488
+ y = np.zeros(arg.shape, dtype=np.float64)
1489
+ value = 1 / arg.shape[0]
1490
+
1491
+ for i in range(arg.shape[0]):
1492
+ y[i][indices[i]] = value
1493
+
1494
+ return y
1495
+
1496
+
1497
+ def calculate_wt_lm_head(wts, inp, w):
1498
+ '''
1499
+ Input:
1500
+ wts: relevance score of the layer
1501
+ inp: input to the layer
1502
+ w: weights of the layer- ['W_lm_head']
1503
+ '''
1504
+ relevance_input = np.zeros(inp.shape)
1505
+
1506
+ for i in range(wts.shape[0]):
1507
+ R = wts[i]
1508
+ contribution_matrix = np.einsum('ij,j->ij', w['W_lm_head'], inp[i])
1509
+ wt_mat = np.zeros(contribution_matrix.shape)
1510
+
1511
+ for j in range(contribution_matrix.shape[0]):
1512
+ l1_ind1 = contribution_matrix[j]
1513
+ wt_ind1 = wt_mat[j]
1514
+ wt = R[j]
1515
+
1516
+ p_ind = l1_ind1 > 0
1517
+ n_ind = l1_ind1 < 0
1518
+
1519
+ p_sum = np.sum(l1_ind1[p_ind])
1520
+ n_sum = np.sum(l1_ind1[n_ind]) * -1
1521
+
1522
+ if p_sum > 0:
1523
+ p_agg_wt = p_sum / (p_sum + n_sum)
1524
+ else:
1525
+ p_agg_wt = 0
1526
+
1527
+ if n_sum > 0:
1528
+ n_agg_wt = n_sum / (p_sum + n_sum)
1529
+ else:
1530
+ n_agg_wt = 0
1531
+
1532
+ if p_sum == 0:
1533
+ p_sum = 1
1534
+ if n_sum == 0:
1535
+ n_sum = 1
1536
+
1537
+ wt_ind1[p_ind] = (l1_ind1[p_ind] / p_sum) * wt * p_agg_wt
1538
+ wt_ind1[n_ind] = (l1_ind1[n_ind] / n_sum) * wt * n_agg_wt * -1.0
1539
+
1540
+ relevance_input[i] = wt_mat.sum(axis=0)
1541
+
1542
+ return relevance_input
1543
+
1544
+
1545
+ def calculate_wt_cross_attention(wts, inp, w):
1546
+ '''
1547
+ Input:
1548
+ wts: relevance score of the layer
1549
+ inp: input to the layer
1550
+ w: weights of the layer- ['W_q', 'W_k', 'W_v', 'W_o']
1551
+ inputs: dict_keys(['query', 'key', 'value'])
1552
+
1553
+ Outputs:
1554
+ Step-1: outputs = torch.matmul(input_a, input_b)
1555
+ Step-2: outputs = F.softmax(inputs, dim=dim, dtype=dtype)
1556
+ Step-3: outputs = input_a * input_b
1557
+ '''
1558
+ k_v_inp, q_inp = inp
1559
+ query_output = np.einsum('ij,kj->ik', q_inp, w['W_q'].T)
1560
+ key_output = np.einsum('ij,kj->ik', k_v_inp, w['W_k'].T)
1561
+ value_output = np.einsum('ij,kj->ik', k_v_inp, w['W_v'].T)
1562
+
1563
+ # --------------- Relevance Calculation for Step-3 -----------------------
1564
+ relevance_V = wts / 2
1565
+ relevance_QK = wts / 2
1566
+
1567
+ # --------------- Relevance Calculation for V --------------------------------
1568
+ wt_mat_V = calculate_relevance_V(relevance_V, value_output)
1569
+
1570
+ # --------------- Transformed Relevance QK ----------------------------------
1571
+ QK_output = np.einsum('ij,kj->ik', query_output, key_output)
1572
+ wt_mat_QK = calculate_relevance_QK(relevance_QK, QK_output)
1573
+
1574
+ # --------------- Relevance Calculation for K and Q --------------------------------
1575
+ stabilized_QK_output = stabilize(QK_output * 2)
1576
+ norm_wt_mat_QK = wt_mat_QK / stabilized_QK_output
1577
+ wt_mat_Q = np.einsum('ij,jk->ik', norm_wt_mat_QK, key_output) * query_output
1578
+ wt_mat_K = np.einsum('ij,ik->kj', query_output, norm_wt_mat_QK) * key_output
1579
+
1580
+ wt_mat_KV = wt_mat_V + wt_mat_K
1581
+ wt_mat = [wt_mat_KV, wt_mat_Q]
1582
+ return wt_mat