deepliif 1.1.13__tar.gz → 1.1.14__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- {deepliif-1.1.13/deepliif.egg-info → deepliif-1.1.14}/PKG-INFO +2 -2
- {deepliif-1.1.13 → deepliif-1.1.14}/README.md +1 -1
- {deepliif-1.1.13 → deepliif-1.1.14}/cli.py +15 -22
- {deepliif-1.1.13 → deepliif-1.1.14}/deepliif/data/aligned_dataset.py +2 -2
- deepliif-1.1.14/deepliif/models/DeepLIIFKD_model.py +409 -0
- deepliif-1.1.13/deepliif/models/__init__.py → deepliif-1.1.14/deepliif/models/__init__ - weights, empty, zarr, tile count.py +37 -5
- deepliif-1.1.14/deepliif/models/__init__.py +817 -0
- {deepliif-1.1.13 → deepliif-1.1.14}/deepliif/models/base_model.py +1 -1
- {deepliif-1.1.13 → deepliif-1.1.14}/deepliif/models/networks.py +7 -5
- {deepliif-1.1.13 → deepliif-1.1.14}/deepliif/postprocessing.py +55 -24
- deepliif-1.1.14/deepliif/util/checks.py +17 -0
- {deepliif-1.1.13 → deepliif-1.1.14}/deepliif/util/util.py +42 -0
- {deepliif-1.1.13 → deepliif-1.1.14/deepliif.egg-info}/PKG-INFO +2 -2
- {deepliif-1.1.13 → deepliif-1.1.14}/deepliif.egg-info/SOURCES.txt +3 -0
- {deepliif-1.1.13 → deepliif-1.1.14}/setup.cfg +1 -1
- {deepliif-1.1.13 → deepliif-1.1.14}/setup.py +1 -1
- {deepliif-1.1.13 → deepliif-1.1.14}/LICENSE.md +0 -0
- {deepliif-1.1.13 → deepliif-1.1.14}/deepliif/__init__.py +0 -0
- {deepliif-1.1.13 → deepliif-1.1.14}/deepliif/data/__init__.py +0 -0
- {deepliif-1.1.13 → deepliif-1.1.14}/deepliif/data/base_dataset.py +0 -0
- {deepliif-1.1.13 → deepliif-1.1.14}/deepliif/data/colorization_dataset.py +0 -0
- {deepliif-1.1.13 → deepliif-1.1.14}/deepliif/data/image_folder.py +0 -0
- {deepliif-1.1.13 → deepliif-1.1.14}/deepliif/data/single_dataset.py +0 -0
- {deepliif-1.1.13 → deepliif-1.1.14}/deepliif/data/template_dataset.py +0 -0
- {deepliif-1.1.13 → deepliif-1.1.14}/deepliif/data/unaligned_dataset.py +0 -0
- {deepliif-1.1.13 → deepliif-1.1.14}/deepliif/models/CycleGAN_model.py +0 -0
- {deepliif-1.1.13 → deepliif-1.1.14}/deepliif/models/DeepLIIFExt_model.py +0 -0
- {deepliif-1.1.13 → deepliif-1.1.14}/deepliif/models/DeepLIIF_model.py +0 -0
- {deepliif-1.1.13 → deepliif-1.1.14}/deepliif/models/SDG_model.py +0 -0
- {deepliif-1.1.13 → deepliif-1.1.14}/deepliif/models/__init__ - different weighted.py +0 -0
- {deepliif-1.1.13 → deepliif-1.1.14}/deepliif/models/__init__ - run_dask_multi dev.py +0 -0
- {deepliif-1.1.13 → deepliif-1.1.14}/deepliif/models/__init__ - time gens.py +0 -0
- {deepliif-1.1.13 → deepliif-1.1.14}/deepliif/models/__init__ - timings.py +0 -0
- {deepliif-1.1.13 → deepliif-1.1.14}/deepliif/models/att_unet.py +0 -0
- {deepliif-1.1.13 → deepliif-1.1.14}/deepliif/options/__init__.py +0 -0
- {deepliif-1.1.13 → deepliif-1.1.14}/deepliif/options/base_options.py +0 -0
- {deepliif-1.1.13 → deepliif-1.1.14}/deepliif/options/processing_options.py +0 -0
- {deepliif-1.1.13 → deepliif-1.1.14}/deepliif/options/test_options.py +0 -0
- {deepliif-1.1.13 → deepliif-1.1.14}/deepliif/options/train_options.py +0 -0
- {deepliif-1.1.13 → deepliif-1.1.14}/deepliif/postprocessing__OLD__DELETE.py +0 -0
- {deepliif-1.1.13 → deepliif-1.1.14}/deepliif/util/__init__.py +0 -0
- {deepliif-1.1.13 → deepliif-1.1.14}/deepliif/util/get_data.py +0 -0
- {deepliif-1.1.13 → deepliif-1.1.14}/deepliif/util/html.py +0 -0
- {deepliif-1.1.13 → deepliif-1.1.14}/deepliif/util/image_pool.py +0 -0
- {deepliif-1.1.13 → deepliif-1.1.14}/deepliif/util/visualizer.py +0 -0
- {deepliif-1.1.13 → deepliif-1.1.14}/deepliif.egg-info/dependency_links.txt +0 -0
- {deepliif-1.1.13 → deepliif-1.1.14}/deepliif.egg-info/entry_points.txt +0 -0
- {deepliif-1.1.13 → deepliif-1.1.14}/deepliif.egg-info/requires.txt +0 -0
- {deepliif-1.1.13 → deepliif-1.1.14}/deepliif.egg-info/top_level.txt +0 -0
- {deepliif-1.1.13 → deepliif-1.1.14}/tests/test_args.py +0 -0
- {deepliif-1.1.13 → deepliif-1.1.14}/tests/test_cli_inference.py +0 -0
- {deepliif-1.1.13 → deepliif-1.1.14}/tests/test_cli_serialize.py +0 -0
- {deepliif-1.1.13 → deepliif-1.1.14}/tests/test_cli_train.py +0 -0
- {deepliif-1.1.13 → deepliif-1.1.14}/tests/test_cli_trainlaunch.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.1
|
|
2
2
|
Name: deepliif
|
|
3
|
-
Version: 1.1.
|
|
3
|
+
Version: 1.1.14
|
|
4
4
|
Summary: DeepLIIF: Deep-Learning Inferred Multiplex Immunofluorescence for Immunohistochemical Image Quantification
|
|
5
5
|
Home-page: https://github.com/nadeemlab/DeepLIIF
|
|
6
6
|
Author: Parmida93
|
|
@@ -53,7 +53,7 @@ segmentation.*
|
|
|
53
53
|
|
|
54
54
|
© This code is made available for non-commercial academic purposes.
|
|
55
55
|
|
|
56
|
-

|
|
57
57
|
[](https://pepy.tech/project/deepliif?&left_text=totalusers)
|
|
58
58
|
|
|
59
59
|
*Overview of DeepLIIF pipeline and sample input IHCs (different
|
|
@@ -42,7 +42,7 @@ segmentation.*
|
|
|
42
42
|
|
|
43
43
|
© This code is made available for non-commercial academic purposes.
|
|
44
44
|
|
|
45
|
-

|
|
46
46
|
[](https://pepy.tech/project/deepliif?&left_text=totalusers)
|
|
47
47
|
|
|
48
48
|
*Overview of DeepLIIF pipeline and sample input IHCs (different
|
|
@@ -14,6 +14,7 @@ from deepliif.data import create_dataset, transform
|
|
|
14
14
|
from deepliif.models import init_nets, infer_modalities, infer_results_for_wsi, create_model, postprocess
|
|
15
15
|
from deepliif.util import allowed_file, Visualizer, get_information, test_diff_original_serialized, disable_batchnorm_tracking_stats
|
|
16
16
|
from deepliif.util.util import mkdirs
|
|
17
|
+
from deepliif.util.checks import check_weights
|
|
17
18
|
# from deepliif.util import infer_results_for_wsi
|
|
18
19
|
from deepliif.options import Options, print_options
|
|
19
20
|
|
|
@@ -78,6 +79,7 @@ def cli():
|
|
|
78
79
|
@click.option('--modalities-no', default=4, type=int, help='number of targets')
|
|
79
80
|
# model parameters
|
|
80
81
|
@click.option('--model', default='DeepLIIF', help='name of model class')
|
|
82
|
+
@click.option('--model-dir-teacher', default='', help='the directory of the teacher model, only applicable if model is DeepLIIFKD')
|
|
81
83
|
@click.option('--seg-weights', default='', type=str, help='weights used to aggregate modality images for the final segmentation image; numbers should add up to 1, and each number corresponds to the modality in order; example: 0.25,0.15,0.25,0.1,0.25')
|
|
82
84
|
@click.option('--loss-weights-g', default='', type=str, help='weights used to aggregate modality-wise losses for the final loss; numbers should add up to 1, and each number corresponds to the modality in order; example: 0.2,0.2,0.2,0.2,0.2')
|
|
83
85
|
@click.option('--loss-weights-d', default='', type=str, help='weights used to aggregate modality-wise losses for the final loss; numbers should add up to 1, and each number corresponds to the modality in order; example: 0.2,0.2,0.2,0.2,0.2')
|
|
@@ -193,7 +195,8 @@ def train(dataroot, name, gpu_ids, checkpoints_dir, input_nc, output_nc, ngf, nd
|
|
|
193
195
|
verbose, lambda_l1, is_train, display_freq, display_ncols, display_id, display_server, display_env,
|
|
194
196
|
display_port, update_html_freq, print_freq, no_html, save_latest_freq, save_epoch_freq, save_by_iter,
|
|
195
197
|
continue_train, epoch_count, phase, lr_policy, n_epochs, n_epochs_decay, optimizer, beta1, lr_g, lr_d, lr_decay_iters,
|
|
196
|
-
remote, remote_transfer_cmd, seed, dataset_mode, padding, model,
|
|
198
|
+
remote, remote_transfer_cmd, seed, dataset_mode, padding, model, model_dir_teacher,
|
|
199
|
+
seg_weights, loss_weights_g, loss_weights_d,
|
|
197
200
|
modalities_no, seg_gen, net_ds, net_gs, gan_mode, gan_mode_s, local_rank, with_val, debug, debug_data_size):
|
|
198
201
|
"""General-purpose training script for multi-task image-to-image translation.
|
|
199
202
|
|
|
@@ -206,8 +209,8 @@ def train(dataroot, name, gpu_ids, checkpoints_dir, input_nc, output_nc, ngf, nd
|
|
|
206
209
|
plot, and save models.The script supports continue/resume training.
|
|
207
210
|
Use '--continue_train' to resume your previous training.
|
|
208
211
|
"""
|
|
209
|
-
assert model in ['DeepLIIF','DeepLIIFExt','SDG','CycleGAN'], f'model class {model} is not implemented'
|
|
210
|
-
if model
|
|
212
|
+
assert model in ['DeepLIIF','DeepLIIFExt','SDG','CycleGAN','DeepLIIFKD'], f'model class {model} is not implemented'
|
|
213
|
+
if model in ['DeepLIIF','DeepLIIFKD']:
|
|
211
214
|
seg_no = 1
|
|
212
215
|
elif model == 'DeepLIIFExt':
|
|
213
216
|
if seg_gen:
|
|
@@ -221,6 +224,9 @@ def train(dataroot, name, gpu_ids, checkpoints_dir, input_nc, output_nc, ngf, nd
|
|
|
221
224
|
if model == 'CycleGAN':
|
|
222
225
|
dataset_mode = "unaligned"
|
|
223
226
|
|
|
227
|
+
if model == 'DeepLIIFKD':
|
|
228
|
+
assert len(model_dir_teacher) > 0 and os.path.isdir(model_dir_teacher), f'Teacher model directory {model_dir_teacher} is not valid.'
|
|
229
|
+
|
|
224
230
|
if optimizer != 'adam':
|
|
225
231
|
print(f'Optimizer torch.optim.{optimizer} is not tested. Be careful about the parameters of the optimizer.')
|
|
226
232
|
|
|
@@ -310,7 +316,7 @@ def train(dataroot, name, gpu_ids, checkpoints_dir, input_nc, output_nc, ngf, nd
|
|
|
310
316
|
|
|
311
317
|
net_gs = net_gs.split(',')
|
|
312
318
|
assert len(net_gs) in [1,seg_no], f'net_gs should contain either 1 architecture for all segmentation generators or the same number of architectures as the number of segmentation generators ({seg_no})'
|
|
313
|
-
if len(net_gs) == 1 and model
|
|
319
|
+
if len(net_gs) == 1 and model in ['DeepLIIF','DeepLIIFKD']:
|
|
314
320
|
net_gs = net_gs*(modalities_no + seg_no)
|
|
315
321
|
elif len(net_gs) == 1:
|
|
316
322
|
net_gs = net_gs*seg_no
|
|
@@ -320,34 +326,21 @@ def train(dataroot, name, gpu_ids, checkpoints_dir, input_nc, output_nc, ngf, nd
|
|
|
320
326
|
|
|
321
327
|
# check seg weights and loss weights
|
|
322
328
|
if len(d_params['seg_weights']) == 0:
|
|
323
|
-
seg_weights = [0.25,0.15,0.25,0.1,0.25] if d_params['model']
|
|
329
|
+
seg_weights = [0.25,0.15,0.25,0.1,0.25] if d_params['model'] in ['DeepLIIF','DeepLIIFKD'] else [1 / modalities_no] * modalities_no
|
|
324
330
|
else:
|
|
325
331
|
seg_weights = [float(x) for x in seg_weights.split(',')]
|
|
326
332
|
|
|
327
333
|
if len(d_params['loss_weights_g']) == 0:
|
|
328
|
-
loss_weights_g = [0.2]*5 if d_params['model']
|
|
334
|
+
loss_weights_g = [0.2]*5 if d_params['model'] in ['DeepLIIF','DeepLIIFKD'] else [1 / modalities_no] * modalities_no
|
|
329
335
|
else:
|
|
330
336
|
loss_weights_g = [float(x) for x in loss_weights_g.split(',')]
|
|
331
337
|
|
|
332
338
|
if len(d_params['loss_weights_d']) == 0:
|
|
333
|
-
loss_weights_d = [0.2]*5 if d_params['model']
|
|
339
|
+
loss_weights_d = [0.2]*5 if d_params['model'] in ['DeepLIIF','DeepLIIFKD'] else [1 / modalities_no] * modalities_no
|
|
334
340
|
else:
|
|
335
341
|
loss_weights_d = [float(x) for x in loss_weights_d.split(',')]
|
|
336
342
|
|
|
337
|
-
|
|
338
|
-
assert sum(loss_weights_g) == 1, 'loss weights g should add up to 1'
|
|
339
|
-
assert sum(loss_weights_d) == 1, 'loss weights d should add up to 1'
|
|
340
|
-
|
|
341
|
-
if model == 'DeepLIIF':
|
|
342
|
-
# +1 because input becomes an additional modality used in generating the final segmentation
|
|
343
|
-
assert len(seg_weights) == modalities_no+1, 'seg weights should have the same number of elements as number of modalities to be generated'
|
|
344
|
-
assert len(loss_weights_g) == modalities_no+1, 'loss weights g should have the same number of elements as number of modalities to be generated'
|
|
345
|
-
assert len(loss_weights_d) == modalities_no+1, 'loss weights d should have the same number of elements as number of modalities to be generated'
|
|
346
|
-
|
|
347
|
-
else:
|
|
348
|
-
assert len(seg_weights) == modalities_no, 'seg weights should have the same number of elements as number of modalities to be generated'
|
|
349
|
-
assert len(loss_weights_g) == modalities_no, 'loss weights g should have the same number of elements as number of modalities to be generated'
|
|
350
|
-
assert len(loss_weights_d) == modalities_no, 'loss weights d should have the same number of elements as number of modalities to be generated'
|
|
343
|
+
check_weights(model, modalities_no, seg_weights, loss_weights_g, loss_weights_d)
|
|
351
344
|
|
|
352
345
|
d_params['seg_weights'] = seg_weights
|
|
353
346
|
d_params['loss_G_weights'] = loss_weights_g
|
|
@@ -373,7 +366,7 @@ def train(dataroot, name, gpu_ids, checkpoints_dir, input_nc, output_nc, ngf, nd
|
|
|
373
366
|
data_val = [batch for batch in dataset_val]
|
|
374
367
|
click.echo('The number of validation images = %d' % len(dataset_val))
|
|
375
368
|
|
|
376
|
-
if model in ['DeepLIIF']:
|
|
369
|
+
if model in ['DeepLIIF', 'DeepLIIFKD']:
|
|
377
370
|
metrics_val = json.load(open(os.path.join(dataset_val.dataset.dir_AB,'metrics.json')))
|
|
378
371
|
|
|
379
372
|
# create a model given model and other options
|
|
@@ -50,7 +50,7 @@ class AlignedDataset(BaseDataset):
|
|
|
50
50
|
AB = Image.open(AB_path).convert('RGB')
|
|
51
51
|
# split AB image into A and B
|
|
52
52
|
w, h = AB.size
|
|
53
|
-
if self.model
|
|
53
|
+
if self.model in ['DeepLIIF','DeepLIIFKD']:
|
|
54
54
|
num_img = self.modalities_no + 1 + 1 # +1 for segmentation channel, +1 for input image
|
|
55
55
|
elif self.model == 'DeepLIIFExt':
|
|
56
56
|
num_img = self.modalities_no * 2 + 1 if self.seg_gen else self.modalities_no + 1 # +1 for segmentation channel
|
|
@@ -68,7 +68,7 @@ class AlignedDataset(BaseDataset):
|
|
|
68
68
|
|
|
69
69
|
A = A_transform(A)
|
|
70
70
|
B_Array = []
|
|
71
|
-
if self.model
|
|
71
|
+
if self.model in ['DeepLIIF','DeepLIIFKD']:
|
|
72
72
|
for i in range(1, num_img):
|
|
73
73
|
B = AB.crop((w2 * i, 0, w2 * (i + 1), h))
|
|
74
74
|
B = B_transform(B)
|
|
@@ -0,0 +1,409 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
from .base_model import BaseModel
|
|
3
|
+
from . import networks
|
|
4
|
+
from .networks import get_optimizer
|
|
5
|
+
from . import init_nets, run_dask, get_opt
|
|
6
|
+
from torch import nn
|
|
7
|
+
|
|
8
|
+
class DeepLIIFKDModel(BaseModel):
|
|
9
|
+
""" This class implements the DeepLIIF model, for learning a mapping from input images to modalities given paired data."""
|
|
10
|
+
|
|
11
|
+
def __init__(self, opt):
|
|
12
|
+
"""Initialize the DeepLIIF class.
|
|
13
|
+
|
|
14
|
+
Parameters:
|
|
15
|
+
opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions
|
|
16
|
+
"""
|
|
17
|
+
BaseModel.__init__(self, opt)
|
|
18
|
+
if not hasattr(opt,'net_gs'):
|
|
19
|
+
opt.net_gs = 'unet_512'
|
|
20
|
+
|
|
21
|
+
self.seg_weights = opt.seg_weights
|
|
22
|
+
self.loss_G_weights = opt.loss_G_weights
|
|
23
|
+
self.loss_D_weights = opt.loss_D_weights
|
|
24
|
+
|
|
25
|
+
if not opt.is_train:
|
|
26
|
+
self.gpu_ids = [] # avoid the models being loaded as DP
|
|
27
|
+
else:
|
|
28
|
+
self.gpu_ids = opt.gpu_ids
|
|
29
|
+
|
|
30
|
+
self.loss_names = []
|
|
31
|
+
self.visual_names = ['real_A']
|
|
32
|
+
# specify the training losses you want to print out. The training/test scripts will call <BaseModel.get_current_losses>
|
|
33
|
+
for i in range(1, self.opt.modalities_no + 1 + 1):
|
|
34
|
+
self.loss_names.extend([f'G_GAN_{i}', f'G_L1_{i}', f'D_real_{i}', f'D_fake_{i}', f'G_KLDiv_{i}', f'G_KLDiv_5_{i}'])
|
|
35
|
+
self.visual_names.extend([f'fake_B_{i}', f'fake_B_5_{i}', f'fake_B_{i}_teacher', f'fake_B_5_{i}_teacher', f'real_B_{i}'])
|
|
36
|
+
|
|
37
|
+
# specify the images you want to save/display. The training/test scripts will call <BaseModel.get_current_visuals>
|
|
38
|
+
# specify the models you want to save to the disk. The training/test scripts will call <BaseModel.save_networks> and <BaseModel.load_networks>
|
|
39
|
+
if self.is_train:
|
|
40
|
+
self.model_names = []
|
|
41
|
+
for i in range(1, self.opt.modalities_no + 1):
|
|
42
|
+
self.model_names.extend(['G' + str(i), 'D' + str(i)])
|
|
43
|
+
|
|
44
|
+
for i in range(1, self.opt.modalities_no + 1 + 1):
|
|
45
|
+
self.model_names.extend(['G5' + str(i), 'D5' + str(i)])
|
|
46
|
+
else: # during test time, only load G
|
|
47
|
+
self.model_names = []
|
|
48
|
+
for i in range(1, self.opt.modalities_no + 1):
|
|
49
|
+
self.model_names.extend(['G' + str(i)])
|
|
50
|
+
|
|
51
|
+
for i in range(1, self.opt.modalities_no + 1 + 1):
|
|
52
|
+
self.model_names.extend(['G5' + str(i)])
|
|
53
|
+
|
|
54
|
+
# define networks (both generator and discriminator)
|
|
55
|
+
if isinstance(opt.netG, str):
|
|
56
|
+
opt.netG = [opt.netG] * 4
|
|
57
|
+
if isinstance(opt.net_gs, str):
|
|
58
|
+
opt.net_gs = [opt.net_gs]*5
|
|
59
|
+
|
|
60
|
+
|
|
61
|
+
self.netG1 = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG[0], opt.norm,
|
|
62
|
+
not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids, opt.padding)
|
|
63
|
+
self.netG2 = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG[1], opt.norm,
|
|
64
|
+
not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids, opt.padding)
|
|
65
|
+
self.netG3 = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG[2], opt.norm,
|
|
66
|
+
not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids, opt.padding)
|
|
67
|
+
self.netG4 = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG[3], opt.norm,
|
|
68
|
+
not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids, opt.padding)
|
|
69
|
+
|
|
70
|
+
# DeepLIIF model currently uses one gs arch because there is only one explicit seg mod output
|
|
71
|
+
self.netG51 = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.net_gs[0], opt.norm,
|
|
72
|
+
not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids)
|
|
73
|
+
self.netG52 = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.net_gs[1], opt.norm,
|
|
74
|
+
not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids)
|
|
75
|
+
self.netG53 = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.net_gs[2], opt.norm,
|
|
76
|
+
not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids)
|
|
77
|
+
self.netG54 = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.net_gs[3], opt.norm,
|
|
78
|
+
not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids)
|
|
79
|
+
self.netG55 = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.net_gs[4], opt.norm,
|
|
80
|
+
not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids)
|
|
81
|
+
|
|
82
|
+
|
|
83
|
+
if self.is_train: # define a discriminator; conditional GANs need to take both input and output images; Therefore, #channels for D is input_nc + output_nc
|
|
84
|
+
self.netD1 = networks.define_D(opt.input_nc+opt.output_nc , opt.ndf, opt.netD,
|
|
85
|
+
opt.n_layers_D, opt.norm, opt.init_type, opt.init_gain, self.gpu_ids)
|
|
86
|
+
self.netD2 = networks.define_D(opt.input_nc + opt.output_nc, opt.ndf, opt.netD,
|
|
87
|
+
opt.n_layers_D, opt.norm, opt.init_type, opt.init_gain, self.gpu_ids)
|
|
88
|
+
self.netD3 = networks.define_D(opt.input_nc + opt.output_nc, opt.ndf, opt.netD,
|
|
89
|
+
opt.n_layers_D, opt.norm, opt.init_type, opt.init_gain, self.gpu_ids)
|
|
90
|
+
self.netD4 = networks.define_D(opt.input_nc + opt.output_nc, opt.ndf, opt.netD,
|
|
91
|
+
opt.n_layers_D, opt.norm, opt.init_type, opt.init_gain, self.gpu_ids)
|
|
92
|
+
|
|
93
|
+
self.netD51 = networks.define_D(opt.input_nc + opt.output_nc, opt.ndf, opt.netD,
|
|
94
|
+
opt.n_layers_D, opt.norm, opt.init_type, opt.init_gain, self.gpu_ids)
|
|
95
|
+
self.netD52 = networks.define_D(opt.input_nc + opt.output_nc, opt.ndf, opt.netD,
|
|
96
|
+
opt.n_layers_D, opt.norm, opt.init_type, opt.init_gain, self.gpu_ids)
|
|
97
|
+
self.netD53 = networks.define_D(opt.input_nc + opt.output_nc, opt.ndf, opt.netD,
|
|
98
|
+
opt.n_layers_D, opt.norm, opt.init_type, opt.init_gain, self.gpu_ids)
|
|
99
|
+
self.netD54 = networks.define_D(opt.input_nc + opt.output_nc, opt.ndf, opt.netD,
|
|
100
|
+
opt.n_layers_D, opt.norm, opt.init_type, opt.init_gain, self.gpu_ids)
|
|
101
|
+
self.netD55 = networks.define_D(opt.input_nc + opt.output_nc, opt.ndf, opt.netD,
|
|
102
|
+
opt.n_layers_D, opt.norm, opt.init_type, opt.init_gain, self.gpu_ids)
|
|
103
|
+
|
|
104
|
+
# load the teacher model
|
|
105
|
+
self.opt_teacher = get_opt(opt.model_dir_teacher, mode='test')
|
|
106
|
+
self.opt_teacher.gpu_ids = opt.gpu_ids # use student's gpu_ids
|
|
107
|
+
self.nets_teacher = init_nets(opt.model_dir_teacher, eager_mode=True, opt=self.opt_teacher, phase='test')
|
|
108
|
+
|
|
109
|
+
|
|
110
|
+
if self.is_train:
|
|
111
|
+
# define loss functions
|
|
112
|
+
self.criterionGAN_BCE = networks.GANLoss('vanilla').to(self.device)
|
|
113
|
+
self.criterionGAN_lsgan = networks.GANLoss('lsgan').to(self.device)
|
|
114
|
+
self.criterionSmoothL1 = torch.nn.SmoothL1Loss()
|
|
115
|
+
|
|
116
|
+
# initialize optimizers; schedulers will be automatically created by function <BaseModel.setup>.
|
|
117
|
+
params = list(self.netG1.parameters()) + list(self.netG2.parameters()) + list(self.netG3.parameters()) + list(self.netG4.parameters()) + list(self.netG51.parameters()) + list(self.netG52.parameters()) + list(self.netG53.parameters()) + list(self.netG54.parameters()) + list(self.netG55.parameters())
|
|
118
|
+
try:
|
|
119
|
+
self.optimizer_G = get_optimizer(opt.optimizer)(params, lr=opt.lr_g, betas=(opt.beta1, 0.999))
|
|
120
|
+
except:
|
|
121
|
+
print(f'betas are not used for optimizer torch.optim.{opt.optimizer} in generators')
|
|
122
|
+
self.optimizer_G = get_optimizer(opt.optimizer)(params, lr=opt.lr_g)
|
|
123
|
+
|
|
124
|
+
params = list(self.netD1.parameters()) + list(self.netD2.parameters()) + list(self.netD3.parameters()) + list(self.netD4.parameters()) + list(self.netD51.parameters()) + list(self.netD52.parameters()) + list(self.netD53.parameters()) + list(self.netD54.parameters()) + list(self.netD55.parameters())
|
|
125
|
+
try:
|
|
126
|
+
self.optimizer_D = get_optimizer(opt.optimizer)(params, lr=opt.lr_d, betas=(opt.beta1, 0.999))
|
|
127
|
+
except:
|
|
128
|
+
print(f'betas are not used for optimizer torch.optim.{opt.optimizer} in discriminators')
|
|
129
|
+
self.optimizer_D = get_optimizer(opt.optimizer)(params, lr=opt.lr_d)
|
|
130
|
+
|
|
131
|
+
self.optimizers.append(self.optimizer_G)
|
|
132
|
+
self.optimizers.append(self.optimizer_D)
|
|
133
|
+
|
|
134
|
+
self.criterionVGG = networks.VGGLoss().to(self.device)
|
|
135
|
+
self.criterionKLDiv = torch.nn.KLDivLoss(reduction='batchmean').to(self.device)
|
|
136
|
+
self.softmax = torch.nn.Softmax(dim=-1).to(self.device) # apply softmax on the last dim
|
|
137
|
+
self.logsoftmax = torch.nn.LogSoftmax(dim=-1).to(self.device) # apply log-softmax on the last dim
|
|
138
|
+
|
|
139
|
+
def set_input(self, input):
|
|
140
|
+
"""
|
|
141
|
+
Unpack input data from the dataloader and perform necessary pre-processing steps.
|
|
142
|
+
|
|
143
|
+
:param input (dict): include the input image and the output modalities
|
|
144
|
+
"""
|
|
145
|
+
|
|
146
|
+
self.real_A = input['A'].to(self.device)
|
|
147
|
+
|
|
148
|
+
self.real_B_array = input['B']
|
|
149
|
+
self.real_B_1 = self.real_B_array[0].to(self.device)
|
|
150
|
+
self.real_B_2 = self.real_B_array[1].to(self.device)
|
|
151
|
+
self.real_B_3 = self.real_B_array[2].to(self.device)
|
|
152
|
+
self.real_B_4 = self.real_B_array[3].to(self.device)
|
|
153
|
+
self.real_B_5 = self.real_B_array[4].to(self.device)
|
|
154
|
+
self.image_paths = input['A_paths']
|
|
155
|
+
|
|
156
|
+
def forward(self):
|
|
157
|
+
"""Run forward pass; called by both functions <optimize_parameters> and <test>."""
|
|
158
|
+
self.fake_B_1 = self.netG1(self.real_A) # Hematoxylin image generator
|
|
159
|
+
self.fake_B_2 = self.netG2(self.real_A) # mpIF DAPI image generator
|
|
160
|
+
self.fake_B_3 = self.netG3(self.real_A) # mpIF Lap2 image generator
|
|
161
|
+
self.fake_B_4 = self.netG4(self.real_A) # mpIF Ki67 image generator
|
|
162
|
+
|
|
163
|
+
self.fake_B_5_1 = self.netG51(self.real_A) # Segmentation mask generator from IHC input image
|
|
164
|
+
self.fake_B_5_2 = self.netG52(self.fake_B_1) # Segmentation mask generator from Hematoxylin input image
|
|
165
|
+
self.fake_B_5_3 = self.netG53(self.fake_B_2) # Segmentation mask generator from mpIF DAPI input image
|
|
166
|
+
self.fake_B_5_4 = self.netG54(self.fake_B_3) # Segmentation mask generator from mpIF Lap2 input image
|
|
167
|
+
self.fake_B_5_5 = self.netG55(self.fake_B_4) # Segmentation mask generator from mpIF Lap2 input image
|
|
168
|
+
self.fake_B_5 = torch.stack([torch.mul(self.fake_B_5_1, self.seg_weights[0]),
|
|
169
|
+
torch.mul(self.fake_B_5_2, self.seg_weights[1]),
|
|
170
|
+
torch.mul(self.fake_B_5_3, self.seg_weights[2]),
|
|
171
|
+
torch.mul(self.fake_B_5_4, self.seg_weights[3]),
|
|
172
|
+
torch.mul(self.fake_B_5_5, self.seg_weights[4])]).sum(dim=0)
|
|
173
|
+
|
|
174
|
+
fakes_teacher = run_dask(img=self.real_A, nets=self.nets_teacher, opt=self.opt_teacher, use_dask=False, output_tensor=True)
|
|
175
|
+
for k,v in fakes_teacher.items():
|
|
176
|
+
suffix = k[1:] # starts with G
|
|
177
|
+
suffix = '_'.join(list(suffix)) # 51 -> 5_1
|
|
178
|
+
setattr(self,f'fake_B_{suffix}_teacher',v)
|
|
179
|
+
|
|
180
|
+
def backward_D(self):
|
|
181
|
+
"""Calculate GAN loss for the discriminators"""
|
|
182
|
+
fake_AB_1 = torch.cat((self.real_A, self.fake_B_1), 1) # Conditional GANs; feed IHC input and Hematoxtlin output to the discriminator
|
|
183
|
+
fake_AB_2 = torch.cat((self.real_A, self.fake_B_2), 1) # Conditional GANs; feed IHC input and mpIF DAPI output to the discriminator
|
|
184
|
+
fake_AB_3 = torch.cat((self.real_A, self.fake_B_3), 1) # Conditional GANs; feed IHC input and mpIF Lap2 output to the discriminator
|
|
185
|
+
fake_AB_4 = torch.cat((self.real_A, self.fake_B_4), 1) # Conditional GANs; feed IHC input and mpIF Ki67 output to the discriminator
|
|
186
|
+
|
|
187
|
+
pred_fake_1 = self.netD1(fake_AB_1.detach())
|
|
188
|
+
pred_fake_2 = self.netD2(fake_AB_2.detach())
|
|
189
|
+
pred_fake_3 = self.netD3(fake_AB_3.detach())
|
|
190
|
+
pred_fake_4 = self.netD4(fake_AB_4.detach())
|
|
191
|
+
|
|
192
|
+
fake_AB_5_1 = torch.cat((self.real_A, self.fake_B_5), 1) # Conditional GANs; feed IHC input and Segmentation mask output to the discriminator
|
|
193
|
+
fake_AB_5_2 = torch.cat((self.real_B_1, self.fake_B_5), 1) # Conditional GANs; feed Hematoxylin input and Segmentation mask output to the discriminator
|
|
194
|
+
fake_AB_5_3 = torch.cat((self.real_B_2, self.fake_B_5), 1) # Conditional GANs; feed mpIF DAPI input and Segmentation mask output to the discriminator
|
|
195
|
+
fake_AB_5_4 = torch.cat((self.real_B_3, self.fake_B_5), 1) # Conditional GANs; feed mpIF Lap2 input and Segmentation mask output to the discriminator
|
|
196
|
+
fake_AB_5_5 = torch.cat((self.real_B_4, self.fake_B_5), 1) # Conditional GANs; feed mpIF Lap2 input and Segmentation mask output to the discriminator
|
|
197
|
+
|
|
198
|
+
pred_fake_5_1 = self.netD51(fake_AB_5_1.detach())
|
|
199
|
+
pred_fake_5_2 = self.netD52(fake_AB_5_2.detach())
|
|
200
|
+
pred_fake_5_3 = self.netD53(fake_AB_5_3.detach())
|
|
201
|
+
pred_fake_5_4 = self.netD54(fake_AB_5_4.detach())
|
|
202
|
+
pred_fake_5_5 = self.netD55(fake_AB_5_5.detach())
|
|
203
|
+
|
|
204
|
+
pred_fake_5 = torch.stack(
|
|
205
|
+
[torch.mul(pred_fake_5_1, self.seg_weights[0]),
|
|
206
|
+
torch.mul(pred_fake_5_2, self.seg_weights[1]),
|
|
207
|
+
torch.mul(pred_fake_5_3, self.seg_weights[2]),
|
|
208
|
+
torch.mul(pred_fake_5_4, self.seg_weights[3]),
|
|
209
|
+
torch.mul(pred_fake_5_5, self.seg_weights[4])]).sum(dim=0)
|
|
210
|
+
|
|
211
|
+
self.loss_D_fake_1 = self.criterionGAN_BCE(pred_fake_1, False)
|
|
212
|
+
self.loss_D_fake_2 = self.criterionGAN_BCE(pred_fake_2, False)
|
|
213
|
+
self.loss_D_fake_3 = self.criterionGAN_BCE(pred_fake_3, False)
|
|
214
|
+
self.loss_D_fake_4 = self.criterionGAN_BCE(pred_fake_4, False)
|
|
215
|
+
self.loss_D_fake_5 = self.criterionGAN_lsgan(pred_fake_5, False)
|
|
216
|
+
|
|
217
|
+
|
|
218
|
+
real_AB_1 = torch.cat((self.real_A, self.real_B_1), 1)
|
|
219
|
+
real_AB_2 = torch.cat((self.real_A, self.real_B_2), 1)
|
|
220
|
+
real_AB_3 = torch.cat((self.real_A, self.real_B_3), 1)
|
|
221
|
+
real_AB_4 = torch.cat((self.real_A, self.real_B_4), 1)
|
|
222
|
+
|
|
223
|
+
pred_real_1 = self.netD1(real_AB_1)
|
|
224
|
+
pred_real_2 = self.netD2(real_AB_2)
|
|
225
|
+
pred_real_3 = self.netD3(real_AB_3)
|
|
226
|
+
pred_real_4 = self.netD4(real_AB_4)
|
|
227
|
+
|
|
228
|
+
real_AB_5_1 = torch.cat((self.real_A, self.real_B_5), 1)
|
|
229
|
+
real_AB_5_2 = torch.cat((self.real_B_1, self.real_B_5), 1)
|
|
230
|
+
real_AB_5_3 = torch.cat((self.real_B_2, self.real_B_5), 1)
|
|
231
|
+
real_AB_5_4 = torch.cat((self.real_B_3, self.real_B_5), 1)
|
|
232
|
+
real_AB_5_5 = torch.cat((self.real_B_4, self.real_B_5), 1)
|
|
233
|
+
|
|
234
|
+
pred_real_5_1 = self.netD51(real_AB_5_1)
|
|
235
|
+
pred_real_5_2 = self.netD52(real_AB_5_2)
|
|
236
|
+
pred_real_5_3 = self.netD53(real_AB_5_3)
|
|
237
|
+
pred_real_5_4 = self.netD54(real_AB_5_4)
|
|
238
|
+
pred_real_5_5 = self.netD55(real_AB_5_5)
|
|
239
|
+
|
|
240
|
+
pred_real_5 = torch.stack(
|
|
241
|
+
[torch.mul(pred_real_5_1, self.seg_weights[0]),
|
|
242
|
+
torch.mul(pred_real_5_2, self.seg_weights[1]),
|
|
243
|
+
torch.mul(pred_real_5_3, self.seg_weights[2]),
|
|
244
|
+
torch.mul(pred_real_5_4, self.seg_weights[3]),
|
|
245
|
+
torch.mul(pred_real_5_5, self.seg_weights[4])]).sum(dim=0)
|
|
246
|
+
|
|
247
|
+
self.loss_D_real_1 = self.criterionGAN_BCE(pred_real_1, True)
|
|
248
|
+
self.loss_D_real_2 = self.criterionGAN_BCE(pred_real_2, True)
|
|
249
|
+
self.loss_D_real_3 = self.criterionGAN_BCE(pred_real_3, True)
|
|
250
|
+
self.loss_D_real_4 = self.criterionGAN_BCE(pred_real_4, True)
|
|
251
|
+
self.loss_D_real_5 = self.criterionGAN_lsgan(pred_real_5, True)
|
|
252
|
+
|
|
253
|
+
# combine losses and calculate gradients
|
|
254
|
+
self.loss_D = (self.loss_D_fake_1 + self.loss_D_real_1) * 0.5 * self.loss_D_weights[0] + \
|
|
255
|
+
(self.loss_D_fake_2 + self.loss_D_real_2) * 0.5 * self.loss_D_weights[1] + \
|
|
256
|
+
(self.loss_D_fake_3 + self.loss_D_real_3) * 0.5 * self.loss_D_weights[2] + \
|
|
257
|
+
(self.loss_D_fake_4 + self.loss_D_real_4) * 0.5 * self.loss_D_weights[3] + \
|
|
258
|
+
(self.loss_D_fake_5 + self.loss_D_real_5) * 0.5 * self.loss_D_weights[4]
|
|
259
|
+
|
|
260
|
+
self.loss_D.backward()
|
|
261
|
+
|
|
262
|
+
def backward_G(self):
|
|
263
|
+
"""Calculate GAN and L1 loss for the generator"""
|
|
264
|
+
|
|
265
|
+
fake_AB_1 = torch.cat((self.real_A, self.fake_B_1), 1)
|
|
266
|
+
fake_AB_2 = torch.cat((self.real_A, self.fake_B_2), 1)
|
|
267
|
+
fake_AB_3 = torch.cat((self.real_A, self.fake_B_3), 1)
|
|
268
|
+
fake_AB_4 = torch.cat((self.real_A, self.fake_B_4), 1)
|
|
269
|
+
|
|
270
|
+
fake_AB_5_1 = torch.cat((self.real_A, self.fake_B_5), 1)
|
|
271
|
+
fake_AB_5_2 = torch.cat((self.real_B_1, self.fake_B_5), 1)
|
|
272
|
+
fake_AB_5_3 = torch.cat((self.real_B_2, self.fake_B_5), 1)
|
|
273
|
+
fake_AB_5_4 = torch.cat((self.real_B_3, self.fake_B_5), 1)
|
|
274
|
+
fake_AB_5_5 = torch.cat((self.real_B_4, self.fake_B_5), 1)
|
|
275
|
+
|
|
276
|
+
pred_fake_1 = self.netD1(fake_AB_1)
|
|
277
|
+
pred_fake_2 = self.netD2(fake_AB_2)
|
|
278
|
+
pred_fake_3 = self.netD3(fake_AB_3)
|
|
279
|
+
pred_fake_4 = self.netD4(fake_AB_4)
|
|
280
|
+
|
|
281
|
+
pred_fake_5_1 = self.netD51(fake_AB_5_1)
|
|
282
|
+
pred_fake_5_2 = self.netD52(fake_AB_5_2)
|
|
283
|
+
pred_fake_5_3 = self.netD53(fake_AB_5_3)
|
|
284
|
+
pred_fake_5_4 = self.netD54(fake_AB_5_4)
|
|
285
|
+
pred_fake_5_5 = self.netD55(fake_AB_5_5)
|
|
286
|
+
pred_fake_5 = torch.stack(
|
|
287
|
+
[torch.mul(pred_fake_5_1, self.seg_weights[0]),
|
|
288
|
+
torch.mul(pred_fake_5_2, self.seg_weights[1]),
|
|
289
|
+
torch.mul(pred_fake_5_3, self.seg_weights[2]),
|
|
290
|
+
torch.mul(pred_fake_5_4, self.seg_weights[3]),
|
|
291
|
+
torch.mul(pred_fake_5_5, self.seg_weights[4])]).sum(dim=0)
|
|
292
|
+
|
|
293
|
+
self.loss_G_GAN_1 = self.criterionGAN_BCE(pred_fake_1, True)
|
|
294
|
+
self.loss_G_GAN_2 = self.criterionGAN_BCE(pred_fake_2, True)
|
|
295
|
+
self.loss_G_GAN_3 = self.criterionGAN_BCE(pred_fake_3, True)
|
|
296
|
+
self.loss_G_GAN_4 = self.criterionGAN_BCE(pred_fake_4, True)
|
|
297
|
+
self.loss_G_GAN_5 = self.criterionGAN_lsgan(pred_fake_5, True)
|
|
298
|
+
|
|
299
|
+
# Second, G(A) = B
|
|
300
|
+
self.loss_G_L1_1 = self.criterionSmoothL1(self.fake_B_1, self.real_B_1) * self.opt.lambda_L1
|
|
301
|
+
self.loss_G_L1_2 = self.criterionSmoothL1(self.fake_B_2, self.real_B_2) * self.opt.lambda_L1
|
|
302
|
+
self.loss_G_L1_3 = self.criterionSmoothL1(self.fake_B_3, self.real_B_3) * self.opt.lambda_L1
|
|
303
|
+
self.loss_G_L1_4 = self.criterionSmoothL1(self.fake_B_4, self.real_B_4) * self.opt.lambda_L1
|
|
304
|
+
self.loss_G_L1_5 = self.criterionSmoothL1(self.fake_B_5, self.real_B_5) * self.opt.lambda_L1
|
|
305
|
+
|
|
306
|
+
self.loss_G_VGG_1 = self.criterionVGG(self.fake_B_1, self.real_B_1) * self.opt.lambda_feat
|
|
307
|
+
self.loss_G_VGG_2 = self.criterionVGG(self.fake_B_2, self.real_B_2) * self.opt.lambda_feat
|
|
308
|
+
self.loss_G_VGG_3 = self.criterionVGG(self.fake_B_3, self.real_B_3) * self.opt.lambda_feat
|
|
309
|
+
self.loss_G_VGG_4 = self.criterionVGG(self.fake_B_4, self.real_B_4) * self.opt.lambda_feat
|
|
310
|
+
|
|
311
|
+
|
|
312
|
+
# .view(1,1,-1) reshapes the input (batch_size, 3, 512, 512) to (batch_size, 1, 3*512*512)
|
|
313
|
+
# softmax/log-softmax is then applied on the concatenated vector of size (1, 3*512*512)
|
|
314
|
+
# this normalizes the pixel values across all 3 RGB channels
|
|
315
|
+
# the resulting vectors are then used to compute KL divergence loss
|
|
316
|
+
self.loss_G_KLDiv_1 = self.criterionKLDiv(self.logsoftmax(self.fake_B_1.view(1,1,-1)), self.softmax(self.fake_B_1_teacher.view(1,1,-1)))
|
|
317
|
+
self.loss_G_KLDiv_2 = self.criterionKLDiv(self.logsoftmax(self.fake_B_2.view(1,1,-1)), self.softmax(self.fake_B_2_teacher.view(1,1,-1)))
|
|
318
|
+
self.loss_G_KLDiv_3 = self.criterionKLDiv(self.logsoftmax(self.fake_B_3.view(1,1,-1)), self.softmax(self.fake_B_3_teacher.view(1,1,-1)))
|
|
319
|
+
self.loss_G_KLDiv_4 = self.criterionKLDiv(self.logsoftmax(self.fake_B_4.view(1,1,-1)), self.softmax(self.fake_B_4_teacher.view(1,1,-1)))
|
|
320
|
+
self.loss_G_KLDiv_5 = self.criterionKLDiv(self.logsoftmax(self.fake_B_5.view(1,1,-1)), self.softmax(self.fake_B_5_teacher.view(1,1,-1)))
|
|
321
|
+
self.loss_G_KLDiv_5_1 = self.criterionKLDiv(self.logsoftmax(self.fake_B_5_1.view(1,1,-1)), self.softmax(self.fake_B_5_1_teacher.view(1,1,-1)))
|
|
322
|
+
self.loss_G_KLDiv_5_2 = self.criterionKLDiv(self.logsoftmax(self.fake_B_5_2.view(1,1,-1)), self.softmax(self.fake_B_5_2_teacher.view(1,1,-1)))
|
|
323
|
+
self.loss_G_KLDiv_5_3 = self.criterionKLDiv(self.logsoftmax(self.fake_B_5_3.view(1,1,-1)), self.softmax(self.fake_B_5_3_teacher.view(1,1,-1)))
|
|
324
|
+
self.loss_G_KLDiv_5_4 = self.criterionKLDiv(self.logsoftmax(self.fake_B_5_4.view(1,1,-1)), self.softmax(self.fake_B_5_4_teacher.view(1,1,-1)))
|
|
325
|
+
self.loss_G_KLDiv_5_5 = self.criterionKLDiv(self.logsoftmax(self.fake_B_5_5.view(1,1,-1)), self.softmax(self.fake_B_5_5_teacher.view(1,1,-1)))
|
|
326
|
+
|
|
327
|
+
self.loss_G = (self.loss_G_GAN_1 + self.loss_G_L1_1 + self.loss_G_VGG_1) * self.loss_G_weights[0] + \
|
|
328
|
+
(self.loss_G_GAN_2 + self.loss_G_L1_2 + self.loss_G_VGG_2) * self.loss_G_weights[1] + \
|
|
329
|
+
(self.loss_G_GAN_3 + self.loss_G_L1_3 + self.loss_G_VGG_3) * self.loss_G_weights[2] + \
|
|
330
|
+
(self.loss_G_GAN_4 + self.loss_G_L1_4 + self.loss_G_VGG_4) * self.loss_G_weights[3] + \
|
|
331
|
+
(self.loss_G_GAN_5 + self.loss_G_L1_5) * self.loss_G_weights[4] + \
|
|
332
|
+
(self.loss_G_KLDiv_1 + self.loss_G_KLDiv_2 + self.loss_G_KLDiv_3 + self.loss_G_KLDiv_4 + \
|
|
333
|
+
self.loss_G_KLDiv_5 + self.loss_G_KLDiv_5_1 + self.loss_G_KLDiv_5_2 + self.loss_G_KLDiv_5_3 + \
|
|
334
|
+
self.loss_G_KLDiv_5_4 + self.loss_G_KLDiv_5_5) * 10
|
|
335
|
+
|
|
336
|
+
# combine loss and calculate gradients
|
|
337
|
+
# self.loss_G = (self.loss_G_GAN_1 + self.loss_G_L1_1) * self.loss_G_weights[0] + \
|
|
338
|
+
# (self.loss_G_GAN_2 + self.loss_G_L1_2) * self.loss_G_weights[1] + \
|
|
339
|
+
# (self.loss_G_GAN_3 + self.loss_G_L1_3) * self.loss_G_weights[2] + \
|
|
340
|
+
# (self.loss_G_GAN_4 + self.loss_G_L1_4) * self.loss_G_weights[3] + \
|
|
341
|
+
# (self.loss_G_GAN_5 + self.loss_G_L1_5) * self.loss_G_weights[4]
|
|
342
|
+
self.loss_G.backward()
|
|
343
|
+
|
|
344
|
+
def optimize_parameters(self):
|
|
345
|
+
self.forward() # compute fake images: G(A)
|
|
346
|
+
# update D
|
|
347
|
+
self.set_requires_grad(self.netD1, True) # enable backprop for D1
|
|
348
|
+
self.set_requires_grad(self.netD2, True) # enable backprop for D2
|
|
349
|
+
self.set_requires_grad(self.netD3, True) # enable backprop for D3
|
|
350
|
+
self.set_requires_grad(self.netD4, True) # enable backprop for D4
|
|
351
|
+
self.set_requires_grad(self.netD51, True) # enable backprop for D51
|
|
352
|
+
self.set_requires_grad(self.netD52, True) # enable backprop for D52
|
|
353
|
+
self.set_requires_grad(self.netD53, True) # enable backprop for D53
|
|
354
|
+
self.set_requires_grad(self.netD54, True) # enable backprop for D54
|
|
355
|
+
self.set_requires_grad(self.netD55, True) # enable backprop for D54
|
|
356
|
+
|
|
357
|
+
self.optimizer_D.zero_grad() # set D's gradients to zero
|
|
358
|
+
self.backward_D() # calculate gradients for D
|
|
359
|
+
self.optimizer_D.step() # update D's weights
|
|
360
|
+
|
|
361
|
+
# update G
|
|
362
|
+
self.set_requires_grad(self.netD1, False) # D1 requires no gradients when optimizing G1
|
|
363
|
+
self.set_requires_grad(self.netD2, False) # D2 requires no gradients when optimizing G2
|
|
364
|
+
self.set_requires_grad(self.netD3, False) # D3 requires no gradients when optimizing G3
|
|
365
|
+
self.set_requires_grad(self.netD4, False) # D4 requires no gradients when optimizing G4
|
|
366
|
+
self.set_requires_grad(self.netD51, False) # D51 requires no gradients when optimizing G51
|
|
367
|
+
self.set_requires_grad(self.netD52, False) # D52 requires no gradients when optimizing G52
|
|
368
|
+
self.set_requires_grad(self.netD53, False) # D53 requires no gradients when optimizing G53
|
|
369
|
+
self.set_requires_grad(self.netD54, False) # D54 requires no gradients when optimizing G54
|
|
370
|
+
self.set_requires_grad(self.netD55, False) # D54 requires no gradients when optimizing G54
|
|
371
|
+
|
|
372
|
+
self.optimizer_G.zero_grad() # set G's gradients to zero
|
|
373
|
+
self.backward_G() # calculate graidents for G
|
|
374
|
+
self.optimizer_G.step() # udpate G's weights
|
|
375
|
+
|
|
376
|
+
def calculate_losses(self):
|
|
377
|
+
"""
|
|
378
|
+
Calculate losses but do not optimize parameters. Used in validation loss calculation during training.
|
|
379
|
+
"""
|
|
380
|
+
|
|
381
|
+
self.forward() # compute fake images: G(A)
|
|
382
|
+
# update D
|
|
383
|
+
self.set_requires_grad(self.netD1, True) # enable backprop for D1
|
|
384
|
+
self.set_requires_grad(self.netD2, True) # enable backprop for D2
|
|
385
|
+
self.set_requires_grad(self.netD3, True) # enable backprop for D3
|
|
386
|
+
self.set_requires_grad(self.netD4, True) # enable backprop for D4
|
|
387
|
+
self.set_requires_grad(self.netD51, True) # enable backprop for D51
|
|
388
|
+
self.set_requires_grad(self.netD52, True) # enable backprop for D52
|
|
389
|
+
self.set_requires_grad(self.netD53, True) # enable backprop for D53
|
|
390
|
+
self.set_requires_grad(self.netD54, True) # enable backprop for D54
|
|
391
|
+
self.set_requires_grad(self.netD55, True) # enable backprop for D54
|
|
392
|
+
|
|
393
|
+
self.optimizer_D.zero_grad() # set D's gradients to zero
|
|
394
|
+
self.backward_D() # calculate gradients for D
|
|
395
|
+
|
|
396
|
+
# update G
|
|
397
|
+
self.set_requires_grad(self.netD1, False) # D1 requires no gradients when optimizing G1
|
|
398
|
+
self.set_requires_grad(self.netD2, False) # D2 requires no gradients when optimizing G2
|
|
399
|
+
self.set_requires_grad(self.netD3, False) # D3 requires no gradients when optimizing G3
|
|
400
|
+
self.set_requires_grad(self.netD4, False) # D4 requires no gradients when optimizing G4
|
|
401
|
+
self.set_requires_grad(self.netD51, False) # D51 requires no gradients when optimizing G51
|
|
402
|
+
self.set_requires_grad(self.netD52, False) # D52 requires no gradients when optimizing G52
|
|
403
|
+
self.set_requires_grad(self.netD53, False) # D53 requires no gradients when optimizing G53
|
|
404
|
+
self.set_requires_grad(self.netD54, False) # D54 requires no gradients when optimizing G54
|
|
405
|
+
self.set_requires_grad(self.netD55, False) # D54 requires no gradients when optimizing G54
|
|
406
|
+
|
|
407
|
+
self.optimizer_G.zero_grad() # set G's gradients to zero
|
|
408
|
+
self.backward_G() # calculate graidents for G
|
|
409
|
+
|
|
@@ -265,12 +265,19 @@ def run_dask(img, model_path, eager_mode=False, opt=None, seg_only=False):
|
|
|
265
265
|
return model(input.to(next(model.parameters()).device))
|
|
266
266
|
|
|
267
267
|
if opt.model == 'DeepLIIF':
|
|
268
|
+
#weights = {
|
|
269
|
+
# 'G51': 0.25, # IHC
|
|
270
|
+
# 'G52': 0.25, # Hema
|
|
271
|
+
# 'G53': 0.25, # DAPI
|
|
272
|
+
# 'G54': 0.00, # Lap2
|
|
273
|
+
# 'G55': 0.25, # Marker
|
|
274
|
+
#}
|
|
268
275
|
weights = {
|
|
269
|
-
'G51': 0.
|
|
270
|
-
'G52': 0.
|
|
271
|
-
'G53': 0.
|
|
272
|
-
'G54': 0.
|
|
273
|
-
'G55': 0.
|
|
276
|
+
'G51': 0.5, # IHC
|
|
277
|
+
'G52': 0.0, # Hema
|
|
278
|
+
'G53': 0.0, # DAPI
|
|
279
|
+
'G54': 0.0, # Lap2
|
|
280
|
+
'G55': 0.5, # Marker
|
|
274
281
|
}
|
|
275
282
|
|
|
276
283
|
seg_map = {'G1': 'G52', 'G2': 'G53', 'G3': 'G54', 'G4': 'G55'}
|
|
@@ -330,8 +337,13 @@ def is_empty(tile):
|
|
|
330
337
|
return all([True if np.max(image_variance_rgb(t)) < thresh else False for t in tile])
|
|
331
338
|
else:
|
|
332
339
|
return True if np.max(image_variance_rgb(tile)) < thresh else False
|
|
340
|
+
#return True if image_variance_gray(tile) < thresh else False
|
|
341
|
+
#return True if image_variance_gray(tile.resize((128, 128), resample=Image.NEAREST)) < thresh else False
|
|
342
|
+
#return True if image_variance_gray(tile.convert('L').resize((128, 128), resample=Image.NEAREST)) < thresh else False
|
|
343
|
+
#return True if image_variance_gray(tile.convert('L').resize((64, 64), resample=Image.NEAREST)) < thresh else False
|
|
333
344
|
|
|
334
345
|
|
|
346
|
+
count_run_tiles = 0
|
|
335
347
|
def run_wrapper(tile, run_fn, model_path, eager_mode=False, opt=None, seg_only=False):
|
|
336
348
|
if opt.model == 'DeepLIIF':
|
|
337
349
|
if is_empty(tile):
|
|
@@ -354,6 +366,8 @@ def run_wrapper(tile, run_fn, model_path, eager_mode=False, opt=None, seg_only=F
|
|
|
354
366
|
'G55': Image.new(mode='RGB', size=(512, 512), color=(0, 0, 0)),
|
|
355
367
|
}
|
|
356
368
|
else:
|
|
369
|
+
global count_run_tiles
|
|
370
|
+
count_run_tiles += 1
|
|
357
371
|
return run_fn(tile, model_path, eager_mode, opt, seg_only)
|
|
358
372
|
elif opt.model in ['DeepLIIFExt', 'SDG']:
|
|
359
373
|
if is_empty(tile):
|
|
@@ -648,6 +662,10 @@ def get_wsi_resolution(filename):
|
|
|
648
662
|
return None, None
|
|
649
663
|
|
|
650
664
|
|
|
665
|
+
from tifffile import TiffFile
|
|
666
|
+
import zarr
|
|
667
|
+
import time
|
|
668
|
+
|
|
651
669
|
def infer_cells_for_wsi(filename, model_dir, tile_size, region_size=20000, version=3, print_log=False):
|
|
652
670
|
"""
|
|
653
671
|
Perform inference on a slide and get the results individual cell data.
|
|
@@ -702,11 +720,23 @@ def infer_cells_for_wsi(filename, model_dir, tile_size, region_size=20000, versi
|
|
|
702
720
|
region_XYWH = (start_x, start_y, min(stride_x, size_x-start_x), min(stride_y, size_y-start_y))
|
|
703
721
|
print_info('Region:', region_XYWH)
|
|
704
722
|
|
|
723
|
+
tstart_read = time.time()
|
|
724
|
+
|
|
705
725
|
region = reader.read(XYWH=region_XYWH, rescale=rescale)
|
|
706
726
|
print_info(region.shape, region.dtype)
|
|
707
727
|
img = Image.fromarray((region * 255).astype(np.uint8)) if rescale else Image.fromarray(region)
|
|
708
728
|
print_info(img.size, img.mode)
|
|
709
729
|
|
|
730
|
+
#x, y, w, h = region_XYWH
|
|
731
|
+
#print(region_XYWH, x, y, w, h, flush=True)
|
|
732
|
+
#tif = TiffFile(filename)
|
|
733
|
+
#z = zarr.open(tif.pages[0].aszarr(), mode='r')
|
|
734
|
+
#img = Image.fromarray(z[y:y+h, x:x+w])
|
|
735
|
+
#print_info(img.size, img.mode)
|
|
736
|
+
|
|
737
|
+
tend_read = time.time()
|
|
738
|
+
print('Time to read region:', round(tend_read-tstart_read, 1), 'sec.', flush=True)
|
|
739
|
+
|
|
710
740
|
images = inference(
|
|
711
741
|
img,
|
|
712
742
|
tile_size=tile_size,
|
|
@@ -749,6 +779,8 @@ def infer_cells_for_wsi(filename, model_dir, tile_size, region_size=20000, versi
|
|
|
749
779
|
start_y += stride_y
|
|
750
780
|
|
|
751
781
|
javabridge.kill_vm()
|
|
782
|
+
global count_run_tiles
|
|
783
|
+
print('Num tiles run:', count_run_tiles, flush=True)
|
|
752
784
|
|
|
753
785
|
if count_marker_thresh == 0:
|
|
754
786
|
count_marker_thresh = 1
|