spikezoo 0.1.2__py3-none-any.whl → 0.2__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- spikezoo/__init__.py +13 -0
- spikezoo/archs/__pycache__/__init__.cpython-39.pyc +0 -0
- spikezoo/archs/base/__pycache__/nets.cpython-39.pyc +0 -0
- spikezoo/archs/base/nets.py +34 -0
- spikezoo/archs/bsf/README.md +92 -0
- spikezoo/archs/bsf/datasets/datasets.py +328 -0
- spikezoo/archs/bsf/datasets/ds_utils.py +64 -0
- spikezoo/archs/bsf/main.py +398 -0
- spikezoo/archs/bsf/metrics/psnr.py +22 -0
- spikezoo/archs/bsf/metrics/ssim.py +54 -0
- spikezoo/archs/bsf/models/bsf/__pycache__/align.cpython-39.pyc +0 -0
- spikezoo/archs/bsf/models/bsf/__pycache__/bsf.cpython-39.pyc +0 -0
- spikezoo/archs/bsf/models/bsf/__pycache__/rep.cpython-39.pyc +0 -0
- spikezoo/archs/bsf/models/bsf/align.py +154 -0
- spikezoo/archs/bsf/models/bsf/bsf.py +105 -0
- spikezoo/archs/bsf/models/bsf/dsft_convert.py +96 -0
- spikezoo/archs/bsf/models/bsf/rep.py +44 -0
- spikezoo/archs/bsf/models/get_model.py +7 -0
- spikezoo/archs/bsf/prepare_data/DSFT.py +62 -0
- spikezoo/archs/bsf/prepare_data/crop_dataset_train.py +135 -0
- spikezoo/archs/bsf/prepare_data/crop_dataset_val.py +139 -0
- spikezoo/archs/bsf/prepare_data/crop_train.sh +4 -0
- spikezoo/archs/bsf/prepare_data/crop_val.sh +4 -0
- spikezoo/archs/bsf/prepare_data/io_utils.py +64 -0
- spikezoo/archs/bsf/requirements.txt +9 -0
- spikezoo/archs/bsf/test.py +16 -0
- spikezoo/archs/bsf/utils.py +154 -0
- spikezoo/archs/spikeclip/__pycache__/nets.cpython-39.pyc +0 -0
- spikezoo/archs/spikeclip/nets.py +40 -0
- spikezoo/archs/spikeformer/CheckPoints/readme +1 -0
- spikezoo/archs/spikeformer/DataProcess/DataExtactor.py +60 -0
- spikezoo/archs/spikeformer/DataProcess/DataLoader.py +115 -0
- spikezoo/archs/spikeformer/DataProcess/LoadSpike.py +39 -0
- spikezoo/archs/spikeformer/EvalResults/readme +1 -0
- spikezoo/archs/spikeformer/LICENSE +21 -0
- spikezoo/archs/spikeformer/Metrics/Metrics.py +50 -0
- spikezoo/archs/spikeformer/Metrics/__init__.py +0 -0
- spikezoo/archs/spikeformer/Model/Loss.py +89 -0
- spikezoo/archs/spikeformer/Model/SpikeFormer.py +230 -0
- spikezoo/archs/spikeformer/Model/__init__.py +0 -0
- spikezoo/archs/spikeformer/Model/__pycache__/SpikeFormer.cpython-39.pyc +0 -0
- spikezoo/archs/spikeformer/Model/__pycache__/__init__.cpython-39.pyc +0 -0
- spikezoo/archs/spikeformer/README.md +30 -0
- spikezoo/archs/spikeformer/evaluate.py +87 -0
- spikezoo/archs/spikeformer/recon_real_data.py +97 -0
- spikezoo/archs/spikeformer/requirements.yml +95 -0
- spikezoo/archs/spikeformer/train.py +173 -0
- spikezoo/archs/spikeformer/utils.py +22 -0
- spikezoo/archs/spk2imgnet/.github/workflows/pylint.yml +23 -0
- spikezoo/archs/spk2imgnet/.gitignore +150 -0
- spikezoo/archs/spk2imgnet/DCNv2.py +135 -0
- spikezoo/archs/spk2imgnet/__pycache__/DCNv2.cpython-39.pyc +0 -0
- spikezoo/archs/spk2imgnet/__pycache__/align_arch.cpython-39.pyc +0 -0
- spikezoo/archs/spk2imgnet/__pycache__/nets.cpython-39.pyc +0 -0
- spikezoo/archs/spk2imgnet/align_arch.py +159 -0
- spikezoo/archs/spk2imgnet/dataset.py +144 -0
- spikezoo/archs/spk2imgnet/nets.py +230 -0
- spikezoo/archs/spk2imgnet/readme.md +86 -0
- spikezoo/archs/spk2imgnet/test_gen_imgseq.py +118 -0
- spikezoo/archs/spk2imgnet/train.py +189 -0
- spikezoo/archs/spk2imgnet/utils.py +64 -0
- spikezoo/archs/ssir/README.md +87 -0
- spikezoo/archs/ssir/configs/SSIR.yml +37 -0
- spikezoo/archs/ssir/configs/yml_parser.py +78 -0
- spikezoo/archs/ssir/datasets/dataset_sreds.py +170 -0
- spikezoo/archs/ssir/datasets/ds_utils.py +66 -0
- spikezoo/archs/ssir/losses.py +21 -0
- spikezoo/archs/ssir/main.py +326 -0
- spikezoo/archs/ssir/metrics/psnr.py +22 -0
- spikezoo/archs/ssir/metrics/ssim.py +54 -0
- spikezoo/archs/ssir/models/Vgg19.py +42 -0
- spikezoo/archs/ssir/models/__pycache__/layers.cpython-39.pyc +0 -0
- spikezoo/archs/ssir/models/__pycache__/networks.cpython-39.pyc +0 -0
- spikezoo/archs/ssir/models/layers.py +110 -0
- spikezoo/archs/ssir/models/networks.py +61 -0
- spikezoo/archs/ssir/requirements.txt +8 -0
- spikezoo/archs/ssir/shells/eval_SREDS.sh +6 -0
- spikezoo/archs/ssir/shells/train_SSIR.sh +12 -0
- spikezoo/archs/ssir/test.py +3 -0
- spikezoo/archs/ssir/utils.py +154 -0
- spikezoo/archs/ssml/__pycache__/cbam.cpython-39.pyc +0 -0
- spikezoo/archs/ssml/__pycache__/model.cpython-39.pyc +0 -0
- spikezoo/archs/ssml/cbam.py +224 -0
- spikezoo/archs/ssml/model.py +290 -0
- spikezoo/archs/ssml/res.png +0 -0
- spikezoo/archs/ssml/test.py +67 -0
- spikezoo/archs/stir/.git-credentials +0 -0
- spikezoo/archs/stir/README.md +65 -0
- spikezoo/archs/stir/ckpt_outputs/Descriptions.txt +1 -0
- spikezoo/archs/stir/configs/STIR.yml +37 -0
- spikezoo/archs/stir/configs/utils.py +155 -0
- spikezoo/archs/stir/configs/yml_parser.py +78 -0
- spikezoo/archs/stir/datasets/dataset_sreds.py +180 -0
- spikezoo/archs/stir/datasets/ds_utils.py +66 -0
- spikezoo/archs/stir/eval_SREDS.sh +5 -0
- spikezoo/archs/stir/main.py +397 -0
- spikezoo/archs/stir/metrics/losses.py +219 -0
- spikezoo/archs/stir/metrics/psnr.py +22 -0
- spikezoo/archs/stir/metrics/ssim.py +54 -0
- spikezoo/archs/stir/models/Vgg19.py +42 -0
- spikezoo/archs/stir/models/__pycache__/networks_STIR.cpython-39.pyc +0 -0
- spikezoo/archs/stir/models/__pycache__/submodules.cpython-39.pyc +0 -0
- spikezoo/archs/stir/models/__pycache__/transformer_new.cpython-39.pyc +0 -0
- spikezoo/archs/stir/models/networks_STIR.py +361 -0
- spikezoo/archs/stir/models/submodules.py +86 -0
- spikezoo/archs/stir/models/transformer_new.py +151 -0
- spikezoo/archs/stir/package_core/build/lib/package_core/__init__.py +0 -0
- spikezoo/archs/stir/package_core/build/lib/package_core/convertions.py +721 -0
- spikezoo/archs/stir/package_core/build/lib/package_core/disp_netS.py +133 -0
- spikezoo/archs/stir/package_core/build/lib/package_core/flow_utils.py +167 -0
- spikezoo/archs/stir/package_core/build/lib/package_core/generic_train_test.py +76 -0
- spikezoo/archs/stir/package_core/build/lib/package_core/geometry.py +458 -0
- spikezoo/archs/stir/package_core/build/lib/package_core/image_proc.py +183 -0
- spikezoo/archs/stir/package_core/build/lib/package_core/linalg.py +40 -0
- spikezoo/archs/stir/package_core/build/lib/package_core/losses.py +198 -0
- spikezoo/archs/stir/package_core/build/lib/package_core/metrics.py +51 -0
- spikezoo/archs/stir/package_core/build/lib/package_core/model_base.py +53 -0
- spikezoo/archs/stir/package_core/build/lib/package_core/net_basics.py +100 -0
- spikezoo/archs/stir/package_core/build/lib/package_core/resnet.py +333 -0
- spikezoo/archs/stir/package_core/build/lib/package_core/transforms.py +123 -0
- spikezoo/archs/stir/package_core/build/lib/package_core/utils.py +72 -0
- spikezoo/archs/stir/package_core/dist/package_core-0.0.0-py3.9.egg +0 -0
- spikezoo/archs/stir/package_core/package_core/__init__.py +0 -0
- spikezoo/archs/stir/package_core/package_core/__pycache__/__init__.cpython-39.pyc +0 -0
- spikezoo/archs/stir/package_core/package_core/__pycache__/net_basics.cpython-39.pyc +0 -0
- spikezoo/archs/stir/package_core/package_core/convertions.py +721 -0
- spikezoo/archs/stir/package_core/package_core/disp_netS.py +133 -0
- spikezoo/archs/stir/package_core/package_core/flow_utils.py +167 -0
- spikezoo/archs/stir/package_core/package_core/generic_train_test.py +76 -0
- spikezoo/archs/stir/package_core/package_core/geometry.py +458 -0
- spikezoo/archs/stir/package_core/package_core/image_proc.py +183 -0
- spikezoo/archs/stir/package_core/package_core/linalg.py +40 -0
- spikezoo/archs/stir/package_core/package_core/losses.py +198 -0
- spikezoo/archs/stir/package_core/package_core/metrics.py +51 -0
- spikezoo/archs/stir/package_core/package_core/model_base.py +53 -0
- spikezoo/archs/stir/package_core/package_core/net_basics.py +100 -0
- spikezoo/archs/stir/package_core/package_core/resnet.py +333 -0
- spikezoo/archs/stir/package_core/package_core/transforms.py +123 -0
- spikezoo/archs/stir/package_core/package_core/utils.py +72 -0
- spikezoo/archs/stir/package_core/package_core.egg-info/PKG-INFO +3 -0
- spikezoo/archs/stir/package_core/package_core.egg-info/SOURCES.txt +20 -0
- spikezoo/archs/stir/package_core/package_core.egg-info/dependency_links.txt +1 -0
- spikezoo/archs/stir/package_core/package_core.egg-info/top_level.txt +1 -0
- spikezoo/archs/stir/package_core/setup.py +5 -0
- spikezoo/archs/stir/requirements.txt +12 -0
- spikezoo/archs/stir/train_STIR.sh +9 -0
- spikezoo/archs/tfi/__pycache__/nets.cpython-39.pyc +0 -0
- spikezoo/archs/tfi/nets.py +43 -0
- spikezoo/archs/tfp/__pycache__/nets.cpython-39.pyc +0 -0
- spikezoo/archs/tfp/nets.py +13 -0
- spikezoo/archs/wgse/README.md +64 -0
- spikezoo/archs/wgse/__pycache__/dwtnets.cpython-39.pyc +0 -0
- spikezoo/archs/wgse/__pycache__/submodules.cpython-39.pyc +0 -0
- spikezoo/archs/wgse/dataset.py +59 -0
- spikezoo/archs/wgse/demo.png +0 -0
- spikezoo/archs/wgse/demo.py +83 -0
- spikezoo/archs/wgse/dwtnets.py +145 -0
- spikezoo/archs/wgse/eval.py +133 -0
- spikezoo/archs/wgse/logs/WGSE-Dwt1dNet-db8-5-ks3/log.txt +11 -0
- spikezoo/archs/wgse/submodules.py +68 -0
- spikezoo/archs/wgse/train.py +261 -0
- spikezoo/archs/wgse/transform.py +139 -0
- spikezoo/archs/wgse/utils.py +128 -0
- spikezoo/archs/wgse/weights/demo.png +0 -0
- spikezoo/data/base/test/gt/200_part1_key_id151.png +0 -0
- spikezoo/data/base/test/gt/200_part3_key_id151.png +0 -0
- spikezoo/data/base/test/gt/203_part1_key_id151.png +0 -0
- spikezoo/data/base/test/spike/200_part1_key_id151.dat +0 -0
- spikezoo/data/base/test/spike/200_part3_key_id151.dat +0 -0
- spikezoo/data/base/test/spike/203_part1_key_id151.dat +0 -0
- spikezoo/data/base/train/gt/203_part2_key_id151.png +0 -0
- spikezoo/data/base/train/gt/203_part3_key_id151.png +0 -0
- spikezoo/data/base/train/gt/203_part4_key_id151.png +0 -0
- spikezoo/data/base/train/spike/203_part2_key_id151.dat +0 -0
- spikezoo/data/base/train/spike/203_part3_key_id151.dat +0 -0
- spikezoo/data/base/train/spike/203_part4_key_id151.dat +0 -0
- spikezoo/datasets/base_dataset.py +2 -3
- spikezoo/metrics/__init__.py +1 -1
- spikezoo/models/base_model.py +1 -3
- spikezoo/pipeline/base_pipeline.py +7 -5
- spikezoo/pipeline/train_pipeline.py +1 -1
- spikezoo/utils/other_utils.py +16 -6
- spikezoo/utils/spike_utils.py +33 -29
- spikezoo/utils/vidar_loader.cpython-39-x86_64-linux-gnu.so +0 -0
- spikezoo-0.2.dist-info/METADATA +163 -0
- spikezoo-0.2.dist-info/RECORD +211 -0
- spikezoo/models/spcsnet_model.py +0 -19
- spikezoo-0.1.2.dist-info/METADATA +0 -39
- spikezoo-0.1.2.dist-info/RECORD +0 -36
- {spikezoo-0.1.2.dist-info → spikezoo-0.2.dist-info}/LICENSE.txt +0 -0
- {spikezoo-0.1.2.dist-info → spikezoo-0.2.dist-info}/WHEEL +0 -0
- {spikezoo-0.1.2.dist-info → spikezoo-0.2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,189 @@
|
|
1
|
+
import argparse
|
2
|
+
import glob
|
3
|
+
import os
|
4
|
+
import re
|
5
|
+
from collections import OrderedDict
|
6
|
+
|
7
|
+
import torch.optim as optim
|
8
|
+
from torch.autograd import Variable
|
9
|
+
from torch.utils.data import DataLoader
|
10
|
+
from dataset import *
|
11
|
+
from nets import *
|
12
|
+
from utils import *
|
13
|
+
|
14
|
+
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
|
15
|
+
|
16
|
+
parser = argparse.ArgumentParser(description="Spk2ImgNet")
|
17
|
+
parser.add_argument(
|
18
|
+
"--preprocess", type=bool, default=False, help="run prepare_data or not"
|
19
|
+
)
|
20
|
+
parser.add_argument("--batchSize", type=int, default=16, help="Trainning batch size")
|
21
|
+
parser.add_argument(
|
22
|
+
"--num_of_layers", type=int, default=17, help="Number of total layers"
|
23
|
+
)
|
24
|
+
parser.add_argument("--epochs", type=int, default=61, help="Number of trainning epochs")
|
25
|
+
parser.add_argument(
|
26
|
+
"--milestone",
|
27
|
+
type=int,
|
28
|
+
default=20,
|
29
|
+
help="When to decay learning rate: should be less than epochs",
|
30
|
+
)
|
31
|
+
parser.add_argument(
|
32
|
+
"--lr",
|
33
|
+
type=float,
|
34
|
+
default=1e-4,
|
35
|
+
help="Initial learning rate; should be less than epochs",
|
36
|
+
)
|
37
|
+
parser.add_argument(
|
38
|
+
"--outf",
|
39
|
+
type=str,
|
40
|
+
default="./ckpt2",
|
41
|
+
help="path of log files",
|
42
|
+
)
|
43
|
+
parser.add_argument(
|
44
|
+
"--load_model", type=bool, default=False, help="load model from net.pth"
|
45
|
+
)
|
46
|
+
opt = parser.parse_args()
|
47
|
+
|
48
|
+
if not os.path.exists(opt.outf):
|
49
|
+
os.mkdir(opt.outf)
|
50
|
+
|
51
|
+
|
52
|
+
def find_last_checkpoint(save_dir):
|
53
|
+
file_list = glob.glob(os.path.join(save_dir, "model_*.pth"))
|
54
|
+
if file_list:
|
55
|
+
epoch_exist = []
|
56
|
+
for file_ in file_list:
|
57
|
+
result = re.findall(".*model_(.*).pth.*", file_)
|
58
|
+
epoch_exist.append(int(result[0]))
|
59
|
+
initial_epoch = max(epoch_exist)
|
60
|
+
else:
|
61
|
+
initial_epoch = 0
|
62
|
+
return initial_epoch
|
63
|
+
|
64
|
+
|
65
|
+
def main():
|
66
|
+
# Load dataset
|
67
|
+
print("Loading dataset ...\n")
|
68
|
+
dataset_train = Dataset("train")
|
69
|
+
loader_train = DataLoader(
|
70
|
+
dataset=dataset_train, num_workers=4, batch_size=opt.batchSize, shuffle=True
|
71
|
+
)
|
72
|
+
print("# of training samples: %d\n" % int(len(dataset_train)))
|
73
|
+
'''
|
74
|
+
dataset_val = Dataset("val_stack")
|
75
|
+
loader_val = DataLoader(
|
76
|
+
dataset=dataset_val, num_workers=4, batch_size=opt.batchSize, shuffle=False
|
77
|
+
)
|
78
|
+
'''
|
79
|
+
# Build model
|
80
|
+
model = SpikeNet(in_channels=13, features=64, out_channels=1, win_r=6, win_step=7)
|
81
|
+
if not opt.load_model:
|
82
|
+
initial_epoch = 0
|
83
|
+
print("haha")
|
84
|
+
else:
|
85
|
+
# load model
|
86
|
+
initial_epoch = find_last_checkpoint(save_dir=opt.outf)
|
87
|
+
print("load model from model.pth")
|
88
|
+
state_dict = torch.load(
|
89
|
+
os.path.join(opt.outf, "model_%03d.pth" % initial_epoch)
|
90
|
+
)
|
91
|
+
new_state_dict = OrderedDict()
|
92
|
+
for k, v in state_dict.items():
|
93
|
+
name = k[7:]
|
94
|
+
new_state_dict[name] = v
|
95
|
+
model.load_state_dict(new_state_dict)
|
96
|
+
criterion = nn.L1Loss(size_average=True)
|
97
|
+
# Move to GPU
|
98
|
+
device_ids = [0]
|
99
|
+
model = nn.DataParallel(model).cuda()
|
100
|
+
criterion = criterion.cuda()
|
101
|
+
# Optimazer
|
102
|
+
optimizer = optim.Adam(model.parameters(), lr=opt.lr)
|
103
|
+
# training
|
104
|
+
model.train()
|
105
|
+
step = 0
|
106
|
+
for epoch in range(initial_epoch, opt.epochs):
|
107
|
+
avg_psnr = 0
|
108
|
+
if epoch < opt.milestone:
|
109
|
+
current_lr = opt.lr
|
110
|
+
else:
|
111
|
+
current_lr = opt.lr / 10.0
|
112
|
+
# set learning rate
|
113
|
+
for param_group in optimizer.param_groups:
|
114
|
+
param_group["lr"] = current_lr
|
115
|
+
print("learning rate %f" % current_lr)
|
116
|
+
# train
|
117
|
+
for i, (inputs, gt) in enumerate(loader_train, 0):
|
118
|
+
# print(inputs.shape)
|
119
|
+
inputs = Variable(inputs).cuda()
|
120
|
+
gt = Variable(gt).cuda()
|
121
|
+
# training step
|
122
|
+
model.train()
|
123
|
+
model.zero_grad()
|
124
|
+
optimizer.zero_grad()
|
125
|
+
rec, est0, est1, est2, est3, est4 = model(inputs)
|
126
|
+
est0 = est0 / 0.6
|
127
|
+
est1 = est1 / 0.6
|
128
|
+
est2 = est2 / 0.6
|
129
|
+
est3 = est3 / 0.6
|
130
|
+
est4 = est4 / 0.6
|
131
|
+
rec = rec / 0.6
|
132
|
+
loss = criterion(gt[:, 2:3, :, :], rec)
|
133
|
+
for slice_id in range(4):
|
134
|
+
loss = loss + 0.02 * (
|
135
|
+
criterion(gt[:, 0:1, :, :], est0[:, slice_id : slice_id + 1, :, :])
|
136
|
+
+ criterion(
|
137
|
+
gt[:, 1:2, :, :], est1[:, slice_id : slice_id + 1, :, :]
|
138
|
+
)
|
139
|
+
+ criterion(
|
140
|
+
gt[:, 2:3, :, :], est2[:, slice_id : slice_id + 1, :, :]
|
141
|
+
)
|
142
|
+
+ criterion(
|
143
|
+
gt[:, 3:4, :, :], est3[:, slice_id : slice_id + 1, :, :]
|
144
|
+
)
|
145
|
+
+ criterion(
|
146
|
+
gt[:, 4:5, :, :], est4[:, slice_id : slice_id + 1, :, :]
|
147
|
+
)
|
148
|
+
)
|
149
|
+
loss.backward()
|
150
|
+
optimizer.step()
|
151
|
+
rec = torch.clamp(rec, 0, 1)
|
152
|
+
# print(rec)
|
153
|
+
psnr_train = batch_psnr(rec, gt[:, 2:3, :, :], 1.0)
|
154
|
+
# print(gt[:,2:3,:,:])
|
155
|
+
avg_psnr += psnr_train
|
156
|
+
if i % 10 == 0:
|
157
|
+
print(
|
158
|
+
"[epoch %d][%d | %d] loss: %.4f PSNR_train: %.4f"
|
159
|
+
% (epoch + 1, i + 1, len(loader_train), loss.item(), psnr_train)
|
160
|
+
)
|
161
|
+
step += 1
|
162
|
+
avg_psnr = avg_psnr / len(loader_train)
|
163
|
+
print("avg_psnr: %.2f" % avg_psnr)
|
164
|
+
|
165
|
+
if epoch % 5 == 0:
|
166
|
+
'''
|
167
|
+
# validate
|
168
|
+
model.eval()
|
169
|
+
psnr_val = 0
|
170
|
+
for i, (inputs, gt) in enumerate(loader_val, 0):
|
171
|
+
inputs = Variable(inputs).cuda()
|
172
|
+
gt = Variable(gt).cuda()
|
173
|
+
rec = model(inputs)
|
174
|
+
rec = rec / 0.6
|
175
|
+
rec = torch.clamp(rec, 0, 1)
|
176
|
+
psnr_val += batch_psnr(rec, gt, 1.0)
|
177
|
+
print("[epoch %d] PSNR_val: %.4f" % (epoch + 1, psnr_val / len(loader_val)))
|
178
|
+
'''
|
179
|
+
# save model
|
180
|
+
torch.save(
|
181
|
+
model.state_dict(),
|
182
|
+
os.path.join(opt.outf, "model_%03d.pth" % (epoch + 1)),
|
183
|
+
)
|
184
|
+
|
185
|
+
|
186
|
+
if __name__ == "__main__":
|
187
|
+
if opt.preprocess:
|
188
|
+
prepare_data(data_path="./Spk2ImgNet_train/train2/", patch_size=40, stride=40, h5_name='train')
|
189
|
+
main()
|
@@ -0,0 +1,64 @@
|
|
1
|
+
import math
|
2
|
+
|
3
|
+
import numpy as np
|
4
|
+
import torch.nn as nn
|
5
|
+
from skimage.metrics import peak_signal_noise_ratio
|
6
|
+
|
7
|
+
|
8
|
+
def weights_init_kaiming(m):
|
9
|
+
classname = m.__class__.__name__
|
10
|
+
if classname.find("Conv") != -1:
|
11
|
+
nn.init.kaiming_normal_(m.weight.data, a=0, mode="fan_in")
|
12
|
+
elif classname.find("Linear") != -1:
|
13
|
+
nn.init.kaiming_normal_(m.weight.data, a=0, mode="fan_in")
|
14
|
+
elif classname.find("BatchNorm") != -1:
|
15
|
+
# nn.init.uniform(m.weight.data, 1.0, 0.02)
|
16
|
+
m.weight.data.normal_(mean=0, std=math.sqrt(2.0 / 9.0 / 64.0)).clamp_(
|
17
|
+
-0.025, 0.025
|
18
|
+
)
|
19
|
+
nn.init.constant(m.bias.data, 0.0)
|
20
|
+
|
21
|
+
|
22
|
+
def batch_psnr(img, imclean, data_range):
|
23
|
+
img = img.data.cpu().numpy().astype(np.float32)
|
24
|
+
imclean = imclean.data.cpu().numpy().astype(np.float32)
|
25
|
+
psnr = peak_signal_noise_ratio(img, imclean, data_range=data_range)
|
26
|
+
"""
|
27
|
+
PSNR = 0
|
28
|
+
for i in range(Img.shape[0]):
|
29
|
+
PSNR += compare_psnr(imclean[i,:,:,:], img[i,:,:,:], data_range=data_range)
|
30
|
+
return (PSNR/Img.shape[0])
|
31
|
+
"""
|
32
|
+
return psnr
|
33
|
+
|
34
|
+
|
35
|
+
def data_augmentation(image, mode):
|
36
|
+
out = np.transpose(image, (1, 2, 0))
|
37
|
+
if mode == 0:
|
38
|
+
# original
|
39
|
+
out = out
|
40
|
+
elif mode == 1:
|
41
|
+
# flip up and down
|
42
|
+
out = np.flipud(out)
|
43
|
+
elif mode == 2:
|
44
|
+
# rotate counterwise 90 degree
|
45
|
+
out = np.rot90(out)
|
46
|
+
elif mode == 3:
|
47
|
+
# rotate 90 degree and flip up and down
|
48
|
+
out = np.rot90(out)
|
49
|
+
out = np.flipud(out)
|
50
|
+
elif mode == 4:
|
51
|
+
# rotate 180 degree
|
52
|
+
out = np.rot90(out, k=2)
|
53
|
+
elif mode == 5:
|
54
|
+
# rotate 180 degree and flip
|
55
|
+
out = np.rot90(out, k=2)
|
56
|
+
out = np.flipud(out)
|
57
|
+
elif mode == 6:
|
58
|
+
# rotate 270 degree
|
59
|
+
out = np.rot90(out, k=3)
|
60
|
+
elif mode == 7:
|
61
|
+
# rotate 270 degree and flip
|
62
|
+
out = np.rot90(out, k=3)
|
63
|
+
out = np.flipud(out)
|
64
|
+
return np.transpose(out, (2, 0, 1))
|
@@ -0,0 +1,87 @@
|
|
1
|
+
## [TCSVT 2023] Spike Camera Image Reconstruction Using Deep Spiking Neural Networks
|
2
|
+
|
3
|
+
<h4 align="center"> Rui Zhao<sup>1</sup>, Ruiqin Xiong<sup>1</sup>, Jian Zhang<sup>2</sup>, Zhaofei Yu<sup>1</sup>, Shuyuan Zhu<sup>3</sup>, Lei Ma <sup>1</sup>, Tiejun Huang<sup>1</sup> </h4>
|
4
|
+
<h4 align="center">1. National Engineering Research Center of Visual Technology, School of Computer Science, Peking University<br>
|
5
|
+
2. School of Electronic and Computer Engineering, Peking University Shenzhen Graduate School<br>
|
6
|
+
3. School of Information and Communication Engineering, UESTC</h4><br>
|
7
|
+
|
8
|
+
This repository contains the official source code for our paper:
|
9
|
+
|
10
|
+
Spike Camera Image Reconstruction Using Deep Spiking Neural Networks
|
11
|
+
|
12
|
+
TCSVT 2023
|
13
|
+
|
14
|
+
[Paper](https://ieeexplore.ieee.org/document/10288531)
|
15
|
+
|
16
|
+
|
17
|
+
|
18
|
+
## Environment
|
19
|
+
|
20
|
+
You can choose cudatoolkit version to match your server. The code is tested on PyTorch 2.0.1+cuda12.0.
|
21
|
+
|
22
|
+
```shell
|
23
|
+
conda create -n ssir python==3.10
|
24
|
+
conda activate ssir
|
25
|
+
# You can choose the PyTorch version you like, we recommand version >= 1.10.1
|
26
|
+
# For example
|
27
|
+
pip install torch==2.0.1 torchvision==0.15.2 torchaudio==2.0.2
|
28
|
+
pip install -r requirements.txt
|
29
|
+
```
|
30
|
+
|
31
|
+
## Prepare the Data
|
32
|
+
|
33
|
+
#### 1. Download and deploy the SREDS dataset
|
34
|
+
|
35
|
+
[BaiduNetDisk](https://pan.baidu.com/s/1clA43FcxjOibL1zGTaU82g) (Password: 2728)
|
36
|
+
|
37
|
+
`train.tar` corresponds to the training data, and `test.tar` corresponds to the testing data.
|
38
|
+
|
39
|
+
Move the above two `.tar` file to the `data root` directory and extract to the current directory
|
40
|
+
|
41
|
+
```
|
42
|
+
file directory:
|
43
|
+
train:
|
44
|
+
your_data_root/crop_mini/spike/...
|
45
|
+
your_data_root/crop_mini/image/...
|
46
|
+
test:
|
47
|
+
your_data_root/spike/...
|
48
|
+
your_data_root/imgs/...
|
49
|
+
```
|
50
|
+
|
51
|
+
#### 2. Set the path of RSSF dataset in your serve
|
52
|
+
|
53
|
+
In the line25 of `main.py` or set that in command line when running main.py
|
54
|
+
|
55
|
+
## Evaluate
|
56
|
+
```shell
|
57
|
+
cd shells
|
58
|
+
bash eval_SREDS.sh
|
59
|
+
```
|
60
|
+
|
61
|
+
## Train
|
62
|
+
```shell
|
63
|
+
cd shells
|
64
|
+
bash train_SSIR.sh
|
65
|
+
```
|
66
|
+
We recommended to redirect the output logs by adding
|
67
|
+
`>> SSIR.txt 2>&1`
|
68
|
+
to the last of the above command for management.
|
69
|
+
|
70
|
+
|
71
|
+
## Citation
|
72
|
+
|
73
|
+
If you find this code useful in your research, please consider citing our paper.
|
74
|
+
|
75
|
+
```
|
76
|
+
@article{zhao2023spike,
|
77
|
+
title={Spike Camera Image Reconstruction Using Deep Spiking Neural Networks},
|
78
|
+
author={Zhao, Rui and Xiong, Ruiqin and Zhang, Jian and Yu, Zhaofei and Zhu, Shuyuan and Ma, Lei and Huang, Tiejun},
|
79
|
+
journal={IEEE Transactions on Circuits and Systems for Video Technology (TCSVT)},
|
80
|
+
year={2023},
|
81
|
+
}
|
82
|
+
```
|
83
|
+
|
84
|
+
If you have any questions, please contact:
|
85
|
+
ruizhao@stu.pku.edu.cn
|
86
|
+
|
87
|
+
|
@@ -0,0 +1,37 @@
|
|
1
|
+
data:
|
2
|
+
interp: 20
|
3
|
+
alpha: 0.4
|
4
|
+
|
5
|
+
seed: 6666
|
6
|
+
|
7
|
+
loader:
|
8
|
+
# crop_size: [128, 128]
|
9
|
+
crop_size: [96, 96]
|
10
|
+
pair_step: 4
|
11
|
+
|
12
|
+
model:
|
13
|
+
arch: 'sunet'
|
14
|
+
seq_len: 8
|
15
|
+
flow_weight_decay: 0.0004
|
16
|
+
flow_bias_decay: 0.0
|
17
|
+
#########################
|
18
|
+
kwargs:
|
19
|
+
activation_type: 'lif'
|
20
|
+
mp_activation_type: 'amp_lif'
|
21
|
+
spike_connection: 'concat'
|
22
|
+
num_encoders: 3
|
23
|
+
num_resblocks: 1
|
24
|
+
v_threshold: 1.0
|
25
|
+
v_reset: None
|
26
|
+
tau: 2.0
|
27
|
+
|
28
|
+
|
29
|
+
train:
|
30
|
+
print_freq: 100
|
31
|
+
mixed_precision: True
|
32
|
+
vis_freq: 20
|
33
|
+
|
34
|
+
optimizer:
|
35
|
+
solver: Adam
|
36
|
+
momentum: 0.9
|
37
|
+
beta: 0.999
|
@@ -0,0 +1,78 @@
|
|
1
|
+
import numpy as np
|
2
|
+
import torch
|
3
|
+
import yaml
|
4
|
+
|
5
|
+
|
6
|
+
class YAMLParser:
|
7
|
+
"""
|
8
|
+
Modified from code from tudelft ssl-evflow
|
9
|
+
"""
|
10
|
+
|
11
|
+
def __init__(self, config):
|
12
|
+
self.reset_config()
|
13
|
+
self.parse_config(config)
|
14
|
+
# self.init_seeds()
|
15
|
+
|
16
|
+
def parse_config(self, file):
|
17
|
+
with open(file) as fid:
|
18
|
+
yaml_config = yaml.load(fid, Loader=yaml.FullLoader)
|
19
|
+
self.parse_dict(yaml_config)
|
20
|
+
|
21
|
+
@property
|
22
|
+
def config(self):
|
23
|
+
return self._config
|
24
|
+
|
25
|
+
@property
|
26
|
+
def device(self):
|
27
|
+
return self._device
|
28
|
+
|
29
|
+
@property
|
30
|
+
def loader_kwargs(self):
|
31
|
+
return self._loader_kwargs
|
32
|
+
|
33
|
+
def reset_config(self):
|
34
|
+
self._config = {}
|
35
|
+
|
36
|
+
def update(self, config):
|
37
|
+
self.reset_config()
|
38
|
+
self.parse_config(config)
|
39
|
+
|
40
|
+
def parse_dict(self, input_dict, parent=None):
|
41
|
+
if parent is None:
|
42
|
+
parent = self._config
|
43
|
+
for key, val in input_dict.items():
|
44
|
+
if isinstance(val, dict):
|
45
|
+
if key not in parent.keys():
|
46
|
+
parent[key] = {}
|
47
|
+
self.parse_dict(val, parent[key])
|
48
|
+
else:
|
49
|
+
parent[key] = val
|
50
|
+
|
51
|
+
@staticmethod
|
52
|
+
def worker_init_fn(worker_id):
|
53
|
+
np.random.seed(np.random.get_state()[1][0] + worker_id)
|
54
|
+
|
55
|
+
# def init_seeds(self):
|
56
|
+
# torch.manual_seed(self._config["loader"]["seed"])
|
57
|
+
# if torch.cuda.is_available():
|
58
|
+
# torch.cuda.manual_seed(self._config["loader"]["seed"])
|
59
|
+
# torch.cuda.manual_seed_all(self._config["loader"]["seed"])
|
60
|
+
|
61
|
+
def merge_configs(self, run):
|
62
|
+
"""
|
63
|
+
Overwrites mlflow metadata with configs.
|
64
|
+
"""
|
65
|
+
|
66
|
+
# parse mlflow settings
|
67
|
+
config = {}
|
68
|
+
for key in run.keys():
|
69
|
+
if len(run[key]) > 0 and run[key][0] == "{": # assume dictionary
|
70
|
+
config[key] = eval(run[key])
|
71
|
+
else: # string
|
72
|
+
config[key] = run[key]
|
73
|
+
|
74
|
+
# overwrite with config settings
|
75
|
+
self.parse_dict(self._config, config)
|
76
|
+
self.combine_entries(config)
|
77
|
+
|
78
|
+
return config
|
@@ -0,0 +1,170 @@
|
|
1
|
+
import os
|
2
|
+
import os.path as osp
|
3
|
+
import random
|
4
|
+
import numpy as np
|
5
|
+
import torch
|
6
|
+
import torch.utils.data as data
|
7
|
+
from datasets.ds_utils import *
|
8
|
+
import time
|
9
|
+
|
10
|
+
|
11
|
+
class Augmentor:
|
12
|
+
def __init__(self, crop_size):
|
13
|
+
# spatial augmentation params
|
14
|
+
self.crop_size = crop_size
|
15
|
+
|
16
|
+
def augment_img(self, img, mode=0):
|
17
|
+
'''Kai Zhang (github: https://github.com/cszn)
|
18
|
+
W x H x C or W x H
|
19
|
+
'''
|
20
|
+
if mode == 0:
|
21
|
+
return img
|
22
|
+
elif mode == 1:
|
23
|
+
return np.flipud(np.rot90(img))
|
24
|
+
elif mode == 2:
|
25
|
+
return np.flipud(img)
|
26
|
+
elif mode == 3:
|
27
|
+
return np.rot90(img, k=3)
|
28
|
+
elif mode == 4:
|
29
|
+
return np.flipud(np.rot90(img, k=2))
|
30
|
+
elif mode == 5:
|
31
|
+
return np.rot90(img)
|
32
|
+
elif mode == 6:
|
33
|
+
return np.rot90(img, k=2)
|
34
|
+
elif mode == 7:
|
35
|
+
return np.flipud(np.rot90(img, k=3))
|
36
|
+
|
37
|
+
def spatial_transform(self, spk_list, img_list):
|
38
|
+
mode = random.randint(0, 7)
|
39
|
+
|
40
|
+
for ii, spk in enumerate(spk_list):
|
41
|
+
spk = np.transpose(spk, [1,2,0])
|
42
|
+
spk = self.augment_img(spk, mode=mode)
|
43
|
+
spk_list[ii] = np.transpose(spk, [2,0,1])
|
44
|
+
|
45
|
+
for ii, img in enumerate(img_list):
|
46
|
+
img = np.transpose(img, [1,2,0])
|
47
|
+
img = self.augment_img(img, mode=mode)
|
48
|
+
img_list[ii] = np.transpose(img, [2,0,1])
|
49
|
+
|
50
|
+
return spk_list, img_list
|
51
|
+
|
52
|
+
def __call__(self, spk_list, img_list):
|
53
|
+
spk_list, img_list = self.spatial_transform(spk_list, img_list)
|
54
|
+
spk_list = [np.ascontiguousarray(spk) for spk in spk_list]
|
55
|
+
img_list = [np.ascontiguousarray(img) for img in img_list]
|
56
|
+
return spk_list, img_list
|
57
|
+
|
58
|
+
|
59
|
+
class sreds_train(torch.utils.data.Dataset):
|
60
|
+
def __init__(self, cfg):
|
61
|
+
self.cfg = cfg
|
62
|
+
self.pair_step = self.cfg['loader']['pair_step']
|
63
|
+
self.augmentor = Augmentor(crop_size=self.cfg['loader']['crop_size'])
|
64
|
+
self.samples = self.collect_samples()
|
65
|
+
print('The samples num of training data: {:d}'.format(len(self.samples)))
|
66
|
+
|
67
|
+
def confirm_exist(self, path_list_list):
|
68
|
+
for pl in path_list_list:
|
69
|
+
for p in pl:
|
70
|
+
if not osp.exists(p):
|
71
|
+
return 0
|
72
|
+
return 1
|
73
|
+
|
74
|
+
def collect_samples(self):
|
75
|
+
spike_path = osp.join(self.cfg['data']['root'], 'crop_mini', 'spike', 'train', 'interp_{:d}_alpha_{:.2f}'.format(self.cfg['data']['interp'], self.cfg['data']['alpha']))
|
76
|
+
image_path = osp.join(self.cfg['data']['root'], 'crop_mini', 'image', 'train', 'train_orig')
|
77
|
+
scene_list = sorted(os.listdir(spike_path))
|
78
|
+
samples = []
|
79
|
+
|
80
|
+
for scene in scene_list:
|
81
|
+
spike_dir = osp.join(spike_path, scene)
|
82
|
+
image_dir = osp.join(image_path, scene)
|
83
|
+
spk_path_list = sorted(os.listdir(spike_dir))
|
84
|
+
|
85
|
+
spklen = len(spk_path_list)
|
86
|
+
seq_len = self.cfg['model']['seq_len'] + 2
|
87
|
+
|
88
|
+
for st in range(0, spklen - ((spklen - self.pair_step) % seq_len) - seq_len, self.pair_step):
|
89
|
+
# 按照文件名称读取
|
90
|
+
spikes_path_list = [osp.join(spike_dir, spk_path_list[ii]) for ii in range(st, st+seq_len)]
|
91
|
+
images_path_list = [osp.join(image_dir, spk_path_list[ii][:-4]+'.png') for ii in range(st, st+seq_len)]
|
92
|
+
|
93
|
+
if(self.confirm_exist([spikes_path_list, images_path_list])):
|
94
|
+
s = {}
|
95
|
+
s['spikes_paths'] = spikes_path_list
|
96
|
+
s['images_paths'] = images_path_list
|
97
|
+
samples.append(s)
|
98
|
+
return samples
|
99
|
+
|
100
|
+
def _load_sample(self, s):
|
101
|
+
data = {}
|
102
|
+
|
103
|
+
data['spikes'] = [np.array(dat_to_spmat(p, size=(96, 96)), dtype=np.float32) for p in s['spikes_paths']]
|
104
|
+
data['images'] = [read_img_gray(p) for p in s['images_paths']]
|
105
|
+
|
106
|
+
data['spikes'], data['images'] = self.augmentor(data['spikes'], data['images'])
|
107
|
+
|
108
|
+
# print("data['spikes'][0].shape, data['images'][0].shape", data['spikes'][0].shape, data['images'][0].shape)
|
109
|
+
|
110
|
+
return data
|
111
|
+
|
112
|
+
def __len__(self):
|
113
|
+
return len(self.samples)
|
114
|
+
|
115
|
+
def __getitem__(self, index):
|
116
|
+
data = self._load_sample(self.samples[index])
|
117
|
+
return data
|
118
|
+
|
119
|
+
|
120
|
+
class sreds_test(torch.utils.data.Dataset):
|
121
|
+
def __init__(self, cfg):
|
122
|
+
self.cfg = cfg
|
123
|
+
self.samples = self.collect_samples()
|
124
|
+
print('The samples num of testing data: {:d}'.format(len(self.samples)))
|
125
|
+
|
126
|
+
def confirm_exist(self, path_list_list):
|
127
|
+
for pl in path_list_list:
|
128
|
+
for p in pl:
|
129
|
+
if not osp.exists(p):
|
130
|
+
return 0
|
131
|
+
return 1
|
132
|
+
|
133
|
+
def collect_samples(self):
|
134
|
+
spike_path = osp.join(self.cfg['data']['root'], 'spike', 'val', 'interp_{:d}_alpha_{:.2f}'.format(self.cfg['data']['interp'], self.cfg['data']['alpha']))
|
135
|
+
image_path = osp.join(self.cfg['data']['root'], 'imgs', 'val', 'val_orig')
|
136
|
+
scene_list = sorted(os.listdir(spike_path))
|
137
|
+
samples = []
|
138
|
+
|
139
|
+
for scene in scene_list:
|
140
|
+
spike_dir = osp.join(spike_path, scene)
|
141
|
+
image_dir = osp.join(image_path, scene)
|
142
|
+
spk_path_list = sorted(os.listdir(spike_dir))
|
143
|
+
|
144
|
+
spklen = len(spk_path_list)
|
145
|
+
# seq_len = self.cfg['model']['seq_len']
|
146
|
+
|
147
|
+
# 按照文件名称读取
|
148
|
+
spikes_path_list = [osp.join(spike_dir, spk_path_list[ii]) for ii in range(spklen)]
|
149
|
+
images_path_list = [osp.join(image_dir, spk_path_list[ii][:-4]+'.png') for ii in range(spklen)]
|
150
|
+
|
151
|
+
if(self.confirm_exist([spikes_path_list, images_path_list])):
|
152
|
+
s = {}
|
153
|
+
s['spikes_paths'] = spikes_path_list
|
154
|
+
s['images_paths'] = images_path_list
|
155
|
+
samples.append(s)
|
156
|
+
|
157
|
+
return samples
|
158
|
+
|
159
|
+
def _load_sample(self, s):
|
160
|
+
data = {}
|
161
|
+
data['spikes'] = [np.array(dat_to_spmat(p, size=(720, 1280)), dtype=np.float32) for p in s['spikes_paths']]
|
162
|
+
data['images'] = [read_img_gray(p) for p in s['images_paths']]
|
163
|
+
return data
|
164
|
+
|
165
|
+
def __len__(self):
|
166
|
+
return len(self.samples)
|
167
|
+
|
168
|
+
def __getitem__(self, index):
|
169
|
+
data = self._load_sample(self.samples[index])
|
170
|
+
return data
|
@@ -0,0 +1,66 @@
|
|
1
|
+
import numpy as np
|
2
|
+
import os
|
3
|
+
import cv2
|
4
|
+
import os.path as osp
|
5
|
+
|
6
|
+
def RawToSpike(video_seq, h, w, flipud=True):
|
7
|
+
video_seq = np.array(video_seq).astype(np.uint8)
|
8
|
+
img_size = h*w
|
9
|
+
img_num = len(video_seq)//(img_size//8)
|
10
|
+
SpikeMatrix = np.zeros([img_num, h, w], np.uint8)
|
11
|
+
pix_id = np.arange(0,h*w)
|
12
|
+
pix_id = np.reshape(pix_id, (h, w))
|
13
|
+
comparator = np.left_shift(1, np.mod(pix_id, 8))
|
14
|
+
byte_id = pix_id // 8
|
15
|
+
|
16
|
+
for img_id in np.arange(img_num):
|
17
|
+
id_start = img_id*img_size//8
|
18
|
+
id_end = id_start + img_size//8
|
19
|
+
cur_info = video_seq[id_start:id_end]
|
20
|
+
data = cur_info[byte_id]
|
21
|
+
result = np.bitwise_and(data, comparator)
|
22
|
+
if flipud:
|
23
|
+
SpikeMatrix[img_id, :, :] = np.flipud((result == comparator))
|
24
|
+
else:
|
25
|
+
SpikeMatrix[img_id, :, :] = (result == comparator)
|
26
|
+
|
27
|
+
return SpikeMatrix
|
28
|
+
|
29
|
+
|
30
|
+
def SpikeToRaw(SpikeSeq, save_path):
|
31
|
+
"""
|
32
|
+
SpikeSeq: Numpy array (sfn x h x w)
|
33
|
+
save_path: full saving path (string)
|
34
|
+
Rui Zhao
|
35
|
+
"""
|
36
|
+
sfn, h, w = SpikeSeq.shape
|
37
|
+
base = np.power(2, np.linspace(0, 7, 8))
|
38
|
+
fid = open(save_path, 'ab')
|
39
|
+
for img_id in range(sfn):
|
40
|
+
# 模拟相机的倒像
|
41
|
+
spike = np.flipud(SpikeSeq[img_id, :, :])
|
42
|
+
# numpy按自动按行排,数据也是按行存的
|
43
|
+
spike = spike.flatten()
|
44
|
+
spike = spike.reshape([int(h*w/8), 8])
|
45
|
+
data = spike * base
|
46
|
+
data = np.sum(data, axis=1).astype(np.uint8)
|
47
|
+
fid.write(data.tobytes())
|
48
|
+
|
49
|
+
fid.close()
|
50
|
+
|
51
|
+
return
|
52
|
+
|
53
|
+
|
54
|
+
def dat_to_spmat(dat_path, size=[720, 1280]):
|
55
|
+
f = open(dat_path, 'rb')
|
56
|
+
video_seq = f.read()
|
57
|
+
video_seq = np.frombuffer(video_seq, 'b')
|
58
|
+
sp_mat = RawToSpike(video_seq, size[0], size[1])
|
59
|
+
return sp_mat
|
60
|
+
|
61
|
+
|
62
|
+
def read_img_gray(file_path):
|
63
|
+
im = cv2.imread(file_path).astype(np.float32) / 255.0
|
64
|
+
im = cv2.cvtColor(im, cv2.COLOR_BGR2GRAY)
|
65
|
+
im = np.expand_dims(im, axis=0)
|
66
|
+
return im
|