psaiops 0.3.2__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -1,69 +1,7 @@
1
- import functools
2
-
3
1
  import torch
4
- import transformers
5
-
6
- import deformers.models.openai.gptoss
7
-
8
- # LOAD #########################################################################
9
-
10
- @functools.lru_cache(maxsize=4)
11
- def get_tokenizer(name: str, device: str='cpu'):
12
- return transformers.AutoTokenizer.from_pretrained(
13
- name,
14
- use_fast=True,
15
- dtype='auto',
16
- device_map=device)
17
-
18
- @functools.lru_cache(maxsize=2)
19
- def get_model(name: str, device: str='cpu'):
20
- __model = deformers.models.openai.gptoss.GptOssForCausalInference.from_pretrained(
21
- name,
22
- dtype='auto',
23
- device_map=device)
24
- # toggle the inference mode (not training)
25
- __model.eval()
26
- # transformers model
27
- return __model
28
2
 
29
- # PREPROCESS #####################################################################
30
-
31
- @functools.lru_cache(maxsize=4)
32
- def preprocess_token_ids(
33
- tokenizer_obj: object,
34
- prompt_str: str,
35
- device_str: str='cpu'
36
- ) -> dict:
37
- # tokenize
38
- __inputs = tokenizer_obj(prompt_str, return_tensors='pt')
39
- # move to the main device
40
- return {__k: __v.to(device_str) for __k, __v in __inputs.items()}
41
-
42
- # GENERATE #######################################################################
43
-
44
- def generate_token_ids(
45
- model_obj: object,
46
- input_args: dict,
47
- token_num: int,
48
- topk_num: int = 4,
49
- topp_num: float = 0.9,
50
- ) -> torch.Tensor:
51
- # generate completion
52
- with torch.no_grad():
53
- __outputs = model_obj.generate(
54
- **input_args,
55
- max_new_tokens=token_num,
56
- do_sample=(0.0 < topp_num < 1.0) or (topk_num > 0),
57
- top_k=topk_num if (topk_num > 0) else None,
58
- top_p=topp_num if (0.0 < topp_num < 1.0) else None,
59
- return_dict_in_generate=True,
60
- output_hidden_states=False,
61
- output_attentions=False,
62
- output_scores=False,
63
- # early_stopping=True,
64
- use_cache=True)
65
- # full sequence
66
- return __outputs.sequences # (1, T)
3
+ import psaiops.common.model
4
+ import psaiops.common.tokenizer
67
5
 
68
6
  # COMPUTE ########################################################################
69
7
 
@@ -91,8 +29,8 @@ def reduce_attention_weights(
91
29
  ) -> torch.Tensor:
92
30
  # parse
93
31
  __layer_dim, __batch_dim, __head_dim, __output_dim, __output_dim = tuple(attention_data.shape) # L, B, H, T, T
94
- __layer_idx = min(layer_idx, __layer_dim)
95
- __head_idx = min(head_idx, __head_dim)
32
+ __layer_idx = min(layer_idx, __layer_dim - 1)
33
+ __head_idx = min(head_idx, __head_dim - 1)
96
34
  __token_idx = min(token_idx, __output_dim - input_dim - 1) # T = I + O
97
35
  # select the relevant data along each axis
98
36
  __layer_slice = slice(None) if (__layer_idx < 0) else slice(__layer_idx, __layer_idx + 1)
@@ -127,19 +65,6 @@ def postprocess_attention_scores(
127
65
  # native list of serialized integers
128
66
  return [str(__i) for __i in __input_scores.tolist() + __output_scores.tolist()] # (I,) + (O,) = (T,)
129
67
 
130
- # POSTPROCESS ####################################################################
131
-
132
- def postprocess_token_ids(
133
- tokenizer_obj: object,
134
- token_obj: torch.Tensor,
135
- ) -> list:
136
- # remove the batch axis
137
- __indices = token_obj.squeeze().tolist()
138
- # back to token strings
139
- __tokens = tokenizer_obj.convert_ids_to_tokens(__indices)
140
- # normalize the tokens
141
- return [__t.replace(chr(0x0120), ' ').replace(chr(0x010a), '\n') for __t in __tokens]
142
-
143
68
  # COMPUTE ########################################################################
144
69
 
145
70
  def score_tokens(
@@ -155,14 +80,14 @@ def score_tokens(
155
80
  tokenizer_obj: object,
156
81
  ) -> list:
157
82
  # dictionary {'input_ids': _, 'attention_mask': _}
158
- __inputs = preprocess_token_ids(
83
+ __inputs = psaiops.common.tokenizer.preprocess_token_ids(
159
84
  tokenizer_obj=tokenizer_obj,
160
85
  prompt_str=prompt_str,
161
86
  device_str=device_str)
162
87
  # parse the inputs
163
88
  __input_dim = int(__inputs['input_ids'].shape[-1])
164
89
  # tensor (1, T)
165
- __outputs = generate_token_ids(
90
+ __outputs = psaiops.common.tokenizer.model.generate_token_ids(
166
91
  model_obj=model_obj,
167
92
  input_args=__inputs,
168
93
  token_num=token_num,
@@ -185,7 +110,7 @@ def score_tokens(
185
110
  input_dim=__input_dim,
186
111
  token_idx=token_idx)
187
112
  # detokenize the IDs
188
- __tokens = postprocess_token_ids(
113
+ __tokens = psaiops.common.tokenizer.postprocess_token_ids(
189
114
  tokenizer_obj=tokenizer_obj,
190
115
  token_obj=__outputs)
191
116
  # match tokens and labels for the HighlightedText field
@@ -0,0 +1,290 @@
1
+ import functools
2
+
3
+ import gradio
4
+ import torch
5
+ import torch.cuda
6
+ import matplotlib.pyplot
7
+
8
+ import psaiops.common.model
9
+ import psaiops.common.tokenizer
10
+ import psaiops.score.residual.lib
11
+
12
+ # META #########################################################################
13
+
14
+ STYLE = '''.white-text span { color: white; }'''
15
+ TITLE = '''Router Scoring'''
16
+ INTRO = '''Plot the logits of the router for a given prompt.\nUnder construction, only "openai/gpt-oss-20b" is available for now.'''
17
+
18
+ MODEL = 'openai/gpt-oss-20b'
19
+
20
+ # COLORS #######################################################################
21
+
22
+ def create_color_map() -> dict:
23
+ return {
24
+ '0': '#000000',
25
+ '1': '#004444',}
26
+
27
+ # INTRO ########################################################################
28
+
29
+ def create_intro_block(intro: str) -> dict:
30
+ __intro = gradio.Markdown(intro, line_breaks=True)
31
+ return {'intro_block': __intro}
32
+
33
+ # MODEL ########################################################################
34
+
35
+ def create_model_block() -> dict:
36
+ __model = gradio.Dropdown(label='Model', value='openai/gpt-oss-20b', choices=['openai/gpt-oss-20b'], scale=1, allow_custom_value=False, multiselect=False, interactive=True) # 'openai/gpt-oss-120b'
37
+ return {'model_block': __model,}
38
+
39
+ # SAMPLING #####################################################################
40
+
41
+ def create_sampling_block() -> dict:
42
+ __tokens = gradio.Slider(label='Tokens', value=16, minimum=1, maximum=128, step=1, scale=1, interactive=True)
43
+ __topk = gradio.Slider(label='Top K', value=4, minimum=1, maximum=8, step=1, scale=1, interactive=True)
44
+ __topp = gradio.Slider(label='Top P', value=0.9, minimum=0.0, maximum=1.0, step=0.1, scale=1, interactive=True)
45
+ return {
46
+ 'tokens_block': __tokens,
47
+ 'topk_block': __topk,
48
+ 'topp_block': __topp,}
49
+
50
+ # INPUTS #######################################################################
51
+
52
+ def create_inputs_block() -> dict:
53
+ __input = gradio.Textbox(label='Prompt', value='', placeholder='A string of tokens to score.', lines=4, scale=1, show_copy_button=True, interactive=True)
54
+ return {'input_block': __input}
55
+
56
+ # PLOTS ########################################################################
57
+
58
+ def create_plot_block() -> dict:
59
+ __plot = gradio.Plot(label='Router', scale=1)
60
+ return {'plot_block': __plot,}
61
+
62
+ # OUTPUTS ######################################################################
63
+
64
+ def create_outputs_block() -> dict:
65
+ __output = gradio.HighlightedText(label='Output', value='', scale=1, interactive=False, show_legend=False, show_inline_category=False, combine_adjacent=False, color_map=create_color_map(), elem_classes='white-text')
66
+ return {'output_block': __output}
67
+
68
+ # SELECT #######################################################################
69
+
70
+ def create_selection_block() -> dict:
71
+ # __play = gradio.Button('>', variant='primary', size='lg', scale=1, interactive=True)
72
+ __position = gradio.Slider(label='Token', value=-1, minimum=-1, maximum=15, step=1, scale=1, interactive=True) # info='-1 to average on all tokens'
73
+ return {'position_block': __position,}
74
+
75
+ # ACTIONS ######################################################################
76
+
77
+ def create_actions_block() -> dict:
78
+ __process = gradio.Button('Process', variant='primary', size='lg', scale=1, interactive=True)
79
+ return {'process_block': __process,}
80
+
81
+ # STATE ########################################################################
82
+
83
+ def create_state() -> dict:
84
+ return {
85
+ 'output_state': gradio.State(None),
86
+ 'hidden_state': gradio.State(None),}
87
+
88
+ # LAYOUT #######################################################################
89
+
90
+ def create_layout(intro: str=INTRO) -> dict:
91
+ __fields = {}
92
+ __fields.update(create_intro_block(intro=intro))
93
+ with gradio.Tabs():
94
+ with gradio.Tab('Score Tokens') as __main_tab:
95
+ __fields.update({'main_tab': __main_tab})
96
+ with gradio.Row(equal_height=True):
97
+ __fields.update(create_inputs_block())
98
+ with gradio.Row(equal_height=True):
99
+ __fields.update(create_plot_block())
100
+ with gradio.Row(equal_height=True):
101
+ __fields.update(create_outputs_block())
102
+ with gradio.Row(equal_height=True):
103
+ __fields.update(create_selection_block())
104
+ with gradio.Row(equal_height=True):
105
+ __fields.update(create_actions_block())
106
+ with gradio.Tab('Settings') as __settings_tab:
107
+ __fields.update({'settings_tab': __settings_tab})
108
+ with gradio.Column(scale=1):
109
+ with gradio.Row(equal_height=True):
110
+ __fields.update(create_model_block())
111
+ with gradio.Row(equal_height=True):
112
+ __fields.update(create_sampling_block())
113
+ return __fields
114
+
115
+ # EVENTS #######################################################################
116
+
117
+ def update_position_range(
118
+ current_val: float,
119
+ token_num: float,
120
+ output_data: torch.Tensor,
121
+ ) -> dict:
122
+ # take the generated tokens into account
123
+ __max = int(token_num) - 1 if (output_data is None) else int(output_data.shape[-1])
124
+ # keep the previous value if possible
125
+ __val = min(int(current_val), __max)
126
+ # return a gradio update dictionary
127
+ return gradio.update(maximum=__max, value=__val)
128
+
129
+ def update_computation_state(
130
+ token_num: float,
131
+ topk_num: float,
132
+ topp_num: float,
133
+ token_idx: float,
134
+ prompt_str: str,
135
+ device_str: str,
136
+ model_obj: object,
137
+ tokenizer_obj: object,
138
+ ) -> tuple:
139
+ # sanitize the inputs
140
+ __token_num = max(1, min(128, int(token_num)))
141
+ __topk_num = max(1, min(8, int(topk_num)))
142
+ __topp_num = max(0.0, min(1.0, float(topp_num)))
143
+ __token_idx = max(-1, min(__token_num, int(token_idx)))
144
+ __prompt_str = prompt_str.strip()
145
+ __device_str = device_str if (device_str in ['cpu', 'cuda']) else 'cpu'
146
+ # exit if some values are missing
147
+ if (not __prompt_str) or (model_obj is None) or (tokenizer_obj is None):
148
+ return (torch.empty(0), torch.empty(0))
149
+ # dictionary {'input_ids': _, 'attention_mask': _}
150
+ __input_data = psaiops.common.tokenizer.preprocess_token_ids(
151
+ tokenizer_obj=tokenizer_obj,
152
+ prompt_str=__prompt_str,
153
+ device_str=__device_str)
154
+ # tensor (1, T) and O * L * (1, I, H)
155
+ __output_data, __hidden_data = psaiops.score.residual.lib.generate_token_ids(
156
+ model_obj=model_obj,
157
+ input_args=__input_data,
158
+ token_num=__token_num,
159
+ topk_num=__topk_num,
160
+ topp_num=__topp_num)
161
+ # tensor (1, L, I + O, H)
162
+ __hidden_data = psaiops.score.residual.lib.merge_hidden_states(
163
+ hidden_data=__hidden_data)
164
+ # update each component => (highlight, plot) states
165
+ return (
166
+ __output_data.cpu().float(),
167
+ __hidden_data.cpu().float(),)
168
+
169
+ def update_hidden_plot(
170
+ token_idx: float,
171
+ hidden_data: torch.Tensor,
172
+ ) -> tuple:
173
+ # exit if some values are missing
174
+ if (hidden_data is None) or (len(hidden_data) == 0):
175
+ return None
176
+ # reduce the token axis (B, L, T, E) => (B, L, E)
177
+ __plot_data = psaiops.score.residual.lib.reduce_hidden_states(
178
+ hidden_data=hidden_data,
179
+ token_idx=int(token_idx),)
180
+ # rescale the data to [-1; 1] (B, L, E)
181
+ __plot_data = psaiops.score.residual.lib.rescale_hidden_states(
182
+ hidden_data=__plot_data)
183
+ # reshape into a 3D tensor by folding E (B, L, E) => (B, W, H, L)
184
+ __plot_data = psaiops.score.residual.lib.reshape_hidden_states(
185
+ hidden_data=__plot_data)
186
+ # map the [-1; 1] activations to RGBA colors
187
+ __plot_data = psaiops.score.residual.lib.color_hidden_states(
188
+ hidden_data=__plot_data.numpy())
189
+ # mask the small activations to improve the plot readability
190
+ __mask_data = psaiops.score.residual.lib.mask_hidden_states(
191
+ hidden_data=__plot_data,
192
+ topk_num=128).numpy()
193
+ # plot the first sample
194
+ __figure, __axes = matplotlib.pyplot.subplots(111, projection='3d')
195
+ __axes.voxels(filled=__mask_data[0].numpy(), facecolors=__plot_data[0], edgecolor=None)
196
+ __figure.tight_layout()
197
+ # remove the figure for the pyplot register for garbage collection
198
+ matplotlib.pyplot.close(__figure)
199
+ # update each component => (highlight, plot) states
200
+ return __figure
201
+
202
+ def update_text_highlight(
203
+ token_idx: float,
204
+ output_data: torch.Tensor,
205
+ tokenizer_obj: object,
206
+ ) -> list:
207
+ # exit if some values are missing
208
+ if (output_data is None) or (len(output_data) == 0):
209
+ return None
210
+ # detokenize the IDs
211
+ __token_str = psaiops.common.tokenizer.postprocess_token_ids(
212
+ tokenizer_obj=tokenizer_obj,
213
+ token_data=output_data)
214
+ # list of string classes
215
+ __token_cls = psaiops.score.residual.lib.postprocess_token_cls(
216
+ token_idx=int(token_idx),
217
+ token_dim=len(__token_str))
218
+ # pairs of token and class
219
+ return list(zip(__token_str, __token_cls))
220
+
221
+ # APP ##########################################################################
222
+
223
+ def create_app(title: str=TITLE, intro: str=INTRO, style: str=STYLE, model: str=MODEL) -> gradio.Blocks:
224
+ __fields = {}
225
+ with gradio.Blocks(theme=gradio.themes.Soft(), title=title, css=style) as __app:
226
+ # load the model
227
+ __device = 'cuda' if torch.cuda.is_available() else 'cpu'
228
+ __model = psaiops.common.model.get_model(name=model, device=__device)
229
+ __tokenizer = psaiops.common.tokenizer.get_tokenizer(name=model, device=__device)
230
+ # adapt the event handlers
231
+ __compute = functools.partial(update_computation_state, model_obj=__model, tokenizer_obj=__tokenizer, device_str=__device)
232
+ __highlight = functools.partial(update_text_highlight, tokenizer_obj=__tokenizer)
233
+ # create the UI
234
+ __fields.update(create_layout(intro=intro))
235
+ # init the state
236
+ __fields.update(create_state())
237
+ # update the data after clicking process
238
+ __fields['process_block'].click(
239
+ fn=__compute,
240
+ inputs=[__fields[__k] for __k in ['tokens_block', 'topk_block', 'topp_block', 'position_block', 'input_block']],
241
+ outputs=[__fields[__k] for __k in ['output_state', 'hidden_state']],
242
+ queue=False,
243
+ show_progress='full').then(
244
+ # update the range of the position slider when the output changes
245
+ fn=update_position_range,
246
+ inputs=[__fields[__k] for __k in ['position_block', 'tokens_block', 'output_state']],
247
+ outputs=__fields['position_block'],
248
+ queue=False,
249
+ show_progress='hidden').then(
250
+ # update the token highlight when the output data changes
251
+ fn=__highlight,
252
+ inputs=[__fields[__k] for __k in ['position_block', 'output_state']],
253
+ outputs=__fields['output_block'],
254
+ queue=False,
255
+ show_progress='full').then(
256
+ # update the plot when the router data changes
257
+ fn=update_hidden_plot,
258
+ inputs=[__fields[__k] for __k in ['position_block', 'hidden_state']],
259
+ outputs=__fields['plot_block'],
260
+ queue=False,
261
+ show_progress='full')
262
+ # update the range of the position slider when the settings change
263
+ __fields['tokens_block'].change(
264
+ fn=update_position_range,
265
+ inputs=[__fields[__k] for __k in ['position_block', 'tokens_block', 'output_state']],
266
+ outputs=__fields['position_block'],
267
+ queue=False,
268
+ show_progress='hidden')
269
+ # update the plot when the focus changes
270
+ __fields['position_block'].change(
271
+ fn=update_hidden_plot,
272
+ inputs=[__fields[__k] for __k in ['position_block', 'hidden_state']],
273
+ outputs=__fields['plot_block'],
274
+ queue=False,
275
+ show_progress='full')
276
+ # update the token highlight when the token focus changes
277
+ __fields['position_block'].change(
278
+ fn=__highlight,
279
+ inputs=[__fields[__k] for __k in ['position_block', 'output_state']],
280
+ outputs=__fields['output_block'],
281
+ queue=False,
282
+ show_progress='hidden')
283
+ # gradio application
284
+ return __app
285
+
286
+ # MAIN #########################################################################
287
+
288
+ if __name__ == '__main__':
289
+ __app = create_app()
290
+ __app.launch(share=True, debug=True)
@@ -0,0 +1,134 @@
1
+ import functools
2
+ import math
3
+
4
+ import matplotlib
5
+ import numpy
6
+ import torch
7
+
8
+ # GENERATE #######################################################################
9
+
10
+ @functools.lru_cache(maxsize=32)
11
+ def generate_token_ids(
12
+ model_obj: object,
13
+ input_args: dict,
14
+ token_num: int,
15
+ topk_num: int = 4,
16
+ topp_num: float = 0.9,
17
+ ) -> tuple:
18
+ # generate completion
19
+ with torch.no_grad():
20
+ __outputs = model_obj.generate(
21
+ **input_args,
22
+ max_new_tokens=token_num,
23
+ do_sample=(0.0 < topp_num < 1.0) or (topk_num > 0),
24
+ top_k=topk_num if (topk_num > 0) else None,
25
+ top_p=topp_num if (0.0 < topp_num < 1.0) else None,
26
+ return_dict_in_generate=True,
27
+ output_hidden_states=True,
28
+ output_attentions=False,
29
+ output_scores=False,
30
+ early_stopping=True,
31
+ use_cache=True)
32
+ # ((B, T), O * L * (B, I, H))
33
+ return __outputs.sequences, __outputs.hidden_states
34
+
35
+ # MERGE ########################################################################
36
+
37
+ def merge_hidden_states(
38
+ hidden_data: torch.Tensor,
39
+ ) -> torch.Tensor:
40
+ # parse the inputs
41
+ __token_dim = len(hidden_data)
42
+ __layer_dim = len(hidden_data[0])
43
+ # stack the data for each layer => (B, L, I + O, H)
44
+ return torch.stack(
45
+ [
46
+ # concatenate the data for all the tokens => (B, I + O, H)
47
+ torch.concatenate([hidden_data[__t][__l] for __t in range(__token_dim)], dim=1)
48
+ for __l in range(__layer_dim)],
49
+ dim=1)
50
+
51
+ # REDUCE #######################################################################
52
+
53
+ def reduce_hidden_states(
54
+ hidden_data: torch.Tensor,
55
+ token_idx: int, # -1 => avg over all tokens
56
+ ) -> torch.Tensor:
57
+ # parse the hidden states (B, L, T, H)
58
+ __batch_dim, __layer_dim, __token_dim, __hidden_dim = tuple(hidden_data.shape)
59
+ __token_idx = min(token_idx, __token_dim - 1)
60
+ # select the relevant data along each axis
61
+ __token_slice = slice(0, __token_dim) if (__token_idx < 0) else slice(__token_idx, __token_idx + 1)
62
+ # filter the data
63
+ __data = hidden_data[slice(None), slice(None), __token_slice, slice(None)]
64
+ # reduce the token axis => (B, L, H)
65
+ return __data.mean(dim=2, keepdim=False)
66
+
67
+ # RESCALE ######################################################################
68
+
69
+ def rescale_hidden_states(
70
+ hidden_data: torch.Tensor, # (B, L, H)
71
+ ) -> torch.Tensor:
72
+ # compute the scale of the data, layer by layer
73
+ __s = torch.quantile(hidden_data.abs(), q=0.9, dim=-1, keepdim=True)
74
+ # log scaling on large values and linear near 0
75
+ __a = torch.asinh(hidden_data / (__s + torch.finfo().eps))
76
+ # clip and map to [-1; 1]
77
+ return 0.33 * __a.clamp(min=-3, max=3)
78
+
79
+ # RESHAPE ######################################################################
80
+
81
+ def reshape_hidden_states(
82
+ hidden_data: torch.Tensor, # (B, L, H)
83
+ ) -> torch.Tensor:
84
+ # parse the hidden states (B, L, H)
85
+ __batch_dim, __layer_dim, __hidden_dim = tuple(hidden_data.shape)
86
+ # factor the hidden dimension
87
+ __width_dim = math.gcd(__hidden_dim, 2 ** int(math.log2(__hidden_dim))) # greatest power of 2 that divides H
88
+ __height_dim = __hidden_dim // __width_dim
89
+ # reshape into (B, W, H, L)
90
+ return hidden_data.reshape((__batch_dim, __layer_dim, __width_dim, __height_dim)).permute(0, 2, 3, 1)
91
+
92
+ # MASK #########################################################################
93
+
94
+ def mask_hidden_states(
95
+ hidden_data: torch.Tensor, # (B, L, H)
96
+ topk_num: int=128,
97
+ ) -> torch.Tensor:
98
+ # sanitize
99
+ __k = min(topk_num, int(hidden_data.shape[-1]))
100
+ # indices of the topk values
101
+ __indices = hidden_data.abs().topk(__k, dim=-1, largest=True, sorted=False).indices
102
+ # initialize the mask with False
103
+ __mask = torch.zeros_like(hidden_data, dtype=torch.bool)
104
+ # (B, L, H) mask of the topk values
105
+ return __mask.scatter_(dim=-1, index=__indices, src=True)
106
+
107
+ # FORMAT #######################################################################
108
+
109
+ def color_hidden_states(
110
+ hidden_data: numpy.array, # (B, H, W, L)
111
+ gamma_val: float=0.7,
112
+ alpha_val: float=0.35,
113
+ color_map: matplotlib.colormaps['coolwarm'],
114
+ ) -> list:
115
+ # [-1; 1] => [0; 1]
116
+ __data = 0.5 * (hidden_data + 1.0)
117
+ # (B, W, H, L) => (B, W, H, L, 4)
118
+ __rgba = color_map[__data]
119
+ # compute the transparency from the magnitude
120
+ __rgba[..., 3] = alpha_val * (np.abs(hidden_data) ** gamma_val)
121
+ # (B, W, H, L, 4) in [0; 1]
122
+ return __rgba
123
+
124
+ # POSTPROCESS ##################################################################
125
+
126
+ def postprocess_token_cls(
127
+ token_idx: int,
128
+ token_dim: int,
129
+ ) -> list:
130
+ __token_idx = max(-1, min(token_dim, token_idx))
131
+ # class 1 for the focused token(s) 0 for the rest
132
+ __token_cls = [str(int(__i == token_idx)) for __i in range(token_dim)]
133
+ # average on all the tokens when the idx is negative
134
+ return token_dim * ['1'] if (token_idx < 0) else __token_cls