deepliif 1.2.3__py3-none-any.whl → 1.2.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -2,7 +2,8 @@ import torch
2
2
  from .base_model import BaseModel
3
3
  from . import networks
4
4
  from .networks import get_optimizer
5
-
5
+ import os
6
+ from ..util.util import get_input_id, init_input_and_mod_id
6
7
 
7
8
  class DeepLIIFModel(BaseModel):
8
9
  """ This class implements the DeepLIIF model, for learning a mapping from input images to modalities given paired data."""
@@ -20,6 +21,8 @@ class DeepLIIFModel(BaseModel):
20
21
  self.seg_weights = opt.seg_weights
21
22
  self.loss_G_weights = opt.loss_G_weights
22
23
  self.loss_D_weights = opt.loss_D_weights
24
+ self.mod_id_seg, self.input_id = init_input_and_mod_id(opt) # creates self.input_id, self.mod_id_seg
25
+ print(f'Initializing model with segmentation modality id {self.mod_id_seg}, input id {self.input_id}')
23
26
 
24
27
  if not opt.is_train:
25
28
  self.gpu_ids = [] # avoid the models being loaded as DP
@@ -29,76 +32,77 @@ class DeepLIIFModel(BaseModel):
29
32
  self.loss_names = []
30
33
  self.visual_names = ['real_A']
31
34
  # specify the training losses you want to print out. The training/test scripts will call <BaseModel.get_current_losses>
32
- for i in range(1, self.opt.modalities_no + 1 + 1):
33
- self.loss_names.extend(['G_GAN_' + str(i), 'G_L1_' + str(i), 'D_real_' + str(i), 'D_fake_' + str(i)])
34
- self.visual_names.extend(['fake_B_' + str(i), 'fake_B_5' + str(i), 'real_B_' + str(i)])
35
+ for i in range(self.opt.modalities_no):
36
+ self.loss_names.extend([f'G_GAN_{i+1}', f'G_L1_{i+1}', f'D_real_{i+1}', f'D_fake_{i+1}'])
37
+ self.visual_names.extend([f'fake_B_{i+1}', f'real_B_{i+1}'])
38
+ self.loss_names.extend([f'G_GAN_{self.mod_id_seg}',f'G_L1_{self.mod_id_seg}',f'D_real_{self.mod_id_seg}',f'D_fake_{self.mod_id_seg}'])
39
+
40
+ for i in range(self.opt.modalities_no+1):
41
+ self.visual_names.extend([f'fake_B_{self.mod_id_seg}{i}']) # 0 is used for the base input mod
42
+ self.visual_names.extend([f'fake_B_{self.mod_id_seg}', f'real_B_{self.mod_id_seg}'])
35
43
 
36
44
  # specify the images you want to save/display. The training/test scripts will call <BaseModel.get_current_visuals>
37
45
  # specify the models you want to save to the disk. The training/test scripts will call <BaseModel.save_networks> and <BaseModel.load_networks>
46
+ self.model_names_g = []
47
+ self.model_names_gs = []
38
48
  if self.is_train:
39
49
  self.model_names = []
40
50
  for i in range(1, self.opt.modalities_no + 1):
41
- self.model_names.extend(['G' + str(i), 'D' + str(i)])
42
-
43
- for i in range(1, self.opt.modalities_no + 1 + 1):
44
- self.model_names.extend(['G5' + str(i), 'D5' + str(i)])
51
+ self.model_names.extend([f'G{i}', f'D{i}'])
52
+ self.model_names_g.append(f'G{i}')
53
+
54
+ for i in range(self.opt.modalities_no + 1): # 0 is used for the base input mod
55
+ if self.input_id == '0':
56
+ self.model_names.extend([f'G{self.mod_id_seg}{i}', f'D{self.mod_id_seg}{i}'])
57
+ self.model_names_gs.append(f'G{self.mod_id_seg}{i}')
58
+ else:
59
+ self.model_names.extend([f'G{self.mod_id_seg}{i+1}', f'D{self.mod_id_seg}{i+1}'])
60
+ self.model_names_gs.append(f'G{self.mod_id_seg}{i+1}')
45
61
  else: # during test time, only load G
46
62
  self.model_names = []
47
63
  for i in range(1, self.opt.modalities_no + 1):
48
- self.model_names.extend(['G' + str(i)])
49
-
50
- for i in range(1, self.opt.modalities_no + 1 + 1):
51
- self.model_names.extend(['G5' + str(i)])
64
+ self.model_names.extend([f'G{i}'])
65
+ self.model_names_g.append(f'G{i}')
66
+
67
+ #input_id = get_input_id(os.path.join(opt.checkpoints_dir, opt.name))
68
+ if self.input_id == '0':
69
+ for i in range(self.opt.modalities_no + 1): # 0 is used for the base input mod
70
+ self.model_names.extend([f'G{self.mod_id_seg}{i}'])
71
+ self.model_names_gs.append(f'G{self.mod_id_seg}{i}')
72
+ else:
73
+ for i in range(self.opt.modalities_no + 1): # old setting, 1 is used for the base input mod
74
+ self.model_names.extend([f'G{self.mod_id_seg}{i+1}'])
75
+ self.model_names_gs.append(f'G{self.mod_id_seg}{i+1}')
52
76
 
53
77
  # define networks (both generator and discriminator)
54
78
  if isinstance(opt.netG, str):
55
- opt.netG = [opt.netG] * 4
79
+ opt.netG = [opt.netG] * self.opt.modalities_no
56
80
  if isinstance(opt.net_gs, str):
57
- opt.net_gs = [opt.net_gs]*5
81
+ opt.net_gs = [opt.net_gs] * (self.opt.modalities_no + 1) # +1 for base input mod
58
82
 
59
-
60
- self.netG1 = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG[0], opt.norm,
61
- not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids, opt.padding)
62
- self.netG2 = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG[1], opt.norm,
63
- not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids, opt.padding)
64
- self.netG3 = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG[2], opt.norm,
65
- not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids, opt.padding)
66
- self.netG4 = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG[3], opt.norm,
67
- not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids, opt.padding)
83
+
84
+ for i,model_name in enumerate(self.model_names_g):
85
+ setattr(self,f'net{model_name}',networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG[i], opt.norm,
86
+ not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids, opt.padding))
68
87
 
69
88
  # DeepLIIF model currently uses one gs arch because there is only one explicit seg mod output
70
- self.netG51 = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.net_gs[0], opt.norm,
71
- not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids)
72
- self.netG52 = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.net_gs[1], opt.norm,
73
- not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids)
74
- self.netG53 = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.net_gs[2], opt.norm,
75
- not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids)
76
- self.netG54 = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.net_gs[3], opt.norm,
77
- not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids)
78
- self.netG55 = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.net_gs[4], opt.norm,
79
- not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids)
80
-
89
+ for i,model_name in enumerate(self.model_names_gs):
90
+ setattr(self,f'net{model_name}',networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.net_gs[i], opt.norm,
91
+ not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids))
81
92
 
82
93
  if self.is_train: # define a discriminator; conditional GANs need to take both input and output images; Therefore, #channels for D is input_nc + output_nc
83
- self.netD1 = networks.define_D(opt.input_nc+opt.output_nc , opt.ndf, opt.netD,
84
- opt.n_layers_D, opt.norm, opt.init_type, opt.init_gain, self.gpu_ids)
85
- self.netD2 = networks.define_D(opt.input_nc + opt.output_nc, opt.ndf, opt.netD,
86
- opt.n_layers_D, opt.norm, opt.init_type, opt.init_gain, self.gpu_ids)
87
- self.netD3 = networks.define_D(opt.input_nc + opt.output_nc, opt.ndf, opt.netD,
88
- opt.n_layers_D, opt.norm, opt.init_type, opt.init_gain, self.gpu_ids)
89
- self.netD4 = networks.define_D(opt.input_nc + opt.output_nc, opt.ndf, opt.netD,
90
- opt.n_layers_D, opt.norm, opt.init_type, opt.init_gain, self.gpu_ids)
91
-
92
- self.netD51 = networks.define_D(opt.input_nc + opt.output_nc, opt.ndf, opt.netD,
93
- opt.n_layers_D, opt.norm, opt.init_type, opt.init_gain, self.gpu_ids)
94
- self.netD52 = networks.define_D(opt.input_nc + opt.output_nc, opt.ndf, opt.netD,
95
- opt.n_layers_D, opt.norm, opt.init_type, opt.init_gain, self.gpu_ids)
96
- self.netD53 = networks.define_D(opt.input_nc + opt.output_nc, opt.ndf, opt.netD,
97
- opt.n_layers_D, opt.norm, opt.init_type, opt.init_gain, self.gpu_ids)
98
- self.netD54 = networks.define_D(opt.input_nc + opt.output_nc, opt.ndf, opt.netD,
99
- opt.n_layers_D, opt.norm, opt.init_type, opt.init_gain, self.gpu_ids)
100
- self.netD55 = networks.define_D(opt.input_nc + opt.output_nc, opt.ndf, opt.netD,
101
- opt.n_layers_D, opt.norm, opt.init_type, opt.init_gain, self.gpu_ids)
94
+ self.model_names_d = [f'D{i+1}' for i in range(self.opt.modalities_no)]
95
+ if self.input_id == '0':
96
+ self.model_names_ds = [f'D{self.mod_id_seg}{i}' for i in range(self.opt.modalities_no+1)]
97
+ else:
98
+ self.model_names_ds = [f'D{self.mod_id_seg}{i+1}' for i in range(self.opt.modalities_no+1)]
99
+ for model_name in self.model_names_d:
100
+ setattr(self,f'net{model_name}',networks.define_D(opt.input_nc+opt.output_nc , opt.ndf, opt.netD,
101
+ opt.n_layers_D, opt.norm, opt.init_type, opt.init_gain, self.gpu_ids))
102
+
103
+ for model_name in self.model_names_ds:
104
+ setattr(self,f'net{model_name}',networks.define_D(opt.input_nc + opt.output_nc, opt.ndf, opt.netD,
105
+ opt.n_layers_D, opt.norm, opt.init_type, opt.init_gain, self.gpu_ids))
102
106
 
103
107
  if self.is_train:
104
108
  # define loss functions
@@ -107,14 +111,22 @@ class DeepLIIFModel(BaseModel):
107
111
  self.criterionSmoothL1 = torch.nn.SmoothL1Loss()
108
112
 
109
113
  # initialize optimizers; schedulers will be automatically created by function <BaseModel.setup>.
110
- params = list(self.netG1.parameters()) + list(self.netG2.parameters()) + list(self.netG3.parameters()) + list(self.netG4.parameters()) + list(self.netG51.parameters()) + list(self.netG52.parameters()) + list(self.netG53.parameters()) + list(self.netG54.parameters()) + list(self.netG55.parameters())
114
+ #params = list(self.netG1.parameters()) + list(self.netG2.parameters()) + list(self.netG3.parameters()) + list(self.netG4.parameters()) + list(self.netG51.parameters()) + list(self.netG52.parameters()) + list(self.netG53.parameters()) + list(self.netG54.parameters()) + list(self.netG55.parameters())
115
+ params = []
116
+ for model_name in self.model_names_g + self.model_names_gs:
117
+ params += list(getattr(self,f'net{model_name}').parameters())
118
+
111
119
  try:
112
120
  self.optimizer_G = get_optimizer(opt.optimizer)(params, lr=opt.lr_g, betas=(opt.beta1, 0.999))
113
121
  except:
114
122
  print(f'betas are not used for optimizer torch.optim.{opt.optimizer} in generators')
115
123
  self.optimizer_G = get_optimizer(opt.optimizer)(params, lr=opt.lr_g)
116
124
 
117
- params = list(self.netD1.parameters()) + list(self.netD2.parameters()) + list(self.netD3.parameters()) + list(self.netD4.parameters()) + list(self.netD51.parameters()) + list(self.netD52.parameters()) + list(self.netD53.parameters()) + list(self.netD54.parameters()) + list(self.netD55.parameters())
125
+ #params = list(self.netD1.parameters()) + list(self.netD2.parameters()) + list(self.netD3.parameters()) + list(self.netD4.parameters()) + list(self.netD51.parameters()) + list(self.netD52.parameters()) + list(self.netD53.parameters()) + list(self.netD54.parameters()) + list(self.netD55.parameters())
126
+ params = []
127
+ for model_name in self.model_names_d + self.model_names_ds:
128
+ params += list(getattr(self,f'net{model_name}').parameters())
129
+
118
130
  try:
119
131
  self.optimizer_D = get_optimizer(opt.optimizer)(params, lr=opt.lr_d, betas=(opt.beta1, 0.999))
120
132
  except:
@@ -136,167 +148,251 @@ class DeepLIIFModel(BaseModel):
136
148
  self.real_A = input['A'].to(self.device)
137
149
 
138
150
  self.real_B_array = input['B']
139
- self.real_B_1 = self.real_B_array[0].to(self.device)
140
- self.real_B_2 = self.real_B_array[1].to(self.device)
141
- self.real_B_3 = self.real_B_array[2].to(self.device)
142
- self.real_B_4 = self.real_B_array[3].to(self.device)
143
- self.real_B_5 = self.real_B_array[4].to(self.device)
151
+ for i in range(self.opt.modalities_no):
152
+ setattr(self,f'real_B_{i+1}',self.real_B_array[i].to(self.device))
153
+ setattr(self,f'real_B_{self.mod_id_seg}',self.real_B_array[self.opt.modalities_no].to(self.device)) # the last one is seg
154
+
144
155
  self.image_paths = input['A_paths']
145
156
 
146
157
  def forward(self):
147
158
  """Run forward pass; called by both functions <optimize_parameters> and <test>."""
