dl-backtrace 0.0.18__py3-none-any.whl → 0.0.20.dev36__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of dl-backtrace might be problematic. Click here for more details.

@@ -192,15 +192,15 @@ class Backtrace(object):
192
192
  return temp_out
193
193
 
194
194
  def eval(self, all_out, start_wt=[], mode="default",multiplier=100.0,
195
- scaler=None, max_unit=0,thresholding=0.5,
195
+ scaler=0, max_unit=0,thresholding=0.5,
196
196
  task="binary-classification",predicted_token=None):
197
197
 
198
198
  if mode=="default":
199
199
  output = self.proportional_eval(all_out=all_out,
200
- start_wt=start_wt ,
201
- multiplier=multiplier,
202
- scaler=scaler,
203
- max_unit=max_unit,
200
+ start_wt=start_wt ,
201
+ multiplier=multiplier,
202
+ scaler=scaler,
203
+ max_unit=max_unit,
204
204
  thresholding=thresholding,
205
205
  task=task,
206
206
  predicted_token=predicted_token)
@@ -219,7 +219,7 @@ class Backtrace(object):
219
219
  return output
220
220
 
221
221
  def proportional_eval(self, all_out, start_wt=[] ,
222
- multiplier=100.0, scaler=None, max_unit=0,
222
+ multiplier=100.0, scaler=0, max_unit=0,
223
223
  predicted_token=None, thresholding=0.5,
224
224
  task="binary-classification"):
225
225
  model_resource = self.model_resource
@@ -229,7 +229,7 @@ class Backtrace(object):
229
229
  all_wt = {}
230
230
  if len(start_wt) == 0:
231
231
  if self.model_type == 'encoder':
232
- start_wt = UP.calculate_start_wt(all_out[out_layer])
232
+ start_wt = UP.calculate_start_wt(all_out[out_layer], scaler=scaler)
233
233
  all_wt[out_layer] = start_wt * multiplier
234
234
  layer_stack = self.layer_stack
235
235
  all_wts = self.model_weights
@@ -442,10 +442,12 @@ class Backtrace(object):
442
442
  elif model_resource["graph"][start_layer]["class"] == "Self_Attention":
443
443
  weights = all_wts[start_layer]
444
444
  self_attention_weights = HP.rename_self_attention_keys(weights)
445
+ config = self.model.config
445
446
  temp_wt = UP.calculate_wt_self_attention(
446
447
  all_wt[start_layer],
447
448
  all_out[child_nodes[0]][0],
448
449
  self_attention_weights,
450
+ config
449
451
  )
450
452
  all_wt[child_nodes[0]] += temp_wt
451
453
  elif model_resource["graph"][start_layer]["class"] == 'Residual':
@@ -502,10 +504,12 @@ class Backtrace(object):
502
504
  elif model_resource["graph"][start_layer]["class"] == 'Cross_Attention':
503
505
  weights = all_wts[start_layer]
504
506
  cross_attention_weights = HP.rename_cross_attention_keys(weights)
507
+ config = self.model.config
505
508
  temp_wt = UP.calculate_wt_cross_attention(
506
509
  all_wt[start_layer],
507
510
  [all_out[ch][0] for ch in child_nodes],
508
511
  cross_attention_weights,
512
+ config
509
513
  )
510
514
  for ind, ch in enumerate(child_nodes):
511
515
  all_wt[ch] += temp_wt[ind]
@@ -1161,14 +1161,17 @@ def calculate_wt_residual(wts, inp=None):
1161
1161
  return wt_mat
1162
1162
 
1163
1163
 
1164
- def calculate_relevance_V(wts, value_output):
1165
- # Initialize wt_mat with zeros
1166
- wt_mat_V = np.zeros((wts.shape[0], wts.shape[1], *value_output.shape))
1164
+ def calculate_relevance_V(wts, value_output, w):
1165
+ wt_mat_V = np.zeros(value_output.shape)
1166
+
1167
+ if 'b_v' in w:
1168
+ bias_v = w['b_v']
1169
+ else:
1170
+ bias_v = 0
1167
1171
 
1168
1172
  for i in range(wts.shape[0]):
1169
1173
  for j in range(wts.shape[1]):
1170
1174
  l1_ind1 = value_output
1171
- wt_ind1 = wt_mat_V[i, j]
1172
1175
  wt = wts[i, j]
1173
1176
 
1174
1177
  p_ind = l1_ind1 > 0
@@ -1176,12 +1179,21 @@ def calculate_relevance_V(wts, value_output):
1176
1179
  p_sum = np.sum(l1_ind1[p_ind])
1177
1180
  n_sum = np.sum(l1_ind1[n_ind]) * -1
1178
1181
 
1182
+ if bias_v[i] > 0:
1183
+ pbias = bias_v[i]
1184
+ nbias = 0
1185
+ else:
1186
+ pbias = 0
1187
+ nbias = bias_v[i] * -1
1188
+
1179
1189
  if p_sum > 0:
1180
- p_agg_wt = p_sum / (p_sum + n_sum)
1190
+ p_agg_wt = (p_sum + pbias) / (p_sum + n_sum + pbias + nbias)
1191
+ p_agg_wt = p_agg_wt * (p_sum / (p_sum + pbias))
1181
1192
  else:
1182
1193
  p_agg_wt = 0
1183
1194
  if n_sum > 0:
1184
- n_agg_wt = n_sum / (p_sum + n_sum)
1195
+ n_agg_wt = (n_sum + nbias) / (p_sum + n_sum + pbias + nbias)
1196
+ n_agg_wt = n_agg_wt * (n_sum / (n_sum + nbias))
1185
1197
  else:
1186
1198
  n_agg_wt = 0
1187
1199
 
@@ -1190,21 +1202,22 @@ def calculate_relevance_V(wts, value_output):
1190
1202
  if n_sum == 0:
1191
1203
  n_sum = 1
1192
1204
 
1193
- wt_ind1[p_ind] = (l1_ind1[p_ind] / p_sum) * wt * p_agg_wt
1194
- wt_ind1[n_ind] = (l1_ind1[n_ind] / n_sum) * wt * n_agg_wt * -1.0
1205
+ wt_mat_V[p_ind] += (l1_ind1[p_ind] / p_sum) * wt * p_agg_wt
1206
+ wt_mat_V[n_ind] += (l1_ind1[n_ind] / n_sum) * wt * n_agg_wt * -1.0
1195
1207
 
1196
- wt_mat_V = np.sum(wt_mat_V, axis=(0,1))
1197
- return wt_mat_V
1208
+ return wt_mat_V
1198
1209
 
1199
1210
 
1200
- def calculate_relevance_QK(wts, QK_output):
1201
- # Initialize wt_mat with zeros
1202
- wt_mat_QK = np.zeros((wts.shape[0], wts.shape[1], *QK_output.shape))
1211
+ def calculate_relevance_QK(wts, QK_output, w):
1212
+ wt_mat_QK = np.zeros(QK_output.shape)
1213
+
1214
+ # Check if 'b_q' and 'b_k' exist in the weights, default to 0 if not
1215
+ b_q = w['b_q'] if 'b_q' in w else 0
1216
+ b_k = w['b_k'] if 'b_k' in w else 0
1203
1217
 
1204
1218
  for i in range(wts.shape[0]):
1205
1219
  for j in range(wts.shape[1]):
1206
1220
  l1_ind1 = QK_output
1207
- wt_ind1 = wt_mat_QK[i, j]
1208
1221
  wt = wts[i, j]
1209
1222
 
1210
1223
  p_ind = l1_ind1 > 0
@@ -1212,7 +1225,21 @@ def calculate_relevance_QK(wts, QK_output):
1212
1225
  p_sum = np.sum(l1_ind1[p_ind])
1213
1226
  n_sum = np.sum(l1_ind1[n_ind]) * -1
1214
1227
 
1215
- t_sum = p_sum - n_sum
1228
+ if b_q[i] > 0 and b_k[i] > 0:
1229
+ pbias = b_q[i] + b_k[i]
1230
+ nbias = 0
1231
+ elif b_q[i] > 0 and b_k[i] < 0:
1232
+ pbias = b_q[i]
1233
+ nbias = b_k[i] * -1
1234
+ elif b_q[i] < 0 and b_k[i] > 0:
1235
+ pbias = b_k[i]
1236
+ nbias = b_q[i] * -1
1237
+ else:
1238
+ pbias = 0
1239
+ nbias = b_q[i] + b_k[i]
1240
+ nbias *= -1
1241
+
1242
+ t_sum = p_sum + pbias - n_sum - nbias
1216
1243
 
1217
1244
  # This layer has a softmax activation function
1218
1245
  act = {
@@ -1231,12 +1258,13 @@ def calculate_relevance_QK(wts, QK_output):
1231
1258
  n_sum = 0
1232
1259
 
1233
1260
  if p_sum > 0:
1234
- p_agg_wt = p_sum / (p_sum + n_sum)
1261
+ p_agg_wt = (p_sum + pbias) / (p_sum + n_sum + pbias + nbias)
1262
+ p_agg_wt = p_agg_wt * (p_sum / (p_sum + pbias))
1235
1263
  else:
1236
1264
  p_agg_wt = 0
1237
-
1238
1265
  if n_sum > 0:
1239
- n_agg_wt = n_sum / (p_sum + n_sum)
1266
+ n_agg_wt = (n_sum + nbias) / (p_sum + n_sum + pbias + nbias)
1267
+ n_agg_wt = n_agg_wt * (n_sum / (n_sum + nbias))
1240
1268
  else:
1241
1269
  n_agg_wt = 0
1242
1270
 
@@ -1245,14 +1273,60 @@ def calculate_relevance_QK(wts, QK_output):
1245
1273
  if n_sum == 0:
1246
1274
  n_sum = 1
1247
1275
 
1248
- wt_ind1[p_ind] = (l1_ind1[p_ind] / p_sum) * wt * p_agg_wt
1249
- wt_ind1[n_ind] = (l1_ind1[n_ind] / n_sum) * wt * n_agg_wt * -1.0
1276
+ wt_mat_QK[p_ind] += (l1_ind1[p_ind] / p_sum) * wt * p_agg_wt
1277
+ wt_mat_QK[n_ind] += (l1_ind1[n_ind] / n_sum) * wt * n_agg_wt * -1.0
1250
1278
 
1251
- wt_mat_QK = np.sum(wt_mat_QK, axis=(0, 1))
1252
1279
  return wt_mat_QK
1253
1280
 
1254
1281
 
1255
- def calculate_wt_self_attention(wts, inp, w):
1282
+ def calculate_wt_attention_output_projection(wts, proj_output, w):
1283
+ wt_mat_proj_output = np.zeros(proj_output.shape)
1284
+
1285
+ if 'b_d' in w:
1286
+ bias_d = w['b_d']
1287
+ else:
1288
+ bias_d = 0
1289
+
1290
+ for i in range(wts.shape[0]):
1291
+ for j in range(wts.shape[1]):
1292
+ l1_ind1 = proj_output
1293
+ wt = wts[i, j]
1294
+
1295
+ p_ind = l1_ind1 > 0
1296
+ n_ind = l1_ind1 < 0
1297
+ p_sum = np.sum(l1_ind1[p_ind])
1298
+ n_sum = np.sum(l1_ind1[n_ind]) * -1
1299
+
1300
+ if bias_d[i] > 0:
1301
+ pbias = bias_d[i]
1302
+ nbias = 0
1303
+ else:
1304
+ pbias = 0
1305
+ nbias = bias_d[i] * -1
1306
+
1307
+ if p_sum > 0:
1308
+ p_agg_wt = (p_sum + pbias) / (p_sum + n_sum + pbias + nbias)
1309
+ p_agg_wt = p_agg_wt * (p_sum / (p_sum + pbias))
1310
+ else:
1311
+ p_agg_wt = 0
1312
+ if n_sum > 0:
1313
+ n_agg_wt = (n_sum + nbias) / (p_sum + n_sum + pbias + nbias)
1314
+ n_agg_wt = n_agg_wt * (n_sum / (n_sum + nbias))
1315
+ else:
1316
+ n_agg_wt = 0
1317
+
1318
+ if p_sum == 0:
1319
+ p_sum = 1
1320
+ if n_sum == 0:
1321
+ n_sum = 1
1322
+
1323
+ wt_mat_proj_output[p_ind] += (l1_ind1[p_ind] / p_sum) * wt * p_agg_wt
1324
+ wt_mat_proj_output[n_ind] += (l1_ind1[n_ind] / n_sum) * wt * n_agg_wt * -1.0
1325
+
1326
+ return wt_mat_proj_output
1327
+
1328
+
1329
+ def calculate_wt_self_attention(wts, inp, w, config):
1256
1330
  '''
1257
1331
  Input:
1258
1332
  wts: relevance score of the layer
@@ -1267,25 +1341,76 @@ def calculate_wt_self_attention(wts, inp, w):
1267
1341
  query_output = np.einsum('ij,kj->ik', inp, w['W_q'].T)
1268
1342
  key_output = np.einsum('ij,kj->ik', inp, w['W_k'].T)
1269
1343
  value_output = np.einsum('ij,kj->ik', inp, w['W_v'].T)
1344
+
1345
+ # --------------- Reshape for Multi-Head Attention ----------------------
1346
+ num_heads = getattr(config, 'num_attention_heads', getattr(config, 'num_heads', None)) # will work for BERT as well as T5/ Llama
1347
+ hidden_size = getattr(config, 'hidden_size', getattr(config, 'd_model', None)) # will work for BERT as well as T5/Llama
1348
+ if hasattr(config, 'num_key_value_heads'):
1349
+ num_key_value_heads = config.num_key_value_heads
1350
+ else:
1351
+ num_key_value_heads = num_heads
1352
+ head_dim = hidden_size // num_heads # dimension of each attention head
1353
+
1354
+ query_states = np.einsum('thd->htd', query_output.reshape(query_output.shape[0], num_heads, head_dim)) # (num_heads, num_tokens, head_dim)
1355
+ key_states = np.einsum('thd->htd', key_output.reshape(key_output.shape[0], num_key_value_heads, head_dim)) # (num_key_value_heads, num_tokens, head_dim)
1356
+ value_states = np.einsum('thd->htd', value_output.reshape(value_output.shape[0], num_key_value_heads, head_dim)) # (num_key_value_heads, num_tokens, head_dim)
1357
+
1358
+ # calculate how many times we need to repeat the key/value heads
1359
+ n_rep = num_heads // num_key_value_heads
1360
+ key_states = np.repeat(key_states, n_rep, axis=0)
1361
+ value_states = np.repeat(value_states, n_rep, axis=0)
1362
+
1363
+ QK_output = np.einsum('hqd,hkd->hqk', query_states, key_states) # (num_heads, num_tokens, num_tokens)
1364
+ attn_weights = QK_output / np.sqrt(head_dim)
1365
+
1366
+ # Apply softmax along the last dimension (softmax over key dimension)
1367
+ attn_weights = np.exp(attn_weights - np.max(attn_weights, axis=-1, keepdims=True)) # Numerically stable softmax
1368
+ attn_weights = attn_weights / np.sum(attn_weights, axis=-1, keepdims=True)
1369
+
1370
+ # Weighted sum of values (num_heads, num_tokens, head_dim)
1371
+ attn_output = np.einsum('hqk,hkl->hql', attn_weights, value_states)
1372
+
1373
+ transposed_attn_output = np.einsum('hqd->qhd', attn_output)
1374
+ reshaped_attn_output = transposed_attn_output.reshape(transposed_attn_output.shape[0], num_heads * head_dim)
1375
+
1376
+ # Perform final linear projection (num_tokens, hidden_size)
1377
+ final_output = np.einsum('qd,dh->qh', reshaped_attn_output, w['W_d'])
1378
+
1379
+ # ------------- Relevance calculation for Final Linear Projection -------------
1380
+ wt_mat_attn_proj = calculate_wt_attention_output_projection(wts, final_output, w)
1270
1381
 
1271
1382
  # --------------- Relevance Calculation for Step-3 -----------------------
1272
- relevance_V = wts / 2
1273
- relevance_QK = wts / 2
1383
+ # divide the relevance among `attn_weights` and `value_states`
1384
+ wt_mat_attn_proj = wt_mat_attn_proj.reshape(-1, num_heads, head_dim)
1385
+ wt_mat_attn_proj = np.einsum('qhd->hqd', wt_mat_attn_proj)
1274
1386
 
1275
- # --------------- Relevance Calculation for V --------------------------------
1276
- wt_mat_V = calculate_relevance_V(relevance_V, value_output)
1387
+ stabilized_attn_output = stabilize(attn_output * 2)
1388
+ norm_wt_mat_attn_proj = wt_mat_attn_proj / stabilized_attn_output
1389
+ relevance_QK = np.einsum('htd,hbd->htb', norm_wt_mat_attn_proj, value_states) * attn_weights
1390
+ relevance_V = np.einsum('hdt,hdb->htb', attn_weights, norm_wt_mat_attn_proj) * value_states
1277
1391
 
1392
+ # --------------- Relevance Calculation for V --------------------------------
1393
+ relevance_V = np.einsum('hqd->qhd', relevance_V)
1394
+ relevance_V = relevance_V.reshape(-1, num_heads * head_dim)
1395
+ wt_mat_V = calculate_relevance_V(relevance_V, value_states, w)
1396
+
1278
1397
  # --------------- Transformed Relevance QK ----------------------------------
1279
- QK_output = np.einsum('ij,kj->ik', query_output, key_output)
1280
- wt_mat_QK = calculate_relevance_QK(relevance_QK, QK_output)
1398
+ relevance_QK = np.einsum('hqd->qhd', relevance_QK)
1399
+ relevance_QK = relevance_QK.reshape(-1, relevance_QK.shape[1] * relevance_QK.shape[2])
1400
+ wt_mat_QK = calculate_relevance_QK(relevance_QK, QK_output, w)
1281
1401
 
1282
1402
  # --------------- Relevance Calculation for K and Q --------------------------------
1283
1403
  stabilized_QK_output = stabilize(QK_output * 2)
1284
1404
  norm_wt_mat_QK = wt_mat_QK / stabilized_QK_output
1285
- wt_mat_Q = np.einsum('ij,jk->ik', norm_wt_mat_QK, key_output) * query_output
1286
- wt_mat_K = np.einsum('ij,ik->kj', query_output, norm_wt_mat_QK) * key_output
1405
+ wt_mat_Q = np.einsum('htd,hdb->htb', norm_wt_mat_QK, key_states) * query_states
1406
+ wt_mat_K = np.einsum('htd,htb->hbd', query_states, norm_wt_mat_QK) * key_states
1287
1407
 
1288
1408
  wt_mat = wt_mat_V + wt_mat_K + wt_mat_Q
1409
+
1410
+ # Reshape wt_mat
1411
+ wt_mat = np.einsum('htd->thd', wt_mat)
1412
+ wt_mat = wt_mat.reshape(wt_mat.shape[0], wt_mat.shape[1] * wt_mat.shape[2]) # reshaped_array = array.reshape(8, 32 * 128)
1413
+
1289
1414
  return wt_mat
1290
1415
 
1291
1416
 
@@ -1301,6 +1426,8 @@ def calculate_wt_feed_forward(wts, inp, w):
1301
1426
  R2 = wts[i]
1302
1427
  contribution_matrix2 = np.einsum('ij,j->ij', w['W_out'].T, intermediate_output[i])
1303
1428
  wt_mat2 = np.zeros(contribution_matrix2.shape)
1429
+
1430
+ bias_out = w['b_out'] if 'b_out' in w else 0
1304
1431
 
1305
1432
  for j in range(contribution_matrix2.shape[0]):
1306
1433
  l1_ind1 = contribution_matrix2[j]
@@ -1312,13 +1439,22 @@ def calculate_wt_feed_forward(wts, inp, w):
1312
1439
  p_sum = np.sum(l1_ind1[p_ind])
1313
1440
  n_sum = np.sum(l1_ind1[n_ind]) * -1
1314
1441
 
1442
+ # Handle positive and negative bias contributions
1443
+ if bias_out[i] > 0:
1444
+ pbias = bias_out[i]
1445
+ nbias = 0
1446
+ else:
1447
+ pbias = 0
1448
+ nbias = -bias_out[i]
1449
+
1315
1450
  if p_sum > 0:
1316
- p_agg_wt = p_sum / (p_sum + n_sum)
1451
+ p_agg_wt = (p_sum + pbias) / (p_sum + n_sum + pbias + nbias)
1452
+ p_agg_wt = p_agg_wt * (p_sum / (p_sum + pbias))
1317
1453
  else:
1318
1454
  p_agg_wt = 0
1319
-
1320
1455
  if n_sum > 0:
1321
- n_agg_wt = n_sum / (p_sum + n_sum)
1456
+ n_agg_wt = (n_sum + nbias) / (p_sum + n_sum + pbias + nbias)
1457
+ n_agg_wt = n_agg_wt * (n_sum / (n_sum + nbias))
1322
1458
  else:
1323
1459
  n_agg_wt = 0
1324
1460
 
@@ -1337,6 +1473,9 @@ def calculate_wt_feed_forward(wts, inp, w):
1337
1473
  R1 = relevance_out[i]
1338
1474
  contribution_matrix1 = np.einsum('ij,j->ij', w['W_int'].T, inp[i])
1339
1475
  wt_mat1 = np.zeros(contribution_matrix1.shape)
1476
+
1477
+ # Check if bias 'b_int' exists, default to 0 if not
1478
+ bias_int = w['b_int'] if 'b_int' in w else 0
1340
1479
 
1341
1480
  for j in range(contribution_matrix1.shape[0]):
1342
1481
  l1_ind1 = contribution_matrix1[j]
@@ -1348,7 +1487,15 @@ def calculate_wt_feed_forward(wts, inp, w):
1348
1487
  p_sum = np.sum(l1_ind1[p_ind])
1349
1488
  n_sum = np.sum(l1_ind1[n_ind]) * -1
1350
1489
 
1351
- t_sum = p_sum - n_sum
1490
+ # Handle positive and negative bias
1491
+ if bias_int[i] > 0:
1492
+ pbias = bias_int[i]
1493
+ nbias = 0
1494
+ else:
1495
+ pbias = 0
1496
+ nbias = -bias_int[i]
1497
+
1498
+ t_sum = p_sum + pbias - n_sum - nbias
1352
1499
 
1353
1500
  # This layer has a ReLU activation function
1354
1501
  act = {
@@ -1367,12 +1514,13 @@ def calculate_wt_feed_forward(wts, inp, w):
1367
1514
  n_sum = 0
1368
1515
 
1369
1516
  if p_sum > 0:
1370
- p_agg_wt = p_sum / (p_sum + n_sum)
1517
+ p_agg_wt = (p_sum + pbias) / (p_sum + n_sum + pbias + nbias)
1518
+ p_agg_wt = p_agg_wt * (p_sum / (p_sum + pbias))
1371
1519
  else:
1372
1520
  p_agg_wt = 0
1373
-
1374
1521
  if n_sum > 0:
1375
- n_agg_wt = n_sum / (p_sum + n_sum)
1522
+ n_agg_wt = (n_sum + nbias) / (p_sum + n_sum + pbias + nbias)
1523
+ n_agg_wt = n_agg_wt * (n_sum / (n_sum + nbias))
1376
1524
  else:
1377
1525
  n_agg_wt = 0
1378
1526
 
@@ -1461,8 +1609,8 @@ def calculate_wt_pooler(wts, inp, w):
1461
1609
  # Calculate relevance for each token
1462
1610
  relevance_inp[i] = wt_mat.sum(axis=0)
1463
1611
 
1464
- relevance_inp *= (100 / np.sum(relevance_inp))
1465
- return relevance_inp
1612
+ relevance_inp *= (np.sum(wts) / np.sum(relevance_inp))
1613
+ return relevance_inp
1466
1614
 
1467
1615
 
1468
1616
  def calculate_wt_classifier(wts, inp, w):
@@ -1595,7 +1743,7 @@ def calculate_wt_lm_head(wts, inp, w):
1595
1743
  return relevance_input
1596
1744
 
1597
1745
 
1598
- def calculate_wt_cross_attention(wts, inp, w):
1746
+ def calculate_wt_cross_attention(wts, inp, w, config):
1599
1747
  '''
1600
1748
  Input:
1601
1749
  wts: relevance score of the layer
@@ -1613,23 +1761,77 @@ def calculate_wt_cross_attention(wts, inp, w):
1613
1761
  key_output = np.einsum('ij,kj->ik', k_v_inp, w['W_k'].T)
1614
1762
  value_output = np.einsum('ij,kj->ik', k_v_inp, w['W_v'].T)
1615
1763
 
1764
+ # --------------- Reshape for Multi-Head Attention ----------------------
1765
+ num_heads = getattr(config, 'num_attention_heads', getattr(config, 'num_heads', None)) # will work for BERT as well as T5/ Llama
1766
+ hidden_size = getattr(config, 'hidden_size', getattr(config, 'd_model', None)) # will work for BERT as well as T5/Llama
1767
+ if hasattr(config, 'num_key_value_heads'):
1768
+ num_key_value_heads = config.num_key_value_heads
1769
+ else:
1770
+ num_key_value_heads = num_heads
1771
+ head_dim = hidden_size // num_heads # dimension of each attention head
1772
+
1773
+ query_states = np.einsum('thd->htd', query_output.reshape(query_output.shape[0], num_heads, head_dim)) # (num_heads, num_tokens, head_dim)
1774
+ key_states = np.einsum('thd->htd', key_output.reshape(key_output.shape[0], num_key_value_heads, head_dim)) # (num_key_value_heads, num_tokens, head_dim)
1775
+ value_states = np.einsum('thd->htd', value_output.reshape(value_output.shape[0], num_key_value_heads, head_dim)) # (num_key_value_heads, num_tokens, head_dim)
1776
+
1777
+ # calculate how many times we need to repeat the key/value heads
1778
+ n_rep = num_heads // num_key_value_heads
1779
+ key_states = np.repeat(key_states, n_rep, axis=0)
1780
+ value_states = np.repeat(value_states, n_rep, axis=0)
1781
+
1782
+ QK_output = np.einsum('hqd,hkd->hqk', query_states, key_states) # (num_heads, num_tokens, num_tokens)
1783
+ attn_weights = QK_output / np.sqrt(head_dim)
1784
+
1785
+ # Apply softmax along the last dimension (softmax over key dimension)
1786
+ attn_weights = np.exp(attn_weights - np.max(attn_weights, axis=-1, keepdims=True)) # Numerically stable softmax
1787
+ attn_weights = attn_weights / np.sum(attn_weights, axis=-1, keepdims=True)
1788
+
1789
+ # Weighted sum of values (num_heads, num_tokens, head_dim)
1790
+ attn_output = np.einsum('hqk,hkl->hql', attn_weights, value_states)
1791
+
1792
+ transposed_attn_output = np.einsum('hqd->qhd', attn_output)
1793
+ reshaped_attn_output = transposed_attn_output.reshape(transposed_attn_output.shape[0], num_heads * head_dim)
1794
+
1795
+ # Perform final linear projection (num_tokens, hidden_size)
1796
+ final_output = np.einsum('qd,dh->qh', reshaped_attn_output, w['W_d'])
1797
+
1798
+ # ------------- Relevance calculation for Final Linear Projection -------------
1799
+ wt_mat_attn_proj = calculate_wt_attention_output_projection(wts, final_output)
1800
+
1616
1801
  # --------------- Relevance Calculation for Step-3 -----------------------
1617
- relevance_V = wts / 2
1618
- relevance_QK = wts / 2
1802
+ # divide the relevance among `attn_weights` and `value_states`
1803
+ wt_mat_attn_proj = wt_mat_attn_proj.reshape(-1, num_heads, head_dim)
1804
+ wt_mat_attn_proj = np.einsum('qhd->hqd', wt_mat_attn_proj)
1619
1805
 
1620
- # --------------- Relevance Calculation for V --------------------------------
1621
- wt_mat_V = calculate_relevance_V(relevance_V, value_output)
1806
+ stabilized_attn_output = stabilize(attn_output * 2)
1807
+ norm_wt_mat_attn_proj = wt_mat_attn_proj / stabilized_attn_output
1808
+ relevance_QK = np.einsum('htd,hbd->htb', norm_wt_mat_attn_proj, value_states) * attn_weights
1809
+ relevance_V = np.einsum('hdt,hdb->htb', attn_weights, norm_wt_mat_attn_proj) * value_states
1622
1810
 
1811
+ # --------------- Relevance Calculation for V --------------------------------
1812
+ relevance_V = np.einsum('hqd->qhd', relevance_V)
1813
+ relevance_V = relevance_V.reshape(-1, num_heads * head_dim)
1814
+ wt_mat_V = calculate_relevance_V(relevance_V, value_states)
1815
+
1623
1816
  # --------------- Transformed Relevance QK ----------------------------------
1624
- QK_output = np.einsum('ij,kj->ik', query_output, key_output)
1817
+ relevance_QK = np.einsum('hqd->qhd', relevance_QK)
1818
+ relevance_QK = relevance_QK.reshape(-1, relevance_QK.shape[1] * relevance_QK.shape[2])
1625
1819
  wt_mat_QK = calculate_relevance_QK(relevance_QK, QK_output)
1626
1820
 
1627
1821
  # --------------- Relevance Calculation for K and Q --------------------------------
1628
1822
  stabilized_QK_output = stabilize(QK_output * 2)
1629
1823
  norm_wt_mat_QK = wt_mat_QK / stabilized_QK_output
1630
- wt_mat_Q = np.einsum('ij,jk->ik', norm_wt_mat_QK, key_output) * query_output
1631
- wt_mat_K = np.einsum('ij,ik->kj', query_output, norm_wt_mat_QK) * key_output
1824
+ wt_mat_Q = np.einsum('htd,hdb->htb', norm_wt_mat_QK, key_states) * query_states
1825
+ wt_mat_K = np.einsum('htd,htb->hbd', query_states, norm_wt_mat_QK) * key_states
1632
1826
 
1827
+ # Relevance of KV input
1633
1828
  wt_mat_KV = wt_mat_V + wt_mat_K
1829
+
1830
+ # Reshape wt_mat_Q and wt_mat_KV
1831
+ wt_mat_Q = np.einsum('htd->thd', wt_mat_Q)
1832
+ wt_mat_KV = np.einsum('htd->thd', wt_mat_KV)
1833
+ wt_mat_Q = wt_mat_Q.reshape(wt_mat_Q.shape[0], wt_mat_Q.shape[1] * wt_mat_Q.shape[2])
1834
+ wt_mat_KV = wt_mat_KV.reshape(wt_mat_KV.shape[0], wt_mat_KV.shape[1] * wt_mat_KV.shape[2])
1835
+
1634
1836
  wt_mat = [wt_mat_KV, wt_mat_Q]
1635
1837
  return wt_mat
dl_backtrace/version.py CHANGED
@@ -12,5 +12,5 @@ __version__: str
12
12
  __version_tuple__: VERSION_TUPLE
13
13
  version_tuple: VERSION_TUPLE
14
14
 
15
- __version__ = version = '0.0.18'
16
- __version_tuple__ = version_tuple = (0, 0, 18)
15
+ __version__ = version = '0.0.20.dev36'
16
+ __version_tuple__ = version_tuple = (0, 0, 20, 'dev36')
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: dl_backtrace
3
- Version: 0.0.18
3
+ Version: 0.0.20.dev36
4
4
  Summary: A python SDK for Deep Learning Backtrace
5
5
  Home-page: https://xai.arya.ai/docs/introduction
6
6
  License: MIT
@@ -1,5 +1,5 @@
1
1
  dl_backtrace/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
2
- dl_backtrace/version.py,sha256=k900Q8XjzRKO6ZOHY0wFLzfzTGArI0sGircauDDJhu0,413
2
+ dl_backtrace/version.py,sha256=tYHVV4mIeOCumN5OCSD_xV6vt2LqOBp8qrLBhN4xnyw,428
3
3
  dl_backtrace/old_backtrace/__init__.py,sha256=AbpHGcgLb-kRsJGnwFEktk7uzpZOCcBY74-YBdrKVGs,1
4
4
  dl_backtrace/old_backtrace/pytorch_backtrace/__init__.py,sha256=TDhKQIj1INyQq7cqTvpjpnBDhoMeWSoqVdx_mPAV3Sw,24
5
5
  dl_backtrace/old_backtrace/pytorch_backtrace/backtrace/__init__.py,sha256=AuR7uMTbf7rrFl-9sMIQ3lQmhu_ATSFBZmfs7R66jAc,76
@@ -17,18 +17,18 @@ dl_backtrace/old_backtrace/tf_backtrace/backtrace/utils/contrast.py,sha256=p-0zj
17
17
  dl_backtrace/old_backtrace/tf_backtrace/backtrace/utils/prop.py,sha256=-h7nHEvsoEwfksHsT52VfnZy334DvqqP8g6fMFVNnAM,25670
18
18
  dl_backtrace/pytorch_backtrace/__init__.py,sha256=TDhKQIj1INyQq7cqTvpjpnBDhoMeWSoqVdx_mPAV3Sw,24
19
19
  dl_backtrace/pytorch_backtrace/backtrace/__init__.py,sha256=AuR7uMTbf7rrFl-9sMIQ3lQmhu_ATSFBZmfs7R66jAc,76
20
- dl_backtrace/pytorch_backtrace/backtrace/backtrace.py,sha256=wU5J7QkTnkQ0Jri8Xe6WctqeKFO4hTcL8MQpLrXbvdY,35211
20
+ dl_backtrace/pytorch_backtrace/backtrace/backtrace.py,sha256=bdb8eFzTE7ms34ne-VBJxXY8AMHhu-vTenUNwCMKbRo,44410
21
21
  dl_backtrace/pytorch_backtrace/backtrace/config.py,sha256=ODrgOC74ojzGLnEto_ah-WPA8_MyRE7cZkZlU15239s,983
22
22
  dl_backtrace/pytorch_backtrace/backtrace/utils/__init__.py,sha256=KffAJVu7NsgfMHEZaY7lND2LQwZamVIquqx8POwOaLg,120
23
- dl_backtrace/pytorch_backtrace/backtrace/utils/contrast.py,sha256=owsncnnz-j7UameRxx5uL9Q1AqOcA-uxhpyyfs0DqBw,32228
23
+ dl_backtrace/pytorch_backtrace/backtrace/utils/contrast.py,sha256=gcvDRZstcgAkjzyyxy6dagSV2EDFGqh9_SMn9P1-2lo,52795
24
24
  dl_backtrace/pytorch_backtrace/backtrace/utils/encoder.py,sha256=1dUKEL_LAuFcUEldYdyQVsV_P-KH3gTWEOgROfXPwyc,7469
25
25
  dl_backtrace/pytorch_backtrace/backtrace/utils/encoder_decoder.py,sha256=fxbXwUl7KQKwWSdKMO-bYvDC7NDct71F_U7uZzh1orw,25658
26
26
  dl_backtrace/pytorch_backtrace/backtrace/utils/helper.py,sha256=IG0XkPEfbPpWBm4aVHgv-GkgFFrl2wu5xMrz6DeH_xQ,3512
27
- dl_backtrace/pytorch_backtrace/backtrace/utils/prop.py,sha256=hHYFLGR3hsGtxlzdA4qXhgFFsBqq5AhCvxRUxx6KwZY,41284
27
+ dl_backtrace/pytorch_backtrace/backtrace/utils/prop.py,sha256=oEwWmRsLwoYdCXJaEEW7-UL1CLVOsseUAcOasTkyQeE,67294
28
28
  dl_backtrace/tf_backtrace/__init__.py,sha256=TDhKQIj1INyQq7cqTvpjpnBDhoMeWSoqVdx_mPAV3Sw,24
29
29
  dl_backtrace/tf_backtrace/backtrace/__init__.py,sha256=KkU7X_wXxwYR4HYQqQ3kWtxlJK3Ytaa84e1Jbzc2_ZA,84
30
30
  dl_backtrace/tf_backtrace/backtrace/activation_info.py,sha256=3Ppw4_6rJV16YPXnKjd2WPaULgUUL6U01bh8Foa-3Yg,1334
31
- dl_backtrace/tf_backtrace/backtrace/backtrace.py,sha256=eQa7wz3MsfjkgZE72voKGlZGjxbf42kHkXcwB7Ve3qI,41474
31
+ dl_backtrace/tf_backtrace/backtrace/backtrace.py,sha256=k6irwiOu68EpRG_W9wk5UF7QS50TKM5JqUQfj4qmvWw,41635
32
32
  dl_backtrace/tf_backtrace/backtrace/models.py,sha256=wPaeRuEvZL2xTdj6I6hF-ZCKC2c8EAuF9PoOIZkxkR4,466
33
33
  dl_backtrace/tf_backtrace/backtrace/server.py,sha256=jphibvI46QpQcqnpXVIYFq_M2CtRVONgItZe9iWOd54,567
34
34
  dl_backtrace/tf_backtrace/backtrace/utils/__init__.py,sha256=ci_RAYYnqyAWa_rcIEycnqCghQ4aZtvaGQ7oDUb_k_0,131
@@ -36,9 +36,9 @@ dl_backtrace/tf_backtrace/backtrace/utils/encoder.py,sha256=WeGLjIRHNqjwIK-8UB0x
36
36
  dl_backtrace/tf_backtrace/backtrace/utils/encoder_decoder.py,sha256=qbS34WswNiT1xgON5ayNAIewJMRfDdeGel6l1XrjXms,24247
37
37
  dl_backtrace/tf_backtrace/backtrace/utils/helper.py,sha256=QB21kPB5iJfRpy8khYJnzojaKf5ACnAFYh5XxYBcnXA,3419
38
38
  dl_backtrace/tf_backtrace/backtrace/utils/utils_contrast.py,sha256=rQwManW0d6Td6V_A1qGezcZ19Tgr34zbFcieQ0rqAAc,48415
39
- dl_backtrace/tf_backtrace/backtrace/utils/utils_prop.py,sha256=dC6QxbwpA5JhE8009vnZVQyY8eIE3RFx6t7I-N17-k0,58320
40
- dl_backtrace-0.0.18.dist-info/LICENSE,sha256=RTqAU0MFv1q3ZXKewNobKxIIPzRHgImom7e6ORV7X6o,1064
41
- dl_backtrace-0.0.18.dist-info/METADATA,sha256=YuIoncn6l2OscEaxrC5VA7OaFre9ddTh3pFGqZPJzsI,7837
42
- dl_backtrace-0.0.18.dist-info/WHEEL,sha256=Mdi9PDNwEZptOjTlUcAth7XJDFtKrHYaQMPulZeBCiQ,91
43
- dl_backtrace-0.0.18.dist-info/top_level.txt,sha256=gvGVYScJfW6c4aO5WMo4Aqa6NLEfmLK7VWXVx_GeiIk,13
44
- dl_backtrace-0.0.18.dist-info/RECORD,,
39
+ dl_backtrace/tf_backtrace/backtrace/utils/utils_prop.py,sha256=kduJpbN2FIrrhgrUxSFkX3gnTfXLuQa4qx-LH0dDB7A,68291
40
+ dl_backtrace-0.0.20.dev36.dist-info/LICENSE,sha256=RTqAU0MFv1q3ZXKewNobKxIIPzRHgImom7e6ORV7X6o,1064
41
+ dl_backtrace-0.0.20.dev36.dist-info/METADATA,sha256=H5E00w4t5F6O67roa2Y6tiGEW_SziKjsgMssLEoDM_U,7843
42
+ dl_backtrace-0.0.20.dev36.dist-info/WHEEL,sha256=R06PA3UVYHThwHvxuRWMqaGcr-PuniXahwjmQRFMEkY,91
43
+ dl_backtrace-0.0.20.dev36.dist-info/top_level.txt,sha256=gvGVYScJfW6c4aO5WMo4Aqa6NLEfmLK7VWXVx_GeiIk,13
44
+ dl_backtrace-0.0.20.dev36.dist-info/RECORD,,
@@ -1,5 +1,5 @@
1
1
  Wheel-Version: 1.0
2
- Generator: setuptools (73.0.1)
2
+ Generator: setuptools (75.5.0)
3
3
  Root-Is-Purelib: true
4
4
  Tag: py3-none-any
5
5