konfai 1.0.8__py3-none-any.whl → 1.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of konfai might be problematic. Click here for more details.
- konfai/__init__.py +10 -10
- konfai/data/augmentation.py +4 -4
- konfai/data/{dataset.py → data_manager.py} +39 -43
- konfai/data/{HDF5.py → patching.py} +7 -7
- konfai/data/transform.py +3 -3
- konfai/evaluator.py +6 -6
- konfai/main.py +29 -22
- konfai/metric/measure.py +7 -4
- konfai/models/classification/convNeXt.py +1 -1
- konfai/models/classification/resnet.py +1 -1
- konfai/models/generation/cStyleGan.py +1 -1
- konfai/models/generation/ddpm.py +1 -1
- konfai/models/generation/diffusionGan.py +1 -1
- konfai/models/generation/gan.py +1 -1
- konfai/models/segmentation/NestedUNet.py +1 -1
- konfai/models/segmentation/UNet.py +1 -1
- konfai/network/network.py +12 -12
- konfai/predictor.py +10 -10
- konfai/trainer.py +133 -48
- konfai/utils/config.py +52 -19
- konfai/utils/dataset.py +1 -1
- konfai/utils/utils.py +74 -59
- {konfai-1.0.8.dist-info → konfai-1.1.0.dist-info}/METADATA +1 -1
- konfai-1.1.0.dist-info/RECORD +39 -0
- konfai-1.0.8.dist-info/RECORD +0 -39
- {konfai-1.0.8.dist-info → konfai-1.1.0.dist-info}/WHEEL +0 -0
- {konfai-1.0.8.dist-info → konfai-1.1.0.dist-info}/entry_points.txt +0 -0
- {konfai-1.0.8.dist-info → konfai-1.1.0.dist-info}/licenses/LICENSE +0 -0
- {konfai-1.0.8.dist-info → konfai-1.1.0.dist-info}/top_level.txt +0 -0
konfai/network/network.py
CHANGED
|
@@ -13,11 +13,11 @@ from torch.utils.checkpoint import checkpoint
|
|
|
13
13
|
from typing import Union
|
|
14
14
|
from enum import Enum
|
|
15
15
|
|
|
16
|
-
from konfai import
|
|
16
|
+
from konfai import KONFAI_ROOT
|
|
17
17
|
from konfai.metric.schedulers import Scheduler
|
|
18
18
|
from konfai.utils.config import config
|
|
19
19
|
from konfai.utils.utils import State, _getModule, getDevice, getGPUMemory
|
|
20
|
-
from konfai.data.
|
|
20
|
+
from konfai.data.patching import Accumulator, ModelPatch
|
|
21
21
|
|
|
22
22
|
class NetState(Enum):
|
|
23
23
|
TRAIN = 0,
|
|
@@ -40,7 +40,7 @@ class OptimizerLoader():
|
|
|
40
40
|
|
|
41
41
|
def getOptimizer(self, key: str, parameter: Iterator[torch.nn.parameter.Parameter]) -> torch.optim.Optimizer:
|
|
42
42
|
torch.optim.AdamW
|
|
43
|
-
return config("{}.Model.{}.Optimizer".format(
|
|
43
|
+
return config("{}.Model.{}.Optimizer".format(KONFAI_ROOT(), key))(getattr(importlib.import_module('torch.optim'), self.name))(parameter, config = None)
|
|
44
44
|
|
|
45
45
|
class SchedulerStep():
|
|
46
46
|
|
|
@@ -98,8 +98,8 @@ class CriterionsLoader():
|
|
|
98
98
|
for module_classpath, criterionsAttr in self.criterionsLoader.items():
|
|
99
99
|
module, name = _getModule(module_classpath, "metric.measure")
|
|
100
100
|
criterionsAttr.isTorchCriterion = module.startswith("torch")
|
|
101
|
-
criterionsAttr.sheduler = criterionsAttr.l.getShedulers("{}.Model.{}.outputsCriterions.{}.targetsCriterions.{}.criterionsLoader.{}".format(
|
|
102
|
-
criterions[config("{}.Model.{}.outputsCriterions.{}.targetsCriterions.{}.criterionsLoader.{}".format(
|
|
101
|
+
criterionsAttr.sheduler = criterionsAttr.l.getShedulers("{}.Model.{}.outputsCriterions.{}.targetsCriterions.{}.criterionsLoader.{}".format(KONFAI_ROOT(), model_classname, output_group, target_group, module_classpath))
|
|
102
|
+
criterions[config("{}.Model.{}.outputsCriterions.{}.targetsCriterions.{}.criterionsLoader.{}".format(KONFAI_ROOT(), model_classname, output_group, target_group, module_classpath))(getattr(importlib.import_module(module), name))(config = None)] = criterionsAttr
|
|
103
103
|
return criterions
|
|
104
104
|
|
|
105
105
|
class TargetCriterionsLoader():
|
|
@@ -753,14 +753,14 @@ class Network(ModuleArgsDict, ABC):
|
|
|
753
753
|
output_layer_accumulator : dict[str, Accumulator] = {}
|
|
754
754
|
output_layer_patch_indexed : dict[str, Patch_Indexed] = {}
|
|
755
755
|
it = 0
|
|
756
|
-
debug = "
|
|
756
|
+
debug = "KONFAI_DEBUG" in os.environ
|
|
757
757
|
for (nameTmp, output_layer) in self.named_forward(*inputs):
|
|
758
758
|
name = nameTmp.replace(";accu;", "")
|
|
759
759
|
if debug:
|
|
760
|
-
if "
|
|
761
|
-
os.environ["
|
|
760
|
+
if "KONFAI_DEBUG_LAST_LAYER" in os.environ:
|
|
761
|
+
os.environ["KONFAI_DEBUG_LAST_LAYER"] = "{}|{}:{}:{}".format(os.environ["KONFAI_DEBUG_LAST_LAYER"], name, getGPUMemory(output_layer.device), str(output_layer.device).replace("cuda:", ""))
|
|
762
762
|
else:
|
|
763
|
-
os.environ["
|
|
763
|
+
os.environ["KONFAI_DEBUG_LAST_LAYER"] = "{}:{}:{}".format(name, getGPUMemory(output_layer.device), str(output_layer.device).replace("cuda:", ""))
|
|
764
764
|
it += 1
|
|
765
765
|
if name in layers_name or nameTmp in layers_name:
|
|
766
766
|
if ";accu;" in nameTmp:
|
|
@@ -918,12 +918,12 @@ class Network(ModuleArgsDict, ABC):
|
|
|
918
918
|
class ModelLoader():
|
|
919
919
|
|
|
920
920
|
@config("Model")
|
|
921
|
-
def __init__(self, classpath : str = "default:segmentation.UNet") -> None:
|
|
922
|
-
self.module, self.name = _getModule(classpath
|
|
921
|
+
def __init__(self, classpath : str = "default:segmentation.UNet.UNet") -> None:
|
|
922
|
+
self.module, self.name = _getModule(classpath, "models")
|
|
923
923
|
|
|
924
924
|
def getModel(self, train : bool = True, DL_args: Union[str, None] = None, DL_without=["optimizer", "schedulers", "nb_batch_per_step", "init_type", "init_gain"]) -> Network:
|
|
925
925
|
if not DL_args:
|
|
926
|
-
DL_args="{}.Model".format(
|
|
926
|
+
DL_args="{}.Model".format(KONFAI_ROOT())
|
|
927
927
|
model = partial(getattr(importlib.import_module(self.module), self.name), config = None, DL_args=DL_args)
|
|
928
928
|
if not train:
|
|
929
929
|
model = partial(model, DL_without = DL_without)
|
konfai/predictor.py
CHANGED
|
@@ -6,12 +6,12 @@ import torch
|
|
|
6
6
|
import tqdm
|
|
7
7
|
import os
|
|
8
8
|
|
|
9
|
-
from konfai import MODELS_DIRECTORY, PREDICTIONS_DIRECTORY, CONFIG_FILE, MODEL,
|
|
9
|
+
from konfai import MODELS_DIRECTORY, PREDICTIONS_DIRECTORY, CONFIG_FILE, MODEL, KONFAI_ROOT
|
|
10
10
|
from konfai.utils.config import config
|
|
11
11
|
from konfai.utils.utils import State, get_patch_slices_from_nb_patch_per_dim, NeedDevice, _getModule, DistributedObject, DataLog, description
|
|
12
12
|
from konfai.utils.dataset import Dataset, Attribute
|
|
13
|
-
from konfai.data.
|
|
14
|
-
from konfai.data.
|
|
13
|
+
from konfai.data.data_manager import DataPrediction, DatasetIter
|
|
14
|
+
from konfai.data.patching import Accumulator, PathCombine
|
|
15
15
|
from konfai.network.network import ModelLoader, Network, NetState, CPU_Model
|
|
16
16
|
from konfai.data.transform import Transform, TransformLoader
|
|
17
17
|
|
|
@@ -52,13 +52,13 @@ class OutDataset(Dataset, NeedDevice, ABC):
|
|
|
52
52
|
|
|
53
53
|
if _transform_type is not None:
|
|
54
54
|
for classpath, transform in _transform_type.items():
|
|
55
|
-
transform = transform.getTransform(classpath, DL_args = "{}.outsDataset.{}.OutDataset.{}".format(
|
|
55
|
+
transform = transform.getTransform(classpath, DL_args = "{}.outsDataset.{}.OutDataset.{}".format(KONFAI_ROOT(), name_layer, name))
|
|
56
56
|
transform.setDatasets(datasets)
|
|
57
57
|
transform_type.append(transform)
|
|
58
58
|
|
|
59
59
|
if self._patchCombine is not None:
|
|
60
|
-
module, name = _getModule(self._patchCombine, "data.
|
|
61
|
-
self.patchCombine = getattr(importlib.import_module(module), name)(config = None, DL_args = "{}.outsDataset.{}.OutDataset".format(
|
|
60
|
+
module, name = _getModule(self._patchCombine, "data.patching")
|
|
61
|
+
self.patchCombine = getattr(importlib.import_module(module), name)(config = None, DL_args = "{}.outsDataset.{}.OutDataset".format(KONFAI_ROOT(), name_layer))
|
|
62
62
|
|
|
63
63
|
def setPatchConfig(self, patchSize: Union[list[int], None], overlap: Union[int, None], nb_data_augmentation: int) -> None:
|
|
64
64
|
if patchSize is not None and overlap is not None:
|
|
@@ -94,7 +94,7 @@ class OutDataset(Dataset, NeedDevice, ABC):
|
|
|
94
94
|
class OutSameAsGroupDataset(OutDataset):
|
|
95
95
|
|
|
96
96
|
@config("OutDataset")
|
|
97
|
-
def __init__(self, dataset_filename: str = "Dataset:
|
|
97
|
+
def __init__(self, dataset_filename: str = "./Dataset:mha", group: str = "default", sameAsGroup: str = "default", pre_transforms : dict[str, TransformLoader] = {"default:Normalize": TransformLoader()}, post_transforms : dict[str, TransformLoader] = {"default:Normalize": TransformLoader()}, final_transforms : dict[str, TransformLoader] = {"default:Normalize": TransformLoader()}, patchCombine: Union[str, None] = None, redution: str = "mean", inverse_transform: bool = True) -> None:
|
|
98
98
|
super().__init__(dataset_filename, group, pre_transforms, post_transforms, final_transforms, patchCombine)
|
|
99
99
|
self.group_src, self.group_dest = sameAsGroup.split(":")
|
|
100
100
|
self.redution = redution
|
|
@@ -240,7 +240,7 @@ class _Predictor():
|
|
|
240
240
|
self.modelComposite.module.setState(NetState.PREDICTION)
|
|
241
241
|
desc = lambda : "Prediction : {}".format(description(self.modelComposite))
|
|
242
242
|
self.dataloader_prediction.dataset.load()
|
|
243
|
-
with tqdm.tqdm(iterable = enumerate(self.dataloader_prediction), leave=False, desc = desc(), total=len(self.dataloader_prediction), disable=self.global_rank != 0 and "
|
|
243
|
+
with tqdm.tqdm(iterable = enumerate(self.dataloader_prediction), leave=False, desc = desc(), total=len(self.dataloader_prediction), disable=self.global_rank != 0 and "KONFAI_CLUSTER" not in os.environ) as batch_iter:
|
|
244
244
|
dist.barrier()
|
|
245
245
|
for it, data_dict in batch_iter:
|
|
246
246
|
input = self.getInput(data_dict)
|
|
@@ -322,7 +322,7 @@ class Predictor(DistributedObject):
|
|
|
322
322
|
gpu_checkpoints: Union[list[str], None] = None,
|
|
323
323
|
outsDataset: Union[dict[str, OutDatasetLoader], None] = {"default:Default" : OutDatasetLoader()},
|
|
324
324
|
images_log: list[str] = []) -> None:
|
|
325
|
-
if os.environ["
|
|
325
|
+
if os.environ["KONFAI_CONFIG_MODE"] != "Done":
|
|
326
326
|
exit(0)
|
|
327
327
|
super().__init__(train_name)
|
|
328
328
|
self.manual_seed = manual_seed
|
|
@@ -374,7 +374,7 @@ class Predictor(DistributedObject):
|
|
|
374
374
|
for dataset_filename in self.datasets_filename:
|
|
375
375
|
path = self.predict_path +dataset_filename
|
|
376
376
|
if os.path.exists(path):
|
|
377
|
-
if os.environ["
|
|
377
|
+
if os.environ["KONFAI_OVERWRITE"] != "True":
|
|
378
378
|
accept = builtins.input("The prediction {} already exists ! Do you want to overwrite it (yes,no) : ".format(path))
|
|
379
379
|
if accept != "yes":
|
|
380
380
|
return
|
konfai/trainer.py
CHANGED
|
@@ -1,5 +1,4 @@
|
|
|
1
1
|
import torch
|
|
2
|
-
import torch.optim.adamw
|
|
3
2
|
from torch.utils.data import DataLoader
|
|
4
3
|
import tqdm
|
|
5
4
|
import numpy as np
|
|
@@ -12,16 +11,79 @@ from torch.utils.tensorboard.writer import SummaryWriter
|
|
|
12
11
|
from torch.optim.swa_utils import AveragedModel
|
|
13
12
|
import torch.distributed as dist
|
|
14
13
|
|
|
15
|
-
from konfai import MODELS_DIRECTORY, CHECKPOINTS_DIRECTORY, STATISTICS_DIRECTORY, SETUPS_DIRECTORY, CONFIG_FILE, MODEL, DATE,
|
|
16
|
-
from konfai.data.
|
|
14
|
+
from konfai import MODELS_DIRECTORY, CHECKPOINTS_DIRECTORY, STATISTICS_DIRECTORY, SETUPS_DIRECTORY, CONFIG_FILE, MODEL, DATE, KONFAI_STATE
|
|
15
|
+
from konfai.data.data_manager import DataTrain
|
|
17
16
|
from konfai.utils.config import config
|
|
18
17
|
from konfai.utils.utils import State, DataLog, DistributedObject, description
|
|
19
18
|
from konfai.network.network import Network, ModelLoader, NetState, CPU_Model
|
|
20
19
|
|
|
21
20
|
|
|
21
|
+
class EarlyStoppingBase:
|
|
22
|
+
|
|
23
|
+
def __init__(self):
|
|
24
|
+
pass
|
|
25
|
+
|
|
26
|
+
def isStopped(self) -> bool:
|
|
27
|
+
return False
|
|
28
|
+
|
|
29
|
+
def getScore(self, values: dict[str, float]):
|
|
30
|
+
return sum([i for i in values.values()])
|
|
31
|
+
|
|
32
|
+
def __call__(self, current_score: float) -> bool:
|
|
33
|
+
return False
|
|
34
|
+
|
|
35
|
+
class EarlyStopping(EarlyStoppingBase):
|
|
36
|
+
|
|
37
|
+
@config("EarlyStopping")
|
|
38
|
+
def __init__(self, monitor: Union[list[str], None] = [], patience=10, min_delta=0.0, mode="min"):
|
|
39
|
+
super().__init__()
|
|
40
|
+
self.monitor = [] if monitor is None else monitor
|
|
41
|
+
self.patience = patience
|
|
42
|
+
self.min_delta = min_delta
|
|
43
|
+
self.mode = mode
|
|
44
|
+
self.counter = 0
|
|
45
|
+
self.best_score = None
|
|
46
|
+
self.early_stop = False
|
|
47
|
+
|
|
48
|
+
def isStopped(self) -> bool:
|
|
49
|
+
return self.early_stop
|
|
50
|
+
|
|
51
|
+
def getScore(self, values: dict[str, float]):
|
|
52
|
+
if len(self.monitor) == 0:
|
|
53
|
+
return super().getScore(values)
|
|
54
|
+
for v in self.monitor:
|
|
55
|
+
if v not in values.keys():
|
|
56
|
+
raise ValueError(
|
|
57
|
+
"[EarlyStopping] Metric '{}' specified in `monitor` not found in logged values. "
|
|
58
|
+
"Available keys: {}. Please check your configuration.".format(v, list(values.keys())))
|
|
59
|
+
return sum([i for v, i in values.items() if v in self.monitor])
|
|
60
|
+
|
|
61
|
+
def __call__(self, current_score: float) -> bool:
|
|
62
|
+
if self.best_score is None:
|
|
63
|
+
self.best_score = current_score
|
|
64
|
+
return False
|
|
65
|
+
|
|
66
|
+
if self.mode == "min":
|
|
67
|
+
improvement = self.best_score - current_score
|
|
68
|
+
elif self.mode == "max":
|
|
69
|
+
improvement = current_score - self.best_score
|
|
70
|
+
else:
|
|
71
|
+
raise ValueError("Mode must be 'min' or 'max'.")
|
|
72
|
+
|
|
73
|
+
if improvement > self.min_delta:
|
|
74
|
+
self.best_score = current_score
|
|
75
|
+
self.counter = 0
|
|
76
|
+
else:
|
|
77
|
+
self.counter += 1
|
|
78
|
+
|
|
79
|
+
if self.counter >= self.patience:
|
|
80
|
+
self.early_stop = True
|
|
81
|
+
|
|
82
|
+
return self.early_stop
|
|
83
|
+
|
|
22
84
|
class _Trainer():
|
|
23
85
|
|
|
24
|
-
def __init__(self, world_size: int, global_rank: int, local_rank: int, size: int, train_name: str, data_log: Union[list[str], None] , save_checkpoint_mode: str, epochs: int, epoch: int, autocast: bool, it_validation: Union[int, None], it: int, model: Union[DDP, CPU_Model], modelEMA: AveragedModel, dataloader_training: DataLoader, dataloader_validation: Union[DataLoader, None] = None) -> None:
|
|
86
|
+
def __init__(self, world_size: int, global_rank: int, local_rank: int, size: int, train_name: str, early_stopping: EarlyStopping, data_log: Union[list[str], None] , save_checkpoint_mode: str, epochs: int, epoch: int, autocast: bool, it_validation: Union[int, None], it: int, model: Union[DDP, CPU_Model], modelEMA: AveragedModel, dataloader_training: DataLoader, dataloader_validation: Union[DataLoader, None] = None) -> None:
|
|
25
87
|
self.world_size = world_size
|
|
26
88
|
self.global_rank = global_rank
|
|
27
89
|
self.local_rank = local_rank
|
|
@@ -36,7 +98,8 @@ class _Trainer():
|
|
|
36
98
|
self.dataloader_validation = dataloader_validation
|
|
37
99
|
self.autocast = autocast
|
|
38
100
|
self.modelEMA = modelEMA
|
|
39
|
-
|
|
101
|
+
self.early_stopping = EarlyStoppingBase() if early_stopping is None else early_stopping
|
|
102
|
+
|
|
40
103
|
self.it_validation = it_validation
|
|
41
104
|
if self.it_validation is None:
|
|
42
105
|
self.it_validation = len(dataloader_training)
|
|
@@ -59,6 +122,8 @@ class _Trainer():
|
|
|
59
122
|
for self.epoch in epoch_tqdm:
|
|
60
123
|
self.dataloader_training.dataset.load()
|
|
61
124
|
self.train()
|
|
125
|
+
if self.early_stopping.isStopped():
|
|
126
|
+
break
|
|
62
127
|
self.dataloader_training.dataset.resetAugmentation()
|
|
63
128
|
|
|
64
129
|
def getInput(self, data_dict : dict[str, tuple[torch.Tensor, int, int, int, str, bool]]) -> dict[tuple[str, bool], torch.Tensor]:
|
|
@@ -72,7 +137,7 @@ class _Trainer():
|
|
|
72
137
|
self.modelEMA.module.setState(NetState.TRAIN)
|
|
73
138
|
|
|
74
139
|
desc = lambda : "Training : {}".format(description(self.model, self.modelEMA))
|
|
75
|
-
with tqdm.tqdm(iterable = enumerate(self.dataloader_training), desc = desc(), total=len(self.dataloader_training), leave=False, disable=self.global_rank != 0 and "
|
|
140
|
+
with tqdm.tqdm(iterable = enumerate(self.dataloader_training), desc = desc(), total=len(self.dataloader_training), leave=False, disable=self.global_rank != 0 and "KONFAI_CLUSTER" not in os.environ) as batch_iter:
|
|
76
141
|
for _, data_dict in batch_iter:
|
|
77
142
|
with torch.amp.autocast('cuda', enabled=self.autocast):
|
|
78
143
|
input = self.getInput(data_dict)
|
|
@@ -88,7 +153,11 @@ class _Trainer():
|
|
|
88
153
|
if self.dataloader_validation is not None:
|
|
89
154
|
loss = self._validate()
|
|
90
155
|
self.model.module.update_lr()
|
|
91
|
-
self.
|
|
156
|
+
score = self.early_stopping.getScore(loss)
|
|
157
|
+
self.checkpoint_save(score)
|
|
158
|
+
if self.early_stopping(score):
|
|
159
|
+
break
|
|
160
|
+
|
|
92
161
|
|
|
93
162
|
batch_iter.set_description(desc())
|
|
94
163
|
|
|
@@ -103,7 +172,7 @@ class _Trainer():
|
|
|
103
172
|
desc = lambda : "Validation : {}".format(description(self.model, self.modelEMA))
|
|
104
173
|
data_dict = None
|
|
105
174
|
self.dataloader_validation.dataset.load()
|
|
106
|
-
with tqdm.tqdm(iterable = enumerate(self.dataloader_validation), desc = desc(), total=len(self.dataloader_validation), leave=False, disable=self.global_rank != 0 and "
|
|
175
|
+
with tqdm.tqdm(iterable = enumerate(self.dataloader_validation), desc = desc(), total=len(self.dataloader_validation), leave=False, disable=self.global_rank != 0 and "KONFAI_CLUSTER" not in os.environ) as batch_iter:
|
|
107
176
|
for _, data_dict in batch_iter:
|
|
108
177
|
input = self.getInput(data_dict)
|
|
109
178
|
self.model(input)
|
|
@@ -120,41 +189,53 @@ class _Trainer():
|
|
|
120
189
|
return self._validation_log(data_dict)
|
|
121
190
|
|
|
122
191
|
def checkpoint_save(self, loss: float) -> None:
|
|
123
|
-
if self.global_rank
|
|
124
|
-
|
|
125
|
-
last_loss = None
|
|
126
|
-
if os.path.exists(path) and os.listdir(path):
|
|
127
|
-
name = sorted(os.listdir(path))[-1]
|
|
128
|
-
state_dict = torch.load(path+name, weights_only=False)
|
|
129
|
-
last_loss = state_dict["loss"]
|
|
130
|
-
if self.save_checkpoint_mode == "BEST":
|
|
131
|
-
if last_loss >= loss:
|
|
132
|
-
os.remove(path+name)
|
|
133
|
-
|
|
134
|
-
if self.save_checkpoint_mode != "BEST" or (last_loss is None or last_loss >= loss):
|
|
135
|
-
name = DATE()+".pt"
|
|
136
|
-
if not os.path.exists(path):
|
|
137
|
-
os.makedirs(path)
|
|
138
|
-
|
|
139
|
-
save_dict = {
|
|
140
|
-
"epoch": self.epoch,
|
|
141
|
-
"it": self.it,
|
|
142
|
-
"loss": loss,
|
|
143
|
-
"Model": self.model.module.state_dict()}
|
|
192
|
+
if self.global_rank != 0:
|
|
193
|
+
return
|
|
144
194
|
|
|
145
|
-
|
|
146
|
-
|
|
195
|
+
path = CHECKPOINTS_DIRECTORY()+self.train_name+"/"
|
|
196
|
+
os.makedirs(path, exist_ok=True)
|
|
197
|
+
|
|
198
|
+
name = DATE() + ".pt"
|
|
199
|
+
save_path = os.path.join(path, name)
|
|
147
200
|
|
|
148
|
-
|
|
149
|
-
|
|
201
|
+
save_dict = {
|
|
202
|
+
"epoch": self.epoch,
|
|
203
|
+
"it": self.it,
|
|
204
|
+
"loss": loss,
|
|
205
|
+
"Model": self.model.module.state_dict()
|
|
206
|
+
}
|
|
207
|
+
|
|
208
|
+
if self.modelEMA is not None:
|
|
209
|
+
save_dict["Model_EMA"] = self.modelEMA.module.state_dict()
|
|
210
|
+
|
|
211
|
+
save_dict.update({'{}_optimizer_state_dict'.format(name): network.optimizer.state_dict() for name, network in self.model.module.getNetworks().items() if network.optimizer is not None})
|
|
212
|
+
torch.save(save_dict, save_path)
|
|
213
|
+
|
|
214
|
+
if self.save_checkpoint_mode == "BEST":
|
|
215
|
+
all_checkpoints = sorted([
|
|
216
|
+
os.path.join(path, f)
|
|
217
|
+
for f in os.listdir(path) if f.endswith(".pt")
|
|
218
|
+
])
|
|
219
|
+
best_ckpt = None
|
|
220
|
+
best_loss = float('inf')
|
|
221
|
+
|
|
222
|
+
for f in all_checkpoints:
|
|
223
|
+
d = torch.load(f, weights_only=False)
|
|
224
|
+
if d.get("loss", float("inf")) < best_loss:
|
|
225
|
+
best_loss = d["loss"]
|
|
226
|
+
best_ckpt = f
|
|
227
|
+
|
|
228
|
+
for f in all_checkpoints:
|
|
229
|
+
if f != best_ckpt and f != save_path:
|
|
230
|
+
os.remove(f)
|
|
150
231
|
|
|
151
232
|
@torch.no_grad()
|
|
152
|
-
def _log(self, type_log: str, data_dict : dict[str, tuple[torch.Tensor, int, int, int]]) -> float:
|
|
233
|
+
def _log(self, type_log: str, data_dict : dict[str, tuple[torch.Tensor, int, int, int]]) -> dict[str, float]:
|
|
153
234
|
models: dict[str, Network] = {"" : self.model.module}
|
|
154
235
|
if self.modelEMA is not None:
|
|
155
236
|
models["_EMA"] = self.modelEMA.module
|
|
156
237
|
|
|
157
|
-
measures = DistributedObject.getMeasure(self.world_size, self.global_rank, self.local_rank*self.size+self.size-1, models, self.it_validation if type_log == "
|
|
238
|
+
measures = DistributedObject.getMeasure(self.world_size, self.global_rank, self.local_rank*self.size+self.size-1, models, self.it_validation if type_log == "Training" else len(self.dataloader_validation))
|
|
158
239
|
|
|
159
240
|
if self.global_rank == 0:
|
|
160
241
|
images_log = []
|
|
@@ -178,34 +259,36 @@ class _Trainer():
|
|
|
178
259
|
for name, layer, _ in model.get_layers([v.to(0) for k, v in self.getInput(data_dict).items() if k[1]], images_log):
|
|
179
260
|
self.data_log[name][0](self.tb, "{}/{}{}".format(type_log, name, label), layer[:self.data_log[name][1]].detach().cpu().numpy(), self.it)
|
|
180
261
|
|
|
181
|
-
if type_log == "
|
|
262
|
+
if type_log == "Training":
|
|
182
263
|
for name, network in self.model.module.getNetworks().items():
|
|
183
264
|
if network.optimizer is not None:
|
|
184
265
|
self.tb.add_scalar("{}/{}/Learning Rate".format(type_log, name), network.optimizer.param_groups[0]['lr'], self.it)
|
|
185
266
|
|
|
186
267
|
if self.global_rank == 0:
|
|
187
|
-
loss =
|
|
268
|
+
loss = {}
|
|
188
269
|
for name, network in self.model.module.getNetworks().items():
|
|
189
270
|
if network.measure is not None:
|
|
190
|
-
loss.
|
|
191
|
-
|
|
271
|
+
loss.update({k : v[1] for k, v in measures["{}{}".format(name, label)][0].items()})
|
|
272
|
+
loss.update({k : v[1] for k, v in measures["{}{}".format(name, label)][1].items()})
|
|
273
|
+
return loss
|
|
192
274
|
return None
|
|
193
275
|
|
|
194
276
|
@torch.no_grad()
|
|
195
|
-
def _train_log(self, data_dict : dict[str, tuple[torch.Tensor, int, int, int]]) -> float:
|
|
196
|
-
return self._log("
|
|
277
|
+
def _train_log(self, data_dict : dict[str, tuple[torch.Tensor, int, int, int]]) -> dict[str, float]:
|
|
278
|
+
return self._log("Training", data_dict)
|
|
197
279
|
|
|
198
280
|
@torch.no_grad()
|
|
199
|
-
def _validation_log(self, data_dict : dict[str, tuple[torch.Tensor, int, int, int]]) -> float:
|
|
281
|
+
def _validation_log(self, data_dict : dict[str, tuple[torch.Tensor, int, int, int]]) -> dict[str, float]:
|
|
200
282
|
return self._log("Validation", data_dict)
|
|
201
283
|
|
|
284
|
+
|
|
202
285
|
class Trainer(DistributedObject):
|
|
203
286
|
|
|
204
287
|
@config("Trainer")
|
|
205
288
|
def __init__( self,
|
|
206
289
|
model : ModelLoader = ModelLoader(),
|
|
207
290
|
dataset : DataTrain = DataTrain(),
|
|
208
|
-
train_name : str = "default:
|
|
291
|
+
train_name : str = "default:TRAIN_01",
|
|
209
292
|
manual_seed : Union[int, None] = None,
|
|
210
293
|
epochs: int = 100,
|
|
211
294
|
it_validation : Union[int, None] = None,
|
|
@@ -214,8 +297,9 @@ class Trainer(DistributedObject):
|
|
|
214
297
|
gpu_checkpoints: Union[list[str], None] = None,
|
|
215
298
|
ema_decay : float = 0,
|
|
216
299
|
data_log: Union[list[str], None] = None,
|
|
300
|
+
early_stopping: Union[EarlyStopping, None] = None,
|
|
217
301
|
save_checkpoint_mode: str= "BEST") -> None:
|
|
218
|
-
if os.environ["
|
|
302
|
+
if os.environ["KONFAI_CONFIG_MODE"] != "Done":
|
|
219
303
|
exit(0)
|
|
220
304
|
super().__init__(train_name)
|
|
221
305
|
self.manual_seed = manual_seed
|
|
@@ -223,6 +307,7 @@ class Trainer(DistributedObject):
|
|
|
223
307
|
self.autocast = autocast
|
|
224
308
|
self.epochs = epochs
|
|
225
309
|
self.epoch = 0
|
|
310
|
+
self.early_stopping = early_stopping
|
|
226
311
|
self.it = 0
|
|
227
312
|
self.it_validation = it_validation
|
|
228
313
|
self.model = model.getModel(train=True)
|
|
@@ -292,9 +377,9 @@ class Trainer(DistributedObject):
|
|
|
292
377
|
return (1-self.ema_decay) * averaged_model_parameter + self.ema_decay * model_parameter
|
|
293
378
|
|
|
294
379
|
def setup(self, world_size: int):
|
|
295
|
-
state = State._member_map_[
|
|
296
|
-
if state != State.RESUME and os.path.exists(
|
|
297
|
-
if os.environ["
|
|
380
|
+
state = State._member_map_[KONFAI_STATE()]
|
|
381
|
+
if state != State.RESUME and os.path.exists(CHECKPOINTS_DIRECTORY()+self.name+"/"):
|
|
382
|
+
if os.environ["KONFAI_OVERWRITE"] != "True":
|
|
298
383
|
accept = input("The model {} already exists ! Do you want to overwrite it (yes,no) : ".format(self.name))
|
|
299
384
|
if accept != "yes":
|
|
300
385
|
return
|
|
@@ -326,5 +411,5 @@ class Trainer(DistributedObject):
|
|
|
326
411
|
model = DDP(model, static_graph=True) if torch.cuda.is_available() else CPU_Model(model)
|
|
327
412
|
if self.modelEMA is not None:
|
|
328
413
|
self.modelEMA.module = Network.to(self.modelEMA.module, local_rank)
|
|
329
|
-
with _Trainer(world_size, global_rank, local_rank, self.size, self.name, self.data_log, self.save_checkpoint_mode, self.epochs, self.epoch, self.autocast, self.it_validation, self.it, model, self.modelEMA, *dataloaders) as t:
|
|
414
|
+
with _Trainer(world_size, global_rank, local_rank, self.size, self.name, self.early_stopping, self.data_log, self.save_checkpoint_mode, self.epochs, self.epoch, self.autocast, self.it_validation, self.it, model, self.modelEMA, *dataloaders) as t:
|
|
330
415
|
t.run()
|
konfai/utils/config.py
CHANGED
|
@@ -3,9 +3,8 @@ import ruamel.yaml
|
|
|
3
3
|
import inspect
|
|
4
4
|
import collections
|
|
5
5
|
from copy import deepcopy
|
|
6
|
-
from typing import Union
|
|
6
|
+
from typing import Union, Literal, get_origin, get_args
|
|
7
7
|
import torch
|
|
8
|
-
|
|
9
8
|
from konfai import CONFIG_FILE
|
|
10
9
|
|
|
11
10
|
yaml = ruamel.yaml.YAML()
|
|
@@ -26,7 +25,7 @@ class Config():
|
|
|
26
25
|
if not os.path.exists(self.filename):
|
|
27
26
|
result = input("Create a new config file ? [no,yes,interactive] : ")
|
|
28
27
|
if result in ["yes", "interactive"]:
|
|
29
|
-
os.environ["
|
|
28
|
+
os.environ["KONFAI_CONFIG_MODE"] = "interactive" if result == "interactive" else "default"
|
|
30
29
|
else:
|
|
31
30
|
exit(0)
|
|
32
31
|
with open(self.filename, "w") as f:
|
|
@@ -69,7 +68,7 @@ class Config():
|
|
|
69
68
|
|
|
70
69
|
def __exit__(self, type, value, traceback) -> None:
|
|
71
70
|
self.yml.close()
|
|
72
|
-
if os.environ["
|
|
71
|
+
if os.environ["KONFAI_CONFIG_MODE"] == "remove":
|
|
73
72
|
if os.path.exists(CONFIG_FILE()):
|
|
74
73
|
os.remove(CONFIG_FILE())
|
|
75
74
|
return
|
|
@@ -87,22 +86,22 @@ class Config():
|
|
|
87
86
|
except:
|
|
88
87
|
result = input("\nKeep a default configuration file ? (yes,no) : ")
|
|
89
88
|
if result == "yes":
|
|
90
|
-
os.environ["
|
|
89
|
+
os.environ["KONFAI_CONFIG_MODE"] = "default"
|
|
91
90
|
else:
|
|
92
|
-
os.environ["
|
|
91
|
+
os.environ["KONFAI_CONFIG_MODE"] = "remove"
|
|
93
92
|
exit(0)
|
|
94
93
|
return default.split(":")[1] if len(default.split(":")) > 1 else default
|
|
95
94
|
|
|
96
95
|
@staticmethod
|
|
97
96
|
def _getInputDefault(name : str, default : Union[str, None], isList : bool = False) -> Union[list[Union[str, None]], str, None]:
|
|
98
97
|
if isinstance(default, str) and (default == "default" or (len(default.split(":")) > 1 and default.split(":")[0] == "default")):
|
|
99
|
-
if os.environ["
|
|
98
|
+
if os.environ["KONFAI_CONFIG_MODE"] == "interactive":
|
|
100
99
|
if isList:
|
|
101
100
|
list_tmp = []
|
|
102
101
|
key_tmp = "OK"
|
|
103
|
-
while key_tmp != "!" and os.environ["
|
|
102
|
+
while (key_tmp != "!" and key_tmp != " ") and os.environ["KONFAI_CONFIG_MODE"] == "interactive":
|
|
104
103
|
key_tmp = Config._getInput(name, default)
|
|
105
|
-
if key_tmp != "!":
|
|
104
|
+
if (key_tmp != "!" and key_tmp != " "):
|
|
106
105
|
if key_tmp == "":
|
|
107
106
|
key_tmp = default.split(":")[1] if len(default.split(":")) > 1 else default
|
|
108
107
|
list_tmp.append(key_tmp)
|
|
@@ -134,6 +133,7 @@ class Config():
|
|
|
134
133
|
list_tmp = []
|
|
135
134
|
for key in value_config:
|
|
136
135
|
list_tmp.extend(Config._getInputDefault(name, key, isList=True))
|
|
136
|
+
|
|
137
137
|
|
|
138
138
|
value = list_tmp
|
|
139
139
|
value_config = list_tmp
|
|
@@ -155,7 +155,7 @@ class Config():
|
|
|
155
155
|
dict_value[key] = value_tmp
|
|
156
156
|
value = dict_value
|
|
157
157
|
if isinstance(self.config, str):
|
|
158
|
-
os.environ['
|
|
158
|
+
os.environ['KONFAI_CONFIG_VARIABLE'] = "True"
|
|
159
159
|
return None
|
|
160
160
|
|
|
161
161
|
self.config[name] = value_config if value_config is not None else "None"
|
|
@@ -169,17 +169,30 @@ def config(key : Union[str, None] = None):
|
|
|
169
169
|
if "config" in kwargs:
|
|
170
170
|
filename = kwargs["config"]
|
|
171
171
|
if filename == None:
|
|
172
|
-
filename = os.environ['
|
|
172
|
+
filename = os.environ['KONFAI_CONFIG_FILE']
|
|
173
173
|
else:
|
|
174
|
-
os.environ['
|
|
174
|
+
os.environ['KONFAI_CONFIG_FILE'] = filename
|
|
175
175
|
key_tmp = kwargs["DL_args"]+("."+key if key is not None else "") if "DL_args" in kwargs else key
|
|
176
176
|
without = kwargs["DL_without"] if "DL_without" in kwargs else []
|
|
177
|
-
os.environ['
|
|
177
|
+
os.environ['KONFAI_CONFIG_PATH'] = key_tmp
|
|
178
178
|
with Config(filename, key_tmp) as config:
|
|
179
|
-
os.environ['
|
|
179
|
+
os.environ['KONFAI_CONFIG_VARIABLE'] = "False"
|
|
180
180
|
kwargs = {}
|
|
181
181
|
for param in list(inspect.signature(function).parameters.values())[len(args):]:
|
|
182
|
+
|
|
182
183
|
annotation = param.annotation
|
|
184
|
+
# --- support Literal ---
|
|
185
|
+
if get_origin(annotation) is Literal:
|
|
186
|
+
allowed_values = get_args(annotation)
|
|
187
|
+
default_value = param.default if param.default != inspect._empty else allowed_values[0]
|
|
188
|
+
value = config.getValue(param.name, f"default:{default_value}")
|
|
189
|
+
if value not in allowed_values:
|
|
190
|
+
raise ValueError(
|
|
191
|
+
f"[Config] Invalid value '{value}' for parameter '{param.name}'. "
|
|
192
|
+
f"Expected one of: {allowed_values}."
|
|
193
|
+
)
|
|
194
|
+
kwargs[param.name] = value
|
|
195
|
+
continue
|
|
183
196
|
if str(annotation).startswith("typing.Union") or str(annotation).startswith("typing.Optional"):
|
|
184
197
|
for i in annotation.__args__:
|
|
185
198
|
annotation = i
|
|
@@ -188,8 +201,24 @@ def config(key : Union[str, None] = None):
|
|
|
188
201
|
continue
|
|
189
202
|
if not annotation == inspect._empty:
|
|
190
203
|
if annotation not in [int, str, bool, float, torch.Tensor]:
|
|
191
|
-
if str(annotation).startswith("list") or str(annotation).startswith("tuple") or str(annotation).startswith("typing.Tuple"):
|
|
192
|
-
|
|
204
|
+
if str(annotation).startswith("list") or str(annotation).startswith("tuple") or str(annotation).startswith("typing.Tuple") or str(annotation).startswith("typing.List"):
|
|
205
|
+
elem_type = annotation.__args__[0]
|
|
206
|
+
values = config.getValue(param.name, param.default)
|
|
207
|
+
if getattr(elem_type, '__origin__', None) is Union:
|
|
208
|
+
valid_types = elem_type.__args__
|
|
209
|
+
result = []
|
|
210
|
+
for v in values:
|
|
211
|
+
for t in valid_types:
|
|
212
|
+
try:
|
|
213
|
+
if t == torch.Tensor and not isinstance(v, torch.Tensor):
|
|
214
|
+
v = torch.tensor(v)
|
|
215
|
+
result.append(t(v) if t != torch.Tensor else v)
|
|
216
|
+
break
|
|
217
|
+
except Exception:
|
|
218
|
+
continue
|
|
219
|
+
kwargs[param.name] = result
|
|
220
|
+
|
|
221
|
+
elif annotation.__args__[0] in [int, str, bool, float]:
|
|
193
222
|
values = config.getValue(param.name, param.default)
|
|
194
223
|
kwargs[param.name] = values
|
|
195
224
|
else:
|
|
@@ -204,9 +233,13 @@ def config(key : Union[str, None] = None):
|
|
|
204
233
|
else:
|
|
205
234
|
raise ConfigError()
|
|
206
235
|
else:
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
236
|
+
try:
|
|
237
|
+
kwargs[param.name] = annotation(config = filename, DL_args = key_tmp)
|
|
238
|
+
except Exception as e:
|
|
239
|
+
raise ValueError("[Config] Failed to instantiate {} with type {}, error {} ".format(param.name, annotation.__name__, e))
|
|
240
|
+
|
|
241
|
+
if os.environ['KONFAI_CONFIG_VARIABLE'] == "True":
|
|
242
|
+
os.environ['KONFAI_CONFIG_VARIABLE'] = "False"
|
|
210
243
|
kwargs[param.name] = None
|
|
211
244
|
else:
|
|
212
245
|
kwargs[param.name] = config.getValue(param.name, param.default)
|
konfai/utils/dataset.py
CHANGED
|
@@ -750,7 +750,7 @@ class Dataset():
|
|
|
750
750
|
else:
|
|
751
751
|
with Dataset.File(self.filename, True, self.format) as file:
|
|
752
752
|
names = file.getNames(groups)
|
|
753
|
-
return [name for i, name in enumerate(names) if index is None or i in index]
|
|
753
|
+
return [name for i, name in enumerate(sorted(names)) if index is None or i in index]
|
|
754
754
|
|
|
755
755
|
def getInfos(self, groups: str, name: str) -> tuple[list[int], Attribute]:
|
|
756
756
|
if self.is_directory:
|