psaiops 0.4.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of psaiops might be problematic. Click here for more details.

@@ -0,0 +1,507 @@
1
+ import functools
2
+
3
+ import gradio
4
+ import numpy
5
+ import torch
6
+ import torch.cuda
7
+ import matplotlib.pyplot
8
+
9
+ import psaiops.common.model
10
+ import psaiops.common.tokenizer
11
+ import psaiops.score.residual.lib
12
+
13
+ # META #########################################################################
14
+
15
+ STYLE = '''.white-text span { color: white; }'''
16
+ TITLE = '''Visualization Of Residuals'''
17
+ INTRO = '''Plot the hidden states for a given prompt.\nUnder construction, only "openai/gpt-oss-20b" is available for now.'''
18
+
19
+ MODEL = 'openai/gpt-oss-20b'
20
+
21
+ # COLORS #######################################################################
22
+
23
+ def create_selection_cmap() -> dict:
24
+ return {
25
+ '0': '#000000',
26
+ '1': '#004444',
27
+ '2': '#444400',
28
+ '3': '#440044',}
29
+
30
+ def create_score_cmap() -> dict:
31
+ return {str(__i): '#{:02x}0000'.format(int(2.55 * __i)) for __i in range(101)}
32
+
33
+ # INTRO ########################################################################
34
+
35
+ def create_intro_block(intro: str) -> dict:
36
+ __intro = gradio.Markdown(intro, line_breaks=True)
37
+ return {'intro_block': __intro}
38
+
39
+ # MODEL ########################################################################
40
+
41
+ def create_model_block() -> dict:
42
+ __model = gradio.Dropdown(label='Model', value='openai/gpt-oss-20b', choices=['openai/gpt-oss-20b'], scale=1, allow_custom_value=False, multiselect=False, interactive=True) # 'openai/gpt-oss-120b'
43
+ return {'model_block': __model,}
44
+
45
+ # SAMPLING #####################################################################
46
+
47
+ def create_sampling_block() -> dict:
48
+ __tokens = gradio.Slider(label='Tokens', value=16, minimum=1, maximum=128, step=1, scale=1, interactive=True)
49
+ __topk = gradio.Slider(label='Top K', value=4, minimum=1, maximum=8, step=1, scale=1, interactive=True)
50
+ __topp = gradio.Slider(label='Top P', value=0.9, minimum=0.0, maximum=1.0, step=0.1, scale=1, interactive=True)
51
+ return {
52
+ 'tokens_block': __tokens,
53
+ 'topk_block': __topk,
54
+ 'topp_block': __topp,}
55
+
56
+ # DATAVIZ ######################################################################
57
+
58
+ def create_visualization_block() -> dict:
59
+ __3d = gradio.Slider(label='3D', value=1, minimum=0, maximum=1, step=1, scale=1, interactive=True)
60
+ __points = gradio.Slider(label='Points', value=128, minimum=32, maximum=2880, step=32, scale=1, interactive=True)
61
+ return {
62
+ 'axes_block': __3d,
63
+ 'points_block': __points,}
64
+
65
+ # INPUTS #######################################################################
66
+
67
+ def create_inputs_block(label: str='Prompt') -> dict:
68
+ __input = gradio.Textbox(label=label, value='', placeholder='A string of tokens to score.', lines=4, scale=1, interactive=True)
69
+ return {'input_block': __input}
70
+
71
+ # PLOTS ########################################################################
72
+
73
+ def create_plot_block(label: str='Residuals', prefix: str='') -> dict:
74
+ __plot = gradio.Plot(label=label, scale=1)
75
+ return {prefix + 'plot_block': __plot,}
76
+
77
+ # OUTPUTS ######################################################################
78
+
79
+ def create_highlight_block(label: str='Output', prefix: str='', cmap: dict=create_selection_cmap()) -> dict:
80
+ __output = gradio.HighlightedText(label=label, value='', scale=1, interactive=False, show_legend=False, show_inline_category=False, combine_adjacent=False, color_map=cmap, elem_classes='white-text')
81
+ return {prefix + 'highlight_block': __output}
82
+
83
+ # SELECT #######################################################################
84
+
85
+ def create_token_selection_block(label: str='Token', prefix: str='') -> dict:
86
+ # __play = gradio.Button('>', variant='primary', size='lg', scale=1, interactive=True)
87
+ __position = gradio.Slider(label=label, value=-1, minimum=-1, maximum=15, step=1, scale=1, interactive=True) # info='-1 to average on all tokens'
88
+ return {prefix + 'position_block': __position,}
89
+
90
+ def create_layer_selection_block(label: str='Layer', prefix: str='') -> dict:
91
+ __layer = gradio.Slider(label=label, value=-1, minimum=-1, maximum=23, step=1, scale=1, interactive=True) # info='-1 to average on all layers'
92
+ return {prefix + 'layer_block': __layer,}
93
+
94
+ # ACTIONS ######################################################################
95
+
96
+ def create_actions_block() -> dict:
97
+ __process = gradio.Button('Process', variant='primary', size='lg', scale=1, interactive=True)
98
+ return {'process_block': __process,}
99
+
100
+ # STATE ########################################################################
101
+
102
+ def create_state() -> dict:
103
+ return {
104
+ 'output_state': gradio.State(None),
105
+ 'hidden_state': gradio.State(None),}
106
+
107
+ # LAYOUT #######################################################################
108
+
109
+ def create_layout(intro: str=INTRO) -> dict:
110
+ __fields = {}
111
+ __fields.update(create_intro_block(intro=intro))
112
+ with gradio.Tabs():
113
+ with gradio.Tab('Residuals') as __main_tab:
114
+ __fields.update({'main_tab': __main_tab})
115
+ with gradio.Row(equal_height=True):
116
+ __fields.update(create_inputs_block())
117
+ with gradio.Row(equal_height=True):
118
+ __fields.update(create_highlight_block())
119
+ with gradio.Row(equal_height=True):
120
+ __fields.update(create_plot_block(label='Left', prefix='left_'))
121
+ __fields.update(create_plot_block(label='Right', prefix='right_'))
122
+ with gradio.Row(equal_height=True):
123
+ __fields.update(create_highlight_block(label='Score', prefix='left_', cmap=create_score_cmap()))
124
+ __fields.update(create_highlight_block(label='Score', prefix='right_', cmap=create_score_cmap()))
125
+ with gradio.Row(equal_height=True):
126
+ __fields.update(create_token_selection_block(label='Token', prefix='left_'))
127
+ __fields.update(create_token_selection_block(label='Token', prefix='right_'))
128
+ with gradio.Row(equal_height=True):
129
+ __fields.update(create_layer_selection_block(label='Layer', prefix='left_'))
130
+ __fields.update(create_layer_selection_block(label='Layer', prefix='right_'))
131
+ with gradio.Row(equal_height=True):
132
+ __fields.update(create_actions_block())
133
+ with gradio.Tab('Settings') as __settings_tab:
134
+ __fields.update({'settings_tab': __settings_tab})
135
+ with gradio.Row(equal_height=True):
136
+ __fields.update(create_model_block())
137
+ with gradio.Row(equal_height=True):
138
+ __fields.update(create_sampling_block())
139
+ with gradio.Row(equal_height=True):
140
+ __fields.update(create_visualization_block())
141
+ return __fields
142
+
143
+ # EVENTS #######################################################################
144
+
145
+ def update_position_range(
146
+ current_val: float,
147
+ token_num: float,
148
+ output_data: torch.Tensor,
149
+ ) -> dict:
150
+ # take the generated tokens into account
151
+ __max = int(token_num) - 1 if (output_data is None) else int(output_data.shape[-1])
152
+ # keep the previous value if possible
153
+ __val = min(int(current_val), __max)
154
+ # return a gradio update dictionary
155
+ return gradio.update(maximum=__max, value=__val)
156
+
157
+ # GENERATE #####################################################################
158
+
159
+ def update_computation_state(
160
+ token_num: float,
161
+ topk_num: float,
162
+ topp_num: float,
163
+ prompt_str: str,
164
+ device_str: str,
165
+ model_obj: object,
166
+ tokenizer_obj: object,
167
+ ) -> tuple:
168
+ # sanitize the inputs
169
+ __token_num = max(1, min(128, int(token_num)))
170
+ __topk_num = max(1, min(8, int(topk_num)))
171
+ __topp_num = max(0.0, min(1.0, float(topp_num)))
172
+ __prompt_str = prompt_str.strip()
173
+ __device_str = device_str if (device_str in ['cpu', 'cuda']) else 'cpu'
174
+ # exit if some values are missing
175
+ if (not __prompt_str) or (model_obj is None) or (tokenizer_obj is None):
176
+ return (torch.empty(0), torch.empty(0))
177
+ # dictionary {'input_ids': _, 'attention_mask': _}
178
+ __input_data = psaiops.common.tokenizer.preprocess_token_ids(
179
+ tokenizer_obj=tokenizer_obj,
180
+ prompt_str=__prompt_str,
181
+ device_str=__device_str)
182
+ # tensor (1, T) and O * L * (1, I, H)
183
+ __output_data, __hidden_data = psaiops.score.residual.lib.generate_token_ids(
184
+ model_obj=model_obj,
185
+ input_ids=__input_data['input_ids'],
186
+ attention_mask=__input_data['attention_mask'],
187
+ token_num=__token_num,
188
+ topk_num=__topk_num,
189
+ topp_num=__topp_num)
190
+ # tensor (1, L, I + O, H)
191
+ __hidden_data = psaiops.score.residual.lib.merge_hidden_states(
192
+ hidden_data=__hidden_data)
193
+ # update each component => (highlight, plot) states
194
+ return (
195
+ __output_data.cpu().float(),
196
+ __hidden_data.cpu().float(),)
197
+
198
+ # HIGHLIGHT ####################################################################
199
+
200
+ def update_token_focus(
201
+ left_idx: float,
202
+ right_idx: float,
203
+ output_data: torch.Tensor,
204
+ tokenizer_obj: object,
205
+ ) -> list:
206
+ # exit if some values are missing
207
+ if (output_data is None) or (len(output_data) == 0):
208
+ return None
209
+ # detokenize the IDs
210
+ __token_str = psaiops.common.tokenizer.postprocess_token_ids(
211
+ tokenizer_obj=tokenizer_obj,
212
+ token_data=output_data)
213
+ # list of string classes
214
+ __token_cls = psaiops.score.residual.lib.postprocess_focus_cls(
215
+ left_idx=int(left_idx),
216
+ right_idx=int(right_idx),
217
+ token_dim=len(__token_str))
218
+ # pairs of token and class
219
+ return list(zip(__token_str, __token_cls))
220
+
221
+ # SCORES #######################################################################
222
+
223
+ def update_token_scores(
224
+ layer_idx: float,
225
+ output_data: torch.Tensor,
226
+ hidden_data: torch.Tensor,
227
+ tokenizer_obj: object,
228
+ model_obj: object,
229
+ ) -> list:
230
+ # exit if some values are missing
231
+ if (output_data is None) or (len(output_data) == 0) or (hidden_data is None) or (len(hidden_data) == 0):
232
+ return None
233
+ # parse the model meta
234
+ __device_str = model_obj.lm_head.weight.device
235
+ __dtype_obj = model_obj.lm_head.weight.dtype
236
+ # detokenize the IDs
237
+ __token_str = psaiops.common.tokenizer.postprocess_token_ids(
238
+ tokenizer_obj=tokenizer_obj,
239
+ token_data=output_data)
240
+ # select the relevant hidden states
241
+ __final_states = hidden_data[0, -1, :, :].to(device=__device_str, dtype=__dtype_obj)
242
+ __layer_states = hidden_data[0, int(layer_idx), :, :].to(device=__device_str, dtype=__dtype_obj)
243
+ # compute the logits
244
+ __final_logits = model_obj.lm_head(__final_states).detach().cpu() # already normalized
245
+ __layer_logits = model_obj.lm_head(model_obj.model.norm(__layer_states)).detach().cpu()
246
+ # compute the JSD metric
247
+ __token_jsd = jsd_from_logits(final_logits=__final_logits, prefix_logits=__layer_logits)
248
+ # scale into a [0; 100] label
249
+ __token_cls = postprocess_score_cls(score_data=__token_jsd)
250
+ # color each token according to the distance between the distribution at layer L and the final distribution
251
+ return list(zip(__token_str, __token_cls))
252
+
253
+ # PLOT #########################################################################
254
+
255
+ def update_2d_plot(
256
+ token_idx: float,
257
+ layer_idx: float,
258
+ hidden_data: torch.Tensor,
259
+ ) -> tuple:
260
+ # reduce the layer and token axes (B, L, T, E) => (B, E)
261
+ __plot_data = psaiops.score.residual.lib.reduce_hidden_states(
262
+ hidden_data=hidden_data,
263
+ layer_idx=int(layer_idx),
264
+ token_idx=int(token_idx),
265
+ axes_idx=(1, 2))
266
+ # rescale the data to [-1; 1] (B, E)
267
+ __plot_data = psaiops.score.residual.lib.rescale_hidden_states(
268
+ hidden_data=__plot_data)
269
+ # reshape into a 3D tensor by folding E (B, E) => (B, W, H)
270
+ __plot_data = psaiops.score.residual.lib.reshape_hidden_states(
271
+ hidden_data=__plot_data,
272
+ layer_idx=-1) # there is no layer axis
273
+ # map the [-1; 1] activations to RGBA colors
274
+ __plot_data = psaiops.score.residual.lib.color_hidden_states(
275
+ hidden_data=__plot_data.numpy())
276
+ # plot the first sample
277
+ __figure = matplotlib.pyplot.figure()
278
+ __axes = __figure.add_subplot(1, 1, 1)
279
+ __axes.imshow(__plot_data[0], vmin=0.0, vmax=1.0, cmap='viridis')
280
+ __figure.tight_layout()
281
+ # remove the figure for the pyplot register for garbage collection
282
+ matplotlib.pyplot.close(__figure)
283
+ # update each component => (highlight, plot) states
284
+ return __figure
285
+
286
+ def update_3d_plot(
287
+ token_idx: float,
288
+ layer_idx: float,
289
+ points_num: float,
290
+ hidden_data: torch.Tensor,
291
+ ) -> tuple:
292
+ # reduce the token axis (B, L, T, E) => (B, L, E)
293
+ __plot_data = psaiops.score.residual.lib.reduce_hidden_states(
294
+ hidden_data=hidden_data,
295
+ token_idx=int(token_idx),
296
+ layer_idx=int(layer_idx),
297
+ axes_idx=2)
298
+ # rescale the data to [-1; 1] (B, L, E)
299
+ __plot_data = psaiops.score.residual.lib.rescale_hidden_states(
300
+ hidden_data=__plot_data)
301
+ # mask the small activations to improve the plot readability
302
+ __mask_data = psaiops.score.residual.lib.mask_hidden_states(
303
+ hidden_data=__plot_data,
304
+ topk_num=int(points_num) if int(layer_idx) == -1 else 2 * int(points_num))
305
+ # reshape into a 3D tensor by folding E (B, L, E) => (B, W, H, L)
306
+ __plot_data = psaiops.score.residual.lib.reshape_hidden_states(
307
+ hidden_data=__plot_data,
308
+ layer_idx=1)
309
+ __mask_data = psaiops.score.residual.lib.reshape_hidden_states(
310
+ hidden_data=__mask_data,
311
+ layer_idx=1)
312
+ # convert to numpy ndarrays
313
+ __plot_data = __plot_data.numpy()
314
+ __mask_data = __mask_data.numpy()
315
+ # map the [-1; 1] activations to RGBA colors
316
+ __rgb_data = psaiops.score.residual.lib.color_hidden_states(
317
+ hidden_data=__plot_data)
318
+ # map the [-1; 1] activations to point areas
319
+ __area_data = psaiops.score.residual.lib.size_hidden_states(
320
+ hidden_data=__plot_data,
321
+ area_min=0.01,
322
+ area_max=16.0,
323
+ gamma_val=1.6)
324
+ # format the first sample for a scatter plot
325
+ __x, __y, __z = numpy.nonzero(__mask_data[0])
326
+ __c = __rgb_data[0, __x, __y, __z]
327
+ __s = __area_data[0, __x, __y, __z]
328
+ # plot the first sample
329
+ __figure = matplotlib.pyplot.figure()
330
+ __axes = __figure.add_subplot(1, 1, 1, projection='3d')
331
+ __axes.scatter(__x, __y, __z, c=__c, s=__s, marker='o', linewidths=0)
332
+ __figure.tight_layout()
333
+ # remove the figure for the pyplot register for garbage collection
334
+ matplotlib.pyplot.close(__figure)
335
+ # update each component => (highlight, plot) states
336
+ return __figure
337
+
338
+ def update_hidden_plot(
339
+ token_idx: float,
340
+ layer_idx: float,
341
+ axes_num: float,
342
+ points_num: float,
343
+ hidden_data: torch.Tensor,
344
+ ) -> tuple:
345
+ # exit if some values are missing
346
+ if (hidden_data is None) or (len(hidden_data) == 0):
347
+ return None
348
+ # plot the residuals of a given layer on a 2D heatmap
349
+ if not axes_num: # 0.0 or 0
350
+ return update_2d_plot(
351
+ token_idx=token_idx,
352
+ layer_idx=layer_idx,
353
+ hidden_data=hidden_data)
354
+ # by default, plot the residuals for all the layers in 3D
355
+ return update_3d_plot(
356
+ token_idx=token_idx,
357
+ layer_idx=layer_idx,
358
+ points_num=points_num,
359
+ hidden_data=hidden_data)
360
+
361
+ # APP ##########################################################################
362
+
363
+ def create_app(title: str=TITLE, intro: str=INTRO, style: str=STYLE, model: str=MODEL) -> gradio.Blocks:
364
+ __fields = {}
365
+ with gradio.Blocks(theme=gradio.themes.Soft(), title=title, css=style) as __app:
366
+ # load the model
367
+ __device = 'cuda' if torch.cuda.is_available() else 'cpu'
368
+ __model = psaiops.common.model.get_model(name=model, device=__device)
369
+ __tokenizer = psaiops.common.tokenizer.get_tokenizer(name=model, device=__device)
370
+ # adapt the event handlers
371
+ __compute = functools.partial(update_computation_state, model_obj=__model, tokenizer_obj=__tokenizer, device_str=__device)
372
+ __highlight = functools.partial(update_token_focus, tokenizer_obj=__tokenizer)
373
+ __score = functools.partial(update_token_scores, tokenizer_obj=__tokenizer, model_obj=__model)
374
+ # create the UI
375
+ __fields.update(create_layout(intro=intro))
376
+ # init the state
377
+ __fields.update(create_state())
378
+ # update the data after clicking process
379
+ __fields['process_block'].click(
380
+ fn=__compute,
381
+ inputs=[__fields[__k] for __k in ['tokens_block', 'topk_block', 'topp_block', 'input_block']],
382
+ outputs=[__fields[__k] for __k in ['output_state', 'hidden_state']],
383
+ queue=False,
384
+ show_progress='full'
385
+ ).then(
386
+ # update the range of the position sliders when the output changes
387
+ fn=update_position_range,
388
+ inputs=[__fields[__k] for __k in ['left_position_block', 'tokens_block', 'output_state']],
389
+ outputs=__fields['left_position_block'],
390
+ queue=False,
391
+ show_progress='hidden'
392
+ ).then(
393
+ fn=update_position_range,
394
+ inputs=[__fields[__k] for __k in ['right_position_block', 'tokens_block', 'output_state']],
395
+ outputs=__fields['right_position_block'],
396
+ queue=False,
397
+ show_progress='hidden'
398
+ ).then(
399
+ # update the token highlight when the output data changes
400
+ fn=__highlight,
401
+ inputs=[__fields[__k] for __k in ['left_position_block', 'right_position_block', 'output_state']],
402
+ outputs=__fields['highlight_block'],
403
+ queue=False,
404
+ show_progress='hidden'
405
+ ).then(
406
+ # update the left token scores when the output data changes
407
+ fn=__score,
408
+ inputs=[__fields[__k] for __k in ['left_layer_block', 'output_state', 'hidden_state']],
409
+ outputs=__fields['left_highlight_block'],
410
+ queue=False,
411
+ show_progress='hidden'
412
+ ).then(
413
+ # update the right token scores when the output data changes
414
+ fn=__score,
415
+ inputs=[__fields[__k] for __k in ['right_layer_block', 'output_state', 'hidden_state']],
416
+ outputs=__fields['right_highlight_block'],
417
+ queue=False,
418
+ show_progress='hidden'
419
+ ).then(
420
+ # update the plot when the router data changes
421
+ fn=update_hidden_plot,
422
+ inputs=[__fields[__k] for __k in ['left_position_block', 'left_layer_block', 'axes_block', 'points_block', 'hidden_state']],
423
+ outputs=__fields['left_plot_block'],
424
+ queue=False,
425
+ show_progress='hidden'
426
+ ).then(
427
+ fn=update_hidden_plot,
428
+ inputs=[__fields[__k] for __k in ['right_position_block', 'right_layer_block', 'axes_block', 'points_block', 'hidden_state']],
429
+ outputs=__fields['right_plot_block'],
430
+ queue=False,
431
+ show_progress='hidden')
432
+ # update the range of the position slider when the settings change
433
+ __fields['tokens_block'].change(
434
+ fn=update_position_range,
435
+ inputs=[__fields[__k] for __k in ['left_position_block', 'tokens_block', 'output_state']],
436
+ outputs=__fields['left_position_block'],
437
+ queue=False,
438
+ show_progress='hidden'
439
+ ).then(
440
+ fn=update_position_range,
441
+ inputs=[__fields[__k] for __k in ['right_position_block', 'tokens_block', 'output_state']],
442
+ outputs=__fields['right_position_block'],
443
+ queue=False,
444
+ show_progress='hidden')
445
+ # update the left plot when the focus changes
446
+ __fields['left_position_block'].change(
447
+ fn=update_hidden_plot,
448
+ inputs=[__fields[__k] for __k in ['left_position_block', 'left_layer_block', 'axes_block', 'points_block', 'hidden_state']],
449
+ outputs=__fields['left_plot_block'],
450
+ queue=False,
451
+ show_progress='hidden')
452
+ __fields['left_layer_block'].change(
453
+ fn=update_hidden_plot,
454
+ inputs=[__fields[__k] for __k in ['left_position_block', 'left_layer_block', 'axes_block', 'points_block', 'hidden_state']],
455
+ outputs=__fields['left_plot_block'],
456
+ queue=False,
457
+ show_progress='hidden'
458
+ ).then(
459
+ # update the left token scores when the focus changes
460
+ fn=__score,
461
+ inputs=[__fields[__k] for __k in ['left_layer_block', 'output_state', 'hidden_state']],
462
+ outputs=__fields['left_highlight_block'],
463
+ queue=False,
464
+ show_progress='hidden'
465
+ )
466
+ # update the right plot when the focus changes
467
+ __fields['right_position_block'].change(
468
+ fn=update_hidden_plot,
469
+ inputs=[__fields[__k] for __k in ['right_position_block', 'right_layer_block', 'axes_block', 'points_block', 'hidden_state']],
470
+ outputs=__fields['right_plot_block'],
471
+ queue=False,
472
+ show_progress='hidden')
473
+ __fields['right_layer_block'].change(
474
+ fn=update_hidden_plot,
475
+ inputs=[__fields[__k] for __k in ['right_position_block', 'right_layer_block', 'axes_block', 'points_block', 'hidden_state']],
476
+ outputs=__fields['right_plot_block'],
477
+ queue=False,
478
+ show_progress='hidden'
479
+ ).then(
480
+ # update the right token scores when the focus changes
481
+ fn=__score,
482
+ inputs=[__fields[__k] for __k in ['right_layer_block', 'output_state', 'hidden_state']],
483
+ outputs=__fields['right_highlight_block'],
484
+ queue=False,
485
+ show_progress='hidden'
486
+ )
487
+ # update the token highlight when the token focus changes
488
+ __fields['left_position_block'].change(
489
+ fn=__highlight,
490
+ inputs=[__fields[__k] for __k in ['left_position_block', 'right_position_block', 'output_state']],
491
+ outputs=__fields['highlight_block'],
492
+ queue=False,
493
+ show_progress='hidden')
494
+ __fields['right_position_block'].change(
495
+ fn=__highlight,
496
+ inputs=[__fields[__k] for __k in ['left_position_block', 'right_position_block', 'output_state']],
497
+ outputs=__fields['highlight_block'],
498
+ queue=False,
499
+ show_progress='hidden')
500
+ # gradio application
501
+ return __app
502
+
503
+ # MAIN #########################################################################
504
+
505
+ if __name__ == '__main__':
506
+ __app = create_app()
507
+ __app.launch(share=True, debug=True)
@@ -0,0 +1,187 @@
1
+ import functools
2
+ import math
3
+
4
+ import matplotlib
5
+ import numpy
6
+ import torch
7
+
8
+ import mlable.shapes
9
+
10
+ # GENERATE #######################################################################
11
+
12
+ @functools.lru_cache(maxsize=32)
13
+ def generate_token_ids(
14
+ model_obj: object,
15
+ input_ids: torch.Tensor,
16
+ token_num: int,
17
+ topk_num: int = 4,
18
+ topp_num: float = 0.9,
19
+ attention_mask: torch.Tensor=None,
20
+ ) -> tuple:
21
+ # generate completion
22
+ with torch.no_grad():
23
+ __outputs = model_obj.generate(
24
+ input_ids=input_ids,
25
+ attention_mask=attention_mask,
26
+ max_new_tokens=token_num,
27
+ do_sample=(0.0 < topp_num < 1.0) or (topk_num > 0),
28
+ top_k=topk_num if (topk_num > 0) else None,
29
+ top_p=topp_num if (0.0 < topp_num < 1.0) else None,
30
+ return_dict_in_generate=True,
31
+ output_hidden_states=True,
32
+ output_attentions=False,
33
+ output_scores=False,
34
+ # early_stopping=True,
35
+ use_cache=True)
36
+ # ((B, T), O * L * (B, I, E))
37
+ return __outputs.sequences, __outputs.hidden_states
38
+
39
+ # MERGE ########################################################################
40
+
41
+ def merge_hidden_states(
42
+ hidden_data: torch.Tensor,
43
+ ) -> torch.Tensor:
44
+ # parse the inputs
45
+ __token_dim = len(hidden_data)
46
+ __layer_dim = len(hidden_data[0])
47
+ # stack the data for each layer => (B, L, I + O, E)
48
+ return torch.stack(
49
+ [
50
+ # concatenate the data for all the tokens => (B, I + O, E)
51
+ torch.concatenate([hidden_data[__t][__l] for __t in range(__token_dim)], dim=1)
52
+ for __l in range(__layer_dim)],
53
+ dim=1)
54
+
55
+ # REDUCE #######################################################################
56
+
57
+ def reduce_hidden_states(
58
+ hidden_data: torch.Tensor, # (B, L, T, E)
59
+ layer_idx: int, # -1 => select all layers
60
+ token_idx: int, # -1 => select all tokens
61
+ axes_idx: int=2, # token sequence axis
62
+ ) -> torch.Tensor:
63
+ # parse the hidden states (B, L, T, E)
64
+ __batch_dim, __layer_dim, __token_dim, __hidden_dim = tuple(hidden_data.shape)
65
+ __layer_idx = min(layer_idx, __layer_dim - 1)
66
+ __token_idx = min(token_idx, __token_dim - 1)
67
+ # select the relevant data along each axis
68
+ __layer_slice = slice(0, __layer_dim) if (__layer_idx < 0) else slice(__layer_idx, __layer_idx + 1)
69
+ __token_slice = slice(0, __token_dim) if (__token_idx < 0) else slice(__token_idx, __token_idx + 1)
70
+ # filter the data
71
+ __data = hidden_data[slice(None), __layer_slice, __token_slice, slice(None)]
72
+ # reduce the token axis => (B, L, E)
73
+ return __data.mean(dim=axes_idx, keepdim=False)
74
+
75
+ # RESCALE ######################################################################
76
+
77
+ def rescale_hidden_states(
78
+ hidden_data: torch.Tensor, # (B, L, E) or (B, E)
79
+ ) -> torch.Tensor:
80
+ # compute the scale of the data, layer by layer
81
+ __s = torch.quantile(hidden_data.abs(), q=0.9, dim=-1, keepdim=True)
82
+ # log scaling on large values and linear near 0
83
+ __a = torch.asinh(hidden_data / (__s + torch.finfo().eps))
84
+ # clip and map to [-1; 1]
85
+ return 0.33 * __a.clamp(min=-3, max=3)
86
+
87
+ # RESHAPE ######################################################################
88
+
89
+ def reshape_hidden_states(
90
+ hidden_data: torch.Tensor, # (B, L, E) or (B, E)
91
+ layer_idx: int=1,
92
+ ) -> torch.Tensor:
93
+ # parse the shape
94
+ __shape = tuple(hidden_data.shape)
95
+ # factor the hidden dimension
96
+ __factor = 2 ** round(0.5 * math.log2(__shape[-1]))
97
+ # compute the shape with the last axis split
98
+ __shape = mlable.shapes.divide(shape=__shape, axis=-1, factor=__factor, insert=True, right=True)
99
+ # move the layer axis at the end
100
+ __perm = mlable.shapes.move(shape=range(len(__shape)), before=layer_idx, after=-1)
101
+ # reshape into (B, W, H, L) or (B, W, H)
102
+ return hidden_data.reshape(__shape).permute(*__perm)
103
+
104
+ # MASK #########################################################################
105
+
106
+ def mask_hidden_states(
107
+ hidden_data: torch.Tensor, # (B, L, E)
108
+ topk_num: int=128,
109
+ ) -> torch.Tensor:
110
+ # sanitize
111
+ __k = min(topk_num, int(hidden_data.shape[-1]))
112
+ # indices of the topk values
113
+ __indices = hidden_data.abs().topk(__k, dim=-1, largest=True, sorted=False).indices
114
+ # initialize the mask with False
115
+ __mask = torch.zeros_like(hidden_data, dtype=torch.bool)
116
+ # (B, L, E) mask of the topk values
117
+ return __mask.scatter_(dim=-1, index=__indices, value=True)
118
+
119
+ # FORMAT #######################################################################
120
+
121
+ def color_hidden_states(
122
+ hidden_data: numpy.array, # (B, H, W, L)
123
+ color_map: callable=matplotlib.colormaps['coolwarm'],
124
+ ) -> list:
125
+ # [-1; 1] => [0; 1]
126
+ __data = 0.5 * (hidden_data + 1.0)
127
+ # (B, W, H, L) => (B, W, H, L, 4)
128
+ __rgba = color_map(__data)
129
+ # (B, W, H, L, 3) in [0; 1]
130
+ return __rgba[..., :3]
131
+
132
+ def size_hidden_states(
133
+ hidden_data: numpy.array, # (B, H, W, L)
134
+ area_min: float=0.01,
135
+ area_max: float=16.0,
136
+ gamma_val: float=1.6,
137
+ ) -> list:
138
+ # [-1; 1] => [0; 1]
139
+ __data = numpy.abs(hidden_data)
140
+ # gamma < 1 will boost small values and > 1 emphasize larger values
141
+ __data = (__data + numpy.finfo(numpy.float32).eps) ** gamma_val
142
+ # map to point area
143
+ return area_min + (area_max - area_min) * __data
144
+
145
+ # KL SCORES ####################################################################
146
+
147
+ def kl_from_logprobs(
148
+ p_log: torch.Tensor,
149
+ q_log: torch.Tensor,
150
+ ) -> torch.Tensor:
151
+ # compute the KL div from log probabilities (B, T, E) or (T, E)
152
+ return (p_log.exp() * (p_log - q_log)).sum(dim=-1)
153
+
154
+ def jsd_from_logits(
155
+ final_logits: torch.Tensor,
156
+ prefix_logits: torch.Tensor,
157
+ ) -> torch.Tensor:
158
+ # compute the log probs from logits (B, T, E) or (T, E)
159
+ __p = torch.log_softmax(final_logits.float(), dim=-1)
160
+ __q = torch.log_softmax(prefix_logits.float(), dim=-1)
161
+ # m = 0.5(p+q) in log-space (logsumexp trick)
162
+ __m = torch.logsumexp(torch.stack([__p, __q], dim=0), dim=0) - math.log(2.0)
163
+ # compute the JSD metric
164
+ __jsd = 0.5 * kl_from_logprobs(__p, __m) + 0.5 * kl_from_logprobs(__q, __m)
165
+ # scale to [0; 1]
166
+ return (__jsd / math.log(2.0)).clamp(0.0, 1.0)
167
+
168
+ # POSTPROCESS ##################################################################
169
+
170
+ def postprocess_focus_cls(
171
+ left_idx: int,
172
+ right_idx: int,
173
+ token_dim: int,
174
+ ) -> list:
175
+ __left_idx = max(-1, min(token_dim, left_idx))
176
+ __right_idx = max(-1, min(token_dim, right_idx))
177
+ # class 1 for the token(s) focused on the left, 0 for the rest
178
+ __left_cls = token_dim * [1] if (__left_idx < 0) else [int(__i == __left_idx) for __i in range(token_dim)]
179
+ # class 2 for the token(s) focused on the right, 0 for the rest
180
+ __right_cls = token_dim * [2] if (__right_idx < 0) else [2 * int(__i == __right_idx) for __i in range(token_dim)]
181
+ # sum the classes so that the overlap has class 3
182
+ return [str(__l + __r) for __l, __r in zip(__left_cls, __right_cls)]
183
+
184
+ def postprocess_score_cls(
185
+ score_data: torch.Tensor
186
+ ) -> list:
187
+ return [str(__s) for __s in score_data.numpy().tolist()]