spikezoo 0.1.2__py3-none-any.whl → 0.2__py3-none-any.whl

Sign up to get free protection for your applications and to get access to all the features.
Files changed (192) hide show
  1. spikezoo/__init__.py +13 -0
  2. spikezoo/archs/__pycache__/__init__.cpython-39.pyc +0 -0
  3. spikezoo/archs/base/__pycache__/nets.cpython-39.pyc +0 -0
  4. spikezoo/archs/base/nets.py +34 -0
  5. spikezoo/archs/bsf/README.md +92 -0
  6. spikezoo/archs/bsf/datasets/datasets.py +328 -0
  7. spikezoo/archs/bsf/datasets/ds_utils.py +64 -0
  8. spikezoo/archs/bsf/main.py +398 -0
  9. spikezoo/archs/bsf/metrics/psnr.py +22 -0
  10. spikezoo/archs/bsf/metrics/ssim.py +54 -0
  11. spikezoo/archs/bsf/models/bsf/__pycache__/align.cpython-39.pyc +0 -0
  12. spikezoo/archs/bsf/models/bsf/__pycache__/bsf.cpython-39.pyc +0 -0
  13. spikezoo/archs/bsf/models/bsf/__pycache__/rep.cpython-39.pyc +0 -0
  14. spikezoo/archs/bsf/models/bsf/align.py +154 -0
  15. spikezoo/archs/bsf/models/bsf/bsf.py +105 -0
  16. spikezoo/archs/bsf/models/bsf/dsft_convert.py +96 -0
  17. spikezoo/archs/bsf/models/bsf/rep.py +44 -0
  18. spikezoo/archs/bsf/models/get_model.py +7 -0
  19. spikezoo/archs/bsf/prepare_data/DSFT.py +62 -0
  20. spikezoo/archs/bsf/prepare_data/crop_dataset_train.py +135 -0
  21. spikezoo/archs/bsf/prepare_data/crop_dataset_val.py +139 -0
  22. spikezoo/archs/bsf/prepare_data/crop_train.sh +4 -0
  23. spikezoo/archs/bsf/prepare_data/crop_val.sh +4 -0
  24. spikezoo/archs/bsf/prepare_data/io_utils.py +64 -0
  25. spikezoo/archs/bsf/requirements.txt +9 -0
  26. spikezoo/archs/bsf/test.py +16 -0
  27. spikezoo/archs/bsf/utils.py +154 -0
  28. spikezoo/archs/spikeclip/__pycache__/nets.cpython-39.pyc +0 -0
  29. spikezoo/archs/spikeclip/nets.py +40 -0
  30. spikezoo/archs/spikeformer/CheckPoints/readme +1 -0
  31. spikezoo/archs/spikeformer/DataProcess/DataExtactor.py +60 -0
  32. spikezoo/archs/spikeformer/DataProcess/DataLoader.py +115 -0
  33. spikezoo/archs/spikeformer/DataProcess/LoadSpike.py +39 -0
  34. spikezoo/archs/spikeformer/EvalResults/readme +1 -0
  35. spikezoo/archs/spikeformer/LICENSE +21 -0
  36. spikezoo/archs/spikeformer/Metrics/Metrics.py +50 -0
  37. spikezoo/archs/spikeformer/Metrics/__init__.py +0 -0
  38. spikezoo/archs/spikeformer/Model/Loss.py +89 -0
  39. spikezoo/archs/spikeformer/Model/SpikeFormer.py +230 -0
  40. spikezoo/archs/spikeformer/Model/__init__.py +0 -0
  41. spikezoo/archs/spikeformer/Model/__pycache__/SpikeFormer.cpython-39.pyc +0 -0
  42. spikezoo/archs/spikeformer/Model/__pycache__/__init__.cpython-39.pyc +0 -0
  43. spikezoo/archs/spikeformer/README.md +30 -0
  44. spikezoo/archs/spikeformer/evaluate.py +87 -0
  45. spikezoo/archs/spikeformer/recon_real_data.py +97 -0
  46. spikezoo/archs/spikeformer/requirements.yml +95 -0
  47. spikezoo/archs/spikeformer/train.py +173 -0
  48. spikezoo/archs/spikeformer/utils.py +22 -0
  49. spikezoo/archs/spk2imgnet/.github/workflows/pylint.yml +23 -0
  50. spikezoo/archs/spk2imgnet/.gitignore +150 -0
  51. spikezoo/archs/spk2imgnet/DCNv2.py +135 -0
  52. spikezoo/archs/spk2imgnet/__pycache__/DCNv2.cpython-39.pyc +0 -0
  53. spikezoo/archs/spk2imgnet/__pycache__/align_arch.cpython-39.pyc +0 -0
  54. spikezoo/archs/spk2imgnet/__pycache__/nets.cpython-39.pyc +0 -0
  55. spikezoo/archs/spk2imgnet/align_arch.py +159 -0
  56. spikezoo/archs/spk2imgnet/dataset.py +144 -0
  57. spikezoo/archs/spk2imgnet/nets.py +230 -0
  58. spikezoo/archs/spk2imgnet/readme.md +86 -0
  59. spikezoo/archs/spk2imgnet/test_gen_imgseq.py +118 -0
  60. spikezoo/archs/spk2imgnet/train.py +189 -0
  61. spikezoo/archs/spk2imgnet/utils.py +64 -0
  62. spikezoo/archs/ssir/README.md +87 -0
  63. spikezoo/archs/ssir/configs/SSIR.yml +37 -0
  64. spikezoo/archs/ssir/configs/yml_parser.py +78 -0
  65. spikezoo/archs/ssir/datasets/dataset_sreds.py +170 -0
  66. spikezoo/archs/ssir/datasets/ds_utils.py +66 -0
  67. spikezoo/archs/ssir/losses.py +21 -0
  68. spikezoo/archs/ssir/main.py +326 -0
  69. spikezoo/archs/ssir/metrics/psnr.py +22 -0
  70. spikezoo/archs/ssir/metrics/ssim.py +54 -0
  71. spikezoo/archs/ssir/models/Vgg19.py +42 -0
  72. spikezoo/archs/ssir/models/__pycache__/layers.cpython-39.pyc +0 -0
  73. spikezoo/archs/ssir/models/__pycache__/networks.cpython-39.pyc +0 -0
  74. spikezoo/archs/ssir/models/layers.py +110 -0
  75. spikezoo/archs/ssir/models/networks.py +61 -0
  76. spikezoo/archs/ssir/requirements.txt +8 -0
  77. spikezoo/archs/ssir/shells/eval_SREDS.sh +6 -0
  78. spikezoo/archs/ssir/shells/train_SSIR.sh +12 -0
  79. spikezoo/archs/ssir/test.py +3 -0
  80. spikezoo/archs/ssir/utils.py +154 -0
  81. spikezoo/archs/ssml/__pycache__/cbam.cpython-39.pyc +0 -0
  82. spikezoo/archs/ssml/__pycache__/model.cpython-39.pyc +0 -0
  83. spikezoo/archs/ssml/cbam.py +224 -0
  84. spikezoo/archs/ssml/model.py +290 -0
  85. spikezoo/archs/ssml/res.png +0 -0
  86. spikezoo/archs/ssml/test.py +67 -0
  87. spikezoo/archs/stir/.git-credentials +0 -0
  88. spikezoo/archs/stir/README.md +65 -0
  89. spikezoo/archs/stir/ckpt_outputs/Descriptions.txt +1 -0
  90. spikezoo/archs/stir/configs/STIR.yml +37 -0
  91. spikezoo/archs/stir/configs/utils.py +155 -0
  92. spikezoo/archs/stir/configs/yml_parser.py +78 -0
  93. spikezoo/archs/stir/datasets/dataset_sreds.py +180 -0
  94. spikezoo/archs/stir/datasets/ds_utils.py +66 -0
  95. spikezoo/archs/stir/eval_SREDS.sh +5 -0
  96. spikezoo/archs/stir/main.py +397 -0
  97. spikezoo/archs/stir/metrics/losses.py +219 -0
  98. spikezoo/archs/stir/metrics/psnr.py +22 -0
  99. spikezoo/archs/stir/metrics/ssim.py +54 -0
  100. spikezoo/archs/stir/models/Vgg19.py +42 -0
  101. spikezoo/archs/stir/models/__pycache__/networks_STIR.cpython-39.pyc +0 -0
  102. spikezoo/archs/stir/models/__pycache__/submodules.cpython-39.pyc +0 -0
  103. spikezoo/archs/stir/models/__pycache__/transformer_new.cpython-39.pyc +0 -0
  104. spikezoo/archs/stir/models/networks_STIR.py +361 -0
  105. spikezoo/archs/stir/models/submodules.py +86 -0
  106. spikezoo/archs/stir/models/transformer_new.py +151 -0
  107. spikezoo/archs/stir/package_core/build/lib/package_core/__init__.py +0 -0
  108. spikezoo/archs/stir/package_core/build/lib/package_core/convertions.py +721 -0
  109. spikezoo/archs/stir/package_core/build/lib/package_core/disp_netS.py +133 -0
  110. spikezoo/archs/stir/package_core/build/lib/package_core/flow_utils.py +167 -0
  111. spikezoo/archs/stir/package_core/build/lib/package_core/generic_train_test.py +76 -0
  112. spikezoo/archs/stir/package_core/build/lib/package_core/geometry.py +458 -0
  113. spikezoo/archs/stir/package_core/build/lib/package_core/image_proc.py +183 -0
  114. spikezoo/archs/stir/package_core/build/lib/package_core/linalg.py +40 -0
  115. spikezoo/archs/stir/package_core/build/lib/package_core/losses.py +198 -0
  116. spikezoo/archs/stir/package_core/build/lib/package_core/metrics.py +51 -0
  117. spikezoo/archs/stir/package_core/build/lib/package_core/model_base.py +53 -0
  118. spikezoo/archs/stir/package_core/build/lib/package_core/net_basics.py +100 -0
  119. spikezoo/archs/stir/package_core/build/lib/package_core/resnet.py +333 -0
  120. spikezoo/archs/stir/package_core/build/lib/package_core/transforms.py +123 -0
  121. spikezoo/archs/stir/package_core/build/lib/package_core/utils.py +72 -0
  122. spikezoo/archs/stir/package_core/dist/package_core-0.0.0-py3.9.egg +0 -0
  123. spikezoo/archs/stir/package_core/package_core/__init__.py +0 -0
  124. spikezoo/archs/stir/package_core/package_core/__pycache__/__init__.cpython-39.pyc +0 -0
  125. spikezoo/archs/stir/package_core/package_core/__pycache__/net_basics.cpython-39.pyc +0 -0
  126. spikezoo/archs/stir/package_core/package_core/convertions.py +721 -0
  127. spikezoo/archs/stir/package_core/package_core/disp_netS.py +133 -0
  128. spikezoo/archs/stir/package_core/package_core/flow_utils.py +167 -0
  129. spikezoo/archs/stir/package_core/package_core/generic_train_test.py +76 -0
  130. spikezoo/archs/stir/package_core/package_core/geometry.py +458 -0
  131. spikezoo/archs/stir/package_core/package_core/image_proc.py +183 -0
  132. spikezoo/archs/stir/package_core/package_core/linalg.py +40 -0
  133. spikezoo/archs/stir/package_core/package_core/losses.py +198 -0
  134. spikezoo/archs/stir/package_core/package_core/metrics.py +51 -0
  135. spikezoo/archs/stir/package_core/package_core/model_base.py +53 -0
  136. spikezoo/archs/stir/package_core/package_core/net_basics.py +100 -0
  137. spikezoo/archs/stir/package_core/package_core/resnet.py +333 -0
  138. spikezoo/archs/stir/package_core/package_core/transforms.py +123 -0
  139. spikezoo/archs/stir/package_core/package_core/utils.py +72 -0
  140. spikezoo/archs/stir/package_core/package_core.egg-info/PKG-INFO +3 -0
  141. spikezoo/archs/stir/package_core/package_core.egg-info/SOURCES.txt +20 -0
  142. spikezoo/archs/stir/package_core/package_core.egg-info/dependency_links.txt +1 -0
  143. spikezoo/archs/stir/package_core/package_core.egg-info/top_level.txt +1 -0
  144. spikezoo/archs/stir/package_core/setup.py +5 -0
  145. spikezoo/archs/stir/requirements.txt +12 -0
  146. spikezoo/archs/stir/train_STIR.sh +9 -0
  147. spikezoo/archs/tfi/__pycache__/nets.cpython-39.pyc +0 -0
  148. spikezoo/archs/tfi/nets.py +43 -0
  149. spikezoo/archs/tfp/__pycache__/nets.cpython-39.pyc +0 -0
  150. spikezoo/archs/tfp/nets.py +13 -0
  151. spikezoo/archs/wgse/README.md +64 -0
  152. spikezoo/archs/wgse/__pycache__/dwtnets.cpython-39.pyc +0 -0
  153. spikezoo/archs/wgse/__pycache__/submodules.cpython-39.pyc +0 -0
  154. spikezoo/archs/wgse/dataset.py +59 -0
  155. spikezoo/archs/wgse/demo.png +0 -0
  156. spikezoo/archs/wgse/demo.py +83 -0
  157. spikezoo/archs/wgse/dwtnets.py +145 -0
  158. spikezoo/archs/wgse/eval.py +133 -0
  159. spikezoo/archs/wgse/logs/WGSE-Dwt1dNet-db8-5-ks3/log.txt +11 -0
  160. spikezoo/archs/wgse/submodules.py +68 -0
  161. spikezoo/archs/wgse/train.py +261 -0
  162. spikezoo/archs/wgse/transform.py +139 -0
  163. spikezoo/archs/wgse/utils.py +128 -0
  164. spikezoo/archs/wgse/weights/demo.png +0 -0
  165. spikezoo/data/base/test/gt/200_part1_key_id151.png +0 -0
  166. spikezoo/data/base/test/gt/200_part3_key_id151.png +0 -0
  167. spikezoo/data/base/test/gt/203_part1_key_id151.png +0 -0
  168. spikezoo/data/base/test/spike/200_part1_key_id151.dat +0 -0
  169. spikezoo/data/base/test/spike/200_part3_key_id151.dat +0 -0
  170. spikezoo/data/base/test/spike/203_part1_key_id151.dat +0 -0
  171. spikezoo/data/base/train/gt/203_part2_key_id151.png +0 -0
  172. spikezoo/data/base/train/gt/203_part3_key_id151.png +0 -0
  173. spikezoo/data/base/train/gt/203_part4_key_id151.png +0 -0
  174. spikezoo/data/base/train/spike/203_part2_key_id151.dat +0 -0
  175. spikezoo/data/base/train/spike/203_part3_key_id151.dat +0 -0
  176. spikezoo/data/base/train/spike/203_part4_key_id151.dat +0 -0
  177. spikezoo/datasets/base_dataset.py +2 -3
  178. spikezoo/metrics/__init__.py +1 -1
  179. spikezoo/models/base_model.py +1 -3
  180. spikezoo/pipeline/base_pipeline.py +7 -5
  181. spikezoo/pipeline/train_pipeline.py +1 -1
  182. spikezoo/utils/other_utils.py +16 -6
  183. spikezoo/utils/spike_utils.py +33 -29
  184. spikezoo/utils/vidar_loader.cpython-39-x86_64-linux-gnu.so +0 -0
  185. spikezoo-0.2.dist-info/METADATA +163 -0
  186. spikezoo-0.2.dist-info/RECORD +211 -0
  187. spikezoo/models/spcsnet_model.py +0 -19
  188. spikezoo-0.1.2.dist-info/METADATA +0 -39
  189. spikezoo-0.1.2.dist-info/RECORD +0 -36
  190. {spikezoo-0.1.2.dist-info → spikezoo-0.2.dist-info}/LICENSE.txt +0 -0
  191. {spikezoo-0.1.2.dist-info → spikezoo-0.2.dist-info}/WHEEL +0 -0
  192. {spikezoo-0.1.2.dist-info → spikezoo-0.2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,173 @@
1
+ import os
2
+ os.environ['CUDA_VISIBLE_DEVICES'] = "0"
3
+ import torch
4
+ from torch import optim
5
+ import numpy as np
6
+ from DataProcess import DataLoader as dl
7
+ from Model.SpikeFormer import SpikeFormer
8
+ from Metrics.Metrics import Metrics
9
+ from Model import Loss
10
+ from utils import SaveModel, LoadModel
11
+ from PIL import Image
12
+
13
+ def eval(model, validData, epoch, optimizer, metrics):
14
+
15
+ model.eval()
16
+ print('Eval Epoch: %s' %(epoch))
17
+
18
+ with torch.no_grad():
19
+ pres = []
20
+ gts = []
21
+ for i, (spikes, gtImg) in enumerate(validData):
22
+
23
+ spikes = spikes.cuda()
24
+ gtImg = gtImg.cuda()
25
+ predImg = model(spikes)
26
+ predImg = predImg.squeeze(1)
27
+ predImg = predImg[:,3:-3,:]
28
+
29
+ predImg = predImg.clamp(min=-1., max=1.)
30
+ predImg = predImg.detach().cpu().numpy()
31
+ gtImg = gtImg.clamp(min=-1., max=1.)
32
+ gtImg = gtImg.detach().cpu().numpy()
33
+
34
+ predImg = (predImg + 1.) / 2. * 255.
35
+ predImg = predImg.astype(np.uint8)
36
+
37
+ gtImg = (gtImg + 1.) / 2. * 255.
38
+ gtImg = gtImg.astype(np.uint8)
39
+
40
+ pres.append(predImg)
41
+ gts.append(gtImg)
42
+ pres = np.concatenate(pres, axis=0)
43
+ gts = np.concatenate(gts, axis=0)
44
+
45
+ psnr = metrics.Cal_PSNR(pres, gts)
46
+ ssim = metrics.Cal_SSIM(pres, gts)
47
+ best_psnr, best_ssim, _ = metrics.GetBestMetrics()
48
+
49
+ SaveModel(epoch, (psnr, ssim), model, optimizer, saveRoot)
50
+ if psnr >= best_psnr and ssim >= best_ssim:
51
+ metrics.Update(psnr, ssim)
52
+ SaveModel(epoch, (psnr, ssim), model, optimizer, saveRoot, best=True)
53
+ with open('eval_best_log.txt', 'w') as f:
54
+ f.write('epoch: %s; psnr: %s, ssim: %s\n' %(epoch, psnr, ssim))
55
+ B, H, W = pres.shape
56
+ divide_line = np.zeros((H,4)).astype(np.uint8)
57
+ num = 0
58
+ for pre, gt in zip(pres, gts):
59
+ num += 1
60
+ concatImg = np.concatenate([pre, divide_line, gt], axis=1)
61
+ concatImg = Image.fromarray(concatImg)
62
+ concatImg.save('EvalResults/valid_%s.jpg' % (num))
63
+
64
+ print('*********************************************************')
65
+ best_psnr, best_ssim, _ = metrics.GetBestMetrics()
66
+ print('Eval Epoch: %s, PSNR: %s, SSIM: %s, Best_PSNR: %s, Best_SSIM: %s'
67
+ %(epoch, psnr, ssim, best_psnr, best_ssim))
68
+
69
+ model.train()
70
+
71
+ def Train(trainData, validData, model, optimizer, epoch, start_epoch, metrics, saveRoot, perIter):
72
+ avg_l2_loss = 0.
73
+ avg_vgg_loss = 0.
74
+ avg_edge_loss = 0.
75
+ avg_total_loss = 0.
76
+ l2loss = Loss.CharbonnierLoss()
77
+ vggloss = Loss.VGGLoss4('vgg19-low-level4.pth').cuda()
78
+ criterion_edge = Loss.EdgeLoss()
79
+ LAMBDA_L2 = 100.0
80
+ LAMBDA_VGG = 1.0
81
+ LAMBDA_EDGE = 5.0
82
+ for i in range(start_epoch, epoch):
83
+ for iter, (spikes, gtImg) in enumerate(trainData):
84
+ spikes = spikes.cuda()
85
+ gtImg = gtImg.cuda()
86
+ predImg = model(spikes)
87
+ gtImg = gtImg.unsqueeze(1)
88
+ predImg = predImg[:,:,3:-3,:]
89
+
90
+ loss_vgg = vggloss(gtImg, predImg) * LAMBDA_VGG
91
+ loss_l2 = l2loss(gtImg, predImg) * LAMBDA_L2
92
+ loss_edge = criterion_edge(gtImg, predImg) * LAMBDA_EDGE
93
+
94
+ totalLoss = loss_l2 + loss_vgg + loss_edge
95
+
96
+ optimizer.zero_grad()
97
+ totalLoss.backward()
98
+ optimizer.step()
99
+
100
+ avg_l2_loss += loss_l2.detach().cpu()
101
+ avg_vgg_loss += loss_vgg.detach().cpu()
102
+ avg_edge_loss += loss_edge.detach().cpu()
103
+ avg_total_loss += totalLoss.detach().cpu()
104
+ if (iter + 1) % perIter == 0:
105
+ avg_l2_loss = avg_l2_loss / perIter
106
+ avg_vgg_loss = avg_vgg_loss / perIter
107
+ avg_edge_loss = avg_edge_loss / perIter
108
+ avg_total_loss = avg_total_loss / perIter
109
+ print('=============================================================')
110
+ print('Epoch: %s, Iter: %s' % (i, iter + 1))
111
+ print('L2Loss: %s; VggLoss: %s; EdgeLoss: %s; TotalLoss: %s' % (
112
+ avg_l2_loss.item(), avg_vgg_loss.item(), avg_edge_loss.item(), avg_total_loss.item()))
113
+ avg_l2_loss = 0.
114
+ avg_vgg_loss = 0.
115
+ avg_edge_loss = 0.
116
+ avg_total_loss = 0.
117
+
118
+ if (i + 1) % 1 == 0:
119
+ eval(model, validData, i, optimizer, metrics)
120
+
121
+ if __name__ == "__main__":
122
+
123
+ dataPath = "/home/storage2/shechen/Spike_Sample_250x400"
124
+ spikeRadius = 32 # half length of input spike sequence expcept for the middle frame
125
+ spikeLen = 2 * spikeRadius + 1 # length of input spike sequence
126
+ batchSize = 2
127
+ epoch = 200
128
+ start_epoch = 0
129
+ lr = 2e-4
130
+ saveRoot = "CheckPoints/" # path to save the trained model
131
+ perIter = 20
132
+
133
+ reuse = False
134
+ reuseType = 'latest' # 'latest' or 'best'
135
+ checkPath = os.path.join('CheckPoints', '%s.pth' % (reuseType))
136
+
137
+ trainContainer = dl.DataContainer(dataPath=dataPath, dataType='train',
138
+ spikeRadius=spikeRadius,
139
+ batchSize=batchSize)
140
+ trainData = trainContainer.GetLoader()
141
+
142
+ validContainer = dl.DataContainer(dataPath=dataPath, dataType='valid',
143
+ spikeRadius=spikeRadius,
144
+ batchSize=batchSize)
145
+ validData = validContainer.GetLoader()
146
+
147
+ metrics = Metrics()
148
+
149
+ model = SpikeFormer(
150
+ inputDim = spikeLen,
151
+ dims = (32, 64, 160, 256), # dimensions of each stage
152
+ heads = (1, 2, 5, 8), # heads of each stage
153
+ ff_expansion = (8, 8, 4, 4), # feedforward expansion factor of each stage
154
+ reduction_ratio = (8, 4, 2, 1), # reduction ratio of each stage for efficient attention
155
+ num_layers = 2, # num layers of each stage
156
+ decoder_dim = 256, # decoder dimension
157
+ out_channel = 1 # channel of restored image
158
+ ).cuda()
159
+
160
+ optimizer = optim.Adam(model.parameters(), lr=lr, betas=(0.9, 0.999), amsgrad=False)
161
+
162
+ if reuse:
163
+ preEpoch, prePerformance, modelDict, optDict = LoadModel(checkPath, model, optimizer)
164
+ start_epoch = preEpoch + 1
165
+ psnr, ssim = prePerformance[0], prePerformance[1]
166
+ metrics.Update(psnr, ssim)
167
+ for para in optimizer.param_groups:
168
+ para['lr'] = lr
169
+
170
+ model.train()
171
+
172
+ Train(trainData, validData, model, optimizer, epoch, start_epoch,
173
+ metrics, saveRoot, perIter)
@@ -0,0 +1,22 @@
1
+ import os
2
+ import torch
3
+
4
+ def SaveModel(epoch, bestPerformance, model, optimizer, saveRoot, best=False):
5
+ saveDict = {
6
+ 'pre_epoch':epoch,
7
+ 'performance':bestPerformance,
8
+ 'model_state_dict':model.state_dict(),
9
+ 'optimizer_state_dict':optimizer.state_dict()
10
+ }
11
+ savePath = os.path.join(saveRoot, '%s.pth' %('latest' if not best else 'best'))
12
+ torch.save(saveDict, savePath)
13
+
14
+ def LoadModel(checkPath, model, optimizer=None):
15
+ stateDict = torch.load(checkPath)
16
+ pre_epoch = stateDict['pre_epoch']
17
+ model.load_state_dict(stateDict['model_state_dict'])
18
+ if optimizer is not None:
19
+ optimizer.load_state_dict(stateDict['optimizer_state_dict'])
20
+
21
+ return pre_epoch, stateDict['performance'], \
22
+ stateDict['model_state_dict'], stateDict['optimizer_state_dict']
@@ -0,0 +1,23 @@
1
+ name: Pylint
2
+
3
+ on: [push]
4
+
5
+ jobs:
6
+ build:
7
+ runs-on: ubuntu-latest
8
+ strategy:
9
+ matrix:
10
+ python-version: ["3.8", "3.9", "3.10"]
11
+ steps:
12
+ - uses: actions/checkout@v3
13
+ - name: Set up Python ${{ matrix.python-version }}
14
+ uses: actions/setup-python@v3
15
+ with:
16
+ python-version: ${{ matrix.python-version }}
17
+ - name: Install dependencies
18
+ run: |
19
+ python -m pip install --upgrade pip
20
+ pip install pylint
21
+ - name: Analysing the code with pylint
22
+ run: |
23
+ pylint $(git ls-files '*.py')
@@ -0,0 +1,150 @@
1
+ ### Python template
2
+ # Byte-compiled / optimized / DLL files
3
+ __pycache__/
4
+ *.py[cod]
5
+ *$py.class
6
+ .idea/
7
+
8
+ # C extensions
9
+ *.so
10
+
11
+ # Distribution / packaging
12
+ .Python
13
+ build/
14
+ develop-eggs/
15
+ dist/
16
+ downloads/
17
+ eggs/
18
+ .eggs/
19
+ lib/
20
+ lib64/
21
+ parts/
22
+ sdist/
23
+ var/
24
+ wheels/
25
+ share/python-wheels/
26
+ results/
27
+ ckpt/
28
+ ckpt2/
29
+ old_ckpt/
30
+ Spk2ImgNet_test2/
31
+ Spk2ImgNet_train/
32
+ *.zip
33
+ *.pth
34
+ *.h5
35
+ *.egg-info/
36
+ .installed.cfg
37
+ *.egg
38
+ MANIFEST
39
+
40
+ # PyInstaller
41
+ # Usually these files are written by a python script from a template
42
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
43
+ *.manifest
44
+ *.spec
45
+
46
+ # Installer logs
47
+ pip-log.txt
48
+ pip-delete-this-directory.txt
49
+
50
+ # Unit test / coverage reports
51
+ htmlcov/
52
+ .tox/
53
+ .nox/
54
+ .coverage
55
+ .coverage.*
56
+ .cache
57
+ nosetests.xml
58
+ coverage.xml
59
+ *.cover
60
+ *.py,cover
61
+ .hypothesis/
62
+ .pytest_cache/
63
+ cover/
64
+
65
+ # Translations
66
+ *.mo
67
+ *.pot
68
+
69
+ # Django stuff:
70
+ *.log
71
+ local_settings.py
72
+ db.sqlite3
73
+ db.sqlite3-journal
74
+
75
+ # Flask stuff:
76
+ instance/
77
+ .webassets-cache
78
+
79
+ # Scrapy stuff:
80
+ .scrapy
81
+
82
+ # Sphinx documentation
83
+ docs/_build/
84
+
85
+ # PyBuilder
86
+ .pybuilder/
87
+ target/
88
+
89
+ # Jupyter Notebook
90
+ .ipynb_checkpoints
91
+
92
+ # IPython
93
+ profile_default/
94
+ ipython_config.py
95
+
96
+ # pyenv
97
+ # For a library or package, you might want to ignore these files since the code is
98
+ # intended to run in multiple environments; otherwise, check them in:
99
+ # .python-version
100
+
101
+ # pipenv
102
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
103
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
104
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
105
+ # install all needed dependencies.
106
+ #Pipfile.lock
107
+
108
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow
109
+ __pypackages__/
110
+
111
+ # Celery stuff
112
+ celerybeat-schedule
113
+ celerybeat.pid
114
+
115
+ # SageMath parsed files
116
+ *.sage.py
117
+
118
+ # Environments
119
+ .env
120
+ .venv
121
+ env/
122
+ venv/
123
+ ENV/
124
+ env.bak/
125
+ venv.bak/
126
+
127
+ # Spyder project settings
128
+ .spyderproject
129
+ .spyproject
130
+
131
+ # Rope project settings
132
+ .ropeproject
133
+
134
+ # mkdocs documentation
135
+ /site
136
+
137
+ # mypy
138
+ .mypy_cache/
139
+ .dmypy.json
140
+ dmypy.json
141
+
142
+ # Pyre type checker
143
+ .pyre/
144
+
145
+ # pytype static type analyzer
146
+ .pytype/
147
+
148
+ # Cython debug symbols
149
+ cython_debug/
150
+
@@ -0,0 +1,135 @@
1
+ #!/usr/bin/env python
2
+
3
+ import math
4
+ import logging
5
+ import torch
6
+ from PIL.Image import logger
7
+ from torch import nn
8
+ import torchvision
9
+ from torch.nn.modules.utils import _pair
10
+
11
+
12
+ class DCNv2(nn.Module):
13
+ def __init__(self, in_channels, out_channels, kernel_size, stride, padding, dilation=1,
14
+ deformable_groups=1):
15
+ super(DCNv2, self).__init__()
16
+ self.in_channels = in_channels
17
+ self.out_channels = out_channels
18
+ self.kernel_size = _pair(kernel_size)
19
+ self.stride = _pair(stride)
20
+ self.padding = _pair(padding)
21
+ self.dilation = _pair(dilation)
22
+ self.deformable_groups = deformable_groups
23
+
24
+ self.weight = nn.Parameter(torch.Tensor(out_channels, in_channels, *self.kernel_size))
25
+ self.bias = nn.Parameter(torch.Tensor(out_channels))
26
+ self.reset_parameters()
27
+
28
+ def reset_parameters(self):
29
+ n = self.in_channels
30
+ for k in self.kernel_size:
31
+ n *= k
32
+ stdv = 1. / math.sqrt(n)
33
+ self.weight.data.uniform_(-stdv, stdv)
34
+ self.bias.data.zero_()
35
+
36
+ def forward(self, input, offset, mask):
37
+ assert 2 * self.deformable_groups * self.kernel_size[0] * self.kernel_size[1] == \
38
+ offset.shape[1]
39
+ assert self.deformable_groups * self.kernel_size[0] * self.kernel_size[1] == \
40
+ mask.shape[1]
41
+
42
+
43
+ return torchvision.ops.deform_conv2d(
44
+ input=input,
45
+ offset=offset,
46
+ mask=mask,
47
+ weight=self.weight,
48
+ bias=self.bias,
49
+ stride=self.stride,
50
+ padding=self.padding,
51
+ dilation=self.dilation,
52
+ groups=self.deformable_groups
53
+ )
54
+
55
+
56
+
57
+
58
+ class DCN(DCNv2):
59
+ def __init__(self, in_channels, out_channels, kernel_size, stride, padding, dilation=1,
60
+ deformable_groups=1):
61
+ super(DCN, self).__init__(in_channels, out_channels, kernel_size, stride, padding, dilation,
62
+ deformable_groups)
63
+
64
+ channels_ = self.deformable_groups * 3 * self.kernel_size[0] * self.kernel_size[1]
65
+ self.conv_offset_mask = nn.Conv2d(self.in_channels, channels_, kernel_size=self.kernel_size,
66
+ stride=self.stride, padding=self.padding, bias=True)
67
+ self.init_offset()
68
+
69
+ def init_offset(self):
70
+ self.conv_offset_mask.weight.data.zero_()
71
+ self.conv_offset_mask.bias.data.zero_()
72
+
73
+ def forward(self, input):
74
+ out = self.conv_offset_mask(input)
75
+ o1, o2, mask = torch.chunk(out, 3, dim=1)
76
+ offset = torch.cat((o1, o2), dim=1)
77
+ mask = torch.sigmoid(mask)
78
+
79
+ # return dcn_v2_conv(input, offset, mask, self.weight, self.bias, self.stride, self.padding,
80
+ # self.dilation, self.deformable_groups)
81
+
82
+ return torchvision.ops.deform_conv2d(
83
+ input=input,
84
+ offset=offset,
85
+ mask=mask,
86
+ weight=self.weight,
87
+ bias=self.bias,
88
+ stride=self.stride,
89
+ padding=self.padding,
90
+ dilation=self.dilation,
91
+ )
92
+
93
+
94
+ class DCN_sep(DCNv2):
95
+ '''Use other features to generate offsets and masks'''
96
+
97
+ def __init__(self, in_channels, out_channels, kernel_size, stride, padding, dilation=1,
98
+ deformable_groups=1):
99
+ super(DCN_sep, self).__init__(in_channels, out_channels, kernel_size, stride, padding,
100
+ dilation, deformable_groups)
101
+
102
+ channels_ = self.deformable_groups * 3 * self.kernel_size[0] * self.kernel_size[1]
103
+ self.conv_offset_mask = nn.Conv2d(self.in_channels, channels_, kernel_size=self.kernel_size,
104
+ stride=self.stride, padding=self.padding, bias=True)
105
+ self.init_offset()
106
+
107
+ def init_offset(self):
108
+ self.conv_offset_mask.weight.data.zero_()
109
+ self.conv_offset_mask.bias.data.zero_()
110
+
111
+ def forward(self, input, fea):
112
+ '''input: input features for deformable conv
113
+ fea: other features used for generating offsets and mask'''
114
+ out = self.conv_offset_mask(fea)
115
+ o1, o2, mask = torch.chunk(out, 3, dim=1)
116
+ offset = torch.cat((o1, o2), dim=1)
117
+
118
+ offset_mean = torch.mean(torch.abs(offset))
119
+ if offset_mean > 100:
120
+ logger.warning('Offset mean is {}, larger than 100.'.format(offset_mean))
121
+
122
+ mask = torch.sigmoid(mask)
123
+
124
+
125
+ return torchvision.ops.deform_conv2d(
126
+ input=input,
127
+ offset=offset,
128
+ mask=mask,
129
+ weight=self.weight,
130
+ bias=self.bias,
131
+ stride=self.stride,
132
+ padding=self.padding,
133
+ dilation=self.dilation,
134
+ )
135
+
@@ -0,0 +1,159 @@
1
+ """ network architecture for Sakuya """
2
+ import torch
3
+ from DCNv2 import *
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ from torchvision.ops import DeformConv2d
7
+
8
+
9
+
10
+
11
+ class PCDAlign(nn.Module):
12
+ """Alignment module using Pyramid, Cascading and Deformable convolution
13
+ with 3 pyramid levels.
14
+ """
15
+
16
+ def __init__(self, nf=64, groups=8):
17
+ super(PCDAlign, self).__init__()
18
+
19
+ # fea1
20
+ # L3: level 3, 1/4 spatial size
21
+ self.L3_offset_conv1_1 = nn.Conv2d(
22
+ nf * 2, nf, 3, 1, 1, bias=True
23
+ ) # concat for dif
24
+ self.L3_offset_conv2_1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
25
+ self.L3_dcnpack_1 = DCN_sep(
26
+ nf, nf, 3, stride=1, padding=1, dilation=1, deformable_groups=groups
27
+ )
28
+ # L2: level 2, 1/2 spatial size
29
+ self.L2_offset_conv1_1 = nn.Conv2d(
30
+ nf * 2, nf, 3, 1, 1, bias=True
31
+ ) # concat for diff
32
+ self.L2_offset_conv2_1 = nn.Conv2d(
33
+ nf * 2, nf, 3, 1, 1, bias=True
34
+ ) # concat for offset
35
+ self.L2_offset_conv3_1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
36
+ self.L2_dcnpack_1 = DCN_sep(
37
+ nf, nf, 3, stride=1, padding=1, dilation=1, deformable_groups=groups
38
+ )
39
+ self.L2_fea_conv_1 = nn.Conv2d(nf * 2, nf, 3, 1, 1, bias=True) # concat for fea
40
+ # L1: level 1, original spatial size
41
+ self.L1_offset_conv1_1 = nn.Conv2d(
42
+ nf * 2, nf, 3, 1, 1, bias=True
43
+ ) # concat for diff
44
+ self.L1_offset_conv2_1 = nn.Conv2d(
45
+ nf * 2, nf, 3, 1, 1, bias=True
46
+ ) # concat for offset
47
+ self.L1_offset_conv3_1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
48
+ self.L1_dcnpack_1 = DCN_sep(
49
+ nf, nf, 3, stride=1, padding=1, dilation=1, deformable_groups=groups
50
+ )
51
+ self.L1_fea_conv_1 = nn.Conv2d(nf * 2, nf, 3, 1, 1, bias=True) # concat for fea
52
+
53
+ # Cascading DCN
54
+ self.cas_offset_conv1 = nn.Conv2d(
55
+ nf * 2, nf, 3, 1, 1, bias=True
56
+ ) # concat for diff
57
+ self.cas_offset_conv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
58
+ self.cas_dcnpack = DCN_sep(
59
+ nf, nf, 3, stride=1, padding=1, dilation=1, deformable_groups=groups
60
+ )
61
+
62
+ self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
63
+
64
+ def forward(self, fea1, fea2):
65
+ """align other neighboring frames to the reference frame in the feature level
66
+ fea1, fea2: [L1, L2, L3], each with [B,C,H,W] features
67
+ fea1 : features of neighboring frame
68
+ fea2 : features of reference (key) frame
69
+ estimate offset bidirectionally
70
+ """
71
+ # param. of fea1
72
+ # L3
73
+ L3_offset = torch.cat([fea1[2], fea2[2]], dim=1)
74
+ L3_offset = self.lrelu(self.L3_offset_conv1_1(L3_offset))
75
+ L3_offset = self.lrelu(self.L3_offset_conv2_1(L3_offset))
76
+ L3_fea = self.lrelu(self.L3_dcnpack_1(fea1[2], L3_offset))
77
+ # L2
78
+ L2_offset = torch.cat([fea1[1], fea2[1]], dim=1)
79
+ L2_offset = self.lrelu(self.L2_offset_conv1_1(L2_offset))
80
+ L3_offset = F.interpolate(
81
+ L3_offset, scale_factor=2, mode="bilinear", align_corners=False
82
+ )
83
+ L2_offset = self.lrelu(
84
+ self.L2_offset_conv2_1(torch.cat([L2_offset, L3_offset * 2], dim=1))
85
+ )
86
+ L2_offset = self.lrelu(self.L2_offset_conv3_1(L2_offset))
87
+ L2_fea = self.L2_dcnpack_1(fea1[1], L2_offset)
88
+ L3_fea = F.interpolate(
89
+ L3_fea, scale_factor=2, mode="bilinear", align_corners=False
90
+ )
91
+ L2_fea = self.lrelu(self.L2_fea_conv_1(torch.cat([L2_fea, L3_fea], dim=1)))
92
+ # L1
93
+ L1_offset = torch.cat([fea1[0], fea2[0]], dim=1)
94
+ L1_offset = self.lrelu(self.L1_offset_conv1_1(L1_offset))
95
+ L2_offset = F.interpolate(
96
+ L2_offset, scale_factor=2, mode="bilinear", align_corners=False
97
+ )
98
+ L1_offset = self.lrelu(
99
+ self.L1_offset_conv2_1(torch.cat([L1_offset, L2_offset * 2], dim=1))
100
+ )
101
+ L1_offset = self.lrelu(self.L1_offset_conv3_1(L1_offset))
102
+ L1_fea = self.L1_dcnpack_1(fea1[0], L1_offset)
103
+ L2_fea = F.interpolate(
104
+ L2_fea, scale_factor=2, mode="bilinear", align_corners=False
105
+ )
106
+ L1_fea = self.L1_fea_conv_1(torch.cat([L1_fea, L2_fea], dim=1))
107
+
108
+ # Cascading DCN
109
+ offset = torch.cat([L1_fea, fea2[0]], dim=1)
110
+ offset = self.lrelu(self.cas_offset_conv1(offset))
111
+ offset = self.lrelu(self.cas_offset_conv2(offset))
112
+ L1_fea = self.lrelu(self.cas_dcnpack(L1_fea, offset))
113
+
114
+ return L1_fea
115
+
116
+
117
+ class Easy_PCD(nn.Module):
118
+ def __init__(self, nf=64, groups=8):
119
+ super(Easy_PCD, self).__init__()
120
+
121
+ self.fea_L2_conv1 = nn.Conv2d(nf, nf, 3, 2, 1, bias=True)
122
+ self.fea_L2_conv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
123
+ self.fea_L3_conv1 = nn.Conv2d(nf, nf, 3, 2, 1, bias=True)
124
+ self.fea_L3_conv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
125
+ self.pcd_align = PCDAlign(nf=nf, groups=groups)
126
+ self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
127
+
128
+ def forward(self, f1, f2):
129
+ # input: extracted features
130
+ # f1: feature of neighboring frame
131
+ # f2: feature of the key (reference) frame
132
+ # feature size: f1 = f2 = [B, C, H, W]
133
+ # print(f1.size())
134
+ L1_fea = torch.stack([f1, f2], dim=1) # [B, 2, C, H, W]
135
+ B, N, C, H, W = L1_fea.size()
136
+ L1_fea = L1_fea.view(-1, C, H, W)
137
+ # L2
138
+ L2_fea = self.lrelu(self.fea_L2_conv1(L1_fea))
139
+ L2_fea = self.lrelu(self.fea_L2_conv2(L2_fea))
140
+ # L3
141
+ L3_fea = self.lrelu(self.fea_L3_conv1(L2_fea))
142
+ L3_fea = self.lrelu(self.fea_L3_conv2(L3_fea))
143
+
144
+ L1_fea = L1_fea.view(B, N, -1, H, W)
145
+ L2_fea = L2_fea.view(B, N, -1, H // 2, W // 2)
146
+ L3_fea = L3_fea.view(B, N, -1, H // 4, W // 4)
147
+
148
+ fea1 = [
149
+ L1_fea[:, 0, :, :, :].clone(),
150
+ L2_fea[:, 0, :, :, :].clone(),
151
+ L3_fea[:, 0, :, :, :].clone(),
152
+ ]
153
+ fea2 = [
154
+ L1_fea[:, 1, :, :, :].clone(),
155
+ L2_fea[:, 1, :, :, :].clone(),
156
+ L3_fea[:, 1, :, :, :].clone(),
157
+ ]
158
+ aligned_fea = self.pcd_align(fea1, fea2)
159
+ return aligned_fea