konfai 1.1.0__tar.gz → 1.1.1__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of konfai might be problematic. Click here for more details.

Files changed (46) hide show
  1. {konfai-1.1.0 → konfai-1.1.1}/PKG-INFO +1 -1
  2. {konfai-1.1.0 → konfai-1.1.1}/konfai/data/augmentation.py +2 -2
  3. {konfai-1.1.0 → konfai-1.1.1}/konfai/data/data_manager.py +96 -16
  4. {konfai-1.1.0 → konfai-1.1.1}/konfai/data/patching.py +4 -1
  5. {konfai-1.1.0 → konfai-1.1.1}/konfai/data/transform.py +0 -1
  6. {konfai-1.1.0 → konfai-1.1.1}/konfai/main.py +4 -3
  7. {konfai-1.1.0 → konfai-1.1.1}/konfai/models/registration/registration.py +0 -1
  8. {konfai-1.1.0 → konfai-1.1.1}/konfai/network/blocks.py +0 -1
  9. {konfai-1.1.0 → konfai-1.1.1}/konfai/network/network.py +29 -15
  10. {konfai-1.1.0 → konfai-1.1.1}/konfai/trainer.py +5 -5
  11. {konfai-1.1.0 → konfai-1.1.1}/konfai/utils/config.py +12 -12
  12. {konfai-1.1.0 → konfai-1.1.1}/konfai/utils/utils.py +61 -14
  13. {konfai-1.1.0 → konfai-1.1.1}/konfai.egg-info/PKG-INFO +1 -1
  14. {konfai-1.1.0 → konfai-1.1.1}/pyproject.toml +1 -1
  15. konfai-1.1.1/tests/test_config.py +110 -0
  16. konfai-1.1.0/tests/test_config.py +0 -18
  17. {konfai-1.1.0 → konfai-1.1.1}/LICENSE +0 -0
  18. {konfai-1.1.0 → konfai-1.1.1}/README.md +0 -0
  19. {konfai-1.1.0 → konfai-1.1.1}/konfai/__init__.py +0 -0
  20. {konfai-1.1.0 → konfai-1.1.1}/konfai/data/__init__.py +0 -0
  21. {konfai-1.1.0 → konfai-1.1.1}/konfai/evaluator.py +0 -0
  22. {konfai-1.1.0 → konfai-1.1.1}/konfai/metric/__init__.py +0 -0
  23. {konfai-1.1.0 → konfai-1.1.1}/konfai/metric/measure.py +0 -0
  24. {konfai-1.1.0 → konfai-1.1.1}/konfai/metric/schedulers.py +0 -0
  25. {konfai-1.1.0 → konfai-1.1.1}/konfai/models/classification/convNeXt.py +0 -0
  26. {konfai-1.1.0 → konfai-1.1.1}/konfai/models/classification/resnet.py +0 -0
  27. {konfai-1.1.0 → konfai-1.1.1}/konfai/models/generation/cStyleGan.py +0 -0
  28. {konfai-1.1.0 → konfai-1.1.1}/konfai/models/generation/ddpm.py +0 -0
  29. {konfai-1.1.0 → konfai-1.1.1}/konfai/models/generation/diffusionGan.py +0 -0
  30. {konfai-1.1.0 → konfai-1.1.1}/konfai/models/generation/gan.py +0 -0
  31. {konfai-1.1.0 → konfai-1.1.1}/konfai/models/generation/vae.py +0 -0
  32. {konfai-1.1.0 → konfai-1.1.1}/konfai/models/representation/representation.py +0 -0
  33. {konfai-1.1.0 → konfai-1.1.1}/konfai/models/segmentation/NestedUNet.py +0 -0
  34. {konfai-1.1.0 → konfai-1.1.1}/konfai/models/segmentation/UNet.py +0 -0
  35. {konfai-1.1.0 → konfai-1.1.1}/konfai/network/__init__.py +0 -0
  36. {konfai-1.1.0 → konfai-1.1.1}/konfai/predictor.py +0 -0
  37. {konfai-1.1.0 → konfai-1.1.1}/konfai/utils/ITK.py +0 -0
  38. {konfai-1.1.0 → konfai-1.1.1}/konfai/utils/__init__.py +0 -0
  39. {konfai-1.1.0 → konfai-1.1.1}/konfai/utils/dataset.py +0 -0
  40. {konfai-1.1.0 → konfai-1.1.1}/konfai/utils/registration.py +0 -0
  41. {konfai-1.1.0 → konfai-1.1.1}/konfai.egg-info/SOURCES.txt +0 -0
  42. {konfai-1.1.0 → konfai-1.1.1}/konfai.egg-info/dependency_links.txt +0 -0
  43. {konfai-1.1.0 → konfai-1.1.1}/konfai.egg-info/entry_points.txt +0 -0
  44. {konfai-1.1.0 → konfai-1.1.1}/konfai.egg-info/requires.txt +0 -0
  45. {konfai-1.1.0 → konfai-1.1.1}/konfai.egg-info/top_level.txt +0 -0
  46. {konfai-1.1.0 → konfai-1.1.1}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: konfai
3
- Version: 1.1.0
3
+ Version: 1.1.1
4
4
  Summary: Modular and configurable Deep Learning framework with YAML and PyTorch
5
5
  Author-email: Valentin Boussot <boussot.v@gmail.com>
6
6
  License-Expression: Apache-2.0
@@ -8,7 +8,7 @@ from typing import Union
8
8
  import os
9
9
  from konfai import KONFAI_ROOT
10
10
  from konfai.utils.config import config
11
- from konfai.utils.utils import _getModule
11
+ from konfai.utils.utils import _getModule, AugmentationError
12
12
  from konfai.utils.dataset import Attribute, data_to_image
13
13
 
14
14
 
@@ -222,7 +222,7 @@ class ColorTransform(DataAugmentation):
222
222
  matrix = matrix[:, :3, :].mean(dim=1, keepdims=True).to(input.device)
223
223
  result = result.float() * matrix[:, :, :3].sum(dim=2, keepdims=True) + matrix[:, :, 3:]
224
224
  else:
225
- raise ValueError('Image must be RGB (3 channels) or L (1 channel)')
225
+ raise AugmentationError('Image must be RGB (3 channels) or L (1 channel)')
226
226
  results.append(result.reshape(input.shape))
227
227
  return results
228
228
 
@@ -11,11 +11,12 @@ from typing import Union, Iterator
11
11
  from concurrent.futures import ThreadPoolExecutor, as_completed
12
12
  import threading
13
13
  from torch.cuda import device_count
14
+ import SimpleITK as sitk
14
15
 
15
16
  from konfai import KONFAI_STATE, KONFAI_ROOT
16
17
  from konfai.data.patching import DatasetPatch, DatasetManager
17
18
  from konfai.utils.config import config
18
- from konfai.utils.utils import memoryInfo, cpuInfo, memoryForecast, getMemory, State
19
+ from konfai.utils.utils import memoryInfo, cpuInfo, memoryForecast, getMemory, State, SUPPORTED_EXTENSIONS, DatasetManagerError
19
20
  from konfai.utils.dataset import Dataset, Attribute
20
21
  from konfai.data.transform import TransformLoader, Transform
21
22
  from konfai.data.augmentation import DataAugmentationsList
@@ -236,6 +237,8 @@ class Subset():
236
237
  index = random.sample(index, len(index))
237
238
  return set([names_filtred[i] for i in index])
238
239
 
240
+ def __str__(self):
241
+ return "Subset : " + str(self.subset) + " shuffle : "+ str(self.shuffle) + " filter : "+ str(self.filter)
239
242
  class TrainSubset(Subset):
240
243
 
241
244
  @config()
@@ -290,6 +293,13 @@ class Data(ABC):
290
293
  map.append((x, y, z))
291
294
  return data, map
292
295
 
296
+ def getGroupsDest(self):
297
+ groupsDest = []
298
+ for group_src in self.groups_src:
299
+ for group_dest in self.groups_src[group_src]:
300
+ groupsDest.append(group_dest)
301
+ return groupsDest
302
+
293
303
  def _split(map: list[tuple[int, int, int]], world_size: int) -> list[list[tuple[int, int, int]]]:
294
304
  if len(map) == 0:
295
305
  return [[] for _ in range(world_size)]
@@ -313,7 +323,14 @@ class Data(ABC):
313
323
 
314
324
  def getData(self, world_size: int) -> list[list[DataLoader]]:
315
325
  datasets: dict[str, list[(str, bool)]] = {}
326
+ if self.dataset_filenames is None or len(self.dataset_filenames) == 0:
327
+ raise DatasetManagerError("No dataset filenames were provided")
316
328
  for dataset_filename in self.dataset_filenames:
329
+ if dataset_filename is None:
330
+ raise DatasetManagerError("Invalid dataset entry: 'None' received.",
331
+ "Each dataset must be a valid path string (e.g., './Dataset/', './Dataset/:mha, './Dataset/:a:mha', './Dataset/:i:mha').",
332
+ "Please check your 'dataset_filenames' list for missing or null entries."
333
+ )
317
334
  if len(dataset_filename.split(":")) == 1:
318
335
  filename = dataset_filename
319
336
  format = "mha"
@@ -324,9 +341,13 @@ class Data(ABC):
324
341
  else:
325
342
  filename, flag, format = dataset_filename.split(":")
326
343
  append = flag == "a"
327
-
328
- dataset = Dataset(filename, format)
344
+
345
+ if format not in SUPPORTED_EXTENSIONS:
346
+ raise DatasetManagerError(f"Unsupported file format '{format}'.",
347
+ f"Supported extensions are: {', '.join(SUPPORTED_EXTENSIONS)}")
329
348
 
349
+ dataset = Dataset(filename, format)
350
+
330
351
  self.datasets[filename] = dataset
331
352
  for group in self.groups_src:
332
353
  if dataset.isGroupExist(group):
@@ -334,16 +355,30 @@ class Data(ABC):
334
355
  datasets[group].append((filename, append))
335
356
  else:
336
357
  datasets[group] = [(filename, append)]
358
+ modelHaveInput = False
337
359
  for group_src in self.groups_src:
338
360
  if group_src not in datasets:
339
- raise ValueError("[DatasetManager] Error: group source {} not found. Available groups: {}".format(group_src, list(datasets.keys())))
340
361
 
362
+ raise DatasetManagerError(
363
+ f"Group source '{group_src}' not found in any dataset.",
364
+ f"Dataset filenames provided: {self.dataset_filenames}",
365
+ f"Available groups across all datasets: {sorted(list(datasets.keys()))}",
366
+ f"Please check that an entry in the dataset with the name '{group_src}.{format}' exists."
367
+ )
368
+
341
369
  for group_dest in self.groups_src[group_src]:
342
370
  self.groups_src[group_src][group_dest].load(group_src, group_dest, [self.datasets[filename] for filename, _ in datasets[group_src]])
371
+ modelHaveInput |= self.groups_src[group_src][group_dest].isInput
372
+
373
+ if not modelHaveInput:
374
+ raise DatasetManagerError(
375
+ "At least one group must be defined with 'isInput: true' to provide input to the network."
376
+ )
377
+
343
378
  for key, dataAugmentations in self.dataAugmentationsList.items():
344
379
  dataAugmentations.load(key)
345
380
 
346
- names : list[list[str]] = []
381
+ names = set()
347
382
  dataset_name : dict[str, dict[str, list[str]]] = {}
348
383
  dataset_info : dict[str, dict[str, dict[str, Attribute]]] = {}
349
384
  for group in self.groups_src:
@@ -351,10 +386,11 @@ class Data(ABC):
351
386
  dataset_name[group] = {}
352
387
  dataset_info[group] = {}
353
388
  for filename, _ in datasets[group]:
354
- names.append(self.datasets[filename].getNames(group))
389
+ names.update(self.datasets[filename].getNames(group))
355
390
  dataset_name[group][filename] = self.datasets[filename].getNames(group)
356
391
  dataset_info[group][filename] = {name: self.datasets[filename].getInfos(group, name) for name in dataset_name[group][filename]}
357
392
  subset_names = set()
393
+
358
394
  if isinstance(self.subset, dict):
359
395
  for filename, subset in self.subset.items():
360
396
  subset_names.update(subset([dataset_name[group][filename] for group in dataset_name], [dataset_info[group][filename] for group in dataset_name]))
@@ -368,13 +404,29 @@ class Data(ABC):
368
404
  subset_names.update(self.subset([dataset_name[group][filename]], [dataset_info[group][filename]]))
369
405
  else:
370
406
  subset_names.intersection(self.subset([dataset_name[group][filename]], [dataset_info[group][filename]]))
407
+ if len(subset_names) == 0:
408
+ raise DatasetManagerError("All data entries were excluded by the subset filter.",
409
+ f"Dataset entries found: {', '.join(names)}",
410
+ f"Subset object applied: {self.subset}",
411
+ f"Subset requested : {', '.join(subset_names)}",
412
+ "None of the dataset entries matched the given subset.",
413
+ "Please check your 'subset' configuration — it may be too restrictive or incorrectly formatted.",
414
+ "Examples of valid subset formats:",
415
+ "\tsubset: [0, 1] # explicit indices",
416
+ "\tsubset: 0:10 # slice notation",
417
+ "\tsubset: ./Validation.txt # external file",
418
+ "\tsubset: None # to disable filtering"
419
+ )
420
+
371
421
  data, map = self._getDatasets(list(subset_names), dataset_name)
372
-
422
+
373
423
  train_map = map
374
424
  validate_map = []
375
- if isinstance(self.validation, float):
376
- if self.validation < 1.0 and int(math.floor(len(map)*(1-self.validation))) > 0:
377
- train_map, validate_map = map[:int(math.floor(len(map)*self.validation))], map[int(math.floor(len(map)*self.validation)):]
425
+ if isinstance(self.validation, float) or isinstance(self.validation, int):
426
+ if self.validation <= 0 or self.validation >= 1:
427
+ raise DatasetManagerError("validation must be a float between 0 and 1.", f"→ Received: {self.validation}", "→ Example: validation = 0.2 # for a 20% validation split")
428
+
429
+ train_map, validate_map = map[:int(math.floor(len(map)*(1-self.validation)))], map[int(math.floor(len(map)*(1-self.validation))):]
378
430
  elif isinstance(self.validation, str):
379
431
  if ":" in self.validation:
380
432
  index = list(range(int(self.subset.split(":")[0]), int(self.subset.split(":")[1])))
@@ -389,7 +441,17 @@ class Data(ABC):
389
441
  train_map = [m for m in map if m[0] not in index]
390
442
  validate_map = [m for m in map if m[0] in index]
391
443
  else:
392
- validate_map = train_map
444
+ raise DatasetManagerError(
445
+ f"Invalid string value for 'validation': '{self.validation}'",
446
+ "Expected one of the following formats:",
447
+ "\t• A slice string like '0:10'",
448
+ "\t• A path to a text file listing validation sample names (e.g., './val.txt')",
449
+ "\t• A float between 0 and 1 (e.g., 0.2)",
450
+ "\t• A list of sample names or indices",
451
+ "The provided value is neither a valid slice nor a readable file.",
452
+ "Please fix your 'validation' setting in the configuration."
453
+ )
454
+
393
455
  elif isinstance(self.validation, list):
394
456
  if len(self.validation) > 0:
395
457
  if isinstance(self.validation[0], int):
@@ -399,10 +461,29 @@ class Data(ABC):
399
461
  index = [i for i, n in enumerate(subset_names) if n in self.validation]
400
462
  train_map = [m for m in map if m[0] not in index]
401
463
  validate_map = [m for m in map if m[0] in index]
402
-
464
+ else:
465
+ raise DatasetManagerError(f"Invalid list type for 'validation': elements of type '{type(self.validation[0]).__name__}' are not supported.",
466
+ "Supported list element types are:",
467
+ "\t• int → list of indices (e.g., [0, 1, 2])",
468
+ "\t• str → list of sample names (e.g., ['patient01', 'patient02'])",
469
+ f"Received list: {self.validation}"
470
+ )
471
+ if len(train_map) == 0:
472
+ raise DatasetManagerError("No data left for training after applying the validation split.",
473
+ f"Dataset size: {len(map)}",
474
+ f"Validation setting: {self.validation}",
475
+ "Please reduce the validation size, increase the dataset, or disable validation."
476
+ )
477
+
478
+ if self.validation is not None and len(validate_map) == 0:
479
+ raise DatasetManagerError("No data left for validation after applying the validation split.",
480
+ f"Dataset size: {len(map)}",
481
+ f"Validation setting: {self.validation}",
482
+ "Please increase the validation size, increase the dataset, or disable validation."
483
+ )
403
484
  train_maps = Data._split(train_map, world_size)
404
485
  validate_maps = Data._split(validate_map, world_size)
405
-
486
+
406
487
  for i, (train_map, validate_map) in enumerate(zip(train_maps, validate_maps)):
407
488
  maps = [train_map]
408
489
  if len(validate_map):
@@ -436,7 +517,7 @@ class DataTrain(Data):
436
517
  subset : Union[TrainSubset, dict[str, TrainSubset]] = TrainSubset(),
437
518
  num_workers : int = 4,
438
519
  batch_size : int = 1,
