deepliif 1.1.11__py3-none-any.whl → 1.1.12__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,199 @@
1
+ # adapted from https://github.com/LeeJunHyun/Image_Segmentation/blob/master/network.py
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ from torch.nn import init
7
+
8
+ def init_weights(net, init_type='normal', gain=0.02):
9
+ def init_func(m):
10
+ classname = m.__class__.__name__
11
+ if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
12
+ if init_type == 'normal':
13
+ init.normal_(m.weight.data, 0.0, gain)
14
+ elif init_type == 'xavier':
15
+ init.xavier_normal_(m.weight.data, gain=gain)
16
+ elif init_type == 'kaiming':
17
+ init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
18
+ elif init_type == 'orthogonal':
19
+ init.orthogonal_(m.weight.data, gain=gain)
20
+ else:
21
+ raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
22
+ if hasattr(m, 'bias') and m.bias is not None:
23
+ init.constant_(m.bias.data, 0.0)
24
+ elif classname.find('BatchNorm2d') != -1:
25
+ init.normal_(m.weight.data, 1.0, gain)
26
+ init.constant_(m.bias.data, 0.0)
27
+
28
+ print('initialize network with %s' % init_type)
29
+ net.apply(init_func)
30
+
31
+ class conv_block(nn.Module):
32
+ def __init__(self,ch_in,ch_out,innermost=False,outermost=False):
33
+ super(conv_block,self).__init__()
34
+ if outermost:
35
+ self.conv = nn.Sequential(
36
+ nn.Conv2d(ch_in, ch_out, kernel_size=4,stride=2,padding=1,bias=True),
37
+ nn.LeakyReLU(0.2, True),
38
+ )
39
+ elif innermost:
40
+ self.conv = nn.Sequential(
41
+ nn.Conv2d(ch_in, ch_out, kernel_size=4,stride=2,padding=1,bias=True),
42
+ nn.ReLU(inplace=True),
43
+ )
44
+ else:
45
+ self.conv = nn.Sequential(
46
+ nn.Conv2d(ch_in, ch_out, kernel_size=4,stride=2,padding=1,bias=True),
47
+ nn.BatchNorm2d(ch_out),
48
+ nn.LeakyReLU(0.2, True),
49
+ )
50
+
51
+ def forward(self,x):
52
+ x = self.conv(x)
53
+ return x
54
+
55
+ class up_conv(nn.Module):
56
+ def __init__(self,ch_in,ch_out,innermost=False,outermost=False):
57
+ super(up_conv,self).__init__()
58
+ use_bias=False
59
+ if outermost:
60
+ self.up = nn.Sequential(
61
+ nn.ConvTranspose2d(ch_in * 2, ch_out,
62
+ kernel_size=4, stride=2,
63
+ padding=1),
64
+ nn.Tanh())
65
+ elif innermost:
66
+ self.up = nn.Sequential(
67
+ nn.ConvTranspose2d(ch_in, ch_out,
68
+ kernel_size=4, stride=2,
69
+ padding=1, bias=use_bias),
70
+ nn.BatchNorm2d(ch_out),
71
+ nn.ReLU(True))
72
+ else:
73
+ self.up = nn.Sequential(
74
+ nn.ConvTranspose2d(ch_in * 2, ch_out,
75
+ kernel_size=4, stride=2,
76
+ padding=1, bias=use_bias),
77
+ nn.BatchNorm2d(ch_out),
78
+ nn.ReLU(True))
79
+
80
+
81
+ def forward(self,x):
82
+ x = self.up(x)
83
+ return x
84
+
85
+
86
+ class Attention_block(nn.Module):
87
+ def __init__(self,F_g,F_l,F_int):
88
+ super(Attention_block,self).__init__()
89
+ self.W_g = nn.Sequential(
90
+ nn.Conv2d(F_g, F_int, kernel_size=1,stride=1,padding=0,bias=True),
91
+ nn.BatchNorm2d(F_int)
92
+ )
93
+
94
+ self.W_x = nn.Sequential(
95
+ nn.Conv2d(F_l, F_int, kernel_size=1,stride=1,padding=0,bias=True),
96
+ nn.BatchNorm2d(F_int)
97
+ )
98
+
99
+ self.psi = nn.Sequential(
100
+ nn.Conv2d(F_int, 1, kernel_size=1,stride=1,padding=0,bias=True),
101
+ nn.BatchNorm2d(1),
102
+ nn.Sigmoid()
103
+ )
104
+
105
+ self.relu = nn.ReLU(inplace=True)
106
+
107
+ def forward(self,g,x):
108
+ g1 = self.W_g(g)
109
+ x1 = self.W_x(x)
110
+ psi = self.relu(g1+x1)
111
+ psi = self.psi(psi)
112
+
113
+ return x*psi
114
+
115
+
116
+
117
+ class AttU_Net(nn.Module):
118
+ def __init__(self,img_ch=3,output_ch=1):
119
+ super(AttU_Net,self).__init__()
120
+
121
+ self.Conv1 = conv_block(ch_in=img_ch,ch_out=64,outermost=True)
122
+ self.Conv2 = conv_block(ch_in=64,ch_out=128)
123
+ self.Conv3 = conv_block(ch_in=128,ch_out=256)
124
+ self.Conv4 = conv_block(ch_in=256,ch_out=512)
125
+ self.Conv5 = conv_block(ch_in=512,ch_out=512)
126
+ self.Conv6 = conv_block(ch_in=512,ch_out=512)
127
+ self.Conv7 = conv_block(ch_in=512,ch_out=512)
128
+ self.Conv8 = conv_block(ch_in=512,ch_out=512,innermost=True)
129
+ #self.Conv9 = conv_block(ch_in=512,ch_out=512,innermost=True)
130
+
131
+ self.Up8 = up_conv(ch_in=512,ch_out=512,innermost=True)
132
+ self.Att8 = Attention_block(F_g=512,F_l=512,F_int=512)
133
+
134
+ self.Up7 = up_conv(ch_in=512,ch_out=512)
135
+ self.Att7 = Attention_block(F_g=512,F_l=512,F_int=512)
136
+
137
+ self.Up6 = up_conv(ch_in=512,ch_out=512)
138
+ self.Att6 = Attention_block(F_g=512,F_l=512,F_int=512)
139
+
140
+ self.Up5 = up_conv(ch_in=512,ch_out=512)
141
+ self.Att5 = Attention_block(F_g=512,F_l=512,F_int=512)
142
+
143
+ self.Up4 = up_conv(ch_in=512,ch_out=256)
144
+ self.Att4 = Attention_block(F_g=256,F_l=256,F_int=128)
145
+
146
+ self.Up3 = up_conv(ch_in=256,ch_out=128)
147
+ self.Att3 = Attention_block(F_g=128,F_l=128,F_int=64)
148
+
149
+ self.Up2 = up_conv(ch_in=128,ch_out=64)
150
+ self.Att2 = Attention_block(F_g=64,F_l=64,F_int=32)
151
+
152
+ self.Up1 = up_conv(ch_in=64,ch_out=output_ch,outermost=True)
153
+
154
+ def forward(self,x):
155
+ # encoding path
156
+ x1 = self.Conv1(x)
157
+ x2 = self.Conv2(x1)
158
+ x3 = self.Conv3(x2)
159
+ x4 = self.Conv4(x3)
160
+ x5 = self.Conv5(x4)
161
+ x6 = self.Conv6(x5)
162
+ x7 = self.Conv7(x6)
163
+ x8 = self.Conv8(x7)
164
+ #x9 = self.Conv9(x8)
165
+
166
+ #d9 = self.Up
167
+ d8 = self.Up8(x8)
168
+ x7 = self.Att8(g=d8,x=x7)
169
+ d8 = torch.cat((x7,d8),dim=1)
170
+
171
+ d7 = self.Up7(d8)
172
+ x6 = self.Att7(g=d7,x=x6)
173
+ d7 = torch.cat((x6,d7),dim=1)
174
+
175
+ d6 = self.Up6(d7)
176
+ x5 = self.Att6(g=d6,x=x5)
177
+ d6 = torch.cat((x5,d6),dim=1)
178
+
179
+ d5 = self.Up5(d6)
180
+ x4 = self.Att5(g=d5,x=x4)
181
+ d5 = torch.cat((x4,d5),dim=1)
182
+
183
+
184
+ d4 = self.Up4(d5) # x4: [2, 512, 4, 4], d4: [2, 256, 4, 4]
185
+ x3 = self.Att4(g=d4,x=x3)
186
+ d4 = torch.cat((x3,d4),dim=1)
187
+
188
+ d3 = self.Up3(d4)
189
+ x2 = self.Att3(g=d3,x=x2)
190
+ d3 = torch.cat((x2,d3),dim=1)
191
+
192
+ d2 = self.Up2(d3)
193
+ x1 = self.Att2(g=d2,x=x1)
194
+ d2 = torch.cat((x1,d2),dim=1)
195
+
196
+ d1 = self.Up1(d2)
197
+
198
+ return d1
199
+
@@ -3,7 +3,7 @@ import torch
3
3
  from collections import OrderedDict
4
4
  from abc import ABC, abstractmethod
5
5
  from . import networks
6
- from ..util import disable_batchnorm_tracking_stats
6
+ from ..util import disable_batchnorm_tracking_stats, enable_batchnorm_tracking_stats
7
7
  from deepliif.util import *
8
8
  import itertools
9
9
 
@@ -30,7 +30,7 @@ class BaseModel(ABC):
30
30
  -- self.loss_names (str list): specify the training losses that you want to plot and save.
31
31
  -- self.model_names (str list): define networks used in our training.
32
32
  -- self.visual_names (str list): specify the images that you want to display and save.
33
- -- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See cycle_gan_model.py for an example.
33
+ -- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See CycleGAN_model.py for an example.
34
34
  """
35
35
  self.opt = opt
36
36
  self.gpu_ids = opt.gpu_ids
@@ -81,6 +81,7 @@ class BaseModel(ABC):
81
81
  Parameters:
82
82
  opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
83
83
  """
84
+ self.opt = opt
84
85
  if self.is_train:
85
86
  self.schedulers = [networks.get_scheduler(optimizer, opt) for optimizer in self.optimizers]
86
87
  if not self.is_train or opt.continue_train:
@@ -88,6 +89,17 @@ class BaseModel(ABC):
88
89
  self.load_networks(load_suffix)
89
90
  self.print_networks(opt.verbose)
90
91
 
92
+ def train(self):
93
+ """Make models train mode """
94
+ for name in self.model_names:
95
+ if isinstance(name, str):
96
+ if '_' in name:
97
+ net = getattr(self, 'net' + name.split('_')[0])[int(name.split('_')[-1]) - 1]
98
+ else:
99
+ net = getattr(self, 'net' + name)
100
+ net.train()
101
+ net = enable_batchnorm_tracking_stats(net)
102
+
91
103
  def eval(self):
92
104
  """Make models eval mode during test time"""
93
105
  for name in self.model_names:
@@ -134,10 +146,21 @@ class BaseModel(ABC):
134
146
  for name in self.visual_names:
135
147
  if isinstance(name, str):
136
148
  if not hasattr(self, name):
137
- if len(name.split('_')) == 2:
138
- visual_ret[name] = getattr(self, name.split('_')[0])[int(name.split('_')[-1]) -1]
149
+ if len(name.split('_')) != 2:
150
+ if self.opt.model == 'DeepLIIF':
151
+ img_name = name[:-1] + '_' + name[-1]
152
+ visual_ret[name] = getattr(self, img_name)
153
+ else:
154
+ if self.opt.model == 'CycleGAN':
155
+ l_output = getattr(self, name.split('_')[0] + '_' + name.split('_')[1])
156
+ if len(l_output) > 0:
157
+ visual_ret[name] = getattr(self, name.split('_')[0] + '_' + name.split('_')[1])[int(name.split('_')[-1]) - 1]
158
+ else:
159
+ print('No output for',name)
160
+ else:
161
+ visual_ret[name] = getattr(self, name.split('_')[0] + '_' + name.split('_')[1])[int(name.split('_')[-1]) - 1]
139
162
  else:
140
- visual_ret[name] = getattr(self, name.split('_')[0] + '_' + name.split('_')[1])[int(name.split('_')[-1]) - 1]
163
+ visual_ret[name] = getattr(self, name.split('_')[0])[int(name.split('_')[-1]) -1]
141
164
  else:
142
165
  visual_ret[name] = getattr(self, name)
143
166
  return visual_ret
@@ -240,6 +263,7 @@ class BaseModel(ABC):
240
263
  epoch (int) -- current epoch; used in the file name '%s_net_%s.pth' % (epoch, name)
241
264
  """
242
265
  for name in self.model_names:
266
+
243
267
  if isinstance(name, str):
244
268
  load_filename = '%s_net_%s.pth' % (epoch, name)
245
269
  load_path = os.path.join(self.save_dir, load_filename)
@@ -272,9 +296,9 @@ class BaseModel(ABC):
272
296
  if hasattr(state_dict, '_metadata'):
273
297
  del state_dict._metadata
274
298
 
275
- # patch InstanceNorm checkpoints prior to 0.4
276
- for key in list(state_dict.keys()): # need to copy keys here because we mutate in loop
277
- self.__patch_instance_norm_state_dict(state_dict, net, key.split('.'))
299
+ # # patch InstanceNorm checkpoints prior to 0.4
300
+ # for key in list(state_dict.keys()): # need to copy keys here because we mutate in loop
301
+ # self.__patch_instance_norm_state_dict(state_dict, net, key.split('.'))
278
302
  net.load_state_dict(state_dict)
279
303
 
280
304
  def print_networks(self, verbose):
@@ -3,15 +3,19 @@ import torch.nn as nn
3
3
  from torch.nn import init
4
4
  import functools
5
5
  from torch.optim import lr_scheduler
6
+
6
7
  import os
7
8
 
8
9
  from torchvision import models
9
-
10
+ from .att_unet import AttU_Net
10
11
  ###############################################################################
11
12
  # Helper Functions
12
13
  ###############################################################################
13
14
  from deepliif.util import util
14
15
 
16
+ # as of pytorch 2.4, all optimizers start with an uppercase letter
17
+ OPTIMIZER_MAPPING = {optimizer_name.lower():optimizer_name for optimizer_name in dir(torch.optim) if optimizer_name[0].isupper()}
18
+
15
19
 
16
20
  class Identity(nn.Module):
17
21
  def forward(self, x):
@@ -31,12 +35,22 @@ def get_norm_layer(norm_type='instance'):
31
35
  norm_layer = functools.partial(nn.BatchNorm2d, affine=True, track_running_stats=True)
32
36
  elif norm_type == 'instance':
33
37
  norm_layer = functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=False)
34
- elif norm_type == 'none':
38
+ elif norm_type in ['none','spectral']:
35
39
  def norm_layer(x): return Identity()
40
+ # elif norm_type == 'spectral':
41
+ # norm_layer = torch.nn.utils.parametrizations.spectral_norm # this needs to be called on nn modules individually, not on nn.Sequential
36
42
  else:
37
43
  raise NotImplementedError('normalization layer [%s] is not found' % norm_type)
38
44
  return norm_layer
39
45
 
46
+ def get_optimizer(optimizer_name):
47
+ try:
48
+ return getattr(torch.optim, optimizer_name)
49
+ except:
50
+ try:
51
+ return getattr(torch.optim, OPTIMIZER_MAPPING[optimizer_name])
52
+ except:
53
+ raise NotImplementedError('optimizer [%s] is not found' % optimizer_name)
40
54
 
41
55
  def get_scheduler(optimizer, opt):
42
56
  """Return a learning rate scheduler
@@ -125,7 +139,9 @@ def init_net(net, init_type='normal', init_gain=0.02, gpu_ids=[]):
125
139
  return net
126
140
 
127
141
 
128
- def define_G(input_nc, output_nc, ngf, netG, norm='batch', use_dropout=False, init_type='normal', init_gain=0.02, gpu_ids=[], padding_type='reflect'):
142
+ def define_G(
143
+ input_nc, output_nc, ngf, netG, norm='batch', use_dropout=False, init_type='normal', init_gain=0.02, gpu_ids=[], padding_type='reflect',
144
+ upsample='convtranspose'):
129
145
  """Create a generator
130
146
 
131
147
  Parameters:
@@ -154,17 +170,22 @@ def define_G(input_nc, output_nc, ngf, netG, norm='batch', use_dropout=False, in
154
170
  """
155
171
  net = None
156
172
  norm_layer = get_norm_layer(norm_type=norm)
173
+ use_spectral_norm = norm == 'spectral'
157
174
 
158
175
  if netG == 'resnet_9blocks':
159
- net = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=9, padding_type=padding_type)
176
+ net = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=9,
177
+ padding_type=padding_type, upsample=upsample, use_spectral_norm=use_spectral_norm)
160
178
  elif netG == 'resnet_6blocks':
161
- net = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=6, padding_type=padding_type)
179
+ net = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=6,
180
+ padding_type=padding_type, upsample=upsample, use_spectral_norm=use_spectral_norm)
162
181
  elif netG == 'unet_128':
163
182
  net = UnetGenerator(input_nc, output_nc, 7, ngf, norm_layer=norm_layer, use_dropout=use_dropout)
164
183
  elif netG == 'unet_256':
165
184
  net = UnetGenerator(input_nc, output_nc, 8, ngf, norm_layer=norm_layer, use_dropout=use_dropout)
166
185
  elif netG == 'unet_512':
167
186
  net = UnetGenerator(input_nc, output_nc, 9, ngf, norm_layer=norm_layer, use_dropout=use_dropout)
187
+ elif netG == 'unet_512_attention':
188
+ net = AttU_Net(img_ch=input_nc,output_ch=output_nc)
168
189
  else:
169
190
  raise NotImplementedError('Generator model name [%s] is not recognized' % netG)
170
191
  return init_net(net, init_type, init_gain, gpu_ids)
@@ -202,11 +223,12 @@ def define_D(input_nc, ndf, netD, n_layers_D=3, norm='batch', init_type='normal'
202
223
  """
203
224
  net = None
204
225
  norm_layer = get_norm_layer(norm_type=norm)
226
+ use_spectral_norm = norm == 'spectral'
205
227
 
206
228
  if netD == 'basic': # default PatchGAN classifier
207
- net = NLayerDiscriminator(input_nc, ndf, n_layers=3, norm_layer=norm_layer)
229
+ net = NLayerDiscriminator(input_nc, ndf, n_layers=3, norm_layer=norm_layer, use_spectral_norm=use_spectral_norm)
208
230
  elif netD == 'n_layers': # more options
209
- net = NLayerDiscriminator(input_nc, ndf, n_layers_D, norm_layer=norm_layer)
231
+ net = NLayerDiscriminator(input_nc, ndf, n_layers_D, norm_layer=norm_layer, use_spectral_norm=use_spectral_norm)
210
232
  elif netD == 'pixel': # classify if each pixel is real or fake
211
233
  net = PixelDiscriminator(input_nc, ndf, norm_layer=norm_layer)
212
234
  else:
@@ -224,7 +246,7 @@ class GANLoss(nn.Module):
224
246
  that has the same size as the input.
225
247
  """
226
248
 
227
- def __init__(self, gan_mode, target_real_label=1.0, target_fake_label=0.0):
249
+ def __init__(self, gan_mode, target_real_label=1.0, target_fake_label=0.0, label_smoothing=0.0):
228
250
  """ Initialize the GANLoss class.
229
251
 
230
252
  Parameters:
@@ -239,6 +261,7 @@ class GANLoss(nn.Module):
239
261
  self.register_buffer('real_label', torch.tensor(target_real_label))
240
262
  self.register_buffer('fake_label', torch.tensor(target_fake_label))
241
263
  self.gan_mode = gan_mode
264
+ self.label_smoothing = label_smoothing
242
265
  if gan_mode == 'lsgan':
243
266
  self.loss = nn.MSELoss()
244
267
  elif gan_mode == 'vanilla':
@@ -261,9 +284,10 @@ class GANLoss(nn.Module):
261
284
 
262
285
  if target_is_real:
263
286
  target_tensor = self.real_label
287
+ return target_tensor.expand_as(prediction) * (1 - self.label_smoothing)
264
288
  else:
265
289
  target_tensor = self.fake_label
266
- return target_tensor.expand_as(prediction)
290
+ return target_tensor.expand_as(prediction) * self.label_smoothing
267
291
 
268
292
  def __call__(self, prediction, target_is_real, epsilon=1.0):
269
293
  """Calculate loss given Discriminator's output and grount truth labels.
@@ -332,9 +356,13 @@ class ResnetGenerator(nn.Module):
332
356
  """Resnet-based generator that consists of Resnet blocks between a few downsampling/upsampling operations.
333
357
 
334
358
  We adapt Torch code and idea from Justin Johnson's neural style transfer project(https://github.com/jcjohnson/fast-neural-style)
359
+
360
+ Resize-conv: optional replacement of ConvTranspose2d
361
+ https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/issues/190#issuecomment-358546675
335
362
  """
336
363
 
337
- def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, n_blocks=6, padding_type='zero'):
364
+ def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, n_blocks=6, padding_type='zero',
365
+ upsample='convtranspose', use_spectral_norm=False):
338
366
  """Construct a Resnet-based generator
339
367
 
340
368
  Parameters:
@@ -355,34 +383,54 @@ class ResnetGenerator(nn.Module):
355
383
 
356
384
  if padding_type == 'reflect':
357
385
  model = [nn.ReflectionPad2d(3),
358
- nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0, bias=use_bias),
359
- norm_layer(ngf),
360
- nn.ReLU(True)]
386
+ SpectralNorm(nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0, bias=use_bias),
387
+ use_spectral_norm=use_spectral_norm),
388
+ norm_layer(ngf),
389
+ nn.ReLU(True)]
361
390
  else:
362
391
  model = [nn.ZeroPad2d(3),
363
- nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0, bias=use_bias),
364
- norm_layer(ngf),
365
- nn.ReLU(True)]
392
+ SpectralNorm(nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0, bias=use_bias),
393
+ use_spectral_norm=use_spectral_norm),
394
+ norm_layer(ngf),
395
+ nn.ReLU(True)]
366
396
 
367
397
  n_downsampling = 2
368
398
  for i in range(n_downsampling): # add downsampling layers
369
399
  mult = 2 ** i
370
- model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1, bias=use_bias),
400
+ model += [SpectralNorm(nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1, bias=use_bias),use_spectral_norm=use_spectral_norm),
371
401
  norm_layer(ngf * mult * 2),
372
402
  nn.ReLU(True)]
373
403
 
374
404
  mult = 2 ** n_downsampling
375
405
  for i in range(n_blocks): # add ResNet blocks
376
406
 
377
- model += [ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout, use_bias=use_bias)]
407
+ model += [ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout, use_bias=use_bias,
408
+ use_spectral_norm=use_spectral_norm)]
378
409
 
379
410
  for i in range(n_downsampling): # add upsampling layers
380
411
  mult = 2 ** (n_downsampling - i)
381
- model += [nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2),
382
- kernel_size=3, stride=2,
383
- padding=1, output_padding=1,
384
- bias=use_bias),
385
- norm_layer(int(ngf * mult / 2)),
412
+ if upsample == 'resize_conv':
413
+ upsample_layer = [#nn.Upsample(scale_factor = 2, mode='bilinear',align_corners=True),
414
+ nn.Upsample(scale_factor = 2, mode='nearest'),
415
+ nn.ReflectionPad2d(1),
416
+ SpectralNorm(nn.Conv2d(ngf * mult, int(ngf * mult / 2),
417
+ kernel_size=3, stride=1, padding=0),use_spectral_norm=use_spectral_norm)]
418
+ elif upsample == 'pixel_shuffle':
419
+ upsample_layer = [SpectralNorm(nn.Conv2d(ngf * mult, int(ngf * mult * 2),use_spectral_norm=use_spectral_norm),
420
+ kernel_size=3, padding=1),
421
+ nn.PixelShuffle(2),
422
+ nn.ReLU()]
423
+ elif upsample == 'convtranspose':
424
+ upsample_layer = [SpectralNorm(nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2),
425
+ kernel_size=3, stride=2,
426
+ padding=1, output_padding=1,
427
+ bias=use_bias),
428
+ use_spectral_norm=use_spectral_norm)]
429
+ else:
430
+ raise Exception(f'upsample layer type {upsample} not implemented')
431
+
432
+ model += upsample_layer
433
+ model += [norm_layer(int(ngf * mult / 2)),
386
434
  nn.ReLU(True)]
387
435
 
388
436
  if padding_type == 'reflect':
@@ -390,7 +438,7 @@ class ResnetGenerator(nn.Module):
390
438
  else:
391
439
  model += [nn.ZeroPad2d(3)]
392
440
 
393
- model += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)]
441
+ model += [SpectralNorm(nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0),use_spectral_norm=use_spectral_norm)]
394
442
  model += [nn.Tanh()]
395
443
 
396
444
  self.model = nn.Sequential(*model)
@@ -403,7 +451,7 @@ class ResnetGenerator(nn.Module):
403
451
  class ResnetBlock(nn.Module):
404
452
  """Define a Resnet block"""
405
453
 
406
- def __init__(self, dim, padding_type, norm_layer, use_dropout, use_bias):
454
+ def __init__(self, dim, padding_type, norm_layer, use_dropout, use_bias, use_spectral_norm=False):
407
455
  """Initialize the Resnet block
408
456
 
409
457
  A resnet block is a conv block with skip connections
@@ -412,9 +460,9 @@ class ResnetBlock(nn.Module):
412
460
  Original Resnet paper: https://arxiv.org/pdf/1512.03385.pdf
413
461
  """
414
462
  super(ResnetBlock, self).__init__()
415
- self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, use_dropout, use_bias)
463
+ self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, use_dropout, use_bias, use_spectral_norm)
416
464
 
417
- def build_conv_block(self, dim, padding_type, norm_layer, use_dropout, use_bias):
465
+ def build_conv_block(self, dim, padding_type, norm_layer, use_dropout, use_bias, use_spectral_norm):
418
466
  """Construct a convolutional block.
419
467
 
420
468
  Parameters:
@@ -437,7 +485,9 @@ class ResnetBlock(nn.Module):
437
485
  else:
438
486
  raise NotImplementedError('padding [%s] is not implemented' % padding_type)
439
487
 
440
- conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias), norm_layer(dim), nn.ReLU(True)]
488
+ conv_block += [SpectralNorm(nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias),use_spectral_norm=use_spectral_norm),
489
+ norm_layer(dim),
490
+ nn.ReLU(True)]
441
491
  if use_dropout:
442
492
  conv_block += [nn.Dropout(0.5)]
443
493
 
@@ -450,7 +500,8 @@ class ResnetBlock(nn.Module):
450
500
  p = 1
451
501
  else:
452
502
  raise NotImplementedError('padding [%s] is not implemented' % padding_type)
453
- conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias), norm_layer(dim)]
503
+ conv_block += [SpectralNorm(nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias),use_spectral_norm=use_spectral_norm),
504
+ norm_layer(dim)]
454
505
 
455
506
  return nn.Sequential(*conv_block)
456
507
 
@@ -565,7 +616,7 @@ class UnetSkipConnectionBlock(nn.Module):
565
616
  class NLayerDiscriminator(nn.Module):
566
617
  """Defines a PatchGAN discriminator"""
567
618
 
568
- def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d):
619
+ def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d, use_spectral_norm=False):
569
620
  """Construct a PatchGAN discriminator
570
621
 
571
622
  Parameters:
@@ -582,14 +633,15 @@ class NLayerDiscriminator(nn.Module):
582
633
 
583
634
  kw = 4
584
635
  padw = 1
585
- sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw), nn.LeakyReLU(0.2, True)]
636
+ sequence = [SpectralNorm(nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw),use_spectral_norm=use_spectral_norm),
637
+ nn.LeakyReLU(0.2, True)]
586
638
  nf_mult = 1
587
639
  nf_mult_prev = 1
588
640
  for n in range(1, n_layers): # gradually increase the number of filters
589
641
  nf_mult_prev = nf_mult
590
642
  nf_mult = min(2 ** n, 8)
591
643
  sequence += [
592
- nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias),
644
+ SpectralNorm(nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias),use_spectral_norm=use_spectral_norm),
593
645
  norm_layer(ndf * nf_mult),
594
646
  nn.LeakyReLU(0.2, True)
595
647
  ]
@@ -597,12 +649,12 @@ class NLayerDiscriminator(nn.Module):
597
649
  nf_mult_prev = nf_mult
598
650
  nf_mult = min(2 ** n_layers, 8)
599
651
  sequence += [
600
- nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias),
652
+ SpectralNorm(nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias),use_spectral_norm=use_spectral_norm),
601
653
  norm_layer(ndf * nf_mult),
602
654
  nn.LeakyReLU(0.2, True)
603
655
  ]
604
656
 
605
- sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] # output 1 channel prediction map
657
+ sequence += [SpectralNorm(nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw),use_spectral_norm=use_spectral_norm)] # output 1 channel prediction map
606
658
  self.model = nn.Sequential(*sequence)
607
659
 
608
660
  def forward(self, input):
@@ -687,3 +739,25 @@ class VGGLoss(nn.Module):
687
739
  for i in range(len(x_vgg)):
688
740
  loss += self.weights[i] * self.criterion(x_vgg[i], y_vgg[i].detach())
689
741
  return loss
742
+
743
+
744
+ class TotalVariationLoss(nn.Module):
745
+ """
746
+ Absolute difference for neighbouring pixels (i,j to i+1,j, then i,j to i,j+1), averaged on pixel level
747
+ """
748
+ def __init__(self):
749
+ super(TotalVariationLoss, self).__init__()
750
+
751
+ def forward(self, x):
752
+ tv = torch.abs(x[:,:,1:,:]-x[:,:,:-1,:]).sum() + torch.abs(x[:,:,:,1:]-x[:,:,:,:-1]).sum()
753
+ return tv / torch.prod(torch.tensor(x.size()))
754
+
755
+ def SpectralNorm(x, use_spectral_norm=False):
756
+ """
757
+ A custom wrapper for nn.utils.parametrizations.spectral_norm,
758
+ with a flag to turn it on or off.
759
+ """
760
+ if use_spectral_norm:
761
+ return nn.utils.parametrizations.spectral_norm(x)
762
+ else:
763
+ return x