spikezoo 0.1.1__py3-none-any.whl → 0.2__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (192) hide show
  1. spikezoo/__init__.py +13 -0
  2. spikezoo/archs/__pycache__/__init__.cpython-39.pyc +0 -0
  3. spikezoo/archs/base/__pycache__/nets.cpython-39.pyc +0 -0
  4. spikezoo/archs/base/nets.py +34 -0
  5. spikezoo/archs/bsf/README.md +92 -0
  6. spikezoo/archs/bsf/datasets/datasets.py +328 -0
  7. spikezoo/archs/bsf/datasets/ds_utils.py +64 -0
  8. spikezoo/archs/bsf/main.py +398 -0
  9. spikezoo/archs/bsf/metrics/psnr.py +22 -0
  10. spikezoo/archs/bsf/metrics/ssim.py +54 -0
  11. spikezoo/archs/bsf/models/bsf/__pycache__/align.cpython-39.pyc +0 -0
  12. spikezoo/archs/bsf/models/bsf/__pycache__/bsf.cpython-39.pyc +0 -0
  13. spikezoo/archs/bsf/models/bsf/__pycache__/rep.cpython-39.pyc +0 -0
  14. spikezoo/archs/bsf/models/bsf/align.py +154 -0
  15. spikezoo/archs/bsf/models/bsf/bsf.py +105 -0
  16. spikezoo/archs/bsf/models/bsf/dsft_convert.py +96 -0
  17. spikezoo/archs/bsf/models/bsf/rep.py +44 -0
  18. spikezoo/archs/bsf/models/get_model.py +7 -0
  19. spikezoo/archs/bsf/prepare_data/DSFT.py +62 -0
  20. spikezoo/archs/bsf/prepare_data/crop_dataset_train.py +135 -0
  21. spikezoo/archs/bsf/prepare_data/crop_dataset_val.py +139 -0
  22. spikezoo/archs/bsf/prepare_data/crop_train.sh +4 -0
  23. spikezoo/archs/bsf/prepare_data/crop_val.sh +4 -0
  24. spikezoo/archs/bsf/prepare_data/io_utils.py +64 -0
  25. spikezoo/archs/bsf/requirements.txt +9 -0
  26. spikezoo/archs/bsf/test.py +16 -0
  27. spikezoo/archs/bsf/utils.py +154 -0
  28. spikezoo/archs/spikeclip/__pycache__/nets.cpython-39.pyc +0 -0
  29. spikezoo/archs/spikeclip/nets.py +40 -0
  30. spikezoo/archs/spikeformer/CheckPoints/readme +1 -0
  31. spikezoo/archs/spikeformer/DataProcess/DataExtactor.py +60 -0
  32. spikezoo/archs/spikeformer/DataProcess/DataLoader.py +115 -0
  33. spikezoo/archs/spikeformer/DataProcess/LoadSpike.py +39 -0
  34. spikezoo/archs/spikeformer/EvalResults/readme +1 -0
  35. spikezoo/archs/spikeformer/LICENSE +21 -0
  36. spikezoo/archs/spikeformer/Metrics/Metrics.py +50 -0
  37. spikezoo/archs/spikeformer/Metrics/__init__.py +0 -0
  38. spikezoo/archs/spikeformer/Model/Loss.py +89 -0
  39. spikezoo/archs/spikeformer/Model/SpikeFormer.py +230 -0
  40. spikezoo/archs/spikeformer/Model/__init__.py +0 -0
  41. spikezoo/archs/spikeformer/Model/__pycache__/SpikeFormer.cpython-39.pyc +0 -0
  42. spikezoo/archs/spikeformer/Model/__pycache__/__init__.cpython-39.pyc +0 -0
  43. spikezoo/archs/spikeformer/README.md +30 -0
  44. spikezoo/archs/spikeformer/evaluate.py +87 -0
  45. spikezoo/archs/spikeformer/recon_real_data.py +97 -0
  46. spikezoo/archs/spikeformer/requirements.yml +95 -0
  47. spikezoo/archs/spikeformer/train.py +173 -0
  48. spikezoo/archs/spikeformer/utils.py +22 -0
  49. spikezoo/archs/spk2imgnet/.github/workflows/pylint.yml +23 -0
  50. spikezoo/archs/spk2imgnet/.gitignore +150 -0
  51. spikezoo/archs/spk2imgnet/DCNv2.py +135 -0
  52. spikezoo/archs/spk2imgnet/__pycache__/DCNv2.cpython-39.pyc +0 -0
  53. spikezoo/archs/spk2imgnet/__pycache__/align_arch.cpython-39.pyc +0 -0
  54. spikezoo/archs/spk2imgnet/__pycache__/nets.cpython-39.pyc +0 -0
  55. spikezoo/archs/spk2imgnet/align_arch.py +159 -0
  56. spikezoo/archs/spk2imgnet/dataset.py +144 -0
  57. spikezoo/archs/spk2imgnet/nets.py +230 -0
  58. spikezoo/archs/spk2imgnet/readme.md +86 -0
  59. spikezoo/archs/spk2imgnet/test_gen_imgseq.py +118 -0
  60. spikezoo/archs/spk2imgnet/train.py +189 -0
  61. spikezoo/archs/spk2imgnet/utils.py +64 -0
  62. spikezoo/archs/ssir/README.md +87 -0
  63. spikezoo/archs/ssir/configs/SSIR.yml +37 -0
  64. spikezoo/archs/ssir/configs/yml_parser.py +78 -0
  65. spikezoo/archs/ssir/datasets/dataset_sreds.py +170 -0
  66. spikezoo/archs/ssir/datasets/ds_utils.py +66 -0
  67. spikezoo/archs/ssir/losses.py +21 -0
  68. spikezoo/archs/ssir/main.py +326 -0
  69. spikezoo/archs/ssir/metrics/psnr.py +22 -0
  70. spikezoo/archs/ssir/metrics/ssim.py +54 -0
  71. spikezoo/archs/ssir/models/Vgg19.py +42 -0
  72. spikezoo/archs/ssir/models/__pycache__/layers.cpython-39.pyc +0 -0
  73. spikezoo/archs/ssir/models/__pycache__/networks.cpython-39.pyc +0 -0
  74. spikezoo/archs/ssir/models/layers.py +110 -0
  75. spikezoo/archs/ssir/models/networks.py +61 -0
  76. spikezoo/archs/ssir/requirements.txt +8 -0
  77. spikezoo/archs/ssir/shells/eval_SREDS.sh +6 -0
  78. spikezoo/archs/ssir/shells/train_SSIR.sh +12 -0
  79. spikezoo/archs/ssir/test.py +3 -0
  80. spikezoo/archs/ssir/utils.py +154 -0
  81. spikezoo/archs/ssml/__pycache__/cbam.cpython-39.pyc +0 -0
  82. spikezoo/archs/ssml/__pycache__/model.cpython-39.pyc +0 -0
  83. spikezoo/archs/ssml/cbam.py +224 -0
  84. spikezoo/archs/ssml/model.py +290 -0
  85. spikezoo/archs/ssml/res.png +0 -0
  86. spikezoo/archs/ssml/test.py +67 -0
  87. spikezoo/archs/stir/.git-credentials +0 -0
  88. spikezoo/archs/stir/README.md +65 -0
  89. spikezoo/archs/stir/ckpt_outputs/Descriptions.txt +1 -0
  90. spikezoo/archs/stir/configs/STIR.yml +37 -0
  91. spikezoo/archs/stir/configs/utils.py +155 -0
  92. spikezoo/archs/stir/configs/yml_parser.py +78 -0
  93. spikezoo/archs/stir/datasets/dataset_sreds.py +180 -0
  94. spikezoo/archs/stir/datasets/ds_utils.py +66 -0
  95. spikezoo/archs/stir/eval_SREDS.sh +5 -0
  96. spikezoo/archs/stir/main.py +397 -0
  97. spikezoo/archs/stir/metrics/losses.py +219 -0
  98. spikezoo/archs/stir/metrics/psnr.py +22 -0
  99. spikezoo/archs/stir/metrics/ssim.py +54 -0
  100. spikezoo/archs/stir/models/Vgg19.py +42 -0
  101. spikezoo/archs/stir/models/__pycache__/networks_STIR.cpython-39.pyc +0 -0
  102. spikezoo/archs/stir/models/__pycache__/submodules.cpython-39.pyc +0 -0
  103. spikezoo/archs/stir/models/__pycache__/transformer_new.cpython-39.pyc +0 -0
  104. spikezoo/archs/stir/models/networks_STIR.py +361 -0
  105. spikezoo/archs/stir/models/submodules.py +86 -0
  106. spikezoo/archs/stir/models/transformer_new.py +151 -0
  107. spikezoo/archs/stir/package_core/build/lib/package_core/__init__.py +0 -0
  108. spikezoo/archs/stir/package_core/build/lib/package_core/convertions.py +721 -0
  109. spikezoo/archs/stir/package_core/build/lib/package_core/disp_netS.py +133 -0
  110. spikezoo/archs/stir/package_core/build/lib/package_core/flow_utils.py +167 -0
  111. spikezoo/archs/stir/package_core/build/lib/package_core/generic_train_test.py +76 -0
  112. spikezoo/archs/stir/package_core/build/lib/package_core/geometry.py +458 -0
  113. spikezoo/archs/stir/package_core/build/lib/package_core/image_proc.py +183 -0
  114. spikezoo/archs/stir/package_core/build/lib/package_core/linalg.py +40 -0
  115. spikezoo/archs/stir/package_core/build/lib/package_core/losses.py +198 -0
  116. spikezoo/archs/stir/package_core/build/lib/package_core/metrics.py +51 -0
  117. spikezoo/archs/stir/package_core/build/lib/package_core/model_base.py +53 -0
  118. spikezoo/archs/stir/package_core/build/lib/package_core/net_basics.py +100 -0
  119. spikezoo/archs/stir/package_core/build/lib/package_core/resnet.py +333 -0
  120. spikezoo/archs/stir/package_core/build/lib/package_core/transforms.py +123 -0
  121. spikezoo/archs/stir/package_core/build/lib/package_core/utils.py +72 -0
  122. spikezoo/archs/stir/package_core/dist/package_core-0.0.0-py3.9.egg +0 -0
  123. spikezoo/archs/stir/package_core/package_core/__init__.py +0 -0
  124. spikezoo/archs/stir/package_core/package_core/__pycache__/__init__.cpython-39.pyc +0 -0
  125. spikezoo/archs/stir/package_core/package_core/__pycache__/net_basics.cpython-39.pyc +0 -0
  126. spikezoo/archs/stir/package_core/package_core/convertions.py +721 -0
  127. spikezoo/archs/stir/package_core/package_core/disp_netS.py +133 -0
  128. spikezoo/archs/stir/package_core/package_core/flow_utils.py +167 -0
  129. spikezoo/archs/stir/package_core/package_core/generic_train_test.py +76 -0
  130. spikezoo/archs/stir/package_core/package_core/geometry.py +458 -0
  131. spikezoo/archs/stir/package_core/package_core/image_proc.py +183 -0
  132. spikezoo/archs/stir/package_core/package_core/linalg.py +40 -0
  133. spikezoo/archs/stir/package_core/package_core/losses.py +198 -0
  134. spikezoo/archs/stir/package_core/package_core/metrics.py +51 -0
  135. spikezoo/archs/stir/package_core/package_core/model_base.py +53 -0
  136. spikezoo/archs/stir/package_core/package_core/net_basics.py +100 -0
  137. spikezoo/archs/stir/package_core/package_core/resnet.py +333 -0
  138. spikezoo/archs/stir/package_core/package_core/transforms.py +123 -0
  139. spikezoo/archs/stir/package_core/package_core/utils.py +72 -0
  140. spikezoo/archs/stir/package_core/package_core.egg-info/PKG-INFO +3 -0
  141. spikezoo/archs/stir/package_core/package_core.egg-info/SOURCES.txt +20 -0
  142. spikezoo/archs/stir/package_core/package_core.egg-info/dependency_links.txt +1 -0
  143. spikezoo/archs/stir/package_core/package_core.egg-info/top_level.txt +1 -0
  144. spikezoo/archs/stir/package_core/setup.py +5 -0
  145. spikezoo/archs/stir/requirements.txt +12 -0
  146. spikezoo/archs/stir/train_STIR.sh +9 -0
  147. spikezoo/archs/tfi/__pycache__/nets.cpython-39.pyc +0 -0
  148. spikezoo/archs/tfi/nets.py +43 -0
  149. spikezoo/archs/tfp/__pycache__/nets.cpython-39.pyc +0 -0
  150. spikezoo/archs/tfp/nets.py +13 -0
  151. spikezoo/archs/wgse/README.md +64 -0
  152. spikezoo/archs/wgse/__pycache__/dwtnets.cpython-39.pyc +0 -0
  153. spikezoo/archs/wgse/__pycache__/submodules.cpython-39.pyc +0 -0
  154. spikezoo/archs/wgse/dataset.py +59 -0
  155. spikezoo/archs/wgse/demo.png +0 -0
  156. spikezoo/archs/wgse/demo.py +83 -0
  157. spikezoo/archs/wgse/dwtnets.py +145 -0
  158. spikezoo/archs/wgse/eval.py +133 -0
  159. spikezoo/archs/wgse/logs/WGSE-Dwt1dNet-db8-5-ks3/log.txt +11 -0
  160. spikezoo/archs/wgse/submodules.py +68 -0
  161. spikezoo/archs/wgse/train.py +261 -0
  162. spikezoo/archs/wgse/transform.py +139 -0
  163. spikezoo/archs/wgse/utils.py +128 -0
  164. spikezoo/archs/wgse/weights/demo.png +0 -0
  165. spikezoo/data/base/test/gt/200_part1_key_id151.png +0 -0
  166. spikezoo/data/base/test/gt/200_part3_key_id151.png +0 -0
  167. spikezoo/data/base/test/gt/203_part1_key_id151.png +0 -0
  168. spikezoo/data/base/test/spike/200_part1_key_id151.dat +0 -0
  169. spikezoo/data/base/test/spike/200_part3_key_id151.dat +0 -0
  170. spikezoo/data/base/test/spike/203_part1_key_id151.dat +0 -0
  171. spikezoo/data/base/train/gt/203_part2_key_id151.png +0 -0
  172. spikezoo/data/base/train/gt/203_part3_key_id151.png +0 -0
  173. spikezoo/data/base/train/gt/203_part4_key_id151.png +0 -0
  174. spikezoo/data/base/train/spike/203_part2_key_id151.dat +0 -0
  175. spikezoo/data/base/train/spike/203_part3_key_id151.dat +0 -0
  176. spikezoo/data/base/train/spike/203_part4_key_id151.dat +0 -0
  177. spikezoo/datasets/base_dataset.py +2 -3
  178. spikezoo/metrics/__init__.py +1 -1
  179. spikezoo/models/base_model.py +1 -3
  180. spikezoo/pipeline/base_pipeline.py +7 -5
  181. spikezoo/pipeline/train_pipeline.py +1 -1
  182. spikezoo/utils/other_utils.py +16 -6
  183. spikezoo/utils/spike_utils.py +33 -29
  184. spikezoo/utils/vidar_loader.cpython-39-x86_64-linux-gnu.so +0 -0
  185. spikezoo-0.2.dist-info/METADATA +163 -0
  186. spikezoo-0.2.dist-info/RECORD +211 -0
  187. spikezoo/models/spcsnet_model.py +0 -19
  188. spikezoo-0.1.1.dist-info/METADATA +0 -39
  189. spikezoo-0.1.1.dist-info/RECORD +0 -36
  190. {spikezoo-0.1.1.dist-info → spikezoo-0.2.dist-info}/LICENSE.txt +0 -0
  191. {spikezoo-0.1.1.dist-info → spikezoo-0.2.dist-info}/WHEEL +0 -0
  192. {spikezoo-0.1.1.dist-info → spikezoo-0.2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,230 @@
1
+ '''
2
+ Ref: To implenment SpikeFormer, we referred to the code of ”segefomer-pytorch” published on github
3
+ (link: https://github.com/lucidrains/segformer-pytorch.git)
4
+ '''
5
+ from math import sqrt
6
+ from functools import partial
7
+ import torch
8
+ from torch import nn, einsum
9
+ from einops import rearrange, reduce
10
+
11
+ # helpers
12
+
13
+ def exists(val):
14
+ return val is not None
15
+
16
+ def cast_tuple(val, depth):
17
+ return val if isinstance(val, tuple) else (val,) * depth
18
+
19
+ LayerNorm = partial(nn.InstanceNorm2d, affine = True)
20
+
21
+ # classes
22
+
23
+ class DsConv2d(nn.Module):
24
+ def __init__(self, dim_in, dim_out, kernel_size, padding, stride = 1, bias = True):
25
+ super().__init__()
26
+ self.net = nn.Sequential(
27
+ nn.Conv2d(dim_in, dim_in, kernel_size = kernel_size, padding = padding, groups = dim_in, stride = stride, bias = bias),
28
+ nn.GELU(),
29
+ nn.Conv2d(dim_in, dim_out, kernel_size = 1, bias = bias),
30
+ nn.GELU(),
31
+ )
32
+ def forward(self, x):
33
+ return self.net(x)
34
+
35
+ class PreNorm(nn.Module):
36
+ def __init__(self, dim, fn):
37
+ super().__init__()
38
+ self.fn = fn
39
+ self.norm = LayerNorm(dim)
40
+
41
+ def forward(self, x):
42
+ # return self.fn(x)
43
+ return self.fn(self.norm(x))
44
+
45
+ class EfficientSelfAttention(nn.Module):
46
+ def __init__(
47
+ self,
48
+ *,
49
+ dim,
50
+ heads,
51
+ reduction_ratio
52
+ ):
53
+ super().__init__()
54
+ self.scale = (dim // heads) ** -0.5
55
+ self.heads = heads
56
+
57
+ self.to_q = nn.Conv2d(dim, dim, 1, bias = False)
58
+ self.to_kv = nn.Conv2d(dim, dim * 2, reduction_ratio, stride = reduction_ratio, bias = False)
59
+ self.to_out = nn.Sequential(
60
+ nn.Conv2d(dim, dim, 1, bias=False),
61
+ nn.GELU()
62
+ )
63
+
64
+ def forward(self, x):
65
+ h, w = x.shape[-2:]
66
+ heads = self.heads
67
+
68
+ q, k, v = (self.to_q(x), *self.to_kv(x).chunk(2, dim = 1))
69
+ q, k, v = map(lambda t: rearrange(t, 'b (h c) x y -> (b h) (x y) c', h = heads), (q, k, v))
70
+
71
+ sim = einsum('b i d, b j d -> b i j', q, k) * self.scale
72
+ attn = sim.softmax(dim = -1)
73
+
74
+ out = einsum('b i j, b j d -> b i d', attn, v)
75
+ out = rearrange(out, '(b h) (x y) c -> b (h c) x y', h = heads, x = h, y = w)
76
+ return self.to_out(out)
77
+
78
+ class MixFeedForward(nn.Module):
79
+ def __init__(
80
+ self,
81
+ *,
82
+ dim,
83
+ expansion_factor
84
+ ):
85
+ super().__init__()
86
+ hidden_dim = dim * expansion_factor
87
+ self.net = nn.Sequential(
88
+ nn.Conv2d(dim, hidden_dim, 1),
89
+ nn.GELU(),
90
+ DsConv2d(hidden_dim, hidden_dim, 3, padding = 1),
91
+ # nn.GELU(),
92
+ nn.Conv2d(hidden_dim, dim, 1),
93
+ nn.GELU(),
94
+ )
95
+
96
+ def forward(self, x):
97
+ return self.net(x)
98
+
99
+ class MiT(nn.Module):
100
+ def __init__(
101
+ self,
102
+ *,
103
+ channels,
104
+ dims,
105
+ heads,
106
+ ff_expansion,
107
+ reduction_ratio,
108
+ num_layers
109
+ ):
110
+ super().__init__()
111
+ stage_kernel_stride_pad = ((7, 4, 3), (3, 2, 1), (3, 2, 1), (3, 2, 1))
112
+ channels_mod = channels/16
113
+ dims = (channels_mod, *dims)
114
+ dim_pairs = list(zip(dims[:-1], dims[1:]))
115
+
116
+ self.stages = nn.ModuleList([])
117
+
118
+ for (dim_in, dim_out), (kernel, stride, padding), num_layers, ff_expansion, heads, reduction_ratio in zip(dim_pairs, stage_kernel_stride_pad, num_layers, ff_expansion, heads, reduction_ratio):
119
+ get_overlap_patches = nn.Unfold(kernel, stride = stride, padding = padding)
120
+ # if count == 0:
121
+ # overlap_patch_embed = nn.Conv2d(int((dim_in * kernel ** 2)/16), dim_out, 1)
122
+ # else:
123
+ overlap_patch_embed = nn.Conv2d(int(dim_in * kernel ** 2), dim_out, 1)
124
+ # count+=1
125
+
126
+ layers = nn.ModuleList([])
127
+
128
+ for _ in range(num_layers):
129
+ layers.append(nn.ModuleList([
130
+ PreNorm(dim_out, EfficientSelfAttention(dim = dim_out, heads = heads, reduction_ratio = reduction_ratio)),
131
+ PreNorm(dim_out, MixFeedForward(dim = dim_out, expansion_factor = ff_expansion)),
132
+ ]))
133
+
134
+ self.stages.append(nn.ModuleList([
135
+ get_overlap_patches,
136
+ overlap_patch_embed,
137
+ layers
138
+ ]))
139
+
140
+ def forward(
141
+ self,
142
+ x,
143
+ return_layer_outputs = False
144
+ ):
145
+ h, w = x.shape[-2:]
146
+ # h = int(h/4)
147
+ # w = int(w/4)
148
+ # print(h)
149
+
150
+ layer_outputs = []
151
+ for (get_overlap_patches, overlap_embed, layers) in self.stages:
152
+ x = get_overlap_patches(x)
153
+ # print('aaa')
154
+ # print(x.shape)
155
+ num_patches = int(x.shape[-1])
156
+
157
+
158
+ # num_patches = int(x.shape[-1]/16)
159
+ # print(num_patches)
160
+ ratio = int(sqrt((h * w) / num_patches))
161
+ # print(ratio)
162
+ x = rearrange(x, 'b c (h w) -> b c h w', h = h // ratio)
163
+ # print(x.shape)
164
+ x = x.type(torch.cuda.FloatTensor)
165
+
166
+ x = overlap_embed(x)
167
+ for (attn, ff) in layers:
168
+ x = attn(x) + x
169
+ x = ff(x) + x
170
+
171
+ layer_outputs.append(x)
172
+
173
+ ret = x if not return_layer_outputs else layer_outputs
174
+ return ret
175
+
176
+ class SpikeFormer(nn.Module):
177
+ def __init__(
178
+ self,
179
+ inputDim=64,
180
+ dims = (32, 64, 160, 256),
181
+ heads = (1, 2, 5, 8),
182
+ ff_expansion = (8, 8, 4, 4),
183
+ reduction_ratio = (8, 4, 2, 1),
184
+ num_layers = 2,
185
+ channels =64,
186
+ decoder_dim = 256,
187
+ out_channel = 1
188
+ ):
189
+ super().__init__()
190
+ dims, heads, ff_expansion, reduction_ratio, num_layers = map(partial(cast_tuple, depth = 4), (dims, heads, ff_expansion, reduction_ratio, num_layers))
191
+ assert all([*map(lambda t: len(t) == 4, (dims, heads, ff_expansion, reduction_ratio, num_layers))]), 'only four stages are allowed, all keyword arguments must be either a single value or a tuple of 4 values'
192
+
193
+ self.mit = MiT(
194
+ channels = channels,
195
+ dims = dims,
196
+ heads = heads,
197
+ ff_expansion = ff_expansion,
198
+ reduction_ratio = reduction_ratio,
199
+ num_layers = num_layers
200
+ )
201
+ self.channel_transform = nn.Sequential(
202
+ nn.Conv2d(inputDim, 64, 3, 1, 1),
203
+ nn.GELU()
204
+ )
205
+
206
+ self.to_fused = nn.ModuleList([nn.Sequential(
207
+ nn.Conv2d(dim, decoder_dim, 1),
208
+ nn.PixelShuffle(2 ** i),
209
+ nn.GELU(),
210
+ ) for i, dim in enumerate(dims)])
211
+
212
+ self.to_restore = nn.Sequential(
213
+ nn.Conv2d(256+64+16+4, decoder_dim, 1),
214
+ nn.GELU(),
215
+ nn.Conv2d(decoder_dim, out_channel, 1),
216
+ )
217
+ self.fournew = nn.PixelShuffle(4)
218
+
219
+
220
+
221
+ def forward(self, x):
222
+ x = self.channel_transform(x)
223
+ x = self.fournew(x)
224
+ layer_outputs = self.mit(x, return_layer_outputs = True)
225
+
226
+ fused = [to_fused(output) for output, to_fused in zip(layer_outputs, self.to_fused)]
227
+
228
+ fused = torch.cat(fused, dim = 1)
229
+
230
+ return self.to_restore(fused)
File without changes
@@ -0,0 +1,30 @@
1
+ # SpikeFormer [![MIT Licence](https://badges.frapsoft.com/os/mit/mit.svg?v=103)](https://opensource.org/licenses/mit-license.php)
2
+ Pytorch Implementation of "SpikeFormer: Image Reconstruction from the Sequence of Spike Camera Based on Transformer"[[Paper]](https://dl.acm.org/doi/abs/10.1145/3512388.3512399)
3
+
4
+ ## Prerequisites
5
+ * Create a conda environment by running `conda env create -f requirements.yml`
6
+
7
+ ## Dataset Structure
8
+ * To train the SpikeFormer, please organize file structure of the dataset as follows:
9
+ ```
10
+ Dataset
11
+ ├── test
12
+ │   └── c.npz
13
+ ├── train
14
+ │   └── a.npz
15
+ └── valid
16
+ └── b.npz
17
+ ```
18
+
19
+ ## Pretrained Model
20
+ * Download the pretrained model [here](https://pan.baidu.com/s/1aeW15vQh0GXgRJtfStBHDg) using password: nwh5.
21
+ * Put the model to the path ./CheckPoints/
22
+
23
+ ## Training
24
+ * Run `python train.py` to train SpikeFormer on training set.
25
+
26
+ ## Validation
27
+ * Run `python evaluate.py` to evaluate the performance of trained model on testing set.
28
+
29
+ ## Reconstruct Images from Real Spike Data
30
+ * Run `python recon_real_data.py`.
@@ -0,0 +1,87 @@
1
+ import os
2
+ os.environ['CUDA_VISIBLE_DEVICES'] = "0"
3
+ import torch
4
+ import numpy as np
5
+ from DataProcess import DataLoader as dl
6
+ from Model import Loss
7
+ from PIL import Image
8
+ from Metrics.Metrics import Metrics
9
+ from Model.SpikeFormer import SpikeFormer
10
+ from utils import LoadModel
11
+
12
+ if __name__ == "__main__":
13
+
14
+ dataPath = "/home/storage2/shechen/Spike_Sample_250x400"
15
+ spikeRadius = 32 # half length of input spike sequence expcept for the middle frame
16
+ spikeLen = 2 * spikeRadius + 1 # length of input spike sequence
17
+ batchSize = 4
18
+
19
+ reuse = True
20
+ checkPath = "CheckPoints/best.pth"
21
+
22
+ validContainer = dl.DataContainer(dataPath=dataPath, dataType='valid',
23
+ spikeRadius=spikeRadius,batchSize=batchSize)
24
+ validData = validContainer.GetLoader()
25
+
26
+ metrics = Metrics()
27
+ # model = Spk2Img(spikeRadius, frameRadius, frameStride).cuda()
28
+
29
+ model = SpikeFormer(
30
+ inputDim=spikeLen,
31
+ dims=(32, 64, 160, 256), # dimensions of each stage
32
+ heads=(1, 2, 5, 8), # heads of each stage
33
+ ff_expansion=(8, 8, 4, 4), # feedforward expansion factor of each stage
34
+ reduction_ratio=(8, 4, 2, 1), # reduction ratio of each stage for efficient attention
35
+ num_layers=2, # num layers of each stage
36
+ decoder_dim=256, # decoder dimension
37
+ out_channel=1 # channel of restored image
38
+ ).cuda()
39
+
40
+
41
+ if reuse:
42
+ _, _, modelDict, _ = LoadModel(checkPath, model)
43
+
44
+ model.eval()
45
+ with torch.no_grad():
46
+ num = 0
47
+ pres = []
48
+ gts = []
49
+ for i, (spikes, gtImg) in enumerate(validData):
50
+ B, D, H, W = spikes.size()
51
+ spikes = spikes.cuda()
52
+ gtImg = gtImg.cuda()
53
+ predImg = model(spikes)
54
+ predImg = predImg.squeeze(1)
55
+
56
+ predImg = predImg.clamp(min=-1., max=1.)
57
+ predImg = predImg.detach().cpu().numpy()
58
+ gtImg = gtImg.clamp(min=-1., max=1.)
59
+ gtImg = gtImg.detach().cpu().numpy()
60
+
61
+ predImg = (predImg + 1.) / 2. * 255.
62
+ predImg = predImg.astype(np.uint8)
63
+ predImg = predImg[:, 3:-3]
64
+
65
+ gtImg = (gtImg + 1.) / 2. * 255.
66
+ gtImg = gtImg.astype(np.uint8)
67
+
68
+ pres.append(predImg)
69
+ gts.append(gtImg)
70
+ pres = np.concatenate(pres, axis=0)
71
+ gts = np.concatenate(gts, axis=0)
72
+
73
+ psnr = metrics.Cal_PSNR(pres, gts)
74
+ ssim = metrics.Cal_SSIM(pres, gts)
75
+ best_psnr, best_ssim, _ = metrics.GetBestMetrics()
76
+
77
+ B, H, W = pres.shape
78
+ divide_line = np.zeros((H, 4)).astype(np.uint8)
79
+ for pre, gt in zip(pres, gts):
80
+ num += 1
81
+ concatImg = np.concatenate([pre, divide_line, gt], axis=1)
82
+ concatImg = Image.fromarray(concatImg)
83
+ concatImg.save('EvalResults/test_%s.jpg' % (num))
84
+
85
+ print('*********************************************************')
86
+ print('PSNR: %s, SSIM: %s' % (psnr, ssim))
87
+
@@ -0,0 +1,97 @@
1
+ import os
2
+ os.environ['CUDA_VISIBLE_DEVICES'] = "0"
3
+ import torch
4
+ import numpy as np
5
+ from PIL import Image
6
+ from Metrics.Metrics import Metrics
7
+ from Model.SpikeFormer import SpikeFormer
8
+ from DataProcess.LoadSpike import load_spike_raw
9
+ from utils import LoadModel
10
+ import shutil
11
+ from PIL import Image
12
+
13
+ def PredictImg(model, inputs):
14
+ inputs = torch.FloatTensor(inputs)
15
+ inputs = inputs.cuda()
16
+
17
+ predImg = model(inputs).squeeze(dim=1)
18
+
19
+ predImg = predImg.clamp(min=-1., max=1.)
20
+ predImg = predImg.detach().cpu().numpy()
21
+ predImg = (predImg + 1.) / 2. * 255.
22
+ predImg = np.clip(predImg, 0., 255.)
23
+ predImg = predImg.astype(np.uint8)
24
+ predImg = predImg[:, 3:-3]
25
+
26
+ return predImg
27
+
28
+ if __name__ == "__main__":
29
+
30
+ dataName = "reds"
31
+ spikeRadius = 32
32
+ spikeLen = 2 * spikeRadius + 1
33
+ stride = 32
34
+ batchSize = 8
35
+ reuse = True
36
+ checkPath = "best.pth"
37
+ sceneClass = {
38
+ 1:'ballon.dat', 2:'car-100kmh.dat',
39
+ 3:'forest.dat', 4:'railway.dat',
40
+ 5:'rotation1.dat', 6:'rotation2.dat',
41
+ 7:'train-350kmh.dat', 8:'viaduct-bridge.dat'
42
+ }
43
+ sceneName = sceneClass[2]
44
+ dataPath = "/home/storage1/Dataset/SpikeImageData/RealData/%s" %(sceneName)
45
+ resultPath = sceneName + "_stride_" + str(stride) + "/"
46
+ shutil.rmtree(resultPath) if os.path.exists(resultPath) else os.mkdir(resultPath)
47
+ spikes = load_spike_raw(dataPath)
48
+ totalLen = spikes.shape[0]
49
+ metrics = Metrics()
50
+ model = SpikeFormer(
51
+ inputDim=spikeLen,
52
+ dims=(32, 64, 160, 256), # dimensions of each stage
53
+ heads=(1, 2, 5, 8), # heads of each stage
54
+ ff_expansion=(8, 8, 4, 4), # feedforward expansion factor of each stage
55
+ reduction_ratio=(8, 4, 2, 1), # reduction ratio of each stage for efficient attention
56
+ num_layers=2, # num layers of each stage
57
+ decoder_dim=256, # decoder dimension
58
+ out_channel = 1 # channel of restored image
59
+ ).cuda()
60
+
61
+ if reuse:
62
+ _, _, modelDict, _ = LoadModel(checkPath, model)
63
+
64
+ model.eval()
65
+ with torch.no_grad():
66
+ num = 0
67
+ pres = []
68
+ batchFlag = 1
69
+ inputs = np.zeros((batchSize, spikeLen, 256, 400)) # 65
70
+ for i in range(32, totalLen - 32, stride):
71
+ batchFlag = 1
72
+ spike = spikes[i - spikeRadius: i + spikeRadius + 1]
73
+ spike = np.pad(spike, ((0, 0), (3, 3), (0, 0)), mode='constant')
74
+ spike = spike.astype(float) * 2 - 1
75
+ inputs[num % batchSize] = spike
76
+ num += 1
77
+
78
+ if num % batchSize == 0:
79
+ predImg = PredictImg(model, inputs)
80
+ inputs = np.zeros((batchSize, spikeLen, 256, 400)) # 65
81
+ pres.append(predImg)
82
+ batchFlag = 0
83
+
84
+ if batchFlag == 1:
85
+ imgNum = num % batchSize
86
+ inputs = inputs[0: imgNum]
87
+ predImg = PredictImg(model, inputs)
88
+ inputs = np.zeros((batchSize, spikeLen, 256, 400))
89
+ pres.append(predImg)
90
+
91
+ predImgs = np.concatenate(pres, axis=0)
92
+ count = 0
93
+ for img in predImgs:
94
+ count += 1
95
+ img = Image.fromarray(img)
96
+ img.save(resultPath + '%s.jpg' % (count))
97
+
@@ -0,0 +1,95 @@
1
+ name: SpikeFormer
2
+ channels:
3
+ - pytorch
4
+ - conda-forge
5
+ - defaults
6
+ dependencies:
7
+ - _libgcc_mutex=0.1=main
8
+ - _openmp_mutex=5.1=1_gnu
9
+ - blas=1.0=mkl
10
+ - brotlipy=0.7.0=py38h0a891b7_1004
11
+ - bzip2=1.0.8=h7f98852_4
12
+ - ca-certificates=2022.07.19=h06a4308_0
13
+ - certifi=2022.6.15=py38h06a4308_0
14
+ - cffi=1.14.6=py38ha65f79e_0
15
+ - charset-normalizer=2.1.1=pyhd8ed1ab_0
16
+ - cloudpickle=2.0.0=pyhd3eb1b0_0
17
+ - cryptography=37.0.2=py38h2b5fc30_0
18
+ - cudatoolkit=11.6.0=hecad31d_10
19
+ - cytoolz=0.11.0=py38h7b6447c_0
20
+ - dask-core=2022.7.0=py38h06a4308_0
21
+ - ffmpeg=4.3=hf484d3e_0
22
+ - fftw=3.3.9=h27cfd23_1
23
+ - freetype=2.10.4=h0708190_1
24
+ - fsspec=2022.7.1=py38h06a4308_0
25
+ - gmp=6.2.1=h58526e2_0
26
+ - gnutls=3.6.13=h85f3911_1
27
+ - idna=3.3=pyhd8ed1ab_0
28
+ - imageio=2.9.0=pyhd3eb1b0_0
29
+ - intel-openmp=2021.4.0=h06a4308_3561
30
+ - jpeg=9e=h166bdaf_1
31
+ - lame=3.100=h7f98852_1001
32
+ - lcms2=2.12=hddcbb42_0
33
+ - ld_impl_linux-64=2.38=h1181459_1
34
+ - libffi=3.3=he6710b0_2
35
+ - libgcc-ng=11.2.0=h1234567_1
36
+ - libgfortran-ng=11.2.0=h00389a5_1
37
+ - libgfortran5=11.2.0=h1234567_1
38
+ - libgomp=11.2.0=h1234567_1
39
+ - libiconv=1.17=h166bdaf_0
40
+ - libpng=1.6.37=h21135ba_2
41
+ - libstdcxx-ng=11.2.0=h1234567_1
42
+ - libtiff=4.2.0=hf544144_3
43
+ - libwebp-base=1.2.2=h7f98852_1
44
+ - locket=1.0.0=py38h06a4308_0
45
+ - lz4-c=1.9.3=h9c3ff4c_1
46
+ - mkl=2021.4.0=h06a4308_640
47
+ - mkl-service=2.4.0=py38h95df7f1_0
48
+ - mkl_fft=1.3.1=py38h8666266_1
49
+ - mkl_random=1.2.2=py38h1abd341_0
50
+ - natsort=7.1.1=pyhd3eb1b0_0
51
+ - ncurses=6.3=h5eee18b_3
52
+ - nettle=3.6=he412f7d_0
53
+ - networkx=2.8.4=py38h06a4308_0
54
+ - numpy=1.21.5=py38h6c91a56_3
55
+ - numpy-base=1.21.5=py38ha15fc14_3
56
+ - olefile=0.46=pyh9f0ad1d_1
57
+ - openh264=2.1.1=h780b84a_0
58
+ - openjpeg=2.4.0=hb52868f_1
59
+ - openssl=1.1.1q=h7f8727e_0
60
+ - packaging=21.3=pyhd3eb1b0_0
61
+ - partd=1.2.0=pyhd3eb1b0_1
62
+ - pillow=8.2.0=py38ha0e1e83_1
63
+ - pip=22.1.2=py38h06a4308_0
64
+ - pycparser=2.21=pyhd8ed1ab_0
65
+ - pyopenssl=22.0.0=pyhd8ed1ab_0
66
+ - pyparsing=3.0.9=py38h06a4308_0
67
+ - pysocks=1.7.1=pyha2e5f31_6
68
+ - python=3.8.13=h12debd9_0
69
+ - python_abi=3.8=2_cp38
70
+ - pytorch=1.12.1=py3.8_cuda11.6_cudnn8.3.2_0
71
+ - pytorch-mutex=1.0=cuda
72
+ - pywavelets=1.3.0=py38h7f8727e_0
73
+ - pyyaml=6.0=py38h7f8727e_1
74
+ - readline=8.1.2=h7f8727e_1
75
+ - requests=2.28.1=pyhd8ed1ab_1
76
+ - scikit-image=0.19.2=py38h51133e4_0
77
+ - scipy=1.7.3=py38h6c91a56_2
78
+ - setuptools=63.4.1=py38h06a4308_0
79
+ - six=1.16.0=pyh6c4a22f_0
80
+ - sqlite=3.39.2=h5082296_0
81
+ - tifffile=2020.10.1=py38hdd07704_2
82
+ - tk=8.6.12=h1ccaba5_0
83
+ - toolz=0.11.2=pyhd3eb1b0_0
84
+ - torchaudio=0.12.1=py38_cu116
85
+ - torchvision=0.13.1=py38_cu116
86
+ - typing_extensions=4.3.0=pyha770c72_0
87
+ - urllib3=1.26.11=pyhd8ed1ab_0
88
+ - wheel=0.37.1=pyhd3eb1b0_0
89
+ - xz=5.2.5=h7f8727e_1
90
+ - yaml=0.2.5=h7b6447c_0
91
+ - zlib=1.2.12=h7f8727e_2
92
+ - zstd=1.5.0=ha95c52a_0
93
+ - pip:
94
+ - einops==0.4.1
95
+ - opencv-python==4.6.0.66