deepliif 1.1.9__py3-none-any.whl → 1.1.11__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cli.py +49 -42
- deepliif/data/aligned_dataset.py +17 -0
- deepliif/models/SDG_model.py +189 -0
- deepliif/models/__init__.py +170 -46
- deepliif/options/__init__.py +62 -29
- deepliif/util/__init__.py +227 -0
- deepliif/util/util.py +17 -1
- {deepliif-1.1.9.dist-info → deepliif-1.1.11.dist-info}/METADATA +181 -27
- {deepliif-1.1.9.dist-info → deepliif-1.1.11.dist-info}/RECORD +13 -12
- {deepliif-1.1.9.dist-info → deepliif-1.1.11.dist-info}/LICENSE.md +0 -0
- {deepliif-1.1.9.dist-info → deepliif-1.1.11.dist-info}/WHEEL +0 -0
- {deepliif-1.1.9.dist-info → deepliif-1.1.11.dist-info}/entry_points.txt +0 -0
- {deepliif-1.1.9.dist-info → deepliif-1.1.11.dist-info}/top_level.txt +0 -0
deepliif/models/__init__.py
CHANGED
|
@@ -133,6 +133,10 @@ def load_eager_models(opt, devices):
|
|
|
133
133
|
net.eval()
|
|
134
134
|
net = disable_batchnorm_tracking_stats(net)
|
|
135
135
|
|
|
136
|
+
# SDG models when loaded are still DP.. not sure why
|
|
137
|
+
if isinstance(net, torch.nn.DataParallel):
|
|
138
|
+
net = net.module
|
|
139
|
+
|
|
136
140
|
nets[name] = net
|
|
137
141
|
nets[name].to(devices[name])
|
|
138
142
|
|
|
@@ -161,7 +165,7 @@ def init_nets(model_dir, eager_mode=False, opt=None, phase='test'):
|
|
|
161
165
|
('G4', 'G55'),
|
|
162
166
|
('G51',)
|
|
163
167
|
]
|
|
164
|
-
elif opt.model
|
|
168
|
+
elif opt.model in ['DeepLIIFExt','SDG']:
|
|
165
169
|
if opt.seg_gen:
|
|
166
170
|
net_groups = [(f'G_{i+1}',f'GS_{i+1}') for i in range(opt.modalities_no)]
|
|
167
171
|
else:
|
|
@@ -172,6 +176,7 @@ def init_nets(model_dir, eager_mode=False, opt=None, phase='test'):
|
|
|
172
176
|
number_of_gpus_all = torch.cuda.device_count()
|
|
173
177
|
number_of_gpus = len(opt.gpu_ids)
|
|
174
178
|
#print(number_of_gpus)
|
|
179
|
+
|
|
175
180
|
if number_of_gpus > 0:
|
|
176
181
|
mapping_gpu_ids = {i:idx for i,idx in enumerate(opt.gpu_ids)}
|
|
177
182
|
chunks = [itertools.chain.from_iterable(c) for c in chunker(net_groups, number_of_gpus)]
|
|
@@ -206,10 +211,7 @@ def run_torchserve(img, model_path=None, eager_mode=False, opt=None):
|
|
|
206
211
|
opt: same as eager_mode
|
|
207
212
|
"""
|
|
208
213
|
buffer = BytesIO()
|
|
209
|
-
|
|
210
|
-
torch.save(transform(img.resize((1024, 1024))), buffer)
|
|
211
|
-
else:
|
|
212
|
-
torch.save(transform(img.resize((512, 512))), buffer)
|
|
214
|
+
torch.save(transform(img.resize((opt.scale_size, opt.scale_size))), buffer)
|
|
213
215
|
|
|
214
216
|
torchserve_host = os.getenv('TORCHSERVE_HOST', 'http://localhost')
|
|
215
217
|
res = requests.post(
|
|
@@ -229,10 +231,12 @@ def run_dask(img, model_path, eager_mode=False, opt=None):
|
|
|
229
231
|
model_dir = os.getenv('DEEPLIIF_MODEL_DIR', model_path)
|
|
230
232
|
nets = init_nets(model_dir, eager_mode, opt)
|
|
231
233
|
|
|
232
|
-
if opt.model == '
|
|
233
|
-
|
|
234
|
+
if opt.input_no > 1 or opt.model == 'SDG':
|
|
235
|
+
l_ts = [transform(img_i.resize((opt.scale_size,opt.scale_size))) for img_i in img]
|
|
236
|
+
ts = torch.cat(l_ts, dim=1)
|
|
234
237
|
else:
|
|
235
|
-
ts = transform(img.resize((
|
|
238
|
+
ts = transform(img.resize((opt.scale_size, opt.scale_size)))
|
|
239
|
+
|
|
236
240
|
|
|
237
241
|
@delayed
|
|
238
242
|
def forward(input, model):
|
|
@@ -249,14 +253,20 @@ def run_dask(img, model_path, eager_mode=False, opt=None):
|
|
|
249
253
|
lazy_segs['G51'] = forward(ts, nets['G51']).to(torch.device('cpu'))
|
|
250
254
|
segs = compute(lazy_segs)[0]
|
|
251
255
|
|
|
252
|
-
|
|
253
|
-
|
|
256
|
+
weights = {
|
|
257
|
+
'G51': 0.25, # IHC
|
|
258
|
+
'G52': 0.25, # Hema
|
|
259
|
+
'G53': 0.25, # DAPI
|
|
260
|
+
'G54': 0.00, # Lap2
|
|
261
|
+
'G55': 0.25, # Marker
|
|
262
|
+
}
|
|
263
|
+
seg = torch.stack([torch.mul(segs[k], weights[k]) for k in segs.keys()]).sum(dim=0)
|
|
254
264
|
|
|
255
265
|
res = {k: tensor_to_pil(v) for k, v in gens.items()}
|
|
256
266
|
res['G5'] = tensor_to_pil(seg)
|
|
257
267
|
|
|
258
268
|
return res
|
|
259
|
-
elif opt.model
|
|
269
|
+
elif opt.model in ['DeepLIIFExt','SDG']:
|
|
260
270
|
seg_map = {'G_' + str(i): 'GS_' + str(i) for i in range(1, opt.modalities_no + 1)}
|
|
261
271
|
|
|
262
272
|
lazy_gens = {k: forward(ts, nets[k]) for k in seg_map}
|
|
@@ -271,14 +281,23 @@ def run_dask(img, model_path, eager_mode=False, opt=None):
|
|
|
271
281
|
|
|
272
282
|
return res
|
|
273
283
|
else:
|
|
274
|
-
raise Exception(f'run_dask() not implemented for {opt.model}')
|
|
284
|
+
raise Exception(f'run_dask() not fully implemented for {opt.model}')
|
|
275
285
|
|
|
276
|
-
|
|
286
|
+
|
|
287
|
+
def is_empty_old(tile):
|
|
288
|
+
# return True if np.mean(np.array(tile) - np.array(mean_background_val)) < 40 else False
|
|
289
|
+
if isinstance(tile, list): # for pair of tiles, only mark it as empty / no need for prediction if ALL tiles are empty
|
|
290
|
+
return all([True if calculate_background_area(t) > 98 else False for t in tile])
|
|
291
|
+
else:
|
|
292
|
+
return True if calculate_background_area(tile) > 98 else False
|
|
277
293
|
|
|
278
294
|
|
|
279
295
|
def is_empty(tile):
|
|
280
|
-
|
|
281
|
-
|
|
296
|
+
thresh = 15
|
|
297
|
+
if isinstance(tile, list): # for pair of tiles, only mark it as empty / no need for prediction if ALL tiles are empty
|
|
298
|
+
return all([True if np.max(image_variance_rgb(t)) < thresh else False for t in tile])
|
|
299
|
+
else:
|
|
300
|
+
return True if np.max(image_variance_rgb(tile)) < thresh else False
|
|
282
301
|
|
|
283
302
|
|
|
284
303
|
def run_wrapper(tile, run_fn, model_path, eager_mode=False, opt=None):
|
|
@@ -293,7 +312,7 @@ def run_wrapper(tile, run_fn, model_path, eager_mode=False, opt=None):
|
|
|
293
312
|
}
|
|
294
313
|
else:
|
|
295
314
|
return run_fn(tile, model_path, eager_mode, opt)
|
|
296
|
-
elif opt.model
|
|
315
|
+
elif opt.model in ['DeepLIIFExt', 'SDG']:
|
|
297
316
|
if is_empty(tile):
|
|
298
317
|
res = {'G_' + str(i): Image.new(mode='RGB', size=(512, 512)) for i in range(1, opt.modalities_no + 1)}
|
|
299
318
|
res.update({'GS_' + str(i): Image.new(mode='RGB', size=(512, 512)) for i in range(1, opt.modalities_no + 1)})
|
|
@@ -350,8 +369,8 @@ def inference_old(img, tile_size, overlap_size, model_path, use_torchserve=False
|
|
|
350
369
|
return images
|
|
351
370
|
|
|
352
371
|
|
|
353
|
-
def
|
|
354
|
-
|
|
372
|
+
def inference_old2(img, tile_size, overlap_size, model_path, use_torchserve=False, eager_mode=False,
|
|
373
|
+
color_dapi=False, color_marker=False, opt=None):
|
|
355
374
|
if not opt:
|
|
356
375
|
opt = get_opt(model_path)
|
|
357
376
|
#print_options(opt)
|
|
@@ -362,28 +381,25 @@ def inference(img, tile_size, overlap_size, model_path, use_torchserve=False, ea
|
|
|
362
381
|
run_fn = run_torchserve if use_torchserve else run_dask
|
|
363
382
|
|
|
364
383
|
images = {}
|
|
365
|
-
|
|
366
|
-
|
|
367
|
-
|
|
368
|
-
|
|
369
|
-
|
|
384
|
+
d_modality2net = {'Hema':'G1',
|
|
385
|
+
'DAPI':'G2',
|
|
386
|
+
'Lap2':'G3',
|
|
387
|
+
'Marker':'G4',
|
|
388
|
+
'Seg':'G5'}
|
|
389
|
+
|
|
390
|
+
for k in d_modality2net.keys():
|
|
391
|
+
images[k] = create_image_for_stitching(tile_size, rows, cols)
|
|
370
392
|
|
|
371
393
|
for i in range(cols):
|
|
372
394
|
for j in range(rows):
|
|
373
395
|
tile = extract_tile(rescaled, tile_size, overlap_size, i, j)
|
|
374
396
|
res = run_wrapper(tile, run_fn, model_path, eager_mode, opt)
|
|
375
|
-
|
|
376
|
-
|
|
377
|
-
|
|
378
|
-
|
|
379
|
-
|
|
380
|
-
|
|
381
|
-
|
|
382
|
-
images['Hema'] = images['Hema'].resize(img.size)
|
|
383
|
-
images['DAPI'] = images['DAPI'].resize(img.size)
|
|
384
|
-
images['Lap2'] = images['Lap2'].resize(img.size)
|
|
385
|
-
images['Marker'] = images['Marker'].resize(img.size)
|
|
386
|
-
images['Seg'] = images['Seg'].resize(img.size)
|
|
397
|
+
|
|
398
|
+
for modality_name, net_name in d_modality2net.items():
|
|
399
|
+
stitch_tile(images[modality_name], res[net_name], tile_size, overlap_size, i, j)
|
|
400
|
+
|
|
401
|
+
for modality_name, output_img in images.items():
|
|
402
|
+
images[modality_name] = output_img.resize(img.size)
|
|
387
403
|
|
|
388
404
|
if color_dapi:
|
|
389
405
|
matrix = ( 0, 0, 0, 0,
|
|
@@ -403,30 +419,133 @@ def inference(img, tile_size, overlap_size, model_path, use_torchserve=False, ea
|
|
|
403
419
|
#param_dict = read_train_options(model_path)
|
|
404
420
|
#modalities_no = int(param_dict['modalities_no']) if param_dict else 4
|
|
405
421
|
#seg_gen = (param_dict['seg_gen'] == 'True') if param_dict else True
|
|
406
|
-
|
|
407
|
-
|
|
408
|
-
|
|
422
|
+
|
|
423
|
+
|
|
424
|
+
rescaled, rows, cols = format_image_for_tiling(img, tile_size, overlap_size)
|
|
409
425
|
run_fn = run_torchserve if use_torchserve else run_dask
|
|
410
|
-
res = [Tile(t.i, t.j, run_wrapper(t.img, run_fn, model_path, eager_mode, opt)) for t in tiles]
|
|
411
426
|
|
|
412
427
|
def get_net_tiles(n):
|
|
413
428
|
return [Tile(t.i, t.j, t.img[n]) for t in res]
|
|
414
429
|
|
|
415
430
|
images = {}
|
|
416
|
-
|
|
417
|
-
for i in range(1, opt.modalities_no + 1):
|
|
418
|
-
images['mod' + str(i)] = stitch(get_net_tiles('G_' + str(i)), tile_size, overlap_size).resize(img.size)
|
|
419
|
-
|
|
431
|
+
d_modality2net = {f'mod{i}':f'G_{i}' for i in range(1, opt.modalities_no + 1)}
|
|
420
432
|
if opt.seg_gen:
|
|
421
|
-
for i in range(1, opt.modalities_no + 1)
|
|
422
|
-
|
|
433
|
+
d_modality2net.update({f'Seg{i}':f'GS_{i}' for i in range(1, opt.modalities_no + 1)})
|
|
434
|
+
|
|
435
|
+
for k in d_modality2net.keys():
|
|
436
|
+
images[k] = create_image_for_stitching(tile_size, rows, cols)
|
|
423
437
|
|
|
438
|
+
for i in range(cols):
|
|
439
|
+
for j in range(rows):
|
|
440
|
+
tile = extract_tile(rescaled, tile_size, overlap_size, i, j)
|
|
441
|
+
res = run_wrapper(tile, run_fn, model_path, eager_mode, opt)
|
|
442
|
+
|
|
443
|
+
for modality_name, net_name in d_modality2net.items():
|
|
444
|
+
stitch_tile(images[modality_name], res[net_name], tile_size, overlap_size, i, j)
|
|
445
|
+
|
|
446
|
+
for modality_name, output_img in images.items():
|
|
447
|
+
images[modality_name] = output_img.resize(img.size)
|
|
448
|
+
|
|
449
|
+
return images
|
|
450
|
+
|
|
451
|
+
elif opt.model == 'SDG':
|
|
452
|
+
# SDG could have multiple input images / modalities
|
|
453
|
+
# the input hence could be a rectangle
|
|
454
|
+
# we split the input to get each modality image one by one
|
|
455
|
+
# then create tiles for each of the modality images
|
|
456
|
+
# tile_pair is a list that contains the tiles at the given location for each modality image
|
|
457
|
+
# l_tile_pair is a list of tile_pair that covers all locations
|
|
458
|
+
# for inference, each tile_pair is used to get the output at the given location
|
|
459
|
+
w, h = img.size
|
|
460
|
+
w2 = int(w / opt.input_no)
|
|
461
|
+
|
|
462
|
+
l_img = []
|
|
463
|
+
for i in range(opt.input_no):
|
|
464
|
+
img_i = img.crop((w2 * i, 0, w2 * (i+1), h))
|
|
465
|
+
rescaled_img_i, rows, cols = format_image_for_tiling(img_i, tile_size, overlap_size)
|
|
466
|
+
l_img.append(rescaled_img_i)
|
|
467
|
+
|
|
468
|
+
run_fn = run_torchserve if use_torchserve else run_dask
|
|
469
|
+
|
|
470
|
+
images = {}
|
|
471
|
+
d_modality2net = {f'mod{i}':f'G_{i}' for i in range(1, opt.modalities_no + 1)}
|
|
472
|
+
for k in d_modality2net.keys():
|
|
473
|
+
images[k] = create_image_for_stitching(tile_size, rows, cols)
|
|
474
|
+
|
|
475
|
+
for i in range(cols):
|
|
476
|
+
for j in range(rows):
|
|
477
|
+
tile_pair = [extract_tile(rescaled, tile_size, overlap_size, i, j) for rescaled in l_img]
|
|
478
|
+
res = run_wrapper(tile_pair, run_fn, model_path, eager_mode, opt)
|
|
479
|
+
|
|
480
|
+
for modality_name, net_name in d_modality2net.items():
|
|
481
|
+
stitch_tile(images[modality_name], res[net_name], tile_size, overlap_size, i, j)
|
|
482
|
+
|
|
483
|
+
for modality_name, output_img in images.items():
|
|
484
|
+
images[modality_name] = output_img.resize((w2,w2))
|
|
485
|
+
|
|
424
486
|
return images
|
|
425
487
|
|
|
426
488
|
else:
|
|
427
489
|
raise Exception(f'inference() not implemented for model {opt.model}')
|
|
428
490
|
|
|
429
491
|
|
|
492
|
+
def inference(img, tile_size, overlap_size, model_path, use_torchserve=False,
|
|
493
|
+
eager_mode=False, color_dapi=False, color_marker=False, opt=None):
|
|
494
|
+
if not opt:
|
|
495
|
+
opt = get_opt(model_path)
|
|
496
|
+
#print_options(opt)
|
|
497
|
+
|
|
498
|
+
run_fn = run_torchserve if use_torchserve else run_dask
|
|
499
|
+
|
|
500
|
+
if opt.model == 'SDG':
|
|
501
|
+
# SDG could have multiple input images/modalities, hence the input could be a rectangle.
|
|
502
|
+
# We split the input to get each modality image then create tiles for each set of input images.
|
|
503
|
+
w, h = int(img.width / opt.input_no), img.height
|
|
504
|
+
orig = [img.crop((w * i, 0, w * (i+1), h)) for i in range(opt.input_no)]
|
|
505
|
+
else:
|
|
506
|
+
# Otherwise expect a single input image, which is used directly.
|
|
507
|
+
orig = img
|
|
508
|
+
|
|
509
|
+
tiler = InferenceTiler(orig, tile_size, overlap_size)
|
|
510
|
+
for tile in tiler:
|
|
511
|
+
tiler.stitch(run_wrapper(tile, run_fn, model_path, eager_mode, opt))
|
|
512
|
+
results = tiler.results()
|
|
513
|
+
|
|
514
|
+
if opt.model == 'DeepLIIF':
|
|
515
|
+
images = {
|
|
516
|
+
'Hema': results['G1'],
|
|
517
|
+
'DAPI': results['G2'],
|
|
518
|
+
'Lap2': results['G3'],
|
|
519
|
+
'Marker': results['G4'],
|
|
520
|
+
'Seg': results['G5'],
|
|
521
|
+
}
|
|
522
|
+
if color_dapi:
|
|
523
|
+
matrix = ( 0, 0, 0, 0,
|
|
524
|
+
299/1000, 587/1000, 114/1000, 0,
|
|
525
|
+
299/1000, 587/1000, 114/1000, 0)
|
|
526
|
+
images['DAPI'] = images['DAPI'].convert('RGB', matrix)
|
|
527
|
+
if color_marker:
|
|
528
|
+
matrix = (299/1000, 587/1000, 114/1000, 0,
|
|
529
|
+
299/1000, 587/1000, 114/1000, 0,
|
|
530
|
+
0, 0, 0, 0)
|
|
531
|
+
images['Marker'] = images['Marker'].convert('RGB', matrix)
|
|
532
|
+
return images
|
|
533
|
+
|
|
534
|
+
elif opt.model == 'DeepLIIFExt':
|
|
535
|
+
images = {f'mod{i}': results[f'G_{i}'] for i in range(1, opt.modalities_no + 1)}
|
|
536
|
+
if opt.seg_gen:
|
|
537
|
+
images.update({f'Seg{i}': results[f'GS_{i}'] for i in range(1, opt.modalities_no + 1)})
|
|
538
|
+
return images
|
|
539
|
+
|
|
540
|
+
elif opt.model == 'SDG':
|
|
541
|
+
images = {f'mod{i}': results[f'G_{i}'] for i in range(1, opt.modalities_no + 1)}
|
|
542
|
+
return images
|
|
543
|
+
|
|
544
|
+
else:
|
|
545
|
+
#raise Exception(f'inference() not implemented for model {opt.model}')
|
|
546
|
+
return results # return result images with default key names (i.e., net names)
|
|
547
|
+
|
|
548
|
+
|
|
430
549
|
def postprocess(orig, images, tile_size, model, seg_thresh=150, size_thresh='auto', marker_thresh='auto', size_thresh_upper=None):
|
|
431
550
|
if model == 'DeepLIIF':
|
|
432
551
|
resolution = '40x' if tile_size > 384 else ('20x' if tile_size > 192 else '10x')
|
|
@@ -476,11 +595,16 @@ def infer_modalities(img, tile_size, model_dir, eager_mode=False,
|
|
|
476
595
|
tile_size = check_multi_scale(Image.open('./images/target.png').convert('L'),
|
|
477
596
|
img.convert('L'))
|
|
478
597
|
tile_size = int(tile_size)
|
|
598
|
+
|
|
599
|
+
# for those with multiple input modalities, find the correct size to calculate overlap_size
|
|
600
|
+
input_no = opt.input_no if hasattr(opt, 'input_no') else 1
|
|
601
|
+
img_size = (img.size[0] / input_no, img.size[1]) # (width, height)
|
|
479
602
|
|
|
480
603
|
images = inference(
|
|
481
604
|
img,
|
|
482
605
|
tile_size=tile_size,
|
|
483
|
-
overlap_size=compute_overlap(
|
|
606
|
+
#overlap_size=compute_overlap(img_size, tile_size),
|
|
607
|
+
overlap_size=tile_size//16,
|
|
484
608
|
model_path=model_dir,
|
|
485
609
|
eager_mode=eager_mode,
|
|
486
610
|
color_dapi=color_dapi,
|
deepliif/options/__init__.py
CHANGED
|
@@ -13,7 +13,6 @@ def read_model_params(file_addr):
|
|
|
13
13
|
key = line.split(':')[0].strip()
|
|
14
14
|
val = line.split(':')[1].split('[')[0].strip()
|
|
15
15
|
param_dict[key] = val
|
|
16
|
-
print(param_dict)
|
|
17
16
|
return param_dict
|
|
18
17
|
|
|
19
18
|
class Options:
|
|
@@ -32,19 +31,7 @@ class Options:
|
|
|
32
31
|
else:
|
|
33
32
|
setattr(self,k,v)
|
|
34
33
|
except:
|
|
35
|
-
setattr(self,k,v)
|
|
36
|
-
|
|
37
|
-
if mode != 'train':
|
|
38
|
-
# to account for old settings where gpu_ids value is an integer, not a tuple
|
|
39
|
-
if isinstance(self.gpu_ids,int):
|
|
40
|
-
self.gpu_ids = (self.gpu_ids,)
|
|
41
|
-
|
|
42
|
-
# to account for old settings before modalities_no was introduced
|
|
43
|
-
if not hasattr(self,'modalities_no') and hasattr(self,'targets_no'):
|
|
44
|
-
self.modalities_no = self.targets_no - 1
|
|
45
|
-
del self.targets_no
|
|
46
|
-
|
|
47
|
-
|
|
34
|
+
setattr(self,k,v)
|
|
48
35
|
|
|
49
36
|
if mode == 'train':
|
|
50
37
|
self.is_train = True
|
|
@@ -71,32 +58,78 @@ class Options:
|
|
|
71
58
|
self.checkpoints_dir = str(model_dir.parent)
|
|
72
59
|
self.name = str(model_dir.name)
|
|
73
60
|
|
|
74
|
-
self.gpu_ids = [] # gpu_ids is only used by eager mode, set to empty / cpu to be the same as the old settings; non-eager mode will use all gpus
|
|
61
|
+
#self.gpu_ids = [] # gpu_ids is only used by eager mode, set to empty / cpu to be the same as the old settings; non-eager mode will use all gpus
|
|
62
|
+
|
|
63
|
+
# to account for old settings where gpu_ids value is an integer, not a tuple
|
|
64
|
+
if isinstance(self.gpu_ids,int):
|
|
65
|
+
self.gpu_ids = (self.gpu_ids,)
|
|
66
|
+
|
|
67
|
+
# to account for old settings before modalities_no was introduced
|
|
68
|
+
if not hasattr(self,'modalities_no') and hasattr(self,'targets_no'):
|
|
69
|
+
self.modalities_no = self.targets_no - 1
|
|
70
|
+
del self.targets_no
|
|
71
|
+
|
|
72
|
+
# to account for old settings: same as in cli.py train
|
|
73
|
+
if not hasattr(self,'seg_no'):
|
|
74
|
+
if self.model == 'DeepLIIF':
|
|
75
|
+
self.seg_no = 1
|
|
76
|
+
elif self.model == 'DeepLIIFExt':
|
|
77
|
+
if self.seg_gen:
|
|
78
|
+
self.seg_no = self.modalities_no
|
|
79
|
+
else:
|
|
80
|
+
self.seg_no = 0
|
|
81
|
+
elif self.model == 'SDG':
|
|
82
|
+
self.seg_no = 0
|
|
83
|
+
self.seg_gen = False
|
|
84
|
+
else:
|
|
85
|
+
raise Exception(f'seg_gen cannot be automatically determined for {opt.model}')
|
|
86
|
+
|
|
87
|
+
# to account for old settings: prior to SDG, our models only have 1 input image
|
|
88
|
+
if not hasattr(self,'input_no'):
|
|
89
|
+
self.input_no = 1
|
|
90
|
+
|
|
91
|
+
# to account for old settings: before adding scale_size
|
|
92
|
+
if not hasattr(self, 'scale_size'):
|
|
93
|
+
if self.model in ['DeepLIIF','SDG']:
|
|
94
|
+
self.scale_size = 512
|
|
95
|
+
elif self.model == 'DeepLIIFExt':
|
|
96
|
+
self.scale_size = 1024
|
|
97
|
+
else:
|
|
98
|
+
raise Exception(f'scale_size cannot be automatically determined for {opt.model}')
|
|
99
|
+
|
|
100
|
+
|
|
75
101
|
|
|
76
102
|
def _get_kwargs(self):
|
|
77
103
|
common_attr = ['__class__', '__delattr__', '__dict__', '__dir__', '__doc__', '__eq__', '__format__', '__ge__', '__getattribute__', '__gt__', '__hash__', '__init__', '__init_subclass__', '__le__', '__lt__', '__module__', '__ne__', '__new__', '__reduce__', '__reduce_ex__', '__repr__', '__setattr__', '__sizeof__', '__str__', '__subclasshook__', '__weakref__']
|
|
78
104
|
l_args = [x for x in dir(self) if x not in common_attr]
|
|
79
105
|
return {k:getattr(self,k) for k in l_args}
|
|
80
|
-
|
|
81
|
-
def print_options(opt):
|
|
82
|
-
"""Print and save options
|
|
83
106
|
|
|
84
|
-
|
|
85
|
-
It will save options into a text file / [checkpoints_dir] / opt.txt
|
|
86
|
-
"""
|
|
107
|
+
def format_options(opt):
|
|
87
108
|
message = ''
|
|
88
109
|
message += '----------------- Options ---------------\n'
|
|
89
110
|
for k, v in sorted(vars(opt).items()):
|
|
90
111
|
comment = ''
|
|
91
112
|
message += '{:>25}: {:<30}{}\n'.format(str(k), str(v), comment)
|
|
92
113
|
message += '----------------- End -------------------'
|
|
93
|
-
|
|
114
|
+
return message
|
|
115
|
+
|
|
116
|
+
def print_options(opt, save=False):
|
|
117
|
+
"""Print (and save) options
|
|
94
118
|
|
|
119
|
+
It will print both current options and default values(if different).
|
|
120
|
+
If save=True, it will save options into a text file / [checkpoints_dir] / opt.txt
|
|
121
|
+
"""
|
|
122
|
+
message = format_options(opt)
|
|
123
|
+
print(message)
|
|
124
|
+
|
|
95
125
|
# save to the disk
|
|
96
|
-
if
|
|
97
|
-
|
|
98
|
-
|
|
99
|
-
|
|
100
|
-
|
|
101
|
-
|
|
102
|
-
|
|
126
|
+
if save:
|
|
127
|
+
save_options(opt)
|
|
128
|
+
|
|
129
|
+
def save_options(opt):
|
|
130
|
+
message = format_options(opt)
|
|
131
|
+
expr_dir = os.path.join(opt.checkpoints_dir, opt.name)
|
|
132
|
+
mkdirs(expr_dir)
|
|
133
|
+
file_name = os.path.join(expr_dir, '{}_opt.txt'.format(opt.phase))
|
|
134
|
+
with open(file_name, 'wt') as opt_file:
|
|
135
|
+
opt_file.write(message+'\n')
|