konfai 1.1.1__py3-none-any.whl → 1.1.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of konfai might be problematic. Click here for more details.

konfai/__init__.py CHANGED
@@ -12,5 +12,5 @@ CONFIG_FILE = lambda : os.environ["KONFAI_CONFIG_FILE"]
12
12
  KONFAI_STATE = lambda : os.environ["KONFAI_STATE"]
13
13
  KONFAI_ROOT = lambda : os.environ["KONFAI_ROOT"]
14
14
  CUDA_VISIBLE_DEVICES = lambda : os.environ["CUDA_VISIBLE_DEVICES"]
15
-
15
+ KONFAI_NB_CORES = lambda : os.environ["KONFAI_NB_CORES"]
16
16
  DATE = lambda : datetime.datetime.now().strftime("%Y_%m_%d_%H_%M_%S")
@@ -9,8 +9,7 @@ import os
9
9
  from konfai import KONFAI_ROOT
10
10
  from konfai.utils.config import config
11
11
  from konfai.utils.utils import _getModule, AugmentationError
12
- from konfai.utils.dataset import Attribute, data_to_image
13
-
12
+ from konfai.utils.dataset import Attribute, data_to_image, Dataset
14
13
 
15
14
  def _translate2DMatrix(t: torch.Tensor) -> torch.Tensor:
16
15
  return torch.cat((torch.cat((torch.eye(2), torch.tensor([[t[0]], [t[1]]])), dim=1), torch.Tensor([[0,0,1]])), dim=0)
@@ -56,23 +55,29 @@ class DataAugmentationsList():
56
55
  self.dataAugmentations : list[DataAugmentation] = []
57
56
  self.dataAugmentationsLoader = dataAugmentations
58
57
 
59
- def load(self, key: str):
58
+ def load(self, key: str, datasets: list[Dataset]):
60
59
  for augmentation, prob in self.dataAugmentationsLoader.items():
61
60
  module, name = _getModule(augmentation, "data.augmentation")
62
- dataAugmentation: DataAugmentation = getattr(importlib.import_module(module), name)(config = None, DL_args="{}.Dataset.augmentations.{}.dataAugmentations".format(KONFAI_ROOT(), key))
61
+ dataAugmentation: DataAugmentation = config("{}.Dataset.augmentations.{}.dataAugmentations.{}".format(KONFAI_ROOT(), key, augmentation))(getattr(importlib.import_module(module), name))(config = None)
63
62
  dataAugmentation.load(prob.prob)
63
+ dataAugmentation.setDatasets(datasets)
64
64
  self.dataAugmentations.append(dataAugmentation)
65
-
65
+
66
66
  class DataAugmentation(ABC):
67
67
 
68
- def __init__(self) -> None:
68
+ def __init__(self, groups: Union[list[str], None] = None) -> None:
69
69
  self.who_index: dict[int, list[int]] = {}
70
70
  self.shape_index: dict[int, list[list[int]]] = {}
71
71
  self._prob: float = 0
72
+ self.groups = groups
73
+ self.datasets : list[Dataset] = []
72
74
 
73
75
  def load(self, prob: float):
74
76
  self._prob = prob
75
77
 
78
+ def setDatasets(self, datasets: list[Dataset]):
79
+ self.datasets = datasets
80
+
76
81
  def state_init(self, index: Union[None, int], shapes: list[list[int]], caches_attribute: list[Attribute]) -> list[list[int]]:
77
82
  if index is not None:
78
83
  if index not in self.who_index:
@@ -93,14 +98,14 @@ class DataAugmentation(ABC):
93
98
  def _state_init(self, index : int, shapes: list[list[int]], caches_attribute: list[Attribute]) -> list[list[int]]:
94
99
  pass
95
100
 
96
- def __call__(self, index: int, inputs : list[torch.Tensor], device: Union[torch.device, None]) -> list[torch.Tensor]:
101
+ def __call__(self, name: str, index: int, inputs : list[torch.Tensor], device: Union[torch.device, None]) -> list[torch.Tensor]:
97
102
  if len(self.who_index[index]) > 0:
98
- for i, result in enumerate(self._compute(index, [inputs[i] for i in self.who_index[index]], device)):
103
+ for i, result in enumerate(self._compute(name, index, [inputs[i] for i in self.who_index[index]], device)):
99
104
  inputs[self.who_index[index][i]] = result if device is None else result.cpu()
100
105
  return inputs
101
106
 
102
107
  @abstractmethod
103
- def _compute(self, index: int, inputs : list[torch.Tensor], device: Union[torch.device, None]) -> list[torch.Tensor]:
108
+ def _compute(self, name: str, index: int, inputs : list[torch.Tensor], device: Union[torch.device, None]) -> list[torch.Tensor]:
104
109
  pass
105
110
 
106
111
  def inverse(self, index: int, a: int, input : torch.Tensor) -> torch.Tensor:
@@ -118,7 +123,7 @@ class EulerTransform(DataAugmentation):
118
123
  super().__init__()
119
124
  self.matrix: dict[int, list[torch.Tensor]] = {}
120
125
 
121
- def _compute(self, index: int, inputs : list[torch.Tensor], device: Union[torch.device, None]) -> list[torch.Tensor]:
126
+ def _compute(self, name: str, index: int, inputs : list[torch.Tensor], device: Union[torch.device, None]) -> list[torch.Tensor]:
122
127
  results = []
123
128
  for input, matrix in zip(inputs, self.matrix[index]):
124
129
  results.append(F.grid_sample(input.unsqueeze(0).type(torch.float32), F.affine_grid(matrix[:, :-1,...], [1]+list(input.shape), align_corners=True).to(input.device), align_corners=True, mode="bilinear", padding_mode="reflection").type(input.dtype).squeeze(0))
@@ -126,7 +131,7 @@ class EulerTransform(DataAugmentation):
126
131
 
127
132
  def _inverse(self, index: int, a: int, input : torch.Tensor) -> torch.Tensor:
128
133
  return F.grid_sample(input.unsqueeze(0).type(torch.float32), F.affine_grid(self.matrix[index][a].inverse()[:, :-1,...], [1]+list(input.shape), align_corners=True).to(input.device), align_corners=True, mode="bilinear", padding_mode="reflection").type(input.dtype).squeeze(0)
129
-
134
+
130
135
  class Translate(EulerTransform):
131
136
 
132
137
  @config("Translate")
@@ -194,7 +199,7 @@ class Flip(DataAugmentation):
194
199
  self.flip[index] = [dims[mask].tolist() for mask in prob]
195
200
  return shapes
196
201
 
197
- def _compute(self, index: int, inputs : list[torch.Tensor], device: Union[torch.device, None]) -> list[torch.Tensor]:
202
+ def _compute(self, name: str, index: int, inputs : list[torch.Tensor], device: Union[torch.device, None]) -> list[torch.Tensor]:
198
203
  results = []
199
204
  for input, flip in zip(inputs, self.flip[index]):
200
205
  results.append(torch.flip(input, dims=flip))
@@ -207,11 +212,11 @@ class Flip(DataAugmentation):
207
212
  class ColorTransform(DataAugmentation):
208
213
 
209
214
  @config("ColorTransform")
210
- def __init__(self) -> None:
211
- super().__init__()
215
+ def __init__(self, groups: Union[list[str], None] = None) -> None:
216
+ super().__init__(groups)
212
217
  self.matrix: dict[int, list[torch.Tensor]] = {}
213
218
 
214
- def _compute(self, index: int, inputs : list[torch.Tensor], device: Union[torch.device, None]) -> list[torch.Tensor]:
219
+ def _compute(self, name: str, index: int, inputs : list[torch.Tensor], device: Union[torch.device, None]) -> list[torch.Tensor]:
215
220
  results = []
216
221
  for input, matrix in zip(inputs, self.matrix[index]):
217
222
  result = input.reshape([*input.shape[:1], int(np.prod(input.shape[1:]))])
@@ -232,8 +237,8 @@ class ColorTransform(DataAugmentation):
232
237
  class Brightness(ColorTransform):
233
238
 
234
239
  @config("Brightness")
235
- def __init__(self, b_std: float) -> None:
236
- super().__init__()
240
+ def __init__(self, b_std: float, groups: Union[list[str], None] = None) -> None:
241
+ super().__init__(groups)
237
242
  self.b_std = b_std
238
243
 
239
244
  def _state_init(self, index : int, shapes: list[list[int]], caches_attribute: list[Attribute]) -> list[list[int]]:
@@ -244,8 +249,8 @@ class Brightness(ColorTransform):
244
249
  class Contrast(ColorTransform):
245
250
 
246
251
  @config("Contrast")
247
- def __init__(self, c_std: float) -> None:
248
- super().__init__()
252
+ def __init__(self, c_std: float, groups: Union[list[str], None] = None) -> None:
253
+ super().__init__(groups)
249
254
  self.c_std = c_std
250
255
 
251
256
  def _state_init(self, index : int, shapes: list[list[int]], caches_attribute: list[Attribute]) -> list[list[int]]:
@@ -256,8 +261,8 @@ class Contrast(ColorTransform):
256
261
  class LumaFlip(ColorTransform):
257
262
 
258
263
  @config("LumaFlip")
259
- def __init__(self) -> None:
260
- super().__init__()
264
+ def __init__(self, groups: Union[list[str], None] = None) -> None:
265
+ super().__init__(groups)
261
266
  self.v = torch.tensor([1, 1, 1, 0])/torch.sqrt(torch.tensor(3))
262
267
 
263
268
  def _state_init(self, index : int, shapes: list[list[int]], caches_attribute: list[Attribute]) -> list[list[int]]:
@@ -268,8 +273,8 @@ class LumaFlip(ColorTransform):
268
273
  class HUE(ColorTransform):
269
274
 
270
275
  @config("HUE")
271
- def __init__(self, hue_max: float) -> None:
272
- super().__init__()
276
+ def __init__(self, hue_max: float, groups: Union[list[str], None] = None) -> None:
277
+ super().__init__(groups)
273
278
  self.hue_max = hue_max
274
279
  self.v = torch.tensor([1, 1, 1])/torch.sqrt(torch.tensor(3))
275
280
 
@@ -281,8 +286,8 @@ class HUE(ColorTransform):
281
286
  class Saturation(ColorTransform):
282
287
 
283
288
  @config("Saturation")
284
- def __init__(self, s_std: float) -> None:
285
- super().__init__()
289
+ def __init__(self, s_std: float, groups: Union[list[str], None] = None) -> None:
290
+ super().__init__(groups)
286
291
  self.s_std = s_std
287
292
  self.v = torch.tensor([1, 1, 1, 0])/torch.sqrt(torch.tensor(3))
288
293
 
@@ -368,8 +373,8 @@ class Saturation(ColorTransform):
368
373
  class Noise(DataAugmentation):
369
374
 
370
375
  @config("Noise")
371
- def __init__(self, n_std: float, noise_step: int=1000, beta_start: float = 1e-4, beta_end: float = 0.02) -> None:
372
- super().__init__()
376
+ def __init__(self, n_std: float, noise_step: int=1000, beta_start: float = 1e-4, beta_end: float = 0.02, groups: Union[list[str], None] = None) -> None:
377
+ super().__init__(groups)
373
378
  self.n_std = n_std
374
379
  self.noise_step = noise_step
375
380
 
@@ -410,7 +415,7 @@ class Noise(DataAugmentation):
410
415
  self.ts[index] = [torch.randint(0, int(self.max_T), (1,)) for _ in shapes]
411
416
  return shapes
412
417
 
413
- def _compute(self, index: int, inputs : list[torch.Tensor], device: Union[torch.device, None]) -> list[torch.Tensor]:
418
+ def _compute(self, name: str, index: int, inputs : list[torch.Tensor], device: Union[torch.device, None]) -> list[torch.Tensor]:
414
419
  results = []
415
420
  for input, t in zip(inputs, self.ts[index]):
416
421
  alpha_hat_t = self.alpha_hat[t].to(input.device).reshape(*[1 for _ in range(len(input.shape))])
@@ -423,8 +428,8 @@ class Noise(DataAugmentation):
423
428
  class CutOUT(DataAugmentation):
424
429
 
425
430
  @config("CutOUT")
426
- def __init__(self, c_prob: float, cutout_size: int, value: float) -> None:
427
- super().__init__()
431
+ def __init__(self, c_prob: float, cutout_size: int, value: float, groups: Union[list[str], None] = None) -> None:
432
+ super().__init__(groups)
428
433
  self.c_prob = c_prob
429
434
  self.cutout_size = cutout_size
430
435
  self.centers: dict[int, list[torch.Tensor]] = {}
@@ -434,7 +439,7 @@ class CutOUT(DataAugmentation):
434
439
  self.centers[index] = [torch.rand((3) if len(shape) == 3 else (2)) for shape in shapes]
435
440
  return shapes
436
441
 
437
- def _compute(self, index: int, inputs : list[torch.Tensor], device: Union[torch.device, None]) -> list[torch.Tensor]:
442
+ def _compute(self, name: str, index: int, inputs : list[torch.Tensor], device: Union[torch.device, None]) -> list[torch.Tensor]:
438
443
  results = []
439
444
  for input, center in zip(inputs, self.centers[index]):
440
445
  masks = []
@@ -513,7 +518,7 @@ class Elastix(DataAugmentation):
513
518
  print("Compute in progress : {:.2f} %".format((i+1)/len(shapes)*100))
514
519
  return shapes
515
520
 
516
- def _compute(self, index: int, inputs : list[torch.Tensor], device: Union[torch.device, None]) -> list[torch.Tensor]:
521
+ def _compute(self, name: str, index: int, inputs : list[torch.Tensor], device: Union[torch.device, None]) -> list[torch.Tensor]:
517
522
  results = []
518
523
  for input, displacement_field in zip(inputs, self.displacement_fields[index]):
519
524
  results.append(F.grid_sample(input.type(torch.float32).unsqueeze(0), displacement_field.to(input.device), align_corners=True, mode="bilinear", padding_mode="border").type(input.dtype).squeeze(0))
@@ -546,7 +551,7 @@ class Permute(DataAugmentation):
546
551
  shapes[i] = [shapes[i][dim-1] for dim in permute[1:]]
547
552
  return shapes
548
553
 
549
- def _compute(self, index: int, inputs : list[torch.Tensor], device: Union[torch.device, None]) -> list[torch.Tensor]:
554
+ def _compute(self, name: str, index: int, inputs : list[torch.Tensor], device: Union[torch.device, None]) -> list[torch.Tensor]:
550
555
  results = []
551
556
  for input, prob in zip(inputs, self.permute[index]):
552
557
  res = input
@@ -563,8 +568,8 @@ class Permute(DataAugmentation):
563
568
  class Mask(DataAugmentation):
564
569
 
565
570
  @config("Mask")
566
- def __init__(self, mask: str, value: float) -> None:
567
- super().__init__()
571
+ def __init__(self, mask: str, value: float, groups: Union[list[str], None] = None) -> None:
572
+ super().__init__(groups)
568
573
  if mask is not None:
569
574
  if os.path.exists(mask):
570
575
  self.mask = torch.tensor(sitk.GetArrayFromImage(sitk.ReadImage(mask)))
@@ -577,7 +582,7 @@ class Mask(DataAugmentation):
577
582
  self.positions[index] = [torch.rand((3) if len(shape) == 3 else (2))*(torch.tensor([max(s1-s2, 0) for s1, s2 in zip(torch.tensor(shape), torch.tensor(self.mask.shape))])) for shape in shapes]
578
583
  return [self.mask.shape for _ in shapes]
579
584
 
580
- def _compute(self, index: int, inputs : list[torch.Tensor], device: Union[torch.device, None]) -> list[torch.Tensor]:
585
+ def _compute(self, name: str, index: int, inputs : list[torch.Tensor], device: Union[torch.device, None]) -> list[torch.Tensor]:
581
586
  results = []
582
587
  for input, position in zip(inputs, self.positions[index]):
583
588
  slices = [slice(None, None)]+[slice(int(s1), int(s1)+s2) for s1, s2 in zip(position, self.mask.shape)]
@@ -62,12 +62,25 @@ class GroupTransform:
62
62
  for transform in self.post_transforms:
63
63
  transform.setDevice(device)
64
64
 
65
+ class GroupTransformMetric(GroupTransform):
66
+
67
+ @config()
68
+ def __init__(self, pre_transforms : Union[dict[str, TransformLoader], list[Transform]] = {"default:Normalize:Standardize:Unsqueeze:TensorCast:ResampleIsotropic:ResampleResize": TransformLoader()},
69
+ post_transforms : Union[dict[str, TransformLoader], list[Transform]] = {"default:Normalize:Standardize:Unsqueeze:TensorCast:ResampleIsotropic:ResampleResize": TransformLoader()}):
70
+ super().__init__(pre_transforms, post_transforms)
71
+
65
72
  class Group(dict[str, GroupTransform]):
66
73
 
67
74
  @config()
68
75
  def __init__(self, groups_dest: dict[str, GroupTransform] = {"default:group_dest": GroupTransform()}):
69
76
  super().__init__(groups_dest)
70
77
 
78
+ class GroupMetric(dict[str, GroupTransformMetric]):
79
+
80
+ @config()
81
+ def __init__(self, groups_dest: dict[str, GroupTransformMetric] = {"default:group_dest": GroupTransformMetric()}):
82
+ super().__init__(groups_dest)
83
+
71
84
  class CustomSampler(Sampler[int]):
72
85
 
73
86
  def __init__(self, size: int, shuffle: bool = False) -> None:
@@ -109,32 +122,33 @@ class DatasetIter(data.Dataset):
109
122
  def getDatasetFromIndex(self, group_dest: str, index: int) -> DatasetManager:
110
123
  return self.data[group_dest][index]
111
124
 
112
- def resetAugmentation(self):
113
- if self.inlineAugmentations:
125
+ def resetAugmentation(self, label):
126
+ if self.inlineAugmentations and len(self.dataAugmentationsList) > 0:
114
127
  for index in range(self.nb_dataset):
115
- self._unloadData(index)
116
128
  for group_src in self.groups_src:
117
129
  for group_dest in self.groups_src[group_src]:
130
+ self.data[group_dest][index].unloadAugmentation()
118
131
  self.data[group_dest][index].resetAugmentation()
132
+ self.load(label + " Augmentation")
119
133
 
120
- def load(self):
134
+ def load(self, label: str):
121
135
  if self.use_cache:
122
136
  memory_init = getMemory()
123
137
 
124
- indexs = [index for index in range(self.nb_dataset) if index not in self._index_cache]
138
+ indexs = [index for index in range(self.nb_dataset)]
125
139
  if len(indexs) > 0:
126
140
  memory_lock = threading.Lock()
141
+ desc = lambda : "Caching "+ label +": {} | {} | {}".format(memoryInfo(), memoryForecast(memory_init, 0, self.nb_dataset), cpuInfo())
127
142
  pbar = tqdm.tqdm(
128
143
  total=len(indexs),
129
- desc="Caching : init | {} | {}".format(memoryForecast(memory_init, 0, self.nb_dataset), cpuInfo()),
130
- leave=False,
131
- disable=self.rank != 0 and "KONFAI_CLUSTER" not in os.environ
144
+ desc=desc(),
145
+ leave=False
132
146
  )
133
147
 
134
148
  def process(index):
135
149
  self._loadData(index)
136
150
  with memory_lock:
137
- pbar.set_description("Caching : {} | {} | {}".format(memoryInfo(), memoryForecast(memory_init, index, self.nb_dataset), cpuInfo()))
151
+ pbar.set_description(desc())
138
152
  pbar.update(1)
139
153
  with ThreadPoolExecutor(max_workers=os.cpu_count()//(device_count() if device_count() > 0 else 1)) as executor:
140
154
  futures = [executor.submit(process, index) for index in indexs]
@@ -170,7 +184,7 @@ class DatasetIter(data.Dataset):
170
184
  data = {}
171
185
  x, a, p = self.map[index]
172
186
  if x not in self._index_cache:
173
- if len(self._index_cache) >= self.buffer_size and not self.use_cache:
187
+ if x not in self._index_cache and len(self._index_cache) >= self.buffer_size and not self.use_cache:
174
188
  self._unloadData(self._index_cache[0])
175
189
  self._loadData(x)
176
190
 
@@ -257,10 +271,10 @@ class Data(ABC):
257
271
  groups_src : dict[str, Group],
258
272
  patch : Union[DatasetPatch, None],
259
273
  use_cache : bool,
260
- subset : Union[Subset, dict[str, Subset]],
274
+ subset : Subset,
261
275
  num_workers : int,
262
276
  batch_size : int,
263
- validation: Union[float, str, list[int], list[str]] = 1,
277
+ validation: Union[float, str, list[int], list[str], None] = None,
264
278
  inlineAugmentations: bool = False,
265
279
  dataAugmentationsList: dict[str, DataAugmentationsList]= {}) -> None:
266
280
  self.dataset_filenames = dataset_filenames
@@ -343,7 +357,7 @@ class Data(ABC):
343
357
  append = flag == "a"
344
358
 
345
359
  if format not in SUPPORTED_EXTENSIONS:
346
- raise DatasetManagerError(f"Unsupported file format '{format}'.",
360
+ raise DatasetManagerError(f"Unsupported file format '{format}'.",
347
361
  f"Supported extensions are: {', '.join(SUPPORTED_EXTENSIONS)}")
348
362
 
349
363
  dataset = Dataset(filename, format)
@@ -362,7 +376,7 @@ class Data(ABC):
362
376
  raise DatasetManagerError(
363
377
  f"Group source '{group_src}' not found in any dataset.",
364
378
  f"Dataset filenames provided: {self.dataset_filenames}",
365
- f"Available groups across all datasets: {sorted(list(datasets.keys()))}",
379
+ "Available groups across all datasets: {}".format(["{} {}".format(f, d.getGroup()) for f, d in self.datasets.items()]),
366
380
  f"Please check that an entry in the dataset with the name '{group_src}.{format}' exists."
367
381
  )
368
382
 
@@ -376,34 +390,44 @@ class Data(ABC):
376
390
  )
377
391
 
378
392
  for key, dataAugmentations in self.dataAugmentationsList.items():
379
- dataAugmentations.load(key)
393
+ dataAugmentations.load(key, [self.datasets[filename] for filename, _ in datasets[group_src]])
380
394
 
381
395
  names = set()
382
396
  dataset_name : dict[str, dict[str, list[str]]] = {}
383
397
  dataset_info : dict[str, dict[str, dict[str, Attribute]]] = {}
384
398
  for group in self.groups_src:
399
+ namesByGroup = set()
385
400
  if group not in dataset_name:
386
401
  dataset_name[group] = {}
387
402
  dataset_info[group] = {}
388
403
  for filename, _ in datasets[group]:
389
- names.update(self.datasets[filename].getNames(group))
404
+ namesByGroup.update(self.datasets[filename].getNames(group))
390
405
  dataset_name[group][filename] = self.datasets[filename].getNames(group)
391
406
  dataset_info[group][filename] = {name: self.datasets[filename].getInfos(group, name) for name in dataset_name[group][filename]}
407
+ if len(names) == 0:
408
+ names.update(namesByGroup)
409
+ else:
410
+ names = names.intersection(namesByGroup)
411
+ if len(names) == 0:
412
+ raise DatasetManagerError(
413
+ f"No data was found for groups {list(self.groups_src.keys())}: although each group contains data from a dataset, there are no common dataset names shared across all groups, the intersection is empty."
414
+ )
415
+
392
416
  subset_names = set()
393
-
394
- if isinstance(self.subset, dict):
395
- for filename, subset in self.subset.items():
396
- subset_names.update(subset([dataset_name[group][filename] for group in dataset_name], [dataset_info[group][filename] for group in dataset_name]))
397
- else:
398
- for group in dataset_name:
399
- for filename, append in datasets[group]:
400
- if append:
401
- subset_names.update(self.subset([dataset_name[group][filename]], [dataset_info[group][filename]]))
417
+ for group in dataset_name:
418
+ subset_names_bygroup = set()
419
+ for filename, append in datasets[group]:
420
+ if append:
421
+ subset_names_bygroup.update(self.subset([dataset_name[group][filename]], [dataset_info[group][filename]]))
422
+ else:
423
+ if len(subset_names_bygroup) == 0:
424
+ subset_names_bygroup.update(self.subset([dataset_name[group][filename]], [dataset_info[group][filename]]))
402
425
  else:
403
- if len(subset_names) == 0:
404
- subset_names.update(self.subset([dataset_name[group][filename]], [dataset_info[group][filename]]))
405
- else:
406
- subset_names.intersection(self.subset([dataset_name[group][filename]], [dataset_info[group][filename]]))
426
+ subset_names_bygroup = subset_names_bygroup.intersection(self.subset([dataset_name[group][filename]], [dataset_info[group][filename]]))
427
+ if len(subset_names) == 0:
428
+ subset_names.update(subset_names_bygroup)
429
+ else:
430
+ subset_names = subset_names.intersection(subset_names_bygroup)
407
431
  if len(subset_names) == 0:
408
432
  raise DatasetManagerError("All data entries were excluded by the subset filter.",
409
433
  f"Dataset entries found: {', '.join(names)}",
@@ -424,7 +448,7 @@ class Data(ABC):
424
448
  validate_map = []
425
449
  if isinstance(self.validation, float) or isinstance(self.validation, int):
426
450
  if self.validation <= 0 or self.validation >= 1:
427
- raise DatasetManagerError("validation must be a float between 0 and 1.", f"Received: {self.validation}", "Example: validation = 0.2 # for a 20% validation split")
451
+ raise DatasetManagerError("Validation must be a float between 0 and 1.", f"Received: {self.validation}", "Example: validation = 0.2 # for a 20% validation split")
428
452
 
429
453
  train_map, validate_map = map[:int(math.floor(len(map)*(1-self.validation)))], map[int(math.floor(len(map)*(1-self.validation))):]
430
454
  elif isinstance(self.validation, str):
@@ -527,20 +551,19 @@ class DataPrediction(Data):
527
551
  groups_src : dict[str, Group] = {"default" : Group()},
528
552
  augmentations : Union[dict[str, DataAugmentationsList], None] = {"DataAugmentation_0" : DataAugmentationsList()},
529
553
  patch : Union[DatasetPatch, None] = DatasetPatch(),
530
- use_cache : bool = True,
531
554
  subset : Union[PredictionSubset, dict[str, PredictionSubset]] = PredictionSubset(),
532
555
  num_workers : int = 4,
533
556
  batch_size : int = 1) -> None:
534
557
 
535
- super().__init__(dataset_filenames, groups_src, patch, use_cache, subset, num_workers, batch_size, dataAugmentationsList=augmentations if augmentations else {})
558
+ super().__init__(dataset_filenames, groups_src, patch, False, subset, num_workers, batch_size, dataAugmentationsList=augmentations if augmentations else {})
536
559
 
537
560
  class DataMetric(Data):
538
561
 
539
562
  @config("Dataset")
540
563
  def __init__(self, dataset_filenames : list[str] = ["default:./Dataset"],
541
- groups_src : dict[str, Group] = {"default" : Group()},
564
+ groups_src : dict[str, GroupMetric] = {"default" : GroupMetric()},
542
565
  subset : Union[PredictionSubset, dict[str, PredictionSubset]] = PredictionSubset(),
543
566
  validation: Union[str, None] = None,
544
567
  num_workers : int = 4) -> None:
545
568
 
546
- super().__init__(dataset_filenames=dataset_filenames, groups_src=groups_src, patch=None, use_cache=False, subset=subset, num_workers=num_workers, batch_size=1, validation=1 if validation is None else validation)
569
+ super().__init__(dataset_filenames=dataset_filenames, groups_src=groups_src, patch=None, use_cache=False, subset=subset, num_workers=num_workers, batch_size=1, validation=validation)
konfai/data/patching.py CHANGED
@@ -15,7 +15,6 @@ from konfai.utils.dataset import Dataset, Attribute
15
15
  from konfai.data.transform import Transform, Save
16
16
  from konfai.data.augmentation import DataAugmentationsList
17
17
 
18
-
19
18
  class PathCombine(ABC):
20
19
 
21
20
  def __init__(self) -> None:
@@ -240,6 +239,7 @@ class DatasetManager():
240
239
  self.index = index
241
240
  self.dataset = dataset
242
241
  self.loaded = False
242
+ self.augmentationLoaded = False
243
243
  self.cache_attributes: list[Attribute] = []
244
244
  _shape, cache_attribute = self.dataset.getInfos(self.group_src, name)
245
245
  self.cache_attributes.append(cache_attribute)
@@ -258,6 +258,7 @@ class DatasetManager():
258
258
  self.cache_attributes_bak = copy.deepcopy(self.cache_attributes)
259
259
 
260
260
  def resetAugmentation(self):
261
+ self.cache_attributes[:] = self.cache_attributes[:1]
261
262
  i = 1
262
263
  for dataAugmentations in self.dataAugmentationsList:
263
264
  shape = []
@@ -272,15 +273,24 @@ class DatasetManager():
272
273
  self.cache_attributes.append(caches_attribute[it])
273
274
  self.patch.load(s, i)
274
275
  i+=1
275
-
276
+
276
277
  def load(self, pre_transform : list[Transform], dataAugmentationsList : list[DataAugmentationsList], device: torch.device) -> None:
277
- if self.loaded:
278
- return
278
+ if not self.loaded:
279
+ self._load(pre_transform)
280
+ if not self.augmentationLoaded:
281
+ self._loadAugmentation(dataAugmentationsList, device)
282
+
283
+ def _load(self, pre_transform : list[Transform]):
284
+ self.cache_attributes = copy.deepcopy(self.cache_attributes_bak)
279
285
  i = len(pre_transform)
280
286
  data = None
281
287
  for transformFunction in reversed(pre_transform):
282
288
  if isinstance(transformFunction, Save):
283
- filename, format = transformFunction.save.split(":")
289
+ if len(transformFunction.dataset.split(":")) > 1:
290
+ filename, format = transformFunction.dataset.split(":")
291
+ else:
292
+ filename = transformFunction.dataset.split(":")
293
+ format = "mha"
284
294
  dataset = Dataset(filename, format)
285
295
  if dataset.isDatasetExist(self.group_dest, self.name):
286
296
  data, attrib = dataset.readData(self.group_dest, self.name)
@@ -298,27 +308,41 @@ class DatasetManager():
298
308
  for transformFunction in pre_transform[i:]:
299
309
  data = transformFunction(self.name, data, self.cache_attributes[0])
300
310
  if isinstance(transformFunction, Save):
301
- filename, format = transformFunction.save.split(":")
311
+ if len(transformFunction.dataset.split(":")) > 1:
312
+ filename, format = transformFunction.dataset.split(":")
313
+ else:
314
+ filename = transformFunction.dataset.split(":")
315
+ format = "mha"
302
316
  dataset = Dataset(filename, format)
303
317
  dataset.write(self.group_dest, self.name, data.numpy(), self.cache_attributes[0])
304
318
  self.data : list[torch.Tensor] = list()
305
319
  self.data.append(data)
306
320
 
321
+ for i in range(len(self.cache_attributes)-1):
322
+ self.cache_attributes[i+1].update(self.cache_attributes[0])
323
+ self.loaded = True
324
+
325
+ def _loadAugmentation(self, dataAugmentationsList : list[DataAugmentationsList], device: torch.device) -> None:
307
326
  for dataAugmentations in dataAugmentationsList:
308
- a_data = [data.clone() for _ in range(dataAugmentations.nb)]
327
+ a_data = [self.data[0].clone() for _ in range(dataAugmentations.nb)]
309
328
  for dataAugmentation in dataAugmentations.dataAugmentations:
310
- a_data = dataAugmentation(self.index, a_data, device)
329
+ if dataAugmentation.groups is None or self.group_dest in dataAugmentation.groups:
330
+ a_data = dataAugmentation(self.name, self.index, a_data, device)
311
331
 
312
332
  for d in a_data:
313
333
  self.data.append(d)
314
- self.loaded = True
334
+ self.augmentationLoaded = True
315
335
 
316
336
  def unload(self) -> None:
317
- if hasattr(self, "data"):
318
- del self.data
319
- self.cache_attributes = copy.deepcopy(self.cache_attributes_bak)
337
+ self.data.clear()
320
338
  self.loaded = False
321
-
339
+ self.augmentationLoaded = False
340
+
341
+ def unloadAugmentation(self) -> None:
342
+ self.data[:] = self.data[:1]
343
+ self.augmentationLoaded = False
344
+
345
+
322
346
  def getData(self, index : int, a : int, post_transforms : list[Transform], isInput: bool) -> torch.Tensor:
323
347
  data = self.patch.getData(self.data[a], index, a, isInput)
324
348
  for transformFunction in post_transforms: