konfai 1.1.7__py3-none-any.whl → 1.1.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of konfai might be problematic. Click here for more details.

Files changed (36) hide show
  1. konfai/__init__.py +59 -14
  2. konfai/data/augmentation.py +457 -286
  3. konfai/data/data_manager.py +509 -290
  4. konfai/data/patching.py +300 -183
  5. konfai/data/transform.py +384 -277
  6. konfai/evaluator.py +309 -68
  7. konfai/main.py +71 -22
  8. konfai/metric/measure.py +341 -222
  9. konfai/metric/schedulers.py +24 -13
  10. konfai/models/classification/convNeXt.py +187 -81
  11. konfai/models/classification/resnet.py +272 -58
  12. konfai/models/generation/cStyleGan.py +233 -59
  13. konfai/models/generation/ddpm.py +348 -121
  14. konfai/models/generation/diffusionGan.py +757 -358
  15. konfai/models/generation/gan.py +177 -53
  16. konfai/models/generation/vae.py +140 -40
  17. konfai/models/registration/registration.py +135 -52
  18. konfai/models/representation/representation.py +57 -23
  19. konfai/models/segmentation/NestedUNet.py +339 -68
  20. konfai/models/segmentation/UNet.py +140 -30
  21. konfai/network/blocks.py +331 -187
  22. konfai/network/network.py +781 -423
  23. konfai/predictor.py +645 -240
  24. konfai/trainer.py +527 -216
  25. konfai/utils/ITK.py +191 -106
  26. konfai/utils/config.py +152 -95
  27. konfai/utils/dataset.py +326 -455
  28. konfai/utils/utils.py +495 -249
  29. {konfai-1.1.7.dist-info → konfai-1.1.9.dist-info}/METADATA +1 -3
  30. konfai-1.1.9.dist-info/RECORD +38 -0
  31. konfai/utils/registration.py +0 -199
  32. konfai-1.1.7.dist-info/RECORD +0 -39
  33. {konfai-1.1.7.dist-info → konfai-1.1.9.dist-info}/WHEEL +0 -0
  34. {konfai-1.1.7.dist-info → konfai-1.1.9.dist-info}/entry_points.txt +0 -0
  35. {konfai-1.1.7.dist-info → konfai-1.1.9.dist-info}/licenses/LICENSE +0 -0
  36. {konfai-1.1.7.dist-info → konfai-1.1.9.dist-info}/top_level.txt +0 -0
konfai/trainer.py CHANGED
@@ -1,62 +1,88 @@
1
- import torch
2
- from torch.utils.data import DataLoader
3
- import tqdm
4
- import numpy as np
5
1
  import os
6
- from typing import Union
7
- from torch.nn.parallel import DistributedDataParallel as DDP
8
-
9
2
  import shutil
10
- from torch.utils.tensorboard.writer import SummaryWriter
11
- from torch.optim.swa_utils import AveragedModel
3
+
4
+ import torch
12
5
  import torch.distributed as dist
6
+ import tqdm
7
+ from torch.nn.parallel import DistributedDataParallel as DDP # noqa: N817
8
+ from torch.optim.swa_utils import AveragedModel
9
+ from torch.utils.data import DataLoader
10
+ from torch.utils.tensorboard.writer import SummaryWriter
13
11
 
14
- from konfai import MODELS_DIRECTORY, CHECKPOINTS_DIRECTORY, STATISTICS_DIRECTORY, SETUPS_DIRECTORY, CONFIG_FILE, MODEL, DATE, KONFAI_STATE
12
+ from konfai import (
13
+ checkpoints_directory,
14
+ config_file,
15
+ current_date,
16
+ konfai_state,
17
+ models_directory,
18
+ path_to_models,
19
+ setups_directory,
20
+ statistics_directory,
21
+ )
15
22
  from konfai.data.data_manager import DataTrain
23
+ from konfai.network.network import CPUModel, ModelLoader, NetState, Network
16
24
  from konfai.utils.config import config
17
- from konfai.utils.utils import State, DataLog, DistributedObject, description, TrainerError
18
- from konfai.network.network import Network, ModelLoader, NetState, CPU_Model
25
+ from konfai.utils.utils import DataLog, DistributedObject, State, TrainerError, description
26
+
19
27
 
20
28
  class EarlyStoppingBase:
21
29
 
22
30
  def __init__(self):
23
31
  pass
24
32
 
25
- def isStopped(self) -> bool:
33
+ def is_stopped(self) -> bool:
34
+
26
35
  return False
27
36
 
28
- def getScore(self, values: dict[str, float]):
29
- return sum([i for i in values.values()])
30
-
37
+ def get_score(self, values: dict[str, float]):
38
+ return sum(list(values.values()))
39
+
31
40
  def __call__(self, current_score: float) -> bool:
32
41
  return False
33
42
 
43
+
34
44
  class EarlyStopping(EarlyStoppingBase):
45
+ """
46
+ Implements early stopping logic with configurable patience and monitored metrics.
47
+
48
+ Attributes:
49
+ monitor (list[str]): Metrics to monitor.
50
+ patience (int): Number of checks with no improvement before stopping.
51
+ min_delta (float): Minimum change to qualify as improvement.
52
+ mode (str): "min" or "max" depending on optimization direction.
53
+ """
35
54
 
36
55
  @config("EarlyStopping")
37
- def __init__(self, monitor: Union[list[str], None] = [], patience: int=10, min_delta: float=0.0, mode: str="min"):
56
+ def __init__(
57
+ self,
58
+ monitor: list[str] | None = None,
59
+ patience: int = 10,
60
+ min_delta: float = 0.0,
61
+ mode: str = "min",
62
+ ):
38
63
  super().__init__()
39
64
  self.monitor = [] if monitor is None else monitor
40
65
  self.patience = patience
41
66
  self.min_delta = min_delta
42
67
  self.mode = mode
43
68
  self.counter = 0
44
- self.best_score = None
69
+ self.best_score: float | None = None
45
70
  self.early_stop = False
46
71
 
47
- def isStopped(self) -> bool:
72
+ def is_stopped(self) -> bool:
48
73
  return self.early_stop
49
74
 
50
- def getScore(self, values: dict[str, float]):
75
+ def get_score(self, values: dict[str, float]):
51
76
  if len(self.monitor) == 0:
52
- return super().getScore(values)
77
+ return super().get_score(values)
53
78
  for v in self.monitor:
54
79
  if v not in values.keys():
55
80
  raise TrainerError(
56
81
  "Metric '{}' specified in EarlyStopping.monitor not found in logged values. ",
57
- "Available keys: {}. Please check your configuration.".format(v, list(values.keys())))
82
+ f"Available keys: {v}. Please check your configuration.",
83
+ )
58
84
  return sum([i for v, i in values.items() if v in self.monitor])
59
-
85
+
60
86
  def __call__(self, current_score: float) -> bool:
61
87
  if self.best_score is None:
62
88
  self.best_score = current_score
@@ -80,10 +106,44 @@ class EarlyStopping(EarlyStoppingBase):
80
106
 
81
107
  return self.early_stop
82
108
 
83
- class _Trainer():
84
109
 
85
- def __init__(self, world_size: int, global_rank: int, local_rank: int, size: int, train_name: str, early_stopping: EarlyStopping, data_log: Union[list[str], None] , save_checkpoint_mode: str, epochs: int, epoch: int, autocast: bool, it_validation: Union[int, None], it: int, model: Union[DDP, CPU_Model], modelEMA: AveragedModel, dataloader_training: DataLoader, dataloader_validation: Union[DataLoader, None] = None) -> None:
86
- self.world_size = world_size
110
+ class _Trainer:
111
+ """
112
+ Internal class for managing the training loop in a distributed or standalone setting.
113
+
114
+ Handles:
115
+ - Epoch iteration with training and optional validation
116
+ - Mixed precision support (autocast)
117
+ - Exponential Moving Average (EMA) model tracking
118
+ - Early stopping
119
+ - Logging to TensorBoard
120
+ - Model checkpoint saving and selection (ALL or BEST)
121
+
122
+ This class is intended to be used via a context manager
123
+ (`with _Trainer(...) as trainer:`) inside the public `Trainer` class.
124
+ """
125
+
126
+ def __init__(
127
+ self,
128
+ world_size: int,
129
+ global_rank: int,
130
+ local_rank: int,
131
+ size: int,
132
+ train_name: str,
133
+ early_stopping: EarlyStopping | None,
134
+ data_log: list[str] | None,
135
+ save_checkpoint_mode: str,
136
+ epochs: int,
137
+ epoch: int,
138
+ autocast: bool,
139
+ it_validation: int | None,
140
+ it: int,
141
+ model: DDP | CPUModel,
142
+ model_ema: AveragedModel,
143
+ dataloader_training: DataLoader,
144
+ dataloader_validation: DataLoader | None = None,
145
+ ) -> None:
146
+ self.world_size = world_size
87
147
  self.global_rank = global_rank
88
148
  self.local_rank = local_rank
89
149
  self.size = size
@@ -96,58 +156,94 @@ class _Trainer():
96
156
  self.dataloader_training = dataloader_training
97
157
  self.dataloader_validation = dataloader_validation
98
158
  self.autocast = autocast
99
- self.modelEMA = modelEMA
100
- self.early_stopping = EarlyStoppingBase() if early_stopping is None else early_stopping
101
-
102
- self.it_validation = it_validation
103
- if self.it_validation is None:
104
- self.it_validation = len(dataloader_training)
159
+ self.model_ema = model_ema
160
+ self.early_stopping = EarlyStoppingBase() if early_stopping is None else early_stopping
161
+
162
+ self.it_validation = len(dataloader_training) if it_validation is None else it_validation
105
163
  self.it = it
106
- self.tb = SummaryWriter(log_dir = STATISTICS_DIRECTORY()+self.train_name+"/")
107
- self.data_log : dict[str, tuple[DataLog, int]] = {}
164
+ self.tb = SummaryWriter(log_dir=statistics_directory() + self.train_name + "/")
165
+ self.data_log: dict[str, tuple[DataLog, int]] = {}
108
166
  if data_log is not None:
109
167
  for data in data_log:
110
- self.data_log[data.split("/")[0].replace(":", ".")] = (DataLog.__getitem__(data.split("/")[1]).value[0], int(data.split("/")[2]))
111
-
168
+ self.data_log[data.split("/")[0].replace(":", ".")] = (
169
+ DataLog[data.split("/")[1]],
170
+ int(data.split("/")[2]),
171
+ )
172
+
112
173
  def __enter__(self):
113
174
  return self
114
-
115
- def __exit__(self, type, value, traceback):
175
+
176
+ def __exit__(self, exc_type, value, traceback):
177
+ """Closes the SummaryWriter if used."""
116
178
  if self.tb is not None:
117
179
  self.tb.close()
118
180
 
119
181
  def run(self) -> None:
182
+ """
183
+ Launches the training loop, performing one epoch at a time.
184
+ Triggers early stopping and resets data augmentations between epochs.
185
+ """
120
186
  self.dataloader_training.dataset.load("Train")
121
187
  if self.dataloader_validation is not None:
122
188
  self.dataloader_validation.dataset.load("Validation")
123
- with tqdm.tqdm(iterable = range(self.epoch, self.epochs), leave=False, total=self.epochs, initial=self.epoch, desc="Progress") as epoch_tqdm:
189
+ with tqdm.tqdm(
190
+ iterable=range(self.epoch, self.epochs),
191
+ leave=False,
192
+ total=self.epochs,
193
+ initial=self.epoch,
194
+ desc="Progress",
195
+ ) as epoch_tqdm:
124
196
  for self.epoch in epoch_tqdm:
125
197
  self.train()
126
- if self.early_stopping.isStopped():
198
+ if self.early_stopping.is_stopped():
127
199
  break
128
- self.dataloader_training.dataset.resetAugmentation("Train")
129
-
130
- def getInput(self, data_dict : dict[str, tuple[torch.Tensor, int, int, int, str, bool]]) -> dict[tuple[str, bool], torch.Tensor]:
131
- return {(k, v[5][0].item()) : v[0] for k, v in data_dict.items()}
200
+ self.dataloader_training.dataset.reset_augmentation("Train")
132
201
 
133
- def train(self) -> None:
134
- self.model.train()
135
- self.model.module.setState(NetState.TRAIN)
136
- if self.modelEMA is not None:
137
- self.modelEMA.eval()
138
- self.modelEMA.module.setState(NetState.TRAIN)
202
+ def get_input(
203
+ self,
204
+ data_dict: dict[str, tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, list[str], torch.Tensor]],
205
+ ) -> dict[tuple[str, bool], torch.Tensor]:
206
+ """
207
+ Extracts input tensors from the data dict for model input.
139
208
 
140
- desc = lambda : "Training : {}".format(description(self.model, self.modelEMA))
209
+ Args:
210
+ data_dict (dict): Dictionary with data items structured as tuples.
141
211
 
142
- with tqdm.tqdm(iterable = enumerate(self.dataloader_training), desc = desc(), total=len(self.dataloader_training), leave=False, ncols=0) as batch_iter:
212
+ Returns:
213
+ dict: Mapping from (group, bool_flag) to input tensors.
214
+ """
215
+ return {(k, v[5][0].item()): v[0] for k, v in data_dict.items()}
216
+
217
+ def train(self) -> None:
218
+ """
219
+ Performs a full training epoch with support for:
220
+ - mixed precision
221
+ - DDP / CPU training
222
+ - EMA updates
223
+ - loss logging and checkpoint saving
224
+ - validation at configurable iteration interval
225
+ """
226
+ self.model.train()
227
+ self.model.module.set_state(NetState.TRAIN)
228
+ if self.model_ema is not None:
229
+ self.model_ema.eval()
230
+ self.model_ema.module.set_state(NetState.TRAIN)
231
+
232
+ with tqdm.tqdm(
233
+ iterable=enumerate(self.dataloader_training),
234
+ desc=f"Training : {description(self.model, self.model_ema)}",
235
+ total=len(self.dataloader_training),
236
+ leave=False,
237
+ ncols=0,
238
+ ) as batch_iter:
143
239
  for _, data_dict in batch_iter:
144
- with torch.amp.autocast('cuda', enabled=self.autocast):
145
- input = self.getInput(data_dict)
146
- self.model(input)
240
+ with torch.amp.autocast("cuda", enabled=self.autocast):
241
+ input_data_dict = self.get_input(data_dict)
242
+ self.model(input_data_dict)
147
243
  self.model.module.backward(self.model.module)
148
- if self.modelEMA is not None:
149
- self.modelEMA.update_parameters(self.model)
150
- self.modelEMA.module(input)
244
+ if self.model_ema is not None:
245
+ self.model_ema.update_parameters(self.model)
246
+ self.model_ema.module(input_data_dict)
151
247
  self.it += 1
152
248
  if (self.it) % self.it_validation == 0:
153
249
  loss = self._train_log(data_dict)
@@ -155,76 +251,109 @@ class _Trainer():
155
251
  if self.dataloader_validation is not None:
156
252
  loss = self._validate()
157
253
  self.model.module.update_lr()
158
- score = self.early_stopping.getScore(loss)
254
+ score = self.early_stopping.get_score(loss)
159
255
  self.checkpoint_save(score)
160
256
  if self.early_stopping(score):
161
257
  break
162
-
163
258
 
164
- batch_iter.set_description(desc())
165
-
166
-
259
+ batch_iter.set_description(f"Training : {description(self.model, self.model_ema)}")
260
+
167
261
  @torch.no_grad()
168
262
  def _validate(self) -> float:
263
+ """
264
+ Executes the validation phase, evaluates loss and metrics.
265
+ Updates model states and resets augmentation for validation set.
266
+
267
+ Returns:
268
+ float: Validation loss.
269
+ """
270
+ if self.dataloader_validation is None:
271
+ return 0
169
272
  self.model.eval()
170
- self.model.module.setState(NetState.PREDICTION)
171
- if self.modelEMA is not None:
172
- self.modelEMA.module.setState(NetState.PREDICTION)
273
+ self.model.module.set_state(NetState.PREDICTION)
274
+ if self.model_ema is not None:
275
+ self.model_ema.module.set_state(NetState.PREDICTION)
173
276
 
174
- desc = lambda : "Validation : {}".format(description(self.model, self.modelEMA))
175
277
  data_dict = None
176
- with tqdm.tqdm(iterable = enumerate(self.dataloader_validation), desc = desc(), total=len(self.dataloader_validation), leave=False, ncols=0) as batch_iter:
278
+ with tqdm.tqdm(
279
+ iterable=enumerate(self.dataloader_validation),
280
+ desc=f"Validation : {description(self.model, self.model_ema)}",
281
+ total=len(self.dataloader_validation),
282
+ leave=False,
283
+ ncols=0,
284
+ ) as batch_iter:
177
285
  for _, data_dict in batch_iter:
178
- input = self.getInput(data_dict)
179
- self.model(input)
180
- if self.modelEMA is not None:
181
- self.modelEMA.module(input)
286
+ input_data_dict = self.get_input(data_dict)
287
+ self.model(input_data_dict)
288
+ if self.model_ema is not None:
289
+ self.model_ema.module(input_data_dict)
182
290
 
183
- batch_iter.set_description(desc())
184
- self.dataloader_validation.dataset.resetAugmentation("Validation")
291
+ batch_iter.set_description(f"Validation : {description(self.model, self.model_ema)}")
292
+ self.dataloader_validation.dataset.reset_augmentation("Validation")
185
293
  dist.barrier()
186
294
  self.model.train()
187
- self.model.module.setState(NetState.TRAIN)
188
- if self.modelEMA is not None:
189
- self.modelEMA.module.setState(NetState.TRAIN)
295
+ self.model.module.set_state(NetState.TRAIN)
296
+ if self.model_ema is not None:
297
+ self.model_ema.module.set_state(NetState.TRAIN)
190
298
  return self._validation_log(data_dict)
191
299
 
192
300
  def checkpoint_save(self, loss: float) -> None:
301
+ """
302
+ Saves model and optimizer states. Keeps either all checkpoints or only the best one.
303
+
304
+ Args:
305
+ loss (float): Current loss used for best checkpoint selection.
306
+ """
193
307
  if self.global_rank != 0:
194
308
  return
195
309
 
196
- path = CHECKPOINTS_DIRECTORY()+self.train_name+"/"
310
+ path = checkpoints_directory() + self.train_name + "/"
197
311
  os.makedirs(path, exist_ok=True)
198
312
 
199
- name = DATE() + ".pt"
313
+ name = current_date() + ".pt"
200
314
  save_path = os.path.join(path, name)
201
315
 
202
316
  save_dict = {
203
- "epoch": self.epoch,
204
- "it": self.it,
205
- "loss": loss,
206
- "Model": self.model.module.state_dict()
317
+ "epoch": self.epoch,
318
+ "it": self.it,
319
+ "loss": loss,
320
+ "Model": self.model.module.state_dict(),
207
321
  }
208
322
 
209
- if self.modelEMA is not None:
210
- save_dict["Model_EMA"] = self.modelEMA.module.state_dict()
211
-
212
- save_dict.update({'{}_optimizer_state_dict'.format(name): network.optimizer.state_dict() for name, network in self.model.module.getNetworks().items() if network.optimizer is not None})
213
- save_dict.update({'{}_it'.format(name): network._it for name, network in self.model.module.getNetworks().items() if network.optimizer is not None})
214
- save_dict.update({'{}_nb_lr_update'.format(name): network._nb_lr_update for name, network in self.model.module.getNetworks().items() if network.optimizer is not None})
215
-
323
+ if self.model_ema is not None:
324
+ save_dict["Model_EMA"] = self.model_ema.module.state_dict()
325
+
326
+ save_dict.update(
327
+ {
328
+ f"{name}_optimizer_state_dict": network.optimizer.state_dict()
329
+ for name, network in self.model.module.get_networks().items()
330
+ if network.optimizer is not None
331
+ }
332
+ )
333
+ save_dict.update(
334
+ {
335
+ f"{name}_it": network._it
336
+ for name, network in self.model.module.get_networks().items()
337
+ if network.optimizer is not None
338
+ }
339
+ )
340
+ save_dict.update(
341
+ {
342
+ f"{name}_nb_lr_update": network._nb_lr_update
343
+ for name, network in self.model.module.get_networks().items()
344
+ if network.optimizer is not None
345
+ }
346
+ )
347
+
216
348
  torch.save(save_dict, save_path)
217
349
 
218
350
  if self.save_checkpoint_mode == "BEST":
219
- all_checkpoints = sorted([
220
- os.path.join(path, f)
221
- for f in os.listdir(path) if f.endswith(".pt")
222
- ])
351
+ all_checkpoints = sorted([os.path.join(path, f) for f in os.listdir(path) if f.endswith(".pt")])
223
352
  best_ckpt = None
224
- best_loss = float('inf')
353
+ best_loss = float("inf")
225
354
 
226
355
  for f in all_checkpoints:
227
- d = torch.load(f, weights_only=False)
356
+ d = torch.load(f, weights_only=True)
228
357
  if d.get("loss", float("inf")) < best_loss:
229
358
  best_loss = d["loss"]
230
359
  best_ckpt = f
@@ -234,79 +363,165 @@ class _Trainer():
234
363
  os.remove(f)
235
364
 
236
365
  @torch.no_grad()
237
- def _log(self, type_log: str, data_dict : dict[str, tuple[torch.Tensor, int, int, int]]) -> dict[str, float]:
238
- models: dict[str, Network] = {"" : self.model.module}
239
- if self.modelEMA is not None:
240
- models["_EMA"] = self.modelEMA.module
241
-
242
- measures = DistributedObject.getMeasure(self.world_size, self.global_rank, self.local_rank*self.size+self.size-1, models, self.it_validation if type_log == "Training" else len(self.dataloader_validation))
243
-
366
+ def _log(
367
+ self,
368
+ type_log: str,
369
+ data_dict: dict[str, tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, list[str], torch.Tensor]],
370
+ ) -> dict[str, float] | None:
371
+ """
372
+ Logs losses, metrics and optionally images to TensorBoard.
373
+
374
+ Args:
375
+ type_log (str): "Training" or "Validation".
376
+ data_dict (dict): Dictionary of data from current batch.
377
+
378
+ Returns:
379
+ dict[str, float] | None: Dictionary of aggregated losses and metrics if rank == 0.
380
+ """
381
+ models: dict[str, Network] = {"": self.model.module}
382
+ if self.model_ema is not None:
383
+ models["_EMA"] = self.model_ema.module
384
+
385
+ measures = DistributedObject.get_measure(
386
+ self.world_size,
387
+ self.global_rank,
388
+ self.local_rank * self.size + self.size - 1,
389
+ models,
390
+ (
391
+ self.it_validation
392
+ if type_log == "Training" or self.dataloader_validation is None
393
+ else len(self.dataloader_validation)
394
+ ),
395
+ )
396
+
244
397
  if self.global_rank == 0:
245
398
  images_log = []
246
- if len(self.data_log):
399
+ if len(self.data_log):
247
400
  for name, data_type in self.data_log.items():
248
401
  if name in data_dict:
249
- data_type[0](self.tb, "{}/{}".format(type_log ,name), data_dict[name][0][:self.data_log[name][1]].detach().cpu().numpy(), self.it)
402
+ data_type[0](
403
+ self.tb,
404
+ f"{type_log}/{name}",
405
+ data_dict[name][0][: self.data_log[name][1]].detach().cpu().numpy(),
406
+ self.it,
407
+ )
250
408
  else:
251
409
  images_log.append(name.replace(":", "."))
252
-
410
+
253
411
  for label, model in models.items():
254
- for name, network in model.getNetworks().items():
412
+ for name, network in model.get_networks().items():
255
413
  if network.measure is not None:
256
- self.tb.add_scalars("{}/{}/Loss/{}".format(type_log, name, label), {k : v[1] for k, v in measures["{}{}".format(name, label)][0].items()}, self.it)
257
- self.tb.add_scalars("{}/{}/Loss_weight/{}".format(type_log, name, label), {k : v[0] for k, v in measures["{}{}".format(name, label)][0].items()}, self.it)
414
+ self.tb.add_scalars(
415
+ f"{type_log}/{name}/Loss/{label}",
416
+ {k: v[1] for k, v in measures[f"{name}{label}"][0].items()},
417
+ self.it,
418
+ )
419
+ self.tb.add_scalars(
420
+ f"{type_log}/{name}/Loss_weight/{label}",
421
+ {k: v[0] for k, v in measures[f"{name}{label}"][0].items()},
422
+ self.it,
423
+ )
424
+
425
+ self.tb.add_scalars(
426
+ f"{type_log}/{name}/Metric/{label}",
427
+ {k: v[1] for k, v in measures[f"{name}{label}"][1].items()},
428
+ self.it,
429
+ )
430
+ self.tb.add_scalars(
431
+ f"{type_log}/{name}/Metric_weight/{label}",
432
+ {k: v[0] for k, v in measures[f"{name}{label}"][1].items()},
433
+ self.it,
434
+ )
258
435
 
259
- self.tb.add_scalars("{}/{}/Metric/{}".format(type_log, name, label), {k : v[1] for k, v in measures["{}{}".format(name, label)][1].items()}, self.it)
260
- self.tb.add_scalars("{}/{}/Metric_weight/{}".format(type_log, name, label), {k : v[0] for k, v in measures["{}{}".format(name, label)][1].items()}, self.it)
261
-
262
436
  if len(images_log):
263
- for name, layer, _ in model.get_layers([v.to(0) for k, v in self.getInput(data_dict).items() if k[1]], images_log):
264
- self.data_log[name][0](self.tb, "{}/{}{}".format(type_log, name, label), layer[:self.data_log[name][1]].detach().cpu().numpy(), self.it)
265
-
437
+ for name, layer, _ in model.get_layers(
438
+ [v.to(0) for k, v in self.get_input(data_dict).items() if k[1]],
439
+ images_log,
440
+ ):
441
+ self.data_log[name][0](
442
+ self.tb,
443
+ f"{type_log}/{name}{label}",
444
+ layer[: self.data_log[name][1]].detach().cpu().numpy(),
445
+ self.it,
446
+ )
447
+
266
448
  if type_log == "Training":
267
- for name, network in self.model.module.getNetworks().items():
449
+ for name, network in self.model.module.get_networks().items():
268
450
  if network.optimizer is not None:
269
- self.tb.add_scalar("{}/{}/Learning Rate".format(type_log, name), network.optimizer.param_groups[0]['lr'], self.it)
270
-
451
+ self.tb.add_scalar(
452
+ f"{type_log}/{name}/Learning Rate",
453
+ network.optimizer.param_groups[0]["lr"],
454
+ self.it,
455
+ )
456
+
271
457
  if self.global_rank == 0:
272
458
  loss = {}
273
- for name, network in self.model.module.getNetworks().items():
459
+ for name, network in self.model.module.get_networks().items():
274
460
  if network.measure is not None:
275
- loss.update({k : v[1] for k, v in measures["{}{}".format(name, label)][0].items()})
276
- loss.update({k : v[1] for k, v in measures["{}{}".format(name, label)][1].items()})
461
+ loss.update({k: v[1] for k, v in measures[f"{name}{label}"][0].items()})
462
+ loss.update({k: v[1] for k, v in measures[f"{name}{label}"][1].items()})
277
463
  return loss
278
464
  return None
279
-
465
+
280
466
  @torch.no_grad()
281
- def _train_log(self, data_dict : dict[str, tuple[torch.Tensor, int, int, int]]) -> dict[str, float]:
467
+ def _train_log(self, data_dict: dict[str, tuple[torch.Tensor, int, int, int]]) -> dict[str, float]:
468
+ """Wrapper for _log during training."""
282
469
  return self._log("Training", data_dict)
283
470
 
284
471
  @torch.no_grad()
285
- def _validation_log(self, data_dict : dict[str, tuple[torch.Tensor, int, int, int]]) -> dict[str, float]:
472
+ def _validation_log(self, data_dict: dict[str, tuple[torch.Tensor, int, int, int]]) -> dict[str, float]:
473
+ """Wrapper for _log during validation."""
286
474
  return self._log("Validation", data_dict)
287
475
 
288
-
476
+
289
477
  class Trainer(DistributedObject):
478
+ """
479
+ Public API for training a model using the KonfAI framework.
480
+ Wraps setup, checkpointing, resuming, logging, and launching distributed _Trainer.
481
+
482
+ Main responsibilities:
483
+ - Initialization from config (via @config)
484
+ - Model and EMA setup
485
+ - Checkpoint loading and saving
486
+ - Distributed setup and launch
487
+
488
+ Args:
489
+ model (ModelLoader): Loader for model architecture.
490
+ dataset (DataTrain): Training/validation dataset.
491
+ train_name (str): Training session name.
492
+ manual_seed (int | None): Random seed.
493
+ epochs (int): Number of epochs to run.
494
+ it_validation (int | None): Validation interval.
495
+ autocast (bool): Enable AMP training.
496
+ gradient_checkpoints (list[str] | None): Modules to use gradient checkpointing on.
497
+ gpu_checkpoints (list[str] | None): Modules to pin on specific GPUs.
498
+ ema_decay (float): EMA decay factor.
499
+ data_log (list[str] | None): Logging instructions.
500
+ early_stopping (EarlyStopping | None): Optional early stopping config.
501
+ save_checkpoint_mode (str): Either "BEST" or "ALL".
502
+ """
290
503
 
291
504
  @config("Trainer")
292
- def __init__( self,
293
- model : ModelLoader = ModelLoader(),
294
- dataset : DataTrain = DataTrain(),
295
- train_name : str = "default:TRAIN_01",
296
- manual_seed : Union[int, None] = None,
297
- epochs: int = 100,
298
- it_validation : Union[int, None] = None,
299
- autocast : bool = False,
300
- gradient_checkpoints: Union[list[str], None] = None,
301
- gpu_checkpoints: Union[list[str], None] = None,
302
- ema_decay : float = 0,
303
- data_log: Union[list[str], None] = None,
304
- early_stopping: Union[EarlyStopping, None] = None,
305
- save_checkpoint_mode: str= "BEST") -> None:
505
+ def __init__(
506
+ self,
507
+ model: ModelLoader = ModelLoader(),
508
+ dataset: DataTrain = DataTrain(),
509
+ train_name: str = "default:TRAIN_01",
510
+ manual_seed: int | None = None,
511
+ epochs: int = 100,
512
+ it_validation: int | None = None,
513
+ autocast: bool = False,
514
+ gradient_checkpoints: list[str] | None = None,
515
+ gpu_checkpoints: list[str] | None = None,
516
+ ema_decay: float = 0,
517
+ data_log: list[str] | None = None,
518
+ early_stopping: EarlyStopping | None = None,
519
+ save_checkpoint_mode: str = "BEST",
520
+ ) -> None:
306
521
  if os.environ["KONFAI_CONFIG_MODE"] != "Done":
307
522
  exit(0)
308
523
  super().__init__(train_name)
309
- self.manual_seed = manual_seed
524
+ self.manual_seed = manual_seed
310
525
  self.dataset = dataset
311
526
  self.autocast = autocast
312
527
  self.epochs = epochs
@@ -314,118 +529,214 @@ class Trainer(DistributedObject):
314
529
  self.early_stopping = early_stopping
315
530
  self.it = 0
316
531
  self.it_validation = it_validation
317
- self.model = model.getModel(train=True)
532
+ self.model = model.get_model(train=True)
318
533
  self.ema_decay = ema_decay
319
- self.modelEMA : Union[torch.optim.swa_utils.AveragedModel, None] = None
534
+ self.model_ema: torch.optim.swa_utils.AveragedModel | None = None
320
535
  self.data_log = data_log
321
536
 
322
537
  modules = []
323
- for i,_ in self.model.named_modules():
538
+ for i, _ in self.model.named_modules():
324
539
  modules.append(i)
325
- for k in self.data_log:
326
- tmp = k.split("/")[0].replace(":", ".")
327
- if tmp not in self.dataset.getGroupsDest() and tmp not in modules:
328
- raise TrainerError( f"Invalid key '{tmp}' in `data_log`.",
329
- f"This key is neither a destination group from the dataset ({self.dataset.getGroupsDest()})",
330
- f"nor a valid module name in the model ({modules}).",
331
- "Please check your `data_log` configuration — it should reference either a model output or a dataset group.")
332
-
540
+ if self.data_log is not None:
541
+ for k in self.data_log:
542
+ tmp = k.split("/")[0].replace(":", ".")
543
+ if tmp not in self.dataset.get_groups_dest() and tmp not in modules:
544
+ raise TrainerError(
545
+ f"Invalid key '{tmp}' in `data_log`.",
546
+ f"This key is neither a destination group from the dataset ({self.dataset.get_groups_dest()})",
547
+ f"nor a valid module name in the model ({modules}).",
548
+ "Please check your `data_log` configuration,"
549
+ " it should reference either a model output or a dataset group.",
550
+ )
551
+
333
552
  self.gradient_checkpoints = gradient_checkpoints
334
553
  self.gpu_checkpoints = gpu_checkpoints
335
554
  self.save_checkpoint_mode = save_checkpoint_mode
336
- self.config_namefile_src = CONFIG_FILE().replace(".yml", "")
337
- self.config_namefile = SETUPS_DIRECTORY()+self.name+"/"+self.config_namefile_src.split("/")[-1]+"_"+str(self.it)+".yml"
338
- self.size = (len(self.gpu_checkpoints)+1 if self.gpu_checkpoints else 1)
339
-
340
- def __exit__(self, type, value, traceback):
341
- super().__exit__(type, value, traceback)
555
+ self.config_namefile_src = config_file().replace(".yml", "")
556
+ self.config_namefile = (
557
+ setups_directory() + self.name + "/" + self.config_namefile_src.split("/")[-1] + "_" + str(self.it) + ".yml"
558
+ )
559
+ self.size = len(self.gpu_checkpoints) + 1 if self.gpu_checkpoints else 1
560
+
561
+ def __exit__(self, exc_type, value, traceback):
562
+ """Exit training context and trigger save of model/checkpoints."""
563
+ super().__exit__(exc_type, value, traceback)
342
564
  self._save()
343
565
 
344
566
  def _load(self) -> dict[str, dict[str, torch.Tensor]]:
345
- if MODEL().startswith("https://"):
567
+ """
568
+ Loads a previously saved checkpoint from local disk or URL.
569
+
570
+ Returns:
571
+ dict: State dictionary loaded from checkpoint.
572
+ """
573
+ if path_to_models().startswith("https://"):
346
574
  try:
347
- state_dict = {MODEL().split(":")[1]: torch.hub.load_state_dict_from_url(url=MODEL().split(":")[0], map_location="cpu", check_hash=True)}
348
- except:
349
- raise Exception("Model : {} does not exist !".format(MODEL()))
575
+ state_dict = {
576
+ path_to_models().split(":")[1]: torch.hub.load_state_dict_from_url(
577
+ url=path_to_models().split(":")[0], map_location="cpu", check_hash=True
578
+ )
579
+ }
580
+ except Exception:
581
+ raise Exception(f"Model : {path_to_models()} does not exist !")
350
582
  else:
351
- if MODEL() != "":
583
+ if path_to_models() != "":
352
584
  path = ""
353
- name = MODEL()
585
+ name = path_to_models()
354
586
  else:
355
- path = CHECKPOINTS_DIRECTORY()+self.name+"/"
587
+ path = checkpoints_directory() + self.name + "/"
356
588
  if os.listdir(path):
357
589
  name = sorted(os.listdir(path))[-1]
358
-
359
- if os.path.exists(path+name):
360
- state_dict = torch.load(path+name, weights_only=False, map_location="cpu")
590
+
591
+ if os.path.exists(path + name):
592
+ state_dict = torch.load(path + name, weights_only=True, map_location="cpu")
361
593
  else:
362
- raise Exception("Model : {} does not exist !".format(self.name))
363
-
594
+ raise Exception(f"Model : {self.name} does not exist !")
595
+
364
596
  if "epoch" in state_dict:
365
- self.epoch = state_dict['epoch']
597
+ self.epoch = state_dict["epoch"]
366
598
  if "it" in state_dict:
367
- self.it = state_dict['it']
599
+ self.it = state_dict["it"]
368
600
  return state_dict
369
601
 
370
602
  def _save(self) -> None:
371
- path_checkpoint = CHECKPOINTS_DIRECTORY()+self.name+"/"
372
- path_model = MODELS_DIRECTORY()+self.name+"/"
603
+ """
604
+ Serializes model and EMA model (if any) in both state_dict and full model formats.
605
+ Also saves optimizer states and YAML config snapshot.
606
+ """
607
+ path_checkpoint = checkpoints_directory() + self.name + "/"
608
+ path_model = models_directory() + self.name + "/"
373
609
  if os.path.exists(path_checkpoint) and os.listdir(path_checkpoint):
374
- for dir in [path_model, "{}Serialized/".format(path_model), "{}StateDict/".format(path_model)]:
375
- if not os.path.exists(dir):
376
- os.makedirs(dir)
610
+ for directory in [
611
+ path_model,
612
+ f"{path_model}Serialized/",
613
+ f"{path_model}StateDict/",
614
+ ]:
615
+ if not os.path.exists(directory):
616
+ os.makedirs(directory)
377
617
 
378
618
  for name in sorted(os.listdir(path_checkpoint)):
379
- checkpoint = torch.load(path_checkpoint+name, weights_only=False, map_location='cpu')
619
+ checkpoint = torch.load(path_checkpoint + name, weights_only=True, map_location="cpu")
380
620
  self.model.load(checkpoint, init=False, ema=False)
381
621
 
382
- torch.save(self.model, "{}Serialized/{}".format(path_model, name))
383
- torch.save({"Model" : self.model.state_dict()}, "{}StateDict/{}".format(path_model, name))
384
-
385
- if self.modelEMA is not None:
386
- self.modelEMA.module.load(checkpoint, init=False, ema=True)
387
- torch.save(self.modelEMA.module, "{}Serialized/{}".format(path_model, DATE()+"_EMA.pt"))
388
- torch.save({"Model_EMA" : self.modelEMA.module.state_dict()}, "{}StateDict/{}".format(path_model, DATE()+"_EMA.pt"))
389
-
390
- os.rename(self.config_namefile, self.config_namefile.replace(".yml", "")+"_"+str(self.it)+".yml")
391
-
392
- def _avg_fn(self, averaged_model_parameter, model_parameter, num_averaged):
393
- return (1-self.ema_decay) * averaged_model_parameter + self.ema_decay * model_parameter
394
-
622
+ torch.save(self.model, f"{path_model}Serialized/{name}")
623
+ torch.save(
624
+ {"Model": self.model.state_dict()},
625
+ f"{path_model}StateDict/{name}",
626
+ )
627
+
628
+ if self.model_ema is not None:
629
+ self.model_ema.module.load(checkpoint, init=False, ema=True)
630
+ torch.save(
631
+ self.model_ema.module,
632
+ f"{path_model}Serialized/{current_date()}_EMA.pt",
633
+ )
634
+ torch.save(
635
+ {"Model_EMA": self.model_ema.module.state_dict()},
636
+ f"{path_model}StateDict/{current_date()}_EMA.pt",
637
+ )
638
+
639
+ os.rename(
640
+ self.config_namefile,
641
+ self.config_namefile.replace(".yml", "") + "_" + str(self.it) + ".yml",
642
+ )
643
+
644
+ def _avg_fn(self, averaged_model_parameter: float, model_parameter, num_averaged):
645
+ """
646
+ EMA update rule used by AveragedModel.
647
+
648
+ Returns:
649
+ torch.Tensor: Blended parameter using decay factor.
650
+ """
651
+ return (1 - self.ema_decay) * averaged_model_parameter + self.ema_decay * model_parameter
652
+
395
653
  def setup(self, world_size: int):
396
- state = State._member_map_[KONFAI_STATE()]
397
- if state != State.RESUME and os.path.exists(CHECKPOINTS_DIRECTORY()+self.name+"/"):
654
+ """
655
+ Initializes the training environment:
656
+ - Clears previous outputs (unless resuming)
657
+ - Initializes model and EMA
658
+ - Loads checkpoint (if resuming)
659
+ - Prepares dataloaders
660
+
661
+ Args:
662
+ world_size (int): Total number of distributed processes.
663
+ """
664
+ state = State[konfai_state()]
665
+ if state != State.RESUME and os.path.exists(checkpoints_directory() + self.name + "/"):
398
666
  if os.environ["KONFAI_OVERWRITE"] != "True":
399
- accept = input("The model {} already exists ! Do you want to overwrite it (yes,no) : ".format(self.name))
667
+ accept = input(f"The model {self.name} already exists ! Do you want to overwrite it (yes,no) : ")
400
668
  if accept != "yes":
401
669
  return
402
- for directory_path in [STATISTICS_DIRECTORY(), MODELS_DIRECTORY(), CHECKPOINTS_DIRECTORY(), SETUPS_DIRECTORY()]:
403
- if os.path.exists(directory_path+self.name+"/"):
404
- shutil.rmtree(directory_path+self.name+"/")
405
-
670
+ for directory_path in [
671
+ statistics_directory(),
672
+ models_directory(),
673
+ checkpoints_directory(),
674
+ setups_directory(),
675
+ ]:
676
+ if os.path.exists(directory_path + self.name + "/"):
677
+ shutil.rmtree(directory_path + self.name + "/")
678
+
406
679
  state_dict = {}
407
680
  if state != State.TRAIN:
408
681
  state_dict = self._load()
409
682
 
410
- self.model.init(self.autocast, state, self.dataset.getGroupsDest())
411
- self.model.init_outputsGroup()
412
- self.model._compute_channels_trace(self.model, self.model.in_channels, self.gradient_checkpoints, self.gpu_checkpoints)
683
+ self.model.init(self.autocast, state, self.dataset.get_groups_dest())
684
+ self.model.init_outputs_group()
685
+ self.model._compute_channels_trace(
686
+ self.model,
687
+ self.model.in_channels,
688
+ self.gradient_checkpoints,
689
+ self.gpu_checkpoints,
690
+ )
413
691
  self.model.load(state_dict, init=True, ema=False)
414
692
  if self.ema_decay > 0:
415
- self.modelEMA = AveragedModel(self.model, avg_fn=self._avg_fn)
693
+ self.model_ema = AveragedModel(self.model, avg_fn=self._avg_fn)
416
694
  if state_dict is not None:
417
- self.modelEMA.module.load(state_dict, init=False, ema=True)
418
-
419
- if not os.path.exists(SETUPS_DIRECTORY()+self.name+"/"):
420
- os.makedirs(SETUPS_DIRECTORY()+self.name+"/")
421
- shutil.copyfile(self.config_namefile_src+".yml", self.config_namefile)
422
-
423
- self.dataloader = self.dataset.getData(world_size//self.size)
424
-
425
- def run_process(self, world_size: int, global_rank: int, local_rank: int, dataloaders: list[DataLoader]):
426
- model = Network.to(self.model, local_rank*self.size)
427
- model = DDP(model, static_graph=True) if torch.cuda.is_available() else CPU_Model(model)
428
- if self.modelEMA is not None:
429
- self.modelEMA.module = Network.to(self.modelEMA.module, local_rank)
430
- with _Trainer(world_size, global_rank, local_rank, self.size, self.name, self.early_stopping, self.data_log, self.save_checkpoint_mode, self.epochs, self.epoch, self.autocast, self.it_validation, self.it, model, self.modelEMA, *dataloaders) as t:
431
- t.run()
695
+ self.model_ema.module.load(state_dict, init=False, ema=True)
696
+
697
+ if not os.path.exists(setups_directory() + self.name + "/"):
698
+ os.makedirs(setups_directory() + self.name + "/")
699
+ shutil.copyfile(self.config_namefile_src + ".yml", self.config_namefile)
700
+
701
+ self.dataloader = self.dataset.get_data(world_size // self.size)
702
+
703
+ def run_process(
704
+ self,
705
+ world_size: int,
706
+ global_rank: int,
707
+ local_rank: int,
708
+ dataloaders: list[DataLoader],
709
+ ):
710
+ """
711
+ Launches the actual training process via internal `_Trainer` class.
712
+ Wraps model with DDP or CPU fallback, attaches EMA, and starts training.
713
+
714
+ Args:
715
+ world_size (int): Total number of distributed processes.
716
+ global_rank (int): Global rank of the current process.
717
+ local_rank (int): Local rank within the node.
718
+ dataloaders (list[DataLoader]): Training and validation dataloaders.
719
+ """
720
+ model = Network.to(self.model, local_rank * self.size)
721
+ model = DDP(model, static_graph=True) if torch.cuda.is_available() else CPUModel(model)
722
+ if self.model_ema is not None:
723
+ self.model_ema.module = Network.to(self.model_ema.module, local_rank)
724
+ with _Trainer(
725
+ world_size,
726
+ global_rank,
727
+ local_rank,
728
+ self.size,
729
+ self.name,
730
+ self.early_stopping,
731
+ self.data_log,
732
+ self.save_checkpoint_mode,
733
+ self.epochs,
734
+ self.epoch,
735
+ self.autocast,
736
+ self.it_validation,
737
+ self.it,
738
+ model,
739
+ self.model_ema,
740
+ *dataloaders,
741
+ ) as t:
742
+ t.run()