dl-backtrace 0.0.14__py3-none-any.whl → 0.0.16__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of dl-backtrace might be problematic. Click here for more details.
- dl_backtrace/pytorch_backtrace/backtrace/backtrace.py +173 -44
- dl_backtrace/pytorch_backtrace/backtrace/utils/__init__.py +3 -0
- dl_backtrace/pytorch_backtrace/backtrace/utils/encoder.py +183 -0
- dl_backtrace/pytorch_backtrace/backtrace/utils/encoder_decoder.py +489 -0
- dl_backtrace/pytorch_backtrace/backtrace/utils/helper.py +95 -0
- dl_backtrace/pytorch_backtrace/backtrace/utils/prop.py +481 -0
- dl_backtrace/tf_backtrace/backtrace/__init__.py +1 -2
- dl_backtrace/tf_backtrace/backtrace/activation_info.py +33 -0
- dl_backtrace/tf_backtrace/backtrace/backtrace.py +506 -279
- dl_backtrace/tf_backtrace/backtrace/models.py +25 -0
- dl_backtrace/tf_backtrace/backtrace/server.py +27 -0
- dl_backtrace/tf_backtrace/backtrace/utils/__init__.py +5 -2
- dl_backtrace/tf_backtrace/backtrace/utils/encoder.py +206 -0
- dl_backtrace/tf_backtrace/backtrace/utils/encoder_decoder.py +501 -0
- dl_backtrace/tf_backtrace/backtrace/utils/helper.py +99 -0
- dl_backtrace/tf_backtrace/backtrace/utils/utils_contrast.py +1132 -0
- dl_backtrace/tf_backtrace/backtrace/utils/utils_prop.py +1582 -0
- dl_backtrace/version.py +2 -2
- {dl_backtrace-0.0.14.dist-info → dl_backtrace-0.0.16.dist-info}/METADATA +2 -2
- dl_backtrace-0.0.16.dist-info/RECORD +29 -0
- {dl_backtrace-0.0.14.dist-info → dl_backtrace-0.0.16.dist-info}/WHEEL +1 -1
- dl_backtrace/tf_backtrace/backtrace/config.py +0 -41
- dl_backtrace/tf_backtrace/backtrace/utils/contrast.py +0 -834
- dl_backtrace/tf_backtrace/backtrace/utils/prop.py +0 -725
- dl_backtrace-0.0.14.dist-info/RECORD +0 -21
- {dl_backtrace-0.0.14.dist-info → dl_backtrace-0.0.16.dist-info}/LICENSE +0 -0
- {dl_backtrace-0.0.14.dist-info → dl_backtrace-0.0.16.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,501 @@
|
|
|
1
|
+
import tensorflow as tf
|
|
2
|
+
from collections import defaultdict
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
def build_enc_dec_tree(model, root='enc-dec'):
|
|
6
|
+
# Initialize the tree structure
|
|
7
|
+
ltree = {}
|
|
8
|
+
layer_tree = {}
|
|
9
|
+
inputs = []
|
|
10
|
+
outputs = []
|
|
11
|
+
intermediates = []
|
|
12
|
+
layer_stack = []
|
|
13
|
+
|
|
14
|
+
# Base component setup
|
|
15
|
+
def add_component(tree, name, component, child=None):
|
|
16
|
+
tree[name] = {
|
|
17
|
+
'name': name,
|
|
18
|
+
'class': component if isinstance(component, str) else type(component).__name__,
|
|
19
|
+
'type': str(type(component)),
|
|
20
|
+
'parent': None,
|
|
21
|
+
'child': None
|
|
22
|
+
}
|
|
23
|
+
|
|
24
|
+
if isinstance(child, list):
|
|
25
|
+
tree[name]['child'] = child
|
|
26
|
+
elif isinstance(child, str):
|
|
27
|
+
tree[name]['child'] = [child]
|
|
28
|
+
|
|
29
|
+
if tree[name]['class'] == 'list':
|
|
30
|
+
tree[name]['class'] = [type(item).__name__ for item in component]
|
|
31
|
+
tree[name]['type'] = [str(type(item)) for item in component]
|
|
32
|
+
|
|
33
|
+
# Keep track of component type in a separate dictionary
|
|
34
|
+
layer_tree[name] = component if isinstance(component, str) else tree[name]['type']
|
|
35
|
+
|
|
36
|
+
# keep track of layer stack
|
|
37
|
+
layer_stack.append(name)
|
|
38
|
+
|
|
39
|
+
# Link the parent to its children
|
|
40
|
+
if isinstance(child, list):
|
|
41
|
+
for ch in child:
|
|
42
|
+
if ch in tree:
|
|
43
|
+
tree[ch]['parent'] = [name]
|
|
44
|
+
|
|
45
|
+
elif isinstance(child, str):
|
|
46
|
+
if child in tree:
|
|
47
|
+
tree[child]['parent'] = [name]
|
|
48
|
+
|
|
49
|
+
return tree[name]
|
|
50
|
+
|
|
51
|
+
# Add root and embeddings component
|
|
52
|
+
encoder_embeddings = add_component(ltree, 'encoder_embedding', model.encoder.embed_tokens, child=None)
|
|
53
|
+
|
|
54
|
+
# Add encoder layers dynamically
|
|
55
|
+
current_child = 'encoder_embedding'
|
|
56
|
+
for i, layer in enumerate(model.encoder.block):
|
|
57
|
+
encoder_layer_norm_0 = add_component(ltree, f'encoder_layer_norm_{i}_0', 'Layer_Norm', child=current_child)
|
|
58
|
+
encoder_self_attention = add_component(ltree, f'encoder_self_attention_{i}', 'Self_Attention', child=f'encoder_layer_norm_{i}_0')
|
|
59
|
+
encoder_residual_self_attention = add_component(ltree, f'encoder_residual_self_attention_{i}', 'Residual', child=[current_child, f'encoder_self_attention_{i}'])
|
|
60
|
+
|
|
61
|
+
encoder_layer_norm_1 = add_component(ltree, f'encoder_layer_norm_{i}_1', 'Layer_Norm', child=f'encoder_residual_self_attention_{i}')
|
|
62
|
+
encoder_feed_forward = add_component(ltree, f'encoder_feed_forward_{i}', 'Feed_Forward', child=f'encoder_layer_norm_{i}_1')
|
|
63
|
+
encoder_residual_feed_forward = add_component(ltree, f'encoder_residual_feed_forward_{i}', 'Residual', child=[f'encoder_residual_self_attention_{i}', f'encoder_feed_forward_{i}'])
|
|
64
|
+
|
|
65
|
+
current_child = f'encoder_residual_feed_forward_{i}'
|
|
66
|
+
|
|
67
|
+
if hasattr(model.encoder, 'final_layer_norm'):
|
|
68
|
+
encoder_final_layer_norm = add_component(ltree, 'encoder_layer_norm', model.encoder.final_layer_norm, child=current_child)
|
|
69
|
+
current_child = 'encoder_layer_norm'
|
|
70
|
+
|
|
71
|
+
# Add Decoder layers
|
|
72
|
+
decoder_embeddings = add_component(ltree, 'decoder_embedding', model.decoder.embed_tokens, child=None)
|
|
73
|
+
|
|
74
|
+
# Add decoder layers dynamically
|
|
75
|
+
current_child = 'decoder_embedding'
|
|
76
|
+
for i, layer in enumerate(model.decoder.block):
|
|
77
|
+
decoder_layer_norm_0 = add_component(ltree, f'decoder_layer_norm_{i}_0', 'Layer_Norm', child=current_child)
|
|
78
|
+
decoder_self_attention = add_component(ltree, f'decoder_self_attention_{i}', 'Self_Attention', child=f'decoder_layer_norm_{i}_0')
|
|
79
|
+
decoder_residual_self_attention = add_component(ltree, f'decoder_residual_self_attention_{i}', 'Residual', child=[current_child, f'decoder_self_attention_{i}'])
|
|
80
|
+
|
|
81
|
+
decoder_layer_norm_1 = add_component(ltree, f'decoder_layer_norm_{i}_1', 'Layer_Norm', child=f'decoder_residual_self_attention_{i}')
|
|
82
|
+
decoder_cross_attention = add_component(ltree, f'decoder_cross_attention_{i}', 'Cross_Attention', child=['encoder_layer_norm', f'decoder_layer_norm_{i}_1'])
|
|
83
|
+
decoder_residual_cross_attention = add_component(ltree, f'decoder_residual_cross_attention_{i}', 'Residual', child=[f'decoder_residual_self_attention_{i}', f'decoder_cross_attention_{i}'])
|
|
84
|
+
|
|
85
|
+
decoder_layer_norm_2 = add_component(ltree, f'decoder_layer_norm_{i}_2', 'Layer_Norm', child=f'decoder_residual_cross_attention_{i}')
|
|
86
|
+
decoder_feed_forward = add_component(ltree, f'decoder_feed_forward_{i}', 'Feed_Forward', child=f'decoder_layer_norm_{i}_2')
|
|
87
|
+
decoder_residual_feed_forward = add_component(ltree, f'decoder_residual_feed_forward_{i}', 'Residual', child=[f'decoder_residual_cross_attention_{i}', f'decoder_feed_forward_{i}'])
|
|
88
|
+
|
|
89
|
+
current_child = f'decoder_residual_feed_forward_{i}'
|
|
90
|
+
|
|
91
|
+
if hasattr(model.decoder, 'final_layer_norm'):
|
|
92
|
+
decoder_final_layer_norm = add_component(ltree, 'decoder_layer_norm', model.decoder.final_layer_norm, child=current_child)
|
|
93
|
+
current_child = 'decoder_layer_norm'
|
|
94
|
+
|
|
95
|
+
# Decoder LM-Head
|
|
96
|
+
if hasattr(model, 'lm_head'):
|
|
97
|
+
decoder_lm_head = add_component(ltree, 'decoder_lm_head', 'LM_Head', child=current_child)
|
|
98
|
+
current_child = 'decoder_lm_head'
|
|
99
|
+
else:
|
|
100
|
+
if model.config.tie_word_embeddings:
|
|
101
|
+
decoder_lm_head = add_component(ltree, 'decoder_lm_head', 'LM_Head', child=current_child)
|
|
102
|
+
else:
|
|
103
|
+
decoder_lm_head = add_component(ltree, 'decoder_lm_head', None, child=current_child)
|
|
104
|
+
|
|
105
|
+
# Classify components
|
|
106
|
+
for name, component in ltree.items():
|
|
107
|
+
if component['parent'] is None:
|
|
108
|
+
outputs.append(component['name'])
|
|
109
|
+
elif component['child'] is None:
|
|
110
|
+
inputs.append(component['name'])
|
|
111
|
+
else:
|
|
112
|
+
intermediates.append(component['name'])
|
|
113
|
+
|
|
114
|
+
# reverse the layer_stack
|
|
115
|
+
layer_stack = list(reversed(layer_stack))
|
|
116
|
+
|
|
117
|
+
model_resource = {
|
|
118
|
+
"layers": layer_tree,
|
|
119
|
+
"graph": ltree,
|
|
120
|
+
"outputs": outputs,
|
|
121
|
+
"inputs": inputs
|
|
122
|
+
}
|
|
123
|
+
|
|
124
|
+
return model_resource, layer_stack
|
|
125
|
+
|
|
126
|
+
|
|
127
|
+
def extract_encoder_decoder_weights(model):
|
|
128
|
+
# Initialize a dictionary to hold the weights
|
|
129
|
+
weights_dict = {
|
|
130
|
+
# 'shared_embeddings': {},
|
|
131
|
+
'encoder_embedding': {},
|
|
132
|
+
'encoder_layer_norm': {},
|
|
133
|
+
'decoder_embedding': {},
|
|
134
|
+
'decoder_layer_norm': {},
|
|
135
|
+
'decoder_lm_head': {}
|
|
136
|
+
}
|
|
137
|
+
|
|
138
|
+
# Extract the model's parameters and organize them into the dictionary
|
|
139
|
+
for weight in model.weights:
|
|
140
|
+
name = weight.name
|
|
141
|
+
value = weight.numpy()
|
|
142
|
+
|
|
143
|
+
if 'shared' in name:
|
|
144
|
+
weights_dict['encoder_embedding'][name] = value
|
|
145
|
+
weights_dict['decoder_embedding'][name] = value
|
|
146
|
+
weights_dict['decoder_lm_head'][name] = value
|
|
147
|
+
|
|
148
|
+
elif 'encoder' in name:
|
|
149
|
+
if 'block' in name:
|
|
150
|
+
layer = name.split('/')[2].split('_')[-1]
|
|
151
|
+
sub_layer = name.split('/')[3].split('_')[-1]
|
|
152
|
+
submodule = name.split('/')[4]
|
|
153
|
+
|
|
154
|
+
if 'SelfAttention' in submodule and f'encoder_self_attention_{layer}' not in weights_dict:
|
|
155
|
+
weights_dict[f'encoder_self_attention_{layer}'] = {}
|
|
156
|
+
if 'layer_norm' in submodule and f'encoder_layer_norm_{layer}' not in weights_dict:
|
|
157
|
+
weights_dict[f'encoder_layer_norm_{layer}'] = {}
|
|
158
|
+
if 'DenseReluDense' in submodule and f'encoder_feed_forward_{layer}' not in weights_dict:
|
|
159
|
+
weights_dict[f'encoder_feed_forward_{layer}'] = {}
|
|
160
|
+
|
|
161
|
+
if 'SelfAttention' in submodule:
|
|
162
|
+
weights_dict[f'encoder_self_attention_{layer}'][name] = value
|
|
163
|
+
elif 'layer_norm' in submodule:
|
|
164
|
+
weights_dict[f'encoder_layer_norm_{layer}'][name] = value
|
|
165
|
+
elif 'DenseReluDense' in submodule:
|
|
166
|
+
weights_dict[f'encoder_feed_forward_{layer}'][name] = value
|
|
167
|
+
|
|
168
|
+
elif 'final_layer_norm' in name:
|
|
169
|
+
weights_dict['encoder_layer_norm'][name] = value
|
|
170
|
+
|
|
171
|
+
elif 'decoder/block' in name:
|
|
172
|
+
layer = name.split('/')[2].split('_')[-1]
|
|
173
|
+
sub_layer = name.split('/')[3].split('_')[-1]
|
|
174
|
+
submodule = name.split('/')[4]
|
|
175
|
+
|
|
176
|
+
if 'SelfAttention' in submodule and f'decoder_self_attention_{layer}' not in weights_dict:
|
|
177
|
+
weights_dict[f'decoder_self_attention_{layer}'] = {}
|
|
178
|
+
if 'layer_norm' in submodule and f'decoder_layer_norm_{layer}' not in weights_dict:
|
|
179
|
+
weights_dict[f'decoder_layer_norm_{layer}'] = {}
|
|
180
|
+
if 'EncDecAttention' in submodule and f'decoder_cross_attention_{layer}' not in weights_dict:
|
|
181
|
+
weights_dict[f'decoder_cross_attention_{layer}'] = {}
|
|
182
|
+
if 'DenseReluDense' in submodule and f'decoder_feed_forward_{layer}' not in weights_dict:
|
|
183
|
+
weights_dict[f'decoder_feed_forward_{layer}'] = {}
|
|
184
|
+
|
|
185
|
+
if 'SelfAttention' in submodule:
|
|
186
|
+
weights_dict[f'decoder_self_attention_{layer}'][name] = value
|
|
187
|
+
elif 'layer_norm' in submodule:
|
|
188
|
+
weights_dict[f'decoder_layer_norm_{layer}'][name] = value
|
|
189
|
+
elif 'EncDecAttention' in submodule:
|
|
190
|
+
weights_dict[f'decoder_cross_attention_{layer}'][name] = value
|
|
191
|
+
elif 'DenseReluDense' in submodule:
|
|
192
|
+
weights_dict[f'decoder_feed_forward_{layer}'][name] = value
|
|
193
|
+
|
|
194
|
+
elif 'decoder/final_layer_norm' in name:
|
|
195
|
+
weights_dict['decoder_layer_norm'][name] = value
|
|
196
|
+
|
|
197
|
+
return weights_dict
|
|
198
|
+
|
|
199
|
+
|
|
200
|
+
def calculate_encoder_decoder_output(input_text, model, tokenizer):
|
|
201
|
+
# Dictionaries to store the inputs and outputs
|
|
202
|
+
encoder_inputs = {}
|
|
203
|
+
encoder_outputs = defaultdict(lambda: defaultdict(dict))
|
|
204
|
+
decoder_inputs = defaultdict(lambda: defaultdict(dict))
|
|
205
|
+
decoder_outputs = defaultdict(lambda: defaultdict(dict))
|
|
206
|
+
|
|
207
|
+
# Hook manager to store hook information
|
|
208
|
+
hook_manager = []
|
|
209
|
+
|
|
210
|
+
# Global variable to keep track of the token index
|
|
211
|
+
token_idx = 0
|
|
212
|
+
|
|
213
|
+
# Function to generate timestamp (token index)
|
|
214
|
+
def get_timestamp():
|
|
215
|
+
return str(token_idx)
|
|
216
|
+
|
|
217
|
+
# Function to wrap the call method of the layer to add hooks
|
|
218
|
+
def wrap_call(layer, hook_fn, layer_index=None):
|
|
219
|
+
original_call = layer.call
|
|
220
|
+
|
|
221
|
+
def hooked_call(*args, **kwargs):
|
|
222
|
+
outputs = original_call(*args, **kwargs)
|
|
223
|
+
if layer_index is not None:
|
|
224
|
+
hook_fn(layer, args, outputs, layer_index)
|
|
225
|
+
else:
|
|
226
|
+
hook_fn(layer, args, outputs)
|
|
227
|
+
return outputs
|
|
228
|
+
|
|
229
|
+
layer.call = hooked_call
|
|
230
|
+
hook_manager.append((layer, original_call))
|
|
231
|
+
|
|
232
|
+
def capture_encoder_embeddings(model, tokenizer, input_text):
|
|
233
|
+
# Tokenize the input text
|
|
234
|
+
encoding = tokenizer(input_text, return_tensors='tf')
|
|
235
|
+
input_ids = encoding["input_ids"]
|
|
236
|
+
attention_mask = encoding["attention_mask"]
|
|
237
|
+
|
|
238
|
+
# Manually capture the embedding output
|
|
239
|
+
embedding_output = model.encoder.embed_tokens(input_ids)
|
|
240
|
+
|
|
241
|
+
# Return the captured embeddings
|
|
242
|
+
return embedding_output
|
|
243
|
+
|
|
244
|
+
## ------------ Hook function for Self-Attention Block --------------------
|
|
245
|
+
def hook_fn__encoder_normalized_hidden_states(layer, inputs, outputs, layer_index):
|
|
246
|
+
encoder_inputs[f'encoder_layer_norm_{layer_index}_0'] = inputs
|
|
247
|
+
encoder_outputs[f'encoder_layer_norm_{layer_index}_0'] = outputs
|
|
248
|
+
encoder_outputs[f'input_to_layer_norm_{layer_index}'] = inputs[0]
|
|
249
|
+
|
|
250
|
+
def hook_fn_encoder_self_attention_outputs(layer, inputs, outputs, layer_index):
|
|
251
|
+
encoder_inputs[f'encoder_self_attention_{layer_index}'] = inputs
|
|
252
|
+
encoder_outputs[f'encoder_self_attention_{layer_index}'] = outputs[0]
|
|
253
|
+
|
|
254
|
+
def hook_fn_encoder_dropout_attention_output(layer, inputs, outputs, layer_index):
|
|
255
|
+
encoder_outputs[f'encoder_dropout_attention_{layer_index}'] = outputs[0]
|
|
256
|
+
|
|
257
|
+
## ------------ Hook function for Feed-Forward Block --------------
|
|
258
|
+
def hook_fn_encoder_normalized_forwarded_states(layer, inputs, outputs, layer_index):
|
|
259
|
+
encoder_inputs[f'encoder_layer_norm_{layer_index}_1'] = inputs
|
|
260
|
+
encoder_outputs[f'encoder_layer_norm_{layer_index}_1'] = outputs
|
|
261
|
+
encoder_outputs[f'input_to_ff_layer_norm_{layer_index}'] = inputs[0]
|
|
262
|
+
|
|
263
|
+
def hook_fn_encoder_forwarded_states(layer, inputs, outputs, layer_index):
|
|
264
|
+
encoder_inputs[f'encoder_feed_forward_{layer_index}'] = inputs
|
|
265
|
+
encoder_outputs[f'encoder_feed_forward_{layer_index}'] = outputs
|
|
266
|
+
|
|
267
|
+
def hook_fn_encoder_dropout_forwarded_states(layer, inputs, outputs, layer_index):
|
|
268
|
+
encoder_outputs[f'dropout_forwarded_states_{layer_index}'] = outputs
|
|
269
|
+
|
|
270
|
+
# Custom hooks to calculate residuals
|
|
271
|
+
def hook_fn_encoder_residual_self_attention(layer, inputs, outputs, layer_index):
|
|
272
|
+
input_to_layer_norm = encoder_outputs[f'input_to_layer_norm_{layer_index}']
|
|
273
|
+
encoder_outputs[f'encoder_residual_self_attention_{layer_index}'] = input_to_layer_norm + outputs
|
|
274
|
+
|
|
275
|
+
def hook_fn_encoder_residual_feed_forward(layer, inputs, outputs, layer_index):
|
|
276
|
+
input_to_ff_layer_norm = encoder_outputs[f'input_to_ff_layer_norm_{layer_index}']
|
|
277
|
+
encoder_outputs[f'encoder_residual_feed_forward_{layer_index}'] = input_to_ff_layer_norm + outputs
|
|
278
|
+
|
|
279
|
+
# Hook for Final Layer normalization and dropout for Encoder
|
|
280
|
+
def hook_fn_normalized_encoder_output(layer, inputs, outputs):
|
|
281
|
+
encoder_outputs['encoder_layer_norm'] = outputs
|
|
282
|
+
|
|
283
|
+
def hook_fn_dropout_normalized_encoder_output(layer, inputs, outputs):
|
|
284
|
+
encoder_outputs['dropout_normalized_encoder_output'] = outputs
|
|
285
|
+
|
|
286
|
+
# Register hooks to the encoder submodules
|
|
287
|
+
for i, layer in enumerate(model.encoder.block):
|
|
288
|
+
layer.layer[0].layer_norm.layer_index = i
|
|
289
|
+
layer.layer[0].SelfAttention.layer_index = i
|
|
290
|
+
layer.layer[0].dropout.layer_index = i
|
|
291
|
+
layer.layer[1].layer_norm.layer_index = i
|
|
292
|
+
layer.layer[1].DenseReluDense.layer_index = i
|
|
293
|
+
layer.layer[1].dropout.layer_index = i
|
|
294
|
+
|
|
295
|
+
wrap_call(layer.layer[0].layer_norm, hook_fn__encoder_normalized_hidden_states, i)
|
|
296
|
+
wrap_call(layer.layer[0].SelfAttention, hook_fn_encoder_self_attention_outputs, i)
|
|
297
|
+
wrap_call(layer.layer[0].dropout, hook_fn_encoder_dropout_attention_output, i)
|
|
298
|
+
wrap_call(layer.layer[0].dropout, hook_fn_encoder_residual_self_attention, i)
|
|
299
|
+
|
|
300
|
+
wrap_call(layer.layer[1].layer_norm, hook_fn_encoder_normalized_forwarded_states, i)
|
|
301
|
+
wrap_call(layer.layer[1].DenseReluDense, hook_fn_encoder_forwarded_states, i)
|
|
302
|
+
wrap_call(layer.layer[1].dropout, hook_fn_encoder_dropout_forwarded_states, i)
|
|
303
|
+
wrap_call(layer.layer[1].dropout, hook_fn_encoder_residual_feed_forward, i)
|
|
304
|
+
|
|
305
|
+
wrap_call(model.encoder.final_layer_norm, hook_fn_normalized_encoder_output)
|
|
306
|
+
wrap_call(model.encoder.dropout, hook_fn_dropout_normalized_encoder_output)
|
|
307
|
+
|
|
308
|
+
############################ Hook for Decoder ################################
|
|
309
|
+
def hook_fn_decoder_embedding(layer, inputs, outputs):
|
|
310
|
+
global token_idx
|
|
311
|
+
timestamp = get_timestamp()
|
|
312
|
+
decoder_outputs[timestamp]['decoder_embedding'] = outputs
|
|
313
|
+
|
|
314
|
+
def hook_fn_decoder_normalized_hidden_states(layer, inputs, outputs, layer_index):
|
|
315
|
+
global token_idx
|
|
316
|
+
timestamp = get_timestamp()
|
|
317
|
+
decoder_inputs[timestamp][f'decoder_layer_norm_{layer_index}_0'] = inputs
|
|
318
|
+
decoder_outputs[timestamp][f'decoder_layer_norm_{layer_index}_0'] = outputs
|
|
319
|
+
decoder_outputs[timestamp][f'input_to_layer_norm_{layer_index}'] = inputs[0]
|
|
320
|
+
|
|
321
|
+
def hook_fn_decoder_self_attention_outputs(layer, inputs, outputs, layer_index):
|
|
322
|
+
global token_idx
|
|
323
|
+
timestamp = get_timestamp()
|
|
324
|
+
decoder_inputs[timestamp][f'decoder_self_attention_{layer_index}'] = inputs
|
|
325
|
+
decoder_outputs[timestamp][f'decoder_self_attention_{layer_index}'] = outputs[0]
|
|
326
|
+
|
|
327
|
+
def hook_fn_decoder_dropout_attention_output(layer, inputs, outputs, layer_index):
|
|
328
|
+
global token_idx
|
|
329
|
+
timestamp = get_timestamp()
|
|
330
|
+
decoder_outputs[timestamp][f'dropout_attention_output_{layer_index}'] = outputs
|
|
331
|
+
|
|
332
|
+
def hook_fn_decoder_normalized_cross_attn_hidden_states(layer, inputs, outputs, layer_index):
|
|
333
|
+
global token_idx
|
|
334
|
+
timestamp = get_timestamp()
|
|
335
|
+
decoder_inputs[timestamp][f'decoder_layer_norm_{layer_index}_1'] = inputs
|
|
336
|
+
decoder_outputs[timestamp][f'decoder_layer_norm_{layer_index}_1'] = outputs
|
|
337
|
+
decoder_outputs[timestamp][f'input_to_cross_attn_layer_norm_{layer_index}'] = inputs[0]
|
|
338
|
+
|
|
339
|
+
def hook_fn_decoder_cross_attention_outputs(layer, inputs, outputs, layer_index):
|
|
340
|
+
global token_idx
|
|
341
|
+
timestamp = get_timestamp()
|
|
342
|
+
decoder_inputs[timestamp][f'decoder_cross_attention_{layer_index}'] = {'query': inputs[0], 'key': outputs, 'value': outputs}
|
|
343
|
+
decoder_outputs[timestamp][f'decoder_cross_attention_{layer_index}'] = outputs[0]
|
|
344
|
+
|
|
345
|
+
def hook_fn_decoder_dropout_cross_attn_output(layer, inputs, outputs, layer_index):
|
|
346
|
+
global token_idx
|
|
347
|
+
timestamp = get_timestamp()
|
|
348
|
+
decoder_outputs[timestamp][f'dropout_cross_attn_output_{layer_index}'] = outputs
|
|
349
|
+
|
|
350
|
+
def hook_fn_decoder_normalized_forwarded_states(layer, inputs, outputs, layer_index):
|
|
351
|
+
global token_idx
|
|
352
|
+
timestamp = get_timestamp()
|
|
353
|
+
decoder_inputs[timestamp][f'decoder_layer_norm_{layer_index}_2'] = inputs
|
|
354
|
+
decoder_outputs[timestamp][f'decoder_layer_norm_{layer_index}_2'] = outputs
|
|
355
|
+
decoder_outputs[timestamp][f'input_to_ff_layer_norm_{layer_index}'] = inputs[0]
|
|
356
|
+
|
|
357
|
+
def hook_fn_decoder_forwarded_states(layer, inputs, outputs, layer_index):
|
|
358
|
+
global token_idx
|
|
359
|
+
timestamp = get_timestamp()
|
|
360
|
+
decoder_inputs[timestamp][f'decoder_feed_forward_{layer_index}'] = inputs[0]
|
|
361
|
+
decoder_outputs[timestamp][f'decoder_feed_forward_{layer_index}'] = outputs
|
|
362
|
+
|
|
363
|
+
def hook_fn_decoder_dropout_forwarded_states(layer, inputs, outputs, layer_index):
|
|
364
|
+
global token_idx
|
|
365
|
+
timestamp = get_timestamp()
|
|
366
|
+
decoder_outputs[timestamp][f'dropout_forwarded_states_{layer_index}'] = outputs
|
|
367
|
+
|
|
368
|
+
def hook_fn_decoder_residual_self_attention(layer, inputs, outputs, layer_index):
|
|
369
|
+
global token_idx
|
|
370
|
+
timestamp = get_timestamp()
|
|
371
|
+
input_to_layer_norm = decoder_outputs[timestamp][f'input_to_layer_norm_{layer_index}']
|
|
372
|
+
decoder_outputs[timestamp][f'decoder_residual_self_attention_{layer_index}'] = input_to_layer_norm + outputs
|
|
373
|
+
|
|
374
|
+
def hook_fn_decoder_residual_cross_attention(layer, inputs, outputs, layer_index):
|
|
375
|
+
global token_idx
|
|
376
|
+
timestamp = get_timestamp()
|
|
377
|
+
input_to_layer_norm = decoder_outputs[timestamp][f'input_to_cross_attn_layer_norm_{layer_index}']
|
|
378
|
+
decoder_outputs[timestamp][f'decoder_residual_cross_attention_{layer_index}'] = input_to_layer_norm + outputs
|
|
379
|
+
|
|
380
|
+
def hook_fn_decoder_residual_feed_forward(layer, inputs, outputs, layer_index):
|
|
381
|
+
global token_idx
|
|
382
|
+
timestamp = get_timestamp()
|
|
383
|
+
input_to_ff_layer_norm = decoder_outputs[timestamp][f'input_to_ff_layer_norm_{layer_index}']
|
|
384
|
+
decoder_outputs[timestamp][f'decoder_residual_feed_forward_{layer_index}'] = input_to_ff_layer_norm + outputs
|
|
385
|
+
|
|
386
|
+
def hook_fn_normalized_decoder_output(layer, inputs, outputs):
|
|
387
|
+
global token_idx
|
|
388
|
+
timestamp = get_timestamp()
|
|
389
|
+
decoder_outputs[timestamp]['decoder_layer_norm'] = outputs
|
|
390
|
+
|
|
391
|
+
def hook_fn_dropout_normalized_decoder_output(layer, inputs, outputs):
|
|
392
|
+
global token_idx
|
|
393
|
+
timestamp = get_timestamp()
|
|
394
|
+
decoder_outputs[timestamp]['dropout_normalized_decoder_output'] = outputs
|
|
395
|
+
|
|
396
|
+
def hook_fn_final_logits(layer, inputs, outputs):
|
|
397
|
+
global token_idx
|
|
398
|
+
timestamp = get_timestamp()
|
|
399
|
+
decoder_outputs[timestamp]['decoder_lm_head'] = outputs.logits
|
|
400
|
+
|
|
401
|
+
# Register hooks to the decoder submodules
|
|
402
|
+
wrap_call(model.decoder.embed_tokens, hook_fn_decoder_embedding)
|
|
403
|
+
|
|
404
|
+
for i, layer in enumerate(model.decoder.block):
|
|
405
|
+
layer.layer[0].layer_norm.layer_index = i
|
|
406
|
+
layer.layer[0].SelfAttention.layer_index = i
|
|
407
|
+
layer.layer[0].dropout.layer_index = i
|
|
408
|
+
layer.layer[1].layer_norm.layer_index = i
|
|
409
|
+
layer.layer[1].EncDecAttention.layer_index = i
|
|
410
|
+
layer.layer[1].dropout.layer_index = i
|
|
411
|
+
layer.layer[2].layer_norm.layer_index = i
|
|
412
|
+
layer.layer[2].DenseReluDense.layer_index = i
|
|
413
|
+
layer.layer[2].dropout.layer_index = i
|
|
414
|
+
|
|
415
|
+
wrap_call(layer.layer[0].layer_norm, hook_fn_decoder_normalized_hidden_states, i)
|
|
416
|
+
wrap_call(layer.layer[0].SelfAttention, hook_fn_decoder_self_attention_outputs, i)
|
|
417
|
+
wrap_call(layer.layer[0].dropout, hook_fn_decoder_dropout_attention_output, i)
|
|
418
|
+
wrap_call(layer.layer[0].dropout, hook_fn_decoder_residual_self_attention, i)
|
|
419
|
+
|
|
420
|
+
wrap_call(layer.layer[1].layer_norm, hook_fn_decoder_normalized_cross_attn_hidden_states, i)
|
|
421
|
+
wrap_call(layer.layer[1].EncDecAttention, hook_fn_decoder_cross_attention_outputs, i)
|
|
422
|
+
wrap_call(layer.layer[1].dropout, hook_fn_decoder_dropout_cross_attn_output, i)
|
|
423
|
+
wrap_call(layer.layer[1].dropout, hook_fn_decoder_residual_cross_attention, i)
|
|
424
|
+
|
|
425
|
+
wrap_call(layer.layer[2].layer_norm, hook_fn_decoder_normalized_forwarded_states, i)
|
|
426
|
+
wrap_call(layer.layer[2].DenseReluDense, hook_fn_decoder_forwarded_states, i)
|
|
427
|
+
wrap_call(layer.layer[2].dropout, hook_fn_decoder_dropout_forwarded_states, i)
|
|
428
|
+
wrap_call(layer.layer[2].dropout, hook_fn_decoder_residual_feed_forward, i)
|
|
429
|
+
|
|
430
|
+
wrap_call(model.decoder.final_layer_norm, hook_fn_normalized_decoder_output)
|
|
431
|
+
wrap_call(model.decoder.dropout, hook_fn_dropout_normalized_decoder_output)
|
|
432
|
+
|
|
433
|
+
# Register hook for the final logits by wrapping the call method of the model itself
|
|
434
|
+
original_call = model.call
|
|
435
|
+
def hooked_call(*args, **kwargs):
|
|
436
|
+
outputs = original_call(*args, **kwargs)
|
|
437
|
+
hook_fn_final_logits(model, args, outputs)
|
|
438
|
+
return outputs
|
|
439
|
+
|
|
440
|
+
model.call = hooked_call
|
|
441
|
+
hook_manager.append((model, original_call))
|
|
442
|
+
|
|
443
|
+
# Function to remove hooks
|
|
444
|
+
def remove_hooks():
|
|
445
|
+
for layer, original_call in hook_manager:
|
|
446
|
+
layer.call = original_call
|
|
447
|
+
hook_manager.clear()
|
|
448
|
+
|
|
449
|
+
# Function to get shape
|
|
450
|
+
def get_shape(value):
|
|
451
|
+
if isinstance(value, tf.Tensor):
|
|
452
|
+
return value.shape
|
|
453
|
+
elif isinstance(value, tuple):
|
|
454
|
+
return [get_shape(v) for v in value if v is not None]
|
|
455
|
+
elif isinstance(value, list):
|
|
456
|
+
return [get_shape(v) for v in value if v is not None]
|
|
457
|
+
elif isinstance(value, dict):
|
|
458
|
+
return {k: get_shape(v) for k, v in value.items()}
|
|
459
|
+
else:
|
|
460
|
+
return None
|
|
461
|
+
|
|
462
|
+
# Function to increment token index
|
|
463
|
+
def increment_token_idx():
|
|
464
|
+
global token_idx
|
|
465
|
+
token_idx += 1
|
|
466
|
+
|
|
467
|
+
encoding = tokenizer(input_text, return_tensors='tf')
|
|
468
|
+
input_ids = encoding["input_ids"]
|
|
469
|
+
attention_mask = encoding["attention_mask"]
|
|
470
|
+
|
|
471
|
+
embedding_output = capture_encoder_embeddings(model, tokenizer, input_text)
|
|
472
|
+
encoder_outputs['encoder_embedding'] = embedding_output
|
|
473
|
+
|
|
474
|
+
# Initialize decoder_input_ids with the start token
|
|
475
|
+
decoder_start_token_id = model.config.decoder_start_token_id
|
|
476
|
+
decoder_input_ids = tf.fill((tf.shape(input_ids)[0], 1), decoder_start_token_id)
|
|
477
|
+
|
|
478
|
+
# Reset token_idx before generating
|
|
479
|
+
token_idx = 0
|
|
480
|
+
max_length = model.config.n_positions if hasattr(model.config, 'n_positions') else model.config.d_model
|
|
481
|
+
generated_tokens = []
|
|
482
|
+
|
|
483
|
+
for _ in range(max_length):
|
|
484
|
+
outputs = model(input_ids=input_ids, attention_mask=attention_mask, decoder_input_ids=decoder_input_ids)
|
|
485
|
+
next_token_logits = outputs.logits[:, -1, :]
|
|
486
|
+
next_token_id = tf.argmax(next_token_logits, axis=-1, output_type=tf.int32)
|
|
487
|
+
next_token_id = tf.expand_dims(next_token_id, axis=-1)
|
|
488
|
+
generated_tokens.append(next_token_id.numpy().item())
|
|
489
|
+
decoder_input_ids = tf.concat([decoder_input_ids, next_token_id], axis=-1)
|
|
490
|
+
increment_token_idx()
|
|
491
|
+
|
|
492
|
+
if next_token_id.numpy().item() == model.config.eos_token_id:
|
|
493
|
+
break
|
|
494
|
+
|
|
495
|
+
# Merge the encoder_outputs with timestep decoder_outputs to generate timestep wise outputs of the model
|
|
496
|
+
outputs = {}
|
|
497
|
+
|
|
498
|
+
for i in range(len(decoder_outputs)):
|
|
499
|
+
outputs[f'{i}'] = {**encoder_outputs, **decoder_outputs[f'{i}']}
|
|
500
|
+
|
|
501
|
+
return outputs, generated_tokens
|
|
@@ -0,0 +1,99 @@
|
|
|
1
|
+
import tensorflow as tf
|
|
2
|
+
|
|
3
|
+
|
|
4
|
+
# Function to rename the dictionary keys
|
|
5
|
+
def rename_self_attention_keys(attention_weights):
|
|
6
|
+
renamed_weights = {}
|
|
7
|
+
for key, value in attention_weights.items():
|
|
8
|
+
if 'query/kernel' in key or 'SelfAttention/q' in key:
|
|
9
|
+
new_key = key.replace(key, 'W_q')
|
|
10
|
+
elif 'query/bias' in key:
|
|
11
|
+
new_key = key.replace(key, 'b_q')
|
|
12
|
+
elif 'key/kernel' in key or 'SelfAttention/k' in key:
|
|
13
|
+
new_key = key.replace(key, 'W_k')
|
|
14
|
+
elif 'key/bias' in key:
|
|
15
|
+
new_key = key.replace(key, 'b_k')
|
|
16
|
+
elif 'value/kernel' in key or 'SelfAttention/v' in key:
|
|
17
|
+
new_key = key.replace(key, 'W_v')
|
|
18
|
+
elif 'value/bias' in key:
|
|
19
|
+
new_key = key.replace(key, 'b_v')
|
|
20
|
+
elif 'output/dense/kernel' in key or 'SelfAttention/o' in key:
|
|
21
|
+
new_key = key.replace(key, 'W_d')
|
|
22
|
+
elif 'output/dense/bias' in key:
|
|
23
|
+
new_key = key.replace(key, 'b_d')
|
|
24
|
+
elif 'SelfAttention/relative_attention_bias' in key:
|
|
25
|
+
new_key = key.replace(key, 'relative_attn_bias')
|
|
26
|
+
|
|
27
|
+
renamed_weights[new_key] = value
|
|
28
|
+
return renamed_weights
|
|
29
|
+
|
|
30
|
+
|
|
31
|
+
def rename_cross_attention_keys(cross_attention_weights):
|
|
32
|
+
renamed_weights = {}
|
|
33
|
+
|
|
34
|
+
for key, value in cross_attention_weights.items():
|
|
35
|
+
if 'EncDecAttention/q' in key:
|
|
36
|
+
new_key = key.replace(key, 'W_q')
|
|
37
|
+
elif 'EncDecAttention/k' in key:
|
|
38
|
+
new_key = key.replace(key, 'W_k')
|
|
39
|
+
elif 'EncDecAttention/v' in key:
|
|
40
|
+
new_key = key.replace(key, 'W_v')
|
|
41
|
+
elif 'EncDecAttention/o' in key:
|
|
42
|
+
new_key = key.replace(key, 'W_o')
|
|
43
|
+
|
|
44
|
+
renamed_weights[new_key] = value
|
|
45
|
+
return renamed_weights
|
|
46
|
+
|
|
47
|
+
|
|
48
|
+
def rename_feed_forward_keys(feed_forward_weights):
|
|
49
|
+
renamed_weights = {}
|
|
50
|
+
|
|
51
|
+
for key, value in feed_forward_weights.items():
|
|
52
|
+
if 'intermediate/dense/kernel' in key or 'DenseReluDense/wi' in key:
|
|
53
|
+
new_key = key.replace(key, 'W_int')
|
|
54
|
+
elif 'intermediate/dense/bias' in key or 'DenseReluDense/bi' in key:
|
|
55
|
+
new_key = key.replace(key, 'b_int')
|
|
56
|
+
elif 'output/dense/kernel' in key or 'DenseReluDense/wo' in key:
|
|
57
|
+
new_key = key.replace(key, 'W_out')
|
|
58
|
+
elif 'output/dense/bias' in key or 'DenseReluDense/bo' in key:
|
|
59
|
+
new_key = key.replace(key, 'b_out')
|
|
60
|
+
|
|
61
|
+
renamed_weights[new_key] = value
|
|
62
|
+
return renamed_weights
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
def rename_pooler_keys(pooler_weights):
|
|
66
|
+
renamed_weights = {}
|
|
67
|
+
|
|
68
|
+
for key, value in pooler_weights.items():
|
|
69
|
+
if 'pooler/dense/kernel' in key:
|
|
70
|
+
new_key = key.replace(key, 'W_p')
|
|
71
|
+
elif 'pooler/dense/bias' in key:
|
|
72
|
+
new_key = key.replace(key, 'b_p')
|
|
73
|
+
|
|
74
|
+
renamed_weights[new_key] = value
|
|
75
|
+
return renamed_weights
|
|
76
|
+
|
|
77
|
+
|
|
78
|
+
def rename_classifier_keys(classifier_weights):
|
|
79
|
+
renamed_weights = {}
|
|
80
|
+
|
|
81
|
+
for key, value in classifier_weights.items():
|
|
82
|
+
if 'classifier/kernel' in key:
|
|
83
|
+
new_key = key.replace(key, 'W_cls')
|
|
84
|
+
elif 'classifier/bias' in key:
|
|
85
|
+
new_key = key.replace(key, 'b_cls')
|
|
86
|
+
|
|
87
|
+
renamed_weights[new_key] = value
|
|
88
|
+
return renamed_weights
|
|
89
|
+
|
|
90
|
+
|
|
91
|
+
def rename_decoder_lm_head(lm_head_weights):
|
|
92
|
+
renamed_weights = {}
|
|
93
|
+
|
|
94
|
+
for key, value in lm_head_weights.items():
|
|
95
|
+
if 'shared/shared/embeddings' in key:
|
|
96
|
+
new_key = key.replace(key, 'W_lm_head')
|
|
97
|
+
|
|
98
|
+
renamed_weights[new_key] = value
|
|
99
|
+
return renamed_weights
|