konfai 1.1.8__py3-none-any.whl → 1.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of konfai might be problematic. Click here for more details.

Files changed (36) hide show
  1. konfai/__init__.py +59 -14
  2. konfai/data/augmentation.py +457 -286
  3. konfai/data/data_manager.py +533 -316
  4. konfai/data/patching.py +300 -183
  5. konfai/data/transform.py +408 -275
  6. konfai/evaluator.py +325 -68
  7. konfai/main.py +71 -22
  8. konfai/metric/measure.py +360 -244
  9. konfai/metric/schedulers.py +24 -13
  10. konfai/models/classification/convNeXt.py +187 -81
  11. konfai/models/classification/resnet.py +272 -58
  12. konfai/models/generation/cStyleGan.py +233 -59
  13. konfai/models/generation/ddpm.py +348 -121
  14. konfai/models/generation/diffusionGan.py +757 -358
  15. konfai/models/generation/gan.py +177 -53
  16. konfai/models/generation/vae.py +140 -40
  17. konfai/models/registration/registration.py +135 -52
  18. konfai/models/representation/representation.py +57 -23
  19. konfai/models/segmentation/NestedUNet.py +339 -68
  20. konfai/models/segmentation/UNet.py +140 -30
  21. konfai/network/blocks.py +331 -187
  22. konfai/network/network.py +795 -427
  23. konfai/predictor.py +644 -238
  24. konfai/trainer.py +509 -222
  25. konfai/utils/ITK.py +191 -106
  26. konfai/utils/config.py +152 -95
  27. konfai/utils/dataset.py +326 -455
  28. konfai/utils/utils.py +497 -249
  29. {konfai-1.1.8.dist-info → konfai-1.2.0.dist-info}/METADATA +1 -3
  30. konfai-1.2.0.dist-info/RECORD +38 -0
  31. konfai/utils/registration.py +0 -199
  32. konfai-1.1.8.dist-info/RECORD +0 -39
  33. {konfai-1.1.8.dist-info → konfai-1.2.0.dist-info}/WHEEL +0 -0
  34. {konfai-1.1.8.dist-info → konfai-1.2.0.dist-info}/entry_points.txt +0 -0
  35. {konfai-1.1.8.dist-info → konfai-1.2.0.dist-info}/licenses/LICENSE +0 -0
  36. {konfai-1.1.8.dist-info → konfai-1.2.0.dist-info}/top_level.txt +0 -0
konfai/data/transform.py CHANGED
@@ -1,258 +1,339 @@
1
1
  import importlib
2
- import torch
3
- import numpy as np
4
- import SimpleITK as sitk
2
+ import tempfile
5
3
  from abc import ABC, abstractmethod
6
- import torch.nn.functional as F
7
- from typing import Any, Union
4
+ from typing import Any
5
+
6
+ import numpy as np
7
+ import SimpleITK as sitk # noqa: N813
8
+ import torch
9
+ import torch.nn.functional as F # noqa: N812
8
10
 
9
- from konfai.utils.utils import _getModule, NeedDevice, _resample_affine, _affine_matrix, TransformError
10
- from konfai.utils.dataset import Dataset, Attribute, data_to_image, image_to_data
11
11
  from konfai.utils.config import config
12
+ from konfai.utils.dataset import Attribute, Dataset, data_to_image, image_to_data
13
+ from konfai.utils.utils import NeedDevice, TransformError, _affine_matrix, _resample_affine, get_module
14
+
12
15
 
13
16
  class Transform(NeedDevice, ABC):
14
-
17
+
15
18
  def __init__(self) -> None:
16
- self.datasets : list[Dataset] = []
17
-
18
- def setDatasets(self, datasets: list[Dataset]):
19
+ self.datasets: list[Dataset] = []
20
+
21
+ def set_datasets(self, datasets: list[Dataset]):
19
22
  self.datasets = datasets
20
23
 
21
- def transformShape(self, shape: list[int], cache_attribute: Attribute) -> list[int]:
24
+ def transform_shape(self, shape: list[int], cache_attribute: Attribute) -> list[int]:
22
25
  return shape
23
26
 
24
27
  @abstractmethod
25
- def __call__(self, name: str, input : torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
28
+ def __call__(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
26
29
  pass
27
30
 
28
31
  @abstractmethod
29
- def inverse(self, name: str, input : torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
32
+ def inverse(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
30
33
  pass
31
34
 
35
+
32
36
  class TransformLoader:
33
37
 
34
38
  @config()
35
39
  def __init__(self) -> None:
36
40
  pass
37
-
38
- def getTransform(self, classpath : str, DL_args : str) -> Transform:
39
- module, name = _getModule(classpath, "konfai.data.transform")
40
- return config("{}.{}".format(DL_args, classpath))(getattr(importlib.import_module(module), name))(config = None)
41
+
42
+ def get_transform(self, classpath: str, konfai_args: str) -> Transform:
43
+ module, name = get_module(classpath, "konfai.data.transform")
44
+ return config(f"{konfai_args}.{classpath}")(getattr(importlib.import_module(module), name))(config=None)
45
+
41
46
 
42
47
  class Clip(Transform):
43
48
 
44
- def __init__(self, min_value : float = -1024, max_value : float = 1024, saveClip_min: bool = False, saveClip_max: bool = False) -> None:
45
- assert max_value > min_value
49
+ def __init__(
50
+ self,
51
+ min_value: float = -1024,
52
+ max_value: float = 1024,
53
+ save_clip_min: bool = False,
54
+ save_clip_max: bool = False,
55
+ ) -> None:
56
+ if max_value <= min_value:
57
+ raise ValueError(
58
+ f"[Clip] Invalid clipping range: max_value ({max_value}) must be greater than min_value ({min_value})"
59
+ )
46
60
  self.min_value = min_value
47
61
  self.max_value = max_value
48
- self.saveClip_min = saveClip_min
49
- self.saveClip_max = saveClip_max
62
+ self.save_clip_min = save_clip_min
63
+ self.save_clip_max = save_clip_max
50
64
 
51
- def __call__(self, name: str, input : torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
52
- input[torch.where(input < self.min_value)] = self.min_value
53
- input[torch.where(input > self.max_value)] = self.max_value
54
- if self.saveClip_min:
65
+ def __call__(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
66
+ tensor[torch.where(tensor < self.min_value)] = self.min_value
67
+ tensor[torch.where(tensor > self.max_value)] = self.max_value
68
+ if self.save_clip_min:
55
69
  cache_attribute["Min"] = self.min_value
56
- if self.saveClip_max:
70
+ if self.save_clip_max:
57
71
  cache_attribute["Max"] = self.max_value
58
- return input
72
+ return tensor
73
+
74
+ def inverse(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
75
+ return tensor
59
76
 
60
- def inverse(self, name: str, input : torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
61
- return input
62
77
 
63
78
  class Normalize(Transform):
64
79
 
65
- def __init__(self, lazy : bool = False, channels: Union[list[int], None] = None, min_value : float = -1, max_value : float = 1) -> None:
66
- assert max_value > min_value
80
+ def __init__(
81
+ self,
82
+ lazy: bool = False,
83
+ channels: list[int] | None = None,
84
+ min_value: float = -1,
85
+ max_value: float = 1,
86
+ ) -> None:
87
+ if max_value <= min_value:
88
+ raise ValueError(
89
+ f"[Normalize] Invalid range: max_value ({max_value}) must be greater than min_value ({min_value})"
90
+ )
67
91
  self.lazy = lazy
68
92
  self.min_value = min_value
69
93
  self.max_value = max_value
70
94
  self.channels = channels
71
95
 
72
- def __call__(self, name: str, input : torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
96
+ def __call__(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
73
97
  if "Min" not in cache_attribute:
74
98
  if self.channels:
75
- cache_attribute["Min"] = torch.min(input[self.channels])
99
+ cache_attribute["Min"] = torch.min(tensor[self.channels])
76
100
  else:
77
- cache_attribute["Min"] = torch.min(input)
101
+ cache_attribute["Min"] = torch.min(tensor)
78
102
  if "Max" not in cache_attribute:
79
103
  if self.channels:
80
- cache_attribute["Max"] = torch.max(input[self.channels])
104
+ cache_attribute["Max"] = torch.max(tensor[self.channels])
81
105
  else:
82
- cache_attribute["Max"] = torch.max(input)
106
+ cache_attribute["Max"] = torch.max(tensor)
83
107
  if not self.lazy:
84
108
  input_min = float(cache_attribute["Min"])
85
109
  input_max = float(cache_attribute["Max"])
86
- norm = input_max-input_min
87
- assert norm != 0
88
- if self.channels:
89
- for channel in self.channels:
90
- input[channel] = (self.max_value-self.min_value)*(input[channel] - input_min) / norm + self.min_value
110
+ norm = input_max - input_min
111
+
112
+ if norm == 0:
113
+ print(f"[WARNING] Norm is zero for case '{name}': input is constant with value = {self.min_value}.")
114
+ if self.channels:
115
+ for channel in self.channels:
116
+ tensor[channel].fill_(self.min_value)
117
+ else:
118
+ tensor.fill_(self.min_value)
91
119
  else:
92
- input = (self.max_value-self.min_value)*(input - input_min) / norm + self.min_value
93
- return input
94
-
95
- def inverse(self, name: str, input : torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
120
+ if self.channels:
121
+ for channel in self.channels:
122
+ tensor[channel] = (self.max_value - self.min_value) * (
123
+ tensor[channel] - input_min
124
+ ) / norm + self.min_value
125
+ else:
126
+ tensor = (self.max_value - self.min_value) * (tensor - input_min) / norm + self.min_value
127
+
128
+ return tensor
129
+
130
+ def inverse(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
96
131
  if self.lazy:
97
- return input
132
+ return tensor
98
133
  else:
99
134
  input_min = float(cache_attribute.pop("Min"))
100
135
  input_max = float(cache_attribute.pop("Max"))
101
- return (input - self.min_value)*(input_max-input_min)/(self.max_value-self.min_value)+input_min
136
+ return (tensor - self.min_value) * (input_max - input_min) / (self.max_value - self.min_value) + input_min
137
+
102
138
 
103
139
  class Standardize(Transform):
104
140
 
105
- def __init__(self, lazy : bool = False, mean: Union[list[float], None] = None, std: Union[list[float], None]= None) -> None:
141
+ def __init__(
142
+ self,
143
+ lazy: bool = False,
144
+ mean: list[float] | None = None,
145
+ std: list[float] | None = None,
146
+ ) -> None:
106
147
  self.lazy = lazy
107
148
  self.mean = mean
108
149
  self.std = std
109
150
 
110
- def __call__(self, name: str, input : torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
151
+ def __call__(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
111
152
  if "Mean" not in cache_attribute:
112
- cache_attribute["Mean"] = torch.mean(input.type(torch.float32), dim=[i + 1 for i in range(len(input.shape)-1)]) if self.mean is None else torch.tensor([self.mean])
153
+ cache_attribute["Mean"] = (
154
+ torch.mean(
155
+ tensor.type(torch.float32),
156
+ dim=[i + 1 for i in range(len(tensor.shape) - 1)],
157
+ )
158
+ if self.mean is None
159
+ else torch.tensor([self.mean])
160
+ )
113
161
  if "Std" not in cache_attribute:
114
- cache_attribute["Std"] = torch.std(input.type(torch.float32), dim=[i + 1 for i in range(len(input.shape)-1)]) if self.std is None else torch.tensor([self.std])
162
+ cache_attribute["Std"] = (
163
+ torch.std(
164
+ tensor.type(torch.float32),
165
+ dim=[i + 1 for i in range(len(tensor.shape) - 1)],
166
+ )
167
+ if self.std is None
168
+ else torch.tensor([self.std])
169
+ )
115
170
 
116
171
  if self.lazy:
117
- return input
172
+ return tensor
118
173
  else:
119
- mean = cache_attribute.get_tensor("Mean").view(-1, *[1 for _ in range(len(input.shape)-1)])
120
- std = cache_attribute.get_tensor("Std").view(-1, *[1 for _ in range(len(input.shape)-1)])
121
- return (input - mean) / std
122
-
123
- def inverse(self, name: str, input : torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
174
+ mean = cache_attribute.get_tensor("Mean").view(-1, *[1 for _ in range(len(tensor.shape) - 1)])
175
+ std = cache_attribute.get_tensor("Std").view(-1, *[1 for _ in range(len(tensor.shape) - 1)])
176
+ return (tensor - mean) / std
177
+
178
+ def inverse(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
124
179
  if self.lazy:
125
- return input
180
+ return tensor
126
181
  else:
127
182
  mean = float(cache_attribute.pop("Mean"))
128
183
  std = float(cache_attribute.pop("Std"))
129
- return input * std + mean
130
-
184
+ return tensor * std + mean
185
+
186
+
131
187
  class TensorCast(Transform):
132
188
 
133
- def __init__(self, dtype : str = "float32") -> None:
134
- self.dtype : torch.dtype = getattr(torch, dtype)
189
+ def __init__(self, dtype: str = "float32") -> None:
190
+ self.dtype: torch.dtype = getattr(torch, dtype)
191
+
192
+ def __call__(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
193
+ cache_attribute["dtype"] = str(tensor.dtype).replace("torch.", "")
194
+ return tensor.type(self.dtype)
195
+
196
+ @staticmethod
197
+ def safe_dtype_cast(dtype_str: str) -> torch.dtype:
198
+ try:
199
+ return getattr(torch, dtype_str)
200
+ except AttributeError:
201
+ raise ValueError(f"Unsupported dtype: {dtype_str}")
202
+
203
+ def inverse(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
204
+ return tensor.to(TensorCast.safe_dtype_cast(cache_attribute.pop("dtype")))
135
205
 
136
- def __call__(self, name: str, input : torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
137
- cache_attribute["dtype"] = input.dtype
138
- return input.type(self.dtype)
139
-
140
- def inverse(self, name: str, input : torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
141
- return input.to(eval(cache_attribute.pop("dtype")))
142
206
 
143
207
  class Padding(Transform):
144
208
 
145
- def __init__(self, padding : list[int] = [0,0,0,0,0,0], mode : str = "constant") -> None:
209
+ def __init__(self, padding: list[int] = [0, 0, 0, 0, 0, 0], mode: str = "constant") -> None:
146
210
  self.padding = padding
147
211
  self.mode = mode
148
212
 
149
- def __call__(self, name: str, input : torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
213
+ def __call__(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
150
214
  if "Origin" in cache_attribute and "Spacing" in cache_attribute and "Direction" in cache_attribute:
151
215
  origin = torch.tensor(cache_attribute.get_np_array("Origin"))
152
- matrix = torch.tensor(cache_attribute.get_np_array("Direction").reshape((len(origin),len(origin))))
216
+ matrix = torch.tensor(cache_attribute.get_np_array("Direction").reshape((len(origin), len(origin))))
153
217
  origin = torch.matmul(origin, matrix)
154
- for dim in range(len(self.padding)//2):
155
- origin[-dim-1] -= self.padding[dim*2]* cache_attribute.get_np_array("Spacing")[-dim-1]
218
+ for dim in range(len(self.padding) // 2):
219
+ origin[-dim - 1] -= self.padding[dim * 2] * cache_attribute.get_np_array("Spacing")[-dim - 1]
156
220
  cache_attribute["Origin"] = torch.matmul(origin, torch.inverse(matrix))
157
- result = F.pad(input.unsqueeze(0), tuple(self.padding), self.mode.split(":")[0], float(self.mode.split(":")[1]) if len(self.mode.split(":")) == 2 else 0).squeeze(0)
221
+ result = F.pad(
222
+ tensor.unsqueeze(0),
223
+ tuple(self.padding),
224
+ self.mode.split(":")[0],
225
+ float(self.mode.split(":")[1]) if len(self.mode.split(":")) == 2 else 0,
226
+ ).squeeze(0)
158
227
  return result
159
-
160
- def transformShape(self, shape: list[int], cache_attribute: Attribute) -> list[int]:
161
- for dim in range(len(self.padding)//2):
162
- shape[-dim-1] += sum(self.padding[dim*2:dim*2+2])
228
+
229
+ def transform_shape(self, shape: list[int], cache_attribute: Attribute) -> list[int]:
230
+ for dim in range(len(self.padding) // 2):
231
+ shape[-dim - 1] += sum(self.padding[dim * 2 : dim * 2 + 2])
163
232
  return shape
164
233
 
165
- def inverse(self, name: str, input : torch.Tensor, cache_attribute: dict[str, torch.Tensor]) -> torch.Tensor:
234
+ def inverse(self, name: str, tensor: torch.Tensor, cache_attribute: dict[str, torch.Tensor]) -> torch.Tensor:
166
235
  if "Origin" in cache_attribute and "Spacing" in cache_attribute and "Direction" in cache_attribute:
167
236
  cache_attribute.pop("Origin")
168
- slices = [slice(0, shape) for shape in input.shape]
169
- for dim in range(len(self.padding)//2):
170
- slices[-dim-1] = slice(self.padding[dim*2], input.shape[-dim-1]-self.padding[dim*2+1])
171
- result = input[slices]
237
+ slices = [slice(0, shape) for shape in tensor.shape]
238
+ for dim in range(len(self.padding) // 2):
239
+ slices[-dim - 1] = slice(self.padding[dim * 2], tensor.shape[-dim - 1] - self.padding[dim * 2 + 1])
240
+ result = tensor[slices]
172
241
  return result
173
242
 
243
+
174
244
  class Squeeze(Transform):
175
245
 
176
246
  def __init__(self, dim: int) -> None:
177
247
  self.dim = dim
178
-
179
- def __call__(self, name: str, input : torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
180
- return input.squeeze(self.dim)
181
248
 
182
- def inverse(self, name: str, input : torch.Tensor, cache_attribute: dict[str, Any]) -> torch.Tensor:
183
- return input.unsqueeze(self.dim)
249
+ def __call__(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
250
+ return tensor.squeeze(self.dim)
251
+
252
+ def inverse(self, name: str, tensor: torch.Tensor, cache_attribute: dict[str, Any]) -> torch.Tensor:
253
+ return tensor.unsqueeze(self.dim)
254
+
184
255
 
185
256
  class Resample(Transform, ABC):
186
257
 
187
258
  def __init__(self) -> None:
188
259
  pass
189
260
 
190
- def _resample(self, input: torch.Tensor, size: list[int]) -> torch.Tensor:
191
- args = {}
192
- if input.dtype == torch.uint8:
261
+ def _resample(self, tensor: torch.Tensor, size: list[int]) -> torch.Tensor:
262
+ if tensor.dtype == torch.uint8:
193
263
  mode = "nearest"
194
- elif len(input.shape) < 4:
264
+ elif len(tensor.shape) < 4:
195
265
  mode = "bilinear"
196
266
  else:
197
267
  mode = "trilinear"
198
- return F.interpolate(input.type(torch.float32).unsqueeze(0), size=tuple(size), mode=mode).squeeze(0).type(input.dtype).cpu()
268
+ return (
269
+ F.interpolate(tensor.type(torch.float32).unsqueeze(0), size=tuple(size), mode=mode)
270
+ .squeeze(0)
271
+ .type(tensor.dtype)
272
+ .cpu()
273
+ )
199
274
 
200
275
  @abstractmethod
201
- def __call__(self, name: str, input : torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
276
+ def __call__(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
202
277
  pass
203
-
278
+
204
279
  @abstractmethod
205
- def transformShape(self, shape: list[int], cache_attribute: Attribute) -> list[int]:
280
+ def transform_shape(self, shape: list[int], cache_attribute: Attribute) -> list[int]:
206
281
  pass
207
-
208
- def inverse(self, name: str, input : torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
209
- size_0 = cache_attribute.pop_np_array("Size")
282
+
283
+ def inverse(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
284
+ cache_attribute.pop_np_array("Size")
210
285
  size_1 = cache_attribute.pop_np_array("Size")
211
286
  _ = cache_attribute.pop_np_array("Spacing")
212
- return self._resample(input, [int(size) for size in size_1])
287
+ return self._resample(tensor, [int(size) for size in size_1])
288
+
213
289
 
214
290
  class ResampleToResolution(Resample):
215
291
 
216
- def __init__(self, spacing : list[float] = [1., 1., 1.]) -> None:
292
+ def __init__(self, spacing: list[float] = [1.0, 1.0, 1.0]) -> None:
217
293
  self.spacing = torch.tensor([0 if s < 0 else s for s in spacing])
218
294
 
219
- def transformShape(self, shape: list[int], cache_attribute: Attribute) -> list[int]:
295
+ def transform_shape(self, shape: list[int], cache_attribute: Attribute) -> list[int]:
220
296
  if "Spacing" not in cache_attribute:
221
- TransformError("Missing 'Spacing' in cache attributes, the data is likely not a valid image.",
222
- "Make sure your input is a image (e.g., .nii, .mha) with proper metadata.")
297
+ TransformError(
298
+ "Missing 'Spacing' in cache attributes, the data is likely not a valid image.",
299
+ "Make sure your input is a image (e.g., .nii, .mha) with proper metadata.",
300
+ )
223
301
  if len(shape) != len(self.spacing):
224
302
  TransformError("Shape and spacing dimensions do not match: shape={shape}, spacing={self.spacing}")
225
303
  image_spacing = cache_attribute.get_tensor("Spacing").flip(0)
226
304
  spacing = self.spacing
227
-
305
+
228
306
  for i, s in enumerate(self.spacing):
229
307
  if s == 0:
230
308
  spacing[i] = image_spacing[i]
231
- resize_factor = spacing/cache_attribute.get_tensor("Spacing").flip(0)
232
- return [int(x) for x in (torch.tensor(shape) * 1/resize_factor)]
309
+ resize_factor = spacing / cache_attribute.get_tensor("Spacing").flip(0)
310
+ return [int(x) for x in (torch.tensor(shape) * 1 / resize_factor)]
233
311
 
234
- def __call__(self, name: str, input : torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
312
+ def __call__(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
235
313
  image_spacing = cache_attribute.get_tensor("Spacing").flip(0)
236
314
  spacing = self.spacing
237
315
  for i, s in enumerate(self.spacing):
238
316
  if s == 0:
239
317
  spacing[i] = image_spacing[i]
240
- resize_factor = spacing/cache_attribute.get_tensor("Spacing").flip(0)
318
+ resize_factor = spacing / cache_attribute.get_tensor("Spacing").flip(0)
241
319
  cache_attribute["Spacing"] = spacing.flip(0)
242
- cache_attribute["Size"] = np.asarray([int(x) for x in torch.tensor(input.shape[1:])])
243
- size = [int(x) for x in (torch.tensor(input.shape[1:]) * 1/resize_factor)]
320
+ cache_attribute["Size"] = np.asarray([int(x) for x in torch.tensor(tensor.shape[1:])])
321
+ size = [int(x) for x in (torch.tensor(tensor.shape[1:]) * 1 / resize_factor)]
244
322
  cache_attribute["Size"] = np.asarray(size)
245
- return self._resample(input, size)
323
+ return self._resample(tensor, size)
324
+
246
325
 
247
326
  class ResampleToShape(Resample):
248
327
 
249
- def __init__(self, shape : list[float] = [100,256,256]) -> None:
328
+ def __init__(self, shape: list[float] = [100, 256, 256]) -> None:
250
329
  self.shape = torch.tensor([0 if s < 0 else s for s in shape])
251
330
 
252
- def transformShape(self, shape: list[int], cache_attribute: Attribute) -> list[int]:
331
+ def transform_shape(self, shape: list[int], cache_attribute: Attribute) -> list[int]:
253
332
  if "Spacing" not in cache_attribute:
254
- TransformError("Missing 'Spacing' in cache attributes, the data is likely not a valid image.",
255
- "Make sure your input is a image (e.g., .nii, .mha) with proper metadata.")
333
+ TransformError(
334
+ "Missing 'Spacing' in cache attributes, the data is likely not a valid image.",
335
+ "Make sure your input is a image (e.g., .nii, .mha) with proper metadata.",
336
+ )
256
337
  if len(shape) != len(self.shape):
257
338
  TransformError("Shape and spacing dimensions do not match: shape={shape}, spacing={self.spacing}")
258
339
  new_shape = self.shape
@@ -260,304 +341,356 @@ class ResampleToShape(Resample):
260
341
  if s == 0:
261
342
  new_shape[i] = shape[i]
262
343
  return new_shape
263
-
264
- def __call__(self, name: str, input: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
344
+
345
+ def __call__(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
265
346
  shape = self.shape
266
- image_shape = torch.tensor([int(x) for x in torch.tensor(input.shape[1:])])
347
+ image_shape = torch.tensor([int(x) for x in torch.tensor(tensor.shape[1:])])
267
348
  for i, s in enumerate(self.shape):
268
349
  if s == 0:
269
350
  shape[i] = image_shape[i]
270
351
  if "Spacing" in cache_attribute:
271
- cache_attribute["Spacing"] = torch.flip(image_shape/shape*torch.flip(cache_attribute.get_tensor("Spacing"), dims=[0]), dims=[0])
352
+ cache_attribute["Spacing"] = torch.flip(
353
+ image_shape / shape * torch.flip(cache_attribute.get_tensor("Spacing"), dims=[0]),
354
+ dims=[0],
355
+ )
272
356
  cache_attribute["Size"] = image_shape
273
357
  cache_attribute["Size"] = shape
274
- return self._resample(input, shape)
358
+ return self._resample(tensor, shape)
359
+
275
360
 
276
361
  class ResampleTransform(Transform):
277
362
 
278
- def __init__(self, transforms : dict[str, bool]) -> None:
363
+ def __init__(self, transforms: dict[str, bool]) -> None:
279
364
  self.transforms = transforms
280
-
281
- def transformShape(self, shape: list[int], cache_attribute: Attribute) -> list[int]:
365
+
366
+ def transform_shape(self, shape: list[int], cache_attribute: Attribute) -> list[int]:
282
367
  return shape
283
368
 
284
- def __call__V1(self, name: str, input : torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
285
- transforms = []
286
- image = data_to_image(input, cache_attribute)
287
- for transform_group, invert in self.transforms.items():
288
- transform = None
289
- for dataset in self.datasets:
290
- if dataset.isDatasetExist(transform_group, name):
291
- transform = dataset.readTransform(transform_group, name)
292
- break
293
- if transform is None:
294
- raise NameError("Tranform : {}/{} not found".format(transform_group, name))
295
- if isinstance(transform, sitk.BSplineTransform):
296
- if invert:
297
- transformToDisplacementFieldFilter = sitk.TransformToDisplacementFieldFilter()
298
- transformToDisplacementFieldFilter.SetReferenceImage(image)
299
- displacementField = transformToDisplacementFieldFilter.Execute(transform)
300
- iterativeInverseDisplacementFieldImageFilter = sitk.IterativeInverseDisplacementFieldImageFilter()
301
- iterativeInverseDisplacementFieldImageFilter.SetNumberOfIterations(20)
302
- inverseDisplacementField = iterativeInverseDisplacementFieldImageFilter.Execute(displacementField)
303
- transform = sitk.DisplacementFieldTransform(inverseDisplacementField)
304
- else:
305
- if invert:
306
- transform = transform.GetInverse()
307
- transforms.append(transform)
308
- result_transform = sitk.CompositeTransform(transforms)
309
- result = torch.tensor(sitk.GetArrayFromImage(sitk.Resample(image, image, result_transform, sitk.sitkNearestNeighbor if input.dtype == torch.uint8 else sitk.sitkBSpline, 0 if input.dtype == torch.uint8 else -1024))).unsqueeze(0)
310
- return result.type(torch.uint8) if input.dtype == torch.uint8 else result
311
-
312
- def __call__(self, name: str, input : torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
313
- assert len(input.shape) == 4 , "input size should be 5 dim"
314
- image = data_to_image(input, cache_attribute)
315
-
316
- vectors = [torch.arange(0, s) for s in input.shape[1:]]
317
- grids = torch.meshgrid(vectors, indexing='ij')
369
+ def __call__(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
370
+ if len(tensor.shape) != 4:
371
+ raise NameError("Input size should be 5 dim")
372
+ image = data_to_image(tensor, cache_attribute)
373
+
374
+ vectors = [torch.arange(0, s) for s in tensor.shape[1:]]
375
+ grids = torch.meshgrid(vectors, indexing="ij")
318
376
  grid = torch.stack(grids)
319
377
  grid = torch.unsqueeze(grid, 0)
320
-
378
+
321
379
  transforms = []
322
380
  for transform_group, invert in self.transforms.items():
323
381
  transform = None
324
382
  for dataset in self.datasets:
325
- if dataset.isDatasetExist(transform_group, name):
326
- transform = dataset.readTransform(transform_group, name)
383
+ if dataset.is_dataset_exist(transform_group, name):
384
+ transform = dataset.read_transform(transform_group, name)
327
385
  break
328
386
  if transform is None:
329
- raise NameError("Tranform : {}/{} not found".format(transform_group, name))
387
+ raise NameError(f"Tranform : {transform_group}/{name} not found")
330
388
  if isinstance(transform, sitk.BSplineTransform):
331
389
  if invert:
332
- transformToDisplacementFieldFilter = sitk.TransformToDisplacementFieldFilter()
333
- transformToDisplacementFieldFilter.SetReferenceImage(image)
334
- displacementField = transformToDisplacementFieldFilter.Execute(transform)
335
- iterativeInverseDisplacementFieldImageFilter = sitk.IterativeInverseDisplacementFieldImageFilter()
336
- iterativeInverseDisplacementFieldImageFilter.SetNumberOfIterations(20)
337
- inverseDisplacementField = iterativeInverseDisplacementFieldImageFilter.Execute(displacementField)
338
- transform = sitk.DisplacementFieldTransform(inverseDisplacementField)
390
+ transform_to_displacement_field_filter = sitk.TransformToDisplacementFieldFilter()
391
+ transform_to_displacement_field_filter.SetReferenceImage(image)
392
+ displacement_field = transform_to_displacement_field_filter.Execute(transform)
393
+ iterative_inverse_displacement_field_image_filter = (
394
+ sitk.IterativeInverseDisplacementFieldImageFilter()
395
+ )
396
+ iterative_inverse_displacement_field_image_filter.SetNumberOfIterations(20)
397
+ inverse_displacement_field = iterative_inverse_displacement_field_image_filter.Execute(
398
+ displacement_field
399
+ )
400
+ transform = sitk.DisplacementFieldTransform(inverse_displacement_field)
339
401
  else:
340
402
  if invert:
341
403
  transform = transform.GetInverse()
342
404
  transforms.append(transform)
343
405
  result_transform = sitk.CompositeTransform(transforms)
344
-
345
- transformToDisplacementFieldFilter = sitk.TransformToDisplacementFieldFilter()
346
- transformToDisplacementFieldFilter.SetReferenceImage(image)
347
- transformToDisplacementFieldFilter.SetNumberOfThreads(16)
348
- new_locs = grid + torch.tensor(sitk.GetArrayFromImage(transformToDisplacementFieldFilter.Execute(result_transform))).unsqueeze(0).permute(0, 4, 1, 2, 3)
406
+
407
+ transform_to_displacement_field_filter = sitk.TransformToDisplacementFieldFilter()
408
+ transform_to_displacement_field_filter.SetReferenceImage(image)
409
+ transform_to_displacement_field_filter.SetNumberOfThreads(16)
410
+ new_locs = grid + torch.tensor(
411
+ sitk.GetArrayFromImage(transform_to_displacement_field_filter.Execute(result_transform))
412
+ ).unsqueeze(0).permute(0, 4, 1, 2, 3)
349
413
  shape = new_locs.shape[2:]
350
414
  for i in range(len(shape)):
351
415
  new_locs[:, i, ...] = 2 * (new_locs[:, i, ...] / (shape[i] - 1) - 0.5)
352
416
  new_locs = new_locs.permute(0, 2, 3, 4, 1)
353
417
  new_locs = new_locs[..., [2, 1, 0]]
354
- result = F.grid_sample(input.to(self.device).unsqueeze(0).float(), new_locs.to(self.device).float(), align_corners=True, padding_mode="border", mode="nearest" if input.dtype == torch.uint8 else "bilinear").squeeze(0).cpu()
355
- return result.type(torch.uint8) if input.dtype == torch.uint8 else result
356
-
357
- def inverse(self, name: str, input : torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
358
- # TODO
359
- return input
418
+ result = (
419
+ F.grid_sample(
420
+ tensor.to(self.device).unsqueeze(0).float(),
421
+ new_locs.to(self.device).float(),
422
+ align_corners=True,
423
+ padding_mode="border",
424
+ mode="nearest" if tensor.dtype == torch.uint8 else "bilinear",
425
+ )
426
+ .squeeze(0)
427
+ .cpu()
428
+ )
429
+ return result.type(torch.uint8) if tensor.dtype == torch.uint8 else result
430
+
431
+ def inverse(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
432
+ # TODO
433
+ return tensor
434
+
360
435
 
361
436
  class Mask(Transform):
362
437
 
363
- def __init__(self, path : str = "./default.mha", value_outside: int = 0) -> None:
438
+ def __init__(self, path: str = "./default.mha", value_outside: int = 0) -> None:
364
439
  self.path = path
365
440
  self.value_outside = value_outside
366
-
367
- def __call__(self, name: str, input : torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
441
+
442
+ def __call__(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
368
443
  if self.path.endswith(".mha"):
369
444
  mask = torch.tensor(sitk.GetArrayFromImage(sitk.ReadImage(self.path))).unsqueeze(0)
370
445
  else:
371
446
  mask = None
372
447
  for dataset in self.datasets:
373
- if dataset.isDatasetExist(self.path, name):
374
- mask, _ = dataset.readData(self.path, name)
448
+ if dataset.is_dataset_exist(self.path, name):
449
+ mask, _ = dataset.read_data(self.path, name)
375
450
  break
376
451
  if mask is None:
377
- raise NameError("Mask : {}/{} not found".format(self.path, name))
378
- return torch.where(torch.tensor(mask) > 0, input, self.value_outside)
452
+ raise NameError(f"Mask : {self.path}/{name} not found")
453
+ return torch.where(torch.tensor(mask) > 0, tensor, self.value_outside)
454
+
455
+ def inverse(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
456
+ return tensor
457
+
379
458
 
380
- def inverse(self, name: str, input : torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
381
- return input
382
-
383
459
  class Gradient(Transform):
384
460
 
385
461
  def __init__(self, per_dim: bool = False):
386
462
  self.per_dim = per_dim
387
-
463
+
388
464
  @staticmethod
389
- def _image_gradient2D(image : torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
465
+ def _image_gradient_2d(image: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
390
466
  dx = image[:, 1:, :] - image[:, :-1, :]
391
467
  dy = image[:, :, 1:] - image[:, :, :-1]
392
- return torch.nn.ConstantPad2d((0,0,0,1), 0)(dx), torch.nn.ConstantPad2d((0,1,0,0), 0)(dy)
468
+ return torch.nn.ConstantPad2d((0, 0, 0, 1), 0)(dx), torch.nn.ConstantPad2d((0, 1, 0, 0), 0)(dy)
393
469
 
394
470
  @staticmethod
395
- def _image_gradient3D(image : torch.Tensor) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
471
+ def _image_gradient_3d(
472
+ image: torch.Tensor,
473
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
396
474
  dx = image[:, 1:, :, :] - image[:, :-1, :, :]
397
475
  dy = image[:, :, 1:, :] - image[:, :, :-1, :]
398
476
  dz = image[:, :, :, 1:] - image[:, :, :, :-1]
399
- return torch.nn.ConstantPad3d((0,0,0,0,0,1), 0)(dx), torch.nn.ConstantPad3d((0,0,0,1,0,0), 0)(dy), torch.nn.ConstantPad3d((0,1,0,0,0,0), 0)(dz)
400
-
401
- def __call__(self, name: str, input : torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
402
- result = torch.stack(Gradient._image_gradient3D(input) if len(input.shape) == 4 else Gradient._image_gradient2D(input), dim=1).squeeze(0)
477
+ return (
478
+ torch.nn.ConstantPad3d((0, 0, 0, 0, 0, 1), 0)(dx),
479
+ torch.nn.ConstantPad3d((0, 0, 0, 1, 0, 0), 0)(dy),
480
+ torch.nn.ConstantPad3d((0, 1, 0, 0, 0, 0), 0)(dz),
481
+ )
482
+
483
+ def __call__(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
484
+ result = torch.stack(
485
+ (Gradient._image_gradient_3d(tensor) if len(tensor.shape) == 4 else Gradient._image_gradient_2d(tensor)),
486
+ dim=1,
487
+ ).squeeze(0)
403
488
  if not self.per_dim:
404
- result = torch.sigmoid(result*3)
489
+ result = torch.sigmoid(result * 3)
405
490
  result = result.norm(dim=0)
406
491
  result = torch.unsqueeze(result, 0)
407
-
492
+
408
493
  return result
409
494
 
410
- def inverse(self, name: str, input : torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
411
- return input
495
+ def inverse(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
496
+ return tensor
497
+
412
498
 
413
499
  class Argmax(Transform):
414
500
 
415
501
  def __init__(self, dim: int = 0) -> None:
416
502
  self.dim = dim
417
-
418
- def __call__(self, name: str, input : torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
419
- return torch.argmax(input, dim=self.dim).unsqueeze(self.dim)
420
-
421
- def inverse(self, name: str, input : torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
422
- return input
423
-
503
+
504
+ def __call__(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
505
+ return torch.argmax(tensor, dim=self.dim).unsqueeze(self.dim)
506
+
507
+ def inverse(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
508
+ return tensor
509
+
510
+
424
511
  class FlatLabel(Transform):
425
512
 
426
- def __init__(self, labels: Union[list[int], None] = None) -> None:
513
+ def __init__(self, labels: list[int] | None = None) -> None:
427
514
  self.labels = labels
428
515
 
429
- def __call__(self, name: str, input : torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
430
- data = torch.zeros_like(input)
516
+ def __call__(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
517
+ data = torch.zeros_like(tensor)
431
518
  if self.labels:
432
519
  for label in self.labels:
433
- data[torch.where(input == label)] = 1
520
+ data[torch.where(tensor == label)] = 1
434
521
  else:
435
- data[torch.where(input > 0)] = 1
522
+ data[torch.where(tensor > 0)] = 1
436
523
  return data
437
524
 
438
- def inverse(self, name: str, input : torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
439
- return input
525
+ def inverse(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
526
+ return tensor
527
+
440
528
 
441
529
  class Save(Transform):
442
530
 
443
531
  def __init__(self, dataset: str) -> None:
444
532
  self.dataset = dataset
445
-
446
- def __call__(self, name: str, input : torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
447
- return input
448
533
 
449
- def inverse(self, name: str, input : torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
450
- return input
534
+ def __call__(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
535
+ return tensor
536
+
537
+ def inverse(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
538
+ return tensor
539
+
451
540
 
452
541
  class Flatten(Transform):
453
542
 
454
543
  def __init__(self) -> None:
455
544
  super().__init__()
456
545
 
457
- def transformShape(self, shape: list[int], cache_attribute: Attribute) -> list[int]:
546
+ def transform_shape(self, shape: list[int], cache_attribute: Attribute) -> list[int]:
458
547
  return [np.prod(np.asarray(shape))]
459
548
 
460
- def __call__(self, name: str, input : torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
461
- return input.flatten()
462
-
463
- def inverse(self, name: str, input : torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
464
- return input
549
+ def __call__(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
550
+ return tensor.flatten()
551
+
552
+ def inverse(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
553
+ return tensor
554
+
465
555
 
466
556
  class Permute(Transform):
467
557
 
468
558
  def __init__(self, dims: str = "1|0|2") -> None:
469
559
  super().__init__()
470
- self.dims = [0]+[int(d)+1 for d in dims.split("|")]
471
-
472
- def transformShape(self, shape: list[int], cache_attribute: Attribute) -> list[int]:
473
- return [shape[it-1] for it in self.dims[1:]]
474
-
475
- def __call__(self, name: str, input : torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
476
- return input.permute(tuple(self.dims))
477
-
478
- def inverse(self, name: str, input : torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
479
- return input.permute(tuple(np.argsort(self.dims)))
480
-
560
+ self.dims = [0] + [int(d) + 1 for d in dims.split("|")]
561
+
562
+ def transform_shape(self, shape: list[int], cache_attribute: Attribute) -> list[int]:
563
+ return [shape[it - 1] for it in self.dims[1:]]
564
+
565
+ def __call__(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
566
+ return tensor.permute(tuple(self.dims))
567
+
568
+ def inverse(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
569
+ return tensor.permute(tuple(np.argsort(self.dims)))
570
+
571
+
481
572
  class Flip(Transform):
482
573
 
483
574
  def __init__(self, dims: str = "1|0|2") -> None:
484
575
  super().__init__()
485
576
 
486
- self.dims = [int(d)+1 for d in str(dims).split("|")]
577
+ self.dims = [int(d) + 1 for d in str(dims).split("|")]
578
+
579
+ def __call__(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
580
+ return tensor.flip(tuple(self.dims))
581
+
582
+ def inverse(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
583
+ return tensor.flip(tuple(self.dims))
487
584
 
488
- def __call__(self, name: str, input : torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
489
- return input.flip(tuple(self.dims))
490
-
491
- def inverse(self, name: str, input : torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
492
- return input.flip(tuple(self.dims))
493
585
 
494
586
  class Canonical(Transform):
495
587
 
496
588
  def __init__(self) -> None:
497
589
  self.canonical_direction = torch.diag(torch.tensor([-1, -1, 1])).to(torch.double)
498
590
 
499
- def __call__(self, name: str, input: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
591
+ def __call__(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
500
592
  spacing = cache_attribute.get_tensor("Spacing")
501
- initial_matrix = cache_attribute.get_tensor("Direction").reshape(3,3).to(torch.double)
593
+ initial_matrix = cache_attribute.get_tensor("Direction").reshape(3, 3).to(torch.double)
502
594
  initial_origin = cache_attribute.get_tensor("Origin")
503
595
  cache_attribute["Direction"] = (self.canonical_direction).flatten()
504
596
  matrix = _affine_matrix(self.canonical_direction @ initial_matrix.inverse(), torch.tensor([0, 0, 0]))
505
- center_voxel = torch.tensor([(input.shape[-i-1] - 1) * spacing[i] / 2 for i in range(3)], dtype=torch.double)
597
+ center_voxel = torch.tensor(
598
+ [(tensor.shape[-i - 1] - 1) * spacing[i] / 2 for i in range(3)],
599
+ dtype=torch.double,
600
+ )
506
601
  center_physical = initial_matrix @ center_voxel + initial_origin
507
602
  cache_attribute["Origin"] = center_physical - (self.canonical_direction @ center_voxel)
508
- return _resample_affine(input, matrix.unsqueeze(0))
509
-
510
- def inverse(self, name: str, input: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
603
+ return _resample_affine(tensor, matrix.unsqueeze(0))
604
+
605
+ def inverse(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
511
606
  cache_attribute.pop("Direction")
512
607
  cache_attribute.pop("Origin")
513
- matrix = _affine_matrix((self.canonical_direction @ cache_attribute.get_tensor("Direction").to(torch.double).reshape(3,3).inverse()).inverse(), torch.tensor([0, 0, 0]))
514
- return _resample_affine(input, matrix.unsqueeze(0))
608
+ matrix = _affine_matrix(
609
+ (
610
+ self.canonical_direction
611
+ @ cache_attribute.get_tensor("Direction").to(torch.double).reshape(3, 3).inverse()
612
+ ).inverse(),
613
+ torch.tensor([0, 0, 0]),
614
+ )
615
+ return _resample_affine(tensor, matrix.unsqueeze(0))
616
+
515
617
 
516
618
  class HistogramMatching(Transform):
517
619
 
518
620
  def __init__(self, reference_group: str) -> None:
519
621
  self.reference_group = reference_group
520
622
 
521
- def __call__(self, name: str, input: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
522
- image = data_to_image(input, cache_attribute)
623
+ def __call__(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
624
+ image = data_to_image(tensor, cache_attribute)
523
625
  image_ref = None
524
626
  for dataset in self.datasets:
525
- if dataset.isDatasetExist(self.reference_group, name):
526
- image_ref = dataset.readImage(self.reference_group, name)
627
+ if dataset.is_dataset_exist(self.reference_group, name):
628
+ image_ref = dataset.read_image(self.reference_group, name)
527
629
  if image_ref is None:
528
- raise NameError("Image : {}/{} not found".format(self.reference_group, name))
630
+ raise NameError(f"Image : {self.reference_group}/{name} not found")
529
631
  matcher = sitk.HistogramMatchingImageFilter()
530
632
  matcher.SetNumberOfHistogramLevels(256)
531
633
  matcher.SetNumberOfMatchPoints(1)
532
634
  matcher.SetThresholdAtMeanIntensity(True)
533
635
  result, _ = image_to_data(matcher.Execute(image, image_ref))
534
636
  return torch.tensor(result)
535
-
536
- def inverse(self, name: str, input: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
537
- return input
637
+
638
+ def inverse(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
639
+ return tensor
640
+
538
641
 
539
642
  class SelectLabel(Transform):
540
643
 
541
644
  def __init__(self, labels: list[str]) -> None:
542
- self.labels = [l[1:-1].split(",") for l in labels]
543
- def __call__(self, name: str, input: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
544
- data = torch.zeros_like(input)
645
+ self.labels = [label[1:-1].split(",") for label in labels]
646
+
647
+ def __call__(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
648
+ data = torch.zeros_like(tensor)
545
649
  for old_label, new_label in self.labels:
546
- data[input == int(old_label)] = int(new_label)
650
+ data[tensor == int(old_label)] = int(new_label)
547
651
  return data
548
-
549
- def inverse(self, name: str, input: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
550
- return input
551
-
652
+
653
+ def inverse(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
654
+ return tensor
655
+
656
+
552
657
  class OneHot(Transform):
553
-
658
+
554
659
  def __init__(self, num_classes: int) -> None:
555
660
  super().__init__()
556
661
  self.num_classes = num_classes
557
662
 
558
- def __call__(self, name: str, input: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
559
- result = F.one_hot(input.type(torch.int64), num_classes=self.num_classes).permute(0, len(input.shape), *[i+1 for i in range(len(input.shape)-1)]).float().squeeze(0)
663
+ def __call__(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
664
+ result = (
665
+ F.one_hot(tensor.type(torch.int64), num_classes=self.num_classes)
666
+ .permute(0, len(tensor.shape), *[i + 1 for i in range(len(tensor.shape) - 1)])
667
+ .float()
668
+ .squeeze(0)
669
+ )
560
670
  return result
561
-
562
- def inverse(self, name: str, input: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
563
- return torch.argmax(input, dim=1).unsqueeze(1)
671
+
672
+ def inverse(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
673
+ return torch.argmax(tensor, dim=1).unsqueeze(1)
674
+
675
+
676
+ class TotalSegmentator(Transform):
677
+
678
+ def __init__(self, task: str = "total"):
679
+ super().__init__()
680
+ self.task = task
681
+
682
+ def __call__(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
683
+ from totalsegmentator.python_api import totalsegmentator
684
+
685
+ with tempfile.TemporaryDirectory() as tmpdir:
686
+ image = data_to_image(tensor.numpy(), cache_attribute)
687
+ sitk.WriteImage(image, tmpdir + "/image.nii.gz")
688
+ seg = totalsegmentator(tmpdir + "/image.nii.gz", tmpdir, task=self.task, skip_saving=True, quiet=True)
689
+ return (
690
+ torch.from_numpy(np.array(np.asanyarray(seg.dataobj), copy=True).astype(np.uint8, copy=False))
691
+ .permute(2, 1, 0)
692
+ .unsqueeze(0)
693
+ )
694
+
695
+ def inverse(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
696
+ return tensor