joonmyung 1.5.14__tar.gz → 1.5.16__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (51) hide show
  1. joonmyung-1.5.16/PKG-INFO +20 -0
  2. {joonmyung-1.5.14 → joonmyung-1.5.16}/joonmyung/analysis/__init__.py +0 -1
  3. joonmyung-1.5.16/joonmyung/analysis/analysis.py +145 -0
  4. joonmyung-1.5.14/joonmyung/analysis/analysis.py → joonmyung-1.5.16/joonmyung/analysis/analysis_bak.py +44 -100
  5. joonmyung-1.5.16/joonmyung/analysis/analysis_/343/205/240/343/205/217.py +218 -0
  6. {joonmyung-1.5.14 → joonmyung-1.5.16}/joonmyung/analysis/dataset.py +13 -19
  7. joonmyung-1.5.16/joonmyung/analysis/evaluate.py +39 -0
  8. joonmyung-1.5.16/joonmyung/analysis/model.py +109 -0
  9. joonmyung-1.5.16/joonmyung/clip/__init__.py +1 -0
  10. joonmyung-1.5.16/joonmyung/clip/clip.py +221 -0
  11. joonmyung-1.5.16/joonmyung/clip/model.py +445 -0
  12. joonmyung-1.5.16/joonmyung/clip/simple_tokenizer.py +131 -0
  13. joonmyung-1.5.16/joonmyung/compression/apply.py +139 -0
  14. joonmyung-1.5.16/joonmyung/compression/compression.py +202 -0
  15. {joonmyung-1.5.14 → joonmyung-1.5.16}/joonmyung/draw.py +72 -6
  16. {joonmyung-1.5.14 → joonmyung-1.5.16}/joonmyung/metric.py +3 -1
  17. joonmyung-1.5.16/joonmyung/model/__init__.py +0 -0
  18. joonmyung-1.5.16/joonmyung/model/compression.py +202 -0
  19. joonmyung-1.5.16/joonmyung/model.py +0 -0
  20. joonmyung-1.5.16/joonmyung/models/__init__.py +0 -0
  21. {joonmyung-1.5.14 → joonmyung-1.5.16}/joonmyung/script.py +21 -14
  22. {joonmyung-1.5.14 → joonmyung-1.5.16}/joonmyung/utils.py +16 -1
  23. joonmyung-1.5.16/joonmyung.egg-info/PKG-INFO +20 -0
  24. {joonmyung-1.5.14 → joonmyung-1.5.16}/joonmyung.egg-info/SOURCES.txt +14 -1
  25. joonmyung-1.5.16/joonmyung.egg-info/requires.txt +11 -0
  26. {joonmyung-1.5.14 → joonmyung-1.5.16}/setup.py +6 -1
  27. joonmyung-1.5.14/PKG-INFO +0 -9
  28. joonmyung-1.5.14/joonmyung/analysis/metric.py +0 -35
  29. joonmyung-1.5.14/joonmyung/analysis/model.py +0 -55
  30. joonmyung-1.5.14/joonmyung.egg-info/PKG-INFO +0 -9
  31. {joonmyung-1.5.14 → joonmyung-1.5.16}/LICENSE.txt +0 -0
  32. {joonmyung-1.5.14 → joonmyung-1.5.16}/README.md +0 -0
  33. {joonmyung-1.5.14 → joonmyung-1.5.16}/joonmyung/__init__.py +0 -0
  34. {joonmyung-1.5.14 → joonmyung-1.5.16}/joonmyung/analysis/hook.py +0 -0
  35. {joonmyung-1.5.14 → joonmyung-1.5.16}/joonmyung/analysis/utils.py +0 -0
  36. {joonmyung-1.5.14 → joonmyung-1.5.16}/joonmyung/app.py +0 -0
  37. {joonmyung-1.5.14/joonmyung/models → joonmyung-1.5.16/joonmyung/compression}/__init__.py +0 -0
  38. {joonmyung-1.5.14 → joonmyung-1.5.16}/joonmyung/data.py +0 -0
  39. {joonmyung-1.5.14 → joonmyung-1.5.16}/joonmyung/dummy.py +0 -0
  40. {joonmyung-1.5.14 → joonmyung-1.5.16}/joonmyung/file.py +0 -0
  41. {joonmyung-1.5.14 → joonmyung-1.5.16}/joonmyung/gradcam.py +0 -0
  42. {joonmyung-1.5.14 → joonmyung-1.5.16}/joonmyung/log.py +0 -0
  43. {joonmyung-1.5.14 → joonmyung-1.5.16}/joonmyung/meta_data/__init__.py +0 -0
  44. {joonmyung-1.5.14 → joonmyung-1.5.16}/joonmyung/meta_data/label.py +0 -0
  45. {joonmyung-1.5.14 → joonmyung-1.5.16}/joonmyung/meta_data/utils.py +0 -0
  46. {joonmyung-1.5.14 → joonmyung-1.5.16}/joonmyung/models/tome.py +0 -0
  47. {joonmyung-1.5.14 → joonmyung-1.5.16}/joonmyung/status.py +0 -0
  48. {joonmyung-1.5.14 → joonmyung-1.5.16}/joonmyung.egg-info/dependency_links.txt +0 -0
  49. {joonmyung-1.5.14 → joonmyung-1.5.16}/joonmyung.egg-info/not-zip-safe +0 -0
  50. {joonmyung-1.5.14 → joonmyung-1.5.16}/joonmyung.egg-info/top_level.txt +0 -0
  51. {joonmyung-1.5.14 → joonmyung-1.5.16}/setup.cfg +0 -0
@@ -0,0 +1,20 @@
1
+ Metadata-Version: 2.1
2
+ Name: joonmyung
3
+ Version: 1.5.16
4
+ Summary: JoonMyung's Library
5
+ Home-page: https://github.com/pizard/JoonMyung.git
6
+ Author: JoonMyung Choi
7
+ Author-email: pizard@korea.ac.kr
8
+ License: MIT
9
+ License-File: LICENSE.txt
10
+ Requires-Dist: fvcore
11
+ Requires-Dist: timm
12
+ Requires-Dist: torchprofile
13
+ Requires-Dist: thop
14
+ Requires-Dist: wandb
15
+ Requires-Dist: scipy
16
+ Requires-Dist: matplotlib
17
+ Requires-Dist: seaborn
18
+ Requires-Dist: opencv-python
19
+ Requires-Dist: ftfy
20
+ Requires-Dist: regex
@@ -1,4 +1,3 @@
1
1
  from .analysis import *
2
2
  from .dataset import *
3
- from .metric import *
4
3
  from .model import *
@@ -0,0 +1,145 @@
1
+ from joonmyung.draw import saliency, overlay, drawImgPlot, unNormalize, drawHeatmap
2
+ from joonmyung.analysis.model import JModel, ZeroShotInference
3
+ from timm.models.vision_transformer import Attention
4
+ from joonmyung.metric import targetPred, accuracy
5
+ from joonmyung.analysis.dataset import JDataset
6
+ from joonmyung.utils import read_classnames
7
+ from joonmyung.meta_data import data2path
8
+ from joonmyung.log import AverageMeter
9
+ import torch.nn.functional as F
10
+ from tqdm import tqdm
11
+ import numpy as np
12
+ import torch
13
+ import cv2
14
+
15
+ def anaModel(transformer_class):
16
+ class VisionTransformer(transformer_class):
17
+ info_key = []
18
+ def resetInfo(self):
19
+ self.info = {n: [] for n in self.info_key}
20
+
21
+ def createHook(self, hooks):
22
+ [self.info_key.append(hook[3]) for hook in hooks]
23
+ for name, module in self.named_modules():
24
+ for idx, hook in enumerate(hooks):
25
+ if hook[1] in name and hook[2] not in name:
26
+ if hook[0] == "f":
27
+ module.register_forward_hook(lambda mod, inp, out, hook_info=hook:
28
+ self.forward_hook(hook_info, mod, inp, out))
29
+ else:
30
+ module.register_backward_hook(lambda mod, inp, out, hook_info=hook:
31
+ self.backward_hook(hook_info, mod, inp, out))
32
+ def forward_hook(self, hook_info, module, input, output):
33
+ self.info[hook_info[3]].append(output.detach())
34
+
35
+ def backward_hook(self, hook_info, module, input, output):
36
+ self.info[hook_info[3]].append(input[0].detach())
37
+
38
+ def forward(self, *args, **kwdargs):
39
+ self.resetInfo()
40
+ return super().forward(*args, **kwdargs)
41
+ def encode_image(self, *args, **kwdargs):
42
+ self.resetInfo()
43
+ return super().encode_image(*args, **kwdargs)
44
+
45
+ return VisionTransformer
46
+
47
+ def Analysis(model, hook_info= [["f", "attn_drop", "decoder", "attn"]]):
48
+ model.__class__ = anaModel(model.__class__)
49
+ model.createHook(hook_info)
50
+ return model
51
+
52
+ if __name__ == '__main__':
53
+ dataset_name, device, debug = "imagenet", 'cuda', True
54
+ data_path, num_classes, _, _ = data2path(dataset_name)
55
+ analysis = [0] # [0] : INPUT TYPE, [0 : SAMPLE + POS, 1 : SAMPLE, 2 : POS]
56
+
57
+ dataset = JDataset(data_path, dataset_name, device=device)
58
+ data_idxs = [[c, i] for i in range(1000) for c in range(50)]
59
+
60
+ modelMaker = JModel(num_classes, device=device)
61
+ model = modelMaker.getModel(2, "ViT-B/16")
62
+ classnames = read_classnames("/hub_data1/joonmyung/data/imagenet/classnames.txt")
63
+ model = ZeroShotInference(model, classnames, prompt="a photo of a {}.", device=device)
64
+ hook_info = [["b", "attn_drop", "decoder", "grad"],
65
+ ["f", "attn_drop", "decoder", "attn"],
66
+ ["f", "ln_pre", "decoder", "feat_1"],
67
+ ["f", "ln_1", "decoder", "feat_2"],
68
+ ["f", "ln_2", "decoder", "feat_3"],
69
+ ["f", "ln_post", "decoder", "feat_4"]]
70
+ model.model = Analysis(model.model, hook_info)
71
+ view = [False, False, True, True, True, True] # [IMG, SALIENCY:ATTN, SALIENCY:OPENCV, SALIENCY:GRAD, ATTN. MOVEMENT]
72
+ for idx, data_idx in enumerate(data_idxs):
73
+ print(f"------------------------- [{data_idx[0]}]/[{data_idx[1]}] -------------------------")
74
+ sample, target, label_name = dataset[data_idx[0], data_idx[1]]
75
+ sample.requires_grad = True
76
+
77
+ if view[0]:
78
+ drawImgPlot(unNormalize(sample, "imagenet"))
79
+
80
+ output = model(sample)
81
+ index = torch.eye(num_classes, device=device)[target]
82
+ (output * index).sum().backward(retain_graph=True)
83
+
84
+ attns = model.model.info["attn"]
85
+ grads = model.model.info["grad"]
86
+ if view[1]:
87
+ col, discard_ratios, v_ratio, head_fusion, data_from = 12, [0.0], 0.0, "mean", "patch"
88
+ results = saliency(attns, False, head_fusion=head_fusion, discard_ratios=discard_ratios, data_from=data_from, reshape=True, device=device)
89
+
90
+ data_roll = overlay(sample, results["rollout"], dataset_name)
91
+ drawImgPlot(data_roll, col=col)
92
+
93
+ data_attn = overlay(sample, results["attentive"], dataset_name)
94
+ drawImgPlot(data_attn, col=col)
95
+
96
+ data_vidTLDR = overlay(sample, results["vidTLDR"], dataset_name)
97
+ drawImgPlot(data_vidTLDR, col=col)
98
+
99
+ discard_ratios, v_ratio, head_fusion, data_from = [0.0], 0.1, "mean", "cls"
100
+ results = saliency(attns, grads, head_fusion=head_fusion, discard_ratios=discard_ratios, data_from=data_from, reshape=True, device=device)
101
+
102
+ data_roll = overlay(sample, results["rollout"], dataset_name)
103
+ drawImgPlot(data_roll, col=col)
104
+
105
+ data_attn = overlay(sample, results["attentive"], dataset_name)
106
+ drawImgPlot(data_attn, col=col)
107
+
108
+ data_vidTLDR = overlay(sample, results["vidTLDR"], dataset_name)
109
+ drawImgPlot(data_vidTLDR, col=col)
110
+
111
+ if view[2]: # SALIENCY W/ DATA
112
+ img = (dataset.unNormalize(sample)[0].permute(1, 2, 0).detach().cpu().numpy() * 255)
113
+ img_saliency = cv2.saliency.StaticSaliencySpectralResidual_create()
114
+ (success, saliencyMap) = img_saliency.computeSaliency(img)
115
+ saliencyMap = (saliencyMap * 255).astype("uint8")
116
+
117
+ img_saliency = cv2.saliency.StaticSaliencyFineGrained_create()
118
+ (success, saliencyFineMap) = img_saliency.computeSaliency(img)
119
+ threshMap = cv2.threshold((saliencyFineMap * 255).astype("uint8"), 0, 255, cv2.THRESH_BINARY | cv2.THRESH_OTSU)[1]
120
+
121
+ if view[3]: # SALIENCY FOR INPUT
122
+ output = model(sample)
123
+ attn = torch.stack(attns, dim=1).mean(dim=[2, 3])[0, -2]
124
+ a = torch.autograd.grad(output[:, 3], sample, retain_graph=True)[0].sum(dim=1)
125
+ b = F.interpolate(a.unsqueeze(0), scale_factor=1.0, mode='nearest')[0]
126
+
127
+ if view[4]: # ATTENTION MOVEMENT (FROM / TO)
128
+ attn = torch.stack(attns).mean(dim=2).transpose(0,1) # (8 (B), 12 (L), 197(T_Q), 197(T_K))
129
+
130
+ cls2cls = attn[:, :, :1, 0].mean(dim=2) # (8(B), 12(L))
131
+ patch2cls = attn[:, :, :1, 1:].mean(dim=2).sum(dim=-1) # (8(B), 12(L))
132
+ cls2patch = attn[:, :, 1:, 0].mean(dim=2)
133
+ patch2patch = attn[:, :, 1:, 1:].mean(dim=2).sum(dim=-1)
134
+ # to_np(torch.stack([cls2cls.mean(dim=0), patch2cls.mean(dim=0), cls2patch.mean(dim=0), patch2patch.mean(dim=0)]))
135
+ if view[5]:
136
+ feats = {k: v for k, v in model.model.info if "feat" in k}
137
+ for name, feat in feats.items():
138
+ print(f"Feature Position : {name}")
139
+ image_feat = (torch.stack(feat)[:, :, 1:] @ model.model.visual.proj) # (1, 1, 196, 512)
140
+ L = image_feat.shape[0]
141
+ image_feat = image_feat / image_feat.norm(dim=-1, keepdim=True)
142
+
143
+ text_feat = model.text_features[1][None].t()
144
+ sim = (image_feat @ text_feat).reshape(L, 14, 14)
145
+ drawHeatmap(sim, col = L)
@@ -1,16 +1,17 @@
1
- from joonmyung.analysis.dataset import JDataset
2
- from joonmyung.analysis.model import JModel
3
1
  from joonmyung.draw import saliency, overlay, drawImgPlot, unNormalize
4
- from joonmyung.meta_data import data2path
5
2
  from joonmyung.metric import targetPred, accuracy
3
+ from joonmyung.analysis.dataset import JDataset
4
+ from joonmyung.analysis.model import JModel, ZeroShotInference
5
+ from joonmyung.meta_data import data2path
6
6
  from joonmyung.log import AverageMeter
7
- from tqdm import tqdm
8
- from contextlib import suppress
9
7
  import torch.nn.functional as F
8
+ from tqdm import tqdm
10
9
  import numpy as np
11
10
  import torch
12
11
  import cv2
13
12
 
13
+ from joonmyung.utils import read_classnames
14
+
14
15
 
15
16
  def anaModel(transformer_class):
16
17
  class VisionTransformer(transformer_class):
@@ -44,78 +45,46 @@ def anaModel(transformer_class):
44
45
  return VisionTransformer
45
46
 
46
47
  class Analysis:
47
- def __init__(self, model, analysis = [0], activate = [True, False, False, False], detach=True, key_name=None, num_classes = 1000
48
- , cls_start=0, cls_end=1, patch_start=1, patch_end=None, wrapping=False
49
- , amp_autocast=suppress, device="cuda"):
50
- # Section A. Model
51
- self.num_classes = num_classes
52
- self.key_name = key_name
53
- if wrapping:
48
+ def __init__(self, model, analysis = [0], activate = [True, False, False, False], num_classes = 1000, device="cuda"):
49
+ if sum(analysis):
54
50
  model_ = anaModel(model.__class__)
55
51
  model.__class__ = model_
56
52
  model.analysis = analysis
57
53
 
58
- self.model = model
59
- self.detach = detach
60
-
61
- # Section B. Attention
62
- self.kwargs_roll = {"cls_start" : cls_start, "cls_end" : cls_end,
63
- "patch_start" : patch_start, "patch_end" : patch_end}
64
-
65
- # Section C. Setting
54
+ self.num_classes = num_classes
55
+ self.model = model.to(device)
66
56
  hooks = [{"name_i": 'attn_drop', "name_o": 'decoder', "fn_f": self.attn_forward, "fn_b": self.attn_backward},
67
- {"name_i": 'qkv', "name_o": 'decoder', "fn_f": self.qkv_forward, "fn_b": self.qkv_backward},
68
- {"name_i": 'head', "name_o": 'decoder', "fn_f": self.head_forward, "fn_b": self.head_backward},
69
- {"name_i": 'patch_embed.norm', "name_o": 'decoder', "fn_f": self.input_forward, "fn_b": self.input_backward}]
70
- self.activate = activate
71
-
72
- self.amp_autocast = amp_autocast
73
- self.device = device
57
+ {"name_i": 'qkv', "name_o": 'decoder', "fn_f": self.qkv_forward, "fn_b": None},
58
+ {"name_i": 'head', "name_o": 'decoder', "fn_f": self.head_forward, "fn_b": None},
59
+ {"name_i": 'patch_embed.norm', "name_o": 'decoder', "fn_f": self.input_forward, "fn_b": None}]
74
60
 
75
61
  for name, module in self.model.named_modules():
76
- for hook in hooks:
77
- if hook["name_i"] in name and hook["name_o"] not in name:
78
- module.register_forward_hook(hook["fn_f"])
79
- module.register_backward_hook(hook["fn_b"])
80
- self.resetInfo()
62
+ for idx, hook in enumerate(hooks):
63
+ if hook["name_i"] in name and hook["name_o"] not in name and activate[idx]:
64
+ if hook["fn_f"]: module.register_forward_hook(hook["fn_f"])
65
+ if hook["fn_b"]: module.register_backward_hook(hook["fn_b"])
81
66
 
82
- def attn_forward(self, module, input, output):
83
- # input/output : 1 * (8, 3, 197, 197) / (8, 3, 197, 197)
84
- if self.activate[0]: self.info["attn"]["f"].append(output.detach())
67
+ def attn_forward(self, module, input, output): # input/output : 1 * (8, 3, 197, 197) / (8, 3, 197, 197)
68
+ self.info["attn"]["f"].append(output.detach())
85
69
 
86
- def attn_backward(self, module, grad_input, grad_output):
87
- # input/output : 1 * (8, 3, 197, 192) / (8, 3, 197, 576)
88
- if self.activate[0]: self.info["attn"]["b"].append(grad_input[0].detach())
70
+ def attn_backward(self, module, grad_input, grad_output): # # input/output : 1 * (8, 3, 197, 192) / (8, 3, 197, 576)
71
+ self.info["attn"]["b"].append(grad_input[0].detach())
89
72
 
90
- def qkv_forward(self, module, input, output):
91
- # input/output : 1 * (8, 197, 192) / (8, 197, 576)
92
- if self.activate[1]: self.info["qkv"]["f"].append(output.detach())
73
+ def qkv_forward(self, module, input, output): # # input/output : 1 * (8, 197, 192) / (8, 197, 576)
74
+ self.info["qkv"]["f"].append(output.detach())
93
75
 
94
- def qkv_backward(self, module, grad_input, grad_output):
95
- self.info["qkv"]["b"].append(grad_input[0].detach())
96
- # pass
76
+ def head_forward(self, module, input, output): # input : 1 * (8(B), 192(D)), output : (8(B), 1000(C))
77
+ B = output.shape[0]
78
+ pred = targetPred(output, self.targets, topk=5)
79
+ self.info["head"]["TF"] += (pred[:, 0] == pred[:, 1])
97
80
 
98
- def head_forward(self, module, input, output):
99
- # input : 1 * (8(B), 192(D)), output : (8(B), 1000(C))
100
- if self.activate[2]:
101
- B = output.shape[0]
102
- pred = targetPred(output, self.targets, topk=5)
103
- self.info["head"]["TF"] += (pred[:, 0] == pred[:, 1])
104
-
105
- acc1, acc5 = accuracy(output, self.targets, topk=(1,5))
106
- self.info["head"]["acc1"].update(acc1.item(), n=B)
107
- self.info["head"]["acc5"].update(acc5.item(), n=B)
108
-
109
- def head_backward(self, module, grad_input, grad_output):
110
- pass
81
+ acc1, acc5 = accuracy(output, self.targets, topk=(1,5))
82
+ self.info["head"]["acc1"].update(acc1.item(), n=B)
83
+ self.info["head"]["acc5"].update(acc5.item(), n=B)
111
84
 
112
85
  def input_forward(self, module, input, output):
113
- if self.activate[3]:
114
- norm = F.normalize(output, dim=-1)
115
- self.info["input"]["sim"] += (norm @ norm.transpose(-1, -2)).mean(dim=(-1, -2))
116
-
117
- def input_backward(self, module, grad_input, grad_output):
118
- pass
86
+ norm = F.normalize(output, dim=-1)
87
+ self.info["input"]["sim"] += (norm @ norm.transpose(-1, -2)).mean(dim=(-1, -2))
119
88
 
120
89
  def resetInfo(self):
121
90
  self.info = {"attn" : {"f": [], "b": []},
@@ -151,7 +120,7 @@ class Analysis:
151
120
  self.info["attn"]["b"] = []
152
121
  self.model.zero_grad()
153
122
  if index == None: index = output.max(dim=1)[1]
154
- index = torch.eye(self.num_classes, device=self.device)[index]
123
+ index = torch.eye(self.num_classes, device=device)[index]
155
124
  loss = (output * index).sum()
156
125
  loss.backward(retain_graph=True)
157
126
  grad = self.info["attn"]["b"]
@@ -162,31 +131,27 @@ class Analysis:
162
131
 
163
132
 
164
133
  if __name__ == '__main__':
165
- # Section A. Data
166
- dataset_name, device, amp_autocast, debug = "imagenet", 'cuda', torch.cuda.amp.autocast, True
134
+ dataset_name, device, debug = "imagenet", 'cuda', True
167
135
  data_path, num_classes, _, _ = data2path(dataset_name)
168
- view, activate = [False, True, False, False, True], [True, False, False]
169
- # VIEW : IMG, SALIENCY:ATTN, SALIENCY:OPENCV, SALIENCY:GRAD, ATTN. MOVEMENT
170
- # ACTIVATE : ATTN, QKV, HEAD
136
+ activate = [True, False, False, False] # [ATTN, QKV, HEAD]
171
137
  analysis = [0] # [0] : INPUT TYPE, [0 : SAMPLE + POS, 1 : SAMPLE, 2 : POS]
172
138
 
173
139
  dataset = JDataset(data_path, dataset_name, device=device)
174
- # data_idxs = [[c, i] for i in range(1000) for c in range(50)]
175
- data_idxs = [[1, 0]]
140
+ data_idxs = [[c, i] for i in range(1000) for c in range(50)]
176
141
 
177
- # Section B. Model
178
- model_number, model_name = 0, "deit_tiny_patch16_224" # deit, vit | tiny, small, base
179
- # model_number, model_name = 1, "deit_tiny_patch16_224"
180
-
181
- # Section C. Setting
182
142
  modelMaker = JModel(num_classes, device=device)
183
- model = modelMaker.getModel(model_number, model_name)
143
+ model = modelMaker.getModel(2, "ViT-B/16")
144
+
145
+ classnames = read_classnames("/hub_data1/joonmyung/data/imagenet/classnames.txt")
146
+ model = ZeroShotInference(model, classnames, prompt="a photo of a {}.", device=device)
147
+
184
148
  model = Analysis(model, analysis = analysis, activate = activate, device=device)
149
+
150
+ view = [False, True, False, False, True] # [IMG, SALIENCY:ATTN, SALIENCY:OPENCV, SALIENCY:GRAD, ATTN. MOVEMENT]
185
151
  for idx, data_idx in enumerate(data_idxs):
186
152
  print(f"------------------------- [{data_idx[0]}]/[{data_idx[1]}] -------------------------")
187
153
 
188
154
  sample, target, label_name = dataset[data_idx[0], data_idx[1]]
189
- # sample, _, img, _ = dataset.getItemPath('/hub_data1/joonmyung/data/imagenet/train/n01440764/n01440764_39.JPEG')
190
155
  output = model(sample)
191
156
  if view[0]:
192
157
  drawImgPlot(unNormalize(sample, "imagenet"))
@@ -219,22 +184,6 @@ if __name__ == '__main__':
219
184
  data_vidTLDR = overlay(sample, results["vidTLDR"], dataset_name)
220
185
  drawImgPlot(data_vidTLDR, col=col)
221
186
 
222
- print(1)
223
-
224
- # roll = F.normalize(results["rollout"].reshape(12, 196), dim=-1)
225
-
226
- # datas_rollout = overlay(sample, rollout, dataset_name)
227
- # drawImgPlot(datas_rollout, col=col)
228
-
229
- # datas_attn = overlay(sample, attentive, dataset_name)
230
- # drawImgPlot(datas_attn, col=col)
231
-
232
- # a = attentive[5]
233
- # b = torch.stack([a.clamp(max=a.quantile(1 - v_ratio)) for v_ratio in [0.0, 0.05, 0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45, 0.5, 0.55]])
234
- # datas_attn = overlay(sample, b, dataset_name)
235
- # drawImgPlot(datas_attn, col=col)
236
- # print(1)
237
-
238
187
  if view[2]: # SALIENCY W/ DATA
239
188
  img = np.array(dataset[data_idx[0], data_idx[1], 2][0])
240
189
 
@@ -253,12 +202,8 @@ if __name__ == '__main__':
253
202
  output = model(sample)
254
203
  attn = torch.stack(model.info["attn"]["f"], dim=1).mean(dim=[2,3])[0,-2]
255
204
  topK = attn[1:].topk(k, -1, True)[1]
256
- # a = torch.autograd.grad(attn.sum(), samples, retain_graph=True)[0].sum(dim=1)
257
205
  a = torch.autograd.grad(output[:,3], sample, retain_graph=True)[0].sum(dim=1)
258
206
  b = F.interpolate(a.unsqueeze(0), scale_factor=0.05, mode='nearest')[0]
259
- # drawHeatmap(b)
260
- print(1)
261
- # to_np(torch.stack([attn[:, :, 0], attn[:, :, 1:].sum(dim=-1)], -1)[0])
262
207
 
263
208
  if view[4]: # ATTENTION MOVEMENT (FROM / TO)
264
209
  attn = torch.stack(model.info["attn"]["f"]).mean(dim=2).transpose(0,1) # (8 (B), 12 (L), 197(T_Q), 197(T_K))
@@ -270,5 +215,4 @@ if __name__ == '__main__':
270
215
  # PATCH가 얼마나 참고하는지
271
216
  cls2patch = attn[:, :, 1:, 0].mean(dim=2)
272
217
  patch2patch = attn[:, :, 1:, 1:].mean(dim=2).sum(dim=-1)
273
- # to_np(torch.stack([cls2cls.mean(dim=0), patch2cls.mean(dim=0), cls2patch.mean(dim=0), patch2patch.mean(dim=0)]))
274
- print(1)
218
+ # to_np(torch.stack([cls2cls.mean(dim=0), patch2cls.mean(dim=0), cls2patch.mean(dim=0), patch2patch.mean(dim=0)]))
@@ -0,0 +1,218 @@
1
+ from joonmyung.draw import saliency, overlay, drawImgPlot, unNormalize
2
+ from joonmyung.metric import targetPred, accuracy
3
+ from joonmyung.analysis.dataset import JDataset
4
+ from joonmyung.analysis.model import JModel, ZeroShotInference
5
+ from joonmyung.meta_data import data2path
6
+ from joonmyung.log import AverageMeter
7
+ import torch.nn.functional as F
8
+ from tqdm import tqdm
9
+ import numpy as np
10
+ import torch
11
+ import cv2
12
+
13
+ from joonmyung.utils import read_classnames
14
+
15
+
16
+ def anaModel(transformer_class):
17
+ class VisionTransformer(transformer_class):
18
+ def forward_features(self, x):
19
+ x = self.patch_embed(x)
20
+ if hasattr(self, "cls_token"):
21
+ cls_token = self.cls_token.expand(x.shape[0], -1, -1) # stole cls_tokens impl from Phil Wang, thanks
22
+ x = torch.cat((cls_token, x), dim=1)
23
+
24
+ if self.analysis[0] == 1: # PATCH
25
+ x = x # (8, 197, 192)
26
+ elif self.analysis[0] == 2: # POS
27
+ x = self.pos_embed # (1, 197, 192)
28
+ elif self.analysis[0] == 3: # PATCH (RANDOM I) + POS
29
+ x = torch.rand_like(self.pos_embed, device=x.device) + self.pos_embed
30
+ elif self.analysis[0] == 4: # PATCH (RANDOM II) + POS
31
+ x = torch.rand_like(self.cls_token, device=x.device).repeat(1, x.shape[1], 1) + self.pos_embed
32
+ else: # PATCH + POS
33
+ x = x + self.pos_embed
34
+ x = self.pos_drop(x)
35
+
36
+ x = self.blocks(x)
37
+ x = self.norm(x)
38
+ if hasattr(self, "cls_token") and hasattr(self, "cls_token"):
39
+ return x[:, 0], x[:, 1]
40
+ elif hasattr(self, "cls_token"):
41
+ return self.pre_logits(x[:, 0])
42
+ else:
43
+ return self.pre_logits(x.mean(dim=1))
44
+
45
+ return VisionTransformer
46
+
47
+ class Analysis:
48
+ def __init__(self, model, analysis = [0], activate = [True, False, False, False], num_classes = 1000, device="cuda"):
49
+ if sum(analysis):
50
+ model_ = anaModel(model.__class__)
51
+ model.__class__ = model_
52
+ model.analysis = analysis
53
+
54
+ self.num_classes = num_classes
55
+ self.model = model.to(device)
56
+ hooks = [{"name_i": 'attn_drop', "name_o": 'decoder', "fn_f": self.attn_forward, "fn_b": self.attn_backward},
57
+ {"name_i": 'qkv', "name_o": 'decoder', "fn_f": self.qkv_forward, "fn_b": None},
58
+ {"name_i": 'head', "name_o": 'decoder', "fn_f": self.head_forward, "fn_b": None},
59
+ {"name_i": 'patch_embed.norm', "name_o": 'decoder', "fn_f": self.input_forward, "fn_b": None}]
60
+
61
+ for name, module in self.model.named_modules():
62
+ for idx, hook in enumerate(hooks):
63
+ if hook["name_i"] in name and hook["name_o"] not in name and activate[idx]:
64
+ if hook["fn_f"]: module.register_forward_hook(hook["fn_f"])
65
+ if hook["fn_b"]: module.register_backward_hook(hook["fn_b"])
66
+
67
+ def attn_forward(self, module, input, output): # input/output : 1 * (8, 3, 197, 197) / (8, 3, 197, 197)
68
+ self.info["attn"]["f"].append(output.detach())
69
+
70
+ def attn_backward(self, module, grad_input, grad_output): # # input/output : 1 * (8, 3, 197, 192) / (8, 3, 197, 576)
71
+ self.info["attn"]["b"].append(grad_input[0].detach())
72
+
73
+ def qkv_forward(self, module, input, output): # # input/output : 1 * (8, 197, 192) / (8, 197, 576)
74
+ self.info["qkv"]["f"].append(output.detach())
75
+
76
+ def head_forward(self, module, input, output): # input : 1 * (8(B), 192(D)), output : (8(B), 1000(C))
77
+ B = output.shape[0]
78
+ pred = targetPred(output, self.targets, topk=5)
79
+ self.info["head"]["TF"] += (pred[:, 0] == pred[:, 1])
80
+
81
+ acc1, acc5 = accuracy(output, self.targets, topk=(1,5))
82
+ self.info["head"]["acc1"].update(acc1.item(), n=B)
83
+ self.info["head"]["acc5"].update(acc5.item(), n=B)
84
+
85
+ def input_forward(self, module, input, output):
86
+ norm = F.normalize(output, dim=-1)
87
+ self.info["input"]["sim"] += (norm @ norm.transpose(-1, -2)).mean(dim=(-1, -2))
88
+
89
+ def resetInfo(self):
90
+ self.info = {"attn" : {"f": [], "b": []},
91
+ "qkv" : {"f": [], "b": []},
92
+ "head" : {"acc1" : AverageMeter(),
93
+ "acc5" : AverageMeter(),
94
+ "TF" : [], "pred" : []},
95
+ "input": {"sim" : []}
96
+ }
97
+
98
+ def __call__(self, samples, targets = None, **kwargs):
99
+ self.resetInfo()
100
+ self.model.zero_grad()
101
+ self.model.eval()
102
+
103
+ if type(samples) == torch.Tensor:
104
+ self.targets = targets
105
+ outputs = self.model(samples, **kwargs)
106
+ return outputs
107
+ else:
108
+ for sample, targets in tqdm(samples):
109
+ self.targets = targets
110
+ _ = self.model(sample)
111
+ return False
112
+
113
+ def anaSaliency(self, attn=True, grad=False, output=None, index=None,
114
+ head_fusion="mean", discard_ratios=[0.], data_from="cls",
115
+ reshape=False, activate= [True, True, False], device="cuda"):
116
+
117
+ if attn:
118
+ attn = self.info["attn"]["f"]
119
+ if grad:
120
+ self.info["attn"]["b"] = []
121
+ self.model.zero_grad()
122
+ if index == None: index = output.max(dim=1)[1]
123
+ index = torch.eye(self.num_classes, device=device)[index]
124
+ loss = (output * index).sum()
125
+ loss.backward(retain_graph=True)
126
+ grad = self.info["attn"]["b"]
127
+
128
+ return saliency(attn, grad, activate=activate,
129
+ head_fusion=head_fusion, discard_ratios=discard_ratios, data_from=data_from,
130
+ reshape=reshape, device=device)
131
+
132
+
133
+ if __name__ == '__main__':
134
+ dataset_name, device, debug = "imagenet", 'cuda', True
135
+ data_path, num_classes, _, _ = data2path(dataset_name)
136
+ activate = [True, False, False, False] # [ATTN, QKV, HEAD]
137
+ analysis = [0] # [0] : INPUT TYPE, [0 : SAMPLE + POS, 1 : SAMPLE, 2 : POS]
138
+
139
+ dataset = JDataset(data_path, dataset_name, device=device)
140
+ data_idxs = [[c, i] for i in range(1000) for c in range(50)]
141
+
142
+ modelMaker = JModel(num_classes, device=device)
143
+ model = modelMaker.getModel(2, "ViT-B/16")
144
+
145
+ classnames = read_classnames("/hub_data1/joonmyung/data/imagenet/classnames.txt")
146
+ model = ZeroShotInference(model, classnames, prompt="a photo of a {}.", device=device)
147
+
148
+ model = Analysis(model, analysis = analysis, activate = activate, device=device)
149
+
150
+ view = [False, True, False, False, True] # [IMG, SALIENCY:ATTN, SALIENCY:OPENCV, SALIENCY:GRAD, ATTN. MOVEMENT]
151
+ for idx, data_idx in enumerate(data_idxs):
152
+ print(f"------------------------- [{data_idx[0]}]/[{data_idx[1]}] -------------------------")
153
+
154
+ sample, target, label_name = dataset[data_idx[0], data_idx[1]]
155
+ output = model(sample)
156
+ if view[0]:
157
+ drawImgPlot(unNormalize(sample, "imagenet"))
158
+
159
+ if view[1]: # SALIENCY W/ MODEL
160
+ col, discard_ratios, v_ratio, head_fusion, data_from = 12, [0.0], 0.0, "mean", "patch" # Attention, Gradient
161
+ results = model.anaSaliency(True, False, output, discard_ratios=discard_ratios,
162
+ head_fusion = head_fusion, index=target, data_from=data_from,
163
+ reshape = True, activate=[True, True, True]) # (12(L), 8(B), 14(H), 14(W))
164
+ data_roll = overlay(sample, results["rollout"], dataset_name)
165
+ drawImgPlot(data_roll, col=col)
166
+
167
+ data_attn = overlay(sample, results["attentive"], dataset_name)
168
+ drawImgPlot(data_attn, col=col)
169
+
170
+ data_vidTLDR = overlay(sample, results["vidTLDR"], dataset_name)
171
+ drawImgPlot(data_vidTLDR, col=col)
172
+
173
+ discard_ratios, v_ratio, head_fusion, data_from = [0.0], 0.1, "mean", "cls"
174
+ results = model.anaSaliency(True, False, output, discard_ratios=discard_ratios,
175
+ head_fusion=head_fusion, index=target, data_from=data_from,
176
+ reshape=True, activate=[True, True, True]) # (12(L), 8(B), 14(H), 14(W))
177
+
178
+ data_roll = overlay(sample, results["rollout"], dataset_name)
179
+ drawImgPlot(data_roll, col=col)
180
+
181
+ data_attn = overlay(sample, results["attentive"], dataset_name)
182
+ drawImgPlot(data_attn, col=col)
183
+
184
+ data_vidTLDR = overlay(sample, results["vidTLDR"], dataset_name)
185
+ drawImgPlot(data_vidTLDR, col=col)
186
+
187
+ if view[2]: # SALIENCY W/ DATA
188
+ img = np.array(dataset[data_idx[0], data_idx[1], 2][0])
189
+
190
+ saliency = cv2.saliency.StaticSaliencySpectralResidual_create()
191
+ (success, saliencyMap) = saliency.computeSaliency(img)
192
+ saliencyMap = (saliencyMap * 255).astype("uint8")
193
+
194
+ saliency = cv2.saliency.StaticSaliencyFineGrained_create()
195
+ (success, saliencyFineMap) = saliency.computeSaliency(img)
196
+ threshMap = cv2.threshold((saliencyFineMap * 255).astype("uint8"), 0, 255, cv2.THRESH_BINARY | cv2.THRESH_OTSU)[1]
197
+ # plt.imshow(threshMap)
198
+ # plt.show()
199
+
200
+ if view[3]: # SALIENCY FOR INPUT
201
+ sample.requires_grad, model.detach, k = True, False, 3
202
+ output = model(sample)
203
+ attn = torch.stack(model.info["attn"]["f"], dim=1).mean(dim=[2,3])[0,-2]
204
+ topK = attn[1:].topk(k, -1, True)[1]
205
+ a = torch.autograd.grad(output[:,3], sample, retain_graph=True)[0].sum(dim=1)
206
+ b = F.interpolate(a.unsqueeze(0), scale_factor=0.05, mode='nearest')[0]
207
+
208
+ if view[4]: # ATTENTION MOVEMENT (FROM / TO)
209
+ attn = torch.stack(model.info["attn"]["f"]).mean(dim=2).transpose(0,1) # (8 (B), 12 (L), 197(T_Q), 197(T_K))
210
+
211
+ # CLS가 얼마나 참고하는지
212
+ cls2cls = attn[:, :, :1, 0].mean(dim=2) # (8(B), 12(L))
213
+ patch2cls = attn[:, :, :1, 1:].mean(dim=2).sum(dim=-1) # (8(B), 12(L))
214
+
215
+ # PATCH가 얼마나 참고하는지
216
+ cls2patch = attn[:, :, 1:, 0].mean(dim=2)
217
+ patch2patch = attn[:, :, 1:, 1:].mean(dim=2).sum(dim=-1)
218
+ # to_np(torch.stack([cls2cls.mean(dim=0), patch2cls.mean(dim=0), cls2patch.mean(dim=0), patch2patch.mean(dim=0)]))