dl-backtrace 0.0.12__py3-none-any.whl → 0.0.16__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of dl-backtrace might be problematic. Click here for more details.
- dl_backtrace/pytorch_backtrace/backtrace/backtrace.py +173 -44
- dl_backtrace/pytorch_backtrace/backtrace/utils/__init__.py +3 -0
- dl_backtrace/pytorch_backtrace/backtrace/utils/encoder.py +183 -0
- dl_backtrace/pytorch_backtrace/backtrace/utils/encoder_decoder.py +489 -0
- dl_backtrace/pytorch_backtrace/backtrace/utils/helper.py +95 -0
- dl_backtrace/pytorch_backtrace/backtrace/utils/prop.py +481 -0
- dl_backtrace/tf_backtrace/backtrace/__init__.py +1 -2
- dl_backtrace/tf_backtrace/backtrace/activation_info.py +33 -0
- dl_backtrace/tf_backtrace/backtrace/backtrace.py +506 -279
- dl_backtrace/tf_backtrace/backtrace/models.py +25 -0
- dl_backtrace/tf_backtrace/backtrace/server.py +27 -0
- dl_backtrace/tf_backtrace/backtrace/utils/__init__.py +5 -2
- dl_backtrace/tf_backtrace/backtrace/utils/encoder.py +206 -0
- dl_backtrace/tf_backtrace/backtrace/utils/encoder_decoder.py +501 -0
- dl_backtrace/tf_backtrace/backtrace/utils/helper.py +99 -0
- dl_backtrace/tf_backtrace/backtrace/utils/utils_contrast.py +1132 -0
- dl_backtrace/tf_backtrace/backtrace/utils/utils_prop.py +1582 -0
- dl_backtrace/version.py +2 -2
- {dl_backtrace-0.0.12.dist-info → dl_backtrace-0.0.16.dist-info}/METADATA +3 -2
- dl_backtrace-0.0.16.dist-info/RECORD +29 -0
- {dl_backtrace-0.0.12.dist-info → dl_backtrace-0.0.16.dist-info}/WHEEL +1 -1
- dl_backtrace/tf_backtrace/backtrace/config.py +0 -41
- dl_backtrace/tf_backtrace/backtrace/utils/contrast.py +0 -834
- dl_backtrace/tf_backtrace/backtrace/utils/prop.py +0 -725
- dl_backtrace-0.0.12.dist-info/RECORD +0 -21
- {dl_backtrace-0.0.12.dist-info → dl_backtrace-0.0.16.dist-info}/LICENSE +0 -0
- {dl_backtrace-0.0.12.dist-info → dl_backtrace-0.0.16.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,489 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from collections import defaultdict
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
def build_enc_dec_tree(model, root='enc-dec'):
|
|
6
|
+
# Initialize the tree structure
|
|
7
|
+
ltree = {}
|
|
8
|
+
layer_tree = {}
|
|
9
|
+
inputs = []
|
|
10
|
+
outputs = []
|
|
11
|
+
intermediates = []
|
|
12
|
+
layer_stack = []
|
|
13
|
+
|
|
14
|
+
# Base component setup
|
|
15
|
+
def add_component(tree, name, component, child=None):
|
|
16
|
+
tree[name] = {
|
|
17
|
+
'name': name,
|
|
18
|
+
'class': component if type(component).__name__ == 'str' else type(component).__name__,
|
|
19
|
+
'type': str(type(component)),
|
|
20
|
+
'parent': None,
|
|
21
|
+
'child': None
|
|
22
|
+
}
|
|
23
|
+
|
|
24
|
+
if isinstance(child, list):
|
|
25
|
+
tree[name]['child'] = child
|
|
26
|
+
elif isinstance(child, str):
|
|
27
|
+
tree[name]['child'] = [child]
|
|
28
|
+
|
|
29
|
+
if tree[name]['class'] == 'list':
|
|
30
|
+
tree[name]['class'] = [type(item).__name__ for item in component]
|
|
31
|
+
tree[name]['type'] = [str(type(item)) for item in component]
|
|
32
|
+
|
|
33
|
+
# Keep track of component type in a separate dictionary
|
|
34
|
+
layer_tree[name] = component if type(component).__name__ == 'str' else tree[name]['type']
|
|
35
|
+
|
|
36
|
+
# keep track of layer stack
|
|
37
|
+
layer_stack.append(name)
|
|
38
|
+
|
|
39
|
+
# Link the parent to its children
|
|
40
|
+
if isinstance(child, list):
|
|
41
|
+
for ch in child:
|
|
42
|
+
if ch in tree:
|
|
43
|
+
tree[ch]['parent'] = [name]
|
|
44
|
+
|
|
45
|
+
elif isinstance(child, str):
|
|
46
|
+
if child in tree:
|
|
47
|
+
tree[child]['parent'] = [name]
|
|
48
|
+
|
|
49
|
+
return tree[name]
|
|
50
|
+
|
|
51
|
+
# Add root and embeddings component
|
|
52
|
+
encoder_embeddings = add_component(ltree, 'encoder_embedding', model.encoder.embed_tokens, child=None)
|
|
53
|
+
|
|
54
|
+
# Add encoder layers dynamically
|
|
55
|
+
current_child = 'encoder_embedding'
|
|
56
|
+
for i, layer in enumerate(model.encoder.block):
|
|
57
|
+
encoder_layer_norm_0 = add_component(ltree, f'encoder_layer_norm_{i}_0', 'Layer_Norm', child=current_child)
|
|
58
|
+
encoder_self_attention = add_component(ltree, f'encoder_self_attention_{i}', 'Self_Attention', child=f'encoder_layer_norm_{i}_0')
|
|
59
|
+
encoder_residual_self_attention = add_component(ltree, f'encoder_residual_self_attention_{i}', 'Residual', child=[current_child, f'encoder_self_attention_{i}'])
|
|
60
|
+
|
|
61
|
+
encoder_layer_norm_1 = add_component(ltree, f'encoder_layer_norm_{i}_1', 'Layer_Norm', child=f'encoder_residual_self_attention_{i}')
|
|
62
|
+
encoder_feed_forward = add_component(ltree, f'encoder_feed_forward_{i}', 'Feed_Forward', child=f'encoder_layer_norm_{i}_1')
|
|
63
|
+
encoder_residual_feed_forward = add_component(ltree, f'encoder_residual_feed_forward_{i}', 'Residual', child=[f'encoder_residual_self_attention_{i}', f'encoder_feed_forward_{i}'])
|
|
64
|
+
|
|
65
|
+
current_child = f'encoder_residual_feed_forward_{i}'
|
|
66
|
+
|
|
67
|
+
if hasattr(model.encoder, 'final_layer_norm'):
|
|
68
|
+
encoder_final_layer_norm = add_component(ltree, 'encoder_layer_norm', model.encoder.final_layer_norm, child=current_child)
|
|
69
|
+
current_child = 'encoder_layer_norm'
|
|
70
|
+
|
|
71
|
+
# Add Decoder layers
|
|
72
|
+
decoder_embeddings = add_component(ltree, 'decoder_embedding', model.decoder.embed_tokens, child=None)
|
|
73
|
+
|
|
74
|
+
# Add decoder layers dynamically
|
|
75
|
+
current_child = 'decoder_embedding'
|
|
76
|
+
for i, layer in enumerate(model.decoder.block):
|
|
77
|
+
decoder_layer_norm_0 = add_component(ltree, f'decoder_layer_norm_{i}_0', 'Layer_Norm', child=current_child)
|
|
78
|
+
decoder_self_attention = add_component(ltree, f'decoder_self_attention_{i}', 'Self_Attention', child=f'decoder_layer_norm_{i}_0')
|
|
79
|
+
decoder_residual_self_attention = add_component(ltree, f'decoder_residual_self_attention_{i}', 'Residual', child=[current_child, f'decoder_self_attention_{i}'])
|
|
80
|
+
|
|
81
|
+
decoder_layer_norm_1 = add_component(ltree, f'decoder_layer_norm_{i}_1', 'Layer_Norm', child=f'decoder_residual_self_attention_{i}')
|
|
82
|
+
decoder_cross_attention = add_component(ltree, f'decoder_cross_attention_{i}', 'Cross_Attention', child=['encoder_layer_norm', f'decoder_layer_norm_{i}_1'])
|
|
83
|
+
decoder_residual_cross_attention = add_component(ltree, f'decoder_residual_cross_attention_{i}', 'Residual', child=[f'decoder_residual_self_attention_{i}', f'decoder_cross_attention_{i}'])
|
|
84
|
+
|
|
85
|
+
decoder_layer_norm_2 = add_component(ltree, f'decoder_layer_norm_{i}_2', 'Layer_Norm', child=f'decoder_residual_cross_attention_{i}')
|
|
86
|
+
decoder_feed_forward = add_component(ltree, f'decoder_feed_forward_{i}', 'Feed_Forward', child=f'decoder_layer_norm_{i}_2')
|
|
87
|
+
decoder_residual_feed_forward = add_component(ltree, f'decoder_residual_feed_forward_{i}', 'Residual', child=[f'decoder_residual_cross_attention_{i}', f'decoder_feed_forward_{i}'])
|
|
88
|
+
|
|
89
|
+
current_child = f'decoder_residual_feed_forward_{i}'
|
|
90
|
+
|
|
91
|
+
if hasattr(model.decoder, 'final_layer_norm'):
|
|
92
|
+
decoder_final_layer_norm = add_component(ltree, 'decoder_layer_norm', model.decoder.final_layer_norm, child=current_child)
|
|
93
|
+
current_child = 'decoder_layer_norm'
|
|
94
|
+
|
|
95
|
+
# Decoder LM-Head
|
|
96
|
+
if hasattr(model, 'lm_head'):
|
|
97
|
+
decoder_lm_head = add_component(ltree, 'decoder_lm_head', 'LM_Head', child=current_child)
|
|
98
|
+
current_child = 'decoder_lm_head'
|
|
99
|
+
|
|
100
|
+
# Classify components
|
|
101
|
+
for name, component in ltree.items():
|
|
102
|
+
if component['parent'] is None:
|
|
103
|
+
outputs.append(component['name'])
|
|
104
|
+
elif component['child'] is None:
|
|
105
|
+
inputs.append(component['name'])
|
|
106
|
+
else:
|
|
107
|
+
intermediates.append(component['name'])
|
|
108
|
+
|
|
109
|
+
# reverse the layer_stack
|
|
110
|
+
layer_stack = list(reversed(layer_stack))
|
|
111
|
+
model_resource = (layer_tree, ltree, outputs, inputs)
|
|
112
|
+
|
|
113
|
+
return model_resource, layer_stack
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
def extract_encoder_decoder_weights(model):
|
|
117
|
+
# Initialize a dictionary to hold the weights
|
|
118
|
+
weights_dict = {
|
|
119
|
+
# 'shared_embeddings': {},
|
|
120
|
+
'encoder_embedding': {},
|
|
121
|
+
'encoder_layer_norm': {},
|
|
122
|
+
'decoder_embedding': {},
|
|
123
|
+
'decoder_layer_norm': {},
|
|
124
|
+
'decoder_lm_head': {}
|
|
125
|
+
}
|
|
126
|
+
|
|
127
|
+
# Extract the model's parameters and organize them into the dictionary
|
|
128
|
+
for name, param in model.named_parameters():
|
|
129
|
+
if 'shared' in name:
|
|
130
|
+
weights_dict['encoder_embedding'][name] = param.data.cpu().numpy()
|
|
131
|
+
weights_dict['decoder_embedding'][name] = param.data.cpu().numpy()
|
|
132
|
+
weights_dict['decoder_lm_head'][name] = param.data.cpu().numpy()
|
|
133
|
+
|
|
134
|
+
elif 'encoder.block' in name:
|
|
135
|
+
layer = name.split('.')[2]
|
|
136
|
+
sub_layer = name.split('.')[4]
|
|
137
|
+
submodule = name.split('.')[5]
|
|
138
|
+
|
|
139
|
+
if 'SelfAttention' in submodule and f'encoder_self_attention_{layer}' not in weights_dict:
|
|
140
|
+
weights_dict[f'encoder_self_attention_{layer}'] = {}
|
|
141
|
+
if 'layer_norm' in submodule and f'encoder_layer_norm_{layer}' not in weights_dict:
|
|
142
|
+
weights_dict[f'encoder_layer_norm_{layer}'] = {}
|
|
143
|
+
if 'DenseReluDense' in submodule and f'encoder_feed_forward_{layer}' not in weights_dict:
|
|
144
|
+
weights_dict[f'encoder_feed_forward_{layer}'] = {}
|
|
145
|
+
|
|
146
|
+
if 'SelfAttention' in submodule:
|
|
147
|
+
weights_dict[f'encoder_self_attention_{layer}'][name] = param.data.cpu().numpy()
|
|
148
|
+
elif 'layer_norm' in submodule:
|
|
149
|
+
weights_dict[f'encoder_layer_norm_{layer}'][name] = param.data.cpu().numpy()
|
|
150
|
+
elif 'DenseReluDense' in submodule:
|
|
151
|
+
weights_dict[f'encoder_feed_forward_{layer}'][name] = param.data.cpu().numpy()
|
|
152
|
+
|
|
153
|
+
elif 'encoder.final_layer_norm.weight' in name:
|
|
154
|
+
weights_dict['encoder_layer_norm'][name] = param.data.cpu().numpy()
|
|
155
|
+
|
|
156
|
+
elif 'decoder.block' in name:
|
|
157
|
+
layer = name.split('.')[2]
|
|
158
|
+
sub_layer = name.split('.')[4]
|
|
159
|
+
submodule = name.split('.')[5]
|
|
160
|
+
|
|
161
|
+
if 'SelfAttention' in submodule and f'decoder_self_attention_{layer}' not in weights_dict:
|
|
162
|
+
weights_dict[f'decoder_self_attention_{layer}'] = {}
|
|
163
|
+
if 'layer_norm' in submodule and f'decoder_layer_norm_{layer}' not in weights_dict:
|
|
164
|
+
weights_dict[f'decoder_layer_norm_{layer}'] = {}
|
|
165
|
+
if 'EncDecAttention' in submodule and f'decoder_cross_attention_{layer}' not in weights_dict:
|
|
166
|
+
weights_dict[f'decoder_cross_attention_{layer}'] = {}
|
|
167
|
+
if 'DenseReluDense' in submodule and f'decoder_feed_forward_{layer}' not in weights_dict:
|
|
168
|
+
weights_dict[f'decoder_feed_forward_{layer}'] = {}
|
|
169
|
+
|
|
170
|
+
if 'SelfAttention' in submodule:
|
|
171
|
+
weights_dict[f'decoder_self_attention_{layer}'][name] = param.data.cpu().numpy()
|
|
172
|
+
elif 'layer_norm' in submodule:
|
|
173
|
+
weights_dict[f'decoder_layer_norm_{layer}'][name] = param.data.cpu().numpy()
|
|
174
|
+
elif 'EncDecAttention' in submodule:
|
|
175
|
+
weights_dict[f'decoder_cross_attention_{layer}'][name] = param.data.cpu().numpy()
|
|
176
|
+
elif 'DenseReluDense' in submodule:
|
|
177
|
+
weights_dict[f'decoder_feed_forward_{layer}'][name] = param.data.cpu().numpy()
|
|
178
|
+
|
|
179
|
+
elif 'decoder.final_layer_norm.weight' in name:
|
|
180
|
+
weights_dict['decoder_layer_norm'][name] = param.data.cpu().numpy()
|
|
181
|
+
|
|
182
|
+
return weights_dict
|
|
183
|
+
|
|
184
|
+
|
|
185
|
+
def calculate_encoder_decoder_output(input_text, model, tokenizer):
|
|
186
|
+
encoder_inputs = {}
|
|
187
|
+
encoder_outputs = {}
|
|
188
|
+
encoder_hooks = []
|
|
189
|
+
decoder_outputs = defaultdict(lambda: defaultdict(dict))
|
|
190
|
+
decoder_inputs = defaultdict(lambda: defaultdict(dict))
|
|
191
|
+
decoder_hooks = []
|
|
192
|
+
|
|
193
|
+
def capture_encoder_embeddings(model, tokenizer, input_text):
|
|
194
|
+
# Ensure the model is in evaluation mode
|
|
195
|
+
model.eval()
|
|
196
|
+
|
|
197
|
+
# Tokenize the input text
|
|
198
|
+
encoding = tokenizer(input_text, return_tensors='pt')
|
|
199
|
+
input_ids = encoding["input_ids"]
|
|
200
|
+
attention_mask = encoding["attention_mask"]
|
|
201
|
+
|
|
202
|
+
# Manually capture the embedding output
|
|
203
|
+
with torch.no_grad():
|
|
204
|
+
embedding_output = model.encoder.embed_tokens(input_ids)
|
|
205
|
+
|
|
206
|
+
# Return the captured embeddings
|
|
207
|
+
return embedding_output
|
|
208
|
+
|
|
209
|
+
## ------------ Hook function for Self-Attention Block --------------------
|
|
210
|
+
def hook_fn__encoder_normalized_hidden_states(module, input, output):
|
|
211
|
+
encoder_inputs[f'encoder_layer_norm_{module.layer_index}_0'] = input
|
|
212
|
+
encoder_outputs[f'encoder_layer_norm_{module.layer_index}_0'] = output
|
|
213
|
+
encoder_outputs[f'input_to_layer_norm_{module.layer_index}'] = input[0]
|
|
214
|
+
|
|
215
|
+
def hook_fn_encoder_self_attention_outputs(module, input, output):
|
|
216
|
+
encoder_inputs[f'encoder_self_attention_{module.layer_index}'] = input
|
|
217
|
+
encoder_outputs[f'encoder_self_attention_{module.layer_index}'] = output[0]
|
|
218
|
+
|
|
219
|
+
def hook_fn_encoder_dropout_attention_output(module, input, output):
|
|
220
|
+
encoder_outputs[f'encoder_dropout_attention_{module.layer_index}'] = output[0]
|
|
221
|
+
|
|
222
|
+
## ------------ Hook function for Feed-Forward Block --------------
|
|
223
|
+
def hook_fn_encoder_normalized_forwarded_states(module, input, output):
|
|
224
|
+
encoder_inputs[f'encoder_layer_norm_{module.layer_index}_1'] = input
|
|
225
|
+
encoder_outputs[f'encoder_layer_norm_{module.layer_index}_1'] = output
|
|
226
|
+
encoder_outputs[f'input_to_ff_layer_norm_{module.layer_index}'] = input[0]
|
|
227
|
+
|
|
228
|
+
def hook_fn_encoder_forwarded_states(module, input, output):
|
|
229
|
+
encoder_inputs[f'encoder_feed_forward_{module.layer_index}'] = input
|
|
230
|
+
encoder_outputs[f'encoder_feed_forward_{module.layer_index}'] = output
|
|
231
|
+
|
|
232
|
+
def hook_fn_encoder_dropout_forwarded_states(module, input, output):
|
|
233
|
+
encoder_outputs[f'dropout_forwarded_states_{module.layer_index}'] = output
|
|
234
|
+
|
|
235
|
+
# Custom hooks to calculate residuals
|
|
236
|
+
def hook_fn_encoder_residual_self_attention(layer_index):
|
|
237
|
+
def hook(module, input, output):
|
|
238
|
+
input_to_layer_norm = encoder_outputs[f'input_to_layer_norm_{layer_index}']
|
|
239
|
+
encoder_outputs[f'encoder_residual_self_attention_{layer_index}'] = input_to_layer_norm + output
|
|
240
|
+
return hook
|
|
241
|
+
|
|
242
|
+
def hook_fn_encoder_residual_feed_forward(layer_index):
|
|
243
|
+
def hook(module, input, output):
|
|
244
|
+
input_to_ff_layer_norm = encoder_outputs[f'input_to_ff_layer_norm_{layer_index}']
|
|
245
|
+
encoder_outputs[f'encoder_residual_feed_forward_{layer_index}'] = input_to_ff_layer_norm + output
|
|
246
|
+
return hook
|
|
247
|
+
|
|
248
|
+
# Hook for Final Layer normalization and dropout for Encoder
|
|
249
|
+
def hook_fn_normalized_encoder_output(module, input, output):
|
|
250
|
+
encoder_outputs['encoder_layer_norm'] = output
|
|
251
|
+
|
|
252
|
+
def hook_fn_dropout_normalized_encoder_output(module, input, output):
|
|
253
|
+
encoder_outputs['dropout_normalized_encoder_output'] = output
|
|
254
|
+
|
|
255
|
+
# Register hooks to the encoder submodules
|
|
256
|
+
for i, layer in enumerate(model.encoder.block):
|
|
257
|
+
# Set layer_index attribute to all relevant submodules
|
|
258
|
+
layer.layer[0].layer_norm.layer_index = i
|
|
259
|
+
layer.layer[0].SelfAttention.layer_index = i
|
|
260
|
+
layer.layer[0].dropout.layer_index = i
|
|
261
|
+
layer.layer[1].layer_norm.layer_index = i
|
|
262
|
+
layer.layer[1].DenseReluDense.layer_index = i
|
|
263
|
+
layer.layer[1].dropout.layer_index = i
|
|
264
|
+
|
|
265
|
+
encoder_hooks.append(layer.layer[0].layer_norm.register_forward_hook(hook_fn__encoder_normalized_hidden_states))
|
|
266
|
+
encoder_hooks.append(layer.layer[0].SelfAttention.register_forward_hook(hook_fn_encoder_self_attention_outputs))
|
|
267
|
+
encoder_hooks.append(layer.layer[0].dropout.register_forward_hook(hook_fn_encoder_dropout_attention_output))
|
|
268
|
+
encoder_hooks.append(layer.layer[0].dropout.register_forward_hook(hook_fn_encoder_residual_self_attention(i))) # Custom hook for residual self-attention
|
|
269
|
+
|
|
270
|
+
encoder_hooks.append(layer.layer[1].layer_norm.register_forward_hook(hook_fn_encoder_normalized_forwarded_states))
|
|
271
|
+
encoder_hooks.append(layer.layer[1].DenseReluDense.register_forward_hook(hook_fn_encoder_forwarded_states))
|
|
272
|
+
encoder_hooks.append(layer.layer[1].dropout.register_forward_hook(hook_fn_encoder_dropout_forwarded_states))
|
|
273
|
+
encoder_hooks.append(layer.layer[1].dropout.register_forward_hook(hook_fn_encoder_residual_feed_forward(i))) # Custom hook for residual feed-forward
|
|
274
|
+
|
|
275
|
+
# Register hook for Final Layer Normalization and dropout for Encoder
|
|
276
|
+
encoder_hooks.append(model.encoder.final_layer_norm.register_forward_hook(hook_fn_normalized_encoder_output))
|
|
277
|
+
encoder_hooks.append(model.encoder.dropout.register_forward_hook(hook_fn_dropout_normalized_encoder_output))
|
|
278
|
+
|
|
279
|
+
############################ Hook for Decoder ################################
|
|
280
|
+
# Global variable to keep track of the token index
|
|
281
|
+
token_idx = 0
|
|
282
|
+
|
|
283
|
+
# Function to generate timestamp (token index)
|
|
284
|
+
def get_timestamp():
|
|
285
|
+
return str(token_idx)
|
|
286
|
+
|
|
287
|
+
# Hook functions to capture input embedding
|
|
288
|
+
def hook_fn_decoder_embedding(module, input, output):
|
|
289
|
+
global token_idx
|
|
290
|
+
timestamp = get_timestamp()
|
|
291
|
+
decoder_outputs[timestamp]['decoder_embedding'] = output.detach().clone()
|
|
292
|
+
|
|
293
|
+
## ------------ Hook function for Self-Attention Block --------------------
|
|
294
|
+
def hook_fn_decoder_normalized_hidden_states(module, input, output, layer_index):
|
|
295
|
+
global token_idx
|
|
296
|
+
timestamp = get_timestamp()
|
|
297
|
+
decoder_inputs[timestamp][f'decoder_layer_norm_{layer_index}_0'] = input
|
|
298
|
+
decoder_outputs[timestamp][f'decoder_layer_norm_{layer_index}_0'] = output
|
|
299
|
+
decoder_outputs[timestamp][f'input_to_layer_norm_{layer_index}'] = input[0]
|
|
300
|
+
|
|
301
|
+
def hook_fn_decoder_self_attention_outputs(module, input, output, layer_index):
|
|
302
|
+
global token_idx
|
|
303
|
+
timestamp = get_timestamp()
|
|
304
|
+
decoder_inputs[timestamp][f'decoder_self_attention_{layer_index}'] = input
|
|
305
|
+
decoder_outputs[timestamp][f'decoder_self_attention_{layer_index}'] = output[0]
|
|
306
|
+
|
|
307
|
+
def hook_fn_decoder_dropout_attention_output(module, input, output, layer_index):
|
|
308
|
+
global token_idx
|
|
309
|
+
timestamp = get_timestamp()
|
|
310
|
+
decoder_outputs[timestamp][f'dropout_attention_output_{layer_index}'] = output
|
|
311
|
+
|
|
312
|
+
## ------------ Hook function for Cross-Attention Block --------------------
|
|
313
|
+
def hook_fn_decoder_normalized_cross_attn_hidden_states(module, input, output, layer_index):
|
|
314
|
+
global token_idx
|
|
315
|
+
timestamp = get_timestamp()
|
|
316
|
+
decoder_inputs[timestamp][f'decoder_layer_norm_{layer_index}_1'] = input
|
|
317
|
+
decoder_outputs[timestamp][f'decoder_layer_norm_{layer_index}_1'] = output
|
|
318
|
+
decoder_outputs[timestamp][f'input_to_cross_attn_layer_norm_{layer_index}'] = input[0]
|
|
319
|
+
|
|
320
|
+
def hook_fn_decoder_cross_attention_outputs(module, input, output, layer_index):
|
|
321
|
+
global token_idx
|
|
322
|
+
timestamp = get_timestamp()
|
|
323
|
+
key_value_states = encoder_outputs['dropout_normalized_encoder_output']
|
|
324
|
+
query_state = input[0]
|
|
325
|
+
|
|
326
|
+
inputs = {
|
|
327
|
+
'query': input[0],
|
|
328
|
+
'key': key_value_states,
|
|
329
|
+
'value': key_value_states
|
|
330
|
+
}
|
|
331
|
+
|
|
332
|
+
decoder_inputs[timestamp][f'decoder_cross_attention_{layer_index}'] = inputs
|
|
333
|
+
decoder_outputs[timestamp][f'decoder_cross_attention_{layer_index}'] = output[0]
|
|
334
|
+
|
|
335
|
+
def hook_fn_decoder_dropout_cross_attn_output(module, input, output, layer_index):
|
|
336
|
+
global token_idx
|
|
337
|
+
timestamp = get_timestamp()
|
|
338
|
+
decoder_outputs[timestamp][f'dropout_cross_attn_output_{layer_index}'] = output
|
|
339
|
+
|
|
340
|
+
## ------------ Hook function for Feed-Forward Block --------------------
|
|
341
|
+
def hook_fn_decoder_normalized_forwarded_states(module, input, output, layer_index):
|
|
342
|
+
global token_idx
|
|
343
|
+
timestamp = get_timestamp()
|
|
344
|
+
decoder_inputs[timestamp][f'decoder_layer_norm_{layer_index}_2'] = input
|
|
345
|
+
decoder_outputs[timestamp][f'decoder_layer_norm_{layer_index}_2'] = output
|
|
346
|
+
decoder_outputs[timestamp][f'input_to_ff_layer_norm_{layer_index}'] = input[0]
|
|
347
|
+
|
|
348
|
+
def hook_fn_decoder_forwarded_states(module, input, output, layer_index):
|
|
349
|
+
global token_idx
|
|
350
|
+
timestamp = get_timestamp()
|
|
351
|
+
decoder_inputs[timestamp][f'decoder_feed_forward_{layer_index}'] = input[0]
|
|
352
|
+
decoder_outputs[timestamp][f'decoder_feed_forward_{layer_index}'] = output
|
|
353
|
+
|
|
354
|
+
def hook_fn_decoder_dropout_forwarded_states(module, input, output, layer_index):
|
|
355
|
+
global token_idx
|
|
356
|
+
timestamp = get_timestamp()
|
|
357
|
+
decoder_outputs[timestamp][f'dropout_forwarded_states_{layer_index}'] = output
|
|
358
|
+
|
|
359
|
+
# Custom hooks to calculate residuals
|
|
360
|
+
def hook_fn_decoder_residual_self_attention(layer_index):
|
|
361
|
+
def hook(module, input, output):
|
|
362
|
+
global token_idx
|
|
363
|
+
timestamp = get_timestamp()
|
|
364
|
+
input_to_layer_norm = decoder_outputs[timestamp][f'input_to_layer_norm_{layer_index}']
|
|
365
|
+
decoder_outputs[timestamp][f'decoder_residual_self_attention_{layer_index}'] = input_to_layer_norm + output
|
|
366
|
+
return hook
|
|
367
|
+
|
|
368
|
+
def hook_fn_decoder_residual_cross_attention(layer_index):
|
|
369
|
+
def hook(module, input, output):
|
|
370
|
+
global token_idx
|
|
371
|
+
timestamp = get_timestamp()
|
|
372
|
+
input_to_layer_norm = decoder_outputs[timestamp][f'input_to_cross_attn_layer_norm_{layer_index}']
|
|
373
|
+
decoder_outputs[timestamp][f'decoder_residual_cross_attention_{layer_index}'] = input_to_layer_norm + output
|
|
374
|
+
return hook
|
|
375
|
+
|
|
376
|
+
def hook_fn_decoder_residual_feed_forward(layer_index):
|
|
377
|
+
def hook(module, input, output):
|
|
378
|
+
global token_idx
|
|
379
|
+
timestamp = get_timestamp()
|
|
380
|
+
input_to_ff_layer_norm = decoder_outputs[timestamp][f'input_to_ff_layer_norm_{layer_index}']
|
|
381
|
+
decoder_outputs[timestamp][f'decoder_residual_feed_forward_{layer_index}'] = input_to_ff_layer_norm + output
|
|
382
|
+
return hook
|
|
383
|
+
|
|
384
|
+
# Hook for Final Layer normalization and dropout for Decoder
|
|
385
|
+
def hook_fn_normalized_decoder_output(module, input, output):
|
|
386
|
+
global token_idx
|
|
387
|
+
timestamp = get_timestamp()
|
|
388
|
+
decoder_outputs[timestamp]['decoder_layer_norm'] = output
|
|
389
|
+
|
|
390
|
+
def hook_fn_dropout_normalized_decoder_output(module, input, output):
|
|
391
|
+
global token_idx
|
|
392
|
+
timestamp = get_timestamp()
|
|
393
|
+
decoder_outputs[timestamp]['dropout_normalized_decoder_output'] = output
|
|
394
|
+
|
|
395
|
+
# Hook for the Decoder LM-Head
|
|
396
|
+
def hook_fn_lm_head(module, input, output):
|
|
397
|
+
global token_idx
|
|
398
|
+
timestamp = get_timestamp()
|
|
399
|
+
decoder_outputs[timestamp]['decoder_lm_head'] = output
|
|
400
|
+
|
|
401
|
+
|
|
402
|
+
# Register hook for embedding
|
|
403
|
+
decoder_hooks.append(model.decoder.embed_tokens.register_forward_hook(lambda module, input, output: hook_fn_decoder_embedding(module, input, output)))
|
|
404
|
+
|
|
405
|
+
# Register hooks to the decoder submodules
|
|
406
|
+
for i, layer in enumerate(model.decoder.block):
|
|
407
|
+
# Set layer_index attribute to all relevant submodules
|
|
408
|
+
layer.layer[0].layer_norm.layer_index = i
|
|
409
|
+
layer.layer[0].SelfAttention.layer_index = i
|
|
410
|
+
layer.layer[0].dropout.layer_index = i
|
|
411
|
+
layer.layer[1].layer_norm.layer_index = i
|
|
412
|
+
layer.layer[1].EncDecAttention.layer_index = i
|
|
413
|
+
layer.layer[1].dropout.layer_index = i
|
|
414
|
+
layer.layer[2].layer_norm.layer_index = i
|
|
415
|
+
layer.layer[2].DenseReluDense.layer_index = i
|
|
416
|
+
layer.layer[2].dropout.layer_index = i
|
|
417
|
+
|
|
418
|
+
decoder_hooks.append(layer.layer[0].layer_norm.register_forward_hook(lambda module, input, output, i=i: hook_fn_decoder_normalized_hidden_states(module, input, output, layer_index=i)))
|
|
419
|
+
decoder_hooks.append(layer.layer[0].SelfAttention.register_forward_hook(lambda module, input, output, i=i: hook_fn_decoder_self_attention_outputs(module, input, output, layer_index=i)))
|
|
420
|
+
decoder_hooks.append(layer.layer[0].dropout.register_forward_hook(lambda module, input, output, i=i: hook_fn_decoder_dropout_attention_output(module, input, output, layer_index=i)))
|
|
421
|
+
decoder_hooks.append(layer.layer[0].dropout.register_forward_hook(lambda module, input, output, i=i: hook_fn_decoder_residual_self_attention(i)(module, input, output)))
|
|
422
|
+
|
|
423
|
+
decoder_hooks.append(layer.layer[1].layer_norm.register_forward_hook(lambda module, input, output, i=i: hook_fn_decoder_normalized_cross_attn_hidden_states(module, input, output, layer_index=i)))
|
|
424
|
+
decoder_hooks.append(layer.layer[1].EncDecAttention.register_forward_hook(lambda module, input, output, i=i: hook_fn_decoder_cross_attention_outputs(module, input, output, layer_index=i)))
|
|
425
|
+
decoder_hooks.append(layer.layer[1].dropout.register_forward_hook(lambda module, input, output, i=i: hook_fn_decoder_dropout_cross_attn_output(module, input, output, layer_index=i)))
|
|
426
|
+
decoder_hooks.append(layer.layer[1].dropout.register_forward_hook(lambda module, input, output, i=i: hook_fn_decoder_residual_cross_attention(i)(module, input, output)))
|
|
427
|
+
|
|
428
|
+
decoder_hooks.append(layer.layer[2].layer_norm.register_forward_hook(lambda module, input, output, i=i: hook_fn_decoder_normalized_forwarded_states(module, input, output, layer_index=i)))
|
|
429
|
+
decoder_hooks.append(layer.layer[2].DenseReluDense.register_forward_hook(lambda module, input, output, i=i: hook_fn_decoder_forwarded_states(module, input, output, layer_index=i)))
|
|
430
|
+
decoder_hooks.append(layer.layer[2].dropout.register_forward_hook(lambda module, input, output, i=i: hook_fn_decoder_dropout_forwarded_states(module, input, output, layer_index=i)))
|
|
431
|
+
decoder_hooks.append(layer.layer[2].dropout.register_forward_hook(lambda module, input, output, i=i: hook_fn_decoder_residual_feed_forward(i)(module, input, output)))
|
|
432
|
+
|
|
433
|
+
# Register hook for Final Layer Normalization and dropout for Decoder
|
|
434
|
+
decoder_hooks.append(model.decoder.final_layer_norm.register_forward_hook(lambda module, input, output: hook_fn_normalized_decoder_output(module, input, output)))
|
|
435
|
+
decoder_hooks.append(model.decoder.dropout.register_forward_hook(lambda module, input, output: hook_fn_dropout_normalized_decoder_output(module, input, output)))
|
|
436
|
+
|
|
437
|
+
# Register hook for the Decoder LM-Head
|
|
438
|
+
decoder_hooks.append(model.lm_head.register_forward_hook(lambda module, input, output: hook_fn_lm_head(module, input, output)))
|
|
439
|
+
|
|
440
|
+
|
|
441
|
+
# Function to increment token_idx
|
|
442
|
+
def increment_token_idx():
|
|
443
|
+
global token_idx
|
|
444
|
+
token_idx += 1
|
|
445
|
+
|
|
446
|
+
encoding = tokenizer(input_text, return_tensors='pt')
|
|
447
|
+
input_ids = encoding["input_ids"]
|
|
448
|
+
attention_mask = encoding["attention_mask"]
|
|
449
|
+
|
|
450
|
+
embedding_output = capture_encoder_embeddings(model, tokenizer, input_text)
|
|
451
|
+
encoder_outputs['encoder_embedding'] = embedding_output
|
|
452
|
+
|
|
453
|
+
# Initialize decoder_input_ids with the start token
|
|
454
|
+
decoder_input_ids = torch.full(
|
|
455
|
+
(input_ids.shape[0], 1), model.config.decoder_start_token_id, dtype=torch.long
|
|
456
|
+
)
|
|
457
|
+
|
|
458
|
+
# Reset token_idx before generating
|
|
459
|
+
token_idx = 0
|
|
460
|
+
max_length = model.config.max_position_embeddings # Set the maximum length for generation
|
|
461
|
+
generated_tokens = []
|
|
462
|
+
|
|
463
|
+
for _ in range(max_length):
|
|
464
|
+
outputs = model(input_ids=input_ids, attention_mask=attention_mask, decoder_input_ids=decoder_input_ids)
|
|
465
|
+
next_token_logits = outputs.logits[:, -1, :]
|
|
466
|
+
next_token_id = next_token_logits.argmax(dim=-1, keepdim=True)
|
|
467
|
+
generated_tokens.append(next_token_id.item())
|
|
468
|
+
decoder_input_ids = torch.cat([decoder_input_ids, next_token_id], dim=-1)
|
|
469
|
+
increment_token_idx()
|
|
470
|
+
|
|
471
|
+
if next_token_id.item() == model.config.eos_token_id:
|
|
472
|
+
break
|
|
473
|
+
|
|
474
|
+
# Deregister hooks
|
|
475
|
+
for handle in encoder_hooks:
|
|
476
|
+
handle.remove()
|
|
477
|
+
|
|
478
|
+
# Deregister hooks
|
|
479
|
+
for handle in decoder_hooks:
|
|
480
|
+
handle.remove()
|
|
481
|
+
|
|
482
|
+
# Merge the encoder_outputs with timestep decoder_outputs to generate timestep wise outputs of the model
|
|
483
|
+
outputs = {}
|
|
484
|
+
|
|
485
|
+
for i in range(len(decoder_outputs)):
|
|
486
|
+
outputs[f'{i}'] = {**encoder_outputs, **decoder_outputs[f'{i}']}
|
|
487
|
+
|
|
488
|
+
return outputs, generated_tokens
|
|
489
|
+
|
|
@@ -0,0 +1,95 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from collections import defaultdict
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
# Function to rename the dictionary keys
|
|
6
|
+
def rename_self_attention_keys(attention_weights):
|
|
7
|
+
renamed_weights = {}
|
|
8
|
+
for key, value in attention_weights.items():
|
|
9
|
+
if 'query.weight' in key or 'SelfAttention.q.weight' in key:
|
|
10
|
+
new_key = key.replace(key, 'W_q')
|
|
11
|
+
elif 'query.bias' in key or 'SelfAttention.q.bias' in key:
|
|
12
|
+
new_key = key.replace(key, 'b_q')
|
|
13
|
+
elif 'key.weight' in key or 'SelfAttention.k.weight' in key:
|
|
14
|
+
new_key = key.replace(key, 'W_k')
|
|
15
|
+
elif 'key.bias' in key or 'SelfAttention.k.bias' in key:
|
|
16
|
+
new_key = key.replace(key, 'b_k')
|
|
17
|
+
elif 'value.weight' in key or 'SelfAttention.v.weight' in key:
|
|
18
|
+
new_key = key.replace(key, 'W_v')
|
|
19
|
+
elif 'value.bias' in key or 'SelfAttention.v.bias' in key:
|
|
20
|
+
new_key = key.replace(key, 'b_v')
|
|
21
|
+
elif 'output.dense.weight' in key or 'SelfAttention.o.weight' in key:
|
|
22
|
+
new_key = key.replace(key, 'W_d')
|
|
23
|
+
elif 'output.dense.bias' in key or 'SelfAttention.o.bias' in key:
|
|
24
|
+
new_key = key.replace(key, 'b_d')
|
|
25
|
+
|
|
26
|
+
renamed_weights[new_key] = value
|
|
27
|
+
return renamed_weights
|
|
28
|
+
|
|
29
|
+
|
|
30
|
+
def rename_cross_attention_keys(cross_attention_weights):
|
|
31
|
+
renamed_weights = {}
|
|
32
|
+
|
|
33
|
+
for key, value in cross_attention_weights.items():
|
|
34
|
+
if 'EncDecAttention.q.weight' in key:
|
|
35
|
+
new_key = key.replace(key, 'W_q')
|
|
36
|
+
elif 'EncDecAttention.k.weight' in key:
|
|
37
|
+
new_key = key.replace(key, 'W_k')
|
|
38
|
+
elif 'EncDecAttention.v.weight' in key:
|
|
39
|
+
new_key = key.replace(key, 'W_v')
|
|
40
|
+
elif 'EncDecAttention.o.weight' in key:
|
|
41
|
+
new_key = key.replace(key, 'W_o')
|
|
42
|
+
|
|
43
|
+
renamed_weights[new_key] = value
|
|
44
|
+
return renamed_weights
|
|
45
|
+
|
|
46
|
+
|
|
47
|
+
def rename_feed_forward_keys(feed_forward_weights):
|
|
48
|
+
renamed_weights = {}
|
|
49
|
+
|
|
50
|
+
for key, value in feed_forward_weights.items():
|
|
51
|
+
if 'intermediate.dense.weight' in key or 'DenseReluDense.wi.weight' in key:
|
|
52
|
+
new_key = key.replace(key, 'W_int')
|
|
53
|
+
elif 'intermediate.dense.bias' in key or 'DenseReluDense.wi.bias' in key:
|
|
54
|
+
new_key = key.replace(key, 'b_int')
|
|
55
|
+
elif 'output.dense.weight' in key or 'DenseReluDense.wo.weight' in key:
|
|
56
|
+
new_key = key.replace(key, 'W_out')
|
|
57
|
+
elif 'output.dense.bias' in key or 'DenseReluDense.wo.bias' in key:
|
|
58
|
+
new_key = key.replace(key, 'b_out')
|
|
59
|
+
|
|
60
|
+
renamed_weights[new_key] = value
|
|
61
|
+
return renamed_weights
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
def rename_pooler_keys(pooler_weights):
|
|
65
|
+
renamed_weights = {}
|
|
66
|
+
for key, value in pooler_weights.items():
|
|
67
|
+
if 'pooler.dense.weight' in key:
|
|
68
|
+
new_key = key.replace(key, 'W_p')
|
|
69
|
+
elif 'pooler.dense.bias' in key:
|
|
70
|
+
new_key = key.replace(key, 'b_p')
|
|
71
|
+
|
|
72
|
+
renamed_weights[new_key] = value
|
|
73
|
+
return renamed_weights
|
|
74
|
+
|
|
75
|
+
|
|
76
|
+
def rename_classifier_keys(classifier_weights):
|
|
77
|
+
renamed_weights = {}
|
|
78
|
+
for key, value in classifier_weights.items():
|
|
79
|
+
if 'classifier.weight' in key:
|
|
80
|
+
new_key = key.replace(key, 'W_cls')
|
|
81
|
+
elif 'classifier.bias' in key:
|
|
82
|
+
new_key = key.replace(key, 'b_cls')
|
|
83
|
+
|
|
84
|
+
renamed_weights[new_key] = value
|
|
85
|
+
return renamed_weights
|
|
86
|
+
|
|
87
|
+
def rename_decoder_lm_head(lm_head_weights):
|
|
88
|
+
renamed_weights = {}
|
|
89
|
+
|
|
90
|
+
for key, value in lm_head_weights.items():
|
|
91
|
+
if 'shared.weight' in key:
|
|
92
|
+
new_key = key.replace(key, 'W_lm_head')
|
|
93
|
+
|
|
94
|
+
renamed_weights[new_key] = value
|
|
95
|
+
return renamed_weights
|