dl-backtrace 0.0.18__py3-none-any.whl → 0.0.20.dev36__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of dl-backtrace might be problematic. Click here for more details.

@@ -1,42 +1,32 @@
1
1
  import gc
2
-
2
+ import torch
3
3
  import numpy as np
4
- import tensorflow as tf
5
4
  from numpy.lib.stride_tricks import as_strided
6
- from tensorflow.keras import backend as K
7
-
8
5
 
9
6
  def np_swish(x, beta=0.75):
10
7
  z = 1 / (1 + np.exp(-(beta * x)))
11
8
  return x * z
12
9
 
13
-
14
10
  def np_wave(x, alpha=1.0):
15
11
  return (alpha * x * np.exp(1.0)) / (np.exp(-x) + np.exp(x))
16
12
 
17
-
18
13
  def np_pulse(x, alpha=1.0):
19
14
  return alpha * (1 - np.tanh(x) * np.tanh(x))
20
15
 
21
-
22
16
  def np_absolute(x, alpha=1.0):
23
17
  return alpha * x * np.tanh(x)
24
18
 
25
-
26
19
  def np_hard_sigmoid(x):
27
20
  return np.clip(0.2 * x + 0.5, 0, 1)
28
21
 
29
-
30
22
  def np_sigmoid(x):
31
23
  z = 1 / (1 + np.exp(-x))
32
24
  return z
33
25
 
34
-
35
26
  def np_tanh(x):
36
27
  z = np.tanh(x)
37
28
  return z.astype(np.float32)
38
29
 
39
-
40
30
  class LSTM_forward(object):
41
31
  def __init__(
42
32
  self, num_cells, units, weights, return_sequence=False, go_backwards=False
@@ -48,8 +38,8 @@ class LSTM_forward(object):
48
38
  self.bias = weights[2][1]
49
39
  self.return_sequence = return_sequence
50
40
  self.go_backwards = go_backwards
51
- self.recurrent_activation = tf.math.sigmoid
52
- self.activation = tf.math.tanh
41
+ self.recurrent_activation = torch.sigmoid()
42
+ self.activation = torch.tanh()
53
43
  self.compute_log = {}
54
44
  for i in range(self.num_cells):
55
45
  self.compute_log[i] = {}
@@ -63,23 +53,19 @@ class LSTM_forward(object):
63
53
  """Computes carry and output using split kernels."""
64
54
  x_i, x_f, x_c, x_o = x
65
55
  h_tm1_i, h_tm1_f, h_tm1_c, h_tm1_o = h_tm1
66
- #print(self.recurrent_kernel[1][:, : self.units].shape)
67
- #print(h_tm1_i.shape,self.recurrent_kernel[1][:, : self.units].shape)
68
- w=tf.convert_to_tensor(self.recurrent_kernel[1], dtype=tf.float32)
69
- #print(K.dot(h_tm1_i, w[:, : self.units]))
70
-
56
+ w=torch.as_tensor(self.recurrent_kernel[1], dtype=torch.float32)
71
57
  i = self.recurrent_activation(
72
- x_i + K.dot(h_tm1_i, w[:, : self.units])
58
+ x_i + torch.dot(h_tm1_i, w[:, : self.units])
73
59
  )
74
60
  f = self.recurrent_activation(
75
- x_f + K.dot(h_tm1_f, w[:, self.units : self.units * 2])
61
+ x_f + torch.dot(h_tm1_f, w[:, self.units : self.units * 2])
76
62
  )
77
63
  c = f * c_tm1 + i * self.activation(
78
64
  x_c
79
- + K.dot(h_tm1_c, w[:, self.units * 2 : self.units * 3])
65
+ + torch.dot(h_tm1_c, w[:, self.units * 2 : self.units * 3])
80
66
  )
81
67
  o = self.recurrent_activation(
82
- x_o + K.dot(h_tm1_o, w[:, self.units * 3 :])
68
+ x_o + torch.dot(h_tm1_o, w[:, self.units * 3 :])
83
69
  )
84
70
  self.compute_log[cell_num]["int_arrays"]["i"] = i
85
71
  self.compute_log[cell_num]["int_arrays"]["f"] = f
@@ -97,16 +83,16 @@ class LSTM_forward(object):
97
83
  inputs_f = inputs
98
84
  inputs_c = inputs
99
85
  inputs_o = inputs
100
- k_i, k_f, k_c, k_o = tf.split(self.kernel[1], num_or_size_splits=4, axis=1)
101
- x_i = K.dot(inputs_i, k_i)
102
- x_f = K.dot(inputs_f, k_f)
103
- x_c = K.dot(inputs_c, k_c)
104
- x_o = K.dot(inputs_o, k_o)
105
- b_i, b_f, b_c, b_o = tf.split(self.bias, num_or_size_splits=4, axis=0)
106
- x_i = tf.add(x_i, b_i)
107
- x_f = tf.add(x_f, b_f)
108
- x_c = tf.add(x_c, b_c)
109
- x_o = tf.add(x_o, b_o)
86
+ k_i, k_f, k_c, k_o = torch.split(self.kernel[1],self.kernel.size(1)//4,dim=1)
87
+ x_i = torch.dot(inputs_i, k_i)
88
+ x_f = torch.dot(inputs_f, k_f)
89
+ x_c = torch.dot(inputs_c, k_c)
90
+ x_o = torch.dot(inputs_o, k_o)
91
+ b_i, b_f, b_c, b_o = torch.split(self.bias,self.bias.size(1)//4,dim=0)
92
+ x_i = x_i + b_i
93
+ x_f = x_f + b_f
94
+ x_c = x_c + b_c
95
+ x_o = x_o + b_o
110
96
 
111
97
  h_tm1_i = h_tm1
112
98
  h_tm1_f = h_tm1
@@ -123,12 +109,12 @@ class LSTM_forward(object):
123
109
  return h, [h, c]
124
110
 
125
111
  def calculate_lstm_wt(self, input_data):
126
- hstate = tf.convert_to_tensor(np.zeros((1, self.units)), dtype=tf.float32)
127
- cstate = tf.convert_to_tensor(np.zeros((1, self.units)), dtype=tf.float32)
112
+ hstate = torch.tensor((1,self.units),dtype=torch.float32)
113
+ cstate = torch.tensor((1,self.units),dtype=torch.float32)
128
114
  output = []
129
115
  for ind in range(input_data.shape[0]):
130
- inp = tf.convert_to_tensor(
131
- input_data[ind, :].reshape((1, input_data.shape[1])), dtype=tf.float32
116
+ inp = torch.tensor(
117
+ input_data[ind, :].reshape((1, input_data.shape[1])), dtype=torch.float32
132
118
  )
133
119
  h, s = self.calculate_lstm_cell_wt(inp, [hstate, cstate], ind)
134
120
  hstate = s[0]
@@ -136,9 +122,6 @@ class LSTM_forward(object):
136
122
  output.append(h)
137
123
  return output
138
124
 
139
-
140
-
141
-
142
125
  class LSTM_backtrace(object):
143
126
  def __init__(
144
127
  self, num_cells, units, weights, return_sequence=False, go_backwards=False
@@ -270,8 +253,6 @@ class LSTM_backtrace(object):
270
253
  x_i, x_f, x_c, x_o = x
271
254
  f = self.compute_log[cell_num]["int_arrays"]["f"].numpy()[0]
272
255
  i = self.compute_log[cell_num]["int_arrays"]["i"].numpy()[0]
273
- # o = self.recurrent_activation(
274
- # x_o + np.dot(h_tm1_o, self.recurrent_kernel[:, self.units * 3:])).astype(np.float32)
275
256
  temp1 = np.dot(h_tm1_o, self.recurrent_kernel[1][:, self.units * 3 :]).astype(
276
257
  np.float32
277
258
  )
@@ -283,9 +264,6 @@ class LSTM_backtrace(object):
283
264
  [],
284
265
  {"type": None},
285
266
  )
286
-
287
- # c = f * c_tm1 + i * self.activation(x_c + np.dot(
288
- # h_tm1_c, self.recurrent_kernel[:, self.units * 2:self.units * 3])).astype(np.float32)
289
267
  temp2 = f * c_tm1
290
268
  temp3_1 = np.dot(
291
269
  h_tm1_c, self.recurrent_kernel[1][:, self.units * 2 : self.units * 3]
@@ -303,9 +281,6 @@ class LSTM_backtrace(object):
303
281
  [],
304
282
  {"type": None},
305
283
  )
306
-
307
- # f = self.recurrent_activation(x_f + np.dot(
308
- # h_tm1_f, self.recurrent_kernel[:, self.units:self.units * 2])).astype(np.float32)
309
284
  temp4 = np.dot(h_tm1_f, self.recurrent_kernel[1][:, self.units : self.units * 2])
310
285
  wt_x_f, wt_temp4 = self.calculate_wt_add(wt_f, [x_f, temp4])
311
286
  wt_h_tm1_f = self.calculate_wt_fc(
@@ -315,9 +290,6 @@ class LSTM_backtrace(object):
315
290
  [],
316
291
  {"type": None},
317
292
  )
318
-
319
- # i = self.recurrent_activation(
320
- # x_i + np.dot(h_tm1_i, self.recurrent_kernel[:, :self.units])).astype(np.float32)
321
293
  temp5 = np.dot(h_tm1_i, self.recurrent_kernel[1][:, : self.units])
322
294
  wt_x_i, wt_temp5 = self.calculate_wt_add(wt_i, [x_i, temp5])
323
295
  wt_h_tm1_i = self.calculate_wt_fc(
@@ -364,7 +336,6 @@ class LSTM_backtrace(object):
364
336
  wt_h_tm1 = wt_h_tm1_i + wt_h_tm1_f + wt_h_tm1_c + wt_h_tm1_o
365
337
  inputs = self.compute_log[cell_num]["inp"].numpy()[0]
366
338
 
367
- #print(np.split(self.kernel[1], indices_or_sections=4, axis=1))
368
339
  k_i, k_f, k_c, k_o = np.split(self.kernel[1], indices_or_sections=4, axis=1)
369
340
  b_i, b_f, b_c, b_o = np.split(self.bias[1], indices_or_sections=4, axis=0)
370
341
 
@@ -395,12 +366,10 @@ class LSTM_backtrace(object):
395
366
  output.reverse()
396
367
  return np.array(output)
397
368
 
398
-
399
369
  def dummy_wt(wts, inp, *args):
400
370
  test_wt = np.zeros_like(inp)
401
371
  return test_wt
402
372
 
403
-
404
373
  def calculate_wt_fc(wts, inp, w, b, act):
405
374
  mul_mat = np.einsum("ij,i->ij", w.numpy().T, inp).T
406
375
  wt_mat = np.zeros(mul_mat.shape)
@@ -461,12 +430,10 @@ def calculate_wt_fc(wts, inp, w, b, act):
461
430
  wt_mat = wt_mat.sum(axis=0)
462
431
  return wt_mat
463
432
 
464
-
465
433
  def calculate_wt_rshp(wts, inp=None):
466
434
  x = np.reshape(wts, inp.shape)
467
435
  return x
468
436
 
469
-
470
437
  def calculate_wt_concat(wts, inp=None, axis=-1):
471
438
  wts=wts.T
472
439
  splits = [i.shape[axis] for i in inp]
@@ -476,7 +443,6 @@ def calculate_wt_concat(wts, inp=None, axis=-1):
476
443
  x = np.split(wts, indices_or_sections=splits, axis=axis)
477
444
  return x
478
445
 
479
-
480
446
  def calculate_wt_add(wts, inp=None):
481
447
  wts=wts.T
482
448
  wt_mat = []
@@ -523,199 +489,231 @@ def calculate_wt_add(wts, inp=None):
523
489
  wt_mat = [i.reshape(wts.shape) for i in list(wt_mat)]
524
490
  return wt_mat
525
491
 
526
-
527
- def calculate_start_wt(arg):
528
- x = np.argmax(arg[0])
529
- y = np.zeros(arg.shape)
530
- y[0][x] = 1
492
+ def calculate_start_wt(arg, scaler=None,thresholding=0.5,task="binary-classification"):
493
+ if arg.ndim == 2:
494
+ if task == "binary-classification" or task == "multi-class classification":
495
+ x = np.argmax(arg[0])
496
+ m = np.max(arg[0])
497
+ y = np.zeros(arg.shape)
498
+ if scaler:
499
+ y[0][x] = scaler
500
+ else:
501
+ y[0][x] = m
502
+ elif task == "bbox-regression":
503
+ y = np.zeros(arg.shape)
504
+ if scaler:
505
+ y[0] = scaler
506
+ num_non_zero_elements = np.count_nonzero(y)
507
+ if num_non_zero_elements > 0:
508
+ y = y / num_non_zero_elements
509
+ else:
510
+ m = np.max(arg[0])
511
+ x = np.argmax(arg[0])
512
+ y[0][x] = m
513
+ else:
514
+ x = np.argmax(arg[0])
515
+ m = np.max(arg[0])
516
+ y = np.zeros(arg.shape)
517
+ if scaler:
518
+ y[0][x] = scaler
519
+ else:
520
+ y[0][x] = m
521
+
522
+ elif arg.ndim == 4 and task == "binary-segmentation":
523
+ indices = np.where(arg > thresholding)
524
+ y = np.zeros(arg.shape)
525
+ if scaler:
526
+ y[indices] = scaler
527
+ num_non_zero_elements = np.count_nonzero(y)
528
+ if num_non_zero_elements > 0:
529
+ y = y / num_non_zero_elements
530
+ else:
531
+ y[indices] = arg[indices]
532
+
533
+ else:
534
+ x = np.argmax(arg[0])
535
+ m = np.max(arg[0])
536
+ y = np.zeros(arg.shape)
537
+ if scaler:
538
+ y[0][x] = scaler
539
+ else:
540
+ y[0][x] = m
531
541
  return y[0]
532
542
 
533
-
534
543
  def calculate_wt_passthru(wts):
535
544
  return wts
545
+ def calculate_wt_zero_pad(wts,inp,padding):
546
+ wt_mat = wts[padding[0][0]:inp.shape[0]+padding[0][0],padding[1][0]:inp.shape[1]+padding[1][0],:]
547
+ return wt_mat
536
548
 
549
+ def calculate_padding(kernel_size, inp, padding, strides, const_val=0.0):
550
+ if padding=='valid':
551
+ return (inp, [[0,0],[0,0],[0,0]])
552
+ elif padding == 'same':
553
+ h = inp.shape[0]%strides[0]
554
+ if h==0:
555
+ pad_h = np.max([0,kernel_size[0]-strides[0]])
556
+ else:
557
+ pad_h = np.max([0,kernel_size[0]-h])
537
558
 
538
- def calculate_wt_conv_unit(wt, p_mat, n_mat, t_sum, p_sum, n_sum, act):
539
- wt_mat = np.zeros_like(p_mat)
540
- if act["type"] == "mono":
559
+ v = inp.shape[1]%strides[1]
560
+ if v==0:
561
+ pad_v = np.max([0,kernel_size[1]-strides[1]])
562
+ else:
563
+ pad_v = np.max([0,kernel_size[1]-v])
564
+
565
+ paddings = [np.floor([pad_h/2.0,(pad_h+1)/2.0]).astype("int32"),
566
+ np.floor([pad_v/2.0,(pad_v+1)/2.0]).astype("int32"),
567
+ np.zeros((2)).astype("int32")]
568
+ inp_pad = np.pad(inp, paddings, 'constant', constant_values=const_val)
569
+ return (inp_pad,paddings)
570
+ else:
571
+ if isinstance(padding, tuple) and padding != (None, None):
572
+ pad_h = padding[0]
573
+ pad_v = padding[1]
574
+ paddings = [np.floor([pad_h,pad_h]).astype("int32"),
575
+ np.floor([pad_v,pad_v]).astype("int32"),
576
+ np.zeros((2)).astype("int32")]
577
+ inp_pad = np.pad(inp, paddings, 'constant', constant_values=const_val)
578
+ return (inp_pad,paddings)
579
+ else:
580
+ return (inp, [[0,0],[0,0],[0,0]])
581
+
582
+ def calculate_wt_conv_unit(patch, wts, w, b, act):
583
+ k = w.numpy()
584
+ bias = b.numpy()
585
+ b_ind = bias>0
586
+ bias_pos = bias*b_ind
587
+ b_ind = bias<0
588
+ bias_neg = bias*b_ind*-1.0
589
+ conv_out = np.einsum("ijkl,ijk->ijkl",k,patch)
590
+ p_ind = conv_out>0
591
+ p_ind = conv_out*p_ind
592
+ p_sum = np.einsum("ijkl->l",p_ind)
593
+ n_ind = conv_out<0
594
+ n_ind = conv_out*n_ind
595
+ n_sum = np.einsum("ijkl->l",n_ind)*-1.0
596
+ t_sum = p_sum+n_sum
597
+ wt_mat = np.zeros_like(k)
598
+ p_saturate = p_sum>0
599
+ n_saturate = n_sum>0
600
+ if act["type"]=='mono':
541
601
  if act["range"]["l"]:
542
- if t_sum < act["range"]["l"]:
543
- p_sum = 0
602
+ temp_ind = t_sum > act["range"]["l"]
603
+ p_saturate = temp_ind
544
604
  if act["range"]["u"]:
545
- if t_sum > act["range"]["u"]:
546
- n_sum = 0
547
- elif act["type"] == "non_mono":
605
+ temp_ind = t_sum < act["range"]["u"]
606
+ n_saturate = temp_ind
607
+ elif act["type"]=='non_mono':
548
608
  t_act = act["func"](t_sum)
549
- p_act = act["func"](p_sum)
550
- n_act = act["func"](n_sum)
609
+ p_act = act["func"](p_sum + bias_pos)
610
+ n_act = act["func"](-1*(n_sum + bias_neg))
551
611
  if act["range"]["l"]:
552
- if t_sum < act["range"]["l"]:
553
- p_sum = 0
612
+ temp_ind = t_sum > act["range"]["l"]
613
+ p_saturate = p_saturate*temp_ind
554
614
  if act["range"]["u"]:
555
- if t_sum > act["range"]["u"]:
556
- n_sum = 0
557
- if p_sum > 0 and n_sum > 0:
558
- if t_act == p_act:
559
- n_sum = 0
560
- elif t_act == n_act:
561
- p_sum = 0
562
- p_agg_wt = 0.0
563
- n_agg_wt = 0.0
564
- if p_sum + n_sum > 0.0:
565
- p_agg_wt = p_sum / (p_sum + n_sum)
566
- n_agg_wt = n_sum / (p_sum + n_sum)
567
- if p_sum == 0.0:
568
- p_sum = 1.0
569
- if n_sum == 0.0:
570
- n_sum = 1.0
571
- wt_mat = wt_mat + ((p_mat / p_sum) * wt * p_agg_wt)
572
- wt_mat = wt_mat + ((n_mat / n_sum) * wt * n_agg_wt * -1.0)
615
+ temp_ind = t_sum < act["range"]["u"]
616
+ n_saturate = n_saturate*temp_ind
617
+ temp_ind = np.abs(t_act - p_act)>1e-5
618
+ n_saturate = n_saturate*temp_ind
619
+ temp_ind = np.abs(t_act - n_act)>1e-5
620
+ p_saturate = p_saturate*temp_ind
621
+ p_agg_wt = (1.0/(p_sum+n_sum+bias_pos+bias_neg))*wts*p_saturate
622
+ n_agg_wt = (1.0/(p_sum+n_sum+bias_pos+bias_neg))*wts*n_saturate
623
+
624
+ wt_mat = wt_mat+(p_ind*p_agg_wt)
625
+ wt_mat = wt_mat+(n_ind*n_agg_wt*-1.0)
626
+ wt_mat = np.sum(wt_mat,axis=-1)
573
627
  return wt_mat
574
628
 
575
-
576
- def calculate_wt_conv(wts, inp, w, b, act):
577
- wts=wts.T
578
- inp=inp.T
579
- w=w.T
580
- expanded_input = as_strided(
581
- inp,
582
- shape=(
583
- inp.shape[0]
584
- - w.numpy().shape[0]
585
- + 1, # The feature map is a few pixels smaller than the input
586
- inp.shape[1] - w.numpy().shape[1] + 1,
587
- inp.shape[2],
588
- w.numpy().shape[0],
589
- w.numpy().shape[1],
590
- ),
591
- strides=(
592
- inp.strides[0],
593
- inp.strides[1],
594
- inp.strides[2],
595
- inp.strides[
596
- 0
597
- ], # When we move one step in the 3rd dimension, we should move one step in the original data too
598
- inp.strides[1],
599
- ),
600
- writeable=False, # totally use this to avoid writing to memory in weird places
601
- )
602
- test_wt = np.einsum("mnc->cmn", np.zeros_like(inp), order="C", optimize=True)
603
- for k in range(w.numpy().shape[-1]):
604
- kernel = w.numpy()[:, :, :, k]
605
- x = np.einsum(
606
- "abcmn,mnc->abcmn", expanded_input, kernel, order="C", optimize=True
607
- )
608
- x_pos = x.copy()
609
- x_neg = x.copy()
610
- x_pos[x < 0] = 0
611
- x_neg[x > 0] = 0
612
- x_sum = np.einsum("abcmn->ab", x, order="C", optimize=True)
613
- x_p_sum = np.einsum("abcmn->ab", x_pos, order="C", optimize=True)
614
- x_n_sum = np.einsum("abcmn->ab", x_neg, order="C", optimize=True) * -1.0
615
- # print(np.sum(x),np.sum(x_pos),np.sum(x_neg),np.sum(x_n_sum))
616
- for ind1 in range(expanded_input.shape[0]):
617
- for ind2 in range(expanded_input.shape[1]):
618
- temp_wt_mat = calculate_wt_conv_unit(
619
- wts[ind1, ind2, k],
620
- x_pos[ind1, ind2, :, :, :],
621
- x_neg[ind1, ind2, :, :, :],
622
- x_sum[ind1, ind2],
623
- x_p_sum[ind1, ind2],
624
- x_n_sum[ind1, ind2],
625
- act,
626
- )
627
- test_wt[
628
- :, ind1 : ind1 + kernel.shape[0], ind2 : ind2 + kernel.shape[1]
629
- ] += temp_wt_mat
630
- test_wt = np.einsum("cmn->mnc", test_wt, order="C", optimize=True)
631
- gc.collect()
632
- return test_wt
633
-
634
-
635
- def get_max_index(mat=None):
636
- max_ind = np.argmax(mat)
637
- ind = []
638
- rem = max_ind
639
- for i in mat.shape[:-1]:
640
- ind.append(rem // i)
641
- rem = rem % i
642
- ind.append(rem)
643
- return tuple(ind)
644
-
645
-
646
- def calculate_wt_maxpool(wts, inp, pool_size):
629
+ def calculate_wt_conv(wts, inp, w, b, padding, strides, act):
630
+ wts = wts.T
631
+ inp = inp.T
632
+ w = w.T
633
+ input_padded, paddings = calculate_padding(w.shape, inp, padding, strides)
634
+ out_ds = np.zeros_like(input_padded)
635
+ for ind1 in range(wts.shape[0]):
636
+ for ind2 in range(wts.shape[1]):
637
+ indexes = [np.arange(ind1*strides[0], ind1*(strides[0])+w.shape[0]),
638
+ np.arange(ind2*strides[1], ind2*(strides[1])+w.shape[1])]
639
+ # Take slice
640
+ tmp_patch = input_padded[np.ix_(indexes[0],indexes[1])]
641
+ updates = calculate_wt_conv_unit(tmp_patch, wts[ind1,ind2,:], w, b, act)
642
+ # Build tensor with "filtered" gradient
643
+ out_ds[np.ix_(indexes[0],indexes[1])]+=updates
644
+ out_ds = out_ds[paddings[0][0]:(paddings[0][0]+inp.shape[0]),
645
+ paddings[1][0]:(paddings[1][0]+inp.shape[1]),:]
646
+ return out_ds
647
+
648
+
649
+ def calculate_wt_max_unit(patch, wts, pool_size):
650
+ pmax = np.einsum("ijk,k->ijk",np.ones_like(patch),np.max(np.max(patch,axis=0),axis=0))
651
+ indexes = (patch-pmax)==0
652
+ indexes = indexes.astype(np.float32)
653
+ indexes_norm = 1.0/np.einsum("mnc->c",indexes)
654
+ indexes = np.einsum("ijk,k->ijk",indexes,indexes_norm)
655
+ out = np.einsum("ijk,k->ijk",indexes,wts)
656
+ return out
657
+
658
+ def calculate_wt_maxpool(wts, inp, pool_size, padding, strides):
647
659
  wts=wts.T
648
660
  inp=inp.T
649
- pad1 = pool_size[0]
650
- pad2 = pool_size[1]
651
- test_samp_pad = np.pad(inp, ((0, pad1), (0, pad2), (0, 0)), "constant")
652
- dim1, dim2, _ = wts.shape
653
- test_wt = np.zeros_like(test_samp_pad)
654
- for k in range(inp.shape[2]):
655
- wt_mat = wts[:, :, k]
656
- for ind1 in range(dim1):
657
- for ind2 in range(dim2):
658
- temp_inp = test_samp_pad[
659
- ind1 * pool_size[0] : (ind1 + 1) * pool_size[0],
660
- ind2 * pool_size[1] : (ind2 + 1) * pool_size[1],
661
- k,
662
- ]
663
- max_index = get_max_index(temp_inp)
664
- test_wt[
665
- ind1 * pool_size[0] : (ind1 + 1) * pool_size[0],
666
- ind2 * pool_size[1] : (ind2 + 1) * pool_size[1],
667
- k,
668
- ][max_index] = wt_mat[ind1, ind2]
669
- test_wt = test_wt[0 : inp.shape[0], 0 : inp.shape[1], :]
670
- return test_wt
671
-
661
+ strides = (strides,strides)
662
+ padding = (padding,padding)
663
+ input_padded, paddings = calculate_padding(pool_size, inp, padding, strides, -np.inf)
664
+ out_ds = np.zeros_like(input_padded)
665
+ for ind1 in range(wts.shape[0]):
666
+ for ind2 in range(wts.shape[1]):
667
+ indexes = [np.arange(ind1*strides[0], ind1*(strides[0])+pool_size[0]),
668
+ np.arange(ind2*strides[1], ind2*(strides[1])+pool_size[1])]
669
+ tmp_patch = input_padded[np.ix_(indexes[0],indexes[1])]
670
+ updates = calculate_wt_max_unit(tmp_patch, wts[ind1,ind2,:], pool_size)
671
+ out_ds[np.ix_(indexes[0],indexes[1])]+=updates
672
+ out_ds = out_ds[paddings[0][0]:(paddings[0][0]+inp.shape[0]),
673
+ paddings[1][0]:(paddings[1][0]+inp.shape[1]),:]
674
+ return out_ds
675
+
676
+
677
+ def calculate_wt_avg_unit(patch, wts, pool_size):
678
+ p_ind = patch>0
679
+ p_ind = patch*p_ind
680
+ p_sum = np.einsum("ijk->k",p_ind)
681
+ n_ind = patch<0
682
+ n_ind = patch*n_ind
683
+ n_sum = np.einsum("ijk->k",n_ind)*-1.0
684
+ t_sum = p_sum+n_sum
685
+ wt_mat = np.zeros_like(patch)
686
+ p_saturate = p_sum>0
687
+ n_saturate = n_sum>0
688
+ t_sum[t_sum==0] = 1.0
689
+ p_agg_wt = (1.0/(t_sum))*wts*p_saturate
690
+ n_agg_wt = (1.0/(t_sum))*wts*n_saturate
691
+ wt_mat = wt_mat+(p_ind*p_agg_wt)
692
+ wt_mat = wt_mat+(n_ind*n_agg_wt*-1.0)
693
+ return wt_mat
672
694
 
673
- def calculate_wt_avgpool(wts, inp, pool_size):
695
+ def calculate_wt_avgpool(wts, inp, pool_size, padding, strides):
674
696
  wts=wts.T
675
697
  inp=inp.T
676
698
 
677
699
  pad1 = pool_size[0]
678
700
  pad2 = pool_size[1]
679
- test_samp_pad = np.pad(inp, ((0, pad1), (0, pad2), (0, 0)), "constant")
680
- dim1, dim2, _ = wts.shape
681
- test_wt = np.zeros_like(test_samp_pad)
682
- for k in range(inp.shape[2]):
683
- wt_mat = wts[:, :, k]
684
- for ind1 in range(dim1):
685
- for ind2 in range(dim2):
686
- temp_inp = test_samp_pad[
687
- ind1 * pool_size[0] : (ind1 + 1) * pool_size[0],
688
- ind2 * pool_size[1] : (ind2 + 1) * pool_size[1],
689
- k,
690
- ]
691
- wt_ind1 = test_wt[
692
- ind1 * pool_size[0] : (ind1 + 1) * pool_size[0],
693
- ind2 * pool_size[1] : (ind2 + 1) * pool_size[1],
694
- k,
695
- ]
696
- wt = wt_mat[ind1, ind2]
697
- p_ind = temp_inp > 0
698
- n_ind = temp_inp < 0
699
- p_sum = np.sum(temp_inp[p_ind])
700
- n_sum = np.sum(temp_inp[n_ind]) * -1
701
- if p_sum > 0:
702
- p_agg_wt = p_sum / (p_sum + n_sum)
703
- else:
704
- p_agg_wt = 0
705
- if n_sum > 0:
706
- n_agg_wt = n_sum / (p_sum + n_sum)
707
- else:
708
- n_agg_wt = 0
709
- if p_sum == 0:
710
- p_sum = 1
711
- if n_sum == 0:
712
- n_sum = 1
713
- wt_ind1[p_ind] += (temp_inp[p_ind] / p_sum) * wt * p_agg_wt
714
- wt_ind1[n_ind] += (temp_inp[n_ind] / n_sum) * wt * n_agg_wt * -1.0
715
- test_wt = test_wt[0 : inp.shape[0], 0 : inp.shape[1], :]
716
- return test_wt
717
-
718
-
701
+ strides = (strides,strides)
702
+ padding = (padding,padding)
703
+ input_padded, paddings = calculate_padding(pool_size, inp, padding, strides, -np.inf)
704
+ out_ds = np.zeros_like(input_padded)
705
+ for ind1 in range(wts.shape[0]):
706
+ for ind2 in range(wts.shape[1]):
707
+ indexes = [np.arange(ind1*strides[0], ind1*(strides[0])+pool_size[0]),
708
+ np.arange(ind2*strides[1], ind2*(strides[1])+pool_size[1])]
709
+ # Take slice
710
+ tmp_patch = input_padded[np.ix_(indexes[0],indexes[1])]
711
+ updates = calculate_wt_avg_unit(tmp_patch, wts[ind1,ind2,:], pool_size)
712
+ # Build tensor with "filtered" gradient
713
+ out_ds[np.ix_(indexes[0],indexes[1])]+=updates
714
+ out_ds = out_ds[paddings[0][0]:(paddings[0][0]+inp.shape[0]),
715
+ paddings[1][0]:(paddings[1][0]+inp.shape[1]),:]
716
+ return out_ds
719
717
  def calculate_wt_gavgpool(wts, inp):
720
718
  wts=wts.T
721
719
  inp=inp.T
@@ -745,6 +743,438 @@ def calculate_wt_gavgpool(wts, inp):
745
743
  wt_mat[..., c] = temp_wt
746
744
  return wt_mat
747
745
 
746
+ def calculate_wt_gmaxpool_2d(wts, inp):
747
+ channels = wts.shape[0]
748
+ wt_mat = np.zeros_like(inp)
749
+ for c in range(channels):
750
+ wt = wts[c]
751
+ x = inp[..., c]
752
+ max_val = np.max(x)
753
+ max_indexes = (x == max_val).astype(np.float32)
754
+ max_indexes_norm = 1.0 / np.sum(max_indexes)
755
+ max_indexes = max_indexes * max_indexes_norm
756
+ wt_mat[..., c] = max_indexes * wt
757
+ return wt_mat
758
+
759
+ def calculate_padding_1d(kernel_size, inp, padding, strides, const_val=0.0):
760
+ if padding == 'valid':
761
+ return inp, [[0, 0],[0,0]]
762
+ elif padding == 0:
763
+ return inp, [[0, 0],[0,0]]
764
+ elif isinstance(padding, int):
765
+ inp_pad = np.pad(inp, ((padding, padding), (0,0)), 'constant', constant_values=const_val)
766
+ return inp_pad, [[padding, padding],[0,0]]
767
+ else:
768
+ remainder = inp.shape[0] % strides
769
+ if remainder == 0:
770
+ pad_total = max(0, kernel_size - strides)
771
+ else:
772
+ pad_total = max(0, kernel_size - remainder)
773
+
774
+ pad_left = int(np.floor(pad_total / 2.0))
775
+ pad_right = int(np.ceil(pad_total / 2.0))
776
+
777
+ inp_pad = np.pad(inp, ((pad_left, pad_right),(0,0)), 'constant', constant_values=const_val)
778
+ return inp_pad, [[pad_left, pad_right],[0,0]]
779
+
780
+ def calculate_wt_conv_unit_1d(patch, wts, w, b, act):
781
+ k = w.numpy()
782
+ bias = b.numpy()
783
+ b_ind = bias > 0
784
+ bias_pos = bias * b_ind
785
+ b_ind = bias < 0
786
+ bias_neg = bias * b_ind * -1.0
787
+ conv_out = np.einsum("ijk,ij->ijk", k, patch)
788
+ p_ind = conv_out > 0
789
+ p_ind = conv_out * p_ind
790
+ p_sum = np.einsum("ijk->k",p_ind)
791
+ n_ind = conv_out < 0
792
+ n_ind = conv_out * n_ind
793
+ n_sum = np.einsum("ijk->k",n_ind) * -1.0
794
+ t_sum = p_sum + n_sum
795
+ wt_mat = np.zeros_like(k)
796
+ p_saturate = p_sum > 0
797
+ n_saturate = n_sum > 0
798
+ if act["type"] == 'mono':
799
+ if act["range"]["l"]:
800
+ temp_ind = t_sum > act["range"]["l"]
801
+ p_saturate = temp_ind
802
+ if act["range"]["u"]:
803
+ temp_ind = t_sum < act["range"]["u"]
804
+ n_saturate = temp_ind
805
+ elif act["type"] == 'non_mono':
806
+ t_act = act["func"](t_sum)
807
+ p_act = act["func"](p_sum + bias_pos)
808
+ n_act = act["func"](-1 * (n_sum + bias_neg))
809
+ if act["range"]["l"]:
810
+ temp_ind = t_sum > act["range"]["l"]
811
+ p_saturate = p_saturate * temp_ind
812
+ if act["range"]["u"]:
813
+ temp_ind = t_sum < act["range"]["u"]
814
+ n_saturate = n_saturate * temp_ind
815
+ temp_ind = np.abs(t_act - p_act) > 1e-5
816
+ n_saturate = n_saturate * temp_ind
817
+ temp_ind = np.abs(t_act - n_act) > 1e-5
818
+ p_saturate = p_saturate * temp_ind
819
+ p_agg_wt = (1.0 / (p_sum + n_sum + bias_pos + bias_neg)) * wts * p_saturate
820
+ n_agg_wt = (1.0 / (p_sum + n_sum + bias_pos + bias_neg)) * wts * n_saturate
821
+
822
+ wt_mat = wt_mat + (p_ind * p_agg_wt)
823
+ wt_mat = wt_mat + (n_ind * n_agg_wt * -1.0)
824
+ wt_mat = np.sum(wt_mat, axis=-1)
825
+ return wt_mat
826
+
827
+ def calculate_wt_conv_1d(wts, inp, w, b, padding, stride, act):
828
+ wts = wts.T
829
+ inp = inp.T
830
+ w = w.T
831
+ stride=stride
832
+ input_padded, paddings = calculate_padding_1d(w.shape[0], inp, padding, stride)
833
+ out_ds = np.zeros_like(input_padded)
834
+ for ind in range(wts.shape[0]):
835
+ indexes = np.arange(ind * stride, ind * stride + w.shape[0])
836
+ tmp_patch = input_padded[indexes]
837
+ updates = calculate_wt_conv_unit_1d(tmp_patch, wts[ind, :], w, b, act)
838
+ out_ds[indexes] += updates
839
+ out_ds = out_ds[paddings[0][0]:(paddings[0][0] + inp.shape[0])]
840
+ return out_ds
841
+
842
+ def calculate_wt_max_unit_1d(patch, wts):
843
+ pmax = np.max(patch, axis=0)
844
+ indexes = (patch - pmax) == 0
845
+ indexes = indexes.astype(np.float32)
846
+ indexes_norm = 1.0 / np.sum(indexes, axis=0)
847
+ indexes = np.einsum("ij,j->ij", indexes, indexes_norm)
848
+ out = np.einsum("ij,j->ij", indexes, wts)
849
+ return out
850
+
851
+ def calculate_wt_maxpool_1d(wts, inp, pool_size, padding, stride):
852
+ inp = inp.T
853
+ wts = wts.T
854
+ input_padded, paddings = calculate_padding_1d(pool_size, inp, padding, stride, -np.inf)
855
+ out_ds = np.zeros_like(input_padded)
856
+ stride=stride
857
+ pool_size=pool_size
858
+ for ind in range(wts.shape[0]):
859
+ indexes = np.arange(ind * stride, ind * stride + pool_size)
860
+ tmp_patch = input_padded[indexes]
861
+ updates = calculate_wt_max_unit_1d(tmp_patch, wts[ind, :])
862
+ out_ds[indexes] += updates
863
+ out_ds = out_ds[paddings[0][0]:(paddings[0][0] + inp.shape[0])]
864
+ return out_ds
865
+
866
+ def calculate_wt_avg_unit_1d(patch, wts):
867
+ p_ind = patch > 0
868
+ p_ind = patch * p_ind
869
+ p_sum = np.sum(p_ind, axis=0)
870
+ n_ind = patch < 0
871
+ n_ind = patch * n_ind
872
+ n_sum = np.sum(n_ind, axis=0) * -1.0
873
+ t_sum = p_sum + n_sum
874
+ wt_mat = np.zeros_like(patch)
875
+ p_saturate = p_sum > 0
876
+ n_saturate = n_sum > 0
877
+ t_sum[t_sum == 0] = 1.0
878
+ p_agg_wt = (1.0 / t_sum) * wts * p_saturate
879
+ n_agg_wt = (1.0 / t_sum) * wts * n_saturate
880
+ wt_mat = wt_mat + (p_ind * p_agg_wt)
881
+ wt_mat = wt_mat + (n_ind * n_agg_wt * -1.0)
882
+ return wt_mat
883
+
884
+ def calculate_wt_avgpool_1d(wts, inp, pool_size, padding, stride):
885
+ wts = wts.T
886
+ inp = inp.T
887
+ stride=stride
888
+ pool_size=pool_size
889
+ input_padded, paddings = calculate_padding_1d(pool_size, inp, padding[0], stride[0], 0)
890
+ out_ds = np.zeros_like(input_padded)
891
+ for ind in range(wts.shape[0]):
892
+ indexes = np.arange(ind * stride[0], ind * stride[0] + pool_size[0])
893
+ tmp_patch = input_padded[indexes]
894
+ updates = calculate_wt_avg_unit_1d(tmp_patch, wts[ind, :])
895
+ out_ds[indexes] += updates
896
+ out_ds = out_ds[paddings[0][0]:(paddings[0][0] + inp.shape[0])]
897
+ return out_ds
898
+
899
+ def calculate_wt_gavgpool_1d(wts, inp):
900
+ channels = wts.shape[0]
901
+ wt_mat = np.zeros_like(inp)
902
+ for c in range(channels):
903
+ wt = wts[c]
904
+ temp_wt = wt_mat[:, c]
905
+ x = inp[:, c]
906
+ p_mat = np.copy(x)
907
+ n_mat = np.copy(x)
908
+ p_mat[p_mat < 0] = 0
909
+ n_mat[n_mat > 0] = 0
910
+ p_sum = np.sum(p_mat)
911
+ n_sum = np.sum(n_mat) * -1
912
+ p_agg_wt = 0.0
913
+ n_agg_wt = 0.0
914
+ if p_sum + n_sum > 0.0:
915
+ p_agg_wt = p_sum / (p_sum + n_sum)
916
+ n_agg_wt = n_sum / (p_sum + n_sum)
917
+ if p_sum == 0.0:
918
+ p_sum = 1.0
919
+ if n_sum == 0.0:
920
+ n_sum = 1.0
921
+ temp_wt = temp_wt + ((p_mat / p_sum) * wt * p_agg_wt)
922
+ temp_wt = temp_wt + ((n_mat / n_sum) * wt * n_agg_wt * -1.0)
923
+ wt_mat[:, c] = temp_wt
924
+ return wt_mat
925
+
926
+ def calculate_wt_gmaxpool_1d(wts, inp):
927
+ wts = wts.T
928
+ inp = inp.T
929
+ channels = wts.shape[0]
930
+ wt_mat = np.zeros_like(inp)
931
+ for c in range(channels):
932
+ wt = wts[c]
933
+ x = inp[:, c]
934
+ max_val = np.max(x)
935
+ max_indexes = (x == max_val).astype(np.float32)
936
+ max_indexes_norm = 1.0 / np.sum(max_indexes)
937
+ max_indexes = max_indexes * max_indexes_norm
938
+ wt_mat[:, c] = max_indexes * wt
939
+ return wt_mat
940
+
941
+ def calculate_output_padding_conv2d_transpose(input_shape, kernel_size, padding, strides):
942
+ if padding == 'valid':
943
+ out_shape = [(input_shape[0] - 1) * strides[0] + kernel_size[0],
944
+ (input_shape[1] - 1) * strides[1] + kernel_size[1]]
945
+ paddings = [[0, 0], [0, 0], [0, 0]]
946
+ elif padding == (0,0):
947
+ out_shape = [(input_shape[0] - 1) * strides[0] + kernel_size[0],
948
+ (input_shape[1] - 1) * strides[1] + kernel_size[1]]
949
+ paddings = [[0, 0], [0, 0], [0, 0]]
950
+ elif isinstance(padding, tuple) and padding != (None, None):
951
+ out_shape = [input_shape[0] * strides[0], input_shape[1] * strides[1]]
952
+ pad_h = padding[0]
953
+ pad_v = padding[1]
954
+ paddings = [[pad_h, pad_h], [pad_v, pad_v], [0, 0]]
955
+ else: # 'same' padding
956
+ out_shape = [input_shape[0] * strides[0], input_shape[1] * strides[1]]
957
+ pad_h = max(0, (input_shape[0] - 1) * strides[0] + kernel_size[0] - out_shape[0])
958
+ pad_v = max(0, (input_shape[1] - 1) * strides[1] + kernel_size[1] - out_shape[1])
959
+ paddings = [[pad_h // 2, pad_h - pad_h // 2],
960
+ [pad_v // 2, pad_v - pad_v // 2],
961
+ [0, 0]]
962
+
963
+ return out_shape, paddings
964
+
965
+ def calculate_wt_conv2d_transpose_unit(patch, wts, w, b, act):
966
+ if patch.ndim == 1:
967
+ patch = patch.reshape(1, 1, -1)
968
+ elif patch.ndim == 2:
969
+ patch = patch.reshape(1, *patch.shape)
970
+ elif patch.ndim != 3:
971
+ raise ValueError(f"Unexpected patch shape: {patch.shape}")
972
+
973
+ k = w.permute(0, 1, 3, 2).numpy()
974
+ bias = b.numpy()
975
+ b_ind = bias > 0
976
+ bias_pos = bias * b_ind
977
+ b_ind = bias < 0
978
+ bias_neg = bias * b_ind * -1.0
979
+
980
+ conv_out = np.einsum('ijkl,mnk->ijkl', k, patch)
981
+ p_ind = conv_out > 0
982
+ p_ind = conv_out * p_ind
983
+ n_ind = conv_out < 0
984
+ n_ind = conv_out * n_ind
985
+
986
+ p_sum = np.einsum("ijkl->l", p_ind)
987
+ n_sum = np.einsum("ijkl->l", n_ind) * -1.0
988
+ t_sum = p_sum + n_sum
989
+
990
+ wt_mat = np.zeros_like(k)
991
+ p_saturate = p_sum > 0
992
+ n_saturate = n_sum > 0
993
+
994
+ if act["type"] == 'mono':
995
+ if act["range"]["l"]:
996
+ p_saturate = t_sum > act["range"]["l"]
997
+ if act["range"]["u"]:
998
+ n_saturate = t_sum < act["range"]["u"]
999
+ elif act["type"] == 'non_mono':
1000
+ t_act = act["func"](t_sum)
1001
+ p_act = act["func"](p_sum + bias_pos)
1002
+ n_act = act["func"](-1 * (n_sum + bias_neg))
1003
+ if act["range"]["l"]:
1004
+ temp_ind = t_sum > act["range"]["l"]
1005
+ p_saturate = p_saturate * temp_ind
1006
+ if act["range"]["u"]:
1007
+ temp_ind = t_sum < act["range"]["u"]
1008
+ n_saturate = n_saturate * temp_ind
1009
+ temp_ind = np.abs(t_act - p_act) > 1e-5
1010
+ n_saturate = n_saturate * temp_ind
1011
+ temp_ind = np.abs(t_act - n_act) > 1e-5
1012
+ p_saturate = p_saturate * temp_ind
1013
+
1014
+ p_agg_wt = (1.0 / (p_sum + n_sum + bias_pos + bias_neg)) * wts * p_saturate
1015
+ n_agg_wt = (1.0 / (p_sum + n_sum + bias_pos + bias_neg)) * wts * n_saturate
1016
+
1017
+ wt_mat = wt_mat + (p_ind * p_agg_wt)
1018
+ wt_mat = wt_mat + (n_ind * n_agg_wt * -1.0)
1019
+ wt_mat = np.sum(wt_mat, axis=-1)
1020
+ return wt_mat
1021
+
1022
+ def calculate_wt_conv2d_transpose(wts, inp, w, b, padding, strides, act):
1023
+ wts = wts.T
1024
+ inp = inp.T
1025
+ w = w.T
1026
+ out_shape, paddings = calculate_output_padding_conv2d_transpose(inp.shape, w.shape, padding, strides)
1027
+ out_ds = np.zeros(out_shape + [w.shape[3]])
1028
+
1029
+ for ind1 in range(inp.shape[0]):
1030
+ for ind2 in range(inp.shape[1]):
1031
+ out_ind1 = ind1 * strides[0]
1032
+ out_ind2 = ind2 * strides[1]
1033
+ tmp_patch = inp[ind1, ind2, :]
1034
+ updates = calculate_wt_conv2d_transpose_unit(tmp_patch, wts[ind1, ind2, :], w, b, act)
1035
+ end_ind1 = min(out_ind1 + w.shape[0], out_shape[0])
1036
+ end_ind2 = min(out_ind2 + w.shape[1], out_shape[1])
1037
+ valid_updates = updates[:end_ind1 - out_ind1, :end_ind2 - out_ind2, :]
1038
+ out_ds[out_ind1:end_ind1, out_ind2:end_ind2, :] += valid_updates
1039
+
1040
+ if padding == 'same':
1041
+ adjusted_out_ds = np.zeros(inp.shape)
1042
+ for i in range(inp.shape[0]):
1043
+ for j in range(inp.shape[1]):
1044
+ start_i = max(0, i * strides[0])
1045
+ start_j = max(0, j * strides[1])
1046
+ end_i = min(out_ds.shape[0], (i+1) * strides[0])
1047
+ end_j = min(out_ds.shape[1], (j+1) * strides[1])
1048
+ relevant_area = out_ds[start_i:end_i, start_j:end_j, :]
1049
+ adjusted_out_ds[i, j, :] = np.sum(relevant_area, axis=(0, 1))
1050
+ out_ds = adjusted_out_ds
1051
+ elif isinstance(padding, tuple) and padding != (None, None):
1052
+ adjusted_out_ds = np.zeros(inp.shape)
1053
+ for i in range(inp.shape[0]):
1054
+ for j in range(inp.shape[1]):
1055
+ start_i = max(0, i * strides[0])
1056
+ start_j = max(0, j * strides[1])
1057
+ end_i = min(out_ds.shape[0], (i+1) * strides[0])
1058
+ end_j = min(out_ds.shape[1], (j+1) * strides[1])
1059
+ relevant_area = out_ds[start_i:end_i, start_j:end_j, :]
1060
+ adjusted_out_ds[i, j, :] = np.sum(relevant_area, axis=(0, 1))
1061
+ out_ds = adjusted_out_ds
1062
+ else:
1063
+ out_ds = out_ds[paddings[0][0]:(paddings[0][0] + inp.shape[0]),
1064
+ paddings[1][0]:(paddings[1][0] + inp.shape[1]), :]
1065
+
1066
+ return out_ds
1067
+
1068
+
1069
+ def calculate_output_padding_conv1d_transpose(input_shape, kernel_size, padding, strides,dilation):
1070
+ if padding == 'valid':
1071
+ out_shape = [(input_shape[0] - 1) * strides + kernel_size[0]]
1072
+ paddings = [[0, 0], [0, 0]]
1073
+ elif padding == 0:
1074
+ out_shape = [(input_shape[0] - 1) * strides + kernel_size[0]]
1075
+ paddings = [[0, 0], [0, 0]]
1076
+ elif isinstance(padding, int):
1077
+ out_shape = [input_shape[0] * strides]
1078
+ pad_v = (dilation * (kernel_size[0] - 1)) - padding
1079
+ out_shape = [input_shape[0] * strides + pad_v]
1080
+ paddings = [[pad_v, pad_v],
1081
+ [0, 0]]
1082
+ else: # 'same' padding
1083
+ out_shape = [input_shape[0] * strides]
1084
+ pad_h = max(0, (input_shape[0] - 1) * strides + kernel_size[0] - out_shape[0])
1085
+ paddings = [[pad_h // 2, pad_h // 2],
1086
+ [0, 0]]
1087
+
1088
+ return out_shape, paddings
1089
+
1090
+ def calculate_wt_conv1d_transpose_unit(patch, wts, w, b, act):
1091
+ if patch.ndim == 1:
1092
+ patch = patch.reshape(1, -1)
1093
+ elif patch.ndim != 2:
1094
+ raise ValueError(f"Unexpected patch shape: {patch.shape}")
1095
+
1096
+ k = w.permute(0, 2, 1).numpy()
1097
+ bias = b.numpy()
1098
+ b_ind = bias > 0
1099
+ bias_pos = bias * b_ind
1100
+ b_ind = bias < 0
1101
+ bias_neg = bias * b_ind * -1.0
1102
+ conv_out = np.einsum('ijk,mj->ijk', k, patch)
1103
+ p_ind = conv_out > 0
1104
+ p_ind = conv_out * p_ind
1105
+ n_ind = conv_out < 0
1106
+ n_ind = conv_out * n_ind
1107
+
1108
+ p_sum = np.einsum("ijl->l", p_ind)
1109
+ n_sum = np.einsum("ijl->l", n_ind) * -1.0
1110
+ t_sum = p_sum + n_sum
1111
+
1112
+ wt_mat = np.zeros_like(k)
1113
+ p_saturate = p_sum > 0
1114
+ n_saturate = n_sum > 0
1115
+
1116
+ if act["type"] == 'mono':
1117
+ if act["range"]["l"]:
1118
+ p_saturate = t_sum > act["range"]["l"]
1119
+ if act["range"]["u"]:
1120
+ n_saturate = t_sum < act["range"]["u"]
1121
+ elif act["type"] == 'non_mono':
1122
+ t_act = act["func"](t_sum)
1123
+ p_act = act["func"](p_sum + bias_pos)
1124
+ n_act = act["func"](-1 * (n_sum + bias_neg))
1125
+ if act["range"]["l"]:
1126
+ temp_ind = t_sum > act["range"]["l"]
1127
+ p_saturate = p_saturate * temp_ind
1128
+ if act["range"]["u"]:
1129
+ temp_ind = t_sum < act["range"]["u"]
1130
+ n_saturate = n_saturate * temp_ind
1131
+ temp_ind = np.abs(t_act - p_act) > 1e-5
1132
+ n_saturate = n_saturate * temp_ind
1133
+ temp_ind = np.abs(t_act - n_act) > 1e-5
1134
+ p_saturate = p_saturate * temp_ind
1135
+
1136
+ p_agg_wt = (1.0 / (p_sum + n_sum + bias_pos + bias_neg)) * wts * p_saturate
1137
+ n_agg_wt = (1.0 / (p_sum + n_sum + bias_pos + bias_neg)) * wts * n_saturate
1138
+ wt_mat = wt_mat + (p_ind * p_agg_wt)
1139
+ wt_mat = wt_mat + (n_ind * n_agg_wt * -1.0)
1140
+ wt_mat = np.sum(wt_mat, axis=-1)
1141
+ return wt_mat
1142
+
1143
+ def calculate_wt_conv1d_transpose(wts, inp, w, b, padding, strides, dilation, act):
1144
+ wts = wts.T
1145
+ inp = inp.T
1146
+ w = w.T
1147
+ out_shape, paddings = calculate_output_padding_conv1d_transpose(inp.shape, w.shape, padding, strides, dilation)
1148
+ out_ds = np.zeros(out_shape + [w.shape[2]])
1149
+
1150
+ for ind in range(inp.shape[0]):
1151
+ out_ind = ind * strides
1152
+ tmp_patch = inp[ind, :]
1153
+ updates = calculate_wt_conv1d_transpose_unit(tmp_patch, wts[ind, :], w, b, act)
1154
+ end_ind = min(out_ind + w.shape[0], out_shape[0])
1155
+ valid_updates = updates[:end_ind - out_ind, :]
1156
+ out_ds[out_ind:end_ind, :] += valid_updates
1157
+
1158
+ if padding == 'same':
1159
+ adjusted_out_ds = np.zeros(inp.shape)
1160
+ for i in range(inp.shape[0]):
1161
+ start_i = max(0, i * strides)
1162
+ end_i = min(out_ds.shape[0], (i + 1) * strides)
1163
+ relevant_area = out_ds[start_i:end_i, :]
1164
+ adjusted_out_ds[i, :] = np.sum(relevant_area, axis=0)
1165
+ out_ds = adjusted_out_ds
1166
+ elif padding == 0:
1167
+ adjusted_out_ds = np.zeros(inp.shape)
1168
+ for i in range(inp.shape[0]):
1169
+ start_i = max(0, i * strides)
1170
+ end_i = min(out_ds.shape[0], (i + 1) * strides)
1171
+ relevant_area = out_ds[start_i:end_i, :]
1172
+ adjusted_out_ds[i, :] = np.sum(relevant_area, axis=0)
1173
+ out_ds = adjusted_out_ds
1174
+ else:
1175
+ out_ds = out_ds[paddings[0][0]:(paddings[0][0] + inp.shape[0]), :]
1176
+ return out_ds
1177
+
748
1178
 
749
1179
  ####################################################################
750
1180
  ################### Encoder Model ####################
@@ -753,27 +1183,85 @@ def stabilize(matrix, epsilon=1e-6):
753
1183
  return matrix + epsilon * np.sign(matrix)
754
1184
 
755
1185
 
756
- def calculate_relevance_V(wts, value_output):
757
- # Initialize wt_mat with zeros
758
- wt_mat_V = np.zeros((wts.shape[0], wts.shape[1], *value_output.shape))
1186
+ def calculate_wt_residual(wts, inp=None):
1187
+ wt_mat = []
1188
+ inp_list = []
1189
+ expanded_wts = as_strided(
1190
+ wts,
1191
+ shape=(np.prod(wts.shape),),
1192
+ strides=(wts.strides[-1],),
1193
+ writeable=False, # totally use this to avoid writing to memory in weird places
1194
+ )
1195
+
1196
+ for x in inp:
1197
+ expanded_input = as_strided(
1198
+ x,
1199
+ shape=(np.prod(x.shape),),
1200
+ strides=(x.strides[-1],),
1201
+ writeable=False, # totally use this to avoid writing to memory in weird places
1202
+ )
1203
+ inp_list.append(expanded_input)
1204
+ wt_mat.append(np.zeros_like(expanded_input))
1205
+ wt_mat = np.array(wt_mat)
1206
+ inp_list = np.array(inp_list)
1207
+ for i in range(wt_mat.shape[1]):
1208
+ wt_ind1 = wt_mat[:, i]
1209
+ wt = expanded_wts[i]
1210
+ l1_ind1 = inp_list[:, i]
1211
+ p_ind = l1_ind1 > 0
1212
+ n_ind = l1_ind1 < 0
1213
+ p_sum = np.sum(l1_ind1[p_ind])
1214
+ n_sum = np.sum(l1_ind1[n_ind]) * -1
1215
+ t_sum = p_sum - n_sum
1216
+ p_agg_wt = 0
1217
+ n_agg_wt = 0
1218
+ if p_sum + n_sum > 0:
1219
+ p_agg_wt = p_sum / (p_sum + n_sum)
1220
+ n_agg_wt = n_sum / (p_sum + n_sum)
1221
+ if p_sum == 0:
1222
+ p_sum = 1
1223
+ if n_sum == 0:
1224
+ n_sum = 1
1225
+ wt_ind1[p_ind] = (l1_ind1[p_ind] / p_sum) * wt * p_agg_wt
1226
+ wt_ind1[n_ind] = (l1_ind1[n_ind] / n_sum) * wt * n_agg_wt * -1.0
1227
+ wt_mat[:, i] = wt_ind1
1228
+ wt_mat = [i.reshape(wts.shape) for i in list(wt_mat)]
1229
+ return wt_mat
1230
+
1231
+
1232
+ def calculate_relevance_V(wts, value_output, w):
1233
+ wt_mat_V = np.zeros(value_output.shape)
1234
+
1235
+ if 'b_v' in w:
1236
+ bias_v = w['b_v']
1237
+ else:
1238
+ bias_v = 0
759
1239
 
760
1240
  for i in range(wts.shape[0]):
761
1241
  for j in range(wts.shape[1]):
762
1242
  l1_ind1 = value_output
763
- wt_ind1 = wt_mat_V[i, j]
764
1243
  wt = wts[i, j]
765
1244
 
766
1245
  p_ind = l1_ind1 > 0
767
1246
  n_ind = l1_ind1 < 0
768
1247
  p_sum = np.sum(l1_ind1[p_ind])
769
1248
  n_sum = np.sum(l1_ind1[n_ind]) * -1
1249
+
1250
+ if bias_v[i] > 0:
1251
+ pbias = bias_v[i]
1252
+ nbias = 0
1253
+ else:
1254
+ pbias = 0
1255
+ nbias = bias_v[i] * -1
770
1256
 
771
1257
  if p_sum > 0:
772
- p_agg_wt = p_sum / (p_sum + n_sum)
1258
+ p_agg_wt = (p_sum + pbias) / (p_sum + n_sum + pbias + nbias)
1259
+ p_agg_wt = p_agg_wt * (p_sum / (p_sum + pbias))
773
1260
  else:
774
1261
  p_agg_wt = 0
775
1262
  if n_sum > 0:
776
- n_agg_wt = n_sum / (p_sum + n_sum)
1263
+ n_agg_wt = (n_sum + nbias) / (p_sum + n_sum + pbias + nbias)
1264
+ n_agg_wt = n_agg_wt * (n_sum / (n_sum + nbias))
777
1265
  else:
778
1266
  n_agg_wt = 0
779
1267
 
@@ -782,21 +1270,22 @@ def calculate_relevance_V(wts, value_output):
782
1270
  if n_sum == 0:
783
1271
  n_sum = 1
784
1272
 
785
- wt_ind1[p_ind] = (l1_ind1[p_ind] / p_sum) * wt * p_agg_wt
786
- wt_ind1[n_ind] = (l1_ind1[n_ind] / n_sum) * wt * n_agg_wt * -1.0
1273
+ wt_mat_V[p_ind] += (l1_ind1[p_ind] / p_sum) * wt * p_agg_wt
1274
+ wt_mat_V[n_ind] += (l1_ind1[n_ind] / n_sum) * wt * n_agg_wt * -1.0
787
1275
 
788
- wt_mat_V = np.sum(wt_mat_V, axis=(0,1))
789
1276
  return wt_mat_V
790
1277
 
791
1278
 
792
- def calculate_relevance_QK(wts, QK_output):
793
- # Initialize wt_mat with zeros
794
- wt_mat_QK = np.zeros((wts.shape[0], wts.shape[1], *QK_output.shape))
1279
+ def calculate_relevance_QK(wts, QK_output, w):
1280
+ wt_mat_QK = np.zeros(QK_output.shape)
1281
+
1282
+ # Check if 'b_q' and 'b_k' exist in the weights, default to 0 if not
1283
+ b_q = w['b_q'] if 'b_q' in w else 0
1284
+ b_k = w['b_k'] if 'b_k' in w else 0
795
1285
 
796
1286
  for i in range(wts.shape[0]):
797
1287
  for j in range(wts.shape[1]):
798
1288
  l1_ind1 = QK_output
799
- wt_ind1 = wt_mat_QK[i, j]
800
1289
  wt = wts[i, j]
801
1290
 
802
1291
  p_ind = l1_ind1 > 0
@@ -804,7 +1293,21 @@ def calculate_relevance_QK(wts, QK_output):
804
1293
  p_sum = np.sum(l1_ind1[p_ind])
805
1294
  n_sum = np.sum(l1_ind1[n_ind]) * -1
806
1295
 
807
- t_sum = p_sum - n_sum
1296
+ if b_q[i] > 0 and b_k[i] > 0:
1297
+ pbias = b_q[i] + b_k[i]
1298
+ nbias = 0
1299
+ elif b_q[i] > 0 and b_k[i] < 0:
1300
+ pbias = b_q[i]
1301
+ nbias = b_k[i] * -1
1302
+ elif b_q[i] < 0 and b_k[i] > 0:
1303
+ pbias = b_k[i]
1304
+ nbias = b_q[i] * -1
1305
+ else:
1306
+ pbias = 0
1307
+ nbias = b_q[i] + b_k[i]
1308
+ nbias *= -1
1309
+
1310
+ t_sum = p_sum + pbias - n_sum - nbias
808
1311
 
809
1312
  # This layer has a softmax activation function
810
1313
  act = {
@@ -823,12 +1326,13 @@ def calculate_relevance_QK(wts, QK_output):
823
1326
  n_sum = 0
824
1327
 
825
1328
  if p_sum > 0:
826
- p_agg_wt = p_sum / (p_sum + n_sum)
1329
+ p_agg_wt = (p_sum + pbias) / (p_sum + n_sum + pbias + nbias)
1330
+ p_agg_wt = p_agg_wt * (p_sum / (p_sum + pbias))
827
1331
  else:
828
1332
  p_agg_wt = 0
829
-
830
1333
  if n_sum > 0:
831
- n_agg_wt = n_sum / (p_sum + n_sum)
1334
+ n_agg_wt = (n_sum + nbias) / (p_sum + n_sum + pbias + nbias)
1335
+ n_agg_wt = n_agg_wt * (n_sum / (n_sum + nbias))
832
1336
  else:
833
1337
  n_agg_wt = 0
834
1338
 
@@ -837,14 +1341,60 @@ def calculate_relevance_QK(wts, QK_output):
837
1341
  if n_sum == 0:
838
1342
  n_sum = 1
839
1343
 
840
- wt_ind1[p_ind] = (l1_ind1[p_ind] / p_sum) * wt * p_agg_wt
841
- wt_ind1[n_ind] = (l1_ind1[n_ind] / n_sum) * wt * n_agg_wt * -1.0
1344
+ wt_mat_QK[p_ind] += (l1_ind1[p_ind] / p_sum) * wt * p_agg_wt
1345
+ wt_mat_QK[n_ind] += (l1_ind1[n_ind] / n_sum) * wt * n_agg_wt * -1.0
842
1346
 
843
- wt_mat_QK = np.sum(wt_mat_QK, axis=(0, 1))
844
1347
  return wt_mat_QK
845
1348
 
846
1349
 
847
- def calculate_wt_self_attention(wts, inp, w):
1350
+ def calculate_wt_attention_output_projection(wts, proj_output, w):
1351
+ wt_mat_proj_output = np.zeros(proj_output.shape)
1352
+
1353
+ if 'b_d' in w:
1354
+ bias_d = w['b_d']
1355
+ else:
1356
+ bias_d = 0
1357
+
1358
+ for i in range(wts.shape[0]):
1359
+ for j in range(wts.shape[1]):
1360
+ l1_ind1 = proj_output
1361
+ wt = wts[i, j]
1362
+
1363
+ p_ind = l1_ind1 > 0
1364
+ n_ind = l1_ind1 < 0
1365
+ p_sum = np.sum(l1_ind1[p_ind])
1366
+ n_sum = np.sum(l1_ind1[n_ind]) * -1
1367
+
1368
+ if bias_d[i] > 0:
1369
+ pbias = bias_d[i]
1370
+ nbias = 0
1371
+ else:
1372
+ pbias = 0
1373
+ nbias = bias_d[i] * -1
1374
+
1375
+ if p_sum > 0:
1376
+ p_agg_wt = (p_sum + pbias) / (p_sum + n_sum + pbias + nbias)
1377
+ p_agg_wt = p_agg_wt * (p_sum / (p_sum + pbias))
1378
+ else:
1379
+ p_agg_wt = 0
1380
+ if n_sum > 0:
1381
+ n_agg_wt = (n_sum + nbias) / (p_sum + n_sum + pbias + nbias)
1382
+ n_agg_wt = n_agg_wt * (n_sum / (n_sum + nbias))
1383
+ else:
1384
+ n_agg_wt = 0
1385
+
1386
+ if p_sum == 0:
1387
+ p_sum = 1
1388
+ if n_sum == 0:
1389
+ n_sum = 1
1390
+
1391
+ wt_mat_proj_output[p_ind] += (l1_ind1[p_ind] / p_sum) * wt * p_agg_wt
1392
+ wt_mat_proj_output[n_ind] += (l1_ind1[n_ind] / n_sum) * wt * n_agg_wt * -1.0
1393
+
1394
+ return wt_mat_proj_output
1395
+
1396
+
1397
+ def calculate_wt_self_attention(wts, inp, w, config):
848
1398
  '''
849
1399
  Input:
850
1400
  wts: relevance score of the layer
@@ -856,28 +1406,82 @@ def calculate_wt_self_attention(wts, inp, w):
856
1406
  Step-2: outputs = F.softmax(inputs, dim=dim, dtype=dtype)
857
1407
  Step-3: outputs = input_a * input_b
858
1408
  '''
1409
+ # print(f"inp: {inp.shape}, wts: {wts.shape}") # (1, 512)
1410
+ # print(f"w['W_q']: {w['W_q'].shape}, w['W_k']: {w['W_k'].shape}, w['W_v']: {w['W_v'].shape}")
1411
+
859
1412
  query_output = np.einsum('ij,kj->ik', inp, w['W_q'])
860
1413
  key_output = np.einsum('ij,kj->ik', inp, w['W_k'])
861
1414
  value_output = np.einsum('ij,kj->ik', inp, w['W_v'])
862
1415
 
1416
+ # --------------- Reshape for Multi-Head Attention ----------------------
1417
+ num_heads = getattr(config, 'num_attention_heads', getattr(config, 'num_heads', None)) # will work for BERT as well as T5/ Llama
1418
+ hidden_size = getattr(config, 'hidden_size', getattr(config, 'd_model', None)) # will work for BERT as well as T5/Llama
1419
+ if hasattr(config, 'num_key_value_heads'):
1420
+ num_key_value_heads = config.num_key_value_heads
1421
+ else:
1422
+ num_key_value_heads = num_heads
1423
+ head_dim = hidden_size // num_heads # dimension of each attention head
1424
+
1425
+ query_states = np.einsum('thd->htd', query_output.reshape(query_output.shape[0], num_heads, head_dim)) # (num_heads, num_tokens, head_dim)
1426
+ key_states = np.einsum('thd->htd', key_output.reshape(key_output.shape[0], num_key_value_heads, head_dim)) # (num_key_value_heads, num_tokens, head_dim)
1427
+ value_states = np.einsum('thd->htd', value_output.reshape(value_output.shape[0], num_key_value_heads, head_dim)) # (num_key_value_heads, num_tokens, head_dim)
1428
+
1429
+ # calculate how many times we need to repeat the key/value heads
1430
+ n_rep = num_heads // num_key_value_heads
1431
+ key_states = np.repeat(key_states, n_rep, axis=0)
1432
+ value_states = np.repeat(value_states, n_rep, axis=0)
1433
+
1434
+ QK_output = np.einsum('hqd,hkd->hqk', query_states, key_states) # (num_heads, num_tokens, num_tokens)
1435
+ attn_weights = QK_output / np.sqrt(head_dim)
1436
+
1437
+ # Apply softmax along the last dimension (softmax over key dimension)
1438
+ attn_weights = np.exp(attn_weights - np.max(attn_weights, axis=-1, keepdims=True)) # Numerically stable softmax
1439
+ attn_weights = attn_weights / np.sum(attn_weights, axis=-1, keepdims=True)
1440
+
1441
+ # Weighted sum of values (num_heads, num_tokens, head_dim)
1442
+ attn_output = np.einsum('hqk,hkl->hql', attn_weights, value_states)
1443
+
1444
+ transposed_attn_output = np.einsum('hqd->qhd', attn_output)
1445
+ reshaped_attn_output = transposed_attn_output.reshape(transposed_attn_output.shape[0], num_heads * head_dim)
1446
+
1447
+ # Perform final linear projection (num_tokens, hidden_size)
1448
+ final_output = np.einsum('qd,dh->qh', reshaped_attn_output, w['W_d'])
1449
+
1450
+ # ------------- Relevance calculation for Final Linear Projection -------------
1451
+ wt_mat_attn_proj = calculate_wt_attention_output_projection(wts, final_output, w)
1452
+
863
1453
  # --------------- Relevance Calculation for Step-3 -----------------------
864
- relevance_V = wts / 2
865
- relevance_QK = wts / 2
1454
+ # divide the relevance among `attn_weights` and `value_states`
1455
+ wt_mat_attn_proj = wt_mat_attn_proj.reshape(-1, num_heads, head_dim)
1456
+ wt_mat_attn_proj = np.einsum('qhd->hqd', wt_mat_attn_proj)
1457
+
1458
+ stabilized_attn_output = stabilize(attn_output * 2)
1459
+ norm_wt_mat_attn_proj = wt_mat_attn_proj / stabilized_attn_output
1460
+ relevance_QK = np.einsum('htd,hbd->htb', norm_wt_mat_attn_proj, value_states) * attn_weights
1461
+ relevance_V = np.einsum('hdt,hdb->htb', attn_weights, norm_wt_mat_attn_proj) * value_states
866
1462
 
867
1463
  # --------------- Relevance Calculation for V --------------------------------
868
- wt_mat_V = calculate_relevance_V(relevance_V, value_output)
1464
+ relevance_V = np.einsum('hqd->qhd', relevance_V)
1465
+ relevance_V = relevance_V.reshape(-1, num_heads * head_dim)
1466
+ wt_mat_V = calculate_relevance_V(relevance_V, value_states, w)
869
1467
 
870
1468
  # --------------- Transformed Relevance QK ----------------------------------
871
- QK_output = np.einsum('ij,kj->ik', query_output, key_output)
872
- wt_mat_QK = calculate_relevance_QK(relevance_QK, QK_output)
1469
+ relevance_QK = np.einsum('hqd->qhd', relevance_QK)
1470
+ relevance_QK = relevance_QK.reshape(-1, relevance_QK.shape[1] * relevance_QK.shape[2])
1471
+ wt_mat_QK = calculate_relevance_QK(relevance_QK, QK_output, w)
873
1472
 
874
1473
  # --------------- Relevance Calculation for K and Q --------------------------------
875
1474
  stabilized_QK_output = stabilize(QK_output * 2)
876
1475
  norm_wt_mat_QK = wt_mat_QK / stabilized_QK_output
877
- wt_mat_Q = np.einsum('ij,jk->ik', norm_wt_mat_QK, key_output) * query_output
878
- wt_mat_K = np.einsum('ij,ik->kj', query_output, norm_wt_mat_QK) * key_output
1476
+ wt_mat_Q = np.einsum('htd,hdb->htb', norm_wt_mat_QK, key_states) * query_states
1477
+ wt_mat_K = np.einsum('htd,htb->hbd', query_states, norm_wt_mat_QK) * key_states
879
1478
 
880
1479
  wt_mat = wt_mat_V + wt_mat_K + wt_mat_Q
1480
+
1481
+ # Reshape wt_mat
1482
+ wt_mat = np.einsum('htd->thd', wt_mat)
1483
+ wt_mat = wt_mat.reshape(wt_mat.shape[0], wt_mat.shape[1] * wt_mat.shape[2]) # reshaped_array = array.reshape(8, 32 * 128)
1484
+
881
1485
  return wt_mat
882
1486
 
883
1487
 
@@ -893,7 +1497,9 @@ def calculate_wt_feed_forward(wts, inp, w):
893
1497
  R2 = wts[i]
894
1498
  contribution_matrix2 = np.einsum('ij,j->ij', w['W_out'], intermediate_output[i])
895
1499
  wt_mat2 = np.zeros(contribution_matrix2.shape)
896
-
1500
+
1501
+ bias_out = w['b_out'] if 'b_out' in w else 0
1502
+
897
1503
  for j in range(contribution_matrix2.shape[0]):
898
1504
  l1_ind1 = contribution_matrix2[j]
899
1505
  wt_ind1 = wt_mat2[j]
@@ -903,14 +1509,23 @@ def calculate_wt_feed_forward(wts, inp, w):
903
1509
  n_ind = l1_ind1 < 0
904
1510
  p_sum = np.sum(l1_ind1[p_ind])
905
1511
  n_sum = np.sum(l1_ind1[n_ind]) * -1
1512
+
1513
+ # Handle positive and negative bias contributions
1514
+ if bias_out[i] > 0:
1515
+ pbias = bias_out[i]
1516
+ nbias = 0
1517
+ else:
1518
+ pbias = 0
1519
+ nbias = -bias_out[i]
906
1520
 
907
1521
  if p_sum > 0:
908
- p_agg_wt = p_sum / (p_sum + n_sum)
1522
+ p_agg_wt = (p_sum + pbias) / (p_sum + n_sum + pbias + nbias)
1523
+ p_agg_wt = p_agg_wt * (p_sum / (p_sum + pbias))
909
1524
  else:
910
1525
  p_agg_wt = 0
911
-
912
1526
  if n_sum > 0:
913
- n_agg_wt = n_sum / (p_sum + n_sum)
1527
+ n_agg_wt = (n_sum + nbias) / (p_sum + n_sum + pbias + nbias)
1528
+ n_agg_wt = n_agg_wt * (n_sum / (n_sum + nbias))
914
1529
  else:
915
1530
  n_agg_wt = 0
916
1531
 
@@ -929,6 +1544,9 @@ def calculate_wt_feed_forward(wts, inp, w):
929
1544
  R1 = relevance_out[i]
930
1545
  contribution_matrix1 = np.einsum('ij,j->ij', w['W_int'], inp[i])
931
1546
  wt_mat1 = np.zeros(contribution_matrix1.shape)
1547
+
1548
+ # Check if bias 'b_int' exists, default to 0 if not
1549
+ bias_int = w['b_int'] if 'b_int' in w else 0
932
1550
 
933
1551
  for j in range(contribution_matrix1.shape[0]):
934
1552
  l1_ind1 = contribution_matrix1[j]
@@ -940,7 +1558,15 @@ def calculate_wt_feed_forward(wts, inp, w):
940
1558
  p_sum = np.sum(l1_ind1[p_ind])
941
1559
  n_sum = np.sum(l1_ind1[n_ind]) * -1
942
1560
 
943
- t_sum = p_sum - n_sum
1561
+ # Handle positive and negative bias
1562
+ if bias_int[i] > 0:
1563
+ pbias = bias_int[i]
1564
+ nbias = 0
1565
+ else:
1566
+ pbias = 0
1567
+ nbias = -bias_int[i]
1568
+
1569
+ t_sum = p_sum + pbias - n_sum - nbias
944
1570
 
945
1571
  # This layer has a ReLU activation function
946
1572
  act = {
@@ -959,12 +1585,13 @@ def calculate_wt_feed_forward(wts, inp, w):
959
1585
  n_sum = 0
960
1586
 
961
1587
  if p_sum > 0:
962
- p_agg_wt = p_sum / (p_sum + n_sum)
1588
+ p_agg_wt = (p_sum + pbias) / (p_sum + n_sum + pbias + nbias)
1589
+ p_agg_wt = p_agg_wt * (p_sum / (p_sum + pbias))
963
1590
  else:
964
1591
  p_agg_wt = 0
965
-
966
1592
  if n_sum > 0:
967
- n_agg_wt = n_sum / (p_sum + n_sum)
1593
+ n_agg_wt = (n_sum + nbias) / (p_sum + n_sum + pbias + nbias)
1594
+ n_agg_wt = n_agg_wt * (n_sum / (n_sum + nbias))
968
1595
  else:
969
1596
  n_agg_wt = 0
970
1597
 
@@ -1121,7 +1748,7 @@ def calculate_wt_pooler(wts, inp, w):
1121
1748
  # Calculate relevance for each token
1122
1749
  relevance_inp[i] = wt_mat.sum(axis=0)
1123
1750
 
1124
- relevance_inp *= (100 / np.sum(relevance_inp))
1751
+ relevance_inp *= (np.sum(wts) / np.sum(relevance_inp))
1125
1752
  return relevance_inp
1126
1753
 
1127
1754