deepliif 1.1.10__py3-none-any.whl → 1.1.12__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cli.py +354 -67
- deepliif/data/__init__.py +7 -7
- deepliif/data/aligned_dataset.py +2 -3
- deepliif/data/unaligned_dataset.py +38 -19
- deepliif/models/CycleGAN_model.py +282 -0
- deepliif/models/DeepLIIFExt_model.py +47 -25
- deepliif/models/DeepLIIF_model.py +69 -19
- deepliif/models/SDG_model.py +57 -26
- deepliif/models/__init__ - run_dask_multi dev.py +943 -0
- deepliif/models/__init__ - timings.py +764 -0
- deepliif/models/__init__.py +354 -232
- deepliif/models/att_unet.py +199 -0
- deepliif/models/base_model.py +32 -8
- deepliif/models/networks.py +108 -34
- deepliif/options/__init__.py +49 -5
- deepliif/postprocessing.py +1034 -227
- deepliif/postprocessing__OLD__DELETE.py +440 -0
- deepliif/util/__init__.py +290 -64
- deepliif/util/visualizer.py +106 -19
- {deepliif-1.1.10.dist-info → deepliif-1.1.12.dist-info}/METADATA +81 -20
- deepliif-1.1.12.dist-info/RECORD +40 -0
- deepliif-1.1.10.dist-info/RECORD +0 -35
- {deepliif-1.1.10.dist-info → deepliif-1.1.12.dist-info}/LICENSE.md +0 -0
- {deepliif-1.1.10.dist-info → deepliif-1.1.12.dist-info}/WHEEL +0 -0
- {deepliif-1.1.10.dist-info → deepliif-1.1.12.dist-info}/entry_points.txt +0 -0
- {deepliif-1.1.10.dist-info → deepliif-1.1.12.dist-info}/top_level.txt +0 -0
|
@@ -0,0 +1,199 @@
|
|
|
1
|
+
# adapted from https://github.com/LeeJunHyun/Image_Segmentation/blob/master/network.py
|
|
2
|
+
|
|
3
|
+
import torch
|
|
4
|
+
import torch.nn as nn
|
|
5
|
+
import torch.nn.functional as F
|
|
6
|
+
from torch.nn import init
|
|
7
|
+
|
|
8
|
+
def init_weights(net, init_type='normal', gain=0.02):
|
|
9
|
+
def init_func(m):
|
|
10
|
+
classname = m.__class__.__name__
|
|
11
|
+
if hasattr(m, 'weight') and (classname.find('Conv') != -1 or classname.find('Linear') != -1):
|
|
12
|
+
if init_type == 'normal':
|
|
13
|
+
init.normal_(m.weight.data, 0.0, gain)
|
|
14
|
+
elif init_type == 'xavier':
|
|
15
|
+
init.xavier_normal_(m.weight.data, gain=gain)
|
|
16
|
+
elif init_type == 'kaiming':
|
|
17
|
+
init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
|
|
18
|
+
elif init_type == 'orthogonal':
|
|
19
|
+
init.orthogonal_(m.weight.data, gain=gain)
|
|
20
|
+
else:
|
|
21
|
+
raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
|
|
22
|
+
if hasattr(m, 'bias') and m.bias is not None:
|
|
23
|
+
init.constant_(m.bias.data, 0.0)
|
|
24
|
+
elif classname.find('BatchNorm2d') != -1:
|
|
25
|
+
init.normal_(m.weight.data, 1.0, gain)
|
|
26
|
+
init.constant_(m.bias.data, 0.0)
|
|
27
|
+
|
|
28
|
+
print('initialize network with %s' % init_type)
|
|
29
|
+
net.apply(init_func)
|
|
30
|
+
|
|
31
|
+
class conv_block(nn.Module):
|
|
32
|
+
def __init__(self,ch_in,ch_out,innermost=False,outermost=False):
|
|
33
|
+
super(conv_block,self).__init__()
|
|
34
|
+
if outermost:
|
|
35
|
+
self.conv = nn.Sequential(
|
|
36
|
+
nn.Conv2d(ch_in, ch_out, kernel_size=4,stride=2,padding=1,bias=True),
|
|
37
|
+
nn.LeakyReLU(0.2, True),
|
|
38
|
+
)
|
|
39
|
+
elif innermost:
|
|
40
|
+
self.conv = nn.Sequential(
|
|
41
|
+
nn.Conv2d(ch_in, ch_out, kernel_size=4,stride=2,padding=1,bias=True),
|
|
42
|
+
nn.ReLU(inplace=True),
|
|
43
|
+
)
|
|
44
|
+
else:
|
|
45
|
+
self.conv = nn.Sequential(
|
|
46
|
+
nn.Conv2d(ch_in, ch_out, kernel_size=4,stride=2,padding=1,bias=True),
|
|
47
|
+
nn.BatchNorm2d(ch_out),
|
|
48
|
+
nn.LeakyReLU(0.2, True),
|
|
49
|
+
)
|
|
50
|
+
|
|
51
|
+
def forward(self,x):
|
|
52
|
+
x = self.conv(x)
|
|
53
|
+
return x
|
|
54
|
+
|
|
55
|
+
class up_conv(nn.Module):
|
|
56
|
+
def __init__(self,ch_in,ch_out,innermost=False,outermost=False):
|
|
57
|
+
super(up_conv,self).__init__()
|
|
58
|
+
use_bias=False
|
|
59
|
+
if outermost:
|
|
60
|
+
self.up = nn.Sequential(
|
|
61
|
+
nn.ConvTranspose2d(ch_in * 2, ch_out,
|
|
62
|
+
kernel_size=4, stride=2,
|
|
63
|
+
padding=1),
|
|
64
|
+
nn.Tanh())
|
|
65
|
+
elif innermost:
|
|
66
|
+
self.up = nn.Sequential(
|
|
67
|
+
nn.ConvTranspose2d(ch_in, ch_out,
|
|
68
|
+
kernel_size=4, stride=2,
|
|
69
|
+
padding=1, bias=use_bias),
|
|
70
|
+
nn.BatchNorm2d(ch_out),
|
|
71
|
+
nn.ReLU(True))
|
|
72
|
+
else:
|
|
73
|
+
self.up = nn.Sequential(
|
|
74
|
+
nn.ConvTranspose2d(ch_in * 2, ch_out,
|
|
75
|
+
kernel_size=4, stride=2,
|
|
76
|
+
padding=1, bias=use_bias),
|
|
77
|
+
nn.BatchNorm2d(ch_out),
|
|
78
|
+
nn.ReLU(True))
|
|
79
|
+
|
|
80
|
+
|
|
81
|
+
def forward(self,x):
|
|
82
|
+
x = self.up(x)
|
|
83
|
+
return x
|
|
84
|
+
|
|
85
|
+
|
|
86
|
+
class Attention_block(nn.Module):
|
|
87
|
+
def __init__(self,F_g,F_l,F_int):
|
|
88
|
+
super(Attention_block,self).__init__()
|
|
89
|
+
self.W_g = nn.Sequential(
|
|
90
|
+
nn.Conv2d(F_g, F_int, kernel_size=1,stride=1,padding=0,bias=True),
|
|
91
|
+
nn.BatchNorm2d(F_int)
|
|
92
|
+
)
|
|
93
|
+
|
|
94
|
+
self.W_x = nn.Sequential(
|
|
95
|
+
nn.Conv2d(F_l, F_int, kernel_size=1,stride=1,padding=0,bias=True),
|
|
96
|
+
nn.BatchNorm2d(F_int)
|
|
97
|
+
)
|
|
98
|
+
|
|
99
|
+
self.psi = nn.Sequential(
|
|
100
|
+
nn.Conv2d(F_int, 1, kernel_size=1,stride=1,padding=0,bias=True),
|
|
101
|
+
nn.BatchNorm2d(1),
|
|
102
|
+
nn.Sigmoid()
|
|
103
|
+
)
|
|
104
|
+
|
|
105
|
+
self.relu = nn.ReLU(inplace=True)
|
|
106
|
+
|
|
107
|
+
def forward(self,g,x):
|
|
108
|
+
g1 = self.W_g(g)
|
|
109
|
+
x1 = self.W_x(x)
|
|
110
|
+
psi = self.relu(g1+x1)
|
|
111
|
+
psi = self.psi(psi)
|
|
112
|
+
|
|
113
|
+
return x*psi
|
|
114
|
+
|
|
115
|
+
|
|
116
|
+
|
|
117
|
+
class AttU_Net(nn.Module):
|
|
118
|
+
def __init__(self,img_ch=3,output_ch=1):
|
|
119
|
+
super(AttU_Net,self).__init__()
|
|
120
|
+
|
|
121
|
+
self.Conv1 = conv_block(ch_in=img_ch,ch_out=64,outermost=True)
|
|
122
|
+
self.Conv2 = conv_block(ch_in=64,ch_out=128)
|
|
123
|
+
self.Conv3 = conv_block(ch_in=128,ch_out=256)
|
|
124
|
+
self.Conv4 = conv_block(ch_in=256,ch_out=512)
|
|
125
|
+
self.Conv5 = conv_block(ch_in=512,ch_out=512)
|
|
126
|
+
self.Conv6 = conv_block(ch_in=512,ch_out=512)
|
|
127
|
+
self.Conv7 = conv_block(ch_in=512,ch_out=512)
|
|
128
|
+
self.Conv8 = conv_block(ch_in=512,ch_out=512,innermost=True)
|
|
129
|
+
#self.Conv9 = conv_block(ch_in=512,ch_out=512,innermost=True)
|
|
130
|
+
|
|
131
|
+
self.Up8 = up_conv(ch_in=512,ch_out=512,innermost=True)
|
|
132
|
+
self.Att8 = Attention_block(F_g=512,F_l=512,F_int=512)
|
|
133
|
+
|
|
134
|
+
self.Up7 = up_conv(ch_in=512,ch_out=512)
|
|
135
|
+
self.Att7 = Attention_block(F_g=512,F_l=512,F_int=512)
|
|
136
|
+
|
|
137
|
+
self.Up6 = up_conv(ch_in=512,ch_out=512)
|
|
138
|
+
self.Att6 = Attention_block(F_g=512,F_l=512,F_int=512)
|
|
139
|
+
|
|
140
|
+
self.Up5 = up_conv(ch_in=512,ch_out=512)
|
|
141
|
+
self.Att5 = Attention_block(F_g=512,F_l=512,F_int=512)
|
|
142
|
+
|
|
143
|
+
self.Up4 = up_conv(ch_in=512,ch_out=256)
|
|
144
|
+
self.Att4 = Attention_block(F_g=256,F_l=256,F_int=128)
|
|
145
|
+
|
|
146
|
+
self.Up3 = up_conv(ch_in=256,ch_out=128)
|
|
147
|
+
self.Att3 = Attention_block(F_g=128,F_l=128,F_int=64)
|
|
148
|
+
|
|
149
|
+
self.Up2 = up_conv(ch_in=128,ch_out=64)
|
|
150
|
+
self.Att2 = Attention_block(F_g=64,F_l=64,F_int=32)
|
|
151
|
+
|
|
152
|
+
self.Up1 = up_conv(ch_in=64,ch_out=output_ch,outermost=True)
|
|
153
|
+
|
|
154
|
+
def forward(self,x):
|
|
155
|
+
# encoding path
|
|
156
|
+
x1 = self.Conv1(x)
|
|
157
|
+
x2 = self.Conv2(x1)
|
|
158
|
+
x3 = self.Conv3(x2)
|
|
159
|
+
x4 = self.Conv4(x3)
|
|
160
|
+
x5 = self.Conv5(x4)
|
|
161
|
+
x6 = self.Conv6(x5)
|
|
162
|
+
x7 = self.Conv7(x6)
|
|
163
|
+
x8 = self.Conv8(x7)
|
|
164
|
+
#x9 = self.Conv9(x8)
|
|
165
|
+
|
|
166
|
+
#d9 = self.Up
|
|
167
|
+
d8 = self.Up8(x8)
|
|
168
|
+
x7 = self.Att8(g=d8,x=x7)
|
|
169
|
+
d8 = torch.cat((x7,d8),dim=1)
|
|
170
|
+
|
|
171
|
+
d7 = self.Up7(d8)
|
|
172
|
+
x6 = self.Att7(g=d7,x=x6)
|
|
173
|
+
d7 = torch.cat((x6,d7),dim=1)
|
|
174
|
+
|
|
175
|
+
d6 = self.Up6(d7)
|
|
176
|
+
x5 = self.Att6(g=d6,x=x5)
|
|
177
|
+
d6 = torch.cat((x5,d6),dim=1)
|
|
178
|
+
|
|
179
|
+
d5 = self.Up5(d6)
|
|
180
|
+
x4 = self.Att5(g=d5,x=x4)
|
|
181
|
+
d5 = torch.cat((x4,d5),dim=1)
|
|
182
|
+
|
|
183
|
+
|
|
184
|
+
d4 = self.Up4(d5) # x4: [2, 512, 4, 4], d4: [2, 256, 4, 4]
|
|
185
|
+
x3 = self.Att4(g=d4,x=x3)
|
|
186
|
+
d4 = torch.cat((x3,d4),dim=1)
|
|
187
|
+
|
|
188
|
+
d3 = self.Up3(d4)
|
|
189
|
+
x2 = self.Att3(g=d3,x=x2)
|
|
190
|
+
d3 = torch.cat((x2,d3),dim=1)
|
|
191
|
+
|
|
192
|
+
d2 = self.Up2(d3)
|
|
193
|
+
x1 = self.Att2(g=d2,x=x1)
|
|
194
|
+
d2 = torch.cat((x1,d2),dim=1)
|
|
195
|
+
|
|
196
|
+
d1 = self.Up1(d2)
|
|
197
|
+
|
|
198
|
+
return d1
|
|
199
|
+
|
deepliif/models/base_model.py
CHANGED
|
@@ -3,7 +3,7 @@ import torch
|
|
|
3
3
|
from collections import OrderedDict
|
|
4
4
|
from abc import ABC, abstractmethod
|
|
5
5
|
from . import networks
|
|
6
|
-
from ..util import disable_batchnorm_tracking_stats
|
|
6
|
+
from ..util import disable_batchnorm_tracking_stats, enable_batchnorm_tracking_stats
|
|
7
7
|
from deepliif.util import *
|
|
8
8
|
import itertools
|
|
9
9
|
|
|
@@ -30,7 +30,7 @@ class BaseModel(ABC):
|
|
|
30
30
|
-- self.loss_names (str list): specify the training losses that you want to plot and save.
|
|
31
31
|
-- self.model_names (str list): define networks used in our training.
|
|
32
32
|
-- self.visual_names (str list): specify the images that you want to display and save.
|
|
33
|
-
-- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See
|
|
33
|
+
-- self.optimizers (optimizer list): define and initialize optimizers. You can define one optimizer for each network. If two networks are updated at the same time, you can use itertools.chain to group them. See CycleGAN_model.py for an example.
|
|
34
34
|
"""
|
|
35
35
|
self.opt = opt
|
|
36
36
|
self.gpu_ids = opt.gpu_ids
|
|
@@ -81,6 +81,7 @@ class BaseModel(ABC):
|
|
|
81
81
|
Parameters:
|
|
82
82
|
opt (Option class) -- stores all the experiment flags; needs to be a subclass of BaseOptions
|
|
83
83
|
"""
|
|
84
|
+
self.opt = opt
|
|
84
85
|
if self.is_train:
|
|
85
86
|
self.schedulers = [networks.get_scheduler(optimizer, opt) for optimizer in self.optimizers]
|
|
86
87
|
if not self.is_train or opt.continue_train:
|
|
@@ -88,6 +89,17 @@ class BaseModel(ABC):
|
|
|
88
89
|
self.load_networks(load_suffix)
|
|
89
90
|
self.print_networks(opt.verbose)
|
|
90
91
|
|
|
92
|
+
def train(self):
|
|
93
|
+
"""Make models train mode """
|
|
94
|
+
for name in self.model_names:
|
|
95
|
+
if isinstance(name, str):
|
|
96
|
+
if '_' in name:
|
|
97
|
+
net = getattr(self, 'net' + name.split('_')[0])[int(name.split('_')[-1]) - 1]
|
|
98
|
+
else:
|
|
99
|
+
net = getattr(self, 'net' + name)
|
|
100
|
+
net.train()
|
|
101
|
+
net = enable_batchnorm_tracking_stats(net)
|
|
102
|
+
|
|
91
103
|
def eval(self):
|
|
92
104
|
"""Make models eval mode during test time"""
|
|
93
105
|
for name in self.model_names:
|
|
@@ -134,10 +146,21 @@ class BaseModel(ABC):
|
|
|
134
146
|
for name in self.visual_names:
|
|
135
147
|
if isinstance(name, str):
|
|
136
148
|
if not hasattr(self, name):
|
|
137
|
-
if len(name.split('_'))
|
|
138
|
-
|
|
149
|
+
if len(name.split('_')) != 2:
|
|
150
|
+
if self.opt.model == 'DeepLIIF':
|
|
151
|
+
img_name = name[:-1] + '_' + name[-1]
|
|
152
|
+
visual_ret[name] = getattr(self, img_name)
|
|
153
|
+
else:
|
|
154
|
+
if self.opt.model == 'CycleGAN':
|
|
155
|
+
l_output = getattr(self, name.split('_')[0] + '_' + name.split('_')[1])
|
|
156
|
+
if len(l_output) > 0:
|
|
157
|
+
visual_ret[name] = getattr(self, name.split('_')[0] + '_' + name.split('_')[1])[int(name.split('_')[-1]) - 1]
|
|
158
|
+
else:
|
|
159
|
+
print('No output for',name)
|
|
160
|
+
else:
|
|
161
|
+
visual_ret[name] = getattr(self, name.split('_')[0] + '_' + name.split('_')[1])[int(name.split('_')[-1]) - 1]
|
|
139
162
|
else:
|
|
140
|
-
visual_ret[name] = getattr(self, name.split('_')[0]
|
|
163
|
+
visual_ret[name] = getattr(self, name.split('_')[0])[int(name.split('_')[-1]) -1]
|
|
141
164
|
else:
|
|
142
165
|
visual_ret[name] = getattr(self, name)
|
|
143
166
|
return visual_ret
|
|
@@ -240,6 +263,7 @@ class BaseModel(ABC):
|
|
|
240
263
|
epoch (int) -- current epoch; used in the file name '%s_net_%s.pth' % (epoch, name)
|
|
241
264
|
"""
|
|
242
265
|
for name in self.model_names:
|
|
266
|
+
|
|
243
267
|
if isinstance(name, str):
|
|
244
268
|
load_filename = '%s_net_%s.pth' % (epoch, name)
|
|
245
269
|
load_path = os.path.join(self.save_dir, load_filename)
|
|
@@ -272,9 +296,9 @@ class BaseModel(ABC):
|
|
|
272
296
|
if hasattr(state_dict, '_metadata'):
|
|
273
297
|
del state_dict._metadata
|
|
274
298
|
|
|
275
|
-
# patch InstanceNorm checkpoints prior to 0.4
|
|
276
|
-
for key in list(state_dict.keys()): # need to copy keys here because we mutate in loop
|
|
277
|
-
|
|
299
|
+
# # patch InstanceNorm checkpoints prior to 0.4
|
|
300
|
+
# for key in list(state_dict.keys()): # need to copy keys here because we mutate in loop
|
|
301
|
+
# self.__patch_instance_norm_state_dict(state_dict, net, key.split('.'))
|
|
278
302
|
net.load_state_dict(state_dict)
|
|
279
303
|
|
|
280
304
|
def print_networks(self, verbose):
|
deepliif/models/networks.py
CHANGED
|
@@ -3,15 +3,19 @@ import torch.nn as nn
|
|
|
3
3
|
from torch.nn import init
|
|
4
4
|
import functools
|
|
5
5
|
from torch.optim import lr_scheduler
|
|
6
|
+
|
|
6
7
|
import os
|
|
7
8
|
|
|
8
9
|
from torchvision import models
|
|
9
|
-
|
|
10
|
+
from .att_unet import AttU_Net
|
|
10
11
|
###############################################################################
|
|
11
12
|
# Helper Functions
|
|
12
13
|
###############################################################################
|
|
13
14
|
from deepliif.util import util
|
|
14
15
|
|
|
16
|
+
# as of pytorch 2.4, all optimizers start with an uppercase letter
|
|
17
|
+
OPTIMIZER_MAPPING = {optimizer_name.lower():optimizer_name for optimizer_name in dir(torch.optim) if optimizer_name[0].isupper()}
|
|
18
|
+
|
|
15
19
|
|
|
16
20
|
class Identity(nn.Module):
|
|
17
21
|
def forward(self, x):
|
|
@@ -31,12 +35,22 @@ def get_norm_layer(norm_type='instance'):
|
|
|
31
35
|
norm_layer = functools.partial(nn.BatchNorm2d, affine=True, track_running_stats=True)
|
|
32
36
|
elif norm_type == 'instance':
|
|
33
37
|
norm_layer = functools.partial(nn.InstanceNorm2d, affine=False, track_running_stats=False)
|
|
34
|
-
elif norm_type
|
|
38
|
+
elif norm_type in ['none','spectral']:
|
|
35
39
|
def norm_layer(x): return Identity()
|
|
40
|
+
# elif norm_type == 'spectral':
|
|
41
|
+
# norm_layer = torch.nn.utils.parametrizations.spectral_norm # this needs to be called on nn modules individually, not on nn.Sequential
|
|
36
42
|
else:
|
|
37
43
|
raise NotImplementedError('normalization layer [%s] is not found' % norm_type)
|
|
38
44
|
return norm_layer
|
|
39
45
|
|
|
46
|
+
def get_optimizer(optimizer_name):
|
|
47
|
+
try:
|
|
48
|
+
return getattr(torch.optim, optimizer_name)
|
|
49
|
+
except:
|
|
50
|
+
try:
|
|
51
|
+
return getattr(torch.optim, OPTIMIZER_MAPPING[optimizer_name])
|
|
52
|
+
except:
|
|
53
|
+
raise NotImplementedError('optimizer [%s] is not found' % optimizer_name)
|
|
40
54
|
|
|
41
55
|
def get_scheduler(optimizer, opt):
|
|
42
56
|
"""Return a learning rate scheduler
|
|
@@ -125,7 +139,9 @@ def init_net(net, init_type='normal', init_gain=0.02, gpu_ids=[]):
|
|
|
125
139
|
return net
|
|
126
140
|
|
|
127
141
|
|
|
128
|
-
def define_G(
|
|
142
|
+
def define_G(
|
|
143
|
+
input_nc, output_nc, ngf, netG, norm='batch', use_dropout=False, init_type='normal', init_gain=0.02, gpu_ids=[], padding_type='reflect',
|
|
144
|
+
upsample='convtranspose'):
|
|
129
145
|
"""Create a generator
|
|
130
146
|
|
|
131
147
|
Parameters:
|
|
@@ -154,17 +170,22 @@ def define_G(input_nc, output_nc, ngf, netG, norm='batch', use_dropout=False, in
|
|
|
154
170
|
"""
|
|
155
171
|
net = None
|
|
156
172
|
norm_layer = get_norm_layer(norm_type=norm)
|
|
173
|
+
use_spectral_norm = norm == 'spectral'
|
|
157
174
|
|
|
158
175
|
if netG == 'resnet_9blocks':
|
|
159
|
-
net = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=9,
|
|
176
|
+
net = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=9,
|
|
177
|
+
padding_type=padding_type, upsample=upsample, use_spectral_norm=use_spectral_norm)
|
|
160
178
|
elif netG == 'resnet_6blocks':
|
|
161
|
-
net = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=6,
|
|
179
|
+
net = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=6,
|
|
180
|
+
padding_type=padding_type, upsample=upsample, use_spectral_norm=use_spectral_norm)
|
|
162
181
|
elif netG == 'unet_128':
|
|
163
182
|
net = UnetGenerator(input_nc, output_nc, 7, ngf, norm_layer=norm_layer, use_dropout=use_dropout)
|
|
164
183
|
elif netG == 'unet_256':
|
|
165
184
|
net = UnetGenerator(input_nc, output_nc, 8, ngf, norm_layer=norm_layer, use_dropout=use_dropout)
|
|
166
185
|
elif netG == 'unet_512':
|
|
167
186
|
net = UnetGenerator(input_nc, output_nc, 9, ngf, norm_layer=norm_layer, use_dropout=use_dropout)
|
|
187
|
+
elif netG == 'unet_512_attention':
|
|
188
|
+
net = AttU_Net(img_ch=input_nc,output_ch=output_nc)
|
|
168
189
|
else:
|
|
169
190
|
raise NotImplementedError('Generator model name [%s] is not recognized' % netG)
|
|
170
191
|
return init_net(net, init_type, init_gain, gpu_ids)
|
|
@@ -202,11 +223,12 @@ def define_D(input_nc, ndf, netD, n_layers_D=3, norm='batch', init_type='normal'
|
|
|
202
223
|
"""
|
|
203
224
|
net = None
|
|
204
225
|
norm_layer = get_norm_layer(norm_type=norm)
|
|
226
|
+
use_spectral_norm = norm == 'spectral'
|
|
205
227
|
|
|
206
228
|
if netD == 'basic': # default PatchGAN classifier
|
|
207
|
-
net = NLayerDiscriminator(input_nc, ndf, n_layers=3, norm_layer=norm_layer)
|
|
229
|
+
net = NLayerDiscriminator(input_nc, ndf, n_layers=3, norm_layer=norm_layer, use_spectral_norm=use_spectral_norm)
|
|
208
230
|
elif netD == 'n_layers': # more options
|
|
209
|
-
net = NLayerDiscriminator(input_nc, ndf, n_layers_D, norm_layer=norm_layer)
|
|
231
|
+
net = NLayerDiscriminator(input_nc, ndf, n_layers_D, norm_layer=norm_layer, use_spectral_norm=use_spectral_norm)
|
|
210
232
|
elif netD == 'pixel': # classify if each pixel is real or fake
|
|
211
233
|
net = PixelDiscriminator(input_nc, ndf, norm_layer=norm_layer)
|
|
212
234
|
else:
|
|
@@ -224,7 +246,7 @@ class GANLoss(nn.Module):
|
|
|
224
246
|
that has the same size as the input.
|
|
225
247
|
"""
|
|
226
248
|
|
|
227
|
-
def __init__(self, gan_mode, target_real_label=1.0, target_fake_label=0.0):
|
|
249
|
+
def __init__(self, gan_mode, target_real_label=1.0, target_fake_label=0.0, label_smoothing=0.0):
|
|
228
250
|
""" Initialize the GANLoss class.
|
|
229
251
|
|
|
230
252
|
Parameters:
|
|
@@ -239,6 +261,7 @@ class GANLoss(nn.Module):
|
|
|
239
261
|
self.register_buffer('real_label', torch.tensor(target_real_label))
|
|
240
262
|
self.register_buffer('fake_label', torch.tensor(target_fake_label))
|
|
241
263
|
self.gan_mode = gan_mode
|
|
264
|
+
self.label_smoothing = label_smoothing
|
|
242
265
|
if gan_mode == 'lsgan':
|
|
243
266
|
self.loss = nn.MSELoss()
|
|
244
267
|
elif gan_mode == 'vanilla':
|
|
@@ -261,9 +284,10 @@ class GANLoss(nn.Module):
|
|
|
261
284
|
|
|
262
285
|
if target_is_real:
|
|
263
286
|
target_tensor = self.real_label
|
|
287
|
+
return target_tensor.expand_as(prediction) * (1 - self.label_smoothing)
|
|
264
288
|
else:
|
|
265
289
|
target_tensor = self.fake_label
|
|
266
|
-
|
|
290
|
+
return target_tensor.expand_as(prediction) * self.label_smoothing
|
|
267
291
|
|
|
268
292
|
def __call__(self, prediction, target_is_real, epsilon=1.0):
|
|
269
293
|
"""Calculate loss given Discriminator's output and grount truth labels.
|
|
@@ -332,9 +356,13 @@ class ResnetGenerator(nn.Module):
|
|
|
332
356
|
"""Resnet-based generator that consists of Resnet blocks between a few downsampling/upsampling operations.
|
|
333
357
|
|
|
334
358
|
We adapt Torch code and idea from Justin Johnson's neural style transfer project(https://github.com/jcjohnson/fast-neural-style)
|
|
359
|
+
|
|
360
|
+
Resize-conv: optional replacement of ConvTranspose2d
|
|
361
|
+
https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/issues/190#issuecomment-358546675
|
|
335
362
|
"""
|
|
336
363
|
|
|
337
|
-
def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, n_blocks=6, padding_type='zero'
|
|
364
|
+
def __init__(self, input_nc, output_nc, ngf=64, norm_layer=nn.BatchNorm2d, use_dropout=False, n_blocks=6, padding_type='zero',
|
|
365
|
+
upsample='convtranspose', use_spectral_norm=False):
|
|
338
366
|
"""Construct a Resnet-based generator
|
|
339
367
|
|
|
340
368
|
Parameters:
|
|
@@ -355,34 +383,54 @@ class ResnetGenerator(nn.Module):
|
|
|
355
383
|
|
|
356
384
|
if padding_type == 'reflect':
|
|
357
385
|
model = [nn.ReflectionPad2d(3),
|
|
358
|
-
|
|
359
|
-
|
|
360
|
-
|
|
386
|
+
SpectralNorm(nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0, bias=use_bias),
|
|
387
|
+
use_spectral_norm=use_spectral_norm),
|
|
388
|
+
norm_layer(ngf),
|
|
389
|
+
nn.ReLU(True)]
|
|
361
390
|
else:
|
|
362
391
|
model = [nn.ZeroPad2d(3),
|
|
363
|
-
|
|
364
|
-
|
|
365
|
-
|
|
392
|
+
SpectralNorm(nn.Conv2d(input_nc, ngf, kernel_size=7, padding=0, bias=use_bias),
|
|
393
|
+
use_spectral_norm=use_spectral_norm),
|
|
394
|
+
norm_layer(ngf),
|
|
395
|
+
nn.ReLU(True)]
|
|
366
396
|
|
|
367
397
|
n_downsampling = 2
|
|
368
398
|
for i in range(n_downsampling): # add downsampling layers
|
|
369
399
|
mult = 2 ** i
|
|
370
|
-
model += [nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1, bias=use_bias),
|
|
400
|
+
model += [SpectralNorm(nn.Conv2d(ngf * mult, ngf * mult * 2, kernel_size=3, stride=2, padding=1, bias=use_bias),use_spectral_norm=use_spectral_norm),
|
|
371
401
|
norm_layer(ngf * mult * 2),
|
|
372
402
|
nn.ReLU(True)]
|
|
373
403
|
|
|
374
404
|
mult = 2 ** n_downsampling
|
|
375
405
|
for i in range(n_blocks): # add ResNet blocks
|
|
376
406
|
|
|
377
|
-
model += [ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout, use_bias=use_bias
|
|
407
|
+
model += [ResnetBlock(ngf * mult, padding_type=padding_type, norm_layer=norm_layer, use_dropout=use_dropout, use_bias=use_bias,
|
|
408
|
+
use_spectral_norm=use_spectral_norm)]
|
|
378
409
|
|
|
379
410
|
for i in range(n_downsampling): # add upsampling layers
|
|
380
411
|
mult = 2 ** (n_downsampling - i)
|
|
381
|
-
|
|
382
|
-
|
|
383
|
-
|
|
384
|
-
|
|
385
|
-
|
|
412
|
+
if upsample == 'resize_conv':
|
|
413
|
+
upsample_layer = [#nn.Upsample(scale_factor = 2, mode='bilinear',align_corners=True),
|
|
414
|
+
nn.Upsample(scale_factor = 2, mode='nearest'),
|
|
415
|
+
nn.ReflectionPad2d(1),
|
|
416
|
+
SpectralNorm(nn.Conv2d(ngf * mult, int(ngf * mult / 2),
|
|
417
|
+
kernel_size=3, stride=1, padding=0),use_spectral_norm=use_spectral_norm)]
|
|
418
|
+
elif upsample == 'pixel_shuffle':
|
|
419
|
+
upsample_layer = [SpectralNorm(nn.Conv2d(ngf * mult, int(ngf * mult * 2),use_spectral_norm=use_spectral_norm),
|
|
420
|
+
kernel_size=3, padding=1),
|
|
421
|
+
nn.PixelShuffle(2),
|
|
422
|
+
nn.ReLU()]
|
|
423
|
+
elif upsample == 'convtranspose':
|
|
424
|
+
upsample_layer = [SpectralNorm(nn.ConvTranspose2d(ngf * mult, int(ngf * mult / 2),
|
|
425
|
+
kernel_size=3, stride=2,
|
|
426
|
+
padding=1, output_padding=1,
|
|
427
|
+
bias=use_bias),
|
|
428
|
+
use_spectral_norm=use_spectral_norm)]
|
|
429
|
+
else:
|
|
430
|
+
raise Exception(f'upsample layer type {upsample} not implemented')
|
|
431
|
+
|
|
432
|
+
model += upsample_layer
|
|
433
|
+
model += [norm_layer(int(ngf * mult / 2)),
|
|
386
434
|
nn.ReLU(True)]
|
|
387
435
|
|
|
388
436
|
if padding_type == 'reflect':
|
|
@@ -390,7 +438,7 @@ class ResnetGenerator(nn.Module):
|
|
|
390
438
|
else:
|
|
391
439
|
model += [nn.ZeroPad2d(3)]
|
|
392
440
|
|
|
393
|
-
model += [nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0)]
|
|
441
|
+
model += [SpectralNorm(nn.Conv2d(ngf, output_nc, kernel_size=7, padding=0),use_spectral_norm=use_spectral_norm)]
|
|
394
442
|
model += [nn.Tanh()]
|
|
395
443
|
|
|
396
444
|
self.model = nn.Sequential(*model)
|
|
@@ -403,7 +451,7 @@ class ResnetGenerator(nn.Module):
|
|
|
403
451
|
class ResnetBlock(nn.Module):
|
|
404
452
|
"""Define a Resnet block"""
|
|
405
453
|
|
|
406
|
-
def __init__(self, dim, padding_type, norm_layer, use_dropout, use_bias):
|
|
454
|
+
def __init__(self, dim, padding_type, norm_layer, use_dropout, use_bias, use_spectral_norm=False):
|
|
407
455
|
"""Initialize the Resnet block
|
|
408
456
|
|
|
409
457
|
A resnet block is a conv block with skip connections
|
|
@@ -412,9 +460,9 @@ class ResnetBlock(nn.Module):
|
|
|
412
460
|
Original Resnet paper: https://arxiv.org/pdf/1512.03385.pdf
|
|
413
461
|
"""
|
|
414
462
|
super(ResnetBlock, self).__init__()
|
|
415
|
-
self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, use_dropout, use_bias)
|
|
463
|
+
self.conv_block = self.build_conv_block(dim, padding_type, norm_layer, use_dropout, use_bias, use_spectral_norm)
|
|
416
464
|
|
|
417
|
-
def build_conv_block(self, dim, padding_type, norm_layer, use_dropout, use_bias):
|
|
465
|
+
def build_conv_block(self, dim, padding_type, norm_layer, use_dropout, use_bias, use_spectral_norm):
|
|
418
466
|
"""Construct a convolutional block.
|
|
419
467
|
|
|
420
468
|
Parameters:
|
|
@@ -437,7 +485,9 @@ class ResnetBlock(nn.Module):
|
|
|
437
485
|
else:
|
|
438
486
|
raise NotImplementedError('padding [%s] is not implemented' % padding_type)
|
|
439
487
|
|
|
440
|
-
conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias),
|
|
488
|
+
conv_block += [SpectralNorm(nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias),use_spectral_norm=use_spectral_norm),
|
|
489
|
+
norm_layer(dim),
|
|
490
|
+
nn.ReLU(True)]
|
|
441
491
|
if use_dropout:
|
|
442
492
|
conv_block += [nn.Dropout(0.5)]
|
|
443
493
|
|
|
@@ -450,7 +500,8 @@ class ResnetBlock(nn.Module):
|
|
|
450
500
|
p = 1
|
|
451
501
|
else:
|
|
452
502
|
raise NotImplementedError('padding [%s] is not implemented' % padding_type)
|
|
453
|
-
conv_block += [nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias),
|
|
503
|
+
conv_block += [SpectralNorm(nn.Conv2d(dim, dim, kernel_size=3, padding=p, bias=use_bias),use_spectral_norm=use_spectral_norm),
|
|
504
|
+
norm_layer(dim)]
|
|
454
505
|
|
|
455
506
|
return nn.Sequential(*conv_block)
|
|
456
507
|
|
|
@@ -565,7 +616,7 @@ class UnetSkipConnectionBlock(nn.Module):
|
|
|
565
616
|
class NLayerDiscriminator(nn.Module):
|
|
566
617
|
"""Defines a PatchGAN discriminator"""
|
|
567
618
|
|
|
568
|
-
def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d):
|
|
619
|
+
def __init__(self, input_nc, ndf=64, n_layers=3, norm_layer=nn.BatchNorm2d, use_spectral_norm=False):
|
|
569
620
|
"""Construct a PatchGAN discriminator
|
|
570
621
|
|
|
571
622
|
Parameters:
|
|
@@ -582,14 +633,15 @@ class NLayerDiscriminator(nn.Module):
|
|
|
582
633
|
|
|
583
634
|
kw = 4
|
|
584
635
|
padw = 1
|
|
585
|
-
sequence = [nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw),
|
|
636
|
+
sequence = [SpectralNorm(nn.Conv2d(input_nc, ndf, kernel_size=kw, stride=2, padding=padw),use_spectral_norm=use_spectral_norm),
|
|
637
|
+
nn.LeakyReLU(0.2, True)]
|
|
586
638
|
nf_mult = 1
|
|
587
639
|
nf_mult_prev = 1
|
|
588
640
|
for n in range(1, n_layers): # gradually increase the number of filters
|
|
589
641
|
nf_mult_prev = nf_mult
|
|
590
642
|
nf_mult = min(2 ** n, 8)
|
|
591
643
|
sequence += [
|
|
592
|
-
nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias),
|
|
644
|
+
SpectralNorm(nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=2, padding=padw, bias=use_bias),use_spectral_norm=use_spectral_norm),
|
|
593
645
|
norm_layer(ndf * nf_mult),
|
|
594
646
|
nn.LeakyReLU(0.2, True)
|
|
595
647
|
]
|
|
@@ -597,12 +649,12 @@ class NLayerDiscriminator(nn.Module):
|
|
|
597
649
|
nf_mult_prev = nf_mult
|
|
598
650
|
nf_mult = min(2 ** n_layers, 8)
|
|
599
651
|
sequence += [
|
|
600
|
-
nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias),
|
|
652
|
+
SpectralNorm(nn.Conv2d(ndf * nf_mult_prev, ndf * nf_mult, kernel_size=kw, stride=1, padding=padw, bias=use_bias),use_spectral_norm=use_spectral_norm),
|
|
601
653
|
norm_layer(ndf * nf_mult),
|
|
602
654
|
nn.LeakyReLU(0.2, True)
|
|
603
655
|
]
|
|
604
656
|
|
|
605
|
-
sequence += [nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw)] # output 1 channel prediction map
|
|
657
|
+
sequence += [SpectralNorm(nn.Conv2d(ndf * nf_mult, 1, kernel_size=kw, stride=1, padding=padw),use_spectral_norm=use_spectral_norm)] # output 1 channel prediction map
|
|
606
658
|
self.model = nn.Sequential(*sequence)
|
|
607
659
|
|
|
608
660
|
def forward(self, input):
|
|
@@ -687,3 +739,25 @@ class VGGLoss(nn.Module):
|
|
|
687
739
|
for i in range(len(x_vgg)):
|
|
688
740
|
loss += self.weights[i] * self.criterion(x_vgg[i], y_vgg[i].detach())
|
|
689
741
|
return loss
|
|
742
|
+
|
|
743
|
+
|
|
744
|
+
class TotalVariationLoss(nn.Module):
|
|
745
|
+
"""
|
|
746
|
+
Absolute difference for neighbouring pixels (i,j to i+1,j, then i,j to i,j+1), averaged on pixel level
|
|
747
|
+
"""
|
|
748
|
+
def __init__(self):
|
|
749
|
+
super(TotalVariationLoss, self).__init__()
|
|
750
|
+
|
|
751
|
+
def forward(self, x):
|
|
752
|
+
tv = torch.abs(x[:,:,1:,:]-x[:,:,:-1,:]).sum() + torch.abs(x[:,:,:,1:]-x[:,:,:,:-1]).sum()
|
|
753
|
+
return tv / torch.prod(torch.tensor(x.size()))
|
|
754
|
+
|
|
755
|
+
def SpectralNorm(x, use_spectral_norm=False):
|
|
756
|
+
"""
|
|
757
|
+
A custom wrapper for nn.utils.parametrizations.spectral_norm,
|
|
758
|
+
with a flag to turn it on or off.
|
|
759
|
+
"""
|
|
760
|
+
if use_spectral_norm:
|
|
761
|
+
return nn.utils.parametrizations.spectral_norm(x)
|
|
762
|
+
else:
|
|
763
|
+
return x
|