konfai 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of konfai might be problematic. Click here for more details.
- konfai/__init__.py +16 -0
- konfai/data/HDF5.py +326 -0
- konfai/data/__init__.py +0 -0
- konfai/data/augmentation.py +597 -0
- konfai/data/dataset.py +470 -0
- konfai/data/transform.py +536 -0
- konfai/evaluator.py +146 -0
- konfai/main.py +43 -0
- konfai/metric/__init__.py +0 -0
- konfai/metric/measure.py +488 -0
- konfai/metric/schedulers.py +49 -0
- konfai/models/classification/convNeXt.py +175 -0
- konfai/models/classification/resnet.py +116 -0
- konfai/models/generation/cStyleGan.py +137 -0
- konfai/models/generation/ddpm.py +218 -0
- konfai/models/generation/diffusionGan.py +557 -0
- konfai/models/generation/gan.py +134 -0
- konfai/models/generation/vae.py +72 -0
- konfai/models/registration/registration.py +136 -0
- konfai/models/representation/representation.py +57 -0
- konfai/models/segmentation/NestedUNet.py +53 -0
- konfai/models/segmentation/UNet.py +58 -0
- konfai/network/__init__.py +0 -0
- konfai/network/blocks.py +348 -0
- konfai/network/network.py +950 -0
- konfai/predictor.py +366 -0
- konfai/trainer.py +330 -0
- konfai/utils/ITK.py +269 -0
- konfai/utils/Registration.py +199 -0
- konfai/utils/__init__.py +0 -0
- konfai/utils/config.py +218 -0
- konfai/utils/dataset.py +764 -0
- konfai/utils/utils.py +493 -0
- konfai-1.0.0.dist-info/METADATA +68 -0
- konfai-1.0.0.dist-info/RECORD +39 -0
- konfai-1.0.0.dist-info/WHEEL +5 -0
- konfai-1.0.0.dist-info/entry_points.txt +3 -0
- konfai-1.0.0.dist-info/licenses/LICENSE +201 -0
- konfai-1.0.0.dist-info/top_level.txt +1 -0
|
@@ -0,0 +1,950 @@
|
|
|
1
|
+
from functools import partial
|
|
2
|
+
import importlib
|
|
3
|
+
import inspect
|
|
4
|
+
import os
|
|
5
|
+
from typing import Iterable, Iterator, Callable
|
|
6
|
+
from typing_extensions import Self
|
|
7
|
+
import torch
|
|
8
|
+
from abc import ABC
|
|
9
|
+
import numpy as np
|
|
10
|
+
from torch._jit_internal import _copy_to_script_wrapper
|
|
11
|
+
from collections import OrderedDict
|
|
12
|
+
from torch.utils.checkpoint import checkpoint
|
|
13
|
+
from typing import Union
|
|
14
|
+
from enum import Enum
|
|
15
|
+
|
|
16
|
+
from KonfAI.konfai import DEEP_LEARNING_API_ROOT
|
|
17
|
+
from KonfAI.konfai.metric.schedulers import Scheduler
|
|
18
|
+
from KonfAI.konfai.utils.config import config
|
|
19
|
+
from KonfAI.konfai.utils.utils import State, _getModule, getDevice, getGPUMemory
|
|
20
|
+
from KonfAI.konfai.data.HDF5 import Accumulator, ModelPatch
|
|
21
|
+
|
|
22
|
+
class NetState(Enum):
|
|
23
|
+
TRAIN = 0,
|
|
24
|
+
PREDICTION = 1
|
|
25
|
+
|
|
26
|
+
class Patch_Indexed():
|
|
27
|
+
|
|
28
|
+
def __init__(self, patch: ModelPatch, index: int) -> None:
|
|
29
|
+
self.patch = patch
|
|
30
|
+
self.index = index
|
|
31
|
+
|
|
32
|
+
def isFull(self) -> bool:
|
|
33
|
+
return len(self.patch.getPatch_slices(0)) == self.index
|
|
34
|
+
|
|
35
|
+
class OptimizerLoader():
|
|
36
|
+
|
|
37
|
+
@config("Optimizer")
|
|
38
|
+
def __init__(self, name: str = "AdamW") -> None:
|
|
39
|
+
self.name = name
|
|
40
|
+
|
|
41
|
+
def getOptimizer(self, key: str, parameter: Iterator[torch.nn.parameter.Parameter]) -> torch.optim.Optimizer:
|
|
42
|
+
torch.optim.AdamW
|
|
43
|
+
return config("{}.Model.{}.Optimizer".format(DEEP_LEARNING_API_ROOT(), key))(getattr(importlib.import_module('torch.optim'), self.name))(parameter, config = None)
|
|
44
|
+
|
|
45
|
+
class SchedulerStep():
|
|
46
|
+
|
|
47
|
+
@config(None)
|
|
48
|
+
def __init__(self, nb_step : int = 0) -> None:
|
|
49
|
+
self.nb_step = nb_step
|
|
50
|
+
|
|
51
|
+
class LRSchedulersLoader():
|
|
52
|
+
|
|
53
|
+
@config("Schedulers")
|
|
54
|
+
def __init__(self, params: dict[str, SchedulerStep] = {"default:ReduceLROnPlateau" : SchedulerStep(0)}) -> None:
|
|
55
|
+
self.params = params
|
|
56
|
+
|
|
57
|
+
def getShedulers(self, key: str, optimizer: torch.optim.Optimizer) -> dict[torch.optim.lr_scheduler._LRScheduler, int]:
|
|
58
|
+
shedulers : dict[torch.optim.lr_scheduler._LRScheduler, int] = {}
|
|
59
|
+
for name, step in self.params.items():
|
|
60
|
+
if name:
|
|
61
|
+
shedulers[config("Trainer.Model.{}.Schedulers.{}".format(key, name))(getattr(importlib.import_module('torch.optim.lr_scheduler'), name))(optimizer, config = None)] = step.nb_step
|
|
62
|
+
return shedulers
|
|
63
|
+
|
|
64
|
+
class SchedulersLoader():
|
|
65
|
+
|
|
66
|
+
@config("Schedulers")
|
|
67
|
+
def __init__(self, params: dict[str, SchedulerStep] = {"default:Constant" : SchedulerStep(0)}) -> None:
|
|
68
|
+
self.params = params
|
|
69
|
+
|
|
70
|
+
def getShedulers(self, key: str) -> dict[torch.optim.lr_scheduler._LRScheduler, int]:
|
|
71
|
+
shedulers : dict[Scheduler, int] = {}
|
|
72
|
+
for name, step in self.params.items():
|
|
73
|
+
if name:
|
|
74
|
+
shedulers[getattr(importlib.import_module("KonfAI.konfai.metric.schedulers"), name)(config = None, DL_args = key)] = step.nb_step
|
|
75
|
+
return shedulers
|
|
76
|
+
|
|
77
|
+
class CriterionsAttr():
|
|
78
|
+
|
|
79
|
+
@config()
|
|
80
|
+
def __init__(self, l: SchedulersLoader = SchedulersLoader(), isLoss: bool = True, group: int = 0, stepStart: int = 0, stepStop: Union[int, None] = None, accumulation: bool = False) -> None:
|
|
81
|
+
self.l = l
|
|
82
|
+
self.isTorchCriterion = True
|
|
83
|
+
self.isLoss = isLoss
|
|
84
|
+
self.stepStart = stepStart
|
|
85
|
+
self.stepStop = stepStop
|
|
86
|
+
self.group = group
|
|
87
|
+
self.accumulation = accumulation
|
|
88
|
+
self.sheduler = None
|
|
89
|
+
|
|
90
|
+
class CriterionsLoader():
|
|
91
|
+
|
|
92
|
+
@config()
|
|
93
|
+
def __init__(self, criterionsLoader: dict[str, CriterionsAttr] = {"default:torch_nn_CrossEntropyLoss:Dice:NCC": CriterionsAttr()}) -> None:
|
|
94
|
+
self.criterionsLoader = criterionsLoader
|
|
95
|
+
|
|
96
|
+
def getCriterions(self, model_classname : str, output_group : str, target_group : str) -> dict[torch.nn.Module, CriterionsAttr]:
|
|
97
|
+
criterions = {}
|
|
98
|
+
for module_classpath, criterionsAttr in self.criterionsLoader.items():
|
|
99
|
+
module, name = _getModule(module_classpath, "measure")
|
|
100
|
+
criterionsAttr.isTorchCriterion = module.startswith("torch")
|
|
101
|
+
criterionsAttr.sheduler = criterionsAttr.l.getShedulers("{}.Model.{}.outputsCriterions.{}.targetsCriterions.{}.criterionsLoader.{}".format(DEEP_LEARNING_API_ROOT(), model_classname, output_group, target_group, module_classpath))
|
|
102
|
+
criterions[config("{}.Model.{}.outputsCriterions.{}.targetsCriterions.{}.criterionsLoader.{}".format(DEEP_LEARNING_API_ROOT(), model_classname, output_group, target_group, module_classpath))(getattr(importlib.import_module(module), name))(config = None)] = criterionsAttr
|
|
103
|
+
return criterions
|
|
104
|
+
|
|
105
|
+
class TargetCriterionsLoader():
|
|
106
|
+
|
|
107
|
+
@config()
|
|
108
|
+
def __init__(self, targetsCriterions : dict[str, CriterionsLoader] = {"default" : CriterionsLoader()}) -> None:
|
|
109
|
+
self.targetsCriterions = targetsCriterions
|
|
110
|
+
|
|
111
|
+
def getTargetsCriterions(self, output_group : str, model_classname : str) -> dict[str, dict[torch.nn.Module, float]]:
|
|
112
|
+
targetsCriterions = {}
|
|
113
|
+
for target_group, criterionsLoader in self.targetsCriterions.items():
|
|
114
|
+
targetsCriterions[target_group] = criterionsLoader.getCriterions(model_classname, output_group, target_group)
|
|
115
|
+
return targetsCriterions
|
|
116
|
+
|
|
117
|
+
class Measure():
|
|
118
|
+
|
|
119
|
+
class Loss():
|
|
120
|
+
|
|
121
|
+
def __init__(self, name: str, output_group: str, target_group: str, group: int, isLoss: bool, accumulation: bool) -> None:
|
|
122
|
+
self.name = name
|
|
123
|
+
self.isLoss = isLoss
|
|
124
|
+
self.accumulation = accumulation
|
|
125
|
+
self.output_group = output_group
|
|
126
|
+
self.target_group = target_group
|
|
127
|
+
self.group = group
|
|
128
|
+
|
|
129
|
+
self._loss: list[torch.Tensor] = []
|
|
130
|
+
self._weight: list[float] = []
|
|
131
|
+
self._values: list[float] = []
|
|
132
|
+
|
|
133
|
+
def resetLoss(self) -> None:
|
|
134
|
+
self._loss.clear()
|
|
135
|
+
|
|
136
|
+
def add(self, weight: float, value: torch.Tensor) -> None:
|
|
137
|
+
self._loss.append(value if self.isLoss else value.detach())
|
|
138
|
+
self._values.append(value.item())
|
|
139
|
+
self._weight.append(weight)
|
|
140
|
+
|
|
141
|
+
def getLastLoss(self) -> torch.Tensor:
|
|
142
|
+
return self._loss[-1]*self._weight[-1] if len(self._loss) else torch.zeros((1), requires_grad = True)
|
|
143
|
+
|
|
144
|
+
def getLoss(self) -> torch.Tensor:
|
|
145
|
+
return torch.stack([w*l for w, l in zip(self._weight, self._loss)], dim=0).mean(dim=0) if len(self._loss) else torch.zeros((1), requires_grad = True)
|
|
146
|
+
|
|
147
|
+
def __len__(self) -> int:
|
|
148
|
+
return len(self._loss)
|
|
149
|
+
|
|
150
|
+
def __init__(self, model_classname : str, outputsCriterions: dict[str, TargetCriterionsLoader]) -> None:
|
|
151
|
+
super().__init__()
|
|
152
|
+
self.outputsCriterions = {}
|
|
153
|
+
for output_group, targetCriterionsLoader in outputsCriterions.items():
|
|
154
|
+
self.outputsCriterions[output_group.replace(":", ".")] = targetCriterionsLoader.getTargetsCriterions(output_group, model_classname)
|
|
155
|
+
self._loss : dict[int, dict[str, Measure.Loss]] = {}
|
|
156
|
+
|
|
157
|
+
def init(self, model : torch.nn.Module) -> None:
|
|
158
|
+
outputs_group_rename = {}
|
|
159
|
+
for output_group in self.outputsCriterions.keys():
|
|
160
|
+
for target_group in self.outputsCriterions[output_group]:
|
|
161
|
+
for criterion in self.outputsCriterions[output_group][target_group]:
|
|
162
|
+
if not self.outputsCriterions[output_group][target_group][criterion].isTorchCriterion:
|
|
163
|
+
outputs_group_rename[output_group] = criterion.init(model, output_group, target_group)
|
|
164
|
+
|
|
165
|
+
outputsCriterions_bak = self.outputsCriterions.copy()
|
|
166
|
+
for old, new in outputs_group_rename.items():
|
|
167
|
+
self.outputsCriterions.pop(old)
|
|
168
|
+
self.outputsCriterions[new] = outputsCriterions_bak[old]
|
|
169
|
+
for output_group in self.outputsCriterions:
|
|
170
|
+
for target_group in self.outputsCriterions[output_group]:
|
|
171
|
+
for criterion, criterionsAttr in self.outputsCriterions[output_group][target_group].items():
|
|
172
|
+
if criterionsAttr.group not in self._loss:
|
|
173
|
+
self._loss[criterionsAttr.group] = {}
|
|
174
|
+
self._loss[criterionsAttr.group]["{}:{}:{}".format(output_group, target_group, criterion.__class__.__name__)] = Measure.Loss(criterion.__class__.__name__, output_group, target_group, criterionsAttr.group, criterionsAttr.isLoss, criterionsAttr.accumulation)
|
|
175
|
+
|
|
176
|
+
def update(self, output_group: str, output : torch.Tensor, data_dict: dict[str, torch.Tensor], it: int, nb_patch: int, training: bool) -> None:
|
|
177
|
+
for target_group in self.outputsCriterions[output_group]:
|
|
178
|
+
target = [data_dict[group].to(output[0].device).detach() for group in target_group.split("/") if group in data_dict]
|
|
179
|
+
|
|
180
|
+
for criterion, criterionsAttr in self.outputsCriterions[output_group][target_group].items():
|
|
181
|
+
if it >= criterionsAttr.stepStart and (criterionsAttr.stepStop is None or it <= criterionsAttr.stepStop):
|
|
182
|
+
scheduler = self.update_scheduler(criterionsAttr.sheduler, it)
|
|
183
|
+
self._loss[criterionsAttr.group]["{}:{}:{}".format(output_group, target_group, criterion.__class__.__name__)].add(scheduler.get_value(), criterion(output, *target))
|
|
184
|
+
if training and len(np.unique([len(l) for l in self._loss[criterionsAttr.group].values() if l.accumulation and l.isLoss])) == 1:
|
|
185
|
+
if criterionsAttr.isLoss:
|
|
186
|
+
loss = torch.zeros((1), requires_grad = True)
|
|
187
|
+
for v in [l for l in self._loss[criterionsAttr.group].values() if l.accumulation and l.isLoss]:
|
|
188
|
+
l = v.getLastLoss()
|
|
189
|
+
loss = loss.to(l.device)+l
|
|
190
|
+
loss = loss/nb_patch
|
|
191
|
+
loss.backward()
|
|
192
|
+
|
|
193
|
+
def getLoss(self) -> list[torch.Tensor]:
|
|
194
|
+
loss: dict[int, torch.Tensor] = {}
|
|
195
|
+
for group in self._loss.keys():
|
|
196
|
+
loss[group] = torch.zeros((1), requires_grad = True)
|
|
197
|
+
for v in self._loss[group].values():
|
|
198
|
+
if v.isLoss and not v.accumulation:
|
|
199
|
+
l = v.getLoss()
|
|
200
|
+
loss[v.group] = loss[v.group].to(l.device)+l
|
|
201
|
+
return loss.values()
|
|
202
|
+
|
|
203
|
+
def resetLoss(self) -> None:
|
|
204
|
+
for group in self._loss.keys():
|
|
205
|
+
for v in self._loss[group].values():
|
|
206
|
+
v.resetLoss()
|
|
207
|
+
|
|
208
|
+
def getLastValues(self, n: int = 1) -> dict[str, float]:
|
|
209
|
+
result = {}
|
|
210
|
+
for group in self._loss.keys():
|
|
211
|
+
result.update({name : np.nanmean(value._values[-n:] if n > 0 else value._values) for name, value in self._loss[group].items() if n < 0 or len(value._values) >= n})
|
|
212
|
+
return result
|
|
213
|
+
|
|
214
|
+
def getLastWeights(self, n: int = 1) -> dict[str, float]:
|
|
215
|
+
result = {}
|
|
216
|
+
for group in self._loss.keys():
|
|
217
|
+
result.update({name : np.nanmean(value._weight[-n:] if n > 0 else value._weight) for name, value in self._loss[group].items() if n < 0 or len(value._values) >= n})
|
|
218
|
+
return result
|
|
219
|
+
|
|
220
|
+
def format(self, isLoss: bool, n: int) -> dict[str, tuple[float, float]]:
|
|
221
|
+
result = dict()
|
|
222
|
+
for group in self._loss.keys():
|
|
223
|
+
for name, loss in self._loss[group].items():
|
|
224
|
+
if loss.isLoss == isLoss and len(loss._values) >= n:
|
|
225
|
+
result[name] = (np.nanmean(loss._weight[-n:]), np.nanmean(loss._values[-n:]))
|
|
226
|
+
return result
|
|
227
|
+
|
|
228
|
+
def update_scheduler(self, schedulers: dict[Scheduler, int], it: int) -> Scheduler:
|
|
229
|
+
step = 0
|
|
230
|
+
scheduler = None
|
|
231
|
+
for scheduler, value in schedulers.items():
|
|
232
|
+
if value is None or (it >= step and it < step+value):
|
|
233
|
+
break
|
|
234
|
+
step += value
|
|
235
|
+
if scheduler:
|
|
236
|
+
scheduler.step(it-step)
|
|
237
|
+
return scheduler
|
|
238
|
+
|
|
239
|
+
class ModuleArgsDict(torch.nn.Module, ABC):
|
|
240
|
+
|
|
241
|
+
class ModuleArgs:
|
|
242
|
+
|
|
243
|
+
def __init__(self, in_branch: list[str], out_branch: list[str], pretrained : bool, alias : list[str], requires_grad: Union[bool, None], training: Union[None, bool]) -> None:
|
|
244
|
+
super().__init__()
|
|
245
|
+
self.alias= alias
|
|
246
|
+
self.pretrained = pretrained
|
|
247
|
+
self.in_branch = in_branch
|
|
248
|
+
self.out_branch = out_branch
|
|
249
|
+
self.in_channels = None
|
|
250
|
+
self.in_is_channel = True
|
|
251
|
+
self.out_channels = None
|
|
252
|
+
self.out_is_channel = True
|
|
253
|
+
self.requires_grad = requires_grad
|
|
254
|
+
self.isCheckpoint = False
|
|
255
|
+
self.isGPU_Checkpoint = False
|
|
256
|
+
self.gpu = "cpu"
|
|
257
|
+
self.training = training
|
|
258
|
+
self._isEnd = False
|
|
259
|
+
|
|
260
|
+
def __init__(self) -> None:
|
|
261
|
+
super().__init__()
|
|
262
|
+
self._modulesArgs : dict[str, ModuleArgsDict.ModuleArgs] = dict()
|
|
263
|
+
self._training = NetState.TRAIN
|
|
264
|
+
|
|
265
|
+
def _addindent(self, s_: str, numSpaces : int):
|
|
266
|
+
s = s_.split('\n')
|
|
267
|
+
if len(s) == 1:
|
|
268
|
+
return s_
|
|
269
|
+
first = s.pop(0)
|
|
270
|
+
s = [(numSpaces * ' ') + line for line in s]
|
|
271
|
+
s = '\n'.join(s)
|
|
272
|
+
s = first + '\n' + s
|
|
273
|
+
return s
|
|
274
|
+
|
|
275
|
+
def __repr__(self):
|
|
276
|
+
extra_lines = []
|
|
277
|
+
|
|
278
|
+
extra_repr = self.extra_repr()
|
|
279
|
+
if extra_repr:
|
|
280
|
+
extra_lines = extra_repr.split('\n')
|
|
281
|
+
|
|
282
|
+
child_lines = []
|
|
283
|
+
is_simple_branch = lambda x : len(x) > 1 or x[0] != 0
|
|
284
|
+
for key, module in self._modules.items():
|
|
285
|
+
mod_str = repr(module)
|
|
286
|
+
|
|
287
|
+
mod_str = self._addindent(mod_str, 2)
|
|
288
|
+
desc = ""
|
|
289
|
+
if is_simple_branch(self._modulesArgs[key].in_branch) or is_simple_branch(self._modulesArgs[key].out_branch):
|
|
290
|
+
desc += ", {}->{}".format(self._modulesArgs[key].in_branch, self._modulesArgs[key].out_branch)
|
|
291
|
+
if not self._modulesArgs[key].pretrained:
|
|
292
|
+
desc += ", pretrained=False"
|
|
293
|
+
if self._modulesArgs[key].alias:
|
|
294
|
+
desc += ", alias={}".format(self._modulesArgs[key].alias)
|
|
295
|
+
desc += ", in_channels={}".format(self._modulesArgs[key].in_channels)
|
|
296
|
+
desc += ", in_is_channel={}".format(self._modulesArgs[key].in_is_channel)
|
|
297
|
+
desc += ", out_channels={}".format(self._modulesArgs[key].out_channels)
|
|
298
|
+
desc += ", out_is_channel={}".format(self._modulesArgs[key].out_is_channel)
|
|
299
|
+
desc += ", is_end={}".format(self._modulesArgs[key]._isEnd)
|
|
300
|
+
desc += ", isInCheckpoint={}".format(self._modulesArgs[key].isCheckpoint)
|
|
301
|
+
desc += ", isInGPU_Checkpoint={}".format(self._modulesArgs[key].isGPU_Checkpoint)
|
|
302
|
+
desc += ", requires_grad={}".format(self._modulesArgs[key].requires_grad)
|
|
303
|
+
desc += ", device={}".format(self._modulesArgs[key].gpu)
|
|
304
|
+
|
|
305
|
+
child_lines.append("({}{}) {}".format(key, desc, mod_str))
|
|
306
|
+
|
|
307
|
+
lines = extra_lines + child_lines
|
|
308
|
+
|
|
309
|
+
desc = ""
|
|
310
|
+
if lines:
|
|
311
|
+
if len(extra_lines) == 1 and not child_lines:
|
|
312
|
+
desc += extra_lines[0]
|
|
313
|
+
else:
|
|
314
|
+
desc += '\n ' + '\n '.join(lines) + '\n'
|
|
315
|
+
|
|
316
|
+
return "{}({})".format(self._get_name(), desc)
|
|
317
|
+
|
|
318
|
+
def __getitem__(self, key: str) -> torch.nn.Module:
|
|
319
|
+
module = self._modules[key]
|
|
320
|
+
assert module, "Error {} is None".format(key)
|
|
321
|
+
return module
|
|
322
|
+
|
|
323
|
+
@_copy_to_script_wrapper
|
|
324
|
+
def keys(self) -> Iterable[str]:
|
|
325
|
+
return self._modules.keys()
|
|
326
|
+
|
|
327
|
+
@_copy_to_script_wrapper
|
|
328
|
+
def items(self) -> Iterable[tuple[str, Union[torch.nn.Module, None]]]:
|
|
329
|
+
return self._modules.items()
|
|
330
|
+
|
|
331
|
+
@_copy_to_script_wrapper
|
|
332
|
+
def values(self) -> Iterable[Union[torch.nn.Module, None]]:
|
|
333
|
+
return self._modules.values()
|
|
334
|
+
|
|
335
|
+
def add_module(self, name: str, module : torch.nn.Module, in_branch: list[Union[int, str]] = [0], out_branch: list[Union[int, str]] = [0], pretrained : bool = True, alias : list[str] = [], requires_grad: Union[bool, None] = None, training: Union[None, bool] = None) -> None:
|
|
336
|
+
super().add_module(name, module)
|
|
337
|
+
self._modulesArgs[name] = ModuleArgsDict.ModuleArgs([str(value) for value in in_branch], [str(value) for value in out_branch], pretrained, alias, requires_grad, training)
|
|
338
|
+
|
|
339
|
+
def getMap(self):
|
|
340
|
+
results : dict[str, str] = {}
|
|
341
|
+
for name, moduleArgs in self._modulesArgs.items():
|
|
342
|
+
module = self[name]
|
|
343
|
+
if isinstance(module, ModuleArgsDict):
|
|
344
|
+
if len(moduleArgs.alias):
|
|
345
|
+
count = {k : 0 for k in set(module.getMap().values())}
|
|
346
|
+
if len(count):
|
|
347
|
+
for k, v in module.getMap().items():
|
|
348
|
+
alias_name = moduleArgs.alias[count[v]]
|
|
349
|
+
if k == "":
|
|
350
|
+
results.update({alias_name : name+"."+v})
|
|
351
|
+
else:
|
|
352
|
+
results.update({alias_name+"."+k : name+"."+v})
|
|
353
|
+
count[v]+=1
|
|
354
|
+
else:
|
|
355
|
+
for alias in moduleArgs.alias:
|
|
356
|
+
results.update({alias : name})
|
|
357
|
+
else:
|
|
358
|
+
results.update({k : name+"."+v for k, v in module.getMap().items()})
|
|
359
|
+
else:
|
|
360
|
+
for alias in moduleArgs.alias:
|
|
361
|
+
results[alias] = name
|
|
362
|
+
return results
|
|
363
|
+
|
|
364
|
+
@staticmethod
|
|
365
|
+
def init_func(module: torch.nn.Module, init_type: str, init_gain: float):
|
|
366
|
+
if not isinstance(module, Network):
|
|
367
|
+
if isinstance(module, ModuleArgsDict):
|
|
368
|
+
module.init(init_type, init_gain)
|
|
369
|
+
elif isinstance(module, torch.nn.modules.conv._ConvNd) or isinstance(module, torch.nn.Linear):
|
|
370
|
+
if init_type == 'normal':
|
|
371
|
+
torch.nn.init.normal_(module.weight, 0.0, init_gain)
|
|
372
|
+
elif init_type == 'xavier':
|
|
373
|
+
torch.nn.init.xavier_normal_(module.weight, gain=init_gain)
|
|
374
|
+
elif init_type == 'kaiming':
|
|
375
|
+
torch.nn.init.kaiming_normal_(module.weight, a=0, mode='fan_in')
|
|
376
|
+
elif init_type == 'orthogonal':
|
|
377
|
+
torch.nn.init.orthogonal_(module.weight, gain=init_gain)
|
|
378
|
+
elif init_type == "trunc_normal":
|
|
379
|
+
torch.nn.init.trunc_normal_(module.weight, std=init_gain)
|
|
380
|
+
else:
|
|
381
|
+
raise NotImplementedError('Initialization method {} is not implemented'.format(init_type))
|
|
382
|
+
if module.bias is not None:
|
|
383
|
+
torch.nn.init.constant_(module.bias, 0.0)
|
|
384
|
+
|
|
385
|
+
elif isinstance(module, torch.nn.modules.batchnorm._BatchNorm):
|
|
386
|
+
if module.weight is not None:
|
|
387
|
+
torch.nn.init.normal_(module.weight, 0.0, std = init_gain)
|
|
388
|
+
if module.bias is not None:
|
|
389
|
+
torch.nn.init.constant_(module.bias, 0.0)
|
|
390
|
+
|
|
391
|
+
def init(self, init_type: str, init_gain: float):
|
|
392
|
+
for module in self._modules.values():
|
|
393
|
+
ModuleArgsDict.init_func(module, init_type, init_gain)
|
|
394
|
+
|
|
395
|
+
|
|
396
|
+
def named_forward(self, *inputs: torch.Tensor) -> Iterator[tuple[str, torch.Tensor]]:
|
|
397
|
+
if len(inputs) > 0:
|
|
398
|
+
branchs: dict[str, torch.Tensor] = {}
|
|
399
|
+
for i, sinput in enumerate(inputs):
|
|
400
|
+
branchs[str(i)] = sinput
|
|
401
|
+
|
|
402
|
+
out = inputs[0]
|
|
403
|
+
tmp = []
|
|
404
|
+
for name, module in self.items():
|
|
405
|
+
if self._modulesArgs[name].training is None or (not (self._modulesArgs[name].training and self._training == NetState.PREDICTION) and not (not self._modulesArgs[name].training and self._training == NetState.TRAIN)):
|
|
406
|
+
requires_grad = self._modulesArgs[name].requires_grad
|
|
407
|
+
if requires_grad is not None and module:
|
|
408
|
+
module.requires_grad_(requires_grad)
|
|
409
|
+
for ib in self._modulesArgs[name].in_branch:
|
|
410
|
+
if ib not in branchs:
|
|
411
|
+
branchs[ib] = inputs[0]
|
|
412
|
+
for branchs_key in branchs.keys():
|
|
413
|
+
if str(self._modulesArgs[name].gpu) != 'cpu' and str(branchs[branchs_key].device) != "cuda:"+self._modulesArgs[name].gpu:
|
|
414
|
+
branchs[branchs_key] = branchs[branchs_key].to(int(self._modulesArgs[name].gpu))
|
|
415
|
+
|
|
416
|
+
if self._modulesArgs[name].isCheckpoint:
|
|
417
|
+
out = checkpoint(module, *[branchs[i] for i in self._modulesArgs[name].in_branch], use_reentrant=True)
|
|
418
|
+
for ob in self._modulesArgs[name].out_branch:
|
|
419
|
+
branchs[ob] = out
|
|
420
|
+
yield name, out
|
|
421
|
+
else:
|
|
422
|
+
if isinstance(module, ModuleArgsDict):
|
|
423
|
+
for k, out in module.named_forward(*[branchs[i] for i in self._modulesArgs[name].in_branch]):
|
|
424
|
+
for ob in self._modulesArgs[name].out_branch:
|
|
425
|
+
if ob in module._modulesArgs[k.split(".")[0].replace(";accu;", "")].out_branch:
|
|
426
|
+
tmp.append(ob)
|
|
427
|
+
branchs[ob] = out
|
|
428
|
+
yield name+"."+k, out
|
|
429
|
+
for ob in self._modulesArgs[name].out_branch:
|
|
430
|
+
if ob not in tmp:
|
|
431
|
+
branchs[ob] = out
|
|
432
|
+
elif isinstance(module, torch.nn.Module):
|
|
433
|
+
out = module(*[branchs[i] for i in self._modulesArgs[name].in_branch])
|
|
434
|
+
for ob in self._modulesArgs[name].out_branch:
|
|
435
|
+
branchs[ob] = out
|
|
436
|
+
yield name, out
|
|
437
|
+
del branchs
|
|
438
|
+
|
|
439
|
+
def forward(self, *input: torch.Tensor) -> torch.Tensor:
|
|
440
|
+
v = input
|
|
441
|
+
for k, v in self.named_forward(*input):
|
|
442
|
+
pass
|
|
443
|
+
return v
|
|
444
|
+
|
|
445
|
+
def named_parameters(self, pretrained: bool = False, recurse=False) -> Iterator[tuple[str, torch.nn.parameter.Parameter]]:
|
|
446
|
+
for name, moduleArgs in self._modulesArgs.items():
|
|
447
|
+
module = self[name]
|
|
448
|
+
if isinstance(module, ModuleArgsDict):
|
|
449
|
+
for k, v in module.named_parameters(pretrained = pretrained):
|
|
450
|
+
yield name+"."+k, v
|
|
451
|
+
elif isinstance(module, torch.nn.Module):
|
|
452
|
+
if not pretrained or not moduleArgs.pretrained:
|
|
453
|
+
if moduleArgs.training is None or moduleArgs.training:
|
|
454
|
+
for k, v in module.named_parameters():
|
|
455
|
+
yield name+"."+k, v
|
|
456
|
+
|
|
457
|
+
def parameters(self, pretrained: bool = False):
|
|
458
|
+
for _, v in self.named_parameters(pretrained = pretrained):
|
|
459
|
+
yield v
|
|
460
|
+
|
|
461
|
+
def named_ModuleArgsDict(self) -> Iterator[tuple[str, Self, ModuleArgs]]:
|
|
462
|
+
for name, module in self._modules.items():
|
|
463
|
+
yield name, module, self._modulesArgs[name]
|
|
464
|
+
if isinstance(module, ModuleArgsDict):
|
|
465
|
+
for k, v, u in module.named_ModuleArgsDict():
|
|
466
|
+
yield name+"."+k, v, u
|
|
467
|
+
|
|
468
|
+
def _requires_grad(self, keys: list[str]):
|
|
469
|
+
keys = keys.copy()
|
|
470
|
+
for name, module, args in self.named_ModuleArgsDict():
|
|
471
|
+
requires_grad = args.requires_grad
|
|
472
|
+
if requires_grad is not None:
|
|
473
|
+
module.requires_grad_(requires_grad)
|
|
474
|
+
if name in keys:
|
|
475
|
+
keys.remove(name)
|
|
476
|
+
if len(keys) == 0:
|
|
477
|
+
break
|
|
478
|
+
|
|
479
|
+
class OutputsGroup(list):
|
|
480
|
+
|
|
481
|
+
def __init__(self, measure: Measure) -> None:
|
|
482
|
+
self.layers: dict[str, torch.Tensor] = {}
|
|
483
|
+
self.measure = measure
|
|
484
|
+
|
|
485
|
+
def addLayer(self, name: str, layer: torch.Tensor):
|
|
486
|
+
self.layers[name] = layer
|
|
487
|
+
|
|
488
|
+
def isDone(self):
|
|
489
|
+
return len(self) == len(self.layers)
|
|
490
|
+
|
|
491
|
+
def clear(self):
|
|
492
|
+
self.layers.clear()
|
|
493
|
+
|
|
494
|
+
|
|
495
|
+
|
|
496
|
+
class Network(ModuleArgsDict, ABC):
|
|
497
|
+
|
|
498
|
+
def _apply_network(self, name_function : Callable[[Self], str], networks: list[str], key: str, function: Callable, *args, **kwargs) -> dict[str, object]:
|
|
499
|
+
results : dict[str, object] = {}
|
|
500
|
+
for module in self.values():
|
|
501
|
+
if isinstance(module, Network):
|
|
502
|
+
if name_function(module) not in networks:
|
|
503
|
+
networks.append(name_function(module))
|
|
504
|
+
for k, v in module._apply_network(name_function, networks, key+"."+name_function(module), function, *args, **kwargs).items():
|
|
505
|
+
results.update({name_function(self)+"."+k : v})
|
|
506
|
+
if len([param.name for param in list(inspect.signature(function).parameters.values()) if param.name == "key"]):
|
|
507
|
+
function = partial(function, key=key)
|
|
508
|
+
|
|
509
|
+
results[name_function(self)] = function(self, *args, **kwargs)
|
|
510
|
+
return results
|
|
511
|
+
|
|
512
|
+
def _function_network(t : bool = False):
|
|
513
|
+
def _function_network_d(function : Callable):
|
|
514
|
+
def new_function(self : Self, *args, **kwargs) -> dict[str, object]:
|
|
515
|
+
return self._apply_network(lambda network: network._getName() if t else network.getName(), [], self.getName(), function, *args, **kwargs)
|
|
516
|
+
return new_function
|
|
517
|
+
return _function_network_d
|
|
518
|
+
|
|
519
|
+
def __init__( self,
|
|
520
|
+
in_channels : int = 1,
|
|
521
|
+
optimizer: Union[OptimizerLoader, None] = None,
|
|
522
|
+
schedulers: Union[LRSchedulersLoader, None] = None,
|
|
523
|
+
outputsCriterions: Union[dict[str, TargetCriterionsLoader], None] = None,
|
|
524
|
+
patch : Union[ModelPatch, None] = None,
|
|
525
|
+
nb_batch_per_step : int = 1,
|
|
526
|
+
init_type : str = "normal",
|
|
527
|
+
init_gain : float = 0.02,
|
|
528
|
+
dim : int = 3) -> None:
|
|
529
|
+
super().__init__()
|
|
530
|
+
self.name = self.__class__.__name__
|
|
531
|
+
self.in_channels = in_channels
|
|
532
|
+
self.optimizerLoader = optimizer
|
|
533
|
+
self.optimizer : Union[torch.optim.Optimizer, None] = None
|
|
534
|
+
|
|
535
|
+
self.LRSchedulersLoader = schedulers
|
|
536
|
+
self.schedulers : Union[dict[torch.optim.lr_scheduler._LRScheduler, int], None] = None
|
|
537
|
+
|
|
538
|
+
self.outputsCriterionsLoader = outputsCriterions
|
|
539
|
+
self.measure : Union[Measure, None] = None
|
|
540
|
+
|
|
541
|
+
self.patch = patch
|
|
542
|
+
|
|
543
|
+
self.nb_batch_per_step = nb_batch_per_step
|
|
544
|
+
self.init_type = init_type
|
|
545
|
+
self.init_gain = init_gain
|
|
546
|
+
self.dim = dim
|
|
547
|
+
self._it = 0
|
|
548
|
+
self.outputsGroup : list[OutputsGroup]= []
|
|
549
|
+
|
|
550
|
+
@_function_network(True)
|
|
551
|
+
def state_dict(self) -> dict[str, OrderedDict]:
|
|
552
|
+
destination = OrderedDict()
|
|
553
|
+
destination._metadata = OrderedDict()
|
|
554
|
+
destination._metadata[""] = local_metadata = dict(version=self._version)
|
|
555
|
+
self._save_to_state_dict(destination, "", False)
|
|
556
|
+
for name, module in self._modules.items():
|
|
557
|
+
if module is not None:
|
|
558
|
+
if not isinstance(module, Network):
|
|
559
|
+
module.state_dict(destination=destination, prefix="" + name + '.', keep_vars=False)
|
|
560
|
+
for hook in self._state_dict_hooks.values():
|
|
561
|
+
hook_result = hook(self, destination, "", local_metadata)
|
|
562
|
+
if hook_result is not None:
|
|
563
|
+
destination = hook_result
|
|
564
|
+
return destination
|
|
565
|
+
|
|
566
|
+
def load_state_dict(self, state_dict: dict[str, torch.Tensor]):
|
|
567
|
+
missing_keys: list[str] = []
|
|
568
|
+
unexpected_keys: list[str] = []
|
|
569
|
+
error_msgs: list[str] = []
|
|
570
|
+
|
|
571
|
+
metadata = getattr(state_dict, '_metadata', None)
|
|
572
|
+
state_dict = state_dict.copy()
|
|
573
|
+
if metadata is not None:
|
|
574
|
+
state_dict._metadata = metadata
|
|
575
|
+
|
|
576
|
+
def load(module: torch.nn.Module, prefix=''):
|
|
577
|
+
local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
|
|
578
|
+
module._load_from_state_dict(
|
|
579
|
+
state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs)
|
|
580
|
+
for name, child in module._modules.items():
|
|
581
|
+
if child is not None:
|
|
582
|
+
if not isinstance(child, Network):
|
|
583
|
+
if isinstance(child, torch.nn.modules.conv._ConvNd) or isinstance(module, torch.nn.Linear):
|
|
584
|
+
|
|
585
|
+
current_size = child.weight.shape[0]
|
|
586
|
+
last_size = state_dict[prefix + name+".weight"].shape[0]
|
|
587
|
+
|
|
588
|
+
if current_size != last_size:
|
|
589
|
+
print("Warning: The size of '{}' has changed from {} to {}. Please check for potential impacts".format(prefix + name, last_size, current_size))
|
|
590
|
+
ModuleArgsDict.init_func(child, self.init_type, self.init_gain)
|
|
591
|
+
|
|
592
|
+
with torch.no_grad():
|
|
593
|
+
child.weight[:last_size] = state_dict[prefix + name+".weight"]
|
|
594
|
+
if child.bias is not None:
|
|
595
|
+
child.bias[:last_size] = state_dict[prefix + name+".bias"]
|
|
596
|
+
return
|
|
597
|
+
load(child, prefix + name + '.')
|
|
598
|
+
|
|
599
|
+
load(self)
|
|
600
|
+
del load
|
|
601
|
+
|
|
602
|
+
if len(unexpected_keys) > 0:
|
|
603
|
+
error_msgs.insert(
|
|
604
|
+
0, 'Unexpected key(s) in state_dict: {}. '.format(
|
|
605
|
+
', '.join('"{}"'.format(k) for k in unexpected_keys)))
|
|
606
|
+
if len(missing_keys) > 0:
|
|
607
|
+
error_msgs.insert(
|
|
608
|
+
0, 'Missing key(s) in state_dict: {}. '.format(
|
|
609
|
+
', '.join('"{}"'.format(k) for k in missing_keys)))
|
|
610
|
+
|
|
611
|
+
if len(error_msgs) > 0:
|
|
612
|
+
raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
|
|
613
|
+
self.__class__.__name__, "\n\t".join(error_msgs)))
|
|
614
|
+
|
|
615
|
+
def apply(self, fn: Callable[[torch.nn.Module], None]) -> None:
|
|
616
|
+
for module in self.children():
|
|
617
|
+
if not isinstance(module, Network):
|
|
618
|
+
module.apply(fn)
|
|
619
|
+
fn(self)
|
|
620
|
+
|
|
621
|
+
@_function_network(True)
|
|
622
|
+
def load(self, state_dict : dict[str, dict[str, torch.Tensor]], init: bool = True, ema : bool =False):
|
|
623
|
+
if init:
|
|
624
|
+
self.apply(partial(ModuleArgsDict.init_func, init_type=self.init_type, init_gain=self.init_gain))
|
|
625
|
+
name = "Model" + ("_EMA" if ema else "")
|
|
626
|
+
if name in state_dict:
|
|
627
|
+
model_state_dict_tmp = {k.split(".")[-1] : v for k, v in state_dict[name].items()}[self._getName()]
|
|
628
|
+
map = self.getMap()
|
|
629
|
+
model_state_dict : OrderedDict[str, torch.Tensor] = OrderedDict()
|
|
630
|
+
|
|
631
|
+
for alias in model_state_dict_tmp.keys():
|
|
632
|
+
prefix = ".".join(alias.split(".")[:-1])
|
|
633
|
+
alias_list = [(".".join(prefix.split(".")[:len(i.split("."))]), v) for i, v in map.items() if prefix.startswith(i)]
|
|
634
|
+
|
|
635
|
+
if len(alias_list):
|
|
636
|
+
for a, b in alias_list:
|
|
637
|
+
model_state_dict[alias.replace(a, b)] = model_state_dict_tmp[alias]
|
|
638
|
+
break
|
|
639
|
+
else:
|
|
640
|
+
model_state_dict[alias] = model_state_dict_tmp[alias]
|
|
641
|
+
self.load_state_dict(model_state_dict)
|
|
642
|
+
|
|
643
|
+
if "{}_optimizer_state_dict".format(name) in state_dict and self.optimizer:
|
|
644
|
+
self.optimizer.load_state_dict(state_dict['{}_optimizer_state_dict'.format(name)])
|
|
645
|
+
self.initialized()
|
|
646
|
+
|
|
647
|
+
def _compute_channels_trace(self, module : ModuleArgsDict, in_channels : int, gradient_checkpoints: Union[list[str], None], gpu_checkpoints: Union[list[str], None], name: Union[str, None] = None, in_is_channel = True, out_channels : Union[int, None] = None, out_is_channel = True) -> tuple[int, bool, int, bool]:
|
|
648
|
+
|
|
649
|
+
for k1, v1 in module.items():
|
|
650
|
+
if isinstance(v1, ModuleArgsDict):
|
|
651
|
+
for t in module._modulesArgs[k1].out_branch:
|
|
652
|
+
last = None
|
|
653
|
+
for k2, _ in v1.items():
|
|
654
|
+
if t in v1._modulesArgs[k2].out_branch:
|
|
655
|
+
last = k2
|
|
656
|
+
if last is not None:
|
|
657
|
+
v1._modulesArgs[last]._isEnd = True
|
|
658
|
+
else:
|
|
659
|
+
v1._modulesArgs[k2]._isEnd = True
|
|
660
|
+
|
|
661
|
+
for k, v in module.items():
|
|
662
|
+
if hasattr(v, "in_channels"):
|
|
663
|
+
if v.in_channels:
|
|
664
|
+
in_channels = v.in_channels
|
|
665
|
+
if hasattr(v, "in_features"):
|
|
666
|
+
if v.in_features:
|
|
667
|
+
in_channels = v.in_features
|
|
668
|
+
key = name+"."+k if name else k
|
|
669
|
+
|
|
670
|
+
if gradient_checkpoints:
|
|
671
|
+
if key in gradient_checkpoints:
|
|
672
|
+
module._modulesArgs[k].isCheckpoint = True
|
|
673
|
+
|
|
674
|
+
if gpu_checkpoints:
|
|
675
|
+
if key in gpu_checkpoints:
|
|
676
|
+
module._modulesArgs[k].isGPU_Checkpoint = True
|
|
677
|
+
|
|
678
|
+
module._modulesArgs[k].in_channels = in_channels
|
|
679
|
+
module._modulesArgs[k].in_is_channel = in_is_channel
|
|
680
|
+
|
|
681
|
+
if isinstance(v, ModuleArgsDict):
|
|
682
|
+
in_channels, in_is_channel, out_channels, out_is_channel = self._compute_channels_trace(v, in_channels, gradient_checkpoints, gpu_checkpoints, key, in_is_channel, out_channels, out_is_channel)
|
|
683
|
+
|
|
684
|
+
if v.__class__.__name__ == "ToChannels":
|
|
685
|
+
out_is_channel = True
|
|
686
|
+
|
|
687
|
+
if v.__class__.__name__ == "ToFeatures":
|
|
688
|
+
out_is_channel = False
|
|
689
|
+
|
|
690
|
+
if hasattr(v, "out_channels"):
|
|
691
|
+
if v.out_channels:
|
|
692
|
+
out_channels = v.out_channels
|
|
693
|
+
if hasattr(v, "out_features"):
|
|
694
|
+
if v.out_features:
|
|
695
|
+
out_channels = v.out_features
|
|
696
|
+
|
|
697
|
+
module._modulesArgs[k].out_channels = out_channels
|
|
698
|
+
module._modulesArgs[k].out_is_channel = out_is_channel
|
|
699
|
+
|
|
700
|
+
in_channels = out_channels
|
|
701
|
+
in_is_channel = out_is_channel
|
|
702
|
+
|
|
703
|
+
return in_channels, in_is_channel, out_channels, out_is_channel
|
|
704
|
+
|
|
705
|
+
@_function_network()
|
|
706
|
+
def init(self, autocast : bool, state : State, key: str) -> None:
|
|
707
|
+
if self.outputsCriterionsLoader:
|
|
708
|
+
self.measure = Measure(key, self.outputsCriterionsLoader)
|
|
709
|
+
self.measure.init(self)
|
|
710
|
+
if state != State.PREDICTION:
|
|
711
|
+
self.scaler = torch.amp.GradScaler("cuda", enabled=autocast)
|
|
712
|
+
if self.optimizerLoader:
|
|
713
|
+
self.optimizer = self.optimizerLoader.getOptimizer(key, self.parameters(state == State.TRANSFER_LEARNING))
|
|
714
|
+
self.optimizer.zero_grad()
|
|
715
|
+
|
|
716
|
+
if self.LRSchedulersLoader and self.optimizer:
|
|
717
|
+
self.schedulers = self.LRSchedulersLoader.getShedulers(key, self.optimizer)
|
|
718
|
+
|
|
719
|
+
def initialized(self):
|
|
720
|
+
pass
|
|
721
|
+
|
|
722
|
+
def named_forward(self, *inputs: torch.Tensor) -> Iterator[tuple[str, torch.Tensor]]:
|
|
723
|
+
if self.patch:
|
|
724
|
+
self.patch.load(inputs[0].shape[2:])
|
|
725
|
+
accumulators: dict[str, Accumulator] = {}
|
|
726
|
+
|
|
727
|
+
patchIterator = self.patch.disassemble(*inputs)
|
|
728
|
+
buffer = []
|
|
729
|
+
for i, patch_input in enumerate(patchIterator):
|
|
730
|
+
for (name, output_layer) in super().named_forward(*patch_input):
|
|
731
|
+
yield "{}{}".format(";accu;", name), output_layer
|
|
732
|
+
buffer.append((name.split(".")[0], output_layer))
|
|
733
|
+
if len(buffer) == 2:
|
|
734
|
+
if buffer[0][0] != buffer[1][0]:
|
|
735
|
+
if self._modulesArgs[buffer[0][0]]._isEnd:
|
|
736
|
+
if buffer[0][0] not in accumulators:
|
|
737
|
+
accumulators[buffer[0][0]] = Accumulator(self.patch.getPatch_slices(), self.patch.patch_size, self.patch.patchCombine)
|
|
738
|
+
accumulators[buffer[0][0]].addLayer(i, buffer[0][1])
|
|
739
|
+
buffer.pop(0)
|
|
740
|
+
if self._modulesArgs[buffer[0][0]]._isEnd:
|
|
741
|
+
if buffer[0][0] not in accumulators:
|
|
742
|
+
accumulators[buffer[0][0]] = Accumulator(self.patch.getPatch_slices(), self.patch.patch_size, self.patch.patchCombine)
|
|
743
|
+
accumulators[buffer[0][0]].addLayer(i, buffer[0][1])
|
|
744
|
+
for name, accumulator in accumulators.items():
|
|
745
|
+
yield name, accumulator.assemble()
|
|
746
|
+
else:
|
|
747
|
+
for (name, output_layer) in super().named_forward(*inputs):
|
|
748
|
+
yield name, output_layer
|
|
749
|
+
|
|
750
|
+
|
|
751
|
+
def get_layers(self, inputs : list[torch.Tensor], layers_name: list[str]) -> Iterator[tuple[str, torch.Tensor, Union[Patch_Indexed, None]]]:
|
|
752
|
+
layers_name = layers_name.copy()
|
|
753
|
+
output_layer_accumulator : dict[str, Accumulator] = {}
|
|
754
|
+
output_layer_patch_indexed : dict[str, Patch_Indexed] = {}
|
|
755
|
+
it = 0
|
|
756
|
+
debug = "DL_API_DEBUG" in os.environ
|
|
757
|
+
for (nameTmp, output_layer) in self.named_forward(*inputs):
|
|
758
|
+
name = nameTmp.replace(";accu;", "")
|
|
759
|
+
if debug:
|
|
760
|
+
if "DL_API_DEBUG_LAST_LAYER" in os.environ:
|
|
761
|
+
os.environ["DL_API_DEBUG_LAST_LAYER"] = "{}|{}:{}:{}".format(os.environ["DL_API_DEBUG_LAST_LAYER"], name, getGPUMemory(output_layer.device), str(output_layer.device).replace("cuda:", ""))
|
|
762
|
+
else:
|
|
763
|
+
os.environ["DL_API_DEBUG_LAST_LAYER"] = "{}:{}:{}".format(name, getGPUMemory(output_layer.device), str(output_layer.device).replace("cuda:", ""))
|
|
764
|
+
it += 1
|
|
765
|
+
if name in layers_name or nameTmp in layers_name:
|
|
766
|
+
if ";accu;" in nameTmp:
|
|
767
|
+
if name not in output_layer_patch_indexed:
|
|
768
|
+
networkName = nameTmp.split(".;accu;")[-2].split(".")[-1] if ".;accu;" in nameTmp else nameTmp.split(";accu;")[-2].split(".")[-1]
|
|
769
|
+
module = self
|
|
770
|
+
network = None
|
|
771
|
+
if networkName == "":
|
|
772
|
+
network = module
|
|
773
|
+
else:
|
|
774
|
+
for n in name.split("."):
|
|
775
|
+
module = module[n]
|
|
776
|
+
if isinstance(module, Network) and n == networkName:
|
|
777
|
+
network = module
|
|
778
|
+
break
|
|
779
|
+
|
|
780
|
+
if network and network.patch:
|
|
781
|
+
output_layer_patch_indexed[name] = Patch_Indexed(network.patch, 0)
|
|
782
|
+
|
|
783
|
+
if name not in output_layer_accumulator:
|
|
784
|
+
output_layer_accumulator[name] = Accumulator(output_layer_patch_indexed[name].patch.getPatch_slices(0), output_layer_patch_indexed[name].patch.patch_size, output_layer_patch_indexed[name].patch.patchCombine)
|
|
785
|
+
|
|
786
|
+
if nameTmp in layers_name:
|
|
787
|
+
output_layer_accumulator[name].addLayer(output_layer_patch_indexed[name].index, output_layer)
|
|
788
|
+
output_layer_patch_indexed[name].index += 1
|
|
789
|
+
if output_layer_accumulator[name].isFull():
|
|
790
|
+
output_layer = output_layer_accumulator[name].assemble()
|
|
791
|
+
output_layer_accumulator.pop(name)
|
|
792
|
+
output_layer_patch_indexed.pop(name)
|
|
793
|
+
layers_name.remove(nameTmp)
|
|
794
|
+
yield nameTmp, output_layer, None
|
|
795
|
+
|
|
796
|
+
if name in layers_name:
|
|
797
|
+
if ";accu;" in nameTmp:
|
|
798
|
+
yield name, output_layer, output_layer_patch_indexed[name]
|
|
799
|
+
output_layer_patch_indexed[name].index += 1
|
|
800
|
+
if output_layer_patch_indexed[name].isFull():
|
|
801
|
+
output_layer_patch_indexed.pop(name)
|
|
802
|
+
layers_name.remove(name)
|
|
803
|
+
else:
|
|
804
|
+
layers_name.remove(name)
|
|
805
|
+
yield name, output_layer, None
|
|
806
|
+
|
|
807
|
+
if not len(layers_name):
|
|
808
|
+
break
|
|
809
|
+
|
|
810
|
+
|
|
811
|
+
|
|
812
|
+
def init_outputsGroup(self):
|
|
813
|
+
metric_tmp = {network.measure : network.measure.outputsCriterions.keys() for network in self.getNetworks().values() if network.measure}
|
|
814
|
+
for k, v in metric_tmp.items():
|
|
815
|
+
for a in v:
|
|
816
|
+
outputs_group = OutputsGroup(k)
|
|
817
|
+
outputs_group.append(a)
|
|
818
|
+
for targetsGroup in k.outputsCriterions[a].keys():
|
|
819
|
+
if ":" in targetsGroup:
|
|
820
|
+
outputs_group.append(targetsGroup.replace(":", "."))
|
|
821
|
+
|
|
822
|
+
self.outputsGroup.append(outputs_group)
|
|
823
|
+
|
|
824
|
+
def forward(self, data_dict: dict[tuple[str, bool], torch.Tensor], output_layers: list[str] = []) -> list[tuple[str, torch.Tensor]]:
|
|
825
|
+
if not len(self.outputsGroup) and not len(output_layers):
|
|
826
|
+
return []
|
|
827
|
+
|
|
828
|
+
self.resetLoss()
|
|
829
|
+
results = []
|
|
830
|
+
measure_output_layers = set()
|
|
831
|
+
for outputs_group in self.outputsGroup:
|
|
832
|
+
for name in outputs_group:
|
|
833
|
+
measure_output_layers.add(name)
|
|
834
|
+
for name, layer, patch_indexed in self.get_layers([v for k, v in data_dict.items() if k[1]], list(set(list(measure_output_layers)+output_layers))):
|
|
835
|
+
|
|
836
|
+
outputs_group = [outputs_group for outputs_group in self.outputsGroup if name in outputs_group]
|
|
837
|
+
if len(outputs_group) > 0:
|
|
838
|
+
if patch_indexed is None:
|
|
839
|
+
targets = {k[0] : v for k, v in data_dict.items()}
|
|
840
|
+
nb = 1
|
|
841
|
+
else:
|
|
842
|
+
targets = {k[0] : patch_indexed.patch.getData(v, patch_indexed.index, 0, False) for k, v in data_dict.items()}
|
|
843
|
+
nb = patch_indexed.patch.getSize(0)
|
|
844
|
+
|
|
845
|
+
for output_group in outputs_group:
|
|
846
|
+
output_group.addLayer(name, layer)
|
|
847
|
+
if output_group.isDone():
|
|
848
|
+
targets.update({k.replace(".", ":"): v for k, v in output_group.layers.items() if k != output_group[0]})
|
|
849
|
+
output_group.measure.update(output_group[0], output_group.layers[output_group[0]], targets, self._it, nb, self.training)
|
|
850
|
+
output_group.clear()
|
|
851
|
+
if name in output_layers:
|
|
852
|
+
results.append((name, layer))
|
|
853
|
+
return results
|
|
854
|
+
|
|
855
|
+
@_function_network()
|
|
856
|
+
def resetLoss(self):
|
|
857
|
+
if self.measure:
|
|
858
|
+
self.measure.resetLoss()
|
|
859
|
+
|
|
860
|
+
@_function_network()
|
|
861
|
+
def backward(self, model: Self):
|
|
862
|
+
if self.measure:
|
|
863
|
+
if self.scaler and self.optimizer:
|
|
864
|
+
model._requires_grad(list(self.measure.outputsCriterions.keys()))
|
|
865
|
+
for loss in self.measure.getLoss():
|
|
866
|
+
self.scaler.scale(loss / self.nb_batch_per_step).backward()
|
|
867
|
+
if self._it % self.nb_batch_per_step == 0:
|
|
868
|
+
self.scaler.step(self.optimizer)
|
|
869
|
+
self.scaler.update()
|
|
870
|
+
self.optimizer.zero_grad(set_to_none=True)
|
|
871
|
+
self._it += 1
|
|
872
|
+
|
|
873
|
+
@_function_network()
|
|
874
|
+
def update_lr(self):
|
|
875
|
+
step = 0
|
|
876
|
+
scheduler = None
|
|
877
|
+
if self.schedulers:
|
|
878
|
+
for scheduler, value in self.schedulers.items():
|
|
879
|
+
if value is None or (self._it >= step and self._it < step+value):
|
|
880
|
+
break
|
|
881
|
+
step += value
|
|
882
|
+
if scheduler:
|
|
883
|
+
if scheduler.__class__.__name__ == 'ReduceLROnPlateau':
|
|
884
|
+
if self.measure:
|
|
885
|
+
print(sum(self.measure.getLastValues(0).values()))
|
|
886
|
+
scheduler.step(sum(self.measure.getLastValues(0).values()))
|
|
887
|
+
else:
|
|
888
|
+
scheduler.step()
|
|
889
|
+
|
|
890
|
+
@_function_network()
|
|
891
|
+
def getNetworks(self) -> Self:
|
|
892
|
+
return self
|
|
893
|
+
|
|
894
|
+
def to(module : ModuleArgsDict, device: int):
|
|
895
|
+
if "device" not in os.environ:
|
|
896
|
+
os.environ["device"] = str(device)
|
|
897
|
+
for k, v in module.items():
|
|
898
|
+
if module._modulesArgs[k].gpu == "cpu":
|
|
899
|
+
if module._modulesArgs[k].isGPU_Checkpoint:
|
|
900
|
+
os.environ["device"] = str(int(os.environ["device"])+1)
|
|
901
|
+
module._modulesArgs[k].gpu = str(getDevice(int(os.environ["device"])))
|
|
902
|
+
if isinstance(v, ModuleArgsDict):
|
|
903
|
+
v = Network.to(v, int(os.environ["device"]))
|
|
904
|
+
else:
|
|
905
|
+
v = v.to(getDevice(int(os.environ["device"])))
|
|
906
|
+
return module
|
|
907
|
+
|
|
908
|
+
def getName(self) -> str:
|
|
909
|
+
return self.__class__.__name__
|
|
910
|
+
|
|
911
|
+
def setName(self, name: str) -> Self:
|
|
912
|
+
self.name = name
|
|
913
|
+
return self
|
|
914
|
+
|
|
915
|
+
def _getName(self) -> str:
|
|
916
|
+
return self.name
|
|
917
|
+
|
|
918
|
+
def setState(self, state: NetState):
|
|
919
|
+
for module in self.modules():
|
|
920
|
+
if isinstance(module, ModuleArgsDict):
|
|
921
|
+
module._training = state
|
|
922
|
+
|
|
923
|
+
class ModelLoader():
|
|
924
|
+
|
|
925
|
+
@config("Model")
|
|
926
|
+
def __init__(self, classpath : str = "default:segmentation.UNet") -> None:
|
|
927
|
+
self.module, self.name = _getModule(classpath.split(".")[-1] if len(classpath.split(".")) > 1 else classpath, ".".join(classpath.split(".")[:-1]) if len(classpath.split(".")) > 1 else "")
|
|
928
|
+
|
|
929
|
+
def getModel(self, train : bool = True, DL_args: Union[str, None] = None, DL_without=["optimizer", "schedulers", "nb_batch_per_step", "init_type", "init_gain"]) -> Network:
|
|
930
|
+
if not DL_args:
|
|
931
|
+
DL_args="{}.Model".format(DEEP_LEARNING_API_ROOT())
|
|
932
|
+
model = partial(getattr(importlib.import_module(self.module), self.name), config = None, DL_args=DL_args)
|
|
933
|
+
if not train:
|
|
934
|
+
model = partial(model, DL_without = DL_without)
|
|
935
|
+
return model()
|
|
936
|
+
|
|
937
|
+
class CPU_Model():
|
|
938
|
+
|
|
939
|
+
def __init__(self, model: Network) -> None:
|
|
940
|
+
self.module = model
|
|
941
|
+
self.device = torch.device('cpu')
|
|
942
|
+
|
|
943
|
+
def train(self):
|
|
944
|
+
self.module.train()
|
|
945
|
+
|
|
946
|
+
def eval(self):
|
|
947
|
+
self.module.eval()
|
|
948
|
+
|
|
949
|
+
def __call__(self, data_dict: dict[tuple[str, bool], torch.Tensor], output_layers: list[str] = []) -> list[tuple[str, torch.Tensor]]:
|
|
950
|
+
return self.module(data_dict, output_layers)
|