joonmyung 1.4.8__tar.gz → 1.4.9__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (36) hide show
  1. {joonmyung-1.4.8 → joonmyung-1.4.9}/PKG-INFO +1 -5
  2. joonmyung-1.4.9/README.md +87 -0
  3. joonmyung-1.4.9/joonmyung/analysis/analysis.py +245 -0
  4. {joonmyung-1.4.8 → joonmyung-1.4.9}/joonmyung/analysis/dataset.py +11 -8
  5. joonmyung-1.4.9/joonmyung/analysis/model.py +35 -0
  6. joonmyung-1.4.9/joonmyung/data.py +47 -0
  7. {joonmyung-1.4.8 → joonmyung-1.4.9}/joonmyung/draw.py +80 -61
  8. {joonmyung-1.4.8 → joonmyung-1.4.9}/joonmyung/file.py +1 -1
  9. {joonmyung-1.4.8 → joonmyung-1.4.9}/joonmyung/log.py +39 -0
  10. {joonmyung-1.4.8 → joonmyung-1.4.9}/joonmyung/meta_data/utils.py +20 -21
  11. {joonmyung-1.4.8 → joonmyung-1.4.9}/joonmyung/metric.py +7 -6
  12. {joonmyung-1.4.8 → joonmyung-1.4.9}/joonmyung/script.py +46 -21
  13. {joonmyung-1.4.8 → joonmyung-1.4.9}/joonmyung/utils.py +3 -2
  14. {joonmyung-1.4.8 → joonmyung-1.4.9}/joonmyung.egg-info/PKG-INFO +1 -5
  15. {joonmyung-1.4.8 → joonmyung-1.4.9}/setup.py +2 -2
  16. joonmyung-1.4.8/README.md +0 -60
  17. joonmyung-1.4.8/joonmyung/analysis/analysis.py +0 -165
  18. joonmyung-1.4.8/joonmyung/analysis/model.py +0 -52
  19. joonmyung-1.4.8/joonmyung/data.py +0 -27
  20. {joonmyung-1.4.8 → joonmyung-1.4.9}/LICENSE.txt +0 -0
  21. {joonmyung-1.4.8 → joonmyung-1.4.9}/joonmyung/__init__.py +0 -0
  22. {joonmyung-1.4.8 → joonmyung-1.4.9}/joonmyung/analysis/__init__.py +0 -0
  23. {joonmyung-1.4.8 → joonmyung-1.4.9}/joonmyung/analysis/hook.py +0 -0
  24. {joonmyung-1.4.8 → joonmyung-1.4.9}/joonmyung/analysis/metric.py +0 -0
  25. {joonmyung-1.4.8 → joonmyung-1.4.9}/joonmyung/analysis/utils.py +0 -0
  26. {joonmyung-1.4.8 → joonmyung-1.4.9}/joonmyung/app.py +0 -0
  27. {joonmyung-1.4.8 → joonmyung-1.4.9}/joonmyung/dummy.py +0 -0
  28. {joonmyung-1.4.8 → joonmyung-1.4.9}/joonmyung/gradcam.py +0 -0
  29. {joonmyung-1.4.8 → joonmyung-1.4.9}/joonmyung/meta_data/__init__.py +0 -0
  30. {joonmyung-1.4.8 → joonmyung-1.4.9}/joonmyung/meta_data/label.py +0 -0
  31. {joonmyung-1.4.8 → joonmyung-1.4.9}/joonmyung/status.py +0 -0
  32. {joonmyung-1.4.8 → joonmyung-1.4.9}/joonmyung.egg-info/SOURCES.txt +0 -0
  33. {joonmyung-1.4.8 → joonmyung-1.4.9}/joonmyung.egg-info/dependency_links.txt +0 -0
  34. {joonmyung-1.4.8 → joonmyung-1.4.9}/joonmyung.egg-info/not-zip-safe +0 -0
  35. {joonmyung-1.4.8 → joonmyung-1.4.9}/joonmyung.egg-info/top_level.txt +0 -0
  36. {joonmyung-1.4.8 → joonmyung-1.4.9}/setup.cfg +0 -0
@@ -1,13 +1,9 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: joonmyung
3
- Version: 1.4.8
3
+ Version: 1.4.9
4
4
  Summary: JoonMyung's Library
5
5
  Home-page: https://github.com/pizard/JoonMyung.git
6
6
  Author: JoonMyung Choi
7
7
  Author-email: pizard@korea.ac.kr
8
8
  License: MIT
9
- Platform: UNKNOWN
10
9
  License-File: LICENSE.txt
11
-
12
- UNKNOWN
13
-
@@ -0,0 +1,87 @@
1
+ # 1. Introduction
2
+ JoonMyung Choi's Package
3
+
4
+
5
+ # 2. ToDo List
6
+ ### a. Library
7
+ 1. joonmyung/Script
8
+ - [ ] 추가 스크립트, Queue 추가
9
+
10
+ 2. joonmyung/draw
11
+ - [ ] LinePlot 수정
12
+
13
+ ### b. Playground
14
+
15
+
16
+ # 3. Previous
17
+ ## Version 1.4.1
18
+ 1. joonmyung/log
19
+ - [X] wandb table 저장 오류 수정
20
+ 2. joonmyung/utils
21
+ - [X] str2list 띄어쓰기 오류 수정
22
+
23
+ ## Version 1.4.0
24
+ 1. joonmyung/app.py
25
+ - [X] 실험 도중 스크립트 추가 기능
26
+ 2. joonmyung/log
27
+ - [X] 모델, 코드 저장 기능
28
+
29
+ ## Version 1.3.2
30
+ 1. joonmyung/Logger
31
+ - [X] wandb_id 작업
32
+
33
+ ## Version 1.3.2
34
+ 1. joonmyung/Script
35
+ - [X] Multi-GPU 적용
36
+ 2joonmyung/draw
37
+ - [X] rollout (Attention, Gradient) 추가
38
+ 3. joonmyung/log
39
+ - [X] type에 대한 확인
40
+ 4. playground/profiling
41
+ - [X] 속도 측정 비교 메서드 구체화
42
+
43
+
44
+ ## Version 1.3.1
45
+ ### a. Library
46
+ 1. joonmyung/draw
47
+ - [X] overlay 기능 추가
48
+
49
+ ### b. Playground
50
+ 1. playground/analysis
51
+ - [X] data 관련 분석 코드 작성
52
+ - [X] Model 관련 분석 코드 작성
53
+
54
+ ## Version 1.3.0
55
+ ### a. Library
56
+ 1. joonmyung/draw
57
+ - [X] drawImgPlot 추가
58
+
59
+ 3. joonmyung/log
60
+ - [X] Wandb Log / Table 추가
61
+
62
+
63
+
64
+
65
+
66
+
67
+
68
+
69
+ [//]: # (CUDA_VISIBLE_DEVICES=2 python playground/models/fastsam/model.py --split 0)
70
+ # CUDA_VISIBLE_DEVICES=2 nohup python playground/models/fastsam/model.py --split 0 > 0.log 2>&1 &
71
+ # CUDA_VISIBLE_DEVICES=2 nohup python playground/models/fastsam/model.py --split 1 > 1.log 2>&1 &
72
+ # CUDA_VISIBLE_DEVICES=2 nohup python playground/models/fastsam/model.py --split 2 > 2.log 2>&1 &
73
+ # CUDA_VISIBLE_DEVICES=2 nohup python playground/models/fastsam/model.py --split 3 > 3.log 2>&1 &
74
+ # CUDA_VISIBLE_DEVICES=3 nohup python playground/models/fastsam/model.py --split 4 > 4.log 2>&1 &
75
+ # CUDA_VISIBLE_DEVICES=3 nohup python playground/models/fastsam/model.py --split 5 > 5.log 2>&1 &
76
+ # CUDA_VISIBLE_DEVICES=3 nohup python playground/models/fastsam/model.py --split 6 > 6.log 2>&1 &
77
+ # CUDA_VISIBLE_DEVICES=3 nohup python playground/models/fastsam/model.py --split 7 > 7.log 2>&1 &
78
+
79
+
80
+ # nohup python playground/saliency/opencv.py --split 0 > 0.log 2>&1 &
81
+ # nohup python playground/saliency/opencv.py --split 1 > 1.log 2>&1 &
82
+ # nohup python playground/saliency/opencv.py --split 2 > 2.log 2>&1 &
83
+ # nohup python playground/saliency/opencv.py --split 3 > 3.log 2>&1 &
84
+ # nohup python playground/saliency/opencv.py --split 4 > 4.log 2>&1 &
85
+ # nohup python playground/saliency/opencv.py --split 5 > 5.log 2>&1 &
86
+ # nohup python playground/saliency/opencv.py --split 6 > 6.log 2>&1 &
87
+ # nohup python playground/saliency/opencv.py --split 7 > 7.log 2>&1 &
@@ -0,0 +1,245 @@
1
+ import os
2
+ os.environ["CUDA_VISIBLE_DEVICES"] = "4"
3
+ from joonmyung.analysis.dataset import JDataset
4
+ from joonmyung.analysis.model import JModel
5
+ from joonmyung.draw import saliency, overlay, drawImgPlot, drawHeatmap, unNormalize
6
+ from joonmyung.meta_data import data2path
7
+ from joonmyung.data import getTransform
8
+ from joonmyung.metric import targetPred
9
+ from joonmyung.log import AverageMeter
10
+ from joonmyung.utils import to_leaf, to_np
11
+ from tqdm import tqdm
12
+ from contextlib import suppress
13
+ import matplotlib.pyplot as plt
14
+ import torch.nn.functional as F
15
+ import numpy as np
16
+ import torch
17
+ import PIL
18
+ import cv2
19
+
20
+
21
+ def anaModel(transformer_class):
22
+ class VisionTransformer(transformer_class):
23
+ def forward_features(self, x):
24
+ x = self.patch_embed(x)
25
+ cls_token = self.cls_token.expand(x.shape[0], -1, -1) # stole cls_tokens impl from Phil Wang, thanks
26
+ if self.dist_token is None:
27
+ x = torch.cat((cls_token, x), dim=1)
28
+ else:
29
+ x = torch.cat((cls_token, self.dist_token.expand(x.shape[0], -1, -1), x), dim=1)
30
+
31
+
32
+ if self.analysis[0] == 1: # PATCH
33
+ x = x # (8, 197, 192)
34
+ elif self.analysis[0] == 2: # POS
35
+ x = self.pos_embed # (1, 197, 192)
36
+ elif self.analysis[0] == 3: # PATCH (RANDOM I) + POS
37
+ x = torch.rand_like(self.pos_embed, device=x.device) + self.pos_embed
38
+ elif self.analysis[0] == 4: # PATCH (RANDOM II) + POS
39
+ x = torch.rand_like(self.cls_token, device=x.device).repeat(1, x.shape[1], 1) + self.pos_embed
40
+ else: # PATCH + POS
41
+ x = x + self.pos_embed
42
+ x = self.pos_drop(x)
43
+
44
+ x = self.blocks(x)
45
+ x = self.norm(x)
46
+ if self.dist_token is None:
47
+ return self.pre_logits(x[:, 0])
48
+ else:
49
+ return x[:, 0], x[:, 1]
50
+
51
+ return VisionTransformer
52
+
53
+ class Analysis:
54
+ def __init__(self, model, analysis = [0], activate = [True, False, False], detach=True, key_name=None, num_classes = 1000
55
+ , cls_start=0, cls_end=1, patch_start=1, patch_end=None
56
+ , ks = 5
57
+ , amp_autocast=suppress, device="cuda"):
58
+ # Section A. Model
59
+ self.num_classes = num_classes
60
+ self.key_name = key_name
61
+
62
+ model_ = anaModel(model.__class__)
63
+ model.__class__ = model_
64
+ model.analysis = analysis
65
+ self.model = model
66
+ self.ks = ks
67
+ self.detach = detach
68
+
69
+ # Section B. Attention
70
+ self.kwargs_roll = {"cls_start" : cls_start, "cls_end" : cls_end,
71
+ "patch_start" : patch_start, "patch_end" : patch_end}
72
+
73
+
74
+ # Section C. Setting
75
+ hooks = [{"name_i": 'attn_drop', "name_o": 'decoder', "fn_f": self.attn_forward, "fn_b": self.attn_backward},
76
+ {"name_i": 'qkv', "name_o": 'decoder', "fn_f": self.qkv_forward, "fn_b": self.qkv_backward},
77
+ {"name_i": 'head', "name_o": 'decoder', "fn_f": self.head_forward, "fn_b": self.head_backward}]
78
+ hooks = [h for h, a in zip(hooks, activate) if a]
79
+
80
+
81
+ self.amp_autocast = amp_autocast
82
+ self.device = device
83
+
84
+ for name, module in self.model.named_modules():
85
+ for hook in hooks:
86
+ if hook["name_i"] in name and hook["name_o"] not in name:
87
+ module.register_forward_hook(hook["fn_f"])
88
+ module.register_backward_hook(hook["fn_b"])
89
+
90
+ def attn_forward(self, module, input, output):
91
+ # input : 1 * (8, 3, 197, 197)
92
+ # output : (8, 3, 197, 197)
93
+ self.info["attn"]["f"].append(output.detach() if self.detach else output)
94
+
95
+ def attn_backward(self, module, grad_input, grad_output):
96
+ # input : 1 * (8, 3, 197, 192)
97
+ # output : (8, 3, 197, 576)
98
+ self.info["attn"]["b"].insert(0, grad_input[0].detach() if self.detach else grad_input[0])
99
+
100
+ def qkv_forward(self, module, input, output):
101
+ # input : 1 * (8, 197, 192)
102
+ # output : (8, 197, 576)
103
+ self.info["qkv"]["f"].append(output.detach())
104
+
105
+ def qkv_backward(self, module, grad_input, grad_output):
106
+ # self.info["qkv"]["b"].append(grad_input[0].detach())
107
+ pass
108
+
109
+ def head_forward(self, module, input, output):
110
+ # input : 1 * (8(B), 192(D)), output : (8(B), 1000(C))
111
+ # TP = targetPred(output.detach(), self.targets.detach(), topk=self.ks)
112
+ TP = targetPred(to_leaf(output), to_leaf(self.targets), topk=self.ks)
113
+ self.info["head"]["TP"] = torch.cat([self.info["head"]["TP"], TP], dim=0) if "TP" in self.info["head"].keys() else TP
114
+
115
+ def head_backward(self, module, grad_input, grad_output):
116
+ pass
117
+
118
+ def resetInfo(self):
119
+ self.info = {"attn": {"f": [], "b": []}, "qkv": {"f": [], "b": []},
120
+ "head": {"acc1" : AverageMeter(),
121
+ "acc5" : AverageMeter(),
122
+ "pred" : None
123
+
124
+ }}
125
+
126
+ def __call__(self, samples, index=None, **kwargs):
127
+ self.resetInfo()
128
+ self.model.zero_grad()
129
+ self.model.eval()
130
+
131
+ if type(samples) == torch.Tensor:
132
+ outputs = self.model(samples, **kwargs)
133
+ return outputs
134
+ else:
135
+ for sample, self.targets in tqdm(samples):
136
+ _ = self.model(sample)
137
+ return False
138
+
139
+ def anaSaliency(self, attn=True, grad=False, output=None, index=None,
140
+ head_fusion="mean", discard_ratios=[0.], data_from="cls",
141
+ ls_attentive=[], ls_rollout=[],
142
+ reshape=False, device="cuda"):
143
+
144
+ if attn:
145
+ attn = self.info["attn"]["f"]
146
+ if grad:
147
+ self.info["attn"]["b"] = []
148
+ self.model.zero_grad()
149
+ if index == None: index = output.max(dim=1)[1]
150
+ index = torch.eye(self.num_classes, device=self.device)[index]
151
+ loss = (output * index).sum()
152
+ loss.backward(retain_graph=True)
153
+ grad = self.info["attn"]["b"]
154
+
155
+ return saliency(attn, grad
156
+ , head_fusion=head_fusion, discard_ratios=discard_ratios, data_from=data_from
157
+ , ls_rollout=ls_rollout, ls_attentive=ls_attentive, reshape=reshape, device=device)
158
+
159
+
160
+ if __name__ == '__main__':
161
+ # Section A. Data
162
+ dataset_name, device, amp_autocast, debug = "imagenet", 'cuda', torch.cuda.amp.autocast, True
163
+ data_path, _, _ = data2path(dataset_name)
164
+ data_num, batch_size, bs = [[0, 0], [1, 0], [2, 0], [3, 0], [0, 1], [1, 1], [2, 1], [3, 1]], 16, []
165
+ view, activate = [False, True, False, False, False], [True, False, False] #
166
+ # VIEW : IMG, SALIENCY(M), SALIENCY(D), SALIENCY(S), ATTN. MOVEMENT
167
+ # ACTIVATE : ATTN, QKV, HEAD
168
+ analysis = [2]
169
+ # [0] : INPUT TYPE, [0 : SAMPLE + POS, 1 : SAMPLE, 2 : POS]
170
+ if not debug:
171
+ dataset = JDataset(data_path, dataset_name, device=device)
172
+ samples, targets, imgs, label_names = dataset.getItems(data_num)
173
+ loader = dataset.getAllItems(batch_size)
174
+ num_classes = dataset.num_classes
175
+ else:
176
+ transform = getTransform(False, True)
177
+ img = PIL.Image.open('/hub_data1/joonmyung/data/imagenet/train/n01440764/n01440764_10026.JPEG')
178
+ samples, targets, label_names = transform(img)[None].to(device), torch.tensor([0]).to(device)[None].to(device), 'tench, Tinca tinca'
179
+ num_classes = 1000
180
+
181
+ # Section B. Model
182
+ model_number, model_name = 0, "deit_tiny_patch16_224" # deit_tiny_patch16_224, deit_small_patch16_224, deit_base_patch16_224
183
+ # model_number, model_name = 0, "vit_tiny_patch16_224" # vit_tiny_patch16_224, vit_small_patch16_224, vit_base_patch16_224
184
+ # model_number, model_name = 1, "deit_tiny_patch16_224"
185
+
186
+ modelMaker = JModel(num_classes, device=device)
187
+ model = modelMaker.getModel(model_number, model_name)
188
+ model = Analysis(model, analysis = analysis, activate = activate, device=device)
189
+
190
+
191
+ samples_ = samples[bs] if bs else samples
192
+ targets_ = targets[bs] if bs else targets
193
+ output = model(samples_)
194
+ if view[0]:
195
+ drawImgPlot(unNormalize(samples_, "imagenet"))
196
+
197
+ if view[1]: # SALIENCY W/ MODEL
198
+ ls_rollout, ls_attentive, col = [0,1,2,3,4,5,6,7,8,9,10,11], [0,1,2,3,4,5,6,7,8,9,10,11], 12
199
+ # discard_ratios, v_ratio, head_fusion, data_from = [0.0, 0.4, 0.8], 0.1, "mean", "cls" # Attention, Gradient
200
+ discard_ratios, v_ratio, head_fusion, data_from = [0.0], 0.0, "mean", "patch" # Attention, Gradient
201
+ rollout, attentive = model.anaSaliency(True, False, output, discard_ratios=discard_ratios,
202
+ ls_attentive = ls_attentive, ls_rollout=ls_rollout,
203
+ head_fusion = head_fusion, index=targets_, data_from=data_from,
204
+ reshape = True) # (12(L), 8(B), 14(H), 14(W))
205
+ print(1)
206
+ # datas_rollout = overlay(samples_, rollout, dataset_name)
207
+ # drawImgPlot(datas_rollout, col=col)
208
+ # datas_attn = overlay(samples_, attentive, dataset_name)
209
+ # drawImgPlot(datas_attn, col=col)
210
+
211
+ if view[2]: # SALIENCY W/ DATA
212
+ img = np.array(imgs[0])
213
+
214
+ saliency = cv2.saliency.StaticSaliencySpectralResidual_create()
215
+ (success, saliencyMap) = saliency.computeSaliency(img)
216
+ saliencyMap = (saliencyMap * 255).astype("uint8")
217
+
218
+ saliency = cv2.saliency.StaticSaliencyFineGrained_create()
219
+ (success, saliencyFineMap) = saliency.computeSaliency(img)
220
+ threshMap = cv2.threshold((saliencyFineMap * 255).astype("uint8"), 0, 255, cv2.THRESH_BINARY | cv2.THRESH_OTSU)[1]
221
+ # plt.imshow(threshMap)
222
+ # plt.show()
223
+
224
+ if view[3]: # SALIENCY FOR INPUT
225
+ samples_.requires_grad, model.detach, k = True, False, 3
226
+ output = model(samples_)
227
+ attn = torch.stack(model.info["attn"]["f"], dim=1).mean(dim=[2,3])[0,-2]
228
+ topK = attn[1:].topk(k, -1, True)[1]
229
+ # a = torch.autograd.grad(attn.sum(), samples, retain_graph=True)[0].sum(dim=1)
230
+ a = torch.autograd.grad(output[:,3], samples_, retain_graph=True)[0].sum(dim=1)
231
+ b = F.interpolate(a.unsqueeze(0), scale_factor=0.05, mode='nearest')[0]
232
+ # drawHeatmap(b)
233
+ print(1)
234
+ # to_np(torch.stack([attn[:, :, 0], attn[:, :, 1:].sum(dim=-1)], -1)[0])
235
+
236
+ if view[4]: # ATTENTION MOVEMENT (FROM / TO)
237
+ attn = torch.stack(model.info["attn"]["f"]).mean(dim=2).transpose(0,1) # (8 (B), 12 (L), 197(T_Q), 197(T_K))
238
+ cls2cls = attn[:, :, :1, 0].mean(dim=2) # (8(B), 12(L))
239
+ patch2cls = attn[:, :, :1, 1:].mean(dim=2).sum(dim=-1) # (8(B), 12(L))
240
+ to_np(torch.stack([cls2cls.mean(dim=0), patch2cls.mean(dim=0)]))
241
+
242
+ cls2patch = attn[:, :, 1:, 0].mean(dim=2)
243
+ patch2patch = attn[:, :, 1:, 1:].mean(dim=2).sum(dim=-1)
244
+ to_np(torch.stack([cls2patch.mean(dim=0), patch2patch.mean(dim=0)]))
245
+ print(1)
@@ -1,21 +1,22 @@
1
1
  from joonmyung.meta_data.label import imnet_label, cifar_label
2
2
  from timm.data import create_dataset, create_loader
3
3
  from torchvision import transforms
4
- from joonmyung.utils import getDir
5
4
  import torch
6
5
  import copy
7
6
  import glob
8
7
  import PIL
9
8
  import os
10
9
 
10
+ from joonmyung.utils import getDir
11
+
11
12
 
12
13
  class JDataset():
13
14
  distributions = {"imagenet": {"mean": [0.485, 0.456, 0.406], "std": [0.229, 0.224, 0.225]},
14
15
  "cifar": {"mean": [0.4914, 0.4822, 0.4465], "std": [0.2023, 0.1994, 0.2010]}}
15
- transform_cifar = transforms.Compose([transforms.ToTensor(), transforms.Normalize(distributions["cifar"]["mean"], distributions["cifar"]["std"])])
16
- transform_imagenet_ = transforms.Compose([transforms.ToTensor(), transforms.Normalize(distributions["imagenet"]["mean"], distributions["imagenet"]["std"])])
16
+ transform_cifar = transforms.Compose([transforms.ToTensor(), transforms.Normalize(distributions["cifar"]["mean"], distributions["cifar"]["std"])])
17
+ # transform_imagenet_ = transforms.Compose([transforms.ToTensor(), transforms.Normalize(distributions["imagenet"]["mean"], distributions["imagenet"]["std"])])
17
18
  # transform_imagenet_vis = transforms.Compose([transforms.Resize(256, interpolation=3), transforms.CenterCrop(224)])
18
- transform_imagenet_vis = transforms.Compose([transforms.Resize((224, 224), interpolation=3)])
19
+ transform_imagenet_vis = transforms.Compose([transforms.Resize((224, 224), interpolation=3)])
19
20
  transform_imagenet_norm = transforms.Compose([transforms.ToTensor(), transforms.Normalize(distributions["imagenet"]["mean"], distributions["imagenet"]["std"])])
20
21
 
21
22
  # transforms.Resize(int((256 / 224) * input_size), interpolation=InterpolationMode.BICUBIC),
@@ -40,13 +41,15 @@ class JDataset():
40
41
  result[:, c].sub_(m).div_(s)
41
42
  return result
42
43
 
43
- def __init__(self, data_path="/hub_data1/joonmyung/data", dataset="imagenet", device="cuda"):
44
+ def __init__(self, data_path="/hub_data1/joonmyung/data", dataset="imagenet", device="cuda", train = False):
44
45
  dataset = dataset.lower()
45
46
 
46
47
  self.d = dataset.lower()
47
48
  self.num_classes = 1000 if self.d == "imagenet" else 100
48
- [self.d_kind, self.d_type] = ["imagenet", "val"] if self.d == "imagenet" else ["cifar", "test"]
49
- # [self.d_kind, self.d_type] = ["imagenet", "train"] if self.d == "imagenet" else ["cifar", "test"]
49
+ if train:
50
+ [self.d_kind, self.d_type] = ["imagenet", "val"] if self.d == "imagenet" else ["cifar", "test"]
51
+ else:
52
+ [self.d_kind, self.d_type] = ["imagenet", "train"] if self.d == "imagenet" else ["cifar", "train"]
50
53
  self.device = device
51
54
 
52
55
  self.transform = self.transforms[self.d_kind]
@@ -81,7 +84,7 @@ class JDataset():
81
84
  dataset = create_dataset(
82
85
  root=self.data_path, name="IMNET" if self.d == "imagenet" else self.d.upper()
83
86
  , split='validation', is_training=False
84
- , download=False, load_bytes=False, class_map='')
87
+ , load_bytes=False, class_map='')
85
88
 
