konfai 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of konfai might be problematic. Click here for more details.

@@ -0,0 +1,950 @@
1
+ from functools import partial
2
+ import importlib
3
+ import inspect
4
+ import os
5
+ from typing import Iterable, Iterator, Callable
6
+ from typing_extensions import Self
7
+ import torch
8
+ from abc import ABC
9
+ import numpy as np
10
+ from torch._jit_internal import _copy_to_script_wrapper
11
+ from collections import OrderedDict
12
+ from torch.utils.checkpoint import checkpoint
13
+ from typing import Union
14
+ from enum import Enum
15
+
16
+ from KonfAI.konfai import DEEP_LEARNING_API_ROOT
17
+ from KonfAI.konfai.metric.schedulers import Scheduler
18
+ from KonfAI.konfai.utils.config import config
19
+ from KonfAI.konfai.utils.utils import State, _getModule, getDevice, getGPUMemory
20
+ from KonfAI.konfai.data.HDF5 import Accumulator, ModelPatch
21
+
22
+ class NetState(Enum):
23
+ TRAIN = 0,
24
+ PREDICTION = 1
25
+
26
+ class Patch_Indexed():
27
+
28
+ def __init__(self, patch: ModelPatch, index: int) -> None:
29
+ self.patch = patch
30
+ self.index = index
31
+
32
+ def isFull(self) -> bool:
33
+ return len(self.patch.getPatch_slices(0)) == self.index
34
+
35
+ class OptimizerLoader():
36
+
37
+ @config("Optimizer")
38
+ def __init__(self, name: str = "AdamW") -> None:
39
+ self.name = name
40
+
41
+ def getOptimizer(self, key: str, parameter: Iterator[torch.nn.parameter.Parameter]) -> torch.optim.Optimizer:
42
+ torch.optim.AdamW
43
+ return config("{}.Model.{}.Optimizer".format(DEEP_LEARNING_API_ROOT(), key))(getattr(importlib.import_module('torch.optim'), self.name))(parameter, config = None)
44
+
45
+ class SchedulerStep():
46
+
47
+ @config(None)
48
+ def __init__(self, nb_step : int = 0) -> None:
49
+ self.nb_step = nb_step
50
+
51
+ class LRSchedulersLoader():
52
+
53
+ @config("Schedulers")
54
+ def __init__(self, params: dict[str, SchedulerStep] = {"default:ReduceLROnPlateau" : SchedulerStep(0)}) -> None:
55
+ self.params = params
56
+
57
+ def getShedulers(self, key: str, optimizer: torch.optim.Optimizer) -> dict[torch.optim.lr_scheduler._LRScheduler, int]:
58
+ shedulers : dict[torch.optim.lr_scheduler._LRScheduler, int] = {}
59
+ for name, step in self.params.items():
60
+ if name:
61
+ shedulers[config("Trainer.Model.{}.Schedulers.{}".format(key, name))(getattr(importlib.import_module('torch.optim.lr_scheduler'), name))(optimizer, config = None)] = step.nb_step
62
+ return shedulers
63
+
64
+ class SchedulersLoader():
65
+
66
+ @config("Schedulers")
67
+ def __init__(self, params: dict[str, SchedulerStep] = {"default:Constant" : SchedulerStep(0)}) -> None:
68
+ self.params = params
69
+
70
+ def getShedulers(self, key: str) -> dict[torch.optim.lr_scheduler._LRScheduler, int]:
71
+ shedulers : dict[Scheduler, int] = {}
72
+ for name, step in self.params.items():
73
+ if name:
74
+ shedulers[getattr(importlib.import_module("KonfAI.konfai.metric.schedulers"), name)(config = None, DL_args = key)] = step.nb_step
75
+ return shedulers
76
+
77
+ class CriterionsAttr():
78
+
79
+ @config()
80
+ def __init__(self, l: SchedulersLoader = SchedulersLoader(), isLoss: bool = True, group: int = 0, stepStart: int = 0, stepStop: Union[int, None] = None, accumulation: bool = False) -> None:
81
+ self.l = l
82
+ self.isTorchCriterion = True
83
+ self.isLoss = isLoss
84
+ self.stepStart = stepStart
85
+ self.stepStop = stepStop
86
+ self.group = group
87
+ self.accumulation = accumulation
88
+ self.sheduler = None
89
+
90
+ class CriterionsLoader():
91
+
92
+ @config()
93
+ def __init__(self, criterionsLoader: dict[str, CriterionsAttr] = {"default:torch_nn_CrossEntropyLoss:Dice:NCC": CriterionsAttr()}) -> None:
94
+ self.criterionsLoader = criterionsLoader
95
+
96
+ def getCriterions(self, model_classname : str, output_group : str, target_group : str) -> dict[torch.nn.Module, CriterionsAttr]:
97
+ criterions = {}
98
+ for module_classpath, criterionsAttr in self.criterionsLoader.items():
99
+ module, name = _getModule(module_classpath, "measure")
100
+ criterionsAttr.isTorchCriterion = module.startswith("torch")
101
+ criterionsAttr.sheduler = criterionsAttr.l.getShedulers("{}.Model.{}.outputsCriterions.{}.targetsCriterions.{}.criterionsLoader.{}".format(DEEP_LEARNING_API_ROOT(), model_classname, output_group, target_group, module_classpath))
102
+ criterions[config("{}.Model.{}.outputsCriterions.{}.targetsCriterions.{}.criterionsLoader.{}".format(DEEP_LEARNING_API_ROOT(), model_classname, output_group, target_group, module_classpath))(getattr(importlib.import_module(module), name))(config = None)] = criterionsAttr
103
+ return criterions
104
+
105
+ class TargetCriterionsLoader():
106
+
107
+ @config()
108
+ def __init__(self, targetsCriterions : dict[str, CriterionsLoader] = {"default" : CriterionsLoader()}) -> None:
109
+ self.targetsCriterions = targetsCriterions
110
+
111
+ def getTargetsCriterions(self, output_group : str, model_classname : str) -> dict[str, dict[torch.nn.Module, float]]:
112
+ targetsCriterions = {}
113
+ for target_group, criterionsLoader in self.targetsCriterions.items():
114
+ targetsCriterions[target_group] = criterionsLoader.getCriterions(model_classname, output_group, target_group)
115
+ return targetsCriterions
116
+
117
+ class Measure():
118
+
119
+ class Loss():
120
+
121
+ def __init__(self, name: str, output_group: str, target_group: str, group: int, isLoss: bool, accumulation: bool) -> None:
122
+ self.name = name
123
+ self.isLoss = isLoss
124
+ self.accumulation = accumulation
125
+ self.output_group = output_group
126
+ self.target_group = target_group
127
+ self.group = group
128
+
129
+ self._loss: list[torch.Tensor] = []
130
+ self._weight: list[float] = []
131
+ self._values: list[float] = []
132
+
133
+ def resetLoss(self) -> None:
134
+ self._loss.clear()
135
+
136
+ def add(self, weight: float, value: torch.Tensor) -> None:
137
+ self._loss.append(value if self.isLoss else value.detach())
138
+ self._values.append(value.item())
139
+ self._weight.append(weight)
140
+
141
+ def getLastLoss(self) -> torch.Tensor:
142
+ return self._loss[-1]*self._weight[-1] if len(self._loss) else torch.zeros((1), requires_grad = True)
143
+
144
+ def getLoss(self) -> torch.Tensor:
145
+ return torch.stack([w*l for w, l in zip(self._weight, self._loss)], dim=0).mean(dim=0) if len(self._loss) else torch.zeros((1), requires_grad = True)
146
+
147
+ def __len__(self) -> int:
148
+ return len(self._loss)
149
+
150
+ def __init__(self, model_classname : str, outputsCriterions: dict[str, TargetCriterionsLoader]) -> None:
151
+ super().__init__()
152
+ self.outputsCriterions = {}
153
+ for output_group, targetCriterionsLoader in outputsCriterions.items():
154
+ self.outputsCriterions[output_group.replace(":", ".")] = targetCriterionsLoader.getTargetsCriterions(output_group, model_classname)
155
+ self._loss : dict[int, dict[str, Measure.Loss]] = {}
156
+
157
+ def init(self, model : torch.nn.Module) -> None:
158
+ outputs_group_rename = {}
159
+ for output_group in self.outputsCriterions.keys():
160
+ for target_group in self.outputsCriterions[output_group]:
161
+ for criterion in self.outputsCriterions[output_group][target_group]:
162
+ if not self.outputsCriterions[output_group][target_group][criterion].isTorchCriterion:
163
+ outputs_group_rename[output_group] = criterion.init(model, output_group, target_group)
164
+
165
+ outputsCriterions_bak = self.outputsCriterions.copy()
166
+ for old, new in outputs_group_rename.items():
167
+ self.outputsCriterions.pop(old)
168
+ self.outputsCriterions[new] = outputsCriterions_bak[old]
169
+ for output_group in self.outputsCriterions:
170
+ for target_group in self.outputsCriterions[output_group]:
171
+ for criterion, criterionsAttr in self.outputsCriterions[output_group][target_group].items():
172
+ if criterionsAttr.group not in self._loss:
173
+ self._loss[criterionsAttr.group] = {}
174
+ self._loss[criterionsAttr.group]["{}:{}:{}".format(output_group, target_group, criterion.__class__.__name__)] = Measure.Loss(criterion.__class__.__name__, output_group, target_group, criterionsAttr.group, criterionsAttr.isLoss, criterionsAttr.accumulation)
175
+
176
+ def update(self, output_group: str, output : torch.Tensor, data_dict: dict[str, torch.Tensor], it: int, nb_patch: int, training: bool) -> None:
177
+ for target_group in self.outputsCriterions[output_group]:
178
+ target = [data_dict[group].to(output[0].device).detach() for group in target_group.split("/") if group in data_dict]
179
+
180
+ for criterion, criterionsAttr in self.outputsCriterions[output_group][target_group].items():
181
+ if it >= criterionsAttr.stepStart and (criterionsAttr.stepStop is None or it <= criterionsAttr.stepStop):
182
+ scheduler = self.update_scheduler(criterionsAttr.sheduler, it)
183
+ self._loss[criterionsAttr.group]["{}:{}:{}".format(output_group, target_group, criterion.__class__.__name__)].add(scheduler.get_value(), criterion(output, *target))
184
+ if training and len(np.unique([len(l) for l in self._loss[criterionsAttr.group].values() if l.accumulation and l.isLoss])) == 1:
185
+ if criterionsAttr.isLoss:
186
+ loss = torch.zeros((1), requires_grad = True)
187
+ for v in [l for l in self._loss[criterionsAttr.group].values() if l.accumulation and l.isLoss]:
188
+ l = v.getLastLoss()
189
+ loss = loss.to(l.device)+l
190
+ loss = loss/nb_patch
191
+ loss.backward()
192
+
193
+ def getLoss(self) -> list[torch.Tensor]:
194
+ loss: dict[int, torch.Tensor] = {}
195
+ for group in self._loss.keys():
196
+ loss[group] = torch.zeros((1), requires_grad = True)
197
+ for v in self._loss[group].values():
198
+ if v.isLoss and not v.accumulation:
199
+ l = v.getLoss()
200
+ loss[v.group] = loss[v.group].to(l.device)+l
201
+ return loss.values()
202
+
203
+ def resetLoss(self) -> None:
204
+ for group in self._loss.keys():
205
+ for v in self._loss[group].values():
206
+ v.resetLoss()
207
+
208
+ def getLastValues(self, n: int = 1) -> dict[str, float]:
209
+ result = {}
210
+ for group in self._loss.keys():
211
+ result.update({name : np.nanmean(value._values[-n:] if n > 0 else value._values) for name, value in self._loss[group].items() if n < 0 or len(value._values) >= n})
212
+ return result
213
+
214
+ def getLastWeights(self, n: int = 1) -> dict[str, float]:
215
+ result = {}
216
+ for group in self._loss.keys():
217
+ result.update({name : np.nanmean(value._weight[-n:] if n > 0 else value._weight) for name, value in self._loss[group].items() if n < 0 or len(value._values) >= n})
218
+ return result
219
+
220
+ def format(self, isLoss: bool, n: int) -> dict[str, tuple[float, float]]:
221
+ result = dict()
222
+ for group in self._loss.keys():
223
+ for name, loss in self._loss[group].items():
224
+ if loss.isLoss == isLoss and len(loss._values) >= n:
225
+ result[name] = (np.nanmean(loss._weight[-n:]), np.nanmean(loss._values[-n:]))
226
+ return result
227
+
228
+ def update_scheduler(self, schedulers: dict[Scheduler, int], it: int) -> Scheduler:
229
+ step = 0
230
+ scheduler = None
231
+ for scheduler, value in schedulers.items():
232
+ if value is None or (it >= step and it < step+value):
233
+ break
234
+ step += value
235
+ if scheduler:
236
+ scheduler.step(it-step)
237
+ return scheduler
238
+
239
+ class ModuleArgsDict(torch.nn.Module, ABC):
240
+
241
+ class ModuleArgs:
242
+
243
+ def __init__(self, in_branch: list[str], out_branch: list[str], pretrained : bool, alias : list[str], requires_grad: Union[bool, None], training: Union[None, bool]) -> None:
244
+ super().__init__()
245
+ self.alias= alias
246
+ self.pretrained = pretrained
247
+ self.in_branch = in_branch
248
+ self.out_branch = out_branch
249
+ self.in_channels = None
250
+ self.in_is_channel = True
251
+ self.out_channels = None
252
+ self.out_is_channel = True
253
+ self.requires_grad = requires_grad
254
+ self.isCheckpoint = False
255
+ self.isGPU_Checkpoint = False
256
+ self.gpu = "cpu"
257
+ self.training = training
258
+ self._isEnd = False
259
+
260
+ def __init__(self) -> None:
261
+ super().__init__()
262
+ self._modulesArgs : dict[str, ModuleArgsDict.ModuleArgs] = dict()
263
+ self._training = NetState.TRAIN
264
+
265
+ def _addindent(self, s_: str, numSpaces : int):
266
+ s = s_.split('\n')
267
+ if len(s) == 1:
268
+ return s_
269
+ first = s.pop(0)
270
+ s = [(numSpaces * ' ') + line for line in s]
271
+ s = '\n'.join(s)
272
+ s = first + '\n' + s
273
+ return s
274
+
275
+ def __repr__(self):
276
+ extra_lines = []
277
+
278
+ extra_repr = self.extra_repr()
279
+ if extra_repr:
280
+ extra_lines = extra_repr.split('\n')
281
+
282
+ child_lines = []
283
+ is_simple_branch = lambda x : len(x) > 1 or x[0] != 0
284
+ for key, module in self._modules.items():
285
+ mod_str = repr(module)
286
+
287
+ mod_str = self._addindent(mod_str, 2)
288
+ desc = ""
289
+ if is_simple_branch(self._modulesArgs[key].in_branch) or is_simple_branch(self._modulesArgs[key].out_branch):
290
+ desc += ", {}->{}".format(self._modulesArgs[key].in_branch, self._modulesArgs[key].out_branch)
291
+ if not self._modulesArgs[key].pretrained:
292
+ desc += ", pretrained=False"
293
+ if self._modulesArgs[key].alias:
294
+ desc += ", alias={}".format(self._modulesArgs[key].alias)
295
+ desc += ", in_channels={}".format(self._modulesArgs[key].in_channels)
296
+ desc += ", in_is_channel={}".format(self._modulesArgs[key].in_is_channel)
297
+ desc += ", out_channels={}".format(self._modulesArgs[key].out_channels)
298
+ desc += ", out_is_channel={}".format(self._modulesArgs[key].out_is_channel)
299
+ desc += ", is_end={}".format(self._modulesArgs[key]._isEnd)
300
+ desc += ", isInCheckpoint={}".format(self._modulesArgs[key].isCheckpoint)
301
+ desc += ", isInGPU_Checkpoint={}".format(self._modulesArgs[key].isGPU_Checkpoint)
302
+ desc += ", requires_grad={}".format(self._modulesArgs[key].requires_grad)
303
+ desc += ", device={}".format(self._modulesArgs[key].gpu)
304
+
305
+ child_lines.append("({}{}) {}".format(key, desc, mod_str))
306
+
307
+ lines = extra_lines + child_lines
308
+
309
+ desc = ""
310
+ if lines:
311
+ if len(extra_lines) == 1 and not child_lines:
312
+ desc += extra_lines[0]
313
+ else:
314
+ desc += '\n ' + '\n '.join(lines) + '\n'
315
+
316
+ return "{}({})".format(self._get_name(), desc)
317
+
318
+ def __getitem__(self, key: str) -> torch.nn.Module:
319
+ module = self._modules[key]
320
+ assert module, "Error {} is None".format(key)
321
+ return module
322
+
323
+ @_copy_to_script_wrapper
324
+ def keys(self) -> Iterable[str]:
325
+ return self._modules.keys()
326
+
327
+ @_copy_to_script_wrapper
328
+ def items(self) -> Iterable[tuple[str, Union[torch.nn.Module, None]]]:
329
+ return self._modules.items()
330
+
331
+ @_copy_to_script_wrapper
332
+ def values(self) -> Iterable[Union[torch.nn.Module, None]]:
333
+ return self._modules.values()
334
+
335
+ def add_module(self, name: str, module : torch.nn.Module, in_branch: list[Union[int, str]] = [0], out_branch: list[Union[int, str]] = [0], pretrained : bool = True, alias : list[str] = [], requires_grad: Union[bool, None] = None, training: Union[None, bool] = None) -> None:
336
+ super().add_module(name, module)
337
+ self._modulesArgs[name] = ModuleArgsDict.ModuleArgs([str(value) for value in in_branch], [str(value) for value in out_branch], pretrained, alias, requires_grad, training)
338
+
339
+ def getMap(self):
340
+ results : dict[str, str] = {}
341
+ for name, moduleArgs in self._modulesArgs.items():
342
+ module = self[name]
343
+ if isinstance(module, ModuleArgsDict):
344
+ if len(moduleArgs.alias):
345
+ count = {k : 0 for k in set(module.getMap().values())}
346
+ if len(count):
347
+ for k, v in module.getMap().items():
348
+ alias_name = moduleArgs.alias[count[v]]
349
+ if k == "":
350
+ results.update({alias_name : name+"."+v})
351
+ else:
352
+ results.update({alias_name+"."+k : name+"."+v})
353
+ count[v]+=1
354
+ else:
355
+ for alias in moduleArgs.alias:
356
+ results.update({alias : name})
357
+ else:
358
+ results.update({k : name+"."+v for k, v in module.getMap().items()})
359
+ else:
360
+ for alias in moduleArgs.alias:
361
+ results[alias] = name
362
+ return results
363
+
364
+ @staticmethod
365
+ def init_func(module: torch.nn.Module, init_type: str, init_gain: float):
366
+ if not isinstance(module, Network):
367
+ if isinstance(module, ModuleArgsDict):
368
+ module.init(init_type, init_gain)
369
+ elif isinstance(module, torch.nn.modules.conv._ConvNd) or isinstance(module, torch.nn.Linear):
370
+ if init_type == 'normal':
371
+ torch.nn.init.normal_(module.weight, 0.0, init_gain)
372
+ elif init_type == 'xavier':
373
+ torch.nn.init.xavier_normal_(module.weight, gain=init_gain)
374
+ elif init_type == 'kaiming':
375
+ torch.nn.init.kaiming_normal_(module.weight, a=0, mode='fan_in')
376
+ elif init_type == 'orthogonal':
377
+ torch.nn.init.orthogonal_(module.weight, gain=init_gain)
378
+ elif init_type == "trunc_normal":
379
+ torch.nn.init.trunc_normal_(module.weight, std=init_gain)
380
+ else:
381
+ raise NotImplementedError('Initialization method {} is not implemented'.format(init_type))
382
+ if module.bias is not None:
383
+ torch.nn.init.constant_(module.bias, 0.0)
384
+
385
+ elif isinstance(module, torch.nn.modules.batchnorm._BatchNorm):
386
+ if module.weight is not None:
387
+ torch.nn.init.normal_(module.weight, 0.0, std = init_gain)
388
+ if module.bias is not None:
389
+ torch.nn.init.constant_(module.bias, 0.0)
390
+
391
+ def init(self, init_type: str, init_gain: float):
392
+ for module in self._modules.values():
393
+ ModuleArgsDict.init_func(module, init_type, init_gain)
394
+
395
+
396
+ def named_forward(self, *inputs: torch.Tensor) -> Iterator[tuple[str, torch.Tensor]]:
397
+ if len(inputs) > 0:
398
+ branchs: dict[str, torch.Tensor] = {}
399
+ for i, sinput in enumerate(inputs):
400
+ branchs[str(i)] = sinput
401
+
402
+ out = inputs[0]
403
+ tmp = []
404
+ for name, module in self.items():
405
+ if self._modulesArgs[name].training is None or (not (self._modulesArgs[name].training and self._training == NetState.PREDICTION) and not (not self._modulesArgs[name].training and self._training == NetState.TRAIN)):
406
+ requires_grad = self._modulesArgs[name].requires_grad
407
+ if requires_grad is not None and module:
408
+ module.requires_grad_(requires_grad)
409
+ for ib in self._modulesArgs[name].in_branch:
410
+ if ib not in branchs:
411
+ branchs[ib] = inputs[0]
412
+ for branchs_key in branchs.keys():
413
+ if str(self._modulesArgs[name].gpu) != 'cpu' and str(branchs[branchs_key].device) != "cuda:"+self._modulesArgs[name].gpu:
414
+ branchs[branchs_key] = branchs[branchs_key].to(int(self._modulesArgs[name].gpu))
415
+
416
+ if self._modulesArgs[name].isCheckpoint:
417
+ out = checkpoint(module, *[branchs[i] for i in self._modulesArgs[name].in_branch], use_reentrant=True)
418
+ for ob in self._modulesArgs[name].out_branch:
419
+ branchs[ob] = out
420
+ yield name, out
421
+ else:
422
+ if isinstance(module, ModuleArgsDict):
423
+ for k, out in module.named_forward(*[branchs[i] for i in self._modulesArgs[name].in_branch]):
424
+ for ob in self._modulesArgs[name].out_branch:
425
+ if ob in module._modulesArgs[k.split(".")[0].replace(";accu;", "")].out_branch:
426
+ tmp.append(ob)
427
+ branchs[ob] = out
428
+ yield name+"."+k, out
429
+ for ob in self._modulesArgs[name].out_branch:
430
+ if ob not in tmp:
431
+ branchs[ob] = out
432
+ elif isinstance(module, torch.nn.Module):
433
+ out = module(*[branchs[i] for i in self._modulesArgs[name].in_branch])
434
+ for ob in self._modulesArgs[name].out_branch:
435
+ branchs[ob] = out
436
+ yield name, out
437
+ del branchs
438
+
439
+ def forward(self, *input: torch.Tensor) -> torch.Tensor:
440
+ v = input
441
+ for k, v in self.named_forward(*input):
442
+ pass
443
+ return v
444
+
445
+ def named_parameters(self, pretrained: bool = False, recurse=False) -> Iterator[tuple[str, torch.nn.parameter.Parameter]]:
446
+ for name, moduleArgs in self._modulesArgs.items():
447
+ module = self[name]
448
+ if isinstance(module, ModuleArgsDict):
449
+ for k, v in module.named_parameters(pretrained = pretrained):
450
+ yield name+"."+k, v
451
+ elif isinstance(module, torch.nn.Module):
452
+ if not pretrained or not moduleArgs.pretrained:
453
+ if moduleArgs.training is None or moduleArgs.training:
454
+ for k, v in module.named_parameters():
455
+ yield name+"."+k, v
456
+
457
+ def parameters(self, pretrained: bool = False):
458
+ for _, v in self.named_parameters(pretrained = pretrained):
459
+ yield v
460
+
461
+ def named_ModuleArgsDict(self) -> Iterator[tuple[str, Self, ModuleArgs]]:
462
+ for name, module in self._modules.items():
463
+ yield name, module, self._modulesArgs[name]
464
+ if isinstance(module, ModuleArgsDict):
465
+ for k, v, u in module.named_ModuleArgsDict():
466
+ yield name+"."+k, v, u
467
+
468
+ def _requires_grad(self, keys: list[str]):
469
+ keys = keys.copy()
470
+ for name, module, args in self.named_ModuleArgsDict():
471
+ requires_grad = args.requires_grad
472
+ if requires_grad is not None:
473
+ module.requires_grad_(requires_grad)
474
+ if name in keys:
475
+ keys.remove(name)
476
+ if len(keys) == 0:
477
+ break
478
+
479
+ class OutputsGroup(list):
480
+
481
+ def __init__(self, measure: Measure) -> None:
482
+ self.layers: dict[str, torch.Tensor] = {}
483
+ self.measure = measure
484
+
485
+ def addLayer(self, name: str, layer: torch.Tensor):
486
+ self.layers[name] = layer
487
+
488
+ def isDone(self):
489
+ return len(self) == len(self.layers)
490
+
491
+ def clear(self):
492
+ self.layers.clear()
493
+
494
+
495
+
496
+ class Network(ModuleArgsDict, ABC):
497
+
498
+ def _apply_network(self, name_function : Callable[[Self], str], networks: list[str], key: str, function: Callable, *args, **kwargs) -> dict[str, object]:
499
+ results : dict[str, object] = {}
500
+ for module in self.values():
501
+ if isinstance(module, Network):
502
+ if name_function(module) not in networks:
503
+ networks.append(name_function(module))
504
+ for k, v in module._apply_network(name_function, networks, key+"."+name_function(module), function, *args, **kwargs).items():
505
+ results.update({name_function(self)+"."+k : v})
506
+ if len([param.name for param in list(inspect.signature(function).parameters.values()) if param.name == "key"]):
507
+ function = partial(function, key=key)
508
+
509
+ results[name_function(self)] = function(self, *args, **kwargs)
510
+ return results
511
+
512
+ def _function_network(t : bool = False):
513
+ def _function_network_d(function : Callable):
514
+ def new_function(self : Self, *args, **kwargs) -> dict[str, object]:
515
+ return self._apply_network(lambda network: network._getName() if t else network.getName(), [], self.getName(), function, *args, **kwargs)
516
+ return new_function
517
+ return _function_network_d
518
+
519
+ def __init__( self,
520
+ in_channels : int = 1,
521
+ optimizer: Union[OptimizerLoader, None] = None,
522
+ schedulers: Union[LRSchedulersLoader, None] = None,
523
+ outputsCriterions: Union[dict[str, TargetCriterionsLoader], None] = None,
524
+ patch : Union[ModelPatch, None] = None,
525
+ nb_batch_per_step : int = 1,
526
+ init_type : str = "normal",
527
+ init_gain : float = 0.02,
528
+ dim : int = 3) -> None:
529
+ super().__init__()
530
+ self.name = self.__class__.__name__
531
+ self.in_channels = in_channels
532
+ self.optimizerLoader = optimizer
533
+ self.optimizer : Union[torch.optim.Optimizer, None] = None
534
+
535
+ self.LRSchedulersLoader = schedulers
536
+ self.schedulers : Union[dict[torch.optim.lr_scheduler._LRScheduler, int], None] = None
537
+
538
+ self.outputsCriterionsLoader = outputsCriterions
539
+ self.measure : Union[Measure, None] = None
540
+
541
+ self.patch = patch
542
+
543
+ self.nb_batch_per_step = nb_batch_per_step
544
+ self.init_type = init_type
545
+ self.init_gain = init_gain
546
+ self.dim = dim
547
+ self._it = 0
548
+ self.outputsGroup : list[OutputsGroup]= []
549
+
550
+ @_function_network(True)
551
+ def state_dict(self) -> dict[str, OrderedDict]:
552
+ destination = OrderedDict()
553
+ destination._metadata = OrderedDict()
554
+ destination._metadata[""] = local_metadata = dict(version=self._version)
555
+ self._save_to_state_dict(destination, "", False)
556
+ for name, module in self._modules.items():
557
+ if module is not None:
558
+ if not isinstance(module, Network):
559
+ module.state_dict(destination=destination, prefix="" + name + '.', keep_vars=False)
560
+ for hook in self._state_dict_hooks.values():
561
+ hook_result = hook(self, destination, "", local_metadata)
562
+ if hook_result is not None:
563
+ destination = hook_result
564
+ return destination
565
+
566
+ def load_state_dict(self, state_dict: dict[str, torch.Tensor]):
567
+ missing_keys: list[str] = []
568
+ unexpected_keys: list[str] = []
569
+ error_msgs: list[str] = []
570
+
571
+ metadata = getattr(state_dict, '_metadata', None)
572
+ state_dict = state_dict.copy()
573
+ if metadata is not None:
574
+ state_dict._metadata = metadata
575
+
576
+ def load(module: torch.nn.Module, prefix=''):
577
+ local_metadata = {} if metadata is None else metadata.get(prefix[:-1], {})
578
+ module._load_from_state_dict(
579
+ state_dict, prefix, local_metadata, True, missing_keys, unexpected_keys, error_msgs)
580
+ for name, child in module._modules.items():
581
+ if child is not None:
582
+ if not isinstance(child, Network):
583
+ if isinstance(child, torch.nn.modules.conv._ConvNd) or isinstance(module, torch.nn.Linear):
584
+
585
+ current_size = child.weight.shape[0]
586
+ last_size = state_dict[prefix + name+".weight"].shape[0]
587
+
588
+ if current_size != last_size:
589
+ print("Warning: The size of '{}' has changed from {} to {}. Please check for potential impacts".format(prefix + name, last_size, current_size))
590
+ ModuleArgsDict.init_func(child, self.init_type, self.init_gain)
591
+
592
+ with torch.no_grad():
593
+ child.weight[:last_size] = state_dict[prefix + name+".weight"]
594
+ if child.bias is not None:
595
+ child.bias[:last_size] = state_dict[prefix + name+".bias"]
596
+ return
597
+ load(child, prefix + name + '.')
598
+
599
+ load(self)
600
+ del load
601
+
602
+ if len(unexpected_keys) > 0:
603
+ error_msgs.insert(
604
+ 0, 'Unexpected key(s) in state_dict: {}. '.format(
605
+ ', '.join('"{}"'.format(k) for k in unexpected_keys)))
606
+ if len(missing_keys) > 0:
607
+ error_msgs.insert(
608
+ 0, 'Missing key(s) in state_dict: {}. '.format(
609
+ ', '.join('"{}"'.format(k) for k in missing_keys)))
610
+
611
+ if len(error_msgs) > 0:
612
+ raise RuntimeError('Error(s) in loading state_dict for {}:\n\t{}'.format(
613
+ self.__class__.__name__, "\n\t".join(error_msgs)))
614
+
615
+ def apply(self, fn: Callable[[torch.nn.Module], None]) -> None:
616
+ for module in self.children():
617
+ if not isinstance(module, Network):
618
+ module.apply(fn)
619
+ fn(self)
620
+
621
+ @_function_network(True)
622
+ def load(self, state_dict : dict[str, dict[str, torch.Tensor]], init: bool = True, ema : bool =False):
623
+ if init:
624
+ self.apply(partial(ModuleArgsDict.init_func, init_type=self.init_type, init_gain=self.init_gain))
625
+ name = "Model" + ("_EMA" if ema else "")
626
+ if name in state_dict:
627
+ model_state_dict_tmp = {k.split(".")[-1] : v for k, v in state_dict[name].items()}[self._getName()]
628
+ map = self.getMap()
629
+ model_state_dict : OrderedDict[str, torch.Tensor] = OrderedDict()
630
+
631
+ for alias in model_state_dict_tmp.keys():
632
+ prefix = ".".join(alias.split(".")[:-1])
633
+ alias_list = [(".".join(prefix.split(".")[:len(i.split("."))]), v) for i, v in map.items() if prefix.startswith(i)]
634
+
635
+ if len(alias_list):
636
+ for a, b in alias_list:
637
+ model_state_dict[alias.replace(a, b)] = model_state_dict_tmp[alias]
638
+ break
639
+ else:
640
+ model_state_dict[alias] = model_state_dict_tmp[alias]
641
+ self.load_state_dict(model_state_dict)
642
+
643
+ if "{}_optimizer_state_dict".format(name) in state_dict and self.optimizer:
644
+ self.optimizer.load_state_dict(state_dict['{}_optimizer_state_dict'.format(name)])
645
+ self.initialized()
646
+
647
+ def _compute_channels_trace(self, module : ModuleArgsDict, in_channels : int, gradient_checkpoints: Union[list[str], None], gpu_checkpoints: Union[list[str], None], name: Union[str, None] = None, in_is_channel = True, out_channels : Union[int, None] = None, out_is_channel = True) -> tuple[int, bool, int, bool]:
648
+
649
+ for k1, v1 in module.items():
650
+ if isinstance(v1, ModuleArgsDict):
651
+ for t in module._modulesArgs[k1].out_branch:
652
+ last = None
653
+ for k2, _ in v1.items():
654
+ if t in v1._modulesArgs[k2].out_branch:
655
+ last = k2
656
+ if last is not None:
657
+ v1._modulesArgs[last]._isEnd = True
658
+ else:
659
+ v1._modulesArgs[k2]._isEnd = True
660
+
661
+ for k, v in module.items():
662
+ if hasattr(v, "in_channels"):
663
+ if v.in_channels:
664
+ in_channels = v.in_channels
665
+ if hasattr(v, "in_features"):
666
+ if v.in_features:
667
+ in_channels = v.in_features
668
+ key = name+"."+k if name else k
669
+
670
+ if gradient_checkpoints:
671
+ if key in gradient_checkpoints:
672
+ module._modulesArgs[k].isCheckpoint = True
673
+
674
+ if gpu_checkpoints:
675
+ if key in gpu_checkpoints:
676
+ module._modulesArgs[k].isGPU_Checkpoint = True
677
+
678
+ module._modulesArgs[k].in_channels = in_channels
679
+ module._modulesArgs[k].in_is_channel = in_is_channel
680
+
681
+ if isinstance(v, ModuleArgsDict):
682
+ in_channels, in_is_channel, out_channels, out_is_channel = self._compute_channels_trace(v, in_channels, gradient_checkpoints, gpu_checkpoints, key, in_is_channel, out_channels, out_is_channel)
683
+
684
+ if v.__class__.__name__ == "ToChannels":
685
+ out_is_channel = True
686
+
687
+ if v.__class__.__name__ == "ToFeatures":
688
+ out_is_channel = False
689
+
690
+ if hasattr(v, "out_channels"):
691
+ if v.out_channels:
692
+ out_channels = v.out_channels
693
+ if hasattr(v, "out_features"):
694
+ if v.out_features:
695
+ out_channels = v.out_features
696
+
697
+ module._modulesArgs[k].out_channels = out_channels
698
+ module._modulesArgs[k].out_is_channel = out_is_channel
699
+
700
+ in_channels = out_channels
701
+ in_is_channel = out_is_channel
702
+
703
+ return in_channels, in_is_channel, out_channels, out_is_channel
704
+
705
+ @_function_network()
706
+ def init(self, autocast : bool, state : State, key: str) -> None:
707
+ if self.outputsCriterionsLoader:
708
+ self.measure = Measure(key, self.outputsCriterionsLoader)
709
+ self.measure.init(self)
710
+ if state != State.PREDICTION:
711
+ self.scaler = torch.amp.GradScaler("cuda", enabled=autocast)
712
+ if self.optimizerLoader:
713
+ self.optimizer = self.optimizerLoader.getOptimizer(key, self.parameters(state == State.TRANSFER_LEARNING))
714
+ self.optimizer.zero_grad()
715
+
716
+ if self.LRSchedulersLoader and self.optimizer:
717
+ self.schedulers = self.LRSchedulersLoader.getShedulers(key, self.optimizer)
718
+
719
+ def initialized(self):
720
+ pass
721
+
722
+ def named_forward(self, *inputs: torch.Tensor) -> Iterator[tuple[str, torch.Tensor]]:
723
+ if self.patch:
724
+ self.patch.load(inputs[0].shape[2:])
725
+ accumulators: dict[str, Accumulator] = {}
726
+
727
+ patchIterator = self.patch.disassemble(*inputs)
728
+ buffer = []
729
+ for i, patch_input in enumerate(patchIterator):
730
+ for (name, output_layer) in super().named_forward(*patch_input):
731
+ yield "{}{}".format(";accu;", name), output_layer
732
+ buffer.append((name.split(".")[0], output_layer))
733
+ if len(buffer) == 2:
734
+ if buffer[0][0] != buffer[1][0]:
735
+ if self._modulesArgs[buffer[0][0]]._isEnd:
736
+ if buffer[0][0] not in accumulators:
737
+ accumulators[buffer[0][0]] = Accumulator(self.patch.getPatch_slices(), self.patch.patch_size, self.patch.patchCombine)
738
+ accumulators[buffer[0][0]].addLayer(i, buffer[0][1])
739
+ buffer.pop(0)
740
+ if self._modulesArgs[buffer[0][0]]._isEnd:
741
+ if buffer[0][0] not in accumulators:
742
+ accumulators[buffer[0][0]] = Accumulator(self.patch.getPatch_slices(), self.patch.patch_size, self.patch.patchCombine)
743
+ accumulators[buffer[0][0]].addLayer(i, buffer[0][1])
744
+ for name, accumulator in accumulators.items():
745
+ yield name, accumulator.assemble()
746
+ else:
747
+ for (name, output_layer) in super().named_forward(*inputs):
748
+ yield name, output_layer
749
+
750
+
751
+ def get_layers(self, inputs : list[torch.Tensor], layers_name: list[str]) -> Iterator[tuple[str, torch.Tensor, Union[Patch_Indexed, None]]]:
752
+ layers_name = layers_name.copy()
753
+ output_layer_accumulator : dict[str, Accumulator] = {}
754
+ output_layer_patch_indexed : dict[str, Patch_Indexed] = {}
755
+ it = 0
756
+ debug = "DL_API_DEBUG" in os.environ
757
+ for (nameTmp, output_layer) in self.named_forward(*inputs):
758
+ name = nameTmp.replace(";accu;", "")
759
+ if debug:
760
+ if "DL_API_DEBUG_LAST_LAYER" in os.environ:
761
+ os.environ["DL_API_DEBUG_LAST_LAYER"] = "{}|{}:{}:{}".format(os.environ["DL_API_DEBUG_LAST_LAYER"], name, getGPUMemory(output_layer.device), str(output_layer.device).replace("cuda:", ""))
762
+ else:
763
+ os.environ["DL_API_DEBUG_LAST_LAYER"] = "{}:{}:{}".format(name, getGPUMemory(output_layer.device), str(output_layer.device).replace("cuda:", ""))
764
+ it += 1
765
+ if name in layers_name or nameTmp in layers_name:
766
+ if ";accu;" in nameTmp:
767
+ if name not in output_layer_patch_indexed:
768
+ networkName = nameTmp.split(".;accu;")[-2].split(".")[-1] if ".;accu;" in nameTmp else nameTmp.split(";accu;")[-2].split(".")[-1]
769
+ module = self
770
+ network = None
771
+ if networkName == "":
772
+ network = module
773
+ else:
774
+ for n in name.split("."):
775
+ module = module[n]
776
+ if isinstance(module, Network) and n == networkName:
777
+ network = module
778
+ break
779
+
780
+ if network and network.patch:
781
+ output_layer_patch_indexed[name] = Patch_Indexed(network.patch, 0)
782
+
783
+ if name not in output_layer_accumulator:
784
+ output_layer_accumulator[name] = Accumulator(output_layer_patch_indexed[name].patch.getPatch_slices(0), output_layer_patch_indexed[name].patch.patch_size, output_layer_patch_indexed[name].patch.patchCombine)
785
+
786
+ if nameTmp in layers_name:
787
+ output_layer_accumulator[name].addLayer(output_layer_patch_indexed[name].index, output_layer)
788
+ output_layer_patch_indexed[name].index += 1
789
+ if output_layer_accumulator[name].isFull():
790
+ output_layer = output_layer_accumulator[name].assemble()
791
+ output_layer_accumulator.pop(name)
792
+ output_layer_patch_indexed.pop(name)
793
+ layers_name.remove(nameTmp)
794
+ yield nameTmp, output_layer, None
795
+
796
+ if name in layers_name:
797
+ if ";accu;" in nameTmp:
798
+ yield name, output_layer, output_layer_patch_indexed[name]
799
+ output_layer_patch_indexed[name].index += 1
800
+ if output_layer_patch_indexed[name].isFull():
801
+ output_layer_patch_indexed.pop(name)
802
+ layers_name.remove(name)
803
+ else:
804
+ layers_name.remove(name)
805
+ yield name, output_layer, None
806
+
807
+ if not len(layers_name):
808
+ break
809
+
810
+
811
+
812
+ def init_outputsGroup(self):
813
+ metric_tmp = {network.measure : network.measure.outputsCriterions.keys() for network in self.getNetworks().values() if network.measure}
814
+ for k, v in metric_tmp.items():
815
+ for a in v:
816
+ outputs_group = OutputsGroup(k)
817
+ outputs_group.append(a)
818
+ for targetsGroup in k.outputsCriterions[a].keys():
819
+ if ":" in targetsGroup:
820
+ outputs_group.append(targetsGroup.replace(":", "."))
821
+
822
+ self.outputsGroup.append(outputs_group)
823
+
824
+ def forward(self, data_dict: dict[tuple[str, bool], torch.Tensor], output_layers: list[str] = []) -> list[tuple[str, torch.Tensor]]:
825
+ if not len(self.outputsGroup) and not len(output_layers):
826
+ return []
827
+
828
+ self.resetLoss()
829
+ results = []
830
+ measure_output_layers = set()
831
+ for outputs_group in self.outputsGroup:
832
+ for name in outputs_group:
833
+ measure_output_layers.add(name)
834
+ for name, layer, patch_indexed in self.get_layers([v for k, v in data_dict.items() if k[1]], list(set(list(measure_output_layers)+output_layers))):
835
+
836
+ outputs_group = [outputs_group for outputs_group in self.outputsGroup if name in outputs_group]
837
+ if len(outputs_group) > 0:
838
+ if patch_indexed is None:
839
+ targets = {k[0] : v for k, v in data_dict.items()}
840
+ nb = 1
841
+ else:
842
+ targets = {k[0] : patch_indexed.patch.getData(v, patch_indexed.index, 0, False) for k, v in data_dict.items()}
843
+ nb = patch_indexed.patch.getSize(0)
844
+
845
+ for output_group in outputs_group:
846
+ output_group.addLayer(name, layer)
847
+ if output_group.isDone():
848
+ targets.update({k.replace(".", ":"): v for k, v in output_group.layers.items() if k != output_group[0]})
849
+ output_group.measure.update(output_group[0], output_group.layers[output_group[0]], targets, self._it, nb, self.training)
850
+ output_group.clear()
851
+ if name in output_layers:
852
+ results.append((name, layer))
853
+ return results
854
+
855
+ @_function_network()
856
+ def resetLoss(self):
857
+ if self.measure:
858
+ self.measure.resetLoss()
859
+
860
+ @_function_network()
861
+ def backward(self, model: Self):
862
+ if self.measure:
863
+ if self.scaler and self.optimizer:
864
+ model._requires_grad(list(self.measure.outputsCriterions.keys()))
865
+ for loss in self.measure.getLoss():
866
+ self.scaler.scale(loss / self.nb_batch_per_step).backward()
867
+ if self._it % self.nb_batch_per_step == 0:
868
+ self.scaler.step(self.optimizer)
869
+ self.scaler.update()
870
+ self.optimizer.zero_grad(set_to_none=True)
871
+ self._it += 1
872
+
873
+ @_function_network()
874
+ def update_lr(self):
875
+ step = 0
876
+ scheduler = None
877
+ if self.schedulers:
878
+ for scheduler, value in self.schedulers.items():
879
+ if value is None or (self._it >= step and self._it < step+value):
880
+ break
881
+ step += value
882
+ if scheduler:
883
+ if scheduler.__class__.__name__ == 'ReduceLROnPlateau':
884
+ if self.measure:
885
+ print(sum(self.measure.getLastValues(0).values()))
886
+ scheduler.step(sum(self.measure.getLastValues(0).values()))
887
+ else:
888
+ scheduler.step()
889
+
890
+ @_function_network()
891
+ def getNetworks(self) -> Self:
892
+ return self
893
+
894
+ def to(module : ModuleArgsDict, device: int):
895
+ if "device" not in os.environ:
896
+ os.environ["device"] = str(device)
897
+ for k, v in module.items():
898
+ if module._modulesArgs[k].gpu == "cpu":
899
+ if module._modulesArgs[k].isGPU_Checkpoint:
900
+ os.environ["device"] = str(int(os.environ["device"])+1)
901
+ module._modulesArgs[k].gpu = str(getDevice(int(os.environ["device"])))
902
+ if isinstance(v, ModuleArgsDict):
903
+ v = Network.to(v, int(os.environ["device"]))
904
+ else:
905
+ v = v.to(getDevice(int(os.environ["device"])))
906
+ return module
907
+
908
+ def getName(self) -> str:
909
+ return self.__class__.__name__
910
+
911
+ def setName(self, name: str) -> Self:
912
+ self.name = name
913
+ return self
914
+
915
+ def _getName(self) -> str:
916
+ return self.name
917
+
918
+ def setState(self, state: NetState):
919
+ for module in self.modules():
920
+ if isinstance(module, ModuleArgsDict):
921
+ module._training = state
922
+
923
+ class ModelLoader():
924
+
925
+ @config("Model")
926
+ def __init__(self, classpath : str = "default:segmentation.UNet") -> None:
927
+ self.module, self.name = _getModule(classpath.split(".")[-1] if len(classpath.split(".")) > 1 else classpath, ".".join(classpath.split(".")[:-1]) if len(classpath.split(".")) > 1 else "")
928
+
929
+ def getModel(self, train : bool = True, DL_args: Union[str, None] = None, DL_without=["optimizer", "schedulers", "nb_batch_per_step", "init_type", "init_gain"]) -> Network:
930
+ if not DL_args:
931
+ DL_args="{}.Model".format(DEEP_LEARNING_API_ROOT())
932
+ model = partial(getattr(importlib.import_module(self.module), self.name), config = None, DL_args=DL_args)
933
+ if not train:
934
+ model = partial(model, DL_without = DL_without)
935
+ return model()
936
+
937
+ class CPU_Model():
938
+
939
+ def __init__(self, model: Network) -> None:
940
+ self.module = model
941
+ self.device = torch.device('cpu')
942
+
943
+ def train(self):
944
+ self.module.train()
945
+
946
+ def eval(self):
947
+ self.module.eval()
948
+
949
+ def __call__(self, data_dict: dict[tuple[str, bool], torch.Tensor], output_layers: list[str] = []) -> list[tuple[str, torch.Tensor]]:
950
+ return self.module(data_dict, output_layers)