konfai 1.1.3__py3-none-any.whl → 1.1.5__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of konfai might be problematic. Click here for more details.

@@ -134,7 +134,6 @@ class EulerTransform(DataAugmentation):
134
134
 
135
135
  class Translate(EulerTransform):
136
136
 
137
- @config("Translate")
138
137
  def __init__(self, t_min: float = -10, t_max = 10, is_int: bool = False):
139
138
  super().__init__()
140
139
  self.t_min = t_min
@@ -152,7 +151,6 @@ class Translate(EulerTransform):
152
151
 
153
152
  class Rotate(EulerTransform):
154
153
 
155
- @config("Rotate")
156
154
  def __init__(self, a_min: float = 0, a_max: float = 360, is_quarter: bool = False):
157
155
  super().__init__()
158
156
  self.a_min = a_min
@@ -174,7 +172,6 @@ class Rotate(EulerTransform):
174
172
 
175
173
  class Scale(EulerTransform):
176
174
 
177
- @config("Scale")
178
175
  def __init__(self, s_std: float = 0.2):
179
176
  super().__init__()
180
177
  self.s_std = s_std
@@ -187,7 +184,6 @@ class Scale(EulerTransform):
187
184
 
188
185
  class Flip(DataAugmentation):
189
186
 
190
- @config("Flip")
191
187
  def __init__(self, f_prob: Union[list[float], None] = [0.33, 0.33 ,0.33]) -> None:
192
188
  super().__init__()
193
189
  self.f_prob = f_prob
@@ -211,7 +207,6 @@ class Flip(DataAugmentation):
211
207
 
212
208
  class ColorTransform(DataAugmentation):
213
209
 
214
- @config("ColorTransform")
215
210
  def __init__(self, groups: Union[list[str], None] = None) -> None:
216
211
  super().__init__(groups)
217
212
  self.matrix: dict[int, list[torch.Tensor]] = {}
@@ -236,7 +231,6 @@ class ColorTransform(DataAugmentation):
236
231
 
237
232
  class Brightness(ColorTransform):
238
233
 
239
- @config("Brightness")
240
234
  def __init__(self, b_std: float, groups: Union[list[str], None] = None) -> None:
241
235
  super().__init__(groups)
242
236
  self.b_std = b_std
@@ -248,7 +242,6 @@ class Brightness(ColorTransform):
248
242
 
249
243
  class Contrast(ColorTransform):
250
244
 
251
- @config("Contrast")
252
245
  def __init__(self, c_std: float, groups: Union[list[str], None] = None) -> None:
253
246
  super().__init__(groups)
254
247
  self.c_std = c_std
@@ -260,7 +253,6 @@ class Contrast(ColorTransform):
260
253
 
261
254
  class LumaFlip(ColorTransform):
262
255
 
263
- @config("LumaFlip")
264
256
  def __init__(self, groups: Union[list[str], None] = None) -> None:
265
257
  super().__init__(groups)
266
258
  self.v = torch.tensor([1, 1, 1, 0])/torch.sqrt(torch.tensor(3))
@@ -285,7 +277,6 @@ class HUE(ColorTransform):
285
277
 
286
278
  class Saturation(ColorTransform):
287
279
 
288
- @config("Saturation")
289
280
  def __init__(self, s_std: float, groups: Union[list[str], None] = None) -> None:
290
281
  super().__init__(groups)
291
282
  self.s_std = s_std
@@ -372,7 +363,6 @@ class Saturation(ColorTransform):
372
363
 
373
364
  class Noise(DataAugmentation):
374
365
 
375
- @config("Noise")
376
366
  def __init__(self, n_std: float, noise_step: int=1000, beta_start: float = 1e-4, beta_end: float = 0.02, groups: Union[list[str], None] = None) -> None:
377
367
  super().__init__(groups)
378
368
  self.n_std = n_std
@@ -427,7 +417,6 @@ class Noise(DataAugmentation):
427
417
 
428
418
  class CutOUT(DataAugmentation):
429
419
 
430
- @config("CutOUT")
431
420
  def __init__(self, c_prob: float, cutout_size: int, value: float, groups: Union[list[str], None] = None) -> None:
432
421
  super().__init__(groups)
433
422
  self.c_prob = c_prob
@@ -458,7 +447,6 @@ class CutOUT(DataAugmentation):
458
447
 
459
448
  class Elastix(DataAugmentation):
460
449
 
461
- @config("Elastix")
462
450
  def __init__(self, grid_spacing: int = 16, max_displacement: int = 16) -> None:
463
451
  super().__init__()
464
452
  self.grid_spacing = grid_spacing
@@ -529,7 +517,6 @@ class Elastix(DataAugmentation):
529
517
 
530
518
  class Permute(DataAugmentation):
531
519
 
532
- @config("Permute")
533
520
  def __init__(self, prob_permute: Union[list[float], None] = [0.5 ,0.5]) -> None:
534
521
  super().__init__()
535
522
  self._permute_dims = torch.tensor([[0, 2, 1, 3], [0, 3, 1, 2]])
@@ -567,7 +554,6 @@ class Permute(DataAugmentation):
567
554
 
568
555
  class Mask(DataAugmentation):
569
556
 
570
- @config("Mask")
571
557
  def __init__(self, mask: str, value: float, groups: Union[list[str], None] = None) -> None:
572
558
  super().__init__(groups)
573
559
  if mask is not None:
@@ -129,7 +129,7 @@ class DatasetIter(data.Dataset):
129
129
  for group_dest in self.groups_src[group_src]:
130
130
  self.data[group_dest][index].unloadAugmentation()
131
131
  self.data[group_dest][index].resetAugmentation()
132
- self.load(label + " Augmentation")
132
+ self.load(label + " Augmentation")
133
133
 
134
134
  def load(self, label: str):
135
135
  if self.use_cache:
@@ -217,7 +217,6 @@ class Subset():
217
217
  else:
218
218
  names_filtred = names
219
219
  size = len(names_filtred)
220
-
221
220
  index = []
222
221
  if self.subset is None:
223
222
  index = list(range(0, size))
@@ -226,15 +225,24 @@ class Subset():
226
225
  r = np.clip(np.asarray([int(self.subset.split(":")[0]), int(self.subset.split(":")[1])]), 0, size)
227
226
  index = list(range(r[0], r[1]))
228
227
  elif os.path.exists(self.subset):
229
- validation_names = []
228
+ train_names = []
230
229
  with open(self.subset, "r") as f:
231
230
  for name in f:
232
- validation_names.append(name.strip())
231
+ train_names.append(name.strip())
233
232
  index = []
234
233
  for i, name in enumerate(names_filtred):
235
- if name in validation_names:
234
+ if name in train_names:
236
235
  index.append(i)
237
-
236
+ elif self.subset.startswith("~") and os.path.exists(self.subset[1:]):
237
+ exclude_names = []
238
+ with open(self.subset[1:], "r") as f:
239
+ for name in f:
240
+ exclude_names.append(name.strip())
241
+ index = []
242
+ for i, name in enumerate(names_filtred):
243
+ if name not in exclude_names:
244
+ index.append(i)
245
+
238
246
  elif isinstance(self.subset, list):
239
247
  if len(self.subset) > 0:
240
248
  if isinstance(self.subset[0], int):
konfai/data/patching.py CHANGED
@@ -125,23 +125,20 @@ class Accumulator():
125
125
  return len(self.patch_slices) == len([v for v in self._layer_accumulator if v is not None])
126
126
 
127
127
  def assemble(self) -> torch.Tensor:
128
- if all([self._layer_accumulator[0].shape[i-len(self.patch_size)] == size for i, size in enumerate(self.patch_size)]):
129
- N = 2 if self.batch else 1
130
-
131
- result = torch.zeros((list(self._layer_accumulator[0].shape[:N])+list(max([[v.stop for v in patch] for patch in self.patch_slices]))), dtype=self._layer_accumulator[0].dtype).to(self._layer_accumulator[0].device)
132
- for patch_slice, data in zip(self.patch_slices, self._layer_accumulator):
133
- slices_dest = tuple([slice(result.shape[i]) for i in range(N)] + list(patch_slice))
134
-
135
- for dim, s in enumerate(patch_slice):
136
- if s.stop-s.start == 1:
137
- data = data.unsqueeze(dim=dim+N)
138
- if self.patchCombine is not None:
139
- result[slices_dest] += self.patchCombine(data)
140
- else:
141
- result[slices_dest] = data
142
- result = result[tuple([slice(None, None)]+[slice(0, s) for s in self.shape])]
143
- else:
144
- result = torch.cat(tuple(self._layer_accumulator), dim=0)
128
+ N = 2 if self.batch else 1
129
+ result = torch.zeros((list(self._layer_accumulator[0].shape[:N])+list(max([[v.stop for v in patch] for patch in self.patch_slices]))), dtype=self._layer_accumulator[0].dtype).to(self._layer_accumulator[0].device)
130
+ for patch_slice, data in zip(self.patch_slices, self._layer_accumulator):
131
+ slices_dest = tuple([slice(result.shape[i]) for i in range(N)] + list(patch_slice))
132
+
133
+ for dim, s in enumerate(patch_slice):
134
+ if s.stop-s.start == 1:
135
+ data = data.unsqueeze(dim=dim+N)
136
+ if self.patchCombine is not None:
137
+ result[slices_dest] += self.patchCombine(data)
138
+ else:
139
+ result[slices_dest] = data
140
+ result = result[tuple([slice(None, None)]+[slice(0, s) for s in self.shape])]
141
+
145
142
  self._layer_accumulator.clear()
146
143
  return result
147
144
 
konfai/evaluator.py CHANGED
@@ -102,7 +102,7 @@ class Evaluator(DistributedObject):
102
102
  result = {}
103
103
  for output_group in self.metrics:
104
104
  for target_group in self.metrics[output_group]:
105
- targets = [data_dict[group][0].to(0) if torch.cuda.is_available() else data_dict[group][0] for group in target_group.split("/") if group in data_dict]
105
+ targets = [data_dict[group][0].to(0) if torch.cuda.is_available() else data_dict[group][0] for group in target_group.split(";") if group in data_dict]
106
106
  name = data_dict[output_group][1][0]
107
107
  for metric in self.metrics[output_group][target_group]:
108
108
  result["{}:{}:{}".format(output_group, target_group, metric.__class__.__name__)] = metric(data_dict[output_group][0].to(0) if torch.cuda.is_available() else data_dict[output_group][0], *targets).item()
@@ -135,8 +135,11 @@ class Evaluator(DistributedObject):
135
135
  f"Available groups: {sorted(groupsDest)}"
136
136
  )
137
137
 
138
- target_groups = {target for targets in self.metrics.values() for target in targets}
139
- missing_targets = target_groups - set(groupsDest)
138
+ target_groups = []
139
+ for i in {target for targets in self.metrics.values() for target in targets}:
140
+ for u in i.split(";"):
141
+ target_groups.append(u)
142
+ missing_targets = set(target_groups) - set(groupsDest)
140
143
  if missing_targets:
141
144
  raise EvaluatorError(
142
145
  f"The following metric target groups are missing from 'groupsDest': {sorted(missing_targets)}. ",
konfai/metric/measure.py CHANGED
@@ -3,6 +3,9 @@ import importlib
3
3
  import numpy as np
4
4
  import torch
5
5
 
6
+ from torchvision.models import inception_v3, Inception_V3_Weights
7
+ from torchvision.transforms import functional as F
8
+ from scipy import linalg
6
9
 
7
10
  import torch.nn.functional as F
8
11
  import os
@@ -401,21 +404,42 @@ class L1LossRepresentation(Criterion):
401
404
  def forward(self, output: torch.Tensor) -> torch.Tensor:
402
405
  return self.loss(output[0], output[1])+ self._variance(output[0]) + self._variance(output[1])
403
406
 
404
- """import torch
405
- import torch.nn as nn
406
- from torchvision.models import inception_v3, Inception_V3_Weights
407
- from torchvision.transforms import functional as F
408
- from scipy import linalg
409
- import numpy as np
407
+ class FocalLoss(Criterion):
408
+
409
+ def __init__(self, gamma: float = 2.0, alpha: list[float] = [0.5, 2.0, 0.5, 0.5, 1], reduction: str = "mean"):
410
+ super().__init__()
411
+ raw_alpha = torch.tensor(alpha, dtype=torch.float32)
412
+ self.alpha = raw_alpha / raw_alpha.sum() * len(raw_alpha)
413
+ self.gamma = gamma
414
+ self.reduction = reduction
415
+
416
+ def forward(self, output: torch.Tensor, targets: list[torch.Tensor]) -> torch.Tensor:
417
+ target = targets[0].long()
418
+
419
+ logpt = F.log_softmax(output, dim=1)
420
+ pt = torch.exp(logpt)
421
+
422
+ logpt = logpt.gather(1, target.unsqueeze(1))
423
+ pt = pt.gather(1, target.unsqueeze(1))
424
+
425
+ at = self.alpha[target].unsqueeze(1)
426
+ loss = -at * ((1 - pt) ** self.gamma) * logpt
427
+
428
+ if self.reduction == "mean":
429
+ return loss.mean()
430
+ elif self.reduction == "sum":
431
+ return loss.sum()
432
+ return loss
410
433
 
434
+
411
435
  class FID(Criterion):
412
436
 
413
- class InceptionV3(nn.Module):
437
+ class InceptionV3(torch.nn.Module):
414
438
 
415
439
  def __init__(self) -> None:
416
440
  super().__init__()
417
441
  self.model = inception_v3(weights=Inception_V3_Weights.DEFAULT, transform_input=False)
418
- self.model.fc = nn.Identity()
442
+ self.model.fc = torch.nn.Identity()
419
443
  self.model.eval()
420
444
 
421
445
  def forward(self, x: torch.Tensor) -> torch.Tensor:
@@ -456,9 +480,9 @@ class FID(Criterion):
456
480
  generated_features = FID.get_features(generated_images, self.inception_model)
457
481
 
458
482
  return FID.calculate_fid(real_features, generated_features)
459
- """
483
+
460
484
 
461
- """class MutualInformationLoss(torch.nn.Module):
485
+ class MutualInformationLoss(torch.nn.Module):
462
486
  def __init__(self, num_bins: int = 23, sigma_ratio: float = 0.5, smooth_nr: float = 1e-7, smooth_dr: float = 1e-7) -> None:
463
487
  super().__init__()
464
488
  bin_centers = torch.linspace(0.0, 1.0, num_bins)
@@ -469,13 +493,13 @@ class FID(Criterion):
469
493
  self.smooth_nr = float(smooth_nr)
470
494
  self.smooth_dr = float(smooth_dr)
471
495
 
472
- def parzen_windowing(self, pred: torch.Tensor, target: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
496
+ def parzen_windowing(self, pred: torch.Tensor, target: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
473
497
  pred_weight, pred_probability = self.parzen_windowing_gaussian(pred)
474
498
  target_weight, target_probability = self.parzen_windowing_gaussian(target)
475
499
  return pred_weight, pred_probability, target_weight, target_probability
476
500
 
477
501
 
478
- def parzen_windowing_gaussian(self, img: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]:
502
+ def parzen_windowing_gaussian(self, img: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
479
503
  img = torch.clamp(img, 0, 1)
480
504
  img = img.reshape(img.shape[0], -1, 1) # (batch, num_sample, 1)
481
505
  weight = torch.exp(-self.preterm.to(img) * (img - self.bin_centers.to(img)) ** 2) # (batch, num_sample, num_bin)
@@ -488,4 +512,4 @@ class FID(Criterion):
488
512
  pab = torch.bmm(wa.permute(0, 2, 1), wb.to(wa)).div(wa.shape[1]) # (batch, num_bins, num_bins)
489
513
  papb = torch.bmm(pa.permute(0, 2, 1), pb.to(pa)) # (batch, num_bins, num_bins)
490
514
  mi = torch.sum(pab * torch.log((pab + self.smooth_nr) / (papb + self.smooth_dr) + self.smooth_dr), dim=(1, 2)) # (batch)
491
- return torch.mean(mi).neg() # average over the batch and channel ndims"""
515
+ return torch.mean(mi).neg() # average over the batch and channel ndims
@@ -1,8 +1,6 @@
1
+ from konfai.network import network, blocks
1
2
  from typing import Union
2
3
  import torch
3
-
4
- from konfai.network import network, blocks
5
- from konfai.utils.config import config
6
4
  from konfai.data.patching import ModelPatch
7
5
 
8
6
 
@@ -25,13 +23,15 @@ class NestedUNet(network.Network):
25
23
 
26
24
  class NestedUNetHead(network.ModuleArgsDict):
27
25
 
28
- def __init__(self, in_channels: int, nb_class: int, dim: int) -> None:
26
+ def __init__(self, in_channels: int, nb_class: int, activation: str, dim: int) -> None:
29
27
  super().__init__()
30
28
  self.add_module("Conv", blocks.getTorchModule("Conv", dim)(in_channels = in_channels, out_channels = nb_class, kernel_size = 1, stride = 1, padding = 0))
31
- self.add_module("Softmax", torch.nn.Softmax(dim=1))
32
- self.add_module("Argmax", blocks.ArgMax(dim=1))
33
-
34
- @config("NestedUNet")
29
+ if activation == "Softmax":
30
+ self.add_module("Softmax", torch.nn.Softmax(dim=1))
31
+ self.add_module("Argmax", blocks.ArgMax(dim=1))
32
+ elif activation == "Tanh":
33
+ self.add_module("Tanh", torch.nn.Tanh())
34
+
35
35
  def __init__( self,
36
36
  optimizer : network.OptimizerLoader = network.OptimizerLoader(),
37
37
  schedulers : network.LRSchedulersLoader = network.LRSchedulersLoader(),
@@ -45,9 +45,71 @@ class NestedUNet(network.Network):
45
45
  downSampleMode: str = "MAXPOOL",
46
46
  upSampleMode: str = "CONV_TRANSPOSE",
47
47
  attention : bool = False,
48
- blockType: str = "Conv") -> None:
48
+ blockType: str = "Conv",
49
+ activation: str = "Softmax") -> None:
49
50
  super().__init__(in_channels = channels[0], optimizer = optimizer, schedulers = schedulers, outputsCriterions = outputsCriterions, patch=patch, dim = dim)
