deepliif 1.1.5__py3-none-any.whl → 1.1.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cli.py +52 -9
- deepliif/models/__init__.py +77 -15
- deepliif/models/base_model.py +2 -0
- deepliif/util/__init__.py +66 -0
- {deepliif-1.1.5.dist-info → deepliif-1.1.7.dist-info}/LICENSE.md +0 -0
- {deepliif-1.1.5.dist-info → deepliif-1.1.7.dist-info}/METADATA +408 -403
- {deepliif-1.1.5.dist-info → deepliif-1.1.7.dist-info}/RECORD +10 -10
- {deepliif-1.1.5.dist-info → deepliif-1.1.7.dist-info}/WHEEL +1 -1
- {deepliif-1.1.5.dist-info → deepliif-1.1.7.dist-info}/entry_points.txt +0 -0
- {deepliif-1.1.5.dist-info → deepliif-1.1.7.dist-info}/top_level.txt +0 -0
cli.py
CHANGED
|
@@ -11,7 +11,7 @@ from PIL import Image
|
|
|
11
11
|
|
|
12
12
|
from deepliif.data import create_dataset, transform
|
|
13
13
|
from deepliif.models import inference, postprocess, compute_overlap, init_nets, DeepLIIFModel, infer_modalities, infer_results_for_wsi
|
|
14
|
-
from deepliif.util import allowed_file, Visualizer, get_information
|
|
14
|
+
from deepliif.util import allowed_file, Visualizer, get_information, test_diff_original_serialized, disable_batchnorm_tracking_stats
|
|
15
15
|
from deepliif.util.util import mkdirs, check_multi_scale
|
|
16
16
|
# from deepliif.util import infer_results_for_wsi
|
|
17
17
|
|
|
@@ -461,21 +461,51 @@ def trainlaunch(**kwargs):
|
|
|
461
461
|
@cli.command()
|
|
462
462
|
@click.option('--models-dir', default='./model-server/DeepLIIF_Latest_Model', help='reads models from here')
|
|
463
463
|
@click.option('--output-dir', help='saves results here.')
|
|
464
|
-
|
|
464
|
+
@click.option('--device', default='cpu', type=str, help='device to load model, either cpu or gpu')
|
|
465
|
+
@click.option('--verbose', default=0, type=int,help='saves results here.')
|
|
466
|
+
def serialize(models_dir, output_dir, device, verbose):
|
|
465
467
|
"""Serialize DeepLIIF models using Torchscript
|
|
466
468
|
"""
|
|
467
469
|
output_dir = output_dir or models_dir
|
|
470
|
+
ensure_exists(output_dir)
|
|
468
471
|
|
|
469
472
|
sample = transform(Image.new('RGB', (512, 512)))
|
|
470
|
-
|
|
473
|
+
|
|
471
474
|
with click.progressbar(
|
|
472
475
|
init_nets(models_dir, eager_mode=True).items(),
|
|
473
476
|
label='Tracing nets',
|
|
474
477
|
item_show_func=lambda n: n[0] if n else n
|
|
475
478
|
) as bar:
|
|
476
479
|
for name, net in bar:
|
|
477
|
-
|
|
480
|
+
# the model should be in eval model so that there won't be randomness in tracking brought by dropout etc. layers
|
|
481
|
+
# https://github.com/pytorch/pytorch/issues/23999#issuecomment-747832122
|
|
482
|
+
net = net.eval()
|
|
483
|
+
net = disable_batchnorm_tracking_stats(net)
|
|
484
|
+
net = net.cpu()
|
|
485
|
+
if name.startswith('GS'):
|
|
486
|
+
traced_net = torch.jit.trace(net, torch.cat([sample, sample, sample], 1))
|
|
487
|
+
else:
|
|
488
|
+
traced_net = torch.jit.trace(net, sample)
|
|
489
|
+
# traced_net = torch.jit.script(net)
|
|
478
490
|
traced_net.save(f'{output_dir}/{name}.pt')
|
|
491
|
+
|
|
492
|
+
# test: whether the original and the serialized model produces highly similar predictions
|
|
493
|
+
print('testing similarity between prediction from original vs serialized models...')
|
|
494
|
+
models_original = init_nets(models_dir,eager_mode=True)
|
|
495
|
+
models_serialized = init_nets(output_dir,eager_mode=False)
|
|
496
|
+
if device == 'gpu':
|
|
497
|
+
sample = sample.cuda()
|
|
498
|
+
else:
|
|
499
|
+
sample = sample.cpu()
|
|
500
|
+
for name in models_serialized.keys():
|
|
501
|
+
print(name,':')
|
|
502
|
+
model_original = models_original[name].cuda().eval() if device=='gpu' else models_original[name].cpu().eval()
|
|
503
|
+
model_serialized = models_serialized[name].cuda() if device=='gpu' else models_serialized[name].cpu().eval()
|
|
504
|
+
if name.startswith('GS'):
|
|
505
|
+
test_diff_original_serialized(model_original,model_serialized,torch.cat([sample, sample, sample], 1),verbose)
|
|
506
|
+
else:
|
|
507
|
+
test_diff_original_serialized(model_original,model_serialized,sample,verbose)
|
|
508
|
+
print('PASS')
|
|
479
509
|
|
|
480
510
|
|
|
481
511
|
@cli.command()
|
|
@@ -486,7 +516,11 @@ def serialize(models_dir, output_dir):
|
|
|
486
516
|
@click.option('--region-size', default=20000, help='Due to limits in the resources, the whole slide image cannot be processed in whole.'
|
|
487
517
|
'So the WSI image is read region by region. '
|
|
488
518
|
'This parameter specifies the size each region to be read into GPU for inferrence.')
|
|
489
|
-
|
|
519
|
+
@click.option('--eager-mode', is_flag=True, help='use eager mode (loading original models, otherwise serialized ones)')
|
|
520
|
+
@click.option('--color-dapi', is_flag=True, help='color dapi image to produce the same coloring as in the paper')
|
|
521
|
+
@click.option('--color-marker', is_flag=True, help='color marker image to produce the same coloring as in the paper')
|
|
522
|
+
def test(input_dir, output_dir, tile_size, model_dir, region_size, eager_mode,
|
|
523
|
+
color_dapi, color_marker):
|
|
490
524
|
|
|
491
525
|
"""Test trained models
|
|
492
526
|
"""
|
|
@@ -507,7 +541,7 @@ def test(input_dir, output_dir, tile_size, model_dir, region_size):
|
|
|
507
541
|
print(time.time() - start_time)
|
|
508
542
|
else:
|
|
509
543
|
img = Image.open(os.path.join(input_dir, filename)).convert('RGB')
|
|
510
|
-
images, scoring = infer_modalities(img, tile_size, model_dir)
|
|
544
|
+
images, scoring = infer_modalities(img, tile_size, model_dir, eager_mode, color_dapi, color_marker)
|
|
511
545
|
|
|
512
546
|
for name, i in images.items():
|
|
513
547
|
i.save(os.path.join(
|
|
@@ -589,6 +623,15 @@ def prepare_testing_data(input_dir, dataset_dir):
|
|
|
589
623
|
cv2.imwrite(os.path.join(test_dir, img), np.concatenate([image, image, image, image, image, image], 1))
|
|
590
624
|
|
|
591
625
|
|
|
626
|
+
# to load pickle file saved from gpu in a cpu environment: https://github.com/pytorch/pytorch/issues/16797#issuecomment-633423219
|
|
627
|
+
from io import BytesIO
|
|
628
|
+
class CPU_Unpickler(pickle.Unpickler):
|
|
629
|
+
def find_class(self, module, name):
|
|
630
|
+
if module == 'torch.storage' and name == '_load_from_bytes':
|
|
631
|
+
return lambda b: torch.load(BytesIO(b), map_location='cpu')
|
|
632
|
+
else: return super().find_class(module, name)
|
|
633
|
+
|
|
634
|
+
|
|
592
635
|
@cli.command()
|
|
593
636
|
@click.option('--pickle-dir', required=True, help='directory where the pickled snapshots are stored')
|
|
594
637
|
def visualize(pickle_dir):
|
|
@@ -599,8 +642,8 @@ def visualize(pickle_dir):
|
|
|
599
642
|
time.sleep(1)
|
|
600
643
|
|
|
601
644
|
params_opt = pickle.load(open(path_init,'rb'))
|
|
602
|
-
params_opt
|
|
603
|
-
visualizer = Visualizer(
|
|
645
|
+
params_opt.remote = False
|
|
646
|
+
visualizer = Visualizer(params_opt) # create a visualizer that display/save images and plots
|
|
604
647
|
|
|
605
648
|
paths_plot = {'display_current_results':os.path.join(pickle_dir,'display_current_results.pickle'),
|
|
606
649
|
'plot_current_losses':os.path.join(pickle_dir,'plot_current_losses.pickle')}
|
|
@@ -612,7 +655,7 @@ def visualize(pickle_dir):
|
|
|
612
655
|
try:
|
|
613
656
|
last_modified_time_plot = os.path.getmtime(path_plot)
|
|
614
657
|
if last_modified_time_plot > last_modified_time[method]:
|
|
615
|
-
params_plot =
|
|
658
|
+
params_plot = CPU_Unpickler(open(path_plot,'rb')).load()
|
|
616
659
|
last_modified_time[method] = last_modified_time_plot
|
|
617
660
|
getattr(visualizer,method)(**params_plot)
|
|
618
661
|
print(f'{method} refreshed, last modified time {time.ctime(last_modified_time[method])}')
|
deepliif/models/__init__.py
CHANGED
|
@@ -88,7 +88,10 @@ def create_model(opt):
|
|
|
88
88
|
|
|
89
89
|
|
|
90
90
|
def load_torchscript_model(model_pt_path, device):
|
|
91
|
-
|
|
91
|
+
net = torch.jit.load(model_pt_path, map_location=device)
|
|
92
|
+
net = disable_batchnorm_tracking_stats(net)
|
|
93
|
+
net.eval()
|
|
94
|
+
return net
|
|
92
95
|
|
|
93
96
|
|
|
94
97
|
def read_model_params(file_addr):
|
|
@@ -132,7 +135,8 @@ def load_eager_models(model_dir, devices):
|
|
|
132
135
|
os.path.join(model_dir, f'latest_net_{n}.pth'),
|
|
133
136
|
map_location=devices[n]
|
|
134
137
|
))
|
|
135
|
-
nets[n] = net
|
|
138
|
+
nets[n] = disable_batchnorm_tracking_stats(net)
|
|
139
|
+
nets[n].eval()
|
|
136
140
|
|
|
137
141
|
for n in ['G51', 'G52', 'G53', 'G54', 'G55']:
|
|
138
142
|
net = UnetGenerator(input_nc, output_nc, 9, ngf, norm_layer=norm_layer, use_dropout=use_dropout)
|
|
@@ -140,7 +144,8 @@ def load_eager_models(model_dir, devices):
|
|
|
140
144
|
os.path.join(model_dir, f'latest_net_{n}.pth'),
|
|
141
145
|
map_location=devices[n]
|
|
142
146
|
))
|
|
143
|
-
nets[n] = net
|
|
147
|
+
nets[n] = disable_batchnorm_tracking_stats(net)
|
|
148
|
+
nets[n].eval()
|
|
144
149
|
|
|
145
150
|
return nets
|
|
146
151
|
|
|
@@ -185,7 +190,12 @@ def compute_overlap(img_size, tile_size):
|
|
|
185
190
|
return tile_size // 4
|
|
186
191
|
|
|
187
192
|
|
|
188
|
-
def run_torchserve(img, model_path=None):
|
|
193
|
+
def run_torchserve(img, model_path=None, eager_mode=False):
|
|
194
|
+
"""
|
|
195
|
+
eager_mode: not used in this function; put in place to be consistent with run_dask
|
|
196
|
+
so that run_wrapper() could call either this function or run_dask with
|
|
197
|
+
same syntax
|
|
198
|
+
"""
|
|
189
199
|
buffer = BytesIO()
|
|
190
200
|
torch.save(transform(img.resize((512, 512))), buffer)
|
|
191
201
|
|
|
@@ -203,9 +213,9 @@ def run_torchserve(img, model_path=None):
|
|
|
203
213
|
return {k: tensor_to_pil(deserialize_tensor(v)) for k, v in res.json().items()}
|
|
204
214
|
|
|
205
215
|
|
|
206
|
-
def run_dask(img, model_path):
|
|
216
|
+
def run_dask(img, model_path, eager_mode=False):
|
|
207
217
|
model_dir = os.getenv('DEEPLIIF_MODEL_DIR', model_path)
|
|
208
|
-
nets = init_nets(model_dir)
|
|
218
|
+
nets = init_nets(model_dir, eager_mode)
|
|
209
219
|
|
|
210
220
|
ts = transform(img.resize((512, 512)))
|
|
211
221
|
|
|
@@ -237,7 +247,7 @@ def is_empty(tile):
|
|
|
237
247
|
return True if calculate_background_area(tile) > 98 else False
|
|
238
248
|
|
|
239
249
|
|
|
240
|
-
def run_wrapper(tile, run_fn, model_path):
|
|
250
|
+
def run_wrapper(tile, run_fn, model_path, eager_mode=False):
|
|
241
251
|
if is_empty(tile):
|
|
242
252
|
return {
|
|
243
253
|
'G1': Image.new(mode='RGB', size=(512, 512), color=(201, 211, 208)),
|
|
@@ -247,17 +257,17 @@ def run_wrapper(tile, run_fn, model_path):
|
|
|
247
257
|
'G5': Image.new(mode='RGB', size=(512, 512), color=(0, 0, 0))
|
|
248
258
|
}
|
|
249
259
|
else:
|
|
250
|
-
return run_fn(tile, model_path)
|
|
251
|
-
|
|
260
|
+
return run_fn(tile, model_path, eager_mode)
|
|
252
261
|
|
|
253
|
-
def inference(img, tile_size, overlap_size, model_path, use_torchserve=False):
|
|
254
262
|
|
|
263
|
+
def inference_old(img, tile_size, overlap_size, model_path, use_torchserve=False, eager_mode=False,
|
|
264
|
+
color_dapi=False, color_marker=False):
|
|
255
265
|
|
|
256
266
|
tiles = list(generate_tiles(img, tile_size, overlap_size))
|
|
257
267
|
|
|
258
268
|
run_fn = run_torchserve if use_torchserve else run_dask
|
|
259
269
|
# res = [Tile(t.i, t.j, run_fn(t.img, model_path)) for t in tiles]
|
|
260
|
-
res = [Tile(t.i, t.j, run_wrapper(t.img, run_fn, model_path)) for t in tiles]
|
|
270
|
+
res = [Tile(t.i, t.j, run_wrapper(t.img, run_fn, model_path, eager_mode)) for t in tiles]
|
|
261
271
|
|
|
262
272
|
def get_net_tiles(n):
|
|
263
273
|
return [Tile(t.i, t.j, t.img[n]) for t in res]
|
|
@@ -276,12 +286,14 @@ def inference(img, tile_size, overlap_size, model_path, use_torchserve=False):
|
|
|
276
286
|
|
|
277
287
|
images['DAPI'] = stitch(get_net_tiles('G2'), tile_size, overlap_size).resize(img.size)
|
|
278
288
|
dapi_pix = np.array(images['DAPI'].convert('L').convert('RGB'))
|
|
279
|
-
|
|
289
|
+
if color_dapi:
|
|
290
|
+
dapi_pix[:, :, 0] = 0
|
|
280
291
|
images['DAPI'] = Image.fromarray(dapi_pix)
|
|
281
292
|
images['Lap2'] = stitch(get_net_tiles('G3'), tile_size, overlap_size).resize(img.size)
|
|
282
293
|
images['Marker'] = stitch(get_net_tiles('G4'), tile_size, overlap_size).resize(img.size)
|
|
283
294
|
marker_pix = np.array(images['Marker'].convert('L').convert('RGB'))
|
|
284
|
-
|
|
295
|
+
if color_marker:
|
|
296
|
+
marker_pix[:, :, 2] = 0
|
|
285
297
|
images['Marker'] = Image.fromarray(marker_pix)
|
|
286
298
|
|
|
287
299
|
# images['Marker'] = stitch(
|
|
@@ -294,6 +306,52 @@ def inference(img, tile_size, overlap_size, model_path, use_torchserve=False):
|
|
|
294
306
|
return images
|
|
295
307
|
|
|
296
308
|
|
|
309
|
+
def inference(img, tile_size, overlap_size, model_path, use_torchserve=False, eager_mode=False,
|
|
310
|
+
color_dapi=False, color_marker=False):
|
|
311
|
+
|
|
312
|
+
rescaled, rows, cols = format_image_for_tiling(img, tile_size, overlap_size)
|
|
313
|
+
|
|
314
|
+
run_fn = run_torchserve if use_torchserve else run_dask
|
|
315
|
+
|
|
316
|
+
images = {}
|
|
317
|
+
images['Hema'] = create_image_for_stitching(tile_size, rows, cols)
|
|
318
|
+
images['DAPI'] = create_image_for_stitching(tile_size, rows, cols)
|
|
319
|
+
images['Lap2'] = create_image_for_stitching(tile_size, rows, cols)
|
|
320
|
+
images['Marker'] = create_image_for_stitching(tile_size, rows, cols)
|
|
321
|
+
images['Seg'] = create_image_for_stitching(tile_size, rows, cols)
|
|
322
|
+
|
|
323
|
+
for i in range(cols):
|
|
324
|
+
for j in range(rows):
|
|
325
|
+
tile = extract_tile(rescaled, tile_size, overlap_size, i, j)
|
|
326
|
+
res = run_wrapper(tile, run_fn, model_path, eager_mode)
|
|
327
|
+
|
|
328
|
+
stitch_tile(images['Hema'], res['G1'], tile_size, overlap_size, i, j)
|
|
329
|
+
stitch_tile(images['DAPI'], res['G2'], tile_size, overlap_size, i, j)
|
|
330
|
+
stitch_tile(images['Lap2'], res['G3'], tile_size, overlap_size, i, j)
|
|
331
|
+
stitch_tile(images['Marker'], res['G4'], tile_size, overlap_size, i, j)
|
|
332
|
+
stitch_tile(images['Seg'], res['G5'], tile_size, overlap_size, i, j)
|
|
333
|
+
|
|
334
|
+
images['Hema'] = images['Hema'].resize(img.size)
|
|
335
|
+
images['DAPI'] = images['DAPI'].resize(img.size)
|
|
336
|
+
images['Lap2'] = images['Lap2'].resize(img.size)
|
|
337
|
+
images['Marker'] = images['Marker'].resize(img.size)
|
|
338
|
+
images['Seg'] = images['Seg'].resize(img.size)
|
|
339
|
+
|
|
340
|
+
if color_dapi:
|
|
341
|
+
matrix = ( 0, 0, 0, 0,
|
|
342
|
+
299/1000, 587/1000, 114/1000, 0,
|
|
343
|
+
299/1000, 587/1000, 114/1000, 0)
|
|
344
|
+
images['DAPI'] = images['DAPI'].convert('RGB', matrix)
|
|
345
|
+
|
|
346
|
+
if color_marker:
|
|
347
|
+
matrix = (299/1000, 587/1000, 114/1000, 0,
|
|
348
|
+
299/1000, 587/1000, 114/1000, 0,
|
|
349
|
+
0, 0, 0, 0)
|
|
350
|
+
images['Marker'] = images['Marker'].convert('RGB', matrix)
|
|
351
|
+
|
|
352
|
+
return images
|
|
353
|
+
|
|
354
|
+
|
|
297
355
|
def postprocess(img, seg_img, thresh=80, noise_objects_size=20, small_object_size=50):
|
|
298
356
|
mask_image = create_basic_segmentation_mask(np.array(img), np.array(seg_img),
|
|
299
357
|
thresh, noise_objects_size, small_object_size)
|
|
@@ -312,7 +370,8 @@ def postprocess(img, seg_img, thresh=80, noise_objects_size=20, small_object_siz
|
|
|
312
370
|
return images, scoring
|
|
313
371
|
|
|
314
372
|
|
|
315
|
-
def infer_modalities(img, tile_size, model_dir
|
|
373
|
+
def infer_modalities(img, tile_size, model_dir, eager_mode=False,
|
|
374
|
+
color_dapi=False, color_marker=False):
|
|
316
375
|
"""
|
|
317
376
|
This function is used to infer modalities for the given image using a trained model.
|
|
318
377
|
:param img: The input image.
|
|
@@ -329,7 +388,10 @@ def infer_modalities(img, tile_size, model_dir):
|
|
|
329
388
|
img,
|
|
330
389
|
tile_size=tile_size,
|
|
331
390
|
overlap_size=compute_overlap(img.size, tile_size),
|
|
332
|
-
model_path=model_dir
|
|
391
|
+
model_path=model_dir,
|
|
392
|
+
eager_mode=eager_mode,
|
|
393
|
+
color_dapi=color_dapi,
|
|
394
|
+
color_marker=color_marker
|
|
333
395
|
)
|
|
334
396
|
|
|
335
397
|
post_images, scoring = postprocess(img, images['Seg'], small_object_size=20)
|
deepliif/models/base_model.py
CHANGED
|
@@ -3,6 +3,7 @@ import torch
|
|
|
3
3
|
from collections import OrderedDict
|
|
4
4
|
from abc import ABC, abstractmethod
|
|
5
5
|
from . import networks
|
|
6
|
+
from ..util import disable_batchnorm_tracking_stats
|
|
6
7
|
|
|
7
8
|
|
|
8
9
|
class BaseModel(ABC):
|
|
@@ -90,6 +91,7 @@ class BaseModel(ABC):
|
|
|
90
91
|
if isinstance(name, str):
|
|
91
92
|
net = getattr(self, 'net' + name)
|
|
92
93
|
net.eval()
|
|
94
|
+
net = disable_batchnorm_tracking_stats(net)
|
|
93
95
|
|
|
94
96
|
def test(self):
|
|
95
97
|
"""Forward function used in test time.
|
deepliif/util/__init__.py
CHANGED
|
@@ -88,6 +88,36 @@ def stitch(tiles, tile_size, overlap_size):
|
|
|
88
88
|
return new_im
|
|
89
89
|
|
|
90
90
|
|
|
91
|
+
def format_image_for_tiling(img, tile_size, overlap_size):
|
|
92
|
+
mean_background_val = calculate_background_mean_value(img)
|
|
93
|
+
img = img.resize(output_size(img, tile_size))
|
|
94
|
+
# Adding borders with size of given overlap around the whole slide image
|
|
95
|
+
img = ImageOps.expand(img, border=overlap_size, fill=tuple(mean_background_val))
|
|
96
|
+
rows = int(img.height / tile_size)
|
|
97
|
+
cols = int(img.width / tile_size)
|
|
98
|
+
return img, rows, cols
|
|
99
|
+
|
|
100
|
+
|
|
101
|
+
def extract_tile(img, tile_size, overlap_size, i, j):
|
|
102
|
+
return img.crop((
|
|
103
|
+
i * tile_size, j * tile_size,
|
|
104
|
+
i * tile_size + tile_size + 2 * overlap_size,
|
|
105
|
+
j * tile_size + tile_size + 2 * overlap_size
|
|
106
|
+
))
|
|
107
|
+
|
|
108
|
+
|
|
109
|
+
def create_image_for_stitching(tile_size, rows, cols):
|
|
110
|
+
width = tile_size * cols
|
|
111
|
+
height = tile_size * rows
|
|
112
|
+
return Image.new('RGB', (width, height))
|
|
113
|
+
|
|
114
|
+
|
|
115
|
+
def stitch_tile(img, tile, tile_size, overlap_size, i, j):
|
|
116
|
+
tile = tile.resize((tile_size + 2 * overlap_size, tile_size + 2 * overlap_size))
|
|
117
|
+
tile = tile.crop((overlap_size, overlap_size, overlap_size + tile_size, overlap_size + tile_size))
|
|
118
|
+
img.paste(tile, (i * tile_size, j * tile_size))
|
|
119
|
+
|
|
120
|
+
|
|
91
121
|
def calculate_background_mean_value(img):
|
|
92
122
|
img = cv2.fastNlMeansDenoisingColored(np.array(img), None, 10, 10, 7, 21)
|
|
93
123
|
img = np.array(img, dtype=float)
|
|
@@ -349,3 +379,39 @@ def read_results_from_pickle_file(input_addr):
|
|
|
349
379
|
pickle_obj.close()
|
|
350
380
|
return results
|
|
351
381
|
|
|
382
|
+
def test_diff_original_serialized(model_original,model_serialized,example,verbose=0):
|
|
383
|
+
threshold = 10
|
|
384
|
+
|
|
385
|
+
orig_res = model_original(example)
|
|
386
|
+
if verbose > 0:
|
|
387
|
+
print('Original:')
|
|
388
|
+
print(orig_res.shape)
|
|
389
|
+
print(orig_res[0, 0:10])
|
|
390
|
+
print('min abs value:{}'.format(torch.min(torch.abs(orig_res))))
|
|
391
|
+
|
|
392
|
+
ts_res = model_serialized(example)
|
|
393
|
+
if verbose > 0:
|
|
394
|
+
print('Torchscript:')
|
|
395
|
+
print(ts_res.shape)
|
|
396
|
+
print(ts_res[0, 0:10])
|
|
397
|
+
print('min abs value:{}'.format(torch.min(torch.abs(ts_res))))
|
|
398
|
+
|
|
399
|
+
abs_diff = torch.abs(orig_res-ts_res)
|
|
400
|
+
if verbose > 0:
|
|
401
|
+
print('Dif sum:')
|
|
402
|
+
print(torch.sum(abs_diff))
|
|
403
|
+
print('max dif:{}'.format(torch.max(abs_diff)))
|
|
404
|
+
|
|
405
|
+
assert torch.sum(abs_diff) <= threshold, f"Sum of difference in predicted values {torch.sum(abs_diff)} is larger than threshold {threshold}"
|
|
406
|
+
|
|
407
|
+
def disable_batchnorm_tracking_stats(model):
|
|
408
|
+
# https://discuss.pytorch.org/t/performance-highly-degraded-when-eval-is-activated-in-the-test-phase/3323/16
|
|
409
|
+
# https://discuss.pytorch.org/t/performance-highly-degraded-when-eval-is-activated-in-the-test-phase/3323/67
|
|
410
|
+
# https://github.com/pytorch/pytorch/blob/ca39c5b04e30a67512589cafbd9d063cc17168a5/torch/nn/modules/batchnorm.py#L158
|
|
411
|
+
for m in model.modules():
|
|
412
|
+
for child in m.children():
|
|
413
|
+
if type(child) == torch.nn.BatchNorm2d:
|
|
414
|
+
child.track_running_stats = False
|
|
415
|
+
child.running_mean = None
|
|
416
|
+
child.running_var = None
|
|
417
|
+
return model
|
|
File without changes
|