spikezoo 0.1.1__py3-none-any.whl → 0.2__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (192) hide show
  1. spikezoo/__init__.py +13 -0
  2. spikezoo/archs/__pycache__/__init__.cpython-39.pyc +0 -0
  3. spikezoo/archs/base/__pycache__/nets.cpython-39.pyc +0 -0
  4. spikezoo/archs/base/nets.py +34 -0
  5. spikezoo/archs/bsf/README.md +92 -0
  6. spikezoo/archs/bsf/datasets/datasets.py +328 -0
  7. spikezoo/archs/bsf/datasets/ds_utils.py +64 -0
  8. spikezoo/archs/bsf/main.py +398 -0
  9. spikezoo/archs/bsf/metrics/psnr.py +22 -0
  10. spikezoo/archs/bsf/metrics/ssim.py +54 -0
  11. spikezoo/archs/bsf/models/bsf/__pycache__/align.cpython-39.pyc +0 -0
  12. spikezoo/archs/bsf/models/bsf/__pycache__/bsf.cpython-39.pyc +0 -0
  13. spikezoo/archs/bsf/models/bsf/__pycache__/rep.cpython-39.pyc +0 -0
  14. spikezoo/archs/bsf/models/bsf/align.py +154 -0
  15. spikezoo/archs/bsf/models/bsf/bsf.py +105 -0
  16. spikezoo/archs/bsf/models/bsf/dsft_convert.py +96 -0
  17. spikezoo/archs/bsf/models/bsf/rep.py +44 -0
  18. spikezoo/archs/bsf/models/get_model.py +7 -0
  19. spikezoo/archs/bsf/prepare_data/DSFT.py +62 -0
  20. spikezoo/archs/bsf/prepare_data/crop_dataset_train.py +135 -0
  21. spikezoo/archs/bsf/prepare_data/crop_dataset_val.py +139 -0
  22. spikezoo/archs/bsf/prepare_data/crop_train.sh +4 -0
  23. spikezoo/archs/bsf/prepare_data/crop_val.sh +4 -0
  24. spikezoo/archs/bsf/prepare_data/io_utils.py +64 -0
  25. spikezoo/archs/bsf/requirements.txt +9 -0
  26. spikezoo/archs/bsf/test.py +16 -0
  27. spikezoo/archs/bsf/utils.py +154 -0
  28. spikezoo/archs/spikeclip/__pycache__/nets.cpython-39.pyc +0 -0
  29. spikezoo/archs/spikeclip/nets.py +40 -0
  30. spikezoo/archs/spikeformer/CheckPoints/readme +1 -0
  31. spikezoo/archs/spikeformer/DataProcess/DataExtactor.py +60 -0
  32. spikezoo/archs/spikeformer/DataProcess/DataLoader.py +115 -0
  33. spikezoo/archs/spikeformer/DataProcess/LoadSpike.py +39 -0
  34. spikezoo/archs/spikeformer/EvalResults/readme +1 -0
  35. spikezoo/archs/spikeformer/LICENSE +21 -0
  36. spikezoo/archs/spikeformer/Metrics/Metrics.py +50 -0
  37. spikezoo/archs/spikeformer/Metrics/__init__.py +0 -0
  38. spikezoo/archs/spikeformer/Model/Loss.py +89 -0
  39. spikezoo/archs/spikeformer/Model/SpikeFormer.py +230 -0
  40. spikezoo/archs/spikeformer/Model/__init__.py +0 -0
  41. spikezoo/archs/spikeformer/Model/__pycache__/SpikeFormer.cpython-39.pyc +0 -0
  42. spikezoo/archs/spikeformer/Model/__pycache__/__init__.cpython-39.pyc +0 -0
  43. spikezoo/archs/spikeformer/README.md +30 -0
  44. spikezoo/archs/spikeformer/evaluate.py +87 -0
  45. spikezoo/archs/spikeformer/recon_real_data.py +97 -0
  46. spikezoo/archs/spikeformer/requirements.yml +95 -0
  47. spikezoo/archs/spikeformer/train.py +173 -0
  48. spikezoo/archs/spikeformer/utils.py +22 -0
  49. spikezoo/archs/spk2imgnet/.github/workflows/pylint.yml +23 -0
  50. spikezoo/archs/spk2imgnet/.gitignore +150 -0
  51. spikezoo/archs/spk2imgnet/DCNv2.py +135 -0
  52. spikezoo/archs/spk2imgnet/__pycache__/DCNv2.cpython-39.pyc +0 -0
  53. spikezoo/archs/spk2imgnet/__pycache__/align_arch.cpython-39.pyc +0 -0
  54. spikezoo/archs/spk2imgnet/__pycache__/nets.cpython-39.pyc +0 -0
  55. spikezoo/archs/spk2imgnet/align_arch.py +159 -0
  56. spikezoo/archs/spk2imgnet/dataset.py +144 -0
  57. spikezoo/archs/spk2imgnet/nets.py +230 -0
  58. spikezoo/archs/spk2imgnet/readme.md +86 -0
  59. spikezoo/archs/spk2imgnet/test_gen_imgseq.py +118 -0
  60. spikezoo/archs/spk2imgnet/train.py +189 -0
  61. spikezoo/archs/spk2imgnet/utils.py +64 -0
  62. spikezoo/archs/ssir/README.md +87 -0
  63. spikezoo/archs/ssir/configs/SSIR.yml +37 -0
  64. spikezoo/archs/ssir/configs/yml_parser.py +78 -0
  65. spikezoo/archs/ssir/datasets/dataset_sreds.py +170 -0
  66. spikezoo/archs/ssir/datasets/ds_utils.py +66 -0
  67. spikezoo/archs/ssir/losses.py +21 -0
  68. spikezoo/archs/ssir/main.py +326 -0
  69. spikezoo/archs/ssir/metrics/psnr.py +22 -0
  70. spikezoo/archs/ssir/metrics/ssim.py +54 -0
  71. spikezoo/archs/ssir/models/Vgg19.py +42 -0
  72. spikezoo/archs/ssir/models/__pycache__/layers.cpython-39.pyc +0 -0
  73. spikezoo/archs/ssir/models/__pycache__/networks.cpython-39.pyc +0 -0
  74. spikezoo/archs/ssir/models/layers.py +110 -0
  75. spikezoo/archs/ssir/models/networks.py +61 -0
  76. spikezoo/archs/ssir/requirements.txt +8 -0
  77. spikezoo/archs/ssir/shells/eval_SREDS.sh +6 -0
  78. spikezoo/archs/ssir/shells/train_SSIR.sh +12 -0
  79. spikezoo/archs/ssir/test.py +3 -0
  80. spikezoo/archs/ssir/utils.py +154 -0
  81. spikezoo/archs/ssml/__pycache__/cbam.cpython-39.pyc +0 -0
  82. spikezoo/archs/ssml/__pycache__/model.cpython-39.pyc +0 -0
  83. spikezoo/archs/ssml/cbam.py +224 -0
  84. spikezoo/archs/ssml/model.py +290 -0
  85. spikezoo/archs/ssml/res.png +0 -0
  86. spikezoo/archs/ssml/test.py +67 -0
  87. spikezoo/archs/stir/.git-credentials +0 -0
  88. spikezoo/archs/stir/README.md +65 -0
  89. spikezoo/archs/stir/ckpt_outputs/Descriptions.txt +1 -0
  90. spikezoo/archs/stir/configs/STIR.yml +37 -0
  91. spikezoo/archs/stir/configs/utils.py +155 -0
  92. spikezoo/archs/stir/configs/yml_parser.py +78 -0
  93. spikezoo/archs/stir/datasets/dataset_sreds.py +180 -0
  94. spikezoo/archs/stir/datasets/ds_utils.py +66 -0
  95. spikezoo/archs/stir/eval_SREDS.sh +5 -0
  96. spikezoo/archs/stir/main.py +397 -0
  97. spikezoo/archs/stir/metrics/losses.py +219 -0
  98. spikezoo/archs/stir/metrics/psnr.py +22 -0
  99. spikezoo/archs/stir/metrics/ssim.py +54 -0
  100. spikezoo/archs/stir/models/Vgg19.py +42 -0
  101. spikezoo/archs/stir/models/__pycache__/networks_STIR.cpython-39.pyc +0 -0
  102. spikezoo/archs/stir/models/__pycache__/submodules.cpython-39.pyc +0 -0
  103. spikezoo/archs/stir/models/__pycache__/transformer_new.cpython-39.pyc +0 -0
  104. spikezoo/archs/stir/models/networks_STIR.py +361 -0
  105. spikezoo/archs/stir/models/submodules.py +86 -0
  106. spikezoo/archs/stir/models/transformer_new.py +151 -0
  107. spikezoo/archs/stir/package_core/build/lib/package_core/__init__.py +0 -0
  108. spikezoo/archs/stir/package_core/build/lib/package_core/convertions.py +721 -0
  109. spikezoo/archs/stir/package_core/build/lib/package_core/disp_netS.py +133 -0
  110. spikezoo/archs/stir/package_core/build/lib/package_core/flow_utils.py +167 -0
  111. spikezoo/archs/stir/package_core/build/lib/package_core/generic_train_test.py +76 -0
  112. spikezoo/archs/stir/package_core/build/lib/package_core/geometry.py +458 -0
  113. spikezoo/archs/stir/package_core/build/lib/package_core/image_proc.py +183 -0
  114. spikezoo/archs/stir/package_core/build/lib/package_core/linalg.py +40 -0
  115. spikezoo/archs/stir/package_core/build/lib/package_core/losses.py +198 -0
  116. spikezoo/archs/stir/package_core/build/lib/package_core/metrics.py +51 -0
  117. spikezoo/archs/stir/package_core/build/lib/package_core/model_base.py +53 -0
  118. spikezoo/archs/stir/package_core/build/lib/package_core/net_basics.py +100 -0
  119. spikezoo/archs/stir/package_core/build/lib/package_core/resnet.py +333 -0
  120. spikezoo/archs/stir/package_core/build/lib/package_core/transforms.py +123 -0
  121. spikezoo/archs/stir/package_core/build/lib/package_core/utils.py +72 -0
  122. spikezoo/archs/stir/package_core/dist/package_core-0.0.0-py3.9.egg +0 -0
  123. spikezoo/archs/stir/package_core/package_core/__init__.py +0 -0
  124. spikezoo/archs/stir/package_core/package_core/__pycache__/__init__.cpython-39.pyc +0 -0
  125. spikezoo/archs/stir/package_core/package_core/__pycache__/net_basics.cpython-39.pyc +0 -0
  126. spikezoo/archs/stir/package_core/package_core/convertions.py +721 -0
  127. spikezoo/archs/stir/package_core/package_core/disp_netS.py +133 -0
  128. spikezoo/archs/stir/package_core/package_core/flow_utils.py +167 -0
  129. spikezoo/archs/stir/package_core/package_core/generic_train_test.py +76 -0
  130. spikezoo/archs/stir/package_core/package_core/geometry.py +458 -0
  131. spikezoo/archs/stir/package_core/package_core/image_proc.py +183 -0
  132. spikezoo/archs/stir/package_core/package_core/linalg.py +40 -0
  133. spikezoo/archs/stir/package_core/package_core/losses.py +198 -0
  134. spikezoo/archs/stir/package_core/package_core/metrics.py +51 -0
  135. spikezoo/archs/stir/package_core/package_core/model_base.py +53 -0
  136. spikezoo/archs/stir/package_core/package_core/net_basics.py +100 -0
  137. spikezoo/archs/stir/package_core/package_core/resnet.py +333 -0
  138. spikezoo/archs/stir/package_core/package_core/transforms.py +123 -0
  139. spikezoo/archs/stir/package_core/package_core/utils.py +72 -0
  140. spikezoo/archs/stir/package_core/package_core.egg-info/PKG-INFO +3 -0
  141. spikezoo/archs/stir/package_core/package_core.egg-info/SOURCES.txt +20 -0
  142. spikezoo/archs/stir/package_core/package_core.egg-info/dependency_links.txt +1 -0
  143. spikezoo/archs/stir/package_core/package_core.egg-info/top_level.txt +1 -0
  144. spikezoo/archs/stir/package_core/setup.py +5 -0
  145. spikezoo/archs/stir/requirements.txt +12 -0
  146. spikezoo/archs/stir/train_STIR.sh +9 -0
  147. spikezoo/archs/tfi/__pycache__/nets.cpython-39.pyc +0 -0
  148. spikezoo/archs/tfi/nets.py +43 -0
  149. spikezoo/archs/tfp/__pycache__/nets.cpython-39.pyc +0 -0
  150. spikezoo/archs/tfp/nets.py +13 -0
  151. spikezoo/archs/wgse/README.md +64 -0
  152. spikezoo/archs/wgse/__pycache__/dwtnets.cpython-39.pyc +0 -0
  153. spikezoo/archs/wgse/__pycache__/submodules.cpython-39.pyc +0 -0
  154. spikezoo/archs/wgse/dataset.py +59 -0
  155. spikezoo/archs/wgse/demo.png +0 -0
  156. spikezoo/archs/wgse/demo.py +83 -0
  157. spikezoo/archs/wgse/dwtnets.py +145 -0
  158. spikezoo/archs/wgse/eval.py +133 -0
  159. spikezoo/archs/wgse/logs/WGSE-Dwt1dNet-db8-5-ks3/log.txt +11 -0
  160. spikezoo/archs/wgse/submodules.py +68 -0
  161. spikezoo/archs/wgse/train.py +261 -0
  162. spikezoo/archs/wgse/transform.py +139 -0
  163. spikezoo/archs/wgse/utils.py +128 -0
  164. spikezoo/archs/wgse/weights/demo.png +0 -0
  165. spikezoo/data/base/test/gt/200_part1_key_id151.png +0 -0
  166. spikezoo/data/base/test/gt/200_part3_key_id151.png +0 -0
  167. spikezoo/data/base/test/gt/203_part1_key_id151.png +0 -0
  168. spikezoo/data/base/test/spike/200_part1_key_id151.dat +0 -0
  169. spikezoo/data/base/test/spike/200_part3_key_id151.dat +0 -0
  170. spikezoo/data/base/test/spike/203_part1_key_id151.dat +0 -0
  171. spikezoo/data/base/train/gt/203_part2_key_id151.png +0 -0
  172. spikezoo/data/base/train/gt/203_part3_key_id151.png +0 -0
  173. spikezoo/data/base/train/gt/203_part4_key_id151.png +0 -0
  174. spikezoo/data/base/train/spike/203_part2_key_id151.dat +0 -0
  175. spikezoo/data/base/train/spike/203_part3_key_id151.dat +0 -0
  176. spikezoo/data/base/train/spike/203_part4_key_id151.dat +0 -0
  177. spikezoo/datasets/base_dataset.py +2 -3
  178. spikezoo/metrics/__init__.py +1 -1
  179. spikezoo/models/base_model.py +1 -3
  180. spikezoo/pipeline/base_pipeline.py +7 -5
  181. spikezoo/pipeline/train_pipeline.py +1 -1
  182. spikezoo/utils/other_utils.py +16 -6
  183. spikezoo/utils/spike_utils.py +33 -29
  184. spikezoo/utils/vidar_loader.cpython-39-x86_64-linux-gnu.so +0 -0
  185. spikezoo-0.2.dist-info/METADATA +163 -0
  186. spikezoo-0.2.dist-info/RECORD +211 -0
  187. spikezoo/models/spcsnet_model.py +0 -19
  188. spikezoo-0.1.1.dist-info/METADATA +0 -39
  189. spikezoo-0.1.1.dist-info/RECORD +0 -36
  190. {spikezoo-0.1.1.dist-info → spikezoo-0.2.dist-info}/LICENSE.txt +0 -0
  191. {spikezoo-0.1.1.dist-info → spikezoo-0.2.dist-info}/WHEEL +0 -0
  192. {spikezoo-0.1.1.dist-info → spikezoo-0.2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,42 @@
1
+ import torch
2
+ import torch.nn as nn
3
+ from torchvision import models
4
+
5
+
6
+ class Vgg19(torch.nn.Module):
7
+ def __init__(self, requires_grad=False, rgb_range=1):
8
+ super(Vgg19, self).__init__()
9
+
10
+ vgg_pretrained_features = models.vgg19(pretrained=True).features
11
+
12
+ self.slice1 = torch.nn.Sequential()
13
+ for x in range(30):
14
+ self.slice1.add_module(str(x), vgg_pretrained_features[x])
15
+
16
+ if not requires_grad:
17
+ for param in self.slice1.parameters():
18
+ param.requires_grad = False
19
+
20
+ vgg_mean = (0.485, 0.456, 0.406)
21
+ vgg_std = (0.229 * rgb_range, 0.224 * rgb_range, 0.225 * rgb_range)
22
+ self.sub_mean = MeanShift(rgb_range, vgg_mean, vgg_std)
23
+
24
+ def forward(self, X):
25
+ h = self.sub_mean(X)
26
+ h_relu5_1 = self.slice1(h)
27
+ return h_relu5_1
28
+
29
+ class MeanShift(nn.Conv2d):
30
+ def __init__(self, rgb_range, rgb_mean, rgb_std, sign=-1):
31
+ super(MeanShift, self).__init__(3, 3, kernel_size=1)
32
+ std = torch.Tensor(rgb_std)
33
+ self.weight.data = torch.eye(3).view(3, 3, 1, 1)
34
+ self.weight.data.div_(std.view(3, 1, 1, 1))
35
+ self.bias.data = sign * rgb_range * torch.Tensor(rgb_mean)
36
+ self.bias.data.div_(std)
37
+ # self.requires_grad = False
38
+ self.weight.requires_grad = False
39
+ self.bias.requires_grad = False
40
+
41
+ if __name__ == '__main__':
42
+ vgg19 = Vgg19(requires_grad=False)
@@ -0,0 +1,361 @@
1
+ import os
2
+ from re import S
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ import numpy as np
7
+ import time
8
+
9
+ from torch.autograd import Variable
10
+
11
+ from ..package_core.package_core.net_basics import *
12
+ from ..models.transformer_new import *
13
+
14
+ class BasicModel(nn.Module):
15
+ def __init__(self):
16
+ super().__init__()
17
+
18
+ ####################################################################################
19
+ ## Tools functions for neural networks
20
+ def weight_parameters(self):
21
+ return [param for name, param in self.named_parameters() if 'weight' in name]
22
+
23
+ def bias_parameters(self):
24
+ return [param for name, param in self.named_parameters() if 'bias' in name]
25
+
26
+ def num_parameters(self):
27
+ return sum([p.data.nelement() if p.requires_grad else 0 for p in self.parameters()])
28
+
29
+ def init_weights(self):
30
+ for layer in self.named_modules():
31
+ if isinstance(layer, nn.Conv2d):
32
+ nn.init.kaiming_normal_(layer.weight)
33
+ if layer.bias is not None:
34
+ nn.init.constant_(layer.bias, 0)
35
+
36
+ elif isinstance(layer, nn.ConvTranspose2d):
37
+ nn.init.kaiming_normal_(layer.weight)
38
+ if layer.bias is not None:
39
+ nn.init.constant_(layer.bias, 0)
40
+
41
+
42
+ def TFP(spk, channel_step=1):
43
+ num = spk.size(1) // 2
44
+ rep_spk = torch.mean(spk, dim=1).unsqueeze(1)
45
+
46
+ for i in range(1, num):
47
+ if i*channel_step < num:
48
+ rep_spk = torch.cat((rep_spk, torch.mean(spk[:, i*channel_step : -i*channel_step, :, :], 1).unsqueeze(1)), 1)
49
+
50
+ return rep_spk
51
+
52
+ class ResidualBlock(nn.Module):
53
+ def __init__(self, in_channles, num_channles, use_1x1conv=False, strides=1):
54
+ super(ResidualBlock, self).__init__()
55
+ self.conv1 = nn.Conv2d(
56
+ in_channles, num_channles, kernel_size=3, stride=strides, padding=1)
57
+ self.conv2 = nn.Conv2d(
58
+ num_channles, num_channles, kernel_size=3, padding=1)
59
+ if use_1x1conv:
60
+ self.conv3=nn.Conv2d(
61
+ in_channles, num_channles,kernel_size=1, stride=strides)
62
+ else:
63
+ self.conv3=None
64
+ self.bn1=nn.BatchNorm2d(num_channles)
65
+ self.bn2=nn.BatchNorm2d(num_channles)
66
+ self.relu=nn.ReLU(inplace=True)
67
+ def forward(self,x):
68
+ y= F.relu(self.bn1(self.conv1(x)))
69
+ y=self.bn2(self.conv2(y))
70
+ if self.conv3:
71
+ x=self.conv3(x)
72
+ y+=x
73
+ return F.relu(y)
74
+
75
+ class DimReduceConv(nn.Module):
76
+ def __init__(self, in_channels, out_channels, bias=True):
77
+ super(DimReduceConv, self).__init__()
78
+ self.conv1 = nn.Sequential(
79
+ nn.Conv2d(in_channels, in_channels, kernel_size=3, stride=1, padding=1, bias=bias),
80
+ nn.PReLU(in_channels)
81
+ )
82
+ self.conv2 = nn.Sequential(
83
+ nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=1, padding=1, bias=bias),
84
+ nn.PReLU(out_channels)
85
+ )
86
+
87
+ def forward(self, x):
88
+ out = self.conv1(x)
89
+ out = self.conv2(out)
90
+ return out
91
+
92
+ def convrelu(in_channels, out_channels, kernel_size=3, stride=1, padding=1, dilation=1, groups=1, bias=True):
93
+ return nn.Sequential(
94
+ nn.Conv2d(in_channels, out_channels, kernel_size, stride, padding, dilation, groups, bias=bias),
95
+ nn.PReLU(out_channels)
96
+ )
97
+
98
+ class ImageEncoder(nn.Module):
99
+ def __init__(self, in_chs, init_chs, num_resblock=1):
100
+ super(ImageEncoder, self).__init__()
101
+ self.conv0 = conv2d(
102
+ in_planes=in_chs,
103
+ out_planes=init_chs[0],
104
+ batch_norm=False,
105
+ activation=nn.PReLU(),
106
+ kernel_size=7,
107
+ stride=1
108
+ )
109
+
110
+ self.conv1 = conv2d(
111
+ in_planes=init_chs[0],
112
+ out_planes=init_chs[1],
113
+ batch_norm=False,
114
+ activation=nn.PReLU(),
115
+ kernel_size=3,
116
+ stride=2
117
+ )
118
+ self.resblocks1 = Cascade_resnet_blocks(in_planes=init_chs[1], n_blocks=num_resblock)
119
+ self.conv2 = conv2d(
120
+ in_planes=init_chs[1],
121
+ out_planes=init_chs[2],
122
+ batch_norm=False,
123
+ activation=nn.PReLU(),
124
+ kernel_size=3,
125
+ stride=2
126
+ )
127
+ self.resblocks2 = Cascade_resnet_blocks(in_planes=init_chs[2], n_blocks=num_resblock)
128
+ self.conv3 = conv2d(
129
+ in_planes=init_chs[2],
130
+ out_planes=init_chs[3],
131
+ batch_norm=False,
132
+ activation=nn.PReLU(),
133
+ kernel_size=3,
134
+ stride=2
135
+ )
136
+ self.resblocks3 = Cascade_resnet_blocks(in_planes=init_chs[3], n_blocks=num_resblock)
137
+ self.conv4 = conv2d(
138
+ in_planes=init_chs[3],
139
+ out_planes=init_chs[4],
140
+ batch_norm=False,
141
+ activation=nn.PReLU(),
142
+ kernel_size=3,
143
+ stride=2
144
+ )
145
+ self.resblocks4 = Cascade_resnet_blocks(in_planes=init_chs[4], n_blocks=num_resblock)
146
+
147
+ def forward(self, x):
148
+ x0 = self.conv0(x)
149
+ x1 = self.resblocks1(self.conv1(x0))
150
+ x2 = self.resblocks2(self.conv2(x1))
151
+ x3 = self.resblocks3(self.conv3(x2))
152
+ x4 = self.resblocks4(self.conv4(x3))
153
+
154
+ return x4, x3, x2, x1
155
+
156
+ def predict_img(in_channels):
157
+ return nn.Conv2d(in_channels, 1, kernel_size=3, stride=1, padding=1, bias=True)
158
+
159
+ def predict_img_flow(in_channels):
160
+ return nn.Conv2d(in_channels, 5, kernel_size=3, stride=1, padding=1, bias=True)#first 4: flow; last 1: img
161
+
162
+ def deconv(in_channels, out_channels, kernel_size=4, stride=2, padding=1):
163
+ return nn.ConvTranspose2d(in_channels, out_channels, kernel_size, stride, padding, bias=True)
164
+
165
+ class STIRDecorder_top_level(nn.Module):#top level
166
+ def __init__(self, in_chs, hidd_chs):
167
+ super(STIRDecorder_top_level, self).__init__()
168
+ self.hidd_chs = hidd_chs
169
+
170
+ self.convrelu = convrelu(in_chs*3, in_chs*3)
171
+
172
+ self.resblocks = Cascade_resnet_blocks(in_planes=in_chs*3, n_blocks=1)#3
173
+
174
+ self.predict_img_flow = predict_img_flow(in_chs*3)
175
+
176
+ def forward(self, c_cat):
177
+ x0 = c_cat
178
+ x0 = self.convrelu(x0)
179
+ x_hidd = self.resblocks(x0)
180
+
181
+ img_flow_curr = self.predict_img_flow(x_hidd)
182
+ flow_0, flow_1 = img_flow_curr[:,:2], img_flow_curr[:,2:4]
183
+ img_pred = img_flow_curr[:,4:5]
184
+
185
+ return img_pred, x_hidd, flow_0, flow_1
186
+
187
+ class STIRDecorder_bottom_level(nn.Module):#second and third levels
188
+ def __init__(self, in_chs_last, in_chs, hidd_chs, N_group):
189
+ super(STIRDecorder_bottom_level, self).__init__()
190
+ self.hidd_chs = hidd_chs
191
+ self.N_group = N_group
192
+
193
+ if self.N_group > 1:
194
+ self.predict_flow_group = nn.Conv2d(in_chs_last*3, 4*(self.N_group-1), kernel_size=3, stride=1, padding=1, bias=True)
195
+ self.deconv_flow_group = deconv(4*(self.N_group-1), 4*(self.N_group-1), kernel_size=4, stride=2, padding=1)
196
+
197
+ self.deconv_flow = deconv(4, 4, kernel_size=4, stride=2, padding=1)
198
+ self.deconv_hidden = deconv(3*in_chs_last, self.hidd_chs, kernel_size=4, stride=2, padding=1)
199
+
200
+ self.convrelu = DimReduceConv(in_chs*2*self.N_group + in_chs + 4*self.N_group + 1 + self.hidd_chs, in_chs*3)
201
+ self.resblocks = Cascade_resnet_blocks(in_planes=in_chs*3, n_blocks=1)#3
202
+
203
+ self.predict_img = predict_img(in_chs*3)
204
+
205
+ def warp(self, img, flow):
206
+ B, _, H, W = flow.shape
207
+ xx = torch.linspace(-1.0, 1.0, W).view(1, 1, 1, W).expand(B, -1, H, -1)
208
+ yy = torch.linspace(-1.0, 1.0, H).view(1, 1, H, 1).expand(B, -1, -1, W)
209
+ grid = torch.cat([xx, yy], 1).to(img)
210
+ flow_ = torch.cat([flow[:, 0:1, :, :] / ((W - 1.0) / 2.0), flow[:, 1:2, :, :] / ((H - 1.0) / 2.0)], 1)
211
+ grid_ = (grid + flow_).permute(0, 2, 3, 1)
212
+ output = F.grid_sample(input=img, grid=grid_, mode='bilinear', padding_mode='border', align_corners=True)
213
+ return output
214
+
215
+ def forward(self, img_last, hidden_last, flow_0_last, flow_1_last, upflow_last, c_0, c_1, c_2):
216
+
217
+ warped_group = []
218
+ if self.N_group > 1:
219
+ flow_group_last = self.predict_flow_group(hidden_last) + torch.cat([upflow_last for _ in range(self.N_group-1)], dim=1) #flow residual
220
+ upflow_group_last = self.deconv_flow_group(flow_group_last)
221
+ for i in range(self.N_group-1):
222
+ warped_group_0 = self.warp(c_0, upflow_group_last[:, 4*i : 4*i+2])
223
+ warped_group_2 = self.warp(c_2, upflow_group_last[:, 4*i+2 : 4*i+4])
224
+ warped_group.append(warped_group_0)
225
+ warped_group.append(warped_group_2)
226
+
227
+ upflow = self.deconv_flow(torch.cat([flow_0_last, flow_1_last], dim=1))
228
+ uphidden = self.deconv_hidden(hidden_last)
229
+ upimg = F.interpolate(img_last, scale_factor=2.0, mode='bilinear')
230
+
231
+ upflow_0, upflow_1 = upflow[:,0:2], upflow[:,2:4]
232
+
233
+ warp_0 = self.warp(c_0, upflow_0)
234
+ warp_2 = self.warp(c_2, upflow_1)
235
+
236
+ x0 = torch.cat([c_1, warp_0, warp_2]+ warped_group, dim=1)
237
+ if self.N_group > 1:
238
+ x0 = torch.cat([upimg, x0, uphidden, upflow_0, upflow_1, upflow_group_last], dim=1)
239
+ else:
240
+ x0 = torch.cat([upimg, x0, uphidden, upflow_0, upflow_1], dim=1)
241
+ x0 = self.convrelu(x0)
242
+ x_hidd = self.resblocks(x0)
243
+
244
+ img_pred = self.predict_img(x_hidd)
245
+
246
+ return img_pred, x_hidd, upflow_0, upflow_1, upflow_0, upflow_1
247
+
248
+ class STIRDecorder(nn.Module):#second and third levels
249
+ def __init__(self, in_chs_last, in_chs, hidd_chs):
250
+ super(STIRDecorder, self).__init__()
251
+ self.hidd_chs = hidd_chs
252
+
253
+ self.deconv_flow = deconv(4, 4, kernel_size=4, stride=2, padding=1)
254
+ self.deconv_hidden = deconv(3*in_chs_last, self.hidd_chs, kernel_size=4, stride=2, padding=1)
255
+
256
+ self.convrelu = DimReduceConv(in_chs*3 + 4 + 1 + self.hidd_chs, in_chs*3)
257
+ self.resblocks = Cascade_resnet_blocks(in_planes=in_chs*3, n_blocks=1)#3
258
+
259
+ self.predict_img_flow = predict_img_flow(in_chs*3)
260
+
261
+ def warp(self, img, flow):
262
+ B, _, H, W = flow.shape
263
+ xx = torch.linspace(-1.0, 1.0, W).view(1, 1, 1, W).expand(B, -1, H, -1)
264
+ yy = torch.linspace(-1.0, 1.0, H).view(1, 1, H, 1).expand(B, -1, -1, W)
265
+ grid = torch.cat([xx, yy], 1).to(img)
266
+ flow_ = torch.cat([flow[:, 0:1, :, :] / ((W - 1.0) / 2.0), flow[:, 1:2, :, :] / ((H - 1.0) / 2.0)], 1)
267
+ grid_ = (grid + flow_).permute(0, 2, 3, 1)
268
+ output = F.grid_sample(input=img, grid=grid_, mode='bilinear', padding_mode='border', align_corners=True)
269
+ return output
270
+
271
+ def forward(self, img_last, hidden_last, flow_0_last, flow_1_last, c_0, c_1, c_2):
272
+ upflow = self.deconv_flow(torch.cat([flow_0_last, flow_1_last], dim=1))
273
+ uphidden = self.deconv_hidden(hidden_last)
274
+ upimg = F.interpolate(img_last, scale_factor=2.0, mode='bilinear')
275
+
276
+ upflow_0, upflow_1 = upflow[:,0:2], upflow[:,2:4]
277
+
278
+ warp_0 = self.warp(c_0, upflow_0)
279
+ warp_2 = self.warp(c_2, upflow_1)
280
+
281
+ x0 = torch.cat([c_1, warp_0, warp_2], dim=1)
282
+ x0 = torch.cat([upimg, x0, uphidden, upflow_0, upflow_1], dim=1)
283
+ x0 = self.convrelu(x0)
284
+ x_hidd = self.resblocks(x0)
285
+
286
+ img_flow_curr = self.predict_img_flow(x_hidd)
287
+ flow_0, flow_1 = img_flow_curr[:,:2]+upflow_0, img_flow_curr[:,2:4]+upflow_1
288
+ img_pred = img_flow_curr[:,4:5]
289
+
290
+ return img_pred, x_hidd, flow_0, flow_1, upflow_0, upflow_1
291
+
292
+
293
+ ##############################Our Model####################################
294
+ class STIR(BasicModel):
295
+ def __init__(self, hidd_chs=8, win_r=6, win_step=7):
296
+ super().__init__()
297
+
298
+ self.init_chs = [16, 24, 32, 64, 96]
299
+ self.hidd_chs = hidd_chs
300
+ self.attn_num_splits = 1
301
+
302
+ self.N_group = 3
303
+
304
+ dim_tfp = 16
305
+ self.encoder = ImageEncoder(in_chs=dim_tfp, init_chs=self.init_chs)
306
+
307
+ self.transformer = CrossTransformerBlock(dim=self.init_chs[-1], num_heads=4, ffn_expansion_factor=2.66, bias=False, LayerNorm_type='WithBias')
308
+
309
+ self.decorder_5nd = STIRDecorder_top_level(self.init_chs[-1], self.hidd_chs)
310
+ self.decorder_4nd = STIRDecorder(self.init_chs[-1], self.init_chs[-2], self.hidd_chs)
311
+ self.decorder_3rd = STIRDecorder(self.init_chs[-2], self.init_chs[-3], self.hidd_chs)
312
+ self.decorder_2nd = STIRDecorder(self.init_chs[-3], self.init_chs[-4], self.hidd_chs)
313
+ self.decorder_1st = STIRDecorder_bottom_level(self.init_chs[-4], dim_tfp, self.hidd_chs, self.N_group)
314
+ self.win_r = win_r
315
+ self.win_step = win_step
316
+
317
+ self.resnet = ResidualBlock(in_channles=21, num_channles=11, use_1x1conv=True)
318
+
319
+ def forward(self, x):
320
+ b,_,h,w=x.size()
321
+
322
+ block1 = x[:, 0 : 21, :, :]
323
+ block2 = x[:, 20 : 41, :, :]
324
+ block3 = x[:, 40 : 61, :, :]
325
+
326
+ repre1 = TFP(block1, channel_step=2)#C: 5
327
+ repre2 = TFP(block2, channel_step=2)
328
+ repre3 = TFP(block3, channel_step=2)
329
+
330
+ repre_resnet = self.resnet(torch.cat((block1, block2, block3), dim=0)) #[3B, 11, H, W]
331
+ repre1_resnet, repre2_resnet, repre3_resnet = repre_resnet[:b], repre_resnet[b:2*b], repre_resnet[2*b:]
332
+
333
+ repre1 = torch.cat((repre1, repre1_resnet), 1)#C: 16
334
+ repre2 = torch.cat((repre2, repre2_resnet), 1)
335
+ repre3 = torch.cat((repre3, repre3_resnet), 1)
336
+
337
+ concat = torch.cat((repre1, repre2, repre3), dim=0)
338
+ feature_4, feature_3, feature_2, feature_1 = self.encoder(concat)
339
+ c0_4, c0_3, c0_2, c0_1 = feature_4[:b], feature_3[:b], feature_2[:b], feature_1[:b]
340
+ c1_4, c1_3, c1_2, c1_1 = feature_4[b:2*b], feature_3[b:2*b], feature_2[b:2*b], feature_1[b:2*b]
341
+ c2_4, c2_3, c2_2, c2_1 = feature_4[2*b:], feature_3[2*b:], feature_2[2*b:], feature_1[2*b:]
342
+
343
+ c_cat = self.transformer(c1_4, c0_4, c2_4)
344
+ img_pred_4, x_hidd_4, flow_0_4, flow_1_4 = self.decorder_5nd(c_cat)
345
+ img_pred_3, x_hidd_3, flow_0_3, flow_1_3, upflow_0_3, upflow_1_3 = self.decorder_4nd(img_pred_4, x_hidd_4, flow_0_4, flow_1_4, c0_3, c1_3, c2_3)
346
+ img_pred_2, x_hidd_2, flow_0_2, flow_1_2, upflow_0_2, upflow_1_2 = self.decorder_3rd(img_pred_3, x_hidd_3, flow_0_3, flow_1_3, c0_2, c1_2, c2_2)
347
+ img_pred_1, x_hidd_1, flow_0_1, flow_1_1, upflow_0_1, upflow_1_1 = self.decorder_2nd(img_pred_2, x_hidd_2, flow_0_2, flow_1_2, c0_1, c1_1, c2_1)
348
+ img_pred_0, _, _, _, upflow_0_0, upflow_1_0 = self.decorder_1st(img_pred_1, x_hidd_1, flow_0_1, flow_1_1, torch.cat((upflow_0_1, upflow_1_1), dim=1), repre1, repre2, repre3)
349
+
350
+ if self.training:
351
+ return torch.clamp(img_pred_0, 0, 1),\
352
+ [torch.clamp(img_pred_0, 0, 1), upflow_0_0, upflow_1_0],\
353
+ [torch.clamp(img_pred_1, 0, 1), upflow_0_1, upflow_1_1],\
354
+ [torch.clamp(img_pred_2, 0, 1), upflow_0_2, upflow_1_2],\
355
+ [torch.clamp(img_pred_3, 0, 1), upflow_0_3, upflow_1_3],\
356
+ [torch.clamp(img_pred_4, 0, 1)],\
357
+ [img_pred_0, img_pred_0, img_pred_0]
358
+ else:
359
+ return img_pred_0
360
+
361
+
@@ -0,0 +1,86 @@
1
+ import torch
2
+ import torch.nn.functional as F
3
+ import torch.nn as nn
4
+
5
+ def actFunc(act, *args, **kwargs):
6
+ act = act.lower()
7
+ if act == 'relu':
8
+ return nn.ReLU()
9
+ elif act == 'relu6':
10
+ return nn.ReLU6()
11
+ elif act == 'leakyrelu':
12
+ return nn.LeakyReLU(0.1)
13
+ elif act == 'prelu':
14
+ return nn.PReLU()
15
+ elif act == 'rrelu':
16
+ return nn.RReLU(0.1, 0.3)
17
+ elif act == 'selu':
18
+ return nn.SELU()
19
+ elif act == 'celu':
20
+ return nn.CELU()
21
+ elif act == 'elu':
22
+ return nn.ELU()
23
+ elif act == 'gelu':
24
+ return nn.GELU()
25
+ elif act == 'tanh':
26
+ return nn.Tanh()
27
+ else:
28
+ raise NotImplementedError
29
+
30
+ class ResBlock(nn.Module):
31
+ """
32
+ Residual block
33
+ """
34
+ def __init__(self, in_chs, activation='relu', batch_norm=False):
35
+ super(ResBlock, self).__init__()
36
+ op = []
37
+ for i in range(2):
38
+ op.append(conv3x3(in_chs, in_chs))
39
+ if batch_norm:
40
+ op.append(nn.BatchNorm2d(in_chs))
41
+ if i == 0:
42
+ op.append(actFunc(activation))
43
+ self.main_branch = nn.Sequential(*op)
44
+
45
+ def forward(self, x):
46
+ out = self.main_branch(x)
47
+ out += x
48
+ return out
49
+
50
+ # conv blocks
51
+ def conv1x1(in_channels, out_channels, stride=1):
52
+ return nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, padding=0, bias=True)
53
+
54
+ def conv3x3(in_channels, out_channels, stride=1):
55
+ return nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=True)
56
+
57
+ def conv5x5(in_channels, out_channels, stride=1):
58
+ return nn.Conv2d(in_channels, out_channels, kernel_size=5, stride=stride, padding=2, bias=True)
59
+
60
+ def deconv4x4(in_channels, out_channels, stride=2):
61
+ return nn.ConvTranspose2d(in_channels, out_channels, kernel_size=4, stride=stride, padding=1)
62
+
63
+ def deconv5x5(in_channels, out_channels, stride=2):
64
+ return nn.ConvTranspose2d(in_channels, out_channels, kernel_size=5, stride=stride, padding=2, output_padding=1)
65
+
66
+ # conv resblock
67
+ def conv_resblock_three(in_channels, out_channels, stride=1):
68
+ return nn.Sequential(conv3x3(in_channels, out_channels, stride), nn.ReLU(), ResBlock(out_channels), ResBlock(out_channels), ResBlock(out_channels))
69
+
70
+ def conv_resblock_two(in_channels, out_channels, stride=1):
71
+ return nn.Sequential(conv3x3(in_channels, out_channels, stride), nn.ReLU(), ResBlock(out_channels), ResBlock(out_channels))
72
+
73
+ def conv_resblock_one(in_channels, out_channels, stride=1):
74
+ return nn.Sequential(conv3x3(in_channels, out_channels, stride), nn.ReLU(), ResBlock(out_channels))
75
+
76
+ def conv_1x1_resblock_one(in_channels, out_channels, stride=1):
77
+ return nn.Sequential(conv1x1(in_channels, out_channels, stride), nn.ReLU(), ResBlock(out_channels))
78
+
79
+ def conv_resblock_two_DS(in_channels, out_channels, stride=2):
80
+ return nn.Sequential(conv3x3(in_channels, out_channels, stride), nn.ReLU(), ResBlock(out_channels), ResBlock(out_channels))
81
+
82
+ def conv3x3_leaky_relu(in_channels, out_channels, stride=1):
83
+ return nn.Sequential(nn.Conv2d(in_channels, out_channels, kernel_size=3, stride=stride, padding=1, bias=True), nn.LeakyReLU(0.1))
84
+
85
+ def conv1x1_leaky_relu(in_channels, out_channels, stride=1):
86
+ return nn.Sequential(nn.Conv2d(in_channels, out_channels, kernel_size=1, stride=stride, padding=0, bias=True), nn.LeakyReLU(0.1))
@@ -0,0 +1,151 @@
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ from einops import rearrange, repeat
6
+ from einops.layers.torch import Rearrange
7
+ from torch.autograd import Variable
8
+ import numbers
9
+ from torch.autograd import Variable
10
+ from .submodules import *
11
+
12
+ ################# Restormer #####################
13
+
14
+ ##########################################################################
15
+ ## Layer Norm
16
+ def to_3d(x):
17
+ return rearrange(x, 'b c h w -> b (h w) c')
18
+
19
+ def to_4d(x,h,w):
20
+ return rearrange(x, 'b (h w) c -> b c h w',h=h,w=w)
21
+
22
+ class BiasFree_LayerNorm(nn.Module):
23
+ def __init__(self, normalized_shape):
24
+ super(BiasFree_LayerNorm, self).__init__()
25
+ if isinstance(normalized_shape, numbers.Integral):
26
+ normalized_shape = (normalized_shape,)
27
+ normalized_shape = torch.Size(normalized_shape)
28
+
29
+ assert len(normalized_shape) == 1
30
+
31
+ self.weight = nn.Parameter(torch.ones(normalized_shape))
32
+ self.normalized_shape = normalized_shape
33
+
34
+ def forward(self, x):
35
+ sigma = x.var(-1, keepdim=True, unbiased=False)
36
+ return x / torch.sqrt(sigma+1e-5) * self.weight
37
+
38
+
39
+ class WithBias_LayerNorm(nn.Module):
40
+ def __init__(self, normalized_shape):
41
+ super(WithBias_LayerNorm, self).__init__()
42
+ if isinstance(normalized_shape, numbers.Integral):
43
+ normalized_shape = (normalized_shape,)
44
+ normalized_shape = torch.Size(normalized_shape)
45
+
46
+ assert len(normalized_shape) == 1
47
+
48
+ self.weight = nn.Parameter(torch.ones(normalized_shape))
49
+ self.bias = nn.Parameter(torch.zeros(normalized_shape))
50
+ self.normalized_shape = normalized_shape
51
+
52
+ def forward(self, x):
53
+ mu = x.mean(-1, keepdim=True)
54
+ sigma = x.var(-1, keepdim=True, unbiased=False)
55
+ return (x - mu) / torch.sqrt(sigma+1e-5) * self.weight + self.bias
56
+
57
+
58
+ class LayerNorm(nn.Module):
59
+ def __init__(self, dim, LayerNorm_type):
60
+ super(LayerNorm, self).__init__()
61
+ if LayerNorm_type =='BiasFree':
62
+ self.body = BiasFree_LayerNorm(dim)
63
+ else:
64
+ self.body = WithBias_LayerNorm(dim)
65
+
66
+ def forward(self, x):
67
+ h, w = x.shape[-2:]
68
+ return to_4d(self.body(to_3d(x)), h, w)
69
+
70
+ ##########################################################################
71
+ ## Gated-Dconv Feed-Forward Network (GDFN)
72
+ class FeedForward(nn.Module):
73
+ def __init__(self, dim, ffn_expansion_factor, bias):
74
+ super(FeedForward, self).__init__()
75
+ hidden_features = int(dim*ffn_expansion_factor)
76
+ self.project_in = nn.Conv2d(dim, hidden_features*2, kernel_size=1, bias=bias)
77
+ self.dwconv = nn.Conv2d(hidden_features*2, hidden_features*2, kernel_size=3, stride=1, padding=1, groups=hidden_features*2, bias=bias)
78
+ self.project_out = nn.Conv2d(hidden_features, dim, kernel_size=1, bias=bias)
79
+
80
+ def forward(self, x):
81
+ x = self.project_in(x)
82
+ x1, x2 = self.dwconv(x).chunk(2, dim=1)
83
+ x = F.gelu(x1) * x2
84
+ x = self.project_out(x)
85
+ return x
86
+
87
+ class CrossAttention(nn.Module):
88
+ def __init__(self, dim, num_heads, bias):
89
+ super(CrossAttention, self).__init__()
90
+ self.num_heads = num_heads
91
+ self.temperature1 = nn.Parameter(torch.ones(num_heads, 1, 1))
92
+ self.temperature2 = nn.Parameter(torch.ones(num_heads, 1, 1))
93
+
94
+ self.q = nn.Conv2d(dim, dim*2, kernel_size=1, bias=bias)
95
+ self.kv1 = nn.Conv2d(dim, dim*2, kernel_size=1, bias=bias)
96
+ self.kv2 = nn.Conv2d(dim, dim*2, kernel_size=1, bias=bias)
97
+ self.q_dwconv = nn.Conv2d(dim*2, dim*2, kernel_size=3, stride=1, padding=1, groups=dim, bias=bias)
98
+ self.kv1_dwconv = nn.Conv2d(dim*2, dim*2, kernel_size=3, stride=1, padding=1, groups=dim*2, bias=bias)
99
+ self.kv2_dwconv = nn.Conv2d(dim*2, dim*2, kernel_size=3, stride=1, padding=1, groups=dim*2, bias=bias)
100
+ # self.project_out = nn.Conv2d(dim*2, dim, kernel_size=1, bias=bias)
101
+
102
+ def forward(self, x, attn_kv1, attn_kv2):
103
+ b,c,h,w = x.shape
104
+
105
+ q_ = self.q_dwconv(self.q(x))
106
+ kv1 = self.kv1_dwconv(self.kv1(attn_kv1))
107
+ kv2 = self.kv2_dwconv(self.kv2(attn_kv2))
108
+ q1,q2 = q_.chunk(2, dim=1)
109
+ k1,v1 = kv1.chunk(2, dim=1)
110
+ k2,v2 = kv2.chunk(2, dim=1)
111
+
112
+ q1 = rearrange(q1, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
113
+ q2 = rearrange(q2, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
114
+ k1 = rearrange(k1, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
115
+ v1 = rearrange(v1, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
116
+ k2 = rearrange(k2, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
117
+ v2 = rearrange(v2, 'b (head c) h w -> b head c (h w)', head=self.num_heads)
118
+
119
+ q1 = torch.nn.functional.normalize(q1, dim=-1)
120
+ q2 = torch.nn.functional.normalize(q2, dim=-1)
121
+ k1 = torch.nn.functional.normalize(k1, dim=-1)
122
+ k2 = torch.nn.functional.normalize(k2, dim=-1)
123
+
124
+ attn = (q1 @ k1.transpose(-2, -1)) * self.temperature1
125
+ attn = attn.softmax(dim=-1)
126
+ out1 = (attn @ v1)
127
+
128
+ attn = (q2 @ k2.transpose(-2, -1)) * self.temperature2
129
+ attn = attn.softmax(dim=-1)
130
+ out2 = (attn @ v2)
131
+
132
+ out1 = rearrange(out1, 'b head c (h w) -> b (head c) h w', head=self.num_heads, h=h, w=w)
133
+ out2 = rearrange(out2, 'b head c (h w) -> b (head c) h w', head=self.num_heads, h=h, w=w)
134
+ return out1, out2
135
+
136
+
137
+ ##########################################################################
138
+ class CrossTransformerBlock(nn.Module):
139
+ def __init__(self, dim, num_heads, ffn_expansion_factor, bias, LayerNorm_type):
140
+ super(CrossTransformerBlock, self).__init__()
141
+ self.norm1 = LayerNorm(dim, LayerNorm_type)
142
+ self.norm_kv1 = LayerNorm(dim, LayerNorm_type)
143
+ self.norm_kv2 = LayerNorm(dim, LayerNorm_type)
144
+ self.attn = CrossAttention(dim, num_heads, bias)
145
+ self.norm2 = LayerNorm(dim, LayerNorm_type)
146
+ self.ffn = FeedForward(dim, ffn_expansion_factor, bias)
147
+
148
+ def forward(self, x, attn_kv1, attn_kv2):
149
+ out1, out2 = self.attn(self.norm1(x), self.norm_kv1(attn_kv1), self.norm_kv2(attn_kv2))
150
+ out = torch.cat((self.ffn(self.norm2(out1)), x, self.ffn(self.norm2(out2))), dim=1)
151
+ return out