joonmyung 1.5.14__tar.gz → 1.5.16__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- joonmyung-1.5.16/PKG-INFO +20 -0
- {joonmyung-1.5.14 → joonmyung-1.5.16}/joonmyung/analysis/__init__.py +0 -1
- joonmyung-1.5.16/joonmyung/analysis/analysis.py +145 -0
- joonmyung-1.5.14/joonmyung/analysis/analysis.py → joonmyung-1.5.16/joonmyung/analysis/analysis_bak.py +44 -100
- joonmyung-1.5.16/joonmyung/analysis/analysis_/343/205/240/343/205/217.py +218 -0
- {joonmyung-1.5.14 → joonmyung-1.5.16}/joonmyung/analysis/dataset.py +13 -19
- joonmyung-1.5.16/joonmyung/analysis/evaluate.py +39 -0
- joonmyung-1.5.16/joonmyung/analysis/model.py +109 -0
- joonmyung-1.5.16/joonmyung/clip/__init__.py +1 -0
- joonmyung-1.5.16/joonmyung/clip/clip.py +221 -0
- joonmyung-1.5.16/joonmyung/clip/model.py +445 -0
- joonmyung-1.5.16/joonmyung/clip/simple_tokenizer.py +131 -0
- joonmyung-1.5.16/joonmyung/compression/apply.py +139 -0
- joonmyung-1.5.16/joonmyung/compression/compression.py +202 -0
- {joonmyung-1.5.14 → joonmyung-1.5.16}/joonmyung/draw.py +72 -6
- {joonmyung-1.5.14 → joonmyung-1.5.16}/joonmyung/metric.py +3 -1
- joonmyung-1.5.16/joonmyung/model/__init__.py +0 -0
- joonmyung-1.5.16/joonmyung/model/compression.py +202 -0
- joonmyung-1.5.16/joonmyung/model.py +0 -0
- joonmyung-1.5.16/joonmyung/models/__init__.py +0 -0
- {joonmyung-1.5.14 → joonmyung-1.5.16}/joonmyung/script.py +21 -14
- {joonmyung-1.5.14 → joonmyung-1.5.16}/joonmyung/utils.py +16 -1
- joonmyung-1.5.16/joonmyung.egg-info/PKG-INFO +20 -0
- {joonmyung-1.5.14 → joonmyung-1.5.16}/joonmyung.egg-info/SOURCES.txt +14 -1
- joonmyung-1.5.16/joonmyung.egg-info/requires.txt +11 -0
- {joonmyung-1.5.14 → joonmyung-1.5.16}/setup.py +6 -1
- joonmyung-1.5.14/PKG-INFO +0 -9
- joonmyung-1.5.14/joonmyung/analysis/metric.py +0 -35
- joonmyung-1.5.14/joonmyung/analysis/model.py +0 -55
- joonmyung-1.5.14/joonmyung.egg-info/PKG-INFO +0 -9
- {joonmyung-1.5.14 → joonmyung-1.5.16}/LICENSE.txt +0 -0
- {joonmyung-1.5.14 → joonmyung-1.5.16}/README.md +0 -0
- {joonmyung-1.5.14 → joonmyung-1.5.16}/joonmyung/__init__.py +0 -0
- {joonmyung-1.5.14 → joonmyung-1.5.16}/joonmyung/analysis/hook.py +0 -0
- {joonmyung-1.5.14 → joonmyung-1.5.16}/joonmyung/analysis/utils.py +0 -0
- {joonmyung-1.5.14 → joonmyung-1.5.16}/joonmyung/app.py +0 -0
- {joonmyung-1.5.14/joonmyung/models → joonmyung-1.5.16/joonmyung/compression}/__init__.py +0 -0
- {joonmyung-1.5.14 → joonmyung-1.5.16}/joonmyung/data.py +0 -0
- {joonmyung-1.5.14 → joonmyung-1.5.16}/joonmyung/dummy.py +0 -0
- {joonmyung-1.5.14 → joonmyung-1.5.16}/joonmyung/file.py +0 -0
- {joonmyung-1.5.14 → joonmyung-1.5.16}/joonmyung/gradcam.py +0 -0
- {joonmyung-1.5.14 → joonmyung-1.5.16}/joonmyung/log.py +0 -0
- {joonmyung-1.5.14 → joonmyung-1.5.16}/joonmyung/meta_data/__init__.py +0 -0
- {joonmyung-1.5.14 → joonmyung-1.5.16}/joonmyung/meta_data/label.py +0 -0
- {joonmyung-1.5.14 → joonmyung-1.5.16}/joonmyung/meta_data/utils.py +0 -0
- {joonmyung-1.5.14 → joonmyung-1.5.16}/joonmyung/models/tome.py +0 -0
- {joonmyung-1.5.14 → joonmyung-1.5.16}/joonmyung/status.py +0 -0
- {joonmyung-1.5.14 → joonmyung-1.5.16}/joonmyung.egg-info/dependency_links.txt +0 -0
- {joonmyung-1.5.14 → joonmyung-1.5.16}/joonmyung.egg-info/not-zip-safe +0 -0
- {joonmyung-1.5.14 → joonmyung-1.5.16}/joonmyung.egg-info/top_level.txt +0 -0
- {joonmyung-1.5.14 → joonmyung-1.5.16}/setup.cfg +0 -0
|
@@ -0,0 +1,20 @@
|
|
|
1
|
+
Metadata-Version: 2.1
|
|
2
|
+
Name: joonmyung
|
|
3
|
+
Version: 1.5.16
|
|
4
|
+
Summary: JoonMyung's Library
|
|
5
|
+
Home-page: https://github.com/pizard/JoonMyung.git
|
|
6
|
+
Author: JoonMyung Choi
|
|
7
|
+
Author-email: pizard@korea.ac.kr
|
|
8
|
+
License: MIT
|
|
9
|
+
License-File: LICENSE.txt
|
|
10
|
+
Requires-Dist: fvcore
|
|
11
|
+
Requires-Dist: timm
|
|
12
|
+
Requires-Dist: torchprofile
|
|
13
|
+
Requires-Dist: thop
|
|
14
|
+
Requires-Dist: wandb
|
|
15
|
+
Requires-Dist: scipy
|
|
16
|
+
Requires-Dist: matplotlib
|
|
17
|
+
Requires-Dist: seaborn
|
|
18
|
+
Requires-Dist: opencv-python
|
|
19
|
+
Requires-Dist: ftfy
|
|
20
|
+
Requires-Dist: regex
|
|
@@ -0,0 +1,145 @@
|
|
|
1
|
+
from joonmyung.draw import saliency, overlay, drawImgPlot, unNormalize, drawHeatmap
|
|
2
|
+
from joonmyung.analysis.model import JModel, ZeroShotInference
|
|
3
|
+
from timm.models.vision_transformer import Attention
|
|
4
|
+
from joonmyung.metric import targetPred, accuracy
|
|
5
|
+
from joonmyung.analysis.dataset import JDataset
|
|
6
|
+
from joonmyung.utils import read_classnames
|
|
7
|
+
from joonmyung.meta_data import data2path
|
|
8
|
+
from joonmyung.log import AverageMeter
|
|
9
|
+
import torch.nn.functional as F
|
|
10
|
+
from tqdm import tqdm
|
|
11
|
+
import numpy as np
|
|
12
|
+
import torch
|
|
13
|
+
import cv2
|
|
14
|
+
|
|
15
|
+
def anaModel(transformer_class):
|
|
16
|
+
class VisionTransformer(transformer_class):
|
|
17
|
+
info_key = []
|
|
18
|
+
def resetInfo(self):
|
|
19
|
+
self.info = {n: [] for n in self.info_key}
|
|
20
|
+
|
|
21
|
+
def createHook(self, hooks):
|
|
22
|
+
[self.info_key.append(hook[3]) for hook in hooks]
|
|
23
|
+
for name, module in self.named_modules():
|
|
24
|
+
for idx, hook in enumerate(hooks):
|
|
25
|
+
if hook[1] in name and hook[2] not in name:
|
|
26
|
+
if hook[0] == "f":
|
|
27
|
+
module.register_forward_hook(lambda mod, inp, out, hook_info=hook:
|
|
28
|
+
self.forward_hook(hook_info, mod, inp, out))
|
|
29
|
+
else:
|
|
30
|
+
module.register_backward_hook(lambda mod, inp, out, hook_info=hook:
|
|
31
|
+
self.backward_hook(hook_info, mod, inp, out))
|
|
32
|
+
def forward_hook(self, hook_info, module, input, output):
|
|
33
|
+
self.info[hook_info[3]].append(output.detach())
|
|
34
|
+
|
|
35
|
+
def backward_hook(self, hook_info, module, input, output):
|
|
36
|
+
self.info[hook_info[3]].append(input[0].detach())
|
|
37
|
+
|
|
38
|
+
def forward(self, *args, **kwdargs):
|
|
39
|
+
self.resetInfo()
|
|
40
|
+
return super().forward(*args, **kwdargs)
|
|
41
|
+
def encode_image(self, *args, **kwdargs):
|
|
42
|
+
self.resetInfo()
|
|
43
|
+
return super().encode_image(*args, **kwdargs)
|
|
44
|
+
|
|
45
|
+
return VisionTransformer
|
|
46
|
+
|
|
47
|
+
def Analysis(model, hook_info= [["f", "attn_drop", "decoder", "attn"]]):
|
|
48
|
+
model.__class__ = anaModel(model.__class__)
|
|
49
|
+
model.createHook(hook_info)
|
|
50
|
+
return model
|
|
51
|
+
|
|
52
|
+
if __name__ == '__main__':
|
|
53
|
+
dataset_name, device, debug = "imagenet", 'cuda', True
|
|
54
|
+
data_path, num_classes, _, _ = data2path(dataset_name)
|
|
55
|
+
analysis = [0] # [0] : INPUT TYPE, [0 : SAMPLE + POS, 1 : SAMPLE, 2 : POS]
|
|
56
|
+
|
|
57
|
+
dataset = JDataset(data_path, dataset_name, device=device)
|
|
58
|
+
data_idxs = [[c, i] for i in range(1000) for c in range(50)]
|
|
59
|
+
|
|
60
|
+
modelMaker = JModel(num_classes, device=device)
|
|
61
|
+
model = modelMaker.getModel(2, "ViT-B/16")
|
|
62
|
+
classnames = read_classnames("/hub_data1/joonmyung/data/imagenet/classnames.txt")
|
|
63
|
+
model = ZeroShotInference(model, classnames, prompt="a photo of a {}.", device=device)
|
|
64
|
+
hook_info = [["b", "attn_drop", "decoder", "grad"],
|
|
65
|
+
["f", "attn_drop", "decoder", "attn"],
|
|
66
|
+
["f", "ln_pre", "decoder", "feat_1"],
|
|
67
|
+
["f", "ln_1", "decoder", "feat_2"],
|
|
68
|
+
["f", "ln_2", "decoder", "feat_3"],
|
|
69
|
+
["f", "ln_post", "decoder", "feat_4"]]
|
|
70
|
+
model.model = Analysis(model.model, hook_info)
|
|
71
|
+
view = [False, False, True, True, True, True] # [IMG, SALIENCY:ATTN, SALIENCY:OPENCV, SALIENCY:GRAD, ATTN. MOVEMENT]
|
|
72
|
+
for idx, data_idx in enumerate(data_idxs):
|
|
73
|
+
print(f"------------------------- [{data_idx[0]}]/[{data_idx[1]}] -------------------------")
|
|
74
|
+
sample, target, label_name = dataset[data_idx[0], data_idx[1]]
|
|
75
|
+
sample.requires_grad = True
|
|
76
|
+
|
|
77
|
+
if view[0]:
|
|
78
|
+
drawImgPlot(unNormalize(sample, "imagenet"))
|
|
79
|
+
|
|
80
|
+
output = model(sample)
|
|
81
|
+
index = torch.eye(num_classes, device=device)[target]
|
|
82
|
+
(output * index).sum().backward(retain_graph=True)
|
|
83
|
+
|
|
84
|
+
attns = model.model.info["attn"]
|
|
85
|
+
grads = model.model.info["grad"]
|
|
86
|
+
if view[1]:
|
|
87
|
+
col, discard_ratios, v_ratio, head_fusion, data_from = 12, [0.0], 0.0, "mean", "patch"
|
|
88
|
+
results = saliency(attns, False, head_fusion=head_fusion, discard_ratios=discard_ratios, data_from=data_from, reshape=True, device=device)
|
|
89
|
+
|
|
90
|
+
data_roll = overlay(sample, results["rollout"], dataset_name)
|
|
91
|
+
drawImgPlot(data_roll, col=col)
|
|
92
|
+
|
|
93
|
+
data_attn = overlay(sample, results["attentive"], dataset_name)
|
|
94
|
+
drawImgPlot(data_attn, col=col)
|
|
95
|
+
|
|
96
|
+
data_vidTLDR = overlay(sample, results["vidTLDR"], dataset_name)
|
|
97
|
+
drawImgPlot(data_vidTLDR, col=col)
|
|
98
|
+
|
|
99
|
+
discard_ratios, v_ratio, head_fusion, data_from = [0.0], 0.1, "mean", "cls"
|
|
100
|
+
results = saliency(attns, grads, head_fusion=head_fusion, discard_ratios=discard_ratios, data_from=data_from, reshape=True, device=device)
|
|
101
|
+
|
|
102
|
+
data_roll = overlay(sample, results["rollout"], dataset_name)
|
|
103
|
+
drawImgPlot(data_roll, col=col)
|
|
104
|
+
|
|
105
|
+
data_attn = overlay(sample, results["attentive"], dataset_name)
|
|
106
|
+
drawImgPlot(data_attn, col=col)
|
|
107
|
+
|
|
108
|
+
data_vidTLDR = overlay(sample, results["vidTLDR"], dataset_name)
|
|
109
|
+
drawImgPlot(data_vidTLDR, col=col)
|
|
110
|
+
|
|
111
|
+
if view[2]: # SALIENCY W/ DATA
|
|
112
|
+
img = (dataset.unNormalize(sample)[0].permute(1, 2, 0).detach().cpu().numpy() * 255)
|
|
113
|
+
img_saliency = cv2.saliency.StaticSaliencySpectralResidual_create()
|
|
114
|
+
(success, saliencyMap) = img_saliency.computeSaliency(img)
|
|
115
|
+
saliencyMap = (saliencyMap * 255).astype("uint8")
|
|
116
|
+
|
|
117
|
+
img_saliency = cv2.saliency.StaticSaliencyFineGrained_create()
|
|
118
|
+
(success, saliencyFineMap) = img_saliency.computeSaliency(img)
|
|
119
|
+
threshMap = cv2.threshold((saliencyFineMap * 255).astype("uint8"), 0, 255, cv2.THRESH_BINARY | cv2.THRESH_OTSU)[1]
|
|
120
|
+
|
|
121
|
+
if view[3]: # SALIENCY FOR INPUT
|
|
122
|
+
output = model(sample)
|
|
123
|
+
attn = torch.stack(attns, dim=1).mean(dim=[2, 3])[0, -2]
|
|
124
|
+
a = torch.autograd.grad(output[:, 3], sample, retain_graph=True)[0].sum(dim=1)
|
|
125
|
+
b = F.interpolate(a.unsqueeze(0), scale_factor=1.0, mode='nearest')[0]
|
|
126
|
+
|
|
127
|
+
if view[4]: # ATTENTION MOVEMENT (FROM / TO)
|
|
128
|
+
attn = torch.stack(attns).mean(dim=2).transpose(0,1) # (8 (B), 12 (L), 197(T_Q), 197(T_K))
|
|
129
|
+
|
|
130
|
+
cls2cls = attn[:, :, :1, 0].mean(dim=2) # (8(B), 12(L))
|
|
131
|
+
patch2cls = attn[:, :, :1, 1:].mean(dim=2).sum(dim=-1) # (8(B), 12(L))
|
|
132
|
+
cls2patch = attn[:, :, 1:, 0].mean(dim=2)
|
|
133
|
+
patch2patch = attn[:, :, 1:, 1:].mean(dim=2).sum(dim=-1)
|
|
134
|
+
# to_np(torch.stack([cls2cls.mean(dim=0), patch2cls.mean(dim=0), cls2patch.mean(dim=0), patch2patch.mean(dim=0)]))
|
|
135
|
+
if view[5]:
|
|
136
|
+
feats = {k: v for k, v in model.model.info if "feat" in k}
|
|
137
|
+
for name, feat in feats.items():
|
|
138
|
+
print(f"Feature Position : {name}")
|
|
139
|
+
image_feat = (torch.stack(feat)[:, :, 1:] @ model.model.visual.proj) # (1, 1, 196, 512)
|
|
140
|
+
L = image_feat.shape[0]
|
|
141
|
+
image_feat = image_feat / image_feat.norm(dim=-1, keepdim=True)
|
|
142
|
+
|
|
143
|
+
text_feat = model.text_features[1][None].t()
|
|
144
|
+
sim = (image_feat @ text_feat).reshape(L, 14, 14)
|
|
145
|
+
drawHeatmap(sim, col = L)
|
|
@@ -1,16 +1,17 @@
|
|
|
1
|
-
from joonmyung.analysis.dataset import JDataset
|
|
2
|
-
from joonmyung.analysis.model import JModel
|
|
3
1
|
from joonmyung.draw import saliency, overlay, drawImgPlot, unNormalize
|
|
4
|
-
from joonmyung.meta_data import data2path
|
|
5
2
|
from joonmyung.metric import targetPred, accuracy
|
|
3
|
+
from joonmyung.analysis.dataset import JDataset
|
|
4
|
+
from joonmyung.analysis.model import JModel, ZeroShotInference
|
|
5
|
+
from joonmyung.meta_data import data2path
|
|
6
6
|
from joonmyung.log import AverageMeter
|
|
7
|
-
from tqdm import tqdm
|
|
8
|
-
from contextlib import suppress
|
|
9
7
|
import torch.nn.functional as F
|
|
8
|
+
from tqdm import tqdm
|
|
10
9
|
import numpy as np
|
|
11
10
|
import torch
|
|
12
11
|
import cv2
|
|
13
12
|
|
|
13
|
+
from joonmyung.utils import read_classnames
|
|
14
|
+
|
|
14
15
|
|
|
15
16
|
def anaModel(transformer_class):
|
|
16
17
|
class VisionTransformer(transformer_class):
|
|
@@ -44,78 +45,46 @@ def anaModel(transformer_class):
|
|
|
44
45
|
return VisionTransformer
|
|
45
46
|
|
|
46
47
|
class Analysis:
|
|
47
|
-
def __init__(self, model, analysis = [0], activate = [True, False, False, False],
|
|
48
|
-
|
|
49
|
-
, amp_autocast=suppress, device="cuda"):
|
|
50
|
-
# Section A. Model
|
|
51
|
-
self.num_classes = num_classes
|
|
52
|
-
self.key_name = key_name
|
|
53
|
-
if wrapping:
|
|
48
|
+
def __init__(self, model, analysis = [0], activate = [True, False, False, False], num_classes = 1000, device="cuda"):
|
|
49
|
+
if sum(analysis):
|
|
54
50
|
model_ = anaModel(model.__class__)
|
|
55
51
|
model.__class__ = model_
|
|
56
52
|
model.analysis = analysis
|
|
57
53
|
|
|
58
|
-
self.
|
|
59
|
-
self.
|
|
60
|
-
|
|
61
|
-
# Section B. Attention
|
|
62
|
-
self.kwargs_roll = {"cls_start" : cls_start, "cls_end" : cls_end,
|
|
63
|
-
"patch_start" : patch_start, "patch_end" : patch_end}
|
|
64
|
-
|
|
65
|
-
# Section C. Setting
|
|
54
|
+
self.num_classes = num_classes
|
|
55
|
+
self.model = model.to(device)
|
|
66
56
|
hooks = [{"name_i": 'attn_drop', "name_o": 'decoder', "fn_f": self.attn_forward, "fn_b": self.attn_backward},
|
|
67
|
-
{"name_i": 'qkv', "name_o": 'decoder', "fn_f": self.qkv_forward, "fn_b":
|
|
68
|
-
{"name_i": 'head', "name_o": 'decoder', "fn_f": self.head_forward, "fn_b":
|
|
69
|
-
{"name_i": 'patch_embed.norm', "name_o": 'decoder', "fn_f": self.input_forward, "fn_b":
|
|
70
|
-
self.activate = activate
|
|
71
|
-
|
|
72
|
-
self.amp_autocast = amp_autocast
|
|
73
|
-
self.device = device
|
|
57
|
+
{"name_i": 'qkv', "name_o": 'decoder', "fn_f": self.qkv_forward, "fn_b": None},
|
|
58
|
+
{"name_i": 'head', "name_o": 'decoder', "fn_f": self.head_forward, "fn_b": None},
|
|
59
|
+
{"name_i": 'patch_embed.norm', "name_o": 'decoder', "fn_f": self.input_forward, "fn_b": None}]
|
|
74
60
|
|
|
75
61
|
for name, module in self.model.named_modules():
|
|
76
|
-
for hook in hooks:
|
|
77
|
-
if hook["name_i"] in name and hook["name_o"] not in name:
|
|
78
|
-
module.register_forward_hook(hook["fn_f"])
|
|
79
|
-
module.register_backward_hook(hook["fn_b"])
|
|
80
|
-
self.resetInfo()
|
|
62
|
+
for idx, hook in enumerate(hooks):
|
|
63
|
+
if hook["name_i"] in name and hook["name_o"] not in name and activate[idx]:
|
|
64
|
+
if hook["fn_f"]: module.register_forward_hook(hook["fn_f"])
|
|
65
|
+
if hook["fn_b"]: module.register_backward_hook(hook["fn_b"])
|
|
81
66
|
|
|
82
|
-
def attn_forward(self, module, input, output):
|
|
83
|
-
|
|
84
|
-
if self.activate[0]: self.info["attn"]["f"].append(output.detach())
|
|
67
|
+
def attn_forward(self, module, input, output): # input/output : 1 * (8, 3, 197, 197) / (8, 3, 197, 197)
|
|
68
|
+
self.info["attn"]["f"].append(output.detach())
|
|
85
69
|
|
|
86
|
-
def attn_backward(self, module, grad_input, grad_output):
|
|
87
|
-
|
|
88
|
-
if self.activate[0]: self.info["attn"]["b"].append(grad_input[0].detach())
|
|
70
|
+
def attn_backward(self, module, grad_input, grad_output): # # input/output : 1 * (8, 3, 197, 192) / (8, 3, 197, 576)
|
|
71
|
+
self.info["attn"]["b"].append(grad_input[0].detach())
|
|
89
72
|
|
|
90
|
-
def qkv_forward(self, module, input, output):
|
|
91
|
-
|
|
92
|
-
if self.activate[1]: self.info["qkv"]["f"].append(output.detach())
|
|
73
|
+
def qkv_forward(self, module, input, output): # # input/output : 1 * (8, 197, 192) / (8, 197, 576)
|
|
74
|
+
self.info["qkv"]["f"].append(output.detach())
|
|
93
75
|
|
|
94
|
-
def
|
|
95
|
-
|
|
96
|
-
|
|
76
|
+
def head_forward(self, module, input, output): # input : 1 * (8(B), 192(D)), output : (8(B), 1000(C))
|
|
77
|
+
B = output.shape[0]
|
|
78
|
+
pred = targetPred(output, self.targets, topk=5)
|
|
79
|
+
self.info["head"]["TF"] += (pred[:, 0] == pred[:, 1])
|
|
97
80
|
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
B = output.shape[0]
|
|
102
|
-
pred = targetPred(output, self.targets, topk=5)
|
|
103
|
-
self.info["head"]["TF"] += (pred[:, 0] == pred[:, 1])
|
|
104
|
-
|
|
105
|
-
acc1, acc5 = accuracy(output, self.targets, topk=(1,5))
|
|
106
|
-
self.info["head"]["acc1"].update(acc1.item(), n=B)
|
|
107
|
-
self.info["head"]["acc5"].update(acc5.item(), n=B)
|
|
108
|
-
|
|
109
|
-
def head_backward(self, module, grad_input, grad_output):
|
|
110
|
-
pass
|
|
81
|
+
acc1, acc5 = accuracy(output, self.targets, topk=(1,5))
|
|
82
|
+
self.info["head"]["acc1"].update(acc1.item(), n=B)
|
|
83
|
+
self.info["head"]["acc5"].update(acc5.item(), n=B)
|
|
111
84
|
|
|
112
85
|
def input_forward(self, module, input, output):
|
|
113
|
-
|
|
114
|
-
|
|
115
|
-
self.info["input"]["sim"] += (norm @ norm.transpose(-1, -2)).mean(dim=(-1, -2))
|
|
116
|
-
|
|
117
|
-
def input_backward(self, module, grad_input, grad_output):
|
|
118
|
-
pass
|
|
86
|
+
norm = F.normalize(output, dim=-1)
|
|
87
|
+
self.info["input"]["sim"] += (norm @ norm.transpose(-1, -2)).mean(dim=(-1, -2))
|
|
119
88
|
|
|
120
89
|
def resetInfo(self):
|
|
121
90
|
self.info = {"attn" : {"f": [], "b": []},
|
|
@@ -151,7 +120,7 @@ class Analysis:
|
|
|
151
120
|
self.info["attn"]["b"] = []
|
|
152
121
|
self.model.zero_grad()
|
|
153
122
|
if index == None: index = output.max(dim=1)[1]
|
|
154
|
-
index = torch.eye(self.num_classes, device=
|
|
123
|
+
index = torch.eye(self.num_classes, device=device)[index]
|
|
155
124
|
loss = (output * index).sum()
|
|
156
125
|
loss.backward(retain_graph=True)
|
|
157
126
|
grad = self.info["attn"]["b"]
|
|
@@ -162,31 +131,27 @@ class Analysis:
|
|
|
162
131
|
|
|
163
132
|
|
|
164
133
|
if __name__ == '__main__':
|
|
165
|
-
|
|
166
|
-
dataset_name, device, amp_autocast, debug = "imagenet", 'cuda', torch.cuda.amp.autocast, True
|
|
134
|
+
dataset_name, device, debug = "imagenet", 'cuda', True
|
|
167
135
|
data_path, num_classes, _, _ = data2path(dataset_name)
|
|
168
|
-
|
|
169
|
-
# VIEW : IMG, SALIENCY:ATTN, SALIENCY:OPENCV, SALIENCY:GRAD, ATTN. MOVEMENT
|
|
170
|
-
# ACTIVATE : ATTN, QKV, HEAD
|
|
136
|
+
activate = [True, False, False, False] # [ATTN, QKV, HEAD]
|
|
171
137
|
analysis = [0] # [0] : INPUT TYPE, [0 : SAMPLE + POS, 1 : SAMPLE, 2 : POS]
|
|
172
138
|
|
|
173
139
|
dataset = JDataset(data_path, dataset_name, device=device)
|
|
174
|
-
|
|
175
|
-
data_idxs = [[1, 0]]
|
|
140
|
+
data_idxs = [[c, i] for i in range(1000) for c in range(50)]
|
|
176
141
|
|
|
177
|
-
# Section B. Model
|
|
178
|
-
model_number, model_name = 0, "deit_tiny_patch16_224" # deit, vit | tiny, small, base
|
|
179
|
-
# model_number, model_name = 1, "deit_tiny_patch16_224"
|
|
180
|
-
|
|
181
|
-
# Section C. Setting
|
|
182
142
|
modelMaker = JModel(num_classes, device=device)
|
|
183
|
-
model = modelMaker.getModel(
|
|
143
|
+
model = modelMaker.getModel(2, "ViT-B/16")
|
|
144
|
+
|
|
145
|
+
classnames = read_classnames("/hub_data1/joonmyung/data/imagenet/classnames.txt")
|
|
146
|
+
model = ZeroShotInference(model, classnames, prompt="a photo of a {}.", device=device)
|
|
147
|
+
|
|
184
148
|
model = Analysis(model, analysis = analysis, activate = activate, device=device)
|
|
149
|
+
|
|
150
|
+
view = [False, True, False, False, True] # [IMG, SALIENCY:ATTN, SALIENCY:OPENCV, SALIENCY:GRAD, ATTN. MOVEMENT]
|
|
185
151
|
for idx, data_idx in enumerate(data_idxs):
|
|
186
152
|
print(f"------------------------- [{data_idx[0]}]/[{data_idx[1]}] -------------------------")
|
|
187
153
|
|
|
188
154
|
sample, target, label_name = dataset[data_idx[0], data_idx[1]]
|
|
189
|
-
# sample, _, img, _ = dataset.getItemPath('/hub_data1/joonmyung/data/imagenet/train/n01440764/n01440764_39.JPEG')
|
|
190
155
|
output = model(sample)
|
|
191
156
|
if view[0]:
|
|
192
157
|
drawImgPlot(unNormalize(sample, "imagenet"))
|
|
@@ -219,22 +184,6 @@ if __name__ == '__main__':
|
|
|
219
184
|
data_vidTLDR = overlay(sample, results["vidTLDR"], dataset_name)
|
|
220
185
|
drawImgPlot(data_vidTLDR, col=col)
|
|
221
186
|
|
|
222
|
-
print(1)
|
|
223
|
-
|
|
224
|
-
# roll = F.normalize(results["rollout"].reshape(12, 196), dim=-1)
|
|
225
|
-
|
|
226
|
-
# datas_rollout = overlay(sample, rollout, dataset_name)
|
|
227
|
-
# drawImgPlot(datas_rollout, col=col)
|
|
228
|
-
|
|
229
|
-
# datas_attn = overlay(sample, attentive, dataset_name)
|
|
230
|
-
# drawImgPlot(datas_attn, col=col)
|
|
231
|
-
|
|
232
|
-
# a = attentive[5]
|
|
233
|
-
# b = torch.stack([a.clamp(max=a.quantile(1 - v_ratio)) for v_ratio in [0.0, 0.05, 0.1, 0.15, 0.2, 0.25, 0.3, 0.35, 0.4, 0.45, 0.5, 0.55]])
|
|
234
|
-
# datas_attn = overlay(sample, b, dataset_name)
|
|
235
|
-
# drawImgPlot(datas_attn, col=col)
|
|
236
|
-
# print(1)
|
|
237
|
-
|
|
238
187
|
if view[2]: # SALIENCY W/ DATA
|
|
239
188
|
img = np.array(dataset[data_idx[0], data_idx[1], 2][0])
|
|
240
189
|
|
|
@@ -253,12 +202,8 @@ if __name__ == '__main__':
|
|
|
253
202
|
output = model(sample)
|
|
254
203
|
attn = torch.stack(model.info["attn"]["f"], dim=1).mean(dim=[2,3])[0,-2]
|
|
255
204
|
topK = attn[1:].topk(k, -1, True)[1]
|
|
256
|
-
# a = torch.autograd.grad(attn.sum(), samples, retain_graph=True)[0].sum(dim=1)
|
|
257
205
|
a = torch.autograd.grad(output[:,3], sample, retain_graph=True)[0].sum(dim=1)
|
|
258
206
|
b = F.interpolate(a.unsqueeze(0), scale_factor=0.05, mode='nearest')[0]
|
|
259
|
-
# drawHeatmap(b)
|
|
260
|
-
print(1)
|
|
261
|
-
# to_np(torch.stack([attn[:, :, 0], attn[:, :, 1:].sum(dim=-1)], -1)[0])
|
|
262
207
|
|
|
263
208
|
if view[4]: # ATTENTION MOVEMENT (FROM / TO)
|
|
264
209
|
attn = torch.stack(model.info["attn"]["f"]).mean(dim=2).transpose(0,1) # (8 (B), 12 (L), 197(T_Q), 197(T_K))
|
|
@@ -270,5 +215,4 @@ if __name__ == '__main__':
|
|
|
270
215
|
# PATCH가 얼마나 참고하는지
|
|
271
216
|
cls2patch = attn[:, :, 1:, 0].mean(dim=2)
|
|
272
217
|
patch2patch = attn[:, :, 1:, 1:].mean(dim=2).sum(dim=-1)
|
|
273
|
-
# to_np(torch.stack([cls2cls.mean(dim=0), patch2cls.mean(dim=0), cls2patch.mean(dim=0), patch2patch.mean(dim=0)]))
|
|
274
|
-
print(1)
|
|
218
|
+
# to_np(torch.stack([cls2cls.mean(dim=0), patch2cls.mean(dim=0), cls2patch.mean(dim=0), patch2patch.mean(dim=0)]))
|
|
@@ -0,0 +1,218 @@
|
|
|
1
|
+
from joonmyung.draw import saliency, overlay, drawImgPlot, unNormalize
|
|
2
|
+
from joonmyung.metric import targetPred, accuracy
|
|
3
|
+
from joonmyung.analysis.dataset import JDataset
|
|
4
|
+
from joonmyung.analysis.model import JModel, ZeroShotInference
|
|
5
|
+
from joonmyung.meta_data import data2path
|
|
6
|
+
from joonmyung.log import AverageMeter
|
|
7
|
+
import torch.nn.functional as F
|
|
8
|
+
from tqdm import tqdm
|
|
9
|
+
import numpy as np
|
|
10
|
+
import torch
|
|
11
|
+
import cv2
|
|
12
|
+
|
|
13
|
+
from joonmyung.utils import read_classnames
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
def anaModel(transformer_class):
|
|
17
|
+
class VisionTransformer(transformer_class):
|
|
18
|
+
def forward_features(self, x):
|
|
19
|
+
x = self.patch_embed(x)
|
|
20
|
+
if hasattr(self, "cls_token"):
|
|
21
|
+
cls_token = self.cls_token.expand(x.shape[0], -1, -1) # stole cls_tokens impl from Phil Wang, thanks
|
|
22
|
+
x = torch.cat((cls_token, x), dim=1)
|
|
23
|
+
|
|
24
|
+
if self.analysis[0] == 1: # PATCH
|
|
25
|
+
x = x # (8, 197, 192)
|
|
26
|
+
elif self.analysis[0] == 2: # POS
|
|
27
|
+
x = self.pos_embed # (1, 197, 192)
|
|
28
|
+
elif self.analysis[0] == 3: # PATCH (RANDOM I) + POS
|
|
29
|
+
x = torch.rand_like(self.pos_embed, device=x.device) + self.pos_embed
|
|
30
|
+
elif self.analysis[0] == 4: # PATCH (RANDOM II) + POS
|
|
31
|
+
x = torch.rand_like(self.cls_token, device=x.device).repeat(1, x.shape[1], 1) + self.pos_embed
|
|
32
|
+
else: # PATCH + POS
|
|
33
|
+
x = x + self.pos_embed
|
|
34
|
+
x = self.pos_drop(x)
|
|
35
|
+
|
|
36
|
+
x = self.blocks(x)
|
|
37
|
+
x = self.norm(x)
|
|
38
|
+
if hasattr(self, "cls_token") and hasattr(self, "cls_token"):
|
|
39
|
+
return x[:, 0], x[:, 1]
|
|
40
|
+
elif hasattr(self, "cls_token"):
|
|
41
|
+
return self.pre_logits(x[:, 0])
|
|
42
|
+
else:
|
|
43
|
+
return self.pre_logits(x.mean(dim=1))
|
|
44
|
+
|
|
45
|
+
return VisionTransformer
|
|
46
|
+
|
|
47
|
+
class Analysis:
|
|
48
|
+
def __init__(self, model, analysis = [0], activate = [True, False, False, False], num_classes = 1000, device="cuda"):
|
|
49
|
+
if sum(analysis):
|
|
50
|
+
model_ = anaModel(model.__class__)
|
|
51
|
+
model.__class__ = model_
|
|
52
|
+
model.analysis = analysis
|
|
53
|
+
|
|
54
|
+
self.num_classes = num_classes
|
|
55
|
+
self.model = model.to(device)
|
|
56
|
+
hooks = [{"name_i": 'attn_drop', "name_o": 'decoder', "fn_f": self.attn_forward, "fn_b": self.attn_backward},
|
|
57
|
+
{"name_i": 'qkv', "name_o": 'decoder', "fn_f": self.qkv_forward, "fn_b": None},
|
|
58
|
+
{"name_i": 'head', "name_o": 'decoder', "fn_f": self.head_forward, "fn_b": None},
|
|
59
|
+
{"name_i": 'patch_embed.norm', "name_o": 'decoder', "fn_f": self.input_forward, "fn_b": None}]
|
|
60
|
+
|
|
61
|
+
for name, module in self.model.named_modules():
|
|
62
|
+
for idx, hook in enumerate(hooks):
|
|
63
|
+
if hook["name_i"] in name and hook["name_o"] not in name and activate[idx]:
|
|
64
|
+
if hook["fn_f"]: module.register_forward_hook(hook["fn_f"])
|
|
65
|
+
if hook["fn_b"]: module.register_backward_hook(hook["fn_b"])
|
|
66
|
+
|
|
67
|
+
def attn_forward(self, module, input, output): # input/output : 1 * (8, 3, 197, 197) / (8, 3, 197, 197)
|
|
68
|
+
self.info["attn"]["f"].append(output.detach())
|
|
69
|
+
|
|
70
|
+
def attn_backward(self, module, grad_input, grad_output): # # input/output : 1 * (8, 3, 197, 192) / (8, 3, 197, 576)
|
|
71
|
+
self.info["attn"]["b"].append(grad_input[0].detach())
|
|
72
|
+
|
|
73
|
+
def qkv_forward(self, module, input, output): # # input/output : 1 * (8, 197, 192) / (8, 197, 576)
|
|
74
|
+
self.info["qkv"]["f"].append(output.detach())
|
|
75
|
+
|
|
76
|
+
def head_forward(self, module, input, output): # input : 1 * (8(B), 192(D)), output : (8(B), 1000(C))
|
|
77
|
+
B = output.shape[0]
|
|
78
|
+
pred = targetPred(output, self.targets, topk=5)
|
|
79
|
+
self.info["head"]["TF"] += (pred[:, 0] == pred[:, 1])
|
|
80
|
+
|
|
81
|
+
acc1, acc5 = accuracy(output, self.targets, topk=(1,5))
|
|
82
|
+
self.info["head"]["acc1"].update(acc1.item(), n=B)
|
|
83
|
+
self.info["head"]["acc5"].update(acc5.item(), n=B)
|
|
84
|
+
|
|
85
|
+
def input_forward(self, module, input, output):
|
|
86
|
+
norm = F.normalize(output, dim=-1)
|
|
87
|
+
self.info["input"]["sim"] += (norm @ norm.transpose(-1, -2)).mean(dim=(-1, -2))
|
|
88
|
+
|
|
89
|
+
def resetInfo(self):
|
|
90
|
+
self.info = {"attn" : {"f": [], "b": []},
|
|
91
|
+
"qkv" : {"f": [], "b": []},
|
|
92
|
+
"head" : {"acc1" : AverageMeter(),
|
|
93
|
+
"acc5" : AverageMeter(),
|
|
94
|
+
"TF" : [], "pred" : []},
|
|
95
|
+
"input": {"sim" : []}
|
|
96
|
+
}
|
|
97
|
+
|
|
98
|
+
def __call__(self, samples, targets = None, **kwargs):
|
|
99
|
+
self.resetInfo()
|
|
100
|
+
self.model.zero_grad()
|
|
101
|
+
self.model.eval()
|
|
102
|
+
|
|
103
|
+
if type(samples) == torch.Tensor:
|
|
104
|
+
self.targets = targets
|
|
105
|
+
outputs = self.model(samples, **kwargs)
|
|
106
|
+
return outputs
|
|
107
|
+
else:
|
|
108
|
+
for sample, targets in tqdm(samples):
|
|
109
|
+
self.targets = targets
|
|
110
|
+
_ = self.model(sample)
|
|
111
|
+
return False
|
|
112
|
+
|
|
113
|
+
def anaSaliency(self, attn=True, grad=False, output=None, index=None,
|
|
114
|
+
head_fusion="mean", discard_ratios=[0.], data_from="cls",
|
|
115
|
+
reshape=False, activate= [True, True, False], device="cuda"):
|
|
116
|
+
|
|
117
|
+
if attn:
|
|
118
|
+
attn = self.info["attn"]["f"]
|
|
119
|
+
if grad:
|
|
120
|
+
self.info["attn"]["b"] = []
|
|
121
|
+
self.model.zero_grad()
|
|
122
|
+
if index == None: index = output.max(dim=1)[1]
|
|
123
|
+
index = torch.eye(self.num_classes, device=device)[index]
|
|
124
|
+
loss = (output * index).sum()
|
|
125
|
+
loss.backward(retain_graph=True)
|
|
126
|
+
grad = self.info["attn"]["b"]
|
|
127
|
+
|
|
128
|
+
return saliency(attn, grad, activate=activate,
|
|
129
|
+
head_fusion=head_fusion, discard_ratios=discard_ratios, data_from=data_from,
|
|
130
|
+
reshape=reshape, device=device)
|
|
131
|
+
|
|
132
|
+
|
|
133
|
+
if __name__ == '__main__':
|
|
134
|
+
dataset_name, device, debug = "imagenet", 'cuda', True
|
|
135
|
+
data_path, num_classes, _, _ = data2path(dataset_name)
|
|
136
|
+
activate = [True, False, False, False] # [ATTN, QKV, HEAD]
|
|
137
|
+
analysis = [0] # [0] : INPUT TYPE, [0 : SAMPLE + POS, 1 : SAMPLE, 2 : POS]
|
|
138
|
+
|
|
139
|
+
dataset = JDataset(data_path, dataset_name, device=device)
|
|
140
|
+
data_idxs = [[c, i] for i in range(1000) for c in range(50)]
|
|
141
|
+
|
|
142
|
+
modelMaker = JModel(num_classes, device=device)
|
|
143
|
+
model = modelMaker.getModel(2, "ViT-B/16")
|
|
144
|
+
|
|
145
|
+
classnames = read_classnames("/hub_data1/joonmyung/data/imagenet/classnames.txt")
|
|
146
|
+
model = ZeroShotInference(model, classnames, prompt="a photo of a {}.", device=device)
|
|
147
|
+
|
|
148
|
+
model = Analysis(model, analysis = analysis, activate = activate, device=device)
|
|
149
|
+
|
|
150
|
+
view = [False, True, False, False, True] # [IMG, SALIENCY:ATTN, SALIENCY:OPENCV, SALIENCY:GRAD, ATTN. MOVEMENT]
|
|
151
|
+
for idx, data_idx in enumerate(data_idxs):
|
|
152
|
+
print(f"------------------------- [{data_idx[0]}]/[{data_idx[1]}] -------------------------")
|
|
153
|
+
|
|
154
|
+
sample, target, label_name = dataset[data_idx[0], data_idx[1]]
|
|
155
|
+
output = model(sample)
|
|
156
|
+
if view[0]:
|
|
157
|
+
drawImgPlot(unNormalize(sample, "imagenet"))
|
|
158
|
+
|
|
159
|
+
if view[1]: # SALIENCY W/ MODEL
|
|
160
|
+
col, discard_ratios, v_ratio, head_fusion, data_from = 12, [0.0], 0.0, "mean", "patch" # Attention, Gradient
|
|
161
|
+
results = model.anaSaliency(True, False, output, discard_ratios=discard_ratios,
|
|
162
|
+
head_fusion = head_fusion, index=target, data_from=data_from,
|
|
163
|
+
reshape = True, activate=[True, True, True]) # (12(L), 8(B), 14(H), 14(W))
|
|
164
|
+
data_roll = overlay(sample, results["rollout"], dataset_name)
|
|
165
|
+
drawImgPlot(data_roll, col=col)
|
|
166
|
+
|
|
167
|
+
data_attn = overlay(sample, results["attentive"], dataset_name)
|
|
168
|
+
drawImgPlot(data_attn, col=col)
|
|
169
|
+
|
|
170
|
+
data_vidTLDR = overlay(sample, results["vidTLDR"], dataset_name)
|
|
171
|
+
drawImgPlot(data_vidTLDR, col=col)
|
|
172
|
+
|
|
173
|
+
discard_ratios, v_ratio, head_fusion, data_from = [0.0], 0.1, "mean", "cls"
|
|
174
|
+
results = model.anaSaliency(True, False, output, discard_ratios=discard_ratios,
|
|
175
|
+
head_fusion=head_fusion, index=target, data_from=data_from,
|
|
176
|
+
reshape=True, activate=[True, True, True]) # (12(L), 8(B), 14(H), 14(W))
|
|
177
|
+
|
|
178
|
+
data_roll = overlay(sample, results["rollout"], dataset_name)
|
|
179
|
+
drawImgPlot(data_roll, col=col)
|
|
180
|
+
|
|
181
|
+
data_attn = overlay(sample, results["attentive"], dataset_name)
|
|
182
|
+
drawImgPlot(data_attn, col=col)
|
|
183
|
+
|
|
184
|
+
data_vidTLDR = overlay(sample, results["vidTLDR"], dataset_name)
|
|
185
|
+
drawImgPlot(data_vidTLDR, col=col)
|
|
186
|
+
|
|
187
|
+
if view[2]: # SALIENCY W/ DATA
|
|
188
|
+
img = np.array(dataset[data_idx[0], data_idx[1], 2][0])
|
|
189
|
+
|
|
190
|
+
saliency = cv2.saliency.StaticSaliencySpectralResidual_create()
|
|
191
|
+
(success, saliencyMap) = saliency.computeSaliency(img)
|
|
192
|
+
saliencyMap = (saliencyMap * 255).astype("uint8")
|
|
193
|
+
|
|
194
|
+
saliency = cv2.saliency.StaticSaliencyFineGrained_create()
|
|
195
|
+
(success, saliencyFineMap) = saliency.computeSaliency(img)
|
|
196
|
+
threshMap = cv2.threshold((saliencyFineMap * 255).astype("uint8"), 0, 255, cv2.THRESH_BINARY | cv2.THRESH_OTSU)[1]
|
|
197
|
+
# plt.imshow(threshMap)
|
|
198
|
+
# plt.show()
|
|
199
|
+
|
|
200
|
+
if view[3]: # SALIENCY FOR INPUT
|
|
201
|
+
sample.requires_grad, model.detach, k = True, False, 3
|
|
202
|
+
output = model(sample)
|
|
203
|
+
attn = torch.stack(model.info["attn"]["f"], dim=1).mean(dim=[2,3])[0,-2]
|
|
204
|
+
topK = attn[1:].topk(k, -1, True)[1]
|
|
205
|
+
a = torch.autograd.grad(output[:,3], sample, retain_graph=True)[0].sum(dim=1)
|
|
206
|
+
b = F.interpolate(a.unsqueeze(0), scale_factor=0.05, mode='nearest')[0]
|
|
207
|
+
|
|
208
|
+
if view[4]: # ATTENTION MOVEMENT (FROM / TO)
|
|
209
|
+
attn = torch.stack(model.info["attn"]["f"]).mean(dim=2).transpose(0,1) # (8 (B), 12 (L), 197(T_Q), 197(T_K))
|
|
210
|
+
|
|
211
|
+
# CLS가 얼마나 참고하는지
|
|
212
|
+
cls2cls = attn[:, :, :1, 0].mean(dim=2) # (8(B), 12(L))
|
|
213
|
+
patch2cls = attn[:, :, :1, 1:].mean(dim=2).sum(dim=-1) # (8(B), 12(L))
|
|
214
|
+
|
|
215
|
+
# PATCH가 얼마나 참고하는지
|
|
216
|
+
cls2patch = attn[:, :, 1:, 0].mean(dim=2)
|
|
217
|
+
patch2patch = attn[:, :, 1:, 1:].mean(dim=2).sum(dim=-1)
|
|
218
|
+
# to_np(torch.stack([cls2cls.mean(dim=0), patch2cls.mean(dim=0), cls2patch.mean(dim=0), patch2patch.mean(dim=0)]))
|