deepliif 1.2.2__py3-none-any.whl → 1.2.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -4,6 +4,7 @@ from . import networks
4
4
  from .networks import get_optimizer
5
5
  from . import init_nets, run_dask, get_opt
6
6
  from torch import nn
7
+ from ..util.util import get_input_id, init_input_and_mod_id, map_model_names
7
8
 
8
9
  class DeepLIIFKDModel(BaseModel):
9
10
  """ This class implements the DeepLIIF model, for learning a mapping from input images to modalities given paired data."""
@@ -21,91 +22,108 @@ class DeepLIIFKDModel(BaseModel):
21
22
  self.seg_weights = opt.seg_weights
22
23
  self.loss_G_weights = opt.loss_G_weights
23
24
  self.loss_D_weights = opt.loss_D_weights
25
+ self.mod_id_seg, self.input_id = init_input_and_mod_id(opt) # creates self.input_id, self.mod_id_seg
26
+ print(f'Initializing model with segmentation modality id {self.mod_id_seg}, input id {self.input_id}')
24
27
 
25
28
  if not opt.is_train:
26
29
  self.gpu_ids = [] # avoid the models being loaded as DP
27
30
  else:
28
31
  self.gpu_ids = opt.gpu_ids
29
-
32
+
30
33
  self.loss_names = []
31
34
  self.visual_names = ['real_A']
32
35
  # specify the training losses you want to print out. The training/test scripts will call <BaseModel.get_current_losses>
33
- for i in range(1, self.opt.modalities_no + 1 + 1):
34
- self.loss_names.extend([f'G_GAN_{i}', f'G_L1_{i}', f'D_real_{i}', f'D_fake_{i}', f'G_KLDiv_{i}', f'G_KLDiv_5_{i}'])
35
- self.visual_names.extend([f'fake_B_{i}', f'fake_B_5_{i}', f'fake_B_{i}_teacher', f'fake_B_5_{i}_teacher', f'real_B_{i}'])
36
+ for i in range(self.opt.modalities_no):
37
+ self.loss_names.extend([f'G_GAN_{i+1}', f'G_L1_{i+1}', f'D_real_{i+1}', f'D_fake_{i+1}', f'G_KLDiv_{i+1}', f'G_KLDiv_{self.mod_id_seg}{i+1}'])
38
+ self.visual_names.extend([f'fake_B_{i+1}', f'fake_B_{i+1}_teacher', f'real_B_{i+1}'])
39
+ self.loss_names.extend([f'G_GAN_{self.mod_id_seg}',f'G_L1_{self.mod_id_seg}',f'D_real_{self.mod_id_seg}',f'D_fake_{self.mod_id_seg}',
40
+ f'G_KLDiv_{self.mod_id_seg}',f'G_KLDiv_{self.mod_id_seg}{self.opt.modalities_no}'])
41
+
42
+ for i in range(self.opt.modalities_no+1):
43
+ self.visual_names.extend([f'fake_B_{self.mod_id_seg}{i}', f'fake_B_{self.mod_id_seg}{i}_teacher']) # 0 is used for the base input mod
44
+ self.visual_names.extend([f'fake_B_{self.mod_id_seg}', f'fake_B_{self.mod_id_seg}_teacher', f'real_B_{self.mod_id_seg}',])
36
45
 
37
46
  # specify the images you want to save/display. The training/test scripts will call <BaseModel.get_current_visuals>
38
47
  # specify the models you want to save to the disk. The training/test scripts will call <BaseModel.save_networks> and <BaseModel.load_networks>
48
+ self.model_names_g = []
49
+ self.model_names_gs = []
39
50
  if self.is_train:
40
51
  self.model_names = []
41
52
  for i in range(1, self.opt.modalities_no + 1):
42
- self.model_names.extend(['G' + str(i), 'D' + str(i)])
43
-
44
- for i in range(1, self.opt.modalities_no + 1 + 1):
45
- self.model_names.extend(['G5' + str(i), 'D5' + str(i)])
53
+ self.model_names.extend([f'G{i}', f'D{i}'])
54
+ self.model_names_g.append(f'G{i}')
55
+
56
+ for i in range(self.opt.modalities_no + 1): # 0 is used for the base input mod
57
+ if self.input_id == '0':
58
+ self.model_names.extend([f'G{self.mod_id_seg}{i}', f'D{self.mod_id_seg}{i}'])
59
+ self.model_names_gs.append(f'G{self.mod_id_seg}{i}')
60
+ else:
61
+ self.model_names.extend([f'G{self.mod_id_seg}{i+1}', f'D{self.mod_id_seg}{i+1}'])
62
+ self.model_names_gs.append(f'G{self.mod_id_seg}{i+1}')
46
63
  else: # during test time, only load G
47
64
  self.model_names = []
48
65
  for i in range(1, self.opt.modalities_no + 1):
49
- self.model_names.extend(['G' + str(i)])
50
-
51
- for i in range(1, self.opt.modalities_no + 1 + 1):
52
- self.model_names.extend(['G5' + str(i)])
53
-
66
+ self.model_names.extend([f'G{i}'])
67
+ self.model_names_g.append(f'G{i}')
68
+
69
+ #input_id = get_input_id(os.path.join(opt.checkpoints_dir, opt.name))
70
+ if self.input_id == '0':
71
+ for i in range(self.opt.modalities_no + 1): # 0 is used for the base input mod
72
+ self.model_names.extend([f'G{self.mod_id_seg}{i}'])
73
+ self.model_names_gs.append(f'G{self.mod_id_seg}{i}')
74
+ else:
75
+ for i in range(self.opt.modalities_no + 1): # old setting, 1 is used for the base input mod
76
+ self.model_names.extend([f'G{self.mod_id_seg}{i+1}'])
77
+ self.model_names_gs.append(f'G{self.mod_id_seg}{i+1}')
78
+
54
79
  # define networks (both generator and discriminator)
55
80
  if isinstance(opt.netG, str):
56
- opt.netG = [opt.netG] * 4
81
+ opt.netG = [opt.netG] * self.opt.modalities_no
57
82
  if isinstance(opt.net_gs, str):
58
- opt.net_gs = [opt.net_gs]*5
83
+ opt.net_gs = [opt.net_gs] * (self.opt.modalities_no + 1) # +1 for base input mod
59
84
 
60
85
 
61
- self.netG1 = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG[0], opt.norm,
62
- not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids, opt.padding)
63
- self.netG2 = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG[1], opt.norm,
64
- not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids, opt.padding)
65
- self.netG3 = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG[2], opt.norm,
66
- not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids, opt.padding)
67
- self.netG4 = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG[3], opt.norm,
68
- not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids, opt.padding)
86
+ for i,model_name in enumerate(self.model_names_g):
87
+ setattr(self,f'net{model_name}',networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG[i], opt.norm,
88
+ not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids, opt.padding))
69
89
 
