spikezoo 0.1.2__py3-none-any.whl → 0.2.1__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- spikezoo/__init__.py +13 -0
- spikezoo/archs/__pycache__/__init__.cpython-39.pyc +0 -0
- spikezoo/archs/base/__pycache__/nets.cpython-39.pyc +0 -0
- spikezoo/archs/base/nets.py +34 -0
- spikezoo/archs/bsf/README.md +92 -0
- spikezoo/archs/bsf/datasets/datasets.py +328 -0
- spikezoo/archs/bsf/datasets/ds_utils.py +64 -0
- spikezoo/archs/bsf/main.py +398 -0
- spikezoo/archs/bsf/metrics/psnr.py +22 -0
- spikezoo/archs/bsf/metrics/ssim.py +54 -0
- spikezoo/archs/bsf/models/bsf/__pycache__/align.cpython-39.pyc +0 -0
- spikezoo/archs/bsf/models/bsf/__pycache__/bsf.cpython-39.pyc +0 -0
- spikezoo/archs/bsf/models/bsf/__pycache__/rep.cpython-39.pyc +0 -0
- spikezoo/archs/bsf/models/bsf/align.py +154 -0
- spikezoo/archs/bsf/models/bsf/bsf.py +105 -0
- spikezoo/archs/bsf/models/bsf/dsft_convert.py +96 -0
- spikezoo/archs/bsf/models/bsf/rep.py +44 -0
- spikezoo/archs/bsf/models/get_model.py +7 -0
- spikezoo/archs/bsf/prepare_data/DSFT.py +62 -0
- spikezoo/archs/bsf/prepare_data/crop_dataset_train.py +135 -0
- spikezoo/archs/bsf/prepare_data/crop_dataset_val.py +139 -0
- spikezoo/archs/bsf/prepare_data/crop_train.sh +4 -0
- spikezoo/archs/bsf/prepare_data/crop_val.sh +4 -0
- spikezoo/archs/bsf/prepare_data/io_utils.py +64 -0
- spikezoo/archs/bsf/requirements.txt +9 -0
- spikezoo/archs/bsf/test.py +16 -0
- spikezoo/archs/bsf/utils.py +154 -0
- spikezoo/archs/spikeclip/__pycache__/nets.cpython-39.pyc +0 -0
- spikezoo/archs/spikeclip/nets.py +40 -0
- spikezoo/archs/spikeformer/CheckPoints/readme +1 -0
- spikezoo/archs/spikeformer/DataProcess/DataExtactor.py +60 -0
- spikezoo/archs/spikeformer/DataProcess/DataLoader.py +115 -0
- spikezoo/archs/spikeformer/DataProcess/LoadSpike.py +39 -0
- spikezoo/archs/spikeformer/EvalResults/readme +1 -0
- spikezoo/archs/spikeformer/LICENSE +21 -0
- spikezoo/archs/spikeformer/Metrics/Metrics.py +50 -0
- spikezoo/archs/spikeformer/Metrics/__init__.py +0 -0
- spikezoo/archs/spikeformer/Model/Loss.py +89 -0
- spikezoo/archs/spikeformer/Model/SpikeFormer.py +230 -0
- spikezoo/archs/spikeformer/Model/__init__.py +0 -0
- spikezoo/archs/spikeformer/Model/__pycache__/SpikeFormer.cpython-39.pyc +0 -0
- spikezoo/archs/spikeformer/Model/__pycache__/__init__.cpython-39.pyc +0 -0
- spikezoo/archs/spikeformer/README.md +30 -0
- spikezoo/archs/spikeformer/evaluate.py +87 -0
- spikezoo/archs/spikeformer/recon_real_data.py +97 -0
- spikezoo/archs/spikeformer/requirements.yml +95 -0
- spikezoo/archs/spikeformer/train.py +173 -0
- spikezoo/archs/spikeformer/utils.py +22 -0
- spikezoo/archs/spk2imgnet/.github/workflows/pylint.yml +23 -0
- spikezoo/archs/spk2imgnet/.gitignore +150 -0
- spikezoo/archs/spk2imgnet/DCNv2.py +135 -0
- spikezoo/archs/spk2imgnet/__pycache__/DCNv2.cpython-39.pyc +0 -0
- spikezoo/archs/spk2imgnet/__pycache__/align_arch.cpython-39.pyc +0 -0
- spikezoo/archs/spk2imgnet/__pycache__/nets.cpython-39.pyc +0 -0
- spikezoo/archs/spk2imgnet/align_arch.py +159 -0
- spikezoo/archs/spk2imgnet/dataset.py +144 -0
- spikezoo/archs/spk2imgnet/nets.py +230 -0
- spikezoo/archs/spk2imgnet/readme.md +86 -0
- spikezoo/archs/spk2imgnet/test_gen_imgseq.py +118 -0
- spikezoo/archs/spk2imgnet/train.py +189 -0
- spikezoo/archs/spk2imgnet/utils.py +64 -0
- spikezoo/archs/ssir/README.md +87 -0
- spikezoo/archs/ssir/configs/SSIR.yml +37 -0
- spikezoo/archs/ssir/configs/yml_parser.py +78 -0
- spikezoo/archs/ssir/datasets/dataset_sreds.py +170 -0
- spikezoo/archs/ssir/datasets/ds_utils.py +66 -0
- spikezoo/archs/ssir/losses.py +21 -0
- spikezoo/archs/ssir/main.py +326 -0
- spikezoo/archs/ssir/metrics/psnr.py +22 -0
- spikezoo/archs/ssir/metrics/ssim.py +54 -0
- spikezoo/archs/ssir/models/Vgg19.py +42 -0
- spikezoo/archs/ssir/models/__pycache__/layers.cpython-39.pyc +0 -0
- spikezoo/archs/ssir/models/__pycache__/networks.cpython-39.pyc +0 -0
- spikezoo/archs/ssir/models/layers.py +110 -0
- spikezoo/archs/ssir/models/networks.py +61 -0
- spikezoo/archs/ssir/requirements.txt +8 -0
- spikezoo/archs/ssir/shells/eval_SREDS.sh +6 -0
- spikezoo/archs/ssir/shells/train_SSIR.sh +12 -0
- spikezoo/archs/ssir/test.py +3 -0
- spikezoo/archs/ssir/utils.py +154 -0
- spikezoo/archs/ssml/__pycache__/cbam.cpython-39.pyc +0 -0
- spikezoo/archs/ssml/__pycache__/model.cpython-39.pyc +0 -0
- spikezoo/archs/ssml/cbam.py +224 -0
- spikezoo/archs/ssml/model.py +290 -0
- spikezoo/archs/ssml/res.png +0 -0
- spikezoo/archs/ssml/test.py +67 -0
- spikezoo/archs/stir/.git-credentials +0 -0
- spikezoo/archs/stir/README.md +65 -0
- spikezoo/archs/stir/ckpt_outputs/Descriptions.txt +1 -0
- spikezoo/archs/stir/configs/STIR.yml +37 -0
- spikezoo/archs/stir/configs/utils.py +155 -0
- spikezoo/archs/stir/configs/yml_parser.py +78 -0
- spikezoo/archs/stir/datasets/dataset_sreds.py +180 -0
- spikezoo/archs/stir/datasets/ds_utils.py +66 -0
- spikezoo/archs/stir/eval_SREDS.sh +5 -0
- spikezoo/archs/stir/main.py +397 -0
- spikezoo/archs/stir/metrics/losses.py +219 -0
- spikezoo/archs/stir/metrics/psnr.py +22 -0
- spikezoo/archs/stir/metrics/ssim.py +54 -0
- spikezoo/archs/stir/models/Vgg19.py +42 -0
- spikezoo/archs/stir/models/__pycache__/networks_STIR.cpython-39.pyc +0 -0
- spikezoo/archs/stir/models/__pycache__/submodules.cpython-39.pyc +0 -0
- spikezoo/archs/stir/models/__pycache__/transformer_new.cpython-39.pyc +0 -0
- spikezoo/archs/stir/models/networks_STIR.py +361 -0
- spikezoo/archs/stir/models/submodules.py +86 -0
- spikezoo/archs/stir/models/transformer_new.py +151 -0
- spikezoo/archs/stir/package_core/build/lib/package_core/__init__.py +0 -0
- spikezoo/archs/stir/package_core/build/lib/package_core/convertions.py +721 -0
- spikezoo/archs/stir/package_core/build/lib/package_core/disp_netS.py +133 -0
- spikezoo/archs/stir/package_core/build/lib/package_core/flow_utils.py +167 -0
- spikezoo/archs/stir/package_core/build/lib/package_core/generic_train_test.py +76 -0
- spikezoo/archs/stir/package_core/build/lib/package_core/geometry.py +458 -0
- spikezoo/archs/stir/package_core/build/lib/package_core/image_proc.py +183 -0
- spikezoo/archs/stir/package_core/build/lib/package_core/linalg.py +40 -0
- spikezoo/archs/stir/package_core/build/lib/package_core/losses.py +198 -0
- spikezoo/archs/stir/package_core/build/lib/package_core/metrics.py +51 -0
- spikezoo/archs/stir/package_core/build/lib/package_core/model_base.py +53 -0
- spikezoo/archs/stir/package_core/build/lib/package_core/net_basics.py +100 -0
- spikezoo/archs/stir/package_core/build/lib/package_core/resnet.py +333 -0
- spikezoo/archs/stir/package_core/build/lib/package_core/transforms.py +123 -0
- spikezoo/archs/stir/package_core/build/lib/package_core/utils.py +72 -0
- spikezoo/archs/stir/package_core/dist/package_core-0.0.0-py3.9.egg +0 -0
- spikezoo/archs/stir/package_core/package_core/__init__.py +0 -0
- spikezoo/archs/stir/package_core/package_core/__pycache__/__init__.cpython-39.pyc +0 -0
- spikezoo/archs/stir/package_core/package_core/__pycache__/net_basics.cpython-39.pyc +0 -0
- spikezoo/archs/stir/package_core/package_core/convertions.py +721 -0
- spikezoo/archs/stir/package_core/package_core/disp_netS.py +133 -0
- spikezoo/archs/stir/package_core/package_core/flow_utils.py +167 -0
- spikezoo/archs/stir/package_core/package_core/generic_train_test.py +76 -0
- spikezoo/archs/stir/package_core/package_core/geometry.py +458 -0
- spikezoo/archs/stir/package_core/package_core/image_proc.py +183 -0
- spikezoo/archs/stir/package_core/package_core/linalg.py +40 -0
- spikezoo/archs/stir/package_core/package_core/losses.py +198 -0
- spikezoo/archs/stir/package_core/package_core/metrics.py +51 -0
- spikezoo/archs/stir/package_core/package_core/model_base.py +53 -0
- spikezoo/archs/stir/package_core/package_core/net_basics.py +100 -0
- spikezoo/archs/stir/package_core/package_core/resnet.py +333 -0
- spikezoo/archs/stir/package_core/package_core/transforms.py +123 -0
- spikezoo/archs/stir/package_core/package_core/utils.py +72 -0
- spikezoo/archs/stir/package_core/package_core.egg-info/PKG-INFO +3 -0
- spikezoo/archs/stir/package_core/package_core.egg-info/SOURCES.txt +20 -0
- spikezoo/archs/stir/package_core/package_core.egg-info/dependency_links.txt +1 -0
- spikezoo/archs/stir/package_core/package_core.egg-info/top_level.txt +1 -0
- spikezoo/archs/stir/package_core/setup.py +5 -0
- spikezoo/archs/stir/requirements.txt +12 -0
- spikezoo/archs/stir/train_STIR.sh +9 -0
- spikezoo/archs/tfi/__pycache__/nets.cpython-39.pyc +0 -0
- spikezoo/archs/tfi/nets.py +43 -0
- spikezoo/archs/tfp/__pycache__/nets.cpython-39.pyc +0 -0
- spikezoo/archs/tfp/nets.py +13 -0
- spikezoo/archs/wgse/README.md +64 -0
- spikezoo/archs/wgse/__pycache__/dwtnets.cpython-39.pyc +0 -0
- spikezoo/archs/wgse/__pycache__/submodules.cpython-39.pyc +0 -0
- spikezoo/archs/wgse/dataset.py +59 -0
- spikezoo/archs/wgse/demo.png +0 -0
- spikezoo/archs/wgse/demo.py +83 -0
- spikezoo/archs/wgse/dwtnets.py +145 -0
- spikezoo/archs/wgse/eval.py +133 -0
- spikezoo/archs/wgse/logs/WGSE-Dwt1dNet-db8-5-ks3/log.txt +11 -0
- spikezoo/archs/wgse/submodules.py +68 -0
- spikezoo/archs/wgse/train.py +261 -0
- spikezoo/archs/wgse/transform.py +139 -0
- spikezoo/archs/wgse/utils.py +128 -0
- spikezoo/archs/wgse/weights/demo.png +0 -0
- spikezoo/data/base/test/gt/200_part1_key_id151.png +0 -0
- spikezoo/data/base/test/gt/200_part3_key_id151.png +0 -0
- spikezoo/data/base/test/gt/203_part1_key_id151.png +0 -0
- spikezoo/data/base/test/spike/200_part1_key_id151.dat +0 -0
- spikezoo/data/base/test/spike/200_part3_key_id151.dat +0 -0
- spikezoo/data/base/test/spike/203_part1_key_id151.dat +0 -0
- spikezoo/data/base/train/gt/203_part2_key_id151.png +0 -0
- spikezoo/data/base/train/gt/203_part3_key_id151.png +0 -0
- spikezoo/data/base/train/gt/203_part4_key_id151.png +0 -0
- spikezoo/data/base/train/spike/203_part2_key_id151.dat +0 -0
- spikezoo/data/base/train/spike/203_part3_key_id151.dat +0 -0
- spikezoo/data/base/train/spike/203_part4_key_id151.dat +0 -0
- spikezoo/datasets/base_dataset.py +2 -3
- spikezoo/metrics/__init__.py +1 -1
- spikezoo/models/base_model.py +1 -3
- spikezoo/pipeline/base_pipeline.py +7 -5
- spikezoo/pipeline/train_pipeline.py +1 -1
- spikezoo/utils/other_utils.py +16 -6
- spikezoo/utils/spike_utils.py +33 -29
- spikezoo/utils/vidar_loader.cpython-39-x86_64-linux-gnu.so +0 -0
- spikezoo-0.2.1.dist-info/METADATA +167 -0
- spikezoo-0.2.1.dist-info/RECORD +211 -0
- spikezoo/models/spcsnet_model.py +0 -19
- spikezoo-0.1.2.dist-info/METADATA +0 -39
- spikezoo-0.1.2.dist-info/RECORD +0 -36
- {spikezoo-0.1.2.dist-info → spikezoo-0.2.1.dist-info}/LICENSE.txt +0 -0
- {spikezoo-0.1.2.dist-info → spikezoo-0.2.1.dist-info}/WHEEL +0 -0
- {spikezoo-0.1.2.dist-info → spikezoo-0.2.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,144 @@
|
|
1
|
+
import os
|
2
|
+
import os.path
|
3
|
+
from typing import List
|
4
|
+
|
5
|
+
import numpy as np
|
6
|
+
import random
|
7
|
+
import h5py
|
8
|
+
import torch
|
9
|
+
import cv2
|
10
|
+
import glob
|
11
|
+
import torch.utils.data as udata
|
12
|
+
from functools import partial
|
13
|
+
|
14
|
+
bytes2num = partial(int.from_bytes, byteorder="little", signed=False)
|
15
|
+
|
16
|
+
|
17
|
+
def normalize(data):
|
18
|
+
return data / 255.0
|
19
|
+
|
20
|
+
|
21
|
+
def raw_to_spike(video_seq, h, w):
|
22
|
+
video_seq = np.array(video_seq).astype(np.uint8)
|
23
|
+
img_size = h * w
|
24
|
+
img_num = len(video_seq) // (img_size // 8)
|
25
|
+
spike_matrix = np.zeros([img_num, h, w], np.uint8)
|
26
|
+
pix_id = np.arange(0, h * w)
|
27
|
+
pix_id = np.reshape(pix_id, (h, w))
|
28
|
+
comparator = np.left_shift(1, np.mod(pix_id, 8))
|
29
|
+
byte_id = pix_id // 8
|
30
|
+
|
31
|
+
for img_id in np.arange(img_num):
|
32
|
+
id_start = img_id * img_size // 8
|
33
|
+
id_end = id_start + img_size // 8
|
34
|
+
cur_info = video_seq[id_start:id_end]
|
35
|
+
data = cur_info[byte_id]
|
36
|
+
result = np.bitwise_and(data, comparator)
|
37
|
+
spike_matrix[img_id, :, :] = np.flipud((result == comparator))
|
38
|
+
|
39
|
+
return spike_matrix
|
40
|
+
|
41
|
+
|
42
|
+
def Im2Patch(img, win, stride=40):
|
43
|
+
k = 0
|
44
|
+
[endc, endw, endh] = img.shape
|
45
|
+
patch = img[:, 0: endw - win + 0 + 1: stride, 0: endh - win + 0 + 1: stride]
|
46
|
+
total_pat_num = patch.shape[1] * patch.shape[2]
|
47
|
+
Y = np.zeros([endc, win * win, total_pat_num], np.float32)
|
48
|
+
for i in range(win):
|
49
|
+
for j in range(win):
|
50
|
+
patch = img[
|
51
|
+
:, i: endw - win + i + 1: stride, j: endh - win + j + 1: stride
|
52
|
+
]
|
53
|
+
Y[:, k, :] = np.array(patch[:]).reshape(endc, total_pat_num)
|
54
|
+
k = k + 1
|
55
|
+
return Y.reshape([endc, win, win, total_pat_num])
|
56
|
+
|
57
|
+
|
58
|
+
def read_image_and_concat_as_tensor(paths: List[str]):
|
59
|
+
tensors = []
|
60
|
+
for path in paths:
|
61
|
+
img = cv2.imread(path)
|
62
|
+
tensors.append(img.reshape([1, *img.shape]))
|
63
|
+
return np.concatenate(tensors, axis=0)
|
64
|
+
|
65
|
+
|
66
|
+
def prepare_data(data_path, patch_size, stride, h5_name, aug_times=1):
|
67
|
+
print("process training data")
|
68
|
+
input_files = glob.glob(os.path.join(data_path, "input", "*.dat"))
|
69
|
+
print(len(input_files))
|
70
|
+
input_files.sort()
|
71
|
+
input_h5f = h5py.File(h5_name + "_input.h5", "w")
|
72
|
+
gt_h5f = h5py.File(h5_name + "_gt.h5", "w")
|
73
|
+
train_num = 0
|
74
|
+
h = 250
|
75
|
+
w = 400
|
76
|
+
for i in range(len(input_files)):
|
77
|
+
input_f = open(input_files[i], "rb+")
|
78
|
+
video_seq = input_f.read()
|
79
|
+
video_seq = np.fromstring(video_seq, "B")
|
80
|
+
# print(video_seq)
|
81
|
+
spike_array = raw_to_spike(video_seq, h, w) # c*h*w
|
82
|
+
# print(input_files[i][:-3])
|
83
|
+
# SpikeArray = SpikeArray[10:-10, :, :]
|
84
|
+
# print(np.mean(SpikeArray))
|
85
|
+
print(spike_array.shape)
|
86
|
+
file_name = input_files[i].replace("\\", "/").split("/")[-1]
|
87
|
+
gt = []
|
88
|
+
for num in [7, 14, 21, 28, 35]:
|
89
|
+
img = cv2.imread(os.path.join(data_path, "gt", file_name[:-6] + str(num) + ".png"), 0)
|
90
|
+
gt.append(img.reshape([1, *img.shape]))
|
91
|
+
gt = np.concatenate(gt, axis=0)
|
92
|
+
print(input_files[i])
|
93
|
+
print(os.path.join(data_path, "gt", file_name[:-3] + "png"))
|
94
|
+
gt = np.float32(normalize(gt)) # size
|
95
|
+
print(gt.shape)
|
96
|
+
print(spike_array.shape)
|
97
|
+
input_patches = Im2Patch(spike_array, win=patch_size, stride=stride)
|
98
|
+
gt_patches = Im2Patch(gt, win=patch_size, stride=stride)
|
99
|
+
assert input_patches.shape[3] == gt_patches.shape[3]
|
100
|
+
for n in range(input_patches.shape[3]):
|
101
|
+
inputs = input_patches[:, :, :, n].copy()
|
102
|
+
input_h5f.create_dataset(str(train_num), data=inputs)
|
103
|
+
gt = gt_patches[:, :, :, n].copy()
|
104
|
+
gt_h5f.create_dataset(str(train_num), data=gt)
|
105
|
+
train_num += 1
|
106
|
+
|
107
|
+
input_h5f.close()
|
108
|
+
gt_h5f.close()
|
109
|
+
|
110
|
+
|
111
|
+
class Dataset(udata.Dataset):
|
112
|
+
def __init__(self, h5_name):
|
113
|
+
super(Dataset, self).__init__()
|
114
|
+
input_h5f = h5py.File(h5_name + "_input.h5", "r")
|
115
|
+
gt_h5f = h5py.File(h5_name + "_gt.h5", "r")
|
116
|
+
self.h5_name = h5_name
|
117
|
+
self.keys = list(input_h5f.keys())
|
118
|
+
# print(self.keys)
|
119
|
+
random.shuffle(self.keys)
|
120
|
+
input_h5f.close()
|
121
|
+
gt_h5f.close()
|
122
|
+
|
123
|
+
def __len__(self):
|
124
|
+
return len(self.keys)
|
125
|
+
|
126
|
+
def __getitem__(self, index):
|
127
|
+
input_h5f = h5py.File(self.h5_name + "_input.h5", "r")
|
128
|
+
gt_h5f = h5py.File(self.h5_name + "_gt.h5", "r")
|
129
|
+
key = self.keys[index]
|
130
|
+
inputs = np.array(input_h5f[key])
|
131
|
+
gt = np.array(gt_h5f[key])
|
132
|
+
input_h5f.close()
|
133
|
+
gt_h5f.close()
|
134
|
+
return torch.Tensor(inputs), torch.Tensor(gt)
|
135
|
+
|
136
|
+
|
137
|
+
if __name__ == "__main__":
|
138
|
+
prepare_data(
|
139
|
+
data_path="./Spk2ImgNet_train/train2/",
|
140
|
+
patch_size=40,
|
141
|
+
stride=40,
|
142
|
+
h5_name="train",
|
143
|
+
)
|
144
|
+
# PrepareData(data_path = './SpikeDataset/val/', patch_size=40, stride=40, h5_name='val')
|
@@ -0,0 +1,230 @@
|
|
1
|
+
import numpy as np
|
2
|
+
import sys
|
3
|
+
import os
|
4
|
+
current_dir = os.path.dirname(os.path.abspath(__file__))
|
5
|
+
sys.path.append(current_dir)
|
6
|
+
|
7
|
+
from align_arch import *
|
8
|
+
|
9
|
+
class BasicBlock(nn.Module):
|
10
|
+
def __init__(self, features):
|
11
|
+
super().__init__()
|
12
|
+
self.conv1 = nn.Conv2d(
|
13
|
+
in_channels=features,
|
14
|
+
out_channels=features,
|
15
|
+
kernel_size=3,
|
16
|
+
padding=1,
|
17
|
+
bias=True,
|
18
|
+
)
|
19
|
+
self.relu1 = nn.ReLU()
|
20
|
+
self.conv2 = nn.Conv2d(
|
21
|
+
in_channels=features,
|
22
|
+
out_channels=features,
|
23
|
+
kernel_size=3,
|
24
|
+
padding=1,
|
25
|
+
bias=True,
|
26
|
+
)
|
27
|
+
self.relu2 = nn.ReLU()
|
28
|
+
self.conv3 = nn.Conv2d(
|
29
|
+
in_channels=features,
|
30
|
+
out_channels=features,
|
31
|
+
kernel_size=3,
|
32
|
+
padding=1,
|
33
|
+
bias=True,
|
34
|
+
)
|
35
|
+
self.relu3 = nn.ReLU()
|
36
|
+
|
37
|
+
def forward(self, x):
|
38
|
+
out = self.conv1(x)
|
39
|
+
out = self.relu1(out)
|
40
|
+
out = self.conv2(out)
|
41
|
+
out = self.relu2(out)
|
42
|
+
out = self.conv3(out)
|
43
|
+
return self.relu3(x + out)
|
44
|
+
|
45
|
+
|
46
|
+
# use Sigmoid
|
47
|
+
class CALayer2(nn.Module):
|
48
|
+
def __init__(self, in_channels):
|
49
|
+
super(CALayer2, self).__init__()
|
50
|
+
self.ca_block = nn.Sequential(
|
51
|
+
nn.Conv2d(in_channels, in_channels * 2, 3, padding=1, bias=True),
|
52
|
+
nn.ReLU(),
|
53
|
+
nn.Conv2d(in_channels * 2, in_channels, 3, padding=1, bias=True),
|
54
|
+
nn.Sigmoid(),
|
55
|
+
)
|
56
|
+
|
57
|
+
def forward(self, x):
|
58
|
+
weight = self.ca_block(x)
|
59
|
+
return weight
|
60
|
+
|
61
|
+
|
62
|
+
# use CALayer
|
63
|
+
class FeatureExtractor(nn.Module):
|
64
|
+
def __init__(
|
65
|
+
self, in_channels, features, out_channels, channel_step, num_of_layers=16
|
66
|
+
):
|
67
|
+
super(FeatureExtractor, self).__init__()
|
68
|
+
# self.InferLayer = LightInferLayer(in_channels=in_channels)
|
69
|
+
self.channel_step = channel_step
|
70
|
+
self.conv0_0 = nn.Conv2d(
|
71
|
+
in_channels=in_channels, out_channels=16, kernel_size=3, padding=1
|
72
|
+
)
|
73
|
+
self.conv0_1 = nn.Conv2d(
|
74
|
+
in_channels=in_channels - 2 * channel_step,
|
75
|
+
out_channels=16,
|
76
|
+
kernel_size=3,
|
77
|
+
padding=1,
|
78
|
+
)
|
79
|
+
self.conv0_2 = nn.Conv2d(
|
80
|
+
in_channels=in_channels - 4 * channel_step,
|
81
|
+
out_channels=16,
|
82
|
+
kernel_size=3,
|
83
|
+
padding=1,
|
84
|
+
)
|
85
|
+
self.conv0_3 = nn.Conv2d(
|
86
|
+
in_channels=in_channels - 6 * channel_step,
|
87
|
+
out_channels=16,
|
88
|
+
kernel_size=3,
|
89
|
+
padding=1,
|
90
|
+
)
|
91
|
+
self.conv1_0 = nn.Conv2d(
|
92
|
+
in_channels=16, out_channels=1, kernel_size=3, padding=1
|
93
|
+
)
|
94
|
+
self.conv1_1 = nn.Conv2d(
|
95
|
+
in_channels=16, out_channels=1, kernel_size=3, padding=1
|
96
|
+
)
|
97
|
+
self.conv1_2 = nn.Conv2d(
|
98
|
+
in_channels=16, out_channels=1, kernel_size=3, padding=1
|
99
|
+
)
|
100
|
+
self.conv1_3 = nn.Conv2d(
|
101
|
+
in_channels=16, out_channels=1, kernel_size=3, padding=1
|
102
|
+
)
|
103
|
+
self.ca = CALayer2(in_channels=4)
|
104
|
+
self.conv = nn.Conv2d(
|
105
|
+
in_channels=4, out_channels=features, kernel_size=3, padding=1
|
106
|
+
)
|
107
|
+
self.relu = nn.ReLU()
|
108
|
+
layers = []
|
109
|
+
for _ in range(num_of_layers - 2):
|
110
|
+
layers.append(BasicBlock(features=features))
|
111
|
+
# layers.append(nn.Conv2d(in_channels=features, out_channels=out_channels, kernel_size=kernel_size, padding=padding, bias=True))
|
112
|
+
self.net = nn.Sequential(*layers)
|
113
|
+
|
114
|
+
def forward(self, x):
|
115
|
+
out_0 = self.conv1_0(self.relu(self.conv0_0(x)))
|
116
|
+
out_1 = self.conv1_1(
|
117
|
+
self.relu(self.conv0_1(x[:, self.channel_step : -self.channel_step, :, :]))
|
118
|
+
)
|
119
|
+
out_2 = self.conv1_2(
|
120
|
+
self.relu(
|
121
|
+
self.conv0_2(x[:, 2 * self.channel_step : -2 * self.channel_step, :, :])
|
122
|
+
)
|
123
|
+
)
|
124
|
+
out_3 = self.conv1_3(
|
125
|
+
self.relu(
|
126
|
+
self.conv0_3(x[:, 3 * self.channel_step : -3 * self.channel_step, :, :])
|
127
|
+
)
|
128
|
+
)
|
129
|
+
out = torch.cat((out_0, out_1), 1)
|
130
|
+
out = torch.cat((out, out_2), 1)
|
131
|
+
out = torch.cat((out, out_3), 1)
|
132
|
+
est = out
|
133
|
+
weight = self.ca(out)
|
134
|
+
out = weight * out
|
135
|
+
out = self.conv(out)
|
136
|
+
out = self.relu(out)
|
137
|
+
tmp = out
|
138
|
+
out = self.net(out)
|
139
|
+
# out = self.conv2(out)
|
140
|
+
# out = self.relu2(out)
|
141
|
+
# out = self.conv3(out)
|
142
|
+
return out + tmp, est
|
143
|
+
|
144
|
+
|
145
|
+
class FusionMaskV1(nn.Module):
|
146
|
+
def __init__(self, features):
|
147
|
+
super(FusionMaskV1, self).__init__()
|
148
|
+
self.conv0 = nn.Conv2d(
|
149
|
+
in_channels=2 * features, out_channels=features, kernel_size=3, padding=1
|
150
|
+
)
|
151
|
+
self.conv1 = nn.Conv2d(
|
152
|
+
in_channels=features, out_channels=features, kernel_size=3, padding=1
|
153
|
+
)
|
154
|
+
self.conv2 = nn.Conv2d(
|
155
|
+
in_channels=features, out_channels=features, kernel_size=3, padding=1
|
156
|
+
)
|
157
|
+
self.prelu0 = nn.PReLU()
|
158
|
+
self.prelu1 = nn.PReLU()
|
159
|
+
self.sig = nn.Sigmoid()
|
160
|
+
|
161
|
+
def forward(self, ref, key):
|
162
|
+
fea = torch.cat((ref, key), 1)
|
163
|
+
fea = self.conv2(self.prelu1(self.conv1(self.prelu0(self.conv0(fea)))))
|
164
|
+
mask = self.sig(fea)
|
165
|
+
return mask
|
166
|
+
|
167
|
+
|
168
|
+
# current best model
|
169
|
+
class SpikeNet(nn.Module):
|
170
|
+
def __init__(self, in_channels, features, out_channels, win_r, win_step):
|
171
|
+
super(SpikeNet, self).__init__()
|
172
|
+
self.extractor = FeatureExtractor(
|
173
|
+
in_channels=in_channels,
|
174
|
+
features=features,
|
175
|
+
out_channels=features,
|
176
|
+
channel_step=1,
|
177
|
+
num_of_layers=12,
|
178
|
+
)
|
179
|
+
self.mask0 = FusionMaskV1(features=features)
|
180
|
+
self.mask1 = FusionMaskV1(features=features)
|
181
|
+
self.mask3 = FusionMaskV1(features=features)
|
182
|
+
self.mask4 = FusionMaskV1(features=features)
|
183
|
+
self.rec_conv0 = nn.Conv2d(
|
184
|
+
in_channels=5 * features,
|
185
|
+
out_channels=3 * features,
|
186
|
+
kernel_size=3,
|
187
|
+
padding=1,
|
188
|
+
)
|
189
|
+
self.rec_conv1 = nn.Conv2d(
|
190
|
+
in_channels=3 * features, out_channels=features, kernel_size=3, padding=1
|
191
|
+
)
|
192
|
+
self.rec_conv2 = nn.Conv2d(
|
193
|
+
in_channels=features, out_channels=1, kernel_size=3, padding=1
|
194
|
+
)
|
195
|
+
self.rec_relu = nn.ReLU()
|
196
|
+
self.pcd_align = Easy_PCD(nf=features, groups=8)
|
197
|
+
self.win_r = win_r
|
198
|
+
self.win_step = win_step
|
199
|
+
|
200
|
+
def forward(self, x):
|
201
|
+
block0 = x[:, 0 : 2 * self.win_r + 1, :, :]
|
202
|
+
block1 = x[:, self.win_step : self.win_step + 2 * self.win_r + 1, :, :]
|
203
|
+
block2 = x[:, 2 * self.win_step : 2 * self.win_step + 2 * self.win_r + 1, :, :]
|
204
|
+
block3 = x[:, 3 * self.win_step : 3 * self.win_step + 2 * self.win_r + 1, :, :]
|
205
|
+
block4 = x[:, 4 * self.win_step : 4 * self.win_step + 2 * self.win_r + 1, :, :]
|
206
|
+
block0_out, est0 = self.extractor(block0)
|
207
|
+
block1_out, est1 = self.extractor(block1)
|
208
|
+
block2_out, est2 = self.extractor(block2)
|
209
|
+
block3_out, est3 = self.extractor(block3)
|
210
|
+
block4_out, est4 = self.extractor(block4)
|
211
|
+
aligned_block0_out = self.pcd_align(block0_out, block2_out)
|
212
|
+
aligned_block1_out = self.pcd_align(block1_out, block2_out)
|
213
|
+
aligned_block3_out = self.pcd_align(block3_out, block2_out)
|
214
|
+
aligned_block4_out = self.pcd_align(block4_out, block2_out)
|
215
|
+
mask0 = self.mask0(aligned_block0_out, block2_out)
|
216
|
+
mask1 = self.mask1(aligned_block1_out, block2_out)
|
217
|
+
mask3 = self.mask3(aligned_block3_out, block2_out)
|
218
|
+
mask4 = self.mask4(aligned_block4_out, block2_out)
|
219
|
+
out = torch.cat((aligned_block0_out * mask0, aligned_block1_out * mask1), 1)
|
220
|
+
out = torch.cat((out, block2_out), 1)
|
221
|
+
out = torch.cat((out, aligned_block3_out * mask3), 1)
|
222
|
+
out = torch.cat((out, aligned_block4_out * mask4), 1)
|
223
|
+
out = self.rec_relu(self.rec_conv0(out))
|
224
|
+
out = self.rec_relu(self.rec_conv1(out))
|
225
|
+
out = self.rec_conv2(out)
|
226
|
+
return out
|
227
|
+
|
228
|
+
|
229
|
+
if __name__ == "__main__":
|
230
|
+
print("out")
|
@@ -0,0 +1,86 @@
|
|
1
|
+
## [CVPR 2021] Spk2ImgNet: Learning to Reconstruct Dynamic Scene from Continuous Spike Stream
|
2
|
+
|
3
|
+
|
4
|
+
<h4 align="center"> Jing Zhao, Ruiqin Xiong, Hangfan Liu, Jian Zhang, Tiejun Huang </h4>
|
5
|
+
|
6
|
+
This repository contains the official source code for our paper:
|
7
|
+
|
8
|
+
Spk2ImgNet: Learning to Reconstruct Dynamic Scene from Continuous Spike Stream. CVPR 2021
|
9
|
+
|
10
|
+
Paper:
|
11
|
+
[Spk2ImgNet-CVPR2021](https://openaccess.thecvf.com/content/CVPR2021/papers/Zhao_Spk2ImgNet_Learning_To_Reconstruct_Dynamic_Scene_From_Continuous_Spike_Stream_CVPR_2021_paper.pdf)
|
12
|
+
|
13
|
+
* [Spk2ImgNet](#Learning-to-Reconstruct-Dynamic-Scene-from-Continuous-Spike-Stream.)
|
14
|
+
* [Environments](#Environments)
|
15
|
+
* [Download the pretrained models](#Download-the-pretrained-models)
|
16
|
+
* [Evaluate](#Evaluate)
|
17
|
+
* [Train](#Train)
|
18
|
+
* [Citation](#Citations)
|
19
|
+
|
20
|
+
|
21
|
+
## Environments
|
22
|
+
|
23
|
+
You will have to choose cudatoolkit version to match your compute environment. The code is tested on PyTorch 1.10.2+cu113 and spatial-correlation-sampler 0.3.0 but other versions might also work.
|
24
|
+
|
25
|
+
```bash
|
26
|
+
conda create -n steflow python==3.9
|
27
|
+
conda activate steflow
|
28
|
+
conda install pytorch torchvision cudatoolkit=11.3 -c pytorch
|
29
|
+
pip3 install matplotlib opencv-python h5py
|
30
|
+
```
|
31
|
+
|
32
|
+
We don't ensure that all the PyTorch versions can work well.
|
33
|
+
|
34
|
+
## Prepare the Data
|
35
|
+
|
36
|
+
### Download the pretrained models
|
37
|
+
|
38
|
+
The pretrained model can be downloaded in the Google Drive link below
|
39
|
+
|
40
|
+
[Link for pretrained model](https://drive.google.com/file/d/1vBTJxlctk4otQKsyRq7lsFYGU4WGRNjt/view?usp=sharing)
|
41
|
+
|
42
|
+
You can download the pretrained models to ```./ckpt```
|
43
|
+
|
44
|
+
### Download the training data
|
45
|
+
|
46
|
+
The training data can be downloaded in the Google Drive link below
|
47
|
+
|
48
|
+
[Link for training data](https://drive.google.com/file/d/1ozR2-fNmU10gA_TCYUfJN-ahV6e_8Ke7/view?usp=sharing)
|
49
|
+
|
50
|
+
## Evaluate
|
51
|
+
|
52
|
+
You can set the data path in the .py files or through argparser (--data)
|
53
|
+
|
54
|
+
```bash
|
55
|
+
python3 main_steflow_dt1.py \
|
56
|
+
--test_data 'Spk2ImgNet_test2' \
|
57
|
+
--model_name 'model_061.pth'
|
58
|
+
|
59
|
+
```
|
60
|
+
|
61
|
+
|
62
|
+
## Train
|
63
|
+
|
64
|
+
|
65
|
+
All the command line arguments for hyperparameter tuning can be found in the `train.py` file.
|
66
|
+
You can set the data path in the .py files or through argparser (--data)
|
67
|
+
|
68
|
+
```bash
|
69
|
+
python3 train.py
|
70
|
+
```
|
71
|
+
|
72
|
+
## Citations
|
73
|
+
|
74
|
+
If you find this code useful in your research, please consider citing our paper:
|
75
|
+
|
76
|
+
```
|
77
|
+
@inproceedings{zhao2021spike,
|
78
|
+
title={Spk2ImgNet: Learning to Reconstruct Dynamic Scene from Continuous Spike Stream},
|
79
|
+
author={Zhao, Jing and Xiong, Ruiqin and Liu, Hangfan and Zhang, Jian and Huang, Tiejun},
|
80
|
+
booktitle={2021 IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
|
81
|
+
year={2021}
|
82
|
+
}
|
83
|
+
```
|
84
|
+
|
85
|
+
|
86
|
+
|
@@ -0,0 +1,118 @@
|
|
1
|
+
import argparse
|
2
|
+
import time
|
3
|
+
|
4
|
+
from skimage.metrics import peak_signal_noise_ratio, structural_similarity
|
5
|
+
from torch.autograd import Variable
|
6
|
+
|
7
|
+
from dataset import *
|
8
|
+
from nets import *
|
9
|
+
|
10
|
+
os.environ["CUDA_VISIBLE_DEVICES"] = "0"
|
11
|
+
|
12
|
+
parser = argparse.ArgumentParser(description="Spike_Net_Test")
|
13
|
+
parser.add_argument(
|
14
|
+
"--num_of_layers", type=int, default=17, help="Number of toatal layers"
|
15
|
+
)
|
16
|
+
parser.add_argument(
|
17
|
+
"--logdir",
|
18
|
+
type=str,
|
19
|
+
default="./ckpt2/",
|
20
|
+
help="path of log files",
|
21
|
+
)
|
22
|
+
parser.add_argument("--test_data", type=str, default="./Spk2ImgNet_test2/test2/", help="test set")
|
23
|
+
parser.add_argument(
|
24
|
+
"--save_result", type=bool, default=True, help="save the reconstruction or not"
|
25
|
+
)
|
26
|
+
parser.add_argument(
|
27
|
+
"--result_dir", type=str, default="results/", help="path of results"
|
28
|
+
)
|
29
|
+
parser.add_argument(
|
30
|
+
"--exist_gt", type=bool, default=True, help="exist ground truth or not"
|
31
|
+
)
|
32
|
+
parser.add_argument("--model_name", type=str, default="model_041.pth", help="Name of ckp")
|
33
|
+
opt = parser.parse_args()
|
34
|
+
|
35
|
+
|
36
|
+
def normalize(data):
|
37
|
+
return data / 255.0
|
38
|
+
|
39
|
+
|
40
|
+
def main():
|
41
|
+
# Build model
|
42
|
+
print("Loading model ... \n")
|
43
|
+
net = SpikeNet(
|
44
|
+
in_channels=13, features=64, out_channels=1, win_r=6, win_step=7
|
45
|
+
)
|
46
|
+
# device_ids = [0]
|
47
|
+
# print(device_ids[0])
|
48
|
+
model = nn.DataParallel(net).cuda()
|
49
|
+
model.load_state_dict(torch.load(os.path.join(opt.logdir, opt.model_name)))
|
50
|
+
model.eval()
|
51
|
+
|
52
|
+
# load data info
|
53
|
+
print("Loading data info ...\n")
|
54
|
+
# sub_dir = 'data4'
|
55
|
+
files_source = glob.glob(os.path.join(opt.test_data, "input", "*.dat"))
|
56
|
+
files_source.sort()
|
57
|
+
|
58
|
+
# process data
|
59
|
+
psnr_test = 0
|
60
|
+
ssim_test = 0
|
61
|
+
for i in range(len(files_source)):
|
62
|
+
sub_dir = files_source[i][:-4]
|
63
|
+
# Input spike
|
64
|
+
input_f = open(files_source[i], "rb+")
|
65
|
+
video_seq = input_f.read()
|
66
|
+
video_seq = np.fromstring(video_seq, "B")
|
67
|
+
InSpikeArray = raw_to_spike(video_seq, 250, 400) # c*h*w
|
68
|
+
[c, h, w] = InSpikeArray.shape
|
69
|
+
for key_id in np.arange(151, 152, 1):
|
70
|
+
start_t = time.time()
|
71
|
+
SpikeArray = InSpikeArray[key_id - 21 : key_id + 20, :, :]
|
72
|
+
# make its shape can be divided by 4
|
73
|
+
SpikeArray = np.pad(
|
74
|
+
SpikeArray, ((0, 0), (0, 2), (0, 0)), "symmetric"
|
75
|
+
) # c*252*40
|
76
|
+
SpikeArray = np.expand_dims(SpikeArray, 0) # n*c*h*w
|
77
|
+
file_name = files_source[i].replace("\\", "/").split("/")[-1]
|
78
|
+
|
79
|
+
SpikeArray = Variable(torch.Tensor(SpikeArray)).cuda()
|
80
|
+
with torch.no_grad():
|
81
|
+
if opt.exist_gt:
|
82
|
+
out_rec, est0, est1, est2, est3, est4 = model(SpikeArray)
|
83
|
+
out_rec = (
|
84
|
+
torch.clamp(out_rec / 0.6, 0, 1).cpu() * 255
|
85
|
+
) # 0.6 is the converation rate used in the spike camera. Only neccessary for our synthezed data.
|
86
|
+
else:
|
87
|
+
out_rec, est0, est1, est2, est3, est4 = model(SpikeArray)
|
88
|
+
out_rec = torch.clamp(out_rec, 0, 1).cpu() ** (1 / 2.2) * 255
|
89
|
+
out_rec = out_rec.detach().numpy().astype(np.float32)
|
90
|
+
out_rec = np.squeeze(out_rec).astype(np.uint8)
|
91
|
+
# transform to orignal shape # 250*400
|
92
|
+
out_rec = out_rec[:250, :]
|
93
|
+
if opt.exist_gt:
|
94
|
+
gt = cv2.imread(
|
95
|
+
os.path.join(opt.test_data, "gt", file_name[:-3] + "png"), 0
|
96
|
+
)
|
97
|
+
psnr = peak_signal_noise_ratio(gt, out_rec)
|
98
|
+
ssim = structural_similarity(gt, out_rec)
|
99
|
+
print("%10s: PSNR:%.2f SSIM:%.4f" % (file_name, psnr, ssim))
|
100
|
+
psnr_test += psnr
|
101
|
+
ssim_test += ssim
|
102
|
+
if opt.save_result:
|
103
|
+
if not os.path.exists(os.path.join(opt.result_dir, sub_dir)):
|
104
|
+
os.makedirs(os.path.join(opt.result_dir, sub_dir))
|
105
|
+
cv2.imwrite(
|
106
|
+
os.path.join(opt.result_dir, sub_dir, str(key_id) + ".png"), out_rec
|
107
|
+
)
|
108
|
+
dur_time = time.time() - start_t
|
109
|
+
print("dur_time:%.2f", dur_time)
|
110
|
+
|
111
|
+
if opt.exist_gt:
|
112
|
+
avg_psnr = psnr_test / len(files_source)
|
113
|
+
avg_ssim = ssim_test / len(files_source)
|
114
|
+
print("average PSNR: %.2f average SSIM: %.4f" % (avg_psnr, avg_ssim))
|
115
|
+
|
116
|
+
|
117
|
+
if __name__ == "__main__":
|
118
|
+
main()
|