spikezoo 0.1.2__py3-none-any.whl → 0.2.1__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- spikezoo/__init__.py +13 -0
- spikezoo/archs/__pycache__/__init__.cpython-39.pyc +0 -0
- spikezoo/archs/base/__pycache__/nets.cpython-39.pyc +0 -0
- spikezoo/archs/base/nets.py +34 -0
- spikezoo/archs/bsf/README.md +92 -0
- spikezoo/archs/bsf/datasets/datasets.py +328 -0
- spikezoo/archs/bsf/datasets/ds_utils.py +64 -0
- spikezoo/archs/bsf/main.py +398 -0
- spikezoo/archs/bsf/metrics/psnr.py +22 -0
- spikezoo/archs/bsf/metrics/ssim.py +54 -0
- spikezoo/archs/bsf/models/bsf/__pycache__/align.cpython-39.pyc +0 -0
- spikezoo/archs/bsf/models/bsf/__pycache__/bsf.cpython-39.pyc +0 -0
- spikezoo/archs/bsf/models/bsf/__pycache__/rep.cpython-39.pyc +0 -0
- spikezoo/archs/bsf/models/bsf/align.py +154 -0
- spikezoo/archs/bsf/models/bsf/bsf.py +105 -0
- spikezoo/archs/bsf/models/bsf/dsft_convert.py +96 -0
- spikezoo/archs/bsf/models/bsf/rep.py +44 -0
- spikezoo/archs/bsf/models/get_model.py +7 -0
- spikezoo/archs/bsf/prepare_data/DSFT.py +62 -0
- spikezoo/archs/bsf/prepare_data/crop_dataset_train.py +135 -0
- spikezoo/archs/bsf/prepare_data/crop_dataset_val.py +139 -0
- spikezoo/archs/bsf/prepare_data/crop_train.sh +4 -0
- spikezoo/archs/bsf/prepare_data/crop_val.sh +4 -0
- spikezoo/archs/bsf/prepare_data/io_utils.py +64 -0
- spikezoo/archs/bsf/requirements.txt +9 -0
- spikezoo/archs/bsf/test.py +16 -0
- spikezoo/archs/bsf/utils.py +154 -0
- spikezoo/archs/spikeclip/__pycache__/nets.cpython-39.pyc +0 -0
- spikezoo/archs/spikeclip/nets.py +40 -0
- spikezoo/archs/spikeformer/CheckPoints/readme +1 -0
- spikezoo/archs/spikeformer/DataProcess/DataExtactor.py +60 -0
- spikezoo/archs/spikeformer/DataProcess/DataLoader.py +115 -0
- spikezoo/archs/spikeformer/DataProcess/LoadSpike.py +39 -0
- spikezoo/archs/spikeformer/EvalResults/readme +1 -0
- spikezoo/archs/spikeformer/LICENSE +21 -0
- spikezoo/archs/spikeformer/Metrics/Metrics.py +50 -0
- spikezoo/archs/spikeformer/Metrics/__init__.py +0 -0
- spikezoo/archs/spikeformer/Model/Loss.py +89 -0
- spikezoo/archs/spikeformer/Model/SpikeFormer.py +230 -0
- spikezoo/archs/spikeformer/Model/__init__.py +0 -0
- spikezoo/archs/spikeformer/Model/__pycache__/SpikeFormer.cpython-39.pyc +0 -0
- spikezoo/archs/spikeformer/Model/__pycache__/__init__.cpython-39.pyc +0 -0
- spikezoo/archs/spikeformer/README.md +30 -0
- spikezoo/archs/spikeformer/evaluate.py +87 -0
- spikezoo/archs/spikeformer/recon_real_data.py +97 -0
- spikezoo/archs/spikeformer/requirements.yml +95 -0
- spikezoo/archs/spikeformer/train.py +173 -0
- spikezoo/archs/spikeformer/utils.py +22 -0
- spikezoo/archs/spk2imgnet/.github/workflows/pylint.yml +23 -0
- spikezoo/archs/spk2imgnet/.gitignore +150 -0
- spikezoo/archs/spk2imgnet/DCNv2.py +135 -0
- spikezoo/archs/spk2imgnet/__pycache__/DCNv2.cpython-39.pyc +0 -0
- spikezoo/archs/spk2imgnet/__pycache__/align_arch.cpython-39.pyc +0 -0
- spikezoo/archs/spk2imgnet/__pycache__/nets.cpython-39.pyc +0 -0
- spikezoo/archs/spk2imgnet/align_arch.py +159 -0
- spikezoo/archs/spk2imgnet/dataset.py +144 -0
- spikezoo/archs/spk2imgnet/nets.py +230 -0
- spikezoo/archs/spk2imgnet/readme.md +86 -0
- spikezoo/archs/spk2imgnet/test_gen_imgseq.py +118 -0
- spikezoo/archs/spk2imgnet/train.py +189 -0
- spikezoo/archs/spk2imgnet/utils.py +64 -0
- spikezoo/archs/ssir/README.md +87 -0
- spikezoo/archs/ssir/configs/SSIR.yml +37 -0
- spikezoo/archs/ssir/configs/yml_parser.py +78 -0
- spikezoo/archs/ssir/datasets/dataset_sreds.py +170 -0
- spikezoo/archs/ssir/datasets/ds_utils.py +66 -0
- spikezoo/archs/ssir/losses.py +21 -0
- spikezoo/archs/ssir/main.py +326 -0
- spikezoo/archs/ssir/metrics/psnr.py +22 -0
- spikezoo/archs/ssir/metrics/ssim.py +54 -0
- spikezoo/archs/ssir/models/Vgg19.py +42 -0
- spikezoo/archs/ssir/models/__pycache__/layers.cpython-39.pyc +0 -0
- spikezoo/archs/ssir/models/__pycache__/networks.cpython-39.pyc +0 -0
- spikezoo/archs/ssir/models/layers.py +110 -0
- spikezoo/archs/ssir/models/networks.py +61 -0
- spikezoo/archs/ssir/requirements.txt +8 -0
- spikezoo/archs/ssir/shells/eval_SREDS.sh +6 -0
- spikezoo/archs/ssir/shells/train_SSIR.sh +12 -0
- spikezoo/archs/ssir/test.py +3 -0
- spikezoo/archs/ssir/utils.py +154 -0
- spikezoo/archs/ssml/__pycache__/cbam.cpython-39.pyc +0 -0
- spikezoo/archs/ssml/__pycache__/model.cpython-39.pyc +0 -0
- spikezoo/archs/ssml/cbam.py +224 -0
- spikezoo/archs/ssml/model.py +290 -0
- spikezoo/archs/ssml/res.png +0 -0
- spikezoo/archs/ssml/test.py +67 -0
- spikezoo/archs/stir/.git-credentials +0 -0
- spikezoo/archs/stir/README.md +65 -0
- spikezoo/archs/stir/ckpt_outputs/Descriptions.txt +1 -0
- spikezoo/archs/stir/configs/STIR.yml +37 -0
- spikezoo/archs/stir/configs/utils.py +155 -0
- spikezoo/archs/stir/configs/yml_parser.py +78 -0
- spikezoo/archs/stir/datasets/dataset_sreds.py +180 -0
- spikezoo/archs/stir/datasets/ds_utils.py +66 -0
- spikezoo/archs/stir/eval_SREDS.sh +5 -0
- spikezoo/archs/stir/main.py +397 -0
- spikezoo/archs/stir/metrics/losses.py +219 -0
- spikezoo/archs/stir/metrics/psnr.py +22 -0
- spikezoo/archs/stir/metrics/ssim.py +54 -0
- spikezoo/archs/stir/models/Vgg19.py +42 -0
- spikezoo/archs/stir/models/__pycache__/networks_STIR.cpython-39.pyc +0 -0
- spikezoo/archs/stir/models/__pycache__/submodules.cpython-39.pyc +0 -0
- spikezoo/archs/stir/models/__pycache__/transformer_new.cpython-39.pyc +0 -0
- spikezoo/archs/stir/models/networks_STIR.py +361 -0
- spikezoo/archs/stir/models/submodules.py +86 -0
- spikezoo/archs/stir/models/transformer_new.py +151 -0
- spikezoo/archs/stir/package_core/build/lib/package_core/__init__.py +0 -0
- spikezoo/archs/stir/package_core/build/lib/package_core/convertions.py +721 -0
- spikezoo/archs/stir/package_core/build/lib/package_core/disp_netS.py +133 -0
- spikezoo/archs/stir/package_core/build/lib/package_core/flow_utils.py +167 -0
- spikezoo/archs/stir/package_core/build/lib/package_core/generic_train_test.py +76 -0
- spikezoo/archs/stir/package_core/build/lib/package_core/geometry.py +458 -0
- spikezoo/archs/stir/package_core/build/lib/package_core/image_proc.py +183 -0
- spikezoo/archs/stir/package_core/build/lib/package_core/linalg.py +40 -0
- spikezoo/archs/stir/package_core/build/lib/package_core/losses.py +198 -0
- spikezoo/archs/stir/package_core/build/lib/package_core/metrics.py +51 -0
- spikezoo/archs/stir/package_core/build/lib/package_core/model_base.py +53 -0
- spikezoo/archs/stir/package_core/build/lib/package_core/net_basics.py +100 -0
- spikezoo/archs/stir/package_core/build/lib/package_core/resnet.py +333 -0
- spikezoo/archs/stir/package_core/build/lib/package_core/transforms.py +123 -0
- spikezoo/archs/stir/package_core/build/lib/package_core/utils.py +72 -0
- spikezoo/archs/stir/package_core/dist/package_core-0.0.0-py3.9.egg +0 -0
- spikezoo/archs/stir/package_core/package_core/__init__.py +0 -0
- spikezoo/archs/stir/package_core/package_core/__pycache__/__init__.cpython-39.pyc +0 -0
- spikezoo/archs/stir/package_core/package_core/__pycache__/net_basics.cpython-39.pyc +0 -0
- spikezoo/archs/stir/package_core/package_core/convertions.py +721 -0
- spikezoo/archs/stir/package_core/package_core/disp_netS.py +133 -0
- spikezoo/archs/stir/package_core/package_core/flow_utils.py +167 -0
- spikezoo/archs/stir/package_core/package_core/generic_train_test.py +76 -0
- spikezoo/archs/stir/package_core/package_core/geometry.py +458 -0
- spikezoo/archs/stir/package_core/package_core/image_proc.py +183 -0
- spikezoo/archs/stir/package_core/package_core/linalg.py +40 -0
- spikezoo/archs/stir/package_core/package_core/losses.py +198 -0
- spikezoo/archs/stir/package_core/package_core/metrics.py +51 -0
- spikezoo/archs/stir/package_core/package_core/model_base.py +53 -0
- spikezoo/archs/stir/package_core/package_core/net_basics.py +100 -0
- spikezoo/archs/stir/package_core/package_core/resnet.py +333 -0
- spikezoo/archs/stir/package_core/package_core/transforms.py +123 -0
- spikezoo/archs/stir/package_core/package_core/utils.py +72 -0
- spikezoo/archs/stir/package_core/package_core.egg-info/PKG-INFO +3 -0
- spikezoo/archs/stir/package_core/package_core.egg-info/SOURCES.txt +20 -0
- spikezoo/archs/stir/package_core/package_core.egg-info/dependency_links.txt +1 -0
- spikezoo/archs/stir/package_core/package_core.egg-info/top_level.txt +1 -0
- spikezoo/archs/stir/package_core/setup.py +5 -0
- spikezoo/archs/stir/requirements.txt +12 -0
- spikezoo/archs/stir/train_STIR.sh +9 -0
- spikezoo/archs/tfi/__pycache__/nets.cpython-39.pyc +0 -0
- spikezoo/archs/tfi/nets.py +43 -0
- spikezoo/archs/tfp/__pycache__/nets.cpython-39.pyc +0 -0
- spikezoo/archs/tfp/nets.py +13 -0
- spikezoo/archs/wgse/README.md +64 -0
- spikezoo/archs/wgse/__pycache__/dwtnets.cpython-39.pyc +0 -0
- spikezoo/archs/wgse/__pycache__/submodules.cpython-39.pyc +0 -0
- spikezoo/archs/wgse/dataset.py +59 -0
- spikezoo/archs/wgse/demo.png +0 -0
- spikezoo/archs/wgse/demo.py +83 -0
- spikezoo/archs/wgse/dwtnets.py +145 -0
- spikezoo/archs/wgse/eval.py +133 -0
- spikezoo/archs/wgse/logs/WGSE-Dwt1dNet-db8-5-ks3/log.txt +11 -0
- spikezoo/archs/wgse/submodules.py +68 -0
- spikezoo/archs/wgse/train.py +261 -0
- spikezoo/archs/wgse/transform.py +139 -0
- spikezoo/archs/wgse/utils.py +128 -0
- spikezoo/archs/wgse/weights/demo.png +0 -0
- spikezoo/data/base/test/gt/200_part1_key_id151.png +0 -0
- spikezoo/data/base/test/gt/200_part3_key_id151.png +0 -0
- spikezoo/data/base/test/gt/203_part1_key_id151.png +0 -0
- spikezoo/data/base/test/spike/200_part1_key_id151.dat +0 -0
- spikezoo/data/base/test/spike/200_part3_key_id151.dat +0 -0
- spikezoo/data/base/test/spike/203_part1_key_id151.dat +0 -0
- spikezoo/data/base/train/gt/203_part2_key_id151.png +0 -0
- spikezoo/data/base/train/gt/203_part3_key_id151.png +0 -0
- spikezoo/data/base/train/gt/203_part4_key_id151.png +0 -0
- spikezoo/data/base/train/spike/203_part2_key_id151.dat +0 -0
- spikezoo/data/base/train/spike/203_part3_key_id151.dat +0 -0
- spikezoo/data/base/train/spike/203_part4_key_id151.dat +0 -0
- spikezoo/datasets/base_dataset.py +2 -3
- spikezoo/metrics/__init__.py +1 -1
- spikezoo/models/base_model.py +1 -3
- spikezoo/pipeline/base_pipeline.py +7 -5
- spikezoo/pipeline/train_pipeline.py +1 -1
- spikezoo/utils/other_utils.py +16 -6
- spikezoo/utils/spike_utils.py +33 -29
- spikezoo/utils/vidar_loader.cpython-39-x86_64-linux-gnu.so +0 -0
- spikezoo-0.2.1.dist-info/METADATA +167 -0
- spikezoo-0.2.1.dist-info/RECORD +211 -0
- spikezoo/models/spcsnet_model.py +0 -19
- spikezoo-0.1.2.dist-info/METADATA +0 -39
- spikezoo-0.1.2.dist-info/RECORD +0 -36
- {spikezoo-0.1.2.dist-info → spikezoo-0.2.1.dist-info}/LICENSE.txt +0 -0
- {spikezoo-0.1.2.dist-info → spikezoo-0.2.1.dist-info}/WHEEL +0 -0
- {spikezoo-0.1.2.dist-info → spikezoo-0.2.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,154 @@
|
|
1
|
+
import numpy as np
|
2
|
+
import torch
|
3
|
+
import torch.nn.functional as F
|
4
|
+
import os
|
5
|
+
import os.path as osp
|
6
|
+
import random
|
7
|
+
import cv2
|
8
|
+
|
9
|
+
def set_seeds(_seed_):
|
10
|
+
random.seed(_seed_)
|
11
|
+
np.random.seed(_seed_)
|
12
|
+
torch.manual_seed(_seed_) # use torch.manual_seed() to seed the RNG for all devices (both CPU and CUDA)
|
13
|
+
torch.cuda.manual_seed_all(_seed_)
|
14
|
+
|
15
|
+
torch.backends.cudnn.deterministic = True
|
16
|
+
torch.backends.cudnn.benchmark = False
|
17
|
+
# os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'
|
18
|
+
# set a debug environment variable CUBLAS_WORKSPACE_CONFIG to ":16:8" (may limit overall performance) or ":4096:8" (will increase library footprint in GPU memory by approximately 24MiB).
|
19
|
+
# torch.use_deterministic_algorithms(True)
|
20
|
+
|
21
|
+
|
22
|
+
def make_dir(path):
|
23
|
+
if not osp.exists(path):
|
24
|
+
os.makedirs(path)
|
25
|
+
return
|
26
|
+
|
27
|
+
|
28
|
+
def add_args_to_cfg(cfg, args, args_list):
|
29
|
+
for aa in args_list:
|
30
|
+
cfg['train'][aa] = eval('args.{:s}'.format(aa))
|
31
|
+
return cfg
|
32
|
+
|
33
|
+
|
34
|
+
# class AverageMeter(object):
|
35
|
+
# """Computes and stores the average and current value"""
|
36
|
+
# def __init__(self, precision=3):
|
37
|
+
# self.precision = precision
|
38
|
+
# self.reset()
|
39
|
+
|
40
|
+
# def reset(self):
|
41
|
+
# self.val = 0
|
42
|
+
# self.avg = 0
|
43
|
+
# self.sum = 0
|
44
|
+
# self.count = 0
|
45
|
+
|
46
|
+
# def update(self, val, n=1):
|
47
|
+
# self.val = val
|
48
|
+
# self.sum += val * n
|
49
|
+
# self.count += n
|
50
|
+
# self.avg = self.sum / self.count
|
51
|
+
|
52
|
+
# def __repr__(self):
|
53
|
+
# return '{:.{}f} ({:.{}f})'.format(self.val, self.precision, self.avg, self.precision)
|
54
|
+
|
55
|
+
|
56
|
+
class AverageMeter(object):
|
57
|
+
"""Computes and stores the average and current value"""
|
58
|
+
|
59
|
+
def __init__(self, i=1, precision=3, names=None):
|
60
|
+
self.meters = i
|
61
|
+
self.precision = precision
|
62
|
+
self.reset(self.meters)
|
63
|
+
self.names = names
|
64
|
+
if names is not None:
|
65
|
+
assert self.meters == len(self.names)
|
66
|
+
else:
|
67
|
+
self.names = [''] * self.meters
|
68
|
+
|
69
|
+
def reset(self, i):
|
70
|
+
self.val = [0] * i
|
71
|
+
self.avg = [0] * i
|
72
|
+
self.sum = [0] * i
|
73
|
+
self.count = [0] * i
|
74
|
+
|
75
|
+
def update(self, val, n=1):
|
76
|
+
if not isinstance(val, list):
|
77
|
+
val = [val]
|
78
|
+
if not isinstance(n, list):
|
79
|
+
n = [n] * self.meters
|
80
|
+
assert (len(val) == self.meters and len(n) == self.meters)
|
81
|
+
for i in range(self.meters):
|
82
|
+
self.count[i] += n[i]
|
83
|
+
for i, v in enumerate(val):
|
84
|
+
self.val[i] = v
|
85
|
+
self.sum[i] += v * n[i]
|
86
|
+
self.avg[i] = self.sum[i] / self.count[i]
|
87
|
+
|
88
|
+
def __repr__(self):
|
89
|
+
# val = ' '.join(['{} {:.{}f}'.format(n, v, self.precision) for n, v in
|
90
|
+
# zip(self.names, self.val)])
|
91
|
+
# avg = ' '.join(['{} {:.{}f}'.format(n, a, self.precision) for n, a in
|
92
|
+
# zip(self.names, self.avg)])
|
93
|
+
out = ' '.join(['{} {:.{}f} ({:.{}f})'.format(n, v, self.precision, a, self.precision) for n, v, a in
|
94
|
+
zip(self.names, self.val, self.avg)])
|
95
|
+
# return '{} ({})'.format(val, avg)
|
96
|
+
return '{}'.format(out)
|
97
|
+
|
98
|
+
|
99
|
+
def normalize_image_torch(image, percentile_lower=1, percentile_upper=99):
|
100
|
+
b, c, h, w = image.shape
|
101
|
+
image_reshape = image.reshape([b, c, h*w])
|
102
|
+
mini = torch.quantile(image_reshape, 0.01, dim=2, keepdim=True).unsqueeze_(dim=3)
|
103
|
+
maxi = torch.quantile(image_reshape, 0.99, dim=2, keepdim=True).unsqueeze_(dim=3)
|
104
|
+
# if mini == maxi:
|
105
|
+
# return 0 * image + 0.5 # gray image
|
106
|
+
return torch.clip((image - mini) / (maxi - mini + 1e-5), 0, 1)
|
107
|
+
|
108
|
+
def normalize_image_torch2(image):
|
109
|
+
return torch.clip(image, 0, 1)
|
110
|
+
|
111
|
+
# --------------------------------------------
|
112
|
+
# Torch to Numpy 0~255
|
113
|
+
# --------------------------------------------
|
114
|
+
def torch2numpy255(im):
|
115
|
+
im = im[0, 0].detach().cpu().numpy()
|
116
|
+
im = (im * 255).astype(np.float64)
|
117
|
+
return im
|
118
|
+
|
119
|
+
def torch2torch255(im):
|
120
|
+
return im * 255.0
|
121
|
+
|
122
|
+
class InputPadder:
|
123
|
+
""" Pads images such that dimensions are divisible by padsize """
|
124
|
+
def __init__(self, dims, padsize=16):
|
125
|
+
self.ht, self.wd = dims[-2:]
|
126
|
+
pad_ht = (((self.ht // padsize) + 1) * padsize - self.ht) % padsize
|
127
|
+
pad_wd = (((self.wd // padsize) + 1) * padsize - self.wd) %padsize
|
128
|
+
self._pad = [pad_wd//2, pad_wd - pad_wd//2, 0, pad_ht]
|
129
|
+
|
130
|
+
def pad(self, *inputs):
|
131
|
+
return [F.pad(x, self._pad, mode='replicate') for x in inputs]
|
132
|
+
|
133
|
+
def unpad(self,x):
|
134
|
+
ht, wd = x.shape[-2:]
|
135
|
+
c = [self._pad[2], ht-self._pad[3], self._pad[0], wd-self._pad[1]]
|
136
|
+
return x[..., c[0]:c[1], c[2]:c[3]]
|
137
|
+
|
138
|
+
|
139
|
+
|
140
|
+
def vis_img(vis_path: str, img: torch.Tensor, vis_name: str = 'vis'):
|
141
|
+
ww = 0
|
142
|
+
rows = []
|
143
|
+
for ii in range(4):
|
144
|
+
cur_row = []
|
145
|
+
for jj in range(img.shape[0]//4):
|
146
|
+
cur_img = img[ww, 0].detach().cpu().numpy() * 255
|
147
|
+
cur_img = cur_img.astype(np.uint8)
|
148
|
+
cur_row.append(cur_img)
|
149
|
+
ww += 1
|
150
|
+
cur_row_cat = np.concatenate(cur_row, axis=1)
|
151
|
+
rows.append(cur_row_cat)
|
152
|
+
out_img = np.concatenate(rows, axis=0)
|
153
|
+
cv2.imwrite(osp.join(vis_path, vis_name+'.png'), out_img)
|
154
|
+
return
|
Binary file
|
Binary file
|
@@ -0,0 +1,224 @@
|
|
1
|
+
import torch
|
2
|
+
import math
|
3
|
+
import torch.nn as nn
|
4
|
+
import torch.nn.functional as F
|
5
|
+
|
6
|
+
class crop(nn.Module):
|
7
|
+
def __init__(self):
|
8
|
+
super().__init__()
|
9
|
+
|
10
|
+
def forward(self, x):
|
11
|
+
N, C, H, W = x.shape
|
12
|
+
x = x[0:N, 0:C, 0:H-1, 0:W]
|
13
|
+
return x
|
14
|
+
|
15
|
+
class shift(nn.Module):
|
16
|
+
def __init__(self):
|
17
|
+
super().__init__()
|
18
|
+
self.shift_down = nn.ZeroPad2d((0,0,1,0))
|
19
|
+
self.crop = crop()
|
20
|
+
|
21
|
+
def forward(self, x):
|
22
|
+
x = self.shift_down(x)
|
23
|
+
x = self.crop(x)
|
24
|
+
return x
|
25
|
+
|
26
|
+
class Conv(nn.Module):
|
27
|
+
def __init__(self, in_channels, out_channels, bias=False, blind=True,stride=1,padding=0,kernel_size=3):
|
28
|
+
super().__init__()
|
29
|
+
self.blind = blind
|
30
|
+
if blind:
|
31
|
+
self.shift_down = nn.ZeroPad2d((0,0,1,0))
|
32
|
+
self.crop = crop()
|
33
|
+
self.replicate = nn.ReplicationPad2d(1)
|
34
|
+
# self.conv = nn.Conv2d(in_channels, out_channels, 3, bias=bias)
|
35
|
+
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride,padding=padding,bias=bias)
|
36
|
+
self.relu = nn.LeakyReLU(0.1, inplace=True)
|
37
|
+
# self.ln = nn.GroupNorm(1,out_channels)
|
38
|
+
|
39
|
+
def forward(self, x):
|
40
|
+
if self.blind:
|
41
|
+
x = self.shift_down(x)
|
42
|
+
x = self.replicate(x)
|
43
|
+
x = self.conv(x)
|
44
|
+
x = self.relu(x)
|
45
|
+
|
46
|
+
if self.blind:
|
47
|
+
x = self.crop(x)
|
48
|
+
return x
|
49
|
+
|
50
|
+
class BasicConv(nn.Module):
|
51
|
+
def __init__(self, in_planes, out_planes, kernel_size, stride=1, padding=0, dilation=1, groups=1, relu=False, bn=False, bias=True,blind=False):
|
52
|
+
super(BasicConv, self).__init__()
|
53
|
+
self.out_channels = out_planes
|
54
|
+
self.conv = nn.Conv2d(in_planes, out_planes, kernel_size=kernel_size, stride=stride, padding=padding, dilation=dilation, groups=groups, bias=bias)
|
55
|
+
self.bn = nn.BatchNorm2d(out_planes,eps=1e-5, momentum=0.01, affine=True) if bn else None
|
56
|
+
self.relu = nn.ReLU() if relu else None
|
57
|
+
|
58
|
+
def forward(self, x):
|
59
|
+
x = self.conv(x)
|
60
|
+
if self.bn is not None:
|
61
|
+
x = self.bn(x)
|
62
|
+
if self.relu is not None:
|
63
|
+
x = self.relu(x)
|
64
|
+
return x
|
65
|
+
|
66
|
+
class Flatten(nn.Module):
|
67
|
+
def forward(self, x):
|
68
|
+
return x.view(x.size(0), -1)
|
69
|
+
|
70
|
+
class ChannelGate(nn.Module):
|
71
|
+
def __init__(self, gate_channels, reduction_ratio=16, pool_types=['avg', 'max']):
|
72
|
+
super(ChannelGate, self).__init__()
|
73
|
+
self.gate_channels = gate_channels
|
74
|
+
self.mlp = nn.Sequential(
|
75
|
+
Flatten(),
|
76
|
+
nn.Linear(gate_channels, gate_channels // reduction_ratio),
|
77
|
+
nn.ReLU(),
|
78
|
+
nn.Linear(gate_channels // reduction_ratio, gate_channels)
|
79
|
+
)
|
80
|
+
self.pool_types = pool_types
|
81
|
+
def forward(self, x):
|
82
|
+
channel_att_sum = None
|
83
|
+
for pool_type in self.pool_types:
|
84
|
+
if pool_type=='avg':
|
85
|
+
avg_pool = F.avg_pool2d( x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3)))
|
86
|
+
channel_att_raw = self.mlp( avg_pool )
|
87
|
+
elif pool_type=='max':
|
88
|
+
max_pool = F.max_pool2d( x, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3)))
|
89
|
+
channel_att_raw = self.mlp( max_pool )
|
90
|
+
elif pool_type=='lp':
|
91
|
+
lp_pool = F.lp_pool2d( x, 2, (x.size(2), x.size(3)), stride=(x.size(2), x.size(3)))
|
92
|
+
channel_att_raw = self.mlp( lp_pool )
|
93
|
+
elif pool_type=='lse':
|
94
|
+
# LSE pool only
|
95
|
+
lse_pool = logsumexp_2d(x)
|
96
|
+
channel_att_raw = self.mlp( lse_pool )
|
97
|
+
|
98
|
+
if channel_att_sum is None:
|
99
|
+
channel_att_sum = channel_att_raw
|
100
|
+
else:
|
101
|
+
channel_att_sum = channel_att_sum + channel_att_raw
|
102
|
+
|
103
|
+
scale = F.sigmoid( channel_att_sum ).unsqueeze(2).unsqueeze(3).expand_as(x)
|
104
|
+
return x * scale
|
105
|
+
|
106
|
+
def logsumexp_2d(tensor):
|
107
|
+
tensor_flatten = tensor.view(tensor.size(0), tensor.size(1), -1)
|
108
|
+
s, _ = torch.max(tensor_flatten, dim=2, keepdim=True)
|
109
|
+
outputs = s + (tensor_flatten - s).exp().sum(dim=2, keepdim=True).log()
|
110
|
+
return outputs
|
111
|
+
|
112
|
+
class ChannelPool(nn.Module):
|
113
|
+
def forward(self, x):
|
114
|
+
return torch.cat( (torch.max(x,1)[0].unsqueeze(1), torch.mean(x,1).unsqueeze(1)), dim=1 )
|
115
|
+
|
116
|
+
class SpatialGate(nn.Module):
|
117
|
+
def __init__(self,bias=False,blind=False):
|
118
|
+
super(SpatialGate, self).__init__()
|
119
|
+
kernel_size = 7
|
120
|
+
self.compress = ChannelPool()
|
121
|
+
self.spatial = BasicConv(2, 1, kernel_size, stride=1, padding=(kernel_size-1) // 2, relu=False,bias=bias,blind=False)
|
122
|
+
def forward(self, x):
|
123
|
+
x_compress = self.compress(x)
|
124
|
+
x_out = self.spatial(x_compress)
|
125
|
+
scale = F.sigmoid(x_out) # broadcasting
|
126
|
+
return x * scale
|
127
|
+
|
128
|
+
class CBAM(nn.Module):
|
129
|
+
def __init__(self, gate_channels, reduction_ratio=16, pool_types=['avg', 'max'], no_spatial=False):
|
130
|
+
super(CBAM, self).__init__()
|
131
|
+
self.ChannelGate = ChannelGate(gate_channels, reduction_ratio, pool_types)
|
132
|
+
self.no_spatial=no_spatial
|
133
|
+
if not no_spatial:
|
134
|
+
self.SpatialGate = SpatialGate()
|
135
|
+
def forward(self, x):
|
136
|
+
x_out = self.ChannelGate(x)
|
137
|
+
if not self.no_spatial:
|
138
|
+
x_out = self.SpatialGate(x_out)
|
139
|
+
return x_out
|
140
|
+
|
141
|
+
def weights_init_rcan(m):
|
142
|
+
"""
|
143
|
+
custom weights initialization called on netG and netD
|
144
|
+
https://github.com/pytorch/examples/blob/master/dcgan/main.py
|
145
|
+
"""
|
146
|
+
classname = m.__class__.__name__
|
147
|
+
if classname.find('Conv') != -1:
|
148
|
+
if classname.find('BasicConv') != -1:
|
149
|
+
m.conv.weight.data.normal_(0.0, 0.02)
|
150
|
+
if m.bn != None:
|
151
|
+
m.bn.bias.data.fill_(0)
|
152
|
+
else:
|
153
|
+
m.weight.data.normal_(0.0, 0.02)
|
154
|
+
elif classname.find('BatchNorm') != -1:
|
155
|
+
m.weight.data.normal_(1.0, 0.02)
|
156
|
+
m.bias.data.fill_(0)
|
157
|
+
|
158
|
+
class Temporal_Fusion(nn.Module):
|
159
|
+
|
160
|
+
def __init__(self, nf=64, nframes=3, center=1,bias=False):
|
161
|
+
super(Temporal_Fusion, self).__init__()
|
162
|
+
self.center = center
|
163
|
+
|
164
|
+
self.tAtt_1 = nn.Conv2d(nf, nf, 3, 1, 1, bias=bias)
|
165
|
+
self.tAtt_2 = nn.Conv2d(nf, nf, 3, 1, 1, bias=bias)
|
166
|
+
|
167
|
+
self.fea_fusion = nn.Conv2d(nframes * nf, nf, 1, 1, bias=bias)
|
168
|
+
|
169
|
+
self.sAtt_1 = nn.Conv2d(nframes * nf, nf, 1, 1, bias=bias)
|
170
|
+
self.maxpool = nn.MaxPool2d(3, stride=2, padding=1)
|
171
|
+
self.avgpool = nn.AvgPool2d(3, stride=2, padding=1)
|
172
|
+
self.sAtt_2 = nn.Conv2d(nf * 2, nf, 1, 1, bias=bias)
|
173
|
+
self.sAtt_3 = nn.Conv2d(nf, nf, 3, 1, 1, bias=bias)
|
174
|
+
self.sAtt_4 = nn.Conv2d(nf, nf, 1, 1, bias=bias)
|
175
|
+
self.sAtt_5 = nn.Conv2d(nf, nf, 3, 1, 1, bias=bias)
|
176
|
+
self.sAtt_L1 = nn.Conv2d(nf, nf, 1, 1, bias=bias)
|
177
|
+
self.sAtt_L2 = nn.Conv2d(nf * 2, nf, 3, 1, 1, bias=bias)
|
178
|
+
self.sAtt_L3 = nn.Conv2d(nf, nf, 3, 1, 1, bias=bias)
|
179
|
+
self.sAtt_add_1 = nn.Conv2d(nf, nf, 1, 1, bias=bias)
|
180
|
+
self.sAtt_add_2 = nn.Conv2d(nf, nf, 1, 1, bias=bias)
|
181
|
+
|
182
|
+
self.lrelu = nn.LeakyReLU(negative_slope=0.1, inplace=True)
|
183
|
+
|
184
|
+
def forward(self, nonlocal_fea):
|
185
|
+
B, N, C, H, W = nonlocal_fea.size()
|
186
|
+
|
187
|
+
emb_ref = self.tAtt_2(nonlocal_fea[:, self.center, :, :, :].clone())
|
188
|
+
emb = self.tAtt_1(nonlocal_fea.view(-1, C, H, W)).view(B, N, -1, H, W)
|
189
|
+
|
190
|
+
cor_l = []
|
191
|
+
for i in range(N):
|
192
|
+
emb_nbr = emb[:, i, :, :, :]
|
193
|
+
cor_tmp = torch.sum(emb_nbr * emb_ref, 1).unsqueeze(1)
|
194
|
+
cor_l.append(cor_tmp)
|
195
|
+
cor_prob = torch.sigmoid(torch.cat(cor_l, dim=1))
|
196
|
+
cor_prob = cor_prob.unsqueeze(2).repeat(1, 1, C, 1, 1)
|
197
|
+
cor_prob = cor_prob.view(B, -1, H, W)
|
198
|
+
nonlocal_fea = nonlocal_fea.view(B, -1, H, W) * cor_prob
|
199
|
+
|
200
|
+
fea = self.lrelu(self.fea_fusion(nonlocal_fea))
|
201
|
+
|
202
|
+
att = self.lrelu(self.sAtt_1(nonlocal_fea))
|
203
|
+
att_max = self.maxpool(att)
|
204
|
+
att_avg = self.avgpool(att)
|
205
|
+
att = self.lrelu(self.sAtt_2(torch.cat([att_max, att_avg], dim=1)))
|
206
|
+
|
207
|
+
att_L = self.lrelu(self.sAtt_L1(att))
|
208
|
+
att_max = self.maxpool(att_L)
|
209
|
+
att_avg = self.avgpool(att_L)
|
210
|
+
att_L = self.lrelu(self.sAtt_L2(torch.cat([att_max, att_avg], dim=1)))
|
211
|
+
att_L = self.lrelu(self.sAtt_L3(att_L))
|
212
|
+
att_L = F.interpolate(att_L, scale_factor=2, mode='bilinear', align_corners=False)
|
213
|
+
|
214
|
+
att = self.lrelu(self.sAtt_3(att))
|
215
|
+
att = att + att_L
|
216
|
+
att = self.lrelu(self.sAtt_4(att))
|
217
|
+
att = F.interpolate(att, scale_factor=2, mode='bilinear', align_corners=False)
|
218
|
+
att = self.sAtt_5(att)
|
219
|
+
att_add = self.sAtt_add_2(self.lrelu(self.sAtt_add_1(att)))
|
220
|
+
att = torch.sigmoid(att)
|
221
|
+
|
222
|
+
fea = fea * att * 2 + att_add
|
223
|
+
|
224
|
+
return fea
|
@@ -0,0 +1,290 @@
|
|
1
|
+
import torch
|
2
|
+
import torch.nn as nn
|
3
|
+
import torch.nn.functional as F
|
4
|
+
# from utils import *
|
5
|
+
import numpy as np
|
6
|
+
import os
|
7
|
+
import sys
|
8
|
+
current_dir = os.path.dirname(os.path.abspath(__file__))
|
9
|
+
sys.path.append(current_dir)
|
10
|
+
|
11
|
+
from cbam import SpatialGate,ChannelGate,Temporal_Fusion
|
12
|
+
|
13
|
+
class crop(nn.Module):
|
14
|
+
def __init__(self):
|
15
|
+
super().__init__()
|
16
|
+
|
17
|
+
def forward(self, x):
|
18
|
+
N, C, H, W = x.shape
|
19
|
+
x = x[0:N, 0:C, 0:H-1, 0:W]
|
20
|
+
return x
|
21
|
+
|
22
|
+
class shift(nn.Module):
|
23
|
+
def __init__(self):
|
24
|
+
super().__init__()
|
25
|
+
self.shift_down = nn.ZeroPad2d((0,0,1,0))
|
26
|
+
self.crop = crop()
|
27
|
+
|
28
|
+
def forward(self, x):
|
29
|
+
x = self.shift_down(x)
|
30
|
+
x = self.crop(x)
|
31
|
+
return x
|
32
|
+
|
33
|
+
class Conv(nn.Module):
|
34
|
+
def __init__(self, in_channels, out_channels, bias=False, blind=True,stride=1,padding=0,kernel_size=3):
|
35
|
+
super().__init__()
|
36
|
+
self.blind = blind
|
37
|
+
if blind:
|
38
|
+
self.shift_down = nn.ZeroPad2d((0,0,1,0))
|
39
|
+
self.crop = crop()
|
40
|
+
self.replicate = nn.ReplicationPad2d(1)
|
41
|
+
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size=kernel_size, stride=stride,padding=padding,bias=bias)
|
42
|
+
self.relu = nn.LeakyReLU(0.1, inplace=True)
|
43
|
+
|
44
|
+
|
45
|
+
def forward(self, x):
|
46
|
+
if self.blind:
|
47
|
+
x = self.shift_down(x)
|
48
|
+
x = self.replicate(x)
|
49
|
+
x = self.conv(x)
|
50
|
+
x = self.relu(x)
|
51
|
+
if self.blind:
|
52
|
+
x = self.crop(x)
|
53
|
+
return x
|
54
|
+
|
55
|
+
class Pool(nn.Module):
|
56
|
+
def __init__(self, blind=True):
|
57
|
+
super().__init__()
|
58
|
+
self.blind = blind
|
59
|
+
if blind:
|
60
|
+
self.shift = shift()
|
61
|
+
self.pool = nn.MaxPool2d(2)
|
62
|
+
|
63
|
+
def forward(self, x):
|
64
|
+
if self.blind:
|
65
|
+
x = self.shift(x)
|
66
|
+
x = self.pool(x)
|
67
|
+
return x
|
68
|
+
|
69
|
+
class rotate(nn.Module):
|
70
|
+
def __init__(self):
|
71
|
+
super().__init__()
|
72
|
+
|
73
|
+
def forward(self, x):
|
74
|
+
x90 = x.transpose(2,3).flip(3)
|
75
|
+
x180 = x.flip(2).flip(3)
|
76
|
+
x270 = x.transpose(2,3).flip(2)
|
77
|
+
x = torch.cat((x,x90,x180,x270), dim=0)
|
78
|
+
return x
|
79
|
+
|
80
|
+
class unrotate(nn.Module):
|
81
|
+
def __init__(self):
|
82
|
+
super().__init__()
|
83
|
+
|
84
|
+
def forward(self, x):
|
85
|
+
x0, x90, x180, x270 = torch.chunk(x, 4, dim=0)
|
86
|
+
x90 = x90.transpose(2,3).flip(2)
|
87
|
+
x180 = x180.flip(2).flip(3)
|
88
|
+
x270 = x270.transpose(2,3).flip(3)
|
89
|
+
x = torch.cat((x0,x90,x180,x270), dim=1)
|
90
|
+
return x
|
91
|
+
|
92
|
+
class ENC_Conv(nn.Module):
|
93
|
+
def __init__(self, in_channels, mid_channels, out_channels, bias=False, reduce=True, blind=True):
|
94
|
+
super().__init__()
|
95
|
+
self.reduce = reduce
|
96
|
+
self.conv1 = Conv(in_channels, mid_channels, bias=bias, blind=blind)
|
97
|
+
self.conv2 = Conv(mid_channels, mid_channels, bias=bias, blind=blind)
|
98
|
+
self.conv3 = Conv(mid_channels, out_channels, bias=bias, blind=blind)
|
99
|
+
if reduce:
|
100
|
+
self.pool = Pool(blind=blind)
|
101
|
+
|
102
|
+
def forward(self, x):
|
103
|
+
x = self.conv1(x)
|
104
|
+
x = self.conv2(x)
|
105
|
+
x = self.conv3(x)
|
106
|
+
if self.reduce:
|
107
|
+
x = self.pool(x)
|
108
|
+
return x
|
109
|
+
|
110
|
+
class DEC_Conv(nn.Module):
|
111
|
+
def __init__(self, in_channels, mid_channels, out_channels, bias=False, blind=True):
|
112
|
+
super().__init__()
|
113
|
+
self.upsample = nn.Upsample(scale_factor=2, mode='nearest')
|
114
|
+
self.conv1 = Conv(in_channels, mid_channels, bias=bias, blind=blind)
|
115
|
+
self.conv2 = Conv(mid_channels, mid_channels, bias=bias, blind=blind)
|
116
|
+
self.conv3 = Conv(mid_channels, mid_channels, bias=bias, blind=blind)
|
117
|
+
self.conv4 = Conv(mid_channels, out_channels, bias=bias, blind=blind)
|
118
|
+
|
119
|
+
def forward(self, x, x_in):
|
120
|
+
x = self.upsample(x)
|
121
|
+
|
122
|
+
# Smart Padding
|
123
|
+
diffY = x_in.size()[2] - x.size()[2]
|
124
|
+
diffX = x_in.size()[3] - x.size()[3]
|
125
|
+
x = F.pad(x, [diffX // 2, diffX - diffX // 2,
|
126
|
+
diffY // 2, diffY - diffY // 2])
|
127
|
+
|
128
|
+
x = torch.cat((x, x_in), dim=1)
|
129
|
+
x = self.conv1(x)
|
130
|
+
x = self.conv2(x)
|
131
|
+
x = self.conv3(x)
|
132
|
+
x = self.conv4(x)
|
133
|
+
return x
|
134
|
+
|
135
|
+
class Blind_UNet(nn.Module):
|
136
|
+
def __init__(self, n_channels=3, n_output=96, bias=False, blind=True):
|
137
|
+
super().__init__()
|
138
|
+
self.n_channels = n_channels
|
139
|
+
self.bias = bias
|
140
|
+
self.enc1 = ENC_Conv(n_channels, 48, 48, bias=bias, blind=blind)
|
141
|
+
self.enc2 = ENC_Conv(48, 48, 48, bias=bias, blind=blind)
|
142
|
+
self.enc3 = ENC_Conv(48, 96, 48, bias=bias, reduce=False, blind=blind)
|
143
|
+
self.dec2 = DEC_Conv(96, 96, 96, bias=bias, blind=blind)
|
144
|
+
self.dec1 = DEC_Conv(96+n_channels, 96, n_output, bias=bias, blind=blind)
|
145
|
+
|
146
|
+
def forward(self, input):
|
147
|
+
x1 = self.enc1(input)
|
148
|
+
x2 = self.enc2(x1)
|
149
|
+
x = self.enc3(x2)
|
150
|
+
x = self.dec2(x, x1)
|
151
|
+
x = self.dec1(x, input)
|
152
|
+
return x
|
153
|
+
|
154
|
+
def middleTFI(spike, middle, window=50):
|
155
|
+
#左右找1
|
156
|
+
spike = spike.squeeze(1).numpy()
|
157
|
+
C, H, W = spike.shape
|
158
|
+
lindex, rindex = np.zeros([H, W]), np.zeros([H, W])
|
159
|
+
l, r = middle+1, middle+1
|
160
|
+
for r in range(middle+1, middle + window+1): #往左包括自己50个,往右不包括自己也是50个
|
161
|
+
l = l - 1
|
162
|
+
if l>=0:
|
163
|
+
newpos = spike[l, :, :]*(1 - np.sign(lindex))
|
164
|
+
distance = l*newpos
|
165
|
+
lindex += distance
|
166
|
+
if r<C:
|
167
|
+
newpos = spike[r, :, :]*(1 - np.sign(rindex))
|
168
|
+
distance = r*newpos
|
169
|
+
rindex += distance
|
170
|
+
|
171
|
+
rindex[rindex==0] = window+middle
|
172
|
+
lindex[lindex==0] = middle-window
|
173
|
+
interval = rindex - lindex
|
174
|
+
tfi = 1.0 / interval
|
175
|
+
|
176
|
+
return tfi
|
177
|
+
|
178
|
+
class MotionInference(nn.Module):
|
179
|
+
def __init__(self,n_frame=41,bias=False,blind=False):
|
180
|
+
super().__init__()
|
181
|
+
self.middle = n_frame//2
|
182
|
+
self.conv0 = nn.Conv2d(5*2+1,1,1,bias=bias)
|
183
|
+
self.conv1 = nn.Conv2d(9*2+1,1,1,bias=bias)
|
184
|
+
self.conv2 = nn.Conv2d(13*2+1,1,1,bias=bias)
|
185
|
+
self.tfpconv = Conv(in_channels=3, out_channels=16, bias=bias,blind=blind)
|
186
|
+
self.tficonv = Conv(in_channels=1, out_channels=16, bias=bias,blind=blind)
|
187
|
+
self.ChannelGate = ChannelGate(gate_channels=16, reduction_ratio=4)
|
188
|
+
self.SpatialGate = SpatialGate(bias=bias,blind=blind)
|
189
|
+
self.blind = blind
|
190
|
+
def forward(self, x):
|
191
|
+
N, C, H, W = x.shape
|
192
|
+
tmp=[]
|
193
|
+
ttt=[]
|
194
|
+
for j in range(N):
|
195
|
+
tmp2 = middleTFI(x[j].cpu(), self.middle, window=12)
|
196
|
+
tmp2 = torch.tensor(tmp2,dtype=torch.float32).unsqueeze_(dim=0)
|
197
|
+
tmp.append(tmp2) #1 40 40
|
198
|
+
ttt5=torch.mean(x[j,self.middle-3:self.middle+3+1,:,:].cpu(),dim=0).unsqueeze_(0)
|
199
|
+
ttt.append(ttt5)
|
200
|
+
tfi_label = torch.stack(tmp,0).cuda()
|
201
|
+
tfp_label = torch.stack(ttt,0).cuda()
|
202
|
+
|
203
|
+
# tfi_label = (torch.clamp(tfi_label,0,1))**(1/2.2)
|
204
|
+
# tfp_label = (torch.clamp(tfp_label,0,1))**(1/2.2)
|
205
|
+
|
206
|
+
tfp0 = self.conv0(x[:,self.middle-5:self.middle+5+1,:,:]) #b 1 h w,
|
207
|
+
tfp1 = self.conv1(x[:,self.middle-9:self.middle+9+1,:,:])
|
208
|
+
tfp2 = self.conv2(x[:,self.middle-13:self.middle+13+1,:,:])
|
209
|
+
tfps = torch.cat([tfp0,tfp1,tfp2],dim=1) #b 3 h w
|
210
|
+
|
211
|
+
tfp_fea = self.tfpconv(tfps)
|
212
|
+
tfi_fea = self.tficonv(tfi_label)
|
213
|
+
|
214
|
+
if not self.blind:
|
215
|
+
tfp_fea = self.SpatialGate(tfp_fea) #b 16 h w
|
216
|
+
tfi_fea = self.SpatialGate(tfi_fea)
|
217
|
+
fusion_fea = self.ChannelGate(tfp_fea+tfi_fea) #b 16 h w
|
218
|
+
else:
|
219
|
+
fusion_fea = tfp_fea+tfi_fea
|
220
|
+
# tfi_label = (torch.clamp(tfi_label,0,1))**(1/2.2)
|
221
|
+
# tfp_label = (torch.clamp(tfp_label,0,1))**(1/2.2)
|
222
|
+
return fusion_fea,tfi_label,tfp_label
|
223
|
+
|
224
|
+
|
225
|
+
class BSN(nn.Module):
|
226
|
+
def __init__(self, n_channels=3, n_output=3, bias=False, blind=True, sigma_known=True):
|
227
|
+
super().__init__()
|
228
|
+
self.n_channels = n_channels
|
229
|
+
self.c = n_channels
|
230
|
+
self.n_output = n_output
|
231
|
+
self.bias = bias
|
232
|
+
self.blind = blind
|
233
|
+
self.sigma_known = sigma_known
|
234
|
+
self.rotate = rotate()
|
235
|
+
self.unet = Blind_UNet(n_channels=n_channels+16, bias=bias, blind=blind)
|
236
|
+
self.shift = shift()
|
237
|
+
self.unrotate = unrotate()
|
238
|
+
self.nin_A = nn.Conv2d(384, 384, 1, bias=bias)
|
239
|
+
self.nin_B = nn.Conv2d(384, 96, 1, bias=bias)
|
240
|
+
self.nin_C = nn.Conv2d(96, n_output, 1, bias=bias)
|
241
|
+
self.MotionInference = MotionInference(n_frame=41,bias=bias,blind=blind)
|
242
|
+
|
243
|
+
def forward(self, x):
|
244
|
+
N, C, H, W = x.shape
|
245
|
+
_,tfi_label,tfp_label = self.MotionInference(x)
|
246
|
+
if(H > W):
|
247
|
+
diff = H - W
|
248
|
+
x = F.pad(x, [diff // 2, diff - diff // 2, 0, 0], mode = 'reflect')
|
249
|
+
elif(W > H):
|
250
|
+
diff = W - H
|
251
|
+
x = F.pad(x, [0, 0, diff // 2, diff - diff // 2], mode = 'reflect')
|
252
|
+
|
253
|
+
x = self.rotate(x)
|
254
|
+
|
255
|
+
fea1,tfi,tfp = self.MotionInference(x)
|
256
|
+
x = torch.cat([x,fea1],1)
|
257
|
+
|
258
|
+
x = self.unet(x) #4 3 100 100 -> 4 96 100 100
|
259
|
+
if self.blind:
|
260
|
+
x = self.shift(x)
|
261
|
+
x = self.unrotate(x) #4 96 100 100 -> 1 384 100 100
|
262
|
+
|
263
|
+
x0 = F.leaky_relu_(self.nin_A(x), negative_slope=0.1)
|
264
|
+
x0 = F.leaky_relu_(self.nin_B(x0), negative_slope=0.1)
|
265
|
+
x0 = self.nin_C(x0)
|
266
|
+
|
267
|
+
# Unsquare
|
268
|
+
if(H > W):
|
269
|
+
diff = H - W
|
270
|
+
x0 = x0[:, :, 0:H, (diff // 2):(diff // 2 + W)]
|
271
|
+
elif(W > H):
|
272
|
+
diff = W - H
|
273
|
+
x0 = x0[:, :, (diff // 2):(diff // 2 + H), 0:W]
|
274
|
+
|
275
|
+
return x0,tfi_label,tfp_label
|
276
|
+
|
277
|
+
class DoubleNet(nn.Module):
|
278
|
+
def __init__(self):
|
279
|
+
super().__init__()
|
280
|
+
self.nbsn = BSN(n_channels=41, n_output=1,blind=False)
|
281
|
+
# self.bsn = BSN(n_channels=41, n_output=1,blind=True)
|
282
|
+
|
283
|
+
def forward(self, x):
|
284
|
+
out1,_,_ = self.nbsn(x)
|
285
|
+
|
286
|
+
return out1
|
287
|
+
|
288
|
+
if __name__ == '__main__':
|
289
|
+
a=DoubleNet().cuda()
|
290
|
+
print(a(torch.ones(2,41,40,40).cuda()))
|
Binary file
|