spikezoo 0.1.1__py3-none-any.whl → 0.2__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (192) hide show
  1. spikezoo/__init__.py +13 -0
  2. spikezoo/archs/__pycache__/__init__.cpython-39.pyc +0 -0
  3. spikezoo/archs/base/__pycache__/nets.cpython-39.pyc +0 -0
  4. spikezoo/archs/base/nets.py +34 -0
  5. spikezoo/archs/bsf/README.md +92 -0
  6. spikezoo/archs/bsf/datasets/datasets.py +328 -0
  7. spikezoo/archs/bsf/datasets/ds_utils.py +64 -0
  8. spikezoo/archs/bsf/main.py +398 -0
  9. spikezoo/archs/bsf/metrics/psnr.py +22 -0
  10. spikezoo/archs/bsf/metrics/ssim.py +54 -0
  11. spikezoo/archs/bsf/models/bsf/__pycache__/align.cpython-39.pyc +0 -0
  12. spikezoo/archs/bsf/models/bsf/__pycache__/bsf.cpython-39.pyc +0 -0
  13. spikezoo/archs/bsf/models/bsf/__pycache__/rep.cpython-39.pyc +0 -0
  14. spikezoo/archs/bsf/models/bsf/align.py +154 -0
  15. spikezoo/archs/bsf/models/bsf/bsf.py +105 -0
  16. spikezoo/archs/bsf/models/bsf/dsft_convert.py +96 -0
  17. spikezoo/archs/bsf/models/bsf/rep.py +44 -0
  18. spikezoo/archs/bsf/models/get_model.py +7 -0
  19. spikezoo/archs/bsf/prepare_data/DSFT.py +62 -0
  20. spikezoo/archs/bsf/prepare_data/crop_dataset_train.py +135 -0
  21. spikezoo/archs/bsf/prepare_data/crop_dataset_val.py +139 -0
  22. spikezoo/archs/bsf/prepare_data/crop_train.sh +4 -0
  23. spikezoo/archs/bsf/prepare_data/crop_val.sh +4 -0
  24. spikezoo/archs/bsf/prepare_data/io_utils.py +64 -0
  25. spikezoo/archs/bsf/requirements.txt +9 -0
  26. spikezoo/archs/bsf/test.py +16 -0
  27. spikezoo/archs/bsf/utils.py +154 -0
  28. spikezoo/archs/spikeclip/__pycache__/nets.cpython-39.pyc +0 -0
  29. spikezoo/archs/spikeclip/nets.py +40 -0
  30. spikezoo/archs/spikeformer/CheckPoints/readme +1 -0
  31. spikezoo/archs/spikeformer/DataProcess/DataExtactor.py +60 -0
  32. spikezoo/archs/spikeformer/DataProcess/DataLoader.py +115 -0
  33. spikezoo/archs/spikeformer/DataProcess/LoadSpike.py +39 -0
  34. spikezoo/archs/spikeformer/EvalResults/readme +1 -0
  35. spikezoo/archs/spikeformer/LICENSE +21 -0
  36. spikezoo/archs/spikeformer/Metrics/Metrics.py +50 -0
  37. spikezoo/archs/spikeformer/Metrics/__init__.py +0 -0
  38. spikezoo/archs/spikeformer/Model/Loss.py +89 -0
  39. spikezoo/archs/spikeformer/Model/SpikeFormer.py +230 -0
  40. spikezoo/archs/spikeformer/Model/__init__.py +0 -0
  41. spikezoo/archs/spikeformer/Model/__pycache__/SpikeFormer.cpython-39.pyc +0 -0
  42. spikezoo/archs/spikeformer/Model/__pycache__/__init__.cpython-39.pyc +0 -0
  43. spikezoo/archs/spikeformer/README.md +30 -0
  44. spikezoo/archs/spikeformer/evaluate.py +87 -0
  45. spikezoo/archs/spikeformer/recon_real_data.py +97 -0
  46. spikezoo/archs/spikeformer/requirements.yml +95 -0
  47. spikezoo/archs/spikeformer/train.py +173 -0
  48. spikezoo/archs/spikeformer/utils.py +22 -0
  49. spikezoo/archs/spk2imgnet/.github/workflows/pylint.yml +23 -0
  50. spikezoo/archs/spk2imgnet/.gitignore +150 -0
  51. spikezoo/archs/spk2imgnet/DCNv2.py +135 -0
  52. spikezoo/archs/spk2imgnet/__pycache__/DCNv2.cpython-39.pyc +0 -0
  53. spikezoo/archs/spk2imgnet/__pycache__/align_arch.cpython-39.pyc +0 -0
  54. spikezoo/archs/spk2imgnet/__pycache__/nets.cpython-39.pyc +0 -0
  55. spikezoo/archs/spk2imgnet/align_arch.py +159 -0
  56. spikezoo/archs/spk2imgnet/dataset.py +144 -0
  57. spikezoo/archs/spk2imgnet/nets.py +230 -0
  58. spikezoo/archs/spk2imgnet/readme.md +86 -0
  59. spikezoo/archs/spk2imgnet/test_gen_imgseq.py +118 -0
  60. spikezoo/archs/spk2imgnet/train.py +189 -0
  61. spikezoo/archs/spk2imgnet/utils.py +64 -0
  62. spikezoo/archs/ssir/README.md +87 -0
  63. spikezoo/archs/ssir/configs/SSIR.yml +37 -0
  64. spikezoo/archs/ssir/configs/yml_parser.py +78 -0
  65. spikezoo/archs/ssir/datasets/dataset_sreds.py +170 -0
  66. spikezoo/archs/ssir/datasets/ds_utils.py +66 -0
  67. spikezoo/archs/ssir/losses.py +21 -0
  68. spikezoo/archs/ssir/main.py +326 -0
  69. spikezoo/archs/ssir/metrics/psnr.py +22 -0
  70. spikezoo/archs/ssir/metrics/ssim.py +54 -0
  71. spikezoo/archs/ssir/models/Vgg19.py +42 -0
  72. spikezoo/archs/ssir/models/__pycache__/layers.cpython-39.pyc +0 -0
  73. spikezoo/archs/ssir/models/__pycache__/networks.cpython-39.pyc +0 -0
  74. spikezoo/archs/ssir/models/layers.py +110 -0
  75. spikezoo/archs/ssir/models/networks.py +61 -0
  76. spikezoo/archs/ssir/requirements.txt +8 -0
  77. spikezoo/archs/ssir/shells/eval_SREDS.sh +6 -0
  78. spikezoo/archs/ssir/shells/train_SSIR.sh +12 -0
  79. spikezoo/archs/ssir/test.py +3 -0
  80. spikezoo/archs/ssir/utils.py +154 -0
  81. spikezoo/archs/ssml/__pycache__/cbam.cpython-39.pyc +0 -0
  82. spikezoo/archs/ssml/__pycache__/model.cpython-39.pyc +0 -0
  83. spikezoo/archs/ssml/cbam.py +224 -0
  84. spikezoo/archs/ssml/model.py +290 -0
  85. spikezoo/archs/ssml/res.png +0 -0
  86. spikezoo/archs/ssml/test.py +67 -0
  87. spikezoo/archs/stir/.git-credentials +0 -0
  88. spikezoo/archs/stir/README.md +65 -0
  89. spikezoo/archs/stir/ckpt_outputs/Descriptions.txt +1 -0
  90. spikezoo/archs/stir/configs/STIR.yml +37 -0
  91. spikezoo/archs/stir/configs/utils.py +155 -0
  92. spikezoo/archs/stir/configs/yml_parser.py +78 -0
  93. spikezoo/archs/stir/datasets/dataset_sreds.py +180 -0
  94. spikezoo/archs/stir/datasets/ds_utils.py +66 -0
  95. spikezoo/archs/stir/eval_SREDS.sh +5 -0
  96. spikezoo/archs/stir/main.py +397 -0
  97. spikezoo/archs/stir/metrics/losses.py +219 -0
  98. spikezoo/archs/stir/metrics/psnr.py +22 -0
  99. spikezoo/archs/stir/metrics/ssim.py +54 -0
  100. spikezoo/archs/stir/models/Vgg19.py +42 -0
  101. spikezoo/archs/stir/models/__pycache__/networks_STIR.cpython-39.pyc +0 -0
  102. spikezoo/archs/stir/models/__pycache__/submodules.cpython-39.pyc +0 -0
  103. spikezoo/archs/stir/models/__pycache__/transformer_new.cpython-39.pyc +0 -0
  104. spikezoo/archs/stir/models/networks_STIR.py +361 -0
  105. spikezoo/archs/stir/models/submodules.py +86 -0
  106. spikezoo/archs/stir/models/transformer_new.py +151 -0
  107. spikezoo/archs/stir/package_core/build/lib/package_core/__init__.py +0 -0
  108. spikezoo/archs/stir/package_core/build/lib/package_core/convertions.py +721 -0
  109. spikezoo/archs/stir/package_core/build/lib/package_core/disp_netS.py +133 -0
  110. spikezoo/archs/stir/package_core/build/lib/package_core/flow_utils.py +167 -0
  111. spikezoo/archs/stir/package_core/build/lib/package_core/generic_train_test.py +76 -0
  112. spikezoo/archs/stir/package_core/build/lib/package_core/geometry.py +458 -0
  113. spikezoo/archs/stir/package_core/build/lib/package_core/image_proc.py +183 -0
  114. spikezoo/archs/stir/package_core/build/lib/package_core/linalg.py +40 -0
  115. spikezoo/archs/stir/package_core/build/lib/package_core/losses.py +198 -0
  116. spikezoo/archs/stir/package_core/build/lib/package_core/metrics.py +51 -0
  117. spikezoo/archs/stir/package_core/build/lib/package_core/model_base.py +53 -0
  118. spikezoo/archs/stir/package_core/build/lib/package_core/net_basics.py +100 -0
  119. spikezoo/archs/stir/package_core/build/lib/package_core/resnet.py +333 -0
  120. spikezoo/archs/stir/package_core/build/lib/package_core/transforms.py +123 -0
  121. spikezoo/archs/stir/package_core/build/lib/package_core/utils.py +72 -0
  122. spikezoo/archs/stir/package_core/dist/package_core-0.0.0-py3.9.egg +0 -0
  123. spikezoo/archs/stir/package_core/package_core/__init__.py +0 -0
  124. spikezoo/archs/stir/package_core/package_core/__pycache__/__init__.cpython-39.pyc +0 -0
  125. spikezoo/archs/stir/package_core/package_core/__pycache__/net_basics.cpython-39.pyc +0 -0
  126. spikezoo/archs/stir/package_core/package_core/convertions.py +721 -0
  127. spikezoo/archs/stir/package_core/package_core/disp_netS.py +133 -0
  128. spikezoo/archs/stir/package_core/package_core/flow_utils.py +167 -0
  129. spikezoo/archs/stir/package_core/package_core/generic_train_test.py +76 -0
  130. spikezoo/archs/stir/package_core/package_core/geometry.py +458 -0
  131. spikezoo/archs/stir/package_core/package_core/image_proc.py +183 -0
  132. spikezoo/archs/stir/package_core/package_core/linalg.py +40 -0
  133. spikezoo/archs/stir/package_core/package_core/losses.py +198 -0
  134. spikezoo/archs/stir/package_core/package_core/metrics.py +51 -0
  135. spikezoo/archs/stir/package_core/package_core/model_base.py +53 -0
  136. spikezoo/archs/stir/package_core/package_core/net_basics.py +100 -0
  137. spikezoo/archs/stir/package_core/package_core/resnet.py +333 -0
  138. spikezoo/archs/stir/package_core/package_core/transforms.py +123 -0
  139. spikezoo/archs/stir/package_core/package_core/utils.py +72 -0
  140. spikezoo/archs/stir/package_core/package_core.egg-info/PKG-INFO +3 -0
  141. spikezoo/archs/stir/package_core/package_core.egg-info/SOURCES.txt +20 -0
  142. spikezoo/archs/stir/package_core/package_core.egg-info/dependency_links.txt +1 -0
  143. spikezoo/archs/stir/package_core/package_core.egg-info/top_level.txt +1 -0
  144. spikezoo/archs/stir/package_core/setup.py +5 -0
  145. spikezoo/archs/stir/requirements.txt +12 -0
  146. spikezoo/archs/stir/train_STIR.sh +9 -0
  147. spikezoo/archs/tfi/__pycache__/nets.cpython-39.pyc +0 -0
  148. spikezoo/archs/tfi/nets.py +43 -0
  149. spikezoo/archs/tfp/__pycache__/nets.cpython-39.pyc +0 -0
  150. spikezoo/archs/tfp/nets.py +13 -0
  151. spikezoo/archs/wgse/README.md +64 -0
  152. spikezoo/archs/wgse/__pycache__/dwtnets.cpython-39.pyc +0 -0
  153. spikezoo/archs/wgse/__pycache__/submodules.cpython-39.pyc +0 -0
  154. spikezoo/archs/wgse/dataset.py +59 -0
  155. spikezoo/archs/wgse/demo.png +0 -0
  156. spikezoo/archs/wgse/demo.py +83 -0
  157. spikezoo/archs/wgse/dwtnets.py +145 -0
  158. spikezoo/archs/wgse/eval.py +133 -0
  159. spikezoo/archs/wgse/logs/WGSE-Dwt1dNet-db8-5-ks3/log.txt +11 -0
  160. spikezoo/archs/wgse/submodules.py +68 -0
  161. spikezoo/archs/wgse/train.py +261 -0
  162. spikezoo/archs/wgse/transform.py +139 -0
  163. spikezoo/archs/wgse/utils.py +128 -0
  164. spikezoo/archs/wgse/weights/demo.png +0 -0
  165. spikezoo/data/base/test/gt/200_part1_key_id151.png +0 -0
  166. spikezoo/data/base/test/gt/200_part3_key_id151.png +0 -0
  167. spikezoo/data/base/test/gt/203_part1_key_id151.png +0 -0
  168. spikezoo/data/base/test/spike/200_part1_key_id151.dat +0 -0
  169. spikezoo/data/base/test/spike/200_part3_key_id151.dat +0 -0
  170. spikezoo/data/base/test/spike/203_part1_key_id151.dat +0 -0
  171. spikezoo/data/base/train/gt/203_part2_key_id151.png +0 -0
  172. spikezoo/data/base/train/gt/203_part3_key_id151.png +0 -0
  173. spikezoo/data/base/train/gt/203_part4_key_id151.png +0 -0
  174. spikezoo/data/base/train/spike/203_part2_key_id151.dat +0 -0
  175. spikezoo/data/base/train/spike/203_part3_key_id151.dat +0 -0
  176. spikezoo/data/base/train/spike/203_part4_key_id151.dat +0 -0
  177. spikezoo/datasets/base_dataset.py +2 -3
  178. spikezoo/metrics/__init__.py +1 -1
  179. spikezoo/models/base_model.py +1 -3
  180. spikezoo/pipeline/base_pipeline.py +7 -5
  181. spikezoo/pipeline/train_pipeline.py +1 -1
  182. spikezoo/utils/other_utils.py +16 -6
  183. spikezoo/utils/spike_utils.py +33 -29
  184. spikezoo/utils/vidar_loader.cpython-39-x86_64-linux-gnu.so +0 -0
  185. spikezoo-0.2.dist-info/METADATA +163 -0
  186. spikezoo-0.2.dist-info/RECORD +211 -0
  187. spikezoo/models/spcsnet_model.py +0 -19
  188. spikezoo-0.1.1.dist-info/METADATA +0 -39
  189. spikezoo-0.1.1.dist-info/RECORD +0 -36
  190. {spikezoo-0.1.1.dist-info → spikezoo-0.2.dist-info}/LICENSE.txt +0 -0
  191. {spikezoo-0.1.1.dist-info → spikezoo-0.2.dist-info}/WHEEL +0 -0
  192. {spikezoo-0.1.1.dist-info → spikezoo-0.2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,67 @@
1
+ from json import load
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ import numpy as np
6
+ from model import DoubleNet
7
+ import cv2
8
+
9
+ def load_vidar_dat(filename, left_up=(0, 0), window=None, frame_cnt = None, **kwargs):
10
+ if isinstance(filename, str):
11
+ array = np.fromfile(filename, dtype=np.uint8)
12
+ elif isinstance(filename, (list, tuple)):
13
+ l = []
14
+ for name in filename:
15
+ a = np.fromfile(name, dtype=np.uint8)
16
+ l.append(a)
17
+ array = np.concatenate(l)
18
+ else:
19
+ raise NotImplementedError
20
+
21
+ height = 250
22
+ width = 400
23
+
24
+ if window == None:
25
+ window = (height - left_up[0], width - left_up[0])
26
+
27
+ len_per_frame = height * width // 8
28
+ framecnt = frame_cnt if frame_cnt != None else len(array) // len_per_frame
29
+
30
+ spikes = []
31
+
32
+ for i in range(framecnt):
33
+ compr_frame = array[i * len_per_frame: (i + 1) * len_per_frame]
34
+ blist = []
35
+ for b in range(8):
36
+ blist.append(np.right_shift(np.bitwise_and(compr_frame, np.left_shift(1, b)), b))
37
+
38
+ frame_ = np.stack(blist).transpose()
39
+ frame_ = np.flipud(frame_.reshape((height, width), order='C'))
40
+
41
+ if window is not None:
42
+ spk = frame_[left_up[0]:left_up[0] + window[0], left_up[1]:left_up[1] + window[1]]
43
+ else:
44
+ spk = frame_
45
+
46
+ spk = torch.from_numpy(spk.copy().astype(np.float32)).unsqueeze(dim=0)
47
+
48
+ spikes.append(spk)
49
+
50
+ return torch.cat(spikes)
51
+
52
+ if __name__ == '__main__':
53
+ model = DoubleNet()
54
+ model_path = "./fin3g-best-lucky.pt"
55
+ model = nn.DataParallel(model)
56
+ model.load_state_dict(torch.load(model_path))
57
+ model = model.cuda()
58
+
59
+ spike_path = "./rotation1.dat"
60
+ spike = load_vidar_dat(spike_path)[200:200+41].unsqueeze(0).cuda()
61
+
62
+ res = model(spike)
63
+ res = res[0].detach().cpu().permute(1,2,0).numpy()*255
64
+ res_path = "./res.png"
65
+ cv2.imwrite(res_path,res)
66
+
67
+ print("done.")
File without changes
@@ -0,0 +1,65 @@
1
+ <!---
2
+ # Spatio-Temporal Interactive Learning for Efficient Image Reconstruction of Spiking Cameras
3
+
4
+ This repository contains the source code for the paper: [Spatio-Temporal Interactive Learning for Efficient Image Reconstruction of Spiking Cameras (NeurIPS 2024)](https://openreview.net/pdf?id=S4ZqnMywcM).
5
+ The spiking camera is an emerging neuromorphic vision sensor that records high-speed motion scenes by asynchronously firing continuous binary spike streams. Prevailing image reconstruction methods, generating intermediate frames from these spike streams, often rely on complex step-by-step network architectures that overlook the intrinsic collaboration of spatio-temporal complementary information. In this paper, we propose an efficient spatio-temporal interactive reconstruction network to jointly perform inter-frame feature alignment and intra-frame feature filtering in a coarse-to-fine manner. Specifically, it starts by extracting hierarchical features from a concise hybrid spike representation, then refines the motion fields and target frames scale-by-scale, ultimately obtaining a full-resolution output. Meanwhile, we introduce a symmetric interactive attention block and a multi-motion field estimation block to further enhance the interaction capability of the overall network. Experiments on synthetic and real-captured data show that our approach exhibits excellent performance while maintaining low model complexity.
6
+
7
+ <img src="picture/performance-speed.png" width="75%"/>
8
+ <img src="picture/overview.png" width="80%"/>
9
+ <img src="picture/results_visual.png" width="82%"/>
10
+ -->
11
+ ## Installation
12
+ You can choose cudatoolkit version to match your server. The code is tested with PyTorch 1.9.1 with CUDA 11.4.
13
+
14
+ ```shell
15
+ conda create -n stir python==3.8.12
16
+ conda activate stir
17
+ # You can choose the PyTorch version you like, for example
18
+ pip install torch==2.2.2 torchvision==0.17.2 torchaudio==2.0.2
19
+ ```
20
+
21
+ Install the dependent packages:
22
+ ```
23
+ pip install -r requirements.txt
24
+ ```
25
+
26
+ Install core package
27
+ ```
28
+ cd ./package_core
29
+ python setup.py install
30
+ ```
31
+
32
+ In our implementation, we borrowed the code framework of [SSIR](https://github.com/ruizhao26/SSIR):
33
+
34
+ ## Prepare the Data
35
+
36
+ #### 1. Download and deploy the SREDS dataset to your local computer from [SSIR](https://github.com/ruizhao26/SSIR).
37
+
38
+ #### 2. Set the path of the SREDS dataset in your serve
39
+
40
+ Set that in `--data_root` when running train_STIR.sh or eval_SREDS.sh
41
+
42
+ ## Evaluate
43
+ ```
44
+ sh eval_SREDS.sh
45
+ ```
46
+
47
+ ## Train
48
+ ```
49
+ sh train_STIR.sh
50
+ ```
51
+ <!---
52
+ ## Citations
53
+ If you find our approach useful in your research or applications, please consider giving us a star 🌟 and citing it by the following BibTeX entry.
54
+ ```
55
+ @article{fan2024spatio,
56
+ title={Spatio-Temporal Interactive Learning for Efficient Image Reconstruction of Spiking Cameras},
57
+ author={Fan, Bin and Yin, Jiaoyang and Dai, Yuchao and Xu, Chao and Huang, Tiejun and Shi, Boxin},
58
+ journal={Proceedings of the Advances in Neural Information Processing Systems (NeurIPS)},
59
+ volume={},
60
+ year={2024}
61
+ }
62
+ ```
63
+ -->
64
+ ## Statement
65
+ This project is for research purpose only, please contact us for the licence of commercial use. For any other questions or discussion please contact: binfan@mail.nwpu.edu.cn
@@ -0,0 +1 @@
1
+ This folder is used to store the trained model.
@@ -0,0 +1,37 @@
1
+ data:
2
+ interp: 20
3
+ alpha: 0.4
4
+
5
+ seed: 6666
6
+
7
+ loader:
8
+ # crop_size: [128, 128]
9
+ crop_size: [96, 96]
10
+ pair_step: 4
11
+
12
+ model:
13
+ arch: 'STIR'
14
+ seq_len: 8
15
+ flow_weight_decay: 0.0004
16
+ flow_bias_decay: 0.0
17
+ #########################
18
+ kwargs:
19
+ activation_type: 'lif'
20
+ mp_activation_type: 'amp_lif'
21
+ spike_connection: 'concat'
22
+ num_encoders: 3
23
+ num_resblocks: 1
24
+ v_threshold: 1.0
25
+ v_reset: None
26
+ tau: 2.0
27
+
28
+
29
+ train:
30
+ print_freq: 1
31
+ mixed_precision: True
32
+ vis_freq: 20
33
+
34
+ optimizer:
35
+ solver: Adam
36
+ momentum: 0.9
37
+ beta: 0.999
@@ -0,0 +1,155 @@
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn.functional as F
4
+ import os
5
+ import os.path as osp
6
+ import random
7
+ import cv2
8
+
9
+ def set_seeds(_seed_):
10
+ random.seed(_seed_)
11
+ np.random.seed(_seed_)
12
+ torch.manual_seed(_seed_) # use torch.manual_seed() to seed the RNG for all devices (both CPU and CUDA)
13
+ torch.cuda.manual_seed_all(_seed_)
14
+
15
+ torch.backends.cudnn.deterministic = True
16
+ torch.backends.cudnn.benchmark = False
17
+ # os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'
18
+ # set a debug environment variable CUBLAS_WORKSPACE_CONFIG to ":16:8" (may limit overall performance) or ":4096:8" (will increase library footprint in GPU memory by approximately 24MiB).
19
+ # torch.use_deterministic_algorithms(True)
20
+
21
+
22
+ def make_dir(path):
23
+ if not osp.exists(path):
24
+ os.makedirs(path)
25
+ return
26
+
27
+
28
+ def add_args_to_cfg(cfg, args, args_list):
29
+ for aa in args_list:
30
+ cfg['train'][aa] = eval('args.{:s}'.format(aa))
31
+ return cfg
32
+
33
+
34
+ # class AverageMeter(object):
35
+ # """Computes and stores the average and current value"""
36
+ # def __init__(self, precision=3):
37
+ # self.precision = precision
38
+ # self.reset()
39
+
40
+ # def reset(self):
41
+ # self.val = 0
42
+ # self.avg = 0
43
+ # self.sum = 0
44
+ # self.count = 0
45
+
46
+ # def update(self, val, n=1):
47
+ # self.val = val
48
+ # self.sum += val * n
49
+ # self.count += n
50
+ # self.avg = self.sum / self.count
51
+
52
+ # def __repr__(self):
53
+ # return '{:.{}f} ({:.{}f})'.format(self.val, self.precision, self.avg, self.precision)
54
+
55
+
56
+ class AverageMeter(object):
57
+ """Computes and stores the average and current value"""
58
+
59
+ def __init__(self, i=1, precision=3, names=None):
60
+ self.meters = i
61
+ self.precision = precision
62
+ self.reset(self.meters)
63
+ self.names = names
64
+ if names is not None:
65
+ assert self.meters == len(self.names)
66
+ else:
67
+ self.names = [''] * self.meters
68
+
69
+ def reset(self, i):
70
+ self.val = [0] * i
71
+ self.avg = [0] * i
72
+ self.sum = [0] * i
73
+ self.count = [0] * i
74
+
75
+ def update(self, val, n=1):
76
+ if not isinstance(val, list):
77
+ val = [val]
78
+ if not isinstance(n, list):
79
+ n = [n] * self.meters
80
+ assert (len(val) == self.meters and len(n) == self.meters)
81
+ for i in range(self.meters):
82
+ self.count[i] += n[i]
83
+ for i, v in enumerate(val):
84
+ self.val[i] = v
85
+ self.sum[i] += v * n[i]
86
+ self.avg[i] = self.sum[i] / self.count[i]
87
+
88
+ def __repr__(self):
89
+ # val = ' '.join(['{} {:.{}f}'.format(n, v, self.precision) for n, v in
90
+ # zip(self.names, self.val)])
91
+ # avg = ' '.join(['{} {:.{}f}'.format(n, a, self.precision) for n, a in
92
+ # zip(self.names, self.avg)])
93
+ out = ' '.join(['{} {:.{}f} ({:.{}f})'.format(n, v, self.precision, a, self.precision) for n, v, a in
94
+ zip(self.names, self.val, self.avg)])
95
+ # return '{} ({})'.format(val, avg)
96
+ return '{}'.format(out)
97
+
98
+
99
+ def normalize_image_torch(image, percentile_lower=1, percentile_upper=99):
100
+ b, c, h, w = image.shape
101
+ image_reshape = image.reshape([b, c, h*w])
102
+ mini = torch.quantile(image_reshape, 0.01, dim=2, keepdim=True).unsqueeze_(dim=3)
103
+ maxi = torch.quantile(image_reshape, 0.99, dim=2, keepdim=True).unsqueeze_(dim=3)
104
+ # if mini == maxi:
105
+ # return 0 * image + 0.5 # gray image
106
+ return torch.clip((image - mini) / (maxi - mini + 1e-5), 0, 1)
107
+
108
+ def normalize_image_torch2(image):
109
+ return torch.clip(image, 0, 1)
110
+
111
+ # --------------------------------------------
112
+ # Torch to Numpy 0~255
113
+ # --------------------------------------------
114
+ def torch2numpy255(im):
115
+ im = im[0, 0].detach().cpu().numpy()
116
+ im = (im * 255).astype(np.float64)
117
+ return im
118
+
119
+ def torch2torch255(im):
120
+ return im * 255.0
121
+
122
+ class InputPadder:
123
+ """ Pads images such that dimensions are divisible by padsize """
124
+ def __init__(self, dims, padsize=16):
125
+ self.ht, self.wd = dims[-2:]
126
+ pad_ht = (((self.ht // padsize) + 1) * padsize - self.ht) % padsize
127
+ pad_wd = (((self.wd // padsize) + 1) * padsize - self.wd) % padsize
128
+ #self._pad = [pad_wd//2, pad_wd - pad_wd//2, 0, pad_ht]
129
+ self._pad = [pad_wd//2, pad_wd - pad_wd//2, pad_ht//2, pad_ht - pad_ht//2]
130
+
131
+ def pad(self, *inputs):
132
+ return [F.pad(x, self._pad, mode='replicate') for x in inputs]
133
+
134
+ def unpad(self,x):
135
+ ht, wd = x.shape[-2:]
136
+ c = [self._pad[2], ht-self._pad[3], self._pad[0], wd-self._pad[1]]
137
+ return x[..., c[0]:c[1], c[2]:c[3]]
138
+
139
+
140
+
141
+ def vis_img(vis_path: str, img: torch.Tensor, vis_name: str = 'vis'):
142
+ ww = 0
143
+ rows = []
144
+ for ii in range(4):
145
+ cur_row = []
146
+ for jj in range(img.shape[0]//4):
147
+ cur_img = img[ww, 0].detach().cpu().numpy() * 255
148
+ cur_img = cur_img.astype(np.uint8)
149
+ cur_row.append(cur_img)
150
+ ww += 1
151
+ cur_row_cat = np.concatenate(cur_row, axis=1)
152
+ rows.append(cur_row_cat)
153
+ out_img = np.concatenate(rows, axis=0)
154
+ cv2.imwrite(osp.join(vis_path, vis_name+'.png'), out_img)
155
+ return
@@ -0,0 +1,78 @@
1
+ import numpy as np
2
+ import torch
3
+ import yaml
4
+
5
+
6
+ class YAMLParser:
7
+ """
8
+ Modified from code from tudelft ssl-evflow
9
+ """
10
+
11
+ def __init__(self, config):
12
+ self.reset_config()
13
+ self.parse_config(config)
14
+ # self.init_seeds()
15
+
16
+ def parse_config(self, file):
17
+ with open(file) as fid:
18
+ yaml_config = yaml.load(fid, Loader=yaml.FullLoader)
19
+ self.parse_dict(yaml_config)
20
+
21
+ @property
22
+ def config(self):
23
+ return self._config
24
+
25
+ @property
26
+ def device(self):
27
+ return self._device
28
+
29
+ @property
30
+ def loader_kwargs(self):
31
+ return self._loader_kwargs
32
+
33
+ def reset_config(self):
34
+ self._config = {}
35
+
36
+ def update(self, config):
37
+ self.reset_config()
38
+ self.parse_config(config)
39
+
40
+ def parse_dict(self, input_dict, parent=None):
41
+ if parent is None:
42
+ parent = self._config
43
+ for key, val in input_dict.items():
44
+ if isinstance(val, dict):
45
+ if key not in parent.keys():
46
+ parent[key] = {}
47
+ self.parse_dict(val, parent[key])
48
+ else:
49
+ parent[key] = val
50
+
51
+ @staticmethod
52
+ def worker_init_fn(worker_id):
53
+ np.random.seed(np.random.get_state()[1][0] + worker_id)
54
+
55
+ # def init_seeds(self):
56
+ # torch.manual_seed(self._config["loader"]["seed"])
57
+ # if torch.cuda.is_available():
58
+ # torch.cuda.manual_seed(self._config["loader"]["seed"])
59
+ # torch.cuda.manual_seed_all(self._config["loader"]["seed"])
60
+
61
+ def merge_configs(self, run):
62
+ """
63
+ Overwrites mlflow metadata with configs.
64
+ """
65
+
66
+ # parse mlflow settings
67
+ config = {}
68
+ for key in run.keys():
69
+ if len(run[key]) > 0 and run[key][0] == "{": # assume dictionary
70
+ config[key] = eval(run[key])
71
+ else: # string
72
+ config[key] = run[key]
73
+
74
+ # overwrite with config settings
75
+ self.parse_dict(self._config, config)
76
+ self.combine_entries(config)
77
+
78
+ return config
@@ -0,0 +1,180 @@
1
+ import os
2
+ import os.path as osp
3
+ import random
4
+ import numpy as np
5
+ import torch
6
+ import torch.utils.data as data
7
+ from datasets.ds_utils import *
8
+ import time
9
+
10
+
11
+ class Augmentor:
12
+ def __init__(self, crop_size):
13
+ # spatial augmentation params
14
+ self.crop_size = crop_size
15
+
16
+ def augment_img(self, img, mode=0):
17
+ '''Kai Zhang (github: https://github.com/cszn)
18
+ W x H x C or W x H
19
+ '''
20
+ if mode == 0:
21
+ return img
22
+ elif mode == 1:
23
+ return np.flipud(np.rot90(img))
24
+ elif mode == 2:
25
+ return np.flipud(img)
26
+ elif mode == 3:
27
+ return np.rot90(img, k=3)
28
+ elif mode == 4:
29
+ return np.flipud(np.rot90(img, k=2))
30
+ elif mode == 5:
31
+ return np.rot90(img)
32
+ elif mode == 6:
33
+ return np.rot90(img, k=2)
34
+ elif mode == 7:
35
+ return np.flipud(np.rot90(img, k=3))
36
+
37
+ def spatial_transform(self, spk_list, img_list):
38
+ mode = random.randint(0, 7)
39
+
40
+ for ii, spk in enumerate(spk_list):
41
+ spk = np.transpose(spk, [1,2,0])
42
+ spk = self.augment_img(spk, mode=mode)
43
+ spk_list[ii] = np.transpose(spk, [2,0,1])
44
+
45
+ for ii, img in enumerate(img_list):
46
+ img = np.transpose(img, [1,2,0])
47
+ img = self.augment_img(img, mode=mode)
48
+ img_list[ii] = np.transpose(img, [2,0,1])
49
+
50
+ return spk_list, img_list
51
+
52
+ def __call__(self, spk_list, img_list):
53
+ spk_list, img_list = self.spatial_transform(spk_list, img_list)
54
+ spk_list = [np.ascontiguousarray(spk) for spk in spk_list]
55
+ img_list = [np.ascontiguousarray(img) for img in img_list]
56
+ return spk_list, img_list
57
+
58
+
59
+ class sreds_train(torch.utils.data.Dataset):
60
+ def __init__(self, cfg):
61
+ self.cfg = cfg
62
+ self.pair_step = self.cfg['loader']['pair_step']
63
+ self.augmentor = Augmentor(crop_size=self.cfg['loader']['crop_size'])
64
+ self.samples = self.collect_samples()
65
+ print('The samples num of training data: {:d}'.format(len(self.samples)))
66
+
67
+ def confirm_exist(self, path_list_list):
68
+ for pl in path_list_list:
69
+ for p in pl:
70
+ if not osp.exists(p):
71
+ return 0
72
+ return 1
73
+
74
+ def collect_samples(self):
75
+ spike_path = osp.join(self.cfg['data']['root'], 'crop_mini', 'spike', 'train', 'interp_{:d}_alpha_{:.2f}'.format(self.cfg['data']['interp'], self.cfg['data']['alpha']))
76
+ image_path = osp.join(self.cfg['data']['root'], 'crop_mini', 'image', 'train', 'train_orig')
77
+ scene_list = sorted(os.listdir(spike_path))
78
+ samples = []
79
+
80
+ for scene in scene_list:
81
+ spike_dir = osp.join(spike_path, scene)
82
+ image_dir = osp.join(image_path, scene)
83
+ spk_path_list = sorted(os.listdir(spike_dir))
84
+
85
+ spklen = len(spk_path_list)
86
+ seq_len = self.cfg['model']['seq_len'] + 2
87
+ '''
88
+ for st in range(0, spklen - ((spklen - self.pair_step) % seq_len) - seq_len, self.pair_step):
89
+ # 按照文件名称读取
90
+ spikes_path_list = [osp.join(spike_dir, spk_path_list[ii]) for ii in range(st, st+seq_len)]
91
+ images_path_list = [osp.join(image_dir, spk_path_list[ii][:-4]+'.png') for ii in range(st, st+seq_len)]
92
+
93
+ if(self.confirm_exist([spikes_path_list, images_path_list])):
94
+ s = {}
95
+ s['spikes_paths'] = spikes_path_list
96
+ s['images_paths'] = images_path_list
97
+ samples.append(s)
98
+ '''
99
+ # 按照文件名称读取
100
+ spikes_path_list = [osp.join(spike_dir, spk_path_list[ii]) for ii in range(spklen)]
101
+ images_path_list = [osp.join(image_dir, spk_path_list[ii][:-4]+'.png') for ii in range(spklen)]
102
+
103
+ if(self.confirm_exist([spikes_path_list, images_path_list])):
104
+ s = {}
105
+ s['spikes_paths'] = spikes_path_list
106
+ s['images_paths'] = images_path_list
107
+ samples.append(s)
108
+
109
+ return samples
110
+
111
+ def _load_sample(self, s):
112
+ data = {}
113
+
114
+ data['spikes'] = [np.array(dat_to_spmat(p, size=(96, 96)), dtype=np.float32) for p in s['spikes_paths']]
115
+ data['images'] = [read_img_gray(p) for p in s['images_paths']]
116
+
117
+ data['spikes'], data['images'] = self.augmentor(data['spikes'], data['images'])
118
+ # print("data['spikes'][0].shape, data['images'][0].shape", data['spikes'][0].shape, data['images'][0].shape)
119
+
120
+ return data
121
+
122
+ def __len__(self):
123
+ return len(self.samples)
124
+
125
+ def __getitem__(self, index):
126
+ data = self._load_sample(self.samples[index])
127
+ return data
128
+
129
+
130
+ class sreds_test(torch.utils.data.Dataset):
131
+ def __init__(self, cfg):
132
+ self.cfg = cfg
133
+ self.samples = self.collect_samples()
134
+ print('The samples num of testing data: {:d}'.format(len(self.samples)))
135
+
136
+ def confirm_exist(self, path_list_list):
137
+ for pl in path_list_list:
138
+ for p in pl:
139
+ if not osp.exists(p):
140
+ return 0
141
+ return 1
142
+
143
+ def collect_samples(self):
144
+ spike_path = osp.join(self.cfg['data']['root'], 'spike', 'val', 'interp_{:d}_alpha_{:.2f}'.format(self.cfg['data']['interp'], self.cfg['data']['alpha']))
145
+ image_path = osp.join(self.cfg['data']['root'], 'imgs', 'val', 'val_orig')
146
+ scene_list = sorted(os.listdir(spike_path))
147
+ samples = []
148
+
149
+ for scene in scene_list:
150
+ spike_dir = osp.join(spike_path, scene)
151
+ image_dir = osp.join(image_path, scene)
152
+ spk_path_list = sorted(os.listdir(spike_dir))
153
+
154
+ spklen = len(spk_path_list)
155
+ # seq_len = self.cfg['model']['seq_len']
156
+
157
+ # 按照文件名称读取
158
+ spikes_path_list = [osp.join(spike_dir, spk_path_list[ii]) for ii in range(spklen)]
159
+ images_path_list = [osp.join(image_dir, spk_path_list[ii][:-4]+'.png') for ii in range(spklen)]
160
+
161
+ if(self.confirm_exist([spikes_path_list, images_path_list])):
162
+ s = {}
163
+ s['spikes_paths'] = spikes_path_list
164
+ s['images_paths'] = images_path_list
165
+ samples.append(s)
166
+
167
+ return samples
168
+
169
+ def _load_sample(self, s):
170
+ data = {}
171
+ data['spikes'] = [np.array(dat_to_spmat(p, size=(720, 1280)), dtype=np.float32) for p in s['spikes_paths']]
172
+ data['images'] = [read_img_gray(p) for p in s['images_paths']]
173
+ return data
174
+
175
+ def __len__(self):
176
+ return len(self.samples)
177
+
178
+ def __getitem__(self, index):
179
+ data = self._load_sample(self.samples[index])
180
+ return data
@@ -0,0 +1,66 @@
1
+ import numpy as np
2
+ import os
3
+ import cv2
4
+ import os.path as osp
5
+
6
+ def RawToSpike(video_seq, h, w, flipud=True):
7
+ video_seq = np.array(video_seq).astype(np.uint8)
8
+ img_size = h*w
9
+ img_num = len(video_seq)//(img_size//8)
10
+ SpikeMatrix = np.zeros([img_num, h, w], np.uint8)
11
+ pix_id = np.arange(0,h*w)
12
+ pix_id = np.reshape(pix_id, (h, w))
13
+ comparator = np.left_shift(1, np.mod(pix_id, 8))
14
+ byte_id = pix_id // 8
15
+
16
+ for img_id in np.arange(img_num):
17
+ id_start = img_id*img_size//8
18
+ id_end = id_start + img_size//8
19
+ cur_info = video_seq[id_start:id_end]
20
+ data = cur_info[byte_id]
21
+ result = np.bitwise_and(data, comparator)
22
+ if flipud:
23
+ SpikeMatrix[img_id, :, :] = np.flipud((result == comparator))
24
+ else:
25
+ SpikeMatrix[img_id, :, :] = (result == comparator)
26
+
27
+ return SpikeMatrix
28
+
29
+
30
+ def SpikeToRaw(SpikeSeq, save_path):
31
+ """
32
+ SpikeSeq: Numpy array (sfn x h x w)
33
+ save_path: full saving path (string)
34
+ Rui Zhao
35
+ """
36
+ sfn, h, w = SpikeSeq.shape
37
+ base = np.power(2, np.linspace(0, 7, 8))
38
+ fid = open(save_path, 'ab')
39
+ for img_id in range(sfn):
40
+ # 模拟相机的倒像
41
+ spike = np.flipud(SpikeSeq[img_id, :, :])
42
+ # numpy按自动按行排,数据也是按行存的
43
+ spike = spike.flatten()
44
+ spike = spike.reshape([int(h*w/8), 8])
45
+ data = spike * base
46
+ data = np.sum(data, axis=1).astype(np.uint8)
47
+ fid.write(data.tobytes())
48
+
49
+ fid.close()
50
+
51
+ return
52
+
53
+
54
+ def dat_to_spmat(dat_path, size=[720, 1280]):
55
+ f = open(dat_path, 'rb')
56
+ video_seq = f.read()
57
+ video_seq = np.frombuffer(video_seq, 'b')
58
+ sp_mat = RawToSpike(video_seq, size[0], size[1])
59
+ return sp_mat
60
+
61
+
62
+ def read_img_gray(file_path):
63
+ im = cv2.imread(file_path).astype(np.float32) / 255.0
64
+ im = cv2.cvtColor(im, cv2.COLOR_BGR2GRAY)
65
+ im = np.expand_dims(im, axis=0)
66
+ return im
@@ -0,0 +1,5 @@
1
+ python3 main.py \
2
+ --data_root /data/local_userdata/fanbin/REDS_dataset/REDS120fps \
3
+ --arch STIR \
4
+ --pretrained ./ckpt/STIR_pretrain.pth \
5
+ --eval