dl-backtrace 0.0.18__py3-none-any.whl → 0.0.20.dev36__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of dl-backtrace might be problematic. Click here for more details.
- dl_backtrace/pytorch_backtrace/backtrace/backtrace.py +194 -70
- dl_backtrace/pytorch_backtrace/backtrace/utils/contrast.py +607 -156
- dl_backtrace/pytorch_backtrace/backtrace/utils/prop.py +892 -265
- dl_backtrace/tf_backtrace/backtrace/backtrace.py +11 -7
- dl_backtrace/tf_backtrace/backtrace/utils/utils_prop.py +249 -47
- dl_backtrace/version.py +2 -2
- {dl_backtrace-0.0.18.dist-info → dl_backtrace-0.0.20.dev36.dist-info}/METADATA +1 -1
- {dl_backtrace-0.0.18.dist-info → dl_backtrace-0.0.20.dev36.dist-info}/RECORD +11 -11
- {dl_backtrace-0.0.18.dist-info → dl_backtrace-0.0.20.dev36.dist-info}/WHEEL +1 -1
- {dl_backtrace-0.0.18.dist-info → dl_backtrace-0.0.20.dev36.dist-info}/LICENSE +0 -0
- {dl_backtrace-0.0.18.dist-info → dl_backtrace-0.0.20.dev36.dist-info}/top_level.txt +0 -0
|
@@ -192,15 +192,15 @@ class Backtrace(object):
|
|
|
192
192
|
return temp_out
|
|
193
193
|
|
|
194
194
|
def eval(self, all_out, start_wt=[], mode="default",multiplier=100.0,
|
|
195
|
-
scaler=
|
|
195
|
+
scaler=0, max_unit=0,thresholding=0.5,
|
|
196
196
|
task="binary-classification",predicted_token=None):
|
|
197
197
|
|
|
198
198
|
if mode=="default":
|
|
199
199
|
output = self.proportional_eval(all_out=all_out,
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
200
|
+
start_wt=start_wt ,
|
|
201
|
+
multiplier=multiplier,
|
|
202
|
+
scaler=scaler,
|
|
203
|
+
max_unit=max_unit,
|
|
204
204
|
thresholding=thresholding,
|
|
205
205
|
task=task,
|
|
206
206
|
predicted_token=predicted_token)
|
|
@@ -219,7 +219,7 @@ class Backtrace(object):
|
|
|
219
219
|
return output
|
|
220
220
|
|
|
221
221
|
def proportional_eval(self, all_out, start_wt=[] ,
|
|
222
|
-
multiplier=100.0, scaler=
|
|
222
|
+
multiplier=100.0, scaler=0, max_unit=0,
|
|
223
223
|
predicted_token=None, thresholding=0.5,
|
|
224
224
|
task="binary-classification"):
|
|
225
225
|
model_resource = self.model_resource
|
|
@@ -229,7 +229,7 @@ class Backtrace(object):
|
|
|
229
229
|
all_wt = {}
|
|
230
230
|
if len(start_wt) == 0:
|
|
231
231
|
if self.model_type == 'encoder':
|
|
232
|
-
start_wt = UP.calculate_start_wt(all_out[out_layer])
|
|
232
|
+
start_wt = UP.calculate_start_wt(all_out[out_layer], scaler=scaler)
|
|
233
233
|
all_wt[out_layer] = start_wt * multiplier
|
|
234
234
|
layer_stack = self.layer_stack
|
|
235
235
|
all_wts = self.model_weights
|
|
@@ -442,10 +442,12 @@ class Backtrace(object):
|
|
|
442
442
|
elif model_resource["graph"][start_layer]["class"] == "Self_Attention":
|
|
443
443
|
weights = all_wts[start_layer]
|
|
444
444
|
self_attention_weights = HP.rename_self_attention_keys(weights)
|
|
445
|
+
config = self.model.config
|
|
445
446
|
temp_wt = UP.calculate_wt_self_attention(
|
|
446
447
|
all_wt[start_layer],
|
|
447
448
|
all_out[child_nodes[0]][0],
|
|
448
449
|
self_attention_weights,
|
|
450
|
+
config
|
|
449
451
|
)
|
|
450
452
|
all_wt[child_nodes[0]] += temp_wt
|
|
451
453
|
elif model_resource["graph"][start_layer]["class"] == 'Residual':
|
|
@@ -502,10 +504,12 @@ class Backtrace(object):
|
|
|
502
504
|
elif model_resource["graph"][start_layer]["class"] == 'Cross_Attention':
|
|
503
505
|
weights = all_wts[start_layer]
|
|
504
506
|
cross_attention_weights = HP.rename_cross_attention_keys(weights)
|
|
507
|
+
config = self.model.config
|
|
505
508
|
temp_wt = UP.calculate_wt_cross_attention(
|
|
506
509
|
all_wt[start_layer],
|
|
507
510
|
[all_out[ch][0] for ch in child_nodes],
|
|
508
511
|
cross_attention_weights,
|
|
512
|
+
config
|
|
509
513
|
)
|
|
510
514
|
for ind, ch in enumerate(child_nodes):
|
|
511
515
|
all_wt[ch] += temp_wt[ind]
|
|
@@ -1161,14 +1161,17 @@ def calculate_wt_residual(wts, inp=None):
|
|
|
1161
1161
|
return wt_mat
|
|
1162
1162
|
|
|
1163
1163
|
|
|
1164
|
-
def calculate_relevance_V(wts, value_output):
|
|
1165
|
-
|
|
1166
|
-
|
|
1164
|
+
def calculate_relevance_V(wts, value_output, w):
|
|
1165
|
+
wt_mat_V = np.zeros(value_output.shape)
|
|
1166
|
+
|
|
1167
|
+
if 'b_v' in w:
|
|
1168
|
+
bias_v = w['b_v']
|
|
1169
|
+
else:
|
|
1170
|
+
bias_v = 0
|
|
1167
1171
|
|
|
1168
1172
|
for i in range(wts.shape[0]):
|
|
1169
1173
|
for j in range(wts.shape[1]):
|
|
1170
1174
|
l1_ind1 = value_output
|
|
1171
|
-
wt_ind1 = wt_mat_V[i, j]
|
|
1172
1175
|
wt = wts[i, j]
|
|
1173
1176
|
|
|
1174
1177
|
p_ind = l1_ind1 > 0
|
|
@@ -1176,12 +1179,21 @@ def calculate_relevance_V(wts, value_output):
|
|
|
1176
1179
|
p_sum = np.sum(l1_ind1[p_ind])
|
|
1177
1180
|
n_sum = np.sum(l1_ind1[n_ind]) * -1
|
|
1178
1181
|
|
|
1182
|
+
if bias_v[i] > 0:
|
|
1183
|
+
pbias = bias_v[i]
|
|
1184
|
+
nbias = 0
|
|
1185
|
+
else:
|
|
1186
|
+
pbias = 0
|
|
1187
|
+
nbias = bias_v[i] * -1
|
|
1188
|
+
|
|
1179
1189
|
if p_sum > 0:
|
|
1180
|
-
p_agg_wt = p_sum / (p_sum + n_sum)
|
|
1190
|
+
p_agg_wt = (p_sum + pbias) / (p_sum + n_sum + pbias + nbias)
|
|
1191
|
+
p_agg_wt = p_agg_wt * (p_sum / (p_sum + pbias))
|
|
1181
1192
|
else:
|
|
1182
1193
|
p_agg_wt = 0
|
|
1183
1194
|
if n_sum > 0:
|
|
1184
|
-
n_agg_wt = n_sum / (p_sum + n_sum)
|
|
1195
|
+
n_agg_wt = (n_sum + nbias) / (p_sum + n_sum + pbias + nbias)
|
|
1196
|
+
n_agg_wt = n_agg_wt * (n_sum / (n_sum + nbias))
|
|
1185
1197
|
else:
|
|
1186
1198
|
n_agg_wt = 0
|
|
1187
1199
|
|
|
@@ -1190,21 +1202,22 @@ def calculate_relevance_V(wts, value_output):
|
|
|
1190
1202
|
if n_sum == 0:
|
|
1191
1203
|
n_sum = 1
|
|
1192
1204
|
|
|
1193
|
-
|
|
1194
|
-
|
|
1205
|
+
wt_mat_V[p_ind] += (l1_ind1[p_ind] / p_sum) * wt * p_agg_wt
|
|
1206
|
+
wt_mat_V[n_ind] += (l1_ind1[n_ind] / n_sum) * wt * n_agg_wt * -1.0
|
|
1195
1207
|
|
|
1196
|
-
|
|
1197
|
-
return wt_mat_V
|
|
1208
|
+
return wt_mat_V
|
|
1198
1209
|
|
|
1199
1210
|
|
|
1200
|
-
def calculate_relevance_QK(wts, QK_output):
|
|
1201
|
-
|
|
1202
|
-
|
|
1211
|
+
def calculate_relevance_QK(wts, QK_output, w):
|
|
1212
|
+
wt_mat_QK = np.zeros(QK_output.shape)
|
|
1213
|
+
|
|
1214
|
+
# Check if 'b_q' and 'b_k' exist in the weights, default to 0 if not
|
|
1215
|
+
b_q = w['b_q'] if 'b_q' in w else 0
|
|
1216
|
+
b_k = w['b_k'] if 'b_k' in w else 0
|
|
1203
1217
|
|
|
1204
1218
|
for i in range(wts.shape[0]):
|
|
1205
1219
|
for j in range(wts.shape[1]):
|
|
1206
1220
|
l1_ind1 = QK_output
|
|
1207
|
-
wt_ind1 = wt_mat_QK[i, j]
|
|
1208
1221
|
wt = wts[i, j]
|
|
1209
1222
|
|
|
1210
1223
|
p_ind = l1_ind1 > 0
|
|
@@ -1212,7 +1225,21 @@ def calculate_relevance_QK(wts, QK_output):
|
|
|
1212
1225
|
p_sum = np.sum(l1_ind1[p_ind])
|
|
1213
1226
|
n_sum = np.sum(l1_ind1[n_ind]) * -1
|
|
1214
1227
|
|
|
1215
|
-
|
|
1228
|
+
if b_q[i] > 0 and b_k[i] > 0:
|
|
1229
|
+
pbias = b_q[i] + b_k[i]
|
|
1230
|
+
nbias = 0
|
|
1231
|
+
elif b_q[i] > 0 and b_k[i] < 0:
|
|
1232
|
+
pbias = b_q[i]
|
|
1233
|
+
nbias = b_k[i] * -1
|
|
1234
|
+
elif b_q[i] < 0 and b_k[i] > 0:
|
|
1235
|
+
pbias = b_k[i]
|
|
1236
|
+
nbias = b_q[i] * -1
|
|
1237
|
+
else:
|
|
1238
|
+
pbias = 0
|
|
1239
|
+
nbias = b_q[i] + b_k[i]
|
|
1240
|
+
nbias *= -1
|
|
1241
|
+
|
|
1242
|
+
t_sum = p_sum + pbias - n_sum - nbias
|
|
1216
1243
|
|
|
1217
1244
|
# This layer has a softmax activation function
|
|
1218
1245
|
act = {
|
|
@@ -1231,12 +1258,13 @@ def calculate_relevance_QK(wts, QK_output):
|
|
|
1231
1258
|
n_sum = 0
|
|
1232
1259
|
|
|
1233
1260
|
if p_sum > 0:
|
|
1234
|
-
p_agg_wt = p_sum / (p_sum + n_sum)
|
|
1261
|
+
p_agg_wt = (p_sum + pbias) / (p_sum + n_sum + pbias + nbias)
|
|
1262
|
+
p_agg_wt = p_agg_wt * (p_sum / (p_sum + pbias))
|
|
1235
1263
|
else:
|
|
1236
1264
|
p_agg_wt = 0
|
|
1237
|
-
|
|
1238
1265
|
if n_sum > 0:
|
|
1239
|
-
n_agg_wt = n_sum / (p_sum + n_sum)
|
|
1266
|
+
n_agg_wt = (n_sum + nbias) / (p_sum + n_sum + pbias + nbias)
|
|
1267
|
+
n_agg_wt = n_agg_wt * (n_sum / (n_sum + nbias))
|
|
1240
1268
|
else:
|
|
1241
1269
|
n_agg_wt = 0
|
|
1242
1270
|
|
|
@@ -1245,14 +1273,60 @@ def calculate_relevance_QK(wts, QK_output):
|
|
|
1245
1273
|
if n_sum == 0:
|
|
1246
1274
|
n_sum = 1
|
|
1247
1275
|
|
|
1248
|
-
|
|
1249
|
-
|
|
1276
|
+
wt_mat_QK[p_ind] += (l1_ind1[p_ind] / p_sum) * wt * p_agg_wt
|
|
1277
|
+
wt_mat_QK[n_ind] += (l1_ind1[n_ind] / n_sum) * wt * n_agg_wt * -1.0
|
|
1250
1278
|
|
|
1251
|
-
wt_mat_QK = np.sum(wt_mat_QK, axis=(0, 1))
|
|
1252
1279
|
return wt_mat_QK
|
|
1253
1280
|
|
|
1254
1281
|
|
|
1255
|
-
def
|
|
1282
|
+
def calculate_wt_attention_output_projection(wts, proj_output, w):
|
|
1283
|
+
wt_mat_proj_output = np.zeros(proj_output.shape)
|
|
1284
|
+
|
|
1285
|
+
if 'b_d' in w:
|
|
1286
|
+
bias_d = w['b_d']
|
|
1287
|
+
else:
|
|
1288
|
+
bias_d = 0
|
|
1289
|
+
|
|
1290
|
+
for i in range(wts.shape[0]):
|
|
1291
|
+
for j in range(wts.shape[1]):
|
|
1292
|
+
l1_ind1 = proj_output
|
|
1293
|
+
wt = wts[i, j]
|
|
1294
|
+
|
|
1295
|
+
p_ind = l1_ind1 > 0
|
|
1296
|
+
n_ind = l1_ind1 < 0
|
|
1297
|
+
p_sum = np.sum(l1_ind1[p_ind])
|
|
1298
|
+
n_sum = np.sum(l1_ind1[n_ind]) * -1
|
|
1299
|
+
|
|
1300
|
+
if bias_d[i] > 0:
|
|
1301
|
+
pbias = bias_d[i]
|
|
1302
|
+
nbias = 0
|
|
1303
|
+
else:
|
|
1304
|
+
pbias = 0
|
|
1305
|
+
nbias = bias_d[i] * -1
|
|
1306
|
+
|
|
1307
|
+
if p_sum > 0:
|
|
1308
|
+
p_agg_wt = (p_sum + pbias) / (p_sum + n_sum + pbias + nbias)
|
|
1309
|
+
p_agg_wt = p_agg_wt * (p_sum / (p_sum + pbias))
|
|
1310
|
+
else:
|
|
1311
|
+
p_agg_wt = 0
|
|
1312
|
+
if n_sum > 0:
|
|
1313
|
+
n_agg_wt = (n_sum + nbias) / (p_sum + n_sum + pbias + nbias)
|
|
1314
|
+
n_agg_wt = n_agg_wt * (n_sum / (n_sum + nbias))
|
|
1315
|
+
else:
|
|
1316
|
+
n_agg_wt = 0
|
|
1317
|
+
|
|
1318
|
+
if p_sum == 0:
|
|
1319
|
+
p_sum = 1
|
|
1320
|
+
if n_sum == 0:
|
|
1321
|
+
n_sum = 1
|
|
1322
|
+
|
|
1323
|
+
wt_mat_proj_output[p_ind] += (l1_ind1[p_ind] / p_sum) * wt * p_agg_wt
|
|
1324
|
+
wt_mat_proj_output[n_ind] += (l1_ind1[n_ind] / n_sum) * wt * n_agg_wt * -1.0
|
|
1325
|
+
|
|
1326
|
+
return wt_mat_proj_output
|
|
1327
|
+
|
|
1328
|
+
|
|
1329
|
+
def calculate_wt_self_attention(wts, inp, w, config):
|
|
1256
1330
|
'''
|
|
1257
1331
|
Input:
|
|
1258
1332
|
wts: relevance score of the layer
|
|
@@ -1267,25 +1341,76 @@ def calculate_wt_self_attention(wts, inp, w):
|
|
|
1267
1341
|
query_output = np.einsum('ij,kj->ik', inp, w['W_q'].T)
|
|
1268
1342
|
key_output = np.einsum('ij,kj->ik', inp, w['W_k'].T)
|
|
1269
1343
|
value_output = np.einsum('ij,kj->ik', inp, w['W_v'].T)
|
|
1344
|
+
|
|
1345
|
+
# --------------- Reshape for Multi-Head Attention ----------------------
|
|
1346
|
+
num_heads = getattr(config, 'num_attention_heads', getattr(config, 'num_heads', None)) # will work for BERT as well as T5/ Llama
|
|
1347
|
+
hidden_size = getattr(config, 'hidden_size', getattr(config, 'd_model', None)) # will work for BERT as well as T5/Llama
|
|
1348
|
+
if hasattr(config, 'num_key_value_heads'):
|
|
1349
|
+
num_key_value_heads = config.num_key_value_heads
|
|
1350
|
+
else:
|
|
1351
|
+
num_key_value_heads = num_heads
|
|
1352
|
+
head_dim = hidden_size // num_heads # dimension of each attention head
|
|
1353
|
+
|
|
1354
|
+
query_states = np.einsum('thd->htd', query_output.reshape(query_output.shape[0], num_heads, head_dim)) # (num_heads, num_tokens, head_dim)
|
|
1355
|
+
key_states = np.einsum('thd->htd', key_output.reshape(key_output.shape[0], num_key_value_heads, head_dim)) # (num_key_value_heads, num_tokens, head_dim)
|
|
1356
|
+
value_states = np.einsum('thd->htd', value_output.reshape(value_output.shape[0], num_key_value_heads, head_dim)) # (num_key_value_heads, num_tokens, head_dim)
|
|
1357
|
+
|
|
1358
|
+
# calculate how many times we need to repeat the key/value heads
|
|
1359
|
+
n_rep = num_heads // num_key_value_heads
|
|
1360
|
+
key_states = np.repeat(key_states, n_rep, axis=0)
|
|
1361
|
+
value_states = np.repeat(value_states, n_rep, axis=0)
|
|
1362
|
+
|
|
1363
|
+
QK_output = np.einsum('hqd,hkd->hqk', query_states, key_states) # (num_heads, num_tokens, num_tokens)
|
|
1364
|
+
attn_weights = QK_output / np.sqrt(head_dim)
|
|
1365
|
+
|
|
1366
|
+
# Apply softmax along the last dimension (softmax over key dimension)
|
|
1367
|
+
attn_weights = np.exp(attn_weights - np.max(attn_weights, axis=-1, keepdims=True)) # Numerically stable softmax
|
|
1368
|
+
attn_weights = attn_weights / np.sum(attn_weights, axis=-1, keepdims=True)
|
|
1369
|
+
|
|
1370
|
+
# Weighted sum of values (num_heads, num_tokens, head_dim)
|
|
1371
|
+
attn_output = np.einsum('hqk,hkl->hql', attn_weights, value_states)
|
|
1372
|
+
|
|
1373
|
+
transposed_attn_output = np.einsum('hqd->qhd', attn_output)
|
|
1374
|
+
reshaped_attn_output = transposed_attn_output.reshape(transposed_attn_output.shape[0], num_heads * head_dim)
|
|
1375
|
+
|
|
1376
|
+
# Perform final linear projection (num_tokens, hidden_size)
|
|
1377
|
+
final_output = np.einsum('qd,dh->qh', reshaped_attn_output, w['W_d'])
|
|
1378
|
+
|
|
1379
|
+
# ------------- Relevance calculation for Final Linear Projection -------------
|
|
1380
|
+
wt_mat_attn_proj = calculate_wt_attention_output_projection(wts, final_output, w)
|
|
1270
1381
|
|
|
1271
1382
|
# --------------- Relevance Calculation for Step-3 -----------------------
|
|
1272
|
-
|
|
1273
|
-
|
|
1383
|
+
# divide the relevance among `attn_weights` and `value_states`
|
|
1384
|
+
wt_mat_attn_proj = wt_mat_attn_proj.reshape(-1, num_heads, head_dim)
|
|
1385
|
+
wt_mat_attn_proj = np.einsum('qhd->hqd', wt_mat_attn_proj)
|
|
1274
1386
|
|
|
1275
|
-
|
|
1276
|
-
|
|
1387
|
+
stabilized_attn_output = stabilize(attn_output * 2)
|
|
1388
|
+
norm_wt_mat_attn_proj = wt_mat_attn_proj / stabilized_attn_output
|
|
1389
|
+
relevance_QK = np.einsum('htd,hbd->htb', norm_wt_mat_attn_proj, value_states) * attn_weights
|
|
1390
|
+
relevance_V = np.einsum('hdt,hdb->htb', attn_weights, norm_wt_mat_attn_proj) * value_states
|
|
1277
1391
|
|
|
1392
|
+
# --------------- Relevance Calculation for V --------------------------------
|
|
1393
|
+
relevance_V = np.einsum('hqd->qhd', relevance_V)
|
|
1394
|
+
relevance_V = relevance_V.reshape(-1, num_heads * head_dim)
|
|
1395
|
+
wt_mat_V = calculate_relevance_V(relevance_V, value_states, w)
|
|
1396
|
+
|
|
1278
1397
|
# --------------- Transformed Relevance QK ----------------------------------
|
|
1279
|
-
|
|
1280
|
-
|
|
1398
|
+
relevance_QK = np.einsum('hqd->qhd', relevance_QK)
|
|
1399
|
+
relevance_QK = relevance_QK.reshape(-1, relevance_QK.shape[1] * relevance_QK.shape[2])
|
|
1400
|
+
wt_mat_QK = calculate_relevance_QK(relevance_QK, QK_output, w)
|
|
1281
1401
|
|
|
1282
1402
|
# --------------- Relevance Calculation for K and Q --------------------------------
|
|
1283
1403
|
stabilized_QK_output = stabilize(QK_output * 2)
|
|
1284
1404
|
norm_wt_mat_QK = wt_mat_QK / stabilized_QK_output
|
|
1285
|
-
wt_mat_Q = np.einsum('
|
|
1286
|
-
wt_mat_K = np.einsum('
|
|
1405
|
+
wt_mat_Q = np.einsum('htd,hdb->htb', norm_wt_mat_QK, key_states) * query_states
|
|
1406
|
+
wt_mat_K = np.einsum('htd,htb->hbd', query_states, norm_wt_mat_QK) * key_states
|
|
1287
1407
|
|
|
1288
1408
|
wt_mat = wt_mat_V + wt_mat_K + wt_mat_Q
|
|
1409
|
+
|
|
1410
|
+
# Reshape wt_mat
|
|
1411
|
+
wt_mat = np.einsum('htd->thd', wt_mat)
|
|
1412
|
+
wt_mat = wt_mat.reshape(wt_mat.shape[0], wt_mat.shape[1] * wt_mat.shape[2]) # reshaped_array = array.reshape(8, 32 * 128)
|
|
1413
|
+
|
|
1289
1414
|
return wt_mat
|
|
1290
1415
|
|
|
1291
1416
|
|
|
@@ -1301,6 +1426,8 @@ def calculate_wt_feed_forward(wts, inp, w):
|
|
|
1301
1426
|
R2 = wts[i]
|
|
1302
1427
|
contribution_matrix2 = np.einsum('ij,j->ij', w['W_out'].T, intermediate_output[i])
|
|
1303
1428
|
wt_mat2 = np.zeros(contribution_matrix2.shape)
|
|
1429
|
+
|
|
1430
|
+
bias_out = w['b_out'] if 'b_out' in w else 0
|
|
1304
1431
|
|
|
1305
1432
|
for j in range(contribution_matrix2.shape[0]):
|
|
1306
1433
|
l1_ind1 = contribution_matrix2[j]
|
|
@@ -1312,13 +1439,22 @@ def calculate_wt_feed_forward(wts, inp, w):
|
|
|
1312
1439
|
p_sum = np.sum(l1_ind1[p_ind])
|
|
1313
1440
|
n_sum = np.sum(l1_ind1[n_ind]) * -1
|
|
1314
1441
|
|
|
1442
|
+
# Handle positive and negative bias contributions
|
|
1443
|
+
if bias_out[i] > 0:
|
|
1444
|
+
pbias = bias_out[i]
|
|
1445
|
+
nbias = 0
|
|
1446
|
+
else:
|
|
1447
|
+
pbias = 0
|
|
1448
|
+
nbias = -bias_out[i]
|
|
1449
|
+
|
|
1315
1450
|
if p_sum > 0:
|
|
1316
|
-
p_agg_wt = p_sum / (p_sum + n_sum)
|
|
1451
|
+
p_agg_wt = (p_sum + pbias) / (p_sum + n_sum + pbias + nbias)
|
|
1452
|
+
p_agg_wt = p_agg_wt * (p_sum / (p_sum + pbias))
|
|
1317
1453
|
else:
|
|
1318
1454
|
p_agg_wt = 0
|
|
1319
|
-
|
|
1320
1455
|
if n_sum > 0:
|
|
1321
|
-
n_agg_wt = n_sum / (p_sum + n_sum)
|
|
1456
|
+
n_agg_wt = (n_sum + nbias) / (p_sum + n_sum + pbias + nbias)
|
|
1457
|
+
n_agg_wt = n_agg_wt * (n_sum / (n_sum + nbias))
|
|
1322
1458
|
else:
|
|
1323
1459
|
n_agg_wt = 0
|
|
1324
1460
|
|
|
@@ -1337,6 +1473,9 @@ def calculate_wt_feed_forward(wts, inp, w):
|
|
|
1337
1473
|
R1 = relevance_out[i]
|
|
1338
1474
|
contribution_matrix1 = np.einsum('ij,j->ij', w['W_int'].T, inp[i])
|
|
1339
1475
|
wt_mat1 = np.zeros(contribution_matrix1.shape)
|
|
1476
|
+
|
|
1477
|
+
# Check if bias 'b_int' exists, default to 0 if not
|
|
1478
|
+
bias_int = w['b_int'] if 'b_int' in w else 0
|
|
1340
1479
|
|
|
1341
1480
|
for j in range(contribution_matrix1.shape[0]):
|
|
1342
1481
|
l1_ind1 = contribution_matrix1[j]
|
|
@@ -1348,7 +1487,15 @@ def calculate_wt_feed_forward(wts, inp, w):
|
|
|
1348
1487
|
p_sum = np.sum(l1_ind1[p_ind])
|
|
1349
1488
|
n_sum = np.sum(l1_ind1[n_ind]) * -1
|
|
1350
1489
|
|
|
1351
|
-
|
|
1490
|
+
# Handle positive and negative bias
|
|
1491
|
+
if bias_int[i] > 0:
|
|
1492
|
+
pbias = bias_int[i]
|
|
1493
|
+
nbias = 0
|
|
1494
|
+
else:
|
|
1495
|
+
pbias = 0
|
|
1496
|
+
nbias = -bias_int[i]
|
|
1497
|
+
|
|
1498
|
+
t_sum = p_sum + pbias - n_sum - nbias
|
|
1352
1499
|
|
|
1353
1500
|
# This layer has a ReLU activation function
|
|
1354
1501
|
act = {
|
|
@@ -1367,12 +1514,13 @@ def calculate_wt_feed_forward(wts, inp, w):
|
|
|
1367
1514
|
n_sum = 0
|
|
1368
1515
|
|
|
1369
1516
|
if p_sum > 0:
|
|
1370
|
-
p_agg_wt = p_sum / (p_sum + n_sum)
|
|
1517
|
+
p_agg_wt = (p_sum + pbias) / (p_sum + n_sum + pbias + nbias)
|
|
1518
|
+
p_agg_wt = p_agg_wt * (p_sum / (p_sum + pbias))
|
|
1371
1519
|
else:
|
|
1372
1520
|
p_agg_wt = 0
|
|
1373
|
-
|
|
1374
1521
|
if n_sum > 0:
|
|
1375
|
-
n_agg_wt = n_sum / (p_sum + n_sum)
|
|
1522
|
+
n_agg_wt = (n_sum + nbias) / (p_sum + n_sum + pbias + nbias)
|
|
1523
|
+
n_agg_wt = n_agg_wt * (n_sum / (n_sum + nbias))
|
|
1376
1524
|
else:
|
|
1377
1525
|
n_agg_wt = 0
|
|
1378
1526
|
|
|
@@ -1461,8 +1609,8 @@ def calculate_wt_pooler(wts, inp, w):
|
|
|
1461
1609
|
# Calculate relevance for each token
|
|
1462
1610
|
relevance_inp[i] = wt_mat.sum(axis=0)
|
|
1463
1611
|
|
|
1464
|
-
relevance_inp *= (
|
|
1465
|
-
return relevance_inp
|
|
1612
|
+
relevance_inp *= (np.sum(wts) / np.sum(relevance_inp))
|
|
1613
|
+
return relevance_inp
|
|
1466
1614
|
|
|
1467
1615
|
|
|
1468
1616
|
def calculate_wt_classifier(wts, inp, w):
|
|
@@ -1595,7 +1743,7 @@ def calculate_wt_lm_head(wts, inp, w):
|
|
|
1595
1743
|
return relevance_input
|
|
1596
1744
|
|
|
1597
1745
|
|
|
1598
|
-
def calculate_wt_cross_attention(wts, inp, w):
|
|
1746
|
+
def calculate_wt_cross_attention(wts, inp, w, config):
|
|
1599
1747
|
'''
|
|
1600
1748
|
Input:
|
|
1601
1749
|
wts: relevance score of the layer
|
|
@@ -1613,23 +1761,77 @@ def calculate_wt_cross_attention(wts, inp, w):
|
|
|
1613
1761
|
key_output = np.einsum('ij,kj->ik', k_v_inp, w['W_k'].T)
|
|
1614
1762
|
value_output = np.einsum('ij,kj->ik', k_v_inp, w['W_v'].T)
|
|
1615
1763
|
|
|
1764
|
+
# --------------- Reshape for Multi-Head Attention ----------------------
|
|
1765
|
+
num_heads = getattr(config, 'num_attention_heads', getattr(config, 'num_heads', None)) # will work for BERT as well as T5/ Llama
|
|
1766
|
+
hidden_size = getattr(config, 'hidden_size', getattr(config, 'd_model', None)) # will work for BERT as well as T5/Llama
|
|
1767
|
+
if hasattr(config, 'num_key_value_heads'):
|
|
1768
|
+
num_key_value_heads = config.num_key_value_heads
|
|
1769
|
+
else:
|
|
1770
|
+
num_key_value_heads = num_heads
|
|
1771
|
+
head_dim = hidden_size // num_heads # dimension of each attention head
|
|
1772
|
+
|
|
1773
|
+
query_states = np.einsum('thd->htd', query_output.reshape(query_output.shape[0], num_heads, head_dim)) # (num_heads, num_tokens, head_dim)
|
|
1774
|
+
key_states = np.einsum('thd->htd', key_output.reshape(key_output.shape[0], num_key_value_heads, head_dim)) # (num_key_value_heads, num_tokens, head_dim)
|
|
1775
|
+
value_states = np.einsum('thd->htd', value_output.reshape(value_output.shape[0], num_key_value_heads, head_dim)) # (num_key_value_heads, num_tokens, head_dim)
|
|
1776
|
+
|
|
1777
|
+
# calculate how many times we need to repeat the key/value heads
|
|
1778
|
+
n_rep = num_heads // num_key_value_heads
|
|
1779
|
+
key_states = np.repeat(key_states, n_rep, axis=0)
|
|
1780
|
+
value_states = np.repeat(value_states, n_rep, axis=0)
|
|
1781
|
+
|
|
1782
|
+
QK_output = np.einsum('hqd,hkd->hqk', query_states, key_states) # (num_heads, num_tokens, num_tokens)
|
|
1783
|
+
attn_weights = QK_output / np.sqrt(head_dim)
|
|
1784
|
+
|
|
1785
|
+
# Apply softmax along the last dimension (softmax over key dimension)
|
|
1786
|
+
attn_weights = np.exp(attn_weights - np.max(attn_weights, axis=-1, keepdims=True)) # Numerically stable softmax
|
|
1787
|
+
attn_weights = attn_weights / np.sum(attn_weights, axis=-1, keepdims=True)
|
|
1788
|
+
|
|
1789
|
+
# Weighted sum of values (num_heads, num_tokens, head_dim)
|
|
1790
|
+
attn_output = np.einsum('hqk,hkl->hql', attn_weights, value_states)
|
|
1791
|
+
|
|
1792
|
+
transposed_attn_output = np.einsum('hqd->qhd', attn_output)
|
|
1793
|
+
reshaped_attn_output = transposed_attn_output.reshape(transposed_attn_output.shape[0], num_heads * head_dim)
|
|
1794
|
+
|
|
1795
|
+
# Perform final linear projection (num_tokens, hidden_size)
|
|
1796
|
+
final_output = np.einsum('qd,dh->qh', reshaped_attn_output, w['W_d'])
|
|
1797
|
+
|
|
1798
|
+
# ------------- Relevance calculation for Final Linear Projection -------------
|
|
1799
|
+
wt_mat_attn_proj = calculate_wt_attention_output_projection(wts, final_output)
|
|
1800
|
+
|
|
1616
1801
|
# --------------- Relevance Calculation for Step-3 -----------------------
|
|
1617
|
-
|
|
1618
|
-
|
|
1802
|
+
# divide the relevance among `attn_weights` and `value_states`
|
|
1803
|
+
wt_mat_attn_proj = wt_mat_attn_proj.reshape(-1, num_heads, head_dim)
|
|
1804
|
+
wt_mat_attn_proj = np.einsum('qhd->hqd', wt_mat_attn_proj)
|
|
1619
1805
|
|
|
1620
|
-
|
|
1621
|
-
|
|
1806
|
+
stabilized_attn_output = stabilize(attn_output * 2)
|
|
1807
|
+
norm_wt_mat_attn_proj = wt_mat_attn_proj / stabilized_attn_output
|
|
1808
|
+
relevance_QK = np.einsum('htd,hbd->htb', norm_wt_mat_attn_proj, value_states) * attn_weights
|
|
1809
|
+
relevance_V = np.einsum('hdt,hdb->htb', attn_weights, norm_wt_mat_attn_proj) * value_states
|
|
1622
1810
|
|
|
1811
|
+
# --------------- Relevance Calculation for V --------------------------------
|
|
1812
|
+
relevance_V = np.einsum('hqd->qhd', relevance_V)
|
|
1813
|
+
relevance_V = relevance_V.reshape(-1, num_heads * head_dim)
|
|
1814
|
+
wt_mat_V = calculate_relevance_V(relevance_V, value_states)
|
|
1815
|
+
|
|
1623
1816
|
# --------------- Transformed Relevance QK ----------------------------------
|
|
1624
|
-
|
|
1817
|
+
relevance_QK = np.einsum('hqd->qhd', relevance_QK)
|
|
1818
|
+
relevance_QK = relevance_QK.reshape(-1, relevance_QK.shape[1] * relevance_QK.shape[2])
|
|
1625
1819
|
wt_mat_QK = calculate_relevance_QK(relevance_QK, QK_output)
|
|
1626
1820
|
|
|
1627
1821
|
# --------------- Relevance Calculation for K and Q --------------------------------
|
|
1628
1822
|
stabilized_QK_output = stabilize(QK_output * 2)
|
|
1629
1823
|
norm_wt_mat_QK = wt_mat_QK / stabilized_QK_output
|
|
1630
|
-
wt_mat_Q = np.einsum('
|
|
1631
|
-
wt_mat_K = np.einsum('
|
|
1824
|
+
wt_mat_Q = np.einsum('htd,hdb->htb', norm_wt_mat_QK, key_states) * query_states
|
|
1825
|
+
wt_mat_K = np.einsum('htd,htb->hbd', query_states, norm_wt_mat_QK) * key_states
|
|
1632
1826
|
|
|
1827
|
+
# Relevance of KV input
|
|
1633
1828
|
wt_mat_KV = wt_mat_V + wt_mat_K
|
|
1829
|
+
|
|
1830
|
+
# Reshape wt_mat_Q and wt_mat_KV
|
|
1831
|
+
wt_mat_Q = np.einsum('htd->thd', wt_mat_Q)
|
|
1832
|
+
wt_mat_KV = np.einsum('htd->thd', wt_mat_KV)
|
|
1833
|
+
wt_mat_Q = wt_mat_Q.reshape(wt_mat_Q.shape[0], wt_mat_Q.shape[1] * wt_mat_Q.shape[2])
|
|
1834
|
+
wt_mat_KV = wt_mat_KV.reshape(wt_mat_KV.shape[0], wt_mat_KV.shape[1] * wt_mat_KV.shape[2])
|
|
1835
|
+
|
|
1634
1836
|
wt_mat = [wt_mat_KV, wt_mat_Q]
|
|
1635
1837
|
return wt_mat
|
dl_backtrace/version.py
CHANGED
|
@@ -12,5 +12,5 @@ __version__: str
|
|
|
12
12
|
__version_tuple__: VERSION_TUPLE
|
|
13
13
|
version_tuple: VERSION_TUPLE
|
|
14
14
|
|
|
15
|
-
__version__ = version = '0.0.
|
|
16
|
-
__version_tuple__ = version_tuple = (0, 0,
|
|
15
|
+
__version__ = version = '0.0.20.dev36'
|
|
16
|
+
__version_tuple__ = version_tuple = (0, 0, 20, 'dev36')
|
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
dl_backtrace/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
2
|
-
dl_backtrace/version.py,sha256=
|
|
2
|
+
dl_backtrace/version.py,sha256=tYHVV4mIeOCumN5OCSD_xV6vt2LqOBp8qrLBhN4xnyw,428
|
|
3
3
|
dl_backtrace/old_backtrace/__init__.py,sha256=AbpHGcgLb-kRsJGnwFEktk7uzpZOCcBY74-YBdrKVGs,1
|
|
4
4
|
dl_backtrace/old_backtrace/pytorch_backtrace/__init__.py,sha256=TDhKQIj1INyQq7cqTvpjpnBDhoMeWSoqVdx_mPAV3Sw,24
|
|
5
5
|
dl_backtrace/old_backtrace/pytorch_backtrace/backtrace/__init__.py,sha256=AuR7uMTbf7rrFl-9sMIQ3lQmhu_ATSFBZmfs7R66jAc,76
|
|
@@ -17,18 +17,18 @@ dl_backtrace/old_backtrace/tf_backtrace/backtrace/utils/contrast.py,sha256=p-0zj
|
|
|
17
17
|
dl_backtrace/old_backtrace/tf_backtrace/backtrace/utils/prop.py,sha256=-h7nHEvsoEwfksHsT52VfnZy334DvqqP8g6fMFVNnAM,25670
|
|
18
18
|
dl_backtrace/pytorch_backtrace/__init__.py,sha256=TDhKQIj1INyQq7cqTvpjpnBDhoMeWSoqVdx_mPAV3Sw,24
|
|
19
19
|
dl_backtrace/pytorch_backtrace/backtrace/__init__.py,sha256=AuR7uMTbf7rrFl-9sMIQ3lQmhu_ATSFBZmfs7R66jAc,76
|
|
20
|
-
dl_backtrace/pytorch_backtrace/backtrace/backtrace.py,sha256=
|
|
20
|
+
dl_backtrace/pytorch_backtrace/backtrace/backtrace.py,sha256=bdb8eFzTE7ms34ne-VBJxXY8AMHhu-vTenUNwCMKbRo,44410
|
|
21
21
|
dl_backtrace/pytorch_backtrace/backtrace/config.py,sha256=ODrgOC74ojzGLnEto_ah-WPA8_MyRE7cZkZlU15239s,983
|
|
22
22
|
dl_backtrace/pytorch_backtrace/backtrace/utils/__init__.py,sha256=KffAJVu7NsgfMHEZaY7lND2LQwZamVIquqx8POwOaLg,120
|
|
23
|
-
dl_backtrace/pytorch_backtrace/backtrace/utils/contrast.py,sha256=
|
|
23
|
+
dl_backtrace/pytorch_backtrace/backtrace/utils/contrast.py,sha256=gcvDRZstcgAkjzyyxy6dagSV2EDFGqh9_SMn9P1-2lo,52795
|
|
24
24
|
dl_backtrace/pytorch_backtrace/backtrace/utils/encoder.py,sha256=1dUKEL_LAuFcUEldYdyQVsV_P-KH3gTWEOgROfXPwyc,7469
|
|
25
25
|
dl_backtrace/pytorch_backtrace/backtrace/utils/encoder_decoder.py,sha256=fxbXwUl7KQKwWSdKMO-bYvDC7NDct71F_U7uZzh1orw,25658
|
|
26
26
|
dl_backtrace/pytorch_backtrace/backtrace/utils/helper.py,sha256=IG0XkPEfbPpWBm4aVHgv-GkgFFrl2wu5xMrz6DeH_xQ,3512
|
|
27
|
-
dl_backtrace/pytorch_backtrace/backtrace/utils/prop.py,sha256=
|
|
27
|
+
dl_backtrace/pytorch_backtrace/backtrace/utils/prop.py,sha256=oEwWmRsLwoYdCXJaEEW7-UL1CLVOsseUAcOasTkyQeE,67294
|
|
28
28
|
dl_backtrace/tf_backtrace/__init__.py,sha256=TDhKQIj1INyQq7cqTvpjpnBDhoMeWSoqVdx_mPAV3Sw,24
|
|
29
29
|
dl_backtrace/tf_backtrace/backtrace/__init__.py,sha256=KkU7X_wXxwYR4HYQqQ3kWtxlJK3Ytaa84e1Jbzc2_ZA,84
|
|
30
30
|
dl_backtrace/tf_backtrace/backtrace/activation_info.py,sha256=3Ppw4_6rJV16YPXnKjd2WPaULgUUL6U01bh8Foa-3Yg,1334
|
|
31
|
-
dl_backtrace/tf_backtrace/backtrace/backtrace.py,sha256=
|
|
31
|
+
dl_backtrace/tf_backtrace/backtrace/backtrace.py,sha256=k6irwiOu68EpRG_W9wk5UF7QS50TKM5JqUQfj4qmvWw,41635
|
|
32
32
|
dl_backtrace/tf_backtrace/backtrace/models.py,sha256=wPaeRuEvZL2xTdj6I6hF-ZCKC2c8EAuF9PoOIZkxkR4,466
|
|
33
33
|
dl_backtrace/tf_backtrace/backtrace/server.py,sha256=jphibvI46QpQcqnpXVIYFq_M2CtRVONgItZe9iWOd54,567
|
|
34
34
|
dl_backtrace/tf_backtrace/backtrace/utils/__init__.py,sha256=ci_RAYYnqyAWa_rcIEycnqCghQ4aZtvaGQ7oDUb_k_0,131
|
|
@@ -36,9 +36,9 @@ dl_backtrace/tf_backtrace/backtrace/utils/encoder.py,sha256=WeGLjIRHNqjwIK-8UB0x
|
|
|
36
36
|
dl_backtrace/tf_backtrace/backtrace/utils/encoder_decoder.py,sha256=qbS34WswNiT1xgON5ayNAIewJMRfDdeGel6l1XrjXms,24247
|
|
37
37
|
dl_backtrace/tf_backtrace/backtrace/utils/helper.py,sha256=QB21kPB5iJfRpy8khYJnzojaKf5ACnAFYh5XxYBcnXA,3419
|
|
38
38
|
dl_backtrace/tf_backtrace/backtrace/utils/utils_contrast.py,sha256=rQwManW0d6Td6V_A1qGezcZ19Tgr34zbFcieQ0rqAAc,48415
|
|
39
|
-
dl_backtrace/tf_backtrace/backtrace/utils/utils_prop.py,sha256=
|
|
40
|
-
dl_backtrace-0.0.
|
|
41
|
-
dl_backtrace-0.0.
|
|
42
|
-
dl_backtrace-0.0.
|
|
43
|
-
dl_backtrace-0.0.
|
|
44
|
-
dl_backtrace-0.0.
|
|
39
|
+
dl_backtrace/tf_backtrace/backtrace/utils/utils_prop.py,sha256=kduJpbN2FIrrhgrUxSFkX3gnTfXLuQa4qx-LH0dDB7A,68291
|
|
40
|
+
dl_backtrace-0.0.20.dev36.dist-info/LICENSE,sha256=RTqAU0MFv1q3ZXKewNobKxIIPzRHgImom7e6ORV7X6o,1064
|
|
41
|
+
dl_backtrace-0.0.20.dev36.dist-info/METADATA,sha256=H5E00w4t5F6O67roa2Y6tiGEW_SziKjsgMssLEoDM_U,7843
|
|
42
|
+
dl_backtrace-0.0.20.dev36.dist-info/WHEEL,sha256=R06PA3UVYHThwHvxuRWMqaGcr-PuniXahwjmQRFMEkY,91
|
|
43
|
+
dl_backtrace-0.0.20.dev36.dist-info/top_level.txt,sha256=gvGVYScJfW6c4aO5WMo4Aqa6NLEfmLK7VWXVx_GeiIk,13
|
|
44
|
+
dl_backtrace-0.0.20.dev36.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|