spikezoo 0.1.2__py3-none-any.whl → 0.2.1__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- spikezoo/__init__.py +13 -0
- spikezoo/archs/__pycache__/__init__.cpython-39.pyc +0 -0
- spikezoo/archs/base/__pycache__/nets.cpython-39.pyc +0 -0
- spikezoo/archs/base/nets.py +34 -0
- spikezoo/archs/bsf/README.md +92 -0
- spikezoo/archs/bsf/datasets/datasets.py +328 -0
- spikezoo/archs/bsf/datasets/ds_utils.py +64 -0
- spikezoo/archs/bsf/main.py +398 -0
- spikezoo/archs/bsf/metrics/psnr.py +22 -0
- spikezoo/archs/bsf/metrics/ssim.py +54 -0
- spikezoo/archs/bsf/models/bsf/__pycache__/align.cpython-39.pyc +0 -0
- spikezoo/archs/bsf/models/bsf/__pycache__/bsf.cpython-39.pyc +0 -0
- spikezoo/archs/bsf/models/bsf/__pycache__/rep.cpython-39.pyc +0 -0
- spikezoo/archs/bsf/models/bsf/align.py +154 -0
- spikezoo/archs/bsf/models/bsf/bsf.py +105 -0
- spikezoo/archs/bsf/models/bsf/dsft_convert.py +96 -0
- spikezoo/archs/bsf/models/bsf/rep.py +44 -0
- spikezoo/archs/bsf/models/get_model.py +7 -0
- spikezoo/archs/bsf/prepare_data/DSFT.py +62 -0
- spikezoo/archs/bsf/prepare_data/crop_dataset_train.py +135 -0
- spikezoo/archs/bsf/prepare_data/crop_dataset_val.py +139 -0
- spikezoo/archs/bsf/prepare_data/crop_train.sh +4 -0
- spikezoo/archs/bsf/prepare_data/crop_val.sh +4 -0
- spikezoo/archs/bsf/prepare_data/io_utils.py +64 -0
- spikezoo/archs/bsf/requirements.txt +9 -0
- spikezoo/archs/bsf/test.py +16 -0
- spikezoo/archs/bsf/utils.py +154 -0
- spikezoo/archs/spikeclip/__pycache__/nets.cpython-39.pyc +0 -0
- spikezoo/archs/spikeclip/nets.py +40 -0
- spikezoo/archs/spikeformer/CheckPoints/readme +1 -0
- spikezoo/archs/spikeformer/DataProcess/DataExtactor.py +60 -0
- spikezoo/archs/spikeformer/DataProcess/DataLoader.py +115 -0
- spikezoo/archs/spikeformer/DataProcess/LoadSpike.py +39 -0
- spikezoo/archs/spikeformer/EvalResults/readme +1 -0
- spikezoo/archs/spikeformer/LICENSE +21 -0
- spikezoo/archs/spikeformer/Metrics/Metrics.py +50 -0
- spikezoo/archs/spikeformer/Metrics/__init__.py +0 -0
- spikezoo/archs/spikeformer/Model/Loss.py +89 -0
- spikezoo/archs/spikeformer/Model/SpikeFormer.py +230 -0
- spikezoo/archs/spikeformer/Model/__init__.py +0 -0
- spikezoo/archs/spikeformer/Model/__pycache__/SpikeFormer.cpython-39.pyc +0 -0
- spikezoo/archs/spikeformer/Model/__pycache__/__init__.cpython-39.pyc +0 -0
- spikezoo/archs/spikeformer/README.md +30 -0
- spikezoo/archs/spikeformer/evaluate.py +87 -0
- spikezoo/archs/spikeformer/recon_real_data.py +97 -0
- spikezoo/archs/spikeformer/requirements.yml +95 -0
- spikezoo/archs/spikeformer/train.py +173 -0
- spikezoo/archs/spikeformer/utils.py +22 -0
- spikezoo/archs/spk2imgnet/.github/workflows/pylint.yml +23 -0
- spikezoo/archs/spk2imgnet/.gitignore +150 -0
- spikezoo/archs/spk2imgnet/DCNv2.py +135 -0
- spikezoo/archs/spk2imgnet/__pycache__/DCNv2.cpython-39.pyc +0 -0
- spikezoo/archs/spk2imgnet/__pycache__/align_arch.cpython-39.pyc +0 -0
- spikezoo/archs/spk2imgnet/__pycache__/nets.cpython-39.pyc +0 -0
- spikezoo/archs/spk2imgnet/align_arch.py +159 -0
- spikezoo/archs/spk2imgnet/dataset.py +144 -0
- spikezoo/archs/spk2imgnet/nets.py +230 -0
- spikezoo/archs/spk2imgnet/readme.md +86 -0
- spikezoo/archs/spk2imgnet/test_gen_imgseq.py +118 -0
- spikezoo/archs/spk2imgnet/train.py +189 -0
- spikezoo/archs/spk2imgnet/utils.py +64 -0
- spikezoo/archs/ssir/README.md +87 -0
- spikezoo/archs/ssir/configs/SSIR.yml +37 -0
- spikezoo/archs/ssir/configs/yml_parser.py +78 -0
- spikezoo/archs/ssir/datasets/dataset_sreds.py +170 -0
- spikezoo/archs/ssir/datasets/ds_utils.py +66 -0
- spikezoo/archs/ssir/losses.py +21 -0
- spikezoo/archs/ssir/main.py +326 -0
- spikezoo/archs/ssir/metrics/psnr.py +22 -0
- spikezoo/archs/ssir/metrics/ssim.py +54 -0
- spikezoo/archs/ssir/models/Vgg19.py +42 -0
- spikezoo/archs/ssir/models/__pycache__/layers.cpython-39.pyc +0 -0
- spikezoo/archs/ssir/models/__pycache__/networks.cpython-39.pyc +0 -0
- spikezoo/archs/ssir/models/layers.py +110 -0
- spikezoo/archs/ssir/models/networks.py +61 -0
- spikezoo/archs/ssir/requirements.txt +8 -0
- spikezoo/archs/ssir/shells/eval_SREDS.sh +6 -0
- spikezoo/archs/ssir/shells/train_SSIR.sh +12 -0
- spikezoo/archs/ssir/test.py +3 -0
- spikezoo/archs/ssir/utils.py +154 -0
- spikezoo/archs/ssml/__pycache__/cbam.cpython-39.pyc +0 -0
- spikezoo/archs/ssml/__pycache__/model.cpython-39.pyc +0 -0
- spikezoo/archs/ssml/cbam.py +224 -0
- spikezoo/archs/ssml/model.py +290 -0
- spikezoo/archs/ssml/res.png +0 -0
- spikezoo/archs/ssml/test.py +67 -0
- spikezoo/archs/stir/.git-credentials +0 -0
- spikezoo/archs/stir/README.md +65 -0
- spikezoo/archs/stir/ckpt_outputs/Descriptions.txt +1 -0
- spikezoo/archs/stir/configs/STIR.yml +37 -0
- spikezoo/archs/stir/configs/utils.py +155 -0
- spikezoo/archs/stir/configs/yml_parser.py +78 -0
- spikezoo/archs/stir/datasets/dataset_sreds.py +180 -0
- spikezoo/archs/stir/datasets/ds_utils.py +66 -0
- spikezoo/archs/stir/eval_SREDS.sh +5 -0
- spikezoo/archs/stir/main.py +397 -0
- spikezoo/archs/stir/metrics/losses.py +219 -0
- spikezoo/archs/stir/metrics/psnr.py +22 -0
- spikezoo/archs/stir/metrics/ssim.py +54 -0
- spikezoo/archs/stir/models/Vgg19.py +42 -0
- spikezoo/archs/stir/models/__pycache__/networks_STIR.cpython-39.pyc +0 -0
- spikezoo/archs/stir/models/__pycache__/submodules.cpython-39.pyc +0 -0
- spikezoo/archs/stir/models/__pycache__/transformer_new.cpython-39.pyc +0 -0
- spikezoo/archs/stir/models/networks_STIR.py +361 -0
- spikezoo/archs/stir/models/submodules.py +86 -0
- spikezoo/archs/stir/models/transformer_new.py +151 -0
- spikezoo/archs/stir/package_core/build/lib/package_core/__init__.py +0 -0
- spikezoo/archs/stir/package_core/build/lib/package_core/convertions.py +721 -0
- spikezoo/archs/stir/package_core/build/lib/package_core/disp_netS.py +133 -0
- spikezoo/archs/stir/package_core/build/lib/package_core/flow_utils.py +167 -0
- spikezoo/archs/stir/package_core/build/lib/package_core/generic_train_test.py +76 -0
- spikezoo/archs/stir/package_core/build/lib/package_core/geometry.py +458 -0
- spikezoo/archs/stir/package_core/build/lib/package_core/image_proc.py +183 -0
- spikezoo/archs/stir/package_core/build/lib/package_core/linalg.py +40 -0
- spikezoo/archs/stir/package_core/build/lib/package_core/losses.py +198 -0
- spikezoo/archs/stir/package_core/build/lib/package_core/metrics.py +51 -0
- spikezoo/archs/stir/package_core/build/lib/package_core/model_base.py +53 -0
- spikezoo/archs/stir/package_core/build/lib/package_core/net_basics.py +100 -0
- spikezoo/archs/stir/package_core/build/lib/package_core/resnet.py +333 -0
- spikezoo/archs/stir/package_core/build/lib/package_core/transforms.py +123 -0
- spikezoo/archs/stir/package_core/build/lib/package_core/utils.py +72 -0
- spikezoo/archs/stir/package_core/dist/package_core-0.0.0-py3.9.egg +0 -0
- spikezoo/archs/stir/package_core/package_core/__init__.py +0 -0
- spikezoo/archs/stir/package_core/package_core/__pycache__/__init__.cpython-39.pyc +0 -0
- spikezoo/archs/stir/package_core/package_core/__pycache__/net_basics.cpython-39.pyc +0 -0
- spikezoo/archs/stir/package_core/package_core/convertions.py +721 -0
- spikezoo/archs/stir/package_core/package_core/disp_netS.py +133 -0
- spikezoo/archs/stir/package_core/package_core/flow_utils.py +167 -0
- spikezoo/archs/stir/package_core/package_core/generic_train_test.py +76 -0
- spikezoo/archs/stir/package_core/package_core/geometry.py +458 -0
- spikezoo/archs/stir/package_core/package_core/image_proc.py +183 -0
- spikezoo/archs/stir/package_core/package_core/linalg.py +40 -0
- spikezoo/archs/stir/package_core/package_core/losses.py +198 -0
- spikezoo/archs/stir/package_core/package_core/metrics.py +51 -0
- spikezoo/archs/stir/package_core/package_core/model_base.py +53 -0
- spikezoo/archs/stir/package_core/package_core/net_basics.py +100 -0
- spikezoo/archs/stir/package_core/package_core/resnet.py +333 -0
- spikezoo/archs/stir/package_core/package_core/transforms.py +123 -0
- spikezoo/archs/stir/package_core/package_core/utils.py +72 -0
- spikezoo/archs/stir/package_core/package_core.egg-info/PKG-INFO +3 -0
- spikezoo/archs/stir/package_core/package_core.egg-info/SOURCES.txt +20 -0
- spikezoo/archs/stir/package_core/package_core.egg-info/dependency_links.txt +1 -0
- spikezoo/archs/stir/package_core/package_core.egg-info/top_level.txt +1 -0
- spikezoo/archs/stir/package_core/setup.py +5 -0
- spikezoo/archs/stir/requirements.txt +12 -0
- spikezoo/archs/stir/train_STIR.sh +9 -0
- spikezoo/archs/tfi/__pycache__/nets.cpython-39.pyc +0 -0
- spikezoo/archs/tfi/nets.py +43 -0
- spikezoo/archs/tfp/__pycache__/nets.cpython-39.pyc +0 -0
- spikezoo/archs/tfp/nets.py +13 -0
- spikezoo/archs/wgse/README.md +64 -0
- spikezoo/archs/wgse/__pycache__/dwtnets.cpython-39.pyc +0 -0
- spikezoo/archs/wgse/__pycache__/submodules.cpython-39.pyc +0 -0
- spikezoo/archs/wgse/dataset.py +59 -0
- spikezoo/archs/wgse/demo.png +0 -0
- spikezoo/archs/wgse/demo.py +83 -0
- spikezoo/archs/wgse/dwtnets.py +145 -0
- spikezoo/archs/wgse/eval.py +133 -0
- spikezoo/archs/wgse/logs/WGSE-Dwt1dNet-db8-5-ks3/log.txt +11 -0
- spikezoo/archs/wgse/submodules.py +68 -0
- spikezoo/archs/wgse/train.py +261 -0
- spikezoo/archs/wgse/transform.py +139 -0
- spikezoo/archs/wgse/utils.py +128 -0
- spikezoo/archs/wgse/weights/demo.png +0 -0
- spikezoo/data/base/test/gt/200_part1_key_id151.png +0 -0
- spikezoo/data/base/test/gt/200_part3_key_id151.png +0 -0
- spikezoo/data/base/test/gt/203_part1_key_id151.png +0 -0
- spikezoo/data/base/test/spike/200_part1_key_id151.dat +0 -0
- spikezoo/data/base/test/spike/200_part3_key_id151.dat +0 -0
- spikezoo/data/base/test/spike/203_part1_key_id151.dat +0 -0
- spikezoo/data/base/train/gt/203_part2_key_id151.png +0 -0
- spikezoo/data/base/train/gt/203_part3_key_id151.png +0 -0
- spikezoo/data/base/train/gt/203_part4_key_id151.png +0 -0
- spikezoo/data/base/train/spike/203_part2_key_id151.dat +0 -0
- spikezoo/data/base/train/spike/203_part3_key_id151.dat +0 -0
- spikezoo/data/base/train/spike/203_part4_key_id151.dat +0 -0
- spikezoo/datasets/base_dataset.py +2 -3
- spikezoo/metrics/__init__.py +1 -1
- spikezoo/models/base_model.py +1 -3
- spikezoo/pipeline/base_pipeline.py +7 -5
- spikezoo/pipeline/train_pipeline.py +1 -1
- spikezoo/utils/other_utils.py +16 -6
- spikezoo/utils/spike_utils.py +33 -29
- spikezoo/utils/vidar_loader.cpython-39-x86_64-linux-gnu.so +0 -0
- spikezoo-0.2.1.dist-info/METADATA +167 -0
- spikezoo-0.2.1.dist-info/RECORD +211 -0
- spikezoo/models/spcsnet_model.py +0 -19
- spikezoo-0.1.2.dist-info/METADATA +0 -39
- spikezoo-0.1.2.dist-info/RECORD +0 -36
- {spikezoo-0.1.2.dist-info → spikezoo-0.2.1.dist-info}/LICENSE.txt +0 -0
- {spikezoo-0.1.2.dist-info → spikezoo-0.2.1.dist-info}/WHEEL +0 -0
- {spikezoo-0.1.2.dist-info → spikezoo-0.2.1.dist-info}/top_level.txt +0 -0
spikezoo/__init__.py
CHANGED
@@ -0,0 +1,13 @@
|
|
1
|
+
from .utils.spike_utils import load_vidar_dat
|
2
|
+
from .models import model_list
|
3
|
+
from .datasets import dataset_list
|
4
|
+
from .metrics import metric_all_names
|
5
|
+
|
6
|
+
def get_datasets():
|
7
|
+
return dataset_list
|
8
|
+
|
9
|
+
def get_models():
|
10
|
+
return model_list
|
11
|
+
|
12
|
+
def get_metrics():
|
13
|
+
return metric_all_names
|
Binary file
|
Binary file
|
@@ -0,0 +1,34 @@
|
|
1
|
+
import torch.nn as nn
|
2
|
+
|
3
|
+
def conv_layer(inDim, outDim, ks, s, p, norm_layer='none'):
|
4
|
+
## convolutional layer
|
5
|
+
conv = nn.Conv2d(inDim, outDim, kernel_size=ks, stride=s, padding=p)
|
6
|
+
relu = nn.ReLU(True)
|
7
|
+
assert norm_layer in ('batch', 'instance', 'none')
|
8
|
+
if norm_layer == 'none':
|
9
|
+
seq = nn.Sequential(*[conv, relu])
|
10
|
+
else:
|
11
|
+
if (norm_layer == 'instance'):
|
12
|
+
norm = nn.InstanceNorm2d(outDim, affine=False, track_running_stats=False) # instance norm
|
13
|
+
else:
|
14
|
+
momentum = 0.1
|
15
|
+
norm = nn.BatchNorm2d(outDim, momentum = momentum, affine=True, track_running_stats=True)
|
16
|
+
seq = nn.Sequential(*[conv, norm, relu])
|
17
|
+
return seq
|
18
|
+
|
19
|
+
class BaseNet(nn.Module):
|
20
|
+
"""Borrow the structure from the SpikeCLIP. (https://arxiv.org/abs/2501.04477)"""
|
21
|
+
def __init__(self, inDim=41):
|
22
|
+
super(BaseNet, self).__init__()
|
23
|
+
norm='none'
|
24
|
+
outDim=1
|
25
|
+
convBlock1 = conv_layer(inDim,64,3,1,1)
|
26
|
+
convBlock2 = conv_layer(64,128,3,1,1,norm)
|
27
|
+
convBlock3 = conv_layer(128,64,3,1,1,norm)
|
28
|
+
convBlock4 = conv_layer(64,16,3,1,1,norm)
|
29
|
+
conv = nn.Conv2d(16, outDim, 3, 1, 1)
|
30
|
+
self.seq = nn.Sequential(*[convBlock1, convBlock2, convBlock3, convBlock4, conv])
|
31
|
+
|
32
|
+
def forward(self,x):
|
33
|
+
return self.seq(x)
|
34
|
+
|
@@ -0,0 +1,92 @@
|
|
1
|
+
## [CVPR 2024] Boosting Spike Camera Image Reconstruction from a Perspective of Dealing with Spike Fluctuations
|
2
|
+
|
3
|
+
<h4 align="center"> Rui Zhao<sup>1,2</sup>, Ruiqin Xiong<sup>1,2</sup>, Jing Zhao<sup>1,2</sup>, Jian Zhang<sup>3</sup>, Xiaopeng Fan<sup>4</sup>, Zhaofei Yu<sup>1,2</sup>, Tiejun Huang<sup>1,2</sup> </h4>
|
4
|
+
<h4 align="center">1. School of Computer Science, Peking University<br>
|
5
|
+
2. National Key Laboratory for Multimedia Information Processing, Peking University<br>
|
6
|
+
3. School of Electronic and Computer Engineering, Peking University<br>
|
7
|
+
4. School of Computer Science and Technology, Harbin Institute of Technology
|
8
|
+
</h4><br>
|
9
|
+
|
10
|
+
This repository contains the official source code for our paper:
|
11
|
+
|
12
|
+
Boosting Spike Camera Image Reconstruction from a Perspective of Dealing with Spike Fluctuations
|
13
|
+
|
14
|
+
CVPR 2024
|
15
|
+
|
16
|
+
## Environment
|
17
|
+
|
18
|
+
You can choose cudatoolkit version to match your server. The code is tested on PyTorch 2.0.1+cu120.
|
19
|
+
|
20
|
+
```bash
|
21
|
+
conda create -n bsf python==3.10.9
|
22
|
+
conda activate bsf
|
23
|
+
# You can choose the PyTorch version you like, we recommand version >= 1.10.1
|
24
|
+
# For example
|
25
|
+
pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
|
26
|
+
pip3 install -r requirements.txt
|
27
|
+
```
|
28
|
+
|
29
|
+
## Prepare the Data
|
30
|
+
|
31
|
+
##### 1. Download the dataset (Approximate 50GB)
|
32
|
+
|
33
|
+
[Link of the dataset (BaiduNetDisk)](https://pan.baidu.com/s/1zBp-ed1KtmhAab5Z_62ttw) (Password: 2728)
|
34
|
+
|
35
|
+
##### 2. Deploy the dataset for training faster (Approximate <u>another</u> 125GB)
|
36
|
+
|
37
|
+
firstly modify the data root and output root in `./prepare_data/crop_dataset_train.py` and `./prepare_data/crop_dataset_val.py`
|
38
|
+
|
39
|
+
```shell
|
40
|
+
cd prepare_data &&
|
41
|
+
bash crop_train.sh $your_gpu_id &&
|
42
|
+
bash crop_val.sh $your_gpu_id
|
43
|
+
```
|
44
|
+
|
45
|
+
## Evaluate
|
46
|
+
|
47
|
+
```shell
|
48
|
+
CUDA_VISIBLE_DEVICES=$1 python3 -W ignore main.py \
|
49
|
+
--alpha 0.7 \
|
50
|
+
--vis-path vis/bsf \
|
51
|
+
-evp eval_vis/bsf \
|
52
|
+
--logs_file_name bsf \
|
53
|
+
--compile_model \
|
54
|
+
--test_eval \
|
55
|
+
--arch bsf \
|
56
|
+
--pretrained ckpt/bsf.pth
|
57
|
+
```
|
58
|
+
|
59
|
+
## Train
|
60
|
+
|
61
|
+
```shell
|
62
|
+
CUDA_VISIBLE_DEVICES=$1 python3 -W ignore main.py \
|
63
|
+
-bs 8 \
|
64
|
+
-j 8 \
|
65
|
+
-lr 1e-4 \
|
66
|
+
--epochs 61 \
|
67
|
+
--train-res 96 96 \
|
68
|
+
--lr-scale-factor 0.5 \
|
69
|
+
--milestones 10 20 30 40 50 60 70 80 90 100 \
|
70
|
+
--alpha 0.7 \
|
71
|
+
--vis-path vis/bsf \
|
72
|
+
-evp eval_vis/bsf \
|
73
|
+
--logs_file_name bsf \
|
74
|
+
--compile_model \
|
75
|
+
--weight_decay 0.0 \
|
76
|
+
--eval-interval 10 \
|
77
|
+
--half_reserve 0 \
|
78
|
+
--arch bsf
|
79
|
+
```
|
80
|
+
|
81
|
+
## Citations
|
82
|
+
|
83
|
+
If you find this code useful in your research, please consider citing our paper:
|
84
|
+
|
85
|
+
```
|
86
|
+
@inproceedings{zhao2024boosting,
|
87
|
+
title={Boosting Spike Camera Image Reconstruction from a Perspective of Dealing with Spike Fluctuations},
|
88
|
+
author={Zhao, Rui and Xiong, Ruiqin and Zhao, Jing and Zhang, Jian and Fan, Xiaopeng and Yu, Zhaofei, and Huang, Tiejun},
|
89
|
+
booktitle={IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
|
90
|
+
year={2024}
|
91
|
+
}
|
92
|
+
```
|
@@ -0,0 +1,328 @@
|
|
1
|
+
import os
|
2
|
+
import os.path as osp
|
3
|
+
import random
|
4
|
+
import numpy as np
|
5
|
+
import torch
|
6
|
+
import torch.utils.data as data
|
7
|
+
from datasets.ds_utils import *
|
8
|
+
import h5py
|
9
|
+
from tqdm import *
|
10
|
+
|
11
|
+
|
12
|
+
class Augmentor:
|
13
|
+
def __init__(self, crop_size):
|
14
|
+
# spatial augmentation params
|
15
|
+
self.crop_size = crop_size
|
16
|
+
|
17
|
+
def augment_img(self, img, mode=0):
|
18
|
+
'''Kai Zhang (github: https://github.com/cszn)
|
19
|
+
W x H x C or W x H
|
20
|
+
注:要使用此种augmentation, 则需保证crop_h = crop_w
|
21
|
+
'''
|
22
|
+
if mode == 0:
|
23
|
+
return img
|
24
|
+
elif mode == 1:
|
25
|
+
return np.flipud(np.rot90(img))
|
26
|
+
elif mode == 2:
|
27
|
+
return np.flipud(img)
|
28
|
+
elif mode == 3:
|
29
|
+
return np.rot90(img, k=3)
|
30
|
+
elif mode == 4:
|
31
|
+
return np.flipud(np.rot90(img, k=2))
|
32
|
+
elif mode == 5:
|
33
|
+
return np.rot90(img)
|
34
|
+
elif mode == 6:
|
35
|
+
return np.rot90(img, k=2)
|
36
|
+
elif mode == 7:
|
37
|
+
return np.flipud(np.rot90(img, k=3))
|
38
|
+
|
39
|
+
def spatial_transform(self, spk_list, img_list):
|
40
|
+
mode = random.randint(0, 7)
|
41
|
+
spike_h = spk_list[0].shape[1]
|
42
|
+
spike_w = spk_list[0].shape[2]
|
43
|
+
|
44
|
+
if spike_h > self.crop_size[0]:
|
45
|
+
y0 = np.random.randint(0, spike_h - self.crop_size[0])
|
46
|
+
else:
|
47
|
+
y0 = 0
|
48
|
+
|
49
|
+
if spike_w > self.crop_size[1]:
|
50
|
+
x0 = np.random.randint(0, spike_w - self.crop_size[1])
|
51
|
+
else:
|
52
|
+
x0 = 0
|
53
|
+
|
54
|
+
for ii, spk in enumerate(spk_list):
|
55
|
+
spk = np.transpose(spk, [1,2,0])
|
56
|
+
spk = spk[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1], :]
|
57
|
+
spk = self.augment_img(spk, mode=mode)
|
58
|
+
spk_list[ii] = np.transpose(spk, [2,0,1])
|
59
|
+
|
60
|
+
for ii, img in enumerate(img_list):
|
61
|
+
img = np.transpose(img, [1,2,0])
|
62
|
+
img = img[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1], :]
|
63
|
+
img = self.augment_img(img, mode=mode)
|
64
|
+
img_list[ii] = np.transpose(img, [2,0,1])
|
65
|
+
|
66
|
+
return spk_list, img_list
|
67
|
+
|
68
|
+
def __call__(self, spk_list, img_list):
|
69
|
+
spk_list, img_list = self.spatial_transform(spk_list, img_list)
|
70
|
+
spk_list = [np.ascontiguousarray(spk) for spk in spk_list]
|
71
|
+
img_list = [np.ascontiguousarray(img) for img in img_list]
|
72
|
+
return spk_list, img_list
|
73
|
+
|
74
|
+
|
75
|
+
|
76
|
+
class sreds_train(torch.utils.data.Dataset):
|
77
|
+
'''
|
78
|
+
测试集Spike原始分辨率 148 x 256
|
79
|
+
'''
|
80
|
+
def __init__(self, args):
|
81
|
+
self.args = args
|
82
|
+
self.input_type = args.input_type
|
83
|
+
self.eta_list = args.eta_list
|
84
|
+
self.gamma = args.gamma
|
85
|
+
self.alpha = args.alpha
|
86
|
+
self.augmentor = Augmentor(crop_size=args.train_res)
|
87
|
+
|
88
|
+
self.dsft_path_name = 'dsft'
|
89
|
+
self.spike_path_name = 'spikes'
|
90
|
+
|
91
|
+
self.read_dsft = not args.no_dsft
|
92
|
+
|
93
|
+
self.samples = self.collect_samples()
|
94
|
+
print('The samples num of training data: {:d}'.format(len(self.samples)))
|
95
|
+
|
96
|
+
def confirm_exist(self, path_list_list):
|
97
|
+
for pl in path_list_list:
|
98
|
+
for p in pl:
|
99
|
+
if not osp.exists(p):
|
100
|
+
return 0
|
101
|
+
return 1
|
102
|
+
|
103
|
+
def collect_samples(self):
|
104
|
+
samples = []
|
105
|
+
root_path = osp.join(self.args.data_root, 'crop', 'train')
|
106
|
+
|
107
|
+
for eta in self.eta_list:
|
108
|
+
cur_eta_dir = osp.join(root_path, "eta_{:.2f}_gamma_{:d}_alpha_{:.1f}".format(eta, self.gamma, self.alpha))
|
109
|
+
scene_list = sorted(os.listdir(cur_eta_dir))
|
110
|
+
|
111
|
+
for scene in scene_list:
|
112
|
+
scene_path = osp.join(cur_eta_dir, scene)
|
113
|
+
crop_list = sorted(os.listdir(scene_path))
|
114
|
+
for crop in crop_list:
|
115
|
+
crop_path = osp.join(scene_path, crop)
|
116
|
+
spike_dir = osp.join(crop_path, self.spike_path_name)
|
117
|
+
image_dir = osp.join(root_path, 'imgs', scene, crop)
|
118
|
+
dsft_dir = osp.join(crop_path, 'dsft')
|
119
|
+
|
120
|
+
## 数据集的制作:dsft 从 09~20.h5, img从 10~19.png
|
121
|
+
spikes_path_list = [osp.join(spike_dir, '{:08d}.dat'.format(ii)) for ii in range(11, 28+1)]
|
122
|
+
dsft_path_list = [osp.join(dsft_dir, '{:08d}.h5'.format(ii)) for ii in range(11, 28+1)]
|
123
|
+
images00_path_list = [osp.join(image_dir, '{:08d}.png'.format(ii)) for ii in range(18, 21+1)]
|
124
|
+
# images05_path_list = [osp.join(image_dir, '{:08d}_05.png'.format(ii)) for ii in range(8, 11+1)]
|
125
|
+
|
126
|
+
if(self.confirm_exist([spikes_path_list, images00_path_list])):
|
127
|
+
s = {}
|
128
|
+
s['spikes_paths'] = spikes_path_list
|
129
|
+
s['dsft_paths'] = dsft_path_list
|
130
|
+
s['images_paths'] = images00_path_list
|
131
|
+
s['norm_fac'] = eta * self.alpha
|
132
|
+
# s['images_05_paths'] = images05_path_list
|
133
|
+
samples.append(s)
|
134
|
+
return samples
|
135
|
+
|
136
|
+
def _load_sample(self, s):
|
137
|
+
## 一组数据中有4个时间点可以做key-frame,抽其中一个作为一次采样
|
138
|
+
## images只有四个,分别是18, 19, 20, 21,直接对应于offset的1,2,3,4
|
139
|
+
## spikes和dsfts都比较多,所使用的key是{18, 19, 20, 21},也即对应于spike和dsft的path list中的{7,8,9,10}index
|
140
|
+
key_frame_offset = random.choice([0,1,2,3])
|
141
|
+
s['spikes_paths'] = s['spikes_paths'][7+key_frame_offset-3-self.args.half_reserve : 7+key_frame_offset+3+self.args.half_reserve+1]
|
142
|
+
s['dsft_paths'] = s['dsft_paths'][7+key_frame_offset-3-self.args.half_reserve : 7+key_frame_offset+3+self.args.half_reserve+1]
|
143
|
+
|
144
|
+
## 第一个Key是13.dat, imgs从10开始,应该是 key_frame_offset+3-2
|
145
|
+
s['images_paths'] = [s['images_paths'][key_frame_offset]]
|
146
|
+
|
147
|
+
data = {}
|
148
|
+
if self.read_dsft:
|
149
|
+
## 读入Spike
|
150
|
+
h5files = [h5py.File(p, 'r') for p in s['dsft_paths']]
|
151
|
+
data['dsft'] = [np.array(f['dsft']).astype(np.float32) for f in h5files]
|
152
|
+
for f in h5files:
|
153
|
+
f.close()
|
154
|
+
data['spikes'] = [dat_to_spmat(p, size=(256, 256)).astype(np.float32) for p in s['spikes_paths']]
|
155
|
+
|
156
|
+
## 读入 Image
|
157
|
+
data['images'] = [read_img_gray(p) for p in s['images_paths']]
|
158
|
+
data['norm_fac'] = np.array(s['norm_fac'])
|
159
|
+
|
160
|
+
if self.read_dsft:
|
161
|
+
data['spikes'] = data['spikes'] + data['dsft']
|
162
|
+
data['spikes'], data['images'] = self.augmentor(data['spikes'], data['images'])
|
163
|
+
data['spikes'], data['dsft'] = data['spikes'][:len(data['spikes'])//2], data['spikes'][len(data['spikes'])//2:]
|
164
|
+
else:
|
165
|
+
data['spikes'], data['images'] = self.augmentor(data['spikes'], data['images'])
|
166
|
+
|
167
|
+
return data
|
168
|
+
|
169
|
+
def __len__(self):
|
170
|
+
return len(self.samples)
|
171
|
+
|
172
|
+
def __getitem__(self, index):
|
173
|
+
data = self._load_sample(self.samples[index])
|
174
|
+
return data
|
175
|
+
|
176
|
+
|
177
|
+
class sreds_test(torch.utils.data.Dataset):
|
178
|
+
'''
|
179
|
+
测试集Spike原始分辨率 540 x 960
|
180
|
+
'''
|
181
|
+
def __init__(self, args, eta):
|
182
|
+
self.args = args
|
183
|
+
self.input_type = args.input_type
|
184
|
+
self.alpha = args.alpha
|
185
|
+
self.eta = eta
|
186
|
+
self.gamma = args.gamma
|
187
|
+
self.dsft_path_name = 'dsft'
|
188
|
+
self.spike_path_name = 'spikes'
|
189
|
+
self.samples = self.collect_samples()
|
190
|
+
print('The samples num of testing data: {:d}'.format(len(self.samples)))
|
191
|
+
|
192
|
+
def confirm_exist(self, path_list_list):
|
193
|
+
for pl in path_list_list:
|
194
|
+
for p in pl:
|
195
|
+
if not osp.exists(p):
|
196
|
+
print(p)
|
197
|
+
return 0
|
198
|
+
return 1
|
199
|
+
|
200
|
+
def collect_samples(self):
|
201
|
+
root_path = osp.join(self.args.data_root, 'crop', 'val')
|
202
|
+
|
203
|
+
cur_eta_dir = osp.join(root_path, "eta_{:.2f}_gamma_{:d}_alpha_{:.1f}".format(self.eta, self.gamma, self.alpha))
|
204
|
+
scene_list = sorted(os.listdir(cur_eta_dir))
|
205
|
+
samples = []
|
206
|
+
|
207
|
+
for scene in scene_list:
|
208
|
+
scene_path = osp.join(cur_eta_dir, scene)
|
209
|
+
spike_dir = osp.join(scene_path, self.spike_path_name)
|
210
|
+
image_dir = osp.join(root_path, 'imgs', scene)
|
211
|
+
dsft_dir = osp.join(scene_path, 'dsft')
|
212
|
+
|
213
|
+
## 数据集的制作:dsft 从 09~20.h5, img从 10~19.png
|
214
|
+
spikes_path_list = [osp.join(spike_dir, '{:08d}.dat'.format(ii)) for ii in range(11, 28+1)]
|
215
|
+
dsft_path_list = [osp.join(dsft_dir, '{:08d}.h5'.format(ii)) for ii in range(11, 28+1)]
|
216
|
+
images_path_list = [osp.join(image_dir, '{:08d}.png'.format(ii)) for ii in range(18, 21+1)]
|
217
|
+
|
218
|
+
if(self.confirm_exist([spikes_path_list, images_path_list])):
|
219
|
+
## 在test函数里测试四组数据
|
220
|
+
## images只有四个,分别是18, 19, 20, 21,直接对应于offset的1,2,3,4
|
221
|
+
## spikes和dsfts都比较多,所使用的key是{18, 19, 20, 21},也即对应于spike和dsft的path list中的{7,8,9,10}index
|
222
|
+
for ii in range(4):
|
223
|
+
# for ii in range(1):
|
224
|
+
s = {}
|
225
|
+
s['spikes_paths'] = spikes_path_list[7+ii-3-self.args.half_reserve : 7+ii+3+self.args.half_reserve+1]
|
226
|
+
s['dsft_paths'] = dsft_path_list[7+ii-3-self.args.half_reserve : 7+ii+3+self.args.half_reserve+1]
|
227
|
+
s['images_paths'] = [images_path_list[ii]]
|
228
|
+
s['norm_fac'] = self.alpha * self.eta
|
229
|
+
samples.append(s)
|
230
|
+
|
231
|
+
return samples
|
232
|
+
|
233
|
+
def _load_sample(self, s):
|
234
|
+
data = {}
|
235
|
+
h5files = [h5py.File(p, 'r') for p in s['dsft_paths']]
|
236
|
+
data['dsft'] = [np.array(f['dsft']).astype(np.float32) for f in h5files]
|
237
|
+
for f in h5files:
|
238
|
+
f.close()
|
239
|
+
data['spikes'] = [dat_to_spmat(p, size=(540, 960)).astype(np.float32) for p in s['spikes_paths']]
|
240
|
+
|
241
|
+
data['images'] = [read_img_gray(p) for p in s['images_paths']]
|
242
|
+
data['norm_fac'] = np.array(s['norm_fac'])
|
243
|
+
return data
|
244
|
+
|
245
|
+
def __len__(self):
|
246
|
+
return len(self.samples)
|
247
|
+
|
248
|
+
def __getitem__(self, index):
|
249
|
+
data = self._load_sample(self.samples[index])
|
250
|
+
return data
|
251
|
+
|
252
|
+
class sreds_test_small(torch.utils.data.Dataset):
|
253
|
+
'''
|
254
|
+
测试集Spike原始分辨率 384 x 512
|
255
|
+
'''
|
256
|
+
def __init__(self, args, eta):
|
257
|
+
self.args = args
|
258
|
+
self.input_type = args.input_type
|
259
|
+
self.alpha = args.alpha
|
260
|
+
self.eta = eta
|
261
|
+
self.gamma = args.gamma
|
262
|
+
self.dsft_path_name = 'dsft'
|
263
|
+
self.spike_path_name = 'spikes'
|
264
|
+
|
265
|
+
self.read_dsft = not args.no_dsft
|
266
|
+
self.samples = self.collect_samples()
|
267
|
+
print('The samples num of testing data: {:d}'.format(len(self.samples)))
|
268
|
+
|
269
|
+
def confirm_exist(self, path_list_list):
|
270
|
+
for pl in path_list_list:
|
271
|
+
for p in pl:
|
272
|
+
if not osp.exists(p):
|
273
|
+
print(p)
|
274
|
+
return 0
|
275
|
+
return 1
|
276
|
+
|
277
|
+
def collect_samples(self):
|
278
|
+
root_path = osp.join(self.args.data_root, 'crop', 'val_small')
|
279
|
+
|
280
|
+
cur_eta_dir = osp.join(root_path, "eta_{:.2f}_gamma_{:d}_alpha_{:.1f}".format(self.eta, self.gamma, self.alpha))
|
281
|
+
scene_list = sorted(os.listdir(cur_eta_dir))
|
282
|
+
samples = []
|
283
|
+
|
284
|
+
for scene in scene_list:
|
285
|
+
scene_path = osp.join(cur_eta_dir, scene)
|
286
|
+
spike_dir = osp.join(scene_path, self.spike_path_name)
|
287
|
+
image_dir = osp.join(root_path, 'imgs', scene)
|
288
|
+
dsft_dir = osp.join(scene_path, 'dsft')
|
289
|
+
|
290
|
+
## 数据集的制作:dsft 从 09~20.h5, img从 10~19.png
|
291
|
+
spikes_path_list = [osp.join(spike_dir, '{:08d}.dat'.format(ii)) for ii in range(11, 28+1)]
|
292
|
+
dsft_path_list = [osp.join(dsft_dir, '{:08d}.h5'.format(ii)) for ii in range(11, 28+1)]
|
293
|
+
images_path_list = [osp.join(image_dir, '{:08d}.png'.format(ii)) for ii in range(18, 21+1)]
|
294
|
+
|
295
|
+
if(self.confirm_exist([spikes_path_list, images_path_list])):
|
296
|
+
# for ii in range(4):
|
297
|
+
for ii in range(4):
|
298
|
+
s = {}
|
299
|
+
s['spikes_paths'] = spikes_path_list[7+ii-3-self.args.half_reserve : 7+ii+3+self.args.half_reserve+1]
|
300
|
+
s['dsft_paths'] = dsft_path_list[7+ii-3-self.args.half_reserve : 7+ii+3+self.args.half_reserve+1]
|
301
|
+
s['images_paths'] = [images_path_list[ii]]
|
302
|
+
s['norm_fac'] = self.alpha * self.eta
|
303
|
+
samples.append(s)
|
304
|
+
|
305
|
+
return samples
|
306
|
+
|
307
|
+
def _load_sample(self, s):
|
308
|
+
## 在test函数里测试四组数据
|
309
|
+
## spikes全取
|
310
|
+
## image取四个key对应的[13, 14, 15, 16]
|
311
|
+
data = {}
|
312
|
+
if self.read_dsft:
|
313
|
+
h5files = [h5py.File(p, 'r') for p in s['dsft_paths']]
|
314
|
+
data['dsft'] = [np.array(f['dsft']).astype(np.float32) for f in h5files]
|
315
|
+
for f in h5files:
|
316
|
+
f.close()
|
317
|
+
data['spikes'] = [dat_to_spmat(p, size=(384, 512)).astype(np.float32) for p in s['spikes_paths']]
|
318
|
+
|
319
|
+
data['images'] = [read_img_gray(p) for p in s['images_paths']]
|
320
|
+
data['norm_fac'] = np.array(s['norm_fac'])
|
321
|
+
return data
|
322
|
+
|
323
|
+
def __len__(self):
|
324
|
+
return len(self.samples)
|
325
|
+
|
326
|
+
def __getitem__(self, index):
|
327
|
+
data = self._load_sample(self.samples[index])
|
328
|
+
return data
|
@@ -0,0 +1,64 @@
|
|
1
|
+
import numpy as np
|
2
|
+
import cv2
|
3
|
+
|
4
|
+
def RawToSpike(video_seq, h, w, flipud=True):
|
5
|
+
video_seq = np.array(video_seq).astype(np.uint8)
|
6
|
+
img_size = h*w
|
7
|
+
img_num = len(video_seq)//(img_size//8)
|
8
|
+
SpikeMatrix = np.zeros([img_num, h, w], np.uint8)
|
9
|
+
pix_id = np.arange(0,h*w)
|
10
|
+
pix_id = np.reshape(pix_id, (h, w))
|
11
|
+
comparator = np.left_shift(1, np.mod(pix_id, 8))
|
12
|
+
byte_id = pix_id // 8
|
13
|
+
|
14
|
+
for img_id in np.arange(img_num):
|
15
|
+
id_start = img_id*img_size//8
|
16
|
+
id_end = id_start + img_size//8
|
17
|
+
cur_info = video_seq[id_start:id_end]
|
18
|
+
data = cur_info[byte_id]
|
19
|
+
result = np.bitwise_and(data, comparator)
|
20
|
+
if flipud:
|
21
|
+
SpikeMatrix[img_id, :, :] = np.flipud((result == comparator))
|
22
|
+
else:
|
23
|
+
SpikeMatrix[img_id, :, :] = (result == comparator)
|
24
|
+
|
25
|
+
return SpikeMatrix
|
26
|
+
|
27
|
+
|
28
|
+
def SpikeToRaw(SpikeSeq, save_path):
|
29
|
+
"""
|
30
|
+
SpikeSeq: Numpy array (sfn x h x w)
|
31
|
+
save_path: full saving path (string)
|
32
|
+
Rui Zhao
|
33
|
+
"""
|
34
|
+
sfn, h, w = SpikeSeq.shape
|
35
|
+
base = np.power(2, np.linspace(0, 7, 8))
|
36
|
+
fid = open(save_path, 'ab')
|
37
|
+
for img_id in range(sfn):
|
38
|
+
# 模拟相机的倒像
|
39
|
+
spike = np.flipud(SpikeSeq[img_id, :, :])
|
40
|
+
# numpy按自动按行排,数据也是按行存的
|
41
|
+
spike = spike.flatten()
|
42
|
+
spike = spike.reshape([int(h*w/8), 8])
|
43
|
+
data = spike * base
|
44
|
+
data = np.sum(data, axis=1).astype(np.uint8)
|
45
|
+
fid.write(data.tobytes())
|
46
|
+
|
47
|
+
fid.close()
|
48
|
+
|
49
|
+
return
|
50
|
+
|
51
|
+
|
52
|
+
def dat_to_spmat(dat_path, size=[720, 1280]):
|
53
|
+
f = open(dat_path, 'rb')
|
54
|
+
video_seq = f.read()
|
55
|
+
video_seq = np.frombuffer(video_seq, 'b')
|
56
|
+
sp_mat = RawToSpike(video_seq, size[0], size[1])
|
57
|
+
return sp_mat
|
58
|
+
|
59
|
+
|
60
|
+
def read_img_gray(file_path):
|
61
|
+
im = cv2.imread(file_path).astype(np.float32) / 255.0
|
62
|
+
im = cv2.cvtColor(im, cv2.COLOR_BGR2GRAY)
|
63
|
+
im = np.expand_dims(im, axis=0)
|
64
|
+
return im
|