konfai 1.1.8__py3-none-any.whl → 1.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of konfai might be problematic. Click here for more details.

Files changed (36) hide show
  1. konfai/__init__.py +59 -14
  2. konfai/data/augmentation.py +457 -286
  3. konfai/data/data_manager.py +533 -316
  4. konfai/data/patching.py +300 -183
  5. konfai/data/transform.py +408 -275
  6. konfai/evaluator.py +325 -68
  7. konfai/main.py +71 -22
  8. konfai/metric/measure.py +360 -244
  9. konfai/metric/schedulers.py +24 -13
  10. konfai/models/classification/convNeXt.py +187 -81
  11. konfai/models/classification/resnet.py +272 -58
  12. konfai/models/generation/cStyleGan.py +233 -59
  13. konfai/models/generation/ddpm.py +348 -121
  14. konfai/models/generation/diffusionGan.py +757 -358
  15. konfai/models/generation/gan.py +177 -53
  16. konfai/models/generation/vae.py +140 -40
  17. konfai/models/registration/registration.py +135 -52
  18. konfai/models/representation/representation.py +57 -23
  19. konfai/models/segmentation/NestedUNet.py +339 -68
  20. konfai/models/segmentation/UNet.py +140 -30
  21. konfai/network/blocks.py +331 -187
  22. konfai/network/network.py +795 -427
  23. konfai/predictor.py +644 -238
  24. konfai/trainer.py +509 -222
  25. konfai/utils/ITK.py +191 -106
  26. konfai/utils/config.py +152 -95
  27. konfai/utils/dataset.py +326 -455
  28. konfai/utils/utils.py +497 -249
  29. {konfai-1.1.8.dist-info → konfai-1.2.0.dist-info}/METADATA +1 -3
  30. konfai-1.2.0.dist-info/RECORD +38 -0
  31. konfai/utils/registration.py +0 -199
  32. konfai-1.1.8.dist-info/RECORD +0 -39
  33. {konfai-1.1.8.dist-info → konfai-1.2.0.dist-info}/WHEEL +0 -0
  34. {konfai-1.1.8.dist-info → konfai-1.2.0.dist-info}/entry_points.txt +0 -0
  35. {konfai-1.1.8.dist-info → konfai-1.2.0.dist-info}/licenses/LICENSE +0 -0
  36. {konfai-1.1.8.dist-info → konfai-1.2.0.dist-info}/top_level.txt +0 -0
konfai/utils/ITK.py CHANGED
@@ -1,17 +1,20 @@
1
- import SimpleITK as sitk
2
- from typing import Union
3
1
  import numpy as np
4
- import torch
5
2
  import scipy
6
- import torch.nn.functional as F
3
+ import SimpleITK as sitk # noqa: N813
4
+ import torch
5
+ import torch.nn.functional as F # noqa: N812
6
+
7
7
  from konfai.utils.utils import _resample
8
8
 
9
- def _openTransform(transform_files: dict[Union[str, sitk.Transform], bool], image: sitk.Image= None) -> list[sitk.Transform]:
9
+
10
+ def _open_transform(
11
+ transform_files: dict[str | sitk.Transform, bool], image: sitk.Image = None
12
+ ) -> list[sitk.Transform]:
10
13
  transforms: list[sitk.Transform] = []
11
14
 
12
15
  for transform_file, invert in transform_files.items():
13
16
  if isinstance(transform_file, str):
14
- transform = sitk.ReadTransform(transform_file+".itk.txt")
17
+ transform = sitk.ReadTransform(transform_file + ".itk.txt")
15
18
  else:
16
19
  transform = transform_file
17
20
  if transform.GetName() == "TranslationTransform":
@@ -32,203 +35,282 @@ def _openTransform(transform_files: dict[Union[str, sitk.Transform], bool], imag
32
35
  transform = sitk.AffineTransform(transform.GetInverse())
33
36
  elif transform.GetName() == "DisplacementFieldTransform":
34
37
  if invert:
35
- transformToDisplacementFieldFilter = sitk.TransformToDisplacementFieldFilter()
36
- transformToDisplacementFieldFilter.SetReferenceImage(image)
37
- displacementField = transformToDisplacementFieldFilter.Execute(transform)
38
- iterativeInverseDisplacementFieldImageFilter = sitk.IterativeInverseDisplacementFieldImageFilter()
39
- iterativeInverseDisplacementFieldImageFilter.SetNumberOfIterations(20)
40
- inverseDisplacementField = iterativeInverseDisplacementFieldImageFilter.Execute(displacementField)
41
- transform = sitk.DisplacementFieldTransform(inverseDisplacementField)
38
+ transform_to_displacement_field_filter = sitk.TransformToDisplacementFieldFilter()
39
+ transform_to_displacement_field_filter.SetReferenceImage(image)
40
+ displacement_field = transform_to_displacement_field_filter.Execute(transform)
41
+ iterative_inverse_displacement_field_image_filter = sitk.IterativeInverseDisplacementFieldImageFilter()
42
+ iterative_inverse_displacement_field_image_filter.SetNumberOfIterations(20)
43
+ inverse_displacement_field = iterative_inverse_displacement_field_image_filter.Execute(
44
+ displacement_field
45
+ )
46
+ transform = sitk.DisplacementFieldTransform(inverse_displacement_field)
42
47
  transforms.append(transform)
43
48
  else:
44
49
  transform = sitk.BSplineTransform(transform)
45
50
  if invert:
46
- transformToDisplacementFieldFilter = sitk.TransformToDisplacementFieldFilter()
47
- transformToDisplacementFieldFilter.SetReferenceImage(image)
48
- displacementField = transformToDisplacementFieldFilter.Execute(transform)
49
- iterativeInverseDisplacementFieldImageFilter = sitk.IterativeInverseDisplacementFieldImageFilter()
50
- iterativeInverseDisplacementFieldImageFilter.SetNumberOfIterations(20)
51
- inverseDisplacementField = iterativeInverseDisplacementFieldImageFilter.Execute(displacementField)
52
- transform = sitk.DisplacementFieldTransform(inverseDisplacementField)
51
+ transform_to_displacement_field_filter = sitk.TransformToDisplacementFieldFilter()
52
+ transform_to_displacement_field_filter.SetReferenceImage(image)
53
+ displacement_field = transform_to_displacement_field_filter.Execute(transform)
54
+ iterative_inverse_displacement_field_image_filter = sitk.IterativeInverseDisplacementFieldImageFilter()
55
+ iterative_inverse_displacement_field_image_filter.SetNumberOfIterations(20)
56
+ inverse_displacement_field = iterative_inverse_displacement_field_image_filter.Execute(
57
+ displacement_field
58
+ )
59
+ transform = sitk.DisplacementFieldTransform(inverse_displacement_field)
53
60
  transforms.append(transform)
54
61
  if len(transforms) == 0:
55
62
  transforms.append(sitk.Euler3DTransform())
56
63
  return transforms
57
64
 
58
- def _openRigidTransform(transform_files: dict[Union[str, sitk.Transform], bool]) -> tuple[np.ndarray, np.ndarray]:
59
- transforms = _openTransform(transform_files)
65
+
66
+ def _open_rigid_transform(transform_files: dict[str | sitk.Transform, bool]) -> tuple[np.ndarray, np.ndarray]:
67
+ transforms = _open_transform(transform_files)
60
68
  matrix_result = np.identity(3)
61
- translation_result = np.array([0,0,0])
69
+ translation_result = np.array([0, 0, 0])
62
70
 
63
71
  for transform in transforms:
64
72
  if hasattr(transform, "GetMatrix"):
65
- matrix = np.linalg.inv(np.array(transform.GetMatrix(), dtype=np.double).reshape((3,3)))
73
+ matrix = np.linalg.inv(np.array(transform.GetMatrix(), dtype=np.double).reshape((3, 3)))
66
74
  translation = -np.asarray(transform.GetTranslation(), dtype=np.double)
67
75
  center = np.asarray(transform.GetCenter(), dtype=np.double)
68
76
  else:
69
77
  matrix = np.eye(len(transform.GetOffset()))
70
78
  translation = -np.asarray(transform.GetOffset(), dtype=np.double)
71
- center = np.asarray([0]*len(transform.GetOffset()), dtype=np.double)
72
-
73
- translation_center = np.linalg.inv(matrix).dot(matrix.dot(translation-center)+center)
74
- translation_result = np.linalg.inv(matrix_result).dot(translation_center)+translation_result
79
+ center = np.asarray([0] * len(transform.GetOffset()), dtype=np.double)
80
+
81
+ translation_center = np.linalg.inv(matrix).dot(matrix.dot(translation - center) + center)
82
+ translation_result = np.linalg.inv(matrix_result).dot(translation_center) + translation_result
75
83
  matrix_result = matrix.dot(matrix_result)
76
84
  return np.linalg.inv(matrix_result), -translation_result
77
85
 
78
- def composeTransform(transform_files : dict[Union[str, sitk.Transform], bool], image : sitk.Image = None) -> None:#sitk.CompositeTransform:
79
- transforms = _openTransform(transform_files, image)
86
+
87
+ def compose_transform(
88
+ transform_files: dict[str | sitk.Transform, bool], image: sitk.Image = None
89
+ ) -> sitk.CompositeTransform:
90
+ transforms = _open_transform(transform_files, image)
80
91
  result = sitk.CompositeTransform(transforms)
81
92
  return result
82
93
 
83
- def flattenTransform(transform_files: dict[Union[str, sitk.Transform], bool]) -> sitk.AffineTransform:
84
- [matrix, translation] = _openRigidTransform(transform_files)
94
+
95
+ def flatten_transform(transform_files: dict[str | sitk.Transform, bool]) -> sitk.AffineTransform:
96
+ [matrix, translation] = _open_rigid_transform(transform_files)
85
97
  transform = sitk.AffineTransform(3)
86
98
  transform.SetMatrix(matrix.flatten())
87
99
  transform.SetTranslation(translation)
88
100
  return transform
89
101
 
90
- def apply_to_image_RigidTransform(image: sitk.Image, transform_files: dict[Union[str, sitk.Transform], bool]) -> sitk.Image:
91
- [matrix, translation] = _openRigidTransform(transform_files)
102
+
103
+ def apply_to_image_rigid_transform(image: sitk.Image, transform_files: dict[str | sitk.Transform, bool]) -> sitk.Image:
104
+ [matrix, translation] = _open_rigid_transform(transform_files)
92
105
  matrix = np.linalg.inv(matrix)
93
106
  translation = -translation
94
107
  data = sitk.GetArrayFromImage(image)
95
108
  result = sitk.GetImageFromArray(data)
96
- result.SetDirection(matrix.dot(np.array(image.GetDirection()).reshape((3,3))).flatten())
97
- result.SetOrigin(matrix.dot(np.array(image.GetOrigin())+translation))
109
+ result.SetDirection(matrix.dot(np.array(image.GetDirection()).reshape((3, 3))).flatten())
110
+ result.SetOrigin(matrix.dot(np.array(image.GetOrigin()) + translation))
98
111
  result.SetSpacing(image.GetSpacing())
99
112
  return result
100
113
 
101
- def apply_to_data_Transform(data: np.ndarray, transform_files: dict[Union[str, sitk.Transform], bool]) -> sitk.Image:
102
- transforms = composeTransform(transform_files)
114
+
115
+ def apply_to_data_transform(data: np.ndarray, transform_files: dict[str | sitk.Transform, bool]) -> sitk.Image:
116
+ transforms = compose_transform(transform_files)
103
117
  result = np.copy(data)
104
- _LPS = lambda matrix: np.array([-matrix[0], -matrix[1], matrix[2]], dtype=np.double)
118
+ # _LPS = lambda matrix: np.array([-matrix[0], -matrix[1], matrix[2]], dtype=np.double)
105
119
  for i in range(data.shape[0]):
106
- result[i, :] = _LPS(transforms.TransformPoint(np.asarray(_LPS(data[i, :]), dtype=np.double)))
120
+ result[i, :] = transforms.TransformPoint(np.asarray(data[i, :], dtype=np.double))
107
121
  return result
108
122
 
109
- def resampleITK(image_reference : sitk.Image, image : sitk.Image, transform_files : dict[Union[str, sitk.Transform], bool], mask = False, defaultPixelValue: Union[float, None] = None, torch_resample : bool = False) -> sitk.Image:
123
+
124
+ def resample_itk(
125
+ image_reference: sitk.Image,
126
+ image: sitk.Image,
127
+ transform_files: dict[str | sitk.Transform, bool],
128
+ mask=False,
129
+ default_pixel_value: float | None = None,
130
+ torch_resample: bool = False,
131
+ ) -> sitk.Image:
110
132
  if torch_resample:
111
- input = torch.tensor(sitk.GetArrayFromImage(image)).unsqueeze(0)
112
- vectors = [torch.arange(0, s) for s in input.shape[1:]]
113
- grids = torch.meshgrid(vectors, indexing='ij')
133
+ input_tensor = torch.tensor(sitk.GetArrayFromImage(image)).unsqueeze(0)
134
+ vectors = [torch.arange(0, s) for s in input_tensor.shape[1:]]
135
+ grids = torch.meshgrid(vectors, indexing="ij")
114
136
  grid = torch.stack(grids)
115
137
  grid = torch.unsqueeze(grid, 0)
116
- transformToDisplacementFieldFilter = sitk.TransformToDisplacementFieldFilter()
117
- transformToDisplacementFieldFilter.SetReferenceImage(image)
118
- transformToDisplacementFieldFilter.SetNumberOfThreads(16)
119
- new_locs = grid + torch.tensor(sitk.GetArrayFromImage(transformToDisplacementFieldFilter.Execute(composeTransform(transform_files, image)))).unsqueeze(0).permute(0, 4, 1, 2, 3)
138
+ transform_to_displacement_field_filter = sitk.TransformToDisplacementFieldFilter()
139
+ transform_to_displacement_field_filter.SetReferenceImage(image)
140
+ transform_to_displacement_field_filter.SetNumberOfThreads(16)
141
+ new_locs = grid + torch.tensor(
142
+ sitk.GetArrayFromImage(
143
+ transform_to_displacement_field_filter.Execute(compose_transform(transform_files, image))
144
+ )
145
+ ).unsqueeze(0).permute(0, 4, 1, 2, 3)
120
146
  shape = new_locs.shape[2:]
121
147
  for i in range(len(shape)):
122
148
  new_locs[:, i, ...] = 2 * (new_locs[:, i, ...] / (shape[i] - 1) - 0.5)
123
149
  new_locs = new_locs.permute(0, 2, 3, 4, 1)
124
150
  new_locs = new_locs[..., [2, 1, 0]]
125
- result_data = F.grid_sample(input.unsqueeze(0).float(), new_locs.float(), align_corners=True, padding_mode="border", mode="nearest" if input.dtype == torch.uint8 else "bilinear").squeeze(0)
126
- result_data = result_data.type(torch.uint8) if input.dtype == torch.uint8 else result_data
151
+ result_data = F.grid_sample(
152
+ input_tensor.unsqueeze(0).float(),
153
+ new_locs.float(),
154
+ align_corners=True,
155
+ padding_mode="border",
156
+ mode="nearest" if input_tensor.dtype == torch.uint8 else "bilinear",
157
+ ).squeeze(0)
158
+ result_data = result_data.type(torch.uint8) if input_tensor.dtype == torch.uint8 else result_data
127
159
  result = sitk.GetImageFromArray(result_data.squeeze(0).numpy())
128
160
  result.CopyInformation(image_reference)
129
161
  return result
130
162
  else:
131
- return sitk.Resample(image, image_reference, composeTransform(transform_files, image), sitk.sitkNearestNeighbor if mask else sitk.sitkBSpline, (defaultPixelValue if defaultPixelValue is not None else (0 if mask else int(np.min(sitk.GetArrayFromImage(image))))))
163
+ return sitk.Resample(
164
+ image,
165
+ image_reference,
166
+ compose_transform(transform_files, image),
167
+ sitk.sitkNearestNeighbor if mask else sitk.sitkBSpline,
168
+ (
169
+ default_pixel_value
170
+ if default_pixel_value is not None
171
+ else (0 if mask else int(np.min(sitk.GetArrayFromImage(image))))
172
+ ),
173
+ )
174
+
175
+
176
+ def parametermap_to_transform(
177
+ path_src: str,
178
+ ) -> sitk.Transform | list[sitk.Transform]:
179
+ transform = sitk.ReadParameterFile(path_src)
180
+
181
+ def array_format(x):
182
+ return [float(i) for i in x]
132
183
 
133
- def parameterMap_to_transform(path_src: str) -> Union[sitk.Transform, list[sitk.Transform]]:
134
- transform = sitk.ReadParameterFile("{}.0.txt".format(path_src))
135
- format = lambda x: np.array([float(i) for i in x])
184
+ dimension = int(transform["FixedImageDimension"][0])
136
185
 
137
186
  if transform["Transform"][0] == "EulerTransform":
138
- result = sitk.Euler3DTransform()
139
- parameters = format(transform["TransformParameters"])
140
- fixedParameters = format(transform["CenterOfRotationPoint"])+[0]
187
+ if dimension == 2:
188
+ result = sitk.Euler2DTransform()
189
+ else:
190
+ result = sitk.Euler3DTransform()
191
+ parameters = array_format(transform["TransformParameters"])
192
+ fixed_parameters = array_format(transform["CenterOfRotationPoint"]) + [0]
193
+ elif transform["Transform"][0] == "TranslationTransform":
194
+ result = sitk.TranslationTransform(dimension)
195
+ parameters = array_format(transform["TransformParameters"])
196
+ fixed_parameters = []
141
197
  elif transform["Transform"][0] == "AffineTransform":
142
- result = sitk.AffineTransform(3)
143
- parameters = format(transform["TransformParameters"])
144
- fixedParameters = format(transform["CenterOfRotationPoint"])+[0]
198
+ result = sitk.AffineTransform(dimension)
199
+ parameters = array_format(transform["TransformParameters"])
200
+ fixed_parameters = array_format(transform["CenterOfRotationPoint"]) + [0]
145
201
  elif transform["Transform"][0] == "BSplineStackTransform":
146
- parameters = format(transform["TransformParameters"])
147
- GridSize = format(transform["GridSize"])
148
- GridOrigin = format(transform["GridOrigin"])
149
- GridSpacing = format(transform["GridSpacing"])
150
- GridDirection = format(transform["GridDirection"]).reshape((3,3)).T.flatten()
151
- fixedParameters = np.concatenate([GridSize, GridOrigin, GridSpacing, GridDirection])
152
-
153
- nb = int(format(transform["Size"])[-1])
154
- sub = int(np.prod(GridSize))*3
202
+ parameters = array_format(transform["TransformParameters"])
203
+ grid_size = array_format(transform["GridSize"])
204
+ grid_origin = array_format(transform["GridOrigin"])
205
+ grid_spacing = array_format(transform["GridSpacing"])
206
+ grid_direction = (
207
+ np.asarray(array_format(transform["GridDirection"])).reshape((dimension, dimension)).T.flatten()
208
+ )
209
+ fixed_parameters = np.concatenate([grid_size, grid_origin, grid_spacing, grid_direction])
210
+
211
+ nb = int(array_format(transform["Size"])[-1])
212
+ sub = int(np.prod(grid_size)) * dimension
155
213
  results = []
156
214
  for i in range(nb):
157
- result = sitk.BSplineTransform(3)
158
- sub_parameters = parameters[i*sub:(i+1)*sub]
159
- result.SetFixedParameters(fixedParameters)
215
+ result = sitk.BSplineTransform(dimension)
216
+ sub_parameters = np.asarray(parameters[i * sub : (i + 1) * sub])
217
+ result.SetFixedParameters(fixed_parameters)
160
218
  result.SetParameters(sub_parameters)
161
219
  results.append(result)
162
220
  return results
163
221
  elif transform["Transform"][0] == "AffineLogStackTransform":
164
- parameters = format(transform["TransformParameters"])
165
- fixedParameters = format(transform["CenterOfRotationPoint"])+[0]
222
+ parameters = array_format(transform["TransformParameters"])
223
+ fixed_parameters = array_format(transform["CenterOfRotationPoint"]) + [0]
166
224
 
167
225
  nb = int(transform["NumberOfSubTransforms"][0])
168
- sub = 12
226
+ sub = dimension * 4
169
227
  results = []
170
228
  for i in range(nb):
171
- result = sitk.AffineTransform(3)
172
- sub_parameters = parameters[i*sub:(i+1)*sub]
229
+ result = sitk.AffineTransform(dimension)
230
+ sub_parameters = np.asarray(parameters[i * sub : (i + 1) * sub])
173
231
 
174
- result.SetFixedParameters(fixedParameters)
175
- result.SetParameters(np.concatenate([scipy.linalg.expm(sub_parameters[:9].reshape((3,3))).flatten(), sub_parameters[-3:]]))
232
+ result.SetFixedParameters(fixed_parameters)
233
+ result.SetParameters(
234
+ np.concatenate(
235
+ [
236
+ scipy.linalg.expm(
237
+ sub_parameters[: dimension * dimension].reshape((dimension, dimension))
238
+ ).flatten(),
239
+ sub_parameters[-dimension:],
240
+ ]
241
+ )
242
+ )
176
243
  results.append(result)
177
244
  return results
178
245
  elif transform["Transform"][0] == "BSplineTransform":
179
- result = sitk.BSplineTransform(3)
180
-
181
- parameters = format(transform["TransformParameters"])
182
- GridSize = format(transform["GridSize"])
183
- GridOrigin = format(transform["GridOrigin"])
184
- GridSpacing = format(transform["GridSpacing"])
185
- GridDirection = np.array(format(transform["GridDirection"])).reshape((3,3)).T.flatten()
186
- fixedParameters = np.concatenate([GridSize, GridOrigin, GridSpacing, GridDirection])
246
+ result = sitk.BSplineTransform(dimension)
247
+
248
+ parameters = array_format(transform["TransformParameters"])
249
+ grid_size = array_format(transform["GridSize"])
250
+ grid_origin = array_format(transform["GridOrigin"])
251
+ grid_spacing = array_format(transform["GridSpacing"])
252
+ grid_direction = np.array(array_format(transform["GridDirection"])).reshape((dimension, dimension)).T.flatten()
253
+ fixed_parameters = np.concatenate([grid_size, grid_origin, grid_spacing, grid_direction])
187
254
  else:
188
- raise NameError("Transform {} doesn't exist".format(transform["Transform"][0]))
189
- result.SetFixedParameters(fixedParameters)
255
+ raise NameError(f"Transform {transform['Transform'][0]} doesn't exist")
256
+ result.SetFixedParameters(fixed_parameters)
190
257
  result.SetParameters(parameters)
191
258
  return result
192
259
 
193
- def resampleIsotropic(image: sitk.Image, spacing : list[float] = [1., 1., 1.]) -> sitk.Image:
194
- resize_factor = [y/x for x,y in zip(spacing, image.GetSpacing())]
195
- result = sitk.GetImageFromArray(_resample(torch.tensor(sitk.GetArrayFromImage(image)).unsqueeze(0), [int(size*factor) for size, factor in zip(image.GetSize(), resize_factor)]).squeeze(0).numpy())
260
+
261
+ def resample_isotropic(image: sitk.Image, spacing: list[float] | None = None) -> sitk.Image:
262
+ spacing = spacing or [1.0, 1.0, 1.0]
263
+ resize_factor = [y / x for x, y in zip(spacing, image.GetSpacing())]
264
+ result = sitk.GetImageFromArray(
265
+ _resample(
266
+ torch.tensor(sitk.GetArrayFromImage(image)).unsqueeze(0),
267
+ [int(size * factor) for size, factor in zip(image.GetSize(), resize_factor)],
268
+ )
269
+ .squeeze(0)
270
+ .numpy()
271
+ )
196
272
  result.SetDirection(image.GetDirection())
197
273
  result.SetOrigin(image.GetOrigin())
198
274
  result.SetSpacing(spacing)
199
275
  return result
200
276
 
201
- def resampleResize(image: sitk.Image, size : list[int] = [100,512,512]):
202
- result = sitk.GetImageFromArray(_resample(torch.tensor(sitk.GetArrayFromImage(image)).unsqueeze(0), size).squeeze(0).numpy())
277
+
278
+ def resample_resize(image: sitk.Image, size: list[int] | None = None):
279
+ size = size or [100, 512, 512]
280
+ result = sitk.GetImageFromArray(
281
+ _resample(torch.tensor(sitk.GetArrayFromImage(image)).unsqueeze(0), size).squeeze(0).numpy()
282
+ )
203
283
  result.SetDirection(image.GetDirection())
204
284
  result.SetOrigin(image.GetOrigin())
205
- result.SetSpacing([x/y*z for x,y,z in zip(image.GetSize(), size, image.GetSpacing())])
285
+ result.SetSpacing([x / y * z for x, y, z in zip(image.GetSize(), size, image.GetSpacing())])
206
286
  return result
207
287
 
288
+
208
289
  def box_with_mask(mask: sitk.Image, label: list[int], dilatations: list[int]) -> np.ndarray:
209
290
 
210
- dilatations = [int(np.ceil(d/s)) for d, s in zip(dilatations, reversed(mask.GetSpacing()))]
291
+ dilatations = [int(np.ceil(d / s)) for d, s in zip(dilatations, reversed(mask.GetSpacing()))]
211
292
 
212
293
  data = sitk.GetArrayFromImage(mask)
213
294
  border = np.where(np.isin(sitk.GetArrayFromImage(mask), label))
214
295
  box = []
215
296
  for w, dilatation, s in zip(border, dilatations, data.shape):
216
- box.append([max(np.min(w)-dilatation, 0), min(np.max(w)+dilatation, s)])
297
+ box.append([max(np.min(w) - dilatation, 0), min(np.max(w) + dilatation, s)])
217
298
  box = np.asarray(box)
218
299
  return box
219
300
 
301
+
220
302
  def crop_with_mask(image: sitk.Image, box: np.ndarray) -> sitk.Image:
221
303
  data = sitk.GetArrayFromImage(image)
222
-
304
+
223
305
  for i, w in enumerate(box):
224
306
  data = np.delete(data, slice(w[1], data.shape[i]), i)
225
307
  data = np.delete(data, slice(0, w[0]), i)
226
-
308
+
227
309
  origin = np.asarray(image.GetOrigin())
228
310
  matrix = np.asarray(image.GetDirection()).reshape((len(origin), len(origin)))
229
311
  origin = origin.dot(matrix)
230
312
  for i, w in enumerate(box):
231
- origin[-i-1] += w[0]*np.asarray(image.GetSpacing())[-i-1]
313
+ origin[-i - 1] += w[0] * np.asarray(image.GetSpacing())[-i - 1]
232
314
  origin = origin.dot(np.linalg.inv(matrix))
233
315
 
234
316
  result = sitk.GetImageFromArray(data)
@@ -237,7 +319,8 @@ def crop_with_mask(image: sitk.Image, box: np.ndarray) -> sitk.Image:
237
319
  result.SetDirection(image.GetDirection())
238
320
  return result
239
321
 
240
- def formatMaskLabel(mask: sitk.Image, labels: list[tuple[int, int]]) -> sitk.Image:
322
+
323
+ def format_mask_label(mask: sitk.Image, labels: list[tuple[int, int]]) -> sitk.Image:
241
324
  data = sitk.GetArrayFromImage(mask)
242
325
  result_data = np.zeros_like(data, np.uint8)
243
326
 
@@ -248,7 +331,8 @@ def formatMaskLabel(mask: sitk.Image, labels: list[tuple[int, int]]) -> sitk.Ima
248
331
  result.CopyInformation(mask)
249
332
  return result
250
333
 
251
- def getFlatLabel(mask: sitk.Image, labels: Union[None, list[int]] = None) -> sitk.Image:
334
+
335
+ def get_flat_label(mask: sitk.Image, labels: None | list[int] = None) -> sitk.Image:
252
336
  data = sitk.GetArrayFromImage(mask)
253
337
  result_data = np.zeros_like(data, np.uint8)
254
338
  if labels is not None:
@@ -257,13 +341,14 @@ def getFlatLabel(mask: sitk.Image, labels: Union[None, list[int]] = None) -> sit
257
341
  else:
258
342
  result_data[np.where(data > 0)] = 1
259
343
  result = sitk.GetImageFromArray(result_data)
260
- result.CopyInformation(mask)
344
+ result.CopyInformation(mask)
261
345
  return result
262
346
 
263
- def clipAndCast(image : sitk.Image, min: float, max: float, dtype: np.dtype) -> sitk.Image:
347
+
348
+ def clip_and_cast(image: sitk.Image, min_value: float, max_value: float, dtype: np.dtype) -> sitk.Image:
264
349
  data = sitk.GetArrayFromImage(image)
265
- data[np.where(data > max)] = max
266
- data[np.where(data < min)] = min
350
+ data[np.where(data > max_value)] = max_value
351
+ data[np.where(data < min_value)] = min_value
267
352
  result = sitk.GetImageFromArray(data.astype(dtype))
268
353
  result.CopyInformation(image)
269
354
  return result