deepliif 1.1.10__py3-none-any.whl → 1.1.12__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
- cli.py +354 -67
- deepliif/data/__init__.py +7 -7
- deepliif/data/aligned_dataset.py +2 -3
- deepliif/data/unaligned_dataset.py +38 -19
- deepliif/models/CycleGAN_model.py +282 -0
- deepliif/models/DeepLIIFExt_model.py +47 -25
- deepliif/models/DeepLIIF_model.py +69 -19
- deepliif/models/SDG_model.py +57 -26
- deepliif/models/__init__ - run_dask_multi dev.py +943 -0
- deepliif/models/__init__ - timings.py +764 -0
- deepliif/models/__init__.py +354 -232
- deepliif/models/att_unet.py +199 -0
- deepliif/models/base_model.py +32 -8
- deepliif/models/networks.py +108 -34
- deepliif/options/__init__.py +49 -5
- deepliif/postprocessing.py +1034 -227
- deepliif/postprocessing__OLD__DELETE.py +440 -0
- deepliif/util/__init__.py +290 -64
- deepliif/util/visualizer.py +106 -19
- {deepliif-1.1.10.dist-info → deepliif-1.1.12.dist-info}/METADATA +81 -20
- deepliif-1.1.12.dist-info/RECORD +40 -0
- deepliif-1.1.10.dist-info/RECORD +0 -35
- {deepliif-1.1.10.dist-info → deepliif-1.1.12.dist-info}/LICENSE.md +0 -0
- {deepliif-1.1.10.dist-info → deepliif-1.1.12.dist-info}/WHEEL +0 -0
- {deepliif-1.1.10.dist-info → deepliif-1.1.12.dist-info}/entry_points.txt +0 -0
- {deepliif-1.1.10.dist-info → deepliif-1.1.12.dist-info}/top_level.txt +0 -0
deepliif/util/__init__.py
CHANGED
|
@@ -14,17 +14,18 @@ from .visualizer import Visualizer
|
|
|
14
14
|
from ..postprocessing import imadjust
|
|
15
15
|
import cv2
|
|
16
16
|
|
|
17
|
+
import pickle
|
|
18
|
+
import sys
|
|
19
|
+
|
|
17
20
|
import bioformats
|
|
18
21
|
import javabridge
|
|
19
22
|
import bioformats.omexml as ome
|
|
20
23
|
import tifffile as tf
|
|
21
24
|
|
|
22
|
-
import pickle
|
|
23
|
-
import sys
|
|
24
25
|
|
|
25
26
|
excluding_names = ['Hema', 'DAPI', 'DAPILap2', 'Ki67', 'Seg', 'Marked', 'SegRefined', 'SegOverlaid', 'Marker', 'Lap2']
|
|
26
27
|
# Image extensions to consider
|
|
27
|
-
image_extensions = ['.png', '.jpg', '.tif', '.jpeg'
|
|
28
|
+
image_extensions = ['.png', '.jpg', '.tif', '.jpeg']
|
|
28
29
|
|
|
29
30
|
|
|
30
31
|
def allowed_file(filename):
|
|
@@ -118,6 +119,211 @@ def stitch_tile(img, tile, tile_size, overlap_size, i, j):
|
|
|
118
119
|
img.paste(tile, (i * tile_size, j * tile_size))
|
|
119
120
|
|
|
120
121
|
|
|
122
|
+
class InferenceTiler:
|
|
123
|
+
"""
|
|
124
|
+
Iterable class to tile image(s) and stitch result tiles together.
|
|
125
|
+
|
|
126
|
+
To perform inference on a large image, that image will need to be
|
|
127
|
+
tiled into smaller tiles that can be run individually and then
|
|
128
|
+
stitched back together. This class wraps the functionality as an
|
|
129
|
+
iterable object that can accept a single image or list of images
|
|
130
|
+
if multiple images are taken as input for inference.
|
|
131
|
+
|
|
132
|
+
An overlap size can be specified so that neighboring tiles will
|
|
133
|
+
overlap at the edges, helping to reduce seams or other artifacts
|
|
134
|
+
near the edge of a tile. Padding of a solid color around the
|
|
135
|
+
perimeter of the tile is also possible, if needed. The specified
|
|
136
|
+
tile size includes this overlap and pad sizes, so a tile size of
|
|
137
|
+
512 with an overlap size of 32 and pad size of 16 would have a
|
|
138
|
+
central area of 416 pixels that are stitched into the result image.
|
|
139
|
+
|
|
140
|
+
Example Usage
|
|
141
|
+
-------------
|
|
142
|
+
tiler = InferenceTiler(img, 512, 32)
|
|
143
|
+
for tile in tiler:
|
|
144
|
+
result_tiles = infer(tile)
|
|
145
|
+
tiler.stitch(result_tiles)
|
|
146
|
+
images = tiler.results()
|
|
147
|
+
"""
|
|
148
|
+
|
|
149
|
+
def __init__(self, orig, tile_size, overlap_size=0, pad_size=0, pad_color=(255, 255, 255)):
|
|
150
|
+
"""
|
|
151
|
+
Initialize for tiling an image or list of images.
|
|
152
|
+
|
|
153
|
+
Parameters
|
|
154
|
+
----------
|
|
155
|
+
orig : Image | list(Image)
|
|
156
|
+
Original image or list of images to be tiled.
|
|
157
|
+
tile_size: int
|
|
158
|
+
Size (width and height) of the tiles to be generated.
|
|
159
|
+
overlap_size: int [default: 0]
|
|
160
|
+
Amount of overlap on each side of the tile.
|
|
161
|
+
pad_size: int [default: 0]
|
|
162
|
+
Amount of solid color padding around perimeter of tile.
|
|
163
|
+
pad_color: tuple(int, int, int) [default: (255,255,255)]
|
|
164
|
+
RGB color to use for padding.
|
|
165
|
+
"""
|
|
166
|
+
|
|
167
|
+
if tile_size <= 0:
|
|
168
|
+
raise ValueError('InfereneTiler input tile_size must be positive and non-zero')
|
|
169
|
+
if overlap_size < 0:
|
|
170
|
+
raise ValueError('InfereneTiler input overlap_size must be positive or zero')
|
|
171
|
+
if pad_size < 0:
|
|
172
|
+
raise ValueError('InfereneTiler input pad_size must be positive or zero')
|
|
173
|
+
|
|
174
|
+
self.single_orig = not type(orig) is list
|
|
175
|
+
if self.single_orig:
|
|
176
|
+
orig = [orig]
|
|
177
|
+
|
|
178
|
+
for i in range(1, len(orig)):
|
|
179
|
+
if orig[i].size != orig[0].size:
|
|
180
|
+
raise ValueError('InferenceTiler input images do not have the same size.')
|
|
181
|
+
self.orig_width = orig[0].width
|
|
182
|
+
self.orig_height = orig[0].height
|
|
183
|
+
|
|
184
|
+
# patch size to extract from input image, which is then padded to tile size
|
|
185
|
+
patch_size = tile_size - (2 * pad_size)
|
|
186
|
+
|
|
187
|
+
# make sure width and height are both at least patch_size
|
|
188
|
+
if orig[0].width < patch_size:
|
|
189
|
+
for i in range(len(orig)):
|
|
190
|
+
while orig[i].width < patch_size:
|
|
191
|
+
mirrored = ImageOps.mirror(orig[i])
|
|
192
|
+
orig[i] = ImageOps.expand(orig[i], (0, 0, orig[i].width, 0))
|
|
193
|
+
orig[i].paste(mirrored, (mirrored.width, 0))
|
|
194
|
+
orig[i] = orig[i].crop((0, 0, patch_size, orig[i].height))
|
|
195
|
+
if orig[0].height < patch_size:
|
|
196
|
+
for i in range(len(orig)):
|
|
197
|
+
while orig[i].height < patch_size:
|
|
198
|
+
flipped = ImageOps.flip(orig[i])
|
|
199
|
+
orig[i] = ImageOps.expand(orig[i], (0, 0, 0, orig[i].height))
|
|
200
|
+
orig[i].paste(flipped, (0, flipped.height))
|
|
201
|
+
orig[i] = orig[i].crop((0, 0, orig[i].width, patch_size))
|
|
202
|
+
self.image_width = orig[0].width
|
|
203
|
+
self.image_height = orig[0].height
|
|
204
|
+
|
|
205
|
+
overlap_width = 0 if patch_size >= self.image_width else overlap_size
|
|
206
|
+
overlap_height = 0 if patch_size >= self.image_height else overlap_size
|
|
207
|
+
center_width = patch_size - (2 * overlap_width)
|
|
208
|
+
center_height = patch_size - (2 * overlap_height)
|
|
209
|
+
if center_width <= 0 or center_height <= 0:
|
|
210
|
+
raise ValueError('InferenceTiler combined overlap_size and pad_size are too large')
|
|
211
|
+
|
|
212
|
+
self.c0x = pad_size # crop offset for left of non-pad content in result tile
|
|
213
|
+
self.c0y = pad_size # crop offset for top of non-pad content in result tile
|
|
214
|
+
self.c1x = overlap_width + pad_size # crop offset for left of center region in result tile
|
|
215
|
+
self.c1y = overlap_height + pad_size # crop offset for top of center region in result tile
|
|
216
|
+
self.c2x = patch_size - overlap_width + pad_size # crop offset for right of center region in result tile
|
|
217
|
+
self.c2y = patch_size - overlap_height + pad_size # crop offset for bottom of center region in result tile
|
|
218
|
+
self.c3x = patch_size + pad_size # crop offset for right of non-pad content in result tile
|
|
219
|
+
self.c3y = patch_size + pad_size # crop offset for bottom of non-pad content in result tile
|
|
220
|
+
self.p1x = overlap_width # paste offset for left of center region w.r.t (x,y) coord
|
|
221
|
+
self.p1y = overlap_height # paste offset for top of center region w.r.t (x,y) coord
|
|
222
|
+
self.p2x = patch_size - overlap_width # paste offset for right of center region w.r.t (x,y) coord
|
|
223
|
+
self.p2y = patch_size - overlap_height # paste offset for bottom of center region w.r.t (x,y) coord
|
|
224
|
+
|
|
225
|
+
self.overlap_width = overlap_width
|
|
226
|
+
self.overlap_height = overlap_height
|
|
227
|
+
self.patch_size = patch_size
|
|
228
|
+
self.center_width = center_width
|
|
229
|
+
self.center_height = center_height
|
|
230
|
+
|
|
231
|
+
self.orig = orig
|
|
232
|
+
self.tile_size = tile_size
|
|
233
|
+
self.pad_size = pad_size
|
|
234
|
+
self.pad_color = pad_color
|
|
235
|
+
self.res = {}
|
|
236
|
+
|
|
237
|
+
def __iter__(self):
|
|
238
|
+
"""
|
|
239
|
+
Generate the tiles as an iterable.
|
|
240
|
+
|
|
241
|
+
Tiles are created and iterated over from top left to bottom
|
|
242
|
+
right, going across the rows. The yielded tile(s) match the
|
|
243
|
+
type of the original input when initialized (either a single
|
|
244
|
+
image or a list of images in the same order as initialized).
|
|
245
|
+
The (x, y) coordinate of the current tile is maintained
|
|
246
|
+
internally for use in the stitch function.
|
|
247
|
+
"""
|
|
248
|
+
|
|
249
|
+
for y in range(0, self.image_height, self.center_height):
|
|
250
|
+
for x in range(0, self.image_width, self.center_width):
|
|
251
|
+
if x + self.patch_size > self.image_width:
|
|
252
|
+
x = self.image_width - self.patch_size
|
|
253
|
+
if y + self.patch_size > self.image_height:
|
|
254
|
+
y = self.image_height - self.patch_size
|
|
255
|
+
self.x = x
|
|
256
|
+
self.y = y
|
|
257
|
+
tiles = [im.crop((x, y, x + self.patch_size, y + self.patch_size)) for im in self.orig]
|
|
258
|
+
if self.pad_size != 0:
|
|
259
|
+
tiles = [ImageOps.expand(t, self.pad_size, self.pad_color) for t in tiles]
|
|
260
|
+
yield tiles[0] if self.single_orig else tiles
|
|
261
|
+
|
|
262
|
+
def stitch(self, result_tiles):
|
|
263
|
+
"""
|
|
264
|
+
Stitch result tiles into the result images.
|
|
265
|
+
|
|
266
|
+
The key names for the dictionary of result tiles are used to
|
|
267
|
+
stitch each tile into its corresponding final image in the
|
|
268
|
+
results attribute. If a result image does not exist for a
|
|
269
|
+
result tile key name, then it will be created. The result tiles
|
|
270
|
+
are stitched at the location from which the list iterated tile
|
|
271
|
+
was extracted.
|
|
272
|
+
|
|
273
|
+
Parameters
|
|
274
|
+
----------
|
|
275
|
+
result_tiles : dict(str: Image)
|
|
276
|
+
Dictionary of result tiles from the inference.
|
|
277
|
+
"""
|
|
278
|
+
|
|
279
|
+
for k, tile in result_tiles.items():
|
|
280
|
+
if k not in self.res:
|
|
281
|
+
self.res[k] = Image.new('RGB', (self.image_width, self.image_height))
|
|
282
|
+
if tile.size != (self.tile_size, self.tile_size):
|
|
283
|
+
tile = tile.resize((self.tile_size, self.tile_size))
|
|
284
|
+
self.res[k].paste(tile.crop((self.c1x, self.c1y, self.c2x, self.c2y)), (self.x + self.p1x, self.y + self.p1y))
|
|
285
|
+
|
|
286
|
+
# top left corner
|
|
287
|
+
if self.x == 0 and self.y == 0:
|
|
288
|
+
self.res[k].paste(tile.crop((self.c0x, self.c0y, self.c1x, self.c1y)), (self.x, self.y))
|
|
289
|
+
# top row
|
|
290
|
+
if self.y == 0:
|
|
291
|
+
self.res[k].paste(tile.crop((self.c1x, self.c0y, self.c2x, self.c1y)), (self.x + self.p1x, self.y))
|
|
292
|
+
# top right corner
|
|
293
|
+
if self.x == self.image_width - self.patch_size and self.y == 0:
|
|
294
|
+
self.res[k].paste(tile.crop((self.c2x, self.c0y, self.c3x, self.c1y)), (self.x + self.p2x, self.y))
|
|
295
|
+
# left column
|
|
296
|
+
if self.x == 0:
|
|
297
|
+
self.res[k].paste(tile.crop((self.c0x, self.c1y, self.c1x, self.c2y)), (self.x, self.y + self.p1y))
|
|
298
|
+
# right column
|
|
299
|
+
if self.x == self.image_width - self.patch_size:
|
|
300
|
+
self.res[k].paste(tile.crop((self.c2x, self.c1y, self.c3x, self.c2y)), (self.x + self.p2x, self.y + self.p1y))
|
|
301
|
+
# bottom left corner
|
|
302
|
+
if self.x == 0 and self.y == self.image_height - self.patch_size:
|
|
303
|
+
self.res[k].paste(tile.crop((self.c0x, self.c2y, self.c1x, self.c3y)), (self.x, self.y + self.p2y))
|
|
304
|
+
# bottom row
|
|
305
|
+
if self.y == self.image_height - self.patch_size:
|
|
306
|
+
self.res[k].paste(tile.crop((self.c1x, self.c2y, self.c2x, self.c3y)), (self.x + self.p1x, self.y + self.p2y))
|
|
307
|
+
# bottom right corner
|
|
308
|
+
if self.x == self.image_width - self.patch_size and self.y == self.image_height - self.patch_size:
|
|
309
|
+
self.res[k].paste(tile.crop((self.c2x, self.c2y, self.c3x, self.c3y)), (self.x + self.p2x, self.y + self.p2y))
|
|
310
|
+
|
|
311
|
+
def results(self):
|
|
312
|
+
"""
|
|
313
|
+
Return a dictionary of result images.
|
|
314
|
+
|
|
315
|
+
The keys for the result images are the same as those used for
|
|
316
|
+
the result tiles in the stitch function. This function should
|
|
317
|
+
only be called once, since the stitched images will be cropped
|
|
318
|
+
if the original image size was less than the patch size.
|
|
319
|
+
"""
|
|
320
|
+
|
|
321
|
+
if self.orig_width != self.image_width or self.orig_height != self.image_height:
|
|
322
|
+
return {k: im.crop((0, 0, self.orig_width, self.orig_height)) for k, im in self.res.items()}
|
|
323
|
+
else:
|
|
324
|
+
return {k: im for k, im in self.res.items()}
|
|
325
|
+
|
|
326
|
+
|
|
121
327
|
def calculate_background_mean_value(img):
|
|
122
328
|
img = cv2.fastNlMeansDenoisingColored(np.array(img), None, 10, 10, 7, 21)
|
|
123
329
|
img = np.array(img, dtype=float)
|
|
@@ -235,6 +441,87 @@ def get_information(filename):
|
|
|
235
441
|
return size_x, size_y, size_z, size_c, size_t, pixel_type
|
|
236
442
|
|
|
237
443
|
|
|
444
|
+
|
|
445
|
+
|
|
446
|
+
def write_results_to_pickle_file(output_addr, results):
|
|
447
|
+
"""
|
|
448
|
+
This function writes data into the pickle file.
|
|
449
|
+
:param output_addr: The address of the pickle file to write data into.
|
|
450
|
+
:param results: The data to be written into the pickle file.
|
|
451
|
+
:return:
|
|
452
|
+
"""
|
|
453
|
+
pickle_obj = open(output_addr, "wb")
|
|
454
|
+
pickle.dump(results, pickle_obj)
|
|
455
|
+
pickle_obj.close()
|
|
456
|
+
|
|
457
|
+
|
|
458
|
+
def read_results_from_pickle_file(input_addr):
|
|
459
|
+
"""
|
|
460
|
+
This function reads data from a pickle file and returns it.
|
|
461
|
+
:param input_addr: The address to the pickle file.
|
|
462
|
+
:return: The data inside pickle file.
|
|
463
|
+
"""
|
|
464
|
+
pickle_obj = open(input_addr, "rb")
|
|
465
|
+
results = pickle.load(pickle_obj)
|
|
466
|
+
pickle_obj.close()
|
|
467
|
+
return results
|
|
468
|
+
|
|
469
|
+
def test_diff_original_serialized(model_original,model_serialized,example,verbose=0):
|
|
470
|
+
threshold = 10
|
|
471
|
+
|
|
472
|
+
orig_res = model_original(example)
|
|
473
|
+
if verbose > 0:
|
|
474
|
+
print('Original:')
|
|
475
|
+
print(orig_res.shape)
|
|
476
|
+
print(orig_res[0, 0:10])
|
|
477
|
+
print('min abs value:{}'.format(torch.min(torch.abs(orig_res))))
|
|
478
|
+
|
|
479
|
+
ts_res = model_serialized(example)
|
|
480
|
+
if verbose > 0:
|
|
481
|
+
print('Torchscript:')
|
|
482
|
+
print(ts_res.shape)
|
|
483
|
+
print(ts_res[0, 0:10])
|
|
484
|
+
print('min abs value:{}'.format(torch.min(torch.abs(ts_res))))
|
|
485
|
+
|
|
486
|
+
abs_diff = torch.abs(orig_res-ts_res)
|
|
487
|
+
if verbose > 0:
|
|
488
|
+
print('Dif sum:')
|
|
489
|
+
print(torch.sum(abs_diff))
|
|
490
|
+
print('max dif:{}'.format(torch.max(abs_diff)))
|
|
491
|
+
|
|
492
|
+
assert torch.sum(abs_diff) <= threshold, f"Sum of difference in predicted values {torch.sum(abs_diff)} is larger than threshold {threshold}"
|
|
493
|
+
|
|
494
|
+
def disable_batchnorm_tracking_stats(model):
|
|
495
|
+
# https://discuss.pytorch.org/t/performance-highly-degraded-when-eval-is-activated-in-the-test-phase/3323/16
|
|
496
|
+
# https://discuss.pytorch.org/t/performance-highly-degraded-when-eval-is-activated-in-the-test-phase/3323/67
|
|
497
|
+
# https://github.com/pytorch/pytorch/blob/ca39c5b04e30a67512589cafbd9d063cc17168a5/torch/nn/modules/batchnorm.py#L158
|
|
498
|
+
for m in model.modules():
|
|
499
|
+
for child in m.children():
|
|
500
|
+
if type(child) == torch.nn.BatchNorm2d:
|
|
501
|
+
child.track_running_stats = False
|
|
502
|
+
child.running_mean_backup = child.running_mean
|
|
503
|
+
child.running_mean = None
|
|
504
|
+
child.running_var_backup = child.running_var
|
|
505
|
+
child.running_var = None
|
|
506
|
+
return model
|
|
507
|
+
|
|
508
|
+
def enable_batchnorm_tracking_stats(model):
|
|
509
|
+
"""
|
|
510
|
+
This is needed during training when val set loss/metrics calculation is enabled.
|
|
511
|
+
In this case, we need to switch to eval mode for inference, which triggers
|
|
512
|
+
disable_batchnorm_tracking_stats(). After the evaluation, the model should be
|
|
513
|
+
set back to train mode, where running stats are restored for batchnorm layers.
|
|
514
|
+
"""
|
|
515
|
+
for m in model.modules():
|
|
516
|
+
for child in m.children():
|
|
517
|
+
if type(child) == torch.nn.BatchNorm2d:
|
|
518
|
+
child.track_running_stats = True
|
|
519
|
+
assert hasattr(child, 'running_mean_backup') and hasattr(child, 'running_var_backup'), 'enable_batchnorm_tracking_stats() is supposed to be executed after disable_batchnorm_tracking_stats() is applied'
|
|
520
|
+
child.running_mean = child.running_mean_backup
|
|
521
|
+
child.running_var = child.running_var_backup
|
|
522
|
+
return model
|
|
523
|
+
|
|
524
|
+
|
|
238
525
|
def write_big_tiff_file(output_addr, img, tile_size):
|
|
239
526
|
"""
|
|
240
527
|
This function write the image into a big tiff file using the tiling and compression.
|
|
@@ -376,64 +663,3 @@ def write_ome_tiff_file_array(results_array, output_addr, size_t, size_z, size_c
|
|
|
376
663
|
output_addr,
|
|
377
664
|
SizeT=size_t, SizeZ=size_z, SizeC=len(channel_names), SizeX=size_x, SizeY=size_y,
|
|
378
665
|
channel_names=channel_names)
|
|
379
|
-
|
|
380
|
-
|
|
381
|
-
def write_results_to_pickle_file(output_addr, results):
|
|
382
|
-
"""
|
|
383
|
-
This function writes data into the pickle file.
|
|
384
|
-
:param output_addr: The address of the pickle file to write data into.
|
|
385
|
-
:param results: The data to be written into the pickle file.
|
|
386
|
-
:return:
|
|
387
|
-
"""
|
|
388
|
-
pickle_obj = open(output_addr, "wb")
|
|
389
|
-
pickle.dump(results, pickle_obj)
|
|
390
|
-
pickle_obj.close()
|
|
391
|
-
|
|
392
|
-
|
|
393
|
-
def read_results_from_pickle_file(input_addr):
|
|
394
|
-
"""
|
|
395
|
-
This function reads data from a pickle file and returns it.
|
|
396
|
-
:param input_addr: The address to the pickle file.
|
|
397
|
-
:return: The data inside pickle file.
|
|
398
|
-
"""
|
|
399
|
-
pickle_obj = open(input_addr, "rb")
|
|
400
|
-
results = pickle.load(pickle_obj)
|
|
401
|
-
pickle_obj.close()
|
|
402
|
-
return results
|
|
403
|
-
|
|
404
|
-
def test_diff_original_serialized(model_original,model_serialized,example,verbose=0):
|
|
405
|
-
threshold = 10
|
|
406
|
-
|
|
407
|
-
orig_res = model_original(example)
|
|
408
|
-
if verbose > 0:
|
|
409
|
-
print('Original:')
|
|
410
|
-
print(orig_res.shape)
|
|
411
|
-
print(orig_res[0, 0:10])
|
|
412
|
-
print('min abs value:{}'.format(torch.min(torch.abs(orig_res))))
|
|
413
|
-
|
|
414
|
-
ts_res = model_serialized(example)
|
|
415
|
-
if verbose > 0:
|
|
416
|
-
print('Torchscript:')
|
|
417
|
-
print(ts_res.shape)
|
|
418
|
-
print(ts_res[0, 0:10])
|
|
419
|
-
print('min abs value:{}'.format(torch.min(torch.abs(ts_res))))
|
|
420
|
-
|
|
421
|
-
abs_diff = torch.abs(orig_res-ts_res)
|
|
422
|
-
if verbose > 0:
|
|
423
|
-
print('Dif sum:')
|
|
424
|
-
print(torch.sum(abs_diff))
|
|
425
|
-
print('max dif:{}'.format(torch.max(abs_diff)))
|
|
426
|
-
|
|
427
|
-
assert torch.sum(abs_diff) <= threshold, f"Sum of difference in predicted values {torch.sum(abs_diff)} is larger than threshold {threshold}"
|
|
428
|
-
|
|
429
|
-
def disable_batchnorm_tracking_stats(model):
|
|
430
|
-
# https://discuss.pytorch.org/t/performance-highly-degraded-when-eval-is-activated-in-the-test-phase/3323/16
|
|
431
|
-
# https://discuss.pytorch.org/t/performance-highly-degraded-when-eval-is-activated-in-the-test-phase/3323/67
|
|
432
|
-
# https://github.com/pytorch/pytorch/blob/ca39c5b04e30a67512589cafbd9d063cc17168a5/torch/nn/modules/batchnorm.py#L158
|
|
433
|
-
for m in model.modules():
|
|
434
|
-
for child in m.children():
|
|
435
|
-
if type(child) == torch.nn.BatchNorm2d:
|
|
436
|
-
child.track_running_stats = False
|
|
437
|
-
child.running_mean = None
|
|
438
|
-
child.running_var = None
|
|
439
|
-
return model
|
deepliif/util/visualizer.py
CHANGED
|
@@ -165,6 +165,7 @@ class Visualizer():
|
|
|
165
165
|
if ncols > 0: # show all the images in one visdom panel
|
|
166
166
|
ncols = min(ncols, len(visuals))
|
|
167
167
|
h, w = next(iter(visuals.values())).shape[:2]
|
|
168
|
+
|
|
168
169
|
table_css = """<style>
|
|
169
170
|
table {border-collapse: separate; border-spacing: 4px; white-space: nowrap; text-align: center}
|
|
170
171
|
table td {width: % dpx; height: % dpx; padding: 4px; outline: 4px solid black}
|
|
@@ -176,13 +177,16 @@ class Visualizer():
|
|
|
176
177
|
images = []
|
|
177
178
|
idx = 0
|
|
178
179
|
for label, image in visuals.items():
|
|
179
|
-
|
|
180
|
-
|
|
181
|
-
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
180
|
+
if image.shape[1] != 3:
|
|
181
|
+
pass
|
|
182
|
+
else:
|
|
183
|
+
image_numpy = util.tensor2im(image)
|
|
184
|
+
label_html_row += '<td>%s</td>' % label
|
|
185
|
+
images.append(image_numpy.transpose([2, 0, 1]))
|
|
186
|
+
idx += 1
|
|
187
|
+
if idx % ncols == 0:
|
|
188
|
+
label_html += '<tr>%s</tr>' % label_html_row
|
|
189
|
+
label_html_row = ''
|
|
186
190
|
|
|
187
191
|
white_image = np.ones_like(image_numpy.transpose([2, 0, 1])) * 255
|
|
188
192
|
while idx % ncols != 0:
|
|
@@ -191,6 +195,7 @@ class Visualizer():
|
|
|
191
195
|
idx += 1
|
|
192
196
|
if label_html_row != '':
|
|
193
197
|
label_html += '<tr>%s</tr>' % label_html_row
|
|
198
|
+
|
|
194
199
|
|
|
195
200
|
try:
|
|
196
201
|
self.vis.images(images, nrow=ncols, win=self.display_id + 1,
|
|
@@ -248,6 +253,10 @@ class Visualizer():
|
|
|
248
253
|
# if having 2 processes, each process obtains 50% of the data (effective dataset_size divided by half), the effective counter ratio shall multiply by 2 to compensate that
|
|
249
254
|
n_proc = int(os.getenv('WORLD_SIZE',1))
|
|
250
255
|
counter_ratio = counter_ratio * n_proc
|
|
256
|
+
|
|
257
|
+
self.plot_data_update_train = False
|
|
258
|
+
self.plot_data_update_val = False
|
|
259
|
+
self.plot_data_update_metrics = False
|
|
251
260
|
|
|
252
261
|
if self.remote:
|
|
253
262
|
fn = 'plot_current_losses.pickle'
|
|
@@ -263,20 +272,98 @@ class Visualizer():
|
|
|
263
272
|
exec(f'{self.remote_transfer_cmd_function}("{path_source}")')
|
|
264
273
|
else:
|
|
265
274
|
if not hasattr(self, 'plot_data'):
|
|
266
|
-
self.plot_data = {'X': [], '
|
|
267
|
-
|
|
268
|
-
|
|
275
|
+
self.plot_data = {'X': [], 'X_val':[], 'X_metrics':[],
|
|
276
|
+
'Y': [], 'Y_val':[], 'Y_metrics':[],
|
|
277
|
+
'legend': [], 'legend_val': [], 'legend_metrics':[]}
|
|
278
|
+
for k in list(losses.keys()):
|
|
279
|
+
if k.endswith('_val'):
|
|
280
|
+
self.plot_data['legend_val'].append(k)
|
|
281
|
+
elif k.startswith(('G_','D_')):
|
|
282
|
+
self.plot_data['legend'].append(k)
|
|
283
|
+
else:
|
|
284
|
+
self.plot_data['legend_metrics'].append(k)
|
|
285
|
+
|
|
286
|
+
# check if all names in losses dict have been seen
|
|
287
|
+
# currently we assume the three types of metrics (train loss, val loss, other metrics) can come into the losses dict
|
|
288
|
+
# at any step, but each type will join or leave the dict as a whole (i.e., train loss metrics will either all appear or all be missing)
|
|
289
|
+
for k in list(losses.keys()):
|
|
290
|
+
if k.endswith('_val'):
|
|
291
|
+
if k not in self.plot_data['legend_val']:
|
|
292
|
+
self.plot_data['legend_val'].append(k)
|
|
293
|
+
elif k.startswith(('G_','D_')):
|
|
294
|
+
if k not in self.plot_data['legend']:
|
|
295
|
+
self.plot_data['legend'].append(k)
|
|
296
|
+
else:
|
|
297
|
+
if k not in self.plot_data['legend_metrics']:
|
|
298
|
+
self.plot_data['legend_metrics'].append(k)
|
|
299
|
+
|
|
300
|
+
# update training loss
|
|
301
|
+
print('update training loss')
|
|
302
|
+
if len(self.plot_data['legend']) > 0:
|
|
303
|
+
if self.plot_data['legend'][0] in losses:
|
|
304
|
+
self.plot_data['X'].append(epoch + counter_ratio)
|
|
305
|
+
self.plot_data['Y'].append([losses[k] for k in self.plot_data['legend']])
|
|
306
|
+
self.plot_data_update_train = True
|
|
307
|
+
|
|
308
|
+
# update validation loss
|
|
309
|
+
print('update validation loss')
|
|
310
|
+
if len(self.plot_data['legend_val']) > 0:
|
|
311
|
+
if self.plot_data['legend_val'][0] in losses:
|
|
312
|
+
self.plot_data['X_val'].append(epoch + counter_ratio)
|
|
313
|
+
self.plot_data['Y_val'].append([losses[k] for k in self.plot_data['legend_val']])
|
|
314
|
+
self.plot_data_update_val = True
|
|
315
|
+
|
|
316
|
+
# update other calculated metrics
|
|
317
|
+
print('update other metrics')
|
|
318
|
+
if len(self.plot_data['legend_metrics']) > 0:
|
|
319
|
+
if self.plot_data['legend_metrics'][0] in losses:
|
|
320
|
+
self.plot_data['X_metrics'].append(epoch + counter_ratio)
|
|
321
|
+
self.plot_data['Y_metrics'].append([losses[k] for k in self.plot_data['legend_metrics']])
|
|
322
|
+
self.plot_data_update_metrics = True
|
|
269
323
|
|
|
270
324
|
try:
|
|
271
|
-
self.
|
|
272
|
-
|
|
273
|
-
|
|
274
|
-
|
|
275
|
-
|
|
276
|
-
|
|
277
|
-
|
|
278
|
-
|
|
279
|
-
|
|
325
|
+
if self.plot_data_update_train:
|
|
326
|
+
print('plotting train loss')
|
|
327
|
+
self.vis.line(
|
|
328
|
+
X=np.stack([np.array(self.plot_data['X'])] * len(self.plot_data['legend']), 1),
|
|
329
|
+
Y=np.array(self.plot_data['Y']),
|
|
330
|
+
opts={
|
|
331
|
+
'title': self.name + ' train loss over time',
|
|
332
|
+
'legend': self.plot_data['legend'],
|
|
333
|
+
'xlabel': 'epoch',
|
|
334
|
+
'ylabel': 'loss'},
|
|
335
|
+
win = 'train',
|
|
336
|
+
#env=self.display_id
|
|
337
|
+
)
|
|
338
|
+
|
|
339
|
+
if self.plot_data_update_val:
|
|
340
|
+
print('plotting val loss')
|
|
341
|
+
self.vis.line(
|
|
342
|
+
X=np.stack([np.array(self.plot_data['X_val'])] * len(self.plot_data['legend_val']), 1),
|
|
343
|
+
Y=np.array(self.plot_data['Y_val']),
|
|
344
|
+
opts={
|
|
345
|
+
'title': self.name + ' val loss over time',
|
|
346
|
+
'legend': self.plot_data['legend_val'],
|
|
347
|
+
'xlabel': 'epoch',
|
|
348
|
+
'ylabel': 'loss'},
|
|
349
|
+
win = 'val',
|
|
350
|
+
#env=self.display_id
|
|
351
|
+
)
|
|
352
|
+
|
|
353
|
+
if self.plot_data_update_metrics:
|
|
354
|
+
print('plotting other metrics')
|
|
355
|
+
self.vis.line(
|
|
356
|
+
X=np.stack([np.array(self.plot_data['X_metrics'])] * len(self.plot_data['legend_metrics']), 1),
|
|
357
|
+
Y=np.array(self.plot_data['Y_metrics']),
|
|
358
|
+
opts={
|
|
359
|
+
'title': self.name + ' metrics over time',
|
|
360
|
+
'legend': self.plot_data['legend_metrics'],
|
|
361
|
+
'xlabel': 'epoch',
|
|
362
|
+
'ylabel': 'metrics'},
|
|
363
|
+
win = 'metrics',
|
|
364
|
+
#env=self.display_id
|
|
365
|
+
)
|
|
366
|
+
|
|
280
367
|
except VisdomExceptionBase:
|
|
281
368
|
self.create_visdom_connections()
|
|
282
369
|
|