konfai 1.1.7__py3-none-any.whl → 1.1.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of konfai might be problematic. Click here for more details.
- konfai/__init__.py +59 -14
- konfai/data/augmentation.py +457 -286
- konfai/data/data_manager.py +509 -290
- konfai/data/patching.py +300 -183
- konfai/data/transform.py +384 -277
- konfai/evaluator.py +309 -68
- konfai/main.py +71 -22
- konfai/metric/measure.py +341 -222
- konfai/metric/schedulers.py +24 -13
- konfai/models/classification/convNeXt.py +187 -81
- konfai/models/classification/resnet.py +272 -58
- konfai/models/generation/cStyleGan.py +233 -59
- konfai/models/generation/ddpm.py +348 -121
- konfai/models/generation/diffusionGan.py +757 -358
- konfai/models/generation/gan.py +177 -53
- konfai/models/generation/vae.py +140 -40
- konfai/models/registration/registration.py +135 -52
- konfai/models/representation/representation.py +57 -23
- konfai/models/segmentation/NestedUNet.py +339 -68
- konfai/models/segmentation/UNet.py +140 -30
- konfai/network/blocks.py +331 -187
- konfai/network/network.py +781 -423
- konfai/predictor.py +645 -240
- konfai/trainer.py +527 -216
- konfai/utils/ITK.py +191 -106
- konfai/utils/config.py +152 -95
- konfai/utils/dataset.py +326 -455
- konfai/utils/utils.py +495 -249
- {konfai-1.1.7.dist-info → konfai-1.1.9.dist-info}/METADATA +1 -3
- konfai-1.1.9.dist-info/RECORD +38 -0
- konfai/utils/registration.py +0 -199
- konfai-1.1.7.dist-info/RECORD +0 -39
- {konfai-1.1.7.dist-info → konfai-1.1.9.dist-info}/WHEEL +0 -0
- {konfai-1.1.7.dist-info → konfai-1.1.9.dist-info}/entry_points.txt +0 -0
- {konfai-1.1.7.dist-info → konfai-1.1.9.dist-info}/licenses/LICENSE +0 -0
- {konfai-1.1.7.dist-info → konfai-1.1.9.dist-info}/top_level.txt +0 -0
konfai/trainer.py
CHANGED
|
@@ -1,62 +1,88 @@
|
|
|
1
|
-
import torch
|
|
2
|
-
from torch.utils.data import DataLoader
|
|
3
|
-
import tqdm
|
|
4
|
-
import numpy as np
|
|
5
1
|
import os
|
|
6
|
-
from typing import Union
|
|
7
|
-
from torch.nn.parallel import DistributedDataParallel as DDP
|
|
8
|
-
|
|
9
2
|
import shutil
|
|
10
|
-
|
|
11
|
-
|
|
3
|
+
|
|
4
|
+
import torch
|
|
12
5
|
import torch.distributed as dist
|
|
6
|
+
import tqdm
|
|
7
|
+
from torch.nn.parallel import DistributedDataParallel as DDP # noqa: N817
|
|
8
|
+
from torch.optim.swa_utils import AveragedModel
|
|
9
|
+
from torch.utils.data import DataLoader
|
|
10
|
+
from torch.utils.tensorboard.writer import SummaryWriter
|
|
13
11
|
|
|
14
|
-
from konfai import
|
|
12
|
+
from konfai import (
|
|
13
|
+
checkpoints_directory,
|
|
14
|
+
config_file,
|
|
15
|
+
current_date,
|
|
16
|
+
konfai_state,
|
|
17
|
+
models_directory,
|
|
18
|
+
path_to_models,
|
|
19
|
+
setups_directory,
|
|
20
|
+
statistics_directory,
|
|
21
|
+
)
|
|
15
22
|
from konfai.data.data_manager import DataTrain
|
|
23
|
+
from konfai.network.network import CPUModel, ModelLoader, NetState, Network
|
|
16
24
|
from konfai.utils.config import config
|
|
17
|
-
from konfai.utils.utils import
|
|
18
|
-
|
|
25
|
+
from konfai.utils.utils import DataLog, DistributedObject, State, TrainerError, description
|
|
26
|
+
|
|
19
27
|
|
|
20
28
|
class EarlyStoppingBase:
|
|
21
29
|
|
|
22
30
|
def __init__(self):
|
|
23
31
|
pass
|
|
24
32
|
|
|
25
|
-
def
|
|
33
|
+
def is_stopped(self) -> bool:
|
|
34
|
+
|
|
26
35
|
return False
|
|
27
36
|
|
|
28
|
-
def
|
|
29
|
-
return sum(
|
|
30
|
-
|
|
37
|
+
def get_score(self, values: dict[str, float]):
|
|
38
|
+
return sum(list(values.values()))
|
|
39
|
+
|
|
31
40
|
def __call__(self, current_score: float) -> bool:
|
|
32
41
|
return False
|
|
33
42
|
|
|
43
|
+
|
|
34
44
|
class EarlyStopping(EarlyStoppingBase):
|
|
45
|
+
"""
|
|
46
|
+
Implements early stopping logic with configurable patience and monitored metrics.
|
|
47
|
+
|
|
48
|
+
Attributes:
|
|
49
|
+
monitor (list[str]): Metrics to monitor.
|
|
50
|
+
patience (int): Number of checks with no improvement before stopping.
|
|
51
|
+
min_delta (float): Minimum change to qualify as improvement.
|
|
52
|
+
mode (str): "min" or "max" depending on optimization direction.
|
|
53
|
+
"""
|
|
35
54
|
|
|
36
55
|
@config("EarlyStopping")
|
|
37
|
-
def __init__(
|
|
56
|
+
def __init__(
|
|
57
|
+
self,
|
|
58
|
+
monitor: list[str] | None = None,
|
|
59
|
+
patience: int = 10,
|
|
60
|
+
min_delta: float = 0.0,
|
|
61
|
+
mode: str = "min",
|
|
62
|
+
):
|
|
38
63
|
super().__init__()
|
|
39
64
|
self.monitor = [] if monitor is None else monitor
|
|
40
65
|
self.patience = patience
|
|
41
66
|
self.min_delta = min_delta
|
|
42
67
|
self.mode = mode
|
|
43
68
|
self.counter = 0
|
|
44
|
-
self.best_score = None
|
|
69
|
+
self.best_score: float | None = None
|
|
45
70
|
self.early_stop = False
|
|
46
71
|
|
|
47
|
-
def
|
|
72
|
+
def is_stopped(self) -> bool:
|
|
48
73
|
return self.early_stop
|
|
49
74
|
|
|
50
|
-
def
|
|
75
|
+
def get_score(self, values: dict[str, float]):
|
|
51
76
|
if len(self.monitor) == 0:
|
|
52
|
-
return super().
|
|
77
|
+
return super().get_score(values)
|
|
53
78
|
for v in self.monitor:
|
|
54
79
|
if v not in values.keys():
|
|
55
80
|
raise TrainerError(
|
|
56
81
|
"Metric '{}' specified in EarlyStopping.monitor not found in logged values. ",
|
|
57
|
-
"Available keys: {}. Please check your configuration."
|
|
82
|
+
f"Available keys: {v}. Please check your configuration.",
|
|
83
|
+
)
|
|
58
84
|
return sum([i for v, i in values.items() if v in self.monitor])
|
|
59
|
-
|
|
85
|
+
|
|
60
86
|
def __call__(self, current_score: float) -> bool:
|
|
61
87
|
if self.best_score is None:
|
|
62
88
|
self.best_score = current_score
|
|
@@ -80,10 +106,44 @@ class EarlyStopping(EarlyStoppingBase):
|
|
|
80
106
|
|
|
81
107
|
return self.early_stop
|
|
82
108
|
|
|
83
|
-
class _Trainer():
|
|
84
109
|
|
|
85
|
-
|
|
86
|
-
|
|
110
|
+
class _Trainer:
|
|
111
|
+
"""
|
|
112
|
+
Internal class for managing the training loop in a distributed or standalone setting.
|
|
113
|
+
|
|
114
|
+
Handles:
|
|
115
|
+
- Epoch iteration with training and optional validation
|
|
116
|
+
- Mixed precision support (autocast)
|
|
117
|
+
- Exponential Moving Average (EMA) model tracking
|
|
118
|
+
- Early stopping
|
|
119
|
+
- Logging to TensorBoard
|
|
120
|
+
- Model checkpoint saving and selection (ALL or BEST)
|
|
121
|
+
|
|
122
|
+
This class is intended to be used via a context manager
|
|
123
|
+
(`with _Trainer(...) as trainer:`) inside the public `Trainer` class.
|
|
124
|
+
"""
|
|
125
|
+
|
|
126
|
+
def __init__(
|
|
127
|
+
self,
|
|
128
|
+
world_size: int,
|
|
129
|
+
global_rank: int,
|
|
130
|
+
local_rank: int,
|
|
131
|
+
size: int,
|
|
132
|
+
train_name: str,
|
|
133
|
+
early_stopping: EarlyStopping | None,
|
|
134
|
+
data_log: list[str] | None,
|
|
135
|
+
save_checkpoint_mode: str,
|
|
136
|
+
epochs: int,
|
|
137
|
+
epoch: int,
|
|
138
|
+
autocast: bool,
|
|
139
|
+
it_validation: int | None,
|
|
140
|
+
it: int,
|
|
141
|
+
model: DDP | CPUModel,
|
|
142
|
+
model_ema: AveragedModel,
|
|
143
|
+
dataloader_training: DataLoader,
|
|
144
|
+
dataloader_validation: DataLoader | None = None,
|
|
145
|
+
) -> None:
|
|
146
|
+
self.world_size = world_size
|
|
87
147
|
self.global_rank = global_rank
|
|
88
148
|
self.local_rank = local_rank
|
|
89
149
|
self.size = size
|
|
@@ -96,58 +156,94 @@ class _Trainer():
|
|
|
96
156
|
self.dataloader_training = dataloader_training
|
|
97
157
|
self.dataloader_validation = dataloader_validation
|
|
98
158
|
self.autocast = autocast
|
|
99
|
-
self.
|
|
100
|
-
self.early_stopping = EarlyStoppingBase() if early_stopping is None else early_stopping
|
|
101
|
-
|
|
102
|
-
self.it_validation = it_validation
|
|
103
|
-
if self.it_validation is None:
|
|
104
|
-
self.it_validation = len(dataloader_training)
|
|
159
|
+
self.model_ema = model_ema
|
|
160
|
+
self.early_stopping = EarlyStoppingBase() if early_stopping is None else early_stopping
|
|
161
|
+
|
|
162
|
+
self.it_validation = len(dataloader_training) if it_validation is None else it_validation
|
|
105
163
|
self.it = it
|
|
106
|
-
self.tb = SummaryWriter(log_dir
|
|
107
|
-
self.data_log
|
|
164
|
+
self.tb = SummaryWriter(log_dir=statistics_directory() + self.train_name + "/")
|
|
165
|
+
self.data_log: dict[str, tuple[DataLog, int]] = {}
|
|
108
166
|
if data_log is not None:
|
|
109
167
|
for data in data_log:
|
|
110
|
-
self.data_log[data.split("/")[0].replace(":", ".")] = (
|
|
111
|
-
|
|
168
|
+
self.data_log[data.split("/")[0].replace(":", ".")] = (
|
|
169
|
+
DataLog[data.split("/")[1]],
|
|
170
|
+
int(data.split("/")[2]),
|
|
171
|
+
)
|
|
172
|
+
|
|
112
173
|
def __enter__(self):
|
|
113
174
|
return self
|
|
114
|
-
|
|
115
|
-
def __exit__(self,
|
|
175
|
+
|
|
176
|
+
def __exit__(self, exc_type, value, traceback):
|
|
177
|
+
"""Closes the SummaryWriter if used."""
|
|
116
178
|
if self.tb is not None:
|
|
117
179
|
self.tb.close()
|
|
118
180
|
|
|
119
181
|
def run(self) -> None:
|
|
182
|
+
"""
|
|
183
|
+
Launches the training loop, performing one epoch at a time.
|
|
184
|
+
Triggers early stopping and resets data augmentations between epochs.
|
|
185
|
+
"""
|
|
120
186
|
self.dataloader_training.dataset.load("Train")
|
|
121
187
|
if self.dataloader_validation is not None:
|
|
122
188
|
self.dataloader_validation.dataset.load("Validation")
|
|
123
|
-
with tqdm.tqdm(
|
|
189
|
+
with tqdm.tqdm(
|
|
190
|
+
iterable=range(self.epoch, self.epochs),
|
|
191
|
+
leave=False,
|
|
192
|
+
total=self.epochs,
|
|
193
|
+
initial=self.epoch,
|
|
194
|
+
desc="Progress",
|
|
195
|
+
) as epoch_tqdm:
|
|
124
196
|
for self.epoch in epoch_tqdm:
|
|
125
197
|
self.train()
|
|
126
|
-
if self.early_stopping.
|
|
198
|
+
if self.early_stopping.is_stopped():
|
|
127
199
|
break
|
|
128
|
-
self.dataloader_training.dataset.
|
|
129
|
-
|
|
130
|
-
def getInput(self, data_dict : dict[str, tuple[torch.Tensor, int, int, int, str, bool]]) -> dict[tuple[str, bool], torch.Tensor]:
|
|
131
|
-
return {(k, v[5][0].item()) : v[0] for k, v in data_dict.items()}
|
|
200
|
+
self.dataloader_training.dataset.reset_augmentation("Train")
|
|
132
201
|
|
|
133
|
-
def
|
|
134
|
-
self
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
202
|
+
def get_input(
|
|
203
|
+
self,
|
|
204
|
+
data_dict: dict[str, tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, list[str], torch.Tensor]],
|
|
205
|
+
) -> dict[tuple[str, bool], torch.Tensor]:
|
|
206
|
+
"""
|
|
207
|
+
Extracts input tensors from the data dict for model input.
|
|
139
208
|
|
|
140
|
-
|
|
209
|
+
Args:
|
|
210
|
+
data_dict (dict): Dictionary with data items structured as tuples.
|
|
141
211
|
|
|
142
|
-
|
|
212
|
+
Returns:
|
|
213
|
+
dict: Mapping from (group, bool_flag) to input tensors.
|
|
214
|
+
"""
|
|
215
|
+
return {(k, v[5][0].item()): v[0] for k, v in data_dict.items()}
|
|
216
|
+
|
|
217
|
+
def train(self) -> None:
|
|
218
|
+
"""
|
|
219
|
+
Performs a full training epoch with support for:
|
|
220
|
+
- mixed precision
|
|
221
|
+
- DDP / CPU training
|
|
222
|
+
- EMA updates
|
|
223
|
+
- loss logging and checkpoint saving
|
|
224
|
+
- validation at configurable iteration interval
|
|
225
|
+
"""
|
|
226
|
+
self.model.train()
|
|
227
|
+
self.model.module.set_state(NetState.TRAIN)
|
|
228
|
+
if self.model_ema is not None:
|
|
229
|
+
self.model_ema.eval()
|
|
230
|
+
self.model_ema.module.set_state(NetState.TRAIN)
|
|
231
|
+
|
|
232
|
+
with tqdm.tqdm(
|
|
233
|
+
iterable=enumerate(self.dataloader_training),
|
|
234
|
+
desc=f"Training : {description(self.model, self.model_ema)}",
|
|
235
|
+
total=len(self.dataloader_training),
|
|
236
|
+
leave=False,
|
|
237
|
+
ncols=0,
|
|
238
|
+
) as batch_iter:
|
|
143
239
|
for _, data_dict in batch_iter:
|
|
144
|
-
with torch.amp.autocast(
|
|
145
|
-
|
|
146
|
-
self.model(
|
|
240
|
+
with torch.amp.autocast("cuda", enabled=self.autocast):
|
|
241
|
+
input_data_dict = self.get_input(data_dict)
|
|
242
|
+
self.model(input_data_dict)
|
|
147
243
|
self.model.module.backward(self.model.module)
|
|
148
|
-
if self.
|
|
149
|
-
self.
|
|
150
|
-
self.
|
|
244
|
+
if self.model_ema is not None:
|
|
245
|
+
self.model_ema.update_parameters(self.model)
|
|
246
|
+
self.model_ema.module(input_data_dict)
|
|
151
247
|
self.it += 1
|
|
152
248
|
if (self.it) % self.it_validation == 0:
|
|
153
249
|
loss = self._train_log(data_dict)
|
|
@@ -155,76 +251,109 @@ class _Trainer():
|
|
|
155
251
|
if self.dataloader_validation is not None:
|
|
156
252
|
loss = self._validate()
|
|
157
253
|
self.model.module.update_lr()
|
|
158
|
-
score = self.early_stopping.
|
|
254
|
+
score = self.early_stopping.get_score(loss)
|
|
159
255
|
self.checkpoint_save(score)
|
|
160
256
|
if self.early_stopping(score):
|
|
161
257
|
break
|
|
162
|
-
|
|
163
258
|
|
|
164
|
-
batch_iter.set_description(
|
|
165
|
-
|
|
166
|
-
|
|
259
|
+
batch_iter.set_description(f"Training : {description(self.model, self.model_ema)}")
|
|
260
|
+
|
|
167
261
|
@torch.no_grad()
|
|
168
262
|
def _validate(self) -> float:
|
|
263
|
+
"""
|
|
264
|
+
Executes the validation phase, evaluates loss and metrics.
|
|
265
|
+
Updates model states and resets augmentation for validation set.
|
|
266
|
+
|
|
267
|
+
Returns:
|
|
268
|
+
float: Validation loss.
|
|
269
|
+
"""
|
|
270
|
+
if self.dataloader_validation is None:
|
|
271
|
+
return 0
|
|
169
272
|
self.model.eval()
|
|
170
|
-
self.model.module.
|
|
171
|
-
if self.
|
|
172
|
-
self.
|
|
273
|
+
self.model.module.set_state(NetState.PREDICTION)
|
|
274
|
+
if self.model_ema is not None:
|
|
275
|
+
self.model_ema.module.set_state(NetState.PREDICTION)
|
|
173
276
|
|
|
174
|
-
desc = lambda : "Validation : {}".format(description(self.model, self.modelEMA))
|
|
175
277
|
data_dict = None
|
|
176
|
-
with tqdm.tqdm(
|
|
278
|
+
with tqdm.tqdm(
|
|
279
|
+
iterable=enumerate(self.dataloader_validation),
|
|
280
|
+
desc=f"Validation : {description(self.model, self.model_ema)}",
|
|
281
|
+
total=len(self.dataloader_validation),
|
|
282
|
+
leave=False,
|
|
283
|
+
ncols=0,
|
|
284
|
+
) as batch_iter:
|
|
177
285
|
for _, data_dict in batch_iter:
|
|
178
|
-
|
|
179
|
-
self.model(
|
|
180
|
-
if self.
|
|
181
|
-
self.
|
|
286
|
+
input_data_dict = self.get_input(data_dict)
|
|
287
|
+
self.model(input_data_dict)
|
|
288
|
+
if self.model_ema is not None:
|
|
289
|
+
self.model_ema.module(input_data_dict)
|
|
182
290
|
|
|
183
|
-
batch_iter.set_description(
|
|
184
|
-
self.dataloader_validation.dataset.
|
|
291
|
+
batch_iter.set_description(f"Validation : {description(self.model, self.model_ema)}")
|
|
292
|
+
self.dataloader_validation.dataset.reset_augmentation("Validation")
|
|
185
293
|
dist.barrier()
|
|
186
294
|
self.model.train()
|
|
187
|
-
self.model.module.
|
|
188
|
-
if self.
|
|
189
|
-
self.
|
|
295
|
+
self.model.module.set_state(NetState.TRAIN)
|
|
296
|
+
if self.model_ema is not None:
|
|
297
|
+
self.model_ema.module.set_state(NetState.TRAIN)
|
|
190
298
|
return self._validation_log(data_dict)
|
|
191
299
|
|
|
192
300
|
def checkpoint_save(self, loss: float) -> None:
|
|
301
|
+
"""
|
|
302
|
+
Saves model and optimizer states. Keeps either all checkpoints or only the best one.
|
|
303
|
+
|
|
304
|
+
Args:
|
|
305
|
+
loss (float): Current loss used for best checkpoint selection.
|
|
306
|
+
"""
|
|
193
307
|
if self.global_rank != 0:
|
|
194
308
|
return
|
|
195
309
|
|
|
196
|
-
path =
|
|
310
|
+
path = checkpoints_directory() + self.train_name + "/"
|
|
197
311
|
os.makedirs(path, exist_ok=True)
|
|
198
312
|
|
|
199
|
-
name =
|
|
313
|
+
name = current_date() + ".pt"
|
|
200
314
|
save_path = os.path.join(path, name)
|
|
201
315
|
|
|
202
316
|
save_dict = {
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
317
|
+
"epoch": self.epoch,
|
|
318
|
+
"it": self.it,
|
|
319
|
+
"loss": loss,
|
|
320
|
+
"Model": self.model.module.state_dict(),
|
|
207
321
|
}
|
|
208
322
|
|
|
209
|
-
if self.
|
|
210
|
-
save_dict["Model_EMA"] = self.
|
|
211
|
-
|
|
212
|
-
save_dict.update(
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
323
|
+
if self.model_ema is not None:
|
|
324
|
+
save_dict["Model_EMA"] = self.model_ema.module.state_dict()
|
|
325
|
+
|
|
326
|
+
save_dict.update(
|
|
327
|
+
{
|
|
328
|
+
f"{name}_optimizer_state_dict": network.optimizer.state_dict()
|
|
329
|
+
for name, network in self.model.module.get_networks().items()
|
|
330
|
+
if network.optimizer is not None
|
|
331
|
+
}
|
|
332
|
+
)
|
|
333
|
+
save_dict.update(
|
|
334
|
+
{
|
|
335
|
+
f"{name}_it": network._it
|
|
336
|
+
for name, network in self.model.module.get_networks().items()
|
|
337
|
+
if network.optimizer is not None
|
|
338
|
+
}
|
|
339
|
+
)
|
|
340
|
+
save_dict.update(
|
|
341
|
+
{
|
|
342
|
+
f"{name}_nb_lr_update": network._nb_lr_update
|
|
343
|
+
for name, network in self.model.module.get_networks().items()
|
|
344
|
+
if network.optimizer is not None
|
|
345
|
+
}
|
|
346
|
+
)
|
|
347
|
+
|
|
216
348
|
torch.save(save_dict, save_path)
|
|
217
349
|
|
|
218
350
|
if self.save_checkpoint_mode == "BEST":
|
|
219
|
-
all_checkpoints = sorted([
|
|
220
|
-
os.path.join(path, f)
|
|
221
|
-
for f in os.listdir(path) if f.endswith(".pt")
|
|
222
|
-
])
|
|
351
|
+
all_checkpoints = sorted([os.path.join(path, f) for f in os.listdir(path) if f.endswith(".pt")])
|
|
223
352
|
best_ckpt = None
|
|
224
|
-
best_loss = float(
|
|
353
|
+
best_loss = float("inf")
|
|
225
354
|
|
|
226
355
|
for f in all_checkpoints:
|
|
227
|
-
d = torch.load(f, weights_only=
|
|
356
|
+
d = torch.load(f, weights_only=True)
|
|
228
357
|
if d.get("loss", float("inf")) < best_loss:
|
|
229
358
|
best_loss = d["loss"]
|
|
230
359
|
best_ckpt = f
|
|
@@ -234,79 +363,165 @@ class _Trainer():
|
|
|
234
363
|
os.remove(f)
|
|
235
364
|
|
|
236
365
|
@torch.no_grad()
|
|
237
|
-
def _log(
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
|
|
366
|
+
def _log(
|
|
367
|
+
self,
|
|
368
|
+
type_log: str,
|
|
369
|
+
data_dict: dict[str, tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, list[str], torch.Tensor]],
|
|
370
|
+
) -> dict[str, float] | None:
|
|
371
|
+
"""
|
|
372
|
+
Logs losses, metrics and optionally images to TensorBoard.
|
|
373
|
+
|
|
374
|
+
Args:
|
|
375
|
+
type_log (str): "Training" or "Validation".
|
|
376
|
+
data_dict (dict): Dictionary of data from current batch.
|
|
377
|
+
|
|
378
|
+
Returns:
|
|
379
|
+
dict[str, float] | None: Dictionary of aggregated losses and metrics if rank == 0.
|
|
380
|
+
"""
|
|
381
|
+
models: dict[str, Network] = {"": self.model.module}
|
|
382
|
+
if self.model_ema is not None:
|
|
383
|
+
models["_EMA"] = self.model_ema.module
|
|
384
|
+
|
|
385
|
+
measures = DistributedObject.get_measure(
|
|
386
|
+
self.world_size,
|
|
387
|
+
self.global_rank,
|
|
388
|
+
self.local_rank * self.size + self.size - 1,
|
|
389
|
+
models,
|
|
390
|
+
(
|
|
391
|
+
self.it_validation
|
|
392
|
+
if type_log == "Training" or self.dataloader_validation is None
|
|
393
|
+
else len(self.dataloader_validation)
|
|
394
|
+
),
|
|
395
|
+
)
|
|
396
|
+
|
|
244
397
|
if self.global_rank == 0:
|
|
245
398
|
images_log = []
|
|
246
|
-
if len(self.data_log):
|
|
399
|
+
if len(self.data_log):
|
|
247
400
|
for name, data_type in self.data_log.items():
|
|
248
401
|
if name in data_dict:
|
|
249
|
-
data_type[0](
|
|
402
|
+
data_type[0](
|
|
403
|
+
self.tb,
|
|
404
|
+
f"{type_log}/{name}",
|
|
405
|
+
data_dict[name][0][: self.data_log[name][1]].detach().cpu().numpy(),
|
|
406
|
+
self.it,
|
|
407
|
+
)
|
|
250
408
|
else:
|
|
251
409
|
images_log.append(name.replace(":", "."))
|
|
252
|
-
|
|
410
|
+
|
|
253
411
|
for label, model in models.items():
|
|
254
|
-
for name, network in model.
|
|
412
|
+
for name, network in model.get_networks().items():
|
|
255
413
|
if network.measure is not None:
|
|
256
|
-
self.tb.add_scalars(
|
|
257
|
-
|
|
414
|
+
self.tb.add_scalars(
|
|
415
|
+
f"{type_log}/{name}/Loss/{label}",
|
|
416
|
+
{k: v[1] for k, v in measures[f"{name}{label}"][0].items()},
|
|
417
|
+
self.it,
|
|
418
|
+
)
|
|
419
|
+
self.tb.add_scalars(
|
|
420
|
+
f"{type_log}/{name}/Loss_weight/{label}",
|
|
421
|
+
{k: v[0] for k, v in measures[f"{name}{label}"][0].items()},
|
|
422
|
+
self.it,
|
|
423
|
+
)
|
|
424
|
+
|
|
425
|
+
self.tb.add_scalars(
|
|
426
|
+
f"{type_log}/{name}/Metric/{label}",
|
|
427
|
+
{k: v[1] for k, v in measures[f"{name}{label}"][1].items()},
|
|
428
|
+
self.it,
|
|
429
|
+
)
|
|
430
|
+
self.tb.add_scalars(
|
|
431
|
+
f"{type_log}/{name}/Metric_weight/{label}",
|
|
432
|
+
{k: v[0] for k, v in measures[f"{name}{label}"][1].items()},
|
|
433
|
+
self.it,
|
|
434
|
+
)
|
|
258
435
|
|
|
259
|
-
self.tb.add_scalars("{}/{}/Metric/{}".format(type_log, name, label), {k : v[1] for k, v in measures["{}{}".format(name, label)][1].items()}, self.it)
|
|
260
|
-
self.tb.add_scalars("{}/{}/Metric_weight/{}".format(type_log, name, label), {k : v[0] for k, v in measures["{}{}".format(name, label)][1].items()}, self.it)
|
|
261
|
-
|
|
262
436
|
if len(images_log):
|
|
263
|
-
for name, layer, _ in model.get_layers(
|
|
264
|
-
|
|
265
|
-
|
|
437
|
+
for name, layer, _ in model.get_layers(
|
|
438
|
+
[v.to(0) for k, v in self.get_input(data_dict).items() if k[1]],
|
|
439
|
+
images_log,
|
|
440
|
+
):
|
|
441
|
+
self.data_log[name][0](
|
|
442
|
+
self.tb,
|
|
443
|
+
f"{type_log}/{name}{label}",
|
|
444
|
+
layer[: self.data_log[name][1]].detach().cpu().numpy(),
|
|
445
|
+
self.it,
|
|
446
|
+
)
|
|
447
|
+
|
|
266
448
|
if type_log == "Training":
|
|
267
|
-
for name, network in self.model.module.
|
|
449
|
+
for name, network in self.model.module.get_networks().items():
|
|
268
450
|
if network.optimizer is not None:
|
|
269
|
-
self.tb.add_scalar(
|
|
270
|
-
|
|
451
|
+
self.tb.add_scalar(
|
|
452
|
+
f"{type_log}/{name}/Learning Rate",
|
|
453
|
+
network.optimizer.param_groups[0]["lr"],
|
|
454
|
+
self.it,
|
|
455
|
+
)
|
|
456
|
+
|
|
271
457
|
if self.global_rank == 0:
|
|
272
458
|
loss = {}
|
|
273
|
-
for name, network in self.model.module.
|
|
459
|
+
for name, network in self.model.module.get_networks().items():
|
|
274
460
|
if network.measure is not None:
|
|
275
|
-
loss.update({k
|
|
276
|
-
loss.update({k
|
|
461
|
+
loss.update({k: v[1] for k, v in measures[f"{name}{label}"][0].items()})
|
|
462
|
+
loss.update({k: v[1] for k, v in measures[f"{name}{label}"][1].items()})
|
|
277
463
|
return loss
|
|
278
464
|
return None
|
|
279
|
-
|
|
465
|
+
|
|
280
466
|
@torch.no_grad()
|
|
281
|
-
def _train_log(self, data_dict
|
|
467
|
+
def _train_log(self, data_dict: dict[str, tuple[torch.Tensor, int, int, int]]) -> dict[str, float]:
|
|
468
|
+
"""Wrapper for _log during training."""
|
|
282
469
|
return self._log("Training", data_dict)
|
|
283
470
|
|
|
284
471
|
@torch.no_grad()
|
|
285
|
-
def _validation_log(self, data_dict
|
|
472
|
+
def _validation_log(self, data_dict: dict[str, tuple[torch.Tensor, int, int, int]]) -> dict[str, float]:
|
|
473
|
+
"""Wrapper for _log during validation."""
|
|
286
474
|
return self._log("Validation", data_dict)
|
|
287
475
|
|
|
288
|
-
|
|
476
|
+
|
|
289
477
|
class Trainer(DistributedObject):
|
|
478
|
+
"""
|
|
479
|
+
Public API for training a model using the KonfAI framework.
|
|
480
|
+
Wraps setup, checkpointing, resuming, logging, and launching distributed _Trainer.
|
|
481
|
+
|
|
482
|
+
Main responsibilities:
|
|
483
|
+
- Initialization from config (via @config)
|
|
484
|
+
- Model and EMA setup
|
|
485
|
+
- Checkpoint loading and saving
|
|
486
|
+
- Distributed setup and launch
|
|
487
|
+
|
|
488
|
+
Args:
|
|
489
|
+
model (ModelLoader): Loader for model architecture.
|
|
490
|
+
dataset (DataTrain): Training/validation dataset.
|
|
491
|
+
train_name (str): Training session name.
|
|
492
|
+
manual_seed (int | None): Random seed.
|
|
493
|
+
epochs (int): Number of epochs to run.
|
|
494
|
+
it_validation (int | None): Validation interval.
|
|
495
|
+
autocast (bool): Enable AMP training.
|
|
496
|
+
gradient_checkpoints (list[str] | None): Modules to use gradient checkpointing on.
|
|
497
|
+
gpu_checkpoints (list[str] | None): Modules to pin on specific GPUs.
|
|
498
|
+
ema_decay (float): EMA decay factor.
|
|
499
|
+
data_log (list[str] | None): Logging instructions.
|
|
500
|
+
early_stopping (EarlyStopping | None): Optional early stopping config.
|
|
501
|
+
save_checkpoint_mode (str): Either "BEST" or "ALL".
|
|
502
|
+
"""
|
|
290
503
|
|
|
291
504
|
@config("Trainer")
|
|
292
|
-
def __init__(
|
|
293
|
-
|
|
294
|
-
|
|
295
|
-
|
|
296
|
-
|
|
297
|
-
|
|
298
|
-
|
|
299
|
-
|
|
300
|
-
|
|
301
|
-
|
|
302
|
-
|
|
303
|
-
|
|
304
|
-
|
|
305
|
-
|
|
505
|
+
def __init__(
|
|
506
|
+
self,
|
|
507
|
+
model: ModelLoader = ModelLoader(),
|
|
508
|
+
dataset: DataTrain = DataTrain(),
|
|
509
|
+
train_name: str = "default:TRAIN_01",
|
|
510
|
+
manual_seed: int | None = None,
|
|
511
|
+
epochs: int = 100,
|
|
512
|
+
it_validation: int | None = None,
|
|
513
|
+
autocast: bool = False,
|
|
514
|
+
gradient_checkpoints: list[str] | None = None,
|
|
515
|
+
gpu_checkpoints: list[str] | None = None,
|
|
516
|
+
ema_decay: float = 0,
|
|
517
|
+
data_log: list[str] | None = None,
|
|
518
|
+
early_stopping: EarlyStopping | None = None,
|
|
519
|
+
save_checkpoint_mode: str = "BEST",
|
|
520
|
+
) -> None:
|
|
306
521
|
if os.environ["KONFAI_CONFIG_MODE"] != "Done":
|
|
307
522
|
exit(0)
|
|
308
523
|
super().__init__(train_name)
|
|
309
|
-
self.manual_seed = manual_seed
|
|
524
|
+
self.manual_seed = manual_seed
|
|
310
525
|
self.dataset = dataset
|
|
311
526
|
self.autocast = autocast
|
|
312
527
|
self.epochs = epochs
|
|
@@ -314,118 +529,214 @@ class Trainer(DistributedObject):
|
|
|
314
529
|
self.early_stopping = early_stopping
|
|
315
530
|
self.it = 0
|
|
316
531
|
self.it_validation = it_validation
|
|
317
|
-
self.model = model.
|
|
532
|
+
self.model = model.get_model(train=True)
|
|
318
533
|
self.ema_decay = ema_decay
|
|
319
|
-
self.
|
|
534
|
+
self.model_ema: torch.optim.swa_utils.AveragedModel | None = None
|
|
320
535
|
self.data_log = data_log
|
|
321
536
|
|
|
322
537
|
modules = []
|
|
323
|
-
for i,_ in self.model.named_modules():
|
|
538
|
+
for i, _ in self.model.named_modules():
|
|
324
539
|
modules.append(i)
|
|
325
|
-
|
|
326
|
-
|
|
327
|
-
|
|
328
|
-
|
|
329
|
-
|
|
330
|
-
|
|
331
|
-
|
|
332
|
-
|
|
540
|
+
if self.data_log is not None:
|
|
541
|
+
for k in self.data_log:
|
|
542
|
+
tmp = k.split("/")[0].replace(":", ".")
|
|
543
|
+
if tmp not in self.dataset.get_groups_dest() and tmp not in modules:
|
|
544
|
+
raise TrainerError(
|
|
545
|
+
f"Invalid key '{tmp}' in `data_log`.",
|
|
546
|
+
f"This key is neither a destination group from the dataset ({self.dataset.get_groups_dest()})",
|
|
547
|
+
f"nor a valid module name in the model ({modules}).",
|
|
548
|
+
"Please check your `data_log` configuration,"
|
|
549
|
+
" it should reference either a model output or a dataset group.",
|
|
550
|
+
)
|
|
551
|
+
|
|
333
552
|
self.gradient_checkpoints = gradient_checkpoints
|
|
334
553
|
self.gpu_checkpoints = gpu_checkpoints
|
|
335
554
|
self.save_checkpoint_mode = save_checkpoint_mode
|
|
336
|
-
self.config_namefile_src =
|
|
337
|
-
self.config_namefile =
|
|
338
|
-
|
|
339
|
-
|
|
340
|
-
|
|
341
|
-
|
|
555
|
+
self.config_namefile_src = config_file().replace(".yml", "")
|
|
556
|
+
self.config_namefile = (
|
|
557
|
+
setups_directory() + self.name + "/" + self.config_namefile_src.split("/")[-1] + "_" + str(self.it) + ".yml"
|
|
558
|
+
)
|
|
559
|
+
self.size = len(self.gpu_checkpoints) + 1 if self.gpu_checkpoints else 1
|
|
560
|
+
|
|
561
|
+
def __exit__(self, exc_type, value, traceback):
|
|
562
|
+
"""Exit training context and trigger save of model/checkpoints."""
|
|
563
|
+
super().__exit__(exc_type, value, traceback)
|
|
342
564
|
self._save()
|
|
343
565
|
|
|
344
566
|
def _load(self) -> dict[str, dict[str, torch.Tensor]]:
|
|
345
|
-
|
|
567
|
+
"""
|
|
568
|
+
Loads a previously saved checkpoint from local disk or URL.
|
|
569
|
+
|
|
570
|
+
Returns:
|
|
571
|
+
dict: State dictionary loaded from checkpoint.
|
|
572
|
+
"""
|
|
573
|
+
if path_to_models().startswith("https://"):
|
|
346
574
|
try:
|
|
347
|
-
state_dict = {
|
|
348
|
-
|
|
349
|
-
|
|
575
|
+
state_dict = {
|
|
576
|
+
path_to_models().split(":")[1]: torch.hub.load_state_dict_from_url(
|
|
577
|
+
url=path_to_models().split(":")[0], map_location="cpu", check_hash=True
|
|
578
|
+
)
|
|
579
|
+
}
|
|
580
|
+
except Exception:
|
|
581
|
+
raise Exception(f"Model : {path_to_models()} does not exist !")
|
|
350
582
|
else:
|
|
351
|
-
if
|
|
583
|
+
if path_to_models() != "":
|
|
352
584
|
path = ""
|
|
353
|
-
name =
|
|
585
|
+
name = path_to_models()
|
|
354
586
|
else:
|
|
355
|
-
path =
|
|
587
|
+
path = checkpoints_directory() + self.name + "/"
|
|
356
588
|
if os.listdir(path):
|
|
357
589
|
name = sorted(os.listdir(path))[-1]
|
|
358
|
-
|
|
359
|
-
if os.path.exists(path+name):
|
|
360
|
-
state_dict = torch.load(path+name, weights_only=
|
|
590
|
+
|
|
591
|
+
if os.path.exists(path + name):
|
|
592
|
+
state_dict = torch.load(path + name, weights_only=True, map_location="cpu")
|
|
361
593
|
else:
|
|
362
|
-
raise Exception("Model : {} does not exist !"
|
|
363
|
-
|
|
594
|
+
raise Exception(f"Model : {self.name} does not exist !")
|
|
595
|
+
|
|
364
596
|
if "epoch" in state_dict:
|
|
365
|
-
self.epoch = state_dict[
|
|
597
|
+
self.epoch = state_dict["epoch"]
|
|
366
598
|
if "it" in state_dict:
|
|
367
|
-
self.it = state_dict[
|
|
599
|
+
self.it = state_dict["it"]
|
|
368
600
|
return state_dict
|
|
369
601
|
|
|
370
602
|
def _save(self) -> None:
|
|
371
|
-
|
|
372
|
-
|
|
603
|
+
"""
|
|
604
|
+
Serializes model and EMA model (if any) in both state_dict and full model formats.
|
|
605
|
+
Also saves optimizer states and YAML config snapshot.
|
|
606
|
+
"""
|
|
607
|
+
path_checkpoint = checkpoints_directory() + self.name + "/"
|
|
608
|
+
path_model = models_directory() + self.name + "/"
|
|
373
609
|
if os.path.exists(path_checkpoint) and os.listdir(path_checkpoint):
|
|
374
|
-
for
|
|
375
|
-
|
|
376
|
-
|
|
610
|
+
for directory in [
|
|
611
|
+
path_model,
|
|
612
|
+
f"{path_model}Serialized/",
|
|
613
|
+
f"{path_model}StateDict/",
|
|
614
|
+
]:
|
|
615
|
+
if not os.path.exists(directory):
|
|
616
|
+
os.makedirs(directory)
|
|
377
617
|
|
|
378
618
|
for name in sorted(os.listdir(path_checkpoint)):
|
|
379
|
-
checkpoint = torch.load(path_checkpoint+name, weights_only=
|
|
619
|
+
checkpoint = torch.load(path_checkpoint + name, weights_only=True, map_location="cpu")
|
|
380
620
|
self.model.load(checkpoint, init=False, ema=False)
|
|
381
621
|
|
|
382
|
-
torch.save(self.model, "{}Serialized/{}"
|
|
383
|
-
torch.save(
|
|
384
|
-
|
|
385
|
-
|
|
386
|
-
|
|
387
|
-
|
|
388
|
-
|
|
389
|
-
|
|
390
|
-
|
|
391
|
-
|
|
392
|
-
|
|
393
|
-
|
|
394
|
-
|
|
622
|
+
torch.save(self.model, f"{path_model}Serialized/{name}")
|
|
623
|
+
torch.save(
|
|
624
|
+
{"Model": self.model.state_dict()},
|
|
625
|
+
f"{path_model}StateDict/{name}",
|
|
626
|
+
)
|
|
627
|
+
|
|
628
|
+
if self.model_ema is not None:
|
|
629
|
+
self.model_ema.module.load(checkpoint, init=False, ema=True)
|
|
630
|
+
torch.save(
|
|
631
|
+
self.model_ema.module,
|
|
632
|
+
f"{path_model}Serialized/{current_date()}_EMA.pt",
|
|
633
|
+
)
|
|
634
|
+
torch.save(
|
|
635
|
+
{"Model_EMA": self.model_ema.module.state_dict()},
|
|
636
|
+
f"{path_model}StateDict/{current_date()}_EMA.pt",
|
|
637
|
+
)
|
|
638
|
+
|
|
639
|
+
os.rename(
|
|
640
|
+
self.config_namefile,
|
|
641
|
+
self.config_namefile.replace(".yml", "") + "_" + str(self.it) + ".yml",
|
|
642
|
+
)
|
|
643
|
+
|
|
644
|
+
def _avg_fn(self, averaged_model_parameter: float, model_parameter, num_averaged):
|
|
645
|
+
"""
|
|
646
|
+
EMA update rule used by AveragedModel.
|
|
647
|
+
|
|
648
|
+
Returns:
|
|
649
|
+
torch.Tensor: Blended parameter using decay factor.
|
|
650
|
+
"""
|
|
651
|
+
return (1 - self.ema_decay) * averaged_model_parameter + self.ema_decay * model_parameter
|
|
652
|
+
|
|
395
653
|
def setup(self, world_size: int):
|
|
396
|
-
|
|
397
|
-
|
|
654
|
+
"""
|
|
655
|
+
Initializes the training environment:
|
|
656
|
+
- Clears previous outputs (unless resuming)
|
|
657
|
+
- Initializes model and EMA
|
|
658
|
+
- Loads checkpoint (if resuming)
|
|
659
|
+
- Prepares dataloaders
|
|
660
|
+
|
|
661
|
+
Args:
|
|
662
|
+
world_size (int): Total number of distributed processes.
|
|
663
|
+
"""
|
|
664
|
+
state = State[konfai_state()]
|
|
665
|
+
if state != State.RESUME and os.path.exists(checkpoints_directory() + self.name + "/"):
|
|
398
666
|
if os.environ["KONFAI_OVERWRITE"] != "True":
|
|
399
|
-
accept = input("The model {} already exists ! Do you want to overwrite it (yes,no) : "
|
|
667
|
+
accept = input(f"The model {self.name} already exists ! Do you want to overwrite it (yes,no) : ")
|
|
400
668
|
if accept != "yes":
|
|
401
669
|
return
|
|
402
|
-
for directory_path in [
|
|
403
|
-
|
|
404
|
-
|
|
405
|
-
|
|
670
|
+
for directory_path in [
|
|
671
|
+
statistics_directory(),
|
|
672
|
+
models_directory(),
|
|
673
|
+
checkpoints_directory(),
|
|
674
|
+
setups_directory(),
|
|
675
|
+
]:
|
|
676
|
+
if os.path.exists(directory_path + self.name + "/"):
|
|
677
|
+
shutil.rmtree(directory_path + self.name + "/")
|
|
678
|
+
|
|
406
679
|
state_dict = {}
|
|
407
680
|
if state != State.TRAIN:
|
|
408
681
|
state_dict = self._load()
|
|
409
682
|
|
|
410
|
-
self.model.init(self.autocast, state, self.dataset.
|
|
411
|
-
self.model.
|
|
412
|
-
self.model._compute_channels_trace(
|
|
683
|
+
self.model.init(self.autocast, state, self.dataset.get_groups_dest())
|
|
684
|
+
self.model.init_outputs_group()
|
|
685
|
+
self.model._compute_channels_trace(
|
|
686
|
+
self.model,
|
|
687
|
+
self.model.in_channels,
|
|
688
|
+
self.gradient_checkpoints,
|
|
689
|
+
self.gpu_checkpoints,
|
|
690
|
+
)
|
|
413
691
|
self.model.load(state_dict, init=True, ema=False)
|
|
414
692
|
if self.ema_decay > 0:
|
|
415
|
-
self.
|
|
693
|
+
self.model_ema = AveragedModel(self.model, avg_fn=self._avg_fn)
|
|
416
694
|
if state_dict is not None:
|
|
417
|
-
self.
|
|
418
|
-
|
|
419
|
-
if not os.path.exists(
|
|
420
|
-
os.makedirs(
|
|
421
|
-
shutil.copyfile(self.config_namefile_src+".yml", self.config_namefile)
|
|
422
|
-
|
|
423
|
-
self.dataloader = self.dataset.
|
|
424
|
-
|
|
425
|
-
def run_process(
|
|
426
|
-
|
|
427
|
-
|
|
428
|
-
|
|
429
|
-
|
|
430
|
-
|
|
431
|
-
|
|
695
|
+
self.model_ema.module.load(state_dict, init=False, ema=True)
|
|
696
|
+
|
|
697
|
+
if not os.path.exists(setups_directory() + self.name + "/"):
|
|
698
|
+
os.makedirs(setups_directory() + self.name + "/")
|
|
699
|
+
shutil.copyfile(self.config_namefile_src + ".yml", self.config_namefile)
|
|
700
|
+
|
|
701
|
+
self.dataloader = self.dataset.get_data(world_size // self.size)
|
|
702
|
+
|
|
703
|
+
def run_process(
|
|
704
|
+
self,
|
|
705
|
+
world_size: int,
|
|
706
|
+
global_rank: int,
|
|
707
|
+
local_rank: int,
|
|
708
|
+
dataloaders: list[DataLoader],
|
|
709
|
+
):
|
|
710
|
+
"""
|
|
711
|
+
Launches the actual training process via internal `_Trainer` class.
|
|
712
|
+
Wraps model with DDP or CPU fallback, attaches EMA, and starts training.
|
|
713
|
+
|
|
714
|
+
Args:
|
|
715
|
+
world_size (int): Total number of distributed processes.
|
|
716
|
+
global_rank (int): Global rank of the current process.
|
|
717
|
+
local_rank (int): Local rank within the node.
|
|
718
|
+
dataloaders (list[DataLoader]): Training and validation dataloaders.
|
|
719
|
+
"""
|
|
720
|
+
model = Network.to(self.model, local_rank * self.size)
|
|
721
|
+
model = DDP(model, static_graph=True) if torch.cuda.is_available() else CPUModel(model)
|
|
722
|
+
if self.model_ema is not None:
|
|
723
|
+
self.model_ema.module = Network.to(self.model_ema.module, local_rank)
|
|
724
|
+
with _Trainer(
|
|
725
|
+
world_size,
|
|
726
|
+
global_rank,
|
|
727
|
+
local_rank,
|
|
728
|
+
self.size,
|
|
729
|
+
self.name,
|
|
730
|
+
self.early_stopping,
|
|
731
|
+
self.data_log,
|
|
732
|
+
self.save_checkpoint_mode,
|
|
733
|
+
self.epochs,
|
|
734
|
+
self.epoch,
|
|
735
|
+
self.autocast,
|
|
736
|
+
self.it_validation,
|
|
737
|
+
self.it,
|
|
738
|
+
model,
|
|
739
|
+
self.model_ema,
|
|
740
|
+
*dataloaders,
|
|
741
|
+
) as t:
|
|
742
|
+
t.run()
|