konfai 1.2.5__py3-none-any.whl → 1.2.6__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of konfai might be problematic. Click here for more details.
- konfai/data/augmentation.py +1 -1
- konfai/data/data_manager.py +1 -1
- konfai/data/transform.py +78 -68
- konfai/predictor.py +9 -7
- {konfai-1.2.5.dist-info → konfai-1.2.6.dist-info}/METADATA +2 -1
- {konfai-1.2.5.dist-info → konfai-1.2.6.dist-info}/RECORD +10 -10
- {konfai-1.2.5.dist-info → konfai-1.2.6.dist-info}/WHEEL +0 -0
- {konfai-1.2.5.dist-info → konfai-1.2.6.dist-info}/entry_points.txt +0 -0
- {konfai-1.2.5.dist-info → konfai-1.2.6.dist-info}/licenses/LICENSE +0 -0
- {konfai-1.2.5.dist-info → konfai-1.2.6.dist-info}/top_level.txt +0 -0
konfai/data/augmentation.py
CHANGED
konfai/data/data_manager.py
CHANGED
|
@@ -515,7 +515,7 @@ class Data(ABC):
|
|
|
515
515
|
self.groups_src[group_src][group_dest].load(
|
|
516
516
|
group_src,
|
|
517
517
|
group_dest,
|
|
518
|
-
|
|
518
|
+
list(self.datasets.values()),
|
|
519
519
|
)
|
|
520
520
|
model_have_input |= self.groups_src[group_src][group_dest].is_input
|
|
521
521
|
if self.patch is not None:
|
konfai/data/transform.py
CHANGED
|
@@ -28,6 +28,12 @@ class Transform(NeedDevice, ABC):
|
|
|
28
28
|
def __call__(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
|
|
29
29
|
pass
|
|
30
30
|
|
|
31
|
+
class TransformInverse(Transform, ABC):
|
|
32
|
+
|
|
33
|
+
def __init__(self, inverse: bool) -> None:
|
|
34
|
+
super().__init__()
|
|
35
|
+
self.apply_inverse = inverse
|
|
36
|
+
|
|
31
37
|
@abstractmethod
|
|
32
38
|
def inverse(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
|
|
33
39
|
pass
|
|
@@ -52,8 +58,9 @@ class Clip(Transform):
|
|
|
52
58
|
max_value: float | str = 1024,
|
|
53
59
|
save_clip_min: bool = False,
|
|
54
60
|
save_clip_max: bool = False,
|
|
55
|
-
mask: str | None = None
|
|
61
|
+
mask: str | None = None,
|
|
56
62
|
) -> None:
|
|
63
|
+
super().__init__()
|
|
57
64
|
if isinstance(min_value, float) and isinstance(max_value, float) and max_value <= min_value:
|
|
58
65
|
raise ValueError(
|
|
59
66
|
f"[Clip] Invalid clipping range: max_value ({max_value}) must be greater than min_value ({min_value})"
|
|
@@ -72,7 +79,10 @@ class Clip(Transform):
|
|
|
72
79
|
mask, _ = dataset.read_data(self.mask, name)
|
|
73
80
|
break
|
|
74
81
|
if mask is None and self.mask is not None:
|
|
75
|
-
raise ValueError(
|
|
82
|
+
raise ValueError(
|
|
83
|
+
f"Requested mask '{self.mask}' is not present in any dataset. "
|
|
84
|
+
"Check your dataset group names or configuration."
|
|
85
|
+
)
|
|
76
86
|
if mask is None:
|
|
77
87
|
tensor_masked = tensor
|
|
78
88
|
else:
|
|
@@ -88,7 +98,10 @@ class Clip(Transform):
|
|
|
88
98
|
except (IndexError, ValueError):
|
|
89
99
|
raise ValueError(f"Invalid format for min_value: '{self.min_value}'. Expected 'percentile:<float>'")
|
|
90
100
|
else:
|
|
91
|
-
raise TypeError(
|
|
101
|
+
raise TypeError(
|
|
102
|
+
f"Unsupported string for min_value: '{self.min_value}'."
|
|
103
|
+
"Must be a float, 'min', or 'percentile:<float>'."
|
|
104
|
+
)
|
|
92
105
|
else:
|
|
93
106
|
min_value = self.min_value
|
|
94
107
|
|
|
@@ -102,10 +115,13 @@ class Clip(Transform):
|
|
|
102
115
|
except (IndexError, ValueError):
|
|
103
116
|
raise ValueError(f"Invalid format for max_value: '{self.max_value}'. Expected 'percentile:<float>'")
|
|
104
117
|
else:
|
|
105
|
-
raise TypeError(
|
|
118
|
+
raise TypeError(
|
|
119
|
+
f"Unsupported string for max_value: '{self.max_value}'."
|
|
120
|
+
" Must be a float, 'max', or 'percentile:<float>'."
|
|
121
|
+
)
|
|
106
122
|
else:
|
|
107
123
|
max_value = self.max_value
|
|
108
|
-
|
|
124
|
+
|
|
109
125
|
tensor[torch.where(tensor < min_value)] = min_value
|
|
110
126
|
tensor[torch.where(tensor > max_value)] = max_value
|
|
111
127
|
if self.save_clip_min:
|
|
@@ -114,11 +130,8 @@ class Clip(Transform):
|
|
|
114
130
|
cache_attribute["Max"] = max_value
|
|
115
131
|
return tensor
|
|
116
132
|
|
|
117
|
-
def inverse(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
|
|
118
|
-
return tensor
|
|
119
|
-
|
|
120
133
|
|
|
121
|
-
class Normalize(
|
|
134
|
+
class Normalize(TransformInverse):
|
|
122
135
|
|
|
123
136
|
def __init__(
|
|
124
137
|
self,
|
|
@@ -126,7 +139,9 @@ class Normalize(Transform):
|
|
|
126
139
|
channels: list[int] | None = None,
|
|
127
140
|
min_value: float = -1,
|
|
128
141
|
max_value: float = 1,
|
|
142
|
+
inverse: bool = True
|
|
129
143
|
) -> None:
|
|
144
|
+
super().__init__(inverse)
|
|
130
145
|
if max_value <= min_value:
|
|
131
146
|
raise ValueError(
|
|
132
147
|
f"[Normalize] Invalid range: max_value ({max_value}) must be greater than min_value ({min_value})"
|
|
@@ -179,15 +194,17 @@ class Normalize(Transform):
|
|
|
179
194
|
return (tensor - self.min_value) * (input_max - input_min) / (self.max_value - self.min_value) + input_min
|
|
180
195
|
|
|
181
196
|
|
|
182
|
-
class Standardize(
|
|
197
|
+
class Standardize(TransformInverse):
|
|
183
198
|
|
|
184
199
|
def __init__(
|
|
185
200
|
self,
|
|
186
201
|
lazy: bool = False,
|
|
187
202
|
mean: list[float] | None = None,
|
|
188
203
|
std: list[float] | None = None,
|
|
189
|
-
mask: str | None = None
|
|
204
|
+
mask: str | None = None,
|
|
205
|
+
inverse: bool = True
|
|
190
206
|
) -> None:
|
|
207
|
+
super().__init__(inverse)
|
|
191
208
|
self.lazy = lazy
|
|
192
209
|
self.mean = mean
|
|
193
210
|
self.std = std
|
|
@@ -201,19 +218,25 @@ class Standardize(Transform):
|
|
|
201
218
|
mask, _ = dataset.read_data(self.mask, name)
|
|
202
219
|
break
|
|
203
220
|
if mask is None and self.mask is not None:
|
|
204
|
-
raise ValueError(
|
|
221
|
+
raise ValueError(
|
|
222
|
+
f"Requested mask '{self.mask}' is not present in any dataset."
|
|
223
|
+
" Check your dataset group names or configuration."
|
|
224
|
+
)
|
|
205
225
|
if mask is None:
|
|
206
226
|
tensor_masked = tensor
|
|
207
227
|
else:
|
|
208
228
|
tensor_masked = tensor[mask == 1]
|
|
209
229
|
|
|
210
230
|
if "Mean" not in cache_attribute:
|
|
211
|
-
cache_attribute["Mean"] =
|
|
212
|
-
|
|
231
|
+
cache_attribute["Mean"] = (
|
|
232
|
+
torch.tensor([torch.mean(tensor_masked.type(torch.float32))])
|
|
233
|
+
if self.mean is None
|
|
234
|
+
else torch.tensor([self.mean])
|
|
235
|
+
)
|
|
236
|
+
|
|
213
237
|
if "Std" not in cache_attribute:
|
|
214
238
|
cache_attribute["Std"] = (
|
|
215
|
-
torch.tensor([torch.std(
|
|
216
|
-
tensor_masked.type(torch.float32))])
|
|
239
|
+
torch.tensor([torch.std(tensor_masked.type(torch.float32))])
|
|
217
240
|
if self.std is None
|
|
218
241
|
else torch.tensor([self.std])
|
|
219
242
|
)
|
|
@@ -233,9 +256,10 @@ class Standardize(Transform):
|
|
|
233
256
|
return tensor * std + mean
|
|
234
257
|
|
|
235
258
|
|
|
236
|
-
class TensorCast(
|
|
259
|
+
class TensorCast(TransformInverse):
|
|
237
260
|
|
|
238
|
-
def __init__(self, dtype: str = "float32") -> None:
|
|
261
|
+
def __init__(self, dtype: str = "float32", inverse: bool = True) -> None:
|
|
262
|
+
super().__init__(inverse)
|
|
239
263
|
self.dtype: torch.dtype = getattr(torch, dtype)
|
|
240
264
|
|
|
241
265
|
def __call__(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
|
|
@@ -253,9 +277,10 @@ class TensorCast(Transform):
|
|
|
253
277
|
return tensor.to(TensorCast.safe_dtype_cast(cache_attribute.pop("dtype")))
|
|
254
278
|
|
|
255
279
|
|
|
256
|
-
class Padding(
|
|
280
|
+
class Padding(TransformInverse):
|
|
257
281
|
|
|
258
|
-
def __init__(self, padding: list[int] = [0, 0, 0, 0, 0, 0], mode: str = "constant") -> None:
|
|
282
|
+
def __init__(self, padding: list[int] = [0, 0, 0, 0, 0, 0], mode: str = "constant", inverse: bool = True) -> None:
|
|
283
|
+
super().__init__(inverse)
|
|
259
284
|
self.padding = padding
|
|
260
285
|
self.mode = mode
|
|
261
286
|
|
|
@@ -290,9 +315,10 @@ class Padding(Transform):
|
|
|
290
315
|
return result
|
|
291
316
|
|
|
292
317
|
|
|
293
|
-
class Squeeze(
|
|
318
|
+
class Squeeze(TransformInverse):
|
|
294
319
|
|
|
295
|
-
def __init__(self, dim: int) -> None:
|
|
320
|
+
def __init__(self, dim: int, inverse: bool = True) -> None:
|
|
321
|
+
super().__init__(inverse)
|
|
296
322
|
self.dim = dim
|
|
297
323
|
|
|
298
324
|
def __call__(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
|
|
@@ -302,10 +328,10 @@ class Squeeze(Transform):
|
|
|
302
328
|
return tensor.unsqueeze(self.dim)
|
|
303
329
|
|
|
304
330
|
|
|
305
|
-
class Resample(
|
|
331
|
+
class Resample(TransformInverse, ABC):
|
|
306
332
|
|
|
307
|
-
def __init__(self) -> None:
|
|
308
|
-
|
|
333
|
+
def __init__(self, inverse: bool) -> None:
|
|
334
|
+
super().__init__(inverse)
|
|
309
335
|
|
|
310
336
|
def _resample(self, tensor: torch.Tensor, size: list[int]) -> torch.Tensor:
|
|
311
337
|
if tensor.dtype == torch.uint8:
|
|
@@ -338,7 +364,8 @@ class Resample(Transform, ABC):
|
|
|
338
364
|
|
|
339
365
|
class ResampleToResolution(Resample):
|
|
340
366
|
|
|
341
|
-
def __init__(self, spacing: list[float] = [1.0, 1.0, 1.0]) -> None:
|
|
367
|
+
def __init__(self, spacing: list[float] = [1.0, 1.0, 1.0], inverse: bool = True) -> None:
|
|
368
|
+
super().__init__(inverse)
|
|
342
369
|
self.spacing = torch.tensor([0 if s < 0 else s for s in spacing])
|
|
343
370
|
|
|
344
371
|
def transform_shape(self, shape: list[int], cache_attribute: Attribute) -> list[int]:
|
|
@@ -374,7 +401,8 @@ class ResampleToResolution(Resample):
|
|
|
374
401
|
|
|
375
402
|
class ResampleToShape(Resample):
|
|
376
403
|
|
|
377
|
-
def __init__(self, shape: list[float] = [100, 256, 256]) -> None:
|
|
404
|
+
def __init__(self, shape: list[float] = [100, 256, 256], inverse: bool = True) -> None:
|
|
405
|
+
super().__init__(inverse)
|
|
378
406
|
self.shape = torch.tensor([0 if s < 0 else s for s in shape])
|
|
379
407
|
|
|
380
408
|
def transform_shape(self, shape: list[int], cache_attribute: Attribute) -> list[int]:
|
|
@@ -407,9 +435,10 @@ class ResampleToShape(Resample):
|
|
|
407
435
|
return self._resample(tensor, shape)
|
|
408
436
|
|
|
409
437
|
|
|
410
|
-
class ResampleTransform(
|
|
438
|
+
class ResampleTransform(TransformInverse):
|
|
411
439
|
|
|
412
|
-
def __init__(self, transforms: dict[str, bool]) -> None:
|
|
440
|
+
def __init__(self, transforms: dict[str, bool], inverse: bool = True) -> None:
|
|
441
|
+
super().__init__(inverse)
|
|
413
442
|
self.transforms = transforms
|
|
414
443
|
|
|
415
444
|
def transform_shape(self, shape: list[int], cache_attribute: Attribute) -> list[int]:
|
|
@@ -485,6 +514,7 @@ class ResampleTransform(Transform):
|
|
|
485
514
|
class Mask(Transform):
|
|
486
515
|
|
|
487
516
|
def __init__(self, path: str = "./default.mha", value_outside: int = 0) -> None:
|
|
517
|
+
super().__init__()
|
|
488
518
|
self.path = path
|
|
489
519
|
self.value_outside = value_outside
|
|
490
520
|
|
|
@@ -501,13 +531,11 @@ class Mask(Transform):
|
|
|
501
531
|
raise NameError(f"Mask : {self.path}/{name} not found")
|
|
502
532
|
return torch.where(torch.tensor(mask) > 0, tensor, self.value_outside)
|
|
503
533
|
|
|
504
|
-
def inverse(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
|
|
505
|
-
return tensor
|
|
506
|
-
|
|
507
534
|
|
|
508
535
|
class Gradient(Transform):
|
|
509
536
|
|
|
510
537
|
def __init__(self, per_dim: bool = False):
|
|
538
|
+
super().__init__()
|
|
511
539
|
self.per_dim = per_dim
|
|
512
540
|
|
|
513
541
|
@staticmethod
|
|
@@ -541,25 +569,21 @@ class Gradient(Transform):
|
|
|
541
569
|
|
|
542
570
|
return result
|
|
543
571
|
|
|
544
|
-
def inverse(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
|
|
545
|
-
return tensor
|
|
546
|
-
|
|
547
|
-
|
|
548
572
|
class Argmax(Transform):
|
|
549
573
|
|
|
550
574
|
def __init__(self, dim: int = 0) -> None:
|
|
575
|
+
super().__init__()
|
|
551
576
|
self.dim = dim
|
|
552
577
|
|
|
553
578
|
def __call__(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
|
|
554
579
|
return torch.argmax(tensor, dim=self.dim).unsqueeze(self.dim)
|
|
555
580
|
|
|
556
|
-
def inverse(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
|
|
557
|
-
return tensor
|
|
558
581
|
|
|
559
582
|
|
|
560
583
|
class FlatLabel(Transform):
|
|
561
584
|
|
|
562
585
|
def __init__(self, labels: list[int] | None = None) -> None:
|
|
586
|
+
super().__init__()
|
|
563
587
|
self.labels = labels
|
|
564
588
|
|
|
565
589
|
def __call__(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
|
|
@@ -571,21 +595,16 @@ class FlatLabel(Transform):
|
|
|
571
595
|
data[torch.where(tensor > 0)] = 1
|
|
572
596
|
return data
|
|
573
597
|
|
|
574
|
-
def inverse(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
|
|
575
|
-
return tensor
|
|
576
|
-
|
|
577
598
|
|
|
578
599
|
class Save(Transform):
|
|
579
600
|
|
|
580
601
|
def __init__(self, dataset: str) -> None:
|
|
602
|
+
super().__init__()
|
|
581
603
|
self.dataset = dataset
|
|
582
604
|
|
|
583
605
|
def __call__(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
|
|
584
606
|
return tensor
|
|
585
607
|
|
|
586
|
-
def inverse(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
|
|
587
|
-
return tensor
|
|
588
|
-
|
|
589
608
|
|
|
590
609
|
class Flatten(Transform):
|
|
591
610
|
|
|
@@ -598,14 +617,11 @@ class Flatten(Transform):
|
|
|
598
617
|
def __call__(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
|
|
599
618
|
return tensor.flatten()
|
|
600
619
|
|
|
601
|
-
def inverse(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
|
|
602
|
-
return tensor
|
|
603
620
|
|
|
621
|
+
class Permute(TransformInverse):
|
|
604
622
|
|
|
605
|
-
|
|
606
|
-
|
|
607
|
-
def __init__(self, dims: str = "1|0|2") -> None:
|
|
608
|
-
super().__init__()
|
|
623
|
+
def __init__(self, dims: str = "1|0|2", inverse: bool = True) -> None:
|
|
624
|
+
super().__init__(inverse)
|
|
609
625
|
self.dims = [0] + [int(d) + 1 for d in dims.split("|")]
|
|
610
626
|
|
|
611
627
|
def transform_shape(self, shape: list[int], cache_attribute: Attribute) -> list[int]:
|
|
@@ -618,10 +634,10 @@ class Permute(Transform):
|
|
|
618
634
|
return tensor.permute(tuple(np.argsort(self.dims)))
|
|
619
635
|
|
|
620
636
|
|
|
621
|
-
class Flip(
|
|
637
|
+
class Flip(TransformInverse):
|
|
622
638
|
|
|
623
|
-
def __init__(self, dims: str = "1|0|2") -> None:
|
|
624
|
-
super().__init__()
|
|
639
|
+
def __init__(self, dims: str = "1|0|2", inverse: bool = True) -> None:
|
|
640
|
+
super().__init__(inverse)
|
|
625
641
|
|
|
626
642
|
self.dims = [int(d) + 1 for d in str(dims).split("|")]
|
|
627
643
|
|
|
@@ -632,9 +648,10 @@ class Flip(Transform):
|
|
|
632
648
|
return tensor.flip(tuple(self.dims))
|
|
633
649
|
|
|
634
650
|
|
|
635
|
-
class Canonical(
|
|
651
|
+
class Canonical(TransformInverse):
|
|
636
652
|
|
|
637
|
-
def __init__(self) -> None:
|
|
653
|
+
def __init__(self, inverse: bool = True) -> None:
|
|
654
|
+
super().__init__(inverse)
|
|
638
655
|
self.canonical_direction = torch.diag(torch.tensor([-1, -1, 1])).to(torch.double)
|
|
639
656
|
|
|
640
657
|
def __call__(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
|
|
@@ -667,6 +684,7 @@ class Canonical(Transform):
|
|
|
667
684
|
class HistogramMatching(Transform):
|
|
668
685
|
|
|
669
686
|
def __init__(self, reference_group: str) -> None:
|
|
687
|
+
super().__init__()
|
|
670
688
|
self.reference_group = reference_group
|
|
671
689
|
|
|
672
690
|
def __call__(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
|
|
@@ -684,13 +702,11 @@ class HistogramMatching(Transform):
|
|
|
684
702
|
result, _ = image_to_data(matcher.Execute(image, image_ref))
|
|
685
703
|
return torch.tensor(result)
|
|
686
704
|
|
|
687
|
-
def inverse(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
|
|
688
|
-
return tensor
|
|
689
|
-
|
|
690
705
|
|
|
691
706
|
class SelectLabel(Transform):
|
|
692
707
|
|
|
693
708
|
def __init__(self, labels: list[str]) -> None:
|
|
709
|
+
super().__init__()
|
|
694
710
|
self.labels = [label[1:-1].split(",") for label in labels]
|
|
695
711
|
|
|
696
712
|
def __call__(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
|
|
@@ -699,14 +715,11 @@ class SelectLabel(Transform):
|
|
|
699
715
|
data[tensor == int(old_label)] = int(new_label)
|
|
700
716
|
return data
|
|
701
717
|
|
|
702
|
-
def inverse(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
|
|
703
|
-
return tensor
|
|
704
|
-
|
|
705
718
|
|
|
706
|
-
class OneHot(
|
|
719
|
+
class OneHot(TransformInverse):
|
|
707
720
|
|
|
708
|
-
def __init__(self, num_classes: int) -> None:
|
|
709
|
-
super().__init__()
|
|
721
|
+
def __init__(self, num_classes: int, inverse: bool = True) -> None:
|
|
722
|
+
super().__init__(inverse)
|
|
710
723
|
self.num_classes = num_classes
|
|
711
724
|
|
|
712
725
|
def __call__(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
|
|
@@ -739,7 +752,4 @@ class TotalSegmentator(Transform):
|
|
|
739
752
|
torch.from_numpy(np.array(np.asanyarray(seg.dataobj), copy=True).astype(np.uint8, copy=False))
|
|
740
753
|
.permute(2, 1, 0)
|
|
741
754
|
.unsqueeze(0)
|
|
742
|
-
)
|
|
743
|
-
|
|
744
|
-
def inverse(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
|
|
745
|
-
return tensor
|
|
755
|
+
)
|
konfai/predictor.py
CHANGED
|
@@ -16,7 +16,7 @@ from torch.utils.tensorboard.writer import SummaryWriter
|
|
|
16
16
|
from konfai import config_file, konfai_root, models_directory, path_to_models, predictions_directory
|
|
17
17
|
from konfai.data.data_manager import DataPrediction, DatasetIter
|
|
18
18
|
from konfai.data.patching import Accumulator, PathCombine
|
|
19
|
-
from konfai.data.transform import Transform, TransformLoader
|
|
19
|
+
from konfai.data.transform import Transform, TransformInverse, TransformLoader
|
|
20
20
|
from konfai.network.network import CPUModel, ModelLoader, NetState, Network
|
|
21
21
|
from konfai.utils.config import config
|
|
22
22
|
from konfai.utils.dataset import Attribute, Dataset
|
|
@@ -226,11 +226,12 @@ class OutSameAsGroupDataset(OutputDataset):
|
|
|
226
226
|
|
|
227
227
|
if self.inverse_transform:
|
|
228
228
|
for transform in reversed(dataset.groups_src[self.group_src][self.group_dest].patch_transforms):
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
|
|
229
|
+
if isinstance(transform, TransformInverse) and transform.apply_inverse:
|
|
230
|
+
layer = transform.inverse(
|
|
231
|
+
self.names[index_dataset],
|
|
232
|
+
layer,
|
|
233
|
+
self.attributes[index_dataset][index_augmentation][index_patch],
|
|
234
|
+
)
|
|
234
235
|
self.output_layer_accumulator[index_dataset][index_augmentation].add_layer(index_patch, layer)
|
|
235
236
|
|
|
236
237
|
def load(self, name_layer: str, datasets: list[Dataset], groups: dict[str, str]):
|
|
@@ -277,7 +278,8 @@ class OutSameAsGroupDataset(OutputDataset):
|
|
|
277
278
|
|
|
278
279
|
if self.inverse_transform:
|
|
279
280
|
for transform in reversed(dataset.groups_src[self.group_src][self.group_dest].transforms):
|
|
280
|
-
|
|
281
|
+
if isinstance(transform, TransformInverse) and transform.apply_inverse:
|
|
282
|
+
result = transform.inverse(self.names[index], result, self.attributes[index][0][0])
|
|
281
283
|
|
|
282
284
|
for transform in self.final_transforms:
|
|
283
285
|
result = transform(self.names[index], result, self.attributes[index][0][0])
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: konfai
|
|
3
|
-
Version: 1.2.
|
|
3
|
+
Version: 1.2.6
|
|
4
4
|
Summary: Modular and configurable Deep Learning framework with YAML and PyTorch
|
|
5
5
|
Author-email: Valentin Boussot <boussot.v@gmail.com>
|
|
6
6
|
License-Expression: Apache-2.0
|
|
@@ -21,6 +21,7 @@ Requires-Dist: SimpleITK
|
|
|
21
21
|
Requires-Dist: lxml
|
|
22
22
|
Requires-Dist: h5py
|
|
23
23
|
Requires-Dist: pynvml
|
|
24
|
+
Requires-Dist: requests
|
|
24
25
|
Provides-Extra: vtk
|
|
25
26
|
Requires-Dist: vtk; extra == "vtk"
|
|
26
27
|
Provides-Extra: lpips
|
|
@@ -1,13 +1,13 @@
|
|
|
1
1
|
konfai/__init__.py,sha256=qjE9Rqxo1sMrkqGS8I5xlGQMZnjIfU-CGgSI5Wmbmbs,1231
|
|
2
2
|
konfai/evaluator.py,sha256=xAKWUDvdSxqYRUsKqH6ieQF06LWa785aE4zLv4I3_i4,17850
|
|
3
3
|
konfai/main.py,sha256=Fc4HcJEhPmgunj_f-QYyvQNvjHrKHSUv27Okgu6V5_A,3842
|
|
4
|
-
konfai/predictor.py,sha256=
|
|
4
|
+
konfai/predictor.py,sha256=fImktAvpTUZMUsf3v1fG95aLrXkUb9Mpbvnev3hA9Bc,34924
|
|
5
5
|
konfai/trainer.py,sha256=g_TkPDUjToFGDGB7aaRZMn-fQllHV_I2GHFKUzDGF8o,27106
|
|
6
6
|
konfai/data/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
7
|
-
konfai/data/augmentation.py,sha256=
|
|
8
|
-
konfai/data/data_manager.py,sha256=
|
|
7
|
+
konfai/data/augmentation.py,sha256=7jrWcpw61t3cCIdHtUlnekRO7JwXIM5Q6RIXY8Ya-xM,27796
|
|
8
|
+
konfai/data/data_manager.py,sha256=tZ2DZHDW4UySCCzwEzR2WIL0fTp7lqAfqEbNPiEw5NE,31064
|
|
9
9
|
konfai/data/patching.py,sha256=jS35OxnJagKNUnJu7TzuGZpVj9fP-6H4nc2OEYOGgt8,16494
|
|
10
|
-
konfai/data/transform.py,sha256=
|
|
10
|
+
konfai/data/transform.py,sha256=LXzvPTgLyqojVn_sRLwU_07FiI3fLSDINKRbynZBiHQ,30318
|
|
11
11
|
konfai/metric/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
12
12
|
konfai/metric/measure.py,sha256=0mOIZKTa2u0UECpoDSbdJUhttAw_e1BlsROQQpi1oBk,27804
|
|
13
13
|
konfai/metric/schedulers.py,sha256=TpYMA24FMpxRnqfhMGb0i_Mm-bzT9kySbBgvkYk-6wM,1327
|
|
@@ -30,9 +30,9 @@ konfai/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
|
30
30
|
konfai/utils/config.py,sha256=a7t44CYMUT5oCDdjL94IswhCVfFbQ5FCgDWZktDDkc4,14347
|
|
31
31
|
konfai/utils/dataset.py,sha256=Au22fcADKyDJMfS8Z9q8kEXLtKkoufJsH7Pwly6pALo,28288
|
|
32
32
|
konfai/utils/utils.py,sha256=jCj3tZ8agQYceSY_tlVYp88UFPE5oUn6tXrqnZGrKiI,28410
|
|
33
|
-
konfai-1.2.
|
|
34
|
-
konfai-1.2.
|
|
35
|
-
konfai-1.2.
|
|
36
|
-
konfai-1.2.
|
|
37
|
-
konfai-1.2.
|
|
38
|
-
konfai-1.2.
|
|
33
|
+
konfai-1.2.6.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
|
|
34
|
+
konfai-1.2.6.dist-info/METADATA,sha256=yDZpZnLUABNQK1alsMDB0m6dHjXxRyxACm1FSAB49aQ,2475
|
|
35
|
+
konfai-1.2.6.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
36
|
+
konfai-1.2.6.dist-info/entry_points.txt,sha256=fG82HRN5-g39ACSOCtij_I3N6EHxfYnMR0D7TI_8pW8,81
|
|
37
|
+
konfai-1.2.6.dist-info/top_level.txt,sha256=xF470dkIlFoFqTZEOlRehKJr4WU_8OKGXrJqYm9vWKs,7
|
|
38
|
+
konfai-1.2.6.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|