konfai 1.1.5__py3-none-any.whl → 1.1.7__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of konfai might be problematic. Click here for more details.

konfai/network/network.py CHANGED
@@ -16,7 +16,7 @@ from enum import Enum
16
16
  from konfai import KONFAI_ROOT
17
17
  from konfai.metric.schedulers import Scheduler
18
18
  from konfai.utils.config import config
19
- from konfai.utils.utils import State, _getModule, getDevice, getGPUMemory, MeasureError
19
+ from konfai.utils.utils import State, _getModule, getDevice, getGPUMemory, MeasureError, TrainerError
20
20
  from konfai.data.patching import Accumulator, ModelPatch
21
21
 
22
22
  class NetState(Enum):
@@ -41,50 +41,41 @@ class OptimizerLoader():
41
41
  def getOptimizer(self, key: str, parameter: Iterator[torch.nn.parameter.Parameter]) -> torch.optim.Optimizer:
42
42
  return config("{}.Model.{}.Optimizer".format(KONFAI_ROOT(), key))(getattr(importlib.import_module('torch.optim'), self.name))(parameter, config = None)
43
43
 
44
- class SchedulerStep():
45
-
46
- @config(None)
47
- def __init__(self, nb_step : int = 0) -> None:
48
- self.nb_step = nb_step
49
44
 
50
45
  class LRSchedulersLoader():
51
46
 
52
- @config("Schedulers")
53
- def __init__(self, params: dict[str, SchedulerStep] = {"default:ReduceLROnPlateau" : SchedulerStep(0)}) -> None:
54
- self.params = params
55
-
56
- def getschedulers(self, key: str, optimizer: torch.optim.Optimizer) -> dict[torch.optim.lr_scheduler._LRScheduler, int]:
57
- schedulers : dict[torch.optim.lr_scheduler._LRScheduler, int] = {}
58
- for name, step in self.params.items():
59
- if name:
60
- schedulers[config("Trainer.Model.{}.Schedulers.{}".format(key, name))(getattr(importlib.import_module('torch.optim.lr_scheduler'), name))(optimizer, config = None)] = step.nb_step
61
- return schedulers
62
-
63
- class SchedulersLoader():
47
+ @config()
48
+ def __init__(self, nb_step : int = 0) -> None:
49
+ self.nb_step = nb_step
50
+
51
+ def getschedulers(self, key: str, scheduler_classname: str, optimizer: torch.optim.Optimizer) -> torch.optim.lr_scheduler._LRScheduler:
52
+ for m in ["torch.optim.lr_scheduler", "konfai.metric.schedulers"]:
53
+ module, name = _getModule(scheduler_classname, m)
54
+ if hasattr(importlib.import_module(module), name):
55
+ return config("{}.Model.{}.schedulers.{}".format(KONFAI_ROOT(), key, scheduler_classname))(getattr(importlib.import_module(module), name))(optimizer, config = None)
56
+ raise TrainerError("Unknown scheduler {}, tried importing from: 'torch.optim.lr_scheduler' and 'konfai.metric.schedulers', but no valid match was found. Check your YAML config or scheduler name spelling.".format(scheduler_classname))
57
+
58
+ class LossSchedulersLoader():
64
59
 
65
- @config("Schedulers")
66
- def __init__(self, params: dict[str, SchedulerStep] = {"default:Constant" : SchedulerStep(0)}) -> None:
67
- self.params = params
68
-
69
- def getschedulers(self, key: str) -> dict[torch.optim.lr_scheduler._LRScheduler, int]:
70
- schedulers : dict[Scheduler, int] = {}
71
- for name, step in self.params.items():
72
- if name:
73
- schedulers[getattr(importlib.import_module("konfai.metric.schedulers"), name)(config = None, DL_args = key)] = step.nb_step
74
- return schedulers
60
+ @config()
61
+ def __init__(self, nb_step : int = 0) -> None:
62
+ self.nb_step = nb_step
63
+
64
+ def getschedulers(self, key: str, scheduler_classname: str) -> torch.optim.lr_scheduler._LRScheduler:
65
+ return config("{}.{}".format(key, scheduler_classname))(getattr(importlib.import_module('konfai.metric.schedulers'), scheduler_classname))(config = None)
75
66
 
76
67
  class CriterionsAttr():
77
68
 
78
69
  @config()
79
- def __init__(self, l: SchedulersLoader = SchedulersLoader(), isLoss: bool = True, group: int = 0, stepStart: int = 0, stepStop: Union[int, None] = None, accumulation: bool = False) -> None:
80
- self.l = l
70
+ def __init__(self, schedulers: dict[str, LossSchedulersLoader] = {"default:Constant": LossSchedulersLoader(0)}, isLoss: bool = True, group: int = 0, stepStart: int = 0, stepStop: Union[int, None] = None, accumulation: bool = False) -> None:
71
+ self.schedulersLoader = schedulers
81
72
  self.isTorchCriterion = True
82
73
  self.isLoss = isLoss
83
74
  self.stepStart = stepStart
84
75
  self.stepStop = stepStop
85
76
  self.group = group
86
77
  self.accumulation = accumulation
87
- self.sheduler = None
78
+ self.schedulers: dict[Scheduler, int] = {}
88
79
 
89
80
  class CriterionsLoader():
90
81
 
@@ -95,9 +86,10 @@ class CriterionsLoader():
95
86
  def getCriterions(self, model_classname : str, output_group : str, target_group : str) -> dict[torch.nn.Module, CriterionsAttr]:
96
87
  criterions = {}
97
88
  for module_classpath, criterionsAttr in self.criterionsLoader.items():
98
- module, name = _getModule(module_classpath, "metric.measure")
89
+ module, name = _getModule(module_classpath, "konfai.metric.measure")
99
90
  criterionsAttr.isTorchCriterion = module.startswith("torch")
100
- criterionsAttr.sheduler = criterionsAttr.l.getschedulers("{}.Model.{}.outputsCriterions.{}.targetsCriterions.{}.criterionsLoader.{}".format(KONFAI_ROOT(), model_classname, output_group, target_group, module_classpath))
91
+ for scheduler_classname, schedulers in criterionsAttr.schedulersLoader.items():
92
+ criterionsAttr.schedulers[schedulers.getschedulers("{}.Model.{}.outputsCriterions.{}.targetsCriterions.{}.criterionsLoader.{}.schedulers".format(KONFAI_ROOT(), model_classname, output_group, target_group, module_classpath), scheduler_classname)] = schedulers.nb_step
101
93
  criterions[config("{}.Model.{}.outputsCriterions.{}.targetsCriterions.{}.criterionsLoader.{}".format(KONFAI_ROOT(), model_classname, output_group, target_group, module_classpath))(getattr(importlib.import_module(module), name))(config = None)] = criterionsAttr
102
94
  return criterions
103
95
 
@@ -195,7 +187,7 @@ class Measure():
195
187
 
196
188
  for criterion, criterionsAttr in self.outputsCriterions[output_group][target_group].items():
197
189
  if it >= criterionsAttr.stepStart and (criterionsAttr.stepStop is None or it <= criterionsAttr.stepStop):
198
- scheduler = self.update_scheduler(criterionsAttr.sheduler, it)
190
+ scheduler = self.update_scheduler(criterionsAttr.schedulers, it)
199
191
  self._loss[criterionsAttr.group]["{}:{}:{}".format(output_group, target_group, criterion.__class__.__name__)].add(scheduler.get_value(), criterion(output, *target))
200
192
  if training and len(np.unique([len(l) for l in self._loss[criterionsAttr.group].values() if l.accumulation and l.isLoss])) == 1:
201
193
  if criterionsAttr.isLoss:
@@ -535,7 +527,7 @@ class Network(ModuleArgsDict, ABC):
535
527
  def __init__( self,
536
528
  in_channels : int = 1,
537
529
  optimizer: Union[OptimizerLoader, None] = None,
538
- schedulers: Union[LRSchedulersLoader, None] = None,
530
+ schedulers: Union[dict[str, LRSchedulersLoader], None] = None,
539
531
  outputsCriterions: Union[dict[str, TargetCriterionsLoader], None] = None,
540
532
  patch : Union[ModelPatch, None] = None,
541
533
  nb_batch_per_step : int = 1,
@@ -549,7 +541,7 @@ class Network(ModuleArgsDict, ABC):
549
541
  self.optimizer : Union[torch.optim.Optimizer, None] = None
550
542
 
551
543
  self.LRSchedulersLoader = schedulers
552
- self.schedulers : Union[dict[torch.optim.lr_scheduler._LRScheduler, int], None] = None
544
+ self.schedulers : Union[dict[torch.optim.lr_scheduler._LRScheduler, int], None] = {}
553
545
 
554
546
  self.outputsCriterionsLoader = outputsCriterions
555
547
  self.measure : Union[Measure, None] = None
@@ -561,6 +553,7 @@ class Network(ModuleArgsDict, ABC):
561
553
  self.init_gain = init_gain
562
554
  self.dim = dim
563
555
  self._it = 0
556
+ self._nb_lr_update = 0
564
557
  self.outputsGroup : list[OutputsGroup]= []
565
558
 
566
559
  @_function_network()
@@ -655,9 +648,15 @@ class Network(ModuleArgsDict, ABC):
655
648
  else:
656
649
  model_state_dict[alias] = model_state_dict_tmp[alias]
657
650
  self.load_state_dict(model_state_dict)
658
-
659
- if "{}_optimizer_state_dict".format(name) in state_dict and self.optimizer:
660
- self.optimizer.load_state_dict(state_dict['{}_optimizer_state_dict'.format(name)])
651
+ if "{}_optimizer_state_dict".format(self.getName()) in state_dict and self.optimizer:
652
+ self.optimizer.load_state_dict(state_dict['{}_optimizer_state_dict'.format(self.getName())])
653
+ if "{}_it".format(self.getName()) in state_dict:
654
+ self._it = int(state_dict["{}_it".format(self.getName())])
655
+ if "{}_nb_lr_update".format(self.getName()) in state_dict:
656
+ self._nb_lr_update = int(state_dict["{}_nb_lr_update".format(self.getName())])
657
+ for scheduler in self.schedulers:
658
+ if scheduler.last_epoch == -1:
659
+ scheduler.last_epoch = self._nb_lr_update
661
660
  self.initialized()
662
661
 
663
662
  def _compute_channels_trace(self, module : ModuleArgsDict, in_channels : int, gradient_checkpoints: Union[list[str], None], gpu_checkpoints: Union[list[str], None], name: Union[str, None] = None, in_is_channel = True, out_channels : Union[int, None] = None, out_is_channel = True) -> tuple[int, bool, int, bool]:
@@ -730,7 +729,8 @@ class Network(ModuleArgsDict, ABC):
730
729
  self.optimizer.zero_grad()
731
730
 
732
731
  if self.LRSchedulersLoader and self.optimizer:
733
- self.schedulers = self.LRSchedulersLoader.getschedulers(key, self.optimizer)
732
+ for schedulers_classname, schedulers in self.LRSchedulersLoader.items():
733
+ self.schedulers[schedulers.getschedulers(key, schedulers_classname, self.optimizer)] = schedulers.nb_step
734
734
 
735
735
  def initialized(self):
736
736
  pass
@@ -878,6 +878,7 @@ class Network(ModuleArgsDict, ABC):
878
878
  model._requires_grad(list(self.measure.outputsCriterions.keys()))
879
879
  for loss in self.measure.getLoss():
880
880
  self.scaler.scale(loss / self.nb_batch_per_step).backward()
881
+
881
882
  if self._it % self.nb_batch_per_step == 0:
882
883
  self.scaler.step(self.optimizer)
883
884
  self.scaler.update()
@@ -886,19 +887,19 @@ class Network(ModuleArgsDict, ABC):
886
887
 
887
888
  @_function_network()
888
889
  def update_lr(self):
890
+ self._nb_lr_update+=1
889
891
  step = 0
890
892
  scheduler = None
891
- if self.schedulers:
892
- for scheduler, value in self.schedulers.items():
893
- if value is None or (self._it >= step and self._it < step+value):
894
- break
895
- step += value
893
+ for scheduler, value in self.schedulers.items():
894
+ if value is None or (self._nb_lr_update >= step and self._nb_lr_update < step+value):
895
+ break
896
+ step += value
896
897
  if scheduler:
897
898
  if scheduler.__class__.__name__ == 'ReduceLROnPlateau':
898
899
  if self.measure:
899
900
  scheduler.step(sum(self.measure.getLastValues(0).values()))
900
901
  else:
901
- scheduler.step()
902
+ scheduler.step()
902
903
 
903
904
  @_function_network()
904
905
  def getNetworks(self) -> Self:
@@ -916,6 +917,12 @@ class Network(ModuleArgsDict, ABC):
916
917
  v = Network.to(v, int(os.environ["device"]))
917
918
  else:
918
919
  v = v.to(getDevice(int(os.environ["device"])))
920
+ if isinstance(module, Network):
921
+ if module.optimizer is not None:
922
+ for state in module.optimizer.state.values():
923
+ for k, v in state.items():
924
+ if isinstance(v, torch.Tensor):
925
+ state[k] = v.to(getDevice(int(os.environ["device"])))
919
926
  return module
920
927
 
921
928
  def getName(self) -> str:
@@ -944,7 +951,7 @@ class ModelLoader():
944
951
 
945
952
  @config("Model")
946
953
  def __init__(self, classpath : str = "default:segmentation.UNet.UNet") -> None:
947
- self.module, self.name = _getModule(classpath, "models")
954
+ self.module, self.name = _getModule(classpath, "konfai.models")
948
955
 
949
956
  def getModel(self, train : bool = True, DL_args: Union[str, None] = None, DL_without=["optimizer", "schedulers", "nb_batch_per_step", "init_type", "init_gain"]) -> Network:
950
957
  if not DL_args:
konfai/predictor.py CHANGED
@@ -27,17 +27,19 @@ from collections import defaultdict
27
27
 
28
28
  class OutDataset(Dataset, NeedDevice, ABC):
29
29
 
30
- def __init__(self, filename: str, group: str, pre_transforms : dict[str, TransformLoader], post_transforms : dict[str, TransformLoader], final_transforms : dict[str, TransformLoader], patchCombine: Union[str, None]) -> None:
30
+ def __init__(self, filename: str, group: str, before_reduction_transforms : dict[str, TransformLoader], after_reduction_transforms : dict[str, TransformLoader], final_transforms : dict[str, TransformLoader], patchCombine: Union[str, None], reduction: str) -> None:
31
31
  filename, format = filename.split(":")
32
32
  super().__init__(filename, format)
33
33
  self.group = group
34
- self._pre_transforms = pre_transforms
35
- self._post_transforms = post_transforms
34
+ self._before_reduction_transforms = before_reduction_transforms
35
+ self._after_reduction_transforms = after_reduction_transforms
36
36
  self._final_transforms = final_transforms
37
37
  self._patchCombine = patchCombine
38
+ self.reduction_classpath = reduction
39
+ self.reduction = None
38
40
 
39
- self.pre_transforms : list[Transform] = []
40
- self.post_transforms : list[Transform] = []
41
+ self.before_reduction_transforms : list[Transform] = []
42
+ self.after_reduction_transforms : list[Transform] = []
41
43
  self.final_transforms : list[Transform] = []
42
44
  self.patchCombine: PathCombine = None
43
45
 
@@ -47,7 +49,7 @@ class OutDataset(Dataset, NeedDevice, ABC):
47
49
  self.nb_data_augmentation = 0
48
50
 
49
51
  def load(self, name_layer: str, datasets: list[Dataset], groups: dict[str, str]):
50
- transforms_type = ["pre_transforms", "post_transforms", "final_transforms"]
52
+ transforms_type = ["before_reduction_transforms", "after_reduction_transforms", "final_transforms"]
51
53
  for name, _transform_type, transform_type in [(k, getattr(self, "_{}".format(k)), getattr(self, k)) for k in transforms_type]:
52
54
 
53
55
  if _transform_type is not None:
@@ -57,8 +59,15 @@ class OutDataset(Dataset, NeedDevice, ABC):
57
59
  transform_type.append(transform)
58
60
 
59
61
  if self._patchCombine is not None:
60
- module, name = _getModule(self._patchCombine, "data.patching")
61
- self.patchCombine = getattr(importlib.import_module(module), name)(config = None, DL_args = "{}.outsDataset.{}.OutDataset".format(KONFAI_ROOT(), name_layer))
62
+ module, name = _getModule(self._patchCombine, "konfai.data.patching")
63
+ self.patchCombine = config("{}.outsDataset.{}.OutDataset".format(KONFAI_ROOT(), name_layer))(getattr(importlib.import_module(module), name))(config = None)
64
+
65
+ module, name = _getModule(self.reduction_classpath, "konfai.predictor")
66
+ if module == "konfai.predictor":
67
+ self.reduction = getattr(importlib.import_module(module), name)
68
+ else:
69
+ self.reduction = config("{}.outsDataset.{}.OutDataset.{}".format(KONFAI_ROOT(), name_layer, self.reduction_classpath))(getattr(importlib.import_module(module), name))(config = None)
70
+
62
71
 
63
72
  def setPatchConfig(self, patchSize: Union[list[int], None], overlap: Union[int, None], nb_data_augmentation: int) -> None:
64
73
  if patchSize is not None and overlap is not None:
@@ -70,14 +79,14 @@ class OutDataset(Dataset, NeedDevice, ABC):
70
79
 
71
80
  def setDevice(self, device: torch.device):
72
81
  super().setDevice(device)
73
- transforms_type = ["pre_transforms", "post_transforms", "final_transforms"]
82
+ transforms_type = ["before_reduction_transforms", "after_reduction_transforms", "final_transforms"]
74
83
  for transform_type in [(getattr(self, k)) for k in transforms_type]:
75
84
  if transform_type is not None:
76
85
  for transform in transform_type:
77
86
  transform.setDevice(device)
78
87
 
79
88
  @abstractmethod
80
- def addLayer(self, index: int, index_patch: int, layer: torch.Tensor, dataset: DatasetIter):
89
+ def addLayer(self, index_dataset: int, index_augmentation: int, index_patch: int, layer: torch.Tensor, dataset: DatasetIter):
81
90
  pass
82
91
 
83
92
  def isDone(self, index: int) -> bool:
@@ -96,25 +105,28 @@ class Reduction():
96
105
  def __init__(self):
97
106
  pass
98
107
 
99
- class ReductionMean():
108
+ class Mean(Reduction):
100
109
 
101
110
  def __init__(self):
102
111
  pass
103
-
104
- class ReductionMedian():
112
+
113
+ def __call__(self, result: torch.Tensor) -> torch.Tensor:
114
+ return torch.mean(result.float(), dim=0)
115
+
116
+ class Median(Reduction):
105
117
 
106
118
  def __init__(self):
107
119
  pass
108
120
 
109
-
121
+ def __call__(self, result: torch.Tensor) -> torch.Tensor:
122
+ return torch.median(result.float(), dim=0).values
110
123
 
111
124
  class OutSameAsGroupDataset(OutDataset):
112
125
 
113
126
  @config("OutDataset")
114
- def __init__(self, dataset_filename: str = "./Dataset:mha", group: str = "default", sameAsGroup: str = "default", pre_transforms : dict[str, TransformLoader] = {"default:Normalize": TransformLoader()}, post_transforms : dict[str, TransformLoader] = {"default:Normalize": TransformLoader()}, final_transforms : dict[str, TransformLoader] = {"default:Normalize": TransformLoader()}, patchCombine: Union[str, None] = None, reduction: str = "mean", inverse_transform: bool = True) -> None:
115
- super().__init__(dataset_filename, group, pre_transforms, post_transforms, final_transforms, patchCombine)
127
+ def __init__(self, sameAsGroup: str = "default", dataset_filename: str = "default:./Dataset:mha", group: str = "default", before_reduction_transforms : dict[str, TransformLoader] = {"default:Normalize": TransformLoader()}, after_reduction_transforms : dict[str, TransformLoader] = {"default:Normalize": TransformLoader()}, final_transforms : dict[str, TransformLoader] = {"default:Normalize": TransformLoader()}, patchCombine: Union[str, None] = None, reduction: str = "mean", inverse_transform: bool = True) -> None:
128
+ super().__init__(dataset_filename, group, before_reduction_transforms, after_reduction_transforms, final_transforms, patchCombine, reduction)
116
129
  self.group_src, self.group_dest = sameAsGroup.split(":")
117
- self.reduction = reduction
118
130
  self.inverse_transform = inverse_transform
119
131
 
120
132
  def addLayer(self, index_dataset: int, index_augmentation: int, index_patch: int, layer: torch.Tensor, dataset: DatasetIter):
@@ -131,9 +143,6 @@ class OutSameAsGroupDataset(OutDataset):
131
143
  for i in range(len(input_dataset.patch.getPatch_slices(index_augmentation))):
132
144
  self.attributes[index_dataset][index_augmentation][i] = Attribute(input_dataset.cache_attributes[0])
133
145
 
134
- for transform in self.pre_transforms:
135
- layer = transform(self.names[index_dataset], layer, self.attributes[index_dataset][index_augmentation][index_patch])
136
-
137
146
  if self.inverse_transform:
138
147
  for transform in reversed(dataset.groups_src[self.group_src][self.group_dest].post_transforms):
139
148
  layer = transform.inverse(self.names[index_dataset], layer, self.attributes[index_dataset][index_augmentation][index_patch])
@@ -142,7 +151,7 @@ class OutSameAsGroupDataset(OutDataset):
142
151
 
143
152
  def load(self, name_layer: str, datasets: list[Dataset], groups: dict[str, str]):
144
153
  super().load(name_layer, datasets, groups)
145
-
154
+
146
155
  if self.group_src not in groups.keys():
147
156
  raise PredictorError(
148
157
  f"Source group '{self.group_src}' not found. Available groups: {list(groups.keys())}."
@@ -155,7 +164,6 @@ class OutSameAsGroupDataset(OutDataset):
155
164
 
156
165
  def _getOutput(self, index: int, index_augmentation: int, dataset: DatasetIter) -> torch.Tensor:
157
166
  layer = self.output_layer_accumulator[index][index_augmentation].assemble()
158
- name = self.names[index]
159
167
  if index_augmentation > 0:
160
168
 
161
169
  i = 0
@@ -167,59 +175,27 @@ class OutSameAsGroupDataset(OutDataset):
167
175
  break
168
176
  i += dataAugmentations.nb
169
177
 
170
- for transform in self.post_transforms:
171
- layer = transform(name, layer, self.attributes[index][index_augmentation][0])
172
-
173
- if self.inverse_transform:
174
- for transform in reversed(dataset.groups_src[self.group_src][self.group_dest].pre_transforms):
175
- layer = transform.inverse(name, layer, self.attributes[index][index_augmentation][0])
178
+ for transform in self.before_reduction_transforms:
179
+ layer = transform(self.names[index], layer, self.attributes[index][index_augmentation][0])
180
+
176
181
  return layer
177
182
 
178
183
  def getOutput(self, index: int, dataset: DatasetIter) -> torch.Tensor:
179
184
  result = torch.cat([self._getOutput(index, index_augmentation, dataset).unsqueeze(0) for index_augmentation in self.output_layer_accumulator[index].keys()], dim=0)
180
- name = self.names[index]
181
185
  self.output_layer_accumulator.pop(index)
182
- dtype = result.dtype
186
+ result = self.reduction(result.float()).to(result.dtype)
187
+
188
+ for transform in self.after_reduction_transforms:
189
+ result = transform(self.names[index], result, self.attributes[index][0][0])
183
190
 
184
- if self.reduction == "mean":
185
- result = torch.mean(result.float(), dim=0).to(dtype)
186
- elif self.reduction == "median":
187
- result, _ = torch.median(result.float(), dim=0)
188
- else:
189
- raise NameError("Reduction method does not exist (mean, median)")
191
+ if self.inverse_transform:
192
+ for transform in reversed(dataset.groups_src[self.group_src][self.group_dest].pre_transforms):
193
+ result = transform.inverse(self.names[index], result, self.attributes[index][0][0])
194
+
190
195
  for transform in self.final_transforms:
191
- result = transform(name, result, self.attributes[index][0][0])
196
+ result = transform(self.names[index], result, self.attributes[index][0][0])
192
197
  return result
193
198
 
194
- class OutLayerDataset(OutDataset):
195
-
196
- @config("OutDataset")
197
- def __init__(self, dataset_filename: str = "Dataset.h5", group: str = "default", overlap : Union[list[int], None] = None, pre_transforms : dict[str, TransformLoader] = {"default:Normalize": TransformLoader()}, post_transforms : dict[str, TransformLoader] = {"default:Normalize": TransformLoader()}, final_transforms : dict[str, TransformLoader] = {"default:Normalize": TransformLoader()}, patchCombine: Union[str, None] = None) -> None:
198
- super().__init__(dataset_filename, group, pre_transforms, post_transforms, final_transforms, patchCombine)
199
- self.overlap = overlap
200
-
201
- def addLayer(self, index: int, index_patch: int, layer: torch.Tensor, dataset: DatasetIter):
202
- if index not in self.output_layer_accumulator:
203
- group = list(dataset.groups.keys())[0]
204
- patch_slices = get_patch_slices_from_nb_patch_per_dim(list(layer.shape[2:]), dataset.getDatasetFromIndex(group, index).patch.nb_patch_per_dim, self.overlap)
205
- self.output_layer_accumulator[index] = Accumulator(patch_slices, self.patchCombine, batch=False)
206
- self.attributes[index] = Attribute()
207
- self.names[index] = dataset.getDatasetFromIndex(group, index).name
208
-
209
-
210
- for transform in self.pre_transforms:
211
- layer = transform(layer, self.attributes[index])
212
- self.output_layer_accumulator[index].addLayer(index_patch, layer)
213
-
214
- def getOutput(self, index: int, dataset: DatasetIter) -> torch.Tensor:
215
- layer = self.output_layer_accumulator[index].assemble()
216
- name = self.names[index]
217
- for transform in self.post_transforms:
218
- layer = transform(name, layer, self.attributes[index])
219
-
220
- self.output_layer_accumulator.pop(index)
221
- return layer
222
-
223
199
  class OutDatasetLoader():
224
200
 
225
201
  @config("OutDataset")
@@ -307,9 +283,9 @@ class _Predictor():
307
283
 
308
284
  class ModelComposite(Network):
309
285
 
310
- def __init__(self, model: Network, nb_models: int, method: str):
286
+ def __init__(self, model: Network, nb_models: int, combine: Reduction):
311
287
  super().__init__(model.in_channels, model.optimizer, model.schedulers, model.outputsCriterionsLoader, model.patch, model.nb_batch_per_step, model.init_type, model.init_gain, model.dim)
312
- self.method = method
288
+ self.combine = combine
313
289
  for i in range(nb_models):
314
290
  self.add_module("Model_{}".format(i), copy.deepcopy(model), in_branch=[0], out_branch=["output_{}".format(i)])
315
291
 
@@ -330,12 +306,7 @@ class ModelComposite(Network):
330
306
 
331
307
  final_outputs = []
332
308
  for key, tensors in aggregated.items():
333
- stacked = torch.stack(tensors, dim=0)
334
- if self.method == 'mean':
335
- agg = torch.mean(stacked, dim=0)
336
- elif self.method == 'median':
337
- agg = torch.median(stacked, dim=0).values
338
- final_outputs.append((key, agg))
309
+ final_outputs.append((key, self.combine(torch.stack(tensors, dim=0))))
339
310
 
340
311
  return final_outputs
341
312
 
@@ -358,7 +329,7 @@ class Predictor(DistributedObject):
358
329
  super().__init__(train_name)
359
330
  self.manual_seed = manual_seed
360
331
  self.dataset = dataset
361
- self.combine = combine
332
+ self.combine_classpath = combine
362
333
  self.autocast = autocast
363
334
 
364
335
  self.model = model.getModel(train=False)
@@ -430,7 +401,15 @@ class Predictor(DistributedObject):
430
401
  "Available modules: {}".format(modules),
431
402
  "Please check that the name matches exactly a submodule or output of your model architecture."
432
403
  )
433
- self.modelComposite = ModelComposite(self.model, len(MODEL().split(":")), self.combine)
404
+
405
+ module, name = _getModule(self.combine_classpath, "konfai.predictor")
406
+ if module == "konfai.predictor":
407
+ combine = getattr(importlib.import_module(module), name)
408
+ else:
409
+ combine = config("{}.{}".format(KONFAI_ROOT(), self.combine_classpath))(getattr(importlib.import_module(module), name))(config = None)
410
+
411
+
412
+ self.modelComposite = ModelComposite(self.model, len(MODEL().split(":")), combine)
434
413
  self.modelComposite.load(self._load())
435
414
 
436
415
  if len(list(self.outsDataset.keys())) == 0 and len([network for network in self.modelComposite.getNetworks().values() if network.measure is not None]) == 0:
konfai/trainer.py CHANGED
@@ -108,9 +108,7 @@ class _Trainer():
108
108
  if data_log is not None:
109
109
  for data in data_log:
110
110
  self.data_log[data.split("/")[0].replace(":", ".")] = (DataLog.__getitem__(data.split("/")[1]).value[0], int(data.split("/")[2]))
111
-
112
-
113
-
111
+
114
112
  def __enter__(self):
115
113
  return self
116
114
 
@@ -212,6 +210,9 @@ class _Trainer():
212
210
  save_dict["Model_EMA"] = self.modelEMA.module.state_dict()
213
211
 
214
212
  save_dict.update({'{}_optimizer_state_dict'.format(name): network.optimizer.state_dict() for name, network in self.model.module.getNetworks().items() if network.optimizer is not None})
213
+ save_dict.update({'{}_it'.format(name): network._it for name, network in self.model.module.getNetworks().items() if network.optimizer is not None})
214
+ save_dict.update({'{}_nb_lr_update'.format(name): network._nb_lr_update for name, network in self.model.module.getNetworks().items() if network.optimizer is not None})
215
+
215
216
  torch.save(save_dict, save_path)
216
217
 
217
218
  if self.save_checkpoint_mode == "BEST":
@@ -356,7 +357,7 @@ class Trainer(DistributedObject):
356
357
  name = sorted(os.listdir(path))[-1]
357
358
 
358
359
  if os.path.exists(path+name):
359
- state_dict = torch.load(path+name, weights_only=False)
360
+ state_dict = torch.load(path+name, weights_only=False, map_location="cpu")
360
361
  else:
361
362
  raise Exception("Model : {} does not exist !".format(self.name))
362
363
 
konfai/utils/utils.py CHANGED
@@ -20,6 +20,11 @@ import torch.nn.functional as F
20
20
  import sys
21
21
  import re
22
22
 
23
+ import requests
24
+ from tqdm import tqdm
25
+ import importlib
26
+ from pathlib import Path
27
+ import shutil
23
28
 
24
29
  def description(model, modelEMA = None, showMemory: bool = True, train: bool = True) -> str:
25
30
  values_desc = lambda weights, values: " ".join(["{}({:.2f}) : {:.6f}".format(name.split(":")[-1], weight, value) for (name, value), weight in zip(values.items(), weights.values())])
@@ -35,11 +40,11 @@ def description(model, modelEMA = None, showMemory: bool = True, train: bool = T
35
40
  def _getModule(classpath : str, type : str) -> tuple[str, str]:
36
41
  if len(classpath.split(":")) > 1:
37
42
  module = ".".join(classpath.split(":")[:-1])
38
- name = classpath.split(":")[-1]
43
+ name = classpath.split(":")[-1]
39
44
  else:
40
- module = "konfai."+type+("." if len(classpath.split(".")) > 2 else "")+".".join(classpath.split(".")[:-1])
45
+ module = type+("." if len(classpath.split(".")) > 2 else "")+".".join(classpath.split(".")[:-1])
41
46
  name = classpath.split(".")[-1]
42
- return module, name
47
+ return module, name.split("/")[0]
43
48
 
44
49
  def cpuInfo() -> str:
45
50
  return "CPU ({:.2f} %)".format(psutil.cpu_percent(interval=0.5))
@@ -420,6 +425,8 @@ def setup(parser: argparse.ArgumentParser) -> DistributedObject:
420
425
 
421
426
  os.environ["CUDA_VISIBLE_DEVICES"] = config["gpu"]
422
427
  os.environ["KONFAI_NB_CORES"] = config["cpu"]
428
+
429
+ os.environ["KONFAI_WORKERS"] = str(config["num_workers"])
423
430
  os.environ["KONFAI_MODELS_DIRECTORY"] = config["MODELS_DIRECTORY"]
424
431
  os.environ["KONFAI_CHECKPOINTS_DIRECTORY"] = config["CHECKPOINTS_DIRECTORY"]
425
432
  os.environ["KONFAI_PREDICTIONS_DIRECTORY"] = config["PREDICTIONS_DIRECTORY"]
@@ -536,6 +543,37 @@ def _resample_affine(data: torch.Tensor, matrix: torch.Tensor):
536
543
  return F.grid_sample(data.unsqueeze(0).type(torch.float32), F.affine_grid(matrix[:, :-1,...].type(torch.float32), [1]+list(data.shape), align_corners=True), align_corners=True, mode=mode, padding_mode="reflection").squeeze(0).type(data.dtype)
537
544
 
538
545
 
546
+
547
+ def download_url(model_name: str, url: str) -> str:
548
+ spec = importlib.util.find_spec("konfai")
549
+ base_path = Path(spec.submodule_search_locations[0]) / "metric" / "models"
550
+ subdirs = Path(model_name).parent
551
+ model_dir = base_path / subdirs
552
+ model_dir.mkdir(exist_ok=True)
553
+ filetmp = model_dir / ("tmp_"+str(Path(model_name).name))
554
+ file = model_dir / Path(model_name).name
555
+ if file.exists():
556
+ return str(file)
557
+
558
+ try:
559
+ print(f"[FOCUS] Downloading {model_name} to {file}")
560
+ with requests.get(url+model_name, stream=True) as r:
561
+ r.raise_for_status()
562
+ total = int(r.headers.get('content-length', 0))
563
+ with open(filetmp, 'wb') as f:
564
+ with tqdm(total=total, unit='B', unit_scale=True, desc=f"Downloading {model_name}") as pbar:
565
+ for chunk in r.iter_content(chunk_size=8192):
566
+ f.write(chunk)
567
+ pbar.update(len(chunk))
568
+ shutil.copy2(filetmp, file)
569
+ print("Download finished.")
570
+ except Exception as e:
571
+ raise e
572
+ finally:
573
+ if filetmp.exists():
574
+ os.remove(filetmp)
575
+ return str(file)
576
+
539
577
  SUPPORTED_EXTENSIONS = [
540
578
  "mha", "mhd", # MetaImage
541
579
  "nii", "nii.gz", # NIfTI
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: konfai
3
- Version: 1.1.5
3
+ Version: 1.1.7
4
4
  Summary: Modular and configurable Deep Learning framework with YAML and PyTorch
5
5
  Author-email: Valentin Boussot <boussot.v@gmail.com>
6
6
  License-Expression: Apache-2.0
@@ -0,0 +1,39 @@
1
+ konfai/__init__.py,sha256=YXG-wpSEXWs6Jt3BDI77V4r89gEUNX-6lxW9btj5VYI,851
2
+ konfai/evaluator.py,sha256=WM78NGydV0iqKElahr-WpOCZLEJXHjjPq-LS-gi23rk,8401
3
+ konfai/main.py,sha256=kr7Iie_f67NF6G3dAAj9G6Z9dhn9RzbdLYpzy2WvIh8,2573
4
+ konfai/predictor.py,sha256=6zFGjQhG5-3hF10pvCn58UDLB6ZQ_UhMlxkOCzpM4tM,23053
5
+ konfai/trainer.py,sha256=Edi8l8OOfCewYM-Cd5C5rCqCaprvlfxzohd4iLkK5u0,20632
6
+ konfai/data/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
7
+ konfai/data/augmentation.py,sha256=mFVMpbJ8WBKGbMILdmTZYBA8k7kRDQVPOEy1A9t5QP4,32281
8
+ konfai/data/data_manager.py,sha256=yphkTjk4_gyr_vOGodfhu9ImDHplRe2KR6dL2ORXzCw,29000
9
+ konfai/data/patching.py,sha256=zAm6jjUW--lsqTBlDFICVlj7O_QOlSxStFAHS9S9H8I,14753
10
+ konfai/data/transform.py,sha256=yZ6aALtVEqTYghREPH8Z1deglU7M4t4OiQqMY6gpjjA,26559
11
+ konfai/metric/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
12
+ konfai/metric/measure.py,sha256=S65WGBRYMoalkwI_Le3C6K0tQmT4Qft82nUTBwhpvFU,26907
13
+ konfai/metric/schedulers.py,sha256=VPp7zEwD0AtQz51XG0TlutD_NrsTZs4fstC5h8A8f8U,1309
14
+ konfai/models/classification/convNeXt.py,sha256=hZQ_9I1lJW7DX8QnoeZ9L1br-98apV9VXGrsTAN25ds,9261
15
+ konfai/models/classification/resnet.py,sha256=eBoS_zNczfa1EnECKKLr0Ow26ekR_4W5eji6p9cPfHk,7978
16
+ konfai/models/generation/cStyleGan.py,sha256=7D8zZveDEZapFeaaDTe3wVhuDCCHLqq6cpXtl2-QaJA,8059
17
+ konfai/models/generation/ddpm.py,sha256=jXb0eOU3i_gMHvj9pawVAWQDjMGl6SPfCc5m2Jzdrik,13175
18
+ konfai/models/generation/diffusionGan.py,sha256=ZFKdRHvlajEPyw9_HCLo0Q92iChCtZd2dvi2nEF_tBI,33218
19
+ konfai/models/generation/gan.py,sha256=RzpGNu8BlBVDCxFR5CCmEFYUzVBs2YrDl7trIos3ssw,7861
20
+ konfai/models/generation/vae.py,sha256=zH8qk04z2lXDhXAVxTHiRIVMfRi3brB61lHGmEsE3kM,4675
21
+ konfai/models/registration/registration.py,sha256=EAE3w8aic2fPWiJz0ilqrs2kCGUQD6NWvysVfHXxA_g,6329
22
+ konfai/models/representation/representation.py,sha256=TiYcBBqZYySpwsRlnnBQh0QVW29Rcvb9GUjsqnCKKLM,2689
23
+ konfai/models/segmentation/NestedUNet.py,sha256=hDSE7BJ17IqaiHWAD5zlVG7G_KvQzuQmVnAVsiYVE6E,10248
24
+ konfai/models/segmentation/UNet.py,sha256=BktCRfAcCDtvGCw8wGfyZvBtT4G0Oy8teIcVgDFOurk,4078
25
+ konfai/network/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
26
+ konfai/network/blocks.py,sha256=C_dOLmXovxANWCx7P439FzR_95Zisy2I1F7FEwwz7AE,14412
27
+ konfai/network/network.py,sha256=LHeA7HtsVYO7BJu2_kqh23q2GIANn5ZSa4LhKMt7dJg,48642
28
+ konfai/utils/ITK.py,sha256=OxTieDNNYHGkn7zxJsAG-6ecRG1VYMvn1dlBbBe1DOs,13955
29
+ konfai/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
30
+ konfai/utils/config.py,sha256=f7o83ix5_oNbr2pki-Czqr-yHi-8n92ZL64nlo0XGwA,12514
31
+ konfai/utils/dataset.py,sha256=6ZzevdhJ7e5zlXATAVwSh9O6acKXM7gYNxkMAa5DrmM,36351
32
+ konfai/utils/registration.py,sha256=v1srEBOcgDnHrx0YtsK6bcj0yCMH7wNeaQ3wC7gEvOw,8898
33
+ konfai/utils/utils.py,sha256=mQQ6FB0Jw7Odg5GxYeTApR6lvFhXn_iz-GVGZRngGQA,24934
34
+ konfai-1.1.7.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
35
+ konfai-1.1.7.dist-info/METADATA,sha256=DJ2n6Rc9NF018KsJPQxFHpQmlDazUnrrP4zHtgE4_Ew,2515
36
+ konfai-1.1.7.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
37
+ konfai-1.1.7.dist-info/entry_points.txt,sha256=fG82HRN5-g39ACSOCtij_I3N6EHxfYnMR0D7TI_8pW8,81
38
+ konfai-1.1.7.dist-info/top_level.txt,sha256=xF470dkIlFoFqTZEOlRehKJr4WU_8OKGXrJqYm9vWKs,7
39
+ konfai-1.1.7.dist-info/RECORD,,