70
90
  # DeepLIIF model currently uses one gs arch because there is only one explicit seg mod output
71
- self.netG51 = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.net_gs[0], opt.norm,
72
- not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids)
73
- self.netG52 = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.net_gs[1], opt.norm,
74
- not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids)
75
- self.netG53 = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.net_gs[2], opt.norm,
76
- not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids)
77
- self.netG54 = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.net_gs[3], opt.norm,
78
- not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids)
79
- self.netG55 = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.net_gs[4], opt.norm,
80
- not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids)
81
-
91
+ for i,model_name in enumerate(self.model_names_gs):
92
+ setattr(self,f'net{model_name}',networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.net_gs[i], opt.norm,
93
+ not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids))
82
94
 
83
95
  if self.is_train: # define a discriminator; conditional GANs need to take both input and output images; Therefore, #channels for D is input_nc + output_nc
84
- self.netD1 = networks.define_D(opt.input_nc+opt.output_nc , opt.ndf, opt.netD,
85
- opt.n_layers_D, opt.norm, opt.init_type, opt.init_gain, self.gpu_ids)
86
- self.netD2 = networks.define_D(opt.input_nc + opt.output_nc, opt.ndf, opt.netD,
87
- opt.n_layers_D, opt.norm, opt.init_type, opt.init_gain, self.gpu_ids)
88
- self.netD3 = networks.define_D(opt.input_nc + opt.output_nc, opt.ndf, opt.netD,
89
- opt.n_layers_D, opt.norm, opt.init_type, opt.init_gain, self.gpu_ids)
90
- self.netD4 = networks.define_D(opt.input_nc + opt.output_nc, opt.ndf, opt.netD,
91
- opt.n_layers_D, opt.norm, opt.init_type, opt.init_gain, self.gpu_ids)
92
-
93
- self.netD51 = networks.define_D(opt.input_nc + opt.output_nc, opt.ndf, opt.netD,
94
- opt.n_layers_D, opt.norm, opt.init_type, opt.init_gain, self.gpu_ids)
95
- self.netD52 = networks.define_D(opt.input_nc + opt.output_nc, opt.ndf, opt.netD,
96
- opt.n_layers_D, opt.norm, opt.init_type, opt.init_gain, self.gpu_ids)
97
- self.netD53 = networks.define_D(opt.input_nc + opt.output_nc, opt.ndf, opt.netD,
98
- opt.n_layers_D, opt.norm, opt.init_type, opt.init_gain, self.gpu_ids)
99
- self.netD54 = networks.define_D(opt.input_nc + opt.output_nc, opt.ndf, opt.netD,
100
- opt.n_layers_D, opt.norm, opt.init_type, opt.init_gain, self.gpu_ids)
101
- self.netD55 = networks.define_D(opt.input_nc + opt.output_nc, opt.ndf, opt.netD,
102
- opt.n_layers_D, opt.norm, opt.init_type, opt.init_gain, self.gpu_ids)
103
-
96
+ self.model_names_d = [f'D{i+1}' for i in range(self.opt.modalities_no)]
97
+ if self.input_id == '0':
98
+ self.model_names_ds = [f'D{self.mod_id_seg}{i}' for i in range(self.opt.modalities_no+1)]
99
+ else:
100
+ self.model_names_ds = [f'D{self.mod_id_seg}{i+1}' for i in range(self.opt.modalities_no+1)]
101
+ for model_name in self.model_names_d:
102
+ setattr(self,f'net{model_name}',networks.define_D(opt.input_nc+opt.output_nc , opt.ndf, opt.netD,
103
+ opt.n_layers_D, opt.norm, opt.init_type, opt.init_gain, self.gpu_ids))
104
+
105
+ for model_name in self.model_names_ds:
106
+ setattr(self,f'net{model_name}',networks.define_D(opt.input_nc + opt.output_nc, opt.ndf, opt.netD,
107
+ opt.n_layers_D, opt.norm, opt.init_type, opt.init_gain, self.gpu_ids))
108
+
104
109
  # load the teacher model
105
110
  self.opt_teacher = get_opt(opt.model_dir_teacher, mode='test')
106
111
  self.opt_teacher.gpu_ids = opt.gpu_ids # use student's gpu_ids
107
112
  self.nets_teacher = init_nets(opt.model_dir_teacher, eager_mode=True, opt=self.opt_teacher, phase='test')
108
113
 
114
+ # modify model names to be consistent with the current deepliifkd model names
115
+ # otherwise it may be tricky to pair the loss terms?
116
+ self.opt_teacher.mod_id_seg, self.opt_teacher.input_id = init_input_and_mod_id(self.opt_teacher)
117
+ d_mapping_model_name = map_model_names(list(self.nets_teacher.keys()),self.opt_teacher.mod_id_seg,self.opt_teacher.input_id,
118
+ self.mod_id_seg,self.input_id)
119
+ self.d_mapping_model_name = d_mapping_model_name
120
+ print('Model name mapping, teacher model to student model:',self.d_mapping_model_name)
121
+ else:
122
+ # remove all model names for the teacher model
123
+ self.model_names = [name for name in self.model_names if not name.endswith('_teacher')]
124
+ self.loss_names = [name for name in self.loss_names if not name.endswith('_teacher')]
125
+ self.visual_names = [name for name in self.visual_names if not name.endswith('_teacher')]
126
+
109
127
 
110
128
  if self.is_train:
111
129
  # define loss functions
@@ -114,14 +132,22 @@ class DeepLIIFKDModel(BaseModel):
114
132
  self.criterionSmoothL1 = torch.nn.SmoothL1Loss()
115
133
 
116
134
  # initialize optimizers; schedulers will be automatically created by function <BaseModel.setup>.
117
- params = list(self.netG1.parameters()) + list(self.netG2.parameters()) + list(self.netG3.parameters()) + list(self.netG4.parameters()) + list(self.netG51.parameters()) + list(self.netG52.parameters()) + list(self.netG53.parameters()) + list(self.netG54.parameters()) + list(self.netG55.parameters())
135
+ #params = list(self.netG1.parameters()) + list(self.netG2.parameters()) + list(self.netG3.parameters()) + list(self.netG4.parameters()) + list(self.netG51.parameters()) + list(self.netG52.parameters()) + list(self.netG53.parameters()) + list(self.netG54.parameters()) + list(self.netG55.parameters())
136
+ params = []
137
+ for model_name in self.model_names_g + self.model_names_gs:
138
+ params += list(getattr(self,f'net{model_name}').parameters())
139
+
118
140
  try:
119
141
  self.optimizer_G = get_optimizer(opt.optimizer)(params, lr=opt.lr_g, betas=(opt.beta1, 0.999))
120
142
  except:
121
143
  print(f'betas are not used for optimizer torch.optim.{opt.optimizer} in generators')
122
144
  self.optimizer_G = get_optimizer(opt.optimizer)(params, lr=opt.lr_g)
123
145
 
124
- params = list(self.netD1.parameters()) + list(self.netD2.parameters()) + list(self.netD3.parameters()) + list(self.netD4.parameters()) + list(self.netD51.parameters()) + list(self.netD52.parameters()) + list(self.netD53.parameters()) + list(self.netD54.parameters()) + list(self.netD55.parameters())
146
+ #params = list(self.netD1.parameters()) + list(self.netD2.parameters()) + list(self.netD3.parameters()) + list(self.netD4.parameters()) + list(self.netD51.parameters()) + list(self.netD52.parameters()) + list(self.netD53.parameters()) + list(self.netD54.parameters()) + list(self.netD55.parameters())
147
+ params = []
148
+ for model_name in self.model_names_d + self.model_names_ds:
149
+ params += list(getattr(self,f'net{model_name}').parameters())
150
+
125
151
  try:
126
152
  self.optimizer_D = get_optimizer(opt.optimizer)(params, lr=opt.lr_d, betas=(opt.beta1, 0.999))
127
153
  except:
@@ -131,6 +157,7 @@ class DeepLIIFKDModel(BaseModel):
131
157
  self.optimizers.append(self.optimizer_G)
132
158
  self.optimizers.append(self.optimizer_D)
133
159
 
160
+ print('self.device',self.device)
134
161
  self.criterionVGG = networks.VGGLoss().to(self.device)
135
162
  self.criterionKLDiv = torch.nn.KLDivLoss(reduction='batchmean').to(self.device)
136
163
  self.softmax = torch.nn.Softmax(dim=-1).to(self.device) # apply softmax on the last dim
@@ -146,192 +173,181 @@ class DeepLIIFKDModel(BaseModel):
146
173
  self.real_A = input['A'].to(self.device)
147
174
 
148
175
  self.real_B_array = input['B']
149
- self.real_B_1 = self.real_B_array[0].to(self.device)
150
- self.real_B_2 = self.real_B_array[1].to(self.device)
151
- self.real_B_3 = self.real_B_array[2].to(self.device)
152
- self.real_B_4 = self.real_B_array[3].to(self.device)
153
- self.real_B_5 = self.real_B_array[4].to(self.device)
176
+ for i in range(self.opt.modalities_no):
177
+ setattr(self,f'real_B_{i+1}',self.real_B_array[i].to(self.device))
178
+ setattr(self,f'real_B_{self.mod_id_seg}',self.real_B_array[self.opt.modalities_no].to(self.device)) # the last one is seg
179
+
154
180
  self.image_paths = input['A_paths']
155
181
 
156
182
  def forward(self):
157
183
  """Run forward pass; called by both functions <optimize_parameters> and <test>."""
158
- self.fake_B_1 = self.netG1(self.real_A) # Hematoxylin image generator
159
- self.fake_B_2 = self.netG2(self.real_A) # mpIF DAPI image generator
160
- self.fake_B_3 = self.netG3(self.real_A) # mpIF Lap2 image generator
161
- self.fake_B_4 = self.netG4(self.real_A) # mpIF Ki67 image generator
162
-
163
- self.fake_B_5_1 = self.netG51(self.real_A) # Segmentation mask generator from IHC input image
164
- self.fake_B_5_2 = self.netG52(self.fake_B_1) # Segmentation mask generator from Hematoxylin input image
165
- self.fake_B_5_3 = self.netG53(self.fake_B_2) # Segmentation mask generator from mpIF DAPI input image
166
- self.fake_B_5_4 = self.netG54(self.fake_B_3) # Segmentation mask generator from mpIF Lap2 input image
167
- self.fake_B_5_5 = self.netG55(self.fake_B_4) # Segmentation mask generator from mpIF Lap2 input image
168
- self.fake_B_5 = torch.stack([torch.mul(self.fake_B_5_1, self.seg_weights[0]),
169
- torch.mul(self.fake_B_5_2, self.seg_weights[1]),
170
- torch.mul(self.fake_B_5_3, self.seg_weights[2]),
171
- torch.mul(self.fake_B_5_4, self.seg_weights[3]),
172
- torch.mul(self.fake_B_5_5, self.seg_weights[4])]).sum(dim=0)
184
+ # self.fake_B_1 = self.netG1(self.real_A) # Hematoxylin image generator
185
+ # self.fake_B_2 = self.netG2(self.real_A) # mpIF DAPI image generator
186
+ # self.fake_B_3 = self.netG3(self.real_A) # mpIF Lap2 image generator
187
+ # self.fake_B_4 = self.netG4(self.real_A) # mpIF Ki67 image generator
188
+
189
+ for i in range(self.opt.modalities_no):
190
+ setattr(self,f'fake_B_{i+1}',getattr(self,f'netG{i+1}')(self.real_A))
191
+
192
+ # self.fake_B_5_1 = self.netG51(self.real_A) # Segmentation mask generator from IHC input image
193
+ # self.fake_B_5_2 = self.netG52(self.fake_B_1) # Segmentation mask generator from Hematoxylin input image
194
+ # self.fake_B_5_3 = self.netG53(self.fake_B_2) # Segmentation mask generator from mpIF DAPI input image
195
+ # self.fake_B_5_4 = self.netG54(self.fake_B_3) # Segmentation mask generator from mpIF Lap2 input image
196
+ # self.fake_B_5_5 = self.netG55(self.fake_B_4) # Segmentation mask generator from mpIF Lap2 input image
197
+ # self.fake_B_5 = torch.stack([torch.mul(self.fake_B_5_1, self.seg_weights[0]),
198
+ # torch.mul(self.fake_B_5_2, self.seg_weights[1]),
199
+ # torch.mul(self.fake_B_5_3, self.seg_weights[2]),
200
+ # torch.mul(self.fake_B_5_4, self.seg_weights[3]),
201
+ # torch.mul(self.fake_B_5_5, self.seg_weights[4])]).sum(dim=0)
173
202
 
