joonmyung 1.5.2__tar.gz → 1.5.5__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {joonmyung-1.5.2 → joonmyung-1.5.5}/PKG-INFO +1 -1
- {joonmyung-1.5.2 → joonmyung-1.5.5}/joonmyung/analysis/analysis.py +85 -59
- {joonmyung-1.5.2 → joonmyung-1.5.5}/joonmyung/analysis/dataset.py +21 -15
- {joonmyung-1.5.2 → joonmyung-1.5.5}/joonmyung/analysis/model.py +24 -2
- {joonmyung-1.5.2 → joonmyung-1.5.5}/joonmyung/draw.py +25 -10
- {joonmyung-1.5.2 → joonmyung-1.5.5}/joonmyung/log.py +3 -3
- joonmyung-1.5.5/joonmyung/models/__init__.py +0 -0
- joonmyung-1.5.5/joonmyung/models/tome.py +386 -0
- {joonmyung-1.5.2 → joonmyung-1.5.5}/joonmyung/script.py +2 -0
- {joonmyung-1.5.2 → joonmyung-1.5.5}/joonmyung.egg-info/PKG-INFO +1 -1
- {joonmyung-1.5.2 → joonmyung-1.5.5}/joonmyung.egg-info/SOURCES.txt +16 -1
- {joonmyung-1.5.2 → joonmyung-1.5.5}/joonmyung.egg-info/top_level.txt +1 -0
- joonmyung-1.5.5/models/SA/MHSA.py +37 -0
- joonmyung-1.5.5/models/SA/PVTSA.py +90 -0
- joonmyung-1.5.5/models/SA/TMSA.py +37 -0
- joonmyung-1.5.5/models/SA/__init__.py +0 -0
- joonmyung-1.5.5/models/__init__.py +0 -0
- joonmyung-1.5.5/models/deit.py +372 -0
- joonmyung-1.5.5/models/evit.py +154 -0
- joonmyung-1.5.5/models/modules/PE.py +139 -0
- joonmyung-1.5.5/models/modules/__init__.py +0 -0
- joonmyung-1.5.5/models/modules/blocks.py +168 -0
- joonmyung-1.5.5/models/pvt.py +307 -0
- joonmyung-1.5.5/models/pvt_v2.py +202 -0
- joonmyung-1.5.5/models/tome.py +285 -0
- {joonmyung-1.5.2 → joonmyung-1.5.5}/setup.py +3 -2
- {joonmyung-1.5.2 → joonmyung-1.5.5}/LICENSE.txt +0 -0
- {joonmyung-1.5.2 → joonmyung-1.5.5}/README.md +0 -0
- {joonmyung-1.5.2 → joonmyung-1.5.5}/joonmyung/__init__.py +0 -0
- {joonmyung-1.5.2 → joonmyung-1.5.5}/joonmyung/analysis/__init__.py +0 -0
- {joonmyung-1.5.2 → joonmyung-1.5.5}/joonmyung/analysis/hook.py +0 -0
- {joonmyung-1.5.2 → joonmyung-1.5.5}/joonmyung/analysis/metric.py +0 -0
- {joonmyung-1.5.2 → joonmyung-1.5.5}/joonmyung/analysis/utils.py +0 -0
- {joonmyung-1.5.2 → joonmyung-1.5.5}/joonmyung/app.py +0 -0
- {joonmyung-1.5.2 → joonmyung-1.5.5}/joonmyung/data.py +0 -0
- {joonmyung-1.5.2 → joonmyung-1.5.5}/joonmyung/dummy.py +0 -0
- {joonmyung-1.5.2 → joonmyung-1.5.5}/joonmyung/file.py +0 -0
- {joonmyung-1.5.2 → joonmyung-1.5.5}/joonmyung/gradcam.py +0 -0
- {joonmyung-1.5.2 → joonmyung-1.5.5}/joonmyung/meta_data/__init__.py +0 -0
- {joonmyung-1.5.2 → joonmyung-1.5.5}/joonmyung/meta_data/label.py +0 -0
- {joonmyung-1.5.2 → joonmyung-1.5.5}/joonmyung/meta_data/utils.py +0 -0
- {joonmyung-1.5.2 → joonmyung-1.5.5}/joonmyung/metric.py +0 -0
- {joonmyung-1.5.2 → joonmyung-1.5.5}/joonmyung/status.py +0 -0
- {joonmyung-1.5.2 → joonmyung-1.5.5}/joonmyung/utils.py +0 -0
- {joonmyung-1.5.2 → joonmyung-1.5.5}/joonmyung.egg-info/dependency_links.txt +0 -0
- {joonmyung-1.5.2 → joonmyung-1.5.5}/joonmyung.egg-info/not-zip-safe +0 -0
- {joonmyung-1.5.2 → joonmyung-1.5.5}/setup.cfg +0 -0
|
@@ -5,7 +5,7 @@ from joonmyung.analysis.model import JModel
|
|
|
5
5
|
from joonmyung.draw import saliency, overlay, drawImgPlot, drawHeatmap, unNormalize
|
|
6
6
|
from joonmyung.meta_data import data2path
|
|
7
7
|
from joonmyung.data import getTransform
|
|
8
|
-
from joonmyung.metric import targetPred
|
|
8
|
+
from joonmyung.metric import targetPred, accuracy
|
|
9
9
|
from joonmyung.log import AverageMeter
|
|
10
10
|
from joonmyung.utils import to_leaf, to_np
|
|
11
11
|
from tqdm import tqdm
|
|
@@ -20,13 +20,15 @@ import cv2
|
|
|
20
20
|
|
|
21
21
|
def anaModel(transformer_class):
|
|
22
22
|
class VisionTransformer(transformer_class):
|
|
23
|
+
def has_parameter(self, parameter_name):
|
|
24
|
+
return parameter_name in self.__init__.__code__.co_varnames
|
|
25
|
+
|
|
23
26
|
def forward_features(self, x):
|
|
24
27
|
x = self.patch_embed(x)
|
|
25
|
-
|
|
26
|
-
|
|
28
|
+
if self.has_parameter("cls_token"):
|
|
29
|
+
cls_token = self.cls_token.expand(x.shape[0], -1, -1) # stole cls_tokens impl from Phil Wang, thanks
|
|
27
30
|
x = torch.cat((cls_token, x), dim=1)
|
|
28
|
-
|
|
29
|
-
x = torch.cat((cls_token, self.dist_token.expand(x.shape[0], -1, -1), x), dim=1)
|
|
31
|
+
|
|
30
32
|
|
|
31
33
|
|
|
32
34
|
if self.analysis[0] == 1: # PATCH
|
|
@@ -43,17 +45,19 @@ def anaModel(transformer_class):
|
|
|
43
45
|
|
|
44
46
|
x = self.blocks(x)
|
|
45
47
|
x = self.norm(x)
|
|
46
|
-
if self.
|
|
48
|
+
if self.has_parameter("cls_token") and self.has_parameter("dist_token"):
|
|
49
|
+
return x[:, 0], x[:, 1]
|
|
50
|
+
elif self.has_parameter("cls_token"):
|
|
47
51
|
return self.pre_logits(x[:, 0])
|
|
48
52
|
else:
|
|
49
|
-
return x
|
|
53
|
+
return self.pre_logits(x.mean(dim=1))
|
|
54
|
+
|
|
50
55
|
|
|
51
56
|
return VisionTransformer
|
|
52
57
|
|
|
53
58
|
class Analysis:
|
|
54
|
-
def __init__(self, model, analysis = [0], activate = [True, False, False], detach=True, key_name=None, num_classes = 1000
|
|
59
|
+
def __init__(self, model, analysis = [0], activate = [True, False, False, False], detach=True, key_name=None, num_classes = 1000
|
|
55
60
|
, cls_start=0, cls_end=1, patch_start=1, patch_end=None
|
|
56
|
-
, ks = 5
|
|
57
61
|
, amp_autocast=suppress, device="cuda"):
|
|
58
62
|
# Section A. Model
|
|
59
63
|
self.num_classes = num_classes
|
|
@@ -63,18 +67,17 @@ class Analysis:
|
|
|
63
67
|
model.__class__ = model_
|
|
64
68
|
model.analysis = analysis
|
|
65
69
|
self.model = model
|
|
66
|
-
self.ks = ks
|
|
67
70
|
self.detach = detach
|
|
68
71
|
|
|
69
72
|
# Section B. Attention
|
|
70
73
|
self.kwargs_roll = {"cls_start" : cls_start, "cls_end" : cls_end,
|
|
71
74
|
"patch_start" : patch_start, "patch_end" : patch_end}
|
|
72
75
|
|
|
73
|
-
|
|
74
76
|
# Section C. Setting
|
|
75
77
|
hooks = [{"name_i": 'attn_drop', "name_o": 'decoder', "fn_f": self.attn_forward, "fn_b": self.attn_backward},
|
|
76
78
|
{"name_i": 'qkv', "name_o": 'decoder', "fn_f": self.qkv_forward, "fn_b": self.qkv_backward},
|
|
77
|
-
{"name_i": 'head', "name_o": 'decoder', "fn_f": self.head_forward, "fn_b": self.head_backward}
|
|
79
|
+
{"name_i": 'head', "name_o": 'decoder', "fn_f": self.head_forward, "fn_b": self.head_backward},
|
|
80
|
+
{"name_i": 'patch_embed.norm', "name_o": 'decoder', "fn_f": self.input_forward, "fn_b": self.input_backward}]
|
|
78
81
|
hooks = [h for h, a in zip(hooks, activate) if a]
|
|
79
82
|
|
|
80
83
|
|
|
@@ -86,20 +89,18 @@ class Analysis:
|
|
|
86
89
|
if hook["name_i"] in name and hook["name_o"] not in name:
|
|
87
90
|
module.register_forward_hook(hook["fn_f"])
|
|
88
91
|
module.register_backward_hook(hook["fn_b"])
|
|
92
|
+
self.resetInfo()
|
|
89
93
|
|
|
90
94
|
def attn_forward(self, module, input, output):
|
|
91
|
-
# input
|
|
92
|
-
|
|
93
|
-
self.info["attn"]["f"].append(output.detach() if self.detach else output)
|
|
95
|
+
# input/output : 1 * (8, 3, 197, 197) / (8, 3, 197, 197)
|
|
96
|
+
self.info["attn"]["f"] = output.detach() if self.detach else output
|
|
94
97
|
|
|
95
98
|
def attn_backward(self, module, grad_input, grad_output):
|
|
96
|
-
# input
|
|
97
|
-
|
|
98
|
-
self.info["attn"]["b"].insert(0, grad_input[0].detach() if self.detach else grad_input[0])
|
|
99
|
+
# input/output : 1 * (8, 3, 197, 192) / (8, 3, 197, 576)
|
|
100
|
+
self.info["attn"]["b"] = grad_input[0].detach() if self.detach else grad_input[0]
|
|
99
101
|
|
|
100
102
|
def qkv_forward(self, module, input, output):
|
|
101
|
-
# input
|
|
102
|
-
# output : (8, 197, 576)
|
|
103
|
+
# input/output : 1 * (8, 197, 192) / (8, 197, 576)
|
|
103
104
|
self.info["qkv"]["f"].append(output.detach())
|
|
104
105
|
|
|
105
106
|
def qkv_backward(self, module, grad_input, grad_output):
|
|
@@ -108,38 +109,50 @@ class Analysis:
|
|
|
108
109
|
|
|
109
110
|
def head_forward(self, module, input, output):
|
|
110
111
|
# input : 1 * (8(B), 192(D)), output : (8(B), 1000(C))
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
self.info["head"]["
|
|
112
|
+
B = output.shape[0]
|
|
113
|
+
pred = targetPred(output, self.targets, topk=5)
|
|
114
|
+
self.info["head"]["TF"] += (pred[:, 0] == pred[:, 1])
|
|
115
|
+
|
|
116
|
+
acc1, acc5 = accuracy(output, self.targets, topk=(1,5))
|
|
117
|
+
self.info["head"]["acc1"].update(acc1.item(), n=B)
|
|
118
|
+
self.info["head"]["acc5"].update(acc5.item(), n=B)
|
|
114
119
|
|
|
115
120
|
def head_backward(self, module, grad_input, grad_output):
|
|
116
121
|
pass
|
|
117
122
|
|
|
118
|
-
def
|
|
119
|
-
|
|
120
|
-
|
|
121
|
-
"acc5" : AverageMeter(),
|
|
122
|
-
"pred" : None
|
|
123
|
+
def input_forward(self, module, input, output):
|
|
124
|
+
norm = F.normalize(output, dim=-1)
|
|
125
|
+
self.info["input"]["sim"] += (norm @ norm.transpose(-1, -2)).mean(dim=(-1, -2))
|
|
123
126
|
|
|
124
|
-
|
|
127
|
+
def input_backward(self, module, grad_input, grad_output):
|
|
128
|
+
pass
|
|
125
129
|
|
|
126
|
-
def
|
|
127
|
-
self.
|
|
130
|
+
def resetInfo(self):
|
|
131
|
+
self.info = {"attn" : {"f": None, "b": None},
|
|
132
|
+
"qkv" : {"f": None, "b": None},
|
|
133
|
+
"head" : {"acc1" : AverageMeter(),
|
|
134
|
+
"acc5" : AverageMeter(),
|
|
135
|
+
"TF" : [], "pred" : []},
|
|
136
|
+
"input": {"sim" : []}
|
|
137
|
+
}
|
|
138
|
+
|
|
139
|
+
def __call__(self, samples, targets = None, **kwargs):
|
|
128
140
|
self.model.zero_grad()
|
|
129
141
|
self.model.eval()
|
|
130
142
|
|
|
131
143
|
if type(samples) == torch.Tensor:
|
|
144
|
+
self.targets = targets
|
|
132
145
|
outputs = self.model(samples, **kwargs)
|
|
133
146
|
return outputs
|
|
134
147
|
else:
|
|
135
|
-
for sample,
|
|
148
|
+
for sample, targets in tqdm(samples):
|
|
149
|
+
self.targets = targets
|
|
136
150
|
_ = self.model(sample)
|
|
137
151
|
return False
|
|
138
152
|
|
|
139
153
|
def anaSaliency(self, attn=True, grad=False, output=None, index=None,
|
|
140
154
|
head_fusion="mean", discard_ratios=[0.], data_from="cls",
|
|
141
|
-
|
|
142
|
-
reshape=False, device="cuda"):
|
|
155
|
+
reshape=False, activate= [True, True, False], device="cuda"):
|
|
143
156
|
|
|
144
157
|
if attn:
|
|
145
158
|
attn = self.info["attn"]["f"]
|
|
@@ -152,31 +165,29 @@ class Analysis:
|
|
|
152
165
|
loss.backward(retain_graph=True)
|
|
153
166
|
grad = self.info["attn"]["b"]
|
|
154
167
|
|
|
155
|
-
return saliency(attn, grad
|
|
156
|
-
|
|
157
|
-
|
|
168
|
+
return saliency(attn, grad, activate=activate,
|
|
169
|
+
head_fusion=head_fusion, discard_ratios=discard_ratios, data_from=data_from,
|
|
170
|
+
reshape=reshape, device=device)
|
|
158
171
|
|
|
159
172
|
|
|
160
173
|
if __name__ == '__main__':
|
|
161
174
|
# Section A. Data
|
|
162
175
|
dataset_name, device, amp_autocast, debug = "imagenet", 'cuda', torch.cuda.amp.autocast, True
|
|
163
176
|
data_path, num_classes, _, _ = data2path(dataset_name)
|
|
164
|
-
view, activate = [False, True, False, False,
|
|
177
|
+
view, activate = [False, True, False, False, True], [True, False, False]
|
|
165
178
|
# VIEW : IMG, SALIENCY:ATTN, SALIENCY:OPENCV, SALIENCY:GRAD, ATTN. MOVEMENT
|
|
166
179
|
# ACTIVATE : ATTN, QKV, HEAD
|
|
167
180
|
analysis = [0] # [0] : INPUT TYPE, [0 : SAMPLE + POS, 1 : SAMPLE, 2 : POS]
|
|
168
181
|
|
|
169
182
|
dataset = JDataset(data_path, dataset_name, device=device)
|
|
170
|
-
data_idxs = [[c, i] for i in range(1000) for c in range(50)]
|
|
171
|
-
|
|
183
|
+
# data_idxs = [[c, i] for i in range(1000) for c in range(50)]
|
|
184
|
+
data_idxs = [[1, 0]]
|
|
172
185
|
|
|
173
186
|
# Section B. Model
|
|
174
187
|
model_number, model_name = 0, "deit_tiny_patch16_224" # deit, vit | tiny, small, base
|
|
175
188
|
# model_number, model_name = 1, "deit_tiny_patch16_224"
|
|
176
189
|
|
|
177
190
|
# Section C. Setting
|
|
178
|
-
ls_rollout, ls_attentive, col = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], 12
|
|
179
|
-
|
|
180
191
|
modelMaker = JModel(num_classes, device=device)
|
|
181
192
|
model = modelMaker.getModel(model_number, model_name)
|
|
182
193
|
model = Analysis(model, analysis = analysis, activate = activate, device=device)
|
|
@@ -190,24 +201,36 @@ if __name__ == '__main__':
|
|
|
190
201
|
drawImgPlot(unNormalize(sample, "imagenet"))
|
|
191
202
|
|
|
192
203
|
if view[1]: # SALIENCY W/ MODEL
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
discard_ratios, v_ratio, head_fusion, data_from = [0.0], 0.0, "mean", "patch" # Attention, Gradient
|
|
196
|
-
rollout, attentive = model.anaSaliency(True, False, output, discard_ratios=discard_ratios,
|
|
197
|
-
ls_attentive = ls_attentive, ls_rollout=ls_rollout,
|
|
204
|
+
col, discard_ratios, v_ratio, head_fusion, data_from = 12, [0.0], 0.0, "mean", "patch" # Attention, Gradient
|
|
205
|
+
results = model.anaSaliency(True, False, output, discard_ratios=discard_ratios,
|
|
198
206
|
head_fusion = head_fusion, index=target, data_from=data_from,
|
|
199
|
-
reshape = True) # (12(L), 8(B), 14(H), 14(W))
|
|
200
|
-
|
|
201
|
-
drawImgPlot(
|
|
207
|
+
reshape = True, activate=[True, True, True]) # (12(L), 8(B), 14(H), 14(W))
|
|
208
|
+
data_roll = overlay(sample, results["rollout"], dataset_name)
|
|
209
|
+
drawImgPlot(data_roll, col=col)
|
|
202
210
|
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
211
|
+
data_attn = overlay(sample, results["attentive"], dataset_name)
|
|
212
|
+
drawImgPlot(data_attn, col=col)
|
|
213
|
+
|
|
214
|
+
data_vidTLDR = overlay(sample, results["vidTLDR"], dataset_name)
|
|
215
|
+
drawImgPlot(data_vidTLDR, col=col)
|
|
216
|
+
|
|
217
|
+
discard_ratios, v_ratio, head_fusion, data_from = [0.0], 0.1, "mean", "cls"
|
|
218
|
+
results = model.anaSaliency(True, False, output, discard_ratios=discard_ratios,
|
|
219
|
+
head_fusion=head_fusion, index=target, data_from=data_from,
|
|
220
|
+
reshape=True, activate=[True, True, True]) # (12(L), 8(B), 14(H), 14(W))
|
|
221
|
+
|
|
222
|
+
data_roll = overlay(sample, results["rollout"], dataset_name)
|
|
223
|
+
drawImgPlot(data_roll, col=col)
|
|
224
|
+
|
|
225
|
+
data_attn = overlay(sample, results["attentive"], dataset_name)
|
|
226
|
+
drawImgPlot(data_attn, col=col)
|
|
210
227
|
|
|
228
|
+
data_vidTLDR = overlay(sample, results["vidTLDR"], dataset_name)
|
|
229
|
+
drawImgPlot(data_vidTLDR, col=col)
|
|
230
|
+
|
|
231
|
+
print(1)
|
|
232
|
+
|
|
233
|
+
# roll = F.normalize(results["rollout"].reshape(12, 196), dim=-1)
|
|
211
234
|
|
|
212
235
|
# datas_rollout = overlay(sample, rollout, dataset_name)
|
|
213
236
|
# drawImgPlot(datas_rollout, col=col)
|
|
@@ -248,10 +271,13 @@ if __name__ == '__main__':
|
|
|
248
271
|
|
|
249
272
|
if view[4]: # ATTENTION MOVEMENT (FROM / TO)
|
|
250
273
|
attn = torch.stack(model.info["attn"]["f"]).mean(dim=2).transpose(0,1) # (8 (B), 12 (L), 197(T_Q), 197(T_K))
|
|
274
|
+
|
|
275
|
+
# CLS가 얼마나 참고하는지
|
|
251
276
|
cls2cls = attn[:, :, :1, 0].mean(dim=2) # (8(B), 12(L))
|
|
252
277
|
patch2cls = attn[:, :, :1, 1:].mean(dim=2).sum(dim=-1) # (8(B), 12(L))
|
|
253
|
-
|
|
278
|
+
|
|
279
|
+
# PATCH가 얼마나 참고하는지
|
|
254
280
|
cls2patch = attn[:, :, 1:, 0].mean(dim=2)
|
|
255
281
|
patch2patch = attn[:, :, 1:, 1:].mean(dim=2).sum(dim=-1)
|
|
256
282
|
# to_np(torch.stack([cls2cls.mean(dim=0), patch2cls.mean(dim=0), cls2patch.mean(dim=0), patch2patch.mean(dim=0)]))
|
|
257
|
-
print(1)
|
|
283
|
+
print(1)
|
|
@@ -1,4 +1,6 @@
|
|
|
1
|
+
from torchvision.transforms import InterpolationMode
|
|
1
2
|
from joonmyung.meta_data.label import imnet_label, cifar_label
|
|
3
|
+
from torchvision.datasets.folder import default_loader
|
|
2
4
|
from timm.data import create_dataset, create_loader
|
|
3
5
|
from torchvision import transforms
|
|
4
6
|
from joonmyung.utils import getDir
|
|
@@ -43,30 +45,32 @@ class JDataset():
|
|
|
43
45
|
size = size if size else setting["size"]
|
|
44
46
|
|
|
45
47
|
self.transform = [
|
|
46
|
-
transforms.Compose([transforms.Resize(
|
|
47
|
-
transforms.Compose([transforms.Resize(
|
|
48
|
+
transforms.Compose([transforms.Resize((256, 256), interpolation=InterpolationMode.BICUBIC), transforms.CenterCrop(size), transforms.ToTensor(), transforms.Normalize(self.distribution["mean"], self.distribution["std"])]),
|
|
49
|
+
transforms.Compose([transforms.Resize((256, 256), interpolation=InterpolationMode.BICUBIC), transforms.CenterCrop(size), transforms.ToTensor()]),
|
|
50
|
+
transforms.Compose([transforms.Resize(size, interpolation=InterpolationMode.BICUBIC), transforms.ToTensor(), transforms.Normalize(self.distribution["mean"], self.distribution["std"])]),
|
|
48
51
|
transforms.Compose([transforms.ToTensor()])
|
|
49
52
|
]
|
|
50
53
|
|
|
51
54
|
self.device = device
|
|
52
55
|
self.data_path = data_path
|
|
53
56
|
self.label_paths = sorted(getDir(os.path.join(self.data_path, self.data_type)))
|
|
54
|
-
|
|
57
|
+
|
|
55
58
|
# self.img_paths = [sorted(glob.glob(os.path.join(self.data_path, self.data_type, "*", "*")))]
|
|
59
|
+
self.img_paths = [[path, idx] for idx, label_path in enumerate(self.label_paths) for path in sorted(glob.glob(os.path.join(self.data_path, self.data_type, label_path, "*")))]
|
|
56
60
|
|
|
57
61
|
|
|
58
62
|
def __getitem__(self, idx):
|
|
59
|
-
if
|
|
60
|
-
|
|
63
|
+
if type(idx) == tuple:
|
|
64
|
+
idx, transform_type = idx
|
|
65
|
+
else:
|
|
61
66
|
transform_type = self.transform_type
|
|
62
|
-
elif len(idx) == 3:
|
|
63
|
-
label_num, img_num, transform_type = idx
|
|
64
67
|
|
|
65
|
-
img_path = self.img_paths[
|
|
66
|
-
|
|
67
|
-
|
|
68
|
+
[img_path, targets] = [idx, 0] if type(idx) == str else self.img_paths[idx]
|
|
69
|
+
|
|
70
|
+
sample = default_loader(img_path)
|
|
71
|
+
sample = self.transform[transform_type](sample)
|
|
68
72
|
|
|
69
|
-
return
|
|
73
|
+
return sample[None].to(self.device), torch.tensor(targets).to(self.device), self.label_name[targets]
|
|
70
74
|
|
|
71
75
|
def getItems(self, indexs):
|
|
72
76
|
ds, ls, lns = [], [], []
|
|
@@ -77,10 +81,12 @@ class JDataset():
|
|
|
77
81
|
lns.append(ln)
|
|
78
82
|
return torch.cat(ds, dim=0), torch.stack(ls, dim=0), lns
|
|
79
83
|
|
|
80
|
-
def
|
|
81
|
-
|
|
82
|
-
|
|
83
|
-
|
|
84
|
+
def getIndex(self, c: list = [0, 1000], i: list = [0, 50]):
|
|
85
|
+
[c_s, c_e], [i_s, i_e] = c, i
|
|
86
|
+
c = torch.arange(c_s, c_e).reshape(-1, 1).repeat(1, i_e - i_s).reshape(-1)
|
|
87
|
+
i = torch.arange(i_s, i_e).reshape(1, -1).repeat(c_e - c_s, 1).reshape(-1)
|
|
88
|
+
c_i = torch.stack([c, i], dim=-1)
|
|
89
|
+
return c_i
|
|
84
90
|
|
|
85
91
|
def __len__(self):
|
|
86
92
|
return
|
|
@@ -1,7 +1,7 @@
|
|
|
1
1
|
from collections import OrderedDict
|
|
2
|
-
|
|
3
2
|
from joonmyung.utils import isDir
|
|
4
3
|
from timm import create_model
|
|
4
|
+
import models.deit
|
|
5
5
|
import torch
|
|
6
6
|
import os
|
|
7
7
|
|
|
@@ -11,7 +11,8 @@ class JModel():
|
|
|
11
11
|
self.num_classes = num_classes
|
|
12
12
|
|
|
13
13
|
if model_path:
|
|
14
|
-
self.model_path = os.path.join(model_path, "
|
|
14
|
+
self.model_path = os.path.join(model_path, "checkpoint.pth")
|
|
15
|
+
|
|
15
16
|
if p and model_path:
|
|
16
17
|
print("file list : ", sorted(os.listdir(model_path), reverse=True))
|
|
17
18
|
self.device = device
|
|
@@ -27,6 +28,27 @@ class JModel():
|
|
|
27
28
|
model = create_model(model_name, pretrained=True, num_classes=self.num_classes, in_chans=3, global_pool=None, scriptable=False)
|
|
28
29
|
elif model_type == 1:
|
|
29
30
|
model = torch.hub.load('facebookresearch/deit:main', model_name, pretrained=True)
|
|
31
|
+
elif model_type == 2:
|
|
32
|
+
checkpoint = torch.load(self.model_path, map_location='cpu')
|
|
33
|
+
args = checkpoint['args']
|
|
34
|
+
model = create_model(
|
|
35
|
+
args.model,
|
|
36
|
+
pretrained=args.pretrained,
|
|
37
|
+
num_classes=args.nb_classes,
|
|
38
|
+
drop_rate=args.drop,
|
|
39
|
+
drop_path_rate=args.drop_path,
|
|
40
|
+
drop_block_rate=None,
|
|
41
|
+
img_size=args.input_size,
|
|
42
|
+
token_nums=args.token_nums,
|
|
43
|
+
embed_type=args.embed_type,
|
|
44
|
+
model_type=args.model_type
|
|
45
|
+
).to(self.device)
|
|
46
|
+
state_dict = []
|
|
47
|
+
for n, p in checkpoint['model'].items():
|
|
48
|
+
if "total_ops" not in n and "total_params" not in n:
|
|
49
|
+
state_dict.append((n, p))
|
|
50
|
+
state_dict = dict(state_dict)
|
|
51
|
+
model.load_state_dict(state_dict)
|
|
30
52
|
else:
|
|
31
53
|
raise ValueError
|
|
32
54
|
model.eval()
|
|
@@ -157,8 +157,9 @@ def drawBarChart(df, x, y, splitColName, col=1, title=[], fmt=1, p=False, c=Fals
|
|
|
157
157
|
|
|
158
158
|
@torch.no_grad()
|
|
159
159
|
def saliency(attentions=None, gradients=None, head_fusion="mean",
|
|
160
|
-
discard_ratios = [0.0], data_from="cls",
|
|
161
|
-
|
|
160
|
+
discard_ratios = [0.0], data_from="cls", reshape=False,
|
|
161
|
+
activate = [True, True, True], device="cpu"):
|
|
162
|
+
|
|
162
163
|
# attentions : L * (B, H, h, w), gradients : L * (B, H, h, w)
|
|
163
164
|
if type(discard_ratios) is not list: discard_ratios = [discard_ratios]
|
|
164
165
|
saliencys = 1.
|
|
@@ -180,14 +181,14 @@ def saliency(attentions=None, gradients=None, head_fusion="mean",
|
|
|
180
181
|
|
|
181
182
|
saliencys = saliencys.to(device)
|
|
182
183
|
|
|
183
|
-
|
|
184
|
+
L, B, _, T = saliencys.shape # (L(12), B(1), T(197), T(197))
|
|
184
185
|
H = W = int(T ** 0.5)
|
|
185
186
|
|
|
186
|
-
|
|
187
|
-
if
|
|
187
|
+
result = {}
|
|
188
|
+
if activate[0]:
|
|
188
189
|
rollouts, I = [], torch.eye(T, device=device).unsqueeze(0).expand(B, -1, -1) # (B, 197, 197)
|
|
189
190
|
for discard_ratio in discard_ratios:
|
|
190
|
-
for start in
|
|
191
|
+
for start in range(L):
|
|
191
192
|
rollout = I
|
|
192
193
|
for attn in copy.deepcopy(saliencys[start:]): # (L, B, w, h)
|
|
193
194
|
# TODO NEED TO CORRECT
|
|
@@ -207,20 +208,34 @@ def saliency(attentions=None, gradients=None, head_fusion="mean",
|
|
|
207
208
|
rollouts.append(rollout)
|
|
208
209
|
rollouts = torch.stack(rollouts, dim=0)
|
|
209
210
|
rollouts = rollouts[:, :, 1:]
|
|
211
|
+
rollouts = rollouts / rollouts.sum(dim=-1, keepdim=True)
|
|
210
212
|
|
|
211
213
|
if reshape:
|
|
212
214
|
rollouts = rollouts.reshape(-1, B, H, W) # L, B, H, W
|
|
215
|
+
result["rollout"] = rollouts
|
|
213
216
|
|
|
214
|
-
if
|
|
217
|
+
if activate[1]:
|
|
215
218
|
# attentive = saliencys[ls_attentive, :, 0] \
|
|
216
219
|
# if data_from == "cls" else saliencys[ls_attentive, :, 1:].mean(dim=2) # (L, B, T)
|
|
217
|
-
attentive = saliencys[
|
|
218
|
-
if data_from == "cls" else saliencys
|
|
220
|
+
attentive = saliencys[:, :, 0] \
|
|
221
|
+
if data_from == "cls" else saliencys.mean(dim=2) # (L, B, T)
|
|
219
222
|
attentive = attentive[:, :, 1:]
|
|
223
|
+
attentive = attentive / attentive.sum(dim=-1, keepdim=True)
|
|
224
|
+
|
|
220
225
|
if reshape:
|
|
221
226
|
attentive = attentive.reshape(-1, B, H, W)
|
|
227
|
+
result["attentive"] = attentive
|
|
228
|
+
|
|
229
|
+
if activate[2]:
|
|
230
|
+
entropy = (saliencys * torch.log(saliencys)).sum(dim=-1)[:, :, 1:] # (L(12), B(1), T(196))
|
|
231
|
+
entropy = entropy - entropy.amin(dim=-1, keepdim=True)
|
|
232
|
+
entropy = entropy / entropy.sum(dim=-1, keepdim=True)
|
|
233
|
+
if reshape:
|
|
234
|
+
entropy = entropy.reshape(L, B, H, W)
|
|
235
|
+
result["vidTLDR"] = entropy
|
|
236
|
+
|
|
222
237
|
|
|
223
|
-
return
|
|
238
|
+
return result
|
|
224
239
|
|
|
225
240
|
|
|
226
241
|
|
|
@@ -49,14 +49,14 @@ class AverageMeter:
|
|
|
49
49
|
|
|
50
50
|
class Logger():
|
|
51
51
|
loggers = {}
|
|
52
|
-
def __init__(self, use_wandb=True, wandb_entity=None, wandb_project=None, wandb_name=None
|
|
52
|
+
def __init__(self, use_wandb=True, wandb_entity=None, wandb_project=None, wandb_name=None, wandb_tags=None
|
|
53
53
|
, wandb_watch=False, main_process=True, wandb_id=None, wandb_dir='./'
|
|
54
54
|
, args=None, model=False):
|
|
55
55
|
self.use_wandb = use_wandb
|
|
56
56
|
self.main_process = main_process
|
|
57
57
|
|
|
58
58
|
if self.use_wandb and self.main_process:
|
|
59
|
-
wandb.init(entity=wandb_entity, project=wandb_project, name=wandb_name
|
|
59
|
+
wandb.init(entity=wandb_entity, project=wandb_project, name=wandb_name, tags=wandb_tags
|
|
60
60
|
, save_code=True, resume="allow", id = wandb_id, dir=wandb_dir
|
|
61
61
|
, config=args, settings=wandb.Settings(code_dir="."))
|
|
62
62
|
|
|
@@ -120,7 +120,7 @@ class Logger():
|
|
|
120
120
|
wandb.finish()
|
|
121
121
|
|
|
122
122
|
def save(self, model, args, name):
|
|
123
|
-
if self.main_process:
|
|
123
|
+
if self.main_process and self.use_wandb:
|
|
124
124
|
path = os.path.join(wandb.run.dir, f"{name}.pth")
|
|
125
125
|
torch.save({"model" : model, "args" : args}, path)
|
|
126
126
|
wandb.save(path, wandb.run.dir)
|
|
File without changes
|