deepliif 1.1.12__tar.gz → 1.1.14__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (54) hide show
  1. {deepliif-1.1.12/deepliif.egg-info → deepliif-1.1.14}/PKG-INFO +2 -2
  2. {deepliif-1.1.12 → deepliif-1.1.14}/README.md +1 -1
  3. {deepliif-1.1.12 → deepliif-1.1.14}/cli.py +15 -22
  4. {deepliif-1.1.12 → deepliif-1.1.14}/deepliif/data/aligned_dataset.py +2 -2
  5. deepliif-1.1.14/deepliif/models/DeepLIIFKD_model.py +409 -0
  6. deepliif-1.1.12/deepliif/models/__init__.py → deepliif-1.1.14/deepliif/models/__init__ - different weighted.py +33 -0
  7. deepliif-1.1.14/deepliif/models/__init__ - time gens.py +792 -0
  8. deepliif-1.1.14/deepliif/models/__init__ - weights, empty, zarr, tile count.py +792 -0
  9. deepliif-1.1.14/deepliif/models/__init__.py +817 -0
  10. {deepliif-1.1.12 → deepliif-1.1.14}/deepliif/models/base_model.py +1 -1
  11. {deepliif-1.1.12 → deepliif-1.1.14}/deepliif/models/networks.py +7 -5
  12. {deepliif-1.1.12 → deepliif-1.1.14}/deepliif/postprocessing.py +55 -24
  13. {deepliif-1.1.12 → deepliif-1.1.14}/deepliif/util/__init__.py +1 -1
  14. deepliif-1.1.14/deepliif/util/checks.py +17 -0
  15. {deepliif-1.1.12 → deepliif-1.1.14}/deepliif/util/util.py +42 -0
  16. {deepliif-1.1.12 → deepliif-1.1.14/deepliif.egg-info}/PKG-INFO +2 -2
  17. {deepliif-1.1.12 → deepliif-1.1.14}/deepliif.egg-info/SOURCES.txt +5 -0
  18. {deepliif-1.1.12 → deepliif-1.1.14}/deepliif.egg-info/requires.txt +2 -3
  19. {deepliif-1.1.12 → deepliif-1.1.14}/setup.cfg +1 -1
  20. {deepliif-1.1.12 → deepliif-1.1.14}/setup.py +3 -4
  21. {deepliif-1.1.12 → deepliif-1.1.14}/LICENSE.md +0 -0
  22. {deepliif-1.1.12 → deepliif-1.1.14}/deepliif/__init__.py +0 -0
  23. {deepliif-1.1.12 → deepliif-1.1.14}/deepliif/data/__init__.py +0 -0
  24. {deepliif-1.1.12 → deepliif-1.1.14}/deepliif/data/base_dataset.py +0 -0
  25. {deepliif-1.1.12 → deepliif-1.1.14}/deepliif/data/colorization_dataset.py +0 -0
  26. {deepliif-1.1.12 → deepliif-1.1.14}/deepliif/data/image_folder.py +0 -0
  27. {deepliif-1.1.12 → deepliif-1.1.14}/deepliif/data/single_dataset.py +0 -0
  28. {deepliif-1.1.12 → deepliif-1.1.14}/deepliif/data/template_dataset.py +0 -0
  29. {deepliif-1.1.12 → deepliif-1.1.14}/deepliif/data/unaligned_dataset.py +0 -0
  30. {deepliif-1.1.12 → deepliif-1.1.14}/deepliif/models/CycleGAN_model.py +0 -0
  31. {deepliif-1.1.12 → deepliif-1.1.14}/deepliif/models/DeepLIIFExt_model.py +0 -0
  32. {deepliif-1.1.12 → deepliif-1.1.14}/deepliif/models/DeepLIIF_model.py +0 -0
  33. {deepliif-1.1.12 → deepliif-1.1.14}/deepliif/models/SDG_model.py +0 -0
  34. {deepliif-1.1.12 → deepliif-1.1.14}/deepliif/models/__init__ - run_dask_multi dev.py +0 -0
  35. {deepliif-1.1.12 → deepliif-1.1.14}/deepliif/models/__init__ - timings.py +0 -0
  36. {deepliif-1.1.12 → deepliif-1.1.14}/deepliif/models/att_unet.py +0 -0
  37. {deepliif-1.1.12 → deepliif-1.1.14}/deepliif/options/__init__.py +0 -0
  38. {deepliif-1.1.12 → deepliif-1.1.14}/deepliif/options/base_options.py +0 -0
  39. {deepliif-1.1.12 → deepliif-1.1.14}/deepliif/options/processing_options.py +0 -0
  40. {deepliif-1.1.12 → deepliif-1.1.14}/deepliif/options/test_options.py +0 -0
  41. {deepliif-1.1.12 → deepliif-1.1.14}/deepliif/options/train_options.py +0 -0
  42. {deepliif-1.1.12 → deepliif-1.1.14}/deepliif/postprocessing__OLD__DELETE.py +0 -0
  43. {deepliif-1.1.12 → deepliif-1.1.14}/deepliif/util/get_data.py +0 -0
  44. {deepliif-1.1.12 → deepliif-1.1.14}/deepliif/util/html.py +0 -0
  45. {deepliif-1.1.12 → deepliif-1.1.14}/deepliif/util/image_pool.py +0 -0
  46. {deepliif-1.1.12 → deepliif-1.1.14}/deepliif/util/visualizer.py +0 -0
  47. {deepliif-1.1.12 → deepliif-1.1.14}/deepliif.egg-info/dependency_links.txt +0 -0
  48. {deepliif-1.1.12 → deepliif-1.1.14}/deepliif.egg-info/entry_points.txt +0 -0
  49. {deepliif-1.1.12 → deepliif-1.1.14}/deepliif.egg-info/top_level.txt +0 -0
  50. {deepliif-1.1.12 → deepliif-1.1.14}/tests/test_args.py +0 -0
  51. {deepliif-1.1.12 → deepliif-1.1.14}/tests/test_cli_inference.py +0 -0
  52. {deepliif-1.1.12 → deepliif-1.1.14}/tests/test_cli_serialize.py +0 -0
  53. {deepliif-1.1.12 → deepliif-1.1.14}/tests/test_cli_train.py +0 -0
  54. {deepliif-1.1.12 → deepliif-1.1.14}/tests/test_cli_trainlaunch.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: deepliif
3
- Version: 1.1.12
3
+ Version: 1.1.14
4
4
  Summary: DeepLIIF: Deep-Learning Inferred Multiplex Immunofluorescence for Immunohistochemical Image Quantification
5
5
  Home-page: https://github.com/nadeemlab/DeepLIIF
6
6
  Author: Parmida93
@@ -53,7 +53,7 @@ segmentation.*
53
53
 
54
54
  © This code is made available for non-commercial academic purposes.
55
55
 
56
- ![Version](https://img.shields.io/static/v1?label=latest&message=v1.1.12&color=darkgreen)
56
+ ![Version](https://img.shields.io/static/v1?label=latest&message=v1.1.14&color=darkgreen)
57
57
  [![Total Downloads](https://static.pepy.tech/personalized-badge/deepliif?period=total&units=international_system&left_color=grey&right_color=blue&left_text=total%20downloads)](https://pepy.tech/project/deepliif?&left_text=totalusers)
58
58
 
59
59
  ![overview_image](./images/overview.png)*Overview of DeepLIIF pipeline and sample input IHCs (different
@@ -42,7 +42,7 @@ segmentation.*
42
42
 
43
43
  © This code is made available for non-commercial academic purposes.
44
44
 
45
- ![Version](https://img.shields.io/static/v1?label=latest&message=v1.1.12&color=darkgreen)
45
+ ![Version](https://img.shields.io/static/v1?label=latest&message=v1.1.14&color=darkgreen)
46
46
  [![Total Downloads](https://static.pepy.tech/personalized-badge/deepliif?period=total&units=international_system&left_color=grey&right_color=blue&left_text=total%20downloads)](https://pepy.tech/project/deepliif?&left_text=totalusers)
47
47
 
48
48
  ![overview_image](./images/overview.png)*Overview of DeepLIIF pipeline and sample input IHCs (different
@@ -14,6 +14,7 @@ from deepliif.data import create_dataset, transform
14
14
  from deepliif.models import init_nets, infer_modalities, infer_results_for_wsi, create_model, postprocess
15
15
  from deepliif.util import allowed_file, Visualizer, get_information, test_diff_original_serialized, disable_batchnorm_tracking_stats
16
16
  from deepliif.util.util import mkdirs
17
+ from deepliif.util.checks import check_weights
17
18
  # from deepliif.util import infer_results_for_wsi
18
19
  from deepliif.options import Options, print_options
19
20
 
@@ -78,6 +79,7 @@ def cli():
78
79
  @click.option('--modalities-no', default=4, type=int, help='number of targets')
79
80
  # model parameters
80
81
  @click.option('--model', default='DeepLIIF', help='name of model class')
82
+ @click.option('--model-dir-teacher', default='', help='the directory of the teacher model, only applicable if model is DeepLIIFKD')
81
83
  @click.option('--seg-weights', default='', type=str, help='weights used to aggregate modality images for the final segmentation image; numbers should add up to 1, and each number corresponds to the modality in order; example: 0.25,0.15,0.25,0.1,0.25')
82
84
  @click.option('--loss-weights-g', default='', type=str, help='weights used to aggregate modality-wise losses for the final loss; numbers should add up to 1, and each number corresponds to the modality in order; example: 0.2,0.2,0.2,0.2,0.2')
83
85
  @click.option('--loss-weights-d', default='', type=str, help='weights used to aggregate modality-wise losses for the final loss; numbers should add up to 1, and each number corresponds to the modality in order; example: 0.2,0.2,0.2,0.2,0.2')
@@ -193,7 +195,8 @@ def train(dataroot, name, gpu_ids, checkpoints_dir, input_nc, output_nc, ngf, nd
193
195
  verbose, lambda_l1, is_train, display_freq, display_ncols, display_id, display_server, display_env,
194
196
  display_port, update_html_freq, print_freq, no_html, save_latest_freq, save_epoch_freq, save_by_iter,
195
197
  continue_train, epoch_count, phase, lr_policy, n_epochs, n_epochs_decay, optimizer, beta1, lr_g, lr_d, lr_decay_iters,
196
- remote, remote_transfer_cmd, seed, dataset_mode, padding, model, seg_weights, loss_weights_g, loss_weights_d,
198
+ remote, remote_transfer_cmd, seed, dataset_mode, padding, model, model_dir_teacher,
199
+ seg_weights, loss_weights_g, loss_weights_d,
197
200
  modalities_no, seg_gen, net_ds, net_gs, gan_mode, gan_mode_s, local_rank, with_val, debug, debug_data_size):
198
201
  """General-purpose training script for multi-task image-to-image translation.
199
202
 
@@ -206,8 +209,8 @@ def train(dataroot, name, gpu_ids, checkpoints_dir, input_nc, output_nc, ngf, nd
206
209
  plot, and save models.The script supports continue/resume training.
207
210
  Use '--continue_train' to resume your previous training.
208
211
  """
209
- assert model in ['DeepLIIF','DeepLIIFExt','SDG','CycleGAN'], f'model class {model} is not implemented'
210
- if model == 'DeepLIIF':
212
+ assert model in ['DeepLIIF','DeepLIIFExt','SDG','CycleGAN','DeepLIIFKD'], f'model class {model} is not implemented'
213
+ if model in ['DeepLIIF','DeepLIIFKD']:
211
214
  seg_no = 1
212
215
  elif model == 'DeepLIIFExt':
213
216
  if seg_gen:
@@ -221,6 +224,9 @@ def train(dataroot, name, gpu_ids, checkpoints_dir, input_nc, output_nc, ngf, nd
221
224
  if model == 'CycleGAN':
222
225
  dataset_mode = "unaligned"
223
226
 
227
+ if model == 'DeepLIIFKD':
228
+ assert len(model_dir_teacher) > 0 and os.path.isdir(model_dir_teacher), f'Teacher model directory {model_dir_teacher} is not valid.'
229
+
224
230
  if optimizer != 'adam':
225
231
  print(f'Optimizer torch.optim.{optimizer} is not tested. Be careful about the parameters of the optimizer.')
226
232
 
@@ -310,7 +316,7 @@ def train(dataroot, name, gpu_ids, checkpoints_dir, input_nc, output_nc, ngf, nd
310
316
 
311
317
  net_gs = net_gs.split(',')
312
318
  assert len(net_gs) in [1,seg_no], f'net_gs should contain either 1 architecture for all segmentation generators or the same number of architectures as the number of segmentation generators ({seg_no})'
313
- if len(net_gs) == 1 and model == 'DeepLIIF':
319
+ if len(net_gs) == 1 and model in ['DeepLIIF','DeepLIIFKD']:
314
320
  net_gs = net_gs*(modalities_no + seg_no)
315
321
  elif len(net_gs) == 1:
316
322
  net_gs = net_gs*seg_no
@@ -320,34 +326,21 @@ def train(dataroot, name, gpu_ids, checkpoints_dir, input_nc, output_nc, ngf, nd
320
326
 
321
327
  # check seg weights and loss weights
322
328
  if len(d_params['seg_weights']) == 0:
323
- seg_weights = [0.25,0.15,0.25,0.1,0.25] if d_params['model'] == 'DeepLIIF' else [1 / modalities_no] * modalities_no
329
+ seg_weights = [0.25,0.15,0.25,0.1,0.25] if d_params['model'] in ['DeepLIIF','DeepLIIFKD'] else [1 / modalities_no] * modalities_no
324
330
  else:
325
331
  seg_weights = [float(x) for x in seg_weights.split(',')]
326
332
 
327
333
  if len(d_params['loss_weights_g']) == 0:
328
- loss_weights_g = [0.2]*5 if d_params['model'] == 'DeepLIIF' else [1 / modalities_no] * modalities_no
334
+ loss_weights_g = [0.2]*5 if d_params['model'] in ['DeepLIIF','DeepLIIFKD'] else [1 / modalities_no] * modalities_no
329
335
  else:
330
336
  loss_weights_g = [float(x) for x in loss_weights_g.split(',')]
331
337
 
332
338
  if len(d_params['loss_weights_d']) == 0:
333
- loss_weights_d = [0.2]*5 if d_params['model'] == 'DeepLIIF' else [1 / modalities_no] * modalities_no
339
+ loss_weights_d = [0.2]*5 if d_params['model'] in ['DeepLIIF','DeepLIIFKD'] else [1 / modalities_no] * modalities_no
334
340
  else:
335
341
  loss_weights_d = [float(x) for x in loss_weights_d.split(',')]
336
342
 
337
- assert sum(seg_weights) == 1, 'seg weights should add up to 1'
338
- assert sum(loss_weights_g) == 1, 'loss weights g should add up to 1'
339
- assert sum(loss_weights_d) == 1, 'loss weights d should add up to 1'
340
-
341
- if model == 'DeepLIIF':
342
- # +1 because input becomes an additional modality used in generating the final segmentation
343
- assert len(seg_weights) == modalities_no+1, 'seg weights should have the same number of elements as number of modalities to be generated'
344
- assert len(loss_weights_g) == modalities_no+1, 'loss weights g should have the same number of elements as number of modalities to be generated'
345
- assert len(loss_weights_d) == modalities_no+1, 'loss weights d should have the same number of elements as number of modalities to be generated'
346
-
347
- else:
348
- assert len(seg_weights) == modalities_no, 'seg weights should have the same number of elements as number of modalities to be generated'
349
- assert len(loss_weights_g) == modalities_no, 'loss weights g should have the same number of elements as number of modalities to be generated'
350
- assert len(loss_weights_d) == modalities_no, 'loss weights d should have the same number of elements as number of modalities to be generated'
343
+ check_weights(model, modalities_no, seg_weights, loss_weights_g, loss_weights_d)
351
344
 
352
345
  d_params['seg_weights'] = seg_weights
353
346
  d_params['loss_G_weights'] = loss_weights_g
@@ -373,7 +366,7 @@ def train(dataroot, name, gpu_ids, checkpoints_dir, input_nc, output_nc, ngf, nd
373
366
  data_val = [batch for batch in dataset_val]
374
367
  click.echo('The number of validation images = %d' % len(dataset_val))
375
368
 
376
- if model in ['DeepLIIF']:
369
+ if model in ['DeepLIIF', 'DeepLIIFKD']:
377
370
  metrics_val = json.load(open(os.path.join(dataset_val.dataset.dir_AB,'metrics.json')))
378
371
 
379
372
  # create a model given model and other options
@@ -50,7 +50,7 @@ class AlignedDataset(BaseDataset):
50
50
  AB = Image.open(AB_path).convert('RGB')
51
51
  # split AB image into A and B
52
52
  w, h = AB.size
53
- if self.model == 'DeepLIIF':
53
+ if self.model in ['DeepLIIF','DeepLIIFKD']:
54
54
  num_img = self.modalities_no + 1 + 1 # +1 for segmentation channel, +1 for input image
55
55
  elif self.model == 'DeepLIIFExt':
56
56
  num_img = self.modalities_no * 2 + 1 if self.seg_gen else self.modalities_no + 1 # +1 for segmentation channel
@@ -68,7 +68,7 @@ class AlignedDataset(BaseDataset):
68
68
 
69
69
  A = A_transform(A)
70
70
  B_Array = []
71
- if self.model == 'DeepLIIF':
71
+ if self.model in ['DeepLIIF','DeepLIIFKD']:
72
72
  for i in range(1, num_img):
73
73
  B = AB.crop((w2 * i, 0, w2 * (i + 1), h))
74
74
  B = B_transform(B)
@@ -0,0 +1,409 @@
1
+ import torch
2
+ from .base_model import BaseModel
3
+ from . import networks
4
+ from .networks import get_optimizer
5
+ from . import init_nets, run_dask, get_opt
6
+ from torch import nn
7
+
8
+ class DeepLIIFKDModel(BaseModel):
9
+ """ This class implements the DeepLIIF model, for learning a mapping from input images to modalities given paired data."""
10
+
11
+ def __init__(self, opt):
12
+ """Initialize the DeepLIIF class.
13
+
14
+ Parameters:
15
+ opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions
16
+ """
17
+ BaseModel.__init__(self, opt)
18
+ if not hasattr(opt,'net_gs'):
19
+ opt.net_gs = 'unet_512'
20
+
21
+ self.seg_weights = opt.seg_weights
22
+ self.loss_G_weights = opt.loss_G_weights
23
+ self.loss_D_weights = opt.loss_D_weights
24
+
25
+ if not opt.is_train:
26
+ self.gpu_ids = [] # avoid the models being loaded as DP
27
+ else:
28
+ self.gpu_ids = opt.gpu_ids
29
+
30
+ self.loss_names = []
31
+ self.visual_names = ['real_A']
32
+ # specify the training losses you want to print out. The training/test scripts will call <BaseModel.get_current_losses>
33
+ for i in range(1, self.opt.modalities_no + 1 + 1):
34
+ self.loss_names.extend([f'G_GAN_{i}', f'G_L1_{i}', f'D_real_{i}', f'D_fake_{i}', f'G_KLDiv_{i}', f'G_KLDiv_5_{i}'])
35
+ self.visual_names.extend([f'fake_B_{i}', f'fake_B_5_{i}', f'fake_B_{i}_teacher', f'fake_B_5_{i}_teacher', f'real_B_{i}'])
36
+
37
+ # specify the images you want to save/display. The training/test scripts will call <BaseModel.get_current_visuals>
38
+ # specify the models you want to save to the disk. The training/test scripts will call <BaseModel.save_networks> and <BaseModel.load_networks>
39
+ if self.is_train:
40
+ self.model_names = []
41
+ for i in range(1, self.opt.modalities_no + 1):
42
+ self.model_names.extend(['G' + str(i), 'D' + str(i)])
43
+
44
+ for i in range(1, self.opt.modalities_no + 1 + 1):
45
+ self.model_names.extend(['G5' + str(i), 'D5' + str(i)])
46
+ else: # during test time, only load G
47
+ self.model_names = []
48
+ for i in range(1, self.opt.modalities_no + 1):
49
+ self.model_names.extend(['G' + str(i)])
50
+
51
+ for i in range(1, self.opt.modalities_no + 1 + 1):
52
+ self.model_names.extend(['G5' + str(i)])
53
+
54
+ # define networks (both generator and discriminator)
55
+ if isinstance(opt.netG, str):
56
+ opt.netG = [opt.netG] * 4
57
+ if isinstance(opt.net_gs, str):
58
+ opt.net_gs = [opt.net_gs]*5
59
+
60
+
61
+ self.netG1 = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG[0], opt.norm,
62
+ not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids, opt.padding)
63
+ self.netG2 = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG[1], opt.norm,
64
+ not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids, opt.padding)
65
+ self.netG3 = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG[2], opt.norm,
66
+ not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids, opt.padding)
67
+ self.netG4 = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG[3], opt.norm,
68
+ not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids, opt.padding)
69
+
70
+ # DeepLIIF model currently uses one gs arch because there is only one explicit seg mod output
71
+ self.netG51 = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.net_gs[0], opt.norm,
72
+ not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids)
73
+ self.netG52 = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.net_gs[1], opt.norm,
74
+ not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids)
75
+ self.netG53 = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.net_gs[2], opt.norm,
76
+ not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids)
77
+ self.netG54 = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.net_gs[3], opt.norm,
78
+ not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids)
79
+ self.netG55 = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.net_gs[4], opt.norm,
80
+ not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids)
81
+
82
+
83
+ if self.is_train: # define a discriminator; conditional GANs need to take both input and output images; Therefore, #channels for D is input_nc + output_nc
84
+ self.netD1 = networks.define_D(opt.input_nc+opt.output_nc , opt.ndf, opt.netD,
85
+ opt.n_layers_D, opt.norm, opt.init_type, opt.init_gain, self.gpu_ids)
86
+ self.netD2 = networks.define_D(opt.input_nc + opt.output_nc, opt.ndf, opt.netD,
87
+ opt.n_layers_D, opt.norm, opt.init_type, opt.init_gain, self.gpu_ids)
88
+ self.netD3 = networks.define_D(opt.input_nc + opt.output_nc, opt.ndf, opt.netD,
89
+ opt.n_layers_D, opt.norm, opt.init_type, opt.init_gain, self.gpu_ids)
90
+ self.netD4 = networks.define_D(opt.input_nc + opt.output_nc, opt.ndf, opt.netD,
91
+ opt.n_layers_D, opt.norm, opt.init_type, opt.init_gain, self.gpu_ids)
92
+
93
+ self.netD51 = networks.define_D(opt.input_nc + opt.output_nc, opt.ndf, opt.netD,
94
+ opt.n_layers_D, opt.norm, opt.init_type, opt.init_gain, self.gpu_ids)
95
+ self.netD52 = networks.define_D(opt.input_nc + opt.output_nc, opt.ndf, opt.netD,
96
+ opt.n_layers_D, opt.norm, opt.init_type, opt.init_gain, self.gpu_ids)
97
+ self.netD53 = networks.define_D(opt.input_nc + opt.output_nc, opt.ndf, opt.netD,
98
+ opt.n_layers_D, opt.norm, opt.init_type, opt.init_gain, self.gpu_ids)
99
+ self.netD54 = networks.define_D(opt.input_nc + opt.output_nc, opt.ndf, opt.netD,
100
+ opt.n_layers_D, opt.norm, opt.init_type, opt.init_gain, self.gpu_ids)
101
+ self.netD55 = networks.define_D(opt.input_nc + opt.output_nc, opt.ndf, opt.netD,
102
+ opt.n_layers_D, opt.norm, opt.init_type, opt.init_gain, self.gpu_ids)
103
+
104
+ # load the teacher model
105
+ self.opt_teacher = get_opt(opt.model_dir_teacher, mode='test')
106
+ self.opt_teacher.gpu_ids = opt.gpu_ids # use student's gpu_ids
107
+ self.nets_teacher = init_nets(opt.model_dir_teacher, eager_mode=True, opt=self.opt_teacher, phase='test')
108
+
109
+
110
+ if self.is_train:
111
+ # define loss functions
112
+ self.criterionGAN_BCE = networks.GANLoss('vanilla').to(self.device)
113
+ self.criterionGAN_lsgan = networks.GANLoss('lsgan').to(self.device)
114
+ self.criterionSmoothL1 = torch.nn.SmoothL1Loss()
115
+
116
+ # initialize optimizers; schedulers will be automatically created by function <BaseModel.setup>.
117
+ params = list(self.netG1.parameters()) + list(self.netG2.parameters()) + list(self.netG3.parameters()) + list(self.netG4.parameters()) + list(self.netG51.parameters()) + list(self.netG52.parameters()) + list(self.netG53.parameters()) + list(self.netG54.parameters()) + list(self.netG55.parameters())
118
+ try:
119
+ self.optimizer_G = get_optimizer(opt.optimizer)(params, lr=opt.lr_g, betas=(opt.beta1, 0.999))
120
+ except:
121
+ print(f'betas are not used for optimizer torch.optim.{opt.optimizer} in generators')
122
+ self.optimizer_G = get_optimizer(opt.optimizer)(params, lr=opt.lr_g)
123
+
124
+ params = list(self.netD1.parameters()) + list(self.netD2.parameters()) + list(self.netD3.parameters()) + list(self.netD4.parameters()) + list(self.netD51.parameters()) + list(self.netD52.parameters()) + list(self.netD53.parameters()) + list(self.netD54.parameters()) + list(self.netD55.parameters())
125
+ try:
126
+ self.optimizer_D = get_optimizer(opt.optimizer)(params, lr=opt.lr_d, betas=(opt.beta1, 0.999))
127
+ except:
128
+ print(f'betas are not used for optimizer torch.optim.{opt.optimizer} in discriminators')
129
+ self.optimizer_D = get_optimizer(opt.optimizer)(params, lr=opt.lr_d)
130
+
131
+ self.optimizers.append(self.optimizer_G)
132
+ self.optimizers.append(self.optimizer_D)
133
+
134
+ self.criterionVGG = networks.VGGLoss().to(self.device)
135
+ self.criterionKLDiv = torch.nn.KLDivLoss(reduction='batchmean').to(self.device)
136
+ self.softmax = torch.nn.Softmax(dim=-1).to(self.device) # apply softmax on the last dim
137
+ self.logsoftmax = torch.nn.LogSoftmax(dim=-1).to(self.device) # apply log-softmax on the last dim
138
+
139
+ def set_input(self, input):
140
+ """
141
+ Unpack input data from the dataloader and perform necessary pre-processing steps.
142
+
143
+ :param input (dict): include the input image and the output modalities
144
+ """
145
+
146
+ self.real_A = input['A'].to(self.device)
147
+
148
+ self.real_B_array = input['B']
149
+ self.real_B_1 = self.real_B_array[0].to(self.device)
150
+ self.real_B_2 = self.real_B_array[1].to(self.device)
151
+ self.real_B_3 = self.real_B_array[2].to(self.device)
152
+ self.real_B_4 = self.real_B_array[3].to(self.device)
153
+ self.real_B_5 = self.real_B_array[4].to(self.device)
154
+ self.image_paths = input['A_paths']
155
+
156
+ def forward(self):
157
+ """Run forward pass; called by both functions <optimize_parameters> and <test>."""
158
+ self.fake_B_1 = self.netG1(self.real_A) # Hematoxylin image generator
159
+ self.fake_B_2 = self.netG2(self.real_A) # mpIF DAPI image generator
160
+ self.fake_B_3 = self.netG3(self.real_A) # mpIF Lap2 image generator
161
+ self.fake_B_4 = self.netG4(self.real_A) # mpIF Ki67 image generator
162
+
163
+ self.fake_B_5_1 = self.netG51(self.real_A) # Segmentation mask generator from IHC input image
164
+ self.fake_B_5_2 = self.netG52(self.fake_B_1) # Segmentation mask generator from Hematoxylin input image
165
+ self.fake_B_5_3 = self.netG53(self.fake_B_2) # Segmentation mask generator from mpIF DAPI input image
166
+ self.fake_B_5_4 = self.netG54(self.fake_B_3) # Segmentation mask generator from mpIF Lap2 input image
167
+ self.fake_B_5_5 = self.netG55(self.fake_B_4) # Segmentation mask generator from mpIF Lap2 input image
168
+ self.fake_B_5 = torch.stack([torch.mul(self.fake_B_5_1, self.seg_weights[0]),
169
+ torch.mul(self.fake_B_5_2, self.seg_weights[1]),
170
+ torch.mul(self.fake_B_5_3, self.seg_weights[2]),
171
+ torch.mul(self.fake_B_5_4, self.seg_weights[3]),
172
+ torch.mul(self.fake_B_5_5, self.seg_weights[4])]).sum(dim=0)
173
+
174
+ fakes_teacher = run_dask(img=self.real_A, nets=self.nets_teacher, opt=self.opt_teacher, use_dask=False, output_tensor=True)
175
+ for k,v in fakes_teacher.items():
176
+ suffix = k[1:] # starts with G
177
+ suffix = '_'.join(list(suffix)) # 51 -> 5_1
178
+ setattr(self,f'fake_B_{suffix}_teacher',v)
179
+
180
+ def backward_D(self):
181
+ """Calculate GAN loss for the discriminators"""
182
+ fake_AB_1 = torch.cat((self.real_A, self.fake_B_1), 1) # Conditional GANs; feed IHC input and Hematoxtlin output to the discriminator
183
+ fake_AB_2 = torch.cat((self.real_A, self.fake_B_2), 1) # Conditional GANs; feed IHC input and mpIF DAPI output to the discriminator
184
+ fake_AB_3 = torch.cat((self.real_A, self.fake_B_3), 1) # Conditional GANs; feed IHC input and mpIF Lap2 output to the discriminator
185
+ fake_AB_4 = torch.cat((self.real_A, self.fake_B_4), 1) # Conditional GANs; feed IHC input and mpIF Ki67 output to the discriminator
186
+
187
+ pred_fake_1 = self.netD1(fake_AB_1.detach())
188
+ pred_fake_2 = self.netD2(fake_AB_2.detach())
189
+ pred_fake_3 = self.netD3(fake_AB_3.detach())
190
+ pred_fake_4 = self.netD4(fake_AB_4.detach())
191
+
192
+ fake_AB_5_1 = torch.cat((self.real_A, self.fake_B_5), 1) # Conditional GANs; feed IHC input and Segmentation mask output to the discriminator
193
+ fake_AB_5_2 = torch.cat((self.real_B_1, self.fake_B_5), 1) # Conditional GANs; feed Hematoxylin input and Segmentation mask output to the discriminator
194
+ fake_AB_5_3 = torch.cat((self.real_B_2, self.fake_B_5), 1) # Conditional GANs; feed mpIF DAPI input and Segmentation mask output to the discriminator
195
+ fake_AB_5_4 = torch.cat((self.real_B_3, self.fake_B_5), 1) # Conditional GANs; feed mpIF Lap2 input and Segmentation mask output to the discriminator
196
+ fake_AB_5_5 = torch.cat((self.real_B_4, self.fake_B_5), 1) # Conditional GANs; feed mpIF Lap2 input and Segmentation mask output to the discriminator
197
+
198
+ pred_fake_5_1 = self.netD51(fake_AB_5_1.detach())
199
+ pred_fake_5_2 = self.netD52(fake_AB_5_2.detach())
200
+ pred_fake_5_3 = self.netD53(fake_AB_5_3.detach())
201
+ pred_fake_5_4 = self.netD54(fake_AB_5_4.detach())
202
+ pred_fake_5_5 = self.netD55(fake_AB_5_5.detach())
203
+
204
+ pred_fake_5 = torch.stack(
205
+ [torch.mul(pred_fake_5_1, self.seg_weights[0]),
206
+ torch.mul(pred_fake_5_2, self.seg_weights[1]),
207
+ torch.mul(pred_fake_5_3, self.seg_weights[2]),
208
+ torch.mul(pred_fake_5_4, self.seg_weights[3]),
209
+ torch.mul(pred_fake_5_5, self.seg_weights[4])]).sum(dim=0)
210
+
211
+ self.loss_D_fake_1 = self.criterionGAN_BCE(pred_fake_1, False)
212
+ self.loss_D_fake_2 = self.criterionGAN_BCE(pred_fake_2, False)
213
+ self.loss_D_fake_3 = self.criterionGAN_BCE(pred_fake_3, False)
214
+ self.loss_D_fake_4 = self.criterionGAN_BCE(pred_fake_4, False)
215
+ self.loss_D_fake_5 = self.criterionGAN_lsgan(pred_fake_5, False)
216
+
217
+
218
+ real_AB_1 = torch.cat((self.real_A, self.real_B_1), 1)
219
+ real_AB_2 = torch.cat((self.real_A, self.real_B_2), 1)
220
+ real_AB_3 = torch.cat((self.real_A, self.real_B_3), 1)
221
+ real_AB_4 = torch.cat((self.real_A, self.real_B_4), 1)
222
+
223
+ pred_real_1 = self.netD1(real_AB_1)
224
+ pred_real_2 = self.netD2(real_AB_2)
225
+ pred_real_3 = self.netD3(real_AB_3)
226
+ pred_real_4 = self.netD4(real_AB_4)
227
+
228
+ real_AB_5_1 = torch.cat((self.real_A, self.real_B_5), 1)
229
+ real_AB_5_2 = torch.cat((self.real_B_1, self.real_B_5), 1)
230
+ real_AB_5_3 = torch.cat((self.real_B_2, self.real_B_5), 1)
231
+ real_AB_5_4 = torch.cat((self.real_B_3, self.real_B_5), 1)
232
+ real_AB_5_5 = torch.cat((self.real_B_4, self.real_B_5), 1)
233
+
234
+ pred_real_5_1 = self.netD51(real_AB_5_1)
235
+ pred_real_5_2 = self.netD52(real_AB_5_2)
236
+ pred_real_5_3 = self.netD53(real_AB_5_3)
237
+ pred_real_5_4 = self.netD54(real_AB_5_4)
238
+ pred_real_5_5 = self.netD55(real_AB_5_5)
239
+
240
+ pred_real_5 = torch.stack(
241
+ [torch.mul(pred_real_5_1, self.seg_weights[0]),
242
+ torch.mul(pred_real_5_2, self.seg_weights[1]),
243
+ torch.mul(pred_real_5_3, self.seg_weights[2]),
244
+ torch.mul(pred_real_5_4, self.seg_weights[3]),
245
+ torch.mul(pred_real_5_5, self.seg_weights[4])]).sum(dim=0)
246
+
247
+ self.loss_D_real_1 = self.criterionGAN_BCE(pred_real_1, True)
248
+ self.loss_D_real_2 = self.criterionGAN_BCE(pred_real_2, True)
249
+ self.loss_D_real_3 = self.criterionGAN_BCE(pred_real_3, True)
250
+ self.loss_D_real_4 = self.criterionGAN_BCE(pred_real_4, True)
251
+ self.loss_D_real_5 = self.criterionGAN_lsgan(pred_real_5, True)
252
+
253
+ # combine losses and calculate gradients
254
+ self.loss_D = (self.loss_D_fake_1 + self.loss_D_real_1) * 0.5 * self.loss_D_weights[0] + \
255
+ (self.loss_D_fake_2 + self.loss_D_real_2) * 0.5 * self.loss_D_weights[1] + \
256
+ (self.loss_D_fake_3 + self.loss_D_real_3) * 0.5 * self.loss_D_weights[2] + \
257
+ (self.loss_D_fake_4 + self.loss_D_real_4) * 0.5 * self.loss_D_weights[3] + \
258
+ (self.loss_D_fake_5 + self.loss_D_real_5) * 0.5 * self.loss_D_weights[4]
259
+
260
+ self.loss_D.backward()
261
+
262
+ def backward_G(self):
263
+ """Calculate GAN and L1 loss for the generator"""
264
+
265
+ fake_AB_1 = torch.cat((self.real_A, self.fake_B_1), 1)
266
+ fake_AB_2 = torch.cat((self.real_A, self.fake_B_2), 1)
267
+ fake_AB_3 = torch.cat((self.real_A, self.fake_B_3), 1)
268
+ fake_AB_4 = torch.cat((self.real_A, self.fake_B_4), 1)
269
+
270
+ fake_AB_5_1 = torch.cat((self.real_A, self.fake_B_5), 1)
271
+ fake_AB_5_2 = torch.cat((self.real_B_1, self.fake_B_5), 1)
272
+ fake_AB_5_3 = torch.cat((self.real_B_2, self.fake_B_5), 1)
273
+ fake_AB_5_4 = torch.cat((self.real_B_3, self.fake_B_5), 1)
274
+ fake_AB_5_5 = torch.cat((self.real_B_4, self.fake_B_5), 1)
275
+
276
+ pred_fake_1 = self.netD1(fake_AB_1)
277
+ pred_fake_2 = self.netD2(fake_AB_2)
278
+ pred_fake_3 = self.netD3(fake_AB_3)
279
+ pred_fake_4 = self.netD4(fake_AB_4)
280
+
281
+ pred_fake_5_1 = self.netD51(fake_AB_5_1)
282
+ pred_fake_5_2 = self.netD52(fake_AB_5_2)
283
+ pred_fake_5_3 = self.netD53(fake_AB_5_3)
284
+ pred_fake_5_4 = self.netD54(fake_AB_5_4)
285
+ pred_fake_5_5 = self.netD55(fake_AB_5_5)
286
+ pred_fake_5 = torch.stack(
287
+ [torch.mul(pred_fake_5_1, self.seg_weights[0]),
288
+ torch.mul(pred_fake_5_2, self.seg_weights[1]),
289
+ torch.mul(pred_fake_5_3, self.seg_weights[2]),
290
+ torch.mul(pred_fake_5_4, self.seg_weights[3]),
291
+ torch.mul(pred_fake_5_5, self.seg_weights[4])]).sum(dim=0)
292
+
293
+ self.loss_G_GAN_1 = self.criterionGAN_BCE(pred_fake_1, True)
294
+ self.loss_G_GAN_2 = self.criterionGAN_BCE(pred_fake_2, True)
295
+ self.loss_G_GAN_3 = self.criterionGAN_BCE(pred_fake_3, True)
296
+ self.loss_G_GAN_4 = self.criterionGAN_BCE(pred_fake_4, True)
297
+ self.loss_G_GAN_5 = self.criterionGAN_lsgan(pred_fake_5, True)
298
+
299
+ # Second, G(A) = B
300
+ self.loss_G_L1_1 = self.criterionSmoothL1(self.fake_B_1, self.real_B_1) * self.opt.lambda_L1
301
+ self.loss_G_L1_2 = self.criterionSmoothL1(self.fake_B_2, self.real_B_2) * self.opt.lambda_L1
302
+ self.loss_G_L1_3 = self.criterionSmoothL1(self.fake_B_3, self.real_B_3) * self.opt.lambda_L1
303
+ self.loss_G_L1_4 = self.criterionSmoothL1(self.fake_B_4, self.real_B_4) * self.opt.lambda_L1
304
+ self.loss_G_L1_5 = self.criterionSmoothL1(self.fake_B_5, self.real_B_5) * self.opt.lambda_L1
305
+
306
+ self.loss_G_VGG_1 = self.criterionVGG(self.fake_B_1, self.real_B_1) * self.opt.lambda_feat
307
+ self.loss_G_VGG_2 = self.criterionVGG(self.fake_B_2, self.real_B_2) * self.opt.lambda_feat
308
+ self.loss_G_VGG_3 = self.criterionVGG(self.fake_B_3, self.real_B_3) * self.opt.lambda_feat
309
+ self.loss_G_VGG_4 = self.criterionVGG(self.fake_B_4, self.real_B_4) * self.opt.lambda_feat
310
+
311
+
312
+ # .view(1,1,-1) reshapes the input (batch_size, 3, 512, 512) to (batch_size, 1, 3*512*512)
313
+ # softmax/log-softmax is then applied on the concatenated vector of size (1, 3*512*512)
314
+ # this normalizes the pixel values across all 3 RGB channels
315
+ # the resulting vectors are then used to compute KL divergence loss
316
+ self.loss_G_KLDiv_1 = self.criterionKLDiv(self.logsoftmax(self.fake_B_1.view(1,1,-1)), self.softmax(self.fake_B_1_teacher.view(1,1,-1)))
317
+ self.loss_G_KLDiv_2 = self.criterionKLDiv(self.logsoftmax(self.fake_B_2.view(1,1,-1)), self.softmax(self.fake_B_2_teacher.view(1,1,-1)))
318
+ self.loss_G_KLDiv_3 = self.criterionKLDiv(self.logsoftmax(self.fake_B_3.view(1,1,-1)), self.softmax(self.fake_B_3_teacher.view(1,1,-1)))
319
+ self.loss_G_KLDiv_4 = self.criterionKLDiv(self.logsoftmax(self.fake_B_4.view(1,1,-1)), self.softmax(self.fake_B_4_teacher.view(1,1,-1)))
320
+ self.loss_G_KLDiv_5 = self.criterionKLDiv(self.logsoftmax(self.fake_B_5.view(1,1,-1)), self.softmax(self.fake_B_5_teacher.view(1,1,-1)))
321
+ self.loss_G_KLDiv_5_1 = self.criterionKLDiv(self.logsoftmax(self.fake_B_5_1.view(1,1,-1)), self.softmax(self.fake_B_5_1_teacher.view(1,1,-1)))
322
+ self.loss_G_KLDiv_5_2 = self.criterionKLDiv(self.logsoftmax(self.fake_B_5_2.view(1,1,-1)), self.softmax(self.fake_B_5_2_teacher.view(1,1,-1)))
323
+ self.loss_G_KLDiv_5_3 = self.criterionKLDiv(self.logsoftmax(self.fake_B_5_3.view(1,1,-1)), self.softmax(self.fake_B_5_3_teacher.view(1,1,-1)))
324
+ self.loss_G_KLDiv_5_4 = self.criterionKLDiv(self.logsoftmax(self.fake_B_5_4.view(1,1,-1)), self.softmax(self.fake_B_5_4_teacher.view(1,1,-1)))
325
+ self.loss_G_KLDiv_5_5 = self.criterionKLDiv(self.logsoftmax(self.fake_B_5_5.view(1,1,-1)), self.softmax(self.fake_B_5_5_teacher.view(1,1,-1)))
326
+
327
+ self.loss_G = (self.loss_G_GAN_1 + self.loss_G_L1_1 + self.loss_G_VGG_1) * self.loss_G_weights[0] + \
328
+ (self.loss_G_GAN_2 + self.loss_G_L1_2 + self.loss_G_VGG_2) * self.loss_G_weights[1] + \
329
+ (self.loss_G_GAN_3 + self.loss_G_L1_3 + self.loss_G_VGG_3) * self.loss_G_weights[2] + \
330
+ (self.loss_G_GAN_4 + self.loss_G_L1_4 + self.loss_G_VGG_4) * self.loss_G_weights[3] + \
331
+ (self.loss_G_GAN_5 + self.loss_G_L1_5) * self.loss_G_weights[4] + \
332
+ (self.loss_G_KLDiv_1 + self.loss_G_KLDiv_2 + self.loss_G_KLDiv_3 + self.loss_G_KLDiv_4 + \
333
+ self.loss_G_KLDiv_5 + self.loss_G_KLDiv_5_1 + self.loss_G_KLDiv_5_2 + self.loss_G_KLDiv_5_3 + \
334
+ self.loss_G_KLDiv_5_4 + self.loss_G_KLDiv_5_5) * 10
335
+
336
+ # combine loss and calculate gradients
337
+ # self.loss_G = (self.loss_G_GAN_1 + self.loss_G_L1_1) * self.loss_G_weights[0] + \
338
+ # (self.loss_G_GAN_2 + self.loss_G_L1_2) * self.loss_G_weights[1] + \
339
+ # (self.loss_G_GAN_3 + self.loss_G_L1_3) * self.loss_G_weights[2] + \
340
+ # (self.loss_G_GAN_4 + self.loss_G_L1_4) * self.loss_G_weights[3] + \
341
+ # (self.loss_G_GAN_5 + self.loss_G_L1_5) * self.loss_G_weights[4]
342
+ self.loss_G.backward()
343
+
344
+ def optimize_parameters(self):
345
+ self.forward() # compute fake images: G(A)
346
+ # update D
347
+ self.set_requires_grad(self.netD1, True) # enable backprop for D1
348
+ self.set_requires_grad(self.netD2, True) # enable backprop for D2
349
+ self.set_requires_grad(self.netD3, True) # enable backprop for D3
350
+ self.set_requires_grad(self.netD4, True) # enable backprop for D4
351
+ self.set_requires_grad(self.netD51, True) # enable backprop for D51
352
+ self.set_requires_grad(self.netD52, True) # enable backprop for D52
353
+ self.set_requires_grad(self.netD53, True) # enable backprop for D53
354
+ self.set_requires_grad(self.netD54, True) # enable backprop for D54
355
+ self.set_requires_grad(self.netD55, True) # enable backprop for D54
356
+
357
+ self.optimizer_D.zero_grad() # set D's gradients to zero
358
+ self.backward_D() # calculate gradients for D
359
+ self.optimizer_D.step() # update D's weights
360
+
361
+ # update G
362
+ self.set_requires_grad(self.netD1, False) # D1 requires no gradients when optimizing G1
363
+ self.set_requires_grad(self.netD2, False) # D2 requires no gradients when optimizing G2
364
+ self.set_requires_grad(self.netD3, False) # D3 requires no gradients when optimizing G3
365
+ self.set_requires_grad(self.netD4, False) # D4 requires no gradients when optimizing G4
366
+ self.set_requires_grad(self.netD51, False) # D51 requires no gradients when optimizing G51
367
+ self.set_requires_grad(self.netD52, False) # D52 requires no gradients when optimizing G52
368
+ self.set_requires_grad(self.netD53, False) # D53 requires no gradients when optimizing G53
369
+ self.set_requires_grad(self.netD54, False) # D54 requires no gradients when optimizing G54
370
+ self.set_requires_grad(self.netD55, False) # D54 requires no gradients when optimizing G54
371
+
372
+ self.optimizer_G.zero_grad() # set G's gradients to zero
373
+ self.backward_G() # calculate graidents for G
374
+ self.optimizer_G.step() # udpate G's weights
375
+
376
+ def calculate_losses(self):
377
+ """
378
+ Calculate losses but do not optimize parameters. Used in validation loss calculation during training.
379
+ """
380
+
381
+ self.forward() # compute fake images: G(A)
382
+ # update D
383
+ self.set_requires_grad(self.netD1, True) # enable backprop for D1
384
+ self.set_requires_grad(self.netD2, True) # enable backprop for D2
385
+ self.set_requires_grad(self.netD3, True) # enable backprop for D3
386
+ self.set_requires_grad(self.netD4, True) # enable backprop for D4
387
+ self.set_requires_grad(self.netD51, True) # enable backprop for D51
388
+ self.set_requires_grad(self.netD52, True) # enable backprop for D52
389
+ self.set_requires_grad(self.netD53, True) # enable backprop for D53
390
+ self.set_requires_grad(self.netD54, True) # enable backprop for D54
391
+ self.set_requires_grad(self.netD55, True) # enable backprop for D54
392
+
393
+ self.optimizer_D.zero_grad() # set D's gradients to zero
394
+ self.backward_D() # calculate gradients for D
395
+
396
+ # update G
397
+ self.set_requires_grad(self.netD1, False) # D1 requires no gradients when optimizing G1
398
+ self.set_requires_grad(self.netD2, False) # D2 requires no gradients when optimizing G2
399
+ self.set_requires_grad(self.netD3, False) # D3 requires no gradients when optimizing G3
400
+ self.set_requires_grad(self.netD4, False) # D4 requires no gradients when optimizing G4
401
+ self.set_requires_grad(self.netD51, False) # D51 requires no gradients when optimizing G51
402
+ self.set_requires_grad(self.netD52, False) # D52 requires no gradients when optimizing G52
403
+ self.set_requires_grad(self.netD53, False) # D53 requires no gradients when optimizing G53
404
+ self.set_requires_grad(self.netD54, False) # D54 requires no gradients when optimizing G54
405
+ self.set_requires_grad(self.netD55, False) # D54 requires no gradients when optimizing G54
406
+
407
+ self.optimizer_G.zero_grad() # set G's gradients to zero
408
+ self.backward_G() # calculate graidents for G
409
+
@@ -273,6 +273,27 @@ def run_dask(img, model_path, eager_mode=False, opt=None, seg_only=False):
273
273
  'G54': 0.00, # Lap2
274
274
  'G55': 0.25, # Marker
275
275
  }
276
+ weights1 = {
277
+ 'G51': 0.5, # IHC
278
+ 'G52': 0.0, # Hema
279
+ 'G53': 0.0, # DAPI
280
+ 'G54': 0.0, # Lap2
281
+ 'G55': 0.5, # Marker
282
+ }
283
+ weights2 = {
284
+ 'G51': 0.34, # IHC
285
+ 'G52': 0.00, # Hema
286
+ 'G53': 0.33, # DAPI
287
+ 'G54': 0.00, # Lap2
288
+ 'G55': 0.33, # Marker
289
+ }
290
+ weights3 = {
291
+ 'G51': 0.34, # IHC
292
+ 'G52': 0.33, # Hema
293
+ 'G53': 0.00, # DAPI
294
+ 'G54': 0.00, # Lap2
295
+ 'G55': 0.33, # Marker
296
+ }
276
297
 
277
298
  seg_map = {'G1': 'G52', 'G2': 'G53', 'G3': 'G54', 'G4': 'G55'}
278
299
  if seg_only:
@@ -289,6 +310,9 @@ def run_dask(img, model_path, eager_mode=False, opt=None, seg_only=False):
289
310
  segs = compute(lazy_segs)[0]
290
311
 
291
312
  seg = torch.stack([torch.mul(segs[k], weights[k]) for k in segs.keys()]).sum(dim=0)
313
+ seg1 = torch.stack([torch.mul(segs[k], weights1[k]) for k in segs.keys()]).sum(dim=0)
314
+ seg2 = torch.stack([torch.mul(segs[k], weights2[k]) for k in segs.keys()]).sum(dim=0)
315
+ seg3 = torch.stack([torch.mul(segs[k], weights3[k]) for k in segs.keys()]).sum(dim=0)
292
316
 
293
317
  if seg_only:
294
318
  res = {'G4': tensor_to_pil(gens['G4'])} if 'G4' in gens else {}
@@ -296,6 +320,9 @@ def run_dask(img, model_path, eager_mode=False, opt=None, seg_only=False):
296
320
  res = {k: tensor_to_pil(v) for k, v in gens.items()}
297
321
  res.update({k: tensor_to_pil(v) for k, v in segs.items()})
298
322
  res['G5'] = tensor_to_pil(seg)
323
+ res['G5a'] = tensor_to_pil(seg1)
324
+ res['G5b'] = tensor_to_pil(seg2)
325
+ res['G5c'] = tensor_to_pil(seg3)
299
326
 
300
327
  return res
301
328
  elif opt.model in ['DeepLIIFExt','SDG','CycleGAN']:
@@ -348,6 +375,9 @@ def run_wrapper(tile, run_fn, model_path, eager_mode=False, opt=None, seg_only=F
348
375
  'G3': Image.new(mode='RGB', size=(512, 512), color=(0, 0, 0)),
349
376
  'G4': Image.new(mode='RGB', size=(512, 512), color=(10, 10, 10)),
350
377
  'G5': Image.new(mode='RGB', size=(512, 512), color=(0, 0, 0)),
378
+ 'G5a': Image.new(mode='RGB', size=(512, 512), color=(0, 0, 0)),
379
+ 'G5b': Image.new(mode='RGB', size=(512, 512), color=(0, 0, 0)),
380
+ 'G5c': Image.new(mode='RGB', size=(512, 512), color=(0, 0, 0)),
351
381
  'G51': Image.new(mode='RGB', size=(512, 512), color=(0, 0, 0)),
352
382
  'G52': Image.new(mode='RGB', size=(512, 512), color=(0, 0, 0)),
353
383
  'G53': Image.new(mode='RGB', size=(512, 512), color=(0, 0, 0)),
@@ -409,6 +439,9 @@ def inference(img, tile_size, overlap_size, model_path, use_torchserve=False,
409
439
  'Lap2': results['G3'],
410
440
  'Marker': results['G4'],
411
441
  'Seg': results['G5'],
442
+ 'Seg1': results['G5a'],
443
+ 'Seg2': results['G5b'],
444
+ 'Seg3': results['G5c'],
412
445
  }
413
446
 
414
447
  if return_seg_intermediate and not seg_only: