spikezoo 0.1.1__py3-none-any.whl → 0.2__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (192) hide show
  1. spikezoo/__init__.py +13 -0
  2. spikezoo/archs/__pycache__/__init__.cpython-39.pyc +0 -0
  3. spikezoo/archs/base/__pycache__/nets.cpython-39.pyc +0 -0
  4. spikezoo/archs/base/nets.py +34 -0
  5. spikezoo/archs/bsf/README.md +92 -0
  6. spikezoo/archs/bsf/datasets/datasets.py +328 -0
  7. spikezoo/archs/bsf/datasets/ds_utils.py +64 -0
  8. spikezoo/archs/bsf/main.py +398 -0
  9. spikezoo/archs/bsf/metrics/psnr.py +22 -0
  10. spikezoo/archs/bsf/metrics/ssim.py +54 -0
  11. spikezoo/archs/bsf/models/bsf/__pycache__/align.cpython-39.pyc +0 -0
  12. spikezoo/archs/bsf/models/bsf/__pycache__/bsf.cpython-39.pyc +0 -0
  13. spikezoo/archs/bsf/models/bsf/__pycache__/rep.cpython-39.pyc +0 -0
  14. spikezoo/archs/bsf/models/bsf/align.py +154 -0
  15. spikezoo/archs/bsf/models/bsf/bsf.py +105 -0
  16. spikezoo/archs/bsf/models/bsf/dsft_convert.py +96 -0
  17. spikezoo/archs/bsf/models/bsf/rep.py +44 -0
  18. spikezoo/archs/bsf/models/get_model.py +7 -0
  19. spikezoo/archs/bsf/prepare_data/DSFT.py +62 -0
  20. spikezoo/archs/bsf/prepare_data/crop_dataset_train.py +135 -0
  21. spikezoo/archs/bsf/prepare_data/crop_dataset_val.py +139 -0
  22. spikezoo/archs/bsf/prepare_data/crop_train.sh +4 -0
  23. spikezoo/archs/bsf/prepare_data/crop_val.sh +4 -0
  24. spikezoo/archs/bsf/prepare_data/io_utils.py +64 -0
  25. spikezoo/archs/bsf/requirements.txt +9 -0
  26. spikezoo/archs/bsf/test.py +16 -0
  27. spikezoo/archs/bsf/utils.py +154 -0
  28. spikezoo/archs/spikeclip/__pycache__/nets.cpython-39.pyc +0 -0
  29. spikezoo/archs/spikeclip/nets.py +40 -0
  30. spikezoo/archs/spikeformer/CheckPoints/readme +1 -0
  31. spikezoo/archs/spikeformer/DataProcess/DataExtactor.py +60 -0
  32. spikezoo/archs/spikeformer/DataProcess/DataLoader.py +115 -0
  33. spikezoo/archs/spikeformer/DataProcess/LoadSpike.py +39 -0
  34. spikezoo/archs/spikeformer/EvalResults/readme +1 -0
  35. spikezoo/archs/spikeformer/LICENSE +21 -0
  36. spikezoo/archs/spikeformer/Metrics/Metrics.py +50 -0
  37. spikezoo/archs/spikeformer/Metrics/__init__.py +0 -0
  38. spikezoo/archs/spikeformer/Model/Loss.py +89 -0
  39. spikezoo/archs/spikeformer/Model/SpikeFormer.py +230 -0
  40. spikezoo/archs/spikeformer/Model/__init__.py +0 -0
  41. spikezoo/archs/spikeformer/Model/__pycache__/SpikeFormer.cpython-39.pyc +0 -0
  42. spikezoo/archs/spikeformer/Model/__pycache__/__init__.cpython-39.pyc +0 -0
  43. spikezoo/archs/spikeformer/README.md +30 -0
  44. spikezoo/archs/spikeformer/evaluate.py +87 -0
  45. spikezoo/archs/spikeformer/recon_real_data.py +97 -0
  46. spikezoo/archs/spikeformer/requirements.yml +95 -0
  47. spikezoo/archs/spikeformer/train.py +173 -0
  48. spikezoo/archs/spikeformer/utils.py +22 -0
  49. spikezoo/archs/spk2imgnet/.github/workflows/pylint.yml +23 -0
  50. spikezoo/archs/spk2imgnet/.gitignore +150 -0
  51. spikezoo/archs/spk2imgnet/DCNv2.py +135 -0
  52. spikezoo/archs/spk2imgnet/__pycache__/DCNv2.cpython-39.pyc +0 -0
  53. spikezoo/archs/spk2imgnet/__pycache__/align_arch.cpython-39.pyc +0 -0
  54. spikezoo/archs/spk2imgnet/__pycache__/nets.cpython-39.pyc +0 -0
  55. spikezoo/archs/spk2imgnet/align_arch.py +159 -0
  56. spikezoo/archs/spk2imgnet/dataset.py +144 -0
  57. spikezoo/archs/spk2imgnet/nets.py +230 -0
  58. spikezoo/archs/spk2imgnet/readme.md +86 -0
  59. spikezoo/archs/spk2imgnet/test_gen_imgseq.py +118 -0
  60. spikezoo/archs/spk2imgnet/train.py +189 -0
  61. spikezoo/archs/spk2imgnet/utils.py +64 -0
  62. spikezoo/archs/ssir/README.md +87 -0
  63. spikezoo/archs/ssir/configs/SSIR.yml +37 -0
  64. spikezoo/archs/ssir/configs/yml_parser.py +78 -0
  65. spikezoo/archs/ssir/datasets/dataset_sreds.py +170 -0
  66. spikezoo/archs/ssir/datasets/ds_utils.py +66 -0
  67. spikezoo/archs/ssir/losses.py +21 -0
  68. spikezoo/archs/ssir/main.py +326 -0
  69. spikezoo/archs/ssir/metrics/psnr.py +22 -0
  70. spikezoo/archs/ssir/metrics/ssim.py +54 -0
  71. spikezoo/archs/ssir/models/Vgg19.py +42 -0
  72. spikezoo/archs/ssir/models/__pycache__/layers.cpython-39.pyc +0 -0
  73. spikezoo/archs/ssir/models/__pycache__/networks.cpython-39.pyc +0 -0
  74. spikezoo/archs/ssir/models/layers.py +110 -0
  75. spikezoo/archs/ssir/models/networks.py +61 -0
  76. spikezoo/archs/ssir/requirements.txt +8 -0
  77. spikezoo/archs/ssir/shells/eval_SREDS.sh +6 -0
  78. spikezoo/archs/ssir/shells/train_SSIR.sh +12 -0
  79. spikezoo/archs/ssir/test.py +3 -0
  80. spikezoo/archs/ssir/utils.py +154 -0
  81. spikezoo/archs/ssml/__pycache__/cbam.cpython-39.pyc +0 -0
  82. spikezoo/archs/ssml/__pycache__/model.cpython-39.pyc +0 -0
  83. spikezoo/archs/ssml/cbam.py +224 -0
  84. spikezoo/archs/ssml/model.py +290 -0
  85. spikezoo/archs/ssml/res.png +0 -0
  86. spikezoo/archs/ssml/test.py +67 -0
  87. spikezoo/archs/stir/.git-credentials +0 -0
  88. spikezoo/archs/stir/README.md +65 -0
  89. spikezoo/archs/stir/ckpt_outputs/Descriptions.txt +1 -0
  90. spikezoo/archs/stir/configs/STIR.yml +37 -0
  91. spikezoo/archs/stir/configs/utils.py +155 -0
  92. spikezoo/archs/stir/configs/yml_parser.py +78 -0
  93. spikezoo/archs/stir/datasets/dataset_sreds.py +180 -0
  94. spikezoo/archs/stir/datasets/ds_utils.py +66 -0
  95. spikezoo/archs/stir/eval_SREDS.sh +5 -0
  96. spikezoo/archs/stir/main.py +397 -0
  97. spikezoo/archs/stir/metrics/losses.py +219 -0
  98. spikezoo/archs/stir/metrics/psnr.py +22 -0
  99. spikezoo/archs/stir/metrics/ssim.py +54 -0
  100. spikezoo/archs/stir/models/Vgg19.py +42 -0
  101. spikezoo/archs/stir/models/__pycache__/networks_STIR.cpython-39.pyc +0 -0
  102. spikezoo/archs/stir/models/__pycache__/submodules.cpython-39.pyc +0 -0
  103. spikezoo/archs/stir/models/__pycache__/transformer_new.cpython-39.pyc +0 -0
  104. spikezoo/archs/stir/models/networks_STIR.py +361 -0
  105. spikezoo/archs/stir/models/submodules.py +86 -0
  106. spikezoo/archs/stir/models/transformer_new.py +151 -0
  107. spikezoo/archs/stir/package_core/build/lib/package_core/__init__.py +0 -0
  108. spikezoo/archs/stir/package_core/build/lib/package_core/convertions.py +721 -0
  109. spikezoo/archs/stir/package_core/build/lib/package_core/disp_netS.py +133 -0
  110. spikezoo/archs/stir/package_core/build/lib/package_core/flow_utils.py +167 -0
  111. spikezoo/archs/stir/package_core/build/lib/package_core/generic_train_test.py +76 -0
  112. spikezoo/archs/stir/package_core/build/lib/package_core/geometry.py +458 -0
  113. spikezoo/archs/stir/package_core/build/lib/package_core/image_proc.py +183 -0
  114. spikezoo/archs/stir/package_core/build/lib/package_core/linalg.py +40 -0
  115. spikezoo/archs/stir/package_core/build/lib/package_core/losses.py +198 -0
  116. spikezoo/archs/stir/package_core/build/lib/package_core/metrics.py +51 -0
  117. spikezoo/archs/stir/package_core/build/lib/package_core/model_base.py +53 -0
  118. spikezoo/archs/stir/package_core/build/lib/package_core/net_basics.py +100 -0
  119. spikezoo/archs/stir/package_core/build/lib/package_core/resnet.py +333 -0
  120. spikezoo/archs/stir/package_core/build/lib/package_core/transforms.py +123 -0
  121. spikezoo/archs/stir/package_core/build/lib/package_core/utils.py +72 -0
  122. spikezoo/archs/stir/package_core/dist/package_core-0.0.0-py3.9.egg +0 -0
  123. spikezoo/archs/stir/package_core/package_core/__init__.py +0 -0
  124. spikezoo/archs/stir/package_core/package_core/__pycache__/__init__.cpython-39.pyc +0 -0
  125. spikezoo/archs/stir/package_core/package_core/__pycache__/net_basics.cpython-39.pyc +0 -0
  126. spikezoo/archs/stir/package_core/package_core/convertions.py +721 -0
  127. spikezoo/archs/stir/package_core/package_core/disp_netS.py +133 -0
  128. spikezoo/archs/stir/package_core/package_core/flow_utils.py +167 -0
  129. spikezoo/archs/stir/package_core/package_core/generic_train_test.py +76 -0
  130. spikezoo/archs/stir/package_core/package_core/geometry.py +458 -0
  131. spikezoo/archs/stir/package_core/package_core/image_proc.py +183 -0
  132. spikezoo/archs/stir/package_core/package_core/linalg.py +40 -0
  133. spikezoo/archs/stir/package_core/package_core/losses.py +198 -0
  134. spikezoo/archs/stir/package_core/package_core/metrics.py +51 -0
  135. spikezoo/archs/stir/package_core/package_core/model_base.py +53 -0
  136. spikezoo/archs/stir/package_core/package_core/net_basics.py +100 -0
  137. spikezoo/archs/stir/package_core/package_core/resnet.py +333 -0
  138. spikezoo/archs/stir/package_core/package_core/transforms.py +123 -0
  139. spikezoo/archs/stir/package_core/package_core/utils.py +72 -0
  140. spikezoo/archs/stir/package_core/package_core.egg-info/PKG-INFO +3 -0
  141. spikezoo/archs/stir/package_core/package_core.egg-info/SOURCES.txt +20 -0
  142. spikezoo/archs/stir/package_core/package_core.egg-info/dependency_links.txt +1 -0
  143. spikezoo/archs/stir/package_core/package_core.egg-info/top_level.txt +1 -0
  144. spikezoo/archs/stir/package_core/setup.py +5 -0
  145. spikezoo/archs/stir/requirements.txt +12 -0
  146. spikezoo/archs/stir/train_STIR.sh +9 -0
  147. spikezoo/archs/tfi/__pycache__/nets.cpython-39.pyc +0 -0
  148. spikezoo/archs/tfi/nets.py +43 -0
  149. spikezoo/archs/tfp/__pycache__/nets.cpython-39.pyc +0 -0
  150. spikezoo/archs/tfp/nets.py +13 -0
  151. spikezoo/archs/wgse/README.md +64 -0
  152. spikezoo/archs/wgse/__pycache__/dwtnets.cpython-39.pyc +0 -0
  153. spikezoo/archs/wgse/__pycache__/submodules.cpython-39.pyc +0 -0
  154. spikezoo/archs/wgse/dataset.py +59 -0
  155. spikezoo/archs/wgse/demo.png +0 -0
  156. spikezoo/archs/wgse/demo.py +83 -0
  157. spikezoo/archs/wgse/dwtnets.py +145 -0
  158. spikezoo/archs/wgse/eval.py +133 -0
  159. spikezoo/archs/wgse/logs/WGSE-Dwt1dNet-db8-5-ks3/log.txt +11 -0
  160. spikezoo/archs/wgse/submodules.py +68 -0
  161. spikezoo/archs/wgse/train.py +261 -0
  162. spikezoo/archs/wgse/transform.py +139 -0
  163. spikezoo/archs/wgse/utils.py +128 -0
  164. spikezoo/archs/wgse/weights/demo.png +0 -0
  165. spikezoo/data/base/test/gt/200_part1_key_id151.png +0 -0
  166. spikezoo/data/base/test/gt/200_part3_key_id151.png +0 -0
  167. spikezoo/data/base/test/gt/203_part1_key_id151.png +0 -0
  168. spikezoo/data/base/test/spike/200_part1_key_id151.dat +0 -0
  169. spikezoo/data/base/test/spike/200_part3_key_id151.dat +0 -0
  170. spikezoo/data/base/test/spike/203_part1_key_id151.dat +0 -0
  171. spikezoo/data/base/train/gt/203_part2_key_id151.png +0 -0
  172. spikezoo/data/base/train/gt/203_part3_key_id151.png +0 -0
  173. spikezoo/data/base/train/gt/203_part4_key_id151.png +0 -0
  174. spikezoo/data/base/train/spike/203_part2_key_id151.dat +0 -0
  175. spikezoo/data/base/train/spike/203_part3_key_id151.dat +0 -0
  176. spikezoo/data/base/train/spike/203_part4_key_id151.dat +0 -0
  177. spikezoo/datasets/base_dataset.py +2 -3
  178. spikezoo/metrics/__init__.py +1 -1
  179. spikezoo/models/base_model.py +1 -3
  180. spikezoo/pipeline/base_pipeline.py +7 -5
  181. spikezoo/pipeline/train_pipeline.py +1 -1
  182. spikezoo/utils/other_utils.py +16 -6
  183. spikezoo/utils/spike_utils.py +33 -29
  184. spikezoo/utils/vidar_loader.cpython-39-x86_64-linux-gnu.so +0 -0
  185. spikezoo-0.2.dist-info/METADATA +163 -0
  186. spikezoo-0.2.dist-info/RECORD +211 -0
  187. spikezoo/models/spcsnet_model.py +0 -19
  188. spikezoo-0.1.1.dist-info/METADATA +0 -39
  189. spikezoo-0.1.1.dist-info/RECORD +0 -36
  190. {spikezoo-0.1.1.dist-info → spikezoo-0.2.dist-info}/LICENSE.txt +0 -0
  191. {spikezoo-0.1.1.dist-info → spikezoo-0.2.dist-info}/WHEEL +0 -0
  192. {spikezoo-0.1.1.dist-info → spikezoo-0.2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,261 @@
1
+ import argparse
2
+ import json
3
+ import os
4
+ import time
5
+
6
+ import numpy as np
7
+ import torch
8
+ import torch.nn.functional as F
9
+ from einops import rearrange
10
+ from pytorch_wavelets import DWT1DForward
11
+
12
+ from transform import Compose, RandomCrop, RandomRotationFlip
13
+ from dataset import DatasetREDS
14
+ from dwtnets import Dwt1dResnetX_TCN
15
+ from utils import calculate_psnr, calculate_ssim, mkdir
16
+
17
+ parser = argparse.ArgumentParser(description='AAAI - WGSE - REDS')
18
+ parser.add_argument('-c', '--cuda', type=str, default='1', help='select gpu card')
19
+ parser.add_argument('-b', '--batch_size', type=int, default=16)
20
+ parser.add_argument('-e', '--epoch', type=int, default=600)
21
+ parser.add_argument('-w', '--wvl', type=str, default='db8', help='select wavelet base function')
22
+ parser.add_argument('-j', '--jlevels', type=int, default=5)
23
+ parser.add_argument('-k', '--kernel_size', type=int, default=3)
24
+ parser.add_argument('-l', '--logpath', type=str, default='WGSE-Dwt1dNet')
25
+ parser.add_argument('-r', '--resume_from', type=str, default=None)
26
+ parser.add_argument('--dataroot', type=str, default=None)
27
+
28
+ args = parser.parse_args()
29
+ os.environ["CUDA_VISIBLE_DEVICES"] = str(args.cuda)
30
+
31
+ resume_folder = args.resume_from
32
+ batch_size = args.batch_size
33
+ learning_rate = 1e-4
34
+ train_epoch = args.epoch
35
+ dataroot = args.dataroot
36
+
37
+ opt = 'adam'
38
+ opt_param = "{\"beta1\":0.9,\"beta2\":0.99,\"weight_decay\":0}"
39
+
40
+ random_seed = True
41
+ manual_seed = 123
42
+
43
+ scheduler = "MultiStepLR"
44
+ scheduler_param = "{\"milestones\": [400, 600], \"gamma\": 0.2}"
45
+
46
+ wvlname = args.wvl
47
+ j = args.jlevels
48
+ ks = args.kernel_size
49
+
50
+ if_save_model = False
51
+ eval_freq = 1
52
+ checkpoints_folder = args.logpath + '-' + args.wvl + '-' + str(args.jlevels) + '-' + 'ks' + str(ks)
53
+
54
+
55
+ def progress_bar_time(total_time):
56
+ hour = int(total_time) // 3600
57
+ minu = (int(total_time) % 3600) // 60
58
+ sec = int(total_time) % 60
59
+ return '%d:%02d:%02d' % (hour, minu, sec)
60
+
61
+ def main():
62
+
63
+ global batch_size, learning_rate, random_seed, manual_seed, opt, opt_param, if_save_model, checkpoints_folder
64
+
65
+ mkdir(os.path.join('logs', checkpoints_folder))
66
+
67
+ if random_seed:
68
+ seed = np.random.randint(0, 10000)
69
+ else:
70
+ seed = manual_seed
71
+ torch.manual_seed(seed)
72
+ np.random.seed(seed)
73
+
74
+ opt_param_dict = json.loads(opt_param)
75
+ scheduler_param_dict = json.loads(scheduler_param)
76
+
77
+ cfg = {}
78
+ cfg['rootfolder'] = os.path.join(dataroot, 'train')
79
+ cfg['spikefolder'] = 'input'
80
+ cfg['imagefolder'] = 'gt'
81
+ cfg['H'] = 250
82
+ cfg['W'] = 400
83
+ cfg['C'] = 41
84
+ train_set = DatasetREDS(cfg,
85
+ transform=Compose(
86
+ [
87
+ RandomCrop(128),
88
+ RandomRotationFlip(0.0, 0.5, 0.5)
89
+ ]),
90
+ )
91
+
92
+ cfg = {}
93
+ cfg['rootfolder'] = os.path.join(dataroot, 'val')
94
+ cfg['spikefolder'] = 'input'
95
+ cfg['imagefolder'] = 'gt'
96
+ cfg['H'] = 250
97
+ cfg['W'] = 400
98
+ cfg['C'] = 41
99
+ test_set = DatasetREDS(cfg)
100
+
101
+ print('train_set len', train_set.__len__())
102
+ print('test_set len', test_set.__len__())
103
+
104
+ train_data_loader = torch.utils.data.DataLoader(
105
+ dataset=train_set,
106
+ batch_size=batch_size,
107
+ shuffle=True,
108
+ num_workers=16,
109
+ drop_last=True)
110
+ test_data_loader = torch.utils.data.DataLoader(
111
+ dataset=test_set,
112
+ batch_size=1,
113
+ shuffle=True,
114
+ num_workers=1,
115
+ drop_last=False)
116
+
117
+ print(train_data_loader)
118
+ print(test_data_loader)
119
+
120
+ item0 = train_set[0]
121
+ s = item0['spikes']
122
+ s = s[None, :, 0:1, 0:1]
123
+ dwt = DWT1DForward(wave=wvlname, J=j)
124
+ B, T, H, W = s.shape
125
+ s_r = rearrange(s, 'b t h w -> b h w t')
126
+ s_r = rearrange(s_r, 'b h w t -> (b h w) 1 t')
127
+ yl, yh = dwt(s_r)
128
+ yl_size = yl.shape[-1]
129
+ yh_size = [yhi.shape[-1] for yhi in yh]
130
+
131
+ model = Dwt1dResnetX_TCN(inc=41, wvlname=wvlname, J=j, yl_size=yl_size, yh_size=yh_size, num_residual_blocks=3, norm=None, ks=ks)
132
+
133
+
134
+ if args.resume_from:
135
+ print("loading model weights from ", resume_folder)
136
+ saved_state_dict = torch.load(os.path.join(resume_folder, 'model_best.pt'))
137
+ model.load_state_dict(saved_state_dict.module.state_dict())
138
+ print("Weighted loaded.")
139
+
140
+ model = torch.nn.DataParallel(model).cuda()
141
+
142
+ # optimizer
143
+ if opt.lower() == 'adam':
144
+ assert ('beta1' in opt_param_dict.keys() and 'beta2' in opt_param_dict.keys() and 'weight_decay' in opt_param_dict.keys())
145
+ betas = (opt_param_dict['beta1'], opt_param_dict['beta2'])
146
+ del opt_param_dict['beta1']
147
+ del opt_param_dict['beta2']
148
+ opt_param_dict['betas'] = betas
149
+ optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, **opt_param_dict)
150
+ elif opt.lower() == 'sgd':
151
+ assert ('momentum' in opt_param_dict.keys() and 'weight_decay' in opt_param_dict.keys())
152
+ optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, **opt_param_dict)
153
+ else:
154
+ raise ValueError()
155
+
156
+ lr_scheduler = getattr(torch.optim.lr_scheduler, scheduler)(optimizer, **scheduler_param_dict)
157
+ best_psnr, best_ssim = 0.0, 0.0
158
+
159
+ for epoch in range(train_epoch+1):
160
+ print('Epoch %d/%d ... ' % (epoch, train_epoch))
161
+
162
+ model.train()
163
+ total_time = 0
164
+ f = open(os.path.join('logs', checkpoints_folder, 'log.txt'), "a")
165
+ for i, item in enumerate(train_data_loader):
166
+
167
+ start_time = time.time()
168
+
169
+ spikes = item['spikes'].cuda()
170
+ image = item['image'].cuda()
171
+ optimizer.zero_grad()
172
+
173
+ pred = model(spikes)
174
+
175
+ loss = F.l1_loss(image, pred)
176
+ loss.backward()
177
+ optimizer.step()
178
+
179
+ elapse_time = time.time() - start_time
180
+ total_time += elapse_time
181
+
182
+ lr_list = lr_scheduler.get_last_lr()
183
+ lr_str = ""
184
+ for ilr in lr_list:
185
+ lr_str += str(ilr) + ' '
186
+ print('\r[training] %3.2f%% | %6d/%6d [%s<%s, %.2fs/it] | LOSS: %.4f | LR: %s' % (
187
+ float(i + 1) / int(len(train_data_loader)) * 100, i + 1, int(len(train_data_loader)),
188
+ progress_bar_time(total_time),
189
+ progress_bar_time(total_time / (i + 1) * int(len(train_data_loader))),
190
+ total_time / (i + 1),
191
+ loss.item(),
192
+ lr_str), end='')
193
+ f.write('[training] %3.2f%% | %6d/%6d [%s<%s, %.2fs/it] | LOSS: %.4f | LR: %s\n' % (
194
+ float(i + 1) / int(len(train_data_loader)) * 100, i + 1, int(len(train_data_loader)),
195
+ progress_bar_time(total_time),
196
+ progress_bar_time(total_time / (i + 1) * int(len(train_data_loader))),
197
+ total_time / (i + 1),
198
+ loss.item(),
199
+ lr_str))
200
+
201
+ lr_scheduler.step()
202
+
203
+ print('')
204
+ if epoch % eval_freq == 0:
205
+ model.eval()
206
+ with torch.no_grad():
207
+ sum_ssim = 0.0
208
+ sum_psnr = 0.0
209
+ sum_num = 0
210
+ total_time = 0
211
+ for i, item in enumerate(test_data_loader):
212
+ start_time = time.time()
213
+
214
+ spikes = item['spikes'][:, 130:171, :, :].cuda()
215
+ image = item['image'].cuda()
216
+
217
+ pred = model(spikes)
218
+
219
+ prediction = pred[0].permute(1,2,0).cpu().numpy()
220
+ gt = image[0].permute(1,2,0).cpu().numpy()
221
+
222
+ sum_ssim += calculate_ssim(gt * 255.0, prediction * 255.0)
223
+ sum_psnr += calculate_psnr(gt * 255.0, prediction * 255.0)
224
+ sum_num += 1
225
+ elapse_time = time.time() - start_time
226
+ total_time += elapse_time
227
+ print('\r[evaluating] %3.2f%% | %6d/%6d [%s<%s, %.2fs/it]' % (
228
+ float(i + 1) / int(len(test_data_loader)) * 100, i + 1, int(len(test_data_loader)),
229
+ progress_bar_time(total_time),
230
+ progress_bar_time(total_time / (i + 1) * int(len(test_data_loader))),
231
+ total_time / (i + 1)), end='')
232
+ f.write('[evaluating] %3.2f%% | %6d/%6d [%s<%s, %.2fs/it]\n' % (
233
+ float(i + 1) / int(len(test_data_loader)) * 100, i + 1, int(len(test_data_loader)),
234
+ progress_bar_time(total_time),
235
+ progress_bar_time(total_time / (i + 1) * int(len(test_data_loader))),
236
+ total_time / (i + 1)))
237
+
238
+ sum_psnr /= sum_num
239
+ sum_ssim /= sum_num
240
+
241
+ print('')
242
+ print('\r[Evaluation Result] PSNR: %.3f | SSIM: %.3f' % (sum_psnr, sum_ssim))
243
+ f.write('[Evaluation Result] PSNR: %.3f | SSIM: %.3f\n' % (sum_psnr, sum_ssim))
244
+
245
+ if if_save_model and epoch % eval_freq == 0:
246
+ print('saving net...')
247
+ torch.save(model, os.path.join('logs', checkpoints_folder) + '/model_epoch%d.pt' % epoch)
248
+ print('saved')
249
+
250
+ if sum_psnr > best_psnr or sum_ssim > best_ssim:
251
+ best_psnr = sum_psnr
252
+ best_ssim = sum_ssim
253
+ print('saving best net...')
254
+ torch.save(model, os.path.join('logs', checkpoints_folder) + '/model_best.pt')
255
+ print('saved')
256
+
257
+ f.close()
258
+
259
+
260
+ if __name__ == '__main__':
261
+ main()
@@ -0,0 +1,139 @@
1
+ import torch
2
+ import torch.nn.functional as F
3
+ from math import sin, cos, pi
4
+ import numbers
5
+ import random
6
+
7
+
8
+ class Compose(object):
9
+ """Composes several transforms together.
10
+ Args:
11
+ transforms (list of ``Transform`` objects): list of transforms to compose.
12
+ Example:
13
+ >>> transforms.Compose([
14
+ >>> transforms.CenterCrop(10),
15
+ >>> transforms.ToTensor(),
16
+ >>> ])
17
+ """
18
+
19
+ def __init__(self, transforms):
20
+ self.transforms = transforms
21
+
22
+ def __call__(self, x, y):
23
+ for t in self.transforms:
24
+ x, y = t(x, y)
25
+ return x, y
26
+
27
+ def __repr__(self):
28
+ format_string = self.__class__.__name__ + '('
29
+ for t in self.transforms:
30
+ format_string += '\n'
31
+ format_string += ' {0}'.format(t)
32
+ format_string += '\n)'
33
+ return format_string
34
+
35
+
36
+ class RandomCrop(object):
37
+ """Crop the tensor at a random location.
38
+ """
39
+
40
+ def __init__(self, size):
41
+ if isinstance(size, numbers.Number):
42
+ self.size = (int(size), int(size))
43
+ else:
44
+ self.size = size
45
+
46
+ @staticmethod
47
+ def get_params(x, output_size):
48
+ w, h = x.shape[2], x.shape[1]
49
+ th, tw = output_size
50
+ assert(th <= h)
51
+ assert(tw <= w)
52
+ if w == tw and h == th:
53
+ return 0, 0, h, w
54
+
55
+ i = random.randint(0, h - th)
56
+ j = random.randint(0, w - tw)
57
+
58
+ return i, j, th, tw
59
+
60
+ def __call__(self, x, y):
61
+ """
62
+ x: [C x H x W] Tensor to be rotated.
63
+ Returns:
64
+ Tensor: Cropped tensor.
65
+ """
66
+ i, j, h, w = self.get_params(x, self.size)
67
+
68
+ return x[:, i:i + h, j:j + w], y[:, i:i + h, j:j + w]
69
+
70
+ def __repr__(self):
71
+ return self.__class__.__name__ + '(size={0})'.format(self.size)
72
+
73
+
74
+ class RandomRotationFlip(object):
75
+ """Rotate the image by angle.
76
+ """
77
+
78
+ def __init__(self, degrees, p_hflip=0.5, p_vflip=0.5):
79
+ if isinstance(degrees, numbers.Number):
80
+ if degrees < 0:
81
+ raise ValueError("If degrees is a single number, it must be positive.")
82
+ self.degrees = (-degrees, degrees)
83
+ else:
84
+ if len(degrees) != 2:
85
+ raise ValueError("If degrees is a sequence, it must be of len 2.")
86
+ self.degrees = degrees
87
+
88
+ self.p_hflip = p_hflip
89
+ self.p_vflip = p_vflip
90
+
91
+ @staticmethod
92
+ def get_params(degrees, p_hflip, p_vflip):
93
+ """Get parameters for ``rotate`` for a random rotation.
94
+ Returns:
95
+ sequence: params to be passed to ``rotate`` for random rotation.
96
+ """
97
+ angle = random.uniform(degrees[0], degrees[1])
98
+ angle_rad = angle * pi / 180.0
99
+
100
+ M_original_transformed = torch.FloatTensor([[cos(angle_rad), -sin(angle_rad), 0],
101
+ [sin(angle_rad), cos(angle_rad), 0],
102
+ [0, 0, 1]])
103
+
104
+ if random.random() < p_hflip:
105
+ M_original_transformed[:, 0] *= -1
106
+
107
+ if random.random() < p_vflip:
108
+ M_original_transformed[:, 1] *= -1
109
+
110
+ M_transformed_original = torch.inverse(M_original_transformed)
111
+
112
+ M_original_transformed = M_original_transformed[:2, :].unsqueeze(dim=0) # 3 x 3 -> N x 2 x 3
113
+ M_transformed_original = M_transformed_original[:2, :].unsqueeze(dim=0)
114
+
115
+ return M_original_transformed, M_transformed_original
116
+
117
+ def __call__(self, x, y):
118
+ """
119
+ x: [C x H x W] Tensor to be rotated.
120
+ Returns:
121
+ Tensor: Rotated tensor.
122
+ """
123
+ assert(len(x.shape) == 3)
124
+
125
+ M_original_transformed, M_transformed_original = self.get_params(self.degrees, self.p_hflip, self.p_vflip)
126
+ affine_gridx = F.affine_grid(M_original_transformed, x.unsqueeze(dim=0).shape, align_corners=False)
127
+ transformedx = F.grid_sample(x.unsqueeze(dim=0), affine_gridx, align_corners=False)
128
+
129
+ affine_gridy = F.affine_grid(M_original_transformed, y.unsqueeze(dim=0).shape, align_corners=False)
130
+ transformedy = F.grid_sample(y.unsqueeze(dim=0), affine_gridy, align_corners=False)
131
+
132
+ return transformedx.squeeze(dim=0), transformedy.squeeze(dim=0)
133
+
134
+ def __repr__(self):
135
+ format_string = self.__class__.__name__ + '(degrees={0}'.format(self.degrees)
136
+ format_string += ', p_flip={:.2f}'.format(self.p_hflip)
137
+ format_string += ', p_vlip={:.2f}'.format(self.p_vflip)
138
+ format_string += ')'
139
+ return format_string
@@ -0,0 +1,128 @@
1
+ import os
2
+ import math
3
+ import numpy as np
4
+ import cv2
5
+ os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
6
+
7
+
8
+ def RawToSpike(video_seq, h, w, flipud=True):
9
+
10
+ video_seq = np.array(video_seq).astype(np.uint8)
11
+ img_size = h*w
12
+ img_num = len(video_seq)//(img_size//8)
13
+ SpikeMatrix = np.zeros([img_num, h, w], np.uint8)
14
+ pix_id = np.arange(0,h*w)
15
+ pix_id = np.reshape(pix_id, (h, w))
16
+ comparator = np.left_shift(1, np.mod(pix_id, 8))
17
+ byte_id = pix_id // 8
18
+
19
+ for img_id in np.arange(img_num):
20
+ id_start = int(img_id)*int(img_size)//8
21
+ id_end = int(id_start) + int(img_size)//8
22
+ cur_info = video_seq[id_start:id_end]
23
+ data = cur_info[byte_id]
24
+ result = np.bitwise_and(data, comparator)
25
+ if flipud:
26
+ SpikeMatrix[img_id, :, :] = np.flipud((result == comparator))
27
+ else:
28
+ SpikeMatrix[img_id, :, :] = (result == comparator)
29
+
30
+ return SpikeMatrix
31
+
32
+ '''
33
+ # --------------------------------------------
34
+ # Kai Zhang (github: https://github.com/cszn)
35
+ # 03/Mar/2019
36
+ # --------------------------------------------
37
+ # https://github.com/twhui/SRGAN-pyTorch
38
+ # https://github.com/xinntao/BasicSR
39
+ # --------------------------------------------
40
+ '''
41
+
42
+
43
+ def mkdir(path):
44
+ if not os.path.exists(path):
45
+ os.makedirs(path)
46
+
47
+
48
+ def mkdirs(paths):
49
+ if isinstance(paths, str):
50
+ mkdir(paths)
51
+ else:
52
+ for path in paths:
53
+ mkdir(path)
54
+
55
+
56
+ # --------------------------------------------
57
+ # PSNR
58
+ # --------------------------------------------
59
+ def calculate_psnr(img1, img2, border=0):
60
+ # img1 and img2 have range [0, 255]
61
+ #img1 = img1.squeeze()
62
+ #img2 = img2.squeeze()
63
+ if not img1.shape == img2.shape:
64
+ raise ValueError('Input images must have the same dimensions.')
65
+ h, w = img1.shape[:2]
66
+ img1 = img1[border:h-border, border:w-border]
67
+ img2 = img2[border:h-border, border:w-border]
68
+
69
+ img1 = img1.astype(np.float64)
70
+ img2 = img2.astype(np.float64)
71
+ mse = np.mean((img1 - img2)**2)
72
+ if mse == 0:
73
+ return float('inf')
74
+ return 20 * math.log10(255.0 / math.sqrt(mse))
75
+
76
+
77
+ # --------------------------------------------
78
+ # SSIM
79
+ # --------------------------------------------
80
+ def calculate_ssim(img1, img2, border=0):
81
+ '''calculate SSIM
82
+ the same outputs as MATLAB's
83
+ img1, img2: [0, 255]
84
+ '''
85
+ #img1 = img1.squeeze()
86
+ #img2 = img2.squeeze()
87
+ if not img1.shape == img2.shape:
88
+ raise ValueError('Input images must have the same dimensions.')
89
+ h, w = img1.shape[:2]
90
+ img1 = img1[border:h-border, border:w-border]
91
+ img2 = img2[border:h-border, border:w-border]
92
+
93
+ if img1.ndim == 2:
94
+ return ssim(img1, img2)
95
+ elif img1.ndim == 3:
96
+ if img1.shape[2] == 3:
97
+ ssims = []
98
+ for i in range(3):
99
+ ssims.append(ssim(img1[:,:,i], img2[:,:,i]))
100
+ return np.array(ssims).mean()
101
+ elif img1.shape[2] == 1:
102
+ return ssim(np.squeeze(img1), np.squeeze(img2))
103
+ else:
104
+ raise ValueError('Wrong input image dimensions.')
105
+
106
+
107
+ def ssim(img1, img2):
108
+ C1 = (0.01 * 255)**2
109
+ C2 = (0.03 * 255)**2
110
+
111
+ img1 = img1.astype(np.float64)
112
+ img2 = img2.astype(np.float64)
113
+ kernel = cv2.getGaussianKernel(11, 1.5)
114
+ window = np.outer(kernel, kernel.transpose())
115
+
116
+ mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5] # valid
117
+ mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5]
118
+ mu1_sq = mu1**2
119
+ mu2_sq = mu2**2
120
+ mu1_mu2 = mu1 * mu2
121
+ sigma1_sq = cv2.filter2D(img1**2, -1, window)[5:-5, 5:-5] - mu1_sq
122
+ sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq
123
+ sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2
124
+
125
+ ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) *
126
+ (sigma1_sq + sigma2_sq + C2))
127
+ return ssim_map.mean()
128
+
Binary file
@@ -6,7 +6,7 @@ import numpy as np
6
6
  from spikezoo.utils.spike_utils import load_vidar_dat
7
7
  import re
8
8
  from dataclasses import dataclass, replace
9
- from typing import Literal
9
+ from typing import Literal,Union
10
10
  import warnings
11
11
  import torch
12
12
  from tqdm import tqdm
@@ -19,7 +19,7 @@ class BaseDatasetConfig:
19
19
  "Dataset name."
20
20
  dataset_name: str = "base"
21
21
  "Directory specifying location of data."
22
- root_dir: Path = Path(__file__).parent.parent / Path("data/base")
22
+ root_dir: Union[str,Path] = Path(__file__).parent.parent / Path("data/base")
23
23
  "Image width."
24
24
  width: int = 400
25
25
  "Image height."
@@ -108,7 +108,6 @@ class BaseDataset(Dataset):
108
108
  spike_name,
109
109
  height=self.cfg.height,
110
110
  width=self.cfg.width,
111
- out_type="float",
112
111
  out_format="tensor",
113
112
  )
114
113
  return spike
@@ -11,7 +11,7 @@ import torch.nn.functional as F
11
11
 
12
12
  # todo with the union type
13
13
  metric_pair_names = ["psnr", "ssim", "lpips", "mse"]
14
- metric_single_names = ["niqe", "brisque", "piqe"]
14
+ metric_single_names = ["niqe", "brisque", "piqe", "liqe_mix", "clipiqa"]
15
15
  metric_all_names = metric_pair_names + metric_single_names
16
16
 
17
17
  metric_single_list = {}
@@ -45,7 +45,6 @@ class BaseModel(nn.Module):
45
45
  self.device = "cuda" if torch.cuda.is_available() else "cpu"
46
46
  self.net = self.build_network().to(self.device)
47
47
  self.net = nn.DataParallel(self.net) if cfg.multi_gpu == True else self.net
48
- self.spike_size = None
49
48
  self.model_half_win_length: int = cfg.model_win_length // 2
50
49
 
51
50
  # ! Might lead to low speed training on the BSF.
@@ -104,8 +103,7 @@ class BaseModel(nn.Module):
104
103
  :,
105
104
  spike_mid - self.model_half_win_length : spike_mid + self.model_half_win_length + 1,
106
105
  ]
107
- if self.spike_size == None:
108
- self.spike_size = (spike.shape[2], spike.shape[3])
106
+ self.spike_size = (spike.shape[2], spike.shape[3])
109
107
  return spike
110
108
 
111
109
  def preprocess_spike(self, spike):
@@ -67,7 +67,7 @@ class Pipeline:
67
67
  """Pipeline setup."""
68
68
  # save folder
69
69
  self.thistime = datetime.now().strftime("%Y-%m-%d_%H-%M-%S-%f")[:23]
70
- self.save_folder = Path(__file__).parent.parent / Path(f"results") if len(self.cfg.save_folder) == 0 else self.cfg.save_folder
70
+ self.save_folder = Path(f"results") if len(self.cfg.save_folder) == 0 else self.cfg.save_folder
71
71
  mode_name = "train" if self.cfg._mode == "train_mode" else "detect"
72
72
  self.save_folder = (
73
73
  self.save_folder / Path(f"{mode_name}/{self.thistime}")
@@ -93,6 +93,7 @@ class Pipeline:
93
93
  def spk2img_from_dataset(self, idx=0):
94
94
  """Func---Save the recoverd image and calculate the metric from the given dataset."""
95
95
  # save folder
96
+ self.logger.info("*********************** spk2img_from_dataset ***********************")
96
97
  save_folder = self.save_folder / Path(f"spk2img_from_dataset/{self.dataset.cfg.dataset_name}_dataset/{self.dataset.cfg.split}/{idx:06d}")
97
98
  os.makedirs(str(save_folder), exist_ok=True)
98
99
 
@@ -106,9 +107,10 @@ class Pipeline:
106
107
  img = None
107
108
  return self._spk2img(spike, img, save_folder)
108
109
 
109
- def spk2img_from_file(self, file_path, height, width, img_path=None, remove_head=False):
110
+ def spk2img_from_file(self, file_path, height = -1, width = -1, img_path=None, remove_head=False):
110
111
  """Func---Save the recoverd image and calculate the metric from the given input file."""
111
112
  # save folder
113
+ self.logger.info("*********************** spk2img_from_file ***********************")
112
114
  save_folder = self.save_folder / Path(f"spk2img_from_file/{os.path.basename(file_path)}")
113
115
  os.makedirs(str(save_folder), exist_ok=True)
114
116
 
@@ -135,6 +137,7 @@ class Pipeline:
135
137
  def spk2img_from_spk(self, spike, img=None):
136
138
  """Func---Save the recoverd image and calculate the metric from the given spike stream."""
137
139
  # save folder
140
+ self.logger.info("*********************** spk2img_from_spk ***********************")
138
141
  save_folder = self.save_folder / Path(f"spk2img_from_spk/{self.thistime}")
139
142
  os.makedirs(str(save_folder), exist_ok=True)
140
143
 
@@ -188,7 +191,7 @@ class Pipeline:
188
191
  if self.cfg.save_metric == True:
189
192
  self.logger.info(f"----------------------Method: {model_name.upper()}----------------------")
190
193
  # paired metric
191
- for metric_name in metric_all_names:
194
+ for metric_name in self.cfg.metric_names:
192
195
  if img is not None and metric_name in metric_pair_names:
193
196
  self.logger.info(f"{metric_name.upper()}: {cal_metric_pair(recon_img,img,metric_name)}")
194
197
  elif metric_name in metric_single_names:
@@ -203,8 +206,7 @@ class Pipeline:
203
206
  if img is not None:
204
207
  img = tensor2npy(img[0, 0])
205
208
  cv2.imwrite(f"{save_folder}/sharp_img.png", img)
206
- self.logger.info(f"Images are saved on the {save_folder}")
207
-
209
+ self.logger.info(f"Images are saved on the {save_folder}")
208
210
  return recon_img_copy
209
211
 
210
212
  def _post_process_img(self, model_name, recon_img, gt_img):
@@ -66,7 +66,7 @@ class TrainPipeline(Pipeline):
66
66
  save_folder = self.save_folder / Path("imgs") / Path(f"{epoch:06d}")
67
67
  os.makedirs(save_folder, exist_ok=True)
68
68
  for batch_idx, batch in enumerate(tqdm(self.dataloader)):
69
- if batch_idx % (len(self.dataloader) // 4) == 0:
69
+ if batch_idx % (len(self.dataloader) // 4) != 0:
70
70
  continue
71
71
  batch = self.model.feed_to_device(batch)
72
72
  outputs = self.model.get_outputs_dict(batch)