spikezoo 0.1.1__py3-none-any.whl → 0.2__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (192) hide show
  1. spikezoo/__init__.py +13 -0
  2. spikezoo/archs/__pycache__/__init__.cpython-39.pyc +0 -0
  3. spikezoo/archs/base/__pycache__/nets.cpython-39.pyc +0 -0
  4. spikezoo/archs/base/nets.py +34 -0
  5. spikezoo/archs/bsf/README.md +92 -0
  6. spikezoo/archs/bsf/datasets/datasets.py +328 -0
  7. spikezoo/archs/bsf/datasets/ds_utils.py +64 -0
  8. spikezoo/archs/bsf/main.py +398 -0
  9. spikezoo/archs/bsf/metrics/psnr.py +22 -0
  10. spikezoo/archs/bsf/metrics/ssim.py +54 -0
  11. spikezoo/archs/bsf/models/bsf/__pycache__/align.cpython-39.pyc +0 -0
  12. spikezoo/archs/bsf/models/bsf/__pycache__/bsf.cpython-39.pyc +0 -0
  13. spikezoo/archs/bsf/models/bsf/__pycache__/rep.cpython-39.pyc +0 -0
  14. spikezoo/archs/bsf/models/bsf/align.py +154 -0
  15. spikezoo/archs/bsf/models/bsf/bsf.py +105 -0
  16. spikezoo/archs/bsf/models/bsf/dsft_convert.py +96 -0
  17. spikezoo/archs/bsf/models/bsf/rep.py +44 -0
  18. spikezoo/archs/bsf/models/get_model.py +7 -0
  19. spikezoo/archs/bsf/prepare_data/DSFT.py +62 -0
  20. spikezoo/archs/bsf/prepare_data/crop_dataset_train.py +135 -0
  21. spikezoo/archs/bsf/prepare_data/crop_dataset_val.py +139 -0
  22. spikezoo/archs/bsf/prepare_data/crop_train.sh +4 -0
  23. spikezoo/archs/bsf/prepare_data/crop_val.sh +4 -0
  24. spikezoo/archs/bsf/prepare_data/io_utils.py +64 -0
  25. spikezoo/archs/bsf/requirements.txt +9 -0
  26. spikezoo/archs/bsf/test.py +16 -0
  27. spikezoo/archs/bsf/utils.py +154 -0
  28. spikezoo/archs/spikeclip/__pycache__/nets.cpython-39.pyc +0 -0
  29. spikezoo/archs/spikeclip/nets.py +40 -0
  30. spikezoo/archs/spikeformer/CheckPoints/readme +1 -0
  31. spikezoo/archs/spikeformer/DataProcess/DataExtactor.py +60 -0
  32. spikezoo/archs/spikeformer/DataProcess/DataLoader.py +115 -0
  33. spikezoo/archs/spikeformer/DataProcess/LoadSpike.py +39 -0
  34. spikezoo/archs/spikeformer/EvalResults/readme +1 -0
  35. spikezoo/archs/spikeformer/LICENSE +21 -0
  36. spikezoo/archs/spikeformer/Metrics/Metrics.py +50 -0
  37. spikezoo/archs/spikeformer/Metrics/__init__.py +0 -0
  38. spikezoo/archs/spikeformer/Model/Loss.py +89 -0
  39. spikezoo/archs/spikeformer/Model/SpikeFormer.py +230 -0
  40. spikezoo/archs/spikeformer/Model/__init__.py +0 -0
  41. spikezoo/archs/spikeformer/Model/__pycache__/SpikeFormer.cpython-39.pyc +0 -0
  42. spikezoo/archs/spikeformer/Model/__pycache__/__init__.cpython-39.pyc +0 -0
  43. spikezoo/archs/spikeformer/README.md +30 -0
  44. spikezoo/archs/spikeformer/evaluate.py +87 -0
  45. spikezoo/archs/spikeformer/recon_real_data.py +97 -0
  46. spikezoo/archs/spikeformer/requirements.yml +95 -0
  47. spikezoo/archs/spikeformer/train.py +173 -0
  48. spikezoo/archs/spikeformer/utils.py +22 -0
  49. spikezoo/archs/spk2imgnet/.github/workflows/pylint.yml +23 -0
  50. spikezoo/archs/spk2imgnet/.gitignore +150 -0
  51. spikezoo/archs/spk2imgnet/DCNv2.py +135 -0
  52. spikezoo/archs/spk2imgnet/__pycache__/DCNv2.cpython-39.pyc +0 -0
  53. spikezoo/archs/spk2imgnet/__pycache__/align_arch.cpython-39.pyc +0 -0
  54. spikezoo/archs/spk2imgnet/__pycache__/nets.cpython-39.pyc +0 -0
  55. spikezoo/archs/spk2imgnet/align_arch.py +159 -0
  56. spikezoo/archs/spk2imgnet/dataset.py +144 -0
  57. spikezoo/archs/spk2imgnet/nets.py +230 -0
  58. spikezoo/archs/spk2imgnet/readme.md +86 -0
  59. spikezoo/archs/spk2imgnet/test_gen_imgseq.py +118 -0
  60. spikezoo/archs/spk2imgnet/train.py +189 -0
  61. spikezoo/archs/spk2imgnet/utils.py +64 -0
  62. spikezoo/archs/ssir/README.md +87 -0
  63. spikezoo/archs/ssir/configs/SSIR.yml +37 -0
  64. spikezoo/archs/ssir/configs/yml_parser.py +78 -0
  65. spikezoo/archs/ssir/datasets/dataset_sreds.py +170 -0
  66. spikezoo/archs/ssir/datasets/ds_utils.py +66 -0
  67. spikezoo/archs/ssir/losses.py +21 -0
  68. spikezoo/archs/ssir/main.py +326 -0
  69. spikezoo/archs/ssir/metrics/psnr.py +22 -0
  70. spikezoo/archs/ssir/metrics/ssim.py +54 -0
  71. spikezoo/archs/ssir/models/Vgg19.py +42 -0
  72. spikezoo/archs/ssir/models/__pycache__/layers.cpython-39.pyc +0 -0
  73. spikezoo/archs/ssir/models/__pycache__/networks.cpython-39.pyc +0 -0
  74. spikezoo/archs/ssir/models/layers.py +110 -0
  75. spikezoo/archs/ssir/models/networks.py +61 -0
  76. spikezoo/archs/ssir/requirements.txt +8 -0
  77. spikezoo/archs/ssir/shells/eval_SREDS.sh +6 -0
  78. spikezoo/archs/ssir/shells/train_SSIR.sh +12 -0
  79. spikezoo/archs/ssir/test.py +3 -0
  80. spikezoo/archs/ssir/utils.py +154 -0
  81. spikezoo/archs/ssml/__pycache__/cbam.cpython-39.pyc +0 -0
  82. spikezoo/archs/ssml/__pycache__/model.cpython-39.pyc +0 -0
  83. spikezoo/archs/ssml/cbam.py +224 -0
  84. spikezoo/archs/ssml/model.py +290 -0
  85. spikezoo/archs/ssml/res.png +0 -0
  86. spikezoo/archs/ssml/test.py +67 -0
  87. spikezoo/archs/stir/.git-credentials +0 -0
  88. spikezoo/archs/stir/README.md +65 -0
  89. spikezoo/archs/stir/ckpt_outputs/Descriptions.txt +1 -0
  90. spikezoo/archs/stir/configs/STIR.yml +37 -0
  91. spikezoo/archs/stir/configs/utils.py +155 -0
  92. spikezoo/archs/stir/configs/yml_parser.py +78 -0
  93. spikezoo/archs/stir/datasets/dataset_sreds.py +180 -0
  94. spikezoo/archs/stir/datasets/ds_utils.py +66 -0
  95. spikezoo/archs/stir/eval_SREDS.sh +5 -0
  96. spikezoo/archs/stir/main.py +397 -0
  97. spikezoo/archs/stir/metrics/losses.py +219 -0
  98. spikezoo/archs/stir/metrics/psnr.py +22 -0
  99. spikezoo/archs/stir/metrics/ssim.py +54 -0
  100. spikezoo/archs/stir/models/Vgg19.py +42 -0
  101. spikezoo/archs/stir/models/__pycache__/networks_STIR.cpython-39.pyc +0 -0
  102. spikezoo/archs/stir/models/__pycache__/submodules.cpython-39.pyc +0 -0
  103. spikezoo/archs/stir/models/__pycache__/transformer_new.cpython-39.pyc +0 -0
  104. spikezoo/archs/stir/models/networks_STIR.py +361 -0
  105. spikezoo/archs/stir/models/submodules.py +86 -0
  106. spikezoo/archs/stir/models/transformer_new.py +151 -0
  107. spikezoo/archs/stir/package_core/build/lib/package_core/__init__.py +0 -0
  108. spikezoo/archs/stir/package_core/build/lib/package_core/convertions.py +721 -0
  109. spikezoo/archs/stir/package_core/build/lib/package_core/disp_netS.py +133 -0
  110. spikezoo/archs/stir/package_core/build/lib/package_core/flow_utils.py +167 -0
  111. spikezoo/archs/stir/package_core/build/lib/package_core/generic_train_test.py +76 -0
  112. spikezoo/archs/stir/package_core/build/lib/package_core/geometry.py +458 -0
  113. spikezoo/archs/stir/package_core/build/lib/package_core/image_proc.py +183 -0
  114. spikezoo/archs/stir/package_core/build/lib/package_core/linalg.py +40 -0
  115. spikezoo/archs/stir/package_core/build/lib/package_core/losses.py +198 -0
  116. spikezoo/archs/stir/package_core/build/lib/package_core/metrics.py +51 -0
  117. spikezoo/archs/stir/package_core/build/lib/package_core/model_base.py +53 -0
  118. spikezoo/archs/stir/package_core/build/lib/package_core/net_basics.py +100 -0
  119. spikezoo/archs/stir/package_core/build/lib/package_core/resnet.py +333 -0
  120. spikezoo/archs/stir/package_core/build/lib/package_core/transforms.py +123 -0
  121. spikezoo/archs/stir/package_core/build/lib/package_core/utils.py +72 -0
  122. spikezoo/archs/stir/package_core/dist/package_core-0.0.0-py3.9.egg +0 -0
  123. spikezoo/archs/stir/package_core/package_core/__init__.py +0 -0
  124. spikezoo/archs/stir/package_core/package_core/__pycache__/__init__.cpython-39.pyc +0 -0
  125. spikezoo/archs/stir/package_core/package_core/__pycache__/net_basics.cpython-39.pyc +0 -0
  126. spikezoo/archs/stir/package_core/package_core/convertions.py +721 -0
  127. spikezoo/archs/stir/package_core/package_core/disp_netS.py +133 -0
  128. spikezoo/archs/stir/package_core/package_core/flow_utils.py +167 -0
  129. spikezoo/archs/stir/package_core/package_core/generic_train_test.py +76 -0
  130. spikezoo/archs/stir/package_core/package_core/geometry.py +458 -0
  131. spikezoo/archs/stir/package_core/package_core/image_proc.py +183 -0
  132. spikezoo/archs/stir/package_core/package_core/linalg.py +40 -0
  133. spikezoo/archs/stir/package_core/package_core/losses.py +198 -0
  134. spikezoo/archs/stir/package_core/package_core/metrics.py +51 -0
  135. spikezoo/archs/stir/package_core/package_core/model_base.py +53 -0
  136. spikezoo/archs/stir/package_core/package_core/net_basics.py +100 -0
  137. spikezoo/archs/stir/package_core/package_core/resnet.py +333 -0
  138. spikezoo/archs/stir/package_core/package_core/transforms.py +123 -0
  139. spikezoo/archs/stir/package_core/package_core/utils.py +72 -0
  140. spikezoo/archs/stir/package_core/package_core.egg-info/PKG-INFO +3 -0
  141. spikezoo/archs/stir/package_core/package_core.egg-info/SOURCES.txt +20 -0
  142. spikezoo/archs/stir/package_core/package_core.egg-info/dependency_links.txt +1 -0
  143. spikezoo/archs/stir/package_core/package_core.egg-info/top_level.txt +1 -0
  144. spikezoo/archs/stir/package_core/setup.py +5 -0
  145. spikezoo/archs/stir/requirements.txt +12 -0
  146. spikezoo/archs/stir/train_STIR.sh +9 -0
  147. spikezoo/archs/tfi/__pycache__/nets.cpython-39.pyc +0 -0
  148. spikezoo/archs/tfi/nets.py +43 -0
  149. spikezoo/archs/tfp/__pycache__/nets.cpython-39.pyc +0 -0
  150. spikezoo/archs/tfp/nets.py +13 -0
  151. spikezoo/archs/wgse/README.md +64 -0
  152. spikezoo/archs/wgse/__pycache__/dwtnets.cpython-39.pyc +0 -0
  153. spikezoo/archs/wgse/__pycache__/submodules.cpython-39.pyc +0 -0
  154. spikezoo/archs/wgse/dataset.py +59 -0
  155. spikezoo/archs/wgse/demo.png +0 -0
  156. spikezoo/archs/wgse/demo.py +83 -0
  157. spikezoo/archs/wgse/dwtnets.py +145 -0
  158. spikezoo/archs/wgse/eval.py +133 -0
  159. spikezoo/archs/wgse/logs/WGSE-Dwt1dNet-db8-5-ks3/log.txt +11 -0
  160. spikezoo/archs/wgse/submodules.py +68 -0
  161. spikezoo/archs/wgse/train.py +261 -0
  162. spikezoo/archs/wgse/transform.py +139 -0
  163. spikezoo/archs/wgse/utils.py +128 -0
  164. spikezoo/archs/wgse/weights/demo.png +0 -0
  165. spikezoo/data/base/test/gt/200_part1_key_id151.png +0 -0
  166. spikezoo/data/base/test/gt/200_part3_key_id151.png +0 -0
  167. spikezoo/data/base/test/gt/203_part1_key_id151.png +0 -0
  168. spikezoo/data/base/test/spike/200_part1_key_id151.dat +0 -0
  169. spikezoo/data/base/test/spike/200_part3_key_id151.dat +0 -0
  170. spikezoo/data/base/test/spike/203_part1_key_id151.dat +0 -0
  171. spikezoo/data/base/train/gt/203_part2_key_id151.png +0 -0
  172. spikezoo/data/base/train/gt/203_part3_key_id151.png +0 -0
  173. spikezoo/data/base/train/gt/203_part4_key_id151.png +0 -0
  174. spikezoo/data/base/train/spike/203_part2_key_id151.dat +0 -0
  175. spikezoo/data/base/train/spike/203_part3_key_id151.dat +0 -0
  176. spikezoo/data/base/train/spike/203_part4_key_id151.dat +0 -0
  177. spikezoo/datasets/base_dataset.py +2 -3
  178. spikezoo/metrics/__init__.py +1 -1
  179. spikezoo/models/base_model.py +1 -3
  180. spikezoo/pipeline/base_pipeline.py +7 -5
  181. spikezoo/pipeline/train_pipeline.py +1 -1
  182. spikezoo/utils/other_utils.py +16 -6
  183. spikezoo/utils/spike_utils.py +33 -29
  184. spikezoo/utils/vidar_loader.cpython-39-x86_64-linux-gnu.so +0 -0
  185. spikezoo-0.2.dist-info/METADATA +163 -0
  186. spikezoo-0.2.dist-info/RECORD +211 -0
  187. spikezoo/models/spcsnet_model.py +0 -19
  188. spikezoo-0.1.1.dist-info/METADATA +0 -39
  189. spikezoo-0.1.1.dist-info/RECORD +0 -36
  190. {spikezoo-0.1.1.dist-info → spikezoo-0.2.dist-info}/LICENSE.txt +0 -0
  191. {spikezoo-0.1.1.dist-info → spikezoo-0.2.dist-info}/WHEEL +0 -0
  192. {spikezoo-0.1.1.dist-info → spikezoo-0.2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,21 @@
1
+ import torch
2
+ import torch.nn.functional as F
3
+ from utils import InputPadder
4
+
5
+
6
+ def compute_l1_loss(img_list, gt):
7
+ l1_loss = 0.0
8
+ for img in img_list:
9
+ cur_size = img.shape[-2:]
10
+ gt_resize = F.interpolate(gt, size=cur_size, mode="bilinear", align_corners=False)
11
+ l1_loss += (img - gt_resize).abs().mean()
12
+
13
+ return l1_loss
14
+
15
+
16
+ def compute_per_loss_single(img, gt, vgg):
17
+ img_relu5_1 = vgg((img.repeat([1,3,1,1]) + 1.) / 2.)
18
+ with torch.no_grad():
19
+ gt_relu5_1 = vgg((gt.repeat([1,3,1,1]).detach() + 1.) / 2.)
20
+ percep_loss = F.mse_loss(img_relu5_1, gt_relu5_1)
21
+ return percep_loss
@@ -0,0 +1,326 @@
1
+ import argparse
2
+ import os
3
+ import os.path as osp
4
+ import time
5
+ import numpy as np
6
+ import torch
7
+ import torch.optim
8
+ import torch.backends.cudnn as cudnn
9
+ from tensorboardX import SummaryWriter
10
+ import pprint
11
+ import datetime
12
+ from configs.yml_parser import *
13
+ from datasets.dataset_sreds import *
14
+ from models.networks import *
15
+ from utils import *
16
+ from metrics.psnr import *
17
+ from metrics.ssim import *
18
+ import lpips
19
+ from losses import *
20
+ from models.Vgg19 import *
21
+ from spikingjelly.clock_driven import functional
22
+
23
+
24
+ parser = argparse.ArgumentParser()
25
+ parser.add_argument('--data-root', '-dr', type=str, default='/home/data/rzhao/REDS_dataset/REDS120fps')
26
+ parser.add_argument('--arch', '-a', type=str, default='SSIR')
27
+ parser.add_argument('--batch-size', '-b', type=int, default=8)
28
+ parser.add_argument('--learning-rate', '-lr', type=float, default=4e-4)
29
+ parser.add_argument('--configs', '-cfg', type=str, default='./configs/SSIR.yml')
30
+ parser.add_argument('--epochs', '-ep', type=int, default=100)
31
+ parser.add_argument('--epoch-size', '-es', type=int, default=1000)
32
+ parser.add_argument('--workers', '-j', type=int, default=8)
33
+ parser.add_argument('--pretrained', '-prt', type=str, default=None)
34
+ parser.add_argument('--start-epoch', '-sep', type=int, default=0)
35
+ parser.add_argument('--print-freq', '-pf', type=int, default=200)
36
+ parser.add_argument('--save-dir', '-sd', type=str, default='outputs')
37
+ parser.add_argument('--save-name', '-sn', type=str, default=None)
38
+ parser.add_argument('--vis-path', '-vp', type=str, default='vis')
39
+ parser.add_argument('--vis-name', '-vn', type=str, default='SSIR')
40
+ parser.add_argument('--eval_path', '-evp', type=str, default='eval_vis')
41
+ parser.add_argument('--vis-freq', '-vf', type=int, default=20)
42
+ parser.add_argument('--eval', '-e', action='store_true')
43
+ parser.add_argument('--w_per', '-wper', type=float, default=0.2)
44
+ parser.add_argument('--print_details', '-pd', action='store_true')
45
+ parser.add_argument('--milestones', default=[20, 25, 30, 35, 40, 45, 50, 55, 60, 65, 70], metavar='N', nargs='*')
46
+ parser.add_argument('--lr-scale-factor', '-lrsf', type=float, default=0.7)
47
+ parser.add_argument('--eval-interval', '-ei', type=int, default=5)
48
+ parser.add_argument('--save-interval', '-si', type=int, default=5)
49
+ parser.add_argument('--no_imwrite', action='store_true', default=False)
50
+ args = parser.parse_args()
51
+
52
+ args.milestones = [int(m) for m in args.milestones]
53
+ print('milstones', args.milestones)
54
+
55
+ cfg_parser = YAMLParser(args.configs)
56
+ cfg = cfg_parser.config
57
+
58
+ cfg['data']['root'] = args.data_root
59
+ cfg = add_args_to_cfg(cfg, args, ['batch_size', 'arch', 'learning_rate', 'configs', 'epochs', 'epoch_size', 'workers', 'pretrained', 'start_epoch',
60
+ 'print_freq', 'save_dir', 'save_name', 'vis_path', 'vis_name', 'eval_path', 'vis_freq', 'w_per'])
61
+
62
+ n_iter = 0
63
+
64
+
65
+
66
+ def train(cfg, train_loader, model, optimizer, epoch, train_writer):
67
+ ######################################################################
68
+ ## Init
69
+ global n_iter
70
+ batch_time = AverageMeter()
71
+ data_time = AverageMeter()
72
+ losses_name = ['rec_loss', 'per_loss', 'all_loss']
73
+ losses = AverageMeter(precision=6, i=len(losses_name), names=losses_name)
74
+ model.train()
75
+ end = time.time()
76
+
77
+ vgg19 = Vgg19(requires_grad=False).cuda()
78
+ if torch.cuda.device_count() > 1:
79
+ vgg19 = nn.DataParallel(vgg19, list(range(torch.cuda.device_count())))
80
+
81
+
82
+ ######################################################################
83
+ ## Training Loop
84
+
85
+ for ww, data in enumerate(train_loader, 0):
86
+
87
+ if ww >= args.epoch_size:
88
+ return
89
+
90
+ spikes = [spk.cuda() for spk in data['spikes']]
91
+ images = [img.cuda() for img in data['images']]
92
+ data_time.update(time.time() - end)
93
+
94
+ cur_spks = torch.cat(spikes, dim=1)
95
+
96
+ rec_loss = 0.0
97
+ per_loss = 0.0
98
+
99
+ for jj in range(1, 1+cfg['model']['seq_len']):
100
+ x = cur_spks[:, jj*20-20 : jj*20+21]
101
+
102
+ gt = images[jj]
103
+
104
+ out_list = model(x)
105
+ rec_list = [torch.clip(out, 0, 1) for out in out_list]
106
+
107
+ # if jj > 1+2:
108
+ if jj >= 2:
109
+ rec_loss += compute_l1_loss(rec_list, gt) / (cfg['model']['seq_len'] - 2)
110
+ if cfg['train']['w_per'] > 0:
111
+ per_loss += cfg['train']['w_per'] * compute_per_loss_single(rec_list[-1], gt, vgg19) / (cfg['model']['seq_len'] - 2)
112
+ else:
113
+ per_loss = torch.tensor([0.0]).cuda()
114
+
115
+ all_loss = rec_loss + per_loss
116
+
117
+ # record loss
118
+ losses.update([rec_loss.item(), per_loss.item(), all_loss.item()])
119
+ train_writer.add_scalar('rec_loss', rec_loss.item(), n_iter)
120
+ train_writer.add_scalar('per_loss', per_loss.item(), n_iter)
121
+ train_writer.add_scalar('total_loss', all_loss.item(), n_iter)
122
+
123
+ ## compute gradient and optimize
124
+ all_loss.backward()
125
+ optimizer.step()
126
+ optimizer.zero_grad()
127
+ functional.reset_net(model)
128
+
129
+ batch_time.update(time.time() - end)
130
+ end = time.time()
131
+ n_iter += 1
132
+
133
+ if n_iter % cfg['train']['vis_freq'] == 0:
134
+ vis_img(cfg['train']['vis_path'], rec_list[-1], cfg['train']['vis_name'])
135
+
136
+ if ww % cfg['train']['print_freq'] == 0:
137
+ out_str = 'Epoch: [{:d}] [{:d}/{:d}], Iter: {:d} '.format(epoch, ww, len(train_loader), n_iter-1)
138
+ out_str += 'Time: {}, Data: {} '.format(batch_time, data_time)
139
+ out_str += ' '.join(map('{:s} {:.4f} ({:.6f}) '.format, losses.names, losses.val, losses.avg))
140
+ out_str += 'lr {:.6f}'.format(optimizer.state_dict()['param_groups'][0]['lr'])
141
+ print(out_str)
142
+
143
+ end = time.time()
144
+
145
+ return
146
+
147
+
148
+ def validation(cfg, test_loader, model):
149
+ global n_iter
150
+ batch_time = AverageMeter()
151
+ data_time = AverageMeter()
152
+ metrics_name = ['PSNR', 'SSIM', 'LPIPS', 'AvgTime']
153
+ all_metrics = AverageMeter(i=len(metrics_name), precision=4, names=metrics_name)
154
+
155
+ model.eval()
156
+
157
+ loss_fn_vgg = lpips.LPIPS(net='alex').cuda()
158
+
159
+ for ww, data in enumerate(test_loader, 0):
160
+ st1 = time.time()
161
+ spikes = torch.cat([spk.cuda() for spk in data['spikes']], dim=1)
162
+ images = data['images']
163
+ data_time.update(time.time() - st1)
164
+
165
+ seq_metrics = AverageMeter(i=len(metrics_name), precision=4, names=metrics_name)
166
+
167
+ seq_len = len(data['spikes']) - 2
168
+
169
+ rec = []
170
+ for jj in range(1, 1+seq_len):
171
+ x = spikes[:, jj*20-20 : jj*20+21]
172
+
173
+ gt = images[jj].cuda()
174
+
175
+ with torch.no_grad():
176
+ st = time.time()
177
+ out = model(x)
178
+ mtime = time.time() - st
179
+ rec = torch.clip(out, 0, 1)
180
+
181
+ cur_rec = torch2numpy255(rec)
182
+ cur_gt = torch2numpy255(gt)
183
+
184
+ if not args.no_imwrite:
185
+ cur_vis_path = osp.join(args.eval_path, '{:03d}_{:03d}.png'.format(ww, jj))
186
+ cv2.imwrite(cur_vis_path, cur_rec.astype(np.uint8))
187
+
188
+ cur_psnr = calculate_psnr(cur_rec, cur_gt)
189
+ cur_ssim = calculate_ssim(cur_rec, cur_gt)
190
+ with torch.no_grad():
191
+ cur_lpips = loss_fn_vgg(rec, gt)
192
+
193
+ cur_metrics_list = [cur_psnr, cur_ssim, cur_lpips.item(), mtime]
194
+
195
+ all_metrics.update(cur_metrics_list)
196
+ seq_metrics.update(cur_metrics_list)
197
+
198
+ functional.reset_net(model)
199
+
200
+ if args.print_details:
201
+ ostr = 'Data{:02d} '.format(ww) + ' '.join(map('{:s} {:.4f} '.format, seq_metrics.names, seq_metrics.avg))
202
+ print(ostr)
203
+ print()
204
+
205
+
206
+ ostr = 'All ' + ' '.join(map('{:s} {:.4f} '.format, all_metrics.names, all_metrics.avg))
207
+ print(ostr)
208
+
209
+ return
210
+
211
+
212
+ def main():
213
+ ##########################################################################################################
214
+ # Set random seeds
215
+ set_seeds(cfg['seed'])
216
+
217
+ # Create save path and logs
218
+ timestamp1 = datetime.datetime.now().strftime('%m-%d')
219
+ timestamp2 = datetime.datetime.now().strftime('%H%M%S')
220
+ if args.save_name == None:
221
+ save_folder_name = 'b{:d}_{:s}'.format(args.batch_size, timestamp2)
222
+ else:
223
+ save_folder_name = 'b{:d}_{:s}_{:s}'.format(args.batch_size, timestamp2, args.save_name)
224
+
225
+ save_path = osp.join(args.save_dir, timestamp1, save_folder_name)
226
+ print('save path: ', save_path)
227
+ make_dir(save_path)
228
+ make_dir(args.vis_path)
229
+ make_dir(args.eval_path)
230
+
231
+ train_writer = SummaryWriter(save_path)
232
+
233
+ cfg_str = pprint.pformat(cfg)
234
+ print('=> configurations: ')
235
+ print(cfg_str)
236
+
237
+ ##########################################################################################################
238
+ ## Create model
239
+ model = eval(args.arch)()
240
+
241
+ if args.pretrained:
242
+ network_data = torch.load(args.pretrained)
243
+ print('=> using pretrained model {:s}'.format(args.pretrained))
244
+ model = torch.nn.DataParallel(model).cuda()
245
+ model = model.cuda()
246
+ model.load_state_dict(network_data)
247
+ else:
248
+ network_data = None
249
+ print('=> train from scratch')
250
+ model.init_weights()
251
+ print('=> model params: {:.6f}M'.format(model.num_parameters()/1e6))
252
+ model = torch.nn.DataParallel(model).cuda()
253
+ model = model.cuda()
254
+
255
+ cudnn.benchmark = True
256
+
257
+ ##########################################################################################################
258
+ ## Create Optimizer
259
+ cfgopt = cfg['optimizer']
260
+ cfgmdl = cfg['model']
261
+ assert(cfgopt['solver'] in ['Adam', 'SGD'])
262
+ print('=> settings {:s} solver'.format(cfgopt['solver']))
263
+
264
+ param_groups = [{'params': model.parameters(), 'weight_decay': cfgmdl['flow_weight_decay']}]
265
+ if cfgopt['solver'] == 'Adam':
266
+ optimizer = torch.optim.Adam(param_groups, args.learning_rate, betas=(cfgopt['momentum'], cfgopt['beta']))
267
+ elif cfgopt['solver'] == 'SGD':
268
+ optimizer = torch.optim.SGD(param_groups, args.learning_rate, momentum=cfgopt['momentum'])
269
+
270
+ ##########################################################################################################
271
+ ## Dataset
272
+ train_set = sreds_train(cfg)
273
+ train_loader = torch.utils.data.DataLoader(
274
+ train_set,
275
+ drop_last=False,
276
+ batch_size=cfg['train']['batch_size'],
277
+ shuffle=True,
278
+ num_workers=cfg['train']['workers'],
279
+ # pin_memory=True
280
+ )
281
+
282
+ test_set = sreds_test(cfg)
283
+ test_loader = torch.utils.data.DataLoader(
284
+ test_set,
285
+ drop_last=False,
286
+ batch_size=1,
287
+ shuffle=False,
288
+ num_workers=cfg['train']['workers']
289
+ )
290
+
291
+ ##########################################################################################################
292
+ ## Train or Evaluate
293
+ if args.eval:
294
+ validation(cfg=cfg, test_loader=test_loader, model=model)
295
+ else:
296
+ epoch = cfg['train']['start_epoch']
297
+ while(True):
298
+ train(
299
+ cfg=cfg,
300
+ train_loader=train_loader,
301
+ model=model,
302
+ optimizer=optimizer,
303
+ epoch=epoch,
304
+ train_writer=train_writer
305
+ )
306
+ epoch += 1
307
+
308
+ # scheduler can be added here
309
+ if epoch in args.milestones:
310
+ for param_group in optimizer.param_groups:
311
+ param_group['lr'] = param_group['lr'] * args.lr_scale_factor
312
+
313
+ # save model
314
+ if epoch % args.save_interval == 0:
315
+ model_save_name = '{:s}_epoch{:03d}.pth'.format(cfg['model']['arch'], epoch)
316
+ torch.save(model.state_dict(), osp.join(save_path, model_save_name))
317
+
318
+ # if epoch % 5 == 0:
319
+ if epoch % args.eval_interval == 0:
320
+ validation(cfg=cfg, test_loader=test_loader, model=model)
321
+
322
+ if epoch >= cfg['train']['epochs']:
323
+ break
324
+
325
+ if __name__ == '__main__':
326
+ main()
@@ -0,0 +1,22 @@
1
+ import math
2
+ import numpy as np
3
+
4
+ # --------------------------------------------
5
+ # PSNR
6
+ # --------------------------------------------
7
+ def calculate_psnr(img1, img2, border=0):
8
+ # img1 and img2 have range [0, 255]
9
+ #img1 = img1.squeeze()
10
+ #img2 = img2.squeeze()
11
+ if not img1.shape == img2.shape:
12
+ raise ValueError('Input images must have the same dimensions.')
13
+ h, w = img1.shape[:2]
14
+ img1 = img1[border:h-border, border:w-border]
15
+ img2 = img2[border:h-border, border:w-border]
16
+
17
+ img1 = img1.astype(np.float64)
18
+ img2 = img2.astype(np.float64)
19
+ mse = np.mean((img1 - img2)**2)
20
+ if mse == 0:
21
+ return float('inf')
22
+ return 20 * math.log10(255.0 / math.sqrt(mse))
@@ -0,0 +1,54 @@
1
+ import numpy as np
2
+ import cv2
3
+
4
+ # --------------------------------------------
5
+ # SSIM
6
+ # --------------------------------------------
7
+ def calculate_ssim(img1, img2, border=0):
8
+ '''calculate SSIM
9
+ the same outputs as MATLAB's
10
+ img1, img2: [0, 255]
11
+ '''
12
+ #img1 = img1.squeeze()
13
+ #img2 = img2.squeeze()
14
+ if not img1.shape == img2.shape:
15
+ raise ValueError('Input images must have the same dimensions.')
16
+ h, w = img1.shape[:2]
17
+ img1 = img1[border:h-border, border:w-border]
18
+ img2 = img2[border:h-border, border:w-border]
19
+
20
+ if img1.ndim == 2:
21
+ return ssim(img1, img2)
22
+ elif img1.ndim == 3:
23
+ if img1.shape[2] == 3:
24
+ ssims = []
25
+ for i in range(3):
26
+ ssims.append(ssim(img1[:,:,i], img2[:,:,i]))
27
+ return np.array(ssims).mean()
28
+ elif img1.shape[2] == 1:
29
+ return ssim(np.squeeze(img1), np.squeeze(img2))
30
+ else:
31
+ raise ValueError('Wrong input image dimensions.')
32
+
33
+
34
+ def ssim(img1, img2):
35
+ C1 = (0.01 * 255)**2
36
+ C2 = (0.03 * 255)**2
37
+
38
+ img1 = img1.astype(np.float64)
39
+ img2 = img2.astype(np.float64)
40
+ kernel = cv2.getGaussianKernel(11, 1.5)
41
+ window = np.outer(kernel, kernel.transpose())
42
+
43
+ mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5] # valid
44
+ mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5]
45
+ mu1_sq = mu1**2
46
+ mu2_sq = mu2**2
47
+ mu1_mu2 = mu1 * mu2
48
+ sigma1_sq = cv2.filter2D(img1**2, -1, window)[5:-5, 5:-5] - mu1_sq
49
+ sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq
50
+ sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2
51
+
52
+ ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) *
53
+ (sigma1_sq + sigma2_sq + C2))
54
+ return ssim_map.mean()
@@ -0,0 +1,42 @@
1
+ import torch
2
+ import torch.nn as nn
3
+ from torchvision import models
4
+
5
+
6
+ class Vgg19(torch.nn.Module):
7
+ def __init__(self, requires_grad=False, rgb_range=1):
8
+ super(Vgg19, self).__init__()
9
+
10
+ vgg_pretrained_features = models.vgg19(pretrained=True).features
11
+
12
+ self.slice1 = torch.nn.Sequential()
13
+ for x in range(30):
14
+ self.slice1.add_module(str(x), vgg_pretrained_features[x])
15
+
16
+ if not requires_grad:
17
+ for param in self.slice1.parameters():
18
+ param.requires_grad = False
19
+
20
+ vgg_mean = (0.485, 0.456, 0.406)
21
+ vgg_std = (0.229 * rgb_range, 0.224 * rgb_range, 0.225 * rgb_range)
22
+ self.sub_mean = MeanShift(rgb_range, vgg_mean, vgg_std)
23
+
24
+ def forward(self, X):
25
+ h = self.sub_mean(X)
26
+ h_relu5_1 = self.slice1(h)
27
+ return h_relu5_1
28
+
29
+ class MeanShift(nn.Conv2d):
30
+ def __init__(self, rgb_range, rgb_mean, rgb_std, sign=-1):
31
+ super(MeanShift, self).__init__(3, 3, kernel_size=1)
32
+ std = torch.Tensor(rgb_std)
33
+ self.weight.data = torch.eye(3).view(3, 3, 1, 1)
34
+ self.weight.data.div_(std.view(3, 1, 1, 1))
35
+ self.bias.data = sign * rgb_range * torch.Tensor(rgb_mean)
36
+ self.bias.data.div_(std)
37
+ # self.requires_grad = False
38
+ self.weight.requires_grad = False
39
+ self.bias.requires_grad = False
40
+
41
+ if __name__ == '__main__':
42
+ vgg19 = Vgg19(requires_grad=False)
@@ -0,0 +1,110 @@
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import math
5
+ from spikingjelly.clock_driven import layer, neuron, surrogate
6
+
7
+ backend = 'torch'
8
+ # backend = 'cupy'
9
+
10
+ def get_neuron_code(neuron_type):
11
+ if neuron_type == 'IF':
12
+ neuron_code = 'neuron.IFNode(surrogate_function=surrogate.ATan(), detach_reset=True)'
13
+ elif neuron_type == 'LIF':
14
+ neuron_code = 'neuron.LIFNode(surrogate_function=surrogate.ATan(), detach_reset=True)'
15
+ elif neuron_type == 'PLIF':
16
+ neuron_code = 'neuron.ParametricLIFNode(init_tau=2.0, surrogate_function=surrogate.ATan(), detach_reset=True)'
17
+
18
+ return neuron_code
19
+
20
+
21
+ class ConvLayerSNN(nn.Module):
22
+ def __init__(self, in_ch, out_ch, kernel_size=3, stride=1, padding=1, neuron_type='PLIF'):
23
+ super().__init__()
24
+
25
+ neuron_code = get_neuron_code(neuron_type)
26
+
27
+ self.layer = nn.Sequential(
28
+ nn.Conv2d(in_ch, out_ch, kernel_size=kernel_size, padding=padding, stride=stride, bias=False),
29
+ nn.BatchNorm2d(out_ch),
30
+ eval(neuron_code)
31
+ )
32
+
33
+ def forward(self, x):
34
+ return self.layer(x)
35
+
36
+
37
+ class BottleneckBlockSNN(nn.Module):
38
+ def __init__(self, in_ch, out_ch, expansion=4, kernel_size=3, padding=1, neuron_type='PLIF'):
39
+ super().__init__()
40
+
41
+ neuron_code = get_neuron_code(neuron_type)
42
+
43
+ mid_ch = out_ch // expansion
44
+ self.conv1 = nn.Sequential(
45
+ nn.Conv2d(in_ch, mid_ch, kernel_size=1, padding=0, stride=1, bias=False),
46
+ nn.BatchNorm2d(mid_ch),
47
+ )
48
+ self.sn1 = eval(neuron_code)
49
+
50
+ self.conv2 = nn.Sequential(
51
+ nn.Conv2d(mid_ch, mid_ch, kernel_size=kernel_size, padding=padding, stride=1, bias=False),
52
+ nn.BatchNorm2d(mid_ch),
53
+ )
54
+ self.sn2 = eval(neuron_code)
55
+
56
+ self.conv3 = nn.Sequential(
57
+ nn.Conv2d(mid_ch, out_ch, kernel_size=1, padding=0, stride=1, bias=False),
58
+ nn.BatchNorm2d(out_ch),
59
+ )
60
+ self.sn3 = eval(neuron_code)
61
+
62
+ def forward(self, x):
63
+ out = self.sn1(self.conv1(x))
64
+ out = self.sn2(self.conv2(out))
65
+ out = self.sn3(self.conv3(out))
66
+
67
+ out = out + x
68
+ return out
69
+
70
+
71
+ class DeConvLayerSNN(nn.Module):
72
+ def __init__(self, in_ch, out_ch, kernel_size=3, padding=1, neuron_type='PLIF'):
73
+ super().__init__()
74
+
75
+ neuron_code = get_neuron_code(neuron_type)
76
+
77
+ self.layer = nn.Sequential(
78
+ nn.ConvTranspose2d(in_ch, out_ch, kernel_size=4, padding=1, stride=2, bias=False),
79
+ nn.BatchNorm2d(out_ch),
80
+ eval(neuron_code)
81
+ )
82
+
83
+ def forward(self, x):
84
+ return self.layer(x)
85
+
86
+
87
+ class PredHead2(nn.Module):
88
+ def __init__(self, in_ch, out_ch, kernel_size=3, padding=1):
89
+ super().__init__()
90
+ self.conv1 = nn.Sequential(
91
+ nn.Conv2d(in_ch, 32, kernel_size=3, padding=4, stride=1, bias=False),
92
+ nn.BatchNorm2d(32)
93
+ )
94
+ self.conv2 = nn.Sequential(
95
+ nn.Conv2d(32, 32, kernel_size=3, padding=0, stride=1, bias=False),
96
+ nn.BatchNorm2d(32)
97
+ )
98
+ self.conv3 = nn.Sequential(
99
+ nn.Conv2d(32, 32, kernel_size=3, padding=0, stride=1, bias=False),
100
+ nn.BatchNorm2d(32)
101
+ )
102
+ self.conv4 = nn.Conv2d(32, 1, kernel_size=3, padding=0, stride=1, bias=False)
103
+
104
+ def forward(self, x):
105
+ x1 = self.conv1(x)
106
+ x2 = self.conv2(x1)
107
+ x3 = self.conv3(x2)
108
+ out = self.conv4(x3)
109
+ return out
110
+
@@ -0,0 +1,61 @@
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from .layers import *
5
+
6
+ class BasicModel(nn.Module):
7
+ def __init__(self):
8
+ super().__init__()
9
+
10
+ ####################################################################################
11
+ ## Tools functions for neural networks
12
+ def weight_parameters(self):
13
+ return [param for name, param in self.named_parameters() if 'weight' in name]
14
+
15
+ def bias_parameters(self):
16
+ return [param for name, param in self.named_parameters() if 'bias' in name]
17
+
18
+ def num_parameters(self):
19
+ return sum([p.data.nelement() if p.requires_grad else 0 for p in self.parameters()])
20
+
21
+ def init_weights(self):
22
+ for layer in self.named_modules():
23
+ if isinstance(layer, nn.Conv2d):
24
+ nn.init.kaiming_normal_(layer.weight)
25
+ if layer.bias is not None:
26
+ nn.init.constant_(layer.bias, 0)
27
+
28
+ elif isinstance(layer, nn.ConvTranspose2d):
29
+ nn.init.kaiming_normal_(layer.weight)
30
+ if layer.bias is not None:
31
+ nn.init.constant_(layer.bias, 0)
32
+
33
+
34
+ ########################################################################
35
+ class SSIR(BasicModel):
36
+ def __init__(self):
37
+ super().__init__()
38
+ base_ch = 128
39
+
40
+ self.static_conv = ConvLayerSNN(in_ch=41, out_ch=base_ch, stride=1)
41
+
42
+ self.enc1 = ConvLayerSNN(in_ch=base_ch , out_ch=base_ch , stride=2)
43
+ self.eres1 = BottleneckBlockSNN(in_ch=base_ch, out_ch=base_ch)
44
+
45
+ self.dec3 = DeConvLayerSNN(in_ch=base_ch, out_ch=base_ch//2)
46
+
47
+ self.pred3 = PredHead2(in_ch=base_ch//2 , out_ch=1)
48
+
49
+ def forward(self, x):
50
+ # x: B x C x H x W
51
+ x0 = self.static_conv(x)
52
+
53
+ x1 = self.eres1(self.enc1(x0))
54
+
55
+ x7 = self.dec3(x1)
56
+ out3 = self.pred3(x7)
57
+
58
+ if self.training:
59
+ return [out3]
60
+ else:
61
+ return out3
@@ -0,0 +1,8 @@
1
+ numpy
2
+ torch
3
+ torchvision
4
+ spikingjelly==0.0.0.0.14
5
+ opencv-python
6
+ tensorboardX
7
+ lpips
8
+ pyyaml
@@ -0,0 +1,6 @@
1
+ cd ../ &&
2
+ python3 main.py \
3
+ --data-root your_data_root \
4
+ --arch SSIR \
5
+ --pretrained ./ckpt/SSIR_e80.pth \
6
+ --eval
@@ -0,0 +1,12 @@
1
+ cd ../ &&
2
+ python3 main.py \
3
+ --data-root your_data_root \
4
+ --arch SSIR \
5
+ --batch-size 8 \
6
+ --learning-rate 4e-4 \
7
+ --configs ./configs/SSIR.yml \
8
+ --epochs 80 \
9
+ --workers 8 \
10
+ --w_per 0.2 \
11
+ --milestones 20 25 30 35 40 45 50 55 65 70 \
12
+ --lr-scale-factor 0.7 -pf 1
@@ -0,0 +1,3 @@
1
+ from models.networks import SSIR
2
+ net = SSIR()
3
+