konfai 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of konfai might be problematic. Click here for more details.

@@ -0,0 +1,536 @@
1
+ import importlib
2
+ import torch
3
+ import numpy as np
4
+ import SimpleITK as sitk
5
+ from abc import ABC, abstractmethod
6
+ import torch.nn.functional as F
7
+ from typing import Any, Union
8
+
9
+ from KonfAI.konfai.utils.utils import _getModule, NeedDevice, _resample_affine, _affine_matrix
10
+ from KonfAI.konfai.utils.dataset import Dataset, Attribute, data_to_image, image_to_data
11
+ from KonfAI.konfai.utils.config import config
12
+
13
+ class Transform(NeedDevice, ABC):
14
+
15
+ def __init__(self) -> None:
16
+ self.datasets : list[Dataset] = []
17
+
18
+ def setDatasets(self, datasets: list[Dataset]):
19
+ self.datasets = datasets
20
+
21
+ def transformShape(self, shape: list[int], cache_attribute: Attribute) -> list[int]:
22
+ return shape
23
+
24
+ @abstractmethod
25
+ def __call__(self, name: str, input : torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
26
+ pass
27
+
28
+ @abstractmethod
29
+ def inverse(self, name: str, input : torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
30
+ pass
31
+
32
+ class TransformLoader:
33
+
34
+ @config()
35
+ def __init__(self) -> None:
36
+ pass
37
+
38
+ def getTransform(self, classpath : str, DL_args : str) -> Transform:
39
+ module, name = _getModule(classpath, "transform")
40
+ return config("{}.{}".format(DL_args, classpath))(getattr(importlib.import_module(module), name))(config = None)
41
+
42
+ class Clip(Transform):
43
+
44
+ def __init__(self, min_value : float = -1024, max_value : float = 1024, saveClip_min: bool = False, saveClip_max: bool = False) -> None:
45
+ assert max_value > min_value
46
+ self.min_value = min_value
47
+ self.max_value = max_value
48
+ self.saveClip_min = saveClip_min
49
+ self.saveClip_max = saveClip_max
50
+
51
+ def __call__(self, name: str, input : torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
52
+ input[torch.where(input < self.min_value)] = self.min_value
53
+ input[torch.where(input > self.max_value)] = self.max_value
54
+ if self.saveClip_min:
55
+ cache_attribute["Min"] = self .min_value
56
+ if self.saveClip_max:
57
+ cache_attribute["Max"] = self.max_value
58
+ return input
59
+
60
+ def inverse(self, name: str, input : torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
61
+ return input
62
+
63
+ class Normalize(Transform):
64
+
65
+ def __init__(self, lazy : bool = False, channels: Union[list[int], None] = None, min_value : float = -1, max_value : float = 1) -> None:
66
+ assert max_value > min_value
67
+ self.lazy = lazy
68
+ self.min_value = min_value
69
+ self.max_value = max_value
70
+ self.channels = channels
71
+
72
+ def __call__(self, name: str, input : torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
73
+ if "Min" not in cache_attribute:
74
+ if self.channels:
75
+ cache_attribute["Min"] = torch.min(input[self.channels])
76
+ else:
77
+ cache_attribute["Min"] = torch.min(input)
78
+ if "Max" not in cache_attribute:
79
+ if self.channels:
80
+ cache_attribute["Max"] = torch.max(input[self.channels])
81
+ else:
82
+ cache_attribute["Max"] = torch.max(input)
83
+ if not self.lazy:
84
+ input_min = float(cache_attribute["Min"])
85
+ input_max = float(cache_attribute["Max"])
86
+ norm = input_max-input_min
87
+ assert norm != 0
88
+ if self.channels:
89
+ for channel in self.channels:
90
+ input[channel] = (self.max_value-self.min_value)*(input[channel] - input_min) / norm + self.min_value
91
+ else:
92
+ input = (self.max_value-self.min_value)*(input - input_min) / norm + self.min_value
93
+ return input
94
+
95
+ def inverse(self, name: str, input : torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
96
+ if self.lazy:
97
+ return input
98
+ else:
99
+ input_min = float(cache_attribute.pop("Min"))
100
+ input_max = float(cache_attribute.pop("Max"))
101
+ return (input - self.min_value)*(input_max-input_min)/(self.max_value-self.min_value)+input_min
102
+
103
+ class Standardize(Transform):
104
+
105
+ def __init__(self, lazy : bool = False, mean: Union[list[float], None] = None, std: Union[list[float], None]= None) -> None:
106
+ self.lazy = lazy
107
+ self.mean = mean
108
+ self.std = std
109
+
110
+ def __call__(self, name: str, input : torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
111
+ if "Mean" not in cache_attribute:
112
+ cache_attribute["Mean"] = torch.mean(input.type(torch.float32), dim=[i + 1 for i in range(len(input.shape)-1)]) if self.mean is None else torch.tensor([self.mean])
113
+ if "Std" not in cache_attribute:
114
+ cache_attribute["Std"] = torch.std(input.type(torch.float32), dim=[i + 1 for i in range(len(input.shape)-1)]) if self.std is None else torch.tensor([self.std])
115
+
116
+ if self.lazy:
117
+ return input
118
+ else:
119
+ mean = cache_attribute.get_tensor("Mean").view(-1, *[1 for _ in range(len(input.shape)-1)])
120
+ std = cache_attribute.get_tensor("Std").view(-1, *[1 for _ in range(len(input.shape)-1)])
121
+ return (input - mean) / std
122
+
123
+ def inverse(self, name: str, input : torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
124
+ if self.lazy:
125
+ return input
126
+ else:
127
+ mean = float(cache_attribute.pop("Mean"))
128
+ std = float(cache_attribute.pop("Std"))
129
+ return input * std + mean
130
+
131
+ class TensorCast(Transform):
132
+
133
+ def __init__(self, dtype : str = "default:float32,int64,int16") -> None:
134
+ self.dtype : torch.dtype = getattr(torch, dtype)
135
+
136
+ def __call__(self, name: str, input : torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
137
+ cache_attribute["dtype"] = input.dtype
138
+ return input.type(self.dtype)
139
+
140
+ def inverse(self, name: str, input : torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
141
+ return input.to(eval(cache_attribute.pop("dtype")))
142
+
143
+ class Padding(Transform):
144
+
145
+ def __init__(self, padding : list[int] = [0,0,0,0,0,0], mode : str = "default:constant,reflect,replicate,circular") -> None:
146
+ self.padding = padding
147
+ self.mode = mode
148
+
149
+ def __call__(self, name: str, input : torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
150
+ if "Origin" in cache_attribute and "Spacing" in cache_attribute and "Direction" in cache_attribute:
151
+ origin = torch.tensor(cache_attribute.get_np_array("Origin"))
152
+ matrix = torch.tensor(cache_attribute.get_np_array("Direction").reshape((len(origin),len(origin))))
153
+ origin = torch.matmul(origin, matrix)
154
+ for dim in range(len(self.padding)//2):
155
+ origin[-dim-1] -= self.padding[dim*2]* cache_attribute.get_np_array("Spacing")[-dim-1]
156
+ cache_attribute["Origin"] = torch.matmul(origin, torch.inverse(matrix))
157
+ result = F.pad(input.unsqueeze(0), tuple(self.padding), self.mode.split(":")[0], float(self.mode.split(":")[1]) if len(self.mode.split(":")) == 2 else 0).squeeze(0)
158
+ return result
159
+
160
+ def transformShape(self, shape: list[int], cache_attribute: Attribute) -> list[int]:
161
+ for dim in range(len(self.padding)//2):
162
+ shape[-dim-1] += sum(self.padding[dim*2:dim*2+2])
163
+ return shape
164
+
165
+ def inverse(self, name: str, input : torch.Tensor, cache_attribute: dict[str, torch.Tensor]) -> torch.Tensor:
166
+ if "Origin" in cache_attribute and "Spacing" in cache_attribute and "Direction" in cache_attribute:
167
+ cache_attribute.pop("Origin")
168
+ slices = [slice(0, shape) for shape in input.shape]
169
+ for dim in range(len(self.padding)//2):
170
+ slices[-dim-1] = slice(self.padding[dim*2], input.shape[-dim-1]-self.padding[dim*2+1])
171
+ result = input[slices]
172
+ return result
173
+
174
+ class Squeeze(Transform):
175
+
176
+ def __init__(self, dim: int) -> None:
177
+ self.dim = dim
178
+
179
+ def __call__(self, name: str, input : torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
180
+ return input.squeeze(self.dim)
181
+
182
+ def inverse(self, name: str, input : torch.Tensor, cache_attribute: dict[str, Any]) -> torch.Tensor:
183
+ return input.unsqueeze(self.dim)
184
+
185
+ class Resample(Transform, ABC):
186
+
187
+ def __init__(self) -> None:
188
+ pass
189
+
190
+ def _resample(self, input: torch.Tensor, size: list[int]) -> torch.Tensor:
191
+ args = {}
192
+ if input.dtype == torch.uint8:
193
+ mode = "nearest"
194
+ elif len(input.shape) < 4:
195
+ mode = "bilinear"
196
+ else:
197
+ mode = "trilinear"
198
+ return F.interpolate(input.type(torch.float32).unsqueeze(0), size=tuple(size), mode=mode).squeeze(0).type(input.dtype).cpu()
199
+
200
+ @abstractmethod
201
+ def __call__(self, name: str, input : torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
202
+ pass
203
+
204
+ @abstractmethod
205
+ def transformShape(self, shape: list[int], cache_attribute: Attribute) -> list[int]:
206
+ pass
207
+
208
+ def inverse(self, name: str, input : torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
209
+ size_0 = cache_attribute.pop_np_array("Size")
210
+ size_1 = cache_attribute.pop_np_array("Size")
211
+ _ = cache_attribute.pop_np_array("Spacing")
212
+ return self._resample(input, [int(size) for size in size_1])
213
+
214
+ class ResampleIsotropic(Resample):
215
+
216
+ def __init__(self, spacing : list[float] = [1., 1., 1.]) -> None:
217
+ self.spacing = torch.tensor(spacing, dtype=torch.float64)
218
+
219
+ def transformShape(self, shape: list[int], cache_attribute: Attribute) -> list[int]:
220
+ assert "Spacing" in cache_attribute, "Error no spacing"
221
+ resize_factor = self.spacing/cache_attribute.get_tensor("Spacing").flip(0)
222
+ return [int(x) for x in (torch.tensor(shape) * 1/resize_factor)]
223
+
224
+ def __call__(self, name: str, input : torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
225
+ assert "Spacing" in cache_attribute, "Error no spacing"
226
+ resize_factor = self.spacing/cache_attribute.get_tensor("Spacing").flip(0)
227
+ cache_attribute["Spacing"] = self.spacing.flip(0)
228
+ cache_attribute["Size"] = np.asarray([int(x) for x in torch.tensor(input.shape[1:])])
229
+ size = [int(x) for x in (torch.tensor(input.shape[1:]) * 1/resize_factor)]
230
+ cache_attribute["Size"] = np.asarray(size)
231
+ return self._resample(input, size)
232
+
233
+ class ResampleResize(Resample):
234
+
235
+ def __init__(self, size : list[int] = [100,512,512]) -> None:
236
+ self.size = size
237
+
238
+ def transformShape(self, shape: list[int], cache_attribute: Attribute) -> list[int]:
239
+ return self.size
240
+
241
+ def __call__(self, name: str, input: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
242
+ if "Spacing" in cache_attribute:
243
+ cache_attribute["Spacing"] = torch.flip(torch.tensor(list(input.shape[1:]))/torch.tensor(self.size)*torch.flip(cache_attribute.get_tensor("Spacing"), dims=[0]), dims=[0])
244
+ cache_attribute["Size"] = np.asarray([int(x) for x in torch.tensor(input.shape[1:])])
245
+ cache_attribute["Size"] = self.size
246
+ return self._resample(input, self.size)
247
+
248
+ class ResampleTransform(Transform):
249
+
250
+ def __init__(self, transforms : dict[str, bool]) -> None:
251
+ self.transforms = transforms
252
+
253
+ def transformShape(self, shape: list[int], cache_attribute: Attribute) -> list[int]:
254
+ return shape
255
+
256
+ def __call__V1(self, name: str, input : torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
257
+ transforms = []
258
+ image = data_to_image(input, cache_attribute)
259
+ for transform_group, invert in self.transforms.items():
260
+ transform = None
261
+ for dataset in self.datasets:
262
+ if dataset.isDatasetExist(transform_group, name):
263
+ transform = dataset.readTransform(transform_group, name)
264
+ break
265
+ if transform is None:
266
+ raise NameError("Tranform : {}/{} not found".format(transform_group, name))
267
+ if isinstance(transform, sitk.BSplineTransform):
268
+ if invert:
269
+ transformToDisplacementFieldFilter = sitk.TransformToDisplacementFieldFilter()
270
+ transformToDisplacementFieldFilter.SetReferenceImage(image)
271
+ displacementField = transformToDisplacementFieldFilter.Execute(transform)
272
+ iterativeInverseDisplacementFieldImageFilter = sitk.IterativeInverseDisplacementFieldImageFilter()
273
+ iterativeInverseDisplacementFieldImageFilter.SetNumberOfIterations(20)
274
+ inverseDisplacementField = iterativeInverseDisplacementFieldImageFilter.Execute(displacementField)
275
+ transform = sitk.DisplacementFieldTransform(inverseDisplacementField)
276
+ else:
277
+ if invert:
278
+ transform = transform.GetInverse()
279
+ transforms.append(transform)
280
+ result_transform = sitk.CompositeTransform(transforms)
281
+ result = torch.tensor(sitk.GetArrayFromImage(sitk.Resample(image, image, result_transform, sitk.sitkNearestNeighbor if input.dtype == torch.uint8 else sitk.sitkBSpline, 0 if input.dtype == torch.uint8 else -1024))).unsqueeze(0)
282
+ return result.type(torch.uint8) if input.dtype == torch.uint8 else result
283
+
284
+ def __call__(self, name: str, input : torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
285
+ assert len(input.shape) == 4 , "input size should be 5 dim"
286
+ image = data_to_image(input, cache_attribute)
287
+
288
+ vectors = [torch.arange(0, s) for s in input.shape[1:]]
289
+ grids = torch.meshgrid(vectors, indexing='ij')
290
+ grid = torch.stack(grids)
291
+ grid = torch.unsqueeze(grid, 0)
292
+
293
+ transforms = []
294
+ for transform_group, invert in self.transforms.items():
295
+ transform = None
296
+ for dataset in self.datasets:
297
+ if dataset.isDatasetExist(transform_group, name):
298
+ transform = dataset.readTransform(transform_group, name)
299
+ break
300
+ if transform is None:
301
+ raise NameError("Tranform : {}/{} not found".format(transform_group, name))
302
+ if isinstance(transform, sitk.BSplineTransform):
303
+ if invert:
304
+ transformToDisplacementFieldFilter = sitk.TransformToDisplacementFieldFilter()
305
+ transformToDisplacementFieldFilter.SetReferenceImage(image)
306
+ displacementField = transformToDisplacementFieldFilter.Execute(transform)
307
+ iterativeInverseDisplacementFieldImageFilter = sitk.IterativeInverseDisplacementFieldImageFilter()
308
+ iterativeInverseDisplacementFieldImageFilter.SetNumberOfIterations(20)
309
+ inverseDisplacementField = iterativeInverseDisplacementFieldImageFilter.Execute(displacementField)
310
+ transform = sitk.DisplacementFieldTransform(inverseDisplacementField)
311
+ else:
312
+ if invert:
313
+ transform = transform.GetInverse()
314
+ transforms.append(transform)
315
+ result_transform = sitk.CompositeTransform(transforms)
316
+
317
+ transformToDisplacementFieldFilter = sitk.TransformToDisplacementFieldFilter()
318
+ transformToDisplacementFieldFilter.SetReferenceImage(image)
319
+ transformToDisplacementFieldFilter.SetNumberOfThreads(16)
320
+ new_locs = grid + torch.tensor(sitk.GetArrayFromImage(transformToDisplacementFieldFilter.Execute(result_transform))).unsqueeze(0).permute(0, 4, 1, 2, 3)
321
+ shape = new_locs.shape[2:]
322
+ for i in range(len(shape)):
323
+ new_locs[:, i, ...] = 2 * (new_locs[:, i, ...] / (shape[i] - 1) - 0.5)
324
+ new_locs = new_locs.permute(0, 2, 3, 4, 1)
325
+ new_locs = new_locs[..., [2, 1, 0]]
326
+ result = F.grid_sample(input.to(self.device).unsqueeze(0).float(), new_locs.to(self.device).float(), align_corners=True, padding_mode="border", mode="nearest" if input.dtype == torch.uint8 else "bilinear").squeeze(0).cpu()
327
+ return result.type(torch.uint8) if input.dtype == torch.uint8 else result
328
+
329
+ def inverse(self, name: str, input : torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
330
+ # TODO
331
+ return input
332
+
333
+ class Mask(Transform):
334
+
335
+ def __init__(self, path : str = "default:./default.mha", value_outside: int = 0) -> None:
336
+ self.path = path
337
+ self.value_outside = value_outside
338
+
339
+ def __call__(self, name: str, input : torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
340
+ if self.path.endswith(".mha"):
341
+ mask = torch.tensor(sitk.GetArrayFromImage(sitk.ReadImage(self.path))).unsqueeze(0)
342
+ else:
343
+ mask = None
344
+ for dataset in self.datasets:
345
+ if dataset.isDatasetExist(self.path, name):
346
+ mask, _ = dataset.readData(self.path, name)
347
+ break
348
+ if mask is None:
349
+ raise NameError("Mask : {}/{} not found".format(self.path, name))
350
+ return torch.where(torch.tensor(mask) > 0, input, self.value_outside)
351
+
352
+ def inverse(self, name: str, input : torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
353
+ return input
354
+
355
+ class Gradient(Transform):
356
+
357
+ def __init__(self, per_dim: bool = False):
358
+ self.per_dim = per_dim
359
+
360
+ @staticmethod
361
+ def _image_gradient2D(image : torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
362
+ dx = image[:, 1:, :] - image[:, :-1, :]
363
+ dy = image[:, :, 1:] - image[:, :, :-1]
364
+ return torch.nn.ConstantPad2d((0,0,0,1), 0)(dx), torch.nn.ConstantPad2d((0,1,0,0), 0)(dy)
365
+
366
+ @staticmethod
367
+ def _image_gradient3D(image : torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
368
+ dx = image[:, 1:, :, :] - image[:, :-1, :, :]
369
+ dy = image[:, :, 1:, :] - image[:, :, :-1, :]
370
+ dz = image[:, :, :, 1:] - image[:, :, :, :-1]
371
+ return torch.nn.ConstantPad3d((0,0,0,0,0,1), 0)(dx), torch.nn.ConstantPad3d((0,0,0,1,0,0), 0)(dy), torch.nn.ConstantPad3d((0,1,0,0,0,0), 0)(dz)
372
+
373
+ def __call__(self, name: str, input : torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
374
+ result = torch.stack(Gradient._image_gradient3D(input) if len(input.shape) == 4 else Gradient._image_gradient2D(input), dim=1).squeeze(0)
375
+ if not self.per_dim:
376
+ result = torch.sigmoid(result*3)
377
+ result = result.norm(dim=0)
378
+ result = torch.unsqueeze(result, 0)
379
+
380
+ return result
381
+
382
+ def inverse(self, name: str, input : torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
383
+ return input
384
+
385
+ class ArgMax(Transform):
386
+
387
+ def __init__(self, dim: int = 0) -> None:
388
+ self.dim = dim
389
+
390
+ def __call__(self, name: str, input : torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
391
+ return torch.argmax(input, dim=self.dim).unsqueeze(self.dim)
392
+
393
+ def inverse(self, name: str, input : torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
394
+ return input
395
+
396
+ class FlatLabel(Transform):
397
+
398
+ def __init__(self, labels: Union[list[int], None] = None) -> None:
399
+ self.labels = labels
400
+
401
+ def __call__(self, name: str, input : torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
402
+ data = torch.zeros_like(input)
403
+ if self.labels:
404
+ for label in self.labels:
405
+ data[torch.where(input == label)] = 1
406
+ else:
407
+ data[torch.where(input > 0)] = 1
408
+ return data
409
+
410
+ def inverse(self, name: str, input : torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
411
+ return input
412
+
413
+ class Save(Transform):
414
+
415
+ def __init__(self, save: str) -> None:
416
+ self.save = save
417
+
418
+ def __call__(self, name: str, input : torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
419
+ return input
420
+
421
+ def inverse(self, name: str, input : torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
422
+ return input
423
+
424
+ class Flatten(Transform):
425
+
426
+ def __init__(self) -> None:
427
+ super().__init__()
428
+
429
+ def transformShape(self, shape: list[int], cache_attribute: Attribute) -> list[int]:
430
+ return [np.prod(np.asarray(shape))]
431
+
432
+ def __call__(self, name: str, input : torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
433
+ return input.flatten()
434
+
435
+ def inverse(self, name: str, input : torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
436
+ return input
437
+
438
+ class Permute(Transform):
439
+
440
+ def __init__(self, dims: str = "1|0|2") -> None:
441
+ super().__init__()
442
+ self.dims = [0]+[int(d)+1 for d in dims.split("|")]
443
+
444
+ def transformShape(self, shape: list[int], cache_attribute: Attribute) -> list[int]:
445
+ return [shape[it-1] for it in self.dims[1:]]
446
+
447
+ def __call__(self, name: str, input : torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
448
+ return input.permute(tuple(self.dims))
449
+
450
+ def inverse(self, name: str, input : torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
451
+ return input.permute(tuple(np.argsort(self.dims)))
452
+
453
+ class Flip(Transform):
454
+
455
+ def __init__(self, dims: str = "1|0|2") -> None:
456
+ super().__init__()
457
+
458
+ self.dims = [int(d)+1 for d in str(dims).split("|")]
459
+
460
+ def __call__(self, name: str, input : torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
461
+ return input.flip(tuple(self.dims))
462
+
463
+ def inverse(self, name: str, input : torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
464
+ return input.flip(tuple(self.dims))
465
+
466
+ class Canonical(Transform):
467
+
468
+ def __init__(self) -> None:
469
+ self.canonical_direction = torch.diag(torch.tensor([-1, -1, 1])).to(torch.double)
470
+
471
+ def __call__(self, name: str, input: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
472
+ spacing = cache_attribute.get_tensor("Spacing")
473
+ initial_matrix = cache_attribute.get_tensor("Direction").reshape(3,3).to(torch.double)
474
+ initial_origin = cache_attribute.get_tensor("Origin")
475
+ cache_attribute["Direction"] = (self.canonical_direction).flatten()
476
+ matrix = _affine_matrix(self.canonical_direction @ initial_matrix.inverse(), torch.tensor([0, 0, 0]))
477
+ center_voxel = torch.tensor([(input.shape[-i-1] - 1) * spacing[i] / 2 for i in range(3)], dtype=torch.double)
478
+ center_physical = initial_matrix @ center_voxel + initial_origin
479
+ cache_attribute["Origin"] = center_physical - (self.canonical_direction @ center_voxel)
480
+ return _resample_affine(input, matrix.unsqueeze(0))
481
+
482
+ def inverse(self, name: str, input: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
483
+ cache_attribute.pop("Direction")
484
+ cache_attribute.pop("Origin")
485
+ matrix = _affine_matrix((self.canonical_direction @ cache_attribute.get_tensor("Direction").to(torch.double).reshape(3,3).inverse()).inverse(), torch.tensor([0, 0, 0]))
486
+ return _resample_affine(input, matrix.unsqueeze(0))
487
+
488
+ class HistogramMatching(Transform):
489
+
490
+ def __init__(self, reference_group: str) -> None:
491
+ self.reference_group = reference_group
492
+
493
+ def __call__(self, name: str, input: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
494
+ image = data_to_image(input, cache_attribute)
495
+ image_ref = None
496
+ for dataset in self.datasets:
497
+ if dataset.isDatasetExist(self.reference_group, name):
498
+ image_ref = dataset.readImage(self.reference_group, name)
499
+ if image_ref is None:
500
+ raise NameError("Image : {}/{} not found".format(self.reference_group, name))
501
+ matcher = sitk.HistogramMatchingImageFilter()
502
+ matcher.SetNumberOfHistogramLevels(256)
503
+ matcher.SetNumberOfMatchPoints(1)
504
+ matcher.SetThresholdAtMeanIntensity(True)
505
+ result, _ = image_to_data(matcher.Execute(image, image_ref))
506
+ return torch.tensor(result)
507
+
508
+ def inverse(self, name: str, input: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
509
+ return input
510
+
511
+ class SelectLabel(Transform):
512
+
513
+ def __init__(self, labels: list[str]) -> None:
514
+ self.labels = [l[1:-1].split(",") for l in labels]
515
+ def __call__(self, name: str, input: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
516
+ data = torch.zeros_like(input)
517
+ for old_label, new_label in self.labels:
518
+ data[input == int(old_label)] = int(new_label)
519
+ return data
520
+
521
+ def inverse(self, name: str, input: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
522
+ return input
523
+
524
+ class OneHot(Transform):
525
+
526
+ def __init__(self, num_classes: int) -> None:
527
+ super().__init__()
528
+ self.num_classes = num_classes
529
+
530
+ def __call__(self, name: str, input: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
531
+ result = F.one_hot(input.type(torch.int64), num_classes=self.num_classes).permute(0, len(input.shape), *[i+1 for i in range(len(input.shape)-1)]).float().squeeze(2)
532
+ print(result.shape)
533
+ return result
534
+
535
+ def inverse(self, name: str, input: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
536
+ return torch.argmax(input, dim=1).unsqueeze(1)
konfai/evaluator.py ADDED
@@ -0,0 +1,146 @@
1
+ import os
2
+ from torch.utils.data import DataLoader
3
+ import torch
4
+ import tqdm
5
+ import numpy as np
6
+ import json
7
+ import shutil
8
+ import builtins
9
+ import importlib
10
+ from KonfAI.konfai import EVALUATIONS_DIRECTORY, PREDICTIONS_DIRECTORY, DEEP_LEARNING_API_ROOT, CONFIG_FILE
11
+ from KonfAI.konfai.utils.config import config
12
+ from KonfAI.konfai.utils.utils import _getModule, DistributedObject, synchronize_data
13
+ from KonfAI.konfai.data.dataset import DataMetric
14
+
15
+ class CriterionsAttr():
16
+
17
+ @config()
18
+ def __init__(self) -> None:
19
+ pass
20
+
21
+ class CriterionsLoader():
22
+
23
+ @config()
24
+ def __init__(self, criterionsLoader: dict[str, CriterionsAttr] = {"default:torch_nn_CrossEntropyLoss:Dice:NCC": CriterionsAttr()}) -> None:
25
+ self.criterionsLoader = criterionsLoader
26
+
27
+ def getCriterions(self, output_group : str, target_group : str) -> dict[torch.nn.Module, CriterionsAttr]:
28
+ criterions = {}
29
+ for module_classpath, criterionsAttr in self.criterionsLoader.items():
30
+ module, name = _getModule(module_classpath, "measure")
31
+ criterions[config("{}.metrics.{}.targetsCriterions.{}.criterionsLoader.{}".format(DEEP_LEARNING_API_ROOT(), output_group, target_group, module_classpath))(getattr(importlib.import_module(module), name))(config = None)] = criterionsAttr
32
+ return criterions
33
+
34
+ class TargetCriterionsLoader():
35
+
36
+ @config()
37
+ def __init__(self, targetsCriterions : dict[str, CriterionsLoader] = {"default" : CriterionsLoader()}) -> None:
38
+ self.targetsCriterions = targetsCriterions
39
+
40
+ def getTargetsCriterions(self, output_group : str) -> dict[str, dict[torch.nn.Module, float]]:
41
+ targetsCriterions = {}
42
+ for target_group, criterionsLoader in self.targetsCriterions.items():
43
+ targetsCriterions[target_group] = criterionsLoader.getCriterions(output_group, target_group)
44
+ return targetsCriterions
45
+
46
+ class Statistics():
47
+
48
+ def __init__(self, filename: str) -> None:
49
+ self.measures: dict[str, dict[str, float]] = {}
50
+ self.filename = filename
51
+
52
+ def add(self, values: dict[str, float], name_dataset: str) -> None:
53
+ for name, value in values.items():
54
+ if name_dataset not in self.measures:
55
+ self.measures[name_dataset] = {}
56
+ self.measures[name_dataset][name] = value
57
+
58
+ def getStatistic(values: list[float]) -> dict[str, float]:
59
+ return {"max": np.max(values), "min": np.min(values), "std": np.std(values), "25pc": np.percentile(values, 25), "50pc": np.percentile(values, 50), "75pc": np.percentile(values, 75), "mean": np.mean(values), "count": len(values)}
60
+
61
+ def write(self, outputs: list[dict[str, any]]) -> None:
62
+ measures = {}
63
+ for output in outputs:
64
+ measures.update(output)
65
+ result = {}
66
+ result["case"] = {}
67
+ for name, v in measures.items():
68
+ for metric_name, value in v.items():
69
+ if metric_name not in result["case"]:
70
+ result["case"][metric_name] = {}
71
+ result["case"][metric_name][name] = value
72
+
73
+ result["aggregates"] = {}
74
+ tmp: dict[str, list[float]] = {}
75
+ for name, v in measures.items():
76
+ for metric_name, value in v.items():
77
+ if metric_name not in tmp:
78
+ tmp[metric_name] = []
79
+ tmp[metric_name].append(v[metric_name])
80
+ for metric_name, values in tmp.items():
81
+ result["aggregates"][metric_name] = Statistics.getStatistic(values)
82
+
83
+ with open(self.filename, "w") as f:
84
+ f.write(json.dumps(result, indent=4))
85
+
86
+ class Evaluator(DistributedObject):
87
+
88
+ @config("Evaluator")
89
+ def __init__(self, train_name: str = "default:name", metrics: dict[str, TargetCriterionsLoader] = {"default": TargetCriterionsLoader()}, dataset : DataMetric = DataMetric(),) -> None:
90
+ if os.environ["DEEP_LEANING_API_CONFIG_MODE"] != "Done":
91
+ exit(0)
92
+ super().__init__(train_name)
93
+ self.metric_path = METRICS_DIRECTORY()+self.name+"/"
94
+ self.predict_path = PREDICTIONS_DIRECTORY()+self.name+"/"
95
+ self.metricsLoader = metrics
96
+ self.dataset = dataset
97
+ self.metrics = {k: v.getTargetsCriterions(k) for k, v in self.metricsLoader.items()}
98
+ self.statistics_train = Statistics(self.metric_path+"Metric_TRAIN.json")
99
+ self.statistics_validation = Statistics(self.metric_path+"Metric_VALIDATION.json")
100
+
101
+ def update(self, data_dict: dict[str, tuple[torch.Tensor, str]], statistics : Statistics) -> dict[str, float]:
102
+ result = {}
103
+ for output_group in self.metrics:
104
+ for target_group in self.metrics[output_group]:
105
+ targets = [data_dict[group][0] for group in target_group.split("/") if group in data_dict]
106
+ name = data_dict[output_group][1][0]
107
+ for metric in self.metrics[output_group][target_group]:
108
+ result["{}:{}:{}".format(output_group, target_group, metric.__class__.__name__)] = metric(data_dict[output_group][0], *targets).item()
109
+ statistics.add(result, name)
110
+ return result
111
+
112
+ def setup(self, world_size: int):
113
+ if os.path.exists(self.metric_path):
114
+ if os.environ["DL_API_OVERWRITE"] != "True":
115
+ accept = builtins.input("The metric {} already exists ! Do you want to overwrite it (yes,no) : ".format(self.name))
116
+ if accept != "yes":
117
+ return
118
+
119
+ if os.path.exists(self.metric_path):
120
+ shutil.rmtree(self.metric_path)
121
+
122
+ if not os.path.exists(self.metric_path):
123
+ os.makedirs(self.metric_path)
124
+ metric_namefile_src = CONFIG_FILE().replace(".yml", "")
125
+ shutil.copyfile(metric_namefile_src+".yml", "{}{}.yml".format(self.metric_path, metric_namefile_src))
126
+
127
+ self.dataloader = self.dataset.getData(world_size)
128
+
129
+ def run_process(self, world_size: int, global_rank: int, gpu: int, dataloaders: list[DataLoader]):
130
+ description = lambda measure : "Metric TRAIN : {} ".format(" | ".join("{}: {:.2f}".format(k, v) for k, v in measure.items()) if measure is not None else "")
131
+ with tqdm.tqdm(iterable = enumerate(dataloaders[0]), leave=False, desc = description(None), total=len(dataloaders[0])) as batch_iter:
132
+ for _, data_dict in batch_iter:
133
+ batch_iter.set_description(description(self.update({k: (v[0], v[4]) for k,v in data_dict.items()}, self.statistics_train)))
134
+ outputs = synchronize_data(world_size, gpu, self.statistics_train.measures)
135
+ if global_rank == 0:
136
+ self.statistics_train.write(outputs)
137
+ if len(dataloaders) == 2:
138
+ description = lambda measure : "Metric VALIDATION : {} ".format(" | ".join("{}: {:.2f}".format(k, v) for k, v in measure.items()) if measure is not None else "")
139
+ with tqdm.tqdm(iterable = enumerate(dataloaders[1]), leave=False, desc = description(None), total=len(dataloaders[1])) as batch_iter:
140
+ for _, data_dict in batch_iter:
141
+ batch_iter.set_description(description(self.update({k: (v[0], v[4]) for k,v in data_dict.items()}, self.statistics_validation)))
142
+ outputs = synchronize_data(world_size, gpu, self.statistics_validation.measures)
143
+ if global_rank == 0:
144
+ self.statistics_validation.write(outputs)
145
+
146
+ # 69.8936 ± 13.2773 (107) 28.6021 (117) 0.8641 (118)