konfai 1.1.7__py3-none-any.whl → 1.1.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of konfai might be problematic. Click here for more details.

Files changed (36) hide show
  1. konfai/__init__.py +59 -14
  2. konfai/data/augmentation.py +457 -286
  3. konfai/data/data_manager.py +509 -290
  4. konfai/data/patching.py +300 -183
  5. konfai/data/transform.py +384 -277
  6. konfai/evaluator.py +309 -68
  7. konfai/main.py +71 -22
  8. konfai/metric/measure.py +341 -222
  9. konfai/metric/schedulers.py +24 -13
  10. konfai/models/classification/convNeXt.py +187 -81
  11. konfai/models/classification/resnet.py +272 -58
  12. konfai/models/generation/cStyleGan.py +233 -59
  13. konfai/models/generation/ddpm.py +348 -121
  14. konfai/models/generation/diffusionGan.py +757 -358
  15. konfai/models/generation/gan.py +177 -53
  16. konfai/models/generation/vae.py +140 -40
  17. konfai/models/registration/registration.py +135 -52
  18. konfai/models/representation/representation.py +57 -23
  19. konfai/models/segmentation/NestedUNet.py +339 -68
  20. konfai/models/segmentation/UNet.py +140 -30
  21. konfai/network/blocks.py +331 -187
  22. konfai/network/network.py +781 -423
  23. konfai/predictor.py +645 -240
  24. konfai/trainer.py +527 -216
  25. konfai/utils/ITK.py +191 -106
  26. konfai/utils/config.py +152 -95
  27. konfai/utils/dataset.py +326 -455
  28. konfai/utils/utils.py +495 -249
  29. {konfai-1.1.7.dist-info → konfai-1.1.9.dist-info}/METADATA +1 -3
  30. konfai-1.1.9.dist-info/RECORD +38 -0
  31. konfai/utils/registration.py +0 -199
  32. konfai-1.1.7.dist-info/RECORD +0 -39
  33. {konfai-1.1.7.dist-info → konfai-1.1.9.dist-info}/WHEEL +0 -0
  34. {konfai-1.1.7.dist-info → konfai-1.1.9.dist-info}/entry_points.txt +0 -0
  35. {konfai-1.1.7.dist-info → konfai-1.1.9.dist-info}/licenses/LICENSE +0 -0
  36. {konfai-1.1.7.dist-info → konfai-1.1.9.dist-info}/top_level.txt +0 -0
konfai/data/transform.py CHANGED
@@ -1,565 +1,672 @@
1
1
  import importlib
2
- import torch
3
- import numpy as np
4
- import SimpleITK as sitk
5
2
  from abc import ABC, abstractmethod
6
- import torch.nn.functional as F
7
- from typing import Any, Union
3
+ from typing import Any
4
+
5
+ import numpy as np
6
+ import SimpleITK as sitk # noqa: N813
7
+ import torch
8
+ import torch.nn.functional as F # noqa: N812
8
9
 
9
- from konfai.utils.utils import _getModule, NeedDevice, _resample_affine, _affine_matrix, TransformError
10
- from konfai.utils.dataset import Dataset, Attribute, data_to_image, image_to_data
11
10
  from konfai.utils.config import config
11
+ from konfai.utils.dataset import Attribute, Dataset, data_to_image, image_to_data
12
+ from konfai.utils.utils import NeedDevice, TransformError, _affine_matrix, _resample_affine, get_module
13
+
12
14
 
13
15
  class Transform(NeedDevice, ABC):
14
-
16
+
15
17
  def __init__(self) -> None:
16
- self.datasets : list[Dataset] = []
17
-
18
- def setDatasets(self, datasets: list[Dataset]):
18
+ self.datasets: list[Dataset] = []
19
+
20
+ def set_datasets(self, datasets: list[Dataset]):
19
21
  self.datasets = datasets
20
22
 
21
- def transformShape(self, shape: list[int], cache_attribute: Attribute) -> list[int]:
23
+ def transform_shape(self, shape: list[int], cache_attribute: Attribute) -> list[int]:
22
24
  return shape
23
25
 
24
26
  @abstractmethod
