deepliif 1.1.11__py3-none-any.whl → 1.1.12__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
deepliif/data/__init__.py CHANGED
@@ -55,28 +55,28 @@ def get_option_setter(dataset_name):
55
55
  return dataset_class.modify_commandline_options
56
56
 
57
57
 
58
- def create_dataset(opt):
58
+ def create_dataset(opt, phase=None, batch_size=None):
59
59
  """Create a dataset given the option.
60
60
 
61
61
  This function wraps the class CustomDatasetDataLoader.
62
62
  This is the main interface between this package and 'train.py'/'test.py'
63
63
  """
64
- return CustomDatasetDataLoader(opt)
64
+ return CustomDatasetDataLoader(opt, phase=phase if phase else opt.phase, batch_size=batch_size if batch_size else opt.batch_size)
65
65
 
66
66
 
67
67
  class CustomDatasetDataLoader(object):
68
68
  """Wrapper class of Dataset class that performs multi-threaded data loading"""
69
69
 
70
- def __init__(self, opt):
70
+ def __init__(self, opt, phase=None, batch_size=None):
71
71
  """Initialize this class
72
72
 
73
73
  Step 1: create a dataset instance given the name [dataset_mode]
74
74
  Step 2: create a multi-threaded data loader.
75
75
  """
76
- self.batch_size = opt.batch_size
76
+ self.batch_size = batch_size if batch_size else opt.batch_size
77
77
  self.max_dataset_size = opt.max_dataset_size
78
78
  dataset_class = find_dataset_using_name(opt.dataset_mode)
79
- self.dataset = dataset_class(opt)
79
+ self.dataset = dataset_class(opt, phase=phase if phase else opt.phase)
80
80
  print("dataset [%s] was created" % type(self.dataset).__name__)
81
81
 
82
82
  sampler = None
@@ -95,7 +95,7 @@ class CustomDatasetDataLoader(object):
95
95
  self.dataloader = torch.utils.data.DataLoader(
96
96
  self.dataset,
97
97
  sampler=sampler,
98
- batch_size=opt.batch_size,
98
+ batch_size=batch_size,
99
99
  shuffle=not opt.serial_batches if sampler is None else False,
100
100
  num_workers=int(opt.num_threads)
101
101
  )
@@ -106,7 +106,7 @@ class CustomDatasetDataLoader(object):
106
106
  self.dataloader = torch.utils.data.DataLoader(
107
107
  self.dataset,
108
108
  sampler=sampler,
109
- batch_size=opt.batch_size,
109
+ batch_size=batch_size,
110
110
  shuffle=not opt.serial_batches if sampler is None else False,
111
111
  num_workers=int(opt.num_threads),
112
112
  worker_init_fn=seed_worker,
@@ -11,7 +11,7 @@ class AlignedDataset(BaseDataset):
11
11
  During test time, you need to prepare a directory '/path/to/data/test'.
12
12
  """
13
13
 
14
- def __init__(self, opt):
14
+ def __init__(self, opt, phase='train'):
15
15
  """Initialize this dataset class.
16
16
 
17
17
  Parameters:
@@ -19,7 +19,7 @@ class AlignedDataset(BaseDataset):
19
19
  """
20
20
  BaseDataset.__init__(self, opt.dataroot)
21
21
  self.preprocess = opt.preprocess
22
- self.dir_AB = os.path.join(opt.dataroot, opt.phase) # get the image directory
22
+ self.dir_AB = os.path.join(opt.dataroot, phase) # get the image directory
23
23
  self.AB_paths = sorted(make_dataset(self.dir_AB, opt.max_dataset_size)) # get image paths
24
24
  assert(opt.load_size >= opt.crop_size) # crop_size should be smaller than the size of loaded image
25
25
  self.input_nc = opt.output_nc if opt.direction == 'BtoA' else opt.input_nc
@@ -95,7 +95,6 @@ class AlignedDataset(BaseDataset):
95
95
  A = AB.crop((w2 * i, 0, w2 * (i+1), h))
96
96
  A = A_transform(A)
97
97
  A_Array.append(A)
98
-
99
98
  for i in range(self.input_no, self.input_no + self.modalities_no + 1):
100
99
  B = AB.crop((w2 * i, 0, w2 * (i + 1), h))
101
100
  B = B_transform(B)
@@ -1,5 +1,5 @@
1
1
  import os.path
2
- from deepliif.data.base_dataset import BaseDataset, get_transform
2
+ from deepliif.data.base_dataset import BaseDataset, get_params, get_transform
3
3
  from deepliif.data.image_folder import make_dataset
4
4
  from PIL import Image
5
5
  import random
@@ -16,25 +16,39 @@ class UnalignedDataset(BaseDataset):
16
16
  '/path/to/data/testA' and '/path/to/data/testB' during test time.
17
17
  """
18
18
 
19
- def __init__(self, opt):
19
+ def __init__(self, opt, phase='train'):
20
20
  """Initialize this dataset class.
21
21
 
22
22
  Parameters:
23
23
  opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
24
24
  """
25
25
  BaseDataset.__init__(self, opt)
26
- self.dir_A = os.path.join(opt.dataroot, opt.phase + 'A') # create a path '/path/to/data/trainA'
27
- self.dir_B = os.path.join(opt.dataroot, opt.phase + 'B') # create a path '/path/to/data/trainB'
26
+ self.opt = opt
27
+ self.input_nc = opt.output_nc if opt.direction == 'BtoA' else opt.input_nc
28
+ self.output_nc = opt.input_nc if opt.direction == 'BtoA' else opt.output_nc
29
+ self.preprocess = opt.preprocess
30
+ self.no_flip = opt.no_flip
31
+ self.modalities_no = opt.modalities_no
32
+ self.seg_no = opt.seg_no
33
+ self.input_no = opt.input_no
34
+ self.seg_gen = opt.seg_gen
35
+ self.load_size = opt.load_size
36
+ self.crop_size = opt.crop_size
37
+ self.model = opt.model
38
+
39
+ self.dir_A = os.path.join(opt.dataroot, phase + 'A') # create a path '/path/to/data/trainA'
40
+ # trainB1/trainB2/trainB3... are organized as elements of DATASET B which is a list
41
+ self.dirs_B = [os.path.join(opt.dataroot, phase + f'B{i}') for i in range(1,self.modalities_no+1)] # create a list of paths ['/path/to/data/trainB',...]
28
42
 
29
43
  self.A_paths = sorted(make_dataset(self.dir_A, opt.max_dataset_size)) # load images from '/path/to/data/trainA'
30
- self.B_paths = sorted(make_dataset(self.dir_B, opt.max_dataset_size)) # load images from '/path/to/data/trainB'
44
+ self.B_paths = [sorted(make_dataset(dir_B, opt.max_dataset_size)) for dir_B in self.dirs_B] # load images from '/path/to/data/trainB', '/path/to/data/trainC', ...
31
45
  self.A_size = len(self.A_paths) # get the size of dataset A
32
- self.B_size = len(self.B_paths) # get the size of dataset B
46
+ self.B_sizes = [len(B_paths) for B_paths in self.B_paths] # get the size of dataset B1, B2, B3, ...
33
47
  btoA = self.opt.direction == 'BtoA'
34
48
  input_nc = self.opt.output_nc if btoA else self.opt.input_nc # get the number of channels of input image
35
49
  output_nc = self.opt.input_nc if btoA else self.opt.output_nc # get the number of channels of output image
36
- self.transform_A = get_transform(self.opt, grayscale=(input_nc == 1))
37
- self.transform_B = get_transform(self.opt, grayscale=(output_nc == 1))
50
+
51
+
38
52
 
39
53
  def __getitem__(self, index):
40
54
  """Return a data point and its metadata information.
@@ -50,22 +64,27 @@ class UnalignedDataset(BaseDataset):
50
64
  """
51
65
  A_path = self.A_paths[index % self.A_size] # make sure index is within then range
52
66
  if self.opt.serial_batches: # make sure index is within then range
53
- index_B = index % self.B_size
67
+ indice_B = [index % B_size for B_size in self.B_sizes]
54
68
  else: # randomize the index for domain B to avoid fixed pairs.
55
- index_B = random.randint(0, self.B_size - 1)
56
- B_path = self.B_paths[index_B]
69
+ indice_B = [random.randint(0, B_size - 1) for B_size in self.B_sizes]
70
+ B_paths = [B_paths[index_B] for B_paths, index_B in zip(self.B_paths, indice_B)]
71
+
57
72
  A_img = Image.open(A_path).convert('RGB')
58
- B_img = Image.open(B_path).convert('RGB')
73
+ B_imgs = [Image.open(B_path).convert('RGB') for B_path in B_paths]
74
+
59
75
  # apply image transformation
60
- A = self.transform_A(A_img)
61
- B = self.transform_B(B_img)
62
-
63
- return {'A': A, 'B': B, 'A_paths': A_path, 'B_paths': B_path}
76
+ transform_params = get_params(self.preprocess, self.load_size, self.crop_size, A_img.size)
77
+ A_transform = get_transform(self.preprocess, self.load_size, self.crop_size, self.no_flip, transform_params, grayscale=(self.input_nc == 1))
78
+ B_transform = get_transform(self.preprocess, self.load_size, self.crop_size, self.no_flip, transform_params, grayscale=(self.output_nc == 1))
79
+ A = A_transform(A_img)
80
+ Bs = [B_transform(B_img) for B_img in B_imgs]
81
+
82
+ return {'A': A, 'Bs': Bs, 'A_paths': A_path, 'B_paths': B_paths}
64
83
 
65
84
  def __len__(self):
66
85
  """Return the total number of images in the dataset.
67
86
 
68
- As we have two datasets with potentially different number of images,
69
- we take a maximum of
87
+
88
+ The effective size of this dataset will be the size of datasetA through which we loop and grab a random/matching image B1/B2/B3... for
70
89
  """
71
- return max(self.A_size, self.B_size)
90
+ return self.A_size #max(self.A_size, self.B_size)
@@ -0,0 +1,282 @@
1
+ import torch
2
+ from packaging import version
3
+ from torch import nn
4
+ import itertools
5
+ from ..util.image_pool import ImagePool
6
+ from .base_model import BaseModel
7
+ from . import networks
8
+ from .networks import get_optimizer
9
+
10
+
11
+ class CycleGANModel(BaseModel):
12
+ """
13
+ This class implements the CycleGAN model, for learning image-to-image translation without paired data.
14
+
15
+ The model training requires '--dataset_mode unaligned' dataset.
16
+ By default, it uses a '--netG resnet_9blocks' ResNet generator,
17
+ a '--netD basic' discriminator (PatchGAN introduced by pix2pix),
18
+ and a least-square GANs objective ('--gan_mode lsgan').
19
+
20
+ CycleGAN paper: https://arxiv.org/pdf/1703.10593.pdf
21
+ """
22
+
23
+ def __init__(self, opt):
24
+ """Initialize the CycleGAN class.
25
+
26
+ Parameters:
27
+ opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions
28
+ """
29
+ BaseModel.__init__(self, opt)
30
+ self.mod_gen_no = self.opt.modalities_no
31
+ if not hasattr(self.opt,'upsample'):
32
+ self.opt.upsample = 'convtranspose'
33
+ if not hasattr(self.opt,'label_smoothing'):
34
+ self.opt.label_smoothing = 0
35
+
36
+ use_spectral_norm = self.opt.norm == 'spectral'
37
+
38
+ self.loss_G_weights = opt.loss_G_weights
39
+ self.loss_D_weights = opt.loss_D_weights
40
+ self.loss_cyc_weights = [1 / self.mod_gen_no] * self.mod_gen_no
41
+
42
+ self.opt.lambda_identity = 0 # do not use lambda identity for the first trial
43
+ # specify the training losses you want to print out. The training/test scripts will call <BaseModel.get_current_losses>
44
+ self.loss_names = ['D_A', 'G_A', 'cycle_A', 'idt_A', 'D_B', 'G_B', 'cycle_B', 'idt_B']
45
+ # specify the images you want to save/display. The training/test scripts will call <BaseModel.get_current_visuals>
46
+ l_suffix = range(1, self.opt.modalities_no + 1)
47
+ visual_names_A = [f'real_As_{i}' for i in l_suffix] + [f'fake_Bs_{i}' for i in l_suffix] + [f'rec_As_{i}' for i in l_suffix]
48
+ visual_names_B = [f'real_Bs_{i}' for i in l_suffix] + [f'fake_As_{i}' for i in l_suffix] + [f'rec_Bs_{i}' for i in l_suffix]
49
+
50
+ # if self.is_train and self.opt.lambda_identity > 0.0: # if identity loss is used, we also visualize idt_B=G_A(B) ad idt_A=G_A(B)
51
+ # visual_names_A.append('idt_B')
52
+ # visual_names_B.append('idt_A')
53
+
54
+ self.visual_names = visual_names_A + visual_names_B # combine visualizations for A and B
55
+ # specify the models you want to save to the disk. The training/test scripts will call <BaseModel.save_networks> and <BaseModel.load_networks>.
56
+ if self.is_train:
57
+ self.model_names = [f'GA_{i}' for i in l_suffix] + [f'GB_{i}' for i in l_suffix] + [f'DA_{i}' for i in l_suffix] + [f'DB_{i}' for i in l_suffix]
58
+ else: # during test time, only load Gs
59
+ if self.opt.BtoA:
60
+ self.model_names = [f'GB_{i}' for i in l_suffix]
61
+ else:
62
+ self.model_names = [f'GA_{i}' for i in l_suffix]
63
+
64
+ # define networks (both Generators and discriminators)
65
+ # The naming is different from those used in the paper.
66
+ # Code (vs. paper): G_A (G), G_B (F), D_A (D_Y), D_B (D_X)
67
+ if isinstance(opt.net_g, str):
68
+ self.opt.net_g = [self.opt.net_g] * self.mod_gen_no
69
+
70
+ if version.parse(torch.__version__) < version.parse('1.11.0'):
71
+ self.netGA = list()
72
+ self.netGB = list()
73
+ else:
74
+ self.netGA = nn.ModuleList()
75
+ self.netGB = nn.ModuleList()
76
+
77
+ for i in range(self.mod_gen_no):
78
+ if self.is_train or not self.opt.BtoA:
79
+ self.netGA.append(networks.define_G(self.opt.input_nc, self.opt.output_nc, self.opt.ngf, self.opt.net_g[i], self.opt.norm,
80
+ not self.opt.no_dropout, self.opt.init_type, self.opt.init_gain, self.gpu_ids, self.opt.padding,
81
+ upsample=self.opt.upsample))
82
+ if self.is_train or self.opt.BtoA:
83
+ self.netGB.append(networks.define_G(self.opt.output_nc, self.opt.input_nc, self.opt.ngf, self.opt.net_g[i], self.opt.norm,
84
+ not self.opt.no_dropout, self.opt.init_type, self.opt.init_gain, self.gpu_ids, self.opt.padding,
85
+ upsample=self.opt.upsample))
86
+
87
+ if self.is_train: # define a discriminator; conditional GANs need to take both input and output images; Therefore, #channels for D is input_nc + output_nc
88
+ if version.parse(torch.__version__) < version.parse('1.11.0'):
89
+ self.netDA = list()
90
+ self.netDB = list()
91
+ else:
92
+ self.netDA = nn.ModuleList()
93
+ self.netDB = nn.ModuleList()
94
+
95
+ for i in range(self.mod_gen_no):
96
+ self.netDA.append(networks.define_D(self.opt.output_nc, self.opt.ndf, self.opt.net_d,
97
+ self.opt.n_layers_D, self.opt.norm, self.opt.init_type, self.opt.init_gain,
98
+ self.gpu_ids))
99
+ self.netDB.append(networks.define_D(self.opt.input_nc, self.opt.ndf, self.opt.net_d,
100
+ self.opt.n_layers_D, self.opt.norm, self.opt.init_type, self.opt.init_gain,
101
+ self.gpu_ids))
102
+
103
+
104
+
105
+ if self.is_train:
106
+ if opt.lambda_identity > 0.0: # only works when input and output images have the same number of channels
107
+ assert(opt.input_nc == opt.output_nc)
108
+ self.fake_A_pools = [ImagePool(opt.pool_size) for _ in range(self.opt.modalities_no)] # create image buffer to store previously generated images
109
+ self.fake_B_pools = [ImagePool(opt.pool_size) for _ in range(self.opt.modalities_no)] # create image buffer to store previously generated images
110
+
111
+ # define loss functions
112
+ # label smoothing currently only applies to discriminator losses & generatoe of lsgan/vanilla
113
+ self.criterionGAN = networks.GANLoss(opt.gan_mode, label_smoothing=self.opt.label_smoothing).to(self.device) # define GAN loss.
114
+ self.criterionCycle = torch.nn.L1Loss()
115
+ self.criterionIdt = torch.nn.L1Loss()
116
+
117
+ self.criterionSmoothL1 = torch.nn.SmoothL1Loss()
118
+ self.criterionVGG = networks.VGGLoss().to(self.device)
119
+
120
+
121
+ # initialize optimizers
122
+ params = []
123
+ for i in range(len(self.netGA)):
124
+ params += list(self.netGA[i].parameters())
125
+ for i in range(len(self.netGB)):
126
+ params += list(self.netGB[i].parameters())
127
+ try:
128
+ self.optimizer_G = get_optimizer(opt.optimizer)(params, lr=opt.lr_g, betas=(opt.beta1, 0.999))
129
+ except:
130
+ print(f'betas are not used for optimizer torch.optim.{opt.optimizer} in generators')
131
+ self.optimizer_G = get_optimizer(opt.optimizer)(params, lr=opt.lr_g)
132
+
133
+ params = []
134
+ for i in range(len(self.netDA)):
135
+ params += list(self.netDA[i].parameters())
136
+ for i in range(len(self.netDB)):
137
+ params += list(self.netDB[i].parameters())
138
+
139
+ # a smaller learning rate for discriminators to postpone training failure due to discriminators quickly become too strong
140
+ try:
141
+ self.optimizer_D = get_optimizer(opt.optimizer)(params, lr=opt.lr_d, betas=(opt.beta1, 0.999))
142
+ except:
143
+ print(f'betas are not used for optimizer torch.optim.{opt.optimizer} in generators')
144
+ self.optimizer_D = get_optimizer(opt.optimizer)(params, lr=opt.lr_d)
145
+
146
+
147
+ self.optimizers.append(self.optimizer_G)
148
+ self.optimizers.append(self.optimizer_D)
149
+
150
+ def set_input(self, input):
151
+ """Unpack input data from the dataloader and perform necessary pre-processing steps.
152
+
153
+ Parameters:
154
+ input (dict): include the data itself and its metadata information.
155
+
156
+ The option 'direction' can be used to swap domain A and domain B.
157
+ """
158
+ self.real_As = [input['A'].to(self.device) for _ in range(self.opt.modalities_no)]
159
+ self.real_Bs = [x.to(self.device) for x in input['Bs']]
160
+ self.image_paths = input['A_paths']
161
+
162
+ def forward(self):
163
+ """
164
+ Run forward pass; called by both functions <optimize_parameters> and <test>.
165
+ During inference, some output list could be empty. For example, if only netGAs are loaded,
166
+ there will not be valid elements in self.rec_As and self.fake_As.
167
+ """
168
+ self.fake_Bs = [netGA(real_A) for netGA, real_A in zip(self.netGA, self.real_As)] # G_A(A)
169
+ self.rec_As = [netGB(fake_B) for netGB, fake_B in zip(self.netGB, self.fake_Bs)] # G_B(G_A(A))
170
+
171
+ self.fake_As = [netGB(real_B) for netGB, real_B in zip(self.netGB, self.real_Bs)] # G_B(B)
172
+ self.rec_Bs = [netGA(fake_A) for netGA, fake_A in zip(self.netGA, self.fake_As)] # G_A(G_B(B))
173
+
174
+
175
+ def backward_D_basic(self, netD, real, fake, scale_factor=1):
176
+ """Calculate GAN loss for the discriminator
177
+
178
+ Parameters:
179
+ netD (network) -- the discriminator D
180
+ real (tensor array) -- real images
181
+ fake (tensor array) -- images generated by a generator
182
+
183
+ Return the discriminator loss.
184
+ We also call loss_D.backward() to calculate the gradients.
185
+ """
186
+ # Real
187
+ pred_real = netD(real)
188
+ loss_D_real = self.criterionGAN(pred_real, True)
189
+ # Fake
190
+ pred_fake = netD(fake.detach())
191
+ loss_D_fake = self.criterionGAN(pred_fake, False)
192
+ # Combined loss and calculate gradients
193
+ loss_D = (loss_D_real + loss_D_fake) * 0.5 * scale_factor
194
+ loss_D.backward()
195
+ return loss_D
196
+
197
+ def backward_D_A(self):
198
+ """Calculate GAN loss for discriminator D_A"""
199
+ fake_Bs = [fake_B_pool.query(fake_B) for fake_B_pool, fake_B in zip(self.fake_B_pools, self.fake_Bs)]
200
+ real_Bs = self.real_Bs
201
+
202
+ self.loss_D_A = 0
203
+ for i, (netDA, real_B, fake_B) in enumerate(zip(self.netDA, real_Bs, fake_Bs)):
204
+ self.loss_D_A += self.backward_D_basic(netDA, real_B, fake_B, scale_factor=self.loss_D_weights[i])
205
+ #self.loss_D_A.backward()
206
+
207
+ def backward_D_B(self):
208
+ """Calculate GAN loss for discriminator D_B"""
209
+ fake_As = [fake_A_pool.query(fake_A) for fake_A_pool, fake_A in zip(self.fake_A_pools, self.fake_As)]
210
+ real_As = self.real_As
211
+
212
+ self.loss_D_B = 0
213
+ for i, (netDB, real_A, fake_A) in enumerate(zip(self.netDB, real_As, fake_As)):
214
+ self.loss_D_B += self.backward_D_basic(netDB, real_A, fake_A, scale_factor=self.loss_D_weights[i])
215
+ #self.loss_D_B.backward()
216
+
217
+ def backward_G(self):
218
+ """Calculate the loss for generators G_A and G_B"""
219
+ # default lambda values from cyclegan implementation:
220
+ # https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/c3268edd50ec37a81600c9b981841f48929671b8/models/cycle_gan_model.py#L41
221
+ lambda_idt = 0#self.opt.lambda_identity # identity loss is used to preserve color consistency between input and output images, which we do not want to encourage
222
+ lambda_A = 10#self.opt.lambda_A
223
+ lambda_B = 10#self.opt.lambda_B
224
+ # Identity loss
225
+ if lambda_idt > 0:
226
+ # G_A should be identity if real_B is fed: ||G_A(B) - B||
227
+ self.idt_A = self.netG_A(self.real_B)
228
+ self.loss_idt_A = self.criterionIdt(self.idt_A, self.real_B) * lambda_B * lambda_idt
229
+ # G_B should be identity if real_A is fed: ||G_B(A) - A||
230
+ self.idt_B = self.netG_B(self.real_A)
231
+ self.loss_idt_B = self.criterionIdt(self.idt_B, self.real_A) * lambda_A * lambda_idt
232
+ else:
233
+ self.loss_idt_A = 0
234
+ self.loss_idt_B = 0
235
+
236
+ # GAN loss D_A(G_A(A))
237
+ self.loss_G_A = 0
238
+ for i, (netDA, fake_B, real_B) in enumerate(zip(self.netDA, self.fake_Bs, self.real_Bs)):
239
+ self.loss_G_A += self.criterionGAN(netDA(fake_B), True) * self.loss_G_weights[i]
240
+ self.loss_G_A += self.criterionVGG(fake_B, real_B) * self.loss_G_weights[i]
241
+
242
+ # GAN loss D_B(G_B(B))
243
+ self.loss_G_B = 0
244
+ for i, (netDB, fake_A, real_A) in enumerate(zip(self.netDB, self.fake_As, self.real_As)):
245
+ self.loss_G_B += self.criterionGAN(netDB(fake_A), True) * self.loss_G_weights[i]
246
+ self.loss_G_B += self.criterionVGG(fake_A, real_A) * self.loss_G_weights[i]
247
+
248
+ # Forward cycle loss || G_B(G_A(A)) - A||
249
+ self.loss_cycle_A = 0
250
+ for i, (rec_A, real_A) in enumerate(zip(self.rec_As, self.real_As)):
251
+ self.loss_cycle_A += self.criterionCycle(rec_A, real_A) * lambda_A * self.loss_cyc_weights[i]
252
+ # Backward cycle loss || G_A(G_B(B)) - B||
253
+ self.loss_cycle_B = 0
254
+ for i, (rec_B, real_B) in enumerate(zip(self.rec_Bs, self.real_Bs)):
255
+ self.loss_cycle_B += self.criterionCycle(rec_B, real_B) * lambda_B * self.loss_cyc_weights[i]
256
+
257
+ # VGG loss
258
+ # self.loss_G_VGG = self.criterionVGG(self.fake_B_1, self.real_B_1) * self.opt.lambda_feat
259
+
260
+ # smooth L1
261
+ # self.loss_G_A_L1 = self.criterionSmoothL1(self.fake_B_1, self.real_B_1) * self.opt.lambda_L1
262
+
263
+ # combined loss and calculate gradients
264
+ self.loss_G = self.loss_G_A + self.loss_G_B + self.loss_cycle_A + self.loss_cycle_B + self.loss_idt_A + self.loss_idt_B
265
+ self.loss_G.backward()
266
+
267
+ def optimize_parameters(self):
268
+ """Calculate losses, gradients, and update network weights; called in every training iteration"""
269
+ # forward
270
+ self.forward() # compute fake images and reconstruction images.
271
+ # G_A and G_B
272
+ self.set_requires_grad(self.netDA + self.netDB, False) # Ds require no gradients when optimizing Gs
273
+ self.optimizer_G.zero_grad() # set G_A and G_B's gradients to zero
274
+ self.backward_G() # calculate gradients for G_A and G_B
275
+ self.optimizer_G.step() # update G_A and G_B's weights
276
+
277
+ # D_A and D_B
278
+ self.set_requires_grad(self.netDA + self.netDB, True)
279
+ self.optimizer_D.zero_grad() # set D_A and D_B's gradients to zero
280
+ self.backward_D_A() # calculate gradients for D_A
281
+ self.backward_D_B() # calculate graidents for D_B
282
+ self.optimizer_D.step() # update D_A and D_B's weights
@@ -1,6 +1,7 @@
1
1
  import torch
2
2
  from .base_model import BaseModel
3
3
  from . import networks
4
+ from .networks import get_optimizer
4
5
 
5
6
 
6
7
  class DeepLIIFExtModel(BaseModel):
@@ -18,19 +19,16 @@ class DeepLIIFExtModel(BaseModel):
18
19
  # self.seg_gen_no = self.opt.modalities_no + 1
19
20
 
20
21
  # weights of the modalities in generating segmentation mask
21
- self.seg_weights = [0, 0, 0]
22
- if opt.seg_gen:
23
- self.seg_weights = [0.3] * self.mod_gen_no
24
- self.seg_weights[1] = 0.4
22
+ self.seg_weights = opt.seg_weights
25
23
 
26
24
  # self.seg_weights = opt.seg_weights
27
25
  # assert len(self.seg_weights) == self.seg_gen_no, 'The number of the segmentation weights (seg_weights) is not equal to the number of target images (modalities_no)!'
28
26
  # print(self.seg_weights)
29
27
  # loss weights in calculating the final loss
30
- self.loss_G_weights = [1 / self.mod_gen_no] * self.mod_gen_no
28
+ self.loss_G_weights = opt.loss_G_weights
31
29
  self.loss_GS_weights = [1 / self.mod_gen_no] * self.mod_gen_no
32
30
 
33
- self.loss_D_weights = [1 / self.mod_gen_no] * self.mod_gen_no
31
+ self.loss_D_weights = opt.loss_D_weights
34
32
  self.loss_DS_weights = [1 / self.mod_gen_no] * self.mod_gen_no
35
33
 
36
34
  # self.gpu_ids is a possibly modifed one for model initialization
@@ -72,22 +70,19 @@ class DeepLIIFExtModel(BaseModel):
72
70
  self.model_names.extend(['GS_' + str(i)])
73
71
 
74
72
  # define networks (both generator and discriminator)
73
+ if isinstance(opt.net_g, str):
74
+ self.opt.net_g = [self.opt.net_g] * self.mod_gen_no
75
+ if isinstance(opt.net_gs, str):
76
+ self.opt.net_gs = [self.opt.net_gs]*self.mod_gen_no
75
77
  self.netG = [None for _ in range(self.mod_gen_no)]
76
78
  self.netGS = [None for _ in range(self.mod_gen_no)]
77
79
  for i in range(self.mod_gen_no):
78
- self.netG[i] = networks.define_G(self.opt.input_nc, self.opt.output_nc, self.opt.ngf, self.opt.net_g, self.opt.norm,
80
+ self.netG[i] = networks.define_G(self.opt.input_nc, self.opt.output_nc, self.opt.ngf, self.opt.net_g[i], self.opt.norm,
79
81
  not self.opt.no_dropout, self.opt.init_type, self.opt.init_gain, self.gpu_ids, self.opt.padding)
80
- print('***************************************')
81
- print(self.opt.input_nc, self.opt.output_nc, self.opt.ngf, self.opt.net_g, self.opt.norm,
82
- not self.opt.no_dropout, self.opt.init_type, self.opt.init_gain, self.gpu_ids, self.opt.padding)
83
- print('***************************************')
82
+
84
83
  for i in range(self.mod_gen_no):
85
84
  if self.opt.seg_gen:
86
- # if i == 0:
87
- # self.netGS[i] = networks.define_G(self.opt.input_nc, self.opt.output_nc, self.opt.ngf, self.opt.net_gs, self.opt.norm,
88
- # not self.opt.no_dropout, self.opt.init_type, self.opt.init_gain, self.gpu_ids)
89
- # else:
90
- self.netGS[i] = networks.define_G(self.opt.input_nc * 3, self.opt.output_nc, self.opt.ngf, self.opt.net_gs, self.opt.norm,
85
+ self.netGS[i] = networks.define_G(self.opt.input_nc * 3, self.opt.output_nc, self.opt.ngf, self.opt.net_gs[i], self.opt.norm,
91
86
  not self.opt.no_dropout, self.opt.init_type, self.opt.init_gain, self.gpu_ids)
92
87
 
93
88
  if self.is_train: # define a discriminator; conditional GANs need to take both input and output images; Therefore, #channels for D is input_nc + output_nc
@@ -99,11 +94,6 @@ class DeepLIIFExtModel(BaseModel):
99
94
  self.gpu_ids)
100
95
  for i in range(self.mod_gen_no):
101
96
  if self.opt.seg_gen:
102
- # if i == 0:
103
- # self.netDS[i] = networks.define_D(self.opt.input_nc + self.opt.output_nc, self.opt.ndf, self.opt.net_ds,
104
- # self.opt.n_layers_D, self.opt.norm, self.opt.init_type, self.opt.init_gain,
105
- # self.gpu_ids)
106
- # else:
107
97
  self.netDS[i] = networks.define_D(self.opt.input_nc * 3 + self.opt.output_nc, self.opt.ndf, self.opt.net_ds,
108
98
  self.opt.n_layers_D, self.opt.norm, self.opt.init_type, self.opt.init_gain,
109
99
  self.gpu_ids)
@@ -113,9 +103,7 @@ class DeepLIIFExtModel(BaseModel):
113
103
  # define loss functions
114
104
  self.criterionGAN_mod = networks.GANLoss(self.opt.gan_mode).to(self.device)
115
105
  self.criterionGAN_seg = networks.GANLoss(self.opt.gan_mode_s).to(self.device)
116
-
117
106
  self.criterionSmoothL1 = torch.nn.SmoothL1Loss()
118
-
119
107
  self.criterionVGG = networks.VGGLoss().to(self.device)
120
108
 
121
109
  # initialize optimizers; schedulers will be automatically created by function <BaseModel.setup>.
@@ -125,7 +113,11 @@ class DeepLIIFExtModel(BaseModel):
125
113
  for i in range(len(self.netGS)):
126
114
  if self.netGS[i]:
127
115
  params += list(self.netGS[i].parameters())
128
- self.optimizer_G = torch.optim.Adam(params, lr=opt.lr, betas=(opt.beta1, 0.999))
116
+ try:
117
+ self.optimizer_G = get_optimizer(opt.optimizer)(params, lr=opt.lr_g, betas=(opt.beta1, 0.999))
118
+ except:
119
+ print(f'betas are not used for optimizer torch.optim.{opt.optimizer} in generators')
120
+ self.optimizer_G = get_optimizer(opt.optimizer)(params, lr=opt.lr_g)
129
121
 
130
122
  params = []
131
123
  for i in range(len(self.netD)):
@@ -133,7 +125,11 @@ class DeepLIIFExtModel(BaseModel):
133
125
  for i in range(len(self.netDS)):
134
126
  if self.netDS[i]:
135
127
  params += list(self.netDS[i].parameters())
136
- self.optimizer_D = torch.optim.Adam(params, lr=opt.lr, betas=(opt.beta1, 0.999))
128
+ try:
129
+ self.optimizer_D = get_optimizer(opt.optimizer)(params, lr=opt.lr_d, betas=(opt.beta1, 0.999))
130
+ except:
131
+ print(f'betas are not used for optimizer torch.optim.{opt.optimizer} in discriminators')
132
+ self.optimizer_D = get_optimizer(opt.optimizer)(params, lr=opt.lr_d)
137
133
 
138
134
  self.optimizers.append(self.optimizer_G)
139
135
  self.optimizers.append(self.optimizer_D)
@@ -295,3 +291,29 @@ class DeepLIIFExtModel(BaseModel):
295
291
  self.optimizer_G.zero_grad() # set G's gradients to zero
296
292
  self.backward_G() # calculate graidents for G
297
293
  self.optimizer_G.step() # udpate G's weights
294
+
295
+ def calculate_losses(self):
296
+ """
297
+ Calculate losses but do not optimize parameters. Used in validation loss calculation during training.
298
+ """
299
+ self.forward() # compute fake images: G(A)
300
+ # update D
301
+ for i in range(self.mod_gen_no):
302
+ self.set_requires_grad(self.netD[i], True) # enable backprop for D1
303
+ for i in range(self.mod_gen_no):
304
+ if self.netDS[i]:
305
+ self.set_requires_grad(self.netDS[i], True)
306
+
307
+ self.optimizer_D.zero_grad() # set D's gradients to zero
308
+ self.backward_D() # calculate gradients for D
309
+
310
+ # update G
311
+ for i in range(self.mod_gen_no):
312
+ self.set_requires_grad(self.netD[i], False)
313
+ for i in range(self.mod_gen_no):
314
+ if self.netDS[i]:
315
+ self.set_requires_grad(self.netDS[i], False)
316
+
317
+ self.optimizer_G.zero_grad() # set G's gradients to zero
318
+ self.backward_G() # calculate graidents for G
319
+