psaiops 0.0.13__py3-none-any.whl → 0.4.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,143 @@
1
+ import functools
2
+
3
+ import torch
4
+ import torch.nn.modules
5
+
6
+ import mlable.shapes
7
+ import psaiops.common.tokenizer
8
+
9
+ # HOOK #########################################################################
10
+
11
+ def capture_hidden_activation(
12
+ module: torch.nn.modules.Module,
13
+ inputs: torch.Tensor,
14
+ outputs: torch.Tensor,
15
+ index: int,
16
+ captured: dict,
17
+ ) -> None:
18
+ captured[index] = outputs # (B, S, E)
19
+
20
+ # MASKS ########################################################################
21
+
22
+ def compute_sequence_mask(
23
+ tokens: torch.Tensor, # (B, S)
24
+ # masks: torch.Tensor, # (B, S)
25
+ ) -> torch.Tensor:
26
+ __shape = mlable.shapes.divide(tokens.shape, axis=0, factor=2, insert=True)
27
+ # group the samples two by two
28
+ __data = tokens.reshape(__shape)
29
+ # compare each sample with its neighbor
30
+ __masks = __data[:, :1] != __data[:, 1:]
31
+ # apply the same mask to both samples
32
+ return __masks.expand(__shape).reshape(tokens.shape)
33
+
34
+ # REDUCTION ####################################################################
35
+
36
+ def compute_delta_activation(
37
+ data: torch.Tensor, # (B, S, E)
38
+ masks: torch.Tensor, # (B, S,)
39
+ signs: torch.Tensor, # (B,)
40
+ keepdim: bool=True,
41
+ ) -> torch.Tensor:
42
+ __dtype = data.dtype
43
+ __device = data.device
44
+ __dim0, __dim1, __dim2 = tuple(data.shape)
45
+ # sign each sample along the batch axis
46
+ __shape = tuple(mlable.shapes.filter(data.shape, axes=[0]))
47
+ __signs = signs.to(dtype=__dtype, device=__device).view(__shape)
48
+ # combine along the batch axis to keep the shortest mask on the sequence axis
49
+ __shape = tuple(mlable.shapes.filter(data.shape, axes=[0, 1]))
50
+ __masks = masks.to(dtype=__dtype, device=__device).view(__shape)
51
+ # mean factor: half the signs size along the batch axis and the number of positions kept along the sequence axis
52
+ __factor = (0.5 * float(__dim0) * __masks.sum(dim=1, keepdim=True)).clamp(min=1.0)
53
+ # take the difference along the batch axis and the average along the sequence axis
54
+ return (data * __signs * __masks / __factor).sum(dim=[0, 1], keepdim=keepdim)
55
+
56
+ # DELTA ########################################################################
57
+
58
+ def add_delta_activation(
59
+ module: torch.nn.modules.Module,
60
+ inputs: torch.Tensor,
61
+ outputs: torch.Tensor,
62
+ delta: torch.Tensor,
63
+ alpha: torch.Tensor,
64
+ beta: torch.Tensor,
65
+ ) -> torch.Tensor:
66
+ # expand the single feature axis of the delta
67
+ __shape = mlable.shapes.filter(outputs.shape, axes=[-1])
68
+ # rescale the delta
69
+ return alpha * outputs + beta * delta.view(__shape)
70
+
71
+ # MAIN #########################################################################
72
+
73
+ def steer_model_output(
74
+ positive_str: str,
75
+ negative_str: str,
76
+ prompt_str: str,
77
+ positive_rate: float,
78
+ negative_rate: float,
79
+ prompt_rate: float,
80
+ token_num: int,
81
+ topk_num: int,
82
+ topp_num: float,
83
+ layer_idx: int,
84
+ device_str: str,
85
+ model_obj: object,
86
+ tokenizer_obj: object,
87
+ ) -> str:
88
+ # parse & sanitize
89
+ __prompt0 = positive_str.strip()
90
+ __prompt1 = negative_str.strip()
91
+ __prompt2 = prompt_str.strip()
92
+ __alpha0 = max(0.0, float(positive_rate))
93
+ __alpha1 = max(0.0, float(negative_rate))
94
+ __alpha2 = max(0.0, float(prompt_rate))
95
+ __count = max(1, int(token_num))
96
+ __topk = max(1, int(topk_num))
97
+ __topp = max(0.0, float(topp_num))
98
+ __index = max(0, int(layer_idx))
99
+ # store hidden states
100
+ __captured = {}
101
+ # stop if inputs are missing
102
+ if not (__prompt0 and __prompt1 and __prompt2):
103
+ return ''
104
+ # tokenize the 2 prompts and pad to same length
105
+ __inputs = psaiops.common.tokenizer.preprocess_token_ids(tokenizer=tokenizer_obj, prompts=(__prompt0, __prompt1), device=device_str)
106
+ # forward hook to capture output hidden state
107
+ __hook = functools.partial(capture_hidden_activation, index=__index, captured=__captured)
108
+ # attach to the model
109
+ __handle = model_obj.model.layers[__index].register_forward_hook(__hook)
110
+ with torch.no_grad():
111
+ # inference mode
112
+ model_obj.eval().to(device_str)
113
+ # prefill with a single forward
114
+ __outputs = model_obj(**__inputs, use_cache=True, output_attentions=False, output_hidden_states=False, return_dict=True)
115
+ # stop capturing activations
116
+ __handle.remove()
117
+ # select only the positions where the tokens differ
118
+ __masks = compute_sequence_mask(tokens=__inputs['input_ids'])
119
+ # activation delta at layer L
120
+ __delta = compute_delta_activation(data=__captured[__index], masks=__masks, signs=torch.Tensor([1, -1]), keepdim=False)
121
+ # add the delta on every forward pass
122
+ __hook = functools.partial(add_delta_activation, alpha=__alpha2, beta=0.5 * (__alpha0 + __alpha1), delta=__delta)
123
+ # attach to the model
124
+ __handle = model_obj.model.layers[__index].register_forward_hook(__hook)
125
+ # now process the user input
126
+ __inputs = psaiops.common.tokenizer.preprocess_token_ids(tokenizer=tokenizer_obj, prompts=(prompt_str,), device=device_str)
127
+ # generate the new with tampered activations
128
+ with torch.no_grad():
129
+ __outputs = model_obj.generate(
130
+ **__inputs,
131
+ max_new_tokens=__count,
132
+ do_sample=(0.0 < __topp < 1.0) or (__topk > 0),
133
+ top_k=__topk if (__topk > 0) else None,
134
+ top_p=__topp if (0.0 < __topp <= 1.0) else None,
135
+ return_dict_in_generate=True,
136
+ output_hidden_states=False,
137
+ output_attentions=False,
138
+ output_scores=False,
139
+ use_cache=True)
140
+ # stop altering the activations
141
+ __handle.remove()
142
+ # single string
143
+ return tokenizer_obj.decode(__outputs.sequences[0], skip_special_tokens=True)
@@ -0,0 +1,323 @@
1
+ import functools
2
+ import itertools
3
+
4
+ import gradio
5
+ import pandas
6
+ import torch
7
+ import torch.cuda
8
+
9
+ import psaiops.common.model
10
+ import psaiops.common.tokenizer
11
+
12
+ # META #########################################################################
13
+
14
+ MODEL = 'openai/gpt-oss-20b'
15
+
16
+ STYLE = '''.giga-text input { font-size: 32px; }'''
17
+ TITLE = '''Activation Maths'''
18
+ INTRO = '''Compose prompts in the latent space.\nUnder construction, only "openai/gpt-oss-20b" is available for now.'''
19
+
20
+ COUNT = 8
21
+
22
+ # COLORS #######################################################################
23
+
24
+ def create_color_map() -> dict:
25
+ return {
26
+ '-1': '#004444',
27
+ **{str(__i): '#{:02x}0000'.format(int(2.55 * __i)) for __i in range(101)}}
28
+
29
+ # INTRO ########################################################################
30
+
31
+ def create_intro_block(intro: str) -> dict:
32
+ __intro = gradio.Markdown(intro, line_breaks=True)
33
+ return {'intro_block': __intro}
34
+
35
+ # MODEL ########################################################################
36
+
37
+ def create_model_block() -> dict:
38
+ __model = gradio.Dropdown(label='Model ID', value='openai/gpt-oss-20b', choices=['openai/gpt-oss-20b'], scale=1, allow_custom_value=False, multiselect=False, interactive=True) # 'openai/gpt-oss-120b'
39
+ __layer = gradio.Slider(label='Layer Depth', value=12, minimum=0, maximum=23, step=1, scale=1, interactive=True)
40
+ return {
41
+ 'model_block': __model,
42
+ 'layer_block': __layer,}
43
+
44
+ # SAMPLING #####################################################################
45
+
46
+ def create_sampling_block() -> dict:
47
+ __tokens = gradio.Slider(label='Tokens', value=32, minimum=1, maximum=128, step=1, scale=1, interactive=True)
48
+ __topk = gradio.Slider(label='Top K', value=4, minimum=1, maximum=8, step=1, scale=1, interactive=True)
49
+ __topp = gradio.Slider(label='Top P', value=0.9, minimum=0.0, maximum=1.0, step=0.1, scale=1, interactive=True)
50
+ return {
51
+ 'tokens_block': __tokens,
52
+ 'topk_block': __topk,
53
+ 'topp_block': __topp,}
54
+
55
+ # REDUCTION ####################################################################
56
+
57
+ def create_reduction_block() -> dict:
58
+ __from = gradio.Slider(label='Average From', value=0, minimum=0, maximum=256, step=1, scale=1, interactive=True)
59
+ __to = gradio.Slider(label='Average To', value=256, minimum=0, maximum=256, step=1, scale=1, interactive=True)
60
+ return {
61
+ 'from_block': __from,
62
+ 'to_block': __to,}
63
+
64
+ # INPUTS #######################################################################
65
+
66
+ def create_inputs_row(operation: str='', index: int=0) -> dict:
67
+ with gradio.Row(equal_height=True, visible=(index == 0)) as __row:
68
+ __operation = gradio.Dropdown(
69
+ label=f'Operation',
70
+ value='' if (index == 0) else operation,
71
+ choices=(index == 0) * [''] + ['+', '-', 'x', '.', 'µ', '='],
72
+ elem_classes='giga-text',
73
+ scale=1,
74
+ show_label=(index == 0),
75
+ allow_custom_value=False,
76
+ multiselect=False,
77
+ interactive=(index != 0),
78
+ visible=(index == 0))
79
+ __alpha = gradio.Slider(
80
+ label='Factor',
81
+ value=1.0,
82
+ minimum=0.0,
83
+ maximum=8.0,
84
+ step=0.1,
85
+ scale=1,
86
+ show_label=(index == 0),
87
+ interactive=True,
88
+ visible=(index == 0))
89
+ __input = gradio.Textbox(
90
+ label=f'Prompt',
91
+ value='',
92
+ placeholder='Some text.',
93
+ lines=2,
94
+ max_lines=2,
95
+ scale=8,
96
+ show_label=(index == 0),
97
+ show_copy_button=True,
98
+ interactive=True,
99
+ visible=(index == 0))
100
+ __delete = gradio.Button(
101
+ value='✖',
102
+ variant='secondary',
103
+ size='lg',
104
+ scale=1,
105
+ interactive=(index != 0),
106
+ visible=(index == 0))
107
+ return {
108
+ f'row_{index}_block': __row,
109
+ f'operation_{index}_block': __operation,
110
+ f'factor_{index}_block': __alpha,
111
+ f'prompt_{index}_block': __input,
112
+ f'button_{index}_block': __delete,}
113
+
114
+ # OUTPUTS ######################################################################
115
+
116
+ def create_outputs_block() -> dict:
117
+ __output = gradio.Textbox(label='= Total', value='', placeholder='Some text.', lines=2, max_lines=8, scale=1, show_label=True, show_copy_button=True, interactive=False)
118
+ return {'output_block': __output}
119
+
120
+ # ACTIONS ######################################################################
121
+
122
+ def create_actions_block() -> dict:
123
+ __add = gradio.Button(value='Add', variant='primary', size='lg', scale=1, interactive=True)
124
+ __process = gradio.Button(value='Process', variant='primary', size='lg', scale=1, interactive=True)
125
+ return {
126
+ 'show_block': __add,
127
+ 'process_block': __process,}
128
+
129
+ # TABLE ########################################################################
130
+
131
+ def create_table_block() -> dict:
132
+ __table = gradio.DataFrame(label='Summary', type='numpy', headers=None, row_count=4, col_count=256, scale=1, interactive=False)
133
+ return {'table_block': __table,}
134
+
135
+ # STATE ########################################################################
136
+
137
+ def default_state(visible: bool=False) -> dict:
138
+ return {'visible': visible, 'operation': '+', 'factor': 1.0, 'prompt': ''}
139
+
140
+ def create_state(limit: int=COUNT) -> dict:
141
+ return {
142
+ 'cache_block': gradio.State(
143
+ [default_state(True)] + [default_state(False) for _ in range(limit - 1)])}
144
+
145
+ # LAYOUT #######################################################################
146
+
147
+ def create_layout(intro: str=INTRO, limit: int=COUNT) -> dict:
148
+ __fields = {}
149
+ __fields.update(create_intro_block(intro=intro))
150
+ with gradio.Tabs():
151
+ with gradio.Tab('Equation') as __main_tab:
152
+ __fields.update({'main_tab': __main_tab})
153
+ for __i in range(limit):
154
+ __fields.update(create_inputs_row(operation='+', index=__i))
155
+ with gradio.Row(equal_height=True):
156
+ __fields.update(create_outputs_block())
157
+ with gradio.Row(equal_height=True):
158
+ __fields.update(create_actions_block())
159
+ with gradio.Tab('Details') as __details_tab:
160
+ __fields.update({'details_tab': __details_tab})
161
+ with gradio.Row(equal_height=True):
162
+ __fields.update(create_table_block())
163
+ with gradio.Tab('Settings') as __settings_tab:
164
+ __fields.update({'settings_tab': __settings_tab})
165
+ with gradio.Column(scale=1):
166
+ with gradio.Row(equal_height=True):
167
+ __fields.update(create_model_block())
168
+ with gradio.Row(equal_height=True):
169
+ __fields.update(create_sampling_block())
170
+ with gradio.Row(equal_height=True):
171
+ __fields.update(create_reduction_block())
172
+ # __fields.update(create_display_block())
173
+ return __fields
174
+
175
+ # DYNAMIC ######################################################################
176
+
177
+ def get_input_rows(inputs: dict, limit: int=COUNT) -> list:
178
+ return list(itertools.chain.from_iterable([
179
+ [
180
+ inputs.get(f'row_{__i}_block', None),
181
+ inputs.get(f'operation_{__i}_block', None),
182
+ inputs.get(f'factor_{__i}_block', None),
183
+ inputs.get(f'prompt_{__i}_block', None),
184
+ inputs.get(f'button_{__i}_block', None),]
185
+ for __i in range(limit)]))
186
+
187
+ def render_input_rows(rows: list) -> list:
188
+ return list(itertools.chain.from_iterable([
189
+ [
190
+ gradio.update(visible=__r.get('visible', False)),
191
+ gradio.update(visible=__r.get('visible', False), value=__r.get('operation', '')),
192
+ gradio.update(visible=__r.get('visible', False), value=__r.get('factor', 1.0)),
193
+ gradio.update(visible=__r.get('visible', False), value=__r.get('prompt', '')),
194
+ gradio.update(visible=__r.get('visible', False))]
195
+ for __r in rows]))
196
+
197
+ def show_input_row(rows: list) -> tuple:
198
+ __count = 0
199
+ __rows = list(rows)
200
+ for __i in range(len(__rows)):
201
+ # count the number of hidden rows (before changing their state)
202
+ __count = __count + int(not __rows[__i]['visible'])
203
+ # all the visible rows stay the same and the first hidden row is toggled
204
+ __rows[__i]['visible'] = __rows[__i]['visible'] or (__count < 2)
205
+ # update state and components
206
+ return __rows, *render_input_rows(__rows)
207
+
208
+ def hide_input_row(rows: list, index: int) -> tuple:
209
+ __rows = list(rows)
210
+ # always show the first row
211
+ if 0 < index < len(__rows):
212
+ # remove the target row
213
+ __rows.pop(index)
214
+ # keep the number of rows constant
215
+ __rows.append({'visible': False, 'operation': '+', 'factor': 1.0, 'prompt': ''})
216
+ # update state and components
217
+ return __rows, *render_input_rows(__rows)
218
+
219
+ # EVENTS #######################################################################
220
+
221
+ def update_layer_range(value: float, model: str) -> dict:
222
+ return gradio.update(maximum=35, value=min(35, int(value))) if '120b' in model else gradio.update(maximum=23, value=min(23, int(value)))
223
+
224
+ def update_input_cache(cache: list, value: any, index: int, field: str) -> list:
225
+ __cache = list(cache)
226
+ __cache[index][field] = value
227
+ return __cache
228
+
229
+ def update_operation_cache(cache: list, index: int, value: any) -> list:
230
+ return update_input_cache(cache=cache, index=int(index), value=str(value), field='operation')
231
+
232
+ def update_factor_cache(cache: list, index: int, value: any) -> list:
233
+ return update_input_cache(cache=cache, index=int(index), value=float(value), field='factor')
234
+
235
+ def update_prompt_cache(cache: list, index: int, value: any) -> list:
236
+ return update_input_cache(cache=cache, index=int(index), value=str(value), field='prompt')
237
+
238
+ def update_table_data(tokenizer: object) -> callable:
239
+ # called with unpacked arguments
240
+ def __update_table_data(*prompts: list) -> list:
241
+ # array of token IDs
242
+ __outputs = tokenizer(prompts, return_tensors='pt', padding=True)
243
+ # array of token strings
244
+ __tokens = [tokenizer.convert_ids_to_tokens(__s) for __s in __outputs['input_ids']]
245
+ # shift the special characters
246
+ return [[__t.replace(chr(0x0120), ' ').replace(chr(0x010a), '\\n') for __t in __s] for __s in __tokens]
247
+ # fixed to a given tokenizer
248
+ return __update_table_data
249
+
250
+ # APP ##########################################################################
251
+
252
+ def create_app(title: str=TITLE, intro: str=INTRO, style: str=STYLE, limit: int=COUNT, model: str=MODEL) -> gradio.Blocks:
253
+ __inputs = {}
254
+ with gradio.Blocks(theme=gradio.themes.Soft(), title=title, css=style) as __app:
255
+ # load the model
256
+ __device = 'cuda' if torch.cuda.is_available() else 'cpu'
257
+ # __model = psaiops.common.model.get_model(name=model, device=__device)
258
+ __tokenizer = psaiops.common.tokenizer.get_tokenizer(name=model, device=__device)
259
+ # create the UI
260
+ __inputs.update(create_layout(intro=intro, limit=limit))
261
+ # init the state
262
+ __inputs.update(create_state(limit=limit))
263
+ # apply the configuration
264
+ __format = update_table_data(tokenizer=__tokenizer)
265
+ # change the depth of the model
266
+ __inputs['model_block'].change(
267
+ fn=update_layer_range,
268
+ inputs=[__inputs[__k] for __k in ['layer_block', 'model_block']],
269
+ outputs=__inputs['layer_block'],
270
+ queue=False,
271
+ show_progress='hidden')
272
+ # show hidden row
273
+ __inputs['show_block'].click(
274
+ fn=show_input_row,
275
+ inputs=[__inputs['cache_block']],
276
+ outputs=[__inputs['cache_block']] + get_input_rows(inputs=__inputs, limit=limit),
277
+ queue=False,
278
+ show_progress='hidden')
279
+ # update the table
280
+ __inputs['details_tab'].select(
281
+ fn=__format,
282
+ inputs=[__inputs[f'prompt_{__i}_block'] for __i in range(limit)] + [__inputs['output_block']],
283
+ outputs=__inputs['table_block'],
284
+ queue=False,
285
+ show_progress='hidden')
286
+ # link each row of inputs to the cache
287
+ for __i in range(limit):
288
+ # update the target operation in the cache
289
+ __inputs[f'operation_{__i}_block'].change(
290
+ fn=update_operation_cache,
291
+ inputs=[__inputs['cache_block'], gradio.State(__i), __inputs[f'operation_{__i}_block']],
292
+ outputs=__inputs['cache_block'],
293
+ queue=False,
294
+ show_progress='hidden')
295
+ # update the target factor in the cache
296
+ __inputs[f'factor_{__i}_block'].change(
297
+ fn=update_factor_cache,
298
+ inputs=[__inputs['cache_block'], gradio.State(__i), __inputs[f'factor_{__i}_block']],
299
+ outputs=__inputs['cache_block'],
300
+ queue=False,
301
+ show_progress='hidden')
302
+ # update the target prompt in the cache
303
+ __inputs[f'prompt_{__i}_block'].change(
304
+ fn=update_prompt_cache,
305
+ inputs=[__inputs['cache_block'], gradio.State(__i), __inputs[f'prompt_{__i}_block']],
306
+ outputs=__inputs['cache_block'],
307
+ queue=False,
308
+ show_progress='hidden')
309
+ # hide the target row
310
+ __inputs[f'button_{__i}_block'].click(
311
+ fn=hide_input_row,
312
+ inputs=[__inputs['cache_block'], gradio.State(__i)],
313
+ outputs=[__inputs['cache_block']] + get_input_rows(inputs=__inputs, limit=limit),
314
+ queue=False,
315
+ show_progress='hidden')
316
+ # gradio application
317
+ return __app
318
+
319
+ # MAIN #########################################################################
320
+
321
+ if __name__ == '__main__':
322
+ __app = create_app()
323
+ __app.launch(share=True, debug=True)
@@ -0,0 +1 @@
1
+ import torch
File without changes