konfai 1.2.5__tar.gz → 1.2.7__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of konfai might be problematic. Click here for more details.

Files changed (45) hide show
  1. {konfai-1.2.5 → konfai-1.2.7}/PKG-INFO +2 -1
  2. {konfai-1.2.5 → konfai-1.2.7}/konfai/data/augmentation.py +1 -1
  3. {konfai-1.2.5 → konfai-1.2.7}/konfai/data/data_manager.py +1 -1
  4. {konfai-1.2.5 → konfai-1.2.7}/konfai/data/patching.py +2 -2
  5. {konfai-1.2.5 → konfai-1.2.7}/konfai/data/transform.py +91 -80
  6. {konfai-1.2.5 → konfai-1.2.7}/konfai/predictor.py +9 -7
  7. {konfai-1.2.5 → konfai-1.2.7}/konfai.egg-info/PKG-INFO +2 -1
  8. {konfai-1.2.5 → konfai-1.2.7}/konfai.egg-info/requires.txt +1 -0
  9. {konfai-1.2.5 → konfai-1.2.7}/pyproject.toml +3 -2
  10. {konfai-1.2.5 → konfai-1.2.7}/LICENSE +0 -0
  11. {konfai-1.2.5 → konfai-1.2.7}/README.md +0 -0
  12. {konfai-1.2.5 → konfai-1.2.7}/konfai/__init__.py +0 -0
  13. {konfai-1.2.5 → konfai-1.2.7}/konfai/data/__init__.py +0 -0
  14. {konfai-1.2.5 → konfai-1.2.7}/konfai/evaluator.py +0 -0
  15. {konfai-1.2.5 → konfai-1.2.7}/konfai/main.py +0 -0
  16. {konfai-1.2.5 → konfai-1.2.7}/konfai/metric/__init__.py +0 -0
  17. {konfai-1.2.5 → konfai-1.2.7}/konfai/metric/measure.py +0 -0
  18. {konfai-1.2.5 → konfai-1.2.7}/konfai/metric/schedulers.py +0 -0
  19. {konfai-1.2.5 → konfai-1.2.7}/konfai/models/classification/convNeXt.py +0 -0
  20. {konfai-1.2.5 → konfai-1.2.7}/konfai/models/classification/resnet.py +0 -0
  21. {konfai-1.2.5 → konfai-1.2.7}/konfai/models/generation/cStyleGan.py +0 -0
  22. {konfai-1.2.5 → konfai-1.2.7}/konfai/models/generation/ddpm.py +0 -0
  23. {konfai-1.2.5 → konfai-1.2.7}/konfai/models/generation/diffusionGan.py +0 -0
  24. {konfai-1.2.5 → konfai-1.2.7}/konfai/models/generation/gan.py +0 -0
  25. {konfai-1.2.5 → konfai-1.2.7}/konfai/models/generation/vae.py +0 -0
  26. {konfai-1.2.5 → konfai-1.2.7}/konfai/models/registration/registration.py +0 -0
  27. {konfai-1.2.5 → konfai-1.2.7}/konfai/models/representation/representation.py +0 -0
  28. {konfai-1.2.5 → konfai-1.2.7}/konfai/models/segmentation/NestedUNet.py +0 -0
  29. {konfai-1.2.5 → konfai-1.2.7}/konfai/models/segmentation/UNet.py +0 -0
  30. {konfai-1.2.5 → konfai-1.2.7}/konfai/network/__init__.py +0 -0
  31. {konfai-1.2.5 → konfai-1.2.7}/konfai/network/blocks.py +0 -0
  32. {konfai-1.2.5 → konfai-1.2.7}/konfai/network/network.py +0 -0
  33. {konfai-1.2.5 → konfai-1.2.7}/konfai/trainer.py +0 -0
  34. {konfai-1.2.5 → konfai-1.2.7}/konfai/utils/ITK.py +0 -0
  35. {konfai-1.2.5 → konfai-1.2.7}/konfai/utils/__init__.py +0 -0
  36. {konfai-1.2.5 → konfai-1.2.7}/konfai/utils/config.py +0 -0
  37. {konfai-1.2.5 → konfai-1.2.7}/konfai/utils/dataset.py +0 -0
  38. {konfai-1.2.5 → konfai-1.2.7}/konfai/utils/utils.py +0 -0
  39. {konfai-1.2.5 → konfai-1.2.7}/konfai.egg-info/SOURCES.txt +0 -0
  40. {konfai-1.2.5 → konfai-1.2.7}/konfai.egg-info/dependency_links.txt +0 -0
  41. {konfai-1.2.5 → konfai-1.2.7}/konfai.egg-info/entry_points.txt +0 -0
  42. {konfai-1.2.5 → konfai-1.2.7}/konfai.egg-info/top_level.txt +0 -0
  43. {konfai-1.2.5 → konfai-1.2.7}/setup.cfg +0 -0
  44. {konfai-1.2.5 → konfai-1.2.7}/tests/test_config.py +0 -0
  45. {konfai-1.2.5 → konfai-1.2.7}/tests/test_dataset.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: konfai