50
51
 
51
52
  self.add_module("UNetBlock_0", NestedUNetBlock(channels, nb_conv_per_stage, blockConfig, downSampleMode=blocks.DownSampleMode._member_map_[downSampleMode], upSampleMode=blocks.UpSampleMode._member_map_[upSampleMode], attention=attention, block = blocks.ConvBlock if blockType == "Conv" else blocks.ResBlock, dim=dim), out_branch=["X_0_{}".format(j+1) for j in range(len(channels)-2)])
52
53
  for j in range(len(channels)-2):
53
- self.add_module("Head_{}".format(j), NestedUNet.NestedUNetHead(in_channels=channels[1], nb_class=nb_class, dim=dim), in_branch=["X_0_{}".format(j+1)], out_branch=[-1])
54
+ self.add_module("Head_{}".format(j), NestedUNet.NestedUNetHead(in_channels=channels[1], nb_class=nb_class, activation=activation, dim=dim), in_branch=["X_0_{}".format(j+1)], out_branch=[-1])
55
+
56
+
57
+
58
+
59
+ class ResNetEncoderLayer(network.ModuleArgsDict):
60
+
61
+ def __init__(self, in_channel: int, out_channel: int, nb_block: int, dim: int, downSampleMode : blocks.DownSampleMode):
62
+ super().__init__()
63
+ for i in range(nb_block):
64
+ if downSampleMode == blocks.DownSampleMode.MAXPOOL and i == 0:
65
+ self.add_module("DownSample", blocks.getTorchModule("MaxPool", dim)(kernel_size=3, stride=2, padding=1, dilation=1, ceil_mode=False))
66
+ self.add_module("ResBlock_{}".format(i), blocks.ResBlock(in_channel, out_channel, [blocks.BlockConfig(3, 2 if downSampleMode == blocks.DownSampleMode.CONV_STRIDE and i == 0 else 1, 1, False, "ReLU;True", blocks.NormMode.BATCH), blocks.BlockConfig(3, 1, 1, False, None, blocks.NormMode.BATCH)], dim=dim))
67
+ in_channel = out_channel
68
+
69
+ def resNetEncoder(channels: list[int], layers: list[int], dim: int) -> list[torch.nn.Module]:
70
+ modules = []
71
+ modules.append(blocks.ConvBlock(channels[0], channels[1], [blocks.BlockConfig(7, 2, 3, False, "ReLU", blocks.NormMode.BATCH)], dim=dim))
72
+ for i, (in_channel, out_channel, layer) in enumerate(zip(channels[1:], channels[2:], layers)):
73
+ modules.append(ResNetEncoderLayer(in_channel, out_channel, layer, dim, blocks.DownSampleMode.MAXPOOL if i==0 else blocks.DownSampleMode.CONV_STRIDE))
74
+ return modules
75
+
76
+ class NestedUNetBlock(network.ModuleArgsDict):
77
+
78
+ def __init__(self, encoder_channels: list[int], decoder_channels: list[int], encoders: list[torch.nn.Module], upSampleMode: blocks.UpSampleMode, dim: int, i : int = 0) -> None:
79
+ super().__init__()
80
+ self.add_module("X_{}_{}".format(i, 0), encoders[0], out_branch=["X_{}_{}".format(i, 0)])
81
+ if len(encoder_channels) > 2:
82
+ self.add_module("UNetBlock_{}".format(i+1), NestedUNetBlock(encoder_channels[1:], decoder_channels[1:], encoders[1:], upSampleMode, dim, i+1), in_branch=["X_{}_{}".format(i, 0)], out_branch=["X_{}_{}".format(i+1, j) for j in range(len(encoder_channels)-2)])
83
+ for j in range(len(encoder_channels)-2):
84
+ self.add_module("X_{}_{}_{}".format(i, j+1, upSampleMode.name), blocks.upSample(in_channels=encoder_channels[2], out_channels=decoder_channels[2] if j == len(encoder_channels)-3 else encoder_channels[1], upSampleMode=upSampleMode, dim=dim), in_branch=["X_{}_{}".format(i+1, j)], out_branch=["X_{}_{}".format(i+1, j)])
85
+ self.add_module("SkipConnection_{}_{}".format(i, j+1), blocks.Concat(), in_branch=["X_{}_{}".format(i+1, j)]+["X_{}_{}".format(i, r) for r in range(j+1)], out_branch=["X_{}_{}".format(i, j+1)])
86
+ self.add_module("X_{}_{}".format(i, j+1), blocks.ConvBlock(in_channels=encoder_channels[1]*j+((encoder_channels[1]+encoder_channels[2]) if upSampleMode != blocks.UpSampleMode.CONV_TRANSPOSE else (decoder_channels[1]+decoder_channels[2] if j == len(encoder_channels)-3 else encoder_channels[1]+encoder_channels[1])), out_channels=decoder_channels[2] if j == len(encoder_channels)-3 else encoder_channels[1], blockConfigs=[blocks.BlockConfig(3,1,1,False, "ReLU;True", blocks.NormMode.BATCH)]*2, dim=dim), in_branch=["X_{}_{}".format(i, j+1)], out_branch=["X_{}_{}".format(i, j+1)])
87
+
88
+ class NestedUNetHead(network.ModuleArgsDict):
89
+
90
+ def __init__(self, in_channels: int, out_channels: int, nb_class: int, dim: int) -> None:
91
+ super().__init__()
92
+ self.add_module("Upsample", blocks.upSample(in_channels=in_channels, out_channels=out_channels, upSampleMode=blocks.UpSampleMode.UPSAMPLE_BILINEAR, dim=dim))
93
+ self.add_module("ConvBlock", blocks.ConvBlock(in_channels=in_channels, out_channels=out_channels, blockConfigs=[blocks.BlockConfig(3,1,1,False, "ReLU;True", blocks.NormMode.BATCH)]*2, dim=dim))
94
+ self.add_module("Conv", blocks.getTorchModule("Conv", dim)(in_channels = out_channels, out_channels = nb_class, kernel_size = 3, stride = 1, padding = 1))
95
+ if nb_class > 1:
96
+ self.add_module("Softmax", torch.nn.Softmax(dim=1))
97
+ self.add_module("Argmax", blocks.ArgMax(dim=1))
98
+ else:
99
+ self.add_module("Tanh", torch.nn.Tanh())
100
+
101
+ class UNetpp(network.Network):
102
+
103
+
104
+ def __init__( self,
105
+ optimizer : network.OptimizerLoader = network.OptimizerLoader(),
106
+ schedulers : network.LRSchedulersLoader = network.LRSchedulersLoader(),
107
+ outputsCriterions: dict[str, network.TargetCriterionsLoader] = {"default" : network.TargetCriterionsLoader()},
108
+ patch : Union[ModelPatch, None] = None,
109
+ encoder_channels: list[int] = [1,64,64,128,256,512],
110
+ decoder_channels: list[int] = [256,128,64,32,16,1],
111
+ layers: list[int] = [3,4,6,3],
112
+ dim : int = 2) -> None:
113
+ super().__init__(in_channels = encoder_channels[0], optimizer = optimizer, schedulers = schedulers, outputsCriterions = outputsCriterions, patch=patch, dim = dim)
114
+ self.add_module("Block_0", NestedUNetBlock(encoder_channels, decoder_channels[::-1], resNetEncoder(encoder_channels, layers, dim), blocks.UpSampleMode.UPSAMPLE_BILINEAR, dim=dim), out_branch=["X_0_{}".format(j+1) for j in range(len(encoder_channels)-2)])
115
+ self.add_module("Head", NestedUNetHead(in_channels=decoder_channels[-3], out_channels=decoder_channels[-2], nb_class=decoder_channels[-1], dim=dim), in_branch=["X_0_{}".format(len(encoder_channels)-2)], out_branch=[-1])
@@ -38,7 +38,6 @@ class UNetBlock(network.ModuleArgsDict):
38
38
 
39
39
  class UNet(network.Network):
40
40
 
41
- @config("UNet")
42
41
  def __init__( self,
43
42
  optimizer : network.OptimizerLoader = network.OptimizerLoader(),
44
43
  schedulers : network.LRSchedulersLoader = network.LRSchedulersLoader(),
konfai/network/blocks.py CHANGED
@@ -19,7 +19,7 @@ class NormMode(Enum):
19
19
  SYNCBATCH = 5
20
20
  INSTANCE_AFFINE = 6
21
21
 
22
- def getNorm(normMode: Enum, channels : int, dim: int) -> torch.nn.Module:
22
+ def getNorm(normMode: Enum, channels : int, dim: int) -> Union[torch.nn.Module, None]:
23
23
  if normMode == NormMode.BATCH:
24
24
  return getTorchModule("BatchNorm", dim = dim)(channels, affine=True, track_running_stats=True)
25
25
  if normMode == NormMode.INSTANCE:
@@ -32,7 +32,7 @@ def getNorm(normMode: Enum, channels : int, dim: int) -> torch.nn.Module:
32
32
  return torch.nn.GroupNorm(num_groups=32, num_channels=channels)
33
33
  if normMode == NormMode.LAYER:
34
34
  return torch.nn.GroupNorm(num_groups=1, num_channels=channels)
35
- return torch.nn.Identity()
35
+ return None
36
36
 
37
37
  class UpSampleMode(Enum):
38
38
  CONV_TRANSPOSE = 0,
@@ -59,9 +59,9 @@ class BlockConfig():
59
59
  self.stride = stride
60
60
  self.padding = padding
61
61
  self.activation = activation
62
- self.normMode = normMode
63
-
64
- if isinstance(normMode, str):
62
+ if normMode is None:
63
+ self.norm = None
64
+ elif isinstance(normMode, str):
65
65
  self.norm = NormMode._member_map_[normMode]
66
66
  elif isinstance(normMode, NormMode):
67
67
  self.norm = normMode
@@ -70,9 +70,13 @@ class BlockConfig():
70
70
  return getTorchModule("Conv", dim = dim)(in_channels = in_channels, out_channels = out_channels, kernel_size = self.kernel_size, stride = self.stride, padding = self.padding, bias=self.bias)
71
71
 
72
72
  def getNorm(self, channels : int, dim: int) -> torch.nn.Module:
73
+ if self.norm is None:
74
+ return None
73
75
  return getNorm(self.norm, channels, dim) if isinstance(self.norm, NormMode) else self.norm(channels)
74
76
 
75
77
  def getActivation(self) -> torch.nn.Module:
78
+ if self.activation is None:
79
+ return None
76
80
  if isinstance(self.activation, str):
77
81
  return getTorchModule(self.activation.split(";")[0])(*[ast.literal_eval(value) for value in self.activation.split(";")[1:]]) if self.activation != "None" else torch.nn.Identity()
78
82
  return self.activation()
@@ -83,21 +87,35 @@ class ConvBlock(network.ModuleArgsDict):
83
87
  super().__init__()
84
88
  for i, blockConfig in enumerate(blockConfigs):
85
89
  self.add_module("Conv_{}".format(i), blockConfig.getConv(in_channels, out_channels, dim), alias=alias[0])
86
- self.add_module("Norm_{}".format(i), blockConfig.getNorm(out_channels, dim), alias=alias[1])
87
- self.add_module("Activation_{}".format(i), blockConfig.getActivation(), alias=alias[2])
90
+ norm = blockConfig.getNorm(out_channels, dim)
91
+ if norm is not None:
92
+ self.add_module("Norm_{}".format(i), norm, alias=alias[1])
93
+ activation = blockConfig.getActivation()
94
+ if activation is not None:
95
+ self.add_module("Activation_{}".format(i), activation, alias=alias[2])
88
96
  in_channels = out_channels
89
97
 
90
98
  class ResBlock(network.ModuleArgsDict):
91
99
 
92
- def __init__(self, in_channels : int, out_channels : int, nb_conv: int, blockConfig : BlockConfig, dim : int, alias : list[list[str]]=[[], [], [], []]) -> None:
100
+ def __init__(self, in_channels : int, out_channels : int, blockConfigs : list[BlockConfig], dim : int, alias : list[list[str]]=[[], [], [], [], []]) -> None:
93
101
  super().__init__()
94
- for i in range(nb_conv):
95
- self.add_module("ConvBlock_{}".format(i), ConvBlock(in_channels, out_channels, nb_conv=1, blockConfig=blockConfig, dim=dim, alias=alias[:3]))
102
+ for i, blockConfig in enumerate(blockConfigs):
103
+ self.add_module("Conv_{}".format(i), blockConfig.getConv(in_channels, out_channels, dim), alias=alias[0])
104
+ norm = blockConfig.getNorm(out_channels, dim)
105
+ if norm is not None:
106
+ self.add_module("Norm_{}".format(i), norm, alias=alias[1])
107
+ activation = blockConfig.getActivation()
108
+ if activation is not None:
109
+ self.add_module("Activation_{}".format(i), activation, alias=alias[2])
110
+
96
111
  if in_channels != out_channels:
97
- self.add_module("Conv_skip_{}".format(i), blockConfig.getConv(in_channels, out_channels, dim), alias=alias[3], in_branch=[1], out_branch=[1])
112
+ self.add_module("Conv_skip", getTorchModule("Conv", dim)(in_channels, out_channels, 1, blockConfig.stride, bias=blockConfig.bias), alias=alias[3], in_branch=[1], out_branch=[1])
113
+ self.add_module("Norm_skip", blockConfig.getNorm(out_channels, dim), alias=alias[4], in_branch=[1], out_branch=[1])
98
114
  in_channels = out_channels
99
- self.add_module("Add_{}".format(i), Add(), in_branch=[0,1], out_branch=[0,1])
100
-
115
+
116
+ self.add_module("Add", Add(), in_branch=[0,1])
117
+ self.add_module("Norm_{}".format(i+1), torch.nn.ReLU(inplace=True))
118
+
101
119
  def downSample(in_channels: int, out_channels: int, downSampleMode: DownSampleMode, dim: int) -> torch.nn.Module:
102
120
  if downSampleMode == DownSampleMode.MAXPOOL:
103
121
  return getTorchModule("MaxPool", dim = dim)(2)
@@ -176,6 +194,7 @@ class Print(torch.nn.Module):
176
194
  super().__init__()
177
195
 
178
196
  def forward(self, input: torch.Tensor) -> torch.Tensor:
197
+ print(input.shape)
179
198
  return input
180
199
 
181
200
  class Write(torch.nn.Module):
konfai/network/network.py CHANGED
@@ -133,9 +133,10 @@ class Measure():
133
133
  self._loss.clear()
134
134
 
135
135
  def add(self, weight: float, value: torch.Tensor) -> None:
136
- self._loss.append(value if self.isLoss else value.detach())
137
- self._values.append(value.item())
138
- self._weight.append(weight)
136
+ if self.isLoss or value != 0:
137
+ self._loss.append(value if self.isLoss else value.detach())
138
+ self._values.append(value.item())
139
+ self._weight.append(weight)
139
140
 
140
141
  def getLastLoss(self) -> torch.Tensor:
141
142
  return self._loss[-1]*self._weight[-1] if len(self._loss) else torch.zeros((1), requires_grad = True)
@@ -157,7 +158,7 @@ class Measure():
157
158
  outputs_group_rename = {}
158
159
 
159
160
  modules = []
160
- for i,_ in model.named_modules():
161
+ for i,_,_ in model.named_ModuleArgsDict():
161
162
  modules.append(i)
162
163
 
163
164
  for output_group in self.outputsCriterions.keys():
@@ -167,11 +168,12 @@ class Measure():
167
168
  "Please check that the name matches exactly a submodule or output of your model architecture."
168
169
  )
169
170
  for target_group in self.outputsCriterions[output_group]:
170
- if target_group not in group_dest:
171
- raise MeasureError(
172
- f"The target_group '{target_group}' defined in 'outputsCriterions.{output_group}.targetsCriterions' was not found in the available destination groups.",
173
- "This target_group is expected for loss or metric computation, but was not loaded in 'group_dest'.",
174
- f"Please make sure that the group '{target_group}' is defined in 'Dataset:groups_src:...:groups_dest:'{target_group}'' and correctly loaded from the dataset.")
171
+ for target_group_tmp in target_group.split(";"):
172
+ if target_group_tmp not in group_dest:
173
+ raise MeasureError(
174
+ f"The target_group '{target_group_tmp}' defined in 'outputsCriterions.{output_group}.targetsCriterions' was not found in the available destination groups.",
175
+ "This target_group is expected for loss or metric computation, but was not loaded in 'group_dest'.",
176
+ f"Please make sure that the group '{target_group_tmp}' is defined in 'Dataset:groups_src:...:groups_dest:'{target_group_tmp}'' and correctly loaded from the dataset.")
175
177
  for criterion in self.outputsCriterions[output_group][target_group]:
176
178
  if not self.outputsCriterions[output_group][target_group][criterion].isTorchCriterion:
177
179
  outputs_group_rename[output_group] = criterion.init(model, output_group, target_group)
@@ -189,7 +191,7 @@ class Measure():
189
191
 
190
192
  def update(self, output_group: str, output : torch.Tensor, data_dict: dict[str, torch.Tensor], it: int, nb_patch: int, training: bool) -> None:
191
193
  for target_group in self.outputsCriterions[output_group]:
192
- target = [data_dict[group].to(output[0].device).detach() for group in target_group.split("/") if group in data_dict]
194
+ target = [data_dict[group].to(output[0].device).detach() for group in target_group.split(";") if group in data_dict]
193
195
 
194
196
  for criterion, criterionsAttr in self.outputsCriterions[output_group][target_group].items():
195
197
  if it >= criterionsAttr.stepStart and (criterionsAttr.stepStop is None or it <= criterionsAttr.stepStop):
@@ -928,6 +930,16 @@ class Network(ModuleArgsDict, ABC):
928
930
  if isinstance(module, ModuleArgsDict):
929
931
  module._training = state
930
932
 
933
+ class MinimalModel(Network):
934
+
935
+ def __init__(self, model: Network, optimizer : OptimizerLoader = OptimizerLoader(),
936
+ schedulers : LRSchedulersLoader = LRSchedulersLoader(),
937
+ outputsCriterions: dict[str, TargetCriterionsLoader] = {"default" : TargetCriterionsLoader()},
938
+ patch : Union[ModelPatch, None] = None,
939
+ dim : int = 3, nb_batch_per_step = 1, init_type = "normal", init_gain = 0.02):
940
+ super().__init__(1, optimizer, schedulers, outputsCriterions, patch, nb_batch_per_step, init_type, init_gain, dim)
941
+ self.add_module("Model", model)
942
+
931
943
  class ModelLoader():
932
944
 
933
945
  @config("Model")
@@ -937,11 +949,13 @@ class ModelLoader():
937
949
  def getModel(self, train : bool = True, DL_args: Union[str, None] = None, DL_without=["optimizer", "schedulers", "nb_batch_per_step", "init_type", "init_gain"]) -> Network:
938
950
  if not DL_args:
939
951
  DL_args="{}.Model".format(KONFAI_ROOT())
940
- model = partial(getattr(importlib.import_module(self.module), self.name), config = None, DL_args=DL_args)
941
- if not train:
942
- model = partial(model, DL_without = DL_without)
943
- return model()
944
-
952
+ DL_args += "."+self.name
953
+ model = config(DL_args)(getattr(importlib.import_module(self.module), self.name))(config = None, DL_without = DL_without if not train else [])
954
+ if not isinstance(model, Network):
955
+ model = config(DL_args)(partial(MinimalModel, model))(config = None, DL_without = DL_without+["model"] if not train else [])
956
+ model.setName(self.name)
957
+ return model
958
+
945
959
  class CPU_Model():
946
960
 
947
961
  def __init__(self, model: Network) -> None:
konfai/predictor.py CHANGED
@@ -8,7 +8,7 @@ import os
8
8
 
9
9
  from konfai import MODELS_DIRECTORY, PREDICTIONS_DIRECTORY, CONFIG_FILE, MODEL, KONFAI_ROOT
10
10
  from konfai.utils.config import config
11
- from konfai.utils.utils import State, get_patch_slices_from_nb_patch_per_dim, NeedDevice, _getModule, DistributedObject, DataLog, description
11
+ from konfai.utils.utils import State, get_patch_slices_from_nb_patch_per_dim, NeedDevice, _getModule, DistributedObject, DataLog, description, PredictorError
12
12
  from konfai.utils.dataset import Dataset, Attribute
13
13
  from konfai.data.data_manager import DataPrediction, DatasetIter
14
14
  from konfai.data.patching import Accumulator, PathCombine
@@ -46,7 +46,7 @@ class OutDataset(Dataset, NeedDevice, ABC):
46
46
  self.names: dict[int, str] = {}
47
47
  self.nb_data_augmentation = 0
48
48
 
49
- def load(self, name_layer: str, datasets: list[Dataset]):
49
+ def load(self, name_layer: str, datasets: list[Dataset], groups: dict[str, str]):
50
50
  transforms_type = ["pre_transforms", "post_transforms", "final_transforms"]
51
51
  for name, _transform_type, transform_type in [(k, getattr(self, "_{}".format(k)), getattr(self, k)) for k in transforms_type]:
52
52
 
@@ -137,9 +137,22 @@ class OutSameAsGroupDataset(OutDataset):
137
137
  if self.inverse_transform:
138
138
  for transform in reversed(dataset.groups_src[self.group_src][self.group_dest].post_transforms):
139
139
  layer = transform.inverse(self.names[index_dataset], layer, self.attributes[index_dataset][index_augmentation][index_patch])
140
-
141
140
  self.output_layer_accumulator[index_dataset][index_augmentation].addLayer(index_patch, layer)
142
141
 
142
+
143
+ def load(self, name_layer: str, datasets: list[Dataset], groups: dict[str, str]):
144
+ super().load(name_layer, datasets, groups)
145
+
146
+ if self.group_src not in groups.keys():
147
+ raise PredictorError(
148
+ f"Source group '{self.group_src}' not found. Available groups: {list(groups.keys())}."
149
+ )
150
+
151
+ if self.group_dest not in groups[self.group_src]:
152
+ raise PredictorError(
153
+ f"Destination group '{self.group_dest}' not found. Available groups: {groups[self.group_src]}."
154
+ )
155
+
143
156
  def _getOutput(self, index: int, index_augmentation: int, dataset: DatasetIter) -> torch.Tensor:
144
157
  layer = self.output_layer_accumulator[index][index_augmentation].assemble()
145
158
  name = self.names[index]
@@ -258,7 +271,6 @@ class _Predictor():
258
271
  desc = lambda : "Prediction : {}".format(description(self.modelComposite))
259
272
  self.dataloader_prediction.dataset.load("Prediction")
260
273
  with tqdm.tqdm(iterable = enumerate(self.dataloader_prediction), leave=True, desc = desc(), total=len(self.dataloader_prediction), ncols=0) as batch_iter:
261
- dist.barrier()
262
274
  for it, data_dict in batch_iter:
263
275
  with torch.amp.autocast('cuda', enabled=self.autocast):
264
276
  input = self.getInput(data_dict)
@@ -272,7 +284,7 @@ class _Predictor():
272
284
 
273
285
  batch_iter.set_description(desc())
274
286
  self.it += 1
275
-
287
+
276
288
  def _predict_log(self, data_dict : dict[str, tuple[torch.Tensor, int, int, int]]):
277
289
  measures = DistributedObject.getMeasure(self.world_size, self.global_rank, self.local_rank, {"" : self.modelComposite.module}, 1)
278
290
 
@@ -404,9 +416,20 @@ class Predictor(DistributedObject):
404
416
 
405
417
  shutil.copyfile(CONFIG_FILE(), self.predict_path+"Prediction.yml")
406
418
 
419
+
407
420
  self.model.init(self.autocast, State.PREDICTION, self.dataset.getGroupsDest())
408
421
  self.model.init_outputsGroup()
409
422
  self.model._compute_channels_trace(self.model, self.model.in_channels, None, self.gpu_checkpoints)
423
+
424
+ modules = []
425
+ for i,_,_ in self.model.named_ModuleArgsDict():
426
+ modules.append(i)
427
+ for output_group in self.outsDataset.keys():
428
+ if output_group not in modules:
429
+ raise PredictorError("The output group '{}' defined in 'outputsCriterions' does not correspond to any module in the model.".format(output_group),
430
+ "Available modules: {}".format(modules),
431
+ "Please check that the name matches exactly a submodule or output of your model architecture."
432
+ )
410
433
  self.modelComposite = ModelComposite(self.model, len(MODEL().split(":")), self.combine)
411
434
  self.modelComposite.load(self._load())
412
435
 
@@ -416,7 +439,7 @@ class Predictor(DistributedObject):
416
439
  self.size = (len(self.gpu_checkpoints)+1 if self.gpu_checkpoints else 1)
417
440
  self.dataloader = self.dataset.getData(world_size//self.size)
418
441
  for name, outDataset in self.outsDataset.items():
419
- outDataset.load(name.replace(".", ":"), list(self.dataset.datasets.values()))
442
+ outDataset.load(name.replace(".", ":"), list(self.dataset.datasets.values()), {src : dest for src, inner in self.dataset.groups_src.items() for dest in inner})
420
443
 
421
444
 
422
445
  def run_process(self, world_size: int, global_rank: int, local_rank: int, dataloaders: list[DataLoader]):
konfai/trainer.py CHANGED
@@ -34,7 +34,7 @@ class EarlyStoppingBase:
34
34
  class EarlyStopping(EarlyStoppingBase):
35
35
 
36
36
  @config("EarlyStopping")
37
- def __init__(self, monitor: Union[list[str], None] = [], patience=10, min_delta=0.0, mode="min"):
37
+ def __init__(self, monitor: Union[list[str], None] = [], patience: int=10, min_delta: float=0.0, mode: str="min"):
38
38
  super().__init__()
39
39
  self.monitor = [] if monitor is None else monitor
40
40
  self.patience = patience
@@ -120,7 +120,8 @@ class _Trainer():
120
120
 
121
121
  def run(self) -> None:
122
122
  self.dataloader_training.dataset.load("Train")
123
- self.dataloader_validation.dataset.load("Validation")
123
+ if self.dataloader_validation is not None:
124
+ self.dataloader_validation.dataset.load("Validation")
124
125
  with tqdm.tqdm(iterable = range(self.epoch, self.epochs), leave=False, total=self.epochs, initial=self.epoch, desc="Progress") as epoch_tqdm:
125
126
  for self.epoch in epoch_tqdm:
126
127
  self.train()
konfai/utils/config.py CHANGED
@@ -3,7 +3,7 @@ import ruamel.yaml
3
3
  import inspect
4
4
  import collections
5
5
  from copy import deepcopy
6
- from typing import Union, Literal, get_origin, get_args
6
+ from typing import Union, Literal, get_origin, get_args, Any
7
7
  import torch
8
8
  from konfai import CONFIG_FILE
9
9
  from konfai.utils.utils import ConfigError
@@ -175,8 +175,11 @@ def config(key : Union[str, None] = None):
175
175
  os.environ['KONFAI_CONFIG_VARIABLE'] = "False"
176
176
  kwargs = {}
177
177
  for param in list(inspect.signature(function).parameters.values())[len(args):]:
178
-
178
+ if param.name in without:
179
+ continue
180
+
179
181
  annotation = param.annotation
182
+
180
183
  # --- support Literal ---
181
184
  if get_origin(annotation) is Literal:
182
185
  allowed_values = get_args(annotation)
@@ -192,11 +195,10 @@ def config(key : Union[str, None] = None):
192
195
  for i in annotation.__args__:
193
196
  annotation = i
194
197
  break
195
- if param.name in without:
196
- continue
198
+
197
199
  if not annotation == inspect._empty:
198
200
  if annotation not in [int, str, bool, float, torch.Tensor]:
199
- if str(annotation).startswith("list") or str(annotation).startswith("tuple") or str(annotation).startswith("typing.Tuple") or str(annotation).startswith("typing.List"):
201
+ if str(annotation).startswith("list") or str(annotation).startswith("tuple") or str(annotation).startswith("typing.Tuple") or str(annotation).startswith("typing.List") or str(annotation).startswith("typing.Sequence"):
200
202
  elem_type = annotation.__args__[0]
201
203
  values = config.getValue(param.name, param.default)
202
204
  if getattr(elem_type, '__origin__', None) is Union:
@@ -221,7 +223,7 @@ def config(key : Union[str, None] = None):
221
223
  elif str(annotation).startswith("dict"):
222
224
  if annotation.__args__[0] == str:
223
225
  values = config.getValue(param.name, param.default)
224
- if values is not None and annotation.__args__[1] not in [int, str, bool, float]:
226
+ if values is not None and annotation.__args__[1] not in [int, str, bool, float, Any]:
225
227
  try:
226
228
  kwargs[param.name] = {value : annotation.__args__[1](config = filename, DL_args = key_tmp+"."+param.name+"."+value) for value in values}
227
229
  except ValueError as e:
@@ -229,6 +231,7 @@ def config(key : Union[str, None] = None):
229
231
  except Exception as e:
230
232
  raise ConfigError("{} {}".format(values, e))
231
233
  else:
234
+
232
235
  kwargs[param.name] = values
233
236
  else:
234
237
  raise ConfigError("Config: The config only supports types : config(Object), int, str, bool, float, list[int], list[str], list[bool], list[float], dict[str, Object]")
@@ -236,7 +239,7 @@ def config(key : Union[str, None] = None):
236
239
  try:
237
240
  kwargs[param.name] = annotation(config = filename, DL_args = key_tmp)
238
241
  except Exception as e:
239
- raise ConfigError("Failed to instantiate {} with type {}, error {} ".format(param.name, annotation.__name__, e))
242
+ raise ConfigError("Failed to instantiate {} with type {}, error {} ".format(param.name, annotation, e))
240
243
 
241
244
  if os.environ['KONFAI_CONFIG_VARIABLE'] == "True":
242
245
  os.environ['KONFAI_CONFIG_VARIABLE'] = "False"
konfai/utils/utils.py CHANGED
@@ -21,7 +21,7 @@ import sys
21
21
  import re
22
22
 
23
23
 
24
- def description(model, modelEMA = None, showMemory: bool = True) -> str:
24
+ def description(model, modelEMA = None, showMemory: bool = True, train: bool = True) -> str:
25
25
  values_desc = lambda weights, values: " ".join(["{}({:.2f}) : {:.6f}".format(name.split(":")[-1], weight, value) for (name, value), weight in zip(values.items(), weights.values())])
26
26
  model_desc = lambda model : "("+" ".join(["{}({:.6f}) : {}".format(name, network.optimizer.param_groups[0]['lr'] if network.optimizer is not None else 0, values_desc(network.measure.getLastWeights(), network.measure.getLastValues())) for name, network in model.module.getNetworks().items() if network.measure is not None])+")"
27
27
  result = "Loss {}".format(model_desc(model))
@@ -33,9 +33,9 @@ def description(model, modelEMA = None, showMemory: bool = True) -> str:
33
33
  return result
34
34
 
35
35
  def _getModule(classpath : str, type : str) -> tuple[str, str]:
36
- if len(classpath.split("_")) > 1:
37
- module = ".".join(classpath.split("_")[:-1])
38
- name = classpath.split("_")[-1]
36
+ if len(classpath.split(":")) > 1:
37
+ module = ".".join(classpath.split(":")[:-1])
38
+ name = classpath.split(":")[-1]
39
39
  else:
40
40
  module = "konfai."+type+("." if len(classpath.split(".")) > 2 else "")+".".join(classpath.split(".")[:-1])
41
41
  name = classpath.split(".")[-1]
@@ -586,6 +586,11 @@ class EvaluatorError(KonfAIError):
586
586
  def __init__(self, *message) -> None:
587
587
  super().__init__("Evaluator", message)
588
588
 
589
+ class PredictorError(KonfAIError):
590
+
591
+ def __init__(self, *message) -> None:
592
+ super().__init__("Predictor", message)
593
+
589
594
  class TransformError(KonfAIError):
590
595
 
591
596
  def __init__(self, *message) -> None:
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: konfai
3
- Version: 1.1.3
3
+ Version: 1.1.5
4
4
  Summary: Modular and configurable Deep Learning framework with YAML and PyTorch
5
5
  Author-email: Valentin Boussot <boussot.v@gmail.com>
6
6
  License-Expression: Apache-2.0
@@ -1,15 +1,15 @@
1
1
  konfai/__init__.py,sha256=YXG-wpSEXWs6Jt3BDI77V4r89gEUNX-6lxW9btj5VYI,851
2
- konfai/evaluator.py,sha256=rAhfdRemMjzC3VoaqyQKJR0SBekuLDiLT1nhblH8RQk,8293
2
+ konfai/evaluator.py,sha256=G8tXI-g8yTxBZXmESpRW7B5UEI4SMXq47fsGksU0G0M,8394
3
3
  konfai/main.py,sha256=xhCIs3VKXxNCgyyN616K-VANEFCsZZRR-l8LIMHlmPw,2597
4
- konfai/predictor.py,sha256=SgHs_gylsyVtQ41DBRBRmcpHIMg2ghLbxELbQBxjjFY,22747
5
- konfai/trainer.py,sha256=GCi8oRgTlbZxwKfmJlS2S4VU1yj1pCl6QeAWNwdSFPs,20223
4
+ konfai/predictor.py,sha256=bNkNBa_9pvbAKt8xrmAdfUnnPIerG0_NC1zOCCvSdGk,23973
5
+ konfai/trainer.py,sha256=mxzBDHMQR7bvpoW8WHpfAzzqFnxIApkDx5raf7GwTxw,20295
6
6
  konfai/data/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
7
- konfai/data/augmentation.py,sha256=kp4HyMSGflbuHbOsNiXBxszLFSuOZ4zRRQT0cHNIVUI,32620
8
- konfai/data/data_manager.py,sha256=oG73iHT5qV52UTWQPC3MKjsjaQHJUOAPF1WA-XeTmUE,29700
9
- konfai/data/patching.py,sha256=oq43r3JlS9hw-pnzpjT3epvXJxEKqiQEtFVZPjY9Jic,15749
7
+ konfai/data/augmentation.py,sha256=ruvS7ykFveOZBe2Mv_TOU5a8AzUgH7gfTipwqYgRM7k,32293
8
+ konfai/data/data_manager.py,sha256=8Te3R39K64ucy3_mWH7AbIurNrmFRwRUkUBY0OBsss4,30107
9
+ konfai/data/patching.py,sha256=zs3T4yTV8_iCFrqO21bo6GhwTywoTxIL31IAi-jiJDQ,15478
10
10
  konfai/data/transform.py,sha256=ZZoxft0V8t-DaKjF3OsIFezz41B8_xwPqBvYCYwF0dk,26471
11
11
  konfai/metric/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
12
- konfai/metric/measure.py,sha256=K7s15wNJm1_Iaav9y7Oe_UhSnbLUiybhLTVumHZ8ZaY,22001
12
+ konfai/metric/measure.py,sha256=PwNbR9Jb07gNmTtvL93MhM57LzUWrrZ-78bhCWcGSP0,22881
13
13
  konfai/metric/schedulers.py,sha256=UoSr1TW_hrus3DhvOEbDefxCSUGz7lJS_8vbz0GEye8,1370
14
14
  konfai/models/classification/convNeXt.py,sha256=Phj1hO8TCItVhBXoFXQcIkA85DZxurGLyHIPWBfXJ0Y,9243
15
15
  konfai/models/classification/resnet.py,sha256=t8KJEgGudWBpGJM9YS1lH_5eX0-wGbK7ZW5uZW214eM,7958
@@ -20,20 +20,20 @@ konfai/models/generation/gan.py,sha256=-GoKxHm3W9NdD4U77UcJrG5TfOZ3NWFUZG663kt2X
20
20
  konfai/models/generation/vae.py,sha256=_3JYVT2ojZ0P98tYcD2ny7a-gWVUmnByLDhY7i-n_4g,4719
21
21
  konfai/models/registration/registration.py,sha256=18EiWt4RJIXLyFtqU-kHjV1sMnQRm9mxAA6_-2B1YqI,6313
22
22
  konfai/models/representation/representation.py,sha256=RwQYoxtdph440-t_ZLelykl0hkUAD1zdspQaLkgxb-0,2677
23
- konfai/models/segmentation/NestedUNet.py,sha256=6XGizAIc4bDL8vx4AHW8BBFjUvovRYcjdMBHsN4ViNo,4301
24
- konfai/models/segmentation/UNet.py,sha256=JhKyGuaFXY7thJcbQ2mWGuIBtNdUhh5FoacpWDZ0N1k,4065
23
+ konfai/models/segmentation/NestedUNet.py,sha256=kF8uXmWdg4MOLTwzY4Zqq1FgWdwzT1I8mLAqi-XnuH8,9958
24
+ konfai/models/segmentation/UNet.py,sha256=Uei51-_0Le24ZJs9tbuL6hQn66kigkfbaLSMlH38els,4045
25
25
  konfai/network/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
26
- konfai/network/blocks.py,sha256=RaTI0Lrvq1V-GIFei-WTUB6wlg4LydZksAyJ8DMk40M,13502
27
- konfai/network/network.py,sha256=SnwQUKFTZZU8zuix5sm0vW8s0G3h6Fa8pHoIcwC984Q,46512
26
+ konfai/network/blocks.py,sha256=P7mEuuE1B0HRSgXRCeLxZvzYGKLskHD33hskpdkjIPs,14370
27
+ konfai/network/network.py,sha256=6DBxTehY6McTaEHl6QzVfIZf2pe69KFqktYkefe3sj8,47521
28
28
  konfai/utils/ITK.py,sha256=OxTieDNNYHGkn7zxJsAG-6ecRG1VYMvn1dlBbBe1DOs,13955
29
29
  konfai/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
30
- konfai/utils/config.py,sha256=OsCUqgNaPX56nD_MpJ2WRpiQ7xxqHonfPYaS9VGtsl8,12413
30
+ konfai/utils/config.py,sha256=f7o83ix5_oNbr2pki-Czqr-yHi-8n92ZL64nlo0XGwA,12514
31
31
  konfai/utils/dataset.py,sha256=6ZzevdhJ7e5zlXATAVwSh9O6acKXM7gYNxkMAa5DrmM,36351
32
32
  konfai/utils/registration.py,sha256=v1srEBOcgDnHrx0YtsK6bcj0yCMH7wNeaQ3wC7gEvOw,8898
33
- konfai/utils/utils.py,sha256=Laq8bGc5mGKFZlJIkHxa-BrC9uR2F7MTRuL4YFAIxQY,23439
34
- konfai-1.1.3.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
35
- konfai-1.1.3.dist-info/METADATA,sha256=Fj6snnVsKrMZjmdZZkp4u1YoonXYd_VWQV_B3hxd7qQ,2515
36
- konfai-1.1.3.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
37
- konfai-1.1.3.dist-info/entry_points.txt,sha256=fG82HRN5-g39ACSOCtij_I3N6EHxfYnMR0D7TI_8pW8,81
38
- konfai-1.1.3.dist-info/top_level.txt,sha256=xF470dkIlFoFqTZEOlRehKJr4WU_8OKGXrJqYm9vWKs,7
39
- konfai-1.1.3.dist-info/RECORD,,
33
+ konfai/utils/utils.py,sha256=EVikQCUnZy7Va-53E4plX9gY9aFancOoHd2NkT4Uy1k,23585
34
+ konfai-1.1.5.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
35
+ konfai-1.1.5.dist-info/METADATA,sha256=DmiQixYnw4ZAb6J1A9C_itL76GGLH3GNeQAa7XqjcGY,2515
36
+ konfai-1.1.5.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
37
+ konfai-1.1.5.dist-info/entry_points.txt,sha256=fG82HRN5-g39ACSOCtij_I3N6EHxfYnMR0D7TI_8pW8,81
38
+ konfai-1.1.5.dist-info/top_level.txt,sha256=xF470dkIlFoFqTZEOlRehKJr4WU_8OKGXrJqYm9vWKs,7
39
+ konfai-1.1.5.dist-info/RECORD,,
File without changes