konfai 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of konfai might be problematic. Click here for more details.

konfai/trainer.py ADDED
@@ -0,0 +1,330 @@
1
+ import torch
2
+ import torch.optim.adamw
3
+ from torch.utils.data import DataLoader
4
+ import tqdm
5
+ import numpy as np
6
+ import os
7
+ from typing import Union
8
+ from torch.nn.parallel import DistributedDataParallel as DDP
9
+
10
+ import shutil
11
+ from torch.utils.tensorboard.writer import SummaryWriter
12
+ from torch.optim.swa_utils import AveragedModel
13
+ import torch.distributed as dist
14
+
15
+ from KonfAI.konfai import MODELS_DIRECTORY, CHECKPOINTS_DIRECTORY, STATISTICS_DIRECTORY, SETUPS_DIRECTORY, CONFIG_FILE, MODEL, DATE, DL_API_STATE
16
+ from KonfAI.konfai.data.dataset import DataTrain
17
+ from KonfAI.konfai.utils.config import config
18
+ from KonfAI.konfai.utils.utils import State, DataLog, DistributedObject, description
19
+ from KonfAI.konfai.network.network import Network, ModelLoader, NetState, CPU_Model
20
+
21
+
22
+ class _Trainer():
23
+
24
+ def __init__(self, world_size: int, global_rank: int, local_rank: int, size: int, train_name: str, data_log: Union[list[str], None] , save_checkpoint_mode: str, epochs: int, epoch: int, autocast: bool, it_validation: Union[int, None], it: int, model: Union[DDP, CPU_Model], modelEMA: AveragedModel, dataloader_training: DataLoader, dataloader_validation: Union[DataLoader, None] = None) -> None:
25
+ self.world_size = world_size
26
+ self.global_rank = global_rank
27
+ self.local_rank = local_rank
28
+ self.size = size
29
+
30
+ self.save_checkpoint_mode = save_checkpoint_mode
31
+ self.train_name = train_name
32
+ self.epochs = epochs
33
+ self.epoch = epoch
34
+ self.model = model
35
+ self.dataloader_training = dataloader_training
36
+ self.dataloader_validation = dataloader_validation
37
+ self.autocast = autocast
38
+ self.modelEMA = modelEMA
39
+
40
+ self.it_validation = it_validation
41
+ if self.it_validation is None:
42
+ self.it_validation = len(dataloader_training)
43
+ self.it = it
44
+ self.tb = SummaryWriter(log_dir = STATISTICS_DIRECTORY()+self.train_name+"/")
45
+ self.data_log : dict[str, tuple[DataLog, int]] = {}
46
+ if data_log is not None:
47
+ for data in data_log:
48
+ self.data_log[data.split("/")[0].replace(":", ".")] = (DataLog.__getitem__(data.split("/")[1]).value[0], int(data.split("/")[2]))
49
+
50
+ def __enter__(self):
51
+ return self
52
+
53
+ def __exit__(self, type, value, traceback):
54
+ if self.tb is not None:
55
+ self.tb.close()
56
+
57
+ def run(self) -> None:
58
+ with tqdm.tqdm(iterable = range(self.epoch, self.epochs), leave=False, total=self.epochs, initial=self.epoch, desc="Progress", disable=self.global_rank != 0) as epoch_tqdm:
59
+ for self.epoch in epoch_tqdm:
60
+ self.dataloader_training.dataset.load()
61
+ self.train()
62
+ self.dataloader_training.dataset.resetAugmentation()
63
+
64
+ def getInput(self, data_dict : dict[str, tuple[torch.Tensor, int, int, int, str, bool]]) -> dict[tuple[str, bool], torch.Tensor]:
65
+ return {(k, v[5][0].item()) : v[0] for k, v in data_dict.items()}
66
+
67
+ def train(self) -> None:
68
+ self.model.train()
69
+ self.model.module.setState(NetState.TRAIN)
70
+ if self.modelEMA is not None:
71
+ self.modelEMA.eval()
72
+ self.modelEMA.module.setState(NetState.TRAIN)
73
+
74
+ desc = lambda : "Training : {}".format(description(self.model, self.modelEMA))
75
+ with tqdm.tqdm(iterable = enumerate(self.dataloader_training), desc = desc(), total=len(self.dataloader_training), leave=False, disable=self.global_rank != 0 and "DL_API_CLUSTER" not in os.environ) as batch_iter:
76
+ for _, data_dict in batch_iter:
77
+ with torch.amp.autocast('cuda', enabled=self.autocast):
78
+ input = self.getInput(data_dict)
79
+ self.model(input)
80
+ self.model.module.backward(self.model.module)
81
+ if self.modelEMA is not None:
82
+ self.modelEMA.update_parameters(self.model)
83
+ self.modelEMA.module(input)
84
+ self.it += 1
85
+ if (self.it) % self.it_validation == 0:
86
+ loss = self._train_log(data_dict)
87
+
88
+ if self.dataloader_validation is not None:
89
+ loss = self._validate()
90
+ self.model.module.update_lr()
91
+ self.checkpoint_save(loss)
92
+
93
+ batch_iter.set_description(desc())
94
+
95
+
96
+ @torch.no_grad()
97
+ def _validate(self) -> float:
98
+ self.model.eval()
99
+ self.model.module.setState(NetState.PREDICTION)
100
+ if self.modelEMA is not None:
101
+ self.modelEMA.module.setState(NetState.PREDICTION)
102
+
103
+ desc = lambda : "Validation : {}".format(description(self.model, self.modelEMA))
104
+ data_dict = None
105
+ self.dataloader_validation.dataset.load()
106
+ with tqdm.tqdm(iterable = enumerate(self.dataloader_validation), desc = desc(), total=len(self.dataloader_validation), leave=False, disable=self.global_rank != 0 and "DL_API_CLUSTER" not in os.environ) as batch_iter:
107
+ for _, data_dict in batch_iter:
108
+ input = self.getInput(data_dict)
109
+ self.model(input)
110
+ if self.modelEMA is not None:
111
+ self.modelEMA.module(input)
112
+
113
+ batch_iter.set_description(desc())
114
+ self.dataloader_validation.dataset.resetAugmentation()
115
+ dist.barrier()
116
+ self.model.train()
117
+ self.model.module.setState(NetState.TRAIN)
118
+ if self.modelEMA is not None:
119
+ self.modelEMA.module.setState(NetState.TRAIN)
120
+ return self._validation_log(data_dict)
121
+
122
+ def checkpoint_save(self, loss: float) -> None:
123
+ if self.global_rank == 0:
124
+ path = CHECKPOINTS_DIRECTORY()+self.train_name+"/"
125
+ last_loss = None
126
+ if os.path.exists(path) and os.listdir(path):
127
+ name = sorted(os.listdir(path))[-1]
128
+ state_dict = torch.load(path+name, weights_only=False)
129
+ last_loss = state_dict["loss"]
130
+ if self.save_checkpoint_mode == "BEST":
131
+ if last_loss >= loss:
132
+ os.remove(path+name)
133
+
134
+ if self.save_checkpoint_mode != "BEST" or (last_loss is None or last_loss >= loss):
135
+ name = DATE()+".pt"
136
+ if not os.path.exists(path):
137
+ os.makedirs(path)
138
+
139
+ save_dict = {
140
+ "epoch": self.epoch,
141
+ "it": self.it,
142
+ "loss": loss,
143
+ "Model": self.model.module.state_dict()}
144
+
145
+ if self.modelEMA is not None:
146
+ save_dict.update({"Model_EMA" : self.modelEMA.module.state_dict()})
147
+
148
+ save_dict.update({'{}_optimizer_state_dict'.format(name): network.optimizer.state_dict() for name, network in self.model.module.getNetworks().items() if network.optimizer is not None})
149
+ torch.save(save_dict, path+name)
150
+
151
+ @torch.no_grad()
152
+ def _log(self, type_log: str, data_dict : dict[str, tuple[torch.Tensor, int, int, int]]) -> float:
153
+ models: dict[str, Network] = {"" : self.model.module}
154
+ if self.modelEMA is not None:
155
+ models["_EMA"] = self.modelEMA.module
156
+
157
+ measures = DistributedObject.getMeasure(self.world_size, self.global_rank, self.local_rank*self.size+self.size-1, models, self.it_validation if type_log == "Trainning" else len(self.dataloader_validation))
158
+
159
+ if self.global_rank == 0:
160
+ images_log = []
161
+ if len(self.data_log):
162
+ for name, data_type in self.data_log.items():
163
+ if name in data_dict:
164
+ data_type[0](self.tb, "{}/{}".format(type_log ,name), data_dict[name][0][:self.data_log[name][1]].detach().cpu().numpy(), self.it)
165
+ else:
166
+ images_log.append(name.replace(":", "."))
167
+
168
+ for label, model in models.items():
169
+ for name, network in model.getNetworks().items():
170
+ if network.measure is not None:
171
+ self.tb.add_scalars("{}/{}/Loss/{}".format(type_log, name, label), {k : v[1] for k, v in measures["{}{}".format(name, label)][0].items()}, self.it)
172
+ self.tb.add_scalars("{}/{}/Loss_weight/{}".format(type_log, name, label), {k : v[0] for k, v in measures["{}{}".format(name, label)][0].items()}, self.it)
173
+
174
+ self.tb.add_scalars("{}/{}/Metric/{}".format(type_log, name, label), {k : v[1] for k, v in measures["{}{}".format(name, label)][1].items()}, self.it)
175
+ self.tb.add_scalars("{}/{}/Metric_weight/{}".format(type_log, name, label), {k : v[0] for k, v in measures["{}{}".format(name, label)][1].items()}, self.it)
176
+
177
+ if len(images_log):
178
+ for name, layer, _ in model.get_layers([v.to(0) for k, v in self.getInput(data_dict).items() if k[1]], images_log):
179
+ self.data_log[name][0](self.tb, "{}/{}{}".format(type_log, name, label), layer[:self.data_log[name][1]].detach().cpu().numpy(), self.it)
180
+
181
+ if type_log == "Trainning":
182
+ for name, network in self.model.module.getNetworks().items():
183
+ if network.optimizer is not None:
184
+ self.tb.add_scalar("{}/{}/Learning Rate".format(type_log, name), network.optimizer.param_groups[0]['lr'], self.it)
185
+
186
+ if self.global_rank == 0:
187
+ loss = []
188
+ for name, network in self.model.module.getNetworks().items():
189
+ if network.measure is not None:
190
+ loss.append(sum([v[1] for v in measures["{}".format(name)][0].values()]))
191
+ return np.mean(loss)
192
+ return None
193
+
194
+ @torch.no_grad()
195
+ def _train_log(self, data_dict : dict[str, tuple[torch.Tensor, int, int, int]]) -> float:
196
+ return self._log("Trainning", data_dict)
197
+
198
+ @torch.no_grad()
199
+ def _validation_log(self, data_dict : dict[str, tuple[torch.Tensor, int, int, int]]) -> float:
200
+ return self._log("Validation", data_dict)
201
+
202
+ class Trainer(DistributedObject):
203
+
204
+ @config("Trainer")
205
+ def __init__( self,
206
+ model : ModelLoader = ModelLoader(),
207
+ dataset : DataTrain = DataTrain(),
208
+ train_name : str = "default:name",
209
+ manual_seed : Union[int, None] = None,
210
+ epochs: int = 100,
211
+ it_validation : Union[int, None] = None,
212
+ autocast : bool = False,
213
+ gradient_checkpoints: Union[list[str], None] = None,
214
+ gpu_checkpoints: Union[list[str], None] = None,
215
+ ema_decay : float = 0,
216
+ data_log: Union[list[str], None] = None,
217
+ save_checkpoint_mode: str= "BEST") -> None:
218
+ if os.environ["DEEP_LEANING_API_CONFIG_MODE"] != "Done":
219
+ exit(0)
220
+ super().__init__(train_name)
221
+ self.manual_seed = manual_seed
222
+ self.dataset = dataset
223
+ self.autocast = autocast
224
+ self.epochs = epochs
225
+ self.epoch = 0
226
+ self.it = 0
227
+ self.it_validation = it_validation
228
+ self.model = model.getModel(train=True)
229
+ self.ema_decay = ema_decay
230
+ self.modelEMA : Union[torch.optim.swa_utils.AveragedModel, None] = None
231
+ self.data_log = data_log
232
+ self.gradient_checkpoints = gradient_checkpoints
233
+ self.gpu_checkpoints = gpu_checkpoints
234
+ self.save_checkpoint_mode = save_checkpoint_mode
235
+ self.config_namefile_src = CONFIG_FILE().replace(".yml", "")
236
+ self.config_namefile = SETUPS_DIRECTORY()+self.name+"/"+self.config_namefile_src.split("/")[-1]+"_"+str(self.it)+".yml"
237
+ self.size = (len(self.gpu_checkpoints)+1 if self.gpu_checkpoints else 1)
238
+
239
+ def __exit__(self, type, value, traceback):
240
+ super().__exit__(type, value, traceback)
241
+ self._save()
242
+
243
+ def _load(self) -> dict[str, dict[str, torch.Tensor]]:
244
+ if MODEL().startswith("https://"):
245
+ try:
246
+ state_dict = {MODEL().split(":")[1]: torch.hub.load_state_dict_from_url(url=MODEL().split(":")[0], map_location="cpu", check_hash=True)}
247
+ except:
248
+ raise Exception("Model : {} does not exist !".format(MODEL()))
249
+ else:
250
+ if MODEL() != "":
251
+ path = ""
252
+ name = MODEL()
253
+ else:
254
+ path = CHECKPOINTS_DIRECTORY()+self.name+"/"
255
+ if os.listdir(path):
256
+ name = sorted(os.listdir(path))[-1]
257
+
258
+ if os.path.exists(path+name):
259
+ state_dict = torch.load(path+name, weights_only=False)
260
+ else:
261
+ raise Exception("Model : {} does not exist !".format(self.name))
262
+
263
+ if "epoch" in state_dict:
264
+ self.epoch = state_dict['epoch']
265
+ if "it" in state_dict:
266
+ self.it = state_dict['it']
267
+ return state_dict
268
+
269
+ def _save(self) -> None:
270
+ path_checkpoint = CHECKPOINTS_DIRECTORY()+self.name+"/"
271
+ path_model = MODELS_DIRECTORY()+self.name+"/"
272
+ if os.path.exists(path_checkpoint) and os.listdir(path_checkpoint):
273
+ for dir in [path_model, "{}Serialized/".format(path_model), "{}StateDict/".format(path_model)]:
274
+ if not os.path.exists(dir):
275
+ os.makedirs(dir)
276
+
277
+ for name in sorted(os.listdir(path_checkpoint)):
278
+ checkpoint = torch.load(path_checkpoint+name, weights_only=False)
279
+ self.model.load(checkpoint, init=False, ema=False)
280
+
281
+ torch.save(self.model, "{}Serialized/{}".format(path_model, name))
282
+ torch.save({"Model" : self.model.state_dict()}, "{}StateDict/{}".format(path_model, name))
283
+
284
+ if self.modelEMA is not None:
285
+ self.modelEMA.module.load(checkpoint, init=False, ema=True)
286
+ torch.save(self.modelEMA.module, "{}Serialized/{}".format(path_model, DATE()+"_EMA.pt"))
287
+ torch.save({"Model_EMA" : self.modelEMA.module.state_dict()}, "{}StateDict/{}".format(path_model, DATE()+"_EMA.pt"))
288
+
289
+ os.rename(self.config_namefile, self.config_namefile.replace(".yml", "")+"_"+str(self.it)+".yml")
290
+
291
+ def _avg_fn(self, averaged_model_parameter, model_parameter, num_averaged):
292
+ return (1-self.ema_decay) * averaged_model_parameter + self.ema_decay * model_parameter
293
+
294
+ def setup(self, world_size: int):
295
+ state = State._member_map_[DL_API_STATE()]
296
+ if state != State.RESUME and os.path.exists(STATISTICS_DIRECTORY()+self.name+"/"):
297
+ if os.environ["DL_API_OVERWRITE"] != "True":
298
+ accept = input("The model {} already exists ! Do you want to overwrite it (yes,no) : ".format(self.name))
299
+ if accept != "yes":
300
+ return
301
+ for directory_path in [STATISTICS_DIRECTORY(), MODELS_DIRECTORY(), CHECKPOINTS_DIRECTORY(), SETUPS_DIRECTORY()]:
302
+ if os.path.exists(directory_path+self.name+"/"):
303
+ shutil.rmtree(directory_path+self.name+"/")
304
+
305
+ state_dict = {}
306
+ if state != State.TRAIN:
307
+ state_dict = self._load()
308
+
309
+ self.model.init(self.autocast, state)
310
+ self.model.init_outputsGroup()
311
+ self.model._compute_channels_trace(self.model, self.model.in_channels, self.gradient_checkpoints, self.gpu_checkpoints)
312
+ self.model.load(state_dict, init=True, ema=False)
313
+ if self.ema_decay > 0:
314
+ self.modelEMA = AveragedModel(self.model, avg_fn=self._avg_fn)
315
+ if state_dict is not None:
316
+ self.modelEMA.module.load(state_dict, init=False, ema=True)
317
+
318
+ if not os.path.exists(SETUPS_DIRECTORY()+self.name+"/"):
319
+ os.makedirs(SETUPS_DIRECTORY()+self.name+"/")
320
+ shutil.copyfile(self.config_namefile_src+".yml", self.config_namefile)
321
+
322
+ self.dataloader = self.dataset.getData(world_size//self.size)
323
+
324
+ def run_process(self, world_size: int, global_rank: int, local_rank: int, dataloaders: list[DataLoader]):
325
+ model = Network.to(self.model, local_rank*self.size)
326
+ model = DDP(model, static_graph=True) if torch.cuda.is_available() else CPU_Model(model)
327
+ if self.modelEMA is not None:
328
+ self.modelEMA.module = Network.to(self.modelEMA.module, local_rank)
329
+ with _Trainer(world_size, global_rank, local_rank, self.size, self.name, self.data_log, self.save_checkpoint_mode, self.epochs, self.epoch, self.autocast, self.it_validation, self.it, model, self.modelEMA, *dataloaders) as t:
330
+ t.run()
konfai/utils/ITK.py ADDED
@@ -0,0 +1,269 @@
1
+ import SimpleITK as sitk
2
+ from typing import Union
3
+ import numpy as np
4
+ import torch
5
+ import scipy
6
+ import torch.nn.functional as F
7
+ from KonfAI.konfai.utils.utils import _resample
8
+
9
+ def _openTransform(transform_files: dict[Union[str, sitk.Transform], bool], image: sitk.Image= None) -> list[sitk.Transform]:
10
+ transforms: list[sitk.Transform] = []
11
+
12
+ for transform_file, invert in transform_files.items():
13
+ if isinstance(transform_file, str):
14
+ transform = sitk.ReadTransform(transform_file+".itk.txt")
15
+ else:
16
+ transform = transform_file
17
+ if transform.GetName() == "TranslationTransform":
18
+ transform = sitk.TranslationTransform(transform)
19
+ if invert:
20
+ transform = sitk.TranslationTransform(transform.GetInverse())
21
+ elif transform.GetName() == "Euler3DTransform":
22
+ transform = sitk.Euler3DTransform(transform)
23
+ if invert:
24
+ transform = sitk.Euler3DTransform(transform.GetInverse())
25
+ elif transform.GetName() == "VersorRigid3DTransform":
26
+ transform = sitk.VersorRigid3DTransform(transform)
27
+ if invert:
28
+ transform = sitk.VersorRigid3DTransform(transform.GetInverse())
29
+ elif transform.GetName() == "AffineTransform":
30
+ transform = sitk.AffineTransform(transform)
31
+ if invert:
32
+ transform = sitk.AffineTransform(transform.GetInverse())
33
+ elif transform.GetName() == "DisplacementFieldTransform":
34
+ if invert:
35
+ transformToDisplacementFieldFilter = sitk.TransformToDisplacementFieldFilter()
36
+ transformToDisplacementFieldFilter.SetReferenceImage(image)
37
+ displacementField = transformToDisplacementFieldFilter.Execute(transform)
38
+ iterativeInverseDisplacementFieldImageFilter = sitk.IterativeInverseDisplacementFieldImageFilter()
39
+ iterativeInverseDisplacementFieldImageFilter.SetNumberOfIterations(20)
40
+ inverseDisplacementField = iterativeInverseDisplacementFieldImageFilter.Execute(displacementField)
41
+ transform = sitk.DisplacementFieldTransform(inverseDisplacementField)
42
+ transforms.append(transform)
43
+ else:
44
+ transform = sitk.BSplineTransform(transform)
45
+ if invert:
46
+ transformToDisplacementFieldFilter = sitk.TransformToDisplacementFieldFilter()
47
+ transformToDisplacementFieldFilter.SetReferenceImage(image)
48
+ displacementField = transformToDisplacementFieldFilter.Execute(transform)
49
+ iterativeInverseDisplacementFieldImageFilter = sitk.IterativeInverseDisplacementFieldImageFilter()
50
+ iterativeInverseDisplacementFieldImageFilter.SetNumberOfIterations(20)
51
+ inverseDisplacementField = iterativeInverseDisplacementFieldImageFilter.Execute(displacementField)
52
+ transform = sitk.DisplacementFieldTransform(inverseDisplacementField)
53
+ transforms.append(transform)
54
+ if len(transforms) == 0:
55
+ transforms.append(sitk.Euler3DTransform())
56
+ return transforms
57
+
58
+ def _openRigidTransform(transform_files: dict[Union[str, sitk.Transform], bool]) -> tuple[np.ndarray, np.ndarray]:
59
+ transforms = _openTransform(transform_files)
60
+ matrix_result = np.identity(3)
61
+ translation_result = np.array([0,0,0])
62
+
63
+ for transform in transforms:
64
+ if hasattr(transform, "GetMatrix"):
65
+ matrix = np.linalg.inv(np.array(transform.GetMatrix(), dtype=np.double).reshape((3,3)))
66
+ translation = -np.asarray(transform.GetTranslation(), dtype=np.double)
67
+ center = np.asarray(transform.GetCenter(), dtype=np.double)
68
+ else:
69
+ matrix = np.eye(len(transform.GetOffset()))
70
+ translation = -np.asarray(transform.GetOffset(), dtype=np.double)
71
+ center = np.asarray([0]*len(transform.GetOffset()), dtype=np.double)
72
+
73
+ translation_center = np.linalg.inv(matrix).dot(matrix.dot(translation-center)+center)
74
+ translation_result = np.linalg.inv(matrix_result).dot(translation_center)+translation_result
75
+ matrix_result = matrix.dot(matrix_result)
76
+ return np.linalg.inv(matrix_result), -translation_result
77
+
78
+ def composeTransform(transform_files : dict[Union[str, sitk.Transform], bool], image : sitk.Image = None) -> None:#sitk.CompositeTransform:
79
+ transforms = _openTransform(transform_files, image)
80
+ result = sitk.CompositeTransform(transforms)
81
+ return result
82
+
83
+ def flattenTransform(transform_files: dict[Union[str, sitk.Transform], bool]) -> sitk.AffineTransform:
84
+ [matrix, translation] = _openRigidTransform(transform_files)
85
+ transform = sitk.AffineTransform(3)
86
+ transform.SetMatrix(matrix.flatten())
87
+ transform.SetTranslation(translation)
88
+ return transform
89
+
90
+ def apply_to_image_RigidTransform(image: sitk.Image, transform_files: dict[Union[str, sitk.Transform], bool]) -> sitk.Image:
91
+ [matrix, translation] = _openRigidTransform(transform_files)
92
+ matrix = np.linalg.inv(matrix)
93
+ translation = -translation
94
+ data = sitk.GetArrayFromImage(image)
95
+ result = sitk.GetImageFromArray(data)
96
+ result.SetDirection(matrix.dot(np.array(image.GetDirection()).reshape((3,3))).flatten())
97
+ result.SetOrigin(matrix.dot(np.array(image.GetOrigin())+translation))
98
+ result.SetSpacing(image.GetSpacing())
99
+ return result
100
+
101
+ def apply_to_data_Transform(data: np.ndarray, transform_files: dict[Union[str, sitk.Transform], bool]) -> sitk.Image:
102
+ transforms = composeTransform(transform_files)
103
+ result = np.copy(data)
104
+ _LPS = lambda matrix: np.array([-matrix[0], -matrix[1], matrix[2]], dtype=np.double)
105
+ for i in range(data.shape[0]):
106
+ result[i, :] = _LPS(transforms.TransformPoint(np.asarray(_LPS(data[i, :]), dtype=np.double)))
107
+ return result
108
+
109
+ def resampleITK(image_reference : sitk.Image, image : sitk.Image, transform_files : dict[Union[str, sitk.Transform], bool], mask = False, defaultPixelValue: Union[float, None] = None, torch_resample : bool = False) -> sitk.Image:
110
+ if torch_resample:
111
+ input = torch.tensor(sitk.GetArrayFromImage(image)).unsqueeze(0)
112
+ vectors = [torch.arange(0, s) for s in input.shape[1:]]
113
+ grids = torch.meshgrid(vectors, indexing='ij')
114
+ grid = torch.stack(grids)
115
+ grid = torch.unsqueeze(grid, 0)
116
+ transformToDisplacementFieldFilter = sitk.TransformToDisplacementFieldFilter()
117
+ transformToDisplacementFieldFilter.SetReferenceImage(image)
118
+ transformToDisplacementFieldFilter.SetNumberOfThreads(16)
119
+ new_locs = grid + torch.tensor(sitk.GetArrayFromImage(transformToDisplacementFieldFilter.Execute(composeTransform(transform_files, image)))).unsqueeze(0).permute(0, 4, 1, 2, 3)
120
+ shape = new_locs.shape[2:]
121
+ for i in range(len(shape)):
122
+ new_locs[:, i, ...] = 2 * (new_locs[:, i, ...] / (shape[i] - 1) - 0.5)
123
+ new_locs = new_locs.permute(0, 2, 3, 4, 1)
124
+ new_locs = new_locs[..., [2, 1, 0]]
125
+ result_data = F.grid_sample(input.unsqueeze(0).float(), new_locs.float(), align_corners=True, padding_mode="border", mode="nearest" if input.dtype == torch.uint8 else "bilinear").squeeze(0)
126
+ result_data = result_data.type(torch.uint8) if input.dtype == torch.uint8 else result_data
127
+ result = sitk.GetImageFromArray(result_data.squeeze(0).numpy())
128
+ result.CopyInformation(image_reference)
129
+ return result
130
+ else:
131
+ return sitk.Resample(image, image_reference, composeTransform(transform_files, image), sitk.sitkNearestNeighbor if mask else sitk.sitkBSpline, (defaultPixelValue if defaultPixelValue is not None else (0 if mask else int(np.min(sitk.GetArrayFromImage(image))))))
132
+
133
+ def parameterMap_to_transform(path_src: str) -> Union[sitk.Transform, list[sitk.Transform]]:
134
+ transform = sitk.ReadParameterFile("{}.0.txt".format(path_src))
135
+ format = lambda x: np.array([float(i) for i in x])
136
+
137
+ if transform["Transform"][0] == "EulerTransform":
138
+ result = sitk.Euler3DTransform()
139
+ parameters = format(transform["TransformParameters"])
140
+ fixedParameters = format(transform["CenterOfRotationPoint"])+[0]
141
+ elif transform["Transform"][0] == "AffineTransform":
142
+ result = sitk.AffineTransform(3)
143
+ parameters = format(transform["TransformParameters"])
144
+ fixedParameters = format(transform["CenterOfRotationPoint"])+[0]
145
+ elif transform["Transform"][0] == "BSplineStackTransform":
146
+ parameters = format(transform["TransformParameters"])
147
+ GridSize = format(transform["GridSize"])
148
+ GridOrigin = format(transform["GridOrigin"])
149
+ GridSpacing = format(transform["GridSpacing"])
150
+ GridDirection = format(transform["GridDirection"]).reshape((3,3)).T.flatten()
151
+ fixedParameters = np.concatenate([GridSize, GridOrigin, GridSpacing, GridDirection])
152
+
153
+ nb = int(format(transform["Size"])[-1])
154
+ sub = int(np.prod(GridSize))*3
155
+ results = []
156
+ for i in range(nb):
157
+ result = sitk.BSplineTransform(3)
158
+ sub_parameters = parameters[i*sub:(i+1)*sub]
159
+ result.SetFixedParameters(fixedParameters)
160
+ result.SetParameters(sub_parameters)
161
+ results.append(result)
162
+ return results
163
+ elif transform["Transform"][0] == "AffineLogStackTransform":
164
+ parameters = format(transform["TransformParameters"])
165
+ fixedParameters = format(transform["CenterOfRotationPoint"])+[0]
166
+
167
+ nb = int(transform["NumberOfSubTransforms"][0])
168
+ sub = 12
169
+ results = []
170
+ for i in range(nb):
171
+ result = sitk.AffineTransform(3)
172
+ sub_parameters = parameters[i*sub:(i+1)*sub]
173
+
174
+ result.SetFixedParameters(fixedParameters)
175
+ result.SetParameters(np.concatenate([scipy.linalg.expm(sub_parameters[:9].reshape((3,3))).flatten(), sub_parameters[-3:]]))
176
+ results.append(result)
177
+ return results
178
+ elif transform["Transform"][0] == "BSplineTransform":
179
+ result = sitk.BSplineTransform(3)
180
+
181
+ parameters = format(transform["TransformParameters"])
182
+ GridSize = format(transform["GridSize"])
183
+ GridOrigin = format(transform["GridOrigin"])
184
+ GridSpacing = format(transform["GridSpacing"])
185
+ GridDirection = np.array(format(transform["GridDirection"])).reshape((3,3)).T.flatten()
186
+ fixedParameters = np.concatenate([GridSize, GridOrigin, GridSpacing, GridDirection])
187
+ else:
188
+ raise NameError("Transform {} doesn't exist".format(transform["Transform"][0]))
189
+ result.SetFixedParameters(fixedParameters)
190
+ result.SetParameters(parameters)
191
+ return result
192
+
193
+ def resampleIsotropic(image: sitk.Image, spacing : list[float] = [1., 1., 1.]) -> sitk.Image:
194
+ resize_factor = [y/x for x,y in zip(spacing, image.GetSpacing())]
195
+ result = sitk.GetImageFromArray(_resample(torch.tensor(sitk.GetArrayFromImage(image)).unsqueeze(0), [int(size*factor) for size, factor in zip(image.GetSize(), resize_factor)]).squeeze(0).numpy())
196
+ result.SetDirection(image.GetDirection())
197
+ result.SetOrigin(image.GetOrigin())
198
+ result.SetSpacing(spacing)
199
+ return result
200
+
201
+ def resampleResize(image: sitk.Image, size : list[int] = [100,512,512]):
202
+ result = sitk.GetImageFromArray(_resample(torch.tensor(sitk.GetArrayFromImage(image)).unsqueeze(0), size).squeeze(0).numpy())
203
+ result.SetDirection(image.GetDirection())
204
+ result.SetOrigin(image.GetOrigin())
205
+ result.SetSpacing([x/y*z for x,y,z in zip(image.GetSize(), size, image.GetSpacing())])
206
+ return result
207
+
208
+ def box_with_mask(mask: sitk.Image, label: list[int], dilatations: list[int]) -> np.ndarray:
209
+
210
+ dilatations = [int(np.ceil(d/s)) for d, s in zip(dilatations, reversed(mask.GetSpacing()))]
211
+
212
+ data = sitk.GetArrayFromImage(mask)
213
+ border = np.where(np.isin(sitk.GetArrayFromImage(mask), label))
214
+ box = []
215
+ for w, dilatation, s in zip(border, dilatations, data.shape):
216
+ box.append([max(np.min(w)-dilatation, 0), min(np.max(w)+dilatation, s)])
217
+ box = np.asarray(box)
218
+ return box
219
+
220
+ def crop_with_mask(image: sitk.Image, box: np.ndarray) -> sitk.Image:
221
+ data = sitk.GetArrayFromImage(image)
222
+
223
+ for i, w in enumerate(box):
224
+ data = np.delete(data, slice(w[1], data.shape[i]), i)
225
+ data = np.delete(data, slice(0, w[0]), i)
226
+
227
+ origin = np.asarray(image.GetOrigin())
228
+ matrix = np.asarray(image.GetDirection()).reshape((len(origin), len(origin)))
229
+ origin = origin.dot(matrix)
230
+ for i, w in enumerate(box):
231
+ origin[-i-1] += w[0]*np.asarray(image.GetSpacing())[-i-1]
232
+ origin = origin.dot(np.linalg.inv(matrix))
233
+
234
+ result = sitk.GetImageFromArray(data)
235
+ result.SetOrigin(origin)
236
+ result.SetSpacing(image.GetSpacing())
237
+ result.SetDirection(image.GetDirection())
238
+ return result
239
+
240
+ def formatMaskLabel(mask: sitk.Image, labels: list[tuple[int, int]]) -> sitk.Image:
241
+ data = sitk.GetArrayFromImage(mask)
242
+ result_data = np.zeros_like(data, np.uint8)
243
+
244
+ for label_old, label_new in labels:
245
+ result_data[np.where(data == label_old)] = label_new
246
+
247
+ result = sitk.GetImageFromArray(result_data)
248
+ result.CopyInformation(mask)
249
+ return result
250
+
251
+ def getFlatLabel(mask: sitk.Image, labels: Union[None, list[int]] = None) -> sitk.Image:
252
+ data = sitk.GetArrayFromImage(mask)
253
+ result_data = np.zeros_like(data, np.uint8)
254
+ if labels is not None:
255
+ for label in labels:
256
+ result_data[np.where(data == label)] = 1
257
+ else:
258
+ result_data[np.where(data > 0)] = 1
259
+ result = sitk.GetImageFromArray(result_data)
260
+ result.CopyInformation(mask)
261
+ return result
262
+
263
+ def clipAndCast(image : sitk.Image, min: float, max: float, dtype: np.dtype) -> sitk.Image:
264
+ data = sitk.GetArrayFromImage(image)
265
+ data[np.where(data > max)] = max
266
+ data[np.where(data < min)] = min
267
+ result = sitk.GetImageFromArray(data.astype(dtype))
268
+ result.CopyInformation(image)
269
+ return result