joonmyung 1.5.13__tar.gz → 1.5.14__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {joonmyung-1.5.13 → joonmyung-1.5.14}/PKG-INFO +1 -1
- {joonmyung-1.5.13 → joonmyung-1.5.14}/joonmyung/draw.py +10 -16
- {joonmyung-1.5.13 → joonmyung-1.5.14}/joonmyung.egg-info/PKG-INFO +1 -1
- {joonmyung-1.5.13 → joonmyung-1.5.14}/joonmyung.egg-info/SOURCES.txt +1 -14
- {joonmyung-1.5.13 → joonmyung-1.5.14}/joonmyung.egg-info/top_level.txt +0 -1
- {joonmyung-1.5.13 → joonmyung-1.5.14}/setup.py +1 -1
- joonmyung-1.5.13/models/SA/MHSA.py +0 -37
- joonmyung-1.5.13/models/SA/PVTSA.py +0 -90
- joonmyung-1.5.13/models/SA/TMSA.py +0 -37
- joonmyung-1.5.13/models/SA/__init__.py +0 -0
- joonmyung-1.5.13/models/__init__.py +0 -0
- joonmyung-1.5.13/models/deit.py +0 -372
- joonmyung-1.5.13/models/evit.py +0 -154
- joonmyung-1.5.13/models/modules/PE.py +0 -139
- joonmyung-1.5.13/models/modules/__init__.py +0 -0
- joonmyung-1.5.13/models/modules/blocks.py +0 -168
- joonmyung-1.5.13/models/pvt.py +0 -307
- joonmyung-1.5.13/models/pvt_v2.py +0 -202
- joonmyung-1.5.13/models/tome.py +0 -285
- {joonmyung-1.5.13 → joonmyung-1.5.14}/LICENSE.txt +0 -0
- {joonmyung-1.5.13 → joonmyung-1.5.14}/README.md +0 -0
- {joonmyung-1.5.13 → joonmyung-1.5.14}/joonmyung/__init__.py +0 -0
- {joonmyung-1.5.13 → joonmyung-1.5.14}/joonmyung/analysis/__init__.py +0 -0
- {joonmyung-1.5.13 → joonmyung-1.5.14}/joonmyung/analysis/analysis.py +0 -0
- {joonmyung-1.5.13 → joonmyung-1.5.14}/joonmyung/analysis/dataset.py +0 -0
- {joonmyung-1.5.13 → joonmyung-1.5.14}/joonmyung/analysis/hook.py +0 -0
- {joonmyung-1.5.13 → joonmyung-1.5.14}/joonmyung/analysis/metric.py +0 -0
- {joonmyung-1.5.13 → joonmyung-1.5.14}/joonmyung/analysis/model.py +0 -0
- {joonmyung-1.5.13 → joonmyung-1.5.14}/joonmyung/analysis/utils.py +0 -0
- {joonmyung-1.5.13 → joonmyung-1.5.14}/joonmyung/app.py +0 -0
- {joonmyung-1.5.13 → joonmyung-1.5.14}/joonmyung/data.py +0 -0
- {joonmyung-1.5.13 → joonmyung-1.5.14}/joonmyung/dummy.py +0 -0
- {joonmyung-1.5.13 → joonmyung-1.5.14}/joonmyung/file.py +0 -0
- {joonmyung-1.5.13 → joonmyung-1.5.14}/joonmyung/gradcam.py +0 -0
- {joonmyung-1.5.13 → joonmyung-1.5.14}/joonmyung/log.py +0 -0
- {joonmyung-1.5.13 → joonmyung-1.5.14}/joonmyung/meta_data/__init__.py +0 -0
- {joonmyung-1.5.13 → joonmyung-1.5.14}/joonmyung/meta_data/label.py +0 -0
- {joonmyung-1.5.13 → joonmyung-1.5.14}/joonmyung/meta_data/utils.py +0 -0
- {joonmyung-1.5.13 → joonmyung-1.5.14}/joonmyung/metric.py +0 -0
- {joonmyung-1.5.13 → joonmyung-1.5.14}/joonmyung/models/__init__.py +0 -0
- {joonmyung-1.5.13 → joonmyung-1.5.14}/joonmyung/models/tome.py +0 -0
- {joonmyung-1.5.13 → joonmyung-1.5.14}/joonmyung/script.py +0 -0
- {joonmyung-1.5.13 → joonmyung-1.5.14}/joonmyung/status.py +0 -0
- {joonmyung-1.5.13 → joonmyung-1.5.14}/joonmyung/utils.py +0 -0
- {joonmyung-1.5.13 → joonmyung-1.5.14}/joonmyung.egg-info/dependency_links.txt +0 -0
- {joonmyung-1.5.13 → joonmyung-1.5.14}/joonmyung.egg-info/not-zip-safe +0 -0
- {joonmyung-1.5.13 → joonmyung-1.5.14}/setup.cfg +0 -0
|
@@ -241,28 +241,22 @@ def saliency(attentions=None, gradients=None, head_fusion="mean",
|
|
|
241
241
|
|
|
242
242
|
|
|
243
243
|
|
|
244
|
+
|
|
244
245
|
def data2PIL(datas):
|
|
245
|
-
if type(datas) == torch.Tensor:
|
|
246
|
-
if datas.shape
|
|
247
|
-
pils = datas
|
|
248
|
-
elif len(datas.shape) == 2:
|
|
246
|
+
if type(datas) == torch.Tensor:
|
|
247
|
+
if len(datas.shape) == 2:
|
|
249
248
|
pils = datas.unsqueeze(-1).detach().cpu()
|
|
250
|
-
|
|
249
|
+
if len(datas.shape) == 3:
|
|
251
250
|
pils = datas.permute(1, 2, 0).detach().cpu()
|
|
252
251
|
elif len(datas.shape) == 4:
|
|
253
252
|
pils = datas.permute(0, 2, 3, 1).detach().cpu()
|
|
254
253
|
elif type(datas) == np.ndarray:
|
|
255
|
-
if len(datas.shape) ==
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
# pils = datas
|
|
262
|
-
|
|
263
|
-
# image = Image.fromarray(image) # 0.32ms
|
|
264
|
-
pils = cv2.cvtColor(datas, cv2.COLOR_BGR2RGB) # 0.29ms
|
|
265
|
-
|
|
254
|
+
if len(datas.shape) == 3: datas = np.expand_dims(datas, axis=0)
|
|
255
|
+
if datas.max() <= 1:
|
|
256
|
+
# image = Image.fromarray(image) # 0.32ms
|
|
257
|
+
pils = cv2.cvtColor(datas, cv2.COLOR_BGR2RGB) # 0.29ms
|
|
258
|
+
else:
|
|
259
|
+
pils = datas
|
|
266
260
|
elif type(datas) == PIL.JpegImagePlugin.JpegImageFile \
|
|
267
261
|
or type(datas) == PIL.Image.Image:
|
|
268
262
|
pils = datas
|
|
@@ -29,17 +29,4 @@ joonmyung/meta_data/__init__.py
|
|
|
29
29
|
joonmyung/meta_data/label.py
|
|
30
30
|
joonmyung/meta_data/utils.py
|
|
31
31
|
joonmyung/models/__init__.py
|
|
32
|
-
joonmyung/models/tome.py
|
|
33
|
-
models/__init__.py
|
|
34
|
-
models/deit.py
|
|
35
|
-
models/evit.py
|
|
36
|
-
models/pvt.py
|
|
37
|
-
models/pvt_v2.py
|
|
38
|
-
models/tome.py
|
|
39
|
-
models/SA/MHSA.py
|
|
40
|
-
models/SA/PVTSA.py
|
|
41
|
-
models/SA/TMSA.py
|
|
42
|
-
models/SA/__init__.py
|
|
43
|
-
models/modules/PE.py
|
|
44
|
-
models/modules/__init__.py
|
|
45
|
-
models/modules/blocks.py
|
|
32
|
+
joonmyung/models/tome.py
|
|
@@ -1,37 +0,0 @@
|
|
|
1
|
-
import torch.nn as nn
|
|
2
|
-
class MHSA(nn.Module):
|
|
3
|
-
def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.,
|
|
4
|
-
talking_head=False):
|
|
5
|
-
super().__init__()
|
|
6
|
-
self.num_heads = num_heads
|
|
7
|
-
head_dim = dim // num_heads
|
|
8
|
-
self.scale = head_dim ** -0.5
|
|
9
|
-
|
|
10
|
-
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
|
11
|
-
self.attn_drop = nn.Dropout(attn_drop)
|
|
12
|
-
self.proj = nn.Linear(dim, dim)
|
|
13
|
-
self.proj_drop = nn.Dropout(proj_drop)
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
self.proj_l = nn.Linear(num_heads, num_heads) if talking_head else nn.Identity()
|
|
17
|
-
self.proj_w = nn.Linear(num_heads, num_heads) if talking_head else nn.Identity()
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
def forward(self, x):
|
|
21
|
-
B, N, C = x.shape
|
|
22
|
-
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
|
23
|
-
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
|
|
24
|
-
|
|
25
|
-
attn = (q @ k.transpose(-2, -1)) * self.scale # (B, H, T, D)
|
|
26
|
-
# attn = self.proj_l(attn.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
|
|
27
|
-
attn = self.proj_l(attn.transpose(1, -1)).transpose(1, -1)
|
|
28
|
-
|
|
29
|
-
attn = attn.softmax(dim=-1)
|
|
30
|
-
attn = self.attn_drop(attn)
|
|
31
|
-
# attn = self.proj_w(attn.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
|
|
32
|
-
attn = self.proj_w(attn.transpose(1, -1)).transpose(1, -1)
|
|
33
|
-
|
|
34
|
-
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
|
|
35
|
-
x = self.proj(x)
|
|
36
|
-
x = self.proj_drop(x)
|
|
37
|
-
return x
|
|
@@ -1,90 +0,0 @@
|
|
|
1
|
-
import torch
|
|
2
|
-
import torch.nn as nn
|
|
3
|
-
import torch.nn.functional as F
|
|
4
|
-
from functools import partial
|
|
5
|
-
|
|
6
|
-
from timm.models.layers import DropPath, to_2tuple, trunc_normal_
|
|
7
|
-
from timm.models.registry import register_model
|
|
8
|
-
from timm.models.vision_transformer import _cfg
|
|
9
|
-
import math
|
|
10
|
-
|
|
11
|
-
class PVTSA(nn.Module):
|
|
12
|
-
def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., sr_ratio=1, linear=False,
|
|
13
|
-
talking_head=False):
|
|
14
|
-
super().__init__()
|
|
15
|
-
assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}."
|
|
16
|
-
|
|
17
|
-
self.dim = dim
|
|
18
|
-
self.num_heads = num_heads
|
|
19
|
-
head_dim = dim // num_heads
|
|
20
|
-
self.scale = qk_scale or head_dim ** -0.5
|
|
21
|
-
|
|
22
|
-
self.q = nn.Linear(dim, dim, bias=qkv_bias)
|
|
23
|
-
self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias)
|
|
24
|
-
self.attn_drop = nn.Dropout(attn_drop)
|
|
25
|
-
self.proj = nn.Linear(dim, dim)
|
|
26
|
-
self.proj_drop = nn.Dropout(proj_drop)
|
|
27
|
-
|
|
28
|
-
self.proj_l = nn.Linear(num_heads, num_heads) if talking_head else nn.Identity()
|
|
29
|
-
self.proj_w = nn.Linear(num_heads, num_heads) if talking_head else nn.Identity()
|
|
30
|
-
|
|
31
|
-
self.linear = linear
|
|
32
|
-
self.sr_ratio = sr_ratio
|
|
33
|
-
if not linear:
|
|
34
|
-
if sr_ratio > 1:
|
|
35
|
-
self.sr = nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio)
|
|
36
|
-
self.norm = nn.LayerNorm(dim)
|
|
37
|
-
else:
|
|
38
|
-
self.pool = nn.AdaptiveAvgPool2d(7)
|
|
39
|
-
self.sr = nn.Conv2d(dim, dim, kernel_size=1, stride=1)
|
|
40
|
-
self.norm = nn.LayerNorm(dim)
|
|
41
|
-
self.act = nn.GELU()
|
|
42
|
-
self.apply(self._init_weights)
|
|
43
|
-
|
|
44
|
-
def _init_weights(self, m):
|
|
45
|
-
if isinstance(m, nn.Linear):
|
|
46
|
-
trunc_normal_(m.weight, std=.02)
|
|
47
|
-
if isinstance(m, nn.Linear) and m.bias is not None:
|
|
48
|
-
nn.init.constant_(m.bias, 0)
|
|
49
|
-
elif isinstance(m, nn.LayerNorm):
|
|
50
|
-
nn.init.constant_(m.bias, 0)
|
|
51
|
-
nn.init.constant_(m.weight, 1.0)
|
|
52
|
-
elif isinstance(m, nn.Conv2d):
|
|
53
|
-
fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
|
54
|
-
fan_out //= m.groups
|
|
55
|
-
m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
|
|
56
|
-
if m.bias is not None:
|
|
57
|
-
m.bias.data.zero_()
|
|
58
|
-
|
|
59
|
-
def forward(self, x, H, W):
|
|
60
|
-
B, N, C = x.shape
|
|
61
|
-
q = self.q(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
|
|
62
|
-
|
|
63
|
-
if not self.linear:
|
|
64
|
-
if self.sr_ratio > 1:
|
|
65
|
-
x_ = x.permute(0, 2, 1).reshape(B, C, H, W)
|
|
66
|
-
x_ = self.sr(x_).reshape(B, C, -1).permute(0, 2, 1)
|
|
67
|
-
x_ = self.norm(x_)
|
|
68
|
-
kv = self.kv(x_).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
|
69
|
-
else:
|
|
70
|
-
kv = self.kv(x).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
|
71
|
-
else:
|
|
72
|
-
x_ = x.permute(0, 2, 1).reshape(B, C, H, W)
|
|
73
|
-
x_ = self.sr(self.pool(x_)).reshape(B, C, -1).permute(0, 2, 1)
|
|
74
|
-
x_ = self.norm(x_)
|
|
75
|
-
x_ = self.act(x_)
|
|
76
|
-
kv = self.kv(x_).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
|
77
|
-
k, v = kv[0], kv[1]
|
|
78
|
-
|
|
79
|
-
attn = (q @ k.transpose(-2, -1)) * self.scale
|
|
80
|
-
|
|
81
|
-
attn = self.proj_l(attn.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
|
|
82
|
-
attn = attn.softmax(dim=-1)
|
|
83
|
-
attn = self.attn_drop(attn)
|
|
84
|
-
attn = self.proj_w(attn.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
|
|
85
|
-
|
|
86
|
-
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
|
|
87
|
-
x = self.proj(x)
|
|
88
|
-
x = self.proj_drop(x)
|
|
89
|
-
|
|
90
|
-
return x
|
|
@@ -1,37 +0,0 @@
|
|
|
1
|
-
import torch.nn as nn
|
|
2
|
-
class MHSA(nn.Module):
|
|
3
|
-
def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.,
|
|
4
|
-
talking_head=False):
|
|
5
|
-
super().__init__()
|
|
6
|
-
self.num_heads = num_heads
|
|
7
|
-
head_dim = dim // num_heads
|
|
8
|
-
self.scale = head_dim ** -0.5
|
|
9
|
-
|
|
10
|
-
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
|
11
|
-
self.attn_drop = nn.Dropout(attn_drop)
|
|
12
|
-
self.proj = nn.Linear(dim, dim)
|
|
13
|
-
self.proj_drop = nn.Dropout(proj_drop)
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
self.proj_l = nn.Linear(num_heads, num_heads) if talking_head else nn.Identity()
|
|
17
|
-
self.proj_w = nn.Linear(num_heads, num_heads) if talking_head else nn.Identity()
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
def forward(self, x):
|
|
21
|
-
B, N, C = x.shape
|
|
22
|
-
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
|
23
|
-
q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
|
|
24
|
-
|
|
25
|
-
attn = (q @ k.transpose(-2, -1)) * self.scale # (B, H, T, D)
|
|
26
|
-
# attn = self.proj_l(attn.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
|
|
27
|
-
attn = self.proj_l(attn.transpose(1, -1)).transpose(1, -1)
|
|
28
|
-
|
|
29
|
-
attn = attn.softmax(dim=-1)
|
|
30
|
-
attn = self.attn_drop(attn)
|
|
31
|
-
# attn = self.proj_w(attn.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
|
|
32
|
-
attn = self.proj_w(attn.transpose(1, -1)).transpose(1, -1)
|
|
33
|
-
|
|
34
|
-
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
|
|
35
|
-
x = self.proj(x)
|
|
36
|
-
x = self.proj_drop(x)
|
|
37
|
-
return x, k.mean(dim=1)
|
|
File without changes
|
|
File without changes
|
joonmyung-1.5.13/models/deit.py
DELETED
|
@@ -1,372 +0,0 @@
|
|
|
1
|
-
""" Vision Transformer (ViT) in PyTorch
|
|
2
|
-
|
|
3
|
-
A PyTorch implement of Vision Transformers as described in:
|
|
4
|
-
|
|
5
|
-
'An Image Is Worth 16 x 16 Words: Transformers for Image Recognition at Scale'
|
|
6
|
-
- https://arxiv.org/abs/2010.11929
|
|
7
|
-
|
|
8
|
-
`How to train your ViT? Data, Augmentation, and Regularization in Vision Transformers`
|
|
9
|
-
- https://arxiv.org/abs/2106.10270
|
|
10
|
-
|
|
11
|
-
The official jax code is released and available at https://github.com/google-research/vision_transformer
|
|
12
|
-
|
|
13
|
-
DeiT model defs and weights from https://github.com/facebookresearch/deit,
|
|
14
|
-
paper `DeiT: Data-efficient Image Transformers` - https://arxiv.org/abs/2012.12877
|
|
15
|
-
|
|
16
|
-
Acknowledgments:
|
|
17
|
-
* The paper authors for releasing code and weights, thanks!
|
|
18
|
-
* I fixed my class token impl based on Phil Wang's https://github.com/lucidrains/vit-pytorch ... check it out
|
|
19
|
-
for some einops/einsum fun
|
|
20
|
-
* Simple transformer style inspired by Andrej Karpathy's https://github.com/karpathy/minGPT
|
|
21
|
-
* Bert reference code checks against Huggingface Transformers and Tensorflow Bert
|
|
22
|
-
|
|
23
|
-
Hacked together by / Copyright 2021 Ross Wightman
|
|
24
|
-
"""
|
|
25
|
-
import math
|
|
26
|
-
import logging
|
|
27
|
-
from functools import partial
|
|
28
|
-
|
|
29
|
-
import torch
|
|
30
|
-
import torch.nn as nn
|
|
31
|
-
import torch.nn.functional as F
|
|
32
|
-
|
|
33
|
-
from timm.models.helpers import named_apply, adapt_input_conv
|
|
34
|
-
from timm.models.layers import PatchEmbed, trunc_normal_, lecun_normal_
|
|
35
|
-
from timm.models.registry import register_model
|
|
36
|
-
|
|
37
|
-
from models.modules.PE import PositionalEncodingFourier, get_2d_sincos_pos_embed
|
|
38
|
-
from models.modules.blocks import Block_DEIT
|
|
39
|
-
|
|
40
|
-
_logger = logging.getLogger(__name__)
|
|
41
|
-
|
|
42
|
-
class VisionTransformer(nn.Module):
|
|
43
|
-
""" Vision Transformer
|
|
44
|
-
|
|
45
|
-
A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale`
|
|
46
|
-
- https://arxiv.org/abs/2010.11929
|
|
47
|
-
|
|
48
|
-
Includes distillation token & head support for `DeiT: Data-efficient Image Transformers`
|
|
49
|
-
- https://arxiv.org/abs/2012.12877
|
|
50
|
-
"""
|
|
51
|
-
|
|
52
|
-
def __init__(self, input_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,
|
|
53
|
-
num_heads=12, mlp_ratio=4., qkv_bias=True,
|
|
54
|
-
drop_rate=0., attn_drop_rate=0., drop_path_rate=0., embed_layer=PatchEmbed, norm_layer=None,
|
|
55
|
-
act_layer=None, weight_init='', **kwargs):
|
|
56
|
-
super().__init__()
|
|
57
|
-
self.num_classes = num_classes
|
|
58
|
-
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
|
|
59
|
-
self.cls_token_num = kwargs["token_nums"][0]
|
|
60
|
-
self.dist_token_num = kwargs["token_nums"][1]
|
|
61
|
-
|
|
62
|
-
self.patch_embed_type = kwargs["embed_type"][0]
|
|
63
|
-
self.pos_embed_type = kwargs["embed_type"][1]
|
|
64
|
-
self.pos_embed_applied_type = kwargs["embed_type"][2]
|
|
65
|
-
|
|
66
|
-
|
|
67
|
-
self.layer_scale = kwargs["model_type"][0]
|
|
68
|
-
self.talking_head = kwargs["model_type"][1]
|
|
69
|
-
|
|
70
|
-
self.num_tokens = self.cls_token_num + self.dist_token_num
|
|
71
|
-
|
|
72
|
-
act_layer = act_layer or nn.GELU
|
|
73
|
-
|
|
74
|
-
img_size = input_size
|
|
75
|
-
if self.patch_embed_type == 0:
|
|
76
|
-
self.patch_embed = embed_layer(
|
|
77
|
-
img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
|
|
78
|
-
else:
|
|
79
|
-
raise ValueError
|
|
80
|
-
num_patches = self.patch_embed.num_patches
|
|
81
|
-
|
|
82
|
-
if self.cls_token_num: self.cls_token = nn.Parameter(torch.zeros(1, self.cls_token_num, embed_dim))
|
|
83
|
-
if self.dist_token_num: self.dist_token = nn.Parameter(torch.zeros(1, self.dist_token_num, embed_dim))
|
|
84
|
-
|
|
85
|
-
pos_embed_token_num = num_patches + self.num_tokens if self.pos_embed_applied_type == 0 else num_patches
|
|
86
|
-
if self.pos_embed_type == 0:
|
|
87
|
-
self.pos_embed = nn.Parameter(torch.zeros(1, pos_embed_token_num, embed_dim))
|
|
88
|
-
elif self.pos_embed_type == 1:
|
|
89
|
-
self.pos_embed = nn.Parameter(torch.zeros(1, pos_embed_token_num, embed_dim), requires_grad=False)
|
|
90
|
-
pos_embed = get_2d_sincos_pos_embed(embed_dim, int(pos_embed_token_num **.5))
|
|
91
|
-
self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
|
|
92
|
-
elif self.pos_embed_type == 2:
|
|
93
|
-
self.pos_embed = PositionalEncodingFourier(dim=embed_dim)
|
|
94
|
-
else:
|
|
95
|
-
raise ValueError
|
|
96
|
-
self.pos_drop = nn.Dropout(p=drop_rate)
|
|
97
|
-
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
|
|
98
|
-
self.blocks = nn.Sequential(*[
|
|
99
|
-
Block_DEIT(
|
|
100
|
-
dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, drop=drop_rate,
|
|
101
|
-
attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, act_layer=act_layer,
|
|
102
|
-
# OURS
|
|
103
|
-
layer_scale=self.layer_scale, talking_head=self.talking_head)
|
|
104
|
-
for i in range(depth)])
|
|
105
|
-
self.norm = norm_layer(embed_dim)
|
|
106
|
-
self.pre_logits = nn.Identity()
|
|
107
|
-
|
|
108
|
-
# Classifier head(s)
|
|
109
|
-
self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
|
|
110
|
-
if self.dist_token_num:
|
|
111
|
-
self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if num_classes > 0 else nn.Identity()
|
|
112
|
-
else:
|
|
113
|
-
self.head_dist = None
|
|
114
|
-
|
|
115
|
-
self.init_weights(weight_init)
|
|
116
|
-
|
|
117
|
-
def init_weights(self, mode=''):
|
|
118
|
-
assert mode in ('jax', 'jax_nlhb', 'nlhb', '')
|
|
119
|
-
head_bias = -math.log(self.num_classes) if 'nlhb' in mode else 0.
|
|
120
|
-
|
|
121
|
-
if self.pos_embed_type not in [1,2]: trunc_normal_(self.pos_embed, std=.02)
|
|
122
|
-
if self.dist_token_num:
|
|
123
|
-
trunc_normal_(self.dist_token, std=.02)
|
|
124
|
-
if mode.startswith('jax'):
|
|
125
|
-
# leave cls token as zeros to match jax impl
|
|
126
|
-
named_apply(partial(_init_vit_weights, head_bias=head_bias, jax_impl=True), self)
|
|
127
|
-
else:
|
|
128
|
-
if self.cls_token_num:
|
|
129
|
-
trunc_normal_(self.cls_token, std=.02)
|
|
130
|
-
self.apply(_init_vit_weights)
|
|
131
|
-
|
|
132
|
-
def _init_weights(self, m):
|
|
133
|
-
# this fn left here for compat with downstream users
|
|
134
|
-
_init_vit_weights(m)
|
|
135
|
-
|
|
136
|
-
@torch.jit.ignore()
|
|
137
|
-
def load_pretrained(self, checkpoint_path, prefix=''):
|
|
138
|
-
_load_weights(self, checkpoint_path, prefix)
|
|
139
|
-
|
|
140
|
-
@torch.jit.ignore
|
|
141
|
-
def no_weight_decay(self):
|
|
142
|
-
return {'pos_embed', 'cls_token', 'dist_token'}
|
|
143
|
-
|
|
144
|
-
def get_classifier(self):
|
|
145
|
-
if self.dist_token is None:
|
|
146
|
-
return self.head
|
|
147
|
-
else:
|
|
148
|
-
return self.head, self.head_dist
|
|
149
|
-
|
|
150
|
-
def reset_classifier(self, num_classes, global_pool=''):
|
|
151
|
-
self.num_classes = num_classes
|
|
152
|
-
self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
|
|
153
|
-
if self.dist_token_num:
|
|
154
|
-
self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if num_classes > 0 else nn.Identity()
|
|
155
|
-
|
|
156
|
-
def forward_features(self, x):
|
|
157
|
-
|
|
158
|
-
x = self.patch_embed(x)
|
|
159
|
-
B, T, D = x.shape
|
|
160
|
-
cls_token = self.cls_token.expand(B, -1, -1)
|
|
161
|
-
|
|
162
|
-
pos_embed = self.pos_embed if self.pos_embed_type not in [2] else self.pos_embed(B, int(T**0.5), int(T**0.5)).reshape(B, -1, T).transpose(1, 2)
|
|
163
|
-
if self.pos_embed_applied_type == 0:
|
|
164
|
-
x = torch.cat((cls_token, x), dim=1)
|
|
165
|
-
x = self.pos_drop(x + pos_embed)
|
|
166
|
-
else:
|
|
167
|
-
x = torch.cat((cls_token, x + pos_embed), dim=1)
|
|
168
|
-
x = self.pos_drop(x)
|
|
169
|
-
|
|
170
|
-
x = self.blocks(x)
|
|
171
|
-
x = self.norm(x)
|
|
172
|
-
return self.pre_logits(x[:, :self.cls_token_num].mean(dim=1))
|
|
173
|
-
|
|
174
|
-
def forward(self, x):
|
|
175
|
-
x = self.forward_features(x)
|
|
176
|
-
x = self.head(x)
|
|
177
|
-
return x
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
def _init_vit_weights(module: nn.Module, name: str = '', head_bias: float = 0., jax_impl: bool = False):
|
|
181
|
-
""" ViT weight initialization
|
|
182
|
-
* When called without n, head_bias, jax_impl args it will behave exactly the same
|
|
183
|
-
as my original init for compatibility with prev hparam / downstream use cases (ie DeiT).
|
|
184
|
-
* When called w/ valid n (module name) and jax_impl=True, will (hopefully) match JAX impl
|
|
185
|
-
"""
|
|
186
|
-
if isinstance(module, nn.Linear):
|
|
187
|
-
if name.startswith('head'):
|
|
188
|
-
nn.init.zeros_(module.weight)
|
|
189
|
-
nn.init.constant_(module.bias, head_bias)
|
|
190
|
-
elif name.startswith('pre_logits'):
|
|
191
|
-
lecun_normal_(module.weight)
|
|
192
|
-
nn.init.zeros_(module.bias)
|
|
193
|
-
else:
|
|
194
|
-
if jax_impl:
|
|
195
|
-
nn.init.xavier_uniform_(module.weight)
|
|
196
|
-
if module.bias is not None:
|
|
197
|
-
if 'mlp' in name:
|
|
198
|
-
nn.init.normal_(module.bias, std=1e-6)
|
|
199
|
-
else:
|
|
200
|
-
nn.init.zeros_(module.bias)
|
|
201
|
-
else:
|
|
202
|
-
trunc_normal_(module.weight, std=.02)
|
|
203
|
-
if module.bias is not None:
|
|
204
|
-
nn.init.zeros_(module.bias)
|
|
205
|
-
elif jax_impl and isinstance(module, nn.Conv2d):
|
|
206
|
-
# NOTE conv was left to pytorch default in my original init
|
|
207
|
-
lecun_normal_(module.weight)
|
|
208
|
-
if module.bias is not None:
|
|
209
|
-
nn.init.zeros_(module.bias)
|
|
210
|
-
elif isinstance(module, (nn.LayerNorm, nn.GroupNorm, nn.BatchNorm2d)):
|
|
211
|
-
nn.init.zeros_(module.bias)
|
|
212
|
-
nn.init.ones_(module.weight)
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
@torch.no_grad()
|
|
216
|
-
def _load_weights(model: VisionTransformer, checkpoint_path: str, prefix: str = ''):
|
|
217
|
-
""" Load weights from .npz checkpoints for official Google Brain Flax implementation
|
|
218
|
-
"""
|
|
219
|
-
import numpy as np
|
|
220
|
-
|
|
221
|
-
def _n2p(w, t=True):
|
|
222
|
-
if w.ndim == 4 and w.shape[0] == w.shape[1] == w.shape[2] == 1:
|
|
223
|
-
w = w.flatten()
|
|
224
|
-
if t:
|
|
225
|
-
if w.ndim == 4:
|
|
226
|
-
w = w.transpose([3, 2, 0, 1])
|
|
227
|
-
elif w.ndim == 3:
|
|
228
|
-
w = w.transpose([2, 0, 1])
|
|
229
|
-
elif w.ndim == 2:
|
|
230
|
-
w = w.transpose([1, 0])
|
|
231
|
-
return torch.from_numpy(w)
|
|
232
|
-
|
|
233
|
-
w = np.load(checkpoint_path)
|
|
234
|
-
if not prefix and 'opt/target/embedding/kernel' in w:
|
|
235
|
-
prefix = 'opt/target/'
|
|
236
|
-
|
|
237
|
-
if hasattr(model.patch_embed, 'backbone'):
|
|
238
|
-
# hybrid
|
|
239
|
-
backbone = model.patch_embed.backbone
|
|
240
|
-
stem_only = not hasattr(backbone, 'stem')
|
|
241
|
-
stem = backbone if stem_only else backbone.stem
|
|
242
|
-
stem.conv.weight.copy_(adapt_input_conv(stem.conv.weight.shape[1], _n2p(w[f'{prefix}conv_root/kernel'])))
|
|
243
|
-
stem.norm.weight.copy_(_n2p(w[f'{prefix}gn_root/scale']))
|
|
244
|
-
stem.norm.bias.copy_(_n2p(w[f'{prefix}gn_root/bias']))
|
|
245
|
-
if not stem_only:
|
|
246
|
-
for i, stage in enumerate(backbone.stages):
|
|
247
|
-
for j, block in enumerate(stage.blocks):
|
|
248
|
-
bp = f'{prefix}block{i + 1}/unit{j + 1}/'
|
|
249
|
-
for r in range(3):
|
|
250
|
-
getattr(block, f'conv{r + 1}').weight.copy_(_n2p(w[f'{bp}conv{r + 1}/kernel']))
|
|
251
|
-
getattr(block, f'norm{r + 1}').weight.copy_(_n2p(w[f'{bp}gn{r + 1}/scale']))
|
|
252
|
-
getattr(block, f'norm{r + 1}').bias.copy_(_n2p(w[f'{bp}gn{r + 1}/bias']))
|
|
253
|
-
if block.downsample is not None:
|
|
254
|
-
block.downsample.conv.weight.copy_(_n2p(w[f'{bp}conv_proj/kernel']))
|
|
255
|
-
block.downsample.norm.weight.copy_(_n2p(w[f'{bp}gn_proj/scale']))
|
|
256
|
-
block.downsample.norm.bias.copy_(_n2p(w[f'{bp}gn_proj/bias']))
|
|
257
|
-
embed_conv_w = _n2p(w[f'{prefix}embedding/kernel'])
|
|
258
|
-
else:
|
|
259
|
-
embed_conv_w = adapt_input_conv(
|
|
260
|
-
model.patch_embed.proj.weight.shape[1], _n2p(w[f'{prefix}embedding/kernel']))
|
|
261
|
-
model.patch_embed.proj.weight.copy_(embed_conv_w)
|
|
262
|
-
model.patch_embed.proj.bias.copy_(_n2p(w[f'{prefix}embedding/bias']))
|
|
263
|
-
model.cls_token.copy_(_n2p(w[f'{prefix}cls'], t=False))
|
|
264
|
-
pos_embed_w = _n2p(w[f'{prefix}Transformer/posembed_input/pos_embedding'], t=False)
|
|
265
|
-
if pos_embed_w.shape != model.pos_embed.shape:
|
|
266
|
-
pos_embed_w = resize_pos_embed( # resize pos embedding when different size from pretrained weights
|
|
267
|
-
pos_embed_w, model.pos_embed, getattr(model, 'num_tokens', 1), model.patch_embed.grid_size)
|
|
268
|
-
model.pos_embed.copy_(pos_embed_w)
|
|
269
|
-
model.norm.weight.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/scale']))
|
|
270
|
-
model.norm.bias.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/bias']))
|
|
271
|
-
if isinstance(model.head, nn.Linear) and model.head.bias.shape[0] == w[f'{prefix}head/bias'].shape[-1]:
|
|
272
|
-
model.head.weight.copy_(_n2p(w[f'{prefix}head/kernel']))
|
|
273
|
-
model.head.bias.copy_(_n2p(w[f'{prefix}head/bias']))
|
|
274
|
-
if isinstance(getattr(model.pre_logits, 'fc', None), nn.Linear) and f'{prefix}pre_logits/bias' in w:
|
|
275
|
-
model.pre_logits.fc.weight.copy_(_n2p(w[f'{prefix}pre_logits/kernel']))
|
|
276
|
-
model.pre_logits.fc.bias.copy_(_n2p(w[f'{prefix}pre_logits/bias']))
|
|
277
|
-
for i, block in enumerate(model.blocks.children()):
|
|
278
|
-
block_prefix = f'{prefix}Transformer/encoderblock_{i}/'
|
|
279
|
-
mha_prefix = block_prefix + 'MultiHeadDotProductAttention_1/'
|
|
280
|
-
block.norm1.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale']))
|
|
281
|
-
block.norm1.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias']))
|
|
282
|
-
block.attn.qkv.weight.copy_(torch.cat([
|
|
283
|
-
_n2p(w[f'{mha_prefix}{n}/kernel'], t=False).flatten(1).T for n in ('query', 'key', 'value')]))
|
|
284
|
-
block.attn.qkv.bias.copy_(torch.cat([
|
|
285
|
-
_n2p(w[f'{mha_prefix}{n}/bias'], t=False).reshape(-1) for n in ('query', 'key', 'value')]))
|
|
286
|
-
block.attn.proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel']).flatten(1))
|
|
287
|
-
block.attn.proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias']))
|
|
288
|
-
for r in range(2):
|
|
289
|
-
getattr(block.mlp, f'fc{r + 1}').weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_{r}/kernel']))
|
|
290
|
-
getattr(block.mlp, f'fc{r + 1}').bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_{r}/bias']))
|
|
291
|
-
block.norm2.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/scale']))
|
|
292
|
-
block.norm2.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/bias']))
|
|
293
|
-
|
|
294
|
-
|
|
295
|
-
def resize_pos_embed(posemb, posemb_new, num_tokens=1, gs_new=()):
|
|
296
|
-
# Rescale the grid of position embeddings when loading from state_dict. Adapted from
|
|
297
|
-
# https://github.com/google-research/vision_transformer/blob/00883dd691c63a6830751563748663526e811cee/vit_jax/checkpoint.py#L224
|
|
298
|
-
_logger.info('Resized position embedding: %s to %s', posemb.shape, posemb_new.shape)
|
|
299
|
-
ntok_new = posemb_new.shape[1]
|
|
300
|
-
if num_tokens:
|
|
301
|
-
posemb_tok, posemb_grid = posemb[:, :num_tokens], posemb[0, num_tokens:]
|
|
302
|
-
ntok_new -= num_tokens
|
|
303
|
-
else:
|
|
304
|
-
posemb_tok, posemb_grid = posemb[:, :0], posemb[0]
|
|
305
|
-
gs_old = int(math.sqrt(len(posemb_grid)))
|
|
306
|
-
if not len(gs_new): # backwards compatibility
|
|
307
|
-
gs_new = [int(math.sqrt(ntok_new))] * 2
|
|
308
|
-
assert len(gs_new) >= 2
|
|
309
|
-
_logger.info('Position embedding grid-size from %s to %s', [gs_old, gs_old], gs_new)
|
|
310
|
-
posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2)
|
|
311
|
-
posemb_grid = F.interpolate(posemb_grid, size=gs_new, mode='bilinear')
|
|
312
|
-
posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_new[0] * gs_new[1], -1)
|
|
313
|
-
posemb = torch.cat([posemb_tok, posemb_grid], dim=1)
|
|
314
|
-
return posemb
|
|
315
|
-
|
|
316
|
-
|
|
317
|
-
def checkpoint_filter_fn(state_dict, model):
|
|
318
|
-
""" convert patch embedding weight from manual patchify + linear proj to conv"""
|
|
319
|
-
out_dict = {}
|
|
320
|
-
if 'model' in state_dict:
|
|
321
|
-
# For deit models
|
|
322
|
-
state_dict = state_dict['model']
|
|
323
|
-
for k, v in state_dict.items():
|
|
324
|
-
if 'patch_embed.proj.weight' in k and len(v.shape) < 4:
|
|
325
|
-
# For old models that I trained prior to conv based patchification
|
|
326
|
-
O, I, H, W = model.patch_embed.proj.weight.shape
|
|
327
|
-
v = v.reshape(O, -1, H, W)
|
|
328
|
-
elif k == 'pos_embed' and v.shape != model.pos_embed.shape:
|
|
329
|
-
# To resize pos embedding when using model at different size from pretrained weights
|
|
330
|
-
v = resize_pos_embed(
|
|
331
|
-
v, model.pos_embed, getattr(model, 'num_tokens', 1), model.patch_embed.grid_size)
|
|
332
|
-
out_dict[k] = v
|
|
333
|
-
return out_dict
|
|
334
|
-
|
|
335
|
-
@register_model
|
|
336
|
-
def deit_tiny(pretrained=False, **kwargs):
|
|
337
|
-
model_kwargs = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3,
|
|
338
|
-
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
|
|
339
|
-
model = VisionTransformer(**model_kwargs)
|
|
340
|
-
if pretrained:
|
|
341
|
-
checkpoint = torch.hub.load_state_dict_from_url(
|
|
342
|
-
url="https://dl.fbaipublicfiles.com/deit/deit_tiny_patch16_224-a1311bcf.pth",
|
|
343
|
-
map_location="cpu", check_hash=True
|
|
344
|
-
)
|
|
345
|
-
model.load_state_dict(checkpoint["model"])
|
|
346
|
-
return model
|
|
347
|
-
|
|
348
|
-
@register_model
|
|
349
|
-
def deit_small(pretrained=False, **kwargs):
|
|
350
|
-
model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6,
|
|
351
|
-
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
|
|
352
|
-
model = VisionTransformer(**model_kwargs)
|
|
353
|
-
if pretrained:
|
|
354
|
-
checkpoint = torch.hub.load_state_dict_from_url(
|
|
355
|
-
url="https://dl.fbaipublicfiles.com/deit/deit_small_patch16_224-cd65a155.pth",
|
|
356
|
-
map_location="cpu", check_hash=True
|
|
357
|
-
)
|
|
358
|
-
model.load_state_dict(checkpoint["model"])
|
|
359
|
-
return model
|
|
360
|
-
|
|
361
|
-
@register_model
|
|
362
|
-
def deit_base(pretrained=False, **kwargs):
|
|
363
|
-
model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12,
|
|
364
|
-
norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
|
|
365
|
-
model = VisionTransformer(**model_kwargs)
|
|
366
|
-
if pretrained:
|
|
367
|
-
checkpoint = torch.hub.load_state_dict_from_url(
|
|
368
|
-
url="https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth",
|
|
369
|
-
map_location="cpu", check_hash=True
|
|
370
|
-
)
|
|
371
|
-
model.load_state_dict(checkpoint["model"])
|
|
372
|
-
return model
|