spikezoo 0.2.2__py3-none-any.whl → 0.2.3__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- spikezoo/__init__.py +23 -7
- spikezoo/archs/bsf/models/bsf/bsf.py +37 -25
- spikezoo/archs/bsf/models/bsf/rep.py +2 -2
- spikezoo/archs/spk2imgnet/nets.py +1 -1
- spikezoo/archs/ssir/models/networks.py +1 -1
- spikezoo/archs/ssml/model.py +9 -5
- spikezoo/archs/stir/metrics/losses.py +1 -1
- spikezoo/archs/stir/models/networks_STIR.py +16 -9
- spikezoo/archs/tfi/nets.py +1 -1
- spikezoo/archs/tfp/nets.py +1 -1
- spikezoo/archs/wgse/dwtnets.py +6 -6
- spikezoo/datasets/__init__.py +11 -9
- spikezoo/datasets/base_dataset.py +10 -3
- spikezoo/datasets/realworld_dataset.py +1 -3
- spikezoo/datasets/{reds_small_dataset.py → reds_base_dataset.py} +9 -8
- spikezoo/datasets/reds_ssir_dataset.py +181 -0
- spikezoo/datasets/szdata_dataset.py +5 -15
- spikezoo/datasets/uhsr_dataset.py +4 -3
- spikezoo/models/__init__.py +8 -6
- spikezoo/models/base_model.py +120 -64
- spikezoo/models/bsf_model.py +11 -3
- spikezoo/models/spcsnet_model.py +19 -0
- spikezoo/models/spikeclip_model.py +4 -3
- spikezoo/models/spk2imgnet_model.py +9 -15
- spikezoo/models/ssir_model.py +4 -6
- spikezoo/models/ssml_model.py +44 -2
- spikezoo/models/stir_model.py +26 -5
- spikezoo/models/tfi_model.py +3 -1
- spikezoo/models/tfp_model.py +4 -2
- spikezoo/models/wgse_model.py +8 -14
- spikezoo/pipeline/base_pipeline.py +79 -55
- spikezoo/pipeline/ensemble_pipeline.py +10 -9
- spikezoo/pipeline/train_cfgs.py +89 -0
- spikezoo/pipeline/train_pipeline.py +129 -30
- spikezoo/utils/optimizer_utils.py +22 -0
- spikezoo/utils/other_utils.py +31 -6
- spikezoo/utils/scheduler_utils.py +25 -0
- spikezoo/utils/spike_utils.py +61 -29
- spikezoo-0.2.3.dist-info/METADATA +263 -0
- {spikezoo-0.2.2.dist-info → spikezoo-0.2.3.dist-info}/RECORD +43 -80
- spikezoo/archs/__pycache__/__init__.cpython-39.pyc +0 -0
- spikezoo/archs/base/__pycache__/nets.cpython-39.pyc +0 -0
- spikezoo/archs/bsf/models/bsf/__pycache__/align.cpython-39.pyc +0 -0
- spikezoo/archs/bsf/models/bsf/__pycache__/bsf.cpython-39.pyc +0 -0
- spikezoo/archs/bsf/models/bsf/__pycache__/rep.cpython-39.pyc +0 -0
- spikezoo/archs/spikeclip/__pycache__/nets.cpython-39.pyc +0 -0
- spikezoo/archs/spikeformer/CheckPoints/readme +0 -1
- spikezoo/archs/spikeformer/DataProcess/DataExtactor.py +0 -60
- spikezoo/archs/spikeformer/DataProcess/DataLoader.py +0 -115
- spikezoo/archs/spikeformer/DataProcess/LoadSpike.py +0 -39
- spikezoo/archs/spikeformer/EvalResults/readme +0 -1
- spikezoo/archs/spikeformer/LICENSE +0 -21
- spikezoo/archs/spikeformer/Metrics/Metrics.py +0 -50
- spikezoo/archs/spikeformer/Metrics/__init__.py +0 -0
- spikezoo/archs/spikeformer/Model/Loss.py +0 -89
- spikezoo/archs/spikeformer/Model/SpikeFormer.py +0 -230
- spikezoo/archs/spikeformer/Model/__init__.py +0 -0
- spikezoo/archs/spikeformer/Model/__pycache__/SpikeFormer.cpython-39.pyc +0 -0
- spikezoo/archs/spikeformer/Model/__pycache__/__init__.cpython-39.pyc +0 -0
- spikezoo/archs/spikeformer/README.md +0 -30
- spikezoo/archs/spikeformer/evaluate.py +0 -87
- spikezoo/archs/spikeformer/recon_real_data.py +0 -97
- spikezoo/archs/spikeformer/requirements.yml +0 -95
- spikezoo/archs/spikeformer/train.py +0 -173
- spikezoo/archs/spikeformer/utils.py +0 -22
- spikezoo/archs/spk2imgnet/__pycache__/DCNv2.cpython-39.pyc +0 -0
- spikezoo/archs/spk2imgnet/__pycache__/align_arch.cpython-39.pyc +0 -0
- spikezoo/archs/spk2imgnet/__pycache__/nets.cpython-39.pyc +0 -0
- spikezoo/archs/ssir/models/__pycache__/layers.cpython-39.pyc +0 -0
- spikezoo/archs/ssir/models/__pycache__/networks.cpython-39.pyc +0 -0
- spikezoo/archs/ssml/__pycache__/cbam.cpython-39.pyc +0 -0
- spikezoo/archs/ssml/__pycache__/model.cpython-39.pyc +0 -0
- spikezoo/archs/stir/models/__pycache__/networks_STIR.cpython-39.pyc +0 -0
- spikezoo/archs/stir/models/__pycache__/submodules.cpython-39.pyc +0 -0
- spikezoo/archs/stir/models/__pycache__/transformer_new.cpython-39.pyc +0 -0
- spikezoo/archs/stir/package_core/package_core/__pycache__/__init__.cpython-39.pyc +0 -0
- spikezoo/archs/stir/package_core/package_core/__pycache__/net_basics.cpython-39.pyc +0 -0
- spikezoo/archs/tfi/__pycache__/nets.cpython-39.pyc +0 -0
- spikezoo/archs/tfp/__pycache__/nets.cpython-39.pyc +0 -0
- spikezoo/archs/wgse/__pycache__/dwtnets.cpython-39.pyc +0 -0
- spikezoo/archs/wgse/__pycache__/submodules.cpython-39.pyc +0 -0
- spikezoo/models/spikeformer_model.py +0 -50
- spikezoo-0.2.2.dist-info/METADATA +0 -196
- {spikezoo-0.2.2.dist-info → spikezoo-0.2.3.dist-info}/LICENSE.txt +0 -0
- {spikezoo-0.2.2.dist-info → spikezoo-0.2.3.dist-info}/WHEEL +0 -0
- {spikezoo-0.2.2.dist-info → spikezoo-0.2.3.dist-info}/top_level.txt +0 -0
@@ -1,173 +0,0 @@
|
|
1
|
-
import os
|
2
|
-
os.environ['CUDA_VISIBLE_DEVICES'] = "0"
|
3
|
-
import torch
|
4
|
-
from torch import optim
|
5
|
-
import numpy as np
|
6
|
-
from DataProcess import DataLoader as dl
|
7
|
-
from Model.SpikeFormer import SpikeFormer
|
8
|
-
from Metrics.Metrics import Metrics
|
9
|
-
from Model import Loss
|
10
|
-
from utils import SaveModel, LoadModel
|
11
|
-
from PIL import Image
|
12
|
-
|
13
|
-
def eval(model, validData, epoch, optimizer, metrics):
|
14
|
-
|
15
|
-
model.eval()
|
16
|
-
print('Eval Epoch: %s' %(epoch))
|
17
|
-
|
18
|
-
with torch.no_grad():
|
19
|
-
pres = []
|
20
|
-
gts = []
|
21
|
-
for i, (spikes, gtImg) in enumerate(validData):
|
22
|
-
|
23
|
-
spikes = spikes.cuda()
|
24
|
-
gtImg = gtImg.cuda()
|
25
|
-
predImg = model(spikes)
|
26
|
-
predImg = predImg.squeeze(1)
|
27
|
-
predImg = predImg[:,3:-3,:]
|
28
|
-
|
29
|
-
predImg = predImg.clamp(min=-1., max=1.)
|
30
|
-
predImg = predImg.detach().cpu().numpy()
|
31
|
-
gtImg = gtImg.clamp(min=-1., max=1.)
|
32
|
-
gtImg = gtImg.detach().cpu().numpy()
|
33
|
-
|
34
|
-
predImg = (predImg + 1.) / 2. * 255.
|
35
|
-
predImg = predImg.astype(np.uint8)
|
36
|
-
|
37
|
-
gtImg = (gtImg + 1.) / 2. * 255.
|
38
|
-
gtImg = gtImg.astype(np.uint8)
|
39
|
-
|
40
|
-
pres.append(predImg)
|
41
|
-
gts.append(gtImg)
|
42
|
-
pres = np.concatenate(pres, axis=0)
|
43
|
-
gts = np.concatenate(gts, axis=0)
|
44
|
-
|
45
|
-
psnr = metrics.Cal_PSNR(pres, gts)
|
46
|
-
ssim = metrics.Cal_SSIM(pres, gts)
|
47
|
-
best_psnr, best_ssim, _ = metrics.GetBestMetrics()
|
48
|
-
|
49
|
-
SaveModel(epoch, (psnr, ssim), model, optimizer, saveRoot)
|
50
|
-
if psnr >= best_psnr and ssim >= best_ssim:
|
51
|
-
metrics.Update(psnr, ssim)
|
52
|
-
SaveModel(epoch, (psnr, ssim), model, optimizer, saveRoot, best=True)
|
53
|
-
with open('eval_best_log.txt', 'w') as f:
|
54
|
-
f.write('epoch: %s; psnr: %s, ssim: %s\n' %(epoch, psnr, ssim))
|
55
|
-
B, H, W = pres.shape
|
56
|
-
divide_line = np.zeros((H,4)).astype(np.uint8)
|
57
|
-
num = 0
|
58
|
-
for pre, gt in zip(pres, gts):
|
59
|
-
num += 1
|
60
|
-
concatImg = np.concatenate([pre, divide_line, gt], axis=1)
|
61
|
-
concatImg = Image.fromarray(concatImg)
|
62
|
-
concatImg.save('EvalResults/valid_%s.jpg' % (num))
|
63
|
-
|
64
|
-
print('*********************************************************')
|
65
|
-
best_psnr, best_ssim, _ = metrics.GetBestMetrics()
|
66
|
-
print('Eval Epoch: %s, PSNR: %s, SSIM: %s, Best_PSNR: %s, Best_SSIM: %s'
|
67
|
-
%(epoch, psnr, ssim, best_psnr, best_ssim))
|
68
|
-
|
69
|
-
model.train()
|
70
|
-
|
71
|
-
def Train(trainData, validData, model, optimizer, epoch, start_epoch, metrics, saveRoot, perIter):
|
72
|
-
avg_l2_loss = 0.
|
73
|
-
avg_vgg_loss = 0.
|
74
|
-
avg_edge_loss = 0.
|
75
|
-
avg_total_loss = 0.
|
76
|
-
l2loss = Loss.CharbonnierLoss()
|
77
|
-
vggloss = Loss.VGGLoss4('vgg19-low-level4.pth').cuda()
|
78
|
-
criterion_edge = Loss.EdgeLoss()
|
79
|
-
LAMBDA_L2 = 100.0
|
80
|
-
LAMBDA_VGG = 1.0
|
81
|
-
LAMBDA_EDGE = 5.0
|
82
|
-
for i in range(start_epoch, epoch):
|
83
|
-
for iter, (spikes, gtImg) in enumerate(trainData):
|
84
|
-
spikes = spikes.cuda()
|
85
|
-
gtImg = gtImg.cuda()
|
86
|
-
predImg = model(spikes)
|
87
|
-
gtImg = gtImg.unsqueeze(1)
|
88
|
-
predImg = predImg[:,:,3:-3,:]
|
89
|
-
|
90
|
-
loss_vgg = vggloss(gtImg, predImg) * LAMBDA_VGG
|
91
|
-
loss_l2 = l2loss(gtImg, predImg) * LAMBDA_L2
|
92
|
-
loss_edge = criterion_edge(gtImg, predImg) * LAMBDA_EDGE
|
93
|
-
|
94
|
-
totalLoss = loss_l2 + loss_vgg + loss_edge
|
95
|
-
|
96
|
-
optimizer.zero_grad()
|
97
|
-
totalLoss.backward()
|
98
|
-
optimizer.step()
|
99
|
-
|
100
|
-
avg_l2_loss += loss_l2.detach().cpu()
|
101
|
-
avg_vgg_loss += loss_vgg.detach().cpu()
|
102
|
-
avg_edge_loss += loss_edge.detach().cpu()
|
103
|
-
avg_total_loss += totalLoss.detach().cpu()
|
104
|
-
if (iter + 1) % perIter == 0:
|
105
|
-
avg_l2_loss = avg_l2_loss / perIter
|
106
|
-
avg_vgg_loss = avg_vgg_loss / perIter
|
107
|
-
avg_edge_loss = avg_edge_loss / perIter
|
108
|
-
avg_total_loss = avg_total_loss / perIter
|
109
|
-
print('=============================================================')
|
110
|
-
print('Epoch: %s, Iter: %s' % (i, iter + 1))
|
111
|
-
print('L2Loss: %s; VggLoss: %s; EdgeLoss: %s; TotalLoss: %s' % (
|
112
|
-
avg_l2_loss.item(), avg_vgg_loss.item(), avg_edge_loss.item(), avg_total_loss.item()))
|
113
|
-
avg_l2_loss = 0.
|
114
|
-
avg_vgg_loss = 0.
|
115
|
-
avg_edge_loss = 0.
|
116
|
-
avg_total_loss = 0.
|
117
|
-
|
118
|
-
if (i + 1) % 1 == 0:
|
119
|
-
eval(model, validData, i, optimizer, metrics)
|
120
|
-
|
121
|
-
if __name__ == "__main__":
|
122
|
-
|
123
|
-
dataPath = "/home/storage2/shechen/Spike_Sample_250x400"
|
124
|
-
spikeRadius = 32 # half length of input spike sequence expcept for the middle frame
|
125
|
-
spikeLen = 2 * spikeRadius + 1 # length of input spike sequence
|
126
|
-
batchSize = 2
|
127
|
-
epoch = 200
|
128
|
-
start_epoch = 0
|
129
|
-
lr = 2e-4
|
130
|
-
saveRoot = "CheckPoints/" # path to save the trained model
|
131
|
-
perIter = 20
|
132
|
-
|
133
|
-
reuse = False
|
134
|
-
reuseType = 'latest' # 'latest' or 'best'
|
135
|
-
checkPath = os.path.join('CheckPoints', '%s.pth' % (reuseType))
|
136
|
-
|
137
|
-
trainContainer = dl.DataContainer(dataPath=dataPath, dataType='train',
|
138
|
-
spikeRadius=spikeRadius,
|
139
|
-
batchSize=batchSize)
|
140
|
-
trainData = trainContainer.GetLoader()
|
141
|
-
|
142
|
-
validContainer = dl.DataContainer(dataPath=dataPath, dataType='valid',
|
143
|
-
spikeRadius=spikeRadius,
|
144
|
-
batchSize=batchSize)
|
145
|
-
validData = validContainer.GetLoader()
|
146
|
-
|
147
|
-
metrics = Metrics()
|
148
|
-
|
149
|
-
model = SpikeFormer(
|
150
|
-
inputDim = spikeLen,
|
151
|
-
dims = (32, 64, 160, 256), # dimensions of each stage
|
152
|
-
heads = (1, 2, 5, 8), # heads of each stage
|
153
|
-
ff_expansion = (8, 8, 4, 4), # feedforward expansion factor of each stage
|
154
|
-
reduction_ratio = (8, 4, 2, 1), # reduction ratio of each stage for efficient attention
|
155
|
-
num_layers = 2, # num layers of each stage
|
156
|
-
decoder_dim = 256, # decoder dimension
|
157
|
-
out_channel = 1 # channel of restored image
|
158
|
-
).cuda()
|
159
|
-
|
160
|
-
optimizer = optim.Adam(model.parameters(), lr=lr, betas=(0.9, 0.999), amsgrad=False)
|
161
|
-
|
162
|
-
if reuse:
|
163
|
-
preEpoch, prePerformance, modelDict, optDict = LoadModel(checkPath, model, optimizer)
|
164
|
-
start_epoch = preEpoch + 1
|
165
|
-
psnr, ssim = prePerformance[0], prePerformance[1]
|
166
|
-
metrics.Update(psnr, ssim)
|
167
|
-
for para in optimizer.param_groups:
|
168
|
-
para['lr'] = lr
|
169
|
-
|
170
|
-
model.train()
|
171
|
-
|
172
|
-
Train(trainData, validData, model, optimizer, epoch, start_epoch,
|
173
|
-
metrics, saveRoot, perIter)
|
@@ -1,22 +0,0 @@
|
|
1
|
-
import os
|
2
|
-
import torch
|
3
|
-
|
4
|
-
def SaveModel(epoch, bestPerformance, model, optimizer, saveRoot, best=False):
|
5
|
-
saveDict = {
|
6
|
-
'pre_epoch':epoch,
|
7
|
-
'performance':bestPerformance,
|
8
|
-
'model_state_dict':model.state_dict(),
|
9
|
-
'optimizer_state_dict':optimizer.state_dict()
|
10
|
-
}
|
11
|
-
savePath = os.path.join(saveRoot, '%s.pth' %('latest' if not best else 'best'))
|
12
|
-
torch.save(saveDict, savePath)
|
13
|
-
|
14
|
-
def LoadModel(checkPath, model, optimizer=None):
|
15
|
-
stateDict = torch.load(checkPath)
|
16
|
-
pre_epoch = stateDict['pre_epoch']
|
17
|
-
model.load_state_dict(stateDict['model_state_dict'])
|
18
|
-
if optimizer is not None:
|
19
|
-
optimizer.load_state_dict(stateDict['optimizer_state_dict'])
|
20
|
-
|
21
|
-
return pre_epoch, stateDict['performance'], \
|
22
|
-
stateDict['model_state_dict'], stateDict['optimizer_state_dict']
|
Binary file
|
Binary file
|
Binary file
|
Binary file
|
Binary file
|
Binary file
|
Binary file
|
Binary file
|
Binary file
|
Binary file
|
Binary file
|
Binary file
|
Binary file
|
Binary file
|
Binary file
|
Binary file
|
@@ -1,50 +0,0 @@
|
|
1
|
-
import torch
|
2
|
-
from dataclasses import dataclass, field
|
3
|
-
from spikezoo.models.base_model import BaseModel, BaseModelConfig
|
4
|
-
|
5
|
-
|
6
|
-
@dataclass
|
7
|
-
class SpikeFormerConfig(BaseModelConfig):
|
8
|
-
# default params for SpikeFormer
|
9
|
-
model_name: str = "spikeformer"
|
10
|
-
model_file_name: str = "Model.SpikeFormer"
|
11
|
-
model_cls_name: str = "SpikeFormer"
|
12
|
-
model_win_length: int = 65
|
13
|
-
require_params: bool = True
|
14
|
-
ckpt_path: str = "weights/spikeformer.pth"
|
15
|
-
model_params: dict = field(
|
16
|
-
default_factory=lambda: {
|
17
|
-
"inputDim": 65,
|
18
|
-
"dims": (32, 64, 160, 256),
|
19
|
-
"heads": (1, 2, 5, 8),
|
20
|
-
"ff_expansion": (8, 8, 4, 4),
|
21
|
-
"reduction_ratio": (8, 4, 2, 1),
|
22
|
-
"num_layers": 2,
|
23
|
-
"decoder_dim": 256,
|
24
|
-
"out_channel": 1,
|
25
|
-
}
|
26
|
-
)
|
27
|
-
|
28
|
-
|
29
|
-
class SpikeFormer(BaseModel):
|
30
|
-
def __init__(self, cfg: BaseModelConfig):
|
31
|
-
super(SpikeFormer, self).__init__(cfg)
|
32
|
-
|
33
|
-
def preprocess_spike(self, spike):
|
34
|
-
# length
|
35
|
-
spike = self.crop_spike_length(spike)
|
36
|
-
# size
|
37
|
-
if self.spike_size == (250, 400):
|
38
|
-
spike = torch.cat([spike[:, :, :3, :], spike, spike[:, :, -3:, :]], dim=2)
|
39
|
-
elif self.spike_size == (480, 854):
|
40
|
-
spike = torch.cat([spike, spike[:, :, :, -2:]], dim=3)
|
41
|
-
# input
|
42
|
-
spike = 2 * spike - 1
|
43
|
-
return spike
|
44
|
-
|
45
|
-
def postprocess_img(self, image):
|
46
|
-
if self.spike_size == (250, 400):
|
47
|
-
image = image[:, :, 3:-3, :]
|
48
|
-
elif self.spike_size == (480, 854):
|
49
|
-
image = image[:, :, :, :854]
|
50
|
-
return image
|
@@ -1,196 +0,0 @@
|
|
1
|
-
Metadata-Version: 2.2
|
2
|
-
Name: spikezoo
|
3
|
-
Version: 0.2.2
|
4
|
-
Summary: A deep learning toolbox for spike-to-image models.
|
5
|
-
Home-page: https://github.com/chenkang455/Spike-Zoo
|
6
|
-
Author: Kang Chen
|
7
|
-
Author-email: mrchenkang@stu.pku.edu.cn
|
8
|
-
Requires-Python: >=3.7
|
9
|
-
Description-Content-Type: text/markdown
|
10
|
-
License-File: LICENSE.txt
|
11
|
-
Requires-Dist: torch
|
12
|
-
Requires-Dist: requests
|
13
|
-
Requires-Dist: numpy
|
14
|
-
Requires-Dist: tqdm
|
15
|
-
Requires-Dist: scikit-image
|
16
|
-
Requires-Dist: lpips
|
17
|
-
Requires-Dist: pyiqa
|
18
|
-
Requires-Dist: opencv-python
|
19
|
-
Requires-Dist: thop
|
20
|
-
Requires-Dist: pytorch-wavelets
|
21
|
-
Requires-Dist: pytz
|
22
|
-
Requires-Dist: PyWavelets
|
23
|
-
Requires-Dist: pandas
|
24
|
-
Requires-Dist: pillow
|
25
|
-
Requires-Dist: scikit-learn
|
26
|
-
Requires-Dist: scipy
|
27
|
-
Requires-Dist: spikingjelly
|
28
|
-
Requires-Dist: setuptools
|
29
|
-
Dynamic: author
|
30
|
-
Dynamic: author-email
|
31
|
-
Dynamic: description
|
32
|
-
Dynamic: description-content-type
|
33
|
-
Dynamic: home-page
|
34
|
-
Dynamic: requires-dist
|
35
|
-
Dynamic: requires-python
|
36
|
-
Dynamic: summary
|
37
|
-
|
38
|
-
<h2 align="center">
|
39
|
-
<a href="">⚡Spike-Zoo: A Toolbox for Spike-to-Image Reconstruction
|
40
|
-
</a>
|
41
|
-
</h2>
|
42
|
-
|
43
|
-
## 📖 About
|
44
|
-
⚡Spike-Zoo is the go-to library for state-of-the-art pretrained **spike-to-image** models designed to reconstruct images from spike streams. Whether you're looking for a simple inference solution or aiming to train your own spike-to-image models, ⚡Spike-Zoo is a modular toolbox that supports both, with key features including:
|
45
|
-
|
46
|
-
- Fast inference with pre-trained models.
|
47
|
-
- Training support for custom-designed spike-to-image models.
|
48
|
-
- Specialized functions for processing spike data.
|
49
|
-
|
50
|
-
|
51
|
-
|
52
|
-
## 🚩 Updates/Changelog
|
53
|
-
* **25-02-02:** Release the `Spike-Zoo v0.2` code, which supports more methods, provide more usages.
|
54
|
-
* **24-08-26:** Update the `SpikeFormer` and `RSIR` methods, the `UHSR` dataset and the `piqe` non-reference metric.
|
55
|
-
|
56
|
-
* **24-07-19:** Release the `Spike-Zoo v0.1` base code.
|
57
|
-
|
58
|
-
## 🍾 Quick Start
|
59
|
-
### 1. Installation
|
60
|
-
For users focused on **utilizing pretrained models for spike-to-image conversion**, we recommend installing SpikeZoo using one of the following methods:
|
61
|
-
|
62
|
-
* Install the last stable version from PyPI:
|
63
|
-
```
|
64
|
-
pip install spikezoo
|
65
|
-
```
|
66
|
-
* Install the latest developing version from the source code:
|
67
|
-
```
|
68
|
-
git clone https://github.com/chenkang455/Spike-Zoo
|
69
|
-
cd Spike-Zoo
|
70
|
-
python setup.py install
|
71
|
-
```
|
72
|
-
|
73
|
-
For users interested in **training their own spike-to-image model based on our framework**, we recommend cloning the repository and modifying the related code directly.
|
74
|
-
|
75
|
-
### 2. Inference
|
76
|
-
Reconstructing images from the spike input is super easy with Spike-Zoo. Try the following code of the single model:
|
77
|
-
``` python
|
78
|
-
from spikezoo.pipeline import Pipeline, PipelineConfig
|
79
|
-
pipeline = Pipeline(
|
80
|
-
cfg = PipelineConfig(save_folder="results"),
|
81
|
-
model_cfg="spk2imgnet",
|
82
|
-
dataset_cfg="base"
|
83
|
-
)
|
84
|
-
```
|
85
|
-
You can also run multiple models at once by changing the pipeline:
|
86
|
-
``` python
|
87
|
-
from spikezoo.pipeline import EnsemblePipeline, EnsemblePipelineConfig
|
88
|
-
pipeline = EnsemblePipeline(
|
89
|
-
cfg = EnsemblePipelineConfig(save_folder="results"),
|
90
|
-
model_cfg_list=['tfp','tfi', 'spk2imgnet', 'wgse', 'ssml', 'bsf', 'stir', 'spikeclip','spikeformer'],
|
91
|
-
dataset_cfg="base"
|
92
|
-
)
|
93
|
-
```
|
94
|
-
* Having established the pipeline, run the following code to obtain the metric and save the reconstructed image from the given spike:
|
95
|
-
``` python
|
96
|
-
# 1. spike-to-image from the given dataset
|
97
|
-
pipeline.spk2img_from_dataset(idx = 0)
|
98
|
-
|
99
|
-
# 2. spike-to-image from the given .dat file
|
100
|
-
pipeline.spk2img_from_file(file_path = 'data/scissor.dat',width = 400,height=250)
|
101
|
-
|
102
|
-
# 3. spike-to-image from the given spike
|
103
|
-
import spikezoo as sz
|
104
|
-
spike = sz.load_vidar_dat("data/scissor.dat",width = 400,height = 250,version='cpp')
|
105
|
-
pipeline.spk2img_from_spk(spike)
|
106
|
-
```
|
107
|
-
For detailed usage, welcome check [test_single.ipynb](examples/test_single.ipynb) and [test_multi.ipynb](examples/test_multi.ipynb) 😊😊😊.
|
108
|
-
|
109
|
-
* Save all images of the given dataset.
|
110
|
-
``` python
|
111
|
-
pipeline.save_imgs_from_dataset()
|
112
|
-
```
|
113
|
-
|
114
|
-
* Calculate the metrics for the specified dataset.
|
115
|
-
``` python
|
116
|
-
pipeline.cal_metrics()
|
117
|
-
```
|
118
|
-
|
119
|
-
* Calculate the parameters (params,flops,latency) based on the established pipeline.
|
120
|
-
``` python
|
121
|
-
pipeline.cal_params()
|
122
|
-
```
|
123
|
-
|
124
|
-
### 3. Training
|
125
|
-
We provide a user-friendly code for training our provided base model (modified from the `SpikeCLIP`) for the classic `REDS` dataset introduced in `Spk2ImgNet`:
|
126
|
-
``` python
|
127
|
-
from spikezoo.pipeline import TrainPipelineConfig, TrainPipeline
|
128
|
-
from spikezoo.datasets.reds_small_dataset import REDS_Small_Config
|
129
|
-
pipeline = TrainPipeline(
|
130
|
-
cfg=TrainPipelineConfig(save_folder="results", epochs = 10),
|
131
|
-
dataset_cfg=REDS_Small_Config(root_dir = "path/REDS_Small"),
|
132
|
-
model_cfg="base",
|
133
|
-
)
|
134
|
-
pipeline.train()
|
135
|
-
```
|
136
|
-
We finish the training with one 4090 GPU in `2 minutes`, achieving `34.7dB` in PSNR and `0.94` in SSIM.
|
137
|
-
|
138
|
-
> 🌟 We encourage users to develop their models using our framework, with the tutorial being released soon.
|
139
|
-
|
140
|
-
### 4. Others
|
141
|
-
We provide a faster `load_vidar_dat` function implemented with `cpp` (by [@zeal-ye](https://github.com/zeal-ye)):
|
142
|
-
``` python
|
143
|
-
import spikezoo as sz
|
144
|
-
spike = sz.load_vidar_dat("data/scissor.dat",width = 400,height = 250,version='cpp')
|
145
|
-
```
|
146
|
-
🚀 Results on [examples/test_load_dat.py](examples/test_load_dat.py) show that the `cpp` version is more than 10 times faster than the `python` version.
|
147
|
-
|
148
|
-
## 📅 TODO
|
149
|
-
- [ ] Provide the tutorials.
|
150
|
-
- [ ] Support more training settings.
|
151
|
-
- [ ] Support more spike-based image reconstruction methods and datasets.
|
152
|
-
- [ ] Support the overall pipeline for spike simulation.
|
153
|
-
|
154
|
-
## 🤗 Supports
|
155
|
-
Run the following code to find our supported models, datasets and metrics:
|
156
|
-
``` python
|
157
|
-
import spikezoo as sz
|
158
|
-
print(sz.get_models())
|
159
|
-
print(sz.get_datasets())
|
160
|
-
print(sz.get_metrics())
|
161
|
-
```
|
162
|
-
**Supported Models:**
|
163
|
-
| Models | Source
|
164
|
-
| ---- | ---- |
|
165
|
-
| `tfp`,`tfi` | Spike camera and its coding methods |
|
166
|
-
| `spk2imgnet` | Spk2ImgNet: Learning to Reconstruct Dynamic Scene from Continuous Spike Stream |
|
167
|
-
| `wgse` | Learning Temporal-Ordered Representation for Spike Streams Based on Discrete Wavelet Transforms |
|
168
|
-
| `ssml` | Self-Supervised Mutual Learning for Dynamic Scene Reconstruction of Spiking Camera |
|
169
|
-
| `spikeformer` | SpikeFormer: Image Reconstruction from the Sequence of Spike Camera Based on Transformer |
|
170
|
-
| `ssir` | Spike Camera Image Reconstruction Using Deep Spiking Neural Networks |
|
171
|
-
| `bsf` | Boosting Spike Camera Image Reconstruction from a Perspective of Dealing with Spike Fluctuations |
|
172
|
-
| `stir` | Spatio-Temporal Interactive Learning for Efficient Image Reconstruction of Spiking Cameras |
|
173
|
-
| `spikeclip` | Rethinking High-speed Image Reconstruction Framework with Spike Camera |
|
174
|
-
|
175
|
-
**Supported Datasets:**
|
176
|
-
| Datasets | Source
|
177
|
-
| ---- | ---- |
|
178
|
-
| `reds_small` | Spk2ImgNet: Learning to Reconstruct Dynamic Scene from Continuous Spike Stream |
|
179
|
-
| `uhsr` | Recognizing Ultra-High-Speed Moving Objects with Bio-Inspired Spike Camera |
|
180
|
-
| `realworld` | `recVidarReal2019`,`momVidarReal2021` in [SpikeCV](https://github.com/Zyj061/SpikeCV) |
|
181
|
-
| `szdata` | SpikeReveal: Unlocking Temporal Sequences from Real Blurry Inputs with Spike Streams |
|
182
|
-
|
183
|
-
|
184
|
-
## ✨ Acknowledgment
|
185
|
-
Our code is built on the open-source projects of [SpikeCV](https://spikecv.github.io/), [IQA-Pytorch](https://github.com/chaofengc/IQA-PyTorch), [BasicSR](https://github.com/XPixelGroup/BasicSR) and [NeRFStudio](https://github.com/nerfstudio-project/nerfstudio).We appreciate the effort of the contributors to these repositories. Thanks for [@ruizhao26](https://github.com/ruizhao26) and [@Leozhangjiyuan](https://github.com/Leozhangjiyuan) for their help in building this project.
|
186
|
-
|
187
|
-
## 📑 Citation
|
188
|
-
If you find our codes helpful to your research, please consider to use the following citation:
|
189
|
-
```
|
190
|
-
@misc{spikezoo,
|
191
|
-
title={{Spike-Zoo}: Spike-Zoo: A Toolbox for Spike-to-Image Reconstruction},
|
192
|
-
author={Kang Chen and Zhiyuan Ye},
|
193
|
-
year={2025},
|
194
|
-
howpublished = "[Online]. Available: \url{https://github.com/chenkang455/Spike-Zoo}"
|
195
|
-
}
|
196
|
-
```
|
File without changes
|
File without changes
|
File without changes
|