spikezoo 0.1.2__py3-none-any.whl → 0.2.1__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (192) hide show
  1. spikezoo/__init__.py +13 -0
  2. spikezoo/archs/__pycache__/__init__.cpython-39.pyc +0 -0
  3. spikezoo/archs/base/__pycache__/nets.cpython-39.pyc +0 -0
  4. spikezoo/archs/base/nets.py +34 -0
  5. spikezoo/archs/bsf/README.md +92 -0
  6. spikezoo/archs/bsf/datasets/datasets.py +328 -0
  7. spikezoo/archs/bsf/datasets/ds_utils.py +64 -0
  8. spikezoo/archs/bsf/main.py +398 -0
  9. spikezoo/archs/bsf/metrics/psnr.py +22 -0
  10. spikezoo/archs/bsf/metrics/ssim.py +54 -0
  11. spikezoo/archs/bsf/models/bsf/__pycache__/align.cpython-39.pyc +0 -0
  12. spikezoo/archs/bsf/models/bsf/__pycache__/bsf.cpython-39.pyc +0 -0
  13. spikezoo/archs/bsf/models/bsf/__pycache__/rep.cpython-39.pyc +0 -0
  14. spikezoo/archs/bsf/models/bsf/align.py +154 -0
  15. spikezoo/archs/bsf/models/bsf/bsf.py +105 -0
  16. spikezoo/archs/bsf/models/bsf/dsft_convert.py +96 -0
  17. spikezoo/archs/bsf/models/bsf/rep.py +44 -0
  18. spikezoo/archs/bsf/models/get_model.py +7 -0
  19. spikezoo/archs/bsf/prepare_data/DSFT.py +62 -0
  20. spikezoo/archs/bsf/prepare_data/crop_dataset_train.py +135 -0
  21. spikezoo/archs/bsf/prepare_data/crop_dataset_val.py +139 -0
  22. spikezoo/archs/bsf/prepare_data/crop_train.sh +4 -0
  23. spikezoo/archs/bsf/prepare_data/crop_val.sh +4 -0
  24. spikezoo/archs/bsf/prepare_data/io_utils.py +64 -0
  25. spikezoo/archs/bsf/requirements.txt +9 -0
  26. spikezoo/archs/bsf/test.py +16 -0
  27. spikezoo/archs/bsf/utils.py +154 -0
  28. spikezoo/archs/spikeclip/__pycache__/nets.cpython-39.pyc +0 -0
  29. spikezoo/archs/spikeclip/nets.py +40 -0
  30. spikezoo/archs/spikeformer/CheckPoints/readme +1 -0
  31. spikezoo/archs/spikeformer/DataProcess/DataExtactor.py +60 -0
  32. spikezoo/archs/spikeformer/DataProcess/DataLoader.py +115 -0
  33. spikezoo/archs/spikeformer/DataProcess/LoadSpike.py +39 -0
  34. spikezoo/archs/spikeformer/EvalResults/readme +1 -0
  35. spikezoo/archs/spikeformer/LICENSE +21 -0
  36. spikezoo/archs/spikeformer/Metrics/Metrics.py +50 -0
  37. spikezoo/archs/spikeformer/Metrics/__init__.py +0 -0
  38. spikezoo/archs/spikeformer/Model/Loss.py +89 -0
  39. spikezoo/archs/spikeformer/Model/SpikeFormer.py +230 -0
  40. spikezoo/archs/spikeformer/Model/__init__.py +0 -0
  41. spikezoo/archs/spikeformer/Model/__pycache__/SpikeFormer.cpython-39.pyc +0 -0
  42. spikezoo/archs/spikeformer/Model/__pycache__/__init__.cpython-39.pyc +0 -0
  43. spikezoo/archs/spikeformer/README.md +30 -0
  44. spikezoo/archs/spikeformer/evaluate.py +87 -0
  45. spikezoo/archs/spikeformer/recon_real_data.py +97 -0
  46. spikezoo/archs/spikeformer/requirements.yml +95 -0
  47. spikezoo/archs/spikeformer/train.py +173 -0
  48. spikezoo/archs/spikeformer/utils.py +22 -0
  49. spikezoo/archs/spk2imgnet/.github/workflows/pylint.yml +23 -0
  50. spikezoo/archs/spk2imgnet/.gitignore +150 -0
  51. spikezoo/archs/spk2imgnet/DCNv2.py +135 -0
  52. spikezoo/archs/spk2imgnet/__pycache__/DCNv2.cpython-39.pyc +0 -0
  53. spikezoo/archs/spk2imgnet/__pycache__/align_arch.cpython-39.pyc +0 -0
  54. spikezoo/archs/spk2imgnet/__pycache__/nets.cpython-39.pyc +0 -0
  55. spikezoo/archs/spk2imgnet/align_arch.py +159 -0
  56. spikezoo/archs/spk2imgnet/dataset.py +144 -0
  57. spikezoo/archs/spk2imgnet/nets.py +230 -0
  58. spikezoo/archs/spk2imgnet/readme.md +86 -0
  59. spikezoo/archs/spk2imgnet/test_gen_imgseq.py +118 -0
  60. spikezoo/archs/spk2imgnet/train.py +189 -0
  61. spikezoo/archs/spk2imgnet/utils.py +64 -0
  62. spikezoo/archs/ssir/README.md +87 -0
  63. spikezoo/archs/ssir/configs/SSIR.yml +37 -0
  64. spikezoo/archs/ssir/configs/yml_parser.py +78 -0
  65. spikezoo/archs/ssir/datasets/dataset_sreds.py +170 -0
  66. spikezoo/archs/ssir/datasets/ds_utils.py +66 -0
  67. spikezoo/archs/ssir/losses.py +21 -0
  68. spikezoo/archs/ssir/main.py +326 -0
  69. spikezoo/archs/ssir/metrics/psnr.py +22 -0
  70. spikezoo/archs/ssir/metrics/ssim.py +54 -0
  71. spikezoo/archs/ssir/models/Vgg19.py +42 -0
  72. spikezoo/archs/ssir/models/__pycache__/layers.cpython-39.pyc +0 -0
  73. spikezoo/archs/ssir/models/__pycache__/networks.cpython-39.pyc +0 -0
  74. spikezoo/archs/ssir/models/layers.py +110 -0
  75. spikezoo/archs/ssir/models/networks.py +61 -0
  76. spikezoo/archs/ssir/requirements.txt +8 -0
  77. spikezoo/archs/ssir/shells/eval_SREDS.sh +6 -0
  78. spikezoo/archs/ssir/shells/train_SSIR.sh +12 -0
  79. spikezoo/archs/ssir/test.py +3 -0
  80. spikezoo/archs/ssir/utils.py +154 -0
  81. spikezoo/archs/ssml/__pycache__/cbam.cpython-39.pyc +0 -0
  82. spikezoo/archs/ssml/__pycache__/model.cpython-39.pyc +0 -0
  83. spikezoo/archs/ssml/cbam.py +224 -0
  84. spikezoo/archs/ssml/model.py +290 -0
  85. spikezoo/archs/ssml/res.png +0 -0
  86. spikezoo/archs/ssml/test.py +67 -0
  87. spikezoo/archs/stir/.git-credentials +0 -0
  88. spikezoo/archs/stir/README.md +65 -0
  89. spikezoo/archs/stir/ckpt_outputs/Descriptions.txt +1 -0
  90. spikezoo/archs/stir/configs/STIR.yml +37 -0
  91. spikezoo/archs/stir/configs/utils.py +155 -0
  92. spikezoo/archs/stir/configs/yml_parser.py +78 -0
  93. spikezoo/archs/stir/datasets/dataset_sreds.py +180 -0
  94. spikezoo/archs/stir/datasets/ds_utils.py +66 -0
  95. spikezoo/archs/stir/eval_SREDS.sh +5 -0
  96. spikezoo/archs/stir/main.py +397 -0
  97. spikezoo/archs/stir/metrics/losses.py +219 -0
  98. spikezoo/archs/stir/metrics/psnr.py +22 -0
  99. spikezoo/archs/stir/metrics/ssim.py +54 -0
  100. spikezoo/archs/stir/models/Vgg19.py +42 -0
  101. spikezoo/archs/stir/models/__pycache__/networks_STIR.cpython-39.pyc +0 -0
  102. spikezoo/archs/stir/models/__pycache__/submodules.cpython-39.pyc +0 -0
  103. spikezoo/archs/stir/models/__pycache__/transformer_new.cpython-39.pyc +0 -0
  104. spikezoo/archs/stir/models/networks_STIR.py +361 -0
  105. spikezoo/archs/stir/models/submodules.py +86 -0
  106. spikezoo/archs/stir/models/transformer_new.py +151 -0
  107. spikezoo/archs/stir/package_core/build/lib/package_core/__init__.py +0 -0
  108. spikezoo/archs/stir/package_core/build/lib/package_core/convertions.py +721 -0
  109. spikezoo/archs/stir/package_core/build/lib/package_core/disp_netS.py +133 -0
  110. spikezoo/archs/stir/package_core/build/lib/package_core/flow_utils.py +167 -0
  111. spikezoo/archs/stir/package_core/build/lib/package_core/generic_train_test.py +76 -0
  112. spikezoo/archs/stir/package_core/build/lib/package_core/geometry.py +458 -0
  113. spikezoo/archs/stir/package_core/build/lib/package_core/image_proc.py +183 -0
  114. spikezoo/archs/stir/package_core/build/lib/package_core/linalg.py +40 -0
  115. spikezoo/archs/stir/package_core/build/lib/package_core/losses.py +198 -0
  116. spikezoo/archs/stir/package_core/build/lib/package_core/metrics.py +51 -0
  117. spikezoo/archs/stir/package_core/build/lib/package_core/model_base.py +53 -0
  118. spikezoo/archs/stir/package_core/build/lib/package_core/net_basics.py +100 -0
  119. spikezoo/archs/stir/package_core/build/lib/package_core/resnet.py +333 -0
  120. spikezoo/archs/stir/package_core/build/lib/package_core/transforms.py +123 -0
  121. spikezoo/archs/stir/package_core/build/lib/package_core/utils.py +72 -0
  122. spikezoo/archs/stir/package_core/dist/package_core-0.0.0-py3.9.egg +0 -0
  123. spikezoo/archs/stir/package_core/package_core/__init__.py +0 -0
  124. spikezoo/archs/stir/package_core/package_core/__pycache__/__init__.cpython-39.pyc +0 -0
  125. spikezoo/archs/stir/package_core/package_core/__pycache__/net_basics.cpython-39.pyc +0 -0
  126. spikezoo/archs/stir/package_core/package_core/convertions.py +721 -0
  127. spikezoo/archs/stir/package_core/package_core/disp_netS.py +133 -0
  128. spikezoo/archs/stir/package_core/package_core/flow_utils.py +167 -0
  129. spikezoo/archs/stir/package_core/package_core/generic_train_test.py +76 -0
  130. spikezoo/archs/stir/package_core/package_core/geometry.py +458 -0
  131. spikezoo/archs/stir/package_core/package_core/image_proc.py +183 -0
  132. spikezoo/archs/stir/package_core/package_core/linalg.py +40 -0
  133. spikezoo/archs/stir/package_core/package_core/losses.py +198 -0
  134. spikezoo/archs/stir/package_core/package_core/metrics.py +51 -0
  135. spikezoo/archs/stir/package_core/package_core/model_base.py +53 -0
  136. spikezoo/archs/stir/package_core/package_core/net_basics.py +100 -0
  137. spikezoo/archs/stir/package_core/package_core/resnet.py +333 -0
  138. spikezoo/archs/stir/package_core/package_core/transforms.py +123 -0
  139. spikezoo/archs/stir/package_core/package_core/utils.py +72 -0
  140. spikezoo/archs/stir/package_core/package_core.egg-info/PKG-INFO +3 -0
  141. spikezoo/archs/stir/package_core/package_core.egg-info/SOURCES.txt +20 -0
  142. spikezoo/archs/stir/package_core/package_core.egg-info/dependency_links.txt +1 -0
  143. spikezoo/archs/stir/package_core/package_core.egg-info/top_level.txt +1 -0
  144. spikezoo/archs/stir/package_core/setup.py +5 -0
  145. spikezoo/archs/stir/requirements.txt +12 -0
  146. spikezoo/archs/stir/train_STIR.sh +9 -0
  147. spikezoo/archs/tfi/__pycache__/nets.cpython-39.pyc +0 -0
  148. spikezoo/archs/tfi/nets.py +43 -0
  149. spikezoo/archs/tfp/__pycache__/nets.cpython-39.pyc +0 -0
  150. spikezoo/archs/tfp/nets.py +13 -0
  151. spikezoo/archs/wgse/README.md +64 -0
  152. spikezoo/archs/wgse/__pycache__/dwtnets.cpython-39.pyc +0 -0
  153. spikezoo/archs/wgse/__pycache__/submodules.cpython-39.pyc +0 -0
  154. spikezoo/archs/wgse/dataset.py +59 -0
  155. spikezoo/archs/wgse/demo.png +0 -0
  156. spikezoo/archs/wgse/demo.py +83 -0
  157. spikezoo/archs/wgse/dwtnets.py +145 -0
  158. spikezoo/archs/wgse/eval.py +133 -0
  159. spikezoo/archs/wgse/logs/WGSE-Dwt1dNet-db8-5-ks3/log.txt +11 -0
  160. spikezoo/archs/wgse/submodules.py +68 -0
  161. spikezoo/archs/wgse/train.py +261 -0
  162. spikezoo/archs/wgse/transform.py +139 -0
  163. spikezoo/archs/wgse/utils.py +128 -0
  164. spikezoo/archs/wgse/weights/demo.png +0 -0
  165. spikezoo/data/base/test/gt/200_part1_key_id151.png +0 -0
  166. spikezoo/data/base/test/gt/200_part3_key_id151.png +0 -0
  167. spikezoo/data/base/test/gt/203_part1_key_id151.png +0 -0
  168. spikezoo/data/base/test/spike/200_part1_key_id151.dat +0 -0
  169. spikezoo/data/base/test/spike/200_part3_key_id151.dat +0 -0
  170. spikezoo/data/base/test/spike/203_part1_key_id151.dat +0 -0
  171. spikezoo/data/base/train/gt/203_part2_key_id151.png +0 -0
  172. spikezoo/data/base/train/gt/203_part3_key_id151.png +0 -0
  173. spikezoo/data/base/train/gt/203_part4_key_id151.png +0 -0
  174. spikezoo/data/base/train/spike/203_part2_key_id151.dat +0 -0
  175. spikezoo/data/base/train/spike/203_part3_key_id151.dat +0 -0
  176. spikezoo/data/base/train/spike/203_part4_key_id151.dat +0 -0
  177. spikezoo/datasets/base_dataset.py +2 -3
  178. spikezoo/metrics/__init__.py +1 -1
  179. spikezoo/models/base_model.py +1 -3
  180. spikezoo/pipeline/base_pipeline.py +7 -5
  181. spikezoo/pipeline/train_pipeline.py +1 -1
  182. spikezoo/utils/other_utils.py +16 -6
  183. spikezoo/utils/spike_utils.py +33 -29
  184. spikezoo/utils/vidar_loader.cpython-39-x86_64-linux-gnu.so +0 -0
  185. spikezoo-0.2.1.dist-info/METADATA +167 -0
  186. spikezoo-0.2.1.dist-info/RECORD +211 -0
  187. spikezoo/models/spcsnet_model.py +0 -19
  188. spikezoo-0.1.2.dist-info/METADATA +0 -39
  189. spikezoo-0.1.2.dist-info/RECORD +0 -36
  190. {spikezoo-0.1.2.dist-info → spikezoo-0.2.1.dist-info}/LICENSE.txt +0 -0
  191. {spikezoo-0.1.2.dist-info → spikezoo-0.2.1.dist-info}/WHEEL +0 -0
  192. {spikezoo-0.1.2.dist-info → spikezoo-0.2.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,144 @@
1
+ import os
2
+ import os.path
3
+ from typing import List
4
+
5
+ import numpy as np
6
+ import random
7
+ import h5py
8
+ import torch
9
+ import cv2
10
+ import glob
11
+ import torch.utils.data as udata
12
+ from functools import partial
13
+
14
+ bytes2num = partial(int.from_bytes, byteorder="little", signed=False)
15
+
16
+
17
+ def normalize(data):
18
+ return data / 255.0
19
+
20
+
21
+ def raw_to_spike(video_seq, h, w):
22
+ video_seq = np.array(video_seq).astype(np.uint8)
23
+ img_size = h * w
24
+ img_num = len(video_seq) // (img_size // 8)
25
+ spike_matrix = np.zeros([img_num, h, w], np.uint8)
26
+ pix_id = np.arange(0, h * w)
27
+ pix_id = np.reshape(pix_id, (h, w))
28
+ comparator = np.left_shift(1, np.mod(pix_id, 8))
29
+ byte_id = pix_id // 8
30
+
31
+ for img_id in np.arange(img_num):
32
+ id_start = img_id * img_size // 8
33
+ id_end = id_start + img_size // 8
34
+ cur_info = video_seq[id_start:id_end]
35
+ data = cur_info[byte_id]
36
+ result = np.bitwise_and(data, comparator)
37
+ spike_matrix[img_id, :, :] = np.flipud((result == comparator))
38
+
39
+ return spike_matrix
40
+
41
+
42
+ def Im2Patch(img, win, stride=40):
43
+ k = 0
44
+ [endc, endw, endh] = img.shape
45
+ patch = img[:, 0: endw - win + 0 + 1: stride, 0: endh - win + 0 + 1: stride]
46
+ total_pat_num = patch.shape[1] * patch.shape[2]
47
+ Y = np.zeros([endc, win * win, total_pat_num], np.float32)
48
+ for i in range(win):
49
+ for j in range(win):
50
+ patch = img[
51
+ :, i: endw - win + i + 1: stride, j: endh - win + j + 1: stride
52
+ ]
53
+ Y[:, k, :] = np.array(patch[:]).reshape(endc, total_pat_num)
54
+ k = k + 1
55
+ return Y.reshape([endc, win, win, total_pat_num])
56
+
57
+
58
+ def read_image_and_concat_as_tensor(paths: List[str]):
59
+ tensors = []
60
+ for path in paths:
61
+ img = cv2.imread(path)
62
+ tensors.append(img.reshape([1, *img.shape]))
63
+ return np.concatenate(tensors, axis=0)
64
+
65
+
66
+ def prepare_data(data_path, patch_size, stride, h5_name, aug_times=1):
67
+ print("process training data")
68
+ input_files = glob.glob(os.path.join(data_path, "input", "*.dat"))
69
+ print(len(input_files))
70
+ input_files.sort()
71
+ input_h5f = h5py.File(h5_name + "_input.h5", "w")
72
+ gt_h5f = h5py.File(h5_name + "_gt.h5", "w")
73
+ train_num = 0
74
+ h = 250
75
+ w = 400
76
+ for i in range(len(input_files)):
77
+ input_f = open(input_files[i], "rb+")
78
+ video_seq = input_f.read()
79
+ video_seq = np.fromstring(video_seq, "B")
80
+ # print(video_seq)
81
+ spike_array = raw_to_spike(video_seq, h, w) # c*h*w
82
+ # print(input_files[i][:-3])
83
+ # SpikeArray = SpikeArray[10:-10, :, :]
84
+ # print(np.mean(SpikeArray))
85
+ print(spike_array.shape)
86
+ file_name = input_files[i].replace("\\", "/").split("/")[-1]
87
+ gt = []
88
+ for num in [7, 14, 21, 28, 35]:
89
+ img = cv2.imread(os.path.join(data_path, "gt", file_name[:-6] + str(num) + ".png"), 0)
90
+ gt.append(img.reshape([1, *img.shape]))
91
+ gt = np.concatenate(gt, axis=0)
92
+ print(input_files[i])
93
+ print(os.path.join(data_path, "gt", file_name[:-3] + "png"))
94
+ gt = np.float32(normalize(gt)) # size
95
+ print(gt.shape)
96
+ print(spike_array.shape)
97
+ input_patches = Im2Patch(spike_array, win=patch_size, stride=stride)
98
+ gt_patches = Im2Patch(gt, win=patch_size, stride=stride)
99
+ assert input_patches.shape[3] == gt_patches.shape[3]
100
+ for n in range(input_patches.shape[3]):
101
+ inputs = input_patches[:, :, :, n].copy()
102
+ input_h5f.create_dataset(str(train_num), data=inputs)
103
+ gt = gt_patches[:, :, :, n].copy()
104
+ gt_h5f.create_dataset(str(train_num), data=gt)
105
+ train_num += 1
106
+
107
+ input_h5f.close()
108
+ gt_h5f.close()
109
+
110
+
111
+ class Dataset(udata.Dataset):
112
+ def __init__(self, h5_name):
113
+ super(Dataset, self).__init__()
114
+ input_h5f = h5py.File(h5_name + "_input.h5", "r")
115
+ gt_h5f = h5py.File(h5_name + "_gt.h5", "r")
116
+ self.h5_name = h5_name
117
+ self.keys = list(input_h5f.keys())
118
+ # print(self.keys)
119
+ random.shuffle(self.keys)
120
+ input_h5f.close()
121
+ gt_h5f.close()
122
+
123
+ def __len__(self):
124
+ return len(self.keys)
125
+
126
+ def __getitem__(self, index):
127
+ input_h5f = h5py.File(self.h5_name + "_input.h5", "r")
128
+ gt_h5f = h5py.File(self.h5_name + "_gt.h5", "r")
129
+ key = self.keys[index]
130
+ inputs = np.array(input_h5f[key])
131
+ gt = np.array(gt_h5f[key])
132
+ input_h5f.close()
133
+ gt_h5f.close()
134
+ return torch.Tensor(inputs), torch.Tensor(gt)
135
+
136
+
137
+ if __name__ == "__main__":
138
+ prepare_data(
139
+ data_path="./Spk2ImgNet_train/train2/",
140
+ patch_size=40,
141
+ stride=40,
142
+ h5_name="train",
143
+ )
144
+ # PrepareData(data_path = './SpikeDataset/val/', patch_size=40, stride=40, h5_name='val')
@@ -0,0 +1,230 @@
1
+ import numpy as np
2
+ import sys
3
+ import os
4
+ current_dir = os.path.dirname(os.path.abspath(__file__))
5
+ sys.path.append(current_dir)
6
+
7
+ from align_arch import *
8
+
9
+ class BasicBlock(nn.Module):
10
+ def __init__(self, features):
11
+ super().__init__()
12
+ self.conv1 = nn.Conv2d(
13
+ in_channels=features,
14
+ out_channels=features,
15
+ kernel_size=3,
16
+ padding=1,
17
+ bias=True,
18
+ )
19
+ self.relu1 = nn.ReLU()
20
+ self.conv2 = nn.Conv2d(
21
+ in_channels=features,
22
+ out_channels=features,
23
+ kernel_size=3,
24
+ padding=1,
25
+ bias=True,
26
+ )
27
+ self.relu2 = nn.ReLU()
28
+ self.conv3 = nn.Conv2d(
29
+ in_channels=features,
30
+ out_channels=features,
31
+ kernel_size=3,
32
+ padding=1,
33
+ bias=True,
34
+ )
35
+ self.relu3 = nn.ReLU()
36
+
37
+ def forward(self, x):
38
+ out = self.conv1(x)
39
+ out = self.relu1(out)
40
+ out = self.conv2(out)
41
+ out = self.relu2(out)
42
+ out = self.conv3(out)
43
+ return self.relu3(x + out)
44
+
45
+
46
+ # use Sigmoid
47
+ class CALayer2(nn.Module):
48
+ def __init__(self, in_channels):
49
+ super(CALayer2, self).__init__()
50
+ self.ca_block = nn.Sequential(
51
+ nn.Conv2d(in_channels, in_channels * 2, 3, padding=1, bias=True),
52
+ nn.ReLU(),
53
+ nn.Conv2d(in_channels * 2, in_channels, 3, padding=1, bias=True),
54
+ nn.Sigmoid(),
55
+ )
56
+
57
+ def forward(self, x):
58
+ weight = self.ca_block(x)
59
+ return weight
60
+
61
+
62
+ # use CALayer
63
+ class FeatureExtractor(nn.Module):
64
+ def __init__(
65
+ self, in_channels, features, out_channels, channel_step, num_of_layers=16
66
+ ):
67
+ super(FeatureExtractor, self).__init__()
68
+ # self.InferLayer = LightInferLayer(in_channels=in_channels)
69
+ self.channel_step = channel_step
70
+ self.conv0_0 = nn.Conv2d(
71
+ in_channels=in_channels, out_channels=16, kernel_size=3, padding=1
72
+ )
73
+ self.conv0_1 = nn.Conv2d(
74
+ in_channels=in_channels - 2 * channel_step,
75
+ out_channels=16,
76
+ kernel_size=3,
77
+ padding=1,
78
+ )
79
+ self.conv0_2 = nn.Conv2d(
80
+ in_channels=in_channels - 4 * channel_step,
81
+ out_channels=16,
82
+ kernel_size=3,
83
+ padding=1,
84
+ )
85
+ self.conv0_3 = nn.Conv2d(
86
+ in_channels=in_channels - 6 * channel_step,
87
+ out_channels=16,
88
+ kernel_size=3,
89
+ padding=1,
90
+ )
91
+ self.conv1_0 = nn.Conv2d(
92
+ in_channels=16, out_channels=1, kernel_size=3, padding=1
93
+ )
94
+ self.conv1_1 = nn.Conv2d(
95
+ in_channels=16, out_channels=1, kernel_size=3, padding=1
96
+ )
97
+ self.conv1_2 = nn.Conv2d(
98
+ in_channels=16, out_channels=1, kernel_size=3, padding=1
99
+ )
100
+ self.conv1_3 = nn.Conv2d(
101
+ in_channels=16, out_channels=1, kernel_size=3, padding=1
102
+ )
103
+ self.ca = CALayer2(in_channels=4)
104
+ self.conv = nn.Conv2d(
105
+ in_channels=4, out_channels=features, kernel_size=3, padding=1
106
+ )
107
+ self.relu = nn.ReLU()
108
+ layers = []
109
+ for _ in range(num_of_layers - 2):
110
+ layers.append(BasicBlock(features=features))
111
+ # layers.append(nn.Conv2d(in_channels=features, out_channels=out_channels, kernel_size=kernel_size, padding=padding, bias=True))
112
+ self.net = nn.Sequential(*layers)
113
+
114
+ def forward(self, x):
115
+ out_0 = self.conv1_0(self.relu(self.conv0_0(x)))
116
+ out_1 = self.conv1_1(
117
+ self.relu(self.conv0_1(x[:, self.channel_step : -self.channel_step, :, :]))
118
+ )
119
+ out_2 = self.conv1_2(
120
+ self.relu(
121
+ self.conv0_2(x[:, 2 * self.channel_step : -2 * self.channel_step, :, :])
122
+ )
123
+ )
124
+ out_3 = self.conv1_3(
125
+ self.relu(
126
+ self.conv0_3(x[:, 3 * self.channel_step : -3 * self.channel_step, :, :])
127
+ )
128
+ )
129
+ out = torch.cat((out_0, out_1), 1)
130
+ out = torch.cat((out, out_2), 1)
131
+ out = torch.cat((out, out_3), 1)
132
+ est = out
133
+ weight = self.ca(out)
134
+ out = weight * out
135
+ out = self.conv(out)
136
+ out = self.relu(out)
137
+ tmp = out
138
+ out = self.net(out)
139
+ # out = self.conv2(out)
140
+ # out = self.relu2(out)
141
+ # out = self.conv3(out)
142
+ return out + tmp, est
143
+
144
+
145
+ class FusionMaskV1(nn.Module):
146
+ def __init__(self, features):
147
+ super(FusionMaskV1, self).__init__()
148
+ self.conv0 = nn.Conv2d(
149
+ in_channels=2 * features, out_channels=features, kernel_size=3, padding=1
150
+ )
151
+ self.conv1 = nn.Conv2d(
152
+ in_channels=features, out_channels=features, kernel_size=3, padding=1
153
+ )
154
+ self.conv2 = nn.Conv2d(
155
+ in_channels=features, out_channels=features, kernel_size=3, padding=1
156
+ )
157
+ self.prelu0 = nn.PReLU()
158
+ self.prelu1 = nn.PReLU()
159
+ self.sig = nn.Sigmoid()
160
+
161
+ def forward(self, ref, key):
162
+ fea = torch.cat((ref, key), 1)
163
+ fea = self.conv2(self.prelu1(self.conv1(self.prelu0(self.conv0(fea)))))
164
+ mask = self.sig(fea)
165
+ return mask
166
+
167
+
168
+ # current best model
169
+ class SpikeNet(nn.Module):
170
+ def __init__(self, in_channels, features, out_channels, win_r, win_step):
171
+ super(SpikeNet, self).__init__()
172
+ self.extractor = FeatureExtractor(
173
+ in_channels=in_channels,
174
+ features=features,
175
+ out_channels=features,
176
+ channel_step=1,
177
+ num_of_layers=12,
178
+ )
179
+ self.mask0 = FusionMaskV1(features=features)
180
+ self.mask1 = FusionMaskV1(features=features)
181
+ self.mask3 = FusionMaskV1(features=features)
182
+ self.mask4 = FusionMaskV1(features=features)
183
+ self.rec_conv0 = nn.Conv2d(
184
+ in_channels=5 * features,
185
+ out_channels=3 * features,
186
+ kernel_size=3,
187
+ padding=1,
188
+ )
189
+ self.rec_conv1 = nn.Conv2d(
190
+ in_channels=3 * features, out_channels=features, kernel_size=3, padding=1
191
+ )
192
+ self.rec_conv2 = nn.Conv2d(
193
+ in_channels=features, out_channels=1, kernel_size=3, padding=1
194
+ )
195
+ self.rec_relu = nn.ReLU()
196
+ self.pcd_align = Easy_PCD(nf=features, groups=8)
197
+ self.win_r = win_r
198
+ self.win_step = win_step
199
+
200
+ def forward(self, x):
201
+ block0 = x[:, 0 : 2 * self.win_r + 1, :, :]
202
+ block1 = x[:, self.win_step : self.win_step + 2 * self.win_r + 1, :, :]
203
+ block2 = x[:, 2 * self.win_step : 2 * self.win_step + 2 * self.win_r + 1, :, :]
204
+ block3 = x[:, 3 * self.win_step : 3 * self.win_step + 2 * self.win_r + 1, :, :]
205
+ block4 = x[:, 4 * self.win_step : 4 * self.win_step + 2 * self.win_r + 1, :, :]
206
+ block0_out, est0 = self.extractor(block0)
207
+ block1_out, est1 = self.extractor(block1)
208
+ block2_out, est2 = self.extractor(block2)
209
+ block3_out, est3 = self.extractor(block3)
210
+ block4_out, est4 = self.extractor(block4)
211
+ aligned_block0_out = self.pcd_align(block0_out, block2_out)
212
+ aligned_block1_out = self.pcd_align(block1_out, block2_out)
213
+ aligned_block3_out = self.pcd_align(block3_out, block2_out)
214
+ aligned_block4_out = self.pcd_align(block4_out, block2_out)
215
+ mask0 = self.mask0(aligned_block0_out, block2_out)
216
+ mask1 = self.mask1(aligned_block1_out, block2_out)
217
+ mask3 = self.mask3(aligned_block3_out, block2_out)
218
+ mask4 = self.mask4(aligned_block4_out, block2_out)
219
+ out = torch.cat((aligned_block0_out * mask0, aligned_block1_out * mask1), 1)
220
+ out = torch.cat((out, block2_out), 1)
221
+ out = torch.cat((out, aligned_block3_out * mask3), 1)
222
+ out = torch.cat((out, aligned_block4_out * mask4), 1)
223
+ out = self.rec_relu(self.rec_conv0(out))
224
+ out = self.rec_relu(self.rec_conv1(out))
225
+ out = self.rec_conv2(out)
226
+ return out
227
+
228
+
229
+ if __name__ == "__main__":
230
+ print("out")
@@ -0,0 +1,86 @@
1
+ ## [CVPR 2021] Spk2ImgNet: Learning to Reconstruct Dynamic Scene from Continuous Spike Stream
2
+
3
+
4
+ <h4 align="center"> Jing Zhao, Ruiqin Xiong, Hangfan Liu, Jian Zhang, Tiejun Huang </h4>
5
+
6
+ This repository contains the official source code for our paper:
7
+
8
+ Spk2ImgNet: Learning to Reconstruct Dynamic Scene from Continuous Spike Stream. CVPR 2021
9
+
10
+ Paper:
11
+ [Spk2ImgNet-CVPR2021](https://openaccess.thecvf.com/content/CVPR2021/papers/Zhao_Spk2ImgNet_Learning_To_Reconstruct_Dynamic_Scene_From_Continuous_Spike_Stream_CVPR_2021_paper.pdf)
12
+
13
+ * [Spk2ImgNet](#Learning-to-Reconstruct-Dynamic-Scene-from-Continuous-Spike-Stream.)
14
+ * [Environments](#Environments)
15
+ * [Download the pretrained models](#Download-the-pretrained-models)
16
+ * [Evaluate](#Evaluate)
17
+ * [Train](#Train)
18
+ * [Citation](#Citations)
19
+
20
+
21
+ ## Environments
22
+
23
+ You will have to choose cudatoolkit version to match your compute environment. The code is tested on PyTorch 1.10.2+cu113 and spatial-correlation-sampler 0.3.0 but other versions might also work.
24
+
25
+ ```bash
26
+ conda create -n steflow python==3.9
27
+ conda activate steflow
28
+ conda install pytorch torchvision cudatoolkit=11.3 -c pytorch
29
+ pip3 install matplotlib opencv-python h5py
30
+ ```
31
+
32
+ We don't ensure that all the PyTorch versions can work well.
33
+
34
+ ## Prepare the Data
35
+
36
+ ### Download the pretrained models
37
+
38
+ The pretrained model can be downloaded in the Google Drive link below
39
+
40
+ [Link for pretrained model](https://drive.google.com/file/d/1vBTJxlctk4otQKsyRq7lsFYGU4WGRNjt/view?usp=sharing)
41
+
42
+ You can download the pretrained models to ```./ckpt```
43
+
44
+ ### Download the training data
45
+
46
+ The training data can be downloaded in the Google Drive link below
47
+
48
+ [Link for training data](https://drive.google.com/file/d/1ozR2-fNmU10gA_TCYUfJN-ahV6e_8Ke7/view?usp=sharing)
49
+
50
+ ## Evaluate
51
+
52
+ You can set the data path in the .py files or through argparser (--data)
53
+
54
+ ```bash
55
+ python3 main_steflow_dt1.py \
56
+ --test_data 'Spk2ImgNet_test2' \
57
+ --model_name 'model_061.pth'
58
+
59
+ ```
60
+
61
+
62
+ ## Train
63
+
64
+
65
+ All the command line arguments for hyperparameter tuning can be found in the `train.py` file.
66
+ You can set the data path in the .py files or through argparser (--data)
67
+
68
+ ```bash
69
+ python3 train.py
70
+ ```
71
+
72
+ ## Citations
73
+
74
+ If you find this code useful in your research, please consider citing our paper:
75
+
76
+ ```
77
+ @inproceedings{zhao2021spike,
78
+ title={Spk2ImgNet: Learning to Reconstruct Dynamic Scene from Continuous Spike Stream},
79
+ author={Zhao, Jing and Xiong, Ruiqin and Liu, Hangfan and Zhang, Jian and Huang, Tiejun},
80
+ booktitle={2021 IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
81
+ year={2021}
82
+ }
83
+ ```
84
+
85
+
86
+
@@ -0,0 +1,118 @@
1
+ import argparse
2
+ import time
3
+
4
+ from skimage.metrics import peak_signal_noise_ratio, structural_similarity
5
+ from torch.autograd import Variable
6
+
7
+ from dataset import *
8
+ from nets import *
9
+
10
+ os.environ["CUDA_VISIBLE_DEVICES"] = "0"
11
+
12
+ parser = argparse.ArgumentParser(description="Spike_Net_Test")
13
+ parser.add_argument(
14
+ "--num_of_layers", type=int, default=17, help="Number of toatal layers"
15
+ )
16
+ parser.add_argument(
17
+ "--logdir",
18
+ type=str,
19
+ default="./ckpt2/",
20
+ help="path of log files",
21
+ )
22
+ parser.add_argument("--test_data", type=str, default="./Spk2ImgNet_test2/test2/", help="test set")
23
+ parser.add_argument(
24
+ "--save_result", type=bool, default=True, help="save the reconstruction or not"
25
+ )
26
+ parser.add_argument(
27
+ "--result_dir", type=str, default="results/", help="path of results"
28
+ )
29
+ parser.add_argument(
30
+ "--exist_gt", type=bool, default=True, help="exist ground truth or not"
31
+ )
32
+ parser.add_argument("--model_name", type=str, default="model_041.pth", help="Name of ckp")
33
+ opt = parser.parse_args()
34
+
35
+
36
+ def normalize(data):
37
+ return data / 255.0
38
+
39
+
40
+ def main():
41
+ # Build model
42
+ print("Loading model ... \n")
43
+ net = SpikeNet(
44
+ in_channels=13, features=64, out_channels=1, win_r=6, win_step=7
45
+ )
46
+ # device_ids = [0]
47
+ # print(device_ids[0])
48
+ model = nn.DataParallel(net).cuda()
49
+ model.load_state_dict(torch.load(os.path.join(opt.logdir, opt.model_name)))
50
+ model.eval()
51
+
52
+ # load data info
53
+ print("Loading data info ...\n")
54
+ # sub_dir = 'data4'
55
+ files_source = glob.glob(os.path.join(opt.test_data, "input", "*.dat"))
56
+ files_source.sort()
57
+
58
+ # process data
59
+ psnr_test = 0
60
+ ssim_test = 0
61
+ for i in range(len(files_source)):
62
+ sub_dir = files_source[i][:-4]
63
+ # Input spike
64
+ input_f = open(files_source[i], "rb+")
65
+ video_seq = input_f.read()
66
+ video_seq = np.fromstring(video_seq, "B")
67
+ InSpikeArray = raw_to_spike(video_seq, 250, 400) # c*h*w
68
+ [c, h, w] = InSpikeArray.shape
69
+ for key_id in np.arange(151, 152, 1):
70
+ start_t = time.time()
71
+ SpikeArray = InSpikeArray[key_id - 21 : key_id + 20, :, :]
72
+ # make its shape can be divided by 4
73
+ SpikeArray = np.pad(
74
+ SpikeArray, ((0, 0), (0, 2), (0, 0)), "symmetric"
75
+ ) # c*252*40
76
+ SpikeArray = np.expand_dims(SpikeArray, 0) # n*c*h*w
77
+ file_name = files_source[i].replace("\\", "/").split("/")[-1]
78
+
79
+ SpikeArray = Variable(torch.Tensor(SpikeArray)).cuda()
80
+ with torch.no_grad():
81
+ if opt.exist_gt:
82
+ out_rec, est0, est1, est2, est3, est4 = model(SpikeArray)
83
+ out_rec = (
84
+ torch.clamp(out_rec / 0.6, 0, 1).cpu() * 255
85
+ ) # 0.6 is the converation rate used in the spike camera. Only neccessary for our synthezed data.
86
+ else:
87
+ out_rec, est0, est1, est2, est3, est4 = model(SpikeArray)
88
+ out_rec = torch.clamp(out_rec, 0, 1).cpu() ** (1 / 2.2) * 255
89
+ out_rec = out_rec.detach().numpy().astype(np.float32)
90
+ out_rec = np.squeeze(out_rec).astype(np.uint8)
91
+ # transform to orignal shape # 250*400
92
+ out_rec = out_rec[:250, :]
93
+ if opt.exist_gt:
94
+ gt = cv2.imread(
95
+ os.path.join(opt.test_data, "gt", file_name[:-3] + "png"), 0
96
+ )
97
+ psnr = peak_signal_noise_ratio(gt, out_rec)
98
+ ssim = structural_similarity(gt, out_rec)
99
+ print("%10s: PSNR:%.2f SSIM:%.4f" % (file_name, psnr, ssim))
100
+ psnr_test += psnr
101
+ ssim_test += ssim
102
+ if opt.save_result:
103
+ if not os.path.exists(os.path.join(opt.result_dir, sub_dir)):
104
+ os.makedirs(os.path.join(opt.result_dir, sub_dir))
105
+ cv2.imwrite(
106
+ os.path.join(opt.result_dir, sub_dir, str(key_id) + ".png"), out_rec
107
+ )
108
+ dur_time = time.time() - start_t
109
+ print("dur_time:%.2f", dur_time)
110
+
111
+ if opt.exist_gt:
112
+ avg_psnr = psnr_test / len(files_source)
113
+ avg_ssim = ssim_test / len(files_source)
114
+ print("average PSNR: %.2f average SSIM: %.4f" % (avg_psnr, avg_ssim))
115
+
116
+
117
+ if __name__ == "__main__":
118
+ main()