deepliif 1.2.2__py3-none-any.whl → 1.2.4__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cli.py +79 -24
- deepliif/data/base_dataset.py +2 -0
- deepliif/models/DeepLIIFKD_model.py +243 -255
- deepliif/models/DeepLIIF_model.py +344 -235
- deepliif/models/__init__.py +194 -103
- deepliif/models/base_model.py +7 -2
- deepliif/options/__init__.py +40 -8
- deepliif/postprocessing.py +1 -1
- deepliif/util/__init__.py +98 -1
- deepliif/util/util.py +85 -0
- deepliif/util/visualizer.py +2 -2
- {deepliif-1.2.2.dist-info → deepliif-1.2.4.dist-info}/METADATA +3 -3
- {deepliif-1.2.2.dist-info → deepliif-1.2.4.dist-info}/RECORD +17 -17
- {deepliif-1.2.2.dist-info → deepliif-1.2.4.dist-info}/LICENSE.md +0 -0
- {deepliif-1.2.2.dist-info → deepliif-1.2.4.dist-info}/WHEEL +0 -0
- {deepliif-1.2.2.dist-info → deepliif-1.2.4.dist-info}/entry_points.txt +0 -0
- {deepliif-1.2.2.dist-info → deepliif-1.2.4.dist-info}/top_level.txt +0 -0
|
@@ -2,7 +2,8 @@ import torch
|
|
|
2
2
|
from .base_model import BaseModel
|
|
3
3
|
from . import networks
|
|
4
4
|
from .networks import get_optimizer
|
|
5
|
-
|
|
5
|
+
import os
|
|
6
|
+
from ..util.util import get_input_id, init_input_and_mod_id
|
|
6
7
|
|
|
7
8
|
class DeepLIIFModel(BaseModel):
|
|
8
9
|
""" This class implements the DeepLIIF model, for learning a mapping from input images to modalities given paired data."""
|
|
@@ -20,6 +21,8 @@ class DeepLIIFModel(BaseModel):
|
|
|
20
21
|
self.seg_weights = opt.seg_weights
|
|
21
22
|
self.loss_G_weights = opt.loss_G_weights
|
|
22
23
|
self.loss_D_weights = opt.loss_D_weights
|
|
24
|
+
self.mod_id_seg, self.input_id = init_input_and_mod_id(opt) # creates self.input_id, self.mod_id_seg
|
|
25
|
+
print(f'Initializing model with segmentation modality id {self.mod_id_seg}, input id {self.input_id}')
|
|
23
26
|
|
|
24
27
|
if not opt.is_train:
|
|
25
28
|
self.gpu_ids = [] # avoid the models being loaded as DP
|
|
@@ -29,76 +32,77 @@ class DeepLIIFModel(BaseModel):
|
|
|
29
32
|
self.loss_names = []
|
|
30
33
|
self.visual_names = ['real_A']
|
|
31
34
|
# specify the training losses you want to print out. The training/test scripts will call <BaseModel.get_current_losses>
|
|
32
|
-
for i in range(
|
|
33
|
-
self.loss_names.extend(['G_GAN_'
|
|
34
|
-
self.visual_names.extend(['fake_B_
|
|
35
|
+
for i in range(self.opt.modalities_no):
|
|
36
|
+
self.loss_names.extend([f'G_GAN_{i+1}', f'G_L1_{i+1}', f'D_real_{i+1}', f'D_fake_{i+1}'])
|
|
37
|
+
self.visual_names.extend([f'fake_B_{i+1}', f'real_B_{i+1}'])
|
|
38
|
+
self.loss_names.extend([f'G_GAN_{self.mod_id_seg}',f'G_L1_{self.mod_id_seg}',f'D_real_{self.mod_id_seg}',f'D_fake_{self.mod_id_seg}'])
|
|
39
|
+
|
|
40
|
+
for i in range(self.opt.modalities_no+1):
|
|
41
|
+
self.visual_names.extend([f'fake_B_{self.mod_id_seg}{i}']) # 0 is used for the base input mod
|
|
42
|
+
self.visual_names.extend([f'fake_B_{self.mod_id_seg}', f'real_B_{self.mod_id_seg}'])
|
|
35
43
|
|
|
36
44
|
# specify the images you want to save/display. The training/test scripts will call <BaseModel.get_current_visuals>
|
|
37
45
|
# specify the models you want to save to the disk. The training/test scripts will call <BaseModel.save_networks> and <BaseModel.load_networks>
|
|
46
|
+
self.model_names_g = []
|
|
47
|
+
self.model_names_gs = []
|
|
38
48
|
if self.is_train:
|
|
39
49
|
self.model_names = []
|
|
40
50
|
for i in range(1, self.opt.modalities_no + 1):
|
|
41
|
-
self.model_names.extend(['G'
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
51
|
+
self.model_names.extend([f'G{i}', f'D{i}'])
|
|
52
|
+
self.model_names_g.append(f'G{i}')
|
|
53
|
+
|
|
54
|
+
for i in range(self.opt.modalities_no + 1): # 0 is used for the base input mod
|
|
55
|
+
if self.input_id == '0':
|
|
56
|
+
self.model_names.extend([f'G{self.mod_id_seg}{i}', f'D{self.mod_id_seg}{i}'])
|
|
57
|
+
self.model_names_gs.append(f'G{self.mod_id_seg}{i}')
|
|
58
|
+
else:
|
|
59
|
+
self.model_names.extend([f'G{self.mod_id_seg}{i+1}', f'D{self.mod_id_seg}{i+1}'])
|
|
60
|
+
self.model_names_gs.append(f'G{self.mod_id_seg}{i+1}')
|
|
45
61
|
else: # during test time, only load G
|
|
46
62
|
self.model_names = []
|
|
47
63
|
for i in range(1, self.opt.modalities_no + 1):
|
|
48
|
-
self.model_names.extend(['G'
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
64
|
+
self.model_names.extend([f'G{i}'])
|
|
65
|
+
self.model_names_g.append(f'G{i}')
|
|
66
|
+
|
|
67
|
+
#input_id = get_input_id(os.path.join(opt.checkpoints_dir, opt.name))
|
|
68
|
+
if self.input_id == '0':
|
|
69
|
+
for i in range(self.opt.modalities_no + 1): # 0 is used for the base input mod
|
|
70
|
+
self.model_names.extend([f'G{self.mod_id_seg}{i}'])
|
|
71
|
+
self.model_names_gs.append(f'G{self.mod_id_seg}{i}')
|
|
72
|
+
else:
|
|
73
|
+
for i in range(self.opt.modalities_no + 1): # old setting, 1 is used for the base input mod
|
|
74
|
+
self.model_names.extend([f'G{self.mod_id_seg}{i+1}'])
|
|
75
|
+
self.model_names_gs.append(f'G{self.mod_id_seg}{i+1}')
|
|
52
76
|
|
|
53
77
|
# define networks (both generator and discriminator)
|
|
54
78
|
if isinstance(opt.netG, str):
|
|
55
|
-
opt.netG = [opt.netG] *
|
|
79
|
+
opt.netG = [opt.netG] * self.opt.modalities_no
|
|
56
80
|
if isinstance(opt.net_gs, str):
|
|
57
|
-
opt.net_gs = [opt.net_gs]*
|
|
81
|
+
opt.net_gs = [opt.net_gs] * (self.opt.modalities_no + 1) # +1 for base input mod
|
|
58
82
|
|
|
59
|
-
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids, opt.padding)
|
|
64
|
-
self.netG3 = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG[2], opt.norm,
|
|
65
|
-
not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids, opt.padding)
|
|
66
|
-
self.netG4 = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG[3], opt.norm,
|
|
67
|
-
not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids, opt.padding)
|
|
83
|
+
|
|
84
|
+
for i,model_name in enumerate(self.model_names_g):
|
|
85
|
+
setattr(self,f'net{model_name}',networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG[i], opt.norm,
|
|
86
|
+
not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids, opt.padding))
|
|
68
87
|
|
|
69
88
|
# DeepLIIF model currently uses one gs arch because there is only one explicit seg mod output
|
|
70
|
-
|
|
71
|
-
|
|
72
|
-
|
|
73
|
-
not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids)
|
|
74
|
-
self.netG53 = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.net_gs[2], opt.norm,
|
|
75
|
-
not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids)
|
|
76
|
-
self.netG54 = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.net_gs[3], opt.norm,
|
|
77
|
-
not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids)
|
|
78
|
-
self.netG55 = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.net_gs[4], opt.norm,
|
|
79
|
-
not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids)
|
|
80
|
-
|
|
89
|
+
for i,model_name in enumerate(self.model_names_gs):
|
|
90
|
+
setattr(self,f'net{model_name}',networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.net_gs[i], opt.norm,
|
|
91
|
+
not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids))
|
|
81
92
|
|
|
82
93
|
if self.is_train: # define a discriminator; conditional GANs need to take both input and output images; Therefore, #channels for D is input_nc + output_nc
|
|
83
|
-
self.
|
|
84
|
-
|
|
85
|
-
|
|
86
|
-
|
|
87
|
-
|
|
88
|
-
|
|
89
|
-
|
|
90
|
-
opt.n_layers_D, opt.norm, opt.init_type, opt.init_gain, self.gpu_ids)
|
|
91
|
-
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
opt.n_layers_D, opt.norm, opt.init_type, opt.init_gain, self.gpu_ids)
|
|
96
|
-
self.netD53 = networks.define_D(opt.input_nc + opt.output_nc, opt.ndf, opt.netD,
|
|
97
|
-
opt.n_layers_D, opt.norm, opt.init_type, opt.init_gain, self.gpu_ids)
|
|
98
|
-
self.netD54 = networks.define_D(opt.input_nc + opt.output_nc, opt.ndf, opt.netD,
|
|
99
|
-
opt.n_layers_D, opt.norm, opt.init_type, opt.init_gain, self.gpu_ids)
|
|
100
|
-
self.netD55 = networks.define_D(opt.input_nc + opt.output_nc, opt.ndf, opt.netD,
|
|
101
|
-
opt.n_layers_D, opt.norm, opt.init_type, opt.init_gain, self.gpu_ids)
|
|
94
|
+
self.model_names_d = [f'D{i+1}' for i in range(self.opt.modalities_no)]
|
|
95
|
+
if self.input_id == '0':
|
|
96
|
+
self.model_names_ds = [f'D{self.mod_id_seg}{i}' for i in range(self.opt.modalities_no+1)]
|
|
97
|
+
else:
|
|
98
|
+
self.model_names_ds = [f'D{self.mod_id_seg}{i+1}' for i in range(self.opt.modalities_no+1)]
|
|
99
|
+
for model_name in self.model_names_d:
|
|
100
|
+
setattr(self,f'net{model_name}',networks.define_D(opt.input_nc+opt.output_nc , opt.ndf, opt.netD,
|
|
101
|
+
opt.n_layers_D, opt.norm, opt.init_type, opt.init_gain, self.gpu_ids))
|
|
102
|
+
|
|
103
|
+
for model_name in self.model_names_ds:
|
|
104
|
+
setattr(self,f'net{model_name}',networks.define_D(opt.input_nc + opt.output_nc, opt.ndf, opt.netD,
|
|
105
|
+
opt.n_layers_D, opt.norm, opt.init_type, opt.init_gain, self.gpu_ids))
|
|
102
106
|
|
|
103
107
|
if self.is_train:
|
|
104
108
|
# define loss functions
|
|
@@ -107,14 +111,22 @@ class DeepLIIFModel(BaseModel):
|
|
|
107
111
|
self.criterionSmoothL1 = torch.nn.SmoothL1Loss()
|
|
108
112
|
|
|
109
113
|
# initialize optimizers; schedulers will be automatically created by function <BaseModel.setup>.
|
|
110
|
-
params = list(self.netG1.parameters()) + list(self.netG2.parameters()) + list(self.netG3.parameters()) + list(self.netG4.parameters()) + list(self.netG51.parameters()) + list(self.netG52.parameters()) + list(self.netG53.parameters()) + list(self.netG54.parameters()) + list(self.netG55.parameters())
|
|
114
|
+
#params = list(self.netG1.parameters()) + list(self.netG2.parameters()) + list(self.netG3.parameters()) + list(self.netG4.parameters()) + list(self.netG51.parameters()) + list(self.netG52.parameters()) + list(self.netG53.parameters()) + list(self.netG54.parameters()) + list(self.netG55.parameters())
|
|
115
|
+
params = []
|
|
116
|
+
for model_name in self.model_names_g + self.model_names_gs:
|
|
117
|
+
params += list(getattr(self,f'net{model_name}').parameters())
|
|
118
|
+
|
|
111
119
|
try:
|
|
112
120
|
self.optimizer_G = get_optimizer(opt.optimizer)(params, lr=opt.lr_g, betas=(opt.beta1, 0.999))
|
|
113
121
|
except:
|
|
114
122
|
print(f'betas are not used for optimizer torch.optim.{opt.optimizer} in generators')
|
|
115
123
|
self.optimizer_G = get_optimizer(opt.optimizer)(params, lr=opt.lr_g)
|
|
116
124
|
|
|
117
|
-
params = list(self.netD1.parameters()) + list(self.netD2.parameters()) + list(self.netD3.parameters()) + list(self.netD4.parameters()) + list(self.netD51.parameters()) + list(self.netD52.parameters()) + list(self.netD53.parameters()) + list(self.netD54.parameters()) + list(self.netD55.parameters())
|
|
125
|
+
#params = list(self.netD1.parameters()) + list(self.netD2.parameters()) + list(self.netD3.parameters()) + list(self.netD4.parameters()) + list(self.netD51.parameters()) + list(self.netD52.parameters()) + list(self.netD53.parameters()) + list(self.netD54.parameters()) + list(self.netD55.parameters())
|
|
126
|
+
params = []
|
|
127
|
+
for model_name in self.model_names_d + self.model_names_ds:
|
|
128
|
+
params += list(getattr(self,f'net{model_name}').parameters())
|
|
129
|
+
|
|
118
130
|
try:
|
|
119
131
|
self.optimizer_D = get_optimizer(opt.optimizer)(params, lr=opt.lr_d, betas=(opt.beta1, 0.999))
|
|
120
132
|
except:
|
|
@@ -136,167 +148,251 @@ class DeepLIIFModel(BaseModel):
|
|
|
136
148
|
self.real_A = input['A'].to(self.device)
|
|
137
149
|
|
|
138
150
|
self.real_B_array = input['B']
|
|
139
|
-
|
|
140
|
-
|
|
141
|
-
self.
|
|
142
|
-
|
|
143
|
-
self.real_B_5 = self.real_B_array[4].to(self.device)
|
|
151
|
+
for i in range(self.opt.modalities_no):
|
|
152
|
+
setattr(self,f'real_B_{i+1}',self.real_B_array[i].to(self.device))
|
|
153
|
+
setattr(self,f'real_B_{self.mod_id_seg}',self.real_B_array[self.opt.modalities_no].to(self.device)) # the last one is seg
|
|
154
|
+
|
|
144
155
|
self.image_paths = input['A_paths']
|
|
145
156
|
|
|
146
157
|
def forward(self):
|
|
147
158
|
"""Run forward pass; called by both functions <optimize_parameters> and <test>."""
|
|
148
|
-
self.fake_B_1 = self.netG1(self.real_A) # Hematoxylin image generator
|
|
149
|
-
self.fake_B_2 = self.netG2(self.real_A) # mpIF DAPI image generator
|
|
150
|
-
self.fake_B_3 = self.netG3(self.real_A) # mpIF Lap2 image generator
|
|
151
|
-
self.fake_B_4 = self.netG4(self.real_A) # mpIF Ki67 image generator
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
155
|
-
|
|
156
|
-
self.
|
|
157
|
-
self.
|
|
158
|
-
self.
|
|
159
|
-
|
|
160
|
-
|
|
161
|
-
|
|
162
|
-
|
|
159
|
+
# self.fake_B_1 = self.netG1(self.real_A) # Hematoxylin image generator
|
|
160
|
+
# self.fake_B_2 = self.netG2(self.real_A) # mpIF DAPI image generator
|
|
161
|
+
# self.fake_B_3 = self.netG3(self.real_A) # mpIF Lap2 image generator
|
|
162
|
+
# self.fake_B_4 = self.netG4(self.real_A) # mpIF Ki67 image generator
|
|
163
|
+
|
|
164
|
+
for i in range(self.opt.modalities_no):
|
|
165
|
+
setattr(self,f'fake_B_{i+1}',getattr(self,f'netG{i+1}')(self.real_A))
|
|
166
|
+
|
|
167
|
+
# self.fake_B_5_1 = self.netG51(self.real_A) # Segmentation mask generator from IHC input image
|
|
168
|
+
# self.fake_B_5_2 = self.netG52(self.fake_B_1) # Segmentation mask generator from Hematoxylin input image
|
|
169
|
+
# self.fake_B_5_3 = self.netG53(self.fake_B_2) # Segmentation mask generator from mpIF DAPI input image
|
|
170
|
+
# self.fake_B_5_4 = self.netG54(self.fake_B_3) # Segmentation mask generator from mpIF Lap2 input image
|
|
171
|
+
# self.fake_B_5_5 = self.netG55(self.fake_B_4) # Segmentation mask generator from mpIF Lap2 input image
|
|
172
|
+
# self.fake_B_5 = torch.stack([torch.mul(self.fake_B_5_1, self.seg_weights[0]),
|
|
173
|
+
# torch.mul(self.fake_B_5_2, self.seg_weights[1]),
|
|
174
|
+
# torch.mul(self.fake_B_5_3, self.seg_weights[2]),
|
|
175
|
+
# torch.mul(self.fake_B_5_4, self.seg_weights[3]),
|
|
176
|
+
# torch.mul(self.fake_B_5_5, self.seg_weights[4])]).sum(dim=0)
|
|
177
|
+
|
|
178
|
+
for i,model_name in enumerate(self.model_names_gs):
|
|
179
|
+
if i == 0:
|
|
180
|
+
setattr(self,f'fake_B_{self.mod_id_seg}_{i}',getattr(self,f'net{model_name}')(self.real_A))
|
|
181
|
+
else:
|
|
182
|
+
setattr(self,f'fake_B_{self.mod_id_seg}_{i}',getattr(self,f'net{model_name}')(getattr(self,f'fake_B_{i}')))
|
|
183
|
+
|
|
184
|
+
setattr(self,f'fake_B_{self.mod_id_seg}',torch.stack([torch.mul(getattr(self,f'fake_B_{self.mod_id_seg}_{i}'), self.seg_weights[i]) for i in range(self.opt.modalities_no+1)]).sum(dim=0))
|
|
163
185
|
|
|
164
186
|
def backward_D(self):
|
|
165
187
|
"""Calculate GAN loss for the discriminators"""
|
|
166
|
-
fake_AB_1 = torch.cat((self.real_A, self.fake_B_1), 1) # Conditional GANs; feed IHC input and Hematoxtlin output to the discriminator
|
|
167
|
-
fake_AB_2 = torch.cat((self.real_A, self.fake_B_2), 1) # Conditional GANs; feed IHC input and mpIF DAPI output to the discriminator
|
|
168
|
-
fake_AB_3 = torch.cat((self.real_A, self.fake_B_3), 1) # Conditional GANs; feed IHC input and mpIF Lap2 output to the discriminator
|
|
169
|
-
fake_AB_4 = torch.cat((self.real_A, self.fake_B_4), 1) # Conditional GANs; feed IHC input and mpIF Ki67 output to the discriminator
|
|
170
|
-
|
|
171
|
-
pred_fake_1 = self.netD1(fake_AB_1.detach())
|
|
172
|
-
pred_fake_2 = self.netD2(fake_AB_2.detach())
|
|
173
|
-
pred_fake_3 = self.netD3(fake_AB_3.detach())
|
|
174
|
-
pred_fake_4 = self.netD4(fake_AB_4.detach())
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
178
|
-
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
199
|
-
|
|
200
|
-
|
|
201
|
-
|
|
202
|
-
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
223
|
-
|
|
224
|
-
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
|
|
228
|
-
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
self.
|
|
234
|
-
self.
|
|
235
|
-
self.
|
|
188
|
+
# fake_AB_1 = torch.cat((self.real_A, self.fake_B_1), 1) # Conditional GANs; feed IHC input and Hematoxtlin output to the discriminator
|
|
189
|
+
# fake_AB_2 = torch.cat((self.real_A, self.fake_B_2), 1) # Conditional GANs; feed IHC input and mpIF DAPI output to the discriminator
|
|
190
|
+
# fake_AB_3 = torch.cat((self.real_A, self.fake_B_3), 1) # Conditional GANs; feed IHC input and mpIF Lap2 output to the discriminator
|
|
191
|
+
# fake_AB_4 = torch.cat((self.real_A, self.fake_B_4), 1) # Conditional GANs; feed IHC input and mpIF Ki67 output to the discriminator
|
|
192
|
+
|
|
193
|
+
# pred_fake_1 = self.netD1(fake_AB_1.detach())
|
|
194
|
+
# pred_fake_2 = self.netD2(fake_AB_2.detach())
|
|
195
|
+
# pred_fake_3 = self.netD3(fake_AB_3.detach())
|
|
196
|
+
# pred_fake_4 = self.netD4(fake_AB_4.detach())
|
|
197
|
+
|
|
198
|
+
# self.loss_D_fake_1 = self.criterionGAN_BCE(pred_fake_1, False)
|
|
199
|
+
# self.loss_D_fake_2 = self.criterionGAN_BCE(pred_fake_2, False)
|
|
200
|
+
# self.loss_D_fake_3 = self.criterionGAN_BCE(pred_fake_3, False)
|
|
201
|
+
# self.loss_D_fake_4 = self.criterionGAN_BCE(pred_fake_4, False)
|
|
202
|
+
|
|
203
|
+
for i,model_name in enumerate(self.model_names_d):
|
|
204
|
+
fake_AB = torch.cat((self.real_A, getattr(self,f'fake_B_{i+1}')), 1)
|
|
205
|
+
pred_fake = getattr(self,f'net{model_name}')(fake_AB.detach())
|
|
206
|
+
setattr(self,f'loss_D_fake_{i+1}',self.criterionGAN_BCE(pred_fake, False))
|
|
207
|
+
#setattr(self,f'fake_AB_{i+1}',torch.cat((self.real_A, getattr(self,f'fake_B_{i+1}')), 1))
|
|
208
|
+
#setattr(self,f'pred_fake_{i+1}',getattr(self,f'netD{i+1}')(getattr))
|
|
209
|
+
|
|
210
|
+
|
|
211
|
+
# fake_AB_5_1 = torch.cat((self.real_A, self.fake_B_5), 1) # Conditional GANs; feed IHC input and Segmentation mask output to the discriminator
|
|
212
|
+
# fake_AB_5_2 = torch.cat((self.real_B_1, self.fake_B_5), 1) # Conditional GANs; feed Hematoxylin input and Segmentation mask output to the discriminator
|
|
213
|
+
# fake_AB_5_3 = torch.cat((self.real_B_2, self.fake_B_5), 1) # Conditional GANs; feed mpIF DAPI input and Segmentation mask output to the discriminator
|
|
214
|
+
# fake_AB_5_4 = torch.cat((self.real_B_3, self.fake_B_5), 1) # Conditional GANs; feed mpIF Lap2 input and Segmentation mask output to the discriminator
|
|
215
|
+
# fake_AB_5_5 = torch.cat((self.real_B_4, self.fake_B_5), 1) # Conditional GANs; feed mpIF Lap2 input and Segmentation mask output to the discriminator
|
|
216
|
+
#
|
|
217
|
+
# pred_fake_5_1 = self.netD51(fake_AB_5_1.detach())
|
|
218
|
+
# pred_fake_5_2 = self.netD52(fake_AB_5_2.detach())
|
|
219
|
+
# pred_fake_5_3 = self.netD53(fake_AB_5_3.detach())
|
|
220
|
+
# pred_fake_5_4 = self.netD54(fake_AB_5_4.detach())
|
|
221
|
+
# pred_fake_5_5 = self.netD55(fake_AB_5_5.detach())
|
|
222
|
+
#
|
|
223
|
+
# pred_fake_5 = torch.stack(
|
|
224
|
+
# [torch.mul(pred_fake_5_1, self.seg_weights[0]),
|
|
225
|
+
# torch.mul(pred_fake_5_2, self.seg_weights[1]),
|
|
226
|
+
# torch.mul(pred_fake_5_3, self.seg_weights[2]),
|
|
227
|
+
# torch.mul(pred_fake_5_4, self.seg_weights[3]),
|
|
228
|
+
# torch.mul(pred_fake_5_5, self.seg_weights[4])]).sum(dim=0)
|
|
229
|
+
|
|
230
|
+
l_pred_fake_seg = []
|
|
231
|
+
for i,model_name in enumerate(self.model_names_ds):
|
|
232
|
+
if i == 0:
|
|
233
|
+
fake_AB_seg_i = torch.cat((self.real_A, getattr(self,f'fake_B_{self.mod_id_seg}')), 1)
|
|
234
|
+
else:
|
|
235
|
+
fake_AB_seg_i = torch.cat((getattr(self,f'real_B_{i}'), getattr(self,f'fake_B_{self.mod_id_seg}')), 1)
|
|
236
|
+
|
|
237
|
+
pred_fake_seg_i = getattr(self,f'net{model_name}')(fake_AB_seg_i.detach())
|
|
238
|
+
l_pred_fake_seg.append(torch.mul(pred_fake_seg_i, self.seg_weights[i]))
|
|
239
|
+
pred_fake_seg = torch.stack(l_pred_fake_seg).sum(dim=0)
|
|
240
|
+
|
|
241
|
+
#self.loss_D_fake_5 = self.criterionGAN_lsgan(pred_fake_5, False)
|
|
242
|
+
setattr(self,f'loss_D_fake_{self.mod_id_seg}',self.criterionGAN_lsgan(pred_fake_seg, False))
|
|
243
|
+
|
|
244
|
+
|
|
245
|
+
# real_AB_1 = torch.cat((self.real_A, self.real_B_1), 1)
|
|
246
|
+
# real_AB_2 = torch.cat((self.real_A, self.real_B_2), 1)
|
|
247
|
+
# real_AB_3 = torch.cat((self.real_A, self.real_B_3), 1)
|
|
248
|
+
# real_AB_4 = torch.cat((self.real_A, self.real_B_4), 1)
|
|
249
|
+
#
|
|
250
|
+
# pred_real_1 = self.netD1(real_AB_1)
|
|
251
|
+
# pred_real_2 = self.netD2(real_AB_2)
|
|
252
|
+
# pred_real_3 = self.netD3(real_AB_3)
|
|
253
|
+
# pred_real_4 = self.netD4(real_AB_4)
|
|
254
|
+
#
|
|
255
|
+
# self.loss_D_real_1 = self.criterionGAN_BCE(pred_real_1, True)
|
|
256
|
+
# self.loss_D_real_2 = self.criterionGAN_BCE(pred_real_2, True)
|
|
257
|
+
# self.loss_D_real_3 = self.criterionGAN_BCE(pred_real_3, True)
|
|
258
|
+
# self.loss_D_real_4 = self.criterionGAN_BCE(pred_real_4, True)
|
|
259
|
+
|
|
260
|
+
for i,model_name in enumerate(self.model_names_d):
|
|
261
|
+
real_AB = torch.cat((self.real_A, getattr(self,f'real_B_{i+1}')), 1)
|
|
262
|
+
pred_real = getattr(self,f'net{model_name}')(real_AB)
|
|
263
|
+
setattr(self,f'loss_D_real_{i+1}',self.criterionGAN_BCE(pred_real, True))
|
|
264
|
+
|
|
265
|
+
# real_AB_5_1 = torch.cat((self.real_A, self.real_B_5), 1)
|
|
266
|
+
# real_AB_5_2 = torch.cat((self.real_B_1, self.real_B_5), 1)
|
|
267
|
+
# real_AB_5_3 = torch.cat((self.real_B_2, self.real_B_5), 1)
|
|
268
|
+
# real_AB_5_4 = torch.cat((self.real_B_3, self.real_B_5), 1)
|
|
269
|
+
# real_AB_5_5 = torch.cat((self.real_B_4, self.real_B_5), 1)
|
|
270
|
+
#
|
|
271
|
+
# pred_real_5_1 = self.netD51(real_AB_5_1)
|
|
272
|
+
# pred_real_5_2 = self.netD52(real_AB_5_2)
|
|
273
|
+
# pred_real_5_3 = self.netD53(real_AB_5_3)
|
|
274
|
+
# pred_real_5_4 = self.netD54(real_AB_5_4)
|
|
275
|
+
# pred_real_5_5 = self.netD55(real_AB_5_5)
|
|
276
|
+
#
|
|
277
|
+
# pred_real_5 = torch.stack(
|
|
278
|
+
# [torch.mul(pred_real_5_1, self.seg_weights[0]),
|
|
279
|
+
# torch.mul(pred_real_5_2, self.seg_weights[1]),
|
|
280
|
+
# torch.mul(pred_real_5_3, self.seg_weights[2]),
|
|
281
|
+
# torch.mul(pred_real_5_4, self.seg_weights[3]),
|
|
282
|
+
# torch.mul(pred_real_5_5, self.seg_weights[4])]).sum(dim=0)
|
|
283
|
+
|
|
284
|
+
l_pred_real_seg = []
|
|
285
|
+
for i,model_name in enumerate(self.model_names_ds):
|
|
286
|
+
if i == 0:
|
|
287
|
+
real_AB_seg_i = torch.cat((self.real_A, getattr(self,f'real_B_{self.mod_id_seg}')), 1)
|
|
288
|
+
else:
|
|
289
|
+
real_AB_seg_i = torch.cat((getattr(self,f'real_B_{i}'), getattr(self,f'real_B_{self.mod_id_seg}')), 1)
|
|
290
|
+
|
|
291
|
+
pred_real_seg_i = getattr(self,f'net{model_name}')(real_AB_seg_i)
|
|
292
|
+
l_pred_real_seg.append(torch.mul(pred_real_seg_i, self.seg_weights[i]))
|
|
293
|
+
pred_real_seg = torch.stack(l_pred_real_seg).sum(dim=0)
|
|
294
|
+
|
|
295
|
+
#self.loss_D_real_5 = self.criterionGAN_lsgan(pred_real_5, True)
|
|
296
|
+
setattr(self,f'loss_D_real_{self.mod_id_seg}',self.criterionGAN_lsgan(pred_real_seg, True))
|
|
236
297
|
|
|
237
298
|
# combine losses and calculate gradients
|
|
238
|
-
self.loss_D = (self.loss_D_fake_1 + self.loss_D_real_1) * 0.5 * self.loss_D_weights[0] + \
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
|
|
299
|
+
# self.loss_D = (self.loss_D_fake_1 + self.loss_D_real_1) * 0.5 * self.loss_D_weights[0] + \
|
|
300
|
+
# (self.loss_D_fake_2 + self.loss_D_real_2) * 0.5 * self.loss_D_weights[1] + \
|
|
301
|
+
# (self.loss_D_fake_3 + self.loss_D_real_3) * 0.5 * self.loss_D_weights[2] + \
|
|
302
|
+
# (self.loss_D_fake_4 + self.loss_D_real_4) * 0.5 * self.loss_D_weights[3] + \
|
|
303
|
+
# (self.loss_D_fake_5 + self.loss_D_real_5) * 0.5 * self.loss_D_weights[4]
|
|
304
|
+
|
|
305
|
+
self.loss_D = torch.tensor(0., device=self.device)
|
|
306
|
+
for i in range(self.opt.modalities_no):
|
|
307
|
+
self.loss_D += (getattr(self,f'loss_D_fake_{i+1}') + getattr(self,f'loss_D_real_{i+1}')) * 0.5 * self.loss_D_weights[i]
|
|
308
|
+
self.loss_D += (getattr(self,f'loss_D_fake_{self.mod_id_seg}') + getattr(self,f'loss_D_real_{self.mod_id_seg}')) * 0.5 * self.loss_D_weights[self.opt.modalities_no]
|
|
243
309
|
|
|
244
310
|
self.loss_D.backward()
|
|
245
311
|
|
|
246
312
|
def backward_G(self):
|
|
247
313
|
"""Calculate GAN and L1 loss for the generator"""
|
|
248
314
|
|
|
249
|
-
fake_AB_1 = torch.cat((self.real_A, self.fake_B_1), 1)
|
|
250
|
-
fake_AB_2 = torch.cat((self.real_A, self.fake_B_2), 1)
|
|
251
|
-
fake_AB_3 = torch.cat((self.real_A, self.fake_B_3), 1)
|
|
252
|
-
fake_AB_4 = torch.cat((self.real_A, self.fake_B_4), 1)
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
257
|
-
|
|
258
|
-
|
|
259
|
-
|
|
260
|
-
|
|
261
|
-
|
|
262
|
-
|
|
263
|
-
|
|
264
|
-
|
|
265
|
-
|
|
266
|
-
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
|
|
270
|
-
|
|
271
|
-
|
|
272
|
-
|
|
273
|
-
|
|
274
|
-
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
|
|
278
|
-
|
|
279
|
-
|
|
280
|
-
|
|
281
|
-
|
|
315
|
+
# fake_AB_1 = torch.cat((self.real_A, self.fake_B_1), 1)
|
|
316
|
+
# fake_AB_2 = torch.cat((self.real_A, self.fake_B_2), 1)
|
|
317
|
+
# fake_AB_3 = torch.cat((self.real_A, self.fake_B_3), 1)
|
|
318
|
+
# fake_AB_4 = torch.cat((self.real_A, self.fake_B_4), 1)
|
|
319
|
+
#
|
|
320
|
+
# pred_fake_1 = self.netD1(fake_AB_1)
|
|
321
|
+
# pred_fake_2 = self.netD2(fake_AB_2)
|
|
322
|
+
# pred_fake_3 = self.netD3(fake_AB_3)
|
|
323
|
+
# pred_fake_4 = self.netD4(fake_AB_4)
|
|
324
|
+
#
|
|
325
|
+
# self.loss_G_GAN_1 = self.criterionGAN_BCE(pred_fake_1, True)
|
|
326
|
+
# self.loss_G_GAN_2 = self.criterionGAN_BCE(pred_fake_2, True)
|
|
327
|
+
# self.loss_G_GAN_3 = self.criterionGAN_BCE(pred_fake_3, True)
|
|
328
|
+
# self.loss_G_GAN_4 = self.criterionGAN_BCE(pred_fake_4, True)
|
|
329
|
+
|
|
330
|
+
for i,model_name in enumerate(self.model_names_d):
|
|
331
|
+
fake_AB = torch.cat((self.real_A, getattr(self,f'fake_B_{i+1}')), 1)
|
|
332
|
+
pred_fake = getattr(self,f'net{model_name}')(fake_AB)
|
|
333
|
+
setattr(self,f'loss_G_GAN_{i+1}',self.criterionGAN_BCE(pred_fake, True))
|
|
334
|
+
|
|
335
|
+
# fake_AB_5_1 = torch.cat((self.real_A, self.fake_B_5), 1)
|
|
336
|
+
# fake_AB_5_2 = torch.cat((self.real_B_1, self.fake_B_5), 1)
|
|
337
|
+
# fake_AB_5_3 = torch.cat((self.real_B_2, self.fake_B_5), 1)
|
|
338
|
+
# fake_AB_5_4 = torch.cat((self.real_B_3, self.fake_B_5), 1)
|
|
339
|
+
# fake_AB_5_5 = torch.cat((self.real_B_4, self.fake_B_5), 1)
|
|
340
|
+
#
|
|
341
|
+
# pred_fake_5_1 = self.netD51(fake_AB_5_1)
|
|
342
|
+
# pred_fake_5_2 = self.netD52(fake_AB_5_2)
|
|
343
|
+
# pred_fake_5_3 = self.netD53(fake_AB_5_3)
|
|
344
|
+
# pred_fake_5_4 = self.netD54(fake_AB_5_4)
|
|
345
|
+
# pred_fake_5_5 = self.netD55(fake_AB_5_5)
|
|
346
|
+
# pred_fake_5 = torch.stack(
|
|
347
|
+
# [torch.mul(pred_fake_5_1, self.seg_weights[0]),
|
|
348
|
+
# torch.mul(pred_fake_5_2, self.seg_weights[1]),
|
|
349
|
+
# torch.mul(pred_fake_5_3, self.seg_weights[2]),
|
|
350
|
+
# torch.mul(pred_fake_5_4, self.seg_weights[3]),
|
|
351
|
+
# torch.mul(pred_fake_5_5, self.seg_weights[4])]).sum(dim=0)
|
|
352
|
+
|
|
353
|
+
l_pred_fake_seg = []
|
|
354
|
+
for i,model_name in enumerate(self.model_names_ds):
|
|
355
|
+
if i == 0:
|
|
356
|
+
fake_AB_seg_i = torch.cat((self.real_A, getattr(self,f'fake_B_{self.mod_id_seg}')), 1)
|
|
357
|
+
else:
|
|
358
|
+
fake_AB_seg_i = torch.cat((getattr(self,f'real_B_{i}'), getattr(self,f'fake_B_{self.mod_id_seg}')), 1)
|
|
359
|
+
|
|
360
|
+
pred_fake_seg_i = getattr(self,f'net{model_name}')(fake_AB_seg_i)
|
|
361
|
+
l_pred_fake_seg.append(torch.mul(pred_fake_seg_i, self.seg_weights[i]))
|
|
362
|
+
pred_fake_seg = torch.stack(l_pred_fake_seg).sum(dim=0)
|
|
363
|
+
|
|
364
|
+
# self.loss_G_GAN_5 = self.criterionGAN_lsgan(pred_fake_5, True)
|
|
365
|
+
setattr(self,f'loss_G_GAN_{self.mod_id_seg}',self.criterionGAN_lsgan(pred_fake_seg, True))
|
|
282
366
|
|
|
283
367
|
# Second, G(A) = B
|
|
284
|
-
self.loss_G_L1_1 = self.criterionSmoothL1(self.fake_B_1, self.real_B_1) * self.opt.lambda_L1
|
|
285
|
-
self.loss_G_L1_2 = self.criterionSmoothL1(self.fake_B_2, self.real_B_2) * self.opt.lambda_L1
|
|
286
|
-
self.loss_G_L1_3 = self.criterionSmoothL1(self.fake_B_3, self.real_B_3) * self.opt.lambda_L1
|
|
287
|
-
self.loss_G_L1_4 = self.criterionSmoothL1(self.fake_B_4, self.real_B_4) * self.opt.lambda_L1
|
|
288
|
-
self.loss_G_L1_5 = self.criterionSmoothL1(self.fake_B_5, self.real_B_5) * self.opt.lambda_L1
|
|
289
|
-
|
|
290
|
-
|
|
291
|
-
|
|
292
|
-
self.
|
|
293
|
-
|
|
294
|
-
|
|
295
|
-
self.
|
|
296
|
-
|
|
297
|
-
|
|
298
|
-
|
|
299
|
-
|
|
368
|
+
# self.loss_G_L1_1 = self.criterionSmoothL1(self.fake_B_1, self.real_B_1) * self.opt.lambda_L1
|
|
369
|
+
# self.loss_G_L1_2 = self.criterionSmoothL1(self.fake_B_2, self.real_B_2) * self.opt.lambda_L1
|
|
370
|
+
# self.loss_G_L1_3 = self.criterionSmoothL1(self.fake_B_3, self.real_B_3) * self.opt.lambda_L1
|
|
371
|
+
# self.loss_G_L1_4 = self.criterionSmoothL1(self.fake_B_4, self.real_B_4) * self.opt.lambda_L1
|
|
372
|
+
# self.loss_G_L1_5 = self.criterionSmoothL1(self.fake_B_5, self.real_B_5) * self.opt.lambda_L1
|
|
373
|
+
|
|
374
|
+
for i in range(self.opt.modalities_no):
|
|
375
|
+
setattr(self,f'loss_G_L1_{i+1}',self.criterionSmoothL1(getattr(self,f'fake_B_{i+1}'), getattr(self,f'real_B_{i+1}')) * self.opt.lambda_L1)
|
|
376
|
+
setattr(self,f'loss_G_L1_{self.mod_id_seg}',self.criterionSmoothL1(getattr(self,f'fake_B_{self.mod_id_seg}'), getattr(self,f'real_B_{self.mod_id_seg}')) * self.opt.lambda_L1)
|
|
377
|
+
|
|
378
|
+
# self.loss_G_VGG_1 = self.criterionVGG(self.fake_B_1, self.real_B_1) * self.opt.lambda_feat
|
|
379
|
+
# self.loss_G_VGG_2 = self.criterionVGG(self.fake_B_2, self.real_B_2) * self.opt.lambda_feat
|
|
380
|
+
# self.loss_G_VGG_3 = self.criterionVGG(self.fake_B_3, self.real_B_3) * self.opt.lambda_feat
|
|
381
|
+
# self.loss_G_VGG_4 = self.criterionVGG(self.fake_B_4, self.real_B_4) * self.opt.lambda_feat
|
|
382
|
+
for i in range(self.opt.modalities_no):
|
|
383
|
+
setattr(self,f'loss_G_VGG_{i+1}',self.criterionVGG(getattr(self,f'fake_B_{i+1}'), getattr(self,f'real_B_{i+1}')) * self.opt.lambda_feat)
|
|
384
|
+
setattr(self,f'loss_G_VGG_{self.mod_id_seg}',self.criterionVGG(getattr(self,f'fake_B_{self.mod_id_seg}'), getattr(self,f'real_B_{self.mod_id_seg}')) * self.opt.lambda_feat)
|
|
385
|
+
|
|
386
|
+
# self.loss_G = (self.loss_G_GAN_1 + self.loss_G_L1_1 + self.loss_G_VGG_1) * self.loss_G_weights[0] + \
|
|
387
|
+
# (self.loss_G_GAN_2 + self.loss_G_L1_2 + self.loss_G_VGG_2) * self.loss_G_weights[1] + \
|
|
388
|
+
# (self.loss_G_GAN_3 + self.loss_G_L1_3 + self.loss_G_VGG_3) * self.loss_G_weights[2] + \
|
|
389
|
+
# (self.loss_G_GAN_4 + self.loss_G_L1_4 + self.loss_G_VGG_4) * self.loss_G_weights[3] + \
|
|
390
|
+
# (self.loss_G_GAN_5 + self.loss_G_L1_5) * self.loss_G_weights[4]
|
|
391
|
+
|
|
392
|
+
self.loss_G = torch.tensor(0., device=self.device)
|
|
393
|
+
for i in range(self.opt.modalities_no):
|
|
394
|
+
self.loss_G += (getattr(self,f'loss_G_GAN_{i+1}') + getattr(self,f'loss_G_L1_{i+1}') + getattr(self,f'loss_G_VGG_{i+1}')) * self.loss_G_weights[i]
|
|
395
|
+
self.loss_G += (getattr(self,f'loss_G_GAN_{self.mod_id_seg}') + getattr(self,f'loss_G_L1_{self.mod_id_seg}')) * self.loss_G_weights[i]
|
|
300
396
|
|
|
301
397
|
# combine loss and calculate gradients
|
|
302
398
|
# self.loss_G = (self.loss_G_GAN_1 + self.loss_G_L1_1) * self.loss_G_weights[0] + \
|
|
@@ -309,30 +405,36 @@ class DeepLIIFModel(BaseModel):
|
|
|
309
405
|
def optimize_parameters(self):
|
|
310
406
|
self.forward() # compute fake images: G(A)
|
|
311
407
|
# update D
|
|
312
|
-
self.set_requires_grad(self.netD1, True) # enable backprop for D1
|
|
313
|
-
self.set_requires_grad(self.netD2, True) # enable backprop for D2
|
|
314
|
-
self.set_requires_grad(self.netD3, True) # enable backprop for D3
|
|
315
|
-
self.set_requires_grad(self.netD4, True) # enable backprop for D4
|
|
316
|
-
self.set_requires_grad(self.netD51, True) # enable backprop for D51
|
|
317
|
-
self.set_requires_grad(self.netD52, True) # enable backprop for D52
|
|
318
|
-
self.set_requires_grad(self.netD53, True) # enable backprop for D53
|
|
319
|
-
self.set_requires_grad(self.netD54, True) # enable backprop for D54
|
|
320
|
-
self.set_requires_grad(self.netD55, True) # enable backprop for D54
|
|
408
|
+
# self.set_requires_grad(self.netD1, True) # enable backprop for D1
|
|
409
|
+
# self.set_requires_grad(self.netD2, True) # enable backprop for D2
|
|
410
|
+
# self.set_requires_grad(self.netD3, True) # enable backprop for D3
|
|
411
|
+
# self.set_requires_grad(self.netD4, True) # enable backprop for D4
|
|
412
|
+
# self.set_requires_grad(self.netD51, True) # enable backprop for D51
|
|
413
|
+
# self.set_requires_grad(self.netD52, True) # enable backprop for D52
|
|
414
|
+
# self.set_requires_grad(self.netD53, True) # enable backprop for D53
|
|
415
|
+
# self.set_requires_grad(self.netD54, True) # enable backprop for D54
|
|
416
|
+
# self.set_requires_grad(self.netD55, True) # enable backprop for D54
|
|
417
|
+
|
|
418
|
+
for model_name in self.model_names_d + self.model_names_ds:
|
|
419
|
+
self.set_requires_grad(getattr(self,f'net{model_name}'), True)
|
|
321
420
|
|
|
322
421
|
self.optimizer_D.zero_grad() # set D's gradients to zero
|
|
323
422
|
self.backward_D() # calculate gradients for D
|
|
324
423
|
self.optimizer_D.step() # update D's weights
|
|
325
424
|
|
|
326
425
|
# update G
|
|
327
|
-
self.set_requires_grad(self.netD1, False) # D1 requires no gradients when optimizing G1
|
|
328
|
-
self.set_requires_grad(self.netD2, False) # D2 requires no gradients when optimizing G2
|
|
329
|
-
self.set_requires_grad(self.netD3, False) # D3 requires no gradients when optimizing G3
|
|
330
|
-
self.set_requires_grad(self.netD4, False) # D4 requires no gradients when optimizing G4
|
|
331
|
-
self.set_requires_grad(self.netD51, False) # D51 requires no gradients when optimizing G51
|
|
332
|
-
self.set_requires_grad(self.netD52, False) # D52 requires no gradients when optimizing G52
|
|
333
|
-
self.set_requires_grad(self.netD53, False) # D53 requires no gradients when optimizing G53
|
|
334
|
-
self.set_requires_grad(self.netD54, False) # D54 requires no gradients when optimizing G54
|
|
335
|
-
self.set_requires_grad(self.netD55, False) # D54 requires no gradients when optimizing G54
|
|
426
|
+
# self.set_requires_grad(self.netD1, False) # D1 requires no gradients when optimizing G1
|
|
427
|
+
# self.set_requires_grad(self.netD2, False) # D2 requires no gradients when optimizing G2
|
|
428
|
+
# self.set_requires_grad(self.netD3, False) # D3 requires no gradients when optimizing G3
|
|
429
|
+
# self.set_requires_grad(self.netD4, False) # D4 requires no gradients when optimizing G4
|
|
430
|
+
# self.set_requires_grad(self.netD51, False) # D51 requires no gradients when optimizing G51
|
|
431
|
+
# self.set_requires_grad(self.netD52, False) # D52 requires no gradients when optimizing G52
|
|
432
|
+
# self.set_requires_grad(self.netD53, False) # D53 requires no gradients when optimizing G53
|
|
433
|
+
# self.set_requires_grad(self.netD54, False) # D54 requires no gradients when optimizing G54
|
|
434
|
+
# self.set_requires_grad(self.netD55, False) # D54 requires no gradients when optimizing G54
|
|
435
|
+
|
|
436
|
+
for model_name in self.model_names_d + self.model_names_ds:
|
|
437
|
+
self.set_requires_grad(getattr(self,f'net{model_name}'), False)
|
|
336
438
|
|
|
337
439
|
self.optimizer_G.zero_grad() # set G's gradients to zero
|
|
338
440
|
self.backward_G() # calculate graidents for G
|
|
@@ -345,30 +447,37 @@ class DeepLIIFModel(BaseModel):
|
|
|
345
447
|
|
|
346
448
|
self.forward() # compute fake images: G(A)
|
|
347
449
|
# update D
|
|
348
|
-
self.set_requires_grad(self.netD1, True) # enable backprop for D1
|
|
349
|
-
self.set_requires_grad(self.netD2, True) # enable backprop for D2
|
|
350
|
-
self.set_requires_grad(self.netD3, True) # enable backprop for D3
|
|
351
|
-
self.set_requires_grad(self.netD4, True) # enable backprop for D4
|
|
352
|
-
self.set_requires_grad(self.netD51, True) # enable backprop for D51
|
|
353
|
-
self.set_requires_grad(self.netD52, True) # enable backprop for D52
|
|
354
|
-
self.set_requires_grad(self.netD53, True) # enable backprop for D53
|
|
355
|
-
self.set_requires_grad(self.netD54, True) # enable backprop for D54
|
|
356
|
-
self.set_requires_grad(self.netD55, True) # enable backprop for D54
|
|
450
|
+
# self.set_requires_grad(self.netD1, True) # enable backprop for D1
|
|
451
|
+
# self.set_requires_grad(self.netD2, True) # enable backprop for D2
|
|
452
|
+
# self.set_requires_grad(self.netD3, True) # enable backprop for D3
|
|
453
|
+
# self.set_requires_grad(self.netD4, True) # enable backprop for D4
|
|
454
|
+
# self.set_requires_grad(self.netD51, True) # enable backprop for D51
|
|
455
|
+
# self.set_requires_grad(self.netD52, True) # enable backprop for D52
|
|
456
|
+
# self.set_requires_grad(self.netD53, True) # enable backprop for D53
|
|
457
|
+
# self.set_requires_grad(self.netD54, True) # enable backprop for D54
|
|
458
|
+
# self.set_requires_grad(self.netD55, True) # enable backprop for D54
|
|
459
|
+
|
|
460
|
+
for model_name in self.model_names_d + self.model_names_ds:
|
|
461
|
+
self.set_requires_grad(getattr(self,f'net{model_name}'), True)
|
|
357
462
|
|
|
358
463
|
self.optimizer_D.zero_grad() # set D's gradients to zero
|
|
359
464
|
self.backward_D() # calculate gradients for D
|
|
360
465
|
|
|
361
466
|
# update G
|
|
362
|
-
self.set_requires_grad(self.netD1, False) # D1 requires no gradients when optimizing G1
|
|
363
|
-
self.set_requires_grad(self.netD2, False) # D2 requires no gradients when optimizing G2
|
|
364
|
-
self.set_requires_grad(self.netD3, False) # D3 requires no gradients when optimizing G3
|
|
365
|
-
self.set_requires_grad(self.netD4, False) # D4 requires no gradients when optimizing G4
|
|
366
|
-
self.set_requires_grad(self.netD51, False) # D51 requires no gradients when optimizing G51
|
|
367
|
-
self.set_requires_grad(self.netD52, False) # D52 requires no gradients when optimizing G52
|
|
368
|
-
self.set_requires_grad(self.netD53, False) # D53 requires no gradients when optimizing G53
|
|
369
|
-
self.set_requires_grad(self.netD54, False) # D54 requires no gradients when optimizing G54
|
|
370
|
-
self.set_requires_grad(self.netD55, False) # D54 requires no gradients when optimizing G54
|
|
467
|
+
# self.set_requires_grad(self.netD1, False) # D1 requires no gradients when optimizing G1
|
|
468
|
+
# self.set_requires_grad(self.netD2, False) # D2 requires no gradients when optimizing G2
|
|
469
|
+
# self.set_requires_grad(self.netD3, False) # D3 requires no gradients when optimizing G3
|
|
470
|
+
# self.set_requires_grad(self.netD4, False) # D4 requires no gradients when optimizing G4
|
|
471
|
+
# self.set_requires_grad(self.netD51, False) # D51 requires no gradients when optimizing G51
|
|
472
|
+
# self.set_requires_grad(self.netD52, False) # D52 requires no gradients when optimizing G52
|
|
473
|
+
# self.set_requires_grad(self.netD53, False) # D53 requires no gradients when optimizing G53
|
|
474
|
+
# self.set_requires_grad(self.netD54, False) # D54 requires no gradients when optimizing G54
|
|
475
|
+
# self.set_requires_grad(self.netD55, False) # D54 requires no gradients when optimizing G54
|
|
476
|
+
|
|
477
|
+
for model_name in self.model_names_d + self.model_names_ds:
|
|
478
|
+
self.set_requires_grad(getattr(self,f'net{model_name}'), False)
|
|
371
479
|
|
|
372
480
|
self.optimizer_G.zero_grad() # set G's gradients to zero
|
|
373
481
|
self.backward_G() # calculate graidents for G
|
|
374
482
|
|
|
483
|
+
|