konfai 1.2.4__tar.gz → 1.2.6__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of konfai might be problematic. Click here for more details.

Files changed (45) hide show
  1. {konfai-1.2.4 → konfai-1.2.6}/PKG-INFO +2 -1
  2. {konfai-1.2.4 → konfai-1.2.6}/konfai/data/augmentation.py +1 -1
  3. {konfai-1.2.4 → konfai-1.2.6}/konfai/data/data_manager.py +1 -1
  4. {konfai-1.2.4 → konfai-1.2.6}/konfai/data/transform.py +134 -75
  5. {konfai-1.2.4 → konfai-1.2.6}/konfai/network/network.py +1 -2
  6. {konfai-1.2.4 → konfai-1.2.6}/konfai/predictor.py +12 -8
  7. {konfai-1.2.4 → konfai-1.2.6}/konfai.egg-info/PKG-INFO +2 -1
  8. {konfai-1.2.4 → konfai-1.2.6}/konfai.egg-info/requires.txt +1 -0
  9. {konfai-1.2.4 → konfai-1.2.6}/pyproject.toml +3 -2
  10. {konfai-1.2.4 → konfai-1.2.6}/LICENSE +0 -0
  11. {konfai-1.2.4 → konfai-1.2.6}/README.md +0 -0
  12. {konfai-1.2.4 → konfai-1.2.6}/konfai/__init__.py +0 -0
  13. {konfai-1.2.4 → konfai-1.2.6}/konfai/data/__init__.py +0 -0
  14. {konfai-1.2.4 → konfai-1.2.6}/konfai/data/patching.py +0 -0
  15. {konfai-1.2.4 → konfai-1.2.6}/konfai/evaluator.py +0 -0
  16. {konfai-1.2.4 → konfai-1.2.6}/konfai/main.py +0 -0
  17. {konfai-1.2.4 → konfai-1.2.6}/konfai/metric/__init__.py +0 -0
  18. {konfai-1.2.4 → konfai-1.2.6}/konfai/metric/measure.py +0 -0
  19. {konfai-1.2.4 → konfai-1.2.6}/konfai/metric/schedulers.py +0 -0
  20. {konfai-1.2.4 → konfai-1.2.6}/konfai/models/classification/convNeXt.py +0 -0
  21. {konfai-1.2.4 → konfai-1.2.6}/konfai/models/classification/resnet.py +0 -0
  22. {konfai-1.2.4 → konfai-1.2.6}/konfai/models/generation/cStyleGan.py +0 -0
  23. {konfai-1.2.4 → konfai-1.2.6}/konfai/models/generation/ddpm.py +0 -0
  24. {konfai-1.2.4 → konfai-1.2.6}/konfai/models/generation/diffusionGan.py +0 -0
  25. {konfai-1.2.4 → konfai-1.2.6}/konfai/models/generation/gan.py +0 -0
  26. {konfai-1.2.4 → konfai-1.2.6}/konfai/models/generation/vae.py +0 -0
  27. {konfai-1.2.4 → konfai-1.2.6}/konfai/models/registration/registration.py +0 -0
  28. {konfai-1.2.4 → konfai-1.2.6}/konfai/models/representation/representation.py +0 -0
  29. {konfai-1.2.4 → konfai-1.2.6}/konfai/models/segmentation/NestedUNet.py +0 -0
  30. {konfai-1.2.4 → konfai-1.2.6}/konfai/models/segmentation/UNet.py +0 -0
  31. {konfai-1.2.4 → konfai-1.2.6}/konfai/network/__init__.py +0 -0
  32. {konfai-1.2.4 → konfai-1.2.6}/konfai/network/blocks.py +0 -0
  33. {konfai-1.2.4 → konfai-1.2.6}/konfai/trainer.py +0 -0
  34. {konfai-1.2.4 → konfai-1.2.6}/konfai/utils/ITK.py +0 -0
  35. {konfai-1.2.4 → konfai-1.2.6}/konfai/utils/__init__.py +0 -0
  36. {konfai-1.2.4 → konfai-1.2.6}/konfai/utils/config.py +0 -0
  37. {konfai-1.2.4 → konfai-1.2.6}/konfai/utils/dataset.py +0 -0
  38. {konfai-1.2.4 → konfai-1.2.6}/konfai/utils/utils.py +0 -0
  39. {konfai-1.2.4 → konfai-1.2.6}/konfai.egg-info/SOURCES.txt +0 -0
  40. {konfai-1.2.4 → konfai-1.2.6}/konfai.egg-info/dependency_links.txt +0 -0
  41. {konfai-1.2.4 → konfai-1.2.6}/konfai.egg-info/entry_points.txt +0 -0
  42. {konfai-1.2.4 → konfai-1.2.6}/konfai.egg-info/top_level.txt +0 -0
  43. {konfai-1.2.4 → konfai-1.2.6}/setup.cfg +0 -0
  44. {konfai-1.2.4 → konfai-1.2.6}/tests/test_config.py +0 -0
  45. {konfai-1.2.4 → konfai-1.2.6}/tests/test_dataset.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: konfai
