konfai 1.1.8__py3-none-any.whl → 1.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of konfai might be problematic. Click here for more details.
- konfai/__init__.py +59 -14
- konfai/data/augmentation.py +457 -286
- konfai/data/data_manager.py +533 -316
- konfai/data/patching.py +300 -183
- konfai/data/transform.py +408 -275
- konfai/evaluator.py +325 -68
- konfai/main.py +71 -22
- konfai/metric/measure.py +360 -244
- konfai/metric/schedulers.py +24 -13
- konfai/models/classification/convNeXt.py +187 -81
- konfai/models/classification/resnet.py +272 -58
- konfai/models/generation/cStyleGan.py +233 -59
- konfai/models/generation/ddpm.py +348 -121
- konfai/models/generation/diffusionGan.py +757 -358
- konfai/models/generation/gan.py +177 -53
- konfai/models/generation/vae.py +140 -40
- konfai/models/registration/registration.py +135 -52
- konfai/models/representation/representation.py +57 -23
- konfai/models/segmentation/NestedUNet.py +339 -68
- konfai/models/segmentation/UNet.py +140 -30
- konfai/network/blocks.py +331 -187
- konfai/network/network.py +795 -427
- konfai/predictor.py +644 -238
- konfai/trainer.py +509 -222
- konfai/utils/ITK.py +191 -106
- konfai/utils/config.py +152 -95
- konfai/utils/dataset.py +326 -455
- konfai/utils/utils.py +497 -249
- {konfai-1.1.8.dist-info → konfai-1.2.0.dist-info}/METADATA +1 -3
- konfai-1.2.0.dist-info/RECORD +38 -0
- konfai/utils/registration.py +0 -199
- konfai-1.1.8.dist-info/RECORD +0 -39
- {konfai-1.1.8.dist-info → konfai-1.2.0.dist-info}/WHEEL +0 -0
- {konfai-1.1.8.dist-info → konfai-1.2.0.dist-info}/entry_points.txt +0 -0
- {konfai-1.1.8.dist-info → konfai-1.2.0.dist-info}/licenses/LICENSE +0 -0
- {konfai-1.1.8.dist-info → konfai-1.2.0.dist-info}/top_level.txt +0 -0
konfai/network/network.py
CHANGED
|
@@ -1,117 +1,181 @@
|
|
|
1
|
-
from functools import partial
|
|
2
1
|
import importlib
|
|
3
2
|
import inspect
|
|
4
3
|
import os
|
|
5
|
-
from typing import Iterable, Iterator, Callable
|
|
6
|
-
from typing_extensions import Self
|
|
7
|
-
import torch
|
|
8
4
|
from abc import ABC
|
|
5
|
+
from collections import OrderedDict
|
|
6
|
+
from collections.abc import Callable, Iterable, Iterator, Sequence
|
|
7
|
+
from enum import Enum
|
|
8
|
+
from functools import partial
|
|
9
|
+
from typing import Any, Self
|
|
10
|
+
|
|
9
11
|
import numpy as np
|
|
12
|
+
import torch
|
|
10
13
|
from torch._jit_internal import _copy_to_script_wrapper
|
|
11
|
-
from collections import OrderedDict
|
|
12
14
|
from torch.utils.checkpoint import checkpoint
|
|
13
|
-
from typing import Union
|
|
14
|
-
from enum import Enum
|
|
15
15
|
|
|
16
|
-
from konfai import
|
|
16
|
+
from konfai import konfai_root
|
|
17
|
+
from konfai.data.patching import Accumulator, ModelPatch
|
|
17
18
|
from konfai.metric.schedulers import Scheduler
|
|
18
19
|
from konfai.utils.config import config
|
|
19
|
-
from konfai.utils.utils import
|
|
20
|
-
|
|
20
|
+
from konfai.utils.utils import MeasureError, State, TrainerError, get_device, get_gpu_memory, get_module
|
|
21
|
+
|
|
21
22
|
|
|
22
23
|
class NetState(Enum):
|
|
23
|
-
TRAIN = 0,
|
|
24
|
+
TRAIN = (0,)
|
|
24
25
|
PREDICTION = 1
|
|
25
26
|
|
|
26
|
-
|
|
27
|
+
|
|
28
|
+
class PatchIndexed:
|
|
27
29
|
|
|
28
30
|
def __init__(self, patch: ModelPatch, index: int) -> None:
|
|
29
31
|
self.patch = patch
|
|
30
32
|
self.index = index
|
|
31
33
|
|
|
32
|
-
def
|
|
33
|
-
return len(self.patch.
|
|
34
|
+
def is_full(self) -> bool:
|
|
35
|
+
return len(self.patch.get_patch_slices(0)) == self.index
|
|
36
|
+
|
|
37
|
+
|
|
38
|
+
class OptimizerLoader:
|
|
34
39
|
|
|
35
|
-
class OptimizerLoader():
|
|
36
|
-
|
|
37
40
|
@config("Optimizer")
|
|
38
41
|
def __init__(self, name: str = "AdamW") -> None:
|
|
39
42
|
self.name = name
|
|
40
|
-
|
|
41
|
-
def getOptimizer(self, key: str, parameter: Iterator[torch.nn.parameter.Parameter]) -> torch.optim.Optimizer:
|
|
42
|
-
return config("{}.Model.{}.Optimizer".format(KONFAI_ROOT(), key))(getattr(importlib.import_module('torch.optim'), self.name))(parameter, config = None)
|
|
43
|
-
|
|
44
43
|
|
|
45
|
-
|
|
46
|
-
|
|
44
|
+
def get_optimizer(self, key: str, parameter: Iterator[torch.nn.parameter.Parameter]) -> torch.optim.Optimizer:
|
|
45
|
+
return config(f"{konfai_root()}.Model.{key}.Optimizer")(
|
|
46
|
+
getattr(importlib.import_module("torch.optim"), self.name)
|
|
47
|
+
)(parameter, config=None)
|
|
48
|
+
|
|
49
|
+
|
|
50
|
+
class LRSchedulersLoader:
|
|
51
|
+
|
|
47
52
|
@config()
|
|
48
|
-
def __init__(self, nb_step
|
|
53
|
+
def __init__(self, nb_step: int = 0) -> None:
|
|
49
54
|
self.nb_step = nb_step
|
|
50
55
|
|
|
51
|
-
def getschedulers(
|
|
56
|
+
def getschedulers(
|
|
57
|
+
self, key: str, scheduler_classname: str, optimizer: torch.optim.Optimizer
|
|
58
|
+
) -> torch.optim.lr_scheduler._LRScheduler:
|
|
52
59
|
for m in ["torch.optim.lr_scheduler", "konfai.metric.schedulers"]:
|
|
53
|
-
module, name =
|
|
60
|
+
module, name = get_module(scheduler_classname, m)
|
|
54
61
|
if hasattr(importlib.import_module(module), name):
|
|
55
|
-
return config("{}.Model.{}.schedulers.{}"
|
|
56
|
-
|
|
62
|
+
return config(f"{konfai_root()}.Model.{key}.schedulers.{scheduler_classname}")(
|
|
63
|
+
getattr(importlib.import_module(module), name)
|
|
64
|
+
)(optimizer, config=None)
|
|
65
|
+
raise TrainerError(
|
|
66
|
+
f"Unknown scheduler {scheduler_classname}, tried importing from: 'torch.optim.lr_scheduler' and "
|
|
67
|
+
"'konfai.metric.schedulers', but no valid match was found. "
|
|
68
|
+
"Check your YAML config or scheduler name spelling."
|
|
69
|
+
)
|
|
70
|
+
|
|
71
|
+
|
|
72
|
+
class LossSchedulersLoader:
|
|
57
73
|
|
|
58
|
-
class LossSchedulersLoader():
|
|
59
|
-
|
|
60
74
|
@config()
|
|
61
|
-
def __init__(self, nb_step
|
|
75
|
+
def __init__(self, nb_step: int = 0) -> None:
|
|
62
76
|
self.nb_step = nb_step
|
|
63
77
|
|
|
64
78
|
def getschedulers(self, key: str, scheduler_classname: str) -> torch.optim.lr_scheduler._LRScheduler:
|
|
65
|
-
return config("{}.{}"
|
|
66
|
-
|
|
67
|
-
|
|
68
|
-
|
|
79
|
+
return config(f"{key}.{scheduler_classname}")(
|
|
80
|
+
getattr(importlib.import_module("konfai.metric.schedulers"), scheduler_classname)
|
|
81
|
+
)(config=None)
|
|
82
|
+
|
|
83
|
+
|
|
84
|
+
class CriterionsAttr:
|
|
85
|
+
|
|
69
86
|
@config()
|
|
70
|
-
def __init__(
|
|
87
|
+
def __init__(
|
|
88
|
+
self,
|
|
89
|
+
schedulers: dict[str, LossSchedulersLoader] = {"default:Constant": LossSchedulersLoader(0)},
|
|
90
|
+
is_loss: bool = True,
|
|
91
|
+
group: int = 0,
|
|
92
|
+
start: int = 0,
|
|
93
|
+
stop: int | None = None,
|
|
94
|
+
accumulation: bool = False,
|
|
95
|
+
) -> None:
|
|
71
96
|
self.schedulersLoader = schedulers
|
|
72
97
|
self.isTorchCriterion = True
|
|
73
|
-
self.
|
|
74
|
-
self.
|
|
75
|
-
self.
|
|
98
|
+
self.is_loss = is_loss
|
|
99
|
+
self.start = start
|
|
100
|
+
self.stop = stop
|
|
76
101
|
self.group = group
|
|
77
102
|
self.accumulation = accumulation
|
|
78
103
|
self.schedulers: dict[Scheduler, int] = {}
|
|
79
|
-
|
|
80
|
-
class CriterionsLoader():
|
|
81
104
|
|
|
82
|
-
@config()
|
|
83
|
-
def __init__(self, criterionsLoader: dict[str, CriterionsAttr] = {"default:torch_nn_CrossEntropyLoss:Dice:NCC": CriterionsAttr()}) -> None:
|
|
84
|
-
self.criterionsLoader = criterionsLoader
|
|
85
|
-
|
|
86
|
-
def getCriterions(self, model_classname : str, output_group : str, target_group : str) -> dict[torch.nn.Module, CriterionsAttr]:
|
|
87
|
-
criterions = {}
|
|
88
|
-
for module_classpath, criterionsAttr in self.criterionsLoader.items():
|
|
89
|
-
module, name = _getModule(module_classpath, "konfai.metric.measure")
|
|
90
|
-
criterionsAttr.isTorchCriterion = module.startswith("torch")
|
|
91
|
-
for scheduler_classname, schedulers in criterionsAttr.schedulersLoader.items():
|
|
92
|
-
criterionsAttr.schedulers[schedulers.getschedulers("{}.Model.{}.outputsCriterions.{}.targetsCriterions.{}.criterionsLoader.{}.schedulers".format(KONFAI_ROOT(), model_classname, output_group, target_group, module_classpath), scheduler_classname)] = schedulers.nb_step
|
|
93
|
-
criterions[config("{}.Model.{}.outputsCriterions.{}.targetsCriterions.{}.criterionsLoader.{}".format(KONFAI_ROOT(), model_classname, output_group, target_group, module_classpath))(getattr(importlib.import_module(module), name))(config = None)] = criterionsAttr
|
|
94
|
-
return criterions
|
|
95
105
|
|
|
96
|
-
class
|
|
106
|
+
class CriterionsLoader:
|
|
97
107
|
|
|
98
108
|
@config()
|
|
99
|
-
def __init__(
|
|
100
|
-
self
|
|
101
|
-
|
|
102
|
-
|
|
103
|
-
|
|
104
|
-
|
|
105
|
-
|
|
106
|
-
|
|
109
|
+
def __init__(
|
|
110
|
+
self,
|
|
111
|
+
criterions_loader: dict[str, CriterionsAttr] = {"default:torch:nn:CrossEntropyLoss:Dice:NCC": CriterionsAttr()},
|
|
112
|
+
) -> None:
|
|
113
|
+
self.criterions_loader = criterions_loader
|
|
114
|
+
|
|
115
|
+
def get_criterions(
|
|
116
|
+
self, model_classname: str, output_group: str, target_group: str
|
|
117
|
+
) -> dict[torch.nn.Module, CriterionsAttr]:
|
|
118
|
+
criterions = {}
|
|
119
|
+
for module_classpath, criterions_attr in self.criterions_loader.items():
|
|
120
|
+
module, name = get_module(module_classpath, "konfai.metric.measure")
|
|
121
|
+
criterions_attr.isTorchCriterion = module.startswith("torch")
|
|
122
|
+
for (
|
|
123
|
+
scheduler_classname,
|
|
124
|
+
schedulers,
|
|
125
|
+
) in criterions_attr.schedulersLoader.items():
|
|
126
|
+
criterions_attr.schedulers[
|
|
127
|
+
schedulers.getschedulers(
|
|
128
|
+
f"{konfai_root()}.Model.{model_classname}.outputs_criterions.{output_group}"
|
|
129
|
+
f".targets_criterions.{target_group}"
|
|
130
|
+
f".criterions_loader.{module_classpath}.schedulers",
|
|
131
|
+
scheduler_classname,
|
|
132
|
+
)
|
|
133
|
+
] = schedulers.nb_step
|
|
134
|
+
criterions[
|
|
135
|
+
config(
|
|
136
|
+
f"{konfai_root()}.Model.{model_classname}.outputs_criterions."
|
|
137
|
+
f"{output_group}.targets_criterions.{target_group}."
|
|
138
|
+
f"criterions_loader.{module_classpath}"
|
|
139
|
+
)(getattr(importlib.import_module(module), name))(config=None)
|
|
140
|
+
] = criterions_attr
|
|
141
|
+
return criterions
|
|
107
142
|
|
|
108
|
-
class Measure():
|
|
109
143
|
|
|
110
|
-
|
|
144
|
+
class TargetCriterionsLoader:
|
|
111
145
|
|
|
112
|
-
|
|
146
|
+
@config()
|
|
147
|
+
def __init__(
|
|
148
|
+
self,
|
|
149
|
+
targets_criterions: dict[str, CriterionsLoader] = {"default": CriterionsLoader()},
|
|
150
|
+
) -> None:
|
|
151
|
+
self.targets_criterions = targets_criterions
|
|
152
|
+
|
|
153
|
+
def get_targets_criterions(
|
|
154
|
+
self, output_group: str, model_classname: str
|
|
155
|
+
) -> dict[str, dict[torch.nn.Module, CriterionsAttr]]:
|
|
156
|
+
targets_criterions = {}
|
|
157
|
+
for target_group, criterions_loader in self.targets_criterions.items():
|
|
158
|
+
targets_criterions[target_group] = criterions_loader.get_criterions(
|
|
159
|
+
model_classname, output_group, target_group
|
|
160
|
+
)
|
|
161
|
+
return targets_criterions
|
|
162
|
+
|
|
163
|
+
|
|
164
|
+
class Measure:
|
|
165
|
+
|
|
166
|
+
class Loss:
|
|
167
|
+
|
|
168
|
+
def __init__(
|
|
169
|
+
self,
|
|
170
|
+
name: str,
|
|
171
|
+
output_group: str,
|
|
172
|
+
target_group: str,
|
|
173
|
+
group: int,
|
|
174
|
+
is_loss: bool,
|
|
175
|
+
accumulation: bool,
|
|
176
|
+
) -> None:
|
|
113
177
|
self.name = name
|
|
114
|
-
self.
|
|
178
|
+
self.is_loss = is_loss
|
|
115
179
|
self.accumulation = accumulation
|
|
116
180
|
self.output_group = output_group
|
|
117
181
|
self.target_group = target_group
|
|
@@ -120,144 +184,236 @@ class Measure():
|
|
|
120
184
|
self._loss: list[torch.Tensor] = []
|
|
121
185
|
self._weight: list[float] = []
|
|
122
186
|
self._values: list[float] = []
|
|
123
|
-
|
|
124
|
-
def
|
|
187
|
+
|
|
188
|
+
def reset_loss(self) -> None:
|
|
125
189
|
self._loss.clear()
|
|
126
190
|
|
|
127
|
-
def add(self, weight: float, value: torch.Tensor) -> None:
|
|
128
|
-
if
|
|
129
|
-
|
|
130
|
-
|
|
131
|
-
|
|
191
|
+
def add(self, weight: float, value: torch.Tensor | tuple[torch.Tensor, float]) -> None:
|
|
192
|
+
if isinstance(value, tuple):
|
|
193
|
+
loss_value, true_value = value
|
|
194
|
+
else:
|
|
195
|
+
loss_value = value
|
|
196
|
+
true_value = value.item()
|
|
197
|
+
|
|
198
|
+
self._loss.append(loss_value if self.is_loss else loss_value.detach())
|
|
199
|
+
self._values.append(true_value)
|
|
200
|
+
self._weight.append(weight)
|
|
201
|
+
|
|
202
|
+
def get_last_loss(self) -> torch.Tensor:
|
|
203
|
+
return self._loss[-1] * self._weight[-1] if len(self._loss) else torch.zeros((1), requires_grad=True)
|
|
132
204
|
|
|
133
|
-
def
|
|
134
|
-
return
|
|
135
|
-
|
|
136
|
-
|
|
137
|
-
|
|
205
|
+
def get_loss(self) -> torch.Tensor:
|
|
206
|
+
return (
|
|
207
|
+
torch.stack([w * loss_value for w, loss_value in zip(self._weight, self._loss)], dim=0).mean(dim=0)
|
|
208
|
+
if len(self._loss)
|
|
209
|
+
else torch.zeros((1), requires_grad=True)
|
|
210
|
+
)
|
|
138
211
|
|
|
139
212
|
def __len__(self) -> int:
|
|
140
213
|
return len(self._loss)
|
|
141
|
-
|
|
142
|
-
def __init__(self, model_classname : str, outputsCriterions: dict[str, TargetCriterionsLoader]) -> None:
|
|
143
|
-
super().__init__()
|
|
144
|
-
self.outputsCriterions = {}
|
|
145
|
-
for output_group, targetCriterionsLoader in outputsCriterions.items():
|
|
146
|
-
self.outputsCriterions[output_group.replace(":", ".")] = targetCriterionsLoader.getTargetsCriterions(output_group, model_classname)
|
|
147
|
-
self._loss : dict[int, dict[str, Measure.Loss]] = {}
|
|
148
214
|
|
|
149
|
-
def
|
|
215
|
+
def __init__(
|
|
216
|
+
self,
|
|
217
|
+
model_classname: str,
|
|
218
|
+
outputs_criterions_loader: dict[str, TargetCriterionsLoader],
|
|
219
|
+
) -> None:
|
|
220
|
+
super().__init__()
|
|
221
|
+
self.outputs_criterions: dict[str, dict[str, dict[torch.nn.Module, CriterionsAttr]]] = {}
|
|
222
|
+
for output_group, target_criterions_loader in outputs_criterions_loader.items():
|
|
223
|
+
self.outputs_criterions[output_group.replace(":", ".")] = target_criterions_loader.get_targets_criterions(
|
|
224
|
+
output_group, model_classname
|
|
225
|
+
)
|
|
226
|
+
self._loss: dict[int, dict[str, Measure.Loss]] = {}
|
|
227
|
+
|
|
228
|
+
def init(self, model: torch.nn.Module, group_dest: list[str]) -> None:
|
|
150
229
|
outputs_group_rename = {}
|
|
151
|
-
|
|
230
|
+
|
|
152
231
|
modules = []
|
|
153
|
-
for i,_,_ in model.
|
|
232
|
+
for i, _, _ in model.named_module_args_dict():
|
|
154
233
|
modules.append(i)
|
|
155
234
|
|
|
156
|
-
for output_group in self.
|
|
157
|
-
if output_group not in modules:
|
|
158
|
-
|
|
235
|
+
for output_group in self.outputs_criterions.keys():
|
|
236
|
+
if output_group.replace(";accu;", "") not in modules:
|
|
237
|
+
raise MeasureError(
|
|
238
|
+
f"The output group '{output_group}' defined in 'outputs_criterions' "
|
|
239
|
+
"does not correspond to any module in the model.",
|
|
159
240
|
f"Available modules: {modules}",
|
|
160
|
-
"Please check that the name matches exactly a submodule or output of your model architecture."
|
|
241
|
+
"Please check that the name matches exactly a submodule or output of your model architecture.",
|
|
161
242
|
)
|
|
162
|
-
for target_group in self.
|
|
163
|
-
for target_group_tmp in target_group.split(";"):
|
|
243
|
+
for target_group in self.outputs_criterions[output_group]:
|
|
244
|
+
for target_group_tmp in target_group.split(";"):
|
|
164
245
|
if target_group_tmp not in group_dest:
|
|
165
246
|
raise MeasureError(
|
|
166
|
-
f"The target_group
|
|
167
|
-
"
|
|
168
|
-
|
|
169
|
-
|
|
170
|
-
|
|
247
|
+
f"The target_group {target_group_tmp} defined in "
|
|
248
|
+
"'outputs_criterions.{output_group}.targets_criterions'"
|
|
249
|
+
" was not found in the available destination groups.",
|
|
250
|
+
"This target_group is expected for loss or metric computation, "
|
|
251
|
+
"but was not loaded in 'group_dest'.",
|
|
252
|
+
f"Please make sure that the group {target_group_tmp} is defined in "
|
|
253
|
+
"Dataset:groups_src:...:groups_dest: {target_group_tmp} "
|
|
254
|
+
"and correctly loaded from the dataset.",
|
|
255
|
+
)
|
|
256
|
+
for criterion in self.outputs_criterions[output_group][target_group]:
|
|
257
|
+
if not self.outputs_criterions[output_group][target_group][criterion].isTorchCriterion:
|
|
171
258
|
outputs_group_rename[output_group] = criterion.init(model, output_group, target_group)
|
|
172
259
|
|
|
173
|
-
|
|
260
|
+
outputs_criterions_bak = self.outputs_criterions.copy()
|
|
174
261
|
for old, new in outputs_group_rename.items():
|
|
175
|
-
self.
|
|
176
|
-
self.
|
|
177
|
-
for output_group in self.
|
|
178
|
-
for target_group in self.
|
|
179
|
-
for criterion,
|
|
180
|
-
if
|
|
181
|
-
self._loss[
|
|
182
|
-
self._loss[
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
187
|
-
|
|
188
|
-
|
|
189
|
-
|
|
190
|
-
|
|
191
|
-
|
|
192
|
-
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
262
|
+
self.outputs_criterions.pop(old)
|
|
263
|
+
self.outputs_criterions[new] = outputs_criterions_bak[old]
|
|
264
|
+
for output_group in self.outputs_criterions:
|
|
265
|
+
for target_group in self.outputs_criterions[output_group]:
|
|
266
|
+
for criterion, criterions_attr in self.outputs_criterions[output_group][target_group].items():
|
|
267
|
+
if criterions_attr.group not in self._loss:
|
|
268
|
+
self._loss[criterions_attr.group] = {}
|
|
269
|
+
self._loss[criterions_attr.group][
|
|
270
|
+
f"{output_group}:{target_group}:{criterion.__class__.__name__}"
|
|
271
|
+
] = Measure.Loss(
|
|
272
|
+
criterion.__class__.__name__,
|
|
273
|
+
output_group,
|
|
274
|
+
target_group,
|
|
275
|
+
criterions_attr.group,
|
|
276
|
+
criterions_attr.is_loss,
|
|
277
|
+
criterions_attr.accumulation,
|
|
278
|
+
)
|
|
279
|
+
|
|
280
|
+
def update(
|
|
281
|
+
self,
|
|
282
|
+
output_group: str,
|
|
283
|
+
output: torch.Tensor,
|
|
284
|
+
data_dict: dict[str, torch.Tensor],
|
|
285
|
+
it: int,
|
|
286
|
+
nb_patch: int,
|
|
287
|
+
training: bool,
|
|
288
|
+
) -> None:
|
|
289
|
+
for target_group in self.outputs_criterions[output_group]:
|
|
290
|
+
target = [
|
|
291
|
+
data_dict[group].to(output[0].device).detach()
|
|
292
|
+
for group in target_group.split(";")
|
|
293
|
+
if group in data_dict
|
|
294
|
+
]
|
|
295
|
+
|
|
296
|
+
for criterion, criterions_attr in self.outputs_criterions[output_group][target_group].items():
|
|
297
|
+
if it >= criterions_attr.start and (criterions_attr.stop is None or it <= criterions_attr.stop):
|
|
298
|
+
scheduler = self.update_scheduler(criterions_attr.schedulers, it)
|
|
299
|
+
self._loss[criterions_attr.group][
|
|
300
|
+
f"{output_group}:{target_group}:{criterion.__class__.__name__}"
|
|
301
|
+
].add(scheduler.get_value(), criterion(output, *target))
|
|
302
|
+
if (
|
|
303
|
+
training
|
|
304
|
+
and len(
|
|
305
|
+
np.unique(
|
|
306
|
+
[
|
|
307
|
+
len(loss_value)
|
|
308
|
+
for loss_value in self._loss[criterions_attr.group].values()
|
|
309
|
+
if loss_value.accumulation and loss_value.is_loss
|
|
310
|
+
]
|
|
311
|
+
)
|
|
312
|
+
)
|
|
313
|
+
== 1
|
|
314
|
+
):
|
|
315
|
+
if criterions_attr.is_loss:
|
|
316
|
+
loss = torch.zeros((1), requires_grad=True)
|
|
317
|
+
for v in [
|
|
318
|
+
loss_value
|
|
319
|
+
for loss_value in self._loss[criterions_attr.group].values()
|
|
320
|
+
if loss_value.accumulation and loss_value.is_loss
|
|
321
|
+
]:
|
|
322
|
+
loss_value = v.get_last_loss()
|
|
323
|
+
loss = loss.to(loss_value.device) + loss_value
|
|
324
|
+
loss = loss / nb_patch
|
|
199
325
|
loss.backward()
|
|
200
326
|
|
|
201
|
-
def
|
|
327
|
+
def get_loss(self) -> list[torch.Tensor]:
|
|
202
328
|
loss: dict[int, torch.Tensor] = {}
|
|
203
329
|
for group in self._loss.keys():
|
|
204
|
-
loss[group] = torch.zeros((1), requires_grad
|
|
330
|
+
loss[group] = torch.zeros((1), requires_grad=True)
|
|
205
331
|
for v in self._loss[group].values():
|
|
206
|
-
if v.
|
|
207
|
-
|
|
208
|
-
loss[v.group] = loss[v.group].to(
|
|
209
|
-
return loss.values()
|
|
332
|
+
if v.is_loss and not v.accumulation:
|
|
333
|
+
loss_value = v.get_loss()
|
|
334
|
+
loss[v.group] = loss[v.group].to(loss_value.device) + loss_value
|
|
335
|
+
return list(loss.values())
|
|
210
336
|
|
|
211
|
-
def
|
|
337
|
+
def reset_loss(self) -> None:
|
|
212
338
|
for group in self._loss.keys():
|
|
213
339
|
for v in self._loss[group].values():
|
|
214
|
-
v.
|
|
215
|
-
|
|
216
|
-
def
|
|
340
|
+
v.reset_loss()
|
|
341
|
+
|
|
342
|
+
def get_last_values(self, n: int = 1) -> dict[str, float]:
|
|
217
343
|
result = {}
|
|
218
344
|
for group in self._loss.keys():
|
|
219
|
-
result.update(
|
|
345
|
+
result.update(
|
|
346
|
+
{
|
|
347
|
+
name: np.nanmean(value._values[-n:] if n > 0 else value._values)
|
|
348
|
+
for name, value in self._loss[group].items()
|
|
349
|
+
if n < 0 or len(value._values) >= n
|
|
350
|
+
}
|
|
351
|
+
)
|
|
220
352
|
return result
|
|
221
|
-
|
|
222
|
-
def
|
|
353
|
+
|
|
354
|
+
def get_last_weights(self, n: int = 1) -> dict[str, float]:
|
|
223
355
|
result = {}
|
|
224
356
|
for group in self._loss.keys():
|
|
225
|
-
result.update(
|
|
357
|
+
result.update(
|
|
358
|
+
{
|
|
359
|
+
name: np.nanmean(value._weight[-n:] if n > 0 else value._weight)
|
|
360
|
+
for name, value in self._loss[group].items()
|
|
361
|
+
if n < 0 or len(value._values) >= n
|
|
362
|
+
}
|
|
363
|
+
)
|
|
226
364
|
return result
|
|
227
365
|
|
|
228
|
-
def
|
|
229
|
-
result =
|
|
366
|
+
def format_loss(self, is_loss: bool, n: int) -> dict[str, tuple[float, float]]:
|
|
367
|
+
result = {}
|
|
230
368
|
for group in self._loss.keys():
|
|
231
369
|
for name, loss in self._loss[group].items():
|
|
232
|
-
if loss.
|
|
233
|
-
result[name] = (
|
|
370
|
+
if loss.is_loss == is_loss and len(loss._values) >= n:
|
|
371
|
+
result[name] = (
|
|
372
|
+
np.nanmean(loss._weight[-n:]),
|
|
373
|
+
np.nanmean(loss._values[-n:]),
|
|
374
|
+
)
|
|
234
375
|
return result
|
|
235
376
|
|
|
236
377
|
def update_scheduler(self, schedulers: dict[Scheduler, int], it: int) -> Scheduler:
|
|
237
378
|
step = 0
|
|
238
|
-
|
|
239
|
-
for
|
|
240
|
-
if value is None or (it >= step
|
|
379
|
+
_scheduler = None
|
|
380
|
+
for _scheduler, value in schedulers.items():
|
|
381
|
+
if value is None or (it >= step and it < step + value):
|
|
241
382
|
break
|
|
242
383
|
step += value
|
|
243
|
-
if
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
|
|
384
|
+
if _scheduler:
|
|
385
|
+
_scheduler.step(it - step)
|
|
386
|
+
if _scheduler is None:
|
|
387
|
+
raise NameError(
|
|
388
|
+
f"No scheduler found for iteration {it}. "
|
|
389
|
+
f"Available steps were: {list(schedulers.values())}. "
|
|
390
|
+
f"Check your configuration."
|
|
391
|
+
)
|
|
392
|
+
return _scheduler
|
|
393
|
+
|
|
394
|
+
|
|
247
395
|
class ModuleArgsDict(torch.nn.Module, ABC):
|
|
248
|
-
|
|
396
|
+
|
|
249
397
|
class ModuleArgs:
|
|
250
398
|
|
|
251
|
-
def __init__(
|
|
399
|
+
def __init__(
|
|
400
|
+
self,
|
|
401
|
+
in_branch: list[str],
|
|
402
|
+
out_branch: list[str],
|
|
403
|
+
pretrained: bool,
|
|
404
|
+
alias: list[str],
|
|
405
|
+
requires_grad: bool | None,
|
|
406
|
+
training: None | bool,
|
|
407
|
+
) -> None:
|
|
252
408
|
super().__init__()
|
|
253
|
-
self.alias= alias
|
|
409
|
+
self.alias = alias
|
|
254
410
|
self.pretrained = pretrained
|
|
255
411
|
self.in_branch = in_branch
|
|
256
412
|
self.out_branch = out_branch
|
|
257
|
-
self.in_channels = None
|
|
258
|
-
self.in_is_channel = True
|
|
259
|
-
self.out_channels = None
|
|
260
|
-
self.out_is_channel = True
|
|
413
|
+
self.in_channels: int | None = None
|
|
414
|
+
self.in_is_channel: bool = True
|
|
415
|
+
self.out_channels: int | None = None
|
|
416
|
+
self.out_is_channel: bool = True
|
|
261
417
|
self.requires_grad = requires_grad
|
|
262
418
|
self.isCheckpoint = False
|
|
263
419
|
self.isGPU_Checkpoint = False
|
|
@@ -267,51 +423,54 @@ class ModuleArgsDict(torch.nn.Module, ABC):
|
|
|
267
423
|
|
|
268
424
|
def __init__(self) -> None:
|
|
269
425
|
super().__init__()
|
|
270
|
-
self._modulesArgs
|
|
426
|
+
self._modulesArgs: dict[str, ModuleArgsDict.ModuleArgs] = {}
|
|
271
427
|
self._training = NetState.TRAIN
|
|
272
428
|
|
|
273
|
-
def _addindent(self, s_: str,
|
|
274
|
-
s = s_.split(
|
|
429
|
+
def _addindent(self, s_: str, num_spaces: int):
|
|
430
|
+
s = s_.split("\n")
|
|
275
431
|
if len(s) == 1:
|
|
276
432
|
return s_
|
|
277
433
|
first = s.pop(0)
|
|
278
|
-
s = [(
|
|
279
|
-
|
|
280
|
-
s = first + '\n' + s
|
|
281
|
-
return s
|
|
434
|
+
s = [(num_spaces * " ") + line for line in s]
|
|
435
|
+
return first + "\n" + "\n".join(s)
|
|
282
436
|
|
|
283
437
|
def __repr__(self):
|
|
284
438
|
extra_lines = []
|
|
285
439
|
|
|
286
440
|
extra_repr = self.extra_repr()
|
|
287
441
|
if extra_repr:
|
|
288
|
-
extra_lines = extra_repr.split(
|
|
442
|
+
extra_lines = extra_repr.split("\n")
|
|
289
443
|
|
|
290
444
|
child_lines = []
|
|
291
|
-
|
|
445
|
+
|
|
446
|
+
def is_simple_branch(x):
|
|
447
|
+
return len(x) > 1 or x[0] != 0
|
|
448
|
+
|
|
292
449
|
for key, module in self._modules.items():
|
|
293
450
|
mod_str = repr(module)
|
|
294
451
|
|
|
295
452
|
mod_str = self._addindent(mod_str, 2)
|
|
296
453
|
desc = ""
|
|
297
|
-
if is_simple_branch(self._modulesArgs[key].in_branch) or is_simple_branch(
|
|
298
|
-
|
|
454
|
+
if is_simple_branch(self._modulesArgs[key].in_branch) or is_simple_branch(
|
|
455
|
+
self._modulesArgs[key].out_branch
|
|
456
|
+
):
|
|
457
|
+
desc += f", {self._modulesArgs[key].in_branch}->{self._modulesArgs[key].out_branch}"
|
|
299
458
|
if not self._modulesArgs[key].pretrained:
|
|
300
459
|
desc += ", pretrained=False"
|
|
301
460
|
if self._modulesArgs[key].alias:
|
|
302
|
-
desc += ", alias={
|
|
303
|
-
desc += ", in_channels={
|
|
304
|
-
desc += ", in_is_channel={
|
|
305
|
-
desc += ", out_channels={
|
|
306
|
-
desc += ", out_is_channel={
|
|
307
|
-
desc += ", is_end={
|
|
308
|
-
desc += ", isInCheckpoint={
|
|
309
|
-
desc += ", isInGPU_Checkpoint={
|
|
310
|
-
desc += ", requires_grad={
|
|
311
|
-
desc += ", device={
|
|
312
|
-
|
|
313
|
-
child_lines.append("({}{}) {}"
|
|
314
|
-
|
|
461
|
+
desc += f", alias={self._modulesArgs[key].alias}"
|
|
462
|
+
desc += f", in_channels={self._modulesArgs[key].in_channels}"
|
|
463
|
+
desc += f", in_is_channel={self._modulesArgs[key].in_is_channel}"
|
|
464
|
+
desc += f", out_channels={self._modulesArgs[key].out_channels}"
|
|
465
|
+
desc += f", out_is_channel={self._modulesArgs[key].out_is_channel}"
|
|
466
|
+
desc += f", is_end={self._modulesArgs[key]._isEnd}"
|
|
467
|
+
desc += f", isInCheckpoint={self._modulesArgs[key].isCheckpoint}"
|
|
468
|
+
desc += f", isInGPU_Checkpoint={self._modulesArgs[key].isGPU_Checkpoint}"
|
|
469
|
+
desc += f", requires_grad={self._modulesArgs[key].requires_grad}"
|
|
470
|
+
desc += f", device={self._modulesArgs[key].gpu}"
|
|
471
|
+
|
|
472
|
+
child_lines.append(f"({key}{desc}) {mod_str}")
|
|
473
|
+
|
|
315
474
|
lines = extra_lines + child_lines
|
|
316
475
|
|
|
317
476
|
desc = ""
|
|
@@ -319,53 +478,71 @@ class ModuleArgsDict(torch.nn.Module, ABC):
|
|
|
319
478
|
if len(extra_lines) == 1 and not child_lines:
|
|
320
479
|
desc += extra_lines[0]
|
|
321
480
|
else:
|
|
322
|
-
desc +=
|
|
481
|
+
desc += "\n " + "\n ".join(lines) + "\n"
|
|
482
|
+
|
|
483
|
+
return f"{self._get_name()}({desc})"
|
|
323
484
|
|
|
324
|
-
return "{}({})".format(self._get_name(), desc)
|
|
325
|
-
|
|
326
485
|
def __getitem__(self, key: str) -> torch.nn.Module:
|
|
327
486
|
module = self._modules[key]
|
|
328
|
-
|
|
329
|
-
|
|
487
|
+
if not module:
|
|
488
|
+
raise ValueError(f"Module '{key}' is None or missing in self._modules")
|
|
489
|
+
return module
|
|
330
490
|
|
|
331
491
|
@_copy_to_script_wrapper
|
|
332
492
|
def keys(self) -> Iterable[str]:
|
|
333
493
|
return self._modules.keys()
|
|
334
494
|
|
|
335
495
|
@_copy_to_script_wrapper
|
|
336
|
-
def items(self) -> Iterable[tuple[str,
|
|
496
|
+
def items(self) -> Iterable[tuple[str, torch.nn.Module | None]]:
|
|
337
497
|
return self._modules.items()
|
|
338
498
|
|
|
339
499
|
@_copy_to_script_wrapper
|
|
340
|
-
def values(self) -> Iterable[
|
|
500
|
+
def values(self) -> Iterable[torch.nn.Module | None]:
|
|
341
501
|
return self._modules.values()
|
|
342
502
|
|
|
343
|
-
def add_module(
|
|
503
|
+
def add_module(
|
|
504
|
+
self,
|
|
505
|
+
name: str,
|
|
506
|
+
module: torch.nn.Module,
|
|
507
|
+
in_branch: Sequence[int | str] = [0],
|
|
508
|
+
out_branch: Sequence[int | str] = [0],
|
|
509
|
+
pretrained: bool = True,
|
|
510
|
+
alias: list[str] = [],
|
|
511
|
+
requires_grad: bool | None = None,
|
|
512
|
+
training: None | bool = None,
|
|
513
|
+
) -> None:
|
|
344
514
|
super().add_module(name, module)
|
|
345
|
-
self._modulesArgs[name] = ModuleArgsDict.ModuleArgs(
|
|
346
|
-
|
|
347
|
-
|
|
348
|
-
|
|
349
|
-
|
|
515
|
+
self._modulesArgs[name] = ModuleArgsDict.ModuleArgs(
|
|
516
|
+
[str(value) for value in in_branch],
|
|
517
|
+
[str(value) for value in out_branch],
|
|
518
|
+
pretrained,
|
|
519
|
+
alias,
|
|
520
|
+
requires_grad,
|
|
521
|
+
training,
|
|
522
|
+
)
|
|
523
|
+
|
|
524
|
+
def get_mapping(self):
|
|
525
|
+
results: dict[str, str] = {}
|
|
526
|
+
for name, module_args in self._modulesArgs.items():
|
|
350
527
|
module = self[name]
|
|
351
528
|
if isinstance(module, ModuleArgsDict):
|
|
352
|
-
if len(
|
|
353
|
-
count =
|
|
529
|
+
if len(module_args.alias):
|
|
530
|
+
count = dict.fromkeys(set(module.get_mapping().values()), 0)
|
|
354
531
|
if len(count):
|
|
355
|
-
for k, v in module.
|
|
356
|
-
alias_name =
|
|
532
|
+
for k, v in module.get_mapping().items():
|
|
533
|
+
alias_name = module_args.alias[count[v]]
|
|
357
534
|
if k == "":
|
|
358
|
-
results.update({alias_name
|
|
535
|
+
results.update({alias_name: name + "." + v})
|
|
359
536
|
else:
|
|
360
|
-
results.update({alias_name+"."+k
|
|
361
|
-
count[v]+=1
|
|
537
|
+
results.update({alias_name + "." + k: name + "." + v})
|
|
538
|
+
count[v] += 1
|
|
362
539
|
else:
|
|
363
|
-
for alias in
|
|
364
|
-
results.update({alias
|
|
540
|
+
for alias in module_args.alias:
|
|
541
|
+
results.update({alias: name})
|
|
365
542
|
else:
|
|
366
|
-
results.update({k
|
|
543
|
+
results.update({k: name + "." + v for k, v in module.get_mapping().items()})
|
|
367
544
|
else:
|
|
368
|
-
for alias in
|
|
545
|
+
for alias in module_args.alias:
|
|
369
546
|
results[alias] = name
|
|
370
547
|
return results
|
|
371
548
|
|
|
@@ -375,24 +552,24 @@ class ModuleArgsDict(torch.nn.Module, ABC):
|
|
|
375
552
|
if isinstance(module, ModuleArgsDict):
|
|
376
553
|
module.init(init_type, init_gain)
|
|
377
554
|
elif isinstance(module, torch.nn.modules.conv._ConvNd) or isinstance(module, torch.nn.Linear):
|
|
378
|
-
if init_type ==
|
|
555
|
+
if init_type == "normal":
|
|
379
556
|
torch.nn.init.normal_(module.weight, 0.0, init_gain)
|
|
380
|
-
elif init_type ==
|
|
557
|
+
elif init_type == "xavier":
|
|
381
558
|
torch.nn.init.xavier_normal_(module.weight, gain=init_gain)
|
|
382
|
-
elif init_type ==
|
|
383
|
-
torch.nn.init.kaiming_normal_(module.weight, a=0, mode=
|
|
384
|
-
elif init_type ==
|
|
559
|
+
elif init_type == "kaiming":
|
|
560
|
+
torch.nn.init.kaiming_normal_(module.weight, a=0, mode="fan_in")
|
|
561
|
+
elif init_type == "orthogonal":
|
|
385
562
|
torch.nn.init.orthogonal_(module.weight, gain=init_gain)
|
|
386
563
|
elif init_type == "trunc_normal":
|
|
387
564
|
torch.nn.init.trunc_normal_(module.weight, std=init_gain)
|
|
388
565
|
else:
|
|
389
|
-
raise NotImplementedError(
|
|
566
|
+
raise NotImplementedError(f"Initialization method {init_type} is not implemented")
|
|
390
567
|
if module.bias is not None:
|
|
391
568
|
torch.nn.init.constant_(module.bias, 0.0)
|
|
392
569
|
|
|
393
570
|
elif isinstance(module, torch.nn.modules.batchnorm._BatchNorm):
|
|
394
571
|
if module.weight is not None:
|
|
395
|
-
torch.nn.init.normal_(module.weight, 0.0, std
|
|
572
|
+
torch.nn.init.normal_(module.weight, 0.0, std=init_gain)
|
|
396
573
|
if module.bias is not None:
|
|
397
574
|
torch.nn.init.constant_(module.bias, 0.0)
|
|
398
575
|
|
|
@@ -400,8 +577,7 @@ class ModuleArgsDict(torch.nn.Module, ABC):
|
|
|
400
577
|
for module in self._modules.values():
|
|
401
578
|
ModuleArgsDict.init_func(module, init_type, init_gain)
|
|
402
579
|
|
|
403
|
-
|
|
404
|
-
def named_forward(self, *inputs: torch.Tensor) -> Iterator[tuple[str, torch.Tensor]]:
|
|
580
|
+
def named_forward(self, *inputs: torch.Tensor) -> Iterator[tuple[str, torch.Tensor]]:
|
|
405
581
|
if len(inputs) > 0:
|
|
406
582
|
branchs: dict[str, torch.Tensor] = {}
|
|
407
583
|
for i, sinput in enumerate(inputs):
|
|
@@ -410,7 +586,10 @@ class ModuleArgsDict(torch.nn.Module, ABC):
|
|
|
410
586
|
out = inputs[0]
|
|
411
587
|
tmp = []
|
|
412
588
|
for name, module in self.items():
|
|
413
|
-
if self._modulesArgs[name].training is None or (
|
|
589
|
+
if self._modulesArgs[name].training is None or (
|
|
590
|
+
not (self._modulesArgs[name].training and self._training == NetState.PREDICTION)
|
|
591
|
+
and not (not self._modulesArgs[name].training and self._training == NetState.TRAIN)
|
|
592
|
+
):
|
|
414
593
|
requires_grad = self._modulesArgs[name].requires_grad
|
|
415
594
|
if requires_grad is not None and module:
|
|
416
595
|
module.requires_grad_(requires_grad)
|
|
@@ -418,64 +597,75 @@ class ModuleArgsDict(torch.nn.Module, ABC):
|
|
|
418
597
|
if ib not in branchs:
|
|
419
598
|
branchs[ib] = inputs[0]
|
|
420
599
|
for branchs_key in branchs.keys():
|
|
421
|
-
if
|
|
600
|
+
if (
|
|
601
|
+
str(self._modulesArgs[name].gpu) != "cpu"
|
|
602
|
+
and str(branchs[branchs_key].device) != "cuda:" + self._modulesArgs[name].gpu
|
|
603
|
+
):
|
|
422
604
|
branchs[branchs_key] = branchs[branchs_key].to(int(self._modulesArgs[name].gpu))
|
|
423
|
-
|
|
605
|
+
|
|
424
606
|
if self._modulesArgs[name].isCheckpoint:
|
|
425
|
-
out = checkpoint(
|
|
607
|
+
out = checkpoint(
|
|
608
|
+
module,
|
|
609
|
+
*[branchs[i] for i in self._modulesArgs[name].in_branch],
|
|
610
|
+
use_reentrant=True,
|
|
611
|
+
)
|
|
426
612
|
for ob in self._modulesArgs[name].out_branch:
|
|
427
613
|
branchs[ob] = out
|
|
428
614
|
yield name, out
|
|
429
615
|
else:
|
|
430
616
|
if isinstance(module, ModuleArgsDict):
|
|
431
|
-
for k, out in module.named_forward(
|
|
617
|
+
for k, out in module.named_forward(
|
|
618
|
+
*[branchs[i] for i in self._modulesArgs[name].in_branch]
|
|
619
|
+
):
|
|
432
620
|
for ob in self._modulesArgs[name].out_branch:
|
|
433
621
|
if ob in module._modulesArgs[k.split(".")[0].replace(";accu;", "")].out_branch:
|
|
434
622
|
tmp.append(ob)
|
|
435
623
|
branchs[ob] = out
|
|
436
|
-
yield name+"."+k, out
|
|
624
|
+
yield name + "." + k, out
|
|
437
625
|
for ob in self._modulesArgs[name].out_branch:
|
|
438
626
|
if ob not in tmp:
|
|
439
627
|
branchs[ob] = out
|
|
440
628
|
elif isinstance(module, torch.nn.Module):
|
|
441
|
-
out = module(*[branchs[i] for i in self._modulesArgs[name].in_branch])
|
|
629
|
+
out = module(*[branchs[i] for i in self._modulesArgs[name].in_branch])
|
|
442
630
|
for ob in self._modulesArgs[name].out_branch:
|
|
443
631
|
branchs[ob] = out
|
|
444
632
|
yield name, out
|
|
445
633
|
del branchs
|
|
446
634
|
|
|
447
635
|
def forward(self, *input: torch.Tensor) -> torch.Tensor:
|
|
448
|
-
|
|
449
|
-
for
|
|
636
|
+
_v = input
|
|
637
|
+
for _, _v in self.named_forward(*input):
|
|
450
638
|
pass
|
|
451
|
-
return
|
|
639
|
+
return _v
|
|
452
640
|
|
|
453
|
-
def named_parameters(
|
|
454
|
-
|
|
641
|
+
def named_parameters(
|
|
642
|
+
self, pretrained: bool = False, recurse=False
|
|
643
|
+
) -> Iterator[tuple[str, torch.nn.parameter.Parameter]]:
|
|
644
|
+
for name, module_args in self._modulesArgs.items():
|
|
455
645
|
module = self[name]
|
|
456
646
|
if isinstance(module, ModuleArgsDict):
|
|
457
|
-
for k, v in module.named_parameters(pretrained
|
|
458
|
-
yield name+"."+k, v
|
|
647
|
+
for k, v in module.named_parameters(pretrained=pretrained):
|
|
648
|
+
yield name + "." + k, v
|
|
459
649
|
elif isinstance(module, torch.nn.Module):
|
|
460
|
-
if not pretrained or not
|
|
461
|
-
if
|
|
650
|
+
if not pretrained or not module_args.pretrained:
|
|
651
|
+
if module_args.training is None or module_args.training:
|
|
462
652
|
for k, v in module.named_parameters():
|
|
463
|
-
yield name+"."+k, v
|
|
653
|
+
yield name + "." + k, v
|
|
464
654
|
|
|
465
655
|
def parameters(self, pretrained: bool = False):
|
|
466
|
-
for _, v in self.named_parameters(pretrained
|
|
656
|
+
for _, v in self.named_parameters(pretrained=pretrained):
|
|
467
657
|
yield v
|
|
468
|
-
|
|
469
|
-
def
|
|
658
|
+
|
|
659
|
+
def named_module_args_dict(self) -> Iterator[tuple[str, Self, ModuleArgs]]:
|
|
470
660
|
for name, module in self._modules.items():
|
|
471
661
|
yield name, module, self._modulesArgs[name]
|
|
472
662
|
if isinstance(module, ModuleArgsDict):
|
|
473
|
-
for k, v, u in module.
|
|
474
|
-
yield name+"."+k, v, u
|
|
663
|
+
for k, v, u in module.named_module_args_dict():
|
|
664
|
+
yield name + "." + k, v, u
|
|
475
665
|
|
|
476
666
|
def _requires_grad(self, keys: list[str]):
|
|
477
667
|
keys = keys.copy()
|
|
478
|
-
for name, module, args in self.
|
|
668
|
+
for name, module, args in self.named_module_args_dict():
|
|
479
669
|
requires_grad = args.requires_grad
|
|
480
670
|
if requires_grad is not None:
|
|
481
671
|
module.requires_grad_(requires_grad)
|
|
@@ -484,162 +674,224 @@ class ModuleArgsDict(torch.nn.Module, ABC):
|
|
|
484
674
|
if len(keys) == 0:
|
|
485
675
|
break
|
|
486
676
|
|
|
677
|
+
|
|
487
678
|
class OutputsGroup(list):
|
|
488
679
|
|
|
489
680
|
def __init__(self, measure: Measure) -> None:
|
|
490
681
|
self.layers: dict[str, torch.Tensor] = {}
|
|
491
682
|
self.measure = measure
|
|
492
683
|
|
|
493
|
-
def
|
|
684
|
+
def add_layer(self, name: str, layer: torch.Tensor):
|
|
494
685
|
self.layers[name] = layer
|
|
495
686
|
|
|
496
|
-
def
|
|
687
|
+
def is_done(self):
|
|
497
688
|
return len(self) == len(self.layers)
|
|
498
|
-
|
|
689
|
+
|
|
499
690
|
def clear(self):
|
|
500
691
|
self.layers.clear()
|
|
501
692
|
|
|
502
|
-
|
|
503
693
|
|
|
504
694
|
class Network(ModuleArgsDict, ABC):
|
|
505
695
|
|
|
506
|
-
def _apply_network(
|
|
507
|
-
|
|
696
|
+
def _apply_network(
|
|
697
|
+
self,
|
|
698
|
+
name_function: Callable[[Self], str],
|
|
699
|
+
networks: list[str],
|
|
700
|
+
key: str,
|
|
701
|
+
function: Callable,
|
|
702
|
+
*args,
|
|
703
|
+
**kwargs,
|
|
704
|
+
) -> dict[str, object]:
|
|
705
|
+
results: dict[str, object] = {}
|
|
508
706
|
for module in self.values():
|
|
509
707
|
if isinstance(module, Network):
|
|
510
708
|
if name_function(module) not in networks:
|
|
511
709
|
networks.append(name_function(module))
|
|
512
|
-
for k, v in module._apply_network(
|
|
513
|
-
|
|
710
|
+
for k, v in module._apply_network(
|
|
711
|
+
name_function,
|
|
712
|
+
networks,
|
|
713
|
+
key + "." + name_function(module),
|
|
714
|
+
function,
|
|
715
|
+
*args,
|
|
716
|
+
**kwargs,
|
|
717
|
+
).items():
|
|
718
|
+
results.update({name_function(self) + "." + k: v})
|
|
514
719
|
if len([param.name for param in list(inspect.signature(function).parameters.values()) if param.name == "key"]):
|
|
515
720
|
function = partial(function, key=key)
|
|
516
721
|
|
|
517
722
|
results[name_function(self)] = function(self, *args, **kwargs)
|
|
518
723
|
return results
|
|
519
|
-
|
|
520
|
-
def _function_network():
|
|
521
|
-
def _function_network_d(function
|
|
522
|
-
def new_function(self
|
|
523
|
-
return self._apply_network(
|
|
724
|
+
|
|
725
|
+
def _function_network(): # type: ignore[misc]
|
|
726
|
+
def _function_network_d(function: Callable):
|
|
727
|
+
def new_function(self: Self, *args, **kwargs) -> dict[str, object]:
|
|
728
|
+
return self._apply_network(
|
|
729
|
+
lambda network: network.get_name(),
|
|
730
|
+
[],
|
|
731
|
+
self.get_name(),
|
|
732
|
+
function,
|
|
733
|
+
*args,
|
|
734
|
+
**kwargs,
|
|
735
|
+
)
|
|
736
|
+
|
|
524
737
|
return new_function
|
|
738
|
+
|
|
525
739
|
return _function_network_d
|
|
526
740
|
|
|
527
|
-
def __init__(
|
|
528
|
-
|
|
529
|
-
|
|
530
|
-
|
|
531
|
-
|
|
532
|
-
|
|
533
|
-
|
|
534
|
-
|
|
535
|
-
|
|
536
|
-
|
|
741
|
+
def __init__(
|
|
742
|
+
self,
|
|
743
|
+
in_channels: int = 1,
|
|
744
|
+
optimizer: OptimizerLoader | None = None,
|
|
745
|
+
schedulers: dict[str, LRSchedulersLoader] | None = None,
|
|
746
|
+
outputs_criterions: dict[str, TargetCriterionsLoader] | None = None,
|
|
747
|
+
patch: ModelPatch | None = None,
|
|
748
|
+
nb_batch_per_step: int = 1,
|
|
749
|
+
init_type: str = "normal",
|
|
750
|
+
init_gain: float = 0.02,
|
|
751
|
+
dim: int = 3,
|
|
752
|
+
) -> None:
|
|
537
753
|
super().__init__()
|
|
538
754
|
self.name = self.__class__.__name__
|
|
539
755
|
self.in_channels = in_channels
|
|
540
|
-
self.optimizerLoader
|
|
541
|
-
self.optimizer
|
|
756
|
+
self.optimizerLoader = optimizer
|
|
757
|
+
self.optimizer: torch.optim.Optimizer | None = None
|
|
542
758
|
|
|
543
|
-
self.
|
|
544
|
-
self.schedulers
|
|
759
|
+
self.lr_schedulers_loader = schedulers
|
|
760
|
+
self.schedulers: dict[torch.optim.lr_scheduler._LRScheduler, int] = {}
|
|
545
761
|
|
|
546
|
-
self.
|
|
547
|
-
self.measure
|
|
762
|
+
self.outputs_criterions_loader = outputs_criterions
|
|
763
|
+
self.measure: Measure | None = None
|
|
548
764
|
|
|
549
765
|
self.patch = patch
|
|
550
766
|
|
|
551
767
|
self.nb_batch_per_step = nb_batch_per_step
|
|
552
|
-
self.init_type
|
|
553
|
-
self.init_gain
|
|
768
|
+
self.init_type = init_type
|
|
769
|
+
self.init_gain = init_gain
|
|
554
770
|
self.dim = dim
|
|
555
771
|
self._it = 0
|
|
556
772
|
self._nb_lr_update = 0
|
|
557
|
-
self.outputsGroup
|
|
773
|
+
self.outputsGroup: list[OutputsGroup] = []
|
|
558
774
|
|
|
559
775
|
@_function_network()
|
|
560
776
|
def state_dict(self) -> dict[str, OrderedDict]:
|
|
561
|
-
destination = OrderedDict()
|
|
562
|
-
|
|
563
|
-
destination
|
|
777
|
+
destination: OrderedDict[str, Any] = OrderedDict()
|
|
778
|
+
local_metadata = {"version": self._version}
|
|
779
|
+
# destination["_metadata"] = OrderedDict({"": local_metadata})
|
|
564
780
|
self._save_to_state_dict(destination, "", False)
|
|
565
781
|
for name, module in self._modules.items():
|
|
566
782
|
if module is not None:
|
|
567
|
-
if not isinstance(module, Network):
|
|
568
|
-
module.state_dict(destination=destination, prefix="" + name +
|
|
783
|
+
if not isinstance(module, Network):
|
|
784
|
+
module.state_dict(destination=destination, prefix="" + name + ".", keep_vars=False)
|
|
569
785
|
for hook in self._state_dict_hooks.values():
|
|
570
786
|
hook_result = hook(self, destination, "", local_metadata)
|
|
571
787
|
if hook_result is not None:
|
|
572
788
|
destination = hook_result
|
|
573
789
|
return destination
|
|
574
|
-
|
|
790
|
+
|
|
575
791
|
def load_state_dict(self, state_dict: dict[str, torch.Tensor]):
|
|
576
792
|
missing_keys: list[str] = []
|
|
577
793
|
unexpected_keys: list[str] = []
|
|
578
794
|
error_msgs: list[str] = []
|
|
579
795
|
|
|
580
|
-
metadata = getattr(state_dict,
|
|
796
|
+
metadata = getattr(state_dict, "_metadata", None)
|
|
581
797
|
state_dict = state_dict.copy()
|
|
582
798
|
if metadata is not None:
|
|
583
|
-
state_dict
|
|
799
|
+
state_dict["_metadata"] = metadata
|
|
584
800
|
|
|
585
|
-
def load(module: torch.nn.Module, prefix=
|
|
801
|
+
def load(module: torch.nn.Module, prefix=""):
|
|
586
802
|
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
|
|
587
803
|
module._load_from_state_dict(
|
|
588
|
-
state_dict,
|
|
804
|
+
state_dict,
|
|
805
|
+
prefix,
|
|
806
|
+
local_metadata,
|
|
807
|
+
True,
|
|
808
|
+
missing_keys,
|
|
809
|
+
unexpected_keys,
|
|
810
|
+
error_msgs,
|
|
811
|
+
)
|
|
589
812
|
for name, child in module._modules.items():
|
|
590
813
|
if child is not None:
|
|
591
814
|
if not isinstance(child, Network):
|
|
592
815
|
if isinstance(child, torch.nn.modules.conv._ConvNd) or isinstance(module, torch.nn.Linear):
|
|
593
816
|
|
|
594
817
|
current_size = child.weight.shape[0]
|
|
595
|
-
last_size = state_dict[prefix + name+".weight"].shape[0]
|
|
818
|
+
last_size = state_dict[prefix + name + ".weight"].shape[0]
|
|
596
819
|
|
|
597
820
|
if current_size != last_size:
|
|
598
|
-
print(
|
|
821
|
+
print(
|
|
822
|
+
f"Warning: The size of '{prefix + name}' has changed from {last_size}"
|
|
823
|
+
f" to {current_size}. Please check for potential impacts"
|
|
824
|
+
)
|
|
599
825
|
ModuleArgsDict.init_func(child, self.init_type, self.init_gain)
|
|
600
826
|
|
|
601
827
|
with torch.no_grad():
|
|
602
|
-
child.weight[:last_size] = state_dict[prefix + name+".weight"]
|
|
828
|
+
child.weight[:last_size] = state_dict[prefix + name + ".weight"]
|
|
603
829
|
if child.bias is not None:
|
|
604
|
-
child.bias[:last_size] = state_dict[prefix + name+".bias"]
|
|
830
|
+
child.bias[:last_size] = state_dict[prefix + name + ".bias"]
|
|
605
831
|
return
|
|
606
|
-
load(child, prefix + name +
|
|
832
|
+
load(child, prefix + name + ".")
|
|
607
833
|
|
|
608
834
|
load(self)
|
|
609
|
-
del load
|
|
610
835
|
|
|
611
836
|
if len(unexpected_keys) > 0:
|
|
837
|
+
formatted_keys = ", ".join(f'"{k}"' for k in unexpected_keys)
|
|
612
838
|
error_msgs.insert(
|
|
613
|
-
0,
|
|
614
|
-
|
|
839
|
+
0,
|
|
840
|
+
f"Unexpected key(s) in state_dict: {formatted_keys}.",
|
|
841
|
+
)
|
|
615
842
|
if len(missing_keys) > 0:
|
|
843
|
+
formatted_keys = ", ".join(f'"{k}"' for k in missing_keys)
|
|
616
844
|
error_msgs.insert(
|
|
617
|
-
0,
|
|
618
|
-
|
|
845
|
+
0,
|
|
846
|
+
f"Missing key(s) in state_dict: {formatted_keys}.",
|
|
847
|
+
)
|
|
619
848
|
|
|
620
849
|
if len(error_msgs) > 0:
|
|
621
|
-
|
|
622
|
-
|
|
623
|
-
|
|
850
|
+
formatted_errors = "\n\t".join(error_msgs)
|
|
851
|
+
raise RuntimeError(
|
|
852
|
+
f"Error(s) in loading state_dict for {self.__class__.__name__}:\n\t{formatted_errors}",
|
|
853
|
+
)
|
|
854
|
+
|
|
624
855
|
def apply(self, fn: Callable[[torch.nn.Module], None]) -> None:
|
|
625
856
|
for module in self.children():
|
|
626
857
|
if not isinstance(module, Network):
|
|
627
858
|
module.apply(fn)
|
|
628
859
|
fn(self)
|
|
629
|
-
|
|
860
|
+
|
|
630
861
|
@_function_network()
|
|
631
|
-
def load(
|
|
862
|
+
def load(
|
|
863
|
+
self,
|
|
864
|
+
state_dict: dict[str, dict[str, torch.Tensor] | int],
|
|
865
|
+
init: bool = True,
|
|
866
|
+
ema: bool = False,
|
|
867
|
+
):
|
|
632
868
|
if init:
|
|
633
|
-
self.apply(
|
|
634
|
-
|
|
869
|
+
self.apply(
|
|
870
|
+
partial(
|
|
871
|
+
ModuleArgsDict.init_func,
|
|
872
|
+
init_type=self.init_type,
|
|
873
|
+
init_gain=self.init_gain,
|
|
874
|
+
)
|
|
875
|
+
)
|
|
876
|
+
name = "Model"
|
|
877
|
+
if ema:
|
|
878
|
+
if name + "_EMA" in state_dict:
|
|
879
|
+
name += "_EMA"
|
|
635
880
|
if name in state_dict:
|
|
636
|
-
|
|
637
|
-
|
|
638
|
-
|
|
639
|
-
|
|
881
|
+
value = state_dict[name]
|
|
882
|
+
model_state_dict_tmp = {}
|
|
883
|
+
if isinstance(value, dict):
|
|
884
|
+
model_state_dict_tmp = {k.split(".")[-1]: v for k, v in value.items()}[self.get_name()]
|
|
885
|
+
modules_name = self.get_mapping()
|
|
886
|
+
model_state_dict: OrderedDict[str, torch.Tensor] = OrderedDict()
|
|
887
|
+
|
|
640
888
|
for alias in model_state_dict_tmp.keys():
|
|
641
889
|
prefix = ".".join(alias.split(".")[:-1])
|
|
642
|
-
alias_list = [
|
|
890
|
+
alias_list = [
|
|
891
|
+
(".".join(prefix.split(".")[: len(i.split("."))]), v)
|
|
892
|
+
for i, v in modules_name.items()
|
|
893
|
+
if prefix.startswith(i)
|
|
894
|
+
]
|
|
643
895
|
|
|
644
896
|
if len(alias_list):
|
|
645
897
|
for a, b in alias_list:
|
|
@@ -648,19 +900,36 @@ class Network(ModuleArgsDict, ABC):
|
|
|
648
900
|
else:
|
|
649
901
|
model_state_dict[alias] = model_state_dict_tmp[alias]
|
|
650
902
|
self.load_state_dict(model_state_dict)
|
|
651
|
-
if "{
|
|
652
|
-
self.optimizer.
|
|
653
|
-
|
|
654
|
-
self.
|
|
655
|
-
if "{
|
|
656
|
-
|
|
903
|
+
if f"{self.get_name()}_optimizer_state_dict" in state_dict and self.optimizer:
|
|
904
|
+
last_lr = self.optimizer.param_groups[0]["lr"]
|
|
905
|
+
self.optimizer.load_state_dict(state_dict[f"{self.get_name()}_optimizer_state_dict"])
|
|
906
|
+
self.optimizer.param_groups[0]["lr"] = last_lr
|
|
907
|
+
if f"{self.get_name()}_it" in state_dict:
|
|
908
|
+
_it = state_dict.get(f"{self.get_name()}_it")
|
|
909
|
+
if isinstance(_it, int):
|
|
910
|
+
self._it = _it
|
|
911
|
+
if f"{self.get_name()}_nb_lr_update" in state_dict:
|
|
912
|
+
_nb_lr_update = state_dict.get(f"{self.get_name()}_nb_lr_update")
|
|
913
|
+
if isinstance(_nb_lr_update, int):
|
|
914
|
+
self._nb_lr_update = _nb_lr_update
|
|
915
|
+
|
|
657
916
|
for scheduler in self.schedulers:
|
|
658
917
|
if scheduler.last_epoch == -1:
|
|
659
918
|
scheduler.last_epoch = self._nb_lr_update
|
|
660
919
|
self.initialized()
|
|
661
920
|
|
|
662
|
-
def _compute_channels_trace(
|
|
663
|
-
|
|
921
|
+
def _compute_channels_trace(
|
|
922
|
+
self,
|
|
923
|
+
module: ModuleArgsDict,
|
|
924
|
+
in_channels: int,
|
|
925
|
+
gradient_checkpoints: list[str] | None,
|
|
926
|
+
gpu_checkpoints: list[str] | None,
|
|
927
|
+
name: str | None = None,
|
|
928
|
+
in_is_channel: bool = True,
|
|
929
|
+
out_channels: int | None = None,
|
|
930
|
+
out_is_channel: bool = True,
|
|
931
|
+
) -> tuple[int, bool, int | None, bool]:
|
|
932
|
+
|
|
664
933
|
for k1, v1 in module.items():
|
|
665
934
|
if isinstance(v1, ModuleArgsDict):
|
|
666
935
|
for t in module._modulesArgs[k1].out_branch:
|
|
@@ -680,25 +949,34 @@ class Network(ModuleArgsDict, ABC):
|
|
|
680
949
|
if hasattr(v, "in_features"):
|
|
681
950
|
if v.in_features:
|
|
682
951
|
in_channels = v.in_features
|
|
683
|
-
key = name+"."+k if name else k
|
|
952
|
+
key = name + "." + k if name else k
|
|
684
953
|
|
|
685
954
|
if gradient_checkpoints:
|
|
686
955
|
if key in gradient_checkpoints:
|
|
687
956
|
module._modulesArgs[k].isCheckpoint = True
|
|
688
|
-
|
|
957
|
+
|
|
689
958
|
if gpu_checkpoints:
|
|
690
959
|
if key in gpu_checkpoints:
|
|
691
960
|
module._modulesArgs[k].isGPU_Checkpoint = True
|
|
692
|
-
|
|
961
|
+
|
|
693
962
|
module._modulesArgs[k].in_channels = in_channels
|
|
694
963
|
module._modulesArgs[k].in_is_channel = in_is_channel
|
|
695
|
-
|
|
964
|
+
|
|
696
965
|
if isinstance(v, ModuleArgsDict):
|
|
697
|
-
in_channels, in_is_channel, out_channels, out_is_channel = self._compute_channels_trace(
|
|
698
|
-
|
|
966
|
+
in_channels, in_is_channel, out_channels, out_is_channel = self._compute_channels_trace(
|
|
967
|
+
v,
|
|
968
|
+
in_channels,
|
|
969
|
+
gradient_checkpoints,
|
|
970
|
+
gpu_checkpoints,
|
|
971
|
+
key,
|
|
972
|
+
in_is_channel,
|
|
973
|
+
out_channels,
|
|
974
|
+
out_is_channel,
|
|
975
|
+
)
|
|
976
|
+
|
|
699
977
|
if v.__class__.__name__ == "ToChannels":
|
|
700
978
|
out_is_channel = True
|
|
701
|
-
|
|
979
|
+
|
|
702
980
|
if v.__class__.__name__ == "ToFeatures":
|
|
703
981
|
out_is_channel = False
|
|
704
982
|
|
|
@@ -712,26 +990,32 @@ class Network(ModuleArgsDict, ABC):
|
|
|
712
990
|
module._modulesArgs[k].out_channels = out_channels
|
|
713
991
|
module._modulesArgs[k].out_is_channel = out_is_channel
|
|
714
992
|
|
|
715
|
-
in_channels = out_channels
|
|
993
|
+
in_channels = out_channels if out_channels is not None else in_channels
|
|
716
994
|
in_is_channel = out_is_channel
|
|
717
|
-
|
|
995
|
+
|
|
718
996
|
return in_channels, in_is_channel, out_channels, out_is_channel
|
|
719
997
|
|
|
720
998
|
@_function_network()
|
|
721
|
-
def init(self, autocast
|
|
722
|
-
if self.
|
|
723
|
-
self.measure = Measure(key, self.
|
|
999
|
+
def init(self, autocast: bool, state: State, group_dest: list[str], key: str) -> None:
|
|
1000
|
+
if self.outputs_criterions_loader:
|
|
1001
|
+
self.measure = Measure(key, self.outputs_criterions_loader)
|
|
724
1002
|
self.measure.init(self, group_dest)
|
|
1003
|
+
if self.patch is not None:
|
|
1004
|
+
self.patch.init(f"{konfai_root()}.Model.{key}.Patch")
|
|
725
1005
|
if state != State.PREDICTION:
|
|
726
1006
|
self.scaler = torch.amp.GradScaler("cuda", enabled=autocast)
|
|
727
1007
|
if self.optimizerLoader:
|
|
728
|
-
self.optimizer = self.optimizerLoader.
|
|
1008
|
+
self.optimizer = self.optimizerLoader.get_optimizer(
|
|
1009
|
+
key, self.parameters(state == State.TRANSFER_LEARNING)
|
|
1010
|
+
)
|
|
729
1011
|
self.optimizer.zero_grad()
|
|
730
1012
|
|
|
731
|
-
if self.
|
|
732
|
-
for schedulers_classname, schedulers in self.
|
|
733
|
-
self.schedulers[schedulers.getschedulers(key, schedulers_classname, self.optimizer)] =
|
|
734
|
-
|
|
1013
|
+
if self.lr_schedulers_loader and self.optimizer:
|
|
1014
|
+
for schedulers_classname, schedulers in self.lr_schedulers_loader.items():
|
|
1015
|
+
self.schedulers[schedulers.getschedulers(key, schedulers_classname, self.optimizer)] = (
|
|
1016
|
+
schedulers.nb_step
|
|
1017
|
+
)
|
|
1018
|
+
|
|
735
1019
|
def initialized(self):
|
|
736
1020
|
pass
|
|
737
1021
|
|
|
@@ -740,80 +1024,101 @@ class Network(ModuleArgsDict, ABC):
|
|
|
740
1024
|
self.patch.load(inputs[0].shape[2:])
|
|
741
1025
|
accumulators: dict[str, Accumulator] = {}
|
|
742
1026
|
|
|
743
|
-
|
|
1027
|
+
patch_iterator = self.patch.disassemble(*inputs)
|
|
744
1028
|
buffer = []
|
|
745
|
-
for i, patch_input in enumerate(
|
|
746
|
-
for
|
|
747
|
-
yield "
|
|
1029
|
+
for i, patch_input in enumerate(patch_iterator):
|
|
1030
|
+
for name, output_layer in super().named_forward(*patch_input):
|
|
1031
|
+
yield f";accu;{name}", output_layer
|
|
748
1032
|
buffer.append((name.split(".")[0], output_layer))
|
|
749
1033
|
if len(buffer) == 2:
|
|
750
1034
|
if buffer[0][0] != buffer[1][0]:
|
|
751
1035
|
if self._modulesArgs[buffer[0][0]]._isEnd:
|
|
752
1036
|
if buffer[0][0] not in accumulators:
|
|
753
|
-
accumulators[buffer[0][0]] = Accumulator(
|
|
754
|
-
|
|
1037
|
+
accumulators[buffer[0][0]] = Accumulator(
|
|
1038
|
+
self.patch.get_patch_slices(),
|
|
1039
|
+
self.patch.patch_size,
|
|
1040
|
+
self.patch.patch_combine,
|
|
1041
|
+
)
|
|
1042
|
+
accumulators[buffer[0][0]].add_layer(i, buffer[0][1])
|
|
755
1043
|
buffer.pop(0)
|
|
756
1044
|
if self._modulesArgs[buffer[0][0]]._isEnd:
|
|
757
1045
|
if buffer[0][0] not in accumulators:
|
|
758
|
-
accumulators[buffer[0][0]] = Accumulator(
|
|
759
|
-
|
|
1046
|
+
accumulators[buffer[0][0]] = Accumulator(
|
|
1047
|
+
self.patch.get_patch_slices(),
|
|
1048
|
+
self.patch.patch_size,
|
|
1049
|
+
self.patch.patch_combine,
|
|
1050
|
+
)
|
|
1051
|
+
accumulators[buffer[0][0]].add_layer(i, buffer[0][1])
|
|
760
1052
|
for name, accumulator in accumulators.items():
|
|
761
1053
|
yield name, accumulator.assemble()
|
|
762
1054
|
else:
|
|
763
|
-
for
|
|
1055
|
+
for name, output_layer in super().named_forward(*inputs):
|
|
764
1056
|
yield name, output_layer
|
|
765
1057
|
|
|
766
|
-
|
|
767
|
-
|
|
1058
|
+
def get_layers(
|
|
1059
|
+
self, inputs: list[torch.Tensor], layers_name: list[str]
|
|
1060
|
+
) -> Iterator[tuple[str, torch.Tensor, PatchIndexed | None]]:
|
|
768
1061
|
layers_name = layers_name.copy()
|
|
769
|
-
output_layer_accumulator
|
|
770
|
-
output_layer_patch_indexed
|
|
1062
|
+
output_layer_accumulator: dict[str, Accumulator] = {}
|
|
1063
|
+
output_layer_patch_indexed: dict[str, PatchIndexed] = {}
|
|
771
1064
|
it = 0
|
|
772
1065
|
debug = "KONFAI_DEBUG" in os.environ
|
|
773
|
-
for
|
|
774
|
-
name =
|
|
1066
|
+
for name_tmp, output_layer in self.named_forward(*inputs):
|
|
1067
|
+
name = name_tmp.replace(";accu;", "")
|
|
775
1068
|
if debug:
|
|
776
1069
|
if "KONFAI_DEBUG_LAST_LAYER" in os.environ:
|
|
777
|
-
os.environ["KONFAI_DEBUG_LAST_LAYER"] = "{
|
|
1070
|
+
os.environ["KONFAI_DEBUG_LAST_LAYER"] = f"{os.environ['KONFAI_DEBUG_LAST_LAYER']}|{name}:"
|
|
1071
|
+
f"{get_gpu_memory(output_layer.device)}:"
|
|
1072
|
+
f"{str(output_layer.device).replace('cuda:', '')}"
|
|
778
1073
|
else:
|
|
779
|
-
os.environ["KONFAI_DEBUG_LAST_LAYER"] =
|
|
1074
|
+
os.environ["KONFAI_DEBUG_LAST_LAYER"] = (
|
|
1075
|
+
f"{name}:{get_gpu_memory(output_layer.device)}:{str(output_layer.device).replace('cuda:', '')}"
|
|
1076
|
+
)
|
|
780
1077
|
it += 1
|
|
781
|
-
if name in layers_name or
|
|
782
|
-
if ";accu;" in
|
|
1078
|
+
if name in layers_name or name_tmp in layers_name:
|
|
1079
|
+
if ";accu;" in name_tmp:
|
|
783
1080
|
if name not in output_layer_patch_indexed:
|
|
784
|
-
|
|
1081
|
+
network_name = (
|
|
1082
|
+
name_tmp.split(".;accu;")[-2].split(".")[-1]
|
|
1083
|
+
if ".;accu;" in name_tmp
|
|
1084
|
+
else name_tmp.split(";accu;")[-2].split(".")[-1]
|
|
1085
|
+
)
|
|
785
1086
|
module = self
|
|
786
1087
|
network = None
|
|
787
|
-
if
|
|
1088
|
+
if network_name == "":
|
|
788
1089
|
network = module
|
|
789
1090
|
else:
|
|
790
1091
|
for n in name.split("."):
|
|
791
1092
|
module = module[n]
|
|
792
|
-
if isinstance(module, Network) and n ==
|
|
1093
|
+
if isinstance(module, Network) and n == network_name:
|
|
793
1094
|
network = module
|
|
794
1095
|
break
|
|
795
|
-
|
|
1096
|
+
|
|
796
1097
|
if network and network.patch:
|
|
797
|
-
output_layer_patch_indexed[name] =
|
|
1098
|
+
output_layer_patch_indexed[name] = PatchIndexed(network.patch, 0)
|
|
798
1099
|
|
|
799
1100
|
if name not in output_layer_accumulator:
|
|
800
|
-
output_layer_accumulator[name] = Accumulator(
|
|
801
|
-
|
|
802
|
-
|
|
803
|
-
|
|
1101
|
+
output_layer_accumulator[name] = Accumulator(
|
|
1102
|
+
output_layer_patch_indexed[name].patch.get_patch_slices(0),
|
|
1103
|
+
output_layer_patch_indexed[name].patch.patch_size,
|
|
1104
|
+
output_layer_patch_indexed[name].patch.patch_combine,
|
|
1105
|
+
)
|
|
1106
|
+
|
|
1107
|
+
if name_tmp in layers_name:
|
|
1108
|
+
output_layer_accumulator[name].add_layer(output_layer_patch_indexed[name].index, output_layer)
|
|
804
1109
|
output_layer_patch_indexed[name].index += 1
|
|
805
|
-
if output_layer_accumulator[name].
|
|
1110
|
+
if output_layer_accumulator[name].is_full():
|
|
806
1111
|
output_layer = output_layer_accumulator[name].assemble()
|
|
807
1112
|
output_layer_accumulator.pop(name)
|
|
808
1113
|
output_layer_patch_indexed.pop(name)
|
|
809
|
-
layers_name.remove(
|
|
810
|
-
yield
|
|
1114
|
+
layers_name.remove(name_tmp)
|
|
1115
|
+
yield name_tmp, output_layer, None
|
|
811
1116
|
|
|
812
1117
|
if name in layers_name:
|
|
813
|
-
if ";accu;" in
|
|
1118
|
+
if ";accu;" in name_tmp:
|
|
814
1119
|
yield name, output_layer, output_layer_patch_indexed[name]
|
|
815
1120
|
output_layer_patch_indexed[name].index += 1
|
|
816
|
-
if output_layer_patch_indexed[name].
|
|
1121
|
+
if output_layer_patch_indexed[name].is_full():
|
|
817
1122
|
output_layer_patch_indexed.pop(name)
|
|
818
1123
|
layers_name.remove(name)
|
|
819
1124
|
else:
|
|
@@ -823,62 +1128,85 @@ class Network(ModuleArgsDict, ABC):
|
|
|
823
1128
|
if not len(layers_name):
|
|
824
1129
|
break
|
|
825
1130
|
|
|
826
|
-
def
|
|
827
|
-
metric_tmp = {
|
|
1131
|
+
def init_outputs_group(self):
|
|
1132
|
+
metric_tmp = {
|
|
1133
|
+
network.measure: network.measure.outputs_criterions.keys()
|
|
1134
|
+
for network in self.get_networks().values()
|
|
1135
|
+
if network.measure
|
|
1136
|
+
}
|
|
828
1137
|
for k, v in metric_tmp.items():
|
|
829
1138
|
for a in v:
|
|
830
1139
|
outputs_group = OutputsGroup(k)
|
|
831
1140
|
outputs_group.append(a)
|
|
832
|
-
for
|
|
833
|
-
if ":" in
|
|
834
|
-
outputs_group.append(
|
|
1141
|
+
for targets_group in k.outputs_criterions[a].keys():
|
|
1142
|
+
if ":" in targets_group:
|
|
1143
|
+
outputs_group.append(targets_group.replace(":", "."))
|
|
835
1144
|
|
|
836
1145
|
self.outputsGroup.append(outputs_group)
|
|
837
1146
|
|
|
838
|
-
def forward(
|
|
1147
|
+
def forward(
|
|
1148
|
+
self,
|
|
1149
|
+
data_dict: dict[tuple[str, bool], torch.Tensor],
|
|
1150
|
+
output_layers: list[str] = [],
|
|
1151
|
+
) -> list[tuple[str, torch.Tensor]]:
|
|
839
1152
|
if not len(self.outputsGroup) and not len(output_layers):
|
|
840
1153
|
return []
|
|
841
1154
|
|
|
842
|
-
self.
|
|
1155
|
+
self.reset_loss()
|
|
843
1156
|
results = []
|
|
844
1157
|
measure_output_layers = set()
|
|
845
|
-
for
|
|
846
|
-
|
|
1158
|
+
for _outputs_group in self.outputsGroup:
|
|
1159
|
+
for name in _outputs_group:
|
|
847
1160
|
measure_output_layers.add(name)
|
|
848
|
-
for name, layer, patch_indexed in self.get_layers(
|
|
849
|
-
|
|
1161
|
+
for name, layer, patch_indexed in self.get_layers(
|
|
1162
|
+
[v for k, v in data_dict.items() if k[1]],
|
|
1163
|
+
list(set(list(measure_output_layers) + output_layers)),
|
|
1164
|
+
):
|
|
1165
|
+
|
|
850
1166
|
outputs_group = [outputs_group for outputs_group in self.outputsGroup if name in outputs_group]
|
|
851
1167
|
if len(outputs_group) > 0:
|
|
852
1168
|
if patch_indexed is None:
|
|
853
|
-
targets = {k[0]
|
|
1169
|
+
targets = {k[0]: v for k, v in data_dict.items()}
|
|
854
1170
|
nb = 1
|
|
855
1171
|
else:
|
|
856
|
-
targets = {
|
|
857
|
-
|
|
1172
|
+
targets = {
|
|
1173
|
+
k[0]: patch_indexed.patch.get_data(v, patch_indexed.index, 0, False)
|
|
1174
|
+
for k, v in data_dict.items()
|
|
1175
|
+
}
|
|
1176
|
+
nb = patch_indexed.patch.get_size(0)
|
|
858
1177
|
|
|
859
1178
|
for output_group in outputs_group:
|
|
860
|
-
output_group.
|
|
861
|
-
if output_group.
|
|
862
|
-
targets.update(
|
|
863
|
-
|
|
1179
|
+
output_group.add_layer(name, layer)
|
|
1180
|
+
if output_group.is_done():
|
|
1181
|
+
targets.update(
|
|
1182
|
+
{k.replace(".", ":"): v for k, v in output_group.layers.items() if k != output_group[0]}
|
|
1183
|
+
)
|
|
1184
|
+
output_group.measure.update(
|
|
1185
|
+
output_group[0],
|
|
1186
|
+
output_group.layers[output_group[0]],
|
|
1187
|
+
targets,
|
|
1188
|
+
self._it,
|
|
1189
|
+
nb,
|
|
1190
|
+
self.training,
|
|
1191
|
+
)
|
|
864
1192
|
output_group.clear()
|
|
865
1193
|
if name in output_layers:
|
|
866
1194
|
results.append((name, layer))
|
|
867
1195
|
return results
|
|
868
1196
|
|
|
869
1197
|
@_function_network()
|
|
870
|
-
def
|
|
1198
|
+
def reset_loss(self):
|
|
871
1199
|
if self.measure:
|
|
872
|
-
self.measure.
|
|
1200
|
+
self.measure.reset_loss()
|
|
873
1201
|
|
|
874
1202
|
@_function_network()
|
|
875
1203
|
def backward(self, model: Self):
|
|
876
1204
|
if self.measure:
|
|
877
1205
|
if self.scaler and self.optimizer:
|
|
878
|
-
model._requires_grad(list(self.measure.
|
|
879
|
-
for loss in self.measure.
|
|
1206
|
+
model._requires_grad(list(self.measure.outputs_criterions.keys()))
|
|
1207
|
+
for loss in self.measure.get_loss():
|
|
880
1208
|
self.scaler.scale(loss / self.nb_batch_per_step).backward()
|
|
881
|
-
|
|
1209
|
+
|
|
882
1210
|
if self._it % self.nb_batch_per_step == 0:
|
|
883
1211
|
self.scaler.step(self.optimizer)
|
|
884
1212
|
self.scaler.update()
|
|
@@ -887,93 +1215,133 @@ class Network(ModuleArgsDict, ABC):
|
|
|
887
1215
|
|
|
888
1216
|
@_function_network()
|
|
889
1217
|
def update_lr(self):
|
|
890
|
-
self._nb_lr_update+=1
|
|
1218
|
+
self._nb_lr_update += 1
|
|
891
1219
|
step = 0
|
|
892
|
-
|
|
893
|
-
for
|
|
894
|
-
if value is None or (self._nb_lr_update >= step and self._nb_lr_update < step+value):
|
|
1220
|
+
_scheduler = None
|
|
1221
|
+
for _scheduler, value in self.schedulers.items():
|
|
1222
|
+
if value is None or (self._nb_lr_update >= step and self._nb_lr_update < step + value):
|
|
895
1223
|
break
|
|
896
1224
|
step += value
|
|
897
|
-
if
|
|
898
|
-
if
|
|
1225
|
+
if _scheduler:
|
|
1226
|
+
if _scheduler.__class__.__name__ == "ReduceLROnPlateau":
|
|
899
1227
|
if self.measure:
|
|
900
|
-
|
|
1228
|
+
_scheduler.step(sum(self.measure.get_last_values(0).values()))
|
|
901
1229
|
else:
|
|
902
|
-
|
|
1230
|
+
_scheduler.step()
|
|
903
1231
|
|
|
904
1232
|
@_function_network()
|
|
905
|
-
def
|
|
1233
|
+
def get_networks(self) -> Self:
|
|
906
1234
|
return self
|
|
907
1235
|
|
|
908
|
-
|
|
1236
|
+
@staticmethod
|
|
1237
|
+
def to(module: ModuleArgsDict, device: int):
|
|
909
1238
|
if "device" not in os.environ:
|
|
910
1239
|
os.environ["device"] = str(device)
|
|
911
1240
|
for k, v in module.items():
|
|
912
1241
|
if module._modulesArgs[k].gpu == "cpu":
|
|
913
1242
|
if module._modulesArgs[k].isGPU_Checkpoint:
|
|
914
|
-
os.environ["device"] = str(int(os.environ["device"])+1)
|
|
915
|
-
module._modulesArgs[k].gpu = str(
|
|
1243
|
+
os.environ["device"] = str(int(os.environ["device"]) + 1)
|
|
1244
|
+
module._modulesArgs[k].gpu = str(get_device(int(os.environ["device"])))
|
|
916
1245
|
if isinstance(v, ModuleArgsDict):
|
|
917
1246
|
v = Network.to(v, int(os.environ["device"]))
|
|
918
1247
|
else:
|
|
919
|
-
v = v.to(
|
|
1248
|
+
v = v.to(get_device(int(os.environ["device"])))
|
|
920
1249
|
if isinstance(module, Network):
|
|
921
1250
|
if module.optimizer is not None:
|
|
922
1251
|
for state in module.optimizer.state.values():
|
|
923
1252
|
for k, v in state.items():
|
|
924
1253
|
if isinstance(v, torch.Tensor):
|
|
925
|
-
state[k] = v.to(
|
|
1254
|
+
state[k] = v.to(get_device(int(os.environ["device"])))
|
|
926
1255
|
return module
|
|
927
|
-
|
|
928
|
-
def
|
|
1256
|
+
|
|
1257
|
+
def get_name(self) -> str:
|
|
929
1258
|
return self.name
|
|
930
|
-
|
|
931
|
-
def
|
|
1259
|
+
|
|
1260
|
+
def set_name(self, name: str) -> Self:
|
|
932
1261
|
self.name = name
|
|
933
1262
|
return self
|
|
934
|
-
|
|
935
|
-
def
|
|
1263
|
+
|
|
1264
|
+
def set_state(self, state: NetState):
|
|
936
1265
|
for module in self.modules():
|
|
937
1266
|
if isinstance(module, ModuleArgsDict):
|
|
938
1267
|
module._training = state
|
|
939
1268
|
|
|
1269
|
+
|
|
940
1270
|
class MinimalModel(Network):
|
|
941
|
-
|
|
942
|
-
def __init__(
|
|
943
|
-
|
|
944
|
-
|
|
945
|
-
|
|
946
|
-
|
|
947
|
-
|
|
1271
|
+
|
|
1272
|
+
def __init__(
|
|
1273
|
+
self,
|
|
1274
|
+
model: Network,
|
|
1275
|
+
optimizer: OptimizerLoader = OptimizerLoader(),
|
|
1276
|
+
schedulers: dict[str, LRSchedulersLoader] = {"default:ReduceLROnPlateau": LRSchedulersLoader(0)},
|
|
1277
|
+
outputs_criterions: dict[str, TargetCriterionsLoader] = {"default": TargetCriterionsLoader()},
|
|
1278
|
+
patch: ModelPatch | None = None,
|
|
1279
|
+
dim: int = 3,
|
|
1280
|
+
nb_batch_per_step=1,
|
|
1281
|
+
init_type="normal",
|
|
1282
|
+
init_gain=0.02,
|
|
1283
|
+
):
|
|
1284
|
+
super().__init__(
|
|
1285
|
+
1,
|
|
1286
|
+
optimizer,
|
|
1287
|
+
schedulers,
|
|
1288
|
+
outputs_criterions,
|
|
1289
|
+
patch,
|
|
1290
|
+
nb_batch_per_step,
|
|
1291
|
+
init_type,
|
|
1292
|
+
init_gain,
|
|
1293
|
+
dim,
|
|
1294
|
+
)
|
|
948
1295
|
self.add_module("Model", model)
|
|
949
1296
|
|
|
950
|
-
|
|
1297
|
+
|
|
1298
|
+
class ModelLoader:
|
|
951
1299
|
|
|
952
1300
|
@config("Model")
|
|
953
|
-
def __init__(self, classpath
|
|
954
|
-
self.module, self.name =
|
|
955
|
-
|
|
956
|
-
def
|
|
957
|
-
|
|
958
|
-
|
|
959
|
-
|
|
960
|
-
|
|
1301
|
+
def __init__(self, classpath: str = "default:segmentation.UNet.UNet") -> None:
|
|
1302
|
+
self.module, self.name = get_module(classpath, "konfai.models")
|
|
1303
|
+
|
|
1304
|
+
def get_model(
|
|
1305
|
+
self,
|
|
1306
|
+
train: bool = True,
|
|
1307
|
+
konfai_args: str | None = None,
|
|
1308
|
+
konfai_without=[
|
|
1309
|
+
"optimizer",
|
|
1310
|
+
"schedulers",
|
|
1311
|
+
"nb_batch_per_step",
|
|
1312
|
+
"init_type",
|
|
1313
|
+
"init_gain",
|
|
1314
|
+
],
|
|
1315
|
+
) -> Network:
|
|
1316
|
+
if not konfai_args:
|
|
1317
|
+
konfai_args = f"{konfai_root()}.Model"
|
|
1318
|
+
konfai_args += "." + self.name
|
|
1319
|
+
model = config(konfai_args)(getattr(importlib.import_module(self.module), self.name))(
|
|
1320
|
+
config=None, konfai_without=konfai_without if not train else []
|
|
1321
|
+
)
|
|
961
1322
|
if not isinstance(model, Network):
|
|
962
|
-
model = config(
|
|
963
|
-
|
|
1323
|
+
model = config(konfai_args)(partial(MinimalModel, model))(
|
|
1324
|
+
config=None, konfai_without=konfai_without + ["model"] if not train else []
|
|
1325
|
+
)
|
|
1326
|
+
model.set_name(self.name)
|
|
964
1327
|
return model
|
|
965
|
-
|
|
966
|
-
|
|
1328
|
+
|
|
1329
|
+
|
|
1330
|
+
class CPUModel:
|
|
967
1331
|
|
|
968
1332
|
def __init__(self, model: Network) -> None:
|
|
969
1333
|
self.module = model
|
|
970
|
-
self.device = torch.device(
|
|
1334
|
+
self.device = torch.device("cpu")
|
|
971
1335
|
|
|
972
1336
|
def train(self):
|
|
973
1337
|
self.module.train()
|
|
974
|
-
|
|
975
|
-
def eval(self):
|
|
1338
|
+
|
|
1339
|
+
def eval(self): # noqa: A003
|
|
976
1340
|
self.module.eval()
|
|
977
|
-
|
|
978
|
-
def __call__(
|
|
979
|
-
|
|
1341
|
+
|
|
1342
|
+
def __call__(
|
|
1343
|
+
self,
|
|
1344
|
+
data_dict: dict[tuple[str, bool], torch.Tensor],
|
|
1345
|
+
output_layers: list[str] = [],
|
|
1346
|
+
) -> list[tuple[str, torch.Tensor]]:
|
|
1347
|
+
return self.module(data_dict, output_layers)
|