dl-backtrace 0.0.17__py3-none-any.whl → 0.0.19__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of dl-backtrace might be problematic. Click here for more details.

@@ -1,42 +1,32 @@
1
1
  import gc
2
-
2
+ import torch
3
3
  import numpy as np
4
- import tensorflow as tf
5
4
  from numpy.lib.stride_tricks import as_strided
6
- from tensorflow.keras import backend as K
7
-
8
5
 
9
6
  def np_swish(x, beta=0.75):
10
7
  z = 1 / (1 + np.exp(-(beta * x)))
11
8
  return x * z
12
9
 
13
-
14
10
  def np_wave(x, alpha=1.0):
15
11
  return (alpha * x * np.exp(1.0)) / (np.exp(-x) + np.exp(x))
16
12
 
17
-
18
13
  def np_pulse(x, alpha=1.0):
19
14
  return alpha * (1 - np.tanh(x) * np.tanh(x))
20
15
 
21
-
22
16
  def np_absolute(x, alpha=1.0):
23
17
  return alpha * x * np.tanh(x)
24
18
 
25
-
26
19
  def np_hard_sigmoid(x):
27
20
  return np.clip(0.2 * x + 0.5, 0, 1)
28
21
 
29
-
30
22
  def np_sigmoid(x):
31
23
  z = 1 / (1 + np.exp(-x))
32
24
  return z
33
25
 
34
-
35
26
  def np_tanh(x):
36
27
  z = np.tanh(x)
37
28
  return z.astype(np.float32)
38
29
 
39
-
40
30
  class LSTM_forward(object):
41
31
  def __init__(
42
32
  self, num_cells, units, weights, return_sequence=False, go_backwards=False
@@ -48,8 +38,8 @@ class LSTM_forward(object):
48
38
  self.bias = weights[2][1]
49
39
  self.return_sequence = return_sequence
50
40
  self.go_backwards = go_backwards
51
- self.recurrent_activation = tf.math.sigmoid
52
- self.activation = tf.math.tanh
41
+ self.recurrent_activation = torch.sigmoid()
42
+ self.activation = torch.tanh()
53
43
  self.compute_log = {}
54
44
  for i in range(self.num_cells):
55
45
  self.compute_log[i] = {}
@@ -63,23 +53,19 @@ class LSTM_forward(object):
63
53
  """Computes carry and output using split kernels."""
64
54
  x_i, x_f, x_c, x_o = x
65
55
  h_tm1_i, h_tm1_f, h_tm1_c, h_tm1_o = h_tm1
66
- #print(self.recurrent_kernel[1][:, : self.units].shape)
67
- #print(h_tm1_i.shape,self.recurrent_kernel[1][:, : self.units].shape)
68
- w=tf.convert_to_tensor(self.recurrent_kernel[1], dtype=tf.float32)
69
- #print(K.dot(h_tm1_i, w[:, : self.units]))
70
-
56
+ w=torch.as_tensor(self.recurrent_kernel[1], dtype=torch.float32)
71
57
  i = self.recurrent_activation(
72
- x_i + K.dot(h_tm1_i, w[:, : self.units])
58
+ x_i + torch.dot(h_tm1_i, w[:, : self.units])
73
59
  )
74
60
  f = self.recurrent_activation(
75
- x_f + K.dot(h_tm1_f, w[:, self.units : self.units * 2])
61
+ x_f + torch.dot(h_tm1_f, w[:, self.units : self.units * 2])
76
62
  )
77
63
  c = f * c_tm1 + i * self.activation(
78
64
  x_c
79
- + K.dot(h_tm1_c, w[:, self.units * 2 : self.units * 3])
65
+ + torch.dot(h_tm1_c, w[:, self.units * 2 : self.units * 3])
80
66
  )
81
67
  o = self.recurrent_activation(
82
- x_o + K.dot(h_tm1_o, w[:, self.units * 3 :])
68
+ x_o + torch.dot(h_tm1_o, w[:, self.units * 3 :])
83
69
  )
84
70
  self.compute_log[cell_num]["int_arrays"]["i"] = i
85
71
  self.compute_log[cell_num]["int_arrays"]["f"] = f
@@ -97,16 +83,16 @@ class LSTM_forward(object):
97
83
  inputs_f = inputs
98
84
  inputs_c = inputs
99
85
  inputs_o = inputs
100
- k_i, k_f, k_c, k_o = tf.split(self.kernel[1], num_or_size_splits=4, axis=1)
101
- x_i = K.dot(inputs_i, k_i)
102
- x_f = K.dot(inputs_f, k_f)
103
- x_c = K.dot(inputs_c, k_c)
104
- x_o = K.dot(inputs_o, k_o)
105
- b_i, b_f, b_c, b_o = tf.split(self.bias, num_or_size_splits=4, axis=0)
106
- x_i = tf.add(x_i, b_i)
107
- x_f = tf.add(x_f, b_f)
108
- x_c = tf.add(x_c, b_c)
109
- x_o = tf.add(x_o, b_o)
86
+ k_i, k_f, k_c, k_o = torch.split(self.kernel[1],self.kernel.size(1)//4,dim=1)
87
+ x_i = torch.dot(inputs_i, k_i)
88
+ x_f = torch.dot(inputs_f, k_f)
89
+ x_c = torch.dot(inputs_c, k_c)
90
+ x_o = torch.dot(inputs_o, k_o)
91
+ b_i, b_f, b_c, b_o = torch.split(self.bias,self.bias.size(1)//4,dim=0)
92
+ x_i = x_i + b_i
93
+ x_f = x_f + b_f
94
+ x_c = x_c + b_c
95
+ x_o = x_o + b_o
110
96
 
111
97
  h_tm1_i = h_tm1
112
98
  h_tm1_f = h_tm1
@@ -123,12 +109,12 @@ class LSTM_forward(object):
123
109
  return h, [h, c]
124
110
 
125
111
  def calculate_lstm_wt(self, input_data):
126
- hstate = tf.convert_to_tensor(np.zeros((1, self.units)), dtype=tf.float32)
127
- cstate = tf.convert_to_tensor(np.zeros((1, self.units)), dtype=tf.float32)
112
+ hstate = torch.tensor((1,self.units),dtype=torch.float32)
113
+ cstate = torch.tensor((1,self.units),dtype=torch.float32)
128
114
  output = []
129
115
  for ind in range(input_data.shape[0]):
130
- inp = tf.convert_to_tensor(
131
- input_data[ind, :].reshape((1, input_data.shape[1])), dtype=tf.float32
116
+ inp = torch.tensor(
117
+ input_data[ind, :].reshape((1, input_data.shape[1])), dtype=torch.float32
132
118
  )
133
119
  h, s = self.calculate_lstm_cell_wt(inp, [hstate, cstate], ind)
134
120
  hstate = s[0]
@@ -136,9 +122,6 @@ class LSTM_forward(object):
136
122
  output.append(h)
137
123
  return output
138
124
 
139
-
140
-
141
-
142
125
  class LSTM_backtrace(object):
143
126
  def __init__(
144
127
  self, num_cells, units, weights, return_sequence=False, go_backwards=False
@@ -270,8 +253,6 @@ class LSTM_backtrace(object):
270
253
  x_i, x_f, x_c, x_o = x
271
254
  f = self.compute_log[cell_num]["int_arrays"]["f"].numpy()[0]
272
255
  i = self.compute_log[cell_num]["int_arrays"]["i"].numpy()[0]
273
- # o = self.recurrent_activation(
274
- # x_o + np.dot(h_tm1_o, self.recurrent_kernel[:, self.units * 3:])).astype(np.float32)
275
256
  temp1 = np.dot(h_tm1_o, self.recurrent_kernel[1][:, self.units * 3 :]).astype(
276
257
  np.float32
277
258
  )
@@ -283,9 +264,6 @@ class LSTM_backtrace(object):
283
264
  [],
284
265
  {"type": None},
285
266
  )
286
-
287
- # c = f * c_tm1 + i * self.activation(x_c + np.dot(
288
- # h_tm1_c, self.recurrent_kernel[:, self.units * 2:self.units * 3])).astype(np.float32)
289
267
  temp2 = f * c_tm1
290
268
  temp3_1 = np.dot(
291
269
  h_tm1_c, self.recurrent_kernel[1][:, self.units * 2 : self.units * 3]
@@ -303,9 +281,6 @@ class LSTM_backtrace(object):
303
281
  [],
304
282
  {"type": None},
305
283
  )
306
-
307
- # f = self.recurrent_activation(x_f + np.dot(
308
- # h_tm1_f, self.recurrent_kernel[:, self.units:self.units * 2])).astype(np.float32)
309
284
  temp4 = np.dot(h_tm1_f, self.recurrent_kernel[1][:, self.units : self.units * 2])
310
285
  wt_x_f, wt_temp4 = self.calculate_wt_add(wt_f, [x_f, temp4])
311
286
  wt_h_tm1_f = self.calculate_wt_fc(
@@ -315,9 +290,6 @@ class LSTM_backtrace(object):
315
290
  [],
316
291
  {"type": None},
317
292
  )
318
-
319
- # i = self.recurrent_activation(
320
- # x_i + np.dot(h_tm1_i, self.recurrent_kernel[:, :self.units])).astype(np.float32)
321
293
  temp5 = np.dot(h_tm1_i, self.recurrent_kernel[1][:, : self.units])
322
294
  wt_x_i, wt_temp5 = self.calculate_wt_add(wt_i, [x_i, temp5])
323
295
  wt_h_tm1_i = self.calculate_wt_fc(
@@ -364,7 +336,6 @@ class LSTM_backtrace(object):
364
336
  wt_h_tm1 = wt_h_tm1_i + wt_h_tm1_f + wt_h_tm1_c + wt_h_tm1_o
365
337
  inputs = self.compute_log[cell_num]["inp"].numpy()[0]
366
338
 
367
- #print(np.split(self.kernel[1], indices_or_sections=4, axis=1))
368
339
  k_i, k_f, k_c, k_o = np.split(self.kernel[1], indices_or_sections=4, axis=1)
369
340
  b_i, b_f, b_c, b_o = np.split(self.bias[1], indices_or_sections=4, axis=0)
370
341
 
@@ -395,12 +366,10 @@ class LSTM_backtrace(object):
395
366
  output.reverse()
396
367
  return np.array(output)
397
368
 
398
-
399
369
  def dummy_wt(wts, inp, *args):
400
370
  test_wt = np.zeros_like(inp)
401
371
  return test_wt
402
372
 
403
-
404
373
  def calculate_wt_fc(wts, inp, w, b, act):
405
374
  mul_mat = np.einsum("ij,i->ij", w.numpy().T, inp).T
406
375
  wt_mat = np.zeros(mul_mat.shape)
@@ -461,12 +430,10 @@ def calculate_wt_fc(wts, inp, w, b, act):
461
430
  wt_mat = wt_mat.sum(axis=0)
462
431
  return wt_mat
463
432
 
464
-
465
433
  def calculate_wt_rshp(wts, inp=None):
466
434
  x = np.reshape(wts, inp.shape)
467
435
  return x
468
436
 
469
-
470
437
  def calculate_wt_concat(wts, inp=None, axis=-1):
471
438
  wts=wts.T
472
439
  splits = [i.shape[axis] for i in inp]
@@ -476,7 +443,6 @@ def calculate_wt_concat(wts, inp=None, axis=-1):
476
443
  x = np.split(wts, indices_or_sections=splits, axis=axis)
477
444
  return x
478
445
 
479
-
480
446
  def calculate_wt_add(wts, inp=None):
481
447
  wts=wts.T
482
448
  wt_mat = []
@@ -523,199 +489,231 @@ def calculate_wt_add(wts, inp=None):
523
489
  wt_mat = [i.reshape(wts.shape) for i in list(wt_mat)]
524
490
  return wt_mat
525
491
 
526
-
527
- def calculate_start_wt(arg):
528
- x = np.argmax(arg[0])
529
- y = np.zeros(arg.shape)
530
- y[0][x] = 1
492
+ def calculate_start_wt(arg, scaler=None,thresholding=0.5,task="binary-classification"):
493
+ if arg.ndim == 2:
494
+ if task == "binary-classification" or task == "multi-class classification":
495
+ x = np.argmax(arg[0])
496
+ m = np.max(arg[0])
497
+ y = np.zeros(arg.shape)
498
+ if scaler:
499
+ y[0][x] = scaler
500
+ else:
501
+ y[0][x] = m
502
+ elif task == "bbox-regression":
503
+ y = np.zeros(arg.shape)
504
+ if scaler:
505
+ y[0] = scaler
506
+ num_non_zero_elements = np.count_nonzero(y)
507
+ if num_non_zero_elements > 0:
508
+ y = y / num_non_zero_elements
509
+ else:
510
+ m = np.max(arg[0])
511
+ x = np.argmax(arg[0])
512
+ y[0][x] = m
513
+ else:
514
+ x = np.argmax(arg[0])
515
+ m = np.max(arg[0])
516
+ y = np.zeros(arg.shape)
517
+ if scaler:
518
+ y[0][x] = scaler
519
+ else:
520
+ y[0][x] = m
521
+
522
+ elif arg.ndim == 4 and task == "binary-segmentation":
523
+ indices = np.where(arg > thresholding)
524
+ y = np.zeros(arg.shape)
525
+ if scaler:
526
+ y[indices] = scaler
527
+ num_non_zero_elements = np.count_nonzero(y)
528
+ if num_non_zero_elements > 0:
529
+ y = y / num_non_zero_elements
530
+ else:
531
+ y[indices] = arg[indices]
532
+
533
+ else:
534
+ x = np.argmax(arg[0])
535
+ m = np.max(arg[0])
536
+ y = np.zeros(arg.shape)
537
+ if scaler:
538
+ y[0][x] = scaler
539
+ else:
540
+ y[0][x] = m
531
541
  return y[0]
532
542
 
533
-
534
543
  def calculate_wt_passthru(wts):
535
544
  return wts
545
+ def calculate_wt_zero_pad(wts,inp,padding):
546
+ wt_mat = wts[padding[0][0]:inp.shape[0]+padding[0][0],padding[1][0]:inp.shape[1]+padding[1][0],:]
547
+ return wt_mat
536
548
 
549
+ def calculate_padding(kernel_size, inp, padding, strides, const_val=0.0):
550
+ if padding=='valid':
551
+ return (inp, [[0,0],[0,0],[0,0]])
552
+ elif padding == 'same':
553
+ h = inp.shape[0]%strides[0]
554
+ if h==0:
555
+ pad_h = np.max([0,kernel_size[0]-strides[0]])
556
+ else:
557
+ pad_h = np.max([0,kernel_size[0]-h])
537
558
 
538
- def calculate_wt_conv_unit(wt, p_mat, n_mat, t_sum, p_sum, n_sum, act):
539
- wt_mat = np.zeros_like(p_mat)
540
- if act["type"] == "mono":
559
+ v = inp.shape[1]%strides[1]
560
+ if v==0:
561
+ pad_v = np.max([0,kernel_size[1]-strides[1]])
562
+ else:
563
+ pad_v = np.max([0,kernel_size[1]-v])
564
+
565
+ paddings = [np.floor([pad_h/2.0,(pad_h+1)/2.0]).astype("int32"),
566
+ np.floor([pad_v/2.0,(pad_v+1)/2.0]).astype("int32"),
567
+ np.zeros((2)).astype("int32")]
568
+ inp_pad = np.pad(inp, paddings, 'constant', constant_values=const_val)
569
+ return (inp_pad,paddings)
570
+ else:
571
+ if isinstance(padding, tuple) and padding != (None, None):
572
+ pad_h = padding[0]
573
+ pad_v = padding[1]
574
+ paddings = [np.floor([pad_h,pad_h]).astype("int32"),
575
+ np.floor([pad_v,pad_v]).astype("int32"),
576
+ np.zeros((2)).astype("int32")]
577
+ inp_pad = np.pad(inp, paddings, 'constant', constant_values=const_val)
578
+ return (inp_pad,paddings)
579
+ else:
580
+ return (inp, [[0,0],[0,0],[0,0]])
581
+
582
+ def calculate_wt_conv_unit(patch, wts, w, b, act):
583
+ k = w.numpy()
584
+ bias = b.numpy()
585
+ b_ind = bias>0
586
+ bias_pos = bias*b_ind
587
+ b_ind = bias<0
588
+ bias_neg = bias*b_ind*-1.0
589
+ conv_out = np.einsum("ijkl,ijk->ijkl",k,patch)
590
+ p_ind = conv_out>0
591
+ p_ind = conv_out*p_ind
592
+ p_sum = np.einsum("ijkl->l",p_ind)
593
+ n_ind = conv_out<0
594
+ n_ind = conv_out*n_ind
595
+ n_sum = np.einsum("ijkl->l",n_ind)*-1.0
596
+ t_sum = p_sum+n_sum
597
+ wt_mat = np.zeros_like(k)
598
+ p_saturate = p_sum>0
599
+ n_saturate = n_sum>0
600
+ if act["type"]=='mono':
541
601
  if act["range"]["l"]:
542
- if t_sum < act["range"]["l"]:
543
- p_sum = 0
602
+ temp_ind = t_sum > act["range"]["l"]
603
+ p_saturate = temp_ind
544
604
  if act["range"]["u"]:
545
- if t_sum > act["range"]["u"]:
546
- n_sum = 0
547
- elif act["type"] == "non_mono":
605
+ temp_ind = t_sum < act["range"]["u"]
606
+ n_saturate = temp_ind
607
+ elif act["type"]=='non_mono':
548
608
  t_act = act["func"](t_sum)
549
- p_act = act["func"](p_sum)
550
- n_act = act["func"](n_sum)
609
+ p_act = act["func"](p_sum + bias_pos)
610
+ n_act = act["func"](-1*(n_sum + bias_neg))
551
611
  if act["range"]["l"]:
552
- if t_sum < act["range"]["l"]:
553
- p_sum = 0
612
+ temp_ind = t_sum > act["range"]["l"]
613
+ p_saturate = p_saturate*temp_ind
554
614
  if act["range"]["u"]:
555
- if t_sum > act["range"]["u"]:
556
- n_sum = 0
557
- if p_sum > 0 and n_sum > 0:
558
- if t_act == p_act:
559
- n_sum = 0
560
- elif t_act == n_act:
561
- p_sum = 0
562
- p_agg_wt = 0.0
563
- n_agg_wt = 0.0
564
- if p_sum + n_sum > 0.0:
565
- p_agg_wt = p_sum / (p_sum + n_sum)
566
- n_agg_wt = n_sum / (p_sum + n_sum)
567
- if p_sum == 0.0:
568
- p_sum = 1.0
569
- if n_sum == 0.0:
570
- n_sum = 1.0
571
- wt_mat = wt_mat + ((p_mat / p_sum) * wt * p_agg_wt)
572
- wt_mat = wt_mat + ((n_mat / n_sum) * wt * n_agg_wt * -1.0)
615
+ temp_ind = t_sum < act["range"]["u"]
616
+ n_saturate = n_saturate*temp_ind
617
+ temp_ind = np.abs(t_act - p_act)>1e-5
618
+ n_saturate = n_saturate*temp_ind
619
+ temp_ind = np.abs(t_act - n_act)>1e-5
620
+ p_saturate = p_saturate*temp_ind
621
+ p_agg_wt = (1.0/(p_sum+n_sum+bias_pos+bias_neg))*wts*p_saturate
622
+ n_agg_wt = (1.0/(p_sum+n_sum+bias_pos+bias_neg))*wts*n_saturate
623
+
624
+ wt_mat = wt_mat+(p_ind*p_agg_wt)
625
+ wt_mat = wt_mat+(n_ind*n_agg_wt*-1.0)
626
+ wt_mat = np.sum(wt_mat,axis=-1)
573
627
  return wt_mat
574
628
 
575
-
576
- def calculate_wt_conv(wts, inp, w, b, act):
577
- wts=wts.T
578
- inp=inp.T
579
- w=w.T
580
- expanded_input = as_strided(
581
- inp,
582
- shape=(
583
- inp.shape[0]
584
- - w.numpy().shape[0]
585
- + 1, # The feature map is a few pixels smaller than the input
586
- inp.shape[1] - w.numpy().shape[1] + 1,
587
- inp.shape[2],
588
- w.numpy().shape[0],
589
- w.numpy().shape[1],
590
- ),
591
- strides=(
592
- inp.strides[0],
593
- inp.strides[1],
594
- inp.strides[2],
595
- inp.strides[
596
- 0
597
- ], # When we move one step in the 3rd dimension, we should move one step in the original data too
598
- inp.strides[1],
599
- ),
600
- writeable=False, # totally use this to avoid writing to memory in weird places
601
- )
602
- test_wt = np.einsum("mnc->cmn", np.zeros_like(inp), order="C", optimize=True)
603
- for k in range(w.numpy().shape[-1]):
604
- kernel = w.numpy()[:, :, :, k]
605
- x = np.einsum(
606
- "abcmn,mnc->abcmn", expanded_input, kernel, order="C", optimize=True
607
- )
608
- x_pos = x.copy()
609
- x_neg = x.copy()
610
- x_pos[x < 0] = 0
611
- x_neg[x > 0] = 0
612
- x_sum = np.einsum("abcmn->ab", x, order="C", optimize=True)
613
- x_p_sum = np.einsum("abcmn->ab", x_pos, order="C", optimize=True)
614
- x_n_sum = np.einsum("abcmn->ab", x_neg, order="C", optimize=True) * -1.0
615
- # print(np.sum(x),np.sum(x_pos),np.sum(x_neg),np.sum(x_n_sum))
616
- for ind1 in range(expanded_input.shape[0]):
617
- for ind2 in range(expanded_input.shape[1]):
618
- temp_wt_mat = calculate_wt_conv_unit(
619
- wts[ind1, ind2, k],
620
- x_pos[ind1, ind2, :, :, :],
621
- x_neg[ind1, ind2, :, :, :],
622
- x_sum[ind1, ind2],
623
- x_p_sum[ind1, ind2],
624
- x_n_sum[ind1, ind2],
625
- act,
626
- )
627
- test_wt[
628
- :, ind1 : ind1 + kernel.shape[0], ind2 : ind2 + kernel.shape[1]
629
- ] += temp_wt_mat
630
- test_wt = np.einsum("cmn->mnc", test_wt, order="C", optimize=True)
631
- gc.collect()
632
- return test_wt
633
-
634
-
635
- def get_max_index(mat=None):
636
- max_ind = np.argmax(mat)
637
- ind = []
638
- rem = max_ind
639
- for i in mat.shape[:-1]:
640
- ind.append(rem // i)
641
- rem = rem % i
642
- ind.append(rem)
643
- return tuple(ind)
644
-
645
-
646
- def calculate_wt_maxpool(wts, inp, pool_size):
629
+ def calculate_wt_conv(wts, inp, w, b, padding, strides, act):
630
+ wts = wts.T
631
+ inp = inp.T
632
+ w = w.T
633
+ input_padded, paddings = calculate_padding(w.shape, inp, padding, strides)
634
+ out_ds = np.zeros_like(input_padded)
635
+ for ind1 in range(wts.shape[0]):
636
+ for ind2 in range(wts.shape[1]):
637
+ indexes = [np.arange(ind1*strides[0], ind1*(strides[0])+w.shape[0]),
638
+ np.arange(ind2*strides[1], ind2*(strides[1])+w.shape[1])]
639
+ # Take slice
640
+ tmp_patch = input_padded[np.ix_(indexes[0],indexes[1])]
641
+ updates = calculate_wt_conv_unit(tmp_patch, wts[ind1,ind2,:], w, b, act)
642
+ # Build tensor with "filtered" gradient
643
+ out_ds[np.ix_(indexes[0],indexes[1])]+=updates
644
+ out_ds = out_ds[paddings[0][0]:(paddings[0][0]+inp.shape[0]),
645
+ paddings[1][0]:(paddings[1][0]+inp.shape[1]),:]
646
+ return out_ds
647
+
648
+
649
+ def calculate_wt_max_unit(patch, wts, pool_size):
650
+ pmax = np.einsum("ijk,k->ijk",np.ones_like(patch),np.max(np.max(patch,axis=0),axis=0))
651
+ indexes = (patch-pmax)==0
652
+ indexes = indexes.astype(np.float32)
653
+ indexes_norm = 1.0/np.einsum("mnc->c",indexes)
654
+ indexes = np.einsum("ijk,k->ijk",indexes,indexes_norm)
655
+ out = np.einsum("ijk,k->ijk",indexes,wts)
656
+ return out
657
+
658
+ def calculate_wt_maxpool(wts, inp, pool_size, padding, strides):
647
659
  wts=wts.T
648
660
  inp=inp.T
649
- pad1 = pool_size[0]
650
- pad2 = pool_size[1]
651
- test_samp_pad = np.pad(inp, ((0, pad1), (0, pad2), (0, 0)), "constant")
652
- dim1, dim2, _ = wts.shape
653
- test_wt = np.zeros_like(test_samp_pad)
654
- for k in range(inp.shape[2]):
655
- wt_mat = wts[:, :, k]
656
- for ind1 in range(dim1):
657
- for ind2 in range(dim2):
658
- temp_inp = test_samp_pad[
659
- ind1 * pool_size[0] : (ind1 + 1) * pool_size[0],
660
- ind2 * pool_size[1] : (ind2 + 1) * pool_size[1],
661
- k,
662
- ]
663
- max_index = get_max_index(temp_inp)
664
- test_wt[
665
- ind1 * pool_size[0] : (ind1 + 1) * pool_size[0],
666
- ind2 * pool_size[1] : (ind2 + 1) * pool_size[1],
667
- k,
668
- ][max_index] = wt_mat[ind1, ind2]
669
- test_wt = test_wt[0 : inp.shape[0], 0 : inp.shape[1], :]
670
- return test_wt
671
-
661
+ strides = (strides,strides)
662
+ padding = (padding,padding)
663
+ input_padded, paddings = calculate_padding(pool_size, inp, padding, strides, -np.inf)
664
+ out_ds = np.zeros_like(input_padded)
665
+ for ind1 in range(wts.shape[0]):
666
+ for ind2 in range(wts.shape[1]):
667
+ indexes = [np.arange(ind1*strides[0], ind1*(strides[0])+pool_size[0]),
668
+ np.arange(ind2*strides[1], ind2*(strides[1])+pool_size[1])]
669
+ tmp_patch = input_padded[np.ix_(indexes[0],indexes[1])]
670
+ updates = calculate_wt_max_unit(tmp_patch, wts[ind1,ind2,:], pool_size)
671
+ out_ds[np.ix_(indexes[0],indexes[1])]+=updates
672
+ out_ds = out_ds[paddings[0][0]:(paddings[0][0]+inp.shape[0]),
673
+ paddings[1][0]:(paddings[1][0]+inp.shape[1]),:]
674
+ return out_ds
675
+
676
+
677
+ def calculate_wt_avg_unit(patch, wts, pool_size):
678
+ p_ind = patch>0
679
+ p_ind = patch*p_ind
680
+ p_sum = np.einsum("ijk->k",p_ind)
681
+ n_ind = patch<0
682
+ n_ind = patch*n_ind
683
+ n_sum = np.einsum("ijk->k",n_ind)*-1.0
684
+ t_sum = p_sum+n_sum
685
+ wt_mat = np.zeros_like(patch)
686
+ p_saturate = p_sum>0
687
+ n_saturate = n_sum>0
688
+ t_sum[t_sum==0] = 1.0
689
+ p_agg_wt = (1.0/(t_sum))*wts*p_saturate
690
+ n_agg_wt = (1.0/(t_sum))*wts*n_saturate
691
+ wt_mat = wt_mat+(p_ind*p_agg_wt)
692
+ wt_mat = wt_mat+(n_ind*n_agg_wt*-1.0)
693
+ return wt_mat
672
694
 
673
- def calculate_wt_avgpool(wts, inp, pool_size):
695
+ def calculate_wt_avgpool(wts, inp, pool_size, padding, strides):
674
696
  wts=wts.T
675
697
  inp=inp.T
676
698
 
677
699
  pad1 = pool_size[0]
678
700
  pad2 = pool_size[1]
679
- test_samp_pad = np.pad(inp, ((0, pad1), (0, pad2), (0, 0)), "constant")
680
- dim1, dim2, _ = wts.shape
681
- test_wt = np.zeros_like(test_samp_pad)
682
- for k in range(inp.shape[2]):
683
- wt_mat = wts[:, :, k]
684
- for ind1 in range(dim1):
685
- for ind2 in range(dim2):
686
- temp_inp = test_samp_pad[
687
- ind1 * pool_size[0] : (ind1 + 1) * pool_size[0],
688
- ind2 * pool_size[1] : (ind2 + 1) * pool_size[1],
689
- k,
690
- ]
691
- wt_ind1 = test_wt[
692
- ind1 * pool_size[0] : (ind1 + 1) * pool_size[0],
693
- ind2 * pool_size[1] : (ind2 + 1) * pool_size[1],
694
- k,
695
- ]
696
- wt = wt_mat[ind1, ind2]
697
- p_ind = temp_inp > 0
698
- n_ind = temp_inp < 0
699
- p_sum = np.sum(temp_inp[p_ind])
700
- n_sum = np.sum(temp_inp[n_ind]) * -1
701
- if p_sum > 0:
702
- p_agg_wt = p_sum / (p_sum + n_sum)
703
- else:
704
- p_agg_wt = 0
705
- if n_sum > 0:
706
- n_agg_wt = n_sum / (p_sum + n_sum)
707
- else:
708
- n_agg_wt = 0
709
- if p_sum == 0:
710
- p_sum = 1
711
- if n_sum == 0:
712
- n_sum = 1
713
- wt_ind1[p_ind] += (temp_inp[p_ind] / p_sum) * wt * p_agg_wt
714
- wt_ind1[n_ind] += (temp_inp[n_ind] / n_sum) * wt * n_agg_wt * -1.0
715
- test_wt = test_wt[0 : inp.shape[0], 0 : inp.shape[1], :]
716
- return test_wt
717
-
718
-
701
+ strides = (strides,strides)
702
+ padding = (padding,padding)
703
+ input_padded, paddings = calculate_padding(pool_size, inp, padding, strides, -np.inf)
704
+ out_ds = np.zeros_like(input_padded)
705
+ for ind1 in range(wts.shape[0]):
706
+ for ind2 in range(wts.shape[1]):
707
+ indexes = [np.arange(ind1*strides[0], ind1*(strides[0])+pool_size[0]),
708
+ np.arange(ind2*strides[1], ind2*(strides[1])+pool_size[1])]
709
+ # Take slice
710
+ tmp_patch = input_padded[np.ix_(indexes[0],indexes[1])]
711
+ updates = calculate_wt_avg_unit(tmp_patch, wts[ind1,ind2,:], pool_size)
712
+ # Build tensor with "filtered" gradient
713
+ out_ds[np.ix_(indexes[0],indexes[1])]+=updates
714
+ out_ds = out_ds[paddings[0][0]:(paddings[0][0]+inp.shape[0]),
715
+ paddings[1][0]:(paddings[1][0]+inp.shape[1]),:]
716
+ return out_ds
719
717
  def calculate_wt_gavgpool(wts, inp):
720
718
  wts=wts.T
721
719
  inp=inp.T
@@ -745,6 +743,438 @@ def calculate_wt_gavgpool(wts, inp):
745
743
  wt_mat[..., c] = temp_wt
746
744
  return wt_mat
747
745
 
746
+ def calculate_wt_gmaxpool_2d(wts, inp):
747
+ channels = wts.shape[0]
748
+ wt_mat = np.zeros_like(inp)
749
+ for c in range(channels):
750
+ wt = wts[c]
751
+ x = inp[..., c]
752
+ max_val = np.max(x)
753
+ max_indexes = (x == max_val).astype(np.float32)
754
+ max_indexes_norm = 1.0 / np.sum(max_indexes)
755
+ max_indexes = max_indexes * max_indexes_norm
756
+ wt_mat[..., c] = max_indexes * wt
757
+ return wt_mat
758
+
759
+ def calculate_padding_1d(kernel_size, inp, padding, strides, const_val=0.0):
760
+ if padding == 'valid':
761
+ return inp, [[0, 0],[0,0]]
762
+ elif padding == 0:
763
+ return inp, [[0, 0],[0,0]]
764
+ elif isinstance(padding, int):
765
+ inp_pad = np.pad(inp, ((padding, padding), (0,0)), 'constant', constant_values=const_val)
766
+ return inp_pad, [[padding, padding],[0,0]]
767
+ else:
768
+ remainder = inp.shape[0] % strides
769
+ if remainder == 0:
770
+ pad_total = max(0, kernel_size - strides)
771
+ else:
772
+ pad_total = max(0, kernel_size - remainder)
773
+
774
+ pad_left = int(np.floor(pad_total / 2.0))
775
+ pad_right = int(np.ceil(pad_total / 2.0))
776
+
777
+ inp_pad = np.pad(inp, ((pad_left, pad_right),(0,0)), 'constant', constant_values=const_val)
778
+ return inp_pad, [[pad_left, pad_right],[0,0]]
779
+
780
+ def calculate_wt_conv_unit_1d(patch, wts, w, b, act):
781
+ k = w.numpy()
782
+ bias = b.numpy()
783
+ b_ind = bias > 0
784
+ bias_pos = bias * b_ind
785
+ b_ind = bias < 0
786
+ bias_neg = bias * b_ind * -1.0
787
+ conv_out = np.einsum("ijk,ij->ijk", k, patch)
788
+ p_ind = conv_out > 0
789
+ p_ind = conv_out * p_ind
790
+ p_sum = np.einsum("ijk->k",p_ind)
791
+ n_ind = conv_out < 0
792
+ n_ind = conv_out * n_ind
793
+ n_sum = np.einsum("ijk->k",n_ind) * -1.0
794
+ t_sum = p_sum + n_sum
795
+ wt_mat = np.zeros_like(k)
796
+ p_saturate = p_sum > 0
797
+ n_saturate = n_sum > 0
798
+ if act["type"] == 'mono':
799
+ if act["range"]["l"]:
800
+ temp_ind = t_sum > act["range"]["l"]
801
+ p_saturate = temp_ind
802
+ if act["range"]["u"]:
803
+ temp_ind = t_sum < act["range"]["u"]
804
+ n_saturate = temp_ind
805
+ elif act["type"] == 'non_mono':
806
+ t_act = act["func"](t_sum)
807
+ p_act = act["func"](p_sum + bias_pos)
808
+ n_act = act["func"](-1 * (n_sum + bias_neg))
809
+ if act["range"]["l"]:
810
+ temp_ind = t_sum > act["range"]["l"]
811
+ p_saturate = p_saturate * temp_ind
812
+ if act["range"]["u"]:
813
+ temp_ind = t_sum < act["range"]["u"]
814
+ n_saturate = n_saturate * temp_ind
815
+ temp_ind = np.abs(t_act - p_act) > 1e-5
816
+ n_saturate = n_saturate * temp_ind
817
+ temp_ind = np.abs(t_act - n_act) > 1e-5
818
+ p_saturate = p_saturate * temp_ind
819
+ p_agg_wt = (1.0 / (p_sum + n_sum + bias_pos + bias_neg)) * wts * p_saturate
820
+ n_agg_wt = (1.0 / (p_sum + n_sum + bias_pos + bias_neg)) * wts * n_saturate
821
+
822
+ wt_mat = wt_mat + (p_ind * p_agg_wt)
823
+ wt_mat = wt_mat + (n_ind * n_agg_wt * -1.0)
824
+ wt_mat = np.sum(wt_mat, axis=-1)
825
+ return wt_mat
826
+
827
+ def calculate_wt_conv_1d(wts, inp, w, b, padding, stride, act):
828
+ wts = wts.T
829
+ inp = inp.T
830
+ w = w.T
831
+ stride=stride
832
+ input_padded, paddings = calculate_padding_1d(w.shape[0], inp, padding, stride)
833
+ out_ds = np.zeros_like(input_padded)
834
+ for ind in range(wts.shape[0]):
835
+ indexes = np.arange(ind * stride, ind * stride + w.shape[0])
836
+ tmp_patch = input_padded[indexes]
837
+ updates = calculate_wt_conv_unit_1d(tmp_patch, wts[ind, :], w, b, act)
838
+ out_ds[indexes] += updates
839
+ out_ds = out_ds[paddings[0][0]:(paddings[0][0] + inp.shape[0])]
840
+ return out_ds
841
+
842
+ def calculate_wt_max_unit_1d(patch, wts):
843
+ pmax = np.max(patch, axis=0)
844
+ indexes = (patch - pmax) == 0
845
+ indexes = indexes.astype(np.float32)
846
+ indexes_norm = 1.0 / np.sum(indexes, axis=0)
847
+ indexes = np.einsum("ij,j->ij", indexes, indexes_norm)
848
+ out = np.einsum("ij,j->ij", indexes, wts)
849
+ return out
850
+
851
+ def calculate_wt_maxpool_1d(wts, inp, pool_size, padding, stride):
852
+ inp = inp.T
853
+ wts = wts.T
854
+ input_padded, paddings = calculate_padding_1d(pool_size, inp, padding, stride, -np.inf)
855
+ out_ds = np.zeros_like(input_padded)
856
+ stride=stride
857
+ pool_size=pool_size
858
+ for ind in range(wts.shape[0]):
859
+ indexes = np.arange(ind * stride, ind * stride + pool_size)
860
+ tmp_patch = input_padded[indexes]
861
+ updates = calculate_wt_max_unit_1d(tmp_patch, wts[ind, :])
862
+ out_ds[indexes] += updates
863
+ out_ds = out_ds[paddings[0][0]:(paddings[0][0] + inp.shape[0])]
864
+ return out_ds
865
+
866
+ def calculate_wt_avg_unit_1d(patch, wts):
867
+ p_ind = patch > 0
868
+ p_ind = patch * p_ind
869
+ p_sum = np.sum(p_ind, axis=0)
870
+ n_ind = patch < 0
871
+ n_ind = patch * n_ind
872
+ n_sum = np.sum(n_ind, axis=0) * -1.0
873
+ t_sum = p_sum + n_sum
874
+ wt_mat = np.zeros_like(patch)
875
+ p_saturate = p_sum > 0
876
+ n_saturate = n_sum > 0
877
+ t_sum[t_sum == 0] = 1.0
878
+ p_agg_wt = (1.0 / t_sum) * wts * p_saturate
879
+ n_agg_wt = (1.0 / t_sum) * wts * n_saturate
880
+ wt_mat = wt_mat + (p_ind * p_agg_wt)
881
+ wt_mat = wt_mat + (n_ind * n_agg_wt * -1.0)
882
+ return wt_mat
883
+
884
+ def calculate_wt_avgpool_1d(wts, inp, pool_size, padding, stride):
885
+ wts = wts.T
886
+ inp = inp.T
887
+ stride=stride
888
+ pool_size=pool_size
889
+ input_padded, paddings = calculate_padding_1d(pool_size, inp, padding[0], stride[0], 0)
890
+ out_ds = np.zeros_like(input_padded)
891
+ for ind in range(wts.shape[0]):
892
+ indexes = np.arange(ind * stride[0], ind * stride[0] + pool_size[0])
893
+ tmp_patch = input_padded[indexes]
894
+ updates = calculate_wt_avg_unit_1d(tmp_patch, wts[ind, :])
895
+ out_ds[indexes] += updates
896
+ out_ds = out_ds[paddings[0][0]:(paddings[0][0] + inp.shape[0])]
897
+ return out_ds
898
+
899
+ def calculate_wt_gavgpool_1d(wts, inp):
900
+ channels = wts.shape[0]
901
+ wt_mat = np.zeros_like(inp)
902
+ for c in range(channels):
903
+ wt = wts[c]
904
+ temp_wt = wt_mat[:, c]
905
+ x = inp[:, c]
906
+ p_mat = np.copy(x)
907
+ n_mat = np.copy(x)
908
+ p_mat[p_mat < 0] = 0
909
+ n_mat[n_mat > 0] = 0
910
+ p_sum = np.sum(p_mat)
911
+ n_sum = np.sum(n_mat) * -1
912
+ p_agg_wt = 0.0
913
+ n_agg_wt = 0.0
914
+ if p_sum + n_sum > 0.0:
915
+ p_agg_wt = p_sum / (p_sum + n_sum)
916
+ n_agg_wt = n_sum / (p_sum + n_sum)
917
+ if p_sum == 0.0:
918
+ p_sum = 1.0
919
+ if n_sum == 0.0:
920
+ n_sum = 1.0
921
+ temp_wt = temp_wt + ((p_mat / p_sum) * wt * p_agg_wt)
922
+ temp_wt = temp_wt + ((n_mat / n_sum) * wt * n_agg_wt * -1.0)
923
+ wt_mat[:, c] = temp_wt
924
+ return wt_mat
925
+
926
+ def calculate_wt_gmaxpool_1d(wts, inp):
927
+ wts = wts.T
928
+ inp = inp.T
929
+ channels = wts.shape[0]
930
+ wt_mat = np.zeros_like(inp)
931
+ for c in range(channels):
932
+ wt = wts[c]
933
+ x = inp[:, c]
934
+ max_val = np.max(x)
935
+ max_indexes = (x == max_val).astype(np.float32)
936
+ max_indexes_norm = 1.0 / np.sum(max_indexes)
937
+ max_indexes = max_indexes * max_indexes_norm
938
+ wt_mat[:, c] = max_indexes * wt
939
+ return wt_mat
940
+
941
+ def calculate_output_padding_conv2d_transpose(input_shape, kernel_size, padding, strides):
942
+ if padding == 'valid':
943
+ out_shape = [(input_shape[0] - 1) * strides[0] + kernel_size[0],
944
+ (input_shape[1] - 1) * strides[1] + kernel_size[1]]
945
+ paddings = [[0, 0], [0, 0], [0, 0]]
946
+ elif padding == (0,0):
947
+ out_shape = [(input_shape[0] - 1) * strides[0] + kernel_size[0],
948
+ (input_shape[1] - 1) * strides[1] + kernel_size[1]]
949
+ paddings = [[0, 0], [0, 0], [0, 0]]
950
+ elif isinstance(padding, tuple) and padding != (None, None):
951
+ out_shape = [input_shape[0] * strides[0], input_shape[1] * strides[1]]
952
+ pad_h = padding[0]
953
+ pad_v = padding[1]
954
+ paddings = [[pad_h, pad_h], [pad_v, pad_v], [0, 0]]
955
+ else: # 'same' padding
956
+ out_shape = [input_shape[0] * strides[0], input_shape[1] * strides[1]]
957
+ pad_h = max(0, (input_shape[0] - 1) * strides[0] + kernel_size[0] - out_shape[0])
958
+ pad_v = max(0, (input_shape[1] - 1) * strides[1] + kernel_size[1] - out_shape[1])
959
+ paddings = [[pad_h // 2, pad_h - pad_h // 2],
960
+ [pad_v // 2, pad_v - pad_v // 2],
961
+ [0, 0]]
962
+
963
+ return out_shape, paddings
964
+
965
+ def calculate_wt_conv2d_transpose_unit(patch, wts, w, b, act):
966
+ if patch.ndim == 1:
967
+ patch = patch.reshape(1, 1, -1)
968
+ elif patch.ndim == 2:
969
+ patch = patch.reshape(1, *patch.shape)
970
+ elif patch.ndim != 3:
971
+ raise ValueError(f"Unexpected patch shape: {patch.shape}")
972
+
973
+ k = w.permute(0, 1, 3, 2).numpy()
974
+ bias = b.numpy()
975
+ b_ind = bias > 0
976
+ bias_pos = bias * b_ind
977
+ b_ind = bias < 0
978
+ bias_neg = bias * b_ind * -1.0
979
+
980
+ conv_out = np.einsum('ijkl,mnk->ijkl', k, patch)
981
+ p_ind = conv_out > 0
982
+ p_ind = conv_out * p_ind
983
+ n_ind = conv_out < 0
984
+ n_ind = conv_out * n_ind
985
+
986
+ p_sum = np.einsum("ijkl->l", p_ind)
987
+ n_sum = np.einsum("ijkl->l", n_ind) * -1.0
988
+ t_sum = p_sum + n_sum
989
+
990
+ wt_mat = np.zeros_like(k)
991
+ p_saturate = p_sum > 0
992
+ n_saturate = n_sum > 0
993
+
994
+ if act["type"] == 'mono':
995
+ if act["range"]["l"]:
996
+ p_saturate = t_sum > act["range"]["l"]
997
+ if act["range"]["u"]:
998
+ n_saturate = t_sum < act["range"]["u"]
999
+ elif act["type"] == 'non_mono':
1000
+ t_act = act["func"](t_sum)
1001
+ p_act = act["func"](p_sum + bias_pos)
1002
+ n_act = act["func"](-1 * (n_sum + bias_neg))
1003
+ if act["range"]["l"]:
1004
+ temp_ind = t_sum > act["range"]["l"]
1005
+ p_saturate = p_saturate * temp_ind
1006
+ if act["range"]["u"]:
1007
+ temp_ind = t_sum < act["range"]["u"]
1008
+ n_saturate = n_saturate * temp_ind
1009
+ temp_ind = np.abs(t_act - p_act) > 1e-5
1010
+ n_saturate = n_saturate * temp_ind
1011
+ temp_ind = np.abs(t_act - n_act) > 1e-5
1012
+ p_saturate = p_saturate * temp_ind
1013
+
1014
+ p_agg_wt = (1.0 / (p_sum + n_sum + bias_pos + bias_neg)) * wts * p_saturate
1015
+ n_agg_wt = (1.0 / (p_sum + n_sum + bias_pos + bias_neg)) * wts * n_saturate
1016
+
1017
+ wt_mat = wt_mat + (p_ind * p_agg_wt)
1018
+ wt_mat = wt_mat + (n_ind * n_agg_wt * -1.0)
1019
+ wt_mat = np.sum(wt_mat, axis=-1)
1020
+ return wt_mat
1021
+
1022
+ def calculate_wt_conv2d_transpose(wts, inp, w, b, padding, strides, act):
1023
+ wts = wts.T
1024
+ inp = inp.T
1025
+ w = w.T
1026
+ out_shape, paddings = calculate_output_padding_conv2d_transpose(inp.shape, w.shape, padding, strides)
1027
+ out_ds = np.zeros(out_shape + [w.shape[3]])
1028
+
1029
+ for ind1 in range(inp.shape[0]):
1030
+ for ind2 in range(inp.shape[1]):
1031
+ out_ind1 = ind1 * strides[0]
1032
+ out_ind2 = ind2 * strides[1]
1033
+ tmp_patch = inp[ind1, ind2, :]
1034
+ updates = calculate_wt_conv2d_transpose_unit(tmp_patch, wts[ind1, ind2, :], w, b, act)
1035
+ end_ind1 = min(out_ind1 + w.shape[0], out_shape[0])
1036
+ end_ind2 = min(out_ind2 + w.shape[1], out_shape[1])
1037
+ valid_updates = updates[:end_ind1 - out_ind1, :end_ind2 - out_ind2, :]
1038
+ out_ds[out_ind1:end_ind1, out_ind2:end_ind2, :] += valid_updates
1039
+
1040
+ if padding == 'same':
1041
+ adjusted_out_ds = np.zeros(inp.shape)
1042
+ for i in range(inp.shape[0]):
1043
+ for j in range(inp.shape[1]):
1044
+ start_i = max(0, i * strides[0])
1045
+ start_j = max(0, j * strides[1])
1046
+ end_i = min(out_ds.shape[0], (i+1) * strides[0])
1047
+ end_j = min(out_ds.shape[1], (j+1) * strides[1])
1048
+ relevant_area = out_ds[start_i:end_i, start_j:end_j, :]
1049
+ adjusted_out_ds[i, j, :] = np.sum(relevant_area, axis=(0, 1))
1050
+ out_ds = adjusted_out_ds
1051
+ elif isinstance(padding, tuple) and padding != (None, None):
1052
+ adjusted_out_ds = np.zeros(inp.shape)
1053
+ for i in range(inp.shape[0]):
1054
+ for j in range(inp.shape[1]):
1055
+ start_i = max(0, i * strides[0])
1056
+ start_j = max(0, j * strides[1])
1057
+ end_i = min(out_ds.shape[0], (i+1) * strides[0])
1058
+ end_j = min(out_ds.shape[1], (j+1) * strides[1])
1059
+ relevant_area = out_ds[start_i:end_i, start_j:end_j, :]
1060
+ adjusted_out_ds[i, j, :] = np.sum(relevant_area, axis=(0, 1))
1061
+ out_ds = adjusted_out_ds
1062
+ else:
1063
+ out_ds = out_ds[paddings[0][0]:(paddings[0][0] + inp.shape[0]),
1064
+ paddings[1][0]:(paddings[1][0] + inp.shape[1]), :]
1065
+
1066
+ return out_ds
1067
+
1068
+
1069
+ def calculate_output_padding_conv1d_transpose(input_shape, kernel_size, padding, strides,dilation):
1070
+ if padding == 'valid':
1071
+ out_shape = [(input_shape[0] - 1) * strides + kernel_size[0]]
1072
+ paddings = [[0, 0], [0, 0]]
1073
+ elif padding == 0:
1074
+ out_shape = [(input_shape[0] - 1) * strides + kernel_size[0]]
1075
+ paddings = [[0, 0], [0, 0]]
1076
+ elif isinstance(padding, int):
1077
+ out_shape = [input_shape[0] * strides]
1078
+ pad_v = (dilation * (kernel_size[0] - 1)) - padding
1079
+ out_shape = [input_shape[0] * strides + pad_v]
1080
+ paddings = [[pad_v, pad_v],
1081
+ [0, 0]]
1082
+ else: # 'same' padding
1083
+ out_shape = [input_shape[0] * strides]
1084
+ pad_h = max(0, (input_shape[0] - 1) * strides + kernel_size[0] - out_shape[0])
1085
+ paddings = [[pad_h // 2, pad_h // 2],
1086
+ [0, 0]]
1087
+
1088
+ return out_shape, paddings
1089
+
1090
+ def calculate_wt_conv1d_transpose_unit(patch, wts, w, b, act):
1091
+ if patch.ndim == 1:
1092
+ patch = patch.reshape(1, -1)
1093
+ elif patch.ndim != 2:
1094
+ raise ValueError(f"Unexpected patch shape: {patch.shape}")
1095
+
1096
+ k = w.permute(0, 2, 1).numpy()
1097
+ bias = b.numpy()
1098
+ b_ind = bias > 0
1099
+ bias_pos = bias * b_ind
1100
+ b_ind = bias < 0
1101
+ bias_neg = bias * b_ind * -1.0
1102
+ conv_out = np.einsum('ijk,mj->ijk', k, patch)
1103
+ p_ind = conv_out > 0
1104
+ p_ind = conv_out * p_ind
1105
+ n_ind = conv_out < 0
1106
+ n_ind = conv_out * n_ind
1107
+
1108
+ p_sum = np.einsum("ijl->l", p_ind)
1109
+ n_sum = np.einsum("ijl->l", n_ind) * -1.0
1110
+ t_sum = p_sum + n_sum
1111
+
1112
+ wt_mat = np.zeros_like(k)
1113
+ p_saturate = p_sum > 0
1114
+ n_saturate = n_sum > 0
1115
+
1116
+ if act["type"] == 'mono':
1117
+ if act["range"]["l"]:
1118
+ p_saturate = t_sum > act["range"]["l"]
1119
+ if act["range"]["u"]:
1120
+ n_saturate = t_sum < act["range"]["u"]
1121
+ elif act["type"] == 'non_mono':
1122
+ t_act = act["func"](t_sum)
1123
+ p_act = act["func"](p_sum + bias_pos)
1124
+ n_act = act["func"](-1 * (n_sum + bias_neg))
1125
+ if act["range"]["l"]:
1126
+ temp_ind = t_sum > act["range"]["l"]
1127
+ p_saturate = p_saturate * temp_ind
1128
+ if act["range"]["u"]:
1129
+ temp_ind = t_sum < act["range"]["u"]
1130
+ n_saturate = n_saturate * temp_ind
1131
+ temp_ind = np.abs(t_act - p_act) > 1e-5
1132
+ n_saturate = n_saturate * temp_ind
1133
+ temp_ind = np.abs(t_act - n_act) > 1e-5
1134
+ p_saturate = p_saturate * temp_ind
1135
+
1136
+ p_agg_wt = (1.0 / (p_sum + n_sum + bias_pos + bias_neg)) * wts * p_saturate
1137
+ n_agg_wt = (1.0 / (p_sum + n_sum + bias_pos + bias_neg)) * wts * n_saturate
1138
+ wt_mat = wt_mat + (p_ind * p_agg_wt)
1139
+ wt_mat = wt_mat + (n_ind * n_agg_wt * -1.0)
1140
+ wt_mat = np.sum(wt_mat, axis=-1)
1141
+ return wt_mat
1142
+
1143
+ def calculate_wt_conv1d_transpose(wts, inp, w, b, padding, strides, dilation, act):
1144
+ wts = wts.T
1145
+ inp = inp.T
1146
+ w = w.T
1147
+ out_shape, paddings = calculate_output_padding_conv1d_transpose(inp.shape, w.shape, padding, strides, dilation)
1148
+ out_ds = np.zeros(out_shape + [w.shape[2]])
1149
+
1150
+ for ind in range(inp.shape[0]):
1151
+ out_ind = ind * strides
1152
+ tmp_patch = inp[ind, :]
1153
+ updates = calculate_wt_conv1d_transpose_unit(tmp_patch, wts[ind, :], w, b, act)
1154
+ end_ind = min(out_ind + w.shape[0], out_shape[0])
1155
+ valid_updates = updates[:end_ind - out_ind, :]
1156
+ out_ds[out_ind:end_ind, :] += valid_updates
1157
+
1158
+ if padding == 'same':
1159
+ adjusted_out_ds = np.zeros(inp.shape)
1160
+ for i in range(inp.shape[0]):
1161
+ start_i = max(0, i * strides)
1162
+ end_i = min(out_ds.shape[0], (i + 1) * strides)
1163
+ relevant_area = out_ds[start_i:end_i, :]
1164
+ adjusted_out_ds[i, :] = np.sum(relevant_area, axis=0)
1165
+ out_ds = adjusted_out_ds
1166
+ elif padding == 0:
1167
+ adjusted_out_ds = np.zeros(inp.shape)
1168
+ for i in range(inp.shape[0]):
1169
+ start_i = max(0, i * strides)
1170
+ end_i = min(out_ds.shape[0], (i + 1) * strides)
1171
+ relevant_area = out_ds[start_i:end_i, :]
1172
+ adjusted_out_ds[i, :] = np.sum(relevant_area, axis=0)
1173
+ out_ds = adjusted_out_ds
1174
+ else:
1175
+ out_ds = out_ds[paddings[0][0]:(paddings[0][0] + inp.shape[0]), :]
1176
+ return out_ds
1177
+
748
1178
 
749
1179
  ####################################################################
750
1180
  ################### Encoder Model ####################