konfai 1.1.2__tar.gz → 1.1.3__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of konfai might be problematic. Click here for more details.

Files changed (46) hide show
  1. {konfai-1.1.2 → konfai-1.1.3}/PKG-INFO +1 -1
  2. {konfai-1.1.2 → konfai-1.1.3}/konfai/data/augmentation.py +41 -36
  3. {konfai-1.1.2 → konfai-1.1.3}/konfai/data/data_manager.py +2 -2
  4. {konfai-1.1.2 → konfai-1.1.3}/konfai/data/patching.py +2 -1
  5. {konfai-1.1.2 → konfai-1.1.3}/konfai/data/transform.py +1 -1
  6. {konfai-1.1.2 → konfai-1.1.3}/konfai/main.py +2 -1
  7. {konfai-1.1.2 → konfai-1.1.3}/konfai/models/segmentation/UNet.py +9 -10
  8. {konfai-1.1.2 → konfai-1.1.3}/konfai/predictor.py +17 -0
  9. {konfai-1.1.2 → konfai-1.1.3}/konfai/trainer.py +14 -0
  10. {konfai-1.1.2 → konfai-1.1.3}/konfai.egg-info/PKG-INFO +1 -1
  11. {konfai-1.1.2 → konfai-1.1.3}/pyproject.toml +1 -1
  12. {konfai-1.1.2 → konfai-1.1.3}/LICENSE +0 -0
  13. {konfai-1.1.2 → konfai-1.1.3}/README.md +0 -0
  14. {konfai-1.1.2 → konfai-1.1.3}/konfai/__init__.py +0 -0
  15. {konfai-1.1.2 → konfai-1.1.3}/konfai/data/__init__.py +0 -0
  16. {konfai-1.1.2 → konfai-1.1.3}/konfai/evaluator.py +0 -0
  17. {konfai-1.1.2 → konfai-1.1.3}/konfai/metric/__init__.py +0 -0
  18. {konfai-1.1.2 → konfai-1.1.3}/konfai/metric/measure.py +0 -0
  19. {konfai-1.1.2 → konfai-1.1.3}/konfai/metric/schedulers.py +0 -0
  20. {konfai-1.1.2 → konfai-1.1.3}/konfai/models/classification/convNeXt.py +0 -0
  21. {konfai-1.1.2 → konfai-1.1.3}/konfai/models/classification/resnet.py +0 -0
  22. {konfai-1.1.2 → konfai-1.1.3}/konfai/models/generation/cStyleGan.py +0 -0
  23. {konfai-1.1.2 → konfai-1.1.3}/konfai/models/generation/ddpm.py +0 -0
  24. {konfai-1.1.2 → konfai-1.1.3}/konfai/models/generation/diffusionGan.py +0 -0
  25. {konfai-1.1.2 → konfai-1.1.3}/konfai/models/generation/gan.py +0 -0
  26. {konfai-1.1.2 → konfai-1.1.3}/konfai/models/generation/vae.py +0 -0
  27. {konfai-1.1.2 → konfai-1.1.3}/konfai/models/registration/registration.py +0 -0
  28. {konfai-1.1.2 → konfai-1.1.3}/konfai/models/representation/representation.py +0 -0
  29. {konfai-1.1.2 → konfai-1.1.3}/konfai/models/segmentation/NestedUNet.py +0 -0
  30. {konfai-1.1.2 → konfai-1.1.3}/konfai/network/__init__.py +0 -0
  31. {konfai-1.1.2 → konfai-1.1.3}/konfai/network/blocks.py +0 -0
  32. {konfai-1.1.2 → konfai-1.1.3}/konfai/network/network.py +0 -0
  33. {konfai-1.1.2 → konfai-1.1.3}/konfai/utils/ITK.py +0 -0
  34. {konfai-1.1.2 → konfai-1.1.3}/konfai/utils/__init__.py +0 -0
  35. {konfai-1.1.2 → konfai-1.1.3}/konfai/utils/config.py +0 -0
  36. {konfai-1.1.2 → konfai-1.1.3}/konfai/utils/dataset.py +0 -0
  37. {konfai-1.1.2 → konfai-1.1.3}/konfai/utils/registration.py +0 -0
  38. {konfai-1.1.2 → konfai-1.1.3}/konfai/utils/utils.py +0 -0
  39. {konfai-1.1.2 → konfai-1.1.3}/konfai.egg-info/SOURCES.txt +0 -0
  40. {konfai-1.1.2 → konfai-1.1.3}/konfai.egg-info/dependency_links.txt +0 -0
  41. {konfai-1.1.2 → konfai-1.1.3}/konfai.egg-info/entry_points.txt +0 -0
  42. {konfai-1.1.2 → konfai-1.1.3}/konfai.egg-info/requires.txt +0 -0
  43. {konfai-1.1.2 → konfai-1.1.3}/konfai.egg-info/top_level.txt +0 -0
  44. {konfai-1.1.2 → konfai-1.1.3}/setup.cfg +0 -0
  45. {konfai-1.1.2 → konfai-1.1.3}/tests/test_config.py +0 -0
  46. {konfai-1.1.2 → konfai-1.1.3}/tests/test_dataset.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: konfai
3
- Version: 1.1.2
3
+ Version: 1.1.3
4
4
  Summary: Modular and configurable Deep Learning framework with YAML and PyTorch
5
5
  Author-email: Valentin Boussot <boussot.v@gmail.com>
6
6
  License-Expression: Apache-2.0
@@ -9,8 +9,7 @@ import os
9
9
  from konfai import KONFAI_ROOT
10
10
  from konfai.utils.config import config
11
11
  from konfai.utils.utils import _getModule, AugmentationError
12
- from konfai.utils.dataset import Attribute, data_to_image
13
-
12
+ from konfai.utils.dataset import Attribute, data_to_image, Dataset
14
13
 
15
14
  def _translate2DMatrix(t: torch.Tensor) -> torch.Tensor:
16
15
  return torch.cat((torch.cat((torch.eye(2), torch.tensor([[t[0]], [t[1]]])), dim=1), torch.Tensor([[0,0,1]])), dim=0)
@@ -56,23 +55,29 @@ class DataAugmentationsList():
56
55
  self.dataAugmentations : list[DataAugmentation] = []
57
56
  self.dataAugmentationsLoader = dataAugmentations
58
57
 
59
- def load(self, key: str):
58
+ def load(self, key: str, datasets: list[Dataset]):
60
59
  for augmentation, prob in self.dataAugmentationsLoader.items():
61
60
  module, name = _getModule(augmentation, "data.augmentation")
62
- dataAugmentation: DataAugmentation = getattr(importlib.import_module(module), name)(config = None, DL_args="{}.Dataset.augmentations.{}.dataAugmentations".format(KONFAI_ROOT(), key))
61
+ dataAugmentation: DataAugmentation = config("{}.Dataset.augmentations.{}.dataAugmentations.{}".format(KONFAI_ROOT(), key, augmentation))(getattr(importlib.import_module(module), name))(config = None)
63
62
  dataAugmentation.load(prob.prob)
63
+ dataAugmentation.setDatasets(datasets)
64
64
  self.dataAugmentations.append(dataAugmentation)
65
-
65
+
66
66
  class DataAugmentation(ABC):
67
67
 
68
- def __init__(self) -> None:
68
+ def __init__(self, groups: Union[list[str], None] = None) -> None:
69
69
  self.who_index: dict[int, list[int]] = {}
70
70
  self.shape_index: dict[int, list[list[int]]] = {}
71
71
  self._prob: float = 0
72
+ self.groups = groups
73
+ self.datasets : list[Dataset] = []
72
74
 
73
75
  def load(self, prob: float):
74
76
  self._prob = prob
75
77
 
78
+ def setDatasets(self, datasets: list[Dataset]):
79
+ self.datasets = datasets
80
+
76
81
  def state_init(self, index: Union[None, int], shapes: list[list[int]], caches_attribute: list[Attribute]) -> list[list[int]]:
77
82
  if index is not None:
78
83
  if index not in self.who_index:
@@ -93,14 +98,14 @@ class DataAugmentation(ABC):
93
98
  def _state_init(self, index : int, shapes: list[list[int]], caches_attribute: list[Attribute]) -> list[list[int]]:
94
99
  pass
95
100
 
96
- def __call__(self, index: int, inputs : list[torch.Tensor], device: Union[torch.device, None]) -> list[torch.Tensor]:
101
+ def __call__(self, name: str, index: int, inputs : list[torch.Tensor], device: Union[torch.device, None]) -> list[torch.Tensor]:
97
102
  if len(self.who_index[index]) > 0:
98
- for i, result in enumerate(self._compute(index, [inputs[i] for i in self.who_index[index]], device)):
103
+ for i, result in enumerate(self._compute(name, index, [inputs[i] for i in self.who_index[index]], device)):
99
104
  inputs[self.who_index[index][i]] = result if device is None else result.cpu()
100
105
  return inputs
101
106
 
102
107
  @abstractmethod
103
- def _compute(self, index: int, inputs : list[torch.Tensor], device: Union[torch.device, None]) -> list[torch.Tensor]:
108
+ def _compute(self, name: str, index: int, inputs : list[torch.Tensor], device: Union[torch.device, None]) -> list[torch.Tensor]:
104
109
  pass
105
110
 
106
111
  def inverse(self, index: int, a: int, input : torch.Tensor) -> torch.Tensor:
@@ -118,7 +123,7 @@ class EulerTransform(DataAugmentation):
118
123
  super().__init__()
119
124
  self.matrix: dict[int, list[torch.Tensor]] = {}
120
125
 
121
- def _compute(self, index: int, inputs : list[torch.Tensor], device: Union[torch.device, None]) -> list[torch.Tensor]:
126
+ def _compute(self, name: str, index: int, inputs : list[torch.Tensor], device: Union[torch.device, None]) -> list[torch.Tensor]:
122
127
  results = []
123
128
  for input, matrix in zip(inputs, self.matrix[index]):
124
129
  results.append(F.grid_sample(input.unsqueeze(0).type(torch.float32), F.affine_grid(matrix[:, :-1,...], [1]+list(input.shape), align_corners=True).to(input.device), align_corners=True, mode="bilinear", padding_mode="reflection").type(input.dtype).squeeze(0))
@@ -126,7 +131,7 @@ class EulerTransform(DataAugmentation):
126
131
 
127
132
  def _inverse(self, index: int, a: int, input : torch.Tensor) -> torch.Tensor:
128
133
  return F.grid_sample(input.unsqueeze(0).type(torch.float32), F.affine_grid(self.matrix[index][a].inverse()[:, :-1,...], [1]+list(input.shape), align_corners=True).to(input.device), align_corners=True, mode="bilinear", padding_mode="reflection").type(input.dtype).squeeze(0)
129
-
134
+
130
135
  class Translate(EulerTransform):
131
136
 
132
137
  @config("Translate")
@@ -194,7 +199,7 @@ class Flip(DataAugmentation):
194
199
  self.flip[index] = [dims[mask].tolist() for mask in prob]
195
200
  return shapes
196
201
 
197
- def _compute(self, index: int, inputs : list[torch.Tensor], device: Union[torch.device, None]) -> list[torch.Tensor]:
202
+ def _compute(self, name: str, index: int, inputs : list[torch.Tensor], device: Union[torch.device, None]) -> list[torch.Tensor]:
198
203
  results = []
199
204
  for input, flip in zip(inputs, self.flip[index]):
200
205
  results.append(torch.flip(input, dims=flip))
@@ -207,11 +212,11 @@ class Flip(DataAugmentation):
207
212
  class ColorTransform(DataAugmentation):
208
213
 
209
214
  @config("ColorTransform")
210
- def __init__(self) -> None:
211
- super().__init__()
215
+ def __init__(self, groups: Union[list[str], None] = None) -> None:
216
+ super().__init__(groups)
212
217
  self.matrix: dict[int, list[torch.Tensor]] = {}
213
218
 
214
- def _compute(self, index: int, inputs : list[torch.Tensor], device: Union[torch.device, None]) -> list[torch.Tensor]:
219
+ def _compute(self, name: str, index: int, inputs : list[torch.Tensor], device: Union[torch.device, None]) -> list[torch.Tensor]:
215
220
  results = []
216
221
  for input, matrix in zip(inputs, self.matrix[index]):
217
222
  result = input.reshape([*input.shape[:1], int(np.prod(input.shape[1:]))])
@@ -232,8 +237,8 @@ class ColorTransform(DataAugmentation):
232
237
  class Brightness(ColorTransform):
233
238
 
234
239
  @config("Brightness")
235
- def __init__(self, b_std: float) -> None:
236
- super().__init__()
240
+ def __init__(self, b_std: float, groups: Union[list[str], None] = None) -> None:
241
+ super().__init__(groups)
237
242
  self.b_std = b_std
238
243
 
239
244
  def _state_init(self, index : int, shapes: list[list[int]], caches_attribute: list[Attribute]) -> list[list[int]]:
@@ -244,8 +249,8 @@ class Brightness(ColorTransform):
244
249
  class Contrast(ColorTransform):
245
250
 
246
251
  @config("Contrast")
247
- def __init__(self, c_std: float) -> None:
248
- super().__init__()
252
+ def __init__(self, c_std: float, groups: Union[list[str], None] = None) -> None:
253
+ super().__init__(groups)
249
254
  self.c_std = c_std
250
255
 
251
256
  def _state_init(self, index : int, shapes: list[list[int]], caches_attribute: list[Attribute]) -> list[list[int]]:
@@ -256,8 +261,8 @@ class Contrast(ColorTransform):
256
261
  class LumaFlip(ColorTransform):
257
262
 
258
263
  @config("LumaFlip")
259
- def __init__(self) -> None:
260
- super().__init__()
264
+ def __init__(self, groups: Union[list[str], None] = None) -> None:
265
+ super().__init__(groups)
261
266
  self.v = torch.tensor([1, 1, 1, 0])/torch.sqrt(torch.tensor(3))
262
267
 
263
268
  def _state_init(self, index : int, shapes: list[list[int]], caches_attribute: list[Attribute]) -> list[list[int]]:
@@ -268,8 +273,8 @@ class LumaFlip(ColorTransform):
268
273
  class HUE(ColorTransform):
269
274
 
270
275
  @config("HUE")
271
- def __init__(self, hue_max: float) -> None:
272
- super().__init__()
276
+ def __init__(self, hue_max: float, groups: Union[list[str], None] = None) -> None:
277
+ super().__init__(groups)
273
278
  self.hue_max = hue_max
274
279
  self.v = torch.tensor([1, 1, 1])/torch.sqrt(torch.tensor(3))
275
280
 
@@ -281,8 +286,8 @@ class HUE(ColorTransform):
281
286
  class Saturation(ColorTransform):
282
287
 
283
288
  @config("Saturation")
284
- def __init__(self, s_std: float) -> None:
285
- super().__init__()
289
+ def __init__(self, s_std: float, groups: Union[list[str], None] = None) -> None:
290
+ super().__init__(groups)
286
291
  self.s_std = s_std
287
292
  self.v = torch.tensor([1, 1, 1, 0])/torch.sqrt(torch.tensor(3))
288
293
 
@@ -368,8 +373,8 @@ class Saturation(ColorTransform):
368
373
  class Noise(DataAugmentation):
369
374
 
370
375
  @config("Noise")
371
- def __init__(self, n_std: float, noise_step: int=1000, beta_start: float = 1e-4, beta_end: float = 0.02) -> None:
372
- super().__init__()
376
+ def __init__(self, n_std: float, noise_step: int=1000, beta_start: float = 1e-4, beta_end: float = 0.02, groups: Union[list[str], None] = None) -> None:
377
+ super().__init__(groups)
373
378
  self.n_std = n_std
374
379
  self.noise_step = noise_step
375
380
 
@@ -410,7 +415,7 @@ class Noise(DataAugmentation):
410
415
  self.ts[index] = [torch.randint(0, int(self.max_T), (1,)) for _ in shapes]
411
416
  return shapes
412
417
 
413
- def _compute(self, index: int, inputs : list[torch.Tensor], device: Union[torch.device, None]) -> list[torch.Tensor]:
418
+ def _compute(self, name: str, index: int, inputs : list[torch.Tensor], device: Union[torch.device, None]) -> list[torch.Tensor]:
414
419
  results = []
415
420
  for input, t in zip(inputs, self.ts[index]):
416
421
  alpha_hat_t = self.alpha_hat[t].to(input.device).reshape(*[1 for _ in range(len(input.shape))])
@@ -423,8 +428,8 @@ class Noise(DataAugmentation):
423
428
  class CutOUT(DataAugmentation):
424
429
 
425
430
  @config("CutOUT")
426
- def __init__(self, c_prob: float, cutout_size: int, value: float) -> None:
427
- super().__init__()
431
+ def __init__(self, c_prob: float, cutout_size: int, value: float, groups: Union[list[str], None] = None) -> None:
432
+ super().__init__(groups)
428
433
  self.c_prob = c_prob
429
434
  self.cutout_size = cutout_size
430
435
  self.centers: dict[int, list[torch.Tensor]] = {}
@@ -434,7 +439,7 @@ class CutOUT(DataAugmentation):
434
439
  self.centers[index] = [torch.rand((3) if len(shape) == 3 else (2)) for shape in shapes]
435
440
  return shapes
436
441
 
437
- def _compute(self, index: int, inputs : list[torch.Tensor], device: Union[torch.device, None]) -> list[torch.Tensor]:
442
+ def _compute(self, name: str, index: int, inputs : list[torch.Tensor], device: Union[torch.device, None]) -> list[torch.Tensor]:
438
443
  results = []
439
444
  for input, center in zip(inputs, self.centers[index]):
440
445
  masks = []
@@ -513,7 +518,7 @@ class Elastix(DataAugmentation):
513
518
  print("Compute in progress : {:.2f} %".format((i+1)/len(shapes)*100))
514
519
  return shapes
515
520
 
516
- def _compute(self, index: int, inputs : list[torch.Tensor], device: Union[torch.device, None]) -> list[torch.Tensor]:
521
+ def _compute(self, name: str, index: int, inputs : list[torch.Tensor], device: Union[torch.device, None]) -> list[torch.Tensor]:
517
522
  results = []
518
523
  for input, displacement_field in zip(inputs, self.displacement_fields[index]):
519
524
  results.append(F.grid_sample(input.type(torch.float32).unsqueeze(0), displacement_field.to(input.device), align_corners=True, mode="bilinear", padding_mode="border").type(input.dtype).squeeze(0))
@@ -546,7 +551,7 @@ class Permute(DataAugmentation):
546
551
  shapes[i] = [shapes[i][dim-1] for dim in permute[1:]]
547
552
  return shapes
548
553
 
549
- def _compute(self, index: int, inputs : list[torch.Tensor], device: Union[torch.device, None]) -> list[torch.Tensor]:
554
+ def _compute(self, name: str, index: int, inputs : list[torch.Tensor], device: Union[torch.device, None]) -> list[torch.Tensor]:
550
555
  results = []
551
556
  for input, prob in zip(inputs, self.permute[index]):
552
557
  res = input
@@ -563,8 +568,8 @@ class Permute(DataAugmentation):
563
568
  class Mask(DataAugmentation):
564
569
 
565
570
  @config("Mask")
566
- def __init__(self, mask: str, value: float) -> None:
567
- super().__init__()
571
+ def __init__(self, mask: str, value: float, groups: Union[list[str], None] = None) -> None:
572
+ super().__init__(groups)
568
573
  if mask is not None:
569
574
  if os.path.exists(mask):
570
575
  self.mask = torch.tensor(sitk.GetArrayFromImage(sitk.ReadImage(mask)))
@@ -577,7 +582,7 @@ class Mask(DataAugmentation):
577
582
  self.positions[index] = [torch.rand((3) if len(shape) == 3 else (2))*(torch.tensor([max(s1-s2, 0) for s1, s2 in zip(torch.tensor(shape), torch.tensor(self.mask.shape))])) for shape in shapes]
578
583
  return [self.mask.shape for _ in shapes]
579
584
 
580
- def _compute(self, index: int, inputs : list[torch.Tensor], device: Union[torch.device, None]) -> list[torch.Tensor]:
585
+ def _compute(self, name: str, index: int, inputs : list[torch.Tensor], device: Union[torch.device, None]) -> list[torch.Tensor]:
581
586
  results = []
582
587
  for input, position in zip(inputs, self.positions[index]):
583
588
  slices = [slice(None, None)]+[slice(int(s1), int(s1)+s2) for s1, s2 in zip(position, self.mask.shape)]
@@ -184,7 +184,7 @@ class DatasetIter(data.Dataset):
184
184
  data = {}
185
185
  x, a, p = self.map[index]
186
186
  if x not in self._index_cache:
187
- if len(self._index_cache) >= self.buffer_size and not self.use_cache:
187
+ if x not in self._index_cache and len(self._index_cache) >= self.buffer_size and not self.use_cache:
188
188
  self._unloadData(self._index_cache[0])
189
189
  self._loadData(x)
190
190
 
@@ -390,7 +390,7 @@ class Data(ABC):
390
390
  )
391
391
 
392
392
  for key, dataAugmentations in self.dataAugmentationsList.items():
393
- dataAugmentations.load(key)
393
+ dataAugmentations.load(key, [self.datasets[filename] for filename, _ in datasets[group_src]])
394
394
 
395
395
  names = set()
396
396
  dataset_name : dict[str, dict[str, list[str]]] = {}
@@ -326,7 +326,8 @@ class DatasetManager():
326
326
  for dataAugmentations in dataAugmentationsList:
327
327
  a_data = [self.data[0].clone() for _ in range(dataAugmentations.nb)]
328
328
  for dataAugmentation in dataAugmentations.dataAugmentations:
329
- a_data = dataAugmentation(self.index, a_data, device)
329
+ if dataAugmentation.groups is None or self.group_dest in dataAugmentation.groups:
330
+ a_data = dataAugmentation(self.name, self.index, a_data, device)
330
331
 
331
332
  for d in a_data:
332
333
  self.data.append(d)
@@ -52,7 +52,7 @@ class Clip(Transform):
52
52
  input[torch.where(input < self.min_value)] = self.min_value
53
53
  input[torch.where(input > self.max_value)] = self.max_value
54
54
  if self.saveClip_min:
55
- cache_attribute["Min"] = self .min_value
55
+ cache_attribute["Min"] = self.min_value
56
56
  if self.saveClip_max:
57
57
  cache_attribute["Max"] = self.max_value
58
58
  return input
@@ -9,6 +9,8 @@ import sys
9
9
  sys.path.insert(0, os.getcwd())
10
10
 
11
11
  def main():
12
+ import tracemalloc
13
+
12
14
  parser = argparse.ArgumentParser(description="KonfAI", formatter_class=argparse.ArgumentDefaultsHelpFormatter)
13
15
  try:
14
16
  with setup(parser) as distributedObject:
@@ -22,7 +24,6 @@ def main():
22
24
  except KeyboardInterrupt:
23
25
  print("\n[KonfAI] Manual interruption (Ctrl+C)")
24
26
 
25
-
26
27
  def cluster():
27
28
  parser = argparse.ArgumentParser(description="KonfAI", formatter_class=argparse.ArgumentDefaultsHelpFormatter)
28
29
 
@@ -7,32 +7,33 @@ from konfai.data.patching import ModelPatch
7
7
 
8
8
  class UNetHead(network.ModuleArgsDict):
9
9
 
10
- def __init__(self, in_channels: int, nb_class: int, dim: int) -> None:
10
+ def __init__(self, in_channels: int, nb_class: int, dim: int, level: int) -> None:
11
11
  super().__init__()
12
12
  self.add_module("Conv", blocks.getTorchModule("Conv", dim)(in_channels = in_channels, out_channels = nb_class, kernel_size = 1, stride = 1, padding = 0))
13
13
  self.add_module("Softmax", torch.nn.Softmax(dim=1))
14
14
  self.add_module("Argmax", blocks.ArgMax(dim=1))
15
+
15
16
 
16
17
  class UNetBlock(network.ModuleArgsDict):
17
18
 
18
- def __init__(self, channels: list[int], nb_conv_per_stage: int, blockConfig: blocks.BlockConfig, downSampleMode: blocks.DownSampleMode, upSampleMode: blocks.UpSampleMode, attention : bool, block: type, nb_class: int, dim: int, i : int = 0, mri: bool = False) -> None:
19
+ def __init__(self, channels: list[int], nb_conv_per_stage: int, blockConfig: blocks.BlockConfig, downSampleMode: blocks.DownSampleMode, upSampleMode: blocks.UpSampleMode, attention : bool, block: type, nb_class: int, dim: int, i : int = 0) -> None:
19
20
  super().__init__()
20
21
  blockConfig_stride = blockConfig
21
22
  if i > 0:
22
23
  if downSampleMode != blocks.DownSampleMode.CONV_STRIDE:
23
24
  self.add_module(downSampleMode.name, blocks.downSample(in_channels=channels[0], out_channels=channels[1], downSampleMode=downSampleMode, dim=dim))
24
25
  else:
25
- blockConfig_stride = blocks.BlockConfig(blockConfig.kernel_size, (1,2,2) if mri and i > 4 else 2, blockConfig.padding, blockConfig.bias, blockConfig.activation, blockConfig.normMode)
26
+ blockConfig_stride = blocks.BlockConfig(blockConfig.kernel_size, 2, blockConfig.padding, blockConfig.bias, blockConfig.activation, blockConfig.normMode)
26
27
  self.add_module("DownConvBlock", block(in_channels=channels[0], out_channels=channels[1], blockConfigs=[blockConfig_stride]+[blockConfig]*(nb_conv_per_stage-1), dim=dim))
27
28
  if len(channels) > 2:
28
- self.add_module("UNetBlock_{}".format(i+1), UNetBlock(channels[1:], nb_conv_per_stage, blockConfig, downSampleMode, upSampleMode, attention, block, nb_class, dim, i+1, mri=mri))
29
+ self.add_module("UNetBlock_{}".format(i+1), UNetBlock(channels[1:], nb_conv_per_stage, blockConfig, downSampleMode, upSampleMode, attention, block, nb_class, dim, i+1))
29
30
  self.add_module("UpConvBlock", block(in_channels=(channels[1]+channels[2]) if upSampleMode != blocks.UpSampleMode.CONV_TRANSPOSE else channels[1]*2, out_channels=channels[1], blockConfigs=[blockConfig]*nb_conv_per_stage, dim=dim))
30
31
  if nb_class > 0:
31
- self.add_module("Head", UNetHead(channels[1], nb_class, dim), out_branch=[-1])
32
+ self.add_module("Head", UNetHead(channels[1], nb_class, dim, i), out_branch=[-1])
32
33
  if i > 0:
33
34
  if attention:
34
35
  self.add_module("Attention", blocks.Attention(F_g=channels[1], F_l=channels[0], F_int=channels[0], dim=dim), in_branch=[1, 0], out_branch=[1])
35
- self.add_module(upSampleMode.name, blocks.upSample(in_channels=channels[1], out_channels=channels[0], upSampleMode=upSampleMode, dim=dim, kernel_size=(1,2,2) if mri and i > 4 else 2, stride=(1,2,2) if mri and i > 4 else 2))
36
+ self.add_module(upSampleMode.name, blocks.upSample(in_channels=channels[1], out_channels=channels[0], upSampleMode=upSampleMode, dim=dim, kernel_size=2, stride=2))
36
37
  self.add_module("SkipConnection", blocks.Concat(), in_branch=[0, 1])
37
38
 
38
39
  class UNet(network.Network):
@@ -51,8 +52,6 @@ class UNet(network.Network):
51
52
  downSampleMode: str = "MAXPOOL",
52
53
  upSampleMode: str = "CONV_TRANSPOSE",
53
54
  attention : bool = False,
54
- blockType: str = "Conv",
55
- mri: bool = False) -> None:
55
+ blockType: str = "Conv") -> None:
56
56
  super().__init__(in_channels = channels[0], optimizer = optimizer, schedulers = schedulers, outputsCriterions = outputsCriterions, patch=patch, dim = dim)
57
- self.add_module("UNetBlock_0", UNetBlock(channels, nb_conv_per_stage, blockConfig, downSampleMode=blocks.DownSampleMode._member_map_[downSampleMode], upSampleMode=blocks.UpSampleMode._member_map_[upSampleMode], attention=attention, block = blocks.ConvBlock if blockType == "Conv" else blocks.ResBlock, nb_class=nb_class, dim=dim, mri = mri))
58
-
57
+ self.add_module("UNetBlock_0", UNetBlock(channels, nb_conv_per_stage, blockConfig, downSampleMode=blocks.DownSampleMode._member_map_[downSampleMode], upSampleMode=blocks.UpSampleMode._member_map_[upSampleMode], attention=attention, block = blocks.ConvBlock if blockType == "Conv" else blocks.ResBlock, nb_class=nb_class, dim=dim))
@@ -91,6 +91,23 @@ class OutDataset(Dataset, NeedDevice, ABC):
91
91
  super().write(self.group, name, layer.numpy(), self.attributes[index][0][0])
92
92
  self.attributes.pop(index)
93
93
 
94
+ class Reduction():
95
+
96
+ def __init__(self):
97
+ pass
98
+
99
+ class ReductionMean():
100
+
101
+ def __init__(self):
102
+ pass
103
+
104
+ class ReductionMedian():
105
+
106
+ def __init__(self):
107
+ pass
108
+
109
+
110
+
94
111
  class OutSameAsGroupDataset(OutDataset):
95
112
 
96
113
  @config("OutDataset")
@@ -109,6 +109,8 @@ class _Trainer():
109
109
  for data in data_log:
110
110
  self.data_log[data.split("/")[0].replace(":", ".")] = (DataLog.__getitem__(data.split("/")[1]).value[0], int(data.split("/")[2]))
111
111
 
112
+
113
+
112
114
  def __enter__(self):
113
115
  return self
114
116
 
@@ -314,6 +316,18 @@ class Trainer(DistributedObject):
314
316
  self.ema_decay = ema_decay
315
317
  self.modelEMA : Union[torch.optim.swa_utils.AveragedModel, None] = None
316
318
  self.data_log = data_log
319
+
320
+ modules = []
321
+ for i,_ in self.model.named_modules():
322
+ modules.append(i)
323
+ for k in self.data_log:
324
+ tmp = k.split("/")[0].replace(":", ".")
325
+ if tmp not in self.dataset.getGroupsDest() and tmp not in modules:
326
+ raise TrainerError( f"Invalid key '{tmp}' in `data_log`.",
327
+ f"This key is neither a destination group from the dataset ({self.dataset.getGroupsDest()})",
328
+ f"nor a valid module name in the model ({modules}).",
329
+ "Please check your `data_log` configuration — it should reference either a model output or a dataset group.")
330
+
317
331
  self.gradient_checkpoints = gradient_checkpoints
318
332
  self.gpu_checkpoints = gpu_checkpoints
319
333
  self.save_checkpoint_mode = save_checkpoint_mode
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: konfai
3
- Version: 1.1.2
3
+ Version: 1.1.3
4
4
  Summary: Modular and configurable Deep Learning framework with YAML and PyTorch
5
5
  Author-email: Valentin Boussot <boussot.v@gmail.com>
6
6
  License-Expression: Apache-2.0
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
4
4
 
5
5
  [project]
6
6
  name = "konfai"
7
- version = "1.1.2"
7
+ version = "1.1.3"
8
8
  description = "Modular and configurable Deep Learning framework with YAML and PyTorch"
9
9
  readme = "README.md"
10
10
  requires-python = ">=3.8"
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes