konfai 1.1.0__py3-none-any.whl → 1.1.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of konfai might be problematic. Click here for more details.
- konfai/__init__.py +1 -1
- konfai/data/augmentation.py +2 -2
- konfai/data/data_manager.py +145 -42
- konfai/data/patching.py +39 -13
- konfai/data/transform.py +48 -21
- konfai/evaluator.py +24 -7
- konfai/main.py +7 -5
- konfai/models/registration/registration.py +0 -1
- konfai/network/blocks.py +0 -1
- konfai/network/network.py +29 -16
- konfai/predictor.py +24 -21
- konfai/trainer.py +15 -15
- konfai/utils/config.py +12 -12
- konfai/utils/dataset.py +27 -2
- konfai/utils/utils.py +108 -24
- {konfai-1.1.0.dist-info → konfai-1.1.2.dist-info}/METADATA +1 -1
- {konfai-1.1.0.dist-info → konfai-1.1.2.dist-info}/RECORD +21 -21
- {konfai-1.1.0.dist-info → konfai-1.1.2.dist-info}/WHEEL +0 -0
- {konfai-1.1.0.dist-info → konfai-1.1.2.dist-info}/entry_points.txt +0 -0
- {konfai-1.1.0.dist-info → konfai-1.1.2.dist-info}/licenses/LICENSE +0 -0
- {konfai-1.1.0.dist-info → konfai-1.1.2.dist-info}/top_level.txt +0 -0
konfai/evaluator.py
CHANGED
|
@@ -9,7 +9,7 @@ import builtins
|
|
|
9
9
|
import importlib
|
|
10
10
|
from konfai import EVALUATIONS_DIRECTORY, PREDICTIONS_DIRECTORY, KONFAI_ROOT, CONFIG_FILE
|
|
11
11
|
from konfai.utils.config import config
|
|
12
|
-
from konfai.utils.utils import _getModule, DistributedObject, synchronize_data
|
|
12
|
+
from konfai.utils.utils import _getModule, DistributedObject, synchronize_data, EvaluatorError
|
|
13
13
|
from konfai.data.data_manager import DataMetric
|
|
14
14
|
|
|
15
15
|
class CriterionsAttr():
|
|
@@ -54,7 +54,8 @@ class Statistics():
|
|
|
54
54
|
if name_dataset not in self.measures:
|
|
55
55
|
self.measures[name_dataset] = {}
|
|
56
56
|
self.measures[name_dataset][name] = value
|
|
57
|
-
|
|
57
|
+
|
|
58
|
+
@staticmethod
|
|
58
59
|
def getStatistic(values: list[float]) -> dict[str, float]:
|
|
59
60
|
return {"max": np.max(values), "min": np.min(values), "std": np.std(values), "25pc": np.percentile(values, 25), "50pc": np.percentile(values, 50), "75pc": np.percentile(values, 75), "mean": np.mean(values), "count": len(values)}
|
|
60
61
|
|
|
@@ -91,7 +92,6 @@ class Evaluator(DistributedObject):
|
|
|
91
92
|
exit(0)
|
|
92
93
|
super().__init__(train_name)
|
|
93
94
|
self.metric_path = EVALUATIONS_DIRECTORY()+self.name+"/"
|
|
94
|
-
self.predict_path = PREDICTIONS_DIRECTORY()+self.name+"/"
|
|
95
95
|
self.metricsLoader = metrics
|
|
96
96
|
self.dataset = dataset
|
|
97
97
|
self.metrics = {k: v.getTargetsCriterions(k) for k, v in self.metricsLoader.items()}
|
|
@@ -102,10 +102,10 @@ class Evaluator(DistributedObject):
|
|
|
102
102
|
result = {}
|
|
103
103
|
for output_group in self.metrics:
|
|
104
104
|
for target_group in self.metrics[output_group]:
|
|
105
|
-
targets = [data_dict[group][0] for group in target_group.split("/") if group in data_dict]
|
|
105
|
+
targets = [data_dict[group][0].to(0) if torch.cuda.is_available() else data_dict[group][0] for group in target_group.split("/") if group in data_dict]
|
|
106
106
|
name = data_dict[output_group][1][0]
|
|
107
107
|
for metric in self.metrics[output_group][target_group]:
|
|
108
|
-
result["{}:{}:{}".format(output_group, target_group, metric.__class__.__name__)] = metric(data_dict[output_group][0], *targets).item()
|
|
108
|
+
result["{}:{}:{}".format(output_group, target_group, metric.__class__.__name__)] = metric(data_dict[output_group][0].to(0) if torch.cuda.is_available() else data_dict[output_group][0], *targets).item()
|
|
109
109
|
statistics.add(result, name)
|
|
110
110
|
return result
|
|
111
111
|
|
|
@@ -126,9 +126,26 @@ class Evaluator(DistributedObject):
|
|
|
126
126
|
|
|
127
127
|
self.dataloader = self.dataset.getData(world_size)
|
|
128
128
|
|
|
129
|
+
groupsDest = [group for groups in self.dataset.groups_src.values() for group in groups]
|
|
130
|
+
|
|
131
|
+
missing_outputs = set(self.metrics.keys()) - set(groupsDest)
|
|
132
|
+
if missing_outputs:
|
|
133
|
+
raise EvaluatorError(
|
|
134
|
+
f"The following metric output groups are missing from 'groupsDest': {sorted(missing_outputs)}. ",
|
|
135
|
+
f"Available groups: {sorted(groupsDest)}"
|
|
136
|
+
)
|
|
137
|
+
|
|
138
|
+
target_groups = {target for targets in self.metrics.values() for target in targets}
|
|
139
|
+
missing_targets = target_groups - set(groupsDest)
|
|
140
|
+
if missing_targets:
|
|
141
|
+
raise EvaluatorError(
|
|
142
|
+
f"The following metric target groups are missing from 'groupsDest': {sorted(missing_targets)}. ",
|
|
143
|
+
f"Available groups: {sorted(groupsDest)}"
|
|
144
|
+
)
|
|
145
|
+
|
|
129
146
|
def run_process(self, world_size: int, global_rank: int, gpu: int, dataloaders: list[DataLoader]):
|
|
130
147
|
description = lambda measure : "Metric TRAIN : {} ".format(" | ".join("{}: {:.2f}".format(k, v) for k, v in measure.items()) if measure is not None else "")
|
|
131
|
-
with tqdm.tqdm(iterable = enumerate(dataloaders[0]), leave=
|
|
148
|
+
with tqdm.tqdm(iterable = enumerate(dataloaders[0]), leave=True, desc = description(None), total=len(dataloaders[0]), ncols=0) as batch_iter:
|
|
132
149
|
for _, data_dict in batch_iter:
|
|
133
150
|
batch_iter.set_description(description(self.update({k: (v[0], v[4]) for k,v in data_dict.items()}, self.statistics_train)))
|
|
134
151
|
outputs = synchronize_data(world_size, gpu, self.statistics_train.measures)
|
|
@@ -136,7 +153,7 @@ class Evaluator(DistributedObject):
|
|
|
136
153
|
self.statistics_train.write(outputs)
|
|
137
154
|
if len(dataloaders) == 2:
|
|
138
155
|
description = lambda measure : "Metric VALIDATION : {} ".format(" | ".join("{}: {:.2f}".format(k, v) for k, v in measure.items()) if measure is not None else "")
|
|
139
|
-
with tqdm.tqdm(iterable = enumerate(dataloaders[1]), leave=
|
|
156
|
+
with tqdm.tqdm(iterable = enumerate(dataloaders[1]), leave=True, desc = description(None), total=len(dataloaders[1]), ncols=0) as batch_iter:
|
|
140
157
|
for _, data_dict in batch_iter:
|
|
141
158
|
batch_iter.set_description(description(self.update({k: (v[0], v[4]) for k,v in data_dict.items()}, self.statistics_validation)))
|
|
142
159
|
outputs = synchronize_data(world_size, gpu, self.statistics_validation.measures)
|
konfai/main.py
CHANGED
|
@@ -3,6 +3,7 @@ import os
|
|
|
3
3
|
from torch.cuda import device_count
|
|
4
4
|
import torch.multiprocessing as mp
|
|
5
5
|
from konfai.utils.utils import setup, TensorBoard, Log
|
|
6
|
+
from konfai import KONFAI_NB_CORES
|
|
6
7
|
|
|
7
8
|
import sys
|
|
8
9
|
sys.path.insert(0, os.getcwd())
|
|
@@ -11,16 +12,15 @@ def main():
|
|
|
11
12
|
parser = argparse.ArgumentParser(description="KonfAI", formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
|
12
13
|
try:
|
|
13
14
|
with setup(parser) as distributedObject:
|
|
14
|
-
with Log(distributedObject.name):
|
|
15
|
+
with Log(distributedObject.name, 0):
|
|
15
16
|
world_size = device_count()
|
|
16
17
|
if world_size == 0:
|
|
17
|
-
world_size =
|
|
18
|
+
world_size = int(KONFAI_NB_CORES())
|
|
18
19
|
distributedObject.setup(world_size)
|
|
19
20
|
with TensorBoard(distributedObject.name):
|
|
20
21
|
mp.spawn(distributedObject, nprocs=world_size)
|
|
21
|
-
except
|
|
22
|
-
print(
|
|
23
|
-
exit(1)
|
|
22
|
+
except KeyboardInterrupt:
|
|
23
|
+
print("\n[KonfAI] Manual interruption (Ctrl+C)")
|
|
24
24
|
|
|
25
25
|
|
|
26
26
|
def cluster():
|
|
@@ -47,6 +47,8 @@ def cluster():
|
|
|
47
47
|
executor.update_parameters(name=config["name"], mem_gb=config["memory"], gpus_per_node=n_gpu, tasks_per_node=n_gpu//distributedObject.size, cpus_per_task=config["num_workers"], nodes=config["num_nodes"], timeout_min=config["time_limit"])
|
|
48
48
|
with TensorBoard(distributedObject.name):
|
|
49
49
|
executor.submit(distributedObject)
|
|
50
|
+
except KeyboardInterrupt:
|
|
51
|
+
print("\n[KonfAI] Manual interruption (Ctrl+C)")
|
|
50
52
|
except Exception as e:
|
|
51
53
|
print(e)
|
|
52
54
|
exit(1)
|
|
@@ -94,7 +94,6 @@ class SpatialTransformer(torch.nn.Module):
|
|
|
94
94
|
new_locs[:, 1,1] = 1
|
|
95
95
|
new_locs[:, 0,2] = flow[:, 0]
|
|
96
96
|
new_locs[:, 1,2] = flow[:, 1]
|
|
97
|
-
print(new_locs)
|
|
98
97
|
return F.grid_sample(src, F.affine_grid(new_locs, src.size()), align_corners=True, mode="bilinear")
|
|
99
98
|
else:
|
|
100
99
|
new_locs = self.grid + flow
|
konfai/network/blocks.py
CHANGED
konfai/network/network.py
CHANGED
|
@@ -16,7 +16,7 @@ from enum import Enum
|
|
|
16
16
|
from konfai import KONFAI_ROOT
|
|
17
17
|
from konfai.metric.schedulers import Scheduler
|
|
18
18
|
from konfai.utils.config import config
|
|
19
|
-
from konfai.utils.utils import State, _getModule, getDevice, getGPUMemory
|
|
19
|
+
from konfai.utils.utils import State, _getModule, getDevice, getGPUMemory, MeasureError
|
|
20
20
|
from konfai.data.patching import Accumulator, ModelPatch
|
|
21
21
|
|
|
22
22
|
class NetState(Enum):
|
|
@@ -39,7 +39,6 @@ class OptimizerLoader():
|
|
|
39
39
|
self.name = name
|
|
40
40
|
|
|
41
41
|
def getOptimizer(self, key: str, parameter: Iterator[torch.nn.parameter.Parameter]) -> torch.optim.Optimizer:
|
|
42
|
-
torch.optim.AdamW
|
|
43
42
|
return config("{}.Model.{}.Optimizer".format(KONFAI_ROOT(), key))(getattr(importlib.import_module('torch.optim'), self.name))(parameter, config = None)
|
|
44
43
|
|
|
45
44
|
class SchedulerStep():
|
|
@@ -54,12 +53,12 @@ class LRSchedulersLoader():
|
|
|
54
53
|
def __init__(self, params: dict[str, SchedulerStep] = {"default:ReduceLROnPlateau" : SchedulerStep(0)}) -> None:
|
|
55
54
|
self.params = params
|
|
56
55
|
|
|
57
|
-
def
|
|
58
|
-
|
|
56
|
+
def getschedulers(self, key: str, optimizer: torch.optim.Optimizer) -> dict[torch.optim.lr_scheduler._LRScheduler, int]:
|
|
57
|
+
schedulers : dict[torch.optim.lr_scheduler._LRScheduler, int] = {}
|
|
59
58
|
for name, step in self.params.items():
|
|
60
59
|
if name:
|
|
61
|
-
|
|
62
|
-
return
|
|
60
|
+
schedulers[config("Trainer.Model.{}.Schedulers.{}".format(key, name))(getattr(importlib.import_module('torch.optim.lr_scheduler'), name))(optimizer, config = None)] = step.nb_step
|
|
61
|
+
return schedulers
|
|
63
62
|
|
|
64
63
|
class SchedulersLoader():
|
|
65
64
|
|
|
@@ -67,12 +66,12 @@ class SchedulersLoader():
|
|
|
67
66
|
def __init__(self, params: dict[str, SchedulerStep] = {"default:Constant" : SchedulerStep(0)}) -> None:
|
|
68
67
|
self.params = params
|
|
69
68
|
|
|
70
|
-
def
|
|
71
|
-
|
|
69
|
+
def getschedulers(self, key: str) -> dict[torch.optim.lr_scheduler._LRScheduler, int]:
|
|
70
|
+
schedulers : dict[Scheduler, int] = {}
|
|
72
71
|
for name, step in self.params.items():
|
|
73
72
|
if name:
|
|
74
|
-
|
|
75
|
-
return
|
|
73
|
+
schedulers[getattr(importlib.import_module("konfai.metric.schedulers"), name)(config = None, DL_args = key)] = step.nb_step
|
|
74
|
+
return schedulers
|
|
76
75
|
|
|
77
76
|
class CriterionsAttr():
|
|
78
77
|
|
|
@@ -98,7 +97,7 @@ class CriterionsLoader():
|
|
|
98
97
|
for module_classpath, criterionsAttr in self.criterionsLoader.items():
|
|
99
98
|
module, name = _getModule(module_classpath, "metric.measure")
|
|
100
99
|
criterionsAttr.isTorchCriterion = module.startswith("torch")
|
|
101
|
-
criterionsAttr.sheduler = criterionsAttr.l.
|
|
100
|
+
criterionsAttr.sheduler = criterionsAttr.l.getschedulers("{}.Model.{}.outputsCriterions.{}.targetsCriterions.{}.criterionsLoader.{}".format(KONFAI_ROOT(), model_classname, output_group, target_group, module_classpath))
|
|
102
101
|
criterions[config("{}.Model.{}.outputsCriterions.{}.targetsCriterions.{}.criterionsLoader.{}".format(KONFAI_ROOT(), model_classname, output_group, target_group, module_classpath))(getattr(importlib.import_module(module), name))(config = None)] = criterionsAttr
|
|
103
102
|
return criterions
|
|
104
103
|
|
|
@@ -154,10 +153,25 @@ class Measure():
|
|
|
154
153
|
self.outputsCriterions[output_group.replace(":", ".")] = targetCriterionsLoader.getTargetsCriterions(output_group, model_classname)
|
|
155
154
|
self._loss : dict[int, dict[str, Measure.Loss]] = {}
|
|
156
155
|
|
|
157
|
-
def init(self, model : torch.nn.Module) -> None:
|
|
156
|
+
def init(self, model : torch.nn.Module, group_dest: list[str]) -> None:
|
|
158
157
|
outputs_group_rename = {}
|
|
158
|
+
|
|
159
|
+
modules = []
|
|
160
|
+
for i,_ in model.named_modules():
|
|
161
|
+
modules.append(i)
|
|
162
|
+
|
|
159
163
|
for output_group in self.outputsCriterions.keys():
|
|
164
|
+
if output_group not in modules:
|
|
165
|
+
raise MeasureError(f"The output group '{output_group}' defined in 'outputsCriterions' does not correspond to any module in the model.",
|
|
166
|
+
f"Available modules: {modules}",
|
|
167
|
+
"Please check that the name matches exactly a submodule or output of your model architecture."
|
|
168
|
+
)
|
|
160
169
|
for target_group in self.outputsCriterions[output_group]:
|
|
170
|
+
if target_group not in group_dest:
|
|
171
|
+
raise MeasureError(
|
|
172
|
+
f"The target_group '{target_group}' defined in 'outputsCriterions.{output_group}.targetsCriterions' was not found in the available destination groups.",
|
|
173
|
+
"This target_group is expected for loss or metric computation, but was not loaded in 'group_dest'.",
|
|
174
|
+
f"Please make sure that the group '{target_group}' is defined in 'Dataset:groups_src:...:groups_dest:'{target_group}'' and correctly loaded from the dataset.")
|
|
161
175
|
for criterion in self.outputsCriterions[output_group][target_group]:
|
|
162
176
|
if not self.outputsCriterions[output_group][target_group][criterion].isTorchCriterion:
|
|
163
177
|
outputs_group_rename[output_group] = criterion.init(model, output_group, target_group)
|
|
@@ -703,10 +717,10 @@ class Network(ModuleArgsDict, ABC):
|
|
|
703
717
|
return in_channels, in_is_channel, out_channels, out_is_channel
|
|
704
718
|
|
|
705
719
|
@_function_network()
|
|
706
|
-
def init(self, autocast : bool, state : State, key: str) -> None:
|
|
720
|
+
def init(self, autocast : bool, state : State, group_dest: list[str], key: str) -> None:
|
|
707
721
|
if self.outputsCriterionsLoader:
|
|
708
722
|
self.measure = Measure(key, self.outputsCriterionsLoader)
|
|
709
|
-
self.measure.init(self)
|
|
723
|
+
self.measure.init(self, group_dest)
|
|
710
724
|
if state != State.PREDICTION:
|
|
711
725
|
self.scaler = torch.amp.GradScaler("cuda", enabled=autocast)
|
|
712
726
|
if self.optimizerLoader:
|
|
@@ -714,7 +728,7 @@ class Network(ModuleArgsDict, ABC):
|
|
|
714
728
|
self.optimizer.zero_grad()
|
|
715
729
|
|
|
716
730
|
if self.LRSchedulersLoader and self.optimizer:
|
|
717
|
-
self.schedulers = self.LRSchedulersLoader.
|
|
731
|
+
self.schedulers = self.LRSchedulersLoader.getschedulers(key, self.optimizer)
|
|
718
732
|
|
|
719
733
|
def initialized(self):
|
|
720
734
|
pass
|
|
@@ -880,7 +894,6 @@ class Network(ModuleArgsDict, ABC):
|
|
|
880
894
|
if scheduler:
|
|
881
895
|
if scheduler.__class__.__name__ == 'ReduceLROnPlateau':
|
|
882
896
|
if self.measure:
|
|
883
|
-
print(sum(self.measure.getLastValues(0).values()))
|
|
884
897
|
scheduler.step(sum(self.measure.getLastValues(0).values()))
|
|
885
898
|
else:
|
|
886
899
|
scheduler.step()
|
konfai/predictor.py
CHANGED
|
@@ -94,10 +94,10 @@ class OutDataset(Dataset, NeedDevice, ABC):
|
|
|
94
94
|
class OutSameAsGroupDataset(OutDataset):
|
|
95
95
|
|
|
96
96
|
@config("OutDataset")
|
|
97
|
-
def __init__(self, dataset_filename: str = "./Dataset:mha", group: str = "default", sameAsGroup: str = "default", pre_transforms : dict[str, TransformLoader] = {"default:Normalize": TransformLoader()}, post_transforms : dict[str, TransformLoader] = {"default:Normalize": TransformLoader()}, final_transforms : dict[str, TransformLoader] = {"default:Normalize": TransformLoader()}, patchCombine: Union[str, None] = None,
|
|
97
|
+
def __init__(self, dataset_filename: str = "./Dataset:mha", group: str = "default", sameAsGroup: str = "default", pre_transforms : dict[str, TransformLoader] = {"default:Normalize": TransformLoader()}, post_transforms : dict[str, TransformLoader] = {"default:Normalize": TransformLoader()}, final_transforms : dict[str, TransformLoader] = {"default:Normalize": TransformLoader()}, patchCombine: Union[str, None] = None, reduction: str = "mean", inverse_transform: bool = True) -> None:
|
|
98
98
|
super().__init__(dataset_filename, group, pre_transforms, post_transforms, final_transforms, patchCombine)
|
|
99
99
|
self.group_src, self.group_dest = sameAsGroup.split(":")
|
|
100
|
-
self.
|
|
100
|
+
self.reduction = reduction
|
|
101
101
|
self.inverse_transform = inverse_transform
|
|
102
102
|
|
|
103
103
|
def addLayer(self, index_dataset: int, index_augmentation: int, index_patch: int, layer: torch.Tensor, dataset: DatasetIter):
|
|
@@ -151,9 +151,9 @@ class OutSameAsGroupDataset(OutDataset):
|
|
|
151
151
|
self.output_layer_accumulator.pop(index)
|
|
152
152
|
dtype = result.dtype
|
|
153
153
|
|
|
154
|
-
if self.
|
|
154
|
+
if self.reduction == "mean":
|
|
155
155
|
result = torch.mean(result.float(), dim=0).to(dtype)
|
|
156
|
-
elif self.
|
|
156
|
+
elif self.reduction == "median":
|
|
157
157
|
result, _ = torch.median(result.float(), dim=0)
|
|
158
158
|
else:
|
|
159
159
|
raise NameError("Reduction method does not exist (mean, median)")
|
|
@@ -201,7 +201,7 @@ class OutDatasetLoader():
|
|
|
201
201
|
|
|
202
202
|
class _Predictor():
|
|
203
203
|
|
|
204
|
-
def __init__(self, world_size: int, global_rank: int, local_rank: int, predict_path: str, data_log: Union[list[str], None], outsDataset: dict[str, OutDataset], modelComposite: DDP, dataloader_prediction: DataLoader) -> None:
|
|
204
|
+
def __init__(self, world_size: int, global_rank: int, local_rank: int, autocast: bool, predict_path: str, data_log: Union[list[str], None], outsDataset: dict[str, OutDataset], modelComposite: DDP, dataloader_prediction: DataLoader) -> None:
|
|
205
205
|
self.world_size = world_size
|
|
206
206
|
self.global_rank = global_rank
|
|
207
207
|
self.local_rank = local_rank
|
|
@@ -209,7 +209,7 @@ class _Predictor():
|
|
|
209
209
|
self.modelComposite = modelComposite
|
|
210
210
|
self.dataloader_prediction = dataloader_prediction
|
|
211
211
|
self.outsDataset = outsDataset
|
|
212
|
-
|
|
212
|
+
self.autocast = autocast
|
|
213
213
|
|
|
214
214
|
self.it = 0
|
|
215
215
|
|
|
@@ -239,21 +239,22 @@ class _Predictor():
|
|
|
239
239
|
self.modelComposite.eval()
|
|
240
240
|
self.modelComposite.module.setState(NetState.PREDICTION)
|
|
241
241
|
desc = lambda : "Prediction : {}".format(description(self.modelComposite))
|
|
242
|
-
self.dataloader_prediction.dataset.load()
|
|
243
|
-
with tqdm.tqdm(iterable = enumerate(self.dataloader_prediction), leave=
|
|
242
|
+
self.dataloader_prediction.dataset.load("Prediction")
|
|
243
|
+
with tqdm.tqdm(iterable = enumerate(self.dataloader_prediction), leave=True, desc = desc(), total=len(self.dataloader_prediction), ncols=0) as batch_iter:
|
|
244
244
|
dist.barrier()
|
|
245
245
|
for it, data_dict in batch_iter:
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
self.
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
246
|
+
with torch.amp.autocast('cuda', enabled=self.autocast):
|
|
247
|
+
input = self.getInput(data_dict)
|
|
248
|
+
for name, output in self.modelComposite(input, list(self.outsDataset.keys())):
|
|
249
|
+
self._predict_log(data_dict)
|
|
250
|
+
outDataset = self.outsDataset[name]
|
|
251
|
+
for i, (index, patch_augmentation, patch_index) in enumerate([(int(index), int(patch_augmentation), int(patch_index)) for index, patch_augmentation, patch_index in zip(list(data_dict.values())[0][1], list(data_dict.values())[0][2], list(data_dict.values())[0][3])]):
|
|
252
|
+
outDataset.addLayer(index, patch_augmentation, patch_index, output[i].cpu(), self.dataset)
|
|
253
|
+
if outDataset.isDone(index):
|
|
254
|
+
outDataset.write(index, self.dataset.getDatasetFromIndex(list(data_dict.keys())[0], index).name.split("/")[-1], outDataset.getOutput(index, self.dataset))
|
|
255
|
+
|
|
256
|
+
batch_iter.set_description(desc())
|
|
257
|
+
self.it += 1
|
|
257
258
|
|
|
258
259
|
def _predict_log(self, data_dict : dict[str, tuple[torch.Tensor, int, int, int]]):
|
|
259
260
|
measures = DistributedObject.getMeasure(self.world_size, self.global_rank, self.local_rank, {"" : self.modelComposite.module}, 1)
|
|
@@ -320,6 +321,7 @@ class Predictor(DistributedObject):
|
|
|
320
321
|
train_name: str = "name",
|
|
321
322
|
manual_seed : Union[int, None] = None,
|
|
322
323
|
gpu_checkpoints: Union[list[str], None] = None,
|
|
324
|
+
autocast : bool = False,
|
|
323
325
|
outsDataset: Union[dict[str, OutDatasetLoader], None] = {"default:Default" : OutDatasetLoader()},
|
|
324
326
|
images_log: list[str] = []) -> None:
|
|
325
327
|
if os.environ["KONFAI_CONFIG_MODE"] != "Done":
|
|
@@ -328,6 +330,7 @@ class Predictor(DistributedObject):
|
|
|
328
330
|
self.manual_seed = manual_seed
|
|
329
331
|
self.dataset = dataset
|
|
330
332
|
self.combine = combine
|
|
333
|
+
self.autocast = autocast
|
|
331
334
|
|
|
332
335
|
self.model = model.getModel(train=False)
|
|
333
336
|
self.it = 0
|
|
@@ -384,7 +387,7 @@ class Predictor(DistributedObject):
|
|
|
384
387
|
|
|
385
388
|
shutil.copyfile(CONFIG_FILE(), self.predict_path+"Prediction.yml")
|
|
386
389
|
|
|
387
|
-
self.model.init(autocast
|
|
390
|
+
self.model.init(self.autocast, State.PREDICTION, self.dataset.getGroupsDest())
|
|
388
391
|
self.model.init_outputsGroup()
|
|
389
392
|
self.model._compute_channels_trace(self.model, self.model.in_channels, None, self.gpu_checkpoints)
|
|
390
393
|
self.modelComposite = ModelComposite(self.model, len(MODEL().split(":")), self.combine)
|
|
@@ -402,7 +405,7 @@ class Predictor(DistributedObject):
|
|
|
402
405
|
def run_process(self, world_size: int, global_rank: int, local_rank: int, dataloaders: list[DataLoader]):
|
|
403
406
|
modelComposite = Network.to(self.modelComposite, local_rank*self.size)
|
|
404
407
|
modelComposite = DDP(modelComposite, static_graph=True) if torch.cuda.is_available() else CPU_Model(modelComposite)
|
|
405
|
-
with _Predictor(world_size, global_rank, local_rank, self.predict_path, self.images_log, self.outsDataset, modelComposite, *dataloaders) as p:
|
|
408
|
+
with _Predictor(world_size, global_rank, local_rank, self.autocast, self.predict_path, self.images_log, self.outsDataset, modelComposite, *dataloaders) as p:
|
|
406
409
|
p.run()
|
|
407
410
|
|
|
408
411
|
|
konfai/trainer.py
CHANGED
|
@@ -14,10 +14,9 @@ import torch.distributed as dist
|
|
|
14
14
|
from konfai import MODELS_DIRECTORY, CHECKPOINTS_DIRECTORY, STATISTICS_DIRECTORY, SETUPS_DIRECTORY, CONFIG_FILE, MODEL, DATE, KONFAI_STATE
|
|
15
15
|
from konfai.data.data_manager import DataTrain
|
|
16
16
|
from konfai.utils.config import config
|
|
17
|
-
from konfai.utils.utils import State, DataLog, DistributedObject, description
|
|
17
|
+
from konfai.utils.utils import State, DataLog, DistributedObject, description, TrainerError
|
|
18
18
|
from konfai.network.network import Network, ModelLoader, NetState, CPU_Model
|
|
19
19
|
|
|
20
|
-
|
|
21
20
|
class EarlyStoppingBase:
|
|
22
21
|
|
|
23
22
|
def __init__(self):
|
|
@@ -53,8 +52,8 @@ class EarlyStopping(EarlyStoppingBase):
|
|
|
53
52
|
return super().getScore(values)
|
|
54
53
|
for v in self.monitor:
|
|
55
54
|
if v not in values.keys():
|
|
56
|
-
raise
|
|
57
|
-
"
|
|
55
|
+
raise TrainerError(
|
|
56
|
+
"Metric '{}' specified in EarlyStopping.monitor not found in logged values. ",
|
|
58
57
|
"Available keys: {}. Please check your configuration.".format(v, list(values.keys())))
|
|
59
58
|
return sum([i for v, i in values.items() if v in self.monitor])
|
|
60
59
|
|
|
@@ -68,7 +67,7 @@ class EarlyStopping(EarlyStoppingBase):
|
|
|
68
67
|
elif self.mode == "max":
|
|
69
68
|
improvement = current_score - self.best_score
|
|
70
69
|
else:
|
|
71
|
-
raise
|
|
70
|
+
raise TrainerError("Mode must be 'min' or 'max'.")
|
|
72
71
|
|
|
73
72
|
if improvement > self.min_delta:
|
|
74
73
|
self.best_score = current_score
|
|
@@ -99,7 +98,7 @@ class _Trainer():
|
|
|
99
98
|
self.autocast = autocast
|
|
100
99
|
self.modelEMA = modelEMA
|
|
101
100
|
self.early_stopping = EarlyStoppingBase() if early_stopping is None else early_stopping
|
|
102
|
-
|
|
101
|
+
|
|
103
102
|
self.it_validation = it_validation
|
|
104
103
|
if self.it_validation is None:
|
|
105
104
|
self.it_validation = len(dataloader_training)
|
|
@@ -118,13 +117,14 @@ class _Trainer():
|
|
|
118
117
|
self.tb.close()
|
|
119
118
|
|
|
120
119
|
def run(self) -> None:
|
|
121
|
-
|
|
120
|
+
self.dataloader_training.dataset.load("Train")
|
|
121
|
+
self.dataloader_validation.dataset.load("Validation")
|
|
122
|
+
with tqdm.tqdm(iterable = range(self.epoch, self.epochs), leave=False, total=self.epochs, initial=self.epoch, desc="Progress") as epoch_tqdm:
|
|
122
123
|
for self.epoch in epoch_tqdm:
|
|
123
|
-
self.dataloader_training.dataset.load()
|
|
124
124
|
self.train()
|
|
125
125
|
if self.early_stopping.isStopped():
|
|
126
126
|
break
|
|
127
|
-
self.dataloader_training.dataset.resetAugmentation()
|
|
127
|
+
self.dataloader_training.dataset.resetAugmentation("Train")
|
|
128
128
|
|
|
129
129
|
def getInput(self, data_dict : dict[str, tuple[torch.Tensor, int, int, int, str, bool]]) -> dict[tuple[str, bool], torch.Tensor]:
|
|
130
130
|
return {(k, v[5][0].item()) : v[0] for k, v in data_dict.items()}
|
|
@@ -137,7 +137,8 @@ class _Trainer():
|
|
|
137
137
|
self.modelEMA.module.setState(NetState.TRAIN)
|
|
138
138
|
|
|
139
139
|
desc = lambda : "Training : {}".format(description(self.model, self.modelEMA))
|
|
140
|
-
|
|
140
|
+
|
|
141
|
+
with tqdm.tqdm(iterable = enumerate(self.dataloader_training), desc = desc(), total=len(self.dataloader_training), leave=False, ncols=0) as batch_iter:
|
|
141
142
|
for _, data_dict in batch_iter:
|
|
142
143
|
with torch.amp.autocast('cuda', enabled=self.autocast):
|
|
143
144
|
input = self.getInput(data_dict)
|
|
@@ -171,8 +172,7 @@ class _Trainer():
|
|
|
171
172
|
|
|
172
173
|
desc = lambda : "Validation : {}".format(description(self.model, self.modelEMA))
|
|
173
174
|
data_dict = None
|
|
174
|
-
self.dataloader_validation
|
|
175
|
-
with tqdm.tqdm(iterable = enumerate(self.dataloader_validation), desc = desc(), total=len(self.dataloader_validation), leave=False, disable=self.global_rank != 0 and "KONFAI_CLUSTER" not in os.environ) as batch_iter:
|
|
175
|
+
with tqdm.tqdm(iterable = enumerate(self.dataloader_validation), desc = desc(), total=len(self.dataloader_validation), leave=False, ncols=0) as batch_iter:
|
|
176
176
|
for _, data_dict in batch_iter:
|
|
177
177
|
input = self.getInput(data_dict)
|
|
178
178
|
self.model(input)
|
|
@@ -180,7 +180,7 @@ class _Trainer():
|
|
|
180
180
|
self.modelEMA.module(input)
|
|
181
181
|
|
|
182
182
|
batch_iter.set_description(desc())
|
|
183
|
-
self.dataloader_validation.dataset.resetAugmentation()
|
|
183
|
+
self.dataloader_validation.dataset.resetAugmentation("Validation")
|
|
184
184
|
dist.barrier()
|
|
185
185
|
self.model.train()
|
|
186
186
|
self.model.module.setState(NetState.TRAIN)
|
|
@@ -360,7 +360,7 @@ class Trainer(DistributedObject):
|
|
|
360
360
|
os.makedirs(dir)
|
|
361
361
|
|
|
362
362
|
for name in sorted(os.listdir(path_checkpoint)):
|
|
363
|
-
checkpoint = torch.load(path_checkpoint+name, weights_only=False)
|
|
363
|
+
checkpoint = torch.load(path_checkpoint+name, weights_only=False, map_location='cpu')
|
|
364
364
|
self.model.load(checkpoint, init=False, ema=False)
|
|
365
365
|
|
|
366
366
|
torch.save(self.model, "{}Serialized/{}".format(path_model, name))
|
|
@@ -391,7 +391,7 @@ class Trainer(DistributedObject):
|
|
|
391
391
|
if state != State.TRAIN:
|
|
392
392
|
state_dict = self._load()
|
|
393
393
|
|
|
394
|
-
self.model.init(self.autocast, state)
|
|
394
|
+
self.model.init(self.autocast, state, self.dataset.getGroupsDest())
|
|
395
395
|
self.model.init_outputsGroup()
|
|
396
396
|
self.model._compute_channels_trace(self.model, self.model.in_channels, self.gradient_checkpoints, self.gpu_checkpoints)
|
|
397
397
|
self.model.load(state_dict, init=True, ema=False)
|
konfai/utils/config.py
CHANGED
|
@@ -6,14 +6,10 @@ from copy import deepcopy
|
|
|
6
6
|
from typing import Union, Literal, get_origin, get_args
|
|
7
7
|
import torch
|
|
8
8
|
from konfai import CONFIG_FILE
|
|
9
|
+
from konfai.utils.utils import ConfigError
|
|
9
10
|
|
|
10
11
|
yaml = ruamel.yaml.YAML()
|
|
11
12
|
|
|
12
|
-
class ConfigError(Exception):
|
|
13
|
-
|
|
14
|
-
def __init__(self, message : str = "The config only supports types : config(Object), int, str, bool, float, list[int], list[str], list[bool], list[float], dict[str, Object]") -> None:
|
|
15
|
-
self.message = message
|
|
16
|
-
super().__init__(self.message)
|
|
17
13
|
|
|
18
14
|
class Config():
|
|
19
15
|
|
|
@@ -187,9 +183,8 @@ def config(key : Union[str, None] = None):
|
|
|
187
183
|
default_value = param.default if param.default != inspect._empty else allowed_values[0]
|
|
188
184
|
value = config.getValue(param.name, f"default:{default_value}")
|
|
189
185
|
if value not in allowed_values:
|
|
190
|
-
raise
|
|
191
|
-
f"
|
|
192
|
-
f"Expected one of: {allowed_values}."
|
|
186
|
+
raise ConfigError(
|
|
187
|
+
f"Invalid value '{value}' for parameter '{param.name} expected one of: {allowed_values}."
|
|
193
188
|
)
|
|
194
189
|
kwargs[param.name] = value
|
|
195
190
|
continue
|
|
@@ -222,21 +217,26 @@ def config(key : Union[str, None] = None):
|
|
|
222
217
|
values = config.getValue(param.name, param.default)
|
|
223
218
|
kwargs[param.name] = values
|
|
224
219
|
else:
|
|
225
|
-
raise ConfigError()
|
|
220
|
+
raise ConfigError("Config: The config only supports types : config(Object), int, str, bool, float, list[int], list[str], list[bool], list[float], dict[str, Object]")
|
|
226
221
|
elif str(annotation).startswith("dict"):
|
|
227
222
|
if annotation.__args__[0] == str:
|
|
228
223
|
values = config.getValue(param.name, param.default)
|
|
229
224
|
if values is not None and annotation.__args__[1] not in [int, str, bool, float]:
|
|
230
|
-
|
|
225
|
+
try:
|
|
226
|
+
kwargs[param.name] = {value : annotation.__args__[1](config = filename, DL_args = key_tmp+"."+param.name+"."+value) for value in values}
|
|
227
|
+
except ValueError as e:
|
|
228
|
+
raise ValueError(e)
|
|
229
|
+
except Exception as e:
|
|
230
|
+
raise ConfigError("{} {}".format(values, e))
|
|
231
231
|
else:
|
|
232
232
|
kwargs[param.name] = values
|
|
233
233
|
else:
|
|
234
|
-
raise ConfigError()
|
|
234
|
+
raise ConfigError("Config: The config only supports types : config(Object), int, str, bool, float, list[int], list[str], list[bool], list[float], dict[str, Object]")
|
|
235
235
|
else:
|
|
236
236
|
try:
|
|
237
237
|
kwargs[param.name] = annotation(config = filename, DL_args = key_tmp)
|
|
238
238
|
except Exception as e:
|
|
239
|
-
raise
|
|
239
|
+
raise ConfigError("Failed to instantiate {} with type {}, error {} ".format(param.name, annotation.__name__, e))
|
|
240
240
|
|
|
241
241
|
if os.environ['KONFAI_CONFIG_VARIABLE'] == "True":
|
|
242
242
|
os.environ['KONFAI_CONFIG_VARIABLE'] = "False"
|
konfai/utils/dataset.py
CHANGED
|
@@ -348,6 +348,10 @@ class Dataset():
|
|
|
348
348
|
@abstractmethod
|
|
349
349
|
def getNames(self, group: str) -> list[str]:
|
|
350
350
|
pass
|
|
351
|
+
|
|
352
|
+
@abstractmethod
|
|
353
|
+
def getGroup(self) -> list[str]:
|
|
354
|
+
pass
|
|
351
355
|
|
|
352
356
|
@abstractmethod
|
|
353
357
|
def isExist(self, group: str, name: Union[str, None] = None) -> bool:
|
|
@@ -458,6 +462,9 @@ class Dataset():
|
|
|
458
462
|
names.extend(self.getNames("/".join(groups.split("/")[1:]), h5_group[group]))
|
|
459
463
|
return names
|
|
460
464
|
|
|
465
|
+
def getGroup(self):
|
|
466
|
+
return self.h5.keys()
|
|
467
|
+
|
|
461
468
|
def _getDataset(self, groups: str, name: str, h5_group: h5py.Group = None) -> h5py.Dataset:
|
|
462
469
|
if h5_group is None:
|
|
463
470
|
h5_group = self.h5
|
|
@@ -613,7 +620,10 @@ class Dataset():
|
|
|
613
620
|
|
|
614
621
|
def getNames(self, group: str) -> list[str]:
|
|
615
622
|
raise NotImplementedError()
|
|
616
|
-
|
|
623
|
+
|
|
624
|
+
def getGroup(self):
|
|
625
|
+
raise NotImplementedError()
|
|
626
|
+
|
|
617
627
|
def getInfos(self, group: str, name: str) -> tuple[list[int], Attribute]:
|
|
618
628
|
attributes = Attribute()
|
|
619
629
|
if os.path.exists("{}{}{}.{}".format(self.filename, group if group is not None else "", name, self.format)):
|
|
@@ -736,7 +746,7 @@ class Dataset():
|
|
|
736
746
|
subDirectories.extend(self._getSubDirectories("/".join(groups.split("/")[1:]), subDirectory))
|
|
737
747
|
return subDirectories
|
|
738
748
|
|
|
739
|
-
def getNames(self, groups: str, index: Union[list[int], None] = None
|
|
749
|
+
def getNames(self, groups: str, index: Union[list[int], None] = None) -> list[str]:
|
|
740
750
|
names = []
|
|
741
751
|
if self.is_directory:
|
|
742
752
|
for subDirectory in self._getSubDirectories(groups):
|
|
@@ -752,6 +762,21 @@ class Dataset():
|
|
|
752
762
|
names = file.getNames(groups)
|
|
753
763
|
return [name for i, name in enumerate(sorted(names)) if index is None or i in index]
|
|
754
764
|
|
|
765
|
+
def getGroup(self):
|
|
766
|
+
if self.is_directory:
|
|
767
|
+
groups = set()
|
|
768
|
+
for root, _, files in os.walk(self.filename):
|
|
769
|
+
for file in files:
|
|
770
|
+
path = os.path.relpath(os.path.join(root, file.split(".")[0]), self.filename)
|
|
771
|
+
parts = path.split("/")
|
|
772
|
+
if len(parts) >= 2:
|
|
773
|
+
del parts[-2]
|
|
774
|
+
groups.add("/".join(parts))
|
|
775
|
+
else:
|
|
776
|
+
with Dataset.File(self.filename, True, self.format) as file:
|
|
777
|
+
groups = file.getGroup()
|
|
778
|
+
return list(groups)
|
|
779
|
+
|
|
755
780
|
def getInfos(self, groups: str, name: str) -> tuple[list[int], Attribute]:
|
|
756
781
|
if self.is_directory:
|
|
757
782
|
for subDirectory in self._getSubDirectories(groups):
|