dl-backtrace 0.0.12__py3-none-any.whl → 0.0.16__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of dl-backtrace might be problematic. Click here for more details.

Files changed (27) hide show
  1. dl_backtrace/pytorch_backtrace/backtrace/backtrace.py +173 -44
  2. dl_backtrace/pytorch_backtrace/backtrace/utils/__init__.py +3 -0
  3. dl_backtrace/pytorch_backtrace/backtrace/utils/encoder.py +183 -0
  4. dl_backtrace/pytorch_backtrace/backtrace/utils/encoder_decoder.py +489 -0
  5. dl_backtrace/pytorch_backtrace/backtrace/utils/helper.py +95 -0
  6. dl_backtrace/pytorch_backtrace/backtrace/utils/prop.py +481 -0
  7. dl_backtrace/tf_backtrace/backtrace/__init__.py +1 -2
  8. dl_backtrace/tf_backtrace/backtrace/activation_info.py +33 -0
  9. dl_backtrace/tf_backtrace/backtrace/backtrace.py +506 -279
  10. dl_backtrace/tf_backtrace/backtrace/models.py +25 -0
  11. dl_backtrace/tf_backtrace/backtrace/server.py +27 -0
  12. dl_backtrace/tf_backtrace/backtrace/utils/__init__.py +5 -2
  13. dl_backtrace/tf_backtrace/backtrace/utils/encoder.py +206 -0
  14. dl_backtrace/tf_backtrace/backtrace/utils/encoder_decoder.py +501 -0
  15. dl_backtrace/tf_backtrace/backtrace/utils/helper.py +99 -0
  16. dl_backtrace/tf_backtrace/backtrace/utils/utils_contrast.py +1132 -0
  17. dl_backtrace/tf_backtrace/backtrace/utils/utils_prop.py +1582 -0
  18. dl_backtrace/version.py +2 -2
  19. {dl_backtrace-0.0.12.dist-info → dl_backtrace-0.0.16.dist-info}/METADATA +3 -2
  20. dl_backtrace-0.0.16.dist-info/RECORD +29 -0
  21. {dl_backtrace-0.0.12.dist-info → dl_backtrace-0.0.16.dist-info}/WHEEL +1 -1
  22. dl_backtrace/tf_backtrace/backtrace/config.py +0 -41
  23. dl_backtrace/tf_backtrace/backtrace/utils/contrast.py +0 -834
  24. dl_backtrace/tf_backtrace/backtrace/utils/prop.py +0 -725
  25. dl_backtrace-0.0.12.dist-info/RECORD +0 -21
  26. {dl_backtrace-0.0.12.dist-info → dl_backtrace-0.0.16.dist-info}/LICENSE +0 -0
  27. {dl_backtrace-0.0.12.dist-info → dl_backtrace-0.0.16.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,1132 @@
1
+ import gc
2
+ import numpy as np
3
+ import tensorflow as tf
4
+ from tensorflow import keras
5
+ from tensorflow.keras import backend as K
6
+ from tensorflow.keras.backend import sigmoid
7
+ from numpy.lib.stride_tricks import as_strided
8
+
9
+ def np_swish(x, beta = 0.75):
10
+ z = 1/(1 + np.exp(-(beta * x)))
11
+ return (x * z)
12
+
13
+ def np_wave(x, alpha = 1.0):
14
+ return ((alpha*x*np.exp(1.0))/(np.exp(-x)+np.exp(x)))
15
+
16
+ def np_pulse(x,alpha = 1.0):
17
+ return (alpha*(1 - np.tanh(x)*np.tanh(x)))
18
+
19
+ def np_absolute(x,alpha = 1.0):
20
+ return (alpha*x*np.tanh(x))
21
+
22
+ def np_hard_sigmoid(x):
23
+ return np.clip(0.2 * x + 0.5, 0, 1)
24
+
25
+ def np_sigmoid(x):
26
+ z = 1/(1+np.exp(-x))
27
+ return z
28
+
29
+ def np_tanh(x):
30
+ z = np.tanh(x)
31
+ return z.astype(np.float32)
32
+
33
+ def calculate_start_wt(arg,scaler=None,thresholding=0.5,task="binary-classification"):
34
+ if arg.ndim == 2:
35
+ if task=="binary-classification" or task=="multi-class-classification":
36
+ x = np.argmax(arg[0])
37
+ m = np.max(arg[0])
38
+ y_pos = np.zeros_like(arg)
39
+ if scaler:
40
+ y_pos[0][x] = scaler
41
+ else:
42
+ y_pos[0][x] = m
43
+ y_neg = np.array(arg)
44
+ if m<1 and arg.shape[-1]==1:
45
+ y_neg[0][x] = 1-m
46
+ else:
47
+ y_neg[0][x] = 0
48
+ if scaler and np.sum(y_neg)>0:
49
+ y_neg = y_neg*(scaler/np.sum(y_neg))
50
+ elif task == "bbox-regression":
51
+ y_pos = np.zeros_like(arg)
52
+ if scaler:
53
+ y_pos[0] = scaler
54
+ num_non_zero_elements = np.count_nonzero(y)
55
+ if num_non_zero_elements > 0:
56
+ y = y / num_non_zero_elements
57
+ else:
58
+ x = np.argmax(arg[0])
59
+ m = np.max(arg[0])
60
+ y_pos[0] = m
61
+ y_neg = np.array(arg)
62
+ if m<1 and arg.shape[-1]==1:
63
+ y_neg[0][x] = 1-m
64
+ else:
65
+ y_neg[0][x] = 0
66
+ if scaler and np.sum(y_neg)>0:
67
+ y_neg = y_neg*(scaler/np.sum(y_neg))
68
+ else:
69
+ x = np.argmax(arg[0])
70
+ m = np.max(arg[0])
71
+ y_pos = np.zeros_like(arg)
72
+ if scaler:
73
+ y_pos[0][x] = scaler
74
+ else:
75
+ y_pos[0][x] = m
76
+ y_neg = np.array(arg)
77
+ if m<1 and arg.shape[-1]==1:
78
+ y_neg[0][x] = 1-m
79
+ else:
80
+ y_neg[0][x] = 0
81
+ if scaler and np.sum(y_neg)>0:
82
+ y_neg = y_neg*(scaler/np.sum(y_neg))
83
+ elif arg.ndim == 4:
84
+ if task == "binary-segmentation":
85
+ indices = np.where(arg > thresholding)
86
+ y_pos = np.zeros(arg.shape)
87
+ if scaler:
88
+ y_pos[indices] = scaler
89
+ num_non_zero_elements = np.count_nonzero(y_pos)
90
+ if num_non_zero_elements > 0:
91
+ y_pos = y_pos / num_non_zero_elements
92
+ else:
93
+ y_pos[indices] = arg[indices]
94
+
95
+ y_neg = np.array(arg)
96
+ m = np.max(arg[0])
97
+ if m<=1:
98
+ y_neg[indices] = 1 - arg[indices]
99
+ else:
100
+ y_neg[indices] = 0
101
+ if scaler and np.sum(y_neg)>0:
102
+ y_neg = y_neg*(scaler/np.sum(y_neg))
103
+ else:
104
+ indices = np.where(arg > thresholding)
105
+ y_pos = np.zeros(arg.shape)
106
+ if scaler:
107
+ y_pos[indices] = scaler
108
+ num_non_zero_elements = np.count_nonzero(y_pos)
109
+ if num_non_zero_elements > 0:
110
+ y_pos = y_pos / num_non_zero_elements
111
+ else:
112
+ y_pos[indices] = arg[indices]
113
+ num_non_zero_elements = np.count_nonzero(y_pos)
114
+ if num_non_zero_elements > 0:
115
+ y_pos = y_pos / num_non_zero_elements
116
+ y_neg = np.array(arg)
117
+ m = np.max(arg[0])
118
+ if m<1:
119
+ y_neg[indices] = 1 - arg[indices]
120
+ else:
121
+ y_neg[indices] = 0
122
+ if scaler and np.sum(y_neg)>0:
123
+ y_neg = y_neg*(scaler/np.sum(y_neg))
124
+ return y_pos[0],y_neg[0]
125
+
126
+ def calculate_base_wt(p_sum=0,n_sum=0,bias=0,wt_pos=0,wt_neg=0):
127
+ t_diff = p_sum + bias - n_sum
128
+ bias = 0
129
+ wt_sign = 1
130
+ if t_diff>0:
131
+ if wt_pos>wt_neg:
132
+ p_agg_wt = wt_pos
133
+ n_agg_wt = wt_neg
134
+ else:
135
+ p_agg_wt = wt_neg
136
+ n_agg_wt = wt_pos
137
+ wt_sign = -1
138
+ elif t_diff<0:
139
+ if wt_pos<wt_neg:
140
+ p_agg_wt = wt_pos
141
+ n_agg_wt = wt_neg
142
+ else:
143
+ p_agg_wt = wt_neg
144
+ n_agg_wt = wt_pos
145
+ wt_sign = -1
146
+ else:
147
+ p_agg_wt = 0
148
+ n_agg_wt = 0
149
+ if p_sum == 0:
150
+ p_sum = 1
151
+ if n_sum == 0:
152
+ n_sum = 1
153
+ return p_agg_wt,n_agg_wt,p_sum,n_sum,wt_sign
154
+
155
+ def calculate_base_wt_array(p_sum=[],n_sum=[],bias=[],wt_pos=[],wt_neg=[]):
156
+ t_diff = p_sum + bias - n_sum
157
+ t_diff_pos = (t_diff>0)
158
+ t_diff_neg = (t_diff<0)
159
+ wt_sign_pos = wt_pos>wt_neg
160
+ wt_sign_neg = wt_pos<wt_neg
161
+ p_agg_wt_pos = np.zeros_like(wt_pos)
162
+ p_agg_wt_neg = np.zeros_like(wt_pos)
163
+ n_agg_wt_pos = np.zeros_like(wt_pos)
164
+ n_agg_wt_neg = np.zeros_like(wt_pos)
165
+
166
+ p_agg_wt_pos += wt_pos*t_diff_pos*wt_sign_pos
167
+ p_agg_wt_pos += wt_pos*t_diff_neg*wt_sign_neg
168
+
169
+ p_agg_wt_neg += wt_neg*t_diff_pos*wt_sign_neg
170
+ p_agg_wt_neg += wt_neg*t_diff_neg*wt_sign_pos
171
+
172
+ n_agg_wt_pos += wt_pos*t_diff_pos*wt_sign_neg
173
+ n_agg_wt_pos += wt_pos*t_diff_neg*wt_sign_pos
174
+
175
+ n_agg_wt_neg += wt_neg*t_diff_pos*wt_sign_pos
176
+ n_agg_wt_neg += wt_neg*t_diff_neg*wt_sign_neg
177
+
178
+ p_sum[p_sum==0] = 1.0
179
+ n_sum[n_sum==0] = 1.0
180
+
181
+ return p_agg_wt_pos,p_agg_wt_neg,n_agg_wt_pos,n_agg_wt_neg,p_sum,n_sum
182
+
183
+ class LSTM_forward(object):
184
+ def __init__(self, num_cells, units, weights, return_sequence=False, go_backwards=False):
185
+ self.num_cells = num_cells
186
+ self.units = units
187
+ self.kernel = weights[0]
188
+ self.recurrent_kernel = weights[1]
189
+ self.bias = weights[2]
190
+ self.return_sequence = return_sequence
191
+ self.go_backwards = go_backwards
192
+ self.recurrent_activation = tf.math.sigmoid
193
+ self.activation = tf.math.tanh
194
+
195
+ self.compute_log = {}
196
+ for i in range(self.num_cells):
197
+ self.compute_log[i] = {}
198
+ self.compute_log[i]["inp"] = None
199
+ self.compute_log[i]["x"] = None
200
+ self.compute_log[i]["hstate"] = [None,None]
201
+ self.compute_log[i]["cstate"] = [None,None]
202
+ self.compute_log[i]["int_arrays"] = {}
203
+
204
+ def compute_carry_and_output(self, x, h_tm1, c_tm1, cell_num):
205
+ """Computes carry and output using split kernels."""
206
+ x_i, x_f, x_c, x_o = x
207
+ h_tm1_i, h_tm1_f, h_tm1_c, h_tm1_o = h_tm1
208
+ i = self.recurrent_activation(
209
+ x_i + K.dot(h_tm1_i, self.recurrent_kernel[:, :self.units]))
210
+ f = self.recurrent_activation(x_f + K.dot(
211
+ h_tm1_f, self.recurrent_kernel[:, self.units:self.units * 2]))
212
+ c = f * c_tm1 + i * self.activation(x_c + K.dot(
213
+ h_tm1_c, self.recurrent_kernel[:, self.units * 2:self.units * 3]))
214
+ o = self.recurrent_activation(
215
+ x_o + K.dot(h_tm1_o, self.recurrent_kernel[:, self.units * 3:]))
216
+ self.compute_log[cell_num]["int_arrays"]["i"] = i
217
+ self.compute_log[cell_num]["int_arrays"]["f"] = f
218
+ self.compute_log[cell_num]["int_arrays"]["c"] = c
219
+ self.compute_log[cell_num]["int_arrays"]["o"] = o
220
+ return c, o
221
+
222
+ def calculate_lstm_cell_wt(self,inputs, states, cell_num, training=None):
223
+ h_tm1 = states[0] # previous memory state
224
+ c_tm1 = states[1] # previous carry state
225
+ self.compute_log[cell_num]["inp"] = inputs
226
+ self.compute_log[cell_num]["hstate"][0] = h_tm1
227
+ self.compute_log[cell_num]["cstate"][0] = c_tm1
228
+ inputs_i = inputs
229
+ inputs_f = inputs
230
+ inputs_c = inputs
231
+ inputs_o = inputs
232
+ k_i, k_f, k_c, k_o = tf.split(
233
+ self.kernel, num_or_size_splits=4, axis=1)
234
+ x_i = K.dot(inputs_i, k_i)
235
+ x_f = K.dot(inputs_f, k_f)
236
+ x_c = K.dot(inputs_c, k_c)
237
+ x_o = K.dot(inputs_o, k_o)
238
+ b_i, b_f, b_c, b_o = tf.split(
239
+ self.bias, num_or_size_splits=4, axis=0)
240
+ x_i = tf.add(x_i, b_i)
241
+ x_f = tf.add(x_f, b_f)
242
+ x_c = tf.add(x_c, b_c)
243
+ x_o = tf.add(x_o, b_o)
244
+
245
+ h_tm1_i = h_tm1
246
+ h_tm1_f = h_tm1
247
+ h_tm1_c = h_tm1
248
+ h_tm1_o = h_tm1
249
+ x = (x_i, x_f, x_c, x_o)
250
+ h_tm1 = (h_tm1_i, h_tm1_f, h_tm1_c, h_tm1_o)
251
+ c, o = self.compute_carry_and_output(x, h_tm1, c_tm1, cell_num)
252
+ h = o * self.activation(c)
253
+ self.compute_log[cell_num]["x"] = x
254
+ self.compute_log[cell_num]["hstate"][1] = h
255
+ self.compute_log[cell_num]["cstate"][1] = c
256
+ return h, [h, c]
257
+
258
+ def calculate_lstm_wt(self, input_data):
259
+ hstate = tf.convert_to_tensor(np.zeros((1, self.units)), dtype=tf.float32)
260
+ cstate = tf.convert_to_tensor(np.zeros((1, self.units)), dtype=tf.float32)
261
+ output = []
262
+ for ind in range(input_data.shape[0]):
263
+ inp = tf.convert_to_tensor(input_data[ind,:].reshape((1,input_data.shape[1])), dtype=tf.float32)
264
+ h,s = self.calculate_lstm_cell_wt(inp,[hstate,cstate],ind)
265
+ hstate = s[0]
266
+ cstate = s[1]
267
+ output.append(h)
268
+ return output
269
+
270
+ class LSTM_backtrace(object):
271
+ def __init__(self, num_cells, units, weights, return_sequence=False, go_backwards=False):
272
+ self.num_cells = num_cells
273
+ self.units = units
274
+ self.kernel = weights[0]
275
+ self.recurrent_kernel = weights[1]
276
+ self.bias = weights[2]
277
+ self.return_sequence = return_sequence
278
+ self.go_backwards = go_backwards
279
+ self.recurrent_activation = np_sigmoid
280
+ self.activation = np_tanh
281
+
282
+ self.compute_log = {}
283
+
284
+ def calculate_wt_fc(self, wts, inp, w, b, act):
285
+ wts_pos = wts[0]
286
+ wts_neg = wts[1]
287
+ mul_mat = np.einsum("ij,i->ij",w,inp).T
288
+ wt_mat_pos = np.zeros(mul_mat.shape)
289
+ wt_mat_neg = np.zeros(mul_mat.shape)
290
+ for i in range(mul_mat.shape[0]):
291
+ l1_ind1 = mul_mat[i]
292
+ wt_ind1_pos = wt_mat_pos[i]
293
+ wt_ind1_neg = wt_mat_neg[i]
294
+ wt_pos = wts_pos[i]
295
+ wt_neg = wts_neg[i]
296
+ p_ind = l1_ind1>0
297
+ n_ind = l1_ind1<0
298
+ p_sum = np.sum(l1_ind1[p_ind])
299
+ n_sum = np.sum(l1_ind1[n_ind])*-1
300
+ if len(b)>0:
301
+ bias = b[i]
302
+ else:
303
+ bias = 0
304
+ if np.sum(n_ind)==0 and np.sum(p_ind)>0:
305
+ wt_ind1_pos[p_ind] = (l1_ind1[p_ind]/p_sum)*wt_pos
306
+ wt_ind1_neg[p_ind] = (l1_ind1[p_ind]/p_sum)*wt_neg
307
+ elif np.sum(n_ind)>0 and np.sum(p_ind)==0:
308
+ wt_ind1_pos[n_ind] = (l1_ind1[n_ind]/n_sum)*wt_pos*-1
309
+ wt_ind1_neg[n_ind] = (l1_ind1[n_ind]/n_sum)*wt_neg*-1
310
+ else:
311
+ p_agg_wt,n_agg_wt,p_sum,n_sum,wt_sign = calculate_base_wt(p_sum=p_sum,n_sum=n_sum,
312
+ bias=bias,
313
+ wt_pos=wt_pos,wt_neg=wt_neg)
314
+ if wt_sign>0:
315
+ wt_ind1_pos[p_ind] = (l1_ind1[p_ind]/p_sum)*p_agg_wt
316
+ wt_ind1_neg[n_ind] = (l1_ind1[n_ind]/n_sum)*n_agg_wt*-1
317
+ else:
318
+ wt_ind1_neg[p_ind] = (l1_ind1[p_ind]/p_sum)*p_agg_wt
319
+ wt_ind1_pos[n_ind] = (l1_ind1[n_ind]/n_sum)*n_agg_wt*-1
320
+ wt_mat_pos = wt_mat_pos.sum(axis=0)
321
+ wt_mat_neg = wt_mat_neg.sum(axis=0)
322
+ return wt_mat_pos,wt_mat_neg
323
+
324
+
325
+ def calculate_wt_add(self, wts, inp=None):
326
+ wts_pos = wts[0]
327
+ wts_neg = wts[1]
328
+ wt_mat_pos = []
329
+ wt_mat_neg = []
330
+ inp_list = []
331
+ for x in inp:
332
+ wt_mat_pos.append(np.zeros_like(x))
333
+ wt_mat_neg.append(np.zeros_like(x))
334
+ wt_mat_pos = np.array(wt_mat_pos)
335
+ wt_mat_neg = np.array(wt_mat_neg)
336
+ inp_list = np.array(inp)
337
+ for i in range(wt_mat_pos.shape[1]):
338
+ wt_ind1_pos = wt_mat_pos[:,i]
339
+ wt_ind1_neg = wt_mat_neg[:,i]
340
+ wt_pos = wts_pos[i]
341
+ wt_neg = wts_neg[i]
342
+ l1_ind1 = inp_list[:,i]
343
+ p_ind = l1_ind1>0
344
+ n_ind = l1_ind1<0
345
+ p_sum = np.sum(l1_ind1[p_ind])
346
+ n_sum = np.sum(l1_ind1[n_ind])*-1
347
+ if np.sum(n_ind)==0 and np.sum(p_ind)>0:
348
+ wt_ind1_pos[p_ind] = (l1_ind1[p_ind]/p_sum)*wt_pos
349
+ wt_ind1_neg[p_ind] = (l1_ind1[p_ind]/p_sum)*wt_neg
350
+ elif np.sum(n_ind)>0 and np.sum(p_ind)==0:
351
+ wt_ind1_pos[n_ind] = (l1_ind1[n_ind]/n_sum)*wt_pos*-1
352
+ wt_ind1_neg[n_ind] = (l1_ind1[n_ind]/n_sum)*wt_neg*-1
353
+ else:
354
+ p_agg_wt,n_agg_wt,p_sum,n_sum,wt_sign = calculate_base_wt(p_sum=p_sum,n_sum=n_sum,
355
+ bias=0.0,
356
+ wt_pos=wt_pos,wt_neg=wt_neg)
357
+ if wt_sign>0:
358
+ wt_ind1_pos[p_ind] = (l1_ind1[p_ind]/p_sum)*p_agg_wt
359
+ wt_ind1_neg[n_ind] = (l1_ind1[n_ind]/n_sum)*n_agg_wt*-1
360
+ else:
361
+ wt_ind1_neg[p_ind] = (l1_ind1[p_ind]/p_sum)*p_agg_wt
362
+ wt_ind1_pos[n_ind] = (l1_ind1[n_ind]/n_sum)*n_agg_wt*-1
363
+ wt_mat_pos[:,i] = wt_ind1_pos
364
+ wt_mat_neg[:,i] = wt_ind1_neg
365
+ wt_mat_pos = [i.reshape(wts_pos.shape) for i in list(wt_mat_pos)]
366
+ wt_mat_neg = [i.reshape(wts_neg.shape) for i in list(wt_mat_neg)]
367
+ output = []
368
+ for i in range(len(wt_mat_pos)):
369
+ output.append((wt_mat_pos[i],wt_mat_neg[i]))
370
+
371
+ return output
372
+
373
+
374
+
375
+ def calculate_wt_multiply(self, wts, inp=None):
376
+ wts_pos = wts[0]
377
+ wts_neg = wts[1]
378
+ inp_list = []
379
+ wt_mat_pos = []
380
+ wt_mat_neg = []
381
+ for x in inp:
382
+ wt_mat_pos.append(np.zeros_like(x))
383
+ wt_mat_neg.append(np.zeros_like(x))
384
+ wt_mat_pos = np.array(wt_mat_pos)
385
+ wt_mat_neg = np.array(wt_mat_neg)
386
+ inp_list = np.array(inp)
387
+ inp1 = np.abs(inp[0])
388
+ inp2 = np.abs(inp[1])
389
+ inp_sum = inp1+inp2
390
+ inp_prod = inp1*inp2
391
+ inp1[inp_sum==0] = 0
392
+ inp2[inp_sum==0] = 0
393
+ inp1[inp_prod==0] = 0
394
+ inp2[inp_prod==0] = 0
395
+ inp_sum[inp_sum==0] = 1
396
+ inp_wt1_pos = np.nan_to_num((inp2/inp_sum)*wts_pos)
397
+ inp_wt1_neg = np.nan_to_num((inp2/inp_sum)*wts_neg)
398
+ inp_wt2_pos = np.nan_to_num((inp1/inp_sum)*wts_pos)
399
+ inp_wt2_neg = np.nan_to_num((inp1/inp_sum)*wts_neg)
400
+ return [[inp_wt1_pos,inp_wt1_neg],[inp_wt2_pos,inp_wt2_neg]]
401
+
402
+
403
+ def compute_carry_and_output(self, wt_o, wt_c, h_tm1, c_tm1, x, cell_num):
404
+ """Computes carry and output using split kernels."""
405
+ h_tm1_i, h_tm1_f, h_tm1_c, h_tm1_o = (h_tm1,h_tm1,h_tm1,h_tm1)
406
+ x_i, x_f, x_c, x_o = x
407
+ f = self.compute_log[cell_num]["int_arrays"]["f"].numpy()[0]
408
+ i = self.compute_log[cell_num]["int_arrays"]["i"].numpy()[0]
409
+ temp1 = np.dot(h_tm1_o, self.recurrent_kernel[:, self.units * 3:]).astype(np.float32)
410
+ wt_x_o, wt_temp1 = self.calculate_wt_add(wt_o,[x_o,temp1])
411
+ wt_h_tm1_o = self.calculate_wt_fc(wt_temp1, h_tm1_o, self.recurrent_kernel[:, self.units * 3:], [], {"type":None})
412
+
413
+
414
+ temp2 = f*c_tm1
415
+ temp3_1 = np.dot(h_tm1_c, self.recurrent_kernel[:, self.units * 2:self.units * 3])
416
+ temp3_2 = self.activation(x_c + temp3_1)
417
+ temp3_3 = i*temp3_2
418
+ wt_temp2,wt_temp3_3 = self.calculate_wt_add(wt_c,[temp2,temp3_3])
419
+ wt_f, wt_c_tm1 = self.calculate_wt_multiply(wt_temp2,[f,c_tm1])
420
+ wt_i, wt_temp3_2 = self.calculate_wt_multiply(wt_temp3_3,[i,temp3_2])
421
+ wt_x_c,wt_temp3_1 = self.calculate_wt_add(wt_temp3_2,[x_c,temp3_1])
422
+ wt_h_tm1_c = self.calculate_wt_fc(wt_temp3_1, h_tm1_c, self.recurrent_kernel[:, self.units * 2:self.units * 3], [], {"type":None})
423
+
424
+ temp4 = np.dot(h_tm1_f, self.recurrent_kernel[:, self.units:self.units * 2])
425
+ wt_x_f, wt_temp4 = self.calculate_wt_add(wt_f,[x_f,temp4])
426
+ wt_h_tm1_f = self.calculate_wt_fc(wt_temp4, h_tm1_f, self.recurrent_kernel[:, self.units:self.units * 2], [], {"type":None})
427
+
428
+ temp5 = np.dot(h_tm1_i, self.recurrent_kernel[:, :self.units])
429
+ wt_x_i, wt_temp5 = self.calculate_wt_add(wt_i,[x_i,temp5])
430
+ wt_h_tm1_i = self.calculate_wt_fc(wt_temp5, h_tm1_i, self.recurrent_kernel[:, :self.units], [], {"type":None})
431
+
432
+
433
+ return (wt_x_i, wt_x_f, wt_x_c, wt_x_o,
434
+ wt_h_tm1_i, wt_h_tm1_f, wt_h_tm1_c, wt_h_tm1_o,
435
+ wt_c_tm1)
436
+
437
+
438
+ def calculate_lstm_cell_wt(self, cell_num, wts_hstate, wts_cstate):
439
+ o = self.compute_log[cell_num]["int_arrays"]["o"].numpy()[0]
440
+ c = self.compute_log[cell_num]["cstate"][1].numpy()[0]
441
+ h_tm1 = self.compute_log[cell_num]["hstate"][0].numpy()[0]
442
+ c_tm1 = self.compute_log[cell_num]["cstate"][0].numpy()[0]
443
+ x = [i.numpy()[0] for i in self.compute_log[cell_num]["x"]]
444
+ wt_o,wt_c = self.calculate_wt_multiply(wts_hstate,[o,self.activation(c)])# h = o * self.activation(c)
445
+ wt_c[0] = wt_c[0]+wts_cstate[0]
446
+ wt_c[1] = wt_c[1]+wts_cstate[1]
447
+ wt_x_i, wt_x_f, wt_x_c, wt_x_o, wt_h_tm1_i, wt_h_tm1_f, wt_h_tm1_c, wt_h_tm1_o, wt_c_tm1 = self.compute_carry_and_output(wt_o, wt_c, h_tm1, c_tm1, x, cell_num)
448
+ wt_h_tm1 = [wt_h_tm1_i[0] + wt_h_tm1_f[0] + wt_h_tm1_c[0] + wt_h_tm1_o[0],
449
+ wt_h_tm1_i[1] + wt_h_tm1_f[1] + wt_h_tm1_c[1] + wt_h_tm1_o[1]]
450
+ inputs = self.compute_log[cell_num]["inp"].numpy()[0]
451
+ k_i, k_f, k_c, k_o = np.split(
452
+ self.kernel, indices_or_sections=4, axis=1)
453
+ b_i, b_f, b_c, b_o = np.split(
454
+ self.bias, indices_or_sections=4, axis=0)
455
+
456
+ wt_inputs_i = self.calculate_wt_fc(wt_x_i, inputs, k_i, b_i, {"type":None})
457
+ wt_inputs_f = self.calculate_wt_fc(wt_x_f, inputs, k_f, b_f, {"type":None})
458
+ wt_inputs_c = self.calculate_wt_fc(wt_x_c, inputs, k_c, b_c, {"type":None})
459
+ wt_inputs_o = self.calculate_wt_fc(wt_x_o, inputs, k_o, b_o, {"type":None})
460
+
461
+ wt_inputs = [wt_inputs_i[0]+wt_inputs_f[0]+wt_inputs_c[0]+wt_inputs_o[0],
462
+ wt_inputs_i[1]+wt_inputs_f[1]+wt_inputs_c[1]+wt_inputs_o[1]]
463
+
464
+ return wt_inputs, wt_h_tm1, wt_c_tm1
465
+
466
+ def calculate_lstm_wt(self,wts_pos, wts_neg, compute_log):
467
+ self.compute_log = compute_log
468
+ output_pos = []
469
+ output_neg = []
470
+ if self.return_sequence:
471
+ temp_wts_hstate = [wts_pos[-1,:],wts_neg[-1,:]]
472
+ else:
473
+ temp_wts_hstate = [wts_pos,wts_neg]
474
+ temp_wts_cstate = [np.zeros_like(self.compute_log[0]["cstate"][1].numpy()[0]),
475
+ np.zeros_like(self.compute_log[0]["cstate"][1].numpy()[0])]
476
+ for ind in range(len(self.compute_log)-1,-1,-1):
477
+ temp_wt_inp, temp_wts_hstate, temp_wts_cstate = self.calculate_lstm_cell_wt(ind, temp_wts_hstate, temp_wts_cstate)
478
+ output_pos.append(temp_wt_inp[0])
479
+ output_neg.append(temp_wt_inp[1])
480
+ if self.return_sequence and ind>0:
481
+ temp_wts_hstate[0] = temp_wts_hstate[0]+wts_pos[ind-1,:]
482
+ temp_wts_hstate[1] = temp_wts_hstate[1]+wts_neg[ind-1,:]
483
+ output_pos.reverse()
484
+ output_pos = np.array(output_pos)
485
+ output_neg.reverse()
486
+ output_neg = np.array(output_neg)
487
+ return output_pos,output_neg
488
+
489
+ def dummy_wt(wts, inp, *args):
490
+ test_wt = np.zeros_like(inp)
491
+ return test_wt
492
+
493
+ def calculate_wt_fc(wts_pos,wts_neg, inp, w, b, act={}):
494
+ mul_mat = np.einsum("ij,i->ij",w.numpy(),inp).T
495
+ wt_mat_pos = np.zeros(mul_mat.shape)
496
+ wt_mat_neg = np.zeros(mul_mat.shape)
497
+ for i in range(mul_mat.shape[0]):
498
+ l1_ind1 = mul_mat[i]
499
+ wt_ind1_pos = wt_mat_pos[i]
500
+ wt_ind1_neg = wt_mat_neg[i]
501
+ wt_pos = wts_pos[i]
502
+ wt_neg = wts_neg[i]
503
+ p_ind = l1_ind1>0
504
+ n_ind = l1_ind1<0
505
+ p_sum = np.sum(l1_ind1[p_ind])
506
+ n_sum = np.sum(l1_ind1[n_ind])*-1
507
+ if np.sum(n_ind)==0 and np.sum(p_ind)>0:
508
+ wt_ind1_pos[p_ind] = (l1_ind1[p_ind]/p_sum)*wt_pos
509
+ wt_ind1_neg[p_ind] = (l1_ind1[p_ind]/p_sum)*wt_neg
510
+ elif np.sum(n_ind)>0 and np.sum(p_ind)==0:
511
+ wt_ind1_pos[n_ind] = (l1_ind1[n_ind]/n_sum)*wt_pos*-1
512
+ wt_ind1_neg[n_ind] = (l1_ind1[n_ind]/n_sum)*wt_neg*-1
513
+ else:
514
+ p_agg_wt,n_agg_wt,p_sum,n_sum,wt_sign = calculate_base_wt(p_sum=p_sum,n_sum=n_sum,
515
+ bias=b.numpy()[i],
516
+ wt_pos=wt_pos,wt_neg=wt_neg)
517
+ if wt_sign>0:
518
+ wt_ind1_pos[p_ind] = (l1_ind1[p_ind]/p_sum)*p_agg_wt
519
+ wt_ind1_neg[n_ind] = (l1_ind1[n_ind]/n_sum)*n_agg_wt*-1
520
+ else:
521
+ wt_ind1_neg[p_ind] = (l1_ind1[p_ind]/p_sum)*p_agg_wt
522
+ wt_ind1_pos[n_ind] = (l1_ind1[n_ind]/n_sum)*n_agg_wt*-1
523
+ wt_mat_pos = wt_mat_pos.sum(axis=0)
524
+ wt_mat_neg = wt_mat_neg.sum(axis=0)
525
+ return wt_mat_pos,wt_mat_neg
526
+
527
+ def calculate_wt_passthru(wts):
528
+ return wts
529
+
530
+ def calculate_wt_rshp(wts,inp=None):
531
+ x = np.reshape(wts,inp.shape)
532
+ return x
533
+
534
+ def calculate_wt_concat(wts,inp=None,axis=-1):
535
+ splits = [i.shape[axis] for i in inp]
536
+ splits = np.cumsum(splits)
537
+ if axis>0:
538
+ axis = axis-1
539
+ x = np.split(wts,indices_or_sections=splits,axis=axis)
540
+ return x
541
+
542
+ def calculate_wt_add(wts_pos,wts_neg,inp=None):
543
+ wts_pos = wts_pos
544
+ wts_neg = wts_neg
545
+ wt_mat_pos = []
546
+ wt_mat_neg = []
547
+ inp_list = []
548
+
549
+ expanded_wts_pos = as_strided(wts_pos,
550
+ shape=(np.prod(wts_pos.shape),),
551
+ strides=(wts_pos.strides[-1],),
552
+ writeable=False, # totally use this to avoid writing to memory in weird places
553
+ )
554
+ expanded_wts_neg = as_strided(wts_neg,
555
+ shape=(np.prod(wts_neg.shape),),
556
+ strides=(wts_neg.strides[-1],),
557
+ writeable=False, # totally use this to avoid writing to memory in weird places
558
+ )
559
+ for x in inp:
560
+ expanded_input = as_strided(x,
561
+ shape=(np.prod(x.shape),),
562
+ strides=(x.strides[-1],),
563
+ writeable=False, # totally use this to avoid writing to memory in weird places
564
+ )
565
+ inp_list.append(expanded_input)
566
+ wt_mat_pos.append(np.zeros_like(expanded_input))
567
+ wt_mat_neg.append(np.zeros_like(expanded_input))
568
+ wt_mat_pos = np.array(wt_mat_pos)
569
+ wt_mat_neg = np.array(wt_mat_neg)
570
+ inp_list = np.array(inp_list)
571
+ for i in range(wt_mat_pos.shape[1]):
572
+ wt_ind1_pos = wt_mat_pos[:,i]
573
+ wt_ind1_neg = wt_mat_neg[:,i]
574
+ wt_pos = expanded_wts_pos[i]
575
+ wt_neg = expanded_wts_neg[i]
576
+ l1_ind1 = inp_list[:,i]
577
+ p_ind = l1_ind1>0
578
+ n_ind = l1_ind1<0
579
+ p_sum = np.sum(l1_ind1[p_ind])
580
+ n_sum = np.sum(l1_ind1[n_ind])*-1
581
+ if np.sum(n_ind)==0 and np.sum(p_ind)>0:
582
+ wt_ind1_pos[p_ind] = (l1_ind1[p_ind]/p_sum)*wt_pos
583
+ wt_ind1_neg[p_ind] = (l1_ind1[p_ind]/p_sum)*wt_neg
584
+ elif np.sum(n_ind)>0 and np.sum(p_ind)==0:
585
+ wt_ind1_pos[n_ind] = (l1_ind1[n_ind]/n_sum)*wt_pos*-1
586
+ wt_ind1_neg[n_ind] = (l1_ind1[n_ind]/n_sum)*wt_neg*-1
587
+ else:
588
+ p_agg_wt,n_agg_wt,p_sum,n_sum,wt_sign = calculate_base_wt(p_sum=p_sum,n_sum=n_sum,
589
+ bias=0.0,
590
+ wt_pos=wt_pos,wt_neg=wt_neg)
591
+ if wt_sign>0:
592
+ wt_ind1_pos[p_ind] = (l1_ind1[p_ind]/p_sum)*p_agg_wt
593
+ wt_ind1_neg[n_ind] = (l1_ind1[n_ind]/n_sum)*n_agg_wt*-1
594
+ else:
595
+ wt_ind1_neg[p_ind] = (l1_ind1[p_ind]/p_sum)*p_agg_wt
596
+ wt_ind1_pos[n_ind] = (l1_ind1[n_ind]/n_sum)*n_agg_wt*-1
597
+ wt_mat_pos[:,i] = wt_ind1_pos
598
+ wt_mat_neg[:,i] = wt_ind1_neg
599
+ wt_mat_pos = [i.reshape(wts_pos.shape) for i in list(wt_mat_pos)]
600
+ wt_mat_neg = [i.reshape(wts_neg.shape) for i in list(wt_mat_neg)]
601
+ output = []
602
+ for i in range(len(wt_mat_pos)):
603
+ output.append((wt_mat_pos[i],wt_mat_neg[i]))
604
+ return output
605
+
606
+ def calculate_wt_zero_pad(wts_pos,wts_neg,inp,padding):
607
+ wt_mat_pos = wts_pos[padding[0][0]:inp.shape[0]+padding[0][0],padding[1][0]:inp.shape[1]+padding[1][0],:]
608
+ wt_mat_neg = wts_neg[padding[0][0]:inp.shape[0]+padding[0][0],padding[1][0]:inp.shape[1]+padding[1][0],:]
609
+ return wt_mat_pos,wt_mat_neg
610
+
611
+ def calculate_padding(kernel_size, inp, padding, strides, const_val=0.0):
612
+ if padding=='valid':
613
+ return (inp, [[0,0],[0,0],[0,0]])
614
+ else:
615
+ h = inp.shape[0]%strides[0]
616
+ if h==0:
617
+ pad_h = np.max([0,kernel_size[0]-strides[0]])
618
+ else:
619
+ pad_h = np.max([0,kernel_size[0]-h])
620
+
621
+ v = inp.shape[1]%strides[1]
622
+ if v==0:
623
+ pad_v = np.max([0,kernel_size[1]-strides[1]])
624
+ else:
625
+ pad_v = np.max([0,kernel_size[1]-v])
626
+
627
+ paddings = [np.floor([pad_h/2.0,(pad_h+1)/2.0]).astype("int32"),
628
+ np.floor([pad_v/2.0,(pad_v+1)/2.0]).astype("int32"),
629
+ np.zeros((2)).astype("int32")]
630
+ inp_pad = np.pad(inp, paddings, 'constant', constant_values=const_val)
631
+ return (inp_pad,paddings)
632
+
633
+ def calculate_wt_conv_unit(patch, wts_pos, wts_neg, w, b, act):
634
+ k = w.numpy()
635
+ bias = b.numpy()
636
+ conv_out = np.einsum("ijkl,ijk->ijkl",k,patch)
637
+ p_ind = conv_out>0
638
+ p_ind = conv_out*p_ind
639
+ p_sum = np.einsum("ijkl->l",p_ind)
640
+ n_ind = conv_out<0
641
+ n_ind = conv_out*n_ind
642
+ n_sum = np.einsum("ijkl->l",n_ind)*-1.0
643
+ p_agg_wt_pos,p_agg_wt_neg,n_agg_wt_pos,n_agg_wt_neg,p_sum,n_sum = calculate_base_wt_array(p_sum,n_sum,bias,wts_pos,wts_neg)
644
+ wt_mat_pos = np.zeros_like(k)
645
+ wt_mat_neg = np.zeros_like(k)
646
+
647
+ wt_mat_pos = wt_mat_pos+((p_ind/p_sum)*p_agg_wt_pos)
648
+ wt_mat_pos = wt_mat_pos+((n_ind/n_sum)*n_agg_wt_pos)*-1.0
649
+ wt_mat_neg = wt_mat_neg+((p_ind/p_sum)*p_agg_wt_neg)
650
+ wt_mat_neg = wt_mat_neg+((n_ind/n_sum)*n_agg_wt_neg)*-1.0
651
+ wt_mat_pos = np.sum(wt_mat_pos,axis=-1)
652
+ wt_mat_neg = np.sum(wt_mat_neg,axis=-1)
653
+
654
+ return wt_mat_pos, wt_mat_neg
655
+
656
+ def calculate_wt_conv(wts_pos, wts_neg, inp, w, b, padding, strides, act):
657
+ input_padded, paddings = calculate_padding(w.shape, inp, padding, strides)
658
+ out_ds_pos = np.zeros_like(input_padded)
659
+ out_ds_neg = np.zeros_like(input_padded)
660
+ for ind1 in range(wts_pos.shape[0]):
661
+ for ind2 in range(wts_pos.shape[1]):
662
+ indexes = [np.arange(ind1*strides[0], ind1*(strides[0])+w.shape[0]),
663
+ np.arange(ind2*strides[1], ind2*(strides[1])+w.shape[1])]
664
+ # Take slice
665
+ tmp_patch = input_padded[np.ix_(indexes[0],indexes[1])]
666
+ updates_pos,updates_neg = calculate_wt_conv_unit(tmp_patch, wts_pos[ind1,ind2,:], wts_neg[ind1,ind2,:], w, b, act)
667
+ # Build tensor with "filtered" gradient
668
+ out_ds_pos[np.ix_(indexes[0],indexes[1])]+=updates_pos
669
+ out_ds_neg[np.ix_(indexes[0],indexes[1])]+=updates_neg
670
+ out_ds_pos = out_ds_pos[paddings[0][0]:(paddings[0][0]+inp.shape[0]),
671
+ paddings[1][0]:(paddings[1][0]+inp.shape[1]),:]
672
+ out_ds_neg = out_ds_neg[paddings[0][0]:(paddings[0][0]+inp.shape[0]),
673
+ paddings[1][0]:(paddings[1][0]+inp.shape[1]),:]
674
+ return out_ds_pos, out_ds_neg
675
+
676
+ def calculate_wt_max_unit(patch, wts, pool_size):
677
+ pmax = np.einsum("ijk,k->ijk",np.ones_like(patch),np.max(np.max(patch,axis=0),axis=0))
678
+ indexes = (patch-pmax)==0
679
+ indexes = indexes.astype(np.float32)
680
+ indexes_norm = 1.0/np.einsum("mnc->c",indexes)
681
+ indexes = np.einsum("ijk,k->ijk",indexes,indexes_norm)
682
+ out = np.einsum("ijk,k->ijk",indexes,wts)
683
+ return out
684
+
685
+ def calculate_wt_maxpool(wts, inp, pool_size, padding, strides):
686
+ input_padded, paddings = calculate_padding(pool_size, inp, padding, strides, -np.inf)
687
+ out_ds = np.zeros_like(input_padded)
688
+ for ind1 in range(wts.shape[0]):
689
+ for ind2 in range(wts.shape[1]):
690
+ indexes = [np.arange(ind1*strides[0], ind1*(strides[0])+pool_size[0]),
691
+ np.arange(ind2*strides[1], ind2*(strides[1])+pool_size[1])]
692
+ # Take slice
693
+ tmp_patch = input_padded[np.ix_(indexes[0],indexes[1])]
694
+ updates = calculate_wt_max_unit(tmp_patch, wts[ind1,ind2,:], pool_size)
695
+ # Build tensor with "filtered" gradient
696
+ out_ds[np.ix_(indexes[0],indexes[1])]+=updates
697
+ out_ds = out_ds[paddings[0][0]:(paddings[0][0]+inp.shape[0]),
698
+ paddings[1][0]:(paddings[1][0]+inp.shape[1]),:]
699
+ return out_ds
700
+
701
+ def calculate_wt_avg_unit(patch, wts_pos, wts_neg, pool_size):
702
+ p_ind = patch>0
703
+ p_ind = patch*p_ind
704
+ p_sum = np.einsum("ijk->k",p_ind)
705
+ n_ind = patch<0
706
+ n_ind = patch*n_ind
707
+ n_sum = np.einsum("ijk->k",n_ind)*-1.0
708
+ bias = np.zeros_like(wts_pos)
709
+ p_agg_wt_pos,p_agg_wt_neg,n_agg_wt_pos,n_agg_wt_neg,p_sum,n_sum = calculate_base_wt_array(p_sum,n_sum,bias,wts_pos,wts_neg)
710
+ wt_mat_pos = np.zeros_like(patch)
711
+ wt_mat_neg = np.zeros_like(patch)
712
+
713
+ wt_mat_pos = wt_mat_pos+((p_ind/p_sum)*p_agg_wt_pos)
714
+ wt_mat_pos = wt_mat_pos+((n_ind/n_sum)*n_agg_wt_pos)*-1.0
715
+ wt_mat_neg = wt_mat_neg+((p_ind/p_sum)*p_agg_wt_neg)
716
+ wt_mat_neg = wt_mat_neg+((n_ind/n_sum)*n_agg_wt_neg)*-1.0
717
+ return wt_mat_pos, wt_mat_neg
718
+
719
+ def calculate_wt_avgpool(wts_pos, wts_neg, inp, pool_size, padding, strides, act={}):
720
+ input_padded, paddings = calculate_padding(pool_size, inp, padding, strides)
721
+ out_ds_pos = np.zeros_like(input_padded)
722
+ out_ds_neg = np.zeros_like(input_padded)
723
+ for ind1 in range(wts_pos.shape[0]):
724
+ for ind2 in range(wts_pos.shape[1]):
725
+ indexes = [np.arange(ind1*strides[0], ind1*(strides[0])+pool_size[0]),
726
+ np.arange(ind2*strides[1], ind2*(strides[1])+pool_size[1])]
727
+ # Take slice
728
+ tmp_patch = input_padded[np.ix_(indexes[0],indexes[1])]
729
+ updates_pos,updates_neg = calculate_wt_avg_unit(tmp_patch, wts_pos[ind1,ind2,:], wts_neg[ind1,ind2,:],
730
+ pool_size)
731
+ # Build tensor with "filtered" gradient
732
+ out_ds_pos[np.ix_(indexes[0],indexes[1])]+=updates_pos
733
+ out_ds_neg[np.ix_(indexes[0],indexes[1])]+=updates_neg
734
+ out_ds_pos = out_ds_pos[paddings[0][0]:(paddings[0][0]+inp.shape[0]),
735
+ paddings[1][0]:(paddings[1][0]+inp.shape[1]),:]
736
+ out_ds_neg = out_ds_neg[paddings[0][0]:(paddings[0][0]+inp.shape[0]),
737
+ paddings[1][0]:(paddings[1][0]+inp.shape[1]),:]
738
+ return out_ds_pos,out_ds_neg
739
+
740
+ def calculate_wt_gavgpool(wts_pos,wts_neg,inp):
741
+ channels = wts_pos.shape[0]
742
+ wt_mat_pos = np.zeros_like(inp)
743
+ wt_mat_neg = np.zeros_like(inp)
744
+ for c in range(channels):
745
+ wt_pos = wts_pos[c]
746
+ wt_neg = wts_neg[c]
747
+ temp_wt_pos = wt_mat_pos[...,c]
748
+ temp_wt_neg = wt_mat_neg[...,c]
749
+ x = inp[...,c]
750
+ p_mat = np.copy(x)
751
+ n_mat = np.copy(x)
752
+ p_mat[x<0] = 0
753
+ n_mat[x>0] = 0
754
+ p_sum = np.sum(p_mat)
755
+ n_sum = np.sum(n_mat)*-1
756
+ if n_sum==0 and p_sum>0:
757
+ temp_wt_pos = temp_wt_pos+((p_mat/p_sum)*wt_pos)
758
+ temp_wt_neg = temp_wt_neg+((p_mat/p_sum)*wt_neg)
759
+ elif n_sum>0 and p_sum==0:
760
+ temp_wt_pos = temp_wt_pos+((n_mat/n_sum)*wt_pos*-1)
761
+ temp_wt_neg = temp_wt_neg+((n_mat/n_sum)*wt_neg*-1)
762
+ else:
763
+ p_agg_wt,n_agg_wt,p_sum,n_sum,wt_sign = calculate_base_wt(p_sum=p_sum,n_sum=n_sum,
764
+ bias=0,
765
+ wt_pos=wt_pos,wt_neg=wt_neg)
766
+ if wt_sign>0:
767
+ temp_wt_pos = temp_wt_pos+((p_mat/p_sum)*p_agg_wt)
768
+ temp_wt_neg = temp_wt_neg+((n_mat/n_sum)*n_agg_wt*-1)
769
+ else:
770
+ temp_wt_neg = temp_wt_neg+((p_mat/p_sum)*p_agg_wt)
771
+ temp_wt_pos = temp_wt_pos+((n_mat/n_sum)*n_agg_wt*-1)
772
+ wt_mat_pos[...,c] = temp_wt_pos
773
+ wt_mat_neg[...,c] = temp_wt_neg
774
+ return wt_mat_pos,wt_mat_neg
775
+
776
+ def calculate_wt_gmaxpool_2d(wts, inp):
777
+ channels = wts.shape[0]
778
+ wt_mat = np.zeros_like(inp)
779
+ for c in range(channels):
780
+ wt = wts[c]
781
+ x = inp[..., c]
782
+ max_val = np.max(x)
783
+ max_indexes = (x == max_val).astype(np.float32)
784
+ max_indexes_norm = 1.0 / np.sum(max_indexes)
785
+ max_indexes = max_indexes * max_indexes_norm
786
+ wt_mat[..., c] = max_indexes * wt
787
+ return wt_mat
788
+
789
+ def weight_scaler(arg,scaler=100.0):
790
+ s1 = np.sum(arg)
791
+ scale_factor = s1/scaler
792
+ return arg/scale_factor
793
+
794
+ def weight_normalize(arg,max_val=1.0):
795
+ arg_max = np.max(arg)
796
+ arg_min = np.abs(np.min(arg))
797
+ if arg_max>arg_min:
798
+ return (arg/arg_max)*max_val
799
+ elif arg_min>0:
800
+ return (arg/arg_min)*max_val
801
+ else:
802
+ return arg
803
+
804
+ def calculate_padding_1d(kernel_size, inp, padding, strides, const_val=0.0):
805
+ if padding == 'valid':
806
+ return inp, [0, 0]
807
+ else:
808
+ remainder = inp.shape[0] % strides
809
+ if remainder == 0:
810
+ pad_total = max(0, kernel_size - strides)
811
+ else:
812
+ pad_total = max(0, kernel_size - remainder)
813
+
814
+ pad_left = int(np.floor(pad_total / 2.0))
815
+ pad_right = int(np.ceil(pad_total / 2.0))
816
+
817
+ inp_pad = np.pad(inp, (pad_left, pad_right), 'constant', constant_values=const_val)
818
+ return inp_pad, [pad_left, pad_right]
819
+
820
+ def calculate_wt_conv_unit_1d(patch, wts_pos, wts_neg, w, b, act):
821
+ k = w.numpy()
822
+ bias = b.numpy()
823
+ conv_out = np.einsum("ijk,ij->ijk",k,patch)
824
+ p_ind = conv_out>0
825
+ p_ind = conv_out*p_ind
826
+ p_sum = np.einsum("ijk->k",p_ind)
827
+ n_ind = conv_out<0
828
+ n_ind = conv_out*n_ind
829
+ n_sum = np.einsum("ijk->k",n_ind)*-1.0
830
+ p_agg_wt_pos,p_agg_wt_neg,n_agg_wt_pos,n_agg_wt_neg,p_sum,n_sum = calculate_base_wt_array(p_sum,n_sum,bias,wts_pos,wts_neg)
831
+ wt_mat_pos = np.zeros_like(k)
832
+ wt_mat_neg = np.zeros_like(k)
833
+
834
+ wt_mat_pos = wt_mat_pos+((p_ind/p_sum)*p_agg_wt_pos)
835
+ wt_mat_pos = wt_mat_pos+((n_ind/n_sum)*n_agg_wt_pos)*-1.0
836
+ wt_mat_neg = wt_mat_neg+((p_ind/p_sum)*p_agg_wt_neg)
837
+ wt_mat_neg = wt_mat_neg+((n_ind/n_sum)*n_agg_wt_neg)*-1.0
838
+ wt_mat_pos = np.sum(wt_mat_pos,axis=-1)
839
+ wt_mat_neg = np.sum(wt_mat_neg,axis=-1)
840
+
841
+ return wt_mat_pos, wt_mat_neg
842
+
843
+ def calculate_wt_conv_1d(wts_pos, wts_neg, inp, w, b, padding, stride, act):
844
+ input_padded, paddings = calculate_padding_1d(w.shape[0], inp, padding, stride)
845
+ out_ds_pos = np.zeros_like(input_padded)
846
+ out_ds_neg = np.zeros_like(input_padded)
847
+ for ind in range(wts_pos.shape[0]):
848
+ indexes = np.arange(ind * stride, ind * stride + w.shape[0])
849
+ tmp_patch = input_padded[indexes]
850
+ updates_pos,updates_neg = calculate_wt_conv_unit_1d(tmp_patch, wts_pos[ind, :], wts_neg[ind, :], w, b, act)
851
+
852
+ out_ds_pos[indexes] += updates_pos
853
+ out_ds_neg[indexes] += updates_neg
854
+
855
+ out_ds_pos = out_ds_pos[paddings[0]:(paddings[0] + inp.shape[0])]
856
+ out_ds_neg = out_ds_neg[paddings[0]:(paddings[0] + inp.shape[0])]
857
+ return out_ds_pos, out_ds_neg
858
+
859
+ def calculate_wt_max_unit_1d(patch, wts, pool_size):
860
+ pmax = np.max(patch, axis=0)
861
+ indexes = (patch-pmax)==0
862
+ indexes = indexes.astype(np.float32)
863
+ indexes_norm = 1.0 / np.sum(indexes, axis=0)
864
+ indexes = np.einsum("ij,j->ij", indexes, indexes_norm)
865
+ out = np.einsum("ij,j->ij", indexes, wts)
866
+ return out
867
+
868
+ def calculate_wt_maxpool_1d(wts, inp, pool_size, padding, strides):
869
+ input_padded, paddings = calculate_padding_1d(pool_size, inp, padding, strides, -np.inf)
870
+ out_ds = np.zeros_like(input_padded)
871
+ stride=strides[0]
872
+ pool_size=pool_size[0]
873
+ for ind in range(wts.shape[0]):
874
+ indexes = np.arange(ind * stride, ind * stride + pool_size)
875
+ tmp_patch = input_padded[indexes]
876
+ updates = calculate_wt_max_unit_1d(tmp_patch, wts[ind, :], pool_size)
877
+ out_ds[indexes] += updates
878
+ out_ds = out_ds[paddings[0]:(paddings[0] + inp.shape[0])]
879
+ return out_ds
880
+
881
+ def calculate_wt_avg_unit_1d(patch, wts_pos, wts_neg, pool_size):
882
+ p_ind = patch>0
883
+ p_ind = patch*p_ind
884
+ p_sum = np.sum(p_ind, axis=0)
885
+ n_ind = patch<0
886
+ n_ind = patch*n_ind
887
+ n_sum = np.sum(n_ind, axis=0)*-1.0
888
+ bias = np.zeros_like(wts_pos)
889
+ p_agg_wt_pos,p_agg_wt_neg,n_agg_wt_pos,n_agg_wt_neg,p_sum,n_sum = calculate_base_wt_array(p_sum,n_sum,bias,wts_pos,wts_neg)
890
+ wt_mat_pos = np.zeros_like(patch)
891
+ wt_mat_neg = np.zeros_like(patch)
892
+
893
+ wt_mat_pos = wt_mat_pos+((p_ind/p_sum)*p_agg_wt_pos)
894
+ wt_mat_pos = wt_mat_pos+((n_ind/n_sum)*n_agg_wt_pos)*-1.0
895
+ wt_mat_neg = wt_mat_neg+((p_ind/p_sum)*p_agg_wt_neg)
896
+ wt_mat_neg = wt_mat_neg+((n_ind/n_sum)*n_agg_wt_neg)*-1.0
897
+ return wt_mat_pos, wt_mat_neg
898
+
899
+ def calculate_wt_avgpool_1d(wts_pos, wts_neg, inp, pool_size, padding, strides, act={}):
900
+ input_padded, paddings = calculate_padding_1d(pool_size, inp, padding, strides)
901
+ out_ds_pos = np.zeros_like(input_padded)
902
+ out_ds_neg = np.zeros_like(input_padded)
903
+ stride=strides[0]
904
+ pool_size=pool_size[0]
905
+ for ind in range(wts_pos.shape[0]):
906
+ indexes = np.arange(ind * stride, ind * stride + pool_size)
907
+ tmp_patch = input_padded[indexes]
908
+ updates_pos,updates_neg = calculate_wt_avg_unit_1d(tmp_patch, wts_pos[ind, :], wts_neg[ind, :],pool_size)
909
+ out_ds_pos[indexes] += updates_pos
910
+ out_ds_neg[indexes] += updates_neg
911
+
912
+ out_ds_pos = out_ds_pos[paddings[0]:(paddings[0] + inp.shape[0])]
913
+ out_ds_neg = out_ds_neg[paddings[0]:(paddings[0] + inp.shape[0])]
914
+ return out_ds_pos,out_ds_neg
915
+
916
+ def calculate_wt_gavgpool_1d(wts_pos,wts_neg,inp):
917
+ channels = wts_pos.shape[0]
918
+ wt_mat_pos = np.zeros_like(inp)
919
+ wt_mat_neg = np.zeros_like(inp)
920
+ for c in range(channels):
921
+ wt_pos = wts_pos[c]
922
+ wt_neg = wts_neg[c]
923
+ temp_wt_pos = wt_mat_pos[...,c]
924
+ temp_wt_neg = wt_mat_neg[...,c]
925
+ x = inp[...,c]
926
+ p_mat = np.copy(x)
927
+ n_mat = np.copy(x)
928
+ p_mat[x<0] = 0
929
+ n_mat[x>0] = 0
930
+ p_sum = np.sum(p_mat)
931
+ n_sum = np.sum(n_mat)*-1
932
+ if n_sum==0 and p_sum>0:
933
+ temp_wt_pos = temp_wt_pos+((p_mat/p_sum)*wt_pos)
934
+ temp_wt_neg = temp_wt_neg+((p_mat/p_sum)*wt_neg)
935
+ elif n_sum>0 and p_sum==0:
936
+ temp_wt_pos = temp_wt_pos+((n_mat/n_sum)*wt_pos*-1)
937
+ temp_wt_neg = temp_wt_neg+((n_mat/n_sum)*wt_neg*-1)
938
+ else:
939
+ p_agg_wt,n_agg_wt,p_sum,n_sum,wt_sign = calculate_base_wt(p_sum=p_sum,n_sum=n_sum,
940
+ bias=0,
941
+ wt_pos=wt_pos,wt_neg=wt_neg)
942
+ if wt_sign>0:
943
+ temp_wt_pos = temp_wt_pos+((p_mat/p_sum)*p_agg_wt)
944
+ temp_wt_neg = temp_wt_neg+((n_mat/n_sum)*n_agg_wt*-1)
945
+ else:
946
+ temp_wt_neg = temp_wt_neg+((p_mat/p_sum)*p_agg_wt)
947
+ temp_wt_pos = temp_wt_pos+((n_mat/n_sum)*n_agg_wt*-1)
948
+ wt_mat_pos[...,c] = temp_wt_pos
949
+ wt_mat_neg[...,c] = temp_wt_neg
950
+ return wt_mat_pos,wt_mat_neg
951
+
952
+ def calculate_wt_gmaxpool_1d(wts, inp):
953
+ channels = wts.shape[0]
954
+ wt_mat = np.zeros_like(inp)
955
+ for c in range(channels):
956
+ wt = wts[c]
957
+ x = inp[:, c]
958
+ max_val = np.max(x)
959
+ max_indexes = (x == max_val).astype(np.float32)
960
+ max_indexes_norm = 1.0 / np.sum(max_indexes)
961
+ max_indexes = max_indexes * max_indexes_norm
962
+ wt_mat[:, c] = max_indexes * wt
963
+ return wt_mat
964
+
965
+ def calculate_output_padding_conv2d_transpose(input_shape, kernel_size, padding, strides):
966
+ if padding == 'valid':
967
+ out_shape = [(input_shape[0] - 1) * strides[0] + kernel_size[0],
968
+ (input_shape[1] - 1) * strides[1] + kernel_size[1]]
969
+ return (out_shape, [[0,0],[0,0],[0,0]])
970
+ else: # 'same' padding
971
+ out_shape = [input_shape[0] * strides[0], input_shape[1] * strides[1]]
972
+ pad_h = max(0, (input_shape[0] - 1) * strides[0] + kernel_size[0] - out_shape[0])
973
+ pad_v = max(0, (input_shape[1] - 1) * strides[1] + kernel_size[1] - out_shape[1])
974
+ paddings = [np.floor([pad_h/2.0, (pad_h+1)/2.0]).astype("int32"),
975
+ np.floor([pad_v/2.0, (pad_v+1)/2.0]).astype("int32"),
976
+ np.zeros((2)).astype("int32")]
977
+ return (out_shape, paddings)
978
+
979
+ def calculate_wt_conv2d_transpose_unit(patch, wts_pos, wts_neg, w, b, act):
980
+
981
+ if patch.ndim == 1:
982
+ patch = patch.reshape(1, 1, -1)
983
+ elif patch.ndim == 2:
984
+ patch = patch.reshape(1, *patch.shape)
985
+ elif patch.ndim != 3:
986
+ raise ValueError(f"Unexpected patch shape: {patch.shape}")
987
+
988
+ k = tf.transpose(w, perm=[0, 1, 3, 2]).numpy()
989
+ bias = b.numpy()
990
+ b_ind = bias>0
991
+ bias_pos = bias*b_ind
992
+ b_ind = bias<0
993
+ bias_neg = bias*b_ind*-1.0
994
+ conv_out = np.einsum('ijkl,mnk->ijkl', k, patch)
995
+ p_ind = conv_out > 0
996
+ p_ind = conv_out*p_ind
997
+ n_ind = conv_out < 0
998
+ n_ind = conv_out*n_ind
999
+ p_sum = np.einsum("ijkl->l",p_ind)
1000
+ n_sum = np.einsum("ijkl->l",n_ind)*-1.0
1001
+ p_agg_wt_pos,p_agg_wt_neg,n_agg_wt_pos,n_agg_wt_neg,p_sum,n_sum = calculate_base_wt_array(p_sum,n_sum,bias,wts_pos,wts_neg)
1002
+ wt_mat_pos = np.zeros_like(k)
1003
+ wt_mat_neg = np.zeros_like(k)
1004
+
1005
+ wt_mat_pos = wt_mat_pos+((p_ind/p_sum)*p_agg_wt_pos)
1006
+ wt_mat_pos = wt_mat_pos+((n_ind/n_sum)*n_agg_wt_pos)*-1.0
1007
+ wt_mat_neg = wt_mat_neg+((p_ind/p_sum)*p_agg_wt_neg)
1008
+ wt_mat_neg = wt_mat_neg+((n_ind/n_sum)*n_agg_wt_neg)*-1.0
1009
+ wt_mat_pos = np.sum(wt_mat_pos,axis=-1)
1010
+ wt_mat_neg = np.sum(wt_mat_neg,axis=-1)
1011
+
1012
+ return wt_mat_pos, wt_mat_neg
1013
+
1014
+ def calculate_wt_conv2d_transpose(wts_pos, wts_neg, inp, w, b, padding, strides, act):
1015
+ out_shape, paddings = calculate_output_padding_conv2d_transpose(inp.shape, w.shape, padding, strides)
1016
+ out_ds_pos = np.zeros(out_shape + [w.shape[3]])
1017
+ out_ds_neg = np.zeros(out_shape + [w.shape[3]])
1018
+ for ind1 in range(inp.shape[0]):
1019
+ for ind2 in range(inp.shape[1]):
1020
+ out_ind1 = ind1 * strides[0]
1021
+ out_ind2 = ind2 * strides[1]
1022
+ tmp_patch = inp[ind1, ind2, :]
1023
+ updates_pos,updates_neg = calculate_wt_conv2d_transpose_unit(tmp_patch, wts_pos[ind1,ind2,:], wts_neg[ind1,ind2,:], w, b, act)
1024
+ end_ind1 = min(out_ind1 + w.shape[0], out_shape[0])
1025
+ end_ind2 = min(out_ind2 + w.shape[1], out_shape[1])
1026
+ valid_updates_pos = updates_pos[:end_ind1 - out_ind1, :end_ind2 - out_ind2, :]
1027
+ valid_updates_neg = updates_neg[:end_ind1 - out_ind1, :end_ind2 - out_ind2, :]
1028
+
1029
+ out_ds_pos[out_ind1:end_ind1, out_ind2:end_ind2, :] += valid_updates_pos
1030
+ out_ds_neg[out_ind1:end_ind1, out_ind2:end_ind2, :] += valid_updates_neg
1031
+
1032
+ if padding == 'same':
1033
+ adjusted_out_ds_pos = np.zeros(inp.shape)
1034
+ adjusted_out_ds_neg = np.zeros(inp.shape)
1035
+ for i in range(inp.shape[0]):
1036
+ for j in range(inp.shape[1]):
1037
+ start_i = max(0, i * strides[0])
1038
+ start_j = max(0, j * strides[1])
1039
+ end_i = min(out_ds_pos.shape[0], (i+1) * strides[0])
1040
+ end_j = min(out_ds_pos.shape[1], (j+1) * strides[1])
1041
+ relevant_area_pos = out_ds_pos[start_i:end_i, start_j:end_j, :]
1042
+ adjusted_out_ds_pos[i, j, :] = np.sum(relevant_area_pos, axis=(0, 1))
1043
+ relevant_area_neg = out_ds_neg[start_i:end_i, start_j:end_j, :]
1044
+ adjusted_out_ds_neg[i, j, :] = np.sum(relevant_area_neg, axis=(0, 1))
1045
+ out_ds_pos = adjusted_out_ds_pos
1046
+ out_ds_neg = adjusted_out_ds_neg
1047
+ else:
1048
+ out_ds_pos = out_ds_pos[paddings[0][0]:(paddings[0][0] + inp.shape[0]),
1049
+ paddings[1][0]:(paddings[1][0] + inp.shape[1]), :]
1050
+ out_ds_neg = out_ds_neg[paddings[0][0]:(paddings[0][0] + inp.shape[0]),
1051
+ paddings[1][0]:(paddings[1][0] + inp.shape[1]), :]
1052
+
1053
+ return out_ds_pos,out_ds_neg
1054
+
1055
+ def calculate_output_padding_conv1d_transpose(input_shape, kernel_size, padding, strides):
1056
+ if padding == 'valid':
1057
+ out_shape = [(input_shape[0] - 1) * strides + kernel_size[0]]
1058
+ return (out_shape, [0, 0])
1059
+ else: # 'same' padding
1060
+ out_shape = [input_shape[0] * strides]
1061
+ pad_h = max(0, (input_shape[0] - 1) * strides + kernel_size[0] - out_shape[0])
1062
+ paddings = np.floor([pad_h / 2.0, (pad_h + 1) / 2.0]).astype("int32")
1063
+ return (out_shape, paddings)
1064
+
1065
+ def calculate_wt_conv1d_transpose_unit(patch, wts_pos, wts_neg, w, b, act):
1066
+ if patch.ndim == 1:
1067
+ patch = patch.reshape(1, -1)
1068
+ elif patch.ndim != 2:
1069
+ raise ValueError(f"Unexpected patch shape: {patch.shape}")
1070
+
1071
+ k = tf.transpose(w, perm=[0, 2, 1]).numpy()
1072
+ bias = b.numpy()
1073
+ b_ind = bias > 0
1074
+ bias_pos = bias * b_ind
1075
+ b_ind = bias < 0
1076
+ bias_neg = bias * b_ind * -1.0
1077
+
1078
+ conv_out = np.einsum('ijk,mj->ijk', k, patch)
1079
+ p_ind = conv_out > 0
1080
+ p_ind = conv_out * p_ind
1081
+ n_ind = conv_out < 0
1082
+ n_ind = conv_out * n_ind
1083
+ p_sum = np.einsum("ijk->k", p_ind)
1084
+ n_sum = np.einsum("ijk->k", n_ind) * -1.0
1085
+
1086
+ p_agg_wt_pos, p_agg_wt_neg, n_agg_wt_pos, n_agg_wt_neg, p_sum, n_sum = calculate_base_wt_array(p_sum, n_sum, bias, wts_pos, wts_neg)
1087
+ wt_mat_pos = np.zeros_like(k)
1088
+ wt_mat_neg = np.zeros_like(k)
1089
+
1090
+ wt_mat_pos += (p_ind / p_sum) * p_agg_wt_pos
1091
+ wt_mat_pos += (n_ind / n_sum) * n_agg_wt_pos * -1.0
1092
+ wt_mat_neg += (p_ind / p_sum) * p_agg_wt_neg
1093
+ wt_mat_neg += (n_ind / n_sum) * n_agg_wt_neg * -1.0
1094
+
1095
+ wt_mat_pos = np.sum(wt_mat_pos, axis=-1)
1096
+ wt_mat_neg = np.sum(wt_mat_neg, axis=-1)
1097
+
1098
+ return wt_mat_pos, wt_mat_neg
1099
+
1100
+ def calculate_wt_conv1d_transpose(wts_pos, wts_neg, inp, w, b, padding, strides, act):
1101
+ out_shape, paddings = calculate_output_padding_conv1d_transpose(inp.shape, w.shape, padding, strides)
1102
+ out_ds_pos = np.zeros(out_shape + [w.shape[2]])
1103
+ out_ds_neg = np.zeros(out_shape + [w.shape[2]])
1104
+
1105
+ for ind in range(inp.shape[0]):
1106
+ out_ind = ind * strides
1107
+ tmp_patch = inp[ind, :]
1108
+ updates_pos, updates_neg = calculate_wt_conv1d_transpose_unit(tmp_patch, wts_pos[ind, :], wts_neg[ind, :], w, b, act)
1109
+ end_ind = min(out_ind + w.shape[0], out_shape[0])
1110
+ valid_updates_pos = updates_pos[:end_ind - out_ind, :]
1111
+ valid_updates_neg = updates_neg[:end_ind - out_ind, :]
1112
+
1113
+ out_ds_pos[out_ind:end_ind, :] += valid_updates_pos
1114
+ out_ds_neg[out_ind:end_ind, :] += valid_updates_neg
1115
+
1116
+ if padding == 'same':
1117
+ adjusted_out_ds_pos = np.zeros(inp.shape)
1118
+ adjusted_out_ds_neg = np.zeros(inp.shape)
1119
+ for i in range(inp.shape[0]):
1120
+ start_i = max(0, i * strides)
1121
+ end_i = min(out_ds_pos.shape[0], (i + 1) * strides)
1122
+ relevant_area_pos = out_ds_pos[start_i:end_i, :]
1123
+ adjusted_out_ds_pos[i, :] = np.sum(relevant_area_pos, axis=0)
1124
+ relevant_area_neg = out_ds_neg[start_i:end_i, :]
1125
+ adjusted_out_ds_neg[i, :] = np.sum(relevant_area_neg, axis=0)
1126
+ out_ds_pos = adjusted_out_ds_pos
1127
+ out_ds_neg = adjusted_out_ds_neg
1128
+ else:
1129
+ out_ds_pos = out_ds_pos[paddings[0]:(paddings[0] + inp.shape[0]), :]
1130
+ out_ds_neg = out_ds_neg[paddings[0]:(paddings[0] + inp.shape[0]), :]
1131
+
1132
+ return out_ds_pos, out_ds_neg