konfai 1.1.8__py3-none-any.whl → 1.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of konfai might be problematic. Click here for more details.
- konfai/__init__.py +59 -14
- konfai/data/augmentation.py +457 -286
- konfai/data/data_manager.py +533 -316
- konfai/data/patching.py +300 -183
- konfai/data/transform.py +408 -275
- konfai/evaluator.py +325 -68
- konfai/main.py +71 -22
- konfai/metric/measure.py +360 -244
- konfai/metric/schedulers.py +24 -13
- konfai/models/classification/convNeXt.py +187 -81
- konfai/models/classification/resnet.py +272 -58
- konfai/models/generation/cStyleGan.py +233 -59
- konfai/models/generation/ddpm.py +348 -121
- konfai/models/generation/diffusionGan.py +757 -358
- konfai/models/generation/gan.py +177 -53
- konfai/models/generation/vae.py +140 -40
- konfai/models/registration/registration.py +135 -52
- konfai/models/representation/representation.py +57 -23
- konfai/models/segmentation/NestedUNet.py +339 -68
- konfai/models/segmentation/UNet.py +140 -30
- konfai/network/blocks.py +331 -187
- konfai/network/network.py +795 -427
- konfai/predictor.py +644 -238
- konfai/trainer.py +509 -222
- konfai/utils/ITK.py +191 -106
- konfai/utils/config.py +152 -95
- konfai/utils/dataset.py +326 -455
- konfai/utils/utils.py +497 -249
- {konfai-1.1.8.dist-info → konfai-1.2.0.dist-info}/METADATA +1 -3
- konfai-1.2.0.dist-info/RECORD +38 -0
- konfai/utils/registration.py +0 -199
- konfai-1.1.8.dist-info/RECORD +0 -39
- {konfai-1.1.8.dist-info → konfai-1.2.0.dist-info}/WHEEL +0 -0
- {konfai-1.1.8.dist-info → konfai-1.2.0.dist-info}/entry_points.txt +0 -0
- {konfai-1.1.8.dist-info → konfai-1.2.0.dist-info}/licenses/LICENSE +0 -0
- {konfai-1.1.8.dist-info → konfai-1.2.0.dist-info}/top_level.txt +0 -0
konfai/data/patching.py
CHANGED
|
@@ -1,41 +1,43 @@
|
|
|
1
|
+
import copy
|
|
2
|
+
import importlib
|
|
3
|
+
import itertools
|
|
1
4
|
from abc import ABC, abstractmethod
|
|
2
|
-
|
|
5
|
+
from collections.abc import Iterator
|
|
6
|
+
from functools import partial
|
|
7
|
+
|
|
3
8
|
import numpy as np
|
|
9
|
+
import SimpleITK as sitk # noqa: N813
|
|
4
10
|
import torch
|
|
5
|
-
import
|
|
6
|
-
|
|
7
|
-
from typing import Any, Iterator
|
|
8
|
-
from typing import Union
|
|
9
|
-
import itertools
|
|
10
|
-
import copy
|
|
11
|
-
from functools import partial
|
|
12
|
-
from konfai.utils.config import config
|
|
13
|
-
from konfai.utils.utils import get_patch_slices_from_shape
|
|
14
|
-
from konfai.utils.dataset import Dataset, Attribute
|
|
15
|
-
from konfai.data.transform import Transform, Save
|
|
11
|
+
import torch.nn.functional as F # noqa: N812
|
|
12
|
+
|
|
16
13
|
from konfai.data.augmentation import DataAugmentationsList
|
|
14
|
+
from konfai.data.transform import Save, Transform
|
|
15
|
+
from konfai.utils.config import config
|
|
16
|
+
from konfai.utils.dataset import Attribute, Dataset
|
|
17
|
+
from konfai.utils.utils import get_module, get_patch_slices_from_shape
|
|
18
|
+
|
|
17
19
|
|
|
18
20
|
class PathCombine(ABC):
|
|
19
21
|
|
|
20
22
|
def __init__(self) -> None:
|
|
21
|
-
self.data: torch.Tensor
|
|
22
|
-
self.overlap: int
|
|
23
|
-
|
|
23
|
+
self.data: torch.Tensor
|
|
24
|
+
self.overlap: int
|
|
25
|
+
|
|
24
26
|
"""
|
|
25
27
|
A = slice(0, overlap)
|
|
26
28
|
B = slice(-overlap, None)
|
|
27
29
|
C = slice(overlap, -overlap)
|
|
28
|
-
|
|
29
|
-
1D
|
|
30
|
+
|
|
31
|
+
1D
|
|
30
32
|
A+B
|
|
31
|
-
2D :
|
|
33
|
+
2D :
|
|
32
34
|
AA+AB+BA+BB
|
|
33
|
-
|
|
35
|
+
|
|
34
36
|
AC+BC
|
|
35
37
|
CA+CB
|
|
36
|
-
3D :
|
|
38
|
+
3D :
|
|
37
39
|
AAA+AAB+ABA+ABB+BAA+BAB+BBA+BBB
|
|
38
|
-
|
|
40
|
+
|
|
39
41
|
CAA+CAB+CBA+CBB
|
|
40
42
|
ACA+ACB+BCA+BCB
|
|
41
43
|
AAC+ABC+BAC+BBC
|
|
@@ -43,181 +45,281 @@ class PathCombine(ABC):
|
|
|
43
45
|
CCA+CCB
|
|
44
46
|
CAC+CBC
|
|
45
47
|
ACC+BCC
|
|
46
|
-
|
|
48
|
+
|
|
47
49
|
"""
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
self.data =
|
|
50
|
+
|
|
51
|
+
def set_patch_config(self, patch_size: list[int], overlap: int):
|
|
52
|
+
self.data = F.pad(
|
|
53
|
+
torch.ones([size - overlap * 2 for size in patch_size]),
|
|
54
|
+
[overlap] * 2 * len(patch_size),
|
|
55
|
+
mode="constant",
|
|
56
|
+
value=0,
|
|
57
|
+
)
|
|
58
|
+
self.data = self._set_function(self.data, overlap)
|
|
51
59
|
dim = len(patch_size)
|
|
52
60
|
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
61
|
+
a = slice(0, overlap)
|
|
62
|
+
b = slice(-overlap, None)
|
|
63
|
+
c = slice(overlap, -overlap)
|
|
64
|
+
|
|
57
65
|
for i in range(dim):
|
|
58
|
-
slices_badge = list(itertools.product(*[[
|
|
59
|
-
for indexs in itertools.combinations([0,1,2], i):
|
|
66
|
+
slices_badge = list(itertools.product(*[[a, b] for _ in range(dim - i)]))
|
|
67
|
+
for indexs in itertools.combinations([0, 1, 2], i):
|
|
60
68
|
result = []
|
|
61
|
-
for
|
|
62
|
-
|
|
69
|
+
for slices_tuple in slices_badge:
|
|
70
|
+
slices_list = list(slices_tuple)
|
|
63
71
|
for index in indexs:
|
|
64
|
-
|
|
65
|
-
result.append(tuple(
|
|
72
|
+
slices_list.insert(index, c)
|
|
73
|
+
result.append(tuple(slices_list))
|
|
66
74
|
for patch, s in zip(PathCombine._normalise([self.data[s] for s in result]), result):
|
|
67
75
|
self.data[s] = patch
|
|
68
76
|
|
|
69
|
-
|
|
70
77
|
@staticmethod
|
|
71
78
|
def _normalise(patchs: list[torch.Tensor]) -> list[torch.Tensor]:
|
|
72
79
|
data_sum = torch.sum(torch.concat([patch.unsqueeze(0) for patch in patchs], dim=0), dim=0)
|
|
73
|
-
return [d/data_sum for d in patchs]
|
|
74
|
-
|
|
75
|
-
def __call__(self,
|
|
76
|
-
return self.data.repeat([
|
|
77
|
-
|
|
80
|
+
return [d / data_sum for d in patchs]
|
|
81
|
+
|
|
82
|
+
def __call__(self, tensor: torch.Tensor) -> torch.Tensor:
|
|
83
|
+
return self.data.repeat([tensor.shape[0]] + [1] * (len(tensor.shape) - 1)).to(tensor.device) * tensor
|
|
84
|
+
|
|
78
85
|
@abstractmethod
|
|
79
|
-
def
|
|
86
|
+
def _set_function(self, data: torch.Tensor, overlap: int) -> torch.Tensor:
|
|
80
87
|
pass
|
|
81
88
|
|
|
89
|
+
|
|
82
90
|
class Mean(PathCombine):
|
|
83
91
|
|
|
84
92
|
def __init__(self) -> None:
|
|
85
93
|
super().__init__()
|
|
86
94
|
|
|
87
|
-
def
|
|
95
|
+
def _set_function(self, data: torch.Tensor, overlap: int) -> torch.Tensor:
|
|
88
96
|
return torch.ones_like(self.data)
|
|
89
|
-
|
|
97
|
+
|
|
98
|
+
|
|
90
99
|
class Cosinus(PathCombine):
|
|
91
100
|
|
|
92
101
|
def __init__(self) -> None:
|
|
93
102
|
super().__init__()
|
|
94
103
|
|
|
95
104
|
def _function_sides(self, overlap: int, x: float):
|
|
96
|
-
return np.clip(np.cos(np.pi/(2*(overlap+1))*x), 0, 1)
|
|
105
|
+
return np.clip(np.cos(np.pi / (2 * (overlap + 1)) * x), 0, 1)
|
|
97
106
|
|
|
98
|
-
def
|
|
107
|
+
def _set_function(self, data: torch.Tensor, overlap: int) -> torch.Tensor:
|
|
99
108
|
image = sitk.GetImageFromArray(np.asarray(data, dtype=np.uint8))
|
|
100
|
-
|
|
101
|
-
distance = torch.tensor(sitk.GetArrayFromImage(
|
|
109
|
+
danielsson_distance_map_image_filter = sitk.DanielssonDistanceMapImageFilter()
|
|
110
|
+
distance = torch.tensor(sitk.GetArrayFromImage(danielsson_distance_map_image_filter.Execute(image)))
|
|
102
111
|
return distance.apply_(partial(self._function_sides, overlap))
|
|
103
|
-
|
|
104
|
-
class Accumulator():
|
|
105
112
|
|
|
106
|
-
|
|
107
|
-
|
|
113
|
+
|
|
114
|
+
class Accumulator:
|
|
115
|
+
|
|
116
|
+
def __init__(
|
|
117
|
+
self,
|
|
118
|
+
patch_slices: list[tuple[slice]],
|
|
119
|
+
patch_size: list[int],
|
|
120
|
+
patch_combine: PathCombine | None = None,
|
|
121
|
+
batch: bool = True,
|
|
122
|
+
) -> None:
|
|
123
|
+
self._layer_accumulator: list[torch.Tensor | None] = [None] * len(patch_slices)
|
|
108
124
|
self.patch_slices = []
|
|
109
125
|
for patch in patch_slices:
|
|
110
126
|
slices = []
|
|
111
127
|
for s, shape in zip(patch, patch_size):
|
|
112
|
-
slices.append(slice(s.start, s.start+shape))
|
|
128
|
+
slices.append(slice(s.start, s.start + shape))
|
|
113
129
|
self.patch_slices.append(tuple(slices))
|
|
114
130
|
self.shape = max([[v.stop for v in patch] for patch in patch_slices])
|
|
115
131
|
self.patch_size = patch_size
|
|
116
|
-
self.
|
|
132
|
+
self.patch_combine = patch_combine
|
|
117
133
|
self.batch = batch
|
|
118
134
|
|
|
119
|
-
def
|
|
135
|
+
def add_layer(self, index: int, layer: torch.Tensor) -> None:
|
|
120
136
|
self._layer_accumulator[index] = layer
|
|
121
|
-
|
|
122
|
-
def
|
|
137
|
+
|
|
138
|
+
def is_full(self) -> bool:
|
|
123
139
|
return len(self.patch_slices) == len([v for v in self._layer_accumulator if v is not None])
|
|
124
140
|
|
|
125
141
|
def assemble(self) -> torch.Tensor:
|
|
126
|
-
|
|
127
|
-
|
|
142
|
+
n = 2 if self.batch else 1
|
|
143
|
+
if self._layer_accumulator[0] is not None:
|
|
144
|
+
result = torch.zeros(
|
|
145
|
+
(
|
|
146
|
+
list(self._layer_accumulator[0].shape[:n])
|
|
147
|
+
+ list(max([[v.stop for v in patch] for patch in self.patch_slices]))
|
|
148
|
+
),
|
|
149
|
+
dtype=self._layer_accumulator[0].dtype,
|
|
150
|
+
).to(self._layer_accumulator[0].device)
|
|
128
151
|
for patch_slice, data in zip(self.patch_slices, self._layer_accumulator):
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
132
|
-
|
|
133
|
-
|
|
134
|
-
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
152
|
+
if data is not None:
|
|
153
|
+
slices_dest = tuple([slice(result.shape[i]) for i in range(n)] + list(patch_slice))
|
|
154
|
+
|
|
155
|
+
for dim, s in enumerate(patch_slice):
|
|
156
|
+
if s.stop - s.start == 1:
|
|
157
|
+
data = data.unsqueeze(dim=dim + n)
|
|
158
|
+
if self.patch_combine is not None:
|
|
159
|
+
result[slices_dest] += self.patch_combine(data)
|
|
160
|
+
else:
|
|
161
|
+
result[slices_dest] = data
|
|
162
|
+
result = result[tuple([slice(None, None)] + [slice(0, s) for s in self.shape])]
|
|
139
163
|
|
|
140
164
|
self._layer_accumulator.clear()
|
|
141
165
|
return result
|
|
142
166
|
|
|
167
|
+
|
|
143
168
|
class Patch(ABC):
|
|
144
169
|
|
|
145
|
-
|
|
170
|
+
@abstractmethod
|
|
171
|
+
def __init__(
|
|
172
|
+
self,
|
|
173
|
+
patch_size: list[int],
|
|
174
|
+
overlap: int | None,
|
|
175
|
+
pad_value: float = 0,
|
|
176
|
+
extend_slice: int = 0,
|
|
177
|
+
) -> None:
|
|
146
178
|
self.patch_size = patch_size
|
|
147
|
-
self.overlap = overlap
|
|
179
|
+
self.overlap = overlap
|
|
148
180
|
if isinstance(self.overlap, int):
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
self._patch_slices
|
|
181
|
+
if self.overlap < 0:
|
|
182
|
+
self.overlap = None
|
|
183
|
+
self._patch_slices: dict[int, list[tuple[slice, ...]]] = {}
|
|
152
184
|
self._nb_patch_per_dim: dict[int, list[tuple[int, bool]]] = {}
|
|
153
|
-
self.
|
|
185
|
+
self.pad_value = pad_value
|
|
154
186
|
self.extend_slice = extend_slice
|
|
155
|
-
|
|
156
|
-
def load(self, shape : dict[int, list[int]], a: int = 0) -> None:
|
|
157
|
-
self._patch_slices[a], self._nb_patch_per_dim[a] = get_patch_slices_from_shape(self.patch_size, shape, self.overlap)
|
|
158
187
|
|
|
159
|
-
def
|
|
160
|
-
|
|
161
|
-
|
|
188
|
+
def load(self, shape: list[int], a: int = 0) -> None:
|
|
189
|
+
self._patch_slices[a], self._nb_patch_per_dim[a] = get_patch_slices_from_shape(
|
|
190
|
+
self.patch_size, shape, self.overlap
|
|
191
|
+
)
|
|
192
|
+
|
|
162
193
|
@abstractmethod
|
|
163
|
-
def
|
|
194
|
+
def init(self, key: str):
|
|
164
195
|
pass
|
|
165
|
-
|
|
166
|
-
def
|
|
196
|
+
|
|
197
|
+
def get_patch_slices(self, a: int = 0):
|
|
198
|
+
return self._patch_slices[a]
|
|
199
|
+
|
|
200
|
+
def get_data(self, data: torch.Tensor, index: int, a: int, is_input: bool) -> list[torch.Tensor]:
|
|
167
201
|
slices_pre = []
|
|
168
|
-
for
|
|
169
|
-
slices_pre.append(slice(
|
|
170
|
-
extend_slice = self.extend_slice if
|
|
171
|
-
|
|
172
|
-
bottom = extend_slice//2
|
|
173
|
-
top = int(np.ceil(extend_slice/2))
|
|
174
|
-
s = slice(
|
|
175
|
-
|
|
176
|
-
|
|
177
|
-
|
|
202
|
+
for max_value in data.shape[: -len(self.patch_size)]:
|
|
203
|
+
slices_pre.append(slice(max_value))
|
|
204
|
+
extend_slice = self.extend_slice if is_input else 0
|
|
205
|
+
|
|
206
|
+
bottom = extend_slice // 2
|
|
207
|
+
top = int(np.ceil(extend_slice / 2))
|
|
208
|
+
s = slice(
|
|
209
|
+
(
|
|
210
|
+
self._patch_slices[a][index][0].start - bottom
|
|
211
|
+
if self._patch_slices[a][index][0].start - bottom >= 0
|
|
212
|
+
else 0
|
|
213
|
+
),
|
|
214
|
+
(
|
|
215
|
+
self._patch_slices[a][index][0].stop + top
|
|
216
|
+
if self._patch_slices[a][index][0].stop + top <= data.shape[len(slices_pre)]
|
|
217
|
+
else data.shape[len(slices_pre)]
|
|
218
|
+
),
|
|
219
|
+
)
|
|
220
|
+
slices = [s] + list(self._patch_slices[a][index][1:])
|
|
221
|
+
data_sliced = data[slices_pre + slices]
|
|
222
|
+
if data_sliced.shape[len(slices_pre)] < bottom + top + 1:
|
|
178
223
|
pad_bottom = 0
|
|
179
224
|
pad_top = 0
|
|
180
|
-
if self._patch_slices[a][index][0].start-bottom < 0:
|
|
181
|
-
pad_bottom = bottom-self._patch_slices[a][index][0].start
|
|
182
|
-
if self._patch_slices[a][index][0].stop+top > data.shape[len(slices_pre)]:
|
|
183
|
-
pad_top = self._patch_slices[a][index][0].stop+top-data.shape[len(slices_pre)]
|
|
184
|
-
data_sliced = F.pad(
|
|
225
|
+
if self._patch_slices[a][index][0].start - bottom < 0:
|
|
226
|
+
pad_bottom = bottom - self._patch_slices[a][index][0].start
|
|
227
|
+
if self._patch_slices[a][index][0].stop + top > data.shape[len(slices_pre)]:
|
|
228
|
+
pad_top = self._patch_slices[a][index][0].stop + top - data.shape[len(slices_pre)]
|
|
229
|
+
data_sliced = F.pad(
|
|
230
|
+
data_sliced,
|
|
231
|
+
[0 for _ in range((len(slices) - 1) * 2)] + [pad_bottom, pad_top],
|
|
232
|
+
"reflect",
|
|
233
|
+
)
|
|
185
234
|
|
|
186
235
|
padding = []
|
|
187
236
|
for dim_it, _slice in enumerate(reversed(slices)):
|
|
188
|
-
p =
|
|
237
|
+
p = (
|
|
238
|
+
0
|
|
239
|
+
if _slice.start + self.patch_size[-dim_it - 1] <= data.shape[-dim_it - 1]
|
|
240
|
+
else self.patch_size[-dim_it - 1] - (data.shape[-dim_it - 1] - _slice.start)
|
|
241
|
+
)
|
|
189
242
|
padding.append(0)
|
|
190
243
|
padding.append(p)
|
|
191
244
|
|
|
192
|
-
data_sliced = F.pad(
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
245
|
+
data_sliced = F.pad(
|
|
246
|
+
data_sliced,
|
|
247
|
+
tuple(padding),
|
|
248
|
+
"constant",
|
|
249
|
+
(0 if data_sliced.dtype == torch.uint8 and self.pad_value < 0 else self.pad_value),
|
|
250
|
+
)
|
|
197
251
|
|
|
198
|
-
|
|
252
|
+
for d in [i for i, v in enumerate(reversed(self.patch_size)) if v == 1]:
|
|
253
|
+
data_sliced = torch.squeeze(data_sliced, dim=len(data_sliced.shape) - d - 1)
|
|
254
|
+
return (
|
|
255
|
+
torch.cat([data_sliced[:, i, ...] for i in range(data_sliced.shape[1])], dim=0)
|
|
256
|
+
if extend_slice > 0
|
|
257
|
+
else data_sliced
|
|
258
|
+
)
|
|
259
|
+
|
|
260
|
+
def get_size(self, a: int = 0) -> int:
|
|
199
261
|
return len(self._patch_slices[a])
|
|
200
262
|
|
|
263
|
+
|
|
201
264
|
class DatasetPatch(Patch):
|
|
202
265
|
|
|
203
266
|
@config("Patch")
|
|
204
|
-
def __init__(
|
|
205
|
-
|
|
267
|
+
def __init__(
|
|
268
|
+
self,
|
|
269
|
+
patch_size: list[int] = [128, 128, 128],
|
|
270
|
+
overlap: int | None = None,
|
|
271
|
+
pad_value: float = 0,
|
|
272
|
+
extend_slice: int = 0,
|
|
273
|
+
) -> None:
|
|
274
|
+
super().__init__(patch_size, overlap, pad_value, extend_slice)
|
|
275
|
+
|
|
276
|
+
def init(self, key: str = ""):
|
|
277
|
+
pass
|
|
278
|
+
|
|
206
279
|
|
|
207
280
|
class ModelPatch(Patch):
|
|
208
281
|
|
|
209
282
|
@config("Patch")
|
|
210
|
-
def __init__(
|
|
211
|
-
|
|
212
|
-
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
216
|
-
|
|
217
|
-
|
|
218
|
-
|
|
219
|
-
|
|
220
|
-
|
|
283
|
+
def __init__(
|
|
284
|
+
self,
|
|
285
|
+
patch_size: list[int] = [128, 128, 128],
|
|
286
|
+
overlap: int | None = None,
|
|
287
|
+
patch_combine: str | None = None,
|
|
288
|
+
pad_value: float = 0,
|
|
289
|
+
extend_slice: int = 0,
|
|
290
|
+
) -> None:
|
|
291
|
+
super().__init__(patch_size, overlap, pad_value, extend_slice)
|
|
292
|
+
self._patch_combine = patch_combine
|
|
293
|
+
self.patch_combine: PathCombine | None = None
|
|
294
|
+
|
|
295
|
+
def init(self, key: str):
|
|
296
|
+
if self._patch_combine is not None:
|
|
297
|
+
module, name = get_module(self._patch_combine, "konfai.data.patching")
|
|
298
|
+
self.patch_combine = config(key)(getattr(importlib.import_module(module), name))(config=None)
|
|
299
|
+
if self.patch_size is not None and self.overlap is not None:
|
|
300
|
+
if self.patch_combine is not None:
|
|
301
|
+
self.patch_combine.set_patch_config([i for i in self.patch_size if i > 1], self.overlap)
|
|
302
|
+
else:
|
|
303
|
+
self.patch_combine = None
|
|
304
|
+
|
|
305
|
+
def disassemble(self, *data_list: torch.Tensor) -> Iterator[list[torch.Tensor]]:
|
|
306
|
+
for i in range(self.get_size()):
|
|
307
|
+
yield [self.get_data(data, i, 0, True) for data in data_list]
|
|
308
|
+
|
|
309
|
+
|
|
310
|
+
class DatasetManager:
|
|
311
|
+
|
|
312
|
+
def __init__(
|
|
313
|
+
self,
|
|
314
|
+
index: int,
|
|
315
|
+
group_src: str,
|
|
316
|
+
group_dest: str,
|
|
317
|
+
name: str,
|
|
318
|
+
dataset: Dataset,
|
|
319
|
+
patch: DatasetPatch | None,
|
|
320
|
+
transforms: list[Transform],
|
|
321
|
+
data_augmentations_list: list[DataAugmentationsList],
|
|
322
|
+
) -> None:
|
|
221
323
|
self.group_src = group_src
|
|
222
324
|
self.group_dest = group_dest
|
|
223
325
|
self.name = name
|
|
@@ -226,94 +328,110 @@ class DatasetManager():
|
|
|
226
328
|
self.loaded = False
|
|
227
329
|
self.augmentationLoaded = False
|
|
228
330
|
self.cache_attributes: list[Attribute] = []
|
|
229
|
-
_shape, cache_attribute =
|
|
331
|
+
_shape, cache_attribute = self.dataset.get_infos(self.group_src, name)
|
|
230
332
|
self.cache_attributes.append(cache_attribute)
|
|
231
333
|
_shape = list(_shape[1:])
|
|
232
|
-
|
|
233
|
-
self.data : list[torch.Tensor] = list()
|
|
234
334
|
|
|
235
|
-
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
|
|
335
|
+
self.data: list[torch.Tensor] = []
|
|
336
|
+
|
|
337
|
+
for transform_function in transforms:
|
|
338
|
+
_shape = transform_function.transform_shape(_shape, cache_attribute)
|
|
339
|
+
|
|
340
|
+
self.patch = (
|
|
341
|
+
DatasetPatch(
|
|
342
|
+
patch_size=patch.patch_size,
|
|
343
|
+
overlap=patch.overlap,
|
|
344
|
+
pad_value=patch.pad_value,
|
|
345
|
+
extend_slice=patch.extend_slice,
|
|
346
|
+
)
|
|
347
|
+
if patch
|
|
348
|
+
else DatasetPatch(_shape)
|
|
349
|
+
)
|
|
239
350
|
self.patch.load(_shape, 0)
|
|
240
351
|
self.shape = _shape
|
|
241
|
-
self.
|
|
242
|
-
self.
|
|
352
|
+
self.data_augmentations_list = data_augmentations_list
|
|
353
|
+
self.reset_augmentation()
|
|
243
354
|
self.cache_attributes_bak = copy.deepcopy(self.cache_attributes)
|
|
244
|
-
|
|
245
|
-
def
|
|
355
|
+
|
|
356
|
+
def reset_augmentation(self):
|
|
246
357
|
self.cache_attributes[:] = self.cache_attributes[:1]
|
|
247
358
|
i = 1
|
|
248
|
-
for
|
|
359
|
+
for data_augmentations in self.data_augmentations_list:
|
|
249
360
|
shape = []
|
|
250
361
|
caches_attribute = []
|
|
251
|
-
for _ in range(
|
|
362
|
+
for _ in range(data_augmentations.nb):
|
|
252
363
|
shape.append(self.shape)
|
|
253
364
|
caches_attribute.append(copy.deepcopy(self.cache_attributes[0]))
|
|
254
365
|
|
|
255
|
-
for
|
|
256
|
-
shape =
|
|
366
|
+
for data_augmentation in data_augmentations.data_augmentations:
|
|
367
|
+
shape = data_augmentation.state_init(self.index, shape, caches_attribute)
|
|
257
368
|
for it, s in enumerate(shape):
|
|
258
369
|
self.cache_attributes.append(caches_attribute[it])
|
|
259
370
|
self.patch.load(s, i)
|
|
260
|
-
i+=1
|
|
261
|
-
|
|
262
|
-
def load(
|
|
371
|
+
i += 1
|
|
372
|
+
|
|
373
|
+
def load(
|
|
374
|
+
self,
|
|
375
|
+
pre_transform: list[Transform],
|
|
376
|
+
data_augmentations_list: list[DataAugmentationsList],
|
|
377
|
+
) -> None:
|
|
263
378
|
if not self.loaded:
|
|
264
379
|
self._load(pre_transform)
|
|
265
380
|
if not self.augmentationLoaded:
|
|
266
|
-
self.
|
|
267
|
-
|
|
268
|
-
def _load(self, pre_transform
|
|
381
|
+
self._load_augmentation(data_augmentations_list)
|
|
382
|
+
|
|
383
|
+
def _load(self, pre_transform: list[Transform]):
|
|
269
384
|
self.cache_attributes = copy.deepcopy(self.cache_attributes_bak)
|
|
270
385
|
i = len(pre_transform)
|
|
271
386
|
data = None
|
|
272
|
-
for
|
|
273
|
-
if isinstance(
|
|
274
|
-
if len(
|
|
275
|
-
filename,
|
|
387
|
+
for transform_function in reversed(pre_transform):
|
|
388
|
+
if isinstance(transform_function, Save):
|
|
389
|
+
if len(transform_function.dataset.split(":")) > 1:
|
|
390
|
+
filename, file_format = transform_function.dataset.split(":")
|
|
276
391
|
else:
|
|
277
|
-
filename =
|
|
278
|
-
|
|
279
|
-
dataset = Dataset(filename,
|
|
280
|
-
if dataset.
|
|
281
|
-
data, attrib = dataset.
|
|
392
|
+
filename = transform_function.dataset
|
|
393
|
+
file_format = "mha"
|
|
394
|
+
dataset = Dataset(filename, file_format)
|
|
395
|
+
if dataset.is_dataset_exist(self.group_dest, self.name):
|
|
396
|
+
data, attrib = dataset.read_data(self.group_dest, self.name)
|
|
282
397
|
self.cache_attributes[0].update(attrib)
|
|
283
398
|
break
|
|
284
|
-
i-=1
|
|
285
|
-
|
|
286
|
-
if i==0:
|
|
287
|
-
data, _ = self.dataset.readData(self.group_src, self.name)
|
|
399
|
+
i -= 1
|
|
288
400
|
|
|
401
|
+
if i == 0:
|
|
402
|
+
data, _ = self.dataset.read_data(self.group_src, self.name)
|
|
289
403
|
|
|
290
404
|
data = torch.from_numpy(data)
|
|
291
405
|
|
|
292
406
|
if len(pre_transform):
|
|
293
|
-
for
|
|
294
|
-
data =
|
|
295
|
-
if isinstance(
|
|
296
|
-
if len(
|
|
297
|
-
filename,
|
|
407
|
+
for transform_function in pre_transform[i:]:
|
|
408
|
+
data = transform_function(self.name, data, self.cache_attributes[0])
|
|
409
|
+
if isinstance(transform_function, Save):
|
|
410
|
+
if len(transform_function.dataset.split(":")) > 1:
|
|
411
|
+
filename, file_format = transform_function.dataset.split(":")
|
|
298
412
|
else:
|
|
299
|
-
filename =
|
|
300
|
-
|
|
301
|
-
dataset = Dataset(filename,
|
|
302
|
-
dataset.write(
|
|
303
|
-
|
|
413
|
+
filename = transform_function.dataset
|
|
414
|
+
file_format = "mha"
|
|
415
|
+
dataset = Dataset(filename, file_format)
|
|
416
|
+
dataset.write(
|
|
417
|
+
self.group_dest,
|
|
418
|
+
self.name,
|
|
419
|
+
data.numpy(),
|
|
420
|
+
self.cache_attributes[0],
|
|
421
|
+
)
|
|
304
422
|
self.data.append(data)
|
|
305
|
-
|
|
306
|
-
for i in range(len(self.cache_attributes)-1):
|
|
307
|
-
self.cache_attributes[i+1].update(self.cache_attributes[0])
|
|
423
|
+
|
|
424
|
+
for i in range(len(self.cache_attributes) - 1):
|
|
425
|
+
self.cache_attributes[i + 1].update(self.cache_attributes[0])
|
|
308
426
|
self.loaded = True
|
|
309
|
-
|
|
310
|
-
def
|
|
311
|
-
for
|
|
312
|
-
a_data = [self.data[0].clone() for _ in range(
|
|
313
|
-
for
|
|
314
|
-
if
|
|
315
|
-
a_data =
|
|
316
|
-
|
|
427
|
+
|
|
428
|
+
def _load_augmentation(self, data_augmentations_list: list[DataAugmentationsList]) -> None:
|
|
429
|
+
for data_augmentations in data_augmentations_list:
|
|
430
|
+
a_data = [self.data[0].clone() for _ in range(data_augmentations.nb)]
|
|
431
|
+
for data_augmentation in data_augmentations.data_augmentations:
|
|
432
|
+
if data_augmentation.groups is None or self.group_dest in data_augmentation.groups:
|
|
433
|
+
a_data = data_augmentation(self.name, self.index, a_data)
|
|
434
|
+
|
|
317
435
|
for d in a_data:
|
|
318
436
|
self.data.append(d)
|
|
319
437
|
self.augmentationLoaded = True
|
|
@@ -322,17 +440,16 @@ class DatasetManager():
|
|
|
322
440
|
self.data.clear()
|
|
323
441
|
self.loaded = False
|
|
324
442
|
self.augmentationLoaded = False
|
|
325
|
-
|
|
326
|
-
def
|
|
443
|
+
|
|
444
|
+
def unload_augmentation(self) -> None:
|
|
327
445
|
self.data[:] = self.data[:1]
|
|
328
446
|
self.augmentationLoaded = False
|
|
329
447
|
|
|
330
|
-
|
|
331
|
-
|
|
332
|
-
|
|
333
|
-
|
|
334
|
-
data = transformFunction(self.name, data, self.cache_attributes[a])
|
|
448
|
+
def get_data(self, index: int, a: int, patch_transforms: list[Transform], is_input: bool) -> torch.Tensor:
|
|
449
|
+
data = self.patch.get_data(self.data[a], index, a, is_input)
|
|
450
|
+
for transform_function in patch_transforms:
|
|
451
|
+
data = transform_function(self.name, data, self.cache_attributes[a])
|
|
335
452
|
return data
|
|
336
453
|
|
|
337
|
-
def
|
|
338
|
-
return self.patch.
|
|
454
|
+
def get_size(self, a: int) -> int:
|
|
455
|
+
return self.patch.get_size(a)
|