konfai 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of konfai might be problematic. Click here for more details.
- konfai/__init__.py +16 -0
- konfai/data/HDF5.py +326 -0
- konfai/data/__init__.py +0 -0
- konfai/data/augmentation.py +597 -0
- konfai/data/dataset.py +470 -0
- konfai/data/transform.py +536 -0
- konfai/evaluator.py +146 -0
- konfai/main.py +43 -0
- konfai/metric/__init__.py +0 -0
- konfai/metric/measure.py +488 -0
- konfai/metric/schedulers.py +49 -0
- konfai/models/classification/convNeXt.py +175 -0
- konfai/models/classification/resnet.py +116 -0
- konfai/models/generation/cStyleGan.py +137 -0
- konfai/models/generation/ddpm.py +218 -0
- konfai/models/generation/diffusionGan.py +557 -0
- konfai/models/generation/gan.py +134 -0
- konfai/models/generation/vae.py +72 -0
- konfai/models/registration/registration.py +136 -0
- konfai/models/representation/representation.py +57 -0
- konfai/models/segmentation/NestedUNet.py +53 -0
- konfai/models/segmentation/UNet.py +58 -0
- konfai/network/__init__.py +0 -0
- konfai/network/blocks.py +348 -0
- konfai/network/network.py +950 -0
- konfai/predictor.py +366 -0
- konfai/trainer.py +330 -0
- konfai/utils/ITK.py +269 -0
- konfai/utils/Registration.py +199 -0
- konfai/utils/__init__.py +0 -0
- konfai/utils/config.py +218 -0
- konfai/utils/dataset.py +764 -0
- konfai/utils/utils.py +493 -0
- konfai-1.0.0.dist-info/METADATA +68 -0
- konfai-1.0.0.dist-info/RECORD +39 -0
- konfai-1.0.0.dist-info/WHEEL +5 -0
- konfai-1.0.0.dist-info/entry_points.txt +3 -0
- konfai-1.0.0.dist-info/licenses/LICENSE +201 -0
- konfai-1.0.0.dist-info/top_level.txt +1 -0
konfai/data/transform.py
ADDED
|
@@ -0,0 +1,536 @@
|
|
|
1
|
+
import importlib
|
|
2
|
+
import torch
|
|
3
|
+
import numpy as np
|
|
4
|
+
import SimpleITK as sitk
|
|
5
|
+
from abc import ABC, abstractmethod
|
|
6
|
+
import torch.nn.functional as F
|
|
7
|
+
from typing import Any, Union
|
|
8
|
+
|
|
9
|
+
from KonfAI.konfai.utils.utils import _getModule, NeedDevice, _resample_affine, _affine_matrix
|
|
10
|
+
from KonfAI.konfai.utils.dataset import Dataset, Attribute, data_to_image, image_to_data
|
|
11
|
+
from KonfAI.konfai.utils.config import config
|
|
12
|
+
|
|
13
|
+
class Transform(NeedDevice, ABC):
|
|
14
|
+
|
|
15
|
+
def __init__(self) -> None:
|
|
16
|
+
self.datasets : list[Dataset] = []
|
|
17
|
+
|
|
18
|
+
def setDatasets(self, datasets: list[Dataset]):
|
|
19
|
+
self.datasets = datasets
|
|
20
|
+
|
|
21
|
+
def transformShape(self, shape: list[int], cache_attribute: Attribute) -> list[int]:
|
|
22
|
+
return shape
|
|
23
|
+
|
|
24
|
+
@abstractmethod
|
|
25
|
+
def __call__(self, name: str, input : torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
|
|
26
|
+
pass
|
|
27
|
+
|
|
28
|
+
@abstractmethod
|
|
29
|
+
def inverse(self, name: str, input : torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
|
|
30
|
+
pass
|
|
31
|
+
|
|
32
|
+
class TransformLoader:
|
|
33
|
+
|
|
34
|
+
@config()
|
|
35
|
+
def __init__(self) -> None:
|
|
36
|
+
pass
|
|
37
|
+
|
|
38
|
+
def getTransform(self, classpath : str, DL_args : str) -> Transform:
|
|
39
|
+
module, name = _getModule(classpath, "transform")
|
|
40
|
+
return config("{}.{}".format(DL_args, classpath))(getattr(importlib.import_module(module), name))(config = None)
|
|
41
|
+
|
|
42
|
+
class Clip(Transform):
|
|
43
|
+
|
|
44
|
+
def __init__(self, min_value : float = -1024, max_value : float = 1024, saveClip_min: bool = False, saveClip_max: bool = False) -> None:
|
|
45
|
+
assert max_value > min_value
|
|
46
|
+
self.min_value = min_value
|
|
47
|
+
self.max_value = max_value
|
|
48
|
+
self.saveClip_min = saveClip_min
|
|
49
|
+
self.saveClip_max = saveClip_max
|
|
50
|
+
|
|
51
|
+
def __call__(self, name: str, input : torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
|
|
52
|
+
input[torch.where(input < self.min_value)] = self.min_value
|
|
53
|
+
input[torch.where(input > self.max_value)] = self.max_value
|
|
54
|
+
if self.saveClip_min:
|
|
55
|
+
cache_attribute["Min"] = self .min_value
|
|
56
|
+
if self.saveClip_max:
|
|
57
|
+
cache_attribute["Max"] = self.max_value
|
|
58
|
+
return input
|
|
59
|
+
|
|
60
|
+
def inverse(self, name: str, input : torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
|
|
61
|
+
return input
|
|
62
|
+
|
|
63
|
+
class Normalize(Transform):
|
|
64
|
+
|
|
65
|
+
def __init__(self, lazy : bool = False, channels: Union[list[int], None] = None, min_value : float = -1, max_value : float = 1) -> None:
|
|
66
|
+
assert max_value > min_value
|
|
67
|
+
self.lazy = lazy
|
|
68
|
+
self.min_value = min_value
|
|
69
|
+
self.max_value = max_value
|
|
70
|
+
self.channels = channels
|
|
71
|
+
|
|
72
|
+
def __call__(self, name: str, input : torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
|
|
73
|
+
if "Min" not in cache_attribute:
|
|
74
|
+
if self.channels:
|
|
75
|
+
cache_attribute["Min"] = torch.min(input[self.channels])
|
|
76
|
+
else:
|
|
77
|
+
cache_attribute["Min"] = torch.min(input)
|
|
78
|
+
if "Max" not in cache_attribute:
|
|
79
|
+
if self.channels:
|
|
80
|
+
cache_attribute["Max"] = torch.max(input[self.channels])
|
|
81
|
+
else:
|
|
82
|
+
cache_attribute["Max"] = torch.max(input)
|
|
83
|
+
if not self.lazy:
|
|
84
|
+
input_min = float(cache_attribute["Min"])
|
|
85
|
+
input_max = float(cache_attribute["Max"])
|
|
86
|
+
norm = input_max-input_min
|
|
87
|
+
assert norm != 0
|
|
88
|
+
if self.channels:
|
|
89
|
+
for channel in self.channels:
|
|
90
|
+
input[channel] = (self.max_value-self.min_value)*(input[channel] - input_min) / norm + self.min_value
|
|
91
|
+
else:
|
|
92
|
+
input = (self.max_value-self.min_value)*(input - input_min) / norm + self.min_value
|
|
93
|
+
return input
|
|
94
|
+
|
|
95
|
+
def inverse(self, name: str, input : torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
|
|
96
|
+
if self.lazy:
|
|
97
|
+
return input
|
|
98
|
+
else:
|
|
99
|
+
input_min = float(cache_attribute.pop("Min"))
|
|
100
|
+
input_max = float(cache_attribute.pop("Max"))
|
|
101
|
+
return (input - self.min_value)*(input_max-input_min)/(self.max_value-self.min_value)+input_min
|
|
102
|
+
|
|
103
|
+
class Standardize(Transform):
|
|
104
|
+
|
|
105
|
+
def __init__(self, lazy : bool = False, mean: Union[list[float], None] = None, std: Union[list[float], None]= None) -> None:
|
|
106
|
+
self.lazy = lazy
|
|
107
|
+
self.mean = mean
|
|
108
|
+
self.std = std
|
|
109
|
+
|
|
110
|
+
def __call__(self, name: str, input : torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
|
|
111
|
+
if "Mean" not in cache_attribute:
|
|
112
|
+
cache_attribute["Mean"] = torch.mean(input.type(torch.float32), dim=[i + 1 for i in range(len(input.shape)-1)]) if self.mean is None else torch.tensor([self.mean])
|
|
113
|
+
if "Std" not in cache_attribute:
|
|
114
|
+
cache_attribute["Std"] = torch.std(input.type(torch.float32), dim=[i + 1 for i in range(len(input.shape)-1)]) if self.std is None else torch.tensor([self.std])
|
|
115
|
+
|
|
116
|
+
if self.lazy:
|
|
117
|
+
return input
|
|
118
|
+
else:
|
|
119
|
+
mean = cache_attribute.get_tensor("Mean").view(-1, *[1 for _ in range(len(input.shape)-1)])
|
|
120
|
+
std = cache_attribute.get_tensor("Std").view(-1, *[1 for _ in range(len(input.shape)-1)])
|
|
121
|
+
return (input - mean) / std
|
|
122
|
+
|
|
123
|
+
def inverse(self, name: str, input : torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
|
|
124
|
+
if self.lazy:
|
|
125
|
+
return input
|
|
126
|
+
else:
|
|
127
|
+
mean = float(cache_attribute.pop("Mean"))
|
|
128
|
+
std = float(cache_attribute.pop("Std"))
|
|
129
|
+
return input * std + mean
|
|
130
|
+
|
|
131
|
+
class TensorCast(Transform):
|
|
132
|
+
|
|
133
|
+
def __init__(self, dtype : str = "default:float32,int64,int16") -> None:
|
|
134
|
+
self.dtype : torch.dtype = getattr(torch, dtype)
|
|
135
|
+
|
|
136
|
+
def __call__(self, name: str, input : torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
|
|
137
|
+
cache_attribute["dtype"] = input.dtype
|
|
138
|
+
return input.type(self.dtype)
|
|
139
|
+
|
|
140
|
+
def inverse(self, name: str, input : torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
|
|
141
|
+
return input.to(eval(cache_attribute.pop("dtype")))
|
|
142
|
+
|
|
143
|
+
class Padding(Transform):
|
|
144
|
+
|
|
145
|
+
def __init__(self, padding : list[int] = [0,0,0,0,0,0], mode : str = "default:constant,reflect,replicate,circular") -> None:
|
|
146
|
+
self.padding = padding
|
|
147
|
+
self.mode = mode
|
|
148
|
+
|
|
149
|
+
def __call__(self, name: str, input : torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
|
|
150
|
+
if "Origin" in cache_attribute and "Spacing" in cache_attribute and "Direction" in cache_attribute:
|
|
151
|
+
origin = torch.tensor(cache_attribute.get_np_array("Origin"))
|
|
152
|
+
matrix = torch.tensor(cache_attribute.get_np_array("Direction").reshape((len(origin),len(origin))))
|
|
153
|
+
origin = torch.matmul(origin, matrix)
|
|
154
|
+
for dim in range(len(self.padding)//2):
|
|
155
|
+
origin[-dim-1] -= self.padding[dim*2]* cache_attribute.get_np_array("Spacing")[-dim-1]
|
|
156
|
+
cache_attribute["Origin"] = torch.matmul(origin, torch.inverse(matrix))
|
|
157
|
+
result = F.pad(input.unsqueeze(0), tuple(self.padding), self.mode.split(":")[0], float(self.mode.split(":")[1]) if len(self.mode.split(":")) == 2 else 0).squeeze(0)
|
|
158
|
+
return result
|
|
159
|
+
|
|
160
|
+
def transformShape(self, shape: list[int], cache_attribute: Attribute) -> list[int]:
|
|
161
|
+
for dim in range(len(self.padding)//2):
|
|
162
|
+
shape[-dim-1] += sum(self.padding[dim*2:dim*2+2])
|
|
163
|
+
return shape
|
|
164
|
+
|
|
165
|
+
def inverse(self, name: str, input : torch.Tensor, cache_attribute: dict[str, torch.Tensor]) -> torch.Tensor:
|
|
166
|
+
if "Origin" in cache_attribute and "Spacing" in cache_attribute and "Direction" in cache_attribute:
|
|
167
|
+
cache_attribute.pop("Origin")
|
|
168
|
+
slices = [slice(0, shape) for shape in input.shape]
|
|
169
|
+
for dim in range(len(self.padding)//2):
|
|
170
|
+
slices[-dim-1] = slice(self.padding[dim*2], input.shape[-dim-1]-self.padding[dim*2+1])
|
|
171
|
+
result = input[slices]
|
|
172
|
+
return result
|
|
173
|
+
|
|
174
|
+
class Squeeze(Transform):
|
|
175
|
+
|
|
176
|
+
def __init__(self, dim: int) -> None:
|
|
177
|
+
self.dim = dim
|
|
178
|
+
|
|
179
|
+
def __call__(self, name: str, input : torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
|
|
180
|
+
return input.squeeze(self.dim)
|
|
181
|
+
|
|
182
|
+
def inverse(self, name: str, input : torch.Tensor, cache_attribute: dict[str, Any]) -> torch.Tensor:
|
|
183
|
+
return input.unsqueeze(self.dim)
|
|
184
|
+
|
|
185
|
+
class Resample(Transform, ABC):
|
|
186
|
+
|
|
187
|
+
def __init__(self) -> None:
|
|
188
|
+
pass
|
|
189
|
+
|
|
190
|
+
def _resample(self, input: torch.Tensor, size: list[int]) -> torch.Tensor:
|
|
191
|
+
args = {}
|
|
192
|
+
if input.dtype == torch.uint8:
|
|
193
|
+
mode = "nearest"
|
|
194
|
+
elif len(input.shape) < 4:
|
|
195
|
+
mode = "bilinear"
|
|
196
|
+
else:
|
|
197
|
+
mode = "trilinear"
|
|
198
|
+
return F.interpolate(input.type(torch.float32).unsqueeze(0), size=tuple(size), mode=mode).squeeze(0).type(input.dtype).cpu()
|
|
199
|
+
|
|
200
|
+
@abstractmethod
|
|
201
|
+
def __call__(self, name: str, input : torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
|
|
202
|
+
pass
|
|
203
|
+
|
|
204
|
+
@abstractmethod
|
|
205
|
+
def transformShape(self, shape: list[int], cache_attribute: Attribute) -> list[int]:
|
|
206
|
+
pass
|
|
207
|
+
|
|
208
|
+
def inverse(self, name: str, input : torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
|
|
209
|
+
size_0 = cache_attribute.pop_np_array("Size")
|
|
210
|
+
size_1 = cache_attribute.pop_np_array("Size")
|
|
211
|
+
_ = cache_attribute.pop_np_array("Spacing")
|
|
212
|
+
return self._resample(input, [int(size) for size in size_1])
|
|
213
|
+
|
|
214
|
+
class ResampleIsotropic(Resample):
|
|
215
|
+
|
|
216
|
+
def __init__(self, spacing : list[float] = [1., 1., 1.]) -> None:
|
|
217
|
+
self.spacing = torch.tensor(spacing, dtype=torch.float64)
|
|
218
|
+
|
|
219
|
+
def transformShape(self, shape: list[int], cache_attribute: Attribute) -> list[int]:
|
|
220
|
+
assert "Spacing" in cache_attribute, "Error no spacing"
|
|
221
|
+
resize_factor = self.spacing/cache_attribute.get_tensor("Spacing").flip(0)
|
|
222
|
+
return [int(x) for x in (torch.tensor(shape) * 1/resize_factor)]
|
|
223
|
+
|
|
224
|
+
def __call__(self, name: str, input : torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
|
|
225
|
+
assert "Spacing" in cache_attribute, "Error no spacing"
|
|
226
|
+
resize_factor = self.spacing/cache_attribute.get_tensor("Spacing").flip(0)
|
|
227
|
+
cache_attribute["Spacing"] = self.spacing.flip(0)
|
|
228
|
+
cache_attribute["Size"] = np.asarray([int(x) for x in torch.tensor(input.shape[1:])])
|
|
229
|
+
size = [int(x) for x in (torch.tensor(input.shape[1:]) * 1/resize_factor)]
|
|
230
|
+
cache_attribute["Size"] = np.asarray(size)
|
|
231
|
+
return self._resample(input, size)
|
|
232
|
+
|
|
233
|
+
class ResampleResize(Resample):
|
|
234
|
+
|
|
235
|
+
def __init__(self, size : list[int] = [100,512,512]) -> None:
|
|
236
|
+
self.size = size
|
|
237
|
+
|
|
238
|
+
def transformShape(self, shape: list[int], cache_attribute: Attribute) -> list[int]:
|
|
239
|
+
return self.size
|
|
240
|
+
|
|
241
|
+
def __call__(self, name: str, input: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
|
|
242
|
+
if "Spacing" in cache_attribute:
|
|
243
|
+
cache_attribute["Spacing"] = torch.flip(torch.tensor(list(input.shape[1:]))/torch.tensor(self.size)*torch.flip(cache_attribute.get_tensor("Spacing"), dims=[0]), dims=[0])
|
|
244
|
+
cache_attribute["Size"] = np.asarray([int(x) for x in torch.tensor(input.shape[1:])])
|
|
245
|
+
cache_attribute["Size"] = self.size
|
|
246
|
+
return self._resample(input, self.size)
|
|
247
|
+
|
|
248
|
+
class ResampleTransform(Transform):
|
|
249
|
+
|
|
250
|
+
def __init__(self, transforms : dict[str, bool]) -> None:
|
|
251
|
+
self.transforms = transforms
|
|
252
|
+
|
|
253
|
+
def transformShape(self, shape: list[int], cache_attribute: Attribute) -> list[int]:
|
|
254
|
+
return shape
|
|
255
|
+
|
|
256
|
+
def __call__V1(self, name: str, input : torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
|
|
257
|
+
transforms = []
|
|
258
|
+
image = data_to_image(input, cache_attribute)
|
|
259
|
+
for transform_group, invert in self.transforms.items():
|
|
260
|
+
transform = None
|
|
261
|
+
for dataset in self.datasets:
|
|
262
|
+
if dataset.isDatasetExist(transform_group, name):
|
|
263
|
+
transform = dataset.readTransform(transform_group, name)
|
|
264
|
+
break
|
|
265
|
+
if transform is None:
|
|
266
|
+
raise NameError("Tranform : {}/{} not found".format(transform_group, name))
|
|
267
|
+
if isinstance(transform, sitk.BSplineTransform):
|
|
268
|
+
if invert:
|
|
269
|
+
transformToDisplacementFieldFilter = sitk.TransformToDisplacementFieldFilter()
|
|
270
|
+
transformToDisplacementFieldFilter.SetReferenceImage(image)
|
|
271
|
+
displacementField = transformToDisplacementFieldFilter.Execute(transform)
|
|
272
|
+
iterativeInverseDisplacementFieldImageFilter = sitk.IterativeInverseDisplacementFieldImageFilter()
|
|
273
|
+
iterativeInverseDisplacementFieldImageFilter.SetNumberOfIterations(20)
|
|
274
|
+
inverseDisplacementField = iterativeInverseDisplacementFieldImageFilter.Execute(displacementField)
|
|
275
|
+
transform = sitk.DisplacementFieldTransform(inverseDisplacementField)
|
|
276
|
+
else:
|
|
277
|
+
if invert:
|
|
278
|
+
transform = transform.GetInverse()
|
|
279
|
+
transforms.append(transform)
|
|
280
|
+
result_transform = sitk.CompositeTransform(transforms)
|
|
281
|
+
result = torch.tensor(sitk.GetArrayFromImage(sitk.Resample(image, image, result_transform, sitk.sitkNearestNeighbor if input.dtype == torch.uint8 else sitk.sitkBSpline, 0 if input.dtype == torch.uint8 else -1024))).unsqueeze(0)
|
|
282
|
+
return result.type(torch.uint8) if input.dtype == torch.uint8 else result
|
|
283
|
+
|
|
284
|
+
def __call__(self, name: str, input : torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
|
|
285
|
+
assert len(input.shape) == 4 , "input size should be 5 dim"
|
|
286
|
+
image = data_to_image(input, cache_attribute)
|
|
287
|
+
|
|
288
|
+
vectors = [torch.arange(0, s) for s in input.shape[1:]]
|
|
289
|
+
grids = torch.meshgrid(vectors, indexing='ij')
|
|
290
|
+
grid = torch.stack(grids)
|
|
291
|
+
grid = torch.unsqueeze(grid, 0)
|
|
292
|
+
|
|
293
|
+
transforms = []
|
|
294
|
+
for transform_group, invert in self.transforms.items():
|
|
295
|
+
transform = None
|
|
296
|
+
for dataset in self.datasets:
|
|
297
|
+
if dataset.isDatasetExist(transform_group, name):
|
|
298
|
+
transform = dataset.readTransform(transform_group, name)
|
|
299
|
+
break
|
|
300
|
+
if transform is None:
|
|
301
|
+
raise NameError("Tranform : {}/{} not found".format(transform_group, name))
|
|
302
|
+
if isinstance(transform, sitk.BSplineTransform):
|
|
303
|
+
if invert:
|
|
304
|
+
transformToDisplacementFieldFilter = sitk.TransformToDisplacementFieldFilter()
|
|
305
|
+
transformToDisplacementFieldFilter.SetReferenceImage(image)
|
|
306
|
+
displacementField = transformToDisplacementFieldFilter.Execute(transform)
|
|
307
|
+
iterativeInverseDisplacementFieldImageFilter = sitk.IterativeInverseDisplacementFieldImageFilter()
|
|
308
|
+
iterativeInverseDisplacementFieldImageFilter.SetNumberOfIterations(20)
|
|
309
|
+
inverseDisplacementField = iterativeInverseDisplacementFieldImageFilter.Execute(displacementField)
|
|
310
|
+
transform = sitk.DisplacementFieldTransform(inverseDisplacementField)
|
|
311
|
+
else:
|
|
312
|
+
if invert:
|
|
313
|
+
transform = transform.GetInverse()
|
|
314
|
+
transforms.append(transform)
|
|
315
|
+
result_transform = sitk.CompositeTransform(transforms)
|
|
316
|
+
|
|
317
|
+
transformToDisplacementFieldFilter = sitk.TransformToDisplacementFieldFilter()
|
|
318
|
+
transformToDisplacementFieldFilter.SetReferenceImage(image)
|
|
319
|
+
transformToDisplacementFieldFilter.SetNumberOfThreads(16)
|
|
320
|
+
new_locs = grid + torch.tensor(sitk.GetArrayFromImage(transformToDisplacementFieldFilter.Execute(result_transform))).unsqueeze(0).permute(0, 4, 1, 2, 3)
|
|
321
|
+
shape = new_locs.shape[2:]
|
|
322
|
+
for i in range(len(shape)):
|
|
323
|
+
new_locs[:, i, ...] = 2 * (new_locs[:, i, ...] / (shape[i] - 1) - 0.5)
|
|
324
|
+
new_locs = new_locs.permute(0, 2, 3, 4, 1)
|
|
325
|
+
new_locs = new_locs[..., [2, 1, 0]]
|
|
326
|
+
result = F.grid_sample(input.to(self.device).unsqueeze(0).float(), new_locs.to(self.device).float(), align_corners=True, padding_mode="border", mode="nearest" if input.dtype == torch.uint8 else "bilinear").squeeze(0).cpu()
|
|
327
|
+
return result.type(torch.uint8) if input.dtype == torch.uint8 else result
|
|
328
|
+
|
|
329
|
+
def inverse(self, name: str, input : torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
|
|
330
|
+
# TODO
|
|
331
|
+
return input
|
|
332
|
+
|
|
333
|
+
class Mask(Transform):
|
|
334
|
+
|
|
335
|
+
def __init__(self, path : str = "default:./default.mha", value_outside: int = 0) -> None:
|
|
336
|
+
self.path = path
|
|
337
|
+
self.value_outside = value_outside
|
|
338
|
+
|
|
339
|
+
def __call__(self, name: str, input : torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
|
|
340
|
+
if self.path.endswith(".mha"):
|
|
341
|
+
mask = torch.tensor(sitk.GetArrayFromImage(sitk.ReadImage(self.path))).unsqueeze(0)
|
|
342
|
+
else:
|
|
343
|
+
mask = None
|
|
344
|
+
for dataset in self.datasets:
|
|
345
|
+
if dataset.isDatasetExist(self.path, name):
|
|
346
|
+
mask, _ = dataset.readData(self.path, name)
|
|
347
|
+
break
|
|
348
|
+
if mask is None:
|
|
349
|
+
raise NameError("Mask : {}/{} not found".format(self.path, name))
|
|
350
|
+
return torch.where(torch.tensor(mask) > 0, input, self.value_outside)
|
|
351
|
+
|
|
352
|
+
def inverse(self, name: str, input : torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
|
|
353
|
+
return input
|
|
354
|
+
|
|
355
|
+
class Gradient(Transform):
|
|
356
|
+
|
|
357
|
+
def __init__(self, per_dim: bool = False):
|
|
358
|
+
self.per_dim = per_dim
|
|
359
|
+
|
|
360
|
+
@staticmethod
|
|
361
|
+
def _image_gradient2D(image : torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
|
362
|
+
dx = image[:, 1:, :] - image[:, :-1, :]
|
|
363
|
+
dy = image[:, :, 1:] - image[:, :, :-1]
|
|
364
|
+
return torch.nn.ConstantPad2d((0,0,0,1), 0)(dx), torch.nn.ConstantPad2d((0,1,0,0), 0)(dy)
|
|
365
|
+
|
|
366
|
+
@staticmethod
|
|
367
|
+
def _image_gradient3D(image : torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
368
|
+
dx = image[:, 1:, :, :] - image[:, :-1, :, :]
|
|
369
|
+
dy = image[:, :, 1:, :] - image[:, :, :-1, :]
|
|
370
|
+
dz = image[:, :, :, 1:] - image[:, :, :, :-1]
|
|
371
|
+
return torch.nn.ConstantPad3d((0,0,0,0,0,1), 0)(dx), torch.nn.ConstantPad3d((0,0,0,1,0,0), 0)(dy), torch.nn.ConstantPad3d((0,1,0,0,0,0), 0)(dz)
|
|
372
|
+
|
|
373
|
+
def __call__(self, name: str, input : torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
|
|
374
|
+
result = torch.stack(Gradient._image_gradient3D(input) if len(input.shape) == 4 else Gradient._image_gradient2D(input), dim=1).squeeze(0)
|
|
375
|
+
if not self.per_dim:
|
|
376
|
+
result = torch.sigmoid(result*3)
|
|
377
|
+
result = result.norm(dim=0)
|
|
378
|
+
result = torch.unsqueeze(result, 0)
|
|
379
|
+
|
|
380
|
+
return result
|
|
381
|
+
|
|
382
|
+
def inverse(self, name: str, input : torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
|
|
383
|
+
return input
|
|
384
|
+
|
|
385
|
+
class ArgMax(Transform):
|
|
386
|
+
|
|
387
|
+
def __init__(self, dim: int = 0) -> None:
|
|
388
|
+
self.dim = dim
|
|
389
|
+
|
|
390
|
+
def __call__(self, name: str, input : torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
|
|
391
|
+
return torch.argmax(input, dim=self.dim).unsqueeze(self.dim)
|
|
392
|
+
|
|
393
|
+
def inverse(self, name: str, input : torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
|
|
394
|
+
return input
|
|
395
|
+
|
|
396
|
+
class FlatLabel(Transform):
|
|
397
|
+
|
|
398
|
+
def __init__(self, labels: Union[list[int], None] = None) -> None:
|
|
399
|
+
self.labels = labels
|
|
400
|
+
|
|
401
|
+
def __call__(self, name: str, input : torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
|
|
402
|
+
data = torch.zeros_like(input)
|
|
403
|
+
if self.labels:
|
|
404
|
+
for label in self.labels:
|
|
405
|
+
data[torch.where(input == label)] = 1
|
|
406
|
+
else:
|
|
407
|
+
data[torch.where(input > 0)] = 1
|
|
408
|
+
return data
|
|
409
|
+
|
|
410
|
+
def inverse(self, name: str, input : torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
|
|
411
|
+
return input
|
|
412
|
+
|
|
413
|
+
class Save(Transform):
|
|
414
|
+
|
|
415
|
+
def __init__(self, save: str) -> None:
|
|
416
|
+
self.save = save
|
|
417
|
+
|
|
418
|
+
def __call__(self, name: str, input : torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
|
|
419
|
+
return input
|
|
420
|
+
|
|
421
|
+
def inverse(self, name: str, input : torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
|
|
422
|
+
return input
|
|
423
|
+
|
|
424
|
+
class Flatten(Transform):
|
|
425
|
+
|
|
426
|
+
def __init__(self) -> None:
|
|
427
|
+
super().__init__()
|
|
428
|
+
|
|
429
|
+
def transformShape(self, shape: list[int], cache_attribute: Attribute) -> list[int]:
|
|
430
|
+
return [np.prod(np.asarray(shape))]
|
|
431
|
+
|
|
432
|
+
def __call__(self, name: str, input : torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
|
|
433
|
+
return input.flatten()
|
|
434
|
+
|
|
435
|
+
def inverse(self, name: str, input : torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
|
|
436
|
+
return input
|
|
437
|
+
|
|
438
|
+
class Permute(Transform):
|
|
439
|
+
|
|
440
|
+
def __init__(self, dims: str = "1|0|2") -> None:
|
|
441
|
+
super().__init__()
|
|
442
|
+
self.dims = [0]+[int(d)+1 for d in dims.split("|")]
|
|
443
|
+
|
|
444
|
+
def transformShape(self, shape: list[int], cache_attribute: Attribute) -> list[int]:
|
|
445
|
+
return [shape[it-1] for it in self.dims[1:]]
|
|
446
|
+
|
|
447
|
+
def __call__(self, name: str, input : torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
|
|
448
|
+
return input.permute(tuple(self.dims))
|
|
449
|
+
|
|
450
|
+
def inverse(self, name: str, input : torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
|
|
451
|
+
return input.permute(tuple(np.argsort(self.dims)))
|
|
452
|
+
|
|
453
|
+
class Flip(Transform):
|
|
454
|
+
|
|
455
|
+
def __init__(self, dims: str = "1|0|2") -> None:
|
|
456
|
+
super().__init__()
|
|
457
|
+
|
|
458
|
+
self.dims = [int(d)+1 for d in str(dims).split("|")]
|
|
459
|
+
|
|
460
|
+
def __call__(self, name: str, input : torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
|
|
461
|
+
return input.flip(tuple(self.dims))
|
|
462
|
+
|
|
463
|
+
def inverse(self, name: str, input : torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
|
|
464
|
+
return input.flip(tuple(self.dims))
|
|
465
|
+
|
|
466
|
+
class Canonical(Transform):
|
|
467
|
+
|
|
468
|
+
def __init__(self) -> None:
|
|
469
|
+
self.canonical_direction = torch.diag(torch.tensor([-1, -1, 1])).to(torch.double)
|
|
470
|
+
|
|
471
|
+
def __call__(self, name: str, input: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
|
|
472
|
+
spacing = cache_attribute.get_tensor("Spacing")
|
|
473
|
+
initial_matrix = cache_attribute.get_tensor("Direction").reshape(3,3).to(torch.double)
|
|
474
|
+
initial_origin = cache_attribute.get_tensor("Origin")
|
|
475
|
+
cache_attribute["Direction"] = (self.canonical_direction).flatten()
|
|
476
|
+
matrix = _affine_matrix(self.canonical_direction @ initial_matrix.inverse(), torch.tensor([0, 0, 0]))
|
|
477
|
+
center_voxel = torch.tensor([(input.shape[-i-1] - 1) * spacing[i] / 2 for i in range(3)], dtype=torch.double)
|
|
478
|
+
center_physical = initial_matrix @ center_voxel + initial_origin
|
|
479
|
+
cache_attribute["Origin"] = center_physical - (self.canonical_direction @ center_voxel)
|
|
480
|
+
return _resample_affine(input, matrix.unsqueeze(0))
|
|
481
|
+
|
|
482
|
+
def inverse(self, name: str, input: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
|
|
483
|
+
cache_attribute.pop("Direction")
|
|
484
|
+
cache_attribute.pop("Origin")
|
|
485
|
+
matrix = _affine_matrix((self.canonical_direction @ cache_attribute.get_tensor("Direction").to(torch.double).reshape(3,3).inverse()).inverse(), torch.tensor([0, 0, 0]))
|
|
486
|
+
return _resample_affine(input, matrix.unsqueeze(0))
|
|
487
|
+
|
|
488
|
+
class HistogramMatching(Transform):
|
|
489
|
+
|
|
490
|
+
def __init__(self, reference_group: str) -> None:
|
|
491
|
+
self.reference_group = reference_group
|
|
492
|
+
|
|
493
|
+
def __call__(self, name: str, input: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
|
|
494
|
+
image = data_to_image(input, cache_attribute)
|
|
495
|
+
image_ref = None
|
|
496
|
+
for dataset in self.datasets:
|
|
497
|
+
if dataset.isDatasetExist(self.reference_group, name):
|
|
498
|
+
image_ref = dataset.readImage(self.reference_group, name)
|
|
499
|
+
if image_ref is None:
|
|
500
|
+
raise NameError("Image : {}/{} not found".format(self.reference_group, name))
|
|
501
|
+
matcher = sitk.HistogramMatchingImageFilter()
|
|
502
|
+
matcher.SetNumberOfHistogramLevels(256)
|
|
503
|
+
matcher.SetNumberOfMatchPoints(1)
|
|
504
|
+
matcher.SetThresholdAtMeanIntensity(True)
|
|
505
|
+
result, _ = image_to_data(matcher.Execute(image, image_ref))
|
|
506
|
+
return torch.tensor(result)
|
|
507
|
+
|
|
508
|
+
def inverse(self, name: str, input: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
|
|
509
|
+
return input
|
|
510
|
+
|
|
511
|
+
class SelectLabel(Transform):
|
|
512
|
+
|
|
513
|
+
def __init__(self, labels: list[str]) -> None:
|
|
514
|
+
self.labels = [l[1:-1].split(",") for l in labels]
|
|
515
|
+
def __call__(self, name: str, input: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
|
|
516
|
+
data = torch.zeros_like(input)
|
|
517
|
+
for old_label, new_label in self.labels:
|
|
518
|
+
data[input == int(old_label)] = int(new_label)
|
|
519
|
+
return data
|
|
520
|
+
|
|
521
|
+
def inverse(self, name: str, input: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
|
|
522
|
+
return input
|
|
523
|
+
|
|
524
|
+
class OneHot(Transform):
|
|
525
|
+
|
|
526
|
+
def __init__(self, num_classes: int) -> None:
|
|
527
|
+
super().__init__()
|
|
528
|
+
self.num_classes = num_classes
|
|
529
|
+
|
|
530
|
+
def __call__(self, name: str, input: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
|
|
531
|
+
result = F.one_hot(input.type(torch.int64), num_classes=self.num_classes).permute(0, len(input.shape), *[i+1 for i in range(len(input.shape)-1)]).float().squeeze(2)
|
|
532
|
+
print(result.shape)
|
|
533
|
+
return result
|
|
534
|
+
|
|
535
|
+
def inverse(self, name: str, input: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
|
|
536
|
+
return torch.argmax(input, dim=1).unsqueeze(1)
|
konfai/evaluator.py
ADDED
|
@@ -0,0 +1,146 @@
|
|
|
1
|
+
import os
|
|
2
|
+
from torch.utils.data import DataLoader
|
|
3
|
+
import torch
|
|
4
|
+
import tqdm
|
|
5
|
+
import numpy as np
|
|
6
|
+
import json
|
|
7
|
+
import shutil
|
|
8
|
+
import builtins
|
|
9
|
+
import importlib
|
|
10
|
+
from KonfAI.konfai import EVALUATIONS_DIRECTORY, PREDICTIONS_DIRECTORY, DEEP_LEARNING_API_ROOT, CONFIG_FILE
|
|
11
|
+
from KonfAI.konfai.utils.config import config
|
|
12
|
+
from KonfAI.konfai.utils.utils import _getModule, DistributedObject, synchronize_data
|
|
13
|
+
from KonfAI.konfai.data.dataset import DataMetric
|
|
14
|
+
|
|
15
|
+
class CriterionsAttr():
|
|
16
|
+
|
|
17
|
+
@config()
|
|
18
|
+
def __init__(self) -> None:
|
|
19
|
+
pass
|
|
20
|
+
|
|
21
|
+
class CriterionsLoader():
|
|
22
|
+
|
|
23
|
+
@config()
|
|
24
|
+
def __init__(self, criterionsLoader: dict[str, CriterionsAttr] = {"default:torch_nn_CrossEntropyLoss:Dice:NCC": CriterionsAttr()}) -> None:
|
|
25
|
+
self.criterionsLoader = criterionsLoader
|
|
26
|
+
|
|
27
|
+
def getCriterions(self, output_group : str, target_group : str) -> dict[torch.nn.Module, CriterionsAttr]:
|
|
28
|
+
criterions = {}
|
|
29
|
+
for module_classpath, criterionsAttr in self.criterionsLoader.items():
|
|
30
|
+
module, name = _getModule(module_classpath, "measure")
|
|
31
|
+
criterions[config("{}.metrics.{}.targetsCriterions.{}.criterionsLoader.{}".format(DEEP_LEARNING_API_ROOT(), output_group, target_group, module_classpath))(getattr(importlib.import_module(module), name))(config = None)] = criterionsAttr
|
|
32
|
+
return criterions
|
|
33
|
+
|
|
34
|
+
class TargetCriterionsLoader():
|
|
35
|
+
|
|
36
|
+
@config()
|
|
37
|
+
def __init__(self, targetsCriterions : dict[str, CriterionsLoader] = {"default" : CriterionsLoader()}) -> None:
|
|
38
|
+
self.targetsCriterions = targetsCriterions
|
|
39
|
+
|
|
40
|
+
def getTargetsCriterions(self, output_group : str) -> dict[str, dict[torch.nn.Module, float]]:
|
|
41
|
+
targetsCriterions = {}
|
|
42
|
+
for target_group, criterionsLoader in self.targetsCriterions.items():
|
|
43
|
+
targetsCriterions[target_group] = criterionsLoader.getCriterions(output_group, target_group)
|
|
44
|
+
return targetsCriterions
|
|
45
|
+
|
|
46
|
+
class Statistics():
|
|
47
|
+
|
|
48
|
+
def __init__(self, filename: str) -> None:
|
|
49
|
+
self.measures: dict[str, dict[str, float]] = {}
|
|
50
|
+
self.filename = filename
|
|
51
|
+
|
|
52
|
+
def add(self, values: dict[str, float], name_dataset: str) -> None:
|
|
53
|
+
for name, value in values.items():
|
|
54
|
+
if name_dataset not in self.measures:
|
|
55
|
+
self.measures[name_dataset] = {}
|
|
56
|
+
self.measures[name_dataset][name] = value
|
|
57
|
+
|
|
58
|
+
def getStatistic(values: list[float]) -> dict[str, float]:
|
|
59
|
+
return {"max": np.max(values), "min": np.min(values), "std": np.std(values), "25pc": np.percentile(values, 25), "50pc": np.percentile(values, 50), "75pc": np.percentile(values, 75), "mean": np.mean(values), "count": len(values)}
|
|
60
|
+
|
|
61
|
+
def write(self, outputs: list[dict[str, any]]) -> None:
|
|
62
|
+
measures = {}
|
|
63
|
+
for output in outputs:
|
|
64
|
+
measures.update(output)
|
|
65
|
+
result = {}
|
|
66
|
+
result["case"] = {}
|
|
67
|
+
for name, v in measures.items():
|
|
68
|
+
for metric_name, value in v.items():
|
|
69
|
+
if metric_name not in result["case"]:
|
|
70
|
+
result["case"][metric_name] = {}
|
|
71
|
+
result["case"][metric_name][name] = value
|
|
72
|
+
|
|
73
|
+
result["aggregates"] = {}
|
|
74
|
+
tmp: dict[str, list[float]] = {}
|
|
75
|
+
for name, v in measures.items():
|
|
76
|
+
for metric_name, value in v.items():
|
|
77
|
+
if metric_name not in tmp:
|
|
78
|
+
tmp[metric_name] = []
|
|
79
|
+
tmp[metric_name].append(v[metric_name])
|
|
80
|
+
for metric_name, values in tmp.items():
|
|
81
|
+
result["aggregates"][metric_name] = Statistics.getStatistic(values)
|
|
82
|
+
|
|
83
|
+
with open(self.filename, "w") as f:
|
|
84
|
+
f.write(json.dumps(result, indent=4))
|
|
85
|
+
|
|
86
|
+
class Evaluator(DistributedObject):
|
|
87
|
+
|
|
88
|
+
@config("Evaluator")
|
|
89
|
+
def __init__(self, train_name: str = "default:name", metrics: dict[str, TargetCriterionsLoader] = {"default": TargetCriterionsLoader()}, dataset : DataMetric = DataMetric(),) -> None:
|
|
90
|
+
if os.environ["DEEP_LEANING_API_CONFIG_MODE"] != "Done":
|
|
91
|
+
exit(0)
|
|
92
|
+
super().__init__(train_name)
|
|
93
|
+
self.metric_path = METRICS_DIRECTORY()+self.name+"/"
|
|
94
|
+
self.predict_path = PREDICTIONS_DIRECTORY()+self.name+"/"
|
|
95
|
+
self.metricsLoader = metrics
|
|
96
|
+
self.dataset = dataset
|
|
97
|
+
self.metrics = {k: v.getTargetsCriterions(k) for k, v in self.metricsLoader.items()}
|
|
98
|
+
self.statistics_train = Statistics(self.metric_path+"Metric_TRAIN.json")
|
|
99
|
+
self.statistics_validation = Statistics(self.metric_path+"Metric_VALIDATION.json")
|
|
100
|
+
|
|
101
|
+
def update(self, data_dict: dict[str, tuple[torch.Tensor, str]], statistics : Statistics) -> dict[str, float]:
|
|
102
|
+
result = {}
|
|
103
|
+
for output_group in self.metrics:
|
|
104
|
+
for target_group in self.metrics[output_group]:
|
|
105
|
+
targets = [data_dict[group][0] for group in target_group.split("/") if group in data_dict]
|
|
106
|
+
name = data_dict[output_group][1][0]
|
|
107
|
+
for metric in self.metrics[output_group][target_group]:
|
|
108
|
+
result["{}:{}:{}".format(output_group, target_group, metric.__class__.__name__)] = metric(data_dict[output_group][0], *targets).item()
|
|
109
|
+
statistics.add(result, name)
|
|
110
|
+
return result
|
|
111
|
+
|
|
112
|
+
def setup(self, world_size: int):
|
|
113
|
+
if os.path.exists(self.metric_path):
|
|
114
|
+
if os.environ["DL_API_OVERWRITE"] != "True":
|
|
115
|
+
accept = builtins.input("The metric {} already exists ! Do you want to overwrite it (yes,no) : ".format(self.name))
|
|
116
|
+
if accept != "yes":
|
|
117
|
+
return
|
|
118
|
+
|
|
119
|
+
if os.path.exists(self.metric_path):
|
|
120
|
+
shutil.rmtree(self.metric_path)
|
|
121
|
+
|
|
122
|
+
if not os.path.exists(self.metric_path):
|
|
123
|
+
os.makedirs(self.metric_path)
|
|
124
|
+
metric_namefile_src = CONFIG_FILE().replace(".yml", "")
|
|
125
|
+
shutil.copyfile(metric_namefile_src+".yml", "{}{}.yml".format(self.metric_path, metric_namefile_src))
|
|
126
|
+
|
|
127
|
+
self.dataloader = self.dataset.getData(world_size)
|
|
128
|
+
|
|
129
|
+
def run_process(self, world_size: int, global_rank: int, gpu: int, dataloaders: list[DataLoader]):
|
|
130
|
+
description = lambda measure : "Metric TRAIN : {} ".format(" | ".join("{}: {:.2f}".format(k, v) for k, v in measure.items()) if measure is not None else "")
|
|
131
|
+
with tqdm.tqdm(iterable = enumerate(dataloaders[0]), leave=False, desc = description(None), total=len(dataloaders[0])) as batch_iter:
|
|
132
|
+
for _, data_dict in batch_iter:
|
|
133
|
+
batch_iter.set_description(description(self.update({k: (v[0], v[4]) for k,v in data_dict.items()}, self.statistics_train)))
|
|
134
|
+
outputs = synchronize_data(world_size, gpu, self.statistics_train.measures)
|
|
135
|
+
if global_rank == 0:
|
|
136
|
+
self.statistics_train.write(outputs)
|
|
137
|
+
if len(dataloaders) == 2:
|
|
138
|
+
description = lambda measure : "Metric VALIDATION : {} ".format(" | ".join("{}: {:.2f}".format(k, v) for k, v in measure.items()) if measure is not None else "")
|
|
139
|
+
with tqdm.tqdm(iterable = enumerate(dataloaders[1]), leave=False, desc = description(None), total=len(dataloaders[1])) as batch_iter:
|
|
140
|
+
for _, data_dict in batch_iter:
|
|
141
|
+
batch_iter.set_description(description(self.update({k: (v[0], v[4]) for k,v in data_dict.items()}, self.statistics_validation)))
|
|
142
|
+
outputs = synchronize_data(world_size, gpu, self.statistics_validation.measures)
|
|
143
|
+
if global_rank == 0:
|
|
144
|
+
self.statistics_validation.write(outputs)
|
|
145
|
+
|
|
146
|
+
# 69.8936 ± 13.2773 (107) 28.6021 (117) 0.8641 (118)
|