konfai 1.1.5__py3-none-any.whl → 1.1.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of konfai might be problematic. Click here for more details.
- konfai/data/augmentation.py +1 -2
- konfai/data/data_manager.py +26 -40
- konfai/data/patching.py +7 -19
- konfai/data/transform.py +23 -21
- konfai/evaluator.py +1 -1
- konfai/main.py +0 -2
- konfai/metric/measure.py +86 -17
- konfai/metric/schedulers.py +7 -12
- konfai/models/classification/convNeXt.py +1 -2
- konfai/models/classification/resnet.py +1 -2
- konfai/models/generation/cStyleGan.py +1 -2
- konfai/models/generation/ddpm.py +1 -2
- konfai/models/generation/diffusionGan.py +10 -23
- konfai/models/generation/gan.py +3 -6
- konfai/models/generation/vae.py +0 -2
- konfai/models/registration/registration.py +1 -2
- konfai/models/representation/representation.py +1 -2
- konfai/models/segmentation/NestedUNet.py +61 -62
- konfai/models/segmentation/UNet.py +1 -2
- konfai/network/blocks.py +8 -6
- konfai/network/network.py +54 -47
- konfai/predictor.py +55 -76
- konfai/trainer.py +5 -4
- konfai/utils/utils.py +41 -3
- {konfai-1.1.5.dist-info → konfai-1.1.7.dist-info}/METADATA +1 -1
- konfai-1.1.7.dist-info/RECORD +39 -0
- konfai-1.1.5.dist-info/RECORD +0 -39
- {konfai-1.1.5.dist-info → konfai-1.1.7.dist-info}/WHEEL +0 -0
- {konfai-1.1.5.dist-info → konfai-1.1.7.dist-info}/entry_points.txt +0 -0
- {konfai-1.1.5.dist-info → konfai-1.1.7.dist-info}/licenses/LICENSE +0 -0
- {konfai-1.1.5.dist-info → konfai-1.1.7.dist-info}/top_level.txt +0 -0
konfai/network/network.py
CHANGED
|
@@ -16,7 +16,7 @@ from enum import Enum
|
|
|
16
16
|
from konfai import KONFAI_ROOT
|
|
17
17
|
from konfai.metric.schedulers import Scheduler
|
|
18
18
|
from konfai.utils.config import config
|
|
19
|
-
from konfai.utils.utils import State, _getModule, getDevice, getGPUMemory, MeasureError
|
|
19
|
+
from konfai.utils.utils import State, _getModule, getDevice, getGPUMemory, MeasureError, TrainerError
|
|
20
20
|
from konfai.data.patching import Accumulator, ModelPatch
|
|
21
21
|
|
|
22
22
|
class NetState(Enum):
|
|
@@ -41,50 +41,41 @@ class OptimizerLoader():
|
|
|
41
41
|
def getOptimizer(self, key: str, parameter: Iterator[torch.nn.parameter.Parameter]) -> torch.optim.Optimizer:
|
|
42
42
|
return config("{}.Model.{}.Optimizer".format(KONFAI_ROOT(), key))(getattr(importlib.import_module('torch.optim'), self.name))(parameter, config = None)
|
|
43
43
|
|
|
44
|
-
class SchedulerStep():
|
|
45
|
-
|
|
46
|
-
@config(None)
|
|
47
|
-
def __init__(self, nb_step : int = 0) -> None:
|
|
48
|
-
self.nb_step = nb_step
|
|
49
44
|
|
|
50
45
|
class LRSchedulersLoader():
|
|
51
46
|
|
|
52
|
-
@config(
|
|
53
|
-
def __init__(self,
|
|
54
|
-
self.
|
|
55
|
-
|
|
56
|
-
def getschedulers(self, key: str, optimizer: torch.optim.Optimizer) ->
|
|
57
|
-
|
|
58
|
-
|
|
59
|
-
if name:
|
|
60
|
-
|
|
61
|
-
|
|
62
|
-
|
|
63
|
-
class
|
|
47
|
+
@config()
|
|
48
|
+
def __init__(self, nb_step : int = 0) -> None:
|
|
49
|
+
self.nb_step = nb_step
|
|
50
|
+
|
|
51
|
+
def getschedulers(self, key: str, scheduler_classname: str, optimizer: torch.optim.Optimizer) -> torch.optim.lr_scheduler._LRScheduler:
|
|
52
|
+
for m in ["torch.optim.lr_scheduler", "konfai.metric.schedulers"]:
|
|
53
|
+
module, name = _getModule(scheduler_classname, m)
|
|
54
|
+
if hasattr(importlib.import_module(module), name):
|
|
55
|
+
return config("{}.Model.{}.schedulers.{}".format(KONFAI_ROOT(), key, scheduler_classname))(getattr(importlib.import_module(module), name))(optimizer, config = None)
|
|
56
|
+
raise TrainerError("Unknown scheduler {}, tried importing from: 'torch.optim.lr_scheduler' and 'konfai.metric.schedulers', but no valid match was found. Check your YAML config or scheduler name spelling.".format(scheduler_classname))
|
|
57
|
+
|
|
58
|
+
class LossSchedulersLoader():
|
|
64
59
|
|
|
65
|
-
@config(
|
|
66
|
-
def __init__(self,
|
|
67
|
-
self.
|
|
68
|
-
|
|
69
|
-
def getschedulers(self, key: str) ->
|
|
70
|
-
|
|
71
|
-
for name, step in self.params.items():
|
|
72
|
-
if name:
|
|
73
|
-
schedulers[getattr(importlib.import_module("konfai.metric.schedulers"), name)(config = None, DL_args = key)] = step.nb_step
|
|
74
|
-
return schedulers
|
|
60
|
+
@config()
|
|
61
|
+
def __init__(self, nb_step : int = 0) -> None:
|
|
62
|
+
self.nb_step = nb_step
|
|
63
|
+
|
|
64
|
+
def getschedulers(self, key: str, scheduler_classname: str) -> torch.optim.lr_scheduler._LRScheduler:
|
|
65
|
+
return config("{}.{}".format(key, scheduler_classname))(getattr(importlib.import_module('konfai.metric.schedulers'), scheduler_classname))(config = None)
|
|
75
66
|
|
|
76
67
|
class CriterionsAttr():
|
|
77
68
|
|
|
78
69
|
@config()
|
|
79
|
-
def __init__(self,
|
|
80
|
-
self.
|
|
70
|
+
def __init__(self, schedulers: dict[str, LossSchedulersLoader] = {"default:Constant": LossSchedulersLoader(0)}, isLoss: bool = True, group: int = 0, stepStart: int = 0, stepStop: Union[int, None] = None, accumulation: bool = False) -> None:
|
|
71
|
+
self.schedulersLoader = schedulers
|
|
81
72
|
self.isTorchCriterion = True
|
|
82
73
|
self.isLoss = isLoss
|
|
83
74
|
self.stepStart = stepStart
|
|
84
75
|
self.stepStop = stepStop
|
|
85
76
|
self.group = group
|
|
86
77
|
self.accumulation = accumulation
|
|
87
|
-
self.
|
|
78
|
+
self.schedulers: dict[Scheduler, int] = {}
|
|
88
79
|
|
|
89
80
|
class CriterionsLoader():
|
|
90
81
|
|
|
@@ -95,9 +86,10 @@ class CriterionsLoader():
|
|
|
95
86
|
def getCriterions(self, model_classname : str, output_group : str, target_group : str) -> dict[torch.nn.Module, CriterionsAttr]:
|
|
96
87
|
criterions = {}
|
|
97
88
|
for module_classpath, criterionsAttr in self.criterionsLoader.items():
|
|
98
|
-
module, name = _getModule(module_classpath, "metric.measure")
|
|
89
|
+
module, name = _getModule(module_classpath, "konfai.metric.measure")
|
|
99
90
|
criterionsAttr.isTorchCriterion = module.startswith("torch")
|
|
100
|
-
|
|
91
|
+
for scheduler_classname, schedulers in criterionsAttr.schedulersLoader.items():
|
|
92
|
+
criterionsAttr.schedulers[schedulers.getschedulers("{}.Model.{}.outputsCriterions.{}.targetsCriterions.{}.criterionsLoader.{}.schedulers".format(KONFAI_ROOT(), model_classname, output_group, target_group, module_classpath), scheduler_classname)] = schedulers.nb_step
|
|
101
93
|
criterions[config("{}.Model.{}.outputsCriterions.{}.targetsCriterions.{}.criterionsLoader.{}".format(KONFAI_ROOT(), model_classname, output_group, target_group, module_classpath))(getattr(importlib.import_module(module), name))(config = None)] = criterionsAttr
|
|
102
94
|
return criterions
|
|
103
95
|
|
|
@@ -195,7 +187,7 @@ class Measure():
|
|
|
195
187
|
|
|
196
188
|
for criterion, criterionsAttr in self.outputsCriterions[output_group][target_group].items():
|
|
197
189
|
if it >= criterionsAttr.stepStart and (criterionsAttr.stepStop is None or it <= criterionsAttr.stepStop):
|
|
198
|
-
scheduler = self.update_scheduler(criterionsAttr.
|
|
190
|
+
scheduler = self.update_scheduler(criterionsAttr.schedulers, it)
|
|
199
191
|
self._loss[criterionsAttr.group]["{}:{}:{}".format(output_group, target_group, criterion.__class__.__name__)].add(scheduler.get_value(), criterion(output, *target))
|
|
200
192
|
if training and len(np.unique([len(l) for l in self._loss[criterionsAttr.group].values() if l.accumulation and l.isLoss])) == 1:
|
|
201
193
|
if criterionsAttr.isLoss:
|
|
@@ -535,7 +527,7 @@ class Network(ModuleArgsDict, ABC):
|
|
|
535
527
|
def __init__( self,
|
|
536
528
|
in_channels : int = 1,
|
|
537
529
|
optimizer: Union[OptimizerLoader, None] = None,
|
|
538
|
-
schedulers: Union[LRSchedulersLoader, None] = None,
|
|
530
|
+
schedulers: Union[dict[str, LRSchedulersLoader], None] = None,
|
|
539
531
|
outputsCriterions: Union[dict[str, TargetCriterionsLoader], None] = None,
|
|
540
532
|
patch : Union[ModelPatch, None] = None,
|
|
541
533
|
nb_batch_per_step : int = 1,
|
|
@@ -549,7 +541,7 @@ class Network(ModuleArgsDict, ABC):
|
|
|
549
541
|
self.optimizer : Union[torch.optim.Optimizer, None] = None
|
|
550
542
|
|
|
551
543
|
self.LRSchedulersLoader = schedulers
|
|
552
|
-
self.schedulers : Union[dict[torch.optim.lr_scheduler._LRScheduler, int], None] =
|
|
544
|
+
self.schedulers : Union[dict[torch.optim.lr_scheduler._LRScheduler, int], None] = {}
|
|
553
545
|
|
|
554
546
|
self.outputsCriterionsLoader = outputsCriterions
|
|
555
547
|
self.measure : Union[Measure, None] = None
|
|
@@ -561,6 +553,7 @@ class Network(ModuleArgsDict, ABC):
|
|
|
561
553
|
self.init_gain = init_gain
|
|
562
554
|
self.dim = dim
|
|
563
555
|
self._it = 0
|
|
556
|
+
self._nb_lr_update = 0
|
|
564
557
|
self.outputsGroup : list[OutputsGroup]= []
|
|
565
558
|
|
|
566
559
|
@_function_network()
|
|
@@ -655,9 +648,15 @@ class Network(ModuleArgsDict, ABC):
|
|
|
655
648
|
else:
|
|
656
649
|
model_state_dict[alias] = model_state_dict_tmp[alias]
|
|
657
650
|
self.load_state_dict(model_state_dict)
|
|
658
|
-
|
|
659
|
-
|
|
660
|
-
|
|
651
|
+
if "{}_optimizer_state_dict".format(self.getName()) in state_dict and self.optimizer:
|
|
652
|
+
self.optimizer.load_state_dict(state_dict['{}_optimizer_state_dict'.format(self.getName())])
|
|
653
|
+
if "{}_it".format(self.getName()) in state_dict:
|
|
654
|
+
self._it = int(state_dict["{}_it".format(self.getName())])
|
|
655
|
+
if "{}_nb_lr_update".format(self.getName()) in state_dict:
|
|
656
|
+
self._nb_lr_update = int(state_dict["{}_nb_lr_update".format(self.getName())])
|
|
657
|
+
for scheduler in self.schedulers:
|
|
658
|
+
if scheduler.last_epoch == -1:
|
|
659
|
+
scheduler.last_epoch = self._nb_lr_update
|
|
661
660
|
self.initialized()
|
|
662
661
|
|
|
663
662
|
def _compute_channels_trace(self, module : ModuleArgsDict, in_channels : int, gradient_checkpoints: Union[list[str], None], gpu_checkpoints: Union[list[str], None], name: Union[str, None] = None, in_is_channel = True, out_channels : Union[int, None] = None, out_is_channel = True) -> tuple[int, bool, int, bool]:
|
|
@@ -730,7 +729,8 @@ class Network(ModuleArgsDict, ABC):
|
|
|
730
729
|
self.optimizer.zero_grad()
|
|
731
730
|
|
|
732
731
|
if self.LRSchedulersLoader and self.optimizer:
|
|
733
|
-
|
|
732
|
+
for schedulers_classname, schedulers in self.LRSchedulersLoader.items():
|
|
733
|
+
self.schedulers[schedulers.getschedulers(key, schedulers_classname, self.optimizer)] = schedulers.nb_step
|
|
734
734
|
|
|
735
735
|
def initialized(self):
|
|
736
736
|
pass
|
|
@@ -878,6 +878,7 @@ class Network(ModuleArgsDict, ABC):
|
|
|
878
878
|
model._requires_grad(list(self.measure.outputsCriterions.keys()))
|
|
879
879
|
for loss in self.measure.getLoss():
|
|
880
880
|
self.scaler.scale(loss / self.nb_batch_per_step).backward()
|
|
881
|
+
|
|
881
882
|
if self._it % self.nb_batch_per_step == 0:
|
|
882
883
|
self.scaler.step(self.optimizer)
|
|
883
884
|
self.scaler.update()
|
|
@@ -886,19 +887,19 @@ class Network(ModuleArgsDict, ABC):
|
|
|
886
887
|
|
|
887
888
|
@_function_network()
|
|
888
889
|
def update_lr(self):
|
|
890
|
+
self._nb_lr_update+=1
|
|
889
891
|
step = 0
|
|
890
892
|
scheduler = None
|
|
891
|
-
|
|
892
|
-
|
|
893
|
-
|
|
894
|
-
|
|
895
|
-
step += value
|
|
893
|
+
for scheduler, value in self.schedulers.items():
|
|
894
|
+
if value is None or (self._nb_lr_update >= step and self._nb_lr_update < step+value):
|
|
895
|
+
break
|
|
896
|
+
step += value
|
|
896
897
|
if scheduler:
|
|
897
898
|
if scheduler.__class__.__name__ == 'ReduceLROnPlateau':
|
|
898
899
|
if self.measure:
|
|
899
900
|
scheduler.step(sum(self.measure.getLastValues(0).values()))
|
|
900
901
|
else:
|
|
901
|
-
scheduler.step()
|
|
902
|
+
scheduler.step()
|
|
902
903
|
|
|
903
904
|
@_function_network()
|
|
904
905
|
def getNetworks(self) -> Self:
|
|
@@ -916,6 +917,12 @@ class Network(ModuleArgsDict, ABC):
|
|
|
916
917
|
v = Network.to(v, int(os.environ["device"]))
|
|
917
918
|
else:
|
|
918
919
|
v = v.to(getDevice(int(os.environ["device"])))
|
|
920
|
+
if isinstance(module, Network):
|
|
921
|
+
if module.optimizer is not None:
|
|
922
|
+
for state in module.optimizer.state.values():
|
|
923
|
+
for k, v in state.items():
|
|
924
|
+
if isinstance(v, torch.Tensor):
|
|
925
|
+
state[k] = v.to(getDevice(int(os.environ["device"])))
|
|
919
926
|
return module
|
|
920
927
|
|
|
921
928
|
def getName(self) -> str:
|
|
@@ -944,7 +951,7 @@ class ModelLoader():
|
|
|
944
951
|
|
|
945
952
|
@config("Model")
|
|
946
953
|
def __init__(self, classpath : str = "default:segmentation.UNet.UNet") -> None:
|
|
947
|
-
self.module, self.name = _getModule(classpath, "models")
|
|
954
|
+
self.module, self.name = _getModule(classpath, "konfai.models")
|
|
948
955
|
|
|
949
956
|
def getModel(self, train : bool = True, DL_args: Union[str, None] = None, DL_without=["optimizer", "schedulers", "nb_batch_per_step", "init_type", "init_gain"]) -> Network:
|
|
950
957
|
if not DL_args:
|
konfai/predictor.py
CHANGED
|
@@ -27,17 +27,19 @@ from collections import defaultdict
|
|
|
27
27
|
|
|
28
28
|
class OutDataset(Dataset, NeedDevice, ABC):
|
|
29
29
|
|
|
30
|
-
def __init__(self, filename: str, group: str,
|
|
30
|
+
def __init__(self, filename: str, group: str, before_reduction_transforms : dict[str, TransformLoader], after_reduction_transforms : dict[str, TransformLoader], final_transforms : dict[str, TransformLoader], patchCombine: Union[str, None], reduction: str) -> None:
|
|
31
31
|
filename, format = filename.split(":")
|
|
32
32
|
super().__init__(filename, format)
|
|
33
33
|
self.group = group
|
|
34
|
-
self.
|
|
35
|
-
self.
|
|
34
|
+
self._before_reduction_transforms = before_reduction_transforms
|
|
35
|
+
self._after_reduction_transforms = after_reduction_transforms
|
|
36
36
|
self._final_transforms = final_transforms
|
|
37
37
|
self._patchCombine = patchCombine
|
|
38
|
+
self.reduction_classpath = reduction
|
|
39
|
+
self.reduction = None
|
|
38
40
|
|
|
39
|
-
self.
|
|
40
|
-
self.
|
|
41
|
+
self.before_reduction_transforms : list[Transform] = []
|
|
42
|
+
self.after_reduction_transforms : list[Transform] = []
|
|
41
43
|
self.final_transforms : list[Transform] = []
|
|
42
44
|
self.patchCombine: PathCombine = None
|
|
43
45
|
|
|
@@ -47,7 +49,7 @@ class OutDataset(Dataset, NeedDevice, ABC):
|
|
|
47
49
|
self.nb_data_augmentation = 0
|
|
48
50
|
|
|
49
51
|
def load(self, name_layer: str, datasets: list[Dataset], groups: dict[str, str]):
|
|
50
|
-
transforms_type = ["
|
|
52
|
+
transforms_type = ["before_reduction_transforms", "after_reduction_transforms", "final_transforms"]
|
|
51
53
|
for name, _transform_type, transform_type in [(k, getattr(self, "_{}".format(k)), getattr(self, k)) for k in transforms_type]:
|
|
52
54
|
|
|
53
55
|
if _transform_type is not None:
|
|
@@ -57,8 +59,15 @@ class OutDataset(Dataset, NeedDevice, ABC):
|
|
|
57
59
|
transform_type.append(transform)
|
|
58
60
|
|
|
59
61
|
if self._patchCombine is not None:
|
|
60
|
-
module, name = _getModule(self._patchCombine, "data.patching")
|
|
61
|
-
self.patchCombine =
|
|
62
|
+
module, name = _getModule(self._patchCombine, "konfai.data.patching")
|
|
63
|
+
self.patchCombine = config("{}.outsDataset.{}.OutDataset".format(KONFAI_ROOT(), name_layer))(getattr(importlib.import_module(module), name))(config = None)
|
|
64
|
+
|
|
65
|
+
module, name = _getModule(self.reduction_classpath, "konfai.predictor")
|
|
66
|
+
if module == "konfai.predictor":
|
|
67
|
+
self.reduction = getattr(importlib.import_module(module), name)
|
|
68
|
+
else:
|
|
69
|
+
self.reduction = config("{}.outsDataset.{}.OutDataset.{}".format(KONFAI_ROOT(), name_layer, self.reduction_classpath))(getattr(importlib.import_module(module), name))(config = None)
|
|
70
|
+
|
|
62
71
|
|
|
63
72
|
def setPatchConfig(self, patchSize: Union[list[int], None], overlap: Union[int, None], nb_data_augmentation: int) -> None:
|
|
64
73
|
if patchSize is not None and overlap is not None:
|
|
@@ -70,14 +79,14 @@ class OutDataset(Dataset, NeedDevice, ABC):
|
|
|
70
79
|
|
|
71
80
|
def setDevice(self, device: torch.device):
|
|
72
81
|
super().setDevice(device)
|
|
73
|
-
transforms_type = ["
|
|
82
|
+
transforms_type = ["before_reduction_transforms", "after_reduction_transforms", "final_transforms"]
|
|
74
83
|
for transform_type in [(getattr(self, k)) for k in transforms_type]:
|
|
75
84
|
if transform_type is not None:
|
|
76
85
|
for transform in transform_type:
|
|
77
86
|
transform.setDevice(device)
|
|
78
87
|
|
|
79
88
|
@abstractmethod
|
|
80
|
-
def addLayer(self,
|
|
89
|
+
def addLayer(self, index_dataset: int, index_augmentation: int, index_patch: int, layer: torch.Tensor, dataset: DatasetIter):
|
|
81
90
|
pass
|
|
82
91
|
|
|
83
92
|
def isDone(self, index: int) -> bool:
|
|
@@ -96,25 +105,28 @@ class Reduction():
|
|
|
96
105
|
def __init__(self):
|
|
97
106
|
pass
|
|
98
107
|
|
|
99
|
-
class
|
|
108
|
+
class Mean(Reduction):
|
|
100
109
|
|
|
101
110
|
def __init__(self):
|
|
102
111
|
pass
|
|
103
|
-
|
|
104
|
-
|
|
112
|
+
|
|
113
|
+
def __call__(self, result: torch.Tensor) -> torch.Tensor:
|
|
114
|
+
return torch.mean(result.float(), dim=0)
|
|
115
|
+
|
|
116
|
+
class Median(Reduction):
|
|
105
117
|
|
|
106
118
|
def __init__(self):
|
|
107
119
|
pass
|
|
108
120
|
|
|
109
|
-
|
|
121
|
+
def __call__(self, result: torch.Tensor) -> torch.Tensor:
|
|
122
|
+
return torch.median(result.float(), dim=0).values
|
|
110
123
|
|
|
111
124
|
class OutSameAsGroupDataset(OutDataset):
|
|
112
125
|
|
|
113
126
|
@config("OutDataset")
|
|
114
|
-
def __init__(self,
|
|
115
|
-
super().__init__(dataset_filename, group,
|
|
127
|
+
def __init__(self, sameAsGroup: str = "default", dataset_filename: str = "default:./Dataset:mha", group: str = "default", before_reduction_transforms : dict[str, TransformLoader] = {"default:Normalize": TransformLoader()}, after_reduction_transforms : dict[str, TransformLoader] = {"default:Normalize": TransformLoader()}, final_transforms : dict[str, TransformLoader] = {"default:Normalize": TransformLoader()}, patchCombine: Union[str, None] = None, reduction: str = "mean", inverse_transform: bool = True) -> None:
|
|
128
|
+
super().__init__(dataset_filename, group, before_reduction_transforms, after_reduction_transforms, final_transforms, patchCombine, reduction)
|
|
116
129
|
self.group_src, self.group_dest = sameAsGroup.split(":")
|
|
117
|
-
self.reduction = reduction
|
|
118
130
|
self.inverse_transform = inverse_transform
|
|
119
131
|
|
|
120
132
|
def addLayer(self, index_dataset: int, index_augmentation: int, index_patch: int, layer: torch.Tensor, dataset: DatasetIter):
|
|
@@ -131,9 +143,6 @@ class OutSameAsGroupDataset(OutDataset):
|
|
|
131
143
|
for i in range(len(input_dataset.patch.getPatch_slices(index_augmentation))):
|
|
132
144
|
self.attributes[index_dataset][index_augmentation][i] = Attribute(input_dataset.cache_attributes[0])
|
|
133
145
|
|
|
134
|
-
for transform in self.pre_transforms:
|
|
135
|
-
layer = transform(self.names[index_dataset], layer, self.attributes[index_dataset][index_augmentation][index_patch])
|
|
136
|
-
|
|
137
146
|
if self.inverse_transform:
|
|
138
147
|
for transform in reversed(dataset.groups_src[self.group_src][self.group_dest].post_transforms):
|
|
139
148
|
layer = transform.inverse(self.names[index_dataset], layer, self.attributes[index_dataset][index_augmentation][index_patch])
|
|
@@ -142,7 +151,7 @@ class OutSameAsGroupDataset(OutDataset):
|
|
|
142
151
|
|
|
143
152
|
def load(self, name_layer: str, datasets: list[Dataset], groups: dict[str, str]):
|
|
144
153
|
super().load(name_layer, datasets, groups)
|
|
145
|
-
|
|
154
|
+
|
|
146
155
|
if self.group_src not in groups.keys():
|
|
147
156
|
raise PredictorError(
|
|
148
157
|
f"Source group '{self.group_src}' not found. Available groups: {list(groups.keys())}."
|
|
@@ -155,7 +164,6 @@ class OutSameAsGroupDataset(OutDataset):
|
|
|
155
164
|
|
|
156
165
|
def _getOutput(self, index: int, index_augmentation: int, dataset: DatasetIter) -> torch.Tensor:
|
|
157
166
|
layer = self.output_layer_accumulator[index][index_augmentation].assemble()
|
|
158
|
-
name = self.names[index]
|
|
159
167
|
if index_augmentation > 0:
|
|
160
168
|
|
|
161
169
|
i = 0
|
|
@@ -167,59 +175,27 @@ class OutSameAsGroupDataset(OutDataset):
|
|
|
167
175
|
break
|
|
168
176
|
i += dataAugmentations.nb
|
|
169
177
|
|
|
170
|
-
for transform in self.
|
|
171
|
-
layer = transform(
|
|
172
|
-
|
|
173
|
-
if self.inverse_transform:
|
|
174
|
-
for transform in reversed(dataset.groups_src[self.group_src][self.group_dest].pre_transforms):
|
|
175
|
-
layer = transform.inverse(name, layer, self.attributes[index][index_augmentation][0])
|
|
178
|
+
for transform in self.before_reduction_transforms:
|
|
179
|
+
layer = transform(self.names[index], layer, self.attributes[index][index_augmentation][0])
|
|
180
|
+
|
|
176
181
|
return layer
|
|
177
182
|
|
|
178
183
|
def getOutput(self, index: int, dataset: DatasetIter) -> torch.Tensor:
|
|
179
184
|
result = torch.cat([self._getOutput(index, index_augmentation, dataset).unsqueeze(0) for index_augmentation in self.output_layer_accumulator[index].keys()], dim=0)
|
|
180
|
-
name = self.names[index]
|
|
181
185
|
self.output_layer_accumulator.pop(index)
|
|
182
|
-
|
|
186
|
+
result = self.reduction(result.float()).to(result.dtype)
|
|
187
|
+
|
|
188
|
+
for transform in self.after_reduction_transforms:
|
|
189
|
+
result = transform(self.names[index], result, self.attributes[index][0][0])
|
|
183
190
|
|
|
184
|
-
if self.
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
else:
|
|
189
|
-
raise NameError("Reduction method does not exist (mean, median)")
|
|
191
|
+
if self.inverse_transform:
|
|
192
|
+
for transform in reversed(dataset.groups_src[self.group_src][self.group_dest].pre_transforms):
|
|
193
|
+
result = transform.inverse(self.names[index], result, self.attributes[index][0][0])
|
|
194
|
+
|
|
190
195
|
for transform in self.final_transforms:
|
|
191
|
-
result = transform(
|
|
196
|
+
result = transform(self.names[index], result, self.attributes[index][0][0])
|
|
192
197
|
return result
|
|
193
198
|
|
|
194
|
-
class OutLayerDataset(OutDataset):
|
|
195
|
-
|
|
196
|
-
@config("OutDataset")
|
|
197
|
-
def __init__(self, dataset_filename: str = "Dataset.h5", group: str = "default", overlap : Union[list[int], None] = None, pre_transforms : dict[str, TransformLoader] = {"default:Normalize": TransformLoader()}, post_transforms : dict[str, TransformLoader] = {"default:Normalize": TransformLoader()}, final_transforms : dict[str, TransformLoader] = {"default:Normalize": TransformLoader()}, patchCombine: Union[str, None] = None) -> None:
|
|
198
|
-
super().__init__(dataset_filename, group, pre_transforms, post_transforms, final_transforms, patchCombine)
|
|
199
|
-
self.overlap = overlap
|
|
200
|
-
|
|
201
|
-
def addLayer(self, index: int, index_patch: int, layer: torch.Tensor, dataset: DatasetIter):
|
|
202
|
-
if index not in self.output_layer_accumulator:
|
|
203
|
-
group = list(dataset.groups.keys())[0]
|
|
204
|
-
patch_slices = get_patch_slices_from_nb_patch_per_dim(list(layer.shape[2:]), dataset.getDatasetFromIndex(group, index).patch.nb_patch_per_dim, self.overlap)
|
|
205
|
-
self.output_layer_accumulator[index] = Accumulator(patch_slices, self.patchCombine, batch=False)
|
|
206
|
-
self.attributes[index] = Attribute()
|
|
207
|
-
self.names[index] = dataset.getDatasetFromIndex(group, index).name
|
|
208
|
-
|
|
209
|
-
|
|
210
|
-
for transform in self.pre_transforms:
|
|
211
|
-
layer = transform(layer, self.attributes[index])
|
|
212
|
-
self.output_layer_accumulator[index].addLayer(index_patch, layer)
|
|
213
|
-
|
|
214
|
-
def getOutput(self, index: int, dataset: DatasetIter) -> torch.Tensor:
|
|
215
|
-
layer = self.output_layer_accumulator[index].assemble()
|
|
216
|
-
name = self.names[index]
|
|
217
|
-
for transform in self.post_transforms:
|
|
218
|
-
layer = transform(name, layer, self.attributes[index])
|
|
219
|
-
|
|
220
|
-
self.output_layer_accumulator.pop(index)
|
|
221
|
-
return layer
|
|
222
|
-
|
|
223
199
|
class OutDatasetLoader():
|
|
224
200
|
|
|
225
201
|
@config("OutDataset")
|
|
@@ -307,9 +283,9 @@ class _Predictor():
|
|
|
307
283
|
|
|
308
284
|
class ModelComposite(Network):
|
|
309
285
|
|
|
310
|
-
def __init__(self, model: Network, nb_models: int,
|
|
286
|
+
def __init__(self, model: Network, nb_models: int, combine: Reduction):
|
|
311
287
|
super().__init__(model.in_channels, model.optimizer, model.schedulers, model.outputsCriterionsLoader, model.patch, model.nb_batch_per_step, model.init_type, model.init_gain, model.dim)
|
|
312
|
-
self.
|
|
288
|
+
self.combine = combine
|
|
313
289
|
for i in range(nb_models):
|
|
314
290
|
self.add_module("Model_{}".format(i), copy.deepcopy(model), in_branch=[0], out_branch=["output_{}".format(i)])
|
|
315
291
|
|
|
@@ -330,12 +306,7 @@ class ModelComposite(Network):
|
|
|
330
306
|
|
|
331
307
|
final_outputs = []
|
|
332
308
|
for key, tensors in aggregated.items():
|
|
333
|
-
|
|
334
|
-
if self.method == 'mean':
|
|
335
|
-
agg = torch.mean(stacked, dim=0)
|
|
336
|
-
elif self.method == 'median':
|
|
337
|
-
agg = torch.median(stacked, dim=0).values
|
|
338
|
-
final_outputs.append((key, agg))
|
|
309
|
+
final_outputs.append((key, self.combine(torch.stack(tensors, dim=0))))
|
|
339
310
|
|
|
340
311
|
return final_outputs
|
|
341
312
|
|
|
@@ -358,7 +329,7 @@ class Predictor(DistributedObject):
|
|
|
358
329
|
super().__init__(train_name)
|
|
359
330
|
self.manual_seed = manual_seed
|
|
360
331
|
self.dataset = dataset
|
|
361
|
-
self.
|
|
332
|
+
self.combine_classpath = combine
|
|
362
333
|
self.autocast = autocast
|
|
363
334
|
|
|
364
335
|
self.model = model.getModel(train=False)
|
|
@@ -430,7 +401,15 @@ class Predictor(DistributedObject):
|
|
|
430
401
|
"Available modules: {}".format(modules),
|
|
431
402
|
"Please check that the name matches exactly a submodule or output of your model architecture."
|
|
432
403
|
)
|
|
433
|
-
|
|
404
|
+
|
|
405
|
+
module, name = _getModule(self.combine_classpath, "konfai.predictor")
|
|
406
|
+
if module == "konfai.predictor":
|
|
407
|
+
combine = getattr(importlib.import_module(module), name)
|
|
408
|
+
else:
|
|
409
|
+
combine = config("{}.{}".format(KONFAI_ROOT(), self.combine_classpath))(getattr(importlib.import_module(module), name))(config = None)
|
|
410
|
+
|
|
411
|
+
|
|
412
|
+
self.modelComposite = ModelComposite(self.model, len(MODEL().split(":")), combine)
|
|
434
413
|
self.modelComposite.load(self._load())
|
|
435
414
|
|
|
436
415
|
if len(list(self.outsDataset.keys())) == 0 and len([network for network in self.modelComposite.getNetworks().values() if network.measure is not None]) == 0:
|
konfai/trainer.py
CHANGED
|
@@ -108,9 +108,7 @@ class _Trainer():
|
|
|
108
108
|
if data_log is not None:
|
|
109
109
|
for data in data_log:
|
|
110
110
|
self.data_log[data.split("/")[0].replace(":", ".")] = (DataLog.__getitem__(data.split("/")[1]).value[0], int(data.split("/")[2]))
|
|
111
|
-
|
|
112
|
-
|
|
113
|
-
|
|
111
|
+
|
|
114
112
|
def __enter__(self):
|
|
115
113
|
return self
|
|
116
114
|
|
|
@@ -212,6 +210,9 @@ class _Trainer():
|
|
|
212
210
|
save_dict["Model_EMA"] = self.modelEMA.module.state_dict()
|
|
213
211
|
|
|
214
212
|
save_dict.update({'{}_optimizer_state_dict'.format(name): network.optimizer.state_dict() for name, network in self.model.module.getNetworks().items() if network.optimizer is not None})
|
|
213
|
+
save_dict.update({'{}_it'.format(name): network._it for name, network in self.model.module.getNetworks().items() if network.optimizer is not None})
|
|
214
|
+
save_dict.update({'{}_nb_lr_update'.format(name): network._nb_lr_update for name, network in self.model.module.getNetworks().items() if network.optimizer is not None})
|
|
215
|
+
|
|
215
216
|
torch.save(save_dict, save_path)
|
|
216
217
|
|
|
217
218
|
if self.save_checkpoint_mode == "BEST":
|
|
@@ -356,7 +357,7 @@ class Trainer(DistributedObject):
|
|
|
356
357
|
name = sorted(os.listdir(path))[-1]
|
|
357
358
|
|
|
358
359
|
if os.path.exists(path+name):
|
|
359
|
-
state_dict = torch.load(path+name, weights_only=False)
|
|
360
|
+
state_dict = torch.load(path+name, weights_only=False, map_location="cpu")
|
|
360
361
|
else:
|
|
361
362
|
raise Exception("Model : {} does not exist !".format(self.name))
|
|
362
363
|
|
konfai/utils/utils.py
CHANGED
|
@@ -20,6 +20,11 @@ import torch.nn.functional as F
|
|
|
20
20
|
import sys
|
|
21
21
|
import re
|
|
22
22
|
|
|
23
|
+
import requests
|
|
24
|
+
from tqdm import tqdm
|
|
25
|
+
import importlib
|
|
26
|
+
from pathlib import Path
|
|
27
|
+
import shutil
|
|
23
28
|
|
|
24
29
|
def description(model, modelEMA = None, showMemory: bool = True, train: bool = True) -> str:
|
|
25
30
|
values_desc = lambda weights, values: " ".join(["{}({:.2f}) : {:.6f}".format(name.split(":")[-1], weight, value) for (name, value), weight in zip(values.items(), weights.values())])
|
|
@@ -35,11 +40,11 @@ def description(model, modelEMA = None, showMemory: bool = True, train: bool = T
|
|
|
35
40
|
def _getModule(classpath : str, type : str) -> tuple[str, str]:
|
|
36
41
|
if len(classpath.split(":")) > 1:
|
|
37
42
|
module = ".".join(classpath.split(":")[:-1])
|
|
38
|
-
name = classpath.split(":")[-1]
|
|
43
|
+
name = classpath.split(":")[-1]
|
|
39
44
|
else:
|
|
40
|
-
module =
|
|
45
|
+
module = type+("." if len(classpath.split(".")) > 2 else "")+".".join(classpath.split(".")[:-1])
|
|
41
46
|
name = classpath.split(".")[-1]
|
|
42
|
-
return module, name
|
|
47
|
+
return module, name.split("/")[0]
|
|
43
48
|
|
|
44
49
|
def cpuInfo() -> str:
|
|
45
50
|
return "CPU ({:.2f} %)".format(psutil.cpu_percent(interval=0.5))
|
|
@@ -420,6 +425,8 @@ def setup(parser: argparse.ArgumentParser) -> DistributedObject:
|
|
|
420
425
|
|
|
421
426
|
os.environ["CUDA_VISIBLE_DEVICES"] = config["gpu"]
|
|
422
427
|
os.environ["KONFAI_NB_CORES"] = config["cpu"]
|
|
428
|
+
|
|
429
|
+
os.environ["KONFAI_WORKERS"] = str(config["num_workers"])
|
|
423
430
|
os.environ["KONFAI_MODELS_DIRECTORY"] = config["MODELS_DIRECTORY"]
|
|
424
431
|
os.environ["KONFAI_CHECKPOINTS_DIRECTORY"] = config["CHECKPOINTS_DIRECTORY"]
|
|
425
432
|
os.environ["KONFAI_PREDICTIONS_DIRECTORY"] = config["PREDICTIONS_DIRECTORY"]
|
|
@@ -536,6 +543,37 @@ def _resample_affine(data: torch.Tensor, matrix: torch.Tensor):
|
|
|
536
543
|
return F.grid_sample(data.unsqueeze(0).type(torch.float32), F.affine_grid(matrix[:, :-1,...].type(torch.float32), [1]+list(data.shape), align_corners=True), align_corners=True, mode=mode, padding_mode="reflection").squeeze(0).type(data.dtype)
|
|
537
544
|
|
|
538
545
|
|
|
546
|
+
|
|
547
|
+
def download_url(model_name: str, url: str) -> str:
|
|
548
|
+
spec = importlib.util.find_spec("konfai")
|
|
549
|
+
base_path = Path(spec.submodule_search_locations[0]) / "metric" / "models"
|
|
550
|
+
subdirs = Path(model_name).parent
|
|
551
|
+
model_dir = base_path / subdirs
|
|
552
|
+
model_dir.mkdir(exist_ok=True)
|
|
553
|
+
filetmp = model_dir / ("tmp_"+str(Path(model_name).name))
|
|
554
|
+
file = model_dir / Path(model_name).name
|
|
555
|
+
if file.exists():
|
|
556
|
+
return str(file)
|
|
557
|
+
|
|
558
|
+
try:
|
|
559
|
+
print(f"[FOCUS] Downloading {model_name} to {file}")
|
|
560
|
+
with requests.get(url+model_name, stream=True) as r:
|
|
561
|
+
r.raise_for_status()
|
|
562
|
+
total = int(r.headers.get('content-length', 0))
|
|
563
|
+
with open(filetmp, 'wb') as f:
|
|
564
|
+
with tqdm(total=total, unit='B', unit_scale=True, desc=f"Downloading {model_name}") as pbar:
|
|
565
|
+
for chunk in r.iter_content(chunk_size=8192):
|
|
566
|
+
f.write(chunk)
|
|
567
|
+
pbar.update(len(chunk))
|
|
568
|
+
shutil.copy2(filetmp, file)
|
|
569
|
+
print("Download finished.")
|
|
570
|
+
except Exception as e:
|
|
571
|
+
raise e
|
|
572
|
+
finally:
|
|
573
|
+
if filetmp.exists():
|
|
574
|
+
os.remove(filetmp)
|
|
575
|
+
return str(file)
|
|
576
|
+
|
|
539
577
|
SUPPORTED_EXTENSIONS = [
|
|
540
578
|
"mha", "mhd", # MetaImage
|
|
541
579
|
"nii", "nii.gz", # NIfTI
|
|
@@ -0,0 +1,39 @@
|
|
|
1
|
+
konfai/__init__.py,sha256=YXG-wpSEXWs6Jt3BDI77V4r89gEUNX-6lxW9btj5VYI,851
|
|
2
|
+
konfai/evaluator.py,sha256=WM78NGydV0iqKElahr-WpOCZLEJXHjjPq-LS-gi23rk,8401
|
|
3
|
+
konfai/main.py,sha256=kr7Iie_f67NF6G3dAAj9G6Z9dhn9RzbdLYpzy2WvIh8,2573
|
|
4
|
+
konfai/predictor.py,sha256=6zFGjQhG5-3hF10pvCn58UDLB6ZQ_UhMlxkOCzpM4tM,23053
|
|
5
|
+
konfai/trainer.py,sha256=Edi8l8OOfCewYM-Cd5C5rCqCaprvlfxzohd4iLkK5u0,20632
|
|
6
|
+
konfai/data/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
7
|
+
konfai/data/augmentation.py,sha256=mFVMpbJ8WBKGbMILdmTZYBA8k7kRDQVPOEy1A9t5QP4,32281
|
|
8
|
+
konfai/data/data_manager.py,sha256=yphkTjk4_gyr_vOGodfhu9ImDHplRe2KR6dL2ORXzCw,29000
|
|
9
|
+
konfai/data/patching.py,sha256=zAm6jjUW--lsqTBlDFICVlj7O_QOlSxStFAHS9S9H8I,14753
|
|
10
|
+
konfai/data/transform.py,sha256=yZ6aALtVEqTYghREPH8Z1deglU7M4t4OiQqMY6gpjjA,26559
|
|
11
|
+
konfai/metric/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
12
|
+
konfai/metric/measure.py,sha256=S65WGBRYMoalkwI_Le3C6K0tQmT4Qft82nUTBwhpvFU,26907
|
|
13
|
+
konfai/metric/schedulers.py,sha256=VPp7zEwD0AtQz51XG0TlutD_NrsTZs4fstC5h8A8f8U,1309
|
|
14
|
+
konfai/models/classification/convNeXt.py,sha256=hZQ_9I1lJW7DX8QnoeZ9L1br-98apV9VXGrsTAN25ds,9261
|
|
15
|
+
konfai/models/classification/resnet.py,sha256=eBoS_zNczfa1EnECKKLr0Ow26ekR_4W5eji6p9cPfHk,7978
|
|
16
|
+
konfai/models/generation/cStyleGan.py,sha256=7D8zZveDEZapFeaaDTe3wVhuDCCHLqq6cpXtl2-QaJA,8059
|
|
17
|
+
konfai/models/generation/ddpm.py,sha256=jXb0eOU3i_gMHvj9pawVAWQDjMGl6SPfCc5m2Jzdrik,13175
|
|
18
|
+
konfai/models/generation/diffusionGan.py,sha256=ZFKdRHvlajEPyw9_HCLo0Q92iChCtZd2dvi2nEF_tBI,33218
|
|
19
|
+
konfai/models/generation/gan.py,sha256=RzpGNu8BlBVDCxFR5CCmEFYUzVBs2YrDl7trIos3ssw,7861
|
|
20
|
+
konfai/models/generation/vae.py,sha256=zH8qk04z2lXDhXAVxTHiRIVMfRi3brB61lHGmEsE3kM,4675
|
|
21
|
+
konfai/models/registration/registration.py,sha256=EAE3w8aic2fPWiJz0ilqrs2kCGUQD6NWvysVfHXxA_g,6329
|
|
22
|
+
konfai/models/representation/representation.py,sha256=TiYcBBqZYySpwsRlnnBQh0QVW29Rcvb9GUjsqnCKKLM,2689
|
|
23
|
+
konfai/models/segmentation/NestedUNet.py,sha256=hDSE7BJ17IqaiHWAD5zlVG7G_KvQzuQmVnAVsiYVE6E,10248
|
|
24
|
+
konfai/models/segmentation/UNet.py,sha256=BktCRfAcCDtvGCw8wGfyZvBtT4G0Oy8teIcVgDFOurk,4078
|
|
25
|
+
konfai/network/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
26
|
+
konfai/network/blocks.py,sha256=C_dOLmXovxANWCx7P439FzR_95Zisy2I1F7FEwwz7AE,14412
|
|
27
|
+
konfai/network/network.py,sha256=LHeA7HtsVYO7BJu2_kqh23q2GIANn5ZSa4LhKMt7dJg,48642
|
|
28
|
+
konfai/utils/ITK.py,sha256=OxTieDNNYHGkn7zxJsAG-6ecRG1VYMvn1dlBbBe1DOs,13955
|
|
29
|
+
konfai/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
30
|
+
konfai/utils/config.py,sha256=f7o83ix5_oNbr2pki-Czqr-yHi-8n92ZL64nlo0XGwA,12514
|
|
31
|
+
konfai/utils/dataset.py,sha256=6ZzevdhJ7e5zlXATAVwSh9O6acKXM7gYNxkMAa5DrmM,36351
|
|
32
|
+
konfai/utils/registration.py,sha256=v1srEBOcgDnHrx0YtsK6bcj0yCMH7wNeaQ3wC7gEvOw,8898
|
|
33
|
+
konfai/utils/utils.py,sha256=mQQ6FB0Jw7Odg5GxYeTApR6lvFhXn_iz-GVGZRngGQA,24934
|
|
34
|
+
konfai-1.1.7.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
|
|
35
|
+
konfai-1.1.7.dist-info/METADATA,sha256=DJ2n6Rc9NF018KsJPQxFHpQmlDazUnrrP4zHtgE4_Ew,2515
|
|
36
|
+
konfai-1.1.7.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
37
|
+
konfai-1.1.7.dist-info/entry_points.txt,sha256=fG82HRN5-g39ACSOCtij_I3N6EHxfYnMR0D7TI_8pW8,81
|
|
38
|
+
konfai-1.1.7.dist-info/top_level.txt,sha256=xF470dkIlFoFqTZEOlRehKJr4WU_8OKGXrJqYm9vWKs,7
|
|
39
|
+
konfai-1.1.7.dist-info/RECORD,,
|