konfai 1.1.4__tar.gz → 1.1.5__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of konfai might be problematic. Click here for more details.

Files changed (47) hide show
  1. {konfai-1.1.4 → konfai-1.1.5}/PKG-INFO +1 -1
  2. {konfai-1.1.4 → konfai-1.1.5}/konfai/data/data_manager.py +1 -1
  3. {konfai-1.1.4 → konfai-1.1.5}/konfai/data/patching.py +14 -17
  4. {konfai-1.1.4 → konfai-1.1.5}/konfai/metric/measure.py +37 -13
  5. konfai-1.1.5/konfai/models/segmentation/NestedUNet.py +115 -0
  6. {konfai-1.1.4 → konfai-1.1.5}/konfai/network/blocks.py +32 -13
  7. {konfai-1.1.4 → konfai-1.1.5}/konfai/network/network.py +4 -3
  8. {konfai-1.1.4 → konfai-1.1.5}/konfai/predictor.py +28 -4
  9. {konfai-1.1.4 → konfai-1.1.5}/konfai/trainer.py +2 -1
  10. {konfai-1.1.4 → konfai-1.1.5}/konfai/utils/utils.py +5 -0
  11. {konfai-1.1.4 → konfai-1.1.5}/konfai.egg-info/PKG-INFO +1 -1
  12. {konfai-1.1.4 → konfai-1.1.5}/pyproject.toml +1 -1
  13. konfai-1.1.4/konfai/models/segmentation/NestedUNet.py +0 -52
  14. {konfai-1.1.4 → konfai-1.1.5}/LICENSE +0 -0
  15. {konfai-1.1.4 → konfai-1.1.5}/README.md +0 -0
  16. {konfai-1.1.4 → konfai-1.1.5}/konfai/__init__.py +0 -0
  17. {konfai-1.1.4 → konfai-1.1.5}/konfai/data/__init__.py +0 -0
  18. {konfai-1.1.4 → konfai-1.1.5}/konfai/data/augmentation.py +0 -0
  19. {konfai-1.1.4 → konfai-1.1.5}/konfai/data/transform.py +0 -0
  20. {konfai-1.1.4 → konfai-1.1.5}/konfai/evaluator.py +0 -0
  21. {konfai-1.1.4 → konfai-1.1.5}/konfai/main.py +0 -0
  22. {konfai-1.1.4 → konfai-1.1.5}/konfai/metric/__init__.py +0 -0
  23. {konfai-1.1.4 → konfai-1.1.5}/konfai/metric/schedulers.py +0 -0
  24. {konfai-1.1.4 → konfai-1.1.5}/konfai/models/classification/convNeXt.py +0 -0
  25. {konfai-1.1.4 → konfai-1.1.5}/konfai/models/classification/resnet.py +0 -0
  26. {konfai-1.1.4 → konfai-1.1.5}/konfai/models/generation/cStyleGan.py +0 -0
  27. {konfai-1.1.4 → konfai-1.1.5}/konfai/models/generation/ddpm.py +0 -0
  28. {konfai-1.1.4 → konfai-1.1.5}/konfai/models/generation/diffusionGan.py +0 -0
  29. {konfai-1.1.4 → konfai-1.1.5}/konfai/models/generation/gan.py +0 -0
  30. {konfai-1.1.4 → konfai-1.1.5}/konfai/models/generation/vae.py +0 -0
  31. {konfai-1.1.4 → konfai-1.1.5}/konfai/models/registration/registration.py +0 -0
  32. {konfai-1.1.4 → konfai-1.1.5}/konfai/models/representation/representation.py +0 -0
  33. {konfai-1.1.4 → konfai-1.1.5}/konfai/models/segmentation/UNet.py +0 -0
  34. {konfai-1.1.4 → konfai-1.1.5}/konfai/network/__init__.py +0 -0
  35. {konfai-1.1.4 → konfai-1.1.5}/konfai/utils/ITK.py +0 -0
  36. {konfai-1.1.4 → konfai-1.1.5}/konfai/utils/__init__.py +0 -0
  37. {konfai-1.1.4 → konfai-1.1.5}/konfai/utils/config.py +0 -0
  38. {konfai-1.1.4 → konfai-1.1.5}/konfai/utils/dataset.py +0 -0
  39. {konfai-1.1.4 → konfai-1.1.5}/konfai/utils/registration.py +0 -0
  40. {konfai-1.1.4 → konfai-1.1.5}/konfai.egg-info/SOURCES.txt +0 -0
  41. {konfai-1.1.4 → konfai-1.1.5}/konfai.egg-info/dependency_links.txt +0 -0
  42. {konfai-1.1.4 → konfai-1.1.5}/konfai.egg-info/entry_points.txt +0 -0
  43. {konfai-1.1.4 → konfai-1.1.5}/konfai.egg-info/requires.txt +0 -0
  44. {konfai-1.1.4 → konfai-1.1.5}/konfai.egg-info/top_level.txt +0 -0
  45. {konfai-1.1.4 → konfai-1.1.5}/setup.cfg +0 -0
  46. {konfai-1.1.4 → konfai-1.1.5}/tests/test_config.py +0 -0
  47. {konfai-1.1.4 → konfai-1.1.5}/tests/test_dataset.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: konfai
3
- Version: 1.1.4
3
+ Version: 1.1.5
4
4
  Summary: Modular and configurable Deep Learning framework with YAML and PyTorch
5
5
  Author-email: Valentin Boussot <boussot.v@gmail.com>
6
6
  License-Expression: Apache-2.0
@@ -129,7 +129,7 @@ class DatasetIter(data.Dataset):
129
129
  for group_dest in self.groups_src[group_src]:
130
130
  self.data[group_dest][index].unloadAugmentation()
131
131
  self.data[group_dest][index].resetAugmentation()
132
- self.load(label + " Augmentation")
132
+ self.load(label + " Augmentation")
133
133
 
134
134
  def load(self, label: str):
135
135
  if self.use_cache:
@@ -125,23 +125,20 @@ class Accumulator():
125
125
  return len(self.patch_slices) == len([v for v in self._layer_accumulator if v is not None])
126
126
 
127
127
  def assemble(self) -> torch.Tensor:
128
- if all([self._layer_accumulator[0].shape[i-len(self.patch_size)] == size for i, size in enumerate(self.patch_size)]):
129
- N = 2 if self.batch else 1
130
-
131
- result = torch.zeros((list(self._layer_accumulator[0].shape[:N])+list(max([[v.stop for v in patch] for patch in self.patch_slices]))), dtype=self._layer_accumulator[0].dtype).to(self._layer_accumulator[0].device)
132
- for patch_slice, data in zip(self.patch_slices, self._layer_accumulator):
133
- slices_dest = tuple([slice(result.shape[i]) for i in range(N)] + list(patch_slice))
134
-
135
- for dim, s in enumerate(patch_slice):
136
- if s.stop-s.start == 1:
137
- data = data.unsqueeze(dim=dim+N)
138
- if self.patchCombine is not None:
139
- result[slices_dest] += self.patchCombine(data)
140
- else:
141
- result[slices_dest] = data
142
- result = result[tuple([slice(None, None)]+[slice(0, s) for s in self.shape])]
143
- else:
144
- result = torch.cat(tuple(self._layer_accumulator), dim=0)
128
+ N = 2 if self.batch else 1
129
+ result = torch.zeros((list(self._layer_accumulator[0].shape[:N])+list(max([[v.stop for v in patch] for patch in self.patch_slices]))), dtype=self._layer_accumulator[0].dtype).to(self._layer_accumulator[0].device)
130
+ for patch_slice, data in zip(self.patch_slices, self._layer_accumulator):
131
+ slices_dest = tuple([slice(result.shape[i]) for i in range(N)] + list(patch_slice))
132
+
133
+ for dim, s in enumerate(patch_slice):
134
+ if s.stop-s.start == 1:
135
+ data = data.unsqueeze(dim=dim+N)
136
+ if self.patchCombine is not None:
137
+ result[slices_dest] += self.patchCombine(data)
138
+ else:
139
+ result[slices_dest] = data
140
+ result = result[tuple([slice(None, None)]+[slice(0, s) for s in self.shape])]
141
+
145
142
  self._layer_accumulator.clear()
146
143
  return result
147
144
 
@@ -3,6 +3,9 @@ import importlib
3
3
  import numpy as np
4
4
  import torch
5
5
 
6
+ from torchvision.models import inception_v3, Inception_V3_Weights
7
+ from torchvision.transforms import functional as F
8
+ from scipy import linalg
6
9
 
7
10
  import torch.nn.functional as F
8
11
  import os
@@ -401,21 +404,42 @@ class L1LossRepresentation(Criterion):
401
404
  def forward(self, output: torch.Tensor) -> torch.Tensor:
402
405
  return self.loss(output[0], output[1])+ self._variance(output[0]) + self._variance(output[1])
403
406
 
404
- """import torch
405
- import torch.nn as nn
406
- from torchvision.models import inception_v3, Inception_V3_Weights
407
- from torchvision.transforms import functional as F
408
- from scipy import linalg
409
- import numpy as np
407
+ class FocalLoss(Criterion):
408
+
409
+ def __init__(self, gamma: float = 2.0, alpha: list[float] = [0.5, 2.0, 0.5, 0.5, 1], reduction: str = "mean"):
410
+ super().__init__()
411
+ raw_alpha = torch.tensor(alpha, dtype=torch.float32)
412
+ self.alpha = raw_alpha / raw_alpha.sum() * len(raw_alpha)
413
+ self.gamma = gamma
414
+ self.reduction = reduction
415
+
416
+ def forward(self, output: torch.Tensor, targets: list[torch.Tensor]) -> torch.Tensor:
417
+ target = targets[0].long()
418
+
419
+ logpt = F.log_softmax(output, dim=1)
420
+ pt = torch.exp(logpt)
421
+
422
+ logpt = logpt.gather(1, target.unsqueeze(1))
423
+ pt = pt.gather(1, target.unsqueeze(1))
424
+
425
+ at = self.alpha[target].unsqueeze(1)
426
+ loss = -at * ((1 - pt) ** self.gamma) * logpt
427
+
428
+ if self.reduction == "mean":
429
+ return loss.mean()
430
+ elif self.reduction == "sum":
431
+ return loss.sum()
432
+ return loss
410
433
 
434
+
411
435
  class FID(Criterion):
412
436
 
413
- class InceptionV3(nn.Module):
437
+ class InceptionV3(torch.nn.Module):
414
438
 
415
439
  def __init__(self) -> None:
416
440
  super().__init__()
417
441
  self.model = inception_v3(weights=Inception_V3_Weights.DEFAULT, transform_input=False)
418
- self.model.fc = nn.Identity()
442
+ self.model.fc = torch.nn.Identity()
419
443
  self.model.eval()
420
444
 
421
445
  def forward(self, x: torch.Tensor) -> torch.Tensor:
@@ -456,9 +480,9 @@ class FID(Criterion):
456
480
  generated_features = FID.get_features(generated_images, self.inception_model)
457
481
 
458
482
  return FID.calculate_fid(real_features, generated_features)
459
- """
483
+
460
484
 
461
- """class MutualInformationLoss(torch.nn.Module):
485
+ class MutualInformationLoss(torch.nn.Module):
462
486
  def __init__(self, num_bins: int = 23, sigma_ratio: float = 0.5, smooth_nr: float = 1e-7, smooth_dr: float = 1e-7) -> None:
463
487
  super().__init__()
464
488
  bin_centers = torch.linspace(0.0, 1.0, num_bins)
@@ -469,13 +493,13 @@ class FID(Criterion):
469
493
  self.smooth_nr = float(smooth_nr)
470
494
  self.smooth_dr = float(smooth_dr)
471
495
 
472
- def parzen_windowing(self, pred: torch.Tensor, target: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
496
+ def parzen_windowing(self, pred: torch.Tensor, target: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
473
497
  pred_weight, pred_probability = self.parzen_windowing_gaussian(pred)
474
498
  target_weight, target_probability = self.parzen_windowing_gaussian(target)
475
499
  return pred_weight, pred_probability, target_weight, target_probability
476
500
 
477
501
 
478
- def parzen_windowing_gaussian(self, img: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
502
+ def parzen_windowing_gaussian(self, img: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
479
503
  img = torch.clamp(img, 0, 1)
480
504
  img = img.reshape(img.shape[0], -1, 1) # (batch, num_sample, 1)
481
505
  weight = torch.exp(-self.preterm.to(img) * (img - self.bin_centers.to(img)) ** 2) # (batch, num_sample, num_bin)
@@ -488,4 +512,4 @@ class FID(Criterion):
488
512
  pab = torch.bmm(wa.permute(0, 2, 1), wb.to(wa)).div(wa.shape[1]) # (batch, num_bins, num_bins)
489
513
  papb = torch.bmm(pa.permute(0, 2, 1), pb.to(pa)) # (batch, num_bins, num_bins)
490
514
  mi = torch.sum(pab * torch.log((pab + self.smooth_nr) / (papb + self.smooth_dr) + self.smooth_dr), dim=(1, 2)) # (batch)
491
- return torch.mean(mi).neg() # average over the batch and channel ndims"""
515
+ return torch.mean(mi).neg() # average over the batch and channel ndims
@@ -0,0 +1,115 @@
1
+ from konfai.network import network, blocks
2
+ from typing import Union
3
+ import torch
4
+ from konfai.data.patching import ModelPatch
5
+
6
+
7
+ class NestedUNetBlock(network.ModuleArgsDict):
8
+
9
+ def __init__(self, channels: list[int], nb_conv_per_stage: int, blockConfig: blocks.BlockConfig, downSampleMode: blocks.DownSampleMode, upSampleMode: blocks.UpSampleMode, attention : bool, block: type, dim: int, i : int = 0) -> None:
10
+ super().__init__()
11
+ if i > 0:
12
+ self.add_module(downSampleMode.name, blocks.downSample(in_channels=channels[0], out_channels=channels[1], downSampleMode=downSampleMode, dim=dim))
13
+
14
+ self.add_module("X_{}_{}".format(i, 0), block(in_channels=channels[1 if downSampleMode == blocks.DownSampleMode.CONV_STRIDE and i > 0 else 0], out_channels=channels[1], blockConfigs=[blockConfig]*nb_conv_per_stage, dim=dim), out_branch=["X_{}_{}".format(i, 0)])
15
+ if len(channels) > 2:
16
+ self.add_module("UNetBlock_{}".format(i+1), NestedUNetBlock(channels[1:], nb_conv_per_stage, blockConfig, downSampleMode, upSampleMode, attention, block, dim, i+1), in_branch=["X_{}_{}".format(i, 0)], out_branch=["X_{}_{}".format(i+1, j) for j in range(len(channels)-2)])
17
+ for j in range(len(channels)-2):
18
+ self.add_module("X_{}_{}_{}".format(i, j+1, upSampleMode.name), blocks.upSample(in_channels=channels[2], out_channels=channels[1], upSampleMode=upSampleMode, dim=dim), in_branch=["X_{}_{}".format(i+1, j)], out_branch=["X_{}_{}".format(i+1, j)])
19
+ self.add_module("SkipConnection_{}_{}".format(i, j+1), blocks.Concat(), in_branch=["X_{}_{}".format(i+1, j)]+["X_{}_{}".format(i, r) for r in range(j+1)], out_branch=["X_{}_{}".format(i, j+1)])
20
+ self.add_module("X_{}_{}".format(i, j+1), block(in_channels=(channels[1]*(j+1)+channels[2]) if upSampleMode != blocks.UpSampleMode.CONV_TRANSPOSE else channels[1]*(j+2), out_channels=channels[1], blockConfigs=[blockConfig]*nb_conv_per_stage, dim=dim), in_branch=["X_{}_{}".format(i, j+1)], out_branch=["X_{}_{}".format(i, j+1)])
21
+
22
+ class NestedUNet(network.Network):
23
+
24
+ class NestedUNetHead(network.ModuleArgsDict):
25
+
26
+ def __init__(self, in_channels: int, nb_class: int, activation: str, dim: int) -> None:
27
+ super().__init__()
28
+ self.add_module("Conv", blocks.getTorchModule("Conv", dim)(in_channels = in_channels, out_channels = nb_class, kernel_size = 1, stride = 1, padding = 0))
29
+ if activation == "Softmax":
30
+ self.add_module("Softmax", torch.nn.Softmax(dim=1))
31
+ self.add_module("Argmax", blocks.ArgMax(dim=1))
32
+ elif activation == "Tanh":
33
+ self.add_module("Tanh", torch.nn.Tanh())
34
+
35
+ def __init__( self,
36
+ optimizer : network.OptimizerLoader = network.OptimizerLoader(),
37
+ schedulers : network.LRSchedulersLoader = network.LRSchedulersLoader(),
38
+ outputsCriterions: dict[str, network.TargetCriterionsLoader] = {"default" : network.TargetCriterionsLoader()},
39
+ patch : Union[ModelPatch, None] = None,
40
+ dim : int = 3,
41
+ channels: list[int]=[1, 64, 128, 256, 512, 1024],
42
+ nb_class: int = 2,
43
+ blockConfig: blocks.BlockConfig = blocks.BlockConfig(),
44
+ nb_conv_per_stage: int = 2,
45
+ downSampleMode: str = "MAXPOOL",
46
+ upSampleMode: str = "CONV_TRANSPOSE",
47
+ attention : bool = False,
48
+ blockType: str = "Conv",
49
+ activation: str = "Softmax") -> None:
50
+ super().__init__(in_channels = channels[0], optimizer = optimizer, schedulers = schedulers, outputsCriterions = outputsCriterions, patch=patch, dim = dim)
51
+
52
+ self.add_module("UNetBlock_0", NestedUNetBlock(channels, nb_conv_per_stage, blockConfig, downSampleMode=blocks.DownSampleMode._member_map_[downSampleMode], upSampleMode=blocks.UpSampleMode._member_map_[upSampleMode], attention=attention, block = blocks.ConvBlock if blockType == "Conv" else blocks.ResBlock, dim=dim), out_branch=["X_0_{}".format(j+1) for j in range(len(channels)-2)])
53
+ for j in range(len(channels)-2):
54
+ self.add_module("Head_{}".format(j), NestedUNet.NestedUNetHead(in_channels=channels[1], nb_class=nb_class, activation=activation, dim=dim), in_branch=["X_0_{}".format(j+1)], out_branch=[-1])
55
+
56
+
57
+
58
+
59
+ class ResNetEncoderLayer(network.ModuleArgsDict):
60
+
61
+ def __init__(self, in_channel: int, out_channel: int, nb_block: int, dim: int, downSampleMode : blocks.DownSampleMode):
62
+ super().__init__()
63
+ for i in range(nb_block):
64
+ if downSampleMode == blocks.DownSampleMode.MAXPOOL and i == 0:
65
+ self.add_module("DownSample", blocks.getTorchModule("MaxPool", dim)(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False))
66
+ self.add_module("ResBlock_{}".format(i), blocks.ResBlock(in_channel, out_channel, [blocks.BlockConfig(3, 2 if downSampleMode == blocks.DownSampleMode.CONV_STRIDE and i == 0 else 1, 1, False, "ReLU;True", blocks.NormMode.BATCH), blocks.BlockConfig(3, 1, 1, False, None, blocks.NormMode.BATCH)], dim=dim))
67
+ in_channel = out_channel
68
+
69
+ def resNetEncoder(channels: list[int], layers: list[int], dim: int) -> list[torch.nn.Module]:
70
+ modules = []
71
+ modules.append(blocks.ConvBlock(channels[0], channels[1], [blocks.BlockConfig(7, 2, 3, False, "ReLU", blocks.NormMode.BATCH)], dim=dim))
72
+ for i, (in_channel, out_channel, layer) in enumerate(zip(channels[1:], channels[2:], layers)):
73
+ modules.append(ResNetEncoderLayer(in_channel, out_channel, layer, dim, blocks.DownSampleMode.MAXPOOL if i==0 else blocks.DownSampleMode.CONV_STRIDE))
74
+ return modules
75
+
76
+ class NestedUNetBlock(network.ModuleArgsDict):
77
+
78
+ def __init__(self, encoder_channels: list[int], decoder_channels: list[int], encoders: list[torch.nn.Module], upSampleMode: blocks.UpSampleMode, dim: int, i : int = 0) -> None:
79
+ super().__init__()
80
+ self.add_module("X_{}_{}".format(i, 0), encoders[0], out_branch=["X_{}_{}".format(i, 0)])
81
+ if len(encoder_channels) > 2:
82
+ self.add_module("UNetBlock_{}".format(i+1), NestedUNetBlock(encoder_channels[1:], decoder_channels[1:], encoders[1:], upSampleMode, dim, i+1), in_branch=["X_{}_{}".format(i, 0)], out_branch=["X_{}_{}".format(i+1, j) for j in range(len(encoder_channels)-2)])
83
+ for j in range(len(encoder_channels)-2):
84
+ self.add_module("X_{}_{}_{}".format(i, j+1, upSampleMode.name), blocks.upSample(in_channels=encoder_channels[2], out_channels=decoder_channels[2] if j == len(encoder_channels)-3 else encoder_channels[1], upSampleMode=upSampleMode, dim=dim), in_branch=["X_{}_{}".format(i+1, j)], out_branch=["X_{}_{}".format(i+1, j)])
85
+ self.add_module("SkipConnection_{}_{}".format(i, j+1), blocks.Concat(), in_branch=["X_{}_{}".format(i+1, j)]+["X_{}_{}".format(i, r) for r in range(j+1)], out_branch=["X_{}_{}".format(i, j+1)])
86
+ self.add_module("X_{}_{}".format(i, j+1), blocks.ConvBlock(in_channels=encoder_channels[1]*j+((encoder_channels[1]+encoder_channels[2]) if upSampleMode != blocks.UpSampleMode.CONV_TRANSPOSE else (decoder_channels[1]+decoder_channels[2] if j == len(encoder_channels)-3 else encoder_channels[1]+encoder_channels[1])), out_channels=decoder_channels[2] if j == len(encoder_channels)-3 else encoder_channels[1], blockConfigs=[blocks.BlockConfig(3,1,1,False, "ReLU;True", blocks.NormMode.BATCH)]*2, dim=dim), in_branch=["X_{}_{}".format(i, j+1)], out_branch=["X_{}_{}".format(i, j+1)])
87
+
88
+ class NestedUNetHead(network.ModuleArgsDict):
89
+
90
+ def __init__(self, in_channels: int, out_channels: int, nb_class: int, dim: int) -> None:
91
+ super().__init__()
92
+ self.add_module("Upsample", blocks.upSample(in_channels=in_channels, out_channels=out_channels, upSampleMode=blocks.UpSampleMode.UPSAMPLE_BILINEAR, dim=dim))
93
+ self.add_module("ConvBlock", blocks.ConvBlock(in_channels=in_channels, out_channels=out_channels, blockConfigs=[blocks.BlockConfig(3,1,1,False, "ReLU;True", blocks.NormMode.BATCH)]*2, dim=dim))
94
+ self.add_module("Conv", blocks.getTorchModule("Conv", dim)(in_channels = out_channels, out_channels = nb_class, kernel_size = 3, stride = 1, padding = 1))
95
+ if nb_class > 1:
96
+ self.add_module("Softmax", torch.nn.Softmax(dim=1))
97
+ self.add_module("Argmax", blocks.ArgMax(dim=1))
98
+ else:
99
+ self.add_module("Tanh", torch.nn.Tanh())
100
+
101
+ class UNetpp(network.Network):
102
+
103
+
104
+ def __init__( self,
105
+ optimizer : network.OptimizerLoader = network.OptimizerLoader(),
106
+ schedulers : network.LRSchedulersLoader = network.LRSchedulersLoader(),
107
+ outputsCriterions: dict[str, network.TargetCriterionsLoader] = {"default" : network.TargetCriterionsLoader()},
108
+ patch : Union[ModelPatch, None] = None,
109
+ encoder_channels: list[int] = [1,64,64,128,256,512],
110
+ decoder_channels: list[int] = [256,128,64,32,16,1],
111
+ layers: list[int] = [3,4,6,3],
112
+ dim : int = 2) -> None:
113
+ super().__init__(in_channels = encoder_channels[0], optimizer = optimizer, schedulers = schedulers, outputsCriterions = outputsCriterions, patch=patch, dim = dim)
114
+ self.add_module("Block_0", NestedUNetBlock(encoder_channels, decoder_channels[::-1], resNetEncoder(encoder_channels, layers, dim), blocks.UpSampleMode.UPSAMPLE_BILINEAR, dim=dim), out_branch=["X_0_{}".format(j+1) for j in range(len(encoder_channels)-2)])
115
+ self.add_module("Head", NestedUNetHead(in_channels=decoder_channels[-3], out_channels=decoder_channels[-2], nb_class=decoder_channels[-1], dim=dim), in_branch=["X_0_{}".format(len(encoder_channels)-2)], out_branch=[-1])
@@ -19,7 +19,7 @@ class NormMode(Enum):
19
19
  SYNCBATCH = 5
20
20
  INSTANCE_AFFINE = 6
21
21
 
22
- def getNorm(normMode: Enum, channels : int, dim: int) -> torch.nn.Module:
22
+ def getNorm(normMode: Enum, channels : int, dim: int) -> Union[torch.nn.Module, None]:
23
23
  if normMode == NormMode.BATCH:
24
24
  return getTorchModule("BatchNorm", dim = dim)(channels, affine=True, track_running_stats=True)
25
25
  if normMode == NormMode.INSTANCE:
@@ -32,7 +32,7 @@ def getNorm(normMode: Enum, channels : int, dim: int) -> torch.nn.Module:
32
32
  return torch.nn.GroupNorm(num_groups=32, num_channels=channels)
33
33
  if normMode == NormMode.LAYER:
34
34
  return torch.nn.GroupNorm(num_groups=1, num_channels=channels)
35
- return torch.nn.Identity()
35
+ return None
36
36
 
37
37
  class UpSampleMode(Enum):
38
38
  CONV_TRANSPOSE = 0,
@@ -59,9 +59,9 @@ class BlockConfig():
59
59
  self.stride = stride
60
60
  self.padding = padding
61
61
  self.activation = activation
62
- self.normMode = normMode
63
-
64
- if isinstance(normMode, str):
62
+ if normMode is None:
63
+ self.norm = None
64
+ elif isinstance(normMode, str):
65
65
  self.norm = NormMode._member_map_[normMode]
66
66
  elif isinstance(normMode, NormMode):
67
67
  self.norm = normMode
@@ -70,9 +70,13 @@ class BlockConfig():
70
70
  return getTorchModule("Conv", dim = dim)(in_channels = in_channels, out_channels = out_channels, kernel_size = self.kernel_size, stride = self.stride, padding = self.padding, bias=self.bias)
71
71
 
72
72
  def getNorm(self, channels : int, dim: int) -> torch.nn.Module:
73
+ if self.norm is None:
74
+ return None
73
75
  return getNorm(self.norm, channels, dim) if isinstance(self.norm, NormMode) else self.norm(channels)
74
76
 
75
77
  def getActivation(self) -> torch.nn.Module:
78
+ if self.activation is None:
79
+ return None
76
80
  if isinstance(self.activation, str):
77
81
  return getTorchModule(self.activation.split(";")[0])(*[ast.literal_eval(value) for value in self.activation.split(";")[1:]]) if self.activation != "None" else torch.nn.Identity()
78
82
  return self.activation()
@@ -83,21 +87,35 @@ class ConvBlock(network.ModuleArgsDict):
83
87
  super().__init__()
84
88
  for i, blockConfig in enumerate(blockConfigs):
85
89
  self.add_module("Conv_{}".format(i), blockConfig.getConv(in_channels, out_channels, dim), alias=alias[0])
86
- self.add_module("Norm_{}".format(i), blockConfig.getNorm(out_channels, dim), alias=alias[1])
87
- self.add_module("Activation_{}".format(i), blockConfig.getActivation(), alias=alias[2])
90
+ norm = blockConfig.getNorm(out_channels, dim)
91
+ if norm is not None:
92
+ self.add_module("Norm_{}".format(i), norm, alias=alias[1])
93
+ activation = blockConfig.getActivation()
94
+ if activation is not None:
95
+ self.add_module("Activation_{}".format(i), activation, alias=alias[2])
88
96
  in_channels = out_channels
89
97
 
90
98
  class ResBlock(network.ModuleArgsDict):
91
99
 
92
- def __init__(self, in_channels : int, out_channels : int, nb_conv: int, blockConfig : BlockConfig, dim : int, alias : list[list[str]]=[[], [], [], []]) -> None:
100
+ def __init__(self, in_channels : int, out_channels : int, blockConfigs : list[BlockConfig], dim : int, alias : list[list[str]]=[[], [], [], [], []]) -> None:
93
101
  super().__init__()
94
- for i in range(nb_conv):
95
- self.add_module("ConvBlock_{}".format(i), ConvBlock(in_channels, out_channels, nb_conv=1, blockConfig=blockConfig, dim=dim, alias=alias[:3]))
102
+ for i, blockConfig in enumerate(blockConfigs):
103
+ self.add_module("Conv_{}".format(i), blockConfig.getConv(in_channels, out_channels, dim), alias=alias[0])
104
+ norm = blockConfig.getNorm(out_channels, dim)
105
+ if norm is not None:
106
+ self.add_module("Norm_{}".format(i), norm, alias=alias[1])
107
+ activation = blockConfig.getActivation()
108
+ if activation is not None:
109
+ self.add_module("Activation_{}".format(i), activation, alias=alias[2])
110
+
96
111
  if in_channels != out_channels:
97
- self.add_module("Conv_skip_{}".format(i), blockConfig.getConv(in_channels, out_channels, dim), alias=alias[3], in_branch=[1], out_branch=[1])
112
+ self.add_module("Conv_skip", getTorchModule("Conv", dim)(in_channels, out_channels, 1, blockConfig.stride, bias=blockConfig.bias), alias=alias[3], in_branch=[1], out_branch=[1])
113
+ self.add_module("Norm_skip", blockConfig.getNorm(out_channels, dim), alias=alias[4], in_branch=[1], out_branch=[1])
98
114
  in_channels = out_channels
99
- self.add_module("Add_{}".format(i), Add(), in_branch=[0,1], out_branch=[0,1])
100
-
115
+
116
+ self.add_module("Add", Add(), in_branch=[0,1])
117
+ self.add_module("Norm_{}".format(i+1), torch.nn.ReLU(inplace=True))
118
+
101
119
  def downSample(in_channels: int, out_channels: int, downSampleMode: DownSampleMode, dim: int) -> torch.nn.Module:
102
120
  if downSampleMode == DownSampleMode.MAXPOOL:
103
121
  return getTorchModule("MaxPool", dim = dim)(2)
@@ -176,6 +194,7 @@ class Print(torch.nn.Module):
176
194
  super().__init__()
177
195
 
178
196
  def forward(self, input: torch.Tensor) -> torch.Tensor:
197
+ print(input.shape)
179
198
  return input
180
199
 
181
200
  class Write(torch.nn.Module):
@@ -133,9 +133,10 @@ class Measure():
133
133
  self._loss.clear()
134
134
 
135
135
  def add(self, weight: float, value: torch.Tensor) -> None:
136
- self._loss.append(value if self.isLoss else value.detach())
137
- self._values.append(value.item())
138
- self._weight.append(weight)
136
+ if self.isLoss or value != 0:
137
+ self._loss.append(value if self.isLoss else value.detach())
138
+ self._values.append(value.item())
139
+ self._weight.append(weight)
139
140
 
140
141
  def getLastLoss(self) -> torch.Tensor:
141
142
  return self._loss[-1]*self._weight[-1] if len(self._loss) else torch.zeros((1), requires_grad = True)
@@ -8,7 +8,7 @@ import os
8
8
 
9
9
  from konfai import MODELS_DIRECTORY, PREDICTIONS_DIRECTORY, CONFIG_FILE, MODEL, KONFAI_ROOT
10
10
  from konfai.utils.config import config
11
- from konfai.utils.utils import State, get_patch_slices_from_nb_patch_per_dim, NeedDevice, _getModule, DistributedObject, DataLog, description
11
+ from konfai.utils.utils import State, get_patch_slices_from_nb_patch_per_dim, NeedDevice, _getModule, DistributedObject, DataLog, description, PredictorError
12
12
  from konfai.utils.dataset import Dataset, Attribute
13
13
  from konfai.data.data_manager import DataPrediction, DatasetIter
14
14
  from konfai.data.patching import Accumulator, PathCombine
@@ -46,7 +46,7 @@ class OutDataset(Dataset, NeedDevice, ABC):
46
46
  self.names: dict[int, str] = {}
47
47
  self.nb_data_augmentation = 0
48
48
 
49
- def load(self, name_layer: str, datasets: list[Dataset]):
49
+ def load(self, name_layer: str, datasets: list[Dataset], groups: dict[str, str]):
50
50
  transforms_type = ["pre_transforms", "post_transforms", "final_transforms"]
51
51
  for name, _transform_type, transform_type in [(k, getattr(self, "_{}".format(k)), getattr(self, k)) for k in transforms_type]:
52
52
 
@@ -137,9 +137,22 @@ class OutSameAsGroupDataset(OutDataset):
137
137
  if self.inverse_transform:
138
138
  for transform in reversed(dataset.groups_src[self.group_src][self.group_dest].post_transforms):
139
139
  layer = transform.inverse(self.names[index_dataset], layer, self.attributes[index_dataset][index_augmentation][index_patch])
140
-
141
140
  self.output_layer_accumulator[index_dataset][index_augmentation].addLayer(index_patch, layer)
142
141
 
142
+
143
+ def load(self, name_layer: str, datasets: list[Dataset], groups: dict[str, str]):
144
+ super().load(name_layer, datasets, groups)
145
+
146
+ if self.group_src not in groups.keys():
147
+ raise PredictorError(
148
+ f"Source group '{self.group_src}' not found. Available groups: {list(groups.keys())}."
149
+ )
150
+
151
+ if self.group_dest not in groups[self.group_src]:
152
+ raise PredictorError(
153
+ f"Destination group '{self.group_dest}' not found. Available groups: {groups[self.group_src]}."
154
+ )
155
+
143
156
  def _getOutput(self, index: int, index_augmentation: int, dataset: DatasetIter) -> torch.Tensor:
144
157
  layer = self.output_layer_accumulator[index][index_augmentation].assemble()
145
158
  name = self.names[index]
@@ -403,9 +416,20 @@ class Predictor(DistributedObject):
403
416
 
404
417
  shutil.copyfile(CONFIG_FILE(), self.predict_path+"Prediction.yml")
405
418
 
419
+
406
420
  self.model.init(self.autocast, State.PREDICTION, self.dataset.getGroupsDest())
407
421
  self.model.init_outputsGroup()
408
422
  self.model._compute_channels_trace(self.model, self.model.in_channels, None, self.gpu_checkpoints)
423
+
424
+ modules = []
425
+ for i,_,_ in self.model.named_ModuleArgsDict():
426
+ modules.append(i)
427
+ for output_group in self.outsDataset.keys():
428
+ if output_group not in modules:
429
+ raise PredictorError("The output group '{}' defined in 'outputsCriterions' does not correspond to any module in the model.".format(output_group),
430
+ "Available modules: {}".format(modules),
431
+ "Please check that the name matches exactly a submodule or output of your model architecture."
432
+ )
409
433
  self.modelComposite = ModelComposite(self.model, len(MODEL().split(":")), self.combine)
410
434
  self.modelComposite.load(self._load())
411
435
 
@@ -415,7 +439,7 @@ class Predictor(DistributedObject):
415
439
  self.size = (len(self.gpu_checkpoints)+1 if self.gpu_checkpoints else 1)
416
440
  self.dataloader = self.dataset.getData(world_size//self.size)
417
441
  for name, outDataset in self.outsDataset.items():
418
- outDataset.load(name.replace(".", ":"), list(self.dataset.datasets.values()))
442
+ outDataset.load(name.replace(".", ":"), list(self.dataset.datasets.values()), {src : dest for src, inner in self.dataset.groups_src.items() for dest in inner})
419
443
 
420
444
 
421
445
  def run_process(self, world_size: int, global_rank: int, local_rank: int, dataloaders: list[DataLoader]):
@@ -120,7 +120,8 @@ class _Trainer():
120
120
 
121
121
  def run(self) -> None:
122
122
  self.dataloader_training.dataset.load("Train")
123
- self.dataloader_validation.dataset.load("Validation")
123
+ if self.dataloader_validation is not None:
124
+ self.dataloader_validation.dataset.load("Validation")
124
125
  with tqdm.tqdm(iterable = range(self.epoch, self.epochs), leave=False, total=self.epochs, initial=self.epoch, desc="Progress") as epoch_tqdm:
125
126
  for self.epoch in epoch_tqdm:
126
127
  self.train()
@@ -586,6 +586,11 @@ class EvaluatorError(KonfAIError):
586
586
  def __init__(self, *message) -> None:
587
587
  super().__init__("Evaluator", message)
588
588
 
589
+ class PredictorError(KonfAIError):
590
+
591
+ def __init__(self, *message) -> None:
592
+ super().__init__("Predictor", message)
593
+
589
594
  class TransformError(KonfAIError):
590
595
 
591
596
  def __init__(self, *message) -> None:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: konfai
3
- Version: 1.1.4
3
+ Version: 1.1.5
4
4
  Summary: Modular and configurable Deep Learning framework with YAML and PyTorch
5
5
  Author-email: Valentin Boussot <boussot.v@gmail.com>
6
6
  License-Expression: Apache-2.0
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
4
4
 
5
5
  [project]
6
6
  name = "konfai"
7
- version = "1.1.4"
7
+ version = "1.1.5"
8
8
  description = "Modular and configurable Deep Learning framework with YAML and PyTorch"
9
9
  readme = "README.md"
10
10
  requires-python = ">=3.8"
@@ -1,52 +0,0 @@
1
- from typing import Union
2
- import torch
3
-
4
- from konfai.network import network, blocks
5
- from konfai.utils.config import config
6
- from konfai.data.patching import ModelPatch
7
-
8
-
9
- class NestedUNetBlock(network.ModuleArgsDict):
10
-
11
- def __init__(self, channels: list[int], nb_conv_per_stage: int, blockConfig: blocks.BlockConfig, downSampleMode: blocks.DownSampleMode, upSampleMode: blocks.UpSampleMode, attention : bool, block: type, dim: int, i : int = 0) -> None:
12
- super().__init__()
13
- if i > 0:
14
- self.add_module(downSampleMode.name, blocks.downSample(in_channels=channels[0], out_channels=channels[1], downSampleMode=downSampleMode, dim=dim))
15
-
16
- self.add_module("X_{}_{}".format(i, 0), block(in_channels=channels[1 if downSampleMode == blocks.DownSampleMode.CONV_STRIDE and i > 0 else 0], out_channels=channels[1], blockConfigs=[blockConfig]*nb_conv_per_stage, dim=dim), out_branch=["X_{}_{}".format(i, 0)])
17
- if len(channels) > 2:
18
- self.add_module("UNetBlock_{}".format(i+1), NestedUNetBlock(channels[1:], nb_conv_per_stage, blockConfig, downSampleMode, upSampleMode, attention, block, dim, i+1), in_branch=["X_{}_{}".format(i, 0)], out_branch=["X_{}_{}".format(i+1, j) for j in range(len(channels)-2)])
19
- for j in range(len(channels)-2):
20
- self.add_module("X_{}_{}_{}".format(i, j+1, upSampleMode.name), blocks.upSample(in_channels=channels[2], out_channels=channels[1], upSampleMode=upSampleMode, dim=dim), in_branch=["X_{}_{}".format(i+1, j)], out_branch=["X_{}_{}".format(i+1, j)])
21
- self.add_module("SkipConnection_{}_{}".format(i, j+1), blocks.Concat(), in_branch=["X_{}_{}".format(i+1, j)]+["X_{}_{}".format(i, r) for r in range(j+1)], out_branch=["X_{}_{}".format(i, j+1)])
22
- self.add_module("X_{}_{}".format(i, j+1), block(in_channels=(channels[1]*(j+1)+channels[2]) if upSampleMode != blocks.UpSampleMode.CONV_TRANSPOSE else channels[1]*(j+2), out_channels=channels[1], blockConfigs=[blockConfig]*nb_conv_per_stage, dim=dim), in_branch=["X_{}_{}".format(i, j+1)], out_branch=["X_{}_{}".format(i, j+1)])
23
-
24
- class NestedUNet(network.Network):
25
-
26
- class NestedUNetHead(network.ModuleArgsDict):
27
-
28
- def __init__(self, in_channels: int, nb_class: int, dim: int) -> None:
29
- super().__init__()
30
- self.add_module("Conv", blocks.getTorchModule("Conv", dim)(in_channels = in_channels, out_channels = nb_class, kernel_size = 1, stride = 1, padding = 0))
31
- self.add_module("Softmax", torch.nn.Softmax(dim=1))
32
- self.add_module("Argmax", blocks.ArgMax(dim=1))
33
-
34
- def __init__( self,
35
- optimizer : network.OptimizerLoader = network.OptimizerLoader(),
36
- schedulers : network.LRSchedulersLoader = network.LRSchedulersLoader(),
37
- outputsCriterions: dict[str, network.TargetCriterionsLoader] = {"default" : network.TargetCriterionsLoader()},
38
- patch : Union[ModelPatch, None] = None,
39
- dim : int = 3,
40
- channels: list[int]=[1, 64, 128, 256, 512, 1024],
41
- nb_class: int = 2,
42
- blockConfig: blocks.BlockConfig = blocks.BlockConfig(),
43
- nb_conv_per_stage: int = 2,
44
- downSampleMode: str = "MAXPOOL",
45
- upSampleMode: str = "CONV_TRANSPOSE",
46
- attention : bool = False,
47
- blockType: str = "Conv") -> None:
48
- super().__init__(in_channels = channels[0], optimizer = optimizer, schedulers = schedulers, outputsCriterions = outputsCriterions, patch=patch, dim = dim)
49
-
50
- self.add_module("UNetBlock_0", NestedUNetBlock(channels, nb_conv_per_stage, blockConfig, downSampleMode=blocks.DownSampleMode._member_map_[downSampleMode], upSampleMode=blocks.UpSampleMode._member_map_[upSampleMode], attention=attention, block = blocks.ConvBlock if blockType == "Conv" else blocks.ResBlock, dim=dim), out_branch=["X_0_{}".format(j+1) for j in range(len(channels)-2)])
51
- for j in range(len(channels)-2):
52
- self.add_module("Head_{}".format(j), NestedUNet.NestedUNetHead(in_channels=channels[1], nb_class=nb_class, dim=dim), in_branch=["X_0_{}".format(j+1)], out_branch=[-1])
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes