joonmyung 1.5.2__tar.gz → 1.5.5__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (47) hide show
  1. {joonmyung-1.5.2 → joonmyung-1.5.5}/PKG-INFO +1 -1
  2. {joonmyung-1.5.2 → joonmyung-1.5.5}/joonmyung/analysis/analysis.py +85 -59
  3. {joonmyung-1.5.2 → joonmyung-1.5.5}/joonmyung/analysis/dataset.py +21 -15
  4. {joonmyung-1.5.2 → joonmyung-1.5.5}/joonmyung/analysis/model.py +24 -2
  5. {joonmyung-1.5.2 → joonmyung-1.5.5}/joonmyung/draw.py +25 -10
  6. {joonmyung-1.5.2 → joonmyung-1.5.5}/joonmyung/log.py +3 -3
  7. joonmyung-1.5.5/joonmyung/models/__init__.py +0 -0
  8. joonmyung-1.5.5/joonmyung/models/tome.py +386 -0
  9. {joonmyung-1.5.2 → joonmyung-1.5.5}/joonmyung/script.py +2 -0
  10. {joonmyung-1.5.2 → joonmyung-1.5.5}/joonmyung.egg-info/PKG-INFO +1 -1
  11. {joonmyung-1.5.2 → joonmyung-1.5.5}/joonmyung.egg-info/SOURCES.txt +16 -1
  12. {joonmyung-1.5.2 → joonmyung-1.5.5}/joonmyung.egg-info/top_level.txt +1 -0
  13. joonmyung-1.5.5/models/SA/MHSA.py +37 -0
  14. joonmyung-1.5.5/models/SA/PVTSA.py +90 -0
  15. joonmyung-1.5.5/models/SA/TMSA.py +37 -0
  16. joonmyung-1.5.5/models/SA/__init__.py +0 -0
  17. joonmyung-1.5.5/models/__init__.py +0 -0
  18. joonmyung-1.5.5/models/deit.py +372 -0
  19. joonmyung-1.5.5/models/evit.py +154 -0
  20. joonmyung-1.5.5/models/modules/PE.py +139 -0
  21. joonmyung-1.5.5/models/modules/__init__.py +0 -0
  22. joonmyung-1.5.5/models/modules/blocks.py +168 -0
  23. joonmyung-1.5.5/models/pvt.py +307 -0
  24. joonmyung-1.5.5/models/pvt_v2.py +202 -0
  25. joonmyung-1.5.5/models/tome.py +285 -0
  26. {joonmyung-1.5.2 → joonmyung-1.5.5}/setup.py +3 -2
  27. {joonmyung-1.5.2 → joonmyung-1.5.5}/LICENSE.txt +0 -0
  28. {joonmyung-1.5.2 → joonmyung-1.5.5}/README.md +0 -0
  29. {joonmyung-1.5.2 → joonmyung-1.5.5}/joonmyung/__init__.py +0 -0
  30. {joonmyung-1.5.2 → joonmyung-1.5.5}/joonmyung/analysis/__init__.py +0 -0
  31. {joonmyung-1.5.2 → joonmyung-1.5.5}/joonmyung/analysis/hook.py +0 -0
  32. {joonmyung-1.5.2 → joonmyung-1.5.5}/joonmyung/analysis/metric.py +0 -0
  33. {joonmyung-1.5.2 → joonmyung-1.5.5}/joonmyung/analysis/utils.py +0 -0
  34. {joonmyung-1.5.2 → joonmyung-1.5.5}/joonmyung/app.py +0 -0
  35. {joonmyung-1.5.2 → joonmyung-1.5.5}/joonmyung/data.py +0 -0
  36. {joonmyung-1.5.2 → joonmyung-1.5.5}/joonmyung/dummy.py +0 -0
  37. {joonmyung-1.5.2 → joonmyung-1.5.5}/joonmyung/file.py +0 -0
  38. {joonmyung-1.5.2 → joonmyung-1.5.5}/joonmyung/gradcam.py +0 -0
  39. {joonmyung-1.5.2 → joonmyung-1.5.5}/joonmyung/meta_data/__init__.py +0 -0
  40. {joonmyung-1.5.2 → joonmyung-1.5.5}/joonmyung/meta_data/label.py +0 -0
  41. {joonmyung-1.5.2 → joonmyung-1.5.5}/joonmyung/meta_data/utils.py +0 -0
  42. {joonmyung-1.5.2 → joonmyung-1.5.5}/joonmyung/metric.py +0 -0
  43. {joonmyung-1.5.2 → joonmyung-1.5.5}/joonmyung/status.py +0 -0
  44. {joonmyung-1.5.2 → joonmyung-1.5.5}/joonmyung/utils.py +0 -0
  45. {joonmyung-1.5.2 → joonmyung-1.5.5}/joonmyung.egg-info/dependency_links.txt +0 -0
  46. {joonmyung-1.5.2 → joonmyung-1.5.5}/joonmyung.egg-info/not-zip-safe +0 -0
  47. {joonmyung-1.5.2 → joonmyung-1.5.5}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: joonmyung
3
- Version: 1.5.2
3
+ Version: 1.5.5
4
4
  Summary: JoonMyung's Library
5
5
  Home-page: https://github.com/pizard/JoonMyung.git
6
6
  Author: JoonMyung Choi
@@ -5,7 +5,7 @@ from joonmyung.analysis.model import JModel
5
5
  from joonmyung.draw import saliency, overlay, drawImgPlot, drawHeatmap, unNormalize
6
6
  from joonmyung.meta_data import data2path
7
7
  from joonmyung.data import getTransform
8
- from joonmyung.metric import targetPred
8
+ from joonmyung.metric import targetPred, accuracy
9
9
  from joonmyung.log import AverageMeter
10
10
  from joonmyung.utils import to_leaf, to_np
11
11
  from tqdm import tqdm
@@ -20,13 +20,15 @@ import cv2
20
20
 
21
21
  def anaModel(transformer_class):
22
22
  class VisionTransformer(transformer_class):
23
+ def has_parameter(self, parameter_name):
24
+ return parameter_name in self.__init__.__code__.co_varnames
25
+
23
26
  def forward_features(self, x):
24
27
  x = self.patch_embed(x)
25
- cls_token = self.cls_token.expand(x.shape[0], -1, -1) # stole cls_tokens impl from Phil Wang, thanks
26
- if self.dist_token is None:
28
+ if self.has_parameter("cls_token"):
29
+ cls_token = self.cls_token.expand(x.shape[0], -1, -1) # stole cls_tokens impl from Phil Wang, thanks
27
30
  x = torch.cat((cls_token, x), dim=1)
28
- else:
29
- x = torch.cat((cls_token, self.dist_token.expand(x.shape[0], -1, -1), x), dim=1)
31
+
30
32
 
31
33
 
32
34
  if self.analysis[0] == 1: # PATCH
@@ -43,17 +45,19 @@ def anaModel(transformer_class):
43
45
 
44
46
  x = self.blocks(x)
45
47
  x = self.norm(x)
46
- if self.dist_token is None:
48
+ if self.has_parameter("cls_token") and self.has_parameter("dist_token"):
49
+ return x[:, 0], x[:, 1]
50
+ elif self.has_parameter("cls_token"):
47
51
  return self.pre_logits(x[:, 0])
48
52
  else:
49
- return x[:, 0], x[:, 1]
53
+ return self.pre_logits(x.mean(dim=1))
54
+
50
55
 
51
56
  return VisionTransformer
52
57
 
53
58
  class Analysis:
54
- def __init__(self, model, analysis = [0], activate = [True, False, False], detach=True, key_name=None, num_classes = 1000
59
+ def __init__(self, model, analysis = [0], activate = [True, False, False, False], detach=True, key_name=None, num_classes = 1000
55
60
  , cls_start=0, cls_end=1, patch_start=1, patch_end=None
56
- , ks = 5
57
61
  , amp_autocast=suppress, device="cuda"):
58
62
  # Section A. Model
59
63
  self.num_classes = num_classes
@@ -63,18 +67,17 @@ class Analysis:
63
67
  model.__class__ = model_
64
68
  model.analysis = analysis
65
69
  self.model = model
66
- self.ks = ks
67
70
  self.detach = detach
68
71
 
69
72
  # Section B. Attention
70
73
  self.kwargs_roll = {"cls_start" : cls_start, "cls_end" : cls_end,
71
74
  "patch_start" : patch_start, "patch_end" : patch_end}
72
75
 
73
-
74
76
  # Section C. Setting
75
77
  hooks = [{"name_i": 'attn_drop', "name_o": 'decoder', "fn_f": self.attn_forward, "fn_b": self.attn_backward},
76
78
  {"name_i": 'qkv', "name_o": 'decoder', "fn_f": self.qkv_forward, "fn_b": self.qkv_backward},
77
- {"name_i": 'head', "name_o": 'decoder', "fn_f": self.head_forward, "fn_b": self.head_backward}]
79
+ {"name_i": 'head', "name_o": 'decoder', "fn_f": self.head_forward, "fn_b": self.head_backward},
80
+ {"name_i": 'patch_embed.norm', "name_o": 'decoder', "fn_f": self.input_forward, "fn_b": self.input_backward}]
78
81
  hooks = [h for h, a in zip(hooks, activate) if a]
79
82
 
80
83
 
@@ -86,20 +89,18 @@ class Analysis:
86
89
  if hook["name_i"] in name and hook["name_o"] not in name:
87
90
  module.register_forward_hook(hook["fn_f"])
88
91
  module.register_backward_hook(hook["fn_b"])
92
+ self.resetInfo()
89
93
 
90
94
  def attn_forward(self, module, input, output):
91
- # input : 1 * (8, 3, 197, 197)
92
- # output : (8, 3, 197, 197)
93
- self.info["attn"]["f"].append(output.detach() if self.detach else output)
95
+ # input/output : 1 * (8, 3, 197, 197) / (8, 3, 197, 197)
96
+ self.info["attn"]["f"] = output.detach() if self.detach else output
94
97
 
95
98
  def attn_backward(self, module, grad_input, grad_output):
96
- # input : 1 * (8, 3, 197, 192)
97
- # output : (8, 3, 197, 576)
98
- self.info["attn"]["b"].insert(0, grad_input[0].detach() if self.detach else grad_input[0])
99
+ # input/output : 1 * (8, 3, 197, 192) / (8, 3, 197, 576)
100
+ self.info["attn"]["b"] = grad_input[0].detach() if self.detach else grad_input[0]
99
101
 
100
102
  def qkv_forward(self, module, input, output):
101
- # input : 1 * (8, 197, 192)
102
- # output : (8, 197, 576)
103
+ # input/output : 1 * (8, 197, 192) / (8, 197, 576)
103
104
  self.info["qkv"]["f"].append(output.detach())
104
105
 
105
106
  def qkv_backward(self, module, grad_input, grad_output):
@@ -108,38 +109,50 @@ class Analysis:
108
109
 
109
110
  def head_forward(self, module, input, output):
110
111
  # input : 1 * (8(B), 192(D)), output : (8(B), 1000(C))
111
- # TP = targetPred(output.detach(), self.targets.detach(), topk=self.ks)
112
- TP = targetPred(to_leaf(output), to_leaf(self.targets), topk=self.ks)
113
- self.info["head"]["TP"] = torch.cat([self.info["head"]["TP"], TP], dim=0) if "TP" in self.info["head"].keys() else TP
112
+ B = output.shape[0]
113
+ pred = targetPred(output, self.targets, topk=5)
114
+ self.info["head"]["TF"] += (pred[:, 0] == pred[:, 1])
115
+
116
+ acc1, acc5 = accuracy(output, self.targets, topk=(1,5))
117
+ self.info["head"]["acc1"].update(acc1.item(), n=B)
118
+ self.info["head"]["acc5"].update(acc5.item(), n=B)
114
119
 
115
120
  def head_backward(self, module, grad_input, grad_output):
116
121
  pass
117
122
 
118
- def resetInfo(self):
119
- self.info = {"attn": {"f": [], "b": []}, "qkv": {"f": [], "b": []},
120
- "head": {"acc1" : AverageMeter(),
121
- "acc5" : AverageMeter(),
122
- "pred" : None
123
+ def input_forward(self, module, input, output):
124
+ norm = F.normalize(output, dim=-1)
125
+ self.info["input"]["sim"] += (norm @ norm.transpose(-1, -2)).mean(dim=(-1, -2))
123
126
 
124
- }}
127
+ def input_backward(self, module, grad_input, grad_output):
128
+ pass
125
129
 
126
- def __call__(self, samples, index=None, **kwargs):
127
- self.resetInfo()
130
+ def resetInfo(self):
131
+ self.info = {"attn" : {"f": None, "b": None},
132
+ "qkv" : {"f": None, "b": None},
133
+ "head" : {"acc1" : AverageMeter(),
134
+ "acc5" : AverageMeter(),
135
+ "TF" : [], "pred" : []},
136
+ "input": {"sim" : []}
137
+ }
138
+
139
+ def __call__(self, samples, targets = None, **kwargs):
128
140
  self.model.zero_grad()
129
141
  self.model.eval()
130
142
 
131
143
  if type(samples) == torch.Tensor:
144
+ self.targets = targets
132
145
  outputs = self.model(samples, **kwargs)
133
146
  return outputs
134
147
  else:
135
- for sample, self.targets in tqdm(samples):
148
+ for sample, targets in tqdm(samples):
149
+ self.targets = targets
136
150
  _ = self.model(sample)
137
151
  return False
138
152
 
139
153
  def anaSaliency(self, attn=True, grad=False, output=None, index=None,
140
154
  head_fusion="mean", discard_ratios=[0.], data_from="cls",
141
- ls_attentive=[], ls_rollout=[],
142
- reshape=False, device="cuda"):
155
+ reshape=False, activate= [True, True, False], device="cuda"):
143
156
 
144
157
  if attn:
145
158
  attn = self.info["attn"]["f"]
@@ -152,31 +165,29 @@ class Analysis:
152
165
  loss.backward(retain_graph=True)
153
166
  grad = self.info["attn"]["b"]
154
167
 
155
- return saliency(attn, grad
156
- , head_fusion=head_fusion, discard_ratios=discard_ratios, data_from=data_from
157
- , ls_rollout=ls_rollout, ls_attentive=ls_attentive, reshape=reshape, device=device)
168
+ return saliency(attn, grad, activate=activate,
169
+ head_fusion=head_fusion, discard_ratios=discard_ratios, data_from=data_from,
170
+ reshape=reshape, device=device)
158
171
 
159
172
 
160
173
  if __name__ == '__main__':
161
174
  # Section A. Data
162
175
  dataset_name, device, amp_autocast, debug = "imagenet", 'cuda', torch.cuda.amp.autocast, True
163
176
  data_path, num_classes, _, _ = data2path(dataset_name)
164
- view, activate = [False, True, False, False, False], [True, False, False]
177
+ view, activate = [False, True, False, False, True], [True, False, False]
165
178
  # VIEW : IMG, SALIENCY:ATTN, SALIENCY:OPENCV, SALIENCY:GRAD, ATTN. MOVEMENT
166
179
  # ACTIVATE : ATTN, QKV, HEAD
167
180
  analysis = [0] # [0] : INPUT TYPE, [0 : SAMPLE + POS, 1 : SAMPLE, 2 : POS]
168
181
 
169
182
  dataset = JDataset(data_path, dataset_name, device=device)
170
- data_idxs = [[c, i] for i in range(1000) for c in range(50)]
171
- # data_idxs = [[21, 0], [22, 0], [2, 0], [0, 0], [0, 1], [1, 1], [2, 1], [3, 1]]
183
+ # data_idxs = [[c, i] for i in range(1000) for c in range(50)]
184
+ data_idxs = [[1, 0]]
172
185
 
173
186
  # Section B. Model
174
187
  model_number, model_name = 0, "deit_tiny_patch16_224" # deit, vit | tiny, small, base
175
188
  # model_number, model_name = 1, "deit_tiny_patch16_224"
176
189
 
177
190
  # Section C. Setting
178
- ls_rollout, ls_attentive, col = [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], [0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11], 12
179
-
180
191
  modelMaker = JModel(num_classes, device=device)
181
192
  model = modelMaker.getModel(model_number, model_name)
182
193
  model = Analysis(model, analysis = analysis, activate = activate, device=device)
@@ -190,24 +201,36 @@ if __name__ == '__main__':
190
201
  drawImgPlot(unNormalize(sample, "imagenet"))
191
202
 
192
203
  if view[1]: # SALIENCY W/ MODEL
193
- # ls_rollout, ls_attentive, col = [], [0,2,4,6,8,10], 6
194
- # discard_ratios, v_ratio, head_fusion, data_from = [0.0, 0.4, 0.8], 0.1, "mean", "cls"
195
- discard_ratios, v_ratio, head_fusion, data_from = [0.0], 0.0, "mean", "patch" # Attention, Gradient
196
- rollout, attentive = model.anaSaliency(True, False, output, discard_ratios=discard_ratios,
197
- ls_attentive = ls_attentive, ls_rollout=ls_rollout,
204
+ col, discard_ratios, v_ratio, head_fusion, data_from = 12, [0.0], 0.0, "mean", "patch" # Attention, Gradient
205
+ results = model.anaSaliency(True, False, output, discard_ratios=discard_ratios,
198
206
  head_fusion = head_fusion, index=target, data_from=data_from,
199
- reshape = True) # (12(L), 8(B), 14(H), 14(W))
200
- datas_attn = overlay(sample, attentive, dataset_name)
201
- drawImgPlot(datas_attn, col=col)
207
+ reshape = True, activate=[True, True, True]) # (12(L), 8(B), 14(H), 14(W))
208
+ data_roll = overlay(sample, results["rollout"], dataset_name)
209
+ drawImgPlot(data_roll, col=col)
202
210
 
203
- discard_ratios, v_ratio, head_fusion, data_from = [0.0], 0.0, "mean", "cls" # Attention, Gradient
204
- rollout, attentive = model.anaSaliency(True, False, output, discard_ratios=discard_ratios,
205
- ls_attentive = ls_attentive, ls_rollout=ls_rollout,
206
- head_fusion = head_fusion, index=target, data_from=data_from,
207
- reshape = True) # (12(L), 8(B), 14(H), 14(W))
208
- datas_attn = overlay(sample, attentive, dataset_name)
209
- drawImgPlot(datas_attn, col=col)
211
+ data_attn = overlay(sample, results["attentive"], dataset_name)
212
+ drawImgPlot(data_attn, col=col)
213
+
214
+ data_vidTLDR = overlay(sample, results["vidTLDR"], dataset_name)
215
+ drawImgPlot(data_vidTLDR, col=col)
216
+
217
+ discard_ratios, v_ratio, head_fusion, data_from = [0.0], 0.1, "mean", "cls"
218
+ results = model.anaSaliency(True, False, output, discard_ratios=discard_ratios,
219
+ head_fusion=head_fusion, index=target, data_from=data_from,
220
+ reshape=True, activate=[True, True, True]) # (12(L), 8(B), 14(H), 14(W))
221
+
222
+ data_roll = overlay(sample, results["rollout"], dataset_name)
223
+ drawImgPlot(data_roll, col=col)
224
+
225
+ data_attn = overlay(sample, results["attentive"], dataset_name)
226
+ drawImgPlot(data_attn, col=col)
210
227
 
228
+ data_vidTLDR = overlay(sample, results["vidTLDR"], dataset_name)
229
+ drawImgPlot(data_vidTLDR, col=col)
230
+
231
+ print(1)
232
+
233
+ # roll = F.normalize(results["rollout"].reshape(12, 196), dim=-1)
211
234
 
212
235
  # datas_rollout = overlay(sample, rollout, dataset_name)
213
236
  # drawImgPlot(datas_rollout, col=col)
@@ -248,10 +271,13 @@ if __name__ == '__main__':
248
271
 
249
272
  if view[4]: # ATTENTION MOVEMENT (FROM / TO)
250
273
  attn = torch.stack(model.info["attn"]["f"]).mean(dim=2).transpose(0,1) # (8 (B), 12 (L), 197(T_Q), 197(T_K))
