konfai 1.1.7__py3-none-any.whl → 1.1.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of konfai might be problematic. Click here for more details.
- konfai/__init__.py +59 -14
- konfai/data/augmentation.py +457 -286
- konfai/data/data_manager.py +509 -290
- konfai/data/patching.py +300 -183
- konfai/data/transform.py +384 -277
- konfai/evaluator.py +309 -68
- konfai/main.py +71 -22
- konfai/metric/measure.py +341 -222
- konfai/metric/schedulers.py +24 -13
- konfai/models/classification/convNeXt.py +187 -81
- konfai/models/classification/resnet.py +272 -58
- konfai/models/generation/cStyleGan.py +233 -59
- konfai/models/generation/ddpm.py +348 -121
- konfai/models/generation/diffusionGan.py +757 -358
- konfai/models/generation/gan.py +177 -53
- konfai/models/generation/vae.py +140 -40
- konfai/models/registration/registration.py +135 -52
- konfai/models/representation/representation.py +57 -23
- konfai/models/segmentation/NestedUNet.py +339 -68
- konfai/models/segmentation/UNet.py +140 -30
- konfai/network/blocks.py +331 -187
- konfai/network/network.py +781 -423
- konfai/predictor.py +645 -240
- konfai/trainer.py +527 -216
- konfai/utils/ITK.py +191 -106
- konfai/utils/config.py +152 -95
- konfai/utils/dataset.py +326 -455
- konfai/utils/utils.py +495 -249
- {konfai-1.1.7.dist-info → konfai-1.1.9.dist-info}/METADATA +1 -3
- konfai-1.1.9.dist-info/RECORD +38 -0
- konfai/utils/registration.py +0 -199
- konfai-1.1.7.dist-info/RECORD +0 -39
- {konfai-1.1.7.dist-info → konfai-1.1.9.dist-info}/WHEEL +0 -0
- {konfai-1.1.7.dist-info → konfai-1.1.9.dist-info}/entry_points.txt +0 -0
- {konfai-1.1.7.dist-info → konfai-1.1.9.dist-info}/licenses/LICENSE +0 -0
- {konfai-1.1.7.dist-info → konfai-1.1.9.dist-info}/top_level.txt +0 -0
konfai/data/transform.py
CHANGED
|
@@ -1,565 +1,672 @@
|
|
|
1
1
|
import importlib
|
|
2
|
-
import torch
|
|
3
|
-
import numpy as np
|
|
4
|
-
import SimpleITK as sitk
|
|
5
2
|
from abc import ABC, abstractmethod
|
|
6
|
-
|
|
7
|
-
|
|
3
|
+
from typing import Any
|
|
4
|
+
|
|
5
|
+
import numpy as np
|
|
6
|
+
import SimpleITK as sitk # noqa: N813
|
|
7
|
+
import torch
|
|
8
|
+
import torch.nn.functional as F # noqa: N812
|
|
8
9
|
|
|
9
|
-
from konfai.utils.utils import _getModule, NeedDevice, _resample_affine, _affine_matrix, TransformError
|
|
10
|
-
from konfai.utils.dataset import Dataset, Attribute, data_to_image, image_to_data
|
|
11
10
|
from konfai.utils.config import config
|
|
11
|
+
from konfai.utils.dataset import Attribute, Dataset, data_to_image, image_to_data
|
|
12
|
+
from konfai.utils.utils import NeedDevice, TransformError, _affine_matrix, _resample_affine, get_module
|
|
13
|
+
|
|
12
14
|
|
|
13
15
|
class Transform(NeedDevice, ABC):
|
|
14
|
-
|
|
16
|
+
|
|
15
17
|
def __init__(self) -> None:
|
|
16
|
-
self.datasets
|
|
17
|
-
|
|
18
|
-
def
|
|
18
|
+
self.datasets: list[Dataset] = []
|
|
19
|
+
|
|
20
|
+
def set_datasets(self, datasets: list[Dataset]):
|
|
19
21
|
self.datasets = datasets
|
|
20
22
|
|
|
21
|
-
def
|
|
23
|
+
def transform_shape(self, shape: list[int], cache_attribute: Attribute) -> list[int]:
|
|
22
24
|
return shape
|
|
23
25
|
|
|
24
26
|
@abstractmethod
|
|
25
|
-
def __call__(self, name: str,
|
|
27
|
+
def __call__(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
|
|
26
28
|
pass
|
|
27
29
|
|
|
28
30
|
@abstractmethod
|
|
29
|
-
def inverse(self, name: str,
|
|
31
|
+
def inverse(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
|
|
30
32
|
pass
|
|
31
33
|
|
|
34
|
+
|
|
32
35
|
class TransformLoader:
|
|
33
36
|
|
|
34
37
|
@config()
|
|
35
38
|
def __init__(self) -> None:
|
|
36
39
|
pass
|
|
37
|
-
|
|
38
|
-
def
|
|
39
|
-
module, name =
|
|
40
|
-
return config("{}.{}"
|
|
40
|
+
|
|
41
|
+
def get_transform(self, classpath: str, konfai_args: str) -> Transform:
|
|
42
|
+
module, name = get_module(classpath, "konfai.data.transform")
|
|
43
|
+
return config(f"{konfai_args}.{classpath}")(getattr(importlib.import_module(module), name))(config=None)
|
|
44
|
+
|
|
41
45
|
|
|
42
46
|
class Clip(Transform):
|
|
43
47
|
|
|
44
|
-
def __init__(
|
|
45
|
-
|
|
48
|
+
def __init__(
|
|
49
|
+
self,
|
|
50
|
+
min_value: float = -1024,
|
|
51
|
+
max_value: float = 1024,
|
|
52
|
+
save_clip_min: bool = False,
|
|
53
|
+
save_clip_max: bool = False,
|
|
54
|
+
) -> None:
|
|
55
|
+
if max_value <= min_value:
|
|
56
|
+
raise ValueError(
|
|
57
|
+
f"[Clip] Invalid clipping range: max_value ({max_value}) must be greater than min_value ({min_value})"
|
|
58
|
+
)
|
|
46
59
|
self.min_value = min_value
|
|
47
60
|
self.max_value = max_value
|
|
48
|
-
self.
|
|
49
|
-
self.
|
|
61
|
+
self.save_clip_min = save_clip_min
|
|
62
|
+
self.save_clip_max = save_clip_max
|
|
50
63
|
|
|
51
|
-
def __call__(self, name: str,
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
if self.
|
|
64
|
+
def __call__(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
|
|
65
|
+
tensor[torch.where(tensor < self.min_value)] = self.min_value
|
|
66
|
+
tensor[torch.where(tensor > self.max_value)] = self.max_value
|
|
67
|
+
if self.save_clip_min:
|
|
55
68
|
cache_attribute["Min"] = self.min_value
|
|
56
|
-
if self.
|
|
69
|
+
if self.save_clip_max:
|
|
57
70
|
cache_attribute["Max"] = self.max_value
|
|
58
|
-
return
|
|
71
|
+
return tensor
|
|
72
|
+
|
|
73
|
+
def inverse(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
|
|
74
|
+
return tensor
|
|
59
75
|
|
|
60
|
-
def inverse(self, name: str, input : torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
|
|
61
|
-
return input
|
|
62
76
|
|
|
63
77
|
class Normalize(Transform):
|
|
64
78
|
|
|
65
|
-
def __init__(
|
|
66
|
-
|
|
79
|
+
def __init__(
|
|
80
|
+
self,
|
|
81
|
+
lazy: bool = False,
|
|
82
|
+
channels: list[int] | None = None,
|
|
83
|
+
min_value: float = -1,
|
|
84
|
+
max_value: float = 1,
|
|
85
|
+
) -> None:
|
|
86
|
+
if max_value <= min_value:
|
|
87
|
+
raise ValueError(
|
|
88
|
+
f"[Normalize] Invalid range: max_value ({max_value}) must be greater than min_value ({min_value})"
|
|
89
|
+
)
|
|
67
90
|
self.lazy = lazy
|
|
68
91
|
self.min_value = min_value
|
|
69
92
|
self.max_value = max_value
|
|
70
93
|
self.channels = channels
|
|
71
94
|
|
|
72
|
-
def __call__(self, name: str,
|
|
95
|
+
def __call__(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
|
|
73
96
|
if "Min" not in cache_attribute:
|
|
74
97
|
if self.channels:
|
|
75
|
-
cache_attribute["Min"] = torch.min(
|
|
98
|
+
cache_attribute["Min"] = torch.min(tensor[self.channels])
|
|
76
99
|
else:
|
|
77
|
-
cache_attribute["Min"] = torch.min(
|
|
100
|
+
cache_attribute["Min"] = torch.min(tensor)
|
|
78
101
|
if "Max" not in cache_attribute:
|
|
79
102
|
if self.channels:
|
|
80
|
-
cache_attribute["Max"] = torch.max(
|
|
103
|
+
cache_attribute["Max"] = torch.max(tensor[self.channels])
|
|
81
104
|
else:
|
|
82
|
-
cache_attribute["Max"] = torch.max(
|
|
105
|
+
cache_attribute["Max"] = torch.max(tensor)
|
|
83
106
|
if not self.lazy:
|
|
84
107
|
input_min = float(cache_attribute["Min"])
|
|
85
108
|
input_max = float(cache_attribute["Max"])
|
|
86
|
-
norm = input_max-input_min
|
|
87
|
-
|
|
88
|
-
if
|
|
89
|
-
for
|
|
90
|
-
|
|
109
|
+
norm = input_max - input_min
|
|
110
|
+
|
|
111
|
+
if norm == 0:
|
|
112
|
+
print(f"[WARNING] Norm is zero for case '{name}': input is constant with value = {self.min_value}.")
|
|
113
|
+
if self.channels:
|
|
114
|
+
for channel in self.channels:
|
|
115
|
+
tensor[channel].fill_(self.min_value)
|
|
116
|
+
else:
|
|
117
|
+
tensor.fill_(self.min_value)
|
|
91
118
|
else:
|
|
92
|
-
|
|
93
|
-
|
|
94
|
-
|
|
95
|
-
|
|
119
|
+
if self.channels:
|
|
120
|
+
for channel in self.channels:
|
|
121
|
+
tensor[channel] = (self.max_value - self.min_value) * (
|
|
122
|
+
tensor[channel] - input_min
|
|
123
|
+
) / norm + self.min_value
|
|
124
|
+
else:
|
|
125
|
+
tensor = (self.max_value - self.min_value) * (tensor - input_min) / norm + self.min_value
|
|
126
|
+
|
|
127
|
+
return tensor
|
|
128
|
+
|
|
129
|
+
def inverse(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
|
|
96
130
|
if self.lazy:
|
|
97
|
-
return
|
|
131
|
+
return tensor
|
|
98
132
|
else:
|
|
99
133
|
input_min = float(cache_attribute.pop("Min"))
|
|
100
134
|
input_max = float(cache_attribute.pop("Max"))
|
|
101
|
-
return (
|
|
135
|
+
return (tensor - self.min_value) * (input_max - input_min) / (self.max_value - self.min_value) + input_min
|
|
136
|
+
|
|
102
137
|
|
|
103
138
|
class Standardize(Transform):
|
|
104
139
|
|
|
105
|
-
def __init__(
|
|
140
|
+
def __init__(
|
|
141
|
+
self,
|
|
142
|
+
lazy: bool = False,
|
|
143
|
+
mean: list[float] | None = None,
|
|
144
|
+
std: list[float] | None = None,
|
|
145
|
+
) -> None:
|
|
106
146
|
self.lazy = lazy
|
|
107
147
|
self.mean = mean
|
|
108
148
|
self.std = std
|
|
109
149
|
|
|
110
|
-
def __call__(self, name: str,
|
|
150
|
+
def __call__(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
|
|
111
151
|
if "Mean" not in cache_attribute:
|
|
112
|
-
cache_attribute["Mean"] =
|
|
152
|
+
cache_attribute["Mean"] = (
|
|
153
|
+
torch.mean(
|
|
154
|
+
tensor.type(torch.float32),
|
|
155
|
+
dim=[i + 1 for i in range(len(tensor.shape) - 1)],
|
|
156
|
+
)
|
|
157
|
+
if self.mean is None
|
|
158
|
+
else torch.tensor([self.mean])
|
|
159
|
+
)
|
|
113
160
|
if "Std" not in cache_attribute:
|
|
114
|
-
cache_attribute["Std"] =
|
|
161
|
+
cache_attribute["Std"] = (
|
|
162
|
+
torch.std(
|
|
163
|
+
tensor.type(torch.float32),
|
|
164
|
+
dim=[i + 1 for i in range(len(tensor.shape) - 1)],
|
|
165
|
+
)
|
|
166
|
+
if self.std is None
|
|
167
|
+
else torch.tensor([self.std])
|
|
168
|
+
)
|
|
115
169
|
|
|
116
170
|
if self.lazy:
|
|
117
|
-
return
|
|
171
|
+
return tensor
|
|
118
172
|
else:
|
|
119
|
-
mean = cache_attribute.get_tensor("Mean").view(-1, *[1 for _ in range(len(
|
|
120
|
-
std = cache_attribute.get_tensor("Std").view(-1, *[1 for _ in range(len(
|
|
121
|
-
return (
|
|
122
|
-
|
|
123
|
-
def inverse(self, name: str,
|
|
173
|
+
mean = cache_attribute.get_tensor("Mean").view(-1, *[1 for _ in range(len(tensor.shape) - 1)])
|
|
174
|
+
std = cache_attribute.get_tensor("Std").view(-1, *[1 for _ in range(len(tensor.shape) - 1)])
|
|
175
|
+
return (tensor - mean) / std
|
|
176
|
+
|
|
177
|
+
def inverse(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
|
|
124
178
|
if self.lazy:
|
|
125
|
-
return
|
|
179
|
+
return tensor
|
|
126
180
|
else:
|
|
127
181
|
mean = float(cache_attribute.pop("Mean"))
|
|
128
182
|
std = float(cache_attribute.pop("Std"))
|
|
129
|
-
return
|
|
130
|
-
|
|
183
|
+
return tensor * std + mean
|
|
184
|
+
|
|
185
|
+
|
|
131
186
|
class TensorCast(Transform):
|
|
132
187
|
|
|
133
|
-
def __init__(self, dtype
|
|
134
|
-
self.dtype
|
|
188
|
+
def __init__(self, dtype: str = "float32") -> None:
|
|
189
|
+
self.dtype: torch.dtype = getattr(torch, dtype)
|
|
190
|
+
|
|
191
|
+
def __call__(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
|
|
192
|
+
cache_attribute["dtype"] = str(tensor.dtype).replace("torch.", "")
|
|
193
|
+
return tensor.type(self.dtype)
|
|
194
|
+
|
|
195
|
+
@staticmethod
|
|
196
|
+
def safe_dtype_cast(dtype_str: str) -> torch.dtype:
|
|
197
|
+
try:
|
|
198
|
+
return getattr(torch, dtype_str)
|
|
199
|
+
except AttributeError:
|
|
200
|
+
raise ValueError(f"Unsupported dtype: {dtype_str}")
|
|
201
|
+
|
|
202
|
+
def inverse(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
|
|
203
|
+
return tensor.to(TensorCast.safe_dtype_cast(cache_attribute.pop("dtype")))
|
|
135
204
|
|
|
136
|
-
def __call__(self, name: str, input : torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
|
|
137
|
-
cache_attribute["dtype"] = input.dtype
|
|
138
|
-
return input.type(self.dtype)
|
|
139
|
-
|
|
140
|
-
def inverse(self, name: str, input : torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
|
|
141
|
-
return input.to(eval(cache_attribute.pop("dtype")))
|
|
142
205
|
|
|
143
206
|
class Padding(Transform):
|
|
144
207
|
|
|
145
|
-
def __init__(self, padding
|
|
208
|
+
def __init__(self, padding: list[int] = [0, 0, 0, 0, 0, 0], mode: str = "constant") -> None:
|
|
146
209
|
self.padding = padding
|
|
147
210
|
self.mode = mode
|
|
148
211
|
|
|
149
|
-
def __call__(self, name: str,
|
|
212
|
+
def __call__(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
|
|
150
213
|
if "Origin" in cache_attribute and "Spacing" in cache_attribute and "Direction" in cache_attribute:
|
|
151
214
|
origin = torch.tensor(cache_attribute.get_np_array("Origin"))
|
|
152
|
-
matrix = torch.tensor(cache_attribute.get_np_array("Direction").reshape((len(origin),len(origin))))
|
|
215
|
+
matrix = torch.tensor(cache_attribute.get_np_array("Direction").reshape((len(origin), len(origin))))
|
|
153
216
|
origin = torch.matmul(origin, matrix)
|
|
154
|
-
for dim in range(len(self.padding)//2):
|
|
155
|
-
origin[-dim-1] -= self.padding[dim*2]* cache_attribute.get_np_array("Spacing")[-dim-1]
|
|
217
|
+
for dim in range(len(self.padding) // 2):
|
|
218
|
+
origin[-dim - 1] -= self.padding[dim * 2] * cache_attribute.get_np_array("Spacing")[-dim - 1]
|
|
156
219
|
cache_attribute["Origin"] = torch.matmul(origin, torch.inverse(matrix))
|
|
157
|
-
result = F.pad(
|
|
220
|
+
result = F.pad(
|
|
221
|
+
tensor.unsqueeze(0),
|
|
222
|
+
tuple(self.padding),
|
|
223
|
+
self.mode.split(":")[0],
|
|
224
|
+
float(self.mode.split(":")[1]) if len(self.mode.split(":")) == 2 else 0,
|
|
225
|
+
).squeeze(0)
|
|
158
226
|
return result
|
|
159
|
-
|
|
160
|
-
def
|
|
161
|
-
for dim in range(len(self.padding)//2):
|
|
162
|
-
shape[-dim-1] += sum(self.padding[dim*2:dim*2+2])
|
|
227
|
+
|
|
228
|
+
def transform_shape(self, shape: list[int], cache_attribute: Attribute) -> list[int]:
|
|
229
|
+
for dim in range(len(self.padding) // 2):
|
|
230
|
+
shape[-dim - 1] += sum(self.padding[dim * 2 : dim * 2 + 2])
|
|
163
231
|
return shape
|
|
164
232
|
|
|
165
|
-
def inverse(self, name: str,
|
|
233
|
+
def inverse(self, name: str, tensor: torch.Tensor, cache_attribute: dict[str, torch.Tensor]) -> torch.Tensor:
|
|
166
234
|
if "Origin" in cache_attribute and "Spacing" in cache_attribute and "Direction" in cache_attribute:
|
|
167
235
|
cache_attribute.pop("Origin")
|
|
168
|
-
slices = [slice(0, shape) for shape in
|
|
169
|
-
for dim in range(len(self.padding)//2):
|
|
170
|
-
slices[-dim-1] = slice(self.padding[dim*2],
|
|
171
|
-
result =
|
|
236
|
+
slices = [slice(0, shape) for shape in tensor.shape]
|
|
237
|
+
for dim in range(len(self.padding) // 2):
|
|
238
|
+
slices[-dim - 1] = slice(self.padding[dim * 2], tensor.shape[-dim - 1] - self.padding[dim * 2 + 1])
|
|
239
|
+
result = tensor[slices]
|
|
172
240
|
return result
|
|
173
241
|
|
|
242
|
+
|
|
174
243
|
class Squeeze(Transform):
|
|
175
244
|
|
|
176
245
|
def __init__(self, dim: int) -> None:
|
|
177
246
|
self.dim = dim
|
|
178
|
-
|
|
179
|
-
def __call__(self, name: str, input : torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
|
|
180
|
-
return input.squeeze(self.dim)
|
|
181
247
|
|
|
182
|
-
def
|
|
183
|
-
return
|
|
248
|
+
def __call__(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
|
|
249
|
+
return tensor.squeeze(self.dim)
|
|
250
|
+
|
|
251
|
+
def inverse(self, name: str, tensor: torch.Tensor, cache_attribute: dict[str, Any]) -> torch.Tensor:
|
|
252
|
+
return tensor.unsqueeze(self.dim)
|
|
253
|
+
|
|
184
254
|
|
|
185
255
|
class Resample(Transform, ABC):
|
|
186
256
|
|
|
187
257
|
def __init__(self) -> None:
|
|
188
258
|
pass
|
|
189
259
|
|
|
190
|
-
def _resample(self,
|
|
191
|
-
|
|
192
|
-
if input.dtype == torch.uint8:
|
|
260
|
+
def _resample(self, tensor: torch.Tensor, size: list[int]) -> torch.Tensor:
|
|
261
|
+
if tensor.dtype == torch.uint8:
|
|
193
262
|
mode = "nearest"
|
|
194
|
-
elif len(
|
|
263
|
+
elif len(tensor.shape) < 4:
|
|
195
264
|
mode = "bilinear"
|
|
196
265
|
else:
|
|
197
266
|
mode = "trilinear"
|
|
198
|
-
return
|
|
267
|
+
return (
|
|
268
|
+
F.interpolate(tensor.type(torch.float32).unsqueeze(0), size=tuple(size), mode=mode)
|
|
269
|
+
.squeeze(0)
|
|
270
|
+
.type(tensor.dtype)
|
|
271
|
+
.cpu()
|
|
272
|
+
)
|
|
199
273
|
|
|
200
274
|
@abstractmethod
|
|
201
|
-
def __call__(self, name: str,
|
|
275
|
+
def __call__(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
|
|
202
276
|
pass
|
|
203
|
-
|
|
277
|
+
|
|
204
278
|
@abstractmethod
|
|
205
|
-
def
|
|
279
|
+
def transform_shape(self, shape: list[int], cache_attribute: Attribute) -> list[int]:
|
|
206
280
|
pass
|
|
207
|
-
|
|
208
|
-
def inverse(self, name: str,
|
|
209
|
-
|
|
281
|
+
|
|
282
|
+
def inverse(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
|
|
283
|
+
cache_attribute.pop_np_array("Size")
|
|
210
284
|
size_1 = cache_attribute.pop_np_array("Size")
|
|
211
285
|
_ = cache_attribute.pop_np_array("Spacing")
|
|
212
|
-
return self._resample(
|
|
286
|
+
return self._resample(tensor, [int(size) for size in size_1])
|
|
287
|
+
|
|
213
288
|
|
|
214
289
|
class ResampleToResolution(Resample):
|
|
215
290
|
|
|
216
|
-
def __init__(self, spacing
|
|
291
|
+
def __init__(self, spacing: list[float] = [1.0, 1.0, 1.0]) -> None:
|
|
217
292
|
self.spacing = torch.tensor([0 if s < 0 else s for s in spacing])
|
|
218
293
|
|
|
219
|
-
def
|
|
294
|
+
def transform_shape(self, shape: list[int], cache_attribute: Attribute) -> list[int]:
|
|
220
295
|
if "Spacing" not in cache_attribute:
|
|
221
|
-
TransformError(
|
|
222
|
-
|
|
296
|
+
TransformError(
|
|
297
|
+
"Missing 'Spacing' in cache attributes, the data is likely not a valid image.",
|
|
298
|
+
"Make sure your input is a image (e.g., .nii, .mha) with proper metadata.",
|
|
299
|
+
)
|
|
223
300
|
if len(shape) != len(self.spacing):
|
|
224
301
|
TransformError("Shape and spacing dimensions do not match: shape={shape}, spacing={self.spacing}")
|
|
225
302
|
image_spacing = cache_attribute.get_tensor("Spacing").flip(0)
|
|
226
303
|
spacing = self.spacing
|
|
227
|
-
|
|
304
|
+
|
|
228
305
|
for i, s in enumerate(self.spacing):
|
|
229
306
|
if s == 0:
|
|
230
307
|
spacing[i] = image_spacing[i]
|
|
231
|
-
resize_factor = spacing/cache_attribute.get_tensor("Spacing").flip(0)
|
|
232
|
-
return [int(x) for x in (torch.tensor(shape) * 1/resize_factor)]
|
|
308
|
+
resize_factor = spacing / cache_attribute.get_tensor("Spacing").flip(0)
|
|
309
|
+
return [int(x) for x in (torch.tensor(shape) * 1 / resize_factor)]
|
|
233
310
|
|
|
234
|
-
def __call__(self, name: str,
|
|
311
|
+
def __call__(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
|
|
235
312
|
image_spacing = cache_attribute.get_tensor("Spacing").flip(0)
|
|
236
313
|
spacing = self.spacing
|
|
237
314
|
for i, s in enumerate(self.spacing):
|
|
238
315
|
if s == 0:
|
|
239
316
|
spacing[i] = image_spacing[i]
|
|
240
|
-
resize_factor = spacing/cache_attribute.get_tensor("Spacing").flip(0)
|
|
317
|
+
resize_factor = spacing / cache_attribute.get_tensor("Spacing").flip(0)
|
|
241
318
|
cache_attribute["Spacing"] = spacing.flip(0)
|
|
242
|
-
cache_attribute["Size"] = np.asarray([int(x) for x in torch.tensor(
|
|
243
|
-
size = [int(x) for x in (torch.tensor(
|
|
319
|
+
cache_attribute["Size"] = np.asarray([int(x) for x in torch.tensor(tensor.shape[1:])])
|
|
320
|
+
size = [int(x) for x in (torch.tensor(tensor.shape[1:]) * 1 / resize_factor)]
|
|
244
321
|
cache_attribute["Size"] = np.asarray(size)
|
|
245
|
-
return self._resample(
|
|
322
|
+
return self._resample(tensor, size)
|
|
323
|
+
|
|
246
324
|
|
|
247
325
|
class ResampleToShape(Resample):
|
|
248
326
|
|
|
249
|
-
def __init__(self, shape
|
|
327
|
+
def __init__(self, shape: list[float] = [100, 256, 256]) -> None:
|
|
250
328
|
self.shape = torch.tensor([0 if s < 0 else s for s in shape])
|
|
251
329
|
|
|
252
|
-
def
|
|
253
|
-
print(shape)
|
|
330
|
+
def transform_shape(self, shape: list[int], cache_attribute: Attribute) -> list[int]:
|
|
254
331
|
if "Spacing" not in cache_attribute:
|
|
255
|
-
TransformError(
|
|
256
|
-
|
|
332
|
+
TransformError(
|
|
333
|
+
"Missing 'Spacing' in cache attributes, the data is likely not a valid image.",
|
|
334
|
+
"Make sure your input is a image (e.g., .nii, .mha) with proper metadata.",
|
|
335
|
+
)
|
|
257
336
|
if len(shape) != len(self.shape):
|
|
258
337
|
TransformError("Shape and spacing dimensions do not match: shape={shape}, spacing={self.spacing}")
|
|
259
338
|
new_shape = self.shape
|
|
260
339
|
for i, s in enumerate(self.shape):
|
|
261
340
|
if s == 0:
|
|
262
341
|
new_shape[i] = shape[i]
|
|
263
|
-
print(new_shape)
|
|
264
342
|
return new_shape
|
|
265
|
-
|
|
266
|
-
def __call__(self, name: str,
|
|
343
|
+
|
|
344
|
+
def __call__(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
|
|
267
345
|
shape = self.shape
|
|
268
|
-
image_shape =
|
|
346
|
+
image_shape = torch.tensor([int(x) for x in torch.tensor(tensor.shape[1:])])
|
|
269
347
|
for i, s in enumerate(self.shape):
|
|
270
348
|
if s == 0:
|
|
271
349
|
shape[i] = image_shape[i]
|
|
272
350
|
if "Spacing" in cache_attribute:
|
|
273
|
-
cache_attribute["Spacing"] = torch.flip(
|
|
351
|
+
cache_attribute["Spacing"] = torch.flip(
|
|
352
|
+
image_shape / shape * torch.flip(cache_attribute.get_tensor("Spacing"), dims=[0]),
|
|
353
|
+
dims=[0],
|
|
354
|
+
)
|
|
274
355
|
cache_attribute["Size"] = image_shape
|
|
275
356
|
cache_attribute["Size"] = shape
|
|
276
|
-
return self._resample(
|
|
357
|
+
return self._resample(tensor, shape)
|
|
358
|
+
|
|
277
359
|
|
|
278
360
|
class ResampleTransform(Transform):
|
|
279
361
|
|
|
280
|
-
def __init__(self, transforms
|
|
362
|
+
def __init__(self, transforms: dict[str, bool]) -> None:
|
|
281
363
|
self.transforms = transforms
|
|
282
|
-
|
|
283
|
-
def
|
|
364
|
+
|
|
365
|
+
def transform_shape(self, shape: list[int], cache_attribute: Attribute) -> list[int]:
|
|
284
366
|
return shape
|
|
285
367
|
|
|
286
|
-
def
|
|
287
|
-
|
|
288
|
-
|
|
289
|
-
|
|
290
|
-
|
|
291
|
-
|
|
292
|
-
|
|
293
|
-
transform = dataset.readTransform(transform_group, name)
|
|
294
|
-
break
|
|
295
|
-
if transform is None:
|
|
296
|
-
raise NameError("Tranform : {}/{} not found".format(transform_group, name))
|
|
297
|
-
if isinstance(transform, sitk.BSplineTransform):
|
|
298
|
-
if invert:
|
|
299
|
-
transformToDisplacementFieldFilter = sitk.TransformToDisplacementFieldFilter()
|
|
300
|
-
transformToDisplacementFieldFilter.SetReferenceImage(image)
|
|
301
|
-
displacementField = transformToDisplacementFieldFilter.Execute(transform)
|
|
302
|
-
iterativeInverseDisplacementFieldImageFilter = sitk.IterativeInverseDisplacementFieldImageFilter()
|
|
303
|
-
iterativeInverseDisplacementFieldImageFilter.SetNumberOfIterations(20)
|
|
304
|
-
inverseDisplacementField = iterativeInverseDisplacementFieldImageFilter.Execute(displacementField)
|
|
305
|
-
transform = sitk.DisplacementFieldTransform(inverseDisplacementField)
|
|
306
|
-
else:
|
|
307
|
-
if invert:
|
|
308
|
-
transform = transform.GetInverse()
|
|
309
|
-
transforms.append(transform)
|
|
310
|
-
result_transform = sitk.CompositeTransform(transforms)
|
|
311
|
-
result = torch.tensor(sitk.GetArrayFromImage(sitk.Resample(image, image, result_transform, sitk.sitkNearestNeighbor if input.dtype == torch.uint8 else sitk.sitkBSpline, 0 if input.dtype == torch.uint8 else -1024))).unsqueeze(0)
|
|
312
|
-
return result.type(torch.uint8) if input.dtype == torch.uint8 else result
|
|
313
|
-
|
|
314
|
-
def __call__(self, name: str, input : torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
|
|
315
|
-
assert len(input.shape) == 4 , "input size should be 5 dim"
|
|
316
|
-
image = data_to_image(input, cache_attribute)
|
|
317
|
-
|
|
318
|
-
vectors = [torch.arange(0, s) for s in input.shape[1:]]
|
|
319
|
-
grids = torch.meshgrid(vectors, indexing='ij')
|
|
368
|
+
def __call__(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
|
|
369
|
+
if len(tensor.shape) != 4:
|
|
370
|
+
raise NameError("Input size should be 5 dim")
|
|
371
|
+
image = data_to_image(tensor, cache_attribute)
|
|
372
|
+
|
|
373
|
+
vectors = [torch.arange(0, s) for s in tensor.shape[1:]]
|
|
374
|
+
grids = torch.meshgrid(vectors, indexing="ij")
|
|
320
375
|
grid = torch.stack(grids)
|
|
321
376
|
grid = torch.unsqueeze(grid, 0)
|
|
322
|
-
|
|
377
|
+
|
|
323
378
|
transforms = []
|
|
324
379
|
for transform_group, invert in self.transforms.items():
|
|
325
380
|
transform = None
|
|
326
381
|
for dataset in self.datasets:
|
|
327
|
-
if dataset.
|
|
328
|
-
transform = dataset.
|
|
382
|
+
if dataset.is_dataset_exist(transform_group, name):
|
|
383
|
+
transform = dataset.read_transform(transform_group, name)
|
|
329
384
|
break
|
|
330
385
|
if transform is None:
|
|
331
|
-
raise NameError("Tranform : {}/{} not found"
|
|
386
|
+
raise NameError(f"Tranform : {transform_group}/{name} not found")
|
|
332
387
|
if isinstance(transform, sitk.BSplineTransform):
|
|
333
388
|
if invert:
|
|
334
|
-
|
|
335
|
-
|
|
336
|
-
|
|
337
|
-
|
|
338
|
-
|
|
339
|
-
|
|
340
|
-
|
|
389
|
+
transform_to_displacement_field_filter = sitk.TransformToDisplacementFieldFilter()
|
|
390
|
+
transform_to_displacement_field_filter.SetReferenceImage(image)
|
|
391
|
+
displacement_field = transform_to_displacement_field_filter.Execute(transform)
|
|
392
|
+
iterative_inverse_displacement_field_image_filter = (
|
|
393
|
+
sitk.IterativeInverseDisplacementFieldImageFilter()
|
|
394
|
+
)
|
|
395
|
+
iterative_inverse_displacement_field_image_filter.SetNumberOfIterations(20)
|
|
396
|
+
inverse_displacement_field = iterative_inverse_displacement_field_image_filter.Execute(
|
|
397
|
+
displacement_field
|
|
398
|
+
)
|
|
399
|
+
transform = sitk.DisplacementFieldTransform(inverse_displacement_field)
|
|
341
400
|
else:
|
|
342
401
|
if invert:
|
|
343
402
|
transform = transform.GetInverse()
|
|
344
403
|
transforms.append(transform)
|
|
345
404
|
result_transform = sitk.CompositeTransform(transforms)
|
|
346
|
-
|
|
347
|
-
|
|
348
|
-
|
|
349
|
-
|
|
350
|
-
new_locs = grid + torch.tensor(
|
|
405
|
+
|
|
406
|
+
transform_to_displacement_field_filter = sitk.TransformToDisplacementFieldFilter()
|
|
407
|
+
transform_to_displacement_field_filter.SetReferenceImage(image)
|
|
408
|
+
transform_to_displacement_field_filter.SetNumberOfThreads(16)
|
|
409
|
+
new_locs = grid + torch.tensor(
|
|
410
|
+
sitk.GetArrayFromImage(transform_to_displacement_field_filter.Execute(result_transform))
|
|
411
|
+
).unsqueeze(0).permute(0, 4, 1, 2, 3)
|
|
351
412
|
shape = new_locs.shape[2:]
|
|
352
413
|
for i in range(len(shape)):
|
|
353
414
|
new_locs[:, i, ...] = 2 * (new_locs[:, i, ...] / (shape[i] - 1) - 0.5)
|
|
354
415
|
new_locs = new_locs.permute(0, 2, 3, 4, 1)
|
|
355
416
|
new_locs = new_locs[..., [2, 1, 0]]
|
|
356
|
-
result =
|
|
357
|
-
|
|
358
|
-
|
|
359
|
-
|
|
360
|
-
|
|
361
|
-
|
|
417
|
+
result = (
|
|
418
|
+
F.grid_sample(
|
|
419
|
+
tensor.to(self.device).unsqueeze(0).float(),
|
|
420
|
+
new_locs.to(self.device).float(),
|
|
421
|
+
align_corners=True,
|
|
422
|
+
padding_mode="border",
|
|
423
|
+
mode="nearest" if tensor.dtype == torch.uint8 else "bilinear",
|
|
424
|
+
)
|
|
425
|
+
.squeeze(0)
|
|
426
|
+
.cpu()
|
|
427
|
+
)
|
|
428
|
+
return result.type(torch.uint8) if tensor.dtype == torch.uint8 else result
|
|
429
|
+
|
|
430
|
+
def inverse(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
|
|
431
|
+
# TODO
|
|
432
|
+
return tensor
|
|
433
|
+
|
|
362
434
|
|
|
363
435
|
class Mask(Transform):
|
|
364
436
|
|
|
365
|
-
def __init__(self, path
|
|
437
|
+
def __init__(self, path: str = "./default.mha", value_outside: int = 0) -> None:
|
|
366
438
|
self.path = path
|
|
367
439
|
self.value_outside = value_outside
|
|
368
|
-
|
|
369
|
-
def __call__(self, name: str,
|
|
440
|
+
|
|
441
|
+
def __call__(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
|
|
370
442
|
if self.path.endswith(".mha"):
|
|
371
443
|
mask = torch.tensor(sitk.GetArrayFromImage(sitk.ReadImage(self.path))).unsqueeze(0)
|
|
372
444
|
else:
|
|
373
445
|
mask = None
|
|
374
446
|
for dataset in self.datasets:
|
|
375
|
-
if dataset.
|
|
376
|
-
mask, _ = dataset.
|
|
447
|
+
if dataset.is_dataset_exist(self.path, name):
|
|
448
|
+
mask, _ = dataset.read_data(self.path, name)
|
|
377
449
|
break
|
|
378
450
|
if mask is None:
|
|
379
|
-
raise NameError("Mask : {}/{} not found"
|
|
380
|
-
return torch.where(torch.tensor(mask) > 0,
|
|
451
|
+
raise NameError(f"Mask : {self.path}/{name} not found")
|
|
452
|
+
return torch.where(torch.tensor(mask) > 0, tensor, self.value_outside)
|
|
453
|
+
|
|
454
|
+
def inverse(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
|
|
455
|
+
return tensor
|
|
456
|
+
|
|
381
457
|
|
|
382
|
-
def inverse(self, name: str, input : torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
|
|
383
|
-
return input
|
|
384
|
-
|
|
385
458
|
class Gradient(Transform):
|
|
386
459
|
|
|
387
460
|
def __init__(self, per_dim: bool = False):
|
|
388
461
|
self.per_dim = per_dim
|
|
389
|
-
|
|
462
|
+
|
|
390
463
|
@staticmethod
|
|
391
|
-
def
|
|
464
|
+
def _image_gradient_2d(image: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
|
392
465
|
dx = image[:, 1:, :] - image[:, :-1, :]
|
|
393
466
|
dy = image[:, :, 1:] - image[:, :, :-1]
|
|
394
|
-
return torch.nn.ConstantPad2d((0,0,0,1), 0)(dx), torch.nn.ConstantPad2d((0,1,0,0), 0)(dy)
|
|
467
|
+
return torch.nn.ConstantPad2d((0, 0, 0, 1), 0)(dx), torch.nn.ConstantPad2d((0, 1, 0, 0), 0)(dy)
|
|
395
468
|
|
|
396
469
|
@staticmethod
|
|
397
|
-
def
|
|
470
|
+
def _image_gradient_3d(
|
|
471
|
+
image: torch.Tensor,
|
|
472
|
+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
|
398
473
|
dx = image[:, 1:, :, :] - image[:, :-1, :, :]
|
|
399
474
|
dy = image[:, :, 1:, :] - image[:, :, :-1, :]
|
|
400
475
|
dz = image[:, :, :, 1:] - image[:, :, :, :-1]
|
|
401
|
-
return
|
|
402
|
-
|
|
403
|
-
|
|
404
|
-
|
|
476
|
+
return (
|
|
477
|
+
torch.nn.ConstantPad3d((0, 0, 0, 0, 0, 1), 0)(dx),
|
|
478
|
+
torch.nn.ConstantPad3d((0, 0, 0, 1, 0, 0), 0)(dy),
|
|
479
|
+
torch.nn.ConstantPad3d((0, 1, 0, 0, 0, 0), 0)(dz),
|
|
480
|
+
)
|
|
481
|
+
|
|
482
|
+
def __call__(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
|
|
483
|
+
result = torch.stack(
|
|
484
|
+
(Gradient._image_gradient_3d(tensor) if len(tensor.shape) == 4 else Gradient._image_gradient_2d(tensor)),
|
|
485
|
+
dim=1,
|
|
486
|
+
).squeeze(0)
|
|
405
487
|
if not self.per_dim:
|
|
406
|
-
result = torch.sigmoid(result*3)
|
|
488
|
+
result = torch.sigmoid(result * 3)
|
|
407
489
|
result = result.norm(dim=0)
|
|
408
490
|
result = torch.unsqueeze(result, 0)
|
|
409
|
-
|
|
491
|
+
|
|
410
492
|
return result
|
|
411
493
|
|
|
412
|
-
def inverse(self, name: str,
|
|
413
|
-
return
|
|
494
|
+
def inverse(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
|
|
495
|
+
return tensor
|
|
496
|
+
|
|
414
497
|
|
|
415
498
|
class Argmax(Transform):
|
|
416
499
|
|
|
417
500
|
def __init__(self, dim: int = 0) -> None:
|
|
418
501
|
self.dim = dim
|
|
419
|
-
|
|
420
|
-
def __call__(self, name: str,
|
|
421
|
-
return torch.argmax(
|
|
422
|
-
|
|
423
|
-
def inverse(self, name: str,
|
|
424
|
-
return
|
|
425
|
-
|
|
502
|
+
|
|
503
|
+
def __call__(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
|
|
504
|
+
return torch.argmax(tensor, dim=self.dim).unsqueeze(self.dim)
|
|
505
|
+
|
|
506
|
+
def inverse(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
|
|
507
|
+
return tensor
|
|
508
|
+
|
|
509
|
+
|
|
426
510
|
class FlatLabel(Transform):
|
|
427
511
|
|
|
428
|
-
def __init__(self, labels:
|
|
512
|
+
def __init__(self, labels: list[int] | None = None) -> None:
|
|
429
513
|
self.labels = labels
|
|
430
514
|
|
|
431
|
-
def __call__(self, name: str,
|
|
432
|
-
data = torch.zeros_like(
|
|
515
|
+
def __call__(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
|
|
516
|
+
data = torch.zeros_like(tensor)
|
|
433
517
|
if self.labels:
|
|
434
518
|
for label in self.labels:
|
|
435
|
-
data[torch.where(
|
|
519
|
+
data[torch.where(tensor == label)] = 1
|
|
436
520
|
else:
|
|
437
|
-
data[torch.where(
|
|
521
|
+
data[torch.where(tensor > 0)] = 1
|
|
438
522
|
return data
|
|
439
523
|
|
|
440
|
-
def inverse(self, name: str,
|
|
441
|
-
return
|
|
524
|
+
def inverse(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
|
|
525
|
+
return tensor
|
|
526
|
+
|
|
442
527
|
|
|
443
528
|
class Save(Transform):
|
|
444
529
|
|
|
445
530
|
def __init__(self, dataset: str) -> None:
|
|
446
531
|
self.dataset = dataset
|
|
447
|
-
|
|
448
|
-
def __call__(self, name: str, input : torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
|
|
449
|
-
return input
|
|
450
532
|
|
|
451
|
-
def
|
|
452
|
-
return
|
|
533
|
+
def __call__(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
|
|
534
|
+
return tensor
|
|
535
|
+
|
|
536
|
+
def inverse(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
|
|
537
|
+
return tensor
|
|
538
|
+
|
|
453
539
|
|
|
454
540
|
class Flatten(Transform):
|
|
455
541
|
|
|
456
542
|
def __init__(self) -> None:
|
|
457
543
|
super().__init__()
|
|
458
544
|
|
|
459
|
-
def
|
|
545
|
+
def transform_shape(self, shape: list[int], cache_attribute: Attribute) -> list[int]:
|
|
460
546
|
return [np.prod(np.asarray(shape))]
|
|
461
547
|
|
|
462
|
-
def __call__(self, name: str,
|
|
463
|
-
return
|
|
464
|
-
|
|
465
|
-
def inverse(self, name: str,
|
|
466
|
-
return
|
|
548
|
+
def __call__(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
|
|
549
|
+
return tensor.flatten()
|
|
550
|
+
|
|
551
|
+
def inverse(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
|
|
552
|
+
return tensor
|
|
553
|
+
|
|
467
554
|
|
|
468
555
|
class Permute(Transform):
|
|
469
556
|
|
|
470
557
|
def __init__(self, dims: str = "1|0|2") -> None:
|
|
471
558
|
super().__init__()
|
|
472
|
-
self.dims = [0]+[int(d)+1 for d in dims.split("|")]
|
|
473
|
-
|
|
474
|
-
def
|
|
475
|
-
return [shape[it-1] for it in self.dims[1:]]
|
|
476
|
-
|
|
477
|
-
def __call__(self, name: str,
|
|
478
|
-
return
|
|
479
|
-
|
|
480
|
-
def inverse(self, name: str,
|
|
481
|
-
return
|
|
482
|
-
|
|
559
|
+
self.dims = [0] + [int(d) + 1 for d in dims.split("|")]
|
|
560
|
+
|
|
561
|
+
def transform_shape(self, shape: list[int], cache_attribute: Attribute) -> list[int]:
|
|
562
|
+
return [shape[it - 1] for it in self.dims[1:]]
|
|
563
|
+
|
|
564
|
+
def __call__(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
|
|
565
|
+
return tensor.permute(tuple(self.dims))
|
|
566
|
+
|
|
567
|
+
def inverse(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
|
|
568
|
+
return tensor.permute(tuple(np.argsort(self.dims)))
|
|
569
|
+
|
|
570
|
+
|
|
483
571
|
class Flip(Transform):
|
|
484
572
|
|
|
485
573
|
def __init__(self, dims: str = "1|0|2") -> None:
|
|
486
574
|
super().__init__()
|
|
487
575
|
|
|
488
|
-
self.dims = [int(d)+1 for d in str(dims).split("|")]
|
|
576
|
+
self.dims = [int(d) + 1 for d in str(dims).split("|")]
|
|
577
|
+
|
|
578
|
+
def __call__(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
|
|
579
|
+
return tensor.flip(tuple(self.dims))
|
|
580
|
+
|
|
581
|
+
def inverse(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
|
|
582
|
+
return tensor.flip(tuple(self.dims))
|
|
489
583
|
|
|
490
|
-
def __call__(self, name: str, input : torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
|
|
491
|
-
return input.flip(tuple(self.dims))
|
|
492
|
-
|
|
493
|
-
def inverse(self, name: str, input : torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
|
|
494
|
-
return input.flip(tuple(self.dims))
|
|
495
584
|
|
|
496
585
|
class Canonical(Transform):
|
|
497
586
|
|
|
498
587
|
def __init__(self) -> None:
|
|
499
588
|
self.canonical_direction = torch.diag(torch.tensor([-1, -1, 1])).to(torch.double)
|
|
500
589
|
|
|
501
|
-
def __call__(self, name: str,
|
|
590
|
+
def __call__(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
|
|
502
591
|
spacing = cache_attribute.get_tensor("Spacing")
|
|
503
|
-
initial_matrix = cache_attribute.get_tensor("Direction").reshape(3,3).to(torch.double)
|
|
592
|
+
initial_matrix = cache_attribute.get_tensor("Direction").reshape(3, 3).to(torch.double)
|
|
504
593
|
initial_origin = cache_attribute.get_tensor("Origin")
|
|
505
594
|
cache_attribute["Direction"] = (self.canonical_direction).flatten()
|
|
506
595
|
matrix = _affine_matrix(self.canonical_direction @ initial_matrix.inverse(), torch.tensor([0, 0, 0]))
|
|
507
|
-
center_voxel = torch.tensor(
|
|
596
|
+
center_voxel = torch.tensor(
|
|
597
|
+
[(tensor.shape[-i - 1] - 1) * spacing[i] / 2 for i in range(3)],
|
|
598
|
+
dtype=torch.double,
|
|
599
|
+
)
|
|
508
600
|
center_physical = initial_matrix @ center_voxel + initial_origin
|
|
509
601
|
cache_attribute["Origin"] = center_physical - (self.canonical_direction @ center_voxel)
|
|
510
|
-
return _resample_affine(
|
|
511
|
-
|
|
512
|
-
def inverse(self, name: str,
|
|
602
|
+
return _resample_affine(tensor, matrix.unsqueeze(0))
|
|
603
|
+
|
|
604
|
+
def inverse(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
|
|
513
605
|
cache_attribute.pop("Direction")
|
|
514
606
|
cache_attribute.pop("Origin")
|
|
515
|
-
matrix = _affine_matrix(
|
|
516
|
-
|
|
607
|
+
matrix = _affine_matrix(
|
|
608
|
+
(
|
|
609
|
+
self.canonical_direction
|
|
610
|
+
@ cache_attribute.get_tensor("Direction").to(torch.double).reshape(3, 3).inverse()
|
|
611
|
+
).inverse(),
|
|
612
|
+
torch.tensor([0, 0, 0]),
|
|
613
|
+
)
|
|
614
|
+
return _resample_affine(tensor, matrix.unsqueeze(0))
|
|
615
|
+
|
|
517
616
|
|
|
518
617
|
class HistogramMatching(Transform):
|
|
519
618
|
|
|
520
619
|
def __init__(self, reference_group: str) -> None:
|
|
521
620
|
self.reference_group = reference_group
|
|
522
621
|
|
|
523
|
-
def __call__(self, name: str,
|
|
524
|
-
image = data_to_image(
|
|
622
|
+
def __call__(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
|
|
623
|
+
image = data_to_image(tensor, cache_attribute)
|
|
525
624
|
image_ref = None
|
|
526
625
|
for dataset in self.datasets:
|
|
527
|
-
if dataset.
|
|
528
|
-
image_ref = dataset.
|
|
626
|
+
if dataset.is_dataset_exist(self.reference_group, name):
|
|
627
|
+
image_ref = dataset.read_image(self.reference_group, name)
|
|
529
628
|
if image_ref is None:
|
|
530
|
-
|
|
629
|
+
raise NameError(f"Image : {self.reference_group}/{name} not found")
|
|
531
630
|
matcher = sitk.HistogramMatchingImageFilter()
|
|
532
631
|
matcher.SetNumberOfHistogramLevels(256)
|
|
533
632
|
matcher.SetNumberOfMatchPoints(1)
|
|
534
633
|
matcher.SetThresholdAtMeanIntensity(True)
|
|
535
634
|
result, _ = image_to_data(matcher.Execute(image, image_ref))
|
|
536
635
|
return torch.tensor(result)
|
|
537
|
-
|
|
538
|
-
def inverse(self, name: str,
|
|
539
|
-
return
|
|
636
|
+
|
|
637
|
+
def inverse(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
|
|
638
|
+
return tensor
|
|
639
|
+
|
|
540
640
|
|
|
541
641
|
class SelectLabel(Transform):
|
|
542
642
|
|
|
543
643
|
def __init__(self, labels: list[str]) -> None:
|
|
544
|
-
self.labels = [
|
|
545
|
-
|
|
546
|
-
|
|
644
|
+
self.labels = [label[1:-1].split(",") for label in labels]
|
|
645
|
+
|
|
646
|
+
def __call__(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
|
|
647
|
+
data = torch.zeros_like(tensor)
|
|
547
648
|
for old_label, new_label in self.labels:
|
|
548
|
-
data[
|
|
649
|
+
data[tensor == int(old_label)] = int(new_label)
|
|
549
650
|
return data
|
|
550
|
-
|
|
551
|
-
def inverse(self, name: str,
|
|
552
|
-
return
|
|
553
|
-
|
|
651
|
+
|
|
652
|
+
def inverse(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
|
|
653
|
+
return tensor
|
|
654
|
+
|
|
655
|
+
|
|
554
656
|
class OneHot(Transform):
|
|
555
|
-
|
|
657
|
+
|
|
556
658
|
def __init__(self, num_classes: int) -> None:
|
|
557
659
|
super().__init__()
|
|
558
660
|
self.num_classes = num_classes
|
|
559
661
|
|
|
560
|
-
def __call__(self, name: str,
|
|
561
|
-
result =
|
|
662
|
+
def __call__(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
|
|
663
|
+
result = (
|
|
664
|
+
F.one_hot(tensor.type(torch.int64), num_classes=self.num_classes)
|
|
665
|
+
.permute(0, len(tensor.shape), *[i + 1 for i in range(len(tensor.shape) - 1)])
|
|
666
|
+
.float()
|
|
667
|
+
.squeeze(0)
|
|
668
|
+
)
|
|
562
669
|
return result
|
|
563
|
-
|
|
564
|
-
def inverse(self, name: str,
|
|
565
|
-
return torch.argmax(
|
|
670
|
+
|
|
671
|
+
def inverse(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
|
|
672
|
+
return torch.argmax(tensor, dim=1).unsqueeze(1)
|