spikezoo 0.1.2__py3-none-any.whl → 0.2__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (192) hide show
  1. spikezoo/__init__.py +13 -0
  2. spikezoo/archs/__pycache__/__init__.cpython-39.pyc +0 -0
  3. spikezoo/archs/base/__pycache__/nets.cpython-39.pyc +0 -0
  4. spikezoo/archs/base/nets.py +34 -0
  5. spikezoo/archs/bsf/README.md +92 -0
  6. spikezoo/archs/bsf/datasets/datasets.py +328 -0
  7. spikezoo/archs/bsf/datasets/ds_utils.py +64 -0
  8. spikezoo/archs/bsf/main.py +398 -0
  9. spikezoo/archs/bsf/metrics/psnr.py +22 -0
  10. spikezoo/archs/bsf/metrics/ssim.py +54 -0
  11. spikezoo/archs/bsf/models/bsf/__pycache__/align.cpython-39.pyc +0 -0
  12. spikezoo/archs/bsf/models/bsf/__pycache__/bsf.cpython-39.pyc +0 -0
  13. spikezoo/archs/bsf/models/bsf/__pycache__/rep.cpython-39.pyc +0 -0
  14. spikezoo/archs/bsf/models/bsf/align.py +154 -0
  15. spikezoo/archs/bsf/models/bsf/bsf.py +105 -0
  16. spikezoo/archs/bsf/models/bsf/dsft_convert.py +96 -0
  17. spikezoo/archs/bsf/models/bsf/rep.py +44 -0
  18. spikezoo/archs/bsf/models/get_model.py +7 -0
  19. spikezoo/archs/bsf/prepare_data/DSFT.py +62 -0
  20. spikezoo/archs/bsf/prepare_data/crop_dataset_train.py +135 -0
  21. spikezoo/archs/bsf/prepare_data/crop_dataset_val.py +139 -0
  22. spikezoo/archs/bsf/prepare_data/crop_train.sh +4 -0
  23. spikezoo/archs/bsf/prepare_data/crop_val.sh +4 -0
  24. spikezoo/archs/bsf/prepare_data/io_utils.py +64 -0
  25. spikezoo/archs/bsf/requirements.txt +9 -0
  26. spikezoo/archs/bsf/test.py +16 -0
  27. spikezoo/archs/bsf/utils.py +154 -0
  28. spikezoo/archs/spikeclip/__pycache__/nets.cpython-39.pyc +0 -0
  29. spikezoo/archs/spikeclip/nets.py +40 -0
  30. spikezoo/archs/spikeformer/CheckPoints/readme +1 -0
  31. spikezoo/archs/spikeformer/DataProcess/DataExtactor.py +60 -0
  32. spikezoo/archs/spikeformer/DataProcess/DataLoader.py +115 -0
  33. spikezoo/archs/spikeformer/DataProcess/LoadSpike.py +39 -0
  34. spikezoo/archs/spikeformer/EvalResults/readme +1 -0
  35. spikezoo/archs/spikeformer/LICENSE +21 -0
  36. spikezoo/archs/spikeformer/Metrics/Metrics.py +50 -0
  37. spikezoo/archs/spikeformer/Metrics/__init__.py +0 -0
  38. spikezoo/archs/spikeformer/Model/Loss.py +89 -0
  39. spikezoo/archs/spikeformer/Model/SpikeFormer.py +230 -0
  40. spikezoo/archs/spikeformer/Model/__init__.py +0 -0
  41. spikezoo/archs/spikeformer/Model/__pycache__/SpikeFormer.cpython-39.pyc +0 -0
  42. spikezoo/archs/spikeformer/Model/__pycache__/__init__.cpython-39.pyc +0 -0
  43. spikezoo/archs/spikeformer/README.md +30 -0
  44. spikezoo/archs/spikeformer/evaluate.py +87 -0
  45. spikezoo/archs/spikeformer/recon_real_data.py +97 -0
  46. spikezoo/archs/spikeformer/requirements.yml +95 -0
  47. spikezoo/archs/spikeformer/train.py +173 -0
  48. spikezoo/archs/spikeformer/utils.py +22 -0
  49. spikezoo/archs/spk2imgnet/.github/workflows/pylint.yml +23 -0
  50. spikezoo/archs/spk2imgnet/.gitignore +150 -0
  51. spikezoo/archs/spk2imgnet/DCNv2.py +135 -0
  52. spikezoo/archs/spk2imgnet/__pycache__/DCNv2.cpython-39.pyc +0 -0
  53. spikezoo/archs/spk2imgnet/__pycache__/align_arch.cpython-39.pyc +0 -0
  54. spikezoo/archs/spk2imgnet/__pycache__/nets.cpython-39.pyc +0 -0
  55. spikezoo/archs/spk2imgnet/align_arch.py +159 -0
  56. spikezoo/archs/spk2imgnet/dataset.py +144 -0
  57. spikezoo/archs/spk2imgnet/nets.py +230 -0
  58. spikezoo/archs/spk2imgnet/readme.md +86 -0
  59. spikezoo/archs/spk2imgnet/test_gen_imgseq.py +118 -0
  60. spikezoo/archs/spk2imgnet/train.py +189 -0
  61. spikezoo/archs/spk2imgnet/utils.py +64 -0
  62. spikezoo/archs/ssir/README.md +87 -0
  63. spikezoo/archs/ssir/configs/SSIR.yml +37 -0
  64. spikezoo/archs/ssir/configs/yml_parser.py +78 -0
  65. spikezoo/archs/ssir/datasets/dataset_sreds.py +170 -0
  66. spikezoo/archs/ssir/datasets/ds_utils.py +66 -0
  67. spikezoo/archs/ssir/losses.py +21 -0
  68. spikezoo/archs/ssir/main.py +326 -0
  69. spikezoo/archs/ssir/metrics/psnr.py +22 -0
  70. spikezoo/archs/ssir/metrics/ssim.py +54 -0
  71. spikezoo/archs/ssir/models/Vgg19.py +42 -0
  72. spikezoo/archs/ssir/models/__pycache__/layers.cpython-39.pyc +0 -0
  73. spikezoo/archs/ssir/models/__pycache__/networks.cpython-39.pyc +0 -0
  74. spikezoo/archs/ssir/models/layers.py +110 -0
  75. spikezoo/archs/ssir/models/networks.py +61 -0
  76. spikezoo/archs/ssir/requirements.txt +8 -0
  77. spikezoo/archs/ssir/shells/eval_SREDS.sh +6 -0
  78. spikezoo/archs/ssir/shells/train_SSIR.sh +12 -0
  79. spikezoo/archs/ssir/test.py +3 -0
  80. spikezoo/archs/ssir/utils.py +154 -0
  81. spikezoo/archs/ssml/__pycache__/cbam.cpython-39.pyc +0 -0
  82. spikezoo/archs/ssml/__pycache__/model.cpython-39.pyc +0 -0
  83. spikezoo/archs/ssml/cbam.py +224 -0
  84. spikezoo/archs/ssml/model.py +290 -0
  85. spikezoo/archs/ssml/res.png +0 -0
  86. spikezoo/archs/ssml/test.py +67 -0
  87. spikezoo/archs/stir/.git-credentials +0 -0
  88. spikezoo/archs/stir/README.md +65 -0
  89. spikezoo/archs/stir/ckpt_outputs/Descriptions.txt +1 -0
  90. spikezoo/archs/stir/configs/STIR.yml +37 -0
  91. spikezoo/archs/stir/configs/utils.py +155 -0
  92. spikezoo/archs/stir/configs/yml_parser.py +78 -0
  93. spikezoo/archs/stir/datasets/dataset_sreds.py +180 -0
  94. spikezoo/archs/stir/datasets/ds_utils.py +66 -0
  95. spikezoo/archs/stir/eval_SREDS.sh +5 -0
  96. spikezoo/archs/stir/main.py +397 -0
  97. spikezoo/archs/stir/metrics/losses.py +219 -0
  98. spikezoo/archs/stir/metrics/psnr.py +22 -0
  99. spikezoo/archs/stir/metrics/ssim.py +54 -0
  100. spikezoo/archs/stir/models/Vgg19.py +42 -0
  101. spikezoo/archs/stir/models/__pycache__/networks_STIR.cpython-39.pyc +0 -0
  102. spikezoo/archs/stir/models/__pycache__/submodules.cpython-39.pyc +0 -0
  103. spikezoo/archs/stir/models/__pycache__/transformer_new.cpython-39.pyc +0 -0
  104. spikezoo/archs/stir/models/networks_STIR.py +361 -0
  105. spikezoo/archs/stir/models/submodules.py +86 -0
  106. spikezoo/archs/stir/models/transformer_new.py +151 -0
  107. spikezoo/archs/stir/package_core/build/lib/package_core/__init__.py +0 -0
  108. spikezoo/archs/stir/package_core/build/lib/package_core/convertions.py +721 -0
  109. spikezoo/archs/stir/package_core/build/lib/package_core/disp_netS.py +133 -0
  110. spikezoo/archs/stir/package_core/build/lib/package_core/flow_utils.py +167 -0
  111. spikezoo/archs/stir/package_core/build/lib/package_core/generic_train_test.py +76 -0
  112. spikezoo/archs/stir/package_core/build/lib/package_core/geometry.py +458 -0
  113. spikezoo/archs/stir/package_core/build/lib/package_core/image_proc.py +183 -0
  114. spikezoo/archs/stir/package_core/build/lib/package_core/linalg.py +40 -0
  115. spikezoo/archs/stir/package_core/build/lib/package_core/losses.py +198 -0
  116. spikezoo/archs/stir/package_core/build/lib/package_core/metrics.py +51 -0
  117. spikezoo/archs/stir/package_core/build/lib/package_core/model_base.py +53 -0
  118. spikezoo/archs/stir/package_core/build/lib/package_core/net_basics.py +100 -0
  119. spikezoo/archs/stir/package_core/build/lib/package_core/resnet.py +333 -0
  120. spikezoo/archs/stir/package_core/build/lib/package_core/transforms.py +123 -0
  121. spikezoo/archs/stir/package_core/build/lib/package_core/utils.py +72 -0
  122. spikezoo/archs/stir/package_core/dist/package_core-0.0.0-py3.9.egg +0 -0
  123. spikezoo/archs/stir/package_core/package_core/__init__.py +0 -0
  124. spikezoo/archs/stir/package_core/package_core/__pycache__/__init__.cpython-39.pyc +0 -0
  125. spikezoo/archs/stir/package_core/package_core/__pycache__/net_basics.cpython-39.pyc +0 -0
  126. spikezoo/archs/stir/package_core/package_core/convertions.py +721 -0
  127. spikezoo/archs/stir/package_core/package_core/disp_netS.py +133 -0
  128. spikezoo/archs/stir/package_core/package_core/flow_utils.py +167 -0
  129. spikezoo/archs/stir/package_core/package_core/generic_train_test.py +76 -0
  130. spikezoo/archs/stir/package_core/package_core/geometry.py +458 -0
  131. spikezoo/archs/stir/package_core/package_core/image_proc.py +183 -0
  132. spikezoo/archs/stir/package_core/package_core/linalg.py +40 -0
  133. spikezoo/archs/stir/package_core/package_core/losses.py +198 -0
  134. spikezoo/archs/stir/package_core/package_core/metrics.py +51 -0
  135. spikezoo/archs/stir/package_core/package_core/model_base.py +53 -0
  136. spikezoo/archs/stir/package_core/package_core/net_basics.py +100 -0
  137. spikezoo/archs/stir/package_core/package_core/resnet.py +333 -0
  138. spikezoo/archs/stir/package_core/package_core/transforms.py +123 -0
  139. spikezoo/archs/stir/package_core/package_core/utils.py +72 -0
  140. spikezoo/archs/stir/package_core/package_core.egg-info/PKG-INFO +3 -0
  141. spikezoo/archs/stir/package_core/package_core.egg-info/SOURCES.txt +20 -0
  142. spikezoo/archs/stir/package_core/package_core.egg-info/dependency_links.txt +1 -0
  143. spikezoo/archs/stir/package_core/package_core.egg-info/top_level.txt +1 -0
  144. spikezoo/archs/stir/package_core/setup.py +5 -0
  145. spikezoo/archs/stir/requirements.txt +12 -0
  146. spikezoo/archs/stir/train_STIR.sh +9 -0
  147. spikezoo/archs/tfi/__pycache__/nets.cpython-39.pyc +0 -0
  148. spikezoo/archs/tfi/nets.py +43 -0
  149. spikezoo/archs/tfp/__pycache__/nets.cpython-39.pyc +0 -0
  150. spikezoo/archs/tfp/nets.py +13 -0
  151. spikezoo/archs/wgse/README.md +64 -0
  152. spikezoo/archs/wgse/__pycache__/dwtnets.cpython-39.pyc +0 -0
  153. spikezoo/archs/wgse/__pycache__/submodules.cpython-39.pyc +0 -0
  154. spikezoo/archs/wgse/dataset.py +59 -0
  155. spikezoo/archs/wgse/demo.png +0 -0
  156. spikezoo/archs/wgse/demo.py +83 -0
  157. spikezoo/archs/wgse/dwtnets.py +145 -0
  158. spikezoo/archs/wgse/eval.py +133 -0
  159. spikezoo/archs/wgse/logs/WGSE-Dwt1dNet-db8-5-ks3/log.txt +11 -0
  160. spikezoo/archs/wgse/submodules.py +68 -0
  161. spikezoo/archs/wgse/train.py +261 -0
  162. spikezoo/archs/wgse/transform.py +139 -0
  163. spikezoo/archs/wgse/utils.py +128 -0
  164. spikezoo/archs/wgse/weights/demo.png +0 -0
  165. spikezoo/data/base/test/gt/200_part1_key_id151.png +0 -0
  166. spikezoo/data/base/test/gt/200_part3_key_id151.png +0 -0
  167. spikezoo/data/base/test/gt/203_part1_key_id151.png +0 -0
  168. spikezoo/data/base/test/spike/200_part1_key_id151.dat +0 -0
  169. spikezoo/data/base/test/spike/200_part3_key_id151.dat +0 -0
  170. spikezoo/data/base/test/spike/203_part1_key_id151.dat +0 -0
  171. spikezoo/data/base/train/gt/203_part2_key_id151.png +0 -0
  172. spikezoo/data/base/train/gt/203_part3_key_id151.png +0 -0
  173. spikezoo/data/base/train/gt/203_part4_key_id151.png +0 -0
  174. spikezoo/data/base/train/spike/203_part2_key_id151.dat +0 -0
  175. spikezoo/data/base/train/spike/203_part3_key_id151.dat +0 -0
  176. spikezoo/data/base/train/spike/203_part4_key_id151.dat +0 -0
  177. spikezoo/datasets/base_dataset.py +2 -3
  178. spikezoo/metrics/__init__.py +1 -1
  179. spikezoo/models/base_model.py +1 -3
  180. spikezoo/pipeline/base_pipeline.py +7 -5
  181. spikezoo/pipeline/train_pipeline.py +1 -1
  182. spikezoo/utils/other_utils.py +16 -6
  183. spikezoo/utils/spike_utils.py +33 -29
  184. spikezoo/utils/vidar_loader.cpython-39-x86_64-linux-gnu.so +0 -0
  185. spikezoo-0.2.dist-info/METADATA +163 -0
  186. spikezoo-0.2.dist-info/RECORD +211 -0
  187. spikezoo/models/spcsnet_model.py +0 -19
  188. spikezoo-0.1.2.dist-info/METADATA +0 -39
  189. spikezoo-0.1.2.dist-info/RECORD +0 -36
  190. {spikezoo-0.1.2.dist-info → spikezoo-0.2.dist-info}/LICENSE.txt +0 -0
  191. {spikezoo-0.1.2.dist-info → spikezoo-0.2.dist-info}/WHEEL +0 -0
  192. {spikezoo-0.1.2.dist-info → spikezoo-0.2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,189 @@
1
+ import argparse
2
+ import glob
3
+ import os
4
+ import re
5
+ from collections import OrderedDict
6
+
7
+ import torch.optim as optim
8
+ from torch.autograd import Variable
9
+ from torch.utils.data import DataLoader
10
+ from dataset import *
11
+ from nets import *
12
+ from utils import *
13
+
14
+ os.environ["CUDA_VISIBLE_DEVICES"] = "0"
15
+
16
+ parser = argparse.ArgumentParser(description="Spk2ImgNet")
17
+ parser.add_argument(
18
+ "--preprocess", type=bool, default=False, help="run prepare_data or not"
19
+ )
20
+ parser.add_argument("--batchSize", type=int, default=16, help="Trainning batch size")
21
+ parser.add_argument(
22
+ "--num_of_layers", type=int, default=17, help="Number of total layers"
23
+ )
24
+ parser.add_argument("--epochs", type=int, default=61, help="Number of trainning epochs")
25
+ parser.add_argument(
26
+ "--milestone",
27
+ type=int,
28
+ default=20,
29
+ help="When to decay learning rate: should be less than epochs",
30
+ )
31
+ parser.add_argument(
32
+ "--lr",
33
+ type=float,
34
+ default=1e-4,
35
+ help="Initial learning rate; should be less than epochs",
36
+ )
37
+ parser.add_argument(
38
+ "--outf",
39
+ type=str,
40
+ default="./ckpt2",
41
+ help="path of log files",
42
+ )
43
+ parser.add_argument(
44
+ "--load_model", type=bool, default=False, help="load model from net.pth"
45
+ )
46
+ opt = parser.parse_args()
47
+
48
+ if not os.path.exists(opt.outf):
49
+ os.mkdir(opt.outf)
50
+
51
+
52
+ def find_last_checkpoint(save_dir):
53
+ file_list = glob.glob(os.path.join(save_dir, "model_*.pth"))
54
+ if file_list:
55
+ epoch_exist = []
56
+ for file_ in file_list:
57
+ result = re.findall(".*model_(.*).pth.*", file_)
58
+ epoch_exist.append(int(result[0]))
59
+ initial_epoch = max(epoch_exist)
60
+ else:
61
+ initial_epoch = 0
62
+ return initial_epoch
63
+
64
+
65
+ def main():
66
+ # Load dataset
67
+ print("Loading dataset ...\n")
68
+ dataset_train = Dataset("train")
69
+ loader_train = DataLoader(
70
+ dataset=dataset_train, num_workers=4, batch_size=opt.batchSize, shuffle=True
71
+ )
72
+ print("# of training samples: %d\n" % int(len(dataset_train)))
73
+ '''
74
+ dataset_val = Dataset("val_stack")
75
+ loader_val = DataLoader(
76
+ dataset=dataset_val, num_workers=4, batch_size=opt.batchSize, shuffle=False
77
+ )
78
+ '''
79
+ # Build model
80
+ model = SpikeNet(in_channels=13, features=64, out_channels=1, win_r=6, win_step=7)
81
+ if not opt.load_model:
82
+ initial_epoch = 0
83
+ print("haha")
84
+ else:
85
+ # load model
86
+ initial_epoch = find_last_checkpoint(save_dir=opt.outf)
87
+ print("load model from model.pth")
88
+ state_dict = torch.load(
89
+ os.path.join(opt.outf, "model_%03d.pth" % initial_epoch)
90
+ )
91
+ new_state_dict = OrderedDict()
92
+ for k, v in state_dict.items():
93
+ name = k[7:]
94
+ new_state_dict[name] = v
95
+ model.load_state_dict(new_state_dict)
96
+ criterion = nn.L1Loss(size_average=True)
97
+ # Move to GPU
98
+ device_ids = [0]
99
+ model = nn.DataParallel(model).cuda()
100
+ criterion = criterion.cuda()
101
+ # Optimazer
102
+ optimizer = optim.Adam(model.parameters(), lr=opt.lr)
103
+ # training
104
+ model.train()
105
+ step = 0
106
+ for epoch in range(initial_epoch, opt.epochs):
107
+ avg_psnr = 0
108
+ if epoch < opt.milestone:
109
+ current_lr = opt.lr
110
+ else:
111
+ current_lr = opt.lr / 10.0
112
+ # set learning rate
113
+ for param_group in optimizer.param_groups:
114
+ param_group["lr"] = current_lr
115
+ print("learning rate %f" % current_lr)
116
+ # train
117
+ for i, (inputs, gt) in enumerate(loader_train, 0):
118
+ # print(inputs.shape)
119
+ inputs = Variable(inputs).cuda()
120
+ gt = Variable(gt).cuda()
121
+ # training step
122
+ model.train()
123
+ model.zero_grad()
124
+ optimizer.zero_grad()
125
+ rec, est0, est1, est2, est3, est4 = model(inputs)
126
+ est0 = est0 / 0.6
127
+ est1 = est1 / 0.6
128
+ est2 = est2 / 0.6
129
+ est3 = est3 / 0.6
130
+ est4 = est4 / 0.6
131
+ rec = rec / 0.6
132
+ loss = criterion(gt[:, 2:3, :, :], rec)
133
+ for slice_id in range(4):
134
+ loss = loss + 0.02 * (
135
+ criterion(gt[:, 0:1, :, :], est0[:, slice_id : slice_id + 1, :, :])
136
+ + criterion(
137
+ gt[:, 1:2, :, :], est1[:, slice_id : slice_id + 1, :, :]
138
+ )
139
+ + criterion(
140
+ gt[:, 2:3, :, :], est2[:, slice_id : slice_id + 1, :, :]
141
+ )
142
+ + criterion(
143
+ gt[:, 3:4, :, :], est3[:, slice_id : slice_id + 1, :, :]
144
+ )
145
+ + criterion(
146
+ gt[:, 4:5, :, :], est4[:, slice_id : slice_id + 1, :, :]
147
+ )
148
+ )
149
+ loss.backward()
150
+ optimizer.step()
151
+ rec = torch.clamp(rec, 0, 1)
152
+ # print(rec)
153
+ psnr_train = batch_psnr(rec, gt[:, 2:3, :, :], 1.0)
154
+ # print(gt[:,2:3,:,:])
155
+ avg_psnr += psnr_train
156
+ if i % 10 == 0:
157
+ print(
158
+ "[epoch %d][%d | %d] loss: %.4f PSNR_train: %.4f"
159
+ % (epoch + 1, i + 1, len(loader_train), loss.item(), psnr_train)
160
+ )
161
+ step += 1
162
+ avg_psnr = avg_psnr / len(loader_train)
163
+ print("avg_psnr: %.2f" % avg_psnr)
164
+
165
+ if epoch % 5 == 0:
166
+ '''
167
+ # validate
168
+ model.eval()
169
+ psnr_val = 0
170
+ for i, (inputs, gt) in enumerate(loader_val, 0):
171
+ inputs = Variable(inputs).cuda()
172
+ gt = Variable(gt).cuda()
173
+ rec = model(inputs)
174
+ rec = rec / 0.6
175
+ rec = torch.clamp(rec, 0, 1)
176
+ psnr_val += batch_psnr(rec, gt, 1.0)
177
+ print("[epoch %d] PSNR_val: %.4f" % (epoch + 1, psnr_val / len(loader_val)))
178
+ '''
179
+ # save model
180
+ torch.save(
181
+ model.state_dict(),
182
+ os.path.join(opt.outf, "model_%03d.pth" % (epoch + 1)),
183
+ )
184
+
185
+
186
+ if __name__ == "__main__":
187
+ if opt.preprocess:
188
+ prepare_data(data_path="./Spk2ImgNet_train/train2/", patch_size=40, stride=40, h5_name='train')
189
+ main()
@@ -0,0 +1,64 @@
1
+ import math
2
+
3
+ import numpy as np
4
+ import torch.nn as nn
5
+ from skimage.metrics import peak_signal_noise_ratio
6
+
7
+
8
+ def weights_init_kaiming(m):
9
+ classname = m.__class__.__name__
10
+ if classname.find("Conv") != -1:
11
+ nn.init.kaiming_normal_(m.weight.data, a=0, mode="fan_in")
12
+ elif classname.find("Linear") != -1:
13
+ nn.init.kaiming_normal_(m.weight.data, a=0, mode="fan_in")
14
+ elif classname.find("BatchNorm") != -1:
15
+ # nn.init.uniform(m.weight.data, 1.0, 0.02)
16
+ m.weight.data.normal_(mean=0, std=math.sqrt(2.0 / 9.0 / 64.0)).clamp_(
17
+ -0.025, 0.025
18
+ )
19
+ nn.init.constant(m.bias.data, 0.0)
20
+
21
+
22
+ def batch_psnr(img, imclean, data_range):
23
+ img = img.data.cpu().numpy().astype(np.float32)
24
+ imclean = imclean.data.cpu().numpy().astype(np.float32)
25
+ psnr = peak_signal_noise_ratio(img, imclean, data_range=data_range)
26
+ """
27
+ PSNR = 0
28
+ for i in range(Img.shape[0]):
29
+ PSNR += compare_psnr(imclean[i,:,:,:], img[i,:,:,:], data_range=data_range)
30
+ return (PSNR/Img.shape[0])
31
+ """
32
+ return psnr
33
+
34
+
35
+ def data_augmentation(image, mode):
36
+ out = np.transpose(image, (1, 2, 0))
37
+ if mode == 0:
38
+ # original
39
+ out = out
40
+ elif mode == 1:
41
+ # flip up and down
42
+ out = np.flipud(out)
43
+ elif mode == 2:
44
+ # rotate counterwise 90 degree
45
+ out = np.rot90(out)
46
+ elif mode == 3:
47
+ # rotate 90 degree and flip up and down
48
+ out = np.rot90(out)
49
+ out = np.flipud(out)
50
+ elif mode == 4:
51
+ # rotate 180 degree
52
+ out = np.rot90(out, k=2)
53
+ elif mode == 5:
54
+ # rotate 180 degree and flip
55
+ out = np.rot90(out, k=2)
56
+ out = np.flipud(out)
57
+ elif mode == 6:
58
+ # rotate 270 degree
59
+ out = np.rot90(out, k=3)
60
+ elif mode == 7:
61
+ # rotate 270 degree and flip
62
+ out = np.rot90(out, k=3)
63
+ out = np.flipud(out)
64
+ return np.transpose(out, (2, 0, 1))
@@ -0,0 +1,87 @@
1
+ ## [TCSVT 2023] Spike Camera Image Reconstruction Using Deep Spiking Neural Networks
2
+
3
+ <h4 align="center"> Rui Zhao<sup>1</sup>, Ruiqin Xiong<sup>1</sup>, Jian Zhang<sup>2</sup>, Zhaofei Yu<sup>1</sup>, Shuyuan Zhu<sup>3</sup>, Lei Ma <sup>1</sup>, Tiejun Huang<sup>1</sup> </h4>
4
+ <h4 align="center">1. National Engineering Research Center of Visual Technology, School of Computer Science, Peking University<br>
5
+ 2. School of Electronic and Computer Engineering, Peking University Shenzhen Graduate School<br>
6
+ 3. School of Information and Communication Engineering, UESTC</h4><br>
7
+
8
+ This repository contains the official source code for our paper:
9
+
10
+ Spike Camera Image Reconstruction Using Deep Spiking Neural Networks
11
+
12
+ TCSVT 2023
13
+
14
+ [Paper](https://ieeexplore.ieee.org/document/10288531)
15
+
16
+
17
+
18
+ ## Environment
19
+
20
+ You can choose cudatoolkit version to match your server. The code is tested on PyTorch 2.0.1+cuda12.0.
21
+
22
+ ```shell
23
+ conda create -n ssir python==3.10
24
+ conda activate ssir
25
+ # You can choose the PyTorch version you like, we recommand version >= 1.10.1
26
+ # For example
27
+ pip install torch==2.0.1 torchvision==0.15.2 torchaudio==2.0.2
28
+ pip install -r requirements.txt
29
+ ```
30
+
31
+ ## Prepare the Data
32
+
33
+ #### 1. Download and deploy the SREDS dataset
34
+
35
+ [BaiduNetDisk](https://pan.baidu.com/s/1clA43FcxjOibL1zGTaU82g) (Password: 2728)
36
+
37
+ `train.tar` corresponds to the training data, and `test.tar` corresponds to the testing data.
38
+
39
+ Move the above two `.tar` file to the `data root` directory and extract to the current directory
40
+
41
+ ```
42
+ file directory:
43
+ train:
44
+ your_data_root/crop_mini/spike/...
45
+ your_data_root/crop_mini/image/...
46
+ test:
47
+ your_data_root/spike/...
48
+ your_data_root/imgs/...
49
+ ```
50
+
51
+ #### 2. Set the path of RSSF dataset in your serve
52
+
53
+ In the line25 of `main.py` or set that in command line when running main.py
54
+
55
+ ## Evaluate
56
+ ```shell
57
+ cd shells
58
+ bash eval_SREDS.sh
59
+ ```
60
+
61
+ ## Train
62
+ ```shell
63
+ cd shells
64
+ bash train_SSIR.sh
65
+ ```
66
+ We recommended to redirect the output logs by adding
67
+ `>> SSIR.txt 2>&1`
68
+ to the last of the above command for management.
69
+
70
+
71
+ ## Citation
72
+
73
+ If you find this code useful in your research, please consider citing our paper.
74
+
75
+ ```
76
+ @article{zhao2023spike,
77
+ title={Spike Camera Image Reconstruction Using Deep Spiking Neural Networks},
78
+ author={Zhao, Rui and Xiong, Ruiqin and Zhang, Jian and Yu, Zhaofei and Zhu, Shuyuan and Ma, Lei and Huang, Tiejun},
79
+ journal={IEEE Transactions on Circuits and Systems for Video Technology (TCSVT)},
80
+ year={2023},
81
+ }
82
+ ```
83
+
84
+ If you have any questions, please contact:
85
+ ruizhao@stu.pku.edu.cn
86
+
87
+
@@ -0,0 +1,37 @@
1
+ data:
2
+ interp: 20
3
+ alpha: 0.4
4
+
5
+ seed: 6666
6
+
7
+ loader:
8
+ # crop_size: [128, 128]
9
+ crop_size: [96, 96]
10
+ pair_step: 4
11
+
12
+ model:
13
+ arch: 'sunet'
14
+ seq_len: 8
15
+ flow_weight_decay: 0.0004
16
+ flow_bias_decay: 0.0
17
+ #########################
18
+ kwargs:
19
+ activation_type: 'lif'
20
+ mp_activation_type: 'amp_lif'
21
+ spike_connection: 'concat'
22
+ num_encoders: 3
23
+ num_resblocks: 1
24
+ v_threshold: 1.0
25
+ v_reset: None
26
+ tau: 2.0
27
+
28
+
29
+ train:
30
+ print_freq: 100
31
+ mixed_precision: True
32
+ vis_freq: 20
33
+
34
+ optimizer:
35
+ solver: Adam
36
+ momentum: 0.9
37
+ beta: 0.999
@@ -0,0 +1,78 @@
1
+ import numpy as np
2
+ import torch
3
+ import yaml
4
+
5
+
6
+ class YAMLParser:
7
+ """
8
+ Modified from code from tudelft ssl-evflow
9
+ """
10
+
11
+ def __init__(self, config):
12
+ self.reset_config()
13
+ self.parse_config(config)
14
+ # self.init_seeds()
15
+
16
+ def parse_config(self, file):
17
+ with open(file) as fid:
18
+ yaml_config = yaml.load(fid, Loader=yaml.FullLoader)
19
+ self.parse_dict(yaml_config)
20
+
21
+ @property
22
+ def config(self):
23
+ return self._config
24
+
25
+ @property
26
+ def device(self):
27
+ return self._device
28
+
29
+ @property
30
+ def loader_kwargs(self):
31
+ return self._loader_kwargs
32
+
33
+ def reset_config(self):
34
+ self._config = {}
35
+
36
+ def update(self, config):
37
+ self.reset_config()
38
+ self.parse_config(config)
39
+
40
+ def parse_dict(self, input_dict, parent=None):
41
+ if parent is None:
42
+ parent = self._config
43
+ for key, val in input_dict.items():
44
+ if isinstance(val, dict):
45
+ if key not in parent.keys():
46
+ parent[key] = {}
47
+ self.parse_dict(val, parent[key])
48
+ else:
49
+ parent[key] = val
50
+
51
+ @staticmethod
52
+ def worker_init_fn(worker_id):
53
+ np.random.seed(np.random.get_state()[1][0] + worker_id)
54
+
55
+ # def init_seeds(self):
56
+ # torch.manual_seed(self._config["loader"]["seed"])
57
+ # if torch.cuda.is_available():
58
+ # torch.cuda.manual_seed(self._config["loader"]["seed"])
59
+ # torch.cuda.manual_seed_all(self._config["loader"]["seed"])
60
+
61
+ def merge_configs(self, run):
62
+ """
63
+ Overwrites mlflow metadata with configs.
64
+ """
65
+
66
+ # parse mlflow settings
67
+ config = {}
68
+ for key in run.keys():
69
+ if len(run[key]) > 0 and run[key][0] == "{": # assume dictionary
70
+ config[key] = eval(run[key])
71
+ else: # string
72
+ config[key] = run[key]
73
+
74
+ # overwrite with config settings
75
+ self.parse_dict(self._config, config)
76
+ self.combine_entries(config)
77
+
78
+ return config
@@ -0,0 +1,170 @@
1
+ import os
2
+ import os.path as osp
3
+ import random
4
+ import numpy as np
5
+ import torch
6
+ import torch.utils.data as data
7
+ from datasets.ds_utils import *
8
+ import time
9
+
10
+
11
+ class Augmentor:
12
+ def __init__(self, crop_size):
13
+ # spatial augmentation params
14
+ self.crop_size = crop_size
15
+
16
+ def augment_img(self, img, mode=0):
17
+ '''Kai Zhang (github: https://github.com/cszn)
18
+ W x H x C or W x H
19
+ '''
20
+ if mode == 0:
21
+ return img
22
+ elif mode == 1:
23
+ return np.flipud(np.rot90(img))
24
+ elif mode == 2:
25
+ return np.flipud(img)
26
+ elif mode == 3:
27
+ return np.rot90(img, k=3)
28
+ elif mode == 4:
29
+ return np.flipud(np.rot90(img, k=2))
30
+ elif mode == 5:
31
+ return np.rot90(img)
32
+ elif mode == 6:
33
+ return np.rot90(img, k=2)
34
+ elif mode == 7:
35
+ return np.flipud(np.rot90(img, k=3))
36
+
37
+ def spatial_transform(self, spk_list, img_list):
38
+ mode = random.randint(0, 7)
39
+
40
+ for ii, spk in enumerate(spk_list):
41
+ spk = np.transpose(spk, [1,2,0])
42
+ spk = self.augment_img(spk, mode=mode)
43
+ spk_list[ii] = np.transpose(spk, [2,0,1])
44
+
45
+ for ii, img in enumerate(img_list):
46
+ img = np.transpose(img, [1,2,0])
47
+ img = self.augment_img(img, mode=mode)
48
+ img_list[ii] = np.transpose(img, [2,0,1])
49
+
50
+ return spk_list, img_list
51
+
52
+ def __call__(self, spk_list, img_list):
53
+ spk_list, img_list = self.spatial_transform(spk_list, img_list)
54
+ spk_list = [np.ascontiguousarray(spk) for spk in spk_list]
55
+ img_list = [np.ascontiguousarray(img) for img in img_list]
56
+ return spk_list, img_list
57
+
58
+
59
+ class sreds_train(torch.utils.data.Dataset):
60
+ def __init__(self, cfg):
61
+ self.cfg = cfg
62
+ self.pair_step = self.cfg['loader']['pair_step']
63
+ self.augmentor = Augmentor(crop_size=self.cfg['loader']['crop_size'])
64
+ self.samples = self.collect_samples()
65
+ print('The samples num of training data: {:d}'.format(len(self.samples)))
66
+
67
+ def confirm_exist(self, path_list_list):
68
+ for pl in path_list_list:
69
+ for p in pl:
70
+ if not osp.exists(p):
71
+ return 0
72
+ return 1
73
+
74
+ def collect_samples(self):
75
+ spike_path = osp.join(self.cfg['data']['root'], 'crop_mini', 'spike', 'train', 'interp_{:d}_alpha_{:.2f}'.format(self.cfg['data']['interp'], self.cfg['data']['alpha']))
76
+ image_path = osp.join(self.cfg['data']['root'], 'crop_mini', 'image', 'train', 'train_orig')
77
+ scene_list = sorted(os.listdir(spike_path))
78
+ samples = []
79
+
80
+ for scene in scene_list:
81
+ spike_dir = osp.join(spike_path, scene)
82
+ image_dir = osp.join(image_path, scene)
83
+ spk_path_list = sorted(os.listdir(spike_dir))
84
+
85
+ spklen = len(spk_path_list)
86
+ seq_len = self.cfg['model']['seq_len'] + 2
87
+
88
+ for st in range(0, spklen - ((spklen - self.pair_step) % seq_len) - seq_len, self.pair_step):
89
+ # 按照文件名称读取
90
+ spikes_path_list = [osp.join(spike_dir, spk_path_list[ii]) for ii in range(st, st+seq_len)]
91
+ images_path_list = [osp.join(image_dir, spk_path_list[ii][:-4]+'.png') for ii in range(st, st+seq_len)]
92
+
93
+ if(self.confirm_exist([spikes_path_list, images_path_list])):
94
+ s = {}
95
+ s['spikes_paths'] = spikes_path_list
96
+ s['images_paths'] = images_path_list
97
+ samples.append(s)
98
+ return samples
99
+
100
+ def _load_sample(self, s):
101
+ data = {}
102
+
103
+ data['spikes'] = [np.array(dat_to_spmat(p, size=(96, 96)), dtype=np.float32) for p in s['spikes_paths']]
104
+ data['images'] = [read_img_gray(p) for p in s['images_paths']]
105
+
106
+ data['spikes'], data['images'] = self.augmentor(data['spikes'], data['images'])
107
+
108
+ # print("data['spikes'][0].shape, data['images'][0].shape", data['spikes'][0].shape, data['images'][0].shape)
109
+
110
+ return data
111
+
112
+ def __len__(self):
113
+ return len(self.samples)
114
+
115
+ def __getitem__(self, index):
116
+ data = self._load_sample(self.samples[index])
117
+ return data
118
+
119
+
120
+ class sreds_test(torch.utils.data.Dataset):
121
+ def __init__(self, cfg):
122
+ self.cfg = cfg
123
+ self.samples = self.collect_samples()
124
+ print('The samples num of testing data: {:d}'.format(len(self.samples)))
125
+
126
+ def confirm_exist(self, path_list_list):
127
+ for pl in path_list_list:
128
+ for p in pl:
129
+ if not osp.exists(p):
130
+ return 0
131
+ return 1
132
+
133
+ def collect_samples(self):
134
+ spike_path = osp.join(self.cfg['data']['root'], 'spike', 'val', 'interp_{:d}_alpha_{:.2f}'.format(self.cfg['data']['interp'], self.cfg['data']['alpha']))
135
+ image_path = osp.join(self.cfg['data']['root'], 'imgs', 'val', 'val_orig')
136
+ scene_list = sorted(os.listdir(spike_path))
137
+ samples = []
138
+
139
+ for scene in scene_list:
140
+ spike_dir = osp.join(spike_path, scene)
141
+ image_dir = osp.join(image_path, scene)
142
+ spk_path_list = sorted(os.listdir(spike_dir))
143
+
144
+ spklen = len(spk_path_list)
145
+ # seq_len = self.cfg['model']['seq_len']
146
+
147
+ # 按照文件名称读取
148
+ spikes_path_list = [osp.join(spike_dir, spk_path_list[ii]) for ii in range(spklen)]
149
+ images_path_list = [osp.join(image_dir, spk_path_list[ii][:-4]+'.png') for ii in range(spklen)]
150
+
151
+ if(self.confirm_exist([spikes_path_list, images_path_list])):
152
+ s = {}
153
+ s['spikes_paths'] = spikes_path_list
154
+ s['images_paths'] = images_path_list
155
+ samples.append(s)
156
+
157
+ return samples
158
+
159
+ def _load_sample(self, s):
160
+ data = {}
161
+ data['spikes'] = [np.array(dat_to_spmat(p, size=(720, 1280)), dtype=np.float32) for p in s['spikes_paths']]
162
+ data['images'] = [read_img_gray(p) for p in s['images_paths']]
163
+ return data
164
+
165
+ def __len__(self):
166
+ return len(self.samples)
167
+
168
+ def __getitem__(self, index):
169
+ data = self._load_sample(self.samples[index])
170
+ return data
@@ -0,0 +1,66 @@
1
+ import numpy as np
2
+ import os
3
+ import cv2
4
+ import os.path as osp
5
+
6
+ def RawToSpike(video_seq, h, w, flipud=True):
7
+ video_seq = np.array(video_seq).astype(np.uint8)
8
+ img_size = h*w
9
+ img_num = len(video_seq)//(img_size//8)
10
+ SpikeMatrix = np.zeros([img_num, h, w], np.uint8)
11
+ pix_id = np.arange(0,h*w)
12
+ pix_id = np.reshape(pix_id, (h, w))
13
+ comparator = np.left_shift(1, np.mod(pix_id, 8))
14
+ byte_id = pix_id // 8
15
+
16
+ for img_id in np.arange(img_num):
17
+ id_start = img_id*img_size//8
18
+ id_end = id_start + img_size//8
19
+ cur_info = video_seq[id_start:id_end]
20
+ data = cur_info[byte_id]
21
+ result = np.bitwise_and(data, comparator)
22
+ if flipud:
23
+ SpikeMatrix[img_id, :, :] = np.flipud((result == comparator))
24
+ else:
25
+ SpikeMatrix[img_id, :, :] = (result == comparator)
26
+
27
+ return SpikeMatrix
28
+
29
+
30
+ def SpikeToRaw(SpikeSeq, save_path):
31
+ """
32
+ SpikeSeq: Numpy array (sfn x h x w)
33
+ save_path: full saving path (string)
34
+ Rui Zhao
35
+ """
36
+ sfn, h, w = SpikeSeq.shape
37
+ base = np.power(2, np.linspace(0, 7, 8))
38
+ fid = open(save_path, 'ab')
39
+ for img_id in range(sfn):
40
+ # 模拟相机的倒像
41
+ spike = np.flipud(SpikeSeq[img_id, :, :])
42
+ # numpy按自动按行排,数据也是按行存的
43
+ spike = spike.flatten()
44
+ spike = spike.reshape([int(h*w/8), 8])
45
+ data = spike * base
46
+ data = np.sum(data, axis=1).astype(np.uint8)
47
+ fid.write(data.tobytes())
48
+
49
+ fid.close()
50
+
51
+ return
52
+
53
+
54
+ def dat_to_spmat(dat_path, size=[720, 1280]):
55
+ f = open(dat_path, 'rb')
56
+ video_seq = f.read()
57
+ video_seq = np.frombuffer(video_seq, 'b')
58
+ sp_mat = RawToSpike(video_seq, size[0], size[1])
59
+ return sp_mat
60
+
61
+
62
+ def read_img_gray(file_path):
63
+ im = cv2.imread(file_path).astype(np.float32) / 255.0
64
+ im = cv2.cvtColor(im, cv2.COLOR_BGR2GRAY)
65
+ im = np.expand_dims(im, axis=0)
66
+ return im