274
+
275
+ # CLS가 얼마나 참고하는지
251
276
  cls2cls = attn[:, :, :1, 0].mean(dim=2) # (8(B), 12(L))
252
277
  patch2cls = attn[:, :, :1, 1:].mean(dim=2).sum(dim=-1) # (8(B), 12(L))
253
- # PATCH가 받는 정도
278
+
279
+ # PATCH가 얼마나 참고하는지
254
280
  cls2patch = attn[:, :, 1:, 0].mean(dim=2)
255
281
  patch2patch = attn[:, :, 1:, 1:].mean(dim=2).sum(dim=-1)
256
282
  # to_np(torch.stack([cls2cls.mean(dim=0), patch2cls.mean(dim=0), cls2patch.mean(dim=0), patch2patch.mean(dim=0)]))
257
- print(1)
283
+ print(1)
@@ -1,4 +1,6 @@
1
+ from torchvision.transforms import InterpolationMode
1
2
  from joonmyung.meta_data.label import imnet_label, cifar_label
3
+ from torchvision.datasets.folder import default_loader
2
4
  from timm.data import create_dataset, create_loader
3
5
  from torchvision import transforms
4
6
  from joonmyung.utils import getDir
@@ -43,30 +45,32 @@ class JDataset():
43
45
  size = size if size else setting["size"]
44
46
 
45
47
  self.transform = [
46
- transforms.Compose([transforms.Resize(size, interpolation=3), transforms.ToTensor(), transforms.Normalize(self.distribution["mean"], self.distribution["std"])]),
47
- transforms.Compose([transforms.Resize(size, interpolation=3), transforms.ToTensor()]),
48
+ transforms.Compose([transforms.Resize((256, 256), interpolation=InterpolationMode.BICUBIC), transforms.CenterCrop(size), transforms.ToTensor(), transforms.Normalize(self.distribution["mean"], self.distribution["std"])]),
49
+ transforms.Compose([transforms.Resize((256, 256), interpolation=InterpolationMode.BICUBIC), transforms.CenterCrop(size), transforms.ToTensor()]),
50
+ transforms.Compose([transforms.Resize(size, interpolation=InterpolationMode.BICUBIC), transforms.ToTensor(), transforms.Normalize(self.distribution["mean"], self.distribution["std"])]),
48
51
  transforms.Compose([transforms.ToTensor()])
49
52
  ]
50
53
 
51
54
  self.device = device
52
55
  self.data_path = data_path
53
56
  self.label_paths = sorted(getDir(os.path.join(self.data_path, self.data_type)))
54
- self.img_paths = [sorted(glob.glob(os.path.join(self.data_path, self.data_type, label_path, "*"))) for label_path in self.label_paths]
57
+
55
58
  # self.img_paths = [sorted(glob.glob(os.path.join(self.data_path, self.data_type, "*", "*")))]
59
+ self.img_paths = [[path, idx] for idx, label_path in enumerate(self.label_paths) for path in sorted(glob.glob(os.path.join(self.data_path, self.data_type, label_path, "*")))]
56
60
 
57
61
 
58
62
  def __getitem__(self, idx):
59
- if len(idx) == 2:
60
- label_num, img_num = idx
63
+ if type(idx) == tuple:
64
+ idx, transform_type = idx
65
+ else:
61
66
  transform_type = self.transform_type
62
- elif len(idx) == 3:
63
- label_num, img_num, transform_type = idx
64
67
 
65
- img_path = self.img_paths[label_num][img_num]
66
- img = PIL.Image.open(img_path)
67
- data = self.transform[transform_type](img)
68
+ [img_path, targets] = [idx, 0] if type(idx) == str else self.img_paths[idx]
69
+
70
+ sample = default_loader(img_path)
71
+ sample = self.transform[transform_type](sample)
68
72
 
69
- return data.unsqueeze(0).to(self.device), torch.tensor(label_num).to(self.device), self.label_name[int(label_num)]
73
+ return sample[None].to(self.device), torch.tensor(targets).to(self.device), self.label_name[targets]
70
74
 
71
75
  def getItems(self, indexs):
72
76
  ds, ls, lns = [], [], []
@@ -77,10 +81,12 @@ class JDataset():
77
81
  lns.append(ln)
78
82
  return torch.cat(ds, dim=0), torch.stack(ls, dim=0), lns
79
83
 
80
- def getItemPath(self, img_path):
81
- img = PIL.Image.open(img_path)
82
- data = self.transform(img)
83
- return data.unsqueeze(0).to(self.device)
84
+ def getIndex(self, c: list = [0, 1000], i: list = [0, 50]):
85
+ [c_s, c_e], [i_s, i_e] = c, i
86
+ c = torch.arange(c_s, c_e).reshape(-1, 1).repeat(1, i_e - i_s).reshape(-1)
87
+ i = torch.arange(i_s, i_e).reshape(1, -1).repeat(c_e - c_s, 1).reshape(-1)
88
+ c_i = torch.stack([c, i], dim=-1)
89
+ return c_i
84
90
 
85
91
  def __len__(self):
86
92
  return
@@ -1,7 +1,7 @@
1
1
  from collections import OrderedDict
2
-
3
2
  from joonmyung.utils import isDir
4
3
  from timm import create_model
4
+ import models.deit
5
5
  import torch
6
6
  import os
7
7
 
@@ -11,7 +11,8 @@ class JModel():
11
11
  self.num_classes = num_classes
12
12
 
13
13
  if model_path:
14
- self.model_path = os.path.join(model_path, "checkpoint_{}.pth")
14
+ self.model_path = os.path.join(model_path, "checkpoint.pth")
15
+
15
16
  if p and model_path:
16
17
  print("file list : ", sorted(os.listdir(model_path), reverse=True))
17
18
  self.device = device
@@ -27,6 +28,27 @@ class JModel():
27
28
  model = create_model(model_name, pretrained=True, num_classes=self.num_classes, in_chans=3, global_pool=None, scriptable=False)
28
29
  elif model_type == 1:
29
30
  model = torch.hub.load('facebookresearch/deit:main', model_name, pretrained=True)
31
+ elif model_type == 2:
32
+ checkpoint = torch.load(self.model_path, map_location='cpu')
33
+ args = checkpoint['args']
34
+ model = create_model(
35
+ args.model,
36
+ pretrained=args.pretrained,
37
+ num_classes=args.nb_classes,
38
+ drop_rate=args.drop,
39
+ drop_path_rate=args.drop_path,
40
+ drop_block_rate=None,
41
+ img_size=args.input_size,
42
+ token_nums=args.token_nums,
43
+ embed_type=args.embed_type,
44
+ model_type=args.model_type
45
+ ).to(self.device)
46
+ state_dict = []
47
+ for n, p in checkpoint['model'].items():
48
+ if "total_ops" not in n and "total_params" not in n:
49
+ state_dict.append((n, p))
50
+ state_dict = dict(state_dict)
51
+ model.load_state_dict(state_dict)
30
52
  else:
31
53
  raise ValueError
32
54
  model.eval()
@@ -157,8 +157,9 @@ def drawBarChart(df, x, y, splitColName, col=1, title=[], fmt=1, p=False, c=Fals
157
157
 
158
158
  @torch.no_grad()
159
159
  def saliency(attentions=None, gradients=None, head_fusion="mean",
160
- discard_ratios = [0.0], data_from="cls", ls_rollout=[], ls_attentive=[],
161
- reshape=False, device="cpu"):
160
+ discard_ratios = [0.0], data_from="cls", reshape=False,
161
+ activate = [True, True, True], device="cpu"):
162
+
162
163
  # attentions : L * (B, H, h, w), gradients : L * (B, H, h, w)
163
164
  if type(discard_ratios) is not list: discard_ratios = [discard_ratios]
164
165
  saliencys = 1.
@@ -180,14 +181,14 @@ def saliency(attentions=None, gradients=None, head_fusion="mean",
180
181
 
181
182
  saliencys = saliencys.to(device)
182
183
 
183
- _, B, _, T = saliencys.shape
184
+ L, B, _, T = saliencys.shape # (L(12), B(1), T(197), T(197))
184
185
  H = W = int(T ** 0.5)
185
186
 
186
- rollouts, attentive = None, None
187
- if ls_rollout:
187
+ result = {}
188
+ if activate[0]:
188
189
  rollouts, I = [], torch.eye(T, device=device).unsqueeze(0).expand(B, -1, -1) # (B, 197, 197)
189
190
  for discard_ratio in discard_ratios:
190
- for start in ls_rollout:
191
+ for start in range(L):
191
192
  rollout = I
192
193
  for attn in copy.deepcopy(saliencys[start:]): # (L, B, w, h)
193
194
  # TODO NEED TO CORRECT
@@ -207,20 +208,34 @@ def saliency(attentions=None, gradients=None, head_fusion="mean",
207
208
  rollouts.append(rollout)
208
209
  rollouts = torch.stack(rollouts, dim=0)
209
210
  rollouts = rollouts[:, :, 1:]
211
+ rollouts = rollouts / rollouts.sum(dim=-1, keepdim=True)
210
212
 
211
213
  if reshape:
212
214
  rollouts = rollouts.reshape(-1, B, H, W) # L, B, H, W
215
+ result["rollout"] = rollouts
213
216
 
214
- if ls_attentive:
217
+ if activate[1]:
215
218
  # attentive = saliencys[ls_attentive, :, 0] \
216
219
  # if data_from == "cls" else saliencys[ls_attentive, :, 1:].mean(dim=2) # (L, B, T)
217
- attentive = saliencys[ls_attentive, :, 0] \
218
- if data_from == "cls" else saliencys[ls_attentive].mean(dim=2) # (L, B, T)
220
+ attentive = saliencys[:, :, 0] \
221
+ if data_from == "cls" else saliencys.mean(dim=2) # (L, B, T)
219
222
  attentive = attentive[:, :, 1:]
223
+ attentive = attentive / attentive.sum(dim=-1, keepdim=True)
224
+
220
225
  if reshape:
221
226
  attentive = attentive.reshape(-1, B, H, W)
227
+ result["attentive"] = attentive
228
+
229
+ if activate[2]:
230
+ entropy = (saliencys * torch.log(saliencys)).sum(dim=-1)[:, :, 1:] # (L(12), B(1), T(196))
231
+ entropy = entropy - entropy.amin(dim=-1, keepdim=True)
232
+ entropy = entropy / entropy.sum(dim=-1, keepdim=True)
233
+ if reshape:
234
+ entropy = entropy.reshape(L, B, H, W)
235
+ result["vidTLDR"] = entropy
236
+
222
237
 
223
- return rollouts, attentive
238
+ return result
224
239
 
225
240
 
226
241
 
@@ -49,14 +49,14 @@ class AverageMeter:
49
49
 
50
50
  class Logger():
51
51
  loggers = {}
52
- def __init__(self, use_wandb=True, wandb_entity=None, wandb_project=None, wandb_name=None
52
+ def __init__(self, use_wandb=True, wandb_entity=None, wandb_project=None, wandb_name=None, wandb_tags=None
53
53
  , wandb_watch=False, main_process=True, wandb_id=None, wandb_dir='./'
54
54
  , args=None, model=False):
55
55
  self.use_wandb = use_wandb
56
56
  self.main_process = main_process
57
57
 
58
58
  if self.use_wandb and self.main_process:
59
- wandb.init(entity=wandb_entity, project=wandb_project, name=wandb_name
59
+ wandb.init(entity=wandb_entity, project=wandb_project, name=wandb_name, tags=wandb_tags
60
60
  , save_code=True, resume="allow", id = wandb_id, dir=wandb_dir
61
61
  , config=args, settings=wandb.Settings(code_dir="."))
62
62
 
@@ -120,7 +120,7 @@ class Logger():
120
120
  wandb.finish()
121
121
 
122
122
  def save(self, model, args, name):
123
- if self.main_process:
123
+ if self.main_process and self.use_wandb:
124
124
  path = os.path.join(wandb.run.dir, f"{name}.pth")
125
125
  torch.save({"model" : model, "args" : args}, path)
126
126
  wandb.save(path, wandb.run.dir)
File without changes