konfai 1.0.9__py3-none-any.whl → 1.1.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of konfai might be problematic. Click here for more details.

@@ -8,7 +8,7 @@ from typing import Union
8
8
  import os
9
9
  from konfai import KONFAI_ROOT
10
10
  from konfai.utils.config import config
11
- from konfai.utils.utils import _getModule
11
+ from konfai.utils.utils import _getModule, AugmentationError
12
12
  from konfai.utils.dataset import Attribute, data_to_image
13
13
 
14
14
 
@@ -222,7 +222,7 @@ class ColorTransform(DataAugmentation):
222
222
  matrix = matrix[:, :3, :].mean(dim=1, keepdims=True).to(input.device)
223
223
  result = result.float() * matrix[:, :, :3].sum(dim=2, keepdims=True) + matrix[:, :, 3:]
224
224
  else:
225
- raise ValueError('Image must be RGB (3 channels) or L (1 channel)')
225
+ raise AugmentationError('Image must be RGB (3 channels) or L (1 channel)')
226
226
  results.append(result.reshape(input.shape))
227
227
  return results
228
228
 
@@ -11,11 +11,12 @@ from typing import Union, Iterator
11
11
  from concurrent.futures import ThreadPoolExecutor, as_completed
12
12
  import threading
13
13
  from torch.cuda import device_count
14
+ import SimpleITK as sitk
14
15
 
15
16
  from konfai import KONFAI_STATE, KONFAI_ROOT
16
17
  from konfai.data.patching import DatasetPatch, DatasetManager
17
18
  from konfai.utils.config import config
18
- from konfai.utils.utils import memoryInfo, cpuInfo, memoryForecast, getMemory, State
19
+ from konfai.utils.utils import memoryInfo, cpuInfo, memoryForecast, getMemory, State, SUPPORTED_EXTENSIONS, DatasetManagerError
19
20
  from konfai.utils.dataset import Dataset, Attribute
20
21
  from konfai.data.transform import TransformLoader, Transform
21
22
  from konfai.data.augmentation import DataAugmentationsList
@@ -236,6 +237,8 @@ class Subset():
236
237
  index = random.sample(index, len(index))
237
238
  return set([names_filtred[i] for i in index])
238
239
 
240
+ def __str__(self):
241
+ return "Subset : " + str(self.subset) + " shuffle : "+ str(self.shuffle) + " filter : "+ str(self.filter)
239
242
  class TrainSubset(Subset):
240
243
 
241
244
  @config()
@@ -257,14 +260,14 @@ class Data(ABC):
257
260
  subset : Union[Subset, dict[str, Subset]],
258
261
  num_workers : int,
259
262
  batch_size : int,
260
- train_size: Union[float, str, list[int], list[str]] = 1,
263
+ validation: Union[float, str, list[int], list[str]] = 1,
261
264
  inlineAugmentations: bool = False,
262
265
  dataAugmentationsList: dict[str, DataAugmentationsList]= {}) -> None:
263
266
  self.dataset_filenames = dataset_filenames
264
267
  self.subset = subset
265
268
  self.groups_src = groups_src
266
269
  self.patch = patch
267
- self.train_size = train_size
270
+ self.validation = validation
268
271
  self.dataAugmentationsList = dataAugmentationsList
269
272
  self.batch_size = batch_size
270
273
  self.dataSet_args = dict(groups_src=self.groups_src, inlineAugmentations=inlineAugmentations, dataAugmentationsList = list(self.dataAugmentationsList.values()), use_cache = use_cache, buffer_size=batch_size+1, patch_size=self.patch.patch_size if self.patch is not None else None, overlap=self.patch.overlap if self.patch is not None else None)
@@ -290,6 +293,13 @@ class Data(ABC):
290
293
  map.append((x, y, z))
291
294
  return data, map
292
295
 
296
+ def getGroupsDest(self):
297
+ groupsDest = []
298
+ for group_src in self.groups_src:
299
+ for group_dest in self.groups_src[group_src]:
300
+ groupsDest.append(group_dest)
301
+ return groupsDest
302
+
293
303
  def _split(map: list[tuple[int, int, int]], world_size: int) -> list[list[tuple[int, int, int]]]:
294
304
  if len(map) == 0:
295
305
  return [[] for _ in range(world_size)]
@@ -313,7 +323,14 @@ class Data(ABC):
313
323
 
314
324
  def getData(self, world_size: int) -> list[list[DataLoader]]:
315
325
  datasets: dict[str, list[(str, bool)]] = {}
326
+ if self.dataset_filenames is None or len(self.dataset_filenames) == 0:
327
+ raise DatasetManagerError("No dataset filenames were provided")
316
328
  for dataset_filename in self.dataset_filenames:
329
+ if dataset_filename is None:
330
+ raise DatasetManagerError("Invalid dataset entry: 'None' received.",
331
+ "Each dataset must be a valid path string (e.g., './Dataset/', './Dataset/:mha, './Dataset/:a:mha', './Dataset/:i:mha').",
332
+ "Please check your 'dataset_filenames' list for missing or null entries."
333
+ )
317
334
  if len(dataset_filename.split(":")) == 1:
318
335
  filename = dataset_filename
319
336
  format = "mha"
@@ -324,9 +341,13 @@ class Data(ABC):
324
341
  else:
325
342
  filename, flag, format = dataset_filename.split(":")
326
343
  append = flag == "a"
327
-
328
- dataset = Dataset(filename, format)
344
+
345
+ if format not in SUPPORTED_EXTENSIONS:
346
+ raise DatasetManagerError(f"Unsupported file format '{format}'.",
347
+ f"Supported extensions are: {', '.join(SUPPORTED_EXTENSIONS)}")
329
348
 
349
+ dataset = Dataset(filename, format)
350
+
330
351
  self.datasets[filename] = dataset
331
352
  for group in self.groups_src:
332
353
  if dataset.isGroupExist(group):
@@ -334,16 +355,30 @@ class Data(ABC):
334
355
  datasets[group].append((filename, append))
335
356
  else:
336
357
  datasets[group] = [(filename, append)]
358
+ modelHaveInput = False
337
359
  for group_src in self.groups_src:
338
360
  if group_src not in datasets:
339
- raise ValueError("[DatasetManager] Error: group source {} not found. Available groups: {}".format(group_src, list(datasets.keys())))
340
361
 
362
+ raise DatasetManagerError(
363
+ f"Group source '{group_src}' not found in any dataset.",
364
+ f"Dataset filenames provided: {self.dataset_filenames}",
365
+ f"Available groups across all datasets: {sorted(list(datasets.keys()))}",
366
+ f"Please check that an entry in the dataset with the name '{group_src}.{format}' exists."
367
+ )
368
+
341
369
  for group_dest in self.groups_src[group_src]:
342
370
  self.groups_src[group_src][group_dest].load(group_src, group_dest, [self.datasets[filename] for filename, _ in datasets[group_src]])
371
+ modelHaveInput |= self.groups_src[group_src][group_dest].isInput
372
+
373
+ if not modelHaveInput:
374
+ raise DatasetManagerError(
375
+ "At least one group must be defined with 'isInput: true' to provide input to the network."
376
+ )
377
+
343
378
  for key, dataAugmentations in self.dataAugmentationsList.items():
344
379
  dataAugmentations.load(key)
345
380
 
346
- names : list[list[str]] = []
381
+ names = set()
347
382
  dataset_name : dict[str, dict[str, list[str]]] = {}
348
383
  dataset_info : dict[str, dict[str, dict[str, Attribute]]] = {}
349
384
  for group in self.groups_src:
@@ -351,10 +386,11 @@ class Data(ABC):
351
386
  dataset_name[group] = {}
352
387
  dataset_info[group] = {}
353
388
  for filename, _ in datasets[group]:
354
- names.append(self.datasets[filename].getNames(group))
389
+ names.update(self.datasets[filename].getNames(group))
355
390
  dataset_name[group][filename] = self.datasets[filename].getNames(group)
356
391
  dataset_info[group][filename] = {name: self.datasets[filename].getInfos(group, name) for name in dataset_name[group][filename]}
357
392
  subset_names = set()
393
+
358
394
  if isinstance(self.subset, dict):
359
395
  for filename, subset in self.subset.items():
360
396
  subset_names.update(subset([dataset_name[group][filename] for group in dataset_name], [dataset_info[group][filename] for group in dataset_name]))
@@ -368,41 +404,86 @@ class Data(ABC):
368
404
  subset_names.update(self.subset([dataset_name[group][filename]], [dataset_info[group][filename]]))
369
405
  else:
370
406
  subset_names.intersection(self.subset([dataset_name[group][filename]], [dataset_info[group][filename]]))
407
+ if len(subset_names) == 0:
408
+ raise DatasetManagerError("All data entries were excluded by the subset filter.",
409
+ f"Dataset entries found: {', '.join(names)}",
410
+ f"Subset object applied: {self.subset}",
411
+ f"Subset requested : {', '.join(subset_names)}",
412
+ "None of the dataset entries matched the given subset.",
413
+ "Please check your 'subset' configuration — it may be too restrictive or incorrectly formatted.",
414
+ "Examples of valid subset formats:",
415
+ "\tsubset: [0, 1] # explicit indices",
416
+ "\tsubset: 0:10 # slice notation",
417
+ "\tsubset: ./Validation.txt # external file",
418
+ "\tsubset: None # to disable filtering"
419
+ )
420
+
371
421
  data, map = self._getDatasets(list(subset_names), dataset_name)
372
-
422
+
373
423
  train_map = map
374
424
  validate_map = []
375
- if isinstance(self.train_size, float):
376
- if self.train_size < 1.0 and int(math.floor(len(map)*(1-self.train_size))) > 0:
377
- train_map, validate_map = map[:int(math.floor(len(map)*self.train_size))], map[int(math.floor(len(map)*self.train_size)):]
378
- elif isinstance(self.train_size, str):
379
- if ":" in self.train_size:
425
+ if isinstance(self.validation, float) or isinstance(self.validation, int):
426
+ if self.validation <= 0 or self.validation >= 1:
427
+ raise DatasetManagerError("validation must be a float between 0 and 1.", f"→ Received: {self.validation}", "→ Example: validation = 0.2 # for a 20% validation split")
428
+
429
+ train_map, validate_map = map[:int(math.floor(len(map)*(1-self.validation)))], map[int(math.floor(len(map)*(1-self.validation))):]
430
+ elif isinstance(self.validation, str):
431
+ if ":" in self.validation:
380
432
  index = list(range(int(self.subset.split(":")[0]), int(self.subset.split(":")[1])))
381
433
  train_map = [m for m in map if m[0] not in index]
382
434
  validate_map = [m for m in map if m[0] in index]
383
- elif os.path.exists(self.train_size):
435
+ elif os.path.exists(self.validation):
384
436
  validation_names = []
385
- with open(self.train_size, "r") as f:
437
+ with open(self.validation, "r") as f:
386
438
  for name in f:
387
439
  validation_names.append(name.strip())
388
440
  index = [i for i, n in enumerate(subset_names) if n in validation_names]
389
441
  train_map = [m for m in map if m[0] not in index]
390
442
  validate_map = [m for m in map if m[0] in index]
391
443
  else:
392
- validate_map = train_map
393
- elif isinstance(self.train_size, list):
394
- if len(self.train_size) > 0:
395
- if isinstance(self.train_size[0], int):
396
- train_map = [m for m in map if m[0] not in self.train_size]
397
- validate_map = [m for m in map if m[0] in self.train_size]
398
- elif isinstance(self.train_size[0], str):
399
- index = [i for i, n in enumerate(subset_names) if n in self.train_size]
444
+ raise DatasetManagerError(
445
+ f"Invalid string value for 'validation': '{self.validation}'",
446
+ "Expected one of the following formats:",
447
+ "\t• A slice string like '0:10'",
448
+ "\t• A path to a text file listing validation sample names (e.g., './val.txt')",
449
+ "\t• A float between 0 and 1 (e.g., 0.2)",
450
+ "\t• A list of sample names or indices",
451
+ "The provided value is neither a valid slice nor a readable file.",
452
+ "Please fix your 'validation' setting in the configuration."
453
+ )
454
+
455
+ elif isinstance(self.validation, list):
456
+ if len(self.validation) > 0:
457
+ if isinstance(self.validation[0], int):
458
+ train_map = [m for m in map if m[0] not in self.validation]
459
+ validate_map = [m for m in map if m[0] in self.validation]
460
+ elif isinstance(self.validation[0], str):
461
+ index = [i for i, n in enumerate(subset_names) if n in self.validation]
400
462
  train_map = [m for m in map if m[0] not in index]
401
463
  validate_map = [m for m in map if m[0] in index]
402
-
464
+ else:
465
+ raise DatasetManagerError(f"Invalid list type for 'validation': elements of type '{type(self.validation[0]).__name__}' are not supported.",
466
+ "Supported list element types are:",
467
+ "\t• int → list of indices (e.g., [0, 1, 2])",
468
+ "\t• str → list of sample names (e.g., ['patient01', 'patient02'])",
469
+ f"Received list: {self.validation}"
470
+ )
471
+ if len(train_map) == 0:
472
+ raise DatasetManagerError("No data left for training after applying the validation split.",
473
+ f"Dataset size: {len(map)}",
474
+ f"Validation setting: {self.validation}",
475
+ "Please reduce the validation size, increase the dataset, or disable validation."
476
+ )
477
+
478
+ if self.validation is not None and len(validate_map) == 0:
479
+ raise DatasetManagerError("No data left for validation after applying the validation split.",
480
+ f"Dataset size: {len(map)}",
481
+ f"Validation setting: {self.validation}",
482
+ "Please increase the validation size, increase the dataset, or disable validation."
483
+ )
403
484
  train_maps = Data._split(train_map, world_size)
404
485
  validate_maps = Data._split(validate_map, world_size)
405
-
486
+
406
487
  for i, (train_map, validate_map) in enumerate(zip(train_maps, validate_maps)):
407
488
  maps = [train_map]
408
489
  if len(validate_map):
@@ -436,8 +517,8 @@ class DataTrain(Data):
436
517
  subset : Union[TrainSubset, dict[str, TrainSubset]] = TrainSubset(),
437
518
  num_workers : int = 4,
438
519
  batch_size : int = 1,
439
- train_size : Union[float, str, list[int], list[str]] = 0.8) -> None:
440
- super().__init__(dataset_filenames, groups_src, patch, use_cache, subset, num_workers, batch_size, train_size, inlineAugmentations, augmentations if augmentations else {})
520
+ validation : Union[float, str, list[int], list[str]] = 0.2) -> None:
521
+ super().__init__(dataset_filenames, groups_src, patch, use_cache, subset, num_workers, batch_size, validation, inlineAugmentations, augmentations if augmentations else {})
441
522
 
442
523
  class DataPrediction(Data):
443
524
 
@@ -445,14 +526,13 @@ class DataPrediction(Data):
445
526
  def __init__(self, dataset_filenames : list[str] = ["default:./Dataset"],
446
527
  groups_src : dict[str, Group] = {"default" : Group()},
447
528
  augmentations : Union[dict[str, DataAugmentationsList], None] = {"DataAugmentation_0" : DataAugmentationsList()},
448
- inlineAugmentations: bool = False,
449
529
  patch : Union[DatasetPatch, None] = DatasetPatch(),
450
530
  use_cache : bool = True,
451
531
  subset : Union[PredictionSubset, dict[str, PredictionSubset]] = PredictionSubset(),
452
532
  num_workers : int = 4,
453
533
  batch_size : int = 1) -> None:
454
534
 
455
- super().__init__(dataset_filenames, groups_src, patch, use_cache, subset, num_workers, batch_size, inlineAugmentations=inlineAugmentations, dataAugmentationsList=augmentations if augmentations else {})
535
+ super().__init__(dataset_filenames, groups_src, patch, use_cache, subset, num_workers, batch_size, dataAugmentationsList=augmentations if augmentations else {})
456
536
 
457
537
  class DataMetric(Data):
458
538
 
@@ -463,4 +543,4 @@ class DataMetric(Data):
463
543
  validation: Union[str, None] = None,
464
544
  num_workers : int = 4) -> None:
465
545
 
466
- super().__init__(dataset_filenames=dataset_filenames, groups_src=groups_src, patch=None, use_cache=False, subset=subset, num_workers=num_workers, batch_size=1, train_size=1 if validation is None else validation)
546
+ super().__init__(dataset_filenames=dataset_filenames, groups_src=groups_src, patch=None, use_cache=False, subset=subset, num_workers=num_workers, batch_size=1, validation=1 if validation is None else validation)
konfai/data/patching.py CHANGED
@@ -150,7 +150,10 @@ class Patch(ABC):
150
150
 
151
151
  def __init__(self, patch_size: list[int], overlap: Union[int, None], path_mask: Union[str, None] = None, padValue: float = 0, extend_slice: int = 0) -> None:
152
152
  self.patch_size = patch_size
153
- self.overlap = overlap
153
+ self.overlap = overlap
154
+ if isinstance(self.overlap, int):
155
+ if self.overlap < 0:
156
+ self.overlap = None
154
157
  self._patch_slices : dict[int, list[tuple[slice]]] = {}
155
158
  self._nb_patch_per_dim: dict[int, list[tuple[int, bool]]] = {}
156
159
  self.path_mask = path_mask
@@ -188,7 +191,7 @@ class Patch(ABC):
188
191
  pad_bottom = 0
189
192
  pad_top = 0
190
193
  if self._patch_slices[a][index][0].start-bottom < 0:
191
- pad_bottom = bottom-s.start
194
+ pad_bottom = bottom-self._patch_slices[a][index][0].start
192
195
  if self._patch_slices[a][index][0].stop+top > data.shape[len(slices_pre)]:
193
196
  pad_top = self._patch_slices[a][index][0].stop+top-data.shape[len(slices_pre)]
194
197
  data_sliced = F.pad(data_sliced, [0 for _ in range((len(slices)-1)*2)]+[pad_bottom, pad_top], 'reflect')
konfai/data/transform.py CHANGED
@@ -529,7 +529,6 @@ class OneHot(Transform):
529
529
 
530
530
  def __call__(self, name: str, input: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
531
531
  result = F.one_hot(input.type(torch.int64), num_classes=self.num_classes).permute(0, len(input.shape), *[i+1 for i in range(len(input.shape)-1)]).float().squeeze(2)
532
- print(result.shape)
533
532
  return result
534
533
 
535
534
  def inverse(self, name: str, input: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
konfai/main.py CHANGED
@@ -18,9 +18,8 @@ def main():
18
18
  distributedObject.setup(world_size)
19
19
  with TensorBoard(distributedObject.name):
20
20
  mp.spawn(distributedObject, nprocs=world_size)
21
- except Exception as e:
22
- print(e)
23
- exit(1)
21
+ except KeyboardInterrupt:
22
+ print("\n[KonfAI] Manual interruption (Ctrl+C)")
24
23
 
25
24
 
26
25
  def cluster():
@@ -47,6 +46,8 @@ def cluster():
47
46
  executor.update_parameters(name=config["name"], mem_gb=config["memory"], gpus_per_node=n_gpu, tasks_per_node=n_gpu//distributedObject.size, cpus_per_task=config["num_workers"], nodes=config["num_nodes"], timeout_min=config["time_limit"])
48
47
  with TensorBoard(distributedObject.name):
49
48
  executor.submit(distributedObject)
49
+ except KeyboardInterrupt:
50
+ print("\n[KonfAI] Manual interruption (Ctrl+C)")
50
51
  except Exception as e:
51
52
  print(e)
52
53
  exit(1)
konfai/metric/measure.py CHANGED
@@ -147,7 +147,7 @@ class Dice(Criterion):
147
147
  def forward(self, output: torch.Tensor, *targets : list[torch.Tensor]) -> torch.Tensor:
148
148
  target = targets[0]
149
149
  if output.shape[1] == 1:
150
- output = F.one_hot(output.type(torch.int64), num_classes=torch.max(output).item()+1).permute(0, len(target.shape), *[i+1 for i in range(len(target.shape)-1)]).float()
150
+ output = F.one_hot(output.type(torch.int64), num_classes=int(torch.max(output).item()+1)).permute(0, len(target.shape), *[i+1 for i in range(len(target.shape)-1)]).float()
151
151
  target = F.one_hot(target.type(torch.int64), num_classes=output.shape[1]).permute(0, len(target.shape), *[i+1 for i in range(len(target.shape)-1)]).float().squeeze(2)
152
152
  return 1-torch.mean(self.dice_per_channel(output, target))
153
153
 
@@ -94,7 +94,6 @@ class SpatialTransformer(torch.nn.Module):
94
94
  new_locs[:, 1,1] = 1
95
95
  new_locs[:, 0,2] = flow[:, 0]
96
96
  new_locs[:, 1,2] = flow[:, 1]
97
- print(new_locs)
98
97
  return F.grid_sample(src, F.affine_grid(new_locs, src.size()), align_corners=True, mode="bilinear")
99
98
  else:
100
99
  new_locs = self.grid + flow
konfai/network/blocks.py CHANGED
@@ -176,7 +176,6 @@ class Print(torch.nn.Module):
176
176
  super().__init__()
177
177
 
178
178
  def forward(self, input: torch.Tensor) -> torch.Tensor:
179
- print(input.shape)
180
179
  return input
181
180
 
182
181
  class Write(torch.nn.Module):
konfai/network/network.py CHANGED
@@ -16,7 +16,7 @@ from enum import Enum
16
16
  from konfai import KONFAI_ROOT
17
17
  from konfai.metric.schedulers import Scheduler
18
18
  from konfai.utils.config import config
19
- from konfai.utils.utils import State, _getModule, getDevice, getGPUMemory
19
+ from konfai.utils.utils import State, _getModule, getDevice, getGPUMemory, MeasureError
20
20
  from konfai.data.patching import Accumulator, ModelPatch
21
21
 
22
22
  class NetState(Enum):
@@ -54,12 +54,12 @@ class LRSchedulersLoader():
54
54
  def __init__(self, params: dict[str, SchedulerStep] = {"default:ReduceLROnPlateau" : SchedulerStep(0)}) -> None:
55
55
  self.params = params
56
56
 
57
- def getShedulers(self, key: str, optimizer: torch.optim.Optimizer) -> dict[torch.optim.lr_scheduler._LRScheduler, int]:
58
- shedulers : dict[torch.optim.lr_scheduler._LRScheduler, int] = {}
57
+ def getschedulers(self, key: str, optimizer: torch.optim.Optimizer) -> dict[torch.optim.lr_scheduler._LRScheduler, int]:
58
+ schedulers : dict[torch.optim.lr_scheduler._LRScheduler, int] = {}
59
59
  for name, step in self.params.items():
60
60
  if name:
61
- shedulers[config("Trainer.Model.{}.Schedulers.{}".format(key, name))(getattr(importlib.import_module('torch.optim.lr_scheduler'), name))(optimizer, config = None)] = step.nb_step
62
- return shedulers
61
+ schedulers[config("Trainer.Model.{}.Schedulers.{}".format(key, name))(getattr(importlib.import_module('torch.optim.lr_scheduler'), name))(optimizer, config = None)] = step.nb_step
62
+ return schedulers
63
63
 
64
64
  class SchedulersLoader():
65
65
 
@@ -67,12 +67,12 @@ class SchedulersLoader():
67
67
  def __init__(self, params: dict[str, SchedulerStep] = {"default:Constant" : SchedulerStep(0)}) -> None:
68
68
  self.params = params
69
69
 
70
- def getShedulers(self, key: str) -> dict[torch.optim.lr_scheduler._LRScheduler, int]:
71
- shedulers : dict[Scheduler, int] = {}
70
+ def getschedulers(self, key: str) -> dict[torch.optim.lr_scheduler._LRScheduler, int]:
71
+ schedulers : dict[Scheduler, int] = {}
72
72
  for name, step in self.params.items():
73
73
  if name:
74
- shedulers[getattr(importlib.import_module("konfai.metric.schedulers"), name)(config = None, DL_args = key)] = step.nb_step
75
- return shedulers
74
+ schedulers[getattr(importlib.import_module("konfai.metric.schedulers"), name)(config = None, DL_args = key)] = step.nb_step
75
+ return schedulers
76
76
 
77
77
  class CriterionsAttr():
78
78
 
@@ -98,7 +98,7 @@ class CriterionsLoader():
98
98
  for module_classpath, criterionsAttr in self.criterionsLoader.items():
99
99
  module, name = _getModule(module_classpath, "metric.measure")
100
100
  criterionsAttr.isTorchCriterion = module.startswith("torch")
101
- criterionsAttr.sheduler = criterionsAttr.l.getShedulers("{}.Model.{}.outputsCriterions.{}.targetsCriterions.{}.criterionsLoader.{}".format(KONFAI_ROOT(), model_classname, output_group, target_group, module_classpath))
101
+ criterionsAttr.sheduler = criterionsAttr.l.getschedulers("{}.Model.{}.outputsCriterions.{}.targetsCriterions.{}.criterionsLoader.{}".format(KONFAI_ROOT(), model_classname, output_group, target_group, module_classpath))
102
102
  criterions[config("{}.Model.{}.outputsCriterions.{}.targetsCriterions.{}.criterionsLoader.{}".format(KONFAI_ROOT(), model_classname, output_group, target_group, module_classpath))(getattr(importlib.import_module(module), name))(config = None)] = criterionsAttr
103
103
  return criterions
104
104
 
@@ -154,10 +154,25 @@ class Measure():
154
154
  self.outputsCriterions[output_group.replace(":", ".")] = targetCriterionsLoader.getTargetsCriterions(output_group, model_classname)
155
155
  self._loss : dict[int, dict[str, Measure.Loss]] = {}
156
156
 
157
- def init(self, model : torch.nn.Module) -> None:
157
+ def init(self, model : torch.nn.Module, group_dest: list[str]) -> None:
158
158
  outputs_group_rename = {}
159
+
160
+ modules = []
161
+ for i,_ in model.named_modules():
162
+ modules.append(i)
163
+
159
164
  for output_group in self.outputsCriterions.keys():
165
+ if output_group not in modules:
166
+ raise MeasureError(f"The output group '{output_group}' defined in 'outputsCriterions' does not correspond to any module in the model.",
167
+ f"Available modules: {modules}",
168
+ "Please check that the name matches exactly a submodule or output of your model architecture."
169
+ )
160
170
  for target_group in self.outputsCriterions[output_group]:
171
+ if target_group not in group_dest:
172
+ raise MeasureError(
173
+ f"The target_group '{target_group}' defined in 'outputsCriterions.{output_group}.targetsCriterions' was not found in the available destination groups.",
174
+ "This target_group is expected for loss or metric computation, but was not loaded in 'group_dest'.",
175
+ f"Please make sure that the group '{target_group}' is defined in 'Dataset:groups_src:...:groups_dest:'{target_group}'' and correctly loaded from the dataset.")
161
176
  for criterion in self.outputsCriterions[output_group][target_group]:
162
177
  if not self.outputsCriterions[output_group][target_group][criterion].isTorchCriterion:
163
178
  outputs_group_rename[output_group] = criterion.init(model, output_group, target_group)
@@ -703,10 +718,10 @@ class Network(ModuleArgsDict, ABC):
703
718
  return in_channels, in_is_channel, out_channels, out_is_channel
704
719
 
705
720
  @_function_network()
706
- def init(self, autocast : bool, state : State, key: str) -> None:
721
+ def init(self, autocast : bool, state : State, group_dest: list[str], key: str) -> None:
707
722
  if self.outputsCriterionsLoader:
708
723
  self.measure = Measure(key, self.outputsCriterionsLoader)
709
- self.measure.init(self)
724
+ self.measure.init(self, group_dest)
710
725
  if state != State.PREDICTION:
711
726
  self.scaler = torch.amp.GradScaler("cuda", enabled=autocast)
712
727
  if self.optimizerLoader:
@@ -714,7 +729,7 @@ class Network(ModuleArgsDict, ABC):
714
729
  self.optimizer.zero_grad()
715
730
 
716
731
  if self.LRSchedulersLoader and self.optimizer:
717
- self.schedulers = self.LRSchedulersLoader.getShedulers(key, self.optimizer)
732
+ self.schedulers = self.LRSchedulersLoader.getschedulers(key, self.optimizer)
718
733
 
719
734
  def initialized(self):
720
735
  pass
@@ -880,7 +895,6 @@ class Network(ModuleArgsDict, ABC):
880
895
  if scheduler:
881
896
  if scheduler.__class__.__name__ == 'ReduceLROnPlateau':
882
897
  if self.measure:
883
- print(sum(self.measure.getLastValues(0).values()))
884
898
  scheduler.step(sum(self.measure.getLastValues(0).values()))
885
899
  else:
886
900
  scheduler.step()
konfai/trainer.py CHANGED
@@ -1,5 +1,4 @@
1
1
  import torch
2
- import torch.optim.adamw
3
2
  from torch.utils.data import DataLoader
4
3
  import tqdm
5
4
  import numpy as np
@@ -15,13 +14,76 @@ import torch.distributed as dist
15
14
  from konfai import MODELS_DIRECTORY, CHECKPOINTS_DIRECTORY, STATISTICS_DIRECTORY, SETUPS_DIRECTORY, CONFIG_FILE, MODEL, DATE, KONFAI_STATE
16
15
  from konfai.data.data_manager import DataTrain
17
16
  from konfai.utils.config import config
18
- from konfai.utils.utils import State, DataLog, DistributedObject, description
17
+ from konfai.utils.utils import State, DataLog, DistributedObject, description, TrainerError
19
18
  from konfai.network.network import Network, ModelLoader, NetState, CPU_Model
20
19
 
21
20
 
21
+ class EarlyStoppingBase:
22
+
23
+ def __init__(self):
24
+ pass
25
+
26
+ def isStopped(self) -> bool:
27
+ return False
28
+
29
+ def getScore(self, values: dict[str, float]):
30
+ return sum([i for i in values.values()])
31
+
32
+ def __call__(self, current_score: float) -> bool:
33
+ return False
34
+
35
+ class EarlyStopping(EarlyStoppingBase):
36
+
37
+ @config("EarlyStopping")
38
+ def __init__(self, monitor: Union[list[str], None] = [], patience=10, min_delta=0.0, mode="min"):
39
+ super().__init__()
40
+ self.monitor = [] if monitor is None else monitor
41
+ self.patience = patience
42
+ self.min_delta = min_delta
43
+ self.mode = mode
44
+ self.counter = 0
45
+ self.best_score = None
46
+ self.early_stop = False
47
+
48
+ def isStopped(self) -> bool:
49
+ return self.early_stop
50
+
51
+ def getScore(self, values: dict[str, float]):
52
+ if len(self.monitor) == 0:
53
+ return super().getScore(values)
54
+ for v in self.monitor:
55
+ if v not in values.keys():
56
+ raise TrainerError(
57
+ "Metric '{}' specified in EarlyStopping.monitor not found in logged values. ",
58
+ "Available keys: {}. Please check your configuration.".format(v, list(values.keys())))
59
+ return sum([i for v, i in values.items() if v in self.monitor])
60
+
61
+ def __call__(self, current_score: float) -> bool:
62
+ if self.best_score is None:
63
+ self.best_score = current_score
64
+ return False
65
+
66
+ if self.mode == "min":
67
+ improvement = self.best_score - current_score
68
+ elif self.mode == "max":
69
+ improvement = current_score - self.best_score
70
+ else:
71
+ raise TrainerError("Mode must be 'min' or 'max'.")
72
+
73
+ if improvement > self.min_delta:
74
+ self.best_score = current_score
75
+ self.counter = 0
76
+ else:
77
+ self.counter += 1
78
+
79
+ if self.counter >= self.patience:
80
+ self.early_stop = True
81
+
82
+ return self.early_stop
83
+
22
84
  class _Trainer():
23
85
 
24
- def __init__(self, world_size: int, global_rank: int, local_rank: int, size: int, train_name: str, data_log: Union[list[str], None] , save_checkpoint_mode: str, epochs: int, epoch: int, autocast: bool, it_validation: Union[int, None], it: int, model: Union[DDP, CPU_Model], modelEMA: AveragedModel, dataloader_training: DataLoader, dataloader_validation: Union[DataLoader, None] = None) -> None:
86
+ def __init__(self, world_size: int, global_rank: int, local_rank: int, size: int, train_name: str, early_stopping: EarlyStopping, data_log: Union[list[str], None] , save_checkpoint_mode: str, epochs: int, epoch: int, autocast: bool, it_validation: Union[int, None], it: int, model: Union[DDP, CPU_Model], modelEMA: AveragedModel, dataloader_training: DataLoader, dataloader_validation: Union[DataLoader, None] = None) -> None:
25
87
  self.world_size = world_size
26
88
  self.global_rank = global_rank
27
89
  self.local_rank = local_rank
@@ -36,7 +98,8 @@ class _Trainer():
36
98
  self.dataloader_validation = dataloader_validation
37
99
  self.autocast = autocast
38
100
  self.modelEMA = modelEMA
39
-
101
+ self.early_stopping = EarlyStoppingBase() if early_stopping is None else early_stopping
102
+
40
103
  self.it_validation = it_validation
41
104
  if self.it_validation is None:
42
105
  self.it_validation = len(dataloader_training)
@@ -59,6 +122,8 @@ class _Trainer():
59
122
  for self.epoch in epoch_tqdm:
60
123
  self.dataloader_training.dataset.load()
61
124
  self.train()
125
+ if self.early_stopping.isStopped():
126
+ break
62
127
  self.dataloader_training.dataset.resetAugmentation()
63
128
 
64
129
  def getInput(self, data_dict : dict[str, tuple[torch.Tensor, int, int, int, str, bool]]) -> dict[tuple[str, bool], torch.Tensor]:
@@ -88,7 +153,11 @@ class _Trainer():
88
153
  if self.dataloader_validation is not None:
89
154
  loss = self._validate()
90
155
  self.model.module.update_lr()
91
- self.checkpoint_save(loss)
156
+ score = self.early_stopping.getScore(loss)
157
+ self.checkpoint_save(score)
158
+ if self.early_stopping(score):
159
+ break
160
+
92
161
 
93
162
  batch_iter.set_description(desc())
94
163
 
@@ -120,41 +189,53 @@ class _Trainer():
120
189
  return self._validation_log(data_dict)
121
190
 
122
191
  def checkpoint_save(self, loss: float) -> None:
123
- if self.global_rank == 0:
124
- path = CHECKPOINTS_DIRECTORY()+self.train_name+"/"
125
- last_loss = None
126
- if os.path.exists(path) and os.listdir(path):
127
- name = sorted(os.listdir(path))[-1]
128
- state_dict = torch.load(path+name, weights_only=False)
129
- last_loss = state_dict["loss"]
130
- if self.save_checkpoint_mode == "BEST":
131
- if last_loss >= loss:
132
- os.remove(path+name)
133
-
134
- if self.save_checkpoint_mode != "BEST" or (last_loss is None or last_loss >= loss):
135
- name = DATE()+".pt"
136
- if not os.path.exists(path):
137
- os.makedirs(path)
138
-
139
- save_dict = {
140
- "epoch": self.epoch,
141
- "it": self.it,
142
- "loss": loss,
143
- "Model": self.model.module.state_dict()}
192
+ if self.global_rank != 0:
193
+ return
144
194
 
145
- if self.modelEMA is not None:
146
- save_dict.update({"Model_EMA" : self.modelEMA.module.state_dict()})
195
+ path = CHECKPOINTS_DIRECTORY()+self.train_name+"/"
196
+ os.makedirs(path, exist_ok=True)
197
+
198
+ name = DATE() + ".pt"
199
+ save_path = os.path.join(path, name)
147
200
 
148
- save_dict.update({'{}_optimizer_state_dict'.format(name): network.optimizer.state_dict() for name, network in self.model.module.getNetworks().items() if network.optimizer is not None})
149
- torch.save(save_dict, path+name)
201
+ save_dict = {
202
+ "epoch": self.epoch,
203
+ "it": self.it,
204
+ "loss": loss,
205
+ "Model": self.model.module.state_dict()
206
+ }
207
+
208
+ if self.modelEMA is not None:
209
+ save_dict["Model_EMA"] = self.modelEMA.module.state_dict()
210
+
211
+ save_dict.update({'{}_optimizer_state_dict'.format(name): network.optimizer.state_dict() for name, network in self.model.module.getNetworks().items() if network.optimizer is not None})
212
+ torch.save(save_dict, save_path)
213
+
214
+ if self.save_checkpoint_mode == "BEST":
215
+ all_checkpoints = sorted([
216
+ os.path.join(path, f)
217
+ for f in os.listdir(path) if f.endswith(".pt")
218
+ ])
219
+ best_ckpt = None
220
+ best_loss = float('inf')
221
+
222
+ for f in all_checkpoints:
223
+ d = torch.load(f, weights_only=False)
224
+ if d.get("loss", float("inf")) < best_loss:
225
+ best_loss = d["loss"]
226
+ best_ckpt = f
227
+
228
+ for f in all_checkpoints:
229
+ if f != best_ckpt and f != save_path:
230
+ os.remove(f)
150
231
 
151
232
  @torch.no_grad()
152
- def _log(self, type_log: str, data_dict : dict[str, tuple[torch.Tensor, int, int, int]]) -> float:
233
+ def _log(self, type_log: str, data_dict : dict[str, tuple[torch.Tensor, int, int, int]]) -> dict[str, float]:
153
234
  models: dict[str, Network] = {"" : self.model.module}
154
235
  if self.modelEMA is not None:
155
236
  models["_EMA"] = self.modelEMA.module
156
237
 
157
- measures = DistributedObject.getMeasure(self.world_size, self.global_rank, self.local_rank*self.size+self.size-1, models, self.it_validation if type_log == "Trainning" else len(self.dataloader_validation))
238
+ measures = DistributedObject.getMeasure(self.world_size, self.global_rank, self.local_rank*self.size+self.size-1, models, self.it_validation if type_log == "Training" else len(self.dataloader_validation))
158
239
 
159
240
  if self.global_rank == 0:
160
241
  images_log = []
@@ -178,27 +259,29 @@ class _Trainer():
178
259
  for name, layer, _ in model.get_layers([v.to(0) for k, v in self.getInput(data_dict).items() if k[1]], images_log):
179
260
  self.data_log[name][0](self.tb, "{}/{}{}".format(type_log, name, label), layer[:self.data_log[name][1]].detach().cpu().numpy(), self.it)
180
261
 
181
- if type_log == "Trainning":
262
+ if type_log == "Training":
182
263
  for name, network in self.model.module.getNetworks().items():
183
264
  if network.optimizer is not None:
184
265
  self.tb.add_scalar("{}/{}/Learning Rate".format(type_log, name), network.optimizer.param_groups[0]['lr'], self.it)
185
266
 
186
267
  if self.global_rank == 0:
187
- loss = []
268
+ loss = {}
188
269
  for name, network in self.model.module.getNetworks().items():
189
270
  if network.measure is not None:
190
- loss.append(sum([v[1] for v in measures["{}".format(name)][0].values()]))
191
- return np.mean(loss)
271
+ loss.update({k : v[1] for k, v in measures["{}{}".format(name, label)][0].items()})
272
+ loss.update({k : v[1] for k, v in measures["{}{}".format(name, label)][1].items()})
273
+ return loss
192
274
  return None
193
275
 
194
276
  @torch.no_grad()
195
- def _train_log(self, data_dict : dict[str, tuple[torch.Tensor, int, int, int]]) -> float:
196
- return self._log("Trainning", data_dict)
277
+ def _train_log(self, data_dict : dict[str, tuple[torch.Tensor, int, int, int]]) -> dict[str, float]:
278
+ return self._log("Training", data_dict)
197
279
 
198
280
  @torch.no_grad()
199
- def _validation_log(self, data_dict : dict[str, tuple[torch.Tensor, int, int, int]]) -> float:
281
+ def _validation_log(self, data_dict : dict[str, tuple[torch.Tensor, int, int, int]]) -> dict[str, float]:
200
282
  return self._log("Validation", data_dict)
201
283
 
284
+
202
285
  class Trainer(DistributedObject):
203
286
 
204
287
  @config("Trainer")
@@ -214,6 +297,7 @@ class Trainer(DistributedObject):
214
297
  gpu_checkpoints: Union[list[str], None] = None,
215
298
  ema_decay : float = 0,
216
299
  data_log: Union[list[str], None] = None,
300
+ early_stopping: Union[EarlyStopping, None] = None,
217
301
  save_checkpoint_mode: str= "BEST") -> None:
218
302
  if os.environ["KONFAI_CONFIG_MODE"] != "Done":
219
303
  exit(0)
@@ -223,6 +307,7 @@ class Trainer(DistributedObject):
223
307
  self.autocast = autocast
224
308
  self.epochs = epochs
225
309
  self.epoch = 0
310
+ self.early_stopping = early_stopping
226
311
  self.it = 0
227
312
  self.it_validation = it_validation
228
313
  self.model = model.getModel(train=True)
@@ -306,7 +391,7 @@ class Trainer(DistributedObject):
306
391
  if state != State.TRAIN:
307
392
  state_dict = self._load()
308
393
 
309
- self.model.init(self.autocast, state)
394
+ self.model.init(self.autocast, state, self.dataset.getGroupsDest())
310
395
  self.model.init_outputsGroup()
311
396
  self.model._compute_channels_trace(self.model, self.model.in_channels, self.gradient_checkpoints, self.gpu_checkpoints)
312
397
  self.model.load(state_dict, init=True, ema=False)
@@ -326,5 +411,5 @@ class Trainer(DistributedObject):
326
411
  model = DDP(model, static_graph=True) if torch.cuda.is_available() else CPU_Model(model)
327
412
  if self.modelEMA is not None:
328
413
  self.modelEMA.module = Network.to(self.modelEMA.module, local_rank)
329
- with _Trainer(world_size, global_rank, local_rank, self.size, self.name, self.data_log, self.save_checkpoint_mode, self.epochs, self.epoch, self.autocast, self.it_validation, self.it, model, self.modelEMA, *dataloaders) as t:
414
+ with _Trainer(world_size, global_rank, local_rank, self.size, self.name, self.early_stopping, self.data_log, self.save_checkpoint_mode, self.epochs, self.epoch, self.autocast, self.it_validation, self.it, model, self.modelEMA, *dataloaders) as t:
330
415
  t.run()
konfai/utils/config.py CHANGED
@@ -6,14 +6,10 @@ from copy import deepcopy
6
6
  from typing import Union, Literal, get_origin, get_args
7
7
  import torch
8
8
  from konfai import CONFIG_FILE
9
+ from konfai.utils.utils import ConfigError
9
10
 
10
11
  yaml = ruamel.yaml.YAML()
11
12
 
12
- class ConfigError(Exception):
13
-
14
- def __init__(self, message : str = "The config only supports types : config(Object), int, str, bool, float, list[int], list[str], list[bool], list[float], dict[str, Object]") -> None:
15
- self.message = message
16
- super().__init__(self.message)
17
13
 
18
14
  class Config():
19
15
 
@@ -187,9 +183,8 @@ def config(key : Union[str, None] = None):
187
183
  default_value = param.default if param.default != inspect._empty else allowed_values[0]
188
184
  value = config.getValue(param.name, f"default:{default_value}")
189
185
  if value not in allowed_values:
190
- raise ValueError(
191
- f"[Config] Invalid value '{value}' for parameter '{param.name}'. "
192
- f"Expected one of: {allowed_values}."
186
+ raise ConfigError(
187
+ f"Invalid value '{value}' for parameter '{param.name} expected one of: {allowed_values}."
193
188
  )
194
189
  kwargs[param.name] = value
195
190
  continue
@@ -222,21 +217,26 @@ def config(key : Union[str, None] = None):
222
217
  values = config.getValue(param.name, param.default)
223
218
  kwargs[param.name] = values
224
219
  else:
225
- raise ConfigError()
220
+ raise ConfigError("Config: The config only supports types : config(Object), int, str, bool, float, list[int], list[str], list[bool], list[float], dict[str, Object]")
226
221
  elif str(annotation).startswith("dict"):
227
222
  if annotation.__args__[0] == str:
228
223
  values = config.getValue(param.name, param.default)
229
224
  if values is not None and annotation.__args__[1] not in [int, str, bool, float]:
230
- kwargs[param.name] = {value : annotation.__args__[1](config = filename, DL_args = key_tmp+"."+param.name+"."+value) for value in values}
225
+ try:
226
+ kwargs[param.name] = {value : annotation.__args__[1](config = filename, DL_args = key_tmp+"."+param.name+"."+value) for value in values}
227
+ except ValueError as e:
228
+ raise ValueError(e)
229
+ except Exception as e:
230
+ raise ConfigError("{} {}".format(values, e))
231
231
  else:
232
232
  kwargs[param.name] = values
233
233
  else:
234
- raise ConfigError()
234
+ raise ConfigError("Config: The config only supports types : config(Object), int, str, bool, float, list[int], list[str], list[bool], list[float], dict[str, Object]")
235
235
  else:
236
236
  try:
237
237
  kwargs[param.name] = annotation(config = filename, DL_args = key_tmp)
238
238
  except Exception as e:
239
- raise ValueError("[Config] Failed to instantiate {} with type {}".format(param.name, annotation.__name__))
239
+ raise ConfigError("Failed to instantiate {} with type {}, error {} ".format(param.name, annotation.__name__, e))
240
240
 
241
241
  if os.environ['KONFAI_CONFIG_VARIABLE'] == "True":
242
242
  os.environ['KONFAI_CONFIG_VARIABLE'] = "False"
konfai/utils/utils.py CHANGED
@@ -145,8 +145,12 @@ def get_patch_slices_from_nb_patch_per_dim(patch_size_tmp: list[int], nb_patch_p
145
145
 
146
146
  def get_patch_slices_from_shape(patch_size: list[int], shape : list[int], overlap: Union[int, None]) -> tuple[list[tuple[slice]], list[tuple[int, bool]]]:
147
147
  if len(shape) != len(patch_size):
148
- return [tuple([slice(0, s) for s in shape])], [(1, True)]*len(shape)
149
-
148
+ raise DatasetManagerError(
149
+ f"Dimension mismatch: 'patch_size' has {len(patch_size)} dimensions, but 'shape' has {len(shape)}.",
150
+ f"patch_size: {patch_size}",
151
+ f"shape: {shape}",
152
+ "Both must have the same number of dimensions (e.g., 3D patch for 3D volume)."
153
+ )
150
154
  patch_slices = []
151
155
  nb_patch_per_dim = []
152
156
  slices : list[list[slice]] = []
@@ -192,12 +196,8 @@ def _logImageFormat(input : np.ndarray):
192
196
 
193
197
  if len(input.shape) == 3 and input.shape[0] != 1:
194
198
  input = np.expand_dims(input, axis=0)
195
-
196
199
  if len(input.shape) == 4:
197
200
  input = input[:, input.shape[1]//2]
198
-
199
- if input.dtype == np.uint8:
200
- return input
201
201
 
202
202
  input = input.astype(float)
203
203
  b = -np.min(input)
@@ -325,7 +325,7 @@ class DistributedObject():
325
325
  return self
326
326
 
327
327
  def __exit__(self, type, value, traceback):
328
- pass
328
+ cleanup()
329
329
 
330
330
  @abstractmethod
331
331
  def run_process(self, world_size: int, global_rank: int, local_rank: int, dataloaders: list[DataLoader]):
@@ -370,11 +370,12 @@ class DistributedObject():
370
370
  dataloaders = self.dataloader[global_rank]
371
371
  if torch.cuda.is_available():
372
372
  torch.cuda.set_device(local_rank)
373
-
374
- self.run_process(world_size, global_rank, local_rank, dataloaders)
375
- if torch.cuda.is_available():
376
- pynvml.nvmlShutdown()
377
- cleanup()
373
+ try:
374
+ self.run_process(world_size, global_rank, local_rank, dataloaders)
375
+ finally:
376
+ cleanup()
377
+ if torch.cuda.is_available():
378
+ pynvml.nvmlShutdown()
378
379
 
379
380
  def setup(parser: argparse.ArgumentParser) -> DistributedObject:
380
381
  # KONFAI arguments
@@ -460,7 +461,7 @@ def setupGPU(world_size: int, port: int, rank: Union[int, None] = None) -> tuple
460
461
  local_rank = rank
461
462
  if global_rank >= world_size:
462
463
  return None, None
463
- print("tcp://{}:{}".format(host_name, port))
464
+ #print("tcp://{}:{}".format(host_name, port))
464
465
  if torch.cuda.is_available():
465
466
  torch.cuda.empty_cache()
466
467
  dist.init_process_group("nccl", rank=global_rank, init_method="tcp://{}:{}".format(host_name, port), world_size=world_size)
@@ -476,7 +477,7 @@ def find_free_port():
476
477
  return s.getsockname()[1]
477
478
 
478
479
  def cleanup():
479
- if torch.cuda.is_available():
480
+ if dist.is_initialized():
480
481
  dist.destroy_process_group()
481
482
 
482
483
  def synchronize_data(world_size: int, gpu: int, data: any) -> list[Any]:
@@ -506,3 +507,49 @@ def _resample_affine(data: torch.Tensor, matrix: torch.Tensor):
506
507
  else:
507
508
  mode = "bilinear"
508
509
  return F.grid_sample(data.unsqueeze(0).type(torch.float32), F.affine_grid(matrix[:, :-1,...].type(torch.float32), [1]+list(data.shape), align_corners=True), align_corners=True, mode=mode, padding_mode="reflection").squeeze(0).type(data.dtype)
510
+
511
+
512
+ SUPPORTED_EXTENSIONS = [
513
+ "mha", "mhd", # MetaImage
514
+ "nii", "nii.gz", # NIfTI
515
+ "nrrd", "nrrd.gz", # NRRD
516
+ "gipl", "gipl.gz", # GIPL
517
+ "hdr", "img", # Analyze
518
+ "dcm", # DICOM (si GDCM activé)
519
+ "tif", "tiff", # TIFF
520
+ "png", "jpg", "jpeg", "bmp", # 2D formats
521
+ "h5", "itk.txt", ".fcsv", ".xml", ".vtk", ".npy"
522
+
523
+ ]
524
+
525
+ class KonfAIError(Exception):
526
+
527
+ def __init__(self, typeError: str, message: list[str]) -> None:
528
+ super().__init__("\n[{}] {}".format(typeError, message[0])+"\n→\t".join(message[1:]))
529
+
530
+
531
+ class ConfigError(KonfAIError):
532
+
533
+ def __init__(self, *message) -> None:
534
+ super().__init__("Config", message)
535
+
536
+
537
+ class DatasetManagerError(KonfAIError):
538
+
539
+ def __init__(self, *message) -> None:
540
+ super().__init__("DatasetManager", message)
541
+
542
+ class MeasureError(KonfAIError):
543
+
544
+ def __init__(self, *message) -> None:
545
+ super().__init__("Measure", message)
546
+
547
+ class TrainerError(KonfAIError):
548
+
549
+ def __init__(self, *message) -> None:
550
+ super().__init__("Trainer", message)
551
+
552
+ class AugmentationError(KonfAIError):
553
+
554
+ def __init__(self, *message) -> None:
555
+ super().__init__("Augmentation", message)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: konfai
3
- Version: 1.0.9
3
+ Version: 1.1.1
4
4
  Summary: Modular and configurable Deep Learning framework with YAML and PyTorch
5
5
  Author-email: Valentin Boussot <boussot.v@gmail.com>
6
6
  License-Expression: Apache-2.0
@@ -1,15 +1,15 @@
1
1
  konfai/__init__.py,sha256=fZM9EjrMcdJZE-jQAsNu3ZcWfa9Ylbly2q4l9STfpBI,795
2
2
  konfai/evaluator.py,sha256=TsHQNN4DhP33Sal9-exbOouZAHe7gNEk7RsBaasgkEQ,7373
3
- konfai/main.py,sha256=g-6TQ2fbRyD6ZJSJ0KwutGNmLTMgjF8eMKdzWA4gHYM,2401
3
+ konfai/main.py,sha256=OPAUWees1j6NNmggyI7lM2abZctCqk1OsvQuc9FpePY,2515
4
4
  konfai/predictor.py,sha256=FGn6_Nt4PCdxNFoIOt8nbBjrAMTwc55uZqy6OFXUjhY,22320
5
- konfai/trainer.py,sha256=RQhR1vtnzZ53jPAS5LkoBZOVeWTv3fQQ19uprzeoznw,16866
5
+ konfai/trainer.py,sha256=22IKMiPeHNJMGcJdVq9BH-Hsopp6xF4QSYyPabz4LGs,19604
6
6
  konfai/data/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
7
- konfai/data/augmentation.py,sha256=QNwwl02PDozhrBvg5nZgYpOcj41H_Nyee6qz27nQca0,31692
8
- konfai/data/data_manager.py,sha256=knkR6hfZSeZlAobAGfWqxHKhx-TfAr-L1ZDjJLVXOg8,23802
9
- konfai/data/patching.py,sha256=MuhOmauFEy8jarZyskGY0GuxCSY9QByKMNolGKF-gqo,14355
10
- konfai/data/transform.py,sha256=rOIFrVnNwxmltqKNI3y_u868MBPdGe1lMJPVqugqQXI,25092
7
+ konfai/data/augmentation.py,sha256=ASmKWBpykLBHDB_YeTgoBTlqvJ06v5OUG7n7ugPN6NU,31718
8
+ konfai/data/data_manager.py,sha256=AfZpQaHEd4QmA8Sxd6SK2LK3Nv658wlZVbivETq4Y4k,28439
9
+ konfai/data/patching.py,sha256=JFgb4qgS_gqlcnFV_Rm0JvCzKQleXNZyOWxqqM8H01M,14499
10
+ konfai/data/transform.py,sha256=ttALCEvvqIcwcYJTp4KCnf87e2hzvkXlwrtksfpWRUU,25064
11
11
  konfai/metric/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
12
- konfai/metric/measure.py,sha256=qDFlytlMoRo-K3l-0ot7dCmGf-y7GpoT29rid5mxOvk,21996
12
+ konfai/metric/measure.py,sha256=K7s15wNJm1_Iaav9y7Oe_UhSnbLUiybhLTVumHZ8ZaY,22001
13
13
  konfai/metric/schedulers.py,sha256=UoSr1TW_hrus3DhvOEbDefxCSUGz7lJS_8vbz0GEye8,1370
14
14
  konfai/models/classification/convNeXt.py,sha256=Phj1hO8TCItVhBXoFXQcIkA85DZxurGLyHIPWBfXJ0Y,9243
15
15
  konfai/models/classification/resnet.py,sha256=t8KJEgGudWBpGJM9YS1lH_5eX0-wGbK7ZW5uZW214eM,7958
@@ -18,22 +18,22 @@ konfai/models/generation/ddpm.py,sha256=awvuRo-vk8M80N93NWF4i0-WWfaycBxSOmdYJNJv
18
18
  konfai/models/generation/diffusionGan.py,sha256=KnJyV-tx4CiE_ag-5IXwiYLCuC2yFHX16k2CtASdecg,33199
19
19
  konfai/models/generation/gan.py,sha256=-GoKxHm3W9NdD4U77UcJrG5TfOZ3NWFUZG663kt2XPo,7854
20
20
  konfai/models/generation/vae.py,sha256=_3JYVT2ojZ0P98tYcD2ny7a-gWVUmnByLDhY7i-n_4g,4719
21
- konfai/models/registration/registration.py,sha256=CymMfHlkE2pYa5Kfv5lNtp_lAK9OHq6GXA-tR-eHBM8,6341
21
+ konfai/models/registration/registration.py,sha256=18EiWt4RJIXLyFtqU-kHjV1sMnQRm9mxAA6_-2B1YqI,6313
22
22
  konfai/models/representation/representation.py,sha256=RwQYoxtdph440-t_ZLelykl0hkUAD1zdspQaLkgxb-0,2677
23
23
  konfai/models/segmentation/NestedUNet.py,sha256=6XGizAIc4bDL8vx4AHW8BBFjUvovRYcjdMBHsN4ViNo,4301
24
24
  konfai/models/segmentation/UNet.py,sha256=9-g63oNqaxSlGmrD1-IVzJ-kac9QphmM7i-3rEsH23I,4218
25
25
  konfai/network/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
26
- konfai/network/blocks.py,sha256=EN1JN929zzV0FfecgFnVvv8u-WwENFRxloOpfCBkqeU,13529
27
- konfai/network/network.py,sha256=psDlOiAAaBgCOgwB8ld38YzMnw80DlGJorJoOn7yslM,45463
26
+ konfai/network/blocks.py,sha256=RaTI0Lrvq1V-GIFei-WTUB6wlg4LydZksAyJ8DMk40M,13502
27
+ konfai/network/network.py,sha256=DTROmwwPJZQllJx32GwAixfy1E23F1mtYy1bpF23TjU,46538
28
28
  konfai/utils/ITK.py,sha256=OxTieDNNYHGkn7zxJsAG-6ecRG1VYMvn1dlBbBe1DOs,13955
29
29
  konfai/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
30
- konfai/utils/config.py,sha256=8b6gCo6lh4ckdQI1aXReOtBjA3elAm4XfcsO545ssjE,12062
30
+ konfai/utils/config.py,sha256=OsCUqgNaPX56nD_MpJ2WRpiQ7xxqHonfPYaS9VGtsl8,12413
31
31
  konfai/utils/dataset.py,sha256=RNULLVauhYiyeQkpVhOVq4VRwI4KRyoQA2VCmOh8VbE,35525
32
32
  konfai/utils/registration.py,sha256=v1srEBOcgDnHrx0YtsK6bcj0yCMH7wNeaQ3wC7gEvOw,8898
33
- konfai/utils/utils.py,sha256=lirj-04jCNTIpZieCFSalK-iVsKQacvddOjF9sIf-Do,20558
34
- konfai-1.0.9.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
35
- konfai-1.0.9.dist-info/METADATA,sha256=_nzxf8r20-oXYHIIuyzvn1FyIlL07DNuSc_mnZlV34g,2515
36
- konfai-1.0.9.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
37
- konfai-1.0.9.dist-info/entry_points.txt,sha256=fG82HRN5-g39ACSOCtij_I3N6EHxfYnMR0D7TI_8pW8,81
38
- konfai-1.0.9.dist-info/top_level.txt,sha256=xF470dkIlFoFqTZEOlRehKJr4WU_8OKGXrJqYm9vWKs,7
39
- konfai-1.0.9.dist-info/RECORD,,
33
+ konfai/utils/utils.py,sha256=GPzQQ7kBmraHytsUdPOqwW0Qu-kAHGA-wwMd4obdVLE,22018
34
+ konfai-1.1.1.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
35
+ konfai-1.1.1.dist-info/METADATA,sha256=mMT6fvoAbwoMkChW6G_Y2vDhrtiMss86xezEnylzLrE,2515
36
+ konfai-1.1.1.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
37
+ konfai-1.1.1.dist-info/entry_points.txt,sha256=fG82HRN5-g39ACSOCtij_I3N6EHxfYnMR0D7TI_8pW8,81
38
+ konfai-1.1.1.dist-info/top_level.txt,sha256=xF470dkIlFoFqTZEOlRehKJr4WU_8OKGXrJqYm9vWKs,7
39
+ konfai-1.1.1.dist-info/RECORD,,
File without changes