148
- self.fake_B_1 = self.netG1(self.real_A) # Hematoxylin image generator
149
- self.fake_B_2 = self.netG2(self.real_A) # mpIF DAPI image generator
150
- self.fake_B_3 = self.netG3(self.real_A) # mpIF Lap2 image generator
151
- self.fake_B_4 = self.netG4(self.real_A) # mpIF Ki67 image generator
152
-
153
- self.fake_B_5_1 = self.netG51(self.real_A) # Segmentation mask generator from IHC input image
154
- self.fake_B_5_2 = self.netG52(self.fake_B_1) # Segmentation mask generator from Hematoxylin input image
155
- self.fake_B_5_3 = self.netG53(self.fake_B_2) # Segmentation mask generator from mpIF DAPI input image
156
- self.fake_B_5_4 = self.netG54(self.fake_B_3) # Segmentation mask generator from mpIF Lap2 input image
157
- self.fake_B_5_5 = self.netG55(self.fake_B_4) # Segmentation mask generator from mpIF Lap2 input image
158
- self.fake_B_5 = torch.stack([torch.mul(self.fake_B_5_1, self.seg_weights[0]),
159
- torch.mul(self.fake_B_5_2, self.seg_weights[1]),
160
- torch.mul(self.fake_B_5_3, self.seg_weights[2]),
161
- torch.mul(self.fake_B_5_4, self.seg_weights[3]),
162
- torch.mul(self.fake_B_5_5, self.seg_weights[4])]).sum(dim=0)
159
+ # self.fake_B_1 = self.netG1(self.real_A) # Hematoxylin image generator
160
+ # self.fake_B_2 = self.netG2(self.real_A) # mpIF DAPI image generator
161
+ # self.fake_B_3 = self.netG3(self.real_A) # mpIF Lap2 image generator
162
+ # self.fake_B_4 = self.netG4(self.real_A) # mpIF Ki67 image generator
163
+
164
+ for i in range(self.opt.modalities_no):
165
+ setattr(self,f'fake_B_{i+1}',getattr(self,f'netG{i+1}')(self.real_A))
166
+
167
+ # self.fake_B_5_1 = self.netG51(self.real_A) # Segmentation mask generator from IHC input image
168
+ # self.fake_B_5_2 = self.netG52(self.fake_B_1) # Segmentation mask generator from Hematoxylin input image
169
+ # self.fake_B_5_3 = self.netG53(self.fake_B_2) # Segmentation mask generator from mpIF DAPI input image
170
+ # self.fake_B_5_4 = self.netG54(self.fake_B_3) # Segmentation mask generator from mpIF Lap2 input image
171
+ # self.fake_B_5_5 = self.netG55(self.fake_B_4) # Segmentation mask generator from mpIF Lap2 input image
172
+ # self.fake_B_5 = torch.stack([torch.mul(self.fake_B_5_1, self.seg_weights[0]),
173
+ # torch.mul(self.fake_B_5_2, self.seg_weights[1]),
174
+ # torch.mul(self.fake_B_5_3, self.seg_weights[2]),
175
+ # torch.mul(self.fake_B_5_4, self.seg_weights[3]),
176
+ # torch.mul(self.fake_B_5_5, self.seg_weights[4])]).sum(dim=0)
177
+
178
+ for i,model_name in enumerate(self.model_names_gs):
179
+ if i == 0:
180
+ setattr(self,f'fake_B_{self.mod_id_seg}_{i}',getattr(self,f'net{model_name}')(self.real_A))
181
+ else:
182
+ setattr(self,f'fake_B_{self.mod_id_seg}_{i}',getattr(self,f'net{model_name}')(getattr(self,f'fake_B_{i}')))
183
+
184
+ setattr(self,f'fake_B_{self.mod_id_seg}',torch.stack([torch.mul(getattr(self,f'fake_B_{self.mod_id_seg}_{i}'), self.seg_weights[i]) for i in range(self.opt.modalities_no+1)]).sum(dim=0))
163
185
 
164
186
  def backward_D(self):
165
187
  """Calculate GAN loss for the discriminators"""
