deepliif 1.1.11__py3-none-any.whl → 1.1.12__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cli.py +354 -67
- deepliif/data/__init__.py +7 -7
- deepliif/data/aligned_dataset.py +2 -3
- deepliif/data/unaligned_dataset.py +38 -19
- deepliif/models/CycleGAN_model.py +282 -0
- deepliif/models/DeepLIIFExt_model.py +47 -25
- deepliif/models/DeepLIIF_model.py +69 -19
- deepliif/models/SDG_model.py +57 -26
- deepliif/models/__init__ - run_dask_multi dev.py +943 -0
- deepliif/models/__init__ - timings.py +764 -0
- deepliif/models/__init__.py +328 -265
- deepliif/models/att_unet.py +199 -0
- deepliif/models/base_model.py +32 -8
- deepliif/models/networks.py +108 -34
- deepliif/options/__init__.py +49 -5
- deepliif/postprocessing.py +1034 -227
- deepliif/postprocessing__OLD__DELETE.py +440 -0
- deepliif/util/__init__.py +85 -64
- deepliif/util/visualizer.py +106 -19
- {deepliif-1.1.11.dist-info → deepliif-1.1.12.dist-info}/METADATA +75 -23
- deepliif-1.1.12.dist-info/RECORD +40 -0
- deepliif-1.1.11.dist-info/RECORD +0 -35
- {deepliif-1.1.11.dist-info → deepliif-1.1.12.dist-info}/LICENSE.md +0 -0
- {deepliif-1.1.11.dist-info → deepliif-1.1.12.dist-info}/WHEEL +0 -0
- {deepliif-1.1.11.dist-info → deepliif-1.1.12.dist-info}/entry_points.txt +0 -0
- {deepliif-1.1.11.dist-info → deepliif-1.1.12.dist-info}/top_level.txt +0 -0
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
import torch
|
|
2
2
|
from .base_model import BaseModel
|
|
3
3
|
from . import networks
|
|
4
|
+
from .networks import get_optimizer
|
|
4
5
|
|
|
5
6
|
|
|
6
7
|
class DeepLIIFModel(BaseModel):
|
|
@@ -13,13 +14,12 @@ class DeepLIIFModel(BaseModel):
|
|
|
13
14
|
opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions
|
|
14
15
|
"""
|
|
15
16
|
BaseModel.__init__(self, opt)
|
|
16
|
-
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
20
|
-
|
|
21
|
-
self.
|
|
22
|
-
self.loss_D_weights = [0.2, 0.2, 0.2, 0.2, 0.2]
|
|
17
|
+
if not hasattr(opt,'net_gs'):
|
|
18
|
+
opt.net_gs = 'unet_512'
|
|
19
|
+
|
|
20
|
+
self.seg_weights = opt.seg_weights
|
|
21
|
+
self.loss_G_weights = opt.loss_G_weights
|
|
22
|
+
self.loss_D_weights = opt.loss_D_weights
|
|
23
23
|
|
|
24
24
|
if not opt.is_train:
|
|
25
25
|
self.gpu_ids = [] # avoid the models being loaded as DP
|
|
@@ -31,7 +31,7 @@ class DeepLIIFModel(BaseModel):
|
|
|
31
31
|
# specify the training losses you want to print out. The training/test scripts will call <BaseModel.get_current_losses>
|
|
32
32
|
for i in range(1, self.opt.modalities_no + 1 + 1):
|
|
33
33
|
self.loss_names.extend(['G_GAN_' + str(i), 'G_L1_' + str(i), 'D_real_' + str(i), 'D_fake_' + str(i)])
|
|
34
|
-
self.visual_names.extend(['fake_B_' + str(i), 'real_B_' + str(i)])
|
|
34
|
+
self.visual_names.extend(['fake_B_' + str(i), 'fake_B_5' + str(i), 'real_B_' + str(i)])
|
|
35
35
|
|
|
36
36
|
# specify the images you want to save/display. The training/test scripts will call <BaseModel.get_current_visuals>
|
|
37
37
|
# specify the models you want to save to the disk. The training/test scripts will call <BaseModel.save_networks> and <BaseModel.load_networks>
|
|
@@ -51,24 +51,31 @@ class DeepLIIFModel(BaseModel):
|
|
|
51
51
|
self.model_names.extend(['G5' + str(i)])
|
|
52
52
|
|
|
53
53
|
# define networks (both generator and discriminator)
|
|
54
|
-
|
|
54
|
+
if isinstance(opt.netG, str):
|
|
55
|
+
opt.netG = [opt.netG] * 4
|
|
56
|
+
if isinstance(opt.net_gs, str):
|
|
57
|
+
opt.net_gs = [opt.net_gs]*5
|
|
58
|
+
|
|
59
|
+
|
|
60
|
+
self.netG1 = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG[0], opt.norm,
|
|
55
61
|
not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids, opt.padding)
|
|
56
|
-
self.netG2 = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG, opt.norm,
|
|
62
|
+
self.netG2 = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG[1], opt.norm,
|
|
57
63
|
not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids, opt.padding)
|
|
58
|
-
self.netG3 = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG, opt.norm,
|
|
64
|
+
self.netG3 = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG[2], opt.norm,
|
|
59
65
|
not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids, opt.padding)
|
|
60
|
-
self.netG4 = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG, opt.norm,
|
|
66
|
+
self.netG4 = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG[3], opt.norm,
|
|
61
67
|
not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids, opt.padding)
|
|
62
68
|
|
|
63
|
-
|
|
69
|
+
# DeepLIIF model currently uses one gs arch because there is only one explicit seg mod output
|
|
70
|
+
self.netG51 = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.net_gs[0], opt.norm,
|
|
64
71
|
not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids)
|
|
65
|
-
self.netG52 = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf,
|
|
72
|
+
self.netG52 = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.net_gs[1], opt.norm,
|
|
66
73
|
not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids)
|
|
67
|
-
self.netG53 = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf,
|
|
74
|
+
self.netG53 = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.net_gs[2], opt.norm,
|
|
68
75
|
not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids)
|
|
69
|
-
self.netG54 = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf,
|
|
76
|
+
self.netG54 = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.net_gs[3], opt.norm,
|
|
70
77
|
not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids)
|
|
71
|
-
self.netG55 = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf,
|
|
78
|
+
self.netG55 = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.net_gs[4], opt.norm,
|
|
72
79
|
not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids)
|
|
73
80
|
|
|
74
81
|
|
|
@@ -101,10 +108,18 @@ class DeepLIIFModel(BaseModel):
|
|
|
101
108
|
|
|
102
109
|
# initialize optimizers; schedulers will be automatically created by function <BaseModel.setup>.
|
|
103
110
|
params = list(self.netG1.parameters()) + list(self.netG2.parameters()) + list(self.netG3.parameters()) + list(self.netG4.parameters()) + list(self.netG51.parameters()) + list(self.netG52.parameters()) + list(self.netG53.parameters()) + list(self.netG54.parameters()) + list(self.netG55.parameters())
|
|
104
|
-
|
|
111
|
+
try:
|
|
112
|
+
self.optimizer_G = get_optimizer(opt.optimizer)(params, lr=opt.lr_g, betas=(opt.beta1, 0.999))
|
|
113
|
+
except:
|
|
114
|
+
print(f'betas are not used for optimizer torch.optim.{opt.optimizer} in generators')
|
|
115
|
+
self.optimizer_G = get_optimizer(opt.optimizer)(params, lr=opt.lr_g)
|
|
105
116
|
|
|
106
117
|
params = list(self.netD1.parameters()) + list(self.netD2.parameters()) + list(self.netD3.parameters()) + list(self.netD4.parameters()) + list(self.netD51.parameters()) + list(self.netD52.parameters()) + list(self.netD53.parameters()) + list(self.netD54.parameters()) + list(self.netD55.parameters())
|
|
107
|
-
|
|
118
|
+
try:
|
|
119
|
+
self.optimizer_D = get_optimizer(opt.optimizer)(params, lr=opt.lr_d, betas=(opt.beta1, 0.999))
|
|
120
|
+
except:
|
|
121
|
+
print(f'betas are not used for optimizer torch.optim.{opt.optimizer} in discriminators')
|
|
122
|
+
self.optimizer_D = get_optimizer(opt.optimizer)(params, lr=opt.lr_d)
|
|
108
123
|
|
|
109
124
|
self.optimizers.append(self.optimizer_G)
|
|
110
125
|
self.optimizers.append(self.optimizer_D)
|
|
@@ -322,3 +337,38 @@ class DeepLIIFModel(BaseModel):
|
|
|
322
337
|
self.optimizer_G.zero_grad() # set G's gradients to zero
|
|
323
338
|
self.backward_G() # calculate graidents for G
|
|
324
339
|
self.optimizer_G.step() # udpate G's weights
|
|
340
|
+
|
|
341
|
+
def calculate_losses(self):
|
|
342
|
+
"""
|
|
343
|
+
Calculate losses but do not optimize parameters. Used in validation loss calculation during training.
|
|
344
|
+
"""
|
|
345
|
+
|
|
346
|
+
self.forward() # compute fake images: G(A)
|
|
347
|
+
# update D
|
|
348
|
+
self.set_requires_grad(self.netD1, True) # enable backprop for D1
|
|
349
|
+
self.set_requires_grad(self.netD2, True) # enable backprop for D2
|
|
350
|
+
self.set_requires_grad(self.netD3, True) # enable backprop for D3
|
|
351
|
+
self.set_requires_grad(self.netD4, True) # enable backprop for D4
|
|
352
|
+
self.set_requires_grad(self.netD51, True) # enable backprop for D51
|
|
353
|
+
self.set_requires_grad(self.netD52, True) # enable backprop for D52
|
|
354
|
+
self.set_requires_grad(self.netD53, True) # enable backprop for D53
|
|
355
|
+
self.set_requires_grad(self.netD54, True) # enable backprop for D54
|
|
356
|
+
self.set_requires_grad(self.netD55, True) # enable backprop for D54
|
|
357
|
+
|
|
358
|
+
self.optimizer_D.zero_grad() # set D's gradients to zero
|
|
359
|
+
self.backward_D() # calculate gradients for D
|
|
360
|
+
|
|
361
|
+
# update G
|
|
362
|
+
self.set_requires_grad(self.netD1, False) # D1 requires no gradients when optimizing G1
|
|
363
|
+
self.set_requires_grad(self.netD2, False) # D2 requires no gradients when optimizing G2
|
|
364
|
+
self.set_requires_grad(self.netD3, False) # D3 requires no gradients when optimizing G3
|
|
365
|
+
self.set_requires_grad(self.netD4, False) # D4 requires no gradients when optimizing G4
|
|
366
|
+
self.set_requires_grad(self.netD51, False) # D51 requires no gradients when optimizing G51
|
|
367
|
+
self.set_requires_grad(self.netD52, False) # D52 requires no gradients when optimizing G52
|
|
368
|
+
self.set_requires_grad(self.netD53, False) # D53 requires no gradients when optimizing G53
|
|
369
|
+
self.set_requires_grad(self.netD54, False) # D54 requires no gradients when optimizing G54
|
|
370
|
+
self.set_requires_grad(self.netD55, False) # D54 requires no gradients when optimizing G54
|
|
371
|
+
|
|
372
|
+
self.optimizer_G.zero_grad() # set G's gradients to zero
|
|
373
|
+
self.backward_G() # calculate graidents for G
|
|
374
|
+
|
deepliif/models/SDG_model.py
CHANGED
|
@@ -1,6 +1,7 @@
|
|
|
1
1
|
import torch
|
|
2
2
|
from .base_model import BaseModel
|
|
3
3
|
from . import networks
|
|
4
|
+
from .networks import get_optimizer
|
|
4
5
|
|
|
5
6
|
|
|
6
7
|
class SDGModel(BaseModel):
|
|
@@ -15,28 +16,34 @@ class SDGModel(BaseModel):
|
|
|
15
16
|
BaseModel.__init__(self, opt)
|
|
16
17
|
|
|
17
18
|
self.mod_gen_no = self.opt.modalities_no
|
|
19
|
+
|
|
18
20
|
|
|
19
21
|
# weights of the modalities in generating segmentation mask
|
|
20
|
-
self.seg_weights =
|
|
21
|
-
if opt.seg_gen:
|
|
22
|
-
self.seg_weights = [0.3] * self.mod_gen_no
|
|
23
|
-
self.seg_weights[1] = 0.4
|
|
22
|
+
self.seg_weights = opt.seg_weights
|
|
24
23
|
|
|
25
24
|
# self.seg_weights = opt.seg_weights
|
|
26
25
|
# assert len(self.seg_weights) == self.seg_gen_no, 'The number of the segmentation weights (seg_weights) is not equal to the number of target images (modalities_no)!'
|
|
27
|
-
|
|
26
|
+
|
|
28
27
|
# loss weights in calculating the final loss
|
|
29
|
-
self.loss_G_weights =
|
|
28
|
+
self.loss_G_weights = opt.loss_G_weights
|
|
30
29
|
self.loss_GS_weights = [1 / self.mod_gen_no] * self.mod_gen_no
|
|
31
30
|
|
|
32
|
-
self.loss_D_weights =
|
|
31
|
+
self.loss_D_weights = opt.loss_D_weights
|
|
33
32
|
self.loss_DS_weights = [1 / self.mod_gen_no] * self.mod_gen_no
|
|
34
33
|
|
|
34
|
+
# self.gpu_ids is a possibly modifed one for model initialization
|
|
35
|
+
# self.opt.gpu_ids is the original one received in the command
|
|
36
|
+
if not opt.is_train:
|
|
37
|
+
self.gpu_ids = [] # avoid the models being loaded as DP
|
|
38
|
+
else:
|
|
39
|
+
self.gpu_ids = opt.gpu_ids
|
|
40
|
+
|
|
35
41
|
self.loss_names = []
|
|
36
42
|
self.visual_names = ['real_A']
|
|
37
43
|
# specify the training losses you want to print out. The training/test scripts will call <BaseModel.get_current_losses>
|
|
38
44
|
for i in range(1, self.mod_gen_no + 1):
|
|
39
|
-
self.loss_names.extend(['G_GAN_' + str(i), 'G_L1_' + str(i), '
|
|
45
|
+
self.loss_names.extend(['G_GAN_' + str(i), 'G_L1_' + str(i), 'G_VGG_'+ str(i),
|
|
46
|
+
'D_real_' + str(i), 'D_fake_' + str(i)])
|
|
40
47
|
self.visual_names.extend(['fake_B_' + str(i), 'real_B_' + str(i)])
|
|
41
48
|
|
|
42
49
|
# specify the images you want to save/display. The training/test scripts will call <BaseModel.get_current_visuals>
|
|
@@ -52,41 +59,48 @@ class SDGModel(BaseModel):
|
|
|
52
59
|
self.model_names.extend(['G_' + str(i)])
|
|
53
60
|
|
|
54
61
|
# define networks (both generator and discriminator)
|
|
62
|
+
if isinstance(opt.net_g, str):
|
|
63
|
+
self.opt.net_g = [self.opt.net_g] * self.mod_gen_no
|
|
64
|
+
if isinstance(opt.net_gs, str):
|
|
65
|
+
self.opt.net_gs = [self.opt.net_gs]*self.mod_gen_no
|
|
55
66
|
self.netG = [None for _ in range(self.mod_gen_no)]
|
|
67
|
+
self.netGS = [None for _ in range(self.mod_gen_no)]
|
|
56
68
|
for i in range(self.mod_gen_no):
|
|
57
|
-
self.netG[i] = networks.define_G(self.opt.input_nc * self.opt.input_no, self.opt.output_nc, self.opt.ngf, self.opt.net_g, self.opt.norm,
|
|
58
|
-
not self.opt.no_dropout, self.opt.init_type, self.opt.init_gain, self.
|
|
59
|
-
print('***************************************')
|
|
60
|
-
print(self.opt.input_nc, self.opt.output_nc, self.opt.ngf, self.opt.net_g, self.opt.norm,
|
|
61
|
-
not self.opt.no_dropout, self.opt.init_type, self.opt.init_gain, self.opt.gpu_ids, self.opt.padding)
|
|
62
|
-
print('***************************************')
|
|
69
|
+
self.netG[i] = networks.define_G(self.opt.input_nc * self.opt.input_no, self.opt.output_nc, self.opt.ngf, self.opt.net_g[i], self.opt.norm,
|
|
70
|
+
not self.opt.no_dropout, self.opt.init_type, self.opt.init_gain, self.gpu_ids, self.opt.padding)
|
|
63
71
|
|
|
64
72
|
if self.is_train: # define a discriminator; conditional GANs need to take both input and output images; Therefore, #channels for D is input_nc + output_nc
|
|
65
73
|
self.netD = [None for _ in range(self.mod_gen_no)]
|
|
66
74
|
for i in range(self.mod_gen_no):
|
|
67
75
|
self.netD[i] = networks.define_D(self.opt.input_nc * self.opt.input_no + self.opt.output_nc, self.opt.ndf, self.opt.net_d,
|
|
68
76
|
self.opt.n_layers_D, self.opt.norm, self.opt.init_type, self.opt.init_gain,
|
|
69
|
-
self.
|
|
77
|
+
self.gpu_ids)
|
|
70
78
|
|
|
71
79
|
if self.is_train:
|
|
72
80
|
# define loss functions
|
|
73
81
|
self.criterionGAN_mod = networks.GANLoss(self.opt.gan_mode).to(self.device)
|
|
74
82
|
self.criterionGAN_seg = networks.GANLoss(self.opt.gan_mode_s).to(self.device)
|
|
75
|
-
|
|
76
83
|
self.criterionSmoothL1 = torch.nn.SmoothL1Loss()
|
|
77
|
-
|
|
78
84
|
self.criterionVGG = networks.VGGLoss().to(self.device)
|
|
79
85
|
|
|
80
86
|
# initialize optimizers; schedulers will be automatically created by function <BaseModel.setup>.
|
|
81
87
|
params = []
|
|
82
88
|
for i in range(len(self.netG)):
|
|
83
89
|
params += list(self.netG[i].parameters())
|
|
84
|
-
|
|
90
|
+
try:
|
|
91
|
+
self.optimizer_G = get_optimizer(opt.optimizer)(params, lr=opt.lr_g, betas=(opt.beta1, 0.999))
|
|
92
|
+
except:
|
|
93
|
+
print(f'betas are not used for optimizer torch.optim.{opt.optimizer} in generators')
|
|
94
|
+
self.optimizer_G = get_optimizer(opt.optimizer)(params, lr=opt.lr_g)
|
|
85
95
|
|
|
86
96
|
params = []
|
|
87
97
|
for i in range(len(self.netD)):
|
|
88
98
|
params += list(self.netD[i].parameters())
|
|
89
|
-
|
|
99
|
+
try:
|
|
100
|
+
self.optimizer_D = get_optimizer(opt.optimizer)(params, lr=opt.lr_d, betas=(opt.beta1, 0.999))
|
|
101
|
+
except:
|
|
102
|
+
print(f'betas are not used for optimizer torch.optim.{opt.optimizer} in discriminators')
|
|
103
|
+
self.optimizer_D = get_optimizer(opt.optimizer)(params, lr=opt.lr_d)
|
|
90
104
|
|
|
91
105
|
self.optimizers.append(self.optimizer_G)
|
|
92
106
|
self.optimizers.append(self.optimizer_D)
|
|
@@ -136,7 +150,6 @@ class SDGModel(BaseModel):
|
|
|
136
150
|
self.loss_D_real.append(self.criterionGAN_mod(pred_real[i], True))
|
|
137
151
|
|
|
138
152
|
# combine losses and calculate gradients
|
|
139
|
-
# self.loss_D = (self.loss_D_fake[0] + self.loss_D_real[0]) * 0.5 * self.loss_D_weights[0]
|
|
140
153
|
self.loss_D = torch.tensor(0., device=self.device)
|
|
141
154
|
for i in range(0, self.mod_gen_no):
|
|
142
155
|
self.loss_D += (self.loss_D_fake[i] + self.loss_D_real[i]) * 0.5 * self.loss_D_weights[i]
|
|
@@ -159,15 +172,14 @@ class SDGModel(BaseModel):
|
|
|
159
172
|
for i in range(self.mod_gen_no):
|
|
160
173
|
self.loss_G_L1.append(self.criterionSmoothL1(self.fake_B[i], self.real_B[i]) * self.opt.lambda_L1)
|
|
161
174
|
|
|
162
|
-
|
|
163
|
-
|
|
164
|
-
|
|
175
|
+
self.loss_G_VGG = []
|
|
176
|
+
for i in range(self.mod_gen_no):
|
|
177
|
+
self.loss_G_VGG.append(self.criterionVGG(self.fake_B[i], self.real_B[i]) * self.opt.lambda_feat)
|
|
178
|
+
|
|
165
179
|
|
|
166
|
-
# self.loss_G = (self.loss_G_GAN[0] + self.loss_G_L1[0]) * self.loss_G_weights[0]
|
|
167
180
|
self.loss_G = torch.tensor(0., device=self.device)
|
|
168
181
|
for i in range(0, self.mod_gen_no):
|
|
169
|
-
self.loss_G += (self.loss_G_GAN[i] + self.loss_G_L1[i]) * self.loss_G_weights[i]
|
|
170
|
-
# self.loss_G += (self.loss_G_GAN[i] + self.loss_G_L1[i] + self.loss_G_VGG[i]) * self.loss_G_weights[i]
|
|
182
|
+
self.loss_G += (self.loss_G_GAN[i] + self.loss_G_L1[i] + self.loss_G_VGG[i]) * self.loss_G_weights[i]
|
|
171
183
|
self.loss_G.backward()
|
|
172
184
|
|
|
173
185
|
def optimize_parameters(self):
|
|
@@ -187,3 +199,22 @@ class SDGModel(BaseModel):
|
|
|
187
199
|
self.optimizer_G.zero_grad() # set G's gradients to zero
|
|
188
200
|
self.backward_G() # calculate graidents for G
|
|
189
201
|
self.optimizer_G.step() # udpate G's weights
|
|
202
|
+
|
|
203
|
+
def calculate_losses(self):
|
|
204
|
+
"""
|
|
205
|
+
Calculate losses but do not optimize parameters. Used in validation loss calculation during training.
|
|
206
|
+
"""
|
|
207
|
+
self.forward() # compute fake images: G(A)
|
|
208
|
+
# update D
|
|
209
|
+
for i in range(self.mod_gen_no):
|
|
210
|
+
self.set_requires_grad(self.netD[i], True) # enable backprop for D1
|
|
211
|
+
|
|
212
|
+
self.optimizer_D.zero_grad() # set D's gradients to zero
|
|
213
|
+
self.backward_D() # calculate gradients for D
|
|
214
|
+
|
|
215
|
+
# update G
|
|
216
|
+
for i in range(self.mod_gen_no):
|
|
217
|
+
self.set_requires_grad(self.netD[i], False)
|
|
218
|
+
|
|
219
|
+
self.optimizer_G.zero_grad() # set G's gradients to zero
|
|
220
|
+
self.backward_G() # calculate graidents for G
|