deepliif 1.1.11__py3-none-any.whl → 1.1.13__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
cli.py CHANGED
@@ -8,11 +8,12 @@ import cv2
8
8
  import torch
9
9
  import numpy as np
10
10
  from PIL import Image
11
+ from torchvision.transforms import ToPILImage
11
12
 
12
13
  from deepliif.data import create_dataset, transform
13
- from deepliif.models import init_nets, infer_modalities, infer_results_for_wsi, create_model
14
+ from deepliif.models import init_nets, infer_modalities, infer_results_for_wsi, create_model, postprocess
14
15
  from deepliif.util import allowed_file, Visualizer, get_information, test_diff_original_serialized, disable_batchnorm_tracking_stats
15
- from deepliif.util.util import mkdirs, check_multi_scale
16
+ from deepliif.util.util import mkdirs
16
17
  # from deepliif.util import infer_results_for_wsi
17
18
  from deepliif.options import Options, print_options
18
19
 
@@ -77,6 +78,9 @@ def cli():
77
78
  @click.option('--modalities-no', default=4, type=int, help='number of targets')
78
79
  # model parameters
79
80
  @click.option('--model', default='DeepLIIF', help='name of model class')
81
+ @click.option('--seg-weights', default='', type=str, help='weights used to aggregate modality images for the final segmentation image; numbers should add up to 1, and each number corresponds to the modality in order; example: 0.25,0.15,0.25,0.1,0.25')
82
+ @click.option('--loss-weights-g', default='', type=str, help='weights used to aggregate modality-wise losses for the final loss; numbers should add up to 1, and each number corresponds to the modality in order; example: 0.2,0.2,0.2,0.2,0.2')
83
+ @click.option('--loss-weights-d', default='', type=str, help='weights used to aggregate modality-wise losses for the final loss; numbers should add up to 1, and each number corresponds to the modality in order; example: 0.2,0.2,0.2,0.2,0.2')
80
84
  @click.option('--input-nc', default=3, help='# of input image channels: 3 for RGB and 1 for grayscale')
81
85
  @click.option('--output-nc', default=3, help='# of output image channels: 3 for RGB and 1 for grayscale')
82
86
  @click.option('--ngf', default=64, help='# of gen filters in the last conv layer')
@@ -85,7 +89,7 @@ def cli():
85
89
  help='specify discriminator architecture [basic | n_layers | pixel]. The basic model is a 70x70 '
86
90
  'PatchGAN. n_layers allows you to specify the layers in the discriminator')
87
91
  @click.option('--net-g', default='resnet_9blocks',
88
- help='specify generator architecture [resnet_9blocks | resnet_6blocks | unet_512 | unet_256 | unet_128]')
92
+ help='specify generator architecture [resnet_9blocks | resnet_6blocks | unet_512 | unet_256 | unet_128 | unet_512_attention]; to specify different arch for generators, list arch for each generator separated by comma, e.g., --net-g=resnet_9blocks,resnet_9blocks,resnet_9blocks,unet_512_attention,unet_512_attention')
89
93
  @click.option('--n-layers-d', default=4, help='only used if netD==n_layers')
90
94
  @click.option('--norm', default='batch',
91
95
  help='instance normalization or batch normalization [instance | batch | none]')
@@ -93,6 +97,8 @@ def cli():
93
97
  help='network initialization [normal | xavier | kaiming | orthogonal]')
94
98
  @click.option('--init-gain', default=0.02, help='scaling factor for normal, xavier and orthogonal.')
95
99
  @click.option('--no-dropout', is_flag=True, help='no dropout for the generator')
100
+ @click.option('--upsample', default='convtranspose', help='use upsampling instead of convtranspose [convtranspose | resize_conv | pixel_shuffle]')
101
+ @click.option('--label-smoothing', type=float,default=0.0, help='label smoothing factor to prevent the discriminator from being too confident')
96
102
  # dataset parameters
97
103
  @click.option('--direction', default='AtoB', help='AtoB or BtoA')
98
104
  @click.option('--serial-batches', is_flag=True,
@@ -128,12 +134,17 @@ def cli():
128
134
  help='number of epochs with the initial learning rate')
129
135
  @click.option('--n-epochs-decay', type=int, default=100,
130
136
  help='number of epochs to linearly decay learning rate to zero')
137
+ @click.option('--optimizer', type=str, default='adam',
138
+ help='optimizer from torch.optim to use, applied to both generators and discriminators [adam | sgd | adamw | ...]; the current parameters however are set up for adam, so other optimziers may encounter issue')
131
139
  @click.option('--beta1', default=0.5, help='momentum term of adam')
132
- @click.option('--lr', default=0.0002, help='initial learning rate for adam')
140
+ #@click.option('--lr', default=0.0002, help='initial learning rate for adam')
141
+ @click.option('--lr-g', default=0.0002, help='initial learning rate for generator adam optimizer')
142
+ @click.option('--lr-d', default=0.0002, help='initial learning rate for discriminator adam optimizer')
133
143
  @click.option('--lr-policy', default='linear',
134
144
  help='learning rate policy. [linear | step | plateau | cosine]')
135
145
  @click.option('--lr-decay-iters', type=int, default=50,
136
146
  help='multiply by a gamma every lr_decay_iters iterations')
147
+ @click.option('--seed', type=int, default=None, help='basic seed to be used for deterministic training, default to None (non-deterministic)')
137
148
  # visdom and HTML visualization parameters
138
149
  @click.option('--display-freq', default=400, help='frequency of showing training results on screen')
139
150
  @click.option('--display-ncols', default=4,
@@ -158,26 +169,32 @@ def cli():
158
169
  help='chooses how datasets are loaded. [unaligned | aligned | single | colorization]')
159
170
  @click.option('--padding', type=str, default='zero',
160
171
  help='chooses the type of padding used by resnet generator. [reflect | zero]')
161
- @click.option('--local-rank', type=int, default=None, help='placeholder argument for torchrun, no need for manual setup')
162
- @click.option('--seed', type=int, default=None, help='basic seed to be used for deterministic training, default to None (non-deterministic)')
163
172
  # DeepLIIFExt params
164
173
  @click.option('--seg-gen', type=bool, default=True, help='True (Translation and Segmentation), False (Only Translation).')
165
174
  @click.option('--net-ds', type=str, default='n_layers',
166
175
  help='specify discriminator architecture for segmentation task [basic | n_layers | pixel]. The basic model is a 70x70 PatchGAN. n_layers allows you to specify the layers in the discriminator')
167
176
  @click.option('--net-gs', type=str, default='unet_512',
168
- help='specify generator architecture for segmentation task [resnet_9blocks | resnet_6blocks | unet_512 | unet_256 | unet_128]')
177
+ help='specify generator architecture for segmentation task [resnet_9blocks | resnet_6blocks | unet_512 | unet_256 | unet_128 | unet_512_attention]; to specify different arch for generators, list arch for each generator separated by comma, e.g., --net-g=resnet_9blocks,resnet_9blocks,resnet_9blocks,unet_512_attention,unet_512_attention')
169
178
  @click.option('--gan-mode', type=str, default='vanilla',
170
179
  help='the type of GAN objective for translation task. [vanilla| lsgan | wgangp]. vanilla GAN loss is the cross-entropy objective used in the original GAN paper.')
171
180
  @click.option('--gan-mode-s', type=str, default='lsgan',
172
181
  help='the type of GAN objective for segmentation task. [vanilla| lsgan | wgangp]. vanilla GAN loss is the cross-entropy objective used in the original GAN paper.')
182
+ # DDP related arguments
183
+ @click.option('--local-rank', type=int, default=None, help='placeholder argument for torchrun, no need for manual setup')
184
+ # Others
185
+ @click.option('--with-val', is_flag=True,
186
+ help='use validation set to evaluate model performance at the end of each epoch')
187
+ @click.option('--debug', is_flag=True,
188
+ help='debug mode, limits the number of data points per epoch to a small value')
189
+ @click.option('--debug-data-size', default=10, type=int, help='data size per epoch used in debug mode; due to batch size, the epoch will be passed once the completed no. data points is greater than this value (e.g., for batch size 3, debug data size 10, the effective size used in training will be 12)')
173
190
  def train(dataroot, name, gpu_ids, checkpoints_dir, input_nc, output_nc, ngf, ndf, net_d, net_g,
174
- n_layers_d, norm, init_type, init_gain, no_dropout, direction, serial_batches, num_threads,
191
+ n_layers_d, norm, init_type, init_gain, no_dropout, upsample, label_smoothing, direction, serial_batches, num_threads,
175
192
  batch_size, load_size, crop_size, max_dataset_size, preprocess, no_flip, display_winsize, epoch, load_iter,
176
193
  verbose, lambda_l1, is_train, display_freq, display_ncols, display_id, display_server, display_env,
177
194
  display_port, update_html_freq, print_freq, no_html, save_latest_freq, save_epoch_freq, save_by_iter,
178
- continue_train, epoch_count, phase, lr_policy, n_epochs, n_epochs_decay, beta1, lr, lr_decay_iters,
179
- remote, local_rank, remote_transfer_cmd, seed, dataset_mode, padding, model,
180
- modalities_no, seg_gen, net_ds, net_gs, gan_mode, gan_mode_s):
195
+ continue_train, epoch_count, phase, lr_policy, n_epochs, n_epochs_decay, optimizer, beta1, lr_g, lr_d, lr_decay_iters,
196
+ remote, remote_transfer_cmd, seed, dataset_mode, padding, model, seg_weights, loss_weights_g, loss_weights_d,
197
+ modalities_no, seg_gen, net_ds, net_gs, gan_mode, gan_mode_s, local_rank, with_val, debug, debug_data_size):
181
198
  """General-purpose training script for multi-task image-to-image translation.
182
199
 
183
200
  This script works for various models (with option '--model': e.g., DeepLIIF) and
@@ -189,7 +206,7 @@ def train(dataroot, name, gpu_ids, checkpoints_dir, input_nc, output_nc, ngf, nd
189
206
  plot, and save models.The script supports continue/resume training.
190
207
  Use '--continue_train' to resume your previous training.
191
208
  """
192
- assert model in ['DeepLIIF','DeepLIIFExt','SDG'], f'model class {model} is not implemented'
209
+ assert model in ['DeepLIIF','DeepLIIFExt','SDG','CycleGAN'], f'model class {model} is not implemented'
193
210
  if model == 'DeepLIIF':
194
211
  seg_no = 1
195
212
  elif model == 'DeepLIIFExt':
@@ -197,10 +214,16 @@ def train(dataroot, name, gpu_ids, checkpoints_dir, input_nc, output_nc, ngf, nd
197
214
  seg_no = modalities_no
198
215
  else:
199
216
  seg_no = 0
200
- else: # SDG
217
+ else: # SDG, CycleGAN
201
218
  seg_no = 0
202
219
  seg_gen = False
203
220
 
221
+ if model == 'CycleGAN':
222
+ dataset_mode = "unaligned"
223
+
224
+ if optimizer != 'adam':
225
+ print(f'Optimizer torch.optim.{optimizer} is not tested. Be careful about the parameters of the optimizer.')
226
+
204
227
  d_params = locals()
205
228
 
206
229
  if gpu_ids and gpu_ids[0] == -1:
@@ -213,12 +236,12 @@ def train(dataroot, name, gpu_ids, checkpoints_dir, input_nc, output_nc, ngf, nd
213
236
  if local_rank is not None:
214
237
  local_rank = int(local_rank)
215
238
  torch.cuda.set_device(gpu_ids[local_rank])
216
- gpu_ids=[gpu_ids[local_rank]]
239
+ gpu_ids=[local_rank]
217
240
  else:
218
241
  torch.cuda.set_device(gpu_ids[0])
219
242
 
220
243
  if local_rank is not None: # LOCAL_RANK will be assigned a rank number if torchrun ddp is used
221
- dist.init_process_group(backend='nccl')
244
+ dist.init_process_group(backend="nccl", rank=int(os.environ['RANK']), world_size=int(os.environ['WORLD_SIZE']))
222
245
  print('local rank:',local_rank)
223
246
  flag_deterministic = set_seed(seed,local_rank)
224
247
  elif rank is not None:
@@ -231,29 +254,127 @@ def train(dataroot, name, gpu_ids, checkpoints_dir, input_nc, output_nc, ngf, nd
231
254
  print('padding type is forced to zero padding, because neither refection pad2d or replication pad2d has a deterministic implementation')
232
255
 
233
256
  # infer number of input images
234
- dir_data_train = dataroot + '/train'
235
- fns = os.listdir(dir_data_train)
236
- fns = [x for x in fns if x.endswith('.png')]
237
- img = Image.open(f"{dir_data_train}/{fns[0]}")
238
257
 
239
- num_img = img.size[0] / img.size[1]
240
- assert int(num_img) == num_img, f'img size {img.size[0]} / {img.size[1]} = {num_img} is not an integer'
241
- num_img = int(num_img)
242
258
 
243
- input_no = num_img - modalities_no - seg_no
244
- assert input_no > 0, f'inferred number of input images is {input_no}; should be greater than 0'
259
+ if dataset_mode == 'unaligned':
260
+ dir_data_train = dataroot + '/trainA'
261
+ fns = os.listdir(dir_data_train)
262
+ fns = [x for x in fns if x.endswith('.png')]
263
+ print(f'{len(fns)} images found in trainA')
264
+ img = Image.open(f"{dir_data_train}/{fns[0]}")
265
+ print(f'image shape:',img.size)
266
+
267
+ for i in range(1, modalities_no + 1):
268
+ dir_data_train = dataroot + f'/trainB{i}'
269
+ fns = os.listdir(dir_data_train)
270
+ fns = [x for x in fns if x.endswith('.png')]
271
+ print(f'{len(fns)} images found in trainB{i}')
272
+ img = Image.open(f"{dir_data_train}/{fns[0]}")
273
+ print(f'image shape:',img.size)
274
+
275
+ input_no = 1
276
+ num_img = None
277
+
278
+ lambda_identity = 0
279
+ pool_size = 50 # https://github.com/junyanz/pytorch-CycleGAN-and-pix2pix/blob/master/scripts/train_cyclegan.sh
280
+
281
+ else:
282
+ dir_data_train = dataroot + '/train'
283
+ fns = os.listdir(dir_data_train)
284
+ fns = [x for x in fns if x.endswith('.png')]
285
+ print(f'{len(fns)} images found')
286
+ img = Image.open(f"{dir_data_train}/{fns[0]}")
287
+ print(f'image shape:',img.size)
288
+
289
+ num_img = img.size[0] / img.size[1]
290
+ assert int(num_img) == num_img, f'img size {img.size[0]} / {img.size[1]} = {num_img} is not an integer'
291
+ num_img = int(num_img)
292
+
293
+ input_no = num_img - modalities_no - seg_no
294
+ assert input_no > 0, f'inferred number of input images is {input_no} (modalities_no {modalities_no}, seg_no {seg_no}); should be greater than 0'
295
+
296
+ pool_size = 0
297
+
245
298
  d_params['input_no'] = input_no
246
299
  d_params['scale_size'] = img.size[1]
300
+ d_params['gpu_ids'] = gpu_ids
301
+ d_params['lambda_identity'] = 0
302
+ d_params['pool_size'] = pool_size
303
+
304
+
305
+ # update generator arch
306
+ net_g = net_g.split(',')
307
+ assert len(net_g) in [1,modalities_no], f'net_g should contain either 1 architecture for all translation generators or the same number of architectures as the number of translation generators ({modalities_no})'
308
+ if len(net_g) == 1:
309
+ net_g = net_g*modalities_no
310
+
311
+ net_gs = net_gs.split(',')
312
+ assert len(net_gs) in [1,seg_no], f'net_gs should contain either 1 architecture for all segmentation generators or the same number of architectures as the number of segmentation generators ({seg_no})'
313
+ if len(net_gs) == 1 and model == 'DeepLIIF':
314
+ net_gs = net_gs*(modalities_no + seg_no)
315
+ elif len(net_gs) == 1:
316
+ net_gs = net_gs*seg_no
317
+
318
+ d_params['net_g'] = net_g
319
+ d_params['net_gs'] = net_gs
320
+
321
+ # check seg weights and loss weights
322
+ if len(d_params['seg_weights']) == 0:
323
+ seg_weights = [0.25,0.15,0.25,0.1,0.25] if d_params['model'] == 'DeepLIIF' else [1 / modalities_no] * modalities_no
324
+ else:
325
+ seg_weights = [float(x) for x in seg_weights.split(',')]
326
+
327
+ if len(d_params['loss_weights_g']) == 0:
328
+ loss_weights_g = [0.2]*5 if d_params['model'] == 'DeepLIIF' else [1 / modalities_no] * modalities_no
329
+ else:
330
+ loss_weights_g = [float(x) for x in loss_weights_g.split(',')]
331
+
332
+ if len(d_params['loss_weights_d']) == 0:
333
+ loss_weights_d = [0.2]*5 if d_params['model'] == 'DeepLIIF' else [1 / modalities_no] * modalities_no
334
+ else:
335
+ loss_weights_d = [float(x) for x in loss_weights_d.split(',')]
336
+
337
+ assert sum(seg_weights) == 1, 'seg weights should add up to 1'
338
+ assert sum(loss_weights_g) == 1, 'loss weights g should add up to 1'
339
+ assert sum(loss_weights_d) == 1, 'loss weights d should add up to 1'
340
+
341
+ if model == 'DeepLIIF':
342
+ # +1 because input becomes an additional modality used in generating the final segmentation
343
+ assert len(seg_weights) == modalities_no+1, 'seg weights should have the same number of elements as number of modalities to be generated'
344
+ assert len(loss_weights_g) == modalities_no+1, 'loss weights g should have the same number of elements as number of modalities to be generated'
345
+ assert len(loss_weights_d) == modalities_no+1, 'loss weights d should have the same number of elements as number of modalities to be generated'
247
346
 
347
+ else:
348
+ assert len(seg_weights) == modalities_no, 'seg weights should have the same number of elements as number of modalities to be generated'
349
+ assert len(loss_weights_g) == modalities_no, 'loss weights g should have the same number of elements as number of modalities to be generated'
350
+ assert len(loss_weights_d) == modalities_no, 'loss weights d should have the same number of elements as number of modalities to be generated'
351
+
352
+ d_params['seg_weights'] = seg_weights
353
+ d_params['loss_G_weights'] = loss_weights_g
354
+ d_params['loss_D_weights'] = loss_weights_d
355
+
356
+ del d_params['loss_weights_g']
357
+ del d_params['loss_weights_d']
358
+
248
359
  # create a dataset given dataset_mode and other options
249
360
  # dataset = AlignedDataset(opt)
250
361
 
251
362
  opt = Options(d_params=d_params)
252
363
  print_options(opt, save=True)
253
364
 
365
+ # set dir for train and val
254
366
  dataset = create_dataset(opt)
367
+
255
368
  # get the number of images in the dataset.
256
369
  click.echo('The number of training images = %d' % len(dataset))
370
+
371
+ if with_val:
372
+ dataset_val = create_dataset(opt,phase='val')
373
+ data_val = [batch for batch in dataset_val]
374
+ click.echo('The number of validation images = %d' % len(dataset_val))
375
+
376
+ if model in ['DeepLIIF']:
377
+ metrics_val = json.load(open(os.path.join(dataset_val.dataset.dir_AB,'metrics.json')))
257
378
 
258
379
  # create a model given model and other options
259
380
  model = create_model(opt)
@@ -299,15 +420,15 @@ def train(dataroot, name, gpu_ids, checkpoints_dir, input_nc, output_nc, ngf, nd
299
420
  if total_iters % display_freq == 0:
300
421
  save_result = total_iters % update_html_freq == 0
301
422
  model.compute_visuals()
302
- visualizer.display_current_results(model.get_current_visuals(), epoch, save_result)
423
+ visualizer.display_current_results({**model.get_current_visuals()}, epoch, save_result)
303
424
 
304
425
  # print training losses and save logging information to the disk
305
426
  if total_iters % print_freq == 0:
306
- losses = model.get_current_losses()
427
+ losses = model.get_current_losses() # get training losses
307
428
  t_comp = (time.time() - iter_start_time) / batch_size
308
- visualizer.print_current_losses(epoch, epoch_iter, losses, t_comp, t_data)
429
+ visualizer.print_current_losses(epoch, epoch_iter, {**losses}, t_comp, t_data)
309
430
  if display_id > 0:
310
- visualizer.plot_current_losses(epoch, float(epoch_iter) / len(dataset), losses)
431
+ visualizer.plot_current_losses(epoch, float(epoch_iter) / len(dataset), {**losses})
311
432
 
312
433
  # cache our latest model every <save_latest_freq> iterations
313
434
  if total_iters % save_latest_freq == 0:
@@ -315,7 +436,11 @@ def train(dataroot, name, gpu_ids, checkpoints_dir, input_nc, output_nc, ngf, nd
315
436
  save_suffix = 'iter_%d' % total_iters if save_by_iter else 'latest'
316
437
  model.save_networks(save_suffix)
317
438
 
439
+
318
440
  iter_data_time = time.time()
441
+ if debug and epoch_iter >= debug_data_size:
442
+ print(f'debug mode, epoch {epoch} stopped at epoch iter {epoch_iter} (>= {debug_data_size})')
443
+ break
319
444
 
320
445
  # cache our model every <save_epoch_freq> epochs
321
446
  if epoch % save_epoch_freq == 0:
@@ -323,6 +448,77 @@ def train(dataroot, name, gpu_ids, checkpoints_dir, input_nc, output_nc, ngf, nd
323
448
  model.save_networks('latest')
324
449
  model.save_networks(epoch)
325
450
 
451
+
452
+
453
+ # validation loss and metrics calculation
454
+ if with_val:
455
+ losses = model.get_current_losses() # get training losses to print
456
+
457
+ model.eval()
458
+ l_losses_val = []
459
+ l_metrics_val = []
460
+
461
+ # for each val image, calculate validation loss and cell count metrics
462
+ for j, data_val_batch in enumerate(data_val):
463
+ # batch size is effectively 1 for validation
464
+ model.set_input(data_val_batch)
465
+ model.calculate_losses() # this does not optimize parameters
466
+ visuals = model.get_current_visuals() # get image results
467
+
468
+ # val losses
469
+ losses_val_batch = model.get_current_losses()
470
+ l_losses_val += [(k,v) for k,v in losses_val_batch.items()]
471
+
472
+ # calculate cell count metrics
473
+ if type(model).__name__ == 'DeepLIIFModel':
474
+ l_seg_names = ['fake_B_5']
475
+ assert l_seg_names[0] in visuals.keys(), f'Cannot find {l_seg_names[0]} in generated image names ({list(visuals.keys())})'
476
+ seg_mod_suffix = l_seg_names[0].split('_')[-1]
477
+ l_seg_names += [x for x in visuals.keys() if x.startswith('fake') and x.split('_')[-1].startswith(seg_mod_suffix) and x != l_seg_names[0]]
478
+ # print(f'Running postprocess for {len(l_seg_names)} generated images ({l_seg_names})')
479
+
480
+ img_name_current = data_val_batch['A_paths'][0].split('/')[-1][:-4] # remove .png
481
+ metrics_gt = metrics_val[img_name_current]
482
+
483
+ for seg_name in l_seg_names:
484
+ images = {'Seg':ToPILImage()((visuals[seg_name][0].cpu()+1)/2),
485
+ #'Marker':ToPILImage()((visuals['fake_B_4'][0].cpu()+1)/2)
486
+ }
487
+ _, scoring = postprocess(ToPILImage()((data['A'][0]+1)/2), images, opt.scale_size, opt.model)
488
+
489
+ for k,v in scoring.items():
490
+ if k.startswith('num') or k.startswith('percent'):
491
+ # to calculate the rmse, here we calculate (x_pred - x_true) ** 2
492
+ l_metrics_val.append((k+'_'+seg_name,(v - metrics_gt[k])**2))
493
+
494
+ if debug and epoch_iter >= debug_data_size:
495
+ print(f'debug mode, epoch {epoch} stopped at epoch iter {epoch_iter} (>= {debug_data_size})')
496
+ break
497
+
498
+ d_losses_val = {k+'_val':0 for k in losses_val_batch.keys()}
499
+ for k,v in l_losses_val:
500
+ d_losses_val[k+'_val'] += v
501
+ for k in d_losses_val:
502
+ d_losses_val[k] = d_losses_val[k] / len(data_val)
503
+
504
+ d_metrics_val = {}
505
+ for k,v in l_metrics_val:
506
+ try:
507
+ d_metrics_val[k] += v
508
+ except:
509
+ d_metrics_val[k] = v
510
+ for k in d_metrics_val:
511
+ # to calculate the rmse, this is the second part, where d_metrics_val[k] now represents sum((x_pred - x_true) ** 2)
512
+ d_metrics_val[k] = np.sqrt(d_metrics_val[k] / len(data_val))
513
+
514
+
515
+ model.train()
516
+ t_comp = (time.time() - iter_start_time) / batch_size
517
+ visualizer.print_current_losses(epoch, epoch_iter, {**losses,**d_losses_val, **d_metrics_val}, t_comp, t_data)
518
+ if display_id > 0:
519
+ visualizer.plot_current_losses(epoch, float(epoch_iter) / len(dataset), {**losses,**d_losses_val,**d_metrics_val})
520
+
521
+
326
522
  print('End of epoch %d / %d \t Time Taken: %d sec' % (
327
523
  epoch, n_epochs + n_epochs_decay, time.time() - epoch_start_time))
328
524
  # update learning rates at the end of every epoch.
@@ -336,8 +532,12 @@ def train(dataroot, name, gpu_ids, checkpoints_dir, input_nc, output_nc, ngf, nd
336
532
  help='name of the experiment. It decides where to store samples and models')
337
533
  @click.option('--gpu-ids', type=int, multiple=True, help='gpu-ids 0 gpu-ids 1 or gpu-ids -1 for CPU')
338
534
  @click.option('--checkpoints-dir', default='./checkpoints', help='models are saved here')
339
- @click.option('--targets-no', default=5, help='number of targets')
535
+ @click.option('--modalities-no', default=4, type=int, help='number of targets')
340
536
  # model parameters
537
+ @click.option('--model', default='DeepLIIF', help='name of model class')
538
+ @click.option('--seg-weights', default='', type=str, help='weights used to aggregate modality images for the final segmentation image; numbers should add up to 1, and each number corresponds to the modality in order; example: 0.25,0.15,0.25,0.1,0.25')
539
+ @click.option('--loss-weights-g', default='', type=str, help='weights used to aggregate modality-wise losses for the final loss; numbers should add up to 1, and each number corresponds to the modality in order; example: 0.2,0.2,0.2,0.2,0.2')
540
+ @click.option('--loss-weights-d', default='', type=str, help='weights used to aggregate modality-wise losses for the final loss; numbers should add up to 1, and each number corresponds to the modality in order; example: 0.2,0.2,0.2,0.2,0.2')
341
541
  @click.option('--input-nc', default=3, help='# of input image channels: 3 for RGB and 1 for grayscale')
342
542
  @click.option('--output-nc', default=3, help='# of output image channels: 3 for RGB and 1 for grayscale')
343
543
  @click.option('--ngf', default=64, help='# of gen filters in the last conv layer')
@@ -346,15 +546,16 @@ def train(dataroot, name, gpu_ids, checkpoints_dir, input_nc, output_nc, ngf, nd
346
546
  help='specify discriminator architecture [basic | n_layers | pixel]. The basic model is a 70x70 '
347
547
  'PatchGAN. n_layers allows you to specify the layers in the discriminator')
348
548
  @click.option('--net-g', default='resnet_9blocks',
349
- help='specify generator architecture [resnet_9blocks | resnet_6blocks | unet_512 | unet_256 | unet_128]')
549
+ help='specify generator architecture [resnet_9blocks | resnet_6blocks | unet_512 | unet_256 | unet_128 | unet_512_attention]; to specify different arch for generators, list arch for each generator separated by comma, e.g., --net-g=resnet_9blocks,resnet_9blocks,resnet_9blocks,unet_512_attention,unet_512_attention')
350
550
  @click.option('--n-layers-d', default=4, help='only used if netD==n_layers')
351
551
  @click.option('--norm', default='batch',
352
552
  help='instance normalization or batch normalization [instance | batch | none]')
353
553
  @click.option('--init-type', default='normal',
354
554
  help='network initialization [normal | xavier | kaiming | orthogonal]')
355
555
  @click.option('--init-gain', default=0.02, help='scaling factor for normal, xavier and orthogonal.')
356
- @click.option('--padding-type', default='reflect', help='network padding type.')
357
556
  @click.option('--no-dropout', is_flag=True, help='no dropout for the generator')
557
+ @click.option('--upsample', default='convtranspose', help='use upsampling instead of convtranspose [convtranspose | resize_conv | pixel_shuffle]')
558
+ @click.option('--label-smoothing', type=float,default=0.0, help='label smoothing factor to prevent the discriminator from being too confident')
358
559
  # dataset parameters
359
560
  @click.option('--direction', default='AtoB', help='AtoB or BtoA')
360
561
  @click.option('--serial-batches', is_flag=True,
@@ -390,12 +591,17 @@ def train(dataroot, name, gpu_ids, checkpoints_dir, input_nc, output_nc, ngf, nd
390
591
  help='number of epochs with the initial learning rate')
391
592
  @click.option('--n-epochs-decay', type=int, default=100,
392
593
  help='number of epochs to linearly decay learning rate to zero')
594
+ @click.option('--optimizer', type=str, default='adam',
595
+ help='optimizer from torch.optim to use, applied to both generators and discriminators [adam | sgd | adamw | ...]; the current parameters however are set up for adam, so other optimziers may encounter issue')
393
596
  @click.option('--beta1', default=0.5, help='momentum term of adam')
394
- @click.option('--lr', default=0.0002, help='initial learning rate for adam')
597
+ #@click.option('--lr', default=0.0002, help='initial learning rate for adam')
598
+ @click.option('--lr-g', default=0.0002, help='initial learning rate for generator adam optimizer')
599
+ @click.option('--lr-d', default=0.0002, help='initial learning rate for discriminator adam optimizer')
395
600
  @click.option('--lr-policy', default='linear',
396
601
  help='learning rate policy. [linear | step | plateau | cosine]')
397
602
  @click.option('--lr-decay-iters', type=int, default=50,
398
603
  help='multiply by a gamma every lr_decay_iters iterations')
604
+ @click.option('--seed', type=int, default=None, help='basic seed to be used for deterministic training, default to None (non-deterministic)')
399
605
  # visdom and HTML visualization parameters
400
606
  @click.option('--display-freq', default=400, help='frequency of showing training results on screen')
401
607
  @click.option('--display-ncols', default=4,
@@ -416,8 +622,29 @@ def train(dataroot, name, gpu_ids, checkpoints_dir, input_nc, output_nc, ngf, nd
416
622
  @click.option('--save-by-iter', is_flag=True, help='whether saves model by iteration')
417
623
  @click.option('--remote', type=bool, default=False, help='whether isolate visdom checkpoints or not; if False, you can run a separate visdom server anywhere that consumes the checkpoints')
418
624
  @click.option('--remote-transfer-cmd', type=str, default=None, help='module and function to be used to transfer remote files to target storage location, for example mymodule.myfunction')
625
+ @click.option('--dataset-mode', type=str, default='aligned',
626
+ help='chooses how datasets are loaded. [unaligned | aligned | single | colorization]')
627
+ @click.option('--padding', type=str, default='zero',
628
+ help='chooses the type of padding used by resnet generator. [reflect | zero]')
629
+ # DeepLIIFExt params
630
+ @click.option('--seg-gen', type=bool, default=True, help='True (Translation and Segmentation), False (Only Translation).')
631
+ @click.option('--net-ds', type=str, default='n_layers',
632
+ help='specify discriminator architecture for segmentation task [basic | n_layers | pixel]. The basic model is a 70x70 PatchGAN. n_layers allows you to specify the layers in the discriminator')
633
+ @click.option('--net-gs', type=str, default='unet_512',
634
+ help='specify generator architecture for segmentation task [resnet_9blocks | resnet_6blocks | unet_512 | unet_256 | unet_128 | unet_512_attention]; to specify different arch for generators, list arch for each generator separated by comma, e.g., --net-g=resnet_9blocks,resnet_9blocks,resnet_9blocks,unet_512_attention,unet_512_attention')
635
+ @click.option('--gan-mode', type=str, default='vanilla',
636
+ help='the type of GAN objective for translation task. [vanilla| lsgan | wgangp]. vanilla GAN loss is the cross-entropy objective used in the original GAN paper.')
637
+ @click.option('--gan-mode-s', type=str, default='lsgan',
638
+ help='the type of GAN objective for segmentation task. [vanilla| lsgan | wgangp]. vanilla GAN loss is the cross-entropy objective used in the original GAN paper.')
639
+ # DDP related arguments
419
640
  @click.option('--local-rank', type=int, default=None, help='placeholder argument for torchrun, no need for manual setup')
420
- @click.option('--seed', type=int, default=None, help='basic seed to be used for deterministic training, default to None (non-deterministic)')
641
+ # Others
642
+ @click.option('--with-val', is_flag=True,
643
+ help='use validation set to evaluate model performance at the end of each epoch')
644
+ @click.option('--debug', is_flag=True,
645
+ help='debug mode, limits the number of data points per epoch to a small value')
646
+ @click.option('--debug-data-size', default=10, type=int, help='data size per epoch used in debug mode; due to batch size, the epoch will be passed once the completed no. data points is greater than this value (e.g., for batch size 3, debug data size 10, the effective size used in training will be 12)')
647
+ # trainlaunch DDP related arguments
421
648
  @click.option('--use-torchrun', type=str, default=None, help='provide torchrun options, all in one string, for example "-t3 --log_dir ~/log/ --nproc_per_node 1"; if your pytorch version is older than 1.10, torch.distributed.launch will be called instead of torchrun')
422
649
  def trainlaunch(**kwargs):
423
650
  """
@@ -448,6 +675,7 @@ def trainlaunch(**kwargs):
448
675
  elif args[i-1] not in l_arg_skip and arg not in l_arg_skip:
449
676
  # if the previous element is not an option name to skip AND if the current element is not an option to remove
450
677
  args_final.append(arg)
678
+
451
679
 
452
680
  ## add quotes back to the input arg that had quotes, e.g., experiment name
453
681
  args_final = [f'"{arg}"' if ' ' in arg else arg for arg in args_final]
@@ -457,16 +685,29 @@ def trainlaunch(**kwargs):
457
685
 
458
686
  #### locate train.py
459
687
  import deepliif
460
- path_train_py = deepliif.__path__[0]+'/train.py'
688
+ path_train_py = deepliif.__path__[0]+'/scripts/train.py'
689
+
690
+ #### find out GPUs to use
691
+ gpu_ids = [args_final[i+1] for i,v in enumerate(args_final) if v=='--gpu-ids']
692
+ if len(gpu_ids) > 0 and gpu_ids[0] == -1:
693
+ gpu_ids = []
694
+
695
+ if len(gpu_ids) > 0:
696
+ opt_env = f"CUDA_VISIBLE_DEVICES=\"{','.join(gpu_ids)}\""
697
+ else:
698
+ opt_env = ''
461
699
 
462
700
  #### execute train.py
463
701
  if kwargs['use_torchrun']:
464
702
  if version.parse(torch.__version__) >= version.parse('1.10.0'):
465
- subprocess.run(f'torchrun {kwargs["use_torchrun"]} {path_train_py} {options}',shell=True)
703
+ cmd = f'{opt_env} torchrun {kwargs["use_torchrun"]} {path_train_py} {options}'
466
704
  else:
467
- subprocess.run(f'python -m torch.distributed.launch {kwargs["use_torchrun"]} {path_train_py} {options}',shell=True)
705
+ cmd = f'{opt_env} python -m torch.distributed.launch {kwargs["use_torchrun"]} {path_train_py} {options}'
468
706
  else:
469
- subprocess.run(f'python {path_train_py} {options}',shell=True)
707
+ cmd = f'{opt_env} python {path_train_py} {options}'
708
+
709
+ print('Executing command:',cmd)
710
+ subprocess.run(cmd,shell=True)
470
711
 
471
712
 
472
713
 
@@ -475,9 +716,10 @@ def trainlaunch(**kwargs):
475
716
  @click.option('--model-dir', default='./model-server/DeepLIIF_Latest_Model', help='reads models from here')
476
717
  @click.option('--output-dir', help='saves results here.')
477
718
  #@click.option('--tile-size', type=int, default=None, help='tile size')
478
- @click.option('--device', default='cpu', type=str, help='device to load model for the similarity test, either cpu or gpu')
719
+ @click.option('--device', default='cpu', type=str, help='device to run serialization as well as load model for the similarity test, either cpu or gpu')
720
+ @click.option('--epoch', default='latest', type=str, help='epoch to load and serialize')
479
721
  @click.option('--verbose', default=0, type=int,help='saves results here.')
480
- def serialize(model_dir, output_dir, device, verbose):
722
+ def serialize(model_dir, output_dir, device, epoch, verbose):
481
723
  """Serialize DeepLIIF models using Torchscript
482
724
  """
483
725
  #if tile_size is None:
@@ -490,12 +732,20 @@ def serialize(model_dir, output_dir, device, verbose):
490
732
  if model_dir != output_dir:
491
733
  shutil.copy(f'{model_dir}/train_opt.txt',f'{output_dir}/train_opt.txt')
492
734
 
735
+ # load and update opt for serialization
493
736
  opt = Options(path_file=os.path.join(model_dir,'train_opt.txt'), mode='test')
737
+ opt.epoch = epoch
738
+ if device == 'gpu':
739
+ opt.gpu_ids = [0] # use gpu 0, in case training was done on larger machines
740
+ else:
741
+ opt.gpu_ids = [] # use cpu
742
+
743
+ print_options(opt)
494
744
  sample = transform(Image.new('RGB', (opt.scale_size, opt.scale_size)))
495
745
  sample = torch.cat([sample]*opt.input_no, 1)
496
746
 
497
747
  with click.progressbar(
498
- init_nets(model_dir, eager_mode=True, phase='test').items(),
748
+ init_nets(model_dir, eager_mode=True, opt=opt, phase='test').items(),
499
749
  label='Tracing nets',
500
750
  item_show_func=lambda n: n[0] if n else n
501
751
  ) as bar:
@@ -514,8 +764,9 @@ def serialize(model_dir, output_dir, device, verbose):
514
764
 
515
765
  # test: whether the original and the serialized model produces highly similar predictions
516
766
  print('testing similarity between prediction from original vs serialized models...')
517
- models_original = init_nets(model_dir,eager_mode=True,phase='test')
518
- models_serialized = init_nets(output_dir,eager_mode=False,phase='test')
767
+ models_original = init_nets(model_dir,eager_mode=True,opt=opt,phase='test')
768
+ models_serialized = init_nets(output_dir,eager_mode=False,opt=opt,phase='test')
769
+
519
770
  if device == 'gpu':
520
771
  sample = sample.cuda()
521
772
  else:
@@ -523,7 +774,7 @@ def serialize(model_dir, output_dir, device, verbose):
523
774
  for name in models_serialized.keys():
524
775
  print(name,':')
525
776
  model_original = models_original[name].cuda().eval() if device=='gpu' else models_original[name].cpu().eval()
526
- model_serialized = models_serialized[name].cuda() if device=='gpu' else models_serialized[name].cpu().eval()
777
+ model_serialized = models_serialized[name].cuda().eval() if device=='gpu' else models_serialized[name].cpu().eval()
527
778
  if name.startswith('GS'):
528
779
  test_diff_original_serialized(model_original,model_serialized,torch.cat([sample, sample, sample], 1),verbose)
529
780
  else:
@@ -534,28 +785,44 @@ def serialize(model_dir, output_dir, device, verbose):
534
785
  @cli.command()
535
786
  @click.option('--input-dir', default='./Sample_Large_Tissues/', help='reads images from here')
536
787
  @click.option('--output-dir', help='saves results here.')
537
- @click.option('--tile-size', default=None, help='tile size')
788
+ @click.option('--tile-size', type=click.IntRange(min=1, max=None), required=True, help='tile size')
538
789
  @click.option('--model-dir', default='./model-server/DeepLIIF_Latest_Model/', help='load models from here.')
790
+ @click.option('--filename-pattern', default='*', help='run inference on files of which the name matches the pattern.')
539
791
  @click.option('--gpu-ids', type=int, multiple=True, help='gpu-ids 0 gpu-ids 1 or gpu-ids -1 for CPU')
540
- @click.option('--region-size', default=20000, help='Due to limits in the resources, the whole slide image cannot be processed in whole.'
541
- 'So the WSI image is read region by region. '
542
- 'This parameter specifies the size each region to be read into GPU for inferrence.')
543
792
  @click.option('--eager-mode', is_flag=True, help='use eager mode (loading original models, otherwise serialized ones)')
793
+ @click.option('--epoch', default='latest',
794
+ help='for eager mode, which epoch to load? set to latest to use latest cached model')
795
+ @click.option('--seg-intermediate', is_flag=True, help='also save intermediate segmentation images (currently only applies to DeepLIIF model)')
796
+ @click.option('--seg-only', is_flag=True, help='save only the final segmentation image (currently only applies to DeepLIIF model); overwrites --seg-intermediate')
544
797
  @click.option('--color-dapi', is_flag=True, help='color dapi image to produce the same coloring as in the paper')
545
798
  @click.option('--color-marker', is_flag=True, help='color marker image to produce the same coloring as in the paper')
546
- def test(input_dir, output_dir, tile_size, model_dir, gpu_ids, region_size, eager_mode,
547
- color_dapi, color_marker):
799
+ @click.option('--BtoA', is_flag=True, help='for models trained with unaligned dataset, this flag instructs to load generatorB instead of generatorA')
800
+ def test(input_dir, output_dir, tile_size, model_dir, filename_pattern, gpu_ids, eager_mode, epoch,
801
+ seg_intermediate, seg_only, color_dapi, color_marker, btoa):
548
802
 
549
803
  """Test trained models
550
804
  """
551
805
  output_dir = output_dir or input_dir
552
806
  ensure_exists(output_dir)
807
+
808
+ if seg_intermediate and seg_only:
809
+ seg_intermediate = False
553
810
 
554
- image_files = [fn for fn in os.listdir(input_dir) if allowed_file(fn)]
811
+ if filename_pattern == '*':
812
+ print('use all alowed files')
813
+ image_files = [fn for fn in os.listdir(input_dir) if allowed_file(fn)]
814
+ else:
815
+ import glob
816
+ print('match files using filename pattern',filename_pattern)
817
+ image_files = [os.path.basename(f) for f in glob.glob(os.path.join(input_dir, filename_pattern))]
818
+ print(len(image_files),'image files')
819
+
555
820
  files = os.listdir(model_dir)
556
821
  assert 'train_opt.txt' in files, f'file train_opt.txt is missing from model directory {model_dir}'
557
822
  opt = Options(path_file=os.path.join(model_dir,'train_opt.txt'), mode='test')
558
823
  opt.use_dp = False
824
+ opt.BtoA = btoa
825
+ opt.epoch = epoch
559
826
 
560
827
  number_of_gpus_all = torch.cuda.device_count()
561
828
  if number_of_gpus_all < len(gpu_ids) and -1 not in gpu_ids:
@@ -582,26 +849,41 @@ def test(input_dir, output_dir, tile_size, model_dir, gpu_ids, region_size, eage
582
849
  item_show_func=lambda fn: fn
583
850
  ) as bar:
584
851
  for filename in bar:
585
- if '.svs' in filename:
586
- start_time = time.time()
587
- infer_results_for_wsi(input_dir, filename, output_dir, model_dir, tile_size, region_size)
588
- print(time.time() - start_time)
589
- else:
590
- img = Image.open(os.path.join(input_dir, filename)).convert('RGB')
591
- images, scoring = infer_modalities(img, tile_size, model_dir, eager_mode, color_dapi, color_marker, opt)
852
+ img = Image.open(os.path.join(input_dir, filename)).convert('RGB')
853
+ images, scoring = infer_modalities(img, tile_size, model_dir, eager_mode, color_dapi, color_marker, opt, return_seg_intermediate=seg_intermediate, seg_only=seg_only)
854
+
855
+ for name, i in images.items():
856
+ i.save(os.path.join(
857
+ output_dir,
858
+ filename.replace('.' + filename.split('.')[-1], f'_{name}.png')
859
+ ))
592
860
 
593
- for name, i in images.items():
594
- i.save(os.path.join(
861
+ if scoring is not None:
862
+ with open(os.path.join(
595
863
  output_dir,
596
- filename.replace('.' + filename.split('.')[-1], f'_{name}.png')
597
- ))
864
+ filename.replace('.' + filename.split('.')[-1], f'.json')
865
+ ), 'w') as f:
866
+ json.dump(scoring, f, indent=2)
867
+
868
+
869
+ @cli.command()
870
+ @click.option('--input-dir', required=True, help='directory containing WSI file')
871
+ @click.option('--filename', required=True, help='name of WSI to read')
872
+ @click.option('--output-dir', required=True, help='saves results here.')
873
+ @click.option('--tile-size', type=click.IntRange(min=1, max=None), required=True, help='tile size')
874
+ @click.option('--model-dir', default='./model-server/DeepLIIF_Latest_Model/', help='load models from here.')
875
+ @click.option('--region-size', default=20000, help='Due to limits in the resources, the whole slide image cannot be processed in whole.'
876
+ 'So the WSI image is read region by region. '
877
+ 'This parameter specifies the size each region to be read into GPU for inferrence.')
878
+ @click.option('--seg-intermediate', is_flag=True, help='also save intermediate segmentation images (currently only applies to DeepLIIF model)')
879
+ @click.option('--seg-only', is_flag=True, help='save only the final segmentation image (currently only applies to DeepLIIF model)')
880
+ @click.option('--color-dapi', is_flag=True, help='color dapi image to produce the same coloring as in the paper')
881
+ @click.option('--color-marker', is_flag=True, help='color marker image to produce the same coloring as in the paper')
882
+ def test_wsi(input_dir, filename, output_dir, tile_size, model_dir, region_size, seg_intermediate, seg_only, color_dapi, color_marker):
883
+ infer_results_for_wsi(input_dir, filename, output_dir, model_dir, tile_size, region_size,
884
+ color_dapi=color_dapi, color_marker=color_marker,
885
+ seg_intermediate=seg_intermediate, seg_only=seg_only)
598
886
 
599
- if scoring is not None:
600
- with open(os.path.join(
601
- output_dir,
602
- filename.replace('.' + filename.split('.')[-1], f'.json')
603
- ), 'w') as f:
604
- json.dump(scoring, f, indent=2)
605
887
 
606
888
  @cli.command()
607
889
  @click.option('--input-dir', type=str, required=True, help='Path to input images')
@@ -721,4 +1003,9 @@ def visualize(pickle_dir, display_env):
721
1003
 
722
1004
 
723
1005
  if __name__ == '__main__':
1006
+ # tensor float 32 is available on nvidia ampere cards (e.g, a100, a40) and provides better performance at the cost of a bit lower precision
1007
+ # in 1.7-1.11, pytorch by default enables tf32 when possible
1008
+ # currently convolutions still uses tf32 by default while matmul does not and needs to be enabled manually
1009
+ # see this issue for a discussion: https://github.com/pytorch/pytorch/issues/67384
1010
+ torch.backends.cuda.matmul.allow_tf32 = True
724
1011
  cli()