dl-backtrace 0.0.16.dev4__py3-none-any.whl → 0.0.17__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of dl-backtrace might be problematic. Click here for more details.

Files changed (25) hide show
  1. dl_backtrace/old_backtrace/__init__.py +1 -0
  2. dl_backtrace/old_backtrace/pytorch_backtrace/__init__.py +1 -0
  3. dl_backtrace/old_backtrace/pytorch_backtrace/backtrace/__init__.py +4 -0
  4. dl_backtrace/old_backtrace/pytorch_backtrace/backtrace/backtrace.py +639 -0
  5. dl_backtrace/old_backtrace/pytorch_backtrace/backtrace/config.py +41 -0
  6. dl_backtrace/old_backtrace/pytorch_backtrace/backtrace/utils/__init__.py +2 -0
  7. dl_backtrace/old_backtrace/pytorch_backtrace/backtrace/utils/contrast.py +840 -0
  8. dl_backtrace/old_backtrace/pytorch_backtrace/backtrace/utils/prop.py +746 -0
  9. dl_backtrace/old_backtrace/tf_backtrace/__init__.py +1 -0
  10. dl_backtrace/old_backtrace/tf_backtrace/backtrace/__init__.py +4 -0
  11. dl_backtrace/old_backtrace/tf_backtrace/backtrace/backtrace.py +527 -0
  12. dl_backtrace/old_backtrace/tf_backtrace/backtrace/config.py +41 -0
  13. dl_backtrace/old_backtrace/tf_backtrace/backtrace/utils/__init__.py +2 -0
  14. dl_backtrace/old_backtrace/tf_backtrace/backtrace/utils/contrast.py +834 -0
  15. dl_backtrace/old_backtrace/tf_backtrace/backtrace/utils/prop.py +725 -0
  16. dl_backtrace/tf_backtrace/backtrace/backtrace.py +5 -3
  17. dl_backtrace/tf_backtrace/backtrace/utils/utils_prop.py +53 -0
  18. dl_backtrace/version.py +2 -2
  19. dl_backtrace-0.0.17.dist-info/METADATA +164 -0
  20. dl_backtrace-0.0.17.dist-info/RECORD +44 -0
  21. dl_backtrace-0.0.16.dev4.dist-info/METADATA +0 -102
  22. dl_backtrace-0.0.16.dev4.dist-info/RECORD +0 -29
  23. {dl_backtrace-0.0.16.dev4.dist-info → dl_backtrace-0.0.17.dist-info}/LICENSE +0 -0
  24. {dl_backtrace-0.0.16.dev4.dist-info → dl_backtrace-0.0.17.dist-info}/WHEEL +0 -0
  25. {dl_backtrace-0.0.16.dev4.dist-info → dl_backtrace-0.0.17.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,746 @@
1
+ import gc
2
+
3
+ import numpy as np
4
+ import tensorflow as tf
5
+ from numpy.lib.stride_tricks import as_strided
6
+ from tensorflow.keras import backend as K
7
+
8
+
9
+ def np_swish(x, beta=0.75):
10
+ z = 1 / (1 + np.exp(-(beta * x)))
11
+ return x * z
12
+
13
+
14
+ def np_wave(x, alpha=1.0):
15
+ return (alpha * x * np.exp(1.0)) / (np.exp(-x) + np.exp(x))
16
+
17
+
18
+ def np_pulse(x, alpha=1.0):
19
+ return alpha * (1 - np.tanh(x) * np.tanh(x))
20
+
21
+
22
+ def np_absolute(x, alpha=1.0):
23
+ return alpha * x * np.tanh(x)
24
+
25
+
26
+ def np_hard_sigmoid(x):
27
+ return np.clip(0.2 * x + 0.5, 0, 1)
28
+
29
+
30
+ def np_sigmoid(x):
31
+ z = 1 / (1 + np.exp(-x))
32
+ return z
33
+
34
+
35
+ def np_tanh(x):
36
+ z = np.tanh(x)
37
+ return z.astype(np.float32)
38
+
39
+
40
+ class LSTM_forward(object):
41
+ def __init__(
42
+ self, num_cells, units, weights, return_sequence=False, go_backwards=False
43
+ ):
44
+ self.num_cells = num_cells
45
+ self.units = units
46
+ self.kernel = weights[0]
47
+ self.recurrent_kernel = weights[1]
48
+ self.bias = weights[2][1]
49
+ self.return_sequence = return_sequence
50
+ self.go_backwards = go_backwards
51
+ self.recurrent_activation = tf.math.sigmoid
52
+ self.activation = tf.math.tanh
53
+ self.compute_log = {}
54
+ for i in range(self.num_cells):
55
+ self.compute_log[i] = {}
56
+ self.compute_log[i]["inp"] = None
57
+ self.compute_log[i]["x"] = None
58
+ self.compute_log[i]["hstate"] = [None, None]
59
+ self.compute_log[i]["cstate"] = [None, None]
60
+ self.compute_log[i]["int_arrays"] = {}
61
+
62
+ def compute_carry_and_output(self, x, h_tm1, c_tm1, cell_num):
63
+ """Computes carry and output using split kernels."""
64
+ x_i, x_f, x_c, x_o = x
65
+ h_tm1_i, h_tm1_f, h_tm1_c, h_tm1_o = h_tm1
66
+ #print(self.recurrent_kernel[1][:, : self.units].shape)
67
+ #print(h_tm1_i.shape,self.recurrent_kernel[1][:, : self.units].shape)
68
+ w=tf.convert_to_tensor(self.recurrent_kernel[1], dtype=tf.float32)
69
+ #print(K.dot(h_tm1_i, w[:, : self.units]))
70
+
71
+ i = self.recurrent_activation(
72
+ x_i + K.dot(h_tm1_i, w[:, : self.units])
73
+ )
74
+ f = self.recurrent_activation(
75
+ x_f + K.dot(h_tm1_f, w[:, self.units : self.units * 2])
76
+ )
77
+ c = f * c_tm1 + i * self.activation(
78
+ x_c
79
+ + K.dot(h_tm1_c, w[:, self.units * 2 : self.units * 3])
80
+ )
81
+ o = self.recurrent_activation(
82
+ x_o + K.dot(h_tm1_o, w[:, self.units * 3 :])
83
+ )
84
+ self.compute_log[cell_num]["int_arrays"]["i"] = i
85
+ self.compute_log[cell_num]["int_arrays"]["f"] = f
86
+ self.compute_log[cell_num]["int_arrays"]["c"] = c
87
+ self.compute_log[cell_num]["int_arrays"]["o"] = o
88
+ return c, o
89
+
90
+ def calculate_lstm_cell_wt(self, inputs, states, cell_num, training=None):
91
+ h_tm1 = states[0] # previous memory state
92
+ c_tm1 = states[1] # previous carry state
93
+ self.compute_log[cell_num]["inp"] = inputs
94
+ self.compute_log[cell_num]["hstate"][0] = h_tm1
95
+ self.compute_log[cell_num]["cstate"][0] = c_tm1
96
+ inputs_i = inputs
97
+ inputs_f = inputs
98
+ inputs_c = inputs
99
+ inputs_o = inputs
100
+ k_i, k_f, k_c, k_o = tf.split(self.kernel[1], num_or_size_splits=4, axis=1)
101
+ x_i = K.dot(inputs_i, k_i)
102
+ x_f = K.dot(inputs_f, k_f)
103
+ x_c = K.dot(inputs_c, k_c)
104
+ x_o = K.dot(inputs_o, k_o)
105
+ b_i, b_f, b_c, b_o = tf.split(self.bias, num_or_size_splits=4, axis=0)
106
+ x_i = tf.add(x_i, b_i)
107
+ x_f = tf.add(x_f, b_f)
108
+ x_c = tf.add(x_c, b_c)
109
+ x_o = tf.add(x_o, b_o)
110
+
111
+ h_tm1_i = h_tm1
112
+ h_tm1_f = h_tm1
113
+ h_tm1_c = h_tm1
114
+ h_tm1_o = h_tm1
115
+ x = (x_i, x_f, x_c, x_o)
116
+ h_tm1 = (h_tm1_i, h_tm1_f, h_tm1_c, h_tm1_o)
117
+
118
+ c, o = self.compute_carry_and_output(x, h_tm1, c_tm1, cell_num)
119
+ h = o * self.activation(c)
120
+ self.compute_log[cell_num]["x"] = x
121
+ self.compute_log[cell_num]["hstate"][1] = h
122
+ self.compute_log[cell_num]["cstate"][1] = c
123
+ return h, [h, c]
124
+
125
+ def calculate_lstm_wt(self, input_data):
126
+ hstate = tf.convert_to_tensor(np.zeros((1, self.units)), dtype=tf.float32)
127
+ cstate = tf.convert_to_tensor(np.zeros((1, self.units)), dtype=tf.float32)
128
+ output = []
129
+ for ind in range(input_data.shape[0]):
130
+ inp = tf.convert_to_tensor(
131
+ input_data[ind, :].reshape((1, input_data.shape[1])), dtype=tf.float32
132
+ )
133
+ h, s = self.calculate_lstm_cell_wt(inp, [hstate, cstate], ind)
134
+ hstate = s[0]
135
+ cstate = s[1]
136
+ output.append(h)
137
+ return output
138
+
139
+
140
+
141
+
142
+ class LSTM_backtrace(object):
143
+ def __init__(
144
+ self, num_cells, units, weights, return_sequence=False, go_backwards=False
145
+ ):
146
+ self.num_cells = num_cells
147
+ self.units = units
148
+ self.kernel = weights[0]
149
+ self.recurrent_kernel = weights[1]
150
+ self.bias = weights[2]
151
+ self.return_sequence = return_sequence
152
+ self.go_backwards = go_backwards
153
+ self.recurrent_activation = np_sigmoid
154
+ self.activation = np_tanh
155
+
156
+ self.compute_log = {}
157
+
158
+ def calculate_wt_fc(self, wts, inp, w, b, act):
159
+ mul_mat = np.einsum("ij,i->ij", w, inp).T
160
+ wt_mat = np.zeros(mul_mat.shape)
161
+ for i in range(mul_mat.shape[0]):
162
+ l1_ind1 = mul_mat[i]
163
+ wt_ind1 = wt_mat[i]
164
+ wt = wts[i]
165
+ p_ind = l1_ind1 > 0
166
+ n_ind = l1_ind1 < 0
167
+ p_sum = np.sum(l1_ind1[p_ind])
168
+ n_sum = np.sum(l1_ind1[n_ind]) * -1
169
+ if len(b) > 0:
170
+ if b[i] > 0:
171
+ pbias = b[i]
172
+ nbias = 0
173
+ else:
174
+ pbias = 0
175
+ nbias = b[i] * -1
176
+ else:
177
+ pbias = 0
178
+ nbias = 0
179
+ t_sum = p_sum + pbias - n_sum - nbias
180
+ if act["type"] == "mono":
181
+ if act["range"]["l"]:
182
+ if t_sum < act["range"]["l"]:
183
+ p_sum = 0
184
+ if act["range"]["u"]:
185
+ if t_sum > act["range"]["u"]:
186
+ n_sum = 0
187
+ elif act["type"] == "non_mono":
188
+ t_act = act["func"](t_sum)
189
+ p_act = act["func"](p_sum + pbias)
190
+ n_act = act["func"](-1 * (n_sum + nbias))
191
+ if act["range"]["l"]:
192
+ if t_sum < act["range"]["l"]:
193
+ p_sum = 0
194
+ if act["range"]["u"]:
195
+ if t_sum > act["range"]["u"]:
196
+ n_sum = 0
197
+ if p_sum > 0 and n_sum > 0:
198
+ if t_act == p_act:
199
+ n_sum = 0
200
+ elif t_act == n_act:
201
+ p_sum = 0
202
+ if p_sum > 0:
203
+ p_agg_wt = (p_sum + pbias) / (p_sum + n_sum + pbias + nbias)
204
+ p_agg_wt = p_agg_wt * (p_sum / (p_sum + pbias))
205
+ else:
206
+ p_agg_wt = 0
207
+ if n_sum > 0:
208
+ n_agg_wt = (n_sum + nbias) / (p_sum + n_sum + pbias + nbias)
209
+ n_agg_wt = n_agg_wt * (n_sum / (n_sum + nbias))
210
+ else:
211
+ n_agg_wt = 0
212
+ if p_sum == 0:
213
+ p_sum = 1
214
+ if n_sum == 0:
215
+ n_sum = 1
216
+ wt_ind1[p_ind] = (l1_ind1[p_ind] / p_sum) * wt * p_agg_wt
217
+ wt_ind1[n_ind] = (l1_ind1[n_ind] / n_sum) * wt * n_agg_wt * -1.0
218
+ wt_mat = wt_mat.sum(axis=0)
219
+ return wt_mat
220
+
221
+ def calculate_wt_add(self, wts, inp=None):
222
+ wt_mat = []
223
+ inp_list = []
224
+ for x in inp:
225
+ wt_mat.append(np.zeros_like(x))
226
+ wt_mat = np.array(wt_mat)
227
+ inp_list = np.array(inp)
228
+ for i in range(wt_mat.shape[1]):
229
+ wt_ind1 = wt_mat[:, i]
230
+ wt = wts[i]
231
+ l1_ind1 = inp_list[:, i]
232
+ p_ind = l1_ind1 > 0
233
+ n_ind = l1_ind1 < 0
234
+ p_sum = np.sum(l1_ind1[p_ind])
235
+ n_sum = np.sum(l1_ind1[n_ind]) * -1
236
+ t_sum = p_sum - n_sum
237
+ p_agg_wt = 0
238
+ n_agg_wt = 0
239
+ if p_sum + n_sum > 0:
240
+ p_agg_wt = p_sum / (p_sum + n_sum)
241
+ n_agg_wt = n_sum / (p_sum + n_sum)
242
+ if p_sum == 0:
243
+ p_sum = 1
244
+ if n_sum == 0:
245
+ n_sum = 1
246
+ wt_ind1[p_ind] = (l1_ind1[p_ind] / p_sum) * wt * p_agg_wt
247
+ wt_ind1[n_ind] = (l1_ind1[n_ind] / n_sum) * wt * n_agg_wt * -1.0
248
+ wt_mat[:, i] = wt_ind1
249
+ wt_mat = [i.reshape(wts.shape) for i in list(wt_mat)]
250
+ return wt_mat
251
+
252
+ def calculate_wt_multiply(self, wts, inp=None):
253
+ wt_mat = []
254
+ inp_list = []
255
+ for x in inp:
256
+ wt_mat.append(np.zeros_like(x))
257
+ wt_mat = np.array(wt_mat)
258
+ inp_list = np.array(inp)
259
+ inp_prod = inp[0] * inp[1]
260
+ inp_diff1 = np.abs(inp_prod - inp[0])
261
+ inp_diff2 = np.abs(inp_prod - inp[1])
262
+ inp_diff_sum = inp_diff1 + inp_diff2
263
+ inp_wt1 = (inp_diff1 / inp_diff_sum) * wts
264
+ inp_wt2 = (inp_diff2 / inp_diff_sum) * wts
265
+ return [inp_wt1, inp_wt2]
266
+
267
+ def compute_carry_and_output(self, wt_o, wt_c, h_tm1, c_tm1, x, cell_num):
268
+ """Computes carry and output using split kernels."""
269
+ h_tm1_i, h_tm1_f, h_tm1_c, h_tm1_o = (h_tm1, h_tm1, h_tm1, h_tm1)
270
+ x_i, x_f, x_c, x_o = x
271
+ f = self.compute_log[cell_num]["int_arrays"]["f"].numpy()[0]
272
+ i = self.compute_log[cell_num]["int_arrays"]["i"].numpy()[0]
273
+ # o = self.recurrent_activation(
274
+ # x_o + np.dot(h_tm1_o, self.recurrent_kernel[:, self.units * 3:])).astype(np.float32)
275
+ temp1 = np.dot(h_tm1_o, self.recurrent_kernel[1][:, self.units * 3 :]).astype(
276
+ np.float32
277
+ )
278
+ wt_x_o, wt_temp1 = self.calculate_wt_add(wt_o, [x_o, temp1])
279
+ wt_h_tm1_o = self.calculate_wt_fc(
280
+ wt_temp1,
281
+ h_tm1_o,
282
+ self.recurrent_kernel[1][:, self.units * 3 :],
283
+ [],
284
+ {"type": None},
285
+ )
286
+
287
+ # c = f * c_tm1 + i * self.activation(x_c + np.dot(
288
+ # h_tm1_c, self.recurrent_kernel[:, self.units * 2:self.units * 3])).astype(np.float32)
289
+ temp2 = f * c_tm1
290
+ temp3_1 = np.dot(
291
+ h_tm1_c, self.recurrent_kernel[1][:, self.units * 2 : self.units * 3]
292
+ )
293
+ temp3_2 = self.activation(x_c + temp3_1)
294
+ temp3_3 = i * temp3_2
295
+ wt_temp2, wt_temp3_3 = self.calculate_wt_add(wt_c, [temp2, temp3_3])
296
+ wt_f, wt_c_tm1 = self.calculate_wt_multiply(wt_temp2, [f, c_tm1])
297
+ wt_i, wt_temp3_2 = self.calculate_wt_multiply(wt_temp3_3, [i, temp3_2])
298
+ wt_x_c, wt_temp3_1 = self.calculate_wt_add(wt_temp3_2, [x_c, temp3_1])
299
+ wt_h_tm1_c = self.calculate_wt_fc(
300
+ wt_temp3_1,
301
+ h_tm1_c,
302
+ self.recurrent_kernel[1][:, self.units * 2 : self.units * 3],
303
+ [],
304
+ {"type": None},
305
+ )
306
+
307
+ # f = self.recurrent_activation(x_f + np.dot(
308
+ # h_tm1_f, self.recurrent_kernel[:, self.units:self.units * 2])).astype(np.float32)
309
+ temp4 = np.dot(h_tm1_f, self.recurrent_kernel[1][:, self.units : self.units * 2])
310
+ wt_x_f, wt_temp4 = self.calculate_wt_add(wt_f, [x_f, temp4])
311
+ wt_h_tm1_f = self.calculate_wt_fc(
312
+ wt_temp4,
313
+ h_tm1_f,
314
+ self.recurrent_kernel[1][:, self.units : self.units * 2],
315
+ [],
316
+ {"type": None},
317
+ )
318
+
319
+ # i = self.recurrent_activation(
320
+ # x_i + np.dot(h_tm1_i, self.recurrent_kernel[:, :self.units])).astype(np.float32)
321
+ temp5 = np.dot(h_tm1_i, self.recurrent_kernel[1][:, : self.units])
322
+ wt_x_i, wt_temp5 = self.calculate_wt_add(wt_i, [x_i, temp5])
323
+ wt_h_tm1_i = self.calculate_wt_fc(
324
+ wt_temp5,
325
+ h_tm1_i,
326
+ self.recurrent_kernel[1][:, : self.units],
327
+ [],
328
+ {"type": None},
329
+ )
330
+
331
+ return (
332
+ wt_x_i,
333
+ wt_x_f,
334
+ wt_x_c,
335
+ wt_x_o,
336
+ wt_h_tm1_i,
337
+ wt_h_tm1_f,
338
+ wt_h_tm1_c,
339
+ wt_h_tm1_o,
340
+ wt_c_tm1,
341
+ )
342
+
343
+ def calculate_lstm_cell_wt(self, cell_num, wts_hstate, wts_cstate):
344
+ o = self.compute_log[cell_num]["int_arrays"]["o"].numpy()[0]
345
+ c = self.compute_log[cell_num]["cstate"][1].numpy()[0]
346
+ h_tm1 = self.compute_log[cell_num]["hstate"][0].numpy()[0]
347
+ c_tm1 = self.compute_log[cell_num]["cstate"][0].numpy()[0]
348
+ x = [i.numpy()[0] for i in self.compute_log[cell_num]["x"]]
349
+ wt_o, wt_c = self.calculate_wt_multiply(
350
+ wts_hstate, [o, self.activation(c)]
351
+ ) # h = o * self.activation(c)
352
+ wt_c = wt_c + wts_cstate
353
+ (
354
+ wt_x_i,
355
+ wt_x_f,
356
+ wt_x_c,
357
+ wt_x_o,
358
+ wt_h_tm1_i,
359
+ wt_h_tm1_f,
360
+ wt_h_tm1_c,
361
+ wt_h_tm1_o,
362
+ wt_c_tm1,
363
+ ) = self.compute_carry_and_output(wt_o, wt_c, h_tm1, c_tm1, x, cell_num)
364
+ wt_h_tm1 = wt_h_tm1_i + wt_h_tm1_f + wt_h_tm1_c + wt_h_tm1_o
365
+ inputs = self.compute_log[cell_num]["inp"].numpy()[0]
366
+
367
+ #print(np.split(self.kernel[1], indices_or_sections=4, axis=1))
368
+ k_i, k_f, k_c, k_o = np.split(self.kernel[1], indices_or_sections=4, axis=1)
369
+ b_i, b_f, b_c, b_o = np.split(self.bias[1], indices_or_sections=4, axis=0)
370
+
371
+ wt_inputs_i = self.calculate_wt_fc(wt_x_i, inputs, k_i, b_i, {"type": None})
372
+ wt_inputs_f = self.calculate_wt_fc(wt_x_f, inputs, k_f, b_f, {"type": None})
373
+ wt_inputs_c = self.calculate_wt_fc(wt_x_c, inputs, k_c, b_c, {"type": None})
374
+ wt_inputs_o = self.calculate_wt_fc(wt_x_o, inputs, k_o, b_o, {"type": None})
375
+
376
+ wt_inputs = wt_inputs_i + wt_inputs_f + wt_inputs_c + wt_inputs_o
377
+
378
+ return wt_inputs, wt_h_tm1, wt_c_tm1
379
+
380
+ def calculate_lstm_wt(self, wts, compute_log):
381
+ self.compute_log = compute_log
382
+ output = []
383
+ if self.return_sequence:
384
+ temp_wts_hstate = wts[-1, :]
385
+ else:
386
+ temp_wts_hstate = wts
387
+ temp_wts_cstate = np.zeros_like(self.compute_log[0]["cstate"][1].numpy()[0])
388
+ for ind in range(len(self.compute_log) - 1, -1, -1):
389
+ temp_wt_inp, temp_wts_hstate, temp_wts_cstate = self.calculate_lstm_cell_wt(
390
+ ind, temp_wts_hstate, temp_wts_cstate
391
+ )
392
+ output.append(temp_wt_inp)
393
+ if self.return_sequence and ind > 0:
394
+ temp_wts_hstate = temp_wts_hstate + wts[ind - 1, :]
395
+ output.reverse()
396
+ return np.array(output)
397
+
398
+
399
+ def dummy_wt(wts, inp, *args):
400
+ test_wt = np.zeros_like(inp)
401
+ return test_wt
402
+
403
+
404
+ def calculate_wt_fc(wts, inp, w, b, act):
405
+ mul_mat = np.einsum("ij,i->ij", w.numpy().T, inp).T
406
+ wt_mat = np.zeros(mul_mat.shape)
407
+ for i in range(mul_mat.shape[0]):
408
+ l1_ind1 = mul_mat[i]
409
+ wt_ind1 = wt_mat[i]
410
+ wt = wts[i]
411
+ p_ind = l1_ind1 > 0
412
+ n_ind = l1_ind1 < 0
413
+ p_sum = np.sum(l1_ind1[p_ind])
414
+ n_sum = np.sum(l1_ind1[n_ind]) * -1
415
+ if b.numpy()[i] > 0:
416
+ pbias = b.numpy()[i]
417
+ nbias = 0
418
+ else:
419
+ pbias = 0
420
+ nbias = b.numpy()[i] * -1
421
+ t_sum = p_sum + pbias - n_sum - nbias
422
+ if act["type"] == "mono":
423
+ if act["range"]["l"]:
424
+ if t_sum < act["range"]["l"]:
425
+ p_sum = 0
426
+ if act["range"]["u"]:
427
+ if t_sum > act["range"]["u"]:
428
+ n_sum = 0
429
+ elif act["type"] == "non_mono":
430
+ t_act = act["func"](t_sum)
431
+ p_act = act["func"](p_sum + pbias)
432
+ n_act = act["func"](-1 * (n_sum + nbias))
433
+ if act["range"]["l"]:
434
+ if t_sum < act["range"]["l"]:
435
+ p_sum = 0
436
+ if act["range"]["u"]:
437
+ if t_sum > act["range"]["u"]:
438
+ n_sum = 0
439
+ if p_sum > 0 and n_sum > 0:
440
+ if t_act == p_act:
441
+ n_sum = 0
442
+ elif t_act == n_act:
443
+ p_sum = 0
444
+ if p_sum > 0:
445
+ p_agg_wt = (p_sum + pbias) / (p_sum + n_sum + pbias + nbias)
446
+ p_agg_wt = p_agg_wt * (p_sum / (p_sum + pbias))
447
+ else:
448
+ p_agg_wt = 0
449
+ if n_sum > 0:
450
+ n_agg_wt = (n_sum + nbias) / (p_sum + n_sum + pbias + nbias)
451
+ n_agg_wt = n_agg_wt * (n_sum / (n_sum + nbias))
452
+ else:
453
+ n_agg_wt = 0
454
+ if p_sum == 0:
455
+ p_sum = 1
456
+ if n_sum == 0:
457
+ n_sum = 1
458
+ wt_ind1[p_ind] = (l1_ind1[p_ind] / p_sum) * wt * p_agg_wt
459
+ wt_ind1[n_ind] = (l1_ind1[n_ind] / n_sum) * wt * n_agg_wt * -1.0
460
+
461
+ wt_mat = wt_mat.sum(axis=0)
462
+ return wt_mat
463
+
464
+
465
+ def calculate_wt_rshp(wts, inp=None):
466
+ x = np.reshape(wts, inp.shape)
467
+ return x
468
+
469
+
470
+ def calculate_wt_concat(wts, inp=None, axis=-1):
471
+ wts=wts.T
472
+ splits = [i.shape[axis] for i in inp]
473
+ splits = np.cumsum(splits)
474
+ if axis > 0:
475
+ axis = axis - 1
476
+ x = np.split(wts, indices_or_sections=splits, axis=axis)
477
+ return x
478
+
479
+
480
+ def calculate_wt_add(wts, inp=None):
481
+ wts=wts.T
482
+ wt_mat = []
483
+ inp_list = []
484
+ expanded_wts = as_strided(
485
+ wts,
486
+ shape=(np.prod(wts.shape),),
487
+ strides=(wts.strides[-1],),
488
+ writeable=False, # totally use this to avoid writing to memory in weird places
489
+ )
490
+
491
+ for x in inp:
492
+ expanded_input = as_strided(
493
+ x,
494
+ shape=(np.prod(x.shape),),
495
+ strides=(x.strides[-1],),
496
+ writeable=False, # totally use this to avoid writing to memory in weird places
497
+ )
498
+ inp_list.append(expanded_input)
499
+ wt_mat.append(np.zeros_like(expanded_input))
500
+ wt_mat = np.array(wt_mat)
501
+ inp_list = np.array(inp_list)
502
+ for i in range(wt_mat.shape[1]):
503
+ wt_ind1 = wt_mat[:, i]
504
+ wt = expanded_wts[i]
505
+ l1_ind1 = inp_list[:, i]
506
+ p_ind = l1_ind1 > 0
507
+ n_ind = l1_ind1 < 0
508
+ p_sum = np.sum(l1_ind1[p_ind])
509
+ n_sum = np.sum(l1_ind1[n_ind]) * -1
510
+ t_sum = p_sum - n_sum
511
+ p_agg_wt = 0
512
+ n_agg_wt = 0
513
+ if p_sum + n_sum > 0:
514
+ p_agg_wt = p_sum / (p_sum + n_sum)
515
+ n_agg_wt = n_sum / (p_sum + n_sum)
516
+ if p_sum == 0:
517
+ p_sum = 1
518
+ if n_sum == 0:
519
+ n_sum = 1
520
+ wt_ind1[p_ind] = (l1_ind1[p_ind] / p_sum) * wt * p_agg_wt
521
+ wt_ind1[n_ind] = (l1_ind1[n_ind] / n_sum) * wt * n_agg_wt * -1.0
522
+ wt_mat[:, i] = wt_ind1
523
+ wt_mat = [i.reshape(wts.shape) for i in list(wt_mat)]
524
+ return wt_mat
525
+
526
+
527
+ def calculate_start_wt(arg):
528
+ x = np.argmax(arg[0])
529
+ y = np.zeros(arg.shape)
530
+ y[0][x] = 1
531
+ return y[0]
532
+
533
+
534
+ def calculate_wt_passthru(wts):
535
+ return wts
536
+
537
+
538
+ def calculate_wt_conv_unit(wt, p_mat, n_mat, t_sum, p_sum, n_sum, act):
539
+ wt_mat = np.zeros_like(p_mat)
540
+ if act["type"] == "mono":
541
+ if act["range"]["l"]:
542
+ if t_sum < act["range"]["l"]:
543
+ p_sum = 0
544
+ if act["range"]["u"]:
545
+ if t_sum > act["range"]["u"]:
546
+ n_sum = 0
547
+ elif act["type"] == "non_mono":
548
+ t_act = act["func"](t_sum)
549
+ p_act = act["func"](p_sum)
550
+ n_act = act["func"](n_sum)
551
+ if act["range"]["l"]:
552
+ if t_sum < act["range"]["l"]:
553
+ p_sum = 0
554
+ if act["range"]["u"]:
555
+ if t_sum > act["range"]["u"]:
556
+ n_sum = 0
557
+ if p_sum > 0 and n_sum > 0:
558
+ if t_act == p_act:
559
+ n_sum = 0
560
+ elif t_act == n_act:
561
+ p_sum = 0
562
+ p_agg_wt = 0.0
563
+ n_agg_wt = 0.0
564
+ if p_sum + n_sum > 0.0:
565
+ p_agg_wt = p_sum / (p_sum + n_sum)
566
+ n_agg_wt = n_sum / (p_sum + n_sum)
567
+ if p_sum == 0.0:
568
+ p_sum = 1.0
569
+ if n_sum == 0.0:
570
+ n_sum = 1.0
571
+ wt_mat = wt_mat + ((p_mat / p_sum) * wt * p_agg_wt)
572
+ wt_mat = wt_mat + ((n_mat / n_sum) * wt * n_agg_wt * -1.0)
573
+ return wt_mat
574
+
575
+
576
+ def calculate_wt_conv(wts, inp, w, b, act):
577
+ wts=wts.T
578
+ inp=inp.T
579
+ w=w.T
580
+ expanded_input = as_strided(
581
+ inp,
582
+ shape=(
583
+ inp.shape[0]
584
+ - w.numpy().shape[0]
585
+ + 1, # The feature map is a few pixels smaller than the input
586
+ inp.shape[1] - w.numpy().shape[1] + 1,
587
+ inp.shape[2],
588
+ w.numpy().shape[0],
589
+ w.numpy().shape[1],
590
+ ),
591
+ strides=(
592
+ inp.strides[0],
593
+ inp.strides[1],
594
+ inp.strides[2],
595
+ inp.strides[
596
+ 0
597
+ ], # When we move one step in the 3rd dimension, we should move one step in the original data too
598
+ inp.strides[1],
599
+ ),
600
+ writeable=False, # totally use this to avoid writing to memory in weird places
601
+ )
602
+ test_wt = np.einsum("mnc->cmn", np.zeros_like(inp), order="C", optimize=True)
603
+ for k in range(w.numpy().shape[-1]):
604
+ kernel = w.numpy()[:, :, :, k]
605
+ x = np.einsum(
606
+ "abcmn,mnc->abcmn", expanded_input, kernel, order="C", optimize=True
607
+ )
608
+ x_pos = x.copy()
609
+ x_neg = x.copy()
610
+ x_pos[x < 0] = 0
611
+ x_neg[x > 0] = 0
612
+ x_sum = np.einsum("abcmn->ab", x, order="C", optimize=True)
613
+ x_p_sum = np.einsum("abcmn->ab", x_pos, order="C", optimize=True)
614
+ x_n_sum = np.einsum("abcmn->ab", x_neg, order="C", optimize=True) * -1.0
615
+ # print(np.sum(x),np.sum(x_pos),np.sum(x_neg),np.sum(x_n_sum))
616
+ for ind1 in range(expanded_input.shape[0]):
617
+ for ind2 in range(expanded_input.shape[1]):
618
+ temp_wt_mat = calculate_wt_conv_unit(
619
+ wts[ind1, ind2, k],
620
+ x_pos[ind1, ind2, :, :, :],
621
+ x_neg[ind1, ind2, :, :, :],
622
+ x_sum[ind1, ind2],
623
+ x_p_sum[ind1, ind2],
624
+ x_n_sum[ind1, ind2],
625
+ act,
626
+ )
627
+ test_wt[
628
+ :, ind1 : ind1 + kernel.shape[0], ind2 : ind2 + kernel.shape[1]
629
+ ] += temp_wt_mat
630
+ test_wt = np.einsum("cmn->mnc", test_wt, order="C", optimize=True)
631
+ gc.collect()
632
+ return test_wt
633
+
634
+
635
+ def get_max_index(mat=None):
636
+ max_ind = np.argmax(mat)
637
+ ind = []
638
+ rem = max_ind
639
+ for i in mat.shape[:-1]:
640
+ ind.append(rem // i)
641
+ rem = rem % i
642
+ ind.append(rem)
643
+ return tuple(ind)
644
+
645
+
646
+ def calculate_wt_maxpool(wts, inp, pool_size):
647
+ wts=wts.T
648
+ inp=inp.T
649
+ pad1 = pool_size[0]
650
+ pad2 = pool_size[1]
651
+ test_samp_pad = np.pad(inp, ((0, pad1), (0, pad2), (0, 0)), "constant")
652
+ dim1, dim2, _ = wts.shape
653
+ test_wt = np.zeros_like(test_samp_pad)
654
+ for k in range(inp.shape[2]):
655
+ wt_mat = wts[:, :, k]
656
+ for ind1 in range(dim1):
657
+ for ind2 in range(dim2):
658
+ temp_inp = test_samp_pad[
659
+ ind1 * pool_size[0] : (ind1 + 1) * pool_size[0],
660
+ ind2 * pool_size[1] : (ind2 + 1) * pool_size[1],
661
+ k,
662
+ ]
663
+ max_index = get_max_index(temp_inp)
664
+ test_wt[
665
+ ind1 * pool_size[0] : (ind1 + 1) * pool_size[0],
666
+ ind2 * pool_size[1] : (ind2 + 1) * pool_size[1],
667
+ k,
668
+ ][max_index] = wt_mat[ind1, ind2]
669
+ test_wt = test_wt[0 : inp.shape[0], 0 : inp.shape[1], :]
670
+ return test_wt
671
+
672
+
673
+ def calculate_wt_avgpool(wts, inp, pool_size):
674
+ wts=wts.T
675
+ inp=inp.T
676
+
677
+ pad1 = pool_size[0]
678
+ pad2 = pool_size[1]
679
+ test_samp_pad = np.pad(inp, ((0, pad1), (0, pad2), (0, 0)), "constant")
680
+ dim1, dim2, _ = wts.shape
681
+ test_wt = np.zeros_like(test_samp_pad)
682
+ for k in range(inp.shape[2]):
683
+ wt_mat = wts[:, :, k]
684
+ for ind1 in range(dim1):
685
+ for ind2 in range(dim2):
686
+ temp_inp = test_samp_pad[
687
+ ind1 * pool_size[0] : (ind1 + 1) * pool_size[0],
688
+ ind2 * pool_size[1] : (ind2 + 1) * pool_size[1],
689
+ k,
690
+ ]
691
+ wt_ind1 = test_wt[
692
+ ind1 * pool_size[0] : (ind1 + 1) * pool_size[0],
693
+ ind2 * pool_size[1] : (ind2 + 1) * pool_size[1],
694
+ k,
695
+ ]
696
+ wt = wt_mat[ind1, ind2]
697
+ p_ind = temp_inp > 0
698
+ n_ind = temp_inp < 0
699
+ p_sum = np.sum(temp_inp[p_ind])
700
+ n_sum = np.sum(temp_inp[n_ind]) * -1
701
+ if p_sum > 0:
702
+ p_agg_wt = p_sum / (p_sum + n_sum)
703
+ else:
704
+ p_agg_wt = 0
705
+ if n_sum > 0:
706
+ n_agg_wt = n_sum / (p_sum + n_sum)
707
+ else:
708
+ n_agg_wt = 0
709
+ if p_sum == 0:
710
+ p_sum = 1
711
+ if n_sum == 0:
712
+ n_sum = 1
713
+ wt_ind1[p_ind] += (temp_inp[p_ind] / p_sum) * wt * p_agg_wt
714
+ wt_ind1[n_ind] += (temp_inp[n_ind] / n_sum) * wt * n_agg_wt * -1.0
715
+ test_wt = test_wt[0 : inp.shape[0], 0 : inp.shape[1], :]
716
+ return test_wt
717
+
718
+
719
+ def calculate_wt_gavgpool(wts, inp):
720
+ wts=wts.T
721
+ inp=inp.T
722
+ channels = wts.shape[0]
723
+ wt_mat = np.zeros_like(inp)
724
+ for c in range(channels):
725
+ wt = wts[c]
726
+ temp_wt = wt_mat[..., c]
727
+ x = inp[..., c]
728
+ p_mat = np.copy(x)
729
+ n_mat = np.copy(x)
730
+ p_mat[x < 0] = 0
731
+ n_mat[x > 0] = 0
732
+ p_sum = np.sum(p_mat)
733
+ n_sum = np.sum(n_mat) * -1
734
+ p_agg_wt = 0.0
735
+ n_agg_wt = 0.0
736
+ if p_sum + n_sum > 0.0:
737
+ p_agg_wt = p_sum / (p_sum + n_sum)
738
+ n_agg_wt = n_sum / (p_sum + n_sum)
739
+ if p_sum == 0.0:
740
+ p_sum = 1.0
741
+ if n_sum == 0.0:
742
+ n_sum = 1.0
743
+ temp_wt = temp_wt + ((p_mat / p_sum) * wt * p_agg_wt)
744
+ temp_wt = temp_wt + ((n_mat / n_sum) * wt * n_agg_wt * -1.0)
745
+ wt_mat[..., c] = temp_wt
746
+ return wt_mat