psaiops 0.3.2__py3-none-any.whl → 0.4.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- psaiops/combine/app.py +366 -0
- psaiops/common/data.py +8 -0
- psaiops/common/model.py +45 -0
- psaiops/common/tokenizer.py +41 -0
- psaiops/compose/contrast/app.py +4 -2
- psaiops/compose/contrast/lib.py +3 -38
- psaiops/compose/maths/app.py +8 -5
- psaiops/compose/maths/lib.py +0 -41
- psaiops/score/attention/app.py +8 -6
- psaiops/score/attention/lib.py +8 -83
- psaiops/score/residual/app.py +291 -0
- psaiops/score/residual/lib.py +134 -0
- psaiops/score/router/app.py +161 -13
- psaiops/score/router/lib.py +50 -57
- psaiops/score/shapley/app.py +5 -4
- psaiops/score/shapley/lib.py +0 -65
- psaiops/score/similarity/__init__.py +0 -0
- psaiops/score/similarity/app.py +152 -0
- psaiops/score/similarity/lib.py +1 -0
- {psaiops-0.3.2.dist-info → psaiops-0.4.2.dist-info}/METADATA +14 -19
- psaiops-0.4.2.dist-info/RECORD +36 -0
- {psaiops-0.3.2.dist-info → psaiops-0.4.2.dist-info}/WHEEL +1 -1
- psaiops-0.4.2.dist-info/licenses/.github/LICENSE.md +661 -0
- psaiops/common/dropdown.py +0 -19
- psaiops-0.3.2.dist-info/RECORD +0 -28
- /psaiops/{steer → score/residual}/__init__.py +0 -0
psaiops/score/attention/lib.py
CHANGED
|
@@ -1,69 +1,7 @@
|
|
|
1
|
-
import functools
|
|
2
|
-
|
|
3
1
|
import torch
|
|
4
|
-
import transformers
|
|
5
|
-
|
|
6
|
-
import deformers.models.openai.gptoss
|
|
7
|
-
|
|
8
|
-
# LOAD #########################################################################
|
|
9
|
-
|
|
10
|
-
@functools.lru_cache(maxsize=4)
|
|
11
|
-
def get_tokenizer(name: str, device: str='cpu'):
|
|
12
|
-
return transformers.AutoTokenizer.from_pretrained(
|
|
13
|
-
name,
|
|
14
|
-
use_fast=True,
|
|
15
|
-
dtype='auto',
|
|
16
|
-
device_map=device)
|
|
17
|
-
|
|
18
|
-
@functools.lru_cache(maxsize=2)
|
|
19
|
-
def get_model(name: str, device: str='cpu'):
|
|
20
|
-
__model = deformers.models.openai.gptoss.GptOssForCausalInference.from_pretrained(
|
|
21
|
-
name,
|
|
22
|
-
dtype='auto',
|
|
23
|
-
device_map=device)
|
|
24
|
-
# toggle the inference mode (not training)
|
|
25
|
-
__model.eval()
|
|
26
|
-
# transformers model
|
|
27
|
-
return __model
|
|
28
2
|
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
@functools.lru_cache(maxsize=4)
|
|
32
|
-
def preprocess_token_ids(
|
|
33
|
-
tokenizer_obj: object,
|
|
34
|
-
prompt_str: str,
|
|
35
|
-
device_str: str='cpu'
|
|
36
|
-
) -> dict:
|
|
37
|
-
# tokenize
|
|
38
|
-
__inputs = tokenizer_obj(prompt_str, return_tensors='pt')
|
|
39
|
-
# move to the main device
|
|
40
|
-
return {__k: __v.to(device_str) for __k, __v in __inputs.items()}
|
|
41
|
-
|
|
42
|
-
# GENERATE #######################################################################
|
|
43
|
-
|
|
44
|
-
def generate_token_ids(
|
|
45
|
-
model_obj: object,
|
|
46
|
-
input_args: dict,
|
|
47
|
-
token_num: int,
|
|
48
|
-
topk_num: int = 4,
|
|
49
|
-
topp_num: float = 0.9,
|
|
50
|
-
) -> torch.Tensor:
|
|
51
|
-
# generate completion
|
|
52
|
-
with torch.no_grad():
|
|
53
|
-
__outputs = model_obj.generate(
|
|
54
|
-
**input_args,
|
|
55
|
-
max_new_tokens=token_num,
|
|
56
|
-
do_sample=(0.0 < topp_num < 1.0) or (topk_num > 0),
|
|
57
|
-
top_k=topk_num if (topk_num > 0) else None,
|
|
58
|
-
top_p=topp_num if (0.0 < topp_num < 1.0) else None,
|
|
59
|
-
return_dict_in_generate=True,
|
|
60
|
-
output_hidden_states=False,
|
|
61
|
-
output_attentions=False,
|
|
62
|
-
output_scores=False,
|
|
63
|
-
# early_stopping=True,
|
|
64
|
-
use_cache=True)
|
|
65
|
-
# full sequence
|
|
66
|
-
return __outputs.sequences # (1, T)
|
|
3
|
+
import psaiops.common.model
|
|
4
|
+
import psaiops.common.tokenizer
|
|
67
5
|
|
|
68
6
|
# COMPUTE ########################################################################
|
|
69
7
|
|
|
@@ -91,8 +29,8 @@ def reduce_attention_weights(
|
|
|
91
29
|
) -> torch.Tensor:
|
|
92
30
|
# parse
|
|
93
31
|
__layer_dim, __batch_dim, __head_dim, __output_dim, __output_dim = tuple(attention_data.shape) # L, B, H, T, T
|
|
94
|
-
__layer_idx = min(layer_idx, __layer_dim)
|
|
95
|
-
__head_idx = min(head_idx, __head_dim)
|
|
32
|
+
__layer_idx = min(layer_idx, __layer_dim - 1)
|
|
33
|
+
__head_idx = min(head_idx, __head_dim - 1)
|
|
96
34
|
__token_idx = min(token_idx, __output_dim - input_dim - 1) # T = I + O
|
|
97
35
|
# select the relevant data along each axis
|
|
98
36
|
__layer_slice = slice(None) if (__layer_idx < 0) else slice(__layer_idx, __layer_idx + 1)
|
|
@@ -127,19 +65,6 @@ def postprocess_attention_scores(
|
|
|
127
65
|
# native list of serialized integers
|
|
128
66
|
return [str(__i) for __i in __input_scores.tolist() + __output_scores.tolist()] # (I,) + (O,) = (T,)
|
|
129
67
|
|
|
130
|
-
# POSTPROCESS ####################################################################
|
|
131
|
-
|
|
132
|
-
def postprocess_token_ids(
|
|
133
|
-
tokenizer_obj: object,
|
|
134
|
-
token_obj: torch.Tensor,
|
|
135
|
-
) -> list:
|
|
136
|
-
# remove the batch axis
|
|
137
|
-
__indices = token_obj.squeeze().tolist()
|
|
138
|
-
# back to token strings
|
|
139
|
-
__tokens = tokenizer_obj.convert_ids_to_tokens(__indices)
|
|
140
|
-
# normalize the tokens
|
|
141
|
-
return [__t.replace(chr(0x0120), ' ').replace(chr(0x010a), '\n') for __t in __tokens]
|
|
142
|
-
|
|
143
68
|
# COMPUTE ########################################################################
|
|
144
69
|
|
|
145
70
|
def score_tokens(
|
|
@@ -155,16 +80,16 @@ def score_tokens(
|
|
|
155
80
|
tokenizer_obj: object,
|
|
156
81
|
) -> list:
|
|
157
82
|
# dictionary {'input_ids': _, 'attention_mask': _}
|
|
158
|
-
__inputs = preprocess_token_ids(
|
|
83
|
+
__inputs = psaiops.common.tokenizer.preprocess_token_ids(
|
|
159
84
|
tokenizer_obj=tokenizer_obj,
|
|
160
85
|
prompt_str=prompt_str,
|
|
161
86
|
device_str=device_str)
|
|
162
87
|
# parse the inputs
|
|
163
88
|
__input_dim = int(__inputs['input_ids'].shape[-1])
|
|
164
89
|
# tensor (1, T)
|
|
165
|
-
__outputs = generate_token_ids(
|
|
90
|
+
__outputs = psaiops.common.tokenizer.model.generate_token_ids(
|
|
166
91
|
model_obj=model_obj,
|
|
167
|
-
|
|
92
|
+
input_ids=__inputs['input_ids'],
|
|
168
93
|
token_num=token_num,
|
|
169
94
|
topk_num=topk_num,
|
|
170
95
|
topp_num=topp_num)
|
|
@@ -185,7 +110,7 @@ def score_tokens(
|
|
|
185
110
|
input_dim=__input_dim,
|
|
186
111
|
token_idx=token_idx)
|
|
187
112
|
# detokenize the IDs
|
|
188
|
-
__tokens = postprocess_token_ids(
|
|
113
|
+
__tokens = psaiops.common.tokenizer.postprocess_token_ids(
|
|
189
114
|
tokenizer_obj=tokenizer_obj,
|
|
190
115
|
token_obj=__outputs)
|
|
191
116
|
# match tokens and labels for the HighlightedText field
|
|
@@ -0,0 +1,291 @@
|
|
|
1
|
+
import functools
|
|
2
|
+
|
|
3
|
+
import gradio
|
|
4
|
+
import torch
|
|
5
|
+
import torch.cuda
|
|
6
|
+
import matplotlib.pyplot
|
|
7
|
+
|
|
8
|
+
import psaiops.common.model
|
|
9
|
+
import psaiops.common.tokenizer
|
|
10
|
+
import psaiops.score.residual.lib
|
|
11
|
+
|
|
12
|
+
# META #########################################################################
|
|
13
|
+
|
|
14
|
+
STYLE = '''.white-text span { color: white; }'''
|
|
15
|
+
TITLE = '''Router Scoring'''
|
|
16
|
+
INTRO = '''Plot the logits of the router for a given prompt.\nUnder construction, only "openai/gpt-oss-20b" is available for now.'''
|
|
17
|
+
|
|
18
|
+
MODEL = 'openai/gpt-oss-20b'
|
|
19
|
+
|
|
20
|
+
# COLORS #######################################################################
|
|
21
|
+
|
|
22
|
+
def create_color_map() -> dict:
|
|
23
|
+
return {
|
|
24
|
+
'0': '#000000',
|
|
25
|
+
'1': '#004444',}
|
|
26
|
+
|
|
27
|
+
# INTRO ########################################################################
|
|
28
|
+
|
|
29
|
+
def create_intro_block(intro: str) -> dict:
|
|
30
|
+
__intro = gradio.Markdown(intro, line_breaks=True)
|
|
31
|
+
return {'intro_block': __intro}
|
|
32
|
+
|
|
33
|
+
# MODEL ########################################################################
|
|
34
|
+
|
|
35
|
+
def create_model_block() -> dict:
|
|
36
|
+
__model = gradio.Dropdown(label='Model', value='openai/gpt-oss-20b', choices=['openai/gpt-oss-20b'], scale=1, allow_custom_value=False, multiselect=False, interactive=True) # 'openai/gpt-oss-120b'
|
|
37
|
+
return {'model_block': __model,}
|
|
38
|
+
|
|
39
|
+
# SAMPLING #####################################################################
|
|
40
|
+
|
|
41
|
+
def create_sampling_block() -> dict:
|
|
42
|
+
__tokens = gradio.Slider(label='Tokens', value=16, minimum=1, maximum=128, step=1, scale=1, interactive=True)
|
|
43
|
+
__topk = gradio.Slider(label='Top K', value=4, minimum=1, maximum=8, step=1, scale=1, interactive=True)
|
|
44
|
+
__topp = gradio.Slider(label='Top P', value=0.9, minimum=0.0, maximum=1.0, step=0.1, scale=1, interactive=True)
|
|
45
|
+
return {
|
|
46
|
+
'tokens_block': __tokens,
|
|
47
|
+
'topk_block': __topk,
|
|
48
|
+
'topp_block': __topp,}
|
|
49
|
+
|
|
50
|
+
# INPUTS #######################################################################
|
|
51
|
+
|
|
52
|
+
def create_inputs_block() -> dict:
|
|
53
|
+
__input = gradio.Textbox(label='Prompt', value='', placeholder='A string of tokens to score.', lines=4, scale=1, show_copy_button=True, interactive=True)
|
|
54
|
+
return {'input_block': __input}
|
|
55
|
+
|
|
56
|
+
# PLOTS ########################################################################
|
|
57
|
+
|
|
58
|
+
def create_plot_block() -> dict:
|
|
59
|
+
__plot = gradio.Plot(label='Router', scale=1)
|
|
60
|
+
return {'plot_block': __plot,}
|
|
61
|
+
|
|
62
|
+
# OUTPUTS ######################################################################
|
|
63
|
+
|
|
64
|
+
def create_outputs_block() -> dict:
|
|
65
|
+
__output = gradio.HighlightedText(label='Output', value='', scale=1, interactive=False, show_legend=False, show_inline_category=False, combine_adjacent=False, color_map=create_color_map(), elem_classes='white-text')
|
|
66
|
+
return {'output_block': __output}
|
|
67
|
+
|
|
68
|
+
# SELECT #######################################################################
|
|
69
|
+
|
|
70
|
+
def create_selection_block() -> dict:
|
|
71
|
+
# __play = gradio.Button('>', variant='primary', size='lg', scale=1, interactive=True)
|
|
72
|
+
__position = gradio.Slider(label='Token', value=-1, minimum=-1, maximum=15, step=1, scale=1, interactive=True) # info='-1 to average on all tokens'
|
|
73
|
+
return {'position_block': __position,}
|
|
74
|
+
|
|
75
|
+
# ACTIONS ######################################################################
|
|
76
|
+
|
|
77
|
+
def create_actions_block() -> dict:
|
|
78
|
+
__process = gradio.Button('Process', variant='primary', size='lg', scale=1, interactive=True)
|
|
79
|
+
return {'process_block': __process,}
|
|
80
|
+
|
|
81
|
+
# STATE ########################################################################
|
|
82
|
+
|
|
83
|
+
def create_state() -> dict:
|
|
84
|
+
return {
|
|
85
|
+
'output_state': gradio.State(None),
|
|
86
|
+
'hidden_state': gradio.State(None),}
|
|
87
|
+
|
|
88
|
+
# LAYOUT #######################################################################
|
|
89
|
+
|
|
90
|
+
def create_layout(intro: str=INTRO) -> dict:
|
|
91
|
+
__fields = {}
|
|
92
|
+
__fields.update(create_intro_block(intro=intro))
|
|
93
|
+
with gradio.Tabs():
|
|
94
|
+
with gradio.Tab('Score Tokens') as __main_tab:
|
|
95
|
+
__fields.update({'main_tab': __main_tab})
|
|
96
|
+
with gradio.Row(equal_height=True):
|
|
97
|
+
__fields.update(create_inputs_block())
|
|
98
|
+
with gradio.Row(equal_height=True):
|
|
99
|
+
__fields.update(create_plot_block())
|
|
100
|
+
with gradio.Row(equal_height=True):
|
|
101
|
+
__fields.update(create_outputs_block())
|
|
102
|
+
with gradio.Row(equal_height=True):
|
|
103
|
+
__fields.update(create_selection_block())
|
|
104
|
+
with gradio.Row(equal_height=True):
|
|
105
|
+
__fields.update(create_actions_block())
|
|
106
|
+
with gradio.Tab('Settings') as __settings_tab:
|
|
107
|
+
__fields.update({'settings_tab': __settings_tab})
|
|
108
|
+
with gradio.Column(scale=1):
|
|
109
|
+
with gradio.Row(equal_height=True):
|
|
110
|
+
__fields.update(create_model_block())
|
|
111
|
+
with gradio.Row(equal_height=True):
|
|
112
|
+
__fields.update(create_sampling_block())
|
|
113
|
+
return __fields
|
|
114
|
+
|
|
115
|
+
# EVENTS #######################################################################
|
|
116
|
+
|
|
117
|
+
def update_position_range(
|
|
118
|
+
current_val: float,
|
|
119
|
+
token_num: float,
|
|
120
|
+
output_data: torch.Tensor,
|
|
121
|
+
) -> dict:
|
|
122
|
+
# take the generated tokens into account
|
|
123
|
+
__max = int(token_num) - 1 if (output_data is None) else int(output_data.shape[-1])
|
|
124
|
+
# keep the previous value if possible
|
|
125
|
+
__val = min(int(current_val), __max)
|
|
126
|
+
# return a gradio update dictionary
|
|
127
|
+
return gradio.update(maximum=__max, value=__val)
|
|
128
|
+
|
|
129
|
+
def update_computation_state(
|
|
130
|
+
token_num: float,
|
|
131
|
+
topk_num: float,
|
|
132
|
+
topp_num: float,
|
|
133
|
+
token_idx: float,
|
|
134
|
+
prompt_str: str,
|
|
135
|
+
device_str: str,
|
|
136
|
+
model_obj: object,
|
|
137
|
+
tokenizer_obj: object,
|
|
138
|
+
) -> tuple:
|
|
139
|
+
# sanitize the inputs
|
|
140
|
+
__token_num = max(1, min(128, int(token_num)))
|
|
141
|
+
__topk_num = max(1, min(8, int(topk_num)))
|
|
142
|
+
__topp_num = max(0.0, min(1.0, float(topp_num)))
|
|
143
|
+
__token_idx = max(-1, min(__token_num, int(token_idx)))
|
|
144
|
+
__prompt_str = prompt_str.strip()
|
|
145
|
+
__device_str = device_str if (device_str in ['cpu', 'cuda']) else 'cpu'
|
|
146
|
+
# exit if some values are missing
|
|
147
|
+
if (not __prompt_str) or (model_obj is None) or (tokenizer_obj is None):
|
|
148
|
+
return (torch.empty(0), torch.empty(0))
|
|
149
|
+
# dictionary {'input_ids': _, 'attention_mask': _}
|
|
150
|
+
__input_data = psaiops.common.tokenizer.preprocess_token_ids(
|
|
151
|
+
tokenizer_obj=tokenizer_obj,
|
|
152
|
+
prompt_str=__prompt_str,
|
|
153
|
+
device_str=__device_str)
|
|
154
|
+
# tensor (1, T) and O * L * (1, I, H)
|
|
155
|
+
__output_data, __hidden_data = psaiops.score.residual.lib.generate_token_ids(
|
|
156
|
+
model_obj=model_obj,
|
|
157
|
+
input_ids=__input_data['input_ids'],
|
|
158
|
+
token_num=__token_num,
|
|
159
|
+
topk_num=__topk_num,
|
|
160
|
+
topp_num=__topp_num)
|
|
161
|
+
# tensor (1, L, I + O, H)
|
|
162
|
+
__hidden_data = psaiops.score.residual.lib.merge_hidden_states(
|
|
163
|
+
hidden_data=__hidden_data)
|
|
164
|
+
# update each component => (highlight, plot) states
|
|
165
|
+
return (
|
|
166
|
+
__output_data.cpu().float(),
|
|
167
|
+
__hidden_data.cpu().float(),)
|
|
168
|
+
|
|
169
|
+
def update_hidden_plot(
|
|
170
|
+
token_idx: float,
|
|
171
|
+
hidden_data: torch.Tensor,
|
|
172
|
+
) -> tuple:
|
|
173
|
+
# exit if some values are missing
|
|
174
|
+
if (hidden_data is None) or (len(hidden_data) == 0):
|
|
175
|
+
return None
|
|
176
|
+
# reduce the token axis (B, L, T, E) => (B, L, E)
|
|
177
|
+
__plot_data = psaiops.score.residual.lib.reduce_hidden_states(
|
|
178
|
+
hidden_data=hidden_data,
|
|
179
|
+
token_idx=int(token_idx),)
|
|
180
|
+
# rescale the data to [-1; 1] (B, L, E)
|
|
181
|
+
__plot_data = psaiops.score.residual.lib.rescale_hidden_states(
|
|
182
|
+
hidden_data=__plot_data)
|
|
183
|
+
# reshape into a 3D tensor by folding E (B, L, E) => (B, W, H, L)
|
|
184
|
+
__plot_data = psaiops.score.residual.lib.reshape_hidden_states(
|
|
185
|
+
hidden_data=__plot_data)
|
|
186
|
+
# mask the small activations to improve the plot readability
|
|
187
|
+
__mask_data = psaiops.score.residual.lib.mask_hidden_states(
|
|
188
|
+
hidden_data=__plot_data,
|
|
189
|
+
topk_num=128).numpy()
|
|
190
|
+
# map the [-1; 1] activations to RGBA colors
|
|
191
|
+
__plot_data = psaiops.score.residual.lib.color_hidden_states(
|
|
192
|
+
hidden_data=__plot_data.numpy())
|
|
193
|
+
# plot the first sample
|
|
194
|
+
__figure = matplotlib.pyplot.figure()
|
|
195
|
+
__axes = __figure.add_subplot(1, 1, 1, projection='3d')
|
|
196
|
+
__axes.voxels(filled=__mask_data[0], facecolors=__plot_data[0], edgecolor=None)
|
|
197
|
+
__figure.tight_layout()
|
|
198
|
+
# remove the figure for the pyplot register for garbage collection
|
|
199
|
+
matplotlib.pyplot.close(__figure)
|
|
200
|
+
# update each component => (highlight, plot) states
|
|
201
|
+
return __figure
|
|
202
|
+
|
|
203
|
+
def update_text_highlight(
|
|
204
|
+
token_idx: float,
|
|
205
|
+
output_data: torch.Tensor,
|
|
206
|
+
tokenizer_obj: object,
|
|
207
|
+
) -> list:
|
|
208
|
+
# exit if some values are missing
|
|
209
|
+
if (output_data is None) or (len(output_data) == 0):
|
|
210
|
+
return None
|
|
211
|
+
# detokenize the IDs
|
|
212
|
+
__token_str = psaiops.common.tokenizer.postprocess_token_ids(
|
|
213
|
+
tokenizer_obj=tokenizer_obj,
|
|
214
|
+
token_data=output_data)
|
|
215
|
+
# list of string classes
|
|
216
|
+
__token_cls = psaiops.score.residual.lib.postprocess_token_cls(
|
|
217
|
+
token_idx=int(token_idx),
|
|
218
|
+
token_dim=len(__token_str))
|
|
219
|
+
# pairs of token and class
|
|
220
|
+
return list(zip(__token_str, __token_cls))
|
|
221
|
+
|
|
222
|
+
# APP ##########################################################################
|
|
223
|
+
|
|
224
|
+
def create_app(title: str=TITLE, intro: str=INTRO, style: str=STYLE, model: str=MODEL) -> gradio.Blocks:
|
|
225
|
+
__fields = {}
|
|
226
|
+
with gradio.Blocks(theme=gradio.themes.Soft(), title=title, css=style) as __app:
|
|
227
|
+
# load the model
|
|
228
|
+
__device = 'cuda' if torch.cuda.is_available() else 'cpu'
|
|
229
|
+
__model = psaiops.common.model.get_model(name=model, device=__device)
|
|
230
|
+
__tokenizer = psaiops.common.tokenizer.get_tokenizer(name=model, device=__device)
|
|
231
|
+
# adapt the event handlers
|
|
232
|
+
__compute = functools.partial(update_computation_state, model_obj=__model, tokenizer_obj=__tokenizer, device_str=__device)
|
|
233
|
+
__highlight = functools.partial(update_text_highlight, tokenizer_obj=__tokenizer)
|
|
234
|
+
# create the UI
|
|
235
|
+
__fields.update(create_layout(intro=intro))
|
|
236
|
+
# init the state
|
|
237
|
+
__fields.update(create_state())
|
|
238
|
+
# update the data after clicking process
|
|
239
|
+
__fields['process_block'].click(
|
|
240
|
+
fn=__compute,
|
|
241
|
+
inputs=[__fields[__k] for __k in ['tokens_block', 'topk_block', 'topp_block', 'position_block', 'input_block']],
|
|
242
|
+
outputs=[__fields[__k] for __k in ['output_state', 'hidden_state']],
|
|
243
|
+
queue=False,
|
|
244
|
+
show_progress='full').then(
|
|
245
|
+
# update the range of the position slider when the output changes
|
|
246
|
+
fn=update_position_range,
|
|
247
|
+
inputs=[__fields[__k] for __k in ['position_block', 'tokens_block', 'output_state']],
|
|
248
|
+
outputs=__fields['position_block'],
|
|
249
|
+
queue=False,
|
|
250
|
+
show_progress='hidden').then(
|
|
251
|
+
# update the token highlight when the output data changes
|
|
252
|
+
fn=__highlight,
|
|
253
|
+
inputs=[__fields[__k] for __k in ['position_block', 'output_state']],
|
|
254
|
+
outputs=__fields['output_block'],
|
|
255
|
+
queue=False,
|
|
256
|
+
show_progress='full').then(
|
|
257
|
+
# update the plot when the router data changes
|
|
258
|
+
fn=update_hidden_plot,
|
|
259
|
+
inputs=[__fields[__k] for __k in ['position_block', 'hidden_state']],
|
|
260
|
+
outputs=__fields['plot_block'],
|
|
261
|
+
queue=False,
|
|
262
|
+
show_progress='full')
|
|
263
|
+
# update the range of the position slider when the settings change
|
|
264
|
+
__fields['tokens_block'].change(
|
|
265
|
+
fn=update_position_range,
|
|
266
|
+
inputs=[__fields[__k] for __k in ['position_block', 'tokens_block', 'output_state']],
|
|
267
|
+
outputs=__fields['position_block'],
|
|
268
|
+
queue=False,
|
|
269
|
+
show_progress='hidden')
|
|
270
|
+
# update the plot when the focus changes
|
|
271
|
+
__fields['position_block'].change(
|
|
272
|
+
fn=update_hidden_plot,
|
|
273
|
+
inputs=[__fields[__k] for __k in ['position_block', 'hidden_state']],
|
|
274
|
+
outputs=__fields['plot_block'],
|
|
275
|
+
queue=False,
|
|
276
|
+
show_progress='full')
|
|
277
|
+
# update the token highlight when the token focus changes
|
|
278
|
+
__fields['position_block'].change(
|
|
279
|
+
fn=__highlight,
|
|
280
|
+
inputs=[__fields[__k] for __k in ['position_block', 'output_state']],
|
|
281
|
+
outputs=__fields['output_block'],
|
|
282
|
+
queue=False,
|
|
283
|
+
show_progress='hidden')
|
|
284
|
+
# gradio application
|
|
285
|
+
return __app
|
|
286
|
+
|
|
287
|
+
# MAIN #########################################################################
|
|
288
|
+
|
|
289
|
+
if __name__ == '__main__':
|
|
290
|
+
__app = create_app()
|
|
291
|
+
__app.launch(share=True, debug=True)
|
|
@@ -0,0 +1,134 @@
|
|
|
1
|
+
import functools
|
|
2
|
+
import math
|
|
3
|
+
|
|
4
|
+
import matplotlib
|
|
5
|
+
import numpy
|
|
6
|
+
import torch
|
|
7
|
+
|
|
8
|
+
# GENERATE #######################################################################
|
|
9
|
+
|
|
10
|
+
@functools.lru_cache(maxsize=32)
|
|
11
|
+
def generate_token_ids(
|
|
12
|
+
model_obj: object,
|
|
13
|
+
input_ids: torch.Tensor,
|
|
14
|
+
token_num: int,
|
|
15
|
+
topk_num: int = 4,
|
|
16
|
+
topp_num: float = 0.9,
|
|
17
|
+
) -> tuple:
|
|
18
|
+
# generate completion
|
|
19
|
+
with torch.no_grad():
|
|
20
|
+
__outputs = model_obj.generate(
|
|
21
|
+
input_ids=input_ids,
|
|
22
|
+
max_new_tokens=token_num,
|
|
23
|
+
do_sample=(0.0 < topp_num < 1.0) or (topk_num > 0),
|
|
24
|
+
top_k=topk_num if (topk_num > 0) else None,
|
|
25
|
+
top_p=topp_num if (0.0 < topp_num < 1.0) else None,
|
|
26
|
+
return_dict_in_generate=True,
|
|
27
|
+
output_hidden_states=True,
|
|
28
|
+
output_attentions=False,
|
|
29
|
+
output_scores=False,
|
|
30
|
+
early_stopping=True,
|
|
31
|
+
use_cache=True)
|
|
32
|
+
# ((B, T), O * L * (B, I, H))
|
|
33
|
+
return __outputs.sequences, __outputs.hidden_states
|
|
34
|
+
|
|
35
|
+
# MERGE ########################################################################
|
|
36
|
+
|
|
37
|
+
def merge_hidden_states(
|
|
38
|
+
hidden_data: torch.Tensor,
|
|
39
|
+
) -> torch.Tensor:
|
|
40
|
+
# parse the inputs
|
|
41
|
+
__token_dim = len(hidden_data)
|
|
42
|
+
__layer_dim = len(hidden_data[0])
|
|
43
|
+
# stack the data for each layer => (B, L, I + O, H)
|
|
44
|
+
return torch.stack(
|
|
45
|
+
[
|
|
46
|
+
# concatenate the data for all the tokens => (B, I + O, H)
|
|
47
|
+
torch.concatenate([hidden_data[__t][__l] for __t in range(__token_dim)], dim=1)
|
|
48
|
+
for __l in range(__layer_dim)],
|
|
49
|
+
dim=1)
|
|
50
|
+
|
|
51
|
+
# REDUCE #######################################################################
|
|
52
|
+
|
|
53
|
+
def reduce_hidden_states(
|
|
54
|
+
hidden_data: torch.Tensor,
|
|
55
|
+
token_idx: int, # -1 => avg over all tokens
|
|
56
|
+
) -> torch.Tensor:
|
|
57
|
+
# parse the hidden states (B, L, T, H)
|
|
58
|
+
__batch_dim, __layer_dim, __token_dim, __hidden_dim = tuple(hidden_data.shape)
|
|
59
|
+
__token_idx = min(token_idx, __token_dim - 1)
|
|
60
|
+
# select the relevant data along each axis
|
|
61
|
+
__token_slice = slice(0, __token_dim) if (__token_idx < 0) else slice(__token_idx, __token_idx + 1)
|
|
62
|
+
# filter the data
|
|
63
|
+
__data = hidden_data[slice(None), slice(None), __token_slice, slice(None)]
|
|
64
|
+
# reduce the token axis => (B, L, H)
|
|
65
|
+
return __data.mean(dim=2, keepdim=False)
|
|
66
|
+
|
|
67
|
+
# RESCALE ######################################################################
|
|
68
|
+
|
|
69
|
+
def rescale_hidden_states(
|
|
70
|
+
hidden_data: torch.Tensor, # (B, L, H)
|
|
71
|
+
) -> torch.Tensor:
|
|
72
|
+
# compute the scale of the data, layer by layer
|
|
73
|
+
__s = torch.quantile(hidden_data.abs(), q=0.9, dim=-1, keepdim=True)
|
|
74
|
+
# log scaling on large values and linear near 0
|
|
75
|
+
__a = torch.asinh(hidden_data / (__s + torch.finfo().eps))
|
|
76
|
+
# clip and map to [-1; 1]
|
|
77
|
+
return 0.33 * __a.clamp(min=-3, max=3)
|
|
78
|
+
|
|
79
|
+
# RESHAPE ######################################################################
|
|
80
|
+
|
|
81
|
+
def reshape_hidden_states(
|
|
82
|
+
hidden_data: torch.Tensor, # (B, L, H)
|
|
83
|
+
) -> torch.Tensor:
|
|
84
|
+
# parse the hidden states (B, L, H)
|
|
85
|
+
__batch_dim, __layer_dim, __hidden_dim = tuple(hidden_data.shape)
|
|
86
|
+
# factor the hidden dimension
|
|
87
|
+
__width_dim = math.gcd(__hidden_dim, 2 ** int(math.log2(__hidden_dim))) # greatest power of 2 that divides H
|
|
88
|
+
__height_dim = __hidden_dim // __width_dim
|
|
89
|
+
# reshape into (B, W, H, L)
|
|
90
|
+
return hidden_data.reshape((__batch_dim, __layer_dim, __width_dim, __height_dim)).permute(0, 2, 3, 1)
|
|
91
|
+
|
|
92
|
+
# MASK #########################################################################
|
|
93
|
+
|
|
94
|
+
def mask_hidden_states(
|
|
95
|
+
hidden_data: torch.Tensor, # (B, L, H)
|
|
96
|
+
topk_num: int=128,
|
|
97
|
+
) -> torch.Tensor:
|
|
98
|
+
# sanitize
|
|
99
|
+
__k = min(topk_num, int(hidden_data.shape[-1]))
|
|
100
|
+
# indices of the topk values
|
|
101
|
+
__indices = hidden_data.abs().topk(__k, dim=-1, largest=True, sorted=False).indices
|
|
102
|
+
# initialize the mask with False
|
|
103
|
+
__mask = torch.zeros_like(hidden_data, dtype=torch.bool)
|
|
104
|
+
# (B, L, H) mask of the topk values
|
|
105
|
+
return __mask.scatter_(dim=-1, index=__indices, value=True)
|
|
106
|
+
|
|
107
|
+
# FORMAT #######################################################################
|
|
108
|
+
|
|
109
|
+
def color_hidden_states(
|
|
110
|
+
hidden_data: numpy.array, # (B, H, W, L)
|
|
111
|
+
gamma_val: float=0.7,
|
|
112
|
+
alpha_val: float=0.35,
|
|
113
|
+
color_map: callable=matplotlib.colormaps['coolwarm'],
|
|
114
|
+
) -> list:
|
|
115
|
+
# [-1; 1] => [0; 1]
|
|
116
|
+
__data = 0.5 * (hidden_data + 1.0)
|
|
117
|
+
# (B, W, H, L) => (B, W, H, L, 4)
|
|
118
|
+
__rgba = color_map(__data)
|
|
119
|
+
# compute the transparency from the magnitude
|
|
120
|
+
__rgba[..., 3] = alpha_val * (numpy.abs(hidden_data) ** gamma_val)
|
|
121
|
+
# (B, W, H, L, 4) in [0; 1]
|
|
122
|
+
return __rgba
|
|
123
|
+
|
|
124
|
+
# POSTPROCESS ##################################################################
|
|
125
|
+
|
|
126
|
+
def postprocess_token_cls(
|
|
127
|
+
token_idx: int,
|
|
128
|
+
token_dim: int,
|
|
129
|
+
) -> list:
|
|
130
|
+
__token_idx = max(-1, min(token_dim, token_idx))
|
|
131
|
+
# class 1 for the focused token(s) 0 for the rest
|
|
132
|
+
__token_cls = [str(int(__i == token_idx)) for __i in range(token_dim)]
|
|
133
|
+
# average on all the tokens when the idx is negative
|
|
134
|
+
return token_dim * ['1'] if (token_idx < 0) else __token_cls
|