psaiops 0.4.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of psaiops might be problematic. Click here for more details.
- psaiops/__init__.py +0 -0
- psaiops/combine/__init__.py +0 -0
- psaiops/combine/app.py +366 -0
- psaiops/common/__init__.py +0 -0
- psaiops/common/data.py +31 -0
- psaiops/common/model.py +73 -0
- psaiops/common/tokenizer.py +41 -0
- psaiops/compose/__init__.py +0 -0
- psaiops/compose/contrast/__init__.py +0 -0
- psaiops/compose/contrast/app.py +195 -0
- psaiops/compose/contrast/lib.py +143 -0
- psaiops/compose/maths/__init__.py +0 -0
- psaiops/compose/maths/app.py +323 -0
- psaiops/compose/maths/lib.py +1 -0
- psaiops/edit/__init__.py +0 -0
- psaiops/reverse/__init__.py +0 -0
- psaiops/score/__init__.py +0 -0
- psaiops/score/attention/__init__.py +0 -0
- psaiops/score/attention/app.py +303 -0
- psaiops/score/attention/lib.py +118 -0
- psaiops/score/residual/__init__.py +0 -0
- psaiops/score/residual/app.py +507 -0
- psaiops/score/residual/lib.py +187 -0
- psaiops/score/router/__init__.py +0 -0
- psaiops/score/router/app.py +282 -0
- psaiops/score/router/lib.py +59 -0
- psaiops/score/shapley/__init__.py +0 -0
- psaiops/score/shapley/app.py +158 -0
- psaiops/score/shapley/lib.py +1 -0
- psaiops/score/similarity/__init__.py +0 -0
- psaiops/score/similarity/app.py +152 -0
- psaiops/score/similarity/lib.py +1 -0
- psaiops-0.4.7.dist-info/METADATA +34 -0
- psaiops-0.4.7.dist-info/RECORD +36 -0
- psaiops-0.4.7.dist-info/WHEEL +4 -0
- psaiops-0.4.7.dist-info/licenses/.github/LICENSE.md +661 -0
|
@@ -0,0 +1,507 @@
|
|
|
1
|
+
import functools
|
|
2
|
+
|
|
3
|
+
import gradio
|
|
4
|
+
import numpy
|
|
5
|
+
import torch
|
|
6
|
+
import torch.cuda
|
|
7
|
+
import matplotlib.pyplot
|
|
8
|
+
|
|
9
|
+
import psaiops.common.model
|
|
10
|
+
import psaiops.common.tokenizer
|
|
11
|
+
import psaiops.score.residual.lib
|
|
12
|
+
|
|
13
|
+
# META #########################################################################
|
|
14
|
+
|
|
15
|
+
STYLE = '''.white-text span { color: white; }'''
|
|
16
|
+
TITLE = '''Visualization Of Residuals'''
|
|
17
|
+
INTRO = '''Plot the hidden states for a given prompt.\nUnder construction, only "openai/gpt-oss-20b" is available for now.'''
|
|
18
|
+
|
|
19
|
+
MODEL = 'openai/gpt-oss-20b'
|
|
20
|
+
|
|
21
|
+
# COLORS #######################################################################
|
|
22
|
+
|
|
23
|
+
def create_selection_cmap() -> dict:
|
|
24
|
+
return {
|
|
25
|
+
'0': '#000000',
|
|
26
|
+
'1': '#004444',
|
|
27
|
+
'2': '#444400',
|
|
28
|
+
'3': '#440044',}
|
|
29
|
+
|
|
30
|
+
def create_score_cmap() -> dict:
|
|
31
|
+
return {str(__i): '#{:02x}0000'.format(int(2.55 * __i)) for __i in range(101)}
|
|
32
|
+
|
|
33
|
+
# INTRO ########################################################################
|
|
34
|
+
|
|
35
|
+
def create_intro_block(intro: str) -> dict:
|
|
36
|
+
__intro = gradio.Markdown(intro, line_breaks=True)
|
|
37
|
+
return {'intro_block': __intro}
|
|
38
|
+
|
|
39
|
+
# MODEL ########################################################################
|
|
40
|
+
|
|
41
|
+
def create_model_block() -> dict:
|
|
42
|
+
__model = gradio.Dropdown(label='Model', value='openai/gpt-oss-20b', choices=['openai/gpt-oss-20b'], scale=1, allow_custom_value=False, multiselect=False, interactive=True) # 'openai/gpt-oss-120b'
|
|
43
|
+
return {'model_block': __model,}
|
|
44
|
+
|
|
45
|
+
# SAMPLING #####################################################################
|
|
46
|
+
|
|
47
|
+
def create_sampling_block() -> dict:
|
|
48
|
+
__tokens = gradio.Slider(label='Tokens', value=16, minimum=1, maximum=128, step=1, scale=1, interactive=True)
|
|
49
|
+
__topk = gradio.Slider(label='Top K', value=4, minimum=1, maximum=8, step=1, scale=1, interactive=True)
|
|
50
|
+
__topp = gradio.Slider(label='Top P', value=0.9, minimum=0.0, maximum=1.0, step=0.1, scale=1, interactive=True)
|
|
51
|
+
return {
|
|
52
|
+
'tokens_block': __tokens,
|
|
53
|
+
'topk_block': __topk,
|
|
54
|
+
'topp_block': __topp,}
|
|
55
|
+
|
|
56
|
+
# DATAVIZ ######################################################################
|
|
57
|
+
|
|
58
|
+
def create_visualization_block() -> dict:
|
|
59
|
+
__3d = gradio.Slider(label='3D', value=1, minimum=0, maximum=1, step=1, scale=1, interactive=True)
|
|
60
|
+
__points = gradio.Slider(label='Points', value=128, minimum=32, maximum=2880, step=32, scale=1, interactive=True)
|
|
61
|
+
return {
|
|
62
|
+
'axes_block': __3d,
|
|
63
|
+
'points_block': __points,}
|
|
64
|
+
|
|
65
|
+
# INPUTS #######################################################################
|
|
66
|
+
|
|
67
|
+
def create_inputs_block(label: str='Prompt') -> dict:
|
|
68
|
+
__input = gradio.Textbox(label=label, value='', placeholder='A string of tokens to score.', lines=4, scale=1, interactive=True)
|
|
69
|
+
return {'input_block': __input}
|
|
70
|
+
|
|
71
|
+
# PLOTS ########################################################################
|
|
72
|
+
|
|
73
|
+
def create_plot_block(label: str='Residuals', prefix: str='') -> dict:
|
|
74
|
+
__plot = gradio.Plot(label=label, scale=1)
|
|
75
|
+
return {prefix + 'plot_block': __plot,}
|
|
76
|
+
|
|
77
|
+
# OUTPUTS ######################################################################
|
|
78
|
+
|
|
79
|
+
def create_highlight_block(label: str='Output', prefix: str='', cmap: dict=create_selection_cmap()) -> dict:
|
|
80
|
+
__output = gradio.HighlightedText(label=label, value='', scale=1, interactive=False, show_legend=False, show_inline_category=False, combine_adjacent=False, color_map=cmap, elem_classes='white-text')
|
|
81
|
+
return {prefix + 'highlight_block': __output}
|
|
82
|
+
|
|
83
|
+
# SELECT #######################################################################
|
|
84
|
+
|
|
85
|
+
def create_token_selection_block(label: str='Token', prefix: str='') -> dict:
|
|
86
|
+
# __play = gradio.Button('>', variant='primary', size='lg', scale=1, interactive=True)
|
|
87
|
+
__position = gradio.Slider(label=label, value=-1, minimum=-1, maximum=15, step=1, scale=1, interactive=True) # info='-1 to average on all tokens'
|
|
88
|
+
return {prefix + 'position_block': __position,}
|
|
89
|
+
|
|
90
|
+
def create_layer_selection_block(label: str='Layer', prefix: str='') -> dict:
|
|
91
|
+
__layer = gradio.Slider(label=label, value=-1, minimum=-1, maximum=23, step=1, scale=1, interactive=True) # info='-1 to average on all layers'
|
|
92
|
+
return {prefix + 'layer_block': __layer,}
|
|
93
|
+
|
|
94
|
+
# ACTIONS ######################################################################
|
|
95
|
+
|
|
96
|
+
def create_actions_block() -> dict:
|
|
97
|
+
__process = gradio.Button('Process', variant='primary', size='lg', scale=1, interactive=True)
|
|
98
|
+
return {'process_block': __process,}
|
|
99
|
+
|
|
100
|
+
# STATE ########################################################################
|
|
101
|
+
|
|
102
|
+
def create_state() -> dict:
|
|
103
|
+
return {
|
|
104
|
+
'output_state': gradio.State(None),
|
|
105
|
+
'hidden_state': gradio.State(None),}
|
|
106
|
+
|
|
107
|
+
# LAYOUT #######################################################################
|
|
108
|
+
|
|
109
|
+
def create_layout(intro: str=INTRO) -> dict:
|
|
110
|
+
__fields = {}
|
|
111
|
+
__fields.update(create_intro_block(intro=intro))
|
|
112
|
+
with gradio.Tabs():
|
|
113
|
+
with gradio.Tab('Residuals') as __main_tab:
|
|
114
|
+
__fields.update({'main_tab': __main_tab})
|
|
115
|
+
with gradio.Row(equal_height=True):
|
|
116
|
+
__fields.update(create_inputs_block())
|
|
117
|
+
with gradio.Row(equal_height=True):
|
|
118
|
+
__fields.update(create_highlight_block())
|
|
119
|
+
with gradio.Row(equal_height=True):
|
|
120
|
+
__fields.update(create_plot_block(label='Left', prefix='left_'))
|
|
121
|
+
__fields.update(create_plot_block(label='Right', prefix='right_'))
|
|
122
|
+
with gradio.Row(equal_height=True):
|
|
123
|
+
__fields.update(create_highlight_block(label='Score', prefix='left_', cmap=create_score_cmap()))
|
|
124
|
+
__fields.update(create_highlight_block(label='Score', prefix='right_', cmap=create_score_cmap()))
|
|
125
|
+
with gradio.Row(equal_height=True):
|
|
126
|
+
__fields.update(create_token_selection_block(label='Token', prefix='left_'))
|
|
127
|
+
__fields.update(create_token_selection_block(label='Token', prefix='right_'))
|
|
128
|
+
with gradio.Row(equal_height=True):
|
|
129
|
+
__fields.update(create_layer_selection_block(label='Layer', prefix='left_'))
|
|
130
|
+
__fields.update(create_layer_selection_block(label='Layer', prefix='right_'))
|
|
131
|
+
with gradio.Row(equal_height=True):
|
|
132
|
+
__fields.update(create_actions_block())
|
|
133
|
+
with gradio.Tab('Settings') as __settings_tab:
|
|
134
|
+
__fields.update({'settings_tab': __settings_tab})
|
|
135
|
+
with gradio.Row(equal_height=True):
|
|
136
|
+
__fields.update(create_model_block())
|
|
137
|
+
with gradio.Row(equal_height=True):
|
|
138
|
+
__fields.update(create_sampling_block())
|
|
139
|
+
with gradio.Row(equal_height=True):
|
|
140
|
+
__fields.update(create_visualization_block())
|
|
141
|
+
return __fields
|
|
142
|
+
|
|
143
|
+
# EVENTS #######################################################################
|
|
144
|
+
|
|
145
|
+
def update_position_range(
|
|
146
|
+
current_val: float,
|
|
147
|
+
token_num: float,
|
|
148
|
+
output_data: torch.Tensor,
|
|
149
|
+
) -> dict:
|
|
150
|
+
# take the generated tokens into account
|
|
151
|
+
__max = int(token_num) - 1 if (output_data is None) else int(output_data.shape[-1])
|
|
152
|
+
# keep the previous value if possible
|
|
153
|
+
__val = min(int(current_val), __max)
|
|
154
|
+
# return a gradio update dictionary
|
|
155
|
+
return gradio.update(maximum=__max, value=__val)
|
|
156
|
+
|
|
157
|
+
# GENERATE #####################################################################
|
|
158
|
+
|
|
159
|
+
def update_computation_state(
|
|
160
|
+
token_num: float,
|
|
161
|
+
topk_num: float,
|
|
162
|
+
topp_num: float,
|
|
163
|
+
prompt_str: str,
|
|
164
|
+
device_str: str,
|
|
165
|
+
model_obj: object,
|
|
166
|
+
tokenizer_obj: object,
|
|
167
|
+
) -> tuple:
|
|
168
|
+
# sanitize the inputs
|
|
169
|
+
__token_num = max(1, min(128, int(token_num)))
|
|
170
|
+
__topk_num = max(1, min(8, int(topk_num)))
|
|
171
|
+
__topp_num = max(0.0, min(1.0, float(topp_num)))
|
|
172
|
+
__prompt_str = prompt_str.strip()
|
|
173
|
+
__device_str = device_str if (device_str in ['cpu', 'cuda']) else 'cpu'
|
|
174
|
+
# exit if some values are missing
|
|
175
|
+
if (not __prompt_str) or (model_obj is None) or (tokenizer_obj is None):
|
|
176
|
+
return (torch.empty(0), torch.empty(0))
|
|
177
|
+
# dictionary {'input_ids': _, 'attention_mask': _}
|
|
178
|
+
__input_data = psaiops.common.tokenizer.preprocess_token_ids(
|
|
179
|
+
tokenizer_obj=tokenizer_obj,
|
|
180
|
+
prompt_str=__prompt_str,
|
|
181
|
+
device_str=__device_str)
|
|
182
|
+
# tensor (1, T) and O * L * (1, I, H)
|
|
183
|
+
__output_data, __hidden_data = psaiops.score.residual.lib.generate_token_ids(
|
|
184
|
+
model_obj=model_obj,
|
|
185
|
+
input_ids=__input_data['input_ids'],
|
|
186
|
+
attention_mask=__input_data['attention_mask'],
|
|
187
|
+
token_num=__token_num,
|
|
188
|
+
topk_num=__topk_num,
|
|
189
|
+
topp_num=__topp_num)
|
|
190
|
+
# tensor (1, L, I + O, H)
|
|
191
|
+
__hidden_data = psaiops.score.residual.lib.merge_hidden_states(
|
|
192
|
+
hidden_data=__hidden_data)
|
|
193
|
+
# update each component => (highlight, plot) states
|
|
194
|
+
return (
|
|
195
|
+
__output_data.cpu().float(),
|
|
196
|
+
__hidden_data.cpu().float(),)
|
|
197
|
+
|
|
198
|
+
# HIGHLIGHT ####################################################################
|
|
199
|
+
|
|
200
|
+
def update_token_focus(
|
|
201
|
+
left_idx: float,
|
|
202
|
+
right_idx: float,
|
|
203
|
+
output_data: torch.Tensor,
|
|
204
|
+
tokenizer_obj: object,
|
|
205
|
+
) -> list:
|
|
206
|
+
# exit if some values are missing
|
|
207
|
+
if (output_data is None) or (len(output_data) == 0):
|
|
208
|
+
return None
|
|
209
|
+
# detokenize the IDs
|
|
210
|
+
__token_str = psaiops.common.tokenizer.postprocess_token_ids(
|
|
211
|
+
tokenizer_obj=tokenizer_obj,
|
|
212
|
+
token_data=output_data)
|
|
213
|
+
# list of string classes
|
|
214
|
+
__token_cls = psaiops.score.residual.lib.postprocess_focus_cls(
|
|
215
|
+
left_idx=int(left_idx),
|
|
216
|
+
right_idx=int(right_idx),
|
|
217
|
+
token_dim=len(__token_str))
|
|
218
|
+
# pairs of token and class
|
|
219
|
+
return list(zip(__token_str, __token_cls))
|
|
220
|
+
|
|
221
|
+
# SCORES #######################################################################
|
|
222
|
+
|
|
223
|
+
def update_token_scores(
|
|
224
|
+
layer_idx: float,
|
|
225
|
+
output_data: torch.Tensor,
|
|
226
|
+
hidden_data: torch.Tensor,
|
|
227
|
+
tokenizer_obj: object,
|
|
228
|
+
model_obj: object,
|
|
229
|
+
) -> list:
|
|
230
|
+
# exit if some values are missing
|
|
231
|
+
if (output_data is None) or (len(output_data) == 0) or (hidden_data is None) or (len(hidden_data) == 0):
|
|
232
|
+
return None
|
|
233
|
+
# parse the model meta
|
|
234
|
+
__device_str = model_obj.lm_head.weight.device
|
|
235
|
+
__dtype_obj = model_obj.lm_head.weight.dtype
|
|
236
|
+
# detokenize the IDs
|
|
237
|
+
__token_str = psaiops.common.tokenizer.postprocess_token_ids(
|
|
238
|
+
tokenizer_obj=tokenizer_obj,
|
|
239
|
+
token_data=output_data)
|
|
240
|
+
# select the relevant hidden states
|
|
241
|
+
__final_states = hidden_data[0, -1, :, :].to(device=__device_str, dtype=__dtype_obj)
|
|
242
|
+
__layer_states = hidden_data[0, int(layer_idx), :, :].to(device=__device_str, dtype=__dtype_obj)
|
|
243
|
+
# compute the logits
|
|
244
|
+
__final_logits = model_obj.lm_head(__final_states).detach().cpu() # already normalized
|
|
245
|
+
__layer_logits = model_obj.lm_head(model_obj.model.norm(__layer_states)).detach().cpu()
|
|
246
|
+
# compute the JSD metric
|
|
247
|
+
__token_jsd = jsd_from_logits(final_logits=__final_logits, prefix_logits=__layer_logits)
|
|
248
|
+
# scale into a [0; 100] label
|
|
249
|
+
__token_cls = postprocess_score_cls(score_data=__token_jsd)
|
|
250
|
+
# color each token according to the distance between the distribution at layer L and the final distribution
|
|
251
|
+
return list(zip(__token_str, __token_cls))
|
|
252
|
+
|
|
253
|
+
# PLOT #########################################################################
|
|
254
|
+
|
|
255
|
+
def update_2d_plot(
|
|
256
|
+
token_idx: float,
|
|
257
|
+
layer_idx: float,
|
|
258
|
+
hidden_data: torch.Tensor,
|
|
259
|
+
) -> tuple:
|
|
260
|
+
# reduce the layer and token axes (B, L, T, E) => (B, E)
|
|
261
|
+
__plot_data = psaiops.score.residual.lib.reduce_hidden_states(
|
|
262
|
+
hidden_data=hidden_data,
|
|
263
|
+
layer_idx=int(layer_idx),
|
|
264
|
+
token_idx=int(token_idx),
|
|
265
|
+
axes_idx=(1, 2))
|
|
266
|
+
# rescale the data to [-1; 1] (B, E)
|
|
267
|
+
__plot_data = psaiops.score.residual.lib.rescale_hidden_states(
|
|
268
|
+
hidden_data=__plot_data)
|
|
269
|
+
# reshape into a 3D tensor by folding E (B, E) => (B, W, H)
|
|
270
|
+
__plot_data = psaiops.score.residual.lib.reshape_hidden_states(
|
|
271
|
+
hidden_data=__plot_data,
|
|
272
|
+
layer_idx=-1) # there is no layer axis
|
|
273
|
+
# map the [-1; 1] activations to RGBA colors
|
|
274
|
+
__plot_data = psaiops.score.residual.lib.color_hidden_states(
|
|
275
|
+
hidden_data=__plot_data.numpy())
|
|
276
|
+
# plot the first sample
|
|
277
|
+
__figure = matplotlib.pyplot.figure()
|
|
278
|
+
__axes = __figure.add_subplot(1, 1, 1)
|
|
279
|
+
__axes.imshow(__plot_data[0], vmin=0.0, vmax=1.0, cmap='viridis')
|
|
280
|
+
__figure.tight_layout()
|
|
281
|
+
# remove the figure for the pyplot register for garbage collection
|
|
282
|
+
matplotlib.pyplot.close(__figure)
|
|
283
|
+
# update each component => (highlight, plot) states
|
|
284
|
+
return __figure
|
|
285
|
+
|
|
286
|
+
def update_3d_plot(
|
|
287
|
+
token_idx: float,
|
|
288
|
+
layer_idx: float,
|
|
289
|
+
points_num: float,
|
|
290
|
+
hidden_data: torch.Tensor,
|
|
291
|
+
) -> tuple:
|
|
292
|
+
# reduce the token axis (B, L, T, E) => (B, L, E)
|
|
293
|
+
__plot_data = psaiops.score.residual.lib.reduce_hidden_states(
|
|
294
|
+
hidden_data=hidden_data,
|
|
295
|
+
token_idx=int(token_idx),
|
|
296
|
+
layer_idx=int(layer_idx),
|
|
297
|
+
axes_idx=2)
|
|
298
|
+
# rescale the data to [-1; 1] (B, L, E)
|
|
299
|
+
__plot_data = psaiops.score.residual.lib.rescale_hidden_states(
|
|
300
|
+
hidden_data=__plot_data)
|
|
301
|
+
# mask the small activations to improve the plot readability
|
|
302
|
+
__mask_data = psaiops.score.residual.lib.mask_hidden_states(
|
|
303
|
+
hidden_data=__plot_data,
|
|
304
|
+
topk_num=int(points_num) if int(layer_idx) == -1 else 2 * int(points_num))
|
|
305
|
+
# reshape into a 3D tensor by folding E (B, L, E) => (B, W, H, L)
|
|
306
|
+
__plot_data = psaiops.score.residual.lib.reshape_hidden_states(
|
|
307
|
+
hidden_data=__plot_data,
|
|
308
|
+
layer_idx=1)
|
|
309
|
+
__mask_data = psaiops.score.residual.lib.reshape_hidden_states(
|
|
310
|
+
hidden_data=__mask_data,
|
|
311
|
+
layer_idx=1)
|
|
312
|
+
# convert to numpy ndarrays
|
|
313
|
+
__plot_data = __plot_data.numpy()
|
|
314
|
+
__mask_data = __mask_data.numpy()
|
|
315
|
+
# map the [-1; 1] activations to RGBA colors
|
|
316
|
+
__rgb_data = psaiops.score.residual.lib.color_hidden_states(
|
|
317
|
+
hidden_data=__plot_data)
|
|
318
|
+
# map the [-1; 1] activations to point areas
|
|
319
|
+
__area_data = psaiops.score.residual.lib.size_hidden_states(
|
|
320
|
+
hidden_data=__plot_data,
|
|
321
|
+
area_min=0.01,
|
|
322
|
+
area_max=16.0,
|
|
323
|
+
gamma_val=1.6)
|
|
324
|
+
# format the first sample for a scatter plot
|
|
325
|
+
__x, __y, __z = numpy.nonzero(__mask_data[0])
|
|
326
|
+
__c = __rgb_data[0, __x, __y, __z]
|
|
327
|
+
__s = __area_data[0, __x, __y, __z]
|
|
328
|
+
# plot the first sample
|
|
329
|
+
__figure = matplotlib.pyplot.figure()
|
|
330
|
+
__axes = __figure.add_subplot(1, 1, 1, projection='3d')
|
|
331
|
+
__axes.scatter(__x, __y, __z, c=__c, s=__s, marker='o', linewidths=0)
|
|
332
|
+
__figure.tight_layout()
|
|
333
|
+
# remove the figure for the pyplot register for garbage collection
|
|
334
|
+
matplotlib.pyplot.close(__figure)
|
|
335
|
+
# update each component => (highlight, plot) states
|
|
336
|
+
return __figure
|
|
337
|
+
|
|
338
|
+
def update_hidden_plot(
|
|
339
|
+
token_idx: float,
|
|
340
|
+
layer_idx: float,
|
|
341
|
+
axes_num: float,
|
|
342
|
+
points_num: float,
|
|
343
|
+
hidden_data: torch.Tensor,
|
|
344
|
+
) -> tuple:
|
|
345
|
+
# exit if some values are missing
|
|
346
|
+
if (hidden_data is None) or (len(hidden_data) == 0):
|
|
347
|
+
return None
|
|
348
|
+
# plot the residuals of a given layer on a 2D heatmap
|
|
349
|
+
if not axes_num: # 0.0 or 0
|
|
350
|
+
return update_2d_plot(
|
|
351
|
+
token_idx=token_idx,
|
|
352
|
+
layer_idx=layer_idx,
|
|
353
|
+
hidden_data=hidden_data)
|
|
354
|
+
# by default, plot the residuals for all the layers in 3D
|
|
355
|
+
return update_3d_plot(
|
|
356
|
+
token_idx=token_idx,
|
|
357
|
+
layer_idx=layer_idx,
|
|
358
|
+
points_num=points_num,
|
|
359
|
+
hidden_data=hidden_data)
|
|
360
|
+
|
|
361
|
+
# APP ##########################################################################
|
|
362
|
+
|
|
363
|
+
def create_app(title: str=TITLE, intro: str=INTRO, style: str=STYLE, model: str=MODEL) -> gradio.Blocks:
|
|
364
|
+
__fields = {}
|
|
365
|
+
with gradio.Blocks(theme=gradio.themes.Soft(), title=title, css=style) as __app:
|
|
366
|
+
# load the model
|
|
367
|
+
__device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
|
368
|
+
__model = psaiops.common.model.get_model(name=model, device=__device)
|
|
369
|
+
__tokenizer = psaiops.common.tokenizer.get_tokenizer(name=model, device=__device)
|
|
370
|
+
# adapt the event handlers
|
|
371
|
+
__compute = functools.partial(update_computation_state, model_obj=__model, tokenizer_obj=__tokenizer, device_str=__device)
|
|
372
|
+
__highlight = functools.partial(update_token_focus, tokenizer_obj=__tokenizer)
|
|
373
|
+
__score = functools.partial(update_token_scores, tokenizer_obj=__tokenizer, model_obj=__model)
|
|
374
|
+
# create the UI
|
|
375
|
+
__fields.update(create_layout(intro=intro))
|
|
376
|
+
# init the state
|
|
377
|
+
__fields.update(create_state())
|
|
378
|
+
# update the data after clicking process
|
|
379
|
+
__fields['process_block'].click(
|
|
380
|
+
fn=__compute,
|
|
381
|
+
inputs=[__fields[__k] for __k in ['tokens_block', 'topk_block', 'topp_block', 'input_block']],
|
|
382
|
+
outputs=[__fields[__k] for __k in ['output_state', 'hidden_state']],
|
|
383
|
+
queue=False,
|
|
384
|
+
show_progress='full'
|
|
385
|
+
).then(
|
|
386
|
+
# update the range of the position sliders when the output changes
|
|
387
|
+
fn=update_position_range,
|
|
388
|
+
inputs=[__fields[__k] for __k in ['left_position_block', 'tokens_block', 'output_state']],
|
|
389
|
+
outputs=__fields['left_position_block'],
|
|
390
|
+
queue=False,
|
|
391
|
+
show_progress='hidden'
|
|
392
|
+
).then(
|
|
393
|
+
fn=update_position_range,
|
|
394
|
+
inputs=[__fields[__k] for __k in ['right_position_block', 'tokens_block', 'output_state']],
|
|
395
|
+
outputs=__fields['right_position_block'],
|
|
396
|
+
queue=False,
|
|
397
|
+
show_progress='hidden'
|
|
398
|
+
).then(
|
|
399
|
+
# update the token highlight when the output data changes
|
|
400
|
+
fn=__highlight,
|
|
401
|
+
inputs=[__fields[__k] for __k in ['left_position_block', 'right_position_block', 'output_state']],
|
|
402
|
+
outputs=__fields['highlight_block'],
|
|
403
|
+
queue=False,
|
|
404
|
+
show_progress='hidden'
|
|
405
|
+
).then(
|
|
406
|
+
# update the left token scores when the output data changes
|
|
407
|
+
fn=__score,
|
|
408
|
+
inputs=[__fields[__k] for __k in ['left_layer_block', 'output_state', 'hidden_state']],
|
|
409
|
+
outputs=__fields['left_highlight_block'],
|
|
410
|
+
queue=False,
|
|
411
|
+
show_progress='hidden'
|
|
412
|
+
).then(
|
|
413
|
+
# update the right token scores when the output data changes
|
|
414
|
+
fn=__score,
|
|
415
|
+
inputs=[__fields[__k] for __k in ['right_layer_block', 'output_state', 'hidden_state']],
|
|
416
|
+
outputs=__fields['right_highlight_block'],
|
|
417
|
+
queue=False,
|
|
418
|
+
show_progress='hidden'
|
|
419
|
+
).then(
|
|
420
|
+
# update the plot when the router data changes
|
|
421
|
+
fn=update_hidden_plot,
|
|
422
|
+
inputs=[__fields[__k] for __k in ['left_position_block', 'left_layer_block', 'axes_block', 'points_block', 'hidden_state']],
|
|
423
|
+
outputs=__fields['left_plot_block'],
|
|
424
|
+
queue=False,
|
|
425
|
+
show_progress='hidden'
|
|
426
|
+
).then(
|
|
427
|
+
fn=update_hidden_plot,
|
|
428
|
+
inputs=[__fields[__k] for __k in ['right_position_block', 'right_layer_block', 'axes_block', 'points_block', 'hidden_state']],
|
|
429
|
+
outputs=__fields['right_plot_block'],
|
|
430
|
+
queue=False,
|
|
431
|
+
show_progress='hidden')
|
|
432
|
+
# update the range of the position slider when the settings change
|
|
433
|
+
__fields['tokens_block'].change(
|
|
434
|
+
fn=update_position_range,
|
|
435
|
+
inputs=[__fields[__k] for __k in ['left_position_block', 'tokens_block', 'output_state']],
|
|
436
|
+
outputs=__fields['left_position_block'],
|
|
437
|
+
queue=False,
|
|
438
|
+
show_progress='hidden'
|
|
439
|
+
).then(
|
|
440
|
+
fn=update_position_range,
|
|
441
|
+
inputs=[__fields[__k] for __k in ['right_position_block', 'tokens_block', 'output_state']],
|
|
442
|
+
outputs=__fields['right_position_block'],
|
|
443
|
+
queue=False,
|
|
444
|
+
show_progress='hidden')
|
|
445
|
+
# update the left plot when the focus changes
|
|
446
|
+
__fields['left_position_block'].change(
|
|
447
|
+
fn=update_hidden_plot,
|
|
448
|
+
inputs=[__fields[__k] for __k in ['left_position_block', 'left_layer_block', 'axes_block', 'points_block', 'hidden_state']],
|
|
449
|
+
outputs=__fields['left_plot_block'],
|
|
450
|
+
queue=False,
|
|
451
|
+
show_progress='hidden')
|
|
452
|
+
__fields['left_layer_block'].change(
|
|
453
|
+
fn=update_hidden_plot,
|
|
454
|
+
inputs=[__fields[__k] for __k in ['left_position_block', 'left_layer_block', 'axes_block', 'points_block', 'hidden_state']],
|
|
455
|
+
outputs=__fields['left_plot_block'],
|
|
456
|
+
queue=False,
|
|
457
|
+
show_progress='hidden'
|
|
458
|
+
).then(
|
|
459
|
+
# update the left token scores when the focus changes
|
|
460
|
+
fn=__score,
|
|
461
|
+
inputs=[__fields[__k] for __k in ['left_layer_block', 'output_state', 'hidden_state']],
|
|
462
|
+
outputs=__fields['left_highlight_block'],
|
|
463
|
+
queue=False,
|
|
464
|
+
show_progress='hidden'
|
|
465
|
+
)
|
|
466
|
+
# update the right plot when the focus changes
|
|
467
|
+
__fields['right_position_block'].change(
|
|
468
|
+
fn=update_hidden_plot,
|
|
469
|
+
inputs=[__fields[__k] for __k in ['right_position_block', 'right_layer_block', 'axes_block', 'points_block', 'hidden_state']],
|
|
470
|
+
outputs=__fields['right_plot_block'],
|
|
471
|
+
queue=False,
|
|
472
|
+
show_progress='hidden')
|
|
473
|
+
__fields['right_layer_block'].change(
|
|
474
|
+
fn=update_hidden_plot,
|
|
475
|
+
inputs=[__fields[__k] for __k in ['right_position_block', 'right_layer_block', 'axes_block', 'points_block', 'hidden_state']],
|
|
476
|
+
outputs=__fields['right_plot_block'],
|
|
477
|
+
queue=False,
|
|
478
|
+
show_progress='hidden'
|
|
479
|
+
).then(
|
|
480
|
+
# update the right token scores when the focus changes
|
|
481
|
+
fn=__score,
|
|
482
|
+
inputs=[__fields[__k] for __k in ['right_layer_block', 'output_state', 'hidden_state']],
|
|
483
|
+
outputs=__fields['right_highlight_block'],
|
|
484
|
+
queue=False,
|
|
485
|
+
show_progress='hidden'
|
|
486
|
+
)
|
|
487
|
+
# update the token highlight when the token focus changes
|
|
488
|
+
__fields['left_position_block'].change(
|
|
489
|
+
fn=__highlight,
|
|
490
|
+
inputs=[__fields[__k] for __k in ['left_position_block', 'right_position_block', 'output_state']],
|
|
491
|
+
outputs=__fields['highlight_block'],
|
|
492
|
+
queue=False,
|
|
493
|
+
show_progress='hidden')
|
|
494
|
+
__fields['right_position_block'].change(
|
|
495
|
+
fn=__highlight,
|
|
496
|
+
inputs=[__fields[__k] for __k in ['left_position_block', 'right_position_block', 'output_state']],
|
|
497
|
+
outputs=__fields['highlight_block'],
|
|
498
|
+
queue=False,
|
|
499
|
+
show_progress='hidden')
|
|
500
|
+
# gradio application
|
|
501
|
+
return __app
|
|
502
|
+
|
|
503
|
+
# MAIN #########################################################################
|
|
504
|
+
|
|
505
|
+
if __name__ == '__main__':
|
|
506
|
+
__app = create_app()
|
|
507
|
+
__app.launch(share=True, debug=True)
|
|
@@ -0,0 +1,187 @@
|
|
|
1
|
+
import functools
|
|
2
|
+
import math
|
|
3
|
+
|
|
4
|
+
import matplotlib
|
|
5
|
+
import numpy
|
|
6
|
+
import torch
|
|
7
|
+
|
|
8
|
+
import mlable.shapes
|
|
9
|
+
|
|
10
|
+
# GENERATE #######################################################################
|
|
11
|
+
|
|
12
|
+
@functools.lru_cache(maxsize=32)
|
|
13
|
+
def generate_token_ids(
|
|
14
|
+
model_obj: object,
|
|
15
|
+
input_ids: torch.Tensor,
|
|
16
|
+
token_num: int,
|
|
17
|
+
topk_num: int = 4,
|
|
18
|
+
topp_num: float = 0.9,
|
|
19
|
+
attention_mask: torch.Tensor=None,
|
|
20
|
+
) -> tuple:
|
|
21
|
+
# generate completion
|
|
22
|
+
with torch.no_grad():
|
|
23
|
+
__outputs = model_obj.generate(
|
|
24
|
+
input_ids=input_ids,
|
|
25
|
+
attention_mask=attention_mask,
|
|
26
|
+
max_new_tokens=token_num,
|
|
27
|
+
do_sample=(0.0 < topp_num < 1.0) or (topk_num > 0),
|
|
28
|
+
top_k=topk_num if (topk_num > 0) else None,
|
|
29
|
+
top_p=topp_num if (0.0 < topp_num < 1.0) else None,
|
|
30
|
+
return_dict_in_generate=True,
|
|
31
|
+
output_hidden_states=True,
|
|
32
|
+
output_attentions=False,
|
|
33
|
+
output_scores=False,
|
|
34
|
+
# early_stopping=True,
|
|
35
|
+
use_cache=True)
|
|
36
|
+
# ((B, T), O * L * (B, I, E))
|
|
37
|
+
return __outputs.sequences, __outputs.hidden_states
|
|
38
|
+
|
|
39
|
+
# MERGE ########################################################################
|
|
40
|
+
|
|
41
|
+
def merge_hidden_states(
|
|
42
|
+
hidden_data: torch.Tensor,
|
|
43
|
+
) -> torch.Tensor:
|
|
44
|
+
# parse the inputs
|
|
45
|
+
__token_dim = len(hidden_data)
|
|
46
|
+
__layer_dim = len(hidden_data[0])
|
|
47
|
+
# stack the data for each layer => (B, L, I + O, E)
|
|
48
|
+
return torch.stack(
|
|
49
|
+
[
|
|
50
|
+
# concatenate the data for all the tokens => (B, I + O, E)
|
|
51
|
+
torch.concatenate([hidden_data[__t][__l] for __t in range(__token_dim)], dim=1)
|
|
52
|
+
for __l in range(__layer_dim)],
|
|
53
|
+
dim=1)
|
|
54
|
+
|
|
55
|
+
# REDUCE #######################################################################
|
|
56
|
+
|
|
57
|
+
def reduce_hidden_states(
|
|
58
|
+
hidden_data: torch.Tensor, # (B, L, T, E)
|
|
59
|
+
layer_idx: int, # -1 => select all layers
|
|
60
|
+
token_idx: int, # -1 => select all tokens
|
|
61
|
+
axes_idx: int=2, # token sequence axis
|
|
62
|
+
) -> torch.Tensor:
|
|
63
|
+
# parse the hidden states (B, L, T, E)
|
|
64
|
+
__batch_dim, __layer_dim, __token_dim, __hidden_dim = tuple(hidden_data.shape)
|
|
65
|
+
__layer_idx = min(layer_idx, __layer_dim - 1)
|
|
66
|
+
__token_idx = min(token_idx, __token_dim - 1)
|
|
67
|
+
# select the relevant data along each axis
|
|
68
|
+
__layer_slice = slice(0, __layer_dim) if (__layer_idx < 0) else slice(__layer_idx, __layer_idx + 1)
|
|
69
|
+
__token_slice = slice(0, __token_dim) if (__token_idx < 0) else slice(__token_idx, __token_idx + 1)
|
|
70
|
+
# filter the data
|
|
71
|
+
__data = hidden_data[slice(None), __layer_slice, __token_slice, slice(None)]
|
|
72
|
+
# reduce the token axis => (B, L, E)
|
|
73
|
+
return __data.mean(dim=axes_idx, keepdim=False)
|
|
74
|
+
|
|
75
|
+
# RESCALE ######################################################################
|
|
76
|
+
|
|
77
|
+
def rescale_hidden_states(
|
|
78
|
+
hidden_data: torch.Tensor, # (B, L, E) or (B, E)
|
|
79
|
+
) -> torch.Tensor:
|
|
80
|
+
# compute the scale of the data, layer by layer
|
|
81
|
+
__s = torch.quantile(hidden_data.abs(), q=0.9, dim=-1, keepdim=True)
|
|
82
|
+
# log scaling on large values and linear near 0
|
|
83
|
+
__a = torch.asinh(hidden_data / (__s + torch.finfo().eps))
|
|
84
|
+
# clip and map to [-1; 1]
|
|
85
|
+
return 0.33 * __a.clamp(min=-3, max=3)
|
|
86
|
+
|
|
87
|
+
# RESHAPE ######################################################################
|
|
88
|
+
|
|
89
|
+
def reshape_hidden_states(
|
|
90
|
+
hidden_data: torch.Tensor, # (B, L, E) or (B, E)
|
|
91
|
+
layer_idx: int=1,
|
|
92
|
+
) -> torch.Tensor:
|
|
93
|
+
# parse the shape
|
|
94
|
+
__shape = tuple(hidden_data.shape)
|
|
95
|
+
# factor the hidden dimension
|
|
96
|
+
__factor = 2 ** round(0.5 * math.log2(__shape[-1]))
|
|
97
|
+
# compute the shape with the last axis split
|
|
98
|
+
__shape = mlable.shapes.divide(shape=__shape, axis=-1, factor=__factor, insert=True, right=True)
|
|
99
|
+
# move the layer axis at the end
|
|
100
|
+
__perm = mlable.shapes.move(shape=range(len(__shape)), before=layer_idx, after=-1)
|
|
101
|
+
# reshape into (B, W, H, L) or (B, W, H)
|
|
102
|
+
return hidden_data.reshape(__shape).permute(*__perm)
|
|
103
|
+
|
|
104
|
+
# MASK #########################################################################
|
|
105
|
+
|
|
106
|
+
def mask_hidden_states(
|
|
107
|
+
hidden_data: torch.Tensor, # (B, L, E)
|
|
108
|
+
topk_num: int=128,
|
|
109
|
+
) -> torch.Tensor:
|
|
110
|
+
# sanitize
|
|
111
|
+
__k = min(topk_num, int(hidden_data.shape[-1]))
|
|
112
|
+
# indices of the topk values
|
|
113
|
+
__indices = hidden_data.abs().topk(__k, dim=-1, largest=True, sorted=False).indices
|
|
114
|
+
# initialize the mask with False
|
|
115
|
+
__mask = torch.zeros_like(hidden_data, dtype=torch.bool)
|
|
116
|
+
# (B, L, E) mask of the topk values
|
|
117
|
+
return __mask.scatter_(dim=-1, index=__indices, value=True)
|
|
118
|
+
|
|
119
|
+
# FORMAT #######################################################################
|
|
120
|
+
|
|
121
|
+
def color_hidden_states(
|
|
122
|
+
hidden_data: numpy.array, # (B, H, W, L)
|
|
123
|
+
color_map: callable=matplotlib.colormaps['coolwarm'],
|
|
124
|
+
) -> list:
|
|
125
|
+
# [-1; 1] => [0; 1]
|
|
126
|
+
__data = 0.5 * (hidden_data + 1.0)
|
|
127
|
+
# (B, W, H, L) => (B, W, H, L, 4)
|
|
128
|
+
__rgba = color_map(__data)
|
|
129
|
+
# (B, W, H, L, 3) in [0; 1]
|
|
130
|
+
return __rgba[..., :3]
|
|
131
|
+
|
|
132
|
+
def size_hidden_states(
|
|
133
|
+
hidden_data: numpy.array, # (B, H, W, L)
|
|
134
|
+
area_min: float=0.01,
|
|
135
|
+
area_max: float=16.0,
|
|
136
|
+
gamma_val: float=1.6,
|
|
137
|
+
) -> list:
|
|
138
|
+
# [-1; 1] => [0; 1]
|
|
139
|
+
__data = numpy.abs(hidden_data)
|
|
140
|
+
# gamma < 1 will boost small values and > 1 emphasize larger values
|
|
141
|
+
__data = (__data + numpy.finfo(numpy.float32).eps) ** gamma_val
|
|
142
|
+
# map to point area
|
|
143
|
+
return area_min + (area_max - area_min) * __data
|
|
144
|
+
|
|
145
|
+
# KL SCORES ####################################################################
|
|
146
|
+
|
|
147
|
+
def kl_from_logprobs(
|
|
148
|
+
p_log: torch.Tensor,
|
|
149
|
+
q_log: torch.Tensor,
|
|
150
|
+
) -> torch.Tensor:
|
|
151
|
+
# compute the KL div from log probabilities (B, T, E) or (T, E)
|
|
152
|
+
return (p_log.exp() * (p_log - q_log)).sum(dim=-1)
|
|
153
|
+
|
|
154
|
+
def jsd_from_logits(
|
|
155
|
+
final_logits: torch.Tensor,
|
|
156
|
+
prefix_logits: torch.Tensor,
|
|
157
|
+
) -> torch.Tensor:
|
|
158
|
+
# compute the log probs from logits (B, T, E) or (T, E)
|
|
159
|
+
__p = torch.log_softmax(final_logits.float(), dim=-1)
|
|
160
|
+
__q = torch.log_softmax(prefix_logits.float(), dim=-1)
|
|
161
|
+
# m = 0.5(p+q) in log-space (logsumexp trick)
|
|
162
|
+
__m = torch.logsumexp(torch.stack([__p, __q], dim=0), dim=0) - math.log(2.0)
|
|
163
|
+
# compute the JSD metric
|
|
164
|
+
__jsd = 0.5 * kl_from_logprobs(__p, __m) + 0.5 * kl_from_logprobs(__q, __m)
|
|
165
|
+
# scale to [0; 1]
|
|
166
|
+
return (__jsd / math.log(2.0)).clamp(0.0, 1.0)
|
|
167
|
+
|
|
168
|
+
# POSTPROCESS ##################################################################
|
|
169
|
+
|
|
170
|
+
def postprocess_focus_cls(
|
|
171
|
+
left_idx: int,
|
|
172
|
+
right_idx: int,
|
|
173
|
+
token_dim: int,
|
|
174
|
+
) -> list:
|
|
175
|
+
__left_idx = max(-1, min(token_dim, left_idx))
|
|
176
|
+
__right_idx = max(-1, min(token_dim, right_idx))
|
|
177
|
+
# class 1 for the token(s) focused on the left, 0 for the rest
|
|
178
|
+
__left_cls = token_dim * [1] if (__left_idx < 0) else [int(__i == __left_idx) for __i in range(token_dim)]
|
|
179
|
+
# class 2 for the token(s) focused on the right, 0 for the rest
|
|
180
|
+
__right_cls = token_dim * [2] if (__right_idx < 0) else [2 * int(__i == __right_idx) for __i in range(token_dim)]
|
|
181
|
+
# sum the classes so that the overlap has class 3
|
|
182
|
+
return [str(__l + __r) for __l, __r in zip(__left_cls, __right_cls)]
|
|
183
|
+
|
|
184
|
+
def postprocess_score_cls(
|
|
185
|
+
score_data: torch.Tensor
|
|
186
|
+
) -> list:
|
|
187
|
+
return [str(__s) for __s in score_data.numpy().tolist()]
|