konfai 1.1.8__py3-none-any.whl → 1.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of konfai might be problematic. Click here for more details.

Files changed (36) hide show
  1. konfai/__init__.py +59 -14
  2. konfai/data/augmentation.py +457 -286
  3. konfai/data/data_manager.py +533 -316
  4. konfai/data/patching.py +300 -183
  5. konfai/data/transform.py +408 -275
  6. konfai/evaluator.py +325 -68
  7. konfai/main.py +71 -22
  8. konfai/metric/measure.py +360 -244
  9. konfai/metric/schedulers.py +24 -13
  10. konfai/models/classification/convNeXt.py +187 -81
  11. konfai/models/classification/resnet.py +272 -58
  12. konfai/models/generation/cStyleGan.py +233 -59
  13. konfai/models/generation/ddpm.py +348 -121
  14. konfai/models/generation/diffusionGan.py +757 -358
  15. konfai/models/generation/gan.py +177 -53
  16. konfai/models/generation/vae.py +140 -40
  17. konfai/models/registration/registration.py +135 -52
  18. konfai/models/representation/representation.py +57 -23
  19. konfai/models/segmentation/NestedUNet.py +339 -68
  20. konfai/models/segmentation/UNet.py +140 -30
  21. konfai/network/blocks.py +331 -187
  22. konfai/network/network.py +795 -427
  23. konfai/predictor.py +644 -238
  24. konfai/trainer.py +509 -222
  25. konfai/utils/ITK.py +191 -106
  26. konfai/utils/config.py +152 -95
  27. konfai/utils/dataset.py +326 -455
  28. konfai/utils/utils.py +497 -249
  29. {konfai-1.1.8.dist-info → konfai-1.2.0.dist-info}/METADATA +1 -3
  30. konfai-1.2.0.dist-info/RECORD +38 -0
  31. konfai/utils/registration.py +0 -199
  32. konfai-1.1.8.dist-info/RECORD +0 -39
  33. {konfai-1.1.8.dist-info → konfai-1.2.0.dist-info}/WHEEL +0 -0
  34. {konfai-1.1.8.dist-info → konfai-1.2.0.dist-info}/entry_points.txt +0 -0
  35. {konfai-1.1.8.dist-info → konfai-1.2.0.dist-info}/licenses/LICENSE +0 -0
  36. {konfai-1.1.8.dist-info → konfai-1.2.0.dist-info}/top_level.txt +0 -0
konfai/trainer.py CHANGED
@@ -1,62 +1,87 @@
1
- import torch
2
- from torch.utils.data import DataLoader
3
- import tqdm
4
- import numpy as np
5
1
  import os
6
- from typing import Union
7
- from torch.nn.parallel import DistributedDataParallel as DDP
8
-
9
2
  import shutil
10
- from torch.utils.tensorboard.writer import SummaryWriter
11
- from torch.optim.swa_utils import AveragedModel
3
+
4
+ import torch
12
5
  import torch.distributed as dist
6
+ import tqdm
7
+ from torch.nn.parallel import DistributedDataParallel as DDP # noqa: N817
8
+ from torch.optim.swa_utils import AveragedModel
9
+ from torch.utils.data import DataLoader
10
+ from torch.utils.tensorboard.writer import SummaryWriter
13
11
 
14
- from konfai import MODELS_DIRECTORY, CHECKPOINTS_DIRECTORY, STATISTICS_DIRECTORY, SETUPS_DIRECTORY, CONFIG_FILE, MODEL, DATE, KONFAI_STATE
12
+ from konfai import (
13
+ checkpoints_directory,
14
+ config_file,
15
+ current_date,
16
+ konfai_state,
17
+ path_to_models,
18
+ setups_directory,
19
+ statistics_directory,
20
+ )
15
21
  from konfai.data.data_manager import DataTrain
22
+ from konfai.network.network import CPUModel, ModelLoader, NetState, Network
16
23
  from konfai.utils.config import config
17
- from konfai.utils.utils import State, DataLog, DistributedObject, description, TrainerError
18
- from konfai.network.network import Network, ModelLoader, NetState, CPU_Model
24
+ from konfai.utils.utils import DataLog, DistributedObject, State, TrainerError, description
25
+
19
26
 
20
27
  class EarlyStoppingBase:
21
28
 
22
29
  def __init__(self):
23
30
  pass
24
31
 
25
- def isStopped(self) -> bool:
32
+ def is_stopped(self) -> bool:
33
+
26
34
  return False
27
35
 
28
- def getScore(self, values: dict[str, float]):
29
- return sum([i for i in values.values()])
30
-
36
+ def get_score(self, values: dict[str, float]):
37
+ return sum(list(values.values()))
38
+
31
39
  def __call__(self, current_score: float) -> bool:
32
40
  return False
33
41
 
42
+
34
43
  class EarlyStopping(EarlyStoppingBase):
44
+ """
45
+ Implements early stopping logic with configurable patience and monitored metrics.
46
+
47
+ Attributes:
48
+ monitor (list[str]): Metrics to monitor.
49
+ patience (int): Number of checks with no improvement before stopping.
50
+ min_delta (float): Minimum change to qualify as improvement.
51
+ mode (str): "min" or "max" depending on optimization direction.
52
+ """
35
53
 
36
54
  @config("EarlyStopping")
37
- def __init__(self, monitor: Union[list[str], None] = [], patience: int=10, min_delta: float=0.0, mode: str="min"):
55
+ def __init__(
56
+ self,
57
+ monitor: list[str] | None = None,
58
+ patience: int = 10,
59
+ min_delta: float = 0.0,
60
+ mode: str = "min",
61
+ ):
38
62
  super().__init__()
39
63
  self.monitor = [] if monitor is None else monitor
40
64
  self.patience = patience
41
65
  self.min_delta = min_delta
42
66
  self.mode = mode
43
67
  self.counter = 0
44
- self.best_score = None
68
+ self.best_score: float | None = None
45
69
  self.early_stop = False
46
70
 
47
- def isStopped(self) -> bool:
71
+ def is_stopped(self) -> bool:
48
72
  return self.early_stop
49
73
 
50
- def getScore(self, values: dict[str, float]):
74
+ def get_score(self, values: dict[str, float]):
51
75
  if len(self.monitor) == 0:
52
- return super().getScore(values)
76
+ return super().get_score(values)
53
77
  for v in self.monitor:
54
78
  if v not in values.keys():
55
79
  raise TrainerError(
56
80
  "Metric '{}' specified in EarlyStopping.monitor not found in logged values. ",
57
- "Available keys: {}. Please check your configuration.".format(v, list(values.keys())))
81
+ f"Available keys: {v}. Please check your configuration.",
82
+ )
58
83
  return sum([i for v, i in values.items() if v in self.monitor])
59
-
84
+
60
85
  def __call__(self, current_score: float) -> bool:
61
86
  if self.best_score is None:
62
87
  self.best_score = current_score
@@ -80,14 +105,47 @@ class EarlyStopping(EarlyStoppingBase):
80
105
 
81
106
  return self.early_stop
82
107
 
83
- class _Trainer():
84
108
 
85
- def __init__(self, world_size: int, global_rank: int, local_rank: int, size: int, train_name: str, early_stopping: EarlyStopping, data_log: Union[list[str], None] , save_checkpoint_mode: str, epochs: int, epoch: int, autocast: bool, it_validation: Union[int, None], it: int, model: Union[DDP, CPU_Model], modelEMA: AveragedModel, dataloader_training: DataLoader, dataloader_validation: Union[DataLoader, None] = None) -> None:
86
- self.world_size = world_size
109
+ class _Trainer:
110
+ """
111
+ Internal class for managing the training loop in a distributed or standalone setting.
112
+
113
+ Handles:
114
+ - Epoch iteration with training and optional validation
115
+ - Mixed precision support (autocast)
116
+ - Exponential Moving Average (EMA) model tracking
117
+ - Early stopping
118
+ - Logging to TensorBoard
119
+ - Model checkpoint saving and selection (ALL or BEST)
120
+
121
+ This class is intended to be used via a context manager
122
+ (`with _Trainer(...) as trainer:`) inside the public `Trainer` class.
123
+ """
124
+
125
+ def __init__(
126
+ self,
127
+ world_size: int,
128
+ global_rank: int,
129
+ local_rank: int,
130
+ size: int,
131
+ train_name: str,
132
+ early_stopping: EarlyStopping | None,
133
+ data_log: list[str] | None,
134
+ save_checkpoint_mode: str,
135
+ epochs: int,
136
+ epoch: int,
137
+ autocast: bool,
138
+ it_validation: int | None,
139
+ it: int,
140
+ model: DDP | CPUModel,
141
+ model_ema: AveragedModel,
142
+ dataloader_training: DataLoader,
143
+ dataloader_validation: DataLoader | None = None,
144
+ ) -> None:
145
+ self.world_size = world_size
87
146
  self.global_rank = global_rank
88
147
  self.local_rank = local_rank
89
148
  self.size = size
90
-
91
149
  self.save_checkpoint_mode = save_checkpoint_mode
92
150
  self.train_name = train_name
93
151
  self.epochs = epochs
@@ -96,58 +154,97 @@ class _Trainer():
96
154
  self.dataloader_training = dataloader_training
97
155
  self.dataloader_validation = dataloader_validation
98
156
  self.autocast = autocast
99
- self.modelEMA = modelEMA
100
- self.early_stopping = EarlyStoppingBase() if early_stopping is None else early_stopping
101
-
102
- self.it_validation = it_validation
103
- if self.it_validation is None:
104
- self.it_validation = len(dataloader_training)
157
+ self.model_ema = model_ema
158
+ self.early_stopping = EarlyStoppingBase() if early_stopping is None else early_stopping
159
+
160
+ self.it_validation = len(dataloader_training) if it_validation is None else it_validation
105
161
  self.it = it
106
- self.tb = SummaryWriter(log_dir = STATISTICS_DIRECTORY()+self.train_name+"/")
107
- self.data_log : dict[str, tuple[DataLog, int]] = {}
162
+ self.tb = SummaryWriter(log_dir=statistics_directory() + self.train_name + "/")
163
+ self.data_log: dict[str, tuple[DataLog, int]] = {}
108
164
  if data_log is not None:
109
165
  for data in data_log:
110
- self.data_log[data.split("/")[0].replace(":", ".")] = (DataLog.__getitem__(data.split("/")[1]).value[0], int(data.split("/")[2]))
111
-
166
+ self.data_log[data.split("/")[0].replace(":", ".")] = (
167
+ DataLog[data.split("/")[1]],
168
+ int(data.split("/")[2]),
169
+ )
170
+
112
171
  def __enter__(self):
113
172
  return self
114
-
115
- def __exit__(self, type, value, traceback):
173
+
174
+ def __exit__(self, exc_type, value, traceback):
175
+ """Closes the SummaryWriter if used."""
116
176
  if self.tb is not None:
117
177
  self.tb.close()
118
178
 
119
179
  def run(self) -> None:
180
+ """
181
+ Launches the training loop, performing one epoch at a time.
182
+ Triggers early stopping and resets data augmentations between epochs.
183
+ """
120
184
  self.dataloader_training.dataset.load("Train")
121
185
  if self.dataloader_validation is not None:
122
186
  self.dataloader_validation.dataset.load("Validation")
123
- with tqdm.tqdm(iterable = range(self.epoch, self.epochs), leave=False, total=self.epochs, initial=self.epoch, desc="Progress") as epoch_tqdm:
187
+ if State[konfai_state()] != State.TRAIN:
188
+ self._validate()
189
+
190
+ with tqdm.tqdm(
191
+ iterable=range(self.epoch, self.epochs),
192
+ leave=False,
193
+ total=self.epochs,
194
+ initial=self.epoch,
195
+ desc="Progress",
196
+ ) as epoch_tqdm:
124
197
  for self.epoch in epoch_tqdm:
125
198
  self.train()
126
- if self.early_stopping.isStopped():
199
+ if self.early_stopping.is_stopped():
127
200
  break
128
- self.dataloader_training.dataset.resetAugmentation("Train")
129
-
130
- def getInput(self, data_dict : dict[str, tuple[torch.Tensor, int, int, int, str, bool]]) -> dict[tuple[str, bool], torch.Tensor]:
131
- return {(k, v[5][0].item()) : v[0] for k, v in data_dict.items()}
201
+ self.dataloader_training.dataset.reset_augmentation("Train")
132
202
 
133
- def train(self) -> None:
134
- self.model.train()
135
- self.model.module.setState(NetState.TRAIN)
136
- if self.modelEMA is not None:
137
- self.modelEMA.eval()
138
- self.modelEMA.module.setState(NetState.TRAIN)
203
+ def get_input(
204
+ self,
205
+ data_dict: dict[str, tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, list[str], torch.Tensor]],
206
+ ) -> dict[tuple[str, bool], torch.Tensor]:
207
+ """
208
+ Extracts input tensors from the data dict for model input.
139
209
 
140
- desc = lambda : "Training : {}".format(description(self.model, self.modelEMA))
210
+ Args:
211
+ data_dict (dict): Dictionary with data items structured as tuples.
141
212
 
142
- with tqdm.tqdm(iterable = enumerate(self.dataloader_training), desc = desc(), total=len(self.dataloader_training), leave=False, ncols=0) as batch_iter:
213
+ Returns:
214
+ dict: Mapping from (group, bool_flag) to input tensors.
215
+ """
216
+ return {(k, v[5][0].item()): v[0] for k, v in data_dict.items()}
217
+
218
+ def train(self) -> None:
219
+ """
220
+ Performs a full training epoch with support for:
221
+ - mixed precision
222
+ - DDP / CPU training
223
+ - EMA updates
224
+ - loss logging and checkpoint saving
225
+ - validation at configurable iteration interval
226
+ """
227
+ self.model.train()
228
+ self.model.module.set_state(NetState.TRAIN)
229
+ if self.model_ema is not None:
230
+ self.model_ema.eval()
231
+ self.model_ema.module.set_state(NetState.TRAIN)
232
+
233
+ with tqdm.tqdm(
234
+ iterable=enumerate(self.dataloader_training),
235
+ desc=f"Training : {description(self.model, self.model_ema)}",
236
+ total=len(self.dataloader_training),
237
+ leave=False,
238
+ ncols=0,
239
+ ) as batch_iter:
143
240
  for _, data_dict in batch_iter:
144
- with torch.amp.autocast('cuda', enabled=self.autocast):
145
- input = self.getInput(data_dict)
146
- self.model(input)
241
+ with torch.amp.autocast("cuda", enabled=self.autocast):
242
+ input_data_dict = self.get_input(data_dict)
243
+ self.model(input_data_dict)
147
244
  self.model.module.backward(self.model.module)
148
- if self.modelEMA is not None:
149
- self.modelEMA.update_parameters(self.model)
150
- self.modelEMA.module(input)
245
+ if self.model_ema is not None:
246
+ self.model_ema.update_parameters(self.model)
247
+ self.model_ema.module(input_data_dict)
151
248
  self.it += 1
152
249
  if (self.it) % self.it_validation == 0:
153
250
  loss = self._train_log(data_dict)
@@ -155,76 +252,109 @@ class _Trainer():
155
252
  if self.dataloader_validation is not None:
156
253
  loss = self._validate()
157
254
  self.model.module.update_lr()
158
- score = self.early_stopping.getScore(loss)
255
+ score = self.early_stopping.get_score(loss)
159
256
  self.checkpoint_save(score)
160
257
  if self.early_stopping(score):
161
258
  break
162
-
163
259
 
164
- batch_iter.set_description(desc())
165
-
166
-
260
+ batch_iter.set_description(f"Training : {description(self.model, self.model_ema)}")
261
+
167
262
  @torch.no_grad()
168
263
  def _validate(self) -> float:
264
+ """
265
+ Executes the validation phase, evaluates loss and metrics.
266
+ Updates model states and resets augmentation for validation set.
267
+
268
+ Returns:
269
+ float: Validation loss.
270
+ """
271
+ if self.dataloader_validation is None:
272
+ return 0
169
273
  self.model.eval()
170
- self.model.module.setState(NetState.PREDICTION)
171
- if self.modelEMA is not None:
172
- self.modelEMA.module.setState(NetState.PREDICTION)
274
+ self.model.module.set_state(NetState.PREDICTION)
275
+ if self.model_ema is not None:
276
+ self.model_ema.module.set_state(NetState.PREDICTION)
173
277
 
174
- desc = lambda : "Validation : {}".format(description(self.model, self.modelEMA))
175
278
  data_dict = None
176
- with tqdm.tqdm(iterable = enumerate(self.dataloader_validation), desc = desc(), total=len(self.dataloader_validation), leave=False, ncols=0) as batch_iter:
279
+ with tqdm.tqdm(
280
+ iterable=enumerate(self.dataloader_validation),
281
+ desc=f"Validation : {description(self.model, self.model_ema)}",
282
+ total=len(self.dataloader_validation),
283
+ leave=False,
284
+ ncols=0,
285
+ ) as batch_iter:
177
286
  for _, data_dict in batch_iter:
178
- input = self.getInput(data_dict)
179
- self.model(input)
180
- if self.modelEMA is not None:
181
- self.modelEMA.module(input)
287
+ input_data_dict = self.get_input(data_dict)
288
+ self.model(input_data_dict)
289
+ if self.model_ema is not None:
290
+ self.model_ema.module(input_data_dict)
182
291
 
183
- batch_iter.set_description(desc())
184
- self.dataloader_validation.dataset.resetAugmentation("Validation")
292
+ batch_iter.set_description(f"Validation : {description(self.model, self.model_ema)}")
293
+ self.dataloader_validation.dataset.reset_augmentation("Validation")
185
294
  dist.barrier()
186
295
  self.model.train()
187
- self.model.module.setState(NetState.TRAIN)
188
- if self.modelEMA is not None:
189
- self.modelEMA.module.setState(NetState.TRAIN)
296
+ self.model.module.set_state(NetState.TRAIN)
297
+ if self.model_ema is not None:
298
+ self.model_ema.module.set_state(NetState.TRAIN)
190
299
  return self._validation_log(data_dict)
191
300
 
192
301
  def checkpoint_save(self, loss: float) -> None:
302
+ """
303
+ Saves model and optimizer states. Keeps either all checkpoints or only the best one.
304
+
305
+ Args:
306
+ loss (float): Current loss used for best checkpoint selection.
307
+ """
193
308
  if self.global_rank != 0:
194
309
  return
195
310
 
196
- path = CHECKPOINTS_DIRECTORY()+self.train_name+"/"
311
+ path = checkpoints_directory() + self.train_name + "/"
197
312
  os.makedirs(path, exist_ok=True)
198
313
 
199
- name = DATE() + ".pt"
314
+ name = current_date() + ".pt"
200
315
  save_path = os.path.join(path, name)
201
316
 
202
317
  save_dict = {
203
- "epoch": self.epoch,
204
- "it": self.it,
205
- "loss": loss,
206
- "Model": self.model.module.state_dict()
318
+ "epoch": self.epoch,
319
+ "it": self.it,
320
+ "loss": loss,
321
+ "Model": self.model.module.state_dict(),
207
322
  }
208
323
 
209
- if self.modelEMA is not None:
210
- save_dict["Model_EMA"] = self.modelEMA.module.state_dict()
211
-
212
- save_dict.update({'{}_optimizer_state_dict'.format(name): network.optimizer.state_dict() for name, network in self.model.module.getNetworks().items() if network.optimizer is not None})
213
- save_dict.update({'{}_it'.format(name): network._it for name, network in self.model.module.getNetworks().items() if network.optimizer is not None})
214
- save_dict.update({'{}_nb_lr_update'.format(name): network._nb_lr_update for name, network in self.model.module.getNetworks().items() if network.optimizer is not None})
215
-
324
+ if self.model_ema is not None:
325
+ save_dict["Model_EMA"] = self.model_ema.module.state_dict()
326
+
327
+ save_dict.update(
328
+ {
329
+ f"{name}_optimizer_state_dict": network.optimizer.state_dict()
330
+ for name, network in self.model.module.get_networks().items()
331
+ if network.optimizer is not None
332
+ }
333
+ )
334
+ save_dict.update(
335
+ {
336
+ f"{name}_it": network._it
337
+ for name, network in self.model.module.get_networks().items()
338
+ if network.optimizer is not None
339
+ }
340
+ )
341
+ save_dict.update(
342
+ {
343
+ f"{name}_nb_lr_update": network._nb_lr_update
344
+ for name, network in self.model.module.get_networks().items()
345
+ if network.optimizer is not None
346
+ }
347
+ )
348
+
216
349
  torch.save(save_dict, save_path)
217
350
 
218
351
  if self.save_checkpoint_mode == "BEST":
219
- all_checkpoints = sorted([
220
- os.path.join(path, f)
221
- for f in os.listdir(path) if f.endswith(".pt")
222
- ])
352
+ all_checkpoints = sorted([os.path.join(path, f) for f in os.listdir(path) if f.endswith(".pt")])
223
353
  best_ckpt = None
224
- best_loss = float('inf')
354
+ best_loss = float("inf")
225
355
 
226
356
  for f in all_checkpoints:
227
- d = torch.load(f, weights_only=False)
357
+ d = torch.load(f, weights_only=True)
228
358
  if d.get("loss", float("inf")) < best_loss:
229
359
  best_loss = d["loss"]
230
360
  best_ckpt = f
@@ -234,79 +364,165 @@ class _Trainer():
234
364
  os.remove(f)
235
365
 
236
366
  @torch.no_grad()
237
- def _log(self, type_log: str, data_dict : dict[str, tuple[torch.Tensor, int, int, int]]) -> dict[str, float]:
238
- models: dict[str, Network] = {"" : self.model.module}
239
- if self.modelEMA is not None:
240
- models["_EMA"] = self.modelEMA.module
241
-
242
- measures = DistributedObject.getMeasure(self.world_size, self.global_rank, self.local_rank*self.size+self.size-1, models, self.it_validation if type_log == "Training" else len(self.dataloader_validation))
243
-
367
+ def _log(
368
+ self,
369
+ type_log: str,
370
+ data_dict: dict[str, tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, list[str], torch.Tensor]],
371
+ ) -> dict[str, float] | None:
372
+ """
373
+ Logs losses, metrics and optionally images to TensorBoard.
374
+
375
+ Args:
376
+ type_log (str): "Training" or "Validation".
377
+ data_dict (dict): Dictionary of data from current batch.
378
+
379
+ Returns:
380
+ dict[str, float] | None: Dictionary of aggregated losses and metrics if rank == 0.
381
+ """
382
+ models: dict[str, Network] = {"": self.model.module}
383
+ if self.model_ema is not None:
384
+ models["_EMA"] = self.model_ema.module
385
+
386
+ measures = DistributedObject.get_measure(
387
+ self.world_size,
388
+ self.global_rank,
389
+ self.local_rank * self.size + self.size - 1,
390
+ models,
391
+ (
392
+ self.it_validation
393
+ if type_log == "Training" or self.dataloader_validation is None
394
+ else len(self.dataloader_validation)
395
+ ),
396
+ )
397
+
244
398
  if self.global_rank == 0:
245
399
  images_log = []
246
- if len(self.data_log):
400
+ if len(self.data_log):
247
401
  for name, data_type in self.data_log.items():
248
402
  if name in data_dict:
249
- data_type[0](self.tb, "{}/{}".format(type_log ,name), data_dict[name][0][:self.data_log[name][1]].detach().cpu().numpy(), self.it)
403
+ data_type[0](
404
+ self.tb,
405
+ f"{type_log}/{name}",
406
+ data_dict[name][0][: self.data_log[name][1]].detach().cpu().numpy(),
407
+ self.it,
408
+ )
250
409
  else:
251
410
  images_log.append(name.replace(":", "."))
252
-
411
+
253
412
  for label, model in models.items():
254
- for name, network in model.getNetworks().items():
413
+ for name, network in model.get_networks().items():
255
414
  if network.measure is not None:
256
- self.tb.add_scalars("{}/{}/Loss/{}".format(type_log, name, label), {k : v[1] for k, v in measures["{}{}".format(name, label)][0].items()}, self.it)
257
- self.tb.add_scalars("{}/{}/Loss_weight/{}".format(type_log, name, label), {k : v[0] for k, v in measures["{}{}".format(name, label)][0].items()}, self.it)
415
+ self.tb.add_scalars(
416
+ f"{type_log}/{name}/Loss/{label}",
417
+ {k: v[1] for k, v in measures[f"{name}{label}"][0].items()},
418
+ self.it,
419
+ )
420
+ self.tb.add_scalars(
421
+ f"{type_log}/{name}/Loss_weight/{label}",
422
+ {k: v[0] for k, v in measures[f"{name}{label}"][0].items()},
423
+ self.it,
424
+ )
425
+
426
+ self.tb.add_scalars(
427
+ f"{type_log}/{name}/Metric/{label}",
428
+ {k: v[1] for k, v in measures[f"{name}{label}"][1].items()},
429
+ self.it,
430
+ )
431
+ self.tb.add_scalars(
432
+ f"{type_log}/{name}/Metric_weight/{label}",
433
+ {k: v[0] for k, v in measures[f"{name}{label}"][1].items()},
434
+ self.it,
435
+ )
258
436
 
259
- self.tb.add_scalars("{}/{}/Metric/{}".format(type_log, name, label), {k : v[1] for k, v in measures["{}{}".format(name, label)][1].items()}, self.it)
260
- self.tb.add_scalars("{}/{}/Metric_weight/{}".format(type_log, name, label), {k : v[0] for k, v in measures["{}{}".format(name, label)][1].items()}, self.it)
261
-
262
437
  if len(images_log):
263
- for name, layer, _ in model.get_layers([v.to(0) for k, v in self.getInput(data_dict).items() if k[1]], images_log):
264
- self.data_log[name][0](self.tb, "{}/{}{}".format(type_log, name, label), layer[:self.data_log[name][1]].detach().cpu().numpy(), self.it)
265
-
438
+ for name, layer, _ in model.get_layers(
439
+ [v.to(0) for k, v in self.get_input(data_dict).items() if k[1]],
440
+ images_log,
441
+ ):
442
+ self.data_log[name][0](
443
+ self.tb,
444
+ f"{type_log}/{name}{label}",
445
+ layer[: self.data_log[name][1]].detach().cpu().numpy(),
446
+ self.it,
447
+ )
448
+
266
449
  if type_log == "Training":
267
- for name, network in self.model.module.getNetworks().items():
450
+ for name, network in self.model.module.get_networks().items():
268
451
  if network.optimizer is not None:
269
- self.tb.add_scalar("{}/{}/Learning Rate".format(type_log, name), network.optimizer.param_groups[0]['lr'], self.it)
270
-
452
+ self.tb.add_scalar(
453
+ f"{type_log}/{name}/Learning Rate",
454
+ network.optimizer.param_groups[0]["lr"],
455
+ self.it,
456
+ )
457
+
271
458
  if self.global_rank == 0:
272
459
  loss = {}
273
- for name, network in self.model.module.getNetworks().items():
460
+ for name, network in self.model.module.get_networks().items():
274
461
  if network.measure is not None:
275
- loss.update({k : v[1] for k, v in measures["{}{}".format(name, label)][0].items()})
276
- loss.update({k : v[1] for k, v in measures["{}{}".format(name, label)][1].items()})
462
+ loss.update({k: v[1] for k, v in measures[f"{name}{label}"][0].items()})
463
+ loss.update({k: v[1] for k, v in measures[f"{name}{label}"][1].items()})
277
464
  return loss
278
465
  return None
279
-
466
+
280
467
  @torch.no_grad()
281
- def _train_log(self, data_dict : dict[str, tuple[torch.Tensor, int, int, int]]) -> dict[str, float]:
468
+ def _train_log(self, data_dict: dict[str, tuple[torch.Tensor, int, int, int]]) -> dict[str, float]:
469
+ """Wrapper for _log during training."""
282
470
  return self._log("Training", data_dict)
283
471
 
284
472
  @torch.no_grad()
285
- def _validation_log(self, data_dict : dict[str, tuple[torch.Tensor, int, int, int]]) -> dict[str, float]:
473
+ def _validation_log(self, data_dict: dict[str, tuple[torch.Tensor, int, int, int]]) -> dict[str, float]:
474
+ """Wrapper for _log during validation."""
286
475
  return self._log("Validation", data_dict)
287
476
 
288
-
477
+
289
478
  class Trainer(DistributedObject):
479
+ """
480
+ Public API for training a model using the KonfAI framework.
481
+ Wraps setup, checkpointing, resuming, logging, and launching distributed _Trainer.
482
+
483
+ Main responsibilities:
484
+ - Initialization from config (via @config)
485
+ - Model and EMA setup
486
+ - Checkpoint loading and saving
487
+ - Distributed setup and launch
488
+
489
+ Args:
490
+ model (ModelLoader): Loader for model architecture.
491
+ dataset (DataTrain): Training/validation dataset.
492
+ train_name (str): Training session name.
493
+ manual_seed (int | None): Random seed.
494
+ epochs (int): Number of epochs to run.
495
+ it_validation (int | None): Validation interval.
496
+ autocast (bool): Enable AMP training.
497
+ gradient_checkpoints (list[str] | None): Modules to use gradient checkpointing on.
498
+ gpu_checkpoints (list[str] | None): Modules to pin on specific GPUs.
499
+ ema_decay (float): EMA decay factor.
500
+ data_log (list[str] | None): Logging instructions.
501
+ early_stopping (EarlyStopping | None): Optional early stopping config.
502
+ save_checkpoint_mode (str): Either "BEST" or "ALL".
503
+ """
290
504
 
291
505
  @config("Trainer")
292
- def __init__( self,
293
- model : ModelLoader = ModelLoader(),
294
- dataset : DataTrain = DataTrain(),
295
- train_name : str = "default:TRAIN_01",
296
- manual_seed : Union[int, None] = None,
297
- epochs: int = 100,
298
- it_validation : Union[int, None] = None,
299
- autocast : bool = False,
300
- gradient_checkpoints: Union[list[str], None] = None,
301
- gpu_checkpoints: Union[list[str], None] = None,
302
- ema_decay : float = 0,
303
- data_log: Union[list[str], None] = None,
304
- early_stopping: Union[EarlyStopping, None] = None,
305
- save_checkpoint_mode: str= "BEST") -> None:
506
+ def __init__(
507
+ self,
508
+ model: ModelLoader = ModelLoader(),
509
+ dataset: DataTrain = DataTrain(),
510
+ train_name: str = "default:TRAIN_01",
511
+ manual_seed: int | None = None,
512
+ epochs: int = 100,
513
+ it_validation: int | None = None,
514
+ autocast: bool = False,
515
+ gradient_checkpoints: list[str] | None = None,
516
+ gpu_checkpoints: list[str] | None = None,
517
+ ema_decay: float = 0,
518
+ data_log: list[str] | None = None,
519
+ early_stopping: EarlyStopping | None = None,
520
+ save_checkpoint_mode: str = "BEST",
521
+ ) -> None:
306
522
  if os.environ["KONFAI_CONFIG_MODE"] != "Done":
307
523
  exit(0)
308
524
  super().__init__(train_name)
309
- self.manual_seed = manual_seed
525
+ self.manual_seed = manual_seed
310
526
  self.dataset = dataset
311
527
  self.autocast = autocast
312
528
  self.epochs = epochs
@@ -314,118 +530,189 @@ class Trainer(DistributedObject):
314
530
  self.early_stopping = early_stopping
315
531
  self.it = 0
316
532
  self.it_validation = it_validation
317
- self.model = model.getModel(train=True)
533
+ self.model = model.get_model(train=True)
318
534
  self.ema_decay = ema_decay
319
- self.modelEMA : Union[torch.optim.swa_utils.AveragedModel, None] = None
535
+ self.model_ema: torch.optim.swa_utils.AveragedModel | None = None
320
536
  self.data_log = data_log
321
537
 
322
538
  modules = []
323
- for i,_ in self.model.named_modules():
539
+ for i, _ in self.model.named_modules():
324
540
  modules.append(i)
325
- for k in self.data_log:
326
- tmp = k.split("/")[0].replace(":", ".")
327
- if tmp not in self.dataset.getGroupsDest() and tmp not in modules:
328
- raise TrainerError( f"Invalid key '{tmp}' in `data_log`.",
329
- f"This key is neither a destination group from the dataset ({self.dataset.getGroupsDest()})",
330
- f"nor a valid module name in the model ({modules}).",
331
- "Please check your `data_log` configuration — it should reference either a model output or a dataset group.")
332
-
541
+ if self.data_log is not None:
542
+ for k in self.data_log:
543
+ tmp = k.split("/")[0].replace(":", ".")
544
+ if tmp not in self.dataset.get_groups_dest() and tmp not in modules:
545
+ raise TrainerError(
546
+ f"Invalid key '{tmp}' in `data_log`.",
547
+ f"This key is neither a destination group from the dataset ({self.dataset.get_groups_dest()})",
548
+ f"nor a valid module name in the model ({modules}).",
549
+ "Please check your `data_log` configuration,"
550
+ " it should reference either a model output or a dataset group.",
551
+ )
552
+
333
553
  self.gradient_checkpoints = gradient_checkpoints
334
554
  self.gpu_checkpoints = gpu_checkpoints
335
555
  self.save_checkpoint_mode = save_checkpoint_mode
