deepliif 1.1.11__py3-none-any.whl → 1.1.13__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cli.py +354 -67
- deepliif/data/__init__.py +7 -7
- deepliif/data/aligned_dataset.py +2 -3
- deepliif/data/unaligned_dataset.py +38 -19
- deepliif/models/CycleGAN_model.py +282 -0
- deepliif/models/DeepLIIFExt_model.py +47 -25
- deepliif/models/DeepLIIF_model.py +69 -19
- deepliif/models/SDG_model.py +57 -26
- deepliif/models/__init__ - different weighted.py +762 -0
- deepliif/models/__init__ - run_dask_multi dev.py +943 -0
- deepliif/models/__init__ - time gens.py +792 -0
- deepliif/models/__init__ - timings.py +764 -0
- deepliif/models/__init__.py +359 -265
- deepliif/models/att_unet.py +199 -0
- deepliif/models/base_model.py +32 -8
- deepliif/models/networks.py +108 -34
- deepliif/options/__init__.py +49 -5
- deepliif/postprocessing.py +1034 -227
- deepliif/postprocessing__OLD__DELETE.py +440 -0
- deepliif/util/__init__.py +86 -65
- deepliif/util/visualizer.py +106 -19
- {deepliif-1.1.11.dist-info → deepliif-1.1.13.dist-info}/METADATA +75 -24
- deepliif-1.1.13.dist-info/RECORD +42 -0
- deepliif-1.1.11.dist-info/RECORD +0 -35
- {deepliif-1.1.11.dist-info → deepliif-1.1.13.dist-info}/LICENSE.md +0 -0
- {deepliif-1.1.11.dist-info → deepliif-1.1.13.dist-info}/WHEEL +0 -0
- {deepliif-1.1.11.dist-info → deepliif-1.1.13.dist-info}/entry_points.txt +0 -0
- {deepliif-1.1.11.dist-info → deepliif-1.1.13.dist-info}/top_level.txt +0 -0
cli.py
CHANGED
|
@@ -8,11 +8,12 @@ import cv2
|
|
|
8
8
|
import torch
|
|
9
9
|
import numpy as np
|
|
10
10
|
from PIL import Image
|
|
11
|
+
from torchvision.transforms import ToPILImage
|
|
11
12
|
|
|
12
13
|
from deepliif.data import create_dataset, transform
|
|
13
|
-
from deepliif.models import init_nets, infer_modalities, infer_results_for_wsi, create_model
|
|
14
|
+
from deepliif.models import init_nets, infer_modalities, infer_results_for_wsi, create_model, postprocess
|
|
14
15
|
from deepliif.util import allowed_file, Visualizer, get_information, test_diff_original_serialized, disable_batchnorm_tracking_stats
|
|
15
|
-
from deepliif.util.util import mkdirs
|
|
16
|
+
from deepliif.util.util import mkdirs
|
|
16
17
|
# from deepliif.util import infer_results_for_wsi
|
|
17
18
|
from deepliif.options import Options, print_options
|
|
18
19
|
|
|
@@ -77,6 +78,9 @@ def cli():
|
|
|
77
78
|
@click.option('--modalities-no', default=4, type=int, help='number of targets')
|
|
78
79
|
# model parameters
|
|
79
80
|
@click.option('--model', default='DeepLIIF', help='name of model class')
|
|
81
|
+
@click.option('--seg-weights', default='', type=str, help='weights used to aggregate modality images for the final segmentation image; numbers should add up to 1, and each number corresponds to the modality in order; example: 0.25,0.15,0.25,0.1,0.25')
|
|
82
|
+
@click.option('--loss-weights-g', default='', type=str, help='weights used to aggregate modality-wise losses for the final loss; numbers should add up to 1, and each number corresponds to the modality in order; example: 0.2,0.2,0.2,0.2,0.2')
|
|
83
|
+
@click.option('--loss-weights-d', default='', type=str, help='weights used to aggregate modality-wise losses for the final loss; numbers should add up to 1, and each number corresponds to the modality in order; example: 0.2,0.2,0.2,0.2,0.2')
|
|
80
84
|
@click.option('--input-nc', default=3, help='# of input image channels: 3 for RGB and 1 for grayscale')
|
|
81
85
|
@click.option('--output-nc', default=3, help='# of output image channels: 3 for RGB and 1 for grayscale')
|
|
82
86
|
@click.option('--ngf', default=64, help='# of gen filters in the last conv layer')
|
|
@@ -85,7 +89,7 @@ def cli():
|
|
|
85
89
|
help='specify discriminator architecture [basic | n_layers | pixel]. The basic model is a 70x70 '
|
|
86
90
|
'PatchGAN. n_layers allows you to specify the layers in the discriminator')
|
|
87
91
|
@click.option('--net-g', default='resnet_9blocks',
|
|
88
|
-
help='specify generator architecture [resnet_9blocks | resnet_6blocks | unet_512 | unet_256 | unet_128]')
|
|
92
|
+
help='specify generator architecture [resnet_9blocks | resnet_6blocks | unet_512 | unet_256 | unet_128 | unet_512_attention]; to specify different arch for generators, list arch for each generator separated by comma, e.g., --net-g=resnet_9blocks,resnet_9blocks,resnet_9blocks,unet_512_attention,unet_512_attention')
|
|
89
93
|
@click.option('--n-layers-d', default=4, help='only used if netD==n_layers')
|
|
90
94
|
@click.option('--norm', default='batch',
|
|
91
95
|
help='instance normalization or batch normalization [instance | batch | none]')
|
|
@@ -93,6 +97,8 @@ def cli():
|
|
|
93
97
|
help='network initialization [normal | xavier | kaiming | orthogonal]')
|
|
94
98
|
@click.option('--init-gain', default=0.02, help='scaling factor for normal, xavier and orthogonal.')
|
|
95
99
|
@click.option('--no-dropout', is_flag=True, help='no dropout for the generator')
|
|
100
|
+
@click.option('--upsample', default='convtranspose', help='use upsampling instead of convtranspose [convtranspose | resize_conv | pixel_shuffle]')
|
|
101
|
+
@click.option('--label-smoothing', type=float,default=0.0, help='label smoothing factor to prevent the discriminator from being too confident')
|
|
96
102
|
# dataset parameters
|
|
97
103
|
@click.option('--direction', default='AtoB', help='AtoB or BtoA')
|
|
98
104
|
@click.option('--serial-batches', is_flag=True,
|
|
@@ -128,12 +134,17 @@ def cli():
|
|
|
128
134
|
help='number of epochs with the initial learning rate')
|
|
129
135
|
@click.option('--n-epochs-decay', type=int, default=100,
|
|
130
136
|
help='number of epochs to linearly decay learning rate to zero')
|
|
137
|
+
@click.option('--optimizer', type=str, default='adam',
|
|
138
|
+
help='optimizer from torch.optim to use, applied to both generators and discriminators [adam | sgd | adamw | ...]; the current parameters however are set up for adam, so other optimziers may encounter issue')
|
|
131
139
|
@click.option('--beta1', default=0.5, help='momentum term of adam')
|
|
132
|
-
|
|
140
|
+
#@click.option('--lr', default=0.0002, help='initial learning rate for adam')
|
|
141
|
+
@click.option('--lr-g', default=0.0002, help='initial learning rate for generator adam optimizer')
|
|
142
|
+
@click.option('--lr-d', default=0.0002, help='initial learning rate for discriminator adam optimizer')
|
|
133
143
|
@click.option('--lr-policy', default='linear',
|
|
134
144
|
help='learning rate policy. [linear | step | plateau | cosine]')
|
|
135
145
|
@click.option('--lr-decay-iters', type=int, default=50,
|
|
136
146
|
help='multiply by a gamma every lr_decay_iters iterations')
|
|
147
|
+
@click.option('--seed', type=int, default=None, help='basic seed to be used for deterministic training, default to None (non-deterministic)')
|
|
137
148
|
# visdom and HTML visualization parameters
|
|
138
149
|
@click.option('--display-freq', default=400, help='frequency of showing training results on screen')
|
|
139
150
|
@click.option('--display-ncols', default=4,
|
|
@@ -158,26 +169,32 @@ def cli():
|
|
|
158
169
|
help='chooses how datasets are loaded. [unaligned | aligned | single | colorization]')
|
|
159
170
|
@click.option('--padding', type=str, default='zero',
|
|
160
171
|
help='chooses the type of padding used by resnet generator. [reflect | zero]')
|
|
161
|
-
@click.option('--local-rank', type=int, default=None, help='placeholder argument for torchrun, no need for manual setup')
|
|
162
|
-
@click.option('--seed', type=int, default=None, help='basic seed to be used for deterministic training, default to None (non-deterministic)')
|
|
163
172
|
# DeepLIIFExt params
|
|
164
173
|
@click.option('--seg-gen', type=bool, default=True, help='True (Translation and Segmentation), False (Only Translation).')
|
|
165
174
|
@click.option('--net-ds', type=str, default='n_layers',
|
|
166
175
|
help='specify discriminator architecture for segmentation task [basic | n_layers | pixel]. The basic model is a 70x70 PatchGAN. n_layers allows you to specify the layers in the discriminator')
|
|
167
176
|
@click.option('--net-gs', type=str, default='unet_512',
|
|
168
|
-
help='specify generator architecture for segmentation task [resnet_9blocks | resnet_6blocks | unet_512 | unet_256 | unet_128]')
|
|
177
|
+
help='specify generator architecture for segmentation task [resnet_9blocks | resnet_6blocks | unet_512 | unet_256 | unet_128 | unet_512_attention]; to specify different arch for generators, list arch for each generator separated by comma, e.g., --net-g=resnet_9blocks,resnet_9blocks,resnet_9blocks,unet_512_attention,unet_512_attention')
|
|
169
178
|
@click.option('--gan-mode', type=str, default='vanilla',
|
|
170
179
|
help='the type of GAN objective for translation task. [vanilla| lsgan | wgangp]. vanilla GAN loss is the cross-entropy objective used in the original GAN paper.')
|
|
171
180
|
@click.option('--gan-mode-s', type=str, default='lsgan',
|
|
172
181
|
help='the type of GAN objective for segmentation task. [vanilla| lsgan | wgangp]. vanilla GAN loss is the cross-entropy objective used in the original GAN paper.')
|
|
182
|
+
# DDP related arguments
|
|
183
|
+
@click.option('--local-rank', type=int, default=None, help='placeholder argument for torchrun, no need for manual setup')
|
|
184
|
+
# Others
|
|
185
|
+
@click.option('--with-val', is_flag=True,
|
|
186
|
+
help='use validation set to evaluate model performance at the end of each epoch')
|
|
187
|
+
@click.option('--debug', is_flag=True,
|
|
188
|
+
help='debug mode, limits the number of data points per epoch to a small value')
|
|
189
|
+
@click.option('--debug-data-size', default=10, type=int, help='data size per epoch used in debug mode; due to batch size, the epoch will be passed once the completed no. data points is greater than this value (e.g., for batch size 3, debug data size 10, the effective size used in training will be 12)')
|
|
173
190
|
def train(dataroot, name, gpu_ids, checkpoints_dir, input_nc, output_nc, ngf, ndf, net_d, net_g,
|
|
174
|
-
n_layers_d, norm, init_type, init_gain, no_dropout, direction, serial_batches, num_threads,
|
|
191
|
+
n_layers_d, norm, init_type, init_gain, no_dropout, upsample, label_smoothing, direction, serial_batches, num_threads,
|
|
175
192
|
batch_size, load_size, crop_size, max_dataset_size, preprocess, no_flip, display_winsize, epoch, load_iter,
|
|
176
193
|
verbose, lambda_l1, is_train, display_freq, display_ncols, display_id, display_server, display_env,
|
|
177
194
|
display_port, update_html_freq, print_freq, no_html, save_latest_freq, save_epoch_freq, save_by_iter,
|
|
178
|
-
continue_train, epoch_count, phase, lr_policy, n_epochs, n_epochs_decay, beta1,
|
|
179
|
-
remote,
|
|
180
|
-
modalities_no, seg_gen, net_ds, net_gs, gan_mode, gan_mode_s):
|
|
195
|
+
continue_train, epoch_count, phase, lr_policy, n_epochs, n_epochs_decay, optimizer, beta1, lr_g, lr_d, lr_decay_iters,
|
|
196
|
+
remote, remote_transfer_cmd, seed, dataset_mode, padding, model, seg_weights, loss_weights_g, loss_weights_d,
|
|
197
|
+
modalities_no, seg_gen, net_ds, net_gs, gan_mode, gan_mode_s, local_rank, with_val, debug, debug_data_size):
|
|
181
198
|
"""General-purpose training script for multi-task image-to-image translation.
|
|
182
199
|
|
|
183
200
|
This script works for various models (with option '--model': e.g., DeepLIIF) and
|
|
@@ -189,7 +206,7 @@ def train(dataroot, name, gpu_ids, checkpoints_dir, input_nc, output_nc, ngf, nd
|
|
|
189
206
|
plot, and save models.The script supports continue/resume training.
|
|
190
207
|
Use '--continue_train' to resume your previous training.
|
|
191
208
|
"""
|
|
192
|
-
assert model in ['DeepLIIF','DeepLIIFExt','SDG'], f'model class {model} is not implemented'
|
|
209
|
+
assert model in ['DeepLIIF','DeepLIIFExt','SDG','CycleGAN'], f'model class {model} is not implemented'
|
|
193
210
|
if model == 'DeepLIIF':
|
|
194
211
|
seg_no = 1
|
|
195
212
|
elif model == 'DeepLIIFExt':
|
|
@@ -197,10 +214,16 @@ def train(dataroot, name, gpu_ids, checkpoints_dir, input_nc, output_nc, ngf, nd
|
|
|
197
214
|
seg_no = modalities_no
|
|
198
215
|
else:
|
|
199
216
|
seg_no = 0
|
|
200
|
-
else: # SDG
|
|
217
|
+
else: # SDG, CycleGAN
|
|
201
218
|
seg_no = 0
|
|
202
219
|
seg_gen = False
|
|
203
220
|
|
|
221
|
+
if model == 'CycleGAN':
|
|
222
|
+
dataset_mode = "unaligned"
|
|
223
|
+
|
|
224
|
+
if optimizer != 'adam':
|
|
225
|
+
print(f'Optimizer torch.optim.{optimizer} is not tested. Be careful about the parameters of the optimizer.')
|
|
226
|
+
|
|
204
227
|
d_params = locals()
|
|
205
228
|
|
|
206
229
|
if gpu_ids and gpu_ids[0] == -1:
|
|
@@ -213,12 +236,12 @@ def train(dataroot, name, gpu_ids, checkpoints_dir, input_nc, output_nc, ngf, nd
|
|
|
213
236
|
if local_rank is not None:
|
|
214
237
|
local_rank = int(local_rank)
|
|
215
238
|
torch.cuda.set_device(gpu_ids[local_rank])
|
|
216
|
-
gpu_ids=[
|
|
239
|
+
gpu_ids=[local_rank]
|
|
217
240
|
else:
|
|
218
241
|
torch.cuda.set_device(gpu_ids[0])
|
|
219
242
|
|
|
220
243
|
if local_rank is not None: # LOCAL_RANK will be assigned a rank number if torchrun ddp is used
|
|
221
|
-
dist.init_process_group(backend=
|
|
244
|
+
dist.init_process_group(backend="nccl", rank=int(os.environ['RANK']), world_size=int(os.environ['WORLD_SIZE']))
|
|
222
245
|
print('local rank:',local_rank)
|
|
223
246
|
flag_deterministic = set_seed(seed,local_rank)
|
|
224
247
|
elif rank is not None:
|
|
@@ -231,29 +254,127 @@ def train(dataroot, name, gpu_ids, checkpoints_dir, input_nc, output_nc, ngf, nd
|
|
|
231
254
|
print('padding type is forced to zero padding, because neither refection pad2d or replication pad2d has a deterministic implementation')
|
|
232
255
|
|
|
233
256
|
# infer number of input images
|
|
234
|
-
dir_data_train = dataroot + '/train'
|
|
235
|
-
fns = os.listdir(dir_data_train)
|
|
236
|
-
fns = [x for x in fns if x.endswith('.png')]
|
|
237
|
-
img = Image.open(f"{dir_data_train}/{fns[0]}")
|
|
238
257
|
|
|
239
|
-
num_img = img.size[0] / img.size[1]
|
|
240
|
-
assert int(num_img) == num_img, f'img size {img.size[0]} / {img.size[1]} = {num_img} is not an integer'
|
|
241
|
-
num_img = int(num_img)
|
|
242
258
|
|
|
243
|
-
|
|
244
|
-
|
|
259
|
+
if dataset_mode == 'unaligned':
|
|
260
|
+
dir_data_train = dataroot + '/trainA'
|
|
261
|
+
fns = os.listdir(dir_data_train)
|
|
262
|
+
fns = [x for x in fns if x.endswith('.png')]
|
|
263
|
+
print(f'{len(fns)} images found in trainA')
|
|
264
|
+
img = Image.open(f"{dir_data_train}/{fns[0]}")
|
|
265
|
+
print(f'image shape:',img.size)
|
|
266
|
+
|
|
267
|
+
for i in range(1, modalities_no + 1):
|
|
268
|
+
dir_data_train = dataroot + f'/trainB{i}'
|
|
269
|
+
fns = os.listdir(dir_data_train)
|
|
270
|
+
fns = [x for x in fns if x.endswith('.png')]
|
|
271
|
+
print(f'{len(fns)} images found in trainB{i}')
|
|
272
|
+
img = Image.open(f"{dir_data_train}/{fns[0]}")
|
|
273
|
+
print(f'image shape:',img.size)
|
|
274
|
+
|
|
275
|
+
input_no = 1
|
|
276
|
+
num_img = None
|
|
277
|
+
|
|
278
|
+
lambda_identity = 0
|
|
279
|
+
pool_size = 50 # https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/scripts/train_cyclegan.sh
|
|
280
|
+
|
|
281
|
+
else:
|
|
282
|
+
dir_data_train = dataroot + '/train'
|
|
283
|
+
fns = os.listdir(dir_data_train)
|
|
284
|
+
fns = [x for x in fns if x.endswith('.png')]
|
|
285
|
+
print(f'{len(fns)} images found')
|
|
286
|
+
img = Image.open(f"{dir_data_train}/{fns[0]}")
|
|
287
|
+
print(f'image shape:',img.size)
|
|
288
|
+
|
|
289
|
+
num_img = img.size[0] / img.size[1]
|
|
290
|
+
assert int(num_img) == num_img, f'img size {img.size[0]} / {img.size[1]} = {num_img} is not an integer'
|
|
291
|
+
num_img = int(num_img)
|
|
292
|
+
|
|
293
|
+
input_no = num_img - modalities_no - seg_no
|
|
294
|
+
assert input_no > 0, f'inferred number of input images is {input_no} (modalities_no {modalities_no}, seg_no {seg_no}); should be greater than 0'
|
|
295
|
+
|
|
296
|
+
pool_size = 0
|
|
297
|
+
|
|
245
298
|
d_params['input_no'] = input_no
|
|
246
299
|
d_params['scale_size'] = img.size[1]
|
|
300
|
+
d_params['gpu_ids'] = gpu_ids
|
|
301
|
+
d_params['lambda_identity'] = 0
|
|
302
|
+
d_params['pool_size'] = pool_size
|
|
303
|
+
|
|
304
|
+
|
|
305
|
+
# update generator arch
|
|
306
|
+
net_g = net_g.split(',')
|
|
307
|
+
assert len(net_g) in [1,modalities_no], f'net_g should contain either 1 architecture for all translation generators or the same number of architectures as the number of translation generators ({modalities_no})'
|
|
308
|
+
if len(net_g) == 1:
|
|
309
|
+
net_g = net_g*modalities_no
|
|
310
|
+
|
|
311
|
+
net_gs = net_gs.split(',')
|
|
312
|
+
assert len(net_gs) in [1,seg_no], f'net_gs should contain either 1 architecture for all segmentation generators or the same number of architectures as the number of segmentation generators ({seg_no})'
|
|
313
|
+
if len(net_gs) == 1 and model == 'DeepLIIF':
|
|
314
|
+
net_gs = net_gs*(modalities_no + seg_no)
|
|
315
|
+
elif len(net_gs) == 1:
|
|
316
|
+
net_gs = net_gs*seg_no
|
|
317
|
+
|
|
318
|
+
d_params['net_g'] = net_g
|
|
319
|
+
d_params['net_gs'] = net_gs
|
|
320
|
+
|
|
321
|
+
# check seg weights and loss weights
|
|
322
|
+
if len(d_params['seg_weights']) == 0:
|
|
323
|
+
seg_weights = [0.25,0.15,0.25,0.1,0.25] if d_params['model'] == 'DeepLIIF' else [1 / modalities_no] * modalities_no
|
|
324
|
+
else:
|
|
325
|
+
seg_weights = [float(x) for x in seg_weights.split(',')]
|
|
326
|
+
|
|
327
|
+
if len(d_params['loss_weights_g']) == 0:
|
|
328
|
+
loss_weights_g = [0.2]*5 if d_params['model'] == 'DeepLIIF' else [1 / modalities_no] * modalities_no
|
|
329
|
+
else:
|
|
330
|
+
loss_weights_g = [float(x) for x in loss_weights_g.split(',')]
|
|
331
|
+
|
|
332
|
+
if len(d_params['loss_weights_d']) == 0:
|
|
333
|
+
loss_weights_d = [0.2]*5 if d_params['model'] == 'DeepLIIF' else [1 / modalities_no] * modalities_no
|
|
334
|
+
else:
|
|
335
|
+
loss_weights_d = [float(x) for x in loss_weights_d.split(',')]
|
|
336
|
+
|
|
337
|
+
assert sum(seg_weights) == 1, 'seg weights should add up to 1'
|
|
338
|
+
assert sum(loss_weights_g) == 1, 'loss weights g should add up to 1'
|
|
339
|
+
assert sum(loss_weights_d) == 1, 'loss weights d should add up to 1'
|
|
340
|
+
|
|
341
|
+
if model == 'DeepLIIF':
|
|
342
|
+
# +1 because input becomes an additional modality used in generating the final segmentation
|
|
343
|
+
assert len(seg_weights) == modalities_no+1, 'seg weights should have the same number of elements as number of modalities to be generated'
|
|
344
|
+
assert len(loss_weights_g) == modalities_no+1, 'loss weights g should have the same number of elements as number of modalities to be generated'
|
|
345
|
+
assert len(loss_weights_d) == modalities_no+1, 'loss weights d should have the same number of elements as number of modalities to be generated'
|
|
247
346
|
|
|
347
|
+
else:
|
|
348
|
+
assert len(seg_weights) == modalities_no, 'seg weights should have the same number of elements as number of modalities to be generated'
|
|
349
|
+
assert len(loss_weights_g) == modalities_no, 'loss weights g should have the same number of elements as number of modalities to be generated'
|
|
350
|
+
assert len(loss_weights_d) == modalities_no, 'loss weights d should have the same number of elements as number of modalities to be generated'
|
|
351
|
+
|
|
352
|
+
d_params['seg_weights'] = seg_weights
|
|
353
|
+
d_params['loss_G_weights'] = loss_weights_g
|
|
354
|
+
d_params['loss_D_weights'] = loss_weights_d
|
|
355
|
+
|
|
356
|
+
del d_params['loss_weights_g']
|
|
357
|
+
del d_params['loss_weights_d']
|
|
358
|
+
|
|
248
359
|
# create a dataset given dataset_mode and other options
|
|
249
360
|
# dataset = AlignedDataset(opt)
|
|
250
361
|
|
|
251
362
|
opt = Options(d_params=d_params)
|
|
252
363
|
print_options(opt, save=True)
|
|
253
364
|
|
|
365
|
+
# set dir for train and val
|
|
254
366
|
dataset = create_dataset(opt)
|
|
367
|
+
|
|
255
368
|
# get the number of images in the dataset.
|
|
256
369
|
click.echo('The number of training images = %d' % len(dataset))
|
|
370
|
+
|
|
371
|
+
if with_val:
|
|
372
|
+
dataset_val = create_dataset(opt,phase='val')
|
|
373
|
+
data_val = [batch for batch in dataset_val]
|
|
374
|
+
click.echo('The number of validation images = %d' % len(dataset_val))
|
|
375
|
+
|
|
376
|
+
if model in ['DeepLIIF']:
|
|
377
|
+
metrics_val = json.load(open(os.path.join(dataset_val.dataset.dir_AB,'metrics.json')))
|
|
257
378
|
|
|
258
379
|
# create a model given model and other options
|
|
259
380
|
model = create_model(opt)
|
|
@@ -299,15 +420,15 @@ def train(dataroot, name, gpu_ids, checkpoints_dir, input_nc, output_nc, ngf, nd
|
|
|
299
420
|
if total_iters % display_freq == 0:
|
|
300
421
|
save_result = total_iters % update_html_freq == 0
|
|
301
422
|
model.compute_visuals()
|
|
302
|
-
visualizer.display_current_results(model.get_current_visuals(), epoch, save_result)
|
|
423
|
+
visualizer.display_current_results({**model.get_current_visuals()}, epoch, save_result)
|
|
303
424
|
|
|
304
425
|
# print training losses and save logging information to the disk
|
|
305
426
|
if total_iters % print_freq == 0:
|
|
306
|
-
losses = model.get_current_losses()
|
|
427
|
+
losses = model.get_current_losses() # get training losses
|
|
307
428
|
t_comp = (time.time() - iter_start_time) / batch_size
|
|
308
|
-
visualizer.print_current_losses(epoch, epoch_iter, losses, t_comp, t_data)
|
|
429
|
+
visualizer.print_current_losses(epoch, epoch_iter, {**losses}, t_comp, t_data)
|
|
309
430
|
if display_id > 0:
|
|
310
|
-
visualizer.plot_current_losses(epoch, float(epoch_iter) / len(dataset), losses)
|
|
431
|
+
visualizer.plot_current_losses(epoch, float(epoch_iter) / len(dataset), {**losses})
|
|
311
432
|
|
|
312
433
|
# cache our latest model every <save_latest_freq> iterations
|
|
313
434
|
if total_iters % save_latest_freq == 0:
|
|
@@ -315,7 +436,11 @@ def train(dataroot, name, gpu_ids, checkpoints_dir, input_nc, output_nc, ngf, nd
|
|
|
315
436
|
save_suffix = 'iter_%d' % total_iters if save_by_iter else 'latest'
|
|
316
437
|
model.save_networks(save_suffix)
|
|
317
438
|
|
|
439
|
+
|
|
318
440
|
iter_data_time = time.time()
|
|
441
|
+
if debug and epoch_iter >= debug_data_size:
|
|
442
|
+
print(f'debug mode, epoch {epoch} stopped at epoch iter {epoch_iter} (>= {debug_data_size})')
|
|
443
|
+
break
|
|
319
444
|
|
|
320
445
|
# cache our model every <save_epoch_freq> epochs
|
|
321
446
|
if epoch % save_epoch_freq == 0:
|
|
@@ -323,6 +448,77 @@ def train(dataroot, name, gpu_ids, checkpoints_dir, input_nc, output_nc, ngf, nd
|
|
|
323
448
|
model.save_networks('latest')
|
|
324
449
|
model.save_networks(epoch)
|
|
325
450
|
|
|
451
|
+
|
|
452
|
+
|
|
453
|
+
# validation loss and metrics calculation
|
|
454
|
+
if with_val:
|
|
455
|
+
losses = model.get_current_losses() # get training losses to print
|
|
456
|
+
|
|
457
|
+
model.eval()
|
|
458
|
+
l_losses_val = []
|
|
459
|
+
l_metrics_val = []
|
|
460
|
+
|
|
461
|
+
# for each val image, calculate validation loss and cell count metrics
|
|
462
|
+
for j, data_val_batch in enumerate(data_val):
|
|
463
|
+
# batch size is effectively 1 for validation
|
|
464
|
+
model.set_input(data_val_batch)
|
|
465
|
+
model.calculate_losses() # this does not optimize parameters
|
|
466
|
+
visuals = model.get_current_visuals() # get image results
|
|
467
|
+
|
|
468
|
+
# val losses
|
|
469
|
+
losses_val_batch = model.get_current_losses()
|
|
470
|
+
l_losses_val += [(k,v) for k,v in losses_val_batch.items()]
|
|
471
|
+
|
|
472
|
+
# calculate cell count metrics
|
|
473
|
+
if type(model).__name__ == 'DeepLIIFModel':
|
|
474
|
+
l_seg_names = ['fake_B_5']
|
|
475
|
+
assert l_seg_names[0] in visuals.keys(), f'Cannot find {l_seg_names[0]} in generated image names ({list(visuals.keys())})'
|
|
476
|
+
seg_mod_suffix = l_seg_names[0].split('_')[-1]
|
|
477
|
+
l_seg_names += [x for x in visuals.keys() if x.startswith('fake') and x.split('_')[-1].startswith(seg_mod_suffix) and x != l_seg_names[0]]
|
|
478
|
+
# print(f'Running postprocess for {len(l_seg_names)} generated images ({l_seg_names})')
|
|
479
|
+
|
|
480
|
+
img_name_current = data_val_batch['A_paths'][0].split('/')[-1][:-4] # remove .png
|
|
481
|
+
metrics_gt = metrics_val[img_name_current]
|
|
482
|
+
|
|
483
|
+
for seg_name in l_seg_names:
|
|
484
|
+
images = {'Seg':ToPILImage()((visuals[seg_name][0].cpu()+1)/2),
|
|
485
|
+
#'Marker':ToPILImage()((visuals['fake_B_4'][0].cpu()+1)/2)
|
|
486
|
+
}
|
|
487
|
+
_, scoring = postprocess(ToPILImage()((data['A'][0]+1)/2), images, opt.scale_size, opt.model)
|
|
488
|
+
|
|
489
|
+
for k,v in scoring.items():
|
|
490
|
+
if k.startswith('num') or k.startswith('percent'):
|
|
491
|
+
# to calculate the rmse, here we calculate (x_pred - x_true) ** 2
|
|
492
|
+
l_metrics_val.append((k+'_'+seg_name,(v - metrics_gt[k])**2))
|
|
493
|
+
|
|
494
|
+
if debug and epoch_iter >= debug_data_size:
|
|
495
|
+
print(f'debug mode, epoch {epoch} stopped at epoch iter {epoch_iter} (>= {debug_data_size})')
|
|
496
|
+
break
|
|
497
|
+
|
|
498
|
+
d_losses_val = {k+'_val':0 for k in losses_val_batch.keys()}
|
|
499
|
+
for k,v in l_losses_val:
|
|
500
|
+
d_losses_val[k+'_val'] += v
|
|
501
|
+
for k in d_losses_val:
|
|
502
|
+
d_losses_val[k] = d_losses_val[k] / len(data_val)
|
|
503
|
+
|
|
504
|
+
d_metrics_val = {}
|
|
505
|
+
for k,v in l_metrics_val:
|
|
506
|
+
try:
|
|
507
|
+
d_metrics_val[k] += v
|
|
508
|
+
except:
|
|
509
|
+
d_metrics_val[k] = v
|
|
510
|
+
for k in d_metrics_val:
|
|
511
|
+
# to calculate the rmse, this is the second part, where d_metrics_val[k] now represents sum((x_pred - x_true) ** 2)
|
|
512
|
+
d_metrics_val[k] = np.sqrt(d_metrics_val[k] / len(data_val))
|
|
513
|
+
|
|
514
|
+
|
|
515
|
+
model.train()
|
|
516
|
+
t_comp = (time.time() - iter_start_time) / batch_size
|
|
517
|
+
visualizer.print_current_losses(epoch, epoch_iter, {**losses,**d_losses_val, **d_metrics_val}, t_comp, t_data)
|
|
518
|
+
if display_id > 0:
|
|
519
|
+
visualizer.plot_current_losses(epoch, float(epoch_iter) / len(dataset), {**losses,**d_losses_val,**d_metrics_val})
|
|
520
|
+
|
|
521
|
+
|
|
326
522
|
print('End of epoch %d / %d \t Time Taken: %d sec' % (
|
|
327
523
|
epoch, n_epochs + n_epochs_decay, time.time() - epoch_start_time))
|
|
328
524
|
# update learning rates at the end of every epoch.
|
|
@@ -336,8 +532,12 @@ def train(dataroot, name, gpu_ids, checkpoints_dir, input_nc, output_nc, ngf, nd
|
|
|
336
532
|
help='name of the experiment. It decides where to store samples and models')
|
|
337
533
|
@click.option('--gpu-ids', type=int, multiple=True, help='gpu-ids 0 gpu-ids 1 or gpu-ids -1 for CPU')
|
|
338
534
|
@click.option('--checkpoints-dir', default='./checkpoints', help='models are saved here')
|
|
339
|
-
@click.option('--
|
|
535
|
+
@click.option('--modalities-no', default=4, type=int, help='number of targets')
|
|
340
536
|
# model parameters
|
|
537
|
+
@click.option('--model', default='DeepLIIF', help='name of model class')
|
|
538
|
+
@click.option('--seg-weights', default='', type=str, help='weights used to aggregate modality images for the final segmentation image; numbers should add up to 1, and each number corresponds to the modality in order; example: 0.25,0.15,0.25,0.1,0.25')
|
|
539
|
+
@click.option('--loss-weights-g', default='', type=str, help='weights used to aggregate modality-wise losses for the final loss; numbers should add up to 1, and each number corresponds to the modality in order; example: 0.2,0.2,0.2,0.2,0.2')
|
|
540
|
+
@click.option('--loss-weights-d', default='', type=str, help='weights used to aggregate modality-wise losses for the final loss; numbers should add up to 1, and each number corresponds to the modality in order; example: 0.2,0.2,0.2,0.2,0.2')
|
|
341
541
|
@click.option('--input-nc', default=3, help='# of input image channels: 3 for RGB and 1 for grayscale')
|
|
342
542
|
@click.option('--output-nc', default=3, help='# of output image channels: 3 for RGB and 1 for grayscale')
|
|
343
543
|
@click.option('--ngf', default=64, help='# of gen filters in the last conv layer')
|
|
@@ -346,15 +546,16 @@ def train(dataroot, name, gpu_ids, checkpoints_dir, input_nc, output_nc, ngf, nd
|
|
|
346
546
|
help='specify discriminator architecture [basic | n_layers | pixel]. The basic model is a 70x70 '
|
|
347
547
|
'PatchGAN. n_layers allows you to specify the layers in the discriminator')
|
|
348
548
|
@click.option('--net-g', default='resnet_9blocks',
|
|
349
|
-
help='specify generator architecture [resnet_9blocks | resnet_6blocks | unet_512 | unet_256 | unet_128]')
|
|
549
|
+
help='specify generator architecture [resnet_9blocks | resnet_6blocks | unet_512 | unet_256 | unet_128 | unet_512_attention]; to specify different arch for generators, list arch for each generator separated by comma, e.g., --net-g=resnet_9blocks,resnet_9blocks,resnet_9blocks,unet_512_attention,unet_512_attention')
|
|
350
550
|
@click.option('--n-layers-d', default=4, help='only used if netD==n_layers')
|
|
351
551
|
@click.option('--norm', default='batch',
|
|
352
552
|
help='instance normalization or batch normalization [instance | batch | none]')
|
|
353
553
|
@click.option('--init-type', default='normal',
|
|
354
554
|
help='network initialization [normal | xavier | kaiming | orthogonal]')
|
|
355
555
|
@click.option('--init-gain', default=0.02, help='scaling factor for normal, xavier and orthogonal.')
|
|
356
|
-
@click.option('--padding-type', default='reflect', help='network padding type.')
|
|
357
556
|
@click.option('--no-dropout', is_flag=True, help='no dropout for the generator')
|
|
557
|
+
@click.option('--upsample', default='convtranspose', help='use upsampling instead of convtranspose [convtranspose | resize_conv | pixel_shuffle]')
|
|
558
|
+
@click.option('--label-smoothing', type=float,default=0.0, help='label smoothing factor to prevent the discriminator from being too confident')
|
|
358
559
|
# dataset parameters
|
|
359
560
|
@click.option('--direction', default='AtoB', help='AtoB or BtoA')
|
|
360
561
|
@click.option('--serial-batches', is_flag=True,
|
|
@@ -390,12 +591,17 @@ def train(dataroot, name, gpu_ids, checkpoints_dir, input_nc, output_nc, ngf, nd
|
|
|
390
591
|
help='number of epochs with the initial learning rate')
|
|
391
592
|
@click.option('--n-epochs-decay', type=int, default=100,
|
|
392
593
|
help='number of epochs to linearly decay learning rate to zero')
|
|
594
|
+
@click.option('--optimizer', type=str, default='adam',
|
|
595
|
+
help='optimizer from torch.optim to use, applied to both generators and discriminators [adam | sgd | adamw | ...]; the current parameters however are set up for adam, so other optimziers may encounter issue')
|
|
393
596
|
@click.option('--beta1', default=0.5, help='momentum term of adam')
|
|
394
|
-
|
|
597
|
+
#@click.option('--lr', default=0.0002, help='initial learning rate for adam')
|
|
598
|
+
@click.option('--lr-g', default=0.0002, help='initial learning rate for generator adam optimizer')
|
|
599
|
+
@click.option('--lr-d', default=0.0002, help='initial learning rate for discriminator adam optimizer')
|
|
395
600
|
@click.option('--lr-policy', default='linear',
|
|
396
601
|
help='learning rate policy. [linear | step | plateau | cosine]')
|
|
397
602
|
@click.option('--lr-decay-iters', type=int, default=50,
|
|
398
603
|
help='multiply by a gamma every lr_decay_iters iterations')
|
|
604
|
+
@click.option('--seed', type=int, default=None, help='basic seed to be used for deterministic training, default to None (non-deterministic)')
|
|
399
605
|
# visdom and HTML visualization parameters
|
|
400
606
|
@click.option('--display-freq', default=400, help='frequency of showing training results on screen')
|
|
401
607
|
@click.option('--display-ncols', default=4,
|
|
@@ -416,8 +622,29 @@ def train(dataroot, name, gpu_ids, checkpoints_dir, input_nc, output_nc, ngf, nd
|
|
|
416
622
|
@click.option('--save-by-iter', is_flag=True, help='whether saves model by iteration')
|
|
417
623
|
@click.option('--remote', type=bool, default=False, help='whether isolate visdom checkpoints or not; if False, you can run a separate visdom server anywhere that consumes the checkpoints')
|
|
418
624
|
@click.option('--remote-transfer-cmd', type=str, default=None, help='module and function to be used to transfer remote files to target storage location, for example mymodule.myfunction')
|
|
625
|
+
@click.option('--dataset-mode', type=str, default='aligned',
|
|
626
|
+
help='chooses how datasets are loaded. [unaligned | aligned | single | colorization]')
|
|
627
|
+
@click.option('--padding', type=str, default='zero',
|
|
628
|
+
help='chooses the type of padding used by resnet generator. [reflect | zero]')
|
|
629
|
+
# DeepLIIFExt params
|
|
630
|
+
@click.option('--seg-gen', type=bool, default=True, help='True (Translation and Segmentation), False (Only Translation).')
|
|
631
|
+
@click.option('--net-ds', type=str, default='n_layers',
|
|
632
|
+
help='specify discriminator architecture for segmentation task [basic | n_layers | pixel]. The basic model is a 70x70 PatchGAN. n_layers allows you to specify the layers in the discriminator')
|
|
633
|
+
@click.option('--net-gs', type=str, default='unet_512',
|
|
634
|
+
help='specify generator architecture for segmentation task [resnet_9blocks | resnet_6blocks | unet_512 | unet_256 | unet_128 | unet_512_attention]; to specify different arch for generators, list arch for each generator separated by comma, e.g., --net-g=resnet_9blocks,resnet_9blocks,resnet_9blocks,unet_512_attention,unet_512_attention')
|
|
635
|
+
@click.option('--gan-mode', type=str, default='vanilla',
|
|
636
|
+
help='the type of GAN objective for translation task. [vanilla| lsgan | wgangp]. vanilla GAN loss is the cross-entropy objective used in the original GAN paper.')
|
|
637
|
+
@click.option('--gan-mode-s', type=str, default='lsgan',
|
|
638
|
+
help='the type of GAN objective for segmentation task. [vanilla| lsgan | wgangp]. vanilla GAN loss is the cross-entropy objective used in the original GAN paper.')
|
|
639
|
+
# DDP related arguments
|
|
419
640
|
@click.option('--local-rank', type=int, default=None, help='placeholder argument for torchrun, no need for manual setup')
|
|
420
|
-
|
|
641
|
+
# Others
|
|
642
|
+
@click.option('--with-val', is_flag=True,
|
|
643
|
+
help='use validation set to evaluate model performance at the end of each epoch')
|
|
644
|
+
@click.option('--debug', is_flag=True,
|
|
645
|
+
help='debug mode, limits the number of data points per epoch to a small value')
|
|
646
|
+
@click.option('--debug-data-size', default=10, type=int, help='data size per epoch used in debug mode; due to batch size, the epoch will be passed once the completed no. data points is greater than this value (e.g., for batch size 3, debug data size 10, the effective size used in training will be 12)')
|
|
647
|
+
# trainlaunch DDP related arguments
|
|
421
648
|
@click.option('--use-torchrun', type=str, default=None, help='provide torchrun options, all in one string, for example "-t3 --log_dir ~/log/ --nproc_per_node 1"; if your pytorch version is older than 1.10, torch.distributed.launch will be called instead of torchrun')
|
|
422
649
|
def trainlaunch(**kwargs):
|
|
423
650
|
"""
|
|
@@ -448,6 +675,7 @@ def trainlaunch(**kwargs):
|
|
|
448
675
|
elif args[i-1] not in l_arg_skip and arg not in l_arg_skip:
|
|
449
676
|
# if the previous element is not an option name to skip AND if the current element is not an option to remove
|
|
450
677
|
args_final.append(arg)
|
|
678
|
+
|
|
451
679
|
|
|
452
680
|
## add quotes back to the input arg that had quotes, e.g., experiment name
|
|
453
681
|
args_final = [f'"{arg}"' if ' ' in arg else arg for arg in args_final]
|
|
@@ -457,16 +685,29 @@ def trainlaunch(**kwargs):
|
|
|
457
685
|
|
|
458
686
|
#### locate train.py
|
|
459
687
|
import deepliif
|
|
460
|
-
path_train_py = deepliif.__path__[0]+'/train.py'
|
|
688
|
+
path_train_py = deepliif.__path__[0]+'/scripts/train.py'
|
|
689
|
+
|
|
690
|
+
#### find out GPUs to use
|
|
691
|
+
gpu_ids = [args_final[i+1] for i,v in enumerate(args_final) if v=='--gpu-ids']
|
|
692
|
+
if len(gpu_ids) > 0 and gpu_ids[0] == -1:
|
|
693
|
+
gpu_ids = []
|
|
694
|
+
|
|
695
|
+
if len(gpu_ids) > 0:
|
|
696
|
+
opt_env = f"CUDA_VISIBLE_DEVICES=\"{','.join(gpu_ids)}\""
|
|
697
|
+
else:
|
|
698
|
+
opt_env = ''
|
|
461
699
|
|
|
462
700
|
#### execute train.py
|
|
463
701
|
if kwargs['use_torchrun']:
|
|
464
702
|
if version.parse(torch.__version__) >= version.parse('1.10.0'):
|
|
465
|
-
|
|
703
|
+
cmd = f'{opt_env} torchrun {kwargs["use_torchrun"]} {path_train_py} {options}'
|
|
466
704
|
else:
|
|
467
|
-
|
|
705
|
+
cmd = f'{opt_env} python -m torch.distributed.launch {kwargs["use_torchrun"]} {path_train_py} {options}'
|
|
468
706
|
else:
|
|
469
|
-
|
|
707
|
+
cmd = f'{opt_env} python {path_train_py} {options}'
|
|
708
|
+
|
|
709
|
+
print('Executing command:',cmd)
|
|
710
|
+
subprocess.run(cmd,shell=True)
|
|
470
711
|
|
|
471
712
|
|
|
472
713
|
|
|
@@ -475,9 +716,10 @@ def trainlaunch(**kwargs):
|
|
|
475
716
|
@click.option('--model-dir', default='./model-server/DeepLIIF_Latest_Model', help='reads models from here')
|
|
476
717
|
@click.option('--output-dir', help='saves results here.')
|
|
477
718
|
#@click.option('--tile-size', type=int, default=None, help='tile size')
|
|
478
|
-
@click.option('--device', default='cpu', type=str, help='device to load model for the similarity test, either cpu or gpu')
|
|
719
|
+
@click.option('--device', default='cpu', type=str, help='device to run serialization as well as load model for the similarity test, either cpu or gpu')
|
|
720
|
+
@click.option('--epoch', default='latest', type=str, help='epoch to load and serialize')
|
|
479
721
|
@click.option('--verbose', default=0, type=int,help='saves results here.')
|
|
480
|
-
def serialize(model_dir, output_dir, device, verbose):
|
|
722
|
+
def serialize(model_dir, output_dir, device, epoch, verbose):
|
|
481
723
|
"""Serialize DeepLIIF models using Torchscript
|
|
482
724
|
"""
|
|
483
725
|
#if tile_size is None:
|
|
@@ -490,12 +732,20 @@ def serialize(model_dir, output_dir, device, verbose):
|
|
|
490
732
|
if model_dir != output_dir:
|
|
491
733
|
shutil.copy(f'{model_dir}/train_opt.txt',f'{output_dir}/train_opt.txt')
|
|
492
734
|
|
|
735
|
+
# load and update opt for serialization
|
|
493
736
|
opt = Options(path_file=os.path.join(model_dir,'train_opt.txt'), mode='test')
|
|
737
|
+
opt.epoch = epoch
|
|
738
|
+
if device == 'gpu':
|
|
739
|
+
opt.gpu_ids = [0] # use gpu 0, in case training was done on larger machines
|
|
740
|
+
else:
|
|
741
|
+
opt.gpu_ids = [] # use cpu
|
|
742
|
+
|
|
743
|
+
print_options(opt)
|
|
494
744
|
sample = transform(Image.new('RGB', (opt.scale_size, opt.scale_size)))
|
|
495
745
|
sample = torch.cat([sample]*opt.input_no, 1)
|
|
496
746
|
|
|
497
747
|
with click.progressbar(
|
|
498
|
-
init_nets(model_dir, eager_mode=True, phase='test').items(),
|
|
748
|
+
init_nets(model_dir, eager_mode=True, opt=opt, phase='test').items(),
|
|
499
749
|
label='Tracing nets',
|
|
500
750
|
item_show_func=lambda n: n[0] if n else n
|
|
501
751
|
) as bar:
|
|
@@ -514,8 +764,9 @@ def serialize(model_dir, output_dir, device, verbose):
|
|
|
514
764
|
|
|
515
765
|
# test: whether the original and the serialized model produces highly similar predictions
|
|
516
766
|
print('testing similarity between prediction from original vs serialized models...')
|
|
517
|
-
models_original = init_nets(model_dir,eager_mode=True,phase='test')
|
|
518
|
-
models_serialized = init_nets(output_dir,eager_mode=False,phase='test')
|
|
767
|
+
models_original = init_nets(model_dir,eager_mode=True,opt=opt,phase='test')
|
|
768
|
+
models_serialized = init_nets(output_dir,eager_mode=False,opt=opt,phase='test')
|
|
769
|
+
|
|
519
770
|
if device == 'gpu':
|
|
520
771
|
sample = sample.cuda()
|
|
521
772
|
else:
|
|
@@ -523,7 +774,7 @@ def serialize(model_dir, output_dir, device, verbose):
|
|
|
523
774
|
for name in models_serialized.keys():
|
|
524
775
|
print(name,':')
|
|
525
776
|
model_original = models_original[name].cuda().eval() if device=='gpu' else models_original[name].cpu().eval()
|
|
526
|
-
model_serialized = models_serialized[name].cuda() if device=='gpu' else models_serialized[name].cpu().eval()
|
|
777
|
+
model_serialized = models_serialized[name].cuda().eval() if device=='gpu' else models_serialized[name].cpu().eval()
|
|
527
778
|
if name.startswith('GS'):
|
|
528
779
|
test_diff_original_serialized(model_original,model_serialized,torch.cat([sample, sample, sample], 1),verbose)
|
|
529
780
|
else:
|
|
@@ -534,28 +785,44 @@ def serialize(model_dir, output_dir, device, verbose):
|
|
|
534
785
|
@cli.command()
|
|
535
786
|
@click.option('--input-dir', default='./Sample_Large_Tissues/', help='reads images from here')
|
|
536
787
|
@click.option('--output-dir', help='saves results here.')
|
|
537
|
-
@click.option('--tile-size',
|
|
788
|
+
@click.option('--tile-size', type=click.IntRange(min=1, max=None), required=True, help='tile size')
|
|
538
789
|
@click.option('--model-dir', default='./model-server/DeepLIIF_Latest_Model/', help='load models from here.')
|
|
790
|
+
@click.option('--filename-pattern', default='*', help='run inference on files of which the name matches the pattern.')
|
|
539
791
|
@click.option('--gpu-ids', type=int, multiple=True, help='gpu-ids 0 gpu-ids 1 or gpu-ids -1 for CPU')
|
|
540
|
-
@click.option('--region-size', default=20000, help='Due to limits in the resources, the whole slide image cannot be processed in whole.'
|
|
541
|
-
'So the WSI image is read region by region. '
|
|
542
|
-
'This parameter specifies the size each region to be read into GPU for inferrence.')
|
|
543
792
|
@click.option('--eager-mode', is_flag=True, help='use eager mode (loading original models, otherwise serialized ones)')
|
|
793
|
+
@click.option('--epoch', default='latest',
|
|
794
|
+
help='for eager mode, which epoch to load? set to latest to use latest cached model')
|
|
795
|
+
@click.option('--seg-intermediate', is_flag=True, help='also save intermediate segmentation images (currently only applies to DeepLIIF model)')
|
|
796
|
+
@click.option('--seg-only', is_flag=True, help='save only the final segmentation image (currently only applies to DeepLIIF model); overwrites --seg-intermediate')
|
|
544
797
|
@click.option('--color-dapi', is_flag=True, help='color dapi image to produce the same coloring as in the paper')
|
|
545
798
|
@click.option('--color-marker', is_flag=True, help='color marker image to produce the same coloring as in the paper')
|
|
546
|
-
|
|
547
|
-
|
|
799
|
+
@click.option('--BtoA', is_flag=True, help='for models trained with unaligned dataset, this flag instructs to load generatorB instead of generatorA')
|
|
800
|
+
def test(input_dir, output_dir, tile_size, model_dir, filename_pattern, gpu_ids, eager_mode, epoch,
|
|
801
|
+
seg_intermediate, seg_only, color_dapi, color_marker, btoa):
|
|
548
802
|
|
|
549
803
|
"""Test trained models
|
|
550
804
|
"""
|
|
551
805
|
output_dir = output_dir or input_dir
|
|
552
806
|
ensure_exists(output_dir)
|
|
807
|
+
|
|
808
|
+
if seg_intermediate and seg_only:
|
|
809
|
+
seg_intermediate = False
|
|
553
810
|
|
|
554
|
-
|
|
811
|
+
if filename_pattern == '*':
|
|
812
|
+
print('use all alowed files')
|
|
813
|
+
image_files = [fn for fn in os.listdir(input_dir) if allowed_file(fn)]
|
|
814
|
+
else:
|
|
815
|
+
import glob
|
|
816
|
+
print('match files using filename pattern',filename_pattern)
|
|
817
|
+
image_files = [os.path.basename(f) for f in glob.glob(os.path.join(input_dir, filename_pattern))]
|
|
818
|
+
print(len(image_files),'image files')
|
|
819
|
+
|
|
555
820
|
files = os.listdir(model_dir)
|
|
556
821
|
assert 'train_opt.txt' in files, f'file train_opt.txt is missing from model directory {model_dir}'
|
|
557
822
|
opt = Options(path_file=os.path.join(model_dir,'train_opt.txt'), mode='test')
|
|
558
823
|
opt.use_dp = False
|
|
824
|
+
opt.BtoA = btoa
|
|
825
|
+
opt.epoch = epoch
|
|
559
826
|
|
|
560
827
|
number_of_gpus_all = torch.cuda.device_count()
|
|
561
828
|
if number_of_gpus_all < len(gpu_ids) and -1 not in gpu_ids:
|
|
@@ -582,26 +849,41 @@ def test(input_dir, output_dir, tile_size, model_dir, gpu_ids, region_size, eage
|
|
|
582
849
|
item_show_func=lambda fn: fn
|
|
583
850
|
) as bar:
|
|
584
851
|
for filename in bar:
|
|
585
|
-
|
|
586
|
-
|
|
587
|
-
|
|
588
|
-
|
|
589
|
-
|
|
590
|
-
|
|
591
|
-
|
|
852
|
+
img = Image.open(os.path.join(input_dir, filename)).convert('RGB')
|
|
853
|
+
images, scoring = infer_modalities(img, tile_size, model_dir, eager_mode, color_dapi, color_marker, opt, return_seg_intermediate=seg_intermediate, seg_only=seg_only)
|
|
854
|
+
|
|
855
|
+
for name, i in images.items():
|
|
856
|
+
i.save(os.path.join(
|
|
857
|
+
output_dir,
|
|
858
|
+
filename.replace('.' + filename.split('.')[-1], f'_{name}.png')
|
|
859
|
+
))
|
|
592
860
|
|
|
593
|
-
|
|
594
|
-
|
|
861
|
+
if scoring is not None:
|
|
862
|
+
with open(os.path.join(
|
|
595
863
|
output_dir,
|
|
596
|
-
filename.replace('.' + filename.split('.')[-1], f'
|
|
597
|
-
|
|
864
|
+
filename.replace('.' + filename.split('.')[-1], f'.json')
|
|
865
|
+
), 'w') as f:
|
|
866
|
+
json.dump(scoring, f, indent=2)
|
|
867
|
+
|
|
868
|
+
|
|
869
|
+
@cli.command()
|
|
870
|
+
@click.option('--input-dir', required=True, help='directory containing WSI file')
|
|
871
|
+
@click.option('--filename', required=True, help='name of WSI to read')
|
|
872
|
+
@click.option('--output-dir', required=True, help='saves results here.')
|
|
873
|
+
@click.option('--tile-size', type=click.IntRange(min=1, max=None), required=True, help='tile size')
|
|
874
|
+
@click.option('--model-dir', default='./model-server/DeepLIIF_Latest_Model/', help='load models from here.')
|
|
875
|
+
@click.option('--region-size', default=20000, help='Due to limits in the resources, the whole slide image cannot be processed in whole.'
|
|
876
|
+
'So the WSI image is read region by region. '
|
|
877
|
+
'This parameter specifies the size each region to be read into GPU for inferrence.')
|
|
878
|
+
@click.option('--seg-intermediate', is_flag=True, help='also save intermediate segmentation images (currently only applies to DeepLIIF model)')
|
|
879
|
+
@click.option('--seg-only', is_flag=True, help='save only the final segmentation image (currently only applies to DeepLIIF model)')
|
|
880
|
+
@click.option('--color-dapi', is_flag=True, help='color dapi image to produce the same coloring as in the paper')
|
|
881
|
+
@click.option('--color-marker', is_flag=True, help='color marker image to produce the same coloring as in the paper')
|
|
882
|
+
def test_wsi(input_dir, filename, output_dir, tile_size, model_dir, region_size, seg_intermediate, seg_only, color_dapi, color_marker):
|
|
883
|
+
infer_results_for_wsi(input_dir, filename, output_dir, model_dir, tile_size, region_size,
|
|
884
|
+
color_dapi=color_dapi, color_marker=color_marker,
|
|
885
|
+
seg_intermediate=seg_intermediate, seg_only=seg_only)
|
|
598
886
|
|
|
599
|
-
if scoring is not None:
|
|
600
|
-
with open(os.path.join(
|
|
601
|
-
output_dir,
|
|
602
|
-
filename.replace('.' + filename.split('.')[-1], f'.json')
|
|
603
|
-
), 'w') as f:
|
|
604
|
-
json.dump(scoring, f, indent=2)
|
|
605
887
|
|
|
606
888
|
@cli.command()
|
|
607
889
|
@click.option('--input-dir', type=str, required=True, help='Path to input images')
|
|
@@ -721,4 +1003,9 @@ def visualize(pickle_dir, display_env):
|
|
|
721
1003
|
|
|
722
1004
|
|
|
723
1005
|
if __name__ == '__main__':
|
|
1006
|
+
# tensor float 32 is available on nvidia ampere cards (e.g, a100, a40) and provides better performance at the cost of a bit lower precision
|
|
1007
|
+
# in 1.7-1.11, pytorch by default enables tf32 when possible
|
|
1008
|
+
# currently convolutions still uses tf32 by default while matmul does not and needs to be enabled manually
|
|
1009
|
+
# see this issue for a discussion: https://github.com/pytorch/pytorch/issues/67384
|
|
1010
|
+
torch.backends.cuda.matmul.allow_tf32 = True
|
|
724
1011
|
cli()
|