dl-backtrace 0.0.12__py3-none-any.whl → 0.0.16__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of dl-backtrace might be problematic. Click here for more details.

Files changed (27) hide show
  1. dl_backtrace/pytorch_backtrace/backtrace/backtrace.py +173 -44
  2. dl_backtrace/pytorch_backtrace/backtrace/utils/__init__.py +3 -0
  3. dl_backtrace/pytorch_backtrace/backtrace/utils/encoder.py +183 -0
  4. dl_backtrace/pytorch_backtrace/backtrace/utils/encoder_decoder.py +489 -0
  5. dl_backtrace/pytorch_backtrace/backtrace/utils/helper.py +95 -0
  6. dl_backtrace/pytorch_backtrace/backtrace/utils/prop.py +481 -0
  7. dl_backtrace/tf_backtrace/backtrace/__init__.py +1 -2
  8. dl_backtrace/tf_backtrace/backtrace/activation_info.py +33 -0
  9. dl_backtrace/tf_backtrace/backtrace/backtrace.py +506 -279
  10. dl_backtrace/tf_backtrace/backtrace/models.py +25 -0
  11. dl_backtrace/tf_backtrace/backtrace/server.py +27 -0
  12. dl_backtrace/tf_backtrace/backtrace/utils/__init__.py +5 -2
  13. dl_backtrace/tf_backtrace/backtrace/utils/encoder.py +206 -0
  14. dl_backtrace/tf_backtrace/backtrace/utils/encoder_decoder.py +501 -0
  15. dl_backtrace/tf_backtrace/backtrace/utils/helper.py +99 -0
  16. dl_backtrace/tf_backtrace/backtrace/utils/utils_contrast.py +1132 -0
  17. dl_backtrace/tf_backtrace/backtrace/utils/utils_prop.py +1582 -0
  18. dl_backtrace/version.py +2 -2
  19. {dl_backtrace-0.0.12.dist-info → dl_backtrace-0.0.16.dist-info}/METADATA +3 -2
  20. dl_backtrace-0.0.16.dist-info/RECORD +29 -0
  21. {dl_backtrace-0.0.12.dist-info → dl_backtrace-0.0.16.dist-info}/WHEEL +1 -1
  22. dl_backtrace/tf_backtrace/backtrace/config.py +0 -41
  23. dl_backtrace/tf_backtrace/backtrace/utils/contrast.py +0 -834
  24. dl_backtrace/tf_backtrace/backtrace/utils/prop.py +0 -725
  25. dl_backtrace-0.0.12.dist-info/RECORD +0 -21
  26. {dl_backtrace-0.0.12.dist-info → dl_backtrace-0.0.16.dist-info}/LICENSE +0 -0
  27. {dl_backtrace-0.0.12.dist-info → dl_backtrace-0.0.16.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,501 @@
1
+ import tensorflow as tf
2
+ from collections import defaultdict
3
+
4
+
5
+ def build_enc_dec_tree(model, root='enc-dec'):
6
+ # Initialize the tree structure
7
+ ltree = {}
8
+ layer_tree = {}
9
+ inputs = []
10
+ outputs = []
11
+ intermediates = []
12
+ layer_stack = []
13
+
14
+ # Base component setup
15
+ def add_component(tree, name, component, child=None):
16
+ tree[name] = {
17
+ 'name': name,
18
+ 'class': component if isinstance(component, str) else type(component).__name__,
19
+ 'type': str(type(component)),
20
+ 'parent': None,
21
+ 'child': None
22
+ }
23
+
24
+ if isinstance(child, list):
25
+ tree[name]['child'] = child
26
+ elif isinstance(child, str):
27
+ tree[name]['child'] = [child]
28
+
29
+ if tree[name]['class'] == 'list':
30
+ tree[name]['class'] = [type(item).__name__ for item in component]
31
+ tree[name]['type'] = [str(type(item)) for item in component]
32
+
33
+ # Keep track of component type in a separate dictionary
34
+ layer_tree[name] = component if isinstance(component, str) else tree[name]['type']
35
+
36
+ # keep track of layer stack
37
+ layer_stack.append(name)
38
+
39
+ # Link the parent to its children
40
+ if isinstance(child, list):
41
+ for ch in child:
42
+ if ch in tree:
43
+ tree[ch]['parent'] = [name]
44
+
45
+ elif isinstance(child, str):
46
+ if child in tree:
47
+ tree[child]['parent'] = [name]
48
+
49
+ return tree[name]
50
+
51
+ # Add root and embeddings component
52
+ encoder_embeddings = add_component(ltree, 'encoder_embedding', model.encoder.embed_tokens, child=None)
53
+
54
+ # Add encoder layers dynamically
55
+ current_child = 'encoder_embedding'
56
+ for i, layer in enumerate(model.encoder.block):
57
+ encoder_layer_norm_0 = add_component(ltree, f'encoder_layer_norm_{i}_0', 'Layer_Norm', child=current_child)
58
+ encoder_self_attention = add_component(ltree, f'encoder_self_attention_{i}', 'Self_Attention', child=f'encoder_layer_norm_{i}_0')
59
+ encoder_residual_self_attention = add_component(ltree, f'encoder_residual_self_attention_{i}', 'Residual', child=[current_child, f'encoder_self_attention_{i}'])
60
+
61
+ encoder_layer_norm_1 = add_component(ltree, f'encoder_layer_norm_{i}_1', 'Layer_Norm', child=f'encoder_residual_self_attention_{i}')
62
+ encoder_feed_forward = add_component(ltree, f'encoder_feed_forward_{i}', 'Feed_Forward', child=f'encoder_layer_norm_{i}_1')
63
+ encoder_residual_feed_forward = add_component(ltree, f'encoder_residual_feed_forward_{i}', 'Residual', child=[f'encoder_residual_self_attention_{i}', f'encoder_feed_forward_{i}'])
64
+
65
+ current_child = f'encoder_residual_feed_forward_{i}'
66
+
67
+ if hasattr(model.encoder, 'final_layer_norm'):
68
+ encoder_final_layer_norm = add_component(ltree, 'encoder_layer_norm', model.encoder.final_layer_norm, child=current_child)
69
+ current_child = 'encoder_layer_norm'
70
+
71
+ # Add Decoder layers
72
+ decoder_embeddings = add_component(ltree, 'decoder_embedding', model.decoder.embed_tokens, child=None)
73
+
74
+ # Add decoder layers dynamically
75
+ current_child = 'decoder_embedding'
76
+ for i, layer in enumerate(model.decoder.block):
77
+ decoder_layer_norm_0 = add_component(ltree, f'decoder_layer_norm_{i}_0', 'Layer_Norm', child=current_child)
78
+ decoder_self_attention = add_component(ltree, f'decoder_self_attention_{i}', 'Self_Attention', child=f'decoder_layer_norm_{i}_0')
79
+ decoder_residual_self_attention = add_component(ltree, f'decoder_residual_self_attention_{i}', 'Residual', child=[current_child, f'decoder_self_attention_{i}'])
80
+
81
+ decoder_layer_norm_1 = add_component(ltree, f'decoder_layer_norm_{i}_1', 'Layer_Norm', child=f'decoder_residual_self_attention_{i}')
82
+ decoder_cross_attention = add_component(ltree, f'decoder_cross_attention_{i}', 'Cross_Attention', child=['encoder_layer_norm', f'decoder_layer_norm_{i}_1'])
83
+ decoder_residual_cross_attention = add_component(ltree, f'decoder_residual_cross_attention_{i}', 'Residual', child=[f'decoder_residual_self_attention_{i}', f'decoder_cross_attention_{i}'])
84
+
85
+ decoder_layer_norm_2 = add_component(ltree, f'decoder_layer_norm_{i}_2', 'Layer_Norm', child=f'decoder_residual_cross_attention_{i}')
86
+ decoder_feed_forward = add_component(ltree, f'decoder_feed_forward_{i}', 'Feed_Forward', child=f'decoder_layer_norm_{i}_2')
87
+ decoder_residual_feed_forward = add_component(ltree, f'decoder_residual_feed_forward_{i}', 'Residual', child=[f'decoder_residual_cross_attention_{i}', f'decoder_feed_forward_{i}'])
88
+
89
+ current_child = f'decoder_residual_feed_forward_{i}'
90
+
91
+ if hasattr(model.decoder, 'final_layer_norm'):
92
+ decoder_final_layer_norm = add_component(ltree, 'decoder_layer_norm', model.decoder.final_layer_norm, child=current_child)
93
+ current_child = 'decoder_layer_norm'
94
+
95
+ # Decoder LM-Head
96
+ if hasattr(model, 'lm_head'):
97
+ decoder_lm_head = add_component(ltree, 'decoder_lm_head', 'LM_Head', child=current_child)
98
+ current_child = 'decoder_lm_head'
99
+ else:
100
+ if model.config.tie_word_embeddings:
101
+ decoder_lm_head = add_component(ltree, 'decoder_lm_head', 'LM_Head', child=current_child)
102
+ else:
103
+ decoder_lm_head = add_component(ltree, 'decoder_lm_head', None, child=current_child)
104
+
105
+ # Classify components
106
+ for name, component in ltree.items():
107
+ if component['parent'] is None:
108
+ outputs.append(component['name'])
109
+ elif component['child'] is None:
110
+ inputs.append(component['name'])
111
+ else:
112
+ intermediates.append(component['name'])
113
+
114
+ # reverse the layer_stack
115
+ layer_stack = list(reversed(layer_stack))
116
+
117
+ model_resource = {
118
+ "layers": layer_tree,
119
+ "graph": ltree,
120
+ "outputs": outputs,
121
+ "inputs": inputs
122
+ }
123
+
124
+ return model_resource, layer_stack
125
+
126
+
127
+ def extract_encoder_decoder_weights(model):
128
+ # Initialize a dictionary to hold the weights
129
+ weights_dict = {
130
+ # 'shared_embeddings': {},
131
+ 'encoder_embedding': {},
132
+ 'encoder_layer_norm': {},
133
+ 'decoder_embedding': {},
134
+ 'decoder_layer_norm': {},
135
+ 'decoder_lm_head': {}
136
+ }
137
+
138
+ # Extract the model's parameters and organize them into the dictionary
139
+ for weight in model.weights:
140
+ name = weight.name
141
+ value = weight.numpy()
142
+
143
+ if 'shared' in name:
144
+ weights_dict['encoder_embedding'][name] = value
145
+ weights_dict['decoder_embedding'][name] = value
146
+ weights_dict['decoder_lm_head'][name] = value
147
+
148
+ elif 'encoder' in name:
149
+ if 'block' in name:
150
+ layer = name.split('/')[2].split('_')[-1]
151
+ sub_layer = name.split('/')[3].split('_')[-1]
152
+ submodule = name.split('/')[4]
153
+
154
+ if 'SelfAttention' in submodule and f'encoder_self_attention_{layer}' not in weights_dict:
155
+ weights_dict[f'encoder_self_attention_{layer}'] = {}
156
+ if 'layer_norm' in submodule and f'encoder_layer_norm_{layer}' not in weights_dict:
157
+ weights_dict[f'encoder_layer_norm_{layer}'] = {}
158
+ if 'DenseReluDense' in submodule and f'encoder_feed_forward_{layer}' not in weights_dict:
159
+ weights_dict[f'encoder_feed_forward_{layer}'] = {}
160
+
161
+ if 'SelfAttention' in submodule:
162
+ weights_dict[f'encoder_self_attention_{layer}'][name] = value
163
+ elif 'layer_norm' in submodule:
164
+ weights_dict[f'encoder_layer_norm_{layer}'][name] = value
165
+ elif 'DenseReluDense' in submodule:
166
+ weights_dict[f'encoder_feed_forward_{layer}'][name] = value
167
+
168
+ elif 'final_layer_norm' in name:
169
+ weights_dict['encoder_layer_norm'][name] = value
170
+
171
+ elif 'decoder/block' in name:
172
+ layer = name.split('/')[2].split('_')[-1]
173
+ sub_layer = name.split('/')[3].split('_')[-1]
174
+ submodule = name.split('/')[4]
175
+
176
+ if 'SelfAttention' in submodule and f'decoder_self_attention_{layer}' not in weights_dict:
177
+ weights_dict[f'decoder_self_attention_{layer}'] = {}
178
+ if 'layer_norm' in submodule and f'decoder_layer_norm_{layer}' not in weights_dict:
179
+ weights_dict[f'decoder_layer_norm_{layer}'] = {}
180
+ if 'EncDecAttention' in submodule and f'decoder_cross_attention_{layer}' not in weights_dict:
181
+ weights_dict[f'decoder_cross_attention_{layer}'] = {}
182
+ if 'DenseReluDense' in submodule and f'decoder_feed_forward_{layer}' not in weights_dict:
183
+ weights_dict[f'decoder_feed_forward_{layer}'] = {}
184
+
185
+ if 'SelfAttention' in submodule:
186
+ weights_dict[f'decoder_self_attention_{layer}'][name] = value
187
+ elif 'layer_norm' in submodule:
188
+ weights_dict[f'decoder_layer_norm_{layer}'][name] = value
189
+ elif 'EncDecAttention' in submodule:
190
+ weights_dict[f'decoder_cross_attention_{layer}'][name] = value
191
+ elif 'DenseReluDense' in submodule:
192
+ weights_dict[f'decoder_feed_forward_{layer}'][name] = value
193
+
194
+ elif 'decoder/final_layer_norm' in name:
195
+ weights_dict['decoder_layer_norm'][name] = value
196
+
197
+ return weights_dict
198
+
199
+
200
+ def calculate_encoder_decoder_output(input_text, model, tokenizer):
201
+ # Dictionaries to store the inputs and outputs
202
+ encoder_inputs = {}
203
+ encoder_outputs = defaultdict(lambda: defaultdict(dict))
204
+ decoder_inputs = defaultdict(lambda: defaultdict(dict))
205
+ decoder_outputs = defaultdict(lambda: defaultdict(dict))
206
+
207
+ # Hook manager to store hook information
208
+ hook_manager = []
209
+
210
+ # Global variable to keep track of the token index
211
+ token_idx = 0
212
+
213
+ # Function to generate timestamp (token index)
214
+ def get_timestamp():
215
+ return str(token_idx)
216
+
217
+ # Function to wrap the call method of the layer to add hooks
218
+ def wrap_call(layer, hook_fn, layer_index=None):
219
+ original_call = layer.call
220
+
221
+ def hooked_call(*args, **kwargs):
222
+ outputs = original_call(*args, **kwargs)
223
+ if layer_index is not None:
224
+ hook_fn(layer, args, outputs, layer_index)
225
+ else:
226
+ hook_fn(layer, args, outputs)
227
+ return outputs
228
+
229
+ layer.call = hooked_call
230
+ hook_manager.append((layer, original_call))
231
+
232
+ def capture_encoder_embeddings(model, tokenizer, input_text):
233
+ # Tokenize the input text
234
+ encoding = tokenizer(input_text, return_tensors='tf')
235
+ input_ids = encoding["input_ids"]
236
+ attention_mask = encoding["attention_mask"]
237
+
238
+ # Manually capture the embedding output
239
+ embedding_output = model.encoder.embed_tokens(input_ids)
240
+
241
+ # Return the captured embeddings
242
+ return embedding_output
243
+
244
+ ## ------------ Hook function for Self-Attention Block --------------------
245
+ def hook_fn__encoder_normalized_hidden_states(layer, inputs, outputs, layer_index):
246
+ encoder_inputs[f'encoder_layer_norm_{layer_index}_0'] = inputs
247
+ encoder_outputs[f'encoder_layer_norm_{layer_index}_0'] = outputs
248
+ encoder_outputs[f'input_to_layer_norm_{layer_index}'] = inputs[0]
249
+
250
+ def hook_fn_encoder_self_attention_outputs(layer, inputs, outputs, layer_index):
251
+ encoder_inputs[f'encoder_self_attention_{layer_index}'] = inputs
252
+ encoder_outputs[f'encoder_self_attention_{layer_index}'] = outputs[0]
253
+
254
+ def hook_fn_encoder_dropout_attention_output(layer, inputs, outputs, layer_index):
255
+ encoder_outputs[f'encoder_dropout_attention_{layer_index}'] = outputs[0]
256
+
257
+ ## ------------ Hook function for Feed-Forward Block --------------
258
+ def hook_fn_encoder_normalized_forwarded_states(layer, inputs, outputs, layer_index):
259
+ encoder_inputs[f'encoder_layer_norm_{layer_index}_1'] = inputs
260
+ encoder_outputs[f'encoder_layer_norm_{layer_index}_1'] = outputs
261
+ encoder_outputs[f'input_to_ff_layer_norm_{layer_index}'] = inputs[0]
262
+
263
+ def hook_fn_encoder_forwarded_states(layer, inputs, outputs, layer_index):
264
+ encoder_inputs[f'encoder_feed_forward_{layer_index}'] = inputs
265
+ encoder_outputs[f'encoder_feed_forward_{layer_index}'] = outputs
266
+
267
+ def hook_fn_encoder_dropout_forwarded_states(layer, inputs, outputs, layer_index):
268
+ encoder_outputs[f'dropout_forwarded_states_{layer_index}'] = outputs
269
+
270
+ # Custom hooks to calculate residuals
271
+ def hook_fn_encoder_residual_self_attention(layer, inputs, outputs, layer_index):
272
+ input_to_layer_norm = encoder_outputs[f'input_to_layer_norm_{layer_index}']
273
+ encoder_outputs[f'encoder_residual_self_attention_{layer_index}'] = input_to_layer_norm + outputs
274
+
275
+ def hook_fn_encoder_residual_feed_forward(layer, inputs, outputs, layer_index):
276
+ input_to_ff_layer_norm = encoder_outputs[f'input_to_ff_layer_norm_{layer_index}']
277
+ encoder_outputs[f'encoder_residual_feed_forward_{layer_index}'] = input_to_ff_layer_norm + outputs
278
+
279
+ # Hook for Final Layer normalization and dropout for Encoder
280
+ def hook_fn_normalized_encoder_output(layer, inputs, outputs):
281
+ encoder_outputs['encoder_layer_norm'] = outputs
282
+
283
+ def hook_fn_dropout_normalized_encoder_output(layer, inputs, outputs):
284
+ encoder_outputs['dropout_normalized_encoder_output'] = outputs
285
+
286
+ # Register hooks to the encoder submodules
287
+ for i, layer in enumerate(model.encoder.block):
288
+ layer.layer[0].layer_norm.layer_index = i
289
+ layer.layer[0].SelfAttention.layer_index = i
290
+ layer.layer[0].dropout.layer_index = i
291
+ layer.layer[1].layer_norm.layer_index = i
292
+ layer.layer[1].DenseReluDense.layer_index = i
293
+ layer.layer[1].dropout.layer_index = i
294
+
295
+ wrap_call(layer.layer[0].layer_norm, hook_fn__encoder_normalized_hidden_states, i)
296
+ wrap_call(layer.layer[0].SelfAttention, hook_fn_encoder_self_attention_outputs, i)
297
+ wrap_call(layer.layer[0].dropout, hook_fn_encoder_dropout_attention_output, i)
298
+ wrap_call(layer.layer[0].dropout, hook_fn_encoder_residual_self_attention, i)
299
+
300
+ wrap_call(layer.layer[1].layer_norm, hook_fn_encoder_normalized_forwarded_states, i)
301
+ wrap_call(layer.layer[1].DenseReluDense, hook_fn_encoder_forwarded_states, i)
302
+ wrap_call(layer.layer[1].dropout, hook_fn_encoder_dropout_forwarded_states, i)
303
+ wrap_call(layer.layer[1].dropout, hook_fn_encoder_residual_feed_forward, i)
304
+
305
+ wrap_call(model.encoder.final_layer_norm, hook_fn_normalized_encoder_output)
306
+ wrap_call(model.encoder.dropout, hook_fn_dropout_normalized_encoder_output)
307
+
308
+ ############################ Hook for Decoder ################################
309
+ def hook_fn_decoder_embedding(layer, inputs, outputs):
310
+ global token_idx
311
+ timestamp = get_timestamp()
312
+ decoder_outputs[timestamp]['decoder_embedding'] = outputs
313
+
314
+ def hook_fn_decoder_normalized_hidden_states(layer, inputs, outputs, layer_index):
315
+ global token_idx
316
+ timestamp = get_timestamp()
317
+ decoder_inputs[timestamp][f'decoder_layer_norm_{layer_index}_0'] = inputs
318
+ decoder_outputs[timestamp][f'decoder_layer_norm_{layer_index}_0'] = outputs
319
+ decoder_outputs[timestamp][f'input_to_layer_norm_{layer_index}'] = inputs[0]
320
+
321
+ def hook_fn_decoder_self_attention_outputs(layer, inputs, outputs, layer_index):
322
+ global token_idx
323
+ timestamp = get_timestamp()
324
+ decoder_inputs[timestamp][f'decoder_self_attention_{layer_index}'] = inputs
325
+ decoder_outputs[timestamp][f'decoder_self_attention_{layer_index}'] = outputs[0]
326
+
327
+ def hook_fn_decoder_dropout_attention_output(layer, inputs, outputs, layer_index):
328
+ global token_idx
329
+ timestamp = get_timestamp()
330
+ decoder_outputs[timestamp][f'dropout_attention_output_{layer_index}'] = outputs
331
+
332
+ def hook_fn_decoder_normalized_cross_attn_hidden_states(layer, inputs, outputs, layer_index):
333
+ global token_idx
334
+ timestamp = get_timestamp()
335
+ decoder_inputs[timestamp][f'decoder_layer_norm_{layer_index}_1'] = inputs
336
+ decoder_outputs[timestamp][f'decoder_layer_norm_{layer_index}_1'] = outputs
337
+ decoder_outputs[timestamp][f'input_to_cross_attn_layer_norm_{layer_index}'] = inputs[0]
338
+
339
+ def hook_fn_decoder_cross_attention_outputs(layer, inputs, outputs, layer_index):
340
+ global token_idx
341
+ timestamp = get_timestamp()
342
+ decoder_inputs[timestamp][f'decoder_cross_attention_{layer_index}'] = {'query': inputs[0], 'key': outputs, 'value': outputs}
343
+ decoder_outputs[timestamp][f'decoder_cross_attention_{layer_index}'] = outputs[0]
344
+
345
+ def hook_fn_decoder_dropout_cross_attn_output(layer, inputs, outputs, layer_index):
346
+ global token_idx
347
+ timestamp = get_timestamp()
348
+ decoder_outputs[timestamp][f'dropout_cross_attn_output_{layer_index}'] = outputs
349
+
350
+ def hook_fn_decoder_normalized_forwarded_states(layer, inputs, outputs, layer_index):
351
+ global token_idx
352
+ timestamp = get_timestamp()
353
+ decoder_inputs[timestamp][f'decoder_layer_norm_{layer_index}_2'] = inputs
354
+ decoder_outputs[timestamp][f'decoder_layer_norm_{layer_index}_2'] = outputs
355
+ decoder_outputs[timestamp][f'input_to_ff_layer_norm_{layer_index}'] = inputs[0]
356
+
357
+ def hook_fn_decoder_forwarded_states(layer, inputs, outputs, layer_index):
358
+ global token_idx
359
+ timestamp = get_timestamp()
360
+ decoder_inputs[timestamp][f'decoder_feed_forward_{layer_index}'] = inputs[0]
361
+ decoder_outputs[timestamp][f'decoder_feed_forward_{layer_index}'] = outputs
362
+
363
+ def hook_fn_decoder_dropout_forwarded_states(layer, inputs, outputs, layer_index):
364
+ global token_idx
365
+ timestamp = get_timestamp()
366
+ decoder_outputs[timestamp][f'dropout_forwarded_states_{layer_index}'] = outputs
367
+
368
+ def hook_fn_decoder_residual_self_attention(layer, inputs, outputs, layer_index):
369
+ global token_idx
370
+ timestamp = get_timestamp()
371
+ input_to_layer_norm = decoder_outputs[timestamp][f'input_to_layer_norm_{layer_index}']
372
+ decoder_outputs[timestamp][f'decoder_residual_self_attention_{layer_index}'] = input_to_layer_norm + outputs
373
+
374
+ def hook_fn_decoder_residual_cross_attention(layer, inputs, outputs, layer_index):
375
+ global token_idx
376
+ timestamp = get_timestamp()
377
+ input_to_layer_norm = decoder_outputs[timestamp][f'input_to_cross_attn_layer_norm_{layer_index}']
378
+ decoder_outputs[timestamp][f'decoder_residual_cross_attention_{layer_index}'] = input_to_layer_norm + outputs
379
+
380
+ def hook_fn_decoder_residual_feed_forward(layer, inputs, outputs, layer_index):
381
+ global token_idx
382
+ timestamp = get_timestamp()
383
+ input_to_ff_layer_norm = decoder_outputs[timestamp][f'input_to_ff_layer_norm_{layer_index}']
384
+ decoder_outputs[timestamp][f'decoder_residual_feed_forward_{layer_index}'] = input_to_ff_layer_norm + outputs
385
+
386
+ def hook_fn_normalized_decoder_output(layer, inputs, outputs):
387
+ global token_idx
388
+ timestamp = get_timestamp()
389
+ decoder_outputs[timestamp]['decoder_layer_norm'] = outputs
390
+
391
+ def hook_fn_dropout_normalized_decoder_output(layer, inputs, outputs):
392
+ global token_idx
393
+ timestamp = get_timestamp()
394
+ decoder_outputs[timestamp]['dropout_normalized_decoder_output'] = outputs
395
+
396
+ def hook_fn_final_logits(layer, inputs, outputs):
397
+ global token_idx
398
+ timestamp = get_timestamp()
399
+ decoder_outputs[timestamp]['decoder_lm_head'] = outputs.logits
400
+
401
+ # Register hooks to the decoder submodules
402
+ wrap_call(model.decoder.embed_tokens, hook_fn_decoder_embedding)
403
+
404
+ for i, layer in enumerate(model.decoder.block):
405
+ layer.layer[0].layer_norm.layer_index = i
406
+ layer.layer[0].SelfAttention.layer_index = i
407
+ layer.layer[0].dropout.layer_index = i
408
+ layer.layer[1].layer_norm.layer_index = i
409
+ layer.layer[1].EncDecAttention.layer_index = i
410
+ layer.layer[1].dropout.layer_index = i
411
+ layer.layer[2].layer_norm.layer_index = i
412
+ layer.layer[2].DenseReluDense.layer_index = i
413
+ layer.layer[2].dropout.layer_index = i
414
+
415
+ wrap_call(layer.layer[0].layer_norm, hook_fn_decoder_normalized_hidden_states, i)
416
+ wrap_call(layer.layer[0].SelfAttention, hook_fn_decoder_self_attention_outputs, i)
417
+ wrap_call(layer.layer[0].dropout, hook_fn_decoder_dropout_attention_output, i)
418
+ wrap_call(layer.layer[0].dropout, hook_fn_decoder_residual_self_attention, i)
419
+
420
+ wrap_call(layer.layer[1].layer_norm, hook_fn_decoder_normalized_cross_attn_hidden_states, i)
421
+ wrap_call(layer.layer[1].EncDecAttention, hook_fn_decoder_cross_attention_outputs, i)
422
+ wrap_call(layer.layer[1].dropout, hook_fn_decoder_dropout_cross_attn_output, i)
423
+ wrap_call(layer.layer[1].dropout, hook_fn_decoder_residual_cross_attention, i)
424
+
425
+ wrap_call(layer.layer[2].layer_norm, hook_fn_decoder_normalized_forwarded_states, i)
426
+ wrap_call(layer.layer[2].DenseReluDense, hook_fn_decoder_forwarded_states, i)
427
+ wrap_call(layer.layer[2].dropout, hook_fn_decoder_dropout_forwarded_states, i)
428
+ wrap_call(layer.layer[2].dropout, hook_fn_decoder_residual_feed_forward, i)
429
+
430
+ wrap_call(model.decoder.final_layer_norm, hook_fn_normalized_decoder_output)
431
+ wrap_call(model.decoder.dropout, hook_fn_dropout_normalized_decoder_output)
432
+
433
+ # Register hook for the final logits by wrapping the call method of the model itself
434
+ original_call = model.call
435
+ def hooked_call(*args, **kwargs):
436
+ outputs = original_call(*args, **kwargs)
437
+ hook_fn_final_logits(model, args, outputs)
438
+ return outputs
439
+
440
+ model.call = hooked_call
441
+ hook_manager.append((model, original_call))
442
+
443
+ # Function to remove hooks
444
+ def remove_hooks():
445
+ for layer, original_call in hook_manager:
446
+ layer.call = original_call
447
+ hook_manager.clear()
448
+
449
+ # Function to get shape
450
+ def get_shape(value):
451
+ if isinstance(value, tf.Tensor):
452
+ return value.shape
453
+ elif isinstance(value, tuple):
454
+ return [get_shape(v) for v in value if v is not None]
455
+ elif isinstance(value, list):
456
+ return [get_shape(v) for v in value if v is not None]
457
+ elif isinstance(value, dict):
458
+ return {k: get_shape(v) for k, v in value.items()}
459
+ else:
460
+ return None
461
+
462
+ # Function to increment token index
463
+ def increment_token_idx():
464
+ global token_idx
465
+ token_idx += 1
466
+
467
+ encoding = tokenizer(input_text, return_tensors='tf')
468
+ input_ids = encoding["input_ids"]
469
+ attention_mask = encoding["attention_mask"]
470
+
471
+ embedding_output = capture_encoder_embeddings(model, tokenizer, input_text)
472
+ encoder_outputs['encoder_embedding'] = embedding_output
473
+
474
+ # Initialize decoder_input_ids with the start token
475
+ decoder_start_token_id = model.config.decoder_start_token_id
476
+ decoder_input_ids = tf.fill((tf.shape(input_ids)[0], 1), decoder_start_token_id)
477
+
478
+ # Reset token_idx before generating
479
+ token_idx = 0
480
+ max_length = model.config.n_positions if hasattr(model.config, 'n_positions') else model.config.d_model
481
+ generated_tokens = []
482
+
483
+ for _ in range(max_length):
484
+ outputs = model(input_ids=input_ids, attention_mask=attention_mask, decoder_input_ids=decoder_input_ids)
485
+ next_token_logits = outputs.logits[:, -1, :]
486
+ next_token_id = tf.argmax(next_token_logits, axis=-1, output_type=tf.int32)
487
+ next_token_id = tf.expand_dims(next_token_id, axis=-1)
488
+ generated_tokens.append(next_token_id.numpy().item())
489
+ decoder_input_ids = tf.concat([decoder_input_ids, next_token_id], axis=-1)
490
+ increment_token_idx()
491
+
492
+ if next_token_id.numpy().item() == model.config.eos_token_id:
493
+ break
494
+
495
+ # Merge the encoder_outputs with timestep decoder_outputs to generate timestep wise outputs of the model
496
+ outputs = {}
497
+
498
+ for i in range(len(decoder_outputs)):
499
+ outputs[f'{i}'] = {**encoder_outputs, **decoder_outputs[f'{i}']}
500
+
501
+ return outputs, generated_tokens
@@ -0,0 +1,99 @@
1
+ import tensorflow as tf
2
+
3
+
4
+ # Function to rename the dictionary keys
5
+ def rename_self_attention_keys(attention_weights):
6
+ renamed_weights = {}
7
+ for key, value in attention_weights.items():
8
+ if 'query/kernel' in key or 'SelfAttention/q' in key:
9
+ new_key = key.replace(key, 'W_q')
10
+ elif 'query/bias' in key:
11
+ new_key = key.replace(key, 'b_q')
12
+ elif 'key/kernel' in key or 'SelfAttention/k' in key:
13
+ new_key = key.replace(key, 'W_k')
14
+ elif 'key/bias' in key:
15
+ new_key = key.replace(key, 'b_k')
16
+ elif 'value/kernel' in key or 'SelfAttention/v' in key:
17
+ new_key = key.replace(key, 'W_v')
18
+ elif 'value/bias' in key:
19
+ new_key = key.replace(key, 'b_v')
20
+ elif 'output/dense/kernel' in key or 'SelfAttention/o' in key:
21
+ new_key = key.replace(key, 'W_d')
22
+ elif 'output/dense/bias' in key:
23
+ new_key = key.replace(key, 'b_d')
24
+ elif 'SelfAttention/relative_attention_bias' in key:
25
+ new_key = key.replace(key, 'relative_attn_bias')
26
+
27
+ renamed_weights[new_key] = value
28
+ return renamed_weights
29
+
30
+
31
+ def rename_cross_attention_keys(cross_attention_weights):
32
+ renamed_weights = {}
33
+
34
+ for key, value in cross_attention_weights.items():
35
+ if 'EncDecAttention/q' in key:
36
+ new_key = key.replace(key, 'W_q')
37
+ elif 'EncDecAttention/k' in key:
38
+ new_key = key.replace(key, 'W_k')
39
+ elif 'EncDecAttention/v' in key:
40
+ new_key = key.replace(key, 'W_v')
41
+ elif 'EncDecAttention/o' in key:
42
+ new_key = key.replace(key, 'W_o')
43
+
44
+ renamed_weights[new_key] = value
45
+ return renamed_weights
46
+
47
+
48
+ def rename_feed_forward_keys(feed_forward_weights):
49
+ renamed_weights = {}
50
+
51
+ for key, value in feed_forward_weights.items():
52
+ if 'intermediate/dense/kernel' in key or 'DenseReluDense/wi' in key:
53
+ new_key = key.replace(key, 'W_int')
54
+ elif 'intermediate/dense/bias' in key or 'DenseReluDense/bi' in key:
55
+ new_key = key.replace(key, 'b_int')
56
+ elif 'output/dense/kernel' in key or 'DenseReluDense/wo' in key:
57
+ new_key = key.replace(key, 'W_out')
58
+ elif 'output/dense/bias' in key or 'DenseReluDense/bo' in key:
59
+ new_key = key.replace(key, 'b_out')
60
+
61
+ renamed_weights[new_key] = value
62
+ return renamed_weights
63
+
64
+
65
+ def rename_pooler_keys(pooler_weights):
66
+ renamed_weights = {}
67
+
68
+ for key, value in pooler_weights.items():
69
+ if 'pooler/dense/kernel' in key:
70
+ new_key = key.replace(key, 'W_p')
71
+ elif 'pooler/dense/bias' in key:
72
+ new_key = key.replace(key, 'b_p')
73
+
74
+ renamed_weights[new_key] = value
75
+ return renamed_weights
76
+
77
+
78
+ def rename_classifier_keys(classifier_weights):
79
+ renamed_weights = {}
80
+
81
+ for key, value in classifier_weights.items():
82
+ if 'classifier/kernel' in key:
83
+ new_key = key.replace(key, 'W_cls')
84
+ elif 'classifier/bias' in key:
85
+ new_key = key.replace(key, 'b_cls')
86
+
87
+ renamed_weights[new_key] = value
88
+ return renamed_weights
89
+
90
+
91
+ def rename_decoder_lm_head(lm_head_weights):
92
+ renamed_weights = {}
93
+
94
+ for key, value in lm_head_weights.items():
95
+ if 'shared/shared/embeddings' in key:
96
+ new_key = key.replace(key, 'W_lm_head')
97
+
98
+ renamed_weights[new_key] = value
99
+ return renamed_weights