spikezoo 0.1.1__py3-none-any.whl → 0.2__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- spikezoo/__init__.py +13 -0
- spikezoo/archs/__pycache__/__init__.cpython-39.pyc +0 -0
- spikezoo/archs/base/__pycache__/nets.cpython-39.pyc +0 -0
- spikezoo/archs/base/nets.py +34 -0
- spikezoo/archs/bsf/README.md +92 -0
- spikezoo/archs/bsf/datasets/datasets.py +328 -0
- spikezoo/archs/bsf/datasets/ds_utils.py +64 -0
- spikezoo/archs/bsf/main.py +398 -0
- spikezoo/archs/bsf/metrics/psnr.py +22 -0
- spikezoo/archs/bsf/metrics/ssim.py +54 -0
- spikezoo/archs/bsf/models/bsf/__pycache__/align.cpython-39.pyc +0 -0
- spikezoo/archs/bsf/models/bsf/__pycache__/bsf.cpython-39.pyc +0 -0
- spikezoo/archs/bsf/models/bsf/__pycache__/rep.cpython-39.pyc +0 -0
- spikezoo/archs/bsf/models/bsf/align.py +154 -0
- spikezoo/archs/bsf/models/bsf/bsf.py +105 -0
- spikezoo/archs/bsf/models/bsf/dsft_convert.py +96 -0
- spikezoo/archs/bsf/models/bsf/rep.py +44 -0
- spikezoo/archs/bsf/models/get_model.py +7 -0
- spikezoo/archs/bsf/prepare_data/DSFT.py +62 -0
- spikezoo/archs/bsf/prepare_data/crop_dataset_train.py +135 -0
- spikezoo/archs/bsf/prepare_data/crop_dataset_val.py +139 -0
- spikezoo/archs/bsf/prepare_data/crop_train.sh +4 -0
- spikezoo/archs/bsf/prepare_data/crop_val.sh +4 -0
- spikezoo/archs/bsf/prepare_data/io_utils.py +64 -0
- spikezoo/archs/bsf/requirements.txt +9 -0
- spikezoo/archs/bsf/test.py +16 -0
- spikezoo/archs/bsf/utils.py +154 -0
- spikezoo/archs/spikeclip/__pycache__/nets.cpython-39.pyc +0 -0
- spikezoo/archs/spikeclip/nets.py +40 -0
- spikezoo/archs/spikeformer/CheckPoints/readme +1 -0
- spikezoo/archs/spikeformer/DataProcess/DataExtactor.py +60 -0
- spikezoo/archs/spikeformer/DataProcess/DataLoader.py +115 -0
- spikezoo/archs/spikeformer/DataProcess/LoadSpike.py +39 -0
- spikezoo/archs/spikeformer/EvalResults/readme +1 -0
- spikezoo/archs/spikeformer/LICENSE +21 -0
- spikezoo/archs/spikeformer/Metrics/Metrics.py +50 -0
- spikezoo/archs/spikeformer/Metrics/__init__.py +0 -0
- spikezoo/archs/spikeformer/Model/Loss.py +89 -0
- spikezoo/archs/spikeformer/Model/SpikeFormer.py +230 -0
- spikezoo/archs/spikeformer/Model/__init__.py +0 -0
- spikezoo/archs/spikeformer/Model/__pycache__/SpikeFormer.cpython-39.pyc +0 -0
- spikezoo/archs/spikeformer/Model/__pycache__/__init__.cpython-39.pyc +0 -0
- spikezoo/archs/spikeformer/README.md +30 -0
- spikezoo/archs/spikeformer/evaluate.py +87 -0
- spikezoo/archs/spikeformer/recon_real_data.py +97 -0
- spikezoo/archs/spikeformer/requirements.yml +95 -0
- spikezoo/archs/spikeformer/train.py +173 -0
- spikezoo/archs/spikeformer/utils.py +22 -0
- spikezoo/archs/spk2imgnet/.github/workflows/pylint.yml +23 -0
- spikezoo/archs/spk2imgnet/.gitignore +150 -0
- spikezoo/archs/spk2imgnet/DCNv2.py +135 -0
- spikezoo/archs/spk2imgnet/__pycache__/DCNv2.cpython-39.pyc +0 -0
- spikezoo/archs/spk2imgnet/__pycache__/align_arch.cpython-39.pyc +0 -0
- spikezoo/archs/spk2imgnet/__pycache__/nets.cpython-39.pyc +0 -0
- spikezoo/archs/spk2imgnet/align_arch.py +159 -0
- spikezoo/archs/spk2imgnet/dataset.py +144 -0
- spikezoo/archs/spk2imgnet/nets.py +230 -0
- spikezoo/archs/spk2imgnet/readme.md +86 -0
- spikezoo/archs/spk2imgnet/test_gen_imgseq.py +118 -0
- spikezoo/archs/spk2imgnet/train.py +189 -0
- spikezoo/archs/spk2imgnet/utils.py +64 -0
- spikezoo/archs/ssir/README.md +87 -0
- spikezoo/archs/ssir/configs/SSIR.yml +37 -0
- spikezoo/archs/ssir/configs/yml_parser.py +78 -0
- spikezoo/archs/ssir/datasets/dataset_sreds.py +170 -0
- spikezoo/archs/ssir/datasets/ds_utils.py +66 -0
- spikezoo/archs/ssir/losses.py +21 -0
- spikezoo/archs/ssir/main.py +326 -0
- spikezoo/archs/ssir/metrics/psnr.py +22 -0
- spikezoo/archs/ssir/metrics/ssim.py +54 -0
- spikezoo/archs/ssir/models/Vgg19.py +42 -0
- spikezoo/archs/ssir/models/__pycache__/layers.cpython-39.pyc +0 -0
- spikezoo/archs/ssir/models/__pycache__/networks.cpython-39.pyc +0 -0
- spikezoo/archs/ssir/models/layers.py +110 -0
- spikezoo/archs/ssir/models/networks.py +61 -0
- spikezoo/archs/ssir/requirements.txt +8 -0
- spikezoo/archs/ssir/shells/eval_SREDS.sh +6 -0
- spikezoo/archs/ssir/shells/train_SSIR.sh +12 -0
- spikezoo/archs/ssir/test.py +3 -0
- spikezoo/archs/ssir/utils.py +154 -0
- spikezoo/archs/ssml/__pycache__/cbam.cpython-39.pyc +0 -0
- spikezoo/archs/ssml/__pycache__/model.cpython-39.pyc +0 -0
- spikezoo/archs/ssml/cbam.py +224 -0
- spikezoo/archs/ssml/model.py +290 -0
- spikezoo/archs/ssml/res.png +0 -0
- spikezoo/archs/ssml/test.py +67 -0
- spikezoo/archs/stir/.git-credentials +0 -0
- spikezoo/archs/stir/README.md +65 -0
- spikezoo/archs/stir/ckpt_outputs/Descriptions.txt +1 -0
- spikezoo/archs/stir/configs/STIR.yml +37 -0
- spikezoo/archs/stir/configs/utils.py +155 -0
- spikezoo/archs/stir/configs/yml_parser.py +78 -0
- spikezoo/archs/stir/datasets/dataset_sreds.py +180 -0
- spikezoo/archs/stir/datasets/ds_utils.py +66 -0
- spikezoo/archs/stir/eval_SREDS.sh +5 -0
- spikezoo/archs/stir/main.py +397 -0
- spikezoo/archs/stir/metrics/losses.py +219 -0
- spikezoo/archs/stir/metrics/psnr.py +22 -0
- spikezoo/archs/stir/metrics/ssim.py +54 -0
- spikezoo/archs/stir/models/Vgg19.py +42 -0
- spikezoo/archs/stir/models/__pycache__/networks_STIR.cpython-39.pyc +0 -0
- spikezoo/archs/stir/models/__pycache__/submodules.cpython-39.pyc +0 -0
- spikezoo/archs/stir/models/__pycache__/transformer_new.cpython-39.pyc +0 -0
- spikezoo/archs/stir/models/networks_STIR.py +361 -0
- spikezoo/archs/stir/models/submodules.py +86 -0
- spikezoo/archs/stir/models/transformer_new.py +151 -0
- spikezoo/archs/stir/package_core/build/lib/package_core/__init__.py +0 -0
- spikezoo/archs/stir/package_core/build/lib/package_core/convertions.py +721 -0
- spikezoo/archs/stir/package_core/build/lib/package_core/disp_netS.py +133 -0
- spikezoo/archs/stir/package_core/build/lib/package_core/flow_utils.py +167 -0
- spikezoo/archs/stir/package_core/build/lib/package_core/generic_train_test.py +76 -0
- spikezoo/archs/stir/package_core/build/lib/package_core/geometry.py +458 -0
- spikezoo/archs/stir/package_core/build/lib/package_core/image_proc.py +183 -0
- spikezoo/archs/stir/package_core/build/lib/package_core/linalg.py +40 -0
- spikezoo/archs/stir/package_core/build/lib/package_core/losses.py +198 -0
- spikezoo/archs/stir/package_core/build/lib/package_core/metrics.py +51 -0
- spikezoo/archs/stir/package_core/build/lib/package_core/model_base.py +53 -0
- spikezoo/archs/stir/package_core/build/lib/package_core/net_basics.py +100 -0
- spikezoo/archs/stir/package_core/build/lib/package_core/resnet.py +333 -0
- spikezoo/archs/stir/package_core/build/lib/package_core/transforms.py +123 -0
- spikezoo/archs/stir/package_core/build/lib/package_core/utils.py +72 -0
- spikezoo/archs/stir/package_core/dist/package_core-0.0.0-py3.9.egg +0 -0
- spikezoo/archs/stir/package_core/package_core/__init__.py +0 -0
- spikezoo/archs/stir/package_core/package_core/__pycache__/__init__.cpython-39.pyc +0 -0
- spikezoo/archs/stir/package_core/package_core/__pycache__/net_basics.cpython-39.pyc +0 -0
- spikezoo/archs/stir/package_core/package_core/convertions.py +721 -0
- spikezoo/archs/stir/package_core/package_core/disp_netS.py +133 -0
- spikezoo/archs/stir/package_core/package_core/flow_utils.py +167 -0
- spikezoo/archs/stir/package_core/package_core/generic_train_test.py +76 -0
- spikezoo/archs/stir/package_core/package_core/geometry.py +458 -0
- spikezoo/archs/stir/package_core/package_core/image_proc.py +183 -0
- spikezoo/archs/stir/package_core/package_core/linalg.py +40 -0
- spikezoo/archs/stir/package_core/package_core/losses.py +198 -0
- spikezoo/archs/stir/package_core/package_core/metrics.py +51 -0
- spikezoo/archs/stir/package_core/package_core/model_base.py +53 -0
- spikezoo/archs/stir/package_core/package_core/net_basics.py +100 -0
- spikezoo/archs/stir/package_core/package_core/resnet.py +333 -0
- spikezoo/archs/stir/package_core/package_core/transforms.py +123 -0
- spikezoo/archs/stir/package_core/package_core/utils.py +72 -0
- spikezoo/archs/stir/package_core/package_core.egg-info/PKG-INFO +3 -0
- spikezoo/archs/stir/package_core/package_core.egg-info/SOURCES.txt +20 -0
- spikezoo/archs/stir/package_core/package_core.egg-info/dependency_links.txt +1 -0
- spikezoo/archs/stir/package_core/package_core.egg-info/top_level.txt +1 -0
- spikezoo/archs/stir/package_core/setup.py +5 -0
- spikezoo/archs/stir/requirements.txt +12 -0
- spikezoo/archs/stir/train_STIR.sh +9 -0
- spikezoo/archs/tfi/__pycache__/nets.cpython-39.pyc +0 -0
- spikezoo/archs/tfi/nets.py +43 -0
- spikezoo/archs/tfp/__pycache__/nets.cpython-39.pyc +0 -0
- spikezoo/archs/tfp/nets.py +13 -0
- spikezoo/archs/wgse/README.md +64 -0
- spikezoo/archs/wgse/__pycache__/dwtnets.cpython-39.pyc +0 -0
- spikezoo/archs/wgse/__pycache__/submodules.cpython-39.pyc +0 -0
- spikezoo/archs/wgse/dataset.py +59 -0
- spikezoo/archs/wgse/demo.png +0 -0
- spikezoo/archs/wgse/demo.py +83 -0
- spikezoo/archs/wgse/dwtnets.py +145 -0
- spikezoo/archs/wgse/eval.py +133 -0
- spikezoo/archs/wgse/logs/WGSE-Dwt1dNet-db8-5-ks3/log.txt +11 -0
- spikezoo/archs/wgse/submodules.py +68 -0
- spikezoo/archs/wgse/train.py +261 -0
- spikezoo/archs/wgse/transform.py +139 -0
- spikezoo/archs/wgse/utils.py +128 -0
- spikezoo/archs/wgse/weights/demo.png +0 -0
- spikezoo/data/base/test/gt/200_part1_key_id151.png +0 -0
- spikezoo/data/base/test/gt/200_part3_key_id151.png +0 -0
- spikezoo/data/base/test/gt/203_part1_key_id151.png +0 -0
- spikezoo/data/base/test/spike/200_part1_key_id151.dat +0 -0
- spikezoo/data/base/test/spike/200_part3_key_id151.dat +0 -0
- spikezoo/data/base/test/spike/203_part1_key_id151.dat +0 -0
- spikezoo/data/base/train/gt/203_part2_key_id151.png +0 -0
- spikezoo/data/base/train/gt/203_part3_key_id151.png +0 -0
- spikezoo/data/base/train/gt/203_part4_key_id151.png +0 -0
- spikezoo/data/base/train/spike/203_part2_key_id151.dat +0 -0
- spikezoo/data/base/train/spike/203_part3_key_id151.dat +0 -0
- spikezoo/data/base/train/spike/203_part4_key_id151.dat +0 -0
- spikezoo/datasets/base_dataset.py +2 -3
- spikezoo/metrics/__init__.py +1 -1
- spikezoo/models/base_model.py +1 -3
- spikezoo/pipeline/base_pipeline.py +7 -5
- spikezoo/pipeline/train_pipeline.py +1 -1
- spikezoo/utils/other_utils.py +16 -6
- spikezoo/utils/spike_utils.py +33 -29
- spikezoo/utils/vidar_loader.cpython-39-x86_64-linux-gnu.so +0 -0
- spikezoo-0.2.dist-info/METADATA +163 -0
- spikezoo-0.2.dist-info/RECORD +211 -0
- spikezoo/models/spcsnet_model.py +0 -19
- spikezoo-0.1.1.dist-info/METADATA +0 -39
- spikezoo-0.1.1.dist-info/RECORD +0 -36
- {spikezoo-0.1.1.dist-info → spikezoo-0.2.dist-info}/LICENSE.txt +0 -0
- {spikezoo-0.1.1.dist-info → spikezoo-0.2.dist-info}/WHEEL +0 -0
- {spikezoo-0.1.1.dist-info → spikezoo-0.2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,67 @@
|
|
1
|
+
from json import load
|
2
|
+
import torch
|
3
|
+
import torch.nn as nn
|
4
|
+
import torch.nn.functional as F
|
5
|
+
import numpy as np
|
6
|
+
from model import DoubleNet
|
7
|
+
import cv2
|
8
|
+
|
9
|
+
def load_vidar_dat(filename, left_up=(0, 0), window=None, frame_cnt = None, **kwargs):
|
10
|
+
if isinstance(filename, str):
|
11
|
+
array = np.fromfile(filename, dtype=np.uint8)
|
12
|
+
elif isinstance(filename, (list, tuple)):
|
13
|
+
l = []
|
14
|
+
for name in filename:
|
15
|
+
a = np.fromfile(name, dtype=np.uint8)
|
16
|
+
l.append(a)
|
17
|
+
array = np.concatenate(l)
|
18
|
+
else:
|
19
|
+
raise NotImplementedError
|
20
|
+
|
21
|
+
height = 250
|
22
|
+
width = 400
|
23
|
+
|
24
|
+
if window == None:
|
25
|
+
window = (height - left_up[0], width - left_up[0])
|
26
|
+
|
27
|
+
len_per_frame = height * width // 8
|
28
|
+
framecnt = frame_cnt if frame_cnt != None else len(array) // len_per_frame
|
29
|
+
|
30
|
+
spikes = []
|
31
|
+
|
32
|
+
for i in range(framecnt):
|
33
|
+
compr_frame = array[i * len_per_frame: (i + 1) * len_per_frame]
|
34
|
+
blist = []
|
35
|
+
for b in range(8):
|
36
|
+
blist.append(np.right_shift(np.bitwise_and(compr_frame, np.left_shift(1, b)), b))
|
37
|
+
|
38
|
+
frame_ = np.stack(blist).transpose()
|
39
|
+
frame_ = np.flipud(frame_.reshape((height, width), order='C'))
|
40
|
+
|
41
|
+
if window is not None:
|
42
|
+
spk = frame_[left_up[0]:left_up[0] + window[0], left_up[1]:left_up[1] + window[1]]
|
43
|
+
else:
|
44
|
+
spk = frame_
|
45
|
+
|
46
|
+
spk = torch.from_numpy(spk.copy().astype(np.float32)).unsqueeze(dim=0)
|
47
|
+
|
48
|
+
spikes.append(spk)
|
49
|
+
|
50
|
+
return torch.cat(spikes)
|
51
|
+
|
52
|
+
if __name__ == '__main__':
|
53
|
+
model = DoubleNet()
|
54
|
+
model_path = "./fin3g-best-lucky.pt"
|
55
|
+
model = nn.DataParallel(model)
|
56
|
+
model.load_state_dict(torch.load(model_path))
|
57
|
+
model = model.cuda()
|
58
|
+
|
59
|
+
spike_path = "./rotation1.dat"
|
60
|
+
spike = load_vidar_dat(spike_path)[200:200+41].unsqueeze(0).cuda()
|
61
|
+
|
62
|
+
res = model(spike)
|
63
|
+
res = res[0].detach().cpu().permute(1,2,0).numpy()*255
|
64
|
+
res_path = "./res.png"
|
65
|
+
cv2.imwrite(res_path,res)
|
66
|
+
|
67
|
+
print("done.")
|
File without changes
|
@@ -0,0 +1,65 @@
|
|
1
|
+
<!---
|
2
|
+
# Spatio-Temporal Interactive Learning for Efficient Image Reconstruction of Spiking Cameras
|
3
|
+
|
4
|
+
This repository contains the source code for the paper: [Spatio-Temporal Interactive Learning for Efficient Image Reconstruction of Spiking Cameras (NeurIPS 2024)](https://openreview.net/pdf?id=S4ZqnMywcM).
|
5
|
+
The spiking camera is an emerging neuromorphic vision sensor that records high-speed motion scenes by asynchronously firing continuous binary spike streams. Prevailing image reconstruction methods, generating intermediate frames from these spike streams, often rely on complex step-by-step network architectures that overlook the intrinsic collaboration of spatio-temporal complementary information. In this paper, we propose an efficient spatio-temporal interactive reconstruction network to jointly perform inter-frame feature alignment and intra-frame feature filtering in a coarse-to-fine manner. Specifically, it starts by extracting hierarchical features from a concise hybrid spike representation, then refines the motion fields and target frames scale-by-scale, ultimately obtaining a full-resolution output. Meanwhile, we introduce a symmetric interactive attention block and a multi-motion field estimation block to further enhance the interaction capability of the overall network. Experiments on synthetic and real-captured data show that our approach exhibits excellent performance while maintaining low model complexity.
|
6
|
+
|
7
|
+
<img src="picture/performance-speed.png" width="75%"/>
|
8
|
+
<img src="picture/overview.png" width="80%"/>
|
9
|
+
<img src="picture/results_visual.png" width="82%"/>
|
10
|
+
-->
|
11
|
+
## Installation
|
12
|
+
You can choose cudatoolkit version to match your server. The code is tested with PyTorch 1.9.1 with CUDA 11.4.
|
13
|
+
|
14
|
+
```shell
|
15
|
+
conda create -n stir python==3.8.12
|
16
|
+
conda activate stir
|
17
|
+
# You can choose the PyTorch version you like, for example
|
18
|
+
pip install torch==2.2.2 torchvision==0.17.2 torchaudio==2.0.2
|
19
|
+
```
|
20
|
+
|
21
|
+
Install the dependent packages:
|
22
|
+
```
|
23
|
+
pip install -r requirements.txt
|
24
|
+
```
|
25
|
+
|
26
|
+
Install core package
|
27
|
+
```
|
28
|
+
cd ./package_core
|
29
|
+
python setup.py install
|
30
|
+
```
|
31
|
+
|
32
|
+
In our implementation, we borrowed the code framework of [SSIR](https://github.com/ruizhao26/SSIR):
|
33
|
+
|
34
|
+
## Prepare the Data
|
35
|
+
|
36
|
+
#### 1. Download and deploy the SREDS dataset to your local computer from [SSIR](https://github.com/ruizhao26/SSIR).
|
37
|
+
|
38
|
+
#### 2. Set the path of the SREDS dataset in your serve
|
39
|
+
|
40
|
+
Set that in `--data_root` when running train_STIR.sh or eval_SREDS.sh
|
41
|
+
|
42
|
+
## Evaluate
|
43
|
+
```
|
44
|
+
sh eval_SREDS.sh
|
45
|
+
```
|
46
|
+
|
47
|
+
## Train
|
48
|
+
```
|
49
|
+
sh train_STIR.sh
|
50
|
+
```
|
51
|
+
<!---
|
52
|
+
## Citations
|
53
|
+
If you find our approach useful in your research or applications, please consider giving us a star 🌟 and citing it by the following BibTeX entry.
|
54
|
+
```
|
55
|
+
@article{fan2024spatio,
|
56
|
+
title={Spatio-Temporal Interactive Learning for Efficient Image Reconstruction of Spiking Cameras},
|
57
|
+
author={Fan, Bin and Yin, Jiaoyang and Dai, Yuchao and Xu, Chao and Huang, Tiejun and Shi, Boxin},
|
58
|
+
journal={Proceedings of the Advances in Neural Information Processing Systems (NeurIPS)},
|
59
|
+
volume={},
|
60
|
+
year={2024}
|
61
|
+
}
|
62
|
+
```
|
63
|
+
-->
|
64
|
+
## Statement
|
65
|
+
This project is for research purpose only, please contact us for the licence of commercial use. For any other questions or discussion please contact: binfan@mail.nwpu.edu.cn
|
@@ -0,0 +1 @@
|
|
1
|
+
This folder is used to store the trained model.
|
@@ -0,0 +1,37 @@
|
|
1
|
+
data:
|
2
|
+
interp: 20
|
3
|
+
alpha: 0.4
|
4
|
+
|
5
|
+
seed: 6666
|
6
|
+
|
7
|
+
loader:
|
8
|
+
# crop_size: [128, 128]
|
9
|
+
crop_size: [96, 96]
|
10
|
+
pair_step: 4
|
11
|
+
|
12
|
+
model:
|
13
|
+
arch: 'STIR'
|
14
|
+
seq_len: 8
|
15
|
+
flow_weight_decay: 0.0004
|
16
|
+
flow_bias_decay: 0.0
|
17
|
+
#########################
|
18
|
+
kwargs:
|
19
|
+
activation_type: 'lif'
|
20
|
+
mp_activation_type: 'amp_lif'
|
21
|
+
spike_connection: 'concat'
|
22
|
+
num_encoders: 3
|
23
|
+
num_resblocks: 1
|
24
|
+
v_threshold: 1.0
|
25
|
+
v_reset: None
|
26
|
+
tau: 2.0
|
27
|
+
|
28
|
+
|
29
|
+
train:
|
30
|
+
print_freq: 1
|
31
|
+
mixed_precision: True
|
32
|
+
vis_freq: 20
|
33
|
+
|
34
|
+
optimizer:
|
35
|
+
solver: Adam
|
36
|
+
momentum: 0.9
|
37
|
+
beta: 0.999
|
@@ -0,0 +1,155 @@
|
|
1
|
+
import numpy as np
|
2
|
+
import torch
|
3
|
+
import torch.nn.functional as F
|
4
|
+
import os
|
5
|
+
import os.path as osp
|
6
|
+
import random
|
7
|
+
import cv2
|
8
|
+
|
9
|
+
def set_seeds(_seed_):
|
10
|
+
random.seed(_seed_)
|
11
|
+
np.random.seed(_seed_)
|
12
|
+
torch.manual_seed(_seed_) # use torch.manual_seed() to seed the RNG for all devices (both CPU and CUDA)
|
13
|
+
torch.cuda.manual_seed_all(_seed_)
|
14
|
+
|
15
|
+
torch.backends.cudnn.deterministic = True
|
16
|
+
torch.backends.cudnn.benchmark = False
|
17
|
+
# os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'
|
18
|
+
# set a debug environment variable CUBLAS_WORKSPACE_CONFIG to ":16:8" (may limit overall performance) or ":4096:8" (will increase library footprint in GPU memory by approximately 24MiB).
|
19
|
+
# torch.use_deterministic_algorithms(True)
|
20
|
+
|
21
|
+
|
22
|
+
def make_dir(path):
|
23
|
+
if not osp.exists(path):
|
24
|
+
os.makedirs(path)
|
25
|
+
return
|
26
|
+
|
27
|
+
|
28
|
+
def add_args_to_cfg(cfg, args, args_list):
|
29
|
+
for aa in args_list:
|
30
|
+
cfg['train'][aa] = eval('args.{:s}'.format(aa))
|
31
|
+
return cfg
|
32
|
+
|
33
|
+
|
34
|
+
# class AverageMeter(object):
|
35
|
+
# """Computes and stores the average and current value"""
|
36
|
+
# def __init__(self, precision=3):
|
37
|
+
# self.precision = precision
|
38
|
+
# self.reset()
|
39
|
+
|
40
|
+
# def reset(self):
|
41
|
+
# self.val = 0
|
42
|
+
# self.avg = 0
|
43
|
+
# self.sum = 0
|
44
|
+
# self.count = 0
|
45
|
+
|
46
|
+
# def update(self, val, n=1):
|
47
|
+
# self.val = val
|
48
|
+
# self.sum += val * n
|
49
|
+
# self.count += n
|
50
|
+
# self.avg = self.sum / self.count
|
51
|
+
|
52
|
+
# def __repr__(self):
|
53
|
+
# return '{:.{}f} ({:.{}f})'.format(self.val, self.precision, self.avg, self.precision)
|
54
|
+
|
55
|
+
|
56
|
+
class AverageMeter(object):
|
57
|
+
"""Computes and stores the average and current value"""
|
58
|
+
|
59
|
+
def __init__(self, i=1, precision=3, names=None):
|
60
|
+
self.meters = i
|
61
|
+
self.precision = precision
|
62
|
+
self.reset(self.meters)
|
63
|
+
self.names = names
|
64
|
+
if names is not None:
|
65
|
+
assert self.meters == len(self.names)
|
66
|
+
else:
|
67
|
+
self.names = [''] * self.meters
|
68
|
+
|
69
|
+
def reset(self, i):
|
70
|
+
self.val = [0] * i
|
71
|
+
self.avg = [0] * i
|
72
|
+
self.sum = [0] * i
|
73
|
+
self.count = [0] * i
|
74
|
+
|
75
|
+
def update(self, val, n=1):
|
76
|
+
if not isinstance(val, list):
|
77
|
+
val = [val]
|
78
|
+
if not isinstance(n, list):
|
79
|
+
n = [n] * self.meters
|
80
|
+
assert (len(val) == self.meters and len(n) == self.meters)
|
81
|
+
for i in range(self.meters):
|
82
|
+
self.count[i] += n[i]
|
83
|
+
for i, v in enumerate(val):
|
84
|
+
self.val[i] = v
|
85
|
+
self.sum[i] += v * n[i]
|
86
|
+
self.avg[i] = self.sum[i] / self.count[i]
|
87
|
+
|
88
|
+
def __repr__(self):
|
89
|
+
# val = ' '.join(['{} {:.{}f}'.format(n, v, self.precision) for n, v in
|
90
|
+
# zip(self.names, self.val)])
|
91
|
+
# avg = ' '.join(['{} {:.{}f}'.format(n, a, self.precision) for n, a in
|
92
|
+
# zip(self.names, self.avg)])
|
93
|
+
out = ' '.join(['{} {:.{}f} ({:.{}f})'.format(n, v, self.precision, a, self.precision) for n, v, a in
|
94
|
+
zip(self.names, self.val, self.avg)])
|
95
|
+
# return '{} ({})'.format(val, avg)
|
96
|
+
return '{}'.format(out)
|
97
|
+
|
98
|
+
|
99
|
+
def normalize_image_torch(image, percentile_lower=1, percentile_upper=99):
|
100
|
+
b, c, h, w = image.shape
|
101
|
+
image_reshape = image.reshape([b, c, h*w])
|
102
|
+
mini = torch.quantile(image_reshape, 0.01, dim=2, keepdim=True).unsqueeze_(dim=3)
|
103
|
+
maxi = torch.quantile(image_reshape, 0.99, dim=2, keepdim=True).unsqueeze_(dim=3)
|
104
|
+
# if mini == maxi:
|
105
|
+
# return 0 * image + 0.5 # gray image
|
106
|
+
return torch.clip((image - mini) / (maxi - mini + 1e-5), 0, 1)
|
107
|
+
|
108
|
+
def normalize_image_torch2(image):
|
109
|
+
return torch.clip(image, 0, 1)
|
110
|
+
|
111
|
+
# --------------------------------------------
|
112
|
+
# Torch to Numpy 0~255
|
113
|
+
# --------------------------------------------
|
114
|
+
def torch2numpy255(im):
|
115
|
+
im = im[0, 0].detach().cpu().numpy()
|
116
|
+
im = (im * 255).astype(np.float64)
|
117
|
+
return im
|
118
|
+
|
119
|
+
def torch2torch255(im):
|
120
|
+
return im * 255.0
|
121
|
+
|
122
|
+
class InputPadder:
|
123
|
+
""" Pads images such that dimensions are divisible by padsize """
|
124
|
+
def __init__(self, dims, padsize=16):
|
125
|
+
self.ht, self.wd = dims[-2:]
|
126
|
+
pad_ht = (((self.ht // padsize) + 1) * padsize - self.ht) % padsize
|
127
|
+
pad_wd = (((self.wd // padsize) + 1) * padsize - self.wd) % padsize
|
128
|
+
#self._pad = [pad_wd//2, pad_wd - pad_wd//2, 0, pad_ht]
|
129
|
+
self._pad = [pad_wd//2, pad_wd - pad_wd//2, pad_ht//2, pad_ht - pad_ht//2]
|
130
|
+
|
131
|
+
def pad(self, *inputs):
|
132
|
+
return [F.pad(x, self._pad, mode='replicate') for x in inputs]
|
133
|
+
|
134
|
+
def unpad(self,x):
|
135
|
+
ht, wd = x.shape[-2:]
|
136
|
+
c = [self._pad[2], ht-self._pad[3], self._pad[0], wd-self._pad[1]]
|
137
|
+
return x[..., c[0]:c[1], c[2]:c[3]]
|
138
|
+
|
139
|
+
|
140
|
+
|
141
|
+
def vis_img(vis_path: str, img: torch.Tensor, vis_name: str = 'vis'):
|
142
|
+
ww = 0
|
143
|
+
rows = []
|
144
|
+
for ii in range(4):
|
145
|
+
cur_row = []
|
146
|
+
for jj in range(img.shape[0]//4):
|
147
|
+
cur_img = img[ww, 0].detach().cpu().numpy() * 255
|
148
|
+
cur_img = cur_img.astype(np.uint8)
|
149
|
+
cur_row.append(cur_img)
|
150
|
+
ww += 1
|
151
|
+
cur_row_cat = np.concatenate(cur_row, axis=1)
|
152
|
+
rows.append(cur_row_cat)
|
153
|
+
out_img = np.concatenate(rows, axis=0)
|
154
|
+
cv2.imwrite(osp.join(vis_path, vis_name+'.png'), out_img)
|
155
|
+
return
|
@@ -0,0 +1,78 @@
|
|
1
|
+
import numpy as np
|
2
|
+
import torch
|
3
|
+
import yaml
|
4
|
+
|
5
|
+
|
6
|
+
class YAMLParser:
|
7
|
+
"""
|
8
|
+
Modified from code from tudelft ssl-evflow
|
9
|
+
"""
|
10
|
+
|
11
|
+
def __init__(self, config):
|
12
|
+
self.reset_config()
|
13
|
+
self.parse_config(config)
|
14
|
+
# self.init_seeds()
|
15
|
+
|
16
|
+
def parse_config(self, file):
|
17
|
+
with open(file) as fid:
|
18
|
+
yaml_config = yaml.load(fid, Loader=yaml.FullLoader)
|
19
|
+
self.parse_dict(yaml_config)
|
20
|
+
|
21
|
+
@property
|
22
|
+
def config(self):
|
23
|
+
return self._config
|
24
|
+
|
25
|
+
@property
|
26
|
+
def device(self):
|
27
|
+
return self._device
|
28
|
+
|
29
|
+
@property
|
30
|
+
def loader_kwargs(self):
|
31
|
+
return self._loader_kwargs
|
32
|
+
|
33
|
+
def reset_config(self):
|
34
|
+
self._config = {}
|
35
|
+
|
36
|
+
def update(self, config):
|
37
|
+
self.reset_config()
|
38
|
+
self.parse_config(config)
|
39
|
+
|
40
|
+
def parse_dict(self, input_dict, parent=None):
|
41
|
+
if parent is None:
|
42
|
+
parent = self._config
|
43
|
+
for key, val in input_dict.items():
|
44
|
+
if isinstance(val, dict):
|
45
|
+
if key not in parent.keys():
|
46
|
+
parent[key] = {}
|
47
|
+
self.parse_dict(val, parent[key])
|
48
|
+
else:
|
49
|
+
parent[key] = val
|
50
|
+
|
51
|
+
@staticmethod
|
52
|
+
def worker_init_fn(worker_id):
|
53
|
+
np.random.seed(np.random.get_state()[1][0] + worker_id)
|
54
|
+
|
55
|
+
# def init_seeds(self):
|
56
|
+
# torch.manual_seed(self._config["loader"]["seed"])
|
57
|
+
# if torch.cuda.is_available():
|
58
|
+
# torch.cuda.manual_seed(self._config["loader"]["seed"])
|
59
|
+
# torch.cuda.manual_seed_all(self._config["loader"]["seed"])
|
60
|
+
|
61
|
+
def merge_configs(self, run):
|
62
|
+
"""
|
63
|
+
Overwrites mlflow metadata with configs.
|
64
|
+
"""
|
65
|
+
|
66
|
+
# parse mlflow settings
|
67
|
+
config = {}
|
68
|
+
for key in run.keys():
|
69
|
+
if len(run[key]) > 0 and run[key][0] == "{": # assume dictionary
|
70
|
+
config[key] = eval(run[key])
|
71
|
+
else: # string
|
72
|
+
config[key] = run[key]
|
73
|
+
|
74
|
+
# overwrite with config settings
|
75
|
+
self.parse_dict(self._config, config)
|
76
|
+
self.combine_entries(config)
|
77
|
+
|
78
|
+
return config
|
@@ -0,0 +1,180 @@
|
|
1
|
+
import os
|
2
|
+
import os.path as osp
|
3
|
+
import random
|
4
|
+
import numpy as np
|
5
|
+
import torch
|
6
|
+
import torch.utils.data as data
|
7
|
+
from datasets.ds_utils import *
|
8
|
+
import time
|
9
|
+
|
10
|
+
|
11
|
+
class Augmentor:
|
12
|
+
def __init__(self, crop_size):
|
13
|
+
# spatial augmentation params
|
14
|
+
self.crop_size = crop_size
|
15
|
+
|
16
|
+
def augment_img(self, img, mode=0):
|
17
|
+
'''Kai Zhang (github: https://github.com/cszn)
|
18
|
+
W x H x C or W x H
|
19
|
+
'''
|
20
|
+
if mode == 0:
|
21
|
+
return img
|
22
|
+
elif mode == 1:
|
23
|
+
return np.flipud(np.rot90(img))
|
24
|
+
elif mode == 2:
|
25
|
+
return np.flipud(img)
|
26
|
+
elif mode == 3:
|
27
|
+
return np.rot90(img, k=3)
|
28
|
+
elif mode == 4:
|
29
|
+
return np.flipud(np.rot90(img, k=2))
|
30
|
+
elif mode == 5:
|
31
|
+
return np.rot90(img)
|
32
|
+
elif mode == 6:
|
33
|
+
return np.rot90(img, k=2)
|
34
|
+
elif mode == 7:
|
35
|
+
return np.flipud(np.rot90(img, k=3))
|
36
|
+
|
37
|
+
def spatial_transform(self, spk_list, img_list):
|
38
|
+
mode = random.randint(0, 7)
|
39
|
+
|
40
|
+
for ii, spk in enumerate(spk_list):
|
41
|
+
spk = np.transpose(spk, [1,2,0])
|
42
|
+
spk = self.augment_img(spk, mode=mode)
|
43
|
+
spk_list[ii] = np.transpose(spk, [2,0,1])
|
44
|
+
|
45
|
+
for ii, img in enumerate(img_list):
|
46
|
+
img = np.transpose(img, [1,2,0])
|
47
|
+
img = self.augment_img(img, mode=mode)
|
48
|
+
img_list[ii] = np.transpose(img, [2,0,1])
|
49
|
+
|
50
|
+
return spk_list, img_list
|
51
|
+
|
52
|
+
def __call__(self, spk_list, img_list):
|
53
|
+
spk_list, img_list = self.spatial_transform(spk_list, img_list)
|
54
|
+
spk_list = [np.ascontiguousarray(spk) for spk in spk_list]
|
55
|
+
img_list = [np.ascontiguousarray(img) for img in img_list]
|
56
|
+
return spk_list, img_list
|
57
|
+
|
58
|
+
|
59
|
+
class sreds_train(torch.utils.data.Dataset):
|
60
|
+
def __init__(self, cfg):
|
61
|
+
self.cfg = cfg
|
62
|
+
self.pair_step = self.cfg['loader']['pair_step']
|
63
|
+
self.augmentor = Augmentor(crop_size=self.cfg['loader']['crop_size'])
|
64
|
+
self.samples = self.collect_samples()
|
65
|
+
print('The samples num of training data: {:d}'.format(len(self.samples)))
|
66
|
+
|
67
|
+
def confirm_exist(self, path_list_list):
|
68
|
+
for pl in path_list_list:
|
69
|
+
for p in pl:
|
70
|
+
if not osp.exists(p):
|
71
|
+
return 0
|
72
|
+
return 1
|
73
|
+
|
74
|
+
def collect_samples(self):
|
75
|
+
spike_path = osp.join(self.cfg['data']['root'], 'crop_mini', 'spike', 'train', 'interp_{:d}_alpha_{:.2f}'.format(self.cfg['data']['interp'], self.cfg['data']['alpha']))
|
76
|
+
image_path = osp.join(self.cfg['data']['root'], 'crop_mini', 'image', 'train', 'train_orig')
|
77
|
+
scene_list = sorted(os.listdir(spike_path))
|
78
|
+
samples = []
|
79
|
+
|
80
|
+
for scene in scene_list:
|
81
|
+
spike_dir = osp.join(spike_path, scene)
|
82
|
+
image_dir = osp.join(image_path, scene)
|
83
|
+
spk_path_list = sorted(os.listdir(spike_dir))
|
84
|
+
|
85
|
+
spklen = len(spk_path_list)
|
86
|
+
seq_len = self.cfg['model']['seq_len'] + 2
|
87
|
+
'''
|
88
|
+
for st in range(0, spklen - ((spklen - self.pair_step) % seq_len) - seq_len, self.pair_step):
|
89
|
+
# 按照文件名称读取
|
90
|
+
spikes_path_list = [osp.join(spike_dir, spk_path_list[ii]) for ii in range(st, st+seq_len)]
|
91
|
+
images_path_list = [osp.join(image_dir, spk_path_list[ii][:-4]+'.png') for ii in range(st, st+seq_len)]
|
92
|
+
|
93
|
+
if(self.confirm_exist([spikes_path_list, images_path_list])):
|
94
|
+
s = {}
|
95
|
+
s['spikes_paths'] = spikes_path_list
|
96
|
+
s['images_paths'] = images_path_list
|
97
|
+
samples.append(s)
|
98
|
+
'''
|
99
|
+
# 按照文件名称读取
|
100
|
+
spikes_path_list = [osp.join(spike_dir, spk_path_list[ii]) for ii in range(spklen)]
|
101
|
+
images_path_list = [osp.join(image_dir, spk_path_list[ii][:-4]+'.png') for ii in range(spklen)]
|
102
|
+
|
103
|
+
if(self.confirm_exist([spikes_path_list, images_path_list])):
|
104
|
+
s = {}
|
105
|
+
s['spikes_paths'] = spikes_path_list
|
106
|
+
s['images_paths'] = images_path_list
|
107
|
+
samples.append(s)
|
108
|
+
|
109
|
+
return samples
|
110
|
+
|
111
|
+
def _load_sample(self, s):
|
112
|
+
data = {}
|
113
|
+
|
114
|
+
data['spikes'] = [np.array(dat_to_spmat(p, size=(96, 96)), dtype=np.float32) for p in s['spikes_paths']]
|
115
|
+
data['images'] = [read_img_gray(p) for p in s['images_paths']]
|
116
|
+
|
117
|
+
data['spikes'], data['images'] = self.augmentor(data['spikes'], data['images'])
|
118
|
+
# print("data['spikes'][0].shape, data['images'][0].shape", data['spikes'][0].shape, data['images'][0].shape)
|
119
|
+
|
120
|
+
return data
|
121
|
+
|
122
|
+
def __len__(self):
|
123
|
+
return len(self.samples)
|
124
|
+
|
125
|
+
def __getitem__(self, index):
|
126
|
+
data = self._load_sample(self.samples[index])
|
127
|
+
return data
|
128
|
+
|
129
|
+
|
130
|
+
class sreds_test(torch.utils.data.Dataset):
|
131
|
+
def __init__(self, cfg):
|
132
|
+
self.cfg = cfg
|
133
|
+
self.samples = self.collect_samples()
|
134
|
+
print('The samples num of testing data: {:d}'.format(len(self.samples)))
|
135
|
+
|
136
|
+
def confirm_exist(self, path_list_list):
|
137
|
+
for pl in path_list_list:
|
138
|
+
for p in pl:
|
139
|
+
if not osp.exists(p):
|
140
|
+
return 0
|
141
|
+
return 1
|
142
|
+
|
143
|
+
def collect_samples(self):
|
144
|
+
spike_path = osp.join(self.cfg['data']['root'], 'spike', 'val', 'interp_{:d}_alpha_{:.2f}'.format(self.cfg['data']['interp'], self.cfg['data']['alpha']))
|
145
|
+
image_path = osp.join(self.cfg['data']['root'], 'imgs', 'val', 'val_orig')
|
146
|
+
scene_list = sorted(os.listdir(spike_path))
|
147
|
+
samples = []
|
148
|
+
|
149
|
+
for scene in scene_list:
|
150
|
+
spike_dir = osp.join(spike_path, scene)
|
151
|
+
image_dir = osp.join(image_path, scene)
|
152
|
+
spk_path_list = sorted(os.listdir(spike_dir))
|
153
|
+
|
154
|
+
spklen = len(spk_path_list)
|
155
|
+
# seq_len = self.cfg['model']['seq_len']
|
156
|
+
|
157
|
+
# 按照文件名称读取
|
158
|
+
spikes_path_list = [osp.join(spike_dir, spk_path_list[ii]) for ii in range(spklen)]
|
159
|
+
images_path_list = [osp.join(image_dir, spk_path_list[ii][:-4]+'.png') for ii in range(spklen)]
|
160
|
+
|
161
|
+
if(self.confirm_exist([spikes_path_list, images_path_list])):
|
162
|
+
s = {}
|
163
|
+
s['spikes_paths'] = spikes_path_list
|
164
|
+
s['images_paths'] = images_path_list
|
165
|
+
samples.append(s)
|
166
|
+
|
167
|
+
return samples
|
168
|
+
|
169
|
+
def _load_sample(self, s):
|
170
|
+
data = {}
|
171
|
+
data['spikes'] = [np.array(dat_to_spmat(p, size=(720, 1280)), dtype=np.float32) for p in s['spikes_paths']]
|
172
|
+
data['images'] = [read_img_gray(p) for p in s['images_paths']]
|
173
|
+
return data
|
174
|
+
|
175
|
+
def __len__(self):
|
176
|
+
return len(self.samples)
|
177
|
+
|
178
|
+
def __getitem__(self, index):
|
179
|
+
data = self._load_sample(self.samples[index])
|
180
|
+
return data
|
@@ -0,0 +1,66 @@
|
|
1
|
+
import numpy as np
|
2
|
+
import os
|
3
|
+
import cv2
|
4
|
+
import os.path as osp
|
5
|
+
|
6
|
+
def RawToSpike(video_seq, h, w, flipud=True):
|
7
|
+
video_seq = np.array(video_seq).astype(np.uint8)
|
8
|
+
img_size = h*w
|
9
|
+
img_num = len(video_seq)//(img_size//8)
|
10
|
+
SpikeMatrix = np.zeros([img_num, h, w], np.uint8)
|
11
|
+
pix_id = np.arange(0,h*w)
|
12
|
+
pix_id = np.reshape(pix_id, (h, w))
|
13
|
+
comparator = np.left_shift(1, np.mod(pix_id, 8))
|
14
|
+
byte_id = pix_id // 8
|
15
|
+
|
16
|
+
for img_id in np.arange(img_num):
|
17
|
+
id_start = img_id*img_size//8
|
18
|
+
id_end = id_start + img_size//8
|
19
|
+
cur_info = video_seq[id_start:id_end]
|
20
|
+
data = cur_info[byte_id]
|
21
|
+
result = np.bitwise_and(data, comparator)
|
22
|
+
if flipud:
|
23
|
+
SpikeMatrix[img_id, :, :] = np.flipud((result == comparator))
|
24
|
+
else:
|
25
|
+
SpikeMatrix[img_id, :, :] = (result == comparator)
|
26
|
+
|
27
|
+
return SpikeMatrix
|
28
|
+
|
29
|
+
|
30
|
+
def SpikeToRaw(SpikeSeq, save_path):
|
31
|
+
"""
|
32
|
+
SpikeSeq: Numpy array (sfn x h x w)
|
33
|
+
save_path: full saving path (string)
|
34
|
+
Rui Zhao
|
35
|
+
"""
|
36
|
+
sfn, h, w = SpikeSeq.shape
|
37
|
+
base = np.power(2, np.linspace(0, 7, 8))
|
38
|
+
fid = open(save_path, 'ab')
|
39
|
+
for img_id in range(sfn):
|
40
|
+
# 模拟相机的倒像
|
41
|
+
spike = np.flipud(SpikeSeq[img_id, :, :])
|
42
|
+
# numpy按自动按行排,数据也是按行存的
|
43
|
+
spike = spike.flatten()
|
44
|
+
spike = spike.reshape([int(h*w/8), 8])
|
45
|
+
data = spike * base
|
46
|
+
data = np.sum(data, axis=1).astype(np.uint8)
|
47
|
+
fid.write(data.tobytes())
|
48
|
+
|
49
|
+
fid.close()
|
50
|
+
|
51
|
+
return
|
52
|
+
|
53
|
+
|
54
|
+
def dat_to_spmat(dat_path, size=[720, 1280]):
|
55
|
+
f = open(dat_path, 'rb')
|
56
|
+
video_seq = f.read()
|
57
|
+
video_seq = np.frombuffer(video_seq, 'b')
|
58
|
+
sp_mat = RawToSpike(video_seq, size[0], size[1])
|
59
|
+
return sp_mat
|
60
|
+
|
61
|
+
|
62
|
+
def read_img_gray(file_path):
|
63
|
+
im = cv2.imread(file_path).astype(np.float32) / 255.0
|
64
|
+
im = cv2.cvtColor(im, cv2.COLOR_BGR2GRAY)
|
65
|
+
im = np.expand_dims(im, axis=0)
|
66
|
+
return im
|