spikezoo 0.1.2__py3-none-any.whl → 0.2.1__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (192) hide show
  1. spikezoo/__init__.py +13 -0
  2. spikezoo/archs/__pycache__/__init__.cpython-39.pyc +0 -0
  3. spikezoo/archs/base/__pycache__/nets.cpython-39.pyc +0 -0
  4. spikezoo/archs/base/nets.py +34 -0
  5. spikezoo/archs/bsf/README.md +92 -0
  6. spikezoo/archs/bsf/datasets/datasets.py +328 -0
  7. spikezoo/archs/bsf/datasets/ds_utils.py +64 -0
  8. spikezoo/archs/bsf/main.py +398 -0
  9. spikezoo/archs/bsf/metrics/psnr.py +22 -0
  10. spikezoo/archs/bsf/metrics/ssim.py +54 -0
  11. spikezoo/archs/bsf/models/bsf/__pycache__/align.cpython-39.pyc +0 -0
  12. spikezoo/archs/bsf/models/bsf/__pycache__/bsf.cpython-39.pyc +0 -0
  13. spikezoo/archs/bsf/models/bsf/__pycache__/rep.cpython-39.pyc +0 -0
  14. spikezoo/archs/bsf/models/bsf/align.py +154 -0
  15. spikezoo/archs/bsf/models/bsf/bsf.py +105 -0
  16. spikezoo/archs/bsf/models/bsf/dsft_convert.py +96 -0
  17. spikezoo/archs/bsf/models/bsf/rep.py +44 -0
  18. spikezoo/archs/bsf/models/get_model.py +7 -0
  19. spikezoo/archs/bsf/prepare_data/DSFT.py +62 -0
  20. spikezoo/archs/bsf/prepare_data/crop_dataset_train.py +135 -0
  21. spikezoo/archs/bsf/prepare_data/crop_dataset_val.py +139 -0
  22. spikezoo/archs/bsf/prepare_data/crop_train.sh +4 -0
  23. spikezoo/archs/bsf/prepare_data/crop_val.sh +4 -0
  24. spikezoo/archs/bsf/prepare_data/io_utils.py +64 -0
  25. spikezoo/archs/bsf/requirements.txt +9 -0
  26. spikezoo/archs/bsf/test.py +16 -0
  27. spikezoo/archs/bsf/utils.py +154 -0
  28. spikezoo/archs/spikeclip/__pycache__/nets.cpython-39.pyc +0 -0
  29. spikezoo/archs/spikeclip/nets.py +40 -0
  30. spikezoo/archs/spikeformer/CheckPoints/readme +1 -0
  31. spikezoo/archs/spikeformer/DataProcess/DataExtactor.py +60 -0
  32. spikezoo/archs/spikeformer/DataProcess/DataLoader.py +115 -0
  33. spikezoo/archs/spikeformer/DataProcess/LoadSpike.py +39 -0
  34. spikezoo/archs/spikeformer/EvalResults/readme +1 -0
  35. spikezoo/archs/spikeformer/LICENSE +21 -0
  36. spikezoo/archs/spikeformer/Metrics/Metrics.py +50 -0
  37. spikezoo/archs/spikeformer/Metrics/__init__.py +0 -0
  38. spikezoo/archs/spikeformer/Model/Loss.py +89 -0
  39. spikezoo/archs/spikeformer/Model/SpikeFormer.py +230 -0
  40. spikezoo/archs/spikeformer/Model/__init__.py +0 -0
  41. spikezoo/archs/spikeformer/Model/__pycache__/SpikeFormer.cpython-39.pyc +0 -0
  42. spikezoo/archs/spikeformer/Model/__pycache__/__init__.cpython-39.pyc +0 -0
  43. spikezoo/archs/spikeformer/README.md +30 -0
  44. spikezoo/archs/spikeformer/evaluate.py +87 -0
  45. spikezoo/archs/spikeformer/recon_real_data.py +97 -0
  46. spikezoo/archs/spikeformer/requirements.yml +95 -0
  47. spikezoo/archs/spikeformer/train.py +173 -0
  48. spikezoo/archs/spikeformer/utils.py +22 -0
  49. spikezoo/archs/spk2imgnet/.github/workflows/pylint.yml +23 -0
  50. spikezoo/archs/spk2imgnet/.gitignore +150 -0
  51. spikezoo/archs/spk2imgnet/DCNv2.py +135 -0
  52. spikezoo/archs/spk2imgnet/__pycache__/DCNv2.cpython-39.pyc +0 -0
  53. spikezoo/archs/spk2imgnet/__pycache__/align_arch.cpython-39.pyc +0 -0
  54. spikezoo/archs/spk2imgnet/__pycache__/nets.cpython-39.pyc +0 -0
  55. spikezoo/archs/spk2imgnet/align_arch.py +159 -0
  56. spikezoo/archs/spk2imgnet/dataset.py +144 -0
  57. spikezoo/archs/spk2imgnet/nets.py +230 -0
  58. spikezoo/archs/spk2imgnet/readme.md +86 -0
  59. spikezoo/archs/spk2imgnet/test_gen_imgseq.py +118 -0
  60. spikezoo/archs/spk2imgnet/train.py +189 -0
  61. spikezoo/archs/spk2imgnet/utils.py +64 -0
  62. spikezoo/archs/ssir/README.md +87 -0
  63. spikezoo/archs/ssir/configs/SSIR.yml +37 -0
  64. spikezoo/archs/ssir/configs/yml_parser.py +78 -0
  65. spikezoo/archs/ssir/datasets/dataset_sreds.py +170 -0
  66. spikezoo/archs/ssir/datasets/ds_utils.py +66 -0
  67. spikezoo/archs/ssir/losses.py +21 -0
  68. spikezoo/archs/ssir/main.py +326 -0
  69. spikezoo/archs/ssir/metrics/psnr.py +22 -0
  70. spikezoo/archs/ssir/metrics/ssim.py +54 -0
  71. spikezoo/archs/ssir/models/Vgg19.py +42 -0
  72. spikezoo/archs/ssir/models/__pycache__/layers.cpython-39.pyc +0 -0
  73. spikezoo/archs/ssir/models/__pycache__/networks.cpython-39.pyc +0 -0
  74. spikezoo/archs/ssir/models/layers.py +110 -0
  75. spikezoo/archs/ssir/models/networks.py +61 -0
  76. spikezoo/archs/ssir/requirements.txt +8 -0
  77. spikezoo/archs/ssir/shells/eval_SREDS.sh +6 -0
  78. spikezoo/archs/ssir/shells/train_SSIR.sh +12 -0
  79. spikezoo/archs/ssir/test.py +3 -0
  80. spikezoo/archs/ssir/utils.py +154 -0
  81. spikezoo/archs/ssml/__pycache__/cbam.cpython-39.pyc +0 -0
  82. spikezoo/archs/ssml/__pycache__/model.cpython-39.pyc +0 -0
  83. spikezoo/archs/ssml/cbam.py +224 -0
  84. spikezoo/archs/ssml/model.py +290 -0
  85. spikezoo/archs/ssml/res.png +0 -0
  86. spikezoo/archs/ssml/test.py +67 -0
  87. spikezoo/archs/stir/.git-credentials +0 -0
  88. spikezoo/archs/stir/README.md +65 -0
  89. spikezoo/archs/stir/ckpt_outputs/Descriptions.txt +1 -0
  90. spikezoo/archs/stir/configs/STIR.yml +37 -0
  91. spikezoo/archs/stir/configs/utils.py +155 -0
  92. spikezoo/archs/stir/configs/yml_parser.py +78 -0
  93. spikezoo/archs/stir/datasets/dataset_sreds.py +180 -0
  94. spikezoo/archs/stir/datasets/ds_utils.py +66 -0
  95. spikezoo/archs/stir/eval_SREDS.sh +5 -0
  96. spikezoo/archs/stir/main.py +397 -0
  97. spikezoo/archs/stir/metrics/losses.py +219 -0
  98. spikezoo/archs/stir/metrics/psnr.py +22 -0
  99. spikezoo/archs/stir/metrics/ssim.py +54 -0
  100. spikezoo/archs/stir/models/Vgg19.py +42 -0
  101. spikezoo/archs/stir/models/__pycache__/networks_STIR.cpython-39.pyc +0 -0
  102. spikezoo/archs/stir/models/__pycache__/submodules.cpython-39.pyc +0 -0
  103. spikezoo/archs/stir/models/__pycache__/transformer_new.cpython-39.pyc +0 -0
  104. spikezoo/archs/stir/models/networks_STIR.py +361 -0
  105. spikezoo/archs/stir/models/submodules.py +86 -0
  106. spikezoo/archs/stir/models/transformer_new.py +151 -0
  107. spikezoo/archs/stir/package_core/build/lib/package_core/__init__.py +0 -0
  108. spikezoo/archs/stir/package_core/build/lib/package_core/convertions.py +721 -0
  109. spikezoo/archs/stir/package_core/build/lib/package_core/disp_netS.py +133 -0
  110. spikezoo/archs/stir/package_core/build/lib/package_core/flow_utils.py +167 -0
  111. spikezoo/archs/stir/package_core/build/lib/package_core/generic_train_test.py +76 -0
  112. spikezoo/archs/stir/package_core/build/lib/package_core/geometry.py +458 -0
  113. spikezoo/archs/stir/package_core/build/lib/package_core/image_proc.py +183 -0
  114. spikezoo/archs/stir/package_core/build/lib/package_core/linalg.py +40 -0
  115. spikezoo/archs/stir/package_core/build/lib/package_core/losses.py +198 -0
  116. spikezoo/archs/stir/package_core/build/lib/package_core/metrics.py +51 -0
  117. spikezoo/archs/stir/package_core/build/lib/package_core/model_base.py +53 -0
  118. spikezoo/archs/stir/package_core/build/lib/package_core/net_basics.py +100 -0
  119. spikezoo/archs/stir/package_core/build/lib/package_core/resnet.py +333 -0
  120. spikezoo/archs/stir/package_core/build/lib/package_core/transforms.py +123 -0
  121. spikezoo/archs/stir/package_core/build/lib/package_core/utils.py +72 -0
  122. spikezoo/archs/stir/package_core/dist/package_core-0.0.0-py3.9.egg +0 -0
  123. spikezoo/archs/stir/package_core/package_core/__init__.py +0 -0
  124. spikezoo/archs/stir/package_core/package_core/__pycache__/__init__.cpython-39.pyc +0 -0
  125. spikezoo/archs/stir/package_core/package_core/__pycache__/net_basics.cpython-39.pyc +0 -0
  126. spikezoo/archs/stir/package_core/package_core/convertions.py +721 -0
  127. spikezoo/archs/stir/package_core/package_core/disp_netS.py +133 -0
  128. spikezoo/archs/stir/package_core/package_core/flow_utils.py +167 -0
  129. spikezoo/archs/stir/package_core/package_core/generic_train_test.py +76 -0
  130. spikezoo/archs/stir/package_core/package_core/geometry.py +458 -0
  131. spikezoo/archs/stir/package_core/package_core/image_proc.py +183 -0
  132. spikezoo/archs/stir/package_core/package_core/linalg.py +40 -0
  133. spikezoo/archs/stir/package_core/package_core/losses.py +198 -0
  134. spikezoo/archs/stir/package_core/package_core/metrics.py +51 -0
  135. spikezoo/archs/stir/package_core/package_core/model_base.py +53 -0
  136. spikezoo/archs/stir/package_core/package_core/net_basics.py +100 -0
  137. spikezoo/archs/stir/package_core/package_core/resnet.py +333 -0
  138. spikezoo/archs/stir/package_core/package_core/transforms.py +123 -0
  139. spikezoo/archs/stir/package_core/package_core/utils.py +72 -0
  140. spikezoo/archs/stir/package_core/package_core.egg-info/PKG-INFO +3 -0
  141. spikezoo/archs/stir/package_core/package_core.egg-info/SOURCES.txt +20 -0
  142. spikezoo/archs/stir/package_core/package_core.egg-info/dependency_links.txt +1 -0
  143. spikezoo/archs/stir/package_core/package_core.egg-info/top_level.txt +1 -0
  144. spikezoo/archs/stir/package_core/setup.py +5 -0
  145. spikezoo/archs/stir/requirements.txt +12 -0
  146. spikezoo/archs/stir/train_STIR.sh +9 -0
  147. spikezoo/archs/tfi/__pycache__/nets.cpython-39.pyc +0 -0
  148. spikezoo/archs/tfi/nets.py +43 -0
  149. spikezoo/archs/tfp/__pycache__/nets.cpython-39.pyc +0 -0
  150. spikezoo/archs/tfp/nets.py +13 -0
  151. spikezoo/archs/wgse/README.md +64 -0
  152. spikezoo/archs/wgse/__pycache__/dwtnets.cpython-39.pyc +0 -0
  153. spikezoo/archs/wgse/__pycache__/submodules.cpython-39.pyc +0 -0
  154. spikezoo/archs/wgse/dataset.py +59 -0
  155. spikezoo/archs/wgse/demo.png +0 -0
  156. spikezoo/archs/wgse/demo.py +83 -0
  157. spikezoo/archs/wgse/dwtnets.py +145 -0
  158. spikezoo/archs/wgse/eval.py +133 -0
  159. spikezoo/archs/wgse/logs/WGSE-Dwt1dNet-db8-5-ks3/log.txt +11 -0
  160. spikezoo/archs/wgse/submodules.py +68 -0
  161. spikezoo/archs/wgse/train.py +261 -0
  162. spikezoo/archs/wgse/transform.py +139 -0
  163. spikezoo/archs/wgse/utils.py +128 -0
  164. spikezoo/archs/wgse/weights/demo.png +0 -0
  165. spikezoo/data/base/test/gt/200_part1_key_id151.png +0 -0
  166. spikezoo/data/base/test/gt/200_part3_key_id151.png +0 -0
  167. spikezoo/data/base/test/gt/203_part1_key_id151.png +0 -0
  168. spikezoo/data/base/test/spike/200_part1_key_id151.dat +0 -0
  169. spikezoo/data/base/test/spike/200_part3_key_id151.dat +0 -0
  170. spikezoo/data/base/test/spike/203_part1_key_id151.dat +0 -0
  171. spikezoo/data/base/train/gt/203_part2_key_id151.png +0 -0
  172. spikezoo/data/base/train/gt/203_part3_key_id151.png +0 -0
  173. spikezoo/data/base/train/gt/203_part4_key_id151.png +0 -0
  174. spikezoo/data/base/train/spike/203_part2_key_id151.dat +0 -0
  175. spikezoo/data/base/train/spike/203_part3_key_id151.dat +0 -0
  176. spikezoo/data/base/train/spike/203_part4_key_id151.dat +0 -0
  177. spikezoo/datasets/base_dataset.py +2 -3
  178. spikezoo/metrics/__init__.py +1 -1
  179. spikezoo/models/base_model.py +1 -3
  180. spikezoo/pipeline/base_pipeline.py +7 -5
  181. spikezoo/pipeline/train_pipeline.py +1 -1
  182. spikezoo/utils/other_utils.py +16 -6
  183. spikezoo/utils/spike_utils.py +33 -29
  184. spikezoo/utils/vidar_loader.cpython-39-x86_64-linux-gnu.so +0 -0
  185. spikezoo-0.2.1.dist-info/METADATA +167 -0
  186. spikezoo-0.2.1.dist-info/RECORD +211 -0
  187. spikezoo/models/spcsnet_model.py +0 -19
  188. spikezoo-0.1.2.dist-info/METADATA +0 -39
  189. spikezoo-0.1.2.dist-info/RECORD +0 -36
  190. {spikezoo-0.1.2.dist-info → spikezoo-0.2.1.dist-info}/LICENSE.txt +0 -0
  191. {spikezoo-0.1.2.dist-info → spikezoo-0.2.1.dist-info}/WHEEL +0 -0
  192. {spikezoo-0.1.2.dist-info → spikezoo-0.2.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,105 @@
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from .rep import MODF
5
+ from .align import Multi_Granularity_Align
6
+
7
+
8
+ class BasicModel(nn.Module):
9
+ def __init__(self):
10
+ super().__init__()
11
+
12
+ ####################################################################################
13
+ ## Tools functions for neural networks
14
+ def weight_parameters(self):
15
+ return [param for name, param in self.named_parameters() if 'weight' in name]
16
+
17
+ def bias_parameters(self):
18
+ return [param for name, param in self.named_parameters() if 'bias' in name]
19
+
20
+ def num_parameters(self):
21
+ return sum([p.data.nelement() if p.requires_grad else 0 for p in self.parameters()])
22
+
23
+ def init_weights(self):
24
+ for layer in self.named_modules():
25
+ if isinstance(layer, nn.Conv2d):
26
+ nn.init.kaiming_normal_(layer.weight)
27
+ if layer.bias is not None:
28
+ nn.init.constant_(layer.bias, 0)
29
+
30
+ elif isinstance(layer, nn.ConvTranspose2d):
31
+ nn.init.kaiming_normal_(layer.weight)
32
+ if layer.bias is not None:
33
+ nn.init.constant_(layer.bias, 0)
34
+
35
+
36
+ def split_and_b_cat(x):
37
+ x0 = x[:, 10-10:10+10+1].clone()
38
+ x1 = x[:, 20-10:20+10+1].clone()
39
+ x2 = x[:, 30-10:30+10+1].clone()
40
+ x3 = x[:, 40-10:40+10+1].clone()
41
+ x4 = x[:, 50-10:50+10+1].clone()
42
+ return torch.cat([x0, x1, x2, x3, x4], dim=0)
43
+
44
+
45
+ class Encoder(nn.Module):
46
+ def __init__(self, base_dim=64, layers=4, act=nn.ReLU()):
47
+ super().__init__()
48
+ self.conv_list = nn.ModuleList()
49
+ for ii in range(layers):
50
+ self.conv_list.append(
51
+ nn.Sequential(
52
+ nn.Conv2d(base_dim, base_dim, kernel_size=3, padding=1),
53
+ act,
54
+ nn.Conv2d(base_dim, base_dim, kernel_size=3, padding=1),
55
+ )
56
+ )
57
+ self.act = act
58
+
59
+ def forward(self, x):
60
+ for conv in self.conv_list:
61
+ x = self.act(conv(x) + x)
62
+ return x
63
+
64
+ ##########################################################################
65
+ class BSF(BasicModel):
66
+ def __init__(self, act=nn.ReLU()):
67
+ super().__init__()
68
+ self.offset_groups = 4
69
+ self.corr_max_disp = 3
70
+
71
+ self.rep = MODF(base_dim=64, act=act)
72
+
73
+ self.encoder = Encoder(base_dim=64, layers=4, act=act)
74
+
75
+ self.align = Multi_Granularity_Align(base_dim=64, groups=self.offset_groups, act=act, sc=3)
76
+
77
+ self.recons = nn.Sequential(
78
+ nn.Conv2d(64*5, 64*3, kernel_size=3, padding=1),
79
+ act,
80
+ nn.Conv2d(64*3, 64, kernel_size=3, padding=1),
81
+ act,
82
+ nn.Conv2d(64, 1, kernel_size=3, padding=1),
83
+ )
84
+
85
+ def forward(self, input_dict):
86
+ dsft_dict = input_dict['dsft_dict']
87
+ dsft11 = dsft_dict['dsft11']
88
+ dsft12 = dsft_dict['dsft12']
89
+ dsft21 = dsft_dict['dsft21']
90
+ dsft22 = dsft_dict['dsft22']
91
+
92
+ dsft_b_cat = {
93
+ 'dsft11': split_and_b_cat(dsft11),
94
+ 'dsft12': split_and_b_cat(dsft12),
95
+ 'dsft21': split_and_b_cat(dsft21),
96
+ 'dsft22': split_and_b_cat(dsft22),
97
+ }
98
+
99
+ feat_b_cat = self.rep(dsft_b_cat)
100
+ feat_b_cat = self.encoder(feat_b_cat)
101
+ feat_list = feat_b_cat.chunk(5, dim=0)
102
+ feat_list_align = self.align(feat_list=feat_list)
103
+ out = self.recons(torch.cat(feat_list_align, dim=1))
104
+
105
+ return out
@@ -0,0 +1,96 @@
1
+ import torch
2
+
3
+
4
+ def convert_dsft4(dsft, spike):
5
+ '''
6
+ input: Pytorch Tensor
7
+ dsft: dsft(1,1) b x T x h x w
8
+ spike: 01 spike b x T x h x w
9
+ output: Pytorch Tensor
10
+ dsft_dict: {dsft(1,1), dsft(1,2), dsft(2,1), dsft(2,2)}
11
+ '''
12
+
13
+ b, T, h, w = spike.shape
14
+
15
+ ## dsft_mask_left_shift -- abbr. --> dmls1, (right-shift: dmrs1)
16
+ dmls1 = -1 * torch.ones(spike.shape, device=spike.device, dtype=torch.float32)
17
+ dmrs1 = -1 * torch.ones(spike.shape, device=spike.device, dtype=torch.float32)
18
+
19
+ ## for dmls1
20
+ # flag的用途是为了边界的copy-padding
21
+ flag = -2 * torch.ones([b, h, w], device=spike.device, dtype=torch.float32)
22
+ for ii in range(T-1, 0-1, -1):
23
+ flag += (spike[:,ii]==1)
24
+
25
+ copy_pad_coord = (flag < 0)
26
+ dmls1[:,ii][copy_pad_coord] = dsft[:,ii][copy_pad_coord]
27
+
28
+ if ii < T-1:
29
+ ## dmls1的数据该更新的情况
30
+ update_coord = (spike[:,ii+1]==1) * (~copy_pad_coord)
31
+ dmls1[:,ii][update_coord] = dsft[:,ii+1][update_coord]
32
+
33
+ ## dmls1的数据不该更新,该继承之前的数的情况
34
+ non_update_coord = (spike[:,ii+1]!=1) * (~copy_pad_coord)
35
+ dmls1[:,ii][non_update_coord] = dmls1[:, ii+1][non_update_coord]
36
+
37
+
38
+ ## for dmrs1
39
+ # flag的用途是为了边界的copy-padding
40
+ flag = -2 * torch.ones([b, h, w], device=spike.device, dtype=torch.float32)
41
+ for ii in range(0, T, 1):
42
+ flag += (spike[:,ii]==1)
43
+
44
+ ## for 边界的 copy-padding
45
+ copy_pad_coord = (flag < 0)
46
+ dmrs1[:,ii][copy_pad_coord] = dsft[:,ii][copy_pad_coord]
47
+
48
+ if ii > 0:
49
+ ## dmrs1的数据该更新的情况
50
+ update_coord = (spike[:,ii]==1) * (~copy_pad_coord)
51
+ dmrs1[:,ii][update_coord] = dsft[:,ii-1][update_coord]
52
+
53
+ ## dmrs1的数据不该更新,该继承之前的数的情况
54
+ non_update_coord = (spike[:,ii]!=1) * (~copy_pad_coord)
55
+ dmrs1[:,ii][non_update_coord] = dmrs1[:, ii-1][non_update_coord]
56
+
57
+
58
+ dsft12 = dsft + dmls1
59
+ dsft21 = dsft + dmrs1
60
+ dsft22 = dsft + dmls1 + dmrs1
61
+
62
+
63
+ dsft_dict = {
64
+ 'dsft11': dsft,
65
+ 'dsft12': dsft12,
66
+ 'dsft21': dsft21,
67
+ 'dsft22': dsft22,
68
+ }
69
+
70
+ return dsft_dict
71
+
72
+
73
+
74
+ if __name__ == '__main__':
75
+ # spike = [0,0,1,0,1,0,0,0,0,1,0,0,1,0,1,0,0,0,1,0,0]
76
+ # dsft = [2,2,2,2,2,5,5,5,5,5,3,3,3,2,2,4,4,4,4,4,4]
77
+
78
+ spike = [0,0,1,0,0,1,0,0,0,1,0,0,0,1,0,0,1,0,1,0,0,1,0,0,0,1,0,0,1,0,0,1,0]
79
+ dsft = [3,3,3,3,3,4,4,4,4,4,4,4,4,3,3,3,2,2,3,3,3,4,4,4,4,3,3,3,3,3,3,3,3]
80
+
81
+ spike = torch.tensor(spike, device='cpu', dtype=torch.float32)[None,:,None,None]
82
+ dsft = torch.tensor(dsft , device='cpu', dtype=torch.float32)[None,:,None,None]
83
+
84
+ dsft_dict = convert_dsft4(dsft=dsft, spike=spike)
85
+ dsft_11 = dsft_dict['dsft11']
86
+ dsft_12 = dsft_dict['dsft12']
87
+ dsft_21 = dsft_dict['dsft21']
88
+ dsft_22 = dsft_dict['dsft22']
89
+
90
+ print(dsft_11[0,:,0,0])
91
+ print()
92
+ print(dsft_12[0,:,0,0])
93
+ print()
94
+ print(dsft_21[0,:,0,0])
95
+ print()
96
+ print(dsft_22[0,:,0,0])
@@ -0,0 +1,44 @@
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ class MODF(nn.Module):
5
+ def __init__(self, base_dim=64, act=nn.ReLU()):
6
+ super().__init__()
7
+ self.base_dim = base_dim
8
+
9
+ self.conv1 = self._make_layer(input_dim=21, hidden_dim=self.base_dim, output_dim=self.base_dim, act=act)
10
+ self.conv_for_others = nn.ModuleList([
11
+ self._make_layer(input_dim=self.base_dim, hidden_dim=self.base_dim, output_dim=self.base_dim, act=act) for ii in range(3)
12
+ ])
13
+ self.conv_fuse = self._make_layer(input_dim=self.base_dim*3, hidden_dim=self.base_dim, output_dim=self.base_dim, act=act)
14
+
15
+ def _make_layer(self, input_dim, hidden_dim, output_dim, act):
16
+ layer = nn.Sequential(
17
+ nn.Conv2d(input_dim, hidden_dim, kernel_size=3, padding=1),
18
+ act,
19
+ nn.Conv2d(hidden_dim, output_dim, kernel_size=3, padding=1),
20
+ )
21
+ return layer
22
+
23
+ def forward(self, dsft_dict):
24
+ d11 = 1.0 / dsft_dict['dsft11']
25
+ d12 = 2.0 / dsft_dict['dsft12']
26
+ d21 = 2.0 / dsft_dict['dsft21']
27
+ d22 = 3.0 / dsft_dict['dsft22']
28
+
29
+ d_list = [d11, d12, d21, d22]
30
+ feat_batch_cat = self.conv1(torch.cat(d_list, dim=0))
31
+ feat_list = feat_batch_cat.chunk(4, dim=0)
32
+
33
+ feat_11 = feat_list[0]
34
+ feat_others_list = feat_list[1:]
35
+ feat_others_list_processed = []
36
+ for ii in range(3):
37
+ feat_others_list_processed.append(self.conv_for_others[ii](feat_others_list[ii]))
38
+
39
+
40
+ other_feat = torch.cat(feat_others_list_processed, dim=1)
41
+ other_feat_res = self.conv_fuse(other_feat)
42
+
43
+ return feat_11 + other_feat_res
44
+
@@ -0,0 +1,7 @@
1
+ from .bsf.bsf import BSF
2
+
3
+ def get_model(args):
4
+ if args.arch.upper() == 'BSF'.upper():
5
+ model = BSF()
6
+
7
+ return model
@@ -0,0 +1,62 @@
1
+ import numpy as np
2
+ import torch
3
+
4
+ class DSFT:
5
+ def __init__(self, spike_h, spike_w, device):
6
+ self.spike_h = spike_h
7
+ self.spike_w = spike_w
8
+ self.device = device
9
+
10
+
11
+ def spikes2images(self, spikes, max_search_half_window=20):
12
+ '''
13
+ 将spikes整体转换为一段DSFT
14
+
15
+ 输入:
16
+ spikes: T x H x W 的numpy张量, 类型: 整型与浮点皆可
17
+ max_search_half_window: 对于要转换为图像的时刻点而言, 左右各参考的最大脉冲帧数量,超过这个数字就不搜了
18
+
19
+ 输出:
20
+ ImageMatrix: T' x H x W 的numpy张量, 其中T' = T - (2 x max_search_half_window)
21
+ 类型: uint8, 取值范围: 0 ~ 255
22
+ '''
23
+
24
+ T = spikes.shape[0]
25
+ T_im = T - 2*max_search_half_window
26
+
27
+ if T_im < 0:
28
+ raise ValueError('The length of spike stream {:d} is not enough for max_search half window length {:d}'.format(T, max_search_half_window))
29
+
30
+ spikes = torch.from_numpy(spikes).to(self.device).float()
31
+ ImageMatrix = torch.zeros([T_im, self.spike_h, self.spike_w]).to(self.device)
32
+
33
+ pre_idx = -1 * torch.ones([T, self.spike_h, self.spike_w]).float().to(self.device)
34
+ cur_idx = -1 * torch.ones([T, self.spike_h, self.spike_w]).float().to(self.device)
35
+
36
+ for ii in range(T):
37
+ if ii > 0:
38
+ pre_idx[ii] = cur_idx[ii-1]
39
+ cur_idx[ii] = cur_idx[ii-1]
40
+ cur_spk = spikes[ii]
41
+ cur_idx[ii][cur_spk==1] = ii
42
+
43
+ diff = cur_idx - pre_idx
44
+
45
+
46
+ interval = -1 * torch.ones([T, self.spike_h, self.spike_w]).float().to(self.device)
47
+ for ii in range(T-1, 0-1, -1):
48
+ interval[ii][diff[ii]!=0] = diff[ii][diff[ii]!=0]
49
+ if ii < T-1:
50
+ interval[ii][diff[ii]==0] = interval[ii+1][diff[ii]==0]
51
+
52
+ # boundary
53
+ interval[interval==-1] = 255
54
+ interval[pre_idx==-1] = 255
55
+
56
+ # for uint8
57
+ interval = torch.clip(interval, 0, 255)
58
+
59
+ ImageMatrix = interval[max_search_half_window:-max_search_half_window].cpu().detach().numpy().astype(np.uint8)
60
+
61
+
62
+ return ImageMatrix
@@ -0,0 +1,135 @@
1
+ import os
2
+ import os.path as osp
3
+ import argparse
4
+ import cv2
5
+ import numpy as np
6
+ from io_utils import *
7
+ import h5py
8
+ from tqdm import *
9
+ from DSFT import DSFT
10
+
11
+ parser = argparse.ArgumentParser()
12
+ parser.add_argument("--root", type=str, default="/data/rzhao/REDS120fps")
13
+ parser.add_argument("--output_path", type=str, default="/data/rzhao/REDS120fps/crop")
14
+ ###### 参数
15
+ parser.add_argument("--eta", type=float, default=1.0)
16
+ parser.add_argument("--gamma", type=int, default=60)
17
+ parser.add_argument("--alpha", type=float, default=0.7)
18
+
19
+ parser.add_argument("--cu", '-c', type=str, default='0')
20
+
21
+ parser.add_argument("--crop_image", action='store_true')
22
+ args = parser.parse_args()
23
+
24
+ os.environ['CUDA_VISIBLE_DEVICES'] = args.cu
25
+
26
+
27
+ if __name__ == '__main__':
28
+ imgs_path = osp.join(args.root, 'imgs', 'train')
29
+ spks_path = osp.join(args.root, 'spikes', 'train',
30
+ "eta_{:.2f}_gamma_{:d}_alpha_{:.1f}".format(args.eta, args.gamma, args.alpha))
31
+
32
+ scene_list = sorted(os.listdir(spks_path))
33
+ for scene in tqdm(scene_list):
34
+ scene_imgs_path = osp.join(imgs_path, scene)
35
+ scene_spks_path = osp.join(spks_path, scene)
36
+
37
+
38
+ if not args.crop_image:
39
+ # read all the dat files
40
+ dat_path = sorted(os.listdir(scene_spks_path))
41
+ spks_list = []
42
+ #### abandon 00000000.dat, corresponding to the following spike_idx_offset
43
+ for dat_name in dat_path[1:]:
44
+ spks_list.append(dat_to_spmat(dat_path=osp.join(scene_spks_path, dat_name), size=(720, 1280)))
45
+ spikes = np.concatenate(spks_list, axis=0)
46
+
47
+ # spikes -> DSFT(max_search_half_win=80)
48
+ dsft_solver = DSFT(spike_h=720, spike_w=1280, device='cuda')
49
+ dsft = dsft_solver.spikes2images(spikes, max_search_half_window=100)
50
+
51
+
52
+ # crop Image
53
+ if args.crop_image:
54
+ imgs_list = []
55
+ for im_idx in range(11, 28+1):
56
+ img = cv2.imread(osp.join(scene_imgs_path, '{:08d}.png'.format(im_idx)))
57
+ # 1. central crop
58
+ crop_img = img[32:-32, 128:-128]
59
+ # 2. crop
60
+ for ii in range(3):
61
+ for jj in range(4):
62
+ if (ii != 2) and (jj != 3):
63
+ cur_img = crop_img[256*ii:256*(ii+1), 256*jj:256*(jj+1)]
64
+ elif ii != 2:
65
+ cur_img = crop_img[256*ii:256*(ii+1), -256:]
66
+ elif jj != 3:
67
+ cur_img = crop_img[-256:, 256*jj:256*(jj+1)]
68
+ else:
69
+ cur_img = crop_img[-256:, -256:]
70
+ cur_save_root = osp.join(args.output_path, 'train', 'imgs', scene, '{:02}'.format(ii*4+jj))
71
+ os.makedirs(cur_save_root, exist_ok=True)
72
+ cur_save_path = osp.join(cur_save_root, '{:08d}.png'.format(im_idx))
73
+ if osp.exists(cur_save_path):
74
+ os.remove(cur_save_path)
75
+ cv2.imwrite(cur_save_path, cur_img)
76
+ continue
77
+
78
+ # crop spikes
79
+ # since 00000000.dat is abandoned
80
+ spike_idx_offset = 10
81
+ # 1. central crop
82
+ spikes = spikes[:, 32:-32, 128:-128]
83
+ # 2. crop
84
+ for spk_idx in range(11, 28+1):
85
+ crop_spike = spikes[spk_idx*10-spike_idx_offset : spk_idx*10-spike_idx_offset+10]
86
+ for ii in range(3):
87
+ for jj in range(4):
88
+ if (ii != 2) and (jj != 3):
89
+ cur_spk = crop_spike[:, 256*ii:256*(ii+1), 256*jj:256*(jj+1)]
90
+ elif ii != 2:
91
+ cur_spk = crop_spike[:, 256*ii:256*(ii+1), -256:]
92
+ elif jj != 3:
93
+ cur_spk = crop_spike[:, -256:, 256*jj:256*(jj+1)]
94
+ else:
95
+ cur_spk = crop_spike[:, -256:, -256:]
96
+
97
+ cur_save_root = osp.join(args.output_path, 'train',
98
+ "eta_{:.2f}_gamma_{:d}_alpha_{:.1f}".format(args.eta, args.gamma, args.alpha),
99
+ scene, '{:02}'.format(ii*4+jj), 'spikes')
100
+
101
+ os.makedirs(cur_save_root, exist_ok=True)
102
+ cur_save_path = osp.join(cur_save_root,'{:08d}.dat'.format(spk_idx))
103
+ if osp.exists(cur_save_path):
104
+ os.remove(cur_save_path)
105
+ SpikeToRaw(SpikeSeq=cur_spk, save_path=cur_save_path)
106
+
107
+
108
+ # crop dsft
109
+ dsft_idx_offset = 10 + 100
110
+ # 1. central crop
111
+ dsft = dsft[:, 32:-32, 128:-128]
112
+ # 2. crop
113
+ for dsft_idx in range(11, 28+1):
114
+ crop_dsft = dsft[dsft_idx*10-dsft_idx_offset : dsft_idx*10-dsft_idx_offset+10]
115
+ for ii in range(3):
116
+ for jj in range(4):
117
+ if (ii != 2) and (jj != 3):
118
+ cur_dsft = crop_dsft[:, 256*ii:256*(ii+1), 256*jj:256*(jj+1)]
119
+ elif ii != 2:
120
+ cur_dsft = crop_dsft[:, 256*ii:256*(ii+1), -256:]
121
+ elif jj != 3:
122
+ cur_dsft = crop_dsft[:, -256:, 256*jj:256*(jj+1)]
123
+ else:
124
+ cur_dsft = crop_dsft[:, -256:, -256:]
125
+
126
+ cur_save_root = osp.join(args.output_path, 'train',
127
+ "eta_{:.2f}_gamma_{:d}_alpha_{:.1f}".format(args.eta, args.gamma, args.alpha),
128
+ scene, '{:02}'.format(ii*4+jj), 'dsft')
129
+ os.makedirs(cur_save_root, exist_ok=True)
130
+ cur_save_path = osp.join(cur_save_root, '{:08d}.h5'.format(dsft_idx))
131
+ if osp.exists(cur_save_path):
132
+ os.remove(cur_save_path)
133
+ f = h5py.File(cur_save_path, 'w')
134
+ f['dsft'] = cur_dsft
135
+ f.close()
@@ -0,0 +1,139 @@
1
+ import os
2
+ import os.path as osp
3
+ import argparse
4
+ import cv2
5
+ import numpy as np
6
+ from io_utils import *
7
+ import h5py
8
+ from tqdm import *
9
+ from DSFT import DSFT
10
+
11
+ parser = argparse.ArgumentParser()
12
+ parser.add_argument("--root", type=str, default="/data/rzhao/REDS120fps")
13
+ parser.add_argument("--output_path", type=str, default="/data/rzhao/REDS120fps/crop")
14
+ ###### 参数
15
+ parser.add_argument("--eta", type=float, default=1.00)
16
+ parser.add_argument("--gamma", type=int, default=60)
17
+ parser.add_argument("--alpha", type=float, default=0.7)
18
+
19
+ parser.add_argument("--cu", '-c', type=str, default='0')
20
+
21
+ parser.add_argument("--crop_image", action='store_true')
22
+ args = parser.parse_args()
23
+
24
+ os.environ['CUDA_VISIBLE_DEVICES'] = args.cu
25
+
26
+ if __name__ == '__main__':
27
+ imgs_path = osp.join(args.root, 'imgs', 'val')
28
+ spks_path = osp.join(args.root, 'spikes', 'val',
29
+ "eta_{:.2f}_gamma_{:d}_alpha_{:.1f}".format(args.eta, args.gamma, args.alpha))
30
+
31
+ scene_list = sorted(os.listdir(spks_path))
32
+ for scene in tqdm(scene_list):
33
+ scene_imgs_path = osp.join(imgs_path, scene)
34
+ scene_spks_path = osp.join(spks_path, scene)
35
+
36
+ if not args.crop_image:
37
+ # read all the dat files
38
+ dat_path = sorted(os.listdir(scene_spks_path))
39
+ spks_list = []
40
+ #### abandon 00000000.dat, corresponding to the following spike_idx_offset
41
+ for dat_name in dat_path[1:]:
42
+ spks_list.append(dat_to_spmat(dat_path=osp.join(scene_spks_path, dat_name), size=(720, 1280)))
43
+ spikes = np.concatenate(spks_list, axis=0)
44
+
45
+ # spikes -> DSFT(max_search_half_win=80)
46
+ dsft_solver = DSFT(spike_h=720, spike_w=1280, device='cuda')
47
+ dsft = dsft_solver.spikes2images(spikes, max_search_half_window=100)
48
+
49
+
50
+ # crop Image
51
+ if args.crop_image:
52
+ imgs_list = []
53
+ for im_idx in range(11, 28+1):
54
+ img = cv2.imread(osp.join(scene_imgs_path, '{:08d}.png'.format(im_idx)))
55
+ # 1. central crop
56
+ crop_img = img[32:-32, 128:-128]
57
+ for sub_scene_idx in range(4):
58
+ cur_scene = '{:s}_{:d}'.format(scene, sub_scene_idx)
59
+ if sub_scene_idx == 0:
60
+ cur_crop_img = crop_img[:384, :512]
61
+ elif sub_scene_idx == 1:
62
+ cur_crop_img = crop_img[-384:, :512]
63
+ elif sub_scene_idx == 2:
64
+ cur_crop_img = crop_img[:384, -512:]
65
+ elif sub_scene_idx == 3:
66
+ cur_crop_img = crop_img[-384:, -512:]
67
+
68
+ cur_save_root = osp.join(args.output_path, 'val_small', 'imgs', cur_scene)
69
+ os.makedirs(cur_save_root, exist_ok=True)
70
+ cur_save_path = osp.join(cur_save_root, '{:08d}.png'.format(im_idx))
71
+ if osp.exists(cur_save_path):
72
+ os.remove(cur_save_path)
73
+ cv2.imwrite(cur_save_path, cur_crop_img)
74
+ continue
75
+
76
+
77
+ # 裁切 spikes
78
+ # since 00000000.dat is abandoned
79
+ spike_idx_offset = 10
80
+ # 1. central crop
81
+ spikes = spikes[:, 32:-32, 128:-128]
82
+ # 2. crop
83
+ for spk_idx in range(11, 28+1):
84
+ crop_spike = spikes[spk_idx*10-spike_idx_offset : spk_idx*10-spike_idx_offset+10]
85
+
86
+ for sub_scene_idx in range(4):
87
+ cur_scene = '{:s}_{:d}'.format(scene, sub_scene_idx)
88
+ if sub_scene_idx == 0:
89
+ cur_crop_spike = crop_spike[:, :384, :512]
90
+ elif sub_scene_idx == 1:
91
+ cur_crop_spike = crop_spike[:, -384:, :512]
92
+ elif sub_scene_idx == 2:
93
+ cur_crop_spike = crop_spike[:, :384, -512:]
94
+ elif sub_scene_idx == 3:
95
+ cur_crop_spike = crop_spike[:, -384:, -512:]
96
+
97
+ cur_save_root = osp.join(args.output_path, 'val_small',
98
+ "eta_{:.2f}_gamma_{:d}_alpha_{:.1f}".format(args.eta, args.gamma, args.alpha),
99
+ cur_scene,
100
+ 'spikes')
101
+
102
+ os.makedirs(cur_save_root, exist_ok=True)
103
+ cur_save_path = osp.join(cur_save_root,'{:08d}.dat'.format(spk_idx))
104
+ if osp.exists(cur_save_path):
105
+ os.remove(cur_save_path)
106
+ SpikeToRaw(SpikeSeq=cur_crop_spike, save_path=cur_save_path)
107
+
108
+
109
+ # crop dsft
110
+ dsft_idx_offset = 10 + 100
111
+ # 1. central crop
112
+ dsft = dsft[:, 32:-32, 128:-128]
113
+ # 2. crop
114
+ for dsft_idx in range(11, 28+1):
115
+ crop_dsft = dsft[dsft_idx*10-dsft_idx_offset : dsft_idx*10-dsft_idx_offset+10]
116
+
117
+ for sub_scene_idx in range(4):
118
+ cur_scene = '{:s}_{:d}'.format(scene, sub_scene_idx)
119
+ if sub_scene_idx == 0:
120
+ cur_crop_dsft = crop_dsft[:, :384, :512]
121
+ elif sub_scene_idx == 1:
122
+ cur_crop_dsft = crop_dsft[:, -384:, :512]
123
+ elif sub_scene_idx == 2:
124
+ cur_crop_dsft = crop_dsft[:, :384, -512:]
125
+ elif sub_scene_idx == 3:
126
+ cur_crop_dsft = crop_dsft[:, -384:, -512:]
127
+
128
+ cur_save_root = osp.join(args.output_path, 'val_small',
129
+ "eta_{:.2f}_gamma_{:d}_alpha_{:.1f}".format(args.eta, args.gamma, args.alpha),
130
+ cur_scene,
131
+ 'dsft')
132
+ os.makedirs(cur_save_root, exist_ok=True)
133
+ cur_save_path = osp.join(cur_save_root, '{:08d}.h5'.format(dsft_idx))
134
+ if osp.exists(cur_save_path):
135
+ os.remove(cur_save_path)
136
+ f = h5py.File(cur_save_path, 'w')
137
+ f['dsft'] = cur_crop_dsft
138
+ f.close()
139
+
@@ -0,0 +1,4 @@
1
+ python3 crop_dataset_train.py -c $1 --eta 1.00 &&
2
+ python3 crop_dataset_train.py -c $1 --eta 0.75 &&
3
+ python3 crop_dataset_train.py -c $1 --eta 0.50 &&
4
+ python3 crop_dataset_train.py --crop_image
@@ -0,0 +1,4 @@
1
+ python3 crop_dataset_val.py -c $1 --eta 1.00 &&
2
+ python3 crop_dataset_val.py -c $1 --eta 0.75 &&
3
+ python3 crop_dataset_val.py -c $1 --eta 0.50 &&
4
+ python3 crop_dataset_val.py --crop_image
@@ -0,0 +1,64 @@
1
+ import numpy as np
2
+ import os
3
+ import os.path as osp
4
+
5
+
6
+ def RawToSpike(video_seq, h, w, flipud=True):
7
+ video_seq = np.array(video_seq).astype(np.uint8)
8
+ img_size = h*w
9
+ img_num = len(video_seq)//(img_size//8)
10
+ SpikeMatrix = np.zeros([img_num, h, w], np.uint8)
11
+ pix_id = np.arange(0,h*w)
12
+ pix_id = np.reshape(pix_id, (h, w))
13
+ comparator = np.left_shift(1, np.mod(pix_id, 8))
14
+ byte_id = pix_id // 8
15
+
16
+ for img_id in np.arange(img_num):
17
+ id_start = img_id*img_size//8
18
+ id_end = id_start + img_size//8
19
+ cur_info = video_seq[id_start:id_end]
20
+ data = cur_info[byte_id]
21
+ result = np.bitwise_and(data, comparator)
22
+ if flipud:
23
+ SpikeMatrix[img_id, :, :] = np.flipud((result == comparator))
24
+ else:
25
+ SpikeMatrix[img_id, :, :] = (result == comparator)
26
+
27
+ return SpikeMatrix
28
+
29
+ def dat_to_spmat(dat_path, size):
30
+ f = open(dat_path, 'rb')
31
+ video_seq = f.read()
32
+ video_seq = np.frombuffer(video_seq, 'b')
33
+ sp_mat = RawToSpike(video_seq, size[0], size[1])
34
+ return sp_mat
35
+
36
+
37
+ ## Save Raw dat files
38
+ def SpikeToRaw(SpikeSeq, save_path):
39
+ """
40
+ SpikeSeq: Numpy array (sfn x h x w)
41
+ save_path: full saving path (string)
42
+ """
43
+ sfn, h, w = SpikeSeq.shape
44
+ base = np.power(2, np.linspace(0, 7, 8))
45
+ fid = open(save_path, 'ab')
46
+ for img_id in range(sfn):
47
+ # 模拟相机的倒像
48
+ spike = np.flipud(SpikeSeq[img_id, :, :])
49
+ # numpy按自动按行排,数据也是按行存的
50
+ spike = spike.flatten()
51
+ spike = spike.reshape([int(h*w/8), 8])
52
+ data = spike * base
53
+ data = np.sum(data, axis=1).astype(np.uint8)
54
+ fid.write(data.tobytes())
55
+
56
+ fid.close()
57
+
58
+ return
59
+
60
+
61
+ def save_to_h5(SpikeMatrix, h5path, name):
62
+ f = h5py.File(h5path, 'w')
63
+ f[name] = SpikeMatrix
64
+ f.close()
@@ -0,0 +1,9 @@
1
+ torch
2
+ torchvision
3
+ opencv-python
4
+ numpy
5
+ tensorboardX
6
+ scikit-image
7
+ lpips
8
+ tqdm
9
+ h5py