konfai 1.2.5__py3-none-any.whl → 1.2.7__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of konfai might be problematic. Click here for more details.
- konfai/data/augmentation.py +1 -1
- konfai/data/data_manager.py +1 -1
- konfai/data/patching.py +2 -2
- konfai/data/transform.py +91 -80
- konfai/predictor.py +9 -7
- {konfai-1.2.5.dist-info → konfai-1.2.7.dist-info}/METADATA +2 -1
- {konfai-1.2.5.dist-info → konfai-1.2.7.dist-info}/RECORD +11 -11
- {konfai-1.2.5.dist-info → konfai-1.2.7.dist-info}/WHEEL +0 -0
- {konfai-1.2.5.dist-info → konfai-1.2.7.dist-info}/entry_points.txt +0 -0
- {konfai-1.2.5.dist-info → konfai-1.2.7.dist-info}/licenses/LICENSE +0 -0
- {konfai-1.2.5.dist-info → konfai-1.2.7.dist-info}/top_level.txt +0 -0
konfai/data/augmentation.py
CHANGED
konfai/data/data_manager.py
CHANGED
|
@@ -515,7 +515,7 @@ class Data(ABC):
|
|
|
515
515
|
self.groups_src[group_src][group_dest].load(
|
|
516
516
|
group_src,
|
|
517
517
|
group_dest,
|
|
518
|
-
|
|
518
|
+
list(self.datasets.values()),
|
|
519
519
|
)
|
|
520
520
|
model_have_input |= self.groups_src[group_src][group_dest].is_input
|
|
521
521
|
if self.patch is not None:
|
konfai/data/patching.py
CHANGED
|
@@ -219,7 +219,7 @@ class Patch(ABC):
|
|
|
219
219
|
)
|
|
220
220
|
slices = [s] + list(self._patch_slices[a][index][1:])
|
|
221
221
|
data_sliced = data[slices_pre + slices]
|
|
222
|
-
if data_sliced.shape[len(slices_pre)] < bottom + top + 1:
|
|
222
|
+
if extend_slice > 0 and data_sliced.shape[len(slices_pre)] < bottom + top + 1:
|
|
223
223
|
pad_bottom = 0
|
|
224
224
|
pad_top = 0
|
|
225
225
|
if self._patch_slices[a][index][0].start - bottom < 0:
|
|
@@ -335,7 +335,7 @@ class DatasetManager:
|
|
|
335
335
|
self.data: list[torch.Tensor] = []
|
|
336
336
|
|
|
337
337
|
for transform_function in transforms:
|
|
338
|
-
_shape = transform_function.transform_shape(_shape, cache_attribute)
|
|
338
|
+
_shape = transform_function.transform_shape(self.name, _shape, cache_attribute)
|
|
339
339
|
|
|
340
340
|
self.patch = (
|
|
341
341
|
DatasetPatch(
|
konfai/data/transform.py
CHANGED
|
@@ -21,13 +21,19 @@ class Transform(NeedDevice, ABC):
|
|
|
21
21
|
def set_datasets(self, datasets: list[Dataset]):
|
|
22
22
|
self.datasets = datasets
|
|
23
23
|
|
|
24
|
-
def transform_shape(self, shape: list[int], cache_attribute: Attribute) -> list[int]:
|
|
24
|
+
def transform_shape(self, name: str, shape: list[int], cache_attribute: Attribute) -> list[int]:
|
|
25
25
|
return shape
|
|
26
26
|
|
|
27
27
|
@abstractmethod
|
|
28
28
|
def __call__(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
|
|
29
29
|
pass
|
|
30
30
|
|
|
31
|
+
class TransformInverse(Transform, ABC):
|
|
32
|
+
|
|
33
|
+
def __init__(self, inverse: bool) -> None:
|
|
34
|
+
super().__init__()
|
|
35
|
+
self.apply_inverse = inverse
|
|
36
|
+
|
|
31
37
|
@abstractmethod
|
|
32
38
|
def inverse(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
|
|
33
39
|
pass
|
|
@@ -52,8 +58,9 @@ class Clip(Transform):
|
|
|
52
58
|
max_value: float | str = 1024,
|
|
53
59
|
save_clip_min: bool = False,
|
|
54
60
|
save_clip_max: bool = False,
|
|
55
|
-
mask: str | None = None
|
|
61
|
+
mask: str | None = None,
|
|
56
62
|
) -> None:
|
|
63
|
+
super().__init__()
|
|
57
64
|
if isinstance(min_value, float) and isinstance(max_value, float) and max_value <= min_value:
|
|
58
65
|
raise ValueError(
|
|
59
66
|
f"[Clip] Invalid clipping range: max_value ({max_value}) must be greater than min_value ({min_value})"
|
|
@@ -72,7 +79,10 @@ class Clip(Transform):
|
|
|
72
79
|
mask, _ = dataset.read_data(self.mask, name)
|
|
73
80
|
break
|
|
74
81
|
if mask is None and self.mask is not None:
|
|
75
|
-
raise ValueError(
|
|
82
|
+
raise ValueError(
|
|
83
|
+
f"Requested mask '{self.mask}' is not present in any dataset. "
|
|
84
|
+
"Check your dataset group names or configuration."
|
|
85
|
+
)
|
|
76
86
|
if mask is None:
|
|
77
87
|
tensor_masked = tensor
|
|
78
88
|
else:
|
|
@@ -88,7 +98,10 @@ class Clip(Transform):
|
|
|
88
98
|
except (IndexError, ValueError):
|
|
89
99
|
raise ValueError(f"Invalid format for min_value: '{self.min_value}'. Expected 'percentile:<float>'")
|
|
90
100
|
else:
|
|
91
|
-
raise TypeError(
|
|
101
|
+
raise TypeError(
|
|
102
|
+
f"Unsupported string for min_value: '{self.min_value}'."
|
|
103
|
+
"Must be a float, 'min', or 'percentile:<float>'."
|
|
104
|
+
)
|
|
92
105
|
else:
|
|
93
106
|
min_value = self.min_value
|
|
94
107
|
|
|
@@ -102,10 +115,13 @@ class Clip(Transform):
|
|
|
102
115
|
except (IndexError, ValueError):
|
|
103
116
|
raise ValueError(f"Invalid format for max_value: '{self.max_value}'. Expected 'percentile:<float>'")
|
|
104
117
|
else:
|
|
105
|
-
raise TypeError(
|
|
118
|
+
raise TypeError(
|
|
119
|
+
f"Unsupported string for max_value: '{self.max_value}'."
|
|
120
|
+
" Must be a float, 'max', or 'percentile:<float>'."
|
|
121
|
+
)
|
|
106
122
|
else:
|
|
107
123
|
max_value = self.max_value
|
|
108
|
-
|
|
124
|
+
|
|
109
125
|
tensor[torch.where(tensor < min_value)] = min_value
|
|
110
126
|
tensor[torch.where(tensor > max_value)] = max_value
|
|
111
127
|
if self.save_clip_min:
|
|
@@ -114,11 +130,8 @@ class Clip(Transform):
|
|
|
114
130
|
cache_attribute["Max"] = max_value
|
|
115
131
|
return tensor
|
|
116
132
|
|
|
117
|
-
def inverse(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
|
|
118
|
-
return tensor
|
|
119
|
-
|
|
120
133
|
|
|
121
|
-
class Normalize(
|
|
134
|
+
class Normalize(TransformInverse):
|
|
122
135
|
|
|
123
136
|
def __init__(
|
|
124
137
|
self,
|
|
@@ -126,7 +139,9 @@ class Normalize(Transform):
|
|
|
126
139
|
channels: list[int] | None = None,
|
|
127
140
|
min_value: float = -1,
|
|
128
141
|
max_value: float = 1,
|
|
142
|
+
inverse: bool = True
|
|
129
143
|
) -> None:
|
|
144
|
+
super().__init__(inverse)
|
|
130
145
|
if max_value <= min_value:
|
|
131
146
|
raise ValueError(
|
|
132
147
|
f"[Normalize] Invalid range: max_value ({max_value}) must be greater than min_value ({min_value})"
|
|
@@ -179,15 +194,17 @@ class Normalize(Transform):
|
|
|
179
194
|
return (tensor - self.min_value) * (input_max - input_min) / (self.max_value - self.min_value) + input_min
|
|
180
195
|
|
|
181
196
|
|
|
182
|
-
class Standardize(
|
|
197
|
+
class Standardize(TransformInverse):
|
|
183
198
|
|
|
184
199
|
def __init__(
|
|
185
200
|
self,
|
|
186
201
|
lazy: bool = False,
|
|
187
202
|
mean: list[float] | None = None,
|
|
188
203
|
std: list[float] | None = None,
|
|
189
|
-
mask: str | None = None
|
|
204
|
+
mask: str | None = None,
|
|
205
|
+
inverse: bool = True
|
|
190
206
|
) -> None:
|
|
207
|
+
super().__init__(inverse)
|
|
191
208
|
self.lazy = lazy
|
|
192
209
|
self.mean = mean
|
|
193
210
|
self.std = std
|
|
@@ -201,19 +218,25 @@ class Standardize(Transform):
|
|
|
201
218
|
mask, _ = dataset.read_data(self.mask, name)
|
|
202
219
|
break
|
|
203
220
|
if mask is None and self.mask is not None:
|
|
204
|
-
raise ValueError(
|
|
221
|
+
raise ValueError(
|
|
222
|
+
f"Requested mask '{self.mask}' is not present in any dataset."
|
|
223
|
+
" Check your dataset group names or configuration."
|
|
224
|
+
)
|
|
205
225
|
if mask is None:
|
|
206
226
|
tensor_masked = tensor
|
|
207
227
|
else:
|
|
208
228
|
tensor_masked = tensor[mask == 1]
|
|
209
229
|
|
|
210
230
|
if "Mean" not in cache_attribute:
|
|
211
|
-
cache_attribute["Mean"] =
|
|
212
|
-
|
|
231
|
+
cache_attribute["Mean"] = (
|
|
232
|
+
torch.tensor([torch.mean(tensor_masked.type(torch.float32))])
|
|
233
|
+
if self.mean is None
|
|
234
|
+
else torch.tensor([self.mean])
|
|
235
|
+
)
|
|
236
|
+
|
|
213
237
|
if "Std" not in cache_attribute:
|
|
214
238
|
cache_attribute["Std"] = (
|
|
215
|
-
torch.tensor([torch.std(
|
|
216
|
-
tensor_masked.type(torch.float32))])
|
|
239
|
+
torch.tensor([torch.std(tensor_masked.type(torch.float32))])
|
|
217
240
|
if self.std is None
|
|
218
241
|
else torch.tensor([self.std])
|
|
219
242
|
)
|
|
@@ -233,9 +256,10 @@ class Standardize(Transform):
|
|
|
233
256
|
return tensor * std + mean
|
|
234
257
|
|
|
235
258
|
|
|
236
|
-
class TensorCast(
|
|
259
|
+
class TensorCast(TransformInverse):
|
|
237
260
|
|
|
238
|
-
def __init__(self, dtype: str = "float32") -> None:
|
|
261
|
+
def __init__(self, dtype: str = "float32", inverse: bool = True) -> None:
|
|
262
|
+
super().__init__(inverse)
|
|
239
263
|
self.dtype: torch.dtype = getattr(torch, dtype)
|
|
240
264
|
|
|
241
265
|
def __call__(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
|
|
@@ -253,9 +277,10 @@ class TensorCast(Transform):
|
|
|
253
277
|
return tensor.to(TensorCast.safe_dtype_cast(cache_attribute.pop("dtype")))
|
|
254
278
|
|
|
255
279
|
|
|
256
|
-
class Padding(
|
|
280
|
+
class Padding(TransformInverse):
|
|
257
281
|
|
|
258
|
-
def __init__(self, padding: list[int] = [0, 0, 0, 0, 0, 0], mode: str = "constant") -> None:
|
|
282
|
+
def __init__(self, padding: list[int] = [0, 0, 0, 0, 0, 0], mode: str = "constant", inverse: bool = True) -> None:
|
|
283
|
+
super().__init__(inverse)
|
|
259
284
|
self.padding = padding
|
|
260
285
|
self.mode = mode
|
|
261
286
|
|
|
@@ -275,7 +300,7 @@ class Padding(Transform):
|
|
|
275
300
|
).squeeze(0)
|
|
276
301
|
return result
|
|
277
302
|
|
|
278
|
-
def transform_shape(self, shape: list[int], cache_attribute: Attribute) -> list[int]:
|
|
303
|
+
def transform_shape(self, name: str, shape: list[int], cache_attribute: Attribute) -> list[int]:
|
|
279
304
|
for dim in range(len(self.padding) // 2):
|
|
280
305
|
shape[-dim - 1] += sum(self.padding[dim * 2 : dim * 2 + 2])
|
|
281
306
|
return shape
|
|
@@ -290,9 +315,10 @@ class Padding(Transform):
|
|
|
290
315
|
return result
|
|
291
316
|
|
|
292
317
|
|
|
293
|
-
class Squeeze(
|
|
318
|
+
class Squeeze(TransformInverse):
|
|
294
319
|
|
|
295
|
-
def __init__(self, dim: int) -> None:
|
|
320
|
+
def __init__(self, dim: int, inverse: bool = True) -> None:
|
|
321
|
+
super().__init__(inverse)
|
|
296
322
|
self.dim = dim
|
|
297
323
|
|
|
298
324
|
def __call__(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
|
|
@@ -302,10 +328,10 @@ class Squeeze(Transform):
|
|
|
302
328
|
return tensor.unsqueeze(self.dim)
|
|
303
329
|
|
|
304
330
|
|
|
305
|
-
class Resample(
|
|
331
|
+
class Resample(TransformInverse, ABC):
|
|
306
332
|
|
|
307
|
-
def __init__(self) -> None:
|
|
308
|
-
|
|
333
|
+
def __init__(self, inverse: bool) -> None:
|
|
334
|
+
super().__init__(inverse)
|
|
309
335
|
|
|
310
336
|
def _resample(self, tensor: torch.Tensor, size: list[int]) -> torch.Tensor:
|
|
311
337
|
if tensor.dtype == torch.uint8:
|
|
@@ -326,7 +352,7 @@ class Resample(Transform, ABC):
|
|
|
326
352
|
pass
|
|
327
353
|
|
|
328
354
|
@abstractmethod
|
|
329
|
-
def transform_shape(self, shape: list[int], cache_attribute: Attribute) -> list[int]:
|
|
355
|
+
def transform_shape(self, name: str, shape: list[int], cache_attribute: Attribute) -> list[int]:
|
|
330
356
|
pass
|
|
331
357
|
|
|
332
358
|
def inverse(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
|
|
@@ -338,7 +364,8 @@ class Resample(Transform, ABC):
|
|
|
338
364
|
|
|
339
365
|
class ResampleToResolution(Resample):
|
|
340
366
|
|
|
341
|
-
def __init__(self, spacing: list[float] = [1.0, 1.0, 1.0]) -> None:
|
|
367
|
+
def __init__(self, spacing: list[float] = [1.0, 1.0, 1.0], inverse: bool = True) -> None:
|
|
368
|
+
super().__init__(inverse)
|
|
342
369
|
self.spacing = torch.tensor([0 if s < 0 else s for s in spacing])
|
|
343
370
|
|
|
344
371
|
def transform_shape(self, shape: list[int], cache_attribute: Attribute) -> list[int]:
|
|
@@ -349,35 +376,37 @@ class ResampleToResolution(Resample):
|
|
|
349
376
|
)
|
|
350
377
|
if len(shape) != len(self.spacing):
|
|
351
378
|
TransformError("Shape and spacing dimensions do not match: shape={shape}, spacing={self.spacing}")
|
|
352
|
-
image_spacing = cache_attribute.get_tensor("Spacing")
|
|
379
|
+
image_spacing = cache_attribute.get_tensor("Spacing")
|
|
353
380
|
spacing = self.spacing
|
|
354
381
|
|
|
355
382
|
for i, s in enumerate(self.spacing):
|
|
356
383
|
if s == 0:
|
|
357
384
|
spacing[i] = image_spacing[i]
|
|
358
|
-
resize_factor = spacing /
|
|
385
|
+
resize_factor = spacing / image_spacing
|
|
359
386
|
return [int(x) for x in (torch.tensor(shape) * 1 / resize_factor)]
|
|
360
387
|
|
|
361
388
|
def __call__(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
|
|
362
|
-
image_spacing = cache_attribute.get_tensor("Spacing")
|
|
389
|
+
image_spacing = cache_attribute.get_tensor("Spacing")
|
|
363
390
|
spacing = self.spacing
|
|
364
391
|
for i, s in enumerate(self.spacing):
|
|
365
392
|
if s == 0:
|
|
366
393
|
spacing[i] = image_spacing[i]
|
|
367
|
-
resize_factor = spacing / cache_attribute.get_tensor("Spacing")
|
|
368
|
-
cache_attribute["Spacing"] = spacing
|
|
394
|
+
resize_factor = spacing / cache_attribute.get_tensor("Spacing")
|
|
395
|
+
cache_attribute["Spacing"] = spacing
|
|
369
396
|
cache_attribute["Size"] = np.asarray([int(x) for x in torch.tensor(tensor.shape[1:])])
|
|
370
397
|
size = [int(x) for x in (torch.tensor(tensor.shape[1:]) * 1 / resize_factor)]
|
|
371
398
|
cache_attribute["Size"] = np.asarray(size)
|
|
372
399
|
return self._resample(tensor, size)
|
|
373
400
|
|
|
374
401
|
|
|
402
|
+
|
|
375
403
|
class ResampleToShape(Resample):
|
|
376
404
|
|
|
377
|
-
def __init__(self, shape: list[float] = [100, 256, 256]) -> None:
|
|
405
|
+
def __init__(self, shape: list[float] = [100, 256, 256], inverse: bool = True) -> None:
|
|
406
|
+
super().__init__(inverse)
|
|
378
407
|
self.shape = torch.tensor([0 if s < 0 else s for s in shape])
|
|
379
408
|
|
|
380
|
-
def transform_shape(self, shape: list[int], cache_attribute: Attribute) -> list[int]:
|
|
409
|
+
def transform_shape(self, name: str, shape: list[int], cache_attribute: Attribute) -> list[int]:
|
|
381
410
|
if "Spacing" not in cache_attribute:
|
|
382
411
|
TransformError(
|
|
383
412
|
"Missing 'Spacing' in cache attributes, the data is likely not a valid image.",
|
|
@@ -407,12 +436,13 @@ class ResampleToShape(Resample):
|
|
|
407
436
|
return self._resample(tensor, shape)
|
|
408
437
|
|
|
409
438
|
|
|
410
|
-
class ResampleTransform(
|
|
439
|
+
class ResampleTransform(TransformInverse):
|
|
411
440
|
|
|
412
|
-
def __init__(self, transforms: dict[str, bool]) -> None:
|
|
441
|
+
def __init__(self, transforms: dict[str, bool], inverse: bool = True) -> None:
|
|
442
|
+
super().__init__(inverse)
|
|
413
443
|
self.transforms = transforms
|
|
414
444
|
|
|
415
|
-
def transform_shape(self, shape: list[int], cache_attribute: Attribute) -> list[int]:
|
|
445
|
+
def transform_shape(self, name: str, shape: list[int], cache_attribute: Attribute) -> list[int]:
|
|
416
446
|
return shape
|
|
417
447
|
|
|
418
448
|
def __call__(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
|
|
@@ -485,6 +515,7 @@ class ResampleTransform(Transform):
|
|
|
485
515
|
class Mask(Transform):
|
|
486
516
|
|
|
487
517
|
def __init__(self, path: str = "./default.mha", value_outside: int = 0) -> None:
|
|
518
|
+
super().__init__()
|
|
488
519
|
self.path = path
|
|
489
520
|
self.value_outside = value_outside
|
|
490
521
|
|
|
@@ -501,13 +532,11 @@ class Mask(Transform):
|
|
|
501
532
|
raise NameError(f"Mask : {self.path}/{name} not found")
|
|
502
533
|
return torch.where(torch.tensor(mask) > 0, tensor, self.value_outside)
|
|
503
534
|
|
|
504
|
-
def inverse(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
|
|
505
|
-
return tensor
|
|
506
|
-
|
|
507
535
|
|
|
508
536
|
class Gradient(Transform):
|
|
509
537
|
|
|
510
538
|
def __init__(self, per_dim: bool = False):
|
|
539
|
+
super().__init__()
|
|
511
540
|
self.per_dim = per_dim
|
|
512
541
|
|
|
513
542
|
@staticmethod
|
|
@@ -541,25 +570,21 @@ class Gradient(Transform):
|
|
|
541
570
|
|
|
542
571
|
return result
|
|
543
572
|
|
|
544
|
-
def inverse(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
|
|
545
|
-
return tensor
|
|
546
|
-
|
|
547
|
-
|
|
548
573
|
class Argmax(Transform):
|
|
549
574
|
|
|
550
575
|
def __init__(self, dim: int = 0) -> None:
|
|
576
|
+
super().__init__()
|
|
551
577
|
self.dim = dim
|
|
552
578
|
|
|
553
579
|
def __call__(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
|
|
554
580
|
return torch.argmax(tensor, dim=self.dim).unsqueeze(self.dim)
|
|
555
581
|
|
|
556
|
-
def inverse(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
|
|
557
|
-
return tensor
|
|
558
582
|
|
|
559
583
|
|
|
560
584
|
class FlatLabel(Transform):
|
|
561
585
|
|
|
562
586
|
def __init__(self, labels: list[int] | None = None) -> None:
|
|
587
|
+
super().__init__()
|
|
563
588
|
self.labels = labels
|
|
564
589
|
|
|
565
590
|
def __call__(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
|
|
@@ -571,44 +596,36 @@ class FlatLabel(Transform):
|
|
|
571
596
|
data[torch.where(tensor > 0)] = 1
|
|
572
597
|
return data
|
|
573
598
|
|
|
574
|
-
def inverse(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
|
|
575
|
-
return tensor
|
|
576
|
-
|
|
577
599
|
|
|
578
600
|
class Save(Transform):
|
|
579
601
|
|
|
580
602
|
def __init__(self, dataset: str) -> None:
|
|
603
|
+
super().__init__()
|
|
581
604
|
self.dataset = dataset
|
|
582
605
|
|
|
583
606
|
def __call__(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
|
|
584
607
|
return tensor
|
|
585
608
|
|
|
586
|
-
def inverse(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
|
|
587
|
-
return tensor
|
|
588
|
-
|
|
589
609
|
|
|
590
610
|
class Flatten(Transform):
|
|
591
611
|
|
|
592
612
|
def __init__(self) -> None:
|
|
593
613
|
super().__init__()
|
|
594
614
|
|
|
595
|
-
def transform_shape(self, shape: list[int], cache_attribute: Attribute) -> list[int]:
|
|
615
|
+
def transform_shape(self, name: str, shape: list[int], cache_attribute: Attribute) -> list[int]:
|
|
596
616
|
return [np.prod(np.asarray(shape))]
|
|
597
617
|
|
|
598
618
|
def __call__(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
|
|
599
619
|
return tensor.flatten()
|
|
600
620
|
|
|
601
|
-
def inverse(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
|
|
602
|
-
return tensor
|
|
603
621
|
|
|
622
|
+
class Permute(TransformInverse):
|
|
604
623
|
|
|
605
|
-
|
|
606
|
-
|
|
607
|
-
def __init__(self, dims: str = "1|0|2") -> None:
|
|
608
|
-
super().__init__()
|
|
624
|
+
def __init__(self, dims: str = "1|0|2", inverse: bool = True) -> None:
|
|
625
|
+
super().__init__(inverse)
|
|
609
626
|
self.dims = [0] + [int(d) + 1 for d in dims.split("|")]
|
|
610
627
|
|
|
611
|
-
def transform_shape(self, shape: list[int], cache_attribute: Attribute) -> list[int]:
|
|
628
|
+
def transform_shape(self, name: str, shape: list[int], cache_attribute: Attribute) -> list[int]:
|
|
612
629
|
return [shape[it - 1] for it in self.dims[1:]]
|
|
613
630
|
|
|
614
631
|
def __call__(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
|
|
@@ -618,10 +635,10 @@ class Permute(Transform):
|
|
|
618
635
|
return tensor.permute(tuple(np.argsort(self.dims)))
|
|
619
636
|
|
|
620
637
|
|
|
621
|
-
class Flip(
|
|
638
|
+
class Flip(TransformInverse):
|
|
622
639
|
|
|
623
|
-
def __init__(self, dims: str = "1|0|2") -> None:
|
|
624
|
-
super().__init__()
|
|
640
|
+
def __init__(self, dims: str = "1|0|2", inverse: bool = True) -> None:
|
|
641
|
+
super().__init__(inverse)
|
|
625
642
|
|
|
626
643
|
self.dims = [int(d) + 1 for d in str(dims).split("|")]
|
|
627
644
|
|
|
@@ -632,9 +649,10 @@ class Flip(Transform):
|
|
|
632
649
|
return tensor.flip(tuple(self.dims))
|
|
633
650
|
|
|
634
651
|
|
|
635
|
-
class Canonical(
|
|
652
|
+
class Canonical(TransformInverse):
|
|
636
653
|
|
|
637
|
-
def __init__(self) -> None:
|
|
654
|
+
def __init__(self, inverse: bool = True) -> None:
|
|
655
|
+
super().__init__(inverse)
|
|
638
656
|
self.canonical_direction = torch.diag(torch.tensor([-1, -1, 1])).to(torch.double)
|
|
639
657
|
|
|
640
658
|
def __call__(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
|
|
@@ -667,6 +685,7 @@ class Canonical(Transform):
|
|
|
667
685
|
class HistogramMatching(Transform):
|
|
668
686
|
|
|
669
687
|
def __init__(self, reference_group: str) -> None:
|
|
688
|
+
super().__init__()
|
|
670
689
|
self.reference_group = reference_group
|
|
671
690
|
|
|
672
691
|
def __call__(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
|
|
@@ -684,13 +703,11 @@ class HistogramMatching(Transform):
|
|
|
684
703
|
result, _ = image_to_data(matcher.Execute(image, image_ref))
|
|
685
704
|
return torch.tensor(result)
|
|
686
705
|
|
|
687
|
-
def inverse(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
|
|
688
|
-
return tensor
|
|
689
|
-
|
|
690
706
|
|
|
691
707
|
class SelectLabel(Transform):
|
|
692
708
|
|
|
693
709
|
def __init__(self, labels: list[str]) -> None:
|
|
710
|
+
super().__init__()
|
|
694
711
|
self.labels = [label[1:-1].split(",") for label in labels]
|
|
695
712
|
|
|
696
713
|
def __call__(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
|
|
@@ -699,14 +716,11 @@ class SelectLabel(Transform):
|
|
|
699
716
|
data[tensor == int(old_label)] = int(new_label)
|
|
700
717
|
return data
|
|
701
718
|
|
|
702
|
-
def inverse(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
|
|
703
|
-
return tensor
|
|
704
719
|
|
|
720
|
+
class OneHot(TransformInverse):
|
|
705
721
|
|
|
706
|
-
|
|
707
|
-
|
|
708
|
-
def __init__(self, num_classes: int) -> None:
|
|
709
|
-
super().__init__()
|
|
722
|
+
def __init__(self, num_classes: int, inverse: bool = True) -> None:
|
|
723
|
+
super().__init__(inverse)
|
|
710
724
|
self.num_classes = num_classes
|
|
711
725
|
|
|
712
726
|
def __call__(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
|
|
@@ -739,7 +753,4 @@ class TotalSegmentator(Transform):
|
|
|
739
753
|
torch.from_numpy(np.array(np.asanyarray(seg.dataobj), copy=True).astype(np.uint8, copy=False))
|
|
740
754
|
.permute(2, 1, 0)
|
|
741
755
|
.unsqueeze(0)
|
|
742
|
-
)
|
|
743
|
-
|
|
744
|
-
def inverse(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
|
|
745
|
-
return tensor
|
|
756
|
+
)
|
konfai/predictor.py
CHANGED
|
@@ -16,7 +16,7 @@ from torch.utils.tensorboard.writer import SummaryWriter
|
|
|
16
16
|
from konfai import config_file, konfai_root, models_directory, path_to_models, predictions_directory
|
|
17
17
|
from konfai.data.data_manager import DataPrediction, DatasetIter
|
|
18
18
|
from konfai.data.patching import Accumulator, PathCombine
|
|
19
|
-
from konfai.data.transform import Transform, TransformLoader
|
|
19
|
+
from konfai.data.transform import Transform, TransformInverse, TransformLoader
|
|
20
20
|
from konfai.network.network import CPUModel, ModelLoader, NetState, Network
|
|
21
21
|
from konfai.utils.config import config
|
|
22
22
|
from konfai.utils.dataset import Attribute, Dataset
|
|
@@ -226,11 +226,12 @@ class OutSameAsGroupDataset(OutputDataset):
|
|
|
226
226
|
|
|
227
227
|
if self.inverse_transform:
|
|
228
228
|
for transform in reversed(dataset.groups_src[self.group_src][self.group_dest].patch_transforms):
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
|
|
229
|
+
if isinstance(transform, TransformInverse) and transform.apply_inverse:
|
|
230
|
+
layer = transform.inverse(
|
|
231
|
+
self.names[index_dataset],
|
|
232
|
+
layer,
|
|
233
|
+
self.attributes[index_dataset][index_augmentation][index_patch],
|
|
234
|
+
)
|
|
234
235
|
self.output_layer_accumulator[index_dataset][index_augmentation].add_layer(index_patch, layer)
|
|
235
236
|
|
|
236
237
|
def load(self, name_layer: str, datasets: list[Dataset], groups: dict[str, str]):
|
|
@@ -277,7 +278,8 @@ class OutSameAsGroupDataset(OutputDataset):
|
|
|
277
278
|
|
|
278
279
|
if self.inverse_transform:
|
|
279
280
|
for transform in reversed(dataset.groups_src[self.group_src][self.group_dest].transforms):
|
|
280
|
-
|
|
281
|
+
if isinstance(transform, TransformInverse) and transform.apply_inverse:
|
|
282
|
+
result = transform.inverse(self.names[index], result, self.attributes[index][0][0])
|
|
281
283
|
|
|
282
284
|
for transform in self.final_transforms:
|
|
283
285
|
result = transform(self.names[index], result, self.attributes[index][0][0])
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: konfai
|
|
3
|
-
Version: 1.2.
|
|
3
|
+
Version: 1.2.7
|
|
4
4
|
Summary: Modular and configurable Deep Learning framework with YAML and PyTorch
|
|
5
5
|
Author-email: Valentin Boussot <boussot.v@gmail.com>
|
|
6
6
|
License-Expression: Apache-2.0
|
|
@@ -21,6 +21,7 @@ Requires-Dist: SimpleITK
|
|
|
21
21
|
Requires-Dist: lxml
|
|
22
22
|
Requires-Dist: h5py
|
|
23
23
|
Requires-Dist: pynvml
|
|
24
|
+
Requires-Dist: requests
|
|
24
25
|
Provides-Extra: vtk
|
|
25
26
|
Requires-Dist: vtk; extra == "vtk"
|
|
26
27
|
Provides-Extra: lpips
|
|
@@ -1,13 +1,13 @@
|
|
|
1
1
|
konfai/__init__.py,sha256=qjE9Rqxo1sMrkqGS8I5xlGQMZnjIfU-CGgSI5Wmbmbs,1231
|
|
2
2
|
konfai/evaluator.py,sha256=xAKWUDvdSxqYRUsKqH6ieQF06LWa785aE4zLv4I3_i4,17850
|
|
3
3
|
konfai/main.py,sha256=Fc4HcJEhPmgunj_f-QYyvQNvjHrKHSUv27Okgu6V5_A,3842
|
|
4
|
-
konfai/predictor.py,sha256=
|
|
4
|
+
konfai/predictor.py,sha256=fImktAvpTUZMUsf3v1fG95aLrXkUb9Mpbvnev3hA9Bc,34924
|
|
5
5
|
konfai/trainer.py,sha256=g_TkPDUjToFGDGB7aaRZMn-fQllHV_I2GHFKUzDGF8o,27106
|
|
6
6
|
konfai/data/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
7
|
-
konfai/data/augmentation.py,sha256=
|
|
8
|
-
konfai/data/data_manager.py,sha256=
|
|
9
|
-
konfai/data/patching.py,sha256=
|
|
10
|
-
konfai/data/transform.py,sha256=
|
|
7
|
+
konfai/data/augmentation.py,sha256=7jrWcpw61t3cCIdHtUlnekRO7JwXIM5Q6RIXY8Ya-xM,27796
|
|
8
|
+
konfai/data/data_manager.py,sha256=tZ2DZHDW4UySCCzwEzR2WIL0fTp7lqAfqEbNPiEw5NE,31064
|
|
9
|
+
konfai/data/patching.py,sha256=P0TcjR4qcUWpB_Uph0-dd8bMeNVJC_IGNK_jkxStglQ,16526
|
|
10
|
+
konfai/data/transform.py,sha256=VVDaQOBdgvHhOdoJlo9fkOJ2ln8th_0u8go1I4Uy7eo,30332
|
|
11
11
|
konfai/metric/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
12
12
|
konfai/metric/measure.py,sha256=0mOIZKTa2u0UECpoDSbdJUhttAw_e1BlsROQQpi1oBk,27804
|
|
13
13
|
konfai/metric/schedulers.py,sha256=TpYMA24FMpxRnqfhMGb0i_Mm-bzT9kySbBgvkYk-6wM,1327
|
|
@@ -30,9 +30,9 @@ konfai/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
|
30
30
|
konfai/utils/config.py,sha256=a7t44CYMUT5oCDdjL94IswhCVfFbQ5FCgDWZktDDkc4,14347
|
|
31
31
|
konfai/utils/dataset.py,sha256=Au22fcADKyDJMfS8Z9q8kEXLtKkoufJsH7Pwly6pALo,28288
|
|
32
32
|
konfai/utils/utils.py,sha256=jCj3tZ8agQYceSY_tlVYp88UFPE5oUn6tXrqnZGrKiI,28410
|
|
33
|
-
konfai-1.2.
|
|
34
|
-
konfai-1.2.
|
|
35
|
-
konfai-1.2.
|
|
36
|
-
konfai-1.2.
|
|
37
|
-
konfai-1.2.
|
|
38
|
-
konfai-1.2.
|
|
33
|
+
konfai-1.2.7.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
|
|
34
|
+
konfai-1.2.7.dist-info/METADATA,sha256=CVuc5KlcaickUPpttTyIdhdRHGm3h8bEWSeLidc5mUs,2475
|
|
35
|
+
konfai-1.2.7.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
36
|
+
konfai-1.2.7.dist-info/entry_points.txt,sha256=fG82HRN5-g39ACSOCtij_I3N6EHxfYnMR0D7TI_8pW8,81
|
|
37
|
+
konfai-1.2.7.dist-info/top_level.txt,sha256=xF470dkIlFoFqTZEOlRehKJr4WU_8OKGXrJqYm9vWKs,7
|
|
38
|
+
konfai-1.2.7.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|