dl-backtrace 0.0.18__py3-none-any.whl → 0.0.19__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of dl-backtrace might be problematic. Click here for more details.

@@ -1,69 +1,46 @@
1
1
  import gc
2
-
2
+ import torch
3
3
  import numpy as np
4
- import tensorflow as tf
5
4
  from numpy.lib.stride_tricks import as_strided
6
- from tensorflow.keras import backend as K
7
-
8
5
 
9
6
  def np_swish(x, beta=0.75):
10
7
  z = 1 / (1 + np.exp(-(beta * x)))
11
8
  return x * z
12
9
 
13
-
14
10
  def np_wave(x, alpha=1.0):
15
11
  return (alpha * x * np.exp(1.0)) / (np.exp(-x) + np.exp(x))
16
12
 
17
-
18
13
  def np_pulse(x, alpha=1.0):
19
14
  return alpha * (1 - np.tanh(x) * np.tanh(x))
20
15
 
21
-
22
16
  def np_absolute(x, alpha=1.0):
23
17
  return alpha * x * np.tanh(x)
24
18
 
25
-
26
19
  def np_hard_sigmoid(x):
27
20
  return np.clip(0.2 * x + 0.5, 0, 1)
28
21
 
29
-
30
22
  def np_sigmoid(x):
31
23
  z = 1 / (1 + np.exp(-x))
32
24
  return z
33
25
 
34
-
35
26
  def np_tanh(x):
36
27
  z = np.tanh(x)
37
28
  return z.astype(np.float32)
38
29
 
39
-
40
- def calculate_start_wt(arg, max_wt=1):
41
- x = np.argmax(arg[0])
42
- m = np.max(arg[0])
43
- y_pos = np.zeros_like(arg)
44
- y_pos[0][x] = m
45
- y_neg = np.array(arg)
46
- if m < 1 and arg.shape[-1] == 1:
47
- y_neg[0][x] = 1 - m
48
- else:
49
- y_neg[0][x] = 0
50
- return y_pos[0], y_neg[0]
51
-
52
-
53
- def calculate_base_wt(p_sum=0, n_sum=0, bias=0, wt_pos=0, wt_neg=0):
30
+ def calculate_base_wt(p_sum=0,n_sum=0,bias=0,wt_pos=0,wt_neg=0):
54
31
  t_diff = p_sum + bias - n_sum
55
32
  bias = 0
56
33
  wt_sign = 1
57
- if t_diff > 0:
58
- if wt_pos > wt_neg:
34
+ if t_diff>0:
35
+ if wt_pos>wt_neg:
59
36
  p_agg_wt = wt_pos
60
37
  n_agg_wt = wt_neg
61
38
  else:
62
39
  p_agg_wt = wt_neg
63
40
  n_agg_wt = wt_pos
64
41
  wt_sign = -1
65
- elif t_diff < 0:
66
- if wt_pos < wt_neg:
42
+ elif t_diff<0:
43
+ if wt_pos<wt_neg:
67
44
  p_agg_wt = wt_pos
68
45
  n_agg_wt = wt_neg
69
46
  else:
@@ -77,8 +54,128 @@ def calculate_base_wt(p_sum=0, n_sum=0, bias=0, wt_pos=0, wt_neg=0):
77
54
  p_sum = 1
78
55
  if n_sum == 0:
79
56
  n_sum = 1
80
- return p_agg_wt, n_agg_wt, p_sum, n_sum, wt_sign
57
+ return p_agg_wt,n_agg_wt,p_sum,n_sum,wt_sign
81
58
 
59
+ def calculate_base_wt_array(p_sum=[],n_sum=[],bias=[],wt_pos=[],wt_neg=[]):
60
+ t_diff = p_sum + bias - n_sum
61
+ t_diff_pos = (t_diff>0)
62
+ t_diff_neg = (t_diff<0)
63
+ wt_sign_pos = wt_pos>wt_neg
64
+ wt_sign_neg = wt_pos<wt_neg
65
+ p_agg_wt_pos = np.zeros_like(wt_pos)
66
+ p_agg_wt_neg = np.zeros_like(wt_pos)
67
+ n_agg_wt_pos = np.zeros_like(wt_pos)
68
+ n_agg_wt_neg = np.zeros_like(wt_pos)
69
+
70
+ p_agg_wt_pos += wt_pos*t_diff_pos*wt_sign_pos
71
+ p_agg_wt_pos += wt_pos*t_diff_neg*wt_sign_neg
72
+
73
+ p_agg_wt_neg += wt_neg*t_diff_pos*wt_sign_neg
74
+ p_agg_wt_neg += wt_neg*t_diff_neg*wt_sign_pos
75
+
76
+ n_agg_wt_pos += wt_pos*t_diff_pos*wt_sign_neg
77
+ n_agg_wt_pos += wt_pos*t_diff_neg*wt_sign_pos
78
+
79
+ n_agg_wt_neg += wt_neg*t_diff_pos*wt_sign_pos
80
+ n_agg_wt_neg += wt_neg*t_diff_neg*wt_sign_neg
81
+
82
+ p_sum[p_sum==0] = 1.0
83
+ n_sum[n_sum==0] = 1.0
84
+
85
+ return p_agg_wt_pos,p_agg_wt_neg,n_agg_wt_pos,n_agg_wt_neg,p_sum,n_sum
86
+
87
+ def calculate_start_wt(arg,scaler=None,thresholding=0.5,task="binary-classification"):
88
+ if arg.ndim == 2:
89
+ if task=="binary-classification" or task=="multi-class-classification":
90
+ x = np.argmax(arg[0])
91
+ m = np.max(arg[0])
92
+ y_pos = np.zeros_like(arg)
93
+ if scaler:
94
+ y_pos[0][x] = scaler
95
+ else:
96
+ y_pos[0][x] = m
97
+ y_neg = np.array(arg)
98
+ if m<1 and arg.shape[-1]==1:
99
+ y_neg[0][x] = 1-m
100
+ else:
101
+ y_neg[0][x] = 0
102
+ if scaler and np.sum(y_neg)>0:
103
+ y_neg = y_neg*(scaler/np.sum(y_neg))
104
+ elif task == "bbox-regression":
105
+ y_pos = np.zeros_like(arg)
106
+ if scaler:
107
+ y_pos[0] = scaler
108
+ num_non_zero_elements = np.count_nonzero(y)
109
+ if num_non_zero_elements > 0:
110
+ y = y / num_non_zero_elements
111
+ else:
112
+ x = np.argmax(arg[0])
113
+ m = np.max(arg[0])
114
+ y_pos[0] = m
115
+ y_neg = np.array(arg)
116
+ if m<1 and arg.shape[-1]==1:
117
+ y_neg[0][x] = 1-m
118
+ else:
119
+ y_neg[0][x] = 0
120
+ if scaler and np.sum(y_neg)>0:
121
+ y_neg = y_neg*(scaler/np.sum(y_neg))
122
+ else:
123
+ x = np.argmax(arg[0])
124
+ m = np.max(arg[0])
125
+ y_pos = np.zeros_like(arg)
126
+ if scaler:
127
+ y_pos[0][x] = scaler
128
+ else:
129
+ y_pos[0][x] = m
130
+ y_neg = np.array(arg)
131
+ if m<1 and arg.shape[-1]==1:
132
+ y_neg[0][x] = 1-m
133
+ else:
134
+ y_neg[0][x] = 0
135
+ if scaler and np.sum(y_neg)>0:
136
+ y_neg = y_neg*(scaler/np.sum(y_neg))
137
+ elif arg.ndim == 4:
138
+ if task == "binary-segmentation":
139
+ indices = np.where(arg > thresholding)
140
+ y_pos = np.zeros(arg.shape)
141
+ if scaler:
142
+ y_pos[indices] = scaler
143
+ num_non_zero_elements = np.count_nonzero(y_pos)
144
+ if num_non_zero_elements > 0:
145
+ y_pos = y_pos / num_non_zero_elements
146
+ else:
147
+ y_pos[indices] = arg[indices]
148
+
149
+ y_neg = np.array(arg)
150
+ m = np.max(arg[0])
151
+ if m<=1:
152
+ y_neg[indices] = 1 - arg[indices]
153
+ else:
154
+ y_neg[indices] = 0
155
+ if scaler and np.sum(y_neg)>0:
156
+ y_neg = y_neg*(scaler/np.sum(y_neg))
157
+ else:
158
+ indices = np.where(arg > thresholding)
159
+ y_pos = np.zeros(arg.shape)
160
+ if scaler:
161
+ y_pos[indices] = scaler
162
+ num_non_zero_elements = np.count_nonzero(y_pos)
163
+ if num_non_zero_elements > 0:
164
+ y_pos = y_pos / num_non_zero_elements
165
+ else:
166
+ y_pos[indices] = arg[indices]
167
+ num_non_zero_elements = np.count_nonzero(y_pos)
168
+ if num_non_zero_elements > 0:
169
+ y_pos = y_pos / num_non_zero_elements
170
+ y_neg = np.array(arg)
171
+ m = np.max(arg[0])
172
+ if m<1:
173
+ y_neg[indices] = 1 - arg[indices]
174
+ else:
175
+ y_neg[indices] = 0
176
+ if scaler and np.sum(y_neg)>0:
177
+ y_neg = y_neg*(scaler/np.sum(y_neg))
178
+ return y_pos[0],y_neg[0]
82
179
 
83
180
  class LSTM_forward(object):
84
181
  def __init__(
@@ -91,8 +188,8 @@ class LSTM_forward(object):
91
188
  self.bias = weights[2]
92
189
  self.return_sequence = return_sequence
93
190
  self.go_backwards = go_backwards
94
- self.recurrent_activation = tf.math.sigmoid
95
- self.activation = tf.math.tanh
191
+ self.recurrent_activation = torch.sigmoid()
192
+ self.activation = torch.tanh()
96
193
 
97
194
  self.compute_log = {}
98
195
  for i in range(self.num_cells):
@@ -108,17 +205,17 @@ class LSTM_forward(object):
108
205
  x_i, x_f, x_c, x_o = x
109
206
  h_tm1_i, h_tm1_f, h_tm1_c, h_tm1_o = h_tm1
110
207
  i = self.recurrent_activation(
111
- x_i + K.dot(h_tm1_i, self.recurrent_kernel[:, : self.units])
208
+ x_i + torch.dot(h_tm1_i, self.recurrent_kernel[:, : self.units])
112
209
  )
113
210
  f = self.recurrent_activation(
114
- x_f + K.dot(h_tm1_f, self.recurrent_kernel[:, self.units : self.units * 2])
211
+ x_f + torch.dot(h_tm1_f, self.recurrent_kernel[:, self.units : self.units * 2])
115
212
  )
116
213
  c = f * c_tm1 + i * self.activation(
117
214
  x_c
118
- + K.dot(h_tm1_c, self.recurrent_kernel[:, self.units * 2 : self.units * 3])
215
+ + torch.dot(h_tm1_c, self.recurrent_kernel[:, self.units * 2 : self.units * 3])
119
216
  )
120
217
  o = self.recurrent_activation(
121
- x_o + K.dot(h_tm1_o, self.recurrent_kernel[:, self.units * 3 :])
218
+ x_o + torch.dot(h_tm1_o, self.recurrent_kernel[:, self.units * 3 :])
122
219
  )
123
220
  self.compute_log[cell_num]["int_arrays"]["i"] = i
124
221
  self.compute_log[cell_num]["int_arrays"]["f"] = f
@@ -136,16 +233,16 @@ class LSTM_forward(object):
136
233
  inputs_f = inputs
137
234
  inputs_c = inputs
138
235
  inputs_o = inputs
139
- k_i, k_f, k_c, k_o = tf.split(self.kernel, num_or_size_splits=4, axis=1)
140
- x_i = K.dot(inputs_i, k_i)
141
- x_f = K.dot(inputs_f, k_f)
142
- x_c = K.dot(inputs_c, k_c)
143
- x_o = K.dot(inputs_o, k_o)
144
- b_i, b_f, b_c, b_o = tf.split(self.bias, num_or_size_splits=4, axis=0)
145
- x_i = tf.add(x_i, b_i)
146
- x_f = tf.add(x_f, b_f)
147
- x_c = tf.add(x_c, b_c)
148
- x_o = tf.add(x_o, b_o)
236
+ k_i, k_f, k_c, k_o = torch.split(self.kernel[1],self.kernel.size(1)//4,dim=1)
237
+ x_i = torch.dot(inputs_i, k_i)
238
+ x_f = torch.dot(inputs_f, k_f)
239
+ x_c = torch.dot(inputs_c, k_c)
240
+ x_o = torch.dot(inputs_o, k_o)
241
+ b_i, b_f, b_c, b_o = torch.split(self.bias,self.bias.size(1)//4,dim=0)
242
+ x_i = x_i + b_i
243
+ x_f = x_f + b_f
244
+ x_c = x_c + b_c
245
+ x_o = x_o + b_o
149
246
 
150
247
  h_tm1_i = h_tm1
151
248
  h_tm1_f = h_tm1
@@ -161,12 +258,12 @@ class LSTM_forward(object):
161
258
  return h, [h, c]
162
259
 
163
260
  def calculate_lstm_wt(self, input_data):
164
- hstate = tf.convert_to_tensor(np.zeros((1, self.units)), dtype=tf.float32)
165
- cstate = tf.convert_to_tensor(np.zeros((1, self.units)), dtype=tf.float32)
261
+ hstate = torch.tensor(np.zeros((1,self.units)),dtype=torch.float32)
262
+ cstate = torch.tensor(np.zeros((1,self.units)),dtype=torch.float32)
166
263
  output = []
167
264
  for ind in range(input_data.shape[0]):
168
- inp = tf.convert_to_tensor(
169
- input_data[ind, :].reshape((1, input_data.shape[1])), dtype=tf.float32
265
+ inp = torch.tensor(
266
+ input_data[ind, :].reshape((1, input_data.shape[1])), dtype=torch.float32
170
267
  )
171
268
  h, s = self.calculate_lstm_cell_wt(inp, [hstate, cstate], ind)
172
269
  hstate = s[0]
@@ -454,12 +551,10 @@ class LSTM_backtrace(object):
454
551
  output_neg = np.array(output_neg)
455
552
  return output_pos, output_neg
456
553
 
457
-
458
554
  def dummy_wt(wts, inp, *args):
459
555
  test_wt = np.zeros_like(inp)
460
556
  return test_wt
461
557
 
462
-
463
558
  def calculate_wt_fc(wts_pos, wts_neg, inp, w, b, act={}):
464
559
  mul_mat = np.einsum("ij,i->ij", w.numpy().T, inp).T
465
560
  wt_mat_pos = np.zeros(mul_mat.shape)
@@ -494,22 +589,17 @@ def calculate_wt_fc(wts_pos, wts_neg, inp, w, b, act={}):
494
589
  else:
495
590
  wt_ind1_neg[p_ind] = (l1_ind1[p_ind] / p_sum) * p_agg_wt
496
591
  wt_ind1_pos[n_ind] = (l1_ind1[n_ind] / n_sum) * n_agg_wt * -1
497
- # print(wt_pos,wt_neg,p_agg_wt,n_agg_wt,wt_sign)
498
- # print("---------------------------------")
499
592
  wt_mat_pos = wt_mat_pos.sum(axis=0)
500
593
  wt_mat_neg = wt_mat_neg.sum(axis=0)
501
594
  return wt_mat_pos, wt_mat_neg
502
595
 
503
-
504
596
  def calculate_wt_passthru(wts):
505
597
  return wts
506
598
 
507
-
508
599
  def calculate_wt_rshp(wts, inp=None):
509
600
  x = np.reshape(wts, inp.shape)
510
601
  return x
511
602
 
512
-
513
603
  def calculate_wt_concat(wts, inp=None, axis=-1):
514
604
  splits = [i.shape[axis] for i in inp]
515
605
  splits = np.cumsum(splits)
@@ -518,7 +608,6 @@ def calculate_wt_concat(wts, inp=None, axis=-1):
518
608
  x = np.split(wts, indices_or_sections=splits, axis=axis)
519
609
  return x
520
610
 
521
-
522
611
  def calculate_wt_add(wts_pos, wts_neg, inp=None):
523
612
  wts_pos = wts_pos
524
613
  wts_neg = wts_neg
@@ -586,113 +675,92 @@ def calculate_wt_add(wts_pos, wts_neg, inp=None):
586
675
  output.append((wt_mat_pos[i], wt_mat_neg[i]))
587
676
  return output
588
677
 
589
-
590
678
  def calculate_wt_passthru(wts):
591
679
  return wts
592
680
 
681
+ def calculate_padding(kernel_size, inp, padding, strides, const_val=0.0):
682
+ if padding=='valid':
683
+ return (inp, [[0,0],[0,0],[0,0]])
684
+ elif padding=="same":
685
+ h = inp.shape[0]%strides[0]
686
+ if h==0:
687
+ pad_h = np.max([0,kernel_size[0]-strides[0]])
688
+ else:
689
+ pad_h = np.max([0,kernel_size[0]-h])
690
+
691
+ v = inp.shape[1]%strides[1]
692
+ if v==0:
693
+ pad_v = np.max([0,kernel_size[1]-strides[1]])
694
+ else:
695
+ pad_v = np.max([0,kernel_size[1]-v])
593
696
 
594
- def calculate_wt_conv_unit(
595
- wt_pos, wt_neg, p_mat, n_mat, p_sum, n_sum, pbias, nbias, act={}
596
- ):
597
- wt_mat_pos = np.zeros_like(p_mat)
598
- wt_mat_neg = np.zeros_like(p_mat)
599
- if n_sum == 0 and p_sum > 0:
600
- wt_mat_pos = wt_mat_pos + ((p_mat / p_sum) * wt_pos)
601
- wt_mat_neg = wt_mat_neg + ((p_mat / p_sum) * wt_neg)
602
- elif n_sum > 0 and p_sum == 0:
603
- wt_mat_pos = wt_mat_pos + ((n_mat / n_sum) * wt_pos * -1)
604
- wt_mat_neg = wt_mat_neg + ((n_mat / n_sum) * wt_neg * -1)
697
+ paddings = [np.floor([pad_h/2.0,(pad_h+1)/2.0]).astype("int32"),
698
+ np.floor([pad_v/2.0,(pad_v+1)/2.0]).astype("int32"),
699
+ np.zeros((2)).astype("int32")]
700
+ inp_pad = np.pad(inp, paddings, 'constant', constant_values=const_val)
701
+ return (inp_pad,paddings)
605
702
  else:
606
- p_agg_wt, n_agg_wt, p_sum, n_sum, wt_sign = calculate_base_wt(
607
- p_sum=p_sum, n_sum=n_sum, bias=pbias - nbias, wt_pos=wt_pos, wt_neg=wt_neg
608
- )
609
- if wt_sign > 0:
610
- wt_mat_pos = wt_mat_pos + ((p_mat / p_sum) * p_agg_wt)
611
- wt_mat_neg = wt_mat_neg + ((n_mat / n_sum) * n_agg_wt * -1)
703
+ if isinstance(padding, tuple) and padding != (None, None):
704
+ pad_h = padding[0]
705
+ pad_v = padding[1]
706
+ paddings = [np.floor([pad_h,pad_h]).astype("int32"),
707
+ np.floor([pad_v,pad_v]).astype("int32"),
708
+ np.zeros((2)).astype("int32")]
709
+ inp_pad = np.pad(inp, paddings, 'constant', constant_values=const_val)
710
+ return (inp_pad,paddings)
612
711
  else:
613
- wt_mat_neg = wt_mat_neg + ((p_mat / p_sum) * p_agg_wt)
614
- wt_mat_pos = wt_mat_pos + ((n_mat / n_sum) * n_agg_wt * -1)
712
+ return (inp, [[0,0],[0,0],[0,0]])
713
+
714
+ def calculate_wt_conv_unit(patch, wts_pos, wts_neg, w, b, act):
715
+ k = w.numpy()
716
+ bias = b.numpy()
717
+ conv_out = np.einsum("ijkl,ijk->ijkl",k,patch)
718
+ p_ind = conv_out>0
719
+ p_ind = conv_out*p_ind
720
+ p_sum = np.einsum("ijkl->l",p_ind)
721
+ n_ind = conv_out<0
722
+ n_ind = conv_out*n_ind
723
+ n_sum = np.einsum("ijkl->l",n_ind)*-1.0
724
+ p_agg_wt_pos,p_agg_wt_neg,n_agg_wt_pos,n_agg_wt_neg,p_sum,n_sum = calculate_base_wt_array(p_sum,n_sum,bias,wts_pos,wts_neg)
725
+ wt_mat_pos = np.zeros_like(k)
726
+ wt_mat_neg = np.zeros_like(k)
727
+
728
+ wt_mat_pos = wt_mat_pos+((p_ind/p_sum)*p_agg_wt_pos)
729
+ wt_mat_pos = wt_mat_pos+((n_ind/n_sum)*n_agg_wt_pos)*-1.0
730
+ wt_mat_neg = wt_mat_neg+((p_ind/p_sum)*p_agg_wt_neg)
731
+ wt_mat_neg = wt_mat_neg+((n_ind/n_sum)*n_agg_wt_neg)*-1.0
732
+ wt_mat_pos = np.sum(wt_mat_pos,axis=-1)
733
+ wt_mat_neg = np.sum(wt_mat_neg,axis=-1)
734
+
615
735
  return wt_mat_pos, wt_mat_neg
616
736
 
617
-
618
- def dummy_wt_conv(wt, p_mat, n_mat, t_sum, p_sum, n_sum, act):
619
- wt_mat = np.ones_like(p_mat)
620
- return wt_mat / np.sum(wt_mat)
621
-
622
-
623
- def calculate_wt_conv(wts_pos, wts_neg, inp, w, b, act):
737
+ def calculate_wt_conv(wts_pos, wts_neg, inp, w, b, padding, strides, act):
624
738
  wts_pos=wts_pos.T
625
739
  wts_neg=wts_neg.T
626
740
  inp=inp.T
627
- w=w.T
628
- expanded_input = as_strided(
629
- inp,
630
- shape=(
631
- inp.shape[0]
632
- - w.numpy().shape[0]
633
- + 1, # The feature map is a few pixels smaller than the input
634
- inp.shape[1] - w.numpy().shape[1] + 1,
635
- inp.shape[2],
636
- w.numpy().shape[0],
637
- w.numpy().shape[1],
638
- ),
639
- strides=(
640
- inp.strides[0],
641
- inp.strides[1],
642
- inp.strides[2],
643
- inp.strides[
644
- 0
645
- ], # When we move one step in the 3rd dimension, we should move one step in the original data too
646
- inp.strides[1],
647
- ),
648
- writeable=False, # totally use this to avoid writing to memory in weird places
649
- )
650
- test_wt_pos = np.einsum("mnc->cmn", np.zeros_like(inp), order="C", optimize=True)
651
- test_wt_neg = np.einsum("mnc->cmn", np.zeros_like(inp), order="C", optimize=True)
652
- for k in range(w.numpy().shape[-1]):
653
- kernel = w.numpy()[:, :, :, k]
654
- if b.numpy()[k] > 0:
655
- pbias = b.numpy()[k]
656
- nbias = 0
657
- else:
658
- pbias = 0
659
- nbias = b.numpy()[k] * -1
660
- x = np.einsum(
661
- "abcmn,mnc->abcmn", expanded_input, kernel, order="C", optimize=True
662
- )
663
- # x_pos = np.copy(x)
664
- # x_neg = np.copy(x)
665
- x_pos = x.copy()
666
- x_neg = x.copy()
667
- x_pos[x < 0] = 0
668
- x_neg[x > 0] = 0
669
- x_p_sum = np.einsum("abcmn->ab", x_pos, order="C", optimize=True)
670
- x_n_sum = np.einsum("abcmn->ab", x_neg, order="C", optimize=True) * -1.0
671
- # print(np.sum(x),np.sum(x_pos),np.sum(x_neg),np.sum(x_n_sum))
672
- for ind1 in range(expanded_input.shape[0]):
673
- for ind2 in range(expanded_input.shape[1]):
674
- temp_wt_mat_pos, temp_wt_mat_neg = calculate_wt_conv_unit(
675
- wts_pos[ind1, ind2, k],
676
- wts_neg[ind1, ind2, k],
677
- x_pos[ind1, ind2, :, :, :],
678
- x_neg[ind1, ind2, :, :, :],
679
- x_p_sum[ind1, ind2],
680
- x_n_sum[ind1, ind2],
681
- pbias,
682
- nbias,
683
- act,
684
- )
685
- test_wt_pos[
686
- :, ind1 : ind1 + kernel.shape[0], ind2 : ind2 + kernel.shape[1]
687
- ] += temp_wt_mat_pos
688
- test_wt_neg[
689
- :, ind1 : ind1 + kernel.shape[0], ind2 : ind2 + kernel.shape[1]
690
- ] += temp_wt_mat_neg
691
- test_wt_pos = np.einsum("cmn->mnc", test_wt_pos, order="C", optimize=True)
692
- test_wt_neg = np.einsum("cmn->mnc", test_wt_neg, order="C", optimize=True)
693
- gc.collect()
694
- return test_wt_pos, test_wt_neg
741
+ w = w.T
742
+ input_padded, paddings = calculate_padding(w.shape, inp, padding, strides)
743
+ out_ds_pos = np.zeros_like(input_padded)
744
+ out_ds_neg = np.zeros_like(input_padded)
745
+ for ind1 in range(wts_pos.shape[0]):
746
+ for ind2 in range(wts_pos.shape[1]):
747
+ indexes = [np.arange(ind1*strides[0], ind1*(strides[0])+w.shape[0]),
748
+ np.arange(ind2*strides[1], ind2*(strides[1])+w.shape[1])]
749
+ # Take slice
750
+ tmp_patch = input_padded[np.ix_(indexes[0],indexes[1])]
751
+ updates_pos,updates_neg = calculate_wt_conv_unit(tmp_patch, wts_pos[ind1,ind2,:], wts_neg[ind1,ind2,:], w, b, act)
752
+ # Build tensor with "filtered" gradient
753
+ out_ds_pos[np.ix_(indexes[0],indexes[1])]+=updates_pos
754
+ out_ds_neg[np.ix_(indexes[0],indexes[1])]+=updates_neg
755
+ out_ds_pos = out_ds_pos[paddings[0][0]:(paddings[0][0]+inp.shape[0]),
756
+ paddings[1][0]:(paddings[1][0]+inp.shape[1]),:]
757
+ out_ds_neg = out_ds_neg[paddings[0][0]:(paddings[0][0]+inp.shape[0]),
758
+ paddings[1][0]:(paddings[1][0]+inp.shape[1]),:]
759
+ return out_ds_pos, out_ds_neg
695
760
 
761
+ def dummy_wt_conv(wt, p_mat, n_mat, t_sum, p_sum, n_sum, act):
762
+ wt_mat = np.ones_like(p_mat)
763
+ return wt_mat / np.sum(wt_mat)
696
764
 
697
765
  def get_max_index(mat=None):
698
766
  max_ind = np.argmax(mat)
@@ -704,7 +772,6 @@ def get_max_index(mat=None):
704
772
  ind.append(rem)
705
773
  return tuple(ind)
706
774
 
707
-
708
775
  def calculate_wt_maxpool(wts, inp, pool_size):
709
776
  wts=wts.T
710
777
  inp=inp.T
@@ -822,13 +889,11 @@ def calculate_wt_gavgpool(wts_pos, wts_neg, inp):
822
889
  wt_mat_neg[..., c] = temp_wt_neg
823
890
  return wt_mat_pos, wt_mat_neg
824
891
 
825
-
826
892
  def weight_scaler(arg, scaler=100.0):
827
893
  s1 = np.sum(arg)
828
894
  scale_factor = s1 / scaler
829
895
  return arg / scale_factor
830
896
 
831
-
832
897
  def weight_normalize(arg, max_val=1.0):
833
898
  arg_max = np.max(arg)
834
899
  arg_min = np.abs(np.min(arg))
@@ -838,3 +903,389 @@ def weight_normalize(arg, max_val=1.0):
838
903
  return (arg / arg_min) * max_val
839
904
  else:
840
905
  return arg
906
+
907
+ def calculate_padding_1d(kernel_size, inp, padding, strides, const_val=0.0):
908
+ if padding == 'valid':
909
+ return inp, [[0, 0],[0,0]]
910
+ elif padding == 0:
911
+ return inp, [[0, 0],[0,0]]
912
+ elif isinstance(padding, int):
913
+ inp_pad = np.pad(inp, ((padding, padding), (0,0)), 'constant', constant_values=const_val)
914
+ return inp_pad, [[padding, padding],[0,0]]
915
+ else:
916
+ remainder = inp.shape[0] % strides
917
+ if remainder == 0:
918
+ pad_total = max(0, kernel_size - strides)
919
+ else:
920
+ pad_total = max(0, kernel_size - remainder)
921
+
922
+ pad_left = int(np.floor(pad_total / 2.0))
923
+ pad_right = int(np.ceil(pad_total / 2.0))
924
+
925
+ inp_pad = np.pad(inp, ((pad_left, pad_right),(0,0)), 'constant', constant_values=const_val)
926
+ return inp_pad, [[pad_left, pad_right],[0,0]]
927
+
928
+ def calculate_wt_conv_unit_1d(patch, wts_pos, wts_neg, w, b, act):
929
+ k = w.numpy()
930
+ bias = b.numpy()
931
+ conv_out = np.einsum("ijk,ij->ijk",k,patch)
932
+ p_ind = conv_out>0
933
+ p_ind = conv_out*p_ind
934
+ p_sum = np.einsum("ijk->k",p_ind)
935
+ n_ind = conv_out<0
936
+ n_ind = conv_out*n_ind
937
+ n_sum = np.einsum("ijk->k",n_ind)*-1.0
938
+ p_agg_wt_pos,p_agg_wt_neg,n_agg_wt_pos,n_agg_wt_neg,p_sum,n_sum = calculate_base_wt_array(p_sum,n_sum,bias,wts_pos,wts_neg)
939
+ wt_mat_pos = np.zeros_like(k)
940
+ wt_mat_neg = np.zeros_like(k)
941
+ wt_mat_pos = wt_mat_pos+((p_ind/p_sum)*p_agg_wt_pos)
942
+ wt_mat_pos = wt_mat_pos+((n_ind/n_sum)*n_agg_wt_pos)*-1.0
943
+ wt_mat_neg = wt_mat_neg+((p_ind/p_sum)*p_agg_wt_neg)
944
+ wt_mat_neg = wt_mat_neg+((n_ind/n_sum)*n_agg_wt_neg)*-1.0
945
+ wt_mat_pos = np.sum(wt_mat_pos,axis=-1)
946
+ wt_mat_neg = np.sum(wt_mat_neg,axis=-1)
947
+
948
+ return wt_mat_pos, wt_mat_neg
949
+
950
+ def calculate_wt_conv_1d(wts_pos, wts_neg, inp, w, b, padding, stride, act):
951
+ wts_pos=wts_pos.T
952
+ wts_neg=wts_neg.T
953
+ inp=inp.T
954
+ w = w.T
955
+ input_padded, paddings = calculate_padding_1d(w.shape[0], inp, padding, stride)
956
+ out_ds_pos = np.zeros_like(input_padded)
957
+ out_ds_neg = np.zeros_like(input_padded)
958
+ for ind in range(wts_pos.shape[0]):
959
+ indexes = np.arange(ind * stride, ind * stride + w.shape[0])
960
+ tmp_patch = input_padded[indexes]
961
+ updates_pos,updates_neg = calculate_wt_conv_unit_1d(tmp_patch, wts_pos[ind, :], wts_neg[ind, :], w, b, act)
962
+
963
+ out_ds_pos[indexes] += updates_pos
964
+ out_ds_neg[indexes] += updates_neg
965
+
966
+ out_ds_pos = out_ds_pos[paddings[0][0]:(paddings[0][0] + inp.shape[0])]
967
+ out_ds_neg = out_ds_neg[paddings[0][0]:(paddings[0][0] + inp.shape[0])]
968
+ return out_ds_pos, out_ds_neg
969
+
970
+ def calculate_wt_max_unit_1d(patch, wts, pool_size):
971
+ pmax = np.max(patch, axis=0)
972
+ indexes = (patch-pmax)==0
973
+ indexes = indexes.astype(np.float32)
974
+ indexes_norm = 1.0 / np.sum(indexes, axis=0)
975
+ indexes = np.einsum("ij,j->ij", indexes, indexes_norm)
976
+ out = np.einsum("ij,j->ij", indexes, wts)
977
+ return out
978
+
979
+ def calculate_wt_maxpool_1d(wts, inp, pool_size, padding, strides):
980
+ wts=wts.T
981
+ inp=inp.T
982
+ input_padded, paddings = calculate_padding_1d(pool_size, inp, padding, strides, -np.inf)
983
+ out_ds = np.zeros_like(input_padded)
984
+ stride=strides
985
+ pool_size=pool_size
986
+ for ind in range(wts.shape[0]):
987
+ indexes = np.arange(ind * stride, ind * stride + pool_size)
988
+ tmp_patch = input_padded[indexes]
989
+ updates = calculate_wt_max_unit_1d(tmp_patch, wts[ind, :], pool_size)
990
+ out_ds[indexes] += updates
991
+ out_ds = out_ds[paddings[0][0]:(paddings[0][0] + inp.shape[0])]
992
+ return out_ds
993
+
994
+ def calculate_wt_avg_unit_1d(patch, wts_pos, wts_neg, pool_size):
995
+ p_ind = patch>0
996
+ p_ind = patch*p_ind
997
+ p_sum = np.sum(p_ind, axis=0)
998
+ n_ind = patch<0
999
+ n_ind = patch*n_ind
1000
+ n_sum = np.sum(n_ind, axis=0)*-1.0
1001
+ bias = np.zeros_like(wts_pos)
1002
+ p_agg_wt_pos,p_agg_wt_neg,n_agg_wt_pos,n_agg_wt_neg,p_sum,n_sum = calculate_base_wt_array(p_sum,n_sum,bias,wts_pos,wts_neg)
1003
+ wt_mat_pos = np.zeros_like(patch)
1004
+ wt_mat_neg = np.zeros_like(patch)
1005
+
1006
+ wt_mat_pos = wt_mat_pos+((p_ind/p_sum)*p_agg_wt_pos)
1007
+ wt_mat_pos = wt_mat_pos+((n_ind/n_sum)*n_agg_wt_pos)*-1.0
1008
+ wt_mat_neg = wt_mat_neg+((p_ind/p_sum)*p_agg_wt_neg)
1009
+ wt_mat_neg = wt_mat_neg+((n_ind/n_sum)*n_agg_wt_neg)*-1.0
1010
+ return wt_mat_pos, wt_mat_neg
1011
+
1012
+ def calculate_wt_avgpool_1d(wts_pos, wts_neg, inp, pool_size, padding, strides, act={}):
1013
+ wts_pos=wts_pos.T
1014
+ wts_neg=wts_neg.T
1015
+ inp=inp.T
1016
+ input_padded, paddings = calculate_padding_1d(pool_size[0], inp, padding[0], strides[0])
1017
+ out_ds_pos = np.zeros_like(input_padded)
1018
+ out_ds_neg = np.zeros_like(input_padded)
1019
+ stride=strides[0]
1020
+ pool_size=pool_size[0]
1021
+ for ind in range(wts_pos.shape[0]):
1022
+ indexes = np.arange(ind * stride, ind * stride + pool_size)
1023
+ tmp_patch = input_padded[indexes]
1024
+ updates_pos,updates_neg = calculate_wt_avg_unit_1d(tmp_patch, wts_pos[ind, :], wts_neg[ind, :],pool_size)
1025
+ out_ds_pos[indexes] += updates_pos
1026
+ out_ds_neg[indexes] += updates_neg
1027
+
1028
+ out_ds_pos = out_ds_pos[paddings[0][0]:(paddings[0][0] + inp.shape[0])]
1029
+ out_ds_neg = out_ds_neg[paddings[0][0]:(paddings[0][0] + inp.shape[0])]
1030
+ return out_ds_pos,out_ds_neg
1031
+
1032
+ def calculate_wt_gavgpool_1d(wts_pos,wts_neg,inp):
1033
+ wts_pos=wts_pos.T
1034
+ wts_neg=wts_neg.T
1035
+ inp=inp.T
1036
+ channels = wts_pos.shape[0]
1037
+ wt_mat_pos = np.zeros_like(inp)
1038
+ wt_mat_neg = np.zeros_like(inp)
1039
+ for c in range(channels):
1040
+ wt_pos = wts_pos[c]
1041
+ wt_neg = wts_neg[c]
1042
+ temp_wt_pos = wt_mat_pos[...,c]
1043
+ temp_wt_neg = wt_mat_neg[...,c]
1044
+ x = inp[...,c]
1045
+ p_mat = np.copy(x)
1046
+ n_mat = np.copy(x)
1047
+ p_mat[x<0] = 0
1048
+ n_mat[x>0] = 0
1049
+ p_sum = np.sum(p_mat)
1050
+ n_sum = np.sum(n_mat)*-1
1051
+ if n_sum==0 and p_sum>0:
1052
+ temp_wt_pos = temp_wt_pos+((p_mat/p_sum)*wt_pos)
1053
+ temp_wt_neg = temp_wt_neg+((p_mat/p_sum)*wt_neg)
1054
+ elif n_sum>0 and p_sum==0:
1055
+ temp_wt_pos = temp_wt_pos+((n_mat/n_sum)*wt_pos*-1)
1056
+ temp_wt_neg = temp_wt_neg+((n_mat/n_sum)*wt_neg*-1)
1057
+ else:
1058
+ p_agg_wt,n_agg_wt,p_sum,n_sum,wt_sign = calculate_base_wt(p_sum=p_sum,n_sum=n_sum,
1059
+ bias=0,
1060
+ wt_pos=wt_pos,wt_neg=wt_neg)
1061
+ if wt_sign>0:
1062
+ temp_wt_pos = temp_wt_pos+((p_mat/p_sum)*p_agg_wt)
1063
+ temp_wt_neg = temp_wt_neg+((n_mat/n_sum)*n_agg_wt*-1)
1064
+ else:
1065
+ temp_wt_neg = temp_wt_neg+((p_mat/p_sum)*p_agg_wt)
1066
+ temp_wt_pos = temp_wt_pos+((n_mat/n_sum)*n_agg_wt*-1)
1067
+ wt_mat_pos[...,c] = temp_wt_pos
1068
+ wt_mat_neg[...,c] = temp_wt_neg
1069
+ return wt_mat_pos,wt_mat_neg
1070
+
1071
+ def calculate_wt_gmaxpool_1d(wts, inp):
1072
+ wts = wts.T
1073
+ inp = inp.T
1074
+ channels = wts.shape[0]
1075
+ wt_mat = np.zeros_like(inp)
1076
+ for c in range(channels):
1077
+ wt = wts[c]
1078
+ x = inp[:, c]
1079
+ max_val = np.max(x)
1080
+ max_indexes = (x == max_val).astype(np.float32)
1081
+ max_indexes_norm = 1.0 / np.sum(max_indexes)
1082
+ max_indexes = max_indexes * max_indexes_norm
1083
+ wt_mat[:, c] = max_indexes * wt
1084
+ return wt_mat
1085
+
1086
+ def calculate_output_padding_conv2d_transpose(input_shape, kernel_size, padding, strides):
1087
+ if padding == 'valid':
1088
+ out_shape = [(input_shape[0] - 1) * strides[0] + kernel_size[0],
1089
+ (input_shape[1] - 1) * strides[1] + kernel_size[1]]
1090
+ return (out_shape, [[0,0],[0,0],[0,0]])
1091
+ elif isinstance(padding, tuple) and padding == (0, 0):
1092
+ out_shape = [(input_shape[0] - 1) * strides[0] + kernel_size[0],
1093
+ (input_shape[1] - 1) * strides[1] + kernel_size[1]]
1094
+ return (out_shape, [[0,0],[0,0],[0,0]])
1095
+ else: # 'same' padding
1096
+ out_shape = [input_shape[0] * strides[0], input_shape[1] * strides[1]]
1097
+ pad_h = max(0, (input_shape[0] - 1) * strides[0] + kernel_size[0] - out_shape[0])
1098
+ pad_v = max(0, (input_shape[1] - 1) * strides[1] + kernel_size[1] - out_shape[1])
1099
+ paddings = [np.floor([pad_h/2.0, (pad_h+1)/2.0]).astype("int32"),
1100
+ np.floor([pad_v/2.0, (pad_v+1)/2.0]).astype("int32"),
1101
+ np.zeros((2)).astype("int32")]
1102
+ return (out_shape, paddings)
1103
+
1104
+ def calculate_wt_conv2d_transpose_unit(patch, wts_pos, wts_neg, w, b, act):
1105
+ if patch.ndim == 1:
1106
+ patch = patch.reshape(1, 1, -1)
1107
+ elif patch.ndim == 2:
1108
+ patch = patch.reshape(1, *patch.shape)
1109
+ elif patch.ndim != 3:
1110
+ raise ValueError(f"Unexpected patch shape: {patch.shape}")
1111
+ k = w.permute(0, 1, 3, 2).numpy()
1112
+ bias = b.numpy()
1113
+ b_ind = bias>0
1114
+ bias_pos = bias*b_ind
1115
+ b_ind = bias<0
1116
+ bias_neg = bias*b_ind*-1.0
1117
+ conv_out = np.einsum('ijkl,mnk->ijkl', k, patch)
1118
+ p_ind = conv_out > 0
1119
+ p_ind = conv_out*p_ind
1120
+ n_ind = conv_out < 0
1121
+ n_ind = conv_out*n_ind
1122
+ p_sum = np.einsum("ijkl->l",p_ind)
1123
+ n_sum = np.einsum("ijkl->l",n_ind)*-1.0
1124
+ p_agg_wt_pos,p_agg_wt_neg,n_agg_wt_pos,n_agg_wt_neg,p_sum,n_sum = calculate_base_wt_array(p_sum,n_sum,bias,wts_pos,wts_neg)
1125
+ wt_mat_pos = np.zeros_like(k)
1126
+ wt_mat_neg = np.zeros_like(k)
1127
+ wt_mat_pos = wt_mat_pos+((p_ind/p_sum)*p_agg_wt_pos)
1128
+ wt_mat_pos = wt_mat_pos+((n_ind/n_sum)*n_agg_wt_pos)*-1.0
1129
+ wt_mat_neg = wt_mat_neg+((p_ind/p_sum)*p_agg_wt_neg)
1130
+ wt_mat_neg = wt_mat_neg+((n_ind/n_sum)*n_agg_wt_neg)*-1.0
1131
+ wt_mat_pos = np.sum(wt_mat_pos,axis=-1)
1132
+ wt_mat_neg = np.sum(wt_mat_neg,axis=-1)
1133
+ return wt_mat_pos, wt_mat_neg
1134
+
1135
+ def calculate_wt_conv2d_transpose(wts_pos, wts_neg, inp, w, b, padding, strides, act):
1136
+ wts_pos=wts_pos.T
1137
+ wts_neg=wts_neg.T
1138
+ inp=inp.T
1139
+ w = w.T
1140
+ out_shape, paddings = calculate_output_padding_conv2d_transpose(inp.shape, w.shape, padding, strides)
1141
+ out_ds_pos = np.zeros(out_shape + [w.shape[3]])
1142
+ out_ds_neg = np.zeros(out_shape + [w.shape[3]])
1143
+ for ind1 in range(inp.shape[0]):
1144
+ for ind2 in range(inp.shape[1]):
1145
+ out_ind1 = ind1 * strides[0]
1146
+ out_ind2 = ind2 * strides[1]
1147
+ tmp_patch = inp[ind1, ind2, :]
1148
+ updates_pos,updates_neg = calculate_wt_conv2d_transpose_unit(tmp_patch, wts_pos[ind1,ind2,:], wts_neg[ind1,ind2,:], w, b, act)
1149
+ end_ind1 = min(out_ind1 + w.shape[0], out_shape[0])
1150
+ end_ind2 = min(out_ind2 + w.shape[1], out_shape[1])
1151
+ valid_updates_pos = updates_pos[:end_ind1 - out_ind1, :end_ind2 - out_ind2, :]
1152
+ valid_updates_neg = updates_neg[:end_ind1 - out_ind1, :end_ind2 - out_ind2, :]
1153
+
1154
+ out_ds_pos[out_ind1:end_ind1, out_ind2:end_ind2, :] += valid_updates_pos
1155
+ out_ds_neg[out_ind1:end_ind1, out_ind2:end_ind2, :] += valid_updates_neg
1156
+
1157
+ if padding == 'same':
1158
+ adjusted_out_ds_pos = np.zeros(inp.shape)
1159
+ adjusted_out_ds_neg = np.zeros(inp.shape)
1160
+ for i in range(inp.shape[0]):
1161
+ for j in range(inp.shape[1]):
1162
+ start_i = max(0, i * strides[0])
1163
+ start_j = max(0, j * strides[1])
1164
+ end_i = min(out_ds_pos.shape[0], (i+1) * strides[0])
1165
+ end_j = min(out_ds_pos.shape[1], (j+1) * strides[1])
1166
+ relevant_area_pos = out_ds_pos[start_i:end_i, start_j:end_j, :]
1167
+ adjusted_out_ds_pos[i, j, :] = np.sum(relevant_area_pos, axis=(0, 1))
1168
+ relevant_area_neg = out_ds_neg[start_i:end_i, start_j:end_j, :]
1169
+ adjusted_out_ds_neg[i, j, :] = np.sum(relevant_area_neg, axis=(0, 1))
1170
+ out_ds_pos = adjusted_out_ds_pos
1171
+ out_ds_neg = adjusted_out_ds_neg
1172
+ elif isinstance(padding, tuple) and padding == (0, 0):
1173
+ adjusted_out_ds_pos = np.zeros(inp.shape)
1174
+ adjusted_out_ds_neg = np.zeros(inp.shape)
1175
+ for i in range(inp.shape[0]):
1176
+ for j in range(inp.shape[1]):
1177
+ start_i = max(0, i * strides[0])
1178
+ start_j = max(0, j * strides[1])
1179
+ end_i = min(out_ds_pos.shape[0], (i+1) * strides[0])
1180
+ end_j = min(out_ds_pos.shape[1], (j+1) * strides[1])
1181
+ relevant_area_pos = out_ds_pos[start_i:end_i, start_j:end_j, :]
1182
+ adjusted_out_ds_pos[i, j, :] = np.sum(relevant_area_pos, axis=(0, 1))
1183
+ relevant_area_neg = out_ds_neg[start_i:end_i, start_j:end_j, :]
1184
+ adjusted_out_ds_neg[i, j, :] = np.sum(relevant_area_neg, axis=(0, 1))
1185
+ out_ds_pos = adjusted_out_ds_pos
1186
+ out_ds_neg = adjusted_out_ds_neg
1187
+ else:
1188
+ out_ds_pos = out_ds_pos[paddings[0][0]:(paddings[0][0] + inp.shape[0]),
1189
+ paddings[1][0]:(paddings[1][0] + inp.shape[1]), :]
1190
+ out_ds_neg = out_ds_neg[paddings[0][0]:(paddings[0][0] + inp.shape[0]),
1191
+ paddings[1][0]:(paddings[1][0] + inp.shape[1]), :]
1192
+
1193
+ return out_ds_pos,out_ds_neg
1194
+
1195
+ def calculate_output_padding_conv1d_transpose(input_shape, kernel_size, padding, strides):
1196
+ if padding == 'valid':
1197
+ out_shape = [(input_shape[0] - 1) * strides + kernel_size[0]]
1198
+ return (out_shape, [0, 0])
1199
+ elif padding == 0:
1200
+ out_shape = [(input_shape[0] - 1) * strides + kernel_size[0]]
1201
+ return (out_shape, [0, 0])
1202
+ else: # 'same' padding
1203
+ out_shape = [input_shape[0] * strides]
1204
+ pad_h = max(0, (input_shape[0] - 1) * strides + kernel_size[0] - out_shape[0])
1205
+ paddings = np.floor([pad_h / 2.0, (pad_h + 1) / 2.0]).astype("int32")
1206
+ return (out_shape, paddings)
1207
+
1208
+ def calculate_wt_conv1d_transpose_unit(patch, wts_pos, wts_neg, w, b, act):
1209
+ if patch.ndim == 1:
1210
+ patch = patch.reshape(1, -1)
1211
+ elif patch.ndim != 2:
1212
+ raise ValueError(f"Unexpected patch shape: {patch.shape}")
1213
+
1214
+ k = w.permute(0, 2, 1).numpy()
1215
+ bias = b.numpy()
1216
+ b_ind = bias > 0
1217
+ bias_pos = bias * b_ind
1218
+ b_ind = bias < 0
1219
+ bias_neg = bias * b_ind * -1.0
1220
+
1221
+ conv_out = np.einsum('ijk,mj->ijk', k, patch)
1222
+ p_ind = conv_out > 0
1223
+ p_ind = conv_out * p_ind
1224
+ n_ind = conv_out < 0
1225
+ n_ind = conv_out * n_ind
1226
+ p_sum = np.einsum("ijk->k", p_ind)
1227
+ n_sum = np.einsum("ijk->k", n_ind) * -1.0
1228
+
1229
+ p_agg_wt_pos, p_agg_wt_neg, n_agg_wt_pos, n_agg_wt_neg, p_sum, n_sum = calculate_base_wt_array(p_sum, n_sum, bias, wts_pos, wts_neg)
1230
+ wt_mat_pos = np.zeros_like(k)
1231
+ wt_mat_neg = np.zeros_like(k)
1232
+
1233
+ wt_mat_pos += (p_ind / p_sum) * p_agg_wt_pos
1234
+ wt_mat_pos += (n_ind / n_sum) * n_agg_wt_pos * -1.0
1235
+ wt_mat_neg += (p_ind / p_sum) * p_agg_wt_neg
1236
+ wt_mat_neg += (n_ind / n_sum) * n_agg_wt_neg * -1.0
1237
+
1238
+ wt_mat_pos = np.sum(wt_mat_pos, axis=-1)
1239
+ wt_mat_neg = np.sum(wt_mat_neg, axis=-1)
1240
+
1241
+ return wt_mat_pos, wt_mat_neg
1242
+
1243
+ def calculate_wt_conv1d_transpose(wts_pos, wts_neg, inp, w, b, padding, strides, act):
1244
+ wts_pos=wts_pos.T
1245
+ wts_neg=wts_neg.T
1246
+ inp=inp.T
1247
+ w = w.T
1248
+ out_shape, paddings = calculate_output_padding_conv1d_transpose(inp.shape, w.shape, padding, strides)
1249
+ out_ds_pos = np.zeros(out_shape + [w.shape[2]])
1250
+ out_ds_neg = np.zeros(out_shape + [w.shape[2]])
1251
+
1252
+ for ind in range(inp.shape[0]):
1253
+ out_ind = ind * strides
1254
+ tmp_patch = inp[ind, :]
1255
+ updates_pos, updates_neg = calculate_wt_conv1d_transpose_unit(tmp_patch, wts_pos[ind, :], wts_neg[ind, :], w, b, act)
1256
+ end_ind = min(out_ind + w.shape[0], out_shape[0])
1257
+ valid_updates_pos = updates_pos[:end_ind - out_ind, :]
1258
+ valid_updates_neg = updates_neg[:end_ind - out_ind, :]
1259
+
1260
+ out_ds_pos[out_ind:end_ind, :] += valid_updates_pos
1261
+ out_ds_neg[out_ind:end_ind, :] += valid_updates_neg
1262
+
1263
+ if padding == 'same':
1264
+ adjusted_out_ds_pos = np.zeros(inp.shape)
1265
+ adjusted_out_ds_neg = np.zeros(inp.shape)
1266
+ for i in range(inp.shape[0]):
1267
+ start_i = max(0, i * strides)
1268
+ end_i = min(out_ds_pos.shape[0], (i + 1) * strides)
1269
+ relevant_area_pos = out_ds_pos[start_i:end_i, :]
1270
+ adjusted_out_ds_pos[i, :] = np.sum(relevant_area_pos, axis=0)
1271
+ relevant_area_neg = out_ds_neg[start_i:end_i, :]
1272
+ adjusted_out_ds_neg[i, :] = np.sum(relevant_area_neg, axis=0)
1273
+ out_ds_pos = adjusted_out_ds_pos
1274
+ out_ds_neg = adjusted_out_ds_neg
1275
+ elif padding == 0:
1276
+ adjusted_out_ds_pos = np.zeros(inp.shape)
1277
+ adjusted_out_ds_neg = np.zeros(inp.shape)
1278
+ for i in range(inp.shape[0]):
1279
+ start_i = max(0, i * strides)
1280
+ end_i = min(out_ds_pos.shape[0], (i + 1) * strides)
1281
+ relevant_area_pos = out_ds_pos[start_i:end_i, :]
1282
+ adjusted_out_ds_pos[i, :] = np.sum(relevant_area_pos, axis=0)
1283
+ relevant_area_neg = out_ds_neg[start_i:end_i, :]
1284
+ adjusted_out_ds_neg[i, :] = np.sum(relevant_area_neg, axis=0)
1285
+ out_ds_pos = adjusted_out_ds_pos
1286
+ out_ds_neg = adjusted_out_ds_neg
1287
+ else:
1288
+ out_ds_pos = out_ds_pos[paddings[0]:(paddings[0] + inp.shape[0]), :]
1289
+ out_ds_neg = out_ds_neg[paddings[0]:(paddings[0] + inp.shape[0]), :]
1290
+
1291
+ return out_ds_pos, out_ds_neg