336
- self.config_namefile_src = CONFIG_FILE().replace(".yml", "")
337
- self.config_namefile = SETUPS_DIRECTORY()+self.name+"/"+self.config_namefile_src.split("/")[-1]+"_"+str(self.it)+".yml"
338
- self.size = (len(self.gpu_checkpoints)+1 if self.gpu_checkpoints else 1)
339
-
340
- def __exit__(self, type, value, traceback):
341
- super().__exit__(type, value, traceback)
556
+ self.config_namefile_src = config_file().replace(".yml", "")
557
+ self.config_namefile = (
558
+ setups_directory() + self.name + "/" + self.config_namefile_src.split("/")[-1] + "_" + str(self.it) + ".yml"
559
+ )
560
+ self.size = len(self.gpu_checkpoints) + 1 if self.gpu_checkpoints else 1
561
+
562
+ def __exit__(self, exc_type, value, traceback):
563
+ """Exit training context and trigger save of model/checkpoints."""
564
+ super().__exit__(exc_type, value, traceback)
342
565
  self._save()
343
566
 
344
567
  def _load(self) -> dict[str, dict[str, torch.Tensor]]:
345
- if MODEL().startswith("https://"):
568
+ """
569
+ Loads a previously saved checkpoint from local disk or URL.
570
+
571
+ Returns:
572
+ dict: State dictionary loaded from checkpoint.
573
+ """
574
+ if path_to_models().startswith("https://"):
346
575
  try:
347
- state_dict = {MODEL().split(":")[1]: torch.hub.load_state_dict_from_url(url=MODEL().split(":")[0], map_location="cpu", check_hash=True)}
348
- except:
349
- raise Exception("Model : {} does not exist !".format(MODEL()))
576
+ state_dict = {
577
+ path_to_models().split(":")[1]: torch.hub.load_state_dict_from_url(
578
+ url=path_to_models().split(":")[0], map_location="cpu", check_hash=True
579
+ )
580
+ }
581
+ except Exception:
582
+ raise Exception(f"Model : {path_to_models()} does not exist !")
350
583
  else:
351
- if MODEL() != "":
584
+ if path_to_models() != "":
352
585
  path = ""
353
- name = MODEL()
586
+ name = path_to_models()
354
587
  else:
355
- path = CHECKPOINTS_DIRECTORY()+self.name+"/"
588
+ path = checkpoints_directory() + self.name + "/"
356
589
  if os.listdir(path):
357
590
  name = sorted(os.listdir(path))[-1]
358
-
359
- if os.path.exists(path+name):
360
- state_dict = torch.load(path+name, weights_only=False, map_location="cpu")
591
+
592
+ if os.path.exists(path + name):
593
+ state_dict = torch.load(path + name, weights_only=True, map_location="cpu")
361
594
  else:
362
- raise Exception("Model : {} does not exist !".format(self.name))
363
-
595
+ raise Exception(f"Model : {self.name} does not exist !")
596
+
364
597
  if "epoch" in state_dict:
365
- self.epoch = state_dict['epoch']
598
+ self.epoch = state_dict["epoch"]
366
599
  if "it" in state_dict:
367
- self.it = state_dict['it']
600
+ self.it = state_dict["it"]
368
601
  return state_dict
369
602
 
370
603
  def _save(self) -> None:
371
- path_checkpoint = CHECKPOINTS_DIRECTORY()+self.name+"/"
372
- path_model = MODELS_DIRECTORY()+self.name+"/"
373
- if os.path.exists(path_checkpoint) and os.listdir(path_checkpoint):
374
- for dir in [path_model, "{}Serialized/".format(path_model), "{}StateDict/".format(path_model)]:
375
- if not os.path.exists(dir):
376
- os.makedirs(dir)
377
-
378
- for name in sorted(os.listdir(path_checkpoint)):
379
- checkpoint = torch.load(path_checkpoint+name, weights_only=False, map_location='cpu')
380
- self.model.load(checkpoint, init=False, ema=False)
381
-
382
- torch.save(self.model, "{}Serialized/{}".format(path_model, name))
383
- torch.save({"Model" : self.model.state_dict()}, "{}StateDict/{}".format(path_model, name))
384
-
385
- if self.modelEMA is not None:
386
- self.modelEMA.module.load(checkpoint, init=False, ema=True)
387
- torch.save(self.modelEMA.module, "{}Serialized/{}".format(path_model, DATE()+"_EMA.pt"))
388
- torch.save({"Model_EMA" : self.modelEMA.module.state_dict()}, "{}StateDict/{}".format(path_model, DATE()+"_EMA.pt"))
389
-
390
- os.rename(self.config_namefile, self.config_namefile.replace(".yml", "")+"_"+str(self.it)+".yml")
391
-
392
- def _avg_fn(self, averaged_model_parameter, model_parameter, num_averaged):
393
- return (1-self.ema_decay) * averaged_model_parameter + self.ema_decay * model_parameter
394
-
604
+ """
605
+ Serializes model and EMA model (if any) in both state_dict and full model formats.
606
+ Also saves optimizer states and YAML config snapshot.
607
+ """
608
+ if os.path.exists(self.config_namefile):
609
+ os.rename(
610
+ self.config_namefile,
611
+ self.config_namefile.replace(".yml", "") + "_" + str(self.it) + ".yml",
612
+ )
613
+
614
+ def _avg_fn(self, averaged_model_parameter: float, model_parameter, num_averaged):
615
+ """
616
+ EMA update rule used by AveragedModel.
617
+
618
+ Returns:
619
+ torch.Tensor: Blended parameter using decay factor.
620
+ """
621
+ return (1 - self.ema_decay) * averaged_model_parameter + self.ema_decay * model_parameter
622
+
395
623
  def setup(self, world_size: int):
396
- state = State._member_map_[KONFAI_STATE()]
397
- if state != State.RESUME and os.path.exists(CHECKPOINTS_DIRECTORY()+self.name+"/"):
624
+ """
625
+ Initializes the training environment:
626
+ - Clears previous outputs (unless resuming)
627
+ - Initializes model and EMA
628
+ - Loads checkpoint (if resuming)
629
+ - Prepares dataloaders
630
+
631
+ Args:
632
+ world_size (int): Total number of distributed processes.
633
+ """
634
+ state = State[konfai_state()]
635
+ print(checkpoints_directory() + self.name + "/")
636
+ if state != State.RESUME and os.path.exists(checkpoints_directory() + self.name + "/"):
398
637
  if os.environ["KONFAI_OVERWRITE"] != "True":
399
- accept = input("The model {} already exists ! Do you want to overwrite it (yes,no) : ".format(self.name))
638
+ accept = input(f"The model {self.name} already exists ! Do you want to overwrite it (yes,no) : ")
400
639
  if accept != "yes":
401
640
  return
402
- for directory_path in [STATISTICS_DIRECTORY(), MODELS_DIRECTORY(), CHECKPOINTS_DIRECTORY(), SETUPS_DIRECTORY()]:
403
- if os.path.exists(directory_path+self.name+"/"):
404
- shutil.rmtree(directory_path+self.name+"/")
405
-
641
+ for directory_path in [
642
+ statistics_directory(),
643
+ checkpoints_directory(),
644
+ setups_directory(),
645
+ ]:
646
+ if os.path.exists(directory_path + self.name + "/"):
647
+ shutil.rmtree(directory_path + self.name + "/")
648
+
406
649
  state_dict = {}
407
650
  if state != State.TRAIN:
408
651
  state_dict = self._load()
409
652
 
410
- self.model.init(self.autocast, state, self.dataset.getGroupsDest())
411
- self.model.init_outputsGroup()
412
- self.model._compute_channels_trace(self.model, self.model.in_channels, self.gradient_checkpoints, self.gpu_checkpoints)
653
+ self.model.init(self.autocast, state, self.dataset.get_groups_dest())
654
+ self.model.init_outputs_group()
655
+ self.model._compute_channels_trace(
656
+ self.model,
657
+ self.model.in_channels,
658
+ self.gradient_checkpoints,
659
+ self.gpu_checkpoints,
660
+ )
413
661
  self.model.load(state_dict, init=True, ema=False)
414
662
  if self.ema_decay > 0:
415
- self.modelEMA = AveragedModel(self.model, avg_fn=self._avg_fn)
663
+ self.model_ema = AveragedModel(self.model, avg_fn=self._avg_fn)
416
664
  if state_dict is not None:
417
- self.modelEMA.module.load(state_dict, init=False, ema=True)
418
-
419
- if not os.path.exists(SETUPS_DIRECTORY()+self.name+"/"):
420
- os.makedirs(SETUPS_DIRECTORY()+self.name+"/")
421
- shutil.copyfile(self.config_namefile_src+".yml", self.config_namefile)
422
-
423
- self.dataloader = self.dataset.getData(world_size//self.size)
424
-
425
- def run_process(self, world_size: int, global_rank: int, local_rank: int, dataloaders: list[DataLoader]):
426
- model = Network.to(self.model, local_rank*self.size)
427
- model = DDP(model, static_graph=True) if torch.cuda.is_available() else CPU_Model(model)
428
- if self.modelEMA is not None:
429
- self.modelEMA.module = Network.to(self.modelEMA.module, local_rank)
430
- with _Trainer(world_size, global_rank, local_rank, self.size, self.name, self.early_stopping, self.data_log, self.save_checkpoint_mode, self.epochs, self.epoch, self.autocast, self.it_validation, self.it, model, self.modelEMA, *dataloaders) as t:
431
- t.run()
665
+ self.model_ema.module.load(state_dict, init=False, ema=True)
666
+
667
+ if not os.path.exists(setups_directory() + self.name + "/"):
668
+ os.makedirs(setups_directory() + self.name + "/")
669
+ shutil.copyfile(self.config_namefile_src + ".yml", self.config_namefile)
670
+
671
+ self.dataloader, train_names, validation_names = self.dataset.get_data(world_size // self.size)
672
+ with open(setups_directory() + self.name + "/Train_" + str(self.it) + ".txt", "w") as f:
673
+ for name in train_names:
674
+ f.write(name + "\n")
675
+ with open(setups_directory() + self.name + "/Validation_" + str(self.it) + ".txt", "w") as f:
676
+ for name in validation_names:
677
+ f.write(name + "\n")
678
+
679
+ def run_process(
680
+ self,
681
+ world_size: int,
682
+ global_rank: int,
683
+ local_rank: int,
684
+ dataloaders: list[DataLoader],
685
+ ):
686
+ """
687
+ Launches the actual training process via internal `_Trainer` class.
688
+ Wraps model with DDP or CPU fallback, attaches EMA, and starts training.
689
+
690
+ Args:
691
+ world_size (int): Total number of distributed processes.
692
+ global_rank (int): Global rank of the current process.
693
+ local_rank (int): Local rank within the node.
694
+ dataloaders (list[DataLoader]): Training and validation dataloaders.
695
+ """
696
+ model = Network.to(self.model, local_rank * self.size)
697
+ model = DDP(model, static_graph=True) if torch.cuda.is_available() else CPUModel(model)
698
+ if self.model_ema is not None:
699
+ self.model_ema.module = Network.to(self.model_ema.module, local_rank)
700
+ with _Trainer(
701
+ world_size,
702
+ global_rank,
703
+ local_rank,
704
+ self.size,
705
+ self.name,
706
+ self.early_stopping,
707
+ self.data_log,
708
+ self.save_checkpoint_mode,
709
+ self.epochs,
710
+ self.epoch,
711
+ self.autocast,
712
+ self.it_validation,
713
+ self.it,
714
+ model,
715
+ self.model_ema,
716
+ *dataloaders,
717
+ ) as t:
718
+ t.run()