deepliif 1.2.2__py3-none-any.whl → 1.2.4__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
cli.py CHANGED
@@ -11,9 +11,9 @@ from PIL import Image
11
11
  from torchvision.transforms import ToPILImage
12
12
 
13
13
  from deepliif.data import create_dataset, transform
14
- from deepliif.models import init_nets, infer_modalities, infer_results_for_wsi, create_model, postprocess
15
- from deepliif.util import allowed_file, Visualizer, get_information, test_diff_original_serialized, disable_batchnorm_tracking_stats
16
- from deepliif.util.util import mkdirs
14
+ from deepliif.models import init_nets, infer_modalities, infer_results_for_wsi, create_model, postprocess, get_opt
15
+ from deepliif.util import allowed_file, Visualizer, test_diff_original_serialized, disable_batchnorm_tracking_stats, infer_background_colors, get_information
16
+ from deepliif.util.util import mkdirs, get_mod_id_seg
17
17
  from deepliif.util.checks import check_weights
18
18
  # from deepliif.util import infer_results_for_wsi
19
19
  from deepliif.options import Options, print_options
@@ -77,6 +77,7 @@ def cli():
77
77
  @click.option('--gpu-ids', type=int, multiple=True, help='gpu-ids 0 gpu-ids 1 or gpu-ids -1 for CPU')
78
78
  @click.option('--checkpoints-dir', default='./checkpoints', help='models are saved here')
79
79
  @click.option('--modalities-no', default=4, type=int, help='number of targets')
80
+ @click.option('--modalities-names', default='', type=str, help='an optional note of the name of each modality (input mod(s) and the mod(s) to learn stain transfer from), separated by comma; this helps document the modalities using the train opt file and will also be used to name the inference output; example: --modalities-names IHC,Hematoxylin,DAPI,Lap2,Marker')
80
81
  # model parameters
81
82
  @click.option('--model', default='DeepLIIF', help='name of model class')
82
83
  @click.option('--model-dir-teacher', default='', help='the directory of the teacher model, only applicable if model is DeepLIIFKD')
@@ -189,6 +190,7 @@ def cli():
189
190
  @click.option('--debug', is_flag=True,
190
191
  help='debug mode, limits the number of data points per epoch to a small value')
191
192
  @click.option('--debug-data-size', default=10, type=int, help='data size per epoch used in debug mode; due to batch size, the epoch will be passed once the completed no. data points is greater than this value (e.g., for batch size 3, debug data size 10, the effective size used in training will be 12)')
193
+ @click.option('--monitor-image', default=None, help='a filename in the training dataset, if set, used for the visualization of model results; this overwrites --display-freq because we now focus on viewing the training progress on one fixed image')
192
194
  def train(dataroot, name, gpu_ids, checkpoints_dir, input_nc, output_nc, ngf, ndf, net_d, net_g,
193
195
  n_layers_d, norm, init_type, init_gain, no_dropout, upsample, label_smoothing, direction, serial_batches, num_threads,
194
196
  batch_size, load_size, crop_size, max_dataset_size, preprocess, no_flip, display_winsize, epoch, load_iter,
@@ -197,7 +199,7 @@ def train(dataroot, name, gpu_ids, checkpoints_dir, input_nc, output_nc, ngf, nd
197
199
  continue_train, epoch_count, phase, lr_policy, n_epochs, n_epochs_decay, optimizer, beta1, lr_g, lr_d, lr_decay_iters,
198
200
  remote, remote_transfer_cmd, seed, dataset_mode, padding, model, model_dir_teacher,
199
201
  seg_weights, loss_weights_g, loss_weights_d,
200
- modalities_no, seg_gen, net_ds, net_gs, gan_mode, gan_mode_s, local_rank, with_val, debug, debug_data_size):
202
+ modalities_no, modalities_names, seg_gen, net_ds, net_gs, gan_mode, gan_mode_s, local_rank, with_val, debug, debug_data_size, monitor_image):
201
203
  """General-purpose training script for multi-task image-to-image translation.
202
204
 
203
205
  This script works for various models (with option '--model': e.g., DeepLIIF) and
@@ -221,6 +223,7 @@ def train(dataroot, name, gpu_ids, checkpoints_dir, input_nc, output_nc, ngf, nd
221
223
  seg_no = 0
222
224
  seg_gen = False
223
225
 
226
+
224
227
  if model == 'CycleGAN':
225
228
  dataset_mode = "unaligned"
226
229
 
@@ -260,8 +263,6 @@ def train(dataroot, name, gpu_ids, checkpoints_dir, input_nc, output_nc, ngf, nd
260
263
  print('padding type is forced to zero padding, because neither refection pad2d or replication pad2d has a deterministic implementation')
261
264
 
262
265
  # infer number of input images
263
-
264
-
265
266
  if dataset_mode == 'unaligned':
266
267
  dir_data_train = dataroot + '/trainA'
267
268
  fns = os.listdir(dir_data_train)
@@ -300,13 +301,28 @@ def train(dataroot, name, gpu_ids, checkpoints_dir, input_nc, output_nc, ngf, nd
300
301
  assert input_no > 0, f'inferred number of input images is {input_no} (modalities_no {modalities_no}, seg_no {seg_no}); should be greater than 0'
301
302
 
302
303
  pool_size = 0
303
-
304
+
305
+ modalities_names = [name.strip() for name in modalities_names.split(',') if len(name) > 0]
306
+ assert len(modalities_names) == 0 or len(modalities_names) == input_no + modalities_no, f'--modalities-names has {len(modalities_names)} entries ({modalities_names}), expecting 0 or {input_no + modalities_no} entries'
307
+
308
+ if len(modalities_names) == 0 and model == 'DeepLIIFKD':
309
+ # inherit this property from teacher model
310
+ opt_teacher = get_opt(model_dir_teacher, mode='test')
311
+ modalities_names = opt_teacher.modalities_names
312
+
304
313
  d_params['input_no'] = input_no
314
+ d_params['modalities_names'] = modalities_names
305
315
  d_params['scale_size'] = img.size[1]
306
316
  d_params['gpu_ids'] = gpu_ids
307
317
  d_params['lambda_identity'] = 0
308
318
  d_params['pool_size'] = pool_size
309
319
 
320
+ if seg_gen:
321
+ # estimate background color for output modalities
322
+ background_colors = infer_background_colors(os.path.join(dataroot,'train'), sample_size=10, input_no=input_no, modalities_no=modalities_no, seg_no=seg_no, tile_size=32, return_list=True)
323
+ if background_colors is not None:
324
+ d_params['background_colors'] = background_colors
325
+
310
326
 
311
327
  # update generator arch
312
328
  net_g = net_g.split(',')
@@ -324,19 +340,27 @@ def train(dataroot, name, gpu_ids, checkpoints_dir, input_nc, output_nc, ngf, nd
324
340
  d_params['net_g'] = net_g
325
341
  d_params['net_gs'] = net_gs
326
342
 
343
+ def get_weights(model='DeepLIIF', modalities_no=4, default=[0.25,0.15,0.25,0.1,0.25]):
344
+ if model in ['DeepLIIF','DeepLIIFKD'] and modalities_no == 4:
345
+ return default
346
+ elif model in ['DeepLIIF','DeepLIIFKD']:
347
+ return [1 / (modalities_no + 1)] * (modalities_no + 1)
348
+ else:
349
+ return [1 / modalities_no] * modalities_no
350
+
327
351
  # check seg weights and loss weights
328
352
  if len(d_params['seg_weights']) == 0:
329
- seg_weights = [0.25,0.15,0.25,0.1,0.25] if d_params['model'] in ['DeepLIIF','DeepLIIFKD'] else [1 / modalities_no] * modalities_no
353
+ seg_weights = get_weights(d_params['model'], modalities_no, default=[0.25,0.15,0.25,0.1,0.25])
330
354
  else:
331
355
  seg_weights = [float(x) for x in seg_weights.split(',')]
332
356
 
333
357
  if len(d_params['loss_weights_g']) == 0:
334
- loss_weights_g = [0.2]*5 if d_params['model'] in ['DeepLIIF','DeepLIIFKD'] else [1 / modalities_no] * modalities_no
358
+ loss_weights_g = get_weights(d_params['model'], modalities_no, default=[0.2]*5)
335
359
  else:
336
360
  loss_weights_g = [float(x) for x in loss_weights_g.split(',')]
337
361
 
338
362
  if len(d_params['loss_weights_d']) == 0:
339
- loss_weights_d = [0.2]*5 if d_params['model'] in ['DeepLIIF','DeepLIIFKD'] else [1 / modalities_no] * modalities_no
363
+ loss_weights_d = get_weights(d_params['model'], modalities_no, default=[0.2]*5)
340
364
  else:
341
365
  loss_weights_d = [float(x) for x in loss_weights_d.split(',')]
342
366
 
@@ -378,6 +402,15 @@ def train(dataroot, name, gpu_ids, checkpoints_dir, input_nc, output_nc, ngf, nd
378
402
  visualizer = Visualizer(opt)
379
403
  # the total number of training iterations
380
404
  total_iters = 0
405
+
406
+ # infer base epoch number, used for checkpoint filename
407
+ if not continue_train:
408
+ epoch_base = 0
409
+ else:
410
+ try:
411
+ epoch_base = int(epoch)
412
+ except:
413
+ epoch_base = 0
381
414
 
382
415
  # outer loop for different epochs; we save the model by <epoch_count>, <epoch_count>+<save_latest_freq>
383
416
  for epoch in range(epoch_count, n_epochs + n_epochs_decay + 1):
@@ -410,10 +443,16 @@ def train(dataroot, name, gpu_ids, checkpoints_dir, input_nc, output_nc, ngf, nd
410
443
  model.optimize_parameters()
411
444
 
412
445
  # display images on visdom and save images to a HTML file
413
- if total_iters % display_freq == 0:
414
- save_result = total_iters % update_html_freq == 0
415
- model.compute_visuals()
416
- visualizer.display_current_results({**model.get_current_visuals()}, epoch, save_result)
446
+ if monitor_image is not None:
447
+ if data['A_paths'][0].endswith(monitor_image):
448
+ save_result = total_iters % update_html_freq == 0
449
+ model.compute_visuals()
450
+ visualizer.display_current_results({**model.get_current_visuals()}, epoch, save_result, filename=monitor_image)
451
+ else:
452
+ if total_iters % display_freq == 0:
453
+ save_result = total_iters % update_html_freq == 0
454
+ model.compute_visuals()
455
+ visualizer.display_current_results({**model.get_current_visuals()}, epoch, save_result, filename=data['A_paths'][0])
417
456
 
418
457
  # print training losses and save logging information to the disk
419
458
  if total_iters % print_freq == 0:
@@ -437,9 +476,12 @@ def train(dataroot, name, gpu_ids, checkpoints_dir, input_nc, output_nc, ngf, nd
437
476
 
438
477
  # cache our model every <save_epoch_freq> epochs
439
478
  if epoch % save_epoch_freq == 0:
440
- print('saving the model at the end of epoch %d, iters %d' % (epoch, total_iters))
441
- model.save_networks('latest')
442
- model.save_networks(epoch)
479
+ if continue_train and epoch == 0: # to not overwrite the loaded epoch
480
+ pass
481
+ else:
482
+ print('saving the model at the end of epoch %d, iters %d' % (epoch, total_iters))
483
+ model.save_networks('latest')
484
+ model.save_networks(epoch+epoch_base)
443
485
 
444
486
 
445
487
 
@@ -463,8 +505,12 @@ def train(dataroot, name, gpu_ids, checkpoints_dir, input_nc, output_nc, ngf, nd
463
505
  l_losses_val += [(k,v) for k,v in losses_val_batch.items()]
464
506
 
465
507
  # calculate cell count metrics
466
- if type(model).__name__ == 'DeepLIIFModel':
467
- l_seg_names = ['fake_B_5']
508
+ if type(model).__name__ in ['DeepLIIFModel','DeepLIIFKDModel']:
509
+ if continue_train:
510
+ mod_id_seg = get_mod_id_seg(os.path.join(opt.checkpoints_dir, opt.name))
511
+ else:
512
+ mod_id_seg = 'S'
513
+ l_seg_names = [f'fake_B_{mod_id_seg}']
468
514
  assert l_seg_names[0] in visuals.keys(), f'Cannot find {l_seg_names[0]} in generated image names ({list(visuals.keys())})'
469
515
  seg_mod_suffix = l_seg_names[0].split('_')[-1]
470
516
  l_seg_names += [x for x in visuals.keys() if x.startswith('fake') and x.split('_')[-1].startswith(seg_mod_suffix) and x != l_seg_names[0]]
@@ -526,8 +572,10 @@ def train(dataroot, name, gpu_ids, checkpoints_dir, input_nc, output_nc, ngf, nd
526
572
  @click.option('--gpu-ids', type=int, multiple=True, help='gpu-ids 0 gpu-ids 1 or gpu-ids -1 for CPU')
527
573
  @click.option('--checkpoints-dir', default='./checkpoints', help='models are saved here')
528
574
  @click.option('--modalities-no', default=4, type=int, help='number of targets')
575
+ @click.option('--modalities-names', default='', type=str, help='an optional note of the name of each modality (input mod(s) and the mod(s) to learn stain transfer from), separated by comma; this helps document the modalities using the train opt file and will also be used to name the inference output; example: --modalities-names IHC,Hematoxylin,DAPI,Lap2,Marker')
529
576
  # model parameters
530
577
  @click.option('--model', default='DeepLIIF', help='name of model class')
578
+ @click.option('--model-dir-teacher', default='', help='the directory of the teacher model, only applicable if model is DeepLIIFKD')
531
579
  @click.option('--seg-weights', default='', type=str, help='weights used to aggregate modality images for the final segmentation image; numbers should add up to 1, and each number corresponds to the modality in order; example: 0.25,0.15,0.25,0.1,0.25')
532
580
  @click.option('--loss-weights-g', default='', type=str, help='weights used to aggregate modality-wise losses for the final loss; numbers should add up to 1, and each number corresponds to the modality in order; example: 0.2,0.2,0.2,0.2,0.2')
533
581
  @click.option('--loss-weights-d', default='', type=str, help='weights used to aggregate modality-wise losses for the final loss; numbers should add up to 1, and each number corresponds to the modality in order; example: 0.2,0.2,0.2,0.2,0.2')
@@ -637,6 +685,7 @@ def train(dataroot, name, gpu_ids, checkpoints_dir, input_nc, output_nc, ngf, nd
637
685
  @click.option('--debug', is_flag=True,
638
686
  help='debug mode, limits the number of data points per epoch to a small value')
639
687
  @click.option('--debug-data-size', default=10, type=int, help='data size per epoch used in debug mode; due to batch size, the epoch will be passed once the completed no. data points is greater than this value (e.g., for batch size 3, debug data size 10, the effective size used in training will be 12)')
688
+ @click.option('--monitor-image', default=None, help='a filename in the training dataset, if set, used for the visualization of model results; this overwrites --display-freq because we now focus on viewing the training progress on one fixed image')
640
689
  # trainlaunch DDP related arguments
641
690
  @click.option('--use-torchrun', type=str, default=None, help='provide torchrun options, all in one string, for example "-t3 --log_dir ~/log/ --nproc_per_node 1"; if your pytorch version is older than 1.10, torch.distributed.launch will be called instead of torchrun')
642
691
  def trainlaunch(**kwargs):
@@ -748,7 +797,7 @@ def serialize(model_dir, output_dir, device, epoch, verbose):
748
797
  net = net.eval()
749
798
  net = disable_batchnorm_tracking_stats(net)
750
799
  net = net.cpu()
751
- if name.startswith('GS'):
800
+ if opt.model in ['DeepLIIFExt'] and name[1]=='S':
752
801
  traced_net = torch.jit.trace(net, torch.cat([sample, sample, sample], 1))
753
802
  else:
754
803
  traced_net = torch.jit.trace(net, sample)
@@ -768,7 +817,7 @@ def serialize(model_dir, output_dir, device, epoch, verbose):
768
817
  print(name,':')
769
818
  model_original = models_original[name].cuda().eval() if device=='gpu' else models_original[name].cpu().eval()
770
819
  model_serialized = models_serialized[name].cuda().eval() if device=='gpu' else models_serialized[name].cpu().eval()
771
- if name.startswith('GS'):
820
+ if opt.model in ['DeepLIIFExt'] and name[1]=='S':
772
821
  test_diff_original_serialized(model_original,model_serialized,torch.cat([sample, sample, sample], 1),verbose)
773
822
  else:
774
823
  test_diff_original_serialized(model_original,model_serialized,sample,verbose)
@@ -787,18 +836,22 @@ def serialize(model_dir, output_dir, device, epoch, verbose):
787
836
  help='for eager mode, which epoch to load? set to latest to use latest cached model')
788
837
  @click.option('--seg-intermediate', is_flag=True, help='also save intermediate segmentation images (currently only applies to DeepLIIF model)')
789
838
  @click.option('--seg-only', is_flag=True, help='save only the final segmentation image (currently only applies to DeepLIIF model); overwrites --seg-intermediate')
839
+ @click.option('--mod-only', is_flag=True, help='save only the translated modality image; overwrites --seg-only and --seg-intermediate')
790
840
  @click.option('--color-dapi', is_flag=True, help='color dapi image to produce the same coloring as in the paper')
791
841
  @click.option('--color-marker', is_flag=True, help='color marker image to produce the same coloring as in the paper')
792
842
  @click.option('--BtoA', is_flag=True, help='for models trained with unaligned dataset, this flag instructs to load generatorB instead of generatorA')
793
843
  def test(input_dir, output_dir, tile_size, model_dir, filename_pattern, gpu_ids, eager_mode, epoch,
794
- seg_intermediate, seg_only, color_dapi, color_marker, btoa):
844
+ seg_intermediate, seg_only, mod_only, color_dapi, color_marker, btoa):
795
845
 
796
846
  """Test trained models
797
847
  """
798
848
  output_dir = output_dir or input_dir
799
849
  ensure_exists(output_dir)
800
850
 
801
- if seg_intermediate and seg_only:
851
+ if mod_only:
852
+ seg_only = False
853
+ seg_intermediate = False
854
+ elif seg_intermediate and seg_only:
802
855
  seg_intermediate = False
803
856
 
804
857
  if filename_pattern == '*':
@@ -817,6 +870,8 @@ def test(input_dir, output_dir, tile_size, model_dir, filename_pattern, gpu_ids,
817
870
  opt.BtoA = btoa
818
871
  opt.epoch = epoch
819
872
 
873
+ seg_weights = opt.seg_weights if hasattr(opt,'seg_weights') else None
874
+
820
875
  number_of_gpus_all = torch.cuda.device_count()
821
876
  if number_of_gpus_all < len(gpu_ids) and -1 not in gpu_ids:
822
877
  number_of_gpus = 0
@@ -843,7 +898,7 @@ def test(input_dir, output_dir, tile_size, model_dir, filename_pattern, gpu_ids,
843
898
  ) as bar:
844
899
  for filename in bar:
845
900
  img = Image.open(os.path.join(input_dir, filename)).convert('RGB')
846
- images, scoring = infer_modalities(img, tile_size, model_dir, eager_mode, color_dapi, color_marker, opt, return_seg_intermediate=seg_intermediate, seg_only=seg_only)
901
+ images, scoring = infer_modalities(img, tile_size, model_dir, eager_mode, color_dapi, color_marker, opt, return_seg_intermediate=seg_intermediate, seg_only=seg_only, mod_only=mod_only, seg_weights=seg_weights)
847
902
 
848
903
  for name, i in images.items():
849
904
  i.save(os.path.join(
@@ -99,7 +99,9 @@ def get_transform(preprocess, load_size, crop_size, no_flip, params=None, graysc
99
99
 
100
100
  if not no_flip:
101
101
  if params is None:
102
+ # default p=0.5
102
103
  transform_list.append(transforms.RandomHorizontalFlip())
104
+ transform_list.append(transforms.RandomVerticalFlip())
103
105
  elif params['flip']:
104
106
  transform_list.append(transforms.Lambda(lambda img: __flip(img, params['flip'])))
105
107