dl-backtrace 0.0.16.dev4__py3-none-any.whl → 0.0.17__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of dl-backtrace might be problematic. Click here for more details.

Files changed (25) hide show
  1. dl_backtrace/old_backtrace/__init__.py +1 -0
  2. dl_backtrace/old_backtrace/pytorch_backtrace/__init__.py +1 -0
  3. dl_backtrace/old_backtrace/pytorch_backtrace/backtrace/__init__.py +4 -0
  4. dl_backtrace/old_backtrace/pytorch_backtrace/backtrace/backtrace.py +639 -0
  5. dl_backtrace/old_backtrace/pytorch_backtrace/backtrace/config.py +41 -0
  6. dl_backtrace/old_backtrace/pytorch_backtrace/backtrace/utils/__init__.py +2 -0
  7. dl_backtrace/old_backtrace/pytorch_backtrace/backtrace/utils/contrast.py +840 -0
  8. dl_backtrace/old_backtrace/pytorch_backtrace/backtrace/utils/prop.py +746 -0
  9. dl_backtrace/old_backtrace/tf_backtrace/__init__.py +1 -0
  10. dl_backtrace/old_backtrace/tf_backtrace/backtrace/__init__.py +4 -0
  11. dl_backtrace/old_backtrace/tf_backtrace/backtrace/backtrace.py +527 -0
  12. dl_backtrace/old_backtrace/tf_backtrace/backtrace/config.py +41 -0
  13. dl_backtrace/old_backtrace/tf_backtrace/backtrace/utils/__init__.py +2 -0
  14. dl_backtrace/old_backtrace/tf_backtrace/backtrace/utils/contrast.py +834 -0
  15. dl_backtrace/old_backtrace/tf_backtrace/backtrace/utils/prop.py +725 -0
  16. dl_backtrace/tf_backtrace/backtrace/backtrace.py +5 -3
  17. dl_backtrace/tf_backtrace/backtrace/utils/utils_prop.py +53 -0
  18. dl_backtrace/version.py +2 -2
  19. dl_backtrace-0.0.17.dist-info/METADATA +164 -0
  20. dl_backtrace-0.0.17.dist-info/RECORD +44 -0
  21. dl_backtrace-0.0.16.dev4.dist-info/METADATA +0 -102
  22. dl_backtrace-0.0.16.dev4.dist-info/RECORD +0 -29
  23. {dl_backtrace-0.0.16.dev4.dist-info → dl_backtrace-0.0.17.dist-info}/LICENSE +0 -0
  24. {dl_backtrace-0.0.16.dev4.dist-info → dl_backtrace-0.0.17.dist-info}/WHEEL +0 -0
  25. {dl_backtrace-0.0.16.dev4.dist-info → dl_backtrace-0.0.17.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,840 @@
1
+ import gc
2
+
3
+ import numpy as np
4
+ import tensorflow as tf
5
+ from numpy.lib.stride_tricks import as_strided
6
+ from tensorflow.keras import backend as K
7
+
8
+
9
+ def np_swish(x, beta=0.75):
10
+ z = 1 / (1 + np.exp(-(beta * x)))
11
+ return x * z
12
+
13
+
14
+ def np_wave(x, alpha=1.0):
15
+ return (alpha * x * np.exp(1.0)) / (np.exp(-x) + np.exp(x))
16
+
17
+
18
+ def np_pulse(x, alpha=1.0):
19
+ return alpha * (1 - np.tanh(x) * np.tanh(x))
20
+
21
+
22
+ def np_absolute(x, alpha=1.0):
23
+ return alpha * x * np.tanh(x)
24
+
25
+
26
+ def np_hard_sigmoid(x):
27
+ return np.clip(0.2 * x + 0.5, 0, 1)
28
+
29
+
30
+ def np_sigmoid(x):
31
+ z = 1 / (1 + np.exp(-x))
32
+ return z
33
+
34
+
35
+ def np_tanh(x):
36
+ z = np.tanh(x)
37
+ return z.astype(np.float32)
38
+
39
+
40
+ def calculate_start_wt(arg, max_wt=1):
41
+ x = np.argmax(arg[0])
42
+ m = np.max(arg[0])
43
+ y_pos = np.zeros_like(arg)
44
+ y_pos[0][x] = m
45
+ y_neg = np.array(arg)
46
+ if m < 1 and arg.shape[-1] == 1:
47
+ y_neg[0][x] = 1 - m
48
+ else:
49
+ y_neg[0][x] = 0
50
+ return y_pos[0], y_neg[0]
51
+
52
+
53
+ def calculate_base_wt(p_sum=0, n_sum=0, bias=0, wt_pos=0, wt_neg=0):
54
+ t_diff = p_sum + bias - n_sum
55
+ bias = 0
56
+ wt_sign = 1
57
+ if t_diff > 0:
58
+ if wt_pos > wt_neg:
59
+ p_agg_wt = wt_pos
60
+ n_agg_wt = wt_neg
61
+ else:
62
+ p_agg_wt = wt_neg
63
+ n_agg_wt = wt_pos
64
+ wt_sign = -1
65
+ elif t_diff < 0:
66
+ if wt_pos < wt_neg:
67
+ p_agg_wt = wt_pos
68
+ n_agg_wt = wt_neg
69
+ else:
70
+ p_agg_wt = wt_neg
71
+ n_agg_wt = wt_pos
72
+ wt_sign = -1
73
+ else:
74
+ p_agg_wt = 0
75
+ n_agg_wt = 0
76
+ if p_sum == 0:
77
+ p_sum = 1
78
+ if n_sum == 0:
79
+ n_sum = 1
80
+ return p_agg_wt, n_agg_wt, p_sum, n_sum, wt_sign
81
+
82
+
83
+ class LSTM_forward(object):
84
+ def __init__(
85
+ self, num_cells, units, weights, return_sequence=False, go_backwards=False
86
+ ):
87
+ self.num_cells = num_cells
88
+ self.units = units
89
+ self.kernel = weights[0]
90
+ self.recurrent_kernel = weights[1]
91
+ self.bias = weights[2]
92
+ self.return_sequence = return_sequence
93
+ self.go_backwards = go_backwards
94
+ self.recurrent_activation = tf.math.sigmoid
95
+ self.activation = tf.math.tanh
96
+
97
+ self.compute_log = {}
98
+ for i in range(self.num_cells):
99
+ self.compute_log[i] = {}
100
+ self.compute_log[i]["inp"] = None
101
+ self.compute_log[i]["x"] = None
102
+ self.compute_log[i]["hstate"] = [None, None]
103
+ self.compute_log[i]["cstate"] = [None, None]
104
+ self.compute_log[i]["int_arrays"] = {}
105
+
106
+ def compute_carry_and_output(self, x, h_tm1, c_tm1, cell_num):
107
+ """Computes carry and output using split kernels."""
108
+ x_i, x_f, x_c, x_o = x
109
+ h_tm1_i, h_tm1_f, h_tm1_c, h_tm1_o = h_tm1
110
+ i = self.recurrent_activation(
111
+ x_i + K.dot(h_tm1_i, self.recurrent_kernel[:, : self.units])
112
+ )
113
+ f = self.recurrent_activation(
114
+ x_f + K.dot(h_tm1_f, self.recurrent_kernel[:, self.units : self.units * 2])
115
+ )
116
+ c = f * c_tm1 + i * self.activation(
117
+ x_c
118
+ + K.dot(h_tm1_c, self.recurrent_kernel[:, self.units * 2 : self.units * 3])
119
+ )
120
+ o = self.recurrent_activation(
121
+ x_o + K.dot(h_tm1_o, self.recurrent_kernel[:, self.units * 3 :])
122
+ )
123
+ self.compute_log[cell_num]["int_arrays"]["i"] = i
124
+ self.compute_log[cell_num]["int_arrays"]["f"] = f
125
+ self.compute_log[cell_num]["int_arrays"]["c"] = c
126
+ self.compute_log[cell_num]["int_arrays"]["o"] = o
127
+ return c, o
128
+
129
+ def calculate_lstm_cell_wt(self, inputs, states, cell_num, training=None):
130
+ h_tm1 = states[0] # previous memory state
131
+ c_tm1 = states[1] # previous carry state
132
+ self.compute_log[cell_num]["inp"] = inputs
133
+ self.compute_log[cell_num]["hstate"][0] = h_tm1
134
+ self.compute_log[cell_num]["cstate"][0] = c_tm1
135
+ inputs_i = inputs
136
+ inputs_f = inputs
137
+ inputs_c = inputs
138
+ inputs_o = inputs
139
+ k_i, k_f, k_c, k_o = tf.split(self.kernel, num_or_size_splits=4, axis=1)
140
+ x_i = K.dot(inputs_i, k_i)
141
+ x_f = K.dot(inputs_f, k_f)
142
+ x_c = K.dot(inputs_c, k_c)
143
+ x_o = K.dot(inputs_o, k_o)
144
+ b_i, b_f, b_c, b_o = tf.split(self.bias, num_or_size_splits=4, axis=0)
145
+ x_i = tf.add(x_i, b_i)
146
+ x_f = tf.add(x_f, b_f)
147
+ x_c = tf.add(x_c, b_c)
148
+ x_o = tf.add(x_o, b_o)
149
+
150
+ h_tm1_i = h_tm1
151
+ h_tm1_f = h_tm1
152
+ h_tm1_c = h_tm1
153
+ h_tm1_o = h_tm1
154
+ x = (x_i, x_f, x_c, x_o)
155
+ h_tm1 = (h_tm1_i, h_tm1_f, h_tm1_c, h_tm1_o)
156
+ c, o = self.compute_carry_and_output(x, h_tm1, c_tm1, cell_num)
157
+ h = o * self.activation(c)
158
+ self.compute_log[cell_num]["x"] = x
159
+ self.compute_log[cell_num]["hstate"][1] = h
160
+ self.compute_log[cell_num]["cstate"][1] = c
161
+ return h, [h, c]
162
+
163
+ def calculate_lstm_wt(self, input_data):
164
+ hstate = tf.convert_to_tensor(np.zeros((1, self.units)), dtype=tf.float32)
165
+ cstate = tf.convert_to_tensor(np.zeros((1, self.units)), dtype=tf.float32)
166
+ output = []
167
+ for ind in range(input_data.shape[0]):
168
+ inp = tf.convert_to_tensor(
169
+ input_data[ind, :].reshape((1, input_data.shape[1])), dtype=tf.float32
170
+ )
171
+ h, s = self.calculate_lstm_cell_wt(inp, [hstate, cstate], ind)
172
+ hstate = s[0]
173
+ cstate = s[1]
174
+ output.append(h)
175
+ return output
176
+
177
+
178
+ class LSTM_backtrace(object):
179
+ def __init__(
180
+ self, num_cells, units, weights, return_sequence=False, go_backwards=False
181
+ ):
182
+ self.num_cells = num_cells
183
+ self.units = units
184
+ self.kernel = weights[0]
185
+ self.recurrent_kernel = weights[1]
186
+ self.bias = weights[2]
187
+ self.return_sequence = return_sequence
188
+ self.go_backwards = go_backwards
189
+ self.recurrent_activation = np_sigmoid
190
+ self.activation = np_tanh
191
+
192
+ self.compute_log = {}
193
+
194
+ def calculate_wt_fc(self, wts, inp, w, b, act):
195
+ wts_pos = wts[0]
196
+ wts_neg = wts[1]
197
+ mul_mat = np.einsum("ij,i->ij", w, inp).T
198
+ wt_mat_pos = np.zeros(mul_mat.shape)
199
+ wt_mat_neg = np.zeros(mul_mat.shape)
200
+ for i in range(mul_mat.shape[0]):
201
+ l1_ind1 = mul_mat[i]
202
+ wt_ind1_pos = wt_mat_pos[i]
203
+ wt_ind1_neg = wt_mat_neg[i]
204
+ wt_pos = wts_pos[i]
205
+ wt_neg = wts_neg[i]
206
+ p_ind = l1_ind1 > 0
207
+ n_ind = l1_ind1 < 0
208
+ p_sum = np.sum(l1_ind1[p_ind])
209
+ n_sum = np.sum(l1_ind1[n_ind]) * -1
210
+ if len(b) > 0:
211
+ bias = b[i]
212
+ else:
213
+ bias = 0
214
+ if np.sum(n_ind) == 0 and np.sum(p_ind) > 0:
215
+ wt_ind1_pos[p_ind] = (l1_ind1[p_ind] / p_sum) * wt_pos
216
+ wt_ind1_neg[p_ind] = (l1_ind1[p_ind] / p_sum) * wt_neg
217
+ elif np.sum(n_ind) > 0 and np.sum(p_ind) == 0:
218
+ wt_ind1_pos[n_ind] = (l1_ind1[n_ind] / n_sum) * wt_pos * -1
219
+ wt_ind1_neg[n_ind] = (l1_ind1[n_ind] / n_sum) * wt_neg * -1
220
+ else:
221
+ p_agg_wt, n_agg_wt, p_sum, n_sum, wt_sign = calculate_base_wt(
222
+ p_sum=p_sum, n_sum=n_sum, bias=bias, wt_pos=wt_pos, wt_neg=wt_neg
223
+ )
224
+ if wt_sign > 0:
225
+ wt_ind1_pos[p_ind] = (l1_ind1[p_ind] / p_sum) * p_agg_wt
226
+ wt_ind1_neg[n_ind] = (l1_ind1[n_ind] / n_sum) * n_agg_wt * -1
227
+ else:
228
+ wt_ind1_neg[p_ind] = (l1_ind1[p_ind] / p_sum) * p_agg_wt
229
+ wt_ind1_pos[n_ind] = (l1_ind1[n_ind] / n_sum) * n_agg_wt * -1
230
+ wt_mat_pos = wt_mat_pos.sum(axis=0)
231
+ wt_mat_neg = wt_mat_neg.sum(axis=0)
232
+ return wt_mat_pos, wt_mat_neg
233
+
234
+ def calculate_wt_add(self, wts, inp=None):
235
+ wts_pos = wts[0]
236
+ wts_neg = wts[1]
237
+ wt_mat_pos = []
238
+ wt_mat_neg = []
239
+ inp_list = []
240
+ for x in inp:
241
+ wt_mat_pos.append(np.zeros_like(x))
242
+ wt_mat_neg.append(np.zeros_like(x))
243
+ wt_mat_pos = np.array(wt_mat_pos)
244
+ wt_mat_neg = np.array(wt_mat_neg)
245
+ inp_list = np.array(inp)
246
+ for i in range(wt_mat_pos.shape[1]):
247
+ wt_ind1_pos = wt_mat_pos[:, i]
248
+ wt_ind1_neg = wt_mat_neg[:, i]
249
+ wt_pos = wts_pos[i]
250
+ wt_neg = wts_neg[i]
251
+ l1_ind1 = inp_list[:, i]
252
+ p_ind = l1_ind1 > 0
253
+ n_ind = l1_ind1 < 0
254
+ p_sum = np.sum(l1_ind1[p_ind])
255
+ n_sum = np.sum(l1_ind1[n_ind]) * -1
256
+ if np.sum(n_ind) == 0 and np.sum(p_ind) > 0:
257
+ wt_ind1_pos[p_ind] = (l1_ind1[p_ind] / p_sum) * wt_pos
258
+ wt_ind1_neg[p_ind] = (l1_ind1[p_ind] / p_sum) * wt_neg
259
+ elif np.sum(n_ind) > 0 and np.sum(p_ind) == 0:
260
+ wt_ind1_pos[n_ind] = (l1_ind1[n_ind] / n_sum) * wt_pos * -1
261
+ wt_ind1_neg[n_ind] = (l1_ind1[n_ind] / n_sum) * wt_neg * -1
262
+ else:
263
+ p_agg_wt, n_agg_wt, p_sum, n_sum, wt_sign = calculate_base_wt(
264
+ p_sum=p_sum, n_sum=n_sum, bias=0.0, wt_pos=wt_pos, wt_neg=wt_neg
265
+ )
266
+ if wt_sign > 0:
267
+ wt_ind1_pos[p_ind] = (l1_ind1[p_ind] / p_sum) * p_agg_wt
268
+ wt_ind1_neg[n_ind] = (l1_ind1[n_ind] / n_sum) * n_agg_wt * -1
269
+ else:
270
+ wt_ind1_neg[p_ind] = (l1_ind1[p_ind] / p_sum) * p_agg_wt
271
+ wt_ind1_pos[n_ind] = (l1_ind1[n_ind] / n_sum) * n_agg_wt * -1
272
+ wt_mat_pos[:, i] = wt_ind1_pos
273
+ wt_mat_neg[:, i] = wt_ind1_neg
274
+ wt_mat_pos = [i.reshape(wts_pos.shape) for i in list(wt_mat_pos)]
275
+ wt_mat_neg = [i.reshape(wts_neg.shape) for i in list(wt_mat_neg)]
276
+ output = []
277
+ for i in range(len(wt_mat_pos)):
278
+ output.append((wt_mat_pos[i], wt_mat_neg[i]))
279
+ # print("\tADD ",np.sum([np.sum(i[0]) for i in output]),
280
+ # np.sum([np.sum(i[1]) for i in output]),
281
+ # np.sum(wts_pos),np.sum(wts_neg))
282
+ return output
283
+
284
+ def calculate_wt_multiply(self, wts, inp=None):
285
+ wts_pos = wts[0]
286
+ wts_neg = wts[1]
287
+ inp_list = []
288
+ wt_mat_pos = []
289
+ wt_mat_neg = []
290
+ for x in inp:
291
+ wt_mat_pos.append(np.zeros_like(x))
292
+ wt_mat_neg.append(np.zeros_like(x))
293
+ wt_mat_pos = np.array(wt_mat_pos)
294
+ wt_mat_neg = np.array(wt_mat_neg)
295
+ inp_list = np.array(inp)
296
+ inp1 = np.abs(inp[0])
297
+ inp2 = np.abs(inp[1])
298
+ inp_sum = inp1 + inp2
299
+ inp_prod = inp1 * inp2
300
+ inp1[inp_sum == 0] = 0
301
+ inp2[inp_sum == 0] = 0
302
+ inp1[inp_prod == 0] = 0
303
+ inp2[inp_prod == 0] = 0
304
+ inp_sum[inp_sum == 0] = 1
305
+ inp_wt1_pos = np.nan_to_num((inp2 / inp_sum) * wts_pos)
306
+ inp_wt1_neg = np.nan_to_num((inp2 / inp_sum) * wts_neg)
307
+ inp_wt2_pos = np.nan_to_num((inp1 / inp_sum) * wts_pos)
308
+ inp_wt2_neg = np.nan_to_num((inp1 / inp_sum) * wts_neg)
309
+ # print("MUL",np.sum(inp_wt1),np.sum(inp_wt2),np.sum(wts))
310
+ return [[inp_wt1_pos, inp_wt1_neg], [inp_wt2_pos, inp_wt2_neg]]
311
+
312
+ def compute_carry_and_output(self, wt_o, wt_c, h_tm1, c_tm1, x, cell_num):
313
+ """Computes carry and output using split kernels."""
314
+ h_tm1_i, h_tm1_f, h_tm1_c, h_tm1_o = (h_tm1, h_tm1, h_tm1, h_tm1)
315
+ x_i, x_f, x_c, x_o = x
316
+ f = self.compute_log[cell_num]["int_arrays"]["f"].numpy()[0]
317
+ i = self.compute_log[cell_num]["int_arrays"]["i"].numpy()[0]
318
+ # o = self.recurrent_activation(
319
+ # x_o + np.dot(h_tm1_o, self.recurrent_kernel[:, self.units * 3:])).astype(np.float32)
320
+ temp1 = np.dot(h_tm1_o, self.recurrent_kernel[:, self.units * 3 :]).astype(
321
+ np.float32
322
+ )
323
+ wt_x_o, wt_temp1 = self.calculate_wt_add(wt_o, [x_o, temp1])
324
+ wt_h_tm1_o = self.calculate_wt_fc(
325
+ wt_temp1,
326
+ h_tm1_o,
327
+ self.recurrent_kernel[:, self.units * 3 :],
328
+ [],
329
+ {"type": None},
330
+ )
331
+
332
+ # c = f * c_tm1 + i * self.activation(x_c + np.dot(
333
+ # h_tm1_c, self.recurrent_kernel[:, self.units * 2:self.units * 3])).astype(np.float32)
334
+ temp2 = f * c_tm1
335
+ temp3_1 = np.dot(
336
+ h_tm1_c, self.recurrent_kernel[:, self.units * 2 : self.units * 3]
337
+ )
338
+ temp3_2 = self.activation(x_c + temp3_1)
339
+ temp3_3 = i * temp3_2
340
+ wt_temp2, wt_temp3_3 = self.calculate_wt_add(wt_c, [temp2, temp3_3])
341
+ wt_f, wt_c_tm1 = self.calculate_wt_multiply(wt_temp2, [f, c_tm1])
342
+ wt_i, wt_temp3_2 = self.calculate_wt_multiply(wt_temp3_3, [i, temp3_2])
343
+ wt_x_c, wt_temp3_1 = self.calculate_wt_add(wt_temp3_2, [x_c, temp3_1])
344
+ wt_h_tm1_c = self.calculate_wt_fc(
345
+ wt_temp3_1,
346
+ h_tm1_c,
347
+ self.recurrent_kernel[:, self.units * 2 : self.units * 3],
348
+ [],
349
+ {"type": None},
350
+ )
351
+
352
+ # f = self.recurrent_activation(x_f + np.dot(
353
+ # h_tm1_f, self.recurrent_kernel[:, self.units:self.units * 2])).astype(np.float32)
354
+ temp4 = np.dot(h_tm1_f, self.recurrent_kernel[:, self.units : self.units * 2])
355
+ wt_x_f, wt_temp4 = self.calculate_wt_add(wt_f, [x_f, temp4])
356
+ wt_h_tm1_f = self.calculate_wt_fc(
357
+ wt_temp4,
358
+ h_tm1_f,
359
+ self.recurrent_kernel[:, self.units : self.units * 2],
360
+ [],
361
+ {"type": None},
362
+ )
363
+
364
+ # i = self.recurrent_activation(
365
+ # x_i + np.dot(h_tm1_i, self.recurrent_kernel[:, :self.units])).astype(np.float32)
366
+ temp5 = np.dot(h_tm1_i, self.recurrent_kernel[:, : self.units])
367
+ wt_x_i, wt_temp5 = self.calculate_wt_add(wt_i, [x_i, temp5])
368
+ wt_h_tm1_i = self.calculate_wt_fc(
369
+ wt_temp5,
370
+ h_tm1_i,
371
+ self.recurrent_kernel[:, : self.units],
372
+ [],
373
+ {"type": None},
374
+ )
375
+
376
+ return (
377
+ wt_x_i,
378
+ wt_x_f,
379
+ wt_x_c,
380
+ wt_x_o,
381
+ wt_h_tm1_i,
382
+ wt_h_tm1_f,
383
+ wt_h_tm1_c,
384
+ wt_h_tm1_o,
385
+ wt_c_tm1,
386
+ )
387
+
388
+ def calculate_lstm_cell_wt(self, cell_num, wts_hstate, wts_cstate):
389
+ o = self.compute_log[cell_num]["int_arrays"]["o"].numpy()[0]
390
+ c = self.compute_log[cell_num]["cstate"][1].numpy()[0]
391
+ h_tm1 = self.compute_log[cell_num]["hstate"][0].numpy()[0]
392
+ c_tm1 = self.compute_log[cell_num]["cstate"][0].numpy()[0]
393
+ x = [i.numpy()[0] for i in self.compute_log[cell_num]["x"]]
394
+ wt_o, wt_c = self.calculate_wt_multiply(
395
+ wts_hstate, [o, self.activation(c)]
396
+ ) # h = o * self.activation(c)
397
+ wt_c[0] = wt_c[0] + wts_cstate[0]
398
+ wt_c[1] = wt_c[1] + wts_cstate[1]
399
+ (
400
+ wt_x_i,
401
+ wt_x_f,
402
+ wt_x_c,
403
+ wt_x_o,
404
+ wt_h_tm1_i,
405
+ wt_h_tm1_f,
406
+ wt_h_tm1_c,
407
+ wt_h_tm1_o,
408
+ wt_c_tm1,
409
+ ) = self.compute_carry_and_output(wt_o, wt_c, h_tm1, c_tm1, x, cell_num)
410
+ wt_h_tm1 = [
411
+ wt_h_tm1_i[0] + wt_h_tm1_f[0] + wt_h_tm1_c[0] + wt_h_tm1_o[0],
412
+ wt_h_tm1_i[1] + wt_h_tm1_f[1] + wt_h_tm1_c[1] + wt_h_tm1_o[1],
413
+ ]
414
+ inputs = self.compute_log[cell_num]["inp"].numpy()[0]
415
+ k_i, k_f, k_c, k_o = np.split(self.kernel, indices_or_sections=4, axis=1)
416
+ b_i, b_f, b_c, b_o = np.split(self.bias, indices_or_sections=4, axis=0)
417
+
418
+ wt_inputs_i = self.calculate_wt_fc(wt_x_i, inputs, k_i, b_i, {"type": None})
419
+ wt_inputs_f = self.calculate_wt_fc(wt_x_f, inputs, k_f, b_f, {"type": None})
420
+ wt_inputs_c = self.calculate_wt_fc(wt_x_c, inputs, k_c, b_c, {"type": None})
421
+ wt_inputs_o = self.calculate_wt_fc(wt_x_o, inputs, k_o, b_o, {"type": None})
422
+
423
+ wt_inputs = [
424
+ wt_inputs_i[0] + wt_inputs_f[0] + wt_inputs_c[0] + wt_inputs_o[0],
425
+ wt_inputs_i[1] + wt_inputs_f[1] + wt_inputs_c[1] + wt_inputs_o[1],
426
+ ]
427
+
428
+ return wt_inputs, wt_h_tm1, wt_c_tm1
429
+
430
+ def calculate_lstm_wt(self, wts_pos, wts_neg, compute_log):
431
+ self.compute_log = compute_log
432
+ output_pos = []
433
+ output_neg = []
434
+ if self.return_sequence:
435
+ temp_wts_hstate = [wts_pos[-1, :], wts_neg[-1, :]]
436
+ else:
437
+ temp_wts_hstate = [wts_pos, wts_neg]
438
+ temp_wts_cstate = [
439
+ np.zeros_like(self.compute_log[0]["cstate"][1].numpy()[0]),
440
+ np.zeros_like(self.compute_log[0]["cstate"][1].numpy()[0]),
441
+ ]
442
+ for ind in range(len(self.compute_log) - 1, -1, -1):
443
+ temp_wt_inp, temp_wts_hstate, temp_wts_cstate = self.calculate_lstm_cell_wt(
444
+ ind, temp_wts_hstate, temp_wts_cstate
445
+ )
446
+ output_pos.append(temp_wt_inp[0])
447
+ output_neg.append(temp_wt_inp[1])
448
+ if self.return_sequence and ind > 0:
449
+ temp_wts_hstate[0] = temp_wts_hstate[0] + wts_pos[ind - 1, :]
450
+ temp_wts_hstate[1] = temp_wts_hstate[1] + wts_neg[ind - 1, :]
451
+ output_pos.reverse()
452
+ output_pos = np.array(output_pos)
453
+ output_neg.reverse()
454
+ output_neg = np.array(output_neg)
455
+ return output_pos, output_neg
456
+
457
+
458
+ def dummy_wt(wts, inp, *args):
459
+ test_wt = np.zeros_like(inp)
460
+ return test_wt
461
+
462
+
463
+ def calculate_wt_fc(wts_pos, wts_neg, inp, w, b, act={}):
464
+ mul_mat = np.einsum("ij,i->ij", w.numpy().T, inp).T
465
+ wt_mat_pos = np.zeros(mul_mat.shape)
466
+ wt_mat_neg = np.zeros(mul_mat.shape)
467
+ for i in range(mul_mat.shape[0]):
468
+ l1_ind1 = mul_mat[i]
469
+ wt_ind1_pos = wt_mat_pos[i]
470
+ wt_ind1_neg = wt_mat_neg[i]
471
+ wt_pos = wts_pos[i]
472
+ wt_neg = wts_neg[i]
473
+ p_ind = l1_ind1 > 0
474
+ n_ind = l1_ind1 < 0
475
+ p_sum = np.sum(l1_ind1[p_ind])
476
+ n_sum = np.sum(l1_ind1[n_ind]) * -1
477
+ if np.sum(n_ind) == 0 and np.sum(p_ind) > 0:
478
+ wt_ind1_pos[p_ind] = (l1_ind1[p_ind] / p_sum) * wt_pos
479
+ wt_ind1_neg[p_ind] = (l1_ind1[p_ind] / p_sum) * wt_neg
480
+ elif np.sum(n_ind) > 0 and np.sum(p_ind) == 0:
481
+ wt_ind1_pos[n_ind] = (l1_ind1[n_ind] / n_sum) * wt_pos * -1
482
+ wt_ind1_neg[n_ind] = (l1_ind1[n_ind] / n_sum) * wt_neg * -1
483
+ else:
484
+ p_agg_wt, n_agg_wt, p_sum, n_sum, wt_sign = calculate_base_wt(
485
+ p_sum=p_sum,
486
+ n_sum=n_sum,
487
+ bias=b.numpy()[i],
488
+ wt_pos=wt_pos,
489
+ wt_neg=wt_neg,
490
+ )
491
+ if wt_sign > 0:
492
+ wt_ind1_pos[p_ind] = (l1_ind1[p_ind] / p_sum) * p_agg_wt
493
+ wt_ind1_neg[n_ind] = (l1_ind1[n_ind] / n_sum) * n_agg_wt * -1
494
+ else:
495
+ wt_ind1_neg[p_ind] = (l1_ind1[p_ind] / p_sum) * p_agg_wt
496
+ wt_ind1_pos[n_ind] = (l1_ind1[n_ind] / n_sum) * n_agg_wt * -1
497
+ # print(wt_pos,wt_neg,p_agg_wt,n_agg_wt,wt_sign)
498
+ # print("---------------------------------")
499
+ wt_mat_pos = wt_mat_pos.sum(axis=0)
500
+ wt_mat_neg = wt_mat_neg.sum(axis=0)
501
+ return wt_mat_pos, wt_mat_neg
502
+
503
+
504
+ def calculate_wt_passthru(wts):
505
+ return wts
506
+
507
+
508
+ def calculate_wt_rshp(wts, inp=None):
509
+ x = np.reshape(wts, inp.shape)
510
+ return x
511
+
512
+
513
+ def calculate_wt_concat(wts, inp=None, axis=-1):
514
+ splits = [i.shape[axis] for i in inp]
515
+ splits = np.cumsum(splits)
516
+ if axis > 0:
517
+ axis = axis - 1
518
+ x = np.split(wts, indices_or_sections=splits, axis=axis)
519
+ return x
520
+
521
+
522
+ def calculate_wt_add(wts_pos, wts_neg, inp=None):
523
+ wts_pos = wts_pos
524
+ wts_neg = wts_neg
525
+ wt_mat_pos = []
526
+ wt_mat_neg = []
527
+ inp_list = []
528
+
529
+ expanded_wts_pos = as_strided(
530
+ wts_pos,
531
+ shape=(np.prod(wts_pos.shape),),
532
+ strides=(wts_pos.strides[-1],),
533
+ writeable=False, # totally use this to avoid writing to memory in weird places
534
+ )
535
+ expanded_wts_neg = as_strided(
536
+ wts_neg,
537
+ shape=(np.prod(wts_neg.shape),),
538
+ strides=(wts_neg.strides[-1],),
539
+ writeable=False, # totally use this to avoid writing to memory in weird places
540
+ )
541
+ for x in inp:
542
+ expanded_input = as_strided(
543
+ x,
544
+ shape=(np.prod(x.shape),),
545
+ strides=(x.strides[-1],),
546
+ writeable=False, # totally use this to avoid writing to memory in weird places
547
+ )
548
+ inp_list.append(expanded_input)
549
+ wt_mat_pos.append(np.zeros_like(expanded_input))
550
+ wt_mat_neg.append(np.zeros_like(expanded_input))
551
+ wt_mat_pos = np.array(wt_mat_pos)
552
+ wt_mat_neg = np.array(wt_mat_neg)
553
+ inp_list = np.array(inp_list)
554
+ for i in range(wt_mat_pos.shape[1]):
555
+ wt_ind1_pos = wt_mat_pos[:, i]
556
+ wt_ind1_neg = wt_mat_neg[:, i]
557
+ wt_pos = expanded_wts_pos[i]
558
+ wt_neg = expanded_wts_neg[i]
559
+ l1_ind1 = inp_list[:, i]
560
+ p_ind = l1_ind1 > 0
561
+ n_ind = l1_ind1 < 0
562
+ p_sum = np.sum(l1_ind1[p_ind])
563
+ n_sum = np.sum(l1_ind1[n_ind]) * -1
564
+ if np.sum(n_ind) == 0 and np.sum(p_ind) > 0:
565
+ wt_ind1_pos[p_ind] = (l1_ind1[p_ind] / p_sum) * wt_pos
566
+ wt_ind1_neg[p_ind] = (l1_ind1[p_ind] / p_sum) * wt_neg
567
+ elif np.sum(n_ind) > 0 and np.sum(p_ind) == 0:
568
+ wt_ind1_pos[n_ind] = (l1_ind1[n_ind] / n_sum) * wt_pos * -1
569
+ wt_ind1_neg[n_ind] = (l1_ind1[n_ind] / n_sum) * wt_neg * -1
570
+ else:
571
+ p_agg_wt, n_agg_wt, p_sum, n_sum, wt_sign = calculate_base_wt(
572
+ p_sum=p_sum, n_sum=n_sum, bias=0.0, wt_pos=wt_pos, wt_neg=wt_neg
573
+ )
574
+ if wt_sign > 0:
575
+ wt_ind1_pos[p_ind] = (l1_ind1[p_ind] / p_sum) * p_agg_wt
576
+ wt_ind1_neg[n_ind] = (l1_ind1[n_ind] / n_sum) * n_agg_wt * -1
577
+ else:
578
+ wt_ind1_neg[p_ind] = (l1_ind1[p_ind] / p_sum) * p_agg_wt
579
+ wt_ind1_pos[n_ind] = (l1_ind1[n_ind] / n_sum) * n_agg_wt * -1
580
+ wt_mat_pos[:, i] = wt_ind1_pos
581
+ wt_mat_neg[:, i] = wt_ind1_neg
582
+ wt_mat_pos = [i.reshape(wts_pos.shape) for i in list(wt_mat_pos)]
583
+ wt_mat_neg = [i.reshape(wts_neg.shape) for i in list(wt_mat_neg)]
584
+ output = []
585
+ for i in range(len(wt_mat_pos)):
586
+ output.append((wt_mat_pos[i], wt_mat_neg[i]))
587
+ return output
588
+
589
+
590
+ def calculate_wt_passthru(wts):
591
+ return wts
592
+
593
+
594
+ def calculate_wt_conv_unit(
595
+ wt_pos, wt_neg, p_mat, n_mat, p_sum, n_sum, pbias, nbias, act={}
596
+ ):
597
+ wt_mat_pos = np.zeros_like(p_mat)
598
+ wt_mat_neg = np.zeros_like(p_mat)
599
+ if n_sum == 0 and p_sum > 0:
600
+ wt_mat_pos = wt_mat_pos + ((p_mat / p_sum) * wt_pos)
601
+ wt_mat_neg = wt_mat_neg + ((p_mat / p_sum) * wt_neg)
602
+ elif n_sum > 0 and p_sum == 0:
603
+ wt_mat_pos = wt_mat_pos + ((n_mat / n_sum) * wt_pos * -1)
604
+ wt_mat_neg = wt_mat_neg + ((n_mat / n_sum) * wt_neg * -1)
605
+ else:
606
+ p_agg_wt, n_agg_wt, p_sum, n_sum, wt_sign = calculate_base_wt(
607
+ p_sum=p_sum, n_sum=n_sum, bias=pbias - nbias, wt_pos=wt_pos, wt_neg=wt_neg
608
+ )
609
+ if wt_sign > 0:
610
+ wt_mat_pos = wt_mat_pos + ((p_mat / p_sum) * p_agg_wt)
611
+ wt_mat_neg = wt_mat_neg + ((n_mat / n_sum) * n_agg_wt * -1)
612
+ else:
613
+ wt_mat_neg = wt_mat_neg + ((p_mat / p_sum) * p_agg_wt)
614
+ wt_mat_pos = wt_mat_pos + ((n_mat / n_sum) * n_agg_wt * -1)
615
+ return wt_mat_pos, wt_mat_neg
616
+
617
+
618
+ def dummy_wt_conv(wt, p_mat, n_mat, t_sum, p_sum, n_sum, act):
619
+ wt_mat = np.ones_like(p_mat)
620
+ return wt_mat / np.sum(wt_mat)
621
+
622
+
623
+ def calculate_wt_conv(wts_pos, wts_neg, inp, w, b, act):
624
+ wts_pos=wts_pos.T
625
+ wts_neg=wts_neg.T
626
+ inp=inp.T
627
+ w=w.T
628
+ expanded_input = as_strided(
629
+ inp,
630
+ shape=(
631
+ inp.shape[0]
632
+ - w.numpy().shape[0]
633
+ + 1, # The feature map is a few pixels smaller than the input
634
+ inp.shape[1] - w.numpy().shape[1] + 1,
635
+ inp.shape[2],
636
+ w.numpy().shape[0],
637
+ w.numpy().shape[1],
638
+ ),
639
+ strides=(
640
+ inp.strides[0],
641
+ inp.strides[1],
642
+ inp.strides[2],
643
+ inp.strides[
644
+ 0
645
+ ], # When we move one step in the 3rd dimension, we should move one step in the original data too
646
+ inp.strides[1],
647
+ ),
648
+ writeable=False, # totally use this to avoid writing to memory in weird places
649
+ )
650
+ test_wt_pos = np.einsum("mnc->cmn", np.zeros_like(inp), order="C", optimize=True)
651
+ test_wt_neg = np.einsum("mnc->cmn", np.zeros_like(inp), order="C", optimize=True)
652
+ for k in range(w.numpy().shape[-1]):
653
+ kernel = w.numpy()[:, :, :, k]
654
+ if b.numpy()[k] > 0:
655
+ pbias = b.numpy()[k]
656
+ nbias = 0
657
+ else:
658
+ pbias = 0
659
+ nbias = b.numpy()[k] * -1
660
+ x = np.einsum(
661
+ "abcmn,mnc->abcmn", expanded_input, kernel, order="C", optimize=True
662
+ )
663
+ # x_pos = np.copy(x)
664
+ # x_neg = np.copy(x)
665
+ x_pos = x.copy()
666
+ x_neg = x.copy()
667
+ x_pos[x < 0] = 0
668
+ x_neg[x > 0] = 0
669
+ x_p_sum = np.einsum("abcmn->ab", x_pos, order="C", optimize=True)
670
+ x_n_sum = np.einsum("abcmn->ab", x_neg, order="C", optimize=True) * -1.0
671
+ # print(np.sum(x),np.sum(x_pos),np.sum(x_neg),np.sum(x_n_sum))
672
+ for ind1 in range(expanded_input.shape[0]):
673
+ for ind2 in range(expanded_input.shape[1]):
674
+ temp_wt_mat_pos, temp_wt_mat_neg = calculate_wt_conv_unit(
675
+ wts_pos[ind1, ind2, k],
676
+ wts_neg[ind1, ind2, k],
677
+ x_pos[ind1, ind2, :, :, :],
678
+ x_neg[ind1, ind2, :, :, :],
679
+ x_p_sum[ind1, ind2],
680
+ x_n_sum[ind1, ind2],
681
+ pbias,
682
+ nbias,
683
+ act,
684
+ )
685
+ test_wt_pos[
686
+ :, ind1 : ind1 + kernel.shape[0], ind2 : ind2 + kernel.shape[1]
687
+ ] += temp_wt_mat_pos
688
+ test_wt_neg[
689
+ :, ind1 : ind1 + kernel.shape[0], ind2 : ind2 + kernel.shape[1]
690
+ ] += temp_wt_mat_neg
691
+ test_wt_pos = np.einsum("cmn->mnc", test_wt_pos, order="C", optimize=True)
692
+ test_wt_neg = np.einsum("cmn->mnc", test_wt_neg, order="C", optimize=True)
693
+ gc.collect()
694
+ return test_wt_pos, test_wt_neg
695
+
696
+
697
+ def get_max_index(mat=None):
698
+ max_ind = np.argmax(mat)
699
+ ind = []
700
+ rem = max_ind
701
+ for i in mat.shape[:-1]:
702
+ ind.append(rem // i)
703
+ rem = rem % i
704
+ ind.append(rem)
705
+ return tuple(ind)
706
+
707
+
708
+ def calculate_wt_maxpool(wts, inp, pool_size):
709
+ wts=wts.T
710
+ inp=inp.T
711
+ pad1 = pool_size[0]
712
+ pad2 = pool_size[1]
713
+ test_samp_pad = np.pad(inp, ((0, pad1), (0, pad2), (0, 0)), "constant")
714
+ dim1, dim2, _ = wts.shape
715
+ test_wt = np.zeros_like(test_samp_pad)
716
+ for k in range(inp.shape[2]):
717
+ wt_mat = wts[:, :, k]
718
+ for ind1 in range(dim1):
719
+ for ind2 in range(dim2):
720
+ temp_inp = test_samp_pad[
721
+ ind1 * pool_size[0] : (ind1 + 1) * pool_size[0],
722
+ ind2 * pool_size[1] : (ind2 + 1) * pool_size[1],
723
+ k,
724
+ ]
725
+ max_index = get_max_index(temp_inp)
726
+ test_wt[
727
+ ind1 * pool_size[0] : (ind1 + 1) * pool_size[0],
728
+ ind2 * pool_size[1] : (ind2 + 1) * pool_size[1],
729
+ k,
730
+ ][max_index] = wt_mat[ind1, ind2]
731
+ test_wt = test_wt[0 : inp.shape[0], 0 : inp.shape[1], :]
732
+ return test_wt
733
+
734
+
735
+ def calculate_wt_avgpool(wts_pos, wts_neg, inp, pool_size):
736
+ pad1 = pool_size[0]
737
+ pad2 = pool_size[1]
738
+ test_samp_pad = np.pad(inp, ((0, pad1), (0, pad2), (0, 0)), "constant")
739
+ dim1, dim2, _ = wts_pos.shape
740
+ test_wt_pos = np.zeros_like(test_samp_pad)
741
+ test_wt_neg = np.zeros_like(test_samp_pad)
742
+ for k in range(inp.shape[2]):
743
+ wt_mat_pos = wts_pos[:, :, k]
744
+ wt_mat_neg = wts_pos[:, :, k]
745
+ for ind1 in range(dim1):
746
+ for ind2 in range(dim2):
747
+ temp_inp = test_samp_pad[
748
+ ind1 * pool_size[0] : (ind1 + 1) * pool_size[0],
749
+ ind2 * pool_size[1] : (ind2 + 1) * pool_size[1],
750
+ k,
751
+ ]
752
+ wt_ind1_pos = test_wt_pos[
753
+ ind1 * pool_size[0] : (ind1 + 1) * pool_size[0],
754
+ ind2 * pool_size[1] : (ind2 + 1) * pool_size[1],
755
+ k,
756
+ ]
757
+ wt_ind1_neg = test_wt_neg[
758
+ ind1 * pool_size[0] : (ind1 + 1) * pool_size[0],
759
+ ind2 * pool_size[1] : (ind2 + 1) * pool_size[1],
760
+ k,
761
+ ]
762
+ wt_pos = wt_mat_pos[ind1, ind2]
763
+ wt_neg = wt_mat_neg[ind1, ind2]
764
+ p_ind = temp_inp > 0
765
+ n_ind = temp_inp < 0
766
+ p_sum = np.sum(temp_inp[p_ind])
767
+ n_sum = np.sum(temp_inp[n_ind]) * -1
768
+ if np.sum(n_ind) == 0 and np.sum(p_ind) > 0:
769
+ wt_ind1_pos[p_ind] += (temp_inp[p_ind] / p_sum) * wt_pos
770
+ wt_ind1_neg[p_ind] += (temp_inp[p_ind] / p_sum) * wt_neg
771
+ elif np.sum(n_ind) > 0 and np.sum(p_ind) == 0:
772
+ wt_ind1_pos[n_ind] += (temp_inp[n_ind] / n_sum) * wt_pos * -1
773
+ wt_ind1_neg[n_ind] += (temp_inp[n_ind] / n_sum) * wt_neg * -1
774
+ else:
775
+ p_agg_wt, n_agg_wt, p_sum, n_sum, wt_sign = calculate_base_wt(
776
+ p_sum=p_sum, n_sum=n_sum, bias=0.0, wt_pos=wt_pos, wt_neg=wt_neg
777
+ )
778
+ if wt_sign > 0:
779
+ wt_ind1_pos[p_ind] += (temp_inp[p_ind] / p_sum) * p_agg_wt
780
+ wt_ind1_neg[n_ind] += (temp_inp[n_ind] / n_sum) * n_agg_wt * -1
781
+ else:
782
+ wt_ind1_neg[p_ind] += (temp_inp[p_ind] / p_sum) * p_agg_wt
783
+ wt_ind1_pos[n_ind] += (temp_inp[n_ind] / n_sum) * n_agg_wt * -1
784
+ test_wt_pos = test_wt_pos[0 : inp.shape[0], 0 : inp.shape[1], :]
785
+ test_wt_neg = test_wt_neg[0 : inp.shape[0], 0 : inp.shape[1], :]
786
+ return test_wt_pos, test_wt_neg
787
+
788
+
789
+ def calculate_wt_gavgpool(wts_pos, wts_neg, inp):
790
+ channels = wts_pos.shape[0]
791
+ wt_mat_pos = np.zeros_like(inp)
792
+ wt_mat_neg = np.zeros_like(inp)
793
+ for c in range(channels):
794
+ wt_pos = wts_pos[c]
795
+ wt_neg = wts_neg[c]
796
+ temp_wt_pos = wt_mat_pos[..., c]
797
+ temp_wt_neg = wt_mat_neg[..., c]
798
+ x = inp[..., c]
799
+ p_mat = np.copy(x)
800
+ n_mat = np.copy(x)
801
+ p_mat[x < 0] = 0
802
+ n_mat[x > 0] = 0
803
+ p_sum = np.sum(p_mat)
804
+ n_sum = np.sum(n_mat) * -1
805
+ if n_sum == 0 and p_sum > 0:
806
+ temp_wt_pos = temp_wt_pos + ((p_mat / p_sum) * wt_pos)
807
+ temp_wt_neg = temp_wt_neg + ((p_mat / p_sum) * wt_neg)
808
+ elif n_sum > 0 and p_sum == 0:
809
+ temp_wt_pos = temp_wt_pos + ((n_mat / n_sum) * wt_pos * -1)
810
+ temp_wt_neg = temp_wt_neg + ((n_mat / n_sum) * wt_neg * -1)
811
+ else:
812
+ p_agg_wt, n_agg_wt, p_sum, n_sum, wt_sign = calculate_base_wt(
813
+ p_sum=p_sum, n_sum=n_sum, bias=0, wt_pos=wt_pos, wt_neg=wt_neg
814
+ )
815
+ if wt_sign > 0:
816
+ temp_wt_pos = temp_wt_pos + ((p_mat / p_sum) * p_agg_wt)
817
+ temp_wt_neg = temp_wt_neg + ((n_mat / n_sum) * n_agg_wt * -1)
818
+ else:
819
+ temp_wt_neg = temp_wt_neg + ((p_mat / p_sum) * p_agg_wt)
820
+ temp_wt_pos = temp_wt_pos + ((n_mat / n_sum) * n_agg_wt * -1)
821
+ wt_mat_pos[..., c] = temp_wt_pos
822
+ wt_mat_neg[..., c] = temp_wt_neg
823
+ return wt_mat_pos, wt_mat_neg
824
+
825
+
826
+ def weight_scaler(arg, scaler=100.0):
827
+ s1 = np.sum(arg)
828
+ scale_factor = s1 / scaler
829
+ return arg / scale_factor
830
+
831
+
832
+ def weight_normalize(arg, max_val=1.0):
833
+ arg_max = np.max(arg)
834
+ arg_min = np.abs(np.min(arg))
835
+ if arg_max > arg_min:
836
+ return (arg / arg_max) * max_val
837
+ elif arg_min > 0:
838
+ return (arg / arg_min) * max_val
839
+ else:
840
+ return arg