spikezoo 0.1.1__py3-none-any.whl → 0.2__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- spikezoo/__init__.py +13 -0
- spikezoo/archs/__pycache__/__init__.cpython-39.pyc +0 -0
- spikezoo/archs/base/__pycache__/nets.cpython-39.pyc +0 -0
- spikezoo/archs/base/nets.py +34 -0
- spikezoo/archs/bsf/README.md +92 -0
- spikezoo/archs/bsf/datasets/datasets.py +328 -0
- spikezoo/archs/bsf/datasets/ds_utils.py +64 -0
- spikezoo/archs/bsf/main.py +398 -0
- spikezoo/archs/bsf/metrics/psnr.py +22 -0
- spikezoo/archs/bsf/metrics/ssim.py +54 -0
- spikezoo/archs/bsf/models/bsf/__pycache__/align.cpython-39.pyc +0 -0
- spikezoo/archs/bsf/models/bsf/__pycache__/bsf.cpython-39.pyc +0 -0
- spikezoo/archs/bsf/models/bsf/__pycache__/rep.cpython-39.pyc +0 -0
- spikezoo/archs/bsf/models/bsf/align.py +154 -0
- spikezoo/archs/bsf/models/bsf/bsf.py +105 -0
- spikezoo/archs/bsf/models/bsf/dsft_convert.py +96 -0
- spikezoo/archs/bsf/models/bsf/rep.py +44 -0
- spikezoo/archs/bsf/models/get_model.py +7 -0
- spikezoo/archs/bsf/prepare_data/DSFT.py +62 -0
- spikezoo/archs/bsf/prepare_data/crop_dataset_train.py +135 -0
- spikezoo/archs/bsf/prepare_data/crop_dataset_val.py +139 -0
- spikezoo/archs/bsf/prepare_data/crop_train.sh +4 -0
- spikezoo/archs/bsf/prepare_data/crop_val.sh +4 -0
- spikezoo/archs/bsf/prepare_data/io_utils.py +64 -0
- spikezoo/archs/bsf/requirements.txt +9 -0
- spikezoo/archs/bsf/test.py +16 -0
- spikezoo/archs/bsf/utils.py +154 -0
- spikezoo/archs/spikeclip/__pycache__/nets.cpython-39.pyc +0 -0
- spikezoo/archs/spikeclip/nets.py +40 -0
- spikezoo/archs/spikeformer/CheckPoints/readme +1 -0
- spikezoo/archs/spikeformer/DataProcess/DataExtactor.py +60 -0
- spikezoo/archs/spikeformer/DataProcess/DataLoader.py +115 -0
- spikezoo/archs/spikeformer/DataProcess/LoadSpike.py +39 -0
- spikezoo/archs/spikeformer/EvalResults/readme +1 -0
- spikezoo/archs/spikeformer/LICENSE +21 -0
- spikezoo/archs/spikeformer/Metrics/Metrics.py +50 -0
- spikezoo/archs/spikeformer/Metrics/__init__.py +0 -0
- spikezoo/archs/spikeformer/Model/Loss.py +89 -0
- spikezoo/archs/spikeformer/Model/SpikeFormer.py +230 -0
- spikezoo/archs/spikeformer/Model/__init__.py +0 -0
- spikezoo/archs/spikeformer/Model/__pycache__/SpikeFormer.cpython-39.pyc +0 -0
- spikezoo/archs/spikeformer/Model/__pycache__/__init__.cpython-39.pyc +0 -0
- spikezoo/archs/spikeformer/README.md +30 -0
- spikezoo/archs/spikeformer/evaluate.py +87 -0
- spikezoo/archs/spikeformer/recon_real_data.py +97 -0
- spikezoo/archs/spikeformer/requirements.yml +95 -0
- spikezoo/archs/spikeformer/train.py +173 -0
- spikezoo/archs/spikeformer/utils.py +22 -0
- spikezoo/archs/spk2imgnet/.github/workflows/pylint.yml +23 -0
- spikezoo/archs/spk2imgnet/.gitignore +150 -0
- spikezoo/archs/spk2imgnet/DCNv2.py +135 -0
- spikezoo/archs/spk2imgnet/__pycache__/DCNv2.cpython-39.pyc +0 -0
- spikezoo/archs/spk2imgnet/__pycache__/align_arch.cpython-39.pyc +0 -0
- spikezoo/archs/spk2imgnet/__pycache__/nets.cpython-39.pyc +0 -0
- spikezoo/archs/spk2imgnet/align_arch.py +159 -0
- spikezoo/archs/spk2imgnet/dataset.py +144 -0
- spikezoo/archs/spk2imgnet/nets.py +230 -0
- spikezoo/archs/spk2imgnet/readme.md +86 -0
- spikezoo/archs/spk2imgnet/test_gen_imgseq.py +118 -0
- spikezoo/archs/spk2imgnet/train.py +189 -0
- spikezoo/archs/spk2imgnet/utils.py +64 -0
- spikezoo/archs/ssir/README.md +87 -0
- spikezoo/archs/ssir/configs/SSIR.yml +37 -0
- spikezoo/archs/ssir/configs/yml_parser.py +78 -0
- spikezoo/archs/ssir/datasets/dataset_sreds.py +170 -0
- spikezoo/archs/ssir/datasets/ds_utils.py +66 -0
- spikezoo/archs/ssir/losses.py +21 -0
- spikezoo/archs/ssir/main.py +326 -0
- spikezoo/archs/ssir/metrics/psnr.py +22 -0
- spikezoo/archs/ssir/metrics/ssim.py +54 -0
- spikezoo/archs/ssir/models/Vgg19.py +42 -0
- spikezoo/archs/ssir/models/__pycache__/layers.cpython-39.pyc +0 -0
- spikezoo/archs/ssir/models/__pycache__/networks.cpython-39.pyc +0 -0
- spikezoo/archs/ssir/models/layers.py +110 -0
- spikezoo/archs/ssir/models/networks.py +61 -0
- spikezoo/archs/ssir/requirements.txt +8 -0
- spikezoo/archs/ssir/shells/eval_SREDS.sh +6 -0
- spikezoo/archs/ssir/shells/train_SSIR.sh +12 -0
- spikezoo/archs/ssir/test.py +3 -0
- spikezoo/archs/ssir/utils.py +154 -0
- spikezoo/archs/ssml/__pycache__/cbam.cpython-39.pyc +0 -0
- spikezoo/archs/ssml/__pycache__/model.cpython-39.pyc +0 -0
- spikezoo/archs/ssml/cbam.py +224 -0
- spikezoo/archs/ssml/model.py +290 -0
- spikezoo/archs/ssml/res.png +0 -0
- spikezoo/archs/ssml/test.py +67 -0
- spikezoo/archs/stir/.git-credentials +0 -0
- spikezoo/archs/stir/README.md +65 -0
- spikezoo/archs/stir/ckpt_outputs/Descriptions.txt +1 -0
- spikezoo/archs/stir/configs/STIR.yml +37 -0
- spikezoo/archs/stir/configs/utils.py +155 -0
- spikezoo/archs/stir/configs/yml_parser.py +78 -0
- spikezoo/archs/stir/datasets/dataset_sreds.py +180 -0
- spikezoo/archs/stir/datasets/ds_utils.py +66 -0
- spikezoo/archs/stir/eval_SREDS.sh +5 -0
- spikezoo/archs/stir/main.py +397 -0
- spikezoo/archs/stir/metrics/losses.py +219 -0
- spikezoo/archs/stir/metrics/psnr.py +22 -0
- spikezoo/archs/stir/metrics/ssim.py +54 -0
- spikezoo/archs/stir/models/Vgg19.py +42 -0
- spikezoo/archs/stir/models/__pycache__/networks_STIR.cpython-39.pyc +0 -0
- spikezoo/archs/stir/models/__pycache__/submodules.cpython-39.pyc +0 -0
- spikezoo/archs/stir/models/__pycache__/transformer_new.cpython-39.pyc +0 -0
- spikezoo/archs/stir/models/networks_STIR.py +361 -0
- spikezoo/archs/stir/models/submodules.py +86 -0
- spikezoo/archs/stir/models/transformer_new.py +151 -0
- spikezoo/archs/stir/package_core/build/lib/package_core/__init__.py +0 -0
- spikezoo/archs/stir/package_core/build/lib/package_core/convertions.py +721 -0
- spikezoo/archs/stir/package_core/build/lib/package_core/disp_netS.py +133 -0
- spikezoo/archs/stir/package_core/build/lib/package_core/flow_utils.py +167 -0
- spikezoo/archs/stir/package_core/build/lib/package_core/generic_train_test.py +76 -0
- spikezoo/archs/stir/package_core/build/lib/package_core/geometry.py +458 -0
- spikezoo/archs/stir/package_core/build/lib/package_core/image_proc.py +183 -0
- spikezoo/archs/stir/package_core/build/lib/package_core/linalg.py +40 -0
- spikezoo/archs/stir/package_core/build/lib/package_core/losses.py +198 -0
- spikezoo/archs/stir/package_core/build/lib/package_core/metrics.py +51 -0
- spikezoo/archs/stir/package_core/build/lib/package_core/model_base.py +53 -0
- spikezoo/archs/stir/package_core/build/lib/package_core/net_basics.py +100 -0
- spikezoo/archs/stir/package_core/build/lib/package_core/resnet.py +333 -0
- spikezoo/archs/stir/package_core/build/lib/package_core/transforms.py +123 -0
- spikezoo/archs/stir/package_core/build/lib/package_core/utils.py +72 -0
- spikezoo/archs/stir/package_core/dist/package_core-0.0.0-py3.9.egg +0 -0
- spikezoo/archs/stir/package_core/package_core/__init__.py +0 -0
- spikezoo/archs/stir/package_core/package_core/__pycache__/__init__.cpython-39.pyc +0 -0
- spikezoo/archs/stir/package_core/package_core/__pycache__/net_basics.cpython-39.pyc +0 -0
- spikezoo/archs/stir/package_core/package_core/convertions.py +721 -0
- spikezoo/archs/stir/package_core/package_core/disp_netS.py +133 -0
- spikezoo/archs/stir/package_core/package_core/flow_utils.py +167 -0
- spikezoo/archs/stir/package_core/package_core/generic_train_test.py +76 -0
- spikezoo/archs/stir/package_core/package_core/geometry.py +458 -0
- spikezoo/archs/stir/package_core/package_core/image_proc.py +183 -0
- spikezoo/archs/stir/package_core/package_core/linalg.py +40 -0
- spikezoo/archs/stir/package_core/package_core/losses.py +198 -0
- spikezoo/archs/stir/package_core/package_core/metrics.py +51 -0
- spikezoo/archs/stir/package_core/package_core/model_base.py +53 -0
- spikezoo/archs/stir/package_core/package_core/net_basics.py +100 -0
- spikezoo/archs/stir/package_core/package_core/resnet.py +333 -0
- spikezoo/archs/stir/package_core/package_core/transforms.py +123 -0
- spikezoo/archs/stir/package_core/package_core/utils.py +72 -0
- spikezoo/archs/stir/package_core/package_core.egg-info/PKG-INFO +3 -0
- spikezoo/archs/stir/package_core/package_core.egg-info/SOURCES.txt +20 -0
- spikezoo/archs/stir/package_core/package_core.egg-info/dependency_links.txt +1 -0
- spikezoo/archs/stir/package_core/package_core.egg-info/top_level.txt +1 -0
- spikezoo/archs/stir/package_core/setup.py +5 -0
- spikezoo/archs/stir/requirements.txt +12 -0
- spikezoo/archs/stir/train_STIR.sh +9 -0
- spikezoo/archs/tfi/__pycache__/nets.cpython-39.pyc +0 -0
- spikezoo/archs/tfi/nets.py +43 -0
- spikezoo/archs/tfp/__pycache__/nets.cpython-39.pyc +0 -0
- spikezoo/archs/tfp/nets.py +13 -0
- spikezoo/archs/wgse/README.md +64 -0
- spikezoo/archs/wgse/__pycache__/dwtnets.cpython-39.pyc +0 -0
- spikezoo/archs/wgse/__pycache__/submodules.cpython-39.pyc +0 -0
- spikezoo/archs/wgse/dataset.py +59 -0
- spikezoo/archs/wgse/demo.png +0 -0
- spikezoo/archs/wgse/demo.py +83 -0
- spikezoo/archs/wgse/dwtnets.py +145 -0
- spikezoo/archs/wgse/eval.py +133 -0
- spikezoo/archs/wgse/logs/WGSE-Dwt1dNet-db8-5-ks3/log.txt +11 -0
- spikezoo/archs/wgse/submodules.py +68 -0
- spikezoo/archs/wgse/train.py +261 -0
- spikezoo/archs/wgse/transform.py +139 -0
- spikezoo/archs/wgse/utils.py +128 -0
- spikezoo/archs/wgse/weights/demo.png +0 -0
- spikezoo/data/base/test/gt/200_part1_key_id151.png +0 -0
- spikezoo/data/base/test/gt/200_part3_key_id151.png +0 -0
- spikezoo/data/base/test/gt/203_part1_key_id151.png +0 -0
- spikezoo/data/base/test/spike/200_part1_key_id151.dat +0 -0
- spikezoo/data/base/test/spike/200_part3_key_id151.dat +0 -0
- spikezoo/data/base/test/spike/203_part1_key_id151.dat +0 -0
- spikezoo/data/base/train/gt/203_part2_key_id151.png +0 -0
- spikezoo/data/base/train/gt/203_part3_key_id151.png +0 -0
- spikezoo/data/base/train/gt/203_part4_key_id151.png +0 -0
- spikezoo/data/base/train/spike/203_part2_key_id151.dat +0 -0
- spikezoo/data/base/train/spike/203_part3_key_id151.dat +0 -0
- spikezoo/data/base/train/spike/203_part4_key_id151.dat +0 -0
- spikezoo/datasets/base_dataset.py +2 -3
- spikezoo/metrics/__init__.py +1 -1
- spikezoo/models/base_model.py +1 -3
- spikezoo/pipeline/base_pipeline.py +7 -5
- spikezoo/pipeline/train_pipeline.py +1 -1
- spikezoo/utils/other_utils.py +16 -6
- spikezoo/utils/spike_utils.py +33 -29
- spikezoo/utils/vidar_loader.cpython-39-x86_64-linux-gnu.so +0 -0
- spikezoo-0.2.dist-info/METADATA +163 -0
- spikezoo-0.2.dist-info/RECORD +211 -0
- spikezoo/models/spcsnet_model.py +0 -19
- spikezoo-0.1.1.dist-info/METADATA +0 -39
- spikezoo-0.1.1.dist-info/RECORD +0 -36
- {spikezoo-0.1.1.dist-info → spikezoo-0.2.dist-info}/LICENSE.txt +0 -0
- {spikezoo-0.1.1.dist-info → spikezoo-0.2.dist-info}/WHEEL +0 -0
- {spikezoo-0.1.1.dist-info → spikezoo-0.2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,173 @@
|
|
1
|
+
import os
|
2
|
+
os.environ['CUDA_VISIBLE_DEVICES'] = "0"
|
3
|
+
import torch
|
4
|
+
from torch import optim
|
5
|
+
import numpy as np
|
6
|
+
from DataProcess import DataLoader as dl
|
7
|
+
from Model.SpikeFormer import SpikeFormer
|
8
|
+
from Metrics.Metrics import Metrics
|
9
|
+
from Model import Loss
|
10
|
+
from utils import SaveModel, LoadModel
|
11
|
+
from PIL import Image
|
12
|
+
|
13
|
+
def eval(model, validData, epoch, optimizer, metrics):
|
14
|
+
|
15
|
+
model.eval()
|
16
|
+
print('Eval Epoch: %s' %(epoch))
|
17
|
+
|
18
|
+
with torch.no_grad():
|
19
|
+
pres = []
|
20
|
+
gts = []
|
21
|
+
for i, (spikes, gtImg) in enumerate(validData):
|
22
|
+
|
23
|
+
spikes = spikes.cuda()
|
24
|
+
gtImg = gtImg.cuda()
|
25
|
+
predImg = model(spikes)
|
26
|
+
predImg = predImg.squeeze(1)
|
27
|
+
predImg = predImg[:,3:-3,:]
|
28
|
+
|
29
|
+
predImg = predImg.clamp(min=-1., max=1.)
|
30
|
+
predImg = predImg.detach().cpu().numpy()
|
31
|
+
gtImg = gtImg.clamp(min=-1., max=1.)
|
32
|
+
gtImg = gtImg.detach().cpu().numpy()
|
33
|
+
|
34
|
+
predImg = (predImg + 1.) / 2. * 255.
|
35
|
+
predImg = predImg.astype(np.uint8)
|
36
|
+
|
37
|
+
gtImg = (gtImg + 1.) / 2. * 255.
|
38
|
+
gtImg = gtImg.astype(np.uint8)
|
39
|
+
|
40
|
+
pres.append(predImg)
|
41
|
+
gts.append(gtImg)
|
42
|
+
pres = np.concatenate(pres, axis=0)
|
43
|
+
gts = np.concatenate(gts, axis=0)
|
44
|
+
|
45
|
+
psnr = metrics.Cal_PSNR(pres, gts)
|
46
|
+
ssim = metrics.Cal_SSIM(pres, gts)
|
47
|
+
best_psnr, best_ssim, _ = metrics.GetBestMetrics()
|
48
|
+
|
49
|
+
SaveModel(epoch, (psnr, ssim), model, optimizer, saveRoot)
|
50
|
+
if psnr >= best_psnr and ssim >= best_ssim:
|
51
|
+
metrics.Update(psnr, ssim)
|
52
|
+
SaveModel(epoch, (psnr, ssim), model, optimizer, saveRoot, best=True)
|
53
|
+
with open('eval_best_log.txt', 'w') as f:
|
54
|
+
f.write('epoch: %s; psnr: %s, ssim: %s\n' %(epoch, psnr, ssim))
|
55
|
+
B, H, W = pres.shape
|
56
|
+
divide_line = np.zeros((H,4)).astype(np.uint8)
|
57
|
+
num = 0
|
58
|
+
for pre, gt in zip(pres, gts):
|
59
|
+
num += 1
|
60
|
+
concatImg = np.concatenate([pre, divide_line, gt], axis=1)
|
61
|
+
concatImg = Image.fromarray(concatImg)
|
62
|
+
concatImg.save('EvalResults/valid_%s.jpg' % (num))
|
63
|
+
|
64
|
+
print('*********************************************************')
|
65
|
+
best_psnr, best_ssim, _ = metrics.GetBestMetrics()
|
66
|
+
print('Eval Epoch: %s, PSNR: %s, SSIM: %s, Best_PSNR: %s, Best_SSIM: %s'
|
67
|
+
%(epoch, psnr, ssim, best_psnr, best_ssim))
|
68
|
+
|
69
|
+
model.train()
|
70
|
+
|
71
|
+
def Train(trainData, validData, model, optimizer, epoch, start_epoch, metrics, saveRoot, perIter):
|
72
|
+
avg_l2_loss = 0.
|
73
|
+
avg_vgg_loss = 0.
|
74
|
+
avg_edge_loss = 0.
|
75
|
+
avg_total_loss = 0.
|
76
|
+
l2loss = Loss.CharbonnierLoss()
|
77
|
+
vggloss = Loss.VGGLoss4('vgg19-low-level4.pth').cuda()
|
78
|
+
criterion_edge = Loss.EdgeLoss()
|
79
|
+
LAMBDA_L2 = 100.0
|
80
|
+
LAMBDA_VGG = 1.0
|
81
|
+
LAMBDA_EDGE = 5.0
|
82
|
+
for i in range(start_epoch, epoch):
|
83
|
+
for iter, (spikes, gtImg) in enumerate(trainData):
|
84
|
+
spikes = spikes.cuda()
|
85
|
+
gtImg = gtImg.cuda()
|
86
|
+
predImg = model(spikes)
|
87
|
+
gtImg = gtImg.unsqueeze(1)
|
88
|
+
predImg = predImg[:,:,3:-3,:]
|
89
|
+
|
90
|
+
loss_vgg = vggloss(gtImg, predImg) * LAMBDA_VGG
|
91
|
+
loss_l2 = l2loss(gtImg, predImg) * LAMBDA_L2
|
92
|
+
loss_edge = criterion_edge(gtImg, predImg) * LAMBDA_EDGE
|
93
|
+
|
94
|
+
totalLoss = loss_l2 + loss_vgg + loss_edge
|
95
|
+
|
96
|
+
optimizer.zero_grad()
|
97
|
+
totalLoss.backward()
|
98
|
+
optimizer.step()
|
99
|
+
|
100
|
+
avg_l2_loss += loss_l2.detach().cpu()
|
101
|
+
avg_vgg_loss += loss_vgg.detach().cpu()
|
102
|
+
avg_edge_loss += loss_edge.detach().cpu()
|
103
|
+
avg_total_loss += totalLoss.detach().cpu()
|
104
|
+
if (iter + 1) % perIter == 0:
|
105
|
+
avg_l2_loss = avg_l2_loss / perIter
|
106
|
+
avg_vgg_loss = avg_vgg_loss / perIter
|
107
|
+
avg_edge_loss = avg_edge_loss / perIter
|
108
|
+
avg_total_loss = avg_total_loss / perIter
|
109
|
+
print('=============================================================')
|
110
|
+
print('Epoch: %s, Iter: %s' % (i, iter + 1))
|
111
|
+
print('L2Loss: %s; VggLoss: %s; EdgeLoss: %s; TotalLoss: %s' % (
|
112
|
+
avg_l2_loss.item(), avg_vgg_loss.item(), avg_edge_loss.item(), avg_total_loss.item()))
|
113
|
+
avg_l2_loss = 0.
|
114
|
+
avg_vgg_loss = 0.
|
115
|
+
avg_edge_loss = 0.
|
116
|
+
avg_total_loss = 0.
|
117
|
+
|
118
|
+
if (i + 1) % 1 == 0:
|
119
|
+
eval(model, validData, i, optimizer, metrics)
|
120
|
+
|
121
|
+
if __name__ == "__main__":
|
122
|
+
|
123
|
+
dataPath = "/home/storage2/shechen/Spike_Sample_250x400"
|
124
|
+
spikeRadius = 32 # half length of input spike sequence expcept for the middle frame
|
125
|
+
spikeLen = 2 * spikeRadius + 1 # length of input spike sequence
|
126
|
+
batchSize = 2
|
127
|
+
epoch = 200
|
128
|
+
start_epoch = 0
|
129
|
+
lr = 2e-4
|
130
|
+
saveRoot = "CheckPoints/" # path to save the trained model
|
131
|
+
perIter = 20
|
132
|
+
|
133
|
+
reuse = False
|
134
|
+
reuseType = 'latest' # 'latest' or 'best'
|
135
|
+
checkPath = os.path.join('CheckPoints', '%s.pth' % (reuseType))
|
136
|
+
|
137
|
+
trainContainer = dl.DataContainer(dataPath=dataPath, dataType='train',
|
138
|
+
spikeRadius=spikeRadius,
|
139
|
+
batchSize=batchSize)
|
140
|
+
trainData = trainContainer.GetLoader()
|
141
|
+
|
142
|
+
validContainer = dl.DataContainer(dataPath=dataPath, dataType='valid',
|
143
|
+
spikeRadius=spikeRadius,
|
144
|
+
batchSize=batchSize)
|
145
|
+
validData = validContainer.GetLoader()
|
146
|
+
|
147
|
+
metrics = Metrics()
|
148
|
+
|
149
|
+
model = SpikeFormer(
|
150
|
+
inputDim = spikeLen,
|
151
|
+
dims = (32, 64, 160, 256), # dimensions of each stage
|
152
|
+
heads = (1, 2, 5, 8), # heads of each stage
|
153
|
+
ff_expansion = (8, 8, 4, 4), # feedforward expansion factor of each stage
|
154
|
+
reduction_ratio = (8, 4, 2, 1), # reduction ratio of each stage for efficient attention
|
155
|
+
num_layers = 2, # num layers of each stage
|
156
|
+
decoder_dim = 256, # decoder dimension
|
157
|
+
out_channel = 1 # channel of restored image
|
158
|
+
).cuda()
|
159
|
+
|
160
|
+
optimizer = optim.Adam(model.parameters(), lr=lr, betas=(0.9, 0.999), amsgrad=False)
|
161
|
+
|
162
|
+
if reuse:
|
163
|
+
preEpoch, prePerformance, modelDict, optDict = LoadModel(checkPath, model, optimizer)
|
164
|
+
start_epoch = preEpoch + 1
|
165
|
+
psnr, ssim = prePerformance[0], prePerformance[1]
|
166
|
+
metrics.Update(psnr, ssim)
|
167
|
+
for para in optimizer.param_groups:
|
168
|
+
para['lr'] = lr
|
169
|
+
|
170
|
+
model.train()
|
171
|
+
|
172
|
+
Train(trainData, validData, model, optimizer, epoch, start_epoch,
|
173
|
+
metrics, saveRoot, perIter)
|
@@ -0,0 +1,22 @@
|
|
1
|
+
import os
|
2
|
+
import torch
|
3
|
+
|
4
|
+
def SaveModel(epoch, bestPerformance, model, optimizer, saveRoot, best=False):
|
5
|
+
saveDict = {
|
6
|
+
'pre_epoch':epoch,
|
7
|
+
'performance':bestPerformance,
|
8
|
+
'model_state_dict':model.state_dict(),
|
9
|
+
'optimizer_state_dict':optimizer.state_dict()
|
10
|
+
}
|
11
|
+
savePath = os.path.join(saveRoot, '%s.pth' %('latest' if not best else 'best'))
|
12
|
+
torch.save(saveDict, savePath)
|
13
|
+
|
14
|
+
def LoadModel(checkPath, model, optimizer=None):
|
15
|
+
stateDict = torch.load(checkPath)
|
16
|
+
pre_epoch = stateDict['pre_epoch']
|
17
|
+
model.load_state_dict(stateDict['model_state_dict'])
|
18
|
+
if optimizer is not None:
|
19
|
+
optimizer.load_state_dict(stateDict['optimizer_state_dict'])
|
20
|
+
|
21
|
+
return pre_epoch, stateDict['performance'], \
|
22
|
+
stateDict['model_state_dict'], stateDict['optimizer_state_dict']
|
@@ -0,0 +1,23 @@
|
|
1
|
+
name: Pylint
|
2
|
+
|
3
|
+
on: [push]
|
4
|
+
|
5
|
+
jobs:
|
6
|
+
build:
|
7
|
+
runs-on: ubuntu-latest
|
8
|
+
strategy:
|
9
|
+
matrix:
|
10
|
+
python-version: ["3.8", "3.9", "3.10"]
|
11
|
+
steps:
|
12
|
+
- uses: actions/checkout@v3
|
13
|
+
- name: Set up Python ${{ matrix.python-version }}
|
14
|
+
uses: actions/setup-python@v3
|
15
|
+
with:
|
16
|
+
python-version: ${{ matrix.python-version }}
|
17
|
+
- name: Install dependencies
|
18
|
+
run: |
|
19
|
+
python -m pip install --upgrade pip
|
20
|
+
pip install pylint
|
21
|
+
- name: Analysing the code with pylint
|
22
|
+
run: |
|
23
|
+
pylint $(git ls-files '*.py')
|
@@ -0,0 +1,150 @@
|
|
1
|
+
### Python template
|
2
|
+
# Byte-compiled / optimized / DLL files
|
3
|
+
__pycache__/
|
4
|
+
*.py[cod]
|
5
|
+
*$py.class
|
6
|
+
.idea/
|
7
|
+
|
8
|
+
# C extensions
|
9
|
+
*.so
|
10
|
+
|
11
|
+
# Distribution / packaging
|
12
|
+
.Python
|
13
|
+
build/
|
14
|
+
develop-eggs/
|
15
|
+
dist/
|
16
|
+
downloads/
|
17
|
+
eggs/
|
18
|
+
.eggs/
|
19
|
+
lib/
|
20
|
+
lib64/
|
21
|
+
parts/
|
22
|
+
sdist/
|
23
|
+
var/
|
24
|
+
wheels/
|
25
|
+
share/python-wheels/
|
26
|
+
results/
|
27
|
+
ckpt/
|
28
|
+
ckpt2/
|
29
|
+
old_ckpt/
|
30
|
+
Spk2ImgNet_test2/
|
31
|
+
Spk2ImgNet_train/
|
32
|
+
*.zip
|
33
|
+
*.pth
|
34
|
+
*.h5
|
35
|
+
*.egg-info/
|
36
|
+
.installed.cfg
|
37
|
+
*.egg
|
38
|
+
MANIFEST
|
39
|
+
|
40
|
+
# PyInstaller
|
41
|
+
# Usually these files are written by a python script from a template
|
42
|
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
43
|
+
*.manifest
|
44
|
+
*.spec
|
45
|
+
|
46
|
+
# Installer logs
|
47
|
+
pip-log.txt
|
48
|
+
pip-delete-this-directory.txt
|
49
|
+
|
50
|
+
# Unit test / coverage reports
|
51
|
+
htmlcov/
|
52
|
+
.tox/
|
53
|
+
.nox/
|
54
|
+
.coverage
|
55
|
+
.coverage.*
|
56
|
+
.cache
|
57
|
+
nosetests.xml
|
58
|
+
coverage.xml
|
59
|
+
*.cover
|
60
|
+
*.py,cover
|
61
|
+
.hypothesis/
|
62
|
+
.pytest_cache/
|
63
|
+
cover/
|
64
|
+
|
65
|
+
# Translations
|
66
|
+
*.mo
|
67
|
+
*.pot
|
68
|
+
|
69
|
+
# Django stuff:
|
70
|
+
*.log
|
71
|
+
local_settings.py
|
72
|
+
db.sqlite3
|
73
|
+
db.sqlite3-journal
|
74
|
+
|
75
|
+
# Flask stuff:
|
76
|
+
instance/
|
77
|
+
.webassets-cache
|
78
|
+
|
79
|
+
# Scrapy stuff:
|
80
|
+
.scrapy
|
81
|
+
|
82
|
+
# Sphinx documentation
|
83
|
+
docs/_build/
|
84
|
+
|
85
|
+
# PyBuilder
|
86
|
+
.pybuilder/
|
87
|
+
target/
|
88
|
+
|
89
|
+
# Jupyter Notebook
|
90
|
+
.ipynb_checkpoints
|
91
|
+
|
92
|
+
# IPython
|
93
|
+
profile_default/
|
94
|
+
ipython_config.py
|
95
|
+
|
96
|
+
# pyenv
|
97
|
+
# For a library or package, you might want to ignore these files since the code is
|
98
|
+
# intended to run in multiple environments; otherwise, check them in:
|
99
|
+
# .python-version
|
100
|
+
|
101
|
+
# pipenv
|
102
|
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
103
|
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
104
|
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
105
|
+
# install all needed dependencies.
|
106
|
+
#Pipfile.lock
|
107
|
+
|
108
|
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
|
109
|
+
__pypackages__/
|
110
|
+
|
111
|
+
# Celery stuff
|
112
|
+
celerybeat-schedule
|
113
|
+
celerybeat.pid
|
114
|
+
|
115
|
+
# SageMath parsed files
|
116
|
+
*.sage.py
|
117
|
+
|
118
|
+
# Environments
|
119
|
+
.env
|
120
|
+
.venv
|
121
|
+
env/
|
122
|
+
venv/
|
123
|
+
ENV/
|
124
|
+
env.bak/
|
125
|
+
venv.bak/
|
126
|
+
|
127
|
+
# Spyder project settings
|
128
|
+
.spyderproject
|
129
|
+
.spyproject
|
130
|
+
|
131
|
+
# Rope project settings
|
132
|
+
.ropeproject
|
133
|
+
|
134
|
+
# mkdocs documentation
|
135
|
+
/site
|
136
|
+
|
137
|
+
# mypy
|
138
|
+
.mypy_cache/
|
139
|
+
.dmypy.json
|
140
|
+
dmypy.json
|
141
|
+
|
142
|
+
# Pyre type checker
|
143
|
+
.pyre/
|
144
|
+
|
145
|
+
# pytype static type analyzer
|
146
|
+
.pytype/
|
147
|
+
|
148
|
+
# Cython debug symbols
|
149
|
+
cython_debug/
|
150
|
+
|
@@ -0,0 +1,135 @@
|
|
1
|
+
#!/usr/bin/env python
|
2
|
+
|
3
|
+
import math
|
4
|
+
import logging
|
5
|
+
import torch
|
6
|
+
from PIL.Image import logger
|
7
|
+
from torch import nn
|
8
|
+
import torchvision
|
9
|
+
from torch.nn.modules.utils import _pair
|
10
|
+
|
11
|
+
|
12
|
+
class DCNv2(nn.Module):
|
13
|
+
def __init__(self, in_channels, out_channels, kernel_size, stride, padding, dilation=1,
|
14
|
+
deformable_groups=1):
|
15
|
+
super(DCNv2, self).__init__()
|
16
|
+
self.in_channels = in_channels
|
17
|
+
self.out_channels = out_channels
|
18
|
+
self.kernel_size = _pair(kernel_size)
|
19
|
+
self.stride = _pair(stride)
|
20
|
+
self.padding = _pair(padding)
|
21
|
+
self.dilation = _pair(dilation)
|
22
|
+
self.deformable_groups = deformable_groups
|
23
|
+
|
24
|
+
self.weight = nn.Parameter(torch.Tensor(out_channels, in_channels, *self.kernel_size))
|
25
|
+
self.bias = nn.Parameter(torch.Tensor(out_channels))
|
26
|
+
self.reset_parameters()
|
27
|
+
|
28
|
+
def reset_parameters(self):
|
29
|
+
n = self.in_channels
|
30
|
+
for k in self.kernel_size:
|
31
|
+
n *= k
|
32
|
+
stdv = 1. / math.sqrt(n)
|
33
|
+
self.weight.data.uniform_(-stdv, stdv)
|
34
|
+
self.bias.data.zero_()
|
35
|
+
|
36
|
+
def forward(self, input, offset, mask):
|
37
|
+
assert 2 * self.deformable_groups * self.kernel_size[0] * self.kernel_size[1] == \
|
38
|
+
offset.shape[1]
|
39
|
+
assert self.deformable_groups * self.kernel_size[0] * self.kernel_size[1] == \
|
40
|
+
mask.shape[1]
|
41
|
+
|
42
|
+
|
43
|
+
return torchvision.ops.deform_conv2d(
|
44
|
+
input=input,
|
45
|
+
offset=offset,
|
46
|
+
mask=mask,
|
47
|
+
weight=self.weight,
|
48
|
+
bias=self.bias,
|
49
|
+
stride=self.stride,
|
50
|
+
padding=self.padding,
|
51
|
+
dilation=self.dilation,
|
52
|
+
groups=self.deformable_groups
|
53
|
+
)
|
54
|
+
|
55
|
+
|
56
|
+
|
57
|
+
|
58
|
+
class DCN(DCNv2):
|
59
|
+
def __init__(self, in_channels, out_channels, kernel_size, stride, padding, dilation=1,
|
60
|
+
deformable_groups=1):
|
61
|
+
super(DCN, self).__init__(in_channels, out_channels, kernel_size, stride, padding, dilation,
|
62
|
+
deformable_groups)
|
63
|
+
|
64
|
+
channels_ = self.deformable_groups * 3 * self.kernel_size[0] * self.kernel_size[1]
|
65
|
+
self.conv_offset_mask = nn.Conv2d(self.in_channels, channels_, kernel_size=self.kernel_size,
|
66
|
+
stride=self.stride, padding=self.padding, bias=True)
|
67
|
+
self.init_offset()
|
68
|
+
|
69
|
+
def init_offset(self):
|
70
|
+
self.conv_offset_mask.weight.data.zero_()
|
71
|
+
self.conv_offset_mask.bias.data.zero_()
|
72
|
+
|
73
|
+
def forward(self, input):
|
74
|
+
out = self.conv_offset_mask(input)
|
75
|
+
o1, o2, mask = torch.chunk(out, 3, dim=1)
|
76
|
+
offset = torch.cat((o1, o2), dim=1)
|
77
|
+
mask = torch.sigmoid(mask)
|
78
|
+
|
79
|
+
# return dcn_v2_conv(input, offset, mask, self.weight, self.bias, self.stride, self.padding,
|
80
|
+
# self.dilation, self.deformable_groups)
|
81
|
+
|
82
|
+
return torchvision.ops.deform_conv2d(
|
83
|
+
input=input,
|
84
|
+
offset=offset,
|
85
|
+
mask=mask,
|
86
|
+
weight=self.weight,
|
87
|
+
bias=self.bias,
|
88
|
+
stride=self.stride,
|
89
|
+
padding=self.padding,
|
90
|
+
dilation=self.dilation,
|
91
|
+
)
|
92
|
+
|
93
|
+
|
94
|
+
class DCN_sep(DCNv2):
|
95
|
+
'''Use other features to generate offsets and masks'''
|
96
|
+
|
97
|
+
def __init__(self, in_channels, out_channels, kernel_size, stride, padding, dilation=1,
|
98
|
+
deformable_groups=1):
|
99
|
+
super(DCN_sep, self).__init__(in_channels, out_channels, kernel_size, stride, padding,
|
100
|
+
dilation, deformable_groups)
|
101
|
+
|
102
|
+
channels_ = self.deformable_groups * 3 * self.kernel_size[0] * self.kernel_size[1]
|
103
|
+
self.conv_offset_mask = nn.Conv2d(self.in_channels, channels_, kernel_size=self.kernel_size,
|
104
|
+
stride=self.stride, padding=self.padding, bias=True)
|
105
|
+
self.init_offset()
|
106
|
+
|
107
|
+
def init_offset(self):
|
108
|
+
self.conv_offset_mask.weight.data.zero_()
|
109
|
+
self.conv_offset_mask.bias.data.zero_()
|
110
|
+
|
111
|
+
def forward(self, input, fea):
|
112
|
+
'''input: input features for deformable conv
|
113
|
+
fea: other features used for generating offsets and mask'''
|
114
|
+
out = self.conv_offset_mask(fea)
|
115
|
+
o1, o2, mask = torch.chunk(out, 3, dim=1)
|
116
|
+
offset = torch.cat((o1, o2), dim=1)
|
117
|
+
|
118
|
+
offset_mean = torch.mean(torch.abs(offset))
|
119
|
+
if offset_mean > 100:
|
120
|
+
logger.warning('Offset mean is {}, larger than 100.'.format(offset_mean))
|
121
|
+
|
122
|
+
mask = torch.sigmoid(mask)
|
123
|
+
|
124
|
+
|
125
|
+
return torchvision.ops.deform_conv2d(
|
126
|
+
input=input,
|
127
|
+
offset=offset,
|
128
|
+
mask=mask,
|
129
|
+
weight=self.weight,
|
130
|
+
bias=self.bias,
|
131
|
+
stride=self.stride,
|
132
|
+
padding=self.padding,
|
133
|
+
dilation=self.dilation,
|
134
|
+
)
|
135
|
+
|
Binary file
|
Binary file
|
Binary file
|
@@ -0,0 +1,159 @@
|
|
1
|
+
""" network architecture for Sakuya """
|
2
|
+
import torch
|
3
|
+
from DCNv2 import *
|
4
|
+
import torch.nn as nn
|
5
|
+
import torch.nn.functional as F
|
6
|
+
from torchvision.ops import DeformConv2d
|
7
|
+
|
8
|
+
|
9
|
+
|
10
|
+
|
11
|
+
class PCDAlign(nn.Module):
|
12
|
+
"""Alignment module using Pyramid, Cascading and Deformable convolution
|
13
|
+
with 3 pyramid levels.
|
14
|
+
"""
|
15
|
+
|
16
|
+
def __init__(self, nf=64, groups=8):
|
17
|
+
super(PCDAlign, self).__init__()
|
18
|
+
|
19
|
+
# fea1
|
20
|
+
# L3: level 3, 1/4 spatial size
|
21
|
+
self.L3_offset_conv1_1 = nn.Conv2d(
|
22
|
+
nf * 2, nf, 3, 1, 1, bias=True
|
23
|
+
) # concat for dif
|
24
|
+
self.L3_offset_conv2_1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
25
|
+
self.L3_dcnpack_1 = DCN_sep(
|
26
|
+
nf, nf, 3, stride=1, padding=1, dilation=1, deformable_groups=groups
|
27
|
+
)
|
28
|
+
# L2: level 2, 1/2 spatial size
|
29
|
+
self.L2_offset_conv1_1 = nn.Conv2d(
|
30
|
+
nf * 2, nf, 3, 1, 1, bias=True
|
31
|
+
) # concat for diff
|
32
|
+
self.L2_offset_conv2_1 = nn.Conv2d(
|
33
|
+
nf * 2, nf, 3, 1, 1, bias=True
|
34
|
+
) # concat for offset
|
35
|
+
self.L2_offset_conv3_1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
36
|
+
self.L2_dcnpack_1 = DCN_sep(
|
37
|
+
nf, nf, 3, stride=1, padding=1, dilation=1, deformable_groups=groups
|
38
|
+
)
|
39
|
+
self.L2_fea_conv_1 = nn.Conv2d(nf * 2, nf, 3, 1, 1, bias=True) # concat for fea
|
40
|
+
# L1: level 1, original spatial size
|
41
|
+
self.L1_offset_conv1_1 = nn.Conv2d(
|
42
|
+
nf * 2, nf, 3, 1, 1, bias=True
|
43
|
+
) # concat for diff
|
44
|
+
self.L1_offset_conv2_1 = nn.Conv2d(
|
45
|
+
nf * 2, nf, 3, 1, 1, bias=True
|
46
|
+
) # concat for offset
|
47
|
+
self.L1_offset_conv3_1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
48
|
+
self.L1_dcnpack_1 = DCN_sep(
|
49
|
+
nf, nf, 3, stride=1, padding=1, dilation=1, deformable_groups=groups
|
50
|
+
)
|
51
|
+
self.L1_fea_conv_1 = nn.Conv2d(nf * 2, nf, 3, 1, 1, bias=True) # concat for fea
|
52
|
+
|
53
|
+
# Cascading DCN
|
54
|
+
self.cas_offset_conv1 = nn.Conv2d(
|
55
|
+
nf * 2, nf, 3, 1, 1, bias=True
|
56
|
+
) # concat for diff
|
57
|
+
self.cas_offset_conv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
58
|
+
self.cas_dcnpack = DCN_sep(
|
59
|
+
nf, nf, 3, stride=1, padding=1, dilation=1, deformable_groups=groups
|
60
|
+
)
|
61
|
+
|
62
|
+
self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
|
63
|
+
|
64
|
+
def forward(self, fea1, fea2):
|
65
|
+
"""align other neighboring frames to the reference frame in the feature level
|
66
|
+
fea1, fea2: [L1, L2, L3], each with [B,C,H,W] features
|
67
|
+
fea1 : features of neighboring frame
|
68
|
+
fea2 : features of reference (key) frame
|
69
|
+
estimate offset bidirectionally
|
70
|
+
"""
|
71
|
+
# param. of fea1
|
72
|
+
# L3
|
73
|
+
L3_offset = torch.cat([fea1[2], fea2[2]], dim=1)
|
74
|
+
L3_offset = self.lrelu(self.L3_offset_conv1_1(L3_offset))
|
75
|
+
L3_offset = self.lrelu(self.L3_offset_conv2_1(L3_offset))
|
76
|
+
L3_fea = self.lrelu(self.L3_dcnpack_1(fea1[2], L3_offset))
|
77
|
+
# L2
|
78
|
+
L2_offset = torch.cat([fea1[1], fea2[1]], dim=1)
|
79
|
+
L2_offset = self.lrelu(self.L2_offset_conv1_1(L2_offset))
|
80
|
+
L3_offset = F.interpolate(
|
81
|
+
L3_offset, scale_factor=2, mode="bilinear", align_corners=False
|
82
|
+
)
|
83
|
+
L2_offset = self.lrelu(
|
84
|
+
self.L2_offset_conv2_1(torch.cat([L2_offset, L3_offset * 2], dim=1))
|
85
|
+
)
|
86
|
+
L2_offset = self.lrelu(self.L2_offset_conv3_1(L2_offset))
|
87
|
+
L2_fea = self.L2_dcnpack_1(fea1[1], L2_offset)
|
88
|
+
L3_fea = F.interpolate(
|
89
|
+
L3_fea, scale_factor=2, mode="bilinear", align_corners=False
|
90
|
+
)
|
91
|
+
L2_fea = self.lrelu(self.L2_fea_conv_1(torch.cat([L2_fea, L3_fea], dim=1)))
|
92
|
+
# L1
|
93
|
+
L1_offset = torch.cat([fea1[0], fea2[0]], dim=1)
|
94
|
+
L1_offset = self.lrelu(self.L1_offset_conv1_1(L1_offset))
|
95
|
+
L2_offset = F.interpolate(
|
96
|
+
L2_offset, scale_factor=2, mode="bilinear", align_corners=False
|
97
|
+
)
|
98
|
+
L1_offset = self.lrelu(
|
99
|
+
self.L1_offset_conv2_1(torch.cat([L1_offset, L2_offset * 2], dim=1))
|
100
|
+
)
|
101
|
+
L1_offset = self.lrelu(self.L1_offset_conv3_1(L1_offset))
|
102
|
+
L1_fea = self.L1_dcnpack_1(fea1[0], L1_offset)
|
103
|
+
L2_fea = F.interpolate(
|
104
|
+
L2_fea, scale_factor=2, mode="bilinear", align_corners=False
|
105
|
+
)
|
106
|
+
L1_fea = self.L1_fea_conv_1(torch.cat([L1_fea, L2_fea], dim=1))
|
107
|
+
|
108
|
+
# Cascading DCN
|
109
|
+
offset = torch.cat([L1_fea, fea2[0]], dim=1)
|
110
|
+
offset = self.lrelu(self.cas_offset_conv1(offset))
|
111
|
+
offset = self.lrelu(self.cas_offset_conv2(offset))
|
112
|
+
L1_fea = self.lrelu(self.cas_dcnpack(L1_fea, offset))
|
113
|
+
|
114
|
+
return L1_fea
|
115
|
+
|
116
|
+
|
117
|
+
class Easy_PCD(nn.Module):
|
118
|
+
def __init__(self, nf=64, groups=8):
|
119
|
+
super(Easy_PCD, self).__init__()
|
120
|
+
|
121
|
+
self.fea_L2_conv1 = nn.Conv2d(nf, nf, 3, 2, 1, bias=True)
|
122
|
+
self.fea_L2_conv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
123
|
+
self.fea_L3_conv1 = nn.Conv2d(nf, nf, 3, 2, 1, bias=True)
|
124
|
+
self.fea_L3_conv2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=True)
|
125
|
+
self.pcd_align = PCDAlign(nf=nf, groups=groups)
|
126
|
+
self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
|
127
|
+
|
128
|
+
def forward(self, f1, f2):
|
129
|
+
# input: extracted features
|
130
|
+
# f1: feature of neighboring frame
|
131
|
+
# f2: feature of the key (reference) frame
|
132
|
+
# feature size: f1 = f2 = [B, C, H, W]
|
133
|
+
# print(f1.size())
|
134
|
+
L1_fea = torch.stack([f1, f2], dim=1) # [B, 2, C, H, W]
|
135
|
+
B, N, C, H, W = L1_fea.size()
|
136
|
+
L1_fea = L1_fea.view(-1, C, H, W)
|
137
|
+
# L2
|
138
|
+
L2_fea = self.lrelu(self.fea_L2_conv1(L1_fea))
|
139
|
+
L2_fea = self.lrelu(self.fea_L2_conv2(L2_fea))
|
140
|
+
# L3
|
141
|
+
L3_fea = self.lrelu(self.fea_L3_conv1(L2_fea))
|
142
|
+
L3_fea = self.lrelu(self.fea_L3_conv2(L3_fea))
|
143
|
+
|
144
|
+
L1_fea = L1_fea.view(B, N, -1, H, W)
|
145
|
+
L2_fea = L2_fea.view(B, N, -1, H // 2, W // 2)
|
146
|
+
L3_fea = L3_fea.view(B, N, -1, H // 4, W // 4)
|
147
|
+
|
148
|
+
fea1 = [
|
149
|
+
L1_fea[:, 0, :, :, :].clone(),
|
150
|
+
L2_fea[:, 0, :, :, :].clone(),
|
151
|
+
L3_fea[:, 0, :, :, :].clone(),
|
152
|
+
]
|
153
|
+
fea2 = [
|
154
|
+
L1_fea[:, 1, :, :, :].clone(),
|
155
|
+
L2_fea[:, 1, :, :, :].clone(),
|
156
|
+
L3_fea[:, 1, :, :, :].clone(),
|
157
|
+
]
|
158
|
+
aligned_fea = self.pcd_align(fea1, fea2)
|
159
|
+
return aligned_fea
|