konfai 1.1.8__py3-none-any.whl → 1.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of konfai might be problematic. Click here for more details.

Files changed (36) hide show
  1. konfai/__init__.py +59 -14
  2. konfai/data/augmentation.py +457 -286
  3. konfai/data/data_manager.py +533 -316
  4. konfai/data/patching.py +300 -183
  5. konfai/data/transform.py +408 -275
  6. konfai/evaluator.py +325 -68
  7. konfai/main.py +71 -22
  8. konfai/metric/measure.py +360 -244
  9. konfai/metric/schedulers.py +24 -13
  10. konfai/models/classification/convNeXt.py +187 -81
  11. konfai/models/classification/resnet.py +272 -58
  12. konfai/models/generation/cStyleGan.py +233 -59
  13. konfai/models/generation/ddpm.py +348 -121
  14. konfai/models/generation/diffusionGan.py +757 -358
  15. konfai/models/generation/gan.py +177 -53
  16. konfai/models/generation/vae.py +140 -40
  17. konfai/models/registration/registration.py +135 -52
  18. konfai/models/representation/representation.py +57 -23
  19. konfai/models/segmentation/NestedUNet.py +339 -68
  20. konfai/models/segmentation/UNet.py +140 -30
  21. konfai/network/blocks.py +331 -187
  22. konfai/network/network.py +795 -427
  23. konfai/predictor.py +644 -238
  24. konfai/trainer.py +509 -222
  25. konfai/utils/ITK.py +191 -106
  26. konfai/utils/config.py +152 -95
  27. konfai/utils/dataset.py +326 -455
  28. konfai/utils/utils.py +497 -249
  29. {konfai-1.1.8.dist-info → konfai-1.2.0.dist-info}/METADATA +1 -3
  30. konfai-1.2.0.dist-info/RECORD +38 -0
  31. konfai/utils/registration.py +0 -199
  32. konfai-1.1.8.dist-info/RECORD +0 -39
  33. {konfai-1.1.8.dist-info → konfai-1.2.0.dist-info}/WHEEL +0 -0
  34. {konfai-1.1.8.dist-info → konfai-1.2.0.dist-info}/entry_points.txt +0 -0
  35. {konfai-1.1.8.dist-info → konfai-1.2.0.dist-info}/licenses/LICENSE +0 -0
  36. {konfai-1.1.8.dist-info → konfai-1.2.0.dist-info}/top_level.txt +0 -0
konfai/data/patching.py CHANGED
@@ -1,41 +1,43 @@
1
+ import copy
2
+ import importlib
3
+ import itertools
1
4
  from abc import ABC, abstractmethod
2
- import SimpleITK as sitk
5
+ from collections.abc import Iterator
6
+ from functools import partial
7
+
3
8
  import numpy as np
9
+ import SimpleITK as sitk # noqa: N813
4
10
  import torch
5
- import os
6
- import torch.nn.functional as F
7
- from typing import Any, Iterator
8
- from typing import Union
9
- import itertools
10
- import copy
11
- from functools import partial
12
- from konfai.utils.config import config
13
- from konfai.utils.utils import get_patch_slices_from_shape
14
- from konfai.utils.dataset import Dataset, Attribute
15
- from konfai.data.transform import Transform, Save
11
+ import torch.nn.functional as F # noqa: N812
12
+
16
13
  from konfai.data.augmentation import DataAugmentationsList
14
+ from konfai.data.transform import Save, Transform
15
+ from konfai.utils.config import config
16
+ from konfai.utils.dataset import Attribute, Dataset
17
+ from konfai.utils.utils import get_module, get_patch_slices_from_shape
18
+
17
19
 
18
20
  class PathCombine(ABC):
19
21
 
20
22
  def __init__(self) -> None:
21
- self.data: torch.Tensor = None
22
- self.overlap: int = None
23
-
23
+ self.data: torch.Tensor
24
+ self.overlap: int
25
+
24
26
  """
25
27
  A = slice(0, overlap)
26
28
  B = slice(-overlap, None)
27
29
  C = slice(overlap, -overlap)
28
-
29
- 1D
30
+
31
+ 1D
30
32
  A+B
31
- 2D :
33
+ 2D :
32
34
  AA+AB+BA+BB
33
-
35
+
34
36
  AC+BC
35
37
  CA+CB
36
- 3D :
38
+ 3D :
37
39
  AAA+AAB+ABA+ABB+BAA+BAB+BBA+BBB
38
-
40
+
39
41
  CAA+CAB+CBA+CBB
40
42
  ACA+ACB+BCA+BCB
41
43
  AAC+ABC+BAC+BBC
@@ -43,181 +45,281 @@ class PathCombine(ABC):
43
45
  CCA+CCB
44
46
  CAC+CBC
45
47
  ACC+BCC
46
-
48
+
47
49
  """
48
- def setPatchConfig(self, patch_size: list[int], overlap: int):
49
- self.data = F.pad(torch.ones([size-overlap*2 for size in patch_size]), [overlap]*2*len(patch_size), mode="constant", value=0)
50
- self.data = self._setFunction(self.data, overlap)
50
+
51
+ def set_patch_config(self, patch_size: list[int], overlap: int):
52
+ self.data = F.pad(
53
+ torch.ones([size - overlap * 2 for size in patch_size]),
54
+ [overlap] * 2 * len(patch_size),
55
+ mode="constant",
56
+ value=0,
57
+ )
58
+ self.data = self._set_function(self.data, overlap)
51
59
  dim = len(patch_size)
52
60
 
53
- A = slice(0, overlap)
54
- B = slice(-overlap, None)
55
- C = slice(overlap, -overlap)
56
-
61
+ a = slice(0, overlap)
62
+ b = slice(-overlap, None)
63
+ c = slice(overlap, -overlap)
64
+
57
65
  for i in range(dim):
58
- slices_badge = list(itertools.product(*[[A, B] for _ in range(dim-i)]))
59
- for indexs in itertools.combinations([0,1,2], i):
66
+ slices_badge = list(itertools.product(*[[a, b] for _ in range(dim - i)]))
67
+ for indexs in itertools.combinations([0, 1, 2], i):
60
68
  result = []
61
- for slices in slices_badge:
62
- slices = list(slices)
69
+ for slices_tuple in slices_badge:
70
+ slices_list = list(slices_tuple)
63
71
  for index in indexs:
64
- slices.insert(index, C)
65
- result.append(tuple(slices))
72
+ slices_list.insert(index, c)
73
+ result.append(tuple(slices_list))
66
74
  for patch, s in zip(PathCombine._normalise([self.data[s] for s in result]), result):
67
75
  self.data[s] = patch
68
76
 
69
-
70
77
  @staticmethod
71
78
  def _normalise(patchs: list[torch.Tensor]) -> list[torch.Tensor]:
72
79
  data_sum = torch.sum(torch.concat([patch.unsqueeze(0) for patch in patchs], dim=0), dim=0)
73
- return [d/data_sum for d in patchs]
74
-
75
- def __call__(self, input: torch.Tensor) -> torch.Tensor:
76
- return self.data.repeat([input.shape[0]]+[1]*(len(input.shape)-1)).to(input.device)*input
77
-
80
+ return [d / data_sum for d in patchs]
81
+
82
+ def __call__(self, tensor: torch.Tensor) -> torch.Tensor:
83
+ return self.data.repeat([tensor.shape[0]] + [1] * (len(tensor.shape) - 1)).to(tensor.device) * tensor
84
+
78
85
  @abstractmethod
79
- def _setFunction(self, data: torch.Tensor, overlap: int) -> torch.Tensor:
86
+ def _set_function(self, data: torch.Tensor, overlap: int) -> torch.Tensor:
80
87
  pass
81
88
 
89
+
82
90
  class Mean(PathCombine):
83
91
 
84
92
  def __init__(self) -> None:
85
93
  super().__init__()
86
94
 
87
- def _setFunction(self, data: torch.Tensor, overlap: int) -> torch.Tensor:
95
+ def _set_function(self, data: torch.Tensor, overlap: int) -> torch.Tensor:
88
96
  return torch.ones_like(self.data)
89
-
97
+
98
+
90
99
  class Cosinus(PathCombine):
91
100
 
92
101
  def __init__(self) -> None:
93
102
  super().__init__()
94
103
 
95
104
  def _function_sides(self, overlap: int, x: float):
96
- return np.clip(np.cos(np.pi/(2*(overlap+1))*x), 0, 1)
105
+ return np.clip(np.cos(np.pi / (2 * (overlap + 1)) * x), 0, 1)
97
106
 
98
- def _setFunction(self, data: torch.Tensor, overlap: int) -> torch.Tensor:
107
+ def _set_function(self, data: torch.Tensor, overlap: int) -> torch.Tensor:
99
108
  image = sitk.GetImageFromArray(np.asarray(data, dtype=np.uint8))
100
- danielssonDistanceMapImageFilter = sitk.DanielssonDistanceMapImageFilter()
101
- distance = torch.tensor(sitk.GetArrayFromImage(danielssonDistanceMapImageFilter.Execute(image)))
109
+ danielsson_distance_map_image_filter = sitk.DanielssonDistanceMapImageFilter()
110
+ distance = torch.tensor(sitk.GetArrayFromImage(danielsson_distance_map_image_filter.Execute(image)))
102
111
  return distance.apply_(partial(self._function_sides, overlap))
103
-
104
- class Accumulator():
105
112
 
106
- def __init__(self, patch_slices: list[tuple[slice]], patch_size: list[int], patchCombine: Union[PathCombine, None] = None, batch: bool = True) -> None:
107
- self._layer_accumulator: list[Union[torch.Tensor, None]] = [None for i in range(len(patch_slices))]
113
+
114
+ class Accumulator:
115
+
116
+ def __init__(
117
+ self,
118
+ patch_slices: list[tuple[slice]],
119
+ patch_size: list[int],
120
+ patch_combine: PathCombine | None = None,
121
+ batch: bool = True,
122
+ ) -> None:
123
+ self._layer_accumulator: list[torch.Tensor | None] = [None] * len(patch_slices)
108
124
  self.patch_slices = []
109
125
  for patch in patch_slices:
110
126
  slices = []
111
127
  for s, shape in zip(patch, patch_size):
112
- slices.append(slice(s.start, s.start+shape))
128
+ slices.append(slice(s.start, s.start + shape))
113
129
  self.patch_slices.append(tuple(slices))
114
130
  self.shape = max([[v.stop for v in patch] for patch in patch_slices])
115
131
  self.patch_size = patch_size
116
- self.patchCombine = patchCombine
132
+ self.patch_combine = patch_combine
117
133
  self.batch = batch
118
134
 
119
- def addLayer(self, index: int, layer: torch.Tensor) -> None:
135
+ def add_layer(self, index: int, layer: torch.Tensor) -> None:
120
136
  self._layer_accumulator[index] = layer
121
-
122
- def isFull(self) -> bool:
137
+
138
+ def is_full(self) -> bool:
123
139
  return len(self.patch_slices) == len([v for v in self._layer_accumulator if v is not None])
124
140
 
125
141
  def assemble(self) -> torch.Tensor:
126
- N = 2 if self.batch else 1
127
- result = torch.zeros((list(self._layer_accumulator[0].shape[:N])+list(max([[v.stop for v in patch] for patch in self.patch_slices]))), dtype=self._layer_accumulator[0].dtype).to(self._layer_accumulator[0].device)
142
+ n = 2 if self.batch else 1
143
+ if self._layer_accumulator[0] is not None:
144
+ result = torch.zeros(
145
+ (
146
+ list(self._layer_accumulator[0].shape[:n])
147
+ + list(max([[v.stop for v in patch] for patch in self.patch_slices]))
148
+ ),
149
+ dtype=self._layer_accumulator[0].dtype,
150
+ ).to(self._layer_accumulator[0].device)
128
151
  for patch_slice, data in zip(self.patch_slices, self._layer_accumulator):
129
- slices_dest = tuple([slice(result.shape[i]) for i in range(N)] + list(patch_slice))
130
-
131
- for dim, s in enumerate(patch_slice):
132
- if s.stop-s.start == 1:
133
- data = data.unsqueeze(dim=dim+N)
134
- if self.patchCombine is not None:
135
- result[slices_dest] += self.patchCombine(data)
136
- else:
137
- result[slices_dest] = data
138
- result = result[tuple([slice(None, None)]+[slice(0, s) for s in self.shape])]
152
+ if data is not None:
153
+ slices_dest = tuple([slice(result.shape[i]) for i in range(n)] + list(patch_slice))
154
+
155
+ for dim, s in enumerate(patch_slice):
156
+ if s.stop - s.start == 1:
157
+ data = data.unsqueeze(dim=dim + n)
158
+ if self.patch_combine is not None:
159
+ result[slices_dest] += self.patch_combine(data)
160
+ else:
161
+ result[slices_dest] = data
162
+ result = result[tuple([slice(None, None)] + [slice(0, s) for s in self.shape])]
139
163
 
140
164
  self._layer_accumulator.clear()
141
165
  return result
142
166
 
167
+
143
168
  class Patch(ABC):
144
169
 
145
- def __init__(self, patch_size: list[int], overlap: Union[int, None], padValue: float = 0, extend_slice: int = 0) -> None:
170
+ @abstractmethod
171
+ def __init__(
172
+ self,
173
+ patch_size: list[int],
174
+ overlap: int | None,
175
+ pad_value: float = 0,
176
+ extend_slice: int = 0,
177
+ ) -> None:
146
178
  self.patch_size = patch_size
147
- self.overlap = overlap
179
+ self.overlap = overlap
148
180
  if isinstance(self.overlap, int):
149
- if self.overlap < 0:
150
- self.overlap = None
151
- self._patch_slices : dict[int, list[tuple[slice]]] = {}
181
+ if self.overlap < 0:
182
+ self.overlap = None
183
+ self._patch_slices: dict[int, list[tuple[slice, ...]]] = {}
152
184
  self._nb_patch_per_dim: dict[int, list[tuple[int, bool]]] = {}
153
- self.padValue = padValue
185
+ self.pad_value = pad_value
154
186
  self.extend_slice = extend_slice
155
-
156
- def load(self, shape : dict[int, list[int]], a: int = 0) -> None:
157
- self._patch_slices[a], self._nb_patch_per_dim[a] = get_patch_slices_from_shape(self.patch_size, shape, self.overlap)
158
187
 
159
- def getPatch_slices(self, a: int = 0):
160
- return self._patch_slices[a]
161
-
188
+ def load(self, shape: list[int], a: int = 0) -> None:
189
+ self._patch_slices[a], self._nb_patch_per_dim[a] = get_patch_slices_from_shape(
190
+ self.patch_size, shape, self.overlap
191
+ )
192
+
162
193
  @abstractmethod
163
- def getData(self, data : torch.Tensor, index : int, a: int, isInput: bool) -> torch.Tensor:
194
+ def init(self, key: str):
164
195
  pass
165
-
166
- def getData(self, data : torch.Tensor, index : int, a: int, isInput: bool) -> list[torch.Tensor]:
196
+
197
+ def get_patch_slices(self, a: int = 0):
198
+ return self._patch_slices[a]
199
+
200
+ def get_data(self, data: torch.Tensor, index: int, a: int, is_input: bool) -> list[torch.Tensor]:
167
201
  slices_pre = []
168
- for max in data.shape[:-len(self.patch_size)]:
169
- slices_pre.append(slice(max))
170
- extend_slice = self.extend_slice if isInput else 0
171
-
172
- bottom = extend_slice//2
173
- top = int(np.ceil(extend_slice/2))
174
- s = slice(self._patch_slices[a][index][0].start-bottom if self._patch_slices[a][index][0].start-bottom >= 0 else 0, self._patch_slices[a][index][0].stop+top if self._patch_slices[a][index][0].stop+top <= data.shape[len(slices_pre)] else data.shape[len(slices_pre)])
175
- slices = [s] + list(self._patch_slices[a][index][1:])
176
- data_sliced = data[slices_pre+slices]
177
- if data_sliced.shape[len(slices_pre)] < bottom+top+1:
202
+ for max_value in data.shape[: -len(self.patch_size)]:
203
+ slices_pre.append(slice(max_value))
204
+ extend_slice = self.extend_slice if is_input else 0
205
+
206
+ bottom = extend_slice // 2
207
+ top = int(np.ceil(extend_slice / 2))
208
+ s = slice(
209
+ (
210
+ self._patch_slices[a][index][0].start - bottom
211
+ if self._patch_slices[a][index][0].start - bottom >= 0
212
+ else 0
213
+ ),
214
+ (
215
+ self._patch_slices[a][index][0].stop + top
216
+ if self._patch_slices[a][index][0].stop + top <= data.shape[len(slices_pre)]
217
+ else data.shape[len(slices_pre)]
218
+ ),
219
+ )
220
+ slices = [s] + list(self._patch_slices[a][index][1:])
221
+ data_sliced = data[slices_pre + slices]
222
+ if data_sliced.shape[len(slices_pre)] < bottom + top + 1:
178
223
  pad_bottom = 0
179
224
  pad_top = 0
180
- if self._patch_slices[a][index][0].start-bottom < 0:
181
- pad_bottom = bottom-self._patch_slices[a][index][0].start
182
- if self._patch_slices[a][index][0].stop+top > data.shape[len(slices_pre)]:
183
- pad_top = self._patch_slices[a][index][0].stop+top-data.shape[len(slices_pre)]
184
- data_sliced = F.pad(data_sliced, [0 for _ in range((len(slices)-1)*2)]+[pad_bottom, pad_top], 'reflect')
225
+ if self._patch_slices[a][index][0].start - bottom < 0:
226
+ pad_bottom = bottom - self._patch_slices[a][index][0].start
227
+ if self._patch_slices[a][index][0].stop + top > data.shape[len(slices_pre)]:
228
+ pad_top = self._patch_slices[a][index][0].stop + top - data.shape[len(slices_pre)]
229
+ data_sliced = F.pad(
230
+ data_sliced,
231
+ [0 for _ in range((len(slices) - 1) * 2)] + [pad_bottom, pad_top],
232
+ "reflect",
233
+ )
185
234
 
186
235
  padding = []
187
236
  for dim_it, _slice in enumerate(reversed(slices)):
188
- p = 0 if _slice.start+self.patch_size[-dim_it-1] <= data.shape[-dim_it-1] else self.patch_size[-dim_it-1]-(data.shape[-dim_it-1]-_slice.start)
237
+ p = (
238
+ 0
239
+ if _slice.start + self.patch_size[-dim_it - 1] <= data.shape[-dim_it - 1]
240
+ else self.patch_size[-dim_it - 1] - (data.shape[-dim_it - 1] - _slice.start)
241
+ )
189
242
  padding.append(0)
190
243
  padding.append(p)
191
244
 
192
- data_sliced = F.pad(data_sliced, tuple(padding), "constant", 0 if data_sliced.dtype == torch.uint8 and self.padValue < 0 else self.padValue)
193
-
194
- for d in [i for i, v in enumerate(reversed(self.patch_size)) if v == 1]:
195
- data_sliced = torch.squeeze(data_sliced, dim = len(data_sliced.shape)-d-1)
196
- return torch.cat([data_sliced[:, i, ...] for i in range(data_sliced.shape[1])], dim=0) if extend_slice > 0 else data_sliced
245
+ data_sliced = F.pad(
246
+ data_sliced,
247
+ tuple(padding),
248
+ "constant",
249
+ (0 if data_sliced.dtype == torch.uint8 and self.pad_value < 0 else self.pad_value),
250
+ )
197
251
 
198
- def getSize(self, a: int = 0) -> int:
252
+ for d in [i for i, v in enumerate(reversed(self.patch_size)) if v == 1]:
253
+ data_sliced = torch.squeeze(data_sliced, dim=len(data_sliced.shape) - d - 1)
254
+ return (
255
+ torch.cat([data_sliced[:, i, ...] for i in range(data_sliced.shape[1])], dim=0)
256
+ if extend_slice > 0
257
+ else data_sliced
258
+ )
259
+
260
+ def get_size(self, a: int = 0) -> int:
199
261
  return len(self._patch_slices[a])
200
262
 
263
+
201
264
  class DatasetPatch(Patch):
202
265
 
203
266
  @config("Patch")
204
- def __init__(self, patch_size : list[int] = [128, 128, 128], overlap : Union[int, None] = None, padValue: float = 0, extend_slice: int = 0) -> None:
205
- super().__init__(patch_size, overlap, padValue, extend_slice)
267
+ def __init__(
268
+ self,
269
+ patch_size: list[int] = [128, 128, 128],
270
+ overlap: int | None = None,
271
+ pad_value: float = 0,
272
+ extend_slice: int = 0,
273
+ ) -> None:
274
+ super().__init__(patch_size, overlap, pad_value, extend_slice)
275
+
276
+ def init(self, key: str = ""):
277
+ pass
278
+
206
279
 
207
280
  class ModelPatch(Patch):
208
281
 
209
282
  @config("Patch")
210
- def __init__(self, patch_size : list[int] = [128, 128, 128], overlap : Union[int, None] = None, patchCombine: Union[str, None] = None, padValue: float = 0, extend_slice: int = 0) -> None:
211
- super().__init__(patch_size, overlap, padValue, extend_slice)
212
- self.patchCombine = patchCombine
213
-
214
- def disassemble(self, *dataList: torch.Tensor) -> Iterator[list[torch.Tensor]]:
215
- for i in range(self.getSize()):
216
- yield [self.getData(data, i, 0, True) for data in dataList]
217
-
218
- class DatasetManager():
219
-
220
- def __init__(self, index: int, group_src: str, group_dest : str, name: str, dataset : Dataset, patch : Union[DatasetPatch, None], pre_transforms : list[Transform], dataAugmentationsList : list[DataAugmentationsList]) -> None:
283
+ def __init__(
284
+ self,
285
+ patch_size: list[int] = [128, 128, 128],
286
+ overlap: int | None = None,
287
+ patch_combine: str | None = None,
288
+ pad_value: float = 0,
289
+ extend_slice: int = 0,
290
+ ) -> None:
291
+ super().__init__(patch_size, overlap, pad_value, extend_slice)
292
+ self._patch_combine = patch_combine
293
+ self.patch_combine: PathCombine | None = None
294
+
295
+ def init(self, key: str):
296
+ if self._patch_combine is not None:
297
+ module, name = get_module(self._patch_combine, "konfai.data.patching")
298
+ self.patch_combine = config(key)(getattr(importlib.import_module(module), name))(config=None)
299
+ if self.patch_size is not None and self.overlap is not None:
300
+ if self.patch_combine is not None:
301
+ self.patch_combine.set_patch_config([i for i in self.patch_size if i > 1], self.overlap)
302
+ else:
303
+ self.patch_combine = None
304
+
305
+ def disassemble(self, *data_list: torch.Tensor) -> Iterator[list[torch.Tensor]]:
306
+ for i in range(self.get_size()):
307
+ yield [self.get_data(data, i, 0, True) for data in data_list]
308
+
309
+
310
+ class DatasetManager:
311
+
312
+ def __init__(
313
+ self,
314
+ index: int,
315
+ group_src: str,
316
+ group_dest: str,
317
+ name: str,
318
+ dataset: Dataset,
319
+ patch: DatasetPatch | None,
320
+ transforms: list[Transform],
321
+ data_augmentations_list: list[DataAugmentationsList],
322
+ ) -> None:
221
323
  self.group_src = group_src
222
324
  self.group_dest = group_dest
223
325
  self.name = name
@@ -226,94 +328,110 @@ class DatasetManager():
226
328
  self.loaded = False
227
329
  self.augmentationLoaded = False
228
330
  self.cache_attributes: list[Attribute] = []
229
- _shape, cache_attribute = self.dataset.getInfos(self.group_src, name)
331
+ _shape, cache_attribute = self.dataset.get_infos(self.group_src, name)
230
332
  self.cache_attributes.append(cache_attribute)
231
333
  _shape = list(_shape[1:])
232
-
233
- self.data : list[torch.Tensor] = list()
234
334
 
235
- for transformFunction in pre_transforms:
236
- _shape = transformFunction.transformShape(_shape, cache_attribute)
237
-
238
- self.patch = DatasetPatch(patch_size=patch.patch_size, overlap=patch.overlap, padValue=patch.padValue, extend_slice=patch.extend_slice) if patch else DatasetPatch(_shape)
335
+ self.data: list[torch.Tensor] = []
336
+
337
+ for transform_function in transforms:
338
+ _shape = transform_function.transform_shape(_shape, cache_attribute)
339
+
340
+ self.patch = (
341
+ DatasetPatch(
342
+ patch_size=patch.patch_size,
343
+ overlap=patch.overlap,
344
+ pad_value=patch.pad_value,
345
+ extend_slice=patch.extend_slice,
346
+ )
347
+ if patch
348
+ else DatasetPatch(_shape)
349
+ )
239
350
  self.patch.load(_shape, 0)
240
351
  self.shape = _shape
241
- self.dataAugmentationsList = dataAugmentationsList
242
- self.resetAugmentation()
352
+ self.data_augmentations_list = data_augmentations_list
353
+ self.reset_augmentation()
243
354
  self.cache_attributes_bak = copy.deepcopy(self.cache_attributes)
244
-
245
- def resetAugmentation(self):
355
+
356
+ def reset_augmentation(self):
246
357
  self.cache_attributes[:] = self.cache_attributes[:1]
247
358
  i = 1
248
- for dataAugmentations in self.dataAugmentationsList:
359
+ for data_augmentations in self.data_augmentations_list:
249
360
  shape = []
250
361
  caches_attribute = []
251
- for _ in range(dataAugmentations.nb):
362
+ for _ in range(data_augmentations.nb):
252
363
  shape.append(self.shape)
253
364
  caches_attribute.append(copy.deepcopy(self.cache_attributes[0]))
254
365
 
255
- for dataAugmentation in dataAugmentations.dataAugmentations:
256
- shape = dataAugmentation.state_init(self.index, shape, caches_attribute)
366
+ for data_augmentation in data_augmentations.data_augmentations:
367
+ shape = data_augmentation.state_init(self.index, shape, caches_attribute)
257
368
  for it, s in enumerate(shape):
258
369
  self.cache_attributes.append(caches_attribute[it])
259
370
  self.patch.load(s, i)
260
- i+=1
261
-
262
- def load(self, pre_transform : list[Transform], dataAugmentationsList : list[DataAugmentationsList], device: torch.device) -> None:
371
+ i += 1
372
+
373
+ def load(
374
+ self,
375
+ pre_transform: list[Transform],
376
+ data_augmentations_list: list[DataAugmentationsList],
377
+ ) -> None:
263
378
  if not self.loaded:
264
379
  self._load(pre_transform)
265
380
  if not self.augmentationLoaded:
266
- self._loadAugmentation(dataAugmentationsList, device)
267
-
268
- def _load(self, pre_transform : list[Transform]):
381
+ self._load_augmentation(data_augmentations_list)
382
+
383
+ def _load(self, pre_transform: list[Transform]):
269
384
  self.cache_attributes = copy.deepcopy(self.cache_attributes_bak)
270
385
  i = len(pre_transform)
271
386
  data = None
272
- for transformFunction in reversed(pre_transform):
273
- if isinstance(transformFunction, Save):
274
- if len(transformFunction.dataset.split(":")) > 1:
275
- filename, format = transformFunction.dataset.split(":")
387
+ for transform_function in reversed(pre_transform):
388
+ if isinstance(transform_function, Save):
389
+ if len(transform_function.dataset.split(":")) > 1:
390
+ filename, file_format = transform_function.dataset.split(":")
276
391
  else:
277
- filename = transformFunction.dataset.split(":")
278
- format = "mha"
279
- dataset = Dataset(filename, format)
280
- if dataset.isDatasetExist(self.group_dest, self.name):
281
- data, attrib = dataset.readData(self.group_dest, self.name)
392
+ filename = transform_function.dataset
393
+ file_format = "mha"
394
+ dataset = Dataset(filename, file_format)
395
+ if dataset.is_dataset_exist(self.group_dest, self.name):
396
+ data, attrib = dataset.read_data(self.group_dest, self.name)
282
397
  self.cache_attributes[0].update(attrib)
283
398
  break
284
- i-=1
285
-
286
- if i==0:
287
- data, _ = self.dataset.readData(self.group_src, self.name)
399
+ i -= 1
288
400
 
401
+ if i == 0:
402
+ data, _ = self.dataset.read_data(self.group_src, self.name)
289
403
 
290
404
  data = torch.from_numpy(data)
291
405
 
292
406
  if len(pre_transform):
293
- for transformFunction in pre_transform[i:]:
294
- data = transformFunction(self.name, data, self.cache_attributes[0])
295
- if isinstance(transformFunction, Save):
296
- if len(transformFunction.dataset.split(":")) > 1:
297
- filename, format = transformFunction.dataset.split(":")
407
+ for transform_function in pre_transform[i:]:
408
+ data = transform_function(self.name, data, self.cache_attributes[0])
409
+ if isinstance(transform_function, Save):
410
+ if len(transform_function.dataset.split(":")) > 1:
411
+ filename, file_format = transform_function.dataset.split(":")
298
412
  else:
299
- filename = transformFunction.dataset.split(":")
300
- format = "mha"
301
- dataset = Dataset(filename, format)
302
- dataset.write(self.group_dest, self.name, data.numpy(), self.cache_attributes[0])
303
- self.data : list[torch.Tensor] = list()
413
+ filename = transform_function.dataset
414
+ file_format = "mha"
415
+ dataset = Dataset(filename, file_format)
416
+ dataset.write(
417
+ self.group_dest,
418
+ self.name,
419
+ data.numpy(),
420
+ self.cache_attributes[0],
421
+ )
304
422
  self.data.append(data)
305
-
306
- for i in range(len(self.cache_attributes)-1):
307
- self.cache_attributes[i+1].update(self.cache_attributes[0])
423
+
424
+ for i in range(len(self.cache_attributes) - 1):
425
+ self.cache_attributes[i + 1].update(self.cache_attributes[0])
308
426
  self.loaded = True
309
-
310
- def _loadAugmentation(self, dataAugmentationsList : list[DataAugmentationsList], device: torch.device) -> None:
311
- for dataAugmentations in dataAugmentationsList:
312
- a_data = [self.data[0].clone() for _ in range(dataAugmentations.nb)]
313
- for dataAugmentation in dataAugmentations.dataAugmentations:
314
- if dataAugmentation.groups is None or self.group_dest in dataAugmentation.groups:
315
- a_data = dataAugmentation(self.name, self.index, a_data, device)
316
-
427
+
428
+ def _load_augmentation(self, data_augmentations_list: list[DataAugmentationsList]) -> None:
429
+ for data_augmentations in data_augmentations_list:
430
+ a_data = [self.data[0].clone() for _ in range(data_augmentations.nb)]
431
+ for data_augmentation in data_augmentations.data_augmentations:
432
+ if data_augmentation.groups is None or self.group_dest in data_augmentation.groups:
433
+ a_data = data_augmentation(self.name, self.index, a_data)
434
+
317
435
  for d in a_data:
318
436
  self.data.append(d)
319
437
  self.augmentationLoaded = True
@@ -322,17 +440,16 @@ class DatasetManager():
322
440
  self.data.clear()
323
441
  self.loaded = False
324
442
  self.augmentationLoaded = False
325
-
326
- def unloadAugmentation(self) -> None:
443
+
444
+ def unload_augmentation(self) -> None:
327
445
  self.data[:] = self.data[:1]
328
446
  self.augmentationLoaded = False
329
447
 
330
-
331
- def getData(self, index : int, a : int, post_transforms : list[Transform], isInput: bool) -> torch.Tensor:
332
- data = self.patch.getData(self.data[a], index, a, isInput)
333
- for transformFunction in post_transforms:
334
- data = transformFunction(self.name, data, self.cache_attributes[a])
448
+ def get_data(self, index: int, a: int, patch_transforms: list[Transform], is_input: bool) -> torch.Tensor:
449
+ data = self.patch.get_data(self.data[a], index, a, is_input)
450
+ for transform_function in patch_transforms:
451
+ data = transform_function(self.name, data, self.cache_attributes[a])
335
452
  return data
336
453
 
337
- def getSize(self, a: int) -> int:
338
- return self.patch.getSize(a)
454
+ def get_size(self, a: int) -> int:
455
+ return self.patch.get_size(a)