konfai 1.1.7__py3-none-any.whl → 1.1.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of konfai might be problematic. Click here for more details.
- konfai/__init__.py +59 -14
- konfai/data/augmentation.py +457 -286
- konfai/data/data_manager.py +509 -290
- konfai/data/patching.py +300 -183
- konfai/data/transform.py +384 -277
- konfai/evaluator.py +309 -68
- konfai/main.py +71 -22
- konfai/metric/measure.py +341 -222
- konfai/metric/schedulers.py +24 -13
- konfai/models/classification/convNeXt.py +187 -81
- konfai/models/classification/resnet.py +272 -58
- konfai/models/generation/cStyleGan.py +233 -59
- konfai/models/generation/ddpm.py +348 -121
- konfai/models/generation/diffusionGan.py +757 -358
- konfai/models/generation/gan.py +177 -53
- konfai/models/generation/vae.py +140 -40
- konfai/models/registration/registration.py +135 -52
- konfai/models/representation/representation.py +57 -23
- konfai/models/segmentation/NestedUNet.py +339 -68
- konfai/models/segmentation/UNet.py +140 -30
- konfai/network/blocks.py +331 -187
- konfai/network/network.py +781 -423
- konfai/predictor.py +645 -240
- konfai/trainer.py +527 -216
- konfai/utils/ITK.py +191 -106
- konfai/utils/config.py +152 -95
- konfai/utils/dataset.py +326 -455
- konfai/utils/utils.py +495 -249
- {konfai-1.1.7.dist-info → konfai-1.1.9.dist-info}/METADATA +1 -3
- konfai-1.1.9.dist-info/RECORD +38 -0
- konfai/utils/registration.py +0 -199
- konfai-1.1.7.dist-info/RECORD +0 -39
- {konfai-1.1.7.dist-info → konfai-1.1.9.dist-info}/WHEEL +0 -0
- {konfai-1.1.7.dist-info → konfai-1.1.9.dist-info}/entry_points.txt +0 -0
- {konfai-1.1.7.dist-info → konfai-1.1.9.dist-info}/licenses/LICENSE +0 -0
- {konfai-1.1.7.dist-info → konfai-1.1.9.dist-info}/top_level.txt +0 -0
konfai/network/network.py
CHANGED
|
@@ -1,117 +1,181 @@
|
|
|
1
|
-
from functools import partial
|
|
2
1
|
import importlib
|
|
3
2
|
import inspect
|
|
4
3
|
import os
|
|
5
|
-
from typing import Iterable, Iterator, Callable
|
|
6
|
-
from typing_extensions import Self
|
|
7
|
-
import torch
|
|
8
4
|
from abc import ABC
|
|
5
|
+
from collections import OrderedDict
|
|
6
|
+
from collections.abc import Callable, Iterable, Iterator, Sequence
|
|
7
|
+
from enum import Enum
|
|
8
|
+
from functools import partial
|
|
9
|
+
from typing import Any, Self
|
|
10
|
+
|
|
9
11
|
import numpy as np
|
|
12
|
+
import torch
|
|
10
13
|
from torch._jit_internal import _copy_to_script_wrapper
|
|
11
|
-
from collections import OrderedDict
|
|
12
14
|
from torch.utils.checkpoint import checkpoint
|
|
13
|
-
from typing import Union
|
|
14
|
-
from enum import Enum
|
|
15
15
|
|
|
16
|
-
from konfai import
|
|
16
|
+
from konfai import konfai_root
|
|
17
|
+
from konfai.data.patching import Accumulator, ModelPatch
|
|
17
18
|
from konfai.metric.schedulers import Scheduler
|
|
18
19
|
from konfai.utils.config import config
|
|
19
|
-
from konfai.utils.utils import
|
|
20
|
-
|
|
20
|
+
from konfai.utils.utils import MeasureError, State, TrainerError, get_device, get_gpu_memory, get_module
|
|
21
|
+
|
|
21
22
|
|
|
22
23
|
class NetState(Enum):
|
|
23
|
-
TRAIN = 0,
|
|
24
|
+
TRAIN = (0,)
|
|
24
25
|
PREDICTION = 1
|
|
25
26
|
|
|
26
|
-
|
|
27
|
+
|
|
28
|
+
class PatchIndexed:
|
|
27
29
|
|
|
28
30
|
def __init__(self, patch: ModelPatch, index: int) -> None:
|
|
29
31
|
self.patch = patch
|
|
30
32
|
self.index = index
|
|
31
33
|
|
|
32
|
-
def
|
|
33
|
-
return len(self.patch.
|
|
34
|
+
def is_full(self) -> bool:
|
|
35
|
+
return len(self.patch.get_patch_slices(0)) == self.index
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
class OptimizerLoader:
|
|
34
39
|
|
|
35
|
-
class OptimizerLoader():
|
|
36
|
-
|
|
37
40
|
@config("Optimizer")
|
|
38
41
|
def __init__(self, name: str = "AdamW") -> None:
|
|
39
42
|
self.name = name
|
|
40
|
-
|
|
41
|
-
def getOptimizer(self, key: str, parameter: Iterator[torch.nn.parameter.Parameter]) -> torch.optim.Optimizer:
|
|
42
|
-
return config("{}.Model.{}.Optimizer".format(KONFAI_ROOT(), key))(getattr(importlib.import_module('torch.optim'), self.name))(parameter, config = None)
|
|
43
|
-
|
|
44
43
|
|
|
45
|
-
|
|
46
|
-
|
|
44
|
+
def get_optimizer(self, key: str, parameter: Iterator[torch.nn.parameter.Parameter]) -> torch.optim.Optimizer:
|
|
45
|
+
return config(f"{konfai_root()}.Model.{key}.Optimizer")(
|
|
46
|
+
getattr(importlib.import_module("torch.optim"), self.name)
|
|
47
|
+
)(parameter, config=None)
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
class LRSchedulersLoader:
|
|
51
|
+
|
|
47
52
|
@config()
|
|
48
|
-
def __init__(self, nb_step
|
|
53
|
+
def __init__(self, nb_step: int = 0) -> None:
|
|
49
54
|
self.nb_step = nb_step
|
|
50
55
|
|
|
51
|
-
def getschedulers(
|
|
56
|
+
def getschedulers(
|
|
57
|
+
self, key: str, scheduler_classname: str, optimizer: torch.optim.Optimizer
|
|
58
|
+
) -> torch.optim.lr_scheduler._LRScheduler:
|
|
52
59
|
for m in ["torch.optim.lr_scheduler", "konfai.metric.schedulers"]:
|
|
53
|
-
module, name =
|
|
60
|
+
module, name = get_module(scheduler_classname, m)
|
|
54
61
|
if hasattr(importlib.import_module(module), name):
|
|
55
|
-
return config("{}.Model.{}.schedulers.{}"
|
|
56
|
-
|
|
62
|
+
return config(f"{konfai_root()}.Model.{key}.schedulers.{scheduler_classname}")(
|
|
63
|
+
getattr(importlib.import_module(module), name)
|
|
64
|
+
)(optimizer, config=None)
|
|
65
|
+
raise TrainerError(
|
|
66
|
+
f"Unknown scheduler {scheduler_classname}, tried importing from: 'torch.optim.lr_scheduler' and "
|
|
67
|
+
"'konfai.metric.schedulers', but no valid match was found. "
|
|
68
|
+
"Check your YAML config or scheduler name spelling."
|
|
69
|
+
)
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
class LossSchedulersLoader:
|
|
57
73
|
|
|
58
|
-
class LossSchedulersLoader():
|
|
59
|
-
|
|
60
74
|
@config()
|
|
61
|
-
def __init__(self, nb_step
|
|
75
|
+
def __init__(self, nb_step: int = 0) -> None:
|
|
62
76
|
self.nb_step = nb_step
|
|
63
77
|
|
|
64
78
|
def getschedulers(self, key: str, scheduler_classname: str) -> torch.optim.lr_scheduler._LRScheduler:
|
|
65
|
-
return config("{}.{}"
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
79
|
+
return config(f"{key}.{scheduler_classname}")(
|
|
80
|
+
getattr(importlib.import_module("konfai.metric.schedulers"), scheduler_classname)
|
|
81
|
+
)(config=None)
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
class CriterionsAttr:
|
|
85
|
+
|
|
69
86
|
@config()
|
|
70
|
-
def __init__(
|
|
87
|
+
def __init__(
|
|
88
|
+
self,
|
|
89
|
+
schedulers: dict[str, LossSchedulersLoader] = {"default:Constant": LossSchedulersLoader(0)},
|
|
90
|
+
is_loss: bool = True,
|
|
91
|
+
group: int = 0,
|
|
92
|
+
start: int = 0,
|
|
93
|
+
stop: int | None = None,
|
|
94
|
+
accumulation: bool = False,
|
|
95
|
+
) -> None:
|
|
71
96
|
self.schedulersLoader = schedulers
|
|
72
97
|
self.isTorchCriterion = True
|
|
73
|
-
self.
|
|
74
|
-
self.
|
|
75
|
-
self.
|
|
98
|
+
self.is_loss = is_loss
|
|
99
|
+
self.start = start
|
|
100
|
+
self.stop = stop
|
|
76
101
|
self.group = group
|
|
77
102
|
self.accumulation = accumulation
|
|
78
103
|
self.schedulers: dict[Scheduler, int] = {}
|
|
79
|
-
|
|
80
|
-
class CriterionsLoader():
|
|
81
|
-
|
|
82
|
-
@config()
|
|
83
|
-
def __init__(self, criterionsLoader: dict[str, CriterionsAttr] = {"default:torch_nn_CrossEntropyLoss:Dice:NCC": CriterionsAttr()}) -> None:
|
|
84
|
-
self.criterionsLoader = criterionsLoader
|
|
85
104
|
|
|
86
|
-
def getCriterions(self, model_classname : str, output_group : str, target_group : str) -> dict[torch.nn.Module, CriterionsAttr]:
|
|
87
|
-
criterions = {}
|
|
88
|
-
for module_classpath, criterionsAttr in self.criterionsLoader.items():
|
|
89
|
-
module, name = _getModule(module_classpath, "konfai.metric.measure")
|
|
90
|
-
criterionsAttr.isTorchCriterion = module.startswith("torch")
|
|
91
|
-
for scheduler_classname, schedulers in criterionsAttr.schedulersLoader.items():
|
|
92
|
-
criterionsAttr.schedulers[schedulers.getschedulers("{}.Model.{}.outputsCriterions.{}.targetsCriterions.{}.criterionsLoader.{}.schedulers".format(KONFAI_ROOT(), model_classname, output_group, target_group, module_classpath), scheduler_classname)] = schedulers.nb_step
|
|
93
|
-
criterions[config("{}.Model.{}.outputsCriterions.{}.targetsCriterions.{}.criterionsLoader.{}".format(KONFAI_ROOT(), model_classname, output_group, target_group, module_classpath))(getattr(importlib.import_module(module), name))(config = None)] = criterionsAttr
|
|
94
|
-
return criterions
|
|
95
105
|
|
|
96
|
-
class
|
|
106
|
+
class CriterionsLoader:
|
|
97
107
|
|
|
98
108
|
@config()
|
|
99
|
-
def __init__(
|
|
100
|
-
self
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
109
|
+
def __init__(
|
|
110
|
+
self,
|
|
111
|
+
criterions_loader: dict[str, CriterionsAttr] = {"default:torch:nn:CrossEntropyLoss:Dice:NCC": CriterionsAttr()},
|
|
112
|
+
) -> None:
|
|
113
|
+
self.criterions_loader = criterions_loader
|
|
114
|
+
|
|
115
|
+
def get_criterions(
|
|
116
|
+
self, model_classname: str, output_group: str, target_group: str
|
|
117
|
+
) -> dict[torch.nn.Module, CriterionsAttr]:
|
|
118
|
+
criterions = {}
|
|
119
|
+
for module_classpath, criterions_attr in self.criterions_loader.items():
|
|
120
|
+
module, name = get_module(module_classpath, "konfai.metric.measure")
|
|
121
|
+
criterions_attr.isTorchCriterion = module.startswith("torch")
|
|
122
|
+
for (
|
|
123
|
+
scheduler_classname,
|
|
124
|
+
schedulers,
|
|
125
|
+
) in criterions_attr.schedulersLoader.items():
|
|
126
|
+
criterions_attr.schedulers[
|
|
127
|
+
schedulers.getschedulers(
|
|
128
|
+
f"{konfai_root()}.Model.{model_classname}.outputs_criterions.{output_group}"
|
|
129
|
+
f".targets_criterions.{target_group}"
|
|
130
|
+
f".criterions_loader.{module_classpath}.schedulers",
|
|
131
|
+
scheduler_classname,
|
|
132
|
+
)
|
|
133
|
+
] = schedulers.nb_step
|
|
134
|
+
criterions[
|
|
135
|
+
config(
|
|
136
|
+
f"{konfai_root()}.Model.{model_classname}.outputs_criterions."
|
|
137
|
+
f"{output_group}.targets_criterions.{target_group}."
|
|
138
|
+
f"criterions_loader.{module_classpath}"
|
|
139
|
+
)(getattr(importlib.import_module(module), name))(config=None)
|
|
140
|
+
] = criterions_attr
|
|
141
|
+
return criterions
|
|
107
142
|
|
|
108
|
-
class Measure():
|
|
109
143
|
|
|
110
|
-
|
|
144
|
+
class TargetCriterionsLoader:
|
|
111
145
|
|
|
112
|
-
|
|
146
|
+
@config()
|
|
147
|
+
def __init__(
|
|
148
|
+
self,
|
|
149
|
+
targets_criterions: dict[str, CriterionsLoader] = {"default": CriterionsLoader()},
|
|
150
|
+
) -> None:
|
|
151
|
+
self.targets_criterions = targets_criterions
|
|
152
|
+
|
|
153
|
+
def get_targets_criterions(
|
|
154
|
+
self, output_group: str, model_classname: str
|
|
155
|
+
) -> dict[str, dict[torch.nn.Module, CriterionsAttr]]:
|
|
156
|
+
targets_criterions = {}
|
|
157
|
+
for target_group, criterions_loader in self.targets_criterions.items():
|
|
158
|
+
targets_criterions[target_group] = criterions_loader.get_criterions(
|
|
159
|
+
model_classname, output_group, target_group
|
|
160
|
+
)
|
|
161
|
+
return targets_criterions
|
|
162
|
+
|
|
163
|
+
|
|
164
|
+
class Measure:
|
|
165
|
+
|
|
166
|
+
class Loss:
|
|
167
|
+
|
|
168
|
+
def __init__(
|
|
169
|
+
self,
|
|
170
|
+
name: str,
|
|
171
|
+
output_group: str,
|
|
172
|
+
target_group: str,
|
|
173
|
+
group: int,
|
|
174
|
+
is_loss: bool,
|
|
175
|
+
accumulation: bool,
|
|
176
|
+
) -> None:
|
|
113
177
|
self.name = name
|
|
114
|
-
self.
|
|
178
|
+
self.is_loss = is_loss
|
|
115
179
|
self.accumulation = accumulation
|
|
116
180
|
self.output_group = output_group
|
|
117
181
|
self.target_group = target_group
|
|
@@ -120,144 +184,231 @@ class Measure():
|
|
|
120
184
|
self._loss: list[torch.Tensor] = []
|
|
121
185
|
self._weight: list[float] = []
|
|
122
186
|
self._values: list[float] = []
|
|
123
|
-
|
|
124
|
-
def
|
|
187
|
+
|
|
188
|
+
def reset_loss(self) -> None:
|
|
125
189
|
self._loss.clear()
|
|
126
190
|
|
|
127
191
|
def add(self, weight: float, value: torch.Tensor) -> None:
|
|
128
|
-
if self.
|
|
129
|
-
self._loss.append(value if self.
|
|
192
|
+
if self.is_loss or value != 0:
|
|
193
|
+
self._loss.append(value if self.is_loss else value.detach())
|
|
130
194
|
self._values.append(value.item())
|
|
131
195
|
self._weight.append(weight)
|
|
132
196
|
|
|
133
|
-
def
|
|
134
|
-
return self._loss[-1]*self._weight[-1] if len(self._loss) else torch.zeros((1), requires_grad
|
|
135
|
-
|
|
136
|
-
def
|
|
137
|
-
return
|
|
197
|
+
def get_last_loss(self) -> torch.Tensor:
|
|
198
|
+
return self._loss[-1] * self._weight[-1] if len(self._loss) else torch.zeros((1), requires_grad=True)
|
|
199
|
+
|
|
200
|
+
def get_loss(self) -> torch.Tensor:
|
|
201
|
+
return (
|
|
202
|
+
torch.stack([w * loss_value for w, loss_value in zip(self._weight, self._loss)], dim=0).mean(dim=0)
|
|
203
|
+
if len(self._loss)
|
|
204
|
+
else torch.zeros((1), requires_grad=True)
|
|
205
|
+
)
|
|
138
206
|
|
|
139
207
|
def __len__(self) -> int:
|
|
140
208
|
return len(self._loss)
|
|
141
|
-
|
|
142
|
-
def __init__(self, model_classname : str, outputsCriterions: dict[str, TargetCriterionsLoader]) -> None:
|
|
143
|
-
super().__init__()
|
|
144
|
-
self.outputsCriterions = {}
|
|
145
|
-
for output_group, targetCriterionsLoader in outputsCriterions.items():
|
|
146
|
-
self.outputsCriterions[output_group.replace(":", ".")] = targetCriterionsLoader.getTargetsCriterions(output_group, model_classname)
|
|
147
|
-
self._loss : dict[int, dict[str, Measure.Loss]] = {}
|
|
148
209
|
|
|
149
|
-
def
|
|
210
|
+
def __init__(
|
|
211
|
+
self,
|
|
212
|
+
model_classname: str,
|
|
213
|
+
outputs_criterions_loader: dict[str, TargetCriterionsLoader],
|
|
214
|
+
) -> None:
|
|
215
|
+
super().__init__()
|
|
216
|
+
self.outputs_criterions: dict[str, dict[str, dict[torch.nn.Module, CriterionsAttr]]] = {}
|
|
217
|
+
for output_group, target_criterions_loader in outputs_criterions_loader.items():
|
|
218
|
+
self.outputs_criterions[output_group.replace(":", ".")] = target_criterions_loader.get_targets_criterions(
|
|
219
|
+
output_group, model_classname
|
|
220
|
+
)
|
|
221
|
+
self._loss: dict[int, dict[str, Measure.Loss]] = {}
|
|
222
|
+
|
|
223
|
+
def init(self, model: torch.nn.Module, group_dest: list[str]) -> None:
|
|
150
224
|
outputs_group_rename = {}
|
|
151
|
-
|
|
225
|
+
|
|
152
226
|
modules = []
|
|
153
|
-
for i,_,_ in model.
|
|
227
|
+
for i, _, _ in model.named_module_args_dict():
|
|
154
228
|
modules.append(i)
|
|
155
229
|
|
|
156
|
-
for output_group in self.
|
|
157
|
-
if output_group not in modules:
|
|
158
|
-
|
|
230
|
+
for output_group in self.outputs_criterions.keys():
|
|
231
|
+
if output_group.replace(";accu;", "") not in modules:
|
|
232
|
+
raise MeasureError(
|
|
233
|
+
f"The output group '{output_group}' defined in 'outputs_criterions' "
|
|
234
|
+
"does not correspond to any module in the model.",
|
|
159
235
|
f"Available modules: {modules}",
|
|
160
|
-
"Please check that the name matches exactly a submodule or output of your model architecture."
|
|
236
|
+
"Please check that the name matches exactly a submodule or output of your model architecture.",
|
|
161
237
|
)
|
|
162
|
-
for target_group in self.
|
|
163
|
-
for target_group_tmp in target_group.split(";"):
|
|
238
|
+
for target_group in self.outputs_criterions[output_group]:
|
|
239
|
+
for target_group_tmp in target_group.split(";"):
|
|
164
240
|
if target_group_tmp not in group_dest:
|
|
165
241
|
raise MeasureError(
|
|
166
|
-
f"The target_group '{target_group_tmp}' defined in
|
|
167
|
-
"
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
242
|
+
f"The target_group '{target_group_tmp}' defined in "
|
|
243
|
+
"'outputs_criterions.{output_group}.targets_criterions'"
|
|
244
|
+
" was not found in the available destination groups.",
|
|
245
|
+
"This target_group is expected for loss or metric computation, "
|
|
246
|
+
"but was not loaded in 'group_dest'.",
|
|
247
|
+
f"Please make sure that the group '{target_group_tmp}' is defined in "
|
|
248
|
+
"'Dataset:groups_src:...:groups_dest:'{target_group_tmp}'' "
|
|
249
|
+
"and correctly loaded from the dataset.",
|
|
250
|
+
)
|
|
251
|
+
for criterion in self.outputs_criterions[output_group][target_group]:
|
|
252
|
+
if not self.outputs_criterions[output_group][target_group][criterion].isTorchCriterion:
|
|
171
253
|
outputs_group_rename[output_group] = criterion.init(model, output_group, target_group)
|
|
172
254
|
|
|
173
|
-
|
|
255
|
+
outputs_criterions_bak = self.outputs_criterions.copy()
|
|
174
256
|
for old, new in outputs_group_rename.items():
|
|
175
|
-
self.
|
|
176
|
-
self.
|
|
177
|
-
for output_group in self.
|
|
178
|
-
for target_group in self.
|
|
179
|
-
for criterion,
|
|
180
|
-
if
|
|
181
|
-
self._loss[
|
|
182
|
-
self._loss[
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
257
|
+
self.outputs_criterions.pop(old)
|
|
258
|
+
self.outputs_criterions[new] = outputs_criterions_bak[old]
|
|
259
|
+
for output_group in self.outputs_criterions:
|
|
260
|
+
for target_group in self.outputs_criterions[output_group]:
|
|
261
|
+
for criterion, criterions_attr in self.outputs_criterions[output_group][target_group].items():
|
|
262
|
+
if criterions_attr.group not in self._loss:
|
|
263
|
+
self._loss[criterions_attr.group] = {}
|
|
264
|
+
self._loss[criterions_attr.group][
|
|
265
|
+
f"{output_group}:{target_group}:{criterion.__class__.__name__}"
|
|
266
|
+
] = Measure.Loss(
|
|
267
|
+
criterion.__class__.__name__,
|
|
268
|
+
output_group,
|
|
269
|
+
target_group,
|
|
270
|
+
criterions_attr.group,
|
|
271
|
+
criterions_attr.is_loss,
|
|
272
|
+
criterions_attr.accumulation,
|
|
273
|
+
)
|
|
274
|
+
|
|
275
|
+
def update(
|
|
276
|
+
self,
|
|
277
|
+
output_group: str,
|
|
278
|
+
output: torch.Tensor,
|
|
279
|
+
data_dict: dict[str, torch.Tensor],
|
|
280
|
+
it: int,
|
|
281
|
+
nb_patch: int,
|
|
282
|
+
training: bool,
|
|
283
|
+
) -> None:
|
|
284
|
+
for target_group in self.outputs_criterions[output_group]:
|
|
285
|
+
target = [
|
|
286
|
+
data_dict[group].to(output[0].device).detach()
|
|
287
|
+
for group in target_group.split(";")
|
|
288
|
+
if group in data_dict
|
|
289
|
+
]
|
|
290
|
+
|
|
291
|
+
for criterion, criterions_attr in self.outputs_criterions[output_group][target_group].items():
|
|
292
|
+
if it >= criterions_attr.start and (criterions_attr.stop is None or it <= criterions_attr.stop):
|
|
293
|
+
scheduler = self.update_scheduler(criterions_attr.schedulers, it)
|
|
294
|
+
self._loss[criterions_attr.group][
|
|
295
|
+
f"{output_group}:{target_group}:{criterion.__class__.__name__}"
|
|
296
|
+
].add(scheduler.get_value(), criterion(output, *target))
|
|
297
|
+
if (
|
|
298
|
+
training
|
|
299
|
+
and len(
|
|
300
|
+
np.unique(
|
|
301
|
+
[
|
|
302
|
+
len(loss_value)
|
|
303
|
+
for loss_value in self._loss[criterions_attr.group].values()
|
|
304
|
+
if loss_value.accumulation and loss_value.is_loss
|
|
305
|
+
]
|
|
306
|
+
)
|
|
307
|
+
)
|
|
308
|
+
== 1
|
|
309
|
+
):
|
|
310
|
+
if criterions_attr.is_loss:
|
|
311
|
+
loss = torch.zeros((1), requires_grad=True)
|
|
312
|
+
for v in [
|
|
313
|
+
loss_value
|
|
314
|
+
for loss_value in self._loss[criterions_attr.group].values()
|
|
315
|
+
if loss_value.accumulation and loss_value.is_loss
|
|
316
|
+
]:
|
|
317
|
+
loss_value = v.get_last_loss()
|
|
318
|
+
loss = loss.to(loss_value.device) + loss_value
|
|
319
|
+
loss = loss / nb_patch
|
|
199
320
|
loss.backward()
|
|
200
321
|
|
|
201
|
-
def
|
|
322
|
+
def get_loss(self) -> list[torch.Tensor]:
|
|
202
323
|
loss: dict[int, torch.Tensor] = {}
|
|
203
324
|
for group in self._loss.keys():
|
|
204
|
-
loss[group] = torch.zeros((1), requires_grad
|
|
325
|
+
loss[group] = torch.zeros((1), requires_grad=True)
|
|
205
326
|
for v in self._loss[group].values():
|
|
206
|
-
if v.
|
|
207
|
-
|
|
208
|
-
loss[v.group] = loss[v.group].to(
|
|
209
|
-
return loss.values()
|
|
327
|
+
if v.is_loss and not v.accumulation:
|
|
328
|
+
loss_value = v.get_loss()
|
|
329
|
+
loss[v.group] = loss[v.group].to(loss_value.device) + loss_value
|
|
330
|
+
return list(loss.values())
|
|
210
331
|
|
|
211
|
-
def
|
|
332
|
+
def reset_loss(self) -> None:
|
|
212
333
|
for group in self._loss.keys():
|
|
213
334
|
for v in self._loss[group].values():
|
|
214
|
-
v.
|
|
215
|
-
|
|
216
|
-
def
|
|
335
|
+
v.reset_loss()
|
|
336
|
+
|
|
337
|
+
def get_last_values(self, n: int = 1) -> dict[str, float]:
|
|
217
338
|
result = {}
|
|
218
339
|
for group in self._loss.keys():
|
|
219
|
-
result.update(
|
|
340
|
+
result.update(
|
|
341
|
+
{
|
|
342
|
+
name: np.nanmean(value._values[-n:] if n > 0 else value._values)
|
|
343
|
+
for name, value in self._loss[group].items()
|
|
344
|
+
if n < 0 or len(value._values) >= n
|
|
345
|
+
}
|
|
346
|
+
)
|
|
220
347
|
return result
|
|
221
|
-
|
|
222
|
-
def
|
|
348
|
+
|
|
349
|
+
def get_last_weights(self, n: int = 1) -> dict[str, float]:
|
|
223
350
|
result = {}
|
|
224
351
|
for group in self._loss.keys():
|
|
225
|
-
result.update(
|
|
352
|
+
result.update(
|
|
353
|
+
{
|
|
354
|
+
name: np.nanmean(value._weight[-n:] if n > 0 else value._weight)
|
|
355
|
+
for name, value in self._loss[group].items()
|
|
356
|
+
if n < 0 or len(value._values) >= n
|
|
357
|
+
}
|
|
358
|
+
)
|
|
226
359
|
return result
|
|
227
360
|
|
|
228
|
-
def
|
|
229
|
-
result =
|
|
361
|
+
def format_loss(self, is_loss: bool, n: int) -> dict[str, tuple[float, float]]:
|
|
362
|
+
result = {}
|
|
230
363
|
for group in self._loss.keys():
|
|
231
364
|
for name, loss in self._loss[group].items():
|
|
232
|
-
if loss.
|
|
233
|
-
result[name] = (
|
|
365
|
+
if loss.is_loss == is_loss and len(loss._values) >= n:
|
|
366
|
+
result[name] = (
|
|
367
|
+
np.nanmean(loss._weight[-n:]),
|
|
368
|
+
np.nanmean(loss._values[-n:]),
|
|
369
|
+
)
|
|
234
370
|
return result
|
|
235
371
|
|
|
236
372
|
def update_scheduler(self, schedulers: dict[Scheduler, int], it: int) -> Scheduler:
|
|
237
373
|
step = 0
|
|
238
|
-
|
|
239
|
-
for
|
|
240
|
-
if value is None or (it >= step
|
|
374
|
+
_scheduler = None
|
|
375
|
+
for _scheduler, value in schedulers.items():
|
|
376
|
+
if value is None or (it >= step and it < step + value):
|
|
241
377
|
break
|
|
242
378
|
step += value
|
|
243
|
-
if
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
379
|
+
if _scheduler:
|
|
380
|
+
_scheduler.step(it - step)
|
|
381
|
+
if _scheduler is None:
|
|
382
|
+
raise NameError(
|
|
383
|
+
f"No scheduler found for iteration {it}. "
|
|
384
|
+
f"Available steps were: {list(schedulers.values())}. "
|
|
385
|
+
f"Check your configuration."
|
|
386
|
+
)
|
|
387
|
+
return _scheduler
|
|
388
|
+
|
|
389
|
+
|
|
247
390
|
class ModuleArgsDict(torch.nn.Module, ABC):
|
|
248
|
-
|
|
391
|
+
|
|
249
392
|
class ModuleArgs:
|
|
250
393
|
|
|
251
|
-
def __init__(
|
|
394
|
+
def __init__(
|
|
395
|
+
self,
|
|
396
|
+
in_branch: list[str],
|
|
397
|
+
out_branch: list[str],
|
|
398
|
+
pretrained: bool,
|
|
399
|
+
alias: list[str],
|
|
400
|
+
requires_grad: bool | None,
|
|
401
|
+
training: None | bool,
|
|
402
|
+
) -> None:
|
|
252
403
|
super().__init__()
|
|
253
|
-
self.alias= alias
|
|
404
|
+
self.alias = alias
|
|
254
405
|
self.pretrained = pretrained
|
|
255
406
|
self.in_branch = in_branch
|
|
256
407
|
self.out_branch = out_branch
|
|
257
|
-
self.in_channels = None
|
|
258
|
-
self.in_is_channel = True
|
|
259
|
-
self.out_channels = None
|
|
260
|
-
self.out_is_channel = True
|
|
408
|
+
self.in_channels: int | None = None
|
|
409
|
+
self.in_is_channel: bool = True
|
|
410
|
+
self.out_channels: int | None = None
|
|
411
|
+
self.out_is_channel: bool = True
|
|
261
412
|
self.requires_grad = requires_grad
|
|
262
413
|
self.isCheckpoint = False
|
|
263
414
|
self.isGPU_Checkpoint = False
|
|
@@ -267,51 +418,54 @@ class ModuleArgsDict(torch.nn.Module, ABC):
|
|
|
267
418
|
|
|
268
419
|
def __init__(self) -> None:
|
|
269
420
|
super().__init__()
|
|
270
|
-
self._modulesArgs
|
|
421
|
+
self._modulesArgs: dict[str, ModuleArgsDict.ModuleArgs] = {}
|
|
271
422
|
self._training = NetState.TRAIN
|
|
272
423
|
|
|
273
|
-
def _addindent(self, s_: str,
|
|
274
|
-
s = s_.split(
|
|
424
|
+
def _addindent(self, s_: str, num_spaces: int):
|
|
425
|
+
s = s_.split("\n")
|
|
275
426
|
if len(s) == 1:
|
|
276
427
|
return s_
|
|
277
428
|
first = s.pop(0)
|
|
278
|
-
s = [(
|
|
279
|
-
|
|
280
|
-
s = first + '\n' + s
|
|
281
|
-
return s
|
|
429
|
+
s = [(num_spaces * " ") + line for line in s]
|
|
430
|
+
return first + "\n" + "\n".join(s)
|
|
282
431
|
|
|
283
432
|
def __repr__(self):
|
|
284
433
|
extra_lines = []
|
|
285
434
|
|
|
286
435
|
extra_repr = self.extra_repr()
|
|
287
436
|
if extra_repr:
|
|
288
|
-
extra_lines = extra_repr.split(
|
|
437
|
+
extra_lines = extra_repr.split("\n")
|
|
289
438
|
|
|
290
439
|
child_lines = []
|
|
291
|
-
|
|
440
|
+
|
|
441
|
+
def is_simple_branch(x):
|
|
442
|
+
return len(x) > 1 or x[0] != 0
|
|
443
|
+
|
|
292
444
|
for key, module in self._modules.items():
|
|
293
445
|
mod_str = repr(module)
|
|
294
446
|
|
|
295
447
|
mod_str = self._addindent(mod_str, 2)
|
|
296
448
|
desc = ""
|
|
297
|
-
if is_simple_branch(self._modulesArgs[key].in_branch) or is_simple_branch(
|
|
298
|
-
|
|
449
|
+
if is_simple_branch(self._modulesArgs[key].in_branch) or is_simple_branch(
|
|
450
|
+
self._modulesArgs[key].out_branch
|
|
451
|
+
):
|
|
452
|
+
desc += f", {self._modulesArgs[key].in_branch}->{self._modulesArgs[key].out_branch}"
|
|
299
453
|
if not self._modulesArgs[key].pretrained:
|
|
300
454
|
desc += ", pretrained=False"
|
|
301
455
|
if self._modulesArgs[key].alias:
|
|
302
|
-
desc += ", alias={
|
|
303
|
-
desc += ", in_channels={
|
|
304
|
-
desc += ", in_is_channel={
|
|
305
|
-
desc += ", out_channels={
|
|
306
|
-
desc += ", out_is_channel={
|
|
307
|
-
desc += ", is_end={
|
|
308
|
-
desc += ", isInCheckpoint={
|
|
309
|
-
desc += ", isInGPU_Checkpoint={
|
|
310
|
-
desc += ", requires_grad={
|
|
311
|
-
desc += ", device={
|
|
312
|
-
|
|
313
|
-
child_lines.append("({}{}) {}"
|
|
314
|
-
|
|
456
|
+
desc += f", alias={self._modulesArgs[key].alias}"
|
|
457
|
+
desc += f", in_channels={self._modulesArgs[key].in_channels}"
|
|
458
|
+
desc += f", in_is_channel={self._modulesArgs[key].in_is_channel}"
|
|
459
|
+
desc += f", out_channels={self._modulesArgs[key].out_channels}"
|
|
460
|
+
desc += f", out_is_channel={self._modulesArgs[key].out_is_channel}"
|
|
461
|
+
desc += f", is_end={self._modulesArgs[key]._isEnd}"
|
|
462
|
+
desc += f", isInCheckpoint={self._modulesArgs[key].isCheckpoint}"
|
|
463
|
+
desc += f", isInGPU_Checkpoint={self._modulesArgs[key].isGPU_Checkpoint}"
|
|
464
|
+
desc += f", requires_grad={self._modulesArgs[key].requires_grad}"
|
|
465
|
+
desc += f", device={self._modulesArgs[key].gpu}"
|
|
466
|
+
|
|
467
|
+
child_lines.append(f"({key}{desc}) {mod_str}")
|
|
468
|
+
|
|
315
469
|
lines = extra_lines + child_lines
|
|
316
470
|
|
|
317
471
|
desc = ""
|
|
@@ -319,53 +473,71 @@ class ModuleArgsDict(torch.nn.Module, ABC):
|
|
|
319
473
|
if len(extra_lines) == 1 and not child_lines:
|
|
320
474
|
desc += extra_lines[0]
|
|
321
475
|
else:
|
|
322
|
-
desc +=
|
|
476
|
+
desc += "\n " + "\n ".join(lines) + "\n"
|
|
477
|
+
|
|
478
|
+
return f"{self._get_name()}({desc})"
|
|
323
479
|
|
|
324
|
-
return "{}({})".format(self._get_name(), desc)
|
|
325
|
-
|
|
326
480
|
def __getitem__(self, key: str) -> torch.nn.Module:
|
|
327
481
|
module = self._modules[key]
|
|
328
|
-
|
|
329
|
-
|
|
482
|
+
if not module:
|
|
483
|
+
raise ValueError(f"Module '{key}' is None or missing in self._modules")
|
|
484
|
+
return module
|
|
330
485
|
|
|
331
486
|
@_copy_to_script_wrapper
|
|
332
487
|
def keys(self) -> Iterable[str]:
|
|
333
488
|
return self._modules.keys()
|
|
334
489
|
|
|
335
490
|
@_copy_to_script_wrapper
|
|
336
|
-
def items(self) -> Iterable[tuple[str,
|
|
491
|
+
def items(self) -> Iterable[tuple[str, torch.nn.Module | None]]:
|
|
337
492
|
return self._modules.items()
|
|
338
493
|
|
|
339
494
|
@_copy_to_script_wrapper
|
|
340
|
-
def values(self) -> Iterable[
|
|
495
|
+
def values(self) -> Iterable[torch.nn.Module | None]:
|
|
341
496
|
return self._modules.values()
|
|
342
497
|
|
|
343
|
-
def add_module(
|
|
498
|
+
def add_module(
|
|
499
|
+
self,
|
|
500
|
+
name: str,
|
|
501
|
+
module: torch.nn.Module,
|
|
502
|
+
in_branch: Sequence[int | str] = [0],
|
|
503
|
+
out_branch: Sequence[int | str] = [0],
|
|
504
|
+
pretrained: bool = True,
|
|
505
|
+
alias: list[str] = [],
|
|
506
|
+
requires_grad: bool | None = None,
|
|
507
|
+
training: None | bool = None,
|
|
508
|
+
) -> None:
|
|
344
509
|
super().add_module(name, module)
|
|
345
|
-
self._modulesArgs[name] = ModuleArgsDict.ModuleArgs(
|
|
346
|
-
|
|
347
|
-
|
|
348
|
-
|
|
349
|
-
|
|
510
|
+
self._modulesArgs[name] = ModuleArgsDict.ModuleArgs(
|
|
511
|
+
[str(value) for value in in_branch],
|
|
512
|
+
[str(value) for value in out_branch],
|
|
513
|
+
pretrained,
|
|
514
|
+
alias,
|
|
515
|
+
requires_grad,
|
|
516
|
+
training,
|
|
517
|
+
)
|
|
518
|
+
|
|
519
|
+
def get_mapping(self):
|
|
520
|
+
results: dict[str, str] = {}
|
|
521
|
+
for name, module_args in self._modulesArgs.items():
|
|
350
522
|
module = self[name]
|
|
351
523
|
if isinstance(module, ModuleArgsDict):
|
|
352
|
-
if len(
|
|
353
|
-
count =
|
|
524
|
+
if len(module_args.alias):
|
|
525
|
+
count = dict.fromkeys(set(module.get_mapping().values()), 0)
|
|
354
526
|
if len(count):
|
|
355
|
-
for k, v in module.
|
|
356
|
-
alias_name =
|
|
527
|
+
for k, v in module.get_mapping().items():
|
|
528
|
+
alias_name = module_args.alias[count[v]]
|
|
357
529
|
if k == "":
|
|
358
|
-
results.update({alias_name
|
|
530
|
+
results.update({alias_name: name + "." + v})
|
|
359
531
|
else:
|
|
360
|
-
results.update({alias_name+"."+k
|
|
361
|
-
count[v]+=1
|
|
532
|
+
results.update({alias_name + "." + k: name + "." + v})
|
|
533
|
+
count[v] += 1
|
|
362
534
|
else:
|
|
363
|
-
for alias in
|
|
364
|
-
results.update({alias
|
|
535
|
+
for alias in module_args.alias:
|
|
536
|
+
results.update({alias: name})
|
|
365
537
|
else:
|
|
366
|
-
results.update({k
|
|
538
|
+
results.update({k: name + "." + v for k, v in module.get_mapping().items()})
|
|
367
539
|
else:
|
|
368
|
-
for alias in
|
|
540
|
+
for alias in module_args.alias:
|
|
369
541
|
results[alias] = name
|
|
370
542
|
return results
|
|
371
543
|
|
|
@@ -375,24 +547,24 @@ class ModuleArgsDict(torch.nn.Module, ABC):
|
|
|
375
547
|
if isinstance(module, ModuleArgsDict):
|
|
376
548
|
module.init(init_type, init_gain)
|
|
377
549
|
elif isinstance(module, torch.nn.modules.conv._ConvNd) or isinstance(module, torch.nn.Linear):
|
|
378
|
-
if init_type ==
|
|
550
|
+
if init_type == "normal":
|
|
379
551
|
torch.nn.init.normal_(module.weight, 0.0, init_gain)
|
|
380
|
-
elif init_type ==
|
|
552
|
+
elif init_type == "xavier":
|
|
381
553
|
torch.nn.init.xavier_normal_(module.weight, gain=init_gain)
|
|
382
|
-
elif init_type ==
|
|
383
|
-
torch.nn.init.kaiming_normal_(module.weight, a=0, mode=
|
|
384
|
-
elif init_type ==
|
|
554
|
+
elif init_type == "kaiming":
|
|
555
|
+
torch.nn.init.kaiming_normal_(module.weight, a=0, mode="fan_in")
|
|
556
|
+
elif init_type == "orthogonal":
|
|
385
557
|
torch.nn.init.orthogonal_(module.weight, gain=init_gain)
|
|
386
558
|
elif init_type == "trunc_normal":
|
|
387
559
|
torch.nn.init.trunc_normal_(module.weight, std=init_gain)
|
|
388
560
|
else:
|
|
389
|
-
raise NotImplementedError(
|
|
561
|
+
raise NotImplementedError(f"Initialization method {init_type} is not implemented")
|
|
390
562
|
if module.bias is not None:
|
|
391
563
|
torch.nn.init.constant_(module.bias, 0.0)
|
|
392
564
|
|
|
393
565
|
elif isinstance(module, torch.nn.modules.batchnorm._BatchNorm):
|
|
394
566
|
if module.weight is not None:
|
|
395
|
-
torch.nn.init.normal_(module.weight, 0.0, std
|
|
567
|
+
torch.nn.init.normal_(module.weight, 0.0, std=init_gain)
|
|
396
568
|
if module.bias is not None:
|
|
397
569
|
torch.nn.init.constant_(module.bias, 0.0)
|
|
398
570
|
|
|
@@ -400,8 +572,7 @@ class ModuleArgsDict(torch.nn.Module, ABC):
|
|
|
400
572
|
for module in self._modules.values():
|
|
401
573
|
ModuleArgsDict.init_func(module, init_type, init_gain)
|
|
402
574
|
|
|
403
|
-
|
|
404
|
-
def named_forward(self, *inputs: torch.Tensor) -> Iterator[tuple[str, torch.Tensor]]:
|
|
575
|
+
def named_forward(self, *inputs: torch.Tensor) -> Iterator[tuple[str, torch.Tensor]]:
|
|
405
576
|
if len(inputs) > 0:
|
|
406
577
|
branchs: dict[str, torch.Tensor] = {}
|
|
407
578
|
for i, sinput in enumerate(inputs):
|
|
@@ -410,7 +581,10 @@ class ModuleArgsDict(torch.nn.Module, ABC):
|
|
|
410
581
|
out = inputs[0]
|
|
411
582
|
tmp = []
|
|
412
583
|
for name, module in self.items():
|
|
413
|
-
if self._modulesArgs[name].training is None or (
|
|
584
|
+
if self._modulesArgs[name].training is None or (
|
|
585
|
+
not (self._modulesArgs[name].training and self._training == NetState.PREDICTION)
|
|
586
|
+
and not (not self._modulesArgs[name].training and self._training == NetState.TRAIN)
|
|
587
|
+
):
|
|
414
588
|
requires_grad = self._modulesArgs[name].requires_grad
|
|
415
589
|
if requires_grad is not None and module:
|
|
416
590
|
module.requires_grad_(requires_grad)
|
|
@@ -418,64 +592,75 @@ class ModuleArgsDict(torch.nn.Module, ABC):
|
|
|
418
592
|
if ib not in branchs:
|
|
419
593
|
branchs[ib] = inputs[0]
|
|
420
594
|
for branchs_key in branchs.keys():
|
|
421
|
-
if
|
|
595
|
+
if (
|
|
596
|
+
str(self._modulesArgs[name].gpu) != "cpu"
|
|
597
|
+
and str(branchs[branchs_key].device) != "cuda:" + self._modulesArgs[name].gpu
|
|
598
|
+
):
|
|
422
599
|
branchs[branchs_key] = branchs[branchs_key].to(int(self._modulesArgs[name].gpu))
|
|
423
|
-
|
|
600
|
+
|
|
424
601
|
if self._modulesArgs[name].isCheckpoint:
|
|
425
|
-
out = checkpoint(
|
|
602
|
+
out = checkpoint(
|
|
603
|
+
module,
|
|
604
|
+
*[branchs[i] for i in self._modulesArgs[name].in_branch],
|
|
605
|
+
use_reentrant=True,
|
|
606
|
+
)
|
|
426
607
|
for ob in self._modulesArgs[name].out_branch:
|
|
427
608
|
branchs[ob] = out
|
|
428
609
|
yield name, out
|
|
429
610
|
else:
|
|
430
611
|
if isinstance(module, ModuleArgsDict):
|
|
431
|
-
for k, out in module.named_forward(
|
|
612
|
+
for k, out in module.named_forward(
|
|
613
|
+
*[branchs[i] for i in self._modulesArgs[name].in_branch]
|
|
614
|
+
):
|
|
432
615
|
for ob in self._modulesArgs[name].out_branch:
|
|
433
616
|
if ob in module._modulesArgs[k.split(".")[0].replace(";accu;", "")].out_branch:
|
|
434
617
|
tmp.append(ob)
|
|
435
618
|
branchs[ob] = out
|
|
436
|
-
yield name+"."+k, out
|
|
619
|
+
yield name + "." + k, out
|
|
437
620
|
for ob in self._modulesArgs[name].out_branch:
|
|
438
621
|
if ob not in tmp:
|
|
439
622
|
branchs[ob] = out
|
|
440
623
|
elif isinstance(module, torch.nn.Module):
|
|
441
|
-
out = module(*[branchs[i] for i in self._modulesArgs[name].in_branch])
|
|
624
|
+
out = module(*[branchs[i] for i in self._modulesArgs[name].in_branch])
|
|
442
625
|
for ob in self._modulesArgs[name].out_branch:
|
|
443
626
|
branchs[ob] = out
|
|
444
627
|
yield name, out
|
|
445
628
|
del branchs
|
|
446
629
|
|
|
447
630
|
def forward(self, *input: torch.Tensor) -> torch.Tensor:
|
|
448
|
-
|
|
449
|
-
for
|
|
631
|
+
_v = input
|
|
632
|
+
for _, _v in self.named_forward(*input):
|
|
450
633
|
pass
|
|
451
|
-
return
|
|
634
|
+
return _v
|
|
452
635
|
|
|
453
|
-
def named_parameters(
|
|
454
|
-
|
|
636
|
+
def named_parameters(
|
|
637
|
+
self, pretrained: bool = False, recurse=False
|
|
638
|
+
) -> Iterator[tuple[str, torch.nn.parameter.Parameter]]:
|
|
639
|
+
for name, module_args in self._modulesArgs.items():
|
|
455
640
|
module = self[name]
|
|
456
641
|
if isinstance(module, ModuleArgsDict):
|
|
457
|
-
for k, v in module.named_parameters(pretrained
|
|
458
|
-
yield name+"."+k, v
|
|
642
|
+
for k, v in module.named_parameters(pretrained=pretrained):
|
|
643
|
+
yield name + "." + k, v
|
|
459
644
|
elif isinstance(module, torch.nn.Module):
|
|
460
|
-
if not pretrained or not
|
|
461
|
-
if
|
|
645
|
+
if not pretrained or not module_args.pretrained:
|
|
646
|
+
if module_args.training is None or module_args.training:
|
|
462
647
|
for k, v in module.named_parameters():
|
|
463
|
-
yield name+"."+k, v
|
|
648
|
+
yield name + "." + k, v
|
|
464
649
|
|
|
465
650
|
def parameters(self, pretrained: bool = False):
|
|
466
|
-
for _, v in self.named_parameters(pretrained
|
|
651
|
+
for _, v in self.named_parameters(pretrained=pretrained):
|
|
467
652
|
yield v
|
|
468
|
-
|
|
469
|
-
def
|
|
653
|
+
|
|
654
|
+
def named_module_args_dict(self) -> Iterator[tuple[str, Self, ModuleArgs]]:
|
|
470
655
|
for name, module in self._modules.items():
|
|
471
656
|
yield name, module, self._modulesArgs[name]
|
|
472
657
|
if isinstance(module, ModuleArgsDict):
|
|
473
|
-
for k, v, u in module.
|
|
474
|
-
yield name+"."+k, v, u
|
|
658
|
+
for k, v, u in module.named_module_args_dict():
|
|
659
|
+
yield name + "." + k, v, u
|
|
475
660
|
|
|
476
661
|
def _requires_grad(self, keys: list[str]):
|
|
477
662
|
keys = keys.copy()
|
|
478
|
-
for name, module, args in self.
|
|
663
|
+
for name, module, args in self.named_module_args_dict():
|
|
479
664
|
requires_grad = args.requires_grad
|
|
480
665
|
if requires_grad is not None:
|
|
481
666
|
module.requires_grad_(requires_grad)
|
|
@@ -484,162 +669,221 @@ class ModuleArgsDict(torch.nn.Module, ABC):
|
|
|
484
669
|
if len(keys) == 0:
|
|
485
670
|
break
|
|
486
671
|
|
|
672
|
+
|
|
487
673
|
class OutputsGroup(list):
|
|
488
674
|
|
|
489
675
|
def __init__(self, measure: Measure) -> None:
|
|
490
676
|
self.layers: dict[str, torch.Tensor] = {}
|
|
491
677
|
self.measure = measure
|
|
492
678
|
|
|
493
|
-
def
|
|
679
|
+
def add_layer(self, name: str, layer: torch.Tensor):
|
|
494
680
|
self.layers[name] = layer
|
|
495
681
|
|
|
496
|
-
def
|
|
682
|
+
def is_done(self):
|
|
497
683
|
return len(self) == len(self.layers)
|
|
498
|
-
|
|
684
|
+
|
|
499
685
|
def clear(self):
|
|
500
686
|
self.layers.clear()
|
|
501
687
|
|
|
502
|
-
|
|
503
688
|
|
|
504
689
|
class Network(ModuleArgsDict, ABC):
|
|
505
690
|
|
|
506
|
-
def _apply_network(
|
|
507
|
-
|
|
691
|
+
def _apply_network(
|
|
692
|
+
self,
|
|
693
|
+
name_function: Callable[[Self], str],
|
|
694
|
+
networks: list[str],
|
|
695
|
+
key: str,
|
|
696
|
+
function: Callable,
|
|
697
|
+
*args,
|
|
698
|
+
**kwargs,
|
|
699
|
+
) -> dict[str, object]:
|
|
700
|
+
results: dict[str, object] = {}
|
|
508
701
|
for module in self.values():
|
|
509
702
|
if isinstance(module, Network):
|
|
510
703
|
if name_function(module) not in networks:
|
|
511
704
|
networks.append(name_function(module))
|
|
512
|
-
for k, v in module._apply_network(
|
|
513
|
-
|
|
705
|
+
for k, v in module._apply_network(
|
|
706
|
+
name_function,
|
|
707
|
+
networks,
|
|
708
|
+
key + "." + name_function(module),
|
|
709
|
+
function,
|
|
710
|
+
*args,
|
|
711
|
+
**kwargs,
|
|
712
|
+
).items():
|
|
713
|
+
results.update({name_function(self) + "." + k: v})
|
|
514
714
|
if len([param.name for param in list(inspect.signature(function).parameters.values()) if param.name == "key"]):
|
|
515
715
|
function = partial(function, key=key)
|
|
516
716
|
|
|
517
717
|
results[name_function(self)] = function(self, *args, **kwargs)
|
|
518
718
|
return results
|
|
519
|
-
|
|
520
|
-
def _function_network():
|
|
521
|
-
def _function_network_d(function
|
|
522
|
-
def new_function(self
|
|
523
|
-
return self._apply_network(
|
|
719
|
+
|
|
720
|
+
def _function_network(): # type: ignore[misc]
|
|
721
|
+
def _function_network_d(function: Callable):
|
|
722
|
+
def new_function(self: Self, *args, **kwargs) -> dict[str, object]:
|
|
723
|
+
return self._apply_network(
|
|
724
|
+
lambda network: network.get_name(),
|
|
725
|
+
[],
|
|
726
|
+
self.get_name(),
|
|
727
|
+
function,
|
|
728
|
+
*args,
|
|
729
|
+
**kwargs,
|
|
730
|
+
)
|
|
731
|
+
|
|
524
732
|
return new_function
|
|
733
|
+
|
|
525
734
|
return _function_network_d
|
|
526
735
|
|
|
527
|
-
def __init__(
|
|
528
|
-
|
|
529
|
-
|
|
530
|
-
|
|
531
|
-
|
|
532
|
-
|
|
533
|
-
|
|
534
|
-
|
|
535
|
-
|
|
536
|
-
|
|
736
|
+
def __init__(
|
|
737
|
+
self,
|
|
738
|
+
in_channels: int = 1,
|
|
739
|
+
optimizer: OptimizerLoader | None = None,
|
|
740
|
+
schedulers: dict[str, LRSchedulersLoader] | None = None,
|
|
741
|
+
outputs_criterions: dict[str, TargetCriterionsLoader] | None = None,
|
|
742
|
+
patch: ModelPatch | None = None,
|
|
743
|
+
nb_batch_per_step: int = 1,
|
|
744
|
+
init_type: str = "normal",
|
|
745
|
+
init_gain: float = 0.02,
|
|
746
|
+
dim: int = 3,
|
|
747
|
+
) -> None:
|
|
537
748
|
super().__init__()
|
|
538
749
|
self.name = self.__class__.__name__
|
|
539
750
|
self.in_channels = in_channels
|
|
540
|
-
self.optimizerLoader
|
|
541
|
-
self.optimizer
|
|
751
|
+
self.optimizerLoader = optimizer
|
|
752
|
+
self.optimizer: torch.optim.Optimizer | None = None
|
|
542
753
|
|
|
543
|
-
self.
|
|
544
|
-
self.schedulers
|
|
754
|
+
self.lr_schedulers_loader = schedulers
|
|
755
|
+
self.schedulers: dict[torch.optim.lr_scheduler._LRScheduler, int] = {}
|
|
545
756
|
|
|
546
|
-
self.
|
|
547
|
-
self.measure
|
|
757
|
+
self.outputs_criterions_loader = outputs_criterions
|
|
758
|
+
self.measure: Measure | None = None
|
|
548
759
|
|
|
549
760
|
self.patch = patch
|
|
550
761
|
|
|
551
762
|
self.nb_batch_per_step = nb_batch_per_step
|
|
552
|
-
self.init_type
|
|
553
|
-
self.init_gain
|
|
763
|
+
self.init_type = init_type
|
|
764
|
+
self.init_gain = init_gain
|
|
554
765
|
self.dim = dim
|
|
555
766
|
self._it = 0
|
|
556
767
|
self._nb_lr_update = 0
|
|
557
|
-
self.outputsGroup
|
|
768
|
+
self.outputsGroup: list[OutputsGroup] = []
|
|
558
769
|
|
|
559
770
|
@_function_network()
|
|
560
771
|
def state_dict(self) -> dict[str, OrderedDict]:
|
|
561
|
-
destination = OrderedDict()
|
|
562
|
-
|
|
563
|
-
destination
|
|
772
|
+
destination: OrderedDict[str, Any] = OrderedDict()
|
|
773
|
+
local_metadata = {"version": self._version}
|
|
774
|
+
# destination["_metadata"] = OrderedDict({"": local_metadata})
|
|
564
775
|
self._save_to_state_dict(destination, "", False)
|
|
565
776
|
for name, module in self._modules.items():
|
|
566
777
|
if module is not None:
|
|
567
|
-
if not isinstance(module, Network):
|
|
568
|
-
module.state_dict(destination=destination, prefix="" + name +
|
|
778
|
+
if not isinstance(module, Network):
|
|
779
|
+
module.state_dict(destination=destination, prefix="" + name + ".", keep_vars=False)
|
|
569
780
|
for hook in self._state_dict_hooks.values():
|
|
570
781
|
hook_result = hook(self, destination, "", local_metadata)
|
|
571
782
|
if hook_result is not None:
|
|
572
783
|
destination = hook_result
|
|
573
784
|
return destination
|
|
574
|
-
|
|
785
|
+
|
|
575
786
|
def load_state_dict(self, state_dict: dict[str, torch.Tensor]):
|
|
576
787
|
missing_keys: list[str] = []
|
|
577
788
|
unexpected_keys: list[str] = []
|
|
578
789
|
error_msgs: list[str] = []
|
|
579
790
|
|
|
580
|
-
metadata = getattr(state_dict,
|
|
791
|
+
metadata = getattr(state_dict, "_metadata", None)
|
|
581
792
|
state_dict = state_dict.copy()
|
|
582
793
|
if metadata is not None:
|
|
583
|
-
state_dict
|
|
794
|
+
state_dict["_metadata"] = metadata
|
|
584
795
|
|
|
585
|
-
def load(module: torch.nn.Module, prefix=
|
|
796
|
+
def load(module: torch.nn.Module, prefix=""):
|
|
586
797
|
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
|
|
587
798
|
module._load_from_state_dict(
|
|
588
|
-
state_dict,
|
|
799
|
+
state_dict,
|
|
800
|
+
prefix,
|
|
801
|
+
local_metadata,
|
|
802
|
+
True,
|
|
803
|
+
missing_keys,
|
|
804
|
+
unexpected_keys,
|
|
805
|
+
error_msgs,
|
|
806
|
+
)
|
|
589
807
|
for name, child in module._modules.items():
|
|
590
808
|
if child is not None:
|
|
591
809
|
if not isinstance(child, Network):
|
|
592
810
|
if isinstance(child, torch.nn.modules.conv._ConvNd) or isinstance(module, torch.nn.Linear):
|
|
593
811
|
|
|
594
812
|
current_size = child.weight.shape[0]
|
|
595
|
-
last_size = state_dict[prefix + name+".weight"].shape[0]
|
|
813
|
+
last_size = state_dict[prefix + name + ".weight"].shape[0]
|
|
596
814
|
|
|
597
815
|
if current_size != last_size:
|
|
598
|
-
print(
|
|
816
|
+
print(
|
|
817
|
+
f"Warning: The size of '{prefix + name}' has changed from {last_size}"
|
|
818
|
+
f" to {current_size}. Please check for potential impacts"
|
|
819
|
+
)
|
|
599
820
|
ModuleArgsDict.init_func(child, self.init_type, self.init_gain)
|
|
600
821
|
|
|
601
822
|
with torch.no_grad():
|
|
602
|
-
child.weight[:last_size] = state_dict[prefix + name+".weight"]
|
|
823
|
+
child.weight[:last_size] = state_dict[prefix + name + ".weight"]
|
|
603
824
|
if child.bias is not None:
|
|
604
|
-
child.bias[:last_size] = state_dict[prefix + name+".bias"]
|
|
825
|
+
child.bias[:last_size] = state_dict[prefix + name + ".bias"]
|
|
605
826
|
return
|
|
606
|
-
load(child, prefix + name +
|
|
827
|
+
load(child, prefix + name + ".")
|
|
607
828
|
|
|
608
829
|
load(self)
|
|
609
|
-
del load
|
|
610
830
|
|
|
611
831
|
if len(unexpected_keys) > 0:
|
|
832
|
+
formatted_keys = ", ".join(f'"{k}"' for k in unexpected_keys)
|
|
612
833
|
error_msgs.insert(
|
|
613
|
-
0,
|
|
614
|
-
|
|
834
|
+
0,
|
|
835
|
+
f"Unexpected key(s) in state_dict: {formatted_keys}.",
|
|
836
|
+
)
|
|
615
837
|
if len(missing_keys) > 0:
|
|
838
|
+
formatted_keys = ", ".join(f'"{k}"' for k in missing_keys)
|
|
616
839
|
error_msgs.insert(
|
|
617
|
-
0,
|
|
618
|
-
|
|
840
|
+
0,
|
|
841
|
+
f"Missing key(s) in state_dict: {formatted_keys}.",
|
|
842
|
+
)
|
|
619
843
|
|
|
620
844
|
if len(error_msgs) > 0:
|
|
621
|
-
|
|
622
|
-
|
|
623
|
-
|
|
845
|
+
formatted_errors = "\n\t".join(error_msgs)
|
|
846
|
+
raise RuntimeError(
|
|
847
|
+
f"Error(s) in loading state_dict for {self.__class__.__name__}:\n\t{formatted_errors}",
|
|
848
|
+
)
|
|
849
|
+
|
|
624
850
|
def apply(self, fn: Callable[[torch.nn.Module], None]) -> None:
|
|
625
851
|
for module in self.children():
|
|
626
852
|
if not isinstance(module, Network):
|
|
627
853
|
module.apply(fn)
|
|
628
854
|
fn(self)
|
|
629
|
-
|
|
855
|
+
|
|
630
856
|
@_function_network()
|
|
631
|
-
def load(
|
|
857
|
+
def load(
|
|
858
|
+
self,
|
|
859
|
+
state_dict: dict[str, dict[str, torch.Tensor] | int],
|
|
860
|
+
init: bool = True,
|
|
861
|
+
ema: bool = False,
|
|
862
|
+
):
|
|
632
863
|
if init:
|
|
633
|
-
self.apply(
|
|
864
|
+
self.apply(
|
|
865
|
+
partial(
|
|
866
|
+
ModuleArgsDict.init_func,
|
|
867
|
+
init_type=self.init_type,
|
|
868
|
+
init_gain=self.init_gain,
|
|
869
|
+
)
|
|
870
|
+
)
|
|
634
871
|
name = "Model" + ("_EMA" if ema else "")
|
|
635
872
|
if name in state_dict:
|
|
636
|
-
|
|
637
|
-
|
|
638
|
-
|
|
639
|
-
|
|
873
|
+
value = state_dict[name]
|
|
874
|
+
model_state_dict_tmp = {}
|
|
875
|
+
if isinstance(value, dict):
|
|
876
|
+
model_state_dict_tmp = {k.split(".")[-1]: v for k, v in value.items()}[self.get_name()]
|
|
877
|
+
modules_name = self.get_mapping()
|
|
878
|
+
model_state_dict: OrderedDict[str, torch.Tensor] = OrderedDict()
|
|
879
|
+
|
|
640
880
|
for alias in model_state_dict_tmp.keys():
|
|
641
881
|
prefix = ".".join(alias.split(".")[:-1])
|
|
642
|
-
alias_list = [
|
|
882
|
+
alias_list = [
|
|
883
|
+
(".".join(prefix.split(".")[: len(i.split("."))]), v)
|
|
884
|
+
for i, v in modules_name.items()
|
|
885
|
+
if prefix.startswith(i)
|
|
886
|
+
]
|
|
643
887
|
|
|
644
888
|
if len(alias_list):
|
|
645
889
|
for a, b in alias_list:
|
|
@@ -648,19 +892,34 @@ class Network(ModuleArgsDict, ABC):
|
|
|
648
892
|
else:
|
|
649
893
|
model_state_dict[alias] = model_state_dict_tmp[alias]
|
|
650
894
|
self.load_state_dict(model_state_dict)
|
|
651
|
-
if "{
|
|
652
|
-
self.optimizer.load_state_dict(state_dict[
|
|
653
|
-
if "{
|
|
654
|
-
|
|
655
|
-
|
|
656
|
-
|
|
895
|
+
if f"{self.get_name()}_optimizer_state_dict" in state_dict and self.optimizer:
|
|
896
|
+
self.optimizer.load_state_dict(state_dict[f"{self.get_name()}_optimizer_state_dict"])
|
|
897
|
+
if f"{self.get_name()}_it" in state_dict:
|
|
898
|
+
_it = state_dict.get(f"{self.get_name()}_it")
|
|
899
|
+
if isinstance(_it, int):
|
|
900
|
+
self._it = _it
|
|
901
|
+
if f"{self.get_name()}_nb_lr_update" in state_dict:
|
|
902
|
+
_nb_lr_update = state_dict.get(f"{self.get_name()}_nb_lr_update")
|
|
903
|
+
if isinstance(_nb_lr_update, int):
|
|
904
|
+
self._nb_lr_update = _nb_lr_update
|
|
905
|
+
|
|
657
906
|
for scheduler in self.schedulers:
|
|
658
907
|
if scheduler.last_epoch == -1:
|
|
659
908
|
scheduler.last_epoch = self._nb_lr_update
|
|
660
909
|
self.initialized()
|
|
661
910
|
|
|
662
|
-
def _compute_channels_trace(
|
|
663
|
-
|
|
911
|
+
def _compute_channels_trace(
|
|
912
|
+
self,
|
|
913
|
+
module: ModuleArgsDict,
|
|
914
|
+
in_channels: int,
|
|
915
|
+
gradient_checkpoints: list[str] | None,
|
|
916
|
+
gpu_checkpoints: list[str] | None,
|
|
917
|
+
name: str | None = None,
|
|
918
|
+
in_is_channel: bool = True,
|
|
919
|
+
out_channels: int | None = None,
|
|
920
|
+
out_is_channel: bool = True,
|
|
921
|
+
) -> tuple[int, bool, int | None, bool]:
|
|
922
|
+
|
|
664
923
|
for k1, v1 in module.items():
|
|
665
924
|
if isinstance(v1, ModuleArgsDict):
|
|
666
925
|
for t in module._modulesArgs[k1].out_branch:
|
|
@@ -680,25 +939,34 @@ class Network(ModuleArgsDict, ABC):
|
|
|
680
939
|
if hasattr(v, "in_features"):
|
|
681
940
|
if v.in_features:
|
|
682
941
|
in_channels = v.in_features
|
|
683
|
-
key = name+"."+k if name else k
|
|
942
|
+
key = name + "." + k if name else k
|
|
684
943
|
|
|
685
944
|
if gradient_checkpoints:
|
|
686
945
|
if key in gradient_checkpoints:
|
|
687
946
|
module._modulesArgs[k].isCheckpoint = True
|
|
688
|
-
|
|
947
|
+
|
|
689
948
|
if gpu_checkpoints:
|
|
690
949
|
if key in gpu_checkpoints:
|
|
691
950
|
module._modulesArgs[k].isGPU_Checkpoint = True
|
|
692
|
-
|
|
951
|
+
|
|
693
952
|
module._modulesArgs[k].in_channels = in_channels
|
|
694
953
|
module._modulesArgs[k].in_is_channel = in_is_channel
|
|
695
|
-
|
|
954
|
+
|
|
696
955
|
if isinstance(v, ModuleArgsDict):
|
|
697
|
-
in_channels, in_is_channel, out_channels, out_is_channel = self._compute_channels_trace(
|
|
698
|
-
|
|
956
|
+
in_channels, in_is_channel, out_channels, out_is_channel = self._compute_channels_trace(
|
|
957
|
+
v,
|
|
958
|
+
in_channels,
|
|
959
|
+
gradient_checkpoints,
|
|
960
|
+
gpu_checkpoints,
|
|
961
|
+
key,
|
|
962
|
+
in_is_channel,
|
|
963
|
+
out_channels,
|
|
964
|
+
out_is_channel,
|
|
965
|
+
)
|
|
966
|
+
|
|
699
967
|
if v.__class__.__name__ == "ToChannels":
|
|
700
968
|
out_is_channel = True
|
|
701
|
-
|
|
969
|
+
|
|
702
970
|
if v.__class__.__name__ == "ToFeatures":
|
|
703
971
|
out_is_channel = False
|
|
704
972
|
|
|
@@ -712,26 +980,32 @@ class Network(ModuleArgsDict, ABC):
|
|
|
712
980
|
module._modulesArgs[k].out_channels = out_channels
|
|
713
981
|
module._modulesArgs[k].out_is_channel = out_is_channel
|
|
714
982
|
|
|
715
|
-
in_channels = out_channels
|
|
983
|
+
in_channels = out_channels if out_channels is not None else in_channels
|
|
716
984
|
in_is_channel = out_is_channel
|
|
717
|
-
|
|
985
|
+
|
|
718
986
|
return in_channels, in_is_channel, out_channels, out_is_channel
|
|
719
987
|
|
|
720
988
|
@_function_network()
|
|
721
|
-
def init(self, autocast
|
|
722
|
-
if self.
|
|
723
|
-
self.measure = Measure(key, self.
|
|
989
|
+
def init(self, autocast: bool, state: State, group_dest: list[str], key: str) -> None:
|
|
990
|
+
if self.outputs_criterions_loader:
|
|
991
|
+
self.measure = Measure(key, self.outputs_criterions_loader)
|
|
724
992
|
self.measure.init(self, group_dest)
|
|
993
|
+
if self.patch is not None:
|
|
994
|
+
self.patch.init(f"{konfai_root()}.Model.{key}.Patch")
|
|
725
995
|
if state != State.PREDICTION:
|
|
726
996
|
self.scaler = torch.amp.GradScaler("cuda", enabled=autocast)
|
|
727
997
|
if self.optimizerLoader:
|
|
728
|
-
self.optimizer = self.optimizerLoader.
|
|
998
|
+
self.optimizer = self.optimizerLoader.get_optimizer(
|
|
999
|
+
key, self.parameters(state == State.TRANSFER_LEARNING)
|
|
1000
|
+
)
|
|
729
1001
|
self.optimizer.zero_grad()
|
|
730
1002
|
|
|
731
|
-
if self.
|
|
732
|
-
for schedulers_classname, schedulers in self.
|
|
733
|
-
self.schedulers[schedulers.getschedulers(key, schedulers_classname, self.optimizer)] =
|
|
734
|
-
|
|
1003
|
+
if self.lr_schedulers_loader and self.optimizer:
|
|
1004
|
+
for schedulers_classname, schedulers in self.lr_schedulers_loader.items():
|
|
1005
|
+
self.schedulers[schedulers.getschedulers(key, schedulers_classname, self.optimizer)] = (
|
|
1006
|
+
schedulers.nb_step
|
|
1007
|
+
)
|
|
1008
|
+
|
|
735
1009
|
def initialized(self):
|
|
736
1010
|
pass
|
|
737
1011
|
|
|
@@ -740,80 +1014,101 @@ class Network(ModuleArgsDict, ABC):
|
|
|
740
1014
|
self.patch.load(inputs[0].shape[2:])
|
|
741
1015
|
accumulators: dict[str, Accumulator] = {}
|
|
742
1016
|
|
|
743
|
-
|
|
1017
|
+
patch_iterator = self.patch.disassemble(*inputs)
|
|
744
1018
|
buffer = []
|
|
745
|
-
for i, patch_input in enumerate(
|
|
746
|
-
for
|
|
747
|
-
yield "
|
|
1019
|
+
for i, patch_input in enumerate(patch_iterator):
|
|
1020
|
+
for name, output_layer in super().named_forward(*patch_input):
|
|
1021
|
+
yield f";accu;{name}", output_layer
|
|
748
1022
|
buffer.append((name.split(".")[0], output_layer))
|
|
749
1023
|
if len(buffer) == 2:
|
|
750
1024
|
if buffer[0][0] != buffer[1][0]:
|
|
751
1025
|
if self._modulesArgs[buffer[0][0]]._isEnd:
|
|
752
1026
|
if buffer[0][0] not in accumulators:
|
|
753
|
-
accumulators[buffer[0][0]] = Accumulator(
|
|
754
|
-
|
|
1027
|
+
accumulators[buffer[0][0]] = Accumulator(
|
|
1028
|
+
self.patch.get_patch_slices(),
|
|
1029
|
+
self.patch.patch_size,
|
|
1030
|
+
self.patch.patch_combine,
|
|
1031
|
+
)
|
|
1032
|
+
accumulators[buffer[0][0]].add_layer(i, buffer[0][1])
|
|
755
1033
|
buffer.pop(0)
|
|
756
1034
|
if self._modulesArgs[buffer[0][0]]._isEnd:
|
|
757
1035
|
if buffer[0][0] not in accumulators:
|
|
758
|
-
accumulators[buffer[0][0]] = Accumulator(
|
|
759
|
-
|
|
1036
|
+
accumulators[buffer[0][0]] = Accumulator(
|
|
1037
|
+
self.patch.get_patch_slices(),
|
|
1038
|
+
self.patch.patch_size,
|
|
1039
|
+
self.patch.patch_combine,
|
|
1040
|
+
)
|
|
1041
|
+
accumulators[buffer[0][0]].add_layer(i, buffer[0][1])
|
|
760
1042
|
for name, accumulator in accumulators.items():
|
|
761
1043
|
yield name, accumulator.assemble()
|
|
762
1044
|
else:
|
|
763
|
-
for
|
|
1045
|
+
for name, output_layer in super().named_forward(*inputs):
|
|
764
1046
|
yield name, output_layer
|
|
765
1047
|
|
|
766
|
-
|
|
767
|
-
|
|
1048
|
+
def get_layers(
|
|
1049
|
+
self, inputs: list[torch.Tensor], layers_name: list[str]
|
|
1050
|
+
) -> Iterator[tuple[str, torch.Tensor, PatchIndexed | None]]:
|
|
768
1051
|
layers_name = layers_name.copy()
|
|
769
|
-
output_layer_accumulator
|
|
770
|
-
output_layer_patch_indexed
|
|
1052
|
+
output_layer_accumulator: dict[str, Accumulator] = {}
|
|
1053
|
+
output_layer_patch_indexed: dict[str, PatchIndexed] = {}
|
|
771
1054
|
it = 0
|
|
772
1055
|
debug = "KONFAI_DEBUG" in os.environ
|
|
773
|
-
for
|
|
774
|
-
name =
|
|
1056
|
+
for name_tmp, output_layer in self.named_forward(*inputs):
|
|
1057
|
+
name = name_tmp.replace(";accu;", "")
|
|
775
1058
|
if debug:
|
|
776
1059
|
if "KONFAI_DEBUG_LAST_LAYER" in os.environ:
|
|
777
|
-
os.environ["KONFAI_DEBUG_LAST_LAYER"] = "{
|
|
1060
|
+
os.environ["KONFAI_DEBUG_LAST_LAYER"] = f"{os.environ['KONFAI_DEBUG_LAST_LAYER']}|{name}:"
|
|
1061
|
+
f"{get_gpu_memory(output_layer.device)}:"
|
|
1062
|
+
f"{str(output_layer.device).replace('cuda:', '')}"
|
|
778
1063
|
else:
|
|
779
|
-
os.environ["KONFAI_DEBUG_LAST_LAYER"] =
|
|
1064
|
+
os.environ["KONFAI_DEBUG_LAST_LAYER"] = (
|
|
1065
|
+
f"{name}:{get_gpu_memory(output_layer.device)}:{str(output_layer.device).replace('cuda:', '')}"
|
|
1066
|
+
)
|
|
780
1067
|
it += 1
|
|
781
|
-
if name in layers_name or
|
|
782
|
-
if ";accu;" in
|
|
1068
|
+
if name in layers_name or name_tmp in layers_name:
|
|
1069
|
+
if ";accu;" in name_tmp:
|
|
783
1070
|
if name not in output_layer_patch_indexed:
|
|
784
|
-
|
|
1071
|
+
network_name = (
|
|
1072
|
+
name_tmp.split(".;accu;")[-2].split(".")[-1]
|
|
1073
|
+
if ".;accu;" in name_tmp
|
|
1074
|
+
else name_tmp.split(";accu;")[-2].split(".")[-1]
|
|
1075
|
+
)
|
|
785
1076
|
module = self
|
|
786
1077
|
network = None
|
|
787
|
-
if
|
|
1078
|
+
if network_name == "":
|
|
788
1079
|
network = module
|
|
789
1080
|
else:
|
|
790
1081
|
for n in name.split("."):
|
|
791
1082
|
module = module[n]
|
|
792
|
-
if isinstance(module, Network) and n ==
|
|
1083
|
+
if isinstance(module, Network) and n == network_name:
|
|
793
1084
|
network = module
|
|
794
1085
|
break
|
|
795
|
-
|
|
1086
|
+
|
|
796
1087
|
if network and network.patch:
|
|
797
|
-
output_layer_patch_indexed[name] =
|
|
1088
|
+
output_layer_patch_indexed[name] = PatchIndexed(network.patch, 0)
|
|
798
1089
|
|
|
799
1090
|
if name not in output_layer_accumulator:
|
|
800
|
-
output_layer_accumulator[name] = Accumulator(
|
|
801
|
-
|
|
802
|
-
|
|
803
|
-
|
|
1091
|
+
output_layer_accumulator[name] = Accumulator(
|
|
1092
|
+
output_layer_patch_indexed[name].patch.get_patch_slices(0),
|
|
1093
|
+
output_layer_patch_indexed[name].patch.patch_size,
|
|
1094
|
+
output_layer_patch_indexed[name].patch.patch_combine,
|
|
1095
|
+
)
|
|
1096
|
+
|
|
1097
|
+
if name_tmp in layers_name:
|
|
1098
|
+
output_layer_accumulator[name].add_layer(output_layer_patch_indexed[name].index, output_layer)
|
|
804
1099
|
output_layer_patch_indexed[name].index += 1
|
|
805
|
-
if output_layer_accumulator[name].
|
|
1100
|
+
if output_layer_accumulator[name].is_full():
|
|
806
1101
|
output_layer = output_layer_accumulator[name].assemble()
|
|
807
1102
|
output_layer_accumulator.pop(name)
|
|
808
1103
|
output_layer_patch_indexed.pop(name)
|
|
809
|
-
layers_name.remove(
|
|
810
|
-
yield
|
|
1104
|
+
layers_name.remove(name_tmp)
|
|
1105
|
+
yield name_tmp, output_layer, None
|
|
811
1106
|
|
|
812
1107
|
if name in layers_name:
|
|
813
|
-
if ";accu;" in
|
|
1108
|
+
if ";accu;" in name_tmp:
|
|
814
1109
|
yield name, output_layer, output_layer_patch_indexed[name]
|
|
815
1110
|
output_layer_patch_indexed[name].index += 1
|
|
816
|
-
if output_layer_patch_indexed[name].
|
|
1111
|
+
if output_layer_patch_indexed[name].is_full():
|
|
817
1112
|
output_layer_patch_indexed.pop(name)
|
|
818
1113
|
layers_name.remove(name)
|
|
819
1114
|
else:
|
|
@@ -823,62 +1118,85 @@ class Network(ModuleArgsDict, ABC):
|
|
|
823
1118
|
if not len(layers_name):
|
|
824
1119
|
break
|
|
825
1120
|
|
|
826
|
-
def
|
|
827
|
-
metric_tmp = {
|
|
1121
|
+
def init_outputs_group(self):
|
|
1122
|
+
metric_tmp = {
|
|
1123
|
+
network.measure: network.measure.outputs_criterions.keys()
|
|
1124
|
+
for network in self.get_networks().values()
|
|
1125
|
+
if network.measure
|
|
1126
|
+
}
|
|
828
1127
|
for k, v in metric_tmp.items():
|
|
829
1128
|
for a in v:
|
|
830
1129
|
outputs_group = OutputsGroup(k)
|
|
831
1130
|
outputs_group.append(a)
|
|
832
|
-
for
|
|
833
|
-
if ":" in
|
|
834
|
-
outputs_group.append(
|
|
1131
|
+
for targets_group in k.outputs_criterions[a].keys():
|
|
1132
|
+
if ":" in targets_group:
|
|
1133
|
+
outputs_group.append(targets_group.replace(":", "."))
|
|
835
1134
|
|
|
836
1135
|
self.outputsGroup.append(outputs_group)
|
|
837
1136
|
|
|
838
|
-
def forward(
|
|
1137
|
+
def forward(
|
|
1138
|
+
self,
|
|
1139
|
+
data_dict: dict[tuple[str, bool], torch.Tensor],
|
|
1140
|
+
output_layers: list[str] = [],
|
|
1141
|
+
) -> list[tuple[str, torch.Tensor]]:
|
|
839
1142
|
if not len(self.outputsGroup) and not len(output_layers):
|
|
840
1143
|
return []
|
|
841
1144
|
|
|
842
|
-
self.
|
|
1145
|
+
self.reset_loss()
|
|
843
1146
|
results = []
|
|
844
1147
|
measure_output_layers = set()
|
|
845
|
-
for
|
|
846
|
-
|
|
1148
|
+
for _outputs_group in self.outputsGroup:
|
|
1149
|
+
for name in _outputs_group:
|
|
847
1150
|
measure_output_layers.add(name)
|
|
848
|
-
for name, layer, patch_indexed in self.get_layers(
|
|
849
|
-
|
|
1151
|
+
for name, layer, patch_indexed in self.get_layers(
|
|
1152
|
+
[v for k, v in data_dict.items() if k[1]],
|
|
1153
|
+
list(set(list(measure_output_layers) + output_layers)),
|
|
1154
|
+
):
|
|
1155
|
+
|
|
850
1156
|
outputs_group = [outputs_group for outputs_group in self.outputsGroup if name in outputs_group]
|
|
851
1157
|
if len(outputs_group) > 0:
|
|
852
1158
|
if patch_indexed is None:
|
|
853
|
-
targets = {k[0]
|
|
1159
|
+
targets = {k[0]: v for k, v in data_dict.items()}
|
|
854
1160
|
nb = 1
|
|
855
1161
|
else:
|
|
856
|
-
targets = {
|
|
857
|
-
|
|
1162
|
+
targets = {
|
|
1163
|
+
k[0]: patch_indexed.patch.get_data(v, patch_indexed.index, 0, False)
|
|
1164
|
+
for k, v in data_dict.items()
|
|
1165
|
+
}
|
|
1166
|
+
nb = patch_indexed.patch.get_size(0)
|
|
858
1167
|
|
|
859
1168
|
for output_group in outputs_group:
|
|
860
|
-
output_group.
|
|
861
|
-
if output_group.
|
|
862
|
-
targets.update(
|
|
863
|
-
|
|
1169
|
+
output_group.add_layer(name, layer)
|
|
1170
|
+
if output_group.is_done():
|
|
1171
|
+
targets.update(
|
|
1172
|
+
{k.replace(".", ":"): v for k, v in output_group.layers.items() if k != output_group[0]}
|
|
1173
|
+
)
|
|
1174
|
+
output_group.measure.update(
|
|
1175
|
+
output_group[0],
|
|
1176
|
+
output_group.layers[output_group[0]],
|
|
1177
|
+
targets,
|
|
1178
|
+
self._it,
|
|
1179
|
+
nb,
|
|
1180
|
+
self.training,
|
|
1181
|
+
)
|
|
864
1182
|
output_group.clear()
|
|
865
1183
|
if name in output_layers:
|
|
866
1184
|
results.append((name, layer))
|
|
867
1185
|
return results
|
|
868
1186
|
|
|
869
1187
|
@_function_network()
|
|
870
|
-
def
|
|
1188
|
+
def reset_loss(self):
|
|
871
1189
|
if self.measure:
|
|
872
|
-
self.measure.
|
|
1190
|
+
self.measure.reset_loss()
|
|
873
1191
|
|
|
874
1192
|
@_function_network()
|
|
875
1193
|
def backward(self, model: Self):
|
|
876
1194
|
if self.measure:
|
|
877
1195
|
if self.scaler and self.optimizer:
|
|
878
|
-
model._requires_grad(list(self.measure.
|
|
879
|
-
for loss in self.measure.
|
|
1196
|
+
model._requires_grad(list(self.measure.outputs_criterions.keys()))
|
|
1197
|
+
for loss in self.measure.get_loss():
|
|
880
1198
|
self.scaler.scale(loss / self.nb_batch_per_step).backward()
|
|
881
|
-
|
|
1199
|
+
|
|
882
1200
|
if self._it % self.nb_batch_per_step == 0:
|
|
883
1201
|
self.scaler.step(self.optimizer)
|
|
884
1202
|
self.scaler.update()
|
|
@@ -887,93 +1205,133 @@ class Network(ModuleArgsDict, ABC):
|
|
|
887
1205
|
|
|
888
1206
|
@_function_network()
|
|
889
1207
|
def update_lr(self):
|
|
890
|
-
self._nb_lr_update+=1
|
|
1208
|
+
self._nb_lr_update += 1
|
|
891
1209
|
step = 0
|
|
892
|
-
|
|
893
|
-
for
|
|
894
|
-
if value is None or (self._nb_lr_update >= step and self._nb_lr_update < step+value):
|
|
1210
|
+
_scheduler = None
|
|
1211
|
+
for _scheduler, value in self.schedulers.items():
|
|
1212
|
+
if value is None or (self._nb_lr_update >= step and self._nb_lr_update < step + value):
|
|
895
1213
|
break
|
|
896
1214
|
step += value
|
|
897
|
-
if
|
|
898
|
-
if
|
|
1215
|
+
if _scheduler:
|
|
1216
|
+
if _scheduler.__class__.__name__ == "ReduceLROnPlateau":
|
|
899
1217
|
if self.measure:
|
|
900
|
-
|
|
1218
|
+
_scheduler.step(sum(self.measure.get_last_values(0).values()))
|
|
901
1219
|
else:
|
|
902
|
-
|
|
1220
|
+
_scheduler.step()
|
|
903
1221
|
|
|
904
1222
|
@_function_network()
|
|
905
|
-
def
|
|
1223
|
+
def get_networks(self) -> Self:
|
|
906
1224
|
return self
|
|
907
1225
|
|
|
908
|
-
|
|
1226
|
+
@staticmethod
|
|
1227
|
+
def to(module: ModuleArgsDict, device: int):
|
|
909
1228
|
if "device" not in os.environ:
|
|
910
1229
|
os.environ["device"] = str(device)
|
|
911
1230
|
for k, v in module.items():
|
|
912
1231
|
if module._modulesArgs[k].gpu == "cpu":
|
|
913
1232
|
if module._modulesArgs[k].isGPU_Checkpoint:
|
|
914
|
-
os.environ["device"] = str(int(os.environ["device"])+1)
|
|
915
|
-
module._modulesArgs[k].gpu = str(
|
|
1233
|
+
os.environ["device"] = str(int(os.environ["device"]) + 1)
|
|
1234
|
+
module._modulesArgs[k].gpu = str(get_device(int(os.environ["device"])))
|
|
916
1235
|
if isinstance(v, ModuleArgsDict):
|
|
917
1236
|
v = Network.to(v, int(os.environ["device"]))
|
|
918
1237
|
else:
|
|
919
|
-
v = v.to(
|
|
1238
|
+
v = v.to(get_device(int(os.environ["device"])))
|
|
920
1239
|
if isinstance(module, Network):
|
|
921
1240
|
if module.optimizer is not None:
|
|
922
1241
|
for state in module.optimizer.state.values():
|
|
923
1242
|
for k, v in state.items():
|
|
924
1243
|
if isinstance(v, torch.Tensor):
|
|
925
|
-
state[k] = v.to(
|
|
1244
|
+
state[k] = v.to(get_device(int(os.environ["device"])))
|
|
926
1245
|
return module
|
|
927
|
-
|
|
928
|
-
def
|
|
1246
|
+
|
|
1247
|
+
def get_name(self) -> str:
|
|
929
1248
|
return self.name
|
|
930
|
-
|
|
931
|
-
def
|
|
1249
|
+
|
|
1250
|
+
def set_name(self, name: str) -> Self:
|
|
932
1251
|
self.name = name
|
|
933
1252
|
return self
|
|
934
|
-
|
|
935
|
-
def
|
|
1253
|
+
|
|
1254
|
+
def set_state(self, state: NetState):
|
|
936
1255
|
for module in self.modules():
|
|
937
1256
|
if isinstance(module, ModuleArgsDict):
|
|
938
1257
|
module._training = state
|
|
939
1258
|
|
|
1259
|
+
|
|
940
1260
|
class MinimalModel(Network):
|
|
941
|
-
|
|
942
|
-
def __init__(
|
|
943
|
-
|
|
944
|
-
|
|
945
|
-
|
|
946
|
-
|
|
947
|
-
|
|
1261
|
+
|
|
1262
|
+
def __init__(
|
|
1263
|
+
self,
|
|
1264
|
+
model: Network,
|
|
1265
|
+
optimizer: OptimizerLoader = OptimizerLoader(),
|
|
1266
|
+
schedulers: dict[str, LRSchedulersLoader] = {"default:ReduceLROnPlateau": LRSchedulersLoader(0)},
|
|
1267
|
+
outputs_criterions: dict[str, TargetCriterionsLoader] = {"default": TargetCriterionsLoader()},
|
|
1268
|
+
patch: ModelPatch | None = None,
|
|
1269
|
+
dim: int = 3,
|
|
1270
|
+
nb_batch_per_step=1,
|
|
1271
|
+
init_type="normal",
|
|
1272
|
+
init_gain=0.02,
|
|
1273
|
+
):
|
|
1274
|
+
super().__init__(
|
|
1275
|
+
1,
|
|
1276
|
+
optimizer,
|
|
1277
|
+
schedulers,
|
|
1278
|
+
outputs_criterions,
|
|
1279
|
+
patch,
|
|
1280
|
+
nb_batch_per_step,
|
|
1281
|
+
init_type,
|
|
1282
|
+
init_gain,
|
|
1283
|
+
dim,
|
|
1284
|
+
)
|
|
948
1285
|
self.add_module("Model", model)
|
|
949
1286
|
|
|
950
|
-
|
|
1287
|
+
|
|
1288
|
+
class ModelLoader:
|
|
951
1289
|
|
|
952
1290
|
@config("Model")
|
|
953
|
-
def __init__(self, classpath
|
|
954
|
-
self.module, self.name =
|
|
955
|
-
|
|
956
|
-
def
|
|
957
|
-
|
|
958
|
-
|
|
959
|
-
|
|
960
|
-
|
|
1291
|
+
def __init__(self, classpath: str = "default:segmentation.UNet.UNet") -> None:
|
|
1292
|
+
self.module, self.name = get_module(classpath, "konfai.models")
|
|
1293
|
+
|
|
1294
|
+
def get_model(
|
|
1295
|
+
self,
|
|
1296
|
+
train: bool = True,
|
|
1297
|
+
konfai_args: str | None = None,
|
|
1298
|
+
konfai_without=[
|
|
1299
|
+
"optimizer",
|
|
1300
|
+
"schedulers",
|
|
1301
|
+
"nb_batch_per_step",
|
|
1302
|
+
"init_type",
|
|
1303
|
+
"init_gain",
|
|
1304
|
+
],
|
|
1305
|
+
) -> Network:
|
|
1306
|
+
if not konfai_args:
|
|
1307
|
+
konfai_args = f"{konfai_root()}.Model"
|
|
1308
|
+
konfai_args += "." + self.name
|
|
1309
|
+
model = config(konfai_args)(getattr(importlib.import_module(self.module), self.name))(
|
|
1310
|
+
config=None, konfai_without=konfai_without if not train else []
|
|
1311
|
+
)
|
|
961
1312
|
if not isinstance(model, Network):
|
|
962
|
-
model = config(
|
|
963
|
-
|
|
1313
|
+
model = config(konfai_args)(partial(MinimalModel, model))(
|
|
1314
|
+
config=None, konfai_without=konfai_without + ["model"] if not train else []
|
|
1315
|
+
)
|
|
1316
|
+
model.set_name(self.name)
|
|
964
1317
|
return model
|
|
965
|
-
|
|
966
|
-
|
|
1318
|
+
|
|
1319
|
+
|
|
1320
|
+
class CPUModel:
|
|
967
1321
|
|
|
968
1322
|
def __init__(self, model: Network) -> None:
|
|
969
1323
|
self.module = model
|
|
970
|
-
self.device = torch.device(
|
|
1324
|
+
self.device = torch.device("cpu")
|
|
971
1325
|
|
|
972
1326
|
def train(self):
|
|
973
1327
|
self.module.train()
|
|
974
|
-
|
|
975
|
-
def eval(self):
|
|
1328
|
+
|
|
1329
|
+
def eval(self): # noqa: A003
|
|
976
1330
|
self.module.eval()
|
|
977
|
-
|
|
978
|
-
def __call__(
|
|
979
|
-
|
|
1331
|
+
|
|
1332
|
+
def __call__(
|
|
1333
|
+
self,
|
|
1334
|
+
data_dict: dict[tuple[str, bool], torch.Tensor],
|
|
1335
|
+
output_layers: list[str] = [],
|
|
1336
|
+
) -> list[tuple[str, torch.Tensor]]:
|
|
1337
|
+
return self.module(data_dict, output_layers)
|