deepliif 1.1.13__tar.gz → 1.1.14__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Files changed (54) hide show
  1. {deepliif-1.1.13/deepliif.egg-info → deepliif-1.1.14}/PKG-INFO +2 -2
  2. {deepliif-1.1.13 → deepliif-1.1.14}/README.md +1 -1
  3. {deepliif-1.1.13 → deepliif-1.1.14}/cli.py +15 -22
  4. {deepliif-1.1.13 → deepliif-1.1.14}/deepliif/data/aligned_dataset.py +2 -2
  5. deepliif-1.1.14/deepliif/models/DeepLIIFKD_model.py +409 -0
  6. deepliif-1.1.13/deepliif/models/__init__.py → deepliif-1.1.14/deepliif/models/__init__ - weights, empty, zarr, tile count.py +37 -5
  7. deepliif-1.1.14/deepliif/models/__init__.py +817 -0
  8. {deepliif-1.1.13 → deepliif-1.1.14}/deepliif/models/base_model.py +1 -1
  9. {deepliif-1.1.13 → deepliif-1.1.14}/deepliif/models/networks.py +7 -5
  10. {deepliif-1.1.13 → deepliif-1.1.14}/deepliif/postprocessing.py +55 -24
  11. deepliif-1.1.14/deepliif/util/checks.py +17 -0
  12. {deepliif-1.1.13 → deepliif-1.1.14}/deepliif/util/util.py +42 -0
  13. {deepliif-1.1.13 → deepliif-1.1.14/deepliif.egg-info}/PKG-INFO +2 -2
  14. {deepliif-1.1.13 → deepliif-1.1.14}/deepliif.egg-info/SOURCES.txt +3 -0
  15. {deepliif-1.1.13 → deepliif-1.1.14}/setup.cfg +1 -1
  16. {deepliif-1.1.13 → deepliif-1.1.14}/setup.py +1 -1
  17. {deepliif-1.1.13 → deepliif-1.1.14}/LICENSE.md +0 -0
  18. {deepliif-1.1.13 → deepliif-1.1.14}/deepliif/__init__.py +0 -0
  19. {deepliif-1.1.13 → deepliif-1.1.14}/deepliif/data/__init__.py +0 -0
  20. {deepliif-1.1.13 → deepliif-1.1.14}/deepliif/data/base_dataset.py +0 -0
  21. {deepliif-1.1.13 → deepliif-1.1.14}/deepliif/data/colorization_dataset.py +0 -0
  22. {deepliif-1.1.13 → deepliif-1.1.14}/deepliif/data/image_folder.py +0 -0
  23. {deepliif-1.1.13 → deepliif-1.1.14}/deepliif/data/single_dataset.py +0 -0
  24. {deepliif-1.1.13 → deepliif-1.1.14}/deepliif/data/template_dataset.py +0 -0
  25. {deepliif-1.1.13 → deepliif-1.1.14}/deepliif/data/unaligned_dataset.py +0 -0
  26. {deepliif-1.1.13 → deepliif-1.1.14}/deepliif/models/CycleGAN_model.py +0 -0
  27. {deepliif-1.1.13 → deepliif-1.1.14}/deepliif/models/DeepLIIFExt_model.py +0 -0
  28. {deepliif-1.1.13 → deepliif-1.1.14}/deepliif/models/DeepLIIF_model.py +0 -0
  29. {deepliif-1.1.13 → deepliif-1.1.14}/deepliif/models/SDG_model.py +0 -0
  30. {deepliif-1.1.13 → deepliif-1.1.14}/deepliif/models/__init__ - different weighted.py +0 -0
  31. {deepliif-1.1.13 → deepliif-1.1.14}/deepliif/models/__init__ - run_dask_multi dev.py +0 -0
  32. {deepliif-1.1.13 → deepliif-1.1.14}/deepliif/models/__init__ - time gens.py +0 -0
  33. {deepliif-1.1.13 → deepliif-1.1.14}/deepliif/models/__init__ - timings.py +0 -0
  34. {deepliif-1.1.13 → deepliif-1.1.14}/deepliif/models/att_unet.py +0 -0
  35. {deepliif-1.1.13 → deepliif-1.1.14}/deepliif/options/__init__.py +0 -0
  36. {deepliif-1.1.13 → deepliif-1.1.14}/deepliif/options/base_options.py +0 -0
  37. {deepliif-1.1.13 → deepliif-1.1.14}/deepliif/options/processing_options.py +0 -0
  38. {deepliif-1.1.13 → deepliif-1.1.14}/deepliif/options/test_options.py +0 -0
  39. {deepliif-1.1.13 → deepliif-1.1.14}/deepliif/options/train_options.py +0 -0
  40. {deepliif-1.1.13 → deepliif-1.1.14}/deepliif/postprocessing__OLD__DELETE.py +0 -0
  41. {deepliif-1.1.13 → deepliif-1.1.14}/deepliif/util/__init__.py +0 -0
  42. {deepliif-1.1.13 → deepliif-1.1.14}/deepliif/util/get_data.py +0 -0
  43. {deepliif-1.1.13 → deepliif-1.1.14}/deepliif/util/html.py +0 -0
  44. {deepliif-1.1.13 → deepliif-1.1.14}/deepliif/util/image_pool.py +0 -0
  45. {deepliif-1.1.13 → deepliif-1.1.14}/deepliif/util/visualizer.py +0 -0
  46. {deepliif-1.1.13 → deepliif-1.1.14}/deepliif.egg-info/dependency_links.txt +0 -0
  47. {deepliif-1.1.13 → deepliif-1.1.14}/deepliif.egg-info/entry_points.txt +0 -0
  48. {deepliif-1.1.13 → deepliif-1.1.14}/deepliif.egg-info/requires.txt +0 -0
  49. {deepliif-1.1.13 → deepliif-1.1.14}/deepliif.egg-info/top_level.txt +0 -0
  50. {deepliif-1.1.13 → deepliif-1.1.14}/tests/test_args.py +0 -0
  51. {deepliif-1.1.13 → deepliif-1.1.14}/tests/test_cli_inference.py +0 -0
  52. {deepliif-1.1.13 → deepliif-1.1.14}/tests/test_cli_serialize.py +0 -0
  53. {deepliif-1.1.13 → deepliif-1.1.14}/tests/test_cli_train.py +0 -0
  54. {deepliif-1.1.13 → deepliif-1.1.14}/tests/test_cli_trainlaunch.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.1
2
2
  Name: deepliif
3
- Version: 1.1.13
3
+ Version: 1.1.14
4
4
  Summary: DeepLIIF: Deep-Learning Inferred Multiplex Immunofluorescence for Immunohistochemical Image Quantification
5
5
  Home-page: https://github.com/nadeemlab/DeepLIIF
6
6
  Author: Parmida93
@@ -53,7 +53,7 @@ segmentation.*
53
53
 
54
54
  © This code is made available for non-commercial academic purposes.
55
55
 
56
- ![Version](https://img.shields.io/static/v1?label=latest&message=v1.1.13&color=darkgreen)
56
+ ![Version](https://img.shields.io/static/v1?label=latest&message=v1.1.14&color=darkgreen)
57
57
  [![Total Downloads](https://static.pepy.tech/personalized-badge/deepliif?period=total&units=international_system&left_color=grey&right_color=blue&left_text=total%20downloads)](https://pepy.tech/project/deepliif?&left_text=totalusers)
58
58
 
59
59
  ![overview_image](./images/overview.png)*Overview of DeepLIIF pipeline and sample input IHCs (different
@@ -42,7 +42,7 @@ segmentation.*
42
42
 
43
43
  © This code is made available for non-commercial academic purposes.
44
44
 
45
- ![Version](https://img.shields.io/static/v1?label=latest&message=v1.1.13&color=darkgreen)
45
+ ![Version](https://img.shields.io/static/v1?label=latest&message=v1.1.14&color=darkgreen)
46
46
  [![Total Downloads](https://static.pepy.tech/personalized-badge/deepliif?period=total&units=international_system&left_color=grey&right_color=blue&left_text=total%20downloads)](https://pepy.tech/project/deepliif?&left_text=totalusers)
47
47
 
48
48
  ![overview_image](./images/overview.png)*Overview of DeepLIIF pipeline and sample input IHCs (different
@@ -14,6 +14,7 @@ from deepliif.data import create_dataset, transform
14
14
  from deepliif.models import init_nets, infer_modalities, infer_results_for_wsi, create_model, postprocess
15
15
  from deepliif.util import allowed_file, Visualizer, get_information, test_diff_original_serialized, disable_batchnorm_tracking_stats
16
16
  from deepliif.util.util import mkdirs
17
+ from deepliif.util.checks import check_weights
17
18
  # from deepliif.util import infer_results_for_wsi
18
19
  from deepliif.options import Options, print_options
19
20
 
@@ -78,6 +79,7 @@ def cli():
78
79
  @click.option('--modalities-no', default=4, type=int, help='number of targets')
79
80
  # model parameters
80
81
  @click.option('--model', default='DeepLIIF', help='name of model class')
82
+ @click.option('--model-dir-teacher', default='', help='the directory of the teacher model, only applicable if model is DeepLIIFKD')
81
83
  @click.option('--seg-weights', default='', type=str, help='weights used to aggregate modality images for the final segmentation image; numbers should add up to 1, and each number corresponds to the modality in order; example: 0.25,0.15,0.25,0.1,0.25')
82
84
  @click.option('--loss-weights-g', default='', type=str, help='weights used to aggregate modality-wise losses for the final loss; numbers should add up to 1, and each number corresponds to the modality in order; example: 0.2,0.2,0.2,0.2,0.2')
83
85
  @click.option('--loss-weights-d', default='', type=str, help='weights used to aggregate modality-wise losses for the final loss; numbers should add up to 1, and each number corresponds to the modality in order; example: 0.2,0.2,0.2,0.2,0.2')
@@ -193,7 +195,8 @@ def train(dataroot, name, gpu_ids, checkpoints_dir, input_nc, output_nc, ngf, nd
193
195
  verbose, lambda_l1, is_train, display_freq, display_ncols, display_id, display_server, display_env,
194
196
  display_port, update_html_freq, print_freq, no_html, save_latest_freq, save_epoch_freq, save_by_iter,
195
197
  continue_train, epoch_count, phase, lr_policy, n_epochs, n_epochs_decay, optimizer, beta1, lr_g, lr_d, lr_decay_iters,
196
- remote, remote_transfer_cmd, seed, dataset_mode, padding, model, seg_weights, loss_weights_g, loss_weights_d,
198
+ remote, remote_transfer_cmd, seed, dataset_mode, padding, model, model_dir_teacher,
199
+ seg_weights, loss_weights_g, loss_weights_d,
197
200
  modalities_no, seg_gen, net_ds, net_gs, gan_mode, gan_mode_s, local_rank, with_val, debug, debug_data_size):
198
201
  """General-purpose training script for multi-task image-to-image translation.
199
202
 
@@ -206,8 +209,8 @@ def train(dataroot, name, gpu_ids, checkpoints_dir, input_nc, output_nc, ngf, nd
206
209
  plot, and save models.The script supports continue/resume training.
207
210
  Use '--continue_train' to resume your previous training.
208
211
  """
209
- assert model in ['DeepLIIF','DeepLIIFExt','SDG','CycleGAN'], f'model class {model} is not implemented'
210
- if model == 'DeepLIIF':
212
+ assert model in ['DeepLIIF','DeepLIIFExt','SDG','CycleGAN','DeepLIIFKD'], f'model class {model} is not implemented'
213
+ if model in ['DeepLIIF','DeepLIIFKD']:
211
214
  seg_no = 1
212
215
  elif model == 'DeepLIIFExt':
213
216
  if seg_gen:
@@ -221,6 +224,9 @@ def train(dataroot, name, gpu_ids, checkpoints_dir, input_nc, output_nc, ngf, nd
221
224
  if model == 'CycleGAN':
222
225
  dataset_mode = "unaligned"
223
226
 
227
+ if model == 'DeepLIIFKD':
228
+ assert len(model_dir_teacher) > 0 and os.path.isdir(model_dir_teacher), f'Teacher model directory {model_dir_teacher} is not valid.'
229
+
224
230
  if optimizer != 'adam':
225
231
  print(f'Optimizer torch.optim.{optimizer} is not tested. Be careful about the parameters of the optimizer.')
226
232
 
@@ -310,7 +316,7 @@ def train(dataroot, name, gpu_ids, checkpoints_dir, input_nc, output_nc, ngf, nd
310
316
 
311
317
  net_gs = net_gs.split(',')
312
318
  assert len(net_gs) in [1,seg_no], f'net_gs should contain either 1 architecture for all segmentation generators or the same number of architectures as the number of segmentation generators ({seg_no})'
313
- if len(net_gs) == 1 and model == 'DeepLIIF':
319
+ if len(net_gs) == 1 and model in ['DeepLIIF','DeepLIIFKD']:
314
320
  net_gs = net_gs*(modalities_no + seg_no)
315
321
  elif len(net_gs) == 1:
316
322
  net_gs = net_gs*seg_no
@@ -320,34 +326,21 @@ def train(dataroot, name, gpu_ids, checkpoints_dir, input_nc, output_nc, ngf, nd
320
326
 
321
327
  # check seg weights and loss weights
322
328
  if len(d_params['seg_weights']) == 0:
323
- seg_weights = [0.25,0.15,0.25,0.1,0.25] if d_params['model'] == 'DeepLIIF' else [1 / modalities_no] * modalities_no
329
+ seg_weights = [0.25,0.15,0.25,0.1,0.25] if d_params['model'] in ['DeepLIIF','DeepLIIFKD'] else [1 / modalities_no] * modalities_no
324
330
  else:
325
331
  seg_weights = [float(x) for x in seg_weights.split(',')]
326
332
 
327
333
  if len(d_params['loss_weights_g']) == 0:
328
- loss_weights_g = [0.2]*5 if d_params['model'] == 'DeepLIIF' else [1 / modalities_no] * modalities_no
334
+ loss_weights_g = [0.2]*5 if d_params['model'] in ['DeepLIIF','DeepLIIFKD'] else [1 / modalities_no] * modalities_no
329
335
  else:
330
336
  loss_weights_g = [float(x) for x in loss_weights_g.split(',')]
331
337
 
332
338
  if len(d_params['loss_weights_d']) == 0:
333
- loss_weights_d = [0.2]*5 if d_params['model'] == 'DeepLIIF' else [1 / modalities_no] * modalities_no
339
+ loss_weights_d = [0.2]*5 if d_params['model'] in ['DeepLIIF','DeepLIIFKD'] else [1 / modalities_no] * modalities_no
334
340
  else:
335
341
  loss_weights_d = [float(x) for x in loss_weights_d.split(',')]
336
342
 
337
- assert sum(seg_weights) == 1, 'seg weights should add up to 1'
338
- assert sum(loss_weights_g) == 1, 'loss weights g should add up to 1'
339
- assert sum(loss_weights_d) == 1, 'loss weights d should add up to 1'
340
-
341
- if model == 'DeepLIIF':
342
- # +1 because input becomes an additional modality used in generating the final segmentation
343
- assert len(seg_weights) == modalities_no+1, 'seg weights should have the same number of elements as number of modalities to be generated'
344
- assert len(loss_weights_g) == modalities_no+1, 'loss weights g should have the same number of elements as number of modalities to be generated'
345
- assert len(loss_weights_d) == modalities_no+1, 'loss weights d should have the same number of elements as number of modalities to be generated'
346
-
347
- else:
348
- assert len(seg_weights) == modalities_no, 'seg weights should have the same number of elements as number of modalities to be generated'
349
- assert len(loss_weights_g) == modalities_no, 'loss weights g should have the same number of elements as number of modalities to be generated'
350
- assert len(loss_weights_d) == modalities_no, 'loss weights d should have the same number of elements as number of modalities to be generated'
343
+ check_weights(model, modalities_no, seg_weights, loss_weights_g, loss_weights_d)
351
344
 
352
345
  d_params['seg_weights'] = seg_weights
353
346
  d_params['loss_G_weights'] = loss_weights_g
@@ -373,7 +366,7 @@ def train(dataroot, name, gpu_ids, checkpoints_dir, input_nc, output_nc, ngf, nd
373
366
  data_val = [batch for batch in dataset_val]
374
367
  click.echo('The number of validation images = %d' % len(dataset_val))
375
368
 
376
- if model in ['DeepLIIF']:
369
+ if model in ['DeepLIIF', 'DeepLIIFKD']:
377
370
  metrics_val = json.load(open(os.path.join(dataset_val.dataset.dir_AB,'metrics.json')))
378
371
 
379
372
  # create a model given model and other options
@@ -50,7 +50,7 @@ class AlignedDataset(BaseDataset):
50
50
  AB = Image.open(AB_path).convert('RGB')
51
51
  # split AB image into A and B
52
52
  w, h = AB.size
53
- if self.model == 'DeepLIIF':
53
+ if self.model in ['DeepLIIF','DeepLIIFKD']:
54
54
  num_img = self.modalities_no + 1 + 1 # +1 for segmentation channel, +1 for input image
55
55
  elif self.model == 'DeepLIIFExt':
56
56
  num_img = self.modalities_no * 2 + 1 if self.seg_gen else self.modalities_no + 1 # +1 for segmentation channel
@@ -68,7 +68,7 @@ class AlignedDataset(BaseDataset):
68
68
 
69
69
  A = A_transform(A)
70
70
  B_Array = []
71
- if self.model == 'DeepLIIF':
71
+ if self.model in ['DeepLIIF','DeepLIIFKD']:
72
72
  for i in range(1, num_img):
73
73
  B = AB.crop((w2 * i, 0, w2 * (i + 1), h))
74
74
  B = B_transform(B)
@@ -0,0 +1,409 @@
1
+ import torch
2
+ from .base_model import BaseModel
3
+ from . import networks
4
+ from .networks import get_optimizer
5
+ from . import init_nets, run_dask, get_opt
6
+ from torch import nn
7
+
8
+ class DeepLIIFKDModel(BaseModel):
9
+ """ This class implements the DeepLIIF model, for learning a mapping from input images to modalities given paired data."""
10
+
11
+ def __init__(self, opt):
12
+ """Initialize the DeepLIIF class.
13
+
14
+ Parameters:
15
+ opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions
16
+ """
17
+ BaseModel.__init__(self, opt)
18
+ if not hasattr(opt,'net_gs'):
19
+ opt.net_gs = 'unet_512'
20
+
21
+ self.seg_weights = opt.seg_weights
22
+ self.loss_G_weights = opt.loss_G_weights
23
+ self.loss_D_weights = opt.loss_D_weights
24
+
25
+ if not opt.is_train:
26
+ self.gpu_ids = [] # avoid the models being loaded as DP
27
+ else:
28
+ self.gpu_ids = opt.gpu_ids
29
+
30
+ self.loss_names = []
31
+ self.visual_names = ['real_A']
32
+ # specify the training losses you want to print out. The training/test scripts will call <BaseModel.get_current_losses>
33
+ for i in range(1, self.opt.modalities_no + 1 + 1):
34
+ self.loss_names.extend([f'G_GAN_{i}', f'G_L1_{i}', f'D_real_{i}', f'D_fake_{i}', f'G_KLDiv_{i}', f'G_KLDiv_5_{i}'])
35
+ self.visual_names.extend([f'fake_B_{i}', f'fake_B_5_{i}', f'fake_B_{i}_teacher', f'fake_B_5_{i}_teacher', f'real_B_{i}'])
36
+
37
+ # specify the images you want to save/display. The training/test scripts will call <BaseModel.get_current_visuals>
38
+ # specify the models you want to save to the disk. The training/test scripts will call <BaseModel.save_networks> and <BaseModel.load_networks>
39
+ if self.is_train:
40
+ self.model_names = []
41
+ for i in range(1, self.opt.modalities_no + 1):
42
+ self.model_names.extend(['G' + str(i), 'D' + str(i)])
43
+
44
+ for i in range(1, self.opt.modalities_no + 1 + 1):
45
+ self.model_names.extend(['G5' + str(i), 'D5' + str(i)])
46
+ else: # during test time, only load G
47
+ self.model_names = []
48
+ for i in range(1, self.opt.modalities_no + 1):
49
+ self.model_names.extend(['G' + str(i)])
50
+
51
+ for i in range(1, self.opt.modalities_no + 1 + 1):
52
+ self.model_names.extend(['G5' + str(i)])
53
+
54
+ # define networks (both generator and discriminator)
55
+ if isinstance(opt.netG, str):
56
+ opt.netG = [opt.netG] * 4
57
+ if isinstance(opt.net_gs, str):
58
+ opt.net_gs = [opt.net_gs]*5
59
+
60
+
61
+ self.netG1 = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG[0], opt.norm,
62
+ not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids, opt.padding)
63
+ self.netG2 = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG[1], opt.norm,
64
+ not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids, opt.padding)
65
+ self.netG3 = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG[2], opt.norm,
66
+ not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids, opt.padding)
67
+ self.netG4 = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.netG[3], opt.norm,
68
+ not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids, opt.padding)
69
+
70
+ # DeepLIIF model currently uses one gs arch because there is only one explicit seg mod output
71
+ self.netG51 = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.net_gs[0], opt.norm,
72
+ not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids)
73
+ self.netG52 = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.net_gs[1], opt.norm,
74
+ not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids)
75
+ self.netG53 = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.net_gs[2], opt.norm,
76
+ not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids)
77
+ self.netG54 = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.net_gs[3], opt.norm,
78
+ not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids)
79
+ self.netG55 = networks.define_G(opt.input_nc, opt.output_nc, opt.ngf, opt.net_gs[4], opt.norm,
80
+ not opt.no_dropout, opt.init_type, opt.init_gain, self.gpu_ids)
81
+
82
+
83
+ if self.is_train: # define a discriminator; conditional GANs need to take both input and output images; Therefore, #channels for D is input_nc + output_nc
84
+ self.netD1 = networks.define_D(opt.input_nc+opt.output_nc , opt.ndf, opt.netD,
85
+ opt.n_layers_D, opt.norm, opt.init_type, opt.init_gain, self.gpu_ids)
86
+ self.netD2 = networks.define_D(opt.input_nc + opt.output_nc, opt.ndf, opt.netD,
87
+ opt.n_layers_D, opt.norm, opt.init_type, opt.init_gain, self.gpu_ids)
88
+ self.netD3 = networks.define_D(opt.input_nc + opt.output_nc, opt.ndf, opt.netD,
89
+ opt.n_layers_D, opt.norm, opt.init_type, opt.init_gain, self.gpu_ids)
90
+ self.netD4 = networks.define_D(opt.input_nc + opt.output_nc, opt.ndf, opt.netD,
91
+ opt.n_layers_D, opt.norm, opt.init_type, opt.init_gain, self.gpu_ids)
92
+
93
+ self.netD51 = networks.define_D(opt.input_nc + opt.output_nc, opt.ndf, opt.netD,
94
+ opt.n_layers_D, opt.norm, opt.init_type, opt.init_gain, self.gpu_ids)
95
+ self.netD52 = networks.define_D(opt.input_nc + opt.output_nc, opt.ndf, opt.netD,
96
+ opt.n_layers_D, opt.norm, opt.init_type, opt.init_gain, self.gpu_ids)
97
+ self.netD53 = networks.define_D(opt.input_nc + opt.output_nc, opt.ndf, opt.netD,
98
+ opt.n_layers_D, opt.norm, opt.init_type, opt.init_gain, self.gpu_ids)
99
+ self.netD54 = networks.define_D(opt.input_nc + opt.output_nc, opt.ndf, opt.netD,
100
+ opt.n_layers_D, opt.norm, opt.init_type, opt.init_gain, self.gpu_ids)
101
+ self.netD55 = networks.define_D(opt.input_nc + opt.output_nc, opt.ndf, opt.netD,
102
+ opt.n_layers_D, opt.norm, opt.init_type, opt.init_gain, self.gpu_ids)
103
+
104
+ # load the teacher model
105
+ self.opt_teacher = get_opt(opt.model_dir_teacher, mode='test')
106
+ self.opt_teacher.gpu_ids = opt.gpu_ids # use student's gpu_ids
107
+ self.nets_teacher = init_nets(opt.model_dir_teacher, eager_mode=True, opt=self.opt_teacher, phase='test')
108
+
109
+
110
+ if self.is_train:
111
+ # define loss functions
112
+ self.criterionGAN_BCE = networks.GANLoss('vanilla').to(self.device)
113
+ self.criterionGAN_lsgan = networks.GANLoss('lsgan').to(self.device)
114
+ self.criterionSmoothL1 = torch.nn.SmoothL1Loss()
115
+
116
+ # initialize optimizers; schedulers will be automatically created by function <BaseModel.setup>.
117
+ params = list(self.netG1.parameters()) + list(self.netG2.parameters()) + list(self.netG3.parameters()) + list(self.netG4.parameters()) + list(self.netG51.parameters()) + list(self.netG52.parameters()) + list(self.netG53.parameters()) + list(self.netG54.parameters()) + list(self.netG55.parameters())
118
+ try:
119
+ self.optimizer_G = get_optimizer(opt.optimizer)(params, lr=opt.lr_g, betas=(opt.beta1, 0.999))
120
+ except:
121
+ print(f'betas are not used for optimizer torch.optim.{opt.optimizer} in generators')
122
+ self.optimizer_G = get_optimizer(opt.optimizer)(params, lr=opt.lr_g)
123
+
124
+ params = list(self.netD1.parameters()) + list(self.netD2.parameters()) + list(self.netD3.parameters()) + list(self.netD4.parameters()) + list(self.netD51.parameters()) + list(self.netD52.parameters()) + list(self.netD53.parameters()) + list(self.netD54.parameters()) + list(self.netD55.parameters())
125
+ try:
126
+ self.optimizer_D = get_optimizer(opt.optimizer)(params, lr=opt.lr_d, betas=(opt.beta1, 0.999))
127
+ except:
128
+ print(f'betas are not used for optimizer torch.optim.{opt.optimizer} in discriminators')
129
+ self.optimizer_D = get_optimizer(opt.optimizer)(params, lr=opt.lr_d)
130
+
131
+ self.optimizers.append(self.optimizer_G)
132
+ self.optimizers.append(self.optimizer_D)
133
+
134
+ self.criterionVGG = networks.VGGLoss().to(self.device)
135
+ self.criterionKLDiv = torch.nn.KLDivLoss(reduction='batchmean').to(self.device)
136
+ self.softmax = torch.nn.Softmax(dim=-1).to(self.device) # apply softmax on the last dim
137
+ self.logsoftmax = torch.nn.LogSoftmax(dim=-1).to(self.device) # apply log-softmax on the last dim
138
+
139
+ def set_input(self, input):
140
+ """
141
+ Unpack input data from the dataloader and perform necessary pre-processing steps.
142
+
143
+ :param input (dict): include the input image and the output modalities
144
+ """
145
+
146
+ self.real_A = input['A'].to(self.device)
147
+
148
+ self.real_B_array = input['B']
149
+ self.real_B_1 = self.real_B_array[0].to(self.device)
150
+ self.real_B_2 = self.real_B_array[1].to(self.device)
151
+ self.real_B_3 = self.real_B_array[2].to(self.device)
152
+ self.real_B_4 = self.real_B_array[3].to(self.device)
153
+ self.real_B_5 = self.real_B_array[4].to(self.device)
154
+ self.image_paths = input['A_paths']
155
+
156
+ def forward(self):
157
+ """Run forward pass; called by both functions <optimize_parameters> and <test>."""
158
+ self.fake_B_1 = self.netG1(self.real_A) # Hematoxylin image generator
159
+ self.fake_B_2 = self.netG2(self.real_A) # mpIF DAPI image generator
160
+ self.fake_B_3 = self.netG3(self.real_A) # mpIF Lap2 image generator
161
+ self.fake_B_4 = self.netG4(self.real_A) # mpIF Ki67 image generator
162
+
163
+ self.fake_B_5_1 = self.netG51(self.real_A) # Segmentation mask generator from IHC input image
164
+ self.fake_B_5_2 = self.netG52(self.fake_B_1) # Segmentation mask generator from Hematoxylin input image
165
+ self.fake_B_5_3 = self.netG53(self.fake_B_2) # Segmentation mask generator from mpIF DAPI input image
166
+ self.fake_B_5_4 = self.netG54(self.fake_B_3) # Segmentation mask generator from mpIF Lap2 input image
167
+ self.fake_B_5_5 = self.netG55(self.fake_B_4) # Segmentation mask generator from mpIF Lap2 input image
168
+ self.fake_B_5 = torch.stack([torch.mul(self.fake_B_5_1, self.seg_weights[0]),
169
+ torch.mul(self.fake_B_5_2, self.seg_weights[1]),
170
+ torch.mul(self.fake_B_5_3, self.seg_weights[2]),
171
+ torch.mul(self.fake_B_5_4, self.seg_weights[3]),
172
+ torch.mul(self.fake_B_5_5, self.seg_weights[4])]).sum(dim=0)
173
+
174
+ fakes_teacher = run_dask(img=self.real_A, nets=self.nets_teacher, opt=self.opt_teacher, use_dask=False, output_tensor=True)
175
+ for k,v in fakes_teacher.items():
176
+ suffix = k[1:] # starts with G
177
+ suffix = '_'.join(list(suffix)) # 51 -> 5_1
178
+ setattr(self,f'fake_B_{suffix}_teacher',v)
179
+
180
+ def backward_D(self):
181
+ """Calculate GAN loss for the discriminators"""
182
+ fake_AB_1 = torch.cat((self.real_A, self.fake_B_1), 1) # Conditional GANs; feed IHC input and Hematoxtlin output to the discriminator
183
+ fake_AB_2 = torch.cat((self.real_A, self.fake_B_2), 1) # Conditional GANs; feed IHC input and mpIF DAPI output to the discriminator
184
+ fake_AB_3 = torch.cat((self.real_A, self.fake_B_3), 1) # Conditional GANs; feed IHC input and mpIF Lap2 output to the discriminator
185
+ fake_AB_4 = torch.cat((self.real_A, self.fake_B_4), 1) # Conditional GANs; feed IHC input and mpIF Ki67 output to the discriminator
186
+
187
+ pred_fake_1 = self.netD1(fake_AB_1.detach())
188
+ pred_fake_2 = self.netD2(fake_AB_2.detach())
189
+ pred_fake_3 = self.netD3(fake_AB_3.detach())
190
+ pred_fake_4 = self.netD4(fake_AB_4.detach())
191
+
192
+ fake_AB_5_1 = torch.cat((self.real_A, self.fake_B_5), 1) # Conditional GANs; feed IHC input and Segmentation mask output to the discriminator
193
+ fake_AB_5_2 = torch.cat((self.real_B_1, self.fake_B_5), 1) # Conditional GANs; feed Hematoxylin input and Segmentation mask output to the discriminator
194
+ fake_AB_5_3 = torch.cat((self.real_B_2, self.fake_B_5), 1) # Conditional GANs; feed mpIF DAPI input and Segmentation mask output to the discriminator
195
+ fake_AB_5_4 = torch.cat((self.real_B_3, self.fake_B_5), 1) # Conditional GANs; feed mpIF Lap2 input and Segmentation mask output to the discriminator
196
+ fake_AB_5_5 = torch.cat((self.real_B_4, self.fake_B_5), 1) # Conditional GANs; feed mpIF Lap2 input and Segmentation mask output to the discriminator
197
+
198
+ pred_fake_5_1 = self.netD51(fake_AB_5_1.detach())
199
+ pred_fake_5_2 = self.netD52(fake_AB_5_2.detach())
200
+ pred_fake_5_3 = self.netD53(fake_AB_5_3.detach())
201
+ pred_fake_5_4 = self.netD54(fake_AB_5_4.detach())
202
+ pred_fake_5_5 = self.netD55(fake_AB_5_5.detach())
203
+
204
+ pred_fake_5 = torch.stack(
205
+ [torch.mul(pred_fake_5_1, self.seg_weights[0]),
206
+ torch.mul(pred_fake_5_2, self.seg_weights[1]),
207
+ torch.mul(pred_fake_5_3, self.seg_weights[2]),
208
+ torch.mul(pred_fake_5_4, self.seg_weights[3]),
209
+ torch.mul(pred_fake_5_5, self.seg_weights[4])]).sum(dim=0)
210
+
211
+ self.loss_D_fake_1 = self.criterionGAN_BCE(pred_fake_1, False)
212
+ self.loss_D_fake_2 = self.criterionGAN_BCE(pred_fake_2, False)
213
+ self.loss_D_fake_3 = self.criterionGAN_BCE(pred_fake_3, False)
214
+ self.loss_D_fake_4 = self.criterionGAN_BCE(pred_fake_4, False)
215
+ self.loss_D_fake_5 = self.criterionGAN_lsgan(pred_fake_5, False)
216
+
217
+
218
+ real_AB_1 = torch.cat((self.real_A, self.real_B_1), 1)
219
+ real_AB_2 = torch.cat((self.real_A, self.real_B_2), 1)
220
+ real_AB_3 = torch.cat((self.real_A, self.real_B_3), 1)
221
+ real_AB_4 = torch.cat((self.real_A, self.real_B_4), 1)
222
+
223
+ pred_real_1 = self.netD1(real_AB_1)
224
+ pred_real_2 = self.netD2(real_AB_2)
225
+ pred_real_3 = self.netD3(real_AB_3)
226
+ pred_real_4 = self.netD4(real_AB_4)
227
+
228
+ real_AB_5_1 = torch.cat((self.real_A, self.real_B_5), 1)
229
+ real_AB_5_2 = torch.cat((self.real_B_1, self.real_B_5), 1)
230
+ real_AB_5_3 = torch.cat((self.real_B_2, self.real_B_5), 1)
231
+ real_AB_5_4 = torch.cat((self.real_B_3, self.real_B_5), 1)
232
+ real_AB_5_5 = torch.cat((self.real_B_4, self.real_B_5), 1)
233
+
234
+ pred_real_5_1 = self.netD51(real_AB_5_1)
235
+ pred_real_5_2 = self.netD52(real_AB_5_2)
236
+ pred_real_5_3 = self.netD53(real_AB_5_3)
237
+ pred_real_5_4 = self.netD54(real_AB_5_4)
238
+ pred_real_5_5 = self.netD55(real_AB_5_5)
239
+
240
+ pred_real_5 = torch.stack(
241
+ [torch.mul(pred_real_5_1, self.seg_weights[0]),
242
+ torch.mul(pred_real_5_2, self.seg_weights[1]),
243
+ torch.mul(pred_real_5_3, self.seg_weights[2]),
244
+ torch.mul(pred_real_5_4, self.seg_weights[3]),
245
+ torch.mul(pred_real_5_5, self.seg_weights[4])]).sum(dim=0)
246
+
247
+ self.loss_D_real_1 = self.criterionGAN_BCE(pred_real_1, True)
248
+ self.loss_D_real_2 = self.criterionGAN_BCE(pred_real_2, True)
249
+ self.loss_D_real_3 = self.criterionGAN_BCE(pred_real_3, True)
250
+ self.loss_D_real_4 = self.criterionGAN_BCE(pred_real_4, True)
251
+ self.loss_D_real_5 = self.criterionGAN_lsgan(pred_real_5, True)
252
+
253
+ # combine losses and calculate gradients
254
+ self.loss_D = (self.loss_D_fake_1 + self.loss_D_real_1) * 0.5 * self.loss_D_weights[0] + \
255
+ (self.loss_D_fake_2 + self.loss_D_real_2) * 0.5 * self.loss_D_weights[1] + \
256
+ (self.loss_D_fake_3 + self.loss_D_real_3) * 0.5 * self.loss_D_weights[2] + \
257
+ (self.loss_D_fake_4 + self.loss_D_real_4) * 0.5 * self.loss_D_weights[3] + \
258
+ (self.loss_D_fake_5 + self.loss_D_real_5) * 0.5 * self.loss_D_weights[4]
259
+
260
+ self.loss_D.backward()
261
+
262
+ def backward_G(self):
263
+ """Calculate GAN and L1 loss for the generator"""
264
+
265
+ fake_AB_1 = torch.cat((self.real_A, self.fake_B_1), 1)
266
+ fake_AB_2 = torch.cat((self.real_A, self.fake_B_2), 1)
267
+ fake_AB_3 = torch.cat((self.real_A, self.fake_B_3), 1)
268
+ fake_AB_4 = torch.cat((self.real_A, self.fake_B_4), 1)
269
+
270
+ fake_AB_5_1 = torch.cat((self.real_A, self.fake_B_5), 1)
271
+ fake_AB_5_2 = torch.cat((self.real_B_1, self.fake_B_5), 1)
272
+ fake_AB_5_3 = torch.cat((self.real_B_2, self.fake_B_5), 1)
273
+ fake_AB_5_4 = torch.cat((self.real_B_3, self.fake_B_5), 1)
274
+ fake_AB_5_5 = torch.cat((self.real_B_4, self.fake_B_5), 1)
275
+
276
+ pred_fake_1 = self.netD1(fake_AB_1)
277
+ pred_fake_2 = self.netD2(fake_AB_2)
278
+ pred_fake_3 = self.netD3(fake_AB_3)
279
+ pred_fake_4 = self.netD4(fake_AB_4)
280
+
281
+ pred_fake_5_1 = self.netD51(fake_AB_5_1)
282
+ pred_fake_5_2 = self.netD52(fake_AB_5_2)
283
+ pred_fake_5_3 = self.netD53(fake_AB_5_3)
284
+ pred_fake_5_4 = self.netD54(fake_AB_5_4)
285
+ pred_fake_5_5 = self.netD55(fake_AB_5_5)
286
+ pred_fake_5 = torch.stack(
287
+ [torch.mul(pred_fake_5_1, self.seg_weights[0]),
288
+ torch.mul(pred_fake_5_2, self.seg_weights[1]),
289
+ torch.mul(pred_fake_5_3, self.seg_weights[2]),
290
+ torch.mul(pred_fake_5_4, self.seg_weights[3]),
291
+ torch.mul(pred_fake_5_5, self.seg_weights[4])]).sum(dim=0)
292
+
293
+ self.loss_G_GAN_1 = self.criterionGAN_BCE(pred_fake_1, True)
294
+ self.loss_G_GAN_2 = self.criterionGAN_BCE(pred_fake_2, True)
295
+ self.loss_G_GAN_3 = self.criterionGAN_BCE(pred_fake_3, True)
296
+ self.loss_G_GAN_4 = self.criterionGAN_BCE(pred_fake_4, True)
297
+ self.loss_G_GAN_5 = self.criterionGAN_lsgan(pred_fake_5, True)
298
+
299
+ # Second, G(A) = B
300
+ self.loss_G_L1_1 = self.criterionSmoothL1(self.fake_B_1, self.real_B_1) * self.opt.lambda_L1
301
+ self.loss_G_L1_2 = self.criterionSmoothL1(self.fake_B_2, self.real_B_2) * self.opt.lambda_L1
302
+ self.loss_G_L1_3 = self.criterionSmoothL1(self.fake_B_3, self.real_B_3) * self.opt.lambda_L1
303
+ self.loss_G_L1_4 = self.criterionSmoothL1(self.fake_B_4, self.real_B_4) * self.opt.lambda_L1
304
+ self.loss_G_L1_5 = self.criterionSmoothL1(self.fake_B_5, self.real_B_5) * self.opt.lambda_L1
305
+
306
+ self.loss_G_VGG_1 = self.criterionVGG(self.fake_B_1, self.real_B_1) * self.opt.lambda_feat
307
+ self.loss_G_VGG_2 = self.criterionVGG(self.fake_B_2, self.real_B_2) * self.opt.lambda_feat
308
+ self.loss_G_VGG_3 = self.criterionVGG(self.fake_B_3, self.real_B_3) * self.opt.lambda_feat
309
+ self.loss_G_VGG_4 = self.criterionVGG(self.fake_B_4, self.real_B_4) * self.opt.lambda_feat
310
+
311
+
312
+ # .view(1,1,-1) reshapes the input (batch_size, 3, 512, 512) to (batch_size, 1, 3*512*512)
313
+ # softmax/log-softmax is then applied on the concatenated vector of size (1, 3*512*512)
314
+ # this normalizes the pixel values across all 3 RGB channels
315
+ # the resulting vectors are then used to compute KL divergence loss
316
+ self.loss_G_KLDiv_1 = self.criterionKLDiv(self.logsoftmax(self.fake_B_1.view(1,1,-1)), self.softmax(self.fake_B_1_teacher.view(1,1,-1)))
317
+ self.loss_G_KLDiv_2 = self.criterionKLDiv(self.logsoftmax(self.fake_B_2.view(1,1,-1)), self.softmax(self.fake_B_2_teacher.view(1,1,-1)))
318
+ self.loss_G_KLDiv_3 = self.criterionKLDiv(self.logsoftmax(self.fake_B_3.view(1,1,-1)), self.softmax(self.fake_B_3_teacher.view(1,1,-1)))
319
+ self.loss_G_KLDiv_4 = self.criterionKLDiv(self.logsoftmax(self.fake_B_4.view(1,1,-1)), self.softmax(self.fake_B_4_teacher.view(1,1,-1)))
320
+ self.loss_G_KLDiv_5 = self.criterionKLDiv(self.logsoftmax(self.fake_B_5.view(1,1,-1)), self.softmax(self.fake_B_5_teacher.view(1,1,-1)))
321
+ self.loss_G_KLDiv_5_1 = self.criterionKLDiv(self.logsoftmax(self.fake_B_5_1.view(1,1,-1)), self.softmax(self.fake_B_5_1_teacher.view(1,1,-1)))
322
+ self.loss_G_KLDiv_5_2 = self.criterionKLDiv(self.logsoftmax(self.fake_B_5_2.view(1,1,-1)), self.softmax(self.fake_B_5_2_teacher.view(1,1,-1)))
323
+ self.loss_G_KLDiv_5_3 = self.criterionKLDiv(self.logsoftmax(self.fake_B_5_3.view(1,1,-1)), self.softmax(self.fake_B_5_3_teacher.view(1,1,-1)))
324
+ self.loss_G_KLDiv_5_4 = self.criterionKLDiv(self.logsoftmax(self.fake_B_5_4.view(1,1,-1)), self.softmax(self.fake_B_5_4_teacher.view(1,1,-1)))
325
+ self.loss_G_KLDiv_5_5 = self.criterionKLDiv(self.logsoftmax(self.fake_B_5_5.view(1,1,-1)), self.softmax(self.fake_B_5_5_teacher.view(1,1,-1)))
326
+
327
+ self.loss_G = (self.loss_G_GAN_1 + self.loss_G_L1_1 + self.loss_G_VGG_1) * self.loss_G_weights[0] + \
328
+ (self.loss_G_GAN_2 + self.loss_G_L1_2 + self.loss_G_VGG_2) * self.loss_G_weights[1] + \
329
+ (self.loss_G_GAN_3 + self.loss_G_L1_3 + self.loss_G_VGG_3) * self.loss_G_weights[2] + \
330
+ (self.loss_G_GAN_4 + self.loss_G_L1_4 + self.loss_G_VGG_4) * self.loss_G_weights[3] + \
331
+ (self.loss_G_GAN_5 + self.loss_G_L1_5) * self.loss_G_weights[4] + \
332
+ (self.loss_G_KLDiv_1 + self.loss_G_KLDiv_2 + self.loss_G_KLDiv_3 + self.loss_G_KLDiv_4 + \
333
+ self.loss_G_KLDiv_5 + self.loss_G_KLDiv_5_1 + self.loss_G_KLDiv_5_2 + self.loss_G_KLDiv_5_3 + \
334
+ self.loss_G_KLDiv_5_4 + self.loss_G_KLDiv_5_5) * 10
335
+
336
+ # combine loss and calculate gradients
337
+ # self.loss_G = (self.loss_G_GAN_1 + self.loss_G_L1_1) * self.loss_G_weights[0] + \
338
+ # (self.loss_G_GAN_2 + self.loss_G_L1_2) * self.loss_G_weights[1] + \
339
+ # (self.loss_G_GAN_3 + self.loss_G_L1_3) * self.loss_G_weights[2] + \
340
+ # (self.loss_G_GAN_4 + self.loss_G_L1_4) * self.loss_G_weights[3] + \
341
+ # (self.loss_G_GAN_5 + self.loss_G_L1_5) * self.loss_G_weights[4]
342
+ self.loss_G.backward()
343
+
344
+ def optimize_parameters(self):
345
+ self.forward() # compute fake images: G(A)
346
+ # update D
347
+ self.set_requires_grad(self.netD1, True) # enable backprop for D1
348
+ self.set_requires_grad(self.netD2, True) # enable backprop for D2
349
+ self.set_requires_grad(self.netD3, True) # enable backprop for D3
350
+ self.set_requires_grad(self.netD4, True) # enable backprop for D4
351
+ self.set_requires_grad(self.netD51, True) # enable backprop for D51
352
+ self.set_requires_grad(self.netD52, True) # enable backprop for D52
353
+ self.set_requires_grad(self.netD53, True) # enable backprop for D53
354
+ self.set_requires_grad(self.netD54, True) # enable backprop for D54
355
+ self.set_requires_grad(self.netD55, True) # enable backprop for D54
356
+
357
+ self.optimizer_D.zero_grad() # set D's gradients to zero
358
+ self.backward_D() # calculate gradients for D
359
+ self.optimizer_D.step() # update D's weights
360
+
361
+ # update G
362
+ self.set_requires_grad(self.netD1, False) # D1 requires no gradients when optimizing G1
363
+ self.set_requires_grad(self.netD2, False) # D2 requires no gradients when optimizing G2
364
+ self.set_requires_grad(self.netD3, False) # D3 requires no gradients when optimizing G3
365
+ self.set_requires_grad(self.netD4, False) # D4 requires no gradients when optimizing G4
366
+ self.set_requires_grad(self.netD51, False) # D51 requires no gradients when optimizing G51
367
+ self.set_requires_grad(self.netD52, False) # D52 requires no gradients when optimizing G52
368
+ self.set_requires_grad(self.netD53, False) # D53 requires no gradients when optimizing G53
369
+ self.set_requires_grad(self.netD54, False) # D54 requires no gradients when optimizing G54
370
+ self.set_requires_grad(self.netD55, False) # D54 requires no gradients when optimizing G54
371
+
372
+ self.optimizer_G.zero_grad() # set G's gradients to zero
373
+ self.backward_G() # calculate graidents for G
374
+ self.optimizer_G.step() # udpate G's weights
375
+
376
+ def calculate_losses(self):
377
+ """
378
+ Calculate losses but do not optimize parameters. Used in validation loss calculation during training.
379
+ """
380
+
381
+ self.forward() # compute fake images: G(A)
382
+ # update D
383
+ self.set_requires_grad(self.netD1, True) # enable backprop for D1
384
+ self.set_requires_grad(self.netD2, True) # enable backprop for D2
385
+ self.set_requires_grad(self.netD3, True) # enable backprop for D3
386
+ self.set_requires_grad(self.netD4, True) # enable backprop for D4
387
+ self.set_requires_grad(self.netD51, True) # enable backprop for D51
388
+ self.set_requires_grad(self.netD52, True) # enable backprop for D52
389
+ self.set_requires_grad(self.netD53, True) # enable backprop for D53
390
+ self.set_requires_grad(self.netD54, True) # enable backprop for D54
391
+ self.set_requires_grad(self.netD55, True) # enable backprop for D54
392
+
393
+ self.optimizer_D.zero_grad() # set D's gradients to zero
394
+ self.backward_D() # calculate gradients for D
395
+
396
+ # update G
397
+ self.set_requires_grad(self.netD1, False) # D1 requires no gradients when optimizing G1
398
+ self.set_requires_grad(self.netD2, False) # D2 requires no gradients when optimizing G2
399
+ self.set_requires_grad(self.netD3, False) # D3 requires no gradients when optimizing G3
400
+ self.set_requires_grad(self.netD4, False) # D4 requires no gradients when optimizing G4
401
+ self.set_requires_grad(self.netD51, False) # D51 requires no gradients when optimizing G51
402
+ self.set_requires_grad(self.netD52, False) # D52 requires no gradients when optimizing G52
403
+ self.set_requires_grad(self.netD53, False) # D53 requires no gradients when optimizing G53
404
+ self.set_requires_grad(self.netD54, False) # D54 requires no gradients when optimizing G54
405
+ self.set_requires_grad(self.netD55, False) # D54 requires no gradients when optimizing G54
406
+
407
+ self.optimizer_G.zero_grad() # set G's gradients to zero
408
+ self.backward_G() # calculate graidents for G
409
+
@@ -265,12 +265,19 @@ def run_dask(img, model_path, eager_mode=False, opt=None, seg_only=False):
265
265
  return model(input.to(next(model.parameters()).device))
266
266
 
267
267
  if opt.model == 'DeepLIIF':
268
+ #weights = {
269
+ # 'G51': 0.25, # IHC
270
+ # 'G52': 0.25, # Hema
271
+ # 'G53': 0.25, # DAPI
272
+ # 'G54': 0.00, # Lap2
273
+ # 'G55': 0.25, # Marker
274
+ #}
268
275
  weights = {
269
- 'G51': 0.25, # IHC
270
- 'G52': 0.25, # Hema
271
- 'G53': 0.25, # DAPI
272
- 'G54': 0.00, # Lap2
273
- 'G55': 0.25, # Marker
276
+ 'G51': 0.5, # IHC
277
+ 'G52': 0.0, # Hema
278
+ 'G53': 0.0, # DAPI
279
+ 'G54': 0.0, # Lap2
280
+ 'G55': 0.5, # Marker
274
281
  }
275
282
 
276
283
  seg_map = {'G1': 'G52', 'G2': 'G53', 'G3': 'G54', 'G4': 'G55'}
@@ -330,8 +337,13 @@ def is_empty(tile):
330
337
  return all([True if np.max(image_variance_rgb(t)) < thresh else False for t in tile])
331
338
  else:
332
339
  return True if np.max(image_variance_rgb(tile)) < thresh else False
340
+ #return True if image_variance_gray(tile) < thresh else False
341
+ #return True if image_variance_gray(tile.resize((128, 128), resample=Image.NEAREST)) < thresh else False
342
+ #return True if image_variance_gray(tile.convert('L').resize((128, 128), resample=Image.NEAREST)) < thresh else False
343
+ #return True if image_variance_gray(tile.convert('L').resize((64, 64), resample=Image.NEAREST)) < thresh else False
333
344
 
334
345
 
346
+ count_run_tiles = 0
335
347
  def run_wrapper(tile, run_fn, model_path, eager_mode=False, opt=None, seg_only=False):
336
348
  if opt.model == 'DeepLIIF':
337
349
  if is_empty(tile):
@@ -354,6 +366,8 @@ def run_wrapper(tile, run_fn, model_path, eager_mode=False, opt=None, seg_only=F
354
366
  'G55': Image.new(mode='RGB', size=(512, 512), color=(0, 0, 0)),
355
367
  }
356
368
  else:
369
+ global count_run_tiles
370
+ count_run_tiles += 1
357
371
  return run_fn(tile, model_path, eager_mode, opt, seg_only)
358
372
  elif opt.model in ['DeepLIIFExt', 'SDG']:
359
373
  if is_empty(tile):
@@ -648,6 +662,10 @@ def get_wsi_resolution(filename):
648
662
  return None, None
649
663
 
650
664
 
665
+ from tifffile import TiffFile
666
+ import zarr
667
+ import time
668
+
651
669
  def infer_cells_for_wsi(filename, model_dir, tile_size, region_size=20000, version=3, print_log=False):
652
670
  """
653
671
  Perform inference on a slide and get the results individual cell data.
@@ -702,11 +720,23 @@ def infer_cells_for_wsi(filename, model_dir, tile_size, region_size=20000, versi
702
720
  region_XYWH = (start_x, start_y, min(stride_x, size_x-start_x), min(stride_y, size_y-start_y))
703
721
  print_info('Region:', region_XYWH)
704
722
 
723
+ tstart_read = time.time()
724
+
705
725
  region = reader.read(XYWH=region_XYWH, rescale=rescale)
706
726
  print_info(region.shape, region.dtype)
707
727
  img = Image.fromarray((region * 255).astype(np.uint8)) if rescale else Image.fromarray(region)
708
728
  print_info(img.size, img.mode)
709
729
 
730
+ #x, y, w, h = region_XYWH
731
+ #print(region_XYWH, x, y, w, h, flush=True)
732
+ #tif = TiffFile(filename)
733
+ #z = zarr.open(tif.pages[0].aszarr(), mode='r')
734
+ #img = Image.fromarray(z[y:y+h, x:x+w])
735
+ #print_info(img.size, img.mode)
736
+
737
+ tend_read = time.time()
738
+ print('Time to read region:', round(tend_read-tstart_read, 1), 'sec.', flush=True)
739
+
710
740
  images = inference(
711
741
  img,
712
742
  tile_size=tile_size,
@@ -749,6 +779,8 @@ def infer_cells_for_wsi(filename, model_dir, tile_size, region_size=20000, versi
749
779
  start_y += stride_y
750
780
 
751
781
  javabridge.kill_vm()
782
+ global count_run_tiles
783
+ print('Num tiles run:', count_run_tiles, flush=True)
752
784
 
753
785
  if count_marker_thresh == 0:
754
786
  count_marker_thresh = 1