joonmyung 1.4.8__tar.gz → 1.4.9__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {joonmyung-1.4.8 → joonmyung-1.4.9}/PKG-INFO +1 -5
- joonmyung-1.4.9/README.md +87 -0
- joonmyung-1.4.9/joonmyung/analysis/analysis.py +245 -0
- {joonmyung-1.4.8 → joonmyung-1.4.9}/joonmyung/analysis/dataset.py +11 -8
- joonmyung-1.4.9/joonmyung/analysis/model.py +35 -0
- joonmyung-1.4.9/joonmyung/data.py +47 -0
- {joonmyung-1.4.8 → joonmyung-1.4.9}/joonmyung/draw.py +80 -61
- {joonmyung-1.4.8 → joonmyung-1.4.9}/joonmyung/file.py +1 -1
- {joonmyung-1.4.8 → joonmyung-1.4.9}/joonmyung/log.py +39 -0
- {joonmyung-1.4.8 → joonmyung-1.4.9}/joonmyung/meta_data/utils.py +20 -21
- {joonmyung-1.4.8 → joonmyung-1.4.9}/joonmyung/metric.py +7 -6
- {joonmyung-1.4.8 → joonmyung-1.4.9}/joonmyung/script.py +46 -21
- {joonmyung-1.4.8 → joonmyung-1.4.9}/joonmyung/utils.py +3 -2
- {joonmyung-1.4.8 → joonmyung-1.4.9}/joonmyung.egg-info/PKG-INFO +1 -5
- {joonmyung-1.4.8 → joonmyung-1.4.9}/setup.py +2 -2
- joonmyung-1.4.8/README.md +0 -60
- joonmyung-1.4.8/joonmyung/analysis/analysis.py +0 -165
- joonmyung-1.4.8/joonmyung/analysis/model.py +0 -52
- joonmyung-1.4.8/joonmyung/data.py +0 -27
- {joonmyung-1.4.8 → joonmyung-1.4.9}/LICENSE.txt +0 -0
- {joonmyung-1.4.8 → joonmyung-1.4.9}/joonmyung/__init__.py +0 -0
- {joonmyung-1.4.8 → joonmyung-1.4.9}/joonmyung/analysis/__init__.py +0 -0
- {joonmyung-1.4.8 → joonmyung-1.4.9}/joonmyung/analysis/hook.py +0 -0
- {joonmyung-1.4.8 → joonmyung-1.4.9}/joonmyung/analysis/metric.py +0 -0
- {joonmyung-1.4.8 → joonmyung-1.4.9}/joonmyung/analysis/utils.py +0 -0
- {joonmyung-1.4.8 → joonmyung-1.4.9}/joonmyung/app.py +0 -0
- {joonmyung-1.4.8 → joonmyung-1.4.9}/joonmyung/dummy.py +0 -0
- {joonmyung-1.4.8 → joonmyung-1.4.9}/joonmyung/gradcam.py +0 -0
- {joonmyung-1.4.8 → joonmyung-1.4.9}/joonmyung/meta_data/__init__.py +0 -0
- {joonmyung-1.4.8 → joonmyung-1.4.9}/joonmyung/meta_data/label.py +0 -0
- {joonmyung-1.4.8 → joonmyung-1.4.9}/joonmyung/status.py +0 -0
- {joonmyung-1.4.8 → joonmyung-1.4.9}/joonmyung.egg-info/SOURCES.txt +0 -0
- {joonmyung-1.4.8 → joonmyung-1.4.9}/joonmyung.egg-info/dependency_links.txt +0 -0
- {joonmyung-1.4.8 → joonmyung-1.4.9}/joonmyung.egg-info/not-zip-safe +0 -0
- {joonmyung-1.4.8 → joonmyung-1.4.9}/joonmyung.egg-info/top_level.txt +0 -0
- {joonmyung-1.4.8 → joonmyung-1.4.9}/setup.cfg +0 -0
|
@@ -1,13 +1,9 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: joonmyung
|
|
3
|
-
Version: 1.4.
|
|
3
|
+
Version: 1.4.9
|
|
4
4
|
Summary: JoonMyung's Library
|
|
5
5
|
Home-page: https://github.com/pizard/JoonMyung.git
|
|
6
6
|
Author: JoonMyung Choi
|
|
7
7
|
Author-email: pizard@korea.ac.kr
|
|
8
8
|
License: MIT
|
|
9
|
-
Platform: UNKNOWN
|
|
10
9
|
License-File: LICENSE.txt
|
|
11
|
-
|
|
12
|
-
UNKNOWN
|
|
13
|
-
|
|
@@ -0,0 +1,87 @@
|
|
|
1
|
+
# 1. Introduction
|
|
2
|
+
JoonMyung Choi's Package
|
|
3
|
+
|
|
4
|
+
|
|
5
|
+
# 2. ToDo List
|
|
6
|
+
### a. Library
|
|
7
|
+
1. joonmyung/Script
|
|
8
|
+
- [ ] 추가 스크립트, Queue 추가
|
|
9
|
+
|
|
10
|
+
2. joonmyung/draw
|
|
11
|
+
- [ ] LinePlot 수정
|
|
12
|
+
|
|
13
|
+
### b. Playground
|
|
14
|
+
|
|
15
|
+
|
|
16
|
+
# 3. Previous
|
|
17
|
+
## Version 1.4.1
|
|
18
|
+
1. joonmyung/log
|
|
19
|
+
- [X] wandb table 저장 오류 수정
|
|
20
|
+
2. joonmyung/utils
|
|
21
|
+
- [X] str2list 띄어쓰기 오류 수정
|
|
22
|
+
|
|
23
|
+
## Version 1.4.0
|
|
24
|
+
1. joonmyung/app.py
|
|
25
|
+
- [X] 실험 도중 스크립트 추가 기능
|
|
26
|
+
2. joonmyung/log
|
|
27
|
+
- [X] 모델, 코드 저장 기능
|
|
28
|
+
|
|
29
|
+
## Version 1.3.2
|
|
30
|
+
1. joonmyung/Logger
|
|
31
|
+
- [X] wandb_id 작업
|
|
32
|
+
|
|
33
|
+
## Version 1.3.2
|
|
34
|
+
1. joonmyung/Script
|
|
35
|
+
- [X] Multi-GPU 적용
|
|
36
|
+
2joonmyung/draw
|
|
37
|
+
- [X] rollout (Attention, Gradient) 추가
|
|
38
|
+
3. joonmyung/log
|
|
39
|
+
- [X] type에 대한 확인
|
|
40
|
+
4. playground/profiling
|
|
41
|
+
- [X] 속도 측정 비교 메서드 구체화
|
|
42
|
+
|
|
43
|
+
|
|
44
|
+
## Version 1.3.1
|
|
45
|
+
### a. Library
|
|
46
|
+
1. joonmyung/draw
|
|
47
|
+
- [X] overlay 기능 추가
|
|
48
|
+
|
|
49
|
+
### b. Playground
|
|
50
|
+
1. playground/analysis
|
|
51
|
+
- [X] data 관련 분석 코드 작성
|
|
52
|
+
- [X] Model 관련 분석 코드 작성
|
|
53
|
+
|
|
54
|
+
## Version 1.3.0
|
|
55
|
+
### a. Library
|
|
56
|
+
1. joonmyung/draw
|
|
57
|
+
- [X] drawImgPlot 추가
|
|
58
|
+
|
|
59
|
+
3. joonmyung/log
|
|
60
|
+
- [X] Wandb Log / Table 추가
|
|
61
|
+
|
|
62
|
+
|
|
63
|
+
|
|
64
|
+
|
|
65
|
+
|
|
66
|
+
|
|
67
|
+
|
|
68
|
+
|
|
69
|
+
[//]: # (CUDA_VISIBLE_DEVICES=2 python playground/models/fastsam/model.py --split 0)
|
|
70
|
+
# CUDA_VISIBLE_DEVICES=2 nohup python playground/models/fastsam/model.py --split 0 > 0.log 2>&1 &
|
|
71
|
+
# CUDA_VISIBLE_DEVICES=2 nohup python playground/models/fastsam/model.py --split 1 > 1.log 2>&1 &
|
|
72
|
+
# CUDA_VISIBLE_DEVICES=2 nohup python playground/models/fastsam/model.py --split 2 > 2.log 2>&1 &
|
|
73
|
+
# CUDA_VISIBLE_DEVICES=2 nohup python playground/models/fastsam/model.py --split 3 > 3.log 2>&1 &
|
|
74
|
+
# CUDA_VISIBLE_DEVICES=3 nohup python playground/models/fastsam/model.py --split 4 > 4.log 2>&1 &
|
|
75
|
+
# CUDA_VISIBLE_DEVICES=3 nohup python playground/models/fastsam/model.py --split 5 > 5.log 2>&1 &
|
|
76
|
+
# CUDA_VISIBLE_DEVICES=3 nohup python playground/models/fastsam/model.py --split 6 > 6.log 2>&1 &
|
|
77
|
+
# CUDA_VISIBLE_DEVICES=3 nohup python playground/models/fastsam/model.py --split 7 > 7.log 2>&1 &
|
|
78
|
+
|
|
79
|
+
|
|
80
|
+
# nohup python playground/saliency/opencv.py --split 0 > 0.log 2>&1 &
|
|
81
|
+
# nohup python playground/saliency/opencv.py --split 1 > 1.log 2>&1 &
|
|
82
|
+
# nohup python playground/saliency/opencv.py --split 2 > 2.log 2>&1 &
|
|
83
|
+
# nohup python playground/saliency/opencv.py --split 3 > 3.log 2>&1 &
|
|
84
|
+
# nohup python playground/saliency/opencv.py --split 4 > 4.log 2>&1 &
|
|
85
|
+
# nohup python playground/saliency/opencv.py --split 5 > 5.log 2>&1 &
|
|
86
|
+
# nohup python playground/saliency/opencv.py --split 6 > 6.log 2>&1 &
|
|
87
|
+
# nohup python playground/saliency/opencv.py --split 7 > 7.log 2>&1 &
|
|
@@ -0,0 +1,245 @@
|
|
|
1
|
+
import os
|
|
2
|
+
os.environ["CUDA_VISIBLE_DEVICES"] = "4"
|
|
3
|
+
from joonmyung.analysis.dataset import JDataset
|
|
4
|
+
from joonmyung.analysis.model import JModel
|
|
5
|
+
from joonmyung.draw import saliency, overlay, drawImgPlot, drawHeatmap, unNormalize
|
|
6
|
+
from joonmyung.meta_data import data2path
|
|
7
|
+
from joonmyung.data import getTransform
|
|
8
|
+
from joonmyung.metric import targetPred
|
|
9
|
+
from joonmyung.log import AverageMeter
|
|
10
|
+
from joonmyung.utils import to_leaf, to_np
|
|
11
|
+
from tqdm import tqdm
|
|
12
|
+
from contextlib import suppress
|
|
13
|
+
import matplotlib.pyplot as plt
|
|
14
|
+
import torch.nn.functional as F
|
|
15
|
+
import numpy as np
|
|
16
|
+
import torch
|
|
17
|
+
import PIL
|
|
18
|
+
import cv2
|
|
19
|
+
|
|
20
|
+
|
|
21
|
+
def anaModel(transformer_class):
|
|
22
|
+
class VisionTransformer(transformer_class):
|
|
23
|
+
def forward_features(self, x):
|
|
24
|
+
x = self.patch_embed(x)
|
|
25
|
+
cls_token = self.cls_token.expand(x.shape[0], -1, -1) # stole cls_tokens impl from Phil Wang, thanks
|
|
26
|
+
if self.dist_token is None:
|
|
27
|
+
x = torch.cat((cls_token, x), dim=1)
|
|
28
|
+
else:
|
|
29
|
+
x = torch.cat((cls_token, self.dist_token.expand(x.shape[0], -1, -1), x), dim=1)
|
|
30
|
+
|
|
31
|
+
|
|
32
|
+
if self.analysis[0] == 1: # PATCH
|
|
33
|
+
x = x # (8, 197, 192)
|
|
34
|
+
elif self.analysis[0] == 2: # POS
|
|
35
|
+
x = self.pos_embed # (1, 197, 192)
|
|
36
|
+
elif self.analysis[0] == 3: # PATCH (RANDOM I) + POS
|
|
37
|
+
x = torch.rand_like(self.pos_embed, device=x.device) + self.pos_embed
|
|
38
|
+
elif self.analysis[0] == 4: # PATCH (RANDOM II) + POS
|
|
39
|
+
x = torch.rand_like(self.cls_token, device=x.device).repeat(1, x.shape[1], 1) + self.pos_embed
|
|
40
|
+
else: # PATCH + POS
|
|
41
|
+
x = x + self.pos_embed
|
|
42
|
+
x = self.pos_drop(x)
|
|
43
|
+
|
|
44
|
+
x = self.blocks(x)
|
|
45
|
+
x = self.norm(x)
|
|
46
|
+
if self.dist_token is None:
|
|
47
|
+
return self.pre_logits(x[:, 0])
|
|
48
|
+
else:
|
|
49
|
+
return x[:, 0], x[:, 1]
|
|
50
|
+
|
|
51
|
+
return VisionTransformer
|
|
52
|
+
|
|
53
|
+
class Analysis:
|
|
54
|
+
def __init__(self, model, analysis = [0], activate = [True, False, False], detach=True, key_name=None, num_classes = 1000
|
|
55
|
+
, cls_start=0, cls_end=1, patch_start=1, patch_end=None
|
|
56
|
+
, ks = 5
|
|
57
|
+
, amp_autocast=suppress, device="cuda"):
|
|
58
|
+
# Section A. Model
|
|
59
|
+
self.num_classes = num_classes
|
|
60
|
+
self.key_name = key_name
|
|
61
|
+
|
|
62
|
+
model_ = anaModel(model.__class__)
|
|
63
|
+
model.__class__ = model_
|
|
64
|
+
model.analysis = analysis
|
|
65
|
+
self.model = model
|
|
66
|
+
self.ks = ks
|
|
67
|
+
self.detach = detach
|
|
68
|
+
|
|
69
|
+
# Section B. Attention
|
|
70
|
+
self.kwargs_roll = {"cls_start" : cls_start, "cls_end" : cls_end,
|
|
71
|
+
"patch_start" : patch_start, "patch_end" : patch_end}
|
|
72
|
+
|
|
73
|
+
|
|
74
|
+
# Section C. Setting
|
|
75
|
+
hooks = [{"name_i": 'attn_drop', "name_o": 'decoder', "fn_f": self.attn_forward, "fn_b": self.attn_backward},
|
|
76
|
+
{"name_i": 'qkv', "name_o": 'decoder', "fn_f": self.qkv_forward, "fn_b": self.qkv_backward},
|
|
77
|
+
{"name_i": 'head', "name_o": 'decoder', "fn_f": self.head_forward, "fn_b": self.head_backward}]
|
|
78
|
+
hooks = [h for h, a in zip(hooks, activate) if a]
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
self.amp_autocast = amp_autocast
|
|
82
|
+
self.device = device
|
|
83
|
+
|
|
84
|
+
for name, module in self.model.named_modules():
|
|
85
|
+
for hook in hooks:
|
|
86
|
+
if hook["name_i"] in name and hook["name_o"] not in name:
|
|
87
|
+
module.register_forward_hook(hook["fn_f"])
|
|
88
|
+
module.register_backward_hook(hook["fn_b"])
|
|
89
|
+
|
|
90
|
+
def attn_forward(self, module, input, output):
|
|
91
|
+
# input : 1 * (8, 3, 197, 197)
|
|
92
|
+
# output : (8, 3, 197, 197)
|
|
93
|
+
self.info["attn"]["f"].append(output.detach() if self.detach else output)
|
|
94
|
+
|
|
95
|
+
def attn_backward(self, module, grad_input, grad_output):
|
|
96
|
+
# input : 1 * (8, 3, 197, 192)
|
|
97
|
+
# output : (8, 3, 197, 576)
|
|
98
|
+
self.info["attn"]["b"].insert(0, grad_input[0].detach() if self.detach else grad_input[0])
|
|
99
|
+
|
|
100
|
+
def qkv_forward(self, module, input, output):
|
|
101
|
+
# input : 1 * (8, 197, 192)
|
|
102
|
+
# output : (8, 197, 576)
|
|
103
|
+
self.info["qkv"]["f"].append(output.detach())
|
|
104
|
+
|
|
105
|
+
def qkv_backward(self, module, grad_input, grad_output):
|
|
106
|
+
# self.info["qkv"]["b"].append(grad_input[0].detach())
|
|
107
|
+
pass
|
|
108
|
+
|
|
109
|
+
def head_forward(self, module, input, output):
|
|
110
|
+
# input : 1 * (8(B), 192(D)), output : (8(B), 1000(C))
|
|
111
|
+
# TP = targetPred(output.detach(), self.targets.detach(), topk=self.ks)
|
|
112
|
+
TP = targetPred(to_leaf(output), to_leaf(self.targets), topk=self.ks)
|
|
113
|
+
self.info["head"]["TP"] = torch.cat([self.info["head"]["TP"], TP], dim=0) if "TP" in self.info["head"].keys() else TP
|
|
114
|
+
|
|
115
|
+
def head_backward(self, module, grad_input, grad_output):
|
|
116
|
+
pass
|
|
117
|
+
|
|
118
|
+
def resetInfo(self):
|
|
119
|
+
self.info = {"attn": {"f": [], "b": []}, "qkv": {"f": [], "b": []},
|
|
120
|
+
"head": {"acc1" : AverageMeter(),
|
|
121
|
+
"acc5" : AverageMeter(),
|
|
122
|
+
"pred" : None
|
|
123
|
+
|
|
124
|
+
}}
|
|
125
|
+
|
|
126
|
+
def __call__(self, samples, index=None, **kwargs):
|
|
127
|
+
self.resetInfo()
|
|
128
|
+
self.model.zero_grad()
|
|
129
|
+
self.model.eval()
|
|
130
|
+
|
|
131
|
+
if type(samples) == torch.Tensor:
|
|
132
|
+
outputs = self.model(samples, **kwargs)
|
|
133
|
+
return outputs
|
|
134
|
+
else:
|
|
135
|
+
for sample, self.targets in tqdm(samples):
|
|
136
|
+
_ = self.model(sample)
|
|
137
|
+
return False
|
|
138
|
+
|
|
139
|
+
def anaSaliency(self, attn=True, grad=False, output=None, index=None,
|
|
140
|
+
head_fusion="mean", discard_ratios=[0.], data_from="cls",
|
|
141
|
+
ls_attentive=[], ls_rollout=[],
|
|
142
|
+
reshape=False, device="cuda"):
|
|
143
|
+
|
|
144
|
+
if attn:
|
|
145
|
+
attn = self.info["attn"]["f"]
|
|
146
|
+
if grad:
|
|
147
|
+
self.info["attn"]["b"] = []
|
|
148
|
+
self.model.zero_grad()
|
|
149
|
+
if index == None: index = output.max(dim=1)[1]
|
|
150
|
+
index = torch.eye(self.num_classes, device=self.device)[index]
|
|
151
|
+
loss = (output * index).sum()
|
|
152
|
+
loss.backward(retain_graph=True)
|
|
153
|
+
grad = self.info["attn"]["b"]
|
|
154
|
+
|
|
155
|
+
return saliency(attn, grad
|
|
156
|
+
, head_fusion=head_fusion, discard_ratios=discard_ratios, data_from=data_from
|
|
157
|
+
, ls_rollout=ls_rollout, ls_attentive=ls_attentive, reshape=reshape, device=device)
|
|
158
|
+
|
|
159
|
+
|
|
160
|
+
if __name__ == '__main__':
|
|
161
|
+
# Section A. Data
|
|
162
|
+
dataset_name, device, amp_autocast, debug = "imagenet", 'cuda', torch.cuda.amp.autocast, True
|
|
163
|
+
data_path, _, _ = data2path(dataset_name)
|
|
164
|
+
data_num, batch_size, bs = [[0, 0], [1, 0], [2, 0], [3, 0], [0, 1], [1, 1], [2, 1], [3, 1]], 16, []
|
|
165
|
+
view, activate = [False, True, False, False, False], [True, False, False] #
|
|
166
|
+
# VIEW : IMG, SALIENCY(M), SALIENCY(D), SALIENCY(S), ATTN. MOVEMENT
|
|
167
|
+
# ACTIVATE : ATTN, QKV, HEAD
|
|
168
|
+
analysis = [2]
|
|
169
|
+
# [0] : INPUT TYPE, [0 : SAMPLE + POS, 1 : SAMPLE, 2 : POS]
|
|
170
|
+
if not debug:
|
|
171
|
+
dataset = JDataset(data_path, dataset_name, device=device)
|
|
172
|
+
samples, targets, imgs, label_names = dataset.getItems(data_num)
|
|
173
|
+
loader = dataset.getAllItems(batch_size)
|
|
174
|
+
num_classes = dataset.num_classes
|
|
175
|
+
else:
|
|
176
|
+
transform = getTransform(False, True)
|
|
177
|
+
img = PIL.Image.open('/hub_data1/joonmyung/data/imagenet/train/n01440764/n01440764_10026.JPEG')
|
|
178
|
+
samples, targets, label_names = transform(img)[None].to(device), torch.tensor([0]).to(device)[None].to(device), 'tench, Tinca tinca'
|
|
179
|
+
num_classes = 1000
|
|
180
|
+
|
|
181
|
+
# Section B. Model
|
|
182
|
+
model_number, model_name = 0, "deit_tiny_patch16_224" # deit_tiny_patch16_224, deit_small_patch16_224, deit_base_patch16_224
|
|
183
|
+
# model_number, model_name = 0, "vit_tiny_patch16_224" # vit_tiny_patch16_224, vit_small_patch16_224, vit_base_patch16_224
|
|
184
|
+
# model_number, model_name = 1, "deit_tiny_patch16_224"
|
|
185
|
+
|
|
186
|
+
modelMaker = JModel(num_classes, device=device)
|
|
187
|
+
model = modelMaker.getModel(model_number, model_name)
|
|
188
|
+
model = Analysis(model, analysis = analysis, activate = activate, device=device)
|
|
189
|
+
|
|
190
|
+
|
|
191
|
+
samples_ = samples[bs] if bs else samples
|
|
192
|
+
targets_ = targets[bs] if bs else targets
|
|
193
|
+
output = model(samples_)
|
|
194
|
+
if view[0]:
|
|
195
|
+
drawImgPlot(unNormalize(samples_, "imagenet"))
|
|
196
|
+
|
|
197
|
+
if view[1]: # SALIENCY W/ MODEL
|
|
198
|
+
ls_rollout, ls_attentive, col = [0,1,2,3,4,5,6,7,8,9,10,11], [0,1,2,3,4,5,6,7,8,9,10,11], 12
|
|
199
|
+
# discard_ratios, v_ratio, head_fusion, data_from = [0.0, 0.4, 0.8], 0.1, "mean", "cls" # Attention, Gradient
|
|
200
|
+
discard_ratios, v_ratio, head_fusion, data_from = [0.0], 0.0, "mean", "patch" # Attention, Gradient
|
|
201
|
+
rollout, attentive = model.anaSaliency(True, False, output, discard_ratios=discard_ratios,
|
|
202
|
+
ls_attentive = ls_attentive, ls_rollout=ls_rollout,
|
|
203
|
+
head_fusion = head_fusion, index=targets_, data_from=data_from,
|
|
204
|
+
reshape = True) # (12(L), 8(B), 14(H), 14(W))
|
|
205
|
+
print(1)
|
|
206
|
+
# datas_rollout = overlay(samples_, rollout, dataset_name)
|
|
207
|
+
# drawImgPlot(datas_rollout, col=col)
|
|
208
|
+
# datas_attn = overlay(samples_, attentive, dataset_name)
|
|
209
|
+
# drawImgPlot(datas_attn, col=col)
|
|
210
|
+
|
|
211
|
+
if view[2]: # SALIENCY W/ DATA
|
|
212
|
+
img = np.array(imgs[0])
|
|
213
|
+
|
|
214
|
+
saliency = cv2.saliency.StaticSaliencySpectralResidual_create()
|
|
215
|
+
(success, saliencyMap) = saliency.computeSaliency(img)
|
|
216
|
+
saliencyMap = (saliencyMap * 255).astype("uint8")
|
|
217
|
+
|
|
218
|
+
saliency = cv2.saliency.StaticSaliencyFineGrained_create()
|
|
219
|
+
(success, saliencyFineMap) = saliency.computeSaliency(img)
|
|
220
|
+
threshMap = cv2.threshold((saliencyFineMap * 255).astype("uint8"), 0, 255, cv2.THRESH_BINARY | cv2.THRESH_OTSU)[1]
|
|
221
|
+
# plt.imshow(threshMap)
|
|
222
|
+
# plt.show()
|
|
223
|
+
|
|
224
|
+
if view[3]: # SALIENCY FOR INPUT
|
|
225
|
+
samples_.requires_grad, model.detach, k = True, False, 3
|
|
226
|
+
output = model(samples_)
|
|
227
|
+
attn = torch.stack(model.info["attn"]["f"], dim=1).mean(dim=[2,3])[0,-2]
|
|
228
|
+
topK = attn[1:].topk(k, -1, True)[1]
|
|
229
|
+
# a = torch.autograd.grad(attn.sum(), samples, retain_graph=True)[0].sum(dim=1)
|
|
230
|
+
a = torch.autograd.grad(output[:,3], samples_, retain_graph=True)[0].sum(dim=1)
|
|
231
|
+
b = F.interpolate(a.unsqueeze(0), scale_factor=0.05, mode='nearest')[0]
|
|
232
|
+
# drawHeatmap(b)
|
|
233
|
+
print(1)
|
|
234
|
+
# to_np(torch.stack([attn[:, :, 0], attn[:, :, 1:].sum(dim=-1)], -1)[0])
|
|
235
|
+
|
|
236
|
+
if view[4]: # ATTENTION MOVEMENT (FROM / TO)
|
|
237
|
+
attn = torch.stack(model.info["attn"]["f"]).mean(dim=2).transpose(0,1) # (8 (B), 12 (L), 197(T_Q), 197(T_K))
|
|
238
|
+
cls2cls = attn[:, :, :1, 0].mean(dim=2) # (8(B), 12(L))
|
|
239
|
+
patch2cls = attn[:, :, :1, 1:].mean(dim=2).sum(dim=-1) # (8(B), 12(L))
|
|
240
|
+
to_np(torch.stack([cls2cls.mean(dim=0), patch2cls.mean(dim=0)]))
|
|
241
|
+
|
|
242
|
+
cls2patch = attn[:, :, 1:, 0].mean(dim=2)
|
|
243
|
+
patch2patch = attn[:, :, 1:, 1:].mean(dim=2).sum(dim=-1)
|
|
244
|
+
to_np(torch.stack([cls2patch.mean(dim=0), patch2patch.mean(dim=0)]))
|
|
245
|
+
print(1)
|
|
@@ -1,21 +1,22 @@
|
|
|
1
1
|
from joonmyung.meta_data.label import imnet_label, cifar_label
|
|
2
2
|
from timm.data import create_dataset, create_loader
|
|
3
3
|
from torchvision import transforms
|
|
4
|
-
from joonmyung.utils import getDir
|
|
5
4
|
import torch
|
|
6
5
|
import copy
|
|
7
6
|
import glob
|
|
8
7
|
import PIL
|
|
9
8
|
import os
|
|
10
9
|
|
|
10
|
+
from joonmyung.utils import getDir
|
|
11
|
+
|
|
11
12
|
|
|
12
13
|
class JDataset():
|
|
13
14
|
distributions = {"imagenet": {"mean": [0.485, 0.456, 0.406], "std": [0.229, 0.224, 0.225]},
|
|
14
15
|
"cifar": {"mean": [0.4914, 0.4822, 0.4465], "std": [0.2023, 0.1994, 0.2010]}}
|
|
15
|
-
transform_cifar
|
|
16
|
-
transform_imagenet_
|
|
16
|
+
transform_cifar = transforms.Compose([transforms.ToTensor(), transforms.Normalize(distributions["cifar"]["mean"], distributions["cifar"]["std"])])
|
|
17
|
+
# transform_imagenet_ = transforms.Compose([transforms.ToTensor(), transforms.Normalize(distributions["imagenet"]["mean"], distributions["imagenet"]["std"])])
|
|
17
18
|
# transform_imagenet_vis = transforms.Compose([transforms.Resize(256, interpolation=3), transforms.CenterCrop(224)])
|
|
18
|
-
transform_imagenet_vis
|
|
19
|
+
transform_imagenet_vis = transforms.Compose([transforms.Resize((224, 224), interpolation=3)])
|
|
19
20
|
transform_imagenet_norm = transforms.Compose([transforms.ToTensor(), transforms.Normalize(distributions["imagenet"]["mean"], distributions["imagenet"]["std"])])
|
|
20
21
|
|
|
21
22
|
# transforms.Resize(int((256 / 224) * input_size), interpolation=InterpolationMode.BICUBIC),
|
|
@@ -40,13 +41,15 @@ class JDataset():
|
|
|
40
41
|
result[:, c].sub_(m).div_(s)
|
|
41
42
|
return result
|
|
42
43
|
|
|
43
|
-
def __init__(self, data_path="/hub_data1/joonmyung/data", dataset="imagenet", device="cuda"):
|
|
44
|
+
def __init__(self, data_path="/hub_data1/joonmyung/data", dataset="imagenet", device="cuda", train = False):
|
|
44
45
|
dataset = dataset.lower()
|
|
45
46
|
|
|
46
47
|
self.d = dataset.lower()
|
|
47
48
|
self.num_classes = 1000 if self.d == "imagenet" else 100
|
|
48
|
-
|
|
49
|
-
|
|
49
|
+
if train:
|
|
50
|
+
[self.d_kind, self.d_type] = ["imagenet", "val"] if self.d == "imagenet" else ["cifar", "test"]
|
|
51
|
+
else:
|
|
52
|
+
[self.d_kind, self.d_type] = ["imagenet", "train"] if self.d == "imagenet" else ["cifar", "train"]
|
|
50
53
|
self.device = device
|
|
51
54
|
|
|
52
55
|
self.transform = self.transforms[self.d_kind]
|
|
@@ -81,7 +84,7 @@ class JDataset():
|
|
|
81
84
|
dataset = create_dataset(
|
|
82
85
|
root=self.data_path, name="IMNET" if self.d == "imagenet" else self.d.upper()
|
|
83
86
|
, split='validation', is_training=False
|
|
84
|
-
,
|
|
87
|
+
, load_bytes=False, class_map='')
|
|
85
88
|
|
|
86
89
|
loader = create_loader(
|
|
87
90
|
dataset,
|
|
@@ -0,0 +1,35 @@
|
|
|
1
|
+
from collections import OrderedDict
|
|
2
|
+
|
|
3
|
+
from joonmyung.utils import isDir
|
|
4
|
+
from timm import create_model
|
|
5
|
+
import torch
|
|
6
|
+
import os
|
|
7
|
+
|
|
8
|
+
class JModel():
|
|
9
|
+
def __init__(self, num_classes = None, model_path= None, device="cuda", p=False):
|
|
10
|
+
# Pretrained_Model
|
|
11
|
+
self.num_classes = num_classes
|
|
12
|
+
|
|
13
|
+
if model_path:
|
|
14
|
+
self.model_path = os.path.join(model_path, "checkpoint_{}.pth")
|
|
15
|
+
if p and model_path:
|
|
16
|
+
print("file list : ", sorted(os.listdir(model_path), reverse=True))
|
|
17
|
+
self.device = device
|
|
18
|
+
|
|
19
|
+
def load_state_dict(self, model, state_dict):
|
|
20
|
+
state_dict = OrderedDict((k.replace("module.", ""), v) for k, v in state_dict.items())
|
|
21
|
+
model.load_state_dict(state_dict)
|
|
22
|
+
|
|
23
|
+
|
|
24
|
+
def getModel(self, model_type=0, model_name ="deit_tiny", **kwargs):
|
|
25
|
+
|
|
26
|
+
if model_type == 0:
|
|
27
|
+
model = create_model(model_name, pretrained=True, num_classes=self.num_classes, in_chans=3, global_pool=None, scriptable=False)
|
|
28
|
+
elif model_type == 1:
|
|
29
|
+
model = torch.hub.load('facebookresearch/deit:main', model_name, pretrained=True)
|
|
30
|
+
else:
|
|
31
|
+
raise ValueError
|
|
32
|
+
model.eval()
|
|
33
|
+
|
|
34
|
+
return model.to(self.device)
|
|
35
|
+
|
|
@@ -0,0 +1,47 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
|
|
3
|
+
def rangeBlock(block, vmin=0, vmax=5):
|
|
4
|
+
loss = torch.arange(vmin, vmax, (vmax - vmin) / block, requires_grad=False).unsqueeze(dim=1)
|
|
5
|
+
return loss
|
|
6
|
+
|
|
7
|
+
def columnRename(df, ns):
|
|
8
|
+
for n in ns:
|
|
9
|
+
if n[0] in df.columns:
|
|
10
|
+
df.rename(columns = {n[0]: n[1]}, inplace = True)
|
|
11
|
+
# columnRemove(df, ['c1', 'c2' ... ])
|
|
12
|
+
|
|
13
|
+
|
|
14
|
+
def columnRemove(df, ns):
|
|
15
|
+
delList = []
|
|
16
|
+
for n in ns:
|
|
17
|
+
if n in df.columns:
|
|
18
|
+
delList.append(n)
|
|
19
|
+
df.drop(delList, axis=1, inplace=True)
|
|
20
|
+
# columnRename(df, [['c1_p', 'c1_a'] , ['c2_p', 'c2_a']])
|
|
21
|
+
|
|
22
|
+
|
|
23
|
+
def normalization(t, type = 0):
|
|
24
|
+
if type == 0:
|
|
25
|
+
return t / t.max()
|
|
26
|
+
elif type == 1:
|
|
27
|
+
return t / t.min()
|
|
28
|
+
|
|
29
|
+
from torchvision import transforms
|
|
30
|
+
from torchvision.transforms import InterpolationMode
|
|
31
|
+
|
|
32
|
+
def getTransform(train = False, totensor = False, resize=True):
|
|
33
|
+
|
|
34
|
+
if not resize:
|
|
35
|
+
transform = lambda x: x
|
|
36
|
+
else:
|
|
37
|
+
transform = []
|
|
38
|
+
|
|
39
|
+
transform.append(transforms.RandomResizedCrop(224, scale=(0.5, 1.0), interpolation=InterpolationMode.BICUBIC)) \
|
|
40
|
+
if train else transform.append(transforms.Resize((224, 224), interpolation=3))
|
|
41
|
+
|
|
42
|
+
if totensor:
|
|
43
|
+
transform.append(transforms.ToTensor())
|
|
44
|
+
transform.append(transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]))
|
|
45
|
+
transform = transforms.Compose(transform)
|
|
46
|
+
|
|
47
|
+
return transform
|