deepliif 1.2.3__py3-none-any.whl → 1.2.5__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cli.py +79 -24
- deepliif/data/base_dataset.py +2 -0
- deepliif/models/DeepLIIFKD_model.py +243 -255
- deepliif/models/DeepLIIF_model.py +344 -235
- deepliif/models/__init__.py +194 -103
- deepliif/models/base_model.py +7 -2
- deepliif/options/__init__.py +40 -8
- deepliif/postprocessing.py +1 -1
- deepliif/util/__init__.py +168 -8
- deepliif/util/util.py +85 -0
- deepliif/util/visualizer.py +2 -2
- {deepliif-1.2.3.dist-info → deepliif-1.2.5.dist-info}/METADATA +2 -2
- {deepliif-1.2.3.dist-info → deepliif-1.2.5.dist-info}/RECORD +17 -17
- {deepliif-1.2.3.dist-info → deepliif-1.2.5.dist-info}/LICENSE.md +0 -0
- {deepliif-1.2.3.dist-info → deepliif-1.2.5.dist-info}/WHEEL +0 -0
- {deepliif-1.2.3.dist-info → deepliif-1.2.5.dist-info}/entry_points.txt +0 -0
- {deepliif-1.2.3.dist-info → deepliif-1.2.5.dist-info}/top_level.txt +0 -0
|
@@ -4,6 +4,7 @@ from . import networks
|
|
|
4
4
|
from .networks import get_optimizer
|
|
5
5
|
from . import init_nets, run_dask, get_opt
|
|
6
6
|
from torch import nn
|
|
7
|
+
from ..util.util import get_input_id, init_input_and_mod_id, map_model_names
|
|
7
8
|
|
|
8
9
|
class DeepLIIFKDModel(BaseModel):
|
|
9
10
|
""" This class implements the DeepLIIF model, for learning a mapping from input images to modalities given paired data."""
|
|
@@ -21,91 +22,108 @@ class DeepLIIFKDModel(BaseModel):
|
|
|
21
22
|
self.seg_weights = opt.seg_weights
|
|
22
23
|
self.loss_G_weights = opt.loss_G_weights
|
|
23
24
|
self.loss_D_weights = opt.loss_D_weights
|
|
25
|
+
self.mod_id_seg, self.input_id = init_input_and_mod_id(opt) # creates self.input_id, self.mod_id_seg
|
|
26
|
+
print(f'Initializing model with segmentation modality id {self.mod_id_seg}, input id {self.input_id}')
|
|
24
27
|
|
|
25
28
|
if not opt.is_train:
|
|
26
29
|
self.gpu_ids = [] # avoid the models being loaded as DP
|
|
27
30
|
else:
|
|
28
31
|
self.gpu_ids = opt.gpu_ids
|
|
29
|
-
|
|
32
|
+
|
|
30
33
|
self.loss_names = []
|
|
31
34
|
self.visual_names = ['real_A']
|
|
32
35
|
# specify the training losses you want to print out. The training/test scripts will call <BaseModel.get_current_losses>
|
|
33
|
-
for i in range(
|
|
34
|
-
self.loss_names.extend([f'G_GAN_{i}', f'G_L1_{i}', f'D_real_{i}', f'D_fake_{i}', f'G_KLDiv_{i}', f'
|
|
35
|
-
self.visual_names.extend([f'fake_B_{i}', f'
|
|
36
|
+
for i in range(self.opt.modalities_no):
|
|
37
|
+
self.loss_names.extend([f'G_GAN_{i+1}', f'G_L1_{i+1}', f'D_real_{i+1}', f'D_fake_{i+1}', f'G_KLDiv_{i+1}', f'G_KLDiv_{self.mod_id_seg}{i+1}'])
|
|
38
|
+
self.visual_names.extend([f'fake_B_{i+1}', f'fake_B_{i+1}_teacher', f'real_B_{i+1}'])
|
|
39
|
+
self.loss_names.extend([f'G_GAN_{self.mod_id_seg}',f'G_L1_{self.mod_id_seg}',f'D_real_{self.mod_id_seg}',f'D_fake_{self.mod_id_seg}',
|
|
40
|
+
f'G_KLDiv_{self.mod_id_seg}',f'G_KLDiv_{self.mod_id_seg}{self.opt.modalities_no}'])
|
|
41
|
+
|
|
42
|
+
for i in range(self.opt.modalities_no+1):
|
|
43
|
+
self.visual_names.extend([f'fake_B_{self.mod_id_seg}{i}', f'fake_B_{self.mod_id_seg}{i}_teacher']) # 0 is used for the base input mod
|
|
44
|
+
self.visual_names.extend([f'fake_B_{self.mod_id_seg}', f'fake_B_{self.mod_id_seg}_teacher', f'real_B_{self.mod_id_seg}',])
|
|
36
45
|
|
|
37
46
|
# specify the images you want to save/display. The training/test scripts will call <BaseModel.get_current_visuals>
|
|
38
47
|
# specify the models you want to save to the disk. The training/test scripts will call <BaseModel.save_networks> and <BaseModel.load_networks>
|
|
48
|
+
self.model_names_g = []
|
|
49
|
+
self.model_names_gs = []
|
|
39
50
|
if self.is_train:
|
|
40
51
|
self.model_names = []
|
|
41
52
|
for i in range(1, self.opt.modalities_no + 1):
|
|
42
|
-
self.model_names.extend(['G'
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
53
|
+
self.model_names.extend([f'G{i}', f'D{i}'])
|
|
54
|
+
self.model_names_g.append(f'G{i}')
|
|
55
|
+
|
|
56
|
+
for i in range(self.opt.modalities_no + 1): # 0 is used for the base input mod
|
|
57
|
+
if self.input_id == '0':
|
|
58
|
+
self.model_names.extend([f'G{self.mod_id_seg}{i}', f'D{self.mod_id_seg}{i}'])
|
|
59
|
+
self.model_names_gs.append(f'G{self.mod_id_seg}{i}')
|
|
60
|
+
else:
|
|
61
|
+
self.model_names.extend([f'G{self.mod_id_seg}{i+1}', f'D{self.mod_id_seg}{i+1}'])
|
|
62
|
+
self.model_names_gs.append(f'G{self.mod_id_seg}{i+1}')
|
|
46
63
|
else: # during test time, only load G
|
|
47
64
|
self.model_names = []
|
|
48
65
|
for i in range(1, self.opt.modalities_no + 1):
|
|
49
|
-
self.model_names.extend(['G'
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
66
|
+
self.model_names.extend([f'G{i}'])
|
|
67
|
+
self.model_names_g.append(f'G{i}')
|
|
68
|
+
|
|
69
|
+
#input_id = get_input_id(os.path.join(opt.checkpoints_dir, opt.name))
|
|
70
|
+
if self.input_id == '0':
|
|
71
|
+
for i in range(self.opt.modalities_no + 1): # 0 is used for the base input mod
|
|
72
|
+
self.model_names.extend([f'G{self.mod_id_seg}{i}'])
|
|
73
|
+
self.model_names_gs.append(f'G{self.mod_id_seg}{i}')
|
|
74
|
+
else:
|
|
75
|
+
for i in range(self.opt.modalities_no + 1): # old setting, 1 is used for the base input mod
|
|
76
|
+
self.model_names.extend([f'G{self.mod_id_seg}{i+1}'])
|
|
77
|
+
self.model_names_gs.append(f'G{self.mod_id_seg}{i+1}')
|
|
78
|
+
|
|
54
79
|
# define networks (both generator and discriminator)
|
|
55
80
|
if isinstance(opt.netG, str):
|
|
56
|
-
opt.netG = [opt.netG] *
|
|
81
|
+
opt.netG = [opt.netG] * self.opt.modalities_no
|
|
57
82
|
if isinstance(opt.net_gs, str):
|
|
58
|
-
opt.net_gs = [opt.net_gs]*
|
|
83
|
+
opt.net_gs = [opt.net_gs] * (self.opt.modalities_no + 1) # +1 for base input mod
|
|
59
84
|
|
|
60
85
|
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
|
|
64
|
-
not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids, opt.padding)
|
|
65
|
-
self.netG3 = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG[2], opt.norm,
|
|
66
|
-
not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids, opt.padding)
|
|
67
|
-
self.netG4 = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG[3], opt.norm,
|
|
68
|
-
not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids, opt.padding)
|
|
86
|
+
for i,model_name in enumerate(self.model_names_g):
|
|
87
|
+
setattr(self,f'net{model_name}',networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG[i], opt.norm,
|
|
88
|
+
not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids, opt.padding))
|
|
69
89
|
|
|
70
90
|
# DeepLIIF model currently uses one gs arch because there is only one explicit seg mod output
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
|
|
74
|
-
not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids)
|
|
75
|
-
self.netG53 = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.net_gs[2], opt.norm,
|
|
76
|
-
not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids)
|
|
77
|
-
self.netG54 = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.net_gs[3], opt.norm,
|
|
78
|
-
not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids)
|
|
79
|
-
self.netG55 = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.net_gs[4], opt.norm,
|
|
80
|
-
not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids)
|
|
81
|
-
|
|
91
|
+
for i,model_name in enumerate(self.model_names_gs):
|
|
92
|
+
setattr(self,f'net{model_name}',networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.net_gs[i], opt.norm,
|
|
93
|
+
not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids))
|
|
82
94
|
|
|
83
95
|
if self.is_train: # define a discriminator; conditional GANs need to take both input and output images; Therefore, #channels for D is input_nc + output_nc
|
|
84
|
-
self.
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
|
|
91
|
-
opt.n_layers_D, opt.norm, opt.init_type, opt.init_gain, self.gpu_ids)
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
96
|
-
|
|
97
|
-
self.netD53 = networks.define_D(opt.input_nc + opt.output_nc, opt.ndf, opt.netD,
|
|
98
|
-
opt.n_layers_D, opt.norm, opt.init_type, opt.init_gain, self.gpu_ids)
|
|
99
|
-
self.netD54 = networks.define_D(opt.input_nc + opt.output_nc, opt.ndf, opt.netD,
|
|
100
|
-
opt.n_layers_D, opt.norm, opt.init_type, opt.init_gain, self.gpu_ids)
|
|
101
|
-
self.netD55 = networks.define_D(opt.input_nc + opt.output_nc, opt.ndf, opt.netD,
|
|
102
|
-
opt.n_layers_D, opt.norm, opt.init_type, opt.init_gain, self.gpu_ids)
|
|
103
|
-
|
|
96
|
+
self.model_names_d = [f'D{i+1}' for i in range(self.opt.modalities_no)]
|
|
97
|
+
if self.input_id == '0':
|
|
98
|
+
self.model_names_ds = [f'D{self.mod_id_seg}{i}' for i in range(self.opt.modalities_no+1)]
|
|
99
|
+
else:
|
|
100
|
+
self.model_names_ds = [f'D{self.mod_id_seg}{i+1}' for i in range(self.opt.modalities_no+1)]
|
|
101
|
+
for model_name in self.model_names_d:
|
|
102
|
+
setattr(self,f'net{model_name}',networks.define_D(opt.input_nc+opt.output_nc , opt.ndf, opt.netD,
|
|
103
|
+
opt.n_layers_D, opt.norm, opt.init_type, opt.init_gain, self.gpu_ids))
|
|
104
|
+
|
|
105
|
+
for model_name in self.model_names_ds:
|
|
106
|
+
setattr(self,f'net{model_name}',networks.define_D(opt.input_nc + opt.output_nc, opt.ndf, opt.netD,
|
|
107
|
+
opt.n_layers_D, opt.norm, opt.init_type, opt.init_gain, self.gpu_ids))
|
|
108
|
+
|
|
104
109
|
# load the teacher model
|
|
105
110
|
self.opt_teacher = get_opt(opt.model_dir_teacher, mode='test')
|
|
106
111
|
self.opt_teacher.gpu_ids = opt.gpu_ids # use student's gpu_ids
|
|
107
112
|
self.nets_teacher = init_nets(opt.model_dir_teacher, eager_mode=True, opt=self.opt_teacher, phase='test')
|
|
108
113
|
|
|
114
|
+
# modify model names to be consistent with the current deepliifkd model names
|
|
115
|
+
# otherwise it may be tricky to pair the loss terms?
|
|
116
|
+
self.opt_teacher.mod_id_seg, self.opt_teacher.input_id = init_input_and_mod_id(self.opt_teacher)
|
|
117
|
+
d_mapping_model_name = map_model_names(list(self.nets_teacher.keys()),self.opt_teacher.mod_id_seg,self.opt_teacher.input_id,
|
|
118
|
+
self.mod_id_seg,self.input_id)
|
|
119
|
+
self.d_mapping_model_name = d_mapping_model_name
|
|
120
|
+
print('Model name mapping, teacher model to student model:',self.d_mapping_model_name)
|
|
121
|
+
else:
|
|
122
|
+
# remove all model names for the teacher model
|
|
123
|
+
self.model_names = [name for name in self.model_names if not name.endswith('_teacher')]
|
|
124
|
+
self.loss_names = [name for name in self.loss_names if not name.endswith('_teacher')]
|
|
125
|
+
self.visual_names = [name for name in self.visual_names if not name.endswith('_teacher')]
|
|
126
|
+
|
|
109
127
|
|
|
110
128
|
if self.is_train:
|
|
111
129
|
# define loss functions
|
|
@@ -114,14 +132,22 @@ class DeepLIIFKDModel(BaseModel):
|
|
|
114
132
|
self.criterionSmoothL1 = torch.nn.SmoothL1Loss()
|
|
115
133
|
|
|
116
134
|
# initialize optimizers; schedulers will be automatically created by function <BaseModel.setup>.
|
|
117
|
-
params = list(self.netG1.parameters()) + list(self.netG2.parameters()) + list(self.netG3.parameters()) + list(self.netG4.parameters()) + list(self.netG51.parameters()) + list(self.netG52.parameters()) + list(self.netG53.parameters()) + list(self.netG54.parameters()) + list(self.netG55.parameters())
|
|
135
|
+
#params = list(self.netG1.parameters()) + list(self.netG2.parameters()) + list(self.netG3.parameters()) + list(self.netG4.parameters()) + list(self.netG51.parameters()) + list(self.netG52.parameters()) + list(self.netG53.parameters()) + list(self.netG54.parameters()) + list(self.netG55.parameters())
|
|
136
|
+
params = []
|
|
137
|
+
for model_name in self.model_names_g + self.model_names_gs:
|
|
138
|
+
params += list(getattr(self,f'net{model_name}').parameters())
|
|
139
|
+
|
|
118
140
|
try:
|
|
119
141
|
self.optimizer_G = get_optimizer(opt.optimizer)(params, lr=opt.lr_g, betas=(opt.beta1, 0.999))
|
|
120
142
|
except:
|
|
121
143
|
print(f'betas are not used for optimizer torch.optim.{opt.optimizer} in generators')
|
|
122
144
|
self.optimizer_G = get_optimizer(opt.optimizer)(params, lr=opt.lr_g)
|
|
123
145
|
|
|
124
|
-
params = list(self.netD1.parameters()) + list(self.netD2.parameters()) + list(self.netD3.parameters()) + list(self.netD4.parameters()) + list(self.netD51.parameters()) + list(self.netD52.parameters()) + list(self.netD53.parameters()) + list(self.netD54.parameters()) + list(self.netD55.parameters())
|
|
146
|
+
#params = list(self.netD1.parameters()) + list(self.netD2.parameters()) + list(self.netD3.parameters()) + list(self.netD4.parameters()) + list(self.netD51.parameters()) + list(self.netD52.parameters()) + list(self.netD53.parameters()) + list(self.netD54.parameters()) + list(self.netD55.parameters())
|
|
147
|
+
params = []
|
|
148
|
+
for model_name in self.model_names_d + self.model_names_ds:
|
|
149
|
+
params += list(getattr(self,f'net{model_name}').parameters())
|
|
150
|
+
|
|
125
151
|
try:
|
|
126
152
|
self.optimizer_D = get_optimizer(opt.optimizer)(params, lr=opt.lr_d, betas=(opt.beta1, 0.999))
|
|
127
153
|
except:
|
|
@@ -131,6 +157,7 @@ class DeepLIIFKDModel(BaseModel):
|
|
|
131
157
|
self.optimizers.append(self.optimizer_G)
|
|
132
158
|
self.optimizers.append(self.optimizer_D)
|
|
133
159
|
|
|
160
|
+
print('self.device',self.device)
|
|
134
161
|
self.criterionVGG = networks.VGGLoss().to(self.device)
|
|
135
162
|
self.criterionKLDiv = torch.nn.KLDivLoss(reduction='batchmean').to(self.device)
|
|
136
163
|
self.softmax = torch.nn.Softmax(dim=-1).to(self.device) # apply softmax on the last dim
|
|
@@ -146,192 +173,181 @@ class DeepLIIFKDModel(BaseModel):
|
|
|
146
173
|
self.real_A = input['A'].to(self.device)
|
|
147
174
|
|
|
148
175
|
self.real_B_array = input['B']
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
self.
|
|
152
|
-
|
|
153
|
-
self.real_B_5 = self.real_B_array[4].to(self.device)
|
|
176
|
+
for i in range(self.opt.modalities_no):
|
|
177
|
+
setattr(self,f'real_B_{i+1}',self.real_B_array[i].to(self.device))
|
|
178
|
+
setattr(self,f'real_B_{self.mod_id_seg}',self.real_B_array[self.opt.modalities_no].to(self.device)) # the last one is seg
|
|
179
|
+
|
|
154
180
|
self.image_paths = input['A_paths']
|
|
155
181
|
|
|
156
182
|
def forward(self):
|
|
157
183
|
"""Run forward pass; called by both functions <optimize_parameters> and <test>."""
|
|
158
|
-
self.fake_B_1 = self.netG1(self.real_A) # Hematoxylin image generator
|
|
159
|
-
self.fake_B_2 = self.netG2(self.real_A) # mpIF DAPI image generator
|
|
160
|
-
self.fake_B_3 = self.netG3(self.real_A) # mpIF Lap2 image generator
|
|
161
|
-
self.fake_B_4 = self.netG4(self.real_A) # mpIF Ki67 image generator
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
165
|
-
|
|
166
|
-
self.
|
|
167
|
-
self.
|
|
168
|
-
self.
|
|
169
|
-
|
|
170
|
-
|
|
171
|
-
|
|
172
|
-
|
|
184
|
+
# self.fake_B_1 = self.netG1(self.real_A) # Hematoxylin image generator
|
|
185
|
+
# self.fake_B_2 = self.netG2(self.real_A) # mpIF DAPI image generator
|
|
186
|
+
# self.fake_B_3 = self.netG3(self.real_A) # mpIF Lap2 image generator
|
|
187
|
+
# self.fake_B_4 = self.netG4(self.real_A) # mpIF Ki67 image generator
|
|
188
|
+
|
|
189
|
+
for i in range(self.opt.modalities_no):
|
|
190
|
+
setattr(self,f'fake_B_{i+1}',getattr(self,f'netG{i+1}')(self.real_A))
|
|
191
|
+
|
|
192
|
+
# self.fake_B_5_1 = self.netG51(self.real_A) # Segmentation mask generator from IHC input image
|
|
193
|
+
# self.fake_B_5_2 = self.netG52(self.fake_B_1) # Segmentation mask generator from Hematoxylin input image
|
|
194
|
+
# self.fake_B_5_3 = self.netG53(self.fake_B_2) # Segmentation mask generator from mpIF DAPI input image
|
|
195
|
+
# self.fake_B_5_4 = self.netG54(self.fake_B_3) # Segmentation mask generator from mpIF Lap2 input image
|
|
196
|
+
# self.fake_B_5_5 = self.netG55(self.fake_B_4) # Segmentation mask generator from mpIF Lap2 input image
|
|
197
|
+
# self.fake_B_5 = torch.stack([torch.mul(self.fake_B_5_1, self.seg_weights[0]),
|
|
198
|
+
# torch.mul(self.fake_B_5_2, self.seg_weights[1]),
|
|
199
|
+
# torch.mul(self.fake_B_5_3, self.seg_weights[2]),
|
|
200
|
+
# torch.mul(self.fake_B_5_4, self.seg_weights[3]),
|
|
201
|
+
# torch.mul(self.fake_B_5_5, self.seg_weights[4])]).sum(dim=0)
|
|
173
202
|
|
|
174
|
-
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
203
|
+
for i,model_name in enumerate(self.model_names_gs):
|
|
204
|
+
if i == 0:
|
|
205
|
+
setattr(self,f'fake_B_{self.mod_id_seg}_{i}',getattr(self,f'net{model_name}')(self.real_A))
|
|
206
|
+
else:
|
|
207
|
+
setattr(self,f'fake_B_{self.mod_id_seg}_{i}',getattr(self,f'net{model_name}')(getattr(self,f'fake_B_{i}')))
|
|
208
|
+
|
|
209
|
+
setattr(self,f'fake_B_{self.mod_id_seg}',torch.stack([torch.mul(getattr(self,f'fake_B_{self.mod_id_seg}_{i}'), self.seg_weights[i]) for i in range(self.opt.modalities_no+1)]).sum(dim=0))
|
|
210
|
+
|
|
211
|
+
if self.is_train:
|
|
212
|
+
fakes_teacher = run_dask(img=self.real_A, nets=self.nets_teacher, opt=self.opt_teacher, use_dask=False, output_tensor=True)
|
|
213
|
+
#print(f'Checking seg mod id for teacher model: current id is {self.opt_teacher.mod_id_seg}, id to map to is {self.mod_id_seg}')
|
|
214
|
+
for k,v in fakes_teacher.items():
|
|
215
|
+
suffix = self.d_mapping_model_name[k][1:] # starts with G
|
|
216
|
+
l_suffix = list(suffix)
|
|
217
|
+
if l_suffix[0] == str(self.opt_teacher.mod_id_seg): # mod_id_seg might be integer
|
|
218
|
+
if l_suffix[0] != str(self.mod_id_seg):
|
|
219
|
+
l_suffix[0] = str(self.mod_id_seg)
|
|
220
|
+
#suffix = '_'.join(list(suffix)) # 51 -> 5_1
|
|
221
|
+
suffix = '_'.join(l_suffix) # 51 -> 5_1
|
|
222
|
+
#print(f'Loaded teacher model: fake_B_{suffix}_teacher')
|
|
223
|
+
setattr(self,f'fake_B_{suffix}_teacher',v.to(self.device))
|
|
179
224
|
|
|
180
225
|
def backward_D(self):
|
|
181
226
|
"""Calculate GAN loss for the discriminators"""
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
real_AB_2 = torch.cat((self.real_A, self.real_B_2), 1)
|
|
220
|
-
real_AB_3 = torch.cat((self.real_A, self.real_B_3), 1)
|
|
221
|
-
real_AB_4 = torch.cat((self.real_A, self.real_B_4), 1)
|
|
222
|
-
|
|
223
|
-
pred_real_1 = self.netD1(real_AB_1)
|
|
224
|
-
pred_real_2 = self.netD2(real_AB_2)
|
|
225
|
-
pred_real_3 = self.netD3(real_AB_3)
|
|
226
|
-
pred_real_4 = self.netD4(real_AB_4)
|
|
227
|
-
|
|
228
|
-
real_AB_5_1 = torch.cat((self.real_A, self.real_B_5), 1)
|
|
229
|
-
real_AB_5_2 = torch.cat((self.real_B_1, self.real_B_5), 1)
|
|
230
|
-
real_AB_5_3 = torch.cat((self.real_B_2, self.real_B_5), 1)
|
|
231
|
-
real_AB_5_4 = torch.cat((self.real_B_3, self.real_B_5), 1)
|
|
232
|
-
real_AB_5_5 = torch.cat((self.real_B_4, self.real_B_5), 1)
|
|
233
|
-
|
|
234
|
-
pred_real_5_1 = self.netD51(real_AB_5_1)
|
|
235
|
-
pred_real_5_2 = self.netD52(real_AB_5_2)
|
|
236
|
-
pred_real_5_3 = self.netD53(real_AB_5_3)
|
|
237
|
-
pred_real_5_4 = self.netD54(real_AB_5_4)
|
|
238
|
-
pred_real_5_5 = self.netD55(real_AB_5_5)
|
|
239
|
-
|
|
240
|
-
pred_real_5 = torch.stack(
|
|
241
|
-
[torch.mul(pred_real_5_1, self.seg_weights[0]),
|
|
242
|
-
torch.mul(pred_real_5_2, self.seg_weights[1]),
|
|
243
|
-
torch.mul(pred_real_5_3, self.seg_weights[2]),
|
|
244
|
-
torch.mul(pred_real_5_4, self.seg_weights[3]),
|
|
245
|
-
torch.mul(pred_real_5_5, self.seg_weights[4])]).sum(dim=0)
|
|
246
|
-
|
|
247
|
-
self.loss_D_real_1 = self.criterionGAN_BCE(pred_real_1, True)
|
|
248
|
-
self.loss_D_real_2 = self.criterionGAN_BCE(pred_real_2, True)
|
|
249
|
-
self.loss_D_real_3 = self.criterionGAN_BCE(pred_real_3, True)
|
|
250
|
-
self.loss_D_real_4 = self.criterionGAN_BCE(pred_real_4, True)
|
|
251
|
-
self.loss_D_real_5 = self.criterionGAN_lsgan(pred_real_5, True)
|
|
227
|
+
|
|
228
|
+
for i,model_name in enumerate(self.model_names_d):
|
|
229
|
+
fake_AB = torch.cat((self.real_A, getattr(self,f'fake_B_{i+1}')), 1)
|
|
230
|
+
pred_fake = getattr(self,f'net{model_name}')(fake_AB.detach())
|
|
231
|
+
setattr(self,f'loss_D_fake_{i+1}',self.criterionGAN_BCE(pred_fake, False))
|
|
232
|
+
|
|
233
|
+
l_pred_fake_seg = []
|
|
234
|
+
for i,model_name in enumerate(self.model_names_ds):
|
|
235
|
+
if i == 0:
|
|
236
|
+
fake_AB_seg_i = torch.cat((self.real_A, getattr(self,f'fake_B_{self.mod_id_seg}')), 1)
|
|
237
|
+
else:
|
|
238
|
+
fake_AB_seg_i = torch.cat((getattr(self,f'real_B_{i}'), getattr(self,f'fake_B_{self.mod_id_seg}')), 1)
|
|
239
|
+
|
|
240
|
+
pred_fake_seg_i = getattr(self,f'net{model_name}')(fake_AB_seg_i.detach())
|
|
241
|
+
l_pred_fake_seg.append(torch.mul(pred_fake_seg_i, self.seg_weights[i]))
|
|
242
|
+
pred_fake_seg = torch.stack(l_pred_fake_seg).sum(dim=0)
|
|
243
|
+
|
|
244
|
+
setattr(self,f'loss_D_fake_{self.mod_id_seg}',self.criterionGAN_lsgan(pred_fake_seg, False))
|
|
245
|
+
|
|
246
|
+
for i,model_name in enumerate(self.model_names_d):
|
|
247
|
+
real_AB = torch.cat((self.real_A, getattr(self,f'real_B_{i+1}')), 1)
|
|
248
|
+
pred_real = getattr(self,f'net{model_name}')(real_AB)
|
|
249
|
+
setattr(self,f'loss_D_real_{i+1}',self.criterionGAN_BCE(pred_real, True))
|
|
250
|
+
|
|
251
|
+
l_pred_real_seg = []
|
|
252
|
+
for i,model_name in enumerate(self.model_names_ds):
|
|
253
|
+
if i == 0:
|
|
254
|
+
real_AB_seg_i = torch.cat((self.real_A, getattr(self,f'real_B_{self.mod_id_seg}')), 1)
|
|
255
|
+
else:
|
|
256
|
+
real_AB_seg_i = torch.cat((getattr(self,f'real_B_{i}'), getattr(self,f'real_B_{self.mod_id_seg}')), 1)
|
|
257
|
+
|
|
258
|
+
pred_real_seg_i = getattr(self,f'net{model_name}')(real_AB_seg_i)
|
|
259
|
+
l_pred_real_seg.append(torch.mul(pred_real_seg_i, self.seg_weights[i]))
|
|
260
|
+
pred_real_seg = torch.stack(l_pred_real_seg).sum(dim=0)
|
|
261
|
+
|
|
262
|
+
#self.loss_D_real_5 = self.criterionGAN_lsgan(pred_real_5, True)
|
|
263
|
+
setattr(self,f'loss_D_real_{self.mod_id_seg}',self.criterionGAN_lsgan(pred_real_seg, True))
|
|
252
264
|
|
|
253
265
|
# combine losses and calculate gradients
|
|
254
|
-
self.loss_D = (
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
(self.loss_D_fake_5 + self.loss_D_real_5) * 0.5 * self.loss_D_weights[4]
|
|
266
|
+
self.loss_D = torch.tensor(0., device=self.device)
|
|
267
|
+
for i in range(self.opt.modalities_no):
|
|
268
|
+
self.loss_D += (getattr(self,f'loss_D_fake_{i+1}') + getattr(self,f'loss_D_real_{i+1}')) * 0.5 * self.loss_D_weights[i]
|
|
269
|
+
self.loss_D += (getattr(self,f'loss_D_fake_{self.mod_id_seg}') + getattr(self,f'loss_D_real_{self.mod_id_seg}')) * 0.5 * self.loss_D_weights[self.opt.modalities_no]
|
|
259
270
|
|
|
260
271
|
self.loss_D.backward()
|
|
261
272
|
|
|
262
273
|
def backward_G(self):
|
|
263
274
|
"""Calculate GAN and L1 loss for the generator"""
|
|
264
|
-
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
|
|
270
|
-
|
|
271
|
-
|
|
272
|
-
|
|
273
|
-
|
|
274
|
-
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
|
|
278
|
-
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
|
|
282
|
-
pred_fake_5_2 = self.netD52(fake_AB_5_2)
|
|
283
|
-
pred_fake_5_3 = self.netD53(fake_AB_5_3)
|
|
284
|
-
pred_fake_5_4 = self.netD54(fake_AB_5_4)
|
|
285
|
-
pred_fake_5_5 = self.netD55(fake_AB_5_5)
|
|
286
|
-
pred_fake_5 = torch.stack(
|
|
287
|
-
[torch.mul(pred_fake_5_1, self.seg_weights[0]),
|
|
288
|
-
torch.mul(pred_fake_5_2, self.seg_weights[1]),
|
|
289
|
-
torch.mul(pred_fake_5_3, self.seg_weights[2]),
|
|
290
|
-
torch.mul(pred_fake_5_4, self.seg_weights[3]),
|
|
291
|
-
torch.mul(pred_fake_5_5, self.seg_weights[4])]).sum(dim=0)
|
|
292
|
-
|
|
293
|
-
self.loss_G_GAN_1 = self.criterionGAN_BCE(pred_fake_1, True)
|
|
294
|
-
self.loss_G_GAN_2 = self.criterionGAN_BCE(pred_fake_2, True)
|
|
295
|
-
self.loss_G_GAN_3 = self.criterionGAN_BCE(pred_fake_3, True)
|
|
296
|
-
self.loss_G_GAN_4 = self.criterionGAN_BCE(pred_fake_4, True)
|
|
297
|
-
self.loss_G_GAN_5 = self.criterionGAN_lsgan(pred_fake_5, True)
|
|
275
|
+
|
|
276
|
+
for i,model_name in enumerate(self.model_names_d):
|
|
277
|
+
fake_AB = torch.cat((self.real_A, getattr(self,f'fake_B_{i+1}')), 1)
|
|
278
|
+
pred_fake = getattr(self,f'net{model_name}')(fake_AB)
|
|
279
|
+
setattr(self,f'loss_G_GAN_{i+1}',self.criterionGAN_BCE(pred_fake, True))
|
|
280
|
+
|
|
281
|
+
l_pred_fake_seg = []
|
|
282
|
+
for i,model_name in enumerate(self.model_names_ds):
|
|
283
|
+
if i == 0:
|
|
284
|
+
fake_AB_seg_i = torch.cat((self.real_A, getattr(self,f'fake_B_{self.mod_id_seg}')), 1)
|
|
285
|
+
else:
|
|
286
|
+
fake_AB_seg_i = torch.cat((getattr(self,f'real_B_{i}'), getattr(self,f'fake_B_{self.mod_id_seg}')), 1)
|
|
287
|
+
|
|
288
|
+
pred_fake_seg_i = getattr(self,f'net{model_name}')(fake_AB_seg_i)
|
|
289
|
+
l_pred_fake_seg.append(torch.mul(pred_fake_seg_i, self.seg_weights[i]))
|
|
290
|
+
pred_fake_seg = torch.stack(l_pred_fake_seg).sum(dim=0)
|
|
291
|
+
|
|
292
|
+
setattr(self,f'loss_G_GAN_{self.mod_id_seg}',self.criterionGAN_lsgan(pred_fake_seg, True))
|
|
298
293
|
|
|
299
294
|
# Second, G(A) = B
|
|
300
|
-
|
|
301
|
-
|
|
302
|
-
self.
|
|
303
|
-
self.loss_G_L1_4 = self.criterionSmoothL1(self.fake_B_4, self.real_B_4) * self.opt.lambda_L1
|
|
304
|
-
self.loss_G_L1_5 = self.criterionSmoothL1(self.fake_B_5, self.real_B_5) * self.opt.lambda_L1
|
|
305
|
-
|
|
306
|
-
self.loss_G_VGG_1 = self.criterionVGG(self.fake_B_1, self.real_B_1) * self.opt.lambda_feat
|
|
307
|
-
self.loss_G_VGG_2 = self.criterionVGG(self.fake_B_2, self.real_B_2) * self.opt.lambda_feat
|
|
308
|
-
self.loss_G_VGG_3 = self.criterionVGG(self.fake_B_3, self.real_B_3) * self.opt.lambda_feat
|
|
309
|
-
self.loss_G_VGG_4 = self.criterionVGG(self.fake_B_4, self.real_B_4) * self.opt.lambda_feat
|
|
295
|
+
for i in range(self.opt.modalities_no):
|
|
296
|
+
setattr(self,f'loss_G_L1_{i+1}',self.criterionSmoothL1(getattr(self,f'fake_B_{i+1}'), getattr(self,f'real_B_{i+1}')) * self.opt.lambda_L1)
|
|
297
|
+
setattr(self,f'loss_G_L1_{self.mod_id_seg}',self.criterionSmoothL1(getattr(self,f'fake_B_{self.mod_id_seg}'), getattr(self,f'real_B_{self.mod_id_seg}')) * self.opt.lambda_L1)
|
|
310
298
|
|
|
299
|
+
for i in range(self.opt.modalities_no):
|
|
300
|
+
setattr(self,f'loss_G_VGG_{i+1}',self.criterionVGG(getattr(self,f'fake_B_{i+1}'), getattr(self,f'real_B_{i+1}')) * self.opt.lambda_feat)
|
|
301
|
+
setattr(self,f'loss_G_VGG_{self.mod_id_seg}',self.criterionVGG(getattr(self,f'fake_B_{self.mod_id_seg}'), getattr(self,f'real_B_{self.mod_id_seg}')) * self.opt.lambda_feat)
|
|
311
302
|
|
|
312
303
|
# .view(1,1,-1) reshapes the input (batch_size, 3, 512, 512) to (batch_size, 1, 3*512*512)
|
|
313
304
|
# softmax/log-softmax is then applied on the concatenated vector of size (1, 3*512*512)
|
|
314
305
|
# this normalizes the pixel values across all 3 RGB channels
|
|
315
306
|
# the resulting vectors are then used to compute KL divergence loss
|
|
316
|
-
self.loss_G_KLDiv_1 = self.criterionKLDiv(self.logsoftmax(self.fake_B_1.view(1,1,-1)), self.softmax(self.fake_B_1_teacher.view(1,1,-1)))
|
|
317
|
-
self.loss_G_KLDiv_2 = self.criterionKLDiv(self.logsoftmax(self.fake_B_2.view(1,1,-1)), self.softmax(self.fake_B_2_teacher.view(1,1,-1)))
|
|
318
|
-
self.loss_G_KLDiv_3 = self.criterionKLDiv(self.logsoftmax(self.fake_B_3.view(1,1,-1)), self.softmax(self.fake_B_3_teacher.view(1,1,-1)))
|
|
319
|
-
self.loss_G_KLDiv_4 = self.criterionKLDiv(self.logsoftmax(self.fake_B_4.view(1,1,-1)), self.softmax(self.fake_B_4_teacher.view(1,1,-1)))
|
|
320
|
-
self.loss_G_KLDiv_5 = self.criterionKLDiv(self.logsoftmax(self.fake_B_5.view(1,1,-1)), self.softmax(self.fake_B_5_teacher.view(1,1,-1)))
|
|
321
|
-
|
|
322
|
-
|
|
323
|
-
|
|
324
|
-
self.
|
|
325
|
-
|
|
326
|
-
|
|
327
|
-
self.
|
|
328
|
-
|
|
329
|
-
|
|
330
|
-
|
|
331
|
-
|
|
332
|
-
|
|
333
|
-
|
|
334
|
-
|
|
307
|
+
# self.loss_G_KLDiv_1 = self.criterionKLDiv(self.logsoftmax(self.fake_B_1.view(1,1,-1)), self.softmax(self.fake_B_1_teacher.view(1,1,-1)))
|
|
308
|
+
# self.loss_G_KLDiv_2 = self.criterionKLDiv(self.logsoftmax(self.fake_B_2.view(1,1,-1)), self.softmax(self.fake_B_2_teacher.view(1,1,-1)))
|
|
309
|
+
# self.loss_G_KLDiv_3 = self.criterionKLDiv(self.logsoftmax(self.fake_B_3.view(1,1,-1)), self.softmax(self.fake_B_3_teacher.view(1,1,-1)))
|
|
310
|
+
# self.loss_G_KLDiv_4 = self.criterionKLDiv(self.logsoftmax(self.fake_B_4.view(1,1,-1)), self.softmax(self.fake_B_4_teacher.view(1,1,-1)))
|
|
311
|
+
# self.loss_G_KLDiv_5 = self.criterionKLDiv(self.logsoftmax(self.fake_B_5.view(1,1,-1)), self.softmax(self.fake_B_5_teacher.view(1,1,-1)))
|
|
312
|
+
|
|
313
|
+
for i in range(self.opt.modalities_no):
|
|
314
|
+
setattr(self,f'loss_G_KLDiv_{i+1}',self.criterionKLDiv(self.logsoftmax(getattr(self,f'fake_B_{i+1}').view(1,1,-1)), self.softmax(getattr(self,f'fake_B_{i+1}_teacher').view(1,1,-1))))
|
|
315
|
+
setattr(self,f'loss_G_KLDiv_{self.mod_id_seg}',self.criterionKLDiv(self.logsoftmax(getattr(self,f'fake_B_{self.mod_id_seg}').view(1,1,-1)), self.softmax(getattr(self,f'fake_B_{self.mod_id_seg}_teacher').view(1,1,-1))))
|
|
316
|
+
|
|
317
|
+
# self.loss_G_KLDiv_5_1 = self.criterionKLDiv(self.logsoftmax(self.fake_B_5_1.view(1,1,-1)), self.softmax(self.fake_B_5_1_teacher.view(1,1,-1)))
|
|
318
|
+
# self.loss_G_KLDiv_5_2 = self.criterionKLDiv(self.logsoftmax(self.fake_B_5_2.view(1,1,-1)), self.softmax(self.fake_B_5_2_teacher.view(1,1,-1)))
|
|
319
|
+
# self.loss_G_KLDiv_5_3 = self.criterionKLDiv(self.logsoftmax(self.fake_B_5_3.view(1,1,-1)), self.softmax(self.fake_B_5_3_teacher.view(1,1,-1)))
|
|
320
|
+
# self.loss_G_KLDiv_5_4 = self.criterionKLDiv(self.logsoftmax(self.fake_B_5_4.view(1,1,-1)), self.softmax(self.fake_B_5_4_teacher.view(1,1,-1)))
|
|
321
|
+
# self.loss_G_KLDiv_5_5 = self.criterionKLDiv(self.logsoftmax(self.fake_B_5_5.view(1,1,-1)), self.softmax(self.fake_B_5_5_teacher.view(1,1,-1)))
|
|
322
|
+
|
|
323
|
+
for i in range(self.opt.modalities_no+1):
|
|
324
|
+
setattr(self,f'loss_G_KLDiv_{self.mod_id_seg}{i}',self.criterionKLDiv(self.logsoftmax(getattr(self,f'fake_B_{self.mod_id_seg}_{i}').view(1,1,-1)), self.softmax(getattr(self,f'fake_B_{self.mod_id_seg}_{i}_teacher').view(1,1,-1))))
|
|
325
|
+
#setattr(self,f'loss_G_KLDiv_{self.mod_id_seg}{self.opt.modalities_no+1}',self.criterionKLDiv(self.logsoftmax(getattr(self,f'fake_B_{self.mod_id_seg}_{self.opt.modalities_no+1}').view(1,1,-1)), self.softmax(getattr(self,f'fake_B_{self.mod_id_seg}_teacher').view(1,1,-1))))
|
|
326
|
+
|
|
327
|
+
|
|
328
|
+
# self.loss_G = (self.loss_G_GAN_1 + self.loss_G_L1_1 + self.loss_G_VGG_1) * self.loss_G_weights[0] + \
|
|
329
|
+
# (self.loss_G_GAN_2 + self.loss_G_L1_2 + self.loss_G_VGG_2) * self.loss_G_weights[1] + \
|
|
330
|
+
# (self.loss_G_GAN_3 + self.loss_G_L1_3 + self.loss_G_VGG_3) * self.loss_G_weights[2] + \
|
|
331
|
+
# (self.loss_G_GAN_4 + self.loss_G_L1_4 + self.loss_G_VGG_4) * self.loss_G_weights[3] + \
|
|
332
|
+
# (self.loss_G_GAN_5 + self.loss_G_L1_5) * self.loss_G_weights[4] + \
|
|
333
|
+
# (self.loss_G_KLDiv_1 + self.loss_G_KLDiv_2 + self.loss_G_KLDiv_3 + self.loss_G_KLDiv_4 + \
|
|
334
|
+
# self.loss_G_KLDiv_5 + self.loss_G_KLDiv_5_1 + self.loss_G_KLDiv_5_2 + self.loss_G_KLDiv_5_3 + \
|
|
335
|
+
# self.loss_G_KLDiv_5_4 + self.loss_G_KLDiv_5_5) * 10
|
|
336
|
+
|
|
337
|
+
self.loss_G = torch.tensor(0., device=self.device)
|
|
338
|
+
for i in range(self.opt.modalities_no):
|
|
339
|
+
self.loss_G += (getattr(self,f'loss_G_GAN_{i+1}') + getattr(self,f'loss_G_L1_{i+1}') + getattr(self,f'loss_G_VGG_{i+1}')) * self.loss_G_weights[i]
|
|
340
|
+
self.loss_G += (getattr(self,f'loss_G_GAN_{self.mod_id_seg}') + getattr(self,f'loss_G_L1_{self.mod_id_seg}')) * self.loss_G_weights[i]
|
|
341
|
+
|
|
342
|
+
factor_KLDiv = 10
|
|
343
|
+
for i in range(self.opt.modalities_no):
|
|
344
|
+
self.loss_G += (getattr(self,f'loss_G_KLDiv_{i+1}') + getattr(self,f'loss_G_KLDiv_{self.mod_id_seg}{i+1}')) * factor_KLDiv
|
|
345
|
+
self.loss_G += getattr(self,f'loss_G_KLDiv_{self.mod_id_seg}') * factor_KLDiv
|
|
346
|
+
if self.input_id == '0':
|
|
347
|
+
self.loss_G += getattr(self,f'loss_G_KLDiv_{self.mod_id_seg}0') * factor_KLDiv
|
|
348
|
+
else:
|
|
349
|
+
self.loss_G += getattr(self,f'loss_G_KLDiv_{self.mod_id_seg}{self.opt.modalities_no+1}') * factor_KLDiv
|
|
350
|
+
|
|
335
351
|
|
|
336
352
|
# combine loss and calculate gradients
|
|
337
353
|
# self.loss_G = (self.loss_G_GAN_1 + self.loss_G_L1_1) * self.loss_G_weights[0] + \
|
|
@@ -344,30 +360,16 @@ class DeepLIIFKDModel(BaseModel):
|
|
|
344
360
|
def optimize_parameters(self):
|
|
345
361
|
self.forward() # compute fake images: G(A)
|
|
346
362
|
# update D
|
|
347
|
-
|
|
348
|
-
|
|
349
|
-
self.set_requires_grad(self.netD3, True) # enable backprop for D3
|
|
350
|
-
self.set_requires_grad(self.netD4, True) # enable backprop for D4
|
|
351
|
-
self.set_requires_grad(self.netD51, True) # enable backprop for D51
|
|
352
|
-
self.set_requires_grad(self.netD52, True) # enable backprop for D52
|
|
353
|
-
self.set_requires_grad(self.netD53, True) # enable backprop for D53
|
|
354
|
-
self.set_requires_grad(self.netD54, True) # enable backprop for D54
|
|
355
|
-
self.set_requires_grad(self.netD55, True) # enable backprop for D54
|
|
363
|
+
for model_name in self.model_names_d + self.model_names_ds:
|
|
364
|
+
self.set_requires_grad(getattr(self,f'net{model_name}'), True)
|
|
356
365
|
|
|
357
366
|
self.optimizer_D.zero_grad() # set D's gradients to zero
|
|
358
367
|
self.backward_D() # calculate gradients for D
|
|
359
368
|
self.optimizer_D.step() # update D's weights
|
|
360
369
|
|
|
361
370
|
# update G
|
|
362
|
-
|
|
363
|
-
|
|
364
|
-
self.set_requires_grad(self.netD3, False) # D3 requires no gradients when optimizing G3
|
|
365
|
-
self.set_requires_grad(self.netD4, False) # D4 requires no gradients when optimizing G4
|
|
366
|
-
self.set_requires_grad(self.netD51, False) # D51 requires no gradients when optimizing G51
|
|
367
|
-
self.set_requires_grad(self.netD52, False) # D52 requires no gradients when optimizing G52
|
|
368
|
-
self.set_requires_grad(self.netD53, False) # D53 requires no gradients when optimizing G53
|
|
369
|
-
self.set_requires_grad(self.netD54, False) # D54 requires no gradients when optimizing G54
|
|
370
|
-
self.set_requires_grad(self.netD55, False) # D54 requires no gradients when optimizing G54
|
|
371
|
+
for model_name in self.model_names_d + self.model_names_ds:
|
|
372
|
+
self.set_requires_grad(getattr(self,f'net{model_name}'), False)
|
|
371
373
|
|
|
372
374
|
self.optimizer_G.zero_grad() # set G's gradients to zero
|
|
373
375
|
self.backward_G() # calculate graidents for G
|
|
@@ -380,29 +382,15 @@ class DeepLIIFKDModel(BaseModel):
|
|
|
380
382
|
|
|
381
383
|
self.forward() # compute fake images: G(A)
|
|
382
384
|
# update D
|
|
383
|
-
|
|
384
|
-
|
|
385
|
-
self.set_requires_grad(self.netD3, True) # enable backprop for D3
|
|
386
|
-
self.set_requires_grad(self.netD4, True) # enable backprop for D4
|
|
387
|
-
self.set_requires_grad(self.netD51, True) # enable backprop for D51
|
|
388
|
-
self.set_requires_grad(self.netD52, True) # enable backprop for D52
|
|
389
|
-
self.set_requires_grad(self.netD53, True) # enable backprop for D53
|
|
390
|
-
self.set_requires_grad(self.netD54, True) # enable backprop for D54
|
|
391
|
-
self.set_requires_grad(self.netD55, True) # enable backprop for D54
|
|
385
|
+
for model_name in self.model_names_d + self.model_names_ds:
|
|
386
|
+
self.set_requires_grad(getattr(self,f'net{model_name}'), True)
|
|
392
387
|
|
|
393
388
|
self.optimizer_D.zero_grad() # set D's gradients to zero
|
|
394
389
|
self.backward_D() # calculate gradients for D
|
|
395
390
|
|
|
396
391
|
# update G
|
|
397
|
-
|
|
398
|
-
|
|
399
|
-
self.set_requires_grad(self.netD3, False) # D3 requires no gradients when optimizing G3
|
|
400
|
-
self.set_requires_grad(self.netD4, False) # D4 requires no gradients when optimizing G4
|
|
401
|
-
self.set_requires_grad(self.netD51, False) # D51 requires no gradients when optimizing G51
|
|
402
|
-
self.set_requires_grad(self.netD52, False) # D52 requires no gradients when optimizing G52
|
|
403
|
-
self.set_requires_grad(self.netD53, False) # D53 requires no gradients when optimizing G53
|
|
404
|
-
self.set_requires_grad(self.netD54, False) # D54 requires no gradients when optimizing G54
|
|
405
|
-
self.set_requires_grad(self.netD55, False) # D54 requires no gradients when optimizing G54
|
|
392
|
+
for model_name in self.model_names_d + self.model_names_ds:
|
|
393
|
+
self.set_requires_grad(getattr(self,f'net{model_name}'), False)
|
|
406
394
|
|
|
407
395
|
self.optimizer_G.zero_grad() # set G's gradients to zero
|
|
408
396
|
self.backward_G() # calculate graidents for G
|