psaiops 0.4.0__py3-none-any.whl → 0.4.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
psaiops/common/model.py CHANGED
@@ -22,7 +22,7 @@ def get_model(name: str, device: str='cpu'):
22
22
  @functools.lru_cache(maxsize=32)
23
23
  def generate_token_ids(
24
24
  model_obj: object,
25
- input_args: dict,
25
+ input_ids: torch.Tensor,
26
26
  token_num: int,
27
27
  topk_num: int = 4,
28
28
  topp_num: float = 0.9,
@@ -30,7 +30,7 @@ def generate_token_ids(
30
30
  # generate completion
31
31
  with torch.no_grad():
32
32
  __outputs = model_obj.generate(
33
- **input_args,
33
+ input_ids=input_ids,
34
34
  max_new_tokens=token_num,
35
35
  do_sample=(0.0 < topp_num < 1.0) or (topk_num > 0),
36
36
  top_k=topk_num if (topk_num > 0) else None,
@@ -166,7 +166,7 @@ def update_computation_state(
166
166
  # tensor (1, T)
167
167
  __output_data = psaiops.common.model.generate_token_ids(
168
168
  model_obj=model_obj,
169
- input_args=__input_data,
169
+ input_ids=__input_data['input_ids'],
170
170
  token_num=__token_num,
171
171
  topk_num=__topk_num,
172
172
  topp_num=__topp_num)
@@ -89,7 +89,7 @@ def score_tokens(
89
89
  # tensor (1, T)
90
90
  __outputs = psaiops.common.tokenizer.model.generate_token_ids(
91
91
  model_obj=model_obj,
92
- input_args=__inputs,
92
+ input_ids=__inputs['input_ids'],
93
93
  token_num=token_num,
94
94
  topk_num=topk_num,
95
95
  topp_num=topp_num)
@@ -154,7 +154,7 @@ def update_computation_state(
154
154
  # tensor (1, T) and O * L * (1, I, H)
155
155
  __output_data, __hidden_data = psaiops.score.residual.lib.generate_token_ids(
156
156
  model_obj=model_obj,
157
- input_args=__input_data,
157
+ input_ids=__input_data['input_ids'],
158
158
  token_num=__token_num,
159
159
  topk_num=__topk_num,
160
160
  topp_num=__topp_num)
@@ -183,16 +183,17 @@ def update_hidden_plot(
183
183
  # reshape into a 3D tensor by folding E (B, L, E) => (B, W, H, L)
184
184
  __plot_data = psaiops.score.residual.lib.reshape_hidden_states(
185
185
  hidden_data=__plot_data)
186
- # map the [-1; 1] activations to RGBA colors
187
- __plot_data = psaiops.score.residual.lib.color_hidden_states(
188
- hidden_data=__plot_data.numpy())
189
186
  # mask the small activations to improve the plot readability
190
187
  __mask_data = psaiops.score.residual.lib.mask_hidden_states(
191
188
  hidden_data=__plot_data,
192
189
  topk_num=128).numpy()
190
+ # map the [-1; 1] activations to RGBA colors
191
+ __plot_data = psaiops.score.residual.lib.color_hidden_states(
192
+ hidden_data=__plot_data.numpy())
193
193
  # plot the first sample
194
- __figure, __axes = matplotlib.pyplot.subplots(111, projection='3d')
195
- __axes.voxels(filled=__mask_data[0].numpy(), facecolors=__plot_data[0], edgecolor=None)
194
+ __figure = matplotlib.pyplot.figure()
195
+ __axes = __figure.add_subplot(1, 1, 1, projection='3d')
196
+ __axes.voxels(filled=__mask_data[0], facecolors=__plot_data[0], edgecolor=None)
196
197
  __figure.tight_layout()
197
198
  # remove the figure for the pyplot register for garbage collection
198
199
  matplotlib.pyplot.close(__figure)
@@ -10,7 +10,7 @@ import torch
10
10
  @functools.lru_cache(maxsize=32)
11
11
  def generate_token_ids(
12
12
  model_obj: object,
13
- input_args: dict,
13
+ input_ids: torch.Tensor,
14
14
  token_num: int,
15
15
  topk_num: int = 4,
16
16
  topp_num: float = 0.9,
@@ -18,7 +18,7 @@ def generate_token_ids(
18
18
  # generate completion
19
19
  with torch.no_grad():
20
20
  __outputs = model_obj.generate(
21
- **input_args,
21
+ input_ids=input_ids,
22
22
  max_new_tokens=token_num,
23
23
  do_sample=(0.0 < topp_num < 1.0) or (topk_num > 0),
24
24
  top_k=topk_num if (topk_num > 0) else None,
@@ -102,7 +102,7 @@ def mask_hidden_states(
102
102
  # initialize the mask with False
103
103
  __mask = torch.zeros_like(hidden_data, dtype=torch.bool)
104
104
  # (B, L, H) mask of the topk values
105
- return __mask.scatter_(dim=-1, index=__indices, src=True)
105
+ return __mask.scatter_(dim=-1, index=__indices, value=True)
106
106
 
107
107
  # FORMAT #######################################################################
108
108
 
@@ -110,14 +110,14 @@ def color_hidden_states(
110
110
  hidden_data: numpy.array, # (B, H, W, L)
111
111
  gamma_val: float=0.7,
112
112
  alpha_val: float=0.35,
113
- color_map: matplotlib.colormaps['coolwarm'],
113
+ color_map: callable=matplotlib.colormaps['coolwarm'],
114
114
  ) -> list:
115
115
  # [-1; 1] => [0; 1]
116
116
  __data = 0.5 * (hidden_data + 1.0)
117
117
  # (B, W, H, L) => (B, W, H, L, 4)
118
- __rgba = color_map[__data]
118
+ __rgba = color_map(__data)
119
119
  # compute the transparency from the magnitude
120
- __rgba[..., 3] = alpha_val * (np.abs(hidden_data) ** gamma_val)
120
+ __rgba[..., 3] = alpha_val * (numpy.abs(hidden_data) ** gamma_val)
121
121
  # (B, W, H, L, 4) in [0; 1]
122
122
  return __rgba
123
123
 
@@ -154,7 +154,7 @@ def update_computation_state(
154
154
  # tensor (1, T)
155
155
  __output_data = psaiops.common.model.generate_token_ids(
156
156
  model_obj=model_obj,
157
- input_args=__input_data,
157
+ input_ids=__input_data['input_ids'],
158
158
  token_num=__token_num,
159
159
  topk_num=__topk_num,
160
160
  topp_num=__topp_num)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: psaiops
3
- Version: 0.4.0
3
+ Version: 0.4.2
4
4
  Summary: Web apps to inspect & engineer NN activations.
5
5
  Author-email: apehex <17317183+apehex@users.noreply.github.com>
6
6
  License-Expression: AGPL-3.0
@@ -3,7 +3,7 @@ psaiops/combine/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
3
3
  psaiops/combine/app.py,sha256=BhnUkbp6Gf6StAGIfPGOqpds2dlvNTMh_5KTu8x-TT0,15804
4
4
  psaiops/common/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
5
5
  psaiops/common/data.py,sha256=WzBoLXg7o9RwSG5OjeJ0cZlPCzOlqjTG-o93nAPNwfg,1248
6
- psaiops/common/model.py,sha256=OkZZh_P9rkhxfqkvpM9uSdfNpqemaYxLZzLpaNNMz9Q,1379
6
+ psaiops/common/model.py,sha256=Fog_yH-sNLdMK6ReEpECT5xe_egryvJsMYkMnx3UTaQ,1393
7
7
  psaiops/common/tokenizer.py,sha256=hmysSL9DthYBbb8ulU2yPj_ETaVdBy_l3VpTah3YxCw,1256
8
8
  psaiops/compose/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
9
9
  psaiops/compose/contrast/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
@@ -16,13 +16,13 @@ psaiops/edit/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
16
16
  psaiops/reverse/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
17
17
  psaiops/score/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
18
18
  psaiops/score/attention/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
19
- psaiops/score/attention/app.py,sha256=zhHVjFhJlhWJCfRxwgkQ5tRngo0oGCn1kvWvaHJOPL8,13516
20
- psaiops/score/attention/lib.py,sha256=uVI7Tc5jT9XxJyM4behvcmwKAfgcrlN-Lki3dJOw0nA,4640
19
+ psaiops/score/attention/app.py,sha256=7Mi57jzfp7LJTTdmv4U7B79cPtSPF6b_G3-YLE9NX4U,13528
20
+ psaiops/score/attention/lib.py,sha256=Cn6279QjDpGVPqkpWh7nFsFMGwfXoYjQFu9euM6-_Bk,4652
21
21
  psaiops/score/residual/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
22
- psaiops/score/residual/app.py,sha256=3OpefnkiQgMrFp7joC5sBmVTYTKRXmTM32SGmlrUWCE,12409
23
- psaiops/score/residual/lib.py,sha256=kyr7gTBSnmd3yho0CFDdyfjvz4Qm4ucF3ipMFbXYcS0,4982
22
+ psaiops/score/residual/app.py,sha256=7X1mRZZ-raSu81FYRQRjzv7hazH1ipWNHwAhN4UjzPU,12443
23
+ psaiops/score/residual/lib.py,sha256=tqBrv7BLC9b_zwYHkfn1lq5YsXRagtGuGrdGEE1LMfc,5010
24
24
  psaiops/score/router/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
25
- psaiops/score/router/app.py,sha256=dJhspK43AYFe6IGv2LZKwCDLGannY2bruNjmW_42n_0,11824
25
+ psaiops/score/router/app.py,sha256=yR1zyb0Z9DpZvgDIu6Q1-qBZ5Y1rQQjnHgSlfJhc3s8,11836
26
26
  psaiops/score/router/lib.py,sha256=0cm3m-COGRgEv9EIxiK8wGYDoGojXiGycKxS0b2N4lE,2180
27
27
  psaiops/score/shapley/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
28
28
  psaiops/score/shapley/app.py,sha256=PC5VK7_z5zwMg57o-I9MLKPet-W6WldPpwxO1i4Mats,7170
@@ -30,7 +30,7 @@ psaiops/score/shapley/lib.py,sha256=uRo7xeIGTQJaiuEiVVyWlI7FtIldU9gJxLZrcaSRUTM,
30
30
  psaiops/score/similarity/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
31
31
  psaiops/score/similarity/app.py,sha256=JDYCiZOAHczhPnmjmOq_Pb08EdaTBipOBac8pwdOeP0,7085
32
32
  psaiops/score/similarity/lib.py,sha256=uRo7xeIGTQJaiuEiVVyWlI7FtIldU9gJxLZrcaSRUTM,13
33
- psaiops-0.4.0.dist-info/METADATA,sha256=QTAA83vwbcgyeXPrh1LEBzIfYBihPi8bBqls95pmMIM,1110
34
- psaiops-0.4.0.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
35
- psaiops-0.4.0.dist-info/licenses/.github/LICENSE.md,sha256=TfPDBt3ar0uv_f9cqCDMZ5rIzW3CY8anRRd4PkL6ejs,34522
36
- psaiops-0.4.0.dist-info/RECORD,,
33
+ psaiops-0.4.2.dist-info/METADATA,sha256=-5-sXpxgBCu1f7defwmAaDv3QwCIRp4Sx1PrFHRhRl8,1110
34
+ psaiops-0.4.2.dist-info/WHEEL,sha256=WLgqFyCfm_KASv4WHyYy0P3pM_m7J5L9k2skdKLirC8,87
35
+ psaiops-0.4.2.dist-info/licenses/.github/LICENSE.md,sha256=TfPDBt3ar0uv_f9cqCDMZ5rIzW3CY8anRRd4PkL6ejs,34522
36
+ psaiops-0.4.2.dist-info/RECORD,,