deepliif 1.1.9__py3-none-any.whl → 1.1.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cli.py +49 -42
- deepliif/data/aligned_dataset.py +17 -0
- deepliif/models/SDG_model.py +189 -0
- deepliif/models/__init__.py +170 -46
- deepliif/options/__init__.py +62 -29
- deepliif/util/__init__.py +227 -0
- deepliif/util/util.py +17 -1
- {deepliif-1.1.9.dist-info → deepliif-1.1.11.dist-info}/METADATA +181 -27
- {deepliif-1.1.9.dist-info → deepliif-1.1.11.dist-info}/RECORD +13 -12
- {deepliif-1.1.9.dist-info → deepliif-1.1.11.dist-info}/LICENSE.md +0 -0
- {deepliif-1.1.9.dist-info → deepliif-1.1.11.dist-info}/WHEEL +0 -0
- {deepliif-1.1.9.dist-info → deepliif-1.1.11.dist-info}/entry_points.txt +0 -0
- {deepliif-1.1.9.dist-info → deepliif-1.1.11.dist-info}/top_level.txt +0 -0
cli.py
CHANGED
|
@@ -14,7 +14,7 @@ from deepliif.models import init_nets, infer_modalities, infer_results_for_wsi,
|
|
|
14
14
|
from deepliif.util import allowed_file, Visualizer, get_information, test_diff_original_serialized, disable_batchnorm_tracking_stats
|
|
15
15
|
from deepliif.util.util import mkdirs, check_multi_scale
|
|
16
16
|
# from deepliif.util import infer_results_for_wsi
|
|
17
|
-
from deepliif.options import Options
|
|
17
|
+
from deepliif.options import Options, print_options
|
|
18
18
|
|
|
19
19
|
import torch.distributed as dist
|
|
20
20
|
|
|
@@ -59,29 +59,6 @@ def set_seed(seed=0,rank=None):
|
|
|
59
59
|
def ensure_exists(d):
|
|
60
60
|
if not os.path.exists(d):
|
|
61
61
|
os.makedirs(d)
|
|
62
|
-
|
|
63
|
-
def print_options(opt):
|
|
64
|
-
"""Print and save options
|
|
65
|
-
|
|
66
|
-
It will print both current options and default values(if different).
|
|
67
|
-
It will save options into a text file / [checkpoints_dir] / opt.txt
|
|
68
|
-
"""
|
|
69
|
-
message = ''
|
|
70
|
-
message += '----------------- Options ---------------\n'
|
|
71
|
-
for k, v in sorted(vars(opt).items()):
|
|
72
|
-
comment = ''
|
|
73
|
-
message += '{:>25}: {:<30}{}\n'.format(str(k), str(v), comment)
|
|
74
|
-
message += '----------------- End -------------------'
|
|
75
|
-
print(message)
|
|
76
|
-
|
|
77
|
-
# save to the disk
|
|
78
|
-
if opt.phase == 'train':
|
|
79
|
-
expr_dir = os.path.join(opt.checkpoints_dir, opt.name)
|
|
80
|
-
mkdirs(expr_dir)
|
|
81
|
-
file_name = os.path.join(expr_dir, '{}_opt.txt'.format(opt.phase))
|
|
82
|
-
with open(file_name, 'wt') as opt_file:
|
|
83
|
-
opt_file.write(message)
|
|
84
|
-
opt_file.write('\n')
|
|
85
62
|
|
|
86
63
|
|
|
87
64
|
@click.group()
|
|
@@ -212,6 +189,18 @@ def train(dataroot, name, gpu_ids, checkpoints_dir, input_nc, output_nc, ngf, nd
|
|
|
212
189
|
plot, and save models.The script supports continue/resume training.
|
|
213
190
|
Use '--continue_train' to resume your previous training.
|
|
214
191
|
"""
|
|
192
|
+
assert model in ['DeepLIIF','DeepLIIFExt','SDG'], f'model class {model} is not implemented'
|
|
193
|
+
if model == 'DeepLIIF':
|
|
194
|
+
seg_no = 1
|
|
195
|
+
elif model == 'DeepLIIFExt':
|
|
196
|
+
if seg_gen:
|
|
197
|
+
seg_no = modalities_no
|
|
198
|
+
else:
|
|
199
|
+
seg_no = 0
|
|
200
|
+
else: # SDG
|
|
201
|
+
seg_no = 0
|
|
202
|
+
seg_gen = False
|
|
203
|
+
|
|
215
204
|
d_params = locals()
|
|
216
205
|
|
|
217
206
|
if gpu_ids and gpu_ids[0] == -1:
|
|
@@ -241,11 +230,26 @@ def train(dataroot, name, gpu_ids, checkpoints_dir, input_nc, output_nc, ngf, nd
|
|
|
241
230
|
d_params['padding'] = 'zero'
|
|
242
231
|
print('padding type is forced to zero padding, because neither refection pad2d or replication pad2d has a deterministic implementation')
|
|
243
232
|
|
|
233
|
+
# infer number of input images
|
|
234
|
+
dir_data_train = dataroot + '/train'
|
|
235
|
+
fns = os.listdir(dir_data_train)
|
|
236
|
+
fns = [x for x in fns if x.endswith('.png')]
|
|
237
|
+
img = Image.open(f"{dir_data_train}/{fns[0]}")
|
|
238
|
+
|
|
239
|
+
num_img = img.size[0] / img.size[1]
|
|
240
|
+
assert int(num_img) == num_img, f'img size {img.size[0]} / {img.size[1]} = {num_img} is not an integer'
|
|
241
|
+
num_img = int(num_img)
|
|
242
|
+
|
|
243
|
+
input_no = num_img - modalities_no - seg_no
|
|
244
|
+
assert input_no > 0, f'inferred number of input images is {input_no}; should be greater than 0'
|
|
245
|
+
d_params['input_no'] = input_no
|
|
246
|
+
d_params['scale_size'] = img.size[1]
|
|
247
|
+
|
|
244
248
|
# create a dataset given dataset_mode and other options
|
|
245
249
|
# dataset = AlignedDataset(opt)
|
|
246
250
|
|
|
247
251
|
opt = Options(d_params=d_params)
|
|
248
|
-
print_options(opt)
|
|
252
|
+
print_options(opt, save=True)
|
|
249
253
|
|
|
250
254
|
dataset = create_dataset(opt)
|
|
251
255
|
# get the number of images in the dataset.
|
|
@@ -468,28 +472,30 @@ def trainlaunch(**kwargs):
|
|
|
468
472
|
|
|
469
473
|
|
|
470
474
|
@cli.command()
|
|
471
|
-
@click.option('--
|
|
475
|
+
@click.option('--model-dir', default='./model-server/DeepLIIF_Latest_Model', help='reads models from here')
|
|
472
476
|
@click.option('--output-dir', help='saves results here.')
|
|
473
|
-
|
|
474
|
-
@click.option('--device', default='cpu', type=str, help='device to load model, either cpu or gpu')
|
|
477
|
+
#@click.option('--tile-size', type=int, default=None, help='tile size')
|
|
478
|
+
@click.option('--device', default='cpu', type=str, help='device to load model for the similarity test, either cpu or gpu')
|
|
475
479
|
@click.option('--verbose', default=0, type=int,help='saves results here.')
|
|
476
|
-
def serialize(
|
|
480
|
+
def serialize(model_dir, output_dir, device, verbose):
|
|
477
481
|
"""Serialize DeepLIIF models using Torchscript
|
|
478
482
|
"""
|
|
479
|
-
if tile_size is None:
|
|
480
|
-
|
|
481
|
-
output_dir = output_dir or
|
|
483
|
+
#if tile_size is None:
|
|
484
|
+
# tile_size = 512
|
|
485
|
+
output_dir = output_dir or model_dir
|
|
482
486
|
ensure_exists(output_dir)
|
|
483
487
|
|
|
484
488
|
# copy train_opt.txt to the target location
|
|
485
489
|
import shutil
|
|
486
|
-
if
|
|
487
|
-
shutil.copy(f'{
|
|
490
|
+
if model_dir != output_dir:
|
|
491
|
+
shutil.copy(f'{model_dir}/train_opt.txt',f'{output_dir}/train_opt.txt')
|
|
488
492
|
|
|
489
|
-
|
|
493
|
+
opt = Options(path_file=os.path.join(model_dir,'train_opt.txt'), mode='test')
|
|
494
|
+
sample = transform(Image.new('RGB', (opt.scale_size, opt.scale_size)))
|
|
495
|
+
sample = torch.cat([sample]*opt.input_no, 1)
|
|
490
496
|
|
|
491
497
|
with click.progressbar(
|
|
492
|
-
init_nets(
|
|
498
|
+
init_nets(model_dir, eager_mode=True, phase='test').items(),
|
|
493
499
|
label='Tracing nets',
|
|
494
500
|
item_show_func=lambda n: n[0] if n else n
|
|
495
501
|
) as bar:
|
|
@@ -508,7 +514,7 @@ def serialize(models_dir, output_dir, tile_size, device, verbose):
|
|
|
508
514
|
|
|
509
515
|
# test: whether the original and the serialized model produces highly similar predictions
|
|
510
516
|
print('testing similarity between prediction from original vs serialized models...')
|
|
511
|
-
models_original = init_nets(
|
|
517
|
+
models_original = init_nets(model_dir,eager_mode=True,phase='test')
|
|
512
518
|
models_serialized = init_nets(output_dir,eager_mode=False,phase='test')
|
|
513
519
|
if device == 'gpu':
|
|
514
520
|
sample = sample.cuda()
|
|
@@ -590,11 +596,12 @@ def test(input_dir, output_dir, tile_size, model_dir, gpu_ids, region_size, eage
|
|
|
590
596
|
filename.replace('.' + filename.split('.')[-1], f'_{name}.png')
|
|
591
597
|
))
|
|
592
598
|
|
|
593
|
-
|
|
594
|
-
|
|
595
|
-
|
|
596
|
-
|
|
597
|
-
|
|
599
|
+
if scoring is not None:
|
|
600
|
+
with open(os.path.join(
|
|
601
|
+
output_dir,
|
|
602
|
+
filename.replace('.' + filename.split('.')[-1], f'.json')
|
|
603
|
+
), 'w') as f:
|
|
604
|
+
json.dump(scoring, f, indent=2)
|
|
598
605
|
|
|
599
606
|
@cli.command()
|
|
600
607
|
@click.option('--input-dir', type=str, required=True, help='Path to input images')
|
deepliif/data/aligned_dataset.py
CHANGED
|
@@ -26,6 +26,8 @@ class AlignedDataset(BaseDataset):
|
|
|
26
26
|
self.output_nc = opt.input_nc if opt.direction == 'BtoA' else opt.output_nc
|
|
27
27
|
self.no_flip = opt.no_flip
|
|
28
28
|
self.modalities_no = opt.modalities_no
|
|
29
|
+
self.seg_no = opt.seg_no
|
|
30
|
+
self.input_no = opt.input_no
|
|
29
31
|
self.seg_gen = opt.seg_gen
|
|
30
32
|
self.load_size = opt.load_size
|
|
31
33
|
self.crop_size = opt.crop_size
|
|
@@ -52,6 +54,8 @@ class AlignedDataset(BaseDataset):
|
|
|
52
54
|
num_img = self.modalities_no + 1 + 1 # +1 for segmentation channel, +1 for input image
|
|
53
55
|
elif self.model == 'DeepLIIFExt':
|
|
54
56
|
num_img = self.modalities_no * 2 + 1 if self.seg_gen else self.modalities_no + 1 # +1 for segmentation channel
|
|
57
|
+
elif self.model == 'SDG':
|
|
58
|
+
num_img = self.modalities_no + self.seg_no + self.input_no
|
|
55
59
|
else:
|
|
56
60
|
raise Exception(f'model class {self.model} does not have corresponding implementation in deepliif/data/aligned_dataset.py')
|
|
57
61
|
w2 = int(w / num_img)
|
|
@@ -85,6 +89,19 @@ class AlignedDataset(BaseDataset):
|
|
|
85
89
|
BS_Array.append(BS)
|
|
86
90
|
|
|
87
91
|
return {'A': A, 'B': B_Array, 'BS': BS_Array,'A_paths': AB_path, 'B_paths': AB_path}
|
|
92
|
+
elif self.model == 'SDG':
|
|
93
|
+
A_Array = []
|
|
94
|
+
for i in range(self.input_no):
|
|
95
|
+
A = AB.crop((w2 * i, 0, w2 * (i+1), h))
|
|
96
|
+
A = A_transform(A)
|
|
97
|
+
A_Array.append(A)
|
|
98
|
+
|
|
99
|
+
for i in range(self.input_no, self.input_no + self.modalities_no + 1):
|
|
100
|
+
B = AB.crop((w2 * i, 0, w2 * (i + 1), h))
|
|
101
|
+
B = B_transform(B)
|
|
102
|
+
B_Array.append(B)
|
|
103
|
+
|
|
104
|
+
return {'A': A_Array, 'B': B_Array, 'A_paths': AB_path, 'B_paths': AB_path}
|
|
88
105
|
else:
|
|
89
106
|
raise Exception(f'model class {self.model} does not have corresponding implementation in deepliif/data/aligned_dataset.py')
|
|
90
107
|
|
|
@@ -0,0 +1,189 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from .base_model import BaseModel
|
|
3
|
+
from . import networks
|
|
4
|
+
|
|
5
|
+
|
|
6
|
+
class SDGModel(BaseModel):
|
|
7
|
+
""" This class implements the Synthetic Data Generation model (based on DeepLIIFExt), for learning a mapping from input images to modalities given paired data."""
|
|
8
|
+
|
|
9
|
+
def __init__(self, opt):
|
|
10
|
+
"""Initialize the DeepLIIF class.
|
|
11
|
+
|
|
12
|
+
Parameters:
|
|
13
|
+
opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions
|
|
14
|
+
"""
|
|
15
|
+
BaseModel.__init__(self, opt)
|
|
16
|
+
|
|
17
|
+
self.mod_gen_no = self.opt.modalities_no
|
|
18
|
+
|
|
19
|
+
# weights of the modalities in generating segmentation mask
|
|
20
|
+
self.seg_weights = [0, 0, 0]
|
|
21
|
+
if opt.seg_gen:
|
|
22
|
+
self.seg_weights = [0.3] * self.mod_gen_no
|
|
23
|
+
self.seg_weights[1] = 0.4
|
|
24
|
+
|
|
25
|
+
# self.seg_weights = opt.seg_weights
|
|
26
|
+
# assert len(self.seg_weights) == self.seg_gen_no, 'The number of the segmentation weights (seg_weights) is not equal to the number of target images (modalities_no)!'
|
|
27
|
+
# print(self.seg_weights)
|
|
28
|
+
# loss weights in calculating the final loss
|
|
29
|
+
self.loss_G_weights = [1 / self.mod_gen_no] * self.mod_gen_no
|
|
30
|
+
self.loss_GS_weights = [1 / self.mod_gen_no] * self.mod_gen_no
|
|
31
|
+
|
|
32
|
+
self.loss_D_weights = [1 / self.mod_gen_no] * self.mod_gen_no
|
|
33
|
+
self.loss_DS_weights = [1 / self.mod_gen_no] * self.mod_gen_no
|
|
34
|
+
|
|
35
|
+
self.loss_names = []
|
|
36
|
+
self.visual_names = ['real_A']
|
|
37
|
+
# specify the training losses you want to print out. The training/test scripts will call <BaseModel.get_current_losses>
|
|
38
|
+
for i in range(1, self.mod_gen_no + 1):
|
|
39
|
+
self.loss_names.extend(['G_GAN_' + str(i), 'G_L1_' + str(i), 'D_real_' + str(i), 'D_fake_' + str(i)])
|
|
40
|
+
self.visual_names.extend(['fake_B_' + str(i), 'real_B_' + str(i)])
|
|
41
|
+
|
|
42
|
+
# specify the images you want to save/display. The training/test scripts will call <BaseModel.get_current_visuals>
|
|
43
|
+
# specify the models you want to save to the disk. The training/test scripts will call <BaseModel.save_networks> and <BaseModel.load_networks>
|
|
44
|
+
if self.is_train:
|
|
45
|
+
self.model_names = []
|
|
46
|
+
for i in range(1, self.mod_gen_no + 1):
|
|
47
|
+
self.model_names.extend(['G_' + str(i), 'D_' + str(i)])
|
|
48
|
+
|
|
49
|
+
else: # during test time, only load G
|
|
50
|
+
self.model_names = []
|
|
51
|
+
for i in range(1, self.mod_gen_no + 1):
|
|
52
|
+
self.model_names.extend(['G_' + str(i)])
|
|
53
|
+
|
|
54
|
+
# define networks (both generator and discriminator)
|
|
55
|
+
self.netG = [None for _ in range(self.mod_gen_no)]
|
|
56
|
+
for i in range(self.mod_gen_no):
|
|
57
|
+
self.netG[i] = networks.define_G(self.opt.input_nc * self.opt.input_no, self.opt.output_nc, self.opt.ngf, self.opt.net_g, self.opt.norm,
|
|
58
|
+
not self.opt.no_dropout, self.opt.init_type, self.opt.init_gain, self.opt.gpu_ids, self.opt.padding)
|
|
59
|
+
print('***************************************')
|
|
60
|
+
print(self.opt.input_nc, self.opt.output_nc, self.opt.ngf, self.opt.net_g, self.opt.norm,
|
|
61
|
+
not self.opt.no_dropout, self.opt.init_type, self.opt.init_gain, self.opt.gpu_ids, self.opt.padding)
|
|
62
|
+
print('***************************************')
|
|
63
|
+
|
|
64
|
+
if self.is_train: # define a discriminator; conditional GANs need to take both input and output images; Therefore, #channels for D is input_nc + output_nc
|
|
65
|
+
self.netD = [None for _ in range(self.mod_gen_no)]
|
|
66
|
+
for i in range(self.mod_gen_no):
|
|
67
|
+
self.netD[i] = networks.define_D(self.opt.input_nc * self.opt.input_no + self.opt.output_nc, self.opt.ndf, self.opt.net_d,
|
|
68
|
+
self.opt.n_layers_D, self.opt.norm, self.opt.init_type, self.opt.init_gain,
|
|
69
|
+
self.opt.gpu_ids)
|
|
70
|
+
|
|
71
|
+
if self.is_train:
|
|
72
|
+
# define loss functions
|
|
73
|
+
self.criterionGAN_mod = networks.GANLoss(self.opt.gan_mode).to(self.device)
|
|
74
|
+
self.criterionGAN_seg = networks.GANLoss(self.opt.gan_mode_s).to(self.device)
|
|
75
|
+
|
|
76
|
+
self.criterionSmoothL1 = torch.nn.SmoothL1Loss()
|
|
77
|
+
|
|
78
|
+
self.criterionVGG = networks.VGGLoss().to(self.device)
|
|
79
|
+
|
|
80
|
+
# initialize optimizers; schedulers will be automatically created by function <BaseModel.setup>.
|
|
81
|
+
params = []
|
|
82
|
+
for i in range(len(self.netG)):
|
|
83
|
+
params += list(self.netG[i].parameters())
|
|
84
|
+
self.optimizer_G = torch.optim.Adam(params, lr=opt.lr, betas=(opt.beta1, 0.999))
|
|
85
|
+
|
|
86
|
+
params = []
|
|
87
|
+
for i in range(len(self.netD)):
|
|
88
|
+
params += list(self.netD[i].parameters())
|
|
89
|
+
self.optimizer_D = torch.optim.Adam(params, lr=opt.lr, betas=(opt.beta1, 0.999))
|
|
90
|
+
|
|
91
|
+
self.optimizers.append(self.optimizer_G)
|
|
92
|
+
self.optimizers.append(self.optimizer_D)
|
|
93
|
+
|
|
94
|
+
def set_input(self, input):
|
|
95
|
+
"""
|
|
96
|
+
Unpack input data from the dataloader and perform necessary pre-processing steps.
|
|
97
|
+
|
|
98
|
+
:param input (dict): include the input image and the output modalities
|
|
99
|
+
"""
|
|
100
|
+
self.real_A_array = input['A']
|
|
101
|
+
As = [A.to(self.device) for A in self.real_A_array]
|
|
102
|
+
self.real_A = torch.cat(As, dim=1) # shape: 1, (3 x input_no), 512, 512
|
|
103
|
+
|
|
104
|
+
self.real_B_array = input['B']
|
|
105
|
+
self.real_B = []
|
|
106
|
+
for i in range(len(self.real_B_array)):
|
|
107
|
+
self.real_B.append(self.real_B_array[i].to(self.device))
|
|
108
|
+
|
|
109
|
+
self.real_concatenated = []
|
|
110
|
+
self.image_paths = input['A_paths']
|
|
111
|
+
|
|
112
|
+
def forward(self):
|
|
113
|
+
"""Run forward pass; called by both functions <optimize_parameters> and <test>."""
|
|
114
|
+
self.fake_B = []
|
|
115
|
+
for i in range(self.mod_gen_no):
|
|
116
|
+
self.fake_B.append(self.netG[i](self.real_A))
|
|
117
|
+
|
|
118
|
+
|
|
119
|
+
def backward_D(self):
|
|
120
|
+
"""Calculate GAN loss for the discriminators"""
|
|
121
|
+
|
|
122
|
+
pred_fake = []
|
|
123
|
+
for i in range(self.mod_gen_no):
|
|
124
|
+
pred_fake.append(self.netD[i](torch.cat((self.real_A, self.fake_B[i]), 1).detach()))
|
|
125
|
+
|
|
126
|
+
self.loss_D_fake = []
|
|
127
|
+
for i in range(self.mod_gen_no):
|
|
128
|
+
self.loss_D_fake.append(self.criterionGAN_mod(pred_fake[i], False))
|
|
129
|
+
|
|
130
|
+
pred_real = []
|
|
131
|
+
for i in range(self.mod_gen_no):
|
|
132
|
+
pred_real.append(self.netD[i](torch.cat((self.real_A, self.real_B[i]), 1)))
|
|
133
|
+
|
|
134
|
+
self.loss_D_real = []
|
|
135
|
+
for i in range(self.mod_gen_no):
|
|
136
|
+
self.loss_D_real.append(self.criterionGAN_mod(pred_real[i], True))
|
|
137
|
+
|
|
138
|
+
# combine losses and calculate gradients
|
|
139
|
+
# self.loss_D = (self.loss_D_fake[0] + self.loss_D_real[0]) * 0.5 * self.loss_D_weights[0]
|
|
140
|
+
self.loss_D = torch.tensor(0., device=self.device)
|
|
141
|
+
for i in range(0, self.mod_gen_no):
|
|
142
|
+
self.loss_D += (self.loss_D_fake[i] + self.loss_D_real[i]) * 0.5 * self.loss_D_weights[i]
|
|
143
|
+
self.loss_D.backward()
|
|
144
|
+
|
|
145
|
+
def backward_G(self):
|
|
146
|
+
"""Calculate GAN and L1 loss for the generator"""
|
|
147
|
+
pred_fake = []
|
|
148
|
+
for i in range(self.mod_gen_no):
|
|
149
|
+
pred_fake.append(self.netD[i](torch.cat((self.real_A, self.fake_B[i]), 1)))
|
|
150
|
+
|
|
151
|
+
self.loss_G_GAN = []
|
|
152
|
+
self.loss_GS_GAN = []
|
|
153
|
+
for i in range(self.mod_gen_no):
|
|
154
|
+
self.loss_G_GAN.append(self.criterionGAN_mod(pred_fake[i], True))
|
|
155
|
+
|
|
156
|
+
# Second, G(A) = B
|
|
157
|
+
self.loss_G_L1 = []
|
|
158
|
+
self.loss_GS_L1 = []
|
|
159
|
+
for i in range(self.mod_gen_no):
|
|
160
|
+
self.loss_G_L1.append(self.criterionSmoothL1(self.fake_B[i], self.real_B[i]) * self.opt.lambda_L1)
|
|
161
|
+
|
|
162
|
+
#self.loss_G_VGG = []
|
|
163
|
+
#for i in range(self.mod_gen_no):
|
|
164
|
+
# self.loss_G_VGG.append(self.criterionVGG(self.fake_B[i], self.real_B[i]) * self.opt.lambda_feat)
|
|
165
|
+
|
|
166
|
+
# self.loss_G = (self.loss_G_GAN[0] + self.loss_G_L1[0]) * self.loss_G_weights[0]
|
|
167
|
+
self.loss_G = torch.tensor(0., device=self.device)
|
|
168
|
+
for i in range(0, self.mod_gen_no):
|
|
169
|
+
self.loss_G += (self.loss_G_GAN[i] + self.loss_G_L1[i]) * self.loss_G_weights[i]
|
|
170
|
+
# self.loss_G += (self.loss_G_GAN[i] + self.loss_G_L1[i] + self.loss_G_VGG[i]) * self.loss_G_weights[i]
|
|
171
|
+
self.loss_G.backward()
|
|
172
|
+
|
|
173
|
+
def optimize_parameters(self):
|
|
174
|
+
self.forward() # compute fake images: G(A)
|
|
175
|
+
# update D
|
|
176
|
+
for i in range(self.mod_gen_no):
|
|
177
|
+
self.set_requires_grad(self.netD[i], True) # enable backprop for D1
|
|
178
|
+
|
|
179
|
+
self.optimizer_D.zero_grad() # set D's gradients to zero
|
|
180
|
+
self.backward_D() # calculate gradients for D
|
|
181
|
+
self.optimizer_D.step() # update D's weights
|
|
182
|
+
|
|
183
|
+
# update G
|
|
184
|
+
for i in range(self.mod_gen_no):
|
|
185
|
+
self.set_requires_grad(self.netD[i], False)
|
|
186
|
+
|
|
187
|
+
self.optimizer_G.zero_grad() # set G's gradients to zero
|
|
188
|
+
self.backward_G() # calculate graidents for G
|
|
189
|
+
self.optimizer_G.step() # udpate G's weights
|