86
89
  loader = create_loader(
87
90
  dataset,
@@ -0,0 +1,35 @@
1
+ from collections import OrderedDict
2
+
3
+ from joonmyung.utils import isDir
4
+ from timm import create_model
5
+ import torch
6
+ import os
7
+
8
+ class JModel():
9
+ def __init__(self, num_classes = None, model_path= None, device="cuda", p=False):
10
+ # Pretrained_Model
11
+ self.num_classes = num_classes
12
+
13
+ if model_path:
14
+ self.model_path = os.path.join(model_path, "checkpoint_{}.pth")
15
+ if p and model_path:
16
+ print("file list : ", sorted(os.listdir(model_path), reverse=True))
17
+ self.device = device
18
+
19
+ def load_state_dict(self, model, state_dict):
20
+ state_dict = OrderedDict((k.replace("module.", ""), v) for k, v in state_dict.items())
21
+ model.load_state_dict(state_dict)
22
+
23
+
24
+ def getModel(self, model_type=0, model_name ="deit_tiny", **kwargs):
25
+
26
+ if model_type == 0:
27
+ model = create_model(model_name, pretrained=True, num_classes=self.num_classes, in_chans=3, global_pool=None, scriptable=False)
28
+ elif model_type == 1:
29
+ model = torch.hub.load('facebookresearch/deit:main', model_name, pretrained=True)
30
+ else:
31
+ raise ValueError
32
+ model.eval()
33
+
34
+ return model.to(self.device)
35
+
@@ -0,0 +1,47 @@
1
+ import torch
2
+
3
+ def rangeBlock(block, vmin=0, vmax=5):
4
+ loss = torch.arange(vmin, vmax, (vmax - vmin) / block, requires_grad=False).unsqueeze(dim=1)
5
+ return loss
6
+
7
+ def columnRename(df, ns):
8
+ for n in ns:
9
+ if n[0] in df.columns:
10
+ df.rename(columns = {n[0]: n[1]}, inplace = True)
11
+ # columnRemove(df, ['c1', 'c2' ... ])
12
+
13
+
14
+ def columnRemove(df, ns):
15
+ delList = []
16
+ for n in ns:
17
+ if n in df.columns:
18
+ delList.append(n)
19
+ df.drop(delList, axis=1, inplace=True)
20
+ # columnRename(df, [['c1_p', 'c1_a'] , ['c2_p', 'c2_a']])
21
+
22
+
23
+ def normalization(t, type = 0):
24
+ if type == 0:
25
+ return t / t.max()
26
+ elif type == 1:
27
+ return t / t.min()
28
+
29
+ from torchvision import transforms
30
+ from torchvision.transforms import InterpolationMode
31
+
32
+ def getTransform(train = False, totensor = False, resize=True):
33
+
34
+ if not resize:
35
+ transform = lambda x: x
36
+ else:
37
+ transform = []
38
+
39
+ transform.append(transforms.RandomResizedCrop(224, scale=(0.5, 1.0), interpolation=InterpolationMode.BICUBIC)) \
40
+ if train else transform.append(transforms.Resize((224, 224), interpolation=3))
41
+
42
+ if totensor:
43
+ transform.append(transforms.ToTensor())
44
+ transform.append(transforms.Normalize([0.485, 0.456, 0.406], [0.229, 0.224, 0.225]))
45
+ transform = transforms.Compose(transform)
46
+
47
+ return transform