deepliif 1.1.7__py3-none-any.whl → 1.1.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,297 @@
1
+ import torch
2
+ from .base_model import BaseModel
3
+ from . import networks
4
+
5
+
6
+ class DeepLIIFExtModel(BaseModel):
7
+ """ This class implements the DeepLIIF model, for learning a mapping from input images to modalities given paired data."""
8
+
9
+ def __init__(self, opt):
10
+ """Initialize the DeepLIIF class.
11
+
12
+ Parameters:
13
+ opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions
14
+ """
15
+ BaseModel.__init__(self, opt)
16
+
17
+ self.mod_gen_no = self.opt.modalities_no
18
+ # self.seg_gen_no = self.opt.modalities_no + 1
19
+
20
+ # weights of the modalities in generating segmentation mask
21
+ self.seg_weights = [0, 0, 0]
22
+ if opt.seg_gen:
23
+ self.seg_weights = [0.3] * self.mod_gen_no
24
+ self.seg_weights[1] = 0.4
25
+
26
+ # self.seg_weights = opt.seg_weights
27
+ # assert len(self.seg_weights) == self.seg_gen_no, 'The number of the segmentation weights (seg_weights) is not equal to the number of target images (modalities_no)!'
28
+ # print(self.seg_weights)
29
+ # loss weights in calculating the final loss
30
+ self.loss_G_weights = [1 / self.mod_gen_no] * self.mod_gen_no
31
+ self.loss_GS_weights = [1 / self.mod_gen_no] * self.mod_gen_no
32
+
33
+ self.loss_D_weights = [1 / self.mod_gen_no] * self.mod_gen_no
34
+ self.loss_DS_weights = [1 / self.mod_gen_no] * self.mod_gen_no
35
+
36
+ # self.gpu_ids is a possibly modifed one for model initialization
37
+ # self.opt.gpu_ids is the original one received in the command
38
+ if not opt.is_train:
39
+ self.gpu_ids = [] # avoid the models being loaded as DP
40
+ else:
41
+ self.gpu_ids = opt.gpu_ids
42
+
43
+ self.loss_names = []
44
+ self.visual_names = ['real_A']
45
+ # specify the training losses you want to print out. The training/test scripts will call <BaseModel.get_current_losses>
46
+ for i in range(1, self.mod_gen_no + 1):
47
+ self.loss_names.extend(['G_GAN_' + str(i), 'G_L1_' + str(i), 'D_real_' + str(i), 'D_fake_' + str(i)])
48
+ self.visual_names.extend(['fake_B_' + str(i), 'real_B_' + str(i)])
49
+ if self.opt.seg_gen:
50
+ for i in range(1, self.mod_gen_no + 1):
51
+ self.loss_names.extend(['GS_GAN_' + str(i), 'GS_L1_' + str(i), 'DS_real_' + str(i), 'DS_fake_' + str(i)])
52
+ self.visual_names.extend(['fake_BS_' + str(i), 'real_BS_' + str(i)])
53
+
54
+ # specify the images you want to save/display. The training/test scripts will call <BaseModel.get_current_visuals>
55
+ # specify the models you want to save to the disk. The training/test scripts will call <BaseModel.save_networks> and <BaseModel.load_networks>
56
+ if self.is_train:
57
+ self.model_names = []
58
+ for i in range(1, self.mod_gen_no + 1):
59
+ self.model_names.extend(['G_' + str(i), 'D_' + str(i)])
60
+
61
+ if self.opt.seg_gen:
62
+ for i in range(1, self.mod_gen_no + 1):
63
+ self.model_names.extend(['GS_' + str(i), 'DS_' + str(i)])
64
+
65
+ else: # during test time, only load G
66
+ self.model_names = []
67
+ for i in range(1, self.mod_gen_no + 1):
68
+ self.model_names.extend(['G_' + str(i)])
69
+
70
+ if self.opt.seg_gen:
71
+ for i in range(1, self.mod_gen_no + 1):
72
+ self.model_names.extend(['GS_' + str(i)])
73
+
74
+ # define networks (both generator and discriminator)
75
+ self.netG = [None for _ in range(self.mod_gen_no)]
76
+ self.netGS = [None for _ in range(self.mod_gen_no)]
77
+ for i in range(self.mod_gen_no):
78
+ self.netG[i] = networks.define_G(self.opt.input_nc, self.opt.output_nc, self.opt.ngf, self.opt.net_g, self.opt.norm,
79
+ not self.opt.no_dropout, self.opt.init_type, self.opt.init_gain, self.gpu_ids, self.opt.padding)
80
+ print('***************************************')
81
+ print(self.opt.input_nc, self.opt.output_nc, self.opt.ngf, self.opt.net_g, self.opt.norm,
82
+ not self.opt.no_dropout, self.opt.init_type, self.opt.init_gain, self.gpu_ids, self.opt.padding)
83
+ print('***************************************')
84
+ for i in range(self.mod_gen_no):
85
+ if self.opt.seg_gen:
86
+ # if i == 0:
87
+ # self.netGS[i] = networks.define_G(self.opt.input_nc, self.opt.output_nc, self.opt.ngf, self.opt.net_gs, self.opt.norm,
88
+ # not self.opt.no_dropout, self.opt.init_type, self.opt.init_gain, self.gpu_ids)
89
+ # else:
90
+ self.netGS[i] = networks.define_G(self.opt.input_nc * 3, self.opt.output_nc, self.opt.ngf, self.opt.net_gs, self.opt.norm,
91
+ not self.opt.no_dropout, self.opt.init_type, self.opt.init_gain, self.gpu_ids)
92
+
93
+ if self.is_train: # define a discriminator; conditional GANs need to take both input and output images; Therefore, #channels for D is input_nc + output_nc
94
+ self.netD = [None for _ in range(self.mod_gen_no)]
95
+ self.netDS = [None for _ in range(self.mod_gen_no)]
96
+ for i in range(self.mod_gen_no):
97
+ self.netD[i] = networks.define_D(self.opt.input_nc + self.opt.output_nc, self.opt.ndf, self.opt.net_d,
98
+ self.opt.n_layers_D, self.opt.norm, self.opt.init_type, self.opt.init_gain,
99
+ self.gpu_ids)
100
+ for i in range(self.mod_gen_no):
101
+ if self.opt.seg_gen:
102
+ # if i == 0:
103
+ # self.netDS[i] = networks.define_D(self.opt.input_nc + self.opt.output_nc, self.opt.ndf, self.opt.net_ds,
104
+ # self.opt.n_layers_D, self.opt.norm, self.opt.init_type, self.opt.init_gain,
105
+ # self.gpu_ids)
106
+ # else:
107
+ self.netDS[i] = networks.define_D(self.opt.input_nc * 3 + self.opt.output_nc, self.opt.ndf, self.opt.net_ds,
108
+ self.opt.n_layers_D, self.opt.norm, self.opt.init_type, self.opt.init_gain,
109
+ self.gpu_ids)
110
+
111
+
112
+ if self.is_train:
113
+ # define loss functions
114
+ self.criterionGAN_mod = networks.GANLoss(self.opt.gan_mode).to(self.device)
115
+ self.criterionGAN_seg = networks.GANLoss(self.opt.gan_mode_s).to(self.device)
116
+
117
+ self.criterionSmoothL1 = torch.nn.SmoothL1Loss()
118
+
119
+ self.criterionVGG = networks.VGGLoss().to(self.device)
120
+
121
+ # initialize optimizers; schedulers will be automatically created by function <BaseModel.setup>.
122
+ params = []
123
+ for i in range(len(self.netG)):
124
+ params += list(self.netG[i].parameters())
125
+ for i in range(len(self.netGS)):
126
+ if self.netGS[i]:
127
+ params += list(self.netGS[i].parameters())
128
+ self.optimizer_G = torch.optim.Adam(params, lr=opt.lr, betas=(opt.beta1, 0.999))
129
+
130
+ params = []
131
+ for i in range(len(self.netD)):
132
+ params += list(self.netD[i].parameters())
133
+ for i in range(len(self.netDS)):
134
+ if self.netDS[i]:
135
+ params += list(self.netDS[i].parameters())
136
+ self.optimizer_D = torch.optim.Adam(params, lr=opt.lr, betas=(opt.beta1, 0.999))
137
+
138
+ self.optimizers.append(self.optimizer_G)
139
+ self.optimizers.append(self.optimizer_D)
140
+
141
+ def set_input(self, input):
142
+ """
143
+ Unpack input data from the dataloader and perform necessary pre-processing steps.
144
+
145
+ :param input (dict): include the input image and the output modalities
146
+ """
147
+ self.real_A = input['A'].to(self.device)
148
+
149
+ self.real_B_array = input['B']
150
+ self.real_BS_array = input['BS']
151
+ self.real_B = []
152
+ self.real_BS = []
153
+ for i in range(len(self.real_B_array)):
154
+ self.real_B.append(self.real_B_array[i].to(self.device))
155
+ for i in range(len(self.real_BS_array)):
156
+ self.real_BS.append(self.real_BS_array[i].to(self.device))
157
+
158
+ self.real_concatenated = []
159
+ if self.opt.seg_gen:
160
+ for i in range(self.mod_gen_no):
161
+ self.real_concatenated.append(torch.cat([self.real_A, self.real_B[0], self.real_B[i]], 1))
162
+ self.image_paths = input['A_paths']
163
+
164
+ def forward(self):
165
+ """Run forward pass; called by both functions <optimize_parameters> and <test>."""
166
+ self.fake_B = []
167
+ for i in range(self.mod_gen_no):
168
+ self.fake_B.append(self.netG[i](self.real_A))
169
+
170
+ self.fake_BS = []
171
+
172
+ for i in range(self.mod_gen_no):
173
+ if self.netGS[i]:
174
+ # if i == 0:
175
+ # self.fake_BS.append(self.netGS[i](self.fake_B[0]))
176
+ # else:
177
+ self.fake_BS.append(self.netGS[i](torch.cat([self.real_A, self.fake_B[0], self.fake_B[i]], 1)))
178
+
179
+
180
+ def backward_D(self):
181
+ """Calculate GAN loss for the discriminators"""
182
+
183
+ pred_fake = []
184
+ for i in range(self.mod_gen_no):
185
+ pred_fake.append(self.netD[i](torch.cat((self.real_A, self.fake_B[i]), 1).detach()))
186
+
187
+ pred_fake_s = []
188
+ for i in range(self.mod_gen_no):
189
+ if self.netDS[i]:
190
+ pred_fake_s.append(self.netDS[i](torch.cat((self.real_concatenated[i], self.fake_BS[i]), 1).detach()))
191
+
192
+
193
+ self.loss_D_fake = []
194
+ for i in range(self.mod_gen_no):
195
+ self.loss_D_fake.append(self.criterionGAN_mod(pred_fake[i], False))
196
+
197
+ self.loss_DS_fake = []
198
+ if self.opt.seg_gen:
199
+ for i in range(self.mod_gen_no):
200
+ self.loss_DS_fake.append(self.criterionGAN_seg(pred_fake_s[i], False))
201
+
202
+ pred_real = []
203
+ for i in range(self.mod_gen_no):
204
+ pred_real.append(self.netD[i](torch.cat((self.real_A, self.real_B[i]), 1)))
205
+
206
+ pred_real_s = []
207
+ for i in range(self.mod_gen_no):
208
+ if self.netDS[i]:
209
+ pred_real_s.append(self.netDS[i](torch.cat((self.real_concatenated[i], self.real_BS[i]), 1)))
210
+
211
+
212
+ self.loss_D_real = []
213
+ for i in range(self.mod_gen_no):
214
+ self.loss_D_real.append(self.criterionGAN_mod(pred_real[i], True))
215
+
216
+ self.loss_DS_real = []
217
+ if self.opt.seg_gen:
218
+ for i in range(self.mod_gen_no):
219
+ self.loss_DS_real.append(self.criterionGAN_seg(pred_real_s[i], True))
220
+
221
+ # combine losses and calculate gradients
222
+ # self.loss_D = (self.loss_D_fake[0] + self.loss_D_real[0]) * 0.5 * self.loss_D_weights[0]
223
+ self.loss_D = torch.tensor(0., device=self.device)
224
+ for i in range(0, self.mod_gen_no):
225
+ self.loss_D += (self.loss_D_fake[i] + self.loss_D_real[i]) * 0.5 * self.loss_D_weights[i]
226
+ if self.opt.seg_gen:
227
+ for i in range(0, self.mod_gen_no):
228
+ self.loss_D += (self.loss_DS_fake[i] + self.loss_DS_real[i]) * 0.5 * self.loss_DS_weights[i]
229
+
230
+ self.loss_D.backward()
231
+
232
+ def backward_G(self):
233
+ """Calculate GAN and L1 loss for the generator"""
234
+ pred_fake = []
235
+ for i in range(self.mod_gen_no):
236
+ pred_fake.append(self.netD[i](torch.cat((self.real_A, self.fake_B[i]), 1)))
237
+
238
+ pred_fake_s = []
239
+ for i in range(self.mod_gen_no):
240
+ if self.netDS[i]:
241
+ pred_fake_s.append(self.netDS[i](torch.cat((self.real_concatenated[i], self.fake_BS[i]), 1)))
242
+
243
+
244
+ self.loss_G_GAN = []
245
+ self.loss_GS_GAN = []
246
+ for i in range(self.mod_gen_no):
247
+ self.loss_G_GAN.append(self.criterionGAN_mod(pred_fake[i], True))
248
+ if self.opt.seg_gen:
249
+ for i in range(self.mod_gen_no):
250
+ self.loss_GS_GAN.append(self.criterionGAN_mod(pred_fake_s[i], True))
251
+
252
+ # Second, G(A) = B
253
+ self.loss_G_L1 = []
254
+ self.loss_GS_L1 = []
255
+ for i in range(self.mod_gen_no):
256
+ self.loss_G_L1.append(self.criterionSmoothL1(self.fake_B[i], self.real_B[i]) * self.opt.lambda_L1)
257
+ if self.opt.seg_gen:
258
+ for i in range(self.mod_gen_no):
259
+ self.loss_GS_L1.append(self.criterionSmoothL1(self.fake_BS[i], self.real_BS[i]) * self.opt.lambda_L1)
260
+
261
+ #self.loss_G_VGG = []
262
+ #for i in range(self.mod_gen_no):
263
+ # self.loss_G_VGG.append(self.criterionVGG(self.fake_B[i], self.real_B[i]) * self.opt.lambda_feat)
264
+
265
+ # self.loss_G = (self.loss_G_GAN[0] + self.loss_G_L1[0]) * self.loss_G_weights[0]
266
+ self.loss_G = torch.tensor(0., device=self.device)
267
+ for i in range(0, self.mod_gen_no):
268
+ self.loss_G += (self.loss_G_GAN[i] + self.loss_G_L1[i]) * self.loss_G_weights[i]
269
+ # self.loss_G += (self.loss_G_GAN[i] + self.loss_G_L1[i] + self.loss_G_VGG[i]) * self.loss_G_weights[i]
270
+ if self.opt.seg_gen:
271
+ for i in range(0, self.mod_gen_no):
272
+ self.loss_G += (self.loss_GS_GAN[i] + self.loss_GS_L1[i]) * self.loss_GS_weights[i]
273
+ self.loss_G.backward()
274
+
275
+ def optimize_parameters(self):
276
+ self.forward() # compute fake images: G(A)
277
+ # update D
278
+ for i in range(self.mod_gen_no):
279
+ self.set_requires_grad(self.netD[i], True) # enable backprop for D1
280
+ for i in range(self.mod_gen_no):
281
+ if self.netDS[i]:
282
+ self.set_requires_grad(self.netDS[i], True)
283
+
284
+ self.optimizer_D.zero_grad() # set D's gradients to zero
285
+ self.backward_D() # calculate gradients for D
286
+ self.optimizer_D.step() # update D's weights
287
+
288
+ # update G
289
+ for i in range(self.mod_gen_no):
290
+ self.set_requires_grad(self.netD[i], False)
291
+ for i in range(self.mod_gen_no):
292
+ if self.netDS[i]:
293
+ self.set_requires_grad(self.netDS[i], False)
294
+
295
+ self.optimizer_G.zero_grad() # set G's gradients to zero
296
+ self.backward_G() # calculate graidents for G
297
+ self.optimizer_G.step() # udpate G's weights
@@ -20,11 +20,16 @@ class DeepLIIFModel(BaseModel):
20
20
  # loss weights in calculating the final loss
21
21
  self.loss_G_weights = [0.2, 0.2, 0.2, 0.2, 0.2]
22
22
  self.loss_D_weights = [0.2, 0.2, 0.2, 0.2, 0.2]
23
+
24
+ if not opt.is_train:
25
+ self.gpu_ids = [] # avoid the models being loaded as DP
26
+ else:
27
+ self.gpu_ids = opt.gpu_ids
23
28
 
24
29
  self.loss_names = []
25
30
  self.visual_names = ['real_A']
26
31
  # specify the training losses you want to print out. The training/test scripts will call <BaseModel.get_current_losses>
27
- for i in range(1, self.opt.targets_no + 1):
32
+ for i in range(1, self.opt.modalities_no + 1 + 1):
28
33
  self.loss_names.extend(['G_GAN_' + str(i), 'G_L1_' + str(i), 'D_real_' + str(i), 'D_fake_' + str(i)])
29
34
  self.visual_names.extend(['fake_B_' + str(i), 'real_B_' + str(i)])
30
35
 
@@ -32,17 +37,17 @@ class DeepLIIFModel(BaseModel):
32
37
  # specify the models you want to save to the disk. The training/test scripts will call <BaseModel.save_networks> and <BaseModel.load_networks>
33
38
  if self.is_train:
34
39
  self.model_names = []
35
- for i in range(1, self.opt.targets_no):
40
+ for i in range(1, self.opt.modalities_no + 1):
36
41
  self.model_names.extend(['G' + str(i), 'D' + str(i)])
37
42
 
38
- for i in range(1, self.opt.targets_no+1):
43
+ for i in range(1, self.opt.modalities_no + 1 + 1):
39
44
  self.model_names.extend(['G5' + str(i), 'D5' + str(i)])
40
45
  else: # during test time, only load G
41
46
  self.model_names = []
42
- for i in range(1, self.opt.targets_no):
47
+ for i in range(1, self.opt.modalities_no + 1):
43
48
  self.model_names.extend(['G' + str(i)])
44
49
 
45
- for i in range(1, self.opt.targets_no+1):
50
+ for i in range(1, self.opt.modalities_no + 1 + 1):
46
51
  self.model_names.extend(['G5' + str(i)])
47
52
 
48
53
  # define networks (both generator and discriminator)