deepliif 1.1.9__py3-none-any.whl → 1.1.11__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
cli.py CHANGED
@@ -14,7 +14,7 @@ from deepliif.models import init_nets, infer_modalities, infer_results_for_wsi,
14
14
  from deepliif.util import allowed_file, Visualizer, get_information, test_diff_original_serialized, disable_batchnorm_tracking_stats
15
15
  from deepliif.util.util import mkdirs, check_multi_scale
16
16
  # from deepliif.util import infer_results_for_wsi
17
- from deepliif.options import Options
17
+ from deepliif.options import Options, print_options
18
18
 
19
19
  import torch.distributed as dist
20
20
 
@@ -59,29 +59,6 @@ def set_seed(seed=0,rank=None):
59
59
  def ensure_exists(d):
60
60
  if not os.path.exists(d):
61
61
  os.makedirs(d)
62
-
63
- def print_options(opt):
64
- """Print and save options
65
-
66
- It will print both current options and default values(if different).
67
- It will save options into a text file / [checkpoints_dir] / opt.txt
68
- """
69
- message = ''
70
- message += '----------------- Options ---------------\n'
71
- for k, v in sorted(vars(opt).items()):
72
- comment = ''
73
- message += '{:>25}: {:<30}{}\n'.format(str(k), str(v), comment)
74
- message += '----------------- End -------------------'
75
- print(message)
76
-
77
- # save to the disk
78
- if opt.phase == 'train':
79
- expr_dir = os.path.join(opt.checkpoints_dir, opt.name)
80
- mkdirs(expr_dir)
81
- file_name = os.path.join(expr_dir, '{}_opt.txt'.format(opt.phase))
82
- with open(file_name, 'wt') as opt_file:
83
- opt_file.write(message)
84
- opt_file.write('\n')
85
62
 
86
63
 
87
64
  @click.group()
@@ -212,6 +189,18 @@ def train(dataroot, name, gpu_ids, checkpoints_dir, input_nc, output_nc, ngf, nd
212
189
  plot, and save models.The script supports continue/resume training.
213
190
  Use '--continue_train' to resume your previous training.
214
191
  """
192
+ assert model in ['DeepLIIF','DeepLIIFExt','SDG'], f'model class {model} is not implemented'
193
+ if model == 'DeepLIIF':
194
+ seg_no = 1
195
+ elif model == 'DeepLIIFExt':
196
+ if seg_gen:
197
+ seg_no = modalities_no
198
+ else:
199
+ seg_no = 0
200
+ else: # SDG
201
+ seg_no = 0
202
+ seg_gen = False
203
+
215
204
  d_params = locals()
216
205
 
217
206
  if gpu_ids and gpu_ids[0] == -1:
@@ -241,11 +230,26 @@ def train(dataroot, name, gpu_ids, checkpoints_dir, input_nc, output_nc, ngf, nd
241
230
  d_params['padding'] = 'zero'
242
231
  print('padding type is forced to zero padding, because neither refection pad2d or replication pad2d has a deterministic implementation')
243
232
 
233
+ # infer number of input images
234
+ dir_data_train = dataroot + '/train'
235
+ fns = os.listdir(dir_data_train)
236
+ fns = [x for x in fns if x.endswith('.png')]
237
+ img = Image.open(f"{dir_data_train}/{fns[0]}")
238
+
239
+ num_img = img.size[0] / img.size[1]
240
+ assert int(num_img) == num_img, f'img size {img.size[0]} / {img.size[1]} = {num_img} is not an integer'
241
+ num_img = int(num_img)
242
+
243
+ input_no = num_img - modalities_no - seg_no
244
+ assert input_no > 0, f'inferred number of input images is {input_no}; should be greater than 0'
245
+ d_params['input_no'] = input_no
246
+ d_params['scale_size'] = img.size[1]
247
+
244
248
  # create a dataset given dataset_mode and other options
245
249
  # dataset = AlignedDataset(opt)
246
250
 
247
251
  opt = Options(d_params=d_params)
248
- print_options(opt)
252
+ print_options(opt, save=True)
249
253
 
250
254
  dataset = create_dataset(opt)
251
255
  # get the number of images in the dataset.
@@ -468,28 +472,30 @@ def trainlaunch(**kwargs):
468
472
 
469
473
 
470
474
  @cli.command()
471
- @click.option('--models-dir', default='./model-server/DeepLIIF_Latest_Model', help='reads models from here')
475
+ @click.option('--model-dir', default='./model-server/DeepLIIF_Latest_Model', help='reads models from here')
472
476
  @click.option('--output-dir', help='saves results here.')
473
- @click.option('--tile-size', type=int, default=None, help='tile size')
474
- @click.option('--device', default='cpu', type=str, help='device to load model, either cpu or gpu')
477
+ #@click.option('--tile-size', type=int, default=None, help='tile size')
478
+ @click.option('--device', default='cpu', type=str, help='device to load model for the similarity test, either cpu or gpu')
475
479
  @click.option('--verbose', default=0, type=int,help='saves results here.')
476
- def serialize(models_dir, output_dir, tile_size, device, verbose):
480
+ def serialize(model_dir, output_dir, device, verbose):
477
481
  """Serialize DeepLIIF models using Torchscript
478
482
  """
479
- if tile_size is None:
480
- tile_size = 512
481
- output_dir = output_dir or models_dir
483
+ #if tile_size is None:
484
+ # tile_size = 512
485
+ output_dir = output_dir or model_dir
482
486
  ensure_exists(output_dir)
483
487
 
484
488
  # copy train_opt.txt to the target location
485
489
  import shutil
486
- if models_dir != output_dir:
487
- shutil.copy(f'{models_dir}/train_opt.txt',f'{output_dir}/train_opt.txt')
490
+ if model_dir != output_dir:
491
+ shutil.copy(f'{model_dir}/train_opt.txt',f'{output_dir}/train_opt.txt')
488
492
 
489
- sample = transform(Image.new('RGB', (tile_size, tile_size)))
493
+ opt = Options(path_file=os.path.join(model_dir,'train_opt.txt'), mode='test')
494
+ sample = transform(Image.new('RGB', (opt.scale_size, opt.scale_size)))
495
+ sample = torch.cat([sample]*opt.input_no, 1)
490
496
 
491
497
  with click.progressbar(
492
- init_nets(models_dir, eager_mode=True, phase='test').items(),
498
+ init_nets(model_dir, eager_mode=True, phase='test').items(),
493
499
  label='Tracing nets',
494
500
  item_show_func=lambda n: n[0] if n else n
495
501
  ) as bar:
@@ -508,7 +514,7 @@ def serialize(models_dir, output_dir, tile_size, device, verbose):
508
514
 
509
515
  # test: whether the original and the serialized model produces highly similar predictions
510
516
  print('testing similarity between prediction from original vs serialized models...')
511
- models_original = init_nets(models_dir,eager_mode=True,phase='test')
517
+ models_original = init_nets(model_dir,eager_mode=True,phase='test')
512
518
  models_serialized = init_nets(output_dir,eager_mode=False,phase='test')
513
519
  if device == 'gpu':
514
520
  sample = sample.cuda()
@@ -590,11 +596,12 @@ def test(input_dir, output_dir, tile_size, model_dir, gpu_ids, region_size, eage
590
596
  filename.replace('.' + filename.split('.')[-1], f'_{name}.png')
591
597
  ))
592
598
 
593
- with open(os.path.join(
594
- output_dir,
595
- filename.replace('.' + filename.split('.')[-1], f'.json')
596
- ), 'w') as f:
597
- json.dump(scoring, f, indent=2)
599
+ if scoring is not None:
600
+ with open(os.path.join(
601
+ output_dir,
602
+ filename.replace('.' + filename.split('.')[-1], f'.json')
603
+ ), 'w') as f:
604
+ json.dump(scoring, f, indent=2)
598
605
 
599
606
  @cli.command()
600
607
  @click.option('--input-dir', type=str, required=True, help='Path to input images')
@@ -26,6 +26,8 @@ class AlignedDataset(BaseDataset):
26
26
  self.output_nc = opt.input_nc if opt.direction == 'BtoA' else opt.output_nc
27
27
  self.no_flip = opt.no_flip
28
28
  self.modalities_no = opt.modalities_no
29
+ self.seg_no = opt.seg_no
30
+ self.input_no = opt.input_no
29
31
  self.seg_gen = opt.seg_gen
30
32
  self.load_size = opt.load_size
31
33
  self.crop_size = opt.crop_size
@@ -52,6 +54,8 @@ class AlignedDataset(BaseDataset):
52
54
  num_img = self.modalities_no + 1 + 1 # +1 for segmentation channel, +1 for input image
53
55
  elif self.model == 'DeepLIIFExt':
54
56
  num_img = self.modalities_no * 2 + 1 if self.seg_gen else self.modalities_no + 1 # +1 for segmentation channel
57
+ elif self.model == 'SDG':
58
+ num_img = self.modalities_no + self.seg_no + self.input_no
55
59
  else:
56
60
  raise Exception(f'model class {self.model} does not have corresponding implementation in deepliif/data/aligned_dataset.py')
57
61
  w2 = int(w / num_img)
@@ -85,6 +89,19 @@ class AlignedDataset(BaseDataset):
85
89
  BS_Array.append(BS)
86
90
 
87
91
  return {'A': A, 'B': B_Array, 'BS': BS_Array,'A_paths': AB_path, 'B_paths': AB_path}
92
+ elif self.model == 'SDG':
93
+ A_Array = []
94
+ for i in range(self.input_no):
95
+ A = AB.crop((w2 * i, 0, w2 * (i+1), h))
96
+ A = A_transform(A)
97
+ A_Array.append(A)
98
+
99
+ for i in range(self.input_no, self.input_no + self.modalities_no + 1):
100
+ B = AB.crop((w2 * i, 0, w2 * (i + 1), h))
101
+ B = B_transform(B)
102
+ B_Array.append(B)
103
+
104
+ return {'A': A_Array, 'B': B_Array, 'A_paths': AB_path, 'B_paths': AB_path}
88
105
  else:
89
106
  raise Exception(f'model class {self.model} does not have corresponding implementation in deepliif/data/aligned_dataset.py')
90
107
 
@@ -0,0 +1,189 @@
1
+ import torch
2
+ from .base_model import BaseModel
3
+ from . import networks
4
+
5
+
6
+ class SDGModel(BaseModel):
7
+ """ This class implements the Synthetic Data Generation model (based on DeepLIIFExt), for learning a mapping from input images to modalities given paired data."""
8
+
9
+ def __init__(self, opt):
10
+ """Initialize the DeepLIIF class.
11
+
12
+ Parameters:
13
+ opt (Option class)-- stores all the experiment flags; needs to be a subclass of BaseOptions
14
+ """
15
+ BaseModel.__init__(self, opt)
16
+
17
+ self.mod_gen_no = self.opt.modalities_no
18
+
19
+ # weights of the modalities in generating segmentation mask
20
+ self.seg_weights = [0, 0, 0]
21
+ if opt.seg_gen:
22
+ self.seg_weights = [0.3] * self.mod_gen_no
23
+ self.seg_weights[1] = 0.4
24
+
25
+ # self.seg_weights = opt.seg_weights
26
+ # assert len(self.seg_weights) == self.seg_gen_no, 'The number of the segmentation weights (seg_weights) is not equal to the number of target images (modalities_no)!'
27
+ # print(self.seg_weights)
28
+ # loss weights in calculating the final loss
29
+ self.loss_G_weights = [1 / self.mod_gen_no] * self.mod_gen_no
30
+ self.loss_GS_weights = [1 / self.mod_gen_no] * self.mod_gen_no
31
+
32
+ self.loss_D_weights = [1 / self.mod_gen_no] * self.mod_gen_no
33
+ self.loss_DS_weights = [1 / self.mod_gen_no] * self.mod_gen_no
34
+
35
+ self.loss_names = []
36
+ self.visual_names = ['real_A']
37
+ # specify the training losses you want to print out. The training/test scripts will call <BaseModel.get_current_losses>
38
+ for i in range(1, self.mod_gen_no + 1):
39
+ self.loss_names.extend(['G_GAN_' + str(i), 'G_L1_' + str(i), 'D_real_' + str(i), 'D_fake_' + str(i)])
40
+ self.visual_names.extend(['fake_B_' + str(i), 'real_B_' + str(i)])
41
+
42
+ # specify the images you want to save/display. The training/test scripts will call <BaseModel.get_current_visuals>
43
+ # specify the models you want to save to the disk. The training/test scripts will call <BaseModel.save_networks> and <BaseModel.load_networks>
44
+ if self.is_train:
45
+ self.model_names = []
46
+ for i in range(1, self.mod_gen_no + 1):
47
+ self.model_names.extend(['G_' + str(i), 'D_' + str(i)])
48
+
49
+ else: # during test time, only load G
50
+ self.model_names = []
51
+ for i in range(1, self.mod_gen_no + 1):
52
+ self.model_names.extend(['G_' + str(i)])
53
+
54
+ # define networks (both generator and discriminator)
55
+ self.netG = [None for _ in range(self.mod_gen_no)]
56
+ for i in range(self.mod_gen_no):
57
+ self.netG[i] = networks.define_G(self.opt.input_nc * self.opt.input_no, self.opt.output_nc, self.opt.ngf, self.opt.net_g, self.opt.norm,
58
+ not self.opt.no_dropout, self.opt.init_type, self.opt.init_gain, self.opt.gpu_ids, self.opt.padding)
59
+ print('***************************************')
60
+ print(self.opt.input_nc, self.opt.output_nc, self.opt.ngf, self.opt.net_g, self.opt.norm,
61
+ not self.opt.no_dropout, self.opt.init_type, self.opt.init_gain, self.opt.gpu_ids, self.opt.padding)
62
+ print('***************************************')
63
+
64
+ if self.is_train: # define a discriminator; conditional GANs need to take both input and output images; Therefore, #channels for D is input_nc + output_nc
65
+ self.netD = [None for _ in range(self.mod_gen_no)]
66
+ for i in range(self.mod_gen_no):
67
+ self.netD[i] = networks.define_D(self.opt.input_nc * self.opt.input_no + self.opt.output_nc, self.opt.ndf, self.opt.net_d,
68
+ self.opt.n_layers_D, self.opt.norm, self.opt.init_type, self.opt.init_gain,
69
+ self.opt.gpu_ids)
70
+
71
+ if self.is_train:
72
+ # define loss functions
73
+ self.criterionGAN_mod = networks.GANLoss(self.opt.gan_mode).to(self.device)
74
+ self.criterionGAN_seg = networks.GANLoss(self.opt.gan_mode_s).to(self.device)
75
+
76
+ self.criterionSmoothL1 = torch.nn.SmoothL1Loss()
77
+
78
+ self.criterionVGG = networks.VGGLoss().to(self.device)
79
+
80
+ # initialize optimizers; schedulers will be automatically created by function <BaseModel.setup>.
81
+ params = []
82
+ for i in range(len(self.netG)):
83
+ params += list(self.netG[i].parameters())
84
+ self.optimizer_G = torch.optim.Adam(params, lr=opt.lr, betas=(opt.beta1, 0.999))
85
+
86
+ params = []
87
+ for i in range(len(self.netD)):
88
+ params += list(self.netD[i].parameters())
89
+ self.optimizer_D = torch.optim.Adam(params, lr=opt.lr, betas=(opt.beta1, 0.999))
90
+
91
+ self.optimizers.append(self.optimizer_G)
92
+ self.optimizers.append(self.optimizer_D)
93
+
94
+ def set_input(self, input):
95
+ """
96
+ Unpack input data from the dataloader and perform necessary pre-processing steps.
97
+
98
+ :param input (dict): include the input image and the output modalities
99
+ """
100
+ self.real_A_array = input['A']
101
+ As = [A.to(self.device) for A in self.real_A_array]
102
+ self.real_A = torch.cat(As, dim=1) # shape: 1, (3 x input_no), 512, 512
103
+
104
+ self.real_B_array = input['B']
105
+ self.real_B = []
106
+ for i in range(len(self.real_B_array)):
107
+ self.real_B.append(self.real_B_array[i].to(self.device))
108
+
109
+ self.real_concatenated = []
110
+ self.image_paths = input['A_paths']
111
+
112
+ def forward(self):
113
+ """Run forward pass; called by both functions <optimize_parameters> and <test>."""
114
+ self.fake_B = []
115
+ for i in range(self.mod_gen_no):
116
+ self.fake_B.append(self.netG[i](self.real_A))
117
+
118
+
119
+ def backward_D(self):
120
+ """Calculate GAN loss for the discriminators"""
121
+
122
+ pred_fake = []
123
+ for i in range(self.mod_gen_no):
124
+ pred_fake.append(self.netD[i](torch.cat((self.real_A, self.fake_B[i]), 1).detach()))
125
+
126
+ self.loss_D_fake = []
127
+ for i in range(self.mod_gen_no):
128
+ self.loss_D_fake.append(self.criterionGAN_mod(pred_fake[i], False))
129
+
130
+ pred_real = []
131
+ for i in range(self.mod_gen_no):
132
+ pred_real.append(self.netD[i](torch.cat((self.real_A, self.real_B[i]), 1)))
133
+
134
+ self.loss_D_real = []
135
+ for i in range(self.mod_gen_no):
136
+ self.loss_D_real.append(self.criterionGAN_mod(pred_real[i], True))
137
+
138
+ # combine losses and calculate gradients
139
+ # self.loss_D = (self.loss_D_fake[0] + self.loss_D_real[0]) * 0.5 * self.loss_D_weights[0]
140
+ self.loss_D = torch.tensor(0., device=self.device)
141
+ for i in range(0, self.mod_gen_no):
142
+ self.loss_D += (self.loss_D_fake[i] + self.loss_D_real[i]) * 0.5 * self.loss_D_weights[i]
143
+ self.loss_D.backward()
144
+
145
+ def backward_G(self):
146
+ """Calculate GAN and L1 loss for the generator"""
147
+ pred_fake = []
148
+ for i in range(self.mod_gen_no):
149
+ pred_fake.append(self.netD[i](torch.cat((self.real_A, self.fake_B[i]), 1)))
150
+
151
+ self.loss_G_GAN = []
152
+ self.loss_GS_GAN = []
153
+ for i in range(self.mod_gen_no):
154
+ self.loss_G_GAN.append(self.criterionGAN_mod(pred_fake[i], True))
155
+
156
+ # Second, G(A) = B
157
+ self.loss_G_L1 = []
158
+ self.loss_GS_L1 = []
159
+ for i in range(self.mod_gen_no):
160
+ self.loss_G_L1.append(self.criterionSmoothL1(self.fake_B[i], self.real_B[i]) * self.opt.lambda_L1)
161
+
162
+ #self.loss_G_VGG = []
163
+ #for i in range(self.mod_gen_no):
164
+ # self.loss_G_VGG.append(self.criterionVGG(self.fake_B[i], self.real_B[i]) * self.opt.lambda_feat)
165
+
166
+ # self.loss_G = (self.loss_G_GAN[0] + self.loss_G_L1[0]) * self.loss_G_weights[0]
167
+ self.loss_G = torch.tensor(0., device=self.device)
168
+ for i in range(0, self.mod_gen_no):
169
+ self.loss_G += (self.loss_G_GAN[i] + self.loss_G_L1[i]) * self.loss_G_weights[i]
170
+ # self.loss_G += (self.loss_G_GAN[i] + self.loss_G_L1[i] + self.loss_G_VGG[i]) * self.loss_G_weights[i]
171
+ self.loss_G.backward()
172
+
173
+ def optimize_parameters(self):
174
+ self.forward() # compute fake images: G(A)
175
+ # update D
176
+ for i in range(self.mod_gen_no):
177
+ self.set_requires_grad(self.netD[i], True) # enable backprop for D1
178
+
179
+ self.optimizer_D.zero_grad() # set D's gradients to zero
180
+ self.backward_D() # calculate gradients for D
181
+ self.optimizer_D.step() # update D's weights
182
+
183
+ # update G
184
+ for i in range(self.mod_gen_no):
185
+ self.set_requires_grad(self.netD[i], False)
186
+
187
+ self.optimizer_G.zero_grad() # set G's gradients to zero
188
+ self.backward_G() # calculate graidents for G
189
+ self.optimizer_G.step() # udpate G's weights