deepliif 1.1.7__py3-none-any.whl → 1.1.8__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -31,15 +31,32 @@ import numpy as np
31
31
  from dask import delayed, compute
32
32
 
33
33
  from deepliif.util import *
34
- from deepliif.util.util import tensor_to_pil
34
+ from deepliif.util.util import tensor_to_pil, check_multi_scale
35
35
  from deepliif.data import transform
36
- from deepliif.postprocessing import adjust_marker, adjust_dapi, compute_IHC_scoring, \
37
- overlay_final_segmentation_mask, create_final_segmentation_mask_with_boundaries, create_basic_segmentation_mask
36
+ from deepliif.postprocessing import compute_results
37
+ from deepliif.options import Options, print_options
38
38
 
39
39
  from .base_model import BaseModel
40
+
41
+ # import for init purpose, not used in this script
40
42
  from .DeepLIIF_model import DeepLIIFModel
41
- from .networks import get_norm_layer, ResnetGenerator, UnetGenerator
43
+ from .DeepLIIFExt_model import DeepLIIFExtModel
44
+
42
45
 
46
+ @lru_cache
47
+ def get_opt(model_dir, mode='test'):
48
+ """
49
+ mode: test or train, currently only functions used for inference utilize get_opt so it
50
+ defaults to test
51
+ """
52
+ if mode == 'train':
53
+ opt = Options(path_file=os.path.join(model_dir,'train_opt.txt'), mode=mode)
54
+ elif mode == 'test':
55
+ try:
56
+ opt = Options(path_file=os.path.join(model_dir,'test_opt.txt'), mode=mode)
57
+ except:
58
+ opt = Options(path_file=os.path.join(model_dir,'train_opt.txt'), mode=mode)
59
+ return opt
43
60
 
44
61
  def find_model_using_name(model_name):
45
62
  """Import the module "models/[model_name]_model.py".
@@ -94,87 +111,75 @@ def load_torchscript_model(model_pt_path, device):
94
111
  return net
95
112
 
96
113
 
97
- def read_model_params(file_addr):
98
- with open(file_addr) as f:
99
- lines = f.readlines()
100
- param_dict = {}
101
- for line in lines:
102
- if ':' in line:
103
- key = line.split(':')[0].strip()
104
- val = line.split(':')[1].split('[')[0].strip()
105
- param_dict[key] = val
106
- print(param_dict)
107
- return param_dict
108
-
109
-
110
- def load_eager_models(model_dir, devices):
111
- input_nc = 3
112
- output_nc = 3
113
- ngf = 64
114
- norm = 'batch'
115
- use_dropout = True
116
- padding_type = 'zero'
117
-
118
- files = os.listdir(model_dir)
119
- for f in files:
120
- if 'train_opt.txt' in f:
121
- param_dict = read_model_params(os.path.join(model_dir, f))
122
- input_nc = int(param_dict['input_nc'])
123
- output_nc = int(param_dict['output_nc'])
124
- ngf = int(param_dict['ngf'])
125
- norm = param_dict['norm']
126
- use_dropout = False if param_dict['no_dropout'] == 'True' else True
127
- padding_type = param_dict['padding']
128
-
129
- norm_layer = get_norm_layer(norm_type=norm)
130
114
 
131
- nets = {}
132
- for n in ['G1', 'G2', 'G3', 'G4']:
133
- net = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=9, padding_type=padding_type)
134
- net.load_state_dict(torch.load(
135
- os.path.join(model_dir, f'latest_net_{n}.pth'),
136
- map_location=devices[n]
137
- ))
138
- nets[n] = disable_batchnorm_tracking_stats(net)
139
- nets[n].eval()
140
-
141
- for n in ['G51', 'G52', 'G53', 'G54', 'G55']:
142
- net = UnetGenerator(input_nc, output_nc, 9, ngf, norm_layer=norm_layer, use_dropout=use_dropout)
143
- net.load_state_dict(torch.load(
144
- os.path.join(model_dir, f'latest_net_{n}.pth'),
145
- map_location=devices[n]
146
- ))
147
- nets[n] = disable_batchnorm_tracking_stats(net)
148
- nets[n].eval()
115
+ def load_eager_models(opt, devices):
116
+ # create a model given model and other options
117
+ model = create_model(opt)
118
+ # regular setup: load and print networks; create schedulers
119
+ model.setup(opt)
149
120
 
121
+ nets = {}
122
+ for name in model.model_names:
123
+ if isinstance(name, str):
124
+ if '_' in name:
125
+ net = getattr(model, 'net' + name.split('_')[0])[int(name.split('_')[-1]) - 1]
126
+ else:
127
+ net = getattr(model, 'net' + name)
128
+
129
+ if opt.phase != 'train':
130
+ net.eval()
131
+ net = disable_batchnorm_tracking_stats(net)
132
+
133
+ nets[name] = net
134
+ nets[name].to(devices[name])
135
+
150
136
  return nets
151
137
 
152
138
 
153
139
  @lru_cache
154
- def init_nets(model_dir, eager_mode=False):
140
+ def init_nets(model_dir, eager_mode=False, opt=None, phase='test'):
155
141
  """
156
142
  Init DeepLIIF networks so that every net in
157
143
  the same group is deployed on the same GPU
158
- """
159
- net_groups = [
160
- ('G1', 'G52'),
161
- ('G2', 'G53'),
162
- ('G3', 'G54'),
163
- ('G4', 'G55'),
164
- ('G51',)
165
- ]
166
-
167
- number_of_gpus = torch.cuda.device_count()
168
- # number_of_gpus = 0
169
- if number_of_gpus:
144
+
145
+ opt_args: to overwrite opt arguments in train_opt.txt, typically used in inference stage
146
+ for example, opt_args={'phase':'test'}
147
+ """
148
+ if opt is None:
149
+ opt = get_opt(model_dir, mode=phase)
150
+ opt.use_dp = False
151
+ print_options(opt)
152
+
153
+ if opt.model == 'DeepLIIF':
154
+ net_groups = [
155
+ ('G1', 'G52'),
156
+ ('G2', 'G53'),
157
+ ('G3', 'G54'),
158
+ ('G4', 'G55'),
159
+ ('G51',)
160
+ ]
161
+ elif opt.model == 'DeepLIIFExt':
162
+ if opt.seg_gen:
163
+ net_groups = [(f'G_{i+1}',f'GS_{i+1}') for i in range(opt.modalities_no)]
164
+ else:
165
+ net_groups = [(f'G_{i+1}',) for i in range(opt.modalities_no)]
166
+ else:
167
+ raise Exception(f'init_nets() not implemented for model {opt.model}')
168
+
169
+ number_of_gpus_all = torch.cuda.device_count()
170
+ number_of_gpus = len(opt.gpu_ids)
171
+ print(number_of_gpus)
172
+ if number_of_gpus > 0:
173
+ mapping_gpu_ids = {i:idx for i,idx in enumerate(opt.gpu_ids)}
170
174
  chunks = [itertools.chain.from_iterable(c) for c in chunker(net_groups, number_of_gpus)]
171
175
  # chunks = chunks[1:]
172
- devices = {n: torch.device(f'cuda:{i}') for i, g in enumerate(chunks) for n in g}
176
+ devices = {n: torch.device(f'cuda:{mapping_gpu_ids[i]}') for i, g in enumerate(chunks) for n in g}
177
+ # devices = {n: torch.device(f'cuda:{i}') for i, g in enumerate(chunks) for n in g}
173
178
  else:
174
179
  devices = {n: torch.device('cpu') for n in itertools.chain.from_iterable(net_groups)}
175
180
 
176
181
  if eager_mode:
177
- return load_eager_models(model_dir, devices)
182
+ return load_eager_models(opt, devices)
178
183
 
179
184
  return {
180
185
  n: load_torchscript_model(os.path.join(model_dir, f'{n}.pt'), device=d)
@@ -190,14 +195,18 @@ def compute_overlap(img_size, tile_size):
190
195
  return tile_size // 4
191
196
 
192
197
 
193
- def run_torchserve(img, model_path=None, eager_mode=False):
198
+ def run_torchserve(img, model_path=None, eager_mode=False, opt=None):
194
199
  """
195
200
  eager_mode: not used in this function; put in place to be consistent with run_dask
196
201
  so that run_wrapper() could call either this function or run_dask with
197
202
  same syntax
203
+ opt: same as eager_mode
198
204
  """
199
205
  buffer = BytesIO()
200
- torch.save(transform(img.resize((512, 512))), buffer)
206
+ if opt.model == 'DeepLIIFExt':
207
+ torch.save(transform(img.resize((1024, 1024))), buffer)
208
+ else:
209
+ torch.save(transform(img.resize((512, 512))), buffer)
201
210
 
202
211
  torchserve_host = os.getenv('TORCHSERVE_HOST', 'http://localhost')
203
212
  res = requests.post(
@@ -213,33 +222,55 @@ def run_torchserve(img, model_path=None, eager_mode=False):
213
222
  return {k: tensor_to_pil(deserialize_tensor(v)) for k, v in res.json().items()}
214
223
 
215
224
 
216
- def run_dask(img, model_path, eager_mode=False):
225
+ def run_dask(img, model_path, eager_mode=False, opt=None):
217
226
  model_dir = os.getenv('DEEPLIIF_MODEL_DIR', model_path)
218
- nets = init_nets(model_dir, eager_mode)
219
-
220
- ts = transform(img.resize((512, 512)))
227
+ nets = init_nets(model_dir, eager_mode, opt)
228
+
229
+ if opt.model == 'DeepLIIFExt':
230
+ ts = transform(img.resize((1024, 1024)))
231
+ else:
232
+ ts = transform(img.resize((512, 512)))
221
233
 
222
234
  @delayed
223
235
  def forward(input, model):
224
236
  with torch.no_grad():
225
237
  return model(input.to(next(model.parameters()).device))
238
+
239
+ if opt.model == 'DeepLIIF':
240
+ seg_map = {'G1': 'G52', 'G2': 'G53', 'G3': 'G54', 'G4': 'G55'}
241
+
242
+ lazy_gens = {k: forward(ts, nets[k]) for k in seg_map}
243
+ gens = compute(lazy_gens)[0]
244
+
245
+ lazy_segs = {v: forward(gens[k], nets[v]).to(torch.device('cpu')) for k, v in seg_map.items()}
246
+ lazy_segs['G51'] = forward(ts, nets['G51']).to(torch.device('cpu'))
247
+ segs = compute(lazy_segs)[0]
248
+
249
+ seg_weights = [0.25, 0.25, 0.25, 0, 0.25]
250
+ seg = torch.stack([torch.mul(n, w) for n, w in zip(segs.values(), seg_weights)]).sum(dim=0)
251
+
252
+ res = {k: tensor_to_pil(v) for k, v in gens.items()}
253
+ res['G5'] = tensor_to_pil(seg)
254
+
255
+ return res
256
+ elif opt.model == 'DeepLIIFExt':
257
+ seg_map = {'G_' + str(i): 'GS_' + str(i) for i in range(1, opt.modalities_no + 1)}
258
+
259
+ lazy_gens = {k: forward(ts, nets[k]) for k in seg_map}
260
+ gens = compute(lazy_gens)[0]
261
+
262
+ res = {k: tensor_to_pil(v) for k, v in gens.items()}
263
+
264
+ if opt.seg_gen:
265
+ lazy_segs = {v: forward(torch.cat([ts.to(torch.device('cpu')), gens[next(iter(seg_map))].to(torch.device('cpu')), gens[k].to(torch.device('cpu'))], 1), nets[v]).to(torch.device('cpu')) for k, v in seg_map.items()}
266
+ segs = compute(lazy_segs)[0]
267
+ res.update({k: tensor_to_pil(v) for k, v in segs.items()})
268
+
269
+ return res
270
+ else:
271
+ raise Exception(f'run_dask() not implemented for {opt.model}')
226
272
 
227
- seg_map = {'G1': 'G52', 'G2': 'G53', 'G3': 'G54', 'G4': 'G55'}
228
-
229
- lazy_gens = {k: forward(ts, nets[k]) for k in seg_map}
230
- gens = compute(lazy_gens)[0]
231
-
232
- lazy_segs = {v: forward(gens[k], nets[v]).to(torch.device('cpu')) for k, v in seg_map.items()}
233
- lazy_segs['G51'] = forward(ts, nets['G51']).to(torch.device('cpu'))
234
- segs = compute(lazy_segs)[0]
235
-
236
- seg_weights = [0.25, 0.25, 0.25, 0, 0.25]
237
- seg = torch.stack([torch.mul(n, w) for n, w in zip(segs.values(), seg_weights)]).sum(dim=0)
238
-
239
- res = {k: tensor_to_pil(v) for k, v in gens.items()}
240
- res['G5'] = tensor_to_pil(seg)
241
-
242
- return res
273
+
243
274
 
244
275
 
245
276
  def is_empty(tile):
@@ -247,17 +278,27 @@ def is_empty(tile):
247
278
  return True if calculate_background_area(tile) > 98 else False
248
279
 
249
280
 
250
- def run_wrapper(tile, run_fn, model_path, eager_mode=False):
251
- if is_empty(tile):
252
- return {
253
- 'G1': Image.new(mode='RGB', size=(512, 512), color=(201, 211, 208)),
254
- 'G2': Image.new(mode='RGB', size=(512, 512), color=(10, 10, 10)),
255
- 'G3': Image.new(mode='RGB', size=(512, 512), color=(0, 0, 0)),
256
- 'G4': Image.new(mode='RGB', size=(512, 512), color=(10, 10, 10)),
257
- 'G5': Image.new(mode='RGB', size=(512, 512), color=(0, 0, 0))
258
- }
281
+ def run_wrapper(tile, run_fn, model_path, eager_mode=False, opt=None):
282
+ if opt.model == 'DeepLIIF':
283
+ if is_empty(tile):
284
+ return {
285
+ 'G1': Image.new(mode='RGB', size=(512, 512), color=(201, 211, 208)),
286
+ 'G2': Image.new(mode='RGB', size=(512, 512), color=(10, 10, 10)),
287
+ 'G3': Image.new(mode='RGB', size=(512, 512), color=(0, 0, 0)),
288
+ 'G4': Image.new(mode='RGB', size=(512, 512), color=(10, 10, 10)),
289
+ 'G5': Image.new(mode='RGB', size=(512, 512), color=(0, 0, 0))
290
+ }
291
+ else:
292
+ return run_fn(tile, model_path, eager_mode, opt)
293
+ elif opt.model == 'DeepLIIFExt':
294
+ if is_empty(tile):
295
+ res = {'G_' + str(i): Image.new(mode='RGB', size=(512, 512)) for i in range(1, opt.modalities_no + 1)}
296
+ res.update({'GS_' + str(i): Image.new(mode='RGB', size=(512, 512)) for i in range(1, opt.modalities_no + 1)})
297
+ return res
298
+ else:
299
+ return run_fn(tile, model_path, eager_mode, opt)
259
300
  else:
260
- return run_fn(tile, model_path, eager_mode)
301
+ raise Exception(f'run_wrapper() not implemented for model {opt.model}')
261
302
 
262
303
 
263
304
  def inference_old(img, tile_size, overlap_size, model_path, use_torchserve=False, eager_mode=False,
@@ -307,71 +348,115 @@ def inference_old(img, tile_size, overlap_size, model_path, use_torchserve=False
307
348
 
308
349
 
309
350
  def inference(img, tile_size, overlap_size, model_path, use_torchserve=False, eager_mode=False,
310
- color_dapi=False, color_marker=False):
311
-
312
- rescaled, rows, cols = format_image_for_tiling(img, tile_size, overlap_size)
313
-
314
- run_fn = run_torchserve if use_torchserve else run_dask
315
-
316
- images = {}
317
- images['Hema'] = create_image_for_stitching(tile_size, rows, cols)
318
- images['DAPI'] = create_image_for_stitching(tile_size, rows, cols)
319
- images['Lap2'] = create_image_for_stitching(tile_size, rows, cols)
320
- images['Marker'] = create_image_for_stitching(tile_size, rows, cols)
321
- images['Seg'] = create_image_for_stitching(tile_size, rows, cols)
322
-
323
- for i in range(cols):
324
- for j in range(rows):
325
- tile = extract_tile(rescaled, tile_size, overlap_size, i, j)
326
- res = run_wrapper(tile, run_fn, model_path, eager_mode)
327
-
328
- stitch_tile(images['Hema'], res['G1'], tile_size, overlap_size, i, j)
329
- stitch_tile(images['DAPI'], res['G2'], tile_size, overlap_size, i, j)
330
- stitch_tile(images['Lap2'], res['G3'], tile_size, overlap_size, i, j)
331
- stitch_tile(images['Marker'], res['G4'], tile_size, overlap_size, i, j)
332
- stitch_tile(images['Seg'], res['G5'], tile_size, overlap_size, i, j)
333
-
334
- images['Hema'] = images['Hema'].resize(img.size)
335
- images['DAPI'] = images['DAPI'].resize(img.size)
336
- images['Lap2'] = images['Lap2'].resize(img.size)
337
- images['Marker'] = images['Marker'].resize(img.size)
338
- images['Seg'] = images['Seg'].resize(img.size)
339
-
340
- if color_dapi:
341
- matrix = ( 0, 0, 0, 0,
342
- 299/1000, 587/1000, 114/1000, 0,
343
- 299/1000, 587/1000, 114/1000, 0)
344
- images['DAPI'] = images['DAPI'].convert('RGB', matrix)
345
-
346
- if color_marker:
347
- matrix = (299/1000, 587/1000, 114/1000, 0,
348
- 299/1000, 587/1000, 114/1000, 0,
349
- 0, 0, 0, 0)
350
- images['Marker'] = images['Marker'].convert('RGB', matrix)
351
-
352
- return images
353
-
354
-
355
- def postprocess(img, seg_img, thresh=80, noise_objects_size=20, small_object_size=50):
356
- mask_image = create_basic_segmentation_mask(np.array(img), np.array(seg_img),
357
- thresh, noise_objects_size, small_object_size)
358
- images = {}
359
- images['SegOverlaid'] = Image.fromarray(overlay_final_segmentation_mask(np.array(img), mask_image))
360
- images['SegRefined'] = Image.fromarray(create_final_segmentation_mask_with_boundaries(np.array(mask_image)))
361
-
362
- all_cells_no, positive_cells_no, negative_cells_no, IHC_score = compute_IHC_scoring(mask_image)
363
- scoring = {
364
- 'num_total': all_cells_no,
365
- 'num_pos': positive_cells_no,
366
- 'num_neg': negative_cells_no,
367
- 'percent_pos': IHC_score
368
- }
351
+ color_dapi=False, color_marker=False, opt=None):
352
+ if not opt:
353
+ opt = get_opt(model_path)
354
+ print_options(opt)
355
+
356
+ if opt.model == 'DeepLIIF':
357
+ rescaled, rows, cols = format_image_for_tiling(img, tile_size, overlap_size)
358
+
359
+ run_fn = run_torchserve if use_torchserve else run_dask
360
+
361
+ images = {}
362
+ images['Hema'] = create_image_for_stitching(tile_size, rows, cols)
363
+ images['DAPI'] = create_image_for_stitching(tile_size, rows, cols)
364
+ images['Lap2'] = create_image_for_stitching(tile_size, rows, cols)
365
+ images['Marker'] = create_image_for_stitching(tile_size, rows, cols)
366
+ images['Seg'] = create_image_for_stitching(tile_size, rows, cols)
367
+
368
+ for i in range(cols):
369
+ for j in range(rows):
370
+ tile = extract_tile(rescaled, tile_size, overlap_size, i, j)
371
+ res = run_wrapper(tile, run_fn, model_path, eager_mode, opt)
372
+
373
+ stitch_tile(images['Hema'], res['G1'], tile_size, overlap_size, i, j)
374
+ stitch_tile(images['DAPI'], res['G2'], tile_size, overlap_size, i, j)
375
+ stitch_tile(images['Lap2'], res['G3'], tile_size, overlap_size, i, j)
376
+ stitch_tile(images['Marker'], res['G4'], tile_size, overlap_size, i, j)
377
+ stitch_tile(images['Seg'], res['G5'], tile_size, overlap_size, i, j)
378
+
379
+ images['Hema'] = images['Hema'].resize(img.size)
380
+ images['DAPI'] = images['DAPI'].resize(img.size)
381
+ images['Lap2'] = images['Lap2'].resize(img.size)
382
+ images['Marker'] = images['Marker'].resize(img.size)
383
+ images['Seg'] = images['Seg'].resize(img.size)
384
+
385
+ if color_dapi:
386
+ matrix = ( 0, 0, 0, 0,
387
+ 299/1000, 587/1000, 114/1000, 0,
388
+ 299/1000, 587/1000, 114/1000, 0)
389
+ images['DAPI'] = images['DAPI'].convert('RGB', matrix)
390
+
391
+ if color_marker:
392
+ matrix = (299/1000, 587/1000, 114/1000, 0,
393
+ 299/1000, 587/1000, 114/1000, 0,
394
+ 0, 0, 0, 0)
395
+ images['Marker'] = images['Marker'].convert('RGB', matrix)
396
+
397
+ return images
398
+
399
+ elif opt.model == 'DeepLIIFExt':
400
+ #param_dict = read_train_options(model_path)
401
+ #modalities_no = int(param_dict['modalities_no']) if param_dict else 4
402
+ #seg_gen = (param_dict['seg_gen'] == 'True') if param_dict else True
403
+
404
+ tiles = list(generate_tiles(img, tile_size, overlap_size))
405
+
406
+ run_fn = run_torchserve if use_torchserve else run_dask
407
+ res = [Tile(t.i, t.j, run_wrapper(t.img, run_fn, model_path, eager_mode, opt)) for t in tiles]
408
+
409
+ def get_net_tiles(n):
410
+ return [Tile(t.i, t.j, t.img[n]) for t in res]
411
+
412
+ images = {}
413
+
414
+ for i in range(1, opt.modalities_no + 1):
415
+ images['mod' + str(i)] = stitch(get_net_tiles('G_' + str(i)), tile_size, overlap_size).resize(img.size)
416
+
417
+ if opt.seg_gen:
418
+ for i in range(1, opt.modalities_no + 1):
419
+ images['Seg' + str(i)] = stitch(get_net_tiles('GS_' + str(i)), tile_size, overlap_size).resize(img.size)
420
+
421
+ return images
422
+
423
+ else:
424
+ raise Exception(f'inference() not implemented for model {opt.model}')
425
+
426
+
427
+ def postprocess(orig, images, tile_size, seg_thresh=150, size_thresh='default', marker_thresh='default', size_thresh_upper=None, opt=None):
428
+ if opt.model == 'DeepLIIF':
429
+ resolution = '40x' if tile_size > 384 else ('20x' if tile_size > 192 else '10x')
430
+ overlay, refined, scoring = compute_results(np.array(orig), np.array(images['Seg']),
431
+ np.array(images['Marker'].convert('L')), resolution,
432
+ seg_thresh, size_thresh, marker_thresh, size_thresh_upper)
433
+ processed_images = {}
434
+ processed_images['SegOverlaid'] = Image.fromarray(overlay)
435
+ processed_images['SegRefined'] = Image.fromarray(refined)
436
+ return processed_images, scoring
437
+
438
+ elif opt.model == 'DeepLIIFExt':
439
+ resolution = '40x' if tile_size > 768 else ('20x' if tile_size > 384 else '10x')
440
+ processed_images = {}
441
+ scoring = {}
442
+ for img_name in list(images.keys()):
443
+ if 'Seg' in img_name:
444
+ seg_img = images[img_name]
445
+ overlay, refined, score = compute_results(np.array(orig), np.array(images[img_name]),
446
+ None, resolution,
447
+ seg_thresh, size_thresh, marker_thresh, size_thresh_upper)
448
+
449
+ processed_images[img_name + '_Overlaid'] = Image.fromarray(overlay)
450
+ processed_images[img_name + '_Refined'] = Image.fromarray(refined)
451
+ scoring[img_name] = score
452
+ return processed_images, scoring
369
453
 
370
- return images, scoring
454
+ else:
455
+ raise Exception(f'postprocess() not implemented for model {opt.model}')
371
456
 
372
457
 
373
458
  def infer_modalities(img, tile_size, model_dir, eager_mode=False,
374
- color_dapi=False, color_marker=False):
459
+ color_dapi=False, color_marker=False, opt=None):
375
460
  """
376
461
  This function is used to infer modalities for the given image using a trained model.
377
462
  :param img: The input image.
@@ -379,6 +464,11 @@ def infer_modalities(img, tile_size, model_dir, eager_mode=False,
379
464
  :param model_dir: The directory containing serialized model files.
380
465
  :return: The inferred modalities and the segmentation mask.
381
466
  """
467
+ if opt is None:
468
+ opt = get_opt(model_dir)
469
+ opt.use_dp = False
470
+ print_options(opt)
471
+
382
472
  if not tile_size:
383
473
  tile_size = check_multi_scale(Image.open('./images/target.png').convert('L'),
384
474
  img.convert('L'))
@@ -391,12 +481,16 @@ def infer_modalities(img, tile_size, model_dir, eager_mode=False,
391
481
  model_path=model_dir,
392
482
  eager_mode=eager_mode,
393
483
  color_dapi=color_dapi,
394
- color_marker=color_marker
484
+ color_marker=color_marker,
485
+ opt=opt
395
486
  )
396
-
397
- post_images, scoring = postprocess(img, images['Seg'], small_object_size=20)
398
- images = {**images, **post_images}
399
- return images, scoring
487
+
488
+ if not hasattr(opt,'seg_gen') or (hasattr(opt,'seg_gen') and opt.seg_gen): # the first condition accounts for old settings of deepliif; the second refers to deepliifext models
489
+ post_images, scoring = postprocess(img, images, tile_size, opt=opt)
490
+ images = {**images, **post_images}
491
+ return images, scoring
492
+ else:
493
+ return images, None
400
494
 
401
495
 
402
496
  def infer_results_for_wsi(input_dir, filename, output_dir, model_dir, tile_size, region_size=20000):
@@ -4,6 +4,8 @@ from collections import OrderedDict
4
4
  from abc import ABC, abstractmethod
5
5
  from . import networks
6
6
  from ..util import disable_batchnorm_tracking_stats
7
+ from deepliif.util import *
8
+ import itertools
7
9
 
8
10
 
9
11
  class BaseModel(ABC):
@@ -35,8 +37,9 @@ class BaseModel(ABC):
35
37
  self.is_train = opt.is_train
36
38
  self.device = torch.device('cuda:{}'.format(self.gpu_ids[0])) if self.gpu_ids else torch.device('cpu') # get device name: CPU or GPU
37
39
  self.save_dir = os.path.join(opt.checkpoints_dir, opt.name) # save all the checkpoints to save_dir
38
- if opt.preprocess != 'scale_width': # with [scale_width], input images might have different sizes, which hurts the performance of cudnn.benchmark.
40
+ if opt.phase == 'train' and opt.preprocess != 'scale_width': # with [scale_width], input images might have different sizes, which hurts the performance of cudnn.benchmark.
39
41
  torch.backends.cudnn.benchmark = True
42
+ # especially for inference, cudnn benchmark can cause excessive usage of GPU memory for the first image in the sequence in order to find the best conv alg which is not necessary.
40
43
  self.loss_names = []
41
44
  self.model_names = []
42
45
  self.visual_names = []
@@ -89,7 +92,10 @@ class BaseModel(ABC):
89
92
  """Make models eval mode during test time"""
90
93
  for name in self.model_names:
91
94
  if isinstance(name, str):
92
- net = getattr(self, 'net' + name)
95
+ if '_' in name:
96
+ net = getattr(self, 'net' + name.split('_')[0])[int(name.split('_')[-1]) - 1]
97
+ else:
98
+ net = getattr(self, 'net' + name)
93
99
  net.eval()
94
100
  net = disable_batchnorm_tracking_stats(net)
95
101
 
@@ -127,7 +133,13 @@ class BaseModel(ABC):
127
133
  visual_ret = OrderedDict()
128
134
  for name in self.visual_names:
129
135
  if isinstance(name, str):
130
- visual_ret[name] = getattr(self, name)
136
+ if not hasattr(self, name):
137
+ if len(name.split('_')) == 2:
138
+ visual_ret[name] = getattr(self, name.split('_')[0])[int(name.split('_')[-1]) -1]
139
+ else:
140
+ visual_ret[name] = getattr(self, name.split('_')[0] + '_' + name.split('_')[1])[int(name.split('_')[-1]) - 1]
141
+ else:
142
+ visual_ret[name] = getattr(self, name)
131
143
  return visual_ret
132
144
 
133
145
  def get_current_losses(self):
@@ -135,7 +147,16 @@ class BaseModel(ABC):
135
147
  errors_ret = OrderedDict()
136
148
  for name in self.loss_names:
137
149
  if isinstance(name, str):
138
- errors_ret[name] = float(getattr(self, 'loss_' + name)) # float(...) works for both scalar tensor and float number
150
+ if not hasattr(self, 'loss_'+name): # appears in DeepLIIFExt
151
+ if len(name.split('_')) == 2:
152
+ errors_ret[name] = float(getattr(self, 'loss_' + name.split('_')[0])[int(
153
+ name.split('_')[-1]) - 1]) # float(...) works for both scalar tensor and float number
154
+ else:
155
+ errors_ret[name] = float(getattr(self, 'loss_' + name.split('_')[0] + '_' + name.split('_')[1])[int(
156
+ name.split('_')[-1]) - 1]) # float(...) works for both scalar tensor and float number
157
+ else: # single numeric value
158
+ errors_ret[name] = float(getattr(self, 'loss_' + name))
159
+
139
160
  return errors_ret
140
161
 
141
162
  def save_networks(self, epoch, save_from_one_process=False):
@@ -151,7 +172,10 @@ class BaseModel(ABC):
151
172
  if isinstance(name, str):
152
173
  save_filename = '%s_net_%s.pth' % (epoch, name)
153
174
  save_path = os.path.join(self.save_dir, save_filename)
154
- net = getattr(self, 'net' + name)
175
+ if '_' in name:
176
+ net = getattr(self, 'net' + name.split('_')[0])[int(name.split('_')[-1]) - 1]
177
+ else:
178
+ net = getattr(self, 'net' + name)
155
179
 
156
180
  if len(self.gpu_ids) > 0 and torch.cuda.is_available():
157
181
  torch.save(net.module.cpu().state_dict(), save_path)
@@ -219,13 +243,32 @@ class BaseModel(ABC):
219
243
  if isinstance(name, str):
220
244
  load_filename = '%s_net_%s.pth' % (epoch, name)
221
245
  load_path = os.path.join(self.save_dir, load_filename)
222
- net = getattr(self, 'net' + name)
246
+ if '_' in name:
247
+ net = getattr(self, 'net' + name.split('_')[0])[int(name.split('_')[-1]) - 1]
248
+ else:
249
+ net = getattr(self, 'net' + name)
223
250
  if isinstance(net, torch.nn.DataParallel):
224
251
  net = net.module
252
+
253
+ self.set_requires_grad(net,self.opt.is_train)
254
+ # check if gradients are disabled
255
+ names_layer_requires_grad = []
256
+ for name, param in net.named_parameters():
257
+ if param.requires_grad:
258
+ names_layer_requires_grad.append(name)
259
+
225
260
  print('loading the model from %s' % load_path)
226
261
  # if you are using PyTorch newer than 0.4 (e.g., built from
227
262
  # GitHub source), you can remove str() on self.device
228
- state_dict = torch.load(load_path, map_location=str(self.device))
263
+
264
+ if self.opt.is_train or self.opt.use_dp:
265
+ device = self.device
266
+ else:
267
+ device = torch.device('cpu') # load in cpu first; later in __inite__.py::init_nets we will move it to the specified device
268
+
269
+ net.to(device)
270
+ state_dict = torch.load(load_path, map_location=str(device))
271
+
229
272
  if hasattr(state_dict, '_metadata'):
230
273
  del state_dict._metadata
231
274
 
@@ -243,7 +286,10 @@ class BaseModel(ABC):
243
286
  print('---------- Networks initialized -------------')
244
287
  for name in self.model_names:
245
288
  if isinstance(name, str):
246
- net = getattr(self, 'net' + name)
289
+ if '_' in name:
290
+ net = getattr(self, 'net' + name.split('_')[0])[int(name.split('_')[-1]) - 1]
291
+ else:
292
+ net = getattr(self, 'net' + name)
247
293
  num_params = 0
248
294
  for param in net.parameters():
249
295
  num_params += param.numel()