deepliif 1.1.6__py3-none-any.whl → 1.1.8__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cli.py +76 -102
- deepliif/data/aligned_dataset.py +33 -7
- deepliif/models/DeepLIIFExt_model.py +297 -0
- deepliif/models/DeepLIIF_model.py +10 -5
- deepliif/models/__init__.py +262 -168
- deepliif/models/base_model.py +54 -8
- deepliif/options/__init__.py +101 -0
- deepliif/options/base_options.py +7 -6
- deepliif/postprocessing.py +285 -246
- {deepliif-1.1.6.dist-info → deepliif-1.1.8.dist-info}/METADATA +26 -12
- {deepliif-1.1.6.dist-info → deepliif-1.1.8.dist-info}/RECORD +15 -14
- {deepliif-1.1.6.dist-info → deepliif-1.1.8.dist-info}/LICENSE.md +0 -0
- {deepliif-1.1.6.dist-info → deepliif-1.1.8.dist-info}/WHEEL +0 -0
- {deepliif-1.1.6.dist-info → deepliif-1.1.8.dist-info}/entry_points.txt +0 -0
- {deepliif-1.1.6.dist-info → deepliif-1.1.8.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,297 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from .base_model import BaseModel
|
|
3
|
+
from . import networks
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class DeepLIIFExtModel(BaseModel):
|
|
7
|
+
""" This class implements the DeepLIIF model, for learning a mapping from input images to modalities given paired data."""
|
|
8
|
+
|
|
9
|
+
def __init__(self, opt):
|
|
10
|
+
"""Initialize the DeepLIIF class.
|
|
11
|
+
|
|
12
|
+
Parameters:
|
|
13
|
+
opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions
|
|
14
|
+
"""
|
|
15
|
+
BaseModel.__init__(self, opt)
|
|
16
|
+
|
|
17
|
+
self.mod_gen_no = self.opt.modalities_no
|
|
18
|
+
# self.seg_gen_no = self.opt.modalities_no + 1
|
|
19
|
+
|
|
20
|
+
# weights of the modalities in generating segmentation mask
|
|
21
|
+
self.seg_weights = [0, 0, 0]
|
|
22
|
+
if opt.seg_gen:
|
|
23
|
+
self.seg_weights = [0.3] * self.mod_gen_no
|
|
24
|
+
self.seg_weights[1] = 0.4
|
|
25
|
+
|
|
26
|
+
# self.seg_weights = opt.seg_weights
|
|
27
|
+
# assert len(self.seg_weights) == self.seg_gen_no, 'The number of the segmentation weights (seg_weights) is not equal to the number of target images (modalities_no)!'
|
|
28
|
+
# print(self.seg_weights)
|
|
29
|
+
# loss weights in calculating the final loss
|
|
30
|
+
self.loss_G_weights = [1 / self.mod_gen_no] * self.mod_gen_no
|
|
31
|
+
self.loss_GS_weights = [1 / self.mod_gen_no] * self.mod_gen_no
|
|
32
|
+
|
|
33
|
+
self.loss_D_weights = [1 / self.mod_gen_no] * self.mod_gen_no
|
|
34
|
+
self.loss_DS_weights = [1 / self.mod_gen_no] * self.mod_gen_no
|
|
35
|
+
|
|
36
|
+
# self.gpu_ids is a possibly modifed one for model initialization
|
|
37
|
+
# self.opt.gpu_ids is the original one received in the command
|
|
38
|
+
if not opt.is_train:
|
|
39
|
+
self.gpu_ids = [] # avoid the models being loaded as DP
|
|
40
|
+
else:
|
|
41
|
+
self.gpu_ids = opt.gpu_ids
|
|
42
|
+
|
|
43
|
+
self.loss_names = []
|
|
44
|
+
self.visual_names = ['real_A']
|
|
45
|
+
# specify the training losses you want to print out. The training/test scripts will call <BaseModel.get_current_losses>
|
|
46
|
+
for i in range(1, self.mod_gen_no + 1):
|
|
47
|
+
self.loss_names.extend(['G_GAN_' + str(i), 'G_L1_' + str(i), 'D_real_' + str(i), 'D_fake_' + str(i)])
|
|
48
|
+
self.visual_names.extend(['fake_B_' + str(i), 'real_B_' + str(i)])
|
|
49
|
+
if self.opt.seg_gen:
|
|
50
|
+
for i in range(1, self.mod_gen_no + 1):
|
|
51
|
+
self.loss_names.extend(['GS_GAN_' + str(i), 'GS_L1_' + str(i), 'DS_real_' + str(i), 'DS_fake_' + str(i)])
|
|
52
|
+
self.visual_names.extend(['fake_BS_' + str(i), 'real_BS_' + str(i)])
|
|
53
|
+
|
|
54
|
+
# specify the images you want to save/display. The training/test scripts will call <BaseModel.get_current_visuals>
|
|
55
|
+
# specify the models you want to save to the disk. The training/test scripts will call <BaseModel.save_networks> and <BaseModel.load_networks>
|
|
56
|
+
if self.is_train:
|
|
57
|
+
self.model_names = []
|
|
58
|
+
for i in range(1, self.mod_gen_no + 1):
|
|
59
|
+
self.model_names.extend(['G_' + str(i), 'D_' + str(i)])
|
|
60
|
+
|
|
61
|
+
if self.opt.seg_gen:
|
|
62
|
+
for i in range(1, self.mod_gen_no + 1):
|
|
63
|
+
self.model_names.extend(['GS_' + str(i), 'DS_' + str(i)])
|
|
64
|
+
|
|
65
|
+
else: # during test time, only load G
|
|
66
|
+
self.model_names = []
|
|
67
|
+
for i in range(1, self.mod_gen_no + 1):
|
|
68
|
+
self.model_names.extend(['G_' + str(i)])
|
|
69
|
+
|
|
70
|
+
if self.opt.seg_gen:
|
|
71
|
+
for i in range(1, self.mod_gen_no + 1):
|
|
72
|
+
self.model_names.extend(['GS_' + str(i)])
|
|
73
|
+
|
|
74
|
+
# define networks (both generator and discriminator)
|
|
75
|
+
self.netG = [None for _ in range(self.mod_gen_no)]
|
|
76
|
+
self.netGS = [None for _ in range(self.mod_gen_no)]
|
|
77
|
+
for i in range(self.mod_gen_no):
|
|
78
|
+
self.netG[i] = networks.define_G(self.opt.input_nc, self.opt.output_nc, self.opt.ngf, self.opt.net_g, self.opt.norm,
|
|
79
|
+
not self.opt.no_dropout, self.opt.init_type, self.opt.init_gain, self.gpu_ids, self.opt.padding)
|
|
80
|
+
print('***************************************')
|
|
81
|
+
print(self.opt.input_nc, self.opt.output_nc, self.opt.ngf, self.opt.net_g, self.opt.norm,
|
|
82
|
+
not self.opt.no_dropout, self.opt.init_type, self.opt.init_gain, self.gpu_ids, self.opt.padding)
|
|
83
|
+
print('***************************************')
|
|
84
|
+
for i in range(self.mod_gen_no):
|
|
85
|
+
if self.opt.seg_gen:
|
|
86
|
+
# if i == 0:
|
|
87
|
+
# self.netGS[i] = networks.define_G(self.opt.input_nc, self.opt.output_nc, self.opt.ngf, self.opt.net_gs, self.opt.norm,
|
|
88
|
+
# not self.opt.no_dropout, self.opt.init_type, self.opt.init_gain, self.gpu_ids)
|
|
89
|
+
# else:
|
|
90
|
+
self.netGS[i] = networks.define_G(self.opt.input_nc * 3, self.opt.output_nc, self.opt.ngf, self.opt.net_gs, self.opt.norm,
|
|
91
|
+
not self.opt.no_dropout, self.opt.init_type, self.opt.init_gain, self.gpu_ids)
|
|
92
|
+
|
|
93
|
+
if self.is_train: # define a discriminator; conditional GANs need to take both input and output images; Therefore, #channels for D is input_nc + output_nc
|
|
94
|
+
self.netD = [None for _ in range(self.mod_gen_no)]
|
|
95
|
+
self.netDS = [None for _ in range(self.mod_gen_no)]
|
|
96
|
+
for i in range(self.mod_gen_no):
|
|
97
|
+
self.netD[i] = networks.define_D(self.opt.input_nc + self.opt.output_nc, self.opt.ndf, self.opt.net_d,
|
|
98
|
+
self.opt.n_layers_D, self.opt.norm, self.opt.init_type, self.opt.init_gain,
|
|
99
|
+
self.gpu_ids)
|
|
100
|
+
for i in range(self.mod_gen_no):
|
|
101
|
+
if self.opt.seg_gen:
|
|
102
|
+
# if i == 0:
|
|
103
|
+
# self.netDS[i] = networks.define_D(self.opt.input_nc + self.opt.output_nc, self.opt.ndf, self.opt.net_ds,
|
|
104
|
+
# self.opt.n_layers_D, self.opt.norm, self.opt.init_type, self.opt.init_gain,
|
|
105
|
+
# self.gpu_ids)
|
|
106
|
+
# else:
|
|
107
|
+
self.netDS[i] = networks.define_D(self.opt.input_nc * 3 + self.opt.output_nc, self.opt.ndf, self.opt.net_ds,
|
|
108
|
+
self.opt.n_layers_D, self.opt.norm, self.opt.init_type, self.opt.init_gain,
|
|
109
|
+
self.gpu_ids)
|
|
110
|
+
|
|
111
|
+
|
|
112
|
+
if self.is_train:
|
|
113
|
+
# define loss functions
|
|
114
|
+
self.criterionGAN_mod = networks.GANLoss(self.opt.gan_mode).to(self.device)
|
|
115
|
+
self.criterionGAN_seg = networks.GANLoss(self.opt.gan_mode_s).to(self.device)
|
|
116
|
+
|
|
117
|
+
self.criterionSmoothL1 = torch.nn.SmoothL1Loss()
|
|
118
|
+
|
|
119
|
+
self.criterionVGG = networks.VGGLoss().to(self.device)
|
|
120
|
+
|
|
121
|
+
# initialize optimizers; schedulers will be automatically created by function <BaseModel.setup>.
|
|
122
|
+
params = []
|
|
123
|
+
for i in range(len(self.netG)):
|
|
124
|
+
params += list(self.netG[i].parameters())
|
|
125
|
+
for i in range(len(self.netGS)):
|
|
126
|
+
if self.netGS[i]:
|
|
127
|
+
params += list(self.netGS[i].parameters())
|
|
128
|
+
self.optimizer_G = torch.optim.Adam(params, lr=opt.lr, betas=(opt.beta1, 0.999))
|
|
129
|
+
|
|
130
|
+
params = []
|
|
131
|
+
for i in range(len(self.netD)):
|
|
132
|
+
params += list(self.netD[i].parameters())
|
|
133
|
+
for i in range(len(self.netDS)):
|
|
134
|
+
if self.netDS[i]:
|
|
135
|
+
params += list(self.netDS[i].parameters())
|
|
136
|
+
self.optimizer_D = torch.optim.Adam(params, lr=opt.lr, betas=(opt.beta1, 0.999))
|
|
137
|
+
|
|
138
|
+
self.optimizers.append(self.optimizer_G)
|
|
139
|
+
self.optimizers.append(self.optimizer_D)
|
|
140
|
+
|
|
141
|
+
def set_input(self, input):
|
|
142
|
+
"""
|
|
143
|
+
Unpack input data from the dataloader and perform necessary pre-processing steps.
|
|
144
|
+
|
|
145
|
+
:param input (dict): include the input image and the output modalities
|
|
146
|
+
"""
|
|
147
|
+
self.real_A = input['A'].to(self.device)
|
|
148
|
+
|
|
149
|
+
self.real_B_array = input['B']
|
|
150
|
+
self.real_BS_array = input['BS']
|
|
151
|
+
self.real_B = []
|
|
152
|
+
self.real_BS = []
|
|
153
|
+
for i in range(len(self.real_B_array)):
|
|
154
|
+
self.real_B.append(self.real_B_array[i].to(self.device))
|
|
155
|
+
for i in range(len(self.real_BS_array)):
|
|
156
|
+
self.real_BS.append(self.real_BS_array[i].to(self.device))
|
|
157
|
+
|
|
158
|
+
self.real_concatenated = []
|
|
159
|
+
if self.opt.seg_gen:
|
|
160
|
+
for i in range(self.mod_gen_no):
|
|
161
|
+
self.real_concatenated.append(torch.cat([self.real_A, self.real_B[0], self.real_B[i]], 1))
|
|
162
|
+
self.image_paths = input['A_paths']
|
|
163
|
+
|
|
164
|
+
def forward(self):
|
|
165
|
+
"""Run forward pass; called by both functions <optimize_parameters> and <test>."""
|
|
166
|
+
self.fake_B = []
|
|
167
|
+
for i in range(self.mod_gen_no):
|
|
168
|
+
self.fake_B.append(self.netG[i](self.real_A))
|
|
169
|
+
|
|
170
|
+
self.fake_BS = []
|
|
171
|
+
|
|
172
|
+
for i in range(self.mod_gen_no):
|
|
173
|
+
if self.netGS[i]:
|
|
174
|
+
# if i == 0:
|
|
175
|
+
# self.fake_BS.append(self.netGS[i](self.fake_B[0]))
|
|
176
|
+
# else:
|
|
177
|
+
self.fake_BS.append(self.netGS[i](torch.cat([self.real_A, self.fake_B[0], self.fake_B[i]], 1)))
|
|
178
|
+
|
|
179
|
+
|
|
180
|
+
def backward_D(self):
|
|
181
|
+
"""Calculate GAN loss for the discriminators"""
|
|
182
|
+
|
|
183
|
+
pred_fake = []
|
|
184
|
+
for i in range(self.mod_gen_no):
|
|
185
|
+
pred_fake.append(self.netD[i](torch.cat((self.real_A, self.fake_B[i]), 1).detach()))
|
|
186
|
+
|
|
187
|
+
pred_fake_s = []
|
|
188
|
+
for i in range(self.mod_gen_no):
|
|
189
|
+
if self.netDS[i]:
|
|
190
|
+
pred_fake_s.append(self.netDS[i](torch.cat((self.real_concatenated[i], self.fake_BS[i]), 1).detach()))
|
|
191
|
+
|
|
192
|
+
|
|
193
|
+
self.loss_D_fake = []
|
|
194
|
+
for i in range(self.mod_gen_no):
|
|
195
|
+
self.loss_D_fake.append(self.criterionGAN_mod(pred_fake[i], False))
|
|
196
|
+
|
|
197
|
+
self.loss_DS_fake = []
|
|
198
|
+
if self.opt.seg_gen:
|
|
199
|
+
for i in range(self.mod_gen_no):
|
|
200
|
+
self.loss_DS_fake.append(self.criterionGAN_seg(pred_fake_s[i], False))
|
|
201
|
+
|
|
202
|
+
pred_real = []
|
|
203
|
+
for i in range(self.mod_gen_no):
|
|
204
|
+
pred_real.append(self.netD[i](torch.cat((self.real_A, self.real_B[i]), 1)))
|
|
205
|
+
|
|
206
|
+
pred_real_s = []
|
|
207
|
+
for i in range(self.mod_gen_no):
|
|
208
|
+
if self.netDS[i]:
|
|
209
|
+
pred_real_s.append(self.netDS[i](torch.cat((self.real_concatenated[i], self.real_BS[i]), 1)))
|
|
210
|
+
|
|
211
|
+
|
|
212
|
+
self.loss_D_real = []
|
|
213
|
+
for i in range(self.mod_gen_no):
|
|
214
|
+
self.loss_D_real.append(self.criterionGAN_mod(pred_real[i], True))
|
|
215
|
+
|
|
216
|
+
self.loss_DS_real = []
|
|
217
|
+
if self.opt.seg_gen:
|
|
218
|
+
for i in range(self.mod_gen_no):
|
|
219
|
+
self.loss_DS_real.append(self.criterionGAN_seg(pred_real_s[i], True))
|
|
220
|
+
|
|
221
|
+
# combine losses and calculate gradients
|
|
222
|
+
# self.loss_D = (self.loss_D_fake[0] + self.loss_D_real[0]) * 0.5 * self.loss_D_weights[0]
|
|
223
|
+
self.loss_D = torch.tensor(0., device=self.device)
|
|
224
|
+
for i in range(0, self.mod_gen_no):
|
|
225
|
+
self.loss_D += (self.loss_D_fake[i] + self.loss_D_real[i]) * 0.5 * self.loss_D_weights[i]
|
|
226
|
+
if self.opt.seg_gen:
|
|
227
|
+
for i in range(0, self.mod_gen_no):
|
|
228
|
+
self.loss_D += (self.loss_DS_fake[i] + self.loss_DS_real[i]) * 0.5 * self.loss_DS_weights[i]
|
|
229
|
+
|
|
230
|
+
self.loss_D.backward()
|
|
231
|
+
|
|
232
|
+
def backward_G(self):
|
|
233
|
+
"""Calculate GAN and L1 loss for the generator"""
|
|
234
|
+
pred_fake = []
|
|
235
|
+
for i in range(self.mod_gen_no):
|
|
236
|
+
pred_fake.append(self.netD[i](torch.cat((self.real_A, self.fake_B[i]), 1)))
|
|
237
|
+
|
|
238
|
+
pred_fake_s = []
|
|
239
|
+
for i in range(self.mod_gen_no):
|
|
240
|
+
if self.netDS[i]:
|
|
241
|
+
pred_fake_s.append(self.netDS[i](torch.cat((self.real_concatenated[i], self.fake_BS[i]), 1)))
|
|
242
|
+
|
|
243
|
+
|
|
244
|
+
self.loss_G_GAN = []
|
|
245
|
+
self.loss_GS_GAN = []
|
|
246
|
+
for i in range(self.mod_gen_no):
|
|
247
|
+
self.loss_G_GAN.append(self.criterionGAN_mod(pred_fake[i], True))
|
|
248
|
+
if self.opt.seg_gen:
|
|
249
|
+
for i in range(self.mod_gen_no):
|
|
250
|
+
self.loss_GS_GAN.append(self.criterionGAN_mod(pred_fake_s[i], True))
|
|
251
|
+
|
|
252
|
+
# Second, G(A) = B
|
|
253
|
+
self.loss_G_L1 = []
|
|
254
|
+
self.loss_GS_L1 = []
|
|
255
|
+
for i in range(self.mod_gen_no):
|
|
256
|
+
self.loss_G_L1.append(self.criterionSmoothL1(self.fake_B[i], self.real_B[i]) * self.opt.lambda_L1)
|
|
257
|
+
if self.opt.seg_gen:
|
|
258
|
+
for i in range(self.mod_gen_no):
|
|
259
|
+
self.loss_GS_L1.append(self.criterionSmoothL1(self.fake_BS[i], self.real_BS[i]) * self.opt.lambda_L1)
|
|
260
|
+
|
|
261
|
+
#self.loss_G_VGG = []
|
|
262
|
+
#for i in range(self.mod_gen_no):
|
|
263
|
+
# self.loss_G_VGG.append(self.criterionVGG(self.fake_B[i], self.real_B[i]) * self.opt.lambda_feat)
|
|
264
|
+
|
|
265
|
+
# self.loss_G = (self.loss_G_GAN[0] + self.loss_G_L1[0]) * self.loss_G_weights[0]
|
|
266
|
+
self.loss_G = torch.tensor(0., device=self.device)
|
|
267
|
+
for i in range(0, self.mod_gen_no):
|
|
268
|
+
self.loss_G += (self.loss_G_GAN[i] + self.loss_G_L1[i]) * self.loss_G_weights[i]
|
|
269
|
+
# self.loss_G += (self.loss_G_GAN[i] + self.loss_G_L1[i] + self.loss_G_VGG[i]) * self.loss_G_weights[i]
|
|
270
|
+
if self.opt.seg_gen:
|
|
271
|
+
for i in range(0, self.mod_gen_no):
|
|
272
|
+
self.loss_G += (self.loss_GS_GAN[i] + self.loss_GS_L1[i]) * self.loss_GS_weights[i]
|
|
273
|
+
self.loss_G.backward()
|
|
274
|
+
|
|
275
|
+
def optimize_parameters(self):
|
|
276
|
+
self.forward() # compute fake images: G(A)
|
|
277
|
+
# update D
|
|
278
|
+
for i in range(self.mod_gen_no):
|
|
279
|
+
self.set_requires_grad(self.netD[i], True) # enable backprop for D1
|
|
280
|
+
for i in range(self.mod_gen_no):
|
|
281
|
+
if self.netDS[i]:
|
|
282
|
+
self.set_requires_grad(self.netDS[i], True)
|
|
283
|
+
|
|
284
|
+
self.optimizer_D.zero_grad() # set D's gradients to zero
|
|
285
|
+
self.backward_D() # calculate gradients for D
|
|
286
|
+
self.optimizer_D.step() # update D's weights
|
|
287
|
+
|
|
288
|
+
# update G
|
|
289
|
+
for i in range(self.mod_gen_no):
|
|
290
|
+
self.set_requires_grad(self.netD[i], False)
|
|
291
|
+
for i in range(self.mod_gen_no):
|
|
292
|
+
if self.netDS[i]:
|
|
293
|
+
self.set_requires_grad(self.netDS[i], False)
|
|
294
|
+
|
|
295
|
+
self.optimizer_G.zero_grad() # set G's gradients to zero
|
|
296
|
+
self.backward_G() # calculate graidents for G
|
|
297
|
+
self.optimizer_G.step() # udpate G's weights
|
|
@@ -20,11 +20,16 @@ class DeepLIIFModel(BaseModel):
|
|
|
20
20
|
# loss weights in calculating the final loss
|
|
21
21
|
self.loss_G_weights = [0.2, 0.2, 0.2, 0.2, 0.2]
|
|
22
22
|
self.loss_D_weights = [0.2, 0.2, 0.2, 0.2, 0.2]
|
|
23
|
+
|
|
24
|
+
if not opt.is_train:
|
|
25
|
+
self.gpu_ids = [] # avoid the models being loaded as DP
|
|
26
|
+
else:
|
|
27
|
+
self.gpu_ids = opt.gpu_ids
|
|
23
28
|
|
|
24
29
|
self.loss_names = []
|
|
25
30
|
self.visual_names = ['real_A']
|
|
26
31
|
# specify the training losses you want to print out. The training/test scripts will call <BaseModel.get_current_losses>
|
|
27
|
-
for i in range(1, self.opt.
|
|
32
|
+
for i in range(1, self.opt.modalities_no + 1 + 1):
|
|
28
33
|
self.loss_names.extend(['G_GAN_' + str(i), 'G_L1_' + str(i), 'D_real_' + str(i), 'D_fake_' + str(i)])
|
|
29
34
|
self.visual_names.extend(['fake_B_' + str(i), 'real_B_' + str(i)])
|
|
30
35
|
|
|
@@ -32,17 +37,17 @@ class DeepLIIFModel(BaseModel):
|
|
|
32
37
|
# specify the models you want to save to the disk. The training/test scripts will call <BaseModel.save_networks> and <BaseModel.load_networks>
|
|
33
38
|
if self.is_train:
|
|
34
39
|
self.model_names = []
|
|
35
|
-
for i in range(1, self.opt.
|
|
40
|
+
for i in range(1, self.opt.modalities_no + 1):
|
|
36
41
|
self.model_names.extend(['G' + str(i), 'D' + str(i)])
|
|
37
42
|
|
|
38
|
-
for i in range(1, self.opt.
|
|
43
|
+
for i in range(1, self.opt.modalities_no + 1 + 1):
|
|
39
44
|
self.model_names.extend(['G5' + str(i), 'D5' + str(i)])
|
|
40
45
|
else: # during test time, only load G
|
|
41
46
|
self.model_names = []
|
|
42
|
-
for i in range(1, self.opt.
|
|
47
|
+
for i in range(1, self.opt.modalities_no + 1):
|
|
43
48
|
self.model_names.extend(['G' + str(i)])
|
|
44
49
|
|
|
45
|
-
for i in range(1, self.opt.
|
|
50
|
+
for i in range(1, self.opt.modalities_no + 1 + 1):
|
|
46
51
|
self.model_names.extend(['G5' + str(i)])
|
|
47
52
|
|
|
48
53
|
# define networks (both generator and discriminator)
|