deepliif 1.1.7__py3-none-any.whl → 1.1.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -31,14 +31,34 @@ import numpy as np
31
31
  from dask import delayed, compute
32
32
 
33
33
  from deepliif.util import *
34
- from deepliif.util.util import tensor_to_pil
34
+ from deepliif.util.util import tensor_to_pil, check_multi_scale
35
35
  from deepliif.data import transform
36
- from deepliif.postprocessing import adjust_marker, adjust_dapi, compute_IHC_scoring, \
37
- overlay_final_segmentation_mask, create_final_segmentation_mask_with_boundaries, create_basic_segmentation_mask
36
+ from deepliif.postprocessing import compute_results
37
+ from deepliif.options import Options, print_options
38
38
 
39
39
  from .base_model import BaseModel
40
+
41
+ # import for init purpose, not used in this script
40
42
  from .DeepLIIF_model import DeepLIIFModel
41
- from .networks import get_norm_layer, ResnetGenerator, UnetGenerator
43
+ from .DeepLIIFExt_model import DeepLIIFExtModel
44
+
45
+
46
+ @lru_cache
47
+ def get_opt(model_dir, mode='test'):
48
+ """
49
+ mode: test or train, currently only functions used for inference utilize get_opt so it
50
+ defaults to test
51
+ """
52
+ if mode == 'train':
53
+ opt = Options(path_file=os.path.join(model_dir,'train_opt.txt'), mode=mode)
54
+ elif mode == 'test':
55
+ try:
56
+ opt = Options(path_file=os.path.join(model_dir,'test_opt.txt'), mode=mode)
57
+ except:
58
+ opt = Options(path_file=os.path.join(model_dir,'train_opt.txt'), mode=mode)
59
+ opt.use_dp = False
60
+ opt.gpu_ids = list(range(torch.cuda.device_count()))
61
+ return opt
42
62
 
43
63
 
44
64
  def find_model_using_name(model_name):
@@ -94,87 +114,75 @@ def load_torchscript_model(model_pt_path, device):
94
114
  return net
95
115
 
96
116
 
97
- def read_model_params(file_addr):
98
- with open(file_addr) as f:
99
- lines = f.readlines()
100
- param_dict = {}
101
- for line in lines:
102
- if ':' in line:
103
- key = line.split(':')[0].strip()
104
- val = line.split(':')[1].split('[')[0].strip()
105
- param_dict[key] = val
106
- print(param_dict)
107
- return param_dict
108
-
109
-
110
- def load_eager_models(model_dir, devices):
111
- input_nc = 3
112
- output_nc = 3
113
- ngf = 64
114
- norm = 'batch'
115
- use_dropout = True
116
- padding_type = 'zero'
117
-
118
- files = os.listdir(model_dir)
119
- for f in files:
120
- if 'train_opt.txt' in f:
121
- param_dict = read_model_params(os.path.join(model_dir, f))
122
- input_nc = int(param_dict['input_nc'])
123
- output_nc = int(param_dict['output_nc'])
124
- ngf = int(param_dict['ngf'])
125
- norm = param_dict['norm']
126
- use_dropout = False if param_dict['no_dropout'] == 'True' else True
127
- padding_type = param_dict['padding']
128
-
129
- norm_layer = get_norm_layer(norm_type=norm)
130
117
 
131
- nets = {}
132
- for n in ['G1', 'G2', 'G3', 'G4']:
133
- net = ResnetGenerator(input_nc, output_nc, ngf, norm_layer=norm_layer, use_dropout=use_dropout, n_blocks=9, padding_type=padding_type)
134
- net.load_state_dict(torch.load(
135
- os.path.join(model_dir, f'latest_net_{n}.pth'),
136
- map_location=devices[n]
137
- ))
138
- nets[n] = disable_batchnorm_tracking_stats(net)
139
- nets[n].eval()
140
-
141
- for n in ['G51', 'G52', 'G53', 'G54', 'G55']:
142
- net = UnetGenerator(input_nc, output_nc, 9, ngf, norm_layer=norm_layer, use_dropout=use_dropout)
143
- net.load_state_dict(torch.load(
144
- os.path.join(model_dir, f'latest_net_{n}.pth'),
145
- map_location=devices[n]
146
- ))
147
- nets[n] = disable_batchnorm_tracking_stats(net)
148
- nets[n].eval()
118
+ def load_eager_models(opt, devices):
119
+ # create a model given model and other options
120
+ model = create_model(opt)
121
+ # regular setup: load and print networks; create schedulers
122
+ model.setup(opt)
149
123
 
124
+ nets = {}
125
+ for name in model.model_names:
126
+ if isinstance(name, str):
127
+ if '_' in name:
128
+ net = getattr(model, 'net' + name.split('_')[0])[int(name.split('_')[-1]) - 1]
129
+ else:
130
+ net = getattr(model, 'net' + name)
131
+
132
+ if opt.phase != 'train':
133
+ net.eval()
134
+ net = disable_batchnorm_tracking_stats(net)
135
+
136
+ nets[name] = net
137
+ nets[name].to(devices[name])
138
+
150
139
  return nets
151
140
 
152
141
 
153
142
  @lru_cache
154
- def init_nets(model_dir, eager_mode=False):
143
+ def init_nets(model_dir, eager_mode=False, opt=None, phase='test'):
155
144
  """
156
145
  Init DeepLIIF networks so that every net in
157
146
  the same group is deployed on the same GPU
158
- """
159
- net_groups = [
160
- ('G1', 'G52'),
161
- ('G2', 'G53'),
162
- ('G3', 'G54'),
163
- ('G4', 'G55'),
164
- ('G51',)
165
- ]
166
-
167
- number_of_gpus = torch.cuda.device_count()
168
- # number_of_gpus = 0
169
- if number_of_gpus:
147
+
148
+ opt_args: to overwrite opt arguments in train_opt.txt, typically used in inference stage
149
+ for example, opt_args={'phase':'test'}
150
+ """
151
+ if opt is None:
152
+ opt = get_opt(model_dir, mode=phase)
153
+ opt.use_dp = False
154
+ #print_options(opt)
155
+
156
+ if opt.model == 'DeepLIIF':
157
+ net_groups = [
158
+ ('G1', 'G52'),
159
+ ('G2', 'G53'),
160
+ ('G3', 'G54'),
161
+ ('G4', 'G55'),
162
+ ('G51',)
163
+ ]
164
+ elif opt.model == 'DeepLIIFExt':
165
+ if opt.seg_gen:
166
+ net_groups = [(f'G_{i+1}',f'GS_{i+1}') for i in range(opt.modalities_no)]
167
+ else:
168
+ net_groups = [(f'G_{i+1}',) for i in range(opt.modalities_no)]
169
+ else:
170
+ raise Exception(f'init_nets() not implemented for model {opt.model}')
171
+
172
+ number_of_gpus_all = torch.cuda.device_count()
173
+ number_of_gpus = len(opt.gpu_ids)
174
+ #print(number_of_gpus)
175
+ if number_of_gpus > 0:
176
+ mapping_gpu_ids = {i:idx for i,idx in enumerate(opt.gpu_ids)}
170
177
  chunks = [itertools.chain.from_iterable(c) for c in chunker(net_groups, number_of_gpus)]
171
178
  # chunks = chunks[1:]
172
- devices = {n: torch.device(f'cuda:{i}') for i, g in enumerate(chunks) for n in g}
179
+ devices = {n: torch.device(f'cuda:{mapping_gpu_ids[i]}') for i, g in enumerate(chunks) for n in g}
180
+ # devices = {n: torch.device(f'cuda:{i}') for i, g in enumerate(chunks) for n in g}
173
181
  else:
174
182
  devices = {n: torch.device('cpu') for n in itertools.chain.from_iterable(net_groups)}
175
183
 
176
184
  if eager_mode:
177
- return load_eager_models(model_dir, devices)
185
+ return load_eager_models(opt, devices)
178
186
 
179
187
  return {
180
188
  n: load_torchscript_model(os.path.join(model_dir, f'{n}.pt'), device=d)
@@ -190,14 +198,18 @@ def compute_overlap(img_size, tile_size):
190
198
  return tile_size // 4
191
199
 
192
200
 
193
- def run_torchserve(img, model_path=None, eager_mode=False):
201
+ def run_torchserve(img, model_path=None, eager_mode=False, opt=None):
194
202
  """
195
203
  eager_mode: not used in this function; put in place to be consistent with run_dask
196
204
  so that run_wrapper() could call either this function or run_dask with
197
205
  same syntax
206
+ opt: same as eager_mode
198
207
  """
199
208
  buffer = BytesIO()
200
- torch.save(transform(img.resize((512, 512))), buffer)
209
+ if opt.model == 'DeepLIIFExt':
210
+ torch.save(transform(img.resize((1024, 1024))), buffer)
211
+ else:
212
+ torch.save(transform(img.resize((512, 512))), buffer)
201
213
 
202
214
  torchserve_host = os.getenv('TORCHSERVE_HOST', 'http://localhost')
203
215
  res = requests.post(
@@ -213,33 +225,55 @@ def run_torchserve(img, model_path=None, eager_mode=False):
213
225
  return {k: tensor_to_pil(deserialize_tensor(v)) for k, v in res.json().items()}
214
226
 
215
227
 
216
- def run_dask(img, model_path, eager_mode=False):
228
+ def run_dask(img, model_path, eager_mode=False, opt=None):
217
229
  model_dir = os.getenv('DEEPLIIF_MODEL_DIR', model_path)
218
- nets = init_nets(model_dir, eager_mode)
219
-
220
- ts = transform(img.resize((512, 512)))
230
+ nets = init_nets(model_dir, eager_mode, opt)
231
+
232
+ if opt.model == 'DeepLIIFExt':
233
+ ts = transform(img.resize((1024, 1024)))
234
+ else:
235
+ ts = transform(img.resize((512, 512)))
221
236
 
222
237
  @delayed
223
238
  def forward(input, model):
224
239
  with torch.no_grad():
225
240
  return model(input.to(next(model.parameters()).device))
241
+
242
+ if opt.model == 'DeepLIIF':
243
+ seg_map = {'G1': 'G52', 'G2': 'G53', 'G3': 'G54', 'G4': 'G55'}
244
+
245
+ lazy_gens = {k: forward(ts, nets[k]) for k in seg_map}
246
+ gens = compute(lazy_gens)[0]
247
+
248
+ lazy_segs = {v: forward(gens[k], nets[v]).to(torch.device('cpu')) for k, v in seg_map.items()}
249
+ lazy_segs['G51'] = forward(ts, nets['G51']).to(torch.device('cpu'))
250
+ segs = compute(lazy_segs)[0]
251
+
252
+ seg_weights = [0.25, 0.25, 0.25, 0, 0.25]
253
+ seg = torch.stack([torch.mul(n, w) for n, w in zip(segs.values(), seg_weights)]).sum(dim=0)
254
+
255
+ res = {k: tensor_to_pil(v) for k, v in gens.items()}
256
+ res['G5'] = tensor_to_pil(seg)
257
+
258
+ return res
259
+ elif opt.model == 'DeepLIIFExt':
260
+ seg_map = {'G_' + str(i): 'GS_' + str(i) for i in range(1, opt.modalities_no + 1)}
261
+
262
+ lazy_gens = {k: forward(ts, nets[k]) for k in seg_map}
263
+ gens = compute(lazy_gens)[0]
264
+
265
+ res = {k: tensor_to_pil(v) for k, v in gens.items()}
266
+
267
+ if opt.seg_gen:
268
+ lazy_segs = {v: forward(torch.cat([ts.to(torch.device('cpu')), gens[next(iter(seg_map))].to(torch.device('cpu')), gens[k].to(torch.device('cpu'))], 1), nets[v]).to(torch.device('cpu')) for k, v in seg_map.items()}
269
+ segs = compute(lazy_segs)[0]
270
+ res.update({k: tensor_to_pil(v) for k, v in segs.items()})
271
+
272
+ return res
273
+ else:
274
+ raise Exception(f'run_dask() not implemented for {opt.model}')
226
275
 
227
- seg_map = {'G1': 'G52', 'G2': 'G53', 'G3': 'G54', 'G4': 'G55'}
228
-
229
- lazy_gens = {k: forward(ts, nets[k]) for k in seg_map}
230
- gens = compute(lazy_gens)[0]
231
-
232
- lazy_segs = {v: forward(gens[k], nets[v]).to(torch.device('cpu')) for k, v in seg_map.items()}
233
- lazy_segs['G51'] = forward(ts, nets['G51']).to(torch.device('cpu'))
234
- segs = compute(lazy_segs)[0]
235
-
236
- seg_weights = [0.25, 0.25, 0.25, 0, 0.25]
237
- seg = torch.stack([torch.mul(n, w) for n, w in zip(segs.values(), seg_weights)]).sum(dim=0)
238
-
239
- res = {k: tensor_to_pil(v) for k, v in gens.items()}
240
- res['G5'] = tensor_to_pil(seg)
241
-
242
- return res
276
+
243
277
 
244
278
 
245
279
  def is_empty(tile):
@@ -247,17 +281,27 @@ def is_empty(tile):
247
281
  return True if calculate_background_area(tile) > 98 else False
248
282
 
249
283
 
250
- def run_wrapper(tile, run_fn, model_path, eager_mode=False):
251
- if is_empty(tile):
252
- return {
253
- 'G1': Image.new(mode='RGB', size=(512, 512), color=(201, 211, 208)),
254
- 'G2': Image.new(mode='RGB', size=(512, 512), color=(10, 10, 10)),
255
- 'G3': Image.new(mode='RGB', size=(512, 512), color=(0, 0, 0)),
256
- 'G4': Image.new(mode='RGB', size=(512, 512), color=(10, 10, 10)),
257
- 'G5': Image.new(mode='RGB', size=(512, 512), color=(0, 0, 0))
258
- }
284
+ def run_wrapper(tile, run_fn, model_path, eager_mode=False, opt=None):
285
+ if opt.model == 'DeepLIIF':
286
+ if is_empty(tile):
287
+ return {
288
+ 'G1': Image.new(mode='RGB', size=(512, 512), color=(201, 211, 208)),
289
+ 'G2': Image.new(mode='RGB', size=(512, 512), color=(10, 10, 10)),
290
+ 'G3': Image.new(mode='RGB', size=(512, 512), color=(0, 0, 0)),
291
+ 'G4': Image.new(mode='RGB', size=(512, 512), color=(10, 10, 10)),
292
+ 'G5': Image.new(mode='RGB', size=(512, 512), color=(0, 0, 0))
293
+ }
294
+ else:
295
+ return run_fn(tile, model_path, eager_mode, opt)
296
+ elif opt.model == 'DeepLIIFExt':
297
+ if is_empty(tile):
298
+ res = {'G_' + str(i): Image.new(mode='RGB', size=(512, 512)) for i in range(1, opt.modalities_no + 1)}
299
+ res.update({'GS_' + str(i): Image.new(mode='RGB', size=(512, 512)) for i in range(1, opt.modalities_no + 1)})
300
+ return res
301
+ else:
302
+ return run_fn(tile, model_path, eager_mode, opt)
259
303
  else:
260
- return run_fn(tile, model_path, eager_mode)
304
+ raise Exception(f'run_wrapper() not implemented for model {opt.model}')
261
305
 
262
306
 
263
307
  def inference_old(img, tile_size, overlap_size, model_path, use_torchserve=False, eager_mode=False,
@@ -307,71 +351,115 @@ def inference_old(img, tile_size, overlap_size, model_path, use_torchserve=False
307
351
 
308
352
 
309
353
  def inference(img, tile_size, overlap_size, model_path, use_torchserve=False, eager_mode=False,
310
- color_dapi=False, color_marker=False):
311
-
312
- rescaled, rows, cols = format_image_for_tiling(img, tile_size, overlap_size)
313
-
314
- run_fn = run_torchserve if use_torchserve else run_dask
315
-
316
- images = {}
317
- images['Hema'] = create_image_for_stitching(tile_size, rows, cols)
318
- images['DAPI'] = create_image_for_stitching(tile_size, rows, cols)
319
- images['Lap2'] = create_image_for_stitching(tile_size, rows, cols)
320
- images['Marker'] = create_image_for_stitching(tile_size, rows, cols)
321
- images['Seg'] = create_image_for_stitching(tile_size, rows, cols)
322
-
323
- for i in range(cols):
324
- for j in range(rows):
325
- tile = extract_tile(rescaled, tile_size, overlap_size, i, j)
326
- res = run_wrapper(tile, run_fn, model_path, eager_mode)
327
-
328
- stitch_tile(images['Hema'], res['G1'], tile_size, overlap_size, i, j)
329
- stitch_tile(images['DAPI'], res['G2'], tile_size, overlap_size, i, j)
330
- stitch_tile(images['Lap2'], res['G3'], tile_size, overlap_size, i, j)
331
- stitch_tile(images['Marker'], res['G4'], tile_size, overlap_size, i, j)
332
- stitch_tile(images['Seg'], res['G5'], tile_size, overlap_size, i, j)
333
-
334
- images['Hema'] = images['Hema'].resize(img.size)
335
- images['DAPI'] = images['DAPI'].resize(img.size)
336
- images['Lap2'] = images['Lap2'].resize(img.size)
337
- images['Marker'] = images['Marker'].resize(img.size)
338
- images['Seg'] = images['Seg'].resize(img.size)
339
-
340
- if color_dapi:
341
- matrix = ( 0, 0, 0, 0,
342
- 299/1000, 587/1000, 114/1000, 0,
343
- 299/1000, 587/1000, 114/1000, 0)
344
- images['DAPI'] = images['DAPI'].convert('RGB', matrix)
345
-
346
- if color_marker:
347
- matrix = (299/1000, 587/1000, 114/1000, 0,
348
- 299/1000, 587/1000, 114/1000, 0,
349
- 0, 0, 0, 0)
350
- images['Marker'] = images['Marker'].convert('RGB', matrix)
351
-
352
- return images
353
-
354
-
355
- def postprocess(img, seg_img, thresh=80, noise_objects_size=20, small_object_size=50):
356
- mask_image = create_basic_segmentation_mask(np.array(img), np.array(seg_img),
357
- thresh, noise_objects_size, small_object_size)
358
- images = {}
359
- images['SegOverlaid'] = Image.fromarray(overlay_final_segmentation_mask(np.array(img), mask_image))
360
- images['SegRefined'] = Image.fromarray(create_final_segmentation_mask_with_boundaries(np.array(mask_image)))
361
-
362
- all_cells_no, positive_cells_no, negative_cells_no, IHC_score = compute_IHC_scoring(mask_image)
363
- scoring = {
364
- 'num_total': all_cells_no,
365
- 'num_pos': positive_cells_no,
366
- 'num_neg': negative_cells_no,
367
- 'percent_pos': IHC_score
368
- }
354
+ color_dapi=False, color_marker=False, opt=None):
355
+ if not opt:
356
+ opt = get_opt(model_path)
357
+ #print_options(opt)
358
+
359
+ if opt.model == 'DeepLIIF':
360
+ rescaled, rows, cols = format_image_for_tiling(img, tile_size, overlap_size)
361
+
362
+ run_fn = run_torchserve if use_torchserve else run_dask
363
+
364
+ images = {}
365
+ images['Hema'] = create_image_for_stitching(tile_size, rows, cols)
366
+ images['DAPI'] = create_image_for_stitching(tile_size, rows, cols)
367
+ images['Lap2'] = create_image_for_stitching(tile_size, rows, cols)
368
+ images['Marker'] = create_image_for_stitching(tile_size, rows, cols)
369
+ images['Seg'] = create_image_for_stitching(tile_size, rows, cols)
370
+
371
+ for i in range(cols):
372
+ for j in range(rows):
373
+ tile = extract_tile(rescaled, tile_size, overlap_size, i, j)
374
+ res = run_wrapper(tile, run_fn, model_path, eager_mode, opt)
375
+
376
+ stitch_tile(images['Hema'], res['G1'], tile_size, overlap_size, i, j)
377
+ stitch_tile(images['DAPI'], res['G2'], tile_size, overlap_size, i, j)
378
+ stitch_tile(images['Lap2'], res['G3'], tile_size, overlap_size, i, j)
379
+ stitch_tile(images['Marker'], res['G4'], tile_size, overlap_size, i, j)
380
+ stitch_tile(images['Seg'], res['G5'], tile_size, overlap_size, i, j)
381
+
382
+ images['Hema'] = images['Hema'].resize(img.size)
383
+ images['DAPI'] = images['DAPI'].resize(img.size)
384
+ images['Lap2'] = images['Lap2'].resize(img.size)
385
+ images['Marker'] = images['Marker'].resize(img.size)
386
+ images['Seg'] = images['Seg'].resize(img.size)
387
+
388
+ if color_dapi:
389
+ matrix = ( 0, 0, 0, 0,
390
+ 299/1000, 587/1000, 114/1000, 0,
391
+ 299/1000, 587/1000, 114/1000, 0)
392
+ images['DAPI'] = images['DAPI'].convert('RGB', matrix)
393
+
394
+ if color_marker:
395
+ matrix = (299/1000, 587/1000, 114/1000, 0,
396
+ 299/1000, 587/1000, 114/1000, 0,
397
+ 0, 0, 0, 0)
398
+ images['Marker'] = images['Marker'].convert('RGB', matrix)
399
+
400
+ return images
401
+
402
+ elif opt.model == 'DeepLIIFExt':
403
+ #param_dict = read_train_options(model_path)
404
+ #modalities_no = int(param_dict['modalities_no']) if param_dict else 4
405
+ #seg_gen = (param_dict['seg_gen'] == 'True') if param_dict else True
406
+
407
+ tiles = list(generate_tiles(img, tile_size, overlap_size))
408
+
409
+ run_fn = run_torchserve if use_torchserve else run_dask
410
+ res = [Tile(t.i, t.j, run_wrapper(t.img, run_fn, model_path, eager_mode, opt)) for t in tiles]
411
+
412
+ def get_net_tiles(n):
413
+ return [Tile(t.i, t.j, t.img[n]) for t in res]
414
+
415
+ images = {}
416
+
417
+ for i in range(1, opt.modalities_no + 1):
418
+ images['mod' + str(i)] = stitch(get_net_tiles('G_' + str(i)), tile_size, overlap_size).resize(img.size)
419
+
420
+ if opt.seg_gen:
421
+ for i in range(1, opt.modalities_no + 1):
422
+ images['Seg' + str(i)] = stitch(get_net_tiles('GS_' + str(i)), tile_size, overlap_size).resize(img.size)
423
+
424
+ return images
425
+
426
+ else:
427
+ raise Exception(f'inference() not implemented for model {opt.model}')
428
+
429
+
430
+ def postprocess(orig, images, tile_size, model, seg_thresh=150, size_thresh='auto', marker_thresh='auto', size_thresh_upper=None):
431
+ if model == 'DeepLIIF':
432
+ resolution = '40x' if tile_size > 384 else ('20x' if tile_size > 192 else '10x')
433
+ overlay, refined, scoring = compute_results(np.array(orig), np.array(images['Seg']),
434
+ np.array(images['Marker'].convert('L')) if 'Marker' in images else None,
435
+ resolution, seg_thresh, size_thresh, marker_thresh, size_thresh_upper)
436
+ processed_images = {}
437
+ processed_images['SegOverlaid'] = Image.fromarray(overlay)
438
+ processed_images['SegRefined'] = Image.fromarray(refined)
439
+ return processed_images, scoring
440
+
441
+ elif model == 'DeepLIIFExt':
442
+ resolution = '40x' if tile_size > 768 else ('20x' if tile_size > 384 else '10x')
443
+ processed_images = {}
444
+ scoring = {}
445
+ for img_name in list(images.keys()):
446
+ if 'Seg' in img_name:
447
+ seg_img = images[img_name]
448
+ overlay, refined, score = compute_results(np.array(orig), np.array(images[img_name]),
449
+ None, resolution,
450
+ seg_thresh, size_thresh, marker_thresh, size_thresh_upper)
451
+
452
+ processed_images[img_name + '_Overlaid'] = Image.fromarray(overlay)
453
+ processed_images[img_name + '_Refined'] = Image.fromarray(refined)
454
+ scoring[img_name] = score
455
+ return processed_images, scoring
369
456
 
370
- return images, scoring
457
+ else:
458
+ raise Exception(f'postprocess() not implemented for model {model}')
371
459
 
372
460
 
373
461
  def infer_modalities(img, tile_size, model_dir, eager_mode=False,
374
- color_dapi=False, color_marker=False):
462
+ color_dapi=False, color_marker=False, opt=None):
375
463
  """
376
464
  This function is used to infer modalities for the given image using a trained model.
377
465
  :param img: The input image.
@@ -379,6 +467,11 @@ def infer_modalities(img, tile_size, model_dir, eager_mode=False,
379
467
  :param model_dir: The directory containing serialized model files.
380
468
  :return: The inferred modalities and the segmentation mask.
381
469
  """
470
+ if opt is None:
471
+ opt = get_opt(model_dir)
472
+ opt.use_dp = False
473
+ #print_options(opt)
474
+
382
475
  if not tile_size:
383
476
  tile_size = check_multi_scale(Image.open('./images/target.png').convert('L'),
384
477
  img.convert('L'))
@@ -391,12 +484,16 @@ def infer_modalities(img, tile_size, model_dir, eager_mode=False,
391
484
  model_path=model_dir,
392
485
  eager_mode=eager_mode,
393
486
  color_dapi=color_dapi,
394
- color_marker=color_marker
487
+ color_marker=color_marker,
488
+ opt=opt
395
489
  )
396
-
397
- post_images, scoring = postprocess(img, images['Seg'], small_object_size=20)
398
- images = {**images, **post_images}
399
- return images, scoring
490
+
491
+ if not hasattr(opt,'seg_gen') or (hasattr(opt,'seg_gen') and opt.seg_gen): # the first condition accounts for old settings of deepliif; the second refers to deepliifext models
492
+ post_images, scoring = postprocess(img, images, tile_size, opt.model)
493
+ images = {**images, **post_images}
494
+ return images, scoring
495
+ else:
496
+ return images, None
400
497
 
401
498
 
402
499
  def infer_results_for_wsi(input_dir, filename, output_dir, model_dir, tile_size, region_size=20000):
@@ -4,6 +4,8 @@ from collections import OrderedDict
4
4
  from abc import ABC, abstractmethod
5
5
  from . import networks
6
6
  from ..util import disable_batchnorm_tracking_stats
7
+ from deepliif.util import *
8
+ import itertools
7
9
 
8
10
 
9
11
  class BaseModel(ABC):
@@ -35,8 +37,9 @@ class BaseModel(ABC):
35
37
  self.is_train = opt.is_train
36
38
  self.device = torch.device('cuda:{}'.format(self.gpu_ids[0])) if self.gpu_ids else torch.device('cpu') # get device name: CPU or GPU
37
39
  self.save_dir = os.path.join(opt.checkpoints_dir, opt.name) # save all the checkpoints to save_dir
38
- if opt.preprocess != 'scale_width': # with [scale_width], input images might have different sizes, which hurts the performance of cudnn.benchmark.
40
+ if opt.phase == 'train' and opt.preprocess != 'scale_width': # with [scale_width], input images might have different sizes, which hurts the performance of cudnn.benchmark.
39
41
  torch.backends.cudnn.benchmark = True
42
+ # especially for inference, cudnn benchmark can cause excessive usage of GPU memory for the first image in the sequence in order to find the best conv alg which is not necessary.
40
43
  self.loss_names = []
41
44
  self.model_names = []
42
45
  self.visual_names = []
@@ -89,7 +92,10 @@ class BaseModel(ABC):
89
92
  """Make models eval mode during test time"""
90
93
  for name in self.model_names:
91
94
  if isinstance(name, str):
92
- net = getattr(self, 'net' + name)
95
+ if '_' in name:
96
+ net = getattr(self, 'net' + name.split('_')[0])[int(name.split('_')[-1]) - 1]
97
+ else:
98
+ net = getattr(self, 'net' + name)
93
99
  net.eval()
94
100
  net = disable_batchnorm_tracking_stats(net)
95
101
 
@@ -127,7 +133,13 @@ class BaseModel(ABC):
127
133
  visual_ret = OrderedDict()
128
134
  for name in self.visual_names:
129
135
  if isinstance(name, str):
130
- visual_ret[name] = getattr(self, name)
136
+ if not hasattr(self, name):
137
+ if len(name.split('_')) == 2:
138
+ visual_ret[name] = getattr(self, name.split('_')[0])[int(name.split('_')[-1]) -1]
139
+ else:
140
+ visual_ret[name] = getattr(self, name.split('_')[0] + '_' + name.split('_')[1])[int(name.split('_')[-1]) - 1]
141
+ else:
142
+ visual_ret[name] = getattr(self, name)
131
143
  return visual_ret
132
144
 
133
145
  def get_current_losses(self):
@@ -135,7 +147,16 @@ class BaseModel(ABC):
135
147
  errors_ret = OrderedDict()
136
148
  for name in self.loss_names:
137
149
  if isinstance(name, str):
138
- errors_ret[name] = float(getattr(self, 'loss_' + name)) # float(...) works for both scalar tensor and float number
150
+ if not hasattr(self, 'loss_'+name): # appears in DeepLIIFExt
151
+ if len(name.split('_')) == 2:
152
+ errors_ret[name] = float(getattr(self, 'loss_' + name.split('_')[0])[int(
153
+ name.split('_')[-1]) - 1]) # float(...) works for both scalar tensor and float number
154
+ else:
155
+ errors_ret[name] = float(getattr(self, 'loss_' + name.split('_')[0] + '_' + name.split('_')[1])[int(
156
+ name.split('_')[-1]) - 1]) # float(...) works for both scalar tensor and float number
157
+ else: # single numeric value
158
+ errors_ret[name] = float(getattr(self, 'loss_' + name))
159
+
139
160
  return errors_ret
140
161
 
141
162
  def save_networks(self, epoch, save_from_one_process=False):
@@ -151,7 +172,10 @@ class BaseModel(ABC):
151
172
  if isinstance(name, str):
152
173
  save_filename = '%s_net_%s.pth' % (epoch, name)
153
174
  save_path = os.path.join(self.save_dir, save_filename)
154
- net = getattr(self, 'net' + name)
175
+ if '_' in name:
176
+ net = getattr(self, 'net' + name.split('_')[0])[int(name.split('_')[-1]) - 1]
177
+ else:
178
+ net = getattr(self, 'net' + name)
155
179
 
156
180
  if len(self.gpu_ids) > 0 and torch.cuda.is_available():
157
181
  torch.save(net.module.cpu().state_dict(), save_path)
@@ -219,13 +243,32 @@ class BaseModel(ABC):
219
243
  if isinstance(name, str):
220
244
  load_filename = '%s_net_%s.pth' % (epoch, name)
221
245
  load_path = os.path.join(self.save_dir, load_filename)
222
- net = getattr(self, 'net' + name)
246
+ if '_' in name:
247
+ net = getattr(self, 'net' + name.split('_')[0])[int(name.split('_')[-1]) - 1]
248
+ else:
249
+ net = getattr(self, 'net' + name)
223
250
  if isinstance(net, torch.nn.DataParallel):
224
251
  net = net.module
252
+
253
+ self.set_requires_grad(net,self.opt.is_train)
254
+ # check if gradients are disabled
255
+ names_layer_requires_grad = []
256
+ for name, param in net.named_parameters():
257
+ if param.requires_grad:
258
+ names_layer_requires_grad.append(name)
259
+
225
260
  print('loading the model from %s' % load_path)
226
261
  # if you are using PyTorch newer than 0.4 (e.g., built from
227
262
  # GitHub source), you can remove str() on self.device
228
- state_dict = torch.load(load_path, map_location=str(self.device))
263
+
264
+ if self.opt.is_train or self.opt.use_dp:
265
+ device = self.device
266
+ else:
267
+ device = torch.device('cpu') # load in cpu first; later in __inite__.py::init_nets we will move it to the specified device
268
+
269
+ net.to(device)
270
+ state_dict = torch.load(load_path, map_location=str(device))
271
+
229
272
  if hasattr(state_dict, '_metadata'):
230
273
  del state_dict._metadata
231
274
 
@@ -243,7 +286,10 @@ class BaseModel(ABC):
243
286
  print('---------- Networks initialized -------------')
244
287
  for name in self.model_names:
245
288
  if isinstance(name, str):
246
- net = getattr(self, 'net' + name)
289
+ if '_' in name:
290
+ net = getattr(self, 'net' + name.split('_')[0])[int(name.split('_')[-1]) - 1]
291
+ else:
292
+ net = getattr(self, 'net' + name)
247
293
  num_params = 0
248
294
  for param in net.parameters():
249
295
  num_params += param.numel()