439
- validation : Union[float, str, list[int], list[str]] = 0.8) -> None:
520
+ validation : Union[float, str, list[int], list[str]] = 0.2) -> None:
440
521
  super().__init__(dataset_filenames, groups_src, patch, use_cache, subset, num_workers, batch_size, validation, inlineAugmentations, augmentations if augmentations else {})
441
522
 
442
523
  class DataPrediction(Data):
@@ -445,14 +526,13 @@ class DataPrediction(Data):
445
526
  def __init__(self, dataset_filenames : list[str] = ["default:./Dataset"],
446
527
  groups_src : dict[str, Group] = {"default" : Group()},
447
528
  augmentations : Union[dict[str, DataAugmentationsList], None] = {"DataAugmentation_0" : DataAugmentationsList()},
448
- inlineAugmentations: bool = False,
449
529
  patch : Union[DatasetPatch, None] = DatasetPatch(),
450
530
  use_cache : bool = True,
451
531
  subset : Union[PredictionSubset, dict[str, PredictionSubset]] = PredictionSubset(),
452
532
  num_workers : int = 4,
453
533
  batch_size : int = 1) -> None:
454
534
 
455
- super().__init__(dataset_filenames, groups_src, patch, use_cache, subset, num_workers, batch_size, inlineAugmentations=inlineAugmentations, dataAugmentationsList=augmentations if augmentations else {})
535
+ super().__init__(dataset_filenames, groups_src, patch, use_cache, subset, num_workers, batch_size, dataAugmentationsList=augmentations if augmentations else {})
456
536
 
457
537
  class DataMetric(Data):
458
538
 
@@ -150,7 +150,10 @@ class Patch(ABC):
150
150
 
151
151
  def __init__(self, patch_size: list[int], overlap: Union[int, None], path_mask: Union[str, None] = None, padValue: float = 0, extend_slice: int = 0) -> None:
152
152
  self.patch_size = patch_size
153
- self.overlap = overlap
153
+ self.overlap = overlap
154
+ if isinstance(self.overlap, int):
155
+ if self.overlap < 0:
156
+ self.overlap = None
154
157
  self._patch_slices : dict[int, list[tuple[slice]]] = {}
155
158
  self._nb_patch_per_dim: dict[int, list[tuple[int, bool]]] = {}
156
159
  self.path_mask = path_mask
@@ -529,7 +529,6 @@ class OneHot(Transform):
529
529
 
530
530
  def __call__(self, name: str, input: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
531
531
  result = F.one_hot(input.type(torch.int64), num_classes=self.num_classes).permute(0, len(input.shape), *[i+1 for i in range(len(input.shape)-1)]).float().squeeze(2)
532
- print(result.shape)
533
532
  return result
534
533
 
535
534
  def inverse(self, name: str, input: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
@@ -18,9 +18,8 @@ def main():
18
18
  distributedObject.setup(world_size)
19
19
  with TensorBoard(distributedObject.name):
20
20
  mp.spawn(distributedObject, nprocs=world_size)
21
- except Exception as e:
22
- print(e)
23
- exit(1)
21
+ except KeyboardInterrupt:
22
+ print("\n[KonfAI] Manual interruption (Ctrl+C)")
24
23
 
25
24
 
26
25
  def cluster():
@@ -47,6 +46,8 @@ def cluster():
47
46
  executor.update_parameters(name=config["name"], mem_gb=config["memory"], gpus_per_node=n_gpu, tasks_per_node=n_gpu//distributedObject.size, cpus_per_task=config["num_workers"], nodes=config["num_nodes"], timeout_min=config["time_limit"])
48
47
  with TensorBoard(distributedObject.name):
49
48
  executor.submit(distributedObject)
49
+ except KeyboardInterrupt:
50
+ print("\n[KonfAI] Manual interruption (Ctrl+C)")
50
51
  except Exception as e:
51
52
  print(e)
52
53
  exit(1)
@@ -94,7 +94,6 @@ class SpatialTransformer(torch.nn.Module):
94
94
  new_locs[:, 1,1] = 1
95
95
  new_locs[:, 0,2] = flow[:, 0]
96
96
  new_locs[:, 1,2] = flow[:, 1]
97
- print(new_locs)
98
97
  return F.grid_sample(src, F.affine_grid(new_locs, src.size()), align_corners=True, mode="bilinear")
99
98
  else:
100
99
  new_locs = self.grid + flow
@@ -176,7 +176,6 @@ class Print(torch.nn.Module):
176
176
  super().__init__()
177
177
 
178
178
  def forward(self, input: torch.Tensor) -> torch.Tensor:
179
- print(input.shape)
180
179
  return input
181
180
 
182
181
  class Write(torch.nn.Module):
@@ -16,7 +16,7 @@ from enum import Enum
16
16
  from konfai import KONFAI_ROOT
17
17
  from konfai.metric.schedulers import Scheduler
18
18
  from konfai.utils.config import config
19
- from konfai.utils.utils import State, _getModule, getDevice, getGPUMemory
19
+ from konfai.utils.utils import State, _getModule, getDevice, getGPUMemory, MeasureError
20
20
  from konfai.data.patching import Accumulator, ModelPatch
21
21
 
22
22
  class NetState(Enum):
@@ -54,12 +54,12 @@ class LRSchedulersLoader():
54
54
  def __init__(self, params: dict[str, SchedulerStep] = {"default:ReduceLROnPlateau" : SchedulerStep(0)}) -> None:
55
55
  self.params = params
56
56
 
57
- def getShedulers(self, key: str, optimizer: torch.optim.Optimizer) -> dict[torch.optim.lr_scheduler._LRScheduler, int]:
58
- shedulers : dict[torch.optim.lr_scheduler._LRScheduler, int] = {}
57
+ def getschedulers(self, key: str, optimizer: torch.optim.Optimizer) -> dict[torch.optim.lr_scheduler._LRScheduler, int]:
58
+ schedulers : dict[torch.optim.lr_scheduler._LRScheduler, int] = {}
59
59
  for name, step in self.params.items():
60
60
  if name:
61
- shedulers[config("Trainer.Model.{}.Schedulers.{}".format(key, name))(getattr(importlib.import_module('torch.optim.lr_scheduler'), name))(optimizer, config = None)] = step.nb_step
62
- return shedulers
61
+ schedulers[config("Trainer.Model.{}.Schedulers.{}".format(key, name))(getattr(importlib.import_module('torch.optim.lr_scheduler'), name))(optimizer, config = None)] = step.nb_step
62
+ return schedulers
63
63
 
64
64
  class SchedulersLoader():
65
65
 
@@ -67,12 +67,12 @@ class SchedulersLoader():
67
67
  def __init__(self, params: dict[str, SchedulerStep] = {"default:Constant" : SchedulerStep(0)}) -> None:
68
68
  self.params = params
69
69
 
70
- def getShedulers(self, key: str) -> dict[torch.optim.lr_scheduler._LRScheduler, int]:
71
- shedulers : dict[Scheduler, int] = {}
70
+ def getschedulers(self, key: str) -> dict[torch.optim.lr_scheduler._LRScheduler, int]:
71
+ schedulers : dict[Scheduler, int] = {}
72
72
  for name, step in self.params.items():
73
73
  if name:
74
- shedulers[getattr(importlib.import_module("konfai.metric.schedulers"), name)(config = None, DL_args = key)] = step.nb_step
75
- return shedulers
74
+ schedulers[getattr(importlib.import_module("konfai.metric.schedulers"), name)(config = None, DL_args = key)] = step.nb_step
75
+ return schedulers
76
76
 
77
77
  class CriterionsAttr():
78
78
 
@@ -98,7 +98,7 @@ class CriterionsLoader():
98
98
  for module_classpath, criterionsAttr in self.criterionsLoader.items():
99
99
  module, name = _getModule(module_classpath, "metric.measure")
100
100
  criterionsAttr.isTorchCriterion = module.startswith("torch")
101
- criterionsAttr.sheduler = criterionsAttr.l.getShedulers("{}.Model.{}.outputsCriterions.{}.targetsCriterions.{}.criterionsLoader.{}".format(KONFAI_ROOT(), model_classname, output_group, target_group, module_classpath))
101
+ criterionsAttr.sheduler = criterionsAttr.l.getschedulers("{}.Model.{}.outputsCriterions.{}.targetsCriterions.{}.criterionsLoader.{}".format(KONFAI_ROOT(), model_classname, output_group, target_group, module_classpath))
102
102
  criterions[config("{}.Model.{}.outputsCriterions.{}.targetsCriterions.{}.criterionsLoader.{}".format(KONFAI_ROOT(), model_classname, output_group, target_group, module_classpath))(getattr(importlib.import_module(module), name))(config = None)] = criterionsAttr
103
103
  return criterions
104
104
 
@@ -154,10 +154,25 @@ class Measure():
154
154
  self.outputsCriterions[output_group.replace(":", ".")] = targetCriterionsLoader.getTargetsCriterions(output_group, model_classname)
155
155
  self._loss : dict[int, dict[str, Measure.Loss]] = {}
156
156
 
157
- def init(self, model : torch.nn.Module) -> None:
157
+ def init(self, model : torch.nn.Module, group_dest: list[str]) -> None:
158
158
  outputs_group_rename = {}
159
+
160
+ modules = []
161
+ for i,_ in model.named_modules():
162
+ modules.append(i)
163
+
159
164
  for output_group in self.outputsCriterions.keys():
165
+ if output_group not in modules:
166
+ raise MeasureError(f"The output group '{output_group}' defined in 'outputsCriterions' does not correspond to any module in the model.",
167
+ f"Available modules: {modules}",
168
+ "Please check that the name matches exactly a submodule or output of your model architecture."
169
+ )
160
170
  for target_group in self.outputsCriterions[output_group]:
171
+ if target_group not in group_dest:
172
+ raise MeasureError(
173
+ f"The target_group '{target_group}' defined in 'outputsCriterions.{output_group}.targetsCriterions' was not found in the available destination groups.",
174
+ "This target_group is expected for loss or metric computation, but was not loaded in 'group_dest'.",
175
+ f"Please make sure that the group '{target_group}' is defined in 'Dataset:groups_src:...:groups_dest:'{target_group}'' and correctly loaded from the dataset.")
161
176
  for criterion in self.outputsCriterions[output_group][target_group]:
162
177
  if not self.outputsCriterions[output_group][target_group][criterion].isTorchCriterion:
163
178
  outputs_group_rename[output_group] = criterion.init(model, output_group, target_group)
@@ -703,10 +718,10 @@ class Network(ModuleArgsDict, ABC):
703
718
  return in_channels, in_is_channel, out_channels, out_is_channel
704
719
 
705
720
  @_function_network()
706
- def init(self, autocast : bool, state : State, key: str) -> None:
721
+ def init(self, autocast : bool, state : State, group_dest: list[str], key: str) -> None:
707
722
  if self.outputsCriterionsLoader:
708
723
  self.measure = Measure(key, self.outputsCriterionsLoader)
709
- self.measure.init(self)
724
+ self.measure.init(self, group_dest)
710
725
  if state != State.PREDICTION:
711
726
  self.scaler = torch.amp.GradScaler("cuda", enabled=autocast)
712
727
  if self.optimizerLoader:
@@ -714,7 +729,7 @@ class Network(ModuleArgsDict, ABC):
714
729
  self.optimizer.zero_grad()
715
730
 
716
731
  if self.LRSchedulersLoader and self.optimizer:
717
- self.schedulers = self.LRSchedulersLoader.getShedulers(key, self.optimizer)
732
+ self.schedulers = self.LRSchedulersLoader.getschedulers(key, self.optimizer)
718
733
 
719
734
  def initialized(self):
720
735
  pass
@@ -880,7 +895,6 @@ class Network(ModuleArgsDict, ABC):
880
895
  if scheduler:
881
896
  if scheduler.__class__.__name__ == 'ReduceLROnPlateau':
882
897
  if self.measure:
883
- print(sum(self.measure.getLastValues(0).values()))
884
898
  scheduler.step(sum(self.measure.getLastValues(0).values()))
885
899
  else:
886
900
  scheduler.step()
@@ -14,7 +14,7 @@ import torch.distributed as dist
14
14
  from konfai import MODELS_DIRECTORY, CHECKPOINTS_DIRECTORY, STATISTICS_DIRECTORY, SETUPS_DIRECTORY, CONFIG_FILE, MODEL, DATE, KONFAI_STATE
15
15
  from konfai.data.data_manager import DataTrain
16
16
  from konfai.utils.config import config
17
- from konfai.utils.utils import State, DataLog, DistributedObject, description
17
+ from konfai.utils.utils import State, DataLog, DistributedObject, description, TrainerError
18
18
  from konfai.network.network import Network, ModelLoader, NetState, CPU_Model
19
19
 
20
20
 
@@ -53,8 +53,8 @@ class EarlyStopping(EarlyStoppingBase):
53
53
  return super().getScore(values)
54
54
  for v in self.monitor:
55
55
  if v not in values.keys():
56
- raise ValueError(
57
- "[EarlyStopping] Metric '{}' specified in `monitor` not found in logged values. "
56
+ raise TrainerError(
57
+ "Metric '{}' specified in EarlyStopping.monitor not found in logged values. ",
58
58
  "Available keys: {}. Please check your configuration.".format(v, list(values.keys())))
59
59
  return sum([i for v, i in values.items() if v in self.monitor])
60
60
 
@@ -68,7 +68,7 @@ class EarlyStopping(EarlyStoppingBase):
68
68
  elif self.mode == "max":
69
69
  improvement = current_score - self.best_score
70
70
  else:
71
- raise ValueError("Mode must be 'min' or 'max'.")
71
+ raise TrainerError("Mode must be 'min' or 'max'.")
72
72
 
73
73
  if improvement > self.min_delta:
74
74
  self.best_score = current_score
@@ -391,7 +391,7 @@ class Trainer(DistributedObject):
391
391
  if state != State.TRAIN:
392
392
  state_dict = self._load()
393
393
 
394
- self.model.init(self.autocast, state)
394
+ self.model.init(self.autocast, state, self.dataset.getGroupsDest())
395
395
  self.model.init_outputsGroup()
396
396
  self.model._compute_channels_trace(self.model, self.model.in_channels, self.gradient_checkpoints, self.gpu_checkpoints)
397
397
  self.model.load(state_dict, init=True, ema=False)
@@ -6,14 +6,10 @@ from copy import deepcopy
6
6
  from typing import Union, Literal, get_origin, get_args
7
7
  import torch
8
8
  from konfai import CONFIG_FILE
9
+ from konfai.utils.utils import ConfigError
9
10
 
10
11
  yaml = ruamel.yaml.YAML()
11
12
 
12
- class ConfigError(Exception):
13
-
14
- def __init__(self, message : str = "The config only supports types : config(Object), int, str, bool, float, list[int], list[str], list[bool], list[float], dict[str, Object]") -> None:
15
- self.message = message
16
- super().__init__(self.message)
17
13
 
18
14
  class Config():
19
15
 
@@ -187,9 +183,8 @@ def config(key : Union[str, None] = None):
187
183
  default_value = param.default if param.default != inspect._empty else allowed_values[0]
188
184
  value = config.getValue(param.name, f"default:{default_value}")
189
185
  if value not in allowed_values:
190
- raise ValueError(
191
- f"[Config] Invalid value '{value}' for parameter '{param.name}'. "
192
- f"Expected one of: {allowed_values}."
186
+ raise ConfigError(
187
+ f"Invalid value '{value}' for parameter '{param.name} expected one of: {allowed_values}."
193
188
  )
194
189
  kwargs[param.name] = value
195
190
  continue
@@ -222,21 +217,26 @@ def config(key : Union[str, None] = None):
222
217
  values = config.getValue(param.name, param.default)
223
218
  kwargs[param.name] = values
224
219
  else:
225
- raise ConfigError()
220
+ raise ConfigError("Config: The config only supports types : config(Object), int, str, bool, float, list[int], list[str], list[bool], list[float], dict[str, Object]")
226
221
  elif str(annotation).startswith("dict"):
227
222
  if annotation.__args__[0] == str:
228
223
  values = config.getValue(param.name, param.default)
229
224
  if values is not None and annotation.__args__[1] not in [int, str, bool, float]:
230
- kwargs[param.name] = {value : annotation.__args__[1](config = filename, DL_args = key_tmp+"."+param.name+"."+value) for value in values}
225
+ try:
226
+ kwargs[param.name] = {value : annotation.__args__[1](config = filename, DL_args = key_tmp+"."+param.name+"."+value) for value in values}
227
+ except ValueError as e:
228
+ raise ValueError(e)
229
+ except Exception as e:
230
+ raise ConfigError("{} {}".format(values, e))
231
231
  else:
232
232
  kwargs[param.name] = values
233
233
  else:
234
- raise ConfigError()
234
+ raise ConfigError("Config: The config only supports types : config(Object), int, str, bool, float, list[int], list[str], list[bool], list[float], dict[str, Object]")
235
235
  else:
236
236
  try:
237
237
  kwargs[param.name] = annotation(config = filename, DL_args = key_tmp)
238
238
  except Exception as e:
239
- raise ValueError("[Config] Failed to instantiate {} with type {}, error {} ".format(param.name, annotation.__name__, e))
239
+ raise ConfigError("Failed to instantiate {} with type {}, error {} ".format(param.name, annotation.__name__, e))
240
240
 
241
241
  if os.environ['KONFAI_CONFIG_VARIABLE'] == "True":
242
242
  os.environ['KONFAI_CONFIG_VARIABLE'] = "False"
@@ -145,8 +145,12 @@ def get_patch_slices_from_nb_patch_per_dim(patch_size_tmp: list[int], nb_patch_p
145
145
 
146
146
  def get_patch_slices_from_shape(patch_size: list[int], shape : list[int], overlap: Union[int, None]) -> tuple[list[tuple[slice]], list[tuple[int, bool]]]:
147
147
  if len(shape) != len(patch_size):
148
- return [tuple([slice(0, s) for s in shape])], [(1, True)]*len(shape)
149
-
148
+ raise DatasetManagerError(
149
+ f"Dimension mismatch: 'patch_size' has {len(patch_size)} dimensions, but 'shape' has {len(shape)}.",
150
+ f"patch_size: {patch_size}",
151
+ f"shape: {shape}",
152
+ "Both must have the same number of dimensions (e.g., 3D patch for 3D volume)."
153
+ )
150
154
  patch_slices = []
151
155
  nb_patch_per_dim = []
152
156
  slices : list[list[slice]] = []
@@ -192,12 +196,8 @@ def _logImageFormat(input : np.ndarray):
192
196
 
193
197
  if len(input.shape) == 3 and input.shape[0] != 1:
194
198
  input = np.expand_dims(input, axis=0)
195
-
196
199
  if len(input.shape) == 4:
197
200
  input = input[:, input.shape[1]//2]
198
-
199
- if input.dtype == np.uint8:
200
- return input
201
201
 
202
202
  input = input.astype(float)
203
203
  b = -np.min(input)
@@ -325,7 +325,7 @@ class DistributedObject():
325
325
  return self
326
326
 
327
327
  def __exit__(self, type, value, traceback):
328
- pass
328
+ cleanup()
329
329
 
330
330
  @abstractmethod
331
331
  def run_process(self, world_size: int, global_rank: int, local_rank: int, dataloaders: list[DataLoader]):
@@ -370,11 +370,12 @@ class DistributedObject():
370
370
  dataloaders = self.dataloader[global_rank]
371
371
  if torch.cuda.is_available():
372
372
  torch.cuda.set_device(local_rank)
373
-
374
- self.run_process(world_size, global_rank, local_rank, dataloaders)
375
- if torch.cuda.is_available():
376
- pynvml.nvmlShutdown()
377
- cleanup()
373
+ try:
374
+ self.run_process(world_size, global_rank, local_rank, dataloaders)
375
+ finally:
376
+ cleanup()
377
+ if torch.cuda.is_available():
378
+ pynvml.nvmlShutdown()
378
379
 
379
380
  def setup(parser: argparse.ArgumentParser) -> DistributedObject:
380
381
  # KONFAI arguments
@@ -460,7 +461,7 @@ def setupGPU(world_size: int, port: int, rank: Union[int, None] = None) -> tuple
460
461
  local_rank = rank
461
462
  if global_rank >= world_size:
462
463
  return None, None
463
- print("tcp://{}:{}".format(host_name, port))
464
+ #print("tcp://{}:{}".format(host_name, port))
464
465
  if torch.cuda.is_available():
465
466
  torch.cuda.empty_cache()
466
467
  dist.init_process_group("nccl", rank=global_rank, init_method="tcp://{}:{}".format(host_name, port), world_size=world_size)
@@ -476,7 +477,7 @@ def find_free_port():
476
477
  return s.getsockname()[1]
477
478
 
478
479
  def cleanup():
479
- if torch.cuda.is_available():
480
+ if dist.is_initialized():
480
481
  dist.destroy_process_group()
481
482
 
482
483
  def synchronize_data(world_size: int, gpu: int, data: any) -> list[Any]:
@@ -506,3 +507,49 @@ def _resample_affine(data: torch.Tensor, matrix: torch.Tensor):
506
507
  else:
507
508
  mode = "bilinear"
508
509
  return F.grid_sample(data.unsqueeze(0).type(torch.float32), F.affine_grid(matrix[:, :-1,...].type(torch.float32), [1]+list(data.shape), align_corners=True), align_corners=True, mode=mode, padding_mode="reflection").squeeze(0).type(data.dtype)
510
+
511
+
512
+ SUPPORTED_EXTENSIONS = [
513
+ "mha", "mhd", # MetaImage
514
+ "nii", "nii.gz", # NIfTI
515
+ "nrrd", "nrrd.gz", # NRRD
516
+ "gipl", "gipl.gz", # GIPL
517
+ "hdr", "img", # Analyze
518
+ "dcm", # DICOM (si GDCM activé)
519
+ "tif", "tiff", # TIFF
520
+ "png", "jpg", "jpeg", "bmp", # 2D formats
521
+ "h5", "itk.txt", ".fcsv", ".xml", ".vtk", ".npy"
522
+
523
+ ]
524
+
525
+ class KonfAIError(Exception):
526
+
527
+ def __init__(self, typeError: str, message: list[str]) -> None:
528
+ super().__init__("\n[{}] {}".format(typeError, message[0])+"\n→\t".join(message[1:]))
529
+
530
+
531
+ class ConfigError(KonfAIError):
532
+
533
+ def __init__(self, *message) -> None:
534
+ super().__init__("Config", message)
535
+
536
+
537
+ class DatasetManagerError(KonfAIError):
538
+
539
+ def __init__(self, *message) -> None:
540
+ super().__init__("DatasetManager", message)
541
+
542
+ class MeasureError(KonfAIError):
543
+
544
+ def __init__(self, *message) -> None:
545
+ super().__init__("Measure", message)
546
+
547
+ class TrainerError(KonfAIError):
548
+
549
+ def __init__(self, *message) -> None:
550
+ super().__init__("Trainer", message)
551
+
552
+ class AugmentationError(KonfAIError):
553
+
554
+ def __init__(self, *message) -> None:
555
+ super().__init__("Augmentation", message)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: konfai
3
- Version: 1.1.0
3
+ Version: 1.1.1
4
4
  Summary: Modular and configurable Deep Learning framework with YAML and PyTorch
5
5
  Author-email: Valentin Boussot <boussot.v@gmail.com>
6
6
  License-Expression: Apache-2.0
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
4
4
 
5
5
  [project]
6
6
  name = "konfai"
7
- version = "1.1.0"
7
+ version = "1.1.1"
8
8
  description = "Modular and configurable Deep Learning framework with YAML and PyTorch"
9
9
  readme = "README.md"
10
10
  requires-python = ">=3.8"
@@ -0,0 +1,110 @@
1
+ import pytest
2
+ from pathlib import Path
3
+ from konfai.utils.config import config
4
+
5
+ # ---- Classe enfant pour les objets imbriqués ----
6
+ class SubObject:
7
+
8
+ @config("SubObject")
9
+ def __init__(self, z: float = 0.5):
10
+ self.z = z
11
+
12
+ # ---- Classe principale testant tous les types pris en charge ----
13
+ class DummyAllTypes:
14
+ @config("MainTest")
15
+ def __init__(
16
+ self,
17
+ a: int = 1,
18
+ b: str = "hello",
19
+ c: bool = True,
20
+ d: float = 3.14,
21
+ e: list[int] = [1, 2, 3],
22
+ f: list[str] = ["a", "b"],
23
+ g: list[bool] = [True, False],
24
+ h: list[float] = [1.0, 2.0],
25
+ i: dict[str, int] = {"x": 1},
26
+ j: SubObject = None,
27
+ l: dict[str, SubObject] = {"A" : SubObject(), "B" : SubObject()}
28
+ ):
29
+ self.a = a
30
+ self.b = b
31
+ self.c = c
32
+ self.d = d
33
+ self.e = e
34
+ self.f = f
35
+ self.g = g
36
+ self.h = h
37
+ self.i = i
38
+ self.j = j
39
+ self.l = l
40
+
41
+ @pytest.fixture
42
+ def dummy_config_file(tmp_path: Path):
43
+ yaml_path = tmp_path / "config.yml"
44
+ yaml_content = """
45
+ MainTest:
46
+ a: 42
47
+ b: test
48
+ c: false
49
+ d: 2.718
50
+ e: [10, 20]
51
+ f: [x, y]
52
+ g: [true, false, true]
53
+ h: [0.1, 0.2]
54
+ i: {"foo": 9}
55
+ j:
56
+ z: 9.99
57
+ l:
58
+ key1:
59
+ z: 7.7
60
+ key2:
61
+ z: 8.8
62
+ """
63
+ yaml_path.write_text(yaml_content)
64
+ return yaml_path
65
+
66
+ def test_config_instantiation(monkeypatch):
67
+ dummy_config_file = "./tests/dummy_data/dummy_config.yml"
68
+ monkeypatch.setenv("KONFAI_CONFIG_FILE", str(dummy_config_file))
69
+ monkeypatch.setenv("KONFAI_CONFIG_MODE", "Done")
70
+
71
+ obj = DummyAllTypes(config=str(dummy_config_file))
72
+
73
+ assert obj.a == 42
74
+ assert obj.b == "test"
75
+ assert obj.c is False
76
+ assert abs(obj.d - 2.718) < 1e-6
77
+ assert obj.e == [10, 20]
78
+ assert obj.f == ["x", "y"]
79
+ assert obj.g == [True, False, True]
80
+ assert obj.h == [0.1, 0.2]
81
+ assert obj.i == {"foo": 9}
82
+ assert isinstance(obj.j, SubObject)
83
+ assert obj.j.z == 9.99
84
+ assert isinstance(obj.l, dict)
85
+ assert set(obj.l.keys()) == {"key1", "key2"}
86
+ assert obj.l["key1"].z == 7.7
87
+ assert obj.l["key2"].z == 8.8
88
+
89
+ def test_config_create(monkeypatch):
90
+ dummy_config_file = "./tests/dummy_data/dummy_default_config.yml"
91
+ monkeypatch.setenv("KONFAI_CONFIG_FILE", str(dummy_config_file))
92
+ monkeypatch.setenv("KONFAI_CONFIG_MODE", "Done")
93
+
94
+ obj = DummyAllTypes(config=str(dummy_config_file))
95
+
96
+ assert obj.a == 1
97
+ assert obj.b == "hello"
98
+ assert obj.c is True
99
+ assert abs(obj.d - 3.14) < 1e-6
100
+ assert obj.e == [1,2,3]
101
+ assert obj.f == ["a", "b"]
102
+ assert obj.g == [True, False]
103
+ assert obj.h == [1, 2]
104
+ assert obj.i == {}
105
+ assert isinstance(obj.j, SubObject)
106
+ assert obj.j.z == 0.5
107
+ assert isinstance(obj.l, dict)
108
+ assert set(obj.l.keys()) == {"A", "B"}
109
+ assert obj.l["A"].z == 0.5
110
+ assert obj.l["B"].z == 0.5
@@ -1,18 +0,0 @@
1
- import pytest
2
- from konfai.utils.config import config
3
-
4
- class Dummy:
5
-
6
- @config("Test")
7
- def __init__(self, a: int = 1, b: str = "ok"):
8
- self.a = a
9
- self.b = b
10
-
11
- def test_config_instantiation(monkeypatch):
12
- import os
13
- os.environ['KONFAI_CONFIG_FILE'] = "./tests/dummy_data/dummy_config.yml"
14
- os.environ['KONFAI_CONFIG_PATH'] = "Test"
15
-
16
- dummy = Dummy()
17
- assert dummy.a == 42
18
- assert dummy.b == "hello"
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes