konfai 1.2.4__tar.gz → 1.2.6__tar.gz
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of konfai might be problematic. Click here for more details.
- {konfai-1.2.4 → konfai-1.2.6}/PKG-INFO +2 -1
- {konfai-1.2.4 → konfai-1.2.6}/konfai/data/augmentation.py +1 -1
- {konfai-1.2.4 → konfai-1.2.6}/konfai/data/data_manager.py +1 -1
- {konfai-1.2.4 → konfai-1.2.6}/konfai/data/transform.py +134 -75
- {konfai-1.2.4 → konfai-1.2.6}/konfai/network/network.py +1 -2
- {konfai-1.2.4 → konfai-1.2.6}/konfai/predictor.py +12 -8
- {konfai-1.2.4 → konfai-1.2.6}/konfai.egg-info/PKG-INFO +2 -1
- {konfai-1.2.4 → konfai-1.2.6}/konfai.egg-info/requires.txt +1 -0
- {konfai-1.2.4 → konfai-1.2.6}/pyproject.toml +3 -2
- {konfai-1.2.4 → konfai-1.2.6}/LICENSE +0 -0
- {konfai-1.2.4 → konfai-1.2.6}/README.md +0 -0
- {konfai-1.2.4 → konfai-1.2.6}/konfai/__init__.py +0 -0
- {konfai-1.2.4 → konfai-1.2.6}/konfai/data/__init__.py +0 -0
- {konfai-1.2.4 → konfai-1.2.6}/konfai/data/patching.py +0 -0
- {konfai-1.2.4 → konfai-1.2.6}/konfai/evaluator.py +0 -0
- {konfai-1.2.4 → konfai-1.2.6}/konfai/main.py +0 -0
- {konfai-1.2.4 → konfai-1.2.6}/konfai/metric/__init__.py +0 -0
- {konfai-1.2.4 → konfai-1.2.6}/konfai/metric/measure.py +0 -0
- {konfai-1.2.4 → konfai-1.2.6}/konfai/metric/schedulers.py +0 -0
- {konfai-1.2.4 → konfai-1.2.6}/konfai/models/classification/convNeXt.py +0 -0
- {konfai-1.2.4 → konfai-1.2.6}/konfai/models/classification/resnet.py +0 -0
- {konfai-1.2.4 → konfai-1.2.6}/konfai/models/generation/cStyleGan.py +0 -0
- {konfai-1.2.4 → konfai-1.2.6}/konfai/models/generation/ddpm.py +0 -0
- {konfai-1.2.4 → konfai-1.2.6}/konfai/models/generation/diffusionGan.py +0 -0
- {konfai-1.2.4 → konfai-1.2.6}/konfai/models/generation/gan.py +0 -0
- {konfai-1.2.4 → konfai-1.2.6}/konfai/models/generation/vae.py +0 -0
- {konfai-1.2.4 → konfai-1.2.6}/konfai/models/registration/registration.py +0 -0
- {konfai-1.2.4 → konfai-1.2.6}/konfai/models/representation/representation.py +0 -0
- {konfai-1.2.4 → konfai-1.2.6}/konfai/models/segmentation/NestedUNet.py +0 -0
- {konfai-1.2.4 → konfai-1.2.6}/konfai/models/segmentation/UNet.py +0 -0
- {konfai-1.2.4 → konfai-1.2.6}/konfai/network/__init__.py +0 -0
- {konfai-1.2.4 → konfai-1.2.6}/konfai/network/blocks.py +0 -0
- {konfai-1.2.4 → konfai-1.2.6}/konfai/trainer.py +0 -0
- {konfai-1.2.4 → konfai-1.2.6}/konfai/utils/ITK.py +0 -0
- {konfai-1.2.4 → konfai-1.2.6}/konfai/utils/__init__.py +0 -0
- {konfai-1.2.4 → konfai-1.2.6}/konfai/utils/config.py +0 -0
- {konfai-1.2.4 → konfai-1.2.6}/konfai/utils/dataset.py +0 -0
- {konfai-1.2.4 → konfai-1.2.6}/konfai/utils/utils.py +0 -0
- {konfai-1.2.4 → konfai-1.2.6}/konfai.egg-info/SOURCES.txt +0 -0
- {konfai-1.2.4 → konfai-1.2.6}/konfai.egg-info/dependency_links.txt +0 -0
- {konfai-1.2.4 → konfai-1.2.6}/konfai.egg-info/entry_points.txt +0 -0
- {konfai-1.2.4 → konfai-1.2.6}/konfai.egg-info/top_level.txt +0 -0
- {konfai-1.2.4 → konfai-1.2.6}/setup.cfg +0 -0
- {konfai-1.2.4 → konfai-1.2.6}/tests/test_config.py +0 -0
- {konfai-1.2.4 → konfai-1.2.6}/tests/test_dataset.py +0 -0
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: konfai
|
|
3
|
-
Version: 1.2.
|
|
3
|
+
Version: 1.2.6
|
|
4
4
|
Summary: Modular and configurable Deep Learning framework with YAML and PyTorch
|
|
5
5
|
Author-email: Valentin Boussot <boussot.v@gmail.com>
|
|
6
6
|
License-Expression: Apache-2.0
|
|
@@ -21,6 +21,7 @@ Requires-Dist: SimpleITK
|
|
|
21
21
|
Requires-Dist: lxml
|
|
22
22
|
Requires-Dist: h5py
|
|
23
23
|
Requires-Dist: pynvml
|
|
24
|
+
Requires-Dist: requests
|
|
24
25
|
Provides-Extra: vtk
|
|
25
26
|
Requires-Dist: vtk; extra == "vtk"
|
|
26
27
|
Provides-Extra: lpips
|
|
@@ -515,7 +515,7 @@ class Data(ABC):
|
|
|
515
515
|
self.groups_src[group_src][group_dest].load(
|
|
516
516
|
group_src,
|
|
517
517
|
group_dest,
|
|
518
|
-
|
|
518
|
+
list(self.datasets.values()),
|
|
519
519
|
)
|
|
520
520
|
model_have_input |= self.groups_src[group_src][group_dest].is_input
|
|
521
521
|
if self.patch is not None:
|
|
@@ -28,6 +28,12 @@ class Transform(NeedDevice, ABC):
|
|
|
28
28
|
def __call__(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
|
|
29
29
|
pass
|
|
30
30
|
|
|
31
|
+
class TransformInverse(Transform, ABC):
|
|
32
|
+
|
|
33
|
+
def __init__(self, inverse: bool) -> None:
|
|
34
|
+
super().__init__()
|
|
35
|
+
self.apply_inverse = inverse
|
|
36
|
+
|
|
31
37
|
@abstractmethod
|
|
32
38
|
def inverse(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
|
|
33
39
|
pass
|
|
@@ -48,12 +54,14 @@ class Clip(Transform):
|
|
|
48
54
|
|
|
49
55
|
def __init__(
|
|
50
56
|
self,
|
|
51
|
-
min_value: float = -1024,
|
|
52
|
-
max_value: float = 1024,
|
|
57
|
+
min_value: float | str = -1024,
|
|
58
|
+
max_value: float | str = 1024,
|
|
53
59
|
save_clip_min: bool = False,
|
|
54
60
|
save_clip_max: bool = False,
|
|
61
|
+
mask: str | None = None,
|
|
55
62
|
) -> None:
|
|
56
|
-
|
|
63
|
+
super().__init__()
|
|
64
|
+
if isinstance(min_value, float) and isinstance(max_value, float) and max_value <= min_value:
|
|
57
65
|
raise ValueError(
|
|
58
66
|
f"[Clip] Invalid clipping range: max_value ({max_value}) must be greater than min_value ({min_value})"
|
|
59
67
|
)
|
|
@@ -61,21 +69,69 @@ class Clip(Transform):
|
|
|
61
69
|
self.max_value = max_value
|
|
62
70
|
self.save_clip_min = save_clip_min
|
|
63
71
|
self.save_clip_max = save_clip_max
|
|
72
|
+
self.mask = mask
|
|
64
73
|
|
|
65
74
|
def __call__(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
|
|
66
|
-
|
|
67
|
-
|
|
75
|
+
mask = None
|
|
76
|
+
if self.mask is not None:
|
|
77
|
+
for dataset in self.datasets:
|
|
78
|
+
if dataset.is_dataset_exist(self.mask, name):
|
|
79
|
+
mask, _ = dataset.read_data(self.mask, name)
|
|
80
|
+
break
|
|
81
|
+
if mask is None and self.mask is not None:
|
|
82
|
+
raise ValueError(
|
|
83
|
+
f"Requested mask '{self.mask}' is not present in any dataset. "
|
|
84
|
+
"Check your dataset group names or configuration."
|
|
85
|
+
)
|
|
86
|
+
if mask is None:
|
|
87
|
+
tensor_masked = tensor
|
|
88
|
+
else:
|
|
89
|
+
tensor_masked = tensor[mask == 1]
|
|
90
|
+
|
|
91
|
+
if isinstance(self.min_value, str):
|
|
92
|
+
if self.min_value == "min":
|
|
93
|
+
min_value = torch.min(tensor_masked)
|
|
94
|
+
elif self.min_value.startswith("percentile:"):
|
|
95
|
+
try:
|
|
96
|
+
percentile = float(self.min_value.split(":")[1])
|
|
97
|
+
min_value = np.percentile(tensor_masked, percentile)
|
|
98
|
+
except (IndexError, ValueError):
|
|
99
|
+
raise ValueError(f"Invalid format for min_value: '{self.min_value}'. Expected 'percentile:<float>'")
|
|
100
|
+
else:
|
|
101
|
+
raise TypeError(
|
|
102
|
+
f"Unsupported string for min_value: '{self.min_value}'."
|
|
103
|
+
"Must be a float, 'min', or 'percentile:<float>'."
|
|
104
|
+
)
|
|
105
|
+
else:
|
|
106
|
+
min_value = self.min_value
|
|
107
|
+
|
|
108
|
+
if isinstance(self.max_value, str):
|
|
109
|
+
if self.max_value == "max":
|
|
110
|
+
max_value = torch.max(tensor_masked)
|
|
111
|
+
elif self.max_value.startswith("percentile:"):
|
|
112
|
+
try:
|
|
113
|
+
percentile = float(self.max_value.split(":")[1])
|
|
114
|
+
max_value = np.percentile(tensor_masked, percentile)
|
|
115
|
+
except (IndexError, ValueError):
|
|
116
|
+
raise ValueError(f"Invalid format for max_value: '{self.max_value}'. Expected 'percentile:<float>'")
|
|
117
|
+
else:
|
|
118
|
+
raise TypeError(
|
|
119
|
+
f"Unsupported string for max_value: '{self.max_value}'."
|
|
120
|
+
" Must be a float, 'max', or 'percentile:<float>'."
|
|
121
|
+
)
|
|
122
|
+
else:
|
|
123
|
+
max_value = self.max_value
|
|
124
|
+
|
|
125
|
+
tensor[torch.where(tensor < min_value)] = min_value
|
|
126
|
+
tensor[torch.where(tensor > max_value)] = max_value
|
|
68
127
|
if self.save_clip_min:
|
|
69
|
-
cache_attribute["Min"] =
|
|
128
|
+
cache_attribute["Min"] = min_value
|
|
70
129
|
if self.save_clip_max:
|
|
71
|
-
cache_attribute["Max"] =
|
|
130
|
+
cache_attribute["Max"] = max_value
|
|
72
131
|
return tensor
|
|
73
132
|
|
|
74
|
-
def inverse(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
|
|
75
|
-
return tensor
|
|
76
133
|
|
|
77
|
-
|
|
78
|
-
class Normalize(Transform):
|
|
134
|
+
class Normalize(TransformInverse):
|
|
79
135
|
|
|
80
136
|
def __init__(
|
|
81
137
|
self,
|
|
@@ -83,7 +139,9 @@ class Normalize(Transform):
|
|
|
83
139
|
channels: list[int] | None = None,
|
|
84
140
|
min_value: float = -1,
|
|
85
141
|
max_value: float = 1,
|
|
142
|
+
inverse: bool = True
|
|
86
143
|
) -> None:
|
|
144
|
+
super().__init__(inverse)
|
|
87
145
|
if max_value <= min_value:
|
|
88
146
|
raise ValueError(
|
|
89
147
|
f"[Normalize] Invalid range: max_value ({max_value}) must be greater than min_value ({min_value})"
|
|
@@ -136,43 +194,57 @@ class Normalize(Transform):
|
|
|
136
194
|
return (tensor - self.min_value) * (input_max - input_min) / (self.max_value - self.min_value) + input_min
|
|
137
195
|
|
|
138
196
|
|
|
139
|
-
class Standardize(
|
|
197
|
+
class Standardize(TransformInverse):
|
|
140
198
|
|
|
141
199
|
def __init__(
|
|
142
200
|
self,
|
|
143
201
|
lazy: bool = False,
|
|
144
202
|
mean: list[float] | None = None,
|
|
145
203
|
std: list[float] | None = None,
|
|
204
|
+
mask: str | None = None,
|
|
205
|
+
inverse: bool = True
|
|
146
206
|
) -> None:
|
|
207
|
+
super().__init__(inverse)
|
|
147
208
|
self.lazy = lazy
|
|
148
209
|
self.mean = mean
|
|
149
210
|
self.std = std
|
|
211
|
+
self.mask = mask
|
|
150
212
|
|
|
151
213
|
def __call__(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
|
|
214
|
+
mask = None
|
|
215
|
+
if self.mask is not None:
|
|
216
|
+
for dataset in self.datasets:
|
|
217
|
+
if dataset.is_dataset_exist(self.mask, name):
|
|
218
|
+
mask, _ = dataset.read_data(self.mask, name)
|
|
219
|
+
break
|
|
220
|
+
if mask is None and self.mask is not None:
|
|
221
|
+
raise ValueError(
|
|
222
|
+
f"Requested mask '{self.mask}' is not present in any dataset."
|
|
223
|
+
" Check your dataset group names or configuration."
|
|
224
|
+
)
|
|
225
|
+
if mask is None:
|
|
226
|
+
tensor_masked = tensor
|
|
227
|
+
else:
|
|
228
|
+
tensor_masked = tensor[mask == 1]
|
|
229
|
+
|
|
152
230
|
if "Mean" not in cache_attribute:
|
|
153
231
|
cache_attribute["Mean"] = (
|
|
154
|
-
torch.mean(
|
|
155
|
-
tensor.type(torch.float32),
|
|
156
|
-
dim=[i + 1 for i in range(len(tensor.shape) - 1)],
|
|
157
|
-
)
|
|
232
|
+
torch.tensor([torch.mean(tensor_masked.type(torch.float32))])
|
|
158
233
|
if self.mean is None
|
|
159
234
|
else torch.tensor([self.mean])
|
|
160
235
|
)
|
|
236
|
+
|
|
161
237
|
if "Std" not in cache_attribute:
|
|
162
238
|
cache_attribute["Std"] = (
|
|
163
|
-
torch.std(
|
|
164
|
-
tensor.type(torch.float32),
|
|
165
|
-
dim=[i + 1 for i in range(len(tensor.shape) - 1)],
|
|
166
|
-
)
|
|
239
|
+
torch.tensor([torch.std(tensor_masked.type(torch.float32))])
|
|
167
240
|
if self.std is None
|
|
168
241
|
else torch.tensor([self.std])
|
|
169
242
|
)
|
|
170
|
-
|
|
171
243
|
if self.lazy:
|
|
172
244
|
return tensor
|
|
173
245
|
else:
|
|
174
|
-
mean = cache_attribute.get_tensor("Mean")
|
|
175
|
-
std = cache_attribute.get_tensor("Std")
|
|
246
|
+
mean = cache_attribute.get_tensor("Mean")
|
|
247
|
+
std = cache_attribute.get_tensor("Std")
|
|
176
248
|
return (tensor - mean) / std
|
|
177
249
|
|
|
178
250
|
def inverse(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
|
|
@@ -184,9 +256,10 @@ class Standardize(Transform):
|
|
|
184
256
|
return tensor * std + mean
|
|
185
257
|
|
|
186
258
|
|
|
187
|
-
class TensorCast(
|
|
259
|
+
class TensorCast(TransformInverse):
|
|
188
260
|
|
|
189
|
-
def __init__(self, dtype: str = "float32") -> None:
|
|
261
|
+
def __init__(self, dtype: str = "float32", inverse: bool = True) -> None:
|
|
262
|
+
super().__init__(inverse)
|
|
190
263
|
self.dtype: torch.dtype = getattr(torch, dtype)
|
|
191
264
|
|
|
192
265
|
def __call__(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
|
|
@@ -204,9 +277,10 @@ class TensorCast(Transform):
|
|
|
204
277
|
return tensor.to(TensorCast.safe_dtype_cast(cache_attribute.pop("dtype")))
|
|
205
278
|
|
|
206
279
|
|
|
207
|
-
class Padding(
|
|
280
|
+
class Padding(TransformInverse):
|
|
208
281
|
|
|
209
|
-
def __init__(self, padding: list[int] = [0, 0, 0, 0, 0, 0], mode: str = "constant") -> None:
|
|
282
|
+
def __init__(self, padding: list[int] = [0, 0, 0, 0, 0, 0], mode: str = "constant", inverse: bool = True) -> None:
|
|
283
|
+
super().__init__(inverse)
|
|
210
284
|
self.padding = padding
|
|
211
285
|
self.mode = mode
|
|
212
286
|
|
|
@@ -241,9 +315,10 @@ class Padding(Transform):
|
|
|
241
315
|
return result
|
|
242
316
|
|
|
243
317
|
|
|
244
|
-
class Squeeze(
|
|
318
|
+
class Squeeze(TransformInverse):
|
|
245
319
|
|
|
246
|
-
def __init__(self, dim: int) -> None:
|
|
320
|
+
def __init__(self, dim: int, inverse: bool = True) -> None:
|
|
321
|
+
super().__init__(inverse)
|
|
247
322
|
self.dim = dim
|
|
248
323
|
|
|
249
324
|
def __call__(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
|
|
@@ -253,10 +328,10 @@ class Squeeze(Transform):
|
|
|
253
328
|
return tensor.unsqueeze(self.dim)
|
|
254
329
|
|
|
255
330
|
|
|
256
|
-
class Resample(
|
|
331
|
+
class Resample(TransformInverse, ABC):
|
|
257
332
|
|
|
258
|
-
def __init__(self) -> None:
|
|
259
|
-
|
|
333
|
+
def __init__(self, inverse: bool) -> None:
|
|
334
|
+
super().__init__(inverse)
|
|
260
335
|
|
|
261
336
|
def _resample(self, tensor: torch.Tensor, size: list[int]) -> torch.Tensor:
|
|
262
337
|
if tensor.dtype == torch.uint8:
|
|
@@ -289,7 +364,8 @@ class Resample(Transform, ABC):
|
|
|
289
364
|
|
|
290
365
|
class ResampleToResolution(Resample):
|
|
291
366
|
|
|
292
|
-
def __init__(self, spacing: list[float] = [1.0, 1.0, 1.0]) -> None:
|
|
367
|
+
def __init__(self, spacing: list[float] = [1.0, 1.0, 1.0], inverse: bool = True) -> None:
|
|
368
|
+
super().__init__(inverse)
|
|
293
369
|
self.spacing = torch.tensor([0 if s < 0 else s for s in spacing])
|
|
294
370
|
|
|
295
371
|
def transform_shape(self, shape: list[int], cache_attribute: Attribute) -> list[int]:
|
|
@@ -325,7 +401,8 @@ class ResampleToResolution(Resample):
|
|
|
325
401
|
|
|
326
402
|
class ResampleToShape(Resample):
|
|
327
403
|
|
|
328
|
-
def __init__(self, shape: list[float] = [100, 256, 256]) -> None:
|
|
404
|
+
def __init__(self, shape: list[float] = [100, 256, 256], inverse: bool = True) -> None:
|
|
405
|
+
super().__init__(inverse)
|
|
329
406
|
self.shape = torch.tensor([0 if s < 0 else s for s in shape])
|
|
330
407
|
|
|
331
408
|
def transform_shape(self, shape: list[int], cache_attribute: Attribute) -> list[int]:
|
|
@@ -358,9 +435,10 @@ class ResampleToShape(Resample):
|
|
|
358
435
|
return self._resample(tensor, shape)
|
|
359
436
|
|
|
360
437
|
|
|
361
|
-
class ResampleTransform(
|
|
438
|
+
class ResampleTransform(TransformInverse):
|
|
362
439
|
|
|
363
|
-
def __init__(self, transforms: dict[str, bool]) -> None:
|
|
440
|
+
def __init__(self, transforms: dict[str, bool], inverse: bool = True) -> None:
|
|
441
|
+
super().__init__(inverse)
|
|
364
442
|
self.transforms = transforms
|
|
365
443
|
|
|
366
444
|
def transform_shape(self, shape: list[int], cache_attribute: Attribute) -> list[int]:
|
|
@@ -436,6 +514,7 @@ class ResampleTransform(Transform):
|
|
|
436
514
|
class Mask(Transform):
|
|
437
515
|
|
|
438
516
|
def __init__(self, path: str = "./default.mha", value_outside: int = 0) -> None:
|
|
517
|
+
super().__init__()
|
|
439
518
|
self.path = path
|
|
440
519
|
self.value_outside = value_outside
|
|
441
520
|
|
|
@@ -452,13 +531,11 @@ class Mask(Transform):
|
|
|
452
531
|
raise NameError(f"Mask : {self.path}/{name} not found")
|
|
453
532
|
return torch.where(torch.tensor(mask) > 0, tensor, self.value_outside)
|
|
454
533
|
|
|
455
|
-
def inverse(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
|
|
456
|
-
return tensor
|
|
457
|
-
|
|
458
534
|
|
|
459
535
|
class Gradient(Transform):
|
|
460
536
|
|
|
461
537
|
def __init__(self, per_dim: bool = False):
|
|
538
|
+
super().__init__()
|
|
462
539
|
self.per_dim = per_dim
|
|
463
540
|
|
|
464
541
|
@staticmethod
|
|
@@ -492,25 +569,21 @@ class Gradient(Transform):
|
|
|
492
569
|
|
|
493
570
|
return result
|
|
494
571
|
|
|
495
|
-
def inverse(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
|
|
496
|
-
return tensor
|
|
497
|
-
|
|
498
|
-
|
|
499
572
|
class Argmax(Transform):
|
|
500
573
|
|
|
501
574
|
def __init__(self, dim: int = 0) -> None:
|
|
575
|
+
super().__init__()
|
|
502
576
|
self.dim = dim
|
|
503
577
|
|
|
504
578
|
def __call__(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
|
|
505
579
|
return torch.argmax(tensor, dim=self.dim).unsqueeze(self.dim)
|
|
506
580
|
|
|
507
|
-
def inverse(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
|
|
508
|
-
return tensor
|
|
509
581
|
|
|
510
582
|
|
|
511
583
|
class FlatLabel(Transform):
|
|
512
584
|
|
|
513
585
|
def __init__(self, labels: list[int] | None = None) -> None:
|
|
586
|
+
super().__init__()
|
|
514
587
|
self.labels = labels
|
|
515
588
|
|
|
516
589
|
def __call__(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
|
|
@@ -522,21 +595,16 @@ class FlatLabel(Transform):
|
|
|
522
595
|
data[torch.where(tensor > 0)] = 1
|
|
523
596
|
return data
|
|
524
597
|
|
|
525
|
-
def inverse(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
|
|
526
|
-
return tensor
|
|
527
|
-
|
|
528
598
|
|
|
529
599
|
class Save(Transform):
|
|
530
600
|
|
|
531
601
|
def __init__(self, dataset: str) -> None:
|
|
602
|
+
super().__init__()
|
|
532
603
|
self.dataset = dataset
|
|
533
604
|
|
|
534
605
|
def __call__(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
|
|
535
606
|
return tensor
|
|
536
607
|
|
|
537
|
-
def inverse(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
|
|
538
|
-
return tensor
|
|
539
|
-
|
|
540
608
|
|
|
541
609
|
class Flatten(Transform):
|
|
542
610
|
|
|
@@ -549,14 +617,11 @@ class Flatten(Transform):
|
|
|
549
617
|
def __call__(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
|
|
550
618
|
return tensor.flatten()
|
|
551
619
|
|
|
552
|
-
def inverse(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
|
|
553
|
-
return tensor
|
|
554
620
|
|
|
621
|
+
class Permute(TransformInverse):
|
|
555
622
|
|
|
556
|
-
|
|
557
|
-
|
|
558
|
-
def __init__(self, dims: str = "1|0|2") -> None:
|
|
559
|
-
super().__init__()
|
|
623
|
+
def __init__(self, dims: str = "1|0|2", inverse: bool = True) -> None:
|
|
624
|
+
super().__init__(inverse)
|
|
560
625
|
self.dims = [0] + [int(d) + 1 for d in dims.split("|")]
|
|
561
626
|
|
|
562
627
|
def transform_shape(self, shape: list[int], cache_attribute: Attribute) -> list[int]:
|
|
@@ -569,10 +634,10 @@ class Permute(Transform):
|
|
|
569
634
|
return tensor.permute(tuple(np.argsort(self.dims)))
|
|
570
635
|
|
|
571
636
|
|
|
572
|
-
class Flip(
|
|
637
|
+
class Flip(TransformInverse):
|
|
573
638
|
|
|
574
|
-
def __init__(self, dims: str = "1|0|2") -> None:
|
|
575
|
-
super().__init__()
|
|
639
|
+
def __init__(self, dims: str = "1|0|2", inverse: bool = True) -> None:
|
|
640
|
+
super().__init__(inverse)
|
|
576
641
|
|
|
577
642
|
self.dims = [int(d) + 1 for d in str(dims).split("|")]
|
|
578
643
|
|
|
@@ -583,9 +648,10 @@ class Flip(Transform):
|
|
|
583
648
|
return tensor.flip(tuple(self.dims))
|
|
584
649
|
|
|
585
650
|
|
|
586
|
-
class Canonical(
|
|
651
|
+
class Canonical(TransformInverse):
|
|
587
652
|
|
|
588
|
-
def __init__(self) -> None:
|
|
653
|
+
def __init__(self, inverse: bool = True) -> None:
|
|
654
|
+
super().__init__(inverse)
|
|
589
655
|
self.canonical_direction = torch.diag(torch.tensor([-1, -1, 1])).to(torch.double)
|
|
590
656
|
|
|
591
657
|
def __call__(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
|
|
@@ -618,6 +684,7 @@ class Canonical(Transform):
|
|
|
618
684
|
class HistogramMatching(Transform):
|
|
619
685
|
|
|
620
686
|
def __init__(self, reference_group: str) -> None:
|
|
687
|
+
super().__init__()
|
|
621
688
|
self.reference_group = reference_group
|
|
622
689
|
|
|
623
690
|
def __call__(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
|
|
@@ -635,13 +702,11 @@ class HistogramMatching(Transform):
|
|
|
635
702
|
result, _ = image_to_data(matcher.Execute(image, image_ref))
|
|
636
703
|
return torch.tensor(result)
|
|
637
704
|
|
|
638
|
-
def inverse(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
|
|
639
|
-
return tensor
|
|
640
|
-
|
|
641
705
|
|
|
642
706
|
class SelectLabel(Transform):
|
|
643
707
|
|
|
644
708
|
def __init__(self, labels: list[str]) -> None:
|
|
709
|
+
super().__init__()
|
|
645
710
|
self.labels = [label[1:-1].split(",") for label in labels]
|
|
646
711
|
|
|
647
712
|
def __call__(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
|
|
@@ -650,14 +715,11 @@ class SelectLabel(Transform):
|
|
|
650
715
|
data[tensor == int(old_label)] = int(new_label)
|
|
651
716
|
return data
|
|
652
717
|
|
|
653
|
-
def inverse(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
|
|
654
|
-
return tensor
|
|
655
718
|
|
|
719
|
+
class OneHot(TransformInverse):
|
|
656
720
|
|
|
657
|
-
|
|
658
|
-
|
|
659
|
-
def __init__(self, num_classes: int) -> None:
|
|
660
|
-
super().__init__()
|
|
721
|
+
def __init__(self, num_classes: int, inverse: bool = True) -> None:
|
|
722
|
+
super().__init__(inverse)
|
|
661
723
|
self.num_classes = num_classes
|
|
662
724
|
|
|
663
725
|
def __call__(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
|
|
@@ -690,7 +752,4 @@ class TotalSegmentator(Transform):
|
|
|
690
752
|
torch.from_numpy(np.array(np.asanyarray(seg.dataobj), copy=True).astype(np.uint8, copy=False))
|
|
691
753
|
.permute(2, 1, 0)
|
|
692
754
|
.unsqueeze(0)
|
|
693
|
-
)
|
|
694
|
-
|
|
695
|
-
def inverse(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
|
|
696
|
-
return tensor
|
|
755
|
+
)
|
|
@@ -6,8 +6,7 @@ from collections import OrderedDict
|
|
|
6
6
|
from collections.abc import Callable, Iterable, Iterator, Sequence
|
|
7
7
|
from enum import Enum
|
|
8
8
|
from functools import partial
|
|
9
|
-
from typing import Any
|
|
10
|
-
from typing_extensions import Self
|
|
9
|
+
from typing import Any, Self
|
|
11
10
|
|
|
12
11
|
import numpy as np
|
|
13
12
|
import torch
|
|
@@ -16,7 +16,7 @@ from torch.utils.tensorboard.writer import SummaryWriter
|
|
|
16
16
|
from konfai import config_file, konfai_root, models_directory, path_to_models, predictions_directory
|
|
17
17
|
from konfai.data.data_manager import DataPrediction, DatasetIter
|
|
18
18
|
from konfai.data.patching import Accumulator, PathCombine
|
|
19
|
-
from konfai.data.transform import Transform, TransformLoader
|
|
19
|
+
from konfai.data.transform import Transform, TransformInverse, TransformLoader
|
|
20
20
|
from konfai.network.network import CPUModel, ModelLoader, NetState, Network
|
|
21
21
|
from konfai.utils.config import config
|
|
22
22
|
from konfai.utils.dataset import Attribute, Dataset
|
|
@@ -226,11 +226,12 @@ class OutSameAsGroupDataset(OutputDataset):
|
|
|
226
226
|
|
|
227
227
|
if self.inverse_transform:
|
|
228
228
|
for transform in reversed(dataset.groups_src[self.group_src][self.group_dest].patch_transforms):
|
|
229
|
-
|
|
230
|
-
|
|
231
|
-
|
|
232
|
-
|
|
233
|
-
|
|
229
|
+
if isinstance(transform, TransformInverse) and transform.apply_inverse:
|
|
230
|
+
layer = transform.inverse(
|
|
231
|
+
self.names[index_dataset],
|
|
232
|
+
layer,
|
|
233
|
+
self.attributes[index_dataset][index_augmentation][index_patch],
|
|
234
|
+
)
|
|
234
235
|
self.output_layer_accumulator[index_dataset][index_augmentation].add_layer(index_patch, layer)
|
|
235
236
|
|
|
236
237
|
def load(self, name_layer: str, datasets: list[Dataset], groups: dict[str, str]):
|
|
@@ -277,7 +278,8 @@ class OutSameAsGroupDataset(OutputDataset):
|
|
|
277
278
|
|
|
278
279
|
if self.inverse_transform:
|
|
279
280
|
for transform in reversed(dataset.groups_src[self.group_src][self.group_dest].transforms):
|
|
280
|
-
|
|
281
|
+
if isinstance(transform, TransformInverse) and transform.apply_inverse:
|
|
282
|
+
result = transform.inverse(self.names[index], result, self.attributes[index][0][0])
|
|
281
283
|
|
|
282
284
|
for transform in self.final_transforms:
|
|
283
285
|
result = transform(self.names[index], result, self.attributes[index][0][0])
|
|
@@ -715,7 +717,9 @@ class Predictor(DistributedObject):
|
|
|
715
717
|
path = models_directory() + self.name + "/StateDict/"
|
|
716
718
|
name = sorted(os.listdir(path))[-1]
|
|
717
719
|
if os.path.exists(path + name):
|
|
718
|
-
state_dicts.append(
|
|
720
|
+
state_dicts.append(
|
|
721
|
+
torch.load(path + name, map_location=torch.device("cpu"), weights_only=False) # nosec B614
|
|
722
|
+
) # nosec B614
|
|
719
723
|
else:
|
|
720
724
|
raise Exception(f"Model : {path + name} does not exist !")
|
|
721
725
|
return state_dicts
|
|
@@ -1,6 +1,6 @@
|
|
|
1
1
|
Metadata-Version: 2.4
|
|
2
2
|
Name: konfai
|
|
3
|
-
Version: 1.2.
|
|
3
|
+
Version: 1.2.6
|
|
4
4
|
Summary: Modular and configurable Deep Learning framework with YAML and PyTorch
|
|
5
5
|
Author-email: Valentin Boussot <boussot.v@gmail.com>
|
|
6
6
|
License-Expression: Apache-2.0
|
|
@@ -21,6 +21,7 @@ Requires-Dist: SimpleITK
|
|
|
21
21
|
Requires-Dist: lxml
|
|
22
22
|
Requires-Dist: h5py
|
|
23
23
|
Requires-Dist: pynvml
|
|
24
|
+
Requires-Dist: requests
|
|
24
25
|
Provides-Extra: vtk
|
|
25
26
|
Requires-Dist: vtk; extra == "vtk"
|
|
26
27
|
Provides-Extra: lpips
|
|
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
|
|
|
4
4
|
|
|
5
5
|
[project]
|
|
6
6
|
name = "konfai"
|
|
7
|
-
version = "1.2.
|
|
7
|
+
version = "1.2.6"
|
|
8
8
|
description = "Modular and configurable Deep Learning framework with YAML and PyTorch"
|
|
9
9
|
readme = "README.md"
|
|
10
10
|
requires-python = ">=3.8"
|
|
@@ -24,7 +24,8 @@ dependencies = [
|
|
|
24
24
|
"SimpleITK",
|
|
25
25
|
"lxml",
|
|
26
26
|
"h5py",
|
|
27
|
-
"pynvml"
|
|
27
|
+
"pynvml",
|
|
28
|
+
"requests"
|
|
28
29
|
]
|
|
29
30
|
|
|
30
31
|
[project.urls]
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|