166
- fake_AB_1 = torch.cat((self.real_A, self.fake_B_1), 1) # Conditional GANs; feed IHC input and Hematoxtlin output to the discriminator
167
- fake_AB_2 = torch.cat((self.real_A, self.fake_B_2), 1) # Conditional GANs; feed IHC input and mpIF DAPI output to the discriminator
168
- fake_AB_3 = torch.cat((self.real_A, self.fake_B_3), 1) # Conditional GANs; feed IHC input and mpIF Lap2 output to the discriminator
169
- fake_AB_4 = torch.cat((self.real_A, self.fake_B_4), 1) # Conditional GANs; feed IHC input and mpIF Ki67 output to the discriminator
170
-
171
- pred_fake_1 = self.netD1(fake_AB_1.detach())
172
- pred_fake_2 = self.netD2(fake_AB_2.detach())
173
- pred_fake_3 = self.netD3(fake_AB_3.detach())
174
- pred_fake_4 = self.netD4(fake_AB_4.detach())
175
-
176
- fake_AB_5_1 = torch.cat((self.real_A, self.fake_B_5), 1) # Conditional GANs; feed IHC input and Segmentation mask output to the discriminator
177
- fake_AB_5_2 = torch.cat((self.real_B_1, self.fake_B_5), 1) # Conditional GANs; feed Hematoxylin input and Segmentation mask output to the discriminator
178
- fake_AB_5_3 = torch.cat((self.real_B_2, self.fake_B_5), 1) # Conditional GANs; feed mpIF DAPI input and Segmentation mask output to the discriminator
179
- fake_AB_5_4 = torch.cat((self.real_B_3, self.fake_B_5), 1) # Conditional GANs; feed mpIF Lap2 input and Segmentation mask output to the discriminator
180
- fake_AB_5_5 = torch.cat((self.real_B_4, self.fake_B_5), 1) # Conditional GANs; feed mpIF Lap2 input and Segmentation mask output to the discriminator
181
-
182
- pred_fake_5_1 = self.netD51(fake_AB_5_1.detach())
183
- pred_fake_5_2 = self.netD52(fake_AB_5_2.detach())
184
- pred_fake_5_3 = self.netD53(fake_AB_5_3.detach())
185
- pred_fake_5_4 = self.netD54(fake_AB_5_4.detach())
186
- pred_fake_5_5 = self.netD55(fake_AB_5_5.detach())
187
-
188
- pred_fake_5 = torch.stack(
189
- [torch.mul(pred_fake_5_1, self.seg_weights[0]),
190
- torch.mul(pred_fake_5_2, self.seg_weights[1]),
191
- torch.mul(pred_fake_5_3, self.seg_weights[2]),
192
- torch.mul(pred_fake_5_4, self.seg_weights[3]),
193
- torch.mul(pred_fake_5_5, self.seg_weights[4])]).sum(dim=0)
194
-
195
- self.loss_D_fake_1 = self.criterionGAN_BCE(pred_fake_1, False)
196
- self.loss_D_fake_2 = self.criterionGAN_BCE(pred_fake_2, False)
197
- self.loss_D_fake_3 = self.criterionGAN_BCE(pred_fake_3, False)
198
- self.loss_D_fake_4 = self.criterionGAN_BCE(pred_fake_4, False)
199
- self.loss_D_fake_5 = self.criterionGAN_lsgan(pred_fake_5, False)
200
-
201
-
202
- real_AB_1 = torch.cat((self.real_A, self.real_B_1), 1)
203
- real_AB_2 = torch.cat((self.real_A, self.real_B_2), 1)
204
- real_AB_3 = torch.cat((self.real_A, self.real_B_3), 1)
205
- real_AB_4 = torch.cat((self.real_A, self.real_B_4), 1)
206
-
207
- pred_real_1 = self.netD1(real_AB_1)
208
- pred_real_2 = self.netD2(real_AB_2)
209
- pred_real_3 = self.netD3(real_AB_3)
210
- pred_real_4 = self.netD4(real_AB_4)
211
-
212
- real_AB_5_1 = torch.cat((self.real_A, self.real_B_5), 1)
213
- real_AB_5_2 = torch.cat((self.real_B_1, self.real_B_5), 1)
214
- real_AB_5_3 = torch.cat((self.real_B_2, self.real_B_5), 1)
215
- real_AB_5_4 = torch.cat((self.real_B_3, self.real_B_5), 1)
216
- real_AB_5_5 = torch.cat((self.real_B_4, self.real_B_5), 1)
217
-
218
- pred_real_5_1 = self.netD51(real_AB_5_1)
219
- pred_real_5_2 = self.netD52(real_AB_5_2)
220
- pred_real_5_3 = self.netD53(real_AB_5_3)
221
- pred_real_5_4 = self.netD54(real_AB_5_4)
222
- pred_real_5_5 = self.netD55(real_AB_5_5)
223
-
224
- pred_real_5 = torch.stack(
225
- [torch.mul(pred_real_5_1, self.seg_weights[0]),
226
- torch.mul(pred_real_5_2, self.seg_weights[1]),
227
- torch.mul(pred_real_5_3, self.seg_weights[2]),
228
- torch.mul(pred_real_5_4, self.seg_weights[3]),
229
- torch.mul(pred_real_5_5, self.seg_weights[4])]).sum(dim=0)
230
-
231
- self.loss_D_real_1 = self.criterionGAN_BCE(pred_real_1, True)
232
- self.loss_D_real_2 = self.criterionGAN_BCE(pred_real_2, True)
233
- self.loss_D_real_3 = self.criterionGAN_BCE(pred_real_3, True)
234
- self.loss_D_real_4 = self.criterionGAN_BCE(pred_real_4, True)
235
- self.loss_D_real_5 = self.criterionGAN_lsgan(pred_real_5, True)
188
+ # fake_AB_1 = torch.cat((self.real_A, self.fake_B_1), 1) # Conditional GANs; feed IHC input and Hematoxtlin output to the discriminator
189
+ # fake_AB_2 = torch.cat((self.real_A, self.fake_B_2), 1) # Conditional GANs; feed IHC input and mpIF DAPI output to the discriminator
190
+ # fake_AB_3 = torch.cat((self.real_A, self.fake_B_3), 1) # Conditional GANs; feed IHC input and mpIF Lap2 output to the discriminator
191
+ # fake_AB_4 = torch.cat((self.real_A, self.fake_B_4), 1) # Conditional GANs; feed IHC input and mpIF Ki67 output to the discriminator
192
+
193
+ # pred_fake_1 = self.netD1(fake_AB_1.detach())
194
+ # pred_fake_2 = self.netD2(fake_AB_2.detach())
195
+ # pred_fake_3 = self.netD3(fake_AB_3.detach())
196
+ # pred_fake_4 = self.netD4(fake_AB_4.detach())
197
+
198
+ # self.loss_D_fake_1 = self.criterionGAN_BCE(pred_fake_1, False)
199
+ # self.loss_D_fake_2 = self.criterionGAN_BCE(pred_fake_2, False)
200
+ # self.loss_D_fake_3 = self.criterionGAN_BCE(pred_fake_3, False)
201
+ # self.loss_D_fake_4 = self.criterionGAN_BCE(pred_fake_4, False)
202
+
203
+ for i,model_name in enumerate(self.model_names_d):
204
+ fake_AB = torch.cat((self.real_A, getattr(self,f'fake_B_{i+1}')), 1)
205
+ pred_fake = getattr(self,f'net{model_name}')(fake_AB.detach())
206
+ setattr(self,f'loss_D_fake_{i+1}',self.criterionGAN_BCE(pred_fake, False))
207
+ #setattr(self,f'fake_AB_{i+1}',torch.cat((self.real_A, getattr(self,f'fake_B_{i+1}')), 1))
208
+ #setattr(self,f'pred_fake_{i+1}',getattr(self,f'netD{i+1}')(getattr))
209
+
210
+
211
+ # fake_AB_5_1 = torch.cat((self.real_A, self.fake_B_5), 1) # Conditional GANs; feed IHC input and Segmentation mask output to the discriminator
212
+ # fake_AB_5_2 = torch.cat((self.real_B_1, self.fake_B_5), 1) # Conditional GANs; feed Hematoxylin input and Segmentation mask output to the discriminator
213
+ # fake_AB_5_3 = torch.cat((self.real_B_2, self.fake_B_5), 1) # Conditional GANs; feed mpIF DAPI input and Segmentation mask output to the discriminator
214
+ # fake_AB_5_4 = torch.cat((self.real_B_3, self.fake_B_5), 1) # Conditional GANs; feed mpIF Lap2 input and Segmentation mask output to the discriminator
215
+ # fake_AB_5_5 = torch.cat((self.real_B_4, self.fake_B_5), 1) # Conditional GANs; feed mpIF Lap2 input and Segmentation mask output to the discriminator
216
+ #
217
+ # pred_fake_5_1 = self.netD51(fake_AB_5_1.detach())
218
+ # pred_fake_5_2 = self.netD52(fake_AB_5_2.detach())
219
+ # pred_fake_5_3 = self.netD53(fake_AB_5_3.detach())
220
+ # pred_fake_5_4 = self.netD54(fake_AB_5_4.detach())
221
+ # pred_fake_5_5 = self.netD55(fake_AB_5_5.detach())
222
+ #
223
+ # pred_fake_5 = torch.stack(
224
+ # [torch.mul(pred_fake_5_1, self.seg_weights[0]),
225
+ # torch.mul(pred_fake_5_2, self.seg_weights[1]),
226
+ # torch.mul(pred_fake_5_3, self.seg_weights[2]),
227
+ # torch.mul(pred_fake_5_4, self.seg_weights[3]),
228
+ # torch.mul(pred_fake_5_5, self.seg_weights[4])]).sum(dim=0)
229
+
230
+ l_pred_fake_seg = []
231
+ for i,model_name in enumerate(self.model_names_ds):
232
+ if i == 0:
233
+ fake_AB_seg_i = torch.cat((self.real_A, getattr(self,f'fake_B_{self.mod_id_seg}')), 1)
234
+ else:
235
+ fake_AB_seg_i = torch.cat((getattr(self,f'real_B_{i}'), getattr(self,f'fake_B_{self.mod_id_seg}')), 1)
236
+
237
+ pred_fake_seg_i = getattr(self,f'net{model_name}')(fake_AB_seg_i.detach())
238
+ l_pred_fake_seg.append(torch.mul(pred_fake_seg_i, self.seg_weights[i]))
239
+ pred_fake_seg = torch.stack(l_pred_fake_seg).sum(dim=0)
240
+
241
+ #self.loss_D_fake_5 = self.criterionGAN_lsgan(pred_fake_5, False)
242
+ setattr(self,f'loss_D_fake_{self.mod_id_seg}',self.criterionGAN_lsgan(pred_fake_seg, False))
243
+
244
+
245
+ # real_AB_1 = torch.cat((self.real_A, self.real_B_1), 1)
246
+ # real_AB_2 = torch.cat((self.real_A, self.real_B_2), 1)
247
+ # real_AB_3 = torch.cat((self.real_A, self.real_B_3), 1)
248
+ # real_AB_4 = torch.cat((self.real_A, self.real_B_4), 1)
249
+ #
250
+ # pred_real_1 = self.netD1(real_AB_1)
251
+ # pred_real_2 = self.netD2(real_AB_2)
252
+ # pred_real_3 = self.netD3(real_AB_3)
253
+ # pred_real_4 = self.netD4(real_AB_4)
254
+ #
255
+ # self.loss_D_real_1 = self.criterionGAN_BCE(pred_real_1, True)
256
+ # self.loss_D_real_2 = self.criterionGAN_BCE(pred_real_2, True)
257
+ # self.loss_D_real_3 = self.criterionGAN_BCE(pred_real_3, True)
258
+ # self.loss_D_real_4 = self.criterionGAN_BCE(pred_real_4, True)
259
+
260
+ for i,model_name in enumerate(self.model_names_d):
261
+ real_AB = torch.cat((self.real_A, getattr(self,f'real_B_{i+1}')), 1)
262
+ pred_real = getattr(self,f'net{model_name}')(real_AB)
263
+ setattr(self,f'loss_D_real_{i+1}',self.criterionGAN_BCE(pred_real, True))
264
+
265
+ # real_AB_5_1 = torch.cat((self.real_A, self.real_B_5), 1)
266
+ # real_AB_5_2 = torch.cat((self.real_B_1, self.real_B_5), 1)
267
+ # real_AB_5_3 = torch.cat((self.real_B_2, self.real_B_5), 1)
268
+ # real_AB_5_4 = torch.cat((self.real_B_3, self.real_B_5), 1)
269
+ # real_AB_5_5 = torch.cat((self.real_B_4, self.real_B_5), 1)
270
+ #
271
+ # pred_real_5_1 = self.netD51(real_AB_5_1)
272
+ # pred_real_5_2 = self.netD52(real_AB_5_2)
273
+ # pred_real_5_3 = self.netD53(real_AB_5_3)
274
+ # pred_real_5_4 = self.netD54(real_AB_5_4)
275
+ # pred_real_5_5 = self.netD55(real_AB_5_5)
276
+ #
277
+ # pred_real_5 = torch.stack(
278
+ # [torch.mul(pred_real_5_1, self.seg_weights[0]),
279
+ # torch.mul(pred_real_5_2, self.seg_weights[1]),
280
+ # torch.mul(pred_real_5_3, self.seg_weights[2]),
281
+ # torch.mul(pred_real_5_4, self.seg_weights[3]),
282
+ # torch.mul(pred_real_5_5, self.seg_weights[4])]).sum(dim=0)
283
+
284
+ l_pred_real_seg = []
285
+ for i,model_name in enumerate(self.model_names_ds):
286
+ if i == 0:
287
+ real_AB_seg_i = torch.cat((self.real_A, getattr(self,f'real_B_{self.mod_id_seg}')), 1)
288
+ else:
289
+ real_AB_seg_i = torch.cat((getattr(self,f'real_B_{i}'), getattr(self,f'real_B_{self.mod_id_seg}')), 1)
290
+
291
+ pred_real_seg_i = getattr(self,f'net{model_name}')(real_AB_seg_i)
292
+ l_pred_real_seg.append(torch.mul(pred_real_seg_i, self.seg_weights[i]))
293
+ pred_real_seg = torch.stack(l_pred_real_seg).sum(dim=0)
294
+
295
+ #self.loss_D_real_5 = self.criterionGAN_lsgan(pred_real_5, True)
296
+ setattr(self,f'loss_D_real_{self.mod_id_seg}',self.criterionGAN_lsgan(pred_real_seg, True))
236
297
 
237
298
  # combine losses and calculate gradients
238
- self.loss_D = (self.loss_D_fake_1 + self.loss_D_real_1) * 0.5 * self.loss_D_weights[0] + \
239
- (self.loss_D_fake_2 + self.loss_D_real_2) * 0.5 * self.loss_D_weights[1] + \
240
- (self.loss_D_fake_3 + self.loss_D_real_3) * 0.5 * self.loss_D_weights[2] + \
241
- (self.loss_D_fake_4 + self.loss_D_real_4) * 0.5 * self.loss_D_weights[3] + \
242
- (self.loss_D_fake_5 + self.loss_D_real_5) * 0.5 * self.loss_D_weights[4]
299
+ # self.loss_D = (self.loss_D_fake_1 + self.loss_D_real_1) * 0.5 * self.loss_D_weights[0] + \
300
+ # (self.loss_D_fake_2 + self.loss_D_real_2) * 0.5 * self.loss_D_weights[1] + \
301
+ # (self.loss_D_fake_3 + self.loss_D_real_3) * 0.5 * self.loss_D_weights[2] + \
302
+ # (self.loss_D_fake_4 + self.loss_D_real_4) * 0.5 * self.loss_D_weights[3] + \
303
+ # (self.loss_D_fake_5 + self.loss_D_real_5) * 0.5 * self.loss_D_weights[4]
304
+
305
+ self.loss_D = torch.tensor(0., device=self.device)
306
+ for i in range(self.opt.modalities_no):
307
+ self.loss_D += (getattr(self,f'loss_D_fake_{i+1}') + getattr(self,f'loss_D_real_{i+1}')) * 0.5 * self.loss_D_weights[i]
308
+ self.loss_D += (getattr(self,f'loss_D_fake_{self.mod_id_seg}') + getattr(self,f'loss_D_real_{self.mod_id_seg}')) * 0.5 * self.loss_D_weights[self.opt.modalities_no]
243
309
 
244
310
  self.loss_D.backward()
245
311
 
246
312
  def backward_G(self):
247
313
  """Calculate GAN and L1 loss for the generator"""
248
314
 
249
- fake_AB_1 = torch.cat((self.real_A, self.fake_B_1), 1)
250
- fake_AB_2 = torch.cat((self.real_A, self.fake_B_2), 1)
251
- fake_AB_3 = torch.cat((self.real_A, self.fake_B_3), 1)
252
- fake_AB_4 = torch.cat((self.real_A, self.fake_B_4), 1)
253
-
254
- fake_AB_5_1 = torch.cat((self.real_A, self.fake_B_5), 1)
255
- fake_AB_5_2 = torch.cat((self.real_B_1, self.fake_B_5), 1)
256
- fake_AB_5_3 = torch.cat((self.real_B_2, self.fake_B_5), 1)
257
- fake_AB_5_4 = torch.cat((self.real_B_3, self.fake_B_5), 1)
258
- fake_AB_5_5 = torch.cat((self.real_B_4, self.fake_B_5), 1)
259
-
260
- pred_fake_1 = self.netD1(fake_AB_1)
261
- pred_fake_2 = self.netD2(fake_AB_2)
262
- pred_fake_3 = self.netD3(fake_AB_3)
263
- pred_fake_4 = self.netD4(fake_AB_4)
264
-
265
- pred_fake_5_1 = self.netD51(fake_AB_5_1)
266
- pred_fake_5_2 = self.netD52(fake_AB_5_2)
267
- pred_fake_5_3 = self.netD53(fake_AB_5_3)
268
- pred_fake_5_4 = self.netD54(fake_AB_5_4)
269
- pred_fake_5_5 = self.netD55(fake_AB_5_5)
270
- pred_fake_5 = torch.stack(
271
- [torch.mul(pred_fake_5_1, self.seg_weights[0]),
272
- torch.mul(pred_fake_5_2, self.seg_weights[1]),
273
- torch.mul(pred_fake_5_3, self.seg_weights[2]),
274
- torch.mul(pred_fake_5_4, self.seg_weights[3]),
275
- torch.mul(pred_fake_5_5, self.seg_weights[4])]).sum(dim=0)
276
-
277
- self.loss_G_GAN_1 = self.criterionGAN_BCE(pred_fake_1, True)
278
- self.loss_G_GAN_2 = self.criterionGAN_BCE(pred_fake_2, True)
279
- self.loss_G_GAN_3 = self.criterionGAN_BCE(pred_fake_3, True)
280
- self.loss_G_GAN_4 = self.criterionGAN_BCE(pred_fake_4, True)
281
- self.loss_G_GAN_5 = self.criterionGAN_lsgan(pred_fake_5, True)
315
+ # fake_AB_1 = torch.cat((self.real_A, self.fake_B_1), 1)
316
+ # fake_AB_2 = torch.cat((self.real_A, self.fake_B_2), 1)
317
+ # fake_AB_3 = torch.cat((self.real_A, self.fake_B_3), 1)
318
+ # fake_AB_4 = torch.cat((self.real_A, self.fake_B_4), 1)
319
+ #
320
+ # pred_fake_1 = self.netD1(fake_AB_1)
321
+ # pred_fake_2 = self.netD2(fake_AB_2)
322
+ # pred_fake_3 = self.netD3(fake_AB_3)
323
+ # pred_fake_4 = self.netD4(fake_AB_4)
324
+ #
325
+ # self.loss_G_GAN_1 = self.criterionGAN_BCE(pred_fake_1, True)
326
+ # self.loss_G_GAN_2 = self.criterionGAN_BCE(pred_fake_2, True)
327
+ # self.loss_G_GAN_3 = self.criterionGAN_BCE(pred_fake_3, True)
328
+ # self.loss_G_GAN_4 = self.criterionGAN_BCE(pred_fake_4, True)
329
+
330
+ for i,model_name in enumerate(self.model_names_d):
331
+ fake_AB = torch.cat((self.real_A, getattr(self,f'fake_B_{i+1}')), 1)
332
+ pred_fake = getattr(self,f'net{model_name}')(fake_AB)
333
+ setattr(self,f'loss_G_GAN_{i+1}',self.criterionGAN_BCE(pred_fake, True))
334
+
335
+ # fake_AB_5_1 = torch.cat((self.real_A, self.fake_B_5), 1)
336
+ # fake_AB_5_2 = torch.cat((self.real_B_1, self.fake_B_5), 1)
337
+ # fake_AB_5_3 = torch.cat((self.real_B_2, self.fake_B_5), 1)
338
+ # fake_AB_5_4 = torch.cat((self.real_B_3, self.fake_B_5), 1)
339
+ # fake_AB_5_5 = torch.cat((self.real_B_4, self.fake_B_5), 1)
340
+ #
341
+ # pred_fake_5_1 = self.netD51(fake_AB_5_1)
342
+ # pred_fake_5_2 = self.netD52(fake_AB_5_2)
343
+ # pred_fake_5_3 = self.netD53(fake_AB_5_3)
344
+ # pred_fake_5_4 = self.netD54(fake_AB_5_4)
345
+ # pred_fake_5_5 = self.netD55(fake_AB_5_5)
346
+ # pred_fake_5 = torch.stack(
347
+ # [torch.mul(pred_fake_5_1, self.seg_weights[0]),
348
+ # torch.mul(pred_fake_5_2, self.seg_weights[1]),
349
+ # torch.mul(pred_fake_5_3, self.seg_weights[2]),
350
+ # torch.mul(pred_fake_5_4, self.seg_weights[3]),
351
+ # torch.mul(pred_fake_5_5, self.seg_weights[4])]).sum(dim=0)
352
+
353
+ l_pred_fake_seg = []
354
+ for i,model_name in enumerate(self.model_names_ds):
355
+ if i == 0:
356
+ fake_AB_seg_i = torch.cat((self.real_A, getattr(self,f'fake_B_{self.mod_id_seg}')), 1)
357
+ else:
358
+ fake_AB_seg_i = torch.cat((getattr(self,f'real_B_{i}'), getattr(self,f'fake_B_{self.mod_id_seg}')), 1)
359
+
360
+ pred_fake_seg_i = getattr(self,f'net{model_name}')(fake_AB_seg_i)
361
+ l_pred_fake_seg.append(torch.mul(pred_fake_seg_i, self.seg_weights[i]))
362
+ pred_fake_seg = torch.stack(l_pred_fake_seg).sum(dim=0)
363
+
364
+ # self.loss_G_GAN_5 = self.criterionGAN_lsgan(pred_fake_5, True)
365
+ setattr(self,f'loss_G_GAN_{self.mod_id_seg}',self.criterionGAN_lsgan(pred_fake_seg, True))
282
366
 
283
367
  # Second, G(A) = B
284
- self.loss_G_L1_1 = self.criterionSmoothL1(self.fake_B_1, self.real_B_1) * self.opt.lambda_L1
285
- self.loss_G_L1_2 = self.criterionSmoothL1(self.fake_B_2, self.real_B_2) * self.opt.lambda_L1
286
- self.loss_G_L1_3 = self.criterionSmoothL1(self.fake_B_3, self.real_B_3) * self.opt.lambda_L1
287
- self.loss_G_L1_4 = self.criterionSmoothL1(self.fake_B_4, self.real_B_4) * self.opt.lambda_L1
288
- self.loss_G_L1_5 = self.criterionSmoothL1(self.fake_B_5, self.real_B_5) * self.opt.lambda_L1
289
-
290
- self.loss_G_VGG_1 = self.criterionVGG(self.fake_B_1, self.real_B_1) * self.opt.lambda_feat
291
- self.loss_G_VGG_2 = self.criterionVGG(self.fake_B_2, self.real_B_2) * self.opt.lambda_feat
292
- self.loss_G_VGG_3 = self.criterionVGG(self.fake_B_3, self.real_B_3) * self.opt.lambda_feat
293
- self.loss_G_VGG_4 = self.criterionVGG(self.fake_B_4, self.real_B_4) * self.opt.lambda_feat
294
-
295
- self.loss_G = (self.loss_G_GAN_1 + self.loss_G_L1_1 + self.loss_G_VGG_1) * self.loss_G_weights[0] + \
296
- (self.loss_G_GAN_2 + self.loss_G_L1_2 + self.loss_G_VGG_2) * self.loss_G_weights[1] + \
297
- (self.loss_G_GAN_3 + self.loss_G_L1_3 + self.loss_G_VGG_3) * self.loss_G_weights[2] + \
298
- (self.loss_G_GAN_4 + self.loss_G_L1_4 + self.loss_G_VGG_4) * self.loss_G_weights[3] + \
299
- (self.loss_G_GAN_5 + self.loss_G_L1_5) * self.loss_G_weights[4]
368
+ # self.loss_G_L1_1 = self.criterionSmoothL1(self.fake_B_1, self.real_B_1) * self.opt.lambda_L1
369
+ # self.loss_G_L1_2 = self.criterionSmoothL1(self.fake_B_2, self.real_B_2) * self.opt.lambda_L1
370
+ # self.loss_G_L1_3 = self.criterionSmoothL1(self.fake_B_3, self.real_B_3) * self.opt.lambda_L1
371
+ # self.loss_G_L1_4 = self.criterionSmoothL1(self.fake_B_4, self.real_B_4) * self.opt.lambda_L1
372
+ # self.loss_G_L1_5 = self.criterionSmoothL1(self.fake_B_5, self.real_B_5) * self.opt.lambda_L1
373
+
374
+ for i in range(self.opt.modalities_no):
375
+ setattr(self,f'loss_G_L1_{i+1}',self.criterionSmoothL1(getattr(self,f'fake_B_{i+1}'), getattr(self,f'real_B_{i+1}')) * self.opt.lambda_L1)
376
+ setattr(self,f'loss_G_L1_{self.mod_id_seg}',self.criterionSmoothL1(getattr(self,f'fake_B_{self.mod_id_seg}'), getattr(self,f'real_B_{self.mod_id_seg}')) * self.opt.lambda_L1)
377
+
378
+ # self.loss_G_VGG_1 = self.criterionVGG(self.fake_B_1, self.real_B_1) * self.opt.lambda_feat
379
+ # self.loss_G_VGG_2 = self.criterionVGG(self.fake_B_2, self.real_B_2) * self.opt.lambda_feat
380
+ # self.loss_G_VGG_3 = self.criterionVGG(self.fake_B_3, self.real_B_3) * self.opt.lambda_feat
381
+ # self.loss_G_VGG_4 = self.criterionVGG(self.fake_B_4, self.real_B_4) * self.opt.lambda_feat
382
+ for i in range(self.opt.modalities_no):
383
+ setattr(self,f'loss_G_VGG_{i+1}',self.criterionVGG(getattr(self,f'fake_B_{i+1}'), getattr(self,f'real_B_{i+1}')) * self.opt.lambda_feat)
384
+ setattr(self,f'loss_G_VGG_{self.mod_id_seg}',self.criterionVGG(getattr(self,f'fake_B_{self.mod_id_seg}'), getattr(self,f'real_B_{self.mod_id_seg}')) * self.opt.lambda_feat)
385
+
386
+ # self.loss_G = (self.loss_G_GAN_1 + self.loss_G_L1_1 + self.loss_G_VGG_1) * self.loss_G_weights[0] + \
387
+ # (self.loss_G_GAN_2 + self.loss_G_L1_2 + self.loss_G_VGG_2) * self.loss_G_weights[1] + \
388
+ # (self.loss_G_GAN_3 + self.loss_G_L1_3 + self.loss_G_VGG_3) * self.loss_G_weights[2] + \
389
+ # (self.loss_G_GAN_4 + self.loss_G_L1_4 + self.loss_G_VGG_4) * self.loss_G_weights[3] + \
390
+ # (self.loss_G_GAN_5 + self.loss_G_L1_5) * self.loss_G_weights[4]
391
+
392
+ self.loss_G = torch.tensor(0., device=self.device)
393
+ for i in range(self.opt.modalities_no):
394
+ self.loss_G += (getattr(self,f'loss_G_GAN_{i+1}') + getattr(self,f'loss_G_L1_{i+1}') + getattr(self,f'loss_G_VGG_{i+1}')) * self.loss_G_weights[i]
395
+ self.loss_G += (getattr(self,f'loss_G_GAN_{self.mod_id_seg}') + getattr(self,f'loss_G_L1_{self.mod_id_seg}')) * self.loss_G_weights[i]
300
396
 
301
397
  # combine loss and calculate gradients
302
398
  # self.loss_G = (self.loss_G_GAN_1 + self.loss_G_L1_1) * self.loss_G_weights[0] + \
@@ -309,30 +405,36 @@ class DeepLIIFModel(BaseModel):
309
405
  def optimize_parameters(self):
310
406
  self.forward() # compute fake images: G(A)
311
407
  # update D
312
- self.set_requires_grad(self.netD1, True) # enable backprop for D1
313
- self.set_requires_grad(self.netD2, True) # enable backprop for D2
314
- self.set_requires_grad(self.netD3, True) # enable backprop for D3
315
- self.set_requires_grad(self.netD4, True) # enable backprop for D4
316
- self.set_requires_grad(self.netD51, True) # enable backprop for D51
317
- self.set_requires_grad(self.netD52, True) # enable backprop for D52
318
- self.set_requires_grad(self.netD53, True) # enable backprop for D53
319
- self.set_requires_grad(self.netD54, True) # enable backprop for D54
320
- self.set_requires_grad(self.netD55, True) # enable backprop for D54
408
+ # self.set_requires_grad(self.netD1, True) # enable backprop for D1
409
+ # self.set_requires_grad(self.netD2, True) # enable backprop for D2
410
+ # self.set_requires_grad(self.netD3, True) # enable backprop for D3
411
+ # self.set_requires_grad(self.netD4, True) # enable backprop for D4
412
+ # self.set_requires_grad(self.netD51, True) # enable backprop for D51
413
+ # self.set_requires_grad(self.netD52, True) # enable backprop for D52
414
+ # self.set_requires_grad(self.netD53, True) # enable backprop for D53
415
+ # self.set_requires_grad(self.netD54, True) # enable backprop for D54
416
+ # self.set_requires_grad(self.netD55, True) # enable backprop for D54
417
+
418
+ for model_name in self.model_names_d + self.model_names_ds:
419
+ self.set_requires_grad(getattr(self,f'net{model_name}'), True)
321
420
 
322
421
  self.optimizer_D.zero_grad() # set D's gradients to zero
323
422
  self.backward_D() # calculate gradients for D
324
423
  self.optimizer_D.step() # update D's weights
325
424
 
326
425
  # update G
327
- self.set_requires_grad(self.netD1, False) # D1 requires no gradients when optimizing G1
328
- self.set_requires_grad(self.netD2, False) # D2 requires no gradients when optimizing G2
329
- self.set_requires_grad(self.netD3, False) # D3 requires no gradients when optimizing G3
330
- self.set_requires_grad(self.netD4, False) # D4 requires no gradients when optimizing G4
331
- self.set_requires_grad(self.netD51, False) # D51 requires no gradients when optimizing G51
332
- self.set_requires_grad(self.netD52, False) # D52 requires no gradients when optimizing G52
333
- self.set_requires_grad(self.netD53, False) # D53 requires no gradients when optimizing G53
334
- self.set_requires_grad(self.netD54, False) # D54 requires no gradients when optimizing G54
335
- self.set_requires_grad(self.netD55, False) # D54 requires no gradients when optimizing G54
426
+ # self.set_requires_grad(self.netD1, False) # D1 requires no gradients when optimizing G1
427
+ # self.set_requires_grad(self.netD2, False) # D2 requires no gradients when optimizing G2
428
+ # self.set_requires_grad(self.netD3, False) # D3 requires no gradients when optimizing G3
429
+ # self.set_requires_grad(self.netD4, False) # D4 requires no gradients when optimizing G4
430
+ # self.set_requires_grad(self.netD51, False) # D51 requires no gradients when optimizing G51
431
+ # self.set_requires_grad(self.netD52, False) # D52 requires no gradients when optimizing G52
432
+ # self.set_requires_grad(self.netD53, False) # D53 requires no gradients when optimizing G53
433
+ # self.set_requires_grad(self.netD54, False) # D54 requires no gradients when optimizing G54
434
+ # self.set_requires_grad(self.netD55, False) # D54 requires no gradients when optimizing G54
435
+
436
+ for model_name in self.model_names_d + self.model_names_ds:
437
+ self.set_requires_grad(getattr(self,f'net{model_name}'), False)
336
438
 
337
439
  self.optimizer_G.zero_grad() # set G's gradients to zero
338
440
  self.backward_G() # calculate graidents for G
@@ -345,30 +447,37 @@ class DeepLIIFModel(BaseModel):
345
447
 
346
448
  self.forward() # compute fake images: G(A)
347
449
  # update D
348
- self.set_requires_grad(self.netD1, True) # enable backprop for D1
349
- self.set_requires_grad(self.netD2, True) # enable backprop for D2
350
- self.set_requires_grad(self.netD3, True) # enable backprop for D3
351
- self.set_requires_grad(self.netD4, True) # enable backprop for D4
352
- self.set_requires_grad(self.netD51, True) # enable backprop for D51
353
- self.set_requires_grad(self.netD52, True) # enable backprop for D52
354
- self.set_requires_grad(self.netD53, True) # enable backprop for D53
355
- self.set_requires_grad(self.netD54, True) # enable backprop for D54
356
- self.set_requires_grad(self.netD55, True) # enable backprop for D54
450
+ # self.set_requires_grad(self.netD1, True) # enable backprop for D1
451
+ # self.set_requires_grad(self.netD2, True) # enable backprop for D2
452
+ # self.set_requires_grad(self.netD3, True) # enable backprop for D3
453
+ # self.set_requires_grad(self.netD4, True) # enable backprop for D4
454
+ # self.set_requires_grad(self.netD51, True) # enable backprop for D51
455
+ # self.set_requires_grad(self.netD52, True) # enable backprop for D52
456
+ # self.set_requires_grad(self.netD53, True) # enable backprop for D53
457
+ # self.set_requires_grad(self.netD54, True) # enable backprop for D54
458
+ # self.set_requires_grad(self.netD55, True) # enable backprop for D54
459
+
460
+ for model_name in self.model_names_d + self.model_names_ds:
461
+ self.set_requires_grad(getattr(self,f'net{model_name}'), True)
357
462
 
358
463
  self.optimizer_D.zero_grad() # set D's gradients to zero
359
464
  self.backward_D() # calculate gradients for D
360
465
 
361
466
  # update G
362
- self.set_requires_grad(self.netD1, False) # D1 requires no gradients when optimizing G1
363
- self.set_requires_grad(self.netD2, False) # D2 requires no gradients when optimizing G2
364
- self.set_requires_grad(self.netD3, False) # D3 requires no gradients when optimizing G3
365
- self.set_requires_grad(self.netD4, False) # D4 requires no gradients when optimizing G4
366
- self.set_requires_grad(self.netD51, False) # D51 requires no gradients when optimizing G51
367
- self.set_requires_grad(self.netD52, False) # D52 requires no gradients when optimizing G52
368
- self.set_requires_grad(self.netD53, False) # D53 requires no gradients when optimizing G53
369
- self.set_requires_grad(self.netD54, False) # D54 requires no gradients when optimizing G54
370
- self.set_requires_grad(self.netD55, False) # D54 requires no gradients when optimizing G54
467
+ # self.set_requires_grad(self.netD1, False) # D1 requires no gradients when optimizing G1
468
+ # self.set_requires_grad(self.netD2, False) # D2 requires no gradients when optimizing G2
469
+ # self.set_requires_grad(self.netD3, False) # D3 requires no gradients when optimizing G3
470
+ # self.set_requires_grad(self.netD4, False) # D4 requires no gradients when optimizing G4
471
+ # self.set_requires_grad(self.netD51, False) # D51 requires no gradients when optimizing G51
472
+ # self.set_requires_grad(self.netD52, False) # D52 requires no gradients when optimizing G52
473
+ # self.set_requires_grad(self.netD53, False) # D53 requires no gradients when optimizing G53
474
+ # self.set_requires_grad(self.netD54, False) # D54 requires no gradients when optimizing G54
475
+ # self.set_requires_grad(self.netD55, False) # D54 requires no gradients when optimizing G54
476
+
477
+ for model_name in self.model_names_d + self.model_names_ds:
478
+ self.set_requires_grad(getattr(self,f'net{model_name}'), False)
371
479
 
372
480
  self.optimizer_G.zero_grad() # set G's gradients to zero
373
481
  self.backward_G() # calculate graidents for G
374
482
 
483
+