3
- Version: 1.2.4
3
+ Version: 1.2.6
4
4
  Summary: Modular and configurable Deep Learning framework with YAML and PyTorch
5
5
  Author-email: Valentin Boussot <boussot.v@gmail.com>
6
6
  License-Expression: Apache-2.0
@@ -21,6 +21,7 @@ Requires-Dist: SimpleITK
21
21
  Requires-Dist: lxml
22
22
  Requires-Dist: h5py
23
23
  Requires-Dist: pynvml
24
+ Requires-Dist: requests
24
25
  Provides-Extra: vtk
25
26
  Requires-Dist: vtk; extra == "vtk"
26
27
  Provides-Extra: lpips
@@ -218,7 +218,7 @@ class DataAugmentation(NeedDevice, ABC):
218
218
  @abstractmethod
219
219
  def _inverse(self, index: int, a: int, tensor: torch.Tensor) -> torch.Tensor:
220
220
  pass
221
-
221
+
222
222
 
223
223
  class EulerTransform(DataAugmentation):
224
224
 
@@ -515,7 +515,7 @@ class Data(ABC):
515
515
  self.groups_src[group_src][group_dest].load(
516
516
  group_src,
517
517
  group_dest,
518
- [self.datasets[filename] for filename, _ in datasets[group_src]],
518
+ list(self.datasets.values()),
519
519
  )
520
520
  model_have_input |= self.groups_src[group_src][group_dest].is_input
521
521
  if self.patch is not None:
@@ -28,6 +28,12 @@ class Transform(NeedDevice, ABC):
28
28
  def __call__(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
29
29
  pass
30
30
 
31
+ class TransformInverse(Transform, ABC):
32
+
33
+ def __init__(self, inverse: bool) -> None:
34
+ super().__init__()
35
+ self.apply_inverse = inverse
36
+
31
37
  @abstractmethod
32
38
  def inverse(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
33
39
  pass
@@ -48,12 +54,14 @@ class Clip(Transform):
48
54
 
49
55
  def __init__(
50
56
  self,
51
- min_value: float = -1024,
52
- max_value: float = 1024,
57
+ min_value: float | str = -1024,
58
+ max_value: float | str = 1024,
53
59
  save_clip_min: bool = False,
54
60
  save_clip_max: bool = False,
61
+ mask: str | None = None,
55
62
  ) -> None:
56
- if max_value <= min_value:
63
+ super().__init__()
64
+ if isinstance(min_value, float) and isinstance(max_value, float) and max_value <= min_value:
57
65
  raise ValueError(
58
66
  f"[Clip] Invalid clipping range: max_value ({max_value}) must be greater than min_value ({min_value})"
59
67
  )
@@ -61,21 +69,69 @@ class Clip(Transform):
61
69
  self.max_value = max_value
62
70
  self.save_clip_min = save_clip_min
63
71
  self.save_clip_max = save_clip_max
72
+ self.mask = mask
64
73
 
65
74
  def __call__(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
66
- tensor[torch.where(tensor < self.min_value)] = self.min_value
67
- tensor[torch.where(tensor > self.max_value)] = self.max_value
75
+ mask = None
76
+ if self.mask is not None:
77
+ for dataset in self.datasets:
78
+ if dataset.is_dataset_exist(self.mask, name):
79
+ mask, _ = dataset.read_data(self.mask, name)
80
+ break
81
+ if mask is None and self.mask is not None:
82
+ raise ValueError(
83
+ f"Requested mask '{self.mask}' is not present in any dataset. "
84
+ "Check your dataset group names or configuration."
85
+ )
86
+ if mask is None:
87
+ tensor_masked = tensor
88
+ else:
89
+ tensor_masked = tensor[mask == 1]
90
+
91
+ if isinstance(self.min_value, str):
92
+ if self.min_value == "min":
93
+ min_value = torch.min(tensor_masked)
94
+ elif self.min_value.startswith("percentile:"):
95
+ try:
96
+ percentile = float(self.min_value.split(":")[1])
97
+ min_value = np.percentile(tensor_masked, percentile)
98
+ except (IndexError, ValueError):
99
+ raise ValueError(f"Invalid format for min_value: '{self.min_value}'. Expected 'percentile:<float>'")
100
+ else:
101
+ raise TypeError(
102
+ f"Unsupported string for min_value: '{self.min_value}'."
103
+ "Must be a float, 'min', or 'percentile:<float>'."
104
+ )
105
+ else:
106
+ min_value = self.min_value
107
+
108
+ if isinstance(self.max_value, str):
109
+ if self.max_value == "max":
110
+ max_value = torch.max(tensor_masked)
111
+ elif self.max_value.startswith("percentile:"):
112
+ try:
113
+ percentile = float(self.max_value.split(":")[1])
114
+ max_value = np.percentile(tensor_masked, percentile)
115
+ except (IndexError, ValueError):
116
+ raise ValueError(f"Invalid format for max_value: '{self.max_value}'. Expected 'percentile:<float>'")
117
+ else:
118
+ raise TypeError(
119
+ f"Unsupported string for max_value: '{self.max_value}'."
120
+ " Must be a float, 'max', or 'percentile:<float>'."
121
+ )
122
+ else:
123
+ max_value = self.max_value
124
+
125
+ tensor[torch.where(tensor < min_value)] = min_value
126
+ tensor[torch.where(tensor > max_value)] = max_value
68
127
  if self.save_clip_min:
69
- cache_attribute["Min"] = self.min_value
128
+ cache_attribute["Min"] = min_value
70
129
  if self.save_clip_max:
71
- cache_attribute["Max"] = self.max_value
130
+ cache_attribute["Max"] = max_value
72
131
  return tensor
73
132
 
74
- def inverse(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
75
- return tensor
76
133
 
77
-
78
- class Normalize(Transform):
134
+ class Normalize(TransformInverse):
79
135
 
80
136
  def __init__(
81
137
  self,
@@ -83,7 +139,9 @@ class Normalize(Transform):
83
139
  channels: list[int] | None = None,
84
140
  min_value: float = -1,
85
141
  max_value: float = 1,
142
+ inverse: bool = True
86
143
  ) -> None:
144
+ super().__init__(inverse)
87
145
  if max_value <= min_value:
88
146
  raise ValueError(
89
147
  f"[Normalize] Invalid range: max_value ({max_value}) must be greater than min_value ({min_value})"
@@ -136,43 +194,57 @@ class Normalize(Transform):
136
194
  return (tensor - self.min_value) * (input_max - input_min) / (self.max_value - self.min_value) + input_min
137
195
 
138
196
 
139
- class Standardize(Transform):
197
+ class Standardize(TransformInverse):
140
198
 
141
199
  def __init__(
142
200
  self,
143
201
  lazy: bool = False,
144
202
  mean: list[float] | None = None,
145
203
  std: list[float] | None = None,
204
+ mask: str | None = None,
205
+ inverse: bool = True
146
206
  ) -> None:
207
+ super().__init__(inverse)
147
208
  self.lazy = lazy
148
209
  self.mean = mean
149
210
  self.std = std
211
+ self.mask = mask
150
212
 
151
213
  def __call__(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
214
+ mask = None
215
+ if self.mask is not None:
216
+ for dataset in self.datasets:
217
+ if dataset.is_dataset_exist(self.mask, name):
218
+ mask, _ = dataset.read_data(self.mask, name)
219
+ break
220
+ if mask is None and self.mask is not None:
221
+ raise ValueError(
222
+ f"Requested mask '{self.mask}' is not present in any dataset."
223
+ " Check your dataset group names or configuration."
224
+ )
225
+ if mask is None:
226
+ tensor_masked = tensor
227
+ else:
228
+ tensor_masked = tensor[mask == 1]
229
+
152
230
  if "Mean" not in cache_attribute:
153
231
  cache_attribute["Mean"] = (
154
- torch.mean(
155
- tensor.type(torch.float32),
156
- dim=[i + 1 for i in range(len(tensor.shape) - 1)],
157
- )
232
+ torch.tensor([torch.mean(tensor_masked.type(torch.float32))])
158
233
  if self.mean is None
159
234
  else torch.tensor([self.mean])
160
235
  )
236
+
161
237
  if "Std" not in cache_attribute:
162
238
  cache_attribute["Std"] = (
163
- torch.std(
164
- tensor.type(torch.float32),
165
- dim=[i + 1 for i in range(len(tensor.shape) - 1)],
166
- )
239
+ torch.tensor([torch.std(tensor_masked.type(torch.float32))])
167
240
  if self.std is None
168
241
  else torch.tensor([self.std])
169
242
  )
170
-
171
243
  if self.lazy:
172
244
  return tensor
173
245
  else:
174
- mean = cache_attribute.get_tensor("Mean").view(-1, *[1 for _ in range(len(tensor.shape) - 1)])
175
- std = cache_attribute.get_tensor("Std").view(-1, *[1 for _ in range(len(tensor.shape) - 1)])
246
+ mean = cache_attribute.get_tensor("Mean")
247
+ std = cache_attribute.get_tensor("Std")
176
248
  return (tensor - mean) / std
177
249
 
178
250
  def inverse(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
@@ -184,9 +256,10 @@ class Standardize(Transform):
184
256
  return tensor * std + mean
185
257
 
186
258
 
187
- class TensorCast(Transform):
259
+ class TensorCast(TransformInverse):
188
260
 
189
- def __init__(self, dtype: str = "float32") -> None:
261
+ def __init__(self, dtype: str = "float32", inverse: bool = True) -> None:
262
+ super().__init__(inverse)
190
263
  self.dtype: torch.dtype = getattr(torch, dtype)
191
264
 
192
265
  def __call__(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
@@ -204,9 +277,10 @@ class TensorCast(Transform):
204
277
  return tensor.to(TensorCast.safe_dtype_cast(cache_attribute.pop("dtype")))
205
278
 
206
279
 
207
- class Padding(Transform):
280
+ class Padding(TransformInverse):
208
281
 
209
- def __init__(self, padding: list[int] = [0, 0, 0, 0, 0, 0], mode: str = "constant") -> None:
282
+ def __init__(self, padding: list[int] = [0, 0, 0, 0, 0, 0], mode: str = "constant", inverse: bool = True) -> None:
283
+ super().__init__(inverse)
210
284
  self.padding = padding
211
285
  self.mode = mode
212
286
 
@@ -241,9 +315,10 @@ class Padding(Transform):
241
315
  return result
242
316
 
243
317
 
244
- class Squeeze(Transform):
318
+ class Squeeze(TransformInverse):
245
319
 
246
- def __init__(self, dim: int) -> None:
320
+ def __init__(self, dim: int, inverse: bool = True) -> None:
321
+ super().__init__(inverse)
247
322
  self.dim = dim
248
323
 
249
324
  def __call__(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
@@ -253,10 +328,10 @@ class Squeeze(Transform):
253
328
  return tensor.unsqueeze(self.dim)
254
329
 
255
330
 
256
- class Resample(Transform, ABC):
331
+ class Resample(TransformInverse, ABC):
257
332
 
258
- def __init__(self) -> None:
259
- pass
333
+ def __init__(self, inverse: bool) -> None:
334
+ super().__init__(inverse)
260
335
 
261
336
  def _resample(self, tensor: torch.Tensor, size: list[int]) -> torch.Tensor:
262
337
  if tensor.dtype == torch.uint8:
@@ -289,7 +364,8 @@ class Resample(Transform, ABC):
289
364
 
290
365
  class ResampleToResolution(Resample):
291
366
 
292
- def __init__(self, spacing: list[float] = [1.0, 1.0, 1.0]) -> None:
367
+ def __init__(self, spacing: list[float] = [1.0, 1.0, 1.0], inverse: bool = True) -> None:
368
+ super().__init__(inverse)
293
369
  self.spacing = torch.tensor([0 if s < 0 else s for s in spacing])
294
370
 
295
371
  def transform_shape(self, shape: list[int], cache_attribute: Attribute) -> list[int]:
@@ -325,7 +401,8 @@ class ResampleToResolution(Resample):
325
401
 
326
402
  class ResampleToShape(Resample):
327
403
 
328
- def __init__(self, shape: list[float] = [100, 256, 256]) -> None:
404
+ def __init__(self, shape: list[float] = [100, 256, 256], inverse: bool = True) -> None:
405
+ super().__init__(inverse)
329
406
  self.shape = torch.tensor([0 if s < 0 else s for s in shape])
330
407
 
331
408
  def transform_shape(self, shape: list[int], cache_attribute: Attribute) -> list[int]:
@@ -358,9 +435,10 @@ class ResampleToShape(Resample):
358
435
  return self._resample(tensor, shape)
359
436
 
360
437
 
361
- class ResampleTransform(Transform):
438
+ class ResampleTransform(TransformInverse):
362
439
 
363
- def __init__(self, transforms: dict[str, bool]) -> None:
440
+ def __init__(self, transforms: dict[str, bool], inverse: bool = True) -> None:
441
+ super().__init__(inverse)
364
442
  self.transforms = transforms
365
443
 
366
444
  def transform_shape(self, shape: list[int], cache_attribute: Attribute) -> list[int]:
@@ -436,6 +514,7 @@ class ResampleTransform(Transform):
436
514
  class Mask(Transform):
437
515
 
438
516
  def __init__(self, path: str = "./default.mha", value_outside: int = 0) -> None:
517
+ super().__init__()
439
518
  self.path = path
440
519
  self.value_outside = value_outside
441
520
 
@@ -452,13 +531,11 @@ class Mask(Transform):
452
531
  raise NameError(f"Mask : {self.path}/{name} not found")
453
532
  return torch.where(torch.tensor(mask) > 0, tensor, self.value_outside)
454
533
 
455
- def inverse(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
456
- return tensor
457
-
458
534
 
459
535
  class Gradient(Transform):
460
536
 
461
537
  def __init__(self, per_dim: bool = False):
538
+ super().__init__()
462
539
  self.per_dim = per_dim
463
540
 
464
541
  @staticmethod
@@ -492,25 +569,21 @@ class Gradient(Transform):
492
569
 
493
570
  return result
494
571
 
495
- def inverse(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
496
- return tensor
497
-
498
-
499
572
  class Argmax(Transform):
500
573
 
501
574
  def __init__(self, dim: int = 0) -> None:
575
+ super().__init__()
502
576
  self.dim = dim
503
577
 
504
578
  def __call__(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
505
579
  return torch.argmax(tensor, dim=self.dim).unsqueeze(self.dim)
506
580
 
507
- def inverse(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
508
- return tensor
509
581
 
510
582
 
511
583
  class FlatLabel(Transform):
512
584
 
513
585
  def __init__(self, labels: list[int] | None = None) -> None:
586
+ super().__init__()
514
587
  self.labels = labels
515
588
 
516
589
  def __call__(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
@@ -522,21 +595,16 @@ class FlatLabel(Transform):
522
595
  data[torch.where(tensor > 0)] = 1
523
596
  return data
524
597
 
525
- def inverse(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
526
- return tensor
527
-
528
598
 
529
599
  class Save(Transform):
530
600
 
531
601
  def __init__(self, dataset: str) -> None:
602
+ super().__init__()
532
603
  self.dataset = dataset
533
604
 
534
605
  def __call__(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
535
606
  return tensor
536
607
 
537
- def inverse(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
538
- return tensor
539
-
540
608
 
541
609
  class Flatten(Transform):
542
610
 
@@ -549,14 +617,11 @@ class Flatten(Transform):
549
617
  def __call__(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
550
618
  return tensor.flatten()
551
619
 
552
- def inverse(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
553
- return tensor
554
620
 
621
+ class Permute(TransformInverse):
555
622
 
556
- class Permute(Transform):
557
-
558
- def __init__(self, dims: str = "1|0|2") -> None:
559
- super().__init__()
623
+ def __init__(self, dims: str = "1|0|2", inverse: bool = True) -> None:
624
+ super().__init__(inverse)
560
625
  self.dims = [0] + [int(d) + 1 for d in dims.split("|")]
561
626
 
562
627
  def transform_shape(self, shape: list[int], cache_attribute: Attribute) -> list[int]:
@@ -569,10 +634,10 @@ class Permute(Transform):
569
634
  return tensor.permute(tuple(np.argsort(self.dims)))
570
635
 
571
636
 
572
- class Flip(Transform):
637
+ class Flip(TransformInverse):
573
638
 
574
- def __init__(self, dims: str = "1|0|2") -> None:
575
- super().__init__()
639
+ def __init__(self, dims: str = "1|0|2", inverse: bool = True) -> None:
640
+ super().__init__(inverse)
576
641
 
577
642
  self.dims = [int(d) + 1 for d in str(dims).split("|")]
578
643
 
@@ -583,9 +648,10 @@ class Flip(Transform):
583
648
  return tensor.flip(tuple(self.dims))
584
649
 
585
650
 
586
- class Canonical(Transform):
651
+ class Canonical(TransformInverse):
587
652
 
588
- def __init__(self) -> None:
653
+ def __init__(self, inverse: bool = True) -> None:
654
+ super().__init__(inverse)
589
655
  self.canonical_direction = torch.diag(torch.tensor([-1, -1, 1])).to(torch.double)
590
656
 
591
657
  def __call__(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
@@ -618,6 +684,7 @@ class Canonical(Transform):
618
684
  class HistogramMatching(Transform):
619
685
 
620
686
  def __init__(self, reference_group: str) -> None:
687
+ super().__init__()
621
688
  self.reference_group = reference_group
622
689
 
623
690
  def __call__(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
@@ -635,13 +702,11 @@ class HistogramMatching(Transform):
635
702
  result, _ = image_to_data(matcher.Execute(image, image_ref))
636
703
  return torch.tensor(result)
637
704
 
638
- def inverse(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
639
- return tensor
640
-
641
705
 
642
706
  class SelectLabel(Transform):
643
707
 
644
708
  def __init__(self, labels: list[str]) -> None:
709
+ super().__init__()
645
710
  self.labels = [label[1:-1].split(",") for label in labels]
646
711
 
647
712
  def __call__(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
@@ -650,14 +715,11 @@ class SelectLabel(Transform):
650
715
  data[tensor == int(old_label)] = int(new_label)
651
716
  return data
652
717
 
653
- def inverse(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
654
- return tensor
655
718
 
719
+ class OneHot(TransformInverse):
656
720
 
657
- class OneHot(Transform):
658
-
659
- def __init__(self, num_classes: int) -> None:
660
- super().__init__()
721
+ def __init__(self, num_classes: int, inverse: bool = True) -> None:
722
+ super().__init__(inverse)
661
723
  self.num_classes = num_classes
662
724
 
663
725
  def __call__(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
@@ -690,7 +752,4 @@ class TotalSegmentator(Transform):
690
752
  torch.from_numpy(np.array(np.asanyarray(seg.dataobj), copy=True).astype(np.uint8, copy=False))
691
753
  .permute(2, 1, 0)
692
754
  .unsqueeze(0)
693
- )
694
-
695
- def inverse(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
696
- return tensor
755
+ )
@@ -6,8 +6,7 @@ from collections import OrderedDict
6
6
  from collections.abc import Callable, Iterable, Iterator, Sequence
7
7
  from enum import Enum
8
8
  from functools import partial
9
- from typing import Any
10
- from typing_extensions import Self
9
+ from typing import Any, Self
11
10
 
12
11
  import numpy as np
13
12
  import torch
@@ -16,7 +16,7 @@ from torch.utils.tensorboard.writer import SummaryWriter
16
16
  from konfai import config_file, konfai_root, models_directory, path_to_models, predictions_directory
17
17
  from konfai.data.data_manager import DataPrediction, DatasetIter
18
18
  from konfai.data.patching import Accumulator, PathCombine
19
- from konfai.data.transform import Transform, TransformLoader
19
+ from konfai.data.transform import Transform, TransformInverse, TransformLoader
20
20
  from konfai.network.network import CPUModel, ModelLoader, NetState, Network
21
21
  from konfai.utils.config import config
22
22
  from konfai.utils.dataset import Attribute, Dataset
@@ -226,11 +226,12 @@ class OutSameAsGroupDataset(OutputDataset):
226
226
 
227
227
  if self.inverse_transform:
228
228
  for transform in reversed(dataset.groups_src[self.group_src][self.group_dest].patch_transforms):
229
- layer = transform.inverse(
230
- self.names[index_dataset],
231
- layer,
232
- self.attributes[index_dataset][index_augmentation][index_patch],
233
- )
229
+ if isinstance(transform, TransformInverse) and transform.apply_inverse:
230
+ layer = transform.inverse(
231
+ self.names[index_dataset],
232
+ layer,
233
+ self.attributes[index_dataset][index_augmentation][index_patch],
234
+ )
234
235
  self.output_layer_accumulator[index_dataset][index_augmentation].add_layer(index_patch, layer)
235
236
 
236
237
  def load(self, name_layer: str, datasets: list[Dataset], groups: dict[str, str]):
@@ -277,7 +278,8 @@ class OutSameAsGroupDataset(OutputDataset):
277
278
 
278
279
  if self.inverse_transform:
279
280
  for transform in reversed(dataset.groups_src[self.group_src][self.group_dest].transforms):
280
- result = transform.inverse(self.names[index], result, self.attributes[index][0][0])
281
+ if isinstance(transform, TransformInverse) and transform.apply_inverse:
282
+ result = transform.inverse(self.names[index], result, self.attributes[index][0][0])
281
283
 
282
284
  for transform in self.final_transforms:
283
285
  result = transform(self.names[index], result, self.attributes[index][0][0])
@@ -715,7 +717,9 @@ class Predictor(DistributedObject):
715
717
  path = models_directory() + self.name + "/StateDict/"
716
718
  name = sorted(os.listdir(path))[-1]
717
719
  if os.path.exists(path + name):
718
- state_dicts.append(torch.load(path + name, map_location=torch.device('cpu'), weights_only=False))
720
+ state_dicts.append(
721
+ torch.load(path + name, map_location=torch.device("cpu"), weights_only=False) # nosec B614
722
+ ) # nosec B614
719
723
  else:
720
724
  raise Exception(f"Model : {path + name} does not exist !")
721
725
  return state_dicts
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: konfai
3
- Version: 1.2.4
3
+ Version: 1.2.6
4
4
  Summary: Modular and configurable Deep Learning framework with YAML and PyTorch
5
5
  Author-email: Valentin Boussot <boussot.v@gmail.com>
6
6
  License-Expression: Apache-2.0
@@ -21,6 +21,7 @@ Requires-Dist: SimpleITK
21
21
  Requires-Dist: lxml
22
22
  Requires-Dist: h5py
23
23
  Requires-Dist: pynvml
24
+ Requires-Dist: requests
24
25
  Provides-Extra: vtk
25
26
  Requires-Dist: vtk; extra == "vtk"
26
27
  Provides-Extra: lpips
@@ -8,6 +8,7 @@ SimpleITK
8
8
  lxml
9
9
  h5py
10
10
  pynvml
11
+ requests
11
12
 
12
13
  [cluster]
13
14
  submitit
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
4
4
 
5
5
  [project]
6
6
  name = "konfai"
7
- version = "1.2.4"
7
+ version = "1.2.6"
8
8
  description = "Modular and configurable Deep Learning framework with YAML and PyTorch"
9
9
  readme = "README.md"
10
10
  requires-python = ">=3.8"
@@ -24,7 +24,8 @@ dependencies = [
24
24
  "SimpleITK",
25
25
  "lxml",
26
26
  "h5py",
27
- "pynvml"
27
+ "pynvml",
28
+ "requests"
28
29
  ]
29
30
 
30
31
  [project.urls]
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes