konfai 1.1.8__py3-none-any.whl → 1.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of konfai might be problematic. Click here for more details.
- konfai/__init__.py +59 -14
- konfai/data/augmentation.py +457 -286
- konfai/data/data_manager.py +533 -316
- konfai/data/patching.py +300 -183
- konfai/data/transform.py +408 -275
- konfai/evaluator.py +325 -68
- konfai/main.py +71 -22
- konfai/metric/measure.py +360 -244
- konfai/metric/schedulers.py +24 -13
- konfai/models/classification/convNeXt.py +187 -81
- konfai/models/classification/resnet.py +272 -58
- konfai/models/generation/cStyleGan.py +233 -59
- konfai/models/generation/ddpm.py +348 -121
- konfai/models/generation/diffusionGan.py +757 -358
- konfai/models/generation/gan.py +177 -53
- konfai/models/generation/vae.py +140 -40
- konfai/models/registration/registration.py +135 -52
- konfai/models/representation/representation.py +57 -23
- konfai/models/segmentation/NestedUNet.py +339 -68
- konfai/models/segmentation/UNet.py +140 -30
- konfai/network/blocks.py +331 -187
- konfai/network/network.py +795 -427
- konfai/predictor.py +644 -238
- konfai/trainer.py +509 -222
- konfai/utils/ITK.py +191 -106
- konfai/utils/config.py +152 -95
- konfai/utils/dataset.py +326 -455
- konfai/utils/utils.py +497 -249
- {konfai-1.1.8.dist-info → konfai-1.2.0.dist-info}/METADATA +1 -3
- konfai-1.2.0.dist-info/RECORD +38 -0
- konfai/utils/registration.py +0 -199
- konfai-1.1.8.dist-info/RECORD +0 -39
- {konfai-1.1.8.dist-info → konfai-1.2.0.dist-info}/WHEEL +0 -0
- {konfai-1.1.8.dist-info → konfai-1.2.0.dist-info}/entry_points.txt +0 -0
- {konfai-1.1.8.dist-info → konfai-1.2.0.dist-info}/licenses/LICENSE +0 -0
- {konfai-1.1.8.dist-info → konfai-1.2.0.dist-info}/top_level.txt +0 -0
konfai/trainer.py
CHANGED
|
@@ -1,62 +1,87 @@
|
|
|
1
|
-
import torch
|
|
2
|
-
from torch.utils.data import DataLoader
|
|
3
|
-
import tqdm
|
|
4
|
-
import numpy as np
|
|
5
1
|
import os
|
|
6
|
-
from typing import Union
|
|
7
|
-
from torch.nn.parallel import DistributedDataParallel as DDP
|
|
8
|
-
|
|
9
2
|
import shutil
|
|
10
|
-
|
|
11
|
-
|
|
3
|
+
|
|
4
|
+
import torch
|
|
12
5
|
import torch.distributed as dist
|
|
6
|
+
import tqdm
|
|
7
|
+
from torch.nn.parallel import DistributedDataParallel as DDP # noqa: N817
|
|
8
|
+
from torch.optim.swa_utils import AveragedModel
|
|
9
|
+
from torch.utils.data import DataLoader
|
|
10
|
+
from torch.utils.tensorboard.writer import SummaryWriter
|
|
13
11
|
|
|
14
|
-
from konfai import
|
|
12
|
+
from konfai import (
|
|
13
|
+
checkpoints_directory,
|
|
14
|
+
config_file,
|
|
15
|
+
current_date,
|
|
16
|
+
konfai_state,
|
|
17
|
+
path_to_models,
|
|
18
|
+
setups_directory,
|
|
19
|
+
statistics_directory,
|
|
20
|
+
)
|
|
15
21
|
from konfai.data.data_manager import DataTrain
|
|
22
|
+
from konfai.network.network import CPUModel, ModelLoader, NetState, Network
|
|
16
23
|
from konfai.utils.config import config
|
|
17
|
-
from konfai.utils.utils import
|
|
18
|
-
|
|
24
|
+
from konfai.utils.utils import DataLog, DistributedObject, State, TrainerError, description
|
|
25
|
+
|
|
19
26
|
|
|
20
27
|
class EarlyStoppingBase:
|
|
21
28
|
|
|
22
29
|
def __init__(self):
|
|
23
30
|
pass
|
|
24
31
|
|
|
25
|
-
def
|
|
32
|
+
def is_stopped(self) -> bool:
|
|
33
|
+
|
|
26
34
|
return False
|
|
27
35
|
|
|
28
|
-
def
|
|
29
|
-
return sum(
|
|
30
|
-
|
|
36
|
+
def get_score(self, values: dict[str, float]):
|
|
37
|
+
return sum(list(values.values()))
|
|
38
|
+
|
|
31
39
|
def __call__(self, current_score: float) -> bool:
|
|
32
40
|
return False
|
|
33
41
|
|
|
42
|
+
|
|
34
43
|
class EarlyStopping(EarlyStoppingBase):
|
|
44
|
+
"""
|
|
45
|
+
Implements early stopping logic with configurable patience and monitored metrics.
|
|
46
|
+
|
|
47
|
+
Attributes:
|
|
48
|
+
monitor (list[str]): Metrics to monitor.
|
|
49
|
+
patience (int): Number of checks with no improvement before stopping.
|
|
50
|
+
min_delta (float): Minimum change to qualify as improvement.
|
|
51
|
+
mode (str): "min" or "max" depending on optimization direction.
|
|
52
|
+
"""
|
|
35
53
|
|
|
36
54
|
@config("EarlyStopping")
|
|
37
|
-
def __init__(
|
|
55
|
+
def __init__(
|
|
56
|
+
self,
|
|
57
|
+
monitor: list[str] | None = None,
|
|
58
|
+
patience: int = 10,
|
|
59
|
+
min_delta: float = 0.0,
|
|
60
|
+
mode: str = "min",
|
|
61
|
+
):
|
|
38
62
|
super().__init__()
|
|
39
63
|
self.monitor = [] if monitor is None else monitor
|
|
40
64
|
self.patience = patience
|
|
41
65
|
self.min_delta = min_delta
|
|
42
66
|
self.mode = mode
|
|
43
67
|
self.counter = 0
|
|
44
|
-
self.best_score = None
|
|
68
|
+
self.best_score: float | None = None
|
|
45
69
|
self.early_stop = False
|
|
46
70
|
|
|
47
|
-
def
|
|
71
|
+
def is_stopped(self) -> bool:
|
|
48
72
|
return self.early_stop
|
|
49
73
|
|
|
50
|
-
def
|
|
74
|
+
def get_score(self, values: dict[str, float]):
|
|
51
75
|
if len(self.monitor) == 0:
|
|
52
|
-
return super().
|
|
76
|
+
return super().get_score(values)
|
|
53
77
|
for v in self.monitor:
|
|
54
78
|
if v not in values.keys():
|
|
55
79
|
raise TrainerError(
|
|
56
80
|
"Metric '{}' specified in EarlyStopping.monitor not found in logged values. ",
|
|
57
|
-
"Available keys: {}. Please check your configuration."
|
|
81
|
+
f"Available keys: {v}. Please check your configuration.",
|
|
82
|
+
)
|
|
58
83
|
return sum([i for v, i in values.items() if v in self.monitor])
|
|
59
|
-
|
|
84
|
+
|
|
60
85
|
def __call__(self, current_score: float) -> bool:
|
|
61
86
|
if self.best_score is None:
|
|
62
87
|
self.best_score = current_score
|
|
@@ -80,14 +105,47 @@ class EarlyStopping(EarlyStoppingBase):
|
|
|
80
105
|
|
|
81
106
|
return self.early_stop
|
|
82
107
|
|
|
83
|
-
class _Trainer():
|
|
84
108
|
|
|
85
|
-
|
|
86
|
-
|
|
109
|
+
class _Trainer:
|
|
110
|
+
"""
|
|
111
|
+
Internal class for managing the training loop in a distributed or standalone setting.
|
|
112
|
+
|
|
113
|
+
Handles:
|
|
114
|
+
- Epoch iteration with training and optional validation
|
|
115
|
+
- Mixed precision support (autocast)
|
|
116
|
+
- Exponential Moving Average (EMA) model tracking
|
|
117
|
+
- Early stopping
|
|
118
|
+
- Logging to TensorBoard
|
|
119
|
+
- Model checkpoint saving and selection (ALL or BEST)
|
|
120
|
+
|
|
121
|
+
This class is intended to be used via a context manager
|
|
122
|
+
(`with _Trainer(...) as trainer:`) inside the public `Trainer` class.
|
|
123
|
+
"""
|
|
124
|
+
|
|
125
|
+
def __init__(
|
|
126
|
+
self,
|
|
127
|
+
world_size: int,
|
|
128
|
+
global_rank: int,
|
|
129
|
+
local_rank: int,
|
|
130
|
+
size: int,
|
|
131
|
+
train_name: str,
|
|
132
|
+
early_stopping: EarlyStopping | None,
|
|
133
|
+
data_log: list[str] | None,
|
|
134
|
+
save_checkpoint_mode: str,
|
|
135
|
+
epochs: int,
|
|
136
|
+
epoch: int,
|
|
137
|
+
autocast: bool,
|
|
138
|
+
it_validation: int | None,
|
|
139
|
+
it: int,
|
|
140
|
+
model: DDP | CPUModel,
|
|
141
|
+
model_ema: AveragedModel,
|
|
142
|
+
dataloader_training: DataLoader,
|
|
143
|
+
dataloader_validation: DataLoader | None = None,
|
|
144
|
+
) -> None:
|
|
145
|
+
self.world_size = world_size
|
|
87
146
|
self.global_rank = global_rank
|
|
88
147
|
self.local_rank = local_rank
|
|
89
148
|
self.size = size
|
|
90
|
-
|
|
91
149
|
self.save_checkpoint_mode = save_checkpoint_mode
|
|
92
150
|
self.train_name = train_name
|
|
93
151
|
self.epochs = epochs
|
|
@@ -96,58 +154,97 @@ class _Trainer():
|
|
|
96
154
|
self.dataloader_training = dataloader_training
|
|
97
155
|
self.dataloader_validation = dataloader_validation
|
|
98
156
|
self.autocast = autocast
|
|
99
|
-
self.
|
|
100
|
-
self.early_stopping = EarlyStoppingBase() if early_stopping is None else early_stopping
|
|
101
|
-
|
|
102
|
-
self.it_validation = it_validation
|
|
103
|
-
if self.it_validation is None:
|
|
104
|
-
self.it_validation = len(dataloader_training)
|
|
157
|
+
self.model_ema = model_ema
|
|
158
|
+
self.early_stopping = EarlyStoppingBase() if early_stopping is None else early_stopping
|
|
159
|
+
|
|
160
|
+
self.it_validation = len(dataloader_training) if it_validation is None else it_validation
|
|
105
161
|
self.it = it
|
|
106
|
-
self.tb = SummaryWriter(log_dir
|
|
107
|
-
self.data_log
|
|
162
|
+
self.tb = SummaryWriter(log_dir=statistics_directory() + self.train_name + "/")
|
|
163
|
+
self.data_log: dict[str, tuple[DataLog, int]] = {}
|
|
108
164
|
if data_log is not None:
|
|
109
165
|
for data in data_log:
|
|
110
|
-
self.data_log[data.split("/")[0].replace(":", ".")] = (
|
|
111
|
-
|
|
166
|
+
self.data_log[data.split("/")[0].replace(":", ".")] = (
|
|
167
|
+
DataLog[data.split("/")[1]],
|
|
168
|
+
int(data.split("/")[2]),
|
|
169
|
+
)
|
|
170
|
+
|
|
112
171
|
def __enter__(self):
|
|
113
172
|
return self
|
|
114
|
-
|
|
115
|
-
def __exit__(self,
|
|
173
|
+
|
|
174
|
+
def __exit__(self, exc_type, value, traceback):
|
|
175
|
+
"""Closes the SummaryWriter if used."""
|
|
116
176
|
if self.tb is not None:
|
|
117
177
|
self.tb.close()
|
|
118
178
|
|
|
119
179
|
def run(self) -> None:
|
|
180
|
+
"""
|
|
181
|
+
Launches the training loop, performing one epoch at a time.
|
|
182
|
+
Triggers early stopping and resets data augmentations between epochs.
|
|
183
|
+
"""
|
|
120
184
|
self.dataloader_training.dataset.load("Train")
|
|
121
185
|
if self.dataloader_validation is not None:
|
|
122
186
|
self.dataloader_validation.dataset.load("Validation")
|
|
123
|
-
|
|
187
|
+
if State[konfai_state()] != State.TRAIN:
|
|
188
|
+
self._validate()
|
|
189
|
+
|
|
190
|
+
with tqdm.tqdm(
|
|
191
|
+
iterable=range(self.epoch, self.epochs),
|
|
192
|
+
leave=False,
|
|
193
|
+
total=self.epochs,
|
|
194
|
+
initial=self.epoch,
|
|
195
|
+
desc="Progress",
|
|
196
|
+
) as epoch_tqdm:
|
|
124
197
|
for self.epoch in epoch_tqdm:
|
|
125
198
|
self.train()
|
|
126
|
-
if self.early_stopping.
|
|
199
|
+
if self.early_stopping.is_stopped():
|
|
127
200
|
break
|
|
128
|
-
self.dataloader_training.dataset.
|
|
129
|
-
|
|
130
|
-
def getInput(self, data_dict : dict[str, tuple[torch.Tensor, int, int, int, str, bool]]) -> dict[tuple[str, bool], torch.Tensor]:
|
|
131
|
-
return {(k, v[5][0].item()) : v[0] for k, v in data_dict.items()}
|
|
201
|
+
self.dataloader_training.dataset.reset_augmentation("Train")
|
|
132
202
|
|
|
133
|
-
def
|
|
134
|
-
self
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
138
|
-
|
|
203
|
+
def get_input(
|
|
204
|
+
self,
|
|
205
|
+
data_dict: dict[str, tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, list[str], torch.Tensor]],
|
|
206
|
+
) -> dict[tuple[str, bool], torch.Tensor]:
|
|
207
|
+
"""
|
|
208
|
+
Extracts input tensors from the data dict for model input.
|
|
139
209
|
|
|
140
|
-
|
|
210
|
+
Args:
|
|
211
|
+
data_dict (dict): Dictionary with data items structured as tuples.
|
|
141
212
|
|
|
142
|
-
|
|
213
|
+
Returns:
|
|
214
|
+
dict: Mapping from (group, bool_flag) to input tensors.
|
|
215
|
+
"""
|
|
216
|
+
return {(k, v[5][0].item()): v[0] for k, v in data_dict.items()}
|
|
217
|
+
|
|
218
|
+
def train(self) -> None:
|
|
219
|
+
"""
|
|
220
|
+
Performs a full training epoch with support for:
|
|
221
|
+
- mixed precision
|
|
222
|
+
- DDP / CPU training
|
|
223
|
+
- EMA updates
|
|
224
|
+
- loss logging and checkpoint saving
|
|
225
|
+
- validation at configurable iteration interval
|
|
226
|
+
"""
|
|
227
|
+
self.model.train()
|
|
228
|
+
self.model.module.set_state(NetState.TRAIN)
|
|
229
|
+
if self.model_ema is not None:
|
|
230
|
+
self.model_ema.eval()
|
|
231
|
+
self.model_ema.module.set_state(NetState.TRAIN)
|
|
232
|
+
|
|
233
|
+
with tqdm.tqdm(
|
|
234
|
+
iterable=enumerate(self.dataloader_training),
|
|
235
|
+
desc=f"Training : {description(self.model, self.model_ema)}",
|
|
236
|
+
total=len(self.dataloader_training),
|
|
237
|
+
leave=False,
|
|
238
|
+
ncols=0,
|
|
239
|
+
) as batch_iter:
|
|
143
240
|
for _, data_dict in batch_iter:
|
|
144
|
-
with torch.amp.autocast(
|
|
145
|
-
|
|
146
|
-
self.model(
|
|
241
|
+
with torch.amp.autocast("cuda", enabled=self.autocast):
|
|
242
|
+
input_data_dict = self.get_input(data_dict)
|
|
243
|
+
self.model(input_data_dict)
|
|
147
244
|
self.model.module.backward(self.model.module)
|
|
148
|
-
if self.
|
|
149
|
-
self.
|
|
150
|
-
self.
|
|
245
|
+
if self.model_ema is not None:
|
|
246
|
+
self.model_ema.update_parameters(self.model)
|
|
247
|
+
self.model_ema.module(input_data_dict)
|
|
151
248
|
self.it += 1
|
|
152
249
|
if (self.it) % self.it_validation == 0:
|
|
153
250
|
loss = self._train_log(data_dict)
|
|
@@ -155,76 +252,109 @@ class _Trainer():
|
|
|
155
252
|
if self.dataloader_validation is not None:
|
|
156
253
|
loss = self._validate()
|
|
157
254
|
self.model.module.update_lr()
|
|
158
|
-
score = self.early_stopping.
|
|
255
|
+
score = self.early_stopping.get_score(loss)
|
|
159
256
|
self.checkpoint_save(score)
|
|
160
257
|
if self.early_stopping(score):
|
|
161
258
|
break
|
|
162
|
-
|
|
163
259
|
|
|
164
|
-
batch_iter.set_description(
|
|
165
|
-
|
|
166
|
-
|
|
260
|
+
batch_iter.set_description(f"Training : {description(self.model, self.model_ema)}")
|
|
261
|
+
|
|
167
262
|
@torch.no_grad()
|
|
168
263
|
def _validate(self) -> float:
|
|
264
|
+
"""
|
|
265
|
+
Executes the validation phase, evaluates loss and metrics.
|
|
266
|
+
Updates model states and resets augmentation for validation set.
|
|
267
|
+
|
|
268
|
+
Returns:
|
|
269
|
+
float: Validation loss.
|
|
270
|
+
"""
|
|
271
|
+
if self.dataloader_validation is None:
|
|
272
|
+
return 0
|
|
169
273
|
self.model.eval()
|
|
170
|
-
self.model.module.
|
|
171
|
-
if self.
|
|
172
|
-
self.
|
|
274
|
+
self.model.module.set_state(NetState.PREDICTION)
|
|
275
|
+
if self.model_ema is not None:
|
|
276
|
+
self.model_ema.module.set_state(NetState.PREDICTION)
|
|
173
277
|
|
|
174
|
-
desc = lambda : "Validation : {}".format(description(self.model, self.modelEMA))
|
|
175
278
|
data_dict = None
|
|
176
|
-
with tqdm.tqdm(
|
|
279
|
+
with tqdm.tqdm(
|
|
280
|
+
iterable=enumerate(self.dataloader_validation),
|
|
281
|
+
desc=f"Validation : {description(self.model, self.model_ema)}",
|
|
282
|
+
total=len(self.dataloader_validation),
|
|
283
|
+
leave=False,
|
|
284
|
+
ncols=0,
|
|
285
|
+
) as batch_iter:
|
|
177
286
|
for _, data_dict in batch_iter:
|
|
178
|
-
|
|
179
|
-
self.model(
|
|
180
|
-
if self.
|
|
181
|
-
self.
|
|
287
|
+
input_data_dict = self.get_input(data_dict)
|
|
288
|
+
self.model(input_data_dict)
|
|
289
|
+
if self.model_ema is not None:
|
|
290
|
+
self.model_ema.module(input_data_dict)
|
|
182
291
|
|
|
183
|
-
batch_iter.set_description(
|
|
184
|
-
self.dataloader_validation.dataset.
|
|
292
|
+
batch_iter.set_description(f"Validation : {description(self.model, self.model_ema)}")
|
|
293
|
+
self.dataloader_validation.dataset.reset_augmentation("Validation")
|
|
185
294
|
dist.barrier()
|
|
186
295
|
self.model.train()
|
|
187
|
-
self.model.module.
|
|
188
|
-
if self.
|
|
189
|
-
self.
|
|
296
|
+
self.model.module.set_state(NetState.TRAIN)
|
|
297
|
+
if self.model_ema is not None:
|
|
298
|
+
self.model_ema.module.set_state(NetState.TRAIN)
|
|
190
299
|
return self._validation_log(data_dict)
|
|
191
300
|
|
|
192
301
|
def checkpoint_save(self, loss: float) -> None:
|
|
302
|
+
"""
|
|
303
|
+
Saves model and optimizer states. Keeps either all checkpoints or only the best one.
|
|
304
|
+
|
|
305
|
+
Args:
|
|
306
|
+
loss (float): Current loss used for best checkpoint selection.
|
|
307
|
+
"""
|
|
193
308
|
if self.global_rank != 0:
|
|
194
309
|
return
|
|
195
310
|
|
|
196
|
-
path =
|
|
311
|
+
path = checkpoints_directory() + self.train_name + "/"
|
|
197
312
|
os.makedirs(path, exist_ok=True)
|
|
198
313
|
|
|
199
|
-
name =
|
|
314
|
+
name = current_date() + ".pt"
|
|
200
315
|
save_path = os.path.join(path, name)
|
|
201
316
|
|
|
202
317
|
save_dict = {
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
|
|
206
|
-
|
|
318
|
+
"epoch": self.epoch,
|
|
319
|
+
"it": self.it,
|
|
320
|
+
"loss": loss,
|
|
321
|
+
"Model": self.model.module.state_dict(),
|
|
207
322
|
}
|
|
208
323
|
|
|
209
|
-
if self.
|
|
210
|
-
save_dict["Model_EMA"] = self.
|
|
211
|
-
|
|
212
|
-
save_dict.update(
|
|
213
|
-
|
|
214
|
-
|
|
215
|
-
|
|
324
|
+
if self.model_ema is not None:
|
|
325
|
+
save_dict["Model_EMA"] = self.model_ema.module.state_dict()
|
|
326
|
+
|
|
327
|
+
save_dict.update(
|
|
328
|
+
{
|
|
329
|
+
f"{name}_optimizer_state_dict": network.optimizer.state_dict()
|
|
330
|
+
for name, network in self.model.module.get_networks().items()
|
|
331
|
+
if network.optimizer is not None
|
|
332
|
+
}
|
|
333
|
+
)
|
|
334
|
+
save_dict.update(
|
|
335
|
+
{
|
|
336
|
+
f"{name}_it": network._it
|
|
337
|
+
for name, network in self.model.module.get_networks().items()
|
|
338
|
+
if network.optimizer is not None
|
|
339
|
+
}
|
|
340
|
+
)
|
|
341
|
+
save_dict.update(
|
|
342
|
+
{
|
|
343
|
+
f"{name}_nb_lr_update": network._nb_lr_update
|
|
344
|
+
for name, network in self.model.module.get_networks().items()
|
|
345
|
+
if network.optimizer is not None
|
|
346
|
+
}
|
|
347
|
+
)
|
|
348
|
+
|
|
216
349
|
torch.save(save_dict, save_path)
|
|
217
350
|
|
|
218
351
|
if self.save_checkpoint_mode == "BEST":
|
|
219
|
-
all_checkpoints = sorted([
|
|
220
|
-
os.path.join(path, f)
|
|
221
|
-
for f in os.listdir(path) if f.endswith(".pt")
|
|
222
|
-
])
|
|
352
|
+
all_checkpoints = sorted([os.path.join(path, f) for f in os.listdir(path) if f.endswith(".pt")])
|
|
223
353
|
best_ckpt = None
|
|
224
|
-
best_loss = float(
|
|
354
|
+
best_loss = float("inf")
|
|
225
355
|
|
|
226
356
|
for f in all_checkpoints:
|
|
227
|
-
d = torch.load(f, weights_only=
|
|
357
|
+
d = torch.load(f, weights_only=True)
|
|
228
358
|
if d.get("loss", float("inf")) < best_loss:
|
|
229
359
|
best_loss = d["loss"]
|
|
230
360
|
best_ckpt = f
|
|
@@ -234,79 +364,165 @@ class _Trainer():
|
|
|
234
364
|
os.remove(f)
|
|
235
365
|
|
|
236
366
|
@torch.no_grad()
|
|
237
|
-
def _log(
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
|
|
367
|
+
def _log(
|
|
368
|
+
self,
|
|
369
|
+
type_log: str,
|
|
370
|
+
data_dict: dict[str, tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, list[str], torch.Tensor]],
|
|
371
|
+
) -> dict[str, float] | None:
|
|
372
|
+
"""
|
|
373
|
+
Logs losses, metrics and optionally images to TensorBoard.
|
|
374
|
+
|
|
375
|
+
Args:
|
|
376
|
+
type_log (str): "Training" or "Validation".
|
|
377
|
+
data_dict (dict): Dictionary of data from current batch.
|
|
378
|
+
|
|
379
|
+
Returns:
|
|
380
|
+
dict[str, float] | None: Dictionary of aggregated losses and metrics if rank == 0.
|
|
381
|
+
"""
|
|
382
|
+
models: dict[str, Network] = {"": self.model.module}
|
|
383
|
+
if self.model_ema is not None:
|
|
384
|
+
models["_EMA"] = self.model_ema.module
|
|
385
|
+
|
|
386
|
+
measures = DistributedObject.get_measure(
|
|
387
|
+
self.world_size,
|
|
388
|
+
self.global_rank,
|
|
389
|
+
self.local_rank * self.size + self.size - 1,
|
|
390
|
+
models,
|
|
391
|
+
(
|
|
392
|
+
self.it_validation
|
|
393
|
+
if type_log == "Training" or self.dataloader_validation is None
|
|
394
|
+
else len(self.dataloader_validation)
|
|
395
|
+
),
|
|
396
|
+
)
|
|
397
|
+
|
|
244
398
|
if self.global_rank == 0:
|
|
245
399
|
images_log = []
|
|
246
|
-
if len(self.data_log):
|
|
400
|
+
if len(self.data_log):
|
|
247
401
|
for name, data_type in self.data_log.items():
|
|
248
402
|
if name in data_dict:
|
|
249
|
-
data_type[0](
|
|
403
|
+
data_type[0](
|
|
404
|
+
self.tb,
|
|
405
|
+
f"{type_log}/{name}",
|
|
406
|
+
data_dict[name][0][: self.data_log[name][1]].detach().cpu().numpy(),
|
|
407
|
+
self.it,
|
|
408
|
+
)
|
|
250
409
|
else:
|
|
251
410
|
images_log.append(name.replace(":", "."))
|
|
252
|
-
|
|
411
|
+
|
|
253
412
|
for label, model in models.items():
|
|
254
|
-
for name, network in model.
|
|
413
|
+
for name, network in model.get_networks().items():
|
|
255
414
|
if network.measure is not None:
|
|
256
|
-
self.tb.add_scalars(
|
|
257
|
-
|
|
415
|
+
self.tb.add_scalars(
|
|
416
|
+
f"{type_log}/{name}/Loss/{label}",
|
|
417
|
+
{k: v[1] for k, v in measures[f"{name}{label}"][0].items()},
|
|
418
|
+
self.it,
|
|
419
|
+
)
|
|
420
|
+
self.tb.add_scalars(
|
|
421
|
+
f"{type_log}/{name}/Loss_weight/{label}",
|
|
422
|
+
{k: v[0] for k, v in measures[f"{name}{label}"][0].items()},
|
|
423
|
+
self.it,
|
|
424
|
+
)
|
|
425
|
+
|
|
426
|
+
self.tb.add_scalars(
|
|
427
|
+
f"{type_log}/{name}/Metric/{label}",
|
|
428
|
+
{k: v[1] for k, v in measures[f"{name}{label}"][1].items()},
|
|
429
|
+
self.it,
|
|
430
|
+
)
|
|
431
|
+
self.tb.add_scalars(
|
|
432
|
+
f"{type_log}/{name}/Metric_weight/{label}",
|
|
433
|
+
{k: v[0] for k, v in measures[f"{name}{label}"][1].items()},
|
|
434
|
+
self.it,
|
|
435
|
+
)
|
|
258
436
|
|
|
259
|
-
self.tb.add_scalars("{}/{}/Metric/{}".format(type_log, name, label), {k : v[1] for k, v in measures["{}{}".format(name, label)][1].items()}, self.it)
|
|
260
|
-
self.tb.add_scalars("{}/{}/Metric_weight/{}".format(type_log, name, label), {k : v[0] for k, v in measures["{}{}".format(name, label)][1].items()}, self.it)
|
|
261
|
-
|
|
262
437
|
if len(images_log):
|
|
263
|
-
for name, layer, _ in model.get_layers(
|
|
264
|
-
|
|
265
|
-
|
|
438
|
+
for name, layer, _ in model.get_layers(
|
|
439
|
+
[v.to(0) for k, v in self.get_input(data_dict).items() if k[1]],
|
|
440
|
+
images_log,
|
|
441
|
+
):
|
|
442
|
+
self.data_log[name][0](
|
|
443
|
+
self.tb,
|
|
444
|
+
f"{type_log}/{name}{label}",
|
|
445
|
+
layer[: self.data_log[name][1]].detach().cpu().numpy(),
|
|
446
|
+
self.it,
|
|
447
|
+
)
|
|
448
|
+
|
|
266
449
|
if type_log == "Training":
|
|
267
|
-
for name, network in self.model.module.
|
|
450
|
+
for name, network in self.model.module.get_networks().items():
|
|
268
451
|
if network.optimizer is not None:
|
|
269
|
-
self.tb.add_scalar(
|
|
270
|
-
|
|
452
|
+
self.tb.add_scalar(
|
|
453
|
+
f"{type_log}/{name}/Learning Rate",
|
|
454
|
+
network.optimizer.param_groups[0]["lr"],
|
|
455
|
+
self.it,
|
|
456
|
+
)
|
|
457
|
+
|
|
271
458
|
if self.global_rank == 0:
|
|
272
459
|
loss = {}
|
|
273
|
-
for name, network in self.model.module.
|
|
460
|
+
for name, network in self.model.module.get_networks().items():
|
|
274
461
|
if network.measure is not None:
|
|
275
|
-
loss.update({k
|
|
276
|
-
loss.update({k
|
|
462
|
+
loss.update({k: v[1] for k, v in measures[f"{name}{label}"][0].items()})
|
|
463
|
+
loss.update({k: v[1] for k, v in measures[f"{name}{label}"][1].items()})
|
|
277
464
|
return loss
|
|
278
465
|
return None
|
|
279
|
-
|
|
466
|
+
|
|
280
467
|
@torch.no_grad()
|
|
281
|
-
def _train_log(self, data_dict
|
|
468
|
+
def _train_log(self, data_dict: dict[str, tuple[torch.Tensor, int, int, int]]) -> dict[str, float]:
|
|
469
|
+
"""Wrapper for _log during training."""
|
|
282
470
|
return self._log("Training", data_dict)
|
|
283
471
|
|
|
284
472
|
@torch.no_grad()
|
|
285
|
-
def _validation_log(self, data_dict
|
|
473
|
+
def _validation_log(self, data_dict: dict[str, tuple[torch.Tensor, int, int, int]]) -> dict[str, float]:
|
|
474
|
+
"""Wrapper for _log during validation."""
|
|
286
475
|
return self._log("Validation", data_dict)
|
|
287
476
|
|
|
288
|
-
|
|
477
|
+
|
|
289
478
|
class Trainer(DistributedObject):
|
|
479
|
+
"""
|
|
480
|
+
Public API for training a model using the KonfAI framework.
|
|
481
|
+
Wraps setup, checkpointing, resuming, logging, and launching distributed _Trainer.
|
|
482
|
+
|
|
483
|
+
Main responsibilities:
|
|
484
|
+
- Initialization from config (via @config)
|
|
485
|
+
- Model and EMA setup
|
|
486
|
+
- Checkpoint loading and saving
|
|
487
|
+
- Distributed setup and launch
|
|
488
|
+
|
|
489
|
+
Args:
|
|
490
|
+
model (ModelLoader): Loader for model architecture.
|
|
491
|
+
dataset (DataTrain): Training/validation dataset.
|
|
492
|
+
train_name (str): Training session name.
|
|
493
|
+
manual_seed (int | None): Random seed.
|
|
494
|
+
epochs (int): Number of epochs to run.
|
|
495
|
+
it_validation (int | None): Validation interval.
|
|
496
|
+
autocast (bool): Enable AMP training.
|
|
497
|
+
gradient_checkpoints (list[str] | None): Modules to use gradient checkpointing on.
|
|
498
|
+
gpu_checkpoints (list[str] | None): Modules to pin on specific GPUs.
|
|
499
|
+
ema_decay (float): EMA decay factor.
|
|
500
|
+
data_log (list[str] | None): Logging instructions.
|
|
501
|
+
early_stopping (EarlyStopping | None): Optional early stopping config.
|
|
502
|
+
save_checkpoint_mode (str): Either "BEST" or "ALL".
|
|
503
|
+
"""
|
|
290
504
|
|
|
291
505
|
@config("Trainer")
|
|
292
|
-
def __init__(
|
|
293
|
-
|
|
294
|
-
|
|
295
|
-
|
|
296
|
-
|
|
297
|
-
|
|
298
|
-
|
|
299
|
-
|
|
300
|
-
|
|
301
|
-
|
|
302
|
-
|
|
303
|
-
|
|
304
|
-
|
|
305
|
-
|
|
506
|
+
def __init__(
|
|
507
|
+
self,
|
|
508
|
+
model: ModelLoader = ModelLoader(),
|
|
509
|
+
dataset: DataTrain = DataTrain(),
|
|
510
|
+
train_name: str = "default:TRAIN_01",
|
|
511
|
+
manual_seed: int | None = None,
|
|
512
|
+
epochs: int = 100,
|
|
513
|
+
it_validation: int | None = None,
|
|
514
|
+
autocast: bool = False,
|
|
515
|
+
gradient_checkpoints: list[str] | None = None,
|
|
516
|
+
gpu_checkpoints: list[str] | None = None,
|
|
517
|
+
ema_decay: float = 0,
|
|
518
|
+
data_log: list[str] | None = None,
|
|
519
|
+
early_stopping: EarlyStopping | None = None,
|
|
520
|
+
save_checkpoint_mode: str = "BEST",
|
|
521
|
+
) -> None:
|
|
306
522
|
if os.environ["KONFAI_CONFIG_MODE"] != "Done":
|
|
307
523
|
exit(0)
|
|
308
524
|
super().__init__(train_name)
|
|
309
|
-
self.manual_seed = manual_seed
|
|
525
|
+
self.manual_seed = manual_seed
|
|
310
526
|
self.dataset = dataset
|
|
311
527
|
self.autocast = autocast
|
|
312
528
|
self.epochs = epochs
|
|
@@ -314,118 +530,189 @@ class Trainer(DistributedObject):
|
|
|
314
530
|
self.early_stopping = early_stopping
|
|
315
531
|
self.it = 0
|
|
316
532
|
self.it_validation = it_validation
|
|
317
|
-
self.model = model.
|
|
533
|
+
self.model = model.get_model(train=True)
|
|
318
534
|
self.ema_decay = ema_decay
|
|
319
|
-
self.
|
|
535
|
+
self.model_ema: torch.optim.swa_utils.AveragedModel | None = None
|
|
320
536
|
self.data_log = data_log
|
|
321
537
|
|
|
322
538
|
modules = []
|
|
323
|
-
for i,_ in self.model.named_modules():
|
|
539
|
+
for i, _ in self.model.named_modules():
|
|
324
540
|
modules.append(i)
|
|
325
|
-
|
|
326
|
-
|
|
327
|
-
|
|
328
|
-
|
|
329
|
-
|
|
330
|
-
|
|
331
|
-
|
|
332
|
-
|
|
541
|
+
if self.data_log is not None:
|
|
542
|
+
for k in self.data_log:
|
|
543
|
+
tmp = k.split("/")[0].replace(":", ".")
|
|
544
|
+
if tmp not in self.dataset.get_groups_dest() and tmp not in modules:
|
|
545
|
+
raise TrainerError(
|
|
546
|
+
f"Invalid key '{tmp}' in `data_log`.",
|
|
547
|
+
f"This key is neither a destination group from the dataset ({self.dataset.get_groups_dest()})",
|
|
548
|
+
f"nor a valid module name in the model ({modules}).",
|
|
549
|
+
"Please check your `data_log` configuration,"
|
|
550
|
+
" it should reference either a model output or a dataset group.",
|
|
551
|
+
)
|
|
552
|
+
|
|
333
553
|
self.gradient_checkpoints = gradient_checkpoints
|
|
334
554
|
self.gpu_checkpoints = gpu_checkpoints
|
|
335
555
|
self.save_checkpoint_mode = save_checkpoint_mode
|
|
336
|
-
self.config_namefile_src =
|
|
337
|
-
self.config_namefile =
|
|
338
|
-
|
|
339
|
-
|
|
340
|
-
|
|
341
|
-
|
|
556
|
+
self.config_namefile_src = config_file().replace(".yml", "")
|
|
557
|
+
self.config_namefile = (
|
|
558
|
+
setups_directory() + self.name + "/" + self.config_namefile_src.split("/")[-1] + "_" + str(self.it) + ".yml"
|
|
559
|
+
)
|
|
560
|
+
self.size = len(self.gpu_checkpoints) + 1 if self.gpu_checkpoints else 1
|
|
561
|
+
|
|
562
|
+
def __exit__(self, exc_type, value, traceback):
|
|
563
|
+
"""Exit training context and trigger save of model/checkpoints."""
|
|
564
|
+
super().__exit__(exc_type, value, traceback)
|
|
342
565
|
self._save()
|
|
343
566
|
|
|
344
567
|
def _load(self) -> dict[str, dict[str, torch.Tensor]]:
|
|
345
|
-
|
|
568
|
+
"""
|
|
569
|
+
Loads a previously saved checkpoint from local disk or URL.
|
|
570
|
+
|
|
571
|
+
Returns:
|
|
572
|
+
dict: State dictionary loaded from checkpoint.
|
|
573
|
+
"""
|
|
574
|
+
if path_to_models().startswith("https://"):
|
|
346
575
|
try:
|
|
347
|
-
state_dict = {
|
|
348
|
-
|
|
349
|
-
|
|
576
|
+
state_dict = {
|
|
577
|
+
path_to_models().split(":")[1]: torch.hub.load_state_dict_from_url(
|
|
578
|
+
url=path_to_models().split(":")[0], map_location="cpu", check_hash=True
|
|
579
|
+
)
|
|
580
|
+
}
|
|
581
|
+
except Exception:
|
|
582
|
+
raise Exception(f"Model : {path_to_models()} does not exist !")
|
|
350
583
|
else:
|
|
351
|
-
if
|
|
584
|
+
if path_to_models() != "":
|
|
352
585
|
path = ""
|
|
353
|
-
name =
|
|
586
|
+
name = path_to_models()
|
|
354
587
|
else:
|
|
355
|
-
path =
|
|
588
|
+
path = checkpoints_directory() + self.name + "/"
|
|
356
589
|
if os.listdir(path):
|
|
357
590
|
name = sorted(os.listdir(path))[-1]
|
|
358
|
-
|
|
359
|
-
if os.path.exists(path+name):
|
|
360
|
-
state_dict = torch.load(path+name, weights_only=
|
|
591
|
+
|
|
592
|
+
if os.path.exists(path + name):
|
|
593
|
+
state_dict = torch.load(path + name, weights_only=True, map_location="cpu")
|
|
361
594
|
else:
|
|
362
|
-
raise Exception("Model : {} does not exist !"
|
|
363
|
-
|
|
595
|
+
raise Exception(f"Model : {self.name} does not exist !")
|
|
596
|
+
|
|
364
597
|
if "epoch" in state_dict:
|
|
365
|
-
self.epoch = state_dict[
|
|
598
|
+
self.epoch = state_dict["epoch"]
|
|
366
599
|
if "it" in state_dict:
|
|
367
|
-
self.it = state_dict[
|
|
600
|
+
self.it = state_dict["it"]
|
|
368
601
|
return state_dict
|
|
369
602
|
|
|
370
603
|
def _save(self) -> None:
|
|
371
|
-
|
|
372
|
-
|
|
373
|
-
|
|
374
|
-
|
|
375
|
-
|
|
376
|
-
|
|
377
|
-
|
|
378
|
-
|
|
379
|
-
|
|
380
|
-
|
|
381
|
-
|
|
382
|
-
|
|
383
|
-
|
|
384
|
-
|
|
385
|
-
|
|
386
|
-
|
|
387
|
-
|
|
388
|
-
|
|
389
|
-
|
|
390
|
-
os.rename(self.config_namefile, self.config_namefile.replace(".yml", "")+"_"+str(self.it)+".yml")
|
|
391
|
-
|
|
392
|
-
def _avg_fn(self, averaged_model_parameter, model_parameter, num_averaged):
|
|
393
|
-
return (1-self.ema_decay) * averaged_model_parameter + self.ema_decay * model_parameter
|
|
394
|
-
|
|
604
|
+
"""
|
|
605
|
+
Serializes model and EMA model (if any) in both state_dict and full model formats.
|
|
606
|
+
Also saves optimizer states and YAML config snapshot.
|
|
607
|
+
"""
|
|
608
|
+
if os.path.exists(self.config_namefile):
|
|
609
|
+
os.rename(
|
|
610
|
+
self.config_namefile,
|
|
611
|
+
self.config_namefile.replace(".yml", "") + "_" + str(self.it) + ".yml",
|
|
612
|
+
)
|
|
613
|
+
|
|
614
|
+
def _avg_fn(self, averaged_model_parameter: float, model_parameter, num_averaged):
|
|
615
|
+
"""
|
|
616
|
+
EMA update rule used by AveragedModel.
|
|
617
|
+
|
|
618
|
+
Returns:
|
|
619
|
+
torch.Tensor: Blended parameter using decay factor.
|
|
620
|
+
"""
|
|
621
|
+
return (1 - self.ema_decay) * averaged_model_parameter + self.ema_decay * model_parameter
|
|
622
|
+
|
|
395
623
|
def setup(self, world_size: int):
|
|
396
|
-
|
|
397
|
-
|
|
624
|
+
"""
|
|
625
|
+
Initializes the training environment:
|
|
626
|
+
- Clears previous outputs (unless resuming)
|
|
627
|
+
- Initializes model and EMA
|
|
628
|
+
- Loads checkpoint (if resuming)
|
|
629
|
+
- Prepares dataloaders
|
|
630
|
+
|
|
631
|
+
Args:
|
|
632
|
+
world_size (int): Total number of distributed processes.
|
|
633
|
+
"""
|
|
634
|
+
state = State[konfai_state()]
|
|
635
|
+
print(checkpoints_directory() + self.name + "/")
|
|
636
|
+
if state != State.RESUME and os.path.exists(checkpoints_directory() + self.name + "/"):
|
|
398
637
|
if os.environ["KONFAI_OVERWRITE"] != "True":
|
|
399
|
-
accept = input("The model {} already exists ! Do you want to overwrite it (yes,no) : "
|
|
638
|
+
accept = input(f"The model {self.name} already exists ! Do you want to overwrite it (yes,no) : ")
|
|
400
639
|
if accept != "yes":
|
|
401
640
|
return
|
|
402
|
-
for directory_path in [
|
|
403
|
-
|
|
404
|
-
|
|
405
|
-
|
|
641
|
+
for directory_path in [
|
|
642
|
+
statistics_directory(),
|
|
643
|
+
checkpoints_directory(),
|
|
644
|
+
setups_directory(),
|
|
645
|
+
]:
|
|
646
|
+
if os.path.exists(directory_path + self.name + "/"):
|
|
647
|
+
shutil.rmtree(directory_path + self.name + "/")
|
|
648
|
+
|
|
406
649
|
state_dict = {}
|
|
407
650
|
if state != State.TRAIN:
|
|
408
651
|
state_dict = self._load()
|
|
409
652
|
|
|
410
|
-
self.model.init(self.autocast, state, self.dataset.
|
|
411
|
-
self.model.
|
|
412
|
-
self.model._compute_channels_trace(
|
|
653
|
+
self.model.init(self.autocast, state, self.dataset.get_groups_dest())
|
|
654
|
+
self.model.init_outputs_group()
|
|
655
|
+
self.model._compute_channels_trace(
|
|
656
|
+
self.model,
|
|
657
|
+
self.model.in_channels,
|
|
658
|
+
self.gradient_checkpoints,
|
|
659
|
+
self.gpu_checkpoints,
|
|
660
|
+
)
|
|
413
661
|
self.model.load(state_dict, init=True, ema=False)
|
|
414
662
|
if self.ema_decay > 0:
|
|
415
|
-
self.
|
|
663
|
+
self.model_ema = AveragedModel(self.model, avg_fn=self._avg_fn)
|
|
416
664
|
if state_dict is not None:
|
|
417
|
-
self.
|
|
418
|
-
|
|
419
|
-
if not os.path.exists(
|
|
420
|
-
os.makedirs(
|
|
421
|
-
shutil.copyfile(self.config_namefile_src+".yml", self.config_namefile)
|
|
422
|
-
|
|
423
|
-
self.dataloader = self.dataset.
|
|
424
|
-
|
|
425
|
-
|
|
426
|
-
|
|
427
|
-
|
|
428
|
-
|
|
429
|
-
|
|
430
|
-
|
|
431
|
-
|
|
665
|
+
self.model_ema.module.load(state_dict, init=False, ema=True)
|
|
666
|
+
|
|
667
|
+
if not os.path.exists(setups_directory() + self.name + "/"):
|
|
668
|
+
os.makedirs(setups_directory() + self.name + "/")
|
|
669
|
+
shutil.copyfile(self.config_namefile_src + ".yml", self.config_namefile)
|
|
670
|
+
|
|
671
|
+
self.dataloader, train_names, validation_names = self.dataset.get_data(world_size // self.size)
|
|
672
|
+
with open(setups_directory() + self.name + "/Train_" + str(self.it) + ".txt", "w") as f:
|
|
673
|
+
for name in train_names:
|
|
674
|
+
f.write(name + "\n")
|
|
675
|
+
with open(setups_directory() + self.name + "/Validation_" + str(self.it) + ".txt", "w") as f:
|
|
676
|
+
for name in validation_names:
|
|
677
|
+
f.write(name + "\n")
|
|
678
|
+
|
|
679
|
+
def run_process(
|
|
680
|
+
self,
|
|
681
|
+
world_size: int,
|
|
682
|
+
global_rank: int,
|
|
683
|
+
local_rank: int,
|
|
684
|
+
dataloaders: list[DataLoader],
|
|
685
|
+
):
|
|
686
|
+
"""
|
|
687
|
+
Launches the actual training process via internal `_Trainer` class.
|
|
688
|
+
Wraps model with DDP or CPU fallback, attaches EMA, and starts training.
|
|
689
|
+
|
|
690
|
+
Args:
|
|
691
|
+
world_size (int): Total number of distributed processes.
|
|
692
|
+
global_rank (int): Global rank of the current process.
|
|
693
|
+
local_rank (int): Local rank within the node.
|
|
694
|
+
dataloaders (list[DataLoader]): Training and validation dataloaders.
|
|
695
|
+
"""
|
|
696
|
+
model = Network.to(self.model, local_rank * self.size)
|
|
697
|
+
model = DDP(model, static_graph=True) if torch.cuda.is_available() else CPUModel(model)
|
|
698
|
+
if self.model_ema is not None:
|
|
699
|
+
self.model_ema.module = Network.to(self.model_ema.module, local_rank)
|
|
700
|
+
with _Trainer(
|
|
701
|
+
world_size,
|
|
702
|
+
global_rank,
|
|
703
|
+
local_rank,
|
|
704
|
+
self.size,
|
|
705
|
+
self.name,
|
|
706
|
+
self.early_stopping,
|
|
707
|
+
self.data_log,
|
|
708
|
+
self.save_checkpoint_mode,
|
|
709
|
+
self.epochs,
|
|
710
|
+
self.epoch,
|
|
711
|
+
self.autocast,
|
|
712
|
+
self.it_validation,
|
|
713
|
+
self.it,
|
|
714
|
+
model,
|
|
715
|
+
self.model_ema,
|
|
716
|
+
*dataloaders,
|
|
717
|
+
) as t:
|
|
718
|
+
t.run()
|