psaiops 0.3.2__py3-none-any.whl → 0.4.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,69 +1,7 @@
1
- import functools
2
-
3
1
  import torch
4
- import transformers
5
-
6
- import deformers.models.openai.gptoss
7
-
8
- # LOAD #########################################################################
9
-
10
- @functools.lru_cache(maxsize=4)
11
- def get_tokenizer(name: str, device: str='cpu'):
12
- return transformers.AutoTokenizer.from_pretrained(
13
- name,
14
- use_fast=True,
15
- dtype='auto',
16
- device_map=device)
17
-
18
- @functools.lru_cache(maxsize=2)
19
- def get_model(name: str, device: str='cpu'):
20
- __model = deformers.models.openai.gptoss.GptOssForCausalInference.from_pretrained(
21
- name,
22
- dtype='auto',
23
- device_map=device)
24
- # toggle the inference mode (not training)
25
- __model.eval()
26
- # transformers model
27
- return __model
28
2
 
29
- # PREPROCESS #####################################################################
30
-
31
- @functools.lru_cache(maxsize=4)
32
- def preprocess_token_ids(
33
- tokenizer_obj: object,
34
- prompt_str: str,
35
- device_str: str='cpu'
36
- ) -> dict:
37
- # tokenize
38
- __inputs = tokenizer_obj(prompt_str, return_tensors='pt')
39
- # move to the main device
40
- return {__k: __v.to(device_str) for __k, __v in __inputs.items()}
41
-
42
- # GENERATE #######################################################################
43
-
44
- def generate_token_ids(
45
- model_obj: object,
46
- input_args: dict,
47
- token_num: int,
48
- topk_num: int = 4,
49
- topp_num: float = 0.9,
50
- ) -> torch.Tensor:
51
- # generate completion
52
- with torch.no_grad():
53
- __outputs = model_obj.generate(
54
- **input_args,
55
- max_new_tokens=token_num,
56
- do_sample=(0.0 < topp_num < 1.0) or (topk_num > 0),
57
- top_k=topk_num if (topk_num > 0) else None,
58
- top_p=topp_num if (0.0 < topp_num < 1.0) else None,
59
- return_dict_in_generate=True,
60
- output_hidden_states=False,
61
- output_attentions=False,
62
- output_scores=False,
63
- # early_stopping=True,
64
- use_cache=True)
65
- # full sequence
66
- return __outputs.sequences # (1, T)
3
+ import psaiops.common.model
4
+ import psaiops.common.tokenizer
67
5
 
68
6
  # COMPUTE ########################################################################
69
7
 
@@ -91,8 +29,8 @@ def reduce_attention_weights(
91
29
  ) -> torch.Tensor:
92
30
  # parse
93
31
  __layer_dim, __batch_dim, __head_dim, __output_dim, __output_dim = tuple(attention_data.shape) # L, B, H, T, T
94
- __layer_idx = min(layer_idx, __layer_dim)
95
- __head_idx = min(head_idx, __head_dim)
32
+ __layer_idx = min(layer_idx, __layer_dim - 1)
33
+ __head_idx = min(head_idx, __head_dim - 1)
96
34
  __token_idx = min(token_idx, __output_dim - input_dim - 1) # T = I + O
97
35
  # select the relevant data along each axis
98
36
  __layer_slice = slice(None) if (__layer_idx < 0) else slice(__layer_idx, __layer_idx + 1)
@@ -127,19 +65,6 @@ def postprocess_attention_scores(
127
65
  # native list of serialized integers
128
66
  return [str(__i) for __i in __input_scores.tolist() + __output_scores.tolist()] # (I,) + (O,) = (T,)
129
67
 
130
- # POSTPROCESS ####################################################################
131
-
132
- def postprocess_token_ids(
133
- tokenizer_obj: object,
134
- token_obj: torch.Tensor,
135
- ) -> list:
136
- # remove the batch axis
137
- __indices = token_obj.squeeze().tolist()
138
- # back to token strings
139
- __tokens = tokenizer_obj.convert_ids_to_tokens(__indices)
140
- # normalize the tokens
141
- return [__t.replace(chr(0x0120), ' ').replace(chr(0x010a), '\n') for __t in __tokens]
142
-
143
68
  # COMPUTE ########################################################################
144
69
 
145
70
  def score_tokens(
@@ -155,16 +80,16 @@ def score_tokens(
155
80
  tokenizer_obj: object,
156
81
  ) -> list:
157
82
  # dictionary {'input_ids': _, 'attention_mask': _}
158
- __inputs = preprocess_token_ids(
83
+ __inputs = psaiops.common.tokenizer.preprocess_token_ids(
159
84
  tokenizer_obj=tokenizer_obj,
160
85
  prompt_str=prompt_str,
161
86
  device_str=device_str)
162
87
  # parse the inputs
163
88
  __input_dim = int(__inputs['input_ids'].shape[-1])
164
89
  # tensor (1, T)
165
- __outputs = generate_token_ids(
90
+ __outputs = psaiops.common.tokenizer.model.generate_token_ids(
166
91
  model_obj=model_obj,
167
- input_args=__inputs,
92
+ input_ids=__inputs['input_ids'],
168
93
  token_num=token_num,
169
94
  topk_num=topk_num,
170
95
  topp_num=topp_num)
@@ -185,7 +110,7 @@ def score_tokens(
185
110
  input_dim=__input_dim,
186
111
  token_idx=token_idx)
187
112
  # detokenize the IDs
188
- __tokens = postprocess_token_ids(
113
+ __tokens = psaiops.common.tokenizer.postprocess_token_ids(
189
114
  tokenizer_obj=tokenizer_obj,
190
115
  token_obj=__outputs)
191
116
  # match tokens and labels for the HighlightedText field
@@ -0,0 +1,291 @@
1
+ import functools
2
+
3
+ import gradio
4
+ import torch
5
+ import torch.cuda
6
+ import matplotlib.pyplot
7
+
8
+ import psaiops.common.model
9
+ import psaiops.common.tokenizer
10
+ import psaiops.score.residual.lib
11
+
12
+ # META #########################################################################
13
+
14
+ STYLE = '''.white-text span { color: white; }'''
15
+ TITLE = '''Router Scoring'''
16
+ INTRO = '''Plot the logits of the router for a given prompt.\nUnder construction, only "openai/gpt-oss-20b" is available for now.'''
17
+
18
+ MODEL = 'openai/gpt-oss-20b'
19
+
20
+ # COLORS #######################################################################
21
+
22
+ def create_color_map() -> dict:
23
+ return {
24
+ '0': '#000000',
25
+ '1': '#004444',}
26
+
27
+ # INTRO ########################################################################
28
+
29
+ def create_intro_block(intro: str) -> dict:
30
+ __intro = gradio.Markdown(intro, line_breaks=True)
31
+ return {'intro_block': __intro}
32
+
33
+ # MODEL ########################################################################
34
+
35
+ def create_model_block() -> dict:
36
+ __model = gradio.Dropdown(label='Model', value='openai/gpt-oss-20b', choices=['openai/gpt-oss-20b'], scale=1, allow_custom_value=False, multiselect=False, interactive=True) # 'openai/gpt-oss-120b'
37
+ return {'model_block': __model,}
38
+
39
+ # SAMPLING #####################################################################
40
+
41
+ def create_sampling_block() -> dict:
42
+ __tokens = gradio.Slider(label='Tokens', value=16, minimum=1, maximum=128, step=1, scale=1, interactive=True)
43
+ __topk = gradio.Slider(label='Top K', value=4, minimum=1, maximum=8, step=1, scale=1, interactive=True)
44
+ __topp = gradio.Slider(label='Top P', value=0.9, minimum=0.0, maximum=1.0, step=0.1, scale=1, interactive=True)
45
+ return {
46
+ 'tokens_block': __tokens,
47
+ 'topk_block': __topk,
48
+ 'topp_block': __topp,}
49
+
50
+ # INPUTS #######################################################################
51
+
52
+ def create_inputs_block() -> dict:
53
+ __input = gradio.Textbox(label='Prompt', value='', placeholder='A string of tokens to score.', lines=4, scale=1, show_copy_button=True, interactive=True)
54
+ return {'input_block': __input}
55
+
56
+ # PLOTS ########################################################################
57
+
58
+ def create_plot_block() -> dict:
59
+ __plot = gradio.Plot(label='Router', scale=1)
60
+ return {'plot_block': __plot,}
61
+
62
+ # OUTPUTS ######################################################################
63
+
64
+ def create_outputs_block() -> dict:
65
+ __output = gradio.HighlightedText(label='Output', value='', scale=1, interactive=False, show_legend=False, show_inline_category=False, combine_adjacent=False, color_map=create_color_map(), elem_classes='white-text')
66
+ return {'output_block': __output}
67
+
68
+ # SELECT #######################################################################
69
+
70
+ def create_selection_block() -> dict:
71
+ # __play = gradio.Button('>', variant='primary', size='lg', scale=1, interactive=True)
72
+ __position = gradio.Slider(label='Token', value=-1, minimum=-1, maximum=15, step=1, scale=1, interactive=True) # info='-1 to average on all tokens'
73
+ return {'position_block': __position,}
74
+
75
+ # ACTIONS ######################################################################
76
+
77
+ def create_actions_block() -> dict:
78
+ __process = gradio.Button('Process', variant='primary', size='lg', scale=1, interactive=True)
79
+ return {'process_block': __process,}
80
+
81
+ # STATE ########################################################################
82
+
83
+ def create_state() -> dict:
84
+ return {
85
+ 'output_state': gradio.State(None),
86
+ 'hidden_state': gradio.State(None),}
87
+
88
+ # LAYOUT #######################################################################
89
+
90
+ def create_layout(intro: str=INTRO) -> dict:
91
+ __fields = {}
92
+ __fields.update(create_intro_block(intro=intro))
93
+ with gradio.Tabs():
94
+ with gradio.Tab('Score Tokens') as __main_tab:
95
+ __fields.update({'main_tab': __main_tab})
96
+ with gradio.Row(equal_height=True):
97
+ __fields.update(create_inputs_block())
98
+ with gradio.Row(equal_height=True):
99
+ __fields.update(create_plot_block())
100
+ with gradio.Row(equal_height=True):
101
+ __fields.update(create_outputs_block())
102
+ with gradio.Row(equal_height=True):
103
+ __fields.update(create_selection_block())
104
+ with gradio.Row(equal_height=True):
105
+ __fields.update(create_actions_block())
106
+ with gradio.Tab('Settings') as __settings_tab:
107
+ __fields.update({'settings_tab': __settings_tab})
108
+ with gradio.Column(scale=1):
109
+ with gradio.Row(equal_height=True):
110
+ __fields.update(create_model_block())
111
+ with gradio.Row(equal_height=True):
112
+ __fields.update(create_sampling_block())
113
+ return __fields
114
+
115
+ # EVENTS #######################################################################
116
+
117
+ def update_position_range(
118
+ current_val: float,
119
+ token_num: float,
120
+ output_data: torch.Tensor,
121
+ ) -> dict:
122
+ # take the generated tokens into account
123
+ __max = int(token_num) - 1 if (output_data is None) else int(output_data.shape[-1])
124
+ # keep the previous value if possible
125
+ __val = min(int(current_val), __max)
126
+ # return a gradio update dictionary
127
+ return gradio.update(maximum=__max, value=__val)
128
+
129
+ def update_computation_state(
130
+ token_num: float,
131
+ topk_num: float,
132
+ topp_num: float,
133
+ token_idx: float,
134
+ prompt_str: str,
135
+ device_str: str,
136
+ model_obj: object,
137
+ tokenizer_obj: object,
138
+ ) -> tuple:
139
+ # sanitize the inputs
140
+ __token_num = max(1, min(128, int(token_num)))
141
+ __topk_num = max(1, min(8, int(topk_num)))
142
+ __topp_num = max(0.0, min(1.0, float(topp_num)))
143
+ __token_idx = max(-1, min(__token_num, int(token_idx)))
144
+ __prompt_str = prompt_str.strip()
145
+ __device_str = device_str if (device_str in ['cpu', 'cuda']) else 'cpu'
146
+ # exit if some values are missing
147
+ if (not __prompt_str) or (model_obj is None) or (tokenizer_obj is None):
148
+ return (torch.empty(0), torch.empty(0))
149
+ # dictionary {'input_ids': _, 'attention_mask': _}
150
+ __input_data = psaiops.common.tokenizer.preprocess_token_ids(
151
+ tokenizer_obj=tokenizer_obj,
152
+ prompt_str=__prompt_str,
153
+ device_str=__device_str)
154
+ # tensor (1, T) and O * L * (1, I, H)
155
+ __output_data, __hidden_data = psaiops.score.residual.lib.generate_token_ids(
156
+ model_obj=model_obj,
157
+ input_ids=__input_data['input_ids'],
158
+ token_num=__token_num,
159
+ topk_num=__topk_num,
160
+ topp_num=__topp_num)
161
+ # tensor (1, L, I + O, H)
162
+ __hidden_data = psaiops.score.residual.lib.merge_hidden_states(
163
+ hidden_data=__hidden_data)
164
+ # update each component => (highlight, plot) states
165
+ return (
166
+ __output_data.cpu().float(),
167
+ __hidden_data.cpu().float(),)
168
+
169
+ def update_hidden_plot(
170
+ token_idx: float,
171
+ hidden_data: torch.Tensor,
172
+ ) -> tuple:
173
+ # exit if some values are missing
174
+ if (hidden_data is None) or (len(hidden_data) == 0):
175
+ return None
176
+ # reduce the token axis (B, L, T, E) => (B, L, E)
177
+ __plot_data = psaiops.score.residual.lib.reduce_hidden_states(
178
+ hidden_data=hidden_data,
179
+ token_idx=int(token_idx),)
180
+ # rescale the data to [-1; 1] (B, L, E)
181
+ __plot_data = psaiops.score.residual.lib.rescale_hidden_states(
182
+ hidden_data=__plot_data)
183
+ # reshape into a 3D tensor by folding E (B, L, E) => (B, W, H, L)
184
+ __plot_data = psaiops.score.residual.lib.reshape_hidden_states(
185
+ hidden_data=__plot_data)
186
+ # mask the small activations to improve the plot readability
187
+ __mask_data = psaiops.score.residual.lib.mask_hidden_states(
188
+ hidden_data=__plot_data,
189
+ topk_num=128).numpy()
190
+ # map the [-1; 1] activations to RGBA colors
191
+ __plot_data = psaiops.score.residual.lib.color_hidden_states(
192
+ hidden_data=__plot_data.numpy())
193
+ # plot the first sample
194
+ __figure = matplotlib.pyplot.figure()
195
+ __axes = __figure.add_subplot(1, 1, 1, projection='3d')
196
+ __axes.voxels(filled=__mask_data[0], facecolors=__plot_data[0], edgecolor=None)
197
+ __figure.tight_layout()
198
+ # remove the figure for the pyplot register for garbage collection
199
+ matplotlib.pyplot.close(__figure)
200
+ # update each component => (highlight, plot) states
201
+ return __figure
202
+
203
+ def update_text_highlight(
204
+ token_idx: float,
205
+ output_data: torch.Tensor,
206
+ tokenizer_obj: object,
207
+ ) -> list:
208
+ # exit if some values are missing
209
+ if (output_data is None) or (len(output_data) == 0):
210
+ return None
211
+ # detokenize the IDs
212
+ __token_str = psaiops.common.tokenizer.postprocess_token_ids(
213
+ tokenizer_obj=tokenizer_obj,
214
+ token_data=output_data)
215
+ # list of string classes
216
+ __token_cls = psaiops.score.residual.lib.postprocess_token_cls(
217
+ token_idx=int(token_idx),
218
+ token_dim=len(__token_str))
219
+ # pairs of token and class
220
+ return list(zip(__token_str, __token_cls))
221
+
222
+ # APP ##########################################################################
223
+
224
+ def create_app(title: str=TITLE, intro: str=INTRO, style: str=STYLE, model: str=MODEL) -> gradio.Blocks:
225
+ __fields = {}
226
+ with gradio.Blocks(theme=gradio.themes.Soft(), title=title, css=style) as __app:
227
+ # load the model
228
+ __device = 'cuda' if torch.cuda.is_available() else 'cpu'
229
+ __model = psaiops.common.model.get_model(name=model, device=__device)
230
+ __tokenizer = psaiops.common.tokenizer.get_tokenizer(name=model, device=__device)
231
+ # adapt the event handlers
232
+ __compute = functools.partial(update_computation_state, model_obj=__model, tokenizer_obj=__tokenizer, device_str=__device)
233
+ __highlight = functools.partial(update_text_highlight, tokenizer_obj=__tokenizer)
234
+ # create the UI
235
+ __fields.update(create_layout(intro=intro))
236
+ # init the state
237
+ __fields.update(create_state())
238
+ # update the data after clicking process
239
+ __fields['process_block'].click(
240
+ fn=__compute,
241
+ inputs=[__fields[__k] for __k in ['tokens_block', 'topk_block', 'topp_block', 'position_block', 'input_block']],
242
+ outputs=[__fields[__k] for __k in ['output_state', 'hidden_state']],
243
+ queue=False,
244
+ show_progress='full').then(
245
+ # update the range of the position slider when the output changes
246
+ fn=update_position_range,
247
+ inputs=[__fields[__k] for __k in ['position_block', 'tokens_block', 'output_state']],
248
+ outputs=__fields['position_block'],
249
+ queue=False,
250
+ show_progress='hidden').then(
251
+ # update the token highlight when the output data changes
252
+ fn=__highlight,
253
+ inputs=[__fields[__k] for __k in ['position_block', 'output_state']],
254
+ outputs=__fields['output_block'],
255
+ queue=False,
256
+ show_progress='full').then(
257
+ # update the plot when the router data changes
258
+ fn=update_hidden_plot,
259
+ inputs=[__fields[__k] for __k in ['position_block', 'hidden_state']],
260
+ outputs=__fields['plot_block'],
261
+ queue=False,
262
+ show_progress='full')
263
+ # update the range of the position slider when the settings change
264
+ __fields['tokens_block'].change(
265
+ fn=update_position_range,
266
+ inputs=[__fields[__k] for __k in ['position_block', 'tokens_block', 'output_state']],
267
+ outputs=__fields['position_block'],
268
+ queue=False,
269
+ show_progress='hidden')
270
+ # update the plot when the focus changes
271
+ __fields['position_block'].change(
272
+ fn=update_hidden_plot,
273
+ inputs=[__fields[__k] for __k in ['position_block', 'hidden_state']],
274
+ outputs=__fields['plot_block'],
275
+ queue=False,
276
+ show_progress='full')
277
+ # update the token highlight when the token focus changes
278
+ __fields['position_block'].change(
279
+ fn=__highlight,
280
+ inputs=[__fields[__k] for __k in ['position_block', 'output_state']],
281
+ outputs=__fields['output_block'],
282
+ queue=False,
283
+ show_progress='hidden')
284
+ # gradio application
285
+ return __app
286
+
287
+ # MAIN #########################################################################
288
+
289
+ if __name__ == '__main__':
290
+ __app = create_app()
291
+ __app.launch(share=True, debug=True)
@@ -0,0 +1,134 @@
1
+ import functools
2
+ import math
3
+
4
+ import matplotlib
5
+ import numpy
6
+ import torch
7
+
8
+ # GENERATE #######################################################################
9
+
10
+ @functools.lru_cache(maxsize=32)
11
+ def generate_token_ids(
12
+ model_obj: object,
13
+ input_ids: torch.Tensor,
14
+ token_num: int,
15
+ topk_num: int = 4,
16
+ topp_num: float = 0.9,
17
+ ) -> tuple:
18
+ # generate completion
19
+ with torch.no_grad():
20
+ __outputs = model_obj.generate(
21
+ input_ids=input_ids,
22
+ max_new_tokens=token_num,
23
+ do_sample=(0.0 < topp_num < 1.0) or (topk_num > 0),
24
+ top_k=topk_num if (topk_num > 0) else None,
25
+ top_p=topp_num if (0.0 < topp_num < 1.0) else None,
26
+ return_dict_in_generate=True,
27
+ output_hidden_states=True,
28
+ output_attentions=False,
29
+ output_scores=False,
30
+ early_stopping=True,
31
+ use_cache=True)
32
+ # ((B, T), O * L * (B, I, H))
33
+ return __outputs.sequences, __outputs.hidden_states
34
+
35
+ # MERGE ########################################################################
36
+
37
+ def merge_hidden_states(
38
+ hidden_data: torch.Tensor,
39
+ ) -> torch.Tensor:
40
+ # parse the inputs
41
+ __token_dim = len(hidden_data)
42
+ __layer_dim = len(hidden_data[0])
43
+ # stack the data for each layer => (B, L, I + O, H)
44
+ return torch.stack(
45
+ [
46
+ # concatenate the data for all the tokens => (B, I + O, H)
47
+ torch.concatenate([hidden_data[__t][__l] for __t in range(__token_dim)], dim=1)
48
+ for __l in range(__layer_dim)],
49
+ dim=1)
50
+
51
+ # REDUCE #######################################################################
52
+
53
+ def reduce_hidden_states(
54
+ hidden_data: torch.Tensor,
55
+ token_idx: int, # -1 => avg over all tokens
56
+ ) -> torch.Tensor:
57
+ # parse the hidden states (B, L, T, H)
58
+ __batch_dim, __layer_dim, __token_dim, __hidden_dim = tuple(hidden_data.shape)
59
+ __token_idx = min(token_idx, __token_dim - 1)
60
+ # select the relevant data along each axis
61
+ __token_slice = slice(0, __token_dim) if (__token_idx < 0) else slice(__token_idx, __token_idx + 1)
62
+ # filter the data
63
+ __data = hidden_data[slice(None), slice(None), __token_slice, slice(None)]
64
+ # reduce the token axis => (B, L, H)
65
+ return __data.mean(dim=2, keepdim=False)
66
+
67
+ # RESCALE ######################################################################
68
+
69
+ def rescale_hidden_states(
70
+ hidden_data: torch.Tensor, # (B, L, H)
71
+ ) -> torch.Tensor:
72
+ # compute the scale of the data, layer by layer
73
+ __s = torch.quantile(hidden_data.abs(), q=0.9, dim=-1, keepdim=True)
74
+ # log scaling on large values and linear near 0
75
+ __a = torch.asinh(hidden_data / (__s + torch.finfo().eps))
76
+ # clip and map to [-1; 1]
77
+ return 0.33 * __a.clamp(min=-3, max=3)
78
+
79
+ # RESHAPE ######################################################################
80
+
81
+ def reshape_hidden_states(
82
+ hidden_data: torch.Tensor, # (B, L, H)
83
+ ) -> torch.Tensor:
84
+ # parse the hidden states (B, L, H)
85
+ __batch_dim, __layer_dim, __hidden_dim = tuple(hidden_data.shape)
86
+ # factor the hidden dimension
87
+ __width_dim = math.gcd(__hidden_dim, 2 ** int(math.log2(__hidden_dim))) # greatest power of 2 that divides H
88
+ __height_dim = __hidden_dim // __width_dim
89
+ # reshape into (B, W, H, L)
90
+ return hidden_data.reshape((__batch_dim, __layer_dim, __width_dim, __height_dim)).permute(0, 2, 3, 1)
91
+
92
+ # MASK #########################################################################
93
+
94
+ def mask_hidden_states(
95
+ hidden_data: torch.Tensor, # (B, L, H)
96
+ topk_num: int=128,
97
+ ) -> torch.Tensor:
98
+ # sanitize
99
+ __k = min(topk_num, int(hidden_data.shape[-1]))
100
+ # indices of the topk values
101
+ __indices = hidden_data.abs().topk(__k, dim=-1, largest=True, sorted=False).indices
102
+ # initialize the mask with False
103
+ __mask = torch.zeros_like(hidden_data, dtype=torch.bool)
104
+ # (B, L, H) mask of the topk values
105
+ return __mask.scatter_(dim=-1, index=__indices, value=True)
106
+
107
+ # FORMAT #######################################################################
108
+
109
+ def color_hidden_states(
110
+ hidden_data: numpy.array, # (B, H, W, L)
111
+ gamma_val: float=0.7,
112
+ alpha_val: float=0.35,
113
+ color_map: callable=matplotlib.colormaps['coolwarm'],
114
+ ) -> list:
115
+ # [-1; 1] => [0; 1]
116
+ __data = 0.5 * (hidden_data + 1.0)
117
+ # (B, W, H, L) => (B, W, H, L, 4)
118
+ __rgba = color_map(__data)
119
+ # compute the transparency from the magnitude
120
+ __rgba[..., 3] = alpha_val * (numpy.abs(hidden_data) ** gamma_val)
121
+ # (B, W, H, L, 4) in [0; 1]
122
+ return __rgba
123
+
124
+ # POSTPROCESS ##################################################################
125
+
126
+ def postprocess_token_cls(
127
+ token_idx: int,
128
+ token_dim: int,
129
+ ) -> list:
130
+ __token_idx = max(-1, min(token_dim, token_idx))
131
+ # class 1 for the focused token(s) 0 for the rest
132
+ __token_cls = [str(int(__i == token_idx)) for __i in range(token_dim)]
133
+ # average on all the tokens when the idx is negative
134
+ return token_dim * ['1'] if (token_idx < 0) else __token_cls