3
- Version: 1.2.5
3
+ Version: 1.2.7
4
4
  Summary: Modular and configurable Deep Learning framework with YAML and PyTorch
5
5
  Author-email: Valentin Boussot <boussot.v@gmail.com>
6
6
  License-Expression: Apache-2.0
@@ -21,6 +21,7 @@ Requires-Dist: SimpleITK
21
21
  Requires-Dist: lxml
22
22
  Requires-Dist: h5py
23
23
  Requires-Dist: pynvml
24
+ Requires-Dist: requests
24
25
  Provides-Extra: vtk
25
26
  Requires-Dist: vtk; extra == "vtk"
26
27
  Provides-Extra: lpips
@@ -218,7 +218,7 @@ class DataAugmentation(NeedDevice, ABC):
218
218
  @abstractmethod
219
219
  def _inverse(self, index: int, a: int, tensor: torch.Tensor) -> torch.Tensor:
220
220
  pass
221
-
221
+
222
222
 
223
223
  class EulerTransform(DataAugmentation):
224
224
 
@@ -515,7 +515,7 @@ class Data(ABC):
515
515
  self.groups_src[group_src][group_dest].load(
516
516
  group_src,
517
517
  group_dest,
518
- [self.datasets[filename] for filename, _ in datasets[group_src]],
518
+ list(self.datasets.values()),
519
519
  )
520
520
  model_have_input |= self.groups_src[group_src][group_dest].is_input
521
521
  if self.patch is not None:
@@ -219,7 +219,7 @@ class Patch(ABC):
219
219
  )
220
220
  slices = [s] + list(self._patch_slices[a][index][1:])
221
221
  data_sliced = data[slices_pre + slices]
222
- if data_sliced.shape[len(slices_pre)] < bottom + top + 1:
222
+ if extend_slice > 0 and data_sliced.shape[len(slices_pre)] < bottom + top + 1:
223
223
  pad_bottom = 0
224
224
  pad_top = 0
225
225
  if self._patch_slices[a][index][0].start - bottom < 0:
@@ -335,7 +335,7 @@ class DatasetManager:
335
335
  self.data: list[torch.Tensor] = []
336
336
 
337
337
  for transform_function in transforms:
338
- _shape = transform_function.transform_shape(_shape, cache_attribute)
338
+ _shape = transform_function.transform_shape(self.name, _shape, cache_attribute)
339
339
 
