konfai 1.1.7__py3-none-any.whl → 1.1.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of konfai might be problematic. Click here for more details.

Files changed (36) hide show
  1. konfai/__init__.py +59 -14
  2. konfai/data/augmentation.py +457 -286
  3. konfai/data/data_manager.py +509 -290
  4. konfai/data/patching.py +300 -183
  5. konfai/data/transform.py +384 -277
  6. konfai/evaluator.py +309 -68
  7. konfai/main.py +71 -22
  8. konfai/metric/measure.py +341 -222
  9. konfai/metric/schedulers.py +24 -13
  10. konfai/models/classification/convNeXt.py +187 -81
  11. konfai/models/classification/resnet.py +272 -58
  12. konfai/models/generation/cStyleGan.py +233 -59
  13. konfai/models/generation/ddpm.py +348 -121
  14. konfai/models/generation/diffusionGan.py +757 -358
  15. konfai/models/generation/gan.py +177 -53
  16. konfai/models/generation/vae.py +140 -40
  17. konfai/models/registration/registration.py +135 -52
  18. konfai/models/representation/representation.py +57 -23
  19. konfai/models/segmentation/NestedUNet.py +339 -68
  20. konfai/models/segmentation/UNet.py +140 -30
  21. konfai/network/blocks.py +331 -187
  22. konfai/network/network.py +781 -423
  23. konfai/predictor.py +645 -240
  24. konfai/trainer.py +527 -216
  25. konfai/utils/ITK.py +191 -106
  26. konfai/utils/config.py +152 -95
  27. konfai/utils/dataset.py +326 -455
  28. konfai/utils/utils.py +495 -249
  29. {konfai-1.1.7.dist-info → konfai-1.1.9.dist-info}/METADATA +1 -3
  30. konfai-1.1.9.dist-info/RECORD +38 -0
  31. konfai/utils/registration.py +0 -199
  32. konfai-1.1.7.dist-info/RECORD +0 -39
  33. {konfai-1.1.7.dist-info → konfai-1.1.9.dist-info}/WHEEL +0 -0
  34. {konfai-1.1.7.dist-info → konfai-1.1.9.dist-info}/entry_points.txt +0 -0
  35. {konfai-1.1.7.dist-info → konfai-1.1.9.dist-info}/licenses/LICENSE +0 -0
  36. {konfai-1.1.7.dist-info → konfai-1.1.9.dist-info}/top_level.txt +0 -0
konfai/network/network.py CHANGED
@@ -1,117 +1,181 @@
1
- from functools import partial
2
1
  import importlib
3
2
  import inspect
4
3
  import os
5
- from typing import Iterable, Iterator, Callable
6
- from typing_extensions import Self
7
- import torch
8
4
  from abc import ABC
5
+ from collections import OrderedDict
6
+ from collections.abc import Callable, Iterable, Iterator, Sequence
7
+ from enum import Enum
8
+ from functools import partial
9
+ from typing import Any, Self
10
+
9
11
  import numpy as np
12
+ import torch
10
13
  from torch._jit_internal import _copy_to_script_wrapper
11
- from collections import OrderedDict
12
14
  from torch.utils.checkpoint import checkpoint
13
- from typing import Union
14
- from enum import Enum
15
15
 
16
- from konfai import KONFAI_ROOT
16
+ from konfai import konfai_root
17
+ from konfai.data.patching import Accumulator, ModelPatch
17
18
  from konfai.metric.schedulers import Scheduler
18
19
  from konfai.utils.config import config
19
- from konfai.utils.utils import State, _getModule, getDevice, getGPUMemory, MeasureError, TrainerError
20
- from konfai.data.patching import Accumulator, ModelPatch
20
+ from konfai.utils.utils import MeasureError, State, TrainerError, get_device, get_gpu_memory, get_module
21
+
21
22
 
22
23
  class NetState(Enum):
23
- TRAIN = 0,
24
+ TRAIN = (0,)
24
25
  PREDICTION = 1
25
26
 
26
- class Patch_Indexed():
27
+
28
+ class PatchIndexed:
27
29
 
28
30
  def __init__(self, patch: ModelPatch, index: int) -> None:
29
31
  self.patch = patch
30
32
  self.index = index
31
33
 
32
- def isFull(self) -> bool:
33
- return len(self.patch.getPatch_slices(0)) == self.index
34
+ def is_full(self) -> bool:
35
+ return len(self.patch.get_patch_slices(0)) == self.index
36
+
37
+
38
+ class OptimizerLoader:
34
39
 
35
- class OptimizerLoader():
36
-
37
40
  @config("Optimizer")
38
41
  def __init__(self, name: str = "AdamW") -> None:
39
42
  self.name = name
40
-
41
- def getOptimizer(self, key: str, parameter: Iterator[torch.nn.parameter.Parameter]) -> torch.optim.Optimizer:
42
- return config("{}.Model.{}.Optimizer".format(KONFAI_ROOT(), key))(getattr(importlib.import_module('torch.optim'), self.name))(parameter, config = None)
43
-
44
43
 
45
- class LRSchedulersLoader():
46
-
44
+ def get_optimizer(self, key: str, parameter: Iterator[torch.nn.parameter.Parameter]) -> torch.optim.Optimizer:
45
+ return config(f"{konfai_root()}.Model.{key}.Optimizer")(
46
+ getattr(importlib.import_module("torch.optim"), self.name)
47
+ )(parameter, config=None)
48
+
49
+
50
+ class LRSchedulersLoader:
51
+
47
52
  @config()
48
- def __init__(self, nb_step : int = 0) -> None:
53
+ def __init__(self, nb_step: int = 0) -> None:
49
54
  self.nb_step = nb_step
50
55
 
51
- def getschedulers(self, key: str, scheduler_classname: str, optimizer: torch.optim.Optimizer) -> torch.optim.lr_scheduler._LRScheduler:
56
+ def getschedulers(
57
+ self, key: str, scheduler_classname: str, optimizer: torch.optim.Optimizer
58
+ ) -> torch.optim.lr_scheduler._LRScheduler:
52
59
  for m in ["torch.optim.lr_scheduler", "konfai.metric.schedulers"]:
53
- module, name = _getModule(scheduler_classname, m)
60
+ module, name = get_module(scheduler_classname, m)
54
61
  if hasattr(importlib.import_module(module), name):
55
- return config("{}.Model.{}.schedulers.{}".format(KONFAI_ROOT(), key, scheduler_classname))(getattr(importlib.import_module(module), name))(optimizer, config = None)
56
- raise TrainerError("Unknown scheduler {}, tried importing from: 'torch.optim.lr_scheduler' and 'konfai.metric.schedulers', but no valid match was found. Check your YAML config or scheduler name spelling.".format(scheduler_classname))
62
+ return config(f"{konfai_root()}.Model.{key}.schedulers.{scheduler_classname}")(
63
+ getattr(importlib.import_module(module), name)
64
+ )(optimizer, config=None)
65
+ raise TrainerError(
66
+ f"Unknown scheduler {scheduler_classname}, tried importing from: 'torch.optim.lr_scheduler' and "
67
+ "'konfai.metric.schedulers', but no valid match was found. "
68
+ "Check your YAML config or scheduler name spelling."
69
+ )
70
+
71
+
72
+ class LossSchedulersLoader:
57
73
 
58
- class LossSchedulersLoader():
59
-
60
74
  @config()
61
- def __init__(self, nb_step : int = 0) -> None:
75
+ def __init__(self, nb_step: int = 0) -> None:
62
76
  self.nb_step = nb_step
63
77
 
64
78
  def getschedulers(self, key: str, scheduler_classname: str) -> torch.optim.lr_scheduler._LRScheduler:
65
- return config("{}.{}".format(key, scheduler_classname))(getattr(importlib.import_module('konfai.metric.schedulers'), scheduler_classname))(config = None)
66
-
67
- class CriterionsAttr():
68
-
79
+ return config(f"{key}.{scheduler_classname}")(
80
+ getattr(importlib.import_module("konfai.metric.schedulers"), scheduler_classname)
81
+ )(config=None)
82
+
83
+
84
+ class CriterionsAttr:
85
+
69
86
  @config()
70
- def __init__(self, schedulers: dict[str, LossSchedulersLoader] = {"default:Constant": LossSchedulersLoader(0)}, isLoss: bool = True, group: int = 0, stepStart: int = 0, stepStop: Union[int, None] = None, accumulation: bool = False) -> None:
87
+ def __init__(
88
+ self,
89
+ schedulers: dict[str, LossSchedulersLoader] = {"default:Constant": LossSchedulersLoader(0)},
90
+ is_loss: bool = True,
91
+ group: int = 0,
92
+ start: int = 0,
93
+ stop: int | None = None,
94
+ accumulation: bool = False,
95
+ ) -> None:
71
96
  self.schedulersLoader = schedulers
72
97
  self.isTorchCriterion = True
73
- self.isLoss = isLoss
74
- self.stepStart = stepStart
75
- self.stepStop = stepStop
98
+ self.is_loss = is_loss
99
+ self.start = start
100
+ self.stop = stop
76
101
  self.group = group
77
102
  self.accumulation = accumulation
78
103
  self.schedulers: dict[Scheduler, int] = {}
79
-
80
- class CriterionsLoader():
81
-
82
- @config()
83
- def __init__(self, criterionsLoader: dict[str, CriterionsAttr] = {"default:torch_nn_CrossEntropyLoss:Dice:NCC": CriterionsAttr()}) -> None:
84
- self.criterionsLoader = criterionsLoader
85
104
 
86
- def getCriterions(self, model_classname : str, output_group : str, target_group : str) -> dict[torch.nn.Module, CriterionsAttr]:
87
- criterions = {}
88
- for module_classpath, criterionsAttr in self.criterionsLoader.items():
89
- module, name = _getModule(module_classpath, "konfai.metric.measure")
90
- criterionsAttr.isTorchCriterion = module.startswith("torch")
91
- for scheduler_classname, schedulers in criterionsAttr.schedulersLoader.items():
92
- criterionsAttr.schedulers[schedulers.getschedulers("{}.Model.{}.outputsCriterions.{}.targetsCriterions.{}.criterionsLoader.{}.schedulers".format(KONFAI_ROOT(), model_classname, output_group, target_group, module_classpath), scheduler_classname)] = schedulers.nb_step
93
- criterions[config("{}.Model.{}.outputsCriterions.{}.targetsCriterions.{}.criterionsLoader.{}".format(KONFAI_ROOT(), model_classname, output_group, target_group, module_classpath))(getattr(importlib.import_module(module), name))(config = None)] = criterionsAttr
94
- return criterions
95
105
 
96
- class TargetCriterionsLoader():
106
+ class CriterionsLoader:
97
107
 
98
108
  @config()
99
- def __init__(self, targetsCriterions : dict[str, CriterionsLoader] = {"default" : CriterionsLoader()}) -> None:
100
- self.targetsCriterions = targetsCriterions
101
-
102
- def getTargetsCriterions(self, output_group : str, model_classname : str) -> dict[str, dict[torch.nn.Module, float]]:
103
- targetsCriterions = {}
104
- for target_group, criterionsLoader in self.targetsCriterions.items():
105
- targetsCriterions[target_group] = criterionsLoader.getCriterions(model_classname, output_group, target_group)
106
- return targetsCriterions
109
+ def __init__(
110
+ self,
111
+ criterions_loader: dict[str, CriterionsAttr] = {"default:torch:nn:CrossEntropyLoss:Dice:NCC": CriterionsAttr()},
112
+ ) -> None:
113
+ self.criterions_loader = criterions_loader
114
+
115
+ def get_criterions(
116
+ self, model_classname: str, output_group: str, target_group: str
117
+ ) -> dict[torch.nn.Module, CriterionsAttr]:
118
+ criterions = {}
119
+ for module_classpath, criterions_attr in self.criterions_loader.items():
120
+ module, name = get_module(module_classpath, "konfai.metric.measure")
121
+ criterions_attr.isTorchCriterion = module.startswith("torch")
122
+ for (
123
+ scheduler_classname,
124
+ schedulers,
125
+ ) in criterions_attr.schedulersLoader.items():
126
+ criterions_attr.schedulers[
127
+ schedulers.getschedulers(
128
+ f"{konfai_root()}.Model.{model_classname}.outputs_criterions.{output_group}"
129
+ f".targets_criterions.{target_group}"
130
+ f".criterions_loader.{module_classpath}.schedulers",
131
+ scheduler_classname,
132
+ )
133
+ ] = schedulers.nb_step
134
+ criterions[
135
+ config(
136
+ f"{konfai_root()}.Model.{model_classname}.outputs_criterions."
137
+ f"{output_group}.targets_criterions.{target_group}."
138
+ f"criterions_loader.{module_classpath}"
139
+ )(getattr(importlib.import_module(module), name))(config=None)
140
+ ] = criterions_attr
141
+ return criterions
107
142
 
108
- class Measure():
109
143
 
110
- class Loss():
144
+ class TargetCriterionsLoader:
111
145
 
112
- def __init__(self, name: str, output_group: str, target_group: str, group: int, isLoss: bool, accumulation: bool) -> None:
146
+ @config()
147
+ def __init__(
148
+ self,
149
+ targets_criterions: dict[str, CriterionsLoader] = {"default": CriterionsLoader()},
150
+ ) -> None:
151
+ self.targets_criterions = targets_criterions
152
+
153
+ def get_targets_criterions(
154
+ self, output_group: str, model_classname: str
155
+ ) -> dict[str, dict[torch.nn.Module, CriterionsAttr]]:
156
+ targets_criterions = {}
157
+ for target_group, criterions_loader in self.targets_criterions.items():
158
+ targets_criterions[target_group] = criterions_loader.get_criterions(
159
+ model_classname, output_group, target_group
160
+ )
161
+ return targets_criterions
162
+
163
+
164
+ class Measure:
165
+
166
+ class Loss:
167
+
168
+ def __init__(
169
+ self,
170
+ name: str,
171
+ output_group: str,
172
+ target_group: str,
173
+ group: int,
174
+ is_loss: bool,
175
+ accumulation: bool,
176
+ ) -> None:
113
177
  self.name = name
114
- self.isLoss = isLoss
178
+ self.is_loss = is_loss
115
179
  self.accumulation = accumulation
116
180
  self.output_group = output_group
117
181
  self.target_group = target_group
@@ -120,144 +184,231 @@ class Measure():
120
184
  self._loss: list[torch.Tensor] = []
121
185
  self._weight: list[float] = []
122
186
  self._values: list[float] = []
123
-
124
- def resetLoss(self) -> None:
187
+
188
+ def reset_loss(self) -> None:
125
189
  self._loss.clear()
126
190
 
127
191
  def add(self, weight: float, value: torch.Tensor) -> None:
128
- if self.isLoss or value != 0:
129
- self._loss.append(value if self.isLoss else value.detach())
192
+ if self.is_loss or value != 0:
193
+ self._loss.append(value if self.is_loss else value.detach())
130
194
  self._values.append(value.item())
131
195
  self._weight.append(weight)
132
196
 
133
- def getLastLoss(self) -> torch.Tensor:
134
- return self._loss[-1]*self._weight[-1] if len(self._loss) else torch.zeros((1), requires_grad = True)
135
-
136
- def getLoss(self) -> torch.Tensor:
137
- return torch.stack([w*l for w, l in zip(self._weight, self._loss)], dim=0).mean(dim=0) if len(self._loss) else torch.zeros((1), requires_grad = True)
197
+ def get_last_loss(self) -> torch.Tensor:
198
+ return self._loss[-1] * self._weight[-1] if len(self._loss) else torch.zeros((1), requires_grad=True)
199
+
200
+ def get_loss(self) -> torch.Tensor:
201
+ return (
202
+ torch.stack([w * loss_value for w, loss_value in zip(self._weight, self._loss)], dim=0).mean(dim=0)
203
+ if len(self._loss)
204
+ else torch.zeros((1), requires_grad=True)
205
+ )
138
206
 
139
207
  def __len__(self) -> int:
140
208
  return len(self._loss)
141
-
142
- def __init__(self, model_classname : str, outputsCriterions: dict[str, TargetCriterionsLoader]) -> None:
143
- super().__init__()
144
- self.outputsCriterions = {}
145
- for output_group, targetCriterionsLoader in outputsCriterions.items():
146
- self.outputsCriterions[output_group.replace(":", ".")] = targetCriterionsLoader.getTargetsCriterions(output_group, model_classname)
147
- self._loss : dict[int, dict[str, Measure.Loss]] = {}
148
209
 
149
- def init(self, model : torch.nn.Module, group_dest: list[str]) -> None:
210
+ def __init__(
211
+ self,
212
+ model_classname: str,
213
+ outputs_criterions_loader: dict[str, TargetCriterionsLoader],
214
+ ) -> None:
215
+ super().__init__()
216
+ self.outputs_criterions: dict[str, dict[str, dict[torch.nn.Module, CriterionsAttr]]] = {}
217
+ for output_group, target_criterions_loader in outputs_criterions_loader.items():
218
+ self.outputs_criterions[output_group.replace(":", ".")] = target_criterions_loader.get_targets_criterions(
219
+ output_group, model_classname
220
+ )
221
+ self._loss: dict[int, dict[str, Measure.Loss]] = {}
222
+
223
+ def init(self, model: torch.nn.Module, group_dest: list[str]) -> None:
150
224
  outputs_group_rename = {}
151
-
225
+
152
226
  modules = []
153
- for i,_,_ in model.named_ModuleArgsDict():
227
+ for i, _, _ in model.named_module_args_dict():
154
228
  modules.append(i)
155
229
 
156
- for output_group in self.outputsCriterions.keys():
157
- if output_group not in modules:
158
- raise MeasureError(f"The output group '{output_group}' defined in 'outputsCriterions' does not correspond to any module in the model.",
230
+ for output_group in self.outputs_criterions.keys():
231
+ if output_group.replace(";accu;", "") not in modules:
232
+ raise MeasureError(
233
+ f"The output group '{output_group}' defined in 'outputs_criterions' "
234
+ "does not correspond to any module in the model.",
159
235
  f"Available modules: {modules}",
160
- "Please check that the name matches exactly a submodule or output of your model architecture."
236
+ "Please check that the name matches exactly a submodule or output of your model architecture.",
161
237
  )
162
- for target_group in self.outputsCriterions[output_group]:
163
- for target_group_tmp in target_group.split(";"):
238
+ for target_group in self.outputs_criterions[output_group]:
239
+ for target_group_tmp in target_group.split(";"):
164
240
  if target_group_tmp not in group_dest:
165
241
  raise MeasureError(
166
- f"The target_group '{target_group_tmp}' defined in 'outputsCriterions.{output_group}.targetsCriterions' was not found in the available destination groups.",
167
- "This target_group is expected for loss or metric computation, but was not loaded in 'group_dest'.",
168
- f"Please make sure that the group '{target_group_tmp}' is defined in 'Dataset:groups_src:...:groups_dest:'{target_group_tmp}'' and correctly loaded from the dataset.")
169
- for criterion in self.outputsCriterions[output_group][target_group]:
170
- if not self.outputsCriterions[output_group][target_group][criterion].isTorchCriterion:
242
+ f"The target_group '{target_group_tmp}' defined in "
243
+ "'outputs_criterions.{output_group}.targets_criterions'"
244
+ " was not found in the available destination groups.",
245
+ "This target_group is expected for loss or metric computation, "
246
+ "but was not loaded in 'group_dest'.",
247
+ f"Please make sure that the group '{target_group_tmp}' is defined in "
248
+ "'Dataset:groups_src:...:groups_dest:'{target_group_tmp}'' "
249
+ "and correctly loaded from the dataset.",
250
+ )
251
+ for criterion in self.outputs_criterions[output_group][target_group]:
252
+ if not self.outputs_criterions[output_group][target_group][criterion].isTorchCriterion:
171
253
  outputs_group_rename[output_group] = criterion.init(model, output_group, target_group)
172
254
 
173
- outputsCriterions_bak = self.outputsCriterions.copy()
255
+ outputs_criterions_bak = self.outputs_criterions.copy()
174
256
  for old, new in outputs_group_rename.items():