174
- fakes_teacher = run_dask(img=self.real_A, nets=self.nets_teacher, opt=self.opt_teacher, use_dask=False, output_tensor=True)
175
- for k,v in fakes_teacher.items():
176
- suffix = k[1:] # starts with G
177
- suffix = '_'.join(list(suffix)) # 51 -> 5_1
178
- setattr(self,f'fake_B_{suffix}_teacher',v)
203
+ for i,model_name in enumerate(self.model_names_gs):
204
+ if i == 0:
205
+ setattr(self,f'fake_B_{self.mod_id_seg}_{i}',getattr(self,f'net{model_name}')(self.real_A))
206
+ else:
207
+ setattr(self,f'fake_B_{self.mod_id_seg}_{i}',getattr(self,f'net{model_name}')(getattr(self,f'fake_B_{i}')))
208
+
209
+ setattr(self,f'fake_B_{self.mod_id_seg}',torch.stack([torch.mul(getattr(self,f'fake_B_{self.mod_id_seg}_{i}'), self.seg_weights[i]) for i in range(self.opt.modalities_no+1)]).sum(dim=0))
210
+
211
+ if self.is_train:
212
+ fakes_teacher = run_dask(img=self.real_A, nets=self.nets_teacher, opt=self.opt_teacher, use_dask=False, output_tensor=True)
213
+ #print(f'Checking seg mod id for teacher model: current id is {self.opt_teacher.mod_id_seg}, id to map to is {self.mod_id_seg}')
214
+ for k,v in fakes_teacher.items():
215
+ suffix = self.d_mapping_model_name[k][1:] # starts with G
216
+ l_suffix = list(suffix)
217
+ if l_suffix[0] == str(self.opt_teacher.mod_id_seg): # mod_id_seg might be integer
218
+ if l_suffix[0] != str(self.mod_id_seg):
219
+ l_suffix[0] = str(self.mod_id_seg)
220
+ #suffix = '_'.join(list(suffix)) # 51 -> 5_1
221
+ suffix = '_'.join(l_suffix) # 51 -> 5_1
222
+ #print(f'Loaded teacher model: fake_B_{suffix}_teacher')
223
+ setattr(self,f'fake_B_{suffix}_teacher',v.to(self.device))
179
224
 
180
225
  def backward_D(self):
181
226
  """Calculate GAN loss for the discriminators"""
182
- fake_AB_1 = torch.cat((self.real_A, self.fake_B_1), 1) # Conditional GANs; feed IHC input and Hematoxtlin output to the discriminator
183
- fake_AB_2 = torch.cat((self.real_A, self.fake_B_2), 1) # Conditional GANs; feed IHC input and mpIF DAPI output to the discriminator
184
- fake_AB_3 = torch.cat((self.real_A, self.fake_B_3), 1) # Conditional GANs; feed IHC input and mpIF Lap2 output to the discriminator
185
- fake_AB_4 = torch.cat((self.real_A, self.fake_B_4), 1) # Conditional GANs; feed IHC input and mpIF Ki67 output to the discriminator
186
-
187
- pred_fake_1 = self.netD1(fake_AB_1.detach())
188
- pred_fake_2 = self.netD2(fake_AB_2.detach())
189
- pred_fake_3 = self.netD3(fake_AB_3.detach())
190
- pred_fake_4 = self.netD4(fake_AB_4.detach())
191
-
192
- fake_AB_5_1 = torch.cat((self.real_A, self.fake_B_5), 1) # Conditional GANs; feed IHC input and Segmentation mask output to the discriminator
193
- fake_AB_5_2 = torch.cat((self.real_B_1, self.fake_B_5), 1) # Conditional GANs; feed Hematoxylin input and Segmentation mask output to the discriminator
194
- fake_AB_5_3 = torch.cat((self.real_B_2, self.fake_B_5), 1) # Conditional GANs; feed mpIF DAPI input and Segmentation mask output to the discriminator
195
- fake_AB_5_4 = torch.cat((self.real_B_3, self.fake_B_5), 1) # Conditional GANs; feed mpIF Lap2 input and Segmentation mask output to the discriminator
196
- fake_AB_5_5 = torch.cat((self.real_B_4, self.fake_B_5), 1) # Conditional GANs; feed mpIF Lap2 input and Segmentation mask output to the discriminator
197
-
198
- pred_fake_5_1 = self.netD51(fake_AB_5_1.detach())
199
- pred_fake_5_2 = self.netD52(fake_AB_5_2.detach())
200
- pred_fake_5_3 = self.netD53(fake_AB_5_3.detach())
201
- pred_fake_5_4 = self.netD54(fake_AB_5_4.detach())
202
- pred_fake_5_5 = self.netD55(fake_AB_5_5.detach())
203
-
204
- pred_fake_5 = torch.stack(
205
- [torch.mul(pred_fake_5_1, self.seg_weights[0]),
206
- torch.mul(pred_fake_5_2, self.seg_weights[1]),
207
- torch.mul(pred_fake_5_3, self.seg_weights[2]),
208
- torch.mul(pred_fake_5_4, self.seg_weights[3]),
209
- torch.mul(pred_fake_5_5, self.seg_weights[4])]).sum(dim=0)
210
-
211
- self.loss_D_fake_1 = self.criterionGAN_BCE(pred_fake_1, False)
212
- self.loss_D_fake_2 = self.criterionGAN_BCE(pred_fake_2, False)
213
- self.loss_D_fake_3 = self.criterionGAN_BCE(pred_fake_3, False)
214
- self.loss_D_fake_4 = self.criterionGAN_BCE(pred_fake_4, False)
215
- self.loss_D_fake_5 = self.criterionGAN_lsgan(pred_fake_5, False)
216
-
217
-
218
- real_AB_1 = torch.cat((self.real_A, self.real_B_1), 1)
219
- real_AB_2 = torch.cat((self.real_A, self.real_B_2), 1)
220
- real_AB_3 = torch.cat((self.real_A, self.real_B_3), 1)
221
- real_AB_4 = torch.cat((self.real_A, self.real_B_4), 1)
222
-
223
- pred_real_1 = self.netD1(real_AB_1)
224
- pred_real_2 = self.netD2(real_AB_2)
225
- pred_real_3 = self.netD3(real_AB_3)
226
- pred_real_4 = self.netD4(real_AB_4)
227
-
228
- real_AB_5_1 = torch.cat((self.real_A, self.real_B_5), 1)
229
- real_AB_5_2 = torch.cat((self.real_B_1, self.real_B_5), 1)
230
- real_AB_5_3 = torch.cat((self.real_B_2, self.real_B_5), 1)
231
- real_AB_5_4 = torch.cat((self.real_B_3, self.real_B_5), 1)
232
- real_AB_5_5 = torch.cat((self.real_B_4, self.real_B_5), 1)
233
-
234
- pred_real_5_1 = self.netD51(real_AB_5_1)
235
- pred_real_5_2 = self.netD52(real_AB_5_2)
236
- pred_real_5_3 = self.netD53(real_AB_5_3)
237
- pred_real_5_4 = self.netD54(real_AB_5_4)
238
- pred_real_5_5 = self.netD55(real_AB_5_5)
239
-
240
- pred_real_5 = torch.stack(
241
- [torch.mul(pred_real_5_1, self.seg_weights[0]),
242
- torch.mul(pred_real_5_2, self.seg_weights[1]),
243
- torch.mul(pred_real_5_3, self.seg_weights[2]),
244
- torch.mul(pred_real_5_4, self.seg_weights[3]),
245
- torch.mul(pred_real_5_5, self.seg_weights[4])]).sum(dim=0)
246
-
247
- self.loss_D_real_1 = self.criterionGAN_BCE(pred_real_1, True)
248
- self.loss_D_real_2 = self.criterionGAN_BCE(pred_real_2, True)
249
- self.loss_D_real_3 = self.criterionGAN_BCE(pred_real_3, True)
250
- self.loss_D_real_4 = self.criterionGAN_BCE(pred_real_4, True)
251
- self.loss_D_real_5 = self.criterionGAN_lsgan(pred_real_5, True)
227
+
228
+ for i,model_name in enumerate(self.model_names_d):
229
+ fake_AB = torch.cat((self.real_A, getattr(self,f'fake_B_{i+1}')), 1)
230
+ pred_fake = getattr(self,f'net{model_name}')(fake_AB.detach())
231
+ setattr(self,f'loss_D_fake_{i+1}',self.criterionGAN_BCE(pred_fake, False))
232
+
233
+ l_pred_fake_seg = []
234
+ for i,model_name in enumerate(self.model_names_ds):
235
+ if i == 0:
236
+ fake_AB_seg_i = torch.cat((self.real_A, getattr(self,f'fake_B_{self.mod_id_seg}')), 1)
237
+ else:
238
+ fake_AB_seg_i = torch.cat((getattr(self,f'real_B_{i}'), getattr(self,f'fake_B_{self.mod_id_seg}')), 1)
239
+
240
+ pred_fake_seg_i = getattr(self,f'net{model_name}')(fake_AB_seg_i.detach())
241
+ l_pred_fake_seg.append(torch.mul(pred_fake_seg_i, self.seg_weights[i]))
242
+ pred_fake_seg = torch.stack(l_pred_fake_seg).sum(dim=0)
243
+
244
+ setattr(self,f'loss_D_fake_{self.mod_id_seg}',self.criterionGAN_lsgan(pred_fake_seg, False))
245
+
246
+ for i,model_name in enumerate(self.model_names_d):
247
+ real_AB = torch.cat((self.real_A, getattr(self,f'real_B_{i+1}')), 1)
248
+ pred_real = getattr(self,f'net{model_name}')(real_AB)
249
+ setattr(self,f'loss_D_real_{i+1}',self.criterionGAN_BCE(pred_real, True))
250
+
251
+ l_pred_real_seg = []
252
+ for i,model_name in enumerate(self.model_names_ds):
253
+ if i == 0:
254
+ real_AB_seg_i = torch.cat((self.real_A, getattr(self,f'real_B_{self.mod_id_seg}')), 1)
255
+ else:
256
+ real_AB_seg_i = torch.cat((getattr(self,f'real_B_{i}'), getattr(self,f'real_B_{self.mod_id_seg}')), 1)
257
+
258
+ pred_real_seg_i = getattr(self,f'net{model_name}')(real_AB_seg_i)
259
+ l_pred_real_seg.append(torch.mul(pred_real_seg_i, self.seg_weights[i]))
260
+ pred_real_seg = torch.stack(l_pred_real_seg).sum(dim=0)
261
+
262
+ #self.loss_D_real_5 = self.criterionGAN_lsgan(pred_real_5, True)
263
+ setattr(self,f'loss_D_real_{self.mod_id_seg}',self.criterionGAN_lsgan(pred_real_seg, True))
252
264
 
253
265
  # combine losses and calculate gradients
254
- self.loss_D = (self.loss_D_fake_1 + self.loss_D_real_1) * 0.5 * self.loss_D_weights[0] + \
255
- (self.loss_D_fake_2 + self.loss_D_real_2) * 0.5 * self.loss_D_weights[1] + \
256
- (self.loss_D_fake_3 + self.loss_D_real_3) * 0.5 * self.loss_D_weights[2] + \
257
- (self.loss_D_fake_4 + self.loss_D_real_4) * 0.5 * self.loss_D_weights[3] + \
258
- (self.loss_D_fake_5 + self.loss_D_real_5) * 0.5 * self.loss_D_weights[4]
266
+ self.loss_D = torch.tensor(0., device=self.device)
267
+ for i in range(self.opt.modalities_no):
268
+ self.loss_D += (getattr(self,f'loss_D_fake_{i+1}') + getattr(self,f'loss_D_real_{i+1}')) * 0.5 * self.loss_D_weights[i]
269
+ self.loss_D += (getattr(self,f'loss_D_fake_{self.mod_id_seg}') + getattr(self,f'loss_D_real_{self.mod_id_seg}')) * 0.5 * self.loss_D_weights[self.opt.modalities_no]
259
270
 
260
271
  self.loss_D.backward()
261
272
 
262
273
  def backward_G(self):
263
274
  """Calculate GAN and L1 loss for the generator"""
264
-
265
- fake_AB_1 = torch.cat((self.real_A, self.fake_B_1), 1)
266
- fake_AB_2 = torch.cat((self.real_A, self.fake_B_2), 1)
267
- fake_AB_3 = torch.cat((self.real_A, self.fake_B_3), 1)
268
- fake_AB_4 = torch.cat((self.real_A, self.fake_B_4), 1)
269
-
270
- fake_AB_5_1 = torch.cat((self.real_A, self.fake_B_5), 1)
271
- fake_AB_5_2 = torch.cat((self.real_B_1, self.fake_B_5), 1)
272
- fake_AB_5_3 = torch.cat((self.real_B_2, self.fake_B_5), 1)
273
- fake_AB_5_4 = torch.cat((self.real_B_3, self.fake_B_5), 1)
274
- fake_AB_5_5 = torch.cat((self.real_B_4, self.fake_B_5), 1)
275
-
276
- pred_fake_1 = self.netD1(fake_AB_1)
277
- pred_fake_2 = self.netD2(fake_AB_2)
278
- pred_fake_3 = self.netD3(fake_AB_3)
279
- pred_fake_4 = self.netD4(fake_AB_4)
280
-
281
- pred_fake_5_1 = self.netD51(fake_AB_5_1)
282
- pred_fake_5_2 = self.netD52(fake_AB_5_2)
283
- pred_fake_5_3 = self.netD53(fake_AB_5_3)
284
- pred_fake_5_4 = self.netD54(fake_AB_5_4)
285
- pred_fake_5_5 = self.netD55(fake_AB_5_5)
286
- pred_fake_5 = torch.stack(
287
- [torch.mul(pred_fake_5_1, self.seg_weights[0]),
288
- torch.mul(pred_fake_5_2, self.seg_weights[1]),
289
- torch.mul(pred_fake_5_3, self.seg_weights[2]),
290
- torch.mul(pred_fake_5_4, self.seg_weights[3]),
291
- torch.mul(pred_fake_5_5, self.seg_weights[4])]).sum(dim=0)
292
-
293
- self.loss_G_GAN_1 = self.criterionGAN_BCE(pred_fake_1, True)
294
- self.loss_G_GAN_2 = self.criterionGAN_BCE(pred_fake_2, True)
295
- self.loss_G_GAN_3 = self.criterionGAN_BCE(pred_fake_3, True)
296
- self.loss_G_GAN_4 = self.criterionGAN_BCE(pred_fake_4, True)
297
- self.loss_G_GAN_5 = self.criterionGAN_lsgan(pred_fake_5, True)
275
+
276
+ for i,model_name in enumerate(self.model_names_d):
277
+ fake_AB = torch.cat((self.real_A, getattr(self,f'fake_B_{i+1}')), 1)
278
+ pred_fake = getattr(self,f'net{model_name}')(fake_AB)
279
+ setattr(self,f'loss_G_GAN_{i+1}',self.criterionGAN_BCE(pred_fake, True))
280
+
281
+ l_pred_fake_seg = []
282
+ for i,model_name in enumerate(self.model_names_ds):
283
+ if i == 0:
284
+ fake_AB_seg_i = torch.cat((self.real_A, getattr(self,f'fake_B_{self.mod_id_seg}')), 1)
285
+ else:
286
+ fake_AB_seg_i = torch.cat((getattr(self,f'real_B_{i}'), getattr(self,f'fake_B_{self.mod_id_seg}')), 1)
287
+
288
+ pred_fake_seg_i = getattr(self,f'net{model_name}')(fake_AB_seg_i)
289
+ l_pred_fake_seg.append(torch.mul(pred_fake_seg_i, self.seg_weights[i]))
290
+ pred_fake_seg = torch.stack(l_pred_fake_seg).sum(dim=0)
291
+
292
+ setattr(self,f'loss_G_GAN_{self.mod_id_seg}',self.criterionGAN_lsgan(pred_fake_seg, True))
298
293
 
299
294
  # Second, G(A) = B
300
- self.loss_G_L1_1 = self.criterionSmoothL1(self.fake_B_1, self.real_B_1) * self.opt.lambda_L1
301
- self.loss_G_L1_2 = self.criterionSmoothL1(self.fake_B_2, self.real_B_2) * self.opt.lambda_L1
302
- self.loss_G_L1_3 = self.criterionSmoothL1(self.fake_B_3, self.real_B_3) * self.opt.lambda_L1
303
- self.loss_G_L1_4 = self.criterionSmoothL1(self.fake_B_4, self.real_B_4) * self.opt.lambda_L1
304
- self.loss_G_L1_5 = self.criterionSmoothL1(self.fake_B_5, self.real_B_5) * self.opt.lambda_L1
305
-
306
- self.loss_G_VGG_1 = self.criterionVGG(self.fake_B_1, self.real_B_1) * self.opt.lambda_feat
307
- self.loss_G_VGG_2 = self.criterionVGG(self.fake_B_2, self.real_B_2) * self.opt.lambda_feat
308
- self.loss_G_VGG_3 = self.criterionVGG(self.fake_B_3, self.real_B_3) * self.opt.lambda_feat
309
- self.loss_G_VGG_4 = self.criterionVGG(self.fake_B_4, self.real_B_4) * self.opt.lambda_feat
295
+ for i in range(self.opt.modalities_no):
296
+ setattr(self,f'loss_G_L1_{i+1}',self.criterionSmoothL1(getattr(self,f'fake_B_{i+1}'), getattr(self,f'real_B_{i+1}')) * self.opt.lambda_L1)
297
+ setattr(self,f'loss_G_L1_{self.mod_id_seg}',self.criterionSmoothL1(getattr(self,f'fake_B_{self.mod_id_seg}'), getattr(self,f'real_B_{self.mod_id_seg}')) * self.opt.lambda_L1)
310
298
 
299
+ for i in range(self.opt.modalities_no):
300
+ setattr(self,f'loss_G_VGG_{i+1}',self.criterionVGG(getattr(self,f'fake_B_{i+1}'), getattr(self,f'real_B_{i+1}')) * self.opt.lambda_feat)
301
+ setattr(self,f'loss_G_VGG_{self.mod_id_seg}',self.criterionVGG(getattr(self,f'fake_B_{self.mod_id_seg}'), getattr(self,f'real_B_{self.mod_id_seg}')) * self.opt.lambda_feat)
311
302
 
312
303
  # .view(1,1,-1) reshapes the input (batch_size, 3, 512, 512) to (batch_size, 1, 3*512*512)
313
304
  # softmax/log-softmax is then applied on the concatenated vector of size (1, 3*512*512)
314
305
  # this normalizes the pixel values across all 3 RGB channels
315
306
  # the resulting vectors are then used to compute KL divergence loss
316
- self.loss_G_KLDiv_1 = self.criterionKLDiv(self.logsoftmax(self.fake_B_1.view(1,1,-1)), self.softmax(self.fake_B_1_teacher.view(1,1,-1)))
317
- self.loss_G_KLDiv_2 = self.criterionKLDiv(self.logsoftmax(self.fake_B_2.view(1,1,-1)), self.softmax(self.fake_B_2_teacher.view(1,1,-1)))
318
- self.loss_G_KLDiv_3 = self.criterionKLDiv(self.logsoftmax(self.fake_B_3.view(1,1,-1)), self.softmax(self.fake_B_3_teacher.view(1,1,-1)))
319
- self.loss_G_KLDiv_4 = self.criterionKLDiv(self.logsoftmax(self.fake_B_4.view(1,1,-1)), self.softmax(self.fake_B_4_teacher.view(1,1,-1)))
320
- self.loss_G_KLDiv_5 = self.criterionKLDiv(self.logsoftmax(self.fake_B_5.view(1,1,-1)), self.softmax(self.fake_B_5_teacher.view(1,1,-1)))
321
- self.loss_G_KLDiv_5_1 = self.criterionKLDiv(self.logsoftmax(self.fake_B_5_1.view(1,1,-1)), self.softmax(self.fake_B_5_1_teacher.view(1,1,-1)))
322
- self.loss_G_KLDiv_5_2 = self.criterionKLDiv(self.logsoftmax(self.fake_B_5_2.view(1,1,-1)), self.softmax(self.fake_B_5_2_teacher.view(1,1,-1)))
323
- self.loss_G_KLDiv_5_3 = self.criterionKLDiv(self.logsoftmax(self.fake_B_5_3.view(1,1,-1)), self.softmax(self.fake_B_5_3_teacher.view(1,1,-1)))
324
- self.loss_G_KLDiv_5_4 = self.criterionKLDiv(self.logsoftmax(self.fake_B_5_4.view(1,1,-1)), self.softmax(self.fake_B_5_4_teacher.view(1,1,-1)))
325
- self.loss_G_KLDiv_5_5 = self.criterionKLDiv(self.logsoftmax(self.fake_B_5_5.view(1,1,-1)), self.softmax(self.fake_B_5_5_teacher.view(1,1,-1)))
326
-
327
- self.loss_G = (self.loss_G_GAN_1 + self.loss_G_L1_1 + self.loss_G_VGG_1) * self.loss_G_weights[0] + \
328
- (self.loss_G_GAN_2 + self.loss_G_L1_2 + self.loss_G_VGG_2) * self.loss_G_weights[1] + \
329
- (self.loss_G_GAN_3 + self.loss_G_L1_3 + self.loss_G_VGG_3) * self.loss_G_weights[2] + \
330
- (self.loss_G_GAN_4 + self.loss_G_L1_4 + self.loss_G_VGG_4) * self.loss_G_weights[3] + \
331
- (self.loss_G_GAN_5 + self.loss_G_L1_5) * self.loss_G_weights[4] + \
332
- (self.loss_G_KLDiv_1 + self.loss_G_KLDiv_2 + self.loss_G_KLDiv_3 + self.loss_G_KLDiv_4 + \
333
- self.loss_G_KLDiv_5 + self.loss_G_KLDiv_5_1 + self.loss_G_KLDiv_5_2 + self.loss_G_KLDiv_5_3 + \
334
- self.loss_G_KLDiv_5_4 + self.loss_G_KLDiv_5_5) * 10
307
+ # self.loss_G_KLDiv_1 = self.criterionKLDiv(self.logsoftmax(self.fake_B_1.view(1,1,-1)), self.softmax(self.fake_B_1_teacher.view(1,1,-1)))
308
+ # self.loss_G_KLDiv_2 = self.criterionKLDiv(self.logsoftmax(self.fake_B_2.view(1,1,-1)), self.softmax(self.fake_B_2_teacher.view(1,1,-1)))
309
+ # self.loss_G_KLDiv_3 = self.criterionKLDiv(self.logsoftmax(self.fake_B_3.view(1,1,-1)), self.softmax(self.fake_B_3_teacher.view(1,1,-1)))
310
+ # self.loss_G_KLDiv_4 = self.criterionKLDiv(self.logsoftmax(self.fake_B_4.view(1,1,-1)), self.softmax(self.fake_B_4_teacher.view(1,1,-1)))
311
+ # self.loss_G_KLDiv_5 = self.criterionKLDiv(self.logsoftmax(self.fake_B_5.view(1,1,-1)), self.softmax(self.fake_B_5_teacher.view(1,1,-1)))
312
+
313
+ for i in range(self.opt.modalities_no):
314
+ setattr(self,f'loss_G_KLDiv_{i+1}',self.criterionKLDiv(self.logsoftmax(getattr(self,f'fake_B_{i+1}').view(1,1,-1)), self.softmax(getattr(self,f'fake_B_{i+1}_teacher').view(1,1,-1))))
315
+ setattr(self,f'loss_G_KLDiv_{self.mod_id_seg}',self.criterionKLDiv(self.logsoftmax(getattr(self,f'fake_B_{self.mod_id_seg}').view(1,1,-1)), self.softmax(getattr(self,f'fake_B_{self.mod_id_seg}_teacher').view(1,1,-1))))
316
+
317
+ # self.loss_G_KLDiv_5_1 = self.criterionKLDiv(self.logsoftmax(self.fake_B_5_1.view(1,1,-1)), self.softmax(self.fake_B_5_1_teacher.view(1,1,-1)))
318
+ # self.loss_G_KLDiv_5_2 = self.criterionKLDiv(self.logsoftmax(self.fake_B_5_2.view(1,1,-1)), self.softmax(self.fake_B_5_2_teacher.view(1,1,-1)))
319
+ # self.loss_G_KLDiv_5_3 = self.criterionKLDiv(self.logsoftmax(self.fake_B_5_3.view(1,1,-1)), self.softmax(self.fake_B_5_3_teacher.view(1,1,-1)))
320
+ # self.loss_G_KLDiv_5_4 = self.criterionKLDiv(self.logsoftmax(self.fake_B_5_4.view(1,1,-1)), self.softmax(self.fake_B_5_4_teacher.view(1,1,-1)))
321
+ # self.loss_G_KLDiv_5_5 = self.criterionKLDiv(self.logsoftmax(self.fake_B_5_5.view(1,1,-1)), self.softmax(self.fake_B_5_5_teacher.view(1,1,-1)))
322
+
323
+ for i in range(self.opt.modalities_no+1):
324
+ setattr(self,f'loss_G_KLDiv_{self.mod_id_seg}{i}',self.criterionKLDiv(self.logsoftmax(getattr(self,f'fake_B_{self.mod_id_seg}_{i}').view(1,1,-1)), self.softmax(getattr(self,f'fake_B_{self.mod_id_seg}_{i}_teacher').view(1,1,-1))))
325
+ #setattr(self,f'loss_G_KLDiv_{self.mod_id_seg}{self.opt.modalities_no+1}',self.criterionKLDiv(self.logsoftmax(getattr(self,f'fake_B_{self.mod_id_seg}_{self.opt.modalities_no+1}').view(1,1,-1)), self.softmax(getattr(self,f'fake_B_{self.mod_id_seg}_teacher').view(1,1,-1))))
326
+
327
+
328
+ # self.loss_G = (self.loss_G_GAN_1 + self.loss_G_L1_1 + self.loss_G_VGG_1) * self.loss_G_weights[0] + \
329
+ # (self.loss_G_GAN_2 + self.loss_G_L1_2 + self.loss_G_VGG_2) * self.loss_G_weights[1] + \
330
+ # (self.loss_G_GAN_3 + self.loss_G_L1_3 + self.loss_G_VGG_3) * self.loss_G_weights[2] + \
331
+ # (self.loss_G_GAN_4 + self.loss_G_L1_4 + self.loss_G_VGG_4) * self.loss_G_weights[3] + \
332
+ # (self.loss_G_GAN_5 + self.loss_G_L1_5) * self.loss_G_weights[4] + \
333
+ # (self.loss_G_KLDiv_1 + self.loss_G_KLDiv_2 + self.loss_G_KLDiv_3 + self.loss_G_KLDiv_4 + \
334
+ # self.loss_G_KLDiv_5 + self.loss_G_KLDiv_5_1 + self.loss_G_KLDiv_5_2 + self.loss_G_KLDiv_5_3 + \
335
+ # self.loss_G_KLDiv_5_4 + self.loss_G_KLDiv_5_5) * 10
336
+
337
+ self.loss_G = torch.tensor(0., device=self.device)
338
+ for i in range(self.opt.modalities_no):
339
+ self.loss_G += (getattr(self,f'loss_G_GAN_{i+1}') + getattr(self,f'loss_G_L1_{i+1}') + getattr(self,f'loss_G_VGG_{i+1}')) * self.loss_G_weights[i]
340
+ self.loss_G += (getattr(self,f'loss_G_GAN_{self.mod_id_seg}') + getattr(self,f'loss_G_L1_{self.mod_id_seg}')) * self.loss_G_weights[i]
341
+
342
+ factor_KLDiv = 10
343
+ for i in range(self.opt.modalities_no):
344
+ self.loss_G += (getattr(self,f'loss_G_KLDiv_{i+1}') + getattr(self,f'loss_G_KLDiv_{self.mod_id_seg}{i+1}')) * factor_KLDiv
345
+ self.loss_G += getattr(self,f'loss_G_KLDiv_{self.mod_id_seg}') * factor_KLDiv
346
+ if self.input_id == '0':
347
+ self.loss_G += getattr(self,f'loss_G_KLDiv_{self.mod_id_seg}0') * factor_KLDiv
348
+ else:
349
+ self.loss_G += getattr(self,f'loss_G_KLDiv_{self.mod_id_seg}{self.opt.modalities_no+1}') * factor_KLDiv
350
+
335
351
 
336
352
  # combine loss and calculate gradients
337
353
  # self.loss_G = (self.loss_G_GAN_1 + self.loss_G_L1_1) * self.loss_G_weights[0] + \
@@ -344,30 +360,16 @@ class DeepLIIFKDModel(BaseModel):
344
360
  def optimize_parameters(self):
345
361
  self.forward() # compute fake images: G(A)
346
362
  # update D
347
- self.set_requires_grad(self.netD1, True) # enable backprop for D1
348
- self.set_requires_grad(self.netD2, True) # enable backprop for D2
349
- self.set_requires_grad(self.netD3, True) # enable backprop for D3
350
- self.set_requires_grad(self.netD4, True) # enable backprop for D4
351
- self.set_requires_grad(self.netD51, True) # enable backprop for D51
352
- self.set_requires_grad(self.netD52, True) # enable backprop for D52
353
- self.set_requires_grad(self.netD53, True) # enable backprop for D53
354
- self.set_requires_grad(self.netD54, True) # enable backprop for D54
355
- self.set_requires_grad(self.netD55, True) # enable backprop for D54
363
+ for model_name in self.model_names_d + self.model_names_ds:
364
+ self.set_requires_grad(getattr(self,f'net{model_name}'), True)
356
365
 
357
366
  self.optimizer_D.zero_grad() # set D's gradients to zero
358
367
  self.backward_D() # calculate gradients for D
359
368
  self.optimizer_D.step() # update D's weights
360
369
 
361
370
  # update G
362
- self.set_requires_grad(self.netD1, False) # D1 requires no gradients when optimizing G1
363
- self.set_requires_grad(self.netD2, False) # D2 requires no gradients when optimizing G2
364
- self.set_requires_grad(self.netD3, False) # D3 requires no gradients when optimizing G3
365
- self.set_requires_grad(self.netD4, False) # D4 requires no gradients when optimizing G4
366
- self.set_requires_grad(self.netD51, False) # D51 requires no gradients when optimizing G51
367
- self.set_requires_grad(self.netD52, False) # D52 requires no gradients when optimizing G52
368
- self.set_requires_grad(self.netD53, False) # D53 requires no gradients when optimizing G53
369
- self.set_requires_grad(self.netD54, False) # D54 requires no gradients when optimizing G54
370
- self.set_requires_grad(self.netD55, False) # D54 requires no gradients when optimizing G54
371
+ for model_name in self.model_names_d + self.model_names_ds:
372
+ self.set_requires_grad(getattr(self,f'net{model_name}'), False)
371
373
 
372
374
  self.optimizer_G.zero_grad() # set G's gradients to zero
373
375
  self.backward_G() # calculate graidents for G
@@ -380,29 +382,15 @@ class DeepLIIFKDModel(BaseModel):
380
382
 
381
383
  self.forward() # compute fake images: G(A)
382
384
  # update D
383
- self.set_requires_grad(self.netD1, True) # enable backprop for D1
384
- self.set_requires_grad(self.netD2, True) # enable backprop for D2
385
- self.set_requires_grad(self.netD3, True) # enable backprop for D3
386
- self.set_requires_grad(self.netD4, True) # enable backprop for D4
387
- self.set_requires_grad(self.netD51, True) # enable backprop for D51
388
- self.set_requires_grad(self.netD52, True) # enable backprop for D52
389
- self.set_requires_grad(self.netD53, True) # enable backprop for D53
390
- self.set_requires_grad(self.netD54, True) # enable backprop for D54
391
- self.set_requires_grad(self.netD55, True) # enable backprop for D54
385
+ for model_name in self.model_names_d + self.model_names_ds:
386
+ self.set_requires_grad(getattr(self,f'net{model_name}'), True)
392
387
 
393
388
  self.optimizer_D.zero_grad() # set D's gradients to zero
394
389
  self.backward_D() # calculate gradients for D
395
390
 
396
391
  # update G
397
- self.set_requires_grad(self.netD1, False) # D1 requires no gradients when optimizing G1
398
- self.set_requires_grad(self.netD2, False) # D2 requires no gradients when optimizing G2
399
- self.set_requires_grad(self.netD3, False) # D3 requires no gradients when optimizing G3
400
- self.set_requires_grad(self.netD4, False) # D4 requires no gradients when optimizing G4
401
- self.set_requires_grad(self.netD51, False) # D51 requires no gradients when optimizing G51
402
- self.set_requires_grad(self.netD52, False) # D52 requires no gradients when optimizing G52
403
- self.set_requires_grad(self.netD53, False) # D53 requires no gradients when optimizing G53
404
- self.set_requires_grad(self.netD54, False) # D54 requires no gradients when optimizing G54
405
- self.set_requires_grad(self.netD55, False) # D54 requires no gradients when optimizing G54
392
+ for model_name in self.model_names_d + self.model_names_ds:
393
+ self.set_requires_grad(getattr(self,f'net{model_name}'), False)
406
394
 
407
395
  self.optimizer_G.zero_grad() # set G's gradients to zero
408
396
  self.backward_G() # calculate graidents for G