340
340
  self.patch = (
341
341
  DatasetPatch(
@@ -21,13 +21,19 @@ class Transform(NeedDevice, ABC):
21
21
  def set_datasets(self, datasets: list[Dataset]):
22
22
  self.datasets = datasets
23
23
 
24
- def transform_shape(self, shape: list[int], cache_attribute: Attribute) -> list[int]:
24
+ def transform_shape(self, name: str, shape: list[int], cache_attribute: Attribute) -> list[int]:
25
25
  return shape
26
26
 
27
27
  @abstractmethod
28
28
  def __call__(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
29
29
  pass
30
30
 
31
+ class TransformInverse(Transform, ABC):
32
+
33
+ def __init__(self, inverse: bool) -> None:
34
+ super().__init__()
35
+ self.apply_inverse = inverse
36
+
31
37
  @abstractmethod
32
38
  def inverse(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
33
39
  pass
@@ -52,8 +58,9 @@ class Clip(Transform):
52
58
  max_value: float | str = 1024,
53
59
  save_clip_min: bool = False,
54
60
  save_clip_max: bool = False,
55
- mask: str | None = None
61
+ mask: str | None = None,
56
62
  ) -> None:
63
+ super().__init__()
57
64
  if isinstance(min_value, float) and isinstance(max_value, float) and max_value <= min_value:
58
65
  raise ValueError(
59
66
  f"[Clip] Invalid clipping range: max_value ({max_value}) must be greater than min_value ({min_value})"
@@ -72,7 +79,10 @@ class Clip(Transform):
72
79
  mask, _ = dataset.read_data(self.mask, name)
73
80
  break
74
81
  if mask is None and self.mask is not None:
75
- raise ValueError(f"Requested mask '{self.mask}' is not present in any dataset. Check your dataset group names or configuration.")
82
+ raise ValueError(
83
+ f"Requested mask '{self.mask}' is not present in any dataset. "
84
+ "Check your dataset group names or configuration."
85
+ )
76
86
  if mask is None:
77
87
  tensor_masked = tensor
78
88
  else:
@@ -88,7 +98,10 @@ class Clip(Transform):
88
98
  except (IndexError, ValueError):
89
99
  raise ValueError(f"Invalid format for min_value: '{self.min_value}'. Expected 'percentile:<float>'")
90
100
  else:
91
- raise TypeError(f"Unsupported string for min_value: '{self.min_value}'. Must be a float, 'min', or 'percentile:<float>'.")
101
+ raise TypeError(
102
+ f"Unsupported string for min_value: '{self.min_value}'."
103
+ "Must be a float, 'min', or 'percentile:<float>'."
104
+ )
92
105
  else:
93
106
  min_value = self.min_value
94
107
 
@@ -102,10 +115,13 @@ class Clip(Transform):
102
115
  except (IndexError, ValueError):
103
116
  raise ValueError(f"Invalid format for max_value: '{self.max_value}'. Expected 'percentile:<float>'")
104
117
  else:
105
- raise TypeError(f"Unsupported string for max_value: '{self.max_value}'. Must be a float, 'max', or 'percentile:<float>'.")
118
+ raise TypeError(
119
+ f"Unsupported string for max_value: '{self.max_value}'."
120
+ " Must be a float, 'max', or 'percentile:<float>'."
121
+ )
106
122
  else:
107
123
  max_value = self.max_value
108
-
124
+
109
125
  tensor[torch.where(tensor < min_value)] = min_value
110
126
  tensor[torch.where(tensor > max_value)] = max_value
111
127
  if self.save_clip_min:
@@ -114,11 +130,8 @@ class Clip(Transform):
114
130
  cache_attribute["Max"] = max_value
115
131
  return tensor
116
132
 
117
- def inverse(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
118
- return tensor
119
-
120
133
 
121
- class Normalize(Transform):
134
+ class Normalize(TransformInverse):
122
135
 
123
136
  def __init__(
124
137
  self,
@@ -126,7 +139,9 @@ class Normalize(Transform):
126
139
  channels: list[int] | None = None,
127
140
  min_value: float = -1,
128
141
  max_value: float = 1,
142
+ inverse: bool = True
129
143
  ) -> None:
144
+ super().__init__(inverse)
130
145
  if max_value <= min_value:
131
146
  raise ValueError(
132
147
  f"[Normalize] Invalid range: max_value ({max_value}) must be greater than min_value ({min_value})"
@@ -179,15 +194,17 @@ class Normalize(Transform):
179
194
  return (tensor - self.min_value) * (input_max - input_min) / (self.max_value - self.min_value) + input_min
180
195
 
181
196
 
182
- class Standardize(Transform):
197
+ class Standardize(TransformInverse):
183
198
 
184
199
  def __init__(
185
200
  self,
186
201
  lazy: bool = False,
187
202
  mean: list[float] | None = None,
188
203
  std: list[float] | None = None,
189
- mask: str | None = None
204
+ mask: str | None = None,
205
+ inverse: bool = True
190
206
  ) -> None:
207
+ super().__init__(inverse)
191
208
  self.lazy = lazy
192
209
  self.mean = mean
193
210
  self.std = std
@@ -201,19 +218,25 @@ class Standardize(Transform):
201
218
  mask, _ = dataset.read_data(self.mask, name)
202
219
  break
203
220
  if mask is None and self.mask is not None:
204
- raise ValueError(f"Requested mask '{self.mask}' is not present in any dataset. Check your dataset group names or configuration.")
221
+ raise ValueError(
222
+ f"Requested mask '{self.mask}' is not present in any dataset."
223
+ " Check your dataset group names or configuration."
224
+ )
205
225
  if mask is None:
206
226
  tensor_masked = tensor
207
227
  else:
208
228
  tensor_masked = tensor[mask == 1]
209
229
 
210
230
  if "Mean" not in cache_attribute:
211
- cache_attribute["Mean"] = torch.tensor([torch.mean(tensor_masked.type(torch.float32))]) if self.mean is None else torch.tensor([self.mean])
212
-
231
+ cache_attribute["Mean"] = (
232
+ torch.tensor([torch.mean(tensor_masked.type(torch.float32))])
233
+ if self.mean is None
234
+ else torch.tensor([self.mean])
235
+ )
236
+
213
237
  if "Std" not in cache_attribute:
214
238
  cache_attribute["Std"] = (
215
- torch.tensor([torch.std(
216
- tensor_masked.type(torch.float32))])
239
+ torch.tensor([torch.std(tensor_masked.type(torch.float32))])
217
240
  if self.std is None
218
241
  else torch.tensor([self.std])
219
242
  )
@@ -233,9 +256,10 @@ class Standardize(Transform):
233
256
  return tensor * std + mean
234
257
 
235
258
 
236
- class TensorCast(Transform):
259
+ class TensorCast(TransformInverse):
237
260
 
238
- def __init__(self, dtype: str = "float32") -> None:
261
+ def __init__(self, dtype: str = "float32", inverse: bool = True) -> None:
262
+ super().__init__(inverse)
239
263
  self.dtype: torch.dtype = getattr(torch, dtype)
240
264
 
241
265
  def __call__(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
@@ -253,9 +277,10 @@ class TensorCast(Transform):
253
277
  return tensor.to(TensorCast.safe_dtype_cast(cache_attribute.pop("dtype")))
254
278
 
255
279
 
256
- class Padding(Transform):
280
+ class Padding(TransformInverse):
257
281
 
258
- def __init__(self, padding: list[int] = [0, 0, 0, 0, 0, 0], mode: str = "constant") -> None:
282
+ def __init__(self, padding: list[int] = [0, 0, 0, 0, 0, 0], mode: str = "constant", inverse: bool = True) -> None:
283
+ super().__init__(inverse)
259
284
  self.padding = padding
260
285
  self.mode = mode
261
286
 
@@ -275,7 +300,7 @@ class Padding(Transform):
275
300
  ).squeeze(0)
276
301
  return result
277
302
 
278
- def transform_shape(self, shape: list[int], cache_attribute: Attribute) -> list[int]:
303
+ def transform_shape(self, name: str, shape: list[int], cache_attribute: Attribute) -> list[int]:
279
304
  for dim in range(len(self.padding) // 2):
280
305
  shape[-dim - 1] += sum(self.padding[dim * 2 : dim * 2 + 2])
281
306
  return shape
@@ -290,9 +315,10 @@ class Padding(Transform):
290
315
  return result
291
316
 
292
317
 
293
- class Squeeze(Transform):
318
+ class Squeeze(TransformInverse):
294
319
 
295
- def __init__(self, dim: int) -> None:
320
+ def __init__(self, dim: int, inverse: bool = True) -> None:
321
+ super().__init__(inverse)
296
322
  self.dim = dim
297
323
 
298
324
  def __call__(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
@@ -302,10 +328,10 @@ class Squeeze(Transform):
302
328
  return tensor.unsqueeze(self.dim)
303
329
 
304
330
 
305
- class Resample(Transform, ABC):
331
+ class Resample(TransformInverse, ABC):
306
332
 
307
- def __init__(self) -> None:
308
- pass
333
+ def __init__(self, inverse: bool) -> None:
334
+ super().__init__(inverse)
309
335
 
310
336
  def _resample(self, tensor: torch.Tensor, size: list[int]) -> torch.Tensor:
311
337
  if tensor.dtype == torch.uint8:
@@ -326,7 +352,7 @@ class Resample(Transform, ABC):
326
352
  pass
327
353
 
328
354
  @abstractmethod
329
- def transform_shape(self, shape: list[int], cache_attribute: Attribute) -> list[int]:
355
+ def transform_shape(self, name: str, shape: list[int], cache_attribute: Attribute) -> list[int]:
330
356
  pass
331
357
 
332
358
  def inverse(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
@@ -338,7 +364,8 @@ class Resample(Transform, ABC):
338
364
 
339
365
  class ResampleToResolution(Resample):
340
366
 
341
- def __init__(self, spacing: list[float] = [1.0, 1.0, 1.0]) -> None:
367
+ def __init__(self, spacing: list[float] = [1.0, 1.0, 1.0], inverse: bool = True) -> None:
368
+ super().__init__(inverse)
342
369
  self.spacing = torch.tensor([0 if s < 0 else s for s in spacing])
343
370
 
344
371
  def transform_shape(self, shape: list[int], cache_attribute: Attribute) -> list[int]:
@@ -349,35 +376,37 @@ class ResampleToResolution(Resample):
349
376
  )
350
377
  if len(shape) != len(self.spacing):
351
378
  TransformError("Shape and spacing dimensions do not match: shape={shape}, spacing={self.spacing}")
352
- image_spacing = cache_attribute.get_tensor("Spacing").flip(0)
379
+ image_spacing = cache_attribute.get_tensor("Spacing")
353
380
  spacing = self.spacing
354
381
 
355
382
  for i, s in enumerate(self.spacing):
356
383
  if s == 0:
357
384
  spacing[i] = image_spacing[i]
358
- resize_factor = spacing / cache_attribute.get_tensor("Spacing").flip(0)
385
+ resize_factor = spacing / image_spacing
359
386
  return [int(x) for x in (torch.tensor(shape) * 1 / resize_factor)]
360
387
 
361
388
  def __call__(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
362
- image_spacing = cache_attribute.get_tensor("Spacing").flip(0)
389
+ image_spacing = cache_attribute.get_tensor("Spacing")
363
390
  spacing = self.spacing
364
391
  for i, s in enumerate(self.spacing):
365
392
  if s == 0:
366
393
  spacing[i] = image_spacing[i]
367
- resize_factor = spacing / cache_attribute.get_tensor("Spacing").flip(0)
368
- cache_attribute["Spacing"] = spacing.flip(0)
394
+ resize_factor = spacing / cache_attribute.get_tensor("Spacing")
395
+ cache_attribute["Spacing"] = spacing
369
396
  cache_attribute["Size"] = np.asarray([int(x) for x in torch.tensor(tensor.shape[1:])])
370
397
  size = [int(x) for x in (torch.tensor(tensor.shape[1:]) * 1 / resize_factor)]
371
398
  cache_attribute["Size"] = np.asarray(size)
372
399
  return self._resample(tensor, size)
373
400
 
374
401
 
402
+
375
403
  class ResampleToShape(Resample):
376
404
 
377
- def __init__(self, shape: list[float] = [100, 256, 256]) -> None:
405
+ def __init__(self, shape: list[float] = [100, 256, 256], inverse: bool = True) -> None:
406
+ super().__init__(inverse)
378
407
  self.shape = torch.tensor([0 if s < 0 else s for s in shape])
379
408
 
380
- def transform_shape(self, shape: list[int], cache_attribute: Attribute) -> list[int]:
409
+ def transform_shape(self, name: str, shape: list[int], cache_attribute: Attribute) -> list[int]:
381
410
  if "Spacing" not in cache_attribute:
382
411
  TransformError(
383
412
  "Missing 'Spacing' in cache attributes, the data is likely not a valid image.",
@@ -407,12 +436,13 @@ class ResampleToShape(Resample):
407
436
  return self._resample(tensor, shape)
408
437
 
409
438
 
410
- class ResampleTransform(Transform):
439
+ class ResampleTransform(TransformInverse):
411
440
 
412
- def __init__(self, transforms: dict[str, bool]) -> None:
441
+ def __init__(self, transforms: dict[str, bool], inverse: bool = True) -> None:
442
+ super().__init__(inverse)
413
443
  self.transforms = transforms
414
444
 
415
- def transform_shape(self, shape: list[int], cache_attribute: Attribute) -> list[int]:
445
+ def transform_shape(self, name: str, shape: list[int], cache_attribute: Attribute) -> list[int]:
416
446
  return shape
417
447
 
418
448
  def __call__(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
@@ -485,6 +515,7 @@ class ResampleTransform(Transform):
485
515
  class Mask(Transform):
486
516
 
487
517
  def __init__(self, path: str = "./default.mha", value_outside: int = 0) -> None:
518
+ super().__init__()
488
519
  self.path = path
489
520
  self.value_outside = value_outside
490
521
 
@@ -501,13 +532,11 @@ class Mask(Transform):
501
532
  raise NameError(f"Mask : {self.path}/{name} not found")
502
533
  return torch.where(torch.tensor(mask) > 0, tensor, self.value_outside)
503
534
 
504
- def inverse(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
505
- return tensor
506
-
507
535
 
508
536
  class Gradient(Transform):
509
537
 
510
538
  def __init__(self, per_dim: bool = False):
539
+ super().__init__()
511
540
  self.per_dim = per_dim
512
541
 
513
542
  @staticmethod
@@ -541,25 +570,21 @@ class Gradient(Transform):
541
570
 
542
571
  return result
543
572
 
544
- def inverse(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
545
- return tensor
546
-
547
-
548
573
  class Argmax(Transform):
549
574
 
550
575
  def __init__(self, dim: int = 0) -> None:
576
+ super().__init__()
551
577
  self.dim = dim
552
578
 
553
579
  def __call__(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
554
580
  return torch.argmax(tensor, dim=self.dim).unsqueeze(self.dim)
555
581
 
556
- def inverse(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
557
- return tensor
558
582
 
559
583
 
560
584
  class FlatLabel(Transform):
561
585
 
562
586
  def __init__(self, labels: list[int] | None = None) -> None:
587
+ super().__init__()
563
588
  self.labels = labels
564
589
 
565
590
  def __call__(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
@@ -571,44 +596,36 @@ class FlatLabel(Transform):
571
596
  data[torch.where(tensor > 0)] = 1
572
597
  return data
573
598
 
574
- def inverse(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
575
- return tensor
576
-
577
599
 
578
600
  class Save(Transform):
579
601
 
580
602
  def __init__(self, dataset: str) -> None:
603
+ super().__init__()
581
604
  self.dataset = dataset
582
605
 
583
606
  def __call__(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
584
607
  return tensor
585
608
 
586
- def inverse(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
587
- return tensor
588
-
589
609
 
590
610
  class Flatten(Transform):
591
611
 
592
612
  def __init__(self) -> None:
593
613
  super().__init__()
594
614
 
595
- def transform_shape(self, shape: list[int], cache_attribute: Attribute) -> list[int]:
615
+ def transform_shape(self, name: str, shape: list[int], cache_attribute: Attribute) -> list[int]:
596
616
  return [np.prod(np.asarray(shape))]
597
617
 
598
618
  def __call__(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
599
619
  return tensor.flatten()
600
620
 
601
- def inverse(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
602
- return tensor
603
621
 
622
+ class Permute(TransformInverse):
604
623
 
605
- class Permute(Transform):
606
-
607
- def __init__(self, dims: str = "1|0|2") -> None:
608
- super().__init__()
624
+ def __init__(self, dims: str = "1|0|2", inverse: bool = True) -> None:
625
+ super().__init__(inverse)
609
626
  self.dims = [0] + [int(d) + 1 for d in dims.split("|")]
610
627
 
611
- def transform_shape(self, shape: list[int], cache_attribute: Attribute) -> list[int]:
628
+ def transform_shape(self, name: str, shape: list[int], cache_attribute: Attribute) -> list[int]:
612
629
  return [shape[it - 1] for it in self.dims[1:]]
613
630
 
614
631
  def __call__(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
@@ -618,10 +635,10 @@ class Permute(Transform):
618
635
  return tensor.permute(tuple(np.argsort(self.dims)))
619
636
 
620
637
 
621
- class Flip(Transform):
638
+ class Flip(TransformInverse):
622
639
 
623
- def __init__(self, dims: str = "1|0|2") -> None:
624
- super().__init__()
640
+ def __init__(self, dims: str = "1|0|2", inverse: bool = True) -> None:
641
+ super().__init__(inverse)
625
642
 
626
643
  self.dims = [int(d) + 1 for d in str(dims).split("|")]
627
644
 
@@ -632,9 +649,10 @@ class Flip(Transform):
632
649
  return tensor.flip(tuple(self.dims))
633
650
 
634
651
 
635
- class Canonical(Transform):
652
+ class Canonical(TransformInverse):
636
653
 
637
- def __init__(self) -> None:
654
+ def __init__(self, inverse: bool = True) -> None:
655
+ super().__init__(inverse)
638
656
  self.canonical_direction = torch.diag(torch.tensor([-1, -1, 1])).to(torch.double)
639
657
 
640
658
  def __call__(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
@@ -667,6 +685,7 @@ class Canonical(Transform):
667
685
  class HistogramMatching(Transform):
668
686
 
669
687
  def __init__(self, reference_group: str) -> None:
688
+ super().__init__()
670
689
  self.reference_group = reference_group
671
690
 
672
691
  def __call__(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
@@ -684,13 +703,11 @@ class HistogramMatching(Transform):
684
703
  result, _ = image_to_data(matcher.Execute(image, image_ref))
685
704
  return torch.tensor(result)
686
705
 
687
- def inverse(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
688
- return tensor
689
-
690
706
 
691
707
  class SelectLabel(Transform):
692
708
 
693
709
  def __init__(self, labels: list[str]) -> None:
710
+ super().__init__()
694
711
  self.labels = [label[1:-1].split(",") for label in labels]
695
712
 
696
713
  def __call__(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
@@ -699,14 +716,11 @@ class SelectLabel(Transform):
699
716
  data[tensor == int(old_label)] = int(new_label)
700
717
  return data
701
718
 
702
- def inverse(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
703
- return tensor
704
719
 
720
+ class OneHot(TransformInverse):
705
721
 
706
- class OneHot(Transform):
707
-
708
- def __init__(self, num_classes: int) -> None:
709
- super().__init__()
722
+ def __init__(self, num_classes: int, inverse: bool = True) -> None:
723
+ super().__init__(inverse)
710
724
  self.num_classes = num_classes
711
725
 
712
726
  def __call__(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
@@ -739,7 +753,4 @@ class TotalSegmentator(Transform):
739
753
  torch.from_numpy(np.array(np.asanyarray(seg.dataobj), copy=True).astype(np.uint8, copy=False))
740
754
  .permute(2, 1, 0)
741
755
  .unsqueeze(0)
742
- )
743
-
744
- def inverse(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
745
- return tensor
756
+ )
@@ -16,7 +16,7 @@ from torch.utils.tensorboard.writer import SummaryWriter
16
16
  from konfai import config_file, konfai_root, models_directory, path_to_models, predictions_directory
17
17
  from konfai.data.data_manager import DataPrediction, DatasetIter
18
18
  from konfai.data.patching import Accumulator, PathCombine
19
- from konfai.data.transform import Transform, TransformLoader
19
+ from konfai.data.transform import Transform, TransformInverse, TransformLoader
20
20
  from konfai.network.network import CPUModel, ModelLoader, NetState, Network
21
21
  from konfai.utils.config import config
22
22
  from konfai.utils.dataset import Attribute, Dataset
@@ -226,11 +226,12 @@ class OutSameAsGroupDataset(OutputDataset):
226
226
 
227
227
  if self.inverse_transform:
228
228
  for transform in reversed(dataset.groups_src[self.group_src][self.group_dest].patch_transforms):
229
- layer = transform.inverse(
230
- self.names[index_dataset],
231
- layer,
232
- self.attributes[index_dataset][index_augmentation][index_patch],
233
- )
229
+ if isinstance(transform, TransformInverse) and transform.apply_inverse:
230
+ layer = transform.inverse(
231
+ self.names[index_dataset],
232
+ layer,
233
+ self.attributes[index_dataset][index_augmentation][index_patch],
234
+ )
234
235
  self.output_layer_accumulator[index_dataset][index_augmentation].add_layer(index_patch, layer)
235
236
 
236
237
  def load(self, name_layer: str, datasets: list[Dataset], groups: dict[str, str]):
@@ -277,7 +278,8 @@ class OutSameAsGroupDataset(OutputDataset):
277
278
 
278
279
  if self.inverse_transform:
279
280
  for transform in reversed(dataset.groups_src[self.group_src][self.group_dest].transforms):
280
- result = transform.inverse(self.names[index], result, self.attributes[index][0][0])
281
+ if isinstance(transform, TransformInverse) and transform.apply_inverse:
282
+ result = transform.inverse(self.names[index], result, self.attributes[index][0][0])
281
283
 
282
284
  for transform in self.final_transforms:
283
285
  result = transform(self.names[index], result, self.attributes[index][0][0])
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: konfai
3
- Version: 1.2.5
3
+ Version: 1.2.7
4
4
  Summary: Modular and configurable Deep Learning framework with YAML and PyTorch
5
5
  Author-email: Valentin Boussot <boussot.v@gmail.com>
6
6
  License-Expression: Apache-2.0
@@ -21,6 +21,7 @@ Requires-Dist: SimpleITK
21
21
  Requires-Dist: lxml
22
22
  Requires-Dist: h5py
23
23
  Requires-Dist: pynvml
24
+ Requires-Dist: requests
24
25
  Provides-Extra: vtk
25
26
  Requires-Dist: vtk; extra == "vtk"
26
27
  Provides-Extra: lpips
@@ -8,6 +8,7 @@ SimpleITK
8
8
  lxml
9
9
  h5py
10
10
  pynvml
11
+ requests
11
12
 
12
13
  [cluster]
13
14
  submitit
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
4
4
 
5
5
  [project]
6
6
  name = "konfai"
7
- version = "1.2.5"
7
+ version = "1.2.7"
8
8
  description = "Modular and configurable Deep Learning framework with YAML and PyTorch"
9
9
  readme = "README.md"
10
10
  requires-python = ">=3.8"
@@ -24,7 +24,8 @@ dependencies = [
24
24
  "SimpleITK",
25
25
  "lxml",
26
26
  "h5py",
27
- "pynvml"
27
+ "pynvml",
28
+ "requests"
28
29
  ]
29
30
 
30
31
  [project.urls]
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes