spikezoo 0.1.2__py3-none-any.whl → 0.2.1__py3-none-any.whl
Sign up to get free protection for your applications and to get access to all the features.
- spikezoo/__init__.py +13 -0
- spikezoo/archs/__pycache__/__init__.cpython-39.pyc +0 -0
- spikezoo/archs/base/__pycache__/nets.cpython-39.pyc +0 -0
- spikezoo/archs/base/nets.py +34 -0
- spikezoo/archs/bsf/README.md +92 -0
- spikezoo/archs/bsf/datasets/datasets.py +328 -0
- spikezoo/archs/bsf/datasets/ds_utils.py +64 -0
- spikezoo/archs/bsf/main.py +398 -0
- spikezoo/archs/bsf/metrics/psnr.py +22 -0
- spikezoo/archs/bsf/metrics/ssim.py +54 -0
- spikezoo/archs/bsf/models/bsf/__pycache__/align.cpython-39.pyc +0 -0
- spikezoo/archs/bsf/models/bsf/__pycache__/bsf.cpython-39.pyc +0 -0
- spikezoo/archs/bsf/models/bsf/__pycache__/rep.cpython-39.pyc +0 -0
- spikezoo/archs/bsf/models/bsf/align.py +154 -0
- spikezoo/archs/bsf/models/bsf/bsf.py +105 -0
- spikezoo/archs/bsf/models/bsf/dsft_convert.py +96 -0
- spikezoo/archs/bsf/models/bsf/rep.py +44 -0
- spikezoo/archs/bsf/models/get_model.py +7 -0
- spikezoo/archs/bsf/prepare_data/DSFT.py +62 -0
- spikezoo/archs/bsf/prepare_data/crop_dataset_train.py +135 -0
- spikezoo/archs/bsf/prepare_data/crop_dataset_val.py +139 -0
- spikezoo/archs/bsf/prepare_data/crop_train.sh +4 -0
- spikezoo/archs/bsf/prepare_data/crop_val.sh +4 -0
- spikezoo/archs/bsf/prepare_data/io_utils.py +64 -0
- spikezoo/archs/bsf/requirements.txt +9 -0
- spikezoo/archs/bsf/test.py +16 -0
- spikezoo/archs/bsf/utils.py +154 -0
- spikezoo/archs/spikeclip/__pycache__/nets.cpython-39.pyc +0 -0
- spikezoo/archs/spikeclip/nets.py +40 -0
- spikezoo/archs/spikeformer/CheckPoints/readme +1 -0
- spikezoo/archs/spikeformer/DataProcess/DataExtactor.py +60 -0
- spikezoo/archs/spikeformer/DataProcess/DataLoader.py +115 -0
- spikezoo/archs/spikeformer/DataProcess/LoadSpike.py +39 -0
- spikezoo/archs/spikeformer/EvalResults/readme +1 -0
- spikezoo/archs/spikeformer/LICENSE +21 -0
- spikezoo/archs/spikeformer/Metrics/Metrics.py +50 -0
- spikezoo/archs/spikeformer/Metrics/__init__.py +0 -0
- spikezoo/archs/spikeformer/Model/Loss.py +89 -0
- spikezoo/archs/spikeformer/Model/SpikeFormer.py +230 -0
- spikezoo/archs/spikeformer/Model/__init__.py +0 -0
- spikezoo/archs/spikeformer/Model/__pycache__/SpikeFormer.cpython-39.pyc +0 -0
- spikezoo/archs/spikeformer/Model/__pycache__/__init__.cpython-39.pyc +0 -0
- spikezoo/archs/spikeformer/README.md +30 -0
- spikezoo/archs/spikeformer/evaluate.py +87 -0
- spikezoo/archs/spikeformer/recon_real_data.py +97 -0
- spikezoo/archs/spikeformer/requirements.yml +95 -0
- spikezoo/archs/spikeformer/train.py +173 -0
- spikezoo/archs/spikeformer/utils.py +22 -0
- spikezoo/archs/spk2imgnet/.github/workflows/pylint.yml +23 -0
- spikezoo/archs/spk2imgnet/.gitignore +150 -0
- spikezoo/archs/spk2imgnet/DCNv2.py +135 -0
- spikezoo/archs/spk2imgnet/__pycache__/DCNv2.cpython-39.pyc +0 -0
- spikezoo/archs/spk2imgnet/__pycache__/align_arch.cpython-39.pyc +0 -0
- spikezoo/archs/spk2imgnet/__pycache__/nets.cpython-39.pyc +0 -0
- spikezoo/archs/spk2imgnet/align_arch.py +159 -0
- spikezoo/archs/spk2imgnet/dataset.py +144 -0
- spikezoo/archs/spk2imgnet/nets.py +230 -0
- spikezoo/archs/spk2imgnet/readme.md +86 -0
- spikezoo/archs/spk2imgnet/test_gen_imgseq.py +118 -0
- spikezoo/archs/spk2imgnet/train.py +189 -0
- spikezoo/archs/spk2imgnet/utils.py +64 -0
- spikezoo/archs/ssir/README.md +87 -0
- spikezoo/archs/ssir/configs/SSIR.yml +37 -0
- spikezoo/archs/ssir/configs/yml_parser.py +78 -0
- spikezoo/archs/ssir/datasets/dataset_sreds.py +170 -0
- spikezoo/archs/ssir/datasets/ds_utils.py +66 -0
- spikezoo/archs/ssir/losses.py +21 -0
- spikezoo/archs/ssir/main.py +326 -0
- spikezoo/archs/ssir/metrics/psnr.py +22 -0
- spikezoo/archs/ssir/metrics/ssim.py +54 -0
- spikezoo/archs/ssir/models/Vgg19.py +42 -0
- spikezoo/archs/ssir/models/__pycache__/layers.cpython-39.pyc +0 -0
- spikezoo/archs/ssir/models/__pycache__/networks.cpython-39.pyc +0 -0
- spikezoo/archs/ssir/models/layers.py +110 -0
- spikezoo/archs/ssir/models/networks.py +61 -0
- spikezoo/archs/ssir/requirements.txt +8 -0
- spikezoo/archs/ssir/shells/eval_SREDS.sh +6 -0
- spikezoo/archs/ssir/shells/train_SSIR.sh +12 -0
- spikezoo/archs/ssir/test.py +3 -0
- spikezoo/archs/ssir/utils.py +154 -0
- spikezoo/archs/ssml/__pycache__/cbam.cpython-39.pyc +0 -0
- spikezoo/archs/ssml/__pycache__/model.cpython-39.pyc +0 -0
- spikezoo/archs/ssml/cbam.py +224 -0
- spikezoo/archs/ssml/model.py +290 -0
- spikezoo/archs/ssml/res.png +0 -0
- spikezoo/archs/ssml/test.py +67 -0
- spikezoo/archs/stir/.git-credentials +0 -0
- spikezoo/archs/stir/README.md +65 -0
- spikezoo/archs/stir/ckpt_outputs/Descriptions.txt +1 -0
- spikezoo/archs/stir/configs/STIR.yml +37 -0
- spikezoo/archs/stir/configs/utils.py +155 -0
- spikezoo/archs/stir/configs/yml_parser.py +78 -0
- spikezoo/archs/stir/datasets/dataset_sreds.py +180 -0
- spikezoo/archs/stir/datasets/ds_utils.py +66 -0
- spikezoo/archs/stir/eval_SREDS.sh +5 -0
- spikezoo/archs/stir/main.py +397 -0
- spikezoo/archs/stir/metrics/losses.py +219 -0
- spikezoo/archs/stir/metrics/psnr.py +22 -0
- spikezoo/archs/stir/metrics/ssim.py +54 -0
- spikezoo/archs/stir/models/Vgg19.py +42 -0
- spikezoo/archs/stir/models/__pycache__/networks_STIR.cpython-39.pyc +0 -0
- spikezoo/archs/stir/models/__pycache__/submodules.cpython-39.pyc +0 -0
- spikezoo/archs/stir/models/__pycache__/transformer_new.cpython-39.pyc +0 -0
- spikezoo/archs/stir/models/networks_STIR.py +361 -0
- spikezoo/archs/stir/models/submodules.py +86 -0
- spikezoo/archs/stir/models/transformer_new.py +151 -0
- spikezoo/archs/stir/package_core/build/lib/package_core/__init__.py +0 -0
- spikezoo/archs/stir/package_core/build/lib/package_core/convertions.py +721 -0
- spikezoo/archs/stir/package_core/build/lib/package_core/disp_netS.py +133 -0
- spikezoo/archs/stir/package_core/build/lib/package_core/flow_utils.py +167 -0
- spikezoo/archs/stir/package_core/build/lib/package_core/generic_train_test.py +76 -0
- spikezoo/archs/stir/package_core/build/lib/package_core/geometry.py +458 -0
- spikezoo/archs/stir/package_core/build/lib/package_core/image_proc.py +183 -0
- spikezoo/archs/stir/package_core/build/lib/package_core/linalg.py +40 -0
- spikezoo/archs/stir/package_core/build/lib/package_core/losses.py +198 -0
- spikezoo/archs/stir/package_core/build/lib/package_core/metrics.py +51 -0
- spikezoo/archs/stir/package_core/build/lib/package_core/model_base.py +53 -0
- spikezoo/archs/stir/package_core/build/lib/package_core/net_basics.py +100 -0
- spikezoo/archs/stir/package_core/build/lib/package_core/resnet.py +333 -0
- spikezoo/archs/stir/package_core/build/lib/package_core/transforms.py +123 -0
- spikezoo/archs/stir/package_core/build/lib/package_core/utils.py +72 -0
- spikezoo/archs/stir/package_core/dist/package_core-0.0.0-py3.9.egg +0 -0
- spikezoo/archs/stir/package_core/package_core/__init__.py +0 -0
- spikezoo/archs/stir/package_core/package_core/__pycache__/__init__.cpython-39.pyc +0 -0
- spikezoo/archs/stir/package_core/package_core/__pycache__/net_basics.cpython-39.pyc +0 -0
- spikezoo/archs/stir/package_core/package_core/convertions.py +721 -0
- spikezoo/archs/stir/package_core/package_core/disp_netS.py +133 -0
- spikezoo/archs/stir/package_core/package_core/flow_utils.py +167 -0
- spikezoo/archs/stir/package_core/package_core/generic_train_test.py +76 -0
- spikezoo/archs/stir/package_core/package_core/geometry.py +458 -0
- spikezoo/archs/stir/package_core/package_core/image_proc.py +183 -0
- spikezoo/archs/stir/package_core/package_core/linalg.py +40 -0
- spikezoo/archs/stir/package_core/package_core/losses.py +198 -0
- spikezoo/archs/stir/package_core/package_core/metrics.py +51 -0
- spikezoo/archs/stir/package_core/package_core/model_base.py +53 -0
- spikezoo/archs/stir/package_core/package_core/net_basics.py +100 -0
- spikezoo/archs/stir/package_core/package_core/resnet.py +333 -0
- spikezoo/archs/stir/package_core/package_core/transforms.py +123 -0
- spikezoo/archs/stir/package_core/package_core/utils.py +72 -0
- spikezoo/archs/stir/package_core/package_core.egg-info/PKG-INFO +3 -0
- spikezoo/archs/stir/package_core/package_core.egg-info/SOURCES.txt +20 -0
- spikezoo/archs/stir/package_core/package_core.egg-info/dependency_links.txt +1 -0
- spikezoo/archs/stir/package_core/package_core.egg-info/top_level.txt +1 -0
- spikezoo/archs/stir/package_core/setup.py +5 -0
- spikezoo/archs/stir/requirements.txt +12 -0
- spikezoo/archs/stir/train_STIR.sh +9 -0
- spikezoo/archs/tfi/__pycache__/nets.cpython-39.pyc +0 -0
- spikezoo/archs/tfi/nets.py +43 -0
- spikezoo/archs/tfp/__pycache__/nets.cpython-39.pyc +0 -0
- spikezoo/archs/tfp/nets.py +13 -0
- spikezoo/archs/wgse/README.md +64 -0
- spikezoo/archs/wgse/__pycache__/dwtnets.cpython-39.pyc +0 -0
- spikezoo/archs/wgse/__pycache__/submodules.cpython-39.pyc +0 -0
- spikezoo/archs/wgse/dataset.py +59 -0
- spikezoo/archs/wgse/demo.png +0 -0
- spikezoo/archs/wgse/demo.py +83 -0
- spikezoo/archs/wgse/dwtnets.py +145 -0
- spikezoo/archs/wgse/eval.py +133 -0
- spikezoo/archs/wgse/logs/WGSE-Dwt1dNet-db8-5-ks3/log.txt +11 -0
- spikezoo/archs/wgse/submodules.py +68 -0
- spikezoo/archs/wgse/train.py +261 -0
- spikezoo/archs/wgse/transform.py +139 -0
- spikezoo/archs/wgse/utils.py +128 -0
- spikezoo/archs/wgse/weights/demo.png +0 -0
- spikezoo/data/base/test/gt/200_part1_key_id151.png +0 -0
- spikezoo/data/base/test/gt/200_part3_key_id151.png +0 -0
- spikezoo/data/base/test/gt/203_part1_key_id151.png +0 -0
- spikezoo/data/base/test/spike/200_part1_key_id151.dat +0 -0
- spikezoo/data/base/test/spike/200_part3_key_id151.dat +0 -0
- spikezoo/data/base/test/spike/203_part1_key_id151.dat +0 -0
- spikezoo/data/base/train/gt/203_part2_key_id151.png +0 -0
- spikezoo/data/base/train/gt/203_part3_key_id151.png +0 -0
- spikezoo/data/base/train/gt/203_part4_key_id151.png +0 -0
- spikezoo/data/base/train/spike/203_part2_key_id151.dat +0 -0
- spikezoo/data/base/train/spike/203_part3_key_id151.dat +0 -0
- spikezoo/data/base/train/spike/203_part4_key_id151.dat +0 -0
- spikezoo/datasets/base_dataset.py +2 -3
- spikezoo/metrics/__init__.py +1 -1
- spikezoo/models/base_model.py +1 -3
- spikezoo/pipeline/base_pipeline.py +7 -5
- spikezoo/pipeline/train_pipeline.py +1 -1
- spikezoo/utils/other_utils.py +16 -6
- spikezoo/utils/spike_utils.py +33 -29
- spikezoo/utils/vidar_loader.cpython-39-x86_64-linux-gnu.so +0 -0
- spikezoo-0.2.1.dist-info/METADATA +167 -0
- spikezoo-0.2.1.dist-info/RECORD +211 -0
- spikezoo/models/spcsnet_model.py +0 -19
- spikezoo-0.1.2.dist-info/METADATA +0 -39
- spikezoo-0.1.2.dist-info/RECORD +0 -36
- {spikezoo-0.1.2.dist-info → spikezoo-0.2.1.dist-info}/LICENSE.txt +0 -0
- {spikezoo-0.1.2.dist-info → spikezoo-0.2.1.dist-info}/WHEEL +0 -0
- {spikezoo-0.1.2.dist-info → spikezoo-0.2.1.dist-info}/top_level.txt +0 -0
@@ -0,0 +1,261 @@
|
|
1
|
+
import argparse
|
2
|
+
import json
|
3
|
+
import os
|
4
|
+
import time
|
5
|
+
|
6
|
+
import numpy as np
|
7
|
+
import torch
|
8
|
+
import torch.nn.functional as F
|
9
|
+
from einops import rearrange
|
10
|
+
from pytorch_wavelets import DWT1DForward
|
11
|
+
|
12
|
+
from transform import Compose, RandomCrop, RandomRotationFlip
|
13
|
+
from dataset import DatasetREDS
|
14
|
+
from dwtnets import Dwt1dResnetX_TCN
|
15
|
+
from utils import calculate_psnr, calculate_ssim, mkdir
|
16
|
+
|
17
|
+
parser = argparse.ArgumentParser(description='AAAI - WGSE - REDS')
|
18
|
+
parser.add_argument('-c', '--cuda', type=str, default='1', help='select gpu card')
|
19
|
+
parser.add_argument('-b', '--batch_size', type=int, default=16)
|
20
|
+
parser.add_argument('-e', '--epoch', type=int, default=600)
|
21
|
+
parser.add_argument('-w', '--wvl', type=str, default='db8', help='select wavelet base function')
|
22
|
+
parser.add_argument('-j', '--jlevels', type=int, default=5)
|
23
|
+
parser.add_argument('-k', '--kernel_size', type=int, default=3)
|
24
|
+
parser.add_argument('-l', '--logpath', type=str, default='WGSE-Dwt1dNet')
|
25
|
+
parser.add_argument('-r', '--resume_from', type=str, default=None)
|
26
|
+
parser.add_argument('--dataroot', type=str, default=None)
|
27
|
+
|
28
|
+
args = parser.parse_args()
|
29
|
+
os.environ["CUDA_VISIBLE_DEVICES"] = str(args.cuda)
|
30
|
+
|
31
|
+
resume_folder = args.resume_from
|
32
|
+
batch_size = args.batch_size
|
33
|
+
learning_rate = 1e-4
|
34
|
+
train_epoch = args.epoch
|
35
|
+
dataroot = args.dataroot
|
36
|
+
|
37
|
+
opt = 'adam'
|
38
|
+
opt_param = "{\"beta1\":0.9,\"beta2\":0.99,\"weight_decay\":0}"
|
39
|
+
|
40
|
+
random_seed = True
|
41
|
+
manual_seed = 123
|
42
|
+
|
43
|
+
scheduler = "MultiStepLR"
|
44
|
+
scheduler_param = "{\"milestones\": [400, 600], \"gamma\": 0.2}"
|
45
|
+
|
46
|
+
wvlname = args.wvl
|
47
|
+
j = args.jlevels
|
48
|
+
ks = args.kernel_size
|
49
|
+
|
50
|
+
if_save_model = False
|
51
|
+
eval_freq = 1
|
52
|
+
checkpoints_folder = args.logpath + '-' + args.wvl + '-' + str(args.jlevels) + '-' + 'ks' + str(ks)
|
53
|
+
|
54
|
+
|
55
|
+
def progress_bar_time(total_time):
|
56
|
+
hour = int(total_time) // 3600
|
57
|
+
minu = (int(total_time) % 3600) // 60
|
58
|
+
sec = int(total_time) % 60
|
59
|
+
return '%d:%02d:%02d' % (hour, minu, sec)
|
60
|
+
|
61
|
+
def main():
|
62
|
+
|
63
|
+
global batch_size, learning_rate, random_seed, manual_seed, opt, opt_param, if_save_model, checkpoints_folder
|
64
|
+
|
65
|
+
mkdir(os.path.join('logs', checkpoints_folder))
|
66
|
+
|
67
|
+
if random_seed:
|
68
|
+
seed = np.random.randint(0, 10000)
|
69
|
+
else:
|
70
|
+
seed = manual_seed
|
71
|
+
torch.manual_seed(seed)
|
72
|
+
np.random.seed(seed)
|
73
|
+
|
74
|
+
opt_param_dict = json.loads(opt_param)
|
75
|
+
scheduler_param_dict = json.loads(scheduler_param)
|
76
|
+
|
77
|
+
cfg = {}
|
78
|
+
cfg['rootfolder'] = os.path.join(dataroot, 'train')
|
79
|
+
cfg['spikefolder'] = 'input'
|
80
|
+
cfg['imagefolder'] = 'gt'
|
81
|
+
cfg['H'] = 250
|
82
|
+
cfg['W'] = 400
|
83
|
+
cfg['C'] = 41
|
84
|
+
train_set = DatasetREDS(cfg,
|
85
|
+
transform=Compose(
|
86
|
+
[
|
87
|
+
RandomCrop(128),
|
88
|
+
RandomRotationFlip(0.0, 0.5, 0.5)
|
89
|
+
]),
|
90
|
+
)
|
91
|
+
|
92
|
+
cfg = {}
|
93
|
+
cfg['rootfolder'] = os.path.join(dataroot, 'val')
|
94
|
+
cfg['spikefolder'] = 'input'
|
95
|
+
cfg['imagefolder'] = 'gt'
|
96
|
+
cfg['H'] = 250
|
97
|
+
cfg['W'] = 400
|
98
|
+
cfg['C'] = 41
|
99
|
+
test_set = DatasetREDS(cfg)
|
100
|
+
|
101
|
+
print('train_set len', train_set.__len__())
|
102
|
+
print('test_set len', test_set.__len__())
|
103
|
+
|
104
|
+
train_data_loader = torch.utils.data.DataLoader(
|
105
|
+
dataset=train_set,
|
106
|
+
batch_size=batch_size,
|
107
|
+
shuffle=True,
|
108
|
+
num_workers=16,
|
109
|
+
drop_last=True)
|
110
|
+
test_data_loader = torch.utils.data.DataLoader(
|
111
|
+
dataset=test_set,
|
112
|
+
batch_size=1,
|
113
|
+
shuffle=True,
|
114
|
+
num_workers=1,
|
115
|
+
drop_last=False)
|
116
|
+
|
117
|
+
print(train_data_loader)
|
118
|
+
print(test_data_loader)
|
119
|
+
|
120
|
+
item0 = train_set[0]
|
121
|
+
s = item0['spikes']
|
122
|
+
s = s[None, :, 0:1, 0:1]
|
123
|
+
dwt = DWT1DForward(wave=wvlname, J=j)
|
124
|
+
B, T, H, W = s.shape
|
125
|
+
s_r = rearrange(s, 'b t h w -> b h w t')
|
126
|
+
s_r = rearrange(s_r, 'b h w t -> (b h w) 1 t')
|
127
|
+
yl, yh = dwt(s_r)
|
128
|
+
yl_size = yl.shape[-1]
|
129
|
+
yh_size = [yhi.shape[-1] for yhi in yh]
|
130
|
+
|
131
|
+
model = Dwt1dResnetX_TCN(inc=41, wvlname=wvlname, J=j, yl_size=yl_size, yh_size=yh_size, num_residual_blocks=3, norm=None, ks=ks)
|
132
|
+
|
133
|
+
|
134
|
+
if args.resume_from:
|
135
|
+
print("loading model weights from ", resume_folder)
|
136
|
+
saved_state_dict = torch.load(os.path.join(resume_folder, 'model_best.pt'))
|
137
|
+
model.load_state_dict(saved_state_dict.module.state_dict())
|
138
|
+
print("Weighted loaded.")
|
139
|
+
|
140
|
+
model = torch.nn.DataParallel(model).cuda()
|
141
|
+
|
142
|
+
# optimizer
|
143
|
+
if opt.lower() == 'adam':
|
144
|
+
assert ('beta1' in opt_param_dict.keys() and 'beta2' in opt_param_dict.keys() and 'weight_decay' in opt_param_dict.keys())
|
145
|
+
betas = (opt_param_dict['beta1'], opt_param_dict['beta2'])
|
146
|
+
del opt_param_dict['beta1']
|
147
|
+
del opt_param_dict['beta2']
|
148
|
+
opt_param_dict['betas'] = betas
|
149
|
+
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate, **opt_param_dict)
|
150
|
+
elif opt.lower() == 'sgd':
|
151
|
+
assert ('momentum' in opt_param_dict.keys() and 'weight_decay' in opt_param_dict.keys())
|
152
|
+
optimizer = torch.optim.SGD(model.parameters(), lr=learning_rate, **opt_param_dict)
|
153
|
+
else:
|
154
|
+
raise ValueError()
|
155
|
+
|
156
|
+
lr_scheduler = getattr(torch.optim.lr_scheduler, scheduler)(optimizer, **scheduler_param_dict)
|
157
|
+
best_psnr, best_ssim = 0.0, 0.0
|
158
|
+
|
159
|
+
for epoch in range(train_epoch+1):
|
160
|
+
print('Epoch %d/%d ... ' % (epoch, train_epoch))
|
161
|
+
|
162
|
+
model.train()
|
163
|
+
total_time = 0
|
164
|
+
f = open(os.path.join('logs', checkpoints_folder, 'log.txt'), "a")
|
165
|
+
for i, item in enumerate(train_data_loader):
|
166
|
+
|
167
|
+
start_time = time.time()
|
168
|
+
|
169
|
+
spikes = item['spikes'].cuda()
|
170
|
+
image = item['image'].cuda()
|
171
|
+
optimizer.zero_grad()
|
172
|
+
|
173
|
+
pred = model(spikes)
|
174
|
+
|
175
|
+
loss = F.l1_loss(image, pred)
|
176
|
+
loss.backward()
|
177
|
+
optimizer.step()
|
178
|
+
|
179
|
+
elapse_time = time.time() - start_time
|
180
|
+
total_time += elapse_time
|
181
|
+
|
182
|
+
lr_list = lr_scheduler.get_last_lr()
|
183
|
+
lr_str = ""
|
184
|
+
for ilr in lr_list:
|
185
|
+
lr_str += str(ilr) + ' '
|
186
|
+
print('\r[training] %3.2f%% | %6d/%6d [%s<%s, %.2fs/it] | LOSS: %.4f | LR: %s' % (
|
187
|
+
float(i + 1) / int(len(train_data_loader)) * 100, i + 1, int(len(train_data_loader)),
|
188
|
+
progress_bar_time(total_time),
|
189
|
+
progress_bar_time(total_time / (i + 1) * int(len(train_data_loader))),
|
190
|
+
total_time / (i + 1),
|
191
|
+
loss.item(),
|
192
|
+
lr_str), end='')
|
193
|
+
f.write('[training] %3.2f%% | %6d/%6d [%s<%s, %.2fs/it] | LOSS: %.4f | LR: %s\n' % (
|
194
|
+
float(i + 1) / int(len(train_data_loader)) * 100, i + 1, int(len(train_data_loader)),
|
195
|
+
progress_bar_time(total_time),
|
196
|
+
progress_bar_time(total_time / (i + 1) * int(len(train_data_loader))),
|
197
|
+
total_time / (i + 1),
|
198
|
+
loss.item(),
|
199
|
+
lr_str))
|
200
|
+
|
201
|
+
lr_scheduler.step()
|
202
|
+
|
203
|
+
print('')
|
204
|
+
if epoch % eval_freq == 0:
|
205
|
+
model.eval()
|
206
|
+
with torch.no_grad():
|
207
|
+
sum_ssim = 0.0
|
208
|
+
sum_psnr = 0.0
|
209
|
+
sum_num = 0
|
210
|
+
total_time = 0
|
211
|
+
for i, item in enumerate(test_data_loader):
|
212
|
+
start_time = time.time()
|
213
|
+
|
214
|
+
spikes = item['spikes'][:, 130:171, :, :].cuda()
|
215
|
+
image = item['image'].cuda()
|
216
|
+
|
217
|
+
pred = model(spikes)
|
218
|
+
|
219
|
+
prediction = pred[0].permute(1,2,0).cpu().numpy()
|
220
|
+
gt = image[0].permute(1,2,0).cpu().numpy()
|
221
|
+
|
222
|
+
sum_ssim += calculate_ssim(gt * 255.0, prediction * 255.0)
|
223
|
+
sum_psnr += calculate_psnr(gt * 255.0, prediction * 255.0)
|
224
|
+
sum_num += 1
|
225
|
+
elapse_time = time.time() - start_time
|
226
|
+
total_time += elapse_time
|
227
|
+
print('\r[evaluating] %3.2f%% | %6d/%6d [%s<%s, %.2fs/it]' % (
|
228
|
+
float(i + 1) / int(len(test_data_loader)) * 100, i + 1, int(len(test_data_loader)),
|
229
|
+
progress_bar_time(total_time),
|
230
|
+
progress_bar_time(total_time / (i + 1) * int(len(test_data_loader))),
|
231
|
+
total_time / (i + 1)), end='')
|
232
|
+
f.write('[evaluating] %3.2f%% | %6d/%6d [%s<%s, %.2fs/it]\n' % (
|
233
|
+
float(i + 1) / int(len(test_data_loader)) * 100, i + 1, int(len(test_data_loader)),
|
234
|
+
progress_bar_time(total_time),
|
235
|
+
progress_bar_time(total_time / (i + 1) * int(len(test_data_loader))),
|
236
|
+
total_time / (i + 1)))
|
237
|
+
|
238
|
+
sum_psnr /= sum_num
|
239
|
+
sum_ssim /= sum_num
|
240
|
+
|
241
|
+
print('')
|
242
|
+
print('\r[Evaluation Result] PSNR: %.3f | SSIM: %.3f' % (sum_psnr, sum_ssim))
|
243
|
+
f.write('[Evaluation Result] PSNR: %.3f | SSIM: %.3f\n' % (sum_psnr, sum_ssim))
|
244
|
+
|
245
|
+
if if_save_model and epoch % eval_freq == 0:
|
246
|
+
print('saving net...')
|
247
|
+
torch.save(model, os.path.join('logs', checkpoints_folder) + '/model_epoch%d.pt' % epoch)
|
248
|
+
print('saved')
|
249
|
+
|
250
|
+
if sum_psnr > best_psnr or sum_ssim > best_ssim:
|
251
|
+
best_psnr = sum_psnr
|
252
|
+
best_ssim = sum_ssim
|
253
|
+
print('saving best net...')
|
254
|
+
torch.save(model, os.path.join('logs', checkpoints_folder) + '/model_best.pt')
|
255
|
+
print('saved')
|
256
|
+
|
257
|
+
f.close()
|
258
|
+
|
259
|
+
|
260
|
+
if __name__ == '__main__':
|
261
|
+
main()
|
@@ -0,0 +1,139 @@
|
|
1
|
+
import torch
|
2
|
+
import torch.nn.functional as F
|
3
|
+
from math import sin, cos, pi
|
4
|
+
import numbers
|
5
|
+
import random
|
6
|
+
|
7
|
+
|
8
|
+
class Compose(object):
|
9
|
+
"""Composes several transforms together.
|
10
|
+
Args:
|
11
|
+
transforms (list of ``Transform`` objects): list of transforms to compose.
|
12
|
+
Example:
|
13
|
+
>>> transforms.Compose([
|
14
|
+
>>> transforms.CenterCrop(10),
|
15
|
+
>>> transforms.ToTensor(),
|
16
|
+
>>> ])
|
17
|
+
"""
|
18
|
+
|
19
|
+
def __init__(self, transforms):
|
20
|
+
self.transforms = transforms
|
21
|
+
|
22
|
+
def __call__(self, x, y):
|
23
|
+
for t in self.transforms:
|
24
|
+
x, y = t(x, y)
|
25
|
+
return x, y
|
26
|
+
|
27
|
+
def __repr__(self):
|
28
|
+
format_string = self.__class__.__name__ + '('
|
29
|
+
for t in self.transforms:
|
30
|
+
format_string += '\n'
|
31
|
+
format_string += ' {0}'.format(t)
|
32
|
+
format_string += '\n)'
|
33
|
+
return format_string
|
34
|
+
|
35
|
+
|
36
|
+
class RandomCrop(object):
|
37
|
+
"""Crop the tensor at a random location.
|
38
|
+
"""
|
39
|
+
|
40
|
+
def __init__(self, size):
|
41
|
+
if isinstance(size, numbers.Number):
|
42
|
+
self.size = (int(size), int(size))
|
43
|
+
else:
|
44
|
+
self.size = size
|
45
|
+
|
46
|
+
@staticmethod
|
47
|
+
def get_params(x, output_size):
|
48
|
+
w, h = x.shape[2], x.shape[1]
|
49
|
+
th, tw = output_size
|
50
|
+
assert(th <= h)
|
51
|
+
assert(tw <= w)
|
52
|
+
if w == tw and h == th:
|
53
|
+
return 0, 0, h, w
|
54
|
+
|
55
|
+
i = random.randint(0, h - th)
|
56
|
+
j = random.randint(0, w - tw)
|
57
|
+
|
58
|
+
return i, j, th, tw
|
59
|
+
|
60
|
+
def __call__(self, x, y):
|
61
|
+
"""
|
62
|
+
x: [C x H x W] Tensor to be rotated.
|
63
|
+
Returns:
|
64
|
+
Tensor: Cropped tensor.
|
65
|
+
"""
|
66
|
+
i, j, h, w = self.get_params(x, self.size)
|
67
|
+
|
68
|
+
return x[:, i:i + h, j:j + w], y[:, i:i + h, j:j + w]
|
69
|
+
|
70
|
+
def __repr__(self):
|
71
|
+
return self.__class__.__name__ + '(size={0})'.format(self.size)
|
72
|
+
|
73
|
+
|
74
|
+
class RandomRotationFlip(object):
|
75
|
+
"""Rotate the image by angle.
|
76
|
+
"""
|
77
|
+
|
78
|
+
def __init__(self, degrees, p_hflip=0.5, p_vflip=0.5):
|
79
|
+
if isinstance(degrees, numbers.Number):
|
80
|
+
if degrees < 0:
|
81
|
+
raise ValueError("If degrees is a single number, it must be positive.")
|
82
|
+
self.degrees = (-degrees, degrees)
|
83
|
+
else:
|
84
|
+
if len(degrees) != 2:
|
85
|
+
raise ValueError("If degrees is a sequence, it must be of len 2.")
|
86
|
+
self.degrees = degrees
|
87
|
+
|
88
|
+
self.p_hflip = p_hflip
|
89
|
+
self.p_vflip = p_vflip
|
90
|
+
|
91
|
+
@staticmethod
|
92
|
+
def get_params(degrees, p_hflip, p_vflip):
|
93
|
+
"""Get parameters for ``rotate`` for a random rotation.
|
94
|
+
Returns:
|
95
|
+
sequence: params to be passed to ``rotate`` for random rotation.
|
96
|
+
"""
|
97
|
+
angle = random.uniform(degrees[0], degrees[1])
|
98
|
+
angle_rad = angle * pi / 180.0
|
99
|
+
|
100
|
+
M_original_transformed = torch.FloatTensor([[cos(angle_rad), -sin(angle_rad), 0],
|
101
|
+
[sin(angle_rad), cos(angle_rad), 0],
|
102
|
+
[0, 0, 1]])
|
103
|
+
|
104
|
+
if random.random() < p_hflip:
|
105
|
+
M_original_transformed[:, 0] *= -1
|
106
|
+
|
107
|
+
if random.random() < p_vflip:
|
108
|
+
M_original_transformed[:, 1] *= -1
|
109
|
+
|
110
|
+
M_transformed_original = torch.inverse(M_original_transformed)
|
111
|
+
|
112
|
+
M_original_transformed = M_original_transformed[:2, :].unsqueeze(dim=0) # 3 x 3 -> N x 2 x 3
|
113
|
+
M_transformed_original = M_transformed_original[:2, :].unsqueeze(dim=0)
|
114
|
+
|
115
|
+
return M_original_transformed, M_transformed_original
|
116
|
+
|
117
|
+
def __call__(self, x, y):
|
118
|
+
"""
|
119
|
+
x: [C x H x W] Tensor to be rotated.
|
120
|
+
Returns:
|
121
|
+
Tensor: Rotated tensor.
|
122
|
+
"""
|
123
|
+
assert(len(x.shape) == 3)
|
124
|
+
|
125
|
+
M_original_transformed, M_transformed_original = self.get_params(self.degrees, self.p_hflip, self.p_vflip)
|
126
|
+
affine_gridx = F.affine_grid(M_original_transformed, x.unsqueeze(dim=0).shape, align_corners=False)
|
127
|
+
transformedx = F.grid_sample(x.unsqueeze(dim=0), affine_gridx, align_corners=False)
|
128
|
+
|
129
|
+
affine_gridy = F.affine_grid(M_original_transformed, y.unsqueeze(dim=0).shape, align_corners=False)
|
130
|
+
transformedy = F.grid_sample(y.unsqueeze(dim=0), affine_gridy, align_corners=False)
|
131
|
+
|
132
|
+
return transformedx.squeeze(dim=0), transformedy.squeeze(dim=0)
|
133
|
+
|
134
|
+
def __repr__(self):
|
135
|
+
format_string = self.__class__.__name__ + '(degrees={0}'.format(self.degrees)
|
136
|
+
format_string += ', p_flip={:.2f}'.format(self.p_hflip)
|
137
|
+
format_string += ', p_vlip={:.2f}'.format(self.p_vflip)
|
138
|
+
format_string += ')'
|
139
|
+
return format_string
|
@@ -0,0 +1,128 @@
|
|
1
|
+
import os
|
2
|
+
import math
|
3
|
+
import numpy as np
|
4
|
+
import cv2
|
5
|
+
os.environ["KMP_DUPLICATE_LIB_OK"] = "TRUE"
|
6
|
+
|
7
|
+
|
8
|
+
def RawToSpike(video_seq, h, w, flipud=True):
|
9
|
+
|
10
|
+
video_seq = np.array(video_seq).astype(np.uint8)
|
11
|
+
img_size = h*w
|
12
|
+
img_num = len(video_seq)//(img_size//8)
|
13
|
+
SpikeMatrix = np.zeros([img_num, h, w], np.uint8)
|
14
|
+
pix_id = np.arange(0,h*w)
|
15
|
+
pix_id = np.reshape(pix_id, (h, w))
|
16
|
+
comparator = np.left_shift(1, np.mod(pix_id, 8))
|
17
|
+
byte_id = pix_id // 8
|
18
|
+
|
19
|
+
for img_id in np.arange(img_num):
|
20
|
+
id_start = int(img_id)*int(img_size)//8
|
21
|
+
id_end = int(id_start) + int(img_size)//8
|
22
|
+
cur_info = video_seq[id_start:id_end]
|
23
|
+
data = cur_info[byte_id]
|
24
|
+
result = np.bitwise_and(data, comparator)
|
25
|
+
if flipud:
|
26
|
+
SpikeMatrix[img_id, :, :] = np.flipud((result == comparator))
|
27
|
+
else:
|
28
|
+
SpikeMatrix[img_id, :, :] = (result == comparator)
|
29
|
+
|
30
|
+
return SpikeMatrix
|
31
|
+
|
32
|
+
'''
|
33
|
+
# --------------------------------------------
|
34
|
+
# Kai Zhang (github: https://github.com/cszn)
|
35
|
+
# 03/Mar/2019
|
36
|
+
# --------------------------------------------
|
37
|
+
# https://github.com/twhui/SRGAN-pyTorch
|
38
|
+
# https://github.com/xinntao/BasicSR
|
39
|
+
# --------------------------------------------
|
40
|
+
'''
|
41
|
+
|
42
|
+
|
43
|
+
def mkdir(path):
|
44
|
+
if not os.path.exists(path):
|
45
|
+
os.makedirs(path)
|
46
|
+
|
47
|
+
|
48
|
+
def mkdirs(paths):
|
49
|
+
if isinstance(paths, str):
|
50
|
+
mkdir(paths)
|
51
|
+
else:
|
52
|
+
for path in paths:
|
53
|
+
mkdir(path)
|
54
|
+
|
55
|
+
|
56
|
+
# --------------------------------------------
|
57
|
+
# PSNR
|
58
|
+
# --------------------------------------------
|
59
|
+
def calculate_psnr(img1, img2, border=0):
|
60
|
+
# img1 and img2 have range [0, 255]
|
61
|
+
#img1 = img1.squeeze()
|
62
|
+
#img2 = img2.squeeze()
|
63
|
+
if not img1.shape == img2.shape:
|
64
|
+
raise ValueError('Input images must have the same dimensions.')
|
65
|
+
h, w = img1.shape[:2]
|
66
|
+
img1 = img1[border:h-border, border:w-border]
|
67
|
+
img2 = img2[border:h-border, border:w-border]
|
68
|
+
|
69
|
+
img1 = img1.astype(np.float64)
|
70
|
+
img2 = img2.astype(np.float64)
|
71
|
+
mse = np.mean((img1 - img2)**2)
|
72
|
+
if mse == 0:
|
73
|
+
return float('inf')
|
74
|
+
return 20 * math.log10(255.0 / math.sqrt(mse))
|
75
|
+
|
76
|
+
|
77
|
+
# --------------------------------------------
|
78
|
+
# SSIM
|
79
|
+
# --------------------------------------------
|
80
|
+
def calculate_ssim(img1, img2, border=0):
|
81
|
+
'''calculate SSIM
|
82
|
+
the same outputs as MATLAB's
|
83
|
+
img1, img2: [0, 255]
|
84
|
+
'''
|
85
|
+
#img1 = img1.squeeze()
|
86
|
+
#img2 = img2.squeeze()
|
87
|
+
if not img1.shape == img2.shape:
|
88
|
+
raise ValueError('Input images must have the same dimensions.')
|
89
|
+
h, w = img1.shape[:2]
|
90
|
+
img1 = img1[border:h-border, border:w-border]
|
91
|
+
img2 = img2[border:h-border, border:w-border]
|
92
|
+
|
93
|
+
if img1.ndim == 2:
|
94
|
+
return ssim(img1, img2)
|
95
|
+
elif img1.ndim == 3:
|
96
|
+
if img1.shape[2] == 3:
|
97
|
+
ssims = []
|
98
|
+
for i in range(3):
|
99
|
+
ssims.append(ssim(img1[:,:,i], img2[:,:,i]))
|
100
|
+
return np.array(ssims).mean()
|
101
|
+
elif img1.shape[2] == 1:
|
102
|
+
return ssim(np.squeeze(img1), np.squeeze(img2))
|
103
|
+
else:
|
104
|
+
raise ValueError('Wrong input image dimensions.')
|
105
|
+
|
106
|
+
|
107
|
+
def ssim(img1, img2):
|
108
|
+
C1 = (0.01 * 255)**2
|
109
|
+
C2 = (0.03 * 255)**2
|
110
|
+
|
111
|
+
img1 = img1.astype(np.float64)
|
112
|
+
img2 = img2.astype(np.float64)
|
113
|
+
kernel = cv2.getGaussianKernel(11, 1.5)
|
114
|
+
window = np.outer(kernel, kernel.transpose())
|
115
|
+
|
116
|
+
mu1 = cv2.filter2D(img1, -1, window)[5:-5, 5:-5] # valid
|
117
|
+
mu2 = cv2.filter2D(img2, -1, window)[5:-5, 5:-5]
|
118
|
+
mu1_sq = mu1**2
|
119
|
+
mu2_sq = mu2**2
|
120
|
+
mu1_mu2 = mu1 * mu2
|
121
|
+
sigma1_sq = cv2.filter2D(img1**2, -1, window)[5:-5, 5:-5] - mu1_sq
|
122
|
+
sigma2_sq = cv2.filter2D(img2**2, -1, window)[5:-5, 5:-5] - mu2_sq
|
123
|
+
sigma12 = cv2.filter2D(img1 * img2, -1, window)[5:-5, 5:-5] - mu1_mu2
|
124
|
+
|
125
|
+
ssim_map = ((2 * mu1_mu2 + C1) * (2 * sigma12 + C2)) / ((mu1_sq + mu2_sq + C1) *
|
126
|
+
(sigma1_sq + sigma2_sq + C2))
|
127
|
+
return ssim_map.mean()
|
128
|
+
|
Binary file
|
Binary file
|
Binary file
|
Binary file
|
Binary file
|
Binary file
|
Binary file
|
Binary file
|
Binary file
|
Binary file
|
Binary file
|
Binary file
|
Binary file
|
@@ -6,7 +6,7 @@ import numpy as np
|
|
6
6
|
from spikezoo.utils.spike_utils import load_vidar_dat
|
7
7
|
import re
|
8
8
|
from dataclasses import dataclass, replace
|
9
|
-
from typing import Literal
|
9
|
+
from typing import Literal,Union
|
10
10
|
import warnings
|
11
11
|
import torch
|
12
12
|
from tqdm import tqdm
|
@@ -19,7 +19,7 @@ class BaseDatasetConfig:
|
|
19
19
|
"Dataset name."
|
20
20
|
dataset_name: str = "base"
|
21
21
|
"Directory specifying location of data."
|
22
|
-
root_dir: Path = Path(__file__).parent.parent / Path("data/base")
|
22
|
+
root_dir: Union[str,Path] = Path(__file__).parent.parent / Path("data/base")
|
23
23
|
"Image width."
|
24
24
|
width: int = 400
|
25
25
|
"Image height."
|
@@ -108,7 +108,6 @@ class BaseDataset(Dataset):
|
|
108
108
|
spike_name,
|
109
109
|
height=self.cfg.height,
|
110
110
|
width=self.cfg.width,
|
111
|
-
out_type="float",
|
112
111
|
out_format="tensor",
|
113
112
|
)
|
114
113
|
return spike
|
spikezoo/metrics/__init__.py
CHANGED
@@ -11,7 +11,7 @@ import torch.nn.functional as F
|
|
11
11
|
|
12
12
|
# todo with the union type
|
13
13
|
metric_pair_names = ["psnr", "ssim", "lpips", "mse"]
|
14
|
-
metric_single_names = ["niqe", "brisque", "piqe"]
|
14
|
+
metric_single_names = ["niqe", "brisque", "piqe", "liqe_mix", "clipiqa"]
|
15
15
|
metric_all_names = metric_pair_names + metric_single_names
|
16
16
|
|
17
17
|
metric_single_list = {}
|
spikezoo/models/base_model.py
CHANGED
@@ -45,7 +45,6 @@ class BaseModel(nn.Module):
|
|
45
45
|
self.device = "cuda" if torch.cuda.is_available() else "cpu"
|
46
46
|
self.net = self.build_network().to(self.device)
|
47
47
|
self.net = nn.DataParallel(self.net) if cfg.multi_gpu == True else self.net
|
48
|
-
self.spike_size = None
|
49
48
|
self.model_half_win_length: int = cfg.model_win_length // 2
|
50
49
|
|
51
50
|
# ! Might lead to low speed training on the BSF.
|
@@ -104,8 +103,7 @@ class BaseModel(nn.Module):
|
|
104
103
|
:,
|
105
104
|
spike_mid - self.model_half_win_length : spike_mid + self.model_half_win_length + 1,
|
106
105
|
]
|
107
|
-
|
108
|
-
self.spike_size = (spike.shape[2], spike.shape[3])
|
106
|
+
self.spike_size = (spike.shape[2], spike.shape[3])
|
109
107
|
return spike
|
110
108
|
|
111
109
|
def preprocess_spike(self, spike):
|
@@ -67,7 +67,7 @@ class Pipeline:
|
|
67
67
|
"""Pipeline setup."""
|
68
68
|
# save folder
|
69
69
|
self.thistime = datetime.now().strftime("%Y-%m-%d_%H-%M-%S-%f")[:23]
|
70
|
-
self.save_folder = Path(
|
70
|
+
self.save_folder = Path(f"results") if len(self.cfg.save_folder) == 0 else self.cfg.save_folder
|
71
71
|
mode_name = "train" if self.cfg._mode == "train_mode" else "detect"
|
72
72
|
self.save_folder = (
|
73
73
|
self.save_folder / Path(f"{mode_name}/{self.thistime}")
|
@@ -93,6 +93,7 @@ class Pipeline:
|
|
93
93
|
def spk2img_from_dataset(self, idx=0):
|
94
94
|
"""Func---Save the recoverd image and calculate the metric from the given dataset."""
|
95
95
|
# save folder
|
96
|
+
self.logger.info("*********************** spk2img_from_dataset ***********************")
|
96
97
|
save_folder = self.save_folder / Path(f"spk2img_from_dataset/{self.dataset.cfg.dataset_name}_dataset/{self.dataset.cfg.split}/{idx:06d}")
|
97
98
|
os.makedirs(str(save_folder), exist_ok=True)
|
98
99
|
|
@@ -106,9 +107,10 @@ class Pipeline:
|
|
106
107
|
img = None
|
107
108
|
return self._spk2img(spike, img, save_folder)
|
108
109
|
|
109
|
-
def spk2img_from_file(self, file_path, height, width, img_path=None, remove_head=False):
|
110
|
+
def spk2img_from_file(self, file_path, height = -1, width = -1, img_path=None, remove_head=False):
|
110
111
|
"""Func---Save the recoverd image and calculate the metric from the given input file."""
|
111
112
|
# save folder
|
113
|
+
self.logger.info("*********************** spk2img_from_file ***********************")
|
112
114
|
save_folder = self.save_folder / Path(f"spk2img_from_file/{os.path.basename(file_path)}")
|
113
115
|
os.makedirs(str(save_folder), exist_ok=True)
|
114
116
|
|
@@ -135,6 +137,7 @@ class Pipeline:
|
|
135
137
|
def spk2img_from_spk(self, spike, img=None):
|
136
138
|
"""Func---Save the recoverd image and calculate the metric from the given spike stream."""
|
137
139
|
# save folder
|
140
|
+
self.logger.info("*********************** spk2img_from_spk ***********************")
|
138
141
|
save_folder = self.save_folder / Path(f"spk2img_from_spk/{self.thistime}")
|
139
142
|
os.makedirs(str(save_folder), exist_ok=True)
|
140
143
|
|
@@ -188,7 +191,7 @@ class Pipeline:
|
|
188
191
|
if self.cfg.save_metric == True:
|
189
192
|
self.logger.info(f"----------------------Method: {model_name.upper()}----------------------")
|
190
193
|
# paired metric
|
191
|
-
for metric_name in
|
194
|
+
for metric_name in self.cfg.metric_names:
|
192
195
|
if img is not None and metric_name in metric_pair_names:
|
193
196
|
self.logger.info(f"{metric_name.upper()}: {cal_metric_pair(recon_img,img,metric_name)}")
|
194
197
|
elif metric_name in metric_single_names:
|
@@ -203,8 +206,7 @@ class Pipeline:
|
|
203
206
|
if img is not None:
|
204
207
|
img = tensor2npy(img[0, 0])
|
205
208
|
cv2.imwrite(f"{save_folder}/sharp_img.png", img)
|
206
|
-
|
207
|
-
|
209
|
+
self.logger.info(f"Images are saved on the {save_folder}")
|
208
210
|
return recon_img_copy
|
209
211
|
|
210
212
|
def _post_process_img(self, model_name, recon_img, gt_img):
|
@@ -66,7 +66,7 @@ class TrainPipeline(Pipeline):
|
|
66
66
|
save_folder = self.save_folder / Path("imgs") / Path(f"{epoch:06d}")
|
67
67
|
os.makedirs(save_folder, exist_ok=True)
|
68
68
|
for batch_idx, batch in enumerate(tqdm(self.dataloader)):
|
69
|
-
if batch_idx % (len(self.dataloader) // 4)
|
69
|
+
if batch_idx % (len(self.dataloader) // 4) != 0:
|
70
70
|
continue
|
71
71
|
batch = self.model.feed_to_device(batch)
|
72
72
|
outputs = self.model.get_outputs_dict(batch)
|