spikezoo 0.1.2__py3-none-any.whl → 0.2__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (192) hide show
  1. spikezoo/__init__.py +13 -0
  2. spikezoo/archs/__pycache__/__init__.cpython-39.pyc +0 -0
  3. spikezoo/archs/base/__pycache__/nets.cpython-39.pyc +0 -0
  4. spikezoo/archs/base/nets.py +34 -0
  5. spikezoo/archs/bsf/README.md +92 -0
  6. spikezoo/archs/bsf/datasets/datasets.py +328 -0
  7. spikezoo/archs/bsf/datasets/ds_utils.py +64 -0
  8. spikezoo/archs/bsf/main.py +398 -0
  9. spikezoo/archs/bsf/metrics/psnr.py +22 -0
  10. spikezoo/archs/bsf/metrics/ssim.py +54 -0
  11. spikezoo/archs/bsf/models/bsf/__pycache__/align.cpython-39.pyc +0 -0
  12. spikezoo/archs/bsf/models/bsf/__pycache__/bsf.cpython-39.pyc +0 -0
  13. spikezoo/archs/bsf/models/bsf/__pycache__/rep.cpython-39.pyc +0 -0
  14. spikezoo/archs/bsf/models/bsf/align.py +154 -0
  15. spikezoo/archs/bsf/models/bsf/bsf.py +105 -0
  16. spikezoo/archs/bsf/models/bsf/dsft_convert.py +96 -0
  17. spikezoo/archs/bsf/models/bsf/rep.py +44 -0
  18. spikezoo/archs/bsf/models/get_model.py +7 -0
  19. spikezoo/archs/bsf/prepare_data/DSFT.py +62 -0
  20. spikezoo/archs/bsf/prepare_data/crop_dataset_train.py +135 -0
  21. spikezoo/archs/bsf/prepare_data/crop_dataset_val.py +139 -0
  22. spikezoo/archs/bsf/prepare_data/crop_train.sh +4 -0
  23. spikezoo/archs/bsf/prepare_data/crop_val.sh +4 -0
  24. spikezoo/archs/bsf/prepare_data/io_utils.py +64 -0
  25. spikezoo/archs/bsf/requirements.txt +9 -0
  26. spikezoo/archs/bsf/test.py +16 -0
  27. spikezoo/archs/bsf/utils.py +154 -0
  28. spikezoo/archs/spikeclip/__pycache__/nets.cpython-39.pyc +0 -0
  29. spikezoo/archs/spikeclip/nets.py +40 -0
  30. spikezoo/archs/spikeformer/CheckPoints/readme +1 -0
  31. spikezoo/archs/spikeformer/DataProcess/DataExtactor.py +60 -0
  32. spikezoo/archs/spikeformer/DataProcess/DataLoader.py +115 -0
  33. spikezoo/archs/spikeformer/DataProcess/LoadSpike.py +39 -0
  34. spikezoo/archs/spikeformer/EvalResults/readme +1 -0
  35. spikezoo/archs/spikeformer/LICENSE +21 -0
  36. spikezoo/archs/spikeformer/Metrics/Metrics.py +50 -0
  37. spikezoo/archs/spikeformer/Metrics/__init__.py +0 -0
  38. spikezoo/archs/spikeformer/Model/Loss.py +89 -0
  39. spikezoo/archs/spikeformer/Model/SpikeFormer.py +230 -0
  40. spikezoo/archs/spikeformer/Model/__init__.py +0 -0
  41. spikezoo/archs/spikeformer/Model/__pycache__/SpikeFormer.cpython-39.pyc +0 -0
  42. spikezoo/archs/spikeformer/Model/__pycache__/__init__.cpython-39.pyc +0 -0
  43. spikezoo/archs/spikeformer/README.md +30 -0
  44. spikezoo/archs/spikeformer/evaluate.py +87 -0
  45. spikezoo/archs/spikeformer/recon_real_data.py +97 -0
  46. spikezoo/archs/spikeformer/requirements.yml +95 -0
  47. spikezoo/archs/spikeformer/train.py +173 -0
  48. spikezoo/archs/spikeformer/utils.py +22 -0
  49. spikezoo/archs/spk2imgnet/.github/workflows/pylint.yml +23 -0
  50. spikezoo/archs/spk2imgnet/.gitignore +150 -0
  51. spikezoo/archs/spk2imgnet/DCNv2.py +135 -0
  52. spikezoo/archs/spk2imgnet/__pycache__/DCNv2.cpython-39.pyc +0 -0
  53. spikezoo/archs/spk2imgnet/__pycache__/align_arch.cpython-39.pyc +0 -0
  54. spikezoo/archs/spk2imgnet/__pycache__/nets.cpython-39.pyc +0 -0
  55. spikezoo/archs/spk2imgnet/align_arch.py +159 -0
  56. spikezoo/archs/spk2imgnet/dataset.py +144 -0
  57. spikezoo/archs/spk2imgnet/nets.py +230 -0
  58. spikezoo/archs/spk2imgnet/readme.md +86 -0
  59. spikezoo/archs/spk2imgnet/test_gen_imgseq.py +118 -0
  60. spikezoo/archs/spk2imgnet/train.py +189 -0
  61. spikezoo/archs/spk2imgnet/utils.py +64 -0
  62. spikezoo/archs/ssir/README.md +87 -0
  63. spikezoo/archs/ssir/configs/SSIR.yml +37 -0
  64. spikezoo/archs/ssir/configs/yml_parser.py +78 -0
  65. spikezoo/archs/ssir/datasets/dataset_sreds.py +170 -0
  66. spikezoo/archs/ssir/datasets/ds_utils.py +66 -0
  67. spikezoo/archs/ssir/losses.py +21 -0
  68. spikezoo/archs/ssir/main.py +326 -0
  69. spikezoo/archs/ssir/metrics/psnr.py +22 -0
  70. spikezoo/archs/ssir/metrics/ssim.py +54 -0
  71. spikezoo/archs/ssir/models/Vgg19.py +42 -0
  72. spikezoo/archs/ssir/models/__pycache__/layers.cpython-39.pyc +0 -0
  73. spikezoo/archs/ssir/models/__pycache__/networks.cpython-39.pyc +0 -0
  74. spikezoo/archs/ssir/models/layers.py +110 -0
  75. spikezoo/archs/ssir/models/networks.py +61 -0
  76. spikezoo/archs/ssir/requirements.txt +8 -0
  77. spikezoo/archs/ssir/shells/eval_SREDS.sh +6 -0
  78. spikezoo/archs/ssir/shells/train_SSIR.sh +12 -0
  79. spikezoo/archs/ssir/test.py +3 -0
  80. spikezoo/archs/ssir/utils.py +154 -0
  81. spikezoo/archs/ssml/__pycache__/cbam.cpython-39.pyc +0 -0
  82. spikezoo/archs/ssml/__pycache__/model.cpython-39.pyc +0 -0
  83. spikezoo/archs/ssml/cbam.py +224 -0
  84. spikezoo/archs/ssml/model.py +290 -0
  85. spikezoo/archs/ssml/res.png +0 -0
  86. spikezoo/archs/ssml/test.py +67 -0
  87. spikezoo/archs/stir/.git-credentials +0 -0
  88. spikezoo/archs/stir/README.md +65 -0
  89. spikezoo/archs/stir/ckpt_outputs/Descriptions.txt +1 -0
  90. spikezoo/archs/stir/configs/STIR.yml +37 -0
  91. spikezoo/archs/stir/configs/utils.py +155 -0
  92. spikezoo/archs/stir/configs/yml_parser.py +78 -0
  93. spikezoo/archs/stir/datasets/dataset_sreds.py +180 -0
  94. spikezoo/archs/stir/datasets/ds_utils.py +66 -0
  95. spikezoo/archs/stir/eval_SREDS.sh +5 -0
  96. spikezoo/archs/stir/main.py +397 -0
  97. spikezoo/archs/stir/metrics/losses.py +219 -0
  98. spikezoo/archs/stir/metrics/psnr.py +22 -0
  99. spikezoo/archs/stir/metrics/ssim.py +54 -0
  100. spikezoo/archs/stir/models/Vgg19.py +42 -0
  101. spikezoo/archs/stir/models/__pycache__/networks_STIR.cpython-39.pyc +0 -0
  102. spikezoo/archs/stir/models/__pycache__/submodules.cpython-39.pyc +0 -0
  103. spikezoo/archs/stir/models/__pycache__/transformer_new.cpython-39.pyc +0 -0
  104. spikezoo/archs/stir/models/networks_STIR.py +361 -0
  105. spikezoo/archs/stir/models/submodules.py +86 -0
  106. spikezoo/archs/stir/models/transformer_new.py +151 -0
  107. spikezoo/archs/stir/package_core/build/lib/package_core/__init__.py +0 -0
  108. spikezoo/archs/stir/package_core/build/lib/package_core/convertions.py +721 -0
  109. spikezoo/archs/stir/package_core/build/lib/package_core/disp_netS.py +133 -0
  110. spikezoo/archs/stir/package_core/build/lib/package_core/flow_utils.py +167 -0
  111. spikezoo/archs/stir/package_core/build/lib/package_core/generic_train_test.py +76 -0
  112. spikezoo/archs/stir/package_core/build/lib/package_core/geometry.py +458 -0
  113. spikezoo/archs/stir/package_core/build/lib/package_core/image_proc.py +183 -0
  114. spikezoo/archs/stir/package_core/build/lib/package_core/linalg.py +40 -0
  115. spikezoo/archs/stir/package_core/build/lib/package_core/losses.py +198 -0
  116. spikezoo/archs/stir/package_core/build/lib/package_core/metrics.py +51 -0
  117. spikezoo/archs/stir/package_core/build/lib/package_core/model_base.py +53 -0
  118. spikezoo/archs/stir/package_core/build/lib/package_core/net_basics.py +100 -0
  119. spikezoo/archs/stir/package_core/build/lib/package_core/resnet.py +333 -0
  120. spikezoo/archs/stir/package_core/build/lib/package_core/transforms.py +123 -0
  121. spikezoo/archs/stir/package_core/build/lib/package_core/utils.py +72 -0
  122. spikezoo/archs/stir/package_core/dist/package_core-0.0.0-py3.9.egg +0 -0
  123. spikezoo/archs/stir/package_core/package_core/__init__.py +0 -0
  124. spikezoo/archs/stir/package_core/package_core/__pycache__/__init__.cpython-39.pyc +0 -0
  125. spikezoo/archs/stir/package_core/package_core/__pycache__/net_basics.cpython-39.pyc +0 -0
  126. spikezoo/archs/stir/package_core/package_core/convertions.py +721 -0
  127. spikezoo/archs/stir/package_core/package_core/disp_netS.py +133 -0
  128. spikezoo/archs/stir/package_core/package_core/flow_utils.py +167 -0
  129. spikezoo/archs/stir/package_core/package_core/generic_train_test.py +76 -0
  130. spikezoo/archs/stir/package_core/package_core/geometry.py +458 -0
  131. spikezoo/archs/stir/package_core/package_core/image_proc.py +183 -0
  132. spikezoo/archs/stir/package_core/package_core/linalg.py +40 -0
  133. spikezoo/archs/stir/package_core/package_core/losses.py +198 -0
  134. spikezoo/archs/stir/package_core/package_core/metrics.py +51 -0
  135. spikezoo/archs/stir/package_core/package_core/model_base.py +53 -0
  136. spikezoo/archs/stir/package_core/package_core/net_basics.py +100 -0
  137. spikezoo/archs/stir/package_core/package_core/resnet.py +333 -0
  138. spikezoo/archs/stir/package_core/package_core/transforms.py +123 -0
  139. spikezoo/archs/stir/package_core/package_core/utils.py +72 -0
  140. spikezoo/archs/stir/package_core/package_core.egg-info/PKG-INFO +3 -0
  141. spikezoo/archs/stir/package_core/package_core.egg-info/SOURCES.txt +20 -0
  142. spikezoo/archs/stir/package_core/package_core.egg-info/dependency_links.txt +1 -0
  143. spikezoo/archs/stir/package_core/package_core.egg-info/top_level.txt +1 -0
  144. spikezoo/archs/stir/package_core/setup.py +5 -0
  145. spikezoo/archs/stir/requirements.txt +12 -0
  146. spikezoo/archs/stir/train_STIR.sh +9 -0
  147. spikezoo/archs/tfi/__pycache__/nets.cpython-39.pyc +0 -0
  148. spikezoo/archs/tfi/nets.py +43 -0
  149. spikezoo/archs/tfp/__pycache__/nets.cpython-39.pyc +0 -0
  150. spikezoo/archs/tfp/nets.py +13 -0
  151. spikezoo/archs/wgse/README.md +64 -0
  152. spikezoo/archs/wgse/__pycache__/dwtnets.cpython-39.pyc +0 -0
  153. spikezoo/archs/wgse/__pycache__/submodules.cpython-39.pyc +0 -0
  154. spikezoo/archs/wgse/dataset.py +59 -0
  155. spikezoo/archs/wgse/demo.png +0 -0
  156. spikezoo/archs/wgse/demo.py +83 -0
  157. spikezoo/archs/wgse/dwtnets.py +145 -0
  158. spikezoo/archs/wgse/eval.py +133 -0
  159. spikezoo/archs/wgse/logs/WGSE-Dwt1dNet-db8-5-ks3/log.txt +11 -0
  160. spikezoo/archs/wgse/submodules.py +68 -0
  161. spikezoo/archs/wgse/train.py +261 -0
  162. spikezoo/archs/wgse/transform.py +139 -0
  163. spikezoo/archs/wgse/utils.py +128 -0
  164. spikezoo/archs/wgse/weights/demo.png +0 -0
  165. spikezoo/data/base/test/gt/200_part1_key_id151.png +0 -0
  166. spikezoo/data/base/test/gt/200_part3_key_id151.png +0 -0
  167. spikezoo/data/base/test/gt/203_part1_key_id151.png +0 -0
  168. spikezoo/data/base/test/spike/200_part1_key_id151.dat +0 -0
  169. spikezoo/data/base/test/spike/200_part3_key_id151.dat +0 -0
  170. spikezoo/data/base/test/spike/203_part1_key_id151.dat +0 -0
  171. spikezoo/data/base/train/gt/203_part2_key_id151.png +0 -0
  172. spikezoo/data/base/train/gt/203_part3_key_id151.png +0 -0
  173. spikezoo/data/base/train/gt/203_part4_key_id151.png +0 -0
  174. spikezoo/data/base/train/spike/203_part2_key_id151.dat +0 -0
  175. spikezoo/data/base/train/spike/203_part3_key_id151.dat +0 -0
  176. spikezoo/data/base/train/spike/203_part4_key_id151.dat +0 -0
  177. spikezoo/datasets/base_dataset.py +2 -3
  178. spikezoo/metrics/__init__.py +1 -1
  179. spikezoo/models/base_model.py +1 -3
  180. spikezoo/pipeline/base_pipeline.py +7 -5
  181. spikezoo/pipeline/train_pipeline.py +1 -1
  182. spikezoo/utils/other_utils.py +16 -6
  183. spikezoo/utils/spike_utils.py +33 -29
  184. spikezoo/utils/vidar_loader.cpython-39-x86_64-linux-gnu.so +0 -0
  185. spikezoo-0.2.dist-info/METADATA +163 -0
  186. spikezoo-0.2.dist-info/RECORD +211 -0
  187. spikezoo/models/spcsnet_model.py +0 -19
  188. spikezoo-0.1.2.dist-info/METADATA +0 -39
  189. spikezoo-0.1.2.dist-info/RECORD +0 -36
  190. {spikezoo-0.1.2.dist-info → spikezoo-0.2.dist-info}/LICENSE.txt +0 -0
  191. {spikezoo-0.1.2.dist-info → spikezoo-0.2.dist-info}/WHEEL +0 -0
  192. {spikezoo-0.1.2.dist-info → spikezoo-0.2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,16 @@
1
+ import numpy as np
2
+ import torch
3
+ # import cupy as cp
4
+
5
+ def compute_dsft_core(spike):
6
+ H, W, T = spike.shape
7
+ time = spike * torch.arange(T, device='cuda').reshape(1, 1, T)
8
+ l_idx, _ = time.cummax(dim=2)
9
+ time[time==0] = T
10
+ r_idx, _ = torch.flip(time, [2]).cummin(dim=2)
11
+ r_idx = torch.flip(r_idx, [2])
12
+ r_idx = torch.cat([r_idx[:, :, 1:], torch.ones([H, W, 1], device='cuda') * T], dim=2)
13
+ res = r_idx - l_idx
14
+
15
+ res = torch.clip(res, 0)
16
+ return res
@@ -0,0 +1,154 @@
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn.functional as F
4
+ import os
5
+ import os.path as osp
6
+ import random
7
+ import cv2
8
+
9
+ def set_seeds(_seed_):
10
+ random.seed(_seed_)
11
+ np.random.seed(_seed_)
12
+ torch.manual_seed(_seed_) # use torch.manual_seed() to seed the RNG for all devices (both CPU and CUDA)
13
+ torch.cuda.manual_seed_all(_seed_)
14
+
15
+ torch.backends.cudnn.deterministic = True
16
+ torch.backends.cudnn.benchmark = False
17
+ # os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'
18
+ # set a debug environment variable CUBLAS_WORKSPACE_CONFIG to ":16:8" (may limit overall performance) or ":4096:8" (will increase library footprint in GPU memory by approximately 24MiB).
19
+ # torch.use_deterministic_algorithms(True)
20
+
21
+
22
+ def make_dir(path):
23
+ if not osp.exists(path):
24
+ os.makedirs(path)
25
+ return
26
+
27
+
28
+ def add_args_to_cfg(cfg, args, args_list):
29
+ for aa in args_list:
30
+ cfg['train'][aa] = eval('args.{:s}'.format(aa))
31
+ return cfg
32
+
33
+
34
+ # class AverageMeter(object):
35
+ # """Computes and stores the average and current value"""
36
+ # def __init__(self, precision=3):
37
+ # self.precision = precision
38
+ # self.reset()
39
+
40
+ # def reset(self):
41
+ # self.val = 0
42
+ # self.avg = 0
43
+ # self.sum = 0
44
+ # self.count = 0
45
+
46
+ # def update(self, val, n=1):
47
+ # self.val = val
48
+ # self.sum += val * n
49
+ # self.count += n
50
+ # self.avg = self.sum / self.count
51
+
52
+ # def __repr__(self):
53
+ # return '{:.{}f} ({:.{}f})'.format(self.val, self.precision, self.avg, self.precision)
54
+
55
+
56
+ class AverageMeter(object):
57
+ """Computes and stores the average and current value"""
58
+
59
+ def __init__(self, i=1, precision=3, names=None):
60
+ self.meters = i
61
+ self.precision = precision
62
+ self.reset(self.meters)
63
+ self.names = names
64
+ if names is not None:
65
+ assert self.meters == len(self.names)
66
+ else:
67
+ self.names = [''] * self.meters
68
+
69
+ def reset(self, i):
70
+ self.val = [0] * i
71
+ self.avg = [0] * i
72
+ self.sum = [0] * i
73
+ self.count = [0] * i
74
+
75
+ def update(self, val, n=1):
76
+ if not isinstance(val, list):
77
+ val = [val]
78
+ if not isinstance(n, list):
79
+ n = [n] * self.meters
80
+ assert (len(val) == self.meters and len(n) == self.meters)
81
+ for i in range(self.meters):
82
+ self.count[i] += n[i]
83
+ for i, v in enumerate(val):
84
+ self.val[i] = v
85
+ self.sum[i] += v * n[i]
86
+ self.avg[i] = self.sum[i] / self.count[i]
87
+
88
+ def __repr__(self):
89
+ # val = ' '.join(['{} {:.{}f}'.format(n, v, self.precision) for n, v in
90
+ # zip(self.names, self.val)])
91
+ # avg = ' '.join(['{} {:.{}f}'.format(n, a, self.precision) for n, a in
92
+ # zip(self.names, self.avg)])
93
+ out = ' '.join(['{} {:.{}f} ({:.{}f})'.format(n, v, self.precision, a, self.precision) for n, v, a in
94
+ zip(self.names, self.val, self.avg)])
95
+ # return '{} ({})'.format(val, avg)
96
+ return '{}'.format(out)
97
+
98
+
99
+ def normalize_image_torch(image, percentile_lower=1, percentile_upper=99):
100
+ b, c, h, w = image.shape
101
+ image_reshape = image.reshape([b, c, h*w])
102
+ mini = torch.quantile(image_reshape, 0.01, dim=2, keepdim=True).unsqueeze_(dim=3)
103
+ maxi = torch.quantile(image_reshape, 0.99, dim=2, keepdim=True).unsqueeze_(dim=3)
104
+ # if mini == maxi:
105
+ # return 0 * image + 0.5 # gray image
106
+ return torch.clip((image - mini) / (maxi - mini + 1e-5), 0, 1)
107
+
108
+ def normalize_image_torch2(image):
109
+ return torch.clip(image, 0, 1)
110
+
111
+ # --------------------------------------------
112
+ # Torch to Numpy 0~255
113
+ # --------------------------------------------
114
+ def torch2numpy255(im):
115
+ im = im[0, 0].detach().cpu().numpy()
116
+ im = (im * 255).astype(np.uint8)
117
+ return im
118
+
119
+ def torch2torch255(im):
120
+ return im * 255.0
121
+
122
+ class InputPadder:
123
+ """ Pads images such that dimensions are divisible by padsize """
124
+ def __init__(self, dims, padsize=16):
125
+ self.ht, self.wd = dims[-2:]
126
+ pad_ht = (((self.ht // padsize) + 1) * padsize - self.ht) % padsize
127
+ pad_wd = (((self.wd // padsize) + 1) * padsize - self.wd) %padsize
128
+ self._pad = [pad_wd//2, pad_wd - pad_wd//2, 0, pad_ht]
129
+
130
+ def pad(self, *inputs):
131
+ return [F.pad(x, self._pad, mode='replicate') for x in inputs]
132
+
133
+ def unpad(self,x):
134
+ ht, wd = x.shape[-2:]
135
+ c = [self._pad[2], ht-self._pad[3], self._pad[0], wd-self._pad[1]]
136
+ return x[..., c[0]:c[1], c[2]:c[3]]
137
+
138
+
139
+
140
+ def vis_img(vis_path: str, img: torch.Tensor, vis_name: str = 'vis'):
141
+ ww = 0
142
+ rows = []
143
+ for ii in range(4):
144
+ cur_row = []
145
+ for jj in range(img.shape[0]//4):
146
+ cur_img = img[ww, 0].detach().cpu().numpy() * 255
147
+ cur_img = cur_img.astype(np.uint8)
148
+ cur_row.append(cur_img)
149
+ ww += 1
150
+ cur_row_cat = np.concatenate(cur_row, axis=1)
151
+ rows.append(cur_row_cat)
152
+ out_img = np.concatenate(rows, axis=0)
153
+ cv2.imwrite(osp.join(vis_path, vis_name+'.png'), out_img)
154
+ return
@@ -0,0 +1,40 @@
1
+ import torch.nn as nn
2
+ import torch
3
+
4
+ def conv_layer(inDim, outDim, ks, s, p, norm_layer='none'):
5
+ ## convolutional layer
6
+ conv = nn.Conv2d(inDim, outDim, kernel_size=ks, stride=s, padding=p)
7
+ relu = nn.ReLU(True)
8
+ assert norm_layer in ('batch', 'instance', 'none')
9
+ if norm_layer == 'none':
10
+ seq = nn.Sequential(*[conv, relu])
11
+ else:
12
+ if (norm_layer == 'instance'):
13
+ norm = nn.InstanceNorm2d(outDim, affine=False, track_running_stats=False) # instance norm
14
+ else:
15
+ momentum = 0.1
16
+ norm = nn.BatchNorm2d(outDim, momentum = momentum, affine=True, track_running_stats=True)
17
+ seq = nn.Sequential(*[conv, norm, relu])
18
+ return seq
19
+
20
+ def LRN(inDim=50, outDim=1, norm='none'):
21
+ convBlock1 = conv_layer(inDim,64,3,1,1)
22
+ convBlock2 = conv_layer(64,128,3,1,1,norm)
23
+ convBlock3 = conv_layer(128,64,3,1,1,norm)
24
+ convBlock4 = conv_layer(64,16,3,1,1,norm)
25
+ conv = nn.Conv2d(16, outDim, 3, 1, 1)
26
+ seq = nn.Sequential(*[convBlock1, convBlock2, convBlock3, convBlock4, conv])
27
+ return seq
28
+
29
+
30
+ from thop import profile
31
+ if __name__ == "__main__":
32
+ net = LRN()
33
+ total = sum(p.numel() for p in net.parameters())
34
+ spike = torch.zeros((1,50,250,400))
35
+ flops, _ = profile((net), inputs=(spike,))
36
+ re_msg = (
37
+ "Total params: %.4fM" % (total / 1e6),
38
+ "FLOPs=" + str(flops / 1e9) + '{}'.format("G"),
39
+ )
40
+ print(re_msg)
@@ -0,0 +1 @@
1
+ This is a folder for saving the trained model !
@@ -0,0 +1,60 @@
1
+ import os
2
+ import numpy as np
3
+
4
+
5
+ class DataExtractor():
6
+
7
+ def __init__(self, dataPath='', type='train'):
8
+
9
+ self.type = type
10
+ self.rootPath = dataPath
11
+
12
+ def GetData(self):
13
+
14
+ if self.type == "train":
15
+ return self.__GetTrainData()
16
+ if self.type == "valid":
17
+ return self.__GetValidData()
18
+ if self.type == "test":
19
+ return self.__GetTestData()
20
+
21
+
22
+ def __GetTrainData(self):
23
+
24
+ pathList = []
25
+
26
+ root = os.path.join(self.rootPath, 'train')
27
+ fileNames = os.listdir(root)
28
+ fileNames.sort()
29
+ for name in fileNames:
30
+ path = os.path.join(root, name)
31
+ pathList.append(path)
32
+
33
+ return pathList
34
+
35
+ def __GetValidData(self):
36
+
37
+ pathList = []
38
+
39
+ root = os.path.join(self.rootPath, 'valid')
40
+ fileNames = os.listdir(root)
41
+ fileNames.sort()
42
+ for name in fileNames:
43
+ path = os.path.join(root, name)
44
+ pathList.append(path)
45
+
46
+ return pathList
47
+
48
+ def __GetTestData(self):
49
+
50
+ pathList = []
51
+
52
+ root = os.path.join(self.rootPath, 'test')
53
+ fileNames = os.listdir(root)
54
+ fileNames.sort()
55
+ for name in fileNames:
56
+ path = os.path.join(root, name)
57
+ pathList.append(path)
58
+
59
+ return pathList
60
+
@@ -0,0 +1,115 @@
1
+ import os
2
+ import torch
3
+ # from torchvision import transforms
4
+ from torch.utils import data
5
+ import numpy as np
6
+ from PIL import Image
7
+ import cv2
8
+ import random
9
+
10
+
11
+ from DataProcess.DataExtactor import DataExtractor
12
+ from DataProcess.LoadSpike import LoadSpike, load_spike_raw
13
+
14
+ class Dataset(data.Dataset):
15
+
16
+ def __init__(self, pathList, dataType, spikeRadius):
17
+
18
+ self.pathList = pathList
19
+ self.dataType = dataType
20
+ self.spikeRadius = spikeRadius
21
+
22
+ #Random Rotation
23
+ if self.dataType == "train":
24
+ self.choice = [0, 1, 2, 3]
25
+ else:
26
+ self.choice = [0]
27
+
28
+ def __getitem__(self, index):
29
+
30
+ spSeq, gtFrames = self.GetItem(index)
31
+
32
+ return spSeq, gtFrames
33
+
34
+ def __len__(self):
35
+
36
+ return len(self.pathList)
37
+
38
+ def GetItem(self, index):
39
+
40
+ path = self.pathList[index]
41
+ spSeq, gtFrames = LoadSpike(path)
42
+
43
+ spLen, _, _ = spSeq.shape
44
+ gtLen, _, _ = gtFrames.shape
45
+ spCenter = spLen // 2
46
+ gtCenter = gtLen // 2
47
+
48
+ spLeft, spRight = (spCenter - self.spikeRadius,
49
+ spCenter + self.spikeRadius)
50
+ spRight = spRight + 1
51
+ spSeq = spSeq[spLeft:spRight]
52
+
53
+ gtFrame = gtFrames[gtCenter]
54
+
55
+ spSeq = np.pad(spSeq, ((0, 0), (3, 3), (0, 0)), mode='constant')
56
+ spSeq = spSeq.astype(float) * 2 - 1
57
+
58
+ gtFrame = gtFrame.astype(float) / 255. * 2.0 - 1.
59
+
60
+
61
+ spSeq = torch.FloatTensor(spSeq)
62
+ gtFrame = torch.FloatTensor(gtFrame)
63
+
64
+ '''
65
+ Rotate the spike frame and Gt frame by ramdom degree,
66
+ depending on the values of 'self.choice'
67
+ '''
68
+ # choice = random.choice(self.choice)
69
+ # spSeq = torch.rot90(spSeq, choice, dims=(1,2))
70
+ # gtFrame =torch.rot90(gtFrame, choice, dims=(1,2))
71
+ return spSeq, gtFrame
72
+
73
+
74
+
75
+
76
+
77
+ class DataContainer():
78
+
79
+ def __init__(self, dataPath='', dataType='train',
80
+ spikeRadius=16, batchSize=128, numWorks=0):
81
+
82
+ self.dataPath = dataPath
83
+ self.dataType = dataType
84
+ self.spikeRadius = spikeRadius
85
+ self.batchSize = batchSize
86
+ self.numWorks = numWorks
87
+
88
+ self.__GetData()
89
+
90
+ def __GetData(self):
91
+
92
+ dataset = None
93
+
94
+ dataset = DataExtractor(dataPath=self.dataPath, type=self.dataType)
95
+ self.pathList = dataset.GetData()
96
+
97
+ def GetLoader(self):
98
+
99
+ dataset = Dataset(self.pathList, self.dataType, self.spikeRadius)
100
+ dataLoader = None
101
+ if self.dataType == "train":
102
+ dataLoader = data.DataLoader(dataset, batch_size=self.batchSize, shuffle=True,
103
+ num_workers=self.numWorks, pin_memory=False)
104
+ else:
105
+ dataLoader = data.DataLoader(dataset, batch_size=self.batchSize, shuffle=False,
106
+ num_workers=self.numWorks, pin_memory=False)
107
+
108
+ return dataLoader
109
+
110
+ if __name__ == "__main__":
111
+
112
+ pass
113
+
114
+
115
+
@@ -0,0 +1,39 @@
1
+ import numpy as np
2
+
3
+ def load_spike_numpy(path: str) -> (np.ndarray, np.ndarray):
4
+ '''
5
+ Load a spike sequence with it's tag from prepacked `.npz` file.\n
6
+ The sequence is of shape (`length`, `height`, `width`) and tag of
7
+ shape (`height`, `width`).
8
+ '''
9
+ data = np.load(path)
10
+ seq, tag, length = data['seq'], data['tag'], int(data['length'])
11
+ seq = np.array([(seq[i // 8] >> (i & 7)) & 1 for i in range(length)])
12
+ return seq, tag
13
+
14
+ def LoadSpike(path: str) -> (np.ndarray, np.ndarray):
15
+ '''
16
+ Load a spike sequence, the corresponding ground-truth frame sequence,
17
+ and sequence length.
18
+ spSeq: an ndarray of shape('sequence number', 'height', 'width')
19
+ gtFrames: an ndarray of shape('sequence length', 'height', 'width')
20
+ '''
21
+ data = np.load(path)
22
+ spSeq, gtFrames, length = data['spSeq'], data['gt'], int(data['length'])
23
+ spSeq = np.array([(spSeq[i // 8] >> (i & 7)) & 1 for i in range(length)])
24
+ return spSeq, gtFrames
25
+
26
+ def load_spike_raw(path: str, width=400, height=250) -> np.ndarray:
27
+ '''
28
+ Load bit-compact raw spike data into an ndarray of shape
29
+ (`sequence length`, `height`, `width`).
30
+ '''
31
+ with open(path, 'rb') as f:
32
+ fbytes = f.read()
33
+ fnum = (len(fbytes) * 8) // (width * height) # number of frames
34
+ frames = np.frombuffer(fbytes, dtype=np.uint8)
35
+ frames = np.array([frames & (1 << i) for i in range(8)])
36
+ frames = frames.astype(np.bool).astype(np.uint8)
37
+ frames = frames.transpose(1, 0).reshape(fnum, height, width)
38
+ frames = np.flip(frames, 1)
39
+ return frames
@@ -0,0 +1 @@
1
+ This is a folder for saving the images reconstructed from validation/testing set !
@@ -0,0 +1,21 @@
1
+ MIT License
2
+
3
+ Copyright (c) 2022 YangChenUcas
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
@@ -0,0 +1,50 @@
1
+ import numpy as np
2
+ from skimage import metrics
3
+
4
+ class Metrics():
5
+
6
+ def __init__(self):
7
+ self.best_psnr = 0.
8
+ self.best_ssim = 0.
9
+ self.best_niqe = 0.
10
+
11
+ def Update(self, psnr=0., ssim=0., niqe=0.):
12
+ self.best_psnr = psnr
13
+ self.best_ssim = ssim
14
+ self.best_niqe = niqe
15
+
16
+ def GetBestMetrics(self):
17
+
18
+ return self.best_psnr, self.best_ssim, self.best_niqe
19
+
20
+ def Cal_PSNR(self, preImgs, gtImgs): #shape:[B, H, W]
21
+
22
+ B, _, _ = preImgs.shape
23
+ total_psnr = 0.
24
+ for i, (pre, gt) in enumerate(zip(preImgs, gtImgs)):
25
+ print(i+1, metrics.peak_signal_noise_ratio(gt, pre))
26
+ total_psnr += metrics.peak_signal_noise_ratio(gt, pre)
27
+
28
+ avg_psnr = total_psnr / B
29
+
30
+ return avg_psnr
31
+
32
+ def Cal_SSIM(self, preImgs, gtImgs): #shape:[B, H, W]
33
+
34
+ B, _, _ = preImgs.shape
35
+ total_ssim = 0.
36
+ for i, (pre, gt) in enumerate(zip(preImgs, gtImgs)):
37
+ total_ssim += metrics.structural_similarity(pre, gt)
38
+
39
+ avg_ssim = total_ssim / B
40
+
41
+ return avg_ssim
42
+
43
+
44
+ if __name__ == "__main__":
45
+
46
+ a = np.random.random((2,256,256))
47
+ b = np.random.random((2,256,256))
48
+ metrics = Metrics()
49
+
50
+ print(metrics.Cal_NIQE(a))
File without changes
@@ -0,0 +1,89 @@
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+ class CharbonnierLoss(nn.Module):
6
+ """Charbonnier Loss (L1)"""
7
+
8
+ def __init__(self, eps=1e-3):
9
+ super(CharbonnierLoss, self).__init__()
10
+ self.eps = eps
11
+
12
+ def forward(self, x, y):
13
+ diff = x - y
14
+ # loss = torch.sum(torch.sqrt(diff * diff + self.eps))
15
+ loss = torch.mean(torch.sqrt((diff * diff) + (self.eps*self.eps)))
16
+ return loss
17
+
18
+ class EdgeLoss(nn.Module):
19
+ def __init__(self):
20
+ super(EdgeLoss, self).__init__()
21
+ k = torch.Tensor([[.05, .25, .4, .25, .05]])
22
+ # self.kernel = torch.matmul(k.t(),k).unsqueeze(0).repeat(3,1,1,1)
23
+ self.kernel = torch.matmul(k.t(),k).unsqueeze(0).unsqueeze(0).repeat(1,3,1,1) #这个的repeat也是后加的
24
+ # print(self.kernel.shape)
25
+ if torch.cuda.is_available():
26
+ self.kernel = self.kernel.cuda()
27
+ self.loss = CharbonnierLoss()
28
+
29
+ def conv_gauss(self, img):
30
+ # print('aaaa')
31
+ # print(img.shape)
32
+ n_channels, _, kw, kh = self.kernel.shape
33
+ img = F.pad(img, (kw//2, kh//2, kw//2, kh//2), mode='replicate')
34
+ # return F.conv2d(img, self.kernel, groups=n_channels)
35
+ return F.conv2d(img, self.kernel)
36
+
37
+ def laplacian_kernel(self, current):
38
+ filtered = self.conv_gauss(current) # filter
39
+ down = filtered[:,:,::2,::2] # downsample
40
+ new_filter = torch.zeros_like(filtered)
41
+ new_filter[:,:,::2,::2] = down*4 # upsample
42
+ filtered = self.conv_gauss(new_filter.repeat(1,3,1,1)) # filter #这里为什么需要repeat一下?原文的目的是什么?否则不能正常运行
43
+ diff = current - filtered
44
+ return diff
45
+
46
+ def forward(self, x, y):
47
+ y = y.repeat(1,3,1,1)
48
+ x = x.repeat(1,3,1,1)
49
+ # print('bbbbbb')
50
+ # print(x.shape)
51
+ # print(y.shape)
52
+ loss = self.loss(self.laplacian_kernel(x), self.laplacian_kernel(y))
53
+ return loss
54
+
55
+
56
+ class VGGLoss4(nn.Module):
57
+ def __init__(self, path: str):
58
+ super().__init__()
59
+ self.features = nn.Sequential(
60
+ nn.Conv2d(3, 64, 3, 1, 1),
61
+ nn.ReLU(inplace=True),
62
+ nn.Conv2d(64, 64, 3, 1, 1),
63
+ nn.ReLU(inplace=True),
64
+ nn.MaxPool2d(2),
65
+ nn.Conv2d(64, 128, 3, 1, 1),
66
+ nn.ReLU(inplace=True),
67
+ nn.Conv2d(128, 128, 3, 1, 1),
68
+ nn.ReLU(inplace=True),
69
+ nn.MaxPool2d(2),
70
+ nn.Conv2d(128, 256, 3, 1, 1),
71
+ nn.ReLU(inplace=True),
72
+ # nn.Conv2d(256, 256, 3, 1, 1),
73
+ # nn.ReLU(inplace=True),
74
+ # nn.Conv2d(256, 256, 3, 1, 1),
75
+ # nn.ReLU(inplace=True),
76
+ )
77
+ self.load_state_dict(torch.load(path))
78
+ for param in self.parameters():
79
+ param.requires_grad = False
80
+
81
+ def forward(self, real_y, fake_y):
82
+ real_y = real_y.repeat((1, 3, 1, 1))
83
+ fake_y = fake_y.repeat((1, 3, 1, 1))
84
+ with torch.no_grad():
85
+ real_f = self.features(real_y)
86
+ fake_f = self.features(fake_y)
87
+ return F.mse_loss(real_f, fake_f)
88
+
89
+