konfai 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of konfai might be problematic. Click here for more details.

konfai/__init__.py ADDED
@@ -0,0 +1,16 @@
1
+ import os
2
+ import datetime
3
+
4
+ MODELS_DIRECTORY = lambda : os.environ["DL_API_MODELS_DIRECTORY"]
5
+ CHECKPOINTS_DIRECTORY =lambda : os.environ["DL_API_CHECKPOINTS_DIRECTORY"]
6
+ MODEL = lambda : os.environ["DL_API_MODEL"]
7
+ PREDICTIONS_DIRECTORY =lambda : os.environ["DL_API_PREDICTIONS_DIRECTORY"]
8
+ EVALUATIONS_DIRECTORY =lambda : os.environ["DL_API_EVALUATIONS_DIRECTORY"]
9
+ STATISTICS_DIRECTORY = lambda : os.environ["DL_API_STATISTICS_DIRECTORY"]
10
+ SETUPS_DIRECTORY = lambda : os.environ["DL_API_SETUPS_DIRECTORY"]
11
+ CONFIG_FILE = lambda : os.environ["DEEP_LEARNING_API_CONFIG_FILE"]
12
+ DL_API_STATE = lambda : os.environ["DL_API_STATE"]
13
+ DEEP_LEARNING_API_ROOT = lambda : os.environ["DEEP_LEARNING_API_ROOT"]
14
+ CUDA_VISIBLE_DEVICES = lambda : os.environ["CUDA_VISIBLE_DEVICES"]
15
+
16
+ DATE = lambda : datetime.datetime.now().strftime("%Y_%m_%d_%H_%M_%S")
konfai/data/HDF5.py ADDED
@@ -0,0 +1,326 @@
1
+ from abc import ABC, abstractmethod
2
+ import SimpleITK as sitk
3
+ import numpy as np
4
+ import torch
5
+ import os
6
+ import torch.nn.functional as F
7
+ from typing import Any, Iterator
8
+ from typing import Union
9
+ import itertools
10
+ import copy
11
+ from functools import partial
12
+ from KonfAI.konfai.utils.config import config
13
+ from KonfAI.konfai.utils.utils import get_patch_slices_from_shape
14
+ from KonfAI.konfai.utils.dataset import Dataset, Attribute
15
+ from KonfAI.konfai.data.transform import Transform, Save
16
+ from KonfAI.konfai.data.augmentation import DataAugmentationsList
17
+
18
+
19
+ class PathCombine(ABC):
20
+
21
+ def __init__(self) -> None:
22
+ self.data: torch.Tensor = None
23
+ self.overlap: int = None
24
+
25
+ """
26
+ A = slice(0, overlap)
27
+ B = slice(-overlap, None)
28
+ C = slice(overlap, -overlap)
29
+
30
+ 1D
31
+ A+B
32
+ 2D :
33
+ AA+AB+BA+BB
34
+
35
+ AC+BC
36
+ CA+CB
37
+ 3D :
38
+ AAA+AAB+ABA+ABB+BAA+BAB+BBA+BBB
39
+
40
+ AAC+ABC+BAC+BBC
41
+ ACA+ACB+BCA+BCB
42
+ CAA+CAB+CBA+CBB
43
+
44
+ ACC+BCC
45
+ CAC+CBC
46
+ CCA+CCB
47
+
48
+ """
49
+ def setPatchConfig(self, patch_size: list[int], overlap: int):
50
+ self.data = F.pad(torch.ones([size-overlap*2 for size in patch_size]), [overlap]*2*len(patch_size), mode="constant", value=0)
51
+ self.data = self._setFunction(self.data, overlap)
52
+ dim = len(patch_size)
53
+
54
+ A = slice(0, overlap)
55
+ B = slice(-overlap, None)
56
+ C = slice(overlap, -overlap)
57
+
58
+ for i in range(dim):
59
+ slices_badge = list(itertools.product(*[[A, B] for _ in range(dim-i)]))
60
+ for indexs in itertools.combinations([0,1,2], i):
61
+ result = []
62
+ for slices in slices_badge:
63
+ slices = list(slices)
64
+ for index in indexs:
65
+ slices.insert(index, C)
66
+ result.append(tuple(slices))
67
+ for patch, s in zip(PathCombine._normalise([self.data[s] for s in result]), result):
68
+ self.data[s] = patch
69
+
70
+
71
+ @staticmethod
72
+ def _normalise(patchs: list[torch.Tensor]) -> list[torch.Tensor]:
73
+ data_sum = torch.sum(torch.concat([patch.unsqueeze(0) for patch in patchs], dim=0), dim=0)
74
+ return [d/data_sum for d in patchs]
75
+
76
+ def __call__(self, input: torch.Tensor) -> torch.Tensor:
77
+ return self.data.repeat([input.shape[0]]+[1]*(len(input.shape)-1)).to(input.device)*input
78
+
79
+ @abstractmethod
80
+ def _setFunction(self, data: torch.Tensor, overlap: int) -> torch.Tensor:
81
+ pass
82
+
83
+ class Mean(PathCombine):
84
+
85
+ @config("Mean")
86
+ def __init__(self) -> None:
87
+ super().__init__()
88
+
89
+ def _setFunction(self, data: torch.Tensor, overlap: int) -> torch.Tensor:
90
+ return torch.ones_like(self.data)
91
+
92
+ class Cosinus(PathCombine):
93
+
94
+ @config("Cosinus")
95
+ def __init__(self) -> None:
96
+ super().__init__()
97
+
98
+ def _function_sides(self, overlap: int, x: float):
99
+ return np.clip(np.cos(np.pi/(2*(overlap+1))*x), 0, 1)
100
+
101
+ def _setFunction(self, data: torch.Tensor, overlap: int) -> torch.Tensor:
102
+ image = sitk.GetImageFromArray(np.asarray(data, dtype=np.uint8))
103
+ danielssonDistanceMapImageFilter = sitk.DanielssonDistanceMapImageFilter()
104
+ distance = torch.tensor(sitk.GetArrayFromImage(danielssonDistanceMapImageFilter.Execute(image)))
105
+ return distance.apply_(partial(self._function_sides, overlap))
106
+
107
+ class Accumulator():
108
+
109
+ def __init__(self, patch_slices: list[tuple[slice]], patch_size: list[int], patchCombine: Union[PathCombine, None] = None, batch: bool = True) -> None:
110
+ self._layer_accumulator: list[Union[torch.Tensor, None]] = [None for i in range(len(patch_slices))]
111
+ self.patch_slices = []
112
+ for patch in patch_slices:
113
+ slices = []
114
+ for s, shape in zip(patch, patch_size):
115
+ slices.append(slice(s.start, s.start+shape))
116
+ self.patch_slices.append(tuple(slices))
117
+ self.shape = max([[v.stop for v in patch] for patch in patch_slices])
118
+ self.patch_size = patch_size
119
+ self.patchCombine = patchCombine
120
+ self.batch = batch
121
+
122
+ def addLayer(self, index: int, layer: torch.Tensor) -> None:
123
+ self._layer_accumulator[index] = layer
124
+
125
+ def isFull(self) -> bool:
126
+ return len(self.patch_slices) == len([v for v in self._layer_accumulator if v is not None])
127
+
128
+ def assemble(self) -> torch.Tensor:
129
+ if all([self._layer_accumulator[0].shape[i-len(self.patch_size)] == size for i, size in enumerate(self.patch_size)]):
130
+ N = 2 if self.batch else 1
131
+
132
+ result = torch.zeros((list(self._layer_accumulator[0].shape[:N])+list(max([[v.stop for v in patch] for patch in self.patch_slices]))), dtype=self._layer_accumulator[0].dtype).to(self._layer_accumulator[0].device)
133
+ for patch_slice, data in zip(self.patch_slices, self._layer_accumulator):
134
+ slices_dest = tuple([slice(result.shape[i]) for i in range(N)] + list(patch_slice))
135
+
136
+ for dim, s in enumerate(patch_slice):
137
+ if s.stop-s.start == 1:
138
+ data = data.unsqueeze(dim=dim+N)
139
+ if self.patchCombine is not None:
140
+ result[slices_dest] += self.patchCombine(data)
141
+ else:
142
+ result[slices_dest] = data
143
+ result = result[tuple([slice(None, None)]+[slice(0, s) for s in self.shape])]
144
+ else:
145
+ result = torch.cat(tuple(self._layer_accumulator), dim=0)
146
+ self._layer_accumulator.clear()
147
+ return result
148
+
149
+ class Patch(ABC):
150
+
151
+ def __init__(self, patch_size: list[int], overlap: Union[int, None], path_mask: Union[str, None] = None, padValue: float = 0, extend_slice: int = 0) -> None:
152
+ self.patch_size = patch_size
153
+ self.overlap = overlap
154
+ self._patch_slices : dict[int, list[tuple[slice]]] = {}
155
+ self._nb_patch_per_dim: dict[int, list[tuple[int, bool]]] = {}
156
+ self.path_mask = path_mask
157
+ self.mask = None
158
+ self.padValue = padValue
159
+ self.extend_slice = extend_slice
160
+ if self.path_mask is not None:
161
+ if os.path.exists(self.path_mask):
162
+ self.mask = torch.tensor(sitk.GetArrayFromImage(sitk.ReadImage(self.path_mask)))
163
+ else:
164
+ raise NameError('Mask file not found')
165
+
166
+ def load(self, shape : dict[int, list[int]], a: int = 0) -> None:
167
+ self._patch_slices[a], self._nb_patch_per_dim[a] = get_patch_slices_from_shape(self.patch_size, shape, self.overlap)
168
+
169
+ def getPatch_slices(self, a: int = 0):
170
+ return self._patch_slices[a]
171
+
172
+ @abstractmethod
173
+ def getData(self, data : torch.Tensor, index : int, a: int, isInput: bool) -> torch.Tensor:
174
+ pass
175
+
176
+ def getData(self, data : torch.Tensor, index : int, a: int, isInput: bool) -> list[torch.Tensor]:
177
+ slices_pre = []
178
+ for max in data.shape[:-len(self.patch_size)]:
179
+ slices_pre.append(slice(max))
180
+ extend_slice = self.extend_slice if isInput else 0
181
+
182
+ bottom = extend_slice//2
183
+ top = int(np.ceil(extend_slice/2))
184
+ s = slice(self._patch_slices[a][index][0].start-bottom if self._patch_slices[a][index][0].start-bottom >= 0 else 0, self._patch_slices[a][index][0].stop+top if self._patch_slices[a][index][0].stop+top <= data.shape[len(slices_pre)] else data.shape[len(slices_pre)])
185
+ slices = [s] + list(self._patch_slices[a][index][1:])
186
+ data_sliced = data[slices_pre+slices]
187
+ if data_sliced.shape[len(slices_pre)] < bottom+top+1:
188
+ pad_bottom = 0
189
+ pad_top = 0
190
+ if self._patch_slices[a][index][0].start-bottom < 0:
191
+ pad_bottom = bottom-s.start
192
+ if self._patch_slices[a][index][0].stop+top > data.shape[len(slices_pre)]:
193
+ pad_top = self._patch_slices[a][index][0].stop+top-data.shape[len(slices_pre)]
194
+ data_sliced = F.pad(data_sliced, [0 for _ in range((len(slices)-1)*2)]+[pad_bottom, pad_top], 'reflect')
195
+
196
+ padding = []
197
+ for dim_it, _slice in enumerate(reversed(slices)):
198
+ p = 0 if _slice.start+self.patch_size[-dim_it-1] <= data.shape[-dim_it-1] else self.patch_size[-dim_it-1]-(data.shape[-dim_it-1]-_slice.start)
199
+ padding.append(0)
200
+ padding.append(p)
201
+
202
+ data_sliced = F.pad(data_sliced, tuple(padding), "constant", 0 if data_sliced.dtype == torch.uint8 and self.padValue < 0 else self.padValue)
203
+ if self.mask is not None:
204
+ outside = torch.ones_like(data_sliced)*(0 if data_sliced.dtype == torch.uint8 and self.padValue < 0 else self.padValue)
205
+ data_sliced = torch.where(self.mask == 0, outside, data_sliced)
206
+
207
+ for d in [i for i, v in enumerate(reversed(self.patch_size)) if v == 1]:
208
+ data_sliced = torch.squeeze(data_sliced, dim = len(data_sliced.shape)-d-1)
209
+ return torch.cat([data_sliced[:, i, ...] for i in range(data_sliced.shape[1])], dim=0) if extend_slice > 0 else data_sliced
210
+
211
+ def getSize(self, a: int = 0) -> int:
212
+ return len(self._patch_slices[a])
213
+
214
+ class DatasetPatch(Patch):
215
+
216
+ @config("Patch")
217
+ def __init__(self, patch_size : list[int] = [128, 256, 256], overlap : Union[int, None] = None, mask: Union[str, None] = None, padValue: float = 0, extend_slice: int = 0) -> None:
218
+ super().__init__(patch_size, overlap, mask, padValue, extend_slice)
219
+
220
+ class ModelPatch(Patch):
221
+
222
+ @config("Patch")
223
+ def __init__(self, patch_size : list[int] = [128, 256, 256], overlap : Union[int, None] = None, patchCombine: Union[str, None] = None, mask: Union[str, None] = None, padValue: float = 0, extend_slice: int = 0) -> None:
224
+ super().__init__(patch_size, overlap, mask, padValue, extend_slice)
225
+ self.patchCombine = patchCombine
226
+
227
+ def disassemble(self, *dataList: torch.Tensor) -> Iterator[list[torch.Tensor]]:
228
+ for i in range(self.getSize()):
229
+ yield [self.getData(data, i, 0, True) for data in dataList]
230
+
231
+ class DatasetManager():
232
+
233
+ def __init__(self, index: int, group_src: str, group_dest : str, name: str, dataset : Dataset, patch : Union[DatasetPatch, None], pre_transforms : list[Transform], dataAugmentationsList : list[DataAugmentationsList]) -> None:
234
+ self.group_src = group_src
235
+ self.group_dest = group_dest
236
+ self.name = name
237
+ self.index = index
238
+ self.dataset = dataset
239
+ self.loaded = False
240
+ self.cache_attributes: list[Attribute] = []
241
+ _shape, cache_attribute = self.dataset.getInfos(self.group_src, name)
242
+ self.cache_attributes.append(cache_attribute)
243
+ _shape = list(_shape[1:])
244
+
245
+ self.data : list[torch.Tensor] = list()
246
+
247
+ for transformFunction in pre_transforms:
248
+ _shape = transformFunction.transformShape(_shape, cache_attribute)
249
+
250
+ self.patch = DatasetPatch(patch_size=patch.patch_size, overlap=patch.overlap, mask=patch.path_mask, padValue=patch.padValue, extend_slice=patch.extend_slice) if patch else DatasetPatch(_shape)
251
+ self.patch.load(_shape, 0)
252
+ self.shape = _shape
253
+ self.dataAugmentationsList = dataAugmentationsList
254
+ self.resetAugmentation()
255
+ self.cache_attributes_bak = copy.deepcopy(self.cache_attributes)
256
+
257
+ def resetAugmentation(self):
258
+ i = 1
259
+ for dataAugmentations in self.dataAugmentationsList:
260
+ shape = []
261
+ caches_attribute = []
262
+ for _ in range(dataAugmentations.nb):
263
+ shape.append(self.shape)
264
+ caches_attribute.append(copy.deepcopy(self.cache_attributes[0]))
265
+
266
+ for dataAugmentation in dataAugmentations.dataAugmentations:
267
+ shape = dataAugmentation.state_init(self.index, shape, caches_attribute)
268
+ for it, s in enumerate(shape):
269
+ self.cache_attributes.append(caches_attribute[it])
270
+ self.patch.load(s, i)
271
+ i+=1
272
+
273
+ def load(self, pre_transform : list[Transform], dataAugmentationsList : list[DataAugmentationsList], device: torch.device) -> None:
274
+ if self.loaded:
275
+ return
276
+ i = len(pre_transform)
277
+ data = None
278
+ for transformFunction in reversed(pre_transform):
279
+ if isinstance(transformFunction, Save):
280
+ filename, format = transformFunction.save.split(":")
281
+ dataset = Dataset(filename, format)
282
+ if dataset.isDatasetExist(self.group_dest, self.name):
283
+ data, attrib = dataset.readData(self.group_dest, self.name)
284
+ self.cache_attributes[0].update(attrib)
285
+ break
286
+ i-=1
287
+
288
+ if i==0:
289
+ data, _ = self.dataset.readData(self.group_src, self.name)
290
+
291
+
292
+ data = torch.from_numpy(data)
293
+
294
+ if len(pre_transform):
295
+ for transformFunction in pre_transform[i:]:
296
+ data = transformFunction(self.name, data, self.cache_attributes[0])
297
+ if isinstance(transformFunction, Save):
298
+ filename, format = transformFunction.save.split(":")
299
+ dataset = Dataset(filename, format)
300
+ dataset.write(self.group_dest, self.name, data.numpy(), self.cache_attributes[0])
301
+ self.data : list[torch.Tensor] = list()
302
+ self.data.append(data)
303
+
304
+ for dataAugmentations in dataAugmentationsList:
305
+ a_data = [data.clone() for _ in range(dataAugmentations.nb)]
306
+ for dataAugmentation in dataAugmentations.dataAugmentations:
307
+ a_data = dataAugmentation(self.index, a_data, device)
308
+
309
+ for d in a_data:
310
+ self.data.append(d)
311
+ self.loaded = True
312
+
313
+ def unload(self) -> None:
314
+ if hasattr(self, "data"):
315
+ del self.data
316
+ self.cache_attributes = copy.deepcopy(self.cache_attributes_bak)
317
+ self.loaded = False
318
+
319
+ def getData(self, index : int, a : int, post_transforms : list[Transform], isInput: bool) -> torch.Tensor:
320
+ data = self.patch.getData(self.data[a], index, a, isInput)
321
+ for transformFunction in post_transforms:
322
+ data = transformFunction(self.name, data, self.cache_attributes[a])
323
+ return data
324
+
325
+ def getSize(self, a: int) -> int:
326
+ return self.patch.getSize(a)
File without changes