25
- def __call__(self, name: str, input : torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
27
+ def __call__(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
26
28
  pass
27
29
 
28
30
  @abstractmethod
29
- def inverse(self, name: str, input : torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
31
+ def inverse(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
30
32
  pass
31
33
 
34
+
32
35
  class TransformLoader:
33
36
 
34
37
  @config()
35
38
  def __init__(self) -> None:
36
39
  pass
37
-
38
- def getTransform(self, classpath : str, DL_args : str) -> Transform:
39
- module, name = _getModule(classpath, "konfai.data.transform")
40
- return config("{}.{}".format(DL_args, classpath))(getattr(importlib.import_module(module), name))(config = None)
40
+
41
+ def get_transform(self, classpath: str, konfai_args: str) -> Transform:
42
+ module, name = get_module(classpath, "konfai.data.transform")
43
+ return config(f"{konfai_args}.{classpath}")(getattr(importlib.import_module(module), name))(config=None)
44
+
41
45
 
42
46
  class Clip(Transform):
43
47
 
44
- def __init__(self, min_value : float = -1024, max_value : float = 1024, saveClip_min: bool = False, saveClip_max: bool = False) -> None:
45
- assert max_value > min_value
48
+ def __init__(
49
+ self,
50
+ min_value: float = -1024,
51
+ max_value: float = 1024,
52
+ save_clip_min: bool = False,
53
+ save_clip_max: bool = False,
54
+ ) -> None:
55
+ if max_value <= min_value:
56
+ raise ValueError(
57
+ f"[Clip] Invalid clipping range: max_value ({max_value}) must be greater than min_value ({min_value})"
58
+ )
46
59
  self.min_value = min_value
47
60
  self.max_value = max_value
48
- self.saveClip_min = saveClip_min
49
- self.saveClip_max = saveClip_max
61
+ self.save_clip_min = save_clip_min
62
+ self.save_clip_max = save_clip_max
50
63
 
51
- def __call__(self, name: str, input : torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
52
- input[torch.where(input < self.min_value)] = self.min_value
53
- input[torch.where(input > self.max_value)] = self.max_value
54
- if self.saveClip_min:
64
+ def __call__(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
65
+ tensor[torch.where(tensor < self.min_value)] = self.min_value
66
+ tensor[torch.where(tensor > self.max_value)] = self.max_value
67
+ if self.save_clip_min:
55
68
  cache_attribute["Min"] = self.min_value
56
- if self.saveClip_max:
69
+ if self.save_clip_max:
57
70
  cache_attribute["Max"] = self.max_value
58
- return input
71
+ return tensor
72
+
73
+ def inverse(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
74
+ return tensor
59
75
 
60
- def inverse(self, name: str, input : torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
61
- return input
62
76
 
63
77
  class Normalize(Transform):
64
78
 
65
- def __init__(self, lazy : bool = False, channels: Union[list[int], None] = None, min_value : float = -1, max_value : float = 1) -> None:
66
- assert max_value > min_value
79
+ def __init__(
80
+ self,
81
+ lazy: bool = False,
82
+ channels: list[int] | None = None,
83
+ min_value: float = -1,
84
+ max_value: float = 1,
85
+ ) -> None:
86
+ if max_value <= min_value:
87
+ raise ValueError(
88
+ f"[Normalize] Invalid range: max_value ({max_value}) must be greater than min_value ({min_value})"
89
+ )
67
90
  self.lazy = lazy
68
91
  self.min_value = min_value
69
92
  self.max_value = max_value
70
93
  self.channels = channels
71
94
 
72
- def __call__(self, name: str, input : torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
95
+ def __call__(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
73
96
  if "Min" not in cache_attribute:
74
97
  if self.channels:
75
- cache_attribute["Min"] = torch.min(input[self.channels])
98
+ cache_attribute["Min"] = torch.min(tensor[self.channels])
76
99
  else:
77
- cache_attribute["Min"] = torch.min(input)
100
+ cache_attribute["Min"] = torch.min(tensor)
78
101
  if "Max" not in cache_attribute:
79
102
  if self.channels:
80
- cache_attribute["Max"] = torch.max(input[self.channels])
103
+ cache_attribute["Max"] = torch.max(tensor[self.channels])
81
104
  else:
82
- cache_attribute["Max"] = torch.max(input)
105
+ cache_attribute["Max"] = torch.max(tensor)
83
106
  if not self.lazy:
84
107
  input_min = float(cache_attribute["Min"])
85
108
  input_max = float(cache_attribute["Max"])
86
- norm = input_max-input_min
87
- assert norm != 0
88
- if self.channels:
89
- for channel in self.channels:
90
- input[channel] = (self.max_value-self.min_value)*(input[channel] - input_min) / norm + self.min_value
109
+ norm = input_max - input_min
110
+
111
+ if norm == 0:
112
+ print(f"[WARNING] Norm is zero for case '{name}': input is constant with value = {self.min_value}.")
113
+ if self.channels:
114
+ for channel in self.channels:
115
+ tensor[channel].fill_(self.min_value)
116
+ else:
117
+ tensor.fill_(self.min_value)
91
118
  else:
92
- input = (self.max_value-self.min_value)*(input - input_min) / norm + self.min_value
93
- return input
94
-
95
- def inverse(self, name: str, input : torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
119
+ if self.channels:
120
+ for channel in self.channels:
121
+ tensor[channel] = (self.max_value - self.min_value) * (
122
+ tensor[channel] - input_min
123
+ ) / norm + self.min_value
124
+ else:
125
+ tensor = (self.max_value - self.min_value) * (tensor - input_min) / norm + self.min_value
126
+
127
+ return tensor
128
+
129
+ def inverse(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
96
130
  if self.lazy:
97
- return input
131
+ return tensor
98
132
  else:
99
133
  input_min = float(cache_attribute.pop("Min"))
100
134
  input_max = float(cache_attribute.pop("Max"))
101
- return (input - self.min_value)*(input_max-input_min)/(self.max_value-self.min_value)+input_min
135
+ return (tensor - self.min_value) * (input_max - input_min) / (self.max_value - self.min_value) + input_min
136
+
102
137
 
103
138
  class Standardize(Transform):
104
139
 
105
- def __init__(self, lazy : bool = False, mean: Union[list[float], None] = None, std: Union[list[float], None]= None) -> None:
140
+ def __init__(
141
+ self,
142
+ lazy: bool = False,
143
+ mean: list[float] | None = None,
144
+ std: list[float] | None = None,
145
+ ) -> None:
106
146
  self.lazy = lazy
107
147
  self.mean = mean
108
148
  self.std = std
109
149
 
110
- def __call__(self, name: str, input : torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
150
+ def __call__(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
111
151
  if "Mean" not in cache_attribute:
112
- cache_attribute["Mean"] = torch.mean(input.type(torch.float32), dim=[i + 1 for i in range(len(input.shape)-1)]) if self.mean is None else torch.tensor([self.mean])
152
+ cache_attribute["Mean"] = (
153
+ torch.mean(
154
+ tensor.type(torch.float32),
155
+ dim=[i + 1 for i in range(len(tensor.shape) - 1)],
156
+ )
157
+ if self.mean is None
158
+ else torch.tensor([self.mean])
159
+ )
113
160
  if "Std" not in cache_attribute:
114
- cache_attribute["Std"] = torch.std(input.type(torch.float32), dim=[i + 1 for i in range(len(input.shape)-1)]) if self.std is None else torch.tensor([self.std])
161
+ cache_attribute["Std"] = (
162
+ torch.std(
163
+ tensor.type(torch.float32),
164
+ dim=[i + 1 for i in range(len(tensor.shape) - 1)],
165
+ )
166
+ if self.std is None
167
+ else torch.tensor([self.std])
168
+ )
115
169
 
116
170
  if self.lazy:
117
- return input
171
+ return tensor
118
172
  else:
119
- mean = cache_attribute.get_tensor("Mean").view(-1, *[1 for _ in range(len(input.shape)-1)])
120
- std = cache_attribute.get_tensor("Std").view(-1, *[1 for _ in range(len(input.shape)-1)])
121
- return (input - mean) / std
122
-
123
- def inverse(self, name: str, input : torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
173
+ mean = cache_attribute.get_tensor("Mean").view(-1, *[1 for _ in range(len(tensor.shape) - 1)])
174
+ std = cache_attribute.get_tensor("Std").view(-1, *[1 for _ in range(len(tensor.shape) - 1)])
175
+ return (tensor - mean) / std
176
+
177
+ def inverse(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
124
178
  if self.lazy:
125
- return input
179
+ return tensor
126
180
  else:
127
181
  mean = float(cache_attribute.pop("Mean"))
128
182
  std = float(cache_attribute.pop("Std"))
129
- return input * std + mean
130
-
183
+ return tensor * std + mean
184
+
185
+
131
186
  class TensorCast(Transform):
132
187
 
133
- def __init__(self, dtype : str = "float32") -> None:
134
- self.dtype : torch.dtype = getattr(torch, dtype)
188
+ def __init__(self, dtype: str = "float32") -> None:
189
+ self.dtype: torch.dtype = getattr(torch, dtype)
190
+
191
+ def __call__(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
192
+ cache_attribute["dtype"] = str(tensor.dtype).replace("torch.", "")
193
+ return tensor.type(self.dtype)
194
+
195
+ @staticmethod
196
+ def safe_dtype_cast(dtype_str: str) -> torch.dtype:
197
+ try:
198
+ return getattr(torch, dtype_str)
199
+ except AttributeError:
200
+ raise ValueError(f"Unsupported dtype: {dtype_str}")
201
+
202
+ def inverse(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
203
+ return tensor.to(TensorCast.safe_dtype_cast(cache_attribute.pop("dtype")))
135
204
 
136
- def __call__(self, name: str, input : torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
137
- cache_attribute["dtype"] = input.dtype
138
- return input.type(self.dtype)
139
-
140
- def inverse(self, name: str, input : torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
141
- return input.to(eval(cache_attribute.pop("dtype")))
142
205
 
143
206
  class Padding(Transform):
144
207
 
145
- def __init__(self, padding : list[int] = [0,0,0,0,0,0], mode : str = "constant") -> None:
208
+ def __init__(self, padding: list[int] = [0, 0, 0, 0, 0, 0], mode: str = "constant") -> None:
146
209
  self.padding = padding
147
210
  self.mode = mode
148
211
 
149
- def __call__(self, name: str, input : torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
212
+ def __call__(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
150
213
  if "Origin" in cache_attribute and "Spacing" in cache_attribute and "Direction" in cache_attribute:
151
214
  origin = torch.tensor(cache_attribute.get_np_array("Origin"))
152
- matrix = torch.tensor(cache_attribute.get_np_array("Direction").reshape((len(origin),len(origin))))
215
+ matrix = torch.tensor(cache_attribute.get_np_array("Direction").reshape((len(origin), len(origin))))
153
216
  origin = torch.matmul(origin, matrix)
154
- for dim in range(len(self.padding)//2):
155
- origin[-dim-1] -= self.padding[dim*2]* cache_attribute.get_np_array("Spacing")[-dim-1]
217
+ for dim in range(len(self.padding) // 2):
218
+ origin[-dim - 1] -= self.padding[dim * 2] * cache_attribute.get_np_array("Spacing")[-dim - 1]
156
219
  cache_attribute["Origin"] = torch.matmul(origin, torch.inverse(matrix))
157
- result = F.pad(input.unsqueeze(0), tuple(self.padding), self.mode.split(":")[0], float(self.mode.split(":")[1]) if len(self.mode.split(":")) == 2 else 0).squeeze(0)
220
+ result = F.pad(
221
+ tensor.unsqueeze(0),
222
+ tuple(self.padding),
223
+ self.mode.split(":")[0],
224
+ float(self.mode.split(":")[1]) if len(self.mode.split(":")) == 2 else 0,
225
+ ).squeeze(0)
158
226
  return result
159
-
160
- def transformShape(self, shape: list[int], cache_attribute: Attribute) -> list[int]:
161
- for dim in range(len(self.padding)//2):
162
- shape[-dim-1] += sum(self.padding[dim*2:dim*2+2])
227
+
228
+ def transform_shape(self, shape: list[int], cache_attribute: Attribute) -> list[int]:
229
+ for dim in range(len(self.padding) // 2):
230
+ shape[-dim - 1] += sum(self.padding[dim * 2 : dim * 2 + 2])
163
231
  return shape
164
232
 
165
- def inverse(self, name: str, input : torch.Tensor, cache_attribute: dict[str, torch.Tensor]) -> torch.Tensor:
233
+ def inverse(self, name: str, tensor: torch.Tensor, cache_attribute: dict[str, torch.Tensor]) -> torch.Tensor:
166
234
  if "Origin" in cache_attribute and "Spacing" in cache_attribute and "Direction" in cache_attribute:
167
235
  cache_attribute.pop("Origin")
168
- slices = [slice(0, shape) for shape in input.shape]
169
- for dim in range(len(self.padding)//2):
170
- slices[-dim-1] = slice(self.padding[dim*2], input.shape[-dim-1]-self.padding[dim*2+1])
171
- result = input[slices]
236
+ slices = [slice(0, shape) for shape in tensor.shape]
237
+ for dim in range(len(self.padding) // 2):
238
+ slices[-dim - 1] = slice(self.padding[dim * 2], tensor.shape[-dim - 1] - self.padding[dim * 2 + 1])
239
+ result = tensor[slices]
172
240
  return result
173
241
 
242
+
174
243
  class Squeeze(Transform):
175
244
 
176
245
  def __init__(self, dim: int) -> None:
177
246
  self.dim = dim
178
-
179
- def __call__(self, name: str, input : torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
180
- return input.squeeze(self.dim)
181
247
 
182
- def inverse(self, name: str, input : torch.Tensor, cache_attribute: dict[str, Any]) -> torch.Tensor:
183
- return input.unsqueeze(self.dim)
248
+ def __call__(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
249
+ return tensor.squeeze(self.dim)
250
+
251
+ def inverse(self, name: str, tensor: torch.Tensor, cache_attribute: dict[str, Any]) -> torch.Tensor:
252
+ return tensor.unsqueeze(self.dim)
253
+
184
254
 
185
255
  class Resample(Transform, ABC):
186
256
 
187
257
  def __init__(self) -> None:
188
258
  pass
189
259
 
190
- def _resample(self, input: torch.Tensor, size: list[int]) -> torch.Tensor:
191
- args = {}
192
- if input.dtype == torch.uint8:
260
+ def _resample(self, tensor: torch.Tensor, size: list[int]) -> torch.Tensor:
261
+ if tensor.dtype == torch.uint8:
193
262
  mode = "nearest"
194
- elif len(input.shape) < 4:
263
+ elif len(tensor.shape) < 4:
195
264
  mode = "bilinear"
196
265
  else:
197
266
  mode = "trilinear"
198
- return F.interpolate(input.type(torch.float32).unsqueeze(0), size=tuple(size), mode=mode).squeeze(0).type(input.dtype).cpu()
267
+ return (
268
+ F.interpolate(tensor.type(torch.float32).unsqueeze(0), size=tuple(size), mode=mode)
269
+ .squeeze(0)
270
+ .type(tensor.dtype)
271
+ .cpu()
272
+ )
199
273
 
200
274
  @abstractmethod
201
- def __call__(self, name: str, input : torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
275
+ def __call__(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
202
276
  pass
203
-
277
+
204
278
  @abstractmethod
205
- def transformShape(self, shape: list[int], cache_attribute: Attribute) -> list[int]:
279
+ def transform_shape(self, shape: list[int], cache_attribute: Attribute) -> list[int]:
206
280
  pass
207
-
208
- def inverse(self, name: str, input : torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
209
- size_0 = cache_attribute.pop_np_array("Size")
281
+
282
+ def inverse(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
283
+ cache_attribute.pop_np_array("Size")
210
284
  size_1 = cache_attribute.pop_np_array("Size")
211
285
  _ = cache_attribute.pop_np_array("Spacing")
212
- return self._resample(input, [int(size) for size in size_1])
286
+ return self._resample(tensor, [int(size) for size in size_1])
287
+
213
288
 
214
289
  class ResampleToResolution(Resample):
215
290
 
216
- def __init__(self, spacing : list[float] = [1., 1., 1.]) -> None:
291
+ def __init__(self, spacing: list[float] = [1.0, 1.0, 1.0]) -> None:
217
292
  self.spacing = torch.tensor([0 if s < 0 else s for s in spacing])
218
293
 
219
- def transformShape(self, shape: list[int], cache_attribute: Attribute) -> list[int]:
294
+ def transform_shape(self, shape: list[int], cache_attribute: Attribute) -> list[int]:
220
295
  if "Spacing" not in cache_attribute:
221
- TransformError("Missing 'Spacing' in cache attributes, the data is likely not a valid image.",
222
- "Make sure your input is a image (e.g., .nii, .mha) with proper metadata.")
296
+ TransformError(
297
+ "Missing 'Spacing' in cache attributes, the data is likely not a valid image.",
298
+ "Make sure your input is a image (e.g., .nii, .mha) with proper metadata.",
299
+ )
223
300
  if len(shape) != len(self.spacing):
224
301
  TransformError("Shape and spacing dimensions do not match: shape={shape}, spacing={self.spacing}")
225
302
  image_spacing = cache_attribute.get_tensor("Spacing").flip(0)
226
303
  spacing = self.spacing
227
-
304
+
228
305
  for i, s in enumerate(self.spacing):
229
306
  if s == 0:
230
307
  spacing[i] = image_spacing[i]
231
- resize_factor = spacing/cache_attribute.get_tensor("Spacing").flip(0)
232
- return [int(x) for x in (torch.tensor(shape) * 1/resize_factor)]
308
+ resize_factor = spacing / cache_attribute.get_tensor("Spacing").flip(0)
309
+ return [int(x) for x in (torch.tensor(shape) * 1 / resize_factor)]
233
310
 
234
- def __call__(self, name: str, input : torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
311
+ def __call__(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
235
312
  image_spacing = cache_attribute.get_tensor("Spacing").flip(0)
236
313
  spacing = self.spacing
237
314
  for i, s in enumerate(self.spacing):
238
315
  if s == 0:
239
316
  spacing[i] = image_spacing[i]
240
- resize_factor = spacing/cache_attribute.get_tensor("Spacing").flip(0)
317
+ resize_factor = spacing / cache_attribute.get_tensor("Spacing").flip(0)
241
318
  cache_attribute["Spacing"] = spacing.flip(0)
242
- cache_attribute["Size"] = np.asarray([int(x) for x in torch.tensor(input.shape[1:])])
243
- size = [int(x) for x in (torch.tensor(input.shape[1:]) * 1/resize_factor)]
319
+ cache_attribute["Size"] = np.asarray([int(x) for x in torch.tensor(tensor.shape[1:])])
320
+ size = [int(x) for x in (torch.tensor(tensor.shape[1:]) * 1 / resize_factor)]
244
321
  cache_attribute["Size"] = np.asarray(size)
245
- return self._resample(input, size)
322
+ return self._resample(tensor, size)
323
+
246
324
 
247
325
  class ResampleToShape(Resample):
248
326
 
249
- def __init__(self, shape : list[float] = [100,256,256]) -> None:
327
+ def __init__(self, shape: list[float] = [100, 256, 256]) -> None:
250
328
  self.shape = torch.tensor([0 if s < 0 else s for s in shape])
251
329
 
252
- def transformShape(self, shape: list[int], cache_attribute: Attribute) -> list[int]:
253
- print(shape)
330
+ def transform_shape(self, shape: list[int], cache_attribute: Attribute) -> list[int]:
254
331
  if "Spacing" not in cache_attribute:
255
- TransformError("Missing 'Spacing' in cache attributes, the data is likely not a valid image.",
256
- "Make sure your input is a image (e.g., .nii, .mha) with proper metadata.")
332
+ TransformError(
333
+ "Missing 'Spacing' in cache attributes, the data is likely not a valid image.",
334
+ "Make sure your input is a image (e.g., .nii, .mha) with proper metadata.",
335
+ )
257
336
  if len(shape) != len(self.shape):
258
337
  TransformError("Shape and spacing dimensions do not match: shape={shape}, spacing={self.spacing}")
259
338
  new_shape = self.shape
260
339
  for i, s in enumerate(self.shape):
261
340
  if s == 0:
262
341
  new_shape[i] = shape[i]
263
- print(new_shape)
264
342
  return new_shape
265
-
266
- def __call__(self, name: str, input: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
343
+
344
+ def __call__(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
267
345
  shape = self.shape
268
- image_shape = torch.tensor([int(x) for x in torch.tensor(input.shape[1:])])
346
+ image_shape = torch.tensor([int(x) for x in torch.tensor(tensor.shape[1:])])
269
347
  for i, s in enumerate(self.shape):
270
348
  if s == 0:
271
349
  shape[i] = image_shape[i]
272
350
  if "Spacing" in cache_attribute:
273
- cache_attribute["Spacing"] = torch.flip(image_shape/shape*torch.flip(cache_attribute.get_tensor("Spacing"), dims=[0]), dims=[0])
351
+ cache_attribute["Spacing"] = torch.flip(
352
+ image_shape / shape * torch.flip(cache_attribute.get_tensor("Spacing"), dims=[0]),
353
+ dims=[0],
354
+ )
274
355
  cache_attribute["Size"] = image_shape
275
356
  cache_attribute["Size"] = shape
276
- return self._resample(input, shape)
357
+ return self._resample(tensor, shape)
358
+
277
359
 
278
360
  class ResampleTransform(Transform):
279
361
 
280
- def __init__(self, transforms : dict[str, bool]) -> None:
362
+ def __init__(self, transforms: dict[str, bool]) -> None:
281
363
  self.transforms = transforms
282
-
283
- def transformShape(self, shape: list[int], cache_attribute: Attribute) -> list[int]:
364
+
365
+ def transform_shape(self, shape: list[int], cache_attribute: Attribute) -> list[int]:
284
366
  return shape
285
367
 
286
- def __call__V1(self, name: str, input : torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
287
- transforms = []
288
- image = data_to_image(input, cache_attribute)
289
- for transform_group, invert in self.transforms.items():
290
- transform = None
291
- for dataset in self.datasets:
292
- if dataset.isDatasetExist(transform_group, name):
293
- transform = dataset.readTransform(transform_group, name)
294
- break
295
- if transform is None:
296
- raise NameError("Tranform : {}/{} not found".format(transform_group, name))
297
- if isinstance(transform, sitk.BSplineTransform):
298
- if invert:
299
- transformToDisplacementFieldFilter = sitk.TransformToDisplacementFieldFilter()
300
- transformToDisplacementFieldFilter.SetReferenceImage(image)
301
- displacementField = transformToDisplacementFieldFilter.Execute(transform)
302
- iterativeInverseDisplacementFieldImageFilter = sitk.IterativeInverseDisplacementFieldImageFilter()
303
- iterativeInverseDisplacementFieldImageFilter.SetNumberOfIterations(20)
304
- inverseDisplacementField = iterativeInverseDisplacementFieldImageFilter.Execute(displacementField)
305
- transform = sitk.DisplacementFieldTransform(inverseDisplacementField)
306
- else:
307
- if invert:
308
- transform = transform.GetInverse()
309
- transforms.append(transform)
310
- result_transform = sitk.CompositeTransform(transforms)
311
- result = torch.tensor(sitk.GetArrayFromImage(sitk.Resample(image, image, result_transform, sitk.sitkNearestNeighbor if input.dtype == torch.uint8 else sitk.sitkBSpline, 0 if input.dtype == torch.uint8 else -1024))).unsqueeze(0)
312
- return result.type(torch.uint8) if input.dtype == torch.uint8 else result
313
-
314
- def __call__(self, name: str, input : torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
315
- assert len(input.shape) == 4 , "input size should be 5 dim"
316
- image = data_to_image(input, cache_attribute)
317
-
318
- vectors = [torch.arange(0, s) for s in input.shape[1:]]
319
- grids = torch.meshgrid(vectors, indexing='ij')
368
+ def __call__(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
369
+ if len(tensor.shape) != 4:
370
+ raise NameError("Input size should be 5 dim")
371
+ image = data_to_image(tensor, cache_attribute)
372
+
373
+ vectors = [torch.arange(0, s) for s in tensor.shape[1:]]
374
+ grids = torch.meshgrid(vectors, indexing="ij")
320
375
  grid = torch.stack(grids)
321
376
  grid = torch.unsqueeze(grid, 0)
322
-
377
+
323
378
  transforms = []
324
379
  for transform_group, invert in self.transforms.items():
325
380
  transform = None
326
381
  for dataset in self.datasets:
327
- if dataset.isDatasetExist(transform_group, name):
328
- transform = dataset.readTransform(transform_group, name)
382
+ if dataset.is_dataset_exist(transform_group, name):
383
+ transform = dataset.read_transform(transform_group, name)
329
384
  break
330
385
  if transform is None:
331
- raise NameError("Tranform : {}/{} not found".format(transform_group, name))
386
+ raise NameError(f"Tranform : {transform_group}/{name} not found")
332
387
  if isinstance(transform, sitk.BSplineTransform):
333
388
  if invert:
334
- transformToDisplacementFieldFilter = sitk.TransformToDisplacementFieldFilter()
335
- transformToDisplacementFieldFilter.SetReferenceImage(image)
336
- displacementField = transformToDisplacementFieldFilter.Execute(transform)
337
- iterativeInverseDisplacementFieldImageFilter = sitk.IterativeInverseDisplacementFieldImageFilter()
338
- iterativeInverseDisplacementFieldImageFilter.SetNumberOfIterations(20)
339
- inverseDisplacementField = iterativeInverseDisplacementFieldImageFilter.Execute(displacementField)
340
- transform = sitk.DisplacementFieldTransform(inverseDisplacementField)
389
+ transform_to_displacement_field_filter = sitk.TransformToDisplacementFieldFilter()
390
+ transform_to_displacement_field_filter.SetReferenceImage(image)
391
+ displacement_field = transform_to_displacement_field_filter.Execute(transform)
392
+ iterative_inverse_displacement_field_image_filter = (
393
+ sitk.IterativeInverseDisplacementFieldImageFilter()
394
+ )
395
+ iterative_inverse_displacement_field_image_filter.SetNumberOfIterations(20)
396
+ inverse_displacement_field = iterative_inverse_displacement_field_image_filter.Execute(
397
+ displacement_field
398
+ )
399
+ transform = sitk.DisplacementFieldTransform(inverse_displacement_field)
341
400
  else:
342
401
  if invert:
343
402
  transform = transform.GetInverse()
344
403
  transforms.append(transform)
345
404
  result_transform = sitk.CompositeTransform(transforms)
346
-
347
- transformToDisplacementFieldFilter = sitk.TransformToDisplacementFieldFilter()
348
- transformToDisplacementFieldFilter.SetReferenceImage(image)
349
- transformToDisplacementFieldFilter.SetNumberOfThreads(16)
350
- new_locs = grid + torch.tensor(sitk.GetArrayFromImage(transformToDisplacementFieldFilter.Execute(result_transform))).unsqueeze(0).permute(0, 4, 1, 2, 3)
405
+
406
+ transform_to_displacement_field_filter = sitk.TransformToDisplacementFieldFilter()
407
+ transform_to_displacement_field_filter.SetReferenceImage(image)
408
+ transform_to_displacement_field_filter.SetNumberOfThreads(16)
409
+ new_locs = grid + torch.tensor(
410
+ sitk.GetArrayFromImage(transform_to_displacement_field_filter.Execute(result_transform))
411
+ ).unsqueeze(0).permute(0, 4, 1, 2, 3)
351
412
  shape = new_locs.shape[2:]
352
413
  for i in range(len(shape)):
353
414
  new_locs[:, i, ...] = 2 * (new_locs[:, i, ...] / (shape[i] - 1) - 0.5)
354
415
  new_locs = new_locs.permute(0, 2, 3, 4, 1)
355
416
  new_locs = new_locs[..., [2, 1, 0]]
356
- result = F.grid_sample(input.to(self.device).unsqueeze(0).float(), new_locs.to(self.device).float(), align_corners=True, padding_mode="border", mode="nearest" if input.dtype == torch.uint8 else "bilinear").squeeze(0).cpu()
357
- return result.type(torch.uint8) if input.dtype == torch.uint8 else result
358
-
359
- def inverse(self, name: str, input : torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
360
- # TODO
361
- return input
417
+ result = (
418
+ F.grid_sample(
419
+ tensor.to(self.device).unsqueeze(0).float(),
420
+ new_locs.to(self.device).float(),
421
+ align_corners=True,
422
+ padding_mode="border",
423
+ mode="nearest" if tensor.dtype == torch.uint8 else "bilinear",
424
+ )
425
+ .squeeze(0)
426
+ .cpu()
427
+ )
428
+ return result.type(torch.uint8) if tensor.dtype == torch.uint8 else result
429
+
430
+ def inverse(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
431
+ # TODO
432
+ return tensor
433
+
362
434
 
363
435
  class Mask(Transform):
364
436
 
365
- def __init__(self, path : str = "./default.mha", value_outside: int = 0) -> None:
437
+ def __init__(self, path: str = "./default.mha", value_outside: int = 0) -> None:
366
438
  self.path = path
367
439
  self.value_outside = value_outside
368
-
369
- def __call__(self, name: str, input : torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
440
+
441
+ def __call__(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
370
442
  if self.path.endswith(".mha"):
371
443
  mask = torch.tensor(sitk.GetArrayFromImage(sitk.ReadImage(self.path))).unsqueeze(0)
372
444
  else:
373
445
  mask = None
374
446
  for dataset in self.datasets:
375
- if dataset.isDatasetExist(self.path, name):
376
- mask, _ = dataset.readData(self.path, name)
447
+ if dataset.is_dataset_exist(self.path, name):
448
+ mask, _ = dataset.read_data(self.path, name)
377
449
  break
378
450
  if mask is None:
379
- raise NameError("Mask : {}/{} not found".format(self.path, name))
380
- return torch.where(torch.tensor(mask) > 0, input, self.value_outside)
451
+ raise NameError(f"Mask : {self.path}/{name} not found")
452
+ return torch.where(torch.tensor(mask) > 0, tensor, self.value_outside)
453
+
454
+ def inverse(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
455
+ return tensor
456
+
381
457
 
382
- def inverse(self, name: str, input : torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
383
- return input
384
-
385
458
  class Gradient(Transform):
386
459
 
387
460
  def __init__(self, per_dim: bool = False):
388
461
  self.per_dim = per_dim
389
-
462
+
390
463
  @staticmethod
391
- def _image_gradient2D(image : torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
464
+ def _image_gradient_2d(image: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
392
465
  dx = image[:, 1:, :] - image[:, :-1, :]
393
466
  dy = image[:, :, 1:] - image[:, :, :-1]
394
- return torch.nn.ConstantPad2d((0,0,0,1), 0)(dx), torch.nn.ConstantPad2d((0,1,0,0), 0)(dy)
467
+ return torch.nn.ConstantPad2d((0, 0, 0, 1), 0)(dx), torch.nn.ConstantPad2d((0, 1, 0, 0), 0)(dy)
395
468
 
396
469
  @staticmethod
397
- def _image_gradient3D(image : torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
470
+ def _image_gradient_3d(
471
+ image: torch.Tensor,
472
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
398
473
  dx = image[:, 1:, :, :] - image[:, :-1, :, :]
399
474
  dy = image[:, :, 1:, :] - image[:, :, :-1, :]
400
475
  dz = image[:, :, :, 1:] - image[:, :, :, :-1]
401
- return torch.nn.ConstantPad3d((0,0,0,0,0,1), 0)(dx), torch.nn.ConstantPad3d((0,0,0,1,0,0), 0)(dy), torch.nn.ConstantPad3d((0,1,0,0,0,0), 0)(dz)
402
-
403
- def __call__(self, name: str, input : torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
404
- result = torch.stack(Gradient._image_gradient3D(input) if len(input.shape) == 4 else Gradient._image_gradient2D(input), dim=1).squeeze(0)
476
+ return (
477
+ torch.nn.ConstantPad3d((0, 0, 0, 0, 0, 1), 0)(dx),
478
+ torch.nn.ConstantPad3d((0, 0, 0, 1, 0, 0), 0)(dy),
479
+ torch.nn.ConstantPad3d((0, 1, 0, 0, 0, 0), 0)(dz),
480
+ )
481
+
482
+ def __call__(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
483
+ result = torch.stack(
484
+ (Gradient._image_gradient_3d(tensor) if len(tensor.shape) == 4 else Gradient._image_gradient_2d(tensor)),
485
+ dim=1,
486
+ ).squeeze(0)
405
487
  if not self.per_dim:
406
- result = torch.sigmoid(result*3)
488
+ result = torch.sigmoid(result * 3)
407
489
  result = result.norm(dim=0)
408
490
  result = torch.unsqueeze(result, 0)
409
-
491
+
410
492
  return result
411
493
 
412
- def inverse(self, name: str, input : torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
413
- return input
494
+ def inverse(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
495
+ return tensor
496
+
414
497
 
415
498
  class Argmax(Transform):
416
499
 
417
500
  def __init__(self, dim: int = 0) -> None:
418
501
  self.dim = dim
419
-
420
- def __call__(self, name: str, input : torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
421
- return torch.argmax(input, dim=self.dim).unsqueeze(self.dim)
422
-
423
- def inverse(self, name: str, input : torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
424
- return input
425
-
502
+
503
+ def __call__(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
504
+ return torch.argmax(tensor, dim=self.dim).unsqueeze(self.dim)
505
+
506
+ def inverse(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
507
+ return tensor
508
+
509
+
426
510
  class FlatLabel(Transform):
427
511
 
428
- def __init__(self, labels: Union[list[int], None] = None) -> None:
512
+ def __init__(self, labels: list[int] | None = None) -> None:
429
513
  self.labels = labels
430
514
 
431
- def __call__(self, name: str, input : torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
432
- data = torch.zeros_like(input)
515
+ def __call__(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
516
+ data = torch.zeros_like(tensor)
433
517
  if self.labels:
434
518
  for label in self.labels:
435
- data[torch.where(input == label)] = 1
519
+ data[torch.where(tensor == label)] = 1
436
520
  else:
437
- data[torch.where(input > 0)] = 1
521
+ data[torch.where(tensor > 0)] = 1
438
522
  return data
439
523
 
440
- def inverse(self, name: str, input : torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
441
- return input
524
+ def inverse(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
525
+ return tensor
526
+
442
527
 
443
528
  class Save(Transform):
444
529
 
445
530
  def __init__(self, dataset: str) -> None:
446
531
  self.dataset = dataset
447
-
448
- def __call__(self, name: str, input : torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
449
- return input
450
532
 
451
- def inverse(self, name: str, input : torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
452
- return input
533
+ def __call__(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
534
+ return tensor
535
+
536
+ def inverse(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
537
+ return tensor
538
+
453
539
 
454
540
  class Flatten(Transform):
455
541
 
456
542
  def __init__(self) -> None:
457
543
  super().__init__()
458
544
 
459
- def transformShape(self, shape: list[int], cache_attribute: Attribute) -> list[int]:
545
+ def transform_shape(self, shape: list[int], cache_attribute: Attribute) -> list[int]:
460
546
  return [np.prod(np.asarray(shape))]
461
547
 
462
- def __call__(self, name: str, input : torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
463
- return input.flatten()
464
-
465
- def inverse(self, name: str, input : torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
466
- return input
548
+ def __call__(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
549
+ return tensor.flatten()
550
+
551
+ def inverse(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
552
+ return tensor
553
+
467
554
 
468
555
  class Permute(Transform):
469
556
 
470
557
  def __init__(self, dims: str = "1|0|2") -> None:
471
558
  super().__init__()
472
- self.dims = [0]+[int(d)+1 for d in dims.split("|")]
473
-
474
- def transformShape(self, shape: list[int], cache_attribute: Attribute) -> list[int]:
475
- return [shape[it-1] for it in self.dims[1:]]
476
-
477
- def __call__(self, name: str, input : torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
478
- return input.permute(tuple(self.dims))
479
-
480
- def inverse(self, name: str, input : torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
481
- return input.permute(tuple(np.argsort(self.dims)))
482
-
559
+ self.dims = [0] + [int(d) + 1 for d in dims.split("|")]
560
+
561
+ def transform_shape(self, shape: list[int], cache_attribute: Attribute) -> list[int]:
562
+ return [shape[it - 1] for it in self.dims[1:]]
563
+
564
+ def __call__(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
565
+ return tensor.permute(tuple(self.dims))
566
+
567
+ def inverse(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
568
+ return tensor.permute(tuple(np.argsort(self.dims)))
569
+
570
+
483
571
  class Flip(Transform):
484
572
 
485
573
  def __init__(self, dims: str = "1|0|2") -> None:
486
574
  super().__init__()
487
575
 
488
- self.dims = [int(d)+1 for d in str(dims).split("|")]
576
+ self.dims = [int(d) + 1 for d in str(dims).split("|")]
577
+
578
+ def __call__(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
579
+ return tensor.flip(tuple(self.dims))
580
+
581
+ def inverse(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
582
+ return tensor.flip(tuple(self.dims))
489
583
 
490
- def __call__(self, name: str, input : torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
491
- return input.flip(tuple(self.dims))
492
-
493
- def inverse(self, name: str, input : torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
494
- return input.flip(tuple(self.dims))
495
584
 
496
585
  class Canonical(Transform):
497
586
 
498
587
  def __init__(self) -> None:
499
588
  self.canonical_direction = torch.diag(torch.tensor([-1, -1, 1])).to(torch.double)
500
589
 
501
- def __call__(self, name: str, input: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
590
+ def __call__(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
502
591
  spacing = cache_attribute.get_tensor("Spacing")
503
- initial_matrix = cache_attribute.get_tensor("Direction").reshape(3,3).to(torch.double)
592
+ initial_matrix = cache_attribute.get_tensor("Direction").reshape(3, 3).to(torch.double)
504
593
  initial_origin = cache_attribute.get_tensor("Origin")
505
594
  cache_attribute["Direction"] = (self.canonical_direction).flatten()
506
595
  matrix = _affine_matrix(self.canonical_direction @ initial_matrix.inverse(), torch.tensor([0, 0, 0]))
507
- center_voxel = torch.tensor([(input.shape[-i-1] - 1) * spacing[i] / 2 for i in range(3)], dtype=torch.double)
596
+ center_voxel = torch.tensor(
597
+ [(tensor.shape[-i - 1] - 1) * spacing[i] / 2 for i in range(3)],
598
+ dtype=torch.double,
599
+ )
508
600
  center_physical = initial_matrix @ center_voxel + initial_origin
509
601
  cache_attribute["Origin"] = center_physical - (self.canonical_direction @ center_voxel)
510
- return _resample_affine(input, matrix.unsqueeze(0))
511
-
512
- def inverse(self, name: str, input: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
602
+ return _resample_affine(tensor, matrix.unsqueeze(0))
603
+
604
+ def inverse(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
513
605
  cache_attribute.pop("Direction")
514
606
  cache_attribute.pop("Origin")
515
- matrix = _affine_matrix((self.canonical_direction @ cache_attribute.get_tensor("Direction").to(torch.double).reshape(3,3).inverse()).inverse(), torch.tensor([0, 0, 0]))
516
- return _resample_affine(input, matrix.unsqueeze(0))
607
+ matrix = _affine_matrix(
608
+ (
609
+ self.canonical_direction
610
+ @ cache_attribute.get_tensor("Direction").to(torch.double).reshape(3, 3).inverse()
611
+ ).inverse(),
612
+ torch.tensor([0, 0, 0]),
613
+ )
614
+ return _resample_affine(tensor, matrix.unsqueeze(0))
615
+
517
616
 
518
617
  class HistogramMatching(Transform):
519
618
 
520
619
  def __init__(self, reference_group: str) -> None:
521
620
  self.reference_group = reference_group
522
621
 
523
- def __call__(self, name: str, input: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
524
- image = data_to_image(input, cache_attribute)
622
+ def __call__(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
623
+ image = data_to_image(tensor, cache_attribute)
525
624
  image_ref = None
526
625
  for dataset in self.datasets:
527
- if dataset.isDatasetExist(self.reference_group, name):
528
- image_ref = dataset.readImage(self.reference_group, name)
626
+ if dataset.is_dataset_exist(self.reference_group, name):
627
+ image_ref = dataset.read_image(self.reference_group, name)
529
628
  if image_ref is None:
530
- raise NameError("Image : {}/{} not found".format(self.reference_group, name))
629
+ raise NameError(f"Image : {self.reference_group}/{name} not found")
531
630
  matcher = sitk.HistogramMatchingImageFilter()
532
631
  matcher.SetNumberOfHistogramLevels(256)
533
632
  matcher.SetNumberOfMatchPoints(1)
534
633
  matcher.SetThresholdAtMeanIntensity(True)
535
634
  result, _ = image_to_data(matcher.Execute(image, image_ref))
536
635
  return torch.tensor(result)
537
-
538
- def inverse(self, name: str, input: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
539
- return input
636
+
637
+ def inverse(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
638
+ return tensor
639
+
540
640
 
541
641
  class SelectLabel(Transform):
542
642
 
543
643
  def __init__(self, labels: list[str]) -> None:
544
- self.labels = [l[1:-1].split(",") for l in labels]
545
- def __call__(self, name: str, input: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
546
- data = torch.zeros_like(input)
644
+ self.labels = [label[1:-1].split(",") for label in labels]
645
+
646
+ def __call__(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
647
+ data = torch.zeros_like(tensor)
547
648
  for old_label, new_label in self.labels:
548
- data[input == int(old_label)] = int(new_label)
649
+ data[tensor == int(old_label)] = int(new_label)
549
650
  return data
550
-
551
- def inverse(self, name: str, input: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
552
- return input
553
-
651
+
652
+ def inverse(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
653
+ return tensor
654
+
655
+
554
656
  class OneHot(Transform):
555
-
657
+
556
658
  def __init__(self, num_classes: int) -> None:
557
659
  super().__init__()
558
660
  self.num_classes = num_classes
559
661
 
560
- def __call__(self, name: str, input: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
561
- result = F.one_hot(input.type(torch.int64), num_classes=self.num_classes).permute(0, len(input.shape), *[i+1 for i in range(len(input.shape)-1)]).float().squeeze(0)
662
+ def __call__(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
663
+ result = (
664
+ F.one_hot(tensor.type(torch.int64), num_classes=self.num_classes)
665
+ .permute(0, len(tensor.shape), *[i + 1 for i in range(len(tensor.shape) - 1)])
666
+ .float()
667
+ .squeeze(0)
668
+ )
562
669
  return result
563
-
564
- def inverse(self, name: str, input: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
565
- return torch.argmax(input, dim=1).unsqueeze(1)
670
+
671
+ def inverse(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
672
+ return torch.argmax(tensor, dim=1).unsqueeze(1)