spikezoo 0.1.1__py3-none-any.whl → 0.2__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (192) hide show
  1. spikezoo/__init__.py +13 -0
  2. spikezoo/archs/__pycache__/__init__.cpython-39.pyc +0 -0
  3. spikezoo/archs/base/__pycache__/nets.cpython-39.pyc +0 -0
  4. spikezoo/archs/base/nets.py +34 -0
  5. spikezoo/archs/bsf/README.md +92 -0
  6. spikezoo/archs/bsf/datasets/datasets.py +328 -0
  7. spikezoo/archs/bsf/datasets/ds_utils.py +64 -0
  8. spikezoo/archs/bsf/main.py +398 -0
  9. spikezoo/archs/bsf/metrics/psnr.py +22 -0
  10. spikezoo/archs/bsf/metrics/ssim.py +54 -0
  11. spikezoo/archs/bsf/models/bsf/__pycache__/align.cpython-39.pyc +0 -0
  12. spikezoo/archs/bsf/models/bsf/__pycache__/bsf.cpython-39.pyc +0 -0
  13. spikezoo/archs/bsf/models/bsf/__pycache__/rep.cpython-39.pyc +0 -0
  14. spikezoo/archs/bsf/models/bsf/align.py +154 -0
  15. spikezoo/archs/bsf/models/bsf/bsf.py +105 -0
  16. spikezoo/archs/bsf/models/bsf/dsft_convert.py +96 -0
  17. spikezoo/archs/bsf/models/bsf/rep.py +44 -0
  18. spikezoo/archs/bsf/models/get_model.py +7 -0
  19. spikezoo/archs/bsf/prepare_data/DSFT.py +62 -0
  20. spikezoo/archs/bsf/prepare_data/crop_dataset_train.py +135 -0
  21. spikezoo/archs/bsf/prepare_data/crop_dataset_val.py +139 -0
  22. spikezoo/archs/bsf/prepare_data/crop_train.sh +4 -0
  23. spikezoo/archs/bsf/prepare_data/crop_val.sh +4 -0
  24. spikezoo/archs/bsf/prepare_data/io_utils.py +64 -0
  25. spikezoo/archs/bsf/requirements.txt +9 -0
  26. spikezoo/archs/bsf/test.py +16 -0
  27. spikezoo/archs/bsf/utils.py +154 -0
  28. spikezoo/archs/spikeclip/__pycache__/nets.cpython-39.pyc +0 -0
  29. spikezoo/archs/spikeclip/nets.py +40 -0
  30. spikezoo/archs/spikeformer/CheckPoints/readme +1 -0
  31. spikezoo/archs/spikeformer/DataProcess/DataExtactor.py +60 -0
  32. spikezoo/archs/spikeformer/DataProcess/DataLoader.py +115 -0
  33. spikezoo/archs/spikeformer/DataProcess/LoadSpike.py +39 -0
  34. spikezoo/archs/spikeformer/EvalResults/readme +1 -0
  35. spikezoo/archs/spikeformer/LICENSE +21 -0
  36. spikezoo/archs/spikeformer/Metrics/Metrics.py +50 -0
  37. spikezoo/archs/spikeformer/Metrics/__init__.py +0 -0
  38. spikezoo/archs/spikeformer/Model/Loss.py +89 -0
  39. spikezoo/archs/spikeformer/Model/SpikeFormer.py +230 -0
  40. spikezoo/archs/spikeformer/Model/__init__.py +0 -0
  41. spikezoo/archs/spikeformer/Model/__pycache__/SpikeFormer.cpython-39.pyc +0 -0
  42. spikezoo/archs/spikeformer/Model/__pycache__/__init__.cpython-39.pyc +0 -0
  43. spikezoo/archs/spikeformer/README.md +30 -0
  44. spikezoo/archs/spikeformer/evaluate.py +87 -0
  45. spikezoo/archs/spikeformer/recon_real_data.py +97 -0
  46. spikezoo/archs/spikeformer/requirements.yml +95 -0
  47. spikezoo/archs/spikeformer/train.py +173 -0
  48. spikezoo/archs/spikeformer/utils.py +22 -0
  49. spikezoo/archs/spk2imgnet/.github/workflows/pylint.yml +23 -0
  50. spikezoo/archs/spk2imgnet/.gitignore +150 -0
  51. spikezoo/archs/spk2imgnet/DCNv2.py +135 -0
  52. spikezoo/archs/spk2imgnet/__pycache__/DCNv2.cpython-39.pyc +0 -0
  53. spikezoo/archs/spk2imgnet/__pycache__/align_arch.cpython-39.pyc +0 -0
  54. spikezoo/archs/spk2imgnet/__pycache__/nets.cpython-39.pyc +0 -0
  55. spikezoo/archs/spk2imgnet/align_arch.py +159 -0
  56. spikezoo/archs/spk2imgnet/dataset.py +144 -0
  57. spikezoo/archs/spk2imgnet/nets.py +230 -0
  58. spikezoo/archs/spk2imgnet/readme.md +86 -0
  59. spikezoo/archs/spk2imgnet/test_gen_imgseq.py +118 -0
  60. spikezoo/archs/spk2imgnet/train.py +189 -0
  61. spikezoo/archs/spk2imgnet/utils.py +64 -0
  62. spikezoo/archs/ssir/README.md +87 -0
  63. spikezoo/archs/ssir/configs/SSIR.yml +37 -0
  64. spikezoo/archs/ssir/configs/yml_parser.py +78 -0
  65. spikezoo/archs/ssir/datasets/dataset_sreds.py +170 -0
  66. spikezoo/archs/ssir/datasets/ds_utils.py +66 -0
  67. spikezoo/archs/ssir/losses.py +21 -0
  68. spikezoo/archs/ssir/main.py +326 -0
  69. spikezoo/archs/ssir/metrics/psnr.py +22 -0
  70. spikezoo/archs/ssir/metrics/ssim.py +54 -0
  71. spikezoo/archs/ssir/models/Vgg19.py +42 -0
  72. spikezoo/archs/ssir/models/__pycache__/layers.cpython-39.pyc +0 -0
  73. spikezoo/archs/ssir/models/__pycache__/networks.cpython-39.pyc +0 -0
  74. spikezoo/archs/ssir/models/layers.py +110 -0
  75. spikezoo/archs/ssir/models/networks.py +61 -0
  76. spikezoo/archs/ssir/requirements.txt +8 -0
  77. spikezoo/archs/ssir/shells/eval_SREDS.sh +6 -0
  78. spikezoo/archs/ssir/shells/train_SSIR.sh +12 -0
  79. spikezoo/archs/ssir/test.py +3 -0
  80. spikezoo/archs/ssir/utils.py +154 -0
  81. spikezoo/archs/ssml/__pycache__/cbam.cpython-39.pyc +0 -0
  82. spikezoo/archs/ssml/__pycache__/model.cpython-39.pyc +0 -0
  83. spikezoo/archs/ssml/cbam.py +224 -0
  84. spikezoo/archs/ssml/model.py +290 -0
  85. spikezoo/archs/ssml/res.png +0 -0
  86. spikezoo/archs/ssml/test.py +67 -0
  87. spikezoo/archs/stir/.git-credentials +0 -0
  88. spikezoo/archs/stir/README.md +65 -0
  89. spikezoo/archs/stir/ckpt_outputs/Descriptions.txt +1 -0
  90. spikezoo/archs/stir/configs/STIR.yml +37 -0
  91. spikezoo/archs/stir/configs/utils.py +155 -0
  92. spikezoo/archs/stir/configs/yml_parser.py +78 -0
  93. spikezoo/archs/stir/datasets/dataset_sreds.py +180 -0
  94. spikezoo/archs/stir/datasets/ds_utils.py +66 -0
  95. spikezoo/archs/stir/eval_SREDS.sh +5 -0
  96. spikezoo/archs/stir/main.py +397 -0
  97. spikezoo/archs/stir/metrics/losses.py +219 -0
  98. spikezoo/archs/stir/metrics/psnr.py +22 -0
  99. spikezoo/archs/stir/metrics/ssim.py +54 -0
  100. spikezoo/archs/stir/models/Vgg19.py +42 -0
  101. spikezoo/archs/stir/models/__pycache__/networks_STIR.cpython-39.pyc +0 -0
  102. spikezoo/archs/stir/models/__pycache__/submodules.cpython-39.pyc +0 -0
  103. spikezoo/archs/stir/models/__pycache__/transformer_new.cpython-39.pyc +0 -0
  104. spikezoo/archs/stir/models/networks_STIR.py +361 -0
  105. spikezoo/archs/stir/models/submodules.py +86 -0
  106. spikezoo/archs/stir/models/transformer_new.py +151 -0
  107. spikezoo/archs/stir/package_core/build/lib/package_core/__init__.py +0 -0
  108. spikezoo/archs/stir/package_core/build/lib/package_core/convertions.py +721 -0
  109. spikezoo/archs/stir/package_core/build/lib/package_core/disp_netS.py +133 -0
  110. spikezoo/archs/stir/package_core/build/lib/package_core/flow_utils.py +167 -0
  111. spikezoo/archs/stir/package_core/build/lib/package_core/generic_train_test.py +76 -0
  112. spikezoo/archs/stir/package_core/build/lib/package_core/geometry.py +458 -0
  113. spikezoo/archs/stir/package_core/build/lib/package_core/image_proc.py +183 -0
  114. spikezoo/archs/stir/package_core/build/lib/package_core/linalg.py +40 -0
  115. spikezoo/archs/stir/package_core/build/lib/package_core/losses.py +198 -0
  116. spikezoo/archs/stir/package_core/build/lib/package_core/metrics.py +51 -0
  117. spikezoo/archs/stir/package_core/build/lib/package_core/model_base.py +53 -0
  118. spikezoo/archs/stir/package_core/build/lib/package_core/net_basics.py +100 -0
  119. spikezoo/archs/stir/package_core/build/lib/package_core/resnet.py +333 -0
  120. spikezoo/archs/stir/package_core/build/lib/package_core/transforms.py +123 -0
  121. spikezoo/archs/stir/package_core/build/lib/package_core/utils.py +72 -0
  122. spikezoo/archs/stir/package_core/dist/package_core-0.0.0-py3.9.egg +0 -0
  123. spikezoo/archs/stir/package_core/package_core/__init__.py +0 -0
  124. spikezoo/archs/stir/package_core/package_core/__pycache__/__init__.cpython-39.pyc +0 -0
  125. spikezoo/archs/stir/package_core/package_core/__pycache__/net_basics.cpython-39.pyc +0 -0
  126. spikezoo/archs/stir/package_core/package_core/convertions.py +721 -0
  127. spikezoo/archs/stir/package_core/package_core/disp_netS.py +133 -0
  128. spikezoo/archs/stir/package_core/package_core/flow_utils.py +167 -0
  129. spikezoo/archs/stir/package_core/package_core/generic_train_test.py +76 -0
  130. spikezoo/archs/stir/package_core/package_core/geometry.py +458 -0
  131. spikezoo/archs/stir/package_core/package_core/image_proc.py +183 -0
  132. spikezoo/archs/stir/package_core/package_core/linalg.py +40 -0
  133. spikezoo/archs/stir/package_core/package_core/losses.py +198 -0
  134. spikezoo/archs/stir/package_core/package_core/metrics.py +51 -0
  135. spikezoo/archs/stir/package_core/package_core/model_base.py +53 -0
  136. spikezoo/archs/stir/package_core/package_core/net_basics.py +100 -0
  137. spikezoo/archs/stir/package_core/package_core/resnet.py +333 -0
  138. spikezoo/archs/stir/package_core/package_core/transforms.py +123 -0
  139. spikezoo/archs/stir/package_core/package_core/utils.py +72 -0
  140. spikezoo/archs/stir/package_core/package_core.egg-info/PKG-INFO +3 -0
  141. spikezoo/archs/stir/package_core/package_core.egg-info/SOURCES.txt +20 -0
  142. spikezoo/archs/stir/package_core/package_core.egg-info/dependency_links.txt +1 -0
  143. spikezoo/archs/stir/package_core/package_core.egg-info/top_level.txt +1 -0
  144. spikezoo/archs/stir/package_core/setup.py +5 -0
  145. spikezoo/archs/stir/requirements.txt +12 -0
  146. spikezoo/archs/stir/train_STIR.sh +9 -0
  147. spikezoo/archs/tfi/__pycache__/nets.cpython-39.pyc +0 -0
  148. spikezoo/archs/tfi/nets.py +43 -0
  149. spikezoo/archs/tfp/__pycache__/nets.cpython-39.pyc +0 -0
  150. spikezoo/archs/tfp/nets.py +13 -0
  151. spikezoo/archs/wgse/README.md +64 -0
  152. spikezoo/archs/wgse/__pycache__/dwtnets.cpython-39.pyc +0 -0
  153. spikezoo/archs/wgse/__pycache__/submodules.cpython-39.pyc +0 -0
  154. spikezoo/archs/wgse/dataset.py +59 -0
  155. spikezoo/archs/wgse/demo.png +0 -0
  156. spikezoo/archs/wgse/demo.py +83 -0
  157. spikezoo/archs/wgse/dwtnets.py +145 -0
  158. spikezoo/archs/wgse/eval.py +133 -0
  159. spikezoo/archs/wgse/logs/WGSE-Dwt1dNet-db8-5-ks3/log.txt +11 -0
  160. spikezoo/archs/wgse/submodules.py +68 -0
  161. spikezoo/archs/wgse/train.py +261 -0
  162. spikezoo/archs/wgse/transform.py +139 -0
  163. spikezoo/archs/wgse/utils.py +128 -0
  164. spikezoo/archs/wgse/weights/demo.png +0 -0
  165. spikezoo/data/base/test/gt/200_part1_key_id151.png +0 -0
  166. spikezoo/data/base/test/gt/200_part3_key_id151.png +0 -0
  167. spikezoo/data/base/test/gt/203_part1_key_id151.png +0 -0
  168. spikezoo/data/base/test/spike/200_part1_key_id151.dat +0 -0
  169. spikezoo/data/base/test/spike/200_part3_key_id151.dat +0 -0
  170. spikezoo/data/base/test/spike/203_part1_key_id151.dat +0 -0
  171. spikezoo/data/base/train/gt/203_part2_key_id151.png +0 -0
  172. spikezoo/data/base/train/gt/203_part3_key_id151.png +0 -0
  173. spikezoo/data/base/train/gt/203_part4_key_id151.png +0 -0
  174. spikezoo/data/base/train/spike/203_part2_key_id151.dat +0 -0
  175. spikezoo/data/base/train/spike/203_part3_key_id151.dat +0 -0
  176. spikezoo/data/base/train/spike/203_part4_key_id151.dat +0 -0
  177. spikezoo/datasets/base_dataset.py +2 -3
  178. spikezoo/metrics/__init__.py +1 -1
  179. spikezoo/models/base_model.py +1 -3
  180. spikezoo/pipeline/base_pipeline.py +7 -5
  181. spikezoo/pipeline/train_pipeline.py +1 -1
  182. spikezoo/utils/other_utils.py +16 -6
  183. spikezoo/utils/spike_utils.py +33 -29
  184. spikezoo/utils/vidar_loader.cpython-39-x86_64-linux-gnu.so +0 -0
  185. spikezoo-0.2.dist-info/METADATA +163 -0
  186. spikezoo-0.2.dist-info/RECORD +211 -0
  187. spikezoo/models/spcsnet_model.py +0 -19
  188. spikezoo-0.1.1.dist-info/METADATA +0 -39
  189. spikezoo-0.1.1.dist-info/RECORD +0 -36
  190. {spikezoo-0.1.1.dist-info → spikezoo-0.2.dist-info}/LICENSE.txt +0 -0
  191. {spikezoo-0.1.1.dist-info → spikezoo-0.2.dist-info}/WHEEL +0 -0
  192. {spikezoo-0.1.1.dist-info → spikezoo-0.2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,154 @@
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn.functional as F
4
+ import os
5
+ import os.path as osp
6
+ import random
7
+ import cv2
8
+
9
+ def set_seeds(_seed_):
10
+ random.seed(_seed_)
11
+ np.random.seed(_seed_)
12
+ torch.manual_seed(_seed_) # use torch.manual_seed() to seed the RNG for all devices (both CPU and CUDA)
13
+ torch.cuda.manual_seed_all(_seed_)
14
+
15
+ torch.backends.cudnn.deterministic = True
16
+ torch.backends.cudnn.benchmark = False
17
+ # os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'
18
+ # set a debug environment variable CUBLAS_WORKSPACE_CONFIG to ":16:8" (may limit overall performance) or ":4096:8" (will increase library footprint in GPU memory by approximately 24MiB).
19
+ # torch.use_deterministic_algorithms(True)
20
+
21
+
22
+ def make_dir(path):
23
+ if not osp.exists(path):
24
+ os.makedirs(path)
25
+ return
26
+
27
+
28
+ def add_args_to_cfg(cfg, args, args_list):
29
+ for aa in args_list:
30
+ cfg['train'][aa] = eval('args.{:s}'.format(aa))
31
+ return cfg
32
+
33
+
34
+ # class AverageMeter(object):
35
+ # """Computes and stores the average and current value"""
36
+ # def __init__(self, precision=3):
37
+ # self.precision = precision
38
+ # self.reset()
39
+
40
+ # def reset(self):
41
+ # self.val = 0
42
+ # self.avg = 0
43
+ # self.sum = 0
44
+ # self.count = 0
45
+
46
+ # def update(self, val, n=1):
47
+ # self.val = val
48
+ # self.sum += val * n
49
+ # self.count += n
50
+ # self.avg = self.sum / self.count
51
+
52
+ # def __repr__(self):
53
+ # return '{:.{}f} ({:.{}f})'.format(self.val, self.precision, self.avg, self.precision)
54
+
55
+
56
+ class AverageMeter(object):
57
+ """Computes and stores the average and current value"""
58
+
59
+ def __init__(self, i=1, precision=3, names=None):
60
+ self.meters = i
61
+ self.precision = precision
62
+ self.reset(self.meters)
63
+ self.names = names
64
+ if names is not None:
65
+ assert self.meters == len(self.names)
66
+ else:
67
+ self.names = [''] * self.meters
68
+
69
+ def reset(self, i):
70
+ self.val = [0] * i
71
+ self.avg = [0] * i
72
+ self.sum = [0] * i
73
+ self.count = [0] * i
74
+
75
+ def update(self, val, n=1):
76
+ if not isinstance(val, list):
77
+ val = [val]
78
+ if not isinstance(n, list):
79
+ n = [n] * self.meters
80
+ assert (len(val) == self.meters and len(n) == self.meters)
81
+ for i in range(self.meters):
82
+ self.count[i] += n[i]
83
+ for i, v in enumerate(val):
84
+ self.val[i] = v
85
+ self.sum[i] += v * n[i]
86
+ self.avg[i] = self.sum[i] / self.count[i]
87
+
88
+ def __repr__(self):
89
+ # val = ' '.join(['{} {:.{}f}'.format(n, v, self.precision) for n, v in
90
+ # zip(self.names, self.val)])
91
+ # avg = ' '.join(['{} {:.{}f}'.format(n, a, self.precision) for n, a in
92
+ # zip(self.names, self.avg)])
93
+ out = ' '.join(['{} {:.{}f} ({:.{}f})'.format(n, v, self.precision, a, self.precision) for n, v, a in
94
+ zip(self.names, self.val, self.avg)])
95
+ # return '{} ({})'.format(val, avg)
96
+ return '{}'.format(out)
97
+
98
+
99
+ def normalize_image_torch(image, percentile_lower=1, percentile_upper=99):
100
+ b, c, h, w = image.shape
101
+ image_reshape = image.reshape([b, c, h*w])
102
+ mini = torch.quantile(image_reshape, 0.01, dim=2, keepdim=True).unsqueeze_(dim=3)
103
+ maxi = torch.quantile(image_reshape, 0.99, dim=2, keepdim=True).unsqueeze_(dim=3)
104
+ # if mini == maxi:
105
+ # return 0 * image + 0.5 # gray image
106
+ return torch.clip((image - mini) / (maxi - mini + 1e-5), 0, 1)
107
+
108
+ def normalize_image_torch2(image):
109
+ return torch.clip(image, 0, 1)
110
+
111
+ # --------------------------------------------
112
+ # Torch to Numpy 0~255
113
+ # --------------------------------------------
114
+ def torch2numpy255(im):
115
+ im = im[0, 0].detach().cpu().numpy()
116
+ im = (im * 255).astype(np.float64)
117
+ return im
118
+
119
+ def torch2torch255(im):
120
+ return im * 255.0
121
+
122
+ class InputPadder:
123
+ """ Pads images such that dimensions are divisible by padsize """
124
+ def __init__(self, dims, padsize=16):
125
+ self.ht, self.wd = dims[-2:]
126
+ pad_ht = (((self.ht // padsize) + 1) * padsize - self.ht) % padsize
127
+ pad_wd = (((self.wd // padsize) + 1) * padsize - self.wd) %padsize
128
+ self._pad = [pad_wd//2, pad_wd - pad_wd//2, 0, pad_ht]
129
+
130
+ def pad(self, *inputs):
131
+ return [F.pad(x, self._pad, mode='replicate') for x in inputs]
132
+
133
+ def unpad(self,x):
134
+ ht, wd = x.shape[-2:]
135
+ c = [self._pad[2], ht-self._pad[3], self._pad[0], wd-self._pad[1]]
136
+ return x[..., c[0]:c[1], c[2]:c[3]]
137
+
138
+
139
+
140
+ def vis_img(vis_path: str, img: torch.Tensor, vis_name: str = 'vis'):
141
+ ww = 0
142
+ rows = []
143
+ for ii in range(4):
144
+ cur_row = []
145
+ for jj in range(img.shape[0]//4):
146
+ cur_img = img[ww, 0].detach().cpu().numpy() * 255
147
+ cur_img = cur_img.astype(np.uint8)
148
+ cur_row.append(cur_img)
149
+ ww += 1
150
+ cur_row_cat = np.concatenate(cur_row, axis=1)
151
+ rows.append(cur_row_cat)
152
+ out_img = np.concatenate(rows, axis=0)
153
+ cv2.imwrite(osp.join(vis_path, vis_name+'.png'), out_img)
154
+ return
@@ -0,0 +1,224 @@
1
+ import torch
2
+ import math
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+
6
+ class crop(nn.Module):
7
+ def __init__(self):
8
+ super().__init__()
9
+
10
+ def forward(self, x):
11
+ N, C, H, W = x.shape
12
+ x = x[0:N, 0:C, 0:H-1, 0:W]
13
+ return x
14
+
15
+ class shift(nn.Module):
16
+ def __init__(self):
17
+ super().__init__()
18
+ self.shift_down = nn.ZeroPad2d((0,0,1,0))
19
+ self.crop = crop()
20
+
21
+ def forward(self, x):
22
+ x = self.shift_down(x)
23
+ x = self.crop(x)
24
+ return x
25
+
26
+ class Conv(nn.Module):
27
+ def __init__(self, in_channels, out_channels, bias=False, blind=True,stride=1,padding=0,kernel_size=3):
28
+ super().__init__()
29
+ self.blind = blind
30
+ if blind:
31
+ self.shift_down = nn.ZeroPad2d((0,0,1,0))
32
+ self.crop = crop()
33
+ self.replicate = nn.ReplicationPad2d(1)
34
+ # self.conv = nn.Conv2d(in_channels, out_channels, 3, bias=bias)
35
+ self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride,padding=padding,bias=bias)
36
+ self.relu = nn.LeakyReLU(0.1, inplace=True)
37
+ # self.ln = nn.GroupNorm(1,out_channels)
38
+
39
+ def forward(self, x):
40
+ if self.blind:
41
+ x = self.shift_down(x)
42
+ x = self.replicate(x)
43
+ x = self.conv(x)
44
+ x = self.relu(x)
45
+
46
+ if self.blind:
47
+ x = self.crop(x)
48
+ return x
49
+
50
+ class BasicConv(nn.Module):
51
+ def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1, groups=1, relu=False, bn=False, bias=True,blind=False):
52
+ super(BasicConv, self).__init__()
53
+ self.out_channels = out_planes
54
+ self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias)
55
+ self.bn = nn.BatchNorm2d(out_planes,eps=1e-5, momentum=0.01, affine=True) if bn else None
56
+ self.relu = nn.ReLU() if relu else None
57
+
58
+ def forward(self, x):
59
+ x = self.conv(x)
60
+ if self.bn is not None:
61
+ x = self.bn(x)
62
+ if self.relu is not None:
63
+ x = self.relu(x)
64
+ return x
65
+
66
+ class Flatten(nn.Module):
67
+ def forward(self, x):
68
+ return x.view(x.size(0), -1)
69
+
70
+ class ChannelGate(nn.Module):
71
+ def __init__(self, gate_channels, reduction_ratio=16, pool_types=['avg', 'max']):
72
+ super(ChannelGate, self).__init__()
73
+ self.gate_channels = gate_channels
74
+ self.mlp = nn.Sequential(
75
+ Flatten(),
76
+ nn.Linear(gate_channels, gate_channels // reduction_ratio),
77
+ nn.ReLU(),
78
+ nn.Linear(gate_channels // reduction_ratio, gate_channels)
79
+ )
80
+ self.pool_types = pool_types
81
+ def forward(self, x):
82
+ channel_att_sum = None
83
+ for pool_type in self.pool_types:
84
+ if pool_type=='avg':
85
+ avg_pool = F.avg_pool2d( x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3)))
86
+ channel_att_raw = self.mlp( avg_pool )
87
+ elif pool_type=='max':
88
+ max_pool = F.max_pool2d( x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3)))
89
+ channel_att_raw = self.mlp( max_pool )
90
+ elif pool_type=='lp':
91
+ lp_pool = F.lp_pool2d( x, 2, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3)))
92
+ channel_att_raw = self.mlp( lp_pool )
93
+ elif pool_type=='lse':
94
+ # LSE pool only
95
+ lse_pool = logsumexp_2d(x)
96
+ channel_att_raw = self.mlp( lse_pool )
97
+
98
+ if channel_att_sum is None:
99
+ channel_att_sum = channel_att_raw
100
+ else:
101
+ channel_att_sum = channel_att_sum + channel_att_raw
102
+
103
+ scale = F.sigmoid( channel_att_sum ).unsqueeze(2).unsqueeze(3).expand_as(x)
104
+ return x * scale
105
+
106
+ def logsumexp_2d(tensor):
107
+ tensor_flatten = tensor.view(tensor.size(0), tensor.size(1), -1)
108
+ s, _ = torch.max(tensor_flatten, dim=2, keepdim=True)
109
+ outputs = s + (tensor_flatten - s).exp().sum(dim=2, keepdim=True).log()
110
+ return outputs
111
+
112
+ class ChannelPool(nn.Module):
113
+ def forward(self, x):
114
+ return torch.cat( (torch.max(x,1)[0].unsqueeze(1), torch.mean(x,1).unsqueeze(1)), dim=1 )
115
+
116
+ class SpatialGate(nn.Module):
117
+ def __init__(self,bias=False,blind=False):
118
+ super(SpatialGate, self).__init__()
119
+ kernel_size = 7
120
+ self.compress = ChannelPool()
121
+ self.spatial = BasicConv(2, 1, kernel_size, stride=1, padding=(kernel_size-1) // 2, relu=False,bias=bias,blind=False)
122
+ def forward(self, x):
123
+ x_compress = self.compress(x)
124
+ x_out = self.spatial(x_compress)
125
+ scale = F.sigmoid(x_out) # broadcasting
126
+ return x * scale
127
+
128
+ class CBAM(nn.Module):
129
+ def __init__(self, gate_channels, reduction_ratio=16, pool_types=['avg', 'max'], no_spatial=False):
130
+ super(CBAM, self).__init__()
131
+ self.ChannelGate = ChannelGate(gate_channels, reduction_ratio, pool_types)
132
+ self.no_spatial=no_spatial
133
+ if not no_spatial:
134
+ self.SpatialGate = SpatialGate()
135
+ def forward(self, x):
136
+ x_out = self.ChannelGate(x)
137
+ if not self.no_spatial:
138
+ x_out = self.SpatialGate(x_out)
139
+ return x_out
140
+
141
+ def weights_init_rcan(m):
142
+ """
143
+ custom weights initialization called on netG and netD
144
+ https://github.com/pytorch/examples/blob/master/dcgan/main.py
145
+ """
146
+ classname = m.__class__.__name__
147
+ if classname.find('Conv') != -1:
148
+ if classname.find('BasicConv') != -1:
149
+ m.conv.weight.data.normal_(0.0, 0.02)
150
+ if m.bn != None:
151
+ m.bn.bias.data.fill_(0)
152
+ else:
153
+ m.weight.data.normal_(0.0, 0.02)
154
+ elif classname.find('BatchNorm') != -1:
155
+ m.weight.data.normal_(1.0, 0.02)
156
+ m.bias.data.fill_(0)
157
+
158
+ class Temporal_Fusion(nn.Module):
159
+
160
+ def __init__(self, nf=64, nframes=3, center=1,bias=False):
161
+ super(Temporal_Fusion, self).__init__()
162
+ self.center = center
163
+
164
+ self.tAtt_1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=bias)
165
+ self.tAtt_2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=bias)
166
+
167
+ self.fea_fusion = nn.Conv2d(nframes * nf, nf, 1, 1, bias=bias)
168
+
169
+ self.sAtt_1 = nn.Conv2d(nframes * nf, nf, 1, 1, bias=bias)
170
+ self.maxpool = nn.MaxPool2d(3, stride=2, padding=1)
171
+ self.avgpool = nn.AvgPool2d(3, stride=2, padding=1)
172
+ self.sAtt_2 = nn.Conv2d(nf * 2, nf, 1, 1, bias=bias)
173
+ self.sAtt_3 = nn.Conv2d(nf, nf, 3, 1, 1, bias=bias)
174
+ self.sAtt_4 = nn.Conv2d(nf, nf, 1, 1, bias=bias)
175
+ self.sAtt_5 = nn.Conv2d(nf, nf, 3, 1, 1, bias=bias)
176
+ self.sAtt_L1 = nn.Conv2d(nf, nf, 1, 1, bias=bias)
177
+ self.sAtt_L2 = nn.Conv2d(nf * 2, nf, 3, 1, 1, bias=bias)
178
+ self.sAtt_L3 = nn.Conv2d(nf, nf, 3, 1, 1, bias=bias)
179
+ self.sAtt_add_1 = nn.Conv2d(nf, nf, 1, 1, bias=bias)
180
+ self.sAtt_add_2 = nn.Conv2d(nf, nf, 1, 1, bias=bias)
181
+
182
+ self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
183
+
184
+ def forward(self, nonlocal_fea):
185
+ B, N, C, H, W = nonlocal_fea.size()
186
+
187
+ emb_ref = self.tAtt_2(nonlocal_fea[:, self.center, :, :, :].clone())
188
+ emb = self.tAtt_1(nonlocal_fea.view(-1, C, H, W)).view(B, N, -1, H, W)
189
+
190
+ cor_l = []
191
+ for i in range(N):
192
+ emb_nbr = emb[:, i, :, :, :]
193
+ cor_tmp = torch.sum(emb_nbr * emb_ref, 1).unsqueeze(1)
194
+ cor_l.append(cor_tmp)
195
+ cor_prob = torch.sigmoid(torch.cat(cor_l, dim=1))
196
+ cor_prob = cor_prob.unsqueeze(2).repeat(1, 1, C, 1, 1)
197
+ cor_prob = cor_prob.view(B, -1, H, W)
198
+ nonlocal_fea = nonlocal_fea.view(B, -1, H, W) * cor_prob
199
+
200
+ fea = self.lrelu(self.fea_fusion(nonlocal_fea))
201
+
202
+ att = self.lrelu(self.sAtt_1(nonlocal_fea))
203
+ att_max = self.maxpool(att)
204
+ att_avg = self.avgpool(att)
205
+ att = self.lrelu(self.sAtt_2(torch.cat([att_max, att_avg], dim=1)))
206
+
207
+ att_L = self.lrelu(self.sAtt_L1(att))
208
+ att_max = self.maxpool(att_L)
209
+ att_avg = self.avgpool(att_L)
210
+ att_L = self.lrelu(self.sAtt_L2(torch.cat([att_max, att_avg], dim=1)))
211
+ att_L = self.lrelu(self.sAtt_L3(att_L))
212
+ att_L = F.interpolate(att_L, scale_factor=2, mode='bilinear', align_corners=False)
213
+
214
+ att = self.lrelu(self.sAtt_3(att))
215
+ att = att + att_L
216
+ att = self.lrelu(self.sAtt_4(att))
217
+ att = F.interpolate(att, scale_factor=2, mode='bilinear', align_corners=False)
218
+ att = self.sAtt_5(att)
219
+ att_add = self.sAtt_add_2(self.lrelu(self.sAtt_add_1(att)))
220
+ att = torch.sigmoid(att)
221
+
222
+ fea = fea * att * 2 + att_add
223
+
224
+ return fea
@@ -0,0 +1,290 @@
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ # from utils import *
5
+ import numpy as np
6
+ import os
7
+ import sys
8
+ current_dir = os.path.dirname(os.path.abspath(__file__))
9
+ sys.path.append(current_dir)
10
+
11
+ from cbam import SpatialGate,ChannelGate,Temporal_Fusion
12
+
13
+ class crop(nn.Module):
14
+ def __init__(self):
15
+ super().__init__()
16
+
17
+ def forward(self, x):
18
+ N, C, H, W = x.shape
19
+ x = x[0:N, 0:C, 0:H-1, 0:W]
20
+ return x
21
+
22
+ class shift(nn.Module):
23
+ def __init__(self):
24
+ super().__init__()
25
+ self.shift_down = nn.ZeroPad2d((0,0,1,0))
26
+ self.crop = crop()
27
+
28
+ def forward(self, x):
29
+ x = self.shift_down(x)
30
+ x = self.crop(x)
31
+ return x
32
+
33
+ class Conv(nn.Module):
34
+ def __init__(self, in_channels, out_channels, bias=False, blind=True,stride=1,padding=0,kernel_size=3):
35
+ super().__init__()
36
+ self.blind = blind
37
+ if blind:
38
+ self.shift_down = nn.ZeroPad2d((0,0,1,0))
39
+ self.crop = crop()
40
+ self.replicate = nn.ReplicationPad2d(1)
41
+ self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride,padding=padding,bias=bias)
42
+ self.relu = nn.LeakyReLU(0.1, inplace=True)
43
+
44
+
45
+ def forward(self, x):
46
+ if self.blind:
47
+ x = self.shift_down(x)
48
+ x = self.replicate(x)
49
+ x = self.conv(x)
50
+ x = self.relu(x)
51
+ if self.blind:
52
+ x = self.crop(x)
53
+ return x
54
+
55
+ class Pool(nn.Module):
56
+ def __init__(self, blind=True):
57
+ super().__init__()
58
+ self.blind = blind
59
+ if blind:
60
+ self.shift = shift()
61
+ self.pool = nn.MaxPool2d(2)
62
+
63
+ def forward(self, x):
64
+ if self.blind:
65
+ x = self.shift(x)
66
+ x = self.pool(x)
67
+ return x
68
+
69
+ class rotate(nn.Module):
70
+ def __init__(self):
71
+ super().__init__()
72
+
73
+ def forward(self, x):
74
+ x90 = x.transpose(2,3).flip(3)
75
+ x180 = x.flip(2).flip(3)
76
+ x270 = x.transpose(2,3).flip(2)
77
+ x = torch.cat((x,x90,x180,x270), dim=0)
78
+ return x
79
+
80
+ class unrotate(nn.Module):
81
+ def __init__(self):
82
+ super().__init__()
83
+
84
+ def forward(self, x):
85
+ x0, x90, x180, x270 = torch.chunk(x, 4, dim=0)
86
+ x90 = x90.transpose(2,3).flip(2)
87
+ x180 = x180.flip(2).flip(3)
88
+ x270 = x270.transpose(2,3).flip(3)
89
+ x = torch.cat((x0,x90,x180,x270), dim=1)
90
+ return x
91
+
92
+ class ENC_Conv(nn.Module):
93
+ def __init__(self, in_channels, mid_channels, out_channels, bias=False, reduce=True, blind=True):
94
+ super().__init__()
95
+ self.reduce = reduce
96
+ self.conv1 = Conv(in_channels, mid_channels, bias=bias, blind=blind)
97
+ self.conv2 = Conv(mid_channels, mid_channels, bias=bias, blind=blind)
98
+ self.conv3 = Conv(mid_channels, out_channels, bias=bias, blind=blind)
99
+ if reduce:
100
+ self.pool = Pool(blind=blind)
101
+
102
+ def forward(self, x):
103
+ x = self.conv1(x)
104
+ x = self.conv2(x)
105
+ x = self.conv3(x)
106
+ if self.reduce:
107
+ x = self.pool(x)
108
+ return x
109
+
110
+ class DEC_Conv(nn.Module):
111
+ def __init__(self, in_channels, mid_channels, out_channels, bias=False, blind=True):
112
+ super().__init__()
113
+ self.upsample = nn.Upsample(scale_factor=2, mode='nearest')
114
+ self.conv1 = Conv(in_channels, mid_channels, bias=bias, blind=blind)
115
+ self.conv2 = Conv(mid_channels, mid_channels, bias=bias, blind=blind)
116
+ self.conv3 = Conv(mid_channels, mid_channels, bias=bias, blind=blind)
117
+ self.conv4 = Conv(mid_channels, out_channels, bias=bias, blind=blind)
118
+
119
+ def forward(self, x, x_in):
120
+ x = self.upsample(x)
121
+
122
+ # Smart Padding
123
+ diffY = x_in.size()[2] - x.size()[2]
124
+ diffX = x_in.size()[3] - x.size()[3]
125
+ x = F.pad(x, [diffX // 2, diffX - diffX // 2,
126
+ diffY // 2, diffY - diffY // 2])
127
+
128
+ x = torch.cat((x, x_in), dim=1)
129
+ x = self.conv1(x)
130
+ x = self.conv2(x)
131
+ x = self.conv3(x)
132
+ x = self.conv4(x)
133
+ return x
134
+
135
+ class Blind_UNet(nn.Module):
136
+ def __init__(self, n_channels=3, n_output=96, bias=False, blind=True):
137
+ super().__init__()
138
+ self.n_channels = n_channels
139
+ self.bias = bias
140
+ self.enc1 = ENC_Conv(n_channels, 48, 48, bias=bias, blind=blind)
141
+ self.enc2 = ENC_Conv(48, 48, 48, bias=bias, blind=blind)
142
+ self.enc3 = ENC_Conv(48, 96, 48, bias=bias, reduce=False, blind=blind)
143
+ self.dec2 = DEC_Conv(96, 96, 96, bias=bias, blind=blind)
144
+ self.dec1 = DEC_Conv(96+n_channels, 96, n_output, bias=bias, blind=blind)
145
+
146
+ def forward(self, input):
147
+ x1 = self.enc1(input)
148
+ x2 = self.enc2(x1)
149
+ x = self.enc3(x2)
150
+ x = self.dec2(x, x1)
151
+ x = self.dec1(x, input)
152
+ return x
153
+
154
+ def middleTFI(spike, middle, window=50):
155
+ #左右找1
156
+ spike = spike.squeeze(1).numpy()
157
+ C, H, W = spike.shape
158
+ lindex, rindex = np.zeros([H, W]), np.zeros([H, W])
159
+ l, r = middle+1, middle+1
160
+ for r in range(middle+1, middle + window+1): #往左包括自己50个,往右不包括自己也是50个
161
+ l = l - 1
162
+ if l>=0:
163
+ newpos = spike[l, :, :]*(1 - np.sign(lindex))
164
+ distance = l*newpos
165
+ lindex += distance
166
+ if r<C:
167
+ newpos = spike[r, :, :]*(1 - np.sign(rindex))
168
+ distance = r*newpos
169
+ rindex += distance
170
+
171
+ rindex[rindex==0] = window+middle
172
+ lindex[lindex==0] = middle-window
173
+ interval = rindex - lindex
174
+ tfi = 1.0 / interval
175
+
176
+ return tfi
177
+
178
+ class MotionInference(nn.Module):
179
+ def __init__(self,n_frame=41,bias=False,blind=False):
180
+ super().__init__()
181
+ self.middle = n_frame//2
182
+ self.conv0 = nn.Conv2d(5*2+1,1,1,bias=bias)
183
+ self.conv1 = nn.Conv2d(9*2+1,1,1,bias=bias)
184
+ self.conv2 = nn.Conv2d(13*2+1,1,1,bias=bias)
185
+ self.tfpconv = Conv(in_channels=3, out_channels=16, bias=bias,blind=blind)
186
+ self.tficonv = Conv(in_channels=1, out_channels=16, bias=bias,blind=blind)
187
+ self.ChannelGate = ChannelGate(gate_channels=16, reduction_ratio=4)
188
+ self.SpatialGate = SpatialGate(bias=bias,blind=blind)
189
+ self.blind = blind
190
+ def forward(self, x):
191
+ N, C, H, W = x.shape
192
+ tmp=[]
193
+ ttt=[]
194
+ for j in range(N):
195
+ tmp2 = middleTFI(x[j].cpu(), self.middle, window=12)
196
+ tmp2 = torch.tensor(tmp2,dtype=torch.float32).unsqueeze_(dim=0)
197
+ tmp.append(tmp2) #1 40 40
198
+ ttt5=torch.mean(x[j,self.middle-3:self.middle+3+1,:,:].cpu(),dim=0).unsqueeze_(0)
199
+ ttt.append(ttt5)
200
+ tfi_label = torch.stack(tmp,0).cuda()
201
+ tfp_label = torch.stack(ttt,0).cuda()
202
+
203
+ # tfi_label = (torch.clamp(tfi_label,0,1))**(1/2.2)
204
+ # tfp_label = (torch.clamp(tfp_label,0,1))**(1/2.2)
205
+
206
+ tfp0 = self.conv0(x[:,self.middle-5:self.middle+5+1,:,:]) #b 1 h w,
207
+ tfp1 = self.conv1(x[:,self.middle-9:self.middle+9+1,:,:])
208
+ tfp2 = self.conv2(x[:,self.middle-13:self.middle+13+1,:,:])
209
+ tfps = torch.cat([tfp0,tfp1,tfp2],dim=1) #b 3 h w
210
+
211
+ tfp_fea = self.tfpconv(tfps)
212
+ tfi_fea = self.tficonv(tfi_label)
213
+
214
+ if not self.blind:
215
+ tfp_fea = self.SpatialGate(tfp_fea) #b 16 h w
216
+ tfi_fea = self.SpatialGate(tfi_fea)
217
+ fusion_fea = self.ChannelGate(tfp_fea+tfi_fea) #b 16 h w
218
+ else:
219
+ fusion_fea = tfp_fea+tfi_fea
220
+ # tfi_label = (torch.clamp(tfi_label,0,1))**(1/2.2)
221
+ # tfp_label = (torch.clamp(tfp_label,0,1))**(1/2.2)
222
+ return fusion_fea,tfi_label,tfp_label
223
+
224
+
225
+ class BSN(nn.Module):
226
+ def __init__(self, n_channels=3, n_output=3, bias=False, blind=True, sigma_known=True):
227
+ super().__init__()
228
+ self.n_channels = n_channels
229
+ self.c = n_channels
230
+ self.n_output = n_output
231
+ self.bias = bias
232
+ self.blind = blind
233
+ self.sigma_known = sigma_known
234
+ self.rotate = rotate()
235
+ self.unet = Blind_UNet(n_channels=n_channels+16, bias=bias, blind=blind)
236
+ self.shift = shift()
237
+ self.unrotate = unrotate()
238
+ self.nin_A = nn.Conv2d(384, 384, 1, bias=bias)
239
+ self.nin_B = nn.Conv2d(384, 96, 1, bias=bias)
240
+ self.nin_C = nn.Conv2d(96, n_output, 1, bias=bias)
241
+ self.MotionInference = MotionInference(n_frame=41,bias=bias,blind=blind)
242
+
243
+ def forward(self, x):
244
+ N, C, H, W = x.shape
245
+ _,tfi_label,tfp_label = self.MotionInference(x)
246
+ if(H > W):
247
+ diff = H - W
248
+ x = F.pad(x, [diff // 2, diff - diff // 2, 0, 0], mode = 'reflect')
249
+ elif(W > H):
250
+ diff = W - H
251
+ x = F.pad(x, [0, 0, diff // 2, diff - diff // 2], mode = 'reflect')
252
+
253
+ x = self.rotate(x)
254
+
255
+ fea1,tfi,tfp = self.MotionInference(x)
256
+ x = torch.cat([x,fea1],1)
257
+
258
+ x = self.unet(x) #4 3 100 100 -> 4 96 100 100
259
+ if self.blind:
260
+ x = self.shift(x)
261
+ x = self.unrotate(x) #4 96 100 100 -> 1 384 100 100
262
+
263
+ x0 = F.leaky_relu_(self.nin_A(x), negative_slope=0.1)
264
+ x0 = F.leaky_relu_(self.nin_B(x0), negative_slope=0.1)
265
+ x0 = self.nin_C(x0)
266
+
267
+ # Unsquare
268
+ if(H > W):
269
+ diff = H - W
270
+ x0 = x0[:, :, 0:H, (diff // 2):(diff // 2 + W)]
271
+ elif(W > H):
272
+ diff = W - H
273
+ x0 = x0[:, :, (diff // 2):(diff // 2 + H), 0:W]
274
+
275
+ return x0,tfi_label,tfp_label
276
+
277
+ class DoubleNet(nn.Module):
278
+ def __init__(self):
279
+ super().__init__()
280
+ self.nbsn = BSN(n_channels=41, n_output=1,blind=False)
281
+ # self.bsn = BSN(n_channels=41, n_output=1,blind=True)
282
+
283
+ def forward(self, x):
284
+ out1,_,_ = self.nbsn(x)
285
+
286
+ return out1
287
+
288
+ if __name__ == '__main__':
289
+ a=DoubleNet().cuda()
290
+ print(a(torch.ones(2,41,40,40).cuda()))
Binary file