konfai 1.0.9__tar.gz → 1.1.0__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of konfai might be problematic. Click here for more details.
- {konfai-1.0.9 → konfai-1.1.0}/PKG-INFO +1 -1
- {konfai-1.0.9 → konfai-1.1.0}/konfai/data/data_manager.py +19 -19
- {konfai-1.0.9 → konfai-1.1.0}/konfai/data/patching.py +1 -1
- {konfai-1.0.9 → konfai-1.1.0}/konfai/metric/measure.py +1 -1
- {konfai-1.0.9 → konfai-1.1.0}/konfai/trainer.py +124 -39
- {konfai-1.0.9 → konfai-1.1.0}/konfai/utils/config.py +1 -1
- {konfai-1.0.9 → konfai-1.1.0}/konfai.egg-info/PKG-INFO +1 -1
- {konfai-1.0.9 → konfai-1.1.0}/konfai.egg-info/SOURCES.txt +2 -1
- {konfai-1.0.9 → konfai-1.1.0}/pyproject.toml +5 -2
- konfai-1.1.0/tests/test_config.py +18 -0
- {konfai-1.0.9 → konfai-1.1.0}/LICENSE +0 -0
- {konfai-1.0.9 → konfai-1.1.0}/README.md +0 -0
- {konfai-1.0.9 → konfai-1.1.0}/konfai/__init__.py +0 -0
- {konfai-1.0.9 → konfai-1.1.0}/konfai/data/__init__.py +0 -0
- {konfai-1.0.9 → konfai-1.1.0}/konfai/data/augmentation.py +0 -0
- {konfai-1.0.9 → konfai-1.1.0}/konfai/data/transform.py +0 -0
- {konfai-1.0.9 → konfai-1.1.0}/konfai/evaluator.py +0 -0
- {konfai-1.0.9 → konfai-1.1.0}/konfai/main.py +0 -0
- {konfai-1.0.9 → konfai-1.1.0}/konfai/metric/__init__.py +0 -0
- {konfai-1.0.9 → konfai-1.1.0}/konfai/metric/schedulers.py +0 -0
- {konfai-1.0.9 → konfai-1.1.0}/konfai/models/classification/convNeXt.py +0 -0
- {konfai-1.0.9 → konfai-1.1.0}/konfai/models/classification/resnet.py +0 -0
- {konfai-1.0.9 → konfai-1.1.0}/konfai/models/generation/cStyleGan.py +0 -0
- {konfai-1.0.9 → konfai-1.1.0}/konfai/models/generation/ddpm.py +0 -0
- {konfai-1.0.9 → konfai-1.1.0}/konfai/models/generation/diffusionGan.py +0 -0
- {konfai-1.0.9 → konfai-1.1.0}/konfai/models/generation/gan.py +0 -0
- {konfai-1.0.9 → konfai-1.1.0}/konfai/models/generation/vae.py +0 -0
- {konfai-1.0.9 → konfai-1.1.0}/konfai/models/registration/registration.py +0 -0
- {konfai-1.0.9 → konfai-1.1.0}/konfai/models/representation/representation.py +0 -0
- {konfai-1.0.9 → konfai-1.1.0}/konfai/models/segmentation/NestedUNet.py +0 -0
- {konfai-1.0.9 → konfai-1.1.0}/konfai/models/segmentation/UNet.py +0 -0
- {konfai-1.0.9 → konfai-1.1.0}/konfai/network/__init__.py +0 -0
- {konfai-1.0.9 → konfai-1.1.0}/konfai/network/blocks.py +0 -0
- {konfai-1.0.9 → konfai-1.1.0}/konfai/network/network.py +0 -0
- {konfai-1.0.9 → konfai-1.1.0}/konfai/predictor.py +0 -0
- {konfai-1.0.9 → konfai-1.1.0}/konfai/utils/ITK.py +0 -0
- {konfai-1.0.9 → konfai-1.1.0}/konfai/utils/__init__.py +0 -0
- {konfai-1.0.9 → konfai-1.1.0}/konfai/utils/dataset.py +0 -0
- {konfai-1.0.9 → konfai-1.1.0}/konfai/utils/registration.py +0 -0
- {konfai-1.0.9 → konfai-1.1.0}/konfai/utils/utils.py +0 -0
- {konfai-1.0.9 → konfai-1.1.0}/konfai.egg-info/dependency_links.txt +0 -0
- {konfai-1.0.9 → konfai-1.1.0}/konfai.egg-info/entry_points.txt +0 -0
- {konfai-1.0.9 → konfai-1.1.0}/konfai.egg-info/requires.txt +0 -0
- {konfai-1.0.9 → konfai-1.1.0}/konfai.egg-info/top_level.txt +0 -0
- {konfai-1.0.9 → konfai-1.1.0}/setup.cfg +0 -0
|
@@ -257,14 +257,14 @@ class Data(ABC):
|
|
|
257
257
|
subset : Union[Subset, dict[str, Subset]],
|
|
258
258
|
num_workers : int,
|
|
259
259
|
batch_size : int,
|
|
260
|
-
|
|
260
|
+
validation: Union[float, str, list[int], list[str]] = 1,
|
|
261
261
|
inlineAugmentations: bool = False,
|
|
262
262
|
dataAugmentationsList: dict[str, DataAugmentationsList]= {}) -> None:
|
|
263
263
|
self.dataset_filenames = dataset_filenames
|
|
264
264
|
self.subset = subset
|
|
265
265
|
self.groups_src = groups_src
|
|
266
266
|
self.patch = patch
|
|
267
|
-
self.
|
|
267
|
+
self.validation = validation
|
|
268
268
|
self.dataAugmentationsList = dataAugmentationsList
|
|
269
269
|
self.batch_size = batch_size
|
|
270
270
|
self.dataSet_args = dict(groups_src=self.groups_src, inlineAugmentations=inlineAugmentations, dataAugmentationsList = list(self.dataAugmentationsList.values()), use_cache = use_cache, buffer_size=batch_size+1, patch_size=self.patch.patch_size if self.patch is not None else None, overlap=self.patch.overlap if self.patch is not None else None)
|
|
@@ -372,17 +372,17 @@ class Data(ABC):
|
|
|
372
372
|
|
|
373
373
|
train_map = map
|
|
374
374
|
validate_map = []
|
|
375
|
-
if isinstance(self.
|
|
376
|
-
if self.
|
|
377
|
-
train_map, validate_map = map[:int(math.floor(len(map)*self.
|
|
378
|
-
elif isinstance(self.
|
|
379
|
-
if ":" in self.
|
|
375
|
+
if isinstance(self.validation, float):
|
|
376
|
+
if self.validation < 1.0 and int(math.floor(len(map)*(1-self.validation))) > 0:
|
|
377
|
+
train_map, validate_map = map[:int(math.floor(len(map)*self.validation))], map[int(math.floor(len(map)*self.validation)):]
|
|
378
|
+
elif isinstance(self.validation, str):
|
|
379
|
+
if ":" in self.validation:
|
|
380
380
|
index = list(range(int(self.subset.split(":")[0]), int(self.subset.split(":")[1])))
|
|
381
381
|
train_map = [m for m in map if m[0] not in index]
|
|
382
382
|
validate_map = [m for m in map if m[0] in index]
|
|
383
|
-
elif os.path.exists(self.
|
|
383
|
+
elif os.path.exists(self.validation):
|
|
384
384
|
validation_names = []
|
|
385
|
-
with open(self.
|
|
385
|
+
with open(self.validation, "r") as f:
|
|
386
386
|
for name in f:
|
|
387
387
|
validation_names.append(name.strip())
|
|
388
388
|
index = [i for i, n in enumerate(subset_names) if n in validation_names]
|
|
@@ -390,13 +390,13 @@ class Data(ABC):
|
|
|
390
390
|
validate_map = [m for m in map if m[0] in index]
|
|
391
391
|
else:
|
|
392
392
|
validate_map = train_map
|
|
393
|
-
elif isinstance(self.
|
|
394
|
-
if len(self.
|
|
395
|
-
if isinstance(self.
|
|
396
|
-
train_map = [m for m in map if m[0] not in self.
|
|
397
|
-
validate_map = [m for m in map if m[0] in self.
|
|
398
|
-
elif isinstance(self.
|
|
399
|
-
index = [i for i, n in enumerate(subset_names) if n in self.
|
|
393
|
+
elif isinstance(self.validation, list):
|
|
394
|
+
if len(self.validation) > 0:
|
|
395
|
+
if isinstance(self.validation[0], int):
|
|
396
|
+
train_map = [m for m in map if m[0] not in self.validation]
|
|
397
|
+
validate_map = [m for m in map if m[0] in self.validation]
|
|
398
|
+
elif isinstance(self.validation[0], str):
|
|
399
|
+
index = [i for i, n in enumerate(subset_names) if n in self.validation]
|
|
400
400
|
train_map = [m for m in map if m[0] not in index]
|
|
401
401
|
validate_map = [m for m in map if m[0] in index]
|
|
402
402
|
|
|
@@ -436,8 +436,8 @@ class DataTrain(Data):
|
|
|
436
436
|
subset : Union[TrainSubset, dict[str, TrainSubset]] = TrainSubset(),
|
|
437
437
|
num_workers : int = 4,
|
|
438
438
|
batch_size : int = 1,
|
|
439
|
-
|
|
440
|
-
super().__init__(dataset_filenames, groups_src, patch, use_cache, subset, num_workers, batch_size,
|
|
439
|
+
validation : Union[float, str, list[int], list[str]] = 0.8) -> None:
|
|
440
|
+
super().__init__(dataset_filenames, groups_src, patch, use_cache, subset, num_workers, batch_size, validation, inlineAugmentations, augmentations if augmentations else {})
|
|
441
441
|
|
|
442
442
|
class DataPrediction(Data):
|
|
443
443
|
|
|
@@ -463,4 +463,4 @@ class DataMetric(Data):
|
|
|
463
463
|
validation: Union[str, None] = None,
|
|
464
464
|
num_workers : int = 4) -> None:
|
|
465
465
|
|
|
466
|
-
super().__init__(dataset_filenames=dataset_filenames, groups_src=groups_src, patch=None, use_cache=False, subset=subset, num_workers=num_workers, batch_size=1,
|
|
466
|
+
super().__init__(dataset_filenames=dataset_filenames, groups_src=groups_src, patch=None, use_cache=False, subset=subset, num_workers=num_workers, batch_size=1, validation=1 if validation is None else validation)
|
|
@@ -188,7 +188,7 @@ class Patch(ABC):
|
|
|
188
188
|
pad_bottom = 0
|
|
189
189
|
pad_top = 0
|
|
190
190
|
if self._patch_slices[a][index][0].start-bottom < 0:
|
|
191
|
-
pad_bottom = bottom-
|
|
191
|
+
pad_bottom = bottom-self._patch_slices[a][index][0].start
|
|
192
192
|
if self._patch_slices[a][index][0].stop+top > data.shape[len(slices_pre)]:
|
|
193
193
|
pad_top = self._patch_slices[a][index][0].stop+top-data.shape[len(slices_pre)]
|
|
194
194
|
data_sliced = F.pad(data_sliced, [0 for _ in range((len(slices)-1)*2)]+[pad_bottom, pad_top], 'reflect')
|
|
@@ -147,7 +147,7 @@ class Dice(Criterion):
|
|
|
147
147
|
def forward(self, output: torch.Tensor, *targets : list[torch.Tensor]) -> torch.Tensor:
|
|
148
148
|
target = targets[0]
|
|
149
149
|
if output.shape[1] == 1:
|
|
150
|
-
output = F.one_hot(output.type(torch.int64), num_classes=torch.max(output).item()+1).permute(0, len(target.shape), *[i+1 for i in range(len(target.shape)-1)]).float()
|
|
150
|
+
output = F.one_hot(output.type(torch.int64), num_classes=int(torch.max(output).item()+1)).permute(0, len(target.shape), *[i+1 for i in range(len(target.shape)-1)]).float()
|
|
151
151
|
target = F.one_hot(target.type(torch.int64), num_classes=output.shape[1]).permute(0, len(target.shape), *[i+1 for i in range(len(target.shape)-1)]).float().squeeze(2)
|
|
152
152
|
return 1-torch.mean(self.dice_per_channel(output, target))
|
|
153
153
|
|
|
@@ -1,5 +1,4 @@
|
|
|
1
1
|
import torch
|
|
2
|
-
import torch.optim.adamw
|
|
3
2
|
from torch.utils.data import DataLoader
|
|
4
3
|
import tqdm
|
|
5
4
|
import numpy as np
|
|
@@ -19,9 +18,72 @@ from konfai.utils.utils import State, DataLog, DistributedObject, description
|
|
|
19
18
|
from konfai.network.network import Network, ModelLoader, NetState, CPU_Model
|
|
20
19
|
|
|
21
20
|
|
|
21
|
+
class EarlyStoppingBase:
|
|
22
|
+
|
|
23
|
+
def __init__(self):
|
|
24
|
+
pass
|
|
25
|
+
|
|
26
|
+
def isStopped(self) -> bool:
|
|
27
|
+
return False
|
|
28
|
+
|
|
29
|
+
def getScore(self, values: dict[str, float]):
|
|
30
|
+
return sum([i for i in values.values()])
|
|
31
|
+
|
|
32
|
+
def __call__(self, current_score: float) -> bool:
|
|
33
|
+
return False
|
|
34
|
+
|
|
35
|
+
class EarlyStopping(EarlyStoppingBase):
|
|
36
|
+
|
|
37
|
+
@config("EarlyStopping")
|
|
38
|
+
def __init__(self, monitor: Union[list[str], None] = [], patience=10, min_delta=0.0, mode="min"):
|
|
39
|
+
super().__init__()
|
|
40
|
+
self.monitor = [] if monitor is None else monitor
|
|
41
|
+
self.patience = patience
|
|
42
|
+
self.min_delta = min_delta
|
|
43
|
+
self.mode = mode
|
|
44
|
+
self.counter = 0
|
|
45
|
+
self.best_score = None
|
|
46
|
+
self.early_stop = False
|
|
47
|
+
|
|
48
|
+
def isStopped(self) -> bool:
|
|
49
|
+
return self.early_stop
|
|
50
|
+
|
|
51
|
+
def getScore(self, values: dict[str, float]):
|
|
52
|
+
if len(self.monitor) == 0:
|
|
53
|
+
return super().getScore(values)
|
|
54
|
+
for v in self.monitor:
|
|
55
|
+
if v not in values.keys():
|
|
56
|
+
raise ValueError(
|
|
57
|
+
"[EarlyStopping] Metric '{}' specified in `monitor` not found in logged values. "
|
|
58
|
+
"Available keys: {}. Please check your configuration.".format(v, list(values.keys())))
|
|
59
|
+
return sum([i for v, i in values.items() if v in self.monitor])
|
|
60
|
+
|
|
61
|
+
def __call__(self, current_score: float) -> bool:
|
|
62
|
+
if self.best_score is None:
|
|
63
|
+
self.best_score = current_score
|
|
64
|
+
return False
|
|
65
|
+
|
|
66
|
+
if self.mode == "min":
|
|
67
|
+
improvement = self.best_score - current_score
|
|
68
|
+
elif self.mode == "max":
|
|
69
|
+
improvement = current_score - self.best_score
|
|
70
|
+
else:
|
|
71
|
+
raise ValueError("Mode must be 'min' or 'max'.")
|
|
72
|
+
|
|
73
|
+
if improvement > self.min_delta:
|
|
74
|
+
self.best_score = current_score
|
|
75
|
+
self.counter = 0
|
|
76
|
+
else:
|
|
77
|
+
self.counter += 1
|
|
78
|
+
|
|
79
|
+
if self.counter >= self.patience:
|
|
80
|
+
self.early_stop = True
|
|
81
|
+
|
|
82
|
+
return self.early_stop
|
|
83
|
+
|
|
22
84
|
class _Trainer():
|
|
23
85
|
|
|
24
|
-
def __init__(self, world_size: int, global_rank: int, local_rank: int, size: int, train_name: str, data_log: Union[list[str], None] , save_checkpoint_mode: str, epochs: int, epoch: int, autocast: bool, it_validation: Union[int, None], it: int, model: Union[DDP, CPU_Model], modelEMA: AveragedModel, dataloader_training: DataLoader, dataloader_validation: Union[DataLoader, None] = None) -> None:
|
|
86
|
+
def __init__(self, world_size: int, global_rank: int, local_rank: int, size: int, train_name: str, early_stopping: EarlyStopping, data_log: Union[list[str], None] , save_checkpoint_mode: str, epochs: int, epoch: int, autocast: bool, it_validation: Union[int, None], it: int, model: Union[DDP, CPU_Model], modelEMA: AveragedModel, dataloader_training: DataLoader, dataloader_validation: Union[DataLoader, None] = None) -> None:
|
|
25
87
|
self.world_size = world_size
|
|
26
88
|
self.global_rank = global_rank
|
|
27
89
|
self.local_rank = local_rank
|
|
@@ -36,7 +98,8 @@ class _Trainer():
|
|
|
36
98
|
self.dataloader_validation = dataloader_validation
|
|
37
99
|
self.autocast = autocast
|
|
38
100
|
self.modelEMA = modelEMA
|
|
39
|
-
|
|
101
|
+
self.early_stopping = EarlyStoppingBase() if early_stopping is None else early_stopping
|
|
102
|
+
|
|
40
103
|
self.it_validation = it_validation
|
|
41
104
|
if self.it_validation is None:
|
|
42
105
|
self.it_validation = len(dataloader_training)
|
|
@@ -59,6 +122,8 @@ class _Trainer():
|
|
|
59
122
|
for self.epoch in epoch_tqdm:
|
|
60
123
|
self.dataloader_training.dataset.load()
|
|
61
124
|
self.train()
|
|
125
|
+
if self.early_stopping.isStopped():
|
|
126
|
+
break
|
|
62
127
|
self.dataloader_training.dataset.resetAugmentation()
|
|
63
128
|
|
|
64
129
|
def getInput(self, data_dict : dict[str, tuple[torch.Tensor, int, int, int, str, bool]]) -> dict[tuple[str, bool], torch.Tensor]:
|
|
@@ -88,7 +153,11 @@ class _Trainer():
|
|
|
88
153
|
if self.dataloader_validation is not None:
|
|
89
154
|
loss = self._validate()
|
|
90
155
|
self.model.module.update_lr()
|
|
91
|
-
self.
|
|
156
|
+
score = self.early_stopping.getScore(loss)
|
|
157
|
+
self.checkpoint_save(score)
|
|
158
|
+
if self.early_stopping(score):
|
|
159
|
+
break
|
|
160
|
+
|
|
92
161
|
|
|
93
162
|
batch_iter.set_description(desc())
|
|
94
163
|
|
|
@@ -120,41 +189,53 @@ class _Trainer():
|
|
|
120
189
|
return self._validation_log(data_dict)
|
|
121
190
|
|
|
122
191
|
def checkpoint_save(self, loss: float) -> None:
|
|
123
|
-
if self.global_rank
|
|
124
|
-
|
|
125
|
-
last_loss = None
|
|
126
|
-
if os.path.exists(path) and os.listdir(path):
|
|
127
|
-
name = sorted(os.listdir(path))[-1]
|
|
128
|
-
state_dict = torch.load(path+name, weights_only=False)
|
|
129
|
-
last_loss = state_dict["loss"]
|
|
130
|
-
if self.save_checkpoint_mode == "BEST":
|
|
131
|
-
if last_loss >= loss:
|
|
132
|
-
os.remove(path+name)
|
|
133
|
-
|
|
134
|
-
if self.save_checkpoint_mode != "BEST" or (last_loss is None or last_loss >= loss):
|
|
135
|
-
name = DATE()+".pt"
|
|
136
|
-
if not os.path.exists(path):
|
|
137
|
-
os.makedirs(path)
|
|
138
|
-
|
|
139
|
-
save_dict = {
|
|
140
|
-
"epoch": self.epoch,
|
|
141
|
-
"it": self.it,
|
|
142
|
-
"loss": loss,
|
|
143
|
-
"Model": self.model.module.state_dict()}
|
|
192
|
+
if self.global_rank != 0:
|
|
193
|
+
return
|
|
144
194
|
|
|
145
|
-
|
|
146
|
-
|
|
195
|
+
path = CHECKPOINTS_DIRECTORY()+self.train_name+"/"
|
|
196
|
+
os.makedirs(path, exist_ok=True)
|
|
197
|
+
|
|
198
|
+
name = DATE() + ".pt"
|
|
199
|
+
save_path = os.path.join(path, name)
|
|
147
200
|
|
|
148
|
-
|
|
149
|
-
|
|
201
|
+
save_dict = {
|
|
202
|
+
"epoch": self.epoch,
|
|
203
|
+
"it": self.it,
|
|
204
|
+
"loss": loss,
|
|
205
|
+
"Model": self.model.module.state_dict()
|
|
206
|
+
}
|
|
207
|
+
|
|
208
|
+
if self.modelEMA is not None:
|
|
209
|
+
save_dict["Model_EMA"] = self.modelEMA.module.state_dict()
|
|
210
|
+
|
|
211
|
+
save_dict.update({'{}_optimizer_state_dict'.format(name): network.optimizer.state_dict() for name, network in self.model.module.getNetworks().items() if network.optimizer is not None})
|
|
212
|
+
torch.save(save_dict, save_path)
|
|
213
|
+
|
|
214
|
+
if self.save_checkpoint_mode == "BEST":
|
|
215
|
+
all_checkpoints = sorted([
|
|
216
|
+
os.path.join(path, f)
|
|
217
|
+
for f in os.listdir(path) if f.endswith(".pt")
|
|
218
|
+
])
|
|
219
|
+
best_ckpt = None
|
|
220
|
+
best_loss = float('inf')
|
|
221
|
+
|
|
222
|
+
for f in all_checkpoints:
|
|
223
|
+
d = torch.load(f, weights_only=False)
|
|
224
|
+
if d.get("loss", float("inf")) < best_loss:
|
|
225
|
+
best_loss = d["loss"]
|
|
226
|
+
best_ckpt = f
|
|
227
|
+
|
|
228
|
+
for f in all_checkpoints:
|
|
229
|
+
if f != best_ckpt and f != save_path:
|
|
230
|
+
os.remove(f)
|
|
150
231
|
|
|
151
232
|
@torch.no_grad()
|
|
152
|
-
def _log(self, type_log: str, data_dict : dict[str, tuple[torch.Tensor, int, int, int]]) -> float:
|
|
233
|
+
def _log(self, type_log: str, data_dict : dict[str, tuple[torch.Tensor, int, int, int]]) -> dict[str, float]:
|
|
153
234
|
models: dict[str, Network] = {"" : self.model.module}
|
|
154
235
|
if self.modelEMA is not None:
|
|
155
236
|
models["_EMA"] = self.modelEMA.module
|
|
156
237
|
|
|
157
|
-
measures = DistributedObject.getMeasure(self.world_size, self.global_rank, self.local_rank*self.size+self.size-1, models, self.it_validation if type_log == "
|
|
238
|
+
measures = DistributedObject.getMeasure(self.world_size, self.global_rank, self.local_rank*self.size+self.size-1, models, self.it_validation if type_log == "Training" else len(self.dataloader_validation))
|
|
158
239
|
|
|
159
240
|
if self.global_rank == 0:
|
|
160
241
|
images_log = []
|
|
@@ -178,27 +259,29 @@ class _Trainer():
|
|
|
178
259
|
for name, layer, _ in model.get_layers([v.to(0) for k, v in self.getInput(data_dict).items() if k[1]], images_log):
|
|
179
260
|
self.data_log[name][0](self.tb, "{}/{}{}".format(type_log, name, label), layer[:self.data_log[name][1]].detach().cpu().numpy(), self.it)
|
|
180
261
|
|
|
181
|
-
if type_log == "
|
|
262
|
+
if type_log == "Training":
|
|
182
263
|
for name, network in self.model.module.getNetworks().items():
|
|
183
264
|
if network.optimizer is not None:
|
|
184
265
|
self.tb.add_scalar("{}/{}/Learning Rate".format(type_log, name), network.optimizer.param_groups[0]['lr'], self.it)
|
|
185
266
|
|
|
186
267
|
if self.global_rank == 0:
|
|
187
|
-
loss =
|
|
268
|
+
loss = {}
|
|
188
269
|
for name, network in self.model.module.getNetworks().items():
|
|
189
270
|
if network.measure is not None:
|
|
190
|
-
loss.
|
|
191
|
-
|
|
271
|
+
loss.update({k : v[1] for k, v in measures["{}{}".format(name, label)][0].items()})
|
|
272
|
+
loss.update({k : v[1] for k, v in measures["{}{}".format(name, label)][1].items()})
|
|
273
|
+
return loss
|
|
192
274
|
return None
|
|
193
275
|
|
|
194
276
|
@torch.no_grad()
|
|
195
|
-
def _train_log(self, data_dict : dict[str, tuple[torch.Tensor, int, int, int]]) -> float:
|
|
196
|
-
return self._log("
|
|
277
|
+
def _train_log(self, data_dict : dict[str, tuple[torch.Tensor, int, int, int]]) -> dict[str, float]:
|
|
278
|
+
return self._log("Training", data_dict)
|
|
197
279
|
|
|
198
280
|
@torch.no_grad()
|
|
199
|
-
def _validation_log(self, data_dict : dict[str, tuple[torch.Tensor, int, int, int]]) -> float:
|
|
281
|
+
def _validation_log(self, data_dict : dict[str, tuple[torch.Tensor, int, int, int]]) -> dict[str, float]:
|
|
200
282
|
return self._log("Validation", data_dict)
|
|
201
283
|
|
|
284
|
+
|
|
202
285
|
class Trainer(DistributedObject):
|
|
203
286
|
|
|
204
287
|
@config("Trainer")
|
|
@@ -214,6 +297,7 @@ class Trainer(DistributedObject):
|
|
|
214
297
|
gpu_checkpoints: Union[list[str], None] = None,
|
|
215
298
|
ema_decay : float = 0,
|
|
216
299
|
data_log: Union[list[str], None] = None,
|
|
300
|
+
early_stopping: Union[EarlyStopping, None] = None,
|
|
217
301
|
save_checkpoint_mode: str= "BEST") -> None:
|
|
218
302
|
if os.environ["KONFAI_CONFIG_MODE"] != "Done":
|
|
219
303
|
exit(0)
|
|
@@ -223,6 +307,7 @@ class Trainer(DistributedObject):
|
|
|
223
307
|
self.autocast = autocast
|
|
224
308
|
self.epochs = epochs
|
|
225
309
|
self.epoch = 0
|
|
310
|
+
self.early_stopping = early_stopping
|
|
226
311
|
self.it = 0
|
|
227
312
|
self.it_validation = it_validation
|
|
228
313
|
self.model = model.getModel(train=True)
|
|
@@ -326,5 +411,5 @@ class Trainer(DistributedObject):
|
|
|
326
411
|
model = DDP(model, static_graph=True) if torch.cuda.is_available() else CPU_Model(model)
|
|
327
412
|
if self.modelEMA is not None:
|
|
328
413
|
self.modelEMA.module = Network.to(self.modelEMA.module, local_rank)
|
|
329
|
-
with _Trainer(world_size, global_rank, local_rank, self.size, self.name, self.data_log, self.save_checkpoint_mode, self.epochs, self.epoch, self.autocast, self.it_validation, self.it, model, self.modelEMA, *dataloaders) as t:
|
|
414
|
+
with _Trainer(world_size, global_rank, local_rank, self.size, self.name, self.early_stopping, self.data_log, self.save_checkpoint_mode, self.epochs, self.epoch, self.autocast, self.it_validation, self.it, model, self.modelEMA, *dataloaders) as t:
|
|
330
415
|
t.run()
|
|
@@ -236,7 +236,7 @@ def config(key : Union[str, None] = None):
|
|
|
236
236
|
try:
|
|
237
237
|
kwargs[param.name] = annotation(config = filename, DL_args = key_tmp)
|
|
238
238
|
except Exception as e:
|
|
239
|
-
raise ValueError("[Config] Failed to instantiate {} with type {}".format(param.name, annotation.__name__))
|
|
239
|
+
raise ValueError("[Config] Failed to instantiate {} with type {}, error {} ".format(param.name, annotation.__name__, e))
|
|
240
240
|
|
|
241
241
|
if os.environ['KONFAI_CONFIG_VARIABLE'] == "True":
|
|
242
242
|
os.environ['KONFAI_CONFIG_VARIABLE'] = "False"
|
|
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
|
|
|
4
4
|
|
|
5
5
|
[project]
|
|
6
6
|
name = "konfai"
|
|
7
|
-
version = "1.0
|
|
7
|
+
version = "1.1.0"
|
|
8
8
|
description = "Modular and configurable Deep Learning framework with YAML and PyTorch"
|
|
9
9
|
readme = "README.md"
|
|
10
10
|
requires-python = ">=3.8"
|
|
@@ -41,4 +41,7 @@ konfai-cluster = "konfai.main:cluster"
|
|
|
41
41
|
vtk = ["vtk"]
|
|
42
42
|
lpips = ["lpips"]
|
|
43
43
|
cluster = ["submitit"]
|
|
44
|
-
plot = ["matplotlib"]
|
|
44
|
+
plot = ["matplotlib"]
|
|
45
|
+
|
|
46
|
+
[tool.poetry.dev-dependencies]
|
|
47
|
+
pytest = "^8.0"
|
|
@@ -0,0 +1,18 @@
|
|
|
1
|
+
import pytest
|
|
2
|
+
from konfai.utils.config import config
|
|
3
|
+
|
|
4
|
+
class Dummy:
|
|
5
|
+
|
|
6
|
+
@config("Test")
|
|
7
|
+
def __init__(self, a: int = 1, b: str = "ok"):
|
|
8
|
+
self.a = a
|
|
9
|
+
self.b = b
|
|
10
|
+
|
|
11
|
+
def test_config_instantiation(monkeypatch):
|
|
12
|
+
import os
|
|
13
|
+
os.environ['KONFAI_CONFIG_FILE'] = "./tests/dummy_data/dummy_config.yml"
|
|
14
|
+
os.environ['KONFAI_CONFIG_PATH'] = "Test"
|
|
15
|
+
|
|
16
|
+
dummy = Dummy()
|
|
17
|
+
assert dummy.a == 42
|
|
18
|
+
assert dummy.b == "hello"
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|