spikezoo 0.1.2__py3-none-any.whl → 0.2__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- spikezoo/__init__.py +13 -0
- spikezoo/archs/__pycache__/__init__.cpython-39.pyc +0 -0
- spikezoo/archs/base/__pycache__/nets.cpython-39.pyc +0 -0
- spikezoo/archs/base/nets.py +34 -0
- spikezoo/archs/bsf/README.md +92 -0
- spikezoo/archs/bsf/datasets/datasets.py +328 -0
- spikezoo/archs/bsf/datasets/ds_utils.py +64 -0
- spikezoo/archs/bsf/main.py +398 -0
- spikezoo/archs/bsf/metrics/psnr.py +22 -0
- spikezoo/archs/bsf/metrics/ssim.py +54 -0
- spikezoo/archs/bsf/models/bsf/__pycache__/align.cpython-39.pyc +0 -0
- spikezoo/archs/bsf/models/bsf/__pycache__/bsf.cpython-39.pyc +0 -0
- spikezoo/archs/bsf/models/bsf/__pycache__/rep.cpython-39.pyc +0 -0
- spikezoo/archs/bsf/models/bsf/align.py +154 -0
- spikezoo/archs/bsf/models/bsf/bsf.py +105 -0
- spikezoo/archs/bsf/models/bsf/dsft_convert.py +96 -0
- spikezoo/archs/bsf/models/bsf/rep.py +44 -0
- spikezoo/archs/bsf/models/get_model.py +7 -0
- spikezoo/archs/bsf/prepare_data/DSFT.py +62 -0
- spikezoo/archs/bsf/prepare_data/crop_dataset_train.py +135 -0
- spikezoo/archs/bsf/prepare_data/crop_dataset_val.py +139 -0
- spikezoo/archs/bsf/prepare_data/crop_train.sh +4 -0
- spikezoo/archs/bsf/prepare_data/crop_val.sh +4 -0
- spikezoo/archs/bsf/prepare_data/io_utils.py +64 -0
- spikezoo/archs/bsf/requirements.txt +9 -0
- spikezoo/archs/bsf/test.py +16 -0
- spikezoo/archs/bsf/utils.py +154 -0
- spikezoo/archs/spikeclip/__pycache__/nets.cpython-39.pyc +0 -0
- spikezoo/archs/spikeclip/nets.py +40 -0
- spikezoo/archs/spikeformer/CheckPoints/readme +1 -0
- spikezoo/archs/spikeformer/DataProcess/DataExtactor.py +60 -0
- spikezoo/archs/spikeformer/DataProcess/DataLoader.py +115 -0
- spikezoo/archs/spikeformer/DataProcess/LoadSpike.py +39 -0
- spikezoo/archs/spikeformer/EvalResults/readme +1 -0
- spikezoo/archs/spikeformer/LICENSE +21 -0
- spikezoo/archs/spikeformer/Metrics/Metrics.py +50 -0
- spikezoo/archs/spikeformer/Metrics/__init__.py +0 -0
- spikezoo/archs/spikeformer/Model/Loss.py +89 -0
- spikezoo/archs/spikeformer/Model/SpikeFormer.py +230 -0
- spikezoo/archs/spikeformer/Model/__init__.py +0 -0
- spikezoo/archs/spikeformer/Model/__pycache__/SpikeFormer.cpython-39.pyc +0 -0
- spikezoo/archs/spikeformer/Model/__pycache__/__init__.cpython-39.pyc +0 -0
- spikezoo/archs/spikeformer/README.md +30 -0
- spikezoo/archs/spikeformer/evaluate.py +87 -0
- spikezoo/archs/spikeformer/recon_real_data.py +97 -0
- spikezoo/archs/spikeformer/requirements.yml +95 -0
- spikezoo/archs/spikeformer/train.py +173 -0
- spikezoo/archs/spikeformer/utils.py +22 -0
- spikezoo/archs/spk2imgnet/.github/workflows/pylint.yml +23 -0
- spikezoo/archs/spk2imgnet/.gitignore +150 -0
- spikezoo/archs/spk2imgnet/DCNv2.py +135 -0
- spikezoo/archs/spk2imgnet/__pycache__/DCNv2.cpython-39.pyc +0 -0
- spikezoo/archs/spk2imgnet/__pycache__/align_arch.cpython-39.pyc +0 -0
- spikezoo/archs/spk2imgnet/__pycache__/nets.cpython-39.pyc +0 -0
- spikezoo/archs/spk2imgnet/align_arch.py +159 -0
- spikezoo/archs/spk2imgnet/dataset.py +144 -0
- spikezoo/archs/spk2imgnet/nets.py +230 -0
- spikezoo/archs/spk2imgnet/readme.md +86 -0
- spikezoo/archs/spk2imgnet/test_gen_imgseq.py +118 -0
- spikezoo/archs/spk2imgnet/train.py +189 -0
- spikezoo/archs/spk2imgnet/utils.py +64 -0
- spikezoo/archs/ssir/README.md +87 -0
- spikezoo/archs/ssir/configs/SSIR.yml +37 -0
- spikezoo/archs/ssir/configs/yml_parser.py +78 -0
- spikezoo/archs/ssir/datasets/dataset_sreds.py +170 -0
- spikezoo/archs/ssir/datasets/ds_utils.py +66 -0
- spikezoo/archs/ssir/losses.py +21 -0
- spikezoo/archs/ssir/main.py +326 -0
- spikezoo/archs/ssir/metrics/psnr.py +22 -0
- spikezoo/archs/ssir/metrics/ssim.py +54 -0
- spikezoo/archs/ssir/models/Vgg19.py +42 -0
- spikezoo/archs/ssir/models/__pycache__/layers.cpython-39.pyc +0 -0
- spikezoo/archs/ssir/models/__pycache__/networks.cpython-39.pyc +0 -0
- spikezoo/archs/ssir/models/layers.py +110 -0
- spikezoo/archs/ssir/models/networks.py +61 -0
- spikezoo/archs/ssir/requirements.txt +8 -0
- spikezoo/archs/ssir/shells/eval_SREDS.sh +6 -0
- spikezoo/archs/ssir/shells/train_SSIR.sh +12 -0
- spikezoo/archs/ssir/test.py +3 -0
- spikezoo/archs/ssir/utils.py +154 -0
- spikezoo/archs/ssml/__pycache__/cbam.cpython-39.pyc +0 -0
- spikezoo/archs/ssml/__pycache__/model.cpython-39.pyc +0 -0
- spikezoo/archs/ssml/cbam.py +224 -0
- spikezoo/archs/ssml/model.py +290 -0
- spikezoo/archs/ssml/res.png +0 -0
- spikezoo/archs/ssml/test.py +67 -0
- spikezoo/archs/stir/.git-credentials +0 -0
- spikezoo/archs/stir/README.md +65 -0
- spikezoo/archs/stir/ckpt_outputs/Descriptions.txt +1 -0
- spikezoo/archs/stir/configs/STIR.yml +37 -0
- spikezoo/archs/stir/configs/utils.py +155 -0
- spikezoo/archs/stir/configs/yml_parser.py +78 -0
- spikezoo/archs/stir/datasets/dataset_sreds.py +180 -0
- spikezoo/archs/stir/datasets/ds_utils.py +66 -0
- spikezoo/archs/stir/eval_SREDS.sh +5 -0
- spikezoo/archs/stir/main.py +397 -0
- spikezoo/archs/stir/metrics/losses.py +219 -0
- spikezoo/archs/stir/metrics/psnr.py +22 -0
- spikezoo/archs/stir/metrics/ssim.py +54 -0
- spikezoo/archs/stir/models/Vgg19.py +42 -0
- spikezoo/archs/stir/models/__pycache__/networks_STIR.cpython-39.pyc +0 -0
- spikezoo/archs/stir/models/__pycache__/submodules.cpython-39.pyc +0 -0
- spikezoo/archs/stir/models/__pycache__/transformer_new.cpython-39.pyc +0 -0
- spikezoo/archs/stir/models/networks_STIR.py +361 -0
- spikezoo/archs/stir/models/submodules.py +86 -0
- spikezoo/archs/stir/models/transformer_new.py +151 -0
- spikezoo/archs/stir/package_core/build/lib/package_core/__init__.py +0 -0
- spikezoo/archs/stir/package_core/build/lib/package_core/convertions.py +721 -0
- spikezoo/archs/stir/package_core/build/lib/package_core/disp_netS.py +133 -0
- spikezoo/archs/stir/package_core/build/lib/package_core/flow_utils.py +167 -0
- spikezoo/archs/stir/package_core/build/lib/package_core/generic_train_test.py +76 -0
- spikezoo/archs/stir/package_core/build/lib/package_core/geometry.py +458 -0
- spikezoo/archs/stir/package_core/build/lib/package_core/image_proc.py +183 -0
- spikezoo/archs/stir/package_core/build/lib/package_core/linalg.py +40 -0
- spikezoo/archs/stir/package_core/build/lib/package_core/losses.py +198 -0
- spikezoo/archs/stir/package_core/build/lib/package_core/metrics.py +51 -0
- spikezoo/archs/stir/package_core/build/lib/package_core/model_base.py +53 -0
- spikezoo/archs/stir/package_core/build/lib/package_core/net_basics.py +100 -0
- spikezoo/archs/stir/package_core/build/lib/package_core/resnet.py +333 -0
- spikezoo/archs/stir/package_core/build/lib/package_core/transforms.py +123 -0
- spikezoo/archs/stir/package_core/build/lib/package_core/utils.py +72 -0
- spikezoo/archs/stir/package_core/dist/package_core-0.0.0-py3.9.egg +0 -0
- spikezoo/archs/stir/package_core/package_core/__init__.py +0 -0
- spikezoo/archs/stir/package_core/package_core/__pycache__/__init__.cpython-39.pyc +0 -0
- spikezoo/archs/stir/package_core/package_core/__pycache__/net_basics.cpython-39.pyc +0 -0
- spikezoo/archs/stir/package_core/package_core/convertions.py +721 -0
- spikezoo/archs/stir/package_core/package_core/disp_netS.py +133 -0
- spikezoo/archs/stir/package_core/package_core/flow_utils.py +167 -0
- spikezoo/archs/stir/package_core/package_core/generic_train_test.py +76 -0
- spikezoo/archs/stir/package_core/package_core/geometry.py +458 -0
- spikezoo/archs/stir/package_core/package_core/image_proc.py +183 -0
- spikezoo/archs/stir/package_core/package_core/linalg.py +40 -0
- spikezoo/archs/stir/package_core/package_core/losses.py +198 -0
- spikezoo/archs/stir/package_core/package_core/metrics.py +51 -0
- spikezoo/archs/stir/package_core/package_core/model_base.py +53 -0
- spikezoo/archs/stir/package_core/package_core/net_basics.py +100 -0
- spikezoo/archs/stir/package_core/package_core/resnet.py +333 -0
- spikezoo/archs/stir/package_core/package_core/transforms.py +123 -0
- spikezoo/archs/stir/package_core/package_core/utils.py +72 -0
- spikezoo/archs/stir/package_core/package_core.egg-info/PKG-INFO +3 -0
- spikezoo/archs/stir/package_core/package_core.egg-info/SOURCES.txt +20 -0
- spikezoo/archs/stir/package_core/package_core.egg-info/dependency_links.txt +1 -0
- spikezoo/archs/stir/package_core/package_core.egg-info/top_level.txt +1 -0
- spikezoo/archs/stir/package_core/setup.py +5 -0
- spikezoo/archs/stir/requirements.txt +12 -0
- spikezoo/archs/stir/train_STIR.sh +9 -0
- spikezoo/archs/tfi/__pycache__/nets.cpython-39.pyc +0 -0
- spikezoo/archs/tfi/nets.py +43 -0
- spikezoo/archs/tfp/__pycache__/nets.cpython-39.pyc +0 -0
- spikezoo/archs/tfp/nets.py +13 -0
- spikezoo/archs/wgse/README.md +64 -0
- spikezoo/archs/wgse/__pycache__/dwtnets.cpython-39.pyc +0 -0
- spikezoo/archs/wgse/__pycache__/submodules.cpython-39.pyc +0 -0
- spikezoo/archs/wgse/dataset.py +59 -0
- spikezoo/archs/wgse/demo.png +0 -0
- spikezoo/archs/wgse/demo.py +83 -0
- spikezoo/archs/wgse/dwtnets.py +145 -0
- spikezoo/archs/wgse/eval.py +133 -0
- spikezoo/archs/wgse/logs/WGSE-Dwt1dNet-db8-5-ks3/log.txt +11 -0
- spikezoo/archs/wgse/submodules.py +68 -0
- spikezoo/archs/wgse/train.py +261 -0
- spikezoo/archs/wgse/transform.py +139 -0
- spikezoo/archs/wgse/utils.py +128 -0
- spikezoo/archs/wgse/weights/demo.png +0 -0
- spikezoo/data/base/test/gt/200_part1_key_id151.png +0 -0
- spikezoo/data/base/test/gt/200_part3_key_id151.png +0 -0
- spikezoo/data/base/test/gt/203_part1_key_id151.png +0 -0
- spikezoo/data/base/test/spike/200_part1_key_id151.dat +0 -0
- spikezoo/data/base/test/spike/200_part3_key_id151.dat +0 -0
- spikezoo/data/base/test/spike/203_part1_key_id151.dat +0 -0
- spikezoo/data/base/train/gt/203_part2_key_id151.png +0 -0
- spikezoo/data/base/train/gt/203_part3_key_id151.png +0 -0
- spikezoo/data/base/train/gt/203_part4_key_id151.png +0 -0
- spikezoo/data/base/train/spike/203_part2_key_id151.dat +0 -0
- spikezoo/data/base/train/spike/203_part3_key_id151.dat +0 -0
- spikezoo/data/base/train/spike/203_part4_key_id151.dat +0 -0
- spikezoo/datasets/base_dataset.py +2 -3
- spikezoo/metrics/__init__.py +1 -1
- spikezoo/models/base_model.py +1 -3
- spikezoo/pipeline/base_pipeline.py +7 -5
- spikezoo/pipeline/train_pipeline.py +1 -1
- spikezoo/utils/other_utils.py +16 -6
- spikezoo/utils/spike_utils.py +33 -29
- spikezoo/utils/vidar_loader.cpython-39-x86_64-linux-gnu.so +0 -0
- spikezoo-0.2.dist-info/METADATA +163 -0
- spikezoo-0.2.dist-info/RECORD +211 -0
- spikezoo/models/spcsnet_model.py +0 -19
- spikezoo-0.1.2.dist-info/METADATA +0 -39
- spikezoo-0.1.2.dist-info/RECORD +0 -36
- {spikezoo-0.1.2.dist-info → spikezoo-0.2.dist-info}/LICENSE.txt +0 -0
- {spikezoo-0.1.2.dist-info → spikezoo-0.2.dist-info}/WHEEL +0 -0
- {spikezoo-0.1.2.dist-info → spikezoo-0.2.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,16 @@
|
|
1
|
+
import numpy as np
|
2
|
+
import torch
|
3
|
+
# import cupy as cp
|
4
|
+
|
5
|
+
def compute_dsft_core(spike):
|
6
|
+
H, W, T = spike.shape
|
7
|
+
time = spike * torch.arange(T, device='cuda').reshape(1, 1, T)
|
8
|
+
l_idx, _ = time.cummax(dim=2)
|
9
|
+
time[time==0] = T
|
10
|
+
r_idx, _ = torch.flip(time, [2]).cummin(dim=2)
|
11
|
+
r_idx = torch.flip(r_idx, [2])
|
12
|
+
r_idx = torch.cat([r_idx[:, :, 1:], torch.ones([H, W, 1], device='cuda') * T], dim=2)
|
13
|
+
res = r_idx - l_idx
|
14
|
+
|
15
|
+
res = torch.clip(res, 0)
|
16
|
+
return res
|
@@ -0,0 +1,154 @@
|
|
1
|
+
import numpy as np
|
2
|
+
import torch
|
3
|
+
import torch.nn.functional as F
|
4
|
+
import os
|
5
|
+
import os.path as osp
|
6
|
+
import random
|
7
|
+
import cv2
|
8
|
+
|
9
|
+
def set_seeds(_seed_):
|
10
|
+
random.seed(_seed_)
|
11
|
+
np.random.seed(_seed_)
|
12
|
+
torch.manual_seed(_seed_) # use torch.manual_seed() to seed the RNG for all devices (both CPU and CUDA)
|
13
|
+
torch.cuda.manual_seed_all(_seed_)
|
14
|
+
|
15
|
+
torch.backends.cudnn.deterministic = True
|
16
|
+
torch.backends.cudnn.benchmark = False
|
17
|
+
# os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'
|
18
|
+
# set a debug environment variable CUBLAS_WORKSPACE_CONFIG to ":16:8" (may limit overall performance) or ":4096:8" (will increase library footprint in GPU memory by approximately 24MiB).
|
19
|
+
# torch.use_deterministic_algorithms(True)
|
20
|
+
|
21
|
+
|
22
|
+
def make_dir(path):
|
23
|
+
if not osp.exists(path):
|
24
|
+
os.makedirs(path)
|
25
|
+
return
|
26
|
+
|
27
|
+
|
28
|
+
def add_args_to_cfg(cfg, args, args_list):
|
29
|
+
for aa in args_list:
|
30
|
+
cfg['train'][aa] = eval('args.{:s}'.format(aa))
|
31
|
+
return cfg
|
32
|
+
|
33
|
+
|
34
|
+
# class AverageMeter(object):
|
35
|
+
# """Computes and stores the average and current value"""
|
36
|
+
# def __init__(self, precision=3):
|
37
|
+
# self.precision = precision
|
38
|
+
# self.reset()
|
39
|
+
|
40
|
+
# def reset(self):
|
41
|
+
# self.val = 0
|
42
|
+
# self.avg = 0
|
43
|
+
# self.sum = 0
|
44
|
+
# self.count = 0
|
45
|
+
|
46
|
+
# def update(self, val, n=1):
|
47
|
+
# self.val = val
|
48
|
+
# self.sum += val * n
|
49
|
+
# self.count += n
|
50
|
+
# self.avg = self.sum / self.count
|
51
|
+
|
52
|
+
# def __repr__(self):
|
53
|
+
# return '{:.{}f} ({:.{}f})'.format(self.val, self.precision, self.avg, self.precision)
|
54
|
+
|
55
|
+
|
56
|
+
class AverageMeter(object):
|
57
|
+
"""Computes and stores the average and current value"""
|
58
|
+
|
59
|
+
def __init__(self, i=1, precision=3, names=None):
|
60
|
+
self.meters = i
|
61
|
+
self.precision = precision
|
62
|
+
self.reset(self.meters)
|
63
|
+
self.names = names
|
64
|
+
if names is not None:
|
65
|
+
assert self.meters == len(self.names)
|
66
|
+
else:
|
67
|
+
self.names = [''] * self.meters
|
68
|
+
|
69
|
+
def reset(self, i):
|
70
|
+
self.val = [0] * i
|
71
|
+
self.avg = [0] * i
|
72
|
+
self.sum = [0] * i
|
73
|
+
self.count = [0] * i
|
74
|
+
|
75
|
+
def update(self, val, n=1):
|
76
|
+
if not isinstance(val, list):
|
77
|
+
val = [val]
|
78
|
+
if not isinstance(n, list):
|
79
|
+
n = [n] * self.meters
|
80
|
+
assert (len(val) == self.meters and len(n) == self.meters)
|
81
|
+
for i in range(self.meters):
|
82
|
+
self.count[i] += n[i]
|
83
|
+
for i, v in enumerate(val):
|
84
|
+
self.val[i] = v
|
85
|
+
self.sum[i] += v * n[i]
|
86
|
+
self.avg[i] = self.sum[i] / self.count[i]
|
87
|
+
|
88
|
+
def __repr__(self):
|
89
|
+
# val = ' '.join(['{} {:.{}f}'.format(n, v, self.precision) for n, v in
|
90
|
+
# zip(self.names, self.val)])
|
91
|
+
# avg = ' '.join(['{} {:.{}f}'.format(n, a, self.precision) for n, a in
|
92
|
+
# zip(self.names, self.avg)])
|
93
|
+
out = ' '.join(['{} {:.{}f} ({:.{}f})'.format(n, v, self.precision, a, self.precision) for n, v, a in
|
94
|
+
zip(self.names, self.val, self.avg)])
|
95
|
+
# return '{} ({})'.format(val, avg)
|
96
|
+
return '{}'.format(out)
|
97
|
+
|
98
|
+
|
99
|
+
def normalize_image_torch(image, percentile_lower=1, percentile_upper=99):
|
100
|
+
b, c, h, w = image.shape
|
101
|
+
image_reshape = image.reshape([b, c, h*w])
|
102
|
+
mini = torch.quantile(image_reshape, 0.01, dim=2, keepdim=True).unsqueeze_(dim=3)
|
103
|
+
maxi = torch.quantile(image_reshape, 0.99, dim=2, keepdim=True).unsqueeze_(dim=3)
|
104
|
+
# if mini == maxi:
|
105
|
+
# return 0 * image + 0.5 # gray image
|
106
|
+
return torch.clip((image - mini) / (maxi - mini + 1e-5), 0, 1)
|
107
|
+
|
108
|
+
def normalize_image_torch2(image):
|
109
|
+
return torch.clip(image, 0, 1)
|
110
|
+
|
111
|
+
# --------------------------------------------
|
112
|
+
# Torch to Numpy 0~255
|
113
|
+
# --------------------------------------------
|
114
|
+
def torch2numpy255(im):
|
115
|
+
im = im[0, 0].detach().cpu().numpy()
|
116
|
+
im = (im * 255).astype(np.uint8)
|
117
|
+
return im
|
118
|
+
|
119
|
+
def torch2torch255(im):
|
120
|
+
return im * 255.0
|
121
|
+
|
122
|
+
class InputPadder:
|
123
|
+
""" Pads images such that dimensions are divisible by padsize """
|
124
|
+
def __init__(self, dims, padsize=16):
|
125
|
+
self.ht, self.wd = dims[-2:]
|
126
|
+
pad_ht = (((self.ht // padsize) + 1) * padsize - self.ht) % padsize
|
127
|
+
pad_wd = (((self.wd // padsize) + 1) * padsize - self.wd) %padsize
|
128
|
+
self._pad = [pad_wd//2, pad_wd - pad_wd//2, 0, pad_ht]
|
129
|
+
|
130
|
+
def pad(self, *inputs):
|
131
|
+
return [F.pad(x, self._pad, mode='replicate') for x in inputs]
|
132
|
+
|
133
|
+
def unpad(self,x):
|
134
|
+
ht, wd = x.shape[-2:]
|
135
|
+
c = [self._pad[2], ht-self._pad[3], self._pad[0], wd-self._pad[1]]
|
136
|
+
return x[..., c[0]:c[1], c[2]:c[3]]
|
137
|
+
|
138
|
+
|
139
|
+
|
140
|
+
def vis_img(vis_path: str, img: torch.Tensor, vis_name: str = 'vis'):
|
141
|
+
ww = 0
|
142
|
+
rows = []
|
143
|
+
for ii in range(4):
|
144
|
+
cur_row = []
|
145
|
+
for jj in range(img.shape[0]//4):
|
146
|
+
cur_img = img[ww, 0].detach().cpu().numpy() * 255
|
147
|
+
cur_img = cur_img.astype(np.uint8)
|
148
|
+
cur_row.append(cur_img)
|
149
|
+
ww += 1
|
150
|
+
cur_row_cat = np.concatenate(cur_row, axis=1)
|
151
|
+
rows.append(cur_row_cat)
|
152
|
+
out_img = np.concatenate(rows, axis=0)
|
153
|
+
cv2.imwrite(osp.join(vis_path, vis_name+'.png'), out_img)
|
154
|
+
return
|
Binary file
|
@@ -0,0 +1,40 @@
|
|
1
|
+
import torch.nn as nn
|
2
|
+
import torch
|
3
|
+
|
4
|
+
def conv_layer(inDim, outDim, ks, s, p, norm_layer='none'):
|
5
|
+
## convolutional layer
|
6
|
+
conv = nn.Conv2d(inDim, outDim, kernel_size=ks, stride=s, padding=p)
|
7
|
+
relu = nn.ReLU(True)
|
8
|
+
assert norm_layer in ('batch', 'instance', 'none')
|
9
|
+
if norm_layer == 'none':
|
10
|
+
seq = nn.Sequential(*[conv, relu])
|
11
|
+
else:
|
12
|
+
if (norm_layer == 'instance'):
|
13
|
+
norm = nn.InstanceNorm2d(outDim, affine=False, track_running_stats=False) # instance norm
|
14
|
+
else:
|
15
|
+
momentum = 0.1
|
16
|
+
norm = nn.BatchNorm2d(outDim, momentum = momentum, affine=True, track_running_stats=True)
|
17
|
+
seq = nn.Sequential(*[conv, norm, relu])
|
18
|
+
return seq
|
19
|
+
|
20
|
+
def LRN(inDim=50, outDim=1, norm='none'):
|
21
|
+
convBlock1 = conv_layer(inDim,64,3,1,1)
|
22
|
+
convBlock2 = conv_layer(64,128,3,1,1,norm)
|
23
|
+
convBlock3 = conv_layer(128,64,3,1,1,norm)
|
24
|
+
convBlock4 = conv_layer(64,16,3,1,1,norm)
|
25
|
+
conv = nn.Conv2d(16, outDim, 3, 1, 1)
|
26
|
+
seq = nn.Sequential(*[convBlock1, convBlock2, convBlock3, convBlock4, conv])
|
27
|
+
return seq
|
28
|
+
|
29
|
+
|
30
|
+
from thop import profile
|
31
|
+
if __name__ == "__main__":
|
32
|
+
net = LRN()
|
33
|
+
total = sum(p.numel() for p in net.parameters())
|
34
|
+
spike = torch.zeros((1,50,250,400))
|
35
|
+
flops, _ = profile((net), inputs=(spike,))
|
36
|
+
re_msg = (
|
37
|
+
"Total params: %.4fM" % (total / 1e6),
|
38
|
+
"FLOPs=" + str(flops / 1e9) + '{}'.format("G"),
|
39
|
+
)
|
40
|
+
print(re_msg)
|
@@ -0,0 +1 @@
|
|
1
|
+
This is a folder for saving the trained model !
|
@@ -0,0 +1,60 @@
|
|
1
|
+
import os
|
2
|
+
import numpy as np
|
3
|
+
|
4
|
+
|
5
|
+
class DataExtractor():
|
6
|
+
|
7
|
+
def __init__(self, dataPath='', type='train'):
|
8
|
+
|
9
|
+
self.type = type
|
10
|
+
self.rootPath = dataPath
|
11
|
+
|
12
|
+
def GetData(self):
|
13
|
+
|
14
|
+
if self.type == "train":
|
15
|
+
return self.__GetTrainData()
|
16
|
+
if self.type == "valid":
|
17
|
+
return self.__GetValidData()
|
18
|
+
if self.type == "test":
|
19
|
+
return self.__GetTestData()
|
20
|
+
|
21
|
+
|
22
|
+
def __GetTrainData(self):
|
23
|
+
|
24
|
+
pathList = []
|
25
|
+
|
26
|
+
root = os.path.join(self.rootPath, 'train')
|
27
|
+
fileNames = os.listdir(root)
|
28
|
+
fileNames.sort()
|
29
|
+
for name in fileNames:
|
30
|
+
path = os.path.join(root, name)
|
31
|
+
pathList.append(path)
|
32
|
+
|
33
|
+
return pathList
|
34
|
+
|
35
|
+
def __GetValidData(self):
|
36
|
+
|
37
|
+
pathList = []
|
38
|
+
|
39
|
+
root = os.path.join(self.rootPath, 'valid')
|
40
|
+
fileNames = os.listdir(root)
|
41
|
+
fileNames.sort()
|
42
|
+
for name in fileNames:
|
43
|
+
path = os.path.join(root, name)
|
44
|
+
pathList.append(path)
|
45
|
+
|
46
|
+
return pathList
|
47
|
+
|
48
|
+
def __GetTestData(self):
|
49
|
+
|
50
|
+
pathList = []
|
51
|
+
|
52
|
+
root = os.path.join(self.rootPath, 'test')
|
53
|
+
fileNames = os.listdir(root)
|
54
|
+
fileNames.sort()
|
55
|
+
for name in fileNames:
|
56
|
+
path = os.path.join(root, name)
|
57
|
+
pathList.append(path)
|
58
|
+
|
59
|
+
return pathList
|
60
|
+
|
@@ -0,0 +1,115 @@
|
|
1
|
+
import os
|
2
|
+
import torch
|
3
|
+
# from torchvision import transforms
|
4
|
+
from torch.utils import data
|
5
|
+
import numpy as np
|
6
|
+
from PIL import Image
|
7
|
+
import cv2
|
8
|
+
import random
|
9
|
+
|
10
|
+
|
11
|
+
from DataProcess.DataExtactor import DataExtractor
|
12
|
+
from DataProcess.LoadSpike import LoadSpike, load_spike_raw
|
13
|
+
|
14
|
+
class Dataset(data.Dataset):
|
15
|
+
|
16
|
+
def __init__(self, pathList, dataType, spikeRadius):
|
17
|
+
|
18
|
+
self.pathList = pathList
|
19
|
+
self.dataType = dataType
|
20
|
+
self.spikeRadius = spikeRadius
|
21
|
+
|
22
|
+
#Random Rotation
|
23
|
+
if self.dataType == "train":
|
24
|
+
self.choice = [0, 1, 2, 3]
|
25
|
+
else:
|
26
|
+
self.choice = [0]
|
27
|
+
|
28
|
+
def __getitem__(self, index):
|
29
|
+
|
30
|
+
spSeq, gtFrames = self.GetItem(index)
|
31
|
+
|
32
|
+
return spSeq, gtFrames
|
33
|
+
|
34
|
+
def __len__(self):
|
35
|
+
|
36
|
+
return len(self.pathList)
|
37
|
+
|
38
|
+
def GetItem(self, index):
|
39
|
+
|
40
|
+
path = self.pathList[index]
|
41
|
+
spSeq, gtFrames = LoadSpike(path)
|
42
|
+
|
43
|
+
spLen, _, _ = spSeq.shape
|
44
|
+
gtLen, _, _ = gtFrames.shape
|
45
|
+
spCenter = spLen // 2
|
46
|
+
gtCenter = gtLen // 2
|
47
|
+
|
48
|
+
spLeft, spRight = (spCenter - self.spikeRadius,
|
49
|
+
spCenter + self.spikeRadius)
|
50
|
+
spRight = spRight + 1
|
51
|
+
spSeq = spSeq[spLeft:spRight]
|
52
|
+
|
53
|
+
gtFrame = gtFrames[gtCenter]
|
54
|
+
|
55
|
+
spSeq = np.pad(spSeq, ((0, 0), (3, 3), (0, 0)), mode='constant')
|
56
|
+
spSeq = spSeq.astype(float) * 2 - 1
|
57
|
+
|
58
|
+
gtFrame = gtFrame.astype(float) / 255. * 2.0 - 1.
|
59
|
+
|
60
|
+
|
61
|
+
spSeq = torch.FloatTensor(spSeq)
|
62
|
+
gtFrame = torch.FloatTensor(gtFrame)
|
63
|
+
|
64
|
+
'''
|
65
|
+
Rotate the spike frame and Gt frame by ramdom degree,
|
66
|
+
depending on the values of 'self.choice'
|
67
|
+
'''
|
68
|
+
# choice = random.choice(self.choice)
|
69
|
+
# spSeq = torch.rot90(spSeq, choice, dims=(1,2))
|
70
|
+
# gtFrame =torch.rot90(gtFrame, choice, dims=(1,2))
|
71
|
+
return spSeq, gtFrame
|
72
|
+
|
73
|
+
|
74
|
+
|
75
|
+
|
76
|
+
|
77
|
+
class DataContainer():
|
78
|
+
|
79
|
+
def __init__(self, dataPath='', dataType='train',
|
80
|
+
spikeRadius=16, batchSize=128, numWorks=0):
|
81
|
+
|
82
|
+
self.dataPath = dataPath
|
83
|
+
self.dataType = dataType
|
84
|
+
self.spikeRadius = spikeRadius
|
85
|
+
self.batchSize = batchSize
|
86
|
+
self.numWorks = numWorks
|
87
|
+
|
88
|
+
self.__GetData()
|
89
|
+
|
90
|
+
def __GetData(self):
|
91
|
+
|
92
|
+
dataset = None
|
93
|
+
|
94
|
+
dataset = DataExtractor(dataPath=self.dataPath, type=self.dataType)
|
95
|
+
self.pathList = dataset.GetData()
|
96
|
+
|
97
|
+
def GetLoader(self):
|
98
|
+
|
99
|
+
dataset = Dataset(self.pathList, self.dataType, self.spikeRadius)
|
100
|
+
dataLoader = None
|
101
|
+
if self.dataType == "train":
|
102
|
+
dataLoader = data.DataLoader(dataset, batch_size=self.batchSize, shuffle=True,
|
103
|
+
num_workers=self.numWorks, pin_memory=False)
|
104
|
+
else:
|
105
|
+
dataLoader = data.DataLoader(dataset, batch_size=self.batchSize, shuffle=False,
|
106
|
+
num_workers=self.numWorks, pin_memory=False)
|
107
|
+
|
108
|
+
return dataLoader
|
109
|
+
|
110
|
+
if __name__ == "__main__":
|
111
|
+
|
112
|
+
pass
|
113
|
+
|
114
|
+
|
115
|
+
|
@@ -0,0 +1,39 @@
|
|
1
|
+
import numpy as np
|
2
|
+
|
3
|
+
def load_spike_numpy(path: str) -> (np.ndarray, np.ndarray):
|
4
|
+
'''
|
5
|
+
Load a spike sequence with it's tag from prepacked `.npz` file.\n
|
6
|
+
The sequence is of shape (`length`, `height`, `width`) and tag of
|
7
|
+
shape (`height`, `width`).
|
8
|
+
'''
|
9
|
+
data = np.load(path)
|
10
|
+
seq, tag, length = data['seq'], data['tag'], int(data['length'])
|
11
|
+
seq = np.array([(seq[i // 8] >> (i & 7)) & 1 for i in range(length)])
|
12
|
+
return seq, tag
|
13
|
+
|
14
|
+
def LoadSpike(path: str) -> (np.ndarray, np.ndarray):
|
15
|
+
'''
|
16
|
+
Load a spike sequence, the corresponding ground-truth frame sequence,
|
17
|
+
and sequence length.
|
18
|
+
spSeq: an ndarray of shape('sequence number', 'height', 'width')
|
19
|
+
gtFrames: an ndarray of shape('sequence length', 'height', 'width')
|
20
|
+
'''
|
21
|
+
data = np.load(path)
|
22
|
+
spSeq, gtFrames, length = data['spSeq'], data['gt'], int(data['length'])
|
23
|
+
spSeq = np.array([(spSeq[i // 8] >> (i & 7)) & 1 for i in range(length)])
|
24
|
+
return spSeq, gtFrames
|
25
|
+
|
26
|
+
def load_spike_raw(path: str, width=400, height=250) -> np.ndarray:
|
27
|
+
'''
|
28
|
+
Load bit-compact raw spike data into an ndarray of shape
|
29
|
+
(`sequence length`, `height`, `width`).
|
30
|
+
'''
|
31
|
+
with open(path, 'rb') as f:
|
32
|
+
fbytes = f.read()
|
33
|
+
fnum = (len(fbytes) * 8) // (width * height) # number of frames
|
34
|
+
frames = np.frombuffer(fbytes, dtype=np.uint8)
|
35
|
+
frames = np.array([frames & (1 << i) for i in range(8)])
|
36
|
+
frames = frames.astype(np.bool).astype(np.uint8)
|
37
|
+
frames = frames.transpose(1, 0).reshape(fnum, height, width)
|
38
|
+
frames = np.flip(frames, 1)
|
39
|
+
return frames
|
@@ -0,0 +1 @@
|
|
1
|
+
This is a folder for saving the images reconstructed from validation/testing set !
|
@@ -0,0 +1,21 @@
|
|
1
|
+
MIT License
|
2
|
+
|
3
|
+
Copyright (c) 2022 YangChenUcas
|
4
|
+
|
5
|
+
Permission is hereby granted, free of charge, to any person obtaining a copy
|
6
|
+
of this software and associated documentation files (the "Software"), to deal
|
7
|
+
in the Software without restriction, including without limitation the rights
|
8
|
+
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
9
|
+
copies of the Software, and to permit persons to whom the Software is
|
10
|
+
furnished to do so, subject to the following conditions:
|
11
|
+
|
12
|
+
The above copyright notice and this permission notice shall be included in all
|
13
|
+
copies or substantial portions of the Software.
|
14
|
+
|
15
|
+
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
16
|
+
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
17
|
+
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
18
|
+
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
19
|
+
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
20
|
+
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
21
|
+
SOFTWARE.
|
@@ -0,0 +1,50 @@
|
|
1
|
+
import numpy as np
|
2
|
+
from skimage import metrics
|
3
|
+
|
4
|
+
class Metrics():
|
5
|
+
|
6
|
+
def __init__(self):
|
7
|
+
self.best_psnr = 0.
|
8
|
+
self.best_ssim = 0.
|
9
|
+
self.best_niqe = 0.
|
10
|
+
|
11
|
+
def Update(self, psnr=0., ssim=0., niqe=0.):
|
12
|
+
self.best_psnr = psnr
|
13
|
+
self.best_ssim = ssim
|
14
|
+
self.best_niqe = niqe
|
15
|
+
|
16
|
+
def GetBestMetrics(self):
|
17
|
+
|
18
|
+
return self.best_psnr, self.best_ssim, self.best_niqe
|
19
|
+
|
20
|
+
def Cal_PSNR(self, preImgs, gtImgs): #shape:[B, H, W]
|
21
|
+
|
22
|
+
B, _, _ = preImgs.shape
|
23
|
+
total_psnr = 0.
|
24
|
+
for i, (pre, gt) in enumerate(zip(preImgs, gtImgs)):
|
25
|
+
print(i+1, metrics.peak_signal_noise_ratio(gt, pre))
|
26
|
+
total_psnr += metrics.peak_signal_noise_ratio(gt, pre)
|
27
|
+
|
28
|
+
avg_psnr = total_psnr / B
|
29
|
+
|
30
|
+
return avg_psnr
|
31
|
+
|
32
|
+
def Cal_SSIM(self, preImgs, gtImgs): #shape:[B, H, W]
|
33
|
+
|
34
|
+
B, _, _ = preImgs.shape
|
35
|
+
total_ssim = 0.
|
36
|
+
for i, (pre, gt) in enumerate(zip(preImgs, gtImgs)):
|
37
|
+
total_ssim += metrics.structural_similarity(pre, gt)
|
38
|
+
|
39
|
+
avg_ssim = total_ssim / B
|
40
|
+
|
41
|
+
return avg_ssim
|
42
|
+
|
43
|
+
|
44
|
+
if __name__ == "__main__":
|
45
|
+
|
46
|
+
a = np.random.random((2,256,256))
|
47
|
+
b = np.random.random((2,256,256))
|
48
|
+
metrics = Metrics()
|
49
|
+
|
50
|
+
print(metrics.Cal_NIQE(a))
|
File without changes
|
@@ -0,0 +1,89 @@
|
|
1
|
+
import torch
|
2
|
+
import torch.nn as nn
|
3
|
+
import torch.nn.functional as F
|
4
|
+
|
5
|
+
class CharbonnierLoss(nn.Module):
|
6
|
+
"""Charbonnier Loss (L1)"""
|
7
|
+
|
8
|
+
def __init__(self, eps=1e-3):
|
9
|
+
super(CharbonnierLoss, self).__init__()
|
10
|
+
self.eps = eps
|
11
|
+
|
12
|
+
def forward(self, x, y):
|
13
|
+
diff = x - y
|
14
|
+
# loss = torch.sum(torch.sqrt(diff * diff + self.eps))
|
15
|
+
loss = torch.mean(torch.sqrt((diff * diff) + (self.eps*self.eps)))
|
16
|
+
return loss
|
17
|
+
|
18
|
+
class EdgeLoss(nn.Module):
|
19
|
+
def __init__(self):
|
20
|
+
super(EdgeLoss, self).__init__()
|
21
|
+
k = torch.Tensor([[.05, .25, .4, .25, .05]])
|
22
|
+
# self.kernel = torch.matmul(k.t(),k).unsqueeze(0).repeat(3,1,1,1)
|
23
|
+
self.kernel = torch.matmul(k.t(),k).unsqueeze(0).unsqueeze(0).repeat(1,3,1,1) #这个的repeat也是后加的
|
24
|
+
# print(self.kernel.shape)
|
25
|
+
if torch.cuda.is_available():
|
26
|
+
self.kernel = self.kernel.cuda()
|
27
|
+
self.loss = CharbonnierLoss()
|
28
|
+
|
29
|
+
def conv_gauss(self, img):
|
30
|
+
# print('aaaa')
|
31
|
+
# print(img.shape)
|
32
|
+
n_channels, _, kw, kh = self.kernel.shape
|
33
|
+
img = F.pad(img, (kw//2, kh//2, kw//2, kh//2), mode='replicate')
|
34
|
+
# return F.conv2d(img, self.kernel, groups=n_channels)
|
35
|
+
return F.conv2d(img, self.kernel)
|
36
|
+
|
37
|
+
def laplacian_kernel(self, current):
|
38
|
+
filtered = self.conv_gauss(current) # filter
|
39
|
+
down = filtered[:,:,::2,::2] # downsample
|
40
|
+
new_filter = torch.zeros_like(filtered)
|
41
|
+
new_filter[:,:,::2,::2] = down*4 # upsample
|
42
|
+
filtered = self.conv_gauss(new_filter.repeat(1,3,1,1)) # filter #这里为什么需要repeat一下?原文的目的是什么?否则不能正常运行
|
43
|
+
diff = current - filtered
|
44
|
+
return diff
|
45
|
+
|
46
|
+
def forward(self, x, y):
|
47
|
+
y = y.repeat(1,3,1,1)
|
48
|
+
x = x.repeat(1,3,1,1)
|
49
|
+
# print('bbbbbb')
|
50
|
+
# print(x.shape)
|
51
|
+
# print(y.shape)
|
52
|
+
loss = self.loss(self.laplacian_kernel(x), self.laplacian_kernel(y))
|
53
|
+
return loss
|
54
|
+
|
55
|
+
|
56
|
+
class VGGLoss4(nn.Module):
|
57
|
+
def __init__(self, path: str):
|
58
|
+
super().__init__()
|
59
|
+
self.features = nn.Sequential(
|
60
|
+
nn.Conv2d(3, 64, 3, 1, 1),
|
61
|
+
nn.ReLU(inplace=True),
|
62
|
+
nn.Conv2d(64, 64, 3, 1, 1),
|
63
|
+
nn.ReLU(inplace=True),
|
64
|
+
nn.MaxPool2d(2),
|
65
|
+
nn.Conv2d(64, 128, 3, 1, 1),
|
66
|
+
nn.ReLU(inplace=True),
|
67
|
+
nn.Conv2d(128, 128, 3, 1, 1),
|
68
|
+
nn.ReLU(inplace=True),
|
69
|
+
nn.MaxPool2d(2),
|
70
|
+
nn.Conv2d(128, 256, 3, 1, 1),
|
71
|
+
nn.ReLU(inplace=True),
|
72
|
+
# nn.Conv2d(256, 256, 3, 1, 1),
|
73
|
+
# nn.ReLU(inplace=True),
|
74
|
+
# nn.Conv2d(256, 256, 3, 1, 1),
|
75
|
+
# nn.ReLU(inplace=True),
|
76
|
+
)
|
77
|
+
self.load_state_dict(torch.load(path))
|
78
|
+
for param in self.parameters():
|
79
|
+
param.requires_grad = False
|
80
|
+
|
81
|
+
def forward(self, real_y, fake_y):
|
82
|
+
real_y = real_y.repeat((1, 3, 1, 1))
|
83
|
+
fake_y = fake_y.repeat((1, 3, 1, 1))
|
84
|
+
with torch.no_grad():
|
85
|
+
real_f = self.features(real_y)
|
86
|
+
fake_f = self.features(fake_y)
|
87
|
+
return F.mse_loss(real_f, fake_f)
|
88
|
+
|
89
|
+
|