spikezoo 0.1.2__py3-none-any.whl → 0.2__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- spikezoo/__init__.py +13 -0
- spikezoo/archs/__pycache__/__init__.cpython-39.pyc +0 -0
- spikezoo/archs/base/__pycache__/nets.cpython-39.pyc +0 -0
- spikezoo/archs/base/nets.py +34 -0
- spikezoo/archs/bsf/README.md +92 -0
- spikezoo/archs/bsf/datasets/datasets.py +328 -0
- spikezoo/archs/bsf/datasets/ds_utils.py +64 -0
- spikezoo/archs/bsf/main.py +398 -0
- spikezoo/archs/bsf/metrics/psnr.py +22 -0
- spikezoo/archs/bsf/metrics/ssim.py +54 -0
- spikezoo/archs/bsf/models/bsf/__pycache__/align.cpython-39.pyc +0 -0
- spikezoo/archs/bsf/models/bsf/__pycache__/bsf.cpython-39.pyc +0 -0
- spikezoo/archs/bsf/models/bsf/__pycache__/rep.cpython-39.pyc +0 -0
- spikezoo/archs/bsf/models/bsf/align.py +154 -0
- spikezoo/archs/bsf/models/bsf/bsf.py +105 -0
- spikezoo/archs/bsf/models/bsf/dsft_convert.py +96 -0
- spikezoo/archs/bsf/models/bsf/rep.py +44 -0
- spikezoo/archs/bsf/models/get_model.py +7 -0
- spikezoo/archs/bsf/prepare_data/DSFT.py +62 -0
- spikezoo/archs/bsf/prepare_data/crop_dataset_train.py +135 -0
- spikezoo/archs/bsf/prepare_data/crop_dataset_val.py +139 -0
- spikezoo/archs/bsf/prepare_data/crop_train.sh +4 -0
- spikezoo/archs/bsf/prepare_data/crop_val.sh +4 -0
- spikezoo/archs/bsf/prepare_data/io_utils.py +64 -0
- spikezoo/archs/bsf/requirements.txt +9 -0
- spikezoo/archs/bsf/test.py +16 -0
- spikezoo/archs/bsf/utils.py +154 -0
- spikezoo/archs/spikeclip/__pycache__/nets.cpython-39.pyc +0 -0
- spikezoo/archs/spikeclip/nets.py +40 -0
- spikezoo/archs/spikeformer/CheckPoints/readme +1 -0
- spikezoo/archs/spikeformer/DataProcess/DataExtactor.py +60 -0
- spikezoo/archs/spikeformer/DataProcess/DataLoader.py +115 -0
- spikezoo/archs/spikeformer/DataProcess/LoadSpike.py +39 -0
- spikezoo/archs/spikeformer/EvalResults/readme +1 -0
- spikezoo/archs/spikeformer/LICENSE +21 -0
- spikezoo/archs/spikeformer/Metrics/Metrics.py +50 -0
- spikezoo/archs/spikeformer/Metrics/__init__.py +0 -0
- spikezoo/archs/spikeformer/Model/Loss.py +89 -0
- spikezoo/archs/spikeformer/Model/SpikeFormer.py +230 -0
- spikezoo/archs/spikeformer/Model/__init__.py +0 -0
- spikezoo/archs/spikeformer/Model/__pycache__/SpikeFormer.cpython-39.pyc +0 -0
- spikezoo/archs/spikeformer/Model/__pycache__/__init__.cpython-39.pyc +0 -0
- spikezoo/archs/spikeformer/README.md +30 -0
- spikezoo/archs/spikeformer/evaluate.py +87 -0
- spikezoo/archs/spikeformer/recon_real_data.py +97 -0
- spikezoo/archs/spikeformer/requirements.yml +95 -0
- spikezoo/archs/spikeformer/train.py +173 -0
- spikezoo/archs/spikeformer/utils.py +22 -0
- spikezoo/archs/spk2imgnet/.github/workflows/pylint.yml +23 -0
- spikezoo/archs/spk2imgnet/.gitignore +150 -0
- spikezoo/archs/spk2imgnet/DCNv2.py +135 -0
- spikezoo/archs/spk2imgnet/__pycache__/DCNv2.cpython-39.pyc +0 -0
- spikezoo/archs/spk2imgnet/__pycache__/align_arch.cpython-39.pyc +0 -0
- spikezoo/archs/spk2imgnet/__pycache__/nets.cpython-39.pyc +0 -0
- spikezoo/archs/spk2imgnet/align_arch.py +159 -0
- spikezoo/archs/spk2imgnet/dataset.py +144 -0
- spikezoo/archs/spk2imgnet/nets.py +230 -0
- spikezoo/archs/spk2imgnet/readme.md +86 -0
- spikezoo/archs/spk2imgnet/test_gen_imgseq.py +118 -0
- spikezoo/archs/spk2imgnet/train.py +189 -0
- spikezoo/archs/spk2imgnet/utils.py +64 -0
- spikezoo/archs/ssir/README.md +87 -0
- spikezoo/archs/ssir/configs/SSIR.yml +37 -0
- spikezoo/archs/ssir/configs/yml_parser.py +78 -0
- spikezoo/archs/ssir/datasets/dataset_sreds.py +170 -0
- spikezoo/archs/ssir/datasets/ds_utils.py +66 -0
- spikezoo/archs/ssir/losses.py +21 -0
- spikezoo/archs/ssir/main.py +326 -0
- spikezoo/archs/ssir/metrics/psnr.py +22 -0
- spikezoo/archs/ssir/metrics/ssim.py +54 -0
- spikezoo/archs/ssir/models/Vgg19.py +42 -0
- spikezoo/archs/ssir/models/__pycache__/layers.cpython-39.pyc +0 -0
- spikezoo/archs/ssir/models/__pycache__/networks.cpython-39.pyc +0 -0
- spikezoo/archs/ssir/models/layers.py +110 -0
- spikezoo/archs/ssir/models/networks.py +61 -0
- spikezoo/archs/ssir/requirements.txt +8 -0
- spikezoo/archs/ssir/shells/eval_SREDS.sh +6 -0
- spikezoo/archs/ssir/shells/train_SSIR.sh +12 -0
- spikezoo/archs/ssir/test.py +3 -0
- spikezoo/archs/ssir/utils.py +154 -0
- spikezoo/archs/ssml/__pycache__/cbam.cpython-39.pyc +0 -0
- spikezoo/archs/ssml/__pycache__/model.cpython-39.pyc +0 -0
- spikezoo/archs/ssml/cbam.py +224 -0
- spikezoo/archs/ssml/model.py +290 -0
- spikezoo/archs/ssml/res.png +0 -0
- spikezoo/archs/ssml/test.py +67 -0
- spikezoo/archs/stir/.git-credentials +0 -0
- spikezoo/archs/stir/README.md +65 -0
- spikezoo/archs/stir/ckpt_outputs/Descriptions.txt +1 -0
- spikezoo/archs/stir/configs/STIR.yml +37 -0
- spikezoo/archs/stir/configs/utils.py +155 -0
- spikezoo/archs/stir/configs/yml_parser.py +78 -0
- spikezoo/archs/stir/datasets/dataset_sreds.py +180 -0
- spikezoo/archs/stir/datasets/ds_utils.py +66 -0
- spikezoo/archs/stir/eval_SREDS.sh +5 -0
- spikezoo/archs/stir/main.py +397 -0
- spikezoo/archs/stir/metrics/losses.py +219 -0
- spikezoo/archs/stir/metrics/psnr.py +22 -0
- spikezoo/archs/stir/metrics/ssim.py +54 -0
- spikezoo/archs/stir/models/Vgg19.py +42 -0
- spikezoo/archs/stir/models/__pycache__/networks_STIR.cpython-39.pyc +0 -0
- spikezoo/archs/stir/models/__pycache__/submodules.cpython-39.pyc +0 -0
- spikezoo/archs/stir/models/__pycache__/transformer_new.cpython-39.pyc +0 -0
- spikezoo/archs/stir/models/networks_STIR.py +361 -0
- spikezoo/archs/stir/models/submodules.py +86 -0
- spikezoo/archs/stir/models/transformer_new.py +151 -0
- spikezoo/archs/stir/package_core/build/lib/package_core/__init__.py +0 -0
- spikezoo/archs/stir/package_core/build/lib/package_core/convertions.py +721 -0
- spikezoo/archs/stir/package_core/build/lib/package_core/disp_netS.py +133 -0
- spikezoo/archs/stir/package_core/build/lib/package_core/flow_utils.py +167 -0
- spikezoo/archs/stir/package_core/build/lib/package_core/generic_train_test.py +76 -0
- spikezoo/archs/stir/package_core/build/lib/package_core/geometry.py +458 -0
- spikezoo/archs/stir/package_core/build/lib/package_core/image_proc.py +183 -0
- spikezoo/archs/stir/package_core/build/lib/package_core/linalg.py +40 -0
- spikezoo/archs/stir/package_core/build/lib/package_core/losses.py +198 -0
- spikezoo/archs/stir/package_core/build/lib/package_core/metrics.py +51 -0
- spikezoo/archs/stir/package_core/build/lib/package_core/model_base.py +53 -0
- spikezoo/archs/stir/package_core/build/lib/package_core/net_basics.py +100 -0
- spikezoo/archs/stir/package_core/build/lib/package_core/resnet.py +333 -0
- spikezoo/archs/stir/package_core/build/lib/package_core/transforms.py +123 -0
- spikezoo/archs/stir/package_core/build/lib/package_core/utils.py +72 -0
- spikezoo/archs/stir/package_core/dist/package_core-0.0.0-py3.9.egg +0 -0
- spikezoo/archs/stir/package_core/package_core/__init__.py +0 -0
- spikezoo/archs/stir/package_core/package_core/__pycache__/__init__.cpython-39.pyc +0 -0
- spikezoo/archs/stir/package_core/package_core/__pycache__/net_basics.cpython-39.pyc +0 -0
- spikezoo/archs/stir/package_core/package_core/convertions.py +721 -0
- spikezoo/archs/stir/package_core/package_core/disp_netS.py +133 -0
- spikezoo/archs/stir/package_core/package_core/flow_utils.py +167 -0
- spikezoo/archs/stir/package_core/package_core/generic_train_test.py +76 -0
- spikezoo/archs/stir/package_core/package_core/geometry.py +458 -0
- spikezoo/archs/stir/package_core/package_core/image_proc.py +183 -0
- spikezoo/archs/stir/package_core/package_core/linalg.py +40 -0
- spikezoo/archs/stir/package_core/package_core/losses.py +198 -0
- spikezoo/archs/stir/package_core/package_core/metrics.py +51 -0
- spikezoo/archs/stir/package_core/package_core/model_base.py +53 -0
- spikezoo/archs/stir/package_core/package_core/net_basics.py +100 -0
- spikezoo/archs/stir/package_core/package_core/resnet.py +333 -0
- spikezoo/archs/stir/package_core/package_core/transforms.py +123 -0
- spikezoo/archs/stir/package_core/package_core/utils.py +72 -0
- spikezoo/archs/stir/package_core/package_core.egg-info/PKG-INFO +3 -0
- spikezoo/archs/stir/package_core/package_core.egg-info/SOURCES.txt +20 -0
- spikezoo/archs/stir/package_core/package_core.egg-info/dependency_links.txt +1 -0
- spikezoo/archs/stir/package_core/package_core.egg-info/top_level.txt +1 -0
- spikezoo/archs/stir/package_core/setup.py +5 -0
- spikezoo/archs/stir/requirements.txt +12 -0
- spikezoo/archs/stir/train_STIR.sh +9 -0
- spikezoo/archs/tfi/__pycache__/nets.cpython-39.pyc +0 -0
- spikezoo/archs/tfi/nets.py +43 -0
- spikezoo/archs/tfp/__pycache__/nets.cpython-39.pyc +0 -0
- spikezoo/archs/tfp/nets.py +13 -0
- spikezoo/archs/wgse/README.md +64 -0
- spikezoo/archs/wgse/__pycache__/dwtnets.cpython-39.pyc +0 -0
- spikezoo/archs/wgse/__pycache__/submodules.cpython-39.pyc +0 -0
- spikezoo/archs/wgse/dataset.py +59 -0
- spikezoo/archs/wgse/demo.png +0 -0
- spikezoo/archs/wgse/demo.py +83 -0
- spikezoo/archs/wgse/dwtnets.py +145 -0
- spikezoo/archs/wgse/eval.py +133 -0
- spikezoo/archs/wgse/logs/WGSE-Dwt1dNet-db8-5-ks3/log.txt +11 -0
- spikezoo/archs/wgse/submodules.py +68 -0
- spikezoo/archs/wgse/train.py +261 -0
- spikezoo/archs/wgse/transform.py +139 -0
- spikezoo/archs/wgse/utils.py +128 -0
- spikezoo/archs/wgse/weights/demo.png +0 -0
- spikezoo/data/base/test/gt/200_part1_key_id151.png +0 -0
- spikezoo/data/base/test/gt/200_part3_key_id151.png +0 -0
- spikezoo/data/base/test/gt/203_part1_key_id151.png +0 -0
- spikezoo/data/base/test/spike/200_part1_key_id151.dat +0 -0
- spikezoo/data/base/test/spike/200_part3_key_id151.dat +0 -0
- spikezoo/data/base/test/spike/203_part1_key_id151.dat +0 -0
- spikezoo/data/base/train/gt/203_part2_key_id151.png +0 -0
- spikezoo/data/base/train/gt/203_part3_key_id151.png +0 -0
- spikezoo/data/base/train/gt/203_part4_key_id151.png +0 -0
- spikezoo/data/base/train/spike/203_part2_key_id151.dat +0 -0
- spikezoo/data/base/train/spike/203_part3_key_id151.dat +0 -0
- spikezoo/data/base/train/spike/203_part4_key_id151.dat +0 -0
- spikezoo/datasets/base_dataset.py +2 -3
- spikezoo/metrics/__init__.py +1 -1
- spikezoo/models/base_model.py +1 -3
- spikezoo/pipeline/base_pipeline.py +7 -5
- spikezoo/pipeline/train_pipeline.py +1 -1
- spikezoo/utils/other_utils.py +16 -6
- spikezoo/utils/spike_utils.py +33 -29
- spikezoo/utils/vidar_loader.cpython-39-x86_64-linux-gnu.so +0 -0
- spikezoo-0.2.dist-info/METADATA +163 -0
- spikezoo-0.2.dist-info/RECORD +211 -0
- spikezoo/models/spcsnet_model.py +0 -19
- spikezoo-0.1.2.dist-info/METADATA +0 -39
- spikezoo-0.1.2.dist-info/RECORD +0 -36
- {spikezoo-0.1.2.dist-info → spikezoo-0.2.dist-info}/LICENSE.txt +0 -0
- {spikezoo-0.1.2.dist-info → spikezoo-0.2.dist-info}/WHEEL +0 -0
- {spikezoo-0.1.2.dist-info → spikezoo-0.2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,230 @@
|
|
1
|
+
'''
|
2
|
+
Ref: To implenment SpikeFormer, we referred to the code of ”segefomer-pytorch” published on github
|
3
|
+
(link: https://github.com/lucidrains/segformer-pytorch.git)
|
4
|
+
'''
|
5
|
+
from math import sqrt
|
6
|
+
from functools import partial
|
7
|
+
import torch
|
8
|
+
from torch import nn, einsum
|
9
|
+
from einops import rearrange, reduce
|
10
|
+
|
11
|
+
# helpers
|
12
|
+
|
13
|
+
def exists(val):
|
14
|
+
return val is not None
|
15
|
+
|
16
|
+
def cast_tuple(val, depth):
|
17
|
+
return val if isinstance(val, tuple) else (val,) * depth
|
18
|
+
|
19
|
+
LayerNorm = partial(nn.InstanceNorm2d, affine = True)
|
20
|
+
|
21
|
+
# classes
|
22
|
+
|
23
|
+
class DsConv2d(nn.Module):
|
24
|
+
def __init__(self, dim_in, dim_out, kernel_size, padding, stride = 1, bias = True):
|
25
|
+
super().__init__()
|
26
|
+
self.net = nn.Sequential(
|
27
|
+
nn.Conv2d(dim_in, dim_in, kernel_size = kernel_size, padding = padding, groups = dim_in, stride = stride, bias = bias),
|
28
|
+
nn.GELU(),
|
29
|
+
nn.Conv2d(dim_in, dim_out, kernel_size = 1, bias = bias),
|
30
|
+
nn.GELU(),
|
31
|
+
)
|
32
|
+
def forward(self, x):
|
33
|
+
return self.net(x)
|
34
|
+
|
35
|
+
class PreNorm(nn.Module):
|
36
|
+
def __init__(self, dim, fn):
|
37
|
+
super().__init__()
|
38
|
+
self.fn = fn
|
39
|
+
self.norm = LayerNorm(dim)
|
40
|
+
|
41
|
+
def forward(self, x):
|
42
|
+
# return self.fn(x)
|
43
|
+
return self.fn(self.norm(x))
|
44
|
+
|
45
|
+
class EfficientSelfAttention(nn.Module):
|
46
|
+
def __init__(
|
47
|
+
self,
|
48
|
+
*,
|
49
|
+
dim,
|
50
|
+
heads,
|
51
|
+
reduction_ratio
|
52
|
+
):
|
53
|
+
super().__init__()
|
54
|
+
self.scale = (dim // heads) ** -0.5
|
55
|
+
self.heads = heads
|
56
|
+
|
57
|
+
self.to_q = nn.Conv2d(dim, dim, 1, bias = False)
|
58
|
+
self.to_kv = nn.Conv2d(dim, dim * 2, reduction_ratio, stride = reduction_ratio, bias = False)
|
59
|
+
self.to_out = nn.Sequential(
|
60
|
+
nn.Conv2d(dim, dim, 1, bias=False),
|
61
|
+
nn.GELU()
|
62
|
+
)
|
63
|
+
|
64
|
+
def forward(self, x):
|
65
|
+
h, w = x.shape[-2:]
|
66
|
+
heads = self.heads
|
67
|
+
|
68
|
+
q, k, v = (self.to_q(x), *self.to_kv(x).chunk(2, dim = 1))
|
69
|
+
q, k, v = map(lambda t: rearrange(t, 'b (h c) x y -> (b h) (x y) c', h = heads), (q, k, v))
|
70
|
+
|
71
|
+
sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
|
72
|
+
attn = sim.softmax(dim = -1)
|
73
|
+
|
74
|
+
out = einsum('b i j, b j d -> b i d', attn, v)
|
75
|
+
out = rearrange(out, '(b h) (x y) c -> b (h c) x y', h = heads, x = h, y = w)
|
76
|
+
return self.to_out(out)
|
77
|
+
|
78
|
+
class MixFeedForward(nn.Module):
|
79
|
+
def __init__(
|
80
|
+
self,
|
81
|
+
*,
|
82
|
+
dim,
|
83
|
+
expansion_factor
|
84
|
+
):
|
85
|
+
super().__init__()
|
86
|
+
hidden_dim = dim * expansion_factor
|
87
|
+
self.net = nn.Sequential(
|
88
|
+
nn.Conv2d(dim, hidden_dim, 1),
|
89
|
+
nn.GELU(),
|
90
|
+
DsConv2d(hidden_dim, hidden_dim, 3, padding = 1),
|
91
|
+
# nn.GELU(),
|
92
|
+
nn.Conv2d(hidden_dim, dim, 1),
|
93
|
+
nn.GELU(),
|
94
|
+
)
|
95
|
+
|
96
|
+
def forward(self, x):
|
97
|
+
return self.net(x)
|
98
|
+
|
99
|
+
class MiT(nn.Module):
|
100
|
+
def __init__(
|
101
|
+
self,
|
102
|
+
*,
|
103
|
+
channels,
|
104
|
+
dims,
|
105
|
+
heads,
|
106
|
+
ff_expansion,
|
107
|
+
reduction_ratio,
|
108
|
+
num_layers
|
109
|
+
):
|
110
|
+
super().__init__()
|
111
|
+
stage_kernel_stride_pad = ((7, 4, 3), (3, 2, 1), (3, 2, 1), (3, 2, 1))
|
112
|
+
channels_mod = channels/16
|
113
|
+
dims = (channels_mod, *dims)
|
114
|
+
dim_pairs = list(zip(dims[:-1], dims[1:]))
|
115
|
+
|
116
|
+
self.stages = nn.ModuleList([])
|
117
|
+
|
118
|
+
for (dim_in, dim_out), (kernel, stride, padding), num_layers, ff_expansion, heads, reduction_ratio in zip(dim_pairs, stage_kernel_stride_pad, num_layers, ff_expansion, heads, reduction_ratio):
|
119
|
+
get_overlap_patches = nn.Unfold(kernel, stride = stride, padding = padding)
|
120
|
+
# if count == 0:
|
121
|
+
# overlap_patch_embed = nn.Conv2d(int((dim_in * kernel ** 2)/16), dim_out, 1)
|
122
|
+
# else:
|
123
|
+
overlap_patch_embed = nn.Conv2d(int(dim_in * kernel ** 2), dim_out, 1)
|
124
|
+
# count+=1
|
125
|
+
|
126
|
+
layers = nn.ModuleList([])
|
127
|
+
|
128
|
+
for _ in range(num_layers):
|
129
|
+
layers.append(nn.ModuleList([
|
130
|
+
PreNorm(dim_out, EfficientSelfAttention(dim = dim_out, heads = heads, reduction_ratio = reduction_ratio)),
|
131
|
+
PreNorm(dim_out, MixFeedForward(dim = dim_out, expansion_factor = ff_expansion)),
|
132
|
+
]))
|
133
|
+
|
134
|
+
self.stages.append(nn.ModuleList([
|
135
|
+
get_overlap_patches,
|
136
|
+
overlap_patch_embed,
|
137
|
+
layers
|
138
|
+
]))
|
139
|
+
|
140
|
+
def forward(
|
141
|
+
self,
|
142
|
+
x,
|
143
|
+
return_layer_outputs = False
|
144
|
+
):
|
145
|
+
h, w = x.shape[-2:]
|
146
|
+
# h = int(h/4)
|
147
|
+
# w = int(w/4)
|
148
|
+
# print(h)
|
149
|
+
|
150
|
+
layer_outputs = []
|
151
|
+
for (get_overlap_patches, overlap_embed, layers) in self.stages:
|
152
|
+
x = get_overlap_patches(x)
|
153
|
+
# print('aaa')
|
154
|
+
# print(x.shape)
|
155
|
+
num_patches = int(x.shape[-1])
|
156
|
+
|
157
|
+
|
158
|
+
# num_patches = int(x.shape[-1]/16)
|
159
|
+
# print(num_patches)
|
160
|
+
ratio = int(sqrt((h * w) / num_patches))
|
161
|
+
# print(ratio)
|
162
|
+
x = rearrange(x, 'b c (h w) -> b c h w', h = h // ratio)
|
163
|
+
# print(x.shape)
|
164
|
+
x = x.type(torch.cuda.FloatTensor)
|
165
|
+
|
166
|
+
x = overlap_embed(x)
|
167
|
+
for (attn, ff) in layers:
|
168
|
+
x = attn(x) + x
|
169
|
+
x = ff(x) + x
|
170
|
+
|
171
|
+
layer_outputs.append(x)
|
172
|
+
|
173
|
+
ret = x if not return_layer_outputs else layer_outputs
|
174
|
+
return ret
|
175
|
+
|
176
|
+
class SpikeFormer(nn.Module):
|
177
|
+
def __init__(
|
178
|
+
self,
|
179
|
+
inputDim=64,
|
180
|
+
dims = (32, 64, 160, 256),
|
181
|
+
heads = (1, 2, 5, 8),
|
182
|
+
ff_expansion = (8, 8, 4, 4),
|
183
|
+
reduction_ratio = (8, 4, 2, 1),
|
184
|
+
num_layers = 2,
|
185
|
+
channels =64,
|
186
|
+
decoder_dim = 256,
|
187
|
+
out_channel = 1
|
188
|
+
):
|
189
|
+
super().__init__()
|
190
|
+
dims, heads, ff_expansion, reduction_ratio, num_layers = map(partial(cast_tuple, depth = 4), (dims, heads, ff_expansion, reduction_ratio, num_layers))
|
191
|
+
assert all([*map(lambda t: len(t) == 4, (dims, heads, ff_expansion, reduction_ratio, num_layers))]), 'only four stages are allowed, all keyword arguments must be either a single value or a tuple of 4 values'
|
192
|
+
|
193
|
+
self.mit = MiT(
|
194
|
+
channels = channels,
|
195
|
+
dims = dims,
|
196
|
+
heads = heads,
|
197
|
+
ff_expansion = ff_expansion,
|
198
|
+
reduction_ratio = reduction_ratio,
|
199
|
+
num_layers = num_layers
|
200
|
+
)
|
201
|
+
self.channel_transform = nn.Sequential(
|
202
|
+
nn.Conv2d(inputDim, 64, 3, 1, 1),
|
203
|
+
nn.GELU()
|
204
|
+
)
|
205
|
+
|
206
|
+
self.to_fused = nn.ModuleList([nn.Sequential(
|
207
|
+
nn.Conv2d(dim, decoder_dim, 1),
|
208
|
+
nn.PixelShuffle(2 ** i),
|
209
|
+
nn.GELU(),
|
210
|
+
) for i, dim in enumerate(dims)])
|
211
|
+
|
212
|
+
self.to_restore = nn.Sequential(
|
213
|
+
nn.Conv2d(256+64+16+4, decoder_dim, 1),
|
214
|
+
nn.GELU(),
|
215
|
+
nn.Conv2d(decoder_dim, out_channel, 1),
|
216
|
+
)
|
217
|
+
self.fournew = nn.PixelShuffle(4)
|
218
|
+
|
219
|
+
|
220
|
+
|
221
|
+
def forward(self, x):
|
222
|
+
x = self.channel_transform(x)
|
223
|
+
x = self.fournew(x)
|
224
|
+
layer_outputs = self.mit(x, return_layer_outputs = True)
|
225
|
+
|
226
|
+
fused = [to_fused(output) for output, to_fused in zip(layer_outputs, self.to_fused)]
|
227
|
+
|
228
|
+
fused = torch.cat(fused, dim = 1)
|
229
|
+
|
230
|
+
return self.to_restore(fused)
|
File without changes
|
Binary file
|
Binary file
|
@@ -0,0 +1,30 @@
|
|
1
|
+
# SpikeFormer [](https://opensource.org/licenses/mit-license.php)
|
2
|
+
Pytorch Implementation of "SpikeFormer: Image Reconstruction from the Sequence of Spike Camera Based on Transformer"[[Paper]](https://dl.acm.org/doi/abs/10.1145/3512388.3512399)
|
3
|
+
|
4
|
+
## Prerequisites
|
5
|
+
* Create a conda environment by running `conda env create -f requirements.yml`
|
6
|
+
|
7
|
+
## Dataset Structure
|
8
|
+
* To train the SpikeFormer, please organize file structure of the dataset as follows:
|
9
|
+
```
|
10
|
+
Dataset
|
11
|
+
├── test
|
12
|
+
│ └── c.npz
|
13
|
+
├── train
|
14
|
+
│ └── a.npz
|
15
|
+
└── valid
|
16
|
+
└── b.npz
|
17
|
+
```
|
18
|
+
|
19
|
+
## Pretrained Model
|
20
|
+
* Download the pretrained model [here](https://pan.baidu.com/s/1aeW15vQh0GXgRJtfStBHDg) using password: nwh5.
|
21
|
+
* Put the model to the path ./CheckPoints/
|
22
|
+
|
23
|
+
## Training
|
24
|
+
* Run `python train.py` to train SpikeFormer on training set.
|
25
|
+
|
26
|
+
## Validation
|
27
|
+
* Run `python evaluate.py` to evaluate the performance of trained model on testing set.
|
28
|
+
|
29
|
+
## Reconstruct Images from Real Spike Data
|
30
|
+
* Run `python recon_real_data.py`.
|
@@ -0,0 +1,87 @@
|
|
1
|
+
import os
|
2
|
+
os.environ['CUDA_VISIBLE_DEVICES'] = "0"
|
3
|
+
import torch
|
4
|
+
import numpy as np
|
5
|
+
from DataProcess import DataLoader as dl
|
6
|
+
from Model import Loss
|
7
|
+
from PIL import Image
|
8
|
+
from Metrics.Metrics import Metrics
|
9
|
+
from Model.SpikeFormer import SpikeFormer
|
10
|
+
from utils import LoadModel
|
11
|
+
|
12
|
+
if __name__ == "__main__":
|
13
|
+
|
14
|
+
dataPath = "/home/storage2/shechen/Spike_Sample_250x400"
|
15
|
+
spikeRadius = 32 # half length of input spike sequence expcept for the middle frame
|
16
|
+
spikeLen = 2 * spikeRadius + 1 # length of input spike sequence
|
17
|
+
batchSize = 4
|
18
|
+
|
19
|
+
reuse = True
|
20
|
+
checkPath = "CheckPoints/best.pth"
|
21
|
+
|
22
|
+
validContainer = dl.DataContainer(dataPath=dataPath, dataType='valid',
|
23
|
+
spikeRadius=spikeRadius,batchSize=batchSize)
|
24
|
+
validData = validContainer.GetLoader()
|
25
|
+
|
26
|
+
metrics = Metrics()
|
27
|
+
# model = Spk2Img(spikeRadius, frameRadius, frameStride).cuda()
|
28
|
+
|
29
|
+
model = SpikeFormer(
|
30
|
+
inputDim=spikeLen,
|
31
|
+
dims=(32, 64, 160, 256), # dimensions of each stage
|
32
|
+
heads=(1, 2, 5, 8), # heads of each stage
|
33
|
+
ff_expansion=(8, 8, 4, 4), # feedforward expansion factor of each stage
|
34
|
+
reduction_ratio=(8, 4, 2, 1), # reduction ratio of each stage for efficient attention
|
35
|
+
num_layers=2, # num layers of each stage
|
36
|
+
decoder_dim=256, # decoder dimension
|
37
|
+
out_channel=1 # channel of restored image
|
38
|
+
).cuda()
|
39
|
+
|
40
|
+
|
41
|
+
if reuse:
|
42
|
+
_, _, modelDict, _ = LoadModel(checkPath, model)
|
43
|
+
|
44
|
+
model.eval()
|
45
|
+
with torch.no_grad():
|
46
|
+
num = 0
|
47
|
+
pres = []
|
48
|
+
gts = []
|
49
|
+
for i, (spikes, gtImg) in enumerate(validData):
|
50
|
+
B, D, H, W = spikes.size()
|
51
|
+
spikes = spikes.cuda()
|
52
|
+
gtImg = gtImg.cuda()
|
53
|
+
predImg = model(spikes)
|
54
|
+
predImg = predImg.squeeze(1)
|
55
|
+
|
56
|
+
predImg = predImg.clamp(min=-1., max=1.)
|
57
|
+
predImg = predImg.detach().cpu().numpy()
|
58
|
+
gtImg = gtImg.clamp(min=-1., max=1.)
|
59
|
+
gtImg = gtImg.detach().cpu().numpy()
|
60
|
+
|
61
|
+
predImg = (predImg + 1.) / 2. * 255.
|
62
|
+
predImg = predImg.astype(np.uint8)
|
63
|
+
predImg = predImg[:, 3:-3]
|
64
|
+
|
65
|
+
gtImg = (gtImg + 1.) / 2. * 255.
|
66
|
+
gtImg = gtImg.astype(np.uint8)
|
67
|
+
|
68
|
+
pres.append(predImg)
|
69
|
+
gts.append(gtImg)
|
70
|
+
pres = np.concatenate(pres, axis=0)
|
71
|
+
gts = np.concatenate(gts, axis=0)
|
72
|
+
|
73
|
+
psnr = metrics.Cal_PSNR(pres, gts)
|
74
|
+
ssim = metrics.Cal_SSIM(pres, gts)
|
75
|
+
best_psnr, best_ssim, _ = metrics.GetBestMetrics()
|
76
|
+
|
77
|
+
B, H, W = pres.shape
|
78
|
+
divide_line = np.zeros((H, 4)).astype(np.uint8)
|
79
|
+
for pre, gt in zip(pres, gts):
|
80
|
+
num += 1
|
81
|
+
concatImg = np.concatenate([pre, divide_line, gt], axis=1)
|
82
|
+
concatImg = Image.fromarray(concatImg)
|
83
|
+
concatImg.save('EvalResults/test_%s.jpg' % (num))
|
84
|
+
|
85
|
+
print('*********************************************************')
|
86
|
+
print('PSNR: %s, SSIM: %s' % (psnr, ssim))
|
87
|
+
|
@@ -0,0 +1,97 @@
|
|
1
|
+
import os
|
2
|
+
os.environ['CUDA_VISIBLE_DEVICES'] = "0"
|
3
|
+
import torch
|
4
|
+
import numpy as np
|
5
|
+
from PIL import Image
|
6
|
+
from Metrics.Metrics import Metrics
|
7
|
+
from Model.SpikeFormer import SpikeFormer
|
8
|
+
from DataProcess.LoadSpike import load_spike_raw
|
9
|
+
from utils import LoadModel
|
10
|
+
import shutil
|
11
|
+
from PIL import Image
|
12
|
+
|
13
|
+
def PredictImg(model, inputs):
|
14
|
+
inputs = torch.FloatTensor(inputs)
|
15
|
+
inputs = inputs.cuda()
|
16
|
+
|
17
|
+
predImg = model(inputs).squeeze(dim=1)
|
18
|
+
|
19
|
+
predImg = predImg.clamp(min=-1., max=1.)
|
20
|
+
predImg = predImg.detach().cpu().numpy()
|
21
|
+
predImg = (predImg + 1.) / 2. * 255.
|
22
|
+
predImg = np.clip(predImg, 0., 255.)
|
23
|
+
predImg = predImg.astype(np.uint8)
|
24
|
+
predImg = predImg[:, 3:-3]
|
25
|
+
|
26
|
+
return predImg
|
27
|
+
|
28
|
+
if __name__ == "__main__":
|
29
|
+
|
30
|
+
dataName = "reds"
|
31
|
+
spikeRadius = 32
|
32
|
+
spikeLen = 2 * spikeRadius + 1
|
33
|
+
stride = 32
|
34
|
+
batchSize = 8
|
35
|
+
reuse = True
|
36
|
+
checkPath = "best.pth"
|
37
|
+
sceneClass = {
|
38
|
+
1:'ballon.dat', 2:'car-100kmh.dat',
|
39
|
+
3:'forest.dat', 4:'railway.dat',
|
40
|
+
5:'rotation1.dat', 6:'rotation2.dat',
|
41
|
+
7:'train-350kmh.dat', 8:'viaduct-bridge.dat'
|
42
|
+
}
|
43
|
+
sceneName = sceneClass[2]
|
44
|
+
dataPath = "/home/storage1/Dataset/SpikeImageData/RealData/%s" %(sceneName)
|
45
|
+
resultPath = sceneName + "_stride_" + str(stride) + "/"
|
46
|
+
shutil.rmtree(resultPath) if os.path.exists(resultPath) else os.mkdir(resultPath)
|
47
|
+
spikes = load_spike_raw(dataPath)
|
48
|
+
totalLen = spikes.shape[0]
|
49
|
+
metrics = Metrics()
|
50
|
+
model = SpikeFormer(
|
51
|
+
inputDim=spikeLen,
|
52
|
+
dims=(32, 64, 160, 256), # dimensions of each stage
|
53
|
+
heads=(1, 2, 5, 8), # heads of each stage
|
54
|
+
ff_expansion=(8, 8, 4, 4), # feedforward expansion factor of each stage
|
55
|
+
reduction_ratio=(8, 4, 2, 1), # reduction ratio of each stage for efficient attention
|
56
|
+
num_layers=2, # num layers of each stage
|
57
|
+
decoder_dim=256, # decoder dimension
|
58
|
+
out_channel = 1 # channel of restored image
|
59
|
+
).cuda()
|
60
|
+
|
61
|
+
if reuse:
|
62
|
+
_, _, modelDict, _ = LoadModel(checkPath, model)
|
63
|
+
|
64
|
+
model.eval()
|
65
|
+
with torch.no_grad():
|
66
|
+
num = 0
|
67
|
+
pres = []
|
68
|
+
batchFlag = 1
|
69
|
+
inputs = np.zeros((batchSize, spikeLen, 256, 400)) # 65
|
70
|
+
for i in range(32, totalLen - 32, stride):
|
71
|
+
batchFlag = 1
|
72
|
+
spike = spikes[i - spikeRadius: i + spikeRadius + 1]
|
73
|
+
spike = np.pad(spike, ((0, 0), (3, 3), (0, 0)), mode='constant')
|
74
|
+
spike = spike.astype(float) * 2 - 1
|
75
|
+
inputs[num % batchSize] = spike
|
76
|
+
num += 1
|
77
|
+
|
78
|
+
if num % batchSize == 0:
|
79
|
+
predImg = PredictImg(model, inputs)
|
80
|
+
inputs = np.zeros((batchSize, spikeLen, 256, 400)) # 65
|
81
|
+
pres.append(predImg)
|
82
|
+
batchFlag = 0
|
83
|
+
|
84
|
+
if batchFlag == 1:
|
85
|
+
imgNum = num % batchSize
|
86
|
+
inputs = inputs[0: imgNum]
|
87
|
+
predImg = PredictImg(model, inputs)
|
88
|
+
inputs = np.zeros((batchSize, spikeLen, 256, 400))
|
89
|
+
pres.append(predImg)
|
90
|
+
|
91
|
+
predImgs = np.concatenate(pres, axis=0)
|
92
|
+
count = 0
|
93
|
+
for img in predImgs:
|
94
|
+
count += 1
|
95
|
+
img = Image.fromarray(img)
|
96
|
+
img.save(resultPath + '%s.jpg' % (count))
|
97
|
+
|
@@ -0,0 +1,95 @@
|
|
1
|
+
name: SpikeFormer
|
2
|
+
channels:
|
3
|
+
- pytorch
|
4
|
+
- conda-forge
|
5
|
+
- defaults
|
6
|
+
dependencies:
|
7
|
+
- _libgcc_mutex=0.1=main
|
8
|
+
- _openmp_mutex=5.1=1_gnu
|
9
|
+
- blas=1.0=mkl
|
10
|
+
- brotlipy=0.7.0=py38h0a891b7_1004
|
11
|
+
- bzip2=1.0.8=h7f98852_4
|
12
|
+
- ca-certificates=2022.07.19=h06a4308_0
|
13
|
+
- certifi=2022.6.15=py38h06a4308_0
|
14
|
+
- cffi=1.14.6=py38ha65f79e_0
|
15
|
+
- charset-normalizer=2.1.1=pyhd8ed1ab_0
|
16
|
+
- cloudpickle=2.0.0=pyhd3eb1b0_0
|
17
|
+
- cryptography=37.0.2=py38h2b5fc30_0
|
18
|
+
- cudatoolkit=11.6.0=hecad31d_10
|
19
|
+
- cytoolz=0.11.0=py38h7b6447c_0
|
20
|
+
- dask-core=2022.7.0=py38h06a4308_0
|
21
|
+
- ffmpeg=4.3=hf484d3e_0
|
22
|
+
- fftw=3.3.9=h27cfd23_1
|
23
|
+
- freetype=2.10.4=h0708190_1
|
24
|
+
- fsspec=2022.7.1=py38h06a4308_0
|
25
|
+
- gmp=6.2.1=h58526e2_0
|
26
|
+
- gnutls=3.6.13=h85f3911_1
|
27
|
+
- idna=3.3=pyhd8ed1ab_0
|
28
|
+
- imageio=2.9.0=pyhd3eb1b0_0
|
29
|
+
- intel-openmp=2021.4.0=h06a4308_3561
|
30
|
+
- jpeg=9e=h166bdaf_1
|
31
|
+
- lame=3.100=h7f98852_1001
|
32
|
+
- lcms2=2.12=hddcbb42_0
|
33
|
+
- ld_impl_linux-64=2.38=h1181459_1
|
34
|
+
- libffi=3.3=he6710b0_2
|
35
|
+
- libgcc-ng=11.2.0=h1234567_1
|
36
|
+
- libgfortran-ng=11.2.0=h00389a5_1
|
37
|
+
- libgfortran5=11.2.0=h1234567_1
|
38
|
+
- libgomp=11.2.0=h1234567_1
|
39
|
+
- libiconv=1.17=h166bdaf_0
|
40
|
+
- libpng=1.6.37=h21135ba_2
|
41
|
+
- libstdcxx-ng=11.2.0=h1234567_1
|
42
|
+
- libtiff=4.2.0=hf544144_3
|
43
|
+
- libwebp-base=1.2.2=h7f98852_1
|
44
|
+
- locket=1.0.0=py38h06a4308_0
|
45
|
+
- lz4-c=1.9.3=h9c3ff4c_1
|
46
|
+
- mkl=2021.4.0=h06a4308_640
|
47
|
+
- mkl-service=2.4.0=py38h95df7f1_0
|
48
|
+
- mkl_fft=1.3.1=py38h8666266_1
|
49
|
+
- mkl_random=1.2.2=py38h1abd341_0
|
50
|
+
- natsort=7.1.1=pyhd3eb1b0_0
|
51
|
+
- ncurses=6.3=h5eee18b_3
|
52
|
+
- nettle=3.6=he412f7d_0
|
53
|
+
- networkx=2.8.4=py38h06a4308_0
|
54
|
+
- numpy=1.21.5=py38h6c91a56_3
|
55
|
+
- numpy-base=1.21.5=py38ha15fc14_3
|
56
|
+
- olefile=0.46=pyh9f0ad1d_1
|
57
|
+
- openh264=2.1.1=h780b84a_0
|
58
|
+
- openjpeg=2.4.0=hb52868f_1
|
59
|
+
- openssl=1.1.1q=h7f8727e_0
|
60
|
+
- packaging=21.3=pyhd3eb1b0_0
|
61
|
+
- partd=1.2.0=pyhd3eb1b0_1
|
62
|
+
- pillow=8.2.0=py38ha0e1e83_1
|
63
|
+
- pip=22.1.2=py38h06a4308_0
|
64
|
+
- pycparser=2.21=pyhd8ed1ab_0
|
65
|
+
- pyopenssl=22.0.0=pyhd8ed1ab_0
|
66
|
+
- pyparsing=3.0.9=py38h06a4308_0
|
67
|
+
- pysocks=1.7.1=pyha2e5f31_6
|
68
|
+
- python=3.8.13=h12debd9_0
|
69
|
+
- python_abi=3.8=2_cp38
|
70
|
+
- pytorch=1.12.1=py3.8_cuda11.6_cudnn8.3.2_0
|
71
|
+
- pytorch-mutex=1.0=cuda
|
72
|
+
- pywavelets=1.3.0=py38h7f8727e_0
|
73
|
+
- pyyaml=6.0=py38h7f8727e_1
|
74
|
+
- readline=8.1.2=h7f8727e_1
|
75
|
+
- requests=2.28.1=pyhd8ed1ab_1
|
76
|
+
- scikit-image=0.19.2=py38h51133e4_0
|
77
|
+
- scipy=1.7.3=py38h6c91a56_2
|
78
|
+
- setuptools=63.4.1=py38h06a4308_0
|
79
|
+
- six=1.16.0=pyh6c4a22f_0
|
80
|
+
- sqlite=3.39.2=h5082296_0
|
81
|
+
- tifffile=2020.10.1=py38hdd07704_2
|
82
|
+
- tk=8.6.12=h1ccaba5_0
|
83
|
+
- toolz=0.11.2=pyhd3eb1b0_0
|
84
|
+
- torchaudio=0.12.1=py38_cu116
|
85
|
+
- torchvision=0.13.1=py38_cu116
|
86
|
+
- typing_extensions=4.3.0=pyha770c72_0
|
87
|
+
- urllib3=1.26.11=pyhd8ed1ab_0
|
88
|
+
- wheel=0.37.1=pyhd3eb1b0_0
|
89
|
+
- xz=5.2.5=h7f8727e_1
|
90
|
+
- yaml=0.2.5=h7b6447c_0
|
91
|
+
- zlib=1.2.12=h7f8727e_2
|
92
|
+
- zstd=1.5.0=ha95c52a_0
|
93
|
+
- pip:
|
94
|
+
- einops==0.4.1
|
95
|
+
- opencv-python==4.6.0.66
|