dl-backtrace 0.0.19__py3-none-any.whl → 0.0.20.dev36__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of dl-backtrace might be problematic. Click here for more details.

@@ -1,6 +1,7 @@
1
1
  import numpy as np
2
2
  import torch
3
3
  import torch.nn as nn
4
+ from tqdm import tqdm
4
5
  from dl_backtrace.pytorch_backtrace.backtrace.utils import contrast as UC
5
6
  from dl_backtrace.pytorch_backtrace.backtrace.utils import prop as UP
6
7
  from dl_backtrace.pytorch_backtrace.backtrace.config import activation_master
@@ -26,6 +27,7 @@ class Backtrace(object):
26
27
  self.model_weights = EN.extract_encoder_weights(model)
27
28
  # # calculate the output of each submodule of the encoder model
28
29
  # self.all_out_model = EN.create_encoder_output(model)
30
+ self.activation_dict = None
29
31
 
30
32
  elif model_type == 'encoder_decoder':
31
33
  self.model = model
@@ -36,7 +38,7 @@ class Backtrace(object):
36
38
  self.model_weights = ED.extract_encoder_decoder_weights(model)
37
39
  # # calculate the output of each submodule of the encoder-decoder model
38
40
  # self.all_out_model = ED.calculate_encoder_decoder_output(model)
39
-
41
+ self.activation_dict = None
40
42
 
41
43
  else:
42
44
  self.model_type = model_type
@@ -317,10 +319,10 @@ class Backtrace(object):
317
319
  all_out=all_out,
318
320
  start_wt=start_wt,
319
321
  multiplier=multiplier,
320
- scaler=0,
322
+ scaler=scaler,
321
323
  max_unit=0,
322
324
  predicted_token=predicted_token,
323
- thresholding=0.5,
325
+ thresholding=thresholding,
324
326
  task="binary-classification",
325
327
  )
326
328
  return output
@@ -351,11 +353,11 @@ class Backtrace(object):
351
353
  all_wt = {}
352
354
  if len(start_wt) == 0:
353
355
  if self.model_type == 'encoder':
354
- start_wt = UP.calculate_start_wt(all_out[out_layer].detach().numpy())
356
+ start_wt = UP.calculate_start_wt(all_out[out_layer].detach().numpy(), scaler=scaler)
355
357
  all_wt[out_layer] = start_wt * multiplier
356
- layer_stack = self.layer_stack
357
- all_wts = self.model_weights
358
- if self.model_type == 'encoder_decoder':
358
+ layer_stack = self.layer_stack
359
+ all_wts = self.model_weights
360
+ elif self.model_type == 'encoder_decoder':
359
361
  start_wt = UP.calculate_enc_dec_start_wt(all_out[out_layer][0].detach().numpy(), predicted_token)
360
362
  all_wt[out_layer] = start_wt * multiplier
361
363
  layer_stack = self.layer_stack
@@ -365,7 +367,7 @@ class Backtrace(object):
365
367
  all_wt[out_layer] = start_wt * multiplier
366
368
  layer_stack = self.layer_stack
367
369
 
368
- for start_layer in layer_stack:
370
+ for start_layer in tqdm(layer_stack):
369
371
  if model_resource[1][start_layer]["child"]:
370
372
  child_nodes = model_resource[1][start_layer]["child"]
371
373
  for ch in child_nodes:
@@ -538,21 +540,25 @@ class Backtrace(object):
538
540
  elif model_resource[1][start_layer]["class"] == "Self_Attention":
539
541
  weights = all_wts[start_layer]
540
542
  self_attention_weights = HP.rename_self_attention_keys(weights)
543
+ config = self.model.config
541
544
 
542
545
  temp_wt = UP.calculate_wt_self_attention(
543
546
  all_wt[start_layer],
544
547
  all_out[child_nodes[0]][0].detach().numpy(),
545
548
  self_attention_weights,
549
+ config
546
550
  )
547
551
  all_wt[child_nodes[0]] += temp_wt
552
+
548
553
  elif model_resource[1][start_layer]["class"] == 'Residual':
549
- temp_wt = UP.calculate_wt_add(
554
+ temp_wt = UP.calculate_wt_residual(
550
555
  all_wt[start_layer],
551
556
  [all_out[ch].detach().numpy() for ch in child_nodes],
552
557
  )
553
558
 
554
559
  for ind, ch in enumerate(child_nodes):
555
560
  all_wt[ch] += temp_wt[ind]
561
+
556
562
  elif model_resource[1][start_layer]["class"] == 'Feed_Forward':
557
563
  weights = all_wts[start_layer]
558
564
  feed_forward_weights = HP.rename_feed_forward_keys(weights)
@@ -1183,27 +1183,85 @@ def stabilize(matrix, epsilon=1e-6):
1183
1183
  return matrix + epsilon * np.sign(matrix)
1184
1184
 
1185
1185
 
1186
- def calculate_relevance_V(wts, value_output):
1187
- # Initialize wt_mat with zeros
1188
- wt_mat_V = np.zeros((wts.shape[0], wts.shape[1], *value_output.shape))
1186
+ def calculate_wt_residual(wts, inp=None):
1187
+ wt_mat = []
1188
+ inp_list = []
1189
+ expanded_wts = as_strided(
1190
+ wts,
1191
+ shape=(np.prod(wts.shape),),
1192
+ strides=(wts.strides[-1],),
1193
+ writeable=False, # totally use this to avoid writing to memory in weird places
1194
+ )
1195
+
1196
+ for x in inp:
1197
+ expanded_input = as_strided(
1198
+ x,
1199
+ shape=(np.prod(x.shape),),
1200
+ strides=(x.strides[-1],),
1201
+ writeable=False, # totally use this to avoid writing to memory in weird places
1202
+ )
1203
+ inp_list.append(expanded_input)
1204
+ wt_mat.append(np.zeros_like(expanded_input))
1205
+ wt_mat = np.array(wt_mat)
1206
+ inp_list = np.array(inp_list)
1207
+ for i in range(wt_mat.shape[1]):
1208
+ wt_ind1 = wt_mat[:, i]
1209
+ wt = expanded_wts[i]
1210
+ l1_ind1 = inp_list[:, i]
1211
+ p_ind = l1_ind1 > 0
1212
+ n_ind = l1_ind1 < 0
1213
+ p_sum = np.sum(l1_ind1[p_ind])
1214
+ n_sum = np.sum(l1_ind1[n_ind]) * -1
1215
+ t_sum = p_sum - n_sum
1216
+ p_agg_wt = 0
1217
+ n_agg_wt = 0
1218
+ if p_sum + n_sum > 0:
1219
+ p_agg_wt = p_sum / (p_sum + n_sum)
1220
+ n_agg_wt = n_sum / (p_sum + n_sum)
1221
+ if p_sum == 0:
1222
+ p_sum = 1
1223
+ if n_sum == 0:
1224
+ n_sum = 1
1225
+ wt_ind1[p_ind] = (l1_ind1[p_ind] / p_sum) * wt * p_agg_wt
1226
+ wt_ind1[n_ind] = (l1_ind1[n_ind] / n_sum) * wt * n_agg_wt * -1.0
1227
+ wt_mat[:, i] = wt_ind1
1228
+ wt_mat = [i.reshape(wts.shape) for i in list(wt_mat)]
1229
+ return wt_mat
1230
+
1231
+
1232
+ def calculate_relevance_V(wts, value_output, w):
1233
+ wt_mat_V = np.zeros(value_output.shape)
1234
+
1235
+ if 'b_v' in w:
1236
+ bias_v = w['b_v']
1237
+ else:
1238
+ bias_v = 0
1189
1239
 
1190
1240
  for i in range(wts.shape[0]):
1191
1241
  for j in range(wts.shape[1]):
1192
1242
  l1_ind1 = value_output
1193
- wt_ind1 = wt_mat_V[i, j]
1194
1243
  wt = wts[i, j]
1195
1244
 
1196
1245
  p_ind = l1_ind1 > 0
1197
1246
  n_ind = l1_ind1 < 0
1198
1247
  p_sum = np.sum(l1_ind1[p_ind])
1199
1248
  n_sum = np.sum(l1_ind1[n_ind]) * -1
1249
+
1250
+ if bias_v[i] > 0:
1251
+ pbias = bias_v[i]
1252
+ nbias = 0
1253
+ else:
1254
+ pbias = 0
1255
+ nbias = bias_v[i] * -1
1200
1256
 
1201
1257
  if p_sum > 0:
1202
- p_agg_wt = p_sum / (p_sum + n_sum)
1258
+ p_agg_wt = (p_sum + pbias) / (p_sum + n_sum + pbias + nbias)
1259
+ p_agg_wt = p_agg_wt * (p_sum / (p_sum + pbias))
1203
1260
  else:
1204
1261
  p_agg_wt = 0
1205
1262
  if n_sum > 0:
1206
- n_agg_wt = n_sum / (p_sum + n_sum)
1263
+ n_agg_wt = (n_sum + nbias) / (p_sum + n_sum + pbias + nbias)
1264
+ n_agg_wt = n_agg_wt * (n_sum / (n_sum + nbias))
1207
1265
  else:
1208
1266
  n_agg_wt = 0
1209
1267
 
@@ -1212,21 +1270,22 @@ def calculate_relevance_V(wts, value_output):
1212
1270
  if n_sum == 0:
1213
1271
  n_sum = 1
1214
1272
 
1215
- wt_ind1[p_ind] = (l1_ind1[p_ind] / p_sum) * wt * p_agg_wt
1216
- wt_ind1[n_ind] = (l1_ind1[n_ind] / n_sum) * wt * n_agg_wt * -1.0
1273
+ wt_mat_V[p_ind] += (l1_ind1[p_ind] / p_sum) * wt * p_agg_wt
1274
+ wt_mat_V[n_ind] += (l1_ind1[n_ind] / n_sum) * wt * n_agg_wt * -1.0
1217
1275
 
1218
- wt_mat_V = np.sum(wt_mat_V, axis=(0,1))
1219
1276
  return wt_mat_V
1220
1277
 
1221
1278
 
1222
- def calculate_relevance_QK(wts, QK_output):
1223
- # Initialize wt_mat with zeros
1224
- wt_mat_QK = np.zeros((wts.shape[0], wts.shape[1], *QK_output.shape))
1279
+ def calculate_relevance_QK(wts, QK_output, w):
1280
+ wt_mat_QK = np.zeros(QK_output.shape)
1281
+
1282
+ # Check if 'b_q' and 'b_k' exist in the weights, default to 0 if not
1283
+ b_q = w['b_q'] if 'b_q' in w else 0
1284
+ b_k = w['b_k'] if 'b_k' in w else 0
1225
1285
 
1226
1286
  for i in range(wts.shape[0]):
1227
1287
  for j in range(wts.shape[1]):
1228
1288
  l1_ind1 = QK_output
1229
- wt_ind1 = wt_mat_QK[i, j]
1230
1289
  wt = wts[i, j]
1231
1290
 
1232
1291
  p_ind = l1_ind1 > 0
@@ -1234,7 +1293,21 @@ def calculate_relevance_QK(wts, QK_output):
1234
1293
  p_sum = np.sum(l1_ind1[p_ind])
1235
1294
  n_sum = np.sum(l1_ind1[n_ind]) * -1
1236
1295
 
1237
- t_sum = p_sum - n_sum
1296
+ if b_q[i] > 0 and b_k[i] > 0:
1297
+ pbias = b_q[i] + b_k[i]
1298
+ nbias = 0
1299
+ elif b_q[i] > 0 and b_k[i] < 0:
1300
+ pbias = b_q[i]
1301
+ nbias = b_k[i] * -1
1302
+ elif b_q[i] < 0 and b_k[i] > 0:
1303
+ pbias = b_k[i]
1304
+ nbias = b_q[i] * -1
1305
+ else:
1306
+ pbias = 0
1307
+ nbias = b_q[i] + b_k[i]
1308
+ nbias *= -1
1309
+
1310
+ t_sum = p_sum + pbias - n_sum - nbias
1238
1311
 
1239
1312
  # This layer has a softmax activation function
1240
1313
  act = {
@@ -1253,12 +1326,13 @@ def calculate_relevance_QK(wts, QK_output):
1253
1326
  n_sum = 0
1254
1327
 
1255
1328
  if p_sum > 0:
1256
- p_agg_wt = p_sum / (p_sum + n_sum)
1329
+ p_agg_wt = (p_sum + pbias) / (p_sum + n_sum + pbias + nbias)
1330
+ p_agg_wt = p_agg_wt * (p_sum / (p_sum + pbias))
1257
1331
  else:
1258
1332
  p_agg_wt = 0
1259
-
1260
1333
  if n_sum > 0:
1261
- n_agg_wt = n_sum / (p_sum + n_sum)
1334
+ n_agg_wt = (n_sum + nbias) / (p_sum + n_sum + pbias + nbias)
1335
+ n_agg_wt = n_agg_wt * (n_sum / (n_sum + nbias))
1262
1336
  else:
1263
1337
  n_agg_wt = 0
1264
1338
 
@@ -1267,14 +1341,60 @@ def calculate_relevance_QK(wts, QK_output):
1267
1341
  if n_sum == 0:
1268
1342
  n_sum = 1
1269
1343
 
1270
- wt_ind1[p_ind] = (l1_ind1[p_ind] / p_sum) * wt * p_agg_wt
1271
- wt_ind1[n_ind] = (l1_ind1[n_ind] / n_sum) * wt * n_agg_wt * -1.0
1344
+ wt_mat_QK[p_ind] += (l1_ind1[p_ind] / p_sum) * wt * p_agg_wt
1345
+ wt_mat_QK[n_ind] += (l1_ind1[n_ind] / n_sum) * wt * n_agg_wt * -1.0
1272
1346
 
1273
- wt_mat_QK = np.sum(wt_mat_QK, axis=(0, 1))
1274
1347
  return wt_mat_QK
1275
1348
 
1276
1349
 
1277
- def calculate_wt_self_attention(wts, inp, w):
1350
+ def calculate_wt_attention_output_projection(wts, proj_output, w):
1351
+ wt_mat_proj_output = np.zeros(proj_output.shape)
1352
+
1353
+ if 'b_d' in w:
1354
+ bias_d = w['b_d']
1355
+ else:
1356
+ bias_d = 0
1357
+
1358
+ for i in range(wts.shape[0]):
1359
+ for j in range(wts.shape[1]):
1360
+ l1_ind1 = proj_output
1361
+ wt = wts[i, j]
1362
+
1363
+ p_ind = l1_ind1 > 0
1364
+ n_ind = l1_ind1 < 0
1365
+ p_sum = np.sum(l1_ind1[p_ind])
1366
+ n_sum = np.sum(l1_ind1[n_ind]) * -1
1367
+
1368
+ if bias_d[i] > 0:
1369
+ pbias = bias_d[i]
1370
+ nbias = 0
1371
+ else:
1372
+ pbias = 0
1373
+ nbias = bias_d[i] * -1
1374
+
1375
+ if p_sum > 0:
1376
+ p_agg_wt = (p_sum + pbias) / (p_sum + n_sum + pbias + nbias)
1377
+ p_agg_wt = p_agg_wt * (p_sum / (p_sum + pbias))
1378
+ else:
1379
+ p_agg_wt = 0
1380
+ if n_sum > 0:
1381
+ n_agg_wt = (n_sum + nbias) / (p_sum + n_sum + pbias + nbias)
1382
+ n_agg_wt = n_agg_wt * (n_sum / (n_sum + nbias))
1383
+ else:
1384
+ n_agg_wt = 0
1385
+
1386
+ if p_sum == 0:
1387
+ p_sum = 1
1388
+ if n_sum == 0:
1389
+ n_sum = 1
1390
+
1391
+ wt_mat_proj_output[p_ind] += (l1_ind1[p_ind] / p_sum) * wt * p_agg_wt
1392
+ wt_mat_proj_output[n_ind] += (l1_ind1[n_ind] / n_sum) * wt * n_agg_wt * -1.0
1393
+
1394
+ return wt_mat_proj_output
1395
+
1396
+
1397
+ def calculate_wt_self_attention(wts, inp, w, config):
1278
1398
  '''
1279
1399
  Input:
1280
1400
  wts: relevance score of the layer
@@ -1286,28 +1406,82 @@ def calculate_wt_self_attention(wts, inp, w):
1286
1406
  Step-2: outputs = F.softmax(inputs, dim=dim, dtype=dtype)
1287
1407
  Step-3: outputs = input_a * input_b
1288
1408
  '''
1409
+ # print(f"inp: {inp.shape}, wts: {wts.shape}") # (1, 512)
1410
+ # print(f"w['W_q']: {w['W_q'].shape}, w['W_k']: {w['W_k'].shape}, w['W_v']: {w['W_v'].shape}")
1411
+
1289
1412
  query_output = np.einsum('ij,kj->ik', inp, w['W_q'])
1290
1413
  key_output = np.einsum('ij,kj->ik', inp, w['W_k'])
1291
1414
  value_output = np.einsum('ij,kj->ik', inp, w['W_v'])
1292
1415
 
1416
+ # --------------- Reshape for Multi-Head Attention ----------------------
1417
+ num_heads = getattr(config, 'num_attention_heads', getattr(config, 'num_heads', None)) # will work for BERT as well as T5/ Llama
1418
+ hidden_size = getattr(config, 'hidden_size', getattr(config, 'd_model', None)) # will work for BERT as well as T5/Llama
1419
+ if hasattr(config, 'num_key_value_heads'):
1420
+ num_key_value_heads = config.num_key_value_heads
1421
+ else:
1422
+ num_key_value_heads = num_heads
1423
+ head_dim = hidden_size // num_heads # dimension of each attention head
1424
+
1425
+ query_states = np.einsum('thd->htd', query_output.reshape(query_output.shape[0], num_heads, head_dim)) # (num_heads, num_tokens, head_dim)
1426
+ key_states = np.einsum('thd->htd', key_output.reshape(key_output.shape[0], num_key_value_heads, head_dim)) # (num_key_value_heads, num_tokens, head_dim)
1427
+ value_states = np.einsum('thd->htd', value_output.reshape(value_output.shape[0], num_key_value_heads, head_dim)) # (num_key_value_heads, num_tokens, head_dim)
1428
+
1429
+ # calculate how many times we need to repeat the key/value heads
1430
+ n_rep = num_heads // num_key_value_heads
1431
+ key_states = np.repeat(key_states, n_rep, axis=0)
1432
+ value_states = np.repeat(value_states, n_rep, axis=0)
1433
+
1434
+ QK_output = np.einsum('hqd,hkd->hqk', query_states, key_states) # (num_heads, num_tokens, num_tokens)
1435
+ attn_weights = QK_output / np.sqrt(head_dim)
1436
+
1437
+ # Apply softmax along the last dimension (softmax over key dimension)
1438
+ attn_weights = np.exp(attn_weights - np.max(attn_weights, axis=-1, keepdims=True)) # Numerically stable softmax
1439
+ attn_weights = attn_weights / np.sum(attn_weights, axis=-1, keepdims=True)
1440
+
1441
+ # Weighted sum of values (num_heads, num_tokens, head_dim)
1442
+ attn_output = np.einsum('hqk,hkl->hql', attn_weights, value_states)
1443
+
1444
+ transposed_attn_output = np.einsum('hqd->qhd', attn_output)
1445
+ reshaped_attn_output = transposed_attn_output.reshape(transposed_attn_output.shape[0], num_heads * head_dim)
1446
+
1447
+ # Perform final linear projection (num_tokens, hidden_size)
1448
+ final_output = np.einsum('qd,dh->qh', reshaped_attn_output, w['W_d'])
1449
+
1450
+ # ------------- Relevance calculation for Final Linear Projection -------------
1451
+ wt_mat_attn_proj = calculate_wt_attention_output_projection(wts, final_output, w)
1452
+
1293
1453
  # --------------- Relevance Calculation for Step-3 -----------------------
1294
- relevance_V = wts / 2
1295
- relevance_QK = wts / 2
1454
+ # divide the relevance among `attn_weights` and `value_states`
1455
+ wt_mat_attn_proj = wt_mat_attn_proj.reshape(-1, num_heads, head_dim)
1456
+ wt_mat_attn_proj = np.einsum('qhd->hqd', wt_mat_attn_proj)
1457
+
1458
+ stabilized_attn_output = stabilize(attn_output * 2)
1459
+ norm_wt_mat_attn_proj = wt_mat_attn_proj / stabilized_attn_output
1460
+ relevance_QK = np.einsum('htd,hbd->htb', norm_wt_mat_attn_proj, value_states) * attn_weights
1461
+ relevance_V = np.einsum('hdt,hdb->htb', attn_weights, norm_wt_mat_attn_proj) * value_states
1296
1462
 
1297
1463
  # --------------- Relevance Calculation for V --------------------------------
1298
- wt_mat_V = calculate_relevance_V(relevance_V, value_output)
1464
+ relevance_V = np.einsum('hqd->qhd', relevance_V)
1465
+ relevance_V = relevance_V.reshape(-1, num_heads * head_dim)
1466
+ wt_mat_V = calculate_relevance_V(relevance_V, value_states, w)
1299
1467
 
1300
1468
  # --------------- Transformed Relevance QK ----------------------------------
1301
- QK_output = np.einsum('ij,kj->ik', query_output, key_output)
1302
- wt_mat_QK = calculate_relevance_QK(relevance_QK, QK_output)
1469
+ relevance_QK = np.einsum('hqd->qhd', relevance_QK)
1470
+ relevance_QK = relevance_QK.reshape(-1, relevance_QK.shape[1] * relevance_QK.shape[2])
1471
+ wt_mat_QK = calculate_relevance_QK(relevance_QK, QK_output, w)
1303
1472
 
1304
1473
  # --------------- Relevance Calculation for K and Q --------------------------------
1305
1474
  stabilized_QK_output = stabilize(QK_output * 2)
1306
1475
  norm_wt_mat_QK = wt_mat_QK / stabilized_QK_output
1307
- wt_mat_Q = np.einsum('ij,jk->ik', norm_wt_mat_QK, key_output) * query_output
1308
- wt_mat_K = np.einsum('ij,ik->kj', query_output, norm_wt_mat_QK) * key_output
1476
+ wt_mat_Q = np.einsum('htd,hdb->htb', norm_wt_mat_QK, key_states) * query_states
1477
+ wt_mat_K = np.einsum('htd,htb->hbd', query_states, norm_wt_mat_QK) * key_states
1309
1478
 
1310
1479
  wt_mat = wt_mat_V + wt_mat_K + wt_mat_Q
1480
+
1481
+ # Reshape wt_mat
1482
+ wt_mat = np.einsum('htd->thd', wt_mat)
1483
+ wt_mat = wt_mat.reshape(wt_mat.shape[0], wt_mat.shape[1] * wt_mat.shape[2]) # reshaped_array = array.reshape(8, 32 * 128)
1484
+
1311
1485
  return wt_mat
1312
1486
 
1313
1487
 
@@ -1323,7 +1497,9 @@ def calculate_wt_feed_forward(wts, inp, w):
1323
1497
  R2 = wts[i]
1324
1498
  contribution_matrix2 = np.einsum('ij,j->ij', w['W_out'], intermediate_output[i])
1325
1499
  wt_mat2 = np.zeros(contribution_matrix2.shape)
1326
-
1500
+
1501
+ bias_out = w['b_out'] if 'b_out' in w else 0
1502
+
1327
1503
  for j in range(contribution_matrix2.shape[0]):
1328
1504
  l1_ind1 = contribution_matrix2[j]
1329
1505
  wt_ind1 = wt_mat2[j]
@@ -1333,14 +1509,23 @@ def calculate_wt_feed_forward(wts, inp, w):
1333
1509
  n_ind = l1_ind1 < 0
1334
1510
  p_sum = np.sum(l1_ind1[p_ind])
1335
1511
  n_sum = np.sum(l1_ind1[n_ind]) * -1
1512
+
1513
+ # Handle positive and negative bias contributions
1514
+ if bias_out[i] > 0:
1515
+ pbias = bias_out[i]
1516
+ nbias = 0
1517
+ else:
1518
+ pbias = 0
1519
+ nbias = -bias_out[i]
1336
1520
 
1337
1521
  if p_sum > 0:
1338
- p_agg_wt = p_sum / (p_sum + n_sum)
1522
+ p_agg_wt = (p_sum + pbias) / (p_sum + n_sum + pbias + nbias)
1523
+ p_agg_wt = p_agg_wt * (p_sum / (p_sum + pbias))
1339
1524
  else:
1340
1525
  p_agg_wt = 0
1341
-
1342
1526
  if n_sum > 0:
1343
- n_agg_wt = n_sum / (p_sum + n_sum)
1527
+ n_agg_wt = (n_sum + nbias) / (p_sum + n_sum + pbias + nbias)
1528
+ n_agg_wt = n_agg_wt * (n_sum / (n_sum + nbias))
1344
1529
  else:
1345
1530
  n_agg_wt = 0
1346
1531
 
@@ -1359,6 +1544,9 @@ def calculate_wt_feed_forward(wts, inp, w):
1359
1544
  R1 = relevance_out[i]
1360
1545
  contribution_matrix1 = np.einsum('ij,j->ij', w['W_int'], inp[i])
1361
1546
  wt_mat1 = np.zeros(contribution_matrix1.shape)
1547
+
1548
+ # Check if bias 'b_int' exists, default to 0 if not
1549
+ bias_int = w['b_int'] if 'b_int' in w else 0
1362
1550
 
1363
1551
  for j in range(contribution_matrix1.shape[0]):
1364
1552
  l1_ind1 = contribution_matrix1[j]
@@ -1370,7 +1558,15 @@ def calculate_wt_feed_forward(wts, inp, w):
1370
1558
  p_sum = np.sum(l1_ind1[p_ind])
1371
1559
  n_sum = np.sum(l1_ind1[n_ind]) * -1
1372
1560
 
1373
- t_sum = p_sum - n_sum
1561
+ # Handle positive and negative bias
1562
+ if bias_int[i] > 0:
1563
+ pbias = bias_int[i]
1564
+ nbias = 0
1565
+ else:
1566
+ pbias = 0
1567
+ nbias = -bias_int[i]
1568
+
1569
+ t_sum = p_sum + pbias - n_sum - nbias
1374
1570
 
1375
1571
  # This layer has a ReLU activation function
1376
1572
  act = {
@@ -1389,12 +1585,13 @@ def calculate_wt_feed_forward(wts, inp, w):
1389
1585
  n_sum = 0
1390
1586
 
1391
1587
  if p_sum > 0:
1392
- p_agg_wt = p_sum / (p_sum + n_sum)
1588
+ p_agg_wt = (p_sum + pbias) / (p_sum + n_sum + pbias + nbias)
1589
+ p_agg_wt = p_agg_wt * (p_sum / (p_sum + pbias))
1393
1590
  else:
1394
1591
  p_agg_wt = 0
1395
-
1396
1592
  if n_sum > 0:
1397
- n_agg_wt = n_sum / (p_sum + n_sum)
1593
+ n_agg_wt = (n_sum + nbias) / (p_sum + n_sum + pbias + nbias)
1594
+ n_agg_wt = n_agg_wt * (n_sum / (n_sum + nbias))
1398
1595
  else:
1399
1596
  n_agg_wt = 0
1400
1597
 
@@ -1551,7 +1748,7 @@ def calculate_wt_pooler(wts, inp, w):
1551
1748
  # Calculate relevance for each token
1552
1749
  relevance_inp[i] = wt_mat.sum(axis=0)
1553
1750
 
1554
- relevance_inp *= (100 / np.sum(relevance_inp))
1751
+ relevance_inp *= (np.sum(wts) / np.sum(relevance_inp))
1555
1752
  return relevance_inp
1556
1753
 
1557
1754
 
@@ -192,15 +192,15 @@ class Backtrace(object):
192
192
  return temp_out
193
193
 
194
194
  def eval(self, all_out, start_wt=[], mode="default",multiplier=100.0,
195
- scaler=None, max_unit=0,thresholding=0.5,
195
+ scaler=0, max_unit=0,thresholding=0.5,
196
196
  task="binary-classification",predicted_token=None):
197
197
 
198
198
  if mode=="default":
199
199
  output = self.proportional_eval(all_out=all_out,
200
- start_wt=start_wt ,
201
- multiplier=multiplier,
202
- scaler=scaler,
203
- max_unit=max_unit,
200
+ start_wt=start_wt ,
201
+ multiplier=multiplier,
202
+ scaler=scaler,
203
+ max_unit=max_unit,
204
204
  thresholding=thresholding,
205
205
  task=task,
206
206
  predicted_token=predicted_token)
@@ -219,7 +219,7 @@ class Backtrace(object):
219
219
  return output
220
220
 
221
221
  def proportional_eval(self, all_out, start_wt=[] ,
222
- multiplier=100.0, scaler=None, max_unit=0,
222
+ multiplier=100.0, scaler=0, max_unit=0,
223
223
  predicted_token=None, thresholding=0.5,
224
224
  task="binary-classification"):
225
225
  model_resource = self.model_resource
@@ -229,7 +229,7 @@ class Backtrace(object):
229
229
  all_wt = {}
230
230
  if len(start_wt) == 0:
231
231
  if self.model_type == 'encoder':
232
- start_wt = UP.calculate_start_wt(all_out[out_layer])
232
+ start_wt = UP.calculate_start_wt(all_out[out_layer], scaler=scaler)
233
233
  all_wt[out_layer] = start_wt * multiplier
234
234
  layer_stack = self.layer_stack
235
235
  all_wts = self.model_weights
@@ -442,10 +442,12 @@ class Backtrace(object):
442
442
  elif model_resource["graph"][start_layer]["class"] == "Self_Attention":
443
443
  weights = all_wts[start_layer]
444
444
  self_attention_weights = HP.rename_self_attention_keys(weights)
445
+ config = self.model.config
445
446
  temp_wt = UP.calculate_wt_self_attention(
446
447
  all_wt[start_layer],
447
448
  all_out[child_nodes[0]][0],
448
449
  self_attention_weights,
450
+ config
449
451
  )
450
452
  all_wt[child_nodes[0]] += temp_wt
451
453
  elif model_resource["graph"][start_layer]["class"] == 'Residual':
@@ -502,10 +504,12 @@ class Backtrace(object):
502
504
  elif model_resource["graph"][start_layer]["class"] == 'Cross_Attention':
503
505
  weights = all_wts[start_layer]
504
506
  cross_attention_weights = HP.rename_cross_attention_keys(weights)
507
+ config = self.model.config
505
508
  temp_wt = UP.calculate_wt_cross_attention(
506
509
  all_wt[start_layer],
507
510
  [all_out[ch][0] for ch in child_nodes],
508
511
  cross_attention_weights,
512
+ config
509
513
  )
510
514
  for ind, ch in enumerate(child_nodes):
511
515
  all_wt[ch] += temp_wt[ind]
@@ -1161,14 +1161,17 @@ def calculate_wt_residual(wts, inp=None):
1161
1161
  return wt_mat
1162
1162
 
1163
1163
 
1164
- def calculate_relevance_V(wts, value_output):
1165
- # Initialize wt_mat with zeros
1166
- wt_mat_V = np.zeros((wts.shape[0], wts.shape[1], *value_output.shape))
1164
+ def calculate_relevance_V(wts, value_output, w):
1165
+ wt_mat_V = np.zeros(value_output.shape)
1166
+
1167
+ if 'b_v' in w:
1168
+ bias_v = w['b_v']
1169
+ else:
1170
+ bias_v = 0
1167
1171
 
1168
1172
  for i in range(wts.shape[0]):
1169
1173
  for j in range(wts.shape[1]):
1170
1174
  l1_ind1 = value_output
1171
- wt_ind1 = wt_mat_V[i, j]
1172
1175
  wt = wts[i, j]
1173
1176
 
1174
1177
  p_ind = l1_ind1 > 0
@@ -1176,12 +1179,21 @@ def calculate_relevance_V(wts, value_output):
1176
1179
  p_sum = np.sum(l1_ind1[p_ind])
1177
1180
  n_sum = np.sum(l1_ind1[n_ind]) * -1
1178
1181
 
1182
+ if bias_v[i] > 0:
1183
+ pbias = bias_v[i]
1184
+ nbias = 0
1185
+ else:
1186
+ pbias = 0
1187
+ nbias = bias_v[i] * -1
1188
+
1179
1189
  if p_sum > 0:
1180
- p_agg_wt = p_sum / (p_sum + n_sum)
1190
+ p_agg_wt = (p_sum + pbias) / (p_sum + n_sum + pbias + nbias)
1191
+ p_agg_wt = p_agg_wt * (p_sum / (p_sum + pbias))
1181
1192
  else:
1182
1193
  p_agg_wt = 0
1183
1194
  if n_sum > 0:
1184
- n_agg_wt = n_sum / (p_sum + n_sum)
1195
+ n_agg_wt = (n_sum + nbias) / (p_sum + n_sum + pbias + nbias)
1196
+ n_agg_wt = n_agg_wt * (n_sum / (n_sum + nbias))
1185
1197
  else:
1186
1198
  n_agg_wt = 0
1187
1199
 
@@ -1190,21 +1202,22 @@ def calculate_relevance_V(wts, value_output):
1190
1202
  if n_sum == 0:
1191
1203
  n_sum = 1
1192
1204
 
1193
- wt_ind1[p_ind] = (l1_ind1[p_ind] / p_sum) * wt * p_agg_wt
1194
- wt_ind1[n_ind] = (l1_ind1[n_ind] / n_sum) * wt * n_agg_wt * -1.0
1205
+ wt_mat_V[p_ind] += (l1_ind1[p_ind] / p_sum) * wt * p_agg_wt
1206
+ wt_mat_V[n_ind] += (l1_ind1[n_ind] / n_sum) * wt * n_agg_wt * -1.0
1195
1207
 
1196
- wt_mat_V = np.sum(wt_mat_V, axis=(0,1))
1197
- return wt_mat_V
1208
+ return wt_mat_V
1198
1209
 
1199
1210
 
1200
- def calculate_relevance_QK(wts, QK_output):
1201
- # Initialize wt_mat with zeros
1202
- wt_mat_QK = np.zeros((wts.shape[0], wts.shape[1], *QK_output.shape))
1211
+ def calculate_relevance_QK(wts, QK_output, w):
1212
+ wt_mat_QK = np.zeros(QK_output.shape)
1213
+
1214
+ # Check if 'b_q' and 'b_k' exist in the weights, default to 0 if not
1215
+ b_q = w['b_q'] if 'b_q' in w else 0
1216
+ b_k = w['b_k'] if 'b_k' in w else 0
1203
1217
 
1204
1218
  for i in range(wts.shape[0]):
1205
1219
  for j in range(wts.shape[1]):
1206
1220
  l1_ind1 = QK_output
1207
- wt_ind1 = wt_mat_QK[i, j]
1208
1221
  wt = wts[i, j]
1209
1222
 
1210
1223
  p_ind = l1_ind1 > 0
@@ -1212,7 +1225,21 @@ def calculate_relevance_QK(wts, QK_output):
1212
1225
  p_sum = np.sum(l1_ind1[p_ind])
1213
1226
  n_sum = np.sum(l1_ind1[n_ind]) * -1
1214
1227
 
1215
- t_sum = p_sum - n_sum
1228
+ if b_q[i] > 0 and b_k[i] > 0:
1229
+ pbias = b_q[i] + b_k[i]
1230
+ nbias = 0
1231
+ elif b_q[i] > 0 and b_k[i] < 0:
1232
+ pbias = b_q[i]
1233
+ nbias = b_k[i] * -1
1234
+ elif b_q[i] < 0 and b_k[i] > 0:
1235
+ pbias = b_k[i]
1236
+ nbias = b_q[i] * -1
1237
+ else:
1238
+ pbias = 0
1239
+ nbias = b_q[i] + b_k[i]
1240
+ nbias *= -1
1241
+
1242
+ t_sum = p_sum + pbias - n_sum - nbias
1216
1243
 
1217
1244
  # This layer has a softmax activation function
1218
1245
  act = {
@@ -1231,12 +1258,13 @@ def calculate_relevance_QK(wts, QK_output):
1231
1258
  n_sum = 0
1232
1259
 
1233
1260
  if p_sum > 0:
1234
- p_agg_wt = p_sum / (p_sum + n_sum)
1261
+ p_agg_wt = (p_sum + pbias) / (p_sum + n_sum + pbias + nbias)
1262
+ p_agg_wt = p_agg_wt * (p_sum / (p_sum + pbias))
1235
1263
  else:
1236
1264
  p_agg_wt = 0
1237
-
1238
1265
  if n_sum > 0:
1239
- n_agg_wt = n_sum / (p_sum + n_sum)
1266
+ n_agg_wt = (n_sum + nbias) / (p_sum + n_sum + pbias + nbias)
1267
+ n_agg_wt = n_agg_wt * (n_sum / (n_sum + nbias))
1240
1268
  else:
1241
1269
  n_agg_wt = 0
1242
1270
 
@@ -1245,14 +1273,60 @@ def calculate_relevance_QK(wts, QK_output):
1245
1273
  if n_sum == 0:
1246
1274
  n_sum = 1
1247
1275
 
1248
- wt_ind1[p_ind] = (l1_ind1[p_ind] / p_sum) * wt * p_agg_wt
1249
- wt_ind1[n_ind] = (l1_ind1[n_ind] / n_sum) * wt * n_agg_wt * -1.0
1276
+ wt_mat_QK[p_ind] += (l1_ind1[p_ind] / p_sum) * wt * p_agg_wt
1277
+ wt_mat_QK[n_ind] += (l1_ind1[n_ind] / n_sum) * wt * n_agg_wt * -1.0
1250
1278
 
1251
- wt_mat_QK = np.sum(wt_mat_QK, axis=(0, 1))
1252
1279
  return wt_mat_QK
1253
1280
 
1254
1281
 
1255
- def calculate_wt_self_attention(wts, inp, w):
1282
+ def calculate_wt_attention_output_projection(wts, proj_output, w):
1283
+ wt_mat_proj_output = np.zeros(proj_output.shape)
1284
+
1285
+ if 'b_d' in w:
1286
+ bias_d = w['b_d']
1287
+ else:
1288
+ bias_d = 0
1289
+
1290
+ for i in range(wts.shape[0]):
1291
+ for j in range(wts.shape[1]):
1292
+ l1_ind1 = proj_output
1293
+ wt = wts[i, j]
1294
+
1295
+ p_ind = l1_ind1 > 0
1296
+ n_ind = l1_ind1 < 0
1297
+ p_sum = np.sum(l1_ind1[p_ind])
1298
+ n_sum = np.sum(l1_ind1[n_ind]) * -1
1299
+
1300
+ if bias_d[i] > 0:
1301
+ pbias = bias_d[i]
1302
+ nbias = 0
1303
+ else:
1304
+ pbias = 0
1305
+ nbias = bias_d[i] * -1
1306
+
1307
+ if p_sum > 0:
1308
+ p_agg_wt = (p_sum + pbias) / (p_sum + n_sum + pbias + nbias)
1309
+ p_agg_wt = p_agg_wt * (p_sum / (p_sum + pbias))
1310
+ else:
1311
+ p_agg_wt = 0
1312
+ if n_sum > 0:
1313
+ n_agg_wt = (n_sum + nbias) / (p_sum + n_sum + pbias + nbias)
1314
+ n_agg_wt = n_agg_wt * (n_sum / (n_sum + nbias))
1315
+ else:
1316
+ n_agg_wt = 0
1317
+
1318
+ if p_sum == 0:
1319
+ p_sum = 1
1320
+ if n_sum == 0:
1321
+ n_sum = 1
1322
+
1323
+ wt_mat_proj_output[p_ind] += (l1_ind1[p_ind] / p_sum) * wt * p_agg_wt
1324
+ wt_mat_proj_output[n_ind] += (l1_ind1[n_ind] / n_sum) * wt * n_agg_wt * -1.0
1325
+
1326
+ return wt_mat_proj_output
1327
+
1328
+
1329
+ def calculate_wt_self_attention(wts, inp, w, config):
1256
1330
  '''
1257
1331
  Input:
1258
1332
  wts: relevance score of the layer
@@ -1267,25 +1341,76 @@ def calculate_wt_self_attention(wts, inp, w):
1267
1341
  query_output = np.einsum('ij,kj->ik', inp, w['W_q'].T)
1268
1342
  key_output = np.einsum('ij,kj->ik', inp, w['W_k'].T)
1269
1343
  value_output = np.einsum('ij,kj->ik', inp, w['W_v'].T)
1344
+
1345
+ # --------------- Reshape for Multi-Head Attention ----------------------
1346
+ num_heads = getattr(config, 'num_attention_heads', getattr(config, 'num_heads', None)) # will work for BERT as well as T5/ Llama
1347
+ hidden_size = getattr(config, 'hidden_size', getattr(config, 'd_model', None)) # will work for BERT as well as T5/Llama
1348
+ if hasattr(config, 'num_key_value_heads'):
1349
+ num_key_value_heads = config.num_key_value_heads
1350
+ else:
1351
+ num_key_value_heads = num_heads
1352
+ head_dim = hidden_size // num_heads # dimension of each attention head
1353
+
1354
+ query_states = np.einsum('thd->htd', query_output.reshape(query_output.shape[0], num_heads, head_dim)) # (num_heads, num_tokens, head_dim)
1355
+ key_states = np.einsum('thd->htd', key_output.reshape(key_output.shape[0], num_key_value_heads, head_dim)) # (num_key_value_heads, num_tokens, head_dim)
1356
+ value_states = np.einsum('thd->htd', value_output.reshape(value_output.shape[0], num_key_value_heads, head_dim)) # (num_key_value_heads, num_tokens, head_dim)
1357
+
1358
+ # calculate how many times we need to repeat the key/value heads
1359
+ n_rep = num_heads // num_key_value_heads
1360
+ key_states = np.repeat(key_states, n_rep, axis=0)
1361
+ value_states = np.repeat(value_states, n_rep, axis=0)
1362
+
1363
+ QK_output = np.einsum('hqd,hkd->hqk', query_states, key_states) # (num_heads, num_tokens, num_tokens)
1364
+ attn_weights = QK_output / np.sqrt(head_dim)
1365
+
1366
+ # Apply softmax along the last dimension (softmax over key dimension)
1367
+ attn_weights = np.exp(attn_weights - np.max(attn_weights, axis=-1, keepdims=True)) # Numerically stable softmax
1368
+ attn_weights = attn_weights / np.sum(attn_weights, axis=-1, keepdims=True)
1369
+
1370
+ # Weighted sum of values (num_heads, num_tokens, head_dim)
1371
+ attn_output = np.einsum('hqk,hkl->hql', attn_weights, value_states)
1372
+
1373
+ transposed_attn_output = np.einsum('hqd->qhd', attn_output)
1374
+ reshaped_attn_output = transposed_attn_output.reshape(transposed_attn_output.shape[0], num_heads * head_dim)
1375
+
1376
+ # Perform final linear projection (num_tokens, hidden_size)
1377
+ final_output = np.einsum('qd,dh->qh', reshaped_attn_output, w['W_d'])
1378
+
1379
+ # ------------- Relevance calculation for Final Linear Projection -------------
1380
+ wt_mat_attn_proj = calculate_wt_attention_output_projection(wts, final_output, w)
1270
1381
 
1271
1382
  # --------------- Relevance Calculation for Step-3 -----------------------
1272
- relevance_V = wts / 2
1273
- relevance_QK = wts / 2
1383
+ # divide the relevance among `attn_weights` and `value_states`
1384
+ wt_mat_attn_proj = wt_mat_attn_proj.reshape(-1, num_heads, head_dim)
1385
+ wt_mat_attn_proj = np.einsum('qhd->hqd', wt_mat_attn_proj)
1274
1386
 
1275
- # --------------- Relevance Calculation for V --------------------------------
1276
- wt_mat_V = calculate_relevance_V(relevance_V, value_output)
1387
+ stabilized_attn_output = stabilize(attn_output * 2)
1388
+ norm_wt_mat_attn_proj = wt_mat_attn_proj / stabilized_attn_output
1389
+ relevance_QK = np.einsum('htd,hbd->htb', norm_wt_mat_attn_proj, value_states) * attn_weights
1390
+ relevance_V = np.einsum('hdt,hdb->htb', attn_weights, norm_wt_mat_attn_proj) * value_states
1277
1391
 
1392
+ # --------------- Relevance Calculation for V --------------------------------
1393
+ relevance_V = np.einsum('hqd->qhd', relevance_V)
1394
+ relevance_V = relevance_V.reshape(-1, num_heads * head_dim)
1395
+ wt_mat_V = calculate_relevance_V(relevance_V, value_states, w)
1396
+
1278
1397
  # --------------- Transformed Relevance QK ----------------------------------
1279
- QK_output = np.einsum('ij,kj->ik', query_output, key_output)
1280
- wt_mat_QK = calculate_relevance_QK(relevance_QK, QK_output)
1398
+ relevance_QK = np.einsum('hqd->qhd', relevance_QK)
1399
+ relevance_QK = relevance_QK.reshape(-1, relevance_QK.shape[1] * relevance_QK.shape[2])
1400
+ wt_mat_QK = calculate_relevance_QK(relevance_QK, QK_output, w)
1281
1401
 
1282
1402
  # --------------- Relevance Calculation for K and Q --------------------------------
1283
1403
  stabilized_QK_output = stabilize(QK_output * 2)
1284
1404
  norm_wt_mat_QK = wt_mat_QK / stabilized_QK_output
1285
- wt_mat_Q = np.einsum('ij,jk->ik', norm_wt_mat_QK, key_output) * query_output
1286
- wt_mat_K = np.einsum('ij,ik->kj', query_output, norm_wt_mat_QK) * key_output
1405
+ wt_mat_Q = np.einsum('htd,hdb->htb', norm_wt_mat_QK, key_states) * query_states
1406
+ wt_mat_K = np.einsum('htd,htb->hbd', query_states, norm_wt_mat_QK) * key_states
1287
1407
 
1288
1408
  wt_mat = wt_mat_V + wt_mat_K + wt_mat_Q
1409
+
1410
+ # Reshape wt_mat
1411
+ wt_mat = np.einsum('htd->thd', wt_mat)
1412
+ wt_mat = wt_mat.reshape(wt_mat.shape[0], wt_mat.shape[1] * wt_mat.shape[2]) # reshaped_array = array.reshape(8, 32 * 128)
1413
+
1289
1414
  return wt_mat
1290
1415
 
1291
1416
 
@@ -1301,6 +1426,8 @@ def calculate_wt_feed_forward(wts, inp, w):
1301
1426
  R2 = wts[i]
1302
1427
  contribution_matrix2 = np.einsum('ij,j->ij', w['W_out'].T, intermediate_output[i])
1303
1428
  wt_mat2 = np.zeros(contribution_matrix2.shape)
1429
+
1430
+ bias_out = w['b_out'] if 'b_out' in w else 0
1304
1431
 
1305
1432
  for j in range(contribution_matrix2.shape[0]):
1306
1433
  l1_ind1 = contribution_matrix2[j]
@@ -1312,13 +1439,22 @@ def calculate_wt_feed_forward(wts, inp, w):
1312
1439
  p_sum = np.sum(l1_ind1[p_ind])
1313
1440
  n_sum = np.sum(l1_ind1[n_ind]) * -1
1314
1441
 
1442
+ # Handle positive and negative bias contributions
1443
+ if bias_out[i] > 0:
1444
+ pbias = bias_out[i]
1445
+ nbias = 0
1446
+ else:
1447
+ pbias = 0
1448
+ nbias = -bias_out[i]
1449
+
1315
1450
  if p_sum > 0:
1316
- p_agg_wt = p_sum / (p_sum + n_sum)
1451
+ p_agg_wt = (p_sum + pbias) / (p_sum + n_sum + pbias + nbias)
1452
+ p_agg_wt = p_agg_wt * (p_sum / (p_sum + pbias))
1317
1453
  else:
1318
1454
  p_agg_wt = 0
1319
-
1320
1455
  if n_sum > 0:
1321
- n_agg_wt = n_sum / (p_sum + n_sum)
1456
+ n_agg_wt = (n_sum + nbias) / (p_sum + n_sum + pbias + nbias)
1457
+ n_agg_wt = n_agg_wt * (n_sum / (n_sum + nbias))
1322
1458
  else:
1323
1459
  n_agg_wt = 0
1324
1460
 
@@ -1337,6 +1473,9 @@ def calculate_wt_feed_forward(wts, inp, w):
1337
1473
  R1 = relevance_out[i]
1338
1474
  contribution_matrix1 = np.einsum('ij,j->ij', w['W_int'].T, inp[i])
1339
1475
  wt_mat1 = np.zeros(contribution_matrix1.shape)
1476
+
1477
+ # Check if bias 'b_int' exists, default to 0 if not
1478
+ bias_int = w['b_int'] if 'b_int' in w else 0
1340
1479
 
1341
1480
  for j in range(contribution_matrix1.shape[0]):
1342
1481
  l1_ind1 = contribution_matrix1[j]
@@ -1348,7 +1487,15 @@ def calculate_wt_feed_forward(wts, inp, w):
1348
1487
  p_sum = np.sum(l1_ind1[p_ind])
1349
1488
  n_sum = np.sum(l1_ind1[n_ind]) * -1
1350
1489
 
1351
- t_sum = p_sum - n_sum
1490
+ # Handle positive and negative bias
1491
+ if bias_int[i] > 0:
1492
+ pbias = bias_int[i]
1493
+ nbias = 0
1494
+ else:
1495
+ pbias = 0
1496
+ nbias = -bias_int[i]
1497
+
1498
+ t_sum = p_sum + pbias - n_sum - nbias
1352
1499
 
1353
1500
  # This layer has a ReLU activation function
1354
1501
  act = {
@@ -1367,12 +1514,13 @@ def calculate_wt_feed_forward(wts, inp, w):
1367
1514
  n_sum = 0
1368
1515
 
1369
1516
  if p_sum > 0:
1370
- p_agg_wt = p_sum / (p_sum + n_sum)
1517
+ p_agg_wt = (p_sum + pbias) / (p_sum + n_sum + pbias + nbias)
1518
+ p_agg_wt = p_agg_wt * (p_sum / (p_sum + pbias))
1371
1519
  else:
1372
1520
  p_agg_wt = 0
1373
-
1374
1521
  if n_sum > 0:
1375
- n_agg_wt = n_sum / (p_sum + n_sum)
1522
+ n_agg_wt = (n_sum + nbias) / (p_sum + n_sum + pbias + nbias)
1523
+ n_agg_wt = n_agg_wt * (n_sum / (n_sum + nbias))
1376
1524
  else:
1377
1525
  n_agg_wt = 0
1378
1526
 
@@ -1461,8 +1609,8 @@ def calculate_wt_pooler(wts, inp, w):
1461
1609
  # Calculate relevance for each token
1462
1610
  relevance_inp[i] = wt_mat.sum(axis=0)
1463
1611
 
1464
- relevance_inp *= (100 / np.sum(relevance_inp))
1465
- return relevance_inp
1612
+ relevance_inp *= (np.sum(wts) / np.sum(relevance_inp))
1613
+ return relevance_inp
1466
1614
 
1467
1615
 
1468
1616
  def calculate_wt_classifier(wts, inp, w):
@@ -1595,7 +1743,7 @@ def calculate_wt_lm_head(wts, inp, w):
1595
1743
  return relevance_input
1596
1744
 
1597
1745
 
1598
- def calculate_wt_cross_attention(wts, inp, w):
1746
+ def calculate_wt_cross_attention(wts, inp, w, config):
1599
1747
  '''
1600
1748
  Input:
1601
1749
  wts: relevance score of the layer
@@ -1613,23 +1761,77 @@ def calculate_wt_cross_attention(wts, inp, w):
1613
1761
  key_output = np.einsum('ij,kj->ik', k_v_inp, w['W_k'].T)
1614
1762
  value_output = np.einsum('ij,kj->ik', k_v_inp, w['W_v'].T)
1615
1763
 
1764
+ # --------------- Reshape for Multi-Head Attention ----------------------
1765
+ num_heads = getattr(config, 'num_attention_heads', getattr(config, 'num_heads', None)) # will work for BERT as well as T5/ Llama
1766
+ hidden_size = getattr(config, 'hidden_size', getattr(config, 'd_model', None)) # will work for BERT as well as T5/Llama
1767
+ if hasattr(config, 'num_key_value_heads'):
1768
+ num_key_value_heads = config.num_key_value_heads
1769
+ else:
1770
+ num_key_value_heads = num_heads
1771
+ head_dim = hidden_size // num_heads # dimension of each attention head
1772
+
1773
+ query_states = np.einsum('thd->htd', query_output.reshape(query_output.shape[0], num_heads, head_dim)) # (num_heads, num_tokens, head_dim)
1774
+ key_states = np.einsum('thd->htd', key_output.reshape(key_output.shape[0], num_key_value_heads, head_dim)) # (num_key_value_heads, num_tokens, head_dim)
1775
+ value_states = np.einsum('thd->htd', value_output.reshape(value_output.shape[0], num_key_value_heads, head_dim)) # (num_key_value_heads, num_tokens, head_dim)
1776
+
1777
+ # calculate how many times we need to repeat the key/value heads
1778
+ n_rep = num_heads // num_key_value_heads
1779
+ key_states = np.repeat(key_states, n_rep, axis=0)
1780
+ value_states = np.repeat(value_states, n_rep, axis=0)
1781
+
1782
+ QK_output = np.einsum('hqd,hkd->hqk', query_states, key_states) # (num_heads, num_tokens, num_tokens)
1783
+ attn_weights = QK_output / np.sqrt(head_dim)
1784
+
1785
+ # Apply softmax along the last dimension (softmax over key dimension)
1786
+ attn_weights = np.exp(attn_weights - np.max(attn_weights, axis=-1, keepdims=True)) # Numerically stable softmax
1787
+ attn_weights = attn_weights / np.sum(attn_weights, axis=-1, keepdims=True)
1788
+
1789
+ # Weighted sum of values (num_heads, num_tokens, head_dim)
1790
+ attn_output = np.einsum('hqk,hkl->hql', attn_weights, value_states)
1791
+
1792
+ transposed_attn_output = np.einsum('hqd->qhd', attn_output)
1793
+ reshaped_attn_output = transposed_attn_output.reshape(transposed_attn_output.shape[0], num_heads * head_dim)
1794
+
1795
+ # Perform final linear projection (num_tokens, hidden_size)
1796
+ final_output = np.einsum('qd,dh->qh', reshaped_attn_output, w['W_d'])
1797
+
1798
+ # ------------- Relevance calculation for Final Linear Projection -------------
1799
+ wt_mat_attn_proj = calculate_wt_attention_output_projection(wts, final_output)
1800
+
1616
1801
  # --------------- Relevance Calculation for Step-3 -----------------------
1617
- relevance_V = wts / 2
1618
- relevance_QK = wts / 2
1802
+ # divide the relevance among `attn_weights` and `value_states`
1803
+ wt_mat_attn_proj = wt_mat_attn_proj.reshape(-1, num_heads, head_dim)
1804
+ wt_mat_attn_proj = np.einsum('qhd->hqd', wt_mat_attn_proj)
1619
1805
 
1620
- # --------------- Relevance Calculation for V --------------------------------
1621
- wt_mat_V = calculate_relevance_V(relevance_V, value_output)
1806
+ stabilized_attn_output = stabilize(attn_output * 2)
1807
+ norm_wt_mat_attn_proj = wt_mat_attn_proj / stabilized_attn_output
1808
+ relevance_QK = np.einsum('htd,hbd->htb', norm_wt_mat_attn_proj, value_states) * attn_weights
1809
+ relevance_V = np.einsum('hdt,hdb->htb', attn_weights, norm_wt_mat_attn_proj) * value_states
1622
1810
 
1811
+ # --------------- Relevance Calculation for V --------------------------------
1812
+ relevance_V = np.einsum('hqd->qhd', relevance_V)
1813
+ relevance_V = relevance_V.reshape(-1, num_heads * head_dim)
1814
+ wt_mat_V = calculate_relevance_V(relevance_V, value_states)
1815
+
1623
1816
  # --------------- Transformed Relevance QK ----------------------------------
1624
- QK_output = np.einsum('ij,kj->ik', query_output, key_output)
1817
+ relevance_QK = np.einsum('hqd->qhd', relevance_QK)
1818
+ relevance_QK = relevance_QK.reshape(-1, relevance_QK.shape[1] * relevance_QK.shape[2])
1625
1819
  wt_mat_QK = calculate_relevance_QK(relevance_QK, QK_output)
1626
1820
 
1627
1821
  # --------------- Relevance Calculation for K and Q --------------------------------
1628
1822
  stabilized_QK_output = stabilize(QK_output * 2)
1629
1823
  norm_wt_mat_QK = wt_mat_QK / stabilized_QK_output
1630
- wt_mat_Q = np.einsum('ij,jk->ik', norm_wt_mat_QK, key_output) * query_output
1631
- wt_mat_K = np.einsum('ij,ik->kj', query_output, norm_wt_mat_QK) * key_output
1824
+ wt_mat_Q = np.einsum('htd,hdb->htb', norm_wt_mat_QK, key_states) * query_states
1825
+ wt_mat_K = np.einsum('htd,htb->hbd', query_states, norm_wt_mat_QK) * key_states
1632
1826
 
1827
+ # Relevance of KV input
1633
1828
  wt_mat_KV = wt_mat_V + wt_mat_K
1829
+
1830
+ # Reshape wt_mat_Q and wt_mat_KV
1831
+ wt_mat_Q = np.einsum('htd->thd', wt_mat_Q)
1832
+ wt_mat_KV = np.einsum('htd->thd', wt_mat_KV)
1833
+ wt_mat_Q = wt_mat_Q.reshape(wt_mat_Q.shape[0], wt_mat_Q.shape[1] * wt_mat_Q.shape[2])
1834
+ wt_mat_KV = wt_mat_KV.reshape(wt_mat_KV.shape[0], wt_mat_KV.shape[1] * wt_mat_KV.shape[2])
1835
+
1634
1836
  wt_mat = [wt_mat_KV, wt_mat_Q]
1635
1837
  return wt_mat
dl_backtrace/version.py CHANGED
@@ -12,5 +12,5 @@ __version__: str
12
12
  __version_tuple__: VERSION_TUPLE
13
13
  version_tuple: VERSION_TUPLE
14
14
 
15
- __version__ = version = '0.0.19'
16
- __version_tuple__ = version_tuple = (0, 0, 19)
15
+ __version__ = version = '0.0.20.dev36'
16
+ __version_tuple__ = version_tuple = (0, 0, 20, 'dev36')
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: dl_backtrace
3
- Version: 0.0.19
3
+ Version: 0.0.20.dev36
4
4
  Summary: A python SDK for Deep Learning Backtrace
5
5
  Home-page: https://xai.arya.ai/docs/introduction
6
6
  License: MIT
@@ -1,5 +1,5 @@
1
1
  dl_backtrace/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
- dl_backtrace/version.py,sha256=CYabGzkNwriz1Zjt5kNvBOZD6wtqQ_twYh4s5xzmT-I,413
2
+ dl_backtrace/version.py,sha256=tYHVV4mIeOCumN5OCSD_xV6vt2LqOBp8qrLBhN4xnyw,428
3
3
  dl_backtrace/old_backtrace/__init__.py,sha256=AbpHGcgLb-kRsJGnwFEktk7uzpZOCcBY74-YBdrKVGs,1
4
4
  dl_backtrace/old_backtrace/pytorch_backtrace/__init__.py,sha256=TDhKQIj1INyQq7cqTvpjpnBDhoMeWSoqVdx_mPAV3Sw,24
5
5
  dl_backtrace/old_backtrace/pytorch_backtrace/backtrace/__init__.py,sha256=AuR7uMTbf7rrFl-9sMIQ3lQmhu_ATSFBZmfs7R66jAc,76
@@ -17,18 +17,18 @@ dl_backtrace/old_backtrace/tf_backtrace/backtrace/utils/contrast.py,sha256=p-0zj
17
17
  dl_backtrace/old_backtrace/tf_backtrace/backtrace/utils/prop.py,sha256=-h7nHEvsoEwfksHsT52VfnZy334DvqqP8g6fMFVNnAM,25670
18
18
  dl_backtrace/pytorch_backtrace/__init__.py,sha256=TDhKQIj1INyQq7cqTvpjpnBDhoMeWSoqVdx_mPAV3Sw,24
19
19
  dl_backtrace/pytorch_backtrace/backtrace/__init__.py,sha256=AuR7uMTbf7rrFl-9sMIQ3lQmhu_ATSFBZmfs7R66jAc,76
20
- dl_backtrace/pytorch_backtrace/backtrace/backtrace.py,sha256=eeq1tj934j91bsXqZ3FbggFbBqo51bBxXWW6VuHVaMk,44153
20
+ dl_backtrace/pytorch_backtrace/backtrace/backtrace.py,sha256=bdb8eFzTE7ms34ne-VBJxXY8AMHhu-vTenUNwCMKbRo,44410
21
21
  dl_backtrace/pytorch_backtrace/backtrace/config.py,sha256=ODrgOC74ojzGLnEto_ah-WPA8_MyRE7cZkZlU15239s,983
22
22
  dl_backtrace/pytorch_backtrace/backtrace/utils/__init__.py,sha256=KffAJVu7NsgfMHEZaY7lND2LQwZamVIquqx8POwOaLg,120
23
23
  dl_backtrace/pytorch_backtrace/backtrace/utils/contrast.py,sha256=gcvDRZstcgAkjzyyxy6dagSV2EDFGqh9_SMn9P1-2lo,52795
24
24
  dl_backtrace/pytorch_backtrace/backtrace/utils/encoder.py,sha256=1dUKEL_LAuFcUEldYdyQVsV_P-KH3gTWEOgROfXPwyc,7469
25
25
  dl_backtrace/pytorch_backtrace/backtrace/utils/encoder_decoder.py,sha256=fxbXwUl7KQKwWSdKMO-bYvDC7NDct71F_U7uZzh1orw,25658
26
26
  dl_backtrace/pytorch_backtrace/backtrace/utils/helper.py,sha256=IG0XkPEfbPpWBm4aVHgv-GkgFFrl2wu5xMrz6DeH_xQ,3512
27
- dl_backtrace/pytorch_backtrace/backtrace/utils/prop.py,sha256=40fJRTsI5cvb-dbezrklSEWWghu9QobCe8TxaJQWtqU,58982
27
+ dl_backtrace/pytorch_backtrace/backtrace/utils/prop.py,sha256=oEwWmRsLwoYdCXJaEEW7-UL1CLVOsseUAcOasTkyQeE,67294
28
28
  dl_backtrace/tf_backtrace/__init__.py,sha256=TDhKQIj1INyQq7cqTvpjpnBDhoMeWSoqVdx_mPAV3Sw,24
29
29
  dl_backtrace/tf_backtrace/backtrace/__init__.py,sha256=KkU7X_wXxwYR4HYQqQ3kWtxlJK3Ytaa84e1Jbzc2_ZA,84
30
30
  dl_backtrace/tf_backtrace/backtrace/activation_info.py,sha256=3Ppw4_6rJV16YPXnKjd2WPaULgUUL6U01bh8Foa-3Yg,1334
31
- dl_backtrace/tf_backtrace/backtrace/backtrace.py,sha256=eQa7wz3MsfjkgZE72voKGlZGjxbf42kHkXcwB7Ve3qI,41474
31
+ dl_backtrace/tf_backtrace/backtrace/backtrace.py,sha256=k6irwiOu68EpRG_W9wk5UF7QS50TKM5JqUQfj4qmvWw,41635
32
32
  dl_backtrace/tf_backtrace/backtrace/models.py,sha256=wPaeRuEvZL2xTdj6I6hF-ZCKC2c8EAuF9PoOIZkxkR4,466
33
33
  dl_backtrace/tf_backtrace/backtrace/server.py,sha256=jphibvI46QpQcqnpXVIYFq_M2CtRVONgItZe9iWOd54,567
34
34
  dl_backtrace/tf_backtrace/backtrace/utils/__init__.py,sha256=ci_RAYYnqyAWa_rcIEycnqCghQ4aZtvaGQ7oDUb_k_0,131
@@ -36,9 +36,9 @@ dl_backtrace/tf_backtrace/backtrace/utils/encoder.py,sha256=WeGLjIRHNqjwIK-8UB0x
36
36
  dl_backtrace/tf_backtrace/backtrace/utils/encoder_decoder.py,sha256=qbS34WswNiT1xgON5ayNAIewJMRfDdeGel6l1XrjXms,24247
37
37
  dl_backtrace/tf_backtrace/backtrace/utils/helper.py,sha256=QB21kPB5iJfRpy8khYJnzojaKf5ACnAFYh5XxYBcnXA,3419
38
38
  dl_backtrace/tf_backtrace/backtrace/utils/utils_contrast.py,sha256=rQwManW0d6Td6V_A1qGezcZ19Tgr34zbFcieQ0rqAAc,48415
39
- dl_backtrace/tf_backtrace/backtrace/utils/utils_prop.py,sha256=dC6QxbwpA5JhE8009vnZVQyY8eIE3RFx6t7I-N17-k0,58320
40
- dl_backtrace-0.0.19.dist-info/LICENSE,sha256=RTqAU0MFv1q3ZXKewNobKxIIPzRHgImom7e6ORV7X6o,1064
41
- dl_backtrace-0.0.19.dist-info/METADATA,sha256=tSJfzmAVuRLRLUjdXyJ-nBTis8L6yY-TKtV6YVBULao,7837
42
- dl_backtrace-0.0.19.dist-info/WHEEL,sha256=cVxcB9AmuTcXqmwrtPhNK88dr7IR_b6qagTj0UvIEbY,91
43
- dl_backtrace-0.0.19.dist-info/top_level.txt,sha256=gvGVYScJfW6c4aO5WMo4Aqa6NLEfmLK7VWXVx_GeiIk,13
44
- dl_backtrace-0.0.19.dist-info/RECORD,,
39
+ dl_backtrace/tf_backtrace/backtrace/utils/utils_prop.py,sha256=kduJpbN2FIrrhgrUxSFkX3gnTfXLuQa4qx-LH0dDB7A,68291
40
+ dl_backtrace-0.0.20.dev36.dist-info/LICENSE,sha256=RTqAU0MFv1q3ZXKewNobKxIIPzRHgImom7e6ORV7X6o,1064
41
+ dl_backtrace-0.0.20.dev36.dist-info/METADATA,sha256=H5E00w4t5F6O67roa2Y6tiGEW_SziKjsgMssLEoDM_U,7843
42
+ dl_backtrace-0.0.20.dev36.dist-info/WHEEL,sha256=R06PA3UVYHThwHvxuRWMqaGcr-PuniXahwjmQRFMEkY,91
43
+ dl_backtrace-0.0.20.dev36.dist-info/top_level.txt,sha256=gvGVYScJfW6c4aO5WMo4Aqa6NLEfmLK7VWXVx_GeiIk,13
44
+ dl_backtrace-0.0.20.dev36.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (74.1.2)
2
+ Generator: setuptools (75.5.0)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5