spikezoo 0.1.2__py3-none-any.whl → 0.2__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (192) hide show
  1. spikezoo/__init__.py +13 -0
  2. spikezoo/archs/__pycache__/__init__.cpython-39.pyc +0 -0
  3. spikezoo/archs/base/__pycache__/nets.cpython-39.pyc +0 -0
  4. spikezoo/archs/base/nets.py +34 -0
  5. spikezoo/archs/bsf/README.md +92 -0
  6. spikezoo/archs/bsf/datasets/datasets.py +328 -0
  7. spikezoo/archs/bsf/datasets/ds_utils.py +64 -0
  8. spikezoo/archs/bsf/main.py +398 -0
  9. spikezoo/archs/bsf/metrics/psnr.py +22 -0
  10. spikezoo/archs/bsf/metrics/ssim.py +54 -0
  11. spikezoo/archs/bsf/models/bsf/__pycache__/align.cpython-39.pyc +0 -0
  12. spikezoo/archs/bsf/models/bsf/__pycache__/bsf.cpython-39.pyc +0 -0
  13. spikezoo/archs/bsf/models/bsf/__pycache__/rep.cpython-39.pyc +0 -0
  14. spikezoo/archs/bsf/models/bsf/align.py +154 -0
  15. spikezoo/archs/bsf/models/bsf/bsf.py +105 -0
  16. spikezoo/archs/bsf/models/bsf/dsft_convert.py +96 -0
  17. spikezoo/archs/bsf/models/bsf/rep.py +44 -0
  18. spikezoo/archs/bsf/models/get_model.py +7 -0
  19. spikezoo/archs/bsf/prepare_data/DSFT.py +62 -0
  20. spikezoo/archs/bsf/prepare_data/crop_dataset_train.py +135 -0
  21. spikezoo/archs/bsf/prepare_data/crop_dataset_val.py +139 -0
  22. spikezoo/archs/bsf/prepare_data/crop_train.sh +4 -0
  23. spikezoo/archs/bsf/prepare_data/crop_val.sh +4 -0
  24. spikezoo/archs/bsf/prepare_data/io_utils.py +64 -0
  25. spikezoo/archs/bsf/requirements.txt +9 -0
  26. spikezoo/archs/bsf/test.py +16 -0
  27. spikezoo/archs/bsf/utils.py +154 -0
  28. spikezoo/archs/spikeclip/__pycache__/nets.cpython-39.pyc +0 -0
  29. spikezoo/archs/spikeclip/nets.py +40 -0
  30. spikezoo/archs/spikeformer/CheckPoints/readme +1 -0
  31. spikezoo/archs/spikeformer/DataProcess/DataExtactor.py +60 -0
  32. spikezoo/archs/spikeformer/DataProcess/DataLoader.py +115 -0
  33. spikezoo/archs/spikeformer/DataProcess/LoadSpike.py +39 -0
  34. spikezoo/archs/spikeformer/EvalResults/readme +1 -0
  35. spikezoo/archs/spikeformer/LICENSE +21 -0
  36. spikezoo/archs/spikeformer/Metrics/Metrics.py +50 -0
  37. spikezoo/archs/spikeformer/Metrics/__init__.py +0 -0
  38. spikezoo/archs/spikeformer/Model/Loss.py +89 -0
  39. spikezoo/archs/spikeformer/Model/SpikeFormer.py +230 -0
  40. spikezoo/archs/spikeformer/Model/__init__.py +0 -0
  41. spikezoo/archs/spikeformer/Model/__pycache__/SpikeFormer.cpython-39.pyc +0 -0
  42. spikezoo/archs/spikeformer/Model/__pycache__/__init__.cpython-39.pyc +0 -0
  43. spikezoo/archs/spikeformer/README.md +30 -0
  44. spikezoo/archs/spikeformer/evaluate.py +87 -0
  45. spikezoo/archs/spikeformer/recon_real_data.py +97 -0
  46. spikezoo/archs/spikeformer/requirements.yml +95 -0
  47. spikezoo/archs/spikeformer/train.py +173 -0
  48. spikezoo/archs/spikeformer/utils.py +22 -0
  49. spikezoo/archs/spk2imgnet/.github/workflows/pylint.yml +23 -0
  50. spikezoo/archs/spk2imgnet/.gitignore +150 -0
  51. spikezoo/archs/spk2imgnet/DCNv2.py +135 -0
  52. spikezoo/archs/spk2imgnet/__pycache__/DCNv2.cpython-39.pyc +0 -0
  53. spikezoo/archs/spk2imgnet/__pycache__/align_arch.cpython-39.pyc +0 -0
  54. spikezoo/archs/spk2imgnet/__pycache__/nets.cpython-39.pyc +0 -0
  55. spikezoo/archs/spk2imgnet/align_arch.py +159 -0
  56. spikezoo/archs/spk2imgnet/dataset.py +144 -0
  57. spikezoo/archs/spk2imgnet/nets.py +230 -0
  58. spikezoo/archs/spk2imgnet/readme.md +86 -0
  59. spikezoo/archs/spk2imgnet/test_gen_imgseq.py +118 -0
  60. spikezoo/archs/spk2imgnet/train.py +189 -0
  61. spikezoo/archs/spk2imgnet/utils.py +64 -0
  62. spikezoo/archs/ssir/README.md +87 -0
  63. spikezoo/archs/ssir/configs/SSIR.yml +37 -0
  64. spikezoo/archs/ssir/configs/yml_parser.py +78 -0
  65. spikezoo/archs/ssir/datasets/dataset_sreds.py +170 -0
  66. spikezoo/archs/ssir/datasets/ds_utils.py +66 -0
  67. spikezoo/archs/ssir/losses.py +21 -0
  68. spikezoo/archs/ssir/main.py +326 -0
  69. spikezoo/archs/ssir/metrics/psnr.py +22 -0
  70. spikezoo/archs/ssir/metrics/ssim.py +54 -0
  71. spikezoo/archs/ssir/models/Vgg19.py +42 -0
  72. spikezoo/archs/ssir/models/__pycache__/layers.cpython-39.pyc +0 -0
  73. spikezoo/archs/ssir/models/__pycache__/networks.cpython-39.pyc +0 -0
  74. spikezoo/archs/ssir/models/layers.py +110 -0
  75. spikezoo/archs/ssir/models/networks.py +61 -0
  76. spikezoo/archs/ssir/requirements.txt +8 -0
  77. spikezoo/archs/ssir/shells/eval_SREDS.sh +6 -0
  78. spikezoo/archs/ssir/shells/train_SSIR.sh +12 -0
  79. spikezoo/archs/ssir/test.py +3 -0
  80. spikezoo/archs/ssir/utils.py +154 -0
  81. spikezoo/archs/ssml/__pycache__/cbam.cpython-39.pyc +0 -0
  82. spikezoo/archs/ssml/__pycache__/model.cpython-39.pyc +0 -0
  83. spikezoo/archs/ssml/cbam.py +224 -0
  84. spikezoo/archs/ssml/model.py +290 -0
  85. spikezoo/archs/ssml/res.png +0 -0
  86. spikezoo/archs/ssml/test.py +67 -0
  87. spikezoo/archs/stir/.git-credentials +0 -0
  88. spikezoo/archs/stir/README.md +65 -0
  89. spikezoo/archs/stir/ckpt_outputs/Descriptions.txt +1 -0
  90. spikezoo/archs/stir/configs/STIR.yml +37 -0
  91. spikezoo/archs/stir/configs/utils.py +155 -0
  92. spikezoo/archs/stir/configs/yml_parser.py +78 -0
  93. spikezoo/archs/stir/datasets/dataset_sreds.py +180 -0
  94. spikezoo/archs/stir/datasets/ds_utils.py +66 -0
  95. spikezoo/archs/stir/eval_SREDS.sh +5 -0
  96. spikezoo/archs/stir/main.py +397 -0
  97. spikezoo/archs/stir/metrics/losses.py +219 -0
  98. spikezoo/archs/stir/metrics/psnr.py +22 -0
  99. spikezoo/archs/stir/metrics/ssim.py +54 -0
  100. spikezoo/archs/stir/models/Vgg19.py +42 -0
  101. spikezoo/archs/stir/models/__pycache__/networks_STIR.cpython-39.pyc +0 -0
  102. spikezoo/archs/stir/models/__pycache__/submodules.cpython-39.pyc +0 -0
  103. spikezoo/archs/stir/models/__pycache__/transformer_new.cpython-39.pyc +0 -0
  104. spikezoo/archs/stir/models/networks_STIR.py +361 -0
  105. spikezoo/archs/stir/models/submodules.py +86 -0
  106. spikezoo/archs/stir/models/transformer_new.py +151 -0
  107. spikezoo/archs/stir/package_core/build/lib/package_core/__init__.py +0 -0
  108. spikezoo/archs/stir/package_core/build/lib/package_core/convertions.py +721 -0
  109. spikezoo/archs/stir/package_core/build/lib/package_core/disp_netS.py +133 -0
  110. spikezoo/archs/stir/package_core/build/lib/package_core/flow_utils.py +167 -0
  111. spikezoo/archs/stir/package_core/build/lib/package_core/generic_train_test.py +76 -0
  112. spikezoo/archs/stir/package_core/build/lib/package_core/geometry.py +458 -0
  113. spikezoo/archs/stir/package_core/build/lib/package_core/image_proc.py +183 -0
  114. spikezoo/archs/stir/package_core/build/lib/package_core/linalg.py +40 -0
  115. spikezoo/archs/stir/package_core/build/lib/package_core/losses.py +198 -0
  116. spikezoo/archs/stir/package_core/build/lib/package_core/metrics.py +51 -0
  117. spikezoo/archs/stir/package_core/build/lib/package_core/model_base.py +53 -0
  118. spikezoo/archs/stir/package_core/build/lib/package_core/net_basics.py +100 -0
  119. spikezoo/archs/stir/package_core/build/lib/package_core/resnet.py +333 -0
  120. spikezoo/archs/stir/package_core/build/lib/package_core/transforms.py +123 -0
  121. spikezoo/archs/stir/package_core/build/lib/package_core/utils.py +72 -0
  122. spikezoo/archs/stir/package_core/dist/package_core-0.0.0-py3.9.egg +0 -0
  123. spikezoo/archs/stir/package_core/package_core/__init__.py +0 -0
  124. spikezoo/archs/stir/package_core/package_core/__pycache__/__init__.cpython-39.pyc +0 -0
  125. spikezoo/archs/stir/package_core/package_core/__pycache__/net_basics.cpython-39.pyc +0 -0
  126. spikezoo/archs/stir/package_core/package_core/convertions.py +721 -0
  127. spikezoo/archs/stir/package_core/package_core/disp_netS.py +133 -0
  128. spikezoo/archs/stir/package_core/package_core/flow_utils.py +167 -0
  129. spikezoo/archs/stir/package_core/package_core/generic_train_test.py +76 -0
  130. spikezoo/archs/stir/package_core/package_core/geometry.py +458 -0
  131. spikezoo/archs/stir/package_core/package_core/image_proc.py +183 -0
  132. spikezoo/archs/stir/package_core/package_core/linalg.py +40 -0
  133. spikezoo/archs/stir/package_core/package_core/losses.py +198 -0
  134. spikezoo/archs/stir/package_core/package_core/metrics.py +51 -0
  135. spikezoo/archs/stir/package_core/package_core/model_base.py +53 -0
  136. spikezoo/archs/stir/package_core/package_core/net_basics.py +100 -0
  137. spikezoo/archs/stir/package_core/package_core/resnet.py +333 -0
  138. spikezoo/archs/stir/package_core/package_core/transforms.py +123 -0
  139. spikezoo/archs/stir/package_core/package_core/utils.py +72 -0
  140. spikezoo/archs/stir/package_core/package_core.egg-info/PKG-INFO +3 -0
  141. spikezoo/archs/stir/package_core/package_core.egg-info/SOURCES.txt +20 -0
  142. spikezoo/archs/stir/package_core/package_core.egg-info/dependency_links.txt +1 -0
  143. spikezoo/archs/stir/package_core/package_core.egg-info/top_level.txt +1 -0
  144. spikezoo/archs/stir/package_core/setup.py +5 -0
  145. spikezoo/archs/stir/requirements.txt +12 -0
  146. spikezoo/archs/stir/train_STIR.sh +9 -0
  147. spikezoo/archs/tfi/__pycache__/nets.cpython-39.pyc +0 -0
  148. spikezoo/archs/tfi/nets.py +43 -0
  149. spikezoo/archs/tfp/__pycache__/nets.cpython-39.pyc +0 -0
  150. spikezoo/archs/tfp/nets.py +13 -0
  151. spikezoo/archs/wgse/README.md +64 -0
  152. spikezoo/archs/wgse/__pycache__/dwtnets.cpython-39.pyc +0 -0
  153. spikezoo/archs/wgse/__pycache__/submodules.cpython-39.pyc +0 -0
  154. spikezoo/archs/wgse/dataset.py +59 -0
  155. spikezoo/archs/wgse/demo.png +0 -0
  156. spikezoo/archs/wgse/demo.py +83 -0
  157. spikezoo/archs/wgse/dwtnets.py +145 -0
  158. spikezoo/archs/wgse/eval.py +133 -0
  159. spikezoo/archs/wgse/logs/WGSE-Dwt1dNet-db8-5-ks3/log.txt +11 -0
  160. spikezoo/archs/wgse/submodules.py +68 -0
  161. spikezoo/archs/wgse/train.py +261 -0
  162. spikezoo/archs/wgse/transform.py +139 -0
  163. spikezoo/archs/wgse/utils.py +128 -0
  164. spikezoo/archs/wgse/weights/demo.png +0 -0
  165. spikezoo/data/base/test/gt/200_part1_key_id151.png +0 -0
  166. spikezoo/data/base/test/gt/200_part3_key_id151.png +0 -0
  167. spikezoo/data/base/test/gt/203_part1_key_id151.png +0 -0
  168. spikezoo/data/base/test/spike/200_part1_key_id151.dat +0 -0
  169. spikezoo/data/base/test/spike/200_part3_key_id151.dat +0 -0
  170. spikezoo/data/base/test/spike/203_part1_key_id151.dat +0 -0
  171. spikezoo/data/base/train/gt/203_part2_key_id151.png +0 -0
  172. spikezoo/data/base/train/gt/203_part3_key_id151.png +0 -0
  173. spikezoo/data/base/train/gt/203_part4_key_id151.png +0 -0
  174. spikezoo/data/base/train/spike/203_part2_key_id151.dat +0 -0
  175. spikezoo/data/base/train/spike/203_part3_key_id151.dat +0 -0
  176. spikezoo/data/base/train/spike/203_part4_key_id151.dat +0 -0
  177. spikezoo/datasets/base_dataset.py +2 -3
  178. spikezoo/metrics/__init__.py +1 -1
  179. spikezoo/models/base_model.py +1 -3
  180. spikezoo/pipeline/base_pipeline.py +7 -5
  181. spikezoo/pipeline/train_pipeline.py +1 -1
  182. spikezoo/utils/other_utils.py +16 -6
  183. spikezoo/utils/spike_utils.py +33 -29
  184. spikezoo/utils/vidar_loader.cpython-39-x86_64-linux-gnu.so +0 -0
  185. spikezoo-0.2.dist-info/METADATA +163 -0
  186. spikezoo-0.2.dist-info/RECORD +211 -0
  187. spikezoo/models/spcsnet_model.py +0 -19
  188. spikezoo-0.1.2.dist-info/METADATA +0 -39
  189. spikezoo-0.1.2.dist-info/RECORD +0 -36
  190. {spikezoo-0.1.2.dist-info → spikezoo-0.2.dist-info}/LICENSE.txt +0 -0
  191. {spikezoo-0.1.2.dist-info → spikezoo-0.2.dist-info}/WHEEL +0 -0
  192. {spikezoo-0.1.2.dist-info → spikezoo-0.2.dist-info}/top_level.txt +0 -0
spikezoo/__init__.py CHANGED
@@ -0,0 +1,13 @@
1
+ from .utils.spike_utils import load_vidar_dat
2
+ from .models import model_list
3
+ from .datasets import dataset_list
4
+ from .metrics import metric_all_names
5
+
6
+ def get_datasets():
7
+ return dataset_list
8
+
9
+ def get_models():
10
+ return model_list
11
+
12
+ def get_metrics():
13
+ return metric_all_names
@@ -0,0 +1,34 @@
1
+ import torch.nn as nn
2
+
3
+ def conv_layer(inDim, outDim, ks, s, p, norm_layer='none'):
4
+ ## convolutional layer
5
+ conv = nn.Conv2d(inDim, outDim, kernel_size=ks, stride=s, padding=p)
6
+ relu = nn.ReLU(True)
7
+ assert norm_layer in ('batch', 'instance', 'none')
8
+ if norm_layer == 'none':
9
+ seq = nn.Sequential(*[conv, relu])
10
+ else:
11
+ if (norm_layer == 'instance'):
12
+ norm = nn.InstanceNorm2d(outDim, affine=False, track_running_stats=False) # instance norm
13
+ else:
14
+ momentum = 0.1
15
+ norm = nn.BatchNorm2d(outDim, momentum = momentum, affine=True, track_running_stats=True)
16
+ seq = nn.Sequential(*[conv, norm, relu])
17
+ return seq
18
+
19
+ class BaseNet(nn.Module):
20
+ """Borrow the structure from the SpikeCLIP. (https://arxiv.org/abs/2501.04477)"""
21
+ def __init__(self, inDim=41):
22
+ super(BaseNet, self).__init__()
23
+ norm='none'
24
+ outDim=1
25
+ convBlock1 = conv_layer(inDim,64,3,1,1)
26
+ convBlock2 = conv_layer(64,128,3,1,1,norm)
27
+ convBlock3 = conv_layer(128,64,3,1,1,norm)
28
+ convBlock4 = conv_layer(64,16,3,1,1,norm)
29
+ conv = nn.Conv2d(16, outDim, 3, 1, 1)
30
+ self.seq = nn.Sequential(*[convBlock1, convBlock2, convBlock3, convBlock4, conv])
31
+
32
+ def forward(self,x):
33
+ return self.seq(x)
34
+
@@ -0,0 +1,92 @@
1
+ ## [CVPR 2024] Boosting Spike Camera Image Reconstruction from a Perspective of Dealing with Spike Fluctuations
2
+
3
+ <h4 align="center"> Rui Zhao<sup>1,2</sup>, Ruiqin Xiong<sup>1,2</sup>, Jing Zhao<sup>1,2</sup>, Jian Zhang<sup>3</sup>, Xiaopeng Fan<sup>4</sup>, Zhaofei Yu<sup>1,2</sup>, Tiejun Huang<sup>1,2</sup> </h4>
4
+ <h4 align="center">1. School of Computer Science, Peking University<br>
5
+ 2. National Key Laboratory for Multimedia Information Processing, Peking University<br>
6
+ 3. School of Electronic and Computer Engineering, Peking University<br>
7
+ 4. School of Computer Science and Technology, Harbin Institute of Technology
8
+ </h4><br>
9
+
10
+ This repository contains the official source code for our paper:
11
+
12
+ Boosting Spike Camera Image Reconstruction from a Perspective of Dealing with Spike Fluctuations
13
+
14
+ CVPR 2024
15
+
16
+ ## Environment
17
+
18
+ You can choose cudatoolkit version to match your server. The code is tested on PyTorch 2.0.1+cu120.
19
+
20
+ ```bash
21
+ conda create -n bsf python==3.10.9
22
+ conda activate bsf
23
+ # You can choose the PyTorch version you like, we recommand version >= 1.10.1
24
+ # For example
25
+ pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118
26
+ pip3 install -r requirements.txt
27
+ ```
28
+
29
+ ## Prepare the Data
30
+
31
+ ##### 1. Download the dataset (Approximate 50GB)
32
+
33
+ [Link of the dataset (BaiduNetDisk)](https://pan.baidu.com/s/1zBp-ed1KtmhAab5Z_62ttw) (Password: 2728)
34
+
35
+ ##### 2. Deploy the dataset for training faster (Approximate <u>another</u> 125GB)
36
+
37
+ firstly modify the data root and output root in `./prepare_data/crop_dataset_train.py` and `./prepare_data/crop_dataset_val.py`
38
+
39
+ ```shell
40
+ cd prepare_data &&
41
+ bash crop_train.sh $your_gpu_id &&
42
+ bash crop_val.sh $your_gpu_id
43
+ ```
44
+
45
+ ## Evaluate
46
+
47
+ ```shell
48
+ CUDA_VISIBLE_DEVICES=$1 python3 -W ignore main.py \
49
+ --alpha 0.7 \
50
+ --vis-path vis/bsf \
51
+ -evp eval_vis/bsf \
52
+ --logs_file_name bsf \
53
+ --compile_model \
54
+ --test_eval \
55
+ --arch bsf \
56
+ --pretrained ckpt/bsf.pth
57
+ ```
58
+
59
+ ## Train
60
+
61
+ ```shell
62
+ CUDA_VISIBLE_DEVICES=$1 python3 -W ignore main.py \
63
+ -bs 8 \
64
+ -j 8 \
65
+ -lr 1e-4 \
66
+ --epochs 61 \
67
+ --train-res 96 96 \
68
+ --lr-scale-factor 0.5 \
69
+ --milestones 10 20 30 40 50 60 70 80 90 100 \
70
+ --alpha 0.7 \
71
+ --vis-path vis/bsf \
72
+ -evp eval_vis/bsf \
73
+ --logs_file_name bsf \
74
+ --compile_model \
75
+ --weight_decay 0.0 \
76
+ --eval-interval 10 \
77
+ --half_reserve 0 \
78
+ --arch bsf
79
+ ```
80
+
81
+ ## Citations
82
+
83
+ If you find this code useful in your research, please consider citing our paper:
84
+
85
+ ```
86
+ @inproceedings{zhao2024boosting,
87
+ title={Boosting Spike Camera Image Reconstruction from a Perspective of Dealing with Spike Fluctuations},
88
+ author={Zhao, Rui and Xiong, Ruiqin and Zhao, Jing and Zhang, Jian and Fan, Xiaopeng and Yu, Zhaofei, and Huang, Tiejun},
89
+ booktitle={IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
90
+ year={2024}
91
+ }
92
+ ```
@@ -0,0 +1,328 @@
1
+ import os
2
+ import os.path as osp
3
+ import random
4
+ import numpy as np
5
+ import torch
6
+ import torch.utils.data as data
7
+ from datasets.ds_utils import *
8
+ import h5py
9
+ from tqdm import *
10
+
11
+
12
+ class Augmentor:
13
+ def __init__(self, crop_size):
14
+ # spatial augmentation params
15
+ self.crop_size = crop_size
16
+
17
+ def augment_img(self, img, mode=0):
18
+ '''Kai Zhang (github: https://github.com/cszn)
19
+ W x H x C or W x H
20
+ 注:要使用此种augmentation, 则需保证crop_h = crop_w
21
+ '''
22
+ if mode == 0:
23
+ return img
24
+ elif mode == 1:
25
+ return np.flipud(np.rot90(img))
26
+ elif mode == 2:
27
+ return np.flipud(img)
28
+ elif mode == 3:
29
+ return np.rot90(img, k=3)
30
+ elif mode == 4:
31
+ return np.flipud(np.rot90(img, k=2))
32
+ elif mode == 5:
33
+ return np.rot90(img)
34
+ elif mode == 6:
35
+ return np.rot90(img, k=2)
36
+ elif mode == 7:
37
+ return np.flipud(np.rot90(img, k=3))
38
+
39
+ def spatial_transform(self, spk_list, img_list):
40
+ mode = random.randint(0, 7)
41
+ spike_h = spk_list[0].shape[1]
42
+ spike_w = spk_list[0].shape[2]
43
+
44
+ if spike_h > self.crop_size[0]:
45
+ y0 = np.random.randint(0, spike_h - self.crop_size[0])
46
+ else:
47
+ y0 = 0
48
+
49
+ if spike_w > self.crop_size[1]:
50
+ x0 = np.random.randint(0, spike_w - self.crop_size[1])
51
+ else:
52
+ x0 = 0
53
+
54
+ for ii, spk in enumerate(spk_list):
55
+ spk = np.transpose(spk, [1,2,0])
56
+ spk = spk[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1], :]
57
+ spk = self.augment_img(spk, mode=mode)
58
+ spk_list[ii] = np.transpose(spk, [2,0,1])
59
+
60
+ for ii, img in enumerate(img_list):
61
+ img = np.transpose(img, [1,2,0])
62
+ img = img[y0:y0+self.crop_size[0], x0:x0+self.crop_size[1], :]
63
+ img = self.augment_img(img, mode=mode)
64
+ img_list[ii] = np.transpose(img, [2,0,1])
65
+
66
+ return spk_list, img_list
67
+
68
+ def __call__(self, spk_list, img_list):
69
+ spk_list, img_list = self.spatial_transform(spk_list, img_list)
70
+ spk_list = [np.ascontiguousarray(spk) for spk in spk_list]
71
+ img_list = [np.ascontiguousarray(img) for img in img_list]
72
+ return spk_list, img_list
73
+
74
+
75
+
76
+ class sreds_train(torch.utils.data.Dataset):
77
+ '''
78
+ 测试集Spike原始分辨率 148 x 256
79
+ '''
80
+ def __init__(self, args):
81
+ self.args = args
82
+ self.input_type = args.input_type
83
+ self.eta_list = args.eta_list
84
+ self.gamma = args.gamma
85
+ self.alpha = args.alpha
86
+ self.augmentor = Augmentor(crop_size=args.train_res)
87
+
88
+ self.dsft_path_name = 'dsft'
89
+ self.spike_path_name = 'spikes'
90
+
91
+ self.read_dsft = not args.no_dsft
92
+
93
+ self.samples = self.collect_samples()
94
+ print('The samples num of training data: {:d}'.format(len(self.samples)))
95
+
96
+ def confirm_exist(self, path_list_list):
97
+ for pl in path_list_list:
98
+ for p in pl:
99
+ if not osp.exists(p):
100
+ return 0
101
+ return 1
102
+
103
+ def collect_samples(self):
104
+ samples = []
105
+ root_path = osp.join(self.args.data_root, 'crop', 'train')
106
+
107
+ for eta in self.eta_list:
108
+ cur_eta_dir = osp.join(root_path, "eta_{:.2f}_gamma_{:d}_alpha_{:.1f}".format(eta, self.gamma, self.alpha))
109
+ scene_list = sorted(os.listdir(cur_eta_dir))
110
+
111
+ for scene in scene_list:
112
+ scene_path = osp.join(cur_eta_dir, scene)
113
+ crop_list = sorted(os.listdir(scene_path))
114
+ for crop in crop_list:
115
+ crop_path = osp.join(scene_path, crop)
116
+ spike_dir = osp.join(crop_path, self.spike_path_name)
117
+ image_dir = osp.join(root_path, 'imgs', scene, crop)
118
+ dsft_dir = osp.join(crop_path, 'dsft')
119
+
120
+ ## 数据集的制作:dsft 从 09~20.h5, img从 10~19.png
121
+ spikes_path_list = [osp.join(spike_dir, '{:08d}.dat'.format(ii)) for ii in range(11, 28+1)]
122
+ dsft_path_list = [osp.join(dsft_dir, '{:08d}.h5'.format(ii)) for ii in range(11, 28+1)]
123
+ images00_path_list = [osp.join(image_dir, '{:08d}.png'.format(ii)) for ii in range(18, 21+1)]
124
+ # images05_path_list = [osp.join(image_dir, '{:08d}_05.png'.format(ii)) for ii in range(8, 11+1)]
125
+
126
+ if(self.confirm_exist([spikes_path_list, images00_path_list])):
127
+ s = {}
128
+ s['spikes_paths'] = spikes_path_list
129
+ s['dsft_paths'] = dsft_path_list
130
+ s['images_paths'] = images00_path_list
131
+ s['norm_fac'] = eta * self.alpha
132
+ # s['images_05_paths'] = images05_path_list
133
+ samples.append(s)
134
+ return samples
135
+
136
+ def _load_sample(self, s):
137
+ ## 一组数据中有4个时间点可以做key-frame,抽其中一个作为一次采样
138
+ ## images只有四个,分别是18, 19, 20, 21,直接对应于offset的1,2,3,4
139
+ ## spikes和dsfts都比较多,所使用的key是{18, 19, 20, 21},也即对应于spike和dsft的path list中的{7,8,9,10}index
140
+ key_frame_offset = random.choice([0,1,2,3])
141
+ s['spikes_paths'] = s['spikes_paths'][7+key_frame_offset-3-self.args.half_reserve : 7+key_frame_offset+3+self.args.half_reserve+1]
142
+ s['dsft_paths'] = s['dsft_paths'][7+key_frame_offset-3-self.args.half_reserve : 7+key_frame_offset+3+self.args.half_reserve+1]
143
+
144
+ ## 第一个Key是13.dat, imgs从10开始,应该是 key_frame_offset+3-2
145
+ s['images_paths'] = [s['images_paths'][key_frame_offset]]
146
+
147
+ data = {}
148
+ if self.read_dsft:
149
+ ## 读入Spike
150
+ h5files = [h5py.File(p, 'r') for p in s['dsft_paths']]
151
+ data['dsft'] = [np.array(f['dsft']).astype(np.float32) for f in h5files]
152
+ for f in h5files:
153
+ f.close()
154
+ data['spikes'] = [dat_to_spmat(p, size=(256, 256)).astype(np.float32) for p in s['spikes_paths']]
155
+
156
+ ## 读入 Image
157
+ data['images'] = [read_img_gray(p) for p in s['images_paths']]
158
+ data['norm_fac'] = np.array(s['norm_fac'])
159
+
160
+ if self.read_dsft:
161
+ data['spikes'] = data['spikes'] + data['dsft']
162
+ data['spikes'], data['images'] = self.augmentor(data['spikes'], data['images'])
163
+ data['spikes'], data['dsft'] = data['spikes'][:len(data['spikes'])//2], data['spikes'][len(data['spikes'])//2:]
164
+ else:
165
+ data['spikes'], data['images'] = self.augmentor(data['spikes'], data['images'])
166
+
167
+ return data
168
+
169
+ def __len__(self):
170
+ return len(self.samples)
171
+
172
+ def __getitem__(self, index):
173
+ data = self._load_sample(self.samples[index])
174
+ return data
175
+
176
+
177
+ class sreds_test(torch.utils.data.Dataset):
178
+ '''
179
+ 测试集Spike原始分辨率 540 x 960
180
+ '''
181
+ def __init__(self, args, eta):
182
+ self.args = args
183
+ self.input_type = args.input_type
184
+ self.alpha = args.alpha
185
+ self.eta = eta
186
+ self.gamma = args.gamma
187
+ self.dsft_path_name = 'dsft'
188
+ self.spike_path_name = 'spikes'
189
+ self.samples = self.collect_samples()
190
+ print('The samples num of testing data: {:d}'.format(len(self.samples)))
191
+
192
+ def confirm_exist(self, path_list_list):
193
+ for pl in path_list_list:
194
+ for p in pl:
195
+ if not osp.exists(p):
196
+ print(p)
197
+ return 0
198
+ return 1
199
+
200
+ def collect_samples(self):
201
+ root_path = osp.join(self.args.data_root, 'crop', 'val')
202
+
203
+ cur_eta_dir = osp.join(root_path, "eta_{:.2f}_gamma_{:d}_alpha_{:.1f}".format(self.eta, self.gamma, self.alpha))
204
+ scene_list = sorted(os.listdir(cur_eta_dir))
205
+ samples = []
206
+
207
+ for scene in scene_list:
208
+ scene_path = osp.join(cur_eta_dir, scene)
209
+ spike_dir = osp.join(scene_path, self.spike_path_name)
210
+ image_dir = osp.join(root_path, 'imgs', scene)
211
+ dsft_dir = osp.join(scene_path, 'dsft')
212
+
213
+ ## 数据集的制作:dsft 从 09~20.h5, img从 10~19.png
214
+ spikes_path_list = [osp.join(spike_dir, '{:08d}.dat'.format(ii)) for ii in range(11, 28+1)]
215
+ dsft_path_list = [osp.join(dsft_dir, '{:08d}.h5'.format(ii)) for ii in range(11, 28+1)]
216
+ images_path_list = [osp.join(image_dir, '{:08d}.png'.format(ii)) for ii in range(18, 21+1)]
217
+
218
+ if(self.confirm_exist([spikes_path_list, images_path_list])):
219
+ ## 在test函数里测试四组数据
220
+ ## images只有四个,分别是18, 19, 20, 21,直接对应于offset的1,2,3,4
221
+ ## spikes和dsfts都比较多,所使用的key是{18, 19, 20, 21},也即对应于spike和dsft的path list中的{7,8,9,10}index
222
+ for ii in range(4):
223
+ # for ii in range(1):
224
+ s = {}
225
+ s['spikes_paths'] = spikes_path_list[7+ii-3-self.args.half_reserve : 7+ii+3+self.args.half_reserve+1]
226
+ s['dsft_paths'] = dsft_path_list[7+ii-3-self.args.half_reserve : 7+ii+3+self.args.half_reserve+1]
227
+ s['images_paths'] = [images_path_list[ii]]
228
+ s['norm_fac'] = self.alpha * self.eta
229
+ samples.append(s)
230
+
231
+ return samples
232
+
233
+ def _load_sample(self, s):
234
+ data = {}
235
+ h5files = [h5py.File(p, 'r') for p in s['dsft_paths']]
236
+ data['dsft'] = [np.array(f['dsft']).astype(np.float32) for f in h5files]
237
+ for f in h5files:
238
+ f.close()
239
+ data['spikes'] = [dat_to_spmat(p, size=(540, 960)).astype(np.float32) for p in s['spikes_paths']]
240
+
241
+ data['images'] = [read_img_gray(p) for p in s['images_paths']]
242
+ data['norm_fac'] = np.array(s['norm_fac'])
243
+ return data
244
+
245
+ def __len__(self):
246
+ return len(self.samples)
247
+
248
+ def __getitem__(self, index):
249
+ data = self._load_sample(self.samples[index])
250
+ return data
251
+
252
+ class sreds_test_small(torch.utils.data.Dataset):
253
+ '''
254
+ 测试集Spike原始分辨率 384 x 512
255
+ '''
256
+ def __init__(self, args, eta):
257
+ self.args = args
258
+ self.input_type = args.input_type
259
+ self.alpha = args.alpha
260
+ self.eta = eta
261
+ self.gamma = args.gamma
262
+ self.dsft_path_name = 'dsft'
263
+ self.spike_path_name = 'spikes'
264
+
265
+ self.read_dsft = not args.no_dsft
266
+ self.samples = self.collect_samples()
267
+ print('The samples num of testing data: {:d}'.format(len(self.samples)))
268
+
269
+ def confirm_exist(self, path_list_list):
270
+ for pl in path_list_list:
271
+ for p in pl:
272
+ if not osp.exists(p):
273
+ print(p)
274
+ return 0
275
+ return 1
276
+
277
+ def collect_samples(self):
278
+ root_path = osp.join(self.args.data_root, 'crop', 'val_small')
279
+
280
+ cur_eta_dir = osp.join(root_path, "eta_{:.2f}_gamma_{:d}_alpha_{:.1f}".format(self.eta, self.gamma, self.alpha))
281
+ scene_list = sorted(os.listdir(cur_eta_dir))
282
+ samples = []
283
+
284
+ for scene in scene_list:
285
+ scene_path = osp.join(cur_eta_dir, scene)
286
+ spike_dir = osp.join(scene_path, self.spike_path_name)
287
+ image_dir = osp.join(root_path, 'imgs', scene)
288
+ dsft_dir = osp.join(scene_path, 'dsft')
289
+
290
+ ## 数据集的制作:dsft 从 09~20.h5, img从 10~19.png
291
+ spikes_path_list = [osp.join(spike_dir, '{:08d}.dat'.format(ii)) for ii in range(11, 28+1)]
292
+ dsft_path_list = [osp.join(dsft_dir, '{:08d}.h5'.format(ii)) for ii in range(11, 28+1)]
293
+ images_path_list = [osp.join(image_dir, '{:08d}.png'.format(ii)) for ii in range(18, 21+1)]
294
+
295
+ if(self.confirm_exist([spikes_path_list, images_path_list])):
296
+ # for ii in range(4):
297
+ for ii in range(4):
298
+ s = {}
299
+ s['spikes_paths'] = spikes_path_list[7+ii-3-self.args.half_reserve : 7+ii+3+self.args.half_reserve+1]
300
+ s['dsft_paths'] = dsft_path_list[7+ii-3-self.args.half_reserve : 7+ii+3+self.args.half_reserve+1]
301
+ s['images_paths'] = [images_path_list[ii]]
302
+ s['norm_fac'] = self.alpha * self.eta
303
+ samples.append(s)
304
+
305
+ return samples
306
+
307
+ def _load_sample(self, s):
308
+ ## 在test函数里测试四组数据
309
+ ## spikes全取
310
+ ## image取四个key对应的[13, 14, 15, 16]
311
+ data = {}
312
+ if self.read_dsft:
313
+ h5files = [h5py.File(p, 'r') for p in s['dsft_paths']]
314
+ data['dsft'] = [np.array(f['dsft']).astype(np.float32) for f in h5files]
315
+ for f in h5files:
316
+ f.close()
317
+ data['spikes'] = [dat_to_spmat(p, size=(384, 512)).astype(np.float32) for p in s['spikes_paths']]
318
+
319
+ data['images'] = [read_img_gray(p) for p in s['images_paths']]
320
+ data['norm_fac'] = np.array(s['norm_fac'])
321
+ return data
322
+
323
+ def __len__(self):
324
+ return len(self.samples)
325
+
326
+ def __getitem__(self, index):
327
+ data = self._load_sample(self.samples[index])
328
+ return data
@@ -0,0 +1,64 @@
1
+ import numpy as np
2
+ import cv2
3
+
4
+ def RawToSpike(video_seq, h, w, flipud=True):
5
+ video_seq = np.array(video_seq).astype(np.uint8)
6
+ img_size = h*w
7
+ img_num = len(video_seq)//(img_size//8)
8
+ SpikeMatrix = np.zeros([img_num, h, w], np.uint8)
9
+ pix_id = np.arange(0,h*w)
10
+ pix_id = np.reshape(pix_id, (h, w))
11
+ comparator = np.left_shift(1, np.mod(pix_id, 8))
12
+ byte_id = pix_id // 8
13
+
14
+ for img_id in np.arange(img_num):
15
+ id_start = img_id*img_size//8
16
+ id_end = id_start + img_size//8
17
+ cur_info = video_seq[id_start:id_end]
18
+ data = cur_info[byte_id]
19
+ result = np.bitwise_and(data, comparator)
20
+ if flipud:
21
+ SpikeMatrix[img_id, :, :] = np.flipud((result == comparator))
22
+ else:
23
+ SpikeMatrix[img_id, :, :] = (result == comparator)
24
+
25
+ return SpikeMatrix
26
+
27
+
28
+ def SpikeToRaw(SpikeSeq, save_path):
29
+ """
30
+ SpikeSeq: Numpy array (sfn x h x w)
31
+ save_path: full saving path (string)
32
+ Rui Zhao
33
+ """
34
+ sfn, h, w = SpikeSeq.shape
35
+ base = np.power(2, np.linspace(0, 7, 8))
36
+ fid = open(save_path, 'ab')
37
+ for img_id in range(sfn):
38
+ # 模拟相机的倒像
39
+ spike = np.flipud(SpikeSeq[img_id, :, :])
40
+ # numpy按自动按行排,数据也是按行存的
41
+ spike = spike.flatten()
42
+ spike = spike.reshape([int(h*w/8), 8])
43
+ data = spike * base
44
+ data = np.sum(data, axis=1).astype(np.uint8)
45
+ fid.write(data.tobytes())
46
+
47
+ fid.close()
48
+
49
+ return
50
+
51
+
52
+ def dat_to_spmat(dat_path, size=[720, 1280]):
53
+ f = open(dat_path, 'rb')
54
+ video_seq = f.read()
55
+ video_seq = np.frombuffer(video_seq, 'b')
56
+ sp_mat = RawToSpike(video_seq, size[0], size[1])
57
+ return sp_mat
58
+
59
+
60
+ def read_img_gray(file_path):
61
+ im = cv2.imread(file_path).astype(np.float32) / 255.0
62
+ im = cv2.cvtColor(im, cv2.COLOR_BGR2GRAY)
63
+ im = np.expand_dims(im, axis=0)
64
+ return im