psaiops 0.0.13__py3-none-any.whl → 0.4.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- psaiops/combine/app.py +366 -0
- psaiops/{elements → common}/data.py +8 -0
- psaiops/common/model.py +45 -0
- psaiops/common/tokenizer.py +41 -0
- psaiops/compose/contrast/app.py +195 -0
- psaiops/compose/contrast/lib.py +143 -0
- psaiops/compose/maths/app.py +323 -0
- psaiops/compose/maths/lib.py +1 -0
- psaiops/reverse/__init__.py +0 -0
- psaiops/score/attention/app.py +106 -72
- psaiops/score/attention/lib.py +9 -84
- psaiops/score/residual/__init__.py +0 -0
- psaiops/score/residual/app.py +290 -0
- psaiops/score/residual/lib.py +134 -0
- psaiops/score/router/__init__.py +0 -0
- psaiops/score/router/app.py +281 -0
- psaiops/score/router/lib.py +59 -0
- psaiops/score/shapley/__init__.py +0 -0
- psaiops/score/shapley/app.py +158 -0
- psaiops/score/shapley/lib.py +1 -0
- psaiops/score/similarity/__init__.py +0 -0
- psaiops/score/similarity/app.py +152 -0
- psaiops/score/similarity/lib.py +1 -0
- {psaiops-0.0.13.dist-info → psaiops-0.4.0.dist-info}/METADATA +14 -16
- psaiops-0.4.0.dist-info/RECORD +36 -0
- {psaiops-0.0.13.dist-info → psaiops-0.4.0.dist-info}/WHEEL +1 -1
- psaiops-0.4.0.dist-info/licenses/.github/LICENSE.md +661 -0
- psaiops-0.0.13.dist-info/RECORD +0 -15
- /psaiops/{elements → common}/__init__.py +0 -0
- /psaiops/{steer → compose/maths}/__init__.py +0 -0
|
@@ -0,0 +1,290 @@
|
|
|
1
|
+
import functools
|
|
2
|
+
|
|
3
|
+
import gradio
|
|
4
|
+
import torch
|
|
5
|
+
import torch.cuda
|
|
6
|
+
import matplotlib.pyplot
|
|
7
|
+
|
|
8
|
+
import psaiops.common.model
|
|
9
|
+
import psaiops.common.tokenizer
|
|
10
|
+
import psaiops.score.residual.lib
|
|
11
|
+
|
|
12
|
+
# META #########################################################################
|
|
13
|
+
|
|
14
|
+
STYLE = '''.white-text span { color: white; }'''
|
|
15
|
+
TITLE = '''Router Scoring'''
|
|
16
|
+
INTRO = '''Plot the logits of the router for a given prompt.\nUnder construction, only "openai/gpt-oss-20b" is available for now.'''
|
|
17
|
+
|
|
18
|
+
MODEL = 'openai/gpt-oss-20b'
|
|
19
|
+
|
|
20
|
+
# COLORS #######################################################################
|
|
21
|
+
|
|
22
|
+
def create_color_map() -> dict:
|
|
23
|
+
return {
|
|
24
|
+
'0': '#000000',
|
|
25
|
+
'1': '#004444',}
|
|
26
|
+
|
|
27
|
+
# INTRO ########################################################################
|
|
28
|
+
|
|
29
|
+
def create_intro_block(intro: str) -> dict:
|
|
30
|
+
__intro = gradio.Markdown(intro, line_breaks=True)
|
|
31
|
+
return {'intro_block': __intro}
|
|
32
|
+
|
|
33
|
+
# MODEL ########################################################################
|
|
34
|
+
|
|
35
|
+
def create_model_block() -> dict:
|
|
36
|
+
__model = gradio.Dropdown(label='Model', value='openai/gpt-oss-20b', choices=['openai/gpt-oss-20b'], scale=1, allow_custom_value=False, multiselect=False, interactive=True) # 'openai/gpt-oss-120b'
|
|
37
|
+
return {'model_block': __model,}
|
|
38
|
+
|
|
39
|
+
# SAMPLING #####################################################################
|
|
40
|
+
|
|
41
|
+
def create_sampling_block() -> dict:
|
|
42
|
+
__tokens = gradio.Slider(label='Tokens', value=16, minimum=1, maximum=128, step=1, scale=1, interactive=True)
|
|
43
|
+
__topk = gradio.Slider(label='Top K', value=4, minimum=1, maximum=8, step=1, scale=1, interactive=True)
|
|
44
|
+
__topp = gradio.Slider(label='Top P', value=0.9, minimum=0.0, maximum=1.0, step=0.1, scale=1, interactive=True)
|
|
45
|
+
return {
|
|
46
|
+
'tokens_block': __tokens,
|
|
47
|
+
'topk_block': __topk,
|
|
48
|
+
'topp_block': __topp,}
|
|
49
|
+
|
|
50
|
+
# INPUTS #######################################################################
|
|
51
|
+
|
|
52
|
+
def create_inputs_block() -> dict:
|
|
53
|
+
__input = gradio.Textbox(label='Prompt', value='', placeholder='A string of tokens to score.', lines=4, scale=1, show_copy_button=True, interactive=True)
|
|
54
|
+
return {'input_block': __input}
|
|
55
|
+
|
|
56
|
+
# PLOTS ########################################################################
|
|
57
|
+
|
|
58
|
+
def create_plot_block() -> dict:
|
|
59
|
+
__plot = gradio.Plot(label='Router', scale=1)
|
|
60
|
+
return {'plot_block': __plot,}
|
|
61
|
+
|
|
62
|
+
# OUTPUTS ######################################################################
|
|
63
|
+
|
|
64
|
+
def create_outputs_block() -> dict:
|
|
65
|
+
__output = gradio.HighlightedText(label='Output', value='', scale=1, interactive=False, show_legend=False, show_inline_category=False, combine_adjacent=False, color_map=create_color_map(), elem_classes='white-text')
|
|
66
|
+
return {'output_block': __output}
|
|
67
|
+
|
|
68
|
+
# SELECT #######################################################################
|
|
69
|
+
|
|
70
|
+
def create_selection_block() -> dict:
|
|
71
|
+
# __play = gradio.Button('>', variant='primary', size='lg', scale=1, interactive=True)
|
|
72
|
+
__position = gradio.Slider(label='Token', value=-1, minimum=-1, maximum=15, step=1, scale=1, interactive=True) # info='-1 to average on all tokens'
|
|
73
|
+
return {'position_block': __position,}
|
|
74
|
+
|
|
75
|
+
# ACTIONS ######################################################################
|
|
76
|
+
|
|
77
|
+
def create_actions_block() -> dict:
|
|
78
|
+
__process = gradio.Button('Process', variant='primary', size='lg', scale=1, interactive=True)
|
|
79
|
+
return {'process_block': __process,}
|
|
80
|
+
|
|
81
|
+
# STATE ########################################################################
|
|
82
|
+
|
|
83
|
+
def create_state() -> dict:
|
|
84
|
+
return {
|
|
85
|
+
'output_state': gradio.State(None),
|
|
86
|
+
'hidden_state': gradio.State(None),}
|
|
87
|
+
|
|
88
|
+
# LAYOUT #######################################################################
|
|
89
|
+
|
|
90
|
+
def create_layout(intro: str=INTRO) -> dict:
|
|
91
|
+
__fields = {}
|
|
92
|
+
__fields.update(create_intro_block(intro=intro))
|
|
93
|
+
with gradio.Tabs():
|
|
94
|
+
with gradio.Tab('Score Tokens') as __main_tab:
|
|
95
|
+
__fields.update({'main_tab': __main_tab})
|
|
96
|
+
with gradio.Row(equal_height=True):
|
|
97
|
+
__fields.update(create_inputs_block())
|
|
98
|
+
with gradio.Row(equal_height=True):
|
|
99
|
+
__fields.update(create_plot_block())
|
|
100
|
+
with gradio.Row(equal_height=True):
|
|
101
|
+
__fields.update(create_outputs_block())
|
|
102
|
+
with gradio.Row(equal_height=True):
|
|
103
|
+
__fields.update(create_selection_block())
|
|
104
|
+
with gradio.Row(equal_height=True):
|
|
105
|
+
__fields.update(create_actions_block())
|
|
106
|
+
with gradio.Tab('Settings') as __settings_tab:
|
|
107
|
+
__fields.update({'settings_tab': __settings_tab})
|
|
108
|
+
with gradio.Column(scale=1):
|
|
109
|
+
with gradio.Row(equal_height=True):
|
|
110
|
+
__fields.update(create_model_block())
|
|
111
|
+
with gradio.Row(equal_height=True):
|
|
112
|
+
__fields.update(create_sampling_block())
|
|
113
|
+
return __fields
|
|
114
|
+
|
|
115
|
+
# EVENTS #######################################################################
|
|
116
|
+
|
|
117
|
+
def update_position_range(
|
|
118
|
+
current_val: float,
|
|
119
|
+
token_num: float,
|
|
120
|
+
output_data: torch.Tensor,
|
|
121
|
+
) -> dict:
|
|
122
|
+
# take the generated tokens into account
|
|
123
|
+
__max = int(token_num) - 1 if (output_data is None) else int(output_data.shape[-1])
|
|
124
|
+
# keep the previous value if possible
|
|
125
|
+
__val = min(int(current_val), __max)
|
|
126
|
+
# return a gradio update dictionary
|
|
127
|
+
return gradio.update(maximum=__max, value=__val)
|
|
128
|
+
|
|
129
|
+
def update_computation_state(
|
|
130
|
+
token_num: float,
|
|
131
|
+
topk_num: float,
|
|
132
|
+
topp_num: float,
|
|
133
|
+
token_idx: float,
|
|
134
|
+
prompt_str: str,
|
|
135
|
+
device_str: str,
|
|
136
|
+
model_obj: object,
|
|
137
|
+
tokenizer_obj: object,
|
|
138
|
+
) -> tuple:
|
|
139
|
+
# sanitize the inputs
|
|
140
|
+
__token_num = max(1, min(128, int(token_num)))
|
|
141
|
+
__topk_num = max(1, min(8, int(topk_num)))
|
|
142
|
+
__topp_num = max(0.0, min(1.0, float(topp_num)))
|
|
143
|
+
__token_idx = max(-1, min(__token_num, int(token_idx)))
|
|
144
|
+
__prompt_str = prompt_str.strip()
|
|
145
|
+
__device_str = device_str if (device_str in ['cpu', 'cuda']) else 'cpu'
|
|
146
|
+
# exit if some values are missing
|
|
147
|
+
if (not __prompt_str) or (model_obj is None) or (tokenizer_obj is None):
|
|
148
|
+
return (torch.empty(0), torch.empty(0))
|
|
149
|
+
# dictionary {'input_ids': _, 'attention_mask': _}
|
|
150
|
+
__input_data = psaiops.common.tokenizer.preprocess_token_ids(
|
|
151
|
+
tokenizer_obj=tokenizer_obj,
|
|
152
|
+
prompt_str=__prompt_str,
|
|
153
|
+
device_str=__device_str)
|
|
154
|
+
# tensor (1, T) and O * L * (1, I, H)
|
|
155
|
+
__output_data, __hidden_data = psaiops.score.residual.lib.generate_token_ids(
|
|
156
|
+
model_obj=model_obj,
|
|
157
|
+
input_args=__input_data,
|
|
158
|
+
token_num=__token_num,
|
|
159
|
+
topk_num=__topk_num,
|
|
160
|
+
topp_num=__topp_num)
|
|
161
|
+
# tensor (1, L, I + O, H)
|
|
162
|
+
__hidden_data = psaiops.score.residual.lib.merge_hidden_states(
|
|
163
|
+
hidden_data=__hidden_data)
|
|
164
|
+
# update each component => (highlight, plot) states
|
|
165
|
+
return (
|
|
166
|
+
__output_data.cpu().float(),
|
|
167
|
+
__hidden_data.cpu().float(),)
|
|
168
|
+
|
|
169
|
+
def update_hidden_plot(
|
|
170
|
+
token_idx: float,
|
|
171
|
+
hidden_data: torch.Tensor,
|
|
172
|
+
) -> tuple:
|
|
173
|
+
# exit if some values are missing
|
|
174
|
+
if (hidden_data is None) or (len(hidden_data) == 0):
|
|
175
|
+
return None
|
|
176
|
+
# reduce the token axis (B, L, T, E) => (B, L, E)
|
|
177
|
+
__plot_data = psaiops.score.residual.lib.reduce_hidden_states(
|
|
178
|
+
hidden_data=hidden_data,
|
|
179
|
+
token_idx=int(token_idx),)
|
|
180
|
+
# rescale the data to [-1; 1] (B, L, E)
|
|
181
|
+
__plot_data = psaiops.score.residual.lib.rescale_hidden_states(
|
|
182
|
+
hidden_data=__plot_data)
|
|
183
|
+
# reshape into a 3D tensor by folding E (B, L, E) => (B, W, H, L)
|
|
184
|
+
__plot_data = psaiops.score.residual.lib.reshape_hidden_states(
|
|
185
|
+
hidden_data=__plot_data)
|
|
186
|
+
# map the [-1; 1] activations to RGBA colors
|
|
187
|
+
__plot_data = psaiops.score.residual.lib.color_hidden_states(
|
|
188
|
+
hidden_data=__plot_data.numpy())
|
|
189
|
+
# mask the small activations to improve the plot readability
|
|
190
|
+
__mask_data = psaiops.score.residual.lib.mask_hidden_states(
|
|
191
|
+
hidden_data=__plot_data,
|
|
192
|
+
topk_num=128).numpy()
|
|
193
|
+
# plot the first sample
|
|
194
|
+
__figure, __axes = matplotlib.pyplot.subplots(111, projection='3d')
|
|
195
|
+
__axes.voxels(filled=__mask_data[0].numpy(), facecolors=__plot_data[0], edgecolor=None)
|
|
196
|
+
__figure.tight_layout()
|
|
197
|
+
# remove the figure for the pyplot register for garbage collection
|
|
198
|
+
matplotlib.pyplot.close(__figure)
|
|
199
|
+
# update each component => (highlight, plot) states
|
|
200
|
+
return __figure
|
|
201
|
+
|
|
202
|
+
def update_text_highlight(
|
|
203
|
+
token_idx: float,
|
|
204
|
+
output_data: torch.Tensor,
|
|
205
|
+
tokenizer_obj: object,
|
|
206
|
+
) -> list:
|
|
207
|
+
# exit if some values are missing
|
|
208
|
+
if (output_data is None) or (len(output_data) == 0):
|
|
209
|
+
return None
|
|
210
|
+
# detokenize the IDs
|
|
211
|
+
__token_str = psaiops.common.tokenizer.postprocess_token_ids(
|
|
212
|
+
tokenizer_obj=tokenizer_obj,
|
|
213
|
+
token_data=output_data)
|
|
214
|
+
# list of string classes
|
|
215
|
+
__token_cls = psaiops.score.residual.lib.postprocess_token_cls(
|
|
216
|
+
token_idx=int(token_idx),
|
|
217
|
+
token_dim=len(__token_str))
|
|
218
|
+
# pairs of token and class
|
|
219
|
+
return list(zip(__token_str, __token_cls))
|
|
220
|
+
|
|
221
|
+
# APP ##########################################################################
|
|
222
|
+
|
|
223
|
+
def create_app(title: str=TITLE, intro: str=INTRO, style: str=STYLE, model: str=MODEL) -> gradio.Blocks:
|
|
224
|
+
__fields = {}
|
|
225
|
+
with gradio.Blocks(theme=gradio.themes.Soft(), title=title, css=style) as __app:
|
|
226
|
+
# load the model
|
|
227
|
+
__device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
|
228
|
+
__model = psaiops.common.model.get_model(name=model, device=__device)
|
|
229
|
+
__tokenizer = psaiops.common.tokenizer.get_tokenizer(name=model, device=__device)
|
|
230
|
+
# adapt the event handlers
|
|
231
|
+
__compute = functools.partial(update_computation_state, model_obj=__model, tokenizer_obj=__tokenizer, device_str=__device)
|
|
232
|
+
__highlight = functools.partial(update_text_highlight, tokenizer_obj=__tokenizer)
|
|
233
|
+
# create the UI
|
|
234
|
+
__fields.update(create_layout(intro=intro))
|
|
235
|
+
# init the state
|
|
236
|
+
__fields.update(create_state())
|
|
237
|
+
# update the data after clicking process
|
|
238
|
+
__fields['process_block'].click(
|
|
239
|
+
fn=__compute,
|
|
240
|
+
inputs=[__fields[__k] for __k in ['tokens_block', 'topk_block', 'topp_block', 'position_block', 'input_block']],
|
|
241
|
+
outputs=[__fields[__k] for __k in ['output_state', 'hidden_state']],
|
|
242
|
+
queue=False,
|
|
243
|
+
show_progress='full').then(
|
|
244
|
+
# update the range of the position slider when the output changes
|
|
245
|
+
fn=update_position_range,
|
|
246
|
+
inputs=[__fields[__k] for __k in ['position_block', 'tokens_block', 'output_state']],
|
|
247
|
+
outputs=__fields['position_block'],
|
|
248
|
+
queue=False,
|
|
249
|
+
show_progress='hidden').then(
|
|
250
|
+
# update the token highlight when the output data changes
|
|
251
|
+
fn=__highlight,
|
|
252
|
+
inputs=[__fields[__k] for __k in ['position_block', 'output_state']],
|
|
253
|
+
outputs=__fields['output_block'],
|
|
254
|
+
queue=False,
|
|
255
|
+
show_progress='full').then(
|
|
256
|
+
# update the plot when the router data changes
|
|
257
|
+
fn=update_hidden_plot,
|
|
258
|
+
inputs=[__fields[__k] for __k in ['position_block', 'hidden_state']],
|
|
259
|
+
outputs=__fields['plot_block'],
|
|
260
|
+
queue=False,
|
|
261
|
+
show_progress='full')
|
|
262
|
+
# update the range of the position slider when the settings change
|
|
263
|
+
__fields['tokens_block'].change(
|
|
264
|
+
fn=update_position_range,
|
|
265
|
+
inputs=[__fields[__k] for __k in ['position_block', 'tokens_block', 'output_state']],
|
|
266
|
+
outputs=__fields['position_block'],
|
|
267
|
+
queue=False,
|
|
268
|
+
show_progress='hidden')
|
|
269
|
+
# update the plot when the focus changes
|
|
270
|
+
__fields['position_block'].change(
|
|
271
|
+
fn=update_hidden_plot,
|
|
272
|
+
inputs=[__fields[__k] for __k in ['position_block', 'hidden_state']],
|
|
273
|
+
outputs=__fields['plot_block'],
|
|
274
|
+
queue=False,
|
|
275
|
+
show_progress='full')
|
|
276
|
+
# update the token highlight when the token focus changes
|
|
277
|
+
__fields['position_block'].change(
|
|
278
|
+
fn=__highlight,
|
|
279
|
+
inputs=[__fields[__k] for __k in ['position_block', 'output_state']],
|
|
280
|
+
outputs=__fields['output_block'],
|
|
281
|
+
queue=False,
|
|
282
|
+
show_progress='hidden')
|
|
283
|
+
# gradio application
|
|
284
|
+
return __app
|
|
285
|
+
|
|
286
|
+
# MAIN #########################################################################
|
|
287
|
+
|
|
288
|
+
if __name__ == '__main__':
|
|
289
|
+
__app = create_app()
|
|
290
|
+
__app.launch(share=True, debug=True)
|
|
@@ -0,0 +1,134 @@
|
|
|
1
|
+
import functools
|
|
2
|
+
import math
|
|
3
|
+
|
|
4
|
+
import matplotlib
|
|
5
|
+
import numpy
|
|
6
|
+
import torch
|
|
7
|
+
|
|
8
|
+
# GENERATE #######################################################################
|
|
9
|
+
|
|
10
|
+
@functools.lru_cache(maxsize=32)
|
|
11
|
+
def generate_token_ids(
|
|
12
|
+
model_obj: object,
|
|
13
|
+
input_args: dict,
|
|
14
|
+
token_num: int,
|
|
15
|
+
topk_num: int = 4,
|
|
16
|
+
topp_num: float = 0.9,
|
|
17
|
+
) -> tuple:
|
|
18
|
+
# generate completion
|
|
19
|
+
with torch.no_grad():
|
|
20
|
+
__outputs = model_obj.generate(
|
|
21
|
+
**input_args,
|
|
22
|
+
max_new_tokens=token_num,
|
|
23
|
+
do_sample=(0.0 < topp_num < 1.0) or (topk_num > 0),
|
|
24
|
+
top_k=topk_num if (topk_num > 0) else None,
|
|
25
|
+
top_p=topp_num if (0.0 < topp_num < 1.0) else None,
|
|
26
|
+
return_dict_in_generate=True,
|
|
27
|
+
output_hidden_states=True,
|
|
28
|
+
output_attentions=False,
|
|
29
|
+
output_scores=False,
|
|
30
|
+
early_stopping=True,
|
|
31
|
+
use_cache=True)
|
|
32
|
+
# ((B, T), O * L * (B, I, H))
|
|
33
|
+
return __outputs.sequences, __outputs.hidden_states
|
|
34
|
+
|
|
35
|
+
# MERGE ########################################################################
|
|
36
|
+
|
|
37
|
+
def merge_hidden_states(
|
|
38
|
+
hidden_data: torch.Tensor,
|
|
39
|
+
) -> torch.Tensor:
|
|
40
|
+
# parse the inputs
|
|
41
|
+
__token_dim = len(hidden_data)
|
|
42
|
+
__layer_dim = len(hidden_data[0])
|
|
43
|
+
# stack the data for each layer => (B, L, I + O, H)
|
|
44
|
+
return torch.stack(
|
|
45
|
+
[
|
|
46
|
+
# concatenate the data for all the tokens => (B, I + O, H)
|
|
47
|
+
torch.concatenate([hidden_data[__t][__l] for __t in range(__token_dim)], dim=1)
|
|
48
|
+
for __l in range(__layer_dim)],
|
|
49
|
+
dim=1)
|
|
50
|
+
|
|
51
|
+
# REDUCE #######################################################################
|
|
52
|
+
|
|
53
|
+
def reduce_hidden_states(
|
|
54
|
+
hidden_data: torch.Tensor,
|
|
55
|
+
token_idx: int, # -1 => avg over all tokens
|
|
56
|
+
) -> torch.Tensor:
|
|
57
|
+
# parse the hidden states (B, L, T, H)
|
|
58
|
+
__batch_dim, __layer_dim, __token_dim, __hidden_dim = tuple(hidden_data.shape)
|
|
59
|
+
__token_idx = min(token_idx, __token_dim - 1)
|
|
60
|
+
# select the relevant data along each axis
|
|
61
|
+
__token_slice = slice(0, __token_dim) if (__token_idx < 0) else slice(__token_idx, __token_idx + 1)
|
|
62
|
+
# filter the data
|
|
63
|
+
__data = hidden_data[slice(None), slice(None), __token_slice, slice(None)]
|
|
64
|
+
# reduce the token axis => (B, L, H)
|
|
65
|
+
return __data.mean(dim=2, keepdim=False)
|
|
66
|
+
|
|
67
|
+
# RESCALE ######################################################################
|
|
68
|
+
|
|
69
|
+
def rescale_hidden_states(
|
|
70
|
+
hidden_data: torch.Tensor, # (B, L, H)
|
|
71
|
+
) -> torch.Tensor:
|
|
72
|
+
# compute the scale of the data, layer by layer
|
|
73
|
+
__s = torch.quantile(hidden_data.abs(), q=0.9, dim=-1, keepdim=True)
|
|
74
|
+
# log scaling on large values and linear near 0
|
|
75
|
+
__a = torch.asinh(hidden_data / (__s + torch.finfo().eps))
|
|
76
|
+
# clip and map to [-1; 1]
|
|
77
|
+
return 0.33 * __a.clamp(min=-3, max=3)
|
|
78
|
+
|
|
79
|
+
# RESHAPE ######################################################################
|
|
80
|
+
|
|
81
|
+
def reshape_hidden_states(
|
|
82
|
+
hidden_data: torch.Tensor, # (B, L, H)
|
|
83
|
+
) -> torch.Tensor:
|
|
84
|
+
# parse the hidden states (B, L, H)
|
|
85
|
+
__batch_dim, __layer_dim, __hidden_dim = tuple(hidden_data.shape)
|
|
86
|
+
# factor the hidden dimension
|
|
87
|
+
__width_dim = math.gcd(__hidden_dim, 2 ** int(math.log2(__hidden_dim))) # greatest power of 2 that divides H
|
|
88
|
+
__height_dim = __hidden_dim // __width_dim
|
|
89
|
+
# reshape into (B, W, H, L)
|
|
90
|
+
return hidden_data.reshape((__batch_dim, __layer_dim, __width_dim, __height_dim)).permute(0, 2, 3, 1)
|
|
91
|
+
|
|
92
|
+
# MASK #########################################################################
|
|
93
|
+
|
|
94
|
+
def mask_hidden_states(
|
|
95
|
+
hidden_data: torch.Tensor, # (B, L, H)
|
|
96
|
+
topk_num: int=128,
|
|
97
|
+
) -> torch.Tensor:
|
|
98
|
+
# sanitize
|
|
99
|
+
__k = min(topk_num, int(hidden_data.shape[-1]))
|
|
100
|
+
# indices of the topk values
|
|
101
|
+
__indices = hidden_data.abs().topk(__k, dim=-1, largest=True, sorted=False).indices
|
|
102
|
+
# initialize the mask with False
|
|
103
|
+
__mask = torch.zeros_like(hidden_data, dtype=torch.bool)
|
|
104
|
+
# (B, L, H) mask of the topk values
|
|
105
|
+
return __mask.scatter_(dim=-1, index=__indices, src=True)
|
|
106
|
+
|
|
107
|
+
# FORMAT #######################################################################
|
|
108
|
+
|
|
109
|
+
def color_hidden_states(
|
|
110
|
+
hidden_data: numpy.array, # (B, H, W, L)
|
|
111
|
+
gamma_val: float=0.7,
|
|
112
|
+
alpha_val: float=0.35,
|
|
113
|
+
color_map: matplotlib.colormaps['coolwarm'],
|
|
114
|
+
) -> list:
|
|
115
|
+
# [-1; 1] => [0; 1]
|
|
116
|
+
__data = 0.5 * (hidden_data + 1.0)
|
|
117
|
+
# (B, W, H, L) => (B, W, H, L, 4)
|
|
118
|
+
__rgba = color_map[__data]
|
|
119
|
+
# compute the transparency from the magnitude
|
|
120
|
+
__rgba[..., 3] = alpha_val * (np.abs(hidden_data) ** gamma_val)
|
|
121
|
+
# (B, W, H, L, 4) in [0; 1]
|
|
122
|
+
return __rgba
|
|
123
|
+
|
|
124
|
+
# POSTPROCESS ##################################################################
|
|
125
|
+
|
|
126
|
+
def postprocess_token_cls(
|
|
127
|
+
token_idx: int,
|
|
128
|
+
token_dim: int,
|
|
129
|
+
) -> list:
|
|
130
|
+
__token_idx = max(-1, min(token_dim, token_idx))
|
|
131
|
+
# class 1 for the focused token(s) 0 for the rest
|
|
132
|
+
__token_cls = [str(int(__i == token_idx)) for __i in range(token_dim)]
|
|
133
|
+
# average on all the tokens when the idx is negative
|
|
134
|
+
return token_dim * ['1'] if (token_idx < 0) else __token_cls
|
|
File without changes
|