deepliif 1.1.11__py3-none-any.whl → 1.1.12__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
@@ -0,0 +1,440 @@
1
+ import math
2
+ import cv2
3
+ from PIL import Image
4
+ import skimage.measure
5
+ from skimage import feature
6
+ from skimage.morphology import remove_small_objects
7
+ import numpy as np
8
+ import scipy.ndimage as ndi
9
+ from numba import jit
10
+
11
+
12
+ def remove_small_objects_from_image(img, min_size=100):
13
+ image_copy = img.copy()
14
+ image_copy[img > 0] = 1
15
+ image_copy = image_copy.astype(bool)
16
+ removed_red_channel = remove_small_objects(image_copy, min_size=min_size).astype(np.uint8)
17
+ img[removed_red_channel == 0] = 0
18
+
19
+ return img
20
+
21
+
22
+ def remove_background_noise(mask, mask_boundary):
23
+ labeled = skimage.measure.label(mask, background=0)
24
+ padding = 5
25
+ for i in range(1, len(np.unique(labeled))):
26
+ component = np.zeros_like(mask)
27
+ component[labeled == i] = mask[labeled == i]
28
+ component_bound = np.zeros_like(mask_boundary)
29
+ component_bound[max(0, min(np.nonzero(component)[0]) - padding): min(mask_boundary.shape[1],
30
+ max(np.nonzero(component)[0]) + padding),
31
+ max(0, min(np.nonzero(component)[1]) - padding): min(mask_boundary.shape[1],
32
+ max(np.nonzero(component)[1]) + padding)] \
33
+ = mask_boundary[max(0, min(np.nonzero(component)[0]) - padding): min(mask_boundary.shape[1], max(
34
+ np.nonzero(component)[0]) + padding),
35
+ max(0, min(np.nonzero(component)[1]) - padding): min(mask_boundary.shape[1],
36
+ max(np.nonzero(component)[1]) + padding)]
37
+ if len(np.nonzero(component_bound)[0]) < len(np.nonzero(component)[0]) / 3:
38
+ mask[labeled == i] = 0
39
+ return mask
40
+
41
+
42
+ def remove_cell_noise(mask1, mask2):
43
+ labeled = skimage.measure.label(mask1, background=0)
44
+ padding = 2
45
+ for i in range(1, len(np.unique(labeled))):
46
+ component = np.zeros_like(mask1)
47
+ component[labeled == i] = mask1[labeled == i]
48
+ component_bound = np.zeros_like(mask2)
49
+ component_bound[
50
+ max(0, min(np.nonzero(component)[0]) - padding): min(mask2.shape[1], max(np.nonzero(component)[0]) + padding),
51
+ max(0, min(np.nonzero(component)[1]) - padding): min(mask2.shape[1], max(np.nonzero(component)[1]) + padding)] \
52
+ = mask2[max(0, min(np.nonzero(component)[0]) - padding): min(mask2.shape[1],
53
+ max(np.nonzero(component)[0]) + padding),
54
+ max(0, min(np.nonzero(component)[1]) - padding): min(mask2.shape[1],
55
+ max(np.nonzero(component)[1]) + padding)]
56
+ if len(np.nonzero(component_bound)[0]) > len(np.nonzero(component)[0]) / 3:
57
+ mask1[labeled == i] = 0
58
+ mask2[labeled == i] = 255
59
+ return mask1, mask2
60
+
61
+
62
+ def create_basic_segmentation_mask(img, seg_img, thresh=80, noise_objects_size=20, small_object_size=50):
63
+ positive_mask, negative_mask = positive_negative_masks_basic(img, seg_img, thresh, noise_objects_size, small_object_size)
64
+
65
+ mask = np.zeros_like(img)
66
+
67
+ mask[positive_mask > 0] = (255, 0, 0)
68
+ mask[negative_mask > 0] = (0, 0, 255)
69
+
70
+ return mask
71
+
72
+
73
+ def imadjust(x, gamma=0.7, c=0, d=1):
74
+ """
75
+ Adjusting the image contrast and brightness
76
+
77
+ :param x: Input array
78
+ :param gamma: Gamma value
79
+ :param c: Minimum value
80
+ :param d: Maximum value
81
+ :return: Adjusted image
82
+ """
83
+ a = x.min()
84
+ b = x.max()
85
+ y = (((x - a) / (b - a)) ** gamma) * (d - c) + c
86
+ return y
87
+
88
+
89
+ def adjust_dapi(inferred_tile, orig_tile):
90
+ """Adjusts the intensity of mpIF DAPI image
91
+
92
+ Parameters:
93
+ inferred_tile (Image) -- inferred tile image
94
+ orig_tile (Image) -- original tile image
95
+
96
+ Returns:
97
+ processed_tile (Image) -- the adjusted mpIF DAPI image
98
+ """
99
+ inferred_tile_array = np.array(inferred_tile)
100
+ orig_tile_array = np.array(orig_tile)
101
+
102
+ multiplier = 8 / math.log(np.max(orig_tile_array))
103
+
104
+ if np.mean(orig_tile_array) < 200:
105
+ processed_tile = imadjust(inferred_tile_array,
106
+ gamma=multiplier * math.log(np.mean(inferred_tile_array)) / math.log(
107
+ np.mean(orig_tile_array)),
108
+ c=5, d=255).astype(np.uint8)
109
+
110
+ else:
111
+ processed_tile = imadjust(inferred_tile_array,
112
+ gamma=multiplier,
113
+ c=5, d=255).astype(np.uint8)
114
+ return Image.fromarray(processed_tile)
115
+
116
+
117
+ def adjust_marker(inferred_tile, orig_tile):
118
+ """Adjusts the intensity of mpIF marker image
119
+
120
+ Parameters:
121
+ inferred_tile (Image) -- inferred tile image
122
+ orig_tile (Image) -- original tile image
123
+
124
+ Returns:
125
+ processed_tile (Image) -- the adjusted marker image
126
+ """
127
+ inferred_tile_array = np.array(inferred_tile)
128
+ orig_tile_array = np.array(orig_tile)
129
+
130
+ multiplier = 8 / math.log(np.max(orig_tile_array))
131
+
132
+ if np.mean(orig_tile_array) < 200:
133
+ processed_tile = imadjust(inferred_tile_array,
134
+ gamma=multiplier * math.log(np.std(inferred_tile_array)) / math.log(
135
+ np.std(orig_tile_array)),
136
+ c=5, d=255).astype(np.uint8)
137
+
138
+ else:
139
+ processed_tile = imadjust(inferred_tile_array,
140
+ gamma=multiplier,
141
+ c=5, d=255).astype(np.uint8)
142
+ return Image.fromarray(processed_tile)
143
+
144
+
145
+ # Values for uint8 masks
146
+ MASK_UNKNOWN = 50
147
+ MASK_POSITIVE = 200
148
+ MASK_NEGATIVE = 150
149
+ MASK_BACKGROUND = 0
150
+ MASK_CELL = 255
151
+ MASK_CELL_POSITIVE = 201
152
+ MASK_CELL_NEGATIVE = 151
153
+ MASK_BOUNDARY_POSITIVE = 202
154
+ MASK_BOUNDARY_NEGATIVE = 152
155
+
156
+
157
+ @jit(nopython=True)
158
+ def in_bounds(array, index):
159
+ return index[0] >= 0 and index[0] < array.shape[0] and index[1] >= 0 and index[1] < array.shape[1]
160
+
161
+
162
+ def create_posneg_mask(seg, thresh):
163
+ """Create a mask of positive and negative pixels."""
164
+
165
+ cell = np.logical_and(np.add(seg[:,:,0], seg[:,:,2], dtype=np.uint16) > thresh, seg[:,:,1] <= 80)
166
+ pos = np.logical_and(cell, seg[:,:,0] >= seg[:,:,2])
167
+ neg = np.logical_xor(cell, pos)
168
+
169
+ mask = np.full(seg.shape[0:2], MASK_UNKNOWN, dtype=np.uint8)
170
+ mask[pos] = MASK_POSITIVE
171
+ mask[neg] = MASK_NEGATIVE
172
+
173
+ return mask
174
+
175
+
176
+ @jit(nopython=True)
177
+ def mark_background(mask):
178
+ """Mask all background pixels by 4-connected region growing unknown boundary pixels."""
179
+
180
+ seeds = []
181
+ for i in range(mask.shape[0]):
182
+ if mask[i, 0] == MASK_UNKNOWN:
183
+ seeds.append((i, 0))
184
+ if mask[i, mask.shape[1]-1] == MASK_UNKNOWN:
185
+ seeds.append((i, mask.shape[1]-1))
186
+ for j in range(mask.shape[1]):
187
+ if mask[0, j] == MASK_UNKNOWN:
188
+ seeds.append((0, j))
189
+ if mask[mask.shape[0]-1, j] == MASK_UNKNOWN:
190
+ seeds.append((mask.shape[0]-1, j))
191
+
192
+ neighbors = [(-1, 0), (1, 0), (0, -1), (0, 1)]
193
+
194
+ while len(seeds) > 0:
195
+ seed = seeds.pop()
196
+ if mask[seed] == MASK_UNKNOWN:
197
+ mask[seed] = MASK_BACKGROUND
198
+ for n in neighbors:
199
+ idx = (seed[0] + n[0], seed[1] + n[1])
200
+ if in_bounds(mask, idx) and mask[idx] == MASK_UNKNOWN:
201
+ seeds.append(idx)
202
+
203
+
204
+ @jit(nopython=True)
205
+ def compute_cell_classification(mask, marker, size_thresh, marker_thresh, size_thresh_upper = None):
206
+ """
207
+ Compute the mapping of the mask to positive and negative cell classification.
208
+
209
+ Parameters
210
+ ==========
211
+ mask: 2D uint8 numpy array with pixels labeled as positive, negative, background, or unknown.
212
+ After the function executes, the pixels will be labeled as background or cell/boundary pos/neg.
213
+ marker: 2D uint8 numpy array with the restained marker values
214
+ size_thresh: Lower size threshold in pixels. Only include cells larger than this count.
215
+ size_thresh_upper: Upper size threshold in pixels, or None. Only include cells smaller than this count.
216
+ marker_thresh: Classify cell as positive if any marker value within the cell is above this threshold.
217
+
218
+ Returns
219
+ =======
220
+ Dictionary with the following values:
221
+ num_total (integer) -- total number of cells in the image
222
+ num_pos (integer) -- number of positive cells in the image
223
+ num_neg (integer) -- number of negative calles in the image
224
+ percent_pos (floating point) -- percentage of positive cells to all cells (IHC score)
225
+ """
226
+
227
+ neighbors = [(-1, -1), (0, -1), (1, -1), (-1, 0), (1, 0), (-1, 1), (0, 1), (1, 1)]
228
+ border_neighbors = [(0, -1), (-1, 0), (1, 0), (0, 1)]
229
+ positive_cell_count, negative_cell_count = 0, 0
230
+
231
+ for y in range(mask.shape[0]):
232
+ for x in range(mask.shape[1]):
233
+ if mask[y, x] == MASK_POSITIVE or mask[y, x] == MASK_NEGATIVE:
234
+ seeds = [(y, x)]
235
+ cell_coords = []
236
+ count = 1
237
+ count_posneg = 1 if mask[y, x] != MASK_UNKNOWN else 0
238
+ count_positive = 1 if mask[y, x] == MASK_POSITIVE else 0
239
+ max_marker = marker[y, x] if marker is not None else 0
240
+ mask[y, x] = MASK_CELL
241
+ cell_coords.append((y, x))
242
+
243
+ while len(seeds) > 0:
244
+ seed = seeds.pop()
245
+ for n in neighbors:
246
+ idx = (seed[0] + n[0], seed[1] + n[1])
247
+ if in_bounds(mask, idx) and (mask[idx] == MASK_POSITIVE or mask[idx] == MASK_NEGATIVE or mask[idx] == MASK_UNKNOWN):
248
+ seeds.append(idx)
249
+ if mask[idx] == MASK_POSITIVE:
250
+ count_positive += 1
251
+ if mask[idx] != MASK_UNKNOWN:
252
+ count_posneg += 1
253
+ if marker is not None and marker[idx] > max_marker:
254
+ max_marker = marker[idx]
255
+ mask[idx] = MASK_CELL
256
+ cell_coords.append(idx)
257
+ count += 1
258
+
259
+ if count > size_thresh and (size_thresh_upper is None or count < size_thresh_upper):
260
+ if (count_positive/count_posneg) >= 0.5 or max_marker > marker_thresh:
261
+ fill_value = MASK_CELL_POSITIVE
262
+ border_value = MASK_BOUNDARY_POSITIVE
263
+ positive_cell_count += 1
264
+ else:
265
+ fill_value = MASK_CELL_NEGATIVE
266
+ border_value = MASK_BOUNDARY_NEGATIVE
267
+ negative_cell_count += 1
268
+ else:
269
+ fill_value = MASK_BACKGROUND
270
+ border_value = MASK_BACKGROUND
271
+
272
+ for coord in cell_coords:
273
+ is_boundary = False
274
+ for n in border_neighbors:
275
+ idx = (coord[0] + n[0], coord[1] + n[1])
276
+ if in_bounds(mask, idx) and mask[idx] == MASK_BACKGROUND:
277
+ is_boundary = True
278
+ break
279
+ if is_boundary:
280
+ mask[coord] = border_value
281
+ else:
282
+ mask[coord] = fill_value
283
+
284
+ counts = {
285
+ 'num_total': positive_cell_count + negative_cell_count,
286
+ 'num_pos': positive_cell_count,
287
+ 'num_neg': negative_cell_count,
288
+ }
289
+ return counts
290
+
291
+
292
+ @jit(nopython=True)
293
+ def enlarge_cell_boundaries(mask):
294
+ neighbors = [(-1, -1), (0, -1), (1, -1), (-1, 0), (1, 0), (-1, 1), (0, 1), (1, 1)]
295
+ for y in range(mask.shape[0]):
296
+ for x in range(mask.shape[1]):
297
+ if mask[y, x] == MASK_BOUNDARY_POSITIVE or mask[y, x] == MASK_BOUNDARY_NEGATIVE:
298
+ value = MASK_POSITIVE if mask[y, x] == MASK_BOUNDARY_POSITIVE else MASK_NEGATIVE
299
+ for n in neighbors:
300
+ idx = (y + n[0], x + n[1])
301
+ if in_bounds(mask, idx) and mask[idx] != MASK_BOUNDARY_POSITIVE and mask[idx] != MASK_BOUNDARY_NEGATIVE:
302
+ mask[idx] = value
303
+ for y in range(mask.shape[0]):
304
+ for x in range(mask.shape[1]):
305
+ if mask[y, x] == MASK_POSITIVE:
306
+ mask[y, x] = MASK_BOUNDARY_POSITIVE
307
+ elif mask[y, x] == MASK_NEGATIVE:
308
+ mask[y, x] = MASK_BOUNDARY_NEGATIVE
309
+
310
+
311
+ @jit(nopython=True)
312
+ def compute_cell_sizes(mask):
313
+ neighbors = [(-1, -1), (0, -1), (1, -1), (-1, 0), (1, 0), (-1, 1), (0, 1), (1, 1)]
314
+ sizes = []
315
+
316
+ for y in range(mask.shape[0]):
317
+ for x in range(mask.shape[1]):
318
+ if mask[y, x] == MASK_POSITIVE or mask[y, x] == MASK_NEGATIVE:
319
+ seeds = [(y, x)]
320
+ count = 1
321
+ mask[y, x] = MASK_CELL_POSITIVE if mask[y, x] == MASK_POSITIVE else MASK_CELL_NEGATIVE
322
+
323
+ while len(seeds) > 0:
324
+ seed = seeds.pop()
325
+ for n in neighbors:
326
+ idx = (seed[0] + n[0], seed[1] + n[1])
327
+ if in_bounds(mask, idx) and (mask[idx] == MASK_POSITIVE or mask[idx] == MASK_NEGATIVE or mask[idx] == MASK_UNKNOWN):
328
+ seeds.append(idx)
329
+ if mask[idx] == MASK_POSITIVE:
330
+ mask[idx] = MASK_CELL_POSITIVE
331
+ elif mask[idx] == MASK_NEGATIVE:
332
+ mask[idx] = MASK_CELL_NEGATIVE
333
+ else:
334
+ mask[idx] = MASK_CELL
335
+ count += 1
336
+
337
+ sizes.append(count)
338
+
339
+ return sizes
340
+
341
+
342
+ @jit(nopython=True)
343
+ def create_kde(values, count, bandwidth = 1.0):
344
+ gaussian_denom_inv = 1 / math.sqrt(2 * math.pi);
345
+ max_value = max(values) + 1;
346
+ step = max_value / count;
347
+ n = values.shape[0];
348
+ h = bandwidth;
349
+ h_inv = 1 / h;
350
+ kde = np.zeros(count, dtype=np.float32)
351
+
352
+ for i in range(count):
353
+ x = i * step
354
+ total = 0
355
+ for j in range(n):
356
+ val = (x - values[j]) * h_inv;
357
+ total += math.exp(-(val*val/2)) * gaussian_denom_inv; # Gaussian
358
+ kde[i] = total / (n*h);
359
+
360
+ return kde, step
361
+
362
+
363
+ def calc_default_size_thresh(mask, resolution):
364
+ sizes = compute_cell_sizes(mask)
365
+ mask[mask == MASK_CELL_POSITIVE] = MASK_POSITIVE
366
+ mask[mask == MASK_CELL_NEGATIVE] = MASK_NEGATIVE
367
+ mask[mask == MASK_CELL] = MASK_UNKNOWN
368
+
369
+ if len(sizes) > 0:
370
+ kde, step = create_kde(np.sqrt(sizes), 500)
371
+ idx = 1
372
+ for i in range(1, kde.shape[0]-1):
373
+ if kde[i] < kde[i-1] and kde[i] < kde[i+1]:
374
+ idx = i
375
+ break
376
+ thresh_sqrt = (idx - 1) * step
377
+
378
+ allowed_range_sqrt = (4, 7, 10) # [min, default, max] for default sqrt size thresh at 40x
379
+ if resolution == '20x':
380
+ allowed_range_sqrt = (3, 4, 6)
381
+ elif resolution == '10x':
382
+ allowed_range_sqrt = (2, 2, 3)
383
+
384
+ if thresh_sqrt < allowed_range_sqrt[0]:
385
+ thresh_sqrt = allowed_range_sqrt[0]
386
+ elif thresh_sqrt > allowed_range_sqrt[2]:
387
+ thresh_sqrt = allowed_range_sqrt[1]
388
+
389
+ return round(thresh_sqrt * thresh_sqrt)
390
+
391
+ else:
392
+ return 0
393
+
394
+
395
+ def calc_default_marker_thresh(marker):
396
+ if marker is not None:
397
+ nonzero = marker[marker != 0]
398
+ marker_range = (round(np.percentile(nonzero, 0.1)), round(np.percentile(nonzero, 99.9))) if nonzero.shape[0] > 0 else (0, 0)
399
+ return round((marker_range[1] - marker_range[0]) * 0.9) + marker_range[0]
400
+ else:
401
+ return 0
402
+
403
+
404
+ def compute_results(orig, seg, marker, resolution=None, seg_thresh=150, size_thresh='auto', marker_thresh='auto', size_thresh_upper=None):
405
+ mask = create_posneg_mask(seg, seg_thresh)
406
+ mark_background(mask)
407
+
408
+ if size_thresh == 'auto':
409
+ size_thresh = calc_default_size_thresh(mask, resolution)
410
+ if marker_thresh is None:
411
+ marker_thresh = 0
412
+ marker = None
413
+ elif marker_thresh == 'auto':
414
+ marker_thresh = calc_default_marker_thresh(marker)
415
+
416
+ counts = compute_cell_classification(mask, marker, size_thresh, marker_thresh, size_thresh_upper)
417
+ enlarge_cell_boundaries(mask)
418
+
419
+ scoring = {
420
+ 'num_total': counts['num_total'],
421
+ 'num_pos': counts['num_pos'],
422
+ 'num_neg': counts['num_neg'],
423
+ 'percent_pos': round(counts['num_pos'] / counts['num_total'] * 100, 1) if counts['num_pos'] > 0 else 0,
424
+ 'prob_thresh': seg_thresh,
425
+ 'size_thresh': size_thresh,
426
+ 'size_thresh_upper': size_thresh_upper,
427
+ 'marker_thresh': marker_thresh if marker is not None else None,
428
+ }
429
+
430
+ overlay = np.copy(orig)
431
+ overlay[mask == MASK_BOUNDARY_POSITIVE] = (255, 0, 0)
432
+ overlay[mask == MASK_BOUNDARY_NEGATIVE] = (0, 0, 255)
433
+
434
+ refined = np.zeros_like(seg)
435
+ refined[mask == MASK_CELL_POSITIVE, 0] = 255
436
+ refined[mask == MASK_CELL_NEGATIVE, 2] = 255
437
+ refined[mask == MASK_BOUNDARY_POSITIVE, 1] = 255
438
+ refined[mask == MASK_BOUNDARY_NEGATIVE, 1] = 255
439
+
440
+ return overlay, refined, scoring
deepliif/util/__init__.py CHANGED
@@ -14,17 +14,18 @@ from .visualizer import Visualizer
14
14
  from ..postprocessing import imadjust
15
15
  import cv2
16
16
 
17
+ import pickle
18
+ import sys
19
+
17
20
  import bioformats
18
21
  import javabridge
19
22
  import bioformats.omexml as ome
20
23
  import tifffile as tf
21
24
 
22
- import pickle
23
- import sys
24
25
 
25
26
  excluding_names = ['Hema', 'DAPI', 'DAPILap2', 'Ki67', 'Seg', 'Marked', 'SegRefined', 'SegOverlaid', 'Marker', 'Lap2']
26
27
  # Image extensions to consider
27
- image_extensions = ['.png', '.jpg', '.tif', '.jpeg', '.svs']
28
+ image_extensions = ['.png', '.jpg', '.tif', '.jpeg']
28
29
 
29
30
 
30
31
  def allowed_file(filename):
@@ -440,6 +441,87 @@ def get_information(filename):
440
441
  return size_x, size_y, size_z, size_c, size_t, pixel_type
441
442
 
442
443
 
444
+
445
+
446
+ def write_results_to_pickle_file(output_addr, results):
447
+ """
448
+ This function writes data into the pickle file.
449
+ :param output_addr: The address of the pickle file to write data into.
450
+ :param results: The data to be written into the pickle file.
451
+ :return:
452
+ """
453
+ pickle_obj = open(output_addr, "wb")
454
+ pickle.dump(results, pickle_obj)
455
+ pickle_obj.close()
456
+
457
+
458
+ def read_results_from_pickle_file(input_addr):
459
+ """
460
+ This function reads data from a pickle file and returns it.
461
+ :param input_addr: The address to the pickle file.
462
+ :return: The data inside pickle file.
463
+ """
464
+ pickle_obj = open(input_addr, "rb")
465
+ results = pickle.load(pickle_obj)
466
+ pickle_obj.close()
467
+ return results
468
+
469
+ def test_diff_original_serialized(model_original,model_serialized,example,verbose=0):
470
+ threshold = 10
471
+
472
+ orig_res = model_original(example)
473
+ if verbose > 0:
474
+ print('Original:')
475
+ print(orig_res.shape)
476
+ print(orig_res[0, 0:10])
477
+ print('min abs value:{}'.format(torch.min(torch.abs(orig_res))))
478
+
479
+ ts_res = model_serialized(example)
480
+ if verbose > 0:
481
+ print('Torchscript:')
482
+ print(ts_res.shape)
483
+ print(ts_res[0, 0:10])
484
+ print('min abs value:{}'.format(torch.min(torch.abs(ts_res))))
485
+
486
+ abs_diff = torch.abs(orig_res-ts_res)
487
+ if verbose > 0:
488
+ print('Dif sum:')
489
+ print(torch.sum(abs_diff))
490
+ print('max dif:{}'.format(torch.max(abs_diff)))
491
+
492
+ assert torch.sum(abs_diff) <= threshold, f"Sum of difference in predicted values {torch.sum(abs_diff)} is larger than threshold {threshold}"
493
+
494
+ def disable_batchnorm_tracking_stats(model):
495
+ # https://discuss.pytorch.org/t/performance-highly-degraded-when-eval-is-activated-in-the-test-phase/3323/16
496
+ # https://discuss.pytorch.org/t/performance-highly-degraded-when-eval-is-activated-in-the-test-phase/3323/67
497
+ # https://github.com/pytorch/pytorch/blob/ca39c5b04e30a67512589cafbd9d063cc17168a5/torch/nn/modules/batchnorm.py#L158
498
+ for m in model.modules():
499
+ for child in m.children():
500
+ if type(child) == torch.nn.BatchNorm2d:
501
+ child.track_running_stats = False
502
+ child.running_mean_backup = child.running_mean
503
+ child.running_mean = None
504
+ child.running_var_backup = child.running_var
505
+ child.running_var = None
506
+ return model
507
+
508
+ def enable_batchnorm_tracking_stats(model):
509
+ """
510
+ This is needed during training when val set loss/metrics calculation is enabled.
511
+ In this case, we need to switch to eval mode for inference, which triggers
512
+ disable_batchnorm_tracking_stats(). After the evaluation, the model should be
513
+ set back to train mode, where running stats are restored for batchnorm layers.
514
+ """
515
+ for m in model.modules():
516
+ for child in m.children():
517
+ if type(child) == torch.nn.BatchNorm2d:
518
+ child.track_running_stats = True
519
+ assert hasattr(child, 'running_mean_backup') and hasattr(child, 'running_var_backup'), 'enable_batchnorm_tracking_stats() is supposed to be executed after disable_batchnorm_tracking_stats() is applied'
520
+ child.running_mean = child.running_mean_backup
521
+ child.running_var = child.running_var_backup
522
+ return model
523
+
524
+
443
525
  def write_big_tiff_file(output_addr, img, tile_size):
444
526
  """
445
527
  This function write the image into a big tiff file using the tiling and compression.
@@ -581,64 +663,3 @@ def write_ome_tiff_file_array(results_array, output_addr, size_t, size_z, size_c
581
663
  output_addr,
582
664
  SizeT=size_t, SizeZ=size_z, SizeC=len(channel_names), SizeX=size_x, SizeY=size_y,
583
665
  channel_names=channel_names)
584
-
585
-
586
- def write_results_to_pickle_file(output_addr, results):
587
- """
588
- This function writes data into the pickle file.
589
- :param output_addr: The address of the pickle file to write data into.
590
- :param results: The data to be written into the pickle file.
591
- :return:
592
- """
593
- pickle_obj = open(output_addr, "wb")
594
- pickle.dump(results, pickle_obj)
595
- pickle_obj.close()
596
-
597
-
598
- def read_results_from_pickle_file(input_addr):
599
- """
600
- This function reads data from a pickle file and returns it.
601
- :param input_addr: The address to the pickle file.
602
- :return: The data inside pickle file.
603
- """
604
- pickle_obj = open(input_addr, "rb")
605
- results = pickle.load(pickle_obj)
606
- pickle_obj.close()
607
- return results
608
-
609
- def test_diff_original_serialized(model_original,model_serialized,example,verbose=0):
610
- threshold = 10
611
-
612
- orig_res = model_original(example)
613
- if verbose > 0:
614
- print('Original:')
615
- print(orig_res.shape)
616
- print(orig_res[0, 0:10])
617
- print('min abs value:{}'.format(torch.min(torch.abs(orig_res))))
618
-
619
- ts_res = model_serialized(example)
620
- if verbose > 0:
621
- print('Torchscript:')
622
- print(ts_res.shape)
623
- print(ts_res[0, 0:10])
624
- print('min abs value:{}'.format(torch.min(torch.abs(ts_res))))
625
-
626
- abs_diff = torch.abs(orig_res-ts_res)
627
- if verbose > 0:
628
- print('Dif sum:')
629
- print(torch.sum(abs_diff))
630
- print('max dif:{}'.format(torch.max(abs_diff)))
631
-
632
- assert torch.sum(abs_diff) <= threshold, f"Sum of difference in predicted values {torch.sum(abs_diff)} is larger than threshold {threshold}"
633
-
634
- def disable_batchnorm_tracking_stats(model):
635
- # https://discuss.pytorch.org/t/performance-highly-degraded-when-eval-is-activated-in-the-test-phase/3323/16
636
- # https://discuss.pytorch.org/t/performance-highly-degraded-when-eval-is-activated-in-the-test-phase/3323/67
637
- # https://github.com/pytorch/pytorch/blob/ca39c5b04e30a67512589cafbd9d063cc17168a5/torch/nn/modules/batchnorm.py#L158
638
- for m in model.modules():
639
- for child in m.children():
640
- if type(child) == torch.nn.BatchNorm2d:
641
- child.track_running_stats = False
642
- child.running_mean = None
643
- child.running_var = None
644
- return model