konfai 1.1.4__tar.gz → 1.1.6__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of konfai might be problematic. Click here for more details.

Files changed (47) hide show
  1. {konfai-1.1.4 → konfai-1.1.6}/PKG-INFO +1 -1
  2. {konfai-1.1.4 → konfai-1.1.6}/konfai/data/augmentation.py +1 -1
  3. {konfai-1.1.4 → konfai-1.1.6}/konfai/data/data_manager.py +2 -2
  4. {konfai-1.1.4 → konfai-1.1.6}/konfai/data/patching.py +14 -17
  5. {konfai-1.1.4 → konfai-1.1.6}/konfai/data/transform.py +21 -20
  6. {konfai-1.1.4 → konfai-1.1.6}/konfai/evaluator.py +1 -1
  7. {konfai-1.1.4 → konfai-1.1.6}/konfai/main.py +0 -2
  8. {konfai-1.1.4 → konfai-1.1.6}/konfai/metric/measure.py +37 -17
  9. {konfai-1.1.4 → konfai-1.1.6}/konfai/metric/schedulers.py +10 -1
  10. konfai-1.1.6/konfai/models/segmentation/NestedUNet.py +113 -0
  11. {konfai-1.1.4 → konfai-1.1.6}/konfai/models/segmentation/UNet.py +0 -1
  12. {konfai-1.1.4 → konfai-1.1.6}/konfai/network/blocks.py +32 -13
  13. {konfai-1.1.4 → konfai-1.1.6}/konfai/network/network.py +36 -14
  14. {konfai-1.1.4 → konfai-1.1.6}/konfai/predictor.py +49 -27
  15. {konfai-1.1.4 → konfai-1.1.6}/konfai/trainer.py +7 -5
  16. {konfai-1.1.4 → konfai-1.1.6}/konfai/utils/utils.py +6 -1
  17. {konfai-1.1.4 → konfai-1.1.6}/konfai.egg-info/PKG-INFO +1 -1
  18. {konfai-1.1.4 → konfai-1.1.6}/pyproject.toml +1 -1
  19. konfai-1.1.4/konfai/models/segmentation/NestedUNet.py +0 -52
  20. {konfai-1.1.4 → konfai-1.1.6}/LICENSE +0 -0
  21. {konfai-1.1.4 → konfai-1.1.6}/README.md +0 -0
  22. {konfai-1.1.4 → konfai-1.1.6}/konfai/__init__.py +0 -0
  23. {konfai-1.1.4 → konfai-1.1.6}/konfai/data/__init__.py +0 -0
  24. {konfai-1.1.4 → konfai-1.1.6}/konfai/metric/__init__.py +0 -0
  25. {konfai-1.1.4 → konfai-1.1.6}/konfai/models/classification/convNeXt.py +0 -0
  26. {konfai-1.1.4 → konfai-1.1.6}/konfai/models/classification/resnet.py +0 -0
  27. {konfai-1.1.4 → konfai-1.1.6}/konfai/models/generation/cStyleGan.py +0 -0
  28. {konfai-1.1.4 → konfai-1.1.6}/konfai/models/generation/ddpm.py +0 -0
  29. {konfai-1.1.4 → konfai-1.1.6}/konfai/models/generation/diffusionGan.py +0 -0
  30. {konfai-1.1.4 → konfai-1.1.6}/konfai/models/generation/gan.py +0 -0
  31. {konfai-1.1.4 → konfai-1.1.6}/konfai/models/generation/vae.py +0 -0
  32. {konfai-1.1.4 → konfai-1.1.6}/konfai/models/registration/registration.py +0 -0
  33. {konfai-1.1.4 → konfai-1.1.6}/konfai/models/representation/representation.py +0 -0
  34. {konfai-1.1.4 → konfai-1.1.6}/konfai/network/__init__.py +0 -0
  35. {konfai-1.1.4 → konfai-1.1.6}/konfai/utils/ITK.py +0 -0
  36. {konfai-1.1.4 → konfai-1.1.6}/konfai/utils/__init__.py +0 -0
  37. {konfai-1.1.4 → konfai-1.1.6}/konfai/utils/config.py +0 -0
  38. {konfai-1.1.4 → konfai-1.1.6}/konfai/utils/dataset.py +0 -0
  39. {konfai-1.1.4 → konfai-1.1.6}/konfai/utils/registration.py +0 -0
  40. {konfai-1.1.4 → konfai-1.1.6}/konfai.egg-info/SOURCES.txt +0 -0
  41. {konfai-1.1.4 → konfai-1.1.6}/konfai.egg-info/dependency_links.txt +0 -0
  42. {konfai-1.1.4 → konfai-1.1.6}/konfai.egg-info/entry_points.txt +0 -0
  43. {konfai-1.1.4 → konfai-1.1.6}/konfai.egg-info/requires.txt +0 -0
  44. {konfai-1.1.4 → konfai-1.1.6}/konfai.egg-info/top_level.txt +0 -0
  45. {konfai-1.1.4 → konfai-1.1.6}/setup.cfg +0 -0
  46. {konfai-1.1.4 → konfai-1.1.6}/tests/test_config.py +0 -0
  47. {konfai-1.1.4 → konfai-1.1.6}/tests/test_dataset.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: konfai
3
- Version: 1.1.4
3
+ Version: 1.1.6
4
4
  Summary: Modular and configurable Deep Learning framework with YAML and PyTorch
5
5
  Author-email: Valentin Boussot <boussot.v@gmail.com>
6
6
  License-Expression: Apache-2.0
@@ -57,7 +57,7 @@ class DataAugmentationsList():
57
57
 
58
58
  def load(self, key: str, datasets: list[Dataset]):
59
59
  for augmentation, prob in self.dataAugmentationsLoader.items():
60
- module, name = _getModule(augmentation, "data.augmentation")
60
+ module, name = _getModule(augmentation, "konfai.data.augmentation")
61
61
  dataAugmentation: DataAugmentation = config("{}.Dataset.augmentations.{}.dataAugmentations.{}".format(KONFAI_ROOT(), key, augmentation))(getattr(importlib.import_module(module), name))(config = None)
62
62
  dataAugmentation.load(prob.prob)
63
63
  dataAugmentation.setDatasets(datasets)
@@ -129,7 +129,7 @@ class DatasetIter(data.Dataset):
129
129
  for group_dest in self.groups_src[group_src]:
130
130
  self.data[group_dest][index].unloadAugmentation()
131
131
  self.data[group_dest][index].resetAugmentation()
132
- self.load(label + " Augmentation")
132
+ self.load(label + " Augmentation")
133
133
 
134
134
  def load(self, label: str):
135
135
  if self.use_cache:
@@ -184,7 +184,7 @@ class DatasetIter(data.Dataset):
184
184
  data = {}
185
185
  x, a, p = self.map[index]
186
186
  if x not in self._index_cache:
187
- if x not in self._index_cache and len(self._index_cache) >= self.buffer_size and not self.use_cache:
187
+ if len(self._index_cache) >= self.buffer_size and not self.use_cache:
188
188
  self._unloadData(self._index_cache[0])
189
189
  self._loadData(x)
190
190
 
@@ -125,23 +125,20 @@ class Accumulator():
125
125
  return len(self.patch_slices) == len([v for v in self._layer_accumulator if v is not None])
126
126
 
127
127
  def assemble(self) -> torch.Tensor:
128
- if all([self._layer_accumulator[0].shape[i-len(self.patch_size)] == size for i, size in enumerate(self.patch_size)]):
129
- N = 2 if self.batch else 1
130
-
131
- result = torch.zeros((list(self._layer_accumulator[0].shape[:N])+list(max([[v.stop for v in patch] for patch in self.patch_slices]))), dtype=self._layer_accumulator[0].dtype).to(self._layer_accumulator[0].device)
132
- for patch_slice, data in zip(self.patch_slices, self._layer_accumulator):
133
- slices_dest = tuple([slice(result.shape[i]) for i in range(N)] + list(patch_slice))
134
-
135
- for dim, s in enumerate(patch_slice):
136
- if s.stop-s.start == 1:
137
- data = data.unsqueeze(dim=dim+N)
138
- if self.patchCombine is not None:
139
- result[slices_dest] += self.patchCombine(data)
140
- else:
141
- result[slices_dest] = data
142
- result = result[tuple([slice(None, None)]+[slice(0, s) for s in self.shape])]
143
- else:
144
- result = torch.cat(tuple(self._layer_accumulator), dim=0)
128
+ N = 2 if self.batch else 1
129
+ result = torch.zeros((list(self._layer_accumulator[0].shape[:N])+list(max([[v.stop for v in patch] for patch in self.patch_slices]))), dtype=self._layer_accumulator[0].dtype).to(self._layer_accumulator[0].device)
130
+ for patch_slice, data in zip(self.patch_slices, self._layer_accumulator):
131
+ slices_dest = tuple([slice(result.shape[i]) for i in range(N)] + list(patch_slice))
132
+
133
+ for dim, s in enumerate(patch_slice):
134
+ if s.stop-s.start == 1:
135
+ data = data.unsqueeze(dim=dim+N)
136
+ if self.patchCombine is not None:
137
+ result[slices_dest] += self.patchCombine(data)
138
+ else:
139
+ result[slices_dest] = data
140
+ result = result[tuple([slice(None, None)]+[slice(0, s) for s in self.shape])]
141
+
145
142
  self._layer_accumulator.clear()
146
143
  return result
147
144
 
@@ -36,7 +36,7 @@ class TransformLoader:
36
36
  pass
37
37
 
38
38
  def getTransform(self, classpath : str, DL_args : str) -> Transform:
39
- module, name = _getModule(classpath, "data.transform")
39
+ module, name = _getModule(classpath, "konfai.data.transform")
40
40
  return config("{}.{}".format(DL_args, classpath))(getattr(importlib.import_module(module), name))(config = None)
41
41
 
42
42
  class Clip(Transform):
@@ -213,7 +213,7 @@ class Resample(Transform, ABC):
213
213
 
214
214
  class ResampleToResolution(Resample):
215
215
 
216
- def __init__(self, spacing : list[Union[float, None]] = [1., 1., 1.]) -> None:
216
+ def __init__(self, spacing : list[float] = [1., 1., 1.]) -> None:
217
217
  self.spacing = torch.tensor([0 if s < 0 else s for s in spacing])
218
218
 
219
219
  def transformShape(self, shape: list[int], cache_attribute: Attribute) -> list[int]:
@@ -244,34 +244,35 @@ class ResampleToResolution(Resample):
244
244
  cache_attribute["Size"] = np.asarray(size)
245
245
  return self._resample(input, size)
246
246
 
247
- class ResampleToSize(Resample):
247
+ class ResampleToShape(Resample):
248
248
 
249
- def __init__(self, size : list[int] = [100,512,512]) -> None:
250
- self.size = size
249
+ def __init__(self, shape : list[float] = [100,256,256]) -> None:
250
+ self.shape = torch.tensor([0 if s < 0 else s for s in shape])
251
251
 
252
252
  def transformShape(self, shape: list[int], cache_attribute: Attribute) -> list[int]:
253
253
  if "Spacing" not in cache_attribute:
254
254
  TransformError("Missing 'Spacing' in cache attributes, the data is likely not a valid image.",
255
255
  "Make sure your input is a image (e.g., .nii, .mha) with proper metadata.")
256
- if len(shape) != len(self.size):
256
+ if len(shape) != len(self.shape):
257
257
  TransformError("Shape and spacing dimensions do not match: shape={shape}, spacing={self.spacing}")
258
- size = self.size
259
- for i, s in enumerate(self.size):
260
- if s == -1:
261
- size[i] = shape[i]
262
- return size
258
+ new_shape = self.shape
259
+ for i, s in enumerate(self.shape):
260
+ if s == 0:
261
+ new_shape[i] = shape[i]
262
+ print(new_shape)
263
+ return new_shape
263
264
 
264
265
  def __call__(self, name: str, input: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
265
- size = self.size
266
- image_size = np.asarray([int(x) for x in torch.tensor(input.shape[1:])])
267
- for i, s in enumerate(self.size):
268
- if s is None:
269
- size[i] = image_size[i]
266
+ shape = self.shape
267
+ image_shape = torch.tensor([int(x) for x in torch.tensor(input.shape[1:])])
268
+ for i, s in enumerate(self.shape):
269
+ if s == 0:
270
+ shape[i] = image_shape[i]
270
271
  if "Spacing" in cache_attribute:
271
- cache_attribute["Spacing"] = torch.flip(torch.tensor(image_size)/torch.tensor(size)*torch.flip(cache_attribute.get_tensor("Spacing"), dims=[0]), dims=[0])
272
- cache_attribute["Size"] = image_size
273
- cache_attribute["Size"] = size
274
- return self._resample(input, size)
272
+ cache_attribute["Spacing"] = torch.flip(image_shape/shape*torch.flip(cache_attribute.get_tensor("Spacing"), dims=[0]), dims=[0])
273
+ cache_attribute["Size"] = image_shape
274
+ cache_attribute["Size"] = shape
275
+ return self._resample(input, shape)
275
276
 
276
277
  class ResampleTransform(Transform):
277
278
 
@@ -27,7 +27,7 @@ class CriterionsLoader():
27
27
  def getCriterions(self, output_group : str, target_group : str) -> dict[torch.nn.Module, CriterionsAttr]:
28
28
  criterions = {}
29
29
  for module_classpath, criterionsAttr in self.criterionsLoader.items():
30
- module, name = _getModule(module_classpath, "metric.measure")
30
+ module, name = _getModule(module_classpath, "konfai.metric.measure")
31
31
  criterions[config("{}.metrics.{}.targetsCriterions.{}.criterionsLoader.{}".format(KONFAI_ROOT(), output_group, target_group, module_classpath))(getattr(importlib.import_module(module), name))(config = None)] = criterionsAttr
32
32
  return criterions
33
33
 
@@ -9,8 +9,6 @@ import sys
9
9
  sys.path.insert(0, os.getcwd())
10
10
 
11
11
  def main():
12
- import tracemalloc
13
-
14
12
  parser = argparse.ArgumentParser(description="KonfAI", formatter_class=argparse.ArgumentDefaultsHelpFormatter)
15
13
  try:
16
14
  with setup(parser) as distributedObject:
@@ -3,6 +3,9 @@ import importlib
3
3
  import numpy as np
4
4
  import torch
5
5
 
6
+ from torchvision.models import inception_v3, Inception_V3_Weights
7
+ from torchvision.transforms import functional as F
8
+ from scipy import linalg
6
9
 
7
10
  import torch.nn.functional as F
8
11
  import os
@@ -247,7 +250,7 @@ class PerceptualLoss(Criterion):
247
250
  def getLoss(self) -> dict[torch.nn.Module, float]:
248
251
  result: dict[torch.nn.Module, float] = {}
249
252
  for loss, l in self.losses.items():
250
- module, name = _getModule(loss, "metric.measure")
253
+ module, name = _getModule(loss, "konfai.metric.measure")
251
254
  result[config(self.DL_args)(getattr(importlib.import_module(module), name))(config=None)] = l
252
255
  return result
253
256
 
@@ -401,21 +404,42 @@ class L1LossRepresentation(Criterion):
401
404
  def forward(self, output: torch.Tensor) -> torch.Tensor:
402
405
  return self.loss(output[0], output[1])+ self._variance(output[0]) + self._variance(output[1])
403
406
 
404
- """import torch
405
- import torch.nn as nn
406
- from torchvision.models import inception_v3, Inception_V3_Weights
407
- from torchvision.transforms import functional as F
408
- from scipy import linalg
409
- import numpy as np
407
+ class FocalLoss(Criterion):
408
+
409
+ def __init__(self, gamma: float = 2.0, alpha: list[float] = [0.5, 2.0, 0.5, 0.5, 1], reduction: str = "mean"):
410
+ super().__init__()
411
+ raw_alpha = torch.tensor(alpha, dtype=torch.float32)
412
+ self.alpha = raw_alpha / raw_alpha.sum() * len(raw_alpha)
413
+ self.gamma = gamma
414
+ self.reduction = reduction
415
+
416
+ def forward(self, output: torch.Tensor, targets: list[torch.Tensor]) -> torch.Tensor:
417
+ target = targets[0].long()
418
+
419
+ logpt = F.log_softmax(output, dim=1)
420
+ pt = torch.exp(logpt)
421
+
422
+ logpt = logpt.gather(1, target.unsqueeze(1))
423
+ pt = pt.gather(1, target.unsqueeze(1))
424
+
425
+ at = self.alpha[target].unsqueeze(1)
426
+ loss = -at * ((1 - pt) ** self.gamma) * logpt
427
+
428
+ if self.reduction == "mean":
429
+ return loss.mean()
430
+ elif self.reduction == "sum":
431
+ return loss.sum()
432
+ return loss
410
433
 
434
+
411
435
  class FID(Criterion):
412
436
 
413
- class InceptionV3(nn.Module):
437
+ class InceptionV3(torch.nn.Module):
414
438
 
415
439
  def __init__(self) -> None:
416
440
  super().__init__()
417
441
  self.model = inception_v3(weights=Inception_V3_Weights.DEFAULT, transform_input=False)
418
- self.model.fc = nn.Identity()
442
+ self.model.fc = torch.nn.Identity()
419
443
  self.model.eval()
420
444
 
421
445
  def forward(self, x: torch.Tensor) -> torch.Tensor:
@@ -434,13 +458,11 @@ class FID(Criterion):
434
458
  return features
435
459
 
436
460
  def calculate_fid(real_features: np.ndarray, generated_features: np.ndarray) -> float:
437
- # Calculate mean and covariance statistics
438
461
  mu1 = np.mean(real_features, axis=0)
439
462
  sigma1 = np.cov(real_features, rowvar=False)
440
463
  mu2 = np.mean(generated_features, axis=0)
441
464
  sigma2 = np.cov(generated_features, rowvar=False)
442
465
 
443
- # Calculate FID score
444
466
  diff = mu1 - mu2
445
467
  covmean, _ = linalg.sqrtm(sigma1.dot(sigma2), disp=False)
446
468
  if np.iscomplexobj(covmean):
@@ -456,9 +478,8 @@ class FID(Criterion):
456
478
  generated_features = FID.get_features(generated_images, self.inception_model)
457
479
 
458
480
  return FID.calculate_fid(real_features, generated_features)
459
- """
460
481
 
461
- """class MutualInformationLoss(torch.nn.Module):
482
+ class MutualInformationLoss(torch.nn.Module):
462
483
  def __init__(self, num_bins: int = 23, sigma_ratio: float = 0.5, smooth_nr: float = 1e-7, smooth_dr: float = 1e-7) -> None:
463
484
  super().__init__()
464
485
  bin_centers = torch.linspace(0.0, 1.0, num_bins)
@@ -469,13 +490,12 @@ class FID(Criterion):
469
490
  self.smooth_nr = float(smooth_nr)
470
491
  self.smooth_dr = float(smooth_dr)
471
492
 
472
- def parzen_windowing(self, pred: torch.Tensor, target: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
493
+ def parzen_windowing(self, pred: torch.Tensor, target: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
473
494
  pred_weight, pred_probability = self.parzen_windowing_gaussian(pred)
474
495
  target_weight, target_probability = self.parzen_windowing_gaussian(target)
475
496
  return pred_weight, pred_probability, target_weight, target_probability
476
497
 
477
-
478
- def parzen_windowing_gaussian(self, img: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
498
+ def parzen_windowing_gaussian(self, img: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
479
499
  img = torch.clamp(img, 0, 1)
480
500
  img = img.reshape(img.shape[0], -1, 1) # (batch, num_sample, 1)
481
501
  weight = torch.exp(-self.preterm.to(img) * (img - self.bin_centers.to(img)) ** 2) # (batch, num_sample, num_bin)
@@ -488,4 +508,4 @@ class FID(Criterion):
488
508
  pab = torch.bmm(wa.permute(0, 2, 1), wb.to(wa)).div(wa.shape[1]) # (batch, num_bins, num_bins)
489
509
  papb = torch.bmm(pa.permute(0, 2, 1), pb.to(pa)) # (batch, num_bins, num_bins)
490
510
  mi = torch.sum(pab * torch.log((pab + self.smooth_nr) / (papb + self.smooth_dr) + self.smooth_dr), dim=(1, 2)) # (batch)
491
- return torch.mean(mi).neg() # average over the batch and channel ndims"""
511
+ return torch.mean(mi).neg() # average over the batch and channel ndims
@@ -2,6 +2,7 @@ import torch
2
2
  import numpy as np
3
3
  from abc import abstractmethod
4
4
  from konfai.utils.config import config
5
+ from functools import partial
5
6
 
6
7
  class Scheduler():
7
8
 
@@ -46,4 +47,12 @@ class CosineAnnealing(Scheduler):
46
47
  self.T_max = T_max
47
48
 
48
49
  def get_value(self):
49
- return self.eta_min + (self.baseValue - self.eta_min) *(1 + np.cos(self.it * torch.pi / self.T_max)) / 2
50
+ return self.eta_min + (self.baseValue - self.eta_min) *(1 + np.cos(self.it * torch.pi / self.T_max)) / 2
51
+
52
+ class Warmup(torch.optim.lr_scheduler.LambdaLR):
53
+
54
+ def warmup(warmup_steps: int, step: int) -> float:
55
+ return min(1.0, step / warmup_steps)
56
+
57
+ def __init__(self, optimizer: torch.optim.Optimizer, warmup_steps: int = 10, last_epoch=-1, verbose="deprecated"):
58
+ super().__init__(optimizer, partial(Warmup.warmup, warmup_steps), last_epoch, verbose)
@@ -0,0 +1,113 @@
1
+ from konfai.network import network, blocks
2
+ from typing import Union
3
+ import torch
4
+ from konfai.data.patching import ModelPatch
5
+
6
+ class NestedUNet(network.Network):
7
+
8
+ class NestedUNetBlock(network.ModuleArgsDict):
9
+
10
+ def __init__(self, channels: list[int], nb_conv_per_stage: int, blockConfig: blocks.BlockConfig, downSampleMode: blocks.DownSampleMode, upSampleMode: blocks.UpSampleMode, attention : bool, block: type, dim: int, i : int = 0) -> None:
11
+ super().__init__()
12
+ if i > 0:
13
+ self.add_module(downSampleMode.name, blocks.downSample(in_channels=channels[0], out_channels=channels[1], downSampleMode=downSampleMode, dim=dim))
14
+
15
+ self.add_module("X_{}_{}".format(i, 0), block(in_channels=channels[1 if downSampleMode == blocks.DownSampleMode.CONV_STRIDE and i > 0 else 0], out_channels=channels[1], blockConfigs=[blockConfig]*nb_conv_per_stage, dim=dim), out_branch=["X_{}_{}".format(i, 0)])
16
+ if len(channels) > 2:
17
+ self.add_module("UNetBlock_{}".format(i+1), NestedUNet.NestedUNetBlock(channels[1:], nb_conv_per_stage, blockConfig, downSampleMode, upSampleMode, attention, block, dim, i+1), in_branch=["X_{}_{}".format(i, 0)], out_branch=["X_{}_{}".format(i+1, j) for j in range(len(channels)-2)])
18
+ for j in range(len(channels)-2):
19
+ self.add_module("X_{}_{}_{}".format(i, j+1, upSampleMode.name), blocks.upSample(in_channels=channels[2], out_channels=channels[1], upSampleMode=upSampleMode, dim=dim), in_branch=["X_{}_{}".format(i+1, j)], out_branch=["X_{}_{}".format(i+1, j)])
20
+ self.add_module("SkipConnection_{}_{}".format(i, j+1), blocks.Concat(), in_branch=["X_{}_{}".format(i+1, j)]+["X_{}_{}".format(i, r) for r in range(j+1)], out_branch=["X_{}_{}".format(i, j+1)])
21
+ self.add_module("X_{}_{}".format(i, j+1), block(in_channels=(channels[1]*(j+1)+channels[2]) if upSampleMode != blocks.UpSampleMode.CONV_TRANSPOSE else channels[1]*(j+2), out_channels=channels[1], blockConfigs=[blockConfig]*nb_conv_per_stage, dim=dim), in_branch=["X_{}_{}".format(i, j+1)], out_branch=["X_{}_{}".format(i, j+1)])
22
+
23
+
24
+ class NestedUNetHead(network.ModuleArgsDict):
25
+
26
+ def __init__(self, in_channels: int, nb_class: int, activation: str, dim: int) -> None:
27
+ super().__init__()
28
+ self.add_module("Conv", blocks.getTorchModule("Conv", dim)(in_channels = in_channels, out_channels = nb_class, kernel_size = 1, stride = 1, padding = 0))
29
+ if activation == "Softmax":
30
+ self.add_module("Softmax", torch.nn.Softmax(dim=1))
31
+ self.add_module("Argmax", blocks.ArgMax(dim=1))
32
+ elif activation == "Tanh":
33
+ self.add_module("Tanh", torch.nn.Tanh())
34
+
35
+ def __init__( self,
36
+ optimizer : network.OptimizerLoader = network.OptimizerLoader(),
37
+ schedulers : network.LRSchedulersLoader = network.LRSchedulersLoader(),
38
+ outputsCriterions: dict[str, network.TargetCriterionsLoader] = {"default" : network.TargetCriterionsLoader()},
39
+ patch : Union[ModelPatch, None] = None,
40
+ dim : int = 3,
41
+ channels: list[int]=[1, 64, 128, 256, 512, 1024],
42
+ nb_class: int = 2,
43
+ blockConfig: blocks.BlockConfig = blocks.BlockConfig(),
44
+ nb_conv_per_stage: int = 2,
45
+ downSampleMode: str = "MAXPOOL",
46
+ upSampleMode: str = "CONV_TRANSPOSE",
47
+ attention : bool = False,
48
+ blockType: str = "Conv",
49
+ activation: str = "Softmax") -> None:
50
+ super().__init__(in_channels = channels[0], optimizer = optimizer, schedulers = schedulers, outputsCriterions = outputsCriterions, patch=patch, dim = dim)
51
+
52
+ self.add_module("UNetBlock_0", NestedUNet.NestedUNetBlock(channels, nb_conv_per_stage, blockConfig, downSampleMode=blocks.DownSampleMode._member_map_[downSampleMode], upSampleMode=blocks.UpSampleMode._member_map_[upSampleMode], attention=attention, block = blocks.ConvBlock if blockType == "Conv" else blocks.ResBlock, dim=dim), out_branch=["X_0_{}".format(j+1) for j in range(len(channels)-2)])
53
+ for j in range(len(channels)-2):
54
+ self.add_module("Head_{}".format(j), NestedUNet.NestedUNetHead(in_channels=channels[1], nb_class=nb_class, activation=activation, dim=dim), in_branch=["X_0_{}".format(j+1)], out_branch=[-1])
55
+
56
+
57
+ class UNetpp(network.Network):
58
+
59
+ class ResNetEncoderLayer(network.ModuleArgsDict):
60
+
61
+ def __init__(self, in_channel: int, out_channel: int, nb_block: int, dim: int, downSampleMode : blocks.DownSampleMode):
62
+ super().__init__()
63
+ for i in range(nb_block):
64
+ if downSampleMode == blocks.DownSampleMode.MAXPOOL and i == 0:
65
+ self.add_module("DownSample", blocks.getTorchModule("MaxPool", dim)(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False))
66
+ self.add_module("ResBlock_{}".format(i), blocks.ResBlock(in_channel, out_channel, [blocks.BlockConfig(3, 2 if downSampleMode == blocks.DownSampleMode.CONV_STRIDE and i == 0 else 1, 1, False, "ReLU;True", blocks.NormMode.BATCH), blocks.BlockConfig(3, 1, 1, False, None, blocks.NormMode.BATCH)], dim=dim))
67
+ in_channel = out_channel
68
+
69
+ def resNetEncoder(channels: list[int], layers: list[int], dim: int) -> list[torch.nn.Module]:
70
+ modules = []
71
+ modules.append(blocks.ConvBlock(channels[0], channels[1], [blocks.BlockConfig(7, 2, 3, False, "ReLU", blocks.NormMode.BATCH)], dim=dim))
72
+ for i, (in_channel, out_channel, layer) in enumerate(zip(channels[1:], channels[2:], layers)):
73
+ modules.append(UNetpp.ResNetEncoderLayer(in_channel, out_channel, layer, dim, blocks.DownSampleMode.MAXPOOL if i==0 else blocks.DownSampleMode.CONV_STRIDE))
74
+ return modules
75
+
76
+ class UNetPPBlock(network.ModuleArgsDict):
77
+
78
+ def __init__(self, encoder_channels: list[int], decoder_channels: list[int], encoders: list[torch.nn.Module], upSampleMode: blocks.UpSampleMode, dim: int, i : int = 0) -> None:
79
+ super().__init__()
80
+ self.add_module("X_{}_{}".format(i, 0), encoders[0], out_branch=["X_{}_{}".format(i, 0)])
81
+ if len(encoder_channels) > 2:
82
+ self.add_module("UNetBlock_{}".format(i+1), UNetpp.UNetPPBlock(encoder_channels[1:], decoder_channels[1:], encoders[1:], upSampleMode, dim, i+1), in_branch=["X_{}_{}".format(i, 0)], out_branch=["X_{}_{}".format(i+1, j) for j in range(len(encoder_channels)-2)])
83
+ for j in range(len(encoder_channels)-2):
84
+ self.add_module("X_{}_{}_{}".format(i, j+1, upSampleMode.name), blocks.upSample(in_channels=encoder_channels[2], out_channels=decoder_channels[2] if j == len(encoder_channels)-3 else encoder_channels[1], upSampleMode=upSampleMode, dim=dim), in_branch=["X_{}_{}".format(i+1, j)], out_branch=["X_{}_{}".format(i+1, j)])
85
+ self.add_module("SkipConnection_{}_{}".format(i, j+1), blocks.Concat(), in_branch=["X_{}_{}".format(i+1, j)]+["X_{}_{}".format(i, r) for r in range(j+1)], out_branch=["X_{}_{}".format(i, j+1)])
86
+ self.add_module("X_{}_{}".format(i, j+1), blocks.ConvBlock(in_channels=encoder_channels[1]*j+((encoder_channels[1]+encoder_channels[2]) if upSampleMode != blocks.UpSampleMode.CONV_TRANSPOSE else (decoder_channels[1]+decoder_channels[2] if j == len(encoder_channels)-3 else encoder_channels[1]+encoder_channels[1])), out_channels=decoder_channels[2] if j == len(encoder_channels)-3 else encoder_channels[1], blockConfigs=[blocks.BlockConfig(3,1,1,False, "ReLU;True", blocks.NormMode.BATCH)]*2, dim=dim), in_branch=["X_{}_{}".format(i, j+1)], out_branch=["X_{}_{}".format(i, j+1)])
87
+
88
+ class UNetPPHead(network.ModuleArgsDict):
89
+
90
+ def __init__(self, in_channels: int, out_channels: int, nb_class: int, dim: int) -> None:
91
+ super().__init__()
92
+ self.add_module("Upsample", blocks.upSample(in_channels=in_channels, out_channels=out_channels, upSampleMode=blocks.UpSampleMode.UPSAMPLE_BILINEAR, dim=dim))
93
+ self.add_module("ConvBlock", blocks.ConvBlock(in_channels=in_channels, out_channels=out_channels, blockConfigs=[blocks.BlockConfig(3,1,1,False, "ReLU;True", blocks.NormMode.BATCH)]*2, dim=dim))
94
+ self.add_module("Conv", blocks.getTorchModule("Conv", dim)(in_channels = out_channels, out_channels = nb_class, kernel_size = 3, stride = 1, padding = 1))
95
+ if nb_class > 1:
96
+ self.add_module("Softmax", torch.nn.Softmax(dim=1))
97
+ self.add_module("Argmax", blocks.ArgMax(dim=1))
98
+ else:
99
+ self.add_module("Tanh", torch.nn.Tanh())
100
+
101
+ def __init__( self,
102
+ optimizer : network.OptimizerLoader = network.OptimizerLoader(),
103
+ schedulers : network.LRSchedulersLoader = network.LRSchedulersLoader(),
104
+ outputsCriterions: dict[str, network.TargetCriterionsLoader] = {"default" : network.TargetCriterionsLoader()},
105
+ patch : Union[ModelPatch, None] = None,
106
+ encoder_channels: list[int] = [1,64,64,128,256,512],
107
+ decoder_channels: list[int] = [256,128,64,32,16,1],
108
+ layers: list[int] = [3,4,6,3],
109
+ dim : int = 2) -> None:
110
+ super().__init__(in_channels = encoder_channels[0], optimizer = optimizer, schedulers = schedulers, outputsCriterions = outputsCriterions, patch=patch, dim = dim)
111
+ self.add_module("Block_0", UNetpp.UNetPPBlock(encoder_channels, decoder_channels[::-1], UNetpp.resNetEncoder(encoder_channels, layers, dim), blocks.UpSampleMode.UPSAMPLE_BILINEAR, dim=dim), out_branch=["X_0_{}".format(j+1) for j in range(len(encoder_channels)-2)])
112
+ self.add_module("Head", UNetpp.UNetPPHead(in_channels=decoder_channels[-3], out_channels=decoder_channels[-2], nb_class=decoder_channels[-1], dim=dim), in_branch=["X_0_{}".format(len(encoder_channels)-2)], out_branch=[-1])
113
+
@@ -12,7 +12,6 @@ class UNetHead(network.ModuleArgsDict):
12
12
  self.add_module("Conv", blocks.getTorchModule("Conv", dim)(in_channels = in_channels, out_channels = nb_class, kernel_size = 1, stride = 1, padding = 0))
13
13
  self.add_module("Softmax", torch.nn.Softmax(dim=1))
14
14
  self.add_module("Argmax", blocks.ArgMax(dim=1))
15
-
16
15
 
17
16
  class UNetBlock(network.ModuleArgsDict):
18
17
 
@@ -19,7 +19,7 @@ class NormMode(Enum):
19
19
  SYNCBATCH = 5
20
20
  INSTANCE_AFFINE = 6
21
21
 
22
- def getNorm(normMode: Enum, channels : int, dim: int) -> torch.nn.Module:
22
+ def getNorm(normMode: Enum, channels : int, dim: int) -> Union[torch.nn.Module, None]:
23
23
  if normMode == NormMode.BATCH:
24
24
  return getTorchModule("BatchNorm", dim = dim)(channels, affine=True, track_running_stats=True)
25
25
  if normMode == NormMode.INSTANCE:
@@ -32,7 +32,7 @@ def getNorm(normMode: Enum, channels : int, dim: int) -> torch.nn.Module:
32
32
  return torch.nn.GroupNorm(num_groups=32, num_channels=channels)
33
33
  if normMode == NormMode.LAYER:
34
34
  return torch.nn.GroupNorm(num_groups=1, num_channels=channels)
35
- return torch.nn.Identity()
35
+ return None
36
36
 
37
37
  class UpSampleMode(Enum):
38
38
  CONV_TRANSPOSE = 0,
@@ -59,9 +59,9 @@ class BlockConfig():
59
59
  self.stride = stride
60
60
  self.padding = padding
61
61
  self.activation = activation
62
- self.normMode = normMode
63
-
64
- if isinstance(normMode, str):
62
+ if normMode is None:
63
+ self.norm = None
64
+ elif isinstance(normMode, str):
65
65
  self.norm = NormMode._member_map_[normMode]
66
66
  elif isinstance(normMode, NormMode):
67
67
  self.norm = normMode
@@ -70,9 +70,13 @@ class BlockConfig():
70
70
  return getTorchModule("Conv", dim = dim)(in_channels = in_channels, out_channels = out_channels, kernel_size = self.kernel_size, stride = self.stride, padding = self.padding, bias=self.bias)
71
71
 
72
72
  def getNorm(self, channels : int, dim: int) -> torch.nn.Module:
73
+ if self.norm is None:
74
+ return None
73
75
  return getNorm(self.norm, channels, dim) if isinstance(self.norm, NormMode) else self.norm(channels)
74
76
 
75
77
  def getActivation(self) -> torch.nn.Module:
78
+ if self.activation is None:
79
+ return None
76
80
  if isinstance(self.activation, str):
77
81
  return getTorchModule(self.activation.split(";")[0])(*[ast.literal_eval(value) for value in self.activation.split(";")[1:]]) if self.activation != "None" else torch.nn.Identity()
78
82
  return self.activation()
@@ -83,21 +87,35 @@ class ConvBlock(network.ModuleArgsDict):
83
87
  super().__init__()
84
88
  for i, blockConfig in enumerate(blockConfigs):
85
89
  self.add_module("Conv_{}".format(i), blockConfig.getConv(in_channels, out_channels, dim), alias=alias[0])
86
- self.add_module("Norm_{}".format(i), blockConfig.getNorm(out_channels, dim), alias=alias[1])
87
- self.add_module("Activation_{}".format(i), blockConfig.getActivation(), alias=alias[2])
90
+ norm = blockConfig.getNorm(out_channels, dim)
91
+ if norm is not None:
92
+ self.add_module("Norm_{}".format(i), norm, alias=alias[1])
93
+ activation = blockConfig.getActivation()
94
+ if activation is not None:
95
+ self.add_module("Activation_{}".format(i), activation, alias=alias[2])
88
96
  in_channels = out_channels
89
97
 
90
98
  class ResBlock(network.ModuleArgsDict):
91
99
 
92
- def __init__(self, in_channels : int, out_channels : int, nb_conv: int, blockConfig : BlockConfig, dim : int, alias : list[list[str]]=[[], [], [], []]) -> None:
100
+ def __init__(self, in_channels : int, out_channels : int, blockConfigs : list[BlockConfig], dim : int, alias : list[list[str]]=[[], [], [], [], []]) -> None:
93
101
  super().__init__()
94
- for i in range(nb_conv):
95
- self.add_module("ConvBlock_{}".format(i), ConvBlock(in_channels, out_channels, nb_conv=1, blockConfig=blockConfig, dim=dim, alias=alias[:3]))
102
+ for i, blockConfig in enumerate(blockConfigs):
103
+ self.add_module("Conv_{}".format(i), blockConfig.getConv(in_channels, out_channels, dim), alias=alias[0])
104
+ norm = blockConfig.getNorm(out_channels, dim)
105
+ if norm is not None:
106
+ self.add_module("Norm_{}".format(i), norm, alias=alias[1])
107
+ activation = blockConfig.getActivation()
108
+ if activation is not None:
109
+ self.add_module("Activation_{}".format(i), activation, alias=alias[2])
110
+
96
111
  if in_channels != out_channels:
97
- self.add_module("Conv_skip_{}".format(i), blockConfig.getConv(in_channels, out_channels, dim), alias=alias[3], in_branch=[1], out_branch=[1])
112
+ self.add_module("Conv_skip", getTorchModule("Conv", dim)(in_channels, out_channels, 1, blockConfig.stride, bias=blockConfig.bias), alias=alias[3], in_branch=[1], out_branch=[1])
113
+ self.add_module("Norm_skip", blockConfig.getNorm(out_channels, dim), alias=alias[4], in_branch=[1], out_branch=[1])
98
114
  in_channels = out_channels
99
- self.add_module("Add_{}".format(i), Add(), in_branch=[0,1], out_branch=[0,1])
100
-
115
+
116
+ self.add_module("Add", Add(), in_branch=[0,1])
117
+ self.add_module("Norm_{}".format(i+1), torch.nn.ReLU(inplace=True))
118
+
101
119
  def downSample(in_channels: int, out_channels: int, downSampleMode: DownSampleMode, dim: int) -> torch.nn.Module:
102
120
  if downSampleMode == DownSampleMode.MAXPOOL:
103
121
  return getTorchModule("MaxPool", dim = dim)(2)
@@ -176,6 +194,7 @@ class Print(torch.nn.Module):
176
194
  super().__init__()
177
195
 
178
196
  def forward(self, input: torch.Tensor) -> torch.Tensor:
197
+ print(input.shape)
179
198
  return input
180
199
 
181
200
  class Write(torch.nn.Module):
@@ -16,7 +16,7 @@ from enum import Enum
16
16
  from konfai import KONFAI_ROOT
17
17
  from konfai.metric.schedulers import Scheduler
18
18
  from konfai.utils.config import config
19
- from konfai.utils.utils import State, _getModule, getDevice, getGPUMemory, MeasureError
19
+ from konfai.utils.utils import State, _getModule, getDevice, getGPUMemory, MeasureError, TrainerError
20
20
  from konfai.data.patching import Accumulator, ModelPatch
21
21
 
22
22
  class NetState(Enum):
@@ -55,9 +55,16 @@ class LRSchedulersLoader():
55
55
 
56
56
  def getschedulers(self, key: str, optimizer: torch.optim.Optimizer) -> dict[torch.optim.lr_scheduler._LRScheduler, int]:
57
57
  schedulers : dict[torch.optim.lr_scheduler._LRScheduler, int] = {}
58
- for name, step in self.params.items():
59
- if name:
60
- schedulers[config("Trainer.Model.{}.Schedulers.{}".format(key, name))(getattr(importlib.import_module('torch.optim.lr_scheduler'), name))(optimizer, config = None)] = step.nb_step
58
+ for nameTmp, step in self.params.items():
59
+ if nameTmp:
60
+ ok = False
61
+ for m in ["torch.optim.lr_scheduler", "konfai.metric.schedulers"]:
62
+ module, name = _getModule(nameTmp, m)
63
+ if hasattr(importlib.import_module(module), name):
64
+ schedulers[config("{}.Model.{}.Schedulers.{}".format(KONFAI_ROOT(), key, name))(getattr(importlib.import_module(module), name))(optimizer, config = None)] = step.nb_step
65
+ ok = True
66
+ if not ok:
67
+ raise TrainerError("Unknown scheduler {}, tried importing from: 'torch.optim.lr_scheduler' and 'konfai.metric.schedulers', but no valid match was found. Check your YAML config or scheduler name spelling.".format(nameTmp))
61
68
  return schedulers
62
69
 
63
70
  class SchedulersLoader():
@@ -95,7 +102,7 @@ class CriterionsLoader():
95
102
  def getCriterions(self, model_classname : str, output_group : str, target_group : str) -> dict[torch.nn.Module, CriterionsAttr]:
96
103
  criterions = {}
97
104
  for module_classpath, criterionsAttr in self.criterionsLoader.items():
98
- module, name = _getModule(module_classpath, "metric.measure")
105
+ module, name = _getModule(module_classpath, "konfai.metric.measure")
99
106
  criterionsAttr.isTorchCriterion = module.startswith("torch")
100
107
  criterionsAttr.sheduler = criterionsAttr.l.getschedulers("{}.Model.{}.outputsCriterions.{}.targetsCriterions.{}.criterionsLoader.{}".format(KONFAI_ROOT(), model_classname, output_group, target_group, module_classpath))
101
108
  criterions[config("{}.Model.{}.outputsCriterions.{}.targetsCriterions.{}.criterionsLoader.{}".format(KONFAI_ROOT(), model_classname, output_group, target_group, module_classpath))(getattr(importlib.import_module(module), name))(config = None)] = criterionsAttr
@@ -133,9 +140,10 @@ class Measure():
133
140
  self._loss.clear()
134
141
 
135
142
  def add(self, weight: float, value: torch.Tensor) -> None:
136
- self._loss.append(value if self.isLoss else value.detach())
137
- self._values.append(value.item())
138
- self._weight.append(weight)
143
+ if self.isLoss or value != 0:
144
+ self._loss.append(value if self.isLoss else value.detach())
145
+ self._values.append(value.item())
146
+ self._weight.append(weight)
139
147
 
140
148
  def getLastLoss(self) -> torch.Tensor:
141
149
  return self._loss[-1]*self._weight[-1] if len(self._loss) else torch.zeros((1), requires_grad = True)
@@ -560,6 +568,7 @@ class Network(ModuleArgsDict, ABC):
560
568
  self.init_gain = init_gain
561
569
  self.dim = dim
562
570
  self._it = 0
571
+ self._nb_lr_update = 0
563
572
  self.outputsGroup : list[OutputsGroup]= []
564
573
 
565
574
  @_function_network()
@@ -654,9 +663,15 @@ class Network(ModuleArgsDict, ABC):
654
663
  else:
655
664
  model_state_dict[alias] = model_state_dict_tmp[alias]
656
665
  self.load_state_dict(model_state_dict)
657
-
658
- if "{}_optimizer_state_dict".format(name) in state_dict and self.optimizer:
659
- self.optimizer.load_state_dict(state_dict['{}_optimizer_state_dict'.format(name)])
666
+ if "{}_optimizer_state_dict".format(self.getName()) in state_dict and self.optimizer:
667
+ self.optimizer.load_state_dict(state_dict['{}_optimizer_state_dict'.format(self.getName())])
668
+ if "{}_it".format(self.getName()) in state_dict:
669
+ self._it = int(state_dict["{}_it".format(self.getName())])
670
+ if "{}_nb_lr_update".format(self.getName()) in state_dict:
671
+ self._nb_lr_update = int(state_dict["{}_nb_lr_update".format(self.getName())])
672
+ if self.schedulers:
673
+ for scheduler in self.schedulers:
674
+ scheduler.last_epoch = self._nb_lr_update
660
675
  self.initialized()
661
676
 
662
677
  def _compute_channels_trace(self, module : ModuleArgsDict, in_channels : int, gradient_checkpoints: Union[list[str], None], gpu_checkpoints: Union[list[str], None], name: Union[str, None] = None, in_is_channel = True, out_channels : Union[int, None] = None, out_is_channel = True) -> tuple[int, bool, int, bool]:
@@ -885,11 +900,12 @@ class Network(ModuleArgsDict, ABC):
885
900
 
886
901
  @_function_network()
887
902
  def update_lr(self):
903
+ self._nb_lr_update+=1
888
904
  step = 0
889
905
  scheduler = None
890
906
  if self.schedulers:
891
907
  for scheduler, value in self.schedulers.items():
892
- if value is None or (self._it >= step and self._it < step+value):
908
+ if value is None or (self._nb_lr_update >= step and self._nb_lr_update < step+value):
893
909
  break
894
910
  step += value
895
911
  if scheduler:
@@ -897,7 +913,7 @@ class Network(ModuleArgsDict, ABC):
897
913
  if self.measure:
898
914
  scheduler.step(sum(self.measure.getLastValues(0).values()))
899
915
  else:
900
- scheduler.step()
916
+ scheduler.step()
901
917
 
902
918
  @_function_network()
903
919
  def getNetworks(self) -> Self:
@@ -915,6 +931,12 @@ class Network(ModuleArgsDict, ABC):
915
931
  v = Network.to(v, int(os.environ["device"]))
916
932
  else:
917
933
  v = v.to(getDevice(int(os.environ["device"])))
934
+ if isinstance(module, Network):
935
+ if module.optimizer is not None:
936
+ for state in module.optimizer.state.values():
937
+ for k, v in state.items():
938
+ if isinstance(v, torch.Tensor):
939
+ state[k] = v.to(getDevice(int(os.environ["device"])))
918
940
  return module
919
941
 
920
942
  def getName(self) -> str:
@@ -943,7 +965,7 @@ class ModelLoader():
943
965
 
944
966
  @config("Model")
945
967
  def __init__(self, classpath : str = "default:segmentation.UNet.UNet") -> None:
946
- self.module, self.name = _getModule(classpath, "models")
968
+ self.module, self.name = _getModule(classpath, "konfai.models")
947
969
 
948
970
  def getModel(self, train : bool = True, DL_args: Union[str, None] = None, DL_without=["optimizer", "schedulers", "nb_batch_per_step", "init_type", "init_gain"]) -> Network:
949
971
  if not DL_args:
@@ -8,7 +8,7 @@ import os
8
8
 
9
9
  from konfai import MODELS_DIRECTORY, PREDICTIONS_DIRECTORY, CONFIG_FILE, MODEL, KONFAI_ROOT
10
10
  from konfai.utils.config import config
11
- from konfai.utils.utils import State, get_patch_slices_from_nb_patch_per_dim, NeedDevice, _getModule, DistributedObject, DataLog, description
11
+ from konfai.utils.utils import State, get_patch_slices_from_nb_patch_per_dim, NeedDevice, _getModule, DistributedObject, DataLog, description, PredictorError
12
12
  from konfai.utils.dataset import Dataset, Attribute
13
13
  from konfai.data.data_manager import DataPrediction, DatasetIter
14
14
  from konfai.data.patching import Accumulator, PathCombine
@@ -46,7 +46,7 @@ class OutDataset(Dataset, NeedDevice, ABC):
46
46
  self.names: dict[int, str] = {}
47
47
  self.nb_data_augmentation = 0
48
48
 
49
- def load(self, name_layer: str, datasets: list[Dataset]):
49
+ def load(self, name_layer: str, datasets: list[Dataset], groups: dict[str, str]):
50
50
  transforms_type = ["pre_transforms", "post_transforms", "final_transforms"]
51
51
  for name, _transform_type, transform_type in [(k, getattr(self, "_{}".format(k)), getattr(self, k)) for k in transforms_type]:
52
52
 
@@ -57,7 +57,7 @@ class OutDataset(Dataset, NeedDevice, ABC):
57
57
  transform_type.append(transform)
58
58
 
59
59
  if self._patchCombine is not None:
60
- module, name = _getModule(self._patchCombine, "data.patching")
60
+ module, name = _getModule(self._patchCombine, "konfai.data.patching")
61
61
  self.patchCombine = getattr(importlib.import_module(module), name)(config = None, DL_args = "{}.outsDataset.{}.OutDataset".format(KONFAI_ROOT(), name_layer))
62
62
 
63
63
  def setPatchConfig(self, patchSize: Union[list[int], None], overlap: Union[int, None], nb_data_augmentation: int) -> None:
@@ -96,17 +96,21 @@ class Reduction():
96
96
  def __init__(self):
97
97
  pass
98
98
 
99
- class ReductionMean():
99
+ class Mean(Reduction):
100
100
 
101
101
  def __init__(self):
102
102
  pass
103
-
104
- class ReductionMedian():
103
+
104
+ def __call__(self, result: torch.Tensor) -> torch.Tensor:
105
+ return torch.mean(result.float(), dim=0)
106
+
107
+ class Median(Reduction):
105
108
 
106
109
  def __init__(self):
107
110
  pass
108
111
 
109
-
112
+ def __call__(self, result: torch.Tensor) -> torch.Tensor:
113
+ return torch.median(result.float(), dim=0).values
110
114
 
111
115
  class OutSameAsGroupDataset(OutDataset):
112
116
 
@@ -114,7 +118,7 @@ class OutSameAsGroupDataset(OutDataset):
114
118
  def __init__(self, dataset_filename: str = "./Dataset:mha", group: str = "default", sameAsGroup: str = "default", pre_transforms : dict[str, TransformLoader] = {"default:Normalize": TransformLoader()}, post_transforms : dict[str, TransformLoader] = {"default:Normalize": TransformLoader()}, final_transforms : dict[str, TransformLoader] = {"default:Normalize": TransformLoader()}, patchCombine: Union[str, None] = None, reduction: str = "mean", inverse_transform: bool = True) -> None:
115
119
  super().__init__(dataset_filename, group, pre_transforms, post_transforms, final_transforms, patchCombine)
116
120
  self.group_src, self.group_dest = sameAsGroup.split(":")
117
- self.reduction = reduction
121
+ self.reduction_classpath = reduction
118
122
  self.inverse_transform = inverse_transform
119
123
 
120
124
  def addLayer(self, index_dataset: int, index_augmentation: int, index_patch: int, layer: torch.Tensor, dataset: DatasetIter):
@@ -137,9 +141,24 @@ class OutSameAsGroupDataset(OutDataset):
137
141
  if self.inverse_transform:
138
142
  for transform in reversed(dataset.groups_src[self.group_src][self.group_dest].post_transforms):
139
143
  layer = transform.inverse(self.names[index_dataset], layer, self.attributes[index_dataset][index_augmentation][index_patch])
140
-
141
144
  self.output_layer_accumulator[index_dataset][index_augmentation].addLayer(index_patch, layer)
142
145
 
146
+
147
+ def load(self, name_layer: str, datasets: list[Dataset], groups: dict[str, str]):
148
+ super().load(name_layer, datasets, groups)
149
+ module, name = _getModule(self.reduction_classpath, "konfai.predictor")
150
+ self.reduction = config("{}.outsDataset.{}.OutDataset.{}".format(KONFAI_ROOT(), name_layer, self.reduction_classpath))(getattr(importlib.import_module(module), name))(config = None)
151
+
152
+ if self.group_src not in groups.keys():
153
+ raise PredictorError(
154
+ f"Source group '{self.group_src}' not found. Available groups: {list(groups.keys())}."
155
+ )
156
+
157
+ if self.group_dest not in groups[self.group_src]:
158
+ raise PredictorError(
159
+ f"Destination group '{self.group_dest}' not found. Available groups: {groups[self.group_src]}."
160
+ )
161
+
143
162
  def _getOutput(self, index: int, index_augmentation: int, dataset: DatasetIter) -> torch.Tensor:
144
163
  layer = self.output_layer_accumulator[index][index_augmentation].assemble()
145
164
  name = self.names[index]
@@ -166,14 +185,8 @@ class OutSameAsGroupDataset(OutDataset):
166
185
  result = torch.cat([self._getOutput(index, index_augmentation, dataset).unsqueeze(0) for index_augmentation in self.output_layer_accumulator[index].keys()], dim=0)
167
186
  name = self.names[index]
168
187
  self.output_layer_accumulator.pop(index)
169
- dtype = result.dtype
188
+ result = self.reduction(result.float()).to(result.dtype)
170
189
 
171
- if self.reduction == "mean":
172
- result = torch.mean(result.float(), dim=0).to(dtype)
173
- elif self.reduction == "median":
174
- result, _ = torch.median(result.float(), dim=0)
175
- else:
176
- raise NameError("Reduction method does not exist (mean, median)")
177
190
  for transform in self.final_transforms:
178
191
  result = transform(name, result, self.attributes[index][0][0])
179
192
  return result
@@ -294,9 +307,9 @@ class _Predictor():
294
307
 
295
308
  class ModelComposite(Network):
296
309
 
297
- def __init__(self, model: Network, nb_models: int, method: str):
310
+ def __init__(self, model: Network, nb_models: int, combine: Reduction):
298
311
  super().__init__(model.in_channels, model.optimizer, model.schedulers, model.outputsCriterionsLoader, model.patch, model.nb_batch_per_step, model.init_type, model.init_gain, model.dim)
299
- self.method = method
312
+ self.combine = combine
300
313
  for i in range(nb_models):
301
314
  self.add_module("Model_{}".format(i), copy.deepcopy(model), in_branch=[0], out_branch=["output_{}".format(i)])
302
315
 
@@ -317,12 +330,7 @@ class ModelComposite(Network):
317
330
 
318
331
  final_outputs = []
319
332
  for key, tensors in aggregated.items():
320
- stacked = torch.stack(tensors, dim=0)
321
- if self.method == 'mean':
322
- agg = torch.mean(stacked, dim=0)
323
- elif self.method == 'median':
324
- agg = torch.median(stacked, dim=0).values
325
- final_outputs.append((key, agg))
333
+ final_outputs.append((key, self.combine(torch.stack(tensors, dim=0))))
326
334
 
327
335
  return final_outputs
328
336
 
@@ -345,7 +353,7 @@ class Predictor(DistributedObject):
345
353
  super().__init__(train_name)
346
354
  self.manual_seed = manual_seed
347
355
  self.dataset = dataset
348
- self.combine = combine
356
+ self.combine_classpath = combine
349
357
  self.autocast = autocast
350
358
 
351
359
  self.model = model.getModel(train=False)
@@ -403,10 +411,24 @@ class Predictor(DistributedObject):
403
411
 
404
412
  shutil.copyfile(CONFIG_FILE(), self.predict_path+"Prediction.yml")
405
413
 
414
+
406
415
  self.model.init(self.autocast, State.PREDICTION, self.dataset.getGroupsDest())
407
416
  self.model.init_outputsGroup()
408
417
  self.model._compute_channels_trace(self.model, self.model.in_channels, None, self.gpu_checkpoints)
409
- self.modelComposite = ModelComposite(self.model, len(MODEL().split(":")), self.combine)
418
+
419
+ modules = []
420
+ for i,_,_ in self.model.named_ModuleArgsDict():
421
+ modules.append(i)
422
+ for output_group in self.outsDataset.keys():
423
+ if output_group not in modules:
424
+ raise PredictorError("The output group '{}' defined in 'outputsCriterions' does not correspond to any module in the model.".format(output_group),
425
+ "Available modules: {}".format(modules),
426
+ "Please check that the name matches exactly a submodule or output of your model architecture."
427
+ )
428
+ module, name = _getModule(self.combine_classpath, "konfai.predictor")
429
+ combine = config("{}.{}".format(KONFAI_ROOT(), self.combine_classpath))(getattr(importlib.import_module(module), name))(config = None)
430
+
431
+ self.modelComposite = ModelComposite(self.model, len(MODEL().split(":")), combine)
410
432
  self.modelComposite.load(self._load())
411
433
 
412
434
  if len(list(self.outsDataset.keys())) == 0 and len([network for network in self.modelComposite.getNetworks().values() if network.measure is not None]) == 0:
@@ -415,7 +437,7 @@ class Predictor(DistributedObject):
415
437
  self.size = (len(self.gpu_checkpoints)+1 if self.gpu_checkpoints else 1)
416
438
  self.dataloader = self.dataset.getData(world_size//self.size)
417
439
  for name, outDataset in self.outsDataset.items():
418
- outDataset.load(name.replace(".", ":"), list(self.dataset.datasets.values()))
440
+ outDataset.load(name.replace(".", ":"), list(self.dataset.datasets.values()), {src : dest for src, inner in self.dataset.groups_src.items() for dest in inner})
419
441
 
420
442
 
421
443
  def run_process(self, world_size: int, global_rank: int, local_rank: int, dataloaders: list[DataLoader]):
@@ -108,9 +108,7 @@ class _Trainer():
108
108
  if data_log is not None:
109
109
  for data in data_log:
110
110
  self.data_log[data.split("/")[0].replace(":", ".")] = (DataLog.__getitem__(data.split("/")[1]).value[0], int(data.split("/")[2]))
111
-
112
-
113
-
111
+
114
112
  def __enter__(self):
115
113
  return self
116
114
 
@@ -120,7 +118,8 @@ class _Trainer():
120
118
 
121
119
  def run(self) -> None:
122
120
  self.dataloader_training.dataset.load("Train")
123
- self.dataloader_validation.dataset.load("Validation")
121
+ if self.dataloader_validation is not None:
122
+ self.dataloader_validation.dataset.load("Validation")
124
123
  with tqdm.tqdm(iterable = range(self.epoch, self.epochs), leave=False, total=self.epochs, initial=self.epoch, desc="Progress") as epoch_tqdm:
125
124
  for self.epoch in epoch_tqdm:
126
125
  self.train()
@@ -211,6 +210,9 @@ class _Trainer():
211
210
  save_dict["Model_EMA"] = self.modelEMA.module.state_dict()
212
211
 
213
212
  save_dict.update({'{}_optimizer_state_dict'.format(name): network.optimizer.state_dict() for name, network in self.model.module.getNetworks().items() if network.optimizer is not None})
213
+ save_dict.update({'{}_it'.format(name): network._it for name, network in self.model.module.getNetworks().items() if network.optimizer is not None})
214
+ save_dict.update({'{}_nb_lr_update'.format(name): network._nb_lr_update for name, network in self.model.module.getNetworks().items() if network.optimizer is not None})
215
+
214
216
  torch.save(save_dict, save_path)
215
217
 
216
218
  if self.save_checkpoint_mode == "BEST":
@@ -355,7 +357,7 @@ class Trainer(DistributedObject):
355
357
  name = sorted(os.listdir(path))[-1]
356
358
 
357
359
  if os.path.exists(path+name):
358
- state_dict = torch.load(path+name, weights_only=False)
360
+ state_dict = torch.load(path+name, weights_only=False, map_location="cpu")
359
361
  else:
360
362
  raise Exception("Model : {} does not exist !".format(self.name))
361
363
 
@@ -37,7 +37,7 @@ def _getModule(classpath : str, type : str) -> tuple[str, str]:
37
37
  module = ".".join(classpath.split(":")[:-1])
38
38
  name = classpath.split(":")[-1]
39
39
  else:
40
- module = "konfai."+type+("." if len(classpath.split(".")) > 2 else "")+".".join(classpath.split(".")[:-1])
40
+ module = type+("." if len(classpath.split(".")) > 2 else "")+".".join(classpath.split(".")[:-1])
41
41
  name = classpath.split(".")[-1]
42
42
  return module, name
43
43
 
@@ -586,6 +586,11 @@ class EvaluatorError(KonfAIError):
586
586
  def __init__(self, *message) -> None:
587
587
  super().__init__("Evaluator", message)
588
588
 
589
+ class PredictorError(KonfAIError):
590
+
591
+ def __init__(self, *message) -> None:
592
+ super().__init__("Predictor", message)
593
+
589
594
  class TransformError(KonfAIError):
590
595
 
591
596
  def __init__(self, *message) -> None:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: konfai
3
- Version: 1.1.4
3
+ Version: 1.1.6
4
4
  Summary: Modular and configurable Deep Learning framework with YAML and PyTorch
5
5
  Author-email: Valentin Boussot <boussot.v@gmail.com>
6
6
  License-Expression: Apache-2.0
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
4
4
 
5
5
  [project]
6
6
  name = "konfai"
7
- version = "1.1.4"
7
+ version = "1.1.6"
8
8
  description = "Modular and configurable Deep Learning framework with YAML and PyTorch"
9
9
  readme = "README.md"
10
10
  requires-python = ">=3.8"
@@ -1,52 +0,0 @@
1
- from typing import Union
2
- import torch
3
-
4
- from konfai.network import network, blocks
5
- from konfai.utils.config import config
6
- from konfai.data.patching import ModelPatch
7
-
8
-
9
- class NestedUNetBlock(network.ModuleArgsDict):
10
-
11
- def __init__(self, channels: list[int], nb_conv_per_stage: int, blockConfig: blocks.BlockConfig, downSampleMode: blocks.DownSampleMode, upSampleMode: blocks.UpSampleMode, attention : bool, block: type, dim: int, i : int = 0) -> None:
12
- super().__init__()
13
- if i > 0:
14
- self.add_module(downSampleMode.name, blocks.downSample(in_channels=channels[0], out_channels=channels[1], downSampleMode=downSampleMode, dim=dim))
15
-
16
- self.add_module("X_{}_{}".format(i, 0), block(in_channels=channels[1 if downSampleMode == blocks.DownSampleMode.CONV_STRIDE and i > 0 else 0], out_channels=channels[1], blockConfigs=[blockConfig]*nb_conv_per_stage, dim=dim), out_branch=["X_{}_{}".format(i, 0)])
17
- if len(channels) > 2:
18
- self.add_module("UNetBlock_{}".format(i+1), NestedUNetBlock(channels[1:], nb_conv_per_stage, blockConfig, downSampleMode, upSampleMode, attention, block, dim, i+1), in_branch=["X_{}_{}".format(i, 0)], out_branch=["X_{}_{}".format(i+1, j) for j in range(len(channels)-2)])
19
- for j in range(len(channels)-2):
20
- self.add_module("X_{}_{}_{}".format(i, j+1, upSampleMode.name), blocks.upSample(in_channels=channels[2], out_channels=channels[1], upSampleMode=upSampleMode, dim=dim), in_branch=["X_{}_{}".format(i+1, j)], out_branch=["X_{}_{}".format(i+1, j)])
21
- self.add_module("SkipConnection_{}_{}".format(i, j+1), blocks.Concat(), in_branch=["X_{}_{}".format(i+1, j)]+["X_{}_{}".format(i, r) for r in range(j+1)], out_branch=["X_{}_{}".format(i, j+1)])
22
- self.add_module("X_{}_{}".format(i, j+1), block(in_channels=(channels[1]*(j+1)+channels[2]) if upSampleMode != blocks.UpSampleMode.CONV_TRANSPOSE else channels[1]*(j+2), out_channels=channels[1], blockConfigs=[blockConfig]*nb_conv_per_stage, dim=dim), in_branch=["X_{}_{}".format(i, j+1)], out_branch=["X_{}_{}".format(i, j+1)])
23
-
24
- class NestedUNet(network.Network):
25
-
26
- class NestedUNetHead(network.ModuleArgsDict):
27
-
28
- def __init__(self, in_channels: int, nb_class: int, dim: int) -> None:
29
- super().__init__()
30
- self.add_module("Conv", blocks.getTorchModule("Conv", dim)(in_channels = in_channels, out_channels = nb_class, kernel_size = 1, stride = 1, padding = 0))
31
- self.add_module("Softmax", torch.nn.Softmax(dim=1))
32
- self.add_module("Argmax", blocks.ArgMax(dim=1))
33
-
34
- def __init__( self,
35
- optimizer : network.OptimizerLoader = network.OptimizerLoader(),
36
- schedulers : network.LRSchedulersLoader = network.LRSchedulersLoader(),
37
- outputsCriterions: dict[str, network.TargetCriterionsLoader] = {"default" : network.TargetCriterionsLoader()},
38
- patch : Union[ModelPatch, None] = None,
39
- dim : int = 3,
40
- channels: list[int]=[1, 64, 128, 256, 512, 1024],
41
- nb_class: int = 2,
42
- blockConfig: blocks.BlockConfig = blocks.BlockConfig(),
43
- nb_conv_per_stage: int = 2,
44
- downSampleMode: str = "MAXPOOL",
45
- upSampleMode: str = "CONV_TRANSPOSE",
46
- attention : bool = False,
47
- blockType: str = "Conv") -> None:
48
- super().__init__(in_channels = channels[0], optimizer = optimizer, schedulers = schedulers, outputsCriterions = outputsCriterions, patch=patch, dim = dim)
49
-
50
- self.add_module("UNetBlock_0", NestedUNetBlock(channels, nb_conv_per_stage, blockConfig, downSampleMode=blocks.DownSampleMode._member_map_[downSampleMode], upSampleMode=blocks.UpSampleMode._member_map_[upSampleMode], attention=attention, block = blocks.ConvBlock if blockType == "Conv" else blocks.ResBlock, dim=dim), out_branch=["X_0_{}".format(j+1) for j in range(len(channels)-2)])
51
- for j in range(len(channels)-2):
52
- self.add_module("Head_{}".format(j), NestedUNet.NestedUNetHead(in_channels=channels[1], nb_class=nb_class, dim=dim), in_branch=["X_0_{}".format(j+1)], out_branch=[-1])
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes