dl-backtrace 0.0.12__py3-none-any.whl → 0.0.16__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of dl-backtrace might be problematic. Click here for more details.

Files changed (27) hide show
  1. dl_backtrace/pytorch_backtrace/backtrace/backtrace.py +173 -44
  2. dl_backtrace/pytorch_backtrace/backtrace/utils/__init__.py +3 -0
  3. dl_backtrace/pytorch_backtrace/backtrace/utils/encoder.py +183 -0
  4. dl_backtrace/pytorch_backtrace/backtrace/utils/encoder_decoder.py +489 -0
  5. dl_backtrace/pytorch_backtrace/backtrace/utils/helper.py +95 -0
  6. dl_backtrace/pytorch_backtrace/backtrace/utils/prop.py +481 -0
  7. dl_backtrace/tf_backtrace/backtrace/__init__.py +1 -2
  8. dl_backtrace/tf_backtrace/backtrace/activation_info.py +33 -0
  9. dl_backtrace/tf_backtrace/backtrace/backtrace.py +506 -279
  10. dl_backtrace/tf_backtrace/backtrace/models.py +25 -0
  11. dl_backtrace/tf_backtrace/backtrace/server.py +27 -0
  12. dl_backtrace/tf_backtrace/backtrace/utils/__init__.py +5 -2
  13. dl_backtrace/tf_backtrace/backtrace/utils/encoder.py +206 -0
  14. dl_backtrace/tf_backtrace/backtrace/utils/encoder_decoder.py +501 -0
  15. dl_backtrace/tf_backtrace/backtrace/utils/helper.py +99 -0
  16. dl_backtrace/tf_backtrace/backtrace/utils/utils_contrast.py +1132 -0
  17. dl_backtrace/tf_backtrace/backtrace/utils/utils_prop.py +1582 -0
  18. dl_backtrace/version.py +2 -2
  19. {dl_backtrace-0.0.12.dist-info → dl_backtrace-0.0.16.dist-info}/METADATA +3 -2
  20. dl_backtrace-0.0.16.dist-info/RECORD +29 -0
  21. {dl_backtrace-0.0.12.dist-info → dl_backtrace-0.0.16.dist-info}/WHEEL +1 -1
  22. dl_backtrace/tf_backtrace/backtrace/config.py +0 -41
  23. dl_backtrace/tf_backtrace/backtrace/utils/contrast.py +0 -834
  24. dl_backtrace/tf_backtrace/backtrace/utils/prop.py +0 -725
  25. dl_backtrace-0.0.12.dist-info/RECORD +0 -21
  26. {dl_backtrace-0.0.12.dist-info → dl_backtrace-0.0.16.dist-info}/LICENSE +0 -0
  27. {dl_backtrace-0.0.12.dist-info → dl_backtrace-0.0.16.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,489 @@
1
+ import torch
2
+ from collections import defaultdict
3
+
4
+
5
+ def build_enc_dec_tree(model, root='enc-dec'):
6
+ # Initialize the tree structure
7
+ ltree = {}
8
+ layer_tree = {}
9
+ inputs = []
10
+ outputs = []
11
+ intermediates = []
12
+ layer_stack = []
13
+
14
+ # Base component setup
15
+ def add_component(tree, name, component, child=None):
16
+ tree[name] = {
17
+ 'name': name,
18
+ 'class': component if type(component).__name__ == 'str' else type(component).__name__,
19
+ 'type': str(type(component)),
20
+ 'parent': None,
21
+ 'child': None
22
+ }
23
+
24
+ if isinstance(child, list):
25
+ tree[name]['child'] = child
26
+ elif isinstance(child, str):
27
+ tree[name]['child'] = [child]
28
+
29
+ if tree[name]['class'] == 'list':
30
+ tree[name]['class'] = [type(item).__name__ for item in component]
31
+ tree[name]['type'] = [str(type(item)) for item in component]
32
+
33
+ # Keep track of component type in a separate dictionary
34
+ layer_tree[name] = component if type(component).__name__ == 'str' else tree[name]['type']
35
+
36
+ # keep track of layer stack
37
+ layer_stack.append(name)
38
+
39
+ # Link the parent to its children
40
+ if isinstance(child, list):
41
+ for ch in child:
42
+ if ch in tree:
43
+ tree[ch]['parent'] = [name]
44
+
45
+ elif isinstance(child, str):
46
+ if child in tree:
47
+ tree[child]['parent'] = [name]
48
+
49
+ return tree[name]
50
+
51
+ # Add root and embeddings component
52
+ encoder_embeddings = add_component(ltree, 'encoder_embedding', model.encoder.embed_tokens, child=None)
53
+
54
+ # Add encoder layers dynamically
55
+ current_child = 'encoder_embedding'
56
+ for i, layer in enumerate(model.encoder.block):
57
+ encoder_layer_norm_0 = add_component(ltree, f'encoder_layer_norm_{i}_0', 'Layer_Norm', child=current_child)
58
+ encoder_self_attention = add_component(ltree, f'encoder_self_attention_{i}', 'Self_Attention', child=f'encoder_layer_norm_{i}_0')
59
+ encoder_residual_self_attention = add_component(ltree, f'encoder_residual_self_attention_{i}', 'Residual', child=[current_child, f'encoder_self_attention_{i}'])
60
+
61
+ encoder_layer_norm_1 = add_component(ltree, f'encoder_layer_norm_{i}_1', 'Layer_Norm', child=f'encoder_residual_self_attention_{i}')
62
+ encoder_feed_forward = add_component(ltree, f'encoder_feed_forward_{i}', 'Feed_Forward', child=f'encoder_layer_norm_{i}_1')
63
+ encoder_residual_feed_forward = add_component(ltree, f'encoder_residual_feed_forward_{i}', 'Residual', child=[f'encoder_residual_self_attention_{i}', f'encoder_feed_forward_{i}'])
64
+
65
+ current_child = f'encoder_residual_feed_forward_{i}'
66
+
67
+ if hasattr(model.encoder, 'final_layer_norm'):
68
+ encoder_final_layer_norm = add_component(ltree, 'encoder_layer_norm', model.encoder.final_layer_norm, child=current_child)
69
+ current_child = 'encoder_layer_norm'
70
+
71
+ # Add Decoder layers
72
+ decoder_embeddings = add_component(ltree, 'decoder_embedding', model.decoder.embed_tokens, child=None)
73
+
74
+ # Add decoder layers dynamically
75
+ current_child = 'decoder_embedding'
76
+ for i, layer in enumerate(model.decoder.block):
77
+ decoder_layer_norm_0 = add_component(ltree, f'decoder_layer_norm_{i}_0', 'Layer_Norm', child=current_child)
78
+ decoder_self_attention = add_component(ltree, f'decoder_self_attention_{i}', 'Self_Attention', child=f'decoder_layer_norm_{i}_0')
79
+ decoder_residual_self_attention = add_component(ltree, f'decoder_residual_self_attention_{i}', 'Residual', child=[current_child, f'decoder_self_attention_{i}'])
80
+
81
+ decoder_layer_norm_1 = add_component(ltree, f'decoder_layer_norm_{i}_1', 'Layer_Norm', child=f'decoder_residual_self_attention_{i}')
82
+ decoder_cross_attention = add_component(ltree, f'decoder_cross_attention_{i}', 'Cross_Attention', child=['encoder_layer_norm', f'decoder_layer_norm_{i}_1'])
83
+ decoder_residual_cross_attention = add_component(ltree, f'decoder_residual_cross_attention_{i}', 'Residual', child=[f'decoder_residual_self_attention_{i}', f'decoder_cross_attention_{i}'])
84
+
85
+ decoder_layer_norm_2 = add_component(ltree, f'decoder_layer_norm_{i}_2', 'Layer_Norm', child=f'decoder_residual_cross_attention_{i}')
86
+ decoder_feed_forward = add_component(ltree, f'decoder_feed_forward_{i}', 'Feed_Forward', child=f'decoder_layer_norm_{i}_2')
87
+ decoder_residual_feed_forward = add_component(ltree, f'decoder_residual_feed_forward_{i}', 'Residual', child=[f'decoder_residual_cross_attention_{i}', f'decoder_feed_forward_{i}'])
88
+
89
+ current_child = f'decoder_residual_feed_forward_{i}'
90
+
91
+ if hasattr(model.decoder, 'final_layer_norm'):
92
+ decoder_final_layer_norm = add_component(ltree, 'decoder_layer_norm', model.decoder.final_layer_norm, child=current_child)
93
+ current_child = 'decoder_layer_norm'
94
+
95
+ # Decoder LM-Head
96
+ if hasattr(model, 'lm_head'):
97
+ decoder_lm_head = add_component(ltree, 'decoder_lm_head', 'LM_Head', child=current_child)
98
+ current_child = 'decoder_lm_head'
99
+
100
+ # Classify components
101
+ for name, component in ltree.items():
102
+ if component['parent'] is None:
103
+ outputs.append(component['name'])
104
+ elif component['child'] is None:
105
+ inputs.append(component['name'])
106
+ else:
107
+ intermediates.append(component['name'])
108
+
109
+ # reverse the layer_stack
110
+ layer_stack = list(reversed(layer_stack))
111
+ model_resource = (layer_tree, ltree, outputs, inputs)
112
+
113
+ return model_resource, layer_stack
114
+
115
+
116
+ def extract_encoder_decoder_weights(model):
117
+ # Initialize a dictionary to hold the weights
118
+ weights_dict = {
119
+ # 'shared_embeddings': {},
120
+ 'encoder_embedding': {},
121
+ 'encoder_layer_norm': {},
122
+ 'decoder_embedding': {},
123
+ 'decoder_layer_norm': {},
124
+ 'decoder_lm_head': {}
125
+ }
126
+
127
+ # Extract the model's parameters and organize them into the dictionary
128
+ for name, param in model.named_parameters():
129
+ if 'shared' in name:
130
+ weights_dict['encoder_embedding'][name] = param.data.cpu().numpy()
131
+ weights_dict['decoder_embedding'][name] = param.data.cpu().numpy()
132
+ weights_dict['decoder_lm_head'][name] = param.data.cpu().numpy()
133
+
134
+ elif 'encoder.block' in name:
135
+ layer = name.split('.')[2]
136
+ sub_layer = name.split('.')[4]
137
+ submodule = name.split('.')[5]
138
+
139
+ if 'SelfAttention' in submodule and f'encoder_self_attention_{layer}' not in weights_dict:
140
+ weights_dict[f'encoder_self_attention_{layer}'] = {}
141
+ if 'layer_norm' in submodule and f'encoder_layer_norm_{layer}' not in weights_dict:
142
+ weights_dict[f'encoder_layer_norm_{layer}'] = {}
143
+ if 'DenseReluDense' in submodule and f'encoder_feed_forward_{layer}' not in weights_dict:
144
+ weights_dict[f'encoder_feed_forward_{layer}'] = {}
145
+
146
+ if 'SelfAttention' in submodule:
147
+ weights_dict[f'encoder_self_attention_{layer}'][name] = param.data.cpu().numpy()
148
+ elif 'layer_norm' in submodule:
149
+ weights_dict[f'encoder_layer_norm_{layer}'][name] = param.data.cpu().numpy()
150
+ elif 'DenseReluDense' in submodule:
151
+ weights_dict[f'encoder_feed_forward_{layer}'][name] = param.data.cpu().numpy()
152
+
153
+ elif 'encoder.final_layer_norm.weight' in name:
154
+ weights_dict['encoder_layer_norm'][name] = param.data.cpu().numpy()
155
+
156
+ elif 'decoder.block' in name:
157
+ layer = name.split('.')[2]
158
+ sub_layer = name.split('.')[4]
159
+ submodule = name.split('.')[5]
160
+
161
+ if 'SelfAttention' in submodule and f'decoder_self_attention_{layer}' not in weights_dict:
162
+ weights_dict[f'decoder_self_attention_{layer}'] = {}
163
+ if 'layer_norm' in submodule and f'decoder_layer_norm_{layer}' not in weights_dict:
164
+ weights_dict[f'decoder_layer_norm_{layer}'] = {}
165
+ if 'EncDecAttention' in submodule and f'decoder_cross_attention_{layer}' not in weights_dict:
166
+ weights_dict[f'decoder_cross_attention_{layer}'] = {}
167
+ if 'DenseReluDense' in submodule and f'decoder_feed_forward_{layer}' not in weights_dict:
168
+ weights_dict[f'decoder_feed_forward_{layer}'] = {}
169
+
170
+ if 'SelfAttention' in submodule:
171
+ weights_dict[f'decoder_self_attention_{layer}'][name] = param.data.cpu().numpy()
172
+ elif 'layer_norm' in submodule:
173
+ weights_dict[f'decoder_layer_norm_{layer}'][name] = param.data.cpu().numpy()
174
+ elif 'EncDecAttention' in submodule:
175
+ weights_dict[f'decoder_cross_attention_{layer}'][name] = param.data.cpu().numpy()
176
+ elif 'DenseReluDense' in submodule:
177
+ weights_dict[f'decoder_feed_forward_{layer}'][name] = param.data.cpu().numpy()
178
+
179
+ elif 'decoder.final_layer_norm.weight' in name:
180
+ weights_dict['decoder_layer_norm'][name] = param.data.cpu().numpy()
181
+
182
+ return weights_dict
183
+
184
+
185
+ def calculate_encoder_decoder_output(input_text, model, tokenizer):
186
+ encoder_inputs = {}
187
+ encoder_outputs = {}
188
+ encoder_hooks = []
189
+ decoder_outputs = defaultdict(lambda: defaultdict(dict))
190
+ decoder_inputs = defaultdict(lambda: defaultdict(dict))
191
+ decoder_hooks = []
192
+
193
+ def capture_encoder_embeddings(model, tokenizer, input_text):
194
+ # Ensure the model is in evaluation mode
195
+ model.eval()
196
+
197
+ # Tokenize the input text
198
+ encoding = tokenizer(input_text, return_tensors='pt')
199
+ input_ids = encoding["input_ids"]
200
+ attention_mask = encoding["attention_mask"]
201
+
202
+ # Manually capture the embedding output
203
+ with torch.no_grad():
204
+ embedding_output = model.encoder.embed_tokens(input_ids)
205
+
206
+ # Return the captured embeddings
207
+ return embedding_output
208
+
209
+ ## ------------ Hook function for Self-Attention Block --------------------
210
+ def hook_fn__encoder_normalized_hidden_states(module, input, output):
211
+ encoder_inputs[f'encoder_layer_norm_{module.layer_index}_0'] = input
212
+ encoder_outputs[f'encoder_layer_norm_{module.layer_index}_0'] = output
213
+ encoder_outputs[f'input_to_layer_norm_{module.layer_index}'] = input[0]
214
+
215
+ def hook_fn_encoder_self_attention_outputs(module, input, output):
216
+ encoder_inputs[f'encoder_self_attention_{module.layer_index}'] = input
217
+ encoder_outputs[f'encoder_self_attention_{module.layer_index}'] = output[0]
218
+
219
+ def hook_fn_encoder_dropout_attention_output(module, input, output):
220
+ encoder_outputs[f'encoder_dropout_attention_{module.layer_index}'] = output[0]
221
+
222
+ ## ------------ Hook function for Feed-Forward Block --------------
223
+ def hook_fn_encoder_normalized_forwarded_states(module, input, output):
224
+ encoder_inputs[f'encoder_layer_norm_{module.layer_index}_1'] = input
225
+ encoder_outputs[f'encoder_layer_norm_{module.layer_index}_1'] = output
226
+ encoder_outputs[f'input_to_ff_layer_norm_{module.layer_index}'] = input[0]
227
+
228
+ def hook_fn_encoder_forwarded_states(module, input, output):
229
+ encoder_inputs[f'encoder_feed_forward_{module.layer_index}'] = input
230
+ encoder_outputs[f'encoder_feed_forward_{module.layer_index}'] = output
231
+
232
+ def hook_fn_encoder_dropout_forwarded_states(module, input, output):
233
+ encoder_outputs[f'dropout_forwarded_states_{module.layer_index}'] = output
234
+
235
+ # Custom hooks to calculate residuals
236
+ def hook_fn_encoder_residual_self_attention(layer_index):
237
+ def hook(module, input, output):
238
+ input_to_layer_norm = encoder_outputs[f'input_to_layer_norm_{layer_index}']
239
+ encoder_outputs[f'encoder_residual_self_attention_{layer_index}'] = input_to_layer_norm + output
240
+ return hook
241
+
242
+ def hook_fn_encoder_residual_feed_forward(layer_index):
243
+ def hook(module, input, output):
244
+ input_to_ff_layer_norm = encoder_outputs[f'input_to_ff_layer_norm_{layer_index}']
245
+ encoder_outputs[f'encoder_residual_feed_forward_{layer_index}'] = input_to_ff_layer_norm + output
246
+ return hook
247
+
248
+ # Hook for Final Layer normalization and dropout for Encoder
249
+ def hook_fn_normalized_encoder_output(module, input, output):
250
+ encoder_outputs['encoder_layer_norm'] = output
251
+
252
+ def hook_fn_dropout_normalized_encoder_output(module, input, output):
253
+ encoder_outputs['dropout_normalized_encoder_output'] = output
254
+
255
+ # Register hooks to the encoder submodules
256
+ for i, layer in enumerate(model.encoder.block):
257
+ # Set layer_index attribute to all relevant submodules
258
+ layer.layer[0].layer_norm.layer_index = i
259
+ layer.layer[0].SelfAttention.layer_index = i
260
+ layer.layer[0].dropout.layer_index = i
261
+ layer.layer[1].layer_norm.layer_index = i
262
+ layer.layer[1].DenseReluDense.layer_index = i
263
+ layer.layer[1].dropout.layer_index = i
264
+
265
+ encoder_hooks.append(layer.layer[0].layer_norm.register_forward_hook(hook_fn__encoder_normalized_hidden_states))
266
+ encoder_hooks.append(layer.layer[0].SelfAttention.register_forward_hook(hook_fn_encoder_self_attention_outputs))
267
+ encoder_hooks.append(layer.layer[0].dropout.register_forward_hook(hook_fn_encoder_dropout_attention_output))
268
+ encoder_hooks.append(layer.layer[0].dropout.register_forward_hook(hook_fn_encoder_residual_self_attention(i))) # Custom hook for residual self-attention
269
+
270
+ encoder_hooks.append(layer.layer[1].layer_norm.register_forward_hook(hook_fn_encoder_normalized_forwarded_states))
271
+ encoder_hooks.append(layer.layer[1].DenseReluDense.register_forward_hook(hook_fn_encoder_forwarded_states))
272
+ encoder_hooks.append(layer.layer[1].dropout.register_forward_hook(hook_fn_encoder_dropout_forwarded_states))
273
+ encoder_hooks.append(layer.layer[1].dropout.register_forward_hook(hook_fn_encoder_residual_feed_forward(i))) # Custom hook for residual feed-forward
274
+
275
+ # Register hook for Final Layer Normalization and dropout for Encoder
276
+ encoder_hooks.append(model.encoder.final_layer_norm.register_forward_hook(hook_fn_normalized_encoder_output))
277
+ encoder_hooks.append(model.encoder.dropout.register_forward_hook(hook_fn_dropout_normalized_encoder_output))
278
+
279
+ ############################ Hook for Decoder ################################
280
+ # Global variable to keep track of the token index
281
+ token_idx = 0
282
+
283
+ # Function to generate timestamp (token index)
284
+ def get_timestamp():
285
+ return str(token_idx)
286
+
287
+ # Hook functions to capture input embedding
288
+ def hook_fn_decoder_embedding(module, input, output):
289
+ global token_idx
290
+ timestamp = get_timestamp()
291
+ decoder_outputs[timestamp]['decoder_embedding'] = output.detach().clone()
292
+
293
+ ## ------------ Hook function for Self-Attention Block --------------------
294
+ def hook_fn_decoder_normalized_hidden_states(module, input, output, layer_index):
295
+ global token_idx
296
+ timestamp = get_timestamp()
297
+ decoder_inputs[timestamp][f'decoder_layer_norm_{layer_index}_0'] = input
298
+ decoder_outputs[timestamp][f'decoder_layer_norm_{layer_index}_0'] = output
299
+ decoder_outputs[timestamp][f'input_to_layer_norm_{layer_index}'] = input[0]
300
+
301
+ def hook_fn_decoder_self_attention_outputs(module, input, output, layer_index):
302
+ global token_idx
303
+ timestamp = get_timestamp()
304
+ decoder_inputs[timestamp][f'decoder_self_attention_{layer_index}'] = input
305
+ decoder_outputs[timestamp][f'decoder_self_attention_{layer_index}'] = output[0]
306
+
307
+ def hook_fn_decoder_dropout_attention_output(module, input, output, layer_index):
308
+ global token_idx
309
+ timestamp = get_timestamp()
310
+ decoder_outputs[timestamp][f'dropout_attention_output_{layer_index}'] = output
311
+
312
+ ## ------------ Hook function for Cross-Attention Block --------------------
313
+ def hook_fn_decoder_normalized_cross_attn_hidden_states(module, input, output, layer_index):
314
+ global token_idx
315
+ timestamp = get_timestamp()
316
+ decoder_inputs[timestamp][f'decoder_layer_norm_{layer_index}_1'] = input
317
+ decoder_outputs[timestamp][f'decoder_layer_norm_{layer_index}_1'] = output
318
+ decoder_outputs[timestamp][f'input_to_cross_attn_layer_norm_{layer_index}'] = input[0]
319
+
320
+ def hook_fn_decoder_cross_attention_outputs(module, input, output, layer_index):
321
+ global token_idx
322
+ timestamp = get_timestamp()
323
+ key_value_states = encoder_outputs['dropout_normalized_encoder_output']
324
+ query_state = input[0]
325
+
326
+ inputs = {
327
+ 'query': input[0],
328
+ 'key': key_value_states,
329
+ 'value': key_value_states
330
+ }
331
+
332
+ decoder_inputs[timestamp][f'decoder_cross_attention_{layer_index}'] = inputs
333
+ decoder_outputs[timestamp][f'decoder_cross_attention_{layer_index}'] = output[0]
334
+
335
+ def hook_fn_decoder_dropout_cross_attn_output(module, input, output, layer_index):
336
+ global token_idx
337
+ timestamp = get_timestamp()
338
+ decoder_outputs[timestamp][f'dropout_cross_attn_output_{layer_index}'] = output
339
+
340
+ ## ------------ Hook function for Feed-Forward Block --------------------
341
+ def hook_fn_decoder_normalized_forwarded_states(module, input, output, layer_index):
342
+ global token_idx
343
+ timestamp = get_timestamp()
344
+ decoder_inputs[timestamp][f'decoder_layer_norm_{layer_index}_2'] = input
345
+ decoder_outputs[timestamp][f'decoder_layer_norm_{layer_index}_2'] = output
346
+ decoder_outputs[timestamp][f'input_to_ff_layer_norm_{layer_index}'] = input[0]
347
+
348
+ def hook_fn_decoder_forwarded_states(module, input, output, layer_index):
349
+ global token_idx
350
+ timestamp = get_timestamp()
351
+ decoder_inputs[timestamp][f'decoder_feed_forward_{layer_index}'] = input[0]
352
+ decoder_outputs[timestamp][f'decoder_feed_forward_{layer_index}'] = output
353
+
354
+ def hook_fn_decoder_dropout_forwarded_states(module, input, output, layer_index):
355
+ global token_idx
356
+ timestamp = get_timestamp()
357
+ decoder_outputs[timestamp][f'dropout_forwarded_states_{layer_index}'] = output
358
+
359
+ # Custom hooks to calculate residuals
360
+ def hook_fn_decoder_residual_self_attention(layer_index):
361
+ def hook(module, input, output):
362
+ global token_idx
363
+ timestamp = get_timestamp()
364
+ input_to_layer_norm = decoder_outputs[timestamp][f'input_to_layer_norm_{layer_index}']
365
+ decoder_outputs[timestamp][f'decoder_residual_self_attention_{layer_index}'] = input_to_layer_norm + output
366
+ return hook
367
+
368
+ def hook_fn_decoder_residual_cross_attention(layer_index):
369
+ def hook(module, input, output):
370
+ global token_idx
371
+ timestamp = get_timestamp()
372
+ input_to_layer_norm = decoder_outputs[timestamp][f'input_to_cross_attn_layer_norm_{layer_index}']
373
+ decoder_outputs[timestamp][f'decoder_residual_cross_attention_{layer_index}'] = input_to_layer_norm + output
374
+ return hook
375
+
376
+ def hook_fn_decoder_residual_feed_forward(layer_index):
377
+ def hook(module, input, output):
378
+ global token_idx
379
+ timestamp = get_timestamp()
380
+ input_to_ff_layer_norm = decoder_outputs[timestamp][f'input_to_ff_layer_norm_{layer_index}']
381
+ decoder_outputs[timestamp][f'decoder_residual_feed_forward_{layer_index}'] = input_to_ff_layer_norm + output
382
+ return hook
383
+
384
+ # Hook for Final Layer normalization and dropout for Decoder
385
+ def hook_fn_normalized_decoder_output(module, input, output):
386
+ global token_idx
387
+ timestamp = get_timestamp()
388
+ decoder_outputs[timestamp]['decoder_layer_norm'] = output
389
+
390
+ def hook_fn_dropout_normalized_decoder_output(module, input, output):
391
+ global token_idx
392
+ timestamp = get_timestamp()
393
+ decoder_outputs[timestamp]['dropout_normalized_decoder_output'] = output
394
+
395
+ # Hook for the Decoder LM-Head
396
+ def hook_fn_lm_head(module, input, output):
397
+ global token_idx
398
+ timestamp = get_timestamp()
399
+ decoder_outputs[timestamp]['decoder_lm_head'] = output
400
+
401
+
402
+ # Register hook for embedding
403
+ decoder_hooks.append(model.decoder.embed_tokens.register_forward_hook(lambda module, input, output: hook_fn_decoder_embedding(module, input, output)))
404
+
405
+ # Register hooks to the decoder submodules
406
+ for i, layer in enumerate(model.decoder.block):
407
+ # Set layer_index attribute to all relevant submodules
408
+ layer.layer[0].layer_norm.layer_index = i
409
+ layer.layer[0].SelfAttention.layer_index = i
410
+ layer.layer[0].dropout.layer_index = i
411
+ layer.layer[1].layer_norm.layer_index = i
412
+ layer.layer[1].EncDecAttention.layer_index = i
413
+ layer.layer[1].dropout.layer_index = i
414
+ layer.layer[2].layer_norm.layer_index = i
415
+ layer.layer[2].DenseReluDense.layer_index = i
416
+ layer.layer[2].dropout.layer_index = i
417
+
418
+ decoder_hooks.append(layer.layer[0].layer_norm.register_forward_hook(lambda module, input, output, i=i: hook_fn_decoder_normalized_hidden_states(module, input, output, layer_index=i)))
419
+ decoder_hooks.append(layer.layer[0].SelfAttention.register_forward_hook(lambda module, input, output, i=i: hook_fn_decoder_self_attention_outputs(module, input, output, layer_index=i)))
420
+ decoder_hooks.append(layer.layer[0].dropout.register_forward_hook(lambda module, input, output, i=i: hook_fn_decoder_dropout_attention_output(module, input, output, layer_index=i)))
421
+ decoder_hooks.append(layer.layer[0].dropout.register_forward_hook(lambda module, input, output, i=i: hook_fn_decoder_residual_self_attention(i)(module, input, output)))
422
+
423
+ decoder_hooks.append(layer.layer[1].layer_norm.register_forward_hook(lambda module, input, output, i=i: hook_fn_decoder_normalized_cross_attn_hidden_states(module, input, output, layer_index=i)))
424
+ decoder_hooks.append(layer.layer[1].EncDecAttention.register_forward_hook(lambda module, input, output, i=i: hook_fn_decoder_cross_attention_outputs(module, input, output, layer_index=i)))
425
+ decoder_hooks.append(layer.layer[1].dropout.register_forward_hook(lambda module, input, output, i=i: hook_fn_decoder_dropout_cross_attn_output(module, input, output, layer_index=i)))
426
+ decoder_hooks.append(layer.layer[1].dropout.register_forward_hook(lambda module, input, output, i=i: hook_fn_decoder_residual_cross_attention(i)(module, input, output)))
427
+
428
+ decoder_hooks.append(layer.layer[2].layer_norm.register_forward_hook(lambda module, input, output, i=i: hook_fn_decoder_normalized_forwarded_states(module, input, output, layer_index=i)))
429
+ decoder_hooks.append(layer.layer[2].DenseReluDense.register_forward_hook(lambda module, input, output, i=i: hook_fn_decoder_forwarded_states(module, input, output, layer_index=i)))
430
+ decoder_hooks.append(layer.layer[2].dropout.register_forward_hook(lambda module, input, output, i=i: hook_fn_decoder_dropout_forwarded_states(module, input, output, layer_index=i)))
431
+ decoder_hooks.append(layer.layer[2].dropout.register_forward_hook(lambda module, input, output, i=i: hook_fn_decoder_residual_feed_forward(i)(module, input, output)))
432
+
433
+ # Register hook for Final Layer Normalization and dropout for Decoder
434
+ decoder_hooks.append(model.decoder.final_layer_norm.register_forward_hook(lambda module, input, output: hook_fn_normalized_decoder_output(module, input, output)))
435
+ decoder_hooks.append(model.decoder.dropout.register_forward_hook(lambda module, input, output: hook_fn_dropout_normalized_decoder_output(module, input, output)))
436
+
437
+ # Register hook for the Decoder LM-Head
438
+ decoder_hooks.append(model.lm_head.register_forward_hook(lambda module, input, output: hook_fn_lm_head(module, input, output)))
439
+
440
+
441
+ # Function to increment token_idx
442
+ def increment_token_idx():
443
+ global token_idx
444
+ token_idx += 1
445
+
446
+ encoding = tokenizer(input_text, return_tensors='pt')
447
+ input_ids = encoding["input_ids"]
448
+ attention_mask = encoding["attention_mask"]
449
+
450
+ embedding_output = capture_encoder_embeddings(model, tokenizer, input_text)
451
+ encoder_outputs['encoder_embedding'] = embedding_output
452
+
453
+ # Initialize decoder_input_ids with the start token
454
+ decoder_input_ids = torch.full(
455
+ (input_ids.shape[0], 1), model.config.decoder_start_token_id, dtype=torch.long
456
+ )
457
+
458
+ # Reset token_idx before generating
459
+ token_idx = 0
460
+ max_length = model.config.max_position_embeddings # Set the maximum length for generation
461
+ generated_tokens = []
462
+
463
+ for _ in range(max_length):
464
+ outputs = model(input_ids=input_ids, attention_mask=attention_mask, decoder_input_ids=decoder_input_ids)
465
+ next_token_logits = outputs.logits[:, -1, :]
466
+ next_token_id = next_token_logits.argmax(dim=-1, keepdim=True)
467
+ generated_tokens.append(next_token_id.item())
468
+ decoder_input_ids = torch.cat([decoder_input_ids, next_token_id], dim=-1)
469
+ increment_token_idx()
470
+
471
+ if next_token_id.item() == model.config.eos_token_id:
472
+ break
473
+
474
+ # Deregister hooks
475
+ for handle in encoder_hooks:
476
+ handle.remove()
477
+
478
+ # Deregister hooks
479
+ for handle in decoder_hooks:
480
+ handle.remove()
481
+
482
+ # Merge the encoder_outputs with timestep decoder_outputs to generate timestep wise outputs of the model
483
+ outputs = {}
484
+
485
+ for i in range(len(decoder_outputs)):
486
+ outputs[f'{i}'] = {**encoder_outputs, **decoder_outputs[f'{i}']}
487
+
488
+ return outputs, generated_tokens
489
+
@@ -0,0 +1,95 @@
1
+ import torch
2
+ from collections import defaultdict
3
+
4
+
5
+ # Function to rename the dictionary keys
6
+ def rename_self_attention_keys(attention_weights):
7
+ renamed_weights = {}
8
+ for key, value in attention_weights.items():
9
+ if 'query.weight' in key or 'SelfAttention.q.weight' in key:
10
+ new_key = key.replace(key, 'W_q')
11
+ elif 'query.bias' in key or 'SelfAttention.q.bias' in key:
12
+ new_key = key.replace(key, 'b_q')
13
+ elif 'key.weight' in key or 'SelfAttention.k.weight' in key:
14
+ new_key = key.replace(key, 'W_k')
15
+ elif 'key.bias' in key or 'SelfAttention.k.bias' in key:
16
+ new_key = key.replace(key, 'b_k')
17
+ elif 'value.weight' in key or 'SelfAttention.v.weight' in key:
18
+ new_key = key.replace(key, 'W_v')
19
+ elif 'value.bias' in key or 'SelfAttention.v.bias' in key:
20
+ new_key = key.replace(key, 'b_v')
21
+ elif 'output.dense.weight' in key or 'SelfAttention.o.weight' in key:
22
+ new_key = key.replace(key, 'W_d')
23
+ elif 'output.dense.bias' in key or 'SelfAttention.o.bias' in key:
24
+ new_key = key.replace(key, 'b_d')
25
+
26
+ renamed_weights[new_key] = value
27
+ return renamed_weights
28
+
29
+
30
+ def rename_cross_attention_keys(cross_attention_weights):
31
+ renamed_weights = {}
32
+
33
+ for key, value in cross_attention_weights.items():
34
+ if 'EncDecAttention.q.weight' in key:
35
+ new_key = key.replace(key, 'W_q')
36
+ elif 'EncDecAttention.k.weight' in key:
37
+ new_key = key.replace(key, 'W_k')
38
+ elif 'EncDecAttention.v.weight' in key:
39
+ new_key = key.replace(key, 'W_v')
40
+ elif 'EncDecAttention.o.weight' in key:
41
+ new_key = key.replace(key, 'W_o')
42
+
43
+ renamed_weights[new_key] = value
44
+ return renamed_weights
45
+
46
+
47
+ def rename_feed_forward_keys(feed_forward_weights):
48
+ renamed_weights = {}
49
+
50
+ for key, value in feed_forward_weights.items():
51
+ if 'intermediate.dense.weight' in key or 'DenseReluDense.wi.weight' in key:
52
+ new_key = key.replace(key, 'W_int')
53
+ elif 'intermediate.dense.bias' in key or 'DenseReluDense.wi.bias' in key:
54
+ new_key = key.replace(key, 'b_int')
55
+ elif 'output.dense.weight' in key or 'DenseReluDense.wo.weight' in key:
56
+ new_key = key.replace(key, 'W_out')
57
+ elif 'output.dense.bias' in key or 'DenseReluDense.wo.bias' in key:
58
+ new_key = key.replace(key, 'b_out')
59
+
60
+ renamed_weights[new_key] = value
61
+ return renamed_weights
62
+
63
+
64
+ def rename_pooler_keys(pooler_weights):
65
+ renamed_weights = {}
66
+ for key, value in pooler_weights.items():
67
+ if 'pooler.dense.weight' in key:
68
+ new_key = key.replace(key, 'W_p')
69
+ elif 'pooler.dense.bias' in key:
70
+ new_key = key.replace(key, 'b_p')
71
+
72
+ renamed_weights[new_key] = value
73
+ return renamed_weights
74
+
75
+
76
+ def rename_classifier_keys(classifier_weights):
77
+ renamed_weights = {}
78
+ for key, value in classifier_weights.items():
79
+ if 'classifier.weight' in key:
80
+ new_key = key.replace(key, 'W_cls')
81
+ elif 'classifier.bias' in key:
82
+ new_key = key.replace(key, 'b_cls')
83
+
84
+ renamed_weights[new_key] = value
85
+ return renamed_weights
86
+
87
+ def rename_decoder_lm_head(lm_head_weights):
88
+ renamed_weights = {}
89
+
90
+ for key, value in lm_head_weights.items():
91
+ if 'shared.weight' in key:
92
+ new_key = key.replace(key, 'W_lm_head')
93
+
94
+ renamed_weights[new_key] = value
95
+ return renamed_weights