deepliif 1.1.11__py3-none-any.whl → 1.1.12__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cli.py +354 -67
- deepliif/data/__init__.py +7 -7
- deepliif/data/aligned_dataset.py +2 -3
- deepliif/data/unaligned_dataset.py +38 -19
- deepliif/models/CycleGAN_model.py +282 -0
- deepliif/models/DeepLIIFExt_model.py +47 -25
- deepliif/models/DeepLIIF_model.py +69 -19
- deepliif/models/SDG_model.py +57 -26
- deepliif/models/__init__ - run_dask_multi dev.py +943 -0
- deepliif/models/__init__ - timings.py +764 -0
- deepliif/models/__init__.py +328 -265
- deepliif/models/att_unet.py +199 -0
- deepliif/models/base_model.py +32 -8
- deepliif/models/networks.py +108 -34
- deepliif/options/__init__.py +49 -5
- deepliif/postprocessing.py +1034 -227
- deepliif/postprocessing__OLD__DELETE.py +440 -0
- deepliif/util/__init__.py +85 -64
- deepliif/util/visualizer.py +106 -19
- {deepliif-1.1.11.dist-info → deepliif-1.1.12.dist-info}/METADATA +75 -23
- deepliif-1.1.12.dist-info/RECORD +40 -0
- deepliif-1.1.11.dist-info/RECORD +0 -35
- {deepliif-1.1.11.dist-info → deepliif-1.1.12.dist-info}/LICENSE.md +0 -0
- {deepliif-1.1.11.dist-info → deepliif-1.1.12.dist-info}/WHEEL +0 -0
- {deepliif-1.1.11.dist-info → deepliif-1.1.12.dist-info}/entry_points.txt +0 -0
- {deepliif-1.1.11.dist-info → deepliif-1.1.12.dist-info}/top_level.txt +0 -0
deepliif/data/__init__.py
CHANGED
|
@@ -55,28 +55,28 @@ def get_option_setter(dataset_name):
|
|
|
55
55
|
return dataset_class.modify_commandline_options
|
|
56
56
|
|
|
57
57
|
|
|
58
|
-
def create_dataset(opt):
|
|
58
|
+
def create_dataset(opt, phase=None, batch_size=None):
|
|
59
59
|
"""Create a dataset given the option.
|
|
60
60
|
|
|
61
61
|
This function wraps the class CustomDatasetDataLoader.
|
|
62
62
|
This is the main interface between this package and 'train.py'/'test.py'
|
|
63
63
|
"""
|
|
64
|
-
return CustomDatasetDataLoader(opt)
|
|
64
|
+
return CustomDatasetDataLoader(opt, phase=phase if phase else opt.phase, batch_size=batch_size if batch_size else opt.batch_size)
|
|
65
65
|
|
|
66
66
|
|
|
67
67
|
class CustomDatasetDataLoader(object):
|
|
68
68
|
"""Wrapper class of Dataset class that performs multi-threaded data loading"""
|
|
69
69
|
|
|
70
|
-
def __init__(self, opt):
|
|
70
|
+
def __init__(self, opt, phase=None, batch_size=None):
|
|
71
71
|
"""Initialize this class
|
|
72
72
|
|
|
73
73
|
Step 1: create a dataset instance given the name [dataset_mode]
|
|
74
74
|
Step 2: create a multi-threaded data loader.
|
|
75
75
|
"""
|
|
76
|
-
self.batch_size = opt.batch_size
|
|
76
|
+
self.batch_size = batch_size if batch_size else opt.batch_size
|
|
77
77
|
self.max_dataset_size = opt.max_dataset_size
|
|
78
78
|
dataset_class = find_dataset_using_name(opt.dataset_mode)
|
|
79
|
-
self.dataset = dataset_class(opt)
|
|
79
|
+
self.dataset = dataset_class(opt, phase=phase if phase else opt.phase)
|
|
80
80
|
print("dataset [%s] was created" % type(self.dataset).__name__)
|
|
81
81
|
|
|
82
82
|
sampler = None
|
|
@@ -95,7 +95,7 @@ class CustomDatasetDataLoader(object):
|
|
|
95
95
|
self.dataloader = torch.utils.data.DataLoader(
|
|
96
96
|
self.dataset,
|
|
97
97
|
sampler=sampler,
|
|
98
|
-
batch_size=
|
|
98
|
+
batch_size=batch_size,
|
|
99
99
|
shuffle=not opt.serial_batches if sampler is None else False,
|
|
100
100
|
num_workers=int(opt.num_threads)
|
|
101
101
|
)
|
|
@@ -106,7 +106,7 @@ class CustomDatasetDataLoader(object):
|
|
|
106
106
|
self.dataloader = torch.utils.data.DataLoader(
|
|
107
107
|
self.dataset,
|
|
108
108
|
sampler=sampler,
|
|
109
|
-
batch_size=
|
|
109
|
+
batch_size=batch_size,
|
|
110
110
|
shuffle=not opt.serial_batches if sampler is None else False,
|
|
111
111
|
num_workers=int(opt.num_threads),
|
|
112
112
|
worker_init_fn=seed_worker,
|
deepliif/data/aligned_dataset.py
CHANGED
|
@@ -11,7 +11,7 @@ class AlignedDataset(BaseDataset):
|
|
|
11
11
|
During test time, you need to prepare a directory '/path/to/data/test'.
|
|
12
12
|
"""
|
|
13
13
|
|
|
14
|
-
def __init__(self, opt):
|
|
14
|
+
def __init__(self, opt, phase='train'):
|
|
15
15
|
"""Initialize this dataset class.
|
|
16
16
|
|
|
17
17
|
Parameters:
|
|
@@ -19,7 +19,7 @@ class AlignedDataset(BaseDataset):
|
|
|
19
19
|
"""
|
|
20
20
|
BaseDataset.__init__(self, opt.dataroot)
|
|
21
21
|
self.preprocess = opt.preprocess
|
|
22
|
-
self.dir_AB = os.path.join(opt.dataroot,
|
|
22
|
+
self.dir_AB = os.path.join(opt.dataroot, phase) # get the image directory
|
|
23
23
|
self.AB_paths = sorted(make_dataset(self.dir_AB, opt.max_dataset_size)) # get image paths
|
|
24
24
|
assert(opt.load_size >= opt.crop_size) # crop_size should be smaller than the size of loaded image
|
|
25
25
|
self.input_nc = opt.output_nc if opt.direction == 'BtoA' else opt.input_nc
|
|
@@ -95,7 +95,6 @@ class AlignedDataset(BaseDataset):
|
|
|
95
95
|
A = AB.crop((w2 * i, 0, w2 * (i+1), h))
|
|
96
96
|
A = A_transform(A)
|
|
97
97
|
A_Array.append(A)
|
|
98
|
-
|
|
99
98
|
for i in range(self.input_no, self.input_no + self.modalities_no + 1):
|
|
100
99
|
B = AB.crop((w2 * i, 0, w2 * (i + 1), h))
|
|
101
100
|
B = B_transform(B)
|
|
@@ -1,5 +1,5 @@
|
|
|
1
1
|
import os.path
|
|
2
|
-
from deepliif.data.base_dataset import BaseDataset, get_transform
|
|
2
|
+
from deepliif.data.base_dataset import BaseDataset, get_params, get_transform
|
|
3
3
|
from deepliif.data.image_folder import make_dataset
|
|
4
4
|
from PIL import Image
|
|
5
5
|
import random
|
|
@@ -16,25 +16,39 @@ class UnalignedDataset(BaseDataset):
|
|
|
16
16
|
'/path/to/data/testA' and '/path/to/data/testB' during test time.
|
|
17
17
|
"""
|
|
18
18
|
|
|
19
|
-
def __init__(self, opt):
|
|
19
|
+
def __init__(self, opt, phase='train'):
|
|
20
20
|
"""Initialize this dataset class.
|
|
21
21
|
|
|
22
22
|
Parameters:
|
|
23
23
|
opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
|
|
24
24
|
"""
|
|
25
25
|
BaseDataset.__init__(self, opt)
|
|
26
|
-
self.
|
|
27
|
-
self.
|
|
26
|
+
self.opt = opt
|
|
27
|
+
self.input_nc = opt.output_nc if opt.direction == 'BtoA' else opt.input_nc
|
|
28
|
+
self.output_nc = opt.input_nc if opt.direction == 'BtoA' else opt.output_nc
|
|
29
|
+
self.preprocess = opt.preprocess
|
|
30
|
+
self.no_flip = opt.no_flip
|
|
31
|
+
self.modalities_no = opt.modalities_no
|
|
32
|
+
self.seg_no = opt.seg_no
|
|
33
|
+
self.input_no = opt.input_no
|
|
34
|
+
self.seg_gen = opt.seg_gen
|
|
35
|
+
self.load_size = opt.load_size
|
|
36
|
+
self.crop_size = opt.crop_size
|
|
37
|
+
self.model = opt.model
|
|
38
|
+
|
|
39
|
+
self.dir_A = os.path.join(opt.dataroot, phase + 'A') # create a path '/path/to/data/trainA'
|
|
40
|
+
# trainB1/trainB2/trainB3... are organized as elements of DATASET B which is a list
|
|
41
|
+
self.dirs_B = [os.path.join(opt.dataroot, phase + f'B{i}') for i in range(1,self.modalities_no+1)] # create a list of paths ['/path/to/data/trainB',...]
|
|
28
42
|
|
|
29
43
|
self.A_paths = sorted(make_dataset(self.dir_A, opt.max_dataset_size)) # load images from '/path/to/data/trainA'
|
|
30
|
-
self.B_paths = sorted(make_dataset(
|
|
44
|
+
self.B_paths = [sorted(make_dataset(dir_B, opt.max_dataset_size)) for dir_B in self.dirs_B] # load images from '/path/to/data/trainB', '/path/to/data/trainC', ...
|
|
31
45
|
self.A_size = len(self.A_paths) # get the size of dataset A
|
|
32
|
-
self.
|
|
46
|
+
self.B_sizes = [len(B_paths) for B_paths in self.B_paths] # get the size of dataset B1, B2, B3, ...
|
|
33
47
|
btoA = self.opt.direction == 'BtoA'
|
|
34
48
|
input_nc = self.opt.output_nc if btoA else self.opt.input_nc # get the number of channels of input image
|
|
35
49
|
output_nc = self.opt.input_nc if btoA else self.opt.output_nc # get the number of channels of output image
|
|
36
|
-
|
|
37
|
-
|
|
50
|
+
|
|
51
|
+
|
|
38
52
|
|
|
39
53
|
def __getitem__(self, index):
|
|
40
54
|
"""Return a data point and its metadata information.
|
|
@@ -50,22 +64,27 @@ class UnalignedDataset(BaseDataset):
|
|
|
50
64
|
"""
|
|
51
65
|
A_path = self.A_paths[index % self.A_size] # make sure index is within then range
|
|
52
66
|
if self.opt.serial_batches: # make sure index is within then range
|
|
53
|
-
|
|
67
|
+
indice_B = [index % B_size for B_size in self.B_sizes]
|
|
54
68
|
else: # randomize the index for domain B to avoid fixed pairs.
|
|
55
|
-
|
|
56
|
-
|
|
69
|
+
indice_B = [random.randint(0, B_size - 1) for B_size in self.B_sizes]
|
|
70
|
+
B_paths = [B_paths[index_B] for B_paths, index_B in zip(self.B_paths, indice_B)]
|
|
71
|
+
|
|
57
72
|
A_img = Image.open(A_path).convert('RGB')
|
|
58
|
-
|
|
73
|
+
B_imgs = [Image.open(B_path).convert('RGB') for B_path in B_paths]
|
|
74
|
+
|
|
59
75
|
# apply image transformation
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
76
|
+
transform_params = get_params(self.preprocess, self.load_size, self.crop_size, A_img.size)
|
|
77
|
+
A_transform = get_transform(self.preprocess, self.load_size, self.crop_size, self.no_flip, transform_params, grayscale=(self.input_nc == 1))
|
|
78
|
+
B_transform = get_transform(self.preprocess, self.load_size, self.crop_size, self.no_flip, transform_params, grayscale=(self.output_nc == 1))
|
|
79
|
+
A = A_transform(A_img)
|
|
80
|
+
Bs = [B_transform(B_img) for B_img in B_imgs]
|
|
81
|
+
|
|
82
|
+
return {'A': A, 'Bs': Bs, 'A_paths': A_path, 'B_paths': B_paths}
|
|
64
83
|
|
|
65
84
|
def __len__(self):
|
|
66
85
|
"""Return the total number of images in the dataset.
|
|
67
86
|
|
|
68
|
-
|
|
69
|
-
we
|
|
87
|
+
|
|
88
|
+
The effective size of this dataset will be the size of datasetA through which we loop and grab a random/matching image B1/B2/B3... for
|
|
70
89
|
"""
|
|
71
|
-
return max(self.A_size, self.B_size)
|
|
90
|
+
return self.A_size #max(self.A_size, self.B_size)
|
|
@@ -0,0 +1,282 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from packaging import version
|
|
3
|
+
from torch import nn
|
|
4
|
+
import itertools
|
|
5
|
+
from ..util.image_pool import ImagePool
|
|
6
|
+
from .base_model import BaseModel
|
|
7
|
+
from . import networks
|
|
8
|
+
from .networks import get_optimizer
|
|
9
|
+
|
|
10
|
+
|
|
11
|
+
class CycleGANModel(BaseModel):
|
|
12
|
+
"""
|
|
13
|
+
This class implements the CycleGAN model, for learning image-to-image translation without paired data.
|
|
14
|
+
|
|
15
|
+
The model training requires '--dataset_mode unaligned' dataset.
|
|
16
|
+
By default, it uses a '--netG resnet_9blocks' ResNet generator,
|
|
17
|
+
a '--netD basic' discriminator (PatchGAN introduced by pix2pix),
|
|
18
|
+
and a least-square GANs objective ('--gan_mode lsgan').
|
|
19
|
+
|
|
20
|
+
CycleGAN paper: https://arxiv.org/pdf/1703.10593.pdf
|
|
21
|
+
"""
|
|
22
|
+
|
|
23
|
+
def __init__(self, opt):
|
|
24
|
+
"""Initialize the CycleGAN class.
|
|
25
|
+
|
|
26
|
+
Parameters:
|
|
27
|
+
opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions
|
|
28
|
+
"""
|
|
29
|
+
BaseModel.__init__(self, opt)
|
|
30
|
+
self.mod_gen_no = self.opt.modalities_no
|
|
31
|
+
if not hasattr(self.opt,'upsample'):
|
|
32
|
+
self.opt.upsample = 'convtranspose'
|
|
33
|
+
if not hasattr(self.opt,'label_smoothing'):
|
|
34
|
+
self.opt.label_smoothing = 0
|
|
35
|
+
|
|
36
|
+
use_spectral_norm = self.opt.norm == 'spectral'
|
|
37
|
+
|
|
38
|
+
self.loss_G_weights = opt.loss_G_weights
|
|
39
|
+
self.loss_D_weights = opt.loss_D_weights
|
|
40
|
+
self.loss_cyc_weights = [1 / self.mod_gen_no] * self.mod_gen_no
|
|
41
|
+
|
|
42
|
+
self.opt.lambda_identity = 0 # do not use lambda identity for the first trial
|
|
43
|
+
# specify the training losses you want to print out. The training/test scripts will call <BaseModel.get_current_losses>
|
|
44
|
+
self.loss_names = ['D_A', 'G_A', 'cycle_A', 'idt_A', 'D_B', 'G_B', 'cycle_B', 'idt_B']
|
|
45
|
+
# specify the images you want to save/display. The training/test scripts will call <BaseModel.get_current_visuals>
|
|
46
|
+
l_suffix = range(1, self.opt.modalities_no + 1)
|
|
47
|
+
visual_names_A = [f'real_As_{i}' for i in l_suffix] + [f'fake_Bs_{i}' for i in l_suffix] + [f'rec_As_{i}' for i in l_suffix]
|
|
48
|
+
visual_names_B = [f'real_Bs_{i}' for i in l_suffix] + [f'fake_As_{i}' for i in l_suffix] + [f'rec_Bs_{i}' for i in l_suffix]
|
|
49
|
+
|
|
50
|
+
# if self.is_train and self.opt.lambda_identity > 0.0: # if identity loss is used, we also visualize idt_B=G_A(B) ad idt_A=G_A(B)
|
|
51
|
+
# visual_names_A.append('idt_B')
|
|
52
|
+
# visual_names_B.append('idt_A')
|
|
53
|
+
|
|
54
|
+
self.visual_names = visual_names_A + visual_names_B # combine visualizations for A and B
|
|
55
|
+
# specify the models you want to save to the disk. The training/test scripts will call <BaseModel.save_networks> and <BaseModel.load_networks>.
|
|
56
|
+
if self.is_train:
|
|
57
|
+
self.model_names = [f'GA_{i}' for i in l_suffix] + [f'GB_{i}' for i in l_suffix] + [f'DA_{i}' for i in l_suffix] + [f'DB_{i}' for i in l_suffix]
|
|
58
|
+
else: # during test time, only load Gs
|
|
59
|
+
if self.opt.BtoA:
|
|
60
|
+
self.model_names = [f'GB_{i}' for i in l_suffix]
|
|
61
|
+
else:
|
|
62
|
+
self.model_names = [f'GA_{i}' for i in l_suffix]
|
|
63
|
+
|
|
64
|
+
# define networks (both Generators and discriminators)
|
|
65
|
+
# The naming is different from those used in the paper.
|
|
66
|
+
# Code (vs. paper): G_A (G), G_B (F), D_A (D_Y), D_B (D_X)
|
|
67
|
+
if isinstance(opt.net_g, str):
|
|
68
|
+
self.opt.net_g = [self.opt.net_g] * self.mod_gen_no
|
|
69
|
+
|
|
70
|
+
if version.parse(torch.__version__) < version.parse('1.11.0'):
|
|
71
|
+
self.netGA = list()
|
|
72
|
+
self.netGB = list()
|
|
73
|
+
else:
|
|
74
|
+
self.netGA = nn.ModuleList()
|
|
75
|
+
self.netGB = nn.ModuleList()
|
|
76
|
+
|
|
77
|
+
for i in range(self.mod_gen_no):
|
|
78
|
+
if self.is_train or not self.opt.BtoA:
|
|
79
|
+
self.netGA.append(networks.define_G(self.opt.input_nc, self.opt.output_nc, self.opt.ngf, self.opt.net_g[i], self.opt.norm,
|
|
80
|
+
not self.opt.no_dropout, self.opt.init_type, self.opt.init_gain, self.gpu_ids, self.opt.padding,
|
|
81
|
+
upsample=self.opt.upsample))
|
|
82
|
+
if self.is_train or self.opt.BtoA:
|
|
83
|
+
self.netGB.append(networks.define_G(self.opt.output_nc, self.opt.input_nc, self.opt.ngf, self.opt.net_g[i], self.opt.norm,
|
|
84
|
+
not self.opt.no_dropout, self.opt.init_type, self.opt.init_gain, self.gpu_ids, self.opt.padding,
|
|
85
|
+
upsample=self.opt.upsample))
|
|
86
|
+
|
|
87
|
+
if self.is_train: # define a discriminator; conditional GANs need to take both input and output images; Therefore, #channels for D is input_nc + output_nc
|
|
88
|
+
if version.parse(torch.__version__) < version.parse('1.11.0'):
|
|
89
|
+
self.netDA = list()
|
|
90
|
+
self.netDB = list()
|
|
91
|
+
else:
|
|
92
|
+
self.netDA = nn.ModuleList()
|
|
93
|
+
self.netDB = nn.ModuleList()
|
|
94
|
+
|
|
95
|
+
for i in range(self.mod_gen_no):
|
|
96
|
+
self.netDA.append(networks.define_D(self.opt.output_nc, self.opt.ndf, self.opt.net_d,
|
|
97
|
+
self.opt.n_layers_D, self.opt.norm, self.opt.init_type, self.opt.init_gain,
|
|
98
|
+
self.gpu_ids))
|
|
99
|
+
self.netDB.append(networks.define_D(self.opt.input_nc, self.opt.ndf, self.opt.net_d,
|
|
100
|
+
self.opt.n_layers_D, self.opt.norm, self.opt.init_type, self.opt.init_gain,
|
|
101
|
+
self.gpu_ids))
|
|
102
|
+
|
|
103
|
+
|
|
104
|
+
|
|
105
|
+
if self.is_train:
|
|
106
|
+
if opt.lambda_identity > 0.0: # only works when input and output images have the same number of channels
|
|
107
|
+
assert(opt.input_nc == opt.output_nc)
|
|
108
|
+
self.fake_A_pools = [ImagePool(opt.pool_size) for _ in range(self.opt.modalities_no)] # create image buffer to store previously generated images
|
|
109
|
+
self.fake_B_pools = [ImagePool(opt.pool_size) for _ in range(self.opt.modalities_no)] # create image buffer to store previously generated images
|
|
110
|
+
|
|
111
|
+
# define loss functions
|
|
112
|
+
# label smoothing currently only applies to discriminator losses & generatoe of lsgan/vanilla
|
|
113
|
+
self.criterionGAN = networks.GANLoss(opt.gan_mode, label_smoothing=self.opt.label_smoothing).to(self.device) # define GAN loss.
|
|
114
|
+
self.criterionCycle = torch.nn.L1Loss()
|
|
115
|
+
self.criterionIdt = torch.nn.L1Loss()
|
|
116
|
+
|
|
117
|
+
self.criterionSmoothL1 = torch.nn.SmoothL1Loss()
|
|
118
|
+
self.criterionVGG = networks.VGGLoss().to(self.device)
|
|
119
|
+
|
|
120
|
+
|
|
121
|
+
# initialize optimizers
|
|
122
|
+
params = []
|
|
123
|
+
for i in range(len(self.netGA)):
|
|
124
|
+
params += list(self.netGA[i].parameters())
|
|
125
|
+
for i in range(len(self.netGB)):
|
|
126
|
+
params += list(self.netGB[i].parameters())
|
|
127
|
+
try:
|
|
128
|
+
self.optimizer_G = get_optimizer(opt.optimizer)(params, lr=opt.lr_g, betas=(opt.beta1, 0.999))
|
|
129
|
+
except:
|
|
130
|
+
print(f'betas are not used for optimizer torch.optim.{opt.optimizer} in generators')
|
|
131
|
+
self.optimizer_G = get_optimizer(opt.optimizer)(params, lr=opt.lr_g)
|
|
132
|
+
|
|
133
|
+
params = []
|
|
134
|
+
for i in range(len(self.netDA)):
|
|
135
|
+
params += list(self.netDA[i].parameters())
|
|
136
|
+
for i in range(len(self.netDB)):
|
|
137
|
+
params += list(self.netDB[i].parameters())
|
|
138
|
+
|
|
139
|
+
# a smaller learning rate for discriminators to postpone training failure due to discriminators quickly become too strong
|
|
140
|
+
try:
|
|
141
|
+
self.optimizer_D = get_optimizer(opt.optimizer)(params, lr=opt.lr_d, betas=(opt.beta1, 0.999))
|
|
142
|
+
except:
|
|
143
|
+
print(f'betas are not used for optimizer torch.optim.{opt.optimizer} in generators')
|
|
144
|
+
self.optimizer_D = get_optimizer(opt.optimizer)(params, lr=opt.lr_d)
|
|
145
|
+
|
|
146
|
+
|
|
147
|
+
self.optimizers.append(self.optimizer_G)
|
|
148
|
+
self.optimizers.append(self.optimizer_D)
|
|
149
|
+
|
|
150
|
+
def set_input(self, input):
|
|
151
|
+
"""Unpack input data from the dataloader and perform necessary pre-processing steps.
|
|
152
|
+
|
|
153
|
+
Parameters:
|
|
154
|
+
input (dict): include the data itself and its metadata information.
|
|
155
|
+
|
|
156
|
+
The option 'direction' can be used to swap domain A and domain B.
|
|
157
|
+
"""
|
|
158
|
+
self.real_As = [input['A'].to(self.device) for _ in range(self.opt.modalities_no)]
|
|
159
|
+
self.real_Bs = [x.to(self.device) for x in input['Bs']]
|
|
160
|
+
self.image_paths = input['A_paths']
|
|
161
|
+
|
|
162
|
+
def forward(self):
|
|
163
|
+
"""
|
|
164
|
+
Run forward pass; called by both functions <optimize_parameters> and <test>.
|
|
165
|
+
During inference, some output list could be empty. For example, if only netGAs are loaded,
|
|
166
|
+
there will not be valid elements in self.rec_As and self.fake_As.
|
|
167
|
+
"""
|
|
168
|
+
self.fake_Bs = [netGA(real_A) for netGA, real_A in zip(self.netGA, self.real_As)] # G_A(A)
|
|
169
|
+
self.rec_As = [netGB(fake_B) for netGB, fake_B in zip(self.netGB, self.fake_Bs)] # G_B(G_A(A))
|
|
170
|
+
|
|
171
|
+
self.fake_As = [netGB(real_B) for netGB, real_B in zip(self.netGB, self.real_Bs)] # G_B(B)
|
|
172
|
+
self.rec_Bs = [netGA(fake_A) for netGA, fake_A in zip(self.netGA, self.fake_As)] # G_A(G_B(B))
|
|
173
|
+
|
|
174
|
+
|
|
175
|
+
def backward_D_basic(self, netD, real, fake, scale_factor=1):
|
|
176
|
+
"""Calculate GAN loss for the discriminator
|
|
177
|
+
|
|
178
|
+
Parameters:
|
|
179
|
+
netD (network) -- the discriminator D
|
|
180
|
+
real (tensor array) -- real images
|
|
181
|
+
fake (tensor array) -- images generated by a generator
|
|
182
|
+
|
|
183
|
+
Return the discriminator loss.
|
|
184
|
+
We also call loss_D.backward() to calculate the gradients.
|
|
185
|
+
"""
|
|
186
|
+
# Real
|
|
187
|
+
pred_real = netD(real)
|
|
188
|
+
loss_D_real = self.criterionGAN(pred_real, True)
|
|
189
|
+
# Fake
|
|
190
|
+
pred_fake = netD(fake.detach())
|
|
191
|
+
loss_D_fake = self.criterionGAN(pred_fake, False)
|
|
192
|
+
# Combined loss and calculate gradients
|
|
193
|
+
loss_D = (loss_D_real + loss_D_fake) * 0.5 * scale_factor
|
|
194
|
+
loss_D.backward()
|
|
195
|
+
return loss_D
|
|
196
|
+
|
|
197
|
+
def backward_D_A(self):
|
|
198
|
+
"""Calculate GAN loss for discriminator D_A"""
|
|
199
|
+
fake_Bs = [fake_B_pool.query(fake_B) for fake_B_pool, fake_B in zip(self.fake_B_pools, self.fake_Bs)]
|
|
200
|
+
real_Bs = self.real_Bs
|
|
201
|
+
|
|
202
|
+
self.loss_D_A = 0
|
|
203
|
+
for i, (netDA, real_B, fake_B) in enumerate(zip(self.netDA, real_Bs, fake_Bs)):
|
|
204
|
+
self.loss_D_A += self.backward_D_basic(netDA, real_B, fake_B, scale_factor=self.loss_D_weights[i])
|
|
205
|
+
#self.loss_D_A.backward()
|
|
206
|
+
|
|
207
|
+
def backward_D_B(self):
|
|
208
|
+
"""Calculate GAN loss for discriminator D_B"""
|
|
209
|
+
fake_As = [fake_A_pool.query(fake_A) for fake_A_pool, fake_A in zip(self.fake_A_pools, self.fake_As)]
|
|
210
|
+
real_As = self.real_As
|
|
211
|
+
|
|
212
|
+
self.loss_D_B = 0
|
|
213
|
+
for i, (netDB, real_A, fake_A) in enumerate(zip(self.netDB, real_As, fake_As)):
|
|
214
|
+
self.loss_D_B += self.backward_D_basic(netDB, real_A, fake_A, scale_factor=self.loss_D_weights[i])
|
|
215
|
+
#self.loss_D_B.backward()
|
|
216
|
+
|
|
217
|
+
def backward_G(self):
|
|
218
|
+
"""Calculate the loss for generators G_A and G_B"""
|
|
219
|
+
# default lambda values from cyclegan implementation:
|
|
220
|
+
# https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/c3268edd50ec37a81600c9b981841f48929671b8/models/cycle_gan_model.py#L41
|
|
221
|
+
lambda_idt = 0#self.opt.lambda_identity # identity loss is used to preserve color consistency between input and output images, which we do not want to encourage
|
|
222
|
+
lambda_A = 10#self.opt.lambda_A
|
|
223
|
+
lambda_B = 10#self.opt.lambda_B
|
|
224
|
+
# Identity loss
|
|
225
|
+
if lambda_idt > 0:
|
|
226
|
+
# G_A should be identity if real_B is fed: ||G_A(B) - B||
|
|
227
|
+
self.idt_A = self.netG_A(self.real_B)
|
|
228
|
+
self.loss_idt_A = self.criterionIdt(self.idt_A, self.real_B) * lambda_B * lambda_idt
|
|
229
|
+
# G_B should be identity if real_A is fed: ||G_B(A) - A||
|
|
230
|
+
self.idt_B = self.netG_B(self.real_A)
|
|
231
|
+
self.loss_idt_B = self.criterionIdt(self.idt_B, self.real_A) * lambda_A * lambda_idt
|
|
232
|
+
else:
|
|
233
|
+
self.loss_idt_A = 0
|
|
234
|
+
self.loss_idt_B = 0
|
|
235
|
+
|
|
236
|
+
# GAN loss D_A(G_A(A))
|
|
237
|
+
self.loss_G_A = 0
|
|
238
|
+
for i, (netDA, fake_B, real_B) in enumerate(zip(self.netDA, self.fake_Bs, self.real_Bs)):
|
|
239
|
+
self.loss_G_A += self.criterionGAN(netDA(fake_B), True) * self.loss_G_weights[i]
|
|
240
|
+
self.loss_G_A += self.criterionVGG(fake_B, real_B) * self.loss_G_weights[i]
|
|
241
|
+
|
|
242
|
+
# GAN loss D_B(G_B(B))
|
|
243
|
+
self.loss_G_B = 0
|
|
244
|
+
for i, (netDB, fake_A, real_A) in enumerate(zip(self.netDB, self.fake_As, self.real_As)):
|
|
245
|
+
self.loss_G_B += self.criterionGAN(netDB(fake_A), True) * self.loss_G_weights[i]
|
|
246
|
+
self.loss_G_B += self.criterionVGG(fake_A, real_A) * self.loss_G_weights[i]
|
|
247
|
+
|
|
248
|
+
# Forward cycle loss || G_B(G_A(A)) - A||
|
|
249
|
+
self.loss_cycle_A = 0
|
|
250
|
+
for i, (rec_A, real_A) in enumerate(zip(self.rec_As, self.real_As)):
|
|
251
|
+
self.loss_cycle_A += self.criterionCycle(rec_A, real_A) * lambda_A * self.loss_cyc_weights[i]
|
|
252
|
+
# Backward cycle loss || G_A(G_B(B)) - B||
|
|
253
|
+
self.loss_cycle_B = 0
|
|
254
|
+
for i, (rec_B, real_B) in enumerate(zip(self.rec_Bs, self.real_Bs)):
|
|
255
|
+
self.loss_cycle_B += self.criterionCycle(rec_B, real_B) * lambda_B * self.loss_cyc_weights[i]
|
|
256
|
+
|
|
257
|
+
# VGG loss
|
|
258
|
+
# self.loss_G_VGG = self.criterionVGG(self.fake_B_1, self.real_B_1) * self.opt.lambda_feat
|
|
259
|
+
|
|
260
|
+
# smooth L1
|
|
261
|
+
# self.loss_G_A_L1 = self.criterionSmoothL1(self.fake_B_1, self.real_B_1) * self.opt.lambda_L1
|
|
262
|
+
|
|
263
|
+
# combined loss and calculate gradients
|
|
264
|
+
self.loss_G = self.loss_G_A + self.loss_G_B + self.loss_cycle_A + self.loss_cycle_B + self.loss_idt_A + self.loss_idt_B
|
|
265
|
+
self.loss_G.backward()
|
|
266
|
+
|
|
267
|
+
def optimize_parameters(self):
|
|
268
|
+
"""Calculate losses, gradients, and update network weights; called in every training iteration"""
|
|
269
|
+
# forward
|
|
270
|
+
self.forward() # compute fake images and reconstruction images.
|
|
271
|
+
# G_A and G_B
|
|
272
|
+
self.set_requires_grad(self.netDA + self.netDB, False) # Ds require no gradients when optimizing Gs
|
|
273
|
+
self.optimizer_G.zero_grad() # set G_A and G_B's gradients to zero
|
|
274
|
+
self.backward_G() # calculate gradients for G_A and G_B
|
|
275
|
+
self.optimizer_G.step() # update G_A and G_B's weights
|
|
276
|
+
|
|
277
|
+
# D_A and D_B
|
|
278
|
+
self.set_requires_grad(self.netDA + self.netDB, True)
|
|
279
|
+
self.optimizer_D.zero_grad() # set D_A and D_B's gradients to zero
|
|
280
|
+
self.backward_D_A() # calculate gradients for D_A
|
|
281
|
+
self.backward_D_B() # calculate graidents for D_B
|
|
282
|
+
self.optimizer_D.step() # update D_A and D_B's weights
|
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
import torch
|
|
2
2
|
from .base_model import BaseModel
|
|
3
3
|
from . import networks
|
|
4
|
+
from .networks import get_optimizer
|
|
4
5
|
|
|
5
6
|
|
|
6
7
|
class DeepLIIFExtModel(BaseModel):
|
|
@@ -18,19 +19,16 @@ class DeepLIIFExtModel(BaseModel):
|
|
|
18
19
|
# self.seg_gen_no = self.opt.modalities_no + 1
|
|
19
20
|
|
|
20
21
|
# weights of the modalities in generating segmentation mask
|
|
21
|
-
self.seg_weights =
|
|
22
|
-
if opt.seg_gen:
|
|
23
|
-
self.seg_weights = [0.3] * self.mod_gen_no
|
|
24
|
-
self.seg_weights[1] = 0.4
|
|
22
|
+
self.seg_weights = opt.seg_weights
|
|
25
23
|
|
|
26
24
|
# self.seg_weights = opt.seg_weights
|
|
27
25
|
# assert len(self.seg_weights) == self.seg_gen_no, 'The number of the segmentation weights (seg_weights) is not equal to the number of target images (modalities_no)!'
|
|
28
26
|
# print(self.seg_weights)
|
|
29
27
|
# loss weights in calculating the final loss
|
|
30
|
-
self.loss_G_weights =
|
|
28
|
+
self.loss_G_weights = opt.loss_G_weights
|
|
31
29
|
self.loss_GS_weights = [1 / self.mod_gen_no] * self.mod_gen_no
|
|
32
30
|
|
|
33
|
-
self.loss_D_weights =
|
|
31
|
+
self.loss_D_weights = opt.loss_D_weights
|
|
34
32
|
self.loss_DS_weights = [1 / self.mod_gen_no] * self.mod_gen_no
|
|
35
33
|
|
|
36
34
|
# self.gpu_ids is a possibly modifed one for model initialization
|
|
@@ -72,22 +70,19 @@ class DeepLIIFExtModel(BaseModel):
|
|
|
72
70
|
self.model_names.extend(['GS_' + str(i)])
|
|
73
71
|
|
|
74
72
|
# define networks (both generator and discriminator)
|
|
73
|
+
if isinstance(opt.net_g, str):
|
|
74
|
+
self.opt.net_g = [self.opt.net_g] * self.mod_gen_no
|
|
75
|
+
if isinstance(opt.net_gs, str):
|
|
76
|
+
self.opt.net_gs = [self.opt.net_gs]*self.mod_gen_no
|
|
75
77
|
self.netG = [None for _ in range(self.mod_gen_no)]
|
|
76
78
|
self.netGS = [None for _ in range(self.mod_gen_no)]
|
|
77
79
|
for i in range(self.mod_gen_no):
|
|
78
|
-
self.netG[i] = networks.define_G(self.opt.input_nc, self.opt.output_nc, self.opt.ngf, self.opt.net_g, self.opt.norm,
|
|
80
|
+
self.netG[i] = networks.define_G(self.opt.input_nc, self.opt.output_nc, self.opt.ngf, self.opt.net_g[i], self.opt.norm,
|
|
79
81
|
not self.opt.no_dropout, self.opt.init_type, self.opt.init_gain, self.gpu_ids, self.opt.padding)
|
|
80
|
-
|
|
81
|
-
print(self.opt.input_nc, self.opt.output_nc, self.opt.ngf, self.opt.net_g, self.opt.norm,
|
|
82
|
-
not self.opt.no_dropout, self.opt.init_type, self.opt.init_gain, self.gpu_ids, self.opt.padding)
|
|
83
|
-
print('***************************************')
|
|
82
|
+
|
|
84
83
|
for i in range(self.mod_gen_no):
|
|
85
84
|
if self.opt.seg_gen:
|
|
86
|
-
|
|
87
|
-
# self.netGS[i] = networks.define_G(self.opt.input_nc, self.opt.output_nc, self.opt.ngf, self.opt.net_gs, self.opt.norm,
|
|
88
|
-
# not self.opt.no_dropout, self.opt.init_type, self.opt.init_gain, self.gpu_ids)
|
|
89
|
-
# else:
|
|
90
|
-
self.netGS[i] = networks.define_G(self.opt.input_nc * 3, self.opt.output_nc, self.opt.ngf, self.opt.net_gs, self.opt.norm,
|
|
85
|
+
self.netGS[i] = networks.define_G(self.opt.input_nc * 3, self.opt.output_nc, self.opt.ngf, self.opt.net_gs[i], self.opt.norm,
|
|
91
86
|
not self.opt.no_dropout, self.opt.init_type, self.opt.init_gain, self.gpu_ids)
|
|
92
87
|
|
|
93
88
|
if self.is_train: # define a discriminator; conditional GANs need to take both input and output images; Therefore, #channels for D is input_nc + output_nc
|
|
@@ -99,11 +94,6 @@ class DeepLIIFExtModel(BaseModel):
|
|
|
99
94
|
self.gpu_ids)
|
|
100
95
|
for i in range(self.mod_gen_no):
|
|
101
96
|
if self.opt.seg_gen:
|
|
102
|
-
# if i == 0:
|
|
103
|
-
# self.netDS[i] = networks.define_D(self.opt.input_nc + self.opt.output_nc, self.opt.ndf, self.opt.net_ds,
|
|
104
|
-
# self.opt.n_layers_D, self.opt.norm, self.opt.init_type, self.opt.init_gain,
|
|
105
|
-
# self.gpu_ids)
|
|
106
|
-
# else:
|
|
107
97
|
self.netDS[i] = networks.define_D(self.opt.input_nc * 3 + self.opt.output_nc, self.opt.ndf, self.opt.net_ds,
|
|
108
98
|
self.opt.n_layers_D, self.opt.norm, self.opt.init_type, self.opt.init_gain,
|
|
109
99
|
self.gpu_ids)
|
|
@@ -113,9 +103,7 @@ class DeepLIIFExtModel(BaseModel):
|
|
|
113
103
|
# define loss functions
|
|
114
104
|
self.criterionGAN_mod = networks.GANLoss(self.opt.gan_mode).to(self.device)
|
|
115
105
|
self.criterionGAN_seg = networks.GANLoss(self.opt.gan_mode_s).to(self.device)
|
|
116
|
-
|
|
117
106
|
self.criterionSmoothL1 = torch.nn.SmoothL1Loss()
|
|
118
|
-
|
|
119
107
|
self.criterionVGG = networks.VGGLoss().to(self.device)
|
|
120
108
|
|
|
121
109
|
# initialize optimizers; schedulers will be automatically created by function <BaseModel.setup>.
|
|
@@ -125,7 +113,11 @@ class DeepLIIFExtModel(BaseModel):
|
|
|
125
113
|
for i in range(len(self.netGS)):
|
|
126
114
|
if self.netGS[i]:
|
|
127
115
|
params += list(self.netGS[i].parameters())
|
|
128
|
-
|
|
116
|
+
try:
|
|
117
|
+
self.optimizer_G = get_optimizer(opt.optimizer)(params, lr=opt.lr_g, betas=(opt.beta1, 0.999))
|
|
118
|
+
except:
|
|
119
|
+
print(f'betas are not used for optimizer torch.optim.{opt.optimizer} in generators')
|
|
120
|
+
self.optimizer_G = get_optimizer(opt.optimizer)(params, lr=opt.lr_g)
|
|
129
121
|
|
|
130
122
|
params = []
|
|
131
123
|
for i in range(len(self.netD)):
|
|
@@ -133,7 +125,11 @@ class DeepLIIFExtModel(BaseModel):
|
|
|
133
125
|
for i in range(len(self.netDS)):
|
|
134
126
|
if self.netDS[i]:
|
|
135
127
|
params += list(self.netDS[i].parameters())
|
|
136
|
-
|
|
128
|
+
try:
|
|
129
|
+
self.optimizer_D = get_optimizer(opt.optimizer)(params, lr=opt.lr_d, betas=(opt.beta1, 0.999))
|
|
130
|
+
except:
|
|
131
|
+
print(f'betas are not used for optimizer torch.optim.{opt.optimizer} in discriminators')
|
|
132
|
+
self.optimizer_D = get_optimizer(opt.optimizer)(params, lr=opt.lr_d)
|
|
137
133
|
|
|
138
134
|
self.optimizers.append(self.optimizer_G)
|
|
139
135
|
self.optimizers.append(self.optimizer_D)
|
|
@@ -295,3 +291,29 @@ class DeepLIIFExtModel(BaseModel):
|
|
|
295
291
|
self.optimizer_G.zero_grad() # set G's gradients to zero
|
|
296
292
|
self.backward_G() # calculate graidents for G
|
|
297
293
|
self.optimizer_G.step() # udpate G's weights
|
|
294
|
+
|
|
295
|
+
def calculate_losses(self):
|
|
296
|
+
"""
|
|
297
|
+
Calculate losses but do not optimize parameters. Used in validation loss calculation during training.
|
|
298
|
+
"""
|
|
299
|
+
self.forward() # compute fake images: G(A)
|
|
300
|
+
# update D
|
|
301
|
+
for i in range(self.mod_gen_no):
|
|
302
|
+
self.set_requires_grad(self.netD[i], True) # enable backprop for D1
|
|
303
|
+
for i in range(self.mod_gen_no):
|
|
304
|
+
if self.netDS[i]:
|
|
305
|
+
self.set_requires_grad(self.netDS[i], True)
|
|
306
|
+
|
|
307
|
+
self.optimizer_D.zero_grad() # set D's gradients to zero
|
|
308
|
+
self.backward_D() # calculate gradients for D
|
|
309
|
+
|
|
310
|
+
# update G
|
|
311
|
+
for i in range(self.mod_gen_no):
|
|
312
|
+
self.set_requires_grad(self.netD[i], False)
|
|
313
|
+
for i in range(self.mod_gen_no):
|
|
314
|
+
if self.netDS[i]:
|
|
315
|
+
self.set_requires_grad(self.netDS[i], False)
|
|
316
|
+
|
|
317
|
+
self.optimizer_G.zero_grad() # set G's gradients to zero
|
|
318
|
+
self.backward_G() # calculate graidents for G
|
|
319
|
+
|