175
- self.outputsCriterions.pop(old)
176
- self.outputsCriterions[new] = outputsCriterions_bak[old]
177
- for output_group in self.outputsCriterions:
178
- for target_group in self.outputsCriterions[output_group]:
179
- for criterion, criterionsAttr in self.outputsCriterions[output_group][target_group].items():
180
- if criterionsAttr.group not in self._loss:
181
- self._loss[criterionsAttr.group] = {}
182
- self._loss[criterionsAttr.group]["{}:{}:{}".format(output_group, target_group, criterion.__class__.__name__)] = Measure.Loss(criterion.__class__.__name__, output_group, target_group, criterionsAttr.group, criterionsAttr.isLoss, criterionsAttr.accumulation)
183
-
184
- def update(self, output_group: str, output : torch.Tensor, data_dict: dict[str, torch.Tensor], it: int, nb_patch: int, training: bool) -> None:
185
- for target_group in self.outputsCriterions[output_group]:
186
- target = [data_dict[group].to(output[0].device).detach() for group in target_group.split(";") if group in data_dict]
187
-
188
- for criterion, criterionsAttr in self.outputsCriterions[output_group][target_group].items():
189
- if it >= criterionsAttr.stepStart and (criterionsAttr.stepStop is None or it <= criterionsAttr.stepStop):
190
- scheduler = self.update_scheduler(criterionsAttr.schedulers, it)
191
- self._loss[criterionsAttr.group]["{}:{}:{}".format(output_group, target_group, criterion.__class__.__name__)].add(scheduler.get_value(), criterion(output, *target))
192
- if training and len(np.unique([len(l) for l in self._loss[criterionsAttr.group].values() if l.accumulation and l.isLoss])) == 1:
193
- if criterionsAttr.isLoss:
194
- loss = torch.zeros((1), requires_grad = True)
195
- for v in [l for l in self._loss[criterionsAttr.group].values() if l.accumulation and l.isLoss]:
196
- l = v.getLastLoss()
197
- loss = loss.to(l.device)+l
198
- loss = loss/nb_patch
257
+ self.outputs_criterions.pop(old)
258
+ self.outputs_criterions[new] = outputs_criterions_bak[old]
259
+ for output_group in self.outputs_criterions:
260
+ for target_group in self.outputs_criterions[output_group]:
261
+ for criterion, criterions_attr in self.outputs_criterions[output_group][target_group].items():
262
+ if criterions_attr.group not in self._loss:
263
+ self._loss[criterions_attr.group] = {}
264
+ self._loss[criterions_attr.group][
265
+ f"{output_group}:{target_group}:{criterion.__class__.__name__}"
266
+ ] = Measure.Loss(
267
+ criterion.__class__.__name__,
268
+ output_group,
269
+ target_group,
270
+ criterions_attr.group,
271
+ criterions_attr.is_loss,
272
+ criterions_attr.accumulation,
273
+ )
274
+
275
+ def update(
276
+ self,
277
+ output_group: str,
278
+ output: torch.Tensor,
279
+ data_dict: dict[str, torch.Tensor],
280
+ it: int,
281
+ nb_patch: int,
282
+ training: bool,
283
+ ) -> None:
284
+ for target_group in self.outputs_criterions[output_group]:
285
+ target = [
286
+ data_dict[group].to(output[0].device).detach()
287
+ for group in target_group.split(";")
288
+ if group in data_dict
289
+ ]
290
+
291
+ for criterion, criterions_attr in self.outputs_criterions[output_group][target_group].items():
292
+ if it >= criterions_attr.start and (criterions_attr.stop is None or it <= criterions_attr.stop):
293
+ scheduler = self.update_scheduler(criterions_attr.schedulers, it)
294
+ self._loss[criterions_attr.group][
295
+ f"{output_group}:{target_group}:{criterion.__class__.__name__}"
296
+ ].add(scheduler.get_value(), criterion(output, *target))
297
+ if (
298
+ training
299
+ and len(
300
+ np.unique(
301
+ [
302
+ len(loss_value)
303
+ for loss_value in self._loss[criterions_attr.group].values()
304
+ if loss_value.accumulation and loss_value.is_loss
305
+ ]
306
+ )
307
+ )
308
+ == 1
309
+ ):
310
+ if criterions_attr.is_loss:
311
+ loss = torch.zeros((1), requires_grad=True)
312
+ for v in [
313
+ loss_value
314
+ for loss_value in self._loss[criterions_attr.group].values()
315
+ if loss_value.accumulation and loss_value.is_loss
316
+ ]:
317
+ loss_value = v.get_last_loss()
318
+ loss = loss.to(loss_value.device) + loss_value
319
+ loss = loss / nb_patch
199
320
  loss.backward()
200
321
 
201
- def getLoss(self) -> list[torch.Tensor]:
322
+ def get_loss(self) -> list[torch.Tensor]:
202
323
  loss: dict[int, torch.Tensor] = {}
203
324
  for group in self._loss.keys():
204
- loss[group] = torch.zeros((1), requires_grad = True)
325
+ loss[group] = torch.zeros((1), requires_grad=True)
205
326
  for v in self._loss[group].values():
206
- if v.isLoss and not v.accumulation:
207
- l = v.getLoss()
208
- loss[v.group] = loss[v.group].to(l.device)+l
209
- return loss.values()
327
+ if v.is_loss and not v.accumulation:
328
+ loss_value = v.get_loss()
329
+ loss[v.group] = loss[v.group].to(loss_value.device) + loss_value
330
+ return list(loss.values())
210
331
 
211
- def resetLoss(self) -> None:
332
+ def reset_loss(self) -> None:
212
333
  for group in self._loss.keys():
213
334
  for v in self._loss[group].values():
214
- v.resetLoss()
215
-
216
- def getLastValues(self, n: int = 1) -> dict[str, float]:
335
+ v.reset_loss()
336
+
337
+ def get_last_values(self, n: int = 1) -> dict[str, float]:
217
338
  result = {}
218
339
  for group in self._loss.keys():
219
- result.update({name : np.nanmean(value._values[-n:] if n > 0 else value._values) for name, value in self._loss[group].items() if n < 0 or len(value._values) >= n})
340
+ result.update(
341
+ {
342
+ name: np.nanmean(value._values[-n:] if n > 0 else value._values)
343
+ for name, value in self._loss[group].items()
344
+ if n < 0 or len(value._values) >= n
345
+ }
346
+ )
220
347
  return result
221
-
222
- def getLastWeights(self, n: int = 1) -> dict[str, float]:
348
+
349
+ def get_last_weights(self, n: int = 1) -> dict[str, float]:
223
350
  result = {}
224
351
  for group in self._loss.keys():
225
- result.update({name : np.nanmean(value._weight[-n:] if n > 0 else value._weight) for name, value in self._loss[group].items() if n < 0 or len(value._values) >= n})
352
+ result.update(
353
+ {
354
+ name: np.nanmean(value._weight[-n:] if n > 0 else value._weight)
355
+ for name, value in self._loss[group].items()
356
+ if n < 0 or len(value._values) >= n
357
+ }
358
+ )
226
359
  return result
227
360
 
228
- def format(self, isLoss: bool, n: int) -> dict[str, tuple[float, float]]:
229
- result = dict()
361
+ def format_loss(self, is_loss: bool, n: int) -> dict[str, tuple[float, float]]:
362
+ result = {}
230
363
  for group in self._loss.keys():
231
364
  for name, loss in self._loss[group].items():
232
- if loss.isLoss == isLoss and len(loss._values) >= n:
233
- result[name] = (np.nanmean(loss._weight[-n:]), np.nanmean(loss._values[-n:]))
365
+ if loss.is_loss == is_loss and len(loss._values) >= n:
366
+ result[name] = (
367
+ np.nanmean(loss._weight[-n:]),
368
+ np.nanmean(loss._values[-n:]),
369
+ )
234
370
  return result
235
371
 
236
372
  def update_scheduler(self, schedulers: dict[Scheduler, int], it: int) -> Scheduler:
237
373
  step = 0
238
- scheduler = None
239
- for scheduler, value in schedulers.items():
240
- if value is None or (it >= step and it < step+value):
374
+ _scheduler = None
375
+ for _scheduler, value in schedulers.items():
376
+ if value is None or (it >= step and it < step + value):
241
377
  break
242
378
  step += value
243
- if scheduler:
244
- scheduler.step(it-step)
245
- return scheduler
246
-
379
+ if _scheduler:
380
+ _scheduler.step(it - step)
381
+ if _scheduler is None:
382
+ raise NameError(
383
+ f"No scheduler found for iteration {it}. "
384
+ f"Available steps were: {list(schedulers.values())}. "
385
+ f"Check your configuration."
386
+ )
387
+ return _scheduler
388
+
389
+
247
390
  class ModuleArgsDict(torch.nn.Module, ABC):
248
-
391
+
249
392
  class ModuleArgs:
250
393
 
251
- def __init__(self, in_branch: list[str], out_branch: list[str], pretrained : bool, alias : list[str], requires_grad: Union[bool, None], training: Union[None, bool]) -> None:
394
+ def __init__(
395
+ self,
396
+ in_branch: list[str],
397
+ out_branch: list[str],
398
+ pretrained: bool,
399
+ alias: list[str],
400
+ requires_grad: bool | None,
401
+ training: None | bool,
402
+ ) -> None:
252
403
  super().__init__()
253
- self.alias= alias
404
+ self.alias = alias
254
405
  self.pretrained = pretrained
255
406
  self.in_branch = in_branch
256
407
  self.out_branch = out_branch
257
- self.in_channels = None
258
- self.in_is_channel = True
259
- self.out_channels = None
260
- self.out_is_channel = True
408
+ self.in_channels: int | None = None
409
+ self.in_is_channel: bool = True
410
+ self.out_channels: int | None = None
411
+ self.out_is_channel: bool = True
261
412
  self.requires_grad = requires_grad
262
413
  self.isCheckpoint = False
263
414
  self.isGPU_Checkpoint = False
@@ -267,51 +418,54 @@ class ModuleArgsDict(torch.nn.Module, ABC):
267
418
 
268
419
  def __init__(self) -> None:
269
420
  super().__init__()
270
- self._modulesArgs : dict[str, ModuleArgsDict.ModuleArgs] = dict()
421
+ self._modulesArgs: dict[str, ModuleArgsDict.ModuleArgs] = {}
271
422
  self._training = NetState.TRAIN
272
423
 
273
- def _addindent(self, s_: str, numSpaces : int):
274
- s = s_.split('\n')
424
+ def _addindent(self, s_: str, num_spaces: int):
425
+ s = s_.split("\n")
275
426
  if len(s) == 1:
276
427
  return s_
277
428
  first = s.pop(0)
278
- s = [(numSpaces * ' ') + line for line in s]
279
- s = '\n'.join(s)
280
- s = first + '\n' + s
281
- return s
429
+ s = [(num_spaces * " ") + line for line in s]
430
+ return first + "\n" + "\n".join(s)
282
431
 
283
432
  def __repr__(self):
284
433
  extra_lines = []
285
434
 
286
435
  extra_repr = self.extra_repr()
287
436
  if extra_repr:
288
- extra_lines = extra_repr.split('\n')
437
+ extra_lines = extra_repr.split("\n")
289
438
 
290
439
  child_lines = []
291
- is_simple_branch = lambda x : len(x) > 1 or x[0] != 0
440
+
441
+ def is_simple_branch(x):
442
+ return len(x) > 1 or x[0] != 0
443
+
292
444
  for key, module in self._modules.items():
293
445
  mod_str = repr(module)
294
446
 
295
447
  mod_str = self._addindent(mod_str, 2)
296
448
  desc = ""
297
- if is_simple_branch(self._modulesArgs[key].in_branch) or is_simple_branch(self._modulesArgs[key].out_branch):
298
- desc += ", {}->{}".format(self._modulesArgs[key].in_branch, self._modulesArgs[key].out_branch)
449
+ if is_simple_branch(self._modulesArgs[key].in_branch) or is_simple_branch(
450
+ self._modulesArgs[key].out_branch
451
+ ):
452
+ desc += f", {self._modulesArgs[key].in_branch}->{self._modulesArgs[key].out_branch}"
299
453
  if not self._modulesArgs[key].pretrained:
300
454
  desc += ", pretrained=False"
301
455
  if self._modulesArgs[key].alias:
302
- desc += ", alias={}".format(self._modulesArgs[key].alias)
303
- desc += ", in_channels={}".format(self._modulesArgs[key].in_channels)
304
- desc += ", in_is_channel={}".format(self._modulesArgs[key].in_is_channel)
305
- desc += ", out_channels={}".format(self._modulesArgs[key].out_channels)
306
- desc += ", out_is_channel={}".format(self._modulesArgs[key].out_is_channel)
307
- desc += ", is_end={}".format(self._modulesArgs[key]._isEnd)
308
- desc += ", isInCheckpoint={}".format(self._modulesArgs[key].isCheckpoint)
309
- desc += ", isInGPU_Checkpoint={}".format(self._modulesArgs[key].isGPU_Checkpoint)
310
- desc += ", requires_grad={}".format(self._modulesArgs[key].requires_grad)
311
- desc += ", device={}".format(self._modulesArgs[key].gpu)
312
-
313
- child_lines.append("({}{}) {}".format(key, desc, mod_str))
314
-
456
+ desc += f", alias={self._modulesArgs[key].alias}"
457
+ desc += f", in_channels={self._modulesArgs[key].in_channels}"
458
+ desc += f", in_is_channel={self._modulesArgs[key].in_is_channel}"
459
+ desc += f", out_channels={self._modulesArgs[key].out_channels}"
460
+ desc += f", out_is_channel={self._modulesArgs[key].out_is_channel}"
461
+ desc += f", is_end={self._modulesArgs[key]._isEnd}"
462
+ desc += f", isInCheckpoint={self._modulesArgs[key].isCheckpoint}"
463
+ desc += f", isInGPU_Checkpoint={self._modulesArgs[key].isGPU_Checkpoint}"
464
+ desc += f", requires_grad={self._modulesArgs[key].requires_grad}"
465
+ desc += f", device={self._modulesArgs[key].gpu}"
466
+
467
+ child_lines.append(f"({key}{desc}) {mod_str}")
468
+
315
469
  lines = extra_lines + child_lines
316
470
 
317
471
  desc = ""
@@ -319,53 +473,71 @@ class ModuleArgsDict(torch.nn.Module, ABC):
319
473
  if len(extra_lines) == 1 and not child_lines:
320
474
  desc += extra_lines[0]
321
475
  else:
322
- desc += '\n ' + '\n '.join(lines) + '\n'
476
+ desc += "\n " + "\n ".join(lines) + "\n"
477
+
478
+ return f"{self._get_name()}({desc})"
323
479
 
324
- return "{}({})".format(self._get_name(), desc)
325
-
326
480
  def __getitem__(self, key: str) -> torch.nn.Module:
327
481
  module = self._modules[key]
328
- assert module, "Error {} is None".format(key)
329
- return module
482
+ if not module:
483
+ raise ValueError(f"Module '{key}' is None or missing in self._modules")
484
+ return module
330
485
 
331
486
  @_copy_to_script_wrapper
332
487
  def keys(self) -> Iterable[str]:
333
488
  return self._modules.keys()
334
489
 
335
490
  @_copy_to_script_wrapper
336
- def items(self) -> Iterable[tuple[str, Union[torch.nn.Module, None]]]:
491
+ def items(self) -> Iterable[tuple[str, torch.nn.Module | None]]:
337
492
  return self._modules.items()
338
493
 
339
494
  @_copy_to_script_wrapper
340
- def values(self) -> Iterable[Union[torch.nn.Module, None]]:
495
+ def values(self) -> Iterable[torch.nn.Module | None]:
341
496
  return self._modules.values()
342
497
 
343
- def add_module(self, name: str, module : torch.nn.Module, in_branch: list[Union[int, str]] = [0], out_branch: list[Union[int, str]] = [0], pretrained : bool = True, alias : list[str] = [], requires_grad: Union[bool, None] = None, training: Union[None, bool] = None) -> None:
498
+ def add_module(
499
+ self,
500
+ name: str,
501
+ module: torch.nn.Module,
502
+ in_branch: Sequence[int | str] = [0],
503
+ out_branch: Sequence[int | str] = [0],
504
+ pretrained: bool = True,
505
+ alias: list[str] = [],
506
+ requires_grad: bool | None = None,
507
+ training: None | bool = None,
508
+ ) -> None:
344
509
  super().add_module(name, module)
345
- self._modulesArgs[name] = ModuleArgsDict.ModuleArgs([str(value) for value in in_branch], [str(value) for value in out_branch], pretrained, alias, requires_grad, training)
346
-
347
- def getMap(self):
348
- results : dict[str, str] = {}
349
- for name, moduleArgs in self._modulesArgs.items():
510
+ self._modulesArgs[name] = ModuleArgsDict.ModuleArgs(
511
+ [str(value) for value in in_branch],
512
+ [str(value) for value in out_branch],
513
+ pretrained,
514
+ alias,
515
+ requires_grad,
516
+ training,
517
+ )
518
+
519
+ def get_mapping(self):
520
+ results: dict[str, str] = {}
521
+ for name, module_args in self._modulesArgs.items():
350
522
  module = self[name]
351
523
  if isinstance(module, ModuleArgsDict):
352
- if len(moduleArgs.alias):
353
- count = {k : 0 for k in set(module.getMap().values())}
524
+ if len(module_args.alias):
525
+ count = dict.fromkeys(set(module.get_mapping().values()), 0)
354
526
  if len(count):
355
- for k, v in module.getMap().items():
356
- alias_name = moduleArgs.alias[count[v]]
527
+ for k, v in module.get_mapping().items():
528
+ alias_name = module_args.alias[count[v]]
357
529
  if k == "":
358
- results.update({alias_name : name+"."+v})
530
+ results.update({alias_name: name + "." + v})
359
531
  else:
360
- results.update({alias_name+"."+k : name+"."+v})
361
- count[v]+=1
532
+ results.update({alias_name + "." + k: name + "." + v})
533
+ count[v] += 1
362
534
  else:
363
- for alias in moduleArgs.alias:
364
- results.update({alias : name})
535
+ for alias in module_args.alias:
536
+ results.update({alias: name})
365
537
  else:
366
- results.update({k : name+"."+v for k, v in module.getMap().items()})
538
+ results.update({k: name + "." + v for k, v in module.get_mapping().items()})
367
539
  else:
368
- for alias in moduleArgs.alias:
540
+ for alias in module_args.alias:
369
541
  results[alias] = name
370
542
  return results
371
543
 
@@ -375,24 +547,24 @@ class ModuleArgsDict(torch.nn.Module, ABC):
375
547
  if isinstance(module, ModuleArgsDict):
376
548
  module.init(init_type, init_gain)
377
549
  elif isinstance(module, torch.nn.modules.conv._ConvNd) or isinstance(module, torch.nn.Linear):
378
- if init_type == 'normal':
550
+ if init_type == "normal":
379
551
  torch.nn.init.normal_(module.weight, 0.0, init_gain)
380
- elif init_type == 'xavier':
552
+ elif init_type == "xavier":
381
553
  torch.nn.init.xavier_normal_(module.weight, gain=init_gain)
382
- elif init_type == 'kaiming':
383
- torch.nn.init.kaiming_normal_(module.weight, a=0, mode='fan_in')
384
- elif init_type == 'orthogonal':
554
+ elif init_type == "kaiming":
555
+ torch.nn.init.kaiming_normal_(module.weight, a=0, mode="fan_in")
556
+ elif init_type == "orthogonal":
385
557
  torch.nn.init.orthogonal_(module.weight, gain=init_gain)
386
558
  elif init_type == "trunc_normal":
387
559
  torch.nn.init.trunc_normal_(module.weight, std=init_gain)
388
560
  else:
389
- raise NotImplementedError('Initialization method {} is not implemented'.format(init_type))
561
+ raise NotImplementedError(f"Initialization method {init_type} is not implemented")
390
562
  if module.bias is not None:
391
563
  torch.nn.init.constant_(module.bias, 0.0)
392
564
 
393
565
  elif isinstance(module, torch.nn.modules.batchnorm._BatchNorm):
394
566
  if module.weight is not None:
395
- torch.nn.init.normal_(module.weight, 0.0, std = init_gain)
567
+ torch.nn.init.normal_(module.weight, 0.0, std=init_gain)
396
568
  if module.bias is not None:
397
569
  torch.nn.init.constant_(module.bias, 0.0)
398
570
 
@@ -400,8 +572,7 @@ class ModuleArgsDict(torch.nn.Module, ABC):
400
572
  for module in self._modules.values():
401
573
  ModuleArgsDict.init_func(module, init_type, init_gain)
402
574
 
403
-
404
- def named_forward(self, *inputs: torch.Tensor) -> Iterator[tuple[str, torch.Tensor]]:
575
+ def named_forward(self, *inputs: torch.Tensor) -> Iterator[tuple[str, torch.Tensor]]:
405
576
  if len(inputs) > 0:
406
577
  branchs: dict[str, torch.Tensor] = {}
407
578
  for i, sinput in enumerate(inputs):
@@ -410,7 +581,10 @@ class ModuleArgsDict(torch.nn.Module, ABC):
410
581
  out = inputs[0]
411
582
  tmp = []
412
583
  for name, module in self.items():
413
- if self._modulesArgs[name].training is None or (not (self._modulesArgs[name].training and self._training == NetState.PREDICTION) and not (not self._modulesArgs[name].training and self._training == NetState.TRAIN)):
584
+ if self._modulesArgs[name].training is None or (
585
+ not (self._modulesArgs[name].training and self._training == NetState.PREDICTION)
586
+ and not (not self._modulesArgs[name].training and self._training == NetState.TRAIN)
587
+ ):
414
588
  requires_grad = self._modulesArgs[name].requires_grad
415
589
  if requires_grad is not None and module:
416
590
  module.requires_grad_(requires_grad)
@@ -418,64 +592,75 @@ class ModuleArgsDict(torch.nn.Module, ABC):
418
592
  if ib not in branchs:
419
593
  branchs[ib] = inputs[0]
420
594
  for branchs_key in branchs.keys():
421
- if str(self._modulesArgs[name].gpu) != 'cpu' and str(branchs[branchs_key].device) != "cuda:"+self._modulesArgs[name].gpu:
595
+ if (
596
+ str(self._modulesArgs[name].gpu) != "cpu"
597
+ and str(branchs[branchs_key].device) != "cuda:" + self._modulesArgs[name].gpu
598
+ ):
422
599
  branchs[branchs_key] = branchs[branchs_key].to(int(self._modulesArgs[name].gpu))
423
-
600
+
424
601
  if self._modulesArgs[name].isCheckpoint:
425
- out = checkpoint(module, *[branchs[i] for i in self._modulesArgs[name].in_branch], use_reentrant=True)
602
+ out = checkpoint(
603
+ module,
604
+ *[branchs[i] for i in self._modulesArgs[name].in_branch],
605
+ use_reentrant=True,
606
+ )
426
607
  for ob in self._modulesArgs[name].out_branch:
427
608
  branchs[ob] = out
428
609
  yield name, out
429
610
  else:
430
611
  if isinstance(module, ModuleArgsDict):
431
- for k, out in module.named_forward(*[branchs[i] for i in self._modulesArgs[name].in_branch]):
612
+ for k, out in module.named_forward(
613
+ *[branchs[i] for i in self._modulesArgs[name].in_branch]
614
+ ):
432
615
  for ob in self._modulesArgs[name].out_branch:
433
616
  if ob in module._modulesArgs[k.split(".")[0].replace(";accu;", "")].out_branch:
434
617
  tmp.append(ob)
435
618
  branchs[ob] = out
436
- yield name+"."+k, out
619
+ yield name + "." + k, out
437
620
  for ob in self._modulesArgs[name].out_branch:
438
621
  if ob not in tmp:
439
622
  branchs[ob] = out
440
623
  elif isinstance(module, torch.nn.Module):
441
- out = module(*[branchs[i] for i in self._modulesArgs[name].in_branch])
624
+ out = module(*[branchs[i] for i in self._modulesArgs[name].in_branch])
442
625
  for ob in self._modulesArgs[name].out_branch:
443
626
  branchs[ob] = out
444
627
  yield name, out
445
628
  del branchs
446
629
 
447
630
  def forward(self, *input: torch.Tensor) -> torch.Tensor:
448
- v = input
449
- for k, v in self.named_forward(*input):
631
+ _v = input
632
+ for _, _v in self.named_forward(*input):
450
633
  pass
451
- return v
634
+ return _v
452
635
 
453
- def named_parameters(self, pretrained: bool = False, recurse=False) -> Iterator[tuple[str, torch.nn.parameter.Parameter]]:
454
- for name, moduleArgs in self._modulesArgs.items():
636
+ def named_parameters(
637
+ self, pretrained: bool = False, recurse=False
638
+ ) -> Iterator[tuple[str, torch.nn.parameter.Parameter]]:
639
+ for name, module_args in self._modulesArgs.items():
455
640
  module = self[name]
456
641
  if isinstance(module, ModuleArgsDict):
457
- for k, v in module.named_parameters(pretrained = pretrained):
458
- yield name+"."+k, v
642
+ for k, v in module.named_parameters(pretrained=pretrained):
643
+ yield name + "." + k, v
459
644
  elif isinstance(module, torch.nn.Module):
460
- if not pretrained or not moduleArgs.pretrained:
461
- if moduleArgs.training is None or moduleArgs.training:
645
+ if not pretrained or not module_args.pretrained:
646
+ if module_args.training is None or module_args.training:
462
647
  for k, v in module.named_parameters():
463
- yield name+"."+k, v
648
+ yield name + "." + k, v
464
649
 
465
650
  def parameters(self, pretrained: bool = False):
466
- for _, v in self.named_parameters(pretrained = pretrained):
651
+ for _, v in self.named_parameters(pretrained=pretrained):
467
652
  yield v
468
-
469
- def named_ModuleArgsDict(self) -> Iterator[tuple[str, Self, ModuleArgs]]:
653
+
654
+ def named_module_args_dict(self) -> Iterator[tuple[str, Self, ModuleArgs]]:
470
655
  for name, module in self._modules.items():
471
656
  yield name, module, self._modulesArgs[name]
472
657
  if isinstance(module, ModuleArgsDict):
473
- for k, v, u in module.named_ModuleArgsDict():
474
- yield name+"."+k, v, u
658
+ for k, v, u in module.named_module_args_dict():
659
+ yield name + "." + k, v, u
475
660
 
476
661
  def _requires_grad(self, keys: list[str]):
477
662
  keys = keys.copy()
478
- for name, module, args in self.named_ModuleArgsDict():
663
+ for name, module, args in self.named_module_args_dict():
479
664
  requires_grad = args.requires_grad
480
665
  if requires_grad is not None:
481
666
  module.requires_grad_(requires_grad)
@@ -484,162 +669,221 @@ class ModuleArgsDict(torch.nn.Module, ABC):
484
669
  if len(keys) == 0:
485
670
  break
486
671
 
672
+
487
673
  class OutputsGroup(list):
488
674
 
489
675
  def __init__(self, measure: Measure) -> None:
490
676
  self.layers: dict[str, torch.Tensor] = {}
491
677
  self.measure = measure
492
678
 
493
- def addLayer(self, name: str, layer: torch.Tensor):
679
+ def add_layer(self, name: str, layer: torch.Tensor):
494
680
  self.layers[name] = layer
495
681
 
496
- def isDone(self):
682
+ def is_done(self):
497
683
  return len(self) == len(self.layers)
498
-
684
+
499
685
  def clear(self):
500
686
  self.layers.clear()
501
687
 
502
-
503
688
 
504
689
  class Network(ModuleArgsDict, ABC):
505
690
 
506
- def _apply_network(self, name_function : Callable[[Self], str], networks: list[str], key: str, function: Callable, *args, **kwargs) -> dict[str, object]:
507
- results : dict[str, object] = {}
691
+ def _apply_network(
692
+ self,
693
+ name_function: Callable[[Self], str],
694
+ networks: list[str],
695
+ key: str,
696
+ function: Callable,
697
+ *args,
698
+ **kwargs,
699
+ ) -> dict[str, object]:
700
+ results: dict[str, object] = {}
508
701
  for module in self.values():
509
702
  if isinstance(module, Network):
510
703
  if name_function(module) not in networks:
511
704
  networks.append(name_function(module))
512
- for k, v in module._apply_network(name_function, networks, key+"."+name_function(module), function, *args, **kwargs).items():
513
- results.update({name_function(self)+"."+k : v})
705
+ for k, v in module._apply_network(
706
+ name_function,
707
+ networks,
708
+ key + "." + name_function(module),
709
+ function,
710
+ *args,
711
+ **kwargs,
712
+ ).items():
713
+ results.update({name_function(self) + "." + k: v})
514
714
  if len([param.name for param in list(inspect.signature(function).parameters.values()) if param.name == "key"]):
515
715
  function = partial(function, key=key)
516
716
 
517
717
  results[name_function(self)] = function(self, *args, **kwargs)
518
718
  return results
519
-
520
- def _function_network():
521
- def _function_network_d(function : Callable):
522
- def new_function(self : Self, *args, **kwargs) -> dict[str, object]:
523
- return self._apply_network(lambda network: network.getName(), [], self.getName(), function, *args, **kwargs)
719
+
720
+ def _function_network(): # type: ignore[misc]
721
+ def _function_network_d(function: Callable):
722
+ def new_function(self: Self, *args, **kwargs) -> dict[str, object]:
723
+ return self._apply_network(
724
+ lambda network: network.get_name(),
725
+ [],
726
+ self.get_name(),
727
+ function,
728
+ *args,
729
+ **kwargs,
730
+ )
731
+
524
732
  return new_function
733
+
525
734
  return _function_network_d
526
735
 
527
- def __init__( self,
528
- in_channels : int = 1,
529
- optimizer: Union[OptimizerLoader, None] = None,
530
- schedulers: Union[dict[str, LRSchedulersLoader], None] = None,
531
- outputsCriterions: Union[dict[str, TargetCriterionsLoader], None] = None,
532
- patch : Union[ModelPatch, None] = None,
533
- nb_batch_per_step : int = 1,
534
- init_type : str = "normal",
535
- init_gain : float = 0.02,
536
- dim : int = 3) -> None:
736
+ def __init__(
737
+ self,
738
+ in_channels: int = 1,
739
+ optimizer: OptimizerLoader | None = None,
740
+ schedulers: dict[str, LRSchedulersLoader] | None = None,
741
+ outputs_criterions: dict[str, TargetCriterionsLoader] | None = None,
742
+ patch: ModelPatch | None = None,
743
+ nb_batch_per_step: int = 1,
744
+ init_type: str = "normal",
745
+ init_gain: float = 0.02,
746
+ dim: int = 3,
747
+ ) -> None:
537
748
  super().__init__()
538
749
  self.name = self.__class__.__name__
539
750
  self.in_channels = in_channels
540
- self.optimizerLoader = optimizer
541
- self.optimizer : Union[torch.optim.Optimizer, None] = None
751
+ self.optimizerLoader = optimizer
752
+ self.optimizer: torch.optim.Optimizer | None = None
542
753
 
543
- self.LRSchedulersLoader = schedulers
544
- self.schedulers : Union[dict[torch.optim.lr_scheduler._LRScheduler, int], None] = {}
754
+ self.lr_schedulers_loader = schedulers
755
+ self.schedulers: dict[torch.optim.lr_scheduler._LRScheduler, int] = {}
545
756
 
546
- self.outputsCriterionsLoader = outputsCriterions
547
- self.measure : Union[Measure, None] = None
757
+ self.outputs_criterions_loader = outputs_criterions
758
+ self.measure: Measure | None = None
548
759
 
549
760
  self.patch = patch
550
761
 
551
762
  self.nb_batch_per_step = nb_batch_per_step
552
- self.init_type = init_type
553
- self.init_gain = init_gain
763
+ self.init_type = init_type
764
+ self.init_gain = init_gain
554
765
  self.dim = dim
555
766
  self._it = 0
556
767
  self._nb_lr_update = 0
557
- self.outputsGroup : list[OutputsGroup]= []
768
+ self.outputsGroup: list[OutputsGroup] = []
558
769
 
559
770
  @_function_network()
560
771
  def state_dict(self) -> dict[str, OrderedDict]:
561
- destination = OrderedDict()
562
- destination._metadata = OrderedDict()
563
- destination._metadata[""] = local_metadata = dict(version=self._version)
772
+ destination: OrderedDict[str, Any] = OrderedDict()
773
+ local_metadata = {"version": self._version}
774
+ # destination["_metadata"] = OrderedDict({"": local_metadata})
564
775
  self._save_to_state_dict(destination, "", False)
565
776
  for name, module in self._modules.items():
566
777
  if module is not None:
567
- if not isinstance(module, Network):
568
- module.state_dict(destination=destination, prefix="" + name + '.', keep_vars=False)
778
+ if not isinstance(module, Network):
779
+ module.state_dict(destination=destination, prefix="" + name + ".", keep_vars=False)
569
780
  for hook in self._state_dict_hooks.values():
570
781
  hook_result = hook(self, destination, "", local_metadata)
571
782
  if hook_result is not None:
572
783
  destination = hook_result
573
784
  return destination
574
-
785
+
575
786
  def load_state_dict(self, state_dict: dict[str, torch.Tensor]):
576
787
  missing_keys: list[str] = []
577
788
  unexpected_keys: list[str] = []
578
789
  error_msgs: list[str] = []
579
790
 
580
- metadata = getattr(state_dict, '_metadata', None)
791
+ metadata = getattr(state_dict, "_metadata", None)
581
792
  state_dict = state_dict.copy()
582
793
  if metadata is not None:
583
- state_dict._metadata = metadata
794
+ state_dict["_metadata"] = metadata
584
795
 
585
- def load(module: torch.nn.Module, prefix=''):
796
+ def load(module: torch.nn.Module, prefix=""):
586
797
  local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
587
798
  module._load_from_state_dict(
588
- state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs)
799
+ state_dict,
800
+ prefix,
801
+ local_metadata,
802
+ True,
803
+ missing_keys,
804
+ unexpected_keys,
805
+ error_msgs,
806
+ )
589
807
  for name, child in module._modules.items():
590
808
  if child is not None:
591
809
  if not isinstance(child, Network):
592
810
  if isinstance(child, torch.nn.modules.conv._ConvNd) or isinstance(module, torch.nn.Linear):
593
811
 
594
812
  current_size = child.weight.shape[0]
595
- last_size = state_dict[prefix + name+".weight"].shape[0]
813
+ last_size = state_dict[prefix + name + ".weight"].shape[0]
596
814
 
597
815
  if current_size != last_size:
598
- print("Warning: The size of '{}' has changed from {} to {}. Please check for potential impacts".format(prefix + name, last_size, current_size))
816
+ print(
817
+ f"Warning: The size of '{prefix + name}' has changed from {last_size}"
818
+ f" to {current_size}. Please check for potential impacts"
819
+ )
599
820
  ModuleArgsDict.init_func(child, self.init_type, self.init_gain)
600
821
 
601
822
  with torch.no_grad():
602
- child.weight[:last_size] = state_dict[prefix + name+".weight"]
823
+ child.weight[:last_size] = state_dict[prefix + name + ".weight"]
603
824
  if child.bias is not None:
604
- child.bias[:last_size] = state_dict[prefix + name+".bias"]
825
+ child.bias[:last_size] = state_dict[prefix + name + ".bias"]
605
826
  return
606
- load(child, prefix + name + '.')
827
+ load(child, prefix + name + ".")
607
828
 
608
829
  load(self)
609
- del load
610
830
 
611
831
  if len(unexpected_keys) > 0:
832
+ formatted_keys = ", ".join(f'"{k}"' for k in unexpected_keys)
612
833
  error_msgs.insert(
613
- 0, 'Unexpected key(s) in state_dict: {}. '.format(
614
- ', '.join('"{}"'.format(k) for k in unexpected_keys)))
834
+ 0,
835
+ f"Unexpected key(s) in state_dict: {formatted_keys}.",
836
+ )
615
837
  if len(missing_keys) > 0:
838
+ formatted_keys = ", ".join(f'"{k}"' for k in missing_keys)
616
839
  error_msgs.insert(
617
- 0, 'Missing key(s) in state_dict: {}. '.format(
618
- ', '.join('"{}"'.format(k) for k in missing_keys)))
840
+ 0,
841
+ f"Missing key(s) in state_dict: {formatted_keys}.",
842
+ )
619
843
 
620
844
  if len(error_msgs) > 0:
621
- raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
622
- self.__class__.__name__, "\n\t".join(error_msgs)))
623
-
845
+ formatted_errors = "\n\t".join(error_msgs)
846
+ raise RuntimeError(
847
+ f"Error(s) in loading state_dict for {self.__class__.__name__}:\n\t{formatted_errors}",
848
+ )
849
+
624
850
  def apply(self, fn: Callable[[torch.nn.Module], None]) -> None:
625
851
  for module in self.children():
626
852
  if not isinstance(module, Network):
627
853
  module.apply(fn)
628
854
  fn(self)
629
-
855
+
630
856
  @_function_network()
631
- def load(self, state_dict : dict[str, dict[str, torch.Tensor]], init: bool = True, ema : bool =False):
857
+ def load(
858
+ self,
859
+ state_dict: dict[str, dict[str, torch.Tensor] | int],
860
+ init: bool = True,
861
+ ema: bool = False,
862
+ ):
632
863
  if init:
633
- self.apply(partial(ModuleArgsDict.init_func, init_type=self.init_type, init_gain=self.init_gain))
864
+ self.apply(
865
+ partial(
866
+ ModuleArgsDict.init_func,
867
+ init_type=self.init_type,
868
+ init_gain=self.init_gain,
869
+ )
870
+ )
634
871
  name = "Model" + ("_EMA" if ema else "")
635
872
  if name in state_dict:
636
- model_state_dict_tmp = {k.split(".")[-1] : v for k, v in state_dict[name].items()}[self.getName()]
637
- map = self.getMap()
638
- model_state_dict : OrderedDict[str, torch.Tensor] = OrderedDict()
639
-
873
+ value = state_dict[name]
874
+ model_state_dict_tmp = {}
875
+ if isinstance(value, dict):
876
+ model_state_dict_tmp = {k.split(".")[-1]: v for k, v in value.items()}[self.get_name()]
877
+ modules_name = self.get_mapping()
878
+ model_state_dict: OrderedDict[str, torch.Tensor] = OrderedDict()
879
+
640
880
  for alias in model_state_dict_tmp.keys():
641
881
  prefix = ".".join(alias.split(".")[:-1])
642
- alias_list = [(".".join(prefix.split(".")[:len(i.split("."))]), v) for i, v in map.items() if prefix.startswith(i)]
882
+ alias_list = [
883
+ (".".join(prefix.split(".")[: len(i.split("."))]), v)
884
+ for i, v in modules_name.items()
885
+ if prefix.startswith(i)
886
+ ]
643
887
 
644
888
  if len(alias_list):
645
889
  for a, b in alias_list:
@@ -648,19 +892,34 @@ class Network(ModuleArgsDict, ABC):
648
892
  else:
649
893
  model_state_dict[alias] = model_state_dict_tmp[alias]
650
894
  self.load_state_dict(model_state_dict)
651
- if "{}_optimizer_state_dict".format(self.getName()) in state_dict and self.optimizer:
652
- self.optimizer.load_state_dict(state_dict['{}_optimizer_state_dict'.format(self.getName())])
653
- if "{}_it".format(self.getName()) in state_dict:
654
- self._it = int(state_dict["{}_it".format(self.getName())])
655
- if "{}_nb_lr_update".format(self.getName()) in state_dict:
656
- self._nb_lr_update = int(state_dict["{}_nb_lr_update".format(self.getName())])
895
+ if f"{self.get_name()}_optimizer_state_dict" in state_dict and self.optimizer:
896
+ self.optimizer.load_state_dict(state_dict[f"{self.get_name()}_optimizer_state_dict"])
897
+ if f"{self.get_name()}_it" in state_dict:
898
+ _it = state_dict.get(f"{self.get_name()}_it")
899
+ if isinstance(_it, int):
900
+ self._it = _it
901
+ if f"{self.get_name()}_nb_lr_update" in state_dict:
902
+ _nb_lr_update = state_dict.get(f"{self.get_name()}_nb_lr_update")
903
+ if isinstance(_nb_lr_update, int):
904
+ self._nb_lr_update = _nb_lr_update
905
+
657
906
  for scheduler in self.schedulers:
658
907
  if scheduler.last_epoch == -1:
659
908
  scheduler.last_epoch = self._nb_lr_update
660
909
  self.initialized()
661
910
 
662
- def _compute_channels_trace(self, module : ModuleArgsDict, in_channels : int, gradient_checkpoints: Union[list[str], None], gpu_checkpoints: Union[list[str], None], name: Union[str, None] = None, in_is_channel = True, out_channels : Union[int, None] = None, out_is_channel = True) -> tuple[int, bool, int, bool]:
663
-
911
+ def _compute_channels_trace(
912
+ self,
913
+ module: ModuleArgsDict,
914
+ in_channels: int,
915
+ gradient_checkpoints: list[str] | None,
916
+ gpu_checkpoints: list[str] | None,
917
+ name: str | None = None,
918
+ in_is_channel: bool = True,
919
+ out_channels: int | None = None,
920
+ out_is_channel: bool = True,
921
+ ) -> tuple[int, bool, int | None, bool]:
922
+
664
923
  for k1, v1 in module.items():
665
924
  if isinstance(v1, ModuleArgsDict):
666
925
  for t in module._modulesArgs[k1].out_branch:
@@ -680,25 +939,34 @@ class Network(ModuleArgsDict, ABC):
680
939
  if hasattr(v, "in_features"):
681
940
  if v.in_features:
682
941
  in_channels = v.in_features
683
- key = name+"."+k if name else k
942
+ key = name + "." + k if name else k
684
943
 
685
944
  if gradient_checkpoints:
686
945
  if key in gradient_checkpoints:
687
946
  module._modulesArgs[k].isCheckpoint = True
688
-
947
+
689
948
  if gpu_checkpoints:
690
949
  if key in gpu_checkpoints:
691
950
  module._modulesArgs[k].isGPU_Checkpoint = True
692
-
951
+
693
952
  module._modulesArgs[k].in_channels = in_channels
694
953
  module._modulesArgs[k].in_is_channel = in_is_channel
695
-
954
+
696
955
  if isinstance(v, ModuleArgsDict):
697
- in_channels, in_is_channel, out_channels, out_is_channel = self._compute_channels_trace(v, in_channels, gradient_checkpoints, gpu_checkpoints, key, in_is_channel, out_channels, out_is_channel)
698
-
956
+ in_channels, in_is_channel, out_channels, out_is_channel = self._compute_channels_trace(
957
+ v,
958
+ in_channels,
959
+ gradient_checkpoints,
960
+ gpu_checkpoints,
961
+ key,
962
+ in_is_channel,
963
+ out_channels,
964
+ out_is_channel,
965
+ )
966
+
699
967
  if v.__class__.__name__ == "ToChannels":
700
968
  out_is_channel = True
701
-
969
+
702
970
  if v.__class__.__name__ == "ToFeatures":
703
971
  out_is_channel = False
704
972
 
@@ -712,26 +980,32 @@ class Network(ModuleArgsDict, ABC):
712
980
  module._modulesArgs[k].out_channels = out_channels
713
981
  module._modulesArgs[k].out_is_channel = out_is_channel
714
982
 
715
- in_channels = out_channels
983
+ in_channels = out_channels if out_channels is not None else in_channels
716
984
  in_is_channel = out_is_channel
717
-
985
+
718
986
  return in_channels, in_is_channel, out_channels, out_is_channel
719
987
 
720
988
  @_function_network()
721
- def init(self, autocast : bool, state : State, group_dest: list[str], key: str) -> None:
722
- if self.outputsCriterionsLoader:
723
- self.measure = Measure(key, self.outputsCriterionsLoader)
989
+ def init(self, autocast: bool, state: State, group_dest: list[str], key: str) -> None:
990
+ if self.outputs_criterions_loader:
991
+ self.measure = Measure(key, self.outputs_criterions_loader)
724
992
  self.measure.init(self, group_dest)
993
+ if self.patch is not None:
994
+ self.patch.init(f"{konfai_root()}.Model.{key}.Patch")
725
995
  if state != State.PREDICTION:
726
996
  self.scaler = torch.amp.GradScaler("cuda", enabled=autocast)
727
997
  if self.optimizerLoader:
728
- self.optimizer = self.optimizerLoader.getOptimizer(key, self.parameters(state == State.TRANSFER_LEARNING))
998
+ self.optimizer = self.optimizerLoader.get_optimizer(
999
+ key, self.parameters(state == State.TRANSFER_LEARNING)
1000
+ )
729
1001
  self.optimizer.zero_grad()
730
1002
 
731
- if self.LRSchedulersLoader and self.optimizer:
732
- for schedulers_classname, schedulers in self.LRSchedulersLoader.items():
733
- self.schedulers[schedulers.getschedulers(key, schedulers_classname, self.optimizer)] = schedulers.nb_step
734
-
1003
+ if self.lr_schedulers_loader and self.optimizer:
1004
+ for schedulers_classname, schedulers in self.lr_schedulers_loader.items():
1005
+ self.schedulers[schedulers.getschedulers(key, schedulers_classname, self.optimizer)] = (
1006
+ schedulers.nb_step
1007
+ )
1008
+
735
1009
  def initialized(self):
736
1010
  pass
737
1011
 
@@ -740,80 +1014,101 @@ class Network(ModuleArgsDict, ABC):
740
1014
  self.patch.load(inputs[0].shape[2:])
741
1015
  accumulators: dict[str, Accumulator] = {}
742
1016
 
743
- patchIterator = self.patch.disassemble(*inputs)
1017
+ patch_iterator = self.patch.disassemble(*inputs)
744
1018
  buffer = []
745
- for i, patch_input in enumerate(patchIterator):
746
- for (name, output_layer) in super().named_forward(*patch_input):
747
- yield "{}{}".format(";accu;", name), output_layer
1019
+ for i, patch_input in enumerate(patch_iterator):
1020
+ for name, output_layer in super().named_forward(*patch_input):
1021
+ yield f";accu;{name}", output_layer
748
1022
  buffer.append((name.split(".")[0], output_layer))
749
1023
  if len(buffer) == 2:
750
1024
  if buffer[0][0] != buffer[1][0]:
751
1025
  if self._modulesArgs[buffer[0][0]]._isEnd:
752
1026
  if buffer[0][0] not in accumulators:
753
- accumulators[buffer[0][0]] = Accumulator(self.patch.getPatch_slices(), self.patch.patch_size, self.patch.patchCombine)
754
- accumulators[buffer[0][0]].addLayer(i, buffer[0][1])
1027
+ accumulators[buffer[0][0]] = Accumulator(
1028
+ self.patch.get_patch_slices(),
1029
+ self.patch.patch_size,
1030
+ self.patch.patch_combine,
1031
+ )
1032
+ accumulators[buffer[0][0]].add_layer(i, buffer[0][1])
755
1033
  buffer.pop(0)
756
1034
  if self._modulesArgs[buffer[0][0]]._isEnd:
757
1035
  if buffer[0][0] not in accumulators:
758
- accumulators[buffer[0][0]] = Accumulator(self.patch.getPatch_slices(), self.patch.patch_size, self.patch.patchCombine)
759
- accumulators[buffer[0][0]].addLayer(i, buffer[0][1])
1036
+ accumulators[buffer[0][0]] = Accumulator(
1037
+ self.patch.get_patch_slices(),
1038
+ self.patch.patch_size,
1039
+ self.patch.patch_combine,
1040
+ )
1041
+ accumulators[buffer[0][0]].add_layer(i, buffer[0][1])
760
1042
  for name, accumulator in accumulators.items():
761
1043
  yield name, accumulator.assemble()
762
1044
  else:
763
- for (name, output_layer) in super().named_forward(*inputs):
1045
+ for name, output_layer in super().named_forward(*inputs):
764
1046
  yield name, output_layer
765
1047
 
766
-
767
- def get_layers(self, inputs : list[torch.Tensor], layers_name: list[str]) -> Iterator[tuple[str, torch.Tensor, Union[Patch_Indexed, None]]]:
1048
+ def get_layers(
1049
+ self, inputs: list[torch.Tensor], layers_name: list[str]
1050
+ ) -> Iterator[tuple[str, torch.Tensor, PatchIndexed | None]]:
768
1051
  layers_name = layers_name.copy()
769
- output_layer_accumulator : dict[str, Accumulator] = {}
770
- output_layer_patch_indexed : dict[str, Patch_Indexed] = {}
1052
+ output_layer_accumulator: dict[str, Accumulator] = {}
1053
+ output_layer_patch_indexed: dict[str, PatchIndexed] = {}
771
1054
  it = 0
772
1055
  debug = "KONFAI_DEBUG" in os.environ
773
- for (nameTmp, output_layer) in self.named_forward(*inputs):
774
- name = nameTmp.replace(";accu;", "")
1056
+ for name_tmp, output_layer in self.named_forward(*inputs):
1057
+ name = name_tmp.replace(";accu;", "")
775
1058
  if debug:
776
1059
  if "KONFAI_DEBUG_LAST_LAYER" in os.environ:
777
- os.environ["KONFAI_DEBUG_LAST_LAYER"] = "{}|{}:{}:{}".format(os.environ["KONFAI_DEBUG_LAST_LAYER"], name, getGPUMemory(output_layer.device), str(output_layer.device).replace("cuda:", ""))
1060
+ os.environ["KONFAI_DEBUG_LAST_LAYER"] = f"{os.environ['KONFAI_DEBUG_LAST_LAYER']}|{name}:"
1061
+ f"{get_gpu_memory(output_layer.device)}:"
1062
+ f"{str(output_layer.device).replace('cuda:', '')}"
778
1063
  else:
779
- os.environ["KONFAI_DEBUG_LAST_LAYER"] = "{}:{}:{}".format(name, getGPUMemory(output_layer.device), str(output_layer.device).replace("cuda:", ""))
1064
+ os.environ["KONFAI_DEBUG_LAST_LAYER"] = (
1065
+ f"{name}:{get_gpu_memory(output_layer.device)}:{str(output_layer.device).replace('cuda:', '')}"
1066
+ )
780
1067
  it += 1
781
- if name in layers_name or nameTmp in layers_name:
782
- if ";accu;" in nameTmp:
1068
+ if name in layers_name or name_tmp in layers_name:
1069
+ if ";accu;" in name_tmp:
783
1070
  if name not in output_layer_patch_indexed:
784
- networkName = nameTmp.split(".;accu;")[-2].split(".")[-1] if ".;accu;" in nameTmp else nameTmp.split(";accu;")[-2].split(".")[-1]
1071
+ network_name = (
1072
+ name_tmp.split(".;accu;")[-2].split(".")[-1]
1073
+ if ".;accu;" in name_tmp
1074
+ else name_tmp.split(";accu;")[-2].split(".")[-1]
1075
+ )
785
1076
  module = self
786
1077
  network = None
787
- if networkName == "":
1078
+ if network_name == "":
788
1079
  network = module
789
1080
  else:
790
1081
  for n in name.split("."):
791
1082
  module = module[n]
792
- if isinstance(module, Network) and n == networkName:
1083
+ if isinstance(module, Network) and n == network_name:
793
1084
  network = module
794
1085
  break
795
-
1086
+
796
1087
  if network and network.patch:
797
- output_layer_patch_indexed[name] = Patch_Indexed(network.patch, 0)
1088
+ output_layer_patch_indexed[name] = PatchIndexed(network.patch, 0)
798
1089
 
799
1090
  if name not in output_layer_accumulator:
800
- output_layer_accumulator[name] = Accumulator(output_layer_patch_indexed[name].patch.getPatch_slices(0), output_layer_patch_indexed[name].patch.patch_size, output_layer_patch_indexed[name].patch.patchCombine)
801
-
802
- if nameTmp in layers_name:
803
- output_layer_accumulator[name].addLayer(output_layer_patch_indexed[name].index, output_layer)
1091
+ output_layer_accumulator[name] = Accumulator(
1092
+ output_layer_patch_indexed[name].patch.get_patch_slices(0),
1093
+ output_layer_patch_indexed[name].patch.patch_size,
1094
+ output_layer_patch_indexed[name].patch.patch_combine,
1095
+ )
1096
+
1097
+ if name_tmp in layers_name:
1098
+ output_layer_accumulator[name].add_layer(output_layer_patch_indexed[name].index, output_layer)
804
1099
  output_layer_patch_indexed[name].index += 1
805
- if output_layer_accumulator[name].isFull():
1100
+ if output_layer_accumulator[name].is_full():
806
1101
  output_layer = output_layer_accumulator[name].assemble()
807
1102
  output_layer_accumulator.pop(name)
808
1103
  output_layer_patch_indexed.pop(name)
809
- layers_name.remove(nameTmp)
810
- yield nameTmp, output_layer, None
1104
+ layers_name.remove(name_tmp)
1105
+ yield name_tmp, output_layer, None
811
1106
 
812
1107
  if name in layers_name:
813
- if ";accu;" in nameTmp:
1108
+ if ";accu;" in name_tmp:
814
1109
  yield name, output_layer, output_layer_patch_indexed[name]
815
1110
  output_layer_patch_indexed[name].index += 1
816
- if output_layer_patch_indexed[name].isFull():
1111
+ if output_layer_patch_indexed[name].is_full():
817
1112
  output_layer_patch_indexed.pop(name)
818
1113
  layers_name.remove(name)
819
1114
  else:
@@ -823,62 +1118,85 @@ class Network(ModuleArgsDict, ABC):
823
1118
  if not len(layers_name):
824
1119
  break
825
1120
 
826
- def init_outputsGroup(self):
827
- metric_tmp = {network.measure : network.measure.outputsCriterions.keys() for network in self.getNetworks().values() if network.measure}
1121
+ def init_outputs_group(self):
1122
+ metric_tmp = {
1123
+ network.measure: network.measure.outputs_criterions.keys()
1124
+ for network in self.get_networks().values()
1125
+ if network.measure
1126
+ }
828
1127
  for k, v in metric_tmp.items():
829
1128
  for a in v:
830
1129
  outputs_group = OutputsGroup(k)
831
1130
  outputs_group.append(a)
832
- for targetsGroup in k.outputsCriterions[a].keys():
833
- if ":" in targetsGroup:
834
- outputs_group.append(targetsGroup.replace(":", "."))
1131
+ for targets_group in k.outputs_criterions[a].keys():
1132
+ if ":" in targets_group:
1133
+ outputs_group.append(targets_group.replace(":", "."))
835
1134
 
836
1135
  self.outputsGroup.append(outputs_group)
837
1136
 
838
- def forward(self, data_dict: dict[tuple[str, bool], torch.Tensor], output_layers: list[str] = []) -> list[tuple[str, torch.Tensor]]:
1137
+ def forward(
1138
+ self,
1139
+ data_dict: dict[tuple[str, bool], torch.Tensor],
1140
+ output_layers: list[str] = [],
1141
+ ) -> list[tuple[str, torch.Tensor]]:
839
1142
  if not len(self.outputsGroup) and not len(output_layers):
840
1143
  return []
841
1144
 
842
- self.resetLoss()
1145
+ self.reset_loss()
843
1146
  results = []
844
1147
  measure_output_layers = set()
845
- for outputs_group in self.outputsGroup:
846
- for name in outputs_group:
1148
+ for _outputs_group in self.outputsGroup:
1149
+ for name in _outputs_group:
847
1150
  measure_output_layers.add(name)
848
- for name, layer, patch_indexed in self.get_layers([v for k, v in data_dict.items() if k[1]], list(set(list(measure_output_layers)+output_layers))):
849
-
1151
+ for name, layer, patch_indexed in self.get_layers(
1152
+ [v for k, v in data_dict.items() if k[1]],
1153
+ list(set(list(measure_output_layers) + output_layers)),
1154
+ ):
1155
+
850
1156
  outputs_group = [outputs_group for outputs_group in self.outputsGroup if name in outputs_group]
851
1157
  if len(outputs_group) > 0:
852
1158
  if patch_indexed is None:
853
- targets = {k[0] : v for k, v in data_dict.items()}
1159
+ targets = {k[0]: v for k, v in data_dict.items()}
854
1160
  nb = 1
855
1161
  else:
856
- targets = {k[0] : patch_indexed.patch.getData(v, patch_indexed.index, 0, False) for k, v in data_dict.items()}
857
- nb = patch_indexed.patch.getSize(0)
1162
+ targets = {
1163
+ k[0]: patch_indexed.patch.get_data(v, patch_indexed.index, 0, False)
1164
+ for k, v in data_dict.items()
1165
+ }
1166
+ nb = patch_indexed.patch.get_size(0)
858
1167
 
859
1168
  for output_group in outputs_group:
860
- output_group.addLayer(name, layer)
861
- if output_group.isDone():
862
- targets.update({k.replace(".", ":"): v for k, v in output_group.layers.items() if k != output_group[0]})
863
- output_group.measure.update(output_group[0], output_group.layers[output_group[0]], targets, self._it, nb, self.training)
1169
+ output_group.add_layer(name, layer)
1170
+ if output_group.is_done():
1171
+ targets.update(
1172
+ {k.replace(".", ":"): v for k, v in output_group.layers.items() if k != output_group[0]}
1173
+ )
1174
+ output_group.measure.update(
1175
+ output_group[0],
1176
+ output_group.layers[output_group[0]],
1177
+ targets,
1178
+ self._it,
1179
+ nb,
1180
+ self.training,
1181
+ )
864
1182
  output_group.clear()
865
1183
  if name in output_layers:
866
1184
  results.append((name, layer))
867
1185
  return results
868
1186
 
869
1187
  @_function_network()
870
- def resetLoss(self):
1188
+ def reset_loss(self):
871
1189
  if self.measure:
872
- self.measure.resetLoss()
1190
+ self.measure.reset_loss()
873
1191
 
874
1192
  @_function_network()
875
1193
  def backward(self, model: Self):
876
1194
  if self.measure:
877
1195
  if self.scaler and self.optimizer:
878
- model._requires_grad(list(self.measure.outputsCriterions.keys()))
879
- for loss in self.measure.getLoss():
1196
+ model._requires_grad(list(self.measure.outputs_criterions.keys()))
1197
+ for loss in self.measure.get_loss():
880
1198
  self.scaler.scale(loss / self.nb_batch_per_step).backward()
881
-
1199
+
882
1200
  if self._it % self.nb_batch_per_step == 0:
883
1201
  self.scaler.step(self.optimizer)
884
1202
  self.scaler.update()
@@ -887,93 +1205,133 @@ class Network(ModuleArgsDict, ABC):
887
1205
 
888
1206
  @_function_network()
889
1207
  def update_lr(self):
890
- self._nb_lr_update+=1
1208
+ self._nb_lr_update += 1
891
1209
  step = 0
892
- scheduler = None
893
- for scheduler, value in self.schedulers.items():
894
- if value is None or (self._nb_lr_update >= step and self._nb_lr_update < step+value):
1210
+ _scheduler = None
1211
+ for _scheduler, value in self.schedulers.items():
1212
+ if value is None or (self._nb_lr_update >= step and self._nb_lr_update < step + value):
895
1213
  break
896
1214
  step += value
897
- if scheduler:
898
- if scheduler.__class__.__name__ == 'ReduceLROnPlateau':
1215
+ if _scheduler:
1216
+ if _scheduler.__class__.__name__ == "ReduceLROnPlateau":
899
1217
  if self.measure:
900
- scheduler.step(sum(self.measure.getLastValues(0).values()))
1218
+ _scheduler.step(sum(self.measure.get_last_values(0).values()))
901
1219
  else:
902
- scheduler.step()
1220
+ _scheduler.step()
903
1221
 
904
1222
  @_function_network()
905
- def getNetworks(self) -> Self:
1223
+ def get_networks(self) -> Self:
906
1224
  return self
907
1225
 
908
- def to(module : ModuleArgsDict, device: int):
1226
+ @staticmethod
1227
+ def to(module: ModuleArgsDict, device: int):
909
1228
  if "device" not in os.environ:
910
1229
  os.environ["device"] = str(device)
911
1230
  for k, v in module.items():
912
1231
  if module._modulesArgs[k].gpu == "cpu":
913
1232
  if module._modulesArgs[k].isGPU_Checkpoint:
914
- os.environ["device"] = str(int(os.environ["device"])+1)
915
- module._modulesArgs[k].gpu = str(getDevice(int(os.environ["device"])))
1233
+ os.environ["device"] = str(int(os.environ["device"]) + 1)
1234
+ module._modulesArgs[k].gpu = str(get_device(int(os.environ["device"])))
916
1235
  if isinstance(v, ModuleArgsDict):
917
1236
  v = Network.to(v, int(os.environ["device"]))
918
1237
  else:
919
- v = v.to(getDevice(int(os.environ["device"])))
1238
+ v = v.to(get_device(int(os.environ["device"])))
920
1239
  if isinstance(module, Network):
921
1240
  if module.optimizer is not None:
922
1241
  for state in module.optimizer.state.values():
923
1242
  for k, v in state.items():
924
1243
  if isinstance(v, torch.Tensor):
925
- state[k] = v.to(getDevice(int(os.environ["device"])))
1244
+ state[k] = v.to(get_device(int(os.environ["device"])))
926
1245
  return module
927
-
928
- def getName(self) -> str:
1246
+
1247
+ def get_name(self) -> str:
929
1248
  return self.name
930
-
931
- def setName(self, name: str) -> Self:
1249
+
1250
+ def set_name(self, name: str) -> Self:
932
1251
  self.name = name
933
1252
  return self
934
-
935
- def setState(self, state: NetState):
1253
+
1254
+ def set_state(self, state: NetState):
936
1255
  for module in self.modules():
937
1256
  if isinstance(module, ModuleArgsDict):
938
1257
  module._training = state
939
1258
 
1259
+
940
1260
  class MinimalModel(Network):
941
-
942
- def __init__(self, model: Network, optimizer : OptimizerLoader = OptimizerLoader(),
943
- schedulers : LRSchedulersLoader = LRSchedulersLoader(),
944
- outputsCriterions: dict[str, TargetCriterionsLoader] = {"default" : TargetCriterionsLoader()},
945
- patch : Union[ModelPatch, None] = None,
946
- dim : int = 3, nb_batch_per_step = 1, init_type = "normal", init_gain = 0.02):
947
- super().__init__(1, optimizer, schedulers, outputsCriterions, patch, nb_batch_per_step, init_type, init_gain, dim)
1261
+
1262
+ def __init__(
1263
+ self,
1264
+ model: Network,
1265
+ optimizer: OptimizerLoader = OptimizerLoader(),
1266
+ schedulers: dict[str, LRSchedulersLoader] = {"default:ReduceLROnPlateau": LRSchedulersLoader(0)},
1267
+ outputs_criterions: dict[str, TargetCriterionsLoader] = {"default": TargetCriterionsLoader()},
1268
+ patch: ModelPatch | None = None,
1269
+ dim: int = 3,
1270
+ nb_batch_per_step=1,
1271
+ init_type="normal",
1272
+ init_gain=0.02,
1273
+ ):
1274
+ super().__init__(
1275
+ 1,
1276
+ optimizer,
1277
+ schedulers,
1278
+ outputs_criterions,
1279
+ patch,
1280
+ nb_batch_per_step,
1281
+ init_type,
1282
+ init_gain,
1283
+ dim,
1284
+ )
948
1285
  self.add_module("Model", model)
949
1286
 
950
- class ModelLoader():
1287
+
1288
+ class ModelLoader:
951
1289
 
952
1290
  @config("Model")
953
- def __init__(self, classpath : str = "default:segmentation.UNet.UNet") -> None:
954
- self.module, self.name = _getModule(classpath, "konfai.models")
955
-
956
- def getModel(self, train : bool = True, DL_args: Union[str, None] = None, DL_without=["optimizer", "schedulers", "nb_batch_per_step", "init_type", "init_gain"]) -> Network:
957
- if not DL_args:
958
- DL_args="{}.Model".format(KONFAI_ROOT())
959
- DL_args += "."+self.name
960
- model = config(DL_args)(getattr(importlib.import_module(self.module), self.name))(config = None, DL_without = DL_without if not train else [])
1291
+ def __init__(self, classpath: str = "default:segmentation.UNet.UNet") -> None:
1292
+ self.module, self.name = get_module(classpath, "konfai.models")
1293
+
1294
+ def get_model(
1295
+ self,
1296
+ train: bool = True,
1297
+ konfai_args: str | None = None,
1298
+ konfai_without=[
1299
+ "optimizer",
1300
+ "schedulers",
1301
+ "nb_batch_per_step",
1302
+ "init_type",
1303
+ "init_gain",
1304
+ ],
1305
+ ) -> Network:
1306
+ if not konfai_args:
1307
+ konfai_args = f"{konfai_root()}.Model"
1308
+ konfai_args += "." + self.name
1309
+ model = config(konfai_args)(getattr(importlib.import_module(self.module), self.name))(
1310
+ config=None, konfai_without=konfai_without if not train else []
1311
+ )
961
1312
  if not isinstance(model, Network):
962
- model = config(DL_args)(partial(MinimalModel, model))(config = None, DL_without = DL_without+["model"] if not train else [])
963
- model.setName(self.name)
1313
+ model = config(konfai_args)(partial(MinimalModel, model))(
1314
+ config=None, konfai_without=konfai_without + ["model"] if not train else []
1315
+ )
1316
+ model.set_name(self.name)
964
1317
  return model
965
-
966
- class CPU_Model():
1318
+
1319
+
1320
+ class CPUModel:
967
1321
 
968
1322
  def __init__(self, model: Network) -> None:
969
1323
  self.module = model
970
- self.device = torch.device('cpu')
1324
+ self.device = torch.device("cpu")
971
1325
 
972
1326
  def train(self):
973
1327
  self.module.train()
974
-
975
- def eval(self):
1328
+
1329
+ def eval(self): # noqa: A003
976
1330
  self.module.eval()
977
-
978
- def __call__(self, data_dict: dict[tuple[str, bool], torch.Tensor], output_layers: list[str] = []) -> list[tuple[str, torch.Tensor]]:
979
- return self.module(data_dict, output_layers)
1331
+
1332
+ def __call__(
1333
+ self,
1334
+ data_dict: dict[tuple[str, bool], torch.Tensor],
1335
+ output_layers: list[str] = [],
1336
+ ) -> list[tuple[str, torch.Tensor]]:
1337
+ return self.module(data_dict, output_layers)