konfai 1.1.8__py3-none-any.whl → 1.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of konfai might be problematic. Click here for more details.
- konfai/__init__.py +59 -14
- konfai/data/augmentation.py +457 -286
- konfai/data/data_manager.py +533 -316
- konfai/data/patching.py +300 -183
- konfai/data/transform.py +408 -275
- konfai/evaluator.py +325 -68
- konfai/main.py +71 -22
- konfai/metric/measure.py +360 -244
- konfai/metric/schedulers.py +24 -13
- konfai/models/classification/convNeXt.py +187 -81
- konfai/models/classification/resnet.py +272 -58
- konfai/models/generation/cStyleGan.py +233 -59
- konfai/models/generation/ddpm.py +348 -121
- konfai/models/generation/diffusionGan.py +757 -358
- konfai/models/generation/gan.py +177 -53
- konfai/models/generation/vae.py +140 -40
- konfai/models/registration/registration.py +135 -52
- konfai/models/representation/representation.py +57 -23
- konfai/models/segmentation/NestedUNet.py +339 -68
- konfai/models/segmentation/UNet.py +140 -30
- konfai/network/blocks.py +331 -187
- konfai/network/network.py +795 -427
- konfai/predictor.py +644 -238
- konfai/trainer.py +509 -222
- konfai/utils/ITK.py +191 -106
- konfai/utils/config.py +152 -95
- konfai/utils/dataset.py +326 -455
- konfai/utils/utils.py +497 -249
- {konfai-1.1.8.dist-info → konfai-1.2.0.dist-info}/METADATA +1 -3
- konfai-1.2.0.dist-info/RECORD +38 -0
- konfai/utils/registration.py +0 -199
- konfai-1.1.8.dist-info/RECORD +0 -39
- {konfai-1.1.8.dist-info → konfai-1.2.0.dist-info}/WHEEL +0 -0
- {konfai-1.1.8.dist-info → konfai-1.2.0.dist-info}/entry_points.txt +0 -0
- {konfai-1.1.8.dist-info → konfai-1.2.0.dist-info}/licenses/LICENSE +0 -0
- {konfai-1.1.8.dist-info → konfai-1.2.0.dist-info}/top_level.txt +0 -0
konfai/data/transform.py
CHANGED
|
@@ -1,258 +1,339 @@
|
|
|
1
1
|
import importlib
|
|
2
|
-
import
|
|
3
|
-
import numpy as np
|
|
4
|
-
import SimpleITK as sitk
|
|
2
|
+
import tempfile
|
|
5
3
|
from abc import ABC, abstractmethod
|
|
6
|
-
|
|
7
|
-
|
|
4
|
+
from typing import Any
|
|
5
|
+
|
|
6
|
+
import numpy as np
|
|
7
|
+
import SimpleITK as sitk # noqa: N813
|
|
8
|
+
import torch
|
|
9
|
+
import torch.nn.functional as F # noqa: N812
|
|
8
10
|
|
|
9
|
-
from konfai.utils.utils import _getModule, NeedDevice, _resample_affine, _affine_matrix, TransformError
|
|
10
|
-
from konfai.utils.dataset import Dataset, Attribute, data_to_image, image_to_data
|
|
11
11
|
from konfai.utils.config import config
|
|
12
|
+
from konfai.utils.dataset import Attribute, Dataset, data_to_image, image_to_data
|
|
13
|
+
from konfai.utils.utils import NeedDevice, TransformError, _affine_matrix, _resample_affine, get_module
|
|
14
|
+
|
|
12
15
|
|
|
13
16
|
class Transform(NeedDevice, ABC):
|
|
14
|
-
|
|
17
|
+
|
|
15
18
|
def __init__(self) -> None:
|
|
16
|
-
self.datasets
|
|
17
|
-
|
|
18
|
-
def
|
|
19
|
+
self.datasets: list[Dataset] = []
|
|
20
|
+
|
|
21
|
+
def set_datasets(self, datasets: list[Dataset]):
|
|
19
22
|
self.datasets = datasets
|
|
20
23
|
|
|
21
|
-
def
|
|
24
|
+
def transform_shape(self, shape: list[int], cache_attribute: Attribute) -> list[int]:
|
|
22
25
|
return shape
|
|
23
26
|
|
|
24
27
|
@abstractmethod
|
|
25
|
-
def __call__(self, name: str,
|
|
28
|
+
def __call__(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
|
|
26
29
|
pass
|
|
27
30
|
|
|
28
31
|
@abstractmethod
|
|
29
|
-
def inverse(self, name: str,
|
|
32
|
+
def inverse(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
|
|
30
33
|
pass
|
|
31
34
|
|
|
35
|
+
|
|
32
36
|
class TransformLoader:
|
|
33
37
|
|
|
34
38
|
@config()
|
|
35
39
|
def __init__(self) -> None:
|
|
36
40
|
pass
|
|
37
|
-
|
|
38
|
-
def
|
|
39
|
-
module, name =
|
|
40
|
-
return config("{}.{}"
|
|
41
|
+
|
|
42
|
+
def get_transform(self, classpath: str, konfai_args: str) -> Transform:
|
|
43
|
+
module, name = get_module(classpath, "konfai.data.transform")
|
|
44
|
+
return config(f"{konfai_args}.{classpath}")(getattr(importlib.import_module(module), name))(config=None)
|
|
45
|
+
|
|
41
46
|
|
|
42
47
|
class Clip(Transform):
|
|
43
48
|
|
|
44
|
-
def __init__(
|
|
45
|
-
|
|
49
|
+
def __init__(
|
|
50
|
+
self,
|
|
51
|
+
min_value: float = -1024,
|
|
52
|
+
max_value: float = 1024,
|
|
53
|
+
save_clip_min: bool = False,
|
|
54
|
+
save_clip_max: bool = False,
|
|
55
|
+
) -> None:
|
|
56
|
+
if max_value <= min_value:
|
|
57
|
+
raise ValueError(
|
|
58
|
+
f"[Clip] Invalid clipping range: max_value ({max_value}) must be greater than min_value ({min_value})"
|
|
59
|
+
)
|
|
46
60
|
self.min_value = min_value
|
|
47
61
|
self.max_value = max_value
|
|
48
|
-
self.
|
|
49
|
-
self.
|
|
62
|
+
self.save_clip_min = save_clip_min
|
|
63
|
+
self.save_clip_max = save_clip_max
|
|
50
64
|
|
|
51
|
-
def __call__(self, name: str,
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
if self.
|
|
65
|
+
def __call__(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
|
|
66
|
+
tensor[torch.where(tensor < self.min_value)] = self.min_value
|
|
67
|
+
tensor[torch.where(tensor > self.max_value)] = self.max_value
|
|
68
|
+
if self.save_clip_min:
|
|
55
69
|
cache_attribute["Min"] = self.min_value
|
|
56
|
-
if self.
|
|
70
|
+
if self.save_clip_max:
|
|
57
71
|
cache_attribute["Max"] = self.max_value
|
|
58
|
-
return
|
|
72
|
+
return tensor
|
|
73
|
+
|
|
74
|
+
def inverse(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
|
|
75
|
+
return tensor
|
|
59
76
|
|
|
60
|
-
def inverse(self, name: str, input : torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
|
|
61
|
-
return input
|
|
62
77
|
|
|
63
78
|
class Normalize(Transform):
|
|
64
79
|
|
|
65
|
-
def __init__(
|
|
66
|
-
|
|
80
|
+
def __init__(
|
|
81
|
+
self,
|
|
82
|
+
lazy: bool = False,
|
|
83
|
+
channels: list[int] | None = None,
|
|
84
|
+
min_value: float = -1,
|
|
85
|
+
max_value: float = 1,
|
|
86
|
+
) -> None:
|
|
87
|
+
if max_value <= min_value:
|
|
88
|
+
raise ValueError(
|
|
89
|
+
f"[Normalize] Invalid range: max_value ({max_value}) must be greater than min_value ({min_value})"
|
|
90
|
+
)
|
|
67
91
|
self.lazy = lazy
|
|
68
92
|
self.min_value = min_value
|
|
69
93
|
self.max_value = max_value
|
|
70
94
|
self.channels = channels
|
|
71
95
|
|
|
72
|
-
def __call__(self, name: str,
|
|
96
|
+
def __call__(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
|
|
73
97
|
if "Min" not in cache_attribute:
|
|
74
98
|
if self.channels:
|
|
75
|
-
cache_attribute["Min"] = torch.min(
|
|
99
|
+
cache_attribute["Min"] = torch.min(tensor[self.channels])
|
|
76
100
|
else:
|
|
77
|
-
cache_attribute["Min"] = torch.min(
|
|
101
|
+
cache_attribute["Min"] = torch.min(tensor)
|
|
78
102
|
if "Max" not in cache_attribute:
|
|
79
103
|
if self.channels:
|
|
80
|
-
cache_attribute["Max"] = torch.max(
|
|
104
|
+
cache_attribute["Max"] = torch.max(tensor[self.channels])
|
|
81
105
|
else:
|
|
82
|
-
cache_attribute["Max"] = torch.max(
|
|
106
|
+
cache_attribute["Max"] = torch.max(tensor)
|
|
83
107
|
if not self.lazy:
|
|
84
108
|
input_min = float(cache_attribute["Min"])
|
|
85
109
|
input_max = float(cache_attribute["Max"])
|
|
86
|
-
norm = input_max-input_min
|
|
87
|
-
|
|
88
|
-
if
|
|
89
|
-
for
|
|
90
|
-
|
|
110
|
+
norm = input_max - input_min
|
|
111
|
+
|
|
112
|
+
if norm == 0:
|
|
113
|
+
print(f"[WARNING] Norm is zero for case '{name}': input is constant with value = {self.min_value}.")
|
|
114
|
+
if self.channels:
|
|
115
|
+
for channel in self.channels:
|
|
116
|
+
tensor[channel].fill_(self.min_value)
|
|
117
|
+
else:
|
|
118
|
+
tensor.fill_(self.min_value)
|
|
91
119
|
else:
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
120
|
+
if self.channels:
|
|
121
|
+
for channel in self.channels:
|
|
122
|
+
tensor[channel] = (self.max_value - self.min_value) * (
|
|
123
|
+
tensor[channel] - input_min
|
|
124
|
+
) / norm + self.min_value
|
|
125
|
+
else:
|
|
126
|
+
tensor = (self.max_value - self.min_value) * (tensor - input_min) / norm + self.min_value
|
|
127
|
+
|
|
128
|
+
return tensor
|
|
129
|
+
|
|
130
|
+
def inverse(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
|
|
96
131
|
if self.lazy:
|
|
97
|
-
return
|
|
132
|
+
return tensor
|
|
98
133
|
else:
|
|
99
134
|
input_min = float(cache_attribute.pop("Min"))
|
|
100
135
|
input_max = float(cache_attribute.pop("Max"))
|
|
101
|
-
return (
|
|
136
|
+
return (tensor - self.min_value) * (input_max - input_min) / (self.max_value - self.min_value) + input_min
|
|
137
|
+
|
|
102
138
|
|
|
103
139
|
class Standardize(Transform):
|
|
104
140
|
|
|
105
|
-
def __init__(
|
|
141
|
+
def __init__(
|
|
142
|
+
self,
|
|
143
|
+
lazy: bool = False,
|
|
144
|
+
mean: list[float] | None = None,
|
|
145
|
+
std: list[float] | None = None,
|
|
146
|
+
) -> None:
|
|
106
147
|
self.lazy = lazy
|
|
107
148
|
self.mean = mean
|
|
108
149
|
self.std = std
|
|
109
150
|
|
|
110
|
-
def __call__(self, name: str,
|
|
151
|
+
def __call__(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
|
|
111
152
|
if "Mean" not in cache_attribute:
|
|
112
|
-
cache_attribute["Mean"] =
|
|
153
|
+
cache_attribute["Mean"] = (
|
|
154
|
+
torch.mean(
|
|
155
|
+
tensor.type(torch.float32),
|
|
156
|
+
dim=[i + 1 for i in range(len(tensor.shape) - 1)],
|
|
157
|
+
)
|
|
158
|
+
if self.mean is None
|
|
159
|
+
else torch.tensor([self.mean])
|
|
160
|
+
)
|
|
113
161
|
if "Std" not in cache_attribute:
|
|
114
|
-
cache_attribute["Std"] =
|
|
162
|
+
cache_attribute["Std"] = (
|
|
163
|
+
torch.std(
|
|
164
|
+
tensor.type(torch.float32),
|
|
165
|
+
dim=[i + 1 for i in range(len(tensor.shape) - 1)],
|
|
166
|
+
)
|
|
167
|
+
if self.std is None
|
|
168
|
+
else torch.tensor([self.std])
|
|
169
|
+
)
|
|
115
170
|
|
|
116
171
|
if self.lazy:
|
|
117
|
-
return
|
|
172
|
+
return tensor
|
|
118
173
|
else:
|
|
119
|
-
mean = cache_attribute.get_tensor("Mean").view(-1, *[1 for _ in range(len(
|
|
120
|
-
std = cache_attribute.get_tensor("Std").view(-1, *[1 for _ in range(len(
|
|
121
|
-
return (
|
|
122
|
-
|
|
123
|
-
def inverse(self, name: str,
|
|
174
|
+
mean = cache_attribute.get_tensor("Mean").view(-1, *[1 for _ in range(len(tensor.shape) - 1)])
|
|
175
|
+
std = cache_attribute.get_tensor("Std").view(-1, *[1 for _ in range(len(tensor.shape) - 1)])
|
|
176
|
+
return (tensor - mean) / std
|
|
177
|
+
|
|
178
|
+
def inverse(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
|
|
124
179
|
if self.lazy:
|
|
125
|
-
return
|
|
180
|
+
return tensor
|
|
126
181
|
else:
|
|
127
182
|
mean = float(cache_attribute.pop("Mean"))
|
|
128
183
|
std = float(cache_attribute.pop("Std"))
|
|
129
|
-
return
|
|
130
|
-
|
|
184
|
+
return tensor * std + mean
|
|
185
|
+
|
|
186
|
+
|
|
131
187
|
class TensorCast(Transform):
|
|
132
188
|
|
|
133
|
-
def __init__(self, dtype
|
|
134
|
-
self.dtype
|
|
189
|
+
def __init__(self, dtype: str = "float32") -> None:
|
|
190
|
+
self.dtype: torch.dtype = getattr(torch, dtype)
|
|
191
|
+
|
|
192
|
+
def __call__(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
|
|
193
|
+
cache_attribute["dtype"] = str(tensor.dtype).replace("torch.", "")
|
|
194
|
+
return tensor.type(self.dtype)
|
|
195
|
+
|
|
196
|
+
@staticmethod
|
|
197
|
+
def safe_dtype_cast(dtype_str: str) -> torch.dtype:
|
|
198
|
+
try:
|
|
199
|
+
return getattr(torch, dtype_str)
|
|
200
|
+
except AttributeError:
|
|
201
|
+
raise ValueError(f"Unsupported dtype: {dtype_str}")
|
|
202
|
+
|
|
203
|
+
def inverse(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
|
|
204
|
+
return tensor.to(TensorCast.safe_dtype_cast(cache_attribute.pop("dtype")))
|
|
135
205
|
|
|
136
|
-
def __call__(self, name: str, input : torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
|
|
137
|
-
cache_attribute["dtype"] = input.dtype
|
|
138
|
-
return input.type(self.dtype)
|
|
139
|
-
|
|
140
|
-
def inverse(self, name: str, input : torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
|
|
141
|
-
return input.to(eval(cache_attribute.pop("dtype")))
|
|
142
206
|
|
|
143
207
|
class Padding(Transform):
|
|
144
208
|
|
|
145
|
-
def __init__(self, padding
|
|
209
|
+
def __init__(self, padding: list[int] = [0, 0, 0, 0, 0, 0], mode: str = "constant") -> None:
|
|
146
210
|
self.padding = padding
|
|
147
211
|
self.mode = mode
|
|
148
212
|
|
|
149
|
-
def __call__(self, name: str,
|
|
213
|
+
def __call__(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
|
|
150
214
|
if "Origin" in cache_attribute and "Spacing" in cache_attribute and "Direction" in cache_attribute:
|
|
151
215
|
origin = torch.tensor(cache_attribute.get_np_array("Origin"))
|
|
152
|
-
matrix = torch.tensor(cache_attribute.get_np_array("Direction").reshape((len(origin),len(origin))))
|
|
216
|
+
matrix = torch.tensor(cache_attribute.get_np_array("Direction").reshape((len(origin), len(origin))))
|
|
153
217
|
origin = torch.matmul(origin, matrix)
|
|
154
|
-
for dim in range(len(self.padding)//2):
|
|
155
|
-
origin[-dim-1] -= self.padding[dim*2]* cache_attribute.get_np_array("Spacing")[-dim-1]
|
|
218
|
+
for dim in range(len(self.padding) // 2):
|
|
219
|
+
origin[-dim - 1] -= self.padding[dim * 2] * cache_attribute.get_np_array("Spacing")[-dim - 1]
|
|
156
220
|
cache_attribute["Origin"] = torch.matmul(origin, torch.inverse(matrix))
|
|
157
|
-
result = F.pad(
|
|
221
|
+
result = F.pad(
|
|
222
|
+
tensor.unsqueeze(0),
|
|
223
|
+
tuple(self.padding),
|
|
224
|
+
self.mode.split(":")[0],
|
|
225
|
+
float(self.mode.split(":")[1]) if len(self.mode.split(":")) == 2 else 0,
|
|
226
|
+
).squeeze(0)
|
|
158
227
|
return result
|
|
159
|
-
|
|
160
|
-
def
|
|
161
|
-
for dim in range(len(self.padding)//2):
|
|
162
|
-
shape[-dim-1] += sum(self.padding[dim*2:dim*2+2])
|
|
228
|
+
|
|
229
|
+
def transform_shape(self, shape: list[int], cache_attribute: Attribute) -> list[int]:
|
|
230
|
+
for dim in range(len(self.padding) // 2):
|
|
231
|
+
shape[-dim - 1] += sum(self.padding[dim * 2 : dim * 2 + 2])
|
|
163
232
|
return shape
|
|
164
233
|
|
|
165
|
-
def inverse(self, name: str,
|
|
234
|
+
def inverse(self, name: str, tensor: torch.Tensor, cache_attribute: dict[str, torch.Tensor]) -> torch.Tensor:
|
|
166
235
|
if "Origin" in cache_attribute and "Spacing" in cache_attribute and "Direction" in cache_attribute:
|
|
167
236
|
cache_attribute.pop("Origin")
|
|
168
|
-
slices = [slice(0, shape) for shape in
|
|
169
|
-
for dim in range(len(self.padding)//2):
|
|
170
|
-
slices[-dim-1] = slice(self.padding[dim*2],
|
|
171
|
-
result =
|
|
237
|
+
slices = [slice(0, shape) for shape in tensor.shape]
|
|
238
|
+
for dim in range(len(self.padding) // 2):
|
|
239
|
+
slices[-dim - 1] = slice(self.padding[dim * 2], tensor.shape[-dim - 1] - self.padding[dim * 2 + 1])
|
|
240
|
+
result = tensor[slices]
|
|
172
241
|
return result
|
|
173
242
|
|
|
243
|
+
|
|
174
244
|
class Squeeze(Transform):
|
|
175
245
|
|
|
176
246
|
def __init__(self, dim: int) -> None:
|
|
177
247
|
self.dim = dim
|
|
178
|
-
|
|
179
|
-
def __call__(self, name: str, input : torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
|
|
180
|
-
return input.squeeze(self.dim)
|
|
181
248
|
|
|
182
|
-
def
|
|
183
|
-
return
|
|
249
|
+
def __call__(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
|
|
250
|
+
return tensor.squeeze(self.dim)
|
|
251
|
+
|
|
252
|
+
def inverse(self, name: str, tensor: torch.Tensor, cache_attribute: dict[str, Any]) -> torch.Tensor:
|
|
253
|
+
return tensor.unsqueeze(self.dim)
|
|
254
|
+
|
|
184
255
|
|
|
185
256
|
class Resample(Transform, ABC):
|
|
186
257
|
|
|
187
258
|
def __init__(self) -> None:
|
|
188
259
|
pass
|
|
189
260
|
|
|
190
|
-
def _resample(self,
|
|
191
|
-
|
|
192
|
-
if input.dtype == torch.uint8:
|
|
261
|
+
def _resample(self, tensor: torch.Tensor, size: list[int]) -> torch.Tensor:
|
|
262
|
+
if tensor.dtype == torch.uint8:
|
|
193
263
|
mode = "nearest"
|
|
194
|
-
elif len(
|
|
264
|
+
elif len(tensor.shape) < 4:
|
|
195
265
|
mode = "bilinear"
|
|
196
266
|
else:
|
|
197
267
|
mode = "trilinear"
|
|
198
|
-
return
|
|
268
|
+
return (
|
|
269
|
+
F.interpolate(tensor.type(torch.float32).unsqueeze(0), size=tuple(size), mode=mode)
|
|
270
|
+
.squeeze(0)
|
|
271
|
+
.type(tensor.dtype)
|
|
272
|
+
.cpu()
|
|
273
|
+
)
|
|
199
274
|
|
|
200
275
|
@abstractmethod
|
|
201
|
-
def __call__(self, name: str,
|
|
276
|
+
def __call__(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
|
|
202
277
|
pass
|
|
203
|
-
|
|
278
|
+
|
|
204
279
|
@abstractmethod
|
|
205
|
-
def
|
|
280
|
+
def transform_shape(self, shape: list[int], cache_attribute: Attribute) -> list[int]:
|
|
206
281
|
pass
|
|
207
|
-
|
|
208
|
-
def inverse(self, name: str,
|
|
209
|
-
|
|
282
|
+
|
|
283
|
+
def inverse(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
|
|
284
|
+
cache_attribute.pop_np_array("Size")
|
|
210
285
|
size_1 = cache_attribute.pop_np_array("Size")
|
|
211
286
|
_ = cache_attribute.pop_np_array("Spacing")
|
|
212
|
-
return self._resample(
|
|
287
|
+
return self._resample(tensor, [int(size) for size in size_1])
|
|
288
|
+
|
|
213
289
|
|
|
214
290
|
class ResampleToResolution(Resample):
|
|
215
291
|
|
|
216
|
-
def __init__(self, spacing
|
|
292
|
+
def __init__(self, spacing: list[float] = [1.0, 1.0, 1.0]) -> None:
|
|
217
293
|
self.spacing = torch.tensor([0 if s < 0 else s for s in spacing])
|
|
218
294
|
|
|
219
|
-
def
|
|
295
|
+
def transform_shape(self, shape: list[int], cache_attribute: Attribute) -> list[int]:
|
|
220
296
|
if "Spacing" not in cache_attribute:
|
|
221
|
-
TransformError(
|
|
222
|
-
|
|
297
|
+
TransformError(
|
|
298
|
+
"Missing 'Spacing' in cache attributes, the data is likely not a valid image.",
|
|
299
|
+
"Make sure your input is a image (e.g., .nii, .mha) with proper metadata.",
|
|
300
|
+
)
|
|
223
301
|
if len(shape) != len(self.spacing):
|
|
224
302
|
TransformError("Shape and spacing dimensions do not match: shape={shape}, spacing={self.spacing}")
|
|
225
303
|
image_spacing = cache_attribute.get_tensor("Spacing").flip(0)
|
|
226
304
|
spacing = self.spacing
|
|
227
|
-
|
|
305
|
+
|
|
228
306
|
for i, s in enumerate(self.spacing):
|
|
229
307
|
if s == 0:
|
|
230
308
|
spacing[i] = image_spacing[i]
|
|
231
|
-
resize_factor = spacing/cache_attribute.get_tensor("Spacing").flip(0)
|
|
232
|
-
return [int(x) for x in (torch.tensor(shape) * 1/resize_factor)]
|
|
309
|
+
resize_factor = spacing / cache_attribute.get_tensor("Spacing").flip(0)
|
|
310
|
+
return [int(x) for x in (torch.tensor(shape) * 1 / resize_factor)]
|
|
233
311
|
|
|
234
|
-
def __call__(self, name: str,
|
|
312
|
+
def __call__(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
|
|
235
313
|
image_spacing = cache_attribute.get_tensor("Spacing").flip(0)
|
|
236
314
|
spacing = self.spacing
|
|
237
315
|
for i, s in enumerate(self.spacing):
|
|
238
316
|
if s == 0:
|
|
239
317
|
spacing[i] = image_spacing[i]
|
|
240
|
-
resize_factor = spacing/cache_attribute.get_tensor("Spacing").flip(0)
|
|
318
|
+
resize_factor = spacing / cache_attribute.get_tensor("Spacing").flip(0)
|
|
241
319
|
cache_attribute["Spacing"] = spacing.flip(0)
|
|
242
|
-
cache_attribute["Size"] = np.asarray([int(x) for x in torch.tensor(
|
|
243
|
-
size = [int(x) for x in (torch.tensor(
|
|
320
|
+
cache_attribute["Size"] = np.asarray([int(x) for x in torch.tensor(tensor.shape[1:])])
|
|
321
|
+
size = [int(x) for x in (torch.tensor(tensor.shape[1:]) * 1 / resize_factor)]
|
|
244
322
|
cache_attribute["Size"] = np.asarray(size)
|
|
245
|
-
return self._resample(
|
|
323
|
+
return self._resample(tensor, size)
|
|
324
|
+
|
|
246
325
|
|
|
247
326
|
class ResampleToShape(Resample):
|
|
248
327
|
|
|
249
|
-
def __init__(self, shape
|
|
328
|
+
def __init__(self, shape: list[float] = [100, 256, 256]) -> None:
|
|
250
329
|
self.shape = torch.tensor([0 if s < 0 else s for s in shape])
|
|
251
330
|
|
|
252
|
-
def
|
|
331
|
+
def transform_shape(self, shape: list[int], cache_attribute: Attribute) -> list[int]:
|
|
253
332
|
if "Spacing" not in cache_attribute:
|
|
254
|
-
TransformError(
|
|
255
|
-
|
|
333
|
+
TransformError(
|
|
334
|
+
"Missing 'Spacing' in cache attributes, the data is likely not a valid image.",
|
|
335
|
+
"Make sure your input is a image (e.g., .nii, .mha) with proper metadata.",
|
|
336
|
+
)
|
|
256
337
|
if len(shape) != len(self.shape):
|
|
257
338
|
TransformError("Shape and spacing dimensions do not match: shape={shape}, spacing={self.spacing}")
|
|
258
339
|
new_shape = self.shape
|
|
@@ -260,304 +341,356 @@ class ResampleToShape(Resample):
|
|
|
260
341
|
if s == 0:
|
|
261
342
|
new_shape[i] = shape[i]
|
|
262
343
|
return new_shape
|
|
263
|
-
|
|
264
|
-
def __call__(self, name: str,
|
|
344
|
+
|
|
345
|
+
def __call__(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
|
|
265
346
|
shape = self.shape
|
|
266
|
-
image_shape =
|
|
347
|
+
image_shape = torch.tensor([int(x) for x in torch.tensor(tensor.shape[1:])])
|
|
267
348
|
for i, s in enumerate(self.shape):
|
|
268
349
|
if s == 0:
|
|
269
350
|
shape[i] = image_shape[i]
|
|
270
351
|
if "Spacing" in cache_attribute:
|
|
271
|
-
cache_attribute["Spacing"] = torch.flip(
|
|
352
|
+
cache_attribute["Spacing"] = torch.flip(
|
|
353
|
+
image_shape / shape * torch.flip(cache_attribute.get_tensor("Spacing"), dims=[0]),
|
|
354
|
+
dims=[0],
|
|
355
|
+
)
|
|
272
356
|
cache_attribute["Size"] = image_shape
|
|
273
357
|
cache_attribute["Size"] = shape
|
|
274
|
-
return self._resample(
|
|
358
|
+
return self._resample(tensor, shape)
|
|
359
|
+
|
|
275
360
|
|
|
276
361
|
class ResampleTransform(Transform):
|
|
277
362
|
|
|
278
|
-
def __init__(self, transforms
|
|
363
|
+
def __init__(self, transforms: dict[str, bool]) -> None:
|
|
279
364
|
self.transforms = transforms
|
|
280
|
-
|
|
281
|
-
def
|
|
365
|
+
|
|
366
|
+
def transform_shape(self, shape: list[int], cache_attribute: Attribute) -> list[int]:
|
|
282
367
|
return shape
|
|
283
368
|
|
|
284
|
-
def
|
|
285
|
-
|
|
286
|
-
|
|
287
|
-
|
|
288
|
-
|
|
289
|
-
|
|
290
|
-
|
|
291
|
-
transform = dataset.readTransform(transform_group, name)
|
|
292
|
-
break
|
|
293
|
-
if transform is None:
|
|
294
|
-
raise NameError("Tranform : {}/{} not found".format(transform_group, name))
|
|
295
|
-
if isinstance(transform, sitk.BSplineTransform):
|
|
296
|
-
if invert:
|
|
297
|
-
transformToDisplacementFieldFilter = sitk.TransformToDisplacementFieldFilter()
|
|
298
|
-
transformToDisplacementFieldFilter.SetReferenceImage(image)
|
|
299
|
-
displacementField = transformToDisplacementFieldFilter.Execute(transform)
|
|
300
|
-
iterativeInverseDisplacementFieldImageFilter = sitk.IterativeInverseDisplacementFieldImageFilter()
|
|
301
|
-
iterativeInverseDisplacementFieldImageFilter.SetNumberOfIterations(20)
|
|
302
|
-
inverseDisplacementField = iterativeInverseDisplacementFieldImageFilter.Execute(displacementField)
|
|
303
|
-
transform = sitk.DisplacementFieldTransform(inverseDisplacementField)
|
|
304
|
-
else:
|
|
305
|
-
if invert:
|
|
306
|
-
transform = transform.GetInverse()
|
|
307
|
-
transforms.append(transform)
|
|
308
|
-
result_transform = sitk.CompositeTransform(transforms)
|
|
309
|
-
result = torch.tensor(sitk.GetArrayFromImage(sitk.Resample(image, image, result_transform, sitk.sitkNearestNeighbor if input.dtype == torch.uint8 else sitk.sitkBSpline, 0 if input.dtype == torch.uint8 else -1024))).unsqueeze(0)
|
|
310
|
-
return result.type(torch.uint8) if input.dtype == torch.uint8 else result
|
|
311
|
-
|
|
312
|
-
def __call__(self, name: str, input : torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
|
|
313
|
-
assert len(input.shape) == 4 , "input size should be 5 dim"
|
|
314
|
-
image = data_to_image(input, cache_attribute)
|
|
315
|
-
|
|
316
|
-
vectors = [torch.arange(0, s) for s in input.shape[1:]]
|
|
317
|
-
grids = torch.meshgrid(vectors, indexing='ij')
|
|
369
|
+
def __call__(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
|
|
370
|
+
if len(tensor.shape) != 4:
|
|
371
|
+
raise NameError("Input size should be 5 dim")
|
|
372
|
+
image = data_to_image(tensor, cache_attribute)
|
|
373
|
+
|
|
374
|
+
vectors = [torch.arange(0, s) for s in tensor.shape[1:]]
|
|
375
|
+
grids = torch.meshgrid(vectors, indexing="ij")
|
|
318
376
|
grid = torch.stack(grids)
|
|
319
377
|
grid = torch.unsqueeze(grid, 0)
|
|
320
|
-
|
|
378
|
+
|
|
321
379
|
transforms = []
|
|
322
380
|
for transform_group, invert in self.transforms.items():
|
|
323
381
|
transform = None
|
|
324
382
|
for dataset in self.datasets:
|
|
325
|
-
if dataset.
|
|
326
|
-
transform = dataset.
|
|
383
|
+
if dataset.is_dataset_exist(transform_group, name):
|
|
384
|
+
transform = dataset.read_transform(transform_group, name)
|
|
327
385
|
break
|
|
328
386
|
if transform is None:
|
|
329
|
-
raise NameError("Tranform : {}/{} not found"
|
|
387
|
+
raise NameError(f"Tranform : {transform_group}/{name} not found")
|
|
330
388
|
if isinstance(transform, sitk.BSplineTransform):
|
|
331
389
|
if invert:
|
|
332
|
-
|
|
333
|
-
|
|
334
|
-
|
|
335
|
-
|
|
336
|
-
|
|
337
|
-
|
|
338
|
-
|
|
390
|
+
transform_to_displacement_field_filter = sitk.TransformToDisplacementFieldFilter()
|
|
391
|
+
transform_to_displacement_field_filter.SetReferenceImage(image)
|
|
392
|
+
displacement_field = transform_to_displacement_field_filter.Execute(transform)
|
|
393
|
+
iterative_inverse_displacement_field_image_filter = (
|
|
394
|
+
sitk.IterativeInverseDisplacementFieldImageFilter()
|
|
395
|
+
)
|
|
396
|
+
iterative_inverse_displacement_field_image_filter.SetNumberOfIterations(20)
|
|
397
|
+
inverse_displacement_field = iterative_inverse_displacement_field_image_filter.Execute(
|
|
398
|
+
displacement_field
|
|
399
|
+
)
|
|
400
|
+
transform = sitk.DisplacementFieldTransform(inverse_displacement_field)
|
|
339
401
|
else:
|
|
340
402
|
if invert:
|
|
341
403
|
transform = transform.GetInverse()
|
|
342
404
|
transforms.append(transform)
|
|
343
405
|
result_transform = sitk.CompositeTransform(transforms)
|
|
344
|
-
|
|
345
|
-
|
|
346
|
-
|
|
347
|
-
|
|
348
|
-
new_locs = grid + torch.tensor(
|
|
406
|
+
|
|
407
|
+
transform_to_displacement_field_filter = sitk.TransformToDisplacementFieldFilter()
|
|
408
|
+
transform_to_displacement_field_filter.SetReferenceImage(image)
|
|
409
|
+
transform_to_displacement_field_filter.SetNumberOfThreads(16)
|
|
410
|
+
new_locs = grid + torch.tensor(
|
|
411
|
+
sitk.GetArrayFromImage(transform_to_displacement_field_filter.Execute(result_transform))
|
|
412
|
+
).unsqueeze(0).permute(0, 4, 1, 2, 3)
|
|
349
413
|
shape = new_locs.shape[2:]
|
|
350
414
|
for i in range(len(shape)):
|
|
351
415
|
new_locs[:, i, ...] = 2 * (new_locs[:, i, ...] / (shape[i] - 1) - 0.5)
|
|
352
416
|
new_locs = new_locs.permute(0, 2, 3, 4, 1)
|
|
353
417
|
new_locs = new_locs[..., [2, 1, 0]]
|
|
354
|
-
result =
|
|
355
|
-
|
|
356
|
-
|
|
357
|
-
|
|
358
|
-
|
|
359
|
-
|
|
418
|
+
result = (
|
|
419
|
+
F.grid_sample(
|
|
420
|
+
tensor.to(self.device).unsqueeze(0).float(),
|
|
421
|
+
new_locs.to(self.device).float(),
|
|
422
|
+
align_corners=True,
|
|
423
|
+
padding_mode="border",
|
|
424
|
+
mode="nearest" if tensor.dtype == torch.uint8 else "bilinear",
|
|
425
|
+
)
|
|
426
|
+
.squeeze(0)
|
|
427
|
+
.cpu()
|
|
428
|
+
)
|
|
429
|
+
return result.type(torch.uint8) if tensor.dtype == torch.uint8 else result
|
|
430
|
+
|
|
431
|
+
def inverse(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
|
|
432
|
+
# TODO
|
|
433
|
+
return tensor
|
|
434
|
+
|
|
360
435
|
|
|
361
436
|
class Mask(Transform):
|
|
362
437
|
|
|
363
|
-
def __init__(self, path
|
|
438
|
+
def __init__(self, path: str = "./default.mha", value_outside: int = 0) -> None:
|
|
364
439
|
self.path = path
|
|
365
440
|
self.value_outside = value_outside
|
|
366
|
-
|
|
367
|
-
def __call__(self, name: str,
|
|
441
|
+
|
|
442
|
+
def __call__(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
|
|
368
443
|
if self.path.endswith(".mha"):
|
|
369
444
|
mask = torch.tensor(sitk.GetArrayFromImage(sitk.ReadImage(self.path))).unsqueeze(0)
|
|
370
445
|
else:
|
|
371
446
|
mask = None
|
|
372
447
|
for dataset in self.datasets:
|
|
373
|
-
if dataset.
|
|
374
|
-
mask, _ = dataset.
|
|
448
|
+
if dataset.is_dataset_exist(self.path, name):
|
|
449
|
+
mask, _ = dataset.read_data(self.path, name)
|
|
375
450
|
break
|
|
376
451
|
if mask is None:
|
|
377
|
-
raise NameError("Mask : {}/{} not found"
|
|
378
|
-
return torch.where(torch.tensor(mask) > 0,
|
|
452
|
+
raise NameError(f"Mask : {self.path}/{name} not found")
|
|
453
|
+
return torch.where(torch.tensor(mask) > 0, tensor, self.value_outside)
|
|
454
|
+
|
|
455
|
+
def inverse(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
|
|
456
|
+
return tensor
|
|
457
|
+
|
|
379
458
|
|
|
380
|
-
def inverse(self, name: str, input : torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
|
|
381
|
-
return input
|
|
382
|
-
|
|
383
459
|
class Gradient(Transform):
|
|
384
460
|
|
|
385
461
|
def __init__(self, per_dim: bool = False):
|
|
386
462
|
self.per_dim = per_dim
|
|
387
|
-
|
|
463
|
+
|
|
388
464
|
@staticmethod
|
|
389
|
-
def
|
|
465
|
+
def _image_gradient_2d(image: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
|
390
466
|
dx = image[:, 1:, :] - image[:, :-1, :]
|
|
391
467
|
dy = image[:, :, 1:] - image[:, :, :-1]
|
|
392
|
-
return torch.nn.ConstantPad2d((0,0,0,1), 0)(dx), torch.nn.ConstantPad2d((0,1,0,0), 0)(dy)
|
|
468
|
+
return torch.nn.ConstantPad2d((0, 0, 0, 1), 0)(dx), torch.nn.ConstantPad2d((0, 1, 0, 0), 0)(dy)
|
|
393
469
|
|
|
394
470
|
@staticmethod
|
|
395
|
-
def
|
|
471
|
+
def _image_gradient_3d(
|
|
472
|
+
image: torch.Tensor,
|
|
473
|
+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
396
474
|
dx = image[:, 1:, :, :] - image[:, :-1, :, :]
|
|
397
475
|
dy = image[:, :, 1:, :] - image[:, :, :-1, :]
|
|
398
476
|
dz = image[:, :, :, 1:] - image[:, :, :, :-1]
|
|
399
|
-
return
|
|
400
|
-
|
|
401
|
-
|
|
402
|
-
|
|
477
|
+
return (
|
|
478
|
+
torch.nn.ConstantPad3d((0, 0, 0, 0, 0, 1), 0)(dx),
|
|
479
|
+
torch.nn.ConstantPad3d((0, 0, 0, 1, 0, 0), 0)(dy),
|
|
480
|
+
torch.nn.ConstantPad3d((0, 1, 0, 0, 0, 0), 0)(dz),
|
|
481
|
+
)
|
|
482
|
+
|
|
483
|
+
def __call__(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
|
|
484
|
+
result = torch.stack(
|
|
485
|
+
(Gradient._image_gradient_3d(tensor) if len(tensor.shape) == 4 else Gradient._image_gradient_2d(tensor)),
|
|
486
|
+
dim=1,
|
|
487
|
+
).squeeze(0)
|
|
403
488
|
if not self.per_dim:
|
|
404
|
-
result = torch.sigmoid(result*3)
|
|
489
|
+
result = torch.sigmoid(result * 3)
|
|
405
490
|
result = result.norm(dim=0)
|
|
406
491
|
result = torch.unsqueeze(result, 0)
|
|
407
|
-
|
|
492
|
+
|
|
408
493
|
return result
|
|
409
494
|
|
|
410
|
-
def inverse(self, name: str,
|
|
411
|
-
return
|
|
495
|
+
def inverse(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
|
|
496
|
+
return tensor
|
|
497
|
+
|
|
412
498
|
|
|
413
499
|
class Argmax(Transform):
|
|
414
500
|
|
|
415
501
|
def __init__(self, dim: int = 0) -> None:
|
|
416
502
|
self.dim = dim
|
|
417
|
-
|
|
418
|
-
def __call__(self, name: str,
|
|
419
|
-
return torch.argmax(
|
|
420
|
-
|
|
421
|
-
def inverse(self, name: str,
|
|
422
|
-
return
|
|
423
|
-
|
|
503
|
+
|
|
504
|
+
def __call__(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
|
|
505
|
+
return torch.argmax(tensor, dim=self.dim).unsqueeze(self.dim)
|
|
506
|
+
|
|
507
|
+
def inverse(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
|
|
508
|
+
return tensor
|
|
509
|
+
|
|
510
|
+
|
|
424
511
|
class FlatLabel(Transform):
|
|
425
512
|
|
|
426
|
-
def __init__(self, labels:
|
|
513
|
+
def __init__(self, labels: list[int] | None = None) -> None:
|
|
427
514
|
self.labels = labels
|
|
428
515
|
|
|
429
|
-
def __call__(self, name: str,
|
|
430
|
-
data = torch.zeros_like(
|
|
516
|
+
def __call__(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
|
|
517
|
+
data = torch.zeros_like(tensor)
|
|
431
518
|
if self.labels:
|
|
432
519
|
for label in self.labels:
|
|
433
|
-
data[torch.where(
|
|
520
|
+
data[torch.where(tensor == label)] = 1
|
|
434
521
|
else:
|
|
435
|
-
data[torch.where(
|
|
522
|
+
data[torch.where(tensor > 0)] = 1
|
|
436
523
|
return data
|
|
437
524
|
|
|
438
|
-
def inverse(self, name: str,
|
|
439
|
-
return
|
|
525
|
+
def inverse(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
|
|
526
|
+
return tensor
|
|
527
|
+
|
|
440
528
|
|
|
441
529
|
class Save(Transform):
|
|
442
530
|
|
|
443
531
|
def __init__(self, dataset: str) -> None:
|
|
444
532
|
self.dataset = dataset
|
|
445
|
-
|
|
446
|
-
def __call__(self, name: str, input : torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
|
|
447
|
-
return input
|
|
448
533
|
|
|
449
|
-
def
|
|
450
|
-
return
|
|
534
|
+
def __call__(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
|
|
535
|
+
return tensor
|
|
536
|
+
|
|
537
|
+
def inverse(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
|
|
538
|
+
return tensor
|
|
539
|
+
|
|
451
540
|
|
|
452
541
|
class Flatten(Transform):
|
|
453
542
|
|
|
454
543
|
def __init__(self) -> None:
|
|
455
544
|
super().__init__()
|
|
456
545
|
|
|
457
|
-
def
|
|
546
|
+
def transform_shape(self, shape: list[int], cache_attribute: Attribute) -> list[int]:
|
|
458
547
|
return [np.prod(np.asarray(shape))]
|
|
459
548
|
|
|
460
|
-
def __call__(self, name: str,
|
|
461
|
-
return
|
|
462
|
-
|
|
463
|
-
def inverse(self, name: str,
|
|
464
|
-
return
|
|
549
|
+
def __call__(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
|
|
550
|
+
return tensor.flatten()
|
|
551
|
+
|
|
552
|
+
def inverse(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
|
|
553
|
+
return tensor
|
|
554
|
+
|
|
465
555
|
|
|
466
556
|
class Permute(Transform):
|
|
467
557
|
|
|
468
558
|
def __init__(self, dims: str = "1|0|2") -> None:
|
|
469
559
|
super().__init__()
|
|
470
|
-
self.dims = [0]+[int(d)+1 for d in dims.split("|")]
|
|
471
|
-
|
|
472
|
-
def
|
|
473
|
-
return [shape[it-1] for it in self.dims[1:]]
|
|
474
|
-
|
|
475
|
-
def __call__(self, name: str,
|
|
476
|
-
return
|
|
477
|
-
|
|
478
|
-
def inverse(self, name: str,
|
|
479
|
-
return
|
|
480
|
-
|
|
560
|
+
self.dims = [0] + [int(d) + 1 for d in dims.split("|")]
|
|
561
|
+
|
|
562
|
+
def transform_shape(self, shape: list[int], cache_attribute: Attribute) -> list[int]:
|
|
563
|
+
return [shape[it - 1] for it in self.dims[1:]]
|
|
564
|
+
|
|
565
|
+
def __call__(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
|
|
566
|
+
return tensor.permute(tuple(self.dims))
|
|
567
|
+
|
|
568
|
+
def inverse(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
|
|
569
|
+
return tensor.permute(tuple(np.argsort(self.dims)))
|
|
570
|
+
|
|
571
|
+
|
|
481
572
|
class Flip(Transform):
|
|
482
573
|
|
|
483
574
|
def __init__(self, dims: str = "1|0|2") -> None:
|
|
484
575
|
super().__init__()
|
|
485
576
|
|
|
486
|
-
self.dims = [int(d)+1 for d in str(dims).split("|")]
|
|
577
|
+
self.dims = [int(d) + 1 for d in str(dims).split("|")]
|
|
578
|
+
|
|
579
|
+
def __call__(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
|
|
580
|
+
return tensor.flip(tuple(self.dims))
|
|
581
|
+
|
|
582
|
+
def inverse(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
|
|
583
|
+
return tensor.flip(tuple(self.dims))
|
|
487
584
|
|
|
488
|
-
def __call__(self, name: str, input : torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
|
|
489
|
-
return input.flip(tuple(self.dims))
|
|
490
|
-
|
|
491
|
-
def inverse(self, name: str, input : torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
|
|
492
|
-
return input.flip(tuple(self.dims))
|
|
493
585
|
|
|
494
586
|
class Canonical(Transform):
|
|
495
587
|
|
|
496
588
|
def __init__(self) -> None:
|
|
497
589
|
self.canonical_direction = torch.diag(torch.tensor([-1, -1, 1])).to(torch.double)
|
|
498
590
|
|
|
499
|
-
def __call__(self, name: str,
|
|
591
|
+
def __call__(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
|
|
500
592
|
spacing = cache_attribute.get_tensor("Spacing")
|
|
501
|
-
initial_matrix = cache_attribute.get_tensor("Direction").reshape(3,3).to(torch.double)
|
|
593
|
+
initial_matrix = cache_attribute.get_tensor("Direction").reshape(3, 3).to(torch.double)
|
|
502
594
|
initial_origin = cache_attribute.get_tensor("Origin")
|
|
503
595
|
cache_attribute["Direction"] = (self.canonical_direction).flatten()
|
|
504
596
|
matrix = _affine_matrix(self.canonical_direction @ initial_matrix.inverse(), torch.tensor([0, 0, 0]))
|
|
505
|
-
center_voxel = torch.tensor(
|
|
597
|
+
center_voxel = torch.tensor(
|
|
598
|
+
[(tensor.shape[-i - 1] - 1) * spacing[i] / 2 for i in range(3)],
|
|
599
|
+
dtype=torch.double,
|
|
600
|
+
)
|
|
506
601
|
center_physical = initial_matrix @ center_voxel + initial_origin
|
|
507
602
|
cache_attribute["Origin"] = center_physical - (self.canonical_direction @ center_voxel)
|
|
508
|
-
return _resample_affine(
|
|
509
|
-
|
|
510
|
-
def inverse(self, name: str,
|
|
603
|
+
return _resample_affine(tensor, matrix.unsqueeze(0))
|
|
604
|
+
|
|
605
|
+
def inverse(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
|
|
511
606
|
cache_attribute.pop("Direction")
|
|
512
607
|
cache_attribute.pop("Origin")
|
|
513
|
-
matrix = _affine_matrix(
|
|
514
|
-
|
|
608
|
+
matrix = _affine_matrix(
|
|
609
|
+
(
|
|
610
|
+
self.canonical_direction
|
|
611
|
+
@ cache_attribute.get_tensor("Direction").to(torch.double).reshape(3, 3).inverse()
|
|
612
|
+
).inverse(),
|
|
613
|
+
torch.tensor([0, 0, 0]),
|
|
614
|
+
)
|
|
615
|
+
return _resample_affine(tensor, matrix.unsqueeze(0))
|
|
616
|
+
|
|
515
617
|
|
|
516
618
|
class HistogramMatching(Transform):
|
|
517
619
|
|
|
518
620
|
def __init__(self, reference_group: str) -> None:
|
|
519
621
|
self.reference_group = reference_group
|
|
520
622
|
|
|
521
|
-
def __call__(self, name: str,
|
|
522
|
-
image = data_to_image(
|
|
623
|
+
def __call__(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
|
|
624
|
+
image = data_to_image(tensor, cache_attribute)
|
|
523
625
|
image_ref = None
|
|
524
626
|
for dataset in self.datasets:
|
|
525
|
-
if dataset.
|
|
526
|
-
image_ref = dataset.
|
|
627
|
+
if dataset.is_dataset_exist(self.reference_group, name):
|
|
628
|
+
image_ref = dataset.read_image(self.reference_group, name)
|
|
527
629
|
if image_ref is None:
|
|
528
|
-
|
|
630
|
+
raise NameError(f"Image : {self.reference_group}/{name} not found")
|
|
529
631
|
matcher = sitk.HistogramMatchingImageFilter()
|
|
530
632
|
matcher.SetNumberOfHistogramLevels(256)
|
|
531
633
|
matcher.SetNumberOfMatchPoints(1)
|
|
532
634
|
matcher.SetThresholdAtMeanIntensity(True)
|
|
533
635
|
result, _ = image_to_data(matcher.Execute(image, image_ref))
|
|
534
636
|
return torch.tensor(result)
|
|
535
|
-
|
|
536
|
-
def inverse(self, name: str,
|
|
537
|
-
return
|
|
637
|
+
|
|
638
|
+
def inverse(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
|
|
639
|
+
return tensor
|
|
640
|
+
|
|
538
641
|
|
|
539
642
|
class SelectLabel(Transform):
|
|
540
643
|
|
|
541
644
|
def __init__(self, labels: list[str]) -> None:
|
|
542
|
-
self.labels = [
|
|
543
|
-
|
|
544
|
-
|
|
645
|
+
self.labels = [label[1:-1].split(",") for label in labels]
|
|
646
|
+
|
|
647
|
+
def __call__(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
|
|
648
|
+
data = torch.zeros_like(tensor)
|
|
545
649
|
for old_label, new_label in self.labels:
|
|
546
|
-
data[
|
|
650
|
+
data[tensor == int(old_label)] = int(new_label)
|
|
547
651
|
return data
|
|
548
|
-
|
|
549
|
-
def inverse(self, name: str,
|
|
550
|
-
return
|
|
551
|
-
|
|
652
|
+
|
|
653
|
+
def inverse(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
|
|
654
|
+
return tensor
|
|
655
|
+
|
|
656
|
+
|
|
552
657
|
class OneHot(Transform):
|
|
553
|
-
|
|
658
|
+
|
|
554
659
|
def __init__(self, num_classes: int) -> None:
|
|
555
660
|
super().__init__()
|
|
556
661
|
self.num_classes = num_classes
|
|
557
662
|
|
|
558
|
-
def __call__(self, name: str,
|
|
559
|
-
result =
|
|
663
|
+
def __call__(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
|
|
664
|
+
result = (
|
|
665
|
+
F.one_hot(tensor.type(torch.int64), num_classes=self.num_classes)
|
|
666
|
+
.permute(0, len(tensor.shape), *[i + 1 for i in range(len(tensor.shape) - 1)])
|
|
667
|
+
.float()
|
|
668
|
+
.squeeze(0)
|
|
669
|
+
)
|
|
560
670
|
return result
|
|
561
|
-
|
|
562
|
-
def inverse(self, name: str,
|
|
563
|
-
return torch.argmax(
|
|
671
|
+
|
|
672
|
+
def inverse(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
|
|
673
|
+
return torch.argmax(tensor, dim=1).unsqueeze(1)
|
|
674
|
+
|
|
675
|
+
|
|
676
|
+
class TotalSegmentator(Transform):
|
|
677
|
+
|
|
678
|
+
def __init__(self, task: str = "total"):
|
|
679
|
+
super().__init__()
|
|
680
|
+
self.task = task
|
|
681
|
+
|
|
682
|
+
def __call__(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
|
|
683
|
+
from totalsegmentator.python_api import totalsegmentator
|
|
684
|
+
|
|
685
|
+
with tempfile.TemporaryDirectory() as tmpdir:
|
|
686
|
+
image = data_to_image(tensor.numpy(), cache_attribute)
|
|
687
|
+
sitk.WriteImage(image, tmpdir + "/image.nii.gz")
|
|
688
|
+
seg = totalsegmentator(tmpdir + "/image.nii.gz", tmpdir, task=self.task, skip_saving=True, quiet=True)
|
|
689
|
+
return (
|
|
690
|
+
torch.from_numpy(np.array(np.asanyarray(seg.dataobj), copy=True).astype(np.uint8, copy=False))
|
|
691
|
+
.permute(2, 1, 0)
|
|
692
|
+
.unsqueeze(0)
|
|
693
|
+
)
|
|
694
|
+
|
|
695
|
+
def inverse(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
|
|
696
|
+
return tensor
|