konfai 1.1.8__py3-none-any.whl → 1.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of konfai might be problematic. Click here for more details.
- konfai/__init__.py +59 -14
- konfai/data/augmentation.py +457 -286
- konfai/data/data_manager.py +533 -316
- konfai/data/patching.py +300 -183
- konfai/data/transform.py +408 -275
- konfai/evaluator.py +325 -68
- konfai/main.py +71 -22
- konfai/metric/measure.py +360 -244
- konfai/metric/schedulers.py +24 -13
- konfai/models/classification/convNeXt.py +187 -81
- konfai/models/classification/resnet.py +272 -58
- konfai/models/generation/cStyleGan.py +233 -59
- konfai/models/generation/ddpm.py +348 -121
- konfai/models/generation/diffusionGan.py +757 -358
- konfai/models/generation/gan.py +177 -53
- konfai/models/generation/vae.py +140 -40
- konfai/models/registration/registration.py +135 -52
- konfai/models/representation/representation.py +57 -23
- konfai/models/segmentation/NestedUNet.py +339 -68
- konfai/models/segmentation/UNet.py +140 -30
- konfai/network/blocks.py +331 -187
- konfai/network/network.py +795 -427
- konfai/predictor.py +644 -238
- konfai/trainer.py +509 -222
- konfai/utils/ITK.py +191 -106
- konfai/utils/config.py +152 -95
- konfai/utils/dataset.py +326 -455
- konfai/utils/utils.py +497 -249
- {konfai-1.1.8.dist-info → konfai-1.2.0.dist-info}/METADATA +1 -3
- konfai-1.2.0.dist-info/RECORD +38 -0
- konfai/utils/registration.py +0 -199
- konfai-1.1.8.dist-info/RECORD +0 -39
- {konfai-1.1.8.dist-info → konfai-1.2.0.dist-info}/WHEEL +0 -0
- {konfai-1.1.8.dist-info → konfai-1.2.0.dist-info}/entry_points.txt +0 -0
- {konfai-1.1.8.dist-info → konfai-1.2.0.dist-info}/licenses/LICENSE +0 -0
- {konfai-1.1.8.dist-info → konfai-1.2.0.dist-info}/top_level.txt +0 -0
konfai/utils/ITK.py
CHANGED
|
@@ -1,17 +1,20 @@
|
|
|
1
|
-
import SimpleITK as sitk
|
|
2
|
-
from typing import Union
|
|
3
1
|
import numpy as np
|
|
4
|
-
import torch
|
|
5
2
|
import scipy
|
|
6
|
-
import
|
|
3
|
+
import SimpleITK as sitk # noqa: N813
|
|
4
|
+
import torch
|
|
5
|
+
import torch.nn.functional as F # noqa: N812
|
|
6
|
+
|
|
7
7
|
from konfai.utils.utils import _resample
|
|
8
8
|
|
|
9
|
-
|
|
9
|
+
|
|
10
|
+
def _open_transform(
|
|
11
|
+
transform_files: dict[str | sitk.Transform, bool], image: sitk.Image = None
|
|
12
|
+
) -> list[sitk.Transform]:
|
|
10
13
|
transforms: list[sitk.Transform] = []
|
|
11
14
|
|
|
12
15
|
for transform_file, invert in transform_files.items():
|
|
13
16
|
if isinstance(transform_file, str):
|
|
14
|
-
transform = sitk.ReadTransform(transform_file+".itk.txt")
|
|
17
|
+
transform = sitk.ReadTransform(transform_file + ".itk.txt")
|
|
15
18
|
else:
|
|
16
19
|
transform = transform_file
|
|
17
20
|
if transform.GetName() == "TranslationTransform":
|
|
@@ -32,203 +35,282 @@ def _openTransform(transform_files: dict[Union[str, sitk.Transform], bool], imag
|
|
|
32
35
|
transform = sitk.AffineTransform(transform.GetInverse())
|
|
33
36
|
elif transform.GetName() == "DisplacementFieldTransform":
|
|
34
37
|
if invert:
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
38
|
+
transform_to_displacement_field_filter = sitk.TransformToDisplacementFieldFilter()
|
|
39
|
+
transform_to_displacement_field_filter.SetReferenceImage(image)
|
|
40
|
+
displacement_field = transform_to_displacement_field_filter.Execute(transform)
|
|
41
|
+
iterative_inverse_displacement_field_image_filter = sitk.IterativeInverseDisplacementFieldImageFilter()
|
|
42
|
+
iterative_inverse_displacement_field_image_filter.SetNumberOfIterations(20)
|
|
43
|
+
inverse_displacement_field = iterative_inverse_displacement_field_image_filter.Execute(
|
|
44
|
+
displacement_field
|
|
45
|
+
)
|
|
46
|
+
transform = sitk.DisplacementFieldTransform(inverse_displacement_field)
|
|
42
47
|
transforms.append(transform)
|
|
43
48
|
else:
|
|
44
49
|
transform = sitk.BSplineTransform(transform)
|
|
45
50
|
if invert:
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
51
|
+
transform_to_displacement_field_filter = sitk.TransformToDisplacementFieldFilter()
|
|
52
|
+
transform_to_displacement_field_filter.SetReferenceImage(image)
|
|
53
|
+
displacement_field = transform_to_displacement_field_filter.Execute(transform)
|
|
54
|
+
iterative_inverse_displacement_field_image_filter = sitk.IterativeInverseDisplacementFieldImageFilter()
|
|
55
|
+
iterative_inverse_displacement_field_image_filter.SetNumberOfIterations(20)
|
|
56
|
+
inverse_displacement_field = iterative_inverse_displacement_field_image_filter.Execute(
|
|
57
|
+
displacement_field
|
|
58
|
+
)
|
|
59
|
+
transform = sitk.DisplacementFieldTransform(inverse_displacement_field)
|
|
53
60
|
transforms.append(transform)
|
|
54
61
|
if len(transforms) == 0:
|
|
55
62
|
transforms.append(sitk.Euler3DTransform())
|
|
56
63
|
return transforms
|
|
57
64
|
|
|
58
|
-
|
|
59
|
-
|
|
65
|
+
|
|
66
|
+
def _open_rigid_transform(transform_files: dict[str | sitk.Transform, bool]) -> tuple[np.ndarray, np.ndarray]:
|
|
67
|
+
transforms = _open_transform(transform_files)
|
|
60
68
|
matrix_result = np.identity(3)
|
|
61
|
-
translation_result = np.array([0,0,0])
|
|
69
|
+
translation_result = np.array([0, 0, 0])
|
|
62
70
|
|
|
63
71
|
for transform in transforms:
|
|
64
72
|
if hasattr(transform, "GetMatrix"):
|
|
65
|
-
matrix = np.linalg.inv(np.array(transform.GetMatrix(), dtype=np.double).reshape((3,3)))
|
|
73
|
+
matrix = np.linalg.inv(np.array(transform.GetMatrix(), dtype=np.double).reshape((3, 3)))
|
|
66
74
|
translation = -np.asarray(transform.GetTranslation(), dtype=np.double)
|
|
67
75
|
center = np.asarray(transform.GetCenter(), dtype=np.double)
|
|
68
76
|
else:
|
|
69
77
|
matrix = np.eye(len(transform.GetOffset()))
|
|
70
78
|
translation = -np.asarray(transform.GetOffset(), dtype=np.double)
|
|
71
|
-
center = np.asarray([0]*len(transform.GetOffset()), dtype=np.double)
|
|
72
|
-
|
|
73
|
-
translation_center = np.linalg.inv(matrix).dot(matrix.dot(translation-center)+center)
|
|
74
|
-
translation_result = np.linalg.inv(matrix_result).dot(translation_center)+translation_result
|
|
79
|
+
center = np.asarray([0] * len(transform.GetOffset()), dtype=np.double)
|
|
80
|
+
|
|
81
|
+
translation_center = np.linalg.inv(matrix).dot(matrix.dot(translation - center) + center)
|
|
82
|
+
translation_result = np.linalg.inv(matrix_result).dot(translation_center) + translation_result
|
|
75
83
|
matrix_result = matrix.dot(matrix_result)
|
|
76
84
|
return np.linalg.inv(matrix_result), -translation_result
|
|
77
85
|
|
|
78
|
-
|
|
79
|
-
|
|
86
|
+
|
|
87
|
+
def compose_transform(
|
|
88
|
+
transform_files: dict[str | sitk.Transform, bool], image: sitk.Image = None
|
|
89
|
+
) -> sitk.CompositeTransform:
|
|
90
|
+
transforms = _open_transform(transform_files, image)
|
|
80
91
|
result = sitk.CompositeTransform(transforms)
|
|
81
92
|
return result
|
|
82
93
|
|
|
83
|
-
|
|
84
|
-
|
|
94
|
+
|
|
95
|
+
def flatten_transform(transform_files: dict[str | sitk.Transform, bool]) -> sitk.AffineTransform:
|
|
96
|
+
[matrix, translation] = _open_rigid_transform(transform_files)
|
|
85
97
|
transform = sitk.AffineTransform(3)
|
|
86
98
|
transform.SetMatrix(matrix.flatten())
|
|
87
99
|
transform.SetTranslation(translation)
|
|
88
100
|
return transform
|
|
89
101
|
|
|
90
|
-
|
|
91
|
-
|
|
102
|
+
|
|
103
|
+
def apply_to_image_rigid_transform(image: sitk.Image, transform_files: dict[str | sitk.Transform, bool]) -> sitk.Image:
|
|
104
|
+
[matrix, translation] = _open_rigid_transform(transform_files)
|
|
92
105
|
matrix = np.linalg.inv(matrix)
|
|
93
106
|
translation = -translation
|
|
94
107
|
data = sitk.GetArrayFromImage(image)
|
|
95
108
|
result = sitk.GetImageFromArray(data)
|
|
96
|
-
result.SetDirection(matrix.dot(np.array(image.GetDirection()).reshape((3,3))).flatten())
|
|
97
|
-
result.SetOrigin(matrix.dot(np.array(image.GetOrigin())+translation))
|
|
109
|
+
result.SetDirection(matrix.dot(np.array(image.GetDirection()).reshape((3, 3))).flatten())
|
|
110
|
+
result.SetOrigin(matrix.dot(np.array(image.GetOrigin()) + translation))
|
|
98
111
|
result.SetSpacing(image.GetSpacing())
|
|
99
112
|
return result
|
|
100
113
|
|
|
101
|
-
|
|
102
|
-
|
|
114
|
+
|
|
115
|
+
def apply_to_data_transform(data: np.ndarray, transform_files: dict[str | sitk.Transform, bool]) -> sitk.Image:
|
|
116
|
+
transforms = compose_transform(transform_files)
|
|
103
117
|
result = np.copy(data)
|
|
104
|
-
_LPS = lambda matrix: np.array([-matrix[0], -matrix[1], matrix[2]], dtype=np.double)
|
|
118
|
+
# _LPS = lambda matrix: np.array([-matrix[0], -matrix[1], matrix[2]], dtype=np.double)
|
|
105
119
|
for i in range(data.shape[0]):
|
|
106
|
-
result[i, :] =
|
|
120
|
+
result[i, :] = transforms.TransformPoint(np.asarray(data[i, :], dtype=np.double))
|
|
107
121
|
return result
|
|
108
122
|
|
|
109
|
-
|
|
123
|
+
|
|
124
|
+
def resample_itk(
|
|
125
|
+
image_reference: sitk.Image,
|
|
126
|
+
image: sitk.Image,
|
|
127
|
+
transform_files: dict[str | sitk.Transform, bool],
|
|
128
|
+
mask=False,
|
|
129
|
+
default_pixel_value: float | None = None,
|
|
130
|
+
torch_resample: bool = False,
|
|
131
|
+
) -> sitk.Image:
|
|
110
132
|
if torch_resample:
|
|
111
|
-
|
|
112
|
-
vectors = [torch.arange(0, s) for s in
|
|
113
|
-
grids = torch.meshgrid(vectors, indexing=
|
|
133
|
+
input_tensor = torch.tensor(sitk.GetArrayFromImage(image)).unsqueeze(0)
|
|
134
|
+
vectors = [torch.arange(0, s) for s in input_tensor.shape[1:]]
|
|
135
|
+
grids = torch.meshgrid(vectors, indexing="ij")
|
|
114
136
|
grid = torch.stack(grids)
|
|
115
137
|
grid = torch.unsqueeze(grid, 0)
|
|
116
|
-
|
|
117
|
-
|
|
118
|
-
|
|
119
|
-
new_locs = grid + torch.tensor(
|
|
138
|
+
transform_to_displacement_field_filter = sitk.TransformToDisplacementFieldFilter()
|
|
139
|
+
transform_to_displacement_field_filter.SetReferenceImage(image)
|
|
140
|
+
transform_to_displacement_field_filter.SetNumberOfThreads(16)
|
|
141
|
+
new_locs = grid + torch.tensor(
|
|
142
|
+
sitk.GetArrayFromImage(
|
|
143
|
+
transform_to_displacement_field_filter.Execute(compose_transform(transform_files, image))
|
|
144
|
+
)
|
|
145
|
+
).unsqueeze(0).permute(0, 4, 1, 2, 3)
|
|
120
146
|
shape = new_locs.shape[2:]
|
|
121
147
|
for i in range(len(shape)):
|
|
122
148
|
new_locs[:, i, ...] = 2 * (new_locs[:, i, ...] / (shape[i] - 1) - 0.5)
|
|
123
149
|
new_locs = new_locs.permute(0, 2, 3, 4, 1)
|
|
124
150
|
new_locs = new_locs[..., [2, 1, 0]]
|
|
125
|
-
result_data = F.grid_sample(
|
|
126
|
-
|
|
151
|
+
result_data = F.grid_sample(
|
|
152
|
+
input_tensor.unsqueeze(0).float(),
|
|
153
|
+
new_locs.float(),
|
|
154
|
+
align_corners=True,
|
|
155
|
+
padding_mode="border",
|
|
156
|
+
mode="nearest" if input_tensor.dtype == torch.uint8 else "bilinear",
|
|
157
|
+
).squeeze(0)
|
|
158
|
+
result_data = result_data.type(torch.uint8) if input_tensor.dtype == torch.uint8 else result_data
|
|
127
159
|
result = sitk.GetImageFromArray(result_data.squeeze(0).numpy())
|
|
128
160
|
result.CopyInformation(image_reference)
|
|
129
161
|
return result
|
|
130
162
|
else:
|
|
131
|
-
return sitk.Resample(
|
|
163
|
+
return sitk.Resample(
|
|
164
|
+
image,
|
|
165
|
+
image_reference,
|
|
166
|
+
compose_transform(transform_files, image),
|
|
167
|
+
sitk.sitkNearestNeighbor if mask else sitk.sitkBSpline,
|
|
168
|
+
(
|
|
169
|
+
default_pixel_value
|
|
170
|
+
if default_pixel_value is not None
|
|
171
|
+
else (0 if mask else int(np.min(sitk.GetArrayFromImage(image))))
|
|
172
|
+
),
|
|
173
|
+
)
|
|
174
|
+
|
|
175
|
+
|
|
176
|
+
def parametermap_to_transform(
|
|
177
|
+
path_src: str,
|
|
178
|
+
) -> sitk.Transform | list[sitk.Transform]:
|
|
179
|
+
transform = sitk.ReadParameterFile(path_src)
|
|
180
|
+
|
|
181
|
+
def array_format(x):
|
|
182
|
+
return [float(i) for i in x]
|
|
132
183
|
|
|
133
|
-
|
|
134
|
-
transform = sitk.ReadParameterFile("{}.0.txt".format(path_src))
|
|
135
|
-
format = lambda x: np.array([float(i) for i in x])
|
|
184
|
+
dimension = int(transform["FixedImageDimension"][0])
|
|
136
185
|
|
|
137
186
|
if transform["Transform"][0] == "EulerTransform":
|
|
138
|
-
|
|
139
|
-
|
|
140
|
-
|
|
187
|
+
if dimension == 2:
|
|
188
|
+
result = sitk.Euler2DTransform()
|
|
189
|
+
else:
|
|
190
|
+
result = sitk.Euler3DTransform()
|
|
191
|
+
parameters = array_format(transform["TransformParameters"])
|
|
192
|
+
fixed_parameters = array_format(transform["CenterOfRotationPoint"]) + [0]
|
|
193
|
+
elif transform["Transform"][0] == "TranslationTransform":
|
|
194
|
+
result = sitk.TranslationTransform(dimension)
|
|
195
|
+
parameters = array_format(transform["TransformParameters"])
|
|
196
|
+
fixed_parameters = []
|
|
141
197
|
elif transform["Transform"][0] == "AffineTransform":
|
|
142
|
-
result = sitk.AffineTransform(
|
|
143
|
-
parameters =
|
|
144
|
-
|
|
198
|
+
result = sitk.AffineTransform(dimension)
|
|
199
|
+
parameters = array_format(transform["TransformParameters"])
|
|
200
|
+
fixed_parameters = array_format(transform["CenterOfRotationPoint"]) + [0]
|
|
145
201
|
elif transform["Transform"][0] == "BSplineStackTransform":
|
|
146
|
-
parameters =
|
|
147
|
-
|
|
148
|
-
|
|
149
|
-
|
|
150
|
-
|
|
151
|
-
|
|
152
|
-
|
|
153
|
-
|
|
154
|
-
|
|
202
|
+
parameters = array_format(transform["TransformParameters"])
|
|
203
|
+
grid_size = array_format(transform["GridSize"])
|
|
204
|
+
grid_origin = array_format(transform["GridOrigin"])
|
|
205
|
+
grid_spacing = array_format(transform["GridSpacing"])
|
|
206
|
+
grid_direction = (
|
|
207
|
+
np.asarray(array_format(transform["GridDirection"])).reshape((dimension, dimension)).T.flatten()
|
|
208
|
+
)
|
|
209
|
+
fixed_parameters = np.concatenate([grid_size, grid_origin, grid_spacing, grid_direction])
|
|
210
|
+
|
|
211
|
+
nb = int(array_format(transform["Size"])[-1])
|
|
212
|
+
sub = int(np.prod(grid_size)) * dimension
|
|
155
213
|
results = []
|
|
156
214
|
for i in range(nb):
|
|
157
|
-
result = sitk.BSplineTransform(
|
|
158
|
-
sub_parameters = parameters[i*sub:(i+1)*sub]
|
|
159
|
-
result.SetFixedParameters(
|
|
215
|
+
result = sitk.BSplineTransform(dimension)
|
|
216
|
+
sub_parameters = np.asarray(parameters[i * sub : (i + 1) * sub])
|
|
217
|
+
result.SetFixedParameters(fixed_parameters)
|
|
160
218
|
result.SetParameters(sub_parameters)
|
|
161
219
|
results.append(result)
|
|
162
220
|
return results
|
|
163
221
|
elif transform["Transform"][0] == "AffineLogStackTransform":
|
|
164
|
-
parameters =
|
|
165
|
-
|
|
222
|
+
parameters = array_format(transform["TransformParameters"])
|
|
223
|
+
fixed_parameters = array_format(transform["CenterOfRotationPoint"]) + [0]
|
|
166
224
|
|
|
167
225
|
nb = int(transform["NumberOfSubTransforms"][0])
|
|
168
|
-
sub =
|
|
226
|
+
sub = dimension * 4
|
|
169
227
|
results = []
|
|
170
228
|
for i in range(nb):
|
|
171
|
-
result = sitk.AffineTransform(
|
|
172
|
-
sub_parameters = parameters[i*sub:(i+1)*sub]
|
|
229
|
+
result = sitk.AffineTransform(dimension)
|
|
230
|
+
sub_parameters = np.asarray(parameters[i * sub : (i + 1) * sub])
|
|
173
231
|
|
|
174
|
-
result.SetFixedParameters(
|
|
175
|
-
result.SetParameters(
|
|
232
|
+
result.SetFixedParameters(fixed_parameters)
|
|
233
|
+
result.SetParameters(
|
|
234
|
+
np.concatenate(
|
|
235
|
+
[
|
|
236
|
+
scipy.linalg.expm(
|
|
237
|
+
sub_parameters[: dimension * dimension].reshape((dimension, dimension))
|
|
238
|
+
).flatten(),
|
|
239
|
+
sub_parameters[-dimension:],
|
|
240
|
+
]
|
|
241
|
+
)
|
|
242
|
+
)
|
|
176
243
|
results.append(result)
|
|
177
244
|
return results
|
|
178
245
|
elif transform["Transform"][0] == "BSplineTransform":
|
|
179
|
-
result = sitk.BSplineTransform(
|
|
180
|
-
|
|
181
|
-
parameters =
|
|
182
|
-
|
|
183
|
-
|
|
184
|
-
|
|
185
|
-
|
|
186
|
-
|
|
246
|
+
result = sitk.BSplineTransform(dimension)
|
|
247
|
+
|
|
248
|
+
parameters = array_format(transform["TransformParameters"])
|
|
249
|
+
grid_size = array_format(transform["GridSize"])
|
|
250
|
+
grid_origin = array_format(transform["GridOrigin"])
|
|
251
|
+
grid_spacing = array_format(transform["GridSpacing"])
|
|
252
|
+
grid_direction = np.array(array_format(transform["GridDirection"])).reshape((dimension, dimension)).T.flatten()
|
|
253
|
+
fixed_parameters = np.concatenate([grid_size, grid_origin, grid_spacing, grid_direction])
|
|
187
254
|
else:
|
|
188
|
-
raise NameError("Transform {} doesn't exist"
|
|
189
|
-
result.SetFixedParameters(
|
|
255
|
+
raise NameError(f"Transform {transform['Transform'][0]} doesn't exist")
|
|
256
|
+
result.SetFixedParameters(fixed_parameters)
|
|
190
257
|
result.SetParameters(parameters)
|
|
191
258
|
return result
|
|
192
259
|
|
|
193
|
-
|
|
194
|
-
|
|
195
|
-
|
|
260
|
+
|
|
261
|
+
def resample_isotropic(image: sitk.Image, spacing: list[float] | None = None) -> sitk.Image:
|
|
262
|
+
spacing = spacing or [1.0, 1.0, 1.0]
|
|
263
|
+
resize_factor = [y / x for x, y in zip(spacing, image.GetSpacing())]
|
|
264
|
+
result = sitk.GetImageFromArray(
|
|
265
|
+
_resample(
|
|
266
|
+
torch.tensor(sitk.GetArrayFromImage(image)).unsqueeze(0),
|
|
267
|
+
[int(size * factor) for size, factor in zip(image.GetSize(), resize_factor)],
|
|
268
|
+
)
|
|
269
|
+
.squeeze(0)
|
|
270
|
+
.numpy()
|
|
271
|
+
)
|
|
196
272
|
result.SetDirection(image.GetDirection())
|
|
197
273
|
result.SetOrigin(image.GetOrigin())
|
|
198
274
|
result.SetSpacing(spacing)
|
|
199
275
|
return result
|
|
200
276
|
|
|
201
|
-
|
|
202
|
-
|
|
277
|
+
|
|
278
|
+
def resample_resize(image: sitk.Image, size: list[int] | None = None):
|
|
279
|
+
size = size or [100, 512, 512]
|
|
280
|
+
result = sitk.GetImageFromArray(
|
|
281
|
+
_resample(torch.tensor(sitk.GetArrayFromImage(image)).unsqueeze(0), size).squeeze(0).numpy()
|
|
282
|
+
)
|
|
203
283
|
result.SetDirection(image.GetDirection())
|
|
204
284
|
result.SetOrigin(image.GetOrigin())
|
|
205
|
-
result.SetSpacing([x/y*z for x,y,z in zip(image.GetSize(), size, image.GetSpacing())])
|
|
285
|
+
result.SetSpacing([x / y * z for x, y, z in zip(image.GetSize(), size, image.GetSpacing())])
|
|
206
286
|
return result
|
|
207
287
|
|
|
288
|
+
|
|
208
289
|
def box_with_mask(mask: sitk.Image, label: list[int], dilatations: list[int]) -> np.ndarray:
|
|
209
290
|
|
|
210
|
-
dilatations = [int(np.ceil(d/s)) for d, s in zip(dilatations, reversed(mask.GetSpacing()))]
|
|
291
|
+
dilatations = [int(np.ceil(d / s)) for d, s in zip(dilatations, reversed(mask.GetSpacing()))]
|
|
211
292
|
|
|
212
293
|
data = sitk.GetArrayFromImage(mask)
|
|
213
294
|
border = np.where(np.isin(sitk.GetArrayFromImage(mask), label))
|
|
214
295
|
box = []
|
|
215
296
|
for w, dilatation, s in zip(border, dilatations, data.shape):
|
|
216
|
-
box.append([max(np.min(w)-dilatation, 0), min(np.max(w)+dilatation, s)])
|
|
297
|
+
box.append([max(np.min(w) - dilatation, 0), min(np.max(w) + dilatation, s)])
|
|
217
298
|
box = np.asarray(box)
|
|
218
299
|
return box
|
|
219
300
|
|
|
301
|
+
|
|
220
302
|
def crop_with_mask(image: sitk.Image, box: np.ndarray) -> sitk.Image:
|
|
221
303
|
data = sitk.GetArrayFromImage(image)
|
|
222
|
-
|
|
304
|
+
|
|
223
305
|
for i, w in enumerate(box):
|
|
224
306
|
data = np.delete(data, slice(w[1], data.shape[i]), i)
|
|
225
307
|
data = np.delete(data, slice(0, w[0]), i)
|
|
226
|
-
|
|
308
|
+
|
|
227
309
|
origin = np.asarray(image.GetOrigin())
|
|
228
310
|
matrix = np.asarray(image.GetDirection()).reshape((len(origin), len(origin)))
|
|
229
311
|
origin = origin.dot(matrix)
|
|
230
312
|
for i, w in enumerate(box):
|
|
231
|
-
origin[-i-1] += w[0]*np.asarray(image.GetSpacing())[-i-1]
|
|
313
|
+
origin[-i - 1] += w[0] * np.asarray(image.GetSpacing())[-i - 1]
|
|
232
314
|
origin = origin.dot(np.linalg.inv(matrix))
|
|
233
315
|
|
|
234
316
|
result = sitk.GetImageFromArray(data)
|
|
@@ -237,7 +319,8 @@ def crop_with_mask(image: sitk.Image, box: np.ndarray) -> sitk.Image:
|
|
|
237
319
|
result.SetDirection(image.GetDirection())
|
|
238
320
|
return result
|
|
239
321
|
|
|
240
|
-
|
|
322
|
+
|
|
323
|
+
def format_mask_label(mask: sitk.Image, labels: list[tuple[int, int]]) -> sitk.Image:
|
|
241
324
|
data = sitk.GetArrayFromImage(mask)
|
|
242
325
|
result_data = np.zeros_like(data, np.uint8)
|
|
243
326
|
|
|
@@ -248,7 +331,8 @@ def formatMaskLabel(mask: sitk.Image, labels: list[tuple[int, int]]) -> sitk.Ima
|
|
|
248
331
|
result.CopyInformation(mask)
|
|
249
332
|
return result
|
|
250
333
|
|
|
251
|
-
|
|
334
|
+
|
|
335
|
+
def get_flat_label(mask: sitk.Image, labels: None | list[int] = None) -> sitk.Image:
|
|
252
336
|
data = sitk.GetArrayFromImage(mask)
|
|
253
337
|
result_data = np.zeros_like(data, np.uint8)
|
|
254
338
|
if labels is not None:
|
|
@@ -257,13 +341,14 @@ def getFlatLabel(mask: sitk.Image, labels: Union[None, list[int]] = None) -> sit
|
|
|
257
341
|
else:
|
|
258
342
|
result_data[np.where(data > 0)] = 1
|
|
259
343
|
result = sitk.GetImageFromArray(result_data)
|
|
260
|
-
result.CopyInformation(mask)
|
|
344
|
+
result.CopyInformation(mask)
|
|
261
345
|
return result
|
|
262
346
|
|
|
263
|
-
|
|
347
|
+
|
|
348
|
+
def clip_and_cast(image: sitk.Image, min_value: float, max_value: float, dtype: np.dtype) -> sitk.Image:
|
|
264
349
|
data = sitk.GetArrayFromImage(image)
|
|
265
|
-
data[np.where(data >
|
|
266
|
-
data[np.where(data <
|
|
350
|
+
data[np.where(data > max_value)] = max_value
|
|
351
|
+
data[np.where(data < min_value)] = min_value
|
|
267
352
|
result = sitk.GetImageFromArray(data.astype(dtype))
|
|
268
353
|
result.CopyInformation(image)
|
|
269
354
|
return result
|