konfai 1.1.8__py3-none-any.whl → 1.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of konfai might be problematic. Click here for more details.

Files changed (36) hide show
  1. konfai/__init__.py +59 -14
  2. konfai/data/augmentation.py +457 -286
  3. konfai/data/data_manager.py +533 -316
  4. konfai/data/patching.py +300 -183
  5. konfai/data/transform.py +408 -275
  6. konfai/evaluator.py +325 -68
  7. konfai/main.py +71 -22
  8. konfai/metric/measure.py +360 -244
  9. konfai/metric/schedulers.py +24 -13
  10. konfai/models/classification/convNeXt.py +187 -81
  11. konfai/models/classification/resnet.py +272 -58
  12. konfai/models/generation/cStyleGan.py +233 -59
  13. konfai/models/generation/ddpm.py +348 -121
  14. konfai/models/generation/diffusionGan.py +757 -358
  15. konfai/models/generation/gan.py +177 -53
  16. konfai/models/generation/vae.py +140 -40
  17. konfai/models/registration/registration.py +135 -52
  18. konfai/models/representation/representation.py +57 -23
  19. konfai/models/segmentation/NestedUNet.py +339 -68
  20. konfai/models/segmentation/UNet.py +140 -30
  21. konfai/network/blocks.py +331 -187
  22. konfai/network/network.py +795 -427
  23. konfai/predictor.py +644 -238
  24. konfai/trainer.py +509 -222
  25. konfai/utils/ITK.py +191 -106
  26. konfai/utils/config.py +152 -95
  27. konfai/utils/dataset.py +326 -455
  28. konfai/utils/utils.py +497 -249
  29. {konfai-1.1.8.dist-info → konfai-1.2.0.dist-info}/METADATA +1 -3
  30. konfai-1.2.0.dist-info/RECORD +38 -0
  31. konfai/utils/registration.py +0 -199
  32. konfai-1.1.8.dist-info/RECORD +0 -39
  33. {konfai-1.1.8.dist-info → konfai-1.2.0.dist-info}/WHEEL +0 -0
  34. {konfai-1.1.8.dist-info → konfai-1.2.0.dist-info}/entry_points.txt +0 -0
  35. {konfai-1.1.8.dist-info → konfai-1.2.0.dist-info}/licenses/LICENSE +0 -0
  36. {konfai-1.1.8.dist-info → konfai-1.2.0.dist-info}/top_level.txt +0 -0
konfai/network/network.py CHANGED
@@ -1,117 +1,181 @@
1
- from functools import partial
2
1
  import importlib
3
2
  import inspect
4
3
  import os
5
- from typing import Iterable, Iterator, Callable
6
- from typing_extensions import Self
7
- import torch
8
4
  from abc import ABC
5
+ from collections import OrderedDict
6
+ from collections.abc import Callable, Iterable, Iterator, Sequence
7
+ from enum import Enum
8
+ from functools import partial
9
+ from typing import Any, Self
10
+
9
11
  import numpy as np
12
+ import torch
10
13
  from torch._jit_internal import _copy_to_script_wrapper
11
- from collections import OrderedDict
12
14
  from torch.utils.checkpoint import checkpoint
13
- from typing import Union
14
- from enum import Enum
15
15
 
16
- from konfai import KONFAI_ROOT
16
+ from konfai import konfai_root
17
+ from konfai.data.patching import Accumulator, ModelPatch
17
18
  from konfai.metric.schedulers import Scheduler
18
19
  from konfai.utils.config import config
19
- from konfai.utils.utils import State, _getModule, getDevice, getGPUMemory, MeasureError, TrainerError
20
- from konfai.data.patching import Accumulator, ModelPatch
20
+ from konfai.utils.utils import MeasureError, State, TrainerError, get_device, get_gpu_memory, get_module
21
+
21
22
 
22
23
  class NetState(Enum):
23
- TRAIN = 0,
24
+ TRAIN = (0,)
24
25
  PREDICTION = 1
25
26
 
26
- class Patch_Indexed():
27
+
28
+ class PatchIndexed:
27
29
 
28
30
  def __init__(self, patch: ModelPatch, index: int) -> None:
29
31
  self.patch = patch
30
32
  self.index = index
31
33
 
32
- def isFull(self) -> bool:
33
- return len(self.patch.getPatch_slices(0)) == self.index
34
+ def is_full(self) -> bool:
35
+ return len(self.patch.get_patch_slices(0)) == self.index
36
+
37
+
38
+ class OptimizerLoader:
34
39
 
35
- class OptimizerLoader():
36
-
37
40
  @config("Optimizer")
38
41
  def __init__(self, name: str = "AdamW") -> None:
39
42
  self.name = name
40
-
41
- def getOptimizer(self, key: str, parameter: Iterator[torch.nn.parameter.Parameter]) -> torch.optim.Optimizer:
42
- return config("{}.Model.{}.Optimizer".format(KONFAI_ROOT(), key))(getattr(importlib.import_module('torch.optim'), self.name))(parameter, config = None)
43
-
44
43
 
45
- class LRSchedulersLoader():
46
-
44
+ def get_optimizer(self, key: str, parameter: Iterator[torch.nn.parameter.Parameter]) -> torch.optim.Optimizer:
45
+ return config(f"{konfai_root()}.Model.{key}.Optimizer")(
46
+ getattr(importlib.import_module("torch.optim"), self.name)
47
+ )(parameter, config=None)
48
+
49
+
50
+ class LRSchedulersLoader:
51
+
47
52
  @config()
48
- def __init__(self, nb_step : int = 0) -> None:
53
+ def __init__(self, nb_step: int = 0) -> None:
49
54
  self.nb_step = nb_step
50
55
 
51
- def getschedulers(self, key: str, scheduler_classname: str, optimizer: torch.optim.Optimizer) -> torch.optim.lr_scheduler._LRScheduler:
56
+ def getschedulers(
57
+ self, key: str, scheduler_classname: str, optimizer: torch.optim.Optimizer
58
+ ) -> torch.optim.lr_scheduler._LRScheduler:
52
59
  for m in ["torch.optim.lr_scheduler", "konfai.metric.schedulers"]:
53
- module, name = _getModule(scheduler_classname, m)
60
+ module, name = get_module(scheduler_classname, m)
54
61
  if hasattr(importlib.import_module(module), name):
55
- return config("{}.Model.{}.schedulers.{}".format(KONFAI_ROOT(), key, scheduler_classname))(getattr(importlib.import_module(module), name))(optimizer, config = None)
56
- raise TrainerError("Unknown scheduler {}, tried importing from: 'torch.optim.lr_scheduler' and 'konfai.metric.schedulers', but no valid match was found. Check your YAML config or scheduler name spelling.".format(scheduler_classname))
62
+ return config(f"{konfai_root()}.Model.{key}.schedulers.{scheduler_classname}")(
63
+ getattr(importlib.import_module(module), name)
64
+ )(optimizer, config=None)
65
+ raise TrainerError(
66
+ f"Unknown scheduler {scheduler_classname}, tried importing from: 'torch.optim.lr_scheduler' and "
67
+ "'konfai.metric.schedulers', but no valid match was found. "
68
+ "Check your YAML config or scheduler name spelling."
69
+ )
70
+
71
+
72
+ class LossSchedulersLoader:
57
73
 
58
- class LossSchedulersLoader():
59
-
60
74
  @config()
61
- def __init__(self, nb_step : int = 0) -> None:
75
+ def __init__(self, nb_step: int = 0) -> None:
62
76
  self.nb_step = nb_step
63
77
 
64
78
  def getschedulers(self, key: str, scheduler_classname: str) -> torch.optim.lr_scheduler._LRScheduler:
65
- return config("{}.{}".format(key, scheduler_classname))(getattr(importlib.import_module('konfai.metric.schedulers'), scheduler_classname))(config = None)
66
-
67
- class CriterionsAttr():
68
-
79
+ return config(f"{key}.{scheduler_classname}")(
80
+ getattr(importlib.import_module("konfai.metric.schedulers"), scheduler_classname)
81
+ )(config=None)
82
+
83
+
84
+ class CriterionsAttr:
85
+
69
86
  @config()
70
- def __init__(self, schedulers: dict[str, LossSchedulersLoader] = {"default:Constant": LossSchedulersLoader(0)}, isLoss: bool = True, group: int = 0, stepStart: int = 0, stepStop: Union[int, None] = None, accumulation: bool = False) -> None:
87
+ def __init__(
88
+ self,
89
+ schedulers: dict[str, LossSchedulersLoader] = {"default:Constant": LossSchedulersLoader(0)},
90
+ is_loss: bool = True,
91
+ group: int = 0,
92
+ start: int = 0,
93
+ stop: int | None = None,
94
+ accumulation: bool = False,
95
+ ) -> None:
71
96
  self.schedulersLoader = schedulers
72
97
  self.isTorchCriterion = True
73
- self.isLoss = isLoss
74
- self.stepStart = stepStart
75
- self.stepStop = stepStop
98
+ self.is_loss = is_loss
99
+ self.start = start
100
+ self.stop = stop
76
101
  self.group = group
77
102
  self.accumulation = accumulation
78
103
  self.schedulers: dict[Scheduler, int] = {}
79
-
80
- class CriterionsLoader():
81
104
 
82
- @config()
83
- def __init__(self, criterionsLoader: dict[str, CriterionsAttr] = {"default:torch_nn_CrossEntropyLoss:Dice:NCC": CriterionsAttr()}) -> None:
84
- self.criterionsLoader = criterionsLoader
85
-
86
- def getCriterions(self, model_classname : str, output_group : str, target_group : str) -> dict[torch.nn.Module, CriterionsAttr]:
87
- criterions = {}
88
- for module_classpath, criterionsAttr in self.criterionsLoader.items():
89
- module, name = _getModule(module_classpath, "konfai.metric.measure")
90
- criterionsAttr.isTorchCriterion = module.startswith("torch")
91
- for scheduler_classname, schedulers in criterionsAttr.schedulersLoader.items():
92
- criterionsAttr.schedulers[schedulers.getschedulers("{}.Model.{}.outputsCriterions.{}.targetsCriterions.{}.criterionsLoader.{}.schedulers".format(KONFAI_ROOT(), model_classname, output_group, target_group, module_classpath), scheduler_classname)] = schedulers.nb_step
93
- criterions[config("{}.Model.{}.outputsCriterions.{}.targetsCriterions.{}.criterionsLoader.{}".format(KONFAI_ROOT(), model_classname, output_group, target_group, module_classpath))(getattr(importlib.import_module(module), name))(config = None)] = criterionsAttr
94
- return criterions
95
105
 
96
- class TargetCriterionsLoader():
106
+ class CriterionsLoader:
97
107
 
98
108
  @config()
99
- def __init__(self, targetsCriterions : dict[str, CriterionsLoader] = {"default" : CriterionsLoader()}) -> None:
100
- self.targetsCriterions = targetsCriterions
101
-
102
- def getTargetsCriterions(self, output_group : str, model_classname : str) -> dict[str, dict[torch.nn.Module, float]]:
103
- targetsCriterions = {}
104
- for target_group, criterionsLoader in self.targetsCriterions.items():
105
- targetsCriterions[target_group] = criterionsLoader.getCriterions(model_classname, output_group, target_group)
106
- return targetsCriterions
109
+ def __init__(
110
+ self,
111
+ criterions_loader: dict[str, CriterionsAttr] = {"default:torch:nn:CrossEntropyLoss:Dice:NCC": CriterionsAttr()},
112
+ ) -> None:
113
+ self.criterions_loader = criterions_loader
114
+
115
+ def get_criterions(
116
+ self, model_classname: str, output_group: str, target_group: str
117
+ ) -> dict[torch.nn.Module, CriterionsAttr]:
118
+ criterions = {}
119
+ for module_classpath, criterions_attr in self.criterions_loader.items():
120
+ module, name = get_module(module_classpath, "konfai.metric.measure")
121
+ criterions_attr.isTorchCriterion = module.startswith("torch")
122
+ for (
123
+ scheduler_classname,
124
+ schedulers,
125
+ ) in criterions_attr.schedulersLoader.items():
126
+ criterions_attr.schedulers[
127
+ schedulers.getschedulers(
128
+ f"{konfai_root()}.Model.{model_classname}.outputs_criterions.{output_group}"
129
+ f".targets_criterions.{target_group}"
130
+ f".criterions_loader.{module_classpath}.schedulers",
131
+ scheduler_classname,
132
+ )
133
+ ] = schedulers.nb_step
134
+ criterions[
135
+ config(
136
+ f"{konfai_root()}.Model.{model_classname}.outputs_criterions."
137
+ f"{output_group}.targets_criterions.{target_group}."
138
+ f"criterions_loader.{module_classpath}"
139
+ )(getattr(importlib.import_module(module), name))(config=None)
140
+ ] = criterions_attr
141
+ return criterions
107
142
 
108
- class Measure():
109
143
 
110
- class Loss():
144
+ class TargetCriterionsLoader:
111
145
 
112
- def __init__(self, name: str, output_group: str, target_group: str, group: int, isLoss: bool, accumulation: bool) -> None:
146
+ @config()
147
+ def __init__(
148
+ self,
149
+ targets_criterions: dict[str, CriterionsLoader] = {"default": CriterionsLoader()},
150
+ ) -> None:
151
+ self.targets_criterions = targets_criterions
152
+
153
+ def get_targets_criterions(
154
+ self, output_group: str, model_classname: str
155
+ ) -> dict[str, dict[torch.nn.Module, CriterionsAttr]]:
156
+ targets_criterions = {}
157
+ for target_group, criterions_loader in self.targets_criterions.items():
158
+ targets_criterions[target_group] = criterions_loader.get_criterions(
159
+ model_classname, output_group, target_group
160
+ )
161
+ return targets_criterions
162
+
163
+
164
+ class Measure:
165
+
166
+ class Loss:
167
+
168
+ def __init__(
169
+ self,
170
+ name: str,
171
+ output_group: str,
172
+ target_group: str,
173
+ group: int,
174
+ is_loss: bool,
175
+ accumulation: bool,
176
+ ) -> None:
113
177
  self.name = name
114
- self.isLoss = isLoss
178
+ self.is_loss = is_loss
115
179
  self.accumulation = accumulation
116
180
  self.output_group = output_group
117
181
  self.target_group = target_group
@@ -120,144 +184,236 @@ class Measure():
120
184
  self._loss: list[torch.Tensor] = []
121
185
  self._weight: list[float] = []
122
186
  self._values: list[float] = []
123
-
124
- def resetLoss(self) -> None:
187
+
188
+ def reset_loss(self) -> None:
125
189
  self._loss.clear()
126
190
 
127
- def add(self, weight: float, value: torch.Tensor) -> None:
128
- if self.isLoss or value != 0:
129
- self._loss.append(value if self.isLoss else value.detach())
130
- self._values.append(value.item())
131
- self._weight.append(weight)
191
+ def add(self, weight: float, value: torch.Tensor | tuple[torch.Tensor, float]) -> None:
192
+ if isinstance(value, tuple):
193
+ loss_value, true_value = value
194
+ else:
195
+ loss_value = value
196
+ true_value = value.item()
197
+
198
+ self._loss.append(loss_value if self.is_loss else loss_value.detach())
199
+ self._values.append(true_value)
200
+ self._weight.append(weight)
201
+
202
+ def get_last_loss(self) -> torch.Tensor:
203
+ return self._loss[-1] * self._weight[-1] if len(self._loss) else torch.zeros((1), requires_grad=True)
132
204
 
133
- def getLastLoss(self) -> torch.Tensor:
134
- return self._loss[-1]*self._weight[-1] if len(self._loss) else torch.zeros((1), requires_grad = True)
135
-
136
- def getLoss(self) -> torch.Tensor:
137
- return torch.stack([w*l for w, l in zip(self._weight, self._loss)], dim=0).mean(dim=0) if len(self._loss) else torch.zeros((1), requires_grad = True)
205
+ def get_loss(self) -> torch.Tensor:
206
+ return (
207
+ torch.stack([w * loss_value for w, loss_value in zip(self._weight, self._loss)], dim=0).mean(dim=0)
208
+ if len(self._loss)
209
+ else torch.zeros((1), requires_grad=True)
210
+ )
138
211
 
139
212
  def __len__(self) -> int:
140
213
  return len(self._loss)
141
-
142
- def __init__(self, model_classname : str, outputsCriterions: dict[str, TargetCriterionsLoader]) -> None:
143
- super().__init__()
144
- self.outputsCriterions = {}
145
- for output_group, targetCriterionsLoader in outputsCriterions.items():
146
- self.outputsCriterions[output_group.replace(":", ".")] = targetCriterionsLoader.getTargetsCriterions(output_group, model_classname)
147
- self._loss : dict[int, dict[str, Measure.Loss]] = {}
148
214
 
149
- def init(self, model : torch.nn.Module, group_dest: list[str]) -> None:
215
+ def __init__(
216
+ self,
217
+ model_classname: str,
218
+ outputs_criterions_loader: dict[str, TargetCriterionsLoader],
219
+ ) -> None:
220
+ super().__init__()
221
+ self.outputs_criterions: dict[str, dict[str, dict[torch.nn.Module, CriterionsAttr]]] = {}
222
+ for output_group, target_criterions_loader in outputs_criterions_loader.items():
223
+ self.outputs_criterions[output_group.replace(":", ".")] = target_criterions_loader.get_targets_criterions(
224
+ output_group, model_classname
225
+ )
226
+ self._loss: dict[int, dict[str, Measure.Loss]] = {}
227
+
228
+ def init(self, model: torch.nn.Module, group_dest: list[str]) -> None:
150
229
  outputs_group_rename = {}
151
-
230
+
152
231
  modules = []
153
- for i,_,_ in model.named_ModuleArgsDict():
232
+ for i, _, _ in model.named_module_args_dict():
154
233
  modules.append(i)
155
234
 
156
- for output_group in self.outputsCriterions.keys():
157
- if output_group not in modules:
158
- raise MeasureError(f"The output group '{output_group}' defined in 'outputsCriterions' does not correspond to any module in the model.",
235
+ for output_group in self.outputs_criterions.keys():
236
+ if output_group.replace(";accu;", "") not in modules:
237
+ raise MeasureError(
238
+ f"The output group '{output_group}' defined in 'outputs_criterions' "
239
+ "does not correspond to any module in the model.",
159
240
  f"Available modules: {modules}",
160
- "Please check that the name matches exactly a submodule or output of your model architecture."
241
+ "Please check that the name matches exactly a submodule or output of your model architecture.",
161
242
  )
162
- for target_group in self.outputsCriterions[output_group]:
163
- for target_group_tmp in target_group.split(";"):
243
+ for target_group in self.outputs_criterions[output_group]:
244
+ for target_group_tmp in target_group.split(";"):
164
245
  if target_group_tmp not in group_dest:
165
246
  raise MeasureError(
166
- f"The target_group '{target_group_tmp}' defined in 'outputsCriterions.{output_group}.targetsCriterions' was not found in the available destination groups.",
167
- "This target_group is expected for loss or metric computation, but was not loaded in 'group_dest'.",
168
- f"Please make sure that the group '{target_group_tmp}' is defined in 'Dataset:groups_src:...:groups_dest:'{target_group_tmp}'' and correctly loaded from the dataset.")
169
- for criterion in self.outputsCriterions[output_group][target_group]:
170
- if not self.outputsCriterions[output_group][target_group][criterion].isTorchCriterion:
247
+ f"The target_group {target_group_tmp} defined in "
248
+ "'outputs_criterions.{output_group}.targets_criterions'"
249
+ " was not found in the available destination groups.",
250
+ "This target_group is expected for loss or metric computation, "
251
+ "but was not loaded in 'group_dest'.",
252
+ f"Please make sure that the group {target_group_tmp} is defined in "
253
+ "Dataset:groups_src:...:groups_dest: {target_group_tmp} "
254
+ "and correctly loaded from the dataset.",
255
+ )
256
+ for criterion in self.outputs_criterions[output_group][target_group]:
257
+ if not self.outputs_criterions[output_group][target_group][criterion].isTorchCriterion:
171
258
  outputs_group_rename[output_group] = criterion.init(model, output_group, target_group)
172
259
 
173
- outputsCriterions_bak = self.outputsCriterions.copy()
260
+ outputs_criterions_bak = self.outputs_criterions.copy()
174
261
  for old, new in outputs_group_rename.items():
175
- self.outputsCriterions.pop(old)
176
- self.outputsCriterions[new] = outputsCriterions_bak[old]
177
- for output_group in self.outputsCriterions:
178
- for target_group in self.outputsCriterions[output_group]:
179
- for criterion, criterionsAttr in self.outputsCriterions[output_group][target_group].items():
180
- if criterionsAttr.group not in self._loss:
181
- self._loss[criterionsAttr.group] = {}
182
- self._loss[criterionsAttr.group]["{}:{}:{}".format(output_group, target_group, criterion.__class__.__name__)] = Measure.Loss(criterion.__class__.__name__, output_group, target_group, criterionsAttr.group, criterionsAttr.isLoss, criterionsAttr.accumulation)
183
-
184
- def update(self, output_group: str, output : torch.Tensor, data_dict: dict[str, torch.Tensor], it: int, nb_patch: int, training: bool) -> None:
185
- for target_group in self.outputsCriterions[output_group]:
186
- target = [data_dict[group].to(output[0].device).detach() for group in target_group.split(";") if group in data_dict]
187
-
188
- for criterion, criterionsAttr in self.outputsCriterions[output_group][target_group].items():
189
- if it >= criterionsAttr.stepStart and (criterionsAttr.stepStop is None or it <= criterionsAttr.stepStop):
190
- scheduler = self.update_scheduler(criterionsAttr.schedulers, it)
191
- self._loss[criterionsAttr.group]["{}:{}:{}".format(output_group, target_group, criterion.__class__.__name__)].add(scheduler.get_value(), criterion(output, *target))
192
- if training and len(np.unique([len(l) for l in self._loss[criterionsAttr.group].values() if l.accumulation and l.isLoss])) == 1:
193
- if criterionsAttr.isLoss:
194
- loss = torch.zeros((1), requires_grad = True)
195
- for v in [l for l in self._loss[criterionsAttr.group].values() if l.accumulation and l.isLoss]:
196
- l = v.getLastLoss()
197
- loss = loss.to(l.device)+l
198
- loss = loss/nb_patch
262
+ self.outputs_criterions.pop(old)
263
+ self.outputs_criterions[new] = outputs_criterions_bak[old]
264
+ for output_group in self.outputs_criterions:
265
+ for target_group in self.outputs_criterions[output_group]:
266
+ for criterion, criterions_attr in self.outputs_criterions[output_group][target_group].items():
267
+ if criterions_attr.group not in self._loss:
268
+ self._loss[criterions_attr.group] = {}
269
+ self._loss[criterions_attr.group][
270
+ f"{output_group}:{target_group}:{criterion.__class__.__name__}"
271
+ ] = Measure.Loss(
272
+ criterion.__class__.__name__,
273
+ output_group,
274
+ target_group,
275
+ criterions_attr.group,
276
+ criterions_attr.is_loss,
277
+ criterions_attr.accumulation,
278
+ )
279
+
280
+ def update(
281
+ self,
282
+ output_group: str,
283
+ output: torch.Tensor,
284
+ data_dict: dict[str, torch.Tensor],
285
+ it: int,
286
+ nb_patch: int,
287
+ training: bool,
288
+ ) -> None:
289
+ for target_group in self.outputs_criterions[output_group]:
290
+ target = [
291
+ data_dict[group].to(output[0].device).detach()
292
+ for group in target_group.split(";")
293
+ if group in data_dict
294
+ ]
295
+
296
+ for criterion, criterions_attr in self.outputs_criterions[output_group][target_group].items():
297
+ if it >= criterions_attr.start and (criterions_attr.stop is None or it <= criterions_attr.stop):
298
+ scheduler = self.update_scheduler(criterions_attr.schedulers, it)
299
+ self._loss[criterions_attr.group][
300
+ f"{output_group}:{target_group}:{criterion.__class__.__name__}"
301
+ ].add(scheduler.get_value(), criterion(output, *target))
302
+ if (
303
+ training
304
+ and len(
305
+ np.unique(
306
+ [
307
+ len(loss_value)
308
+ for loss_value in self._loss[criterions_attr.group].values()
309
+ if loss_value.accumulation and loss_value.is_loss
310
+ ]
311
+ )
312
+ )
313
+ == 1
314
+ ):
315
+ if criterions_attr.is_loss:
316
+ loss = torch.zeros((1), requires_grad=True)
317
+ for v in [
318
+ loss_value
319
+ for loss_value in self._loss[criterions_attr.group].values()
320
+ if loss_value.accumulation and loss_value.is_loss
321
+ ]:
322
+ loss_value = v.get_last_loss()
323
+ loss = loss.to(loss_value.device) + loss_value
324
+ loss = loss / nb_patch
199
325
  loss.backward()
200
326
 
201
- def getLoss(self) -> list[torch.Tensor]:
327
+ def get_loss(self) -> list[torch.Tensor]:
202
328
  loss: dict[int, torch.Tensor] = {}
203
329
  for group in self._loss.keys():
204
- loss[group] = torch.zeros((1), requires_grad = True)
330
+ loss[group] = torch.zeros((1), requires_grad=True)
205
331
  for v in self._loss[group].values():
206
- if v.isLoss and not v.accumulation:
207
- l = v.getLoss()
208
- loss[v.group] = loss[v.group].to(l.device)+l
209
- return loss.values()
332
+ if v.is_loss and not v.accumulation:
333
+ loss_value = v.get_loss()
334
+ loss[v.group] = loss[v.group].to(loss_value.device) + loss_value
335
+ return list(loss.values())
210
336
 
211
- def resetLoss(self) -> None:
337
+ def reset_loss(self) -> None:
212
338
  for group in self._loss.keys():
213
339
  for v in self._loss[group].values():
214
- v.resetLoss()
215
-
216
- def getLastValues(self, n: int = 1) -> dict[str, float]:
340
+ v.reset_loss()
341
+
342
+ def get_last_values(self, n: int = 1) -> dict[str, float]:
217
343
  result = {}
218
344
  for group in self._loss.keys():
219
- result.update({name : np.nanmean(value._values[-n:] if n > 0 else value._values) for name, value in self._loss[group].items() if n < 0 or len(value._values) >= n})
345
+ result.update(
346
+ {
347
+ name: np.nanmean(value._values[-n:] if n > 0 else value._values)
348
+ for name, value in self._loss[group].items()
349
+ if n < 0 or len(value._values) >= n
350
+ }
351
+ )
220
352
  return result
221
-
222
- def getLastWeights(self, n: int = 1) -> dict[str, float]:
353
+
354
+ def get_last_weights(self, n: int = 1) -> dict[str, float]:
223
355
  result = {}
224
356
  for group in self._loss.keys():
225
- result.update({name : np.nanmean(value._weight[-n:] if n > 0 else value._weight) for name, value in self._loss[group].items() if n < 0 or len(value._values) >= n})
357
+ result.update(
358
+ {
359
+ name: np.nanmean(value._weight[-n:] if n > 0 else value._weight)
360
+ for name, value in self._loss[group].items()
361
+ if n < 0 or len(value._values) >= n
362
+ }
363
+ )
226
364
  return result
227
365
 
228
- def format(self, isLoss: bool, n: int) -> dict[str, tuple[float, float]]:
229
- result = dict()
366
+ def format_loss(self, is_loss: bool, n: int) -> dict[str, tuple[float, float]]:
367
+ result = {}
230
368
  for group in self._loss.keys():
231
369
  for name, loss in self._loss[group].items():
232
- if loss.isLoss == isLoss and len(loss._values) >= n:
233
- result[name] = (np.nanmean(loss._weight[-n:]), np.nanmean(loss._values[-n:]))
370
+ if loss.is_loss == is_loss and len(loss._values) >= n:
371
+ result[name] = (
372
+ np.nanmean(loss._weight[-n:]),
373
+ np.nanmean(loss._values[-n:]),
374
+ )
234
375
  return result
235
376
 
236
377
  def update_scheduler(self, schedulers: dict[Scheduler, int], it: int) -> Scheduler:
237
378
  step = 0
238
- scheduler = None
239
- for scheduler, value in schedulers.items():
240
- if value is None or (it >= step and it < step+value):
379
+ _scheduler = None
380
+ for _scheduler, value in schedulers.items():
381
+ if value is None or (it >= step and it < step + value):
241
382
  break
242
383
  step += value
243
- if scheduler:
244
- scheduler.step(it-step)
245
- return scheduler
246
-
384
+ if _scheduler:
385
+ _scheduler.step(it - step)
386
+ if _scheduler is None:
387
+ raise NameError(
388
+ f"No scheduler found for iteration {it}. "
389
+ f"Available steps were: {list(schedulers.values())}. "
390
+ f"Check your configuration."
391
+ )
392
+ return _scheduler
393
+
394
+
247
395
  class ModuleArgsDict(torch.nn.Module, ABC):
248
-
396
+
249
397
  class ModuleArgs:
250
398
 
251
- def __init__(self, in_branch: list[str], out_branch: list[str], pretrained : bool, alias : list[str], requires_grad: Union[bool, None], training: Union[None, bool]) -> None:
399
+ def __init__(
400
+ self,
401
+ in_branch: list[str],
402
+ out_branch: list[str],
403
+ pretrained: bool,
404
+ alias: list[str],
405
+ requires_grad: bool | None,
406
+ training: None | bool,
407
+ ) -> None:
252
408
  super().__init__()
253
- self.alias= alias
409
+ self.alias = alias
254
410
  self.pretrained = pretrained
255
411
  self.in_branch = in_branch
256
412
  self.out_branch = out_branch
257
- self.in_channels = None
258
- self.in_is_channel = True
259
- self.out_channels = None
260
- self.out_is_channel = True
413
+ self.in_channels: int | None = None
414
+ self.in_is_channel: bool = True
415
+ self.out_channels: int | None = None
416
+ self.out_is_channel: bool = True
261
417
  self.requires_grad = requires_grad
262
418
  self.isCheckpoint = False
263
419
  self.isGPU_Checkpoint = False
@@ -267,51 +423,54 @@ class ModuleArgsDict(torch.nn.Module, ABC):
267
423
 
268
424
  def __init__(self) -> None:
269
425
  super().__init__()
270
- self._modulesArgs : dict[str, ModuleArgsDict.ModuleArgs] = dict()
426
+ self._modulesArgs: dict[str, ModuleArgsDict.ModuleArgs] = {}
271
427
  self._training = NetState.TRAIN
272
428
 
273
- def _addindent(self, s_: str, numSpaces : int):
274
- s = s_.split('\n')
429
+ def _addindent(self, s_: str, num_spaces: int):
430
+ s = s_.split("\n")
275
431
  if len(s) == 1:
276
432
  return s_
277
433
  first = s.pop(0)
278
- s = [(numSpaces * ' ') + line for line in s]
279
- s = '\n'.join(s)
280
- s = first + '\n' + s
281
- return s
434
+ s = [(num_spaces * " ") + line for line in s]
435
+ return first + "\n" + "\n".join(s)
282
436
 
283
437
  def __repr__(self):
284
438
  extra_lines = []
285
439
 
286
440
  extra_repr = self.extra_repr()
287
441
  if extra_repr:
288
- extra_lines = extra_repr.split('\n')
442
+ extra_lines = extra_repr.split("\n")
289
443
 
290
444
  child_lines = []
291
- is_simple_branch = lambda x : len(x) > 1 or x[0] != 0
445
+
446
+ def is_simple_branch(x):
447
+ return len(x) > 1 or x[0] != 0
448
+
292
449
  for key, module in self._modules.items():
293
450
  mod_str = repr(module)
294
451
 
295
452
  mod_str = self._addindent(mod_str, 2)
296
453
  desc = ""
297
- if is_simple_branch(self._modulesArgs[key].in_branch) or is_simple_branch(self._modulesArgs[key].out_branch):
298
- desc += ", {}->{}".format(self._modulesArgs[key].in_branch, self._modulesArgs[key].out_branch)
454
+ if is_simple_branch(self._modulesArgs[key].in_branch) or is_simple_branch(
455
+ self._modulesArgs[key].out_branch
456
+ ):
457
+ desc += f", {self._modulesArgs[key].in_branch}->{self._modulesArgs[key].out_branch}"
299
458
  if not self._modulesArgs[key].pretrained:
300
459
  desc += ", pretrained=False"
301
460
  if self._modulesArgs[key].alias:
302
- desc += ", alias={}".format(self._modulesArgs[key].alias)
303
- desc += ", in_channels={}".format(self._modulesArgs[key].in_channels)
304
- desc += ", in_is_channel={}".format(self._modulesArgs[key].in_is_channel)
305
- desc += ", out_channels={}".format(self._modulesArgs[key].out_channels)
306
- desc += ", out_is_channel={}".format(self._modulesArgs[key].out_is_channel)
307
- desc += ", is_end={}".format(self._modulesArgs[key]._isEnd)
308
- desc += ", isInCheckpoint={}".format(self._modulesArgs[key].isCheckpoint)
309
- desc += ", isInGPU_Checkpoint={}".format(self._modulesArgs[key].isGPU_Checkpoint)
310
- desc += ", requires_grad={}".format(self._modulesArgs[key].requires_grad)
311
- desc += ", device={}".format(self._modulesArgs[key].gpu)
312
-
313
- child_lines.append("({}{}) {}".format(key, desc, mod_str))
314
-
461
+ desc += f", alias={self._modulesArgs[key].alias}"
462
+ desc += f", in_channels={self._modulesArgs[key].in_channels}"
463
+ desc += f", in_is_channel={self._modulesArgs[key].in_is_channel}"
464
+ desc += f", out_channels={self._modulesArgs[key].out_channels}"
465
+ desc += f", out_is_channel={self._modulesArgs[key].out_is_channel}"
466
+ desc += f", is_end={self._modulesArgs[key]._isEnd}"
467
+ desc += f", isInCheckpoint={self._modulesArgs[key].isCheckpoint}"
468
+ desc += f", isInGPU_Checkpoint={self._modulesArgs[key].isGPU_Checkpoint}"
469
+ desc += f", requires_grad={self._modulesArgs[key].requires_grad}"
470
+ desc += f", device={self._modulesArgs[key].gpu}"
471
+
472
+ child_lines.append(f"({key}{desc}) {mod_str}")
473
+
315
474
  lines = extra_lines + child_lines
316
475
 
317
476
  desc = ""
@@ -319,53 +478,71 @@ class ModuleArgsDict(torch.nn.Module, ABC):
319
478
  if len(extra_lines) == 1 and not child_lines:
320
479
  desc += extra_lines[0]
321
480
  else:
322
- desc += '\n ' + '\n '.join(lines) + '\n'
481
+ desc += "\n " + "\n ".join(lines) + "\n"
482
+
483
+ return f"{self._get_name()}({desc})"
323
484
 
324
- return "{}({})".format(self._get_name(), desc)
325
-
326
485
  def __getitem__(self, key: str) -> torch.nn.Module:
327
486
  module = self._modules[key]
328
- assert module, "Error {} is None".format(key)
329
- return module
487
+ if not module:
488
+ raise ValueError(f"Module '{key}' is None or missing in self._modules")
489
+ return module
330
490
 
331
491
  @_copy_to_script_wrapper
332
492
  def keys(self) -> Iterable[str]:
333
493
  return self._modules.keys()
334
494
 
335
495
  @_copy_to_script_wrapper
336
- def items(self) -> Iterable[tuple[str, Union[torch.nn.Module, None]]]:
496
+ def items(self) -> Iterable[tuple[str, torch.nn.Module | None]]:
337
497
  return self._modules.items()
338
498
 
339
499
  @_copy_to_script_wrapper
340
- def values(self) -> Iterable[Union[torch.nn.Module, None]]:
500
+ def values(self) -> Iterable[torch.nn.Module | None]:
341
501
  return self._modules.values()
342
502
 
343
- def add_module(self, name: str, module : torch.nn.Module, in_branch: list[Union[int, str]] = [0], out_branch: list[Union[int, str]] = [0], pretrained : bool = True, alias : list[str] = [], requires_grad: Union[bool, None] = None, training: Union[None, bool] = None) -> None:
503
+ def add_module(
504
+ self,
505
+ name: str,
506
+ module: torch.nn.Module,
507
+ in_branch: Sequence[int | str] = [0],
508
+ out_branch: Sequence[int | str] = [0],
509
+ pretrained: bool = True,
510
+ alias: list[str] = [],
511
+ requires_grad: bool | None = None,
512
+ training: None | bool = None,
513
+ ) -> None:
344
514
  super().add_module(name, module)
345
- self._modulesArgs[name] = ModuleArgsDict.ModuleArgs([str(value) for value in in_branch], [str(value) for value in out_branch], pretrained, alias, requires_grad, training)
346
-
347
- def getMap(self):
348
- results : dict[str, str] = {}
349
- for name, moduleArgs in self._modulesArgs.items():
515
+ self._modulesArgs[name] = ModuleArgsDict.ModuleArgs(
516
+ [str(value) for value in in_branch],
517
+ [str(value) for value in out_branch],
518
+ pretrained,
519
+ alias,
520
+ requires_grad,
521
+ training,
522
+ )
523
+
524
+ def get_mapping(self):
525
+ results: dict[str, str] = {}
526
+ for name, module_args in self._modulesArgs.items():
350
527
  module = self[name]
351
528
  if isinstance(module, ModuleArgsDict):
352
- if len(moduleArgs.alias):
353
- count = {k : 0 for k in set(module.getMap().values())}
529
+ if len(module_args.alias):
530
+ count = dict.fromkeys(set(module.get_mapping().values()), 0)
354
531
  if len(count):
355
- for k, v in module.getMap().items():
356
- alias_name = moduleArgs.alias[count[v]]
532
+ for k, v in module.get_mapping().items():
533
+ alias_name = module_args.alias[count[v]]
357
534
  if k == "":
358
- results.update({alias_name : name+"."+v})
535
+ results.update({alias_name: name + "." + v})
359
536
  else:
360
- results.update({alias_name+"."+k : name+"."+v})
361
- count[v]+=1
537
+ results.update({alias_name + "." + k: name + "." + v})
538
+ count[v] += 1
362
539
  else:
363
- for alias in moduleArgs.alias:
364
- results.update({alias : name})
540
+ for alias in module_args.alias:
541
+ results.update({alias: name})
365
542
  else:
366
- results.update({k : name+"."+v for k, v in module.getMap().items()})
543
+ results.update({k: name + "." + v for k, v in module.get_mapping().items()})
367
544
  else:
368
- for alias in moduleArgs.alias:
545
+ for alias in module_args.alias:
369
546
  results[alias] = name
370
547
  return results
371
548
 
@@ -375,24 +552,24 @@ class ModuleArgsDict(torch.nn.Module, ABC):
375
552
  if isinstance(module, ModuleArgsDict):
376
553
  module.init(init_type, init_gain)
377
554
  elif isinstance(module, torch.nn.modules.conv._ConvNd) or isinstance(module, torch.nn.Linear):
378
- if init_type == 'normal':
555
+ if init_type == "normal":
379
556
  torch.nn.init.normal_(module.weight, 0.0, init_gain)
380
- elif init_type == 'xavier':
557
+ elif init_type == "xavier":
381
558
  torch.nn.init.xavier_normal_(module.weight, gain=init_gain)
382
- elif init_type == 'kaiming':
383
- torch.nn.init.kaiming_normal_(module.weight, a=0, mode='fan_in')
384
- elif init_type == 'orthogonal':
559
+ elif init_type == "kaiming":
560
+ torch.nn.init.kaiming_normal_(module.weight, a=0, mode="fan_in")
561
+ elif init_type == "orthogonal":
385
562
  torch.nn.init.orthogonal_(module.weight, gain=init_gain)
386
563
  elif init_type == "trunc_normal":
387
564
  torch.nn.init.trunc_normal_(module.weight, std=init_gain)
388
565
  else:
389
- raise NotImplementedError('Initialization method {} is not implemented'.format(init_type))
566
+ raise NotImplementedError(f"Initialization method {init_type} is not implemented")
390
567
  if module.bias is not None:
391
568
  torch.nn.init.constant_(module.bias, 0.0)
392
569
 
393
570
  elif isinstance(module, torch.nn.modules.batchnorm._BatchNorm):
394
571
  if module.weight is not None:
395
- torch.nn.init.normal_(module.weight, 0.0, std = init_gain)
572
+ torch.nn.init.normal_(module.weight, 0.0, std=init_gain)
396
573
  if module.bias is not None:
397
574
  torch.nn.init.constant_(module.bias, 0.0)
398
575
 
@@ -400,8 +577,7 @@ class ModuleArgsDict(torch.nn.Module, ABC):
400
577
  for module in self._modules.values():
401
578
  ModuleArgsDict.init_func(module, init_type, init_gain)
402
579
 
403
-
404
- def named_forward(self, *inputs: torch.Tensor) -> Iterator[tuple[str, torch.Tensor]]:
580
+ def named_forward(self, *inputs: torch.Tensor) -> Iterator[tuple[str, torch.Tensor]]:
405
581
  if len(inputs) > 0:
406
582
  branchs: dict[str, torch.Tensor] = {}
407
583
  for i, sinput in enumerate(inputs):
@@ -410,7 +586,10 @@ class ModuleArgsDict(torch.nn.Module, ABC):
410
586
  out = inputs[0]
411
587
  tmp = []
412
588
  for name, module in self.items():
413
- if self._modulesArgs[name].training is None or (not (self._modulesArgs[name].training and self._training == NetState.PREDICTION) and not (not self._modulesArgs[name].training and self._training == NetState.TRAIN)):
589
+ if self._modulesArgs[name].training is None or (
590
+ not (self._modulesArgs[name].training and self._training == NetState.PREDICTION)
591
+ and not (not self._modulesArgs[name].training and self._training == NetState.TRAIN)
592
+ ):
414
593
  requires_grad = self._modulesArgs[name].requires_grad
415
594
  if requires_grad is not None and module:
416
595
  module.requires_grad_(requires_grad)
@@ -418,64 +597,75 @@ class ModuleArgsDict(torch.nn.Module, ABC):
418
597
  if ib not in branchs:
419
598
  branchs[ib] = inputs[0]
420
599
  for branchs_key in branchs.keys():
421
- if str(self._modulesArgs[name].gpu) != 'cpu' and str(branchs[branchs_key].device) != "cuda:"+self._modulesArgs[name].gpu:
600
+ if (
601
+ str(self._modulesArgs[name].gpu) != "cpu"
602
+ and str(branchs[branchs_key].device) != "cuda:" + self._modulesArgs[name].gpu
603
+ ):
422
604
  branchs[branchs_key] = branchs[branchs_key].to(int(self._modulesArgs[name].gpu))
423
-
605
+
424
606
  if self._modulesArgs[name].isCheckpoint:
425
- out = checkpoint(module, *[branchs[i] for i in self._modulesArgs[name].in_branch], use_reentrant=True)
607
+ out = checkpoint(
608
+ module,
609
+ *[branchs[i] for i in self._modulesArgs[name].in_branch],
610
+ use_reentrant=True,
611
+ )
426
612
  for ob in self._modulesArgs[name].out_branch:
427
613
  branchs[ob] = out
428
614
  yield name, out
429
615
  else:
430
616
  if isinstance(module, ModuleArgsDict):
431
- for k, out in module.named_forward(*[branchs[i] for i in self._modulesArgs[name].in_branch]):
617
+ for k, out in module.named_forward(
618
+ *[branchs[i] for i in self._modulesArgs[name].in_branch]
619
+ ):
432
620
  for ob in self._modulesArgs[name].out_branch:
433
621
  if ob in module._modulesArgs[k.split(".")[0].replace(";accu;", "")].out_branch:
434
622
  tmp.append(ob)
435
623
  branchs[ob] = out
436
- yield name+"."+k, out
624
+ yield name + "." + k, out
437
625
  for ob in self._modulesArgs[name].out_branch:
438
626
  if ob not in tmp:
439
627
  branchs[ob] = out
440
628
  elif isinstance(module, torch.nn.Module):
441
- out = module(*[branchs[i] for i in self._modulesArgs[name].in_branch])
629
+ out = module(*[branchs[i] for i in self._modulesArgs[name].in_branch])
442
630
  for ob in self._modulesArgs[name].out_branch:
443
631
  branchs[ob] = out
444
632
  yield name, out
445
633
  del branchs
446
634
 
447
635
  def forward(self, *input: torch.Tensor) -> torch.Tensor:
448
- v = input
449
- for k, v in self.named_forward(*input):
636
+ _v = input
637
+ for _, _v in self.named_forward(*input):
450
638
  pass
451
- return v
639
+ return _v
452
640
 
453
- def named_parameters(self, pretrained: bool = False, recurse=False) -> Iterator[tuple[str, torch.nn.parameter.Parameter]]:
454
- for name, moduleArgs in self._modulesArgs.items():
641
+ def named_parameters(
642
+ self, pretrained: bool = False, recurse=False
643
+ ) -> Iterator[tuple[str, torch.nn.parameter.Parameter]]:
644
+ for name, module_args in self._modulesArgs.items():
455
645
  module = self[name]
456
646
  if isinstance(module, ModuleArgsDict):
457
- for k, v in module.named_parameters(pretrained = pretrained):
458
- yield name+"."+k, v
647
+ for k, v in module.named_parameters(pretrained=pretrained):
648
+ yield name + "." + k, v
459
649
  elif isinstance(module, torch.nn.Module):
460
- if not pretrained or not moduleArgs.pretrained:
461
- if moduleArgs.training is None or moduleArgs.training:
650
+ if not pretrained or not module_args.pretrained:
651
+ if module_args.training is None or module_args.training:
462
652
  for k, v in module.named_parameters():
463
- yield name+"."+k, v
653
+ yield name + "." + k, v
464
654
 
465
655
  def parameters(self, pretrained: bool = False):
466
- for _, v in self.named_parameters(pretrained = pretrained):
656
+ for _, v in self.named_parameters(pretrained=pretrained):
467
657
  yield v
468
-
469
- def named_ModuleArgsDict(self) -> Iterator[tuple[str, Self, ModuleArgs]]:
658
+
659
+ def named_module_args_dict(self) -> Iterator[tuple[str, Self, ModuleArgs]]:
470
660
  for name, module in self._modules.items():
471
661
  yield name, module, self._modulesArgs[name]
472
662
  if isinstance(module, ModuleArgsDict):
473
- for k, v, u in module.named_ModuleArgsDict():
474
- yield name+"."+k, v, u
663
+ for k, v, u in module.named_module_args_dict():
664
+ yield name + "." + k, v, u
475
665
 
476
666
  def _requires_grad(self, keys: list[str]):
477
667
  keys = keys.copy()
478
- for name, module, args in self.named_ModuleArgsDict():
668
+ for name, module, args in self.named_module_args_dict():
479
669
  requires_grad = args.requires_grad
480
670
  if requires_grad is not None:
481
671
  module.requires_grad_(requires_grad)
@@ -484,162 +674,224 @@ class ModuleArgsDict(torch.nn.Module, ABC):
484
674
  if len(keys) == 0:
485
675
  break
486
676
 
677
+
487
678
  class OutputsGroup(list):
488
679
 
489
680
  def __init__(self, measure: Measure) -> None:
490
681
  self.layers: dict[str, torch.Tensor] = {}
491
682
  self.measure = measure
492
683
 
493
- def addLayer(self, name: str, layer: torch.Tensor):
684
+ def add_layer(self, name: str, layer: torch.Tensor):
494
685
  self.layers[name] = layer
495
686
 
496
- def isDone(self):
687
+ def is_done(self):
497
688
  return len(self) == len(self.layers)
498
-
689
+
499
690
  def clear(self):
500
691
  self.layers.clear()
501
692
 
502
-
503
693
 
504
694
  class Network(ModuleArgsDict, ABC):
505
695
 
506
- def _apply_network(self, name_function : Callable[[Self], str], networks: list[str], key: str, function: Callable, *args, **kwargs) -> dict[str, object]:
507
- results : dict[str, object] = {}
696
+ def _apply_network(
697
+ self,
698
+ name_function: Callable[[Self], str],
699
+ networks: list[str],
700
+ key: str,
701
+ function: Callable,
702
+ *args,
703
+ **kwargs,
704
+ ) -> dict[str, object]:
705
+ results: dict[str, object] = {}
508
706
  for module in self.values():
509
707
  if isinstance(module, Network):
510
708
  if name_function(module) not in networks:
511
709
  networks.append(name_function(module))
512
- for k, v in module._apply_network(name_function, networks, key+"."+name_function(module), function, *args, **kwargs).items():
513
- results.update({name_function(self)+"."+k : v})
710
+ for k, v in module._apply_network(
711
+ name_function,
712
+ networks,
713
+ key + "." + name_function(module),
714
+ function,
715
+ *args,
716
+ **kwargs,
717
+ ).items():
718
+ results.update({name_function(self) + "." + k: v})
514
719
  if len([param.name for param in list(inspect.signature(function).parameters.values()) if param.name == "key"]):
515
720
  function = partial(function, key=key)
516
721
 
517
722
  results[name_function(self)] = function(self, *args, **kwargs)
518
723
  return results
519
-
520
- def _function_network():
521
- def _function_network_d(function : Callable):
522
- def new_function(self : Self, *args, **kwargs) -> dict[str, object]:
523
- return self._apply_network(lambda network: network.getName(), [], self.getName(), function, *args, **kwargs)
724
+
725
+ def _function_network(): # type: ignore[misc]
726
+ def _function_network_d(function: Callable):
727
+ def new_function(self: Self, *args, **kwargs) -> dict[str, object]:
728
+ return self._apply_network(
729
+ lambda network: network.get_name(),
730
+ [],
731
+ self.get_name(),
732
+ function,
733
+ *args,
734
+ **kwargs,
735
+ )
736
+
524
737
  return new_function
738
+
525
739
  return _function_network_d
526
740
 
527
- def __init__( self,
528
- in_channels : int = 1,
529
- optimizer: Union[OptimizerLoader, None] = None,
530
- schedulers: Union[dict[str, LRSchedulersLoader], None] = None,
531
- outputsCriterions: Union[dict[str, TargetCriterionsLoader], None] = None,
532
- patch : Union[ModelPatch, None] = None,
533
- nb_batch_per_step : int = 1,
534
- init_type : str = "normal",
535
- init_gain : float = 0.02,
536
- dim : int = 3) -> None:
741
+ def __init__(
742
+ self,
743
+ in_channels: int = 1,
744
+ optimizer: OptimizerLoader | None = None,
745
+ schedulers: dict[str, LRSchedulersLoader] | None = None,
746
+ outputs_criterions: dict[str, TargetCriterionsLoader] | None = None,
747
+ patch: ModelPatch | None = None,
748
+ nb_batch_per_step: int = 1,
749
+ init_type: str = "normal",
750
+ init_gain: float = 0.02,
751
+ dim: int = 3,
752
+ ) -> None:
537
753
  super().__init__()
538
754
  self.name = self.__class__.__name__
539
755
  self.in_channels = in_channels
540
- self.optimizerLoader = optimizer
541
- self.optimizer : Union[torch.optim.Optimizer, None] = None
756
+ self.optimizerLoader = optimizer
757
+ self.optimizer: torch.optim.Optimizer | None = None
542
758
 
543
- self.LRSchedulersLoader = schedulers
544
- self.schedulers : Union[dict[torch.optim.lr_scheduler._LRScheduler, int], None] = {}
759
+ self.lr_schedulers_loader = schedulers
760
+ self.schedulers: dict[torch.optim.lr_scheduler._LRScheduler, int] = {}
545
761
 
546
- self.outputsCriterionsLoader = outputsCriterions
547
- self.measure : Union[Measure, None] = None
762
+ self.outputs_criterions_loader = outputs_criterions
763
+ self.measure: Measure | None = None
548
764
 
549
765
  self.patch = patch
550
766
 
551
767
  self.nb_batch_per_step = nb_batch_per_step
552
- self.init_type = init_type
553
- self.init_gain = init_gain
768
+ self.init_type = init_type
769
+ self.init_gain = init_gain
554
770
  self.dim = dim
555
771
  self._it = 0
556
772
  self._nb_lr_update = 0
557
- self.outputsGroup : list[OutputsGroup]= []
773
+ self.outputsGroup: list[OutputsGroup] = []
558
774
 
559
775
  @_function_network()
560
776
  def state_dict(self) -> dict[str, OrderedDict]:
561
- destination = OrderedDict()
562
- destination._metadata = OrderedDict()
563
- destination._metadata[""] = local_metadata = dict(version=self._version)
777
+ destination: OrderedDict[str, Any] = OrderedDict()
778
+ local_metadata = {"version": self._version}
779
+ # destination["_metadata"] = OrderedDict({"": local_metadata})
564
780
  self._save_to_state_dict(destination, "", False)
565
781
  for name, module in self._modules.items():
566
782
  if module is not None:
567
- if not isinstance(module, Network):
568
- module.state_dict(destination=destination, prefix="" + name + '.', keep_vars=False)
783
+ if not isinstance(module, Network):
784
+ module.state_dict(destination=destination, prefix="" + name + ".", keep_vars=False)
569
785
  for hook in self._state_dict_hooks.values():
570
786
  hook_result = hook(self, destination, "", local_metadata)
571
787
  if hook_result is not None:
572
788
  destination = hook_result
573
789
  return destination
574
-
790
+
575
791
  def load_state_dict(self, state_dict: dict[str, torch.Tensor]):
576
792
  missing_keys: list[str] = []
577
793
  unexpected_keys: list[str] = []
578
794
  error_msgs: list[str] = []
579
795
 
580
- metadata = getattr(state_dict, '_metadata', None)
796
+ metadata = getattr(state_dict, "_metadata", None)
581
797
  state_dict = state_dict.copy()
582
798
  if metadata is not None:
583
- state_dict._metadata = metadata
799
+ state_dict["_metadata"] = metadata
584
800
 
585
- def load(module: torch.nn.Module, prefix=''):
801
+ def load(module: torch.nn.Module, prefix=""):
586
802
  local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
587
803
  module._load_from_state_dict(
588
- state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs)
804
+ state_dict,
805
+ prefix,
806
+ local_metadata,
807
+ True,
808
+ missing_keys,
809
+ unexpected_keys,
810
+ error_msgs,
811
+ )
589
812
  for name, child in module._modules.items():
590
813
  if child is not None:
591
814
  if not isinstance(child, Network):
592
815
  if isinstance(child, torch.nn.modules.conv._ConvNd) or isinstance(module, torch.nn.Linear):
593
816
 
594
817
  current_size = child.weight.shape[0]
595
- last_size = state_dict[prefix + name+".weight"].shape[0]
818
+ last_size = state_dict[prefix + name + ".weight"].shape[0]
596
819
 
597
820
  if current_size != last_size:
598
- print("Warning: The size of '{}' has changed from {} to {}. Please check for potential impacts".format(prefix + name, last_size, current_size))
821
+ print(
822
+ f"Warning: The size of '{prefix + name}' has changed from {last_size}"
823
+ f" to {current_size}. Please check for potential impacts"
824
+ )
599
825
  ModuleArgsDict.init_func(child, self.init_type, self.init_gain)
600
826
 
601
827
  with torch.no_grad():
602
- child.weight[:last_size] = state_dict[prefix + name+".weight"]
828
+ child.weight[:last_size] = state_dict[prefix + name + ".weight"]
603
829
  if child.bias is not None:
604
- child.bias[:last_size] = state_dict[prefix + name+".bias"]
830
+ child.bias[:last_size] = state_dict[prefix + name + ".bias"]
605
831
  return
606
- load(child, prefix + name + '.')
832
+ load(child, prefix + name + ".")
607
833
 
608
834
  load(self)
609
- del load
610
835
 
611
836
  if len(unexpected_keys) > 0:
837
+ formatted_keys = ", ".join(f'"{k}"' for k in unexpected_keys)
612
838
  error_msgs.insert(
613
- 0, 'Unexpected key(s) in state_dict: {}. '.format(
614
- ', '.join('"{}"'.format(k) for k in unexpected_keys)))
839
+ 0,
840
+ f"Unexpected key(s) in state_dict: {formatted_keys}.",
841
+ )
615
842
  if len(missing_keys) > 0:
843
+ formatted_keys = ", ".join(f'"{k}"' for k in missing_keys)
616
844
  error_msgs.insert(
617
- 0, 'Missing key(s) in state_dict: {}. '.format(
618
- ', '.join('"{}"'.format(k) for k in missing_keys)))
845
+ 0,
846
+ f"Missing key(s) in state_dict: {formatted_keys}.",
847
+ )
619
848
 
620
849
  if len(error_msgs) > 0:
621
- raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
622
- self.__class__.__name__, "\n\t".join(error_msgs)))
623
-
850
+ formatted_errors = "\n\t".join(error_msgs)
851
+ raise RuntimeError(
852
+ f"Error(s) in loading state_dict for {self.__class__.__name__}:\n\t{formatted_errors}",
853
+ )
854
+
624
855
  def apply(self, fn: Callable[[torch.nn.Module], None]) -> None:
625
856
  for module in self.children():
626
857
  if not isinstance(module, Network):
627
858
  module.apply(fn)
628
859
  fn(self)
629
-
860
+
630
861
  @_function_network()
631
- def load(self, state_dict : dict[str, dict[str, torch.Tensor]], init: bool = True, ema : bool =False):
862
+ def load(
863
+ self,
864
+ state_dict: dict[str, dict[str, torch.Tensor] | int],
865
+ init: bool = True,
866
+ ema: bool = False,
867
+ ):
632
868
  if init:
633
- self.apply(partial(ModuleArgsDict.init_func, init_type=self.init_type, init_gain=self.init_gain))
634
- name = "Model" + ("_EMA" if ema else "")
869
+ self.apply(
870
+ partial(
871
+ ModuleArgsDict.init_func,
872
+ init_type=self.init_type,
873
+ init_gain=self.init_gain,
874
+ )
875
+ )
876
+ name = "Model"
877
+ if ema:
878
+ if name + "_EMA" in state_dict:
879
+ name += "_EMA"
635
880
  if name in state_dict:
636
- model_state_dict_tmp = {k.split(".")[-1] : v for k, v in state_dict[name].items()}[self.getName()]
637
- map = self.getMap()
638
- model_state_dict : OrderedDict[str, torch.Tensor] = OrderedDict()
639
-
881
+ value = state_dict[name]
882
+ model_state_dict_tmp = {}
883
+ if isinstance(value, dict):
884
+ model_state_dict_tmp = {k.split(".")[-1]: v for k, v in value.items()}[self.get_name()]
885
+ modules_name = self.get_mapping()
886
+ model_state_dict: OrderedDict[str, torch.Tensor] = OrderedDict()
887
+
640
888
  for alias in model_state_dict_tmp.keys():
641
889
  prefix = ".".join(alias.split(".")[:-1])
642
- alias_list = [(".".join(prefix.split(".")[:len(i.split("."))]), v) for i, v in map.items() if prefix.startswith(i)]
890
+ alias_list = [
891
+ (".".join(prefix.split(".")[: len(i.split("."))]), v)
892
+ for i, v in modules_name.items()
893
+ if prefix.startswith(i)
894
+ ]
643
895
 
644
896
  if len(alias_list):
645
897
  for a, b in alias_list:
@@ -648,19 +900,36 @@ class Network(ModuleArgsDict, ABC):
648
900
  else:
649
901
  model_state_dict[alias] = model_state_dict_tmp[alias]
650
902
  self.load_state_dict(model_state_dict)
651
- if "{}_optimizer_state_dict".format(self.getName()) in state_dict and self.optimizer:
652
- self.optimizer.load_state_dict(state_dict['{}_optimizer_state_dict'.format(self.getName())])
653
- if "{}_it".format(self.getName()) in state_dict:
654
- self._it = int(state_dict["{}_it".format(self.getName())])
655
- if "{}_nb_lr_update".format(self.getName()) in state_dict:
656
- self._nb_lr_update = int(state_dict["{}_nb_lr_update".format(self.getName())])
903
+ if f"{self.get_name()}_optimizer_state_dict" in state_dict and self.optimizer:
904
+ last_lr = self.optimizer.param_groups[0]["lr"]
905
+ self.optimizer.load_state_dict(state_dict[f"{self.get_name()}_optimizer_state_dict"])
906
+ self.optimizer.param_groups[0]["lr"] = last_lr
907
+ if f"{self.get_name()}_it" in state_dict:
908
+ _it = state_dict.get(f"{self.get_name()}_it")
909
+ if isinstance(_it, int):
910
+ self._it = _it
911
+ if f"{self.get_name()}_nb_lr_update" in state_dict:
912
+ _nb_lr_update = state_dict.get(f"{self.get_name()}_nb_lr_update")
913
+ if isinstance(_nb_lr_update, int):
914
+ self._nb_lr_update = _nb_lr_update
915
+
657
916
  for scheduler in self.schedulers:
658
917
  if scheduler.last_epoch == -1:
659
918
  scheduler.last_epoch = self._nb_lr_update
660
919
  self.initialized()
661
920
 
662
- def _compute_channels_trace(self, module : ModuleArgsDict, in_channels : int, gradient_checkpoints: Union[list[str], None], gpu_checkpoints: Union[list[str], None], name: Union[str, None] = None, in_is_channel = True, out_channels : Union[int, None] = None, out_is_channel = True) -> tuple[int, bool, int, bool]:
663
-
921
+ def _compute_channels_trace(
922
+ self,
923
+ module: ModuleArgsDict,
924
+ in_channels: int,
925
+ gradient_checkpoints: list[str] | None,
926
+ gpu_checkpoints: list[str] | None,
927
+ name: str | None = None,
928
+ in_is_channel: bool = True,
929
+ out_channels: int | None = None,
930
+ out_is_channel: bool = True,
931
+ ) -> tuple[int, bool, int | None, bool]:
932
+
664
933
  for k1, v1 in module.items():
665
934
  if isinstance(v1, ModuleArgsDict):
666
935
  for t in module._modulesArgs[k1].out_branch:
@@ -680,25 +949,34 @@ class Network(ModuleArgsDict, ABC):
680
949
  if hasattr(v, "in_features"):
681
950
  if v.in_features:
682
951
  in_channels = v.in_features
683
- key = name+"."+k if name else k
952
+ key = name + "." + k if name else k
684
953
 
685
954
  if gradient_checkpoints:
686
955
  if key in gradient_checkpoints:
687
956
  module._modulesArgs[k].isCheckpoint = True
688
-
957
+
689
958
  if gpu_checkpoints:
690
959
  if key in gpu_checkpoints:
691
960
  module._modulesArgs[k].isGPU_Checkpoint = True
692
-
961
+
693
962
  module._modulesArgs[k].in_channels = in_channels
694
963
  module._modulesArgs[k].in_is_channel = in_is_channel
695
-
964
+
696
965
  if isinstance(v, ModuleArgsDict):
697
- in_channels, in_is_channel, out_channels, out_is_channel = self._compute_channels_trace(v, in_channels, gradient_checkpoints, gpu_checkpoints, key, in_is_channel, out_channels, out_is_channel)
698
-
966
+ in_channels, in_is_channel, out_channels, out_is_channel = self._compute_channels_trace(
967
+ v,
968
+ in_channels,
969
+ gradient_checkpoints,
970
+ gpu_checkpoints,
971
+ key,
972
+ in_is_channel,
973
+ out_channels,
974
+ out_is_channel,
975
+ )
976
+
699
977
  if v.__class__.__name__ == "ToChannels":
700
978
  out_is_channel = True
701
-
979
+
702
980
  if v.__class__.__name__ == "ToFeatures":
703
981
  out_is_channel = False
704
982
 
@@ -712,26 +990,32 @@ class Network(ModuleArgsDict, ABC):
712
990
  module._modulesArgs[k].out_channels = out_channels
713
991
  module._modulesArgs[k].out_is_channel = out_is_channel
714
992
 
715
- in_channels = out_channels
993
+ in_channels = out_channels if out_channels is not None else in_channels
716
994
  in_is_channel = out_is_channel
717
-
995
+
718
996
  return in_channels, in_is_channel, out_channels, out_is_channel
719
997
 
720
998
  @_function_network()
721
- def init(self, autocast : bool, state : State, group_dest: list[str], key: str) -> None:
722
- if self.outputsCriterionsLoader:
723
- self.measure = Measure(key, self.outputsCriterionsLoader)
999
+ def init(self, autocast: bool, state: State, group_dest: list[str], key: str) -> None:
1000
+ if self.outputs_criterions_loader:
1001
+ self.measure = Measure(key, self.outputs_criterions_loader)
724
1002
  self.measure.init(self, group_dest)
1003
+ if self.patch is not None:
1004
+ self.patch.init(f"{konfai_root()}.Model.{key}.Patch")
725
1005
  if state != State.PREDICTION:
726
1006
  self.scaler = torch.amp.GradScaler("cuda", enabled=autocast)
727
1007
  if self.optimizerLoader:
728
- self.optimizer = self.optimizerLoader.getOptimizer(key, self.parameters(state == State.TRANSFER_LEARNING))
1008
+ self.optimizer = self.optimizerLoader.get_optimizer(
1009
+ key, self.parameters(state == State.TRANSFER_LEARNING)
1010
+ )
729
1011
  self.optimizer.zero_grad()
730
1012
 
731
- if self.LRSchedulersLoader and self.optimizer:
732
- for schedulers_classname, schedulers in self.LRSchedulersLoader.items():
733
- self.schedulers[schedulers.getschedulers(key, schedulers_classname, self.optimizer)] = schedulers.nb_step
734
-
1013
+ if self.lr_schedulers_loader and self.optimizer:
1014
+ for schedulers_classname, schedulers in self.lr_schedulers_loader.items():
1015
+ self.schedulers[schedulers.getschedulers(key, schedulers_classname, self.optimizer)] = (
1016
+ schedulers.nb_step
1017
+ )
1018
+
735
1019
  def initialized(self):
736
1020
  pass
737
1021
 
@@ -740,80 +1024,101 @@ class Network(ModuleArgsDict, ABC):
740
1024
  self.patch.load(inputs[0].shape[2:])
741
1025
  accumulators: dict[str, Accumulator] = {}
742
1026
 
743
- patchIterator = self.patch.disassemble(*inputs)
1027
+ patch_iterator = self.patch.disassemble(*inputs)
744
1028
  buffer = []
745
- for i, patch_input in enumerate(patchIterator):
746
- for (name, output_layer) in super().named_forward(*patch_input):
747
- yield "{}{}".format(";accu;", name), output_layer
1029
+ for i, patch_input in enumerate(patch_iterator):
1030
+ for name, output_layer in super().named_forward(*patch_input):
1031
+ yield f";accu;{name}", output_layer
748
1032
  buffer.append((name.split(".")[0], output_layer))
749
1033
  if len(buffer) == 2:
750
1034
  if buffer[0][0] != buffer[1][0]:
751
1035
  if self._modulesArgs[buffer[0][0]]._isEnd:
752
1036
  if buffer[0][0] not in accumulators:
753
- accumulators[buffer[0][0]] = Accumulator(self.patch.getPatch_slices(), self.patch.patch_size, self.patch.patchCombine)
754
- accumulators[buffer[0][0]].addLayer(i, buffer[0][1])
1037
+ accumulators[buffer[0][0]] = Accumulator(
1038
+ self.patch.get_patch_slices(),
1039
+ self.patch.patch_size,
1040
+ self.patch.patch_combine,
1041
+ )
1042
+ accumulators[buffer[0][0]].add_layer(i, buffer[0][1])
755
1043
  buffer.pop(0)
756
1044
  if self._modulesArgs[buffer[0][0]]._isEnd:
757
1045
  if buffer[0][0] not in accumulators:
758
- accumulators[buffer[0][0]] = Accumulator(self.patch.getPatch_slices(), self.patch.patch_size, self.patch.patchCombine)
759
- accumulators[buffer[0][0]].addLayer(i, buffer[0][1])
1046
+ accumulators[buffer[0][0]] = Accumulator(
1047
+ self.patch.get_patch_slices(),
1048
+ self.patch.patch_size,
1049
+ self.patch.patch_combine,
1050
+ )
1051
+ accumulators[buffer[0][0]].add_layer(i, buffer[0][1])
760
1052
  for name, accumulator in accumulators.items():
761
1053
  yield name, accumulator.assemble()
762
1054
  else:
763
- for (name, output_layer) in super().named_forward(*inputs):
1055
+ for name, output_layer in super().named_forward(*inputs):
764
1056
  yield name, output_layer
765
1057
 
766
-
767
- def get_layers(self, inputs : list[torch.Tensor], layers_name: list[str]) -> Iterator[tuple[str, torch.Tensor, Union[Patch_Indexed, None]]]:
1058
+ def get_layers(
1059
+ self, inputs: list[torch.Tensor], layers_name: list[str]
1060
+ ) -> Iterator[tuple[str, torch.Tensor, PatchIndexed | None]]:
768
1061
  layers_name = layers_name.copy()
769
- output_layer_accumulator : dict[str, Accumulator] = {}
770
- output_layer_patch_indexed : dict[str, Patch_Indexed] = {}
1062
+ output_layer_accumulator: dict[str, Accumulator] = {}
1063
+ output_layer_patch_indexed: dict[str, PatchIndexed] = {}
771
1064
  it = 0
772
1065
  debug = "KONFAI_DEBUG" in os.environ
773
- for (nameTmp, output_layer) in self.named_forward(*inputs):
774
- name = nameTmp.replace(";accu;", "")
1066
+ for name_tmp, output_layer in self.named_forward(*inputs):
1067
+ name = name_tmp.replace(";accu;", "")
775
1068
  if debug:
776
1069
  if "KONFAI_DEBUG_LAST_LAYER" in os.environ:
777
- os.environ["KONFAI_DEBUG_LAST_LAYER"] = "{}|{}:{}:{}".format(os.environ["KONFAI_DEBUG_LAST_LAYER"], name, getGPUMemory(output_layer.device), str(output_layer.device).replace("cuda:", ""))
1070
+ os.environ["KONFAI_DEBUG_LAST_LAYER"] = f"{os.environ['KONFAI_DEBUG_LAST_LAYER']}|{name}:"
1071
+ f"{get_gpu_memory(output_layer.device)}:"
1072
+ f"{str(output_layer.device).replace('cuda:', '')}"
778
1073
  else:
779
- os.environ["KONFAI_DEBUG_LAST_LAYER"] = "{}:{}:{}".format(name, getGPUMemory(output_layer.device), str(output_layer.device).replace("cuda:", ""))
1074
+ os.environ["KONFAI_DEBUG_LAST_LAYER"] = (
1075
+ f"{name}:{get_gpu_memory(output_layer.device)}:{str(output_layer.device).replace('cuda:', '')}"
1076
+ )
780
1077
  it += 1
781
- if name in layers_name or nameTmp in layers_name:
782
- if ";accu;" in nameTmp:
1078
+ if name in layers_name or name_tmp in layers_name:
1079
+ if ";accu;" in name_tmp:
783
1080
  if name not in output_layer_patch_indexed:
784
- networkName = nameTmp.split(".;accu;")[-2].split(".")[-1] if ".;accu;" in nameTmp else nameTmp.split(";accu;")[-2].split(".")[-1]
1081
+ network_name = (
1082
+ name_tmp.split(".;accu;")[-2].split(".")[-1]
1083
+ if ".;accu;" in name_tmp
1084
+ else name_tmp.split(";accu;")[-2].split(".")[-1]
1085
+ )
785
1086
  module = self
786
1087
  network = None
787
- if networkName == "":
1088
+ if network_name == "":
788
1089
  network = module
789
1090
  else:
790
1091
  for n in name.split("."):
791
1092
  module = module[n]
792
- if isinstance(module, Network) and n == networkName:
1093
+ if isinstance(module, Network) and n == network_name:
793
1094
  network = module
794
1095
  break
795
-
1096
+
796
1097
  if network and network.patch:
797
- output_layer_patch_indexed[name] = Patch_Indexed(network.patch, 0)
1098
+ output_layer_patch_indexed[name] = PatchIndexed(network.patch, 0)
798
1099
 
799
1100
  if name not in output_layer_accumulator:
800
- output_layer_accumulator[name] = Accumulator(output_layer_patch_indexed[name].patch.getPatch_slices(0), output_layer_patch_indexed[name].patch.patch_size, output_layer_patch_indexed[name].patch.patchCombine)
801
-
802
- if nameTmp in layers_name:
803
- output_layer_accumulator[name].addLayer(output_layer_patch_indexed[name].index, output_layer)
1101
+ output_layer_accumulator[name] = Accumulator(
1102
+ output_layer_patch_indexed[name].patch.get_patch_slices(0),
1103
+ output_layer_patch_indexed[name].patch.patch_size,
1104
+ output_layer_patch_indexed[name].patch.patch_combine,
1105
+ )
1106
+
1107
+ if name_tmp in layers_name:
1108
+ output_layer_accumulator[name].add_layer(output_layer_patch_indexed[name].index, output_layer)
804
1109
  output_layer_patch_indexed[name].index += 1
805
- if output_layer_accumulator[name].isFull():
1110
+ if output_layer_accumulator[name].is_full():
806
1111
  output_layer = output_layer_accumulator[name].assemble()
807
1112
  output_layer_accumulator.pop(name)
808
1113
  output_layer_patch_indexed.pop(name)
809
- layers_name.remove(nameTmp)
810
- yield nameTmp, output_layer, None
1114
+ layers_name.remove(name_tmp)
1115
+ yield name_tmp, output_layer, None
811
1116
 
812
1117
  if name in layers_name:
813
- if ";accu;" in nameTmp:
1118
+ if ";accu;" in name_tmp:
814
1119
  yield name, output_layer, output_layer_patch_indexed[name]
815
1120
  output_layer_patch_indexed[name].index += 1
816
- if output_layer_patch_indexed[name].isFull():
1121
+ if output_layer_patch_indexed[name].is_full():
817
1122
  output_layer_patch_indexed.pop(name)
818
1123
  layers_name.remove(name)
819
1124
  else:
@@ -823,62 +1128,85 @@ class Network(ModuleArgsDict, ABC):
823
1128
  if not len(layers_name):
824
1129
  break
825
1130
 
826
- def init_outputsGroup(self):
827
- metric_tmp = {network.measure : network.measure.outputsCriterions.keys() for network in self.getNetworks().values() if network.measure}
1131
+ def init_outputs_group(self):
1132
+ metric_tmp = {
1133
+ network.measure: network.measure.outputs_criterions.keys()
1134
+ for network in self.get_networks().values()
1135
+ if network.measure
1136
+ }
828
1137
  for k, v in metric_tmp.items():
829
1138
  for a in v:
830
1139
  outputs_group = OutputsGroup(k)
831
1140
  outputs_group.append(a)
832
- for targetsGroup in k.outputsCriterions[a].keys():
833
- if ":" in targetsGroup:
834
- outputs_group.append(targetsGroup.replace(":", "."))
1141
+ for targets_group in k.outputs_criterions[a].keys():
1142
+ if ":" in targets_group:
1143
+ outputs_group.append(targets_group.replace(":", "."))
835
1144
 
836
1145
  self.outputsGroup.append(outputs_group)
837
1146
 
838
- def forward(self, data_dict: dict[tuple[str, bool], torch.Tensor], output_layers: list[str] = []) -> list[tuple[str, torch.Tensor]]:
1147
+ def forward(
1148
+ self,
1149
+ data_dict: dict[tuple[str, bool], torch.Tensor],
1150
+ output_layers: list[str] = [],
1151
+ ) -> list[tuple[str, torch.Tensor]]:
839
1152
  if not len(self.outputsGroup) and not len(output_layers):
840
1153
  return []
841
1154
 
842
- self.resetLoss()
1155
+ self.reset_loss()
843
1156
  results = []
844
1157
  measure_output_layers = set()
845
- for outputs_group in self.outputsGroup:
846
- for name in outputs_group:
1158
+ for _outputs_group in self.outputsGroup:
1159
+ for name in _outputs_group:
847
1160
  measure_output_layers.add(name)
848
- for name, layer, patch_indexed in self.get_layers([v for k, v in data_dict.items() if k[1]], list(set(list(measure_output_layers)+output_layers))):
849
-
1161
+ for name, layer, patch_indexed in self.get_layers(
1162
+ [v for k, v in data_dict.items() if k[1]],
1163
+ list(set(list(measure_output_layers) + output_layers)),
1164
+ ):
1165
+
850
1166
  outputs_group = [outputs_group for outputs_group in self.outputsGroup if name in outputs_group]
851
1167
  if len(outputs_group) > 0:
852
1168
  if patch_indexed is None:
853
- targets = {k[0] : v for k, v in data_dict.items()}
1169
+ targets = {k[0]: v for k, v in data_dict.items()}
854
1170
  nb = 1
855
1171
  else:
856
- targets = {k[0] : patch_indexed.patch.getData(v, patch_indexed.index, 0, False) for k, v in data_dict.items()}
857
- nb = patch_indexed.patch.getSize(0)
1172
+ targets = {
1173
+ k[0]: patch_indexed.patch.get_data(v, patch_indexed.index, 0, False)
1174
+ for k, v in data_dict.items()
1175
+ }
1176
+ nb = patch_indexed.patch.get_size(0)
858
1177
 
859
1178
  for output_group in outputs_group:
860
- output_group.addLayer(name, layer)
861
- if output_group.isDone():
862
- targets.update({k.replace(".", ":"): v for k, v in output_group.layers.items() if k != output_group[0]})
863
- output_group.measure.update(output_group[0], output_group.layers[output_group[0]], targets, self._it, nb, self.training)
1179
+ output_group.add_layer(name, layer)
1180
+ if output_group.is_done():
1181
+ targets.update(
1182
+ {k.replace(".", ":"): v for k, v in output_group.layers.items() if k != output_group[0]}
1183
+ )
1184
+ output_group.measure.update(
1185
+ output_group[0],
1186
+ output_group.layers[output_group[0]],
1187
+ targets,
1188
+ self._it,
1189
+ nb,
1190
+ self.training,
1191
+ )
864
1192
  output_group.clear()
865
1193
  if name in output_layers:
866
1194
  results.append((name, layer))
867
1195
  return results
868
1196
 
869
1197
  @_function_network()
870
- def resetLoss(self):
1198
+ def reset_loss(self):
871
1199
  if self.measure:
872
- self.measure.resetLoss()
1200
+ self.measure.reset_loss()
873
1201
 
874
1202
  @_function_network()
875
1203
  def backward(self, model: Self):
876
1204
  if self.measure:
877
1205
  if self.scaler and self.optimizer:
878
- model._requires_grad(list(self.measure.outputsCriterions.keys()))
879
- for loss in self.measure.getLoss():
1206
+ model._requires_grad(list(self.measure.outputs_criterions.keys()))
1207
+ for loss in self.measure.get_loss():
880
1208
  self.scaler.scale(loss / self.nb_batch_per_step).backward()
881
-
1209
+
882
1210
  if self._it % self.nb_batch_per_step == 0:
883
1211
  self.scaler.step(self.optimizer)
884
1212
  self.scaler.update()
@@ -887,93 +1215,133 @@ class Network(ModuleArgsDict, ABC):
887
1215
 
888
1216
  @_function_network()
889
1217
  def update_lr(self):
890
- self._nb_lr_update+=1
1218
+ self._nb_lr_update += 1
891
1219
  step = 0
892
- scheduler = None
893
- for scheduler, value in self.schedulers.items():
894
- if value is None or (self._nb_lr_update >= step and self._nb_lr_update < step+value):
1220
+ _scheduler = None
1221
+ for _scheduler, value in self.schedulers.items():
1222
+ if value is None or (self._nb_lr_update >= step and self._nb_lr_update < step + value):
895
1223
  break
896
1224
  step += value
897
- if scheduler:
898
- if scheduler.__class__.__name__ == 'ReduceLROnPlateau':
1225
+ if _scheduler:
1226
+ if _scheduler.__class__.__name__ == "ReduceLROnPlateau":
899
1227
  if self.measure:
900
- scheduler.step(sum(self.measure.getLastValues(0).values()))
1228
+ _scheduler.step(sum(self.measure.get_last_values(0).values()))
901
1229
  else:
902
- scheduler.step()
1230
+ _scheduler.step()
903
1231
 
904
1232
  @_function_network()
905
- def getNetworks(self) -> Self:
1233
+ def get_networks(self) -> Self:
906
1234
  return self
907
1235
 
908
- def to(module : ModuleArgsDict, device: int):
1236
+ @staticmethod
1237
+ def to(module: ModuleArgsDict, device: int):
909
1238
  if "device" not in os.environ:
910
1239
  os.environ["device"] = str(device)
911
1240
  for k, v in module.items():
912
1241
  if module._modulesArgs[k].gpu == "cpu":
913
1242
  if module._modulesArgs[k].isGPU_Checkpoint:
914
- os.environ["device"] = str(int(os.environ["device"])+1)
915
- module._modulesArgs[k].gpu = str(getDevice(int(os.environ["device"])))
1243
+ os.environ["device"] = str(int(os.environ["device"]) + 1)
1244
+ module._modulesArgs[k].gpu = str(get_device(int(os.environ["device"])))
916
1245
  if isinstance(v, ModuleArgsDict):
917
1246
  v = Network.to(v, int(os.environ["device"]))
918
1247
  else:
919
- v = v.to(getDevice(int(os.environ["device"])))
1248
+ v = v.to(get_device(int(os.environ["device"])))
920
1249
  if isinstance(module, Network):
921
1250
  if module.optimizer is not None:
922
1251
  for state in module.optimizer.state.values():
923
1252
  for k, v in state.items():
924
1253
  if isinstance(v, torch.Tensor):
925
- state[k] = v.to(getDevice(int(os.environ["device"])))
1254
+ state[k] = v.to(get_device(int(os.environ["device"])))
926
1255
  return module
927
-
928
- def getName(self) -> str:
1256
+
1257
+ def get_name(self) -> str:
929
1258
  return self.name
930
-
931
- def setName(self, name: str) -> Self:
1259
+
1260
+ def set_name(self, name: str) -> Self:
932
1261
  self.name = name
933
1262
  return self
934
-
935
- def setState(self, state: NetState):
1263
+
1264
+ def set_state(self, state: NetState):
936
1265
  for module in self.modules():
937
1266
  if isinstance(module, ModuleArgsDict):
938
1267
  module._training = state
939
1268
 
1269
+
940
1270
  class MinimalModel(Network):
941
-
942
- def __init__(self, model: Network, optimizer : OptimizerLoader = OptimizerLoader(),
943
- schedulers : LRSchedulersLoader = LRSchedulersLoader(),
944
- outputsCriterions: dict[str, TargetCriterionsLoader] = {"default" : TargetCriterionsLoader()},
945
- patch : Union[ModelPatch, None] = None,
946
- dim : int = 3, nb_batch_per_step = 1, init_type = "normal", init_gain = 0.02):
947
- super().__init__(1, optimizer, schedulers, outputsCriterions, patch, nb_batch_per_step, init_type, init_gain, dim)
1271
+
1272
+ def __init__(
1273
+ self,
1274
+ model: Network,
1275
+ optimizer: OptimizerLoader = OptimizerLoader(),
1276
+ schedulers: dict[str, LRSchedulersLoader] = {"default:ReduceLROnPlateau": LRSchedulersLoader(0)},
1277
+ outputs_criterions: dict[str, TargetCriterionsLoader] = {"default": TargetCriterionsLoader()},
1278
+ patch: ModelPatch | None = None,
1279
+ dim: int = 3,
1280
+ nb_batch_per_step=1,
1281
+ init_type="normal",
1282
+ init_gain=0.02,
1283
+ ):
1284
+ super().__init__(
1285
+ 1,
1286
+ optimizer,
1287
+ schedulers,
1288
+ outputs_criterions,
1289
+ patch,
1290
+ nb_batch_per_step,
1291
+ init_type,
1292
+ init_gain,
1293
+ dim,
1294
+ )
948
1295
  self.add_module("Model", model)
949
1296
 
950
- class ModelLoader():
1297
+
1298
+ class ModelLoader:
951
1299
 
952
1300
  @config("Model")
953
- def __init__(self, classpath : str = "default:segmentation.UNet.UNet") -> None:
954
- self.module, self.name = _getModule(classpath, "konfai.models")
955
-
956
- def getModel(self, train : bool = True, DL_args: Union[str, None] = None, DL_without=["optimizer", "schedulers", "nb_batch_per_step", "init_type", "init_gain"]) -> Network:
957
- if not DL_args:
958
- DL_args="{}.Model".format(KONFAI_ROOT())
959
- DL_args += "."+self.name
960
- model = config(DL_args)(getattr(importlib.import_module(self.module), self.name))(config = None, DL_without = DL_without if not train else [])
1301
+ def __init__(self, classpath: str = "default:segmentation.UNet.UNet") -> None:
1302
+ self.module, self.name = get_module(classpath, "konfai.models")
1303
+
1304
+ def get_model(
1305
+ self,
1306
+ train: bool = True,
1307
+ konfai_args: str | None = None,
1308
+ konfai_without=[
1309
+ "optimizer",
1310
+ "schedulers",
1311
+ "nb_batch_per_step",
1312
+ "init_type",
1313
+ "init_gain",
1314
+ ],
1315
+ ) -> Network:
1316
+ if not konfai_args:
1317
+ konfai_args = f"{konfai_root()}.Model"
1318
+ konfai_args += "." + self.name
1319
+ model = config(konfai_args)(getattr(importlib.import_module(self.module), self.name))(
1320
+ config=None, konfai_without=konfai_without if not train else []
1321
+ )
961
1322
  if not isinstance(model, Network):
962
- model = config(DL_args)(partial(MinimalModel, model))(config = None, DL_without = DL_without+["model"] if not train else [])
963
- model.setName(self.name)
1323
+ model = config(konfai_args)(partial(MinimalModel, model))(
1324
+ config=None, konfai_without=konfai_without + ["model"] if not train else []
1325
+ )
1326
+ model.set_name(self.name)
964
1327
  return model
965
-
966
- class CPU_Model():
1328
+
1329
+
1330
+ class CPUModel:
967
1331
 
968
1332
  def __init__(self, model: Network) -> None:
969
1333
  self.module = model
970
- self.device = torch.device('cpu')
1334
+ self.device = torch.device("cpu")
971
1335
 
972
1336
  def train(self):
973
1337
  self.module.train()
974
-
975
- def eval(self):
1338
+
1339
+ def eval(self): # noqa: A003
976
1340
  self.module.eval()
977
-
978
- def __call__(self, data_dict: dict[tuple[str, bool], torch.Tensor], output_layers: list[str] = []) -> list[tuple[str, torch.Tensor]]:
979
- return self.module(data_dict, output_layers)
1341
+
1342
+ def __call__(
1343
+ self,
1344
+ data_dict: dict[tuple[str, bool], torch.Tensor],
1345
+ output_layers: list[str] = [],
1346
+ ) -> list[tuple[str, torch.Tensor]]:
1347
+ return self.module(data_dict, output_layers)