konfai 1.1.0__py3-none-any.whl → 1.1.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of konfai might be problematic. Click here for more details.

konfai/__init__.py CHANGED
@@ -12,5 +12,5 @@ CONFIG_FILE = lambda : os.environ["KONFAI_CONFIG_FILE"]
12
12
  KONFAI_STATE = lambda : os.environ["KONFAI_STATE"]
13
13
  KONFAI_ROOT = lambda : os.environ["KONFAI_ROOT"]
14
14
  CUDA_VISIBLE_DEVICES = lambda : os.environ["CUDA_VISIBLE_DEVICES"]
15
-
15
+ KONFAI_NB_CORES = lambda : os.environ["KONFAI_NB_CORES"]
16
16
  DATE = lambda : datetime.datetime.now().strftime("%Y_%m_%d_%H_%M_%S")
@@ -8,7 +8,7 @@ from typing import Union
8
8
  import os
9
9
  from konfai import KONFAI_ROOT
10
10
  from konfai.utils.config import config
11
- from konfai.utils.utils import _getModule
11
+ from konfai.utils.utils import _getModule, AugmentationError
12
12
  from konfai.utils.dataset import Attribute, data_to_image
13
13
 
14
14
 
@@ -222,7 +222,7 @@ class ColorTransform(DataAugmentation):
222
222
  matrix = matrix[:, :3, :].mean(dim=1, keepdims=True).to(input.device)
223
223
  result = result.float() * matrix[:, :, :3].sum(dim=2, keepdims=True) + matrix[:, :, 3:]
224
224
  else:
225
- raise ValueError('Image must be RGB (3 channels) or L (1 channel)')
225
+ raise AugmentationError('Image must be RGB (3 channels) or L (1 channel)')
226
226
  results.append(result.reshape(input.shape))
227
227
  return results
228
228
 
@@ -11,11 +11,12 @@ from typing import Union, Iterator
11
11
  from concurrent.futures import ThreadPoolExecutor, as_completed
12
12
  import threading
13
13
  from torch.cuda import device_count
14
+ import SimpleITK as sitk
14
15
 
15
16
  from konfai import KONFAI_STATE, KONFAI_ROOT
16
17
  from konfai.data.patching import DatasetPatch, DatasetManager
17
18
  from konfai.utils.config import config
18
- from konfai.utils.utils import memoryInfo, cpuInfo, memoryForecast, getMemory, State
19
+ from konfai.utils.utils import memoryInfo, cpuInfo, memoryForecast, getMemory, State, SUPPORTED_EXTENSIONS, DatasetManagerError
19
20
  from konfai.utils.dataset import Dataset, Attribute
20
21
  from konfai.data.transform import TransformLoader, Transform
21
22
  from konfai.data.augmentation import DataAugmentationsList
@@ -61,12 +62,25 @@ class GroupTransform:
61
62
  for transform in self.post_transforms:
62
63
  transform.setDevice(device)
63
64
 
65
+ class GroupTransformMetric(GroupTransform):
66
+
67
+ @config()
68
+ def __init__(self, pre_transforms : Union[dict[str, TransformLoader], list[Transform]] = {"default:Normalize:Standardize:Unsqueeze:TensorCast:ResampleIsotropic:ResampleResize": TransformLoader()},
69
+ post_transforms : Union[dict[str, TransformLoader], list[Transform]] = {"default:Normalize:Standardize:Unsqueeze:TensorCast:ResampleIsotropic:ResampleResize": TransformLoader()}):
70
+ super().__init__(pre_transforms, post_transforms)
71
+
64
72
  class Group(dict[str, GroupTransform]):
65
73
 
66
74
  @config()
67
75
  def __init__(self, groups_dest: dict[str, GroupTransform] = {"default:group_dest": GroupTransform()}):
68
76
  super().__init__(groups_dest)
69
77
 
78
+ class GroupMetric(dict[str, GroupTransformMetric]):
79
+
80
+ @config()
81
+ def __init__(self, groups_dest: dict[str, GroupTransformMetric] = {"default:group_dest": GroupTransformMetric()}):
82
+ super().__init__(groups_dest)
83
+
70
84
  class CustomSampler(Sampler[int]):
71
85
 
72
86
  def __init__(self, size: int, shuffle: bool = False) -> None:
@@ -108,32 +122,33 @@ class DatasetIter(data.Dataset):
108
122
  def getDatasetFromIndex(self, group_dest: str, index: int) -> DatasetManager:
109
123
  return self.data[group_dest][index]
110
124
 
111
- def resetAugmentation(self):
112
- if self.inlineAugmentations:
125
+ def resetAugmentation(self, label):
126
+ if self.inlineAugmentations and len(self.dataAugmentationsList) > 0:
113
127
  for index in range(self.nb_dataset):
114
- self._unloadData(index)
115
128
  for group_src in self.groups_src:
116
129
  for group_dest in self.groups_src[group_src]:
130
+ self.data[group_dest][index].unloadAugmentation()
117
131
  self.data[group_dest][index].resetAugmentation()
132
+ self.load(label + " Augmentation")
118
133
 
119
- def load(self):
134
+ def load(self, label: str):
120
135
  if self.use_cache:
121
136
  memory_init = getMemory()
122
137
 
123
- indexs = [index for index in range(self.nb_dataset) if index not in self._index_cache]
138
+ indexs = [index for index in range(self.nb_dataset)]
124
139
  if len(indexs) > 0:
125
140
  memory_lock = threading.Lock()
141
+ desc = lambda : "Caching "+ label +": {} | {} | {}".format(memoryInfo(), memoryForecast(memory_init, 0, self.nb_dataset), cpuInfo())
126
142
  pbar = tqdm.tqdm(
127
143
  total=len(indexs),
128
- desc="Caching : init | {} | {}".format(memoryForecast(memory_init, 0, self.nb_dataset), cpuInfo()),
129
- leave=False,
130
- disable=self.rank != 0 and "KONFAI_CLUSTER" not in os.environ
144
+ desc=desc(),
145
+ leave=False
131
146
  )
132
147
 
133
148
  def process(index):
134
149
  self._loadData(index)
135
150
  with memory_lock:
136
- pbar.set_description("Caching : {} | {} | {}".format(memoryInfo(), memoryForecast(memory_init, index, self.nb_dataset), cpuInfo()))
151
+ pbar.set_description(desc())
137
152
  pbar.update(1)
138
153
  with ThreadPoolExecutor(max_workers=os.cpu_count()//(device_count() if device_count() > 0 else 1)) as executor:
139
154
  futures = [executor.submit(process, index) for index in indexs]
@@ -236,6 +251,8 @@ class Subset():
236
251
  index = random.sample(index, len(index))
237
252
  return set([names_filtred[i] for i in index])
238
253
 
254
+ def __str__(self):
255
+ return "Subset : " + str(self.subset) + " shuffle : "+ str(self.shuffle) + " filter : "+ str(self.filter)
239
256
  class TrainSubset(Subset):
240
257
 
241
258
  @config()
@@ -254,10 +271,10 @@ class Data(ABC):
254
271
  groups_src : dict[str, Group],
255
272
  patch : Union[DatasetPatch, None],
256
273
  use_cache : bool,
257
- subset : Union[Subset, dict[str, Subset]],
274
+ subset : Subset,
258
275
  num_workers : int,
259
276
  batch_size : int,
260
- validation: Union[float, str, list[int], list[str]] = 1,
277
+ validation: Union[float, str, list[int], list[str], None] = None,
261
278
  inlineAugmentations: bool = False,
262
279
  dataAugmentationsList: dict[str, DataAugmentationsList]= {}) -> None:
263
280
  self.dataset_filenames = dataset_filenames
@@ -290,6 +307,13 @@ class Data(ABC):
290
307
  map.append((x, y, z))
291
308
  return data, map
292
309
 
310
+ def getGroupsDest(self):
311
+ groupsDest = []
312
+ for group_src in self.groups_src:
313
+ for group_dest in self.groups_src[group_src]:
314
+ groupsDest.append(group_dest)
315
+ return groupsDest
316
+
293
317
  def _split(map: list[tuple[int, int, int]], world_size: int) -> list[list[tuple[int, int, int]]]:
294
318
  if len(map) == 0:
295
319
  return [[] for _ in range(world_size)]
@@ -313,7 +337,14 @@ class Data(ABC):
313
337
 
314
338
  def getData(self, world_size: int) -> list[list[DataLoader]]:
315
339
  datasets: dict[str, list[(str, bool)]] = {}
340
+ if self.dataset_filenames is None or len(self.dataset_filenames) == 0:
341
+ raise DatasetManagerError("No dataset filenames were provided")
316
342
  for dataset_filename in self.dataset_filenames:
343
+ if dataset_filename is None:
344
+ raise DatasetManagerError("Invalid dataset entry: 'None' received.",
345
+ "Each dataset must be a valid path string (e.g., './Dataset/', './Dataset/:mha, './Dataset/:a:mha', './Dataset/:i:mha').",
346
+ "Please check your 'dataset_filenames' list for missing or null entries."
347
+ )
317
348
  if len(dataset_filename.split(":")) == 1:
318
349
  filename = dataset_filename
319
350
  format = "mha"
@@ -324,9 +355,13 @@ class Data(ABC):
324
355
  else:
325
356
  filename, flag, format = dataset_filename.split(":")
326
357
  append = flag == "a"
327
-
328
- dataset = Dataset(filename, format)
358
+
359
+ if format not in SUPPORTED_EXTENSIONS:
360
+ raise DatasetManagerError(f"Unsupported file format '{format}'.",
361
+ f"Supported extensions are: {', '.join(SUPPORTED_EXTENSIONS)}")
329
362
 
363
+ dataset = Dataset(filename, format)
364
+
330
365
  self.datasets[filename] = dataset
331
366
  for group in self.groups_src:
332
367
  if dataset.isGroupExist(group):
@@ -334,47 +369,88 @@ class Data(ABC):
334
369
  datasets[group].append((filename, append))
335
370
  else:
336
371
  datasets[group] = [(filename, append)]
372
+ modelHaveInput = False
337
373
  for group_src in self.groups_src:
338
374
  if group_src not in datasets:
339
- raise ValueError("[DatasetManager] Error: group source {} not found. Available groups: {}".format(group_src, list(datasets.keys())))
340
375
 
376
+ raise DatasetManagerError(
377
+ f"Group source '{group_src}' not found in any dataset.",
378
+ f"Dataset filenames provided: {self.dataset_filenames}",
379
+ "Available groups across all datasets: {}".format(["{} {}".format(f, d.getGroup()) for f, d in self.datasets.items()]),
380
+ f"Please check that an entry in the dataset with the name '{group_src}.{format}' exists."
381
+ )
382
+
341
383
  for group_dest in self.groups_src[group_src]:
342
384
  self.groups_src[group_src][group_dest].load(group_src, group_dest, [self.datasets[filename] for filename, _ in datasets[group_src]])
385
+ modelHaveInput |= self.groups_src[group_src][group_dest].isInput
386
+
387
+ if not modelHaveInput:
388
+ raise DatasetManagerError(
389
+ "At least one group must be defined with 'isInput: true' to provide input to the network."
390
+ )
391
+
343
392
  for key, dataAugmentations in self.dataAugmentationsList.items():
344
393
  dataAugmentations.load(key)
345
394
 
346
- names : list[list[str]] = []
395
+ names = set()
347
396
  dataset_name : dict[str, dict[str, list[str]]] = {}
348
397
  dataset_info : dict[str, dict[str, dict[str, Attribute]]] = {}
349
398
  for group in self.groups_src:
399
+ namesByGroup = set()
350
400
  if group not in dataset_name:
351
401
  dataset_name[group] = {}
352
402
  dataset_info[group] = {}
353
403
  for filename, _ in datasets[group]:
354
- names.append(self.datasets[filename].getNames(group))
404
+ namesByGroup.update(self.datasets[filename].getNames(group))
355
405
  dataset_name[group][filename] = self.datasets[filename].getNames(group)
356
406
  dataset_info[group][filename] = {name: self.datasets[filename].getInfos(group, name) for name in dataset_name[group][filename]}
407
+ if len(names) == 0:
408
+ names.update(namesByGroup)
409
+ else:
410
+ names = names.intersection(namesByGroup)
411
+ if len(names) == 0:
412
+ raise DatasetManagerError(
413
+ f"No data was found for groups {list(self.groups_src.keys())}: although each group contains data from a dataset, there are no common dataset names shared across all groups, the intersection is empty."
414
+ )
415
+
357
416
  subset_names = set()
358
- if isinstance(self.subset, dict):
359
- for filename, subset in self.subset.items():
360
- subset_names.update(subset([dataset_name[group][filename] for group in dataset_name], [dataset_info[group][filename] for group in dataset_name]))
361
- else:
362
- for group in dataset_name:
363
- for filename, append in datasets[group]:
364
- if append:
365
- subset_names.update(self.subset([dataset_name[group][filename]], [dataset_info[group][filename]]))
417
+ for group in dataset_name:
418
+ subset_names_bygroup = set()
419
+ for filename, append in datasets[group]:
420
+ if append:
421
+ subset_names_bygroup.update(self.subset([dataset_name[group][filename]], [dataset_info[group][filename]]))
422
+ else:
423
+ if len(subset_names_bygroup) == 0:
424
+ subset_names_bygroup.update(self.subset([dataset_name[group][filename]], [dataset_info[group][filename]]))
366
425
  else:
367
- if len(subset_names) == 0:
368
- subset_names.update(self.subset([dataset_name[group][filename]], [dataset_info[group][filename]]))
369
- else:
370
- subset_names.intersection(self.subset([dataset_name[group][filename]], [dataset_info[group][filename]]))
426
+ subset_names_bygroup = subset_names_bygroup.intersection(self.subset([dataset_name[group][filename]], [dataset_info[group][filename]]))
427
+ if len(subset_names) == 0:
428
+ subset_names.update(subset_names_bygroup)
429
+ else:
430
+ subset_names = subset_names.intersection(subset_names_bygroup)
431
+ if len(subset_names) == 0:
432
+ raise DatasetManagerError("All data entries were excluded by the subset filter.",
433
+ f"Dataset entries found: {', '.join(names)}",
434
+ f"Subset object applied: {self.subset}",
435
+ f"Subset requested : {', '.join(subset_names)}",
436
+ "None of the dataset entries matched the given subset.",
437
+ "Please check your 'subset' configuration — it may be too restrictive or incorrectly formatted.",
438
+ "Examples of valid subset formats:",
439
+ "\tsubset: [0, 1] # explicit indices",
440
+ "\tsubset: 0:10 # slice notation",
441
+ "\tsubset: ./Validation.txt # external file",
442
+ "\tsubset: None # to disable filtering"
443
+ )
444
+
371
445
  data, map = self._getDatasets(list(subset_names), dataset_name)
372
-
446
+
373
447
  train_map = map
374
448
  validate_map = []
375
- if isinstance(self.validation, float):
376
- if self.validation < 1.0 and int(math.floor(len(map)*(1-self.validation))) > 0:
377
- train_map, validate_map = map[:int(math.floor(len(map)*self.validation))], map[int(math.floor(len(map)*self.validation)):]
449
+ if isinstance(self.validation, float) or isinstance(self.validation, int):
450
+ if self.validation <= 0 or self.validation >= 1:
451
+ raise DatasetManagerError("Validation must be a float between 0 and 1.", f"Received: {self.validation}", "Example: validation = 0.2 # for a 20% validation split")
452
+
453
+ train_map, validate_map = map[:int(math.floor(len(map)*(1-self.validation)))], map[int(math.floor(len(map)*(1-self.validation))):]
378
454
  elif isinstance(self.validation, str):
379
455
  if ":" in self.validation:
380
456
  index = list(range(int(self.subset.split(":")[0]), int(self.subset.split(":")[1])))
@@ -389,7 +465,17 @@ class Data(ABC):
389
465
  train_map = [m for m in map if m[0] not in index]
390
466
  validate_map = [m for m in map if m[0] in index]
391
467
  else:
392
- validate_map = train_map
468
+ raise DatasetManagerError(
469
+ f"Invalid string value for 'validation': '{self.validation}'",
470
+ "Expected one of the following formats:",
471
+ "\t• A slice string like '0:10'",
472
+ "\t• A path to a text file listing validation sample names (e.g., './val.txt')",
473
+ "\t• A float between 0 and 1 (e.g., 0.2)",
474
+ "\t• A list of sample names or indices",
475
+ "The provided value is neither a valid slice nor a readable file.",
476
+ "Please fix your 'validation' setting in the configuration."
477
+ )
478
+
393
479
  elif isinstance(self.validation, list):
394
480
  if len(self.validation) > 0:
395
481
  if isinstance(self.validation[0], int):
@@ -399,10 +485,29 @@ class Data(ABC):
399
485
  index = [i for i, n in enumerate(subset_names) if n in self.validation]
400
486
  train_map = [m for m in map if m[0] not in index]
401
487
  validate_map = [m for m in map if m[0] in index]
402
-
488
+ else:
489
+ raise DatasetManagerError(f"Invalid list type for 'validation': elements of type '{type(self.validation[0]).__name__}' are not supported.",
490
+ "Supported list element types are:",
491
+ "\t• int → list of indices (e.g., [0, 1, 2])",
492
+ "\t• str → list of sample names (e.g., ['patient01', 'patient02'])",
493
+ f"Received list: {self.validation}"
494
+ )
495
+ if len(train_map) == 0:
496
+ raise DatasetManagerError("No data left for training after applying the validation split.",
497
+ f"Dataset size: {len(map)}",
498
+ f"Validation setting: {self.validation}",
499
+ "Please reduce the validation size, increase the dataset, or disable validation."
500
+ )
501
+
502
+ if self.validation is not None and len(validate_map) == 0:
503
+ raise DatasetManagerError("No data left for validation after applying the validation split.",
504
+ f"Dataset size: {len(map)}",
505
+ f"Validation setting: {self.validation}",
506
+ "Please increase the validation size, increase the dataset, or disable validation."
507
+ )
403
508
  train_maps = Data._split(train_map, world_size)
404
509
  validate_maps = Data._split(validate_map, world_size)
405
-
510
+
406
511
  for i, (train_map, validate_map) in enumerate(zip(train_maps, validate_maps)):
407
512
  maps = [train_map]
408
513
  if len(validate_map):
@@ -436,7 +541,7 @@ class DataTrain(Data):
436
541
  subset : Union[TrainSubset, dict[str, TrainSubset]] = TrainSubset(),
437
542
  num_workers : int = 4,
438
543
  batch_size : int = 1,
439
- validation : Union[float, str, list[int], list[str]] = 0.8) -> None:
544
+ validation : Union[float, str, list[int], list[str]] = 0.2) -> None:
440
545
  super().__init__(dataset_filenames, groups_src, patch, use_cache, subset, num_workers, batch_size, validation, inlineAugmentations, augmentations if augmentations else {})
441
546
 
442
547
  class DataPrediction(Data):
@@ -445,22 +550,20 @@ class DataPrediction(Data):
445
550
  def __init__(self, dataset_filenames : list[str] = ["default:./Dataset"],
446
551
  groups_src : dict[str, Group] = {"default" : Group()},
447
552
  augmentations : Union[dict[str, DataAugmentationsList], None] = {"DataAugmentation_0" : DataAugmentationsList()},
448
- inlineAugmentations: bool = False,
449
553
  patch : Union[DatasetPatch, None] = DatasetPatch(),
450
- use_cache : bool = True,
451
554
  subset : Union[PredictionSubset, dict[str, PredictionSubset]] = PredictionSubset(),
452
555
  num_workers : int = 4,
453
556
  batch_size : int = 1) -> None:
454
557
 
455
- super().__init__(dataset_filenames, groups_src, patch, use_cache, subset, num_workers, batch_size, inlineAugmentations=inlineAugmentations, dataAugmentationsList=augmentations if augmentations else {})
558
+ super().__init__(dataset_filenames, groups_src, patch, False, subset, num_workers, batch_size, dataAugmentationsList=augmentations if augmentations else {})
456
559
 
457
560
  class DataMetric(Data):
458
561
 
459
562
  @config("Dataset")
460
563
  def __init__(self, dataset_filenames : list[str] = ["default:./Dataset"],
461
- groups_src : dict[str, Group] = {"default" : Group()},
564
+ groups_src : dict[str, GroupMetric] = {"default" : GroupMetric()},
462
565
  subset : Union[PredictionSubset, dict[str, PredictionSubset]] = PredictionSubset(),
463
566
  validation: Union[str, None] = None,
464
567
  num_workers : int = 4) -> None:
465
568
 
466
- super().__init__(dataset_filenames=dataset_filenames, groups_src=groups_src, patch=None, use_cache=False, subset=subset, num_workers=num_workers, batch_size=1, validation=1 if validation is None else validation)
569
+ super().__init__(dataset_filenames=dataset_filenames, groups_src=groups_src, patch=None, use_cache=False, subset=subset, num_workers=num_workers, batch_size=1, validation=validation)
konfai/data/patching.py CHANGED
@@ -15,7 +15,6 @@ from konfai.utils.dataset import Dataset, Attribute
15
15
  from konfai.data.transform import Transform, Save
16
16
  from konfai.data.augmentation import DataAugmentationsList
17
17
 
18
-
19
18
  class PathCombine(ABC):
20
19
 
21
20
  def __init__(self) -> None:
@@ -150,7 +149,10 @@ class Patch(ABC):
150
149
 
151
150
  def __init__(self, patch_size: list[int], overlap: Union[int, None], path_mask: Union[str, None] = None, padValue: float = 0, extend_slice: int = 0) -> None:
152
151
  self.patch_size = patch_size
153
- self.overlap = overlap
152
+ self.overlap = overlap
153
+ if isinstance(self.overlap, int):
154
+ if self.overlap < 0:
155
+ self.overlap = None
154
156
  self._patch_slices : dict[int, list[tuple[slice]]] = {}
155
157
  self._nb_patch_per_dim: dict[int, list[tuple[int, bool]]] = {}
156
158
  self.path_mask = path_mask
@@ -237,6 +239,7 @@ class DatasetManager():
237
239
  self.index = index
238
240
  self.dataset = dataset
239
241
  self.loaded = False
242
+ self.augmentationLoaded = False
240
243
  self.cache_attributes: list[Attribute] = []
241
244
  _shape, cache_attribute = self.dataset.getInfos(self.group_src, name)
242
245
  self.cache_attributes.append(cache_attribute)
@@ -255,6 +258,7 @@ class DatasetManager():
255
258
  self.cache_attributes_bak = copy.deepcopy(self.cache_attributes)
256
259
 
257
260
  def resetAugmentation(self):
261
+ self.cache_attributes[:] = self.cache_attributes[:1]
258
262
  i = 1
259
263
  for dataAugmentations in self.dataAugmentationsList:
260
264
  shape = []
@@ -269,15 +273,24 @@ class DatasetManager():
269
273
  self.cache_attributes.append(caches_attribute[it])
270
274
  self.patch.load(s, i)
271
275
  i+=1
272
-
276
+
273
277
  def load(self, pre_transform : list[Transform], dataAugmentationsList : list[DataAugmentationsList], device: torch.device) -> None:
274
- if self.loaded:
275
- return
278
+ if not self.loaded:
279
+ self._load(pre_transform)
280
+ if not self.augmentationLoaded:
281
+ self._loadAugmentation(dataAugmentationsList, device)
282
+
283
+ def _load(self, pre_transform : list[Transform]):
284
+ self.cache_attributes = copy.deepcopy(self.cache_attributes_bak)
276
285
  i = len(pre_transform)
277
286
  data = None
278
287
  for transformFunction in reversed(pre_transform):
279
288
  if isinstance(transformFunction, Save):
280
- filename, format = transformFunction.save.split(":")
289
+ if len(transformFunction.dataset.split(":")) > 1:
290
+ filename, format = transformFunction.dataset.split(":")
291
+ else:
292
+ filename = transformFunction.dataset.split(":")
293
+ format = "mha"
281
294
  dataset = Dataset(filename, format)
282
295
  if dataset.isDatasetExist(self.group_dest, self.name):
283
296
  data, attrib = dataset.readData(self.group_dest, self.name)
@@ -295,27 +308,40 @@ class DatasetManager():
295
308
  for transformFunction in pre_transform[i:]:
296
309
  data = transformFunction(self.name, data, self.cache_attributes[0])
297
310
  if isinstance(transformFunction, Save):
298
- filename, format = transformFunction.save.split(":")
311
+ if len(transformFunction.dataset.split(":")) > 1:
312
+ filename, format = transformFunction.dataset.split(":")
313
+ else:
314
+ filename = transformFunction.dataset.split(":")
315
+ format = "mha"
299
316
  dataset = Dataset(filename, format)
300
317
  dataset.write(self.group_dest, self.name, data.numpy(), self.cache_attributes[0])
301
318
  self.data : list[torch.Tensor] = list()
302
319
  self.data.append(data)
303
320
 
321
+ for i in range(len(self.cache_attributes)-1):
322
+ self.cache_attributes[i+1].update(self.cache_attributes[0])
323
+ self.loaded = True
324
+
325
+ def _loadAugmentation(self, dataAugmentationsList : list[DataAugmentationsList], device: torch.device) -> None:
304
326
  for dataAugmentations in dataAugmentationsList:
305
- a_data = [data.clone() for _ in range(dataAugmentations.nb)]
327
+ a_data = [self.data[0].clone() for _ in range(dataAugmentations.nb)]
306
328
  for dataAugmentation in dataAugmentations.dataAugmentations:
307
329
  a_data = dataAugmentation(self.index, a_data, device)
308
330
 
309
331
  for d in a_data:
310
332
  self.data.append(d)
311
- self.loaded = True
333
+ self.augmentationLoaded = True
312
334
 
313
335
  def unload(self) -> None:
314
- if hasattr(self, "data"):
315
- del self.data
316
- self.cache_attributes = copy.deepcopy(self.cache_attributes_bak)
336
+ self.data.clear()
317
337
  self.loaded = False
318
-
338
+ self.augmentationLoaded = False
339
+
340
+ def unloadAugmentation(self) -> None:
341
+ self.data[:] = self.data[:1]
342
+ self.augmentationLoaded = False
343
+
344
+
319
345
  def getData(self, index : int, a : int, post_transforms : list[Transform], isInput: bool) -> torch.Tensor:
320
346
  data = self.patch.getData(self.data[a], index, a, isInput)
321
347
  for transformFunction in post_transforms:
konfai/data/transform.py CHANGED
@@ -6,7 +6,7 @@ from abc import ABC, abstractmethod
6
6
  import torch.nn.functional as F
7
7
  from typing import Any, Union
8
8
 
9
- from konfai.utils.utils import _getModule, NeedDevice, _resample_affine, _affine_matrix
9
+ from konfai.utils.utils import _getModule, NeedDevice, _resample_affine, _affine_matrix, TransformError
10
10
  from konfai.utils.dataset import Dataset, Attribute, data_to_image, image_to_data
11
11
  from konfai.utils.config import config
12
12
 
@@ -211,39 +211,67 @@ class Resample(Transform, ABC):
211
211
  _ = cache_attribute.pop_np_array("Spacing")
212
212
  return self._resample(input, [int(size) for size in size_1])
213
213
 
214
- class ResampleIsotropic(Resample):
214
+ class ResampleToResolution(Resample):
215
+
216
+ def __init__(self, spacing : list[Union[float, None]] = [1., 1., 1.]) -> None:
217
+ self.spacing = torch.tensor([0 if s < 0 else s for s in spacing])
215
218
 
216
- def __init__(self, spacing : list[float] = [1., 1., 1.]) -> None:
217
- self.spacing = torch.tensor(spacing, dtype=torch.float64)
218
-
219
219
  def transformShape(self, shape: list[int], cache_attribute: Attribute) -> list[int]:
220
- assert "Spacing" in cache_attribute, "Error no spacing"
221
- resize_factor = self.spacing/cache_attribute.get_tensor("Spacing").flip(0)
222
- return [int(x) for x in (torch.tensor(shape) * 1/resize_factor)]
220
+ if "Spacing" not in cache_attribute:
221
+ TransformError("Missing 'Spacing' in cache attributes, the data is likely not a valid image.",
222
+ "Make sure your input is a image (e.g., .nii, .mha) with proper metadata.")
223
+ if len(shape) != len(self.spacing):
224
+ TransformError("Shape and spacing dimensions do not match: shape={shape}, spacing={self.spacing}")
225
+ image_spacing = cache_attribute.get_tensor("Spacing").flip(0)
226
+ spacing = self.spacing
227
+
228
+ for i, s in enumerate(self.spacing):
229
+ if s == 0:
230
+ spacing[i] = image_spacing[i]
231
+ resize_factor = spacing/cache_attribute.get_tensor("Spacing").flip(0)
232
+ return [int(x) for x in (torch.tensor(shape) * 1/resize_factor)]
223
233
 
224
234
  def __call__(self, name: str, input : torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
225
- assert "Spacing" in cache_attribute, "Error no spacing"
226
- resize_factor = self.spacing/cache_attribute.get_tensor("Spacing").flip(0)
227
- cache_attribute["Spacing"] = self.spacing.flip(0)
235
+ image_spacing = cache_attribute.get_tensor("Spacing").flip(0)
236
+ spacing = self.spacing
237
+ for i, s in enumerate(self.spacing):
238
+ if s == 0:
239
+ spacing[i] = image_spacing[i]
240
+ resize_factor = spacing/cache_attribute.get_tensor("Spacing").flip(0)
241
+ cache_attribute["Spacing"] = spacing.flip(0)
228
242
  cache_attribute["Size"] = np.asarray([int(x) for x in torch.tensor(input.shape[1:])])
229
243
  size = [int(x) for x in (torch.tensor(input.shape[1:]) * 1/resize_factor)]
230
244
  cache_attribute["Size"] = np.asarray(size)
231
245
  return self._resample(input, size)
232
246
 
233
- class ResampleResize(Resample):
247
+ class ResampleToSize(Resample):
234
248
 
235
249
  def __init__(self, size : list[int] = [100,512,512]) -> None:
236
250
  self.size = size
237
251
 
238
252
  def transformShape(self, shape: list[int], cache_attribute: Attribute) -> list[int]:
239
- return self.size
253
+ if "Spacing" not in cache_attribute:
254
+ TransformError("Missing 'Spacing' in cache attributes, the data is likely not a valid image.",
255
+ "Make sure your input is a image (e.g., .nii, .mha) with proper metadata.")
256
+ if len(shape) != len(self.size):
257
+ TransformError("Shape and spacing dimensions do not match: shape={shape}, spacing={self.spacing}")
258
+ size = self.size
259
+ for i, s in enumerate(self.size):
260
+ if s == -1:
261
+ size[i] = shape[i]
262
+ return size
240
263
 
241
264
  def __call__(self, name: str, input: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
265
+ size = self.size
266
+ image_size = np.asarray([int(x) for x in torch.tensor(input.shape[1:])])
267
+ for i, s in enumerate(self.size):
268
+ if s is None:
269
+ size[i] = image_size[i]
242
270
  if "Spacing" in cache_attribute:
243
- cache_attribute["Spacing"] = torch.flip(torch.tensor(list(input.shape[1:]))/torch.tensor(self.size)*torch.flip(cache_attribute.get_tensor("Spacing"), dims=[0]), dims=[0])
244
- cache_attribute["Size"] = np.asarray([int(x) for x in torch.tensor(input.shape[1:])])
245
- cache_attribute["Size"] = self.size
246
- return self._resample(input, self.size)
271
+ cache_attribute["Spacing"] = torch.flip(torch.tensor(image_size)/torch.tensor(size)*torch.flip(cache_attribute.get_tensor("Spacing"), dims=[0]), dims=[0])
272
+ cache_attribute["Size"] = image_size
273
+ cache_attribute["Size"] = size
274
+ return self._resample(input, size)
247
275
 
248
276
  class ResampleTransform(Transform):
249
277
 
@@ -412,8 +440,8 @@ class FlatLabel(Transform):
412
440
 
413
441
  class Save(Transform):
414
442
 
415
- def __init__(self, save: str) -> None:
416
- self.save = save
443
+ def __init__(self, dataset: str) -> None:
444
+ self.dataset = dataset
417
445
 
418
446
  def __call__(self, name: str, input : torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
419
447
  return input
@@ -528,8 +556,7 @@ class OneHot(Transform):
528
556
  self.num_classes = num_classes
529
557
 
530
558
  def __call__(self, name: str, input: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
531
- result = F.one_hot(input.type(torch.int64), num_classes=self.num_classes).permute(0, len(input.shape), *[i+1 for i in range(len(input.shape)-1)]).float().squeeze(2)
532
- print(result.shape)
559
+ result = F.one_hot(input.type(torch.int64), num_classes=self.num_classes).permute(0, len(input.shape), *[i+1 for i in range(len(input.shape)-1)]).float().squeeze(0)
533
560
  return result
534
561
 
535
562
  def inverse(self, name: str, input: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor: