joonmyung 1.5.13__tar.gz → 1.5.14__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (47) hide show
  1. {joonmyung-1.5.13 → joonmyung-1.5.14}/PKG-INFO +1 -1
  2. {joonmyung-1.5.13 → joonmyung-1.5.14}/joonmyung/draw.py +10 -16
  3. {joonmyung-1.5.13 → joonmyung-1.5.14}/joonmyung.egg-info/PKG-INFO +1 -1
  4. {joonmyung-1.5.13 → joonmyung-1.5.14}/joonmyung.egg-info/SOURCES.txt +1 -14
  5. {joonmyung-1.5.13 → joonmyung-1.5.14}/joonmyung.egg-info/top_level.txt +0 -1
  6. {joonmyung-1.5.13 → joonmyung-1.5.14}/setup.py +1 -1
  7. joonmyung-1.5.13/models/SA/MHSA.py +0 -37
  8. joonmyung-1.5.13/models/SA/PVTSA.py +0 -90
  9. joonmyung-1.5.13/models/SA/TMSA.py +0 -37
  10. joonmyung-1.5.13/models/SA/__init__.py +0 -0
  11. joonmyung-1.5.13/models/__init__.py +0 -0
  12. joonmyung-1.5.13/models/deit.py +0 -372
  13. joonmyung-1.5.13/models/evit.py +0 -154
  14. joonmyung-1.5.13/models/modules/PE.py +0 -139
  15. joonmyung-1.5.13/models/modules/__init__.py +0 -0
  16. joonmyung-1.5.13/models/modules/blocks.py +0 -168
  17. joonmyung-1.5.13/models/pvt.py +0 -307
  18. joonmyung-1.5.13/models/pvt_v2.py +0 -202
  19. joonmyung-1.5.13/models/tome.py +0 -285
  20. {joonmyung-1.5.13 → joonmyung-1.5.14}/LICENSE.txt +0 -0
  21. {joonmyung-1.5.13 → joonmyung-1.5.14}/README.md +0 -0
  22. {joonmyung-1.5.13 → joonmyung-1.5.14}/joonmyung/__init__.py +0 -0
  23. {joonmyung-1.5.13 → joonmyung-1.5.14}/joonmyung/analysis/__init__.py +0 -0
  24. {joonmyung-1.5.13 → joonmyung-1.5.14}/joonmyung/analysis/analysis.py +0 -0
  25. {joonmyung-1.5.13 → joonmyung-1.5.14}/joonmyung/analysis/dataset.py +0 -0
  26. {joonmyung-1.5.13 → joonmyung-1.5.14}/joonmyung/analysis/hook.py +0 -0
  27. {joonmyung-1.5.13 → joonmyung-1.5.14}/joonmyung/analysis/metric.py +0 -0
  28. {joonmyung-1.5.13 → joonmyung-1.5.14}/joonmyung/analysis/model.py +0 -0
  29. {joonmyung-1.5.13 → joonmyung-1.5.14}/joonmyung/analysis/utils.py +0 -0
  30. {joonmyung-1.5.13 → joonmyung-1.5.14}/joonmyung/app.py +0 -0
  31. {joonmyung-1.5.13 → joonmyung-1.5.14}/joonmyung/data.py +0 -0
  32. {joonmyung-1.5.13 → joonmyung-1.5.14}/joonmyung/dummy.py +0 -0
  33. {joonmyung-1.5.13 → joonmyung-1.5.14}/joonmyung/file.py +0 -0
  34. {joonmyung-1.5.13 → joonmyung-1.5.14}/joonmyung/gradcam.py +0 -0
  35. {joonmyung-1.5.13 → joonmyung-1.5.14}/joonmyung/log.py +0 -0
  36. {joonmyung-1.5.13 → joonmyung-1.5.14}/joonmyung/meta_data/__init__.py +0 -0
  37. {joonmyung-1.5.13 → joonmyung-1.5.14}/joonmyung/meta_data/label.py +0 -0
  38. {joonmyung-1.5.13 → joonmyung-1.5.14}/joonmyung/meta_data/utils.py +0 -0
  39. {joonmyung-1.5.13 → joonmyung-1.5.14}/joonmyung/metric.py +0 -0
  40. {joonmyung-1.5.13 → joonmyung-1.5.14}/joonmyung/models/__init__.py +0 -0
  41. {joonmyung-1.5.13 → joonmyung-1.5.14}/joonmyung/models/tome.py +0 -0
  42. {joonmyung-1.5.13 → joonmyung-1.5.14}/joonmyung/script.py +0 -0
  43. {joonmyung-1.5.13 → joonmyung-1.5.14}/joonmyung/status.py +0 -0
  44. {joonmyung-1.5.13 → joonmyung-1.5.14}/joonmyung/utils.py +0 -0
  45. {joonmyung-1.5.13 → joonmyung-1.5.14}/joonmyung.egg-info/dependency_links.txt +0 -0
  46. {joonmyung-1.5.13 → joonmyung-1.5.14}/joonmyung.egg-info/not-zip-safe +0 -0
  47. {joonmyung-1.5.13 → joonmyung-1.5.14}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: joonmyung
3
- Version: 1.5.13
3
+ Version: 1.5.14
4
4
  Summary: JoonMyung's Library
5
5
  Home-page: https://github.com/pizard/JoonMyung.git
6
6
  Author: JoonMyung Choi
@@ -241,28 +241,22 @@ def saliency(attentions=None, gradients=None, head_fusion="mean",
241
241
 
242
242
 
243
243
 
244
+
244
245
  def data2PIL(datas):
245
- if type(datas) == torch.Tensor: # MAKE TO (..., H, W, 3)
246
- if datas.shape[-1] == 3:
247
- pils = datas
248
- elif len(datas.shape) == 2:
246
+ if type(datas) == torch.Tensor:
247
+ if len(datas.shape) == 2:
249
248
  pils = datas.unsqueeze(-1).detach().cpu()
250
- elif len(datas.shape) == 3:
249
+ if len(datas.shape) == 3:
251
250
  pils = datas.permute(1, 2, 0).detach().cpu()
252
251
  elif len(datas.shape) == 4:
253
252
  pils = datas.permute(0, 2, 3, 1).detach().cpu()
254
253
  elif type(datas) == np.ndarray:
255
- if len(datas.shape) == 2:
256
- datas = np.expand_dims(datas, axis=-1)
257
- # TODO NEED TO CHECK
258
- # if datas.max() <= 1:
259
- # pils = cv2.cvtColor(datas, cv2.COLOR_BGR2RGB) # 0.29ms
260
- # else:
261
- # pils = datas
262
-
263
- # image = Image.fromarray(image) # 0.32ms
264
- pils = cv2.cvtColor(datas, cv2.COLOR_BGR2RGB) # 0.29ms
265
-
254
+ if len(datas.shape) == 3: datas = np.expand_dims(datas, axis=0)
255
+ if datas.max() <= 1:
256
+ # image = Image.fromarray(image) # 0.32ms
257
+ pils = cv2.cvtColor(datas, cv2.COLOR_BGR2RGB) # 0.29ms
258
+ else:
259
+ pils = datas
266
260
  elif type(datas) == PIL.JpegImagePlugin.JpegImageFile \
267
261
  or type(datas) == PIL.Image.Image:
268
262
  pils = datas
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: joonmyung
3
- Version: 1.5.13
3
+ Version: 1.5.14
4
4
  Summary: JoonMyung's Library
5
5
  Home-page: https://github.com/pizard/JoonMyung.git
6
6
  Author: JoonMyung Choi
@@ -29,17 +29,4 @@ joonmyung/meta_data/__init__.py
29
29
  joonmyung/meta_data/label.py
30
30
  joonmyung/meta_data/utils.py
31
31
  joonmyung/models/__init__.py
32
- joonmyung/models/tome.py
33
- models/__init__.py
34
- models/deit.py
35
- models/evit.py
36
- models/pvt.py
37
- models/pvt_v2.py
38
- models/tome.py
39
- models/SA/MHSA.py
40
- models/SA/PVTSA.py
41
- models/SA/TMSA.py
42
- models/SA/__init__.py
43
- models/modules/PE.py
44
- models/modules/__init__.py
45
- models/modules/blocks.py
32
+ joonmyung/models/tome.py
@@ -3,7 +3,7 @@ from setuptools import find_packages
3
3
 
4
4
  setuptools.setup(
5
5
  name="joonmyung",
6
- version="1.5.13",
6
+ version="1.5.14",
7
7
  author="JoonMyung Choi",
8
8
  author_email="pizard@korea.ac.kr",
9
9
  description="JoonMyung's Library",
@@ -1,37 +0,0 @@
1
- import torch.nn as nn
2
- class MHSA(nn.Module):
3
- def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.,
4
- talking_head=False):
5
- super().__init__()
6
- self.num_heads = num_heads
7
- head_dim = dim // num_heads
8
- self.scale = head_dim ** -0.5
9
-
10
- self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
11
- self.attn_drop = nn.Dropout(attn_drop)
12
- self.proj = nn.Linear(dim, dim)
13
- self.proj_drop = nn.Dropout(proj_drop)
14
-
15
-
16
- self.proj_l = nn.Linear(num_heads, num_heads) if talking_head else nn.Identity()
17
- self.proj_w = nn.Linear(num_heads, num_heads) if talking_head else nn.Identity()
18
-
19
-
20
- def forward(self, x):
21
- B, N, C = x.shape
22
- qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
23
- q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
24
-
25
- attn = (q @ k.transpose(-2, -1)) * self.scale # (B, H, T, D)
26
- # attn = self.proj_l(attn.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
27
- attn = self.proj_l(attn.transpose(1, -1)).transpose(1, -1)
28
-
29
- attn = attn.softmax(dim=-1)
30
- attn = self.attn_drop(attn)
31
- # attn = self.proj_w(attn.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
32
- attn = self.proj_w(attn.transpose(1, -1)).transpose(1, -1)
33
-
34
- x = (attn @ v).transpose(1, 2).reshape(B, N, C)
35
- x = self.proj(x)
36
- x = self.proj_drop(x)
37
- return x
@@ -1,90 +0,0 @@
1
- import torch
2
- import torch.nn as nn
3
- import torch.nn.functional as F
4
- from functools import partial
5
-
6
- from timm.models.layers import DropPath, to_2tuple, trunc_normal_
7
- from timm.models.registry import register_model
8
- from timm.models.vision_transformer import _cfg
9
- import math
10
-
11
- class PVTSA(nn.Module):
12
- def __init__(self, dim, num_heads=8, qkv_bias=False, qk_scale=None, attn_drop=0., proj_drop=0., sr_ratio=1, linear=False,
13
- talking_head=False):
14
- super().__init__()
15
- assert dim % num_heads == 0, f"dim {dim} should be divided by num_heads {num_heads}."
16
-
17
- self.dim = dim
18
- self.num_heads = num_heads
19
- head_dim = dim // num_heads
20
- self.scale = qk_scale or head_dim ** -0.5
21
-
22
- self.q = nn.Linear(dim, dim, bias=qkv_bias)
23
- self.kv = nn.Linear(dim, dim * 2, bias=qkv_bias)
24
- self.attn_drop = nn.Dropout(attn_drop)
25
- self.proj = nn.Linear(dim, dim)
26
- self.proj_drop = nn.Dropout(proj_drop)
27
-
28
- self.proj_l = nn.Linear(num_heads, num_heads) if talking_head else nn.Identity()
29
- self.proj_w = nn.Linear(num_heads, num_heads) if talking_head else nn.Identity()
30
-
31
- self.linear = linear
32
- self.sr_ratio = sr_ratio
33
- if not linear:
34
- if sr_ratio > 1:
35
- self.sr = nn.Conv2d(dim, dim, kernel_size=sr_ratio, stride=sr_ratio)
36
- self.norm = nn.LayerNorm(dim)
37
- else:
38
- self.pool = nn.AdaptiveAvgPool2d(7)
39
- self.sr = nn.Conv2d(dim, dim, kernel_size=1, stride=1)
40
- self.norm = nn.LayerNorm(dim)
41
- self.act = nn.GELU()
42
- self.apply(self._init_weights)
43
-
44
- def _init_weights(self, m):
45
- if isinstance(m, nn.Linear):
46
- trunc_normal_(m.weight, std=.02)
47
- if isinstance(m, nn.Linear) and m.bias is not None:
48
- nn.init.constant_(m.bias, 0)
49
- elif isinstance(m, nn.LayerNorm):
50
- nn.init.constant_(m.bias, 0)
51
- nn.init.constant_(m.weight, 1.0)
52
- elif isinstance(m, nn.Conv2d):
53
- fan_out = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
54
- fan_out //= m.groups
55
- m.weight.data.normal_(0, math.sqrt(2.0 / fan_out))
56
- if m.bias is not None:
57
- m.bias.data.zero_()
58
-
59
- def forward(self, x, H, W):
60
- B, N, C = x.shape
61
- q = self.q(x).reshape(B, N, self.num_heads, C // self.num_heads).permute(0, 2, 1, 3)
62
-
63
- if not self.linear:
64
- if self.sr_ratio > 1:
65
- x_ = x.permute(0, 2, 1).reshape(B, C, H, W)
66
- x_ = self.sr(x_).reshape(B, C, -1).permute(0, 2, 1)
67
- x_ = self.norm(x_)
68
- kv = self.kv(x_).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
69
- else:
70
- kv = self.kv(x).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
71
- else:
72
- x_ = x.permute(0, 2, 1).reshape(B, C, H, W)
73
- x_ = self.sr(self.pool(x_)).reshape(B, C, -1).permute(0, 2, 1)
74
- x_ = self.norm(x_)
75
- x_ = self.act(x_)
76
- kv = self.kv(x_).reshape(B, -1, 2, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
77
- k, v = kv[0], kv[1]
78
-
79
- attn = (q @ k.transpose(-2, -1)) * self.scale
80
-
81
- attn = self.proj_l(attn.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
82
- attn = attn.softmax(dim=-1)
83
- attn = self.attn_drop(attn)
84
- attn = self.proj_w(attn.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
85
-
86
- x = (attn @ v).transpose(1, 2).reshape(B, N, C)
87
- x = self.proj(x)
88
- x = self.proj_drop(x)
89
-
90
- return x
@@ -1,37 +0,0 @@
1
- import torch.nn as nn
2
- class MHSA(nn.Module):
3
- def __init__(self, dim, num_heads=8, qkv_bias=False, attn_drop=0., proj_drop=0.,
4
- talking_head=False):
5
- super().__init__()
6
- self.num_heads = num_heads
7
- head_dim = dim // num_heads
8
- self.scale = head_dim ** -0.5
9
-
10
- self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
11
- self.attn_drop = nn.Dropout(attn_drop)
12
- self.proj = nn.Linear(dim, dim)
13
- self.proj_drop = nn.Dropout(proj_drop)
14
-
15
-
16
- self.proj_l = nn.Linear(num_heads, num_heads) if talking_head else nn.Identity()
17
- self.proj_w = nn.Linear(num_heads, num_heads) if talking_head else nn.Identity()
18
-
19
-
20
- def forward(self, x):
21
- B, N, C = x.shape
22
- qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
23
- q, k, v = qkv[0], qkv[1], qkv[2] # make torchscript happy (cannot use tensor as tuple)
24
-
25
- attn = (q @ k.transpose(-2, -1)) * self.scale # (B, H, T, D)
26
- # attn = self.proj_l(attn.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
27
- attn = self.proj_l(attn.transpose(1, -1)).transpose(1, -1)
28
-
29
- attn = attn.softmax(dim=-1)
30
- attn = self.attn_drop(attn)
31
- # attn = self.proj_w(attn.permute(0, 2, 3, 1)).permute(0, 3, 1, 2)
32
- attn = self.proj_w(attn.transpose(1, -1)).transpose(1, -1)
33
-
34
- x = (attn @ v).transpose(1, 2).reshape(B, N, C)
35
- x = self.proj(x)
36
- x = self.proj_drop(x)
37
- return x, k.mean(dim=1)
File without changes
File without changes
@@ -1,372 +0,0 @@
1
- """ Vision Transformer (ViT) in PyTorch
2
-
3
- A PyTorch implement of Vision Transformers as described in:
4
-
5
- 'An Image Is Worth 16 x 16 Words: Transformers for Image Recognition at Scale'
6
- - https://arxiv.org/abs/2010.11929
7
-
8
- `How to train your ViT? Data, Augmentation, and Regularization in Vision Transformers`
9
- - https://arxiv.org/abs/2106.10270
10
-
11
- The official jax code is released and available at https://github.com/google-research/vision_transformer
12
-
13
- DeiT model defs and weights from https://github.com/facebookresearch/deit,
14
- paper `DeiT: Data-efficient Image Transformers` - https://arxiv.org/abs/2012.12877
15
-
16
- Acknowledgments:
17
- * The paper authors for releasing code and weights, thanks!
18
- * I fixed my class token impl based on Phil Wang's https://github.com/lucidrains/vit-pytorch ... check it out
19
- for some einops/einsum fun
20
- * Simple transformer style inspired by Andrej Karpathy's https://github.com/karpathy/minGPT
21
- * Bert reference code checks against Huggingface Transformers and Tensorflow Bert
22
-
23
- Hacked together by / Copyright 2021 Ross Wightman
24
- """
25
- import math
26
- import logging
27
- from functools import partial
28
-
29
- import torch
30
- import torch.nn as nn
31
- import torch.nn.functional as F
32
-
33
- from timm.models.helpers import named_apply, adapt_input_conv
34
- from timm.models.layers import PatchEmbed, trunc_normal_, lecun_normal_
35
- from timm.models.registry import register_model
36
-
37
- from models.modules.PE import PositionalEncodingFourier, get_2d_sincos_pos_embed
38
- from models.modules.blocks import Block_DEIT
39
-
40
- _logger = logging.getLogger(__name__)
41
-
42
- class VisionTransformer(nn.Module):
43
- """ Vision Transformer
44
-
45
- A PyTorch impl of : `An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale`
46
- - https://arxiv.org/abs/2010.11929
47
-
48
- Includes distillation token & head support for `DeiT: Data-efficient Image Transformers`
49
- - https://arxiv.org/abs/2012.12877
50
- """
51
-
52
- def __init__(self, input_size=224, patch_size=16, in_chans=3, num_classes=1000, embed_dim=768, depth=12,
53
- num_heads=12, mlp_ratio=4., qkv_bias=True,
54
- drop_rate=0., attn_drop_rate=0., drop_path_rate=0., embed_layer=PatchEmbed, norm_layer=None,
55
- act_layer=None, weight_init='', **kwargs):
56
- super().__init__()
57
- self.num_classes = num_classes
58
- self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
59
- self.cls_token_num = kwargs["token_nums"][0]
60
- self.dist_token_num = kwargs["token_nums"][1]
61
-
62
- self.patch_embed_type = kwargs["embed_type"][0]
63
- self.pos_embed_type = kwargs["embed_type"][1]
64
- self.pos_embed_applied_type = kwargs["embed_type"][2]
65
-
66
-
67
- self.layer_scale = kwargs["model_type"][0]
68
- self.talking_head = kwargs["model_type"][1]
69
-
70
- self.num_tokens = self.cls_token_num + self.dist_token_num
71
-
72
- act_layer = act_layer or nn.GELU
73
-
74
- img_size = input_size
75
- if self.patch_embed_type == 0:
76
- self.patch_embed = embed_layer(
77
- img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
78
- else:
79
- raise ValueError
80
- num_patches = self.patch_embed.num_patches
81
-
82
- if self.cls_token_num: self.cls_token = nn.Parameter(torch.zeros(1, self.cls_token_num, embed_dim))
83
- if self.dist_token_num: self.dist_token = nn.Parameter(torch.zeros(1, self.dist_token_num, embed_dim))
84
-
85
- pos_embed_token_num = num_patches + self.num_tokens if self.pos_embed_applied_type == 0 else num_patches
86
- if self.pos_embed_type == 0:
87
- self.pos_embed = nn.Parameter(torch.zeros(1, pos_embed_token_num, embed_dim))
88
- elif self.pos_embed_type == 1:
89
- self.pos_embed = nn.Parameter(torch.zeros(1, pos_embed_token_num, embed_dim), requires_grad=False)
90
- pos_embed = get_2d_sincos_pos_embed(embed_dim, int(pos_embed_token_num **.5))
91
- self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
92
- elif self.pos_embed_type == 2:
93
- self.pos_embed = PositionalEncodingFourier(dim=embed_dim)
94
- else:
95
- raise ValueError
96
- self.pos_drop = nn.Dropout(p=drop_rate)
97
- dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
98
- self.blocks = nn.Sequential(*[
99
- Block_DEIT(
100
- dim=embed_dim, num_heads=num_heads, mlp_ratio=mlp_ratio, qkv_bias=qkv_bias, drop=drop_rate,
101
- attn_drop=attn_drop_rate, drop_path=dpr[i], norm_layer=norm_layer, act_layer=act_layer,
102
- # OURS
103
- layer_scale=self.layer_scale, talking_head=self.talking_head)
104
- for i in range(depth)])
105
- self.norm = norm_layer(embed_dim)
106
- self.pre_logits = nn.Identity()
107
-
108
- # Classifier head(s)
109
- self.head = nn.Linear(self.num_features, num_classes) if num_classes > 0 else nn.Identity()
110
- if self.dist_token_num:
111
- self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if num_classes > 0 else nn.Identity()
112
- else:
113
- self.head_dist = None
114
-
115
- self.init_weights(weight_init)
116
-
117
- def init_weights(self, mode=''):
118
- assert mode in ('jax', 'jax_nlhb', 'nlhb', '')
119
- head_bias = -math.log(self.num_classes) if 'nlhb' in mode else 0.
120
-
121
- if self.pos_embed_type not in [1,2]: trunc_normal_(self.pos_embed, std=.02)
122
- if self.dist_token_num:
123
- trunc_normal_(self.dist_token, std=.02)
124
- if mode.startswith('jax'):
125
- # leave cls token as zeros to match jax impl
126
- named_apply(partial(_init_vit_weights, head_bias=head_bias, jax_impl=True), self)
127
- else:
128
- if self.cls_token_num:
129
- trunc_normal_(self.cls_token, std=.02)
130
- self.apply(_init_vit_weights)
131
-
132
- def _init_weights(self, m):
133
- # this fn left here for compat with downstream users
134
- _init_vit_weights(m)
135
-
136
- @torch.jit.ignore()
137
- def load_pretrained(self, checkpoint_path, prefix=''):
138
- _load_weights(self, checkpoint_path, prefix)
139
-
140
- @torch.jit.ignore
141
- def no_weight_decay(self):
142
- return {'pos_embed', 'cls_token', 'dist_token'}
143
-
144
- def get_classifier(self):
145
- if self.dist_token is None:
146
- return self.head
147
- else:
148
- return self.head, self.head_dist
149
-
150
- def reset_classifier(self, num_classes, global_pool=''):
151
- self.num_classes = num_classes
152
- self.head = nn.Linear(self.embed_dim, num_classes) if num_classes > 0 else nn.Identity()
153
- if self.dist_token_num:
154
- self.head_dist = nn.Linear(self.embed_dim, self.num_classes) if num_classes > 0 else nn.Identity()
155
-
156
- def forward_features(self, x):
157
-
158
- x = self.patch_embed(x)
159
- B, T, D = x.shape
160
- cls_token = self.cls_token.expand(B, -1, -1)
161
-
162
- pos_embed = self.pos_embed if self.pos_embed_type not in [2] else self.pos_embed(B, int(T**0.5), int(T**0.5)).reshape(B, -1, T).transpose(1, 2)
163
- if self.pos_embed_applied_type == 0:
164
- x = torch.cat((cls_token, x), dim=1)
165
- x = self.pos_drop(x + pos_embed)
166
- else:
167
- x = torch.cat((cls_token, x + pos_embed), dim=1)
168
- x = self.pos_drop(x)
169
-
170
- x = self.blocks(x)
171
- x = self.norm(x)
172
- return self.pre_logits(x[:, :self.cls_token_num].mean(dim=1))
173
-
174
- def forward(self, x):
175
- x = self.forward_features(x)
176
- x = self.head(x)
177
- return x
178
-
179
-
180
- def _init_vit_weights(module: nn.Module, name: str = '', head_bias: float = 0., jax_impl: bool = False):
181
- """ ViT weight initialization
182
- * When called without n, head_bias, jax_impl args it will behave exactly the same
183
- as my original init for compatibility with prev hparam / downstream use cases (ie DeiT).
184
- * When called w/ valid n (module name) and jax_impl=True, will (hopefully) match JAX impl
185
- """
186
- if isinstance(module, nn.Linear):
187
- if name.startswith('head'):
188
- nn.init.zeros_(module.weight)
189
- nn.init.constant_(module.bias, head_bias)
190
- elif name.startswith('pre_logits'):
191
- lecun_normal_(module.weight)
192
- nn.init.zeros_(module.bias)
193
- else:
194
- if jax_impl:
195
- nn.init.xavier_uniform_(module.weight)
196
- if module.bias is not None:
197
- if 'mlp' in name:
198
- nn.init.normal_(module.bias, std=1e-6)
199
- else:
200
- nn.init.zeros_(module.bias)
201
- else:
202
- trunc_normal_(module.weight, std=.02)
203
- if module.bias is not None:
204
- nn.init.zeros_(module.bias)
205
- elif jax_impl and isinstance(module, nn.Conv2d):
206
- # NOTE conv was left to pytorch default in my original init
207
- lecun_normal_(module.weight)
208
- if module.bias is not None:
209
- nn.init.zeros_(module.bias)
210
- elif isinstance(module, (nn.LayerNorm, nn.GroupNorm, nn.BatchNorm2d)):
211
- nn.init.zeros_(module.bias)
212
- nn.init.ones_(module.weight)
213
-
214
-
215
- @torch.no_grad()
216
- def _load_weights(model: VisionTransformer, checkpoint_path: str, prefix: str = ''):
217
- """ Load weights from .npz checkpoints for official Google Brain Flax implementation
218
- """
219
- import numpy as np
220
-
221
- def _n2p(w, t=True):
222
- if w.ndim == 4 and w.shape[0] == w.shape[1] == w.shape[2] == 1:
223
- w = w.flatten()
224
- if t:
225
- if w.ndim == 4:
226
- w = w.transpose([3, 2, 0, 1])
227
- elif w.ndim == 3:
228
- w = w.transpose([2, 0, 1])
229
- elif w.ndim == 2:
230
- w = w.transpose([1, 0])
231
- return torch.from_numpy(w)
232
-
233
- w = np.load(checkpoint_path)
234
- if not prefix and 'opt/target/embedding/kernel' in w:
235
- prefix = 'opt/target/'
236
-
237
- if hasattr(model.patch_embed, 'backbone'):
238
- # hybrid
239
- backbone = model.patch_embed.backbone
240
- stem_only = not hasattr(backbone, 'stem')
241
- stem = backbone if stem_only else backbone.stem
242
- stem.conv.weight.copy_(adapt_input_conv(stem.conv.weight.shape[1], _n2p(w[f'{prefix}conv_root/kernel'])))
243
- stem.norm.weight.copy_(_n2p(w[f'{prefix}gn_root/scale']))
244
- stem.norm.bias.copy_(_n2p(w[f'{prefix}gn_root/bias']))
245
- if not stem_only:
246
- for i, stage in enumerate(backbone.stages):
247
- for j, block in enumerate(stage.blocks):
248
- bp = f'{prefix}block{i + 1}/unit{j + 1}/'
249
- for r in range(3):
250
- getattr(block, f'conv{r + 1}').weight.copy_(_n2p(w[f'{bp}conv{r + 1}/kernel']))
251
- getattr(block, f'norm{r + 1}').weight.copy_(_n2p(w[f'{bp}gn{r + 1}/scale']))
252
- getattr(block, f'norm{r + 1}').bias.copy_(_n2p(w[f'{bp}gn{r + 1}/bias']))
253
- if block.downsample is not None:
254
- block.downsample.conv.weight.copy_(_n2p(w[f'{bp}conv_proj/kernel']))
255
- block.downsample.norm.weight.copy_(_n2p(w[f'{bp}gn_proj/scale']))
256
- block.downsample.norm.bias.copy_(_n2p(w[f'{bp}gn_proj/bias']))
257
- embed_conv_w = _n2p(w[f'{prefix}embedding/kernel'])
258
- else:
259
- embed_conv_w = adapt_input_conv(
260
- model.patch_embed.proj.weight.shape[1], _n2p(w[f'{prefix}embedding/kernel']))
261
- model.patch_embed.proj.weight.copy_(embed_conv_w)
262
- model.patch_embed.proj.bias.copy_(_n2p(w[f'{prefix}embedding/bias']))
263
- model.cls_token.copy_(_n2p(w[f'{prefix}cls'], t=False))
264
- pos_embed_w = _n2p(w[f'{prefix}Transformer/posembed_input/pos_embedding'], t=False)
265
- if pos_embed_w.shape != model.pos_embed.shape:
266
- pos_embed_w = resize_pos_embed( # resize pos embedding when different size from pretrained weights
267
- pos_embed_w, model.pos_embed, getattr(model, 'num_tokens', 1), model.patch_embed.grid_size)
268
- model.pos_embed.copy_(pos_embed_w)
269
- model.norm.weight.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/scale']))
270
- model.norm.bias.copy_(_n2p(w[f'{prefix}Transformer/encoder_norm/bias']))
271
- if isinstance(model.head, nn.Linear) and model.head.bias.shape[0] == w[f'{prefix}head/bias'].shape[-1]:
272
- model.head.weight.copy_(_n2p(w[f'{prefix}head/kernel']))
273
- model.head.bias.copy_(_n2p(w[f'{prefix}head/bias']))
274
- if isinstance(getattr(model.pre_logits, 'fc', None), nn.Linear) and f'{prefix}pre_logits/bias' in w:
275
- model.pre_logits.fc.weight.copy_(_n2p(w[f'{prefix}pre_logits/kernel']))
276
- model.pre_logits.fc.bias.copy_(_n2p(w[f'{prefix}pre_logits/bias']))
277
- for i, block in enumerate(model.blocks.children()):
278
- block_prefix = f'{prefix}Transformer/encoderblock_{i}/'
279
- mha_prefix = block_prefix + 'MultiHeadDotProductAttention_1/'
280
- block.norm1.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/scale']))
281
- block.norm1.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_0/bias']))
282
- block.attn.qkv.weight.copy_(torch.cat([
283
- _n2p(w[f'{mha_prefix}{n}/kernel'], t=False).flatten(1).T for n in ('query', 'key', 'value')]))
284
- block.attn.qkv.bias.copy_(torch.cat([
285
- _n2p(w[f'{mha_prefix}{n}/bias'], t=False).reshape(-1) for n in ('query', 'key', 'value')]))
286
- block.attn.proj.weight.copy_(_n2p(w[f'{mha_prefix}out/kernel']).flatten(1))
287
- block.attn.proj.bias.copy_(_n2p(w[f'{mha_prefix}out/bias']))
288
- for r in range(2):
289
- getattr(block.mlp, f'fc{r + 1}').weight.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_{r}/kernel']))
290
- getattr(block.mlp, f'fc{r + 1}').bias.copy_(_n2p(w[f'{block_prefix}MlpBlock_3/Dense_{r}/bias']))
291
- block.norm2.weight.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/scale']))
292
- block.norm2.bias.copy_(_n2p(w[f'{block_prefix}LayerNorm_2/bias']))
293
-
294
-
295
- def resize_pos_embed(posemb, posemb_new, num_tokens=1, gs_new=()):
296
- # Rescale the grid of position embeddings when loading from state_dict. Adapted from
297
- # https://github.com/google-research/vision_transformer/blob/00883dd691c63a6830751563748663526e811cee/vit_jax/checkpoint.py#L224
298
- _logger.info('Resized position embedding: %s to %s', posemb.shape, posemb_new.shape)
299
- ntok_new = posemb_new.shape[1]
300
- if num_tokens:
301
- posemb_tok, posemb_grid = posemb[:, :num_tokens], posemb[0, num_tokens:]
302
- ntok_new -= num_tokens
303
- else:
304
- posemb_tok, posemb_grid = posemb[:, :0], posemb[0]
305
- gs_old = int(math.sqrt(len(posemb_grid)))
306
- if not len(gs_new): # backwards compatibility
307
- gs_new = [int(math.sqrt(ntok_new))] * 2
308
- assert len(gs_new) >= 2
309
- _logger.info('Position embedding grid-size from %s to %s', [gs_old, gs_old], gs_new)
310
- posemb_grid = posemb_grid.reshape(1, gs_old, gs_old, -1).permute(0, 3, 1, 2)
311
- posemb_grid = F.interpolate(posemb_grid, size=gs_new, mode='bilinear')
312
- posemb_grid = posemb_grid.permute(0, 2, 3, 1).reshape(1, gs_new[0] * gs_new[1], -1)
313
- posemb = torch.cat([posemb_tok, posemb_grid], dim=1)
314
- return posemb
315
-
316
-
317
- def checkpoint_filter_fn(state_dict, model):
318
- """ convert patch embedding weight from manual patchify + linear proj to conv"""
319
- out_dict = {}
320
- if 'model' in state_dict:
321
- # For deit models
322
- state_dict = state_dict['model']
323
- for k, v in state_dict.items():
324
- if 'patch_embed.proj.weight' in k and len(v.shape) < 4:
325
- # For old models that I trained prior to conv based patchification
326
- O, I, H, W = model.patch_embed.proj.weight.shape
327
- v = v.reshape(O, -1, H, W)
328
- elif k == 'pos_embed' and v.shape != model.pos_embed.shape:
329
- # To resize pos embedding when using model at different size from pretrained weights
330
- v = resize_pos_embed(
331
- v, model.pos_embed, getattr(model, 'num_tokens', 1), model.patch_embed.grid_size)
332
- out_dict[k] = v
333
- return out_dict
334
-
335
- @register_model
336
- def deit_tiny(pretrained=False, **kwargs):
337
- model_kwargs = dict(patch_size=16, embed_dim=192, depth=12, num_heads=3,
338
- norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
339
- model = VisionTransformer(**model_kwargs)
340
- if pretrained:
341
- checkpoint = torch.hub.load_state_dict_from_url(
342
- url="https://dl.fbaipublicfiles.com/deit/deit_tiny_patch16_224-a1311bcf.pth",
343
- map_location="cpu", check_hash=True
344
- )
345
- model.load_state_dict(checkpoint["model"])
346
- return model
347
-
348
- @register_model
349
- def deit_small(pretrained=False, **kwargs):
350
- model_kwargs = dict(patch_size=16, embed_dim=384, depth=12, num_heads=6,
351
- norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
352
- model = VisionTransformer(**model_kwargs)
353
- if pretrained:
354
- checkpoint = torch.hub.load_state_dict_from_url(
355
- url="https://dl.fbaipublicfiles.com/deit/deit_small_patch16_224-cd65a155.pth",
356
- map_location="cpu", check_hash=True
357
- )
358
- model.load_state_dict(checkpoint["model"])
359
- return model
360
-
361
- @register_model
362
- def deit_base(pretrained=False, **kwargs):
363
- model_kwargs = dict(patch_size=16, embed_dim=768, depth=12, num_heads=12,
364
- norm_layer=partial(nn.LayerNorm, eps=1e-6), **kwargs)
365
- model = VisionTransformer(**model_kwargs)
366
- if pretrained:
367
- checkpoint = torch.hub.load_state_dict_from_url(
368
- url="https://dl.fbaipublicfiles.com/deit/deit_base_patch16_224-b5f2ef4d.pth",
369
- map_location="cpu", check_hash=True
370
- )
371
- model.load_state_dict(checkpoint["model"])
372
- return model