dl-backtrace 0.0.18__py3-none-any.whl → 0.0.20.dev36__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of dl-backtrace might be problematic. Click here for more details.
- dl_backtrace/pytorch_backtrace/backtrace/backtrace.py +194 -70
- dl_backtrace/pytorch_backtrace/backtrace/utils/contrast.py +607 -156
- dl_backtrace/pytorch_backtrace/backtrace/utils/prop.py +892 -265
- dl_backtrace/tf_backtrace/backtrace/backtrace.py +11 -7
- dl_backtrace/tf_backtrace/backtrace/utils/utils_prop.py +249 -47
- dl_backtrace/version.py +2 -2
- {dl_backtrace-0.0.18.dist-info → dl_backtrace-0.0.20.dev36.dist-info}/METADATA +1 -1
- {dl_backtrace-0.0.18.dist-info → dl_backtrace-0.0.20.dev36.dist-info}/RECORD +11 -11
- {dl_backtrace-0.0.18.dist-info → dl_backtrace-0.0.20.dev36.dist-info}/WHEEL +1 -1
- {dl_backtrace-0.0.18.dist-info → dl_backtrace-0.0.20.dev36.dist-info}/LICENSE +0 -0
- {dl_backtrace-0.0.18.dist-info → dl_backtrace-0.0.20.dev36.dist-info}/top_level.txt +0 -0
|
@@ -1,42 +1,32 @@
|
|
|
1
1
|
import gc
|
|
2
|
-
|
|
2
|
+
import torch
|
|
3
3
|
import numpy as np
|
|
4
|
-
import tensorflow as tf
|
|
5
4
|
from numpy.lib.stride_tricks import as_strided
|
|
6
|
-
from tensorflow.keras import backend as K
|
|
7
|
-
|
|
8
5
|
|
|
9
6
|
def np_swish(x, beta=0.75):
|
|
10
7
|
z = 1 / (1 + np.exp(-(beta * x)))
|
|
11
8
|
return x * z
|
|
12
9
|
|
|
13
|
-
|
|
14
10
|
def np_wave(x, alpha=1.0):
|
|
15
11
|
return (alpha * x * np.exp(1.0)) / (np.exp(-x) + np.exp(x))
|
|
16
12
|
|
|
17
|
-
|
|
18
13
|
def np_pulse(x, alpha=1.0):
|
|
19
14
|
return alpha * (1 - np.tanh(x) * np.tanh(x))
|
|
20
15
|
|
|
21
|
-
|
|
22
16
|
def np_absolute(x, alpha=1.0):
|
|
23
17
|
return alpha * x * np.tanh(x)
|
|
24
18
|
|
|
25
|
-
|
|
26
19
|
def np_hard_sigmoid(x):
|
|
27
20
|
return np.clip(0.2 * x + 0.5, 0, 1)
|
|
28
21
|
|
|
29
|
-
|
|
30
22
|
def np_sigmoid(x):
|
|
31
23
|
z = 1 / (1 + np.exp(-x))
|
|
32
24
|
return z
|
|
33
25
|
|
|
34
|
-
|
|
35
26
|
def np_tanh(x):
|
|
36
27
|
z = np.tanh(x)
|
|
37
28
|
return z.astype(np.float32)
|
|
38
29
|
|
|
39
|
-
|
|
40
30
|
class LSTM_forward(object):
|
|
41
31
|
def __init__(
|
|
42
32
|
self, num_cells, units, weights, return_sequence=False, go_backwards=False
|
|
@@ -48,8 +38,8 @@ class LSTM_forward(object):
|
|
|
48
38
|
self.bias = weights[2][1]
|
|
49
39
|
self.return_sequence = return_sequence
|
|
50
40
|
self.go_backwards = go_backwards
|
|
51
|
-
self.recurrent_activation =
|
|
52
|
-
self.activation =
|
|
41
|
+
self.recurrent_activation = torch.sigmoid()
|
|
42
|
+
self.activation = torch.tanh()
|
|
53
43
|
self.compute_log = {}
|
|
54
44
|
for i in range(self.num_cells):
|
|
55
45
|
self.compute_log[i] = {}
|
|
@@ -63,23 +53,19 @@ class LSTM_forward(object):
|
|
|
63
53
|
"""Computes carry and output using split kernels."""
|
|
64
54
|
x_i, x_f, x_c, x_o = x
|
|
65
55
|
h_tm1_i, h_tm1_f, h_tm1_c, h_tm1_o = h_tm1
|
|
66
|
-
|
|
67
|
-
#print(h_tm1_i.shape,self.recurrent_kernel[1][:, : self.units].shape)
|
|
68
|
-
w=tf.convert_to_tensor(self.recurrent_kernel[1], dtype=tf.float32)
|
|
69
|
-
#print(K.dot(h_tm1_i, w[:, : self.units]))
|
|
70
|
-
|
|
56
|
+
w=torch.as_tensor(self.recurrent_kernel[1], dtype=torch.float32)
|
|
71
57
|
i = self.recurrent_activation(
|
|
72
|
-
x_i +
|
|
58
|
+
x_i + torch.dot(h_tm1_i, w[:, : self.units])
|
|
73
59
|
)
|
|
74
60
|
f = self.recurrent_activation(
|
|
75
|
-
x_f +
|
|
61
|
+
x_f + torch.dot(h_tm1_f, w[:, self.units : self.units * 2])
|
|
76
62
|
)
|
|
77
63
|
c = f * c_tm1 + i * self.activation(
|
|
78
64
|
x_c
|
|
79
|
-
+
|
|
65
|
+
+ torch.dot(h_tm1_c, w[:, self.units * 2 : self.units * 3])
|
|
80
66
|
)
|
|
81
67
|
o = self.recurrent_activation(
|
|
82
|
-
x_o +
|
|
68
|
+
x_o + torch.dot(h_tm1_o, w[:, self.units * 3 :])
|
|
83
69
|
)
|
|
84
70
|
self.compute_log[cell_num]["int_arrays"]["i"] = i
|
|
85
71
|
self.compute_log[cell_num]["int_arrays"]["f"] = f
|
|
@@ -97,16 +83,16 @@ class LSTM_forward(object):
|
|
|
97
83
|
inputs_f = inputs
|
|
98
84
|
inputs_c = inputs
|
|
99
85
|
inputs_o = inputs
|
|
100
|
-
k_i, k_f, k_c, k_o =
|
|
101
|
-
x_i =
|
|
102
|
-
x_f =
|
|
103
|
-
x_c =
|
|
104
|
-
x_o =
|
|
105
|
-
b_i, b_f, b_c, b_o =
|
|
106
|
-
x_i =
|
|
107
|
-
x_f =
|
|
108
|
-
x_c =
|
|
109
|
-
x_o =
|
|
86
|
+
k_i, k_f, k_c, k_o = torch.split(self.kernel[1],self.kernel.size(1)//4,dim=1)
|
|
87
|
+
x_i = torch.dot(inputs_i, k_i)
|
|
88
|
+
x_f = torch.dot(inputs_f, k_f)
|
|
89
|
+
x_c = torch.dot(inputs_c, k_c)
|
|
90
|
+
x_o = torch.dot(inputs_o, k_o)
|
|
91
|
+
b_i, b_f, b_c, b_o = torch.split(self.bias,self.bias.size(1)//4,dim=0)
|
|
92
|
+
x_i = x_i + b_i
|
|
93
|
+
x_f = x_f + b_f
|
|
94
|
+
x_c = x_c + b_c
|
|
95
|
+
x_o = x_o + b_o
|
|
110
96
|
|
|
111
97
|
h_tm1_i = h_tm1
|
|
112
98
|
h_tm1_f = h_tm1
|
|
@@ -123,12 +109,12 @@ class LSTM_forward(object):
|
|
|
123
109
|
return h, [h, c]
|
|
124
110
|
|
|
125
111
|
def calculate_lstm_wt(self, input_data):
|
|
126
|
-
hstate =
|
|
127
|
-
cstate =
|
|
112
|
+
hstate = torch.tensor((1,self.units),dtype=torch.float32)
|
|
113
|
+
cstate = torch.tensor((1,self.units),dtype=torch.float32)
|
|
128
114
|
output = []
|
|
129
115
|
for ind in range(input_data.shape[0]):
|
|
130
|
-
inp =
|
|
131
|
-
input_data[ind, :].reshape((1, input_data.shape[1])), dtype=
|
|
116
|
+
inp = torch.tensor(
|
|
117
|
+
input_data[ind, :].reshape((1, input_data.shape[1])), dtype=torch.float32
|
|
132
118
|
)
|
|
133
119
|
h, s = self.calculate_lstm_cell_wt(inp, [hstate, cstate], ind)
|
|
134
120
|
hstate = s[0]
|
|
@@ -136,9 +122,6 @@ class LSTM_forward(object):
|
|
|
136
122
|
output.append(h)
|
|
137
123
|
return output
|
|
138
124
|
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
|
|
142
125
|
class LSTM_backtrace(object):
|
|
143
126
|
def __init__(
|
|
144
127
|
self, num_cells, units, weights, return_sequence=False, go_backwards=False
|
|
@@ -270,8 +253,6 @@ class LSTM_backtrace(object):
|
|
|
270
253
|
x_i, x_f, x_c, x_o = x
|
|
271
254
|
f = self.compute_log[cell_num]["int_arrays"]["f"].numpy()[0]
|
|
272
255
|
i = self.compute_log[cell_num]["int_arrays"]["i"].numpy()[0]
|
|
273
|
-
# o = self.recurrent_activation(
|
|
274
|
-
# x_o + np.dot(h_tm1_o, self.recurrent_kernel[:, self.units * 3:])).astype(np.float32)
|
|
275
256
|
temp1 = np.dot(h_tm1_o, self.recurrent_kernel[1][:, self.units * 3 :]).astype(
|
|
276
257
|
np.float32
|
|
277
258
|
)
|
|
@@ -283,9 +264,6 @@ class LSTM_backtrace(object):
|
|
|
283
264
|
[],
|
|
284
265
|
{"type": None},
|
|
285
266
|
)
|
|
286
|
-
|
|
287
|
-
# c = f * c_tm1 + i * self.activation(x_c + np.dot(
|
|
288
|
-
# h_tm1_c, self.recurrent_kernel[:, self.units * 2:self.units * 3])).astype(np.float32)
|
|
289
267
|
temp2 = f * c_tm1
|
|
290
268
|
temp3_1 = np.dot(
|
|
291
269
|
h_tm1_c, self.recurrent_kernel[1][:, self.units * 2 : self.units * 3]
|
|
@@ -303,9 +281,6 @@ class LSTM_backtrace(object):
|
|
|
303
281
|
[],
|
|
304
282
|
{"type": None},
|
|
305
283
|
)
|
|
306
|
-
|
|
307
|
-
# f = self.recurrent_activation(x_f + np.dot(
|
|
308
|
-
# h_tm1_f, self.recurrent_kernel[:, self.units:self.units * 2])).astype(np.float32)
|
|
309
284
|
temp4 = np.dot(h_tm1_f, self.recurrent_kernel[1][:, self.units : self.units * 2])
|
|
310
285
|
wt_x_f, wt_temp4 = self.calculate_wt_add(wt_f, [x_f, temp4])
|
|
311
286
|
wt_h_tm1_f = self.calculate_wt_fc(
|
|
@@ -315,9 +290,6 @@ class LSTM_backtrace(object):
|
|
|
315
290
|
[],
|
|
316
291
|
{"type": None},
|
|
317
292
|
)
|
|
318
|
-
|
|
319
|
-
# i = self.recurrent_activation(
|
|
320
|
-
# x_i + np.dot(h_tm1_i, self.recurrent_kernel[:, :self.units])).astype(np.float32)
|
|
321
293
|
temp5 = np.dot(h_tm1_i, self.recurrent_kernel[1][:, : self.units])
|
|
322
294
|
wt_x_i, wt_temp5 = self.calculate_wt_add(wt_i, [x_i, temp5])
|
|
323
295
|
wt_h_tm1_i = self.calculate_wt_fc(
|
|
@@ -364,7 +336,6 @@ class LSTM_backtrace(object):
|
|
|
364
336
|
wt_h_tm1 = wt_h_tm1_i + wt_h_tm1_f + wt_h_tm1_c + wt_h_tm1_o
|
|
365
337
|
inputs = self.compute_log[cell_num]["inp"].numpy()[0]
|
|
366
338
|
|
|
367
|
-
#print(np.split(self.kernel[1], indices_or_sections=4, axis=1))
|
|
368
339
|
k_i, k_f, k_c, k_o = np.split(self.kernel[1], indices_or_sections=4, axis=1)
|
|
369
340
|
b_i, b_f, b_c, b_o = np.split(self.bias[1], indices_or_sections=4, axis=0)
|
|
370
341
|
|
|
@@ -395,12 +366,10 @@ class LSTM_backtrace(object):
|
|
|
395
366
|
output.reverse()
|
|
396
367
|
return np.array(output)
|
|
397
368
|
|
|
398
|
-
|
|
399
369
|
def dummy_wt(wts, inp, *args):
|
|
400
370
|
test_wt = np.zeros_like(inp)
|
|
401
371
|
return test_wt
|
|
402
372
|
|
|
403
|
-
|
|
404
373
|
def calculate_wt_fc(wts, inp, w, b, act):
|
|
405
374
|
mul_mat = np.einsum("ij,i->ij", w.numpy().T, inp).T
|
|
406
375
|
wt_mat = np.zeros(mul_mat.shape)
|
|
@@ -461,12 +430,10 @@ def calculate_wt_fc(wts, inp, w, b, act):
|
|
|
461
430
|
wt_mat = wt_mat.sum(axis=0)
|
|
462
431
|
return wt_mat
|
|
463
432
|
|
|
464
|
-
|
|
465
433
|
def calculate_wt_rshp(wts, inp=None):
|
|
466
434
|
x = np.reshape(wts, inp.shape)
|
|
467
435
|
return x
|
|
468
436
|
|
|
469
|
-
|
|
470
437
|
def calculate_wt_concat(wts, inp=None, axis=-1):
|
|
471
438
|
wts=wts.T
|
|
472
439
|
splits = [i.shape[axis] for i in inp]
|
|
@@ -476,7 +443,6 @@ def calculate_wt_concat(wts, inp=None, axis=-1):
|
|
|
476
443
|
x = np.split(wts, indices_or_sections=splits, axis=axis)
|
|
477
444
|
return x
|
|
478
445
|
|
|
479
|
-
|
|
480
446
|
def calculate_wt_add(wts, inp=None):
|
|
481
447
|
wts=wts.T
|
|
482
448
|
wt_mat = []
|
|
@@ -523,199 +489,231 @@ def calculate_wt_add(wts, inp=None):
|
|
|
523
489
|
wt_mat = [i.reshape(wts.shape) for i in list(wt_mat)]
|
|
524
490
|
return wt_mat
|
|
525
491
|
|
|
526
|
-
|
|
527
|
-
|
|
528
|
-
|
|
529
|
-
|
|
530
|
-
|
|
492
|
+
def calculate_start_wt(arg, scaler=None,thresholding=0.5,task="binary-classification"):
|
|
493
|
+
if arg.ndim == 2:
|
|
494
|
+
if task == "binary-classification" or task == "multi-class classification":
|
|
495
|
+
x = np.argmax(arg[0])
|
|
496
|
+
m = np.max(arg[0])
|
|
497
|
+
y = np.zeros(arg.shape)
|
|
498
|
+
if scaler:
|
|
499
|
+
y[0][x] = scaler
|
|
500
|
+
else:
|
|
501
|
+
y[0][x] = m
|
|
502
|
+
elif task == "bbox-regression":
|
|
503
|
+
y = np.zeros(arg.shape)
|
|
504
|
+
if scaler:
|
|
505
|
+
y[0] = scaler
|
|
506
|
+
num_non_zero_elements = np.count_nonzero(y)
|
|
507
|
+
if num_non_zero_elements > 0:
|
|
508
|
+
y = y / num_non_zero_elements
|
|
509
|
+
else:
|
|
510
|
+
m = np.max(arg[0])
|
|
511
|
+
x = np.argmax(arg[0])
|
|
512
|
+
y[0][x] = m
|
|
513
|
+
else:
|
|
514
|
+
x = np.argmax(arg[0])
|
|
515
|
+
m = np.max(arg[0])
|
|
516
|
+
y = np.zeros(arg.shape)
|
|
517
|
+
if scaler:
|
|
518
|
+
y[0][x] = scaler
|
|
519
|
+
else:
|
|
520
|
+
y[0][x] = m
|
|
521
|
+
|
|
522
|
+
elif arg.ndim == 4 and task == "binary-segmentation":
|
|
523
|
+
indices = np.where(arg > thresholding)
|
|
524
|
+
y = np.zeros(arg.shape)
|
|
525
|
+
if scaler:
|
|
526
|
+
y[indices] = scaler
|
|
527
|
+
num_non_zero_elements = np.count_nonzero(y)
|
|
528
|
+
if num_non_zero_elements > 0:
|
|
529
|
+
y = y / num_non_zero_elements
|
|
530
|
+
else:
|
|
531
|
+
y[indices] = arg[indices]
|
|
532
|
+
|
|
533
|
+
else:
|
|
534
|
+
x = np.argmax(arg[0])
|
|
535
|
+
m = np.max(arg[0])
|
|
536
|
+
y = np.zeros(arg.shape)
|
|
537
|
+
if scaler:
|
|
538
|
+
y[0][x] = scaler
|
|
539
|
+
else:
|
|
540
|
+
y[0][x] = m
|
|
531
541
|
return y[0]
|
|
532
542
|
|
|
533
|
-
|
|
534
543
|
def calculate_wt_passthru(wts):
|
|
535
544
|
return wts
|
|
545
|
+
def calculate_wt_zero_pad(wts,inp,padding):
|
|
546
|
+
wt_mat = wts[padding[0][0]:inp.shape[0]+padding[0][0],padding[1][0]:inp.shape[1]+padding[1][0],:]
|
|
547
|
+
return wt_mat
|
|
536
548
|
|
|
549
|
+
def calculate_padding(kernel_size, inp, padding, strides, const_val=0.0):
|
|
550
|
+
if padding=='valid':
|
|
551
|
+
return (inp, [[0,0],[0,0],[0,0]])
|
|
552
|
+
elif padding == 'same':
|
|
553
|
+
h = inp.shape[0]%strides[0]
|
|
554
|
+
if h==0:
|
|
555
|
+
pad_h = np.max([0,kernel_size[0]-strides[0]])
|
|
556
|
+
else:
|
|
557
|
+
pad_h = np.max([0,kernel_size[0]-h])
|
|
537
558
|
|
|
538
|
-
|
|
539
|
-
|
|
540
|
-
|
|
559
|
+
v = inp.shape[1]%strides[1]
|
|
560
|
+
if v==0:
|
|
561
|
+
pad_v = np.max([0,kernel_size[1]-strides[1]])
|
|
562
|
+
else:
|
|
563
|
+
pad_v = np.max([0,kernel_size[1]-v])
|
|
564
|
+
|
|
565
|
+
paddings = [np.floor([pad_h/2.0,(pad_h+1)/2.0]).astype("int32"),
|
|
566
|
+
np.floor([pad_v/2.0,(pad_v+1)/2.0]).astype("int32"),
|
|
567
|
+
np.zeros((2)).astype("int32")]
|
|
568
|
+
inp_pad = np.pad(inp, paddings, 'constant', constant_values=const_val)
|
|
569
|
+
return (inp_pad,paddings)
|
|
570
|
+
else:
|
|
571
|
+
if isinstance(padding, tuple) and padding != (None, None):
|
|
572
|
+
pad_h = padding[0]
|
|
573
|
+
pad_v = padding[1]
|
|
574
|
+
paddings = [np.floor([pad_h,pad_h]).astype("int32"),
|
|
575
|
+
np.floor([pad_v,pad_v]).astype("int32"),
|
|
576
|
+
np.zeros((2)).astype("int32")]
|
|
577
|
+
inp_pad = np.pad(inp, paddings, 'constant', constant_values=const_val)
|
|
578
|
+
return (inp_pad,paddings)
|
|
579
|
+
else:
|
|
580
|
+
return (inp, [[0,0],[0,0],[0,0]])
|
|
581
|
+
|
|
582
|
+
def calculate_wt_conv_unit(patch, wts, w, b, act):
|
|
583
|
+
k = w.numpy()
|
|
584
|
+
bias = b.numpy()
|
|
585
|
+
b_ind = bias>0
|
|
586
|
+
bias_pos = bias*b_ind
|
|
587
|
+
b_ind = bias<0
|
|
588
|
+
bias_neg = bias*b_ind*-1.0
|
|
589
|
+
conv_out = np.einsum("ijkl,ijk->ijkl",k,patch)
|
|
590
|
+
p_ind = conv_out>0
|
|
591
|
+
p_ind = conv_out*p_ind
|
|
592
|
+
p_sum = np.einsum("ijkl->l",p_ind)
|
|
593
|
+
n_ind = conv_out<0
|
|
594
|
+
n_ind = conv_out*n_ind
|
|
595
|
+
n_sum = np.einsum("ijkl->l",n_ind)*-1.0
|
|
596
|
+
t_sum = p_sum+n_sum
|
|
597
|
+
wt_mat = np.zeros_like(k)
|
|
598
|
+
p_saturate = p_sum>0
|
|
599
|
+
n_saturate = n_sum>0
|
|
600
|
+
if act["type"]=='mono':
|
|
541
601
|
if act["range"]["l"]:
|
|
542
|
-
|
|
543
|
-
|
|
602
|
+
temp_ind = t_sum > act["range"]["l"]
|
|
603
|
+
p_saturate = temp_ind
|
|
544
604
|
if act["range"]["u"]:
|
|
545
|
-
|
|
546
|
-
|
|
547
|
-
elif act["type"]
|
|
605
|
+
temp_ind = t_sum < act["range"]["u"]
|
|
606
|
+
n_saturate = temp_ind
|
|
607
|
+
elif act["type"]=='non_mono':
|
|
548
608
|
t_act = act["func"](t_sum)
|
|
549
|
-
p_act = act["func"](p_sum)
|
|
550
|
-
n_act = act["func"](n_sum)
|
|
609
|
+
p_act = act["func"](p_sum + bias_pos)
|
|
610
|
+
n_act = act["func"](-1*(n_sum + bias_neg))
|
|
551
611
|
if act["range"]["l"]:
|
|
552
|
-
|
|
553
|
-
|
|
612
|
+
temp_ind = t_sum > act["range"]["l"]
|
|
613
|
+
p_saturate = p_saturate*temp_ind
|
|
554
614
|
if act["range"]["u"]:
|
|
555
|
-
|
|
556
|
-
|
|
557
|
-
|
|
558
|
-
|
|
559
|
-
|
|
560
|
-
|
|
561
|
-
|
|
562
|
-
|
|
563
|
-
|
|
564
|
-
|
|
565
|
-
|
|
566
|
-
|
|
567
|
-
if p_sum == 0.0:
|
|
568
|
-
p_sum = 1.0
|
|
569
|
-
if n_sum == 0.0:
|
|
570
|
-
n_sum = 1.0
|
|
571
|
-
wt_mat = wt_mat + ((p_mat / p_sum) * wt * p_agg_wt)
|
|
572
|
-
wt_mat = wt_mat + ((n_mat / n_sum) * wt * n_agg_wt * -1.0)
|
|
615
|
+
temp_ind = t_sum < act["range"]["u"]
|
|
616
|
+
n_saturate = n_saturate*temp_ind
|
|
617
|
+
temp_ind = np.abs(t_act - p_act)>1e-5
|
|
618
|
+
n_saturate = n_saturate*temp_ind
|
|
619
|
+
temp_ind = np.abs(t_act - n_act)>1e-5
|
|
620
|
+
p_saturate = p_saturate*temp_ind
|
|
621
|
+
p_agg_wt = (1.0/(p_sum+n_sum+bias_pos+bias_neg))*wts*p_saturate
|
|
622
|
+
n_agg_wt = (1.0/(p_sum+n_sum+bias_pos+bias_neg))*wts*n_saturate
|
|
623
|
+
|
|
624
|
+
wt_mat = wt_mat+(p_ind*p_agg_wt)
|
|
625
|
+
wt_mat = wt_mat+(n_ind*n_agg_wt*-1.0)
|
|
626
|
+
wt_mat = np.sum(wt_mat,axis=-1)
|
|
573
627
|
return wt_mat
|
|
574
628
|
|
|
575
|
-
|
|
576
|
-
|
|
577
|
-
|
|
578
|
-
|
|
579
|
-
|
|
580
|
-
|
|
581
|
-
|
|
582
|
-
shape
|
|
583
|
-
|
|
584
|
-
|
|
585
|
-
|
|
586
|
-
|
|
587
|
-
|
|
588
|
-
|
|
589
|
-
|
|
590
|
-
|
|
591
|
-
|
|
592
|
-
|
|
593
|
-
|
|
594
|
-
|
|
595
|
-
|
|
596
|
-
|
|
597
|
-
|
|
598
|
-
|
|
599
|
-
|
|
600
|
-
|
|
601
|
-
)
|
|
602
|
-
|
|
603
|
-
|
|
604
|
-
|
|
605
|
-
x = np.einsum(
|
|
606
|
-
"abcmn,mnc->abcmn", expanded_input, kernel, order="C", optimize=True
|
|
607
|
-
)
|
|
608
|
-
x_pos = x.copy()
|
|
609
|
-
x_neg = x.copy()
|
|
610
|
-
x_pos[x < 0] = 0
|
|
611
|
-
x_neg[x > 0] = 0
|
|
612
|
-
x_sum = np.einsum("abcmn->ab", x, order="C", optimize=True)
|
|
613
|
-
x_p_sum = np.einsum("abcmn->ab", x_pos, order="C", optimize=True)
|
|
614
|
-
x_n_sum = np.einsum("abcmn->ab", x_neg, order="C", optimize=True) * -1.0
|
|
615
|
-
# print(np.sum(x),np.sum(x_pos),np.sum(x_neg),np.sum(x_n_sum))
|
|
616
|
-
for ind1 in range(expanded_input.shape[0]):
|
|
617
|
-
for ind2 in range(expanded_input.shape[1]):
|
|
618
|
-
temp_wt_mat = calculate_wt_conv_unit(
|
|
619
|
-
wts[ind1, ind2, k],
|
|
620
|
-
x_pos[ind1, ind2, :, :, :],
|
|
621
|
-
x_neg[ind1, ind2, :, :, :],
|
|
622
|
-
x_sum[ind1, ind2],
|
|
623
|
-
x_p_sum[ind1, ind2],
|
|
624
|
-
x_n_sum[ind1, ind2],
|
|
625
|
-
act,
|
|
626
|
-
)
|
|
627
|
-
test_wt[
|
|
628
|
-
:, ind1 : ind1 + kernel.shape[0], ind2 : ind2 + kernel.shape[1]
|
|
629
|
-
] += temp_wt_mat
|
|
630
|
-
test_wt = np.einsum("cmn->mnc", test_wt, order="C", optimize=True)
|
|
631
|
-
gc.collect()
|
|
632
|
-
return test_wt
|
|
633
|
-
|
|
634
|
-
|
|
635
|
-
def get_max_index(mat=None):
|
|
636
|
-
max_ind = np.argmax(mat)
|
|
637
|
-
ind = []
|
|
638
|
-
rem = max_ind
|
|
639
|
-
for i in mat.shape[:-1]:
|
|
640
|
-
ind.append(rem // i)
|
|
641
|
-
rem = rem % i
|
|
642
|
-
ind.append(rem)
|
|
643
|
-
return tuple(ind)
|
|
644
|
-
|
|
645
|
-
|
|
646
|
-
def calculate_wt_maxpool(wts, inp, pool_size):
|
|
629
|
+
def calculate_wt_conv(wts, inp, w, b, padding, strides, act):
|
|
630
|
+
wts = wts.T
|
|
631
|
+
inp = inp.T
|
|
632
|
+
w = w.T
|
|
633
|
+
input_padded, paddings = calculate_padding(w.shape, inp, padding, strides)
|
|
634
|
+
out_ds = np.zeros_like(input_padded)
|
|
635
|
+
for ind1 in range(wts.shape[0]):
|
|
636
|
+
for ind2 in range(wts.shape[1]):
|
|
637
|
+
indexes = [np.arange(ind1*strides[0], ind1*(strides[0])+w.shape[0]),
|
|
638
|
+
np.arange(ind2*strides[1], ind2*(strides[1])+w.shape[1])]
|
|
639
|
+
# Take slice
|
|
640
|
+
tmp_patch = input_padded[np.ix_(indexes[0],indexes[1])]
|
|
641
|
+
updates = calculate_wt_conv_unit(tmp_patch, wts[ind1,ind2,:], w, b, act)
|
|
642
|
+
# Build tensor with "filtered" gradient
|
|
643
|
+
out_ds[np.ix_(indexes[0],indexes[1])]+=updates
|
|
644
|
+
out_ds = out_ds[paddings[0][0]:(paddings[0][0]+inp.shape[0]),
|
|
645
|
+
paddings[1][0]:(paddings[1][0]+inp.shape[1]),:]
|
|
646
|
+
return out_ds
|
|
647
|
+
|
|
648
|
+
|
|
649
|
+
def calculate_wt_max_unit(patch, wts, pool_size):
|
|
650
|
+
pmax = np.einsum("ijk,k->ijk",np.ones_like(patch),np.max(np.max(patch,axis=0),axis=0))
|
|
651
|
+
indexes = (patch-pmax)==0
|
|
652
|
+
indexes = indexes.astype(np.float32)
|
|
653
|
+
indexes_norm = 1.0/np.einsum("mnc->c",indexes)
|
|
654
|
+
indexes = np.einsum("ijk,k->ijk",indexes,indexes_norm)
|
|
655
|
+
out = np.einsum("ijk,k->ijk",indexes,wts)
|
|
656
|
+
return out
|
|
657
|
+
|
|
658
|
+
def calculate_wt_maxpool(wts, inp, pool_size, padding, strides):
|
|
647
659
|
wts=wts.T
|
|
648
660
|
inp=inp.T
|
|
649
|
-
|
|
650
|
-
|
|
651
|
-
|
|
652
|
-
|
|
653
|
-
|
|
654
|
-
|
|
655
|
-
|
|
656
|
-
|
|
657
|
-
|
|
658
|
-
|
|
659
|
-
|
|
660
|
-
|
|
661
|
-
|
|
662
|
-
|
|
663
|
-
|
|
664
|
-
|
|
665
|
-
|
|
666
|
-
|
|
667
|
-
|
|
668
|
-
|
|
669
|
-
|
|
670
|
-
|
|
671
|
-
|
|
661
|
+
strides = (strides,strides)
|
|
662
|
+
padding = (padding,padding)
|
|
663
|
+
input_padded, paddings = calculate_padding(pool_size, inp, padding, strides, -np.inf)
|
|
664
|
+
out_ds = np.zeros_like(input_padded)
|
|
665
|
+
for ind1 in range(wts.shape[0]):
|
|
666
|
+
for ind2 in range(wts.shape[1]):
|
|
667
|
+
indexes = [np.arange(ind1*strides[0], ind1*(strides[0])+pool_size[0]),
|
|
668
|
+
np.arange(ind2*strides[1], ind2*(strides[1])+pool_size[1])]
|
|
669
|
+
tmp_patch = input_padded[np.ix_(indexes[0],indexes[1])]
|
|
670
|
+
updates = calculate_wt_max_unit(tmp_patch, wts[ind1,ind2,:], pool_size)
|
|
671
|
+
out_ds[np.ix_(indexes[0],indexes[1])]+=updates
|
|
672
|
+
out_ds = out_ds[paddings[0][0]:(paddings[0][0]+inp.shape[0]),
|
|
673
|
+
paddings[1][0]:(paddings[1][0]+inp.shape[1]),:]
|
|
674
|
+
return out_ds
|
|
675
|
+
|
|
676
|
+
|
|
677
|
+
def calculate_wt_avg_unit(patch, wts, pool_size):
|
|
678
|
+
p_ind = patch>0
|
|
679
|
+
p_ind = patch*p_ind
|
|
680
|
+
p_sum = np.einsum("ijk->k",p_ind)
|
|
681
|
+
n_ind = patch<0
|
|
682
|
+
n_ind = patch*n_ind
|
|
683
|
+
n_sum = np.einsum("ijk->k",n_ind)*-1.0
|
|
684
|
+
t_sum = p_sum+n_sum
|
|
685
|
+
wt_mat = np.zeros_like(patch)
|
|
686
|
+
p_saturate = p_sum>0
|
|
687
|
+
n_saturate = n_sum>0
|
|
688
|
+
t_sum[t_sum==0] = 1.0
|
|
689
|
+
p_agg_wt = (1.0/(t_sum))*wts*p_saturate
|
|
690
|
+
n_agg_wt = (1.0/(t_sum))*wts*n_saturate
|
|
691
|
+
wt_mat = wt_mat+(p_ind*p_agg_wt)
|
|
692
|
+
wt_mat = wt_mat+(n_ind*n_agg_wt*-1.0)
|
|
693
|
+
return wt_mat
|
|
672
694
|
|
|
673
|
-
def calculate_wt_avgpool(wts, inp, pool_size):
|
|
695
|
+
def calculate_wt_avgpool(wts, inp, pool_size, padding, strides):
|
|
674
696
|
wts=wts.T
|
|
675
697
|
inp=inp.T
|
|
676
698
|
|
|
677
699
|
pad1 = pool_size[0]
|
|
678
700
|
pad2 = pool_size[1]
|
|
679
|
-
|
|
680
|
-
|
|
681
|
-
|
|
682
|
-
|
|
683
|
-
|
|
684
|
-
for
|
|
685
|
-
|
|
686
|
-
|
|
687
|
-
|
|
688
|
-
|
|
689
|
-
|
|
690
|
-
|
|
691
|
-
|
|
692
|
-
|
|
693
|
-
|
|
694
|
-
|
|
695
|
-
]
|
|
696
|
-
wt = wt_mat[ind1, ind2]
|
|
697
|
-
p_ind = temp_inp > 0
|
|
698
|
-
n_ind = temp_inp < 0
|
|
699
|
-
p_sum = np.sum(temp_inp[p_ind])
|
|
700
|
-
n_sum = np.sum(temp_inp[n_ind]) * -1
|
|
701
|
-
if p_sum > 0:
|
|
702
|
-
p_agg_wt = p_sum / (p_sum + n_sum)
|
|
703
|
-
else:
|
|
704
|
-
p_agg_wt = 0
|
|
705
|
-
if n_sum > 0:
|
|
706
|
-
n_agg_wt = n_sum / (p_sum + n_sum)
|
|
707
|
-
else:
|
|
708
|
-
n_agg_wt = 0
|
|
709
|
-
if p_sum == 0:
|
|
710
|
-
p_sum = 1
|
|
711
|
-
if n_sum == 0:
|
|
712
|
-
n_sum = 1
|
|
713
|
-
wt_ind1[p_ind] += (temp_inp[p_ind] / p_sum) * wt * p_agg_wt
|
|
714
|
-
wt_ind1[n_ind] += (temp_inp[n_ind] / n_sum) * wt * n_agg_wt * -1.0
|
|
715
|
-
test_wt = test_wt[0 : inp.shape[0], 0 : inp.shape[1], :]
|
|
716
|
-
return test_wt
|
|
717
|
-
|
|
718
|
-
|
|
701
|
+
strides = (strides,strides)
|
|
702
|
+
padding = (padding,padding)
|
|
703
|
+
input_padded, paddings = calculate_padding(pool_size, inp, padding, strides, -np.inf)
|
|
704
|
+
out_ds = np.zeros_like(input_padded)
|
|
705
|
+
for ind1 in range(wts.shape[0]):
|
|
706
|
+
for ind2 in range(wts.shape[1]):
|
|
707
|
+
indexes = [np.arange(ind1*strides[0], ind1*(strides[0])+pool_size[0]),
|
|
708
|
+
np.arange(ind2*strides[1], ind2*(strides[1])+pool_size[1])]
|
|
709
|
+
# Take slice
|
|
710
|
+
tmp_patch = input_padded[np.ix_(indexes[0],indexes[1])]
|
|
711
|
+
updates = calculate_wt_avg_unit(tmp_patch, wts[ind1,ind2,:], pool_size)
|
|
712
|
+
# Build tensor with "filtered" gradient
|
|
713
|
+
out_ds[np.ix_(indexes[0],indexes[1])]+=updates
|
|
714
|
+
out_ds = out_ds[paddings[0][0]:(paddings[0][0]+inp.shape[0]),
|
|
715
|
+
paddings[1][0]:(paddings[1][0]+inp.shape[1]),:]
|
|
716
|
+
return out_ds
|
|
719
717
|
def calculate_wt_gavgpool(wts, inp):
|
|
720
718
|
wts=wts.T
|
|
721
719
|
inp=inp.T
|
|
@@ -745,6 +743,438 @@ def calculate_wt_gavgpool(wts, inp):
|
|
|
745
743
|
wt_mat[..., c] = temp_wt
|
|
746
744
|
return wt_mat
|
|
747
745
|
|
|
746
|
+
def calculate_wt_gmaxpool_2d(wts, inp):
|
|
747
|
+
channels = wts.shape[0]
|
|
748
|
+
wt_mat = np.zeros_like(inp)
|
|
749
|
+
for c in range(channels):
|
|
750
|
+
wt = wts[c]
|
|
751
|
+
x = inp[..., c]
|
|
752
|
+
max_val = np.max(x)
|
|
753
|
+
max_indexes = (x == max_val).astype(np.float32)
|
|
754
|
+
max_indexes_norm = 1.0 / np.sum(max_indexes)
|
|
755
|
+
max_indexes = max_indexes * max_indexes_norm
|
|
756
|
+
wt_mat[..., c] = max_indexes * wt
|
|
757
|
+
return wt_mat
|
|
758
|
+
|
|
759
|
+
def calculate_padding_1d(kernel_size, inp, padding, strides, const_val=0.0):
|
|
760
|
+
if padding == 'valid':
|
|
761
|
+
return inp, [[0, 0],[0,0]]
|
|
762
|
+
elif padding == 0:
|
|
763
|
+
return inp, [[0, 0],[0,0]]
|
|
764
|
+
elif isinstance(padding, int):
|
|
765
|
+
inp_pad = np.pad(inp, ((padding, padding), (0,0)), 'constant', constant_values=const_val)
|
|
766
|
+
return inp_pad, [[padding, padding],[0,0]]
|
|
767
|
+
else:
|
|
768
|
+
remainder = inp.shape[0] % strides
|
|
769
|
+
if remainder == 0:
|
|
770
|
+
pad_total = max(0, kernel_size - strides)
|
|
771
|
+
else:
|
|
772
|
+
pad_total = max(0, kernel_size - remainder)
|
|
773
|
+
|
|
774
|
+
pad_left = int(np.floor(pad_total / 2.0))
|
|
775
|
+
pad_right = int(np.ceil(pad_total / 2.0))
|
|
776
|
+
|
|
777
|
+
inp_pad = np.pad(inp, ((pad_left, pad_right),(0,0)), 'constant', constant_values=const_val)
|
|
778
|
+
return inp_pad, [[pad_left, pad_right],[0,0]]
|
|
779
|
+
|
|
780
|
+
def calculate_wt_conv_unit_1d(patch, wts, w, b, act):
|
|
781
|
+
k = w.numpy()
|
|
782
|
+
bias = b.numpy()
|
|
783
|
+
b_ind = bias > 0
|
|
784
|
+
bias_pos = bias * b_ind
|
|
785
|
+
b_ind = bias < 0
|
|
786
|
+
bias_neg = bias * b_ind * -1.0
|
|
787
|
+
conv_out = np.einsum("ijk,ij->ijk", k, patch)
|
|
788
|
+
p_ind = conv_out > 0
|
|
789
|
+
p_ind = conv_out * p_ind
|
|
790
|
+
p_sum = np.einsum("ijk->k",p_ind)
|
|
791
|
+
n_ind = conv_out < 0
|
|
792
|
+
n_ind = conv_out * n_ind
|
|
793
|
+
n_sum = np.einsum("ijk->k",n_ind) * -1.0
|
|
794
|
+
t_sum = p_sum + n_sum
|
|
795
|
+
wt_mat = np.zeros_like(k)
|
|
796
|
+
p_saturate = p_sum > 0
|
|
797
|
+
n_saturate = n_sum > 0
|
|
798
|
+
if act["type"] == 'mono':
|
|
799
|
+
if act["range"]["l"]:
|
|
800
|
+
temp_ind = t_sum > act["range"]["l"]
|
|
801
|
+
p_saturate = temp_ind
|
|
802
|
+
if act["range"]["u"]:
|
|
803
|
+
temp_ind = t_sum < act["range"]["u"]
|
|
804
|
+
n_saturate = temp_ind
|
|
805
|
+
elif act["type"] == 'non_mono':
|
|
806
|
+
t_act = act["func"](t_sum)
|
|
807
|
+
p_act = act["func"](p_sum + bias_pos)
|
|
808
|
+
n_act = act["func"](-1 * (n_sum + bias_neg))
|
|
809
|
+
if act["range"]["l"]:
|
|
810
|
+
temp_ind = t_sum > act["range"]["l"]
|
|
811
|
+
p_saturate = p_saturate * temp_ind
|
|
812
|
+
if act["range"]["u"]:
|
|
813
|
+
temp_ind = t_sum < act["range"]["u"]
|
|
814
|
+
n_saturate = n_saturate * temp_ind
|
|
815
|
+
temp_ind = np.abs(t_act - p_act) > 1e-5
|
|
816
|
+
n_saturate = n_saturate * temp_ind
|
|
817
|
+
temp_ind = np.abs(t_act - n_act) > 1e-5
|
|
818
|
+
p_saturate = p_saturate * temp_ind
|
|
819
|
+
p_agg_wt = (1.0 / (p_sum + n_sum + bias_pos + bias_neg)) * wts * p_saturate
|
|
820
|
+
n_agg_wt = (1.0 / (p_sum + n_sum + bias_pos + bias_neg)) * wts * n_saturate
|
|
821
|
+
|
|
822
|
+
wt_mat = wt_mat + (p_ind * p_agg_wt)
|
|
823
|
+
wt_mat = wt_mat + (n_ind * n_agg_wt * -1.0)
|
|
824
|
+
wt_mat = np.sum(wt_mat, axis=-1)
|
|
825
|
+
return wt_mat
|
|
826
|
+
|
|
827
|
+
def calculate_wt_conv_1d(wts, inp, w, b, padding, stride, act):
|
|
828
|
+
wts = wts.T
|
|
829
|
+
inp = inp.T
|
|
830
|
+
w = w.T
|
|
831
|
+
stride=stride
|
|
832
|
+
input_padded, paddings = calculate_padding_1d(w.shape[0], inp, padding, stride)
|
|
833
|
+
out_ds = np.zeros_like(input_padded)
|
|
834
|
+
for ind in range(wts.shape[0]):
|
|
835
|
+
indexes = np.arange(ind * stride, ind * stride + w.shape[0])
|
|
836
|
+
tmp_patch = input_padded[indexes]
|
|
837
|
+
updates = calculate_wt_conv_unit_1d(tmp_patch, wts[ind, :], w, b, act)
|
|
838
|
+
out_ds[indexes] += updates
|
|
839
|
+
out_ds = out_ds[paddings[0][0]:(paddings[0][0] + inp.shape[0])]
|
|
840
|
+
return out_ds
|
|
841
|
+
|
|
842
|
+
def calculate_wt_max_unit_1d(patch, wts):
|
|
843
|
+
pmax = np.max(patch, axis=0)
|
|
844
|
+
indexes = (patch - pmax) == 0
|
|
845
|
+
indexes = indexes.astype(np.float32)
|
|
846
|
+
indexes_norm = 1.0 / np.sum(indexes, axis=0)
|
|
847
|
+
indexes = np.einsum("ij,j->ij", indexes, indexes_norm)
|
|
848
|
+
out = np.einsum("ij,j->ij", indexes, wts)
|
|
849
|
+
return out
|
|
850
|
+
|
|
851
|
+
def calculate_wt_maxpool_1d(wts, inp, pool_size, padding, stride):
|
|
852
|
+
inp = inp.T
|
|
853
|
+
wts = wts.T
|
|
854
|
+
input_padded, paddings = calculate_padding_1d(pool_size, inp, padding, stride, -np.inf)
|
|
855
|
+
out_ds = np.zeros_like(input_padded)
|
|
856
|
+
stride=stride
|
|
857
|
+
pool_size=pool_size
|
|
858
|
+
for ind in range(wts.shape[0]):
|
|
859
|
+
indexes = np.arange(ind * stride, ind * stride + pool_size)
|
|
860
|
+
tmp_patch = input_padded[indexes]
|
|
861
|
+
updates = calculate_wt_max_unit_1d(tmp_patch, wts[ind, :])
|
|
862
|
+
out_ds[indexes] += updates
|
|
863
|
+
out_ds = out_ds[paddings[0][0]:(paddings[0][0] + inp.shape[0])]
|
|
864
|
+
return out_ds
|
|
865
|
+
|
|
866
|
+
def calculate_wt_avg_unit_1d(patch, wts):
|
|
867
|
+
p_ind = patch > 0
|
|
868
|
+
p_ind = patch * p_ind
|
|
869
|
+
p_sum = np.sum(p_ind, axis=0)
|
|
870
|
+
n_ind = patch < 0
|
|
871
|
+
n_ind = patch * n_ind
|
|
872
|
+
n_sum = np.sum(n_ind, axis=0) * -1.0
|
|
873
|
+
t_sum = p_sum + n_sum
|
|
874
|
+
wt_mat = np.zeros_like(patch)
|
|
875
|
+
p_saturate = p_sum > 0
|
|
876
|
+
n_saturate = n_sum > 0
|
|
877
|
+
t_sum[t_sum == 0] = 1.0
|
|
878
|
+
p_agg_wt = (1.0 / t_sum) * wts * p_saturate
|
|
879
|
+
n_agg_wt = (1.0 / t_sum) * wts * n_saturate
|
|
880
|
+
wt_mat = wt_mat + (p_ind * p_agg_wt)
|
|
881
|
+
wt_mat = wt_mat + (n_ind * n_agg_wt * -1.0)
|
|
882
|
+
return wt_mat
|
|
883
|
+
|
|
884
|
+
def calculate_wt_avgpool_1d(wts, inp, pool_size, padding, stride):
|
|
885
|
+
wts = wts.T
|
|
886
|
+
inp = inp.T
|
|
887
|
+
stride=stride
|
|
888
|
+
pool_size=pool_size
|
|
889
|
+
input_padded, paddings = calculate_padding_1d(pool_size, inp, padding[0], stride[0], 0)
|
|
890
|
+
out_ds = np.zeros_like(input_padded)
|
|
891
|
+
for ind in range(wts.shape[0]):
|
|
892
|
+
indexes = np.arange(ind * stride[0], ind * stride[0] + pool_size[0])
|
|
893
|
+
tmp_patch = input_padded[indexes]
|
|
894
|
+
updates = calculate_wt_avg_unit_1d(tmp_patch, wts[ind, :])
|
|
895
|
+
out_ds[indexes] += updates
|
|
896
|
+
out_ds = out_ds[paddings[0][0]:(paddings[0][0] + inp.shape[0])]
|
|
897
|
+
return out_ds
|
|
898
|
+
|
|
899
|
+
def calculate_wt_gavgpool_1d(wts, inp):
|
|
900
|
+
channels = wts.shape[0]
|
|
901
|
+
wt_mat = np.zeros_like(inp)
|
|
902
|
+
for c in range(channels):
|
|
903
|
+
wt = wts[c]
|
|
904
|
+
temp_wt = wt_mat[:, c]
|
|
905
|
+
x = inp[:, c]
|
|
906
|
+
p_mat = np.copy(x)
|
|
907
|
+
n_mat = np.copy(x)
|
|
908
|
+
p_mat[p_mat < 0] = 0
|
|
909
|
+
n_mat[n_mat > 0] = 0
|
|
910
|
+
p_sum = np.sum(p_mat)
|
|
911
|
+
n_sum = np.sum(n_mat) * -1
|
|
912
|
+
p_agg_wt = 0.0
|
|
913
|
+
n_agg_wt = 0.0
|
|
914
|
+
if p_sum + n_sum > 0.0:
|
|
915
|
+
p_agg_wt = p_sum / (p_sum + n_sum)
|
|
916
|
+
n_agg_wt = n_sum / (p_sum + n_sum)
|
|
917
|
+
if p_sum == 0.0:
|
|
918
|
+
p_sum = 1.0
|
|
919
|
+
if n_sum == 0.0:
|
|
920
|
+
n_sum = 1.0
|
|
921
|
+
temp_wt = temp_wt + ((p_mat / p_sum) * wt * p_agg_wt)
|
|
922
|
+
temp_wt = temp_wt + ((n_mat / n_sum) * wt * n_agg_wt * -1.0)
|
|
923
|
+
wt_mat[:, c] = temp_wt
|
|
924
|
+
return wt_mat
|
|
925
|
+
|
|
926
|
+
def calculate_wt_gmaxpool_1d(wts, inp):
|
|
927
|
+
wts = wts.T
|
|
928
|
+
inp = inp.T
|
|
929
|
+
channels = wts.shape[0]
|
|
930
|
+
wt_mat = np.zeros_like(inp)
|
|
931
|
+
for c in range(channels):
|
|
932
|
+
wt = wts[c]
|
|
933
|
+
x = inp[:, c]
|
|
934
|
+
max_val = np.max(x)
|
|
935
|
+
max_indexes = (x == max_val).astype(np.float32)
|
|
936
|
+
max_indexes_norm = 1.0 / np.sum(max_indexes)
|
|
937
|
+
max_indexes = max_indexes * max_indexes_norm
|
|
938
|
+
wt_mat[:, c] = max_indexes * wt
|
|
939
|
+
return wt_mat
|
|
940
|
+
|
|
941
|
+
def calculate_output_padding_conv2d_transpose(input_shape, kernel_size, padding, strides):
|
|
942
|
+
if padding == 'valid':
|
|
943
|
+
out_shape = [(input_shape[0] - 1) * strides[0] + kernel_size[0],
|
|
944
|
+
(input_shape[1] - 1) * strides[1] + kernel_size[1]]
|
|
945
|
+
paddings = [[0, 0], [0, 0], [0, 0]]
|
|
946
|
+
elif padding == (0,0):
|
|
947
|
+
out_shape = [(input_shape[0] - 1) * strides[0] + kernel_size[0],
|
|
948
|
+
(input_shape[1] - 1) * strides[1] + kernel_size[1]]
|
|
949
|
+
paddings = [[0, 0], [0, 0], [0, 0]]
|
|
950
|
+
elif isinstance(padding, tuple) and padding != (None, None):
|
|
951
|
+
out_shape = [input_shape[0] * strides[0], input_shape[1] * strides[1]]
|
|
952
|
+
pad_h = padding[0]
|
|
953
|
+
pad_v = padding[1]
|
|
954
|
+
paddings = [[pad_h, pad_h], [pad_v, pad_v], [0, 0]]
|
|
955
|
+
else: # 'same' padding
|
|
956
|
+
out_shape = [input_shape[0] * strides[0], input_shape[1] * strides[1]]
|
|
957
|
+
pad_h = max(0, (input_shape[0] - 1) * strides[0] + kernel_size[0] - out_shape[0])
|
|
958
|
+
pad_v = max(0, (input_shape[1] - 1) * strides[1] + kernel_size[1] - out_shape[1])
|
|
959
|
+
paddings = [[pad_h // 2, pad_h - pad_h // 2],
|
|
960
|
+
[pad_v // 2, pad_v - pad_v // 2],
|
|
961
|
+
[0, 0]]
|
|
962
|
+
|
|
963
|
+
return out_shape, paddings
|
|
964
|
+
|
|
965
|
+
def calculate_wt_conv2d_transpose_unit(patch, wts, w, b, act):
|
|
966
|
+
if patch.ndim == 1:
|
|
967
|
+
patch = patch.reshape(1, 1, -1)
|
|
968
|
+
elif patch.ndim == 2:
|
|
969
|
+
patch = patch.reshape(1, *patch.shape)
|
|
970
|
+
elif patch.ndim != 3:
|
|
971
|
+
raise ValueError(f"Unexpected patch shape: {patch.shape}")
|
|
972
|
+
|
|
973
|
+
k = w.permute(0, 1, 3, 2).numpy()
|
|
974
|
+
bias = b.numpy()
|
|
975
|
+
b_ind = bias > 0
|
|
976
|
+
bias_pos = bias * b_ind
|
|
977
|
+
b_ind = bias < 0
|
|
978
|
+
bias_neg = bias * b_ind * -1.0
|
|
979
|
+
|
|
980
|
+
conv_out = np.einsum('ijkl,mnk->ijkl', k, patch)
|
|
981
|
+
p_ind = conv_out > 0
|
|
982
|
+
p_ind = conv_out * p_ind
|
|
983
|
+
n_ind = conv_out < 0
|
|
984
|
+
n_ind = conv_out * n_ind
|
|
985
|
+
|
|
986
|
+
p_sum = np.einsum("ijkl->l", p_ind)
|
|
987
|
+
n_sum = np.einsum("ijkl->l", n_ind) * -1.0
|
|
988
|
+
t_sum = p_sum + n_sum
|
|
989
|
+
|
|
990
|
+
wt_mat = np.zeros_like(k)
|
|
991
|
+
p_saturate = p_sum > 0
|
|
992
|
+
n_saturate = n_sum > 0
|
|
993
|
+
|
|
994
|
+
if act["type"] == 'mono':
|
|
995
|
+
if act["range"]["l"]:
|
|
996
|
+
p_saturate = t_sum > act["range"]["l"]
|
|
997
|
+
if act["range"]["u"]:
|
|
998
|
+
n_saturate = t_sum < act["range"]["u"]
|
|
999
|
+
elif act["type"] == 'non_mono':
|
|
1000
|
+
t_act = act["func"](t_sum)
|
|
1001
|
+
p_act = act["func"](p_sum + bias_pos)
|
|
1002
|
+
n_act = act["func"](-1 * (n_sum + bias_neg))
|
|
1003
|
+
if act["range"]["l"]:
|
|
1004
|
+
temp_ind = t_sum > act["range"]["l"]
|
|
1005
|
+
p_saturate = p_saturate * temp_ind
|
|
1006
|
+
if act["range"]["u"]:
|
|
1007
|
+
temp_ind = t_sum < act["range"]["u"]
|
|
1008
|
+
n_saturate = n_saturate * temp_ind
|
|
1009
|
+
temp_ind = np.abs(t_act - p_act) > 1e-5
|
|
1010
|
+
n_saturate = n_saturate * temp_ind
|
|
1011
|
+
temp_ind = np.abs(t_act - n_act) > 1e-5
|
|
1012
|
+
p_saturate = p_saturate * temp_ind
|
|
1013
|
+
|
|
1014
|
+
p_agg_wt = (1.0 / (p_sum + n_sum + bias_pos + bias_neg)) * wts * p_saturate
|
|
1015
|
+
n_agg_wt = (1.0 / (p_sum + n_sum + bias_pos + bias_neg)) * wts * n_saturate
|
|
1016
|
+
|
|
1017
|
+
wt_mat = wt_mat + (p_ind * p_agg_wt)
|
|
1018
|
+
wt_mat = wt_mat + (n_ind * n_agg_wt * -1.0)
|
|
1019
|
+
wt_mat = np.sum(wt_mat, axis=-1)
|
|
1020
|
+
return wt_mat
|
|
1021
|
+
|
|
1022
|
+
def calculate_wt_conv2d_transpose(wts, inp, w, b, padding, strides, act):
|
|
1023
|
+
wts = wts.T
|
|
1024
|
+
inp = inp.T
|
|
1025
|
+
w = w.T
|
|
1026
|
+
out_shape, paddings = calculate_output_padding_conv2d_transpose(inp.shape, w.shape, padding, strides)
|
|
1027
|
+
out_ds = np.zeros(out_shape + [w.shape[3]])
|
|
1028
|
+
|
|
1029
|
+
for ind1 in range(inp.shape[0]):
|
|
1030
|
+
for ind2 in range(inp.shape[1]):
|
|
1031
|
+
out_ind1 = ind1 * strides[0]
|
|
1032
|
+
out_ind2 = ind2 * strides[1]
|
|
1033
|
+
tmp_patch = inp[ind1, ind2, :]
|
|
1034
|
+
updates = calculate_wt_conv2d_transpose_unit(tmp_patch, wts[ind1, ind2, :], w, b, act)
|
|
1035
|
+
end_ind1 = min(out_ind1 + w.shape[0], out_shape[0])
|
|
1036
|
+
end_ind2 = min(out_ind2 + w.shape[1], out_shape[1])
|
|
1037
|
+
valid_updates = updates[:end_ind1 - out_ind1, :end_ind2 - out_ind2, :]
|
|
1038
|
+
out_ds[out_ind1:end_ind1, out_ind2:end_ind2, :] += valid_updates
|
|
1039
|
+
|
|
1040
|
+
if padding == 'same':
|
|
1041
|
+
adjusted_out_ds = np.zeros(inp.shape)
|
|
1042
|
+
for i in range(inp.shape[0]):
|
|
1043
|
+
for j in range(inp.shape[1]):
|
|
1044
|
+
start_i = max(0, i * strides[0])
|
|
1045
|
+
start_j = max(0, j * strides[1])
|
|
1046
|
+
end_i = min(out_ds.shape[0], (i+1) * strides[0])
|
|
1047
|
+
end_j = min(out_ds.shape[1], (j+1) * strides[1])
|
|
1048
|
+
relevant_area = out_ds[start_i:end_i, start_j:end_j, :]
|
|
1049
|
+
adjusted_out_ds[i, j, :] = np.sum(relevant_area, axis=(0, 1))
|
|
1050
|
+
out_ds = adjusted_out_ds
|
|
1051
|
+
elif isinstance(padding, tuple) and padding != (None, None):
|
|
1052
|
+
adjusted_out_ds = np.zeros(inp.shape)
|
|
1053
|
+
for i in range(inp.shape[0]):
|
|
1054
|
+
for j in range(inp.shape[1]):
|
|
1055
|
+
start_i = max(0, i * strides[0])
|
|
1056
|
+
start_j = max(0, j * strides[1])
|
|
1057
|
+
end_i = min(out_ds.shape[0], (i+1) * strides[0])
|
|
1058
|
+
end_j = min(out_ds.shape[1], (j+1) * strides[1])
|
|
1059
|
+
relevant_area = out_ds[start_i:end_i, start_j:end_j, :]
|
|
1060
|
+
adjusted_out_ds[i, j, :] = np.sum(relevant_area, axis=(0, 1))
|
|
1061
|
+
out_ds = adjusted_out_ds
|
|
1062
|
+
else:
|
|
1063
|
+
out_ds = out_ds[paddings[0][0]:(paddings[0][0] + inp.shape[0]),
|
|
1064
|
+
paddings[1][0]:(paddings[1][0] + inp.shape[1]), :]
|
|
1065
|
+
|
|
1066
|
+
return out_ds
|
|
1067
|
+
|
|
1068
|
+
|
|
1069
|
+
def calculate_output_padding_conv1d_transpose(input_shape, kernel_size, padding, strides,dilation):
|
|
1070
|
+
if padding == 'valid':
|
|
1071
|
+
out_shape = [(input_shape[0] - 1) * strides + kernel_size[0]]
|
|
1072
|
+
paddings = [[0, 0], [0, 0]]
|
|
1073
|
+
elif padding == 0:
|
|
1074
|
+
out_shape = [(input_shape[0] - 1) * strides + kernel_size[0]]
|
|
1075
|
+
paddings = [[0, 0], [0, 0]]
|
|
1076
|
+
elif isinstance(padding, int):
|
|
1077
|
+
out_shape = [input_shape[0] * strides]
|
|
1078
|
+
pad_v = (dilation * (kernel_size[0] - 1)) - padding
|
|
1079
|
+
out_shape = [input_shape[0] * strides + pad_v]
|
|
1080
|
+
paddings = [[pad_v, pad_v],
|
|
1081
|
+
[0, 0]]
|
|
1082
|
+
else: # 'same' padding
|
|
1083
|
+
out_shape = [input_shape[0] * strides]
|
|
1084
|
+
pad_h = max(0, (input_shape[0] - 1) * strides + kernel_size[0] - out_shape[0])
|
|
1085
|
+
paddings = [[pad_h // 2, pad_h // 2],
|
|
1086
|
+
[0, 0]]
|
|
1087
|
+
|
|
1088
|
+
return out_shape, paddings
|
|
1089
|
+
|
|
1090
|
+
def calculate_wt_conv1d_transpose_unit(patch, wts, w, b, act):
|
|
1091
|
+
if patch.ndim == 1:
|
|
1092
|
+
patch = patch.reshape(1, -1)
|
|
1093
|
+
elif patch.ndim != 2:
|
|
1094
|
+
raise ValueError(f"Unexpected patch shape: {patch.shape}")
|
|
1095
|
+
|
|
1096
|
+
k = w.permute(0, 2, 1).numpy()
|
|
1097
|
+
bias = b.numpy()
|
|
1098
|
+
b_ind = bias > 0
|
|
1099
|
+
bias_pos = bias * b_ind
|
|
1100
|
+
b_ind = bias < 0
|
|
1101
|
+
bias_neg = bias * b_ind * -1.0
|
|
1102
|
+
conv_out = np.einsum('ijk,mj->ijk', k, patch)
|
|
1103
|
+
p_ind = conv_out > 0
|
|
1104
|
+
p_ind = conv_out * p_ind
|
|
1105
|
+
n_ind = conv_out < 0
|
|
1106
|
+
n_ind = conv_out * n_ind
|
|
1107
|
+
|
|
1108
|
+
p_sum = np.einsum("ijl->l", p_ind)
|
|
1109
|
+
n_sum = np.einsum("ijl->l", n_ind) * -1.0
|
|
1110
|
+
t_sum = p_sum + n_sum
|
|
1111
|
+
|
|
1112
|
+
wt_mat = np.zeros_like(k)
|
|
1113
|
+
p_saturate = p_sum > 0
|
|
1114
|
+
n_saturate = n_sum > 0
|
|
1115
|
+
|
|
1116
|
+
if act["type"] == 'mono':
|
|
1117
|
+
if act["range"]["l"]:
|
|
1118
|
+
p_saturate = t_sum > act["range"]["l"]
|
|
1119
|
+
if act["range"]["u"]:
|
|
1120
|
+
n_saturate = t_sum < act["range"]["u"]
|
|
1121
|
+
elif act["type"] == 'non_mono':
|
|
1122
|
+
t_act = act["func"](t_sum)
|
|
1123
|
+
p_act = act["func"](p_sum + bias_pos)
|
|
1124
|
+
n_act = act["func"](-1 * (n_sum + bias_neg))
|
|
1125
|
+
if act["range"]["l"]:
|
|
1126
|
+
temp_ind = t_sum > act["range"]["l"]
|
|
1127
|
+
p_saturate = p_saturate * temp_ind
|
|
1128
|
+
if act["range"]["u"]:
|
|
1129
|
+
temp_ind = t_sum < act["range"]["u"]
|
|
1130
|
+
n_saturate = n_saturate * temp_ind
|
|
1131
|
+
temp_ind = np.abs(t_act - p_act) > 1e-5
|
|
1132
|
+
n_saturate = n_saturate * temp_ind
|
|
1133
|
+
temp_ind = np.abs(t_act - n_act) > 1e-5
|
|
1134
|
+
p_saturate = p_saturate * temp_ind
|
|
1135
|
+
|
|
1136
|
+
p_agg_wt = (1.0 / (p_sum + n_sum + bias_pos + bias_neg)) * wts * p_saturate
|
|
1137
|
+
n_agg_wt = (1.0 / (p_sum + n_sum + bias_pos + bias_neg)) * wts * n_saturate
|
|
1138
|
+
wt_mat = wt_mat + (p_ind * p_agg_wt)
|
|
1139
|
+
wt_mat = wt_mat + (n_ind * n_agg_wt * -1.0)
|
|
1140
|
+
wt_mat = np.sum(wt_mat, axis=-1)
|
|
1141
|
+
return wt_mat
|
|
1142
|
+
|
|
1143
|
+
def calculate_wt_conv1d_transpose(wts, inp, w, b, padding, strides, dilation, act):
|
|
1144
|
+
wts = wts.T
|
|
1145
|
+
inp = inp.T
|
|
1146
|
+
w = w.T
|
|
1147
|
+
out_shape, paddings = calculate_output_padding_conv1d_transpose(inp.shape, w.shape, padding, strides, dilation)
|
|
1148
|
+
out_ds = np.zeros(out_shape + [w.shape[2]])
|
|
1149
|
+
|
|
1150
|
+
for ind in range(inp.shape[0]):
|
|
1151
|
+
out_ind = ind * strides
|
|
1152
|
+
tmp_patch = inp[ind, :]
|
|
1153
|
+
updates = calculate_wt_conv1d_transpose_unit(tmp_patch, wts[ind, :], w, b, act)
|
|
1154
|
+
end_ind = min(out_ind + w.shape[0], out_shape[0])
|
|
1155
|
+
valid_updates = updates[:end_ind - out_ind, :]
|
|
1156
|
+
out_ds[out_ind:end_ind, :] += valid_updates
|
|
1157
|
+
|
|
1158
|
+
if padding == 'same':
|
|
1159
|
+
adjusted_out_ds = np.zeros(inp.shape)
|
|
1160
|
+
for i in range(inp.shape[0]):
|
|
1161
|
+
start_i = max(0, i * strides)
|
|
1162
|
+
end_i = min(out_ds.shape[0], (i + 1) * strides)
|
|
1163
|
+
relevant_area = out_ds[start_i:end_i, :]
|
|
1164
|
+
adjusted_out_ds[i, :] = np.sum(relevant_area, axis=0)
|
|
1165
|
+
out_ds = adjusted_out_ds
|
|
1166
|
+
elif padding == 0:
|
|
1167
|
+
adjusted_out_ds = np.zeros(inp.shape)
|
|
1168
|
+
for i in range(inp.shape[0]):
|
|
1169
|
+
start_i = max(0, i * strides)
|
|
1170
|
+
end_i = min(out_ds.shape[0], (i + 1) * strides)
|
|
1171
|
+
relevant_area = out_ds[start_i:end_i, :]
|
|
1172
|
+
adjusted_out_ds[i, :] = np.sum(relevant_area, axis=0)
|
|
1173
|
+
out_ds = adjusted_out_ds
|
|
1174
|
+
else:
|
|
1175
|
+
out_ds = out_ds[paddings[0][0]:(paddings[0][0] + inp.shape[0]), :]
|
|
1176
|
+
return out_ds
|
|
1177
|
+
|
|
748
1178
|
|
|
749
1179
|
####################################################################
|
|
750
1180
|
################### Encoder Model ####################
|
|
@@ -753,27 +1183,85 @@ def stabilize(matrix, epsilon=1e-6):
|
|
|
753
1183
|
return matrix + epsilon * np.sign(matrix)
|
|
754
1184
|
|
|
755
1185
|
|
|
756
|
-
def
|
|
757
|
-
|
|
758
|
-
|
|
1186
|
+
def calculate_wt_residual(wts, inp=None):
|
|
1187
|
+
wt_mat = []
|
|
1188
|
+
inp_list = []
|
|
1189
|
+
expanded_wts = as_strided(
|
|
1190
|
+
wts,
|
|
1191
|
+
shape=(np.prod(wts.shape),),
|
|
1192
|
+
strides=(wts.strides[-1],),
|
|
1193
|
+
writeable=False, # totally use this to avoid writing to memory in weird places
|
|
1194
|
+
)
|
|
1195
|
+
|
|
1196
|
+
for x in inp:
|
|
1197
|
+
expanded_input = as_strided(
|
|
1198
|
+
x,
|
|
1199
|
+
shape=(np.prod(x.shape),),
|
|
1200
|
+
strides=(x.strides[-1],),
|
|
1201
|
+
writeable=False, # totally use this to avoid writing to memory in weird places
|
|
1202
|
+
)
|
|
1203
|
+
inp_list.append(expanded_input)
|
|
1204
|
+
wt_mat.append(np.zeros_like(expanded_input))
|
|
1205
|
+
wt_mat = np.array(wt_mat)
|
|
1206
|
+
inp_list = np.array(inp_list)
|
|
1207
|
+
for i in range(wt_mat.shape[1]):
|
|
1208
|
+
wt_ind1 = wt_mat[:, i]
|
|
1209
|
+
wt = expanded_wts[i]
|
|
1210
|
+
l1_ind1 = inp_list[:, i]
|
|
1211
|
+
p_ind = l1_ind1 > 0
|
|
1212
|
+
n_ind = l1_ind1 < 0
|
|
1213
|
+
p_sum = np.sum(l1_ind1[p_ind])
|
|
1214
|
+
n_sum = np.sum(l1_ind1[n_ind]) * -1
|
|
1215
|
+
t_sum = p_sum - n_sum
|
|
1216
|
+
p_agg_wt = 0
|
|
1217
|
+
n_agg_wt = 0
|
|
1218
|
+
if p_sum + n_sum > 0:
|
|
1219
|
+
p_agg_wt = p_sum / (p_sum + n_sum)
|
|
1220
|
+
n_agg_wt = n_sum / (p_sum + n_sum)
|
|
1221
|
+
if p_sum == 0:
|
|
1222
|
+
p_sum = 1
|
|
1223
|
+
if n_sum == 0:
|
|
1224
|
+
n_sum = 1
|
|
1225
|
+
wt_ind1[p_ind] = (l1_ind1[p_ind] / p_sum) * wt * p_agg_wt
|
|
1226
|
+
wt_ind1[n_ind] = (l1_ind1[n_ind] / n_sum) * wt * n_agg_wt * -1.0
|
|
1227
|
+
wt_mat[:, i] = wt_ind1
|
|
1228
|
+
wt_mat = [i.reshape(wts.shape) for i in list(wt_mat)]
|
|
1229
|
+
return wt_mat
|
|
1230
|
+
|
|
1231
|
+
|
|
1232
|
+
def calculate_relevance_V(wts, value_output, w):
|
|
1233
|
+
wt_mat_V = np.zeros(value_output.shape)
|
|
1234
|
+
|
|
1235
|
+
if 'b_v' in w:
|
|
1236
|
+
bias_v = w['b_v']
|
|
1237
|
+
else:
|
|
1238
|
+
bias_v = 0
|
|
759
1239
|
|
|
760
1240
|
for i in range(wts.shape[0]):
|
|
761
1241
|
for j in range(wts.shape[1]):
|
|
762
1242
|
l1_ind1 = value_output
|
|
763
|
-
wt_ind1 = wt_mat_V[i, j]
|
|
764
1243
|
wt = wts[i, j]
|
|
765
1244
|
|
|
766
1245
|
p_ind = l1_ind1 > 0
|
|
767
1246
|
n_ind = l1_ind1 < 0
|
|
768
1247
|
p_sum = np.sum(l1_ind1[p_ind])
|
|
769
1248
|
n_sum = np.sum(l1_ind1[n_ind]) * -1
|
|
1249
|
+
|
|
1250
|
+
if bias_v[i] > 0:
|
|
1251
|
+
pbias = bias_v[i]
|
|
1252
|
+
nbias = 0
|
|
1253
|
+
else:
|
|
1254
|
+
pbias = 0
|
|
1255
|
+
nbias = bias_v[i] * -1
|
|
770
1256
|
|
|
771
1257
|
if p_sum > 0:
|
|
772
|
-
p_agg_wt = p_sum / (p_sum + n_sum)
|
|
1258
|
+
p_agg_wt = (p_sum + pbias) / (p_sum + n_sum + pbias + nbias)
|
|
1259
|
+
p_agg_wt = p_agg_wt * (p_sum / (p_sum + pbias))
|
|
773
1260
|
else:
|
|
774
1261
|
p_agg_wt = 0
|
|
775
1262
|
if n_sum > 0:
|
|
776
|
-
n_agg_wt = n_sum / (p_sum + n_sum)
|
|
1263
|
+
n_agg_wt = (n_sum + nbias) / (p_sum + n_sum + pbias + nbias)
|
|
1264
|
+
n_agg_wt = n_agg_wt * (n_sum / (n_sum + nbias))
|
|
777
1265
|
else:
|
|
778
1266
|
n_agg_wt = 0
|
|
779
1267
|
|
|
@@ -782,21 +1270,22 @@ def calculate_relevance_V(wts, value_output):
|
|
|
782
1270
|
if n_sum == 0:
|
|
783
1271
|
n_sum = 1
|
|
784
1272
|
|
|
785
|
-
|
|
786
|
-
|
|
1273
|
+
wt_mat_V[p_ind] += (l1_ind1[p_ind] / p_sum) * wt * p_agg_wt
|
|
1274
|
+
wt_mat_V[n_ind] += (l1_ind1[n_ind] / n_sum) * wt * n_agg_wt * -1.0
|
|
787
1275
|
|
|
788
|
-
wt_mat_V = np.sum(wt_mat_V, axis=(0,1))
|
|
789
1276
|
return wt_mat_V
|
|
790
1277
|
|
|
791
1278
|
|
|
792
|
-
def calculate_relevance_QK(wts, QK_output):
|
|
793
|
-
|
|
794
|
-
|
|
1279
|
+
def calculate_relevance_QK(wts, QK_output, w):
|
|
1280
|
+
wt_mat_QK = np.zeros(QK_output.shape)
|
|
1281
|
+
|
|
1282
|
+
# Check if 'b_q' and 'b_k' exist in the weights, default to 0 if not
|
|
1283
|
+
b_q = w['b_q'] if 'b_q' in w else 0
|
|
1284
|
+
b_k = w['b_k'] if 'b_k' in w else 0
|
|
795
1285
|
|
|
796
1286
|
for i in range(wts.shape[0]):
|
|
797
1287
|
for j in range(wts.shape[1]):
|
|
798
1288
|
l1_ind1 = QK_output
|
|
799
|
-
wt_ind1 = wt_mat_QK[i, j]
|
|
800
1289
|
wt = wts[i, j]
|
|
801
1290
|
|
|
802
1291
|
p_ind = l1_ind1 > 0
|
|
@@ -804,7 +1293,21 @@ def calculate_relevance_QK(wts, QK_output):
|
|
|
804
1293
|
p_sum = np.sum(l1_ind1[p_ind])
|
|
805
1294
|
n_sum = np.sum(l1_ind1[n_ind]) * -1
|
|
806
1295
|
|
|
807
|
-
|
|
1296
|
+
if b_q[i] > 0 and b_k[i] > 0:
|
|
1297
|
+
pbias = b_q[i] + b_k[i]
|
|
1298
|
+
nbias = 0
|
|
1299
|
+
elif b_q[i] > 0 and b_k[i] < 0:
|
|
1300
|
+
pbias = b_q[i]
|
|
1301
|
+
nbias = b_k[i] * -1
|
|
1302
|
+
elif b_q[i] < 0 and b_k[i] > 0:
|
|
1303
|
+
pbias = b_k[i]
|
|
1304
|
+
nbias = b_q[i] * -1
|
|
1305
|
+
else:
|
|
1306
|
+
pbias = 0
|
|
1307
|
+
nbias = b_q[i] + b_k[i]
|
|
1308
|
+
nbias *= -1
|
|
1309
|
+
|
|
1310
|
+
t_sum = p_sum + pbias - n_sum - nbias
|
|
808
1311
|
|
|
809
1312
|
# This layer has a softmax activation function
|
|
810
1313
|
act = {
|
|
@@ -823,12 +1326,13 @@ def calculate_relevance_QK(wts, QK_output):
|
|
|
823
1326
|
n_sum = 0
|
|
824
1327
|
|
|
825
1328
|
if p_sum > 0:
|
|
826
|
-
p_agg_wt = p_sum / (p_sum + n_sum)
|
|
1329
|
+
p_agg_wt = (p_sum + pbias) / (p_sum + n_sum + pbias + nbias)
|
|
1330
|
+
p_agg_wt = p_agg_wt * (p_sum / (p_sum + pbias))
|
|
827
1331
|
else:
|
|
828
1332
|
p_agg_wt = 0
|
|
829
|
-
|
|
830
1333
|
if n_sum > 0:
|
|
831
|
-
n_agg_wt = n_sum / (p_sum + n_sum)
|
|
1334
|
+
n_agg_wt = (n_sum + nbias) / (p_sum + n_sum + pbias + nbias)
|
|
1335
|
+
n_agg_wt = n_agg_wt * (n_sum / (n_sum + nbias))
|
|
832
1336
|
else:
|
|
833
1337
|
n_agg_wt = 0
|
|
834
1338
|
|
|
@@ -837,14 +1341,60 @@ def calculate_relevance_QK(wts, QK_output):
|
|
|
837
1341
|
if n_sum == 0:
|
|
838
1342
|
n_sum = 1
|
|
839
1343
|
|
|
840
|
-
|
|
841
|
-
|
|
1344
|
+
wt_mat_QK[p_ind] += (l1_ind1[p_ind] / p_sum) * wt * p_agg_wt
|
|
1345
|
+
wt_mat_QK[n_ind] += (l1_ind1[n_ind] / n_sum) * wt * n_agg_wt * -1.0
|
|
842
1346
|
|
|
843
|
-
wt_mat_QK = np.sum(wt_mat_QK, axis=(0, 1))
|
|
844
1347
|
return wt_mat_QK
|
|
845
1348
|
|
|
846
1349
|
|
|
847
|
-
def
|
|
1350
|
+
def calculate_wt_attention_output_projection(wts, proj_output, w):
|
|
1351
|
+
wt_mat_proj_output = np.zeros(proj_output.shape)
|
|
1352
|
+
|
|
1353
|
+
if 'b_d' in w:
|
|
1354
|
+
bias_d = w['b_d']
|
|
1355
|
+
else:
|
|
1356
|
+
bias_d = 0
|
|
1357
|
+
|
|
1358
|
+
for i in range(wts.shape[0]):
|
|
1359
|
+
for j in range(wts.shape[1]):
|
|
1360
|
+
l1_ind1 = proj_output
|
|
1361
|
+
wt = wts[i, j]
|
|
1362
|
+
|
|
1363
|
+
p_ind = l1_ind1 > 0
|
|
1364
|
+
n_ind = l1_ind1 < 0
|
|
1365
|
+
p_sum = np.sum(l1_ind1[p_ind])
|
|
1366
|
+
n_sum = np.sum(l1_ind1[n_ind]) * -1
|
|
1367
|
+
|
|
1368
|
+
if bias_d[i] > 0:
|
|
1369
|
+
pbias = bias_d[i]
|
|
1370
|
+
nbias = 0
|
|
1371
|
+
else:
|
|
1372
|
+
pbias = 0
|
|
1373
|
+
nbias = bias_d[i] * -1
|
|
1374
|
+
|
|
1375
|
+
if p_sum > 0:
|
|
1376
|
+
p_agg_wt = (p_sum + pbias) / (p_sum + n_sum + pbias + nbias)
|
|
1377
|
+
p_agg_wt = p_agg_wt * (p_sum / (p_sum + pbias))
|
|
1378
|
+
else:
|
|
1379
|
+
p_agg_wt = 0
|
|
1380
|
+
if n_sum > 0:
|
|
1381
|
+
n_agg_wt = (n_sum + nbias) / (p_sum + n_sum + pbias + nbias)
|
|
1382
|
+
n_agg_wt = n_agg_wt * (n_sum / (n_sum + nbias))
|
|
1383
|
+
else:
|
|
1384
|
+
n_agg_wt = 0
|
|
1385
|
+
|
|
1386
|
+
if p_sum == 0:
|
|
1387
|
+
p_sum = 1
|
|
1388
|
+
if n_sum == 0:
|
|
1389
|
+
n_sum = 1
|
|
1390
|
+
|
|
1391
|
+
wt_mat_proj_output[p_ind] += (l1_ind1[p_ind] / p_sum) * wt * p_agg_wt
|
|
1392
|
+
wt_mat_proj_output[n_ind] += (l1_ind1[n_ind] / n_sum) * wt * n_agg_wt * -1.0
|
|
1393
|
+
|
|
1394
|
+
return wt_mat_proj_output
|
|
1395
|
+
|
|
1396
|
+
|
|
1397
|
+
def calculate_wt_self_attention(wts, inp, w, config):
|
|
848
1398
|
'''
|
|
849
1399
|
Input:
|
|
850
1400
|
wts: relevance score of the layer
|
|
@@ -856,28 +1406,82 @@ def calculate_wt_self_attention(wts, inp, w):
|
|
|
856
1406
|
Step-2: outputs = F.softmax(inputs, dim=dim, dtype=dtype)
|
|
857
1407
|
Step-3: outputs = input_a * input_b
|
|
858
1408
|
'''
|
|
1409
|
+
# print(f"inp: {inp.shape}, wts: {wts.shape}") # (1, 512)
|
|
1410
|
+
# print(f"w['W_q']: {w['W_q'].shape}, w['W_k']: {w['W_k'].shape}, w['W_v']: {w['W_v'].shape}")
|
|
1411
|
+
|
|
859
1412
|
query_output = np.einsum('ij,kj->ik', inp, w['W_q'])
|
|
860
1413
|
key_output = np.einsum('ij,kj->ik', inp, w['W_k'])
|
|
861
1414
|
value_output = np.einsum('ij,kj->ik', inp, w['W_v'])
|
|
862
1415
|
|
|
1416
|
+
# --------------- Reshape for Multi-Head Attention ----------------------
|
|
1417
|
+
num_heads = getattr(config, 'num_attention_heads', getattr(config, 'num_heads', None)) # will work for BERT as well as T5/ Llama
|
|
1418
|
+
hidden_size = getattr(config, 'hidden_size', getattr(config, 'd_model', None)) # will work for BERT as well as T5/Llama
|
|
1419
|
+
if hasattr(config, 'num_key_value_heads'):
|
|
1420
|
+
num_key_value_heads = config.num_key_value_heads
|
|
1421
|
+
else:
|
|
1422
|
+
num_key_value_heads = num_heads
|
|
1423
|
+
head_dim = hidden_size // num_heads # dimension of each attention head
|
|
1424
|
+
|
|
1425
|
+
query_states = np.einsum('thd->htd', query_output.reshape(query_output.shape[0], num_heads, head_dim)) # (num_heads, num_tokens, head_dim)
|
|
1426
|
+
key_states = np.einsum('thd->htd', key_output.reshape(key_output.shape[0], num_key_value_heads, head_dim)) # (num_key_value_heads, num_tokens, head_dim)
|
|
1427
|
+
value_states = np.einsum('thd->htd', value_output.reshape(value_output.shape[0], num_key_value_heads, head_dim)) # (num_key_value_heads, num_tokens, head_dim)
|
|
1428
|
+
|
|
1429
|
+
# calculate how many times we need to repeat the key/value heads
|
|
1430
|
+
n_rep = num_heads // num_key_value_heads
|
|
1431
|
+
key_states = np.repeat(key_states, n_rep, axis=0)
|
|
1432
|
+
value_states = np.repeat(value_states, n_rep, axis=0)
|
|
1433
|
+
|
|
1434
|
+
QK_output = np.einsum('hqd,hkd->hqk', query_states, key_states) # (num_heads, num_tokens, num_tokens)
|
|
1435
|
+
attn_weights = QK_output / np.sqrt(head_dim)
|
|
1436
|
+
|
|
1437
|
+
# Apply softmax along the last dimension (softmax over key dimension)
|
|
1438
|
+
attn_weights = np.exp(attn_weights - np.max(attn_weights, axis=-1, keepdims=True)) # Numerically stable softmax
|
|
1439
|
+
attn_weights = attn_weights / np.sum(attn_weights, axis=-1, keepdims=True)
|
|
1440
|
+
|
|
1441
|
+
# Weighted sum of values (num_heads, num_tokens, head_dim)
|
|
1442
|
+
attn_output = np.einsum('hqk,hkl->hql', attn_weights, value_states)
|
|
1443
|
+
|
|
1444
|
+
transposed_attn_output = np.einsum('hqd->qhd', attn_output)
|
|
1445
|
+
reshaped_attn_output = transposed_attn_output.reshape(transposed_attn_output.shape[0], num_heads * head_dim)
|
|
1446
|
+
|
|
1447
|
+
# Perform final linear projection (num_tokens, hidden_size)
|
|
1448
|
+
final_output = np.einsum('qd,dh->qh', reshaped_attn_output, w['W_d'])
|
|
1449
|
+
|
|
1450
|
+
# ------------- Relevance calculation for Final Linear Projection -------------
|
|
1451
|
+
wt_mat_attn_proj = calculate_wt_attention_output_projection(wts, final_output, w)
|
|
1452
|
+
|
|
863
1453
|
# --------------- Relevance Calculation for Step-3 -----------------------
|
|
864
|
-
|
|
865
|
-
|
|
1454
|
+
# divide the relevance among `attn_weights` and `value_states`
|
|
1455
|
+
wt_mat_attn_proj = wt_mat_attn_proj.reshape(-1, num_heads, head_dim)
|
|
1456
|
+
wt_mat_attn_proj = np.einsum('qhd->hqd', wt_mat_attn_proj)
|
|
1457
|
+
|
|
1458
|
+
stabilized_attn_output = stabilize(attn_output * 2)
|
|
1459
|
+
norm_wt_mat_attn_proj = wt_mat_attn_proj / stabilized_attn_output
|
|
1460
|
+
relevance_QK = np.einsum('htd,hbd->htb', norm_wt_mat_attn_proj, value_states) * attn_weights
|
|
1461
|
+
relevance_V = np.einsum('hdt,hdb->htb', attn_weights, norm_wt_mat_attn_proj) * value_states
|
|
866
1462
|
|
|
867
1463
|
# --------------- Relevance Calculation for V --------------------------------
|
|
868
|
-
|
|
1464
|
+
relevance_V = np.einsum('hqd->qhd', relevance_V)
|
|
1465
|
+
relevance_V = relevance_V.reshape(-1, num_heads * head_dim)
|
|
1466
|
+
wt_mat_V = calculate_relevance_V(relevance_V, value_states, w)
|
|
869
1467
|
|
|
870
1468
|
# --------------- Transformed Relevance QK ----------------------------------
|
|
871
|
-
|
|
872
|
-
|
|
1469
|
+
relevance_QK = np.einsum('hqd->qhd', relevance_QK)
|
|
1470
|
+
relevance_QK = relevance_QK.reshape(-1, relevance_QK.shape[1] * relevance_QK.shape[2])
|
|
1471
|
+
wt_mat_QK = calculate_relevance_QK(relevance_QK, QK_output, w)
|
|
873
1472
|
|
|
874
1473
|
# --------------- Relevance Calculation for K and Q --------------------------------
|
|
875
1474
|
stabilized_QK_output = stabilize(QK_output * 2)
|
|
876
1475
|
norm_wt_mat_QK = wt_mat_QK / stabilized_QK_output
|
|
877
|
-
wt_mat_Q = np.einsum('
|
|
878
|
-
wt_mat_K = np.einsum('
|
|
1476
|
+
wt_mat_Q = np.einsum('htd,hdb->htb', norm_wt_mat_QK, key_states) * query_states
|
|
1477
|
+
wt_mat_K = np.einsum('htd,htb->hbd', query_states, norm_wt_mat_QK) * key_states
|
|
879
1478
|
|
|
880
1479
|
wt_mat = wt_mat_V + wt_mat_K + wt_mat_Q
|
|
1480
|
+
|
|
1481
|
+
# Reshape wt_mat
|
|
1482
|
+
wt_mat = np.einsum('htd->thd', wt_mat)
|
|
1483
|
+
wt_mat = wt_mat.reshape(wt_mat.shape[0], wt_mat.shape[1] * wt_mat.shape[2]) # reshaped_array = array.reshape(8, 32 * 128)
|
|
1484
|
+
|
|
881
1485
|
return wt_mat
|
|
882
1486
|
|
|
883
1487
|
|
|
@@ -893,7 +1497,9 @@ def calculate_wt_feed_forward(wts, inp, w):
|
|
|
893
1497
|
R2 = wts[i]
|
|
894
1498
|
contribution_matrix2 = np.einsum('ij,j->ij', w['W_out'], intermediate_output[i])
|
|
895
1499
|
wt_mat2 = np.zeros(contribution_matrix2.shape)
|
|
896
|
-
|
|
1500
|
+
|
|
1501
|
+
bias_out = w['b_out'] if 'b_out' in w else 0
|
|
1502
|
+
|
|
897
1503
|
for j in range(contribution_matrix2.shape[0]):
|
|
898
1504
|
l1_ind1 = contribution_matrix2[j]
|
|
899
1505
|
wt_ind1 = wt_mat2[j]
|
|
@@ -903,14 +1509,23 @@ def calculate_wt_feed_forward(wts, inp, w):
|
|
|
903
1509
|
n_ind = l1_ind1 < 0
|
|
904
1510
|
p_sum = np.sum(l1_ind1[p_ind])
|
|
905
1511
|
n_sum = np.sum(l1_ind1[n_ind]) * -1
|
|
1512
|
+
|
|
1513
|
+
# Handle positive and negative bias contributions
|
|
1514
|
+
if bias_out[i] > 0:
|
|
1515
|
+
pbias = bias_out[i]
|
|
1516
|
+
nbias = 0
|
|
1517
|
+
else:
|
|
1518
|
+
pbias = 0
|
|
1519
|
+
nbias = -bias_out[i]
|
|
906
1520
|
|
|
907
1521
|
if p_sum > 0:
|
|
908
|
-
p_agg_wt = p_sum / (p_sum + n_sum)
|
|
1522
|
+
p_agg_wt = (p_sum + pbias) / (p_sum + n_sum + pbias + nbias)
|
|
1523
|
+
p_agg_wt = p_agg_wt * (p_sum / (p_sum + pbias))
|
|
909
1524
|
else:
|
|
910
1525
|
p_agg_wt = 0
|
|
911
|
-
|
|
912
1526
|
if n_sum > 0:
|
|
913
|
-
n_agg_wt = n_sum / (p_sum + n_sum)
|
|
1527
|
+
n_agg_wt = (n_sum + nbias) / (p_sum + n_sum + pbias + nbias)
|
|
1528
|
+
n_agg_wt = n_agg_wt * (n_sum / (n_sum + nbias))
|
|
914
1529
|
else:
|
|
915
1530
|
n_agg_wt = 0
|
|
916
1531
|
|
|
@@ -929,6 +1544,9 @@ def calculate_wt_feed_forward(wts, inp, w):
|
|
|
929
1544
|
R1 = relevance_out[i]
|
|
930
1545
|
contribution_matrix1 = np.einsum('ij,j->ij', w['W_int'], inp[i])
|
|
931
1546
|
wt_mat1 = np.zeros(contribution_matrix1.shape)
|
|
1547
|
+
|
|
1548
|
+
# Check if bias 'b_int' exists, default to 0 if not
|
|
1549
|
+
bias_int = w['b_int'] if 'b_int' in w else 0
|
|
932
1550
|
|
|
933
1551
|
for j in range(contribution_matrix1.shape[0]):
|
|
934
1552
|
l1_ind1 = contribution_matrix1[j]
|
|
@@ -940,7 +1558,15 @@ def calculate_wt_feed_forward(wts, inp, w):
|
|
|
940
1558
|
p_sum = np.sum(l1_ind1[p_ind])
|
|
941
1559
|
n_sum = np.sum(l1_ind1[n_ind]) * -1
|
|
942
1560
|
|
|
943
|
-
|
|
1561
|
+
# Handle positive and negative bias
|
|
1562
|
+
if bias_int[i] > 0:
|
|
1563
|
+
pbias = bias_int[i]
|
|
1564
|
+
nbias = 0
|
|
1565
|
+
else:
|
|
1566
|
+
pbias = 0
|
|
1567
|
+
nbias = -bias_int[i]
|
|
1568
|
+
|
|
1569
|
+
t_sum = p_sum + pbias - n_sum - nbias
|
|
944
1570
|
|
|
945
1571
|
# This layer has a ReLU activation function
|
|
946
1572
|
act = {
|
|
@@ -959,12 +1585,13 @@ def calculate_wt_feed_forward(wts, inp, w):
|
|
|
959
1585
|
n_sum = 0
|
|
960
1586
|
|
|
961
1587
|
if p_sum > 0:
|
|
962
|
-
p_agg_wt = p_sum / (p_sum + n_sum)
|
|
1588
|
+
p_agg_wt = (p_sum + pbias) / (p_sum + n_sum + pbias + nbias)
|
|
1589
|
+
p_agg_wt = p_agg_wt * (p_sum / (p_sum + pbias))
|
|
963
1590
|
else:
|
|
964
1591
|
p_agg_wt = 0
|
|
965
|
-
|
|
966
1592
|
if n_sum > 0:
|
|
967
|
-
n_agg_wt = n_sum / (p_sum + n_sum)
|
|
1593
|
+
n_agg_wt = (n_sum + nbias) / (p_sum + n_sum + pbias + nbias)
|
|
1594
|
+
n_agg_wt = n_agg_wt * (n_sum / (n_sum + nbias))
|
|
968
1595
|
else:
|
|
969
1596
|
n_agg_wt = 0
|
|
970
1597
|
|
|
@@ -1121,7 +1748,7 @@ def calculate_wt_pooler(wts, inp, w):
|
|
|
1121
1748
|
# Calculate relevance for each token
|
|
1122
1749
|
relevance_inp[i] = wt_mat.sum(axis=0)
|
|
1123
1750
|
|
|
1124
|
-
relevance_inp *= (
|
|
1751
|
+
relevance_inp *= (np.sum(wts) / np.sum(relevance_inp))
|
|
1125
1752
|
return relevance_inp
|
|
1126
1753
|
|
|
1127
1754
|
|