konfai 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of konfai might be problematic. Click here for more details.
- konfai/__init__.py +16 -0
- konfai/data/HDF5.py +326 -0
- konfai/data/__init__.py +0 -0
- konfai/data/augmentation.py +597 -0
- konfai/data/dataset.py +470 -0
- konfai/data/transform.py +536 -0
- konfai/evaluator.py +146 -0
- konfai/main.py +43 -0
- konfai/metric/__init__.py +0 -0
- konfai/metric/measure.py +488 -0
- konfai/metric/schedulers.py +49 -0
- konfai/models/classification/convNeXt.py +175 -0
- konfai/models/classification/resnet.py +116 -0
- konfai/models/generation/cStyleGan.py +137 -0
- konfai/models/generation/ddpm.py +218 -0
- konfai/models/generation/diffusionGan.py +557 -0
- konfai/models/generation/gan.py +134 -0
- konfai/models/generation/vae.py +72 -0
- konfai/models/registration/registration.py +136 -0
- konfai/models/representation/representation.py +57 -0
- konfai/models/segmentation/NestedUNet.py +53 -0
- konfai/models/segmentation/UNet.py +58 -0
- konfai/network/__init__.py +0 -0
- konfai/network/blocks.py +348 -0
- konfai/network/network.py +950 -0
- konfai/predictor.py +366 -0
- konfai/trainer.py +330 -0
- konfai/utils/ITK.py +269 -0
- konfai/utils/Registration.py +199 -0
- konfai/utils/__init__.py +0 -0
- konfai/utils/config.py +218 -0
- konfai/utils/dataset.py +764 -0
- konfai/utils/utils.py +493 -0
- konfai-1.0.0.dist-info/METADATA +68 -0
- konfai-1.0.0.dist-info/RECORD +39 -0
- konfai-1.0.0.dist-info/WHEEL +5 -0
- konfai-1.0.0.dist-info/entry_points.txt +3 -0
- konfai-1.0.0.dist-info/licenses/LICENSE +201 -0
- konfai-1.0.0.dist-info/top_level.txt +1 -0
konfai/trainer.py
ADDED
|
@@ -0,0 +1,330 @@
|
|
|
1
|
+
import torch
|
|
2
|
+
import torch.optim.adamw
|
|
3
|
+
from torch.utils.data import DataLoader
|
|
4
|
+
import tqdm
|
|
5
|
+
import numpy as np
|
|
6
|
+
import os
|
|
7
|
+
from typing import Union
|
|
8
|
+
from torch.nn.parallel import DistributedDataParallel as DDP
|
|
9
|
+
|
|
10
|
+
import shutil
|
|
11
|
+
from torch.utils.tensorboard.writer import SummaryWriter
|
|
12
|
+
from torch.optim.swa_utils import AveragedModel
|
|
13
|
+
import torch.distributed as dist
|
|
14
|
+
|
|
15
|
+
from KonfAI.konfai import MODELS_DIRECTORY, CHECKPOINTS_DIRECTORY, STATISTICS_DIRECTORY, SETUPS_DIRECTORY, CONFIG_FILE, MODEL, DATE, DL_API_STATE
|
|
16
|
+
from KonfAI.konfai.data.dataset import DataTrain
|
|
17
|
+
from KonfAI.konfai.utils.config import config
|
|
18
|
+
from KonfAI.konfai.utils.utils import State, DataLog, DistributedObject, description
|
|
19
|
+
from KonfAI.konfai.network.network import Network, ModelLoader, NetState, CPU_Model
|
|
20
|
+
|
|
21
|
+
|
|
22
|
+
class _Trainer():
|
|
23
|
+
|
|
24
|
+
def __init__(self, world_size: int, global_rank: int, local_rank: int, size: int, train_name: str, data_log: Union[list[str], None] , save_checkpoint_mode: str, epochs: int, epoch: int, autocast: bool, it_validation: Union[int, None], it: int, model: Union[DDP, CPU_Model], modelEMA: AveragedModel, dataloader_training: DataLoader, dataloader_validation: Union[DataLoader, None] = None) -> None:
|
|
25
|
+
self.world_size = world_size
|
|
26
|
+
self.global_rank = global_rank
|
|
27
|
+
self.local_rank = local_rank
|
|
28
|
+
self.size = size
|
|
29
|
+
|
|
30
|
+
self.save_checkpoint_mode = save_checkpoint_mode
|
|
31
|
+
self.train_name = train_name
|
|
32
|
+
self.epochs = epochs
|
|
33
|
+
self.epoch = epoch
|
|
34
|
+
self.model = model
|
|
35
|
+
self.dataloader_training = dataloader_training
|
|
36
|
+
self.dataloader_validation = dataloader_validation
|
|
37
|
+
self.autocast = autocast
|
|
38
|
+
self.modelEMA = modelEMA
|
|
39
|
+
|
|
40
|
+
self.it_validation = it_validation
|
|
41
|
+
if self.it_validation is None:
|
|
42
|
+
self.it_validation = len(dataloader_training)
|
|
43
|
+
self.it = it
|
|
44
|
+
self.tb = SummaryWriter(log_dir = STATISTICS_DIRECTORY()+self.train_name+"/")
|
|
45
|
+
self.data_log : dict[str, tuple[DataLog, int]] = {}
|
|
46
|
+
if data_log is not None:
|
|
47
|
+
for data in data_log:
|
|
48
|
+
self.data_log[data.split("/")[0].replace(":", ".")] = (DataLog.__getitem__(data.split("/")[1]).value[0], int(data.split("/")[2]))
|
|
49
|
+
|
|
50
|
+
def __enter__(self):
|
|
51
|
+
return self
|
|
52
|
+
|
|
53
|
+
def __exit__(self, type, value, traceback):
|
|
54
|
+
if self.tb is not None:
|
|
55
|
+
self.tb.close()
|
|
56
|
+
|
|
57
|
+
def run(self) -> None:
|
|
58
|
+
with tqdm.tqdm(iterable = range(self.epoch, self.epochs), leave=False, total=self.epochs, initial=self.epoch, desc="Progress", disable=self.global_rank != 0) as epoch_tqdm:
|
|
59
|
+
for self.epoch in epoch_tqdm:
|
|
60
|
+
self.dataloader_training.dataset.load()
|
|
61
|
+
self.train()
|
|
62
|
+
self.dataloader_training.dataset.resetAugmentation()
|
|
63
|
+
|
|
64
|
+
def getInput(self, data_dict : dict[str, tuple[torch.Tensor, int, int, int, str, bool]]) -> dict[tuple[str, bool], torch.Tensor]:
|
|
65
|
+
return {(k, v[5][0].item()) : v[0] for k, v in data_dict.items()}
|
|
66
|
+
|
|
67
|
+
def train(self) -> None:
|
|
68
|
+
self.model.train()
|
|
69
|
+
self.model.module.setState(NetState.TRAIN)
|
|
70
|
+
if self.modelEMA is not None:
|
|
71
|
+
self.modelEMA.eval()
|
|
72
|
+
self.modelEMA.module.setState(NetState.TRAIN)
|
|
73
|
+
|
|
74
|
+
desc = lambda : "Training : {}".format(description(self.model, self.modelEMA))
|
|
75
|
+
with tqdm.tqdm(iterable = enumerate(self.dataloader_training), desc = desc(), total=len(self.dataloader_training), leave=False, disable=self.global_rank != 0 and "DL_API_CLUSTER" not in os.environ) as batch_iter:
|
|
76
|
+
for _, data_dict in batch_iter:
|
|
77
|
+
with torch.amp.autocast('cuda', enabled=self.autocast):
|
|
78
|
+
input = self.getInput(data_dict)
|
|
79
|
+
self.model(input)
|
|
80
|
+
self.model.module.backward(self.model.module)
|
|
81
|
+
if self.modelEMA is not None:
|
|
82
|
+
self.modelEMA.update_parameters(self.model)
|
|
83
|
+
self.modelEMA.module(input)
|
|
84
|
+
self.it += 1
|
|
85
|
+
if (self.it) % self.it_validation == 0:
|
|
86
|
+
loss = self._train_log(data_dict)
|
|
87
|
+
|
|
88
|
+
if self.dataloader_validation is not None:
|
|
89
|
+
loss = self._validate()
|
|
90
|
+
self.model.module.update_lr()
|
|
91
|
+
self.checkpoint_save(loss)
|
|
92
|
+
|
|
93
|
+
batch_iter.set_description(desc())
|
|
94
|
+
|
|
95
|
+
|
|
96
|
+
@torch.no_grad()
|
|
97
|
+
def _validate(self) -> float:
|
|
98
|
+
self.model.eval()
|
|
99
|
+
self.model.module.setState(NetState.PREDICTION)
|
|
100
|
+
if self.modelEMA is not None:
|
|
101
|
+
self.modelEMA.module.setState(NetState.PREDICTION)
|
|
102
|
+
|
|
103
|
+
desc = lambda : "Validation : {}".format(description(self.model, self.modelEMA))
|
|
104
|
+
data_dict = None
|
|
105
|
+
self.dataloader_validation.dataset.load()
|
|
106
|
+
with tqdm.tqdm(iterable = enumerate(self.dataloader_validation), desc = desc(), total=len(self.dataloader_validation), leave=False, disable=self.global_rank != 0 and "DL_API_CLUSTER" not in os.environ) as batch_iter:
|
|
107
|
+
for _, data_dict in batch_iter:
|
|
108
|
+
input = self.getInput(data_dict)
|
|
109
|
+
self.model(input)
|
|
110
|
+
if self.modelEMA is not None:
|
|
111
|
+
self.modelEMA.module(input)
|
|
112
|
+
|
|
113
|
+
batch_iter.set_description(desc())
|
|
114
|
+
self.dataloader_validation.dataset.resetAugmentation()
|
|
115
|
+
dist.barrier()
|
|
116
|
+
self.model.train()
|
|
117
|
+
self.model.module.setState(NetState.TRAIN)
|
|
118
|
+
if self.modelEMA is not None:
|
|
119
|
+
self.modelEMA.module.setState(NetState.TRAIN)
|
|
120
|
+
return self._validation_log(data_dict)
|
|
121
|
+
|
|
122
|
+
def checkpoint_save(self, loss: float) -> None:
|
|
123
|
+
if self.global_rank == 0:
|
|
124
|
+
path = CHECKPOINTS_DIRECTORY()+self.train_name+"/"
|
|
125
|
+
last_loss = None
|
|
126
|
+
if os.path.exists(path) and os.listdir(path):
|
|
127
|
+
name = sorted(os.listdir(path))[-1]
|
|
128
|
+
state_dict = torch.load(path+name, weights_only=False)
|
|
129
|
+
last_loss = state_dict["loss"]
|
|
130
|
+
if self.save_checkpoint_mode == "BEST":
|
|
131
|
+
if last_loss >= loss:
|
|
132
|
+
os.remove(path+name)
|
|
133
|
+
|
|
134
|
+
if self.save_checkpoint_mode != "BEST" or (last_loss is None or last_loss >= loss):
|
|
135
|
+
name = DATE()+".pt"
|
|
136
|
+
if not os.path.exists(path):
|
|
137
|
+
os.makedirs(path)
|
|
138
|
+
|
|
139
|
+
save_dict = {
|
|
140
|
+
"epoch": self.epoch,
|
|
141
|
+
"it": self.it,
|
|
142
|
+
"loss": loss,
|
|
143
|
+
"Model": self.model.module.state_dict()}
|
|
144
|
+
|
|
145
|
+
if self.modelEMA is not None:
|
|
146
|
+
save_dict.update({"Model_EMA" : self.modelEMA.module.state_dict()})
|
|
147
|
+
|
|
148
|
+
save_dict.update({'{}_optimizer_state_dict'.format(name): network.optimizer.state_dict() for name, network in self.model.module.getNetworks().items() if network.optimizer is not None})
|
|
149
|
+
torch.save(save_dict, path+name)
|
|
150
|
+
|
|
151
|
+
@torch.no_grad()
|
|
152
|
+
def _log(self, type_log: str, data_dict : dict[str, tuple[torch.Tensor, int, int, int]]) -> float:
|
|
153
|
+
models: dict[str, Network] = {"" : self.model.module}
|
|
154
|
+
if self.modelEMA is not None:
|
|
155
|
+
models["_EMA"] = self.modelEMA.module
|
|
156
|
+
|
|
157
|
+
measures = DistributedObject.getMeasure(self.world_size, self.global_rank, self.local_rank*self.size+self.size-1, models, self.it_validation if type_log == "Trainning" else len(self.dataloader_validation))
|
|
158
|
+
|
|
159
|
+
if self.global_rank == 0:
|
|
160
|
+
images_log = []
|
|
161
|
+
if len(self.data_log):
|
|
162
|
+
for name, data_type in self.data_log.items():
|
|
163
|
+
if name in data_dict:
|
|
164
|
+
data_type[0](self.tb, "{}/{}".format(type_log ,name), data_dict[name][0][:self.data_log[name][1]].detach().cpu().numpy(), self.it)
|
|
165
|
+
else:
|
|
166
|
+
images_log.append(name.replace(":", "."))
|
|
167
|
+
|
|
168
|
+
for label, model in models.items():
|
|
169
|
+
for name, network in model.getNetworks().items():
|
|
170
|
+
if network.measure is not None:
|
|
171
|
+
self.tb.add_scalars("{}/{}/Loss/{}".format(type_log, name, label), {k : v[1] for k, v in measures["{}{}".format(name, label)][0].items()}, self.it)
|
|
172
|
+
self.tb.add_scalars("{}/{}/Loss_weight/{}".format(type_log, name, label), {k : v[0] for k, v in measures["{}{}".format(name, label)][0].items()}, self.it)
|
|
173
|
+
|
|
174
|
+
self.tb.add_scalars("{}/{}/Metric/{}".format(type_log, name, label), {k : v[1] for k, v in measures["{}{}".format(name, label)][1].items()}, self.it)
|
|
175
|
+
self.tb.add_scalars("{}/{}/Metric_weight/{}".format(type_log, name, label), {k : v[0] for k, v in measures["{}{}".format(name, label)][1].items()}, self.it)
|
|
176
|
+
|
|
177
|
+
if len(images_log):
|
|
178
|
+
for name, layer, _ in model.get_layers([v.to(0) for k, v in self.getInput(data_dict).items() if k[1]], images_log):
|
|
179
|
+
self.data_log[name][0](self.tb, "{}/{}{}".format(type_log, name, label), layer[:self.data_log[name][1]].detach().cpu().numpy(), self.it)
|
|
180
|
+
|
|
181
|
+
if type_log == "Trainning":
|
|
182
|
+
for name, network in self.model.module.getNetworks().items():
|
|
183
|
+
if network.optimizer is not None:
|
|
184
|
+
self.tb.add_scalar("{}/{}/Learning Rate".format(type_log, name), network.optimizer.param_groups[0]['lr'], self.it)
|
|
185
|
+
|
|
186
|
+
if self.global_rank == 0:
|
|
187
|
+
loss = []
|
|
188
|
+
for name, network in self.model.module.getNetworks().items():
|
|
189
|
+
if network.measure is not None:
|
|
190
|
+
loss.append(sum([v[1] for v in measures["{}".format(name)][0].values()]))
|
|
191
|
+
return np.mean(loss)
|
|
192
|
+
return None
|
|
193
|
+
|
|
194
|
+
@torch.no_grad()
|
|
195
|
+
def _train_log(self, data_dict : dict[str, tuple[torch.Tensor, int, int, int]]) -> float:
|
|
196
|
+
return self._log("Trainning", data_dict)
|
|
197
|
+
|
|
198
|
+
@torch.no_grad()
|
|
199
|
+
def _validation_log(self, data_dict : dict[str, tuple[torch.Tensor, int, int, int]]) -> float:
|
|
200
|
+
return self._log("Validation", data_dict)
|
|
201
|
+
|
|
202
|
+
class Trainer(DistributedObject):
|
|
203
|
+
|
|
204
|
+
@config("Trainer")
|
|
205
|
+
def __init__( self,
|
|
206
|
+
model : ModelLoader = ModelLoader(),
|
|
207
|
+
dataset : DataTrain = DataTrain(),
|
|
208
|
+
train_name : str = "default:name",
|
|
209
|
+
manual_seed : Union[int, None] = None,
|
|
210
|
+
epochs: int = 100,
|
|
211
|
+
it_validation : Union[int, None] = None,
|
|
212
|
+
autocast : bool = False,
|
|
213
|
+
gradient_checkpoints: Union[list[str], None] = None,
|
|
214
|
+
gpu_checkpoints: Union[list[str], None] = None,
|
|
215
|
+
ema_decay : float = 0,
|
|
216
|
+
data_log: Union[list[str], None] = None,
|
|
217
|
+
save_checkpoint_mode: str= "BEST") -> None:
|
|
218
|
+
if os.environ["DEEP_LEANING_API_CONFIG_MODE"] != "Done":
|
|
219
|
+
exit(0)
|
|
220
|
+
super().__init__(train_name)
|
|
221
|
+
self.manual_seed = manual_seed
|
|
222
|
+
self.dataset = dataset
|
|
223
|
+
self.autocast = autocast
|
|
224
|
+
self.epochs = epochs
|
|
225
|
+
self.epoch = 0
|
|
226
|
+
self.it = 0
|
|
227
|
+
self.it_validation = it_validation
|
|
228
|
+
self.model = model.getModel(train=True)
|
|
229
|
+
self.ema_decay = ema_decay
|
|
230
|
+
self.modelEMA : Union[torch.optim.swa_utils.AveragedModel, None] = None
|
|
231
|
+
self.data_log = data_log
|
|
232
|
+
self.gradient_checkpoints = gradient_checkpoints
|
|
233
|
+
self.gpu_checkpoints = gpu_checkpoints
|
|
234
|
+
self.save_checkpoint_mode = save_checkpoint_mode
|
|
235
|
+
self.config_namefile_src = CONFIG_FILE().replace(".yml", "")
|
|
236
|
+
self.config_namefile = SETUPS_DIRECTORY()+self.name+"/"+self.config_namefile_src.split("/")[-1]+"_"+str(self.it)+".yml"
|
|
237
|
+
self.size = (len(self.gpu_checkpoints)+1 if self.gpu_checkpoints else 1)
|
|
238
|
+
|
|
239
|
+
def __exit__(self, type, value, traceback):
|
|
240
|
+
super().__exit__(type, value, traceback)
|
|
241
|
+
self._save()
|
|
242
|
+
|
|
243
|
+
def _load(self) -> dict[str, dict[str, torch.Tensor]]:
|
|
244
|
+
if MODEL().startswith("https://"):
|
|
245
|
+
try:
|
|
246
|
+
state_dict = {MODEL().split(":")[1]: torch.hub.load_state_dict_from_url(url=MODEL().split(":")[0], map_location="cpu", check_hash=True)}
|
|
247
|
+
except:
|
|
248
|
+
raise Exception("Model : {} does not exist !".format(MODEL()))
|
|
249
|
+
else:
|
|
250
|
+
if MODEL() != "":
|
|
251
|
+
path = ""
|
|
252
|
+
name = MODEL()
|
|
253
|
+
else:
|
|
254
|
+
path = CHECKPOINTS_DIRECTORY()+self.name+"/"
|
|
255
|
+
if os.listdir(path):
|
|
256
|
+
name = sorted(os.listdir(path))[-1]
|
|
257
|
+
|
|
258
|
+
if os.path.exists(path+name):
|
|
259
|
+
state_dict = torch.load(path+name, weights_only=False)
|
|
260
|
+
else:
|
|
261
|
+
raise Exception("Model : {} does not exist !".format(self.name))
|
|
262
|
+
|
|
263
|
+
if "epoch" in state_dict:
|
|
264
|
+
self.epoch = state_dict['epoch']
|
|
265
|
+
if "it" in state_dict:
|
|
266
|
+
self.it = state_dict['it']
|
|
267
|
+
return state_dict
|
|
268
|
+
|
|
269
|
+
def _save(self) -> None:
|
|
270
|
+
path_checkpoint = CHECKPOINTS_DIRECTORY()+self.name+"/"
|
|
271
|
+
path_model = MODELS_DIRECTORY()+self.name+"/"
|
|
272
|
+
if os.path.exists(path_checkpoint) and os.listdir(path_checkpoint):
|
|
273
|
+
for dir in [path_model, "{}Serialized/".format(path_model), "{}StateDict/".format(path_model)]:
|
|
274
|
+
if not os.path.exists(dir):
|
|
275
|
+
os.makedirs(dir)
|
|
276
|
+
|
|
277
|
+
for name in sorted(os.listdir(path_checkpoint)):
|
|
278
|
+
checkpoint = torch.load(path_checkpoint+name, weights_only=False)
|
|
279
|
+
self.model.load(checkpoint, init=False, ema=False)
|
|
280
|
+
|
|
281
|
+
torch.save(self.model, "{}Serialized/{}".format(path_model, name))
|
|
282
|
+
torch.save({"Model" : self.model.state_dict()}, "{}StateDict/{}".format(path_model, name))
|
|
283
|
+
|
|
284
|
+
if self.modelEMA is not None:
|
|
285
|
+
self.modelEMA.module.load(checkpoint, init=False, ema=True)
|
|
286
|
+
torch.save(self.modelEMA.module, "{}Serialized/{}".format(path_model, DATE()+"_EMA.pt"))
|
|
287
|
+
torch.save({"Model_EMA" : self.modelEMA.module.state_dict()}, "{}StateDict/{}".format(path_model, DATE()+"_EMA.pt"))
|
|
288
|
+
|
|
289
|
+
os.rename(self.config_namefile, self.config_namefile.replace(".yml", "")+"_"+str(self.it)+".yml")
|
|
290
|
+
|
|
291
|
+
def _avg_fn(self, averaged_model_parameter, model_parameter, num_averaged):
|
|
292
|
+
return (1-self.ema_decay) * averaged_model_parameter + self.ema_decay * model_parameter
|
|
293
|
+
|
|
294
|
+
def setup(self, world_size: int):
|
|
295
|
+
state = State._member_map_[DL_API_STATE()]
|
|
296
|
+
if state != State.RESUME and os.path.exists(STATISTICS_DIRECTORY()+self.name+"/"):
|
|
297
|
+
if os.environ["DL_API_OVERWRITE"] != "True":
|
|
298
|
+
accept = input("The model {} already exists ! Do you want to overwrite it (yes,no) : ".format(self.name))
|
|
299
|
+
if accept != "yes":
|
|
300
|
+
return
|
|
301
|
+
for directory_path in [STATISTICS_DIRECTORY(), MODELS_DIRECTORY(), CHECKPOINTS_DIRECTORY(), SETUPS_DIRECTORY()]:
|
|
302
|
+
if os.path.exists(directory_path+self.name+"/"):
|
|
303
|
+
shutil.rmtree(directory_path+self.name+"/")
|
|
304
|
+
|
|
305
|
+
state_dict = {}
|
|
306
|
+
if state != State.TRAIN:
|
|
307
|
+
state_dict = self._load()
|
|
308
|
+
|
|
309
|
+
self.model.init(self.autocast, state)
|
|
310
|
+
self.model.init_outputsGroup()
|
|
311
|
+
self.model._compute_channels_trace(self.model, self.model.in_channels, self.gradient_checkpoints, self.gpu_checkpoints)
|
|
312
|
+
self.model.load(state_dict, init=True, ema=False)
|
|
313
|
+
if self.ema_decay > 0:
|
|
314
|
+
self.modelEMA = AveragedModel(self.model, avg_fn=self._avg_fn)
|
|
315
|
+
if state_dict is not None:
|
|
316
|
+
self.modelEMA.module.load(state_dict, init=False, ema=True)
|
|
317
|
+
|
|
318
|
+
if not os.path.exists(SETUPS_DIRECTORY()+self.name+"/"):
|
|
319
|
+
os.makedirs(SETUPS_DIRECTORY()+self.name+"/")
|
|
320
|
+
shutil.copyfile(self.config_namefile_src+".yml", self.config_namefile)
|
|
321
|
+
|
|
322
|
+
self.dataloader = self.dataset.getData(world_size//self.size)
|
|
323
|
+
|
|
324
|
+
def run_process(self, world_size: int, global_rank: int, local_rank: int, dataloaders: list[DataLoader]):
|
|
325
|
+
model = Network.to(self.model, local_rank*self.size)
|
|
326
|
+
model = DDP(model, static_graph=True) if torch.cuda.is_available() else CPU_Model(model)
|
|
327
|
+
if self.modelEMA is not None:
|
|
328
|
+
self.modelEMA.module = Network.to(self.modelEMA.module, local_rank)
|
|
329
|
+
with _Trainer(world_size, global_rank, local_rank, self.size, self.name, self.data_log, self.save_checkpoint_mode, self.epochs, self.epoch, self.autocast, self.it_validation, self.it, model, self.modelEMA, *dataloaders) as t:
|
|
330
|
+
t.run()
|
konfai/utils/ITK.py
ADDED
|
@@ -0,0 +1,269 @@
|
|
|
1
|
+
import SimpleITK as sitk
|
|
2
|
+
from typing import Union
|
|
3
|
+
import numpy as np
|
|
4
|
+
import torch
|
|
5
|
+
import scipy
|
|
6
|
+
import torch.nn.functional as F
|
|
7
|
+
from KonfAI.konfai.utils.utils import _resample
|
|
8
|
+
|
|
9
|
+
def _openTransform(transform_files: dict[Union[str, sitk.Transform], bool], image: sitk.Image= None) -> list[sitk.Transform]:
|
|
10
|
+
transforms: list[sitk.Transform] = []
|
|
11
|
+
|
|
12
|
+
for transform_file, invert in transform_files.items():
|
|
13
|
+
if isinstance(transform_file, str):
|
|
14
|
+
transform = sitk.ReadTransform(transform_file+".itk.txt")
|
|
15
|
+
else:
|
|
16
|
+
transform = transform_file
|
|
17
|
+
if transform.GetName() == "TranslationTransform":
|
|
18
|
+
transform = sitk.TranslationTransform(transform)
|
|
19
|
+
if invert:
|
|
20
|
+
transform = sitk.TranslationTransform(transform.GetInverse())
|
|
21
|
+
elif transform.GetName() == "Euler3DTransform":
|
|
22
|
+
transform = sitk.Euler3DTransform(transform)
|
|
23
|
+
if invert:
|
|
24
|
+
transform = sitk.Euler3DTransform(transform.GetInverse())
|
|
25
|
+
elif transform.GetName() == "VersorRigid3DTransform":
|
|
26
|
+
transform = sitk.VersorRigid3DTransform(transform)
|
|
27
|
+
if invert:
|
|
28
|
+
transform = sitk.VersorRigid3DTransform(transform.GetInverse())
|
|
29
|
+
elif transform.GetName() == "AffineTransform":
|
|
30
|
+
transform = sitk.AffineTransform(transform)
|
|
31
|
+
if invert:
|
|
32
|
+
transform = sitk.AffineTransform(transform.GetInverse())
|
|
33
|
+
elif transform.GetName() == "DisplacementFieldTransform":
|
|
34
|
+
if invert:
|
|
35
|
+
transformToDisplacementFieldFilter = sitk.TransformToDisplacementFieldFilter()
|
|
36
|
+
transformToDisplacementFieldFilter.SetReferenceImage(image)
|
|
37
|
+
displacementField = transformToDisplacementFieldFilter.Execute(transform)
|
|
38
|
+
iterativeInverseDisplacementFieldImageFilter = sitk.IterativeInverseDisplacementFieldImageFilter()
|
|
39
|
+
iterativeInverseDisplacementFieldImageFilter.SetNumberOfIterations(20)
|
|
40
|
+
inverseDisplacementField = iterativeInverseDisplacementFieldImageFilter.Execute(displacementField)
|
|
41
|
+
transform = sitk.DisplacementFieldTransform(inverseDisplacementField)
|
|
42
|
+
transforms.append(transform)
|
|
43
|
+
else:
|
|
44
|
+
transform = sitk.BSplineTransform(transform)
|
|
45
|
+
if invert:
|
|
46
|
+
transformToDisplacementFieldFilter = sitk.TransformToDisplacementFieldFilter()
|
|
47
|
+
transformToDisplacementFieldFilter.SetReferenceImage(image)
|
|
48
|
+
displacementField = transformToDisplacementFieldFilter.Execute(transform)
|
|
49
|
+
iterativeInverseDisplacementFieldImageFilter = sitk.IterativeInverseDisplacementFieldImageFilter()
|
|
50
|
+
iterativeInverseDisplacementFieldImageFilter.SetNumberOfIterations(20)
|
|
51
|
+
inverseDisplacementField = iterativeInverseDisplacementFieldImageFilter.Execute(displacementField)
|
|
52
|
+
transform = sitk.DisplacementFieldTransform(inverseDisplacementField)
|
|
53
|
+
transforms.append(transform)
|
|
54
|
+
if len(transforms) == 0:
|
|
55
|
+
transforms.append(sitk.Euler3DTransform())
|
|
56
|
+
return transforms
|
|
57
|
+
|
|
58
|
+
def _openRigidTransform(transform_files: dict[Union[str, sitk.Transform], bool]) -> tuple[np.ndarray, np.ndarray]:
|
|
59
|
+
transforms = _openTransform(transform_files)
|
|
60
|
+
matrix_result = np.identity(3)
|
|
61
|
+
translation_result = np.array([0,0,0])
|
|
62
|
+
|
|
63
|
+
for transform in transforms:
|
|
64
|
+
if hasattr(transform, "GetMatrix"):
|
|
65
|
+
matrix = np.linalg.inv(np.array(transform.GetMatrix(), dtype=np.double).reshape((3,3)))
|
|
66
|
+
translation = -np.asarray(transform.GetTranslation(), dtype=np.double)
|
|
67
|
+
center = np.asarray(transform.GetCenter(), dtype=np.double)
|
|
68
|
+
else:
|
|
69
|
+
matrix = np.eye(len(transform.GetOffset()))
|
|
70
|
+
translation = -np.asarray(transform.GetOffset(), dtype=np.double)
|
|
71
|
+
center = np.asarray([0]*len(transform.GetOffset()), dtype=np.double)
|
|
72
|
+
|
|
73
|
+
translation_center = np.linalg.inv(matrix).dot(matrix.dot(translation-center)+center)
|
|
74
|
+
translation_result = np.linalg.inv(matrix_result).dot(translation_center)+translation_result
|
|
75
|
+
matrix_result = matrix.dot(matrix_result)
|
|
76
|
+
return np.linalg.inv(matrix_result), -translation_result
|
|
77
|
+
|
|
78
|
+
def composeTransform(transform_files : dict[Union[str, sitk.Transform], bool], image : sitk.Image = None) -> None:#sitk.CompositeTransform:
|
|
79
|
+
transforms = _openTransform(transform_files, image)
|
|
80
|
+
result = sitk.CompositeTransform(transforms)
|
|
81
|
+
return result
|
|
82
|
+
|
|
83
|
+
def flattenTransform(transform_files: dict[Union[str, sitk.Transform], bool]) -> sitk.AffineTransform:
|
|
84
|
+
[matrix, translation] = _openRigidTransform(transform_files)
|
|
85
|
+
transform = sitk.AffineTransform(3)
|
|
86
|
+
transform.SetMatrix(matrix.flatten())
|
|
87
|
+
transform.SetTranslation(translation)
|
|
88
|
+
return transform
|
|
89
|
+
|
|
90
|
+
def apply_to_image_RigidTransform(image: sitk.Image, transform_files: dict[Union[str, sitk.Transform], bool]) -> sitk.Image:
|
|
91
|
+
[matrix, translation] = _openRigidTransform(transform_files)
|
|
92
|
+
matrix = np.linalg.inv(matrix)
|
|
93
|
+
translation = -translation
|
|
94
|
+
data = sitk.GetArrayFromImage(image)
|
|
95
|
+
result = sitk.GetImageFromArray(data)
|
|
96
|
+
result.SetDirection(matrix.dot(np.array(image.GetDirection()).reshape((3,3))).flatten())
|
|
97
|
+
result.SetOrigin(matrix.dot(np.array(image.GetOrigin())+translation))
|
|
98
|
+
result.SetSpacing(image.GetSpacing())
|
|
99
|
+
return result
|
|
100
|
+
|
|
101
|
+
def apply_to_data_Transform(data: np.ndarray, transform_files: dict[Union[str, sitk.Transform], bool]) -> sitk.Image:
|
|
102
|
+
transforms = composeTransform(transform_files)
|
|
103
|
+
result = np.copy(data)
|
|
104
|
+
_LPS = lambda matrix: np.array([-matrix[0], -matrix[1], matrix[2]], dtype=np.double)
|
|
105
|
+
for i in range(data.shape[0]):
|
|
106
|
+
result[i, :] = _LPS(transforms.TransformPoint(np.asarray(_LPS(data[i, :]), dtype=np.double)))
|
|
107
|
+
return result
|
|
108
|
+
|
|
109
|
+
def resampleITK(image_reference : sitk.Image, image : sitk.Image, transform_files : dict[Union[str, sitk.Transform], bool], mask = False, defaultPixelValue: Union[float, None] = None, torch_resample : bool = False) -> sitk.Image:
|
|
110
|
+
if torch_resample:
|
|
111
|
+
input = torch.tensor(sitk.GetArrayFromImage(image)).unsqueeze(0)
|
|
112
|
+
vectors = [torch.arange(0, s) for s in input.shape[1:]]
|
|
113
|
+
grids = torch.meshgrid(vectors, indexing='ij')
|
|
114
|
+
grid = torch.stack(grids)
|
|
115
|
+
grid = torch.unsqueeze(grid, 0)
|
|
116
|
+
transformToDisplacementFieldFilter = sitk.TransformToDisplacementFieldFilter()
|
|
117
|
+
transformToDisplacementFieldFilter.SetReferenceImage(image)
|
|
118
|
+
transformToDisplacementFieldFilter.SetNumberOfThreads(16)
|
|
119
|
+
new_locs = grid + torch.tensor(sitk.GetArrayFromImage(transformToDisplacementFieldFilter.Execute(composeTransform(transform_files, image)))).unsqueeze(0).permute(0, 4, 1, 2, 3)
|
|
120
|
+
shape = new_locs.shape[2:]
|
|
121
|
+
for i in range(len(shape)):
|
|
122
|
+
new_locs[:, i, ...] = 2 * (new_locs[:, i, ...] / (shape[i] - 1) - 0.5)
|
|
123
|
+
new_locs = new_locs.permute(0, 2, 3, 4, 1)
|
|
124
|
+
new_locs = new_locs[..., [2, 1, 0]]
|
|
125
|
+
result_data = F.grid_sample(input.unsqueeze(0).float(), new_locs.float(), align_corners=True, padding_mode="border", mode="nearest" if input.dtype == torch.uint8 else "bilinear").squeeze(0)
|
|
126
|
+
result_data = result_data.type(torch.uint8) if input.dtype == torch.uint8 else result_data
|
|
127
|
+
result = sitk.GetImageFromArray(result_data.squeeze(0).numpy())
|
|
128
|
+
result.CopyInformation(image_reference)
|
|
129
|
+
return result
|
|
130
|
+
else:
|
|
131
|
+
return sitk.Resample(image, image_reference, composeTransform(transform_files, image), sitk.sitkNearestNeighbor if mask else sitk.sitkBSpline, (defaultPixelValue if defaultPixelValue is not None else (0 if mask else int(np.min(sitk.GetArrayFromImage(image))))))
|
|
132
|
+
|
|
133
|
+
def parameterMap_to_transform(path_src: str) -> Union[sitk.Transform, list[sitk.Transform]]:
|
|
134
|
+
transform = sitk.ReadParameterFile("{}.0.txt".format(path_src))
|
|
135
|
+
format = lambda x: np.array([float(i) for i in x])
|
|
136
|
+
|
|
137
|
+
if transform["Transform"][0] == "EulerTransform":
|
|
138
|
+
result = sitk.Euler3DTransform()
|
|
139
|
+
parameters = format(transform["TransformParameters"])
|
|
140
|
+
fixedParameters = format(transform["CenterOfRotationPoint"])+[0]
|
|
141
|
+
elif transform["Transform"][0] == "AffineTransform":
|
|
142
|
+
result = sitk.AffineTransform(3)
|
|
143
|
+
parameters = format(transform["TransformParameters"])
|
|
144
|
+
fixedParameters = format(transform["CenterOfRotationPoint"])+[0]
|
|
145
|
+
elif transform["Transform"][0] == "BSplineStackTransform":
|
|
146
|
+
parameters = format(transform["TransformParameters"])
|
|
147
|
+
GridSize = format(transform["GridSize"])
|
|
148
|
+
GridOrigin = format(transform["GridOrigin"])
|
|
149
|
+
GridSpacing = format(transform["GridSpacing"])
|
|
150
|
+
GridDirection = format(transform["GridDirection"]).reshape((3,3)).T.flatten()
|
|
151
|
+
fixedParameters = np.concatenate([GridSize, GridOrigin, GridSpacing, GridDirection])
|
|
152
|
+
|
|
153
|
+
nb = int(format(transform["Size"])[-1])
|
|
154
|
+
sub = int(np.prod(GridSize))*3
|
|
155
|
+
results = []
|
|
156
|
+
for i in range(nb):
|
|
157
|
+
result = sitk.BSplineTransform(3)
|
|
158
|
+
sub_parameters = parameters[i*sub:(i+1)*sub]
|
|
159
|
+
result.SetFixedParameters(fixedParameters)
|
|
160
|
+
result.SetParameters(sub_parameters)
|
|
161
|
+
results.append(result)
|
|
162
|
+
return results
|
|
163
|
+
elif transform["Transform"][0] == "AffineLogStackTransform":
|
|
164
|
+
parameters = format(transform["TransformParameters"])
|
|
165
|
+
fixedParameters = format(transform["CenterOfRotationPoint"])+[0]
|
|
166
|
+
|
|
167
|
+
nb = int(transform["NumberOfSubTransforms"][0])
|
|
168
|
+
sub = 12
|
|
169
|
+
results = []
|
|
170
|
+
for i in range(nb):
|
|
171
|
+
result = sitk.AffineTransform(3)
|
|
172
|
+
sub_parameters = parameters[i*sub:(i+1)*sub]
|
|
173
|
+
|
|
174
|
+
result.SetFixedParameters(fixedParameters)
|
|
175
|
+
result.SetParameters(np.concatenate([scipy.linalg.expm(sub_parameters[:9].reshape((3,3))).flatten(), sub_parameters[-3:]]))
|
|
176
|
+
results.append(result)
|
|
177
|
+
return results
|
|
178
|
+
elif transform["Transform"][0] == "BSplineTransform":
|
|
179
|
+
result = sitk.BSplineTransform(3)
|
|
180
|
+
|
|
181
|
+
parameters = format(transform["TransformParameters"])
|
|
182
|
+
GridSize = format(transform["GridSize"])
|
|
183
|
+
GridOrigin = format(transform["GridOrigin"])
|
|
184
|
+
GridSpacing = format(transform["GridSpacing"])
|
|
185
|
+
GridDirection = np.array(format(transform["GridDirection"])).reshape((3,3)).T.flatten()
|
|
186
|
+
fixedParameters = np.concatenate([GridSize, GridOrigin, GridSpacing, GridDirection])
|
|
187
|
+
else:
|
|
188
|
+
raise NameError("Transform {} doesn't exist".format(transform["Transform"][0]))
|
|
189
|
+
result.SetFixedParameters(fixedParameters)
|
|
190
|
+
result.SetParameters(parameters)
|
|
191
|
+
return result
|
|
192
|
+
|
|
193
|
+
def resampleIsotropic(image: sitk.Image, spacing : list[float] = [1., 1., 1.]) -> sitk.Image:
|
|
194
|
+
resize_factor = [y/x for x,y in zip(spacing, image.GetSpacing())]
|
|
195
|
+
result = sitk.GetImageFromArray(_resample(torch.tensor(sitk.GetArrayFromImage(image)).unsqueeze(0), [int(size*factor) for size, factor in zip(image.GetSize(), resize_factor)]).squeeze(0).numpy())
|
|
196
|
+
result.SetDirection(image.GetDirection())
|
|
197
|
+
result.SetOrigin(image.GetOrigin())
|
|
198
|
+
result.SetSpacing(spacing)
|
|
199
|
+
return result
|
|
200
|
+
|
|
201
|
+
def resampleResize(image: sitk.Image, size : list[int] = [100,512,512]):
|
|
202
|
+
result = sitk.GetImageFromArray(_resample(torch.tensor(sitk.GetArrayFromImage(image)).unsqueeze(0), size).squeeze(0).numpy())
|
|
203
|
+
result.SetDirection(image.GetDirection())
|
|
204
|
+
result.SetOrigin(image.GetOrigin())
|
|
205
|
+
result.SetSpacing([x/y*z for x,y,z in zip(image.GetSize(), size, image.GetSpacing())])
|
|
206
|
+
return result
|
|
207
|
+
|
|
208
|
+
def box_with_mask(mask: sitk.Image, label: list[int], dilatations: list[int]) -> np.ndarray:
|
|
209
|
+
|
|
210
|
+
dilatations = [int(np.ceil(d/s)) for d, s in zip(dilatations, reversed(mask.GetSpacing()))]
|
|
211
|
+
|
|
212
|
+
data = sitk.GetArrayFromImage(mask)
|
|
213
|
+
border = np.where(np.isin(sitk.GetArrayFromImage(mask), label))
|
|
214
|
+
box = []
|
|
215
|
+
for w, dilatation, s in zip(border, dilatations, data.shape):
|
|
216
|
+
box.append([max(np.min(w)-dilatation, 0), min(np.max(w)+dilatation, s)])
|
|
217
|
+
box = np.asarray(box)
|
|
218
|
+
return box
|
|
219
|
+
|
|
220
|
+
def crop_with_mask(image: sitk.Image, box: np.ndarray) -> sitk.Image:
|
|
221
|
+
data = sitk.GetArrayFromImage(image)
|
|
222
|
+
|
|
223
|
+
for i, w in enumerate(box):
|
|
224
|
+
data = np.delete(data, slice(w[1], data.shape[i]), i)
|
|
225
|
+
data = np.delete(data, slice(0, w[0]), i)
|
|
226
|
+
|
|
227
|
+
origin = np.asarray(image.GetOrigin())
|
|
228
|
+
matrix = np.asarray(image.GetDirection()).reshape((len(origin), len(origin)))
|
|
229
|
+
origin = origin.dot(matrix)
|
|
230
|
+
for i, w in enumerate(box):
|
|
231
|
+
origin[-i-1] += w[0]*np.asarray(image.GetSpacing())[-i-1]
|
|
232
|
+
origin = origin.dot(np.linalg.inv(matrix))
|
|
233
|
+
|
|
234
|
+
result = sitk.GetImageFromArray(data)
|
|
235
|
+
result.SetOrigin(origin)
|
|
236
|
+
result.SetSpacing(image.GetSpacing())
|
|
237
|
+
result.SetDirection(image.GetDirection())
|
|
238
|
+
return result
|
|
239
|
+
|
|
240
|
+
def formatMaskLabel(mask: sitk.Image, labels: list[tuple[int, int]]) -> sitk.Image:
|
|
241
|
+
data = sitk.GetArrayFromImage(mask)
|
|
242
|
+
result_data = np.zeros_like(data, np.uint8)
|
|
243
|
+
|
|
244
|
+
for label_old, label_new in labels:
|
|
245
|
+
result_data[np.where(data == label_old)] = label_new
|
|
246
|
+
|
|
247
|
+
result = sitk.GetImageFromArray(result_data)
|
|
248
|
+
result.CopyInformation(mask)
|
|
249
|
+
return result
|
|
250
|
+
|
|
251
|
+
def getFlatLabel(mask: sitk.Image, labels: Union[None, list[int]] = None) -> sitk.Image:
|
|
252
|
+
data = sitk.GetArrayFromImage(mask)
|
|
253
|
+
result_data = np.zeros_like(data, np.uint8)
|
|
254
|
+
if labels is not None:
|
|
255
|
+
for label in labels:
|
|
256
|
+
result_data[np.where(data == label)] = 1
|
|
257
|
+
else:
|
|
258
|
+
result_data[np.where(data > 0)] = 1
|
|
259
|
+
result = sitk.GetImageFromArray(result_data)
|
|
260
|
+
result.CopyInformation(mask)
|
|
261
|
+
return result
|
|
262
|
+
|
|
263
|
+
def clipAndCast(image : sitk.Image, min: float, max: float, dtype: np.dtype) -> sitk.Image:
|
|
264
|
+
data = sitk.GetArrayFromImage(image)
|
|
265
|
+
data[np.where(data > max)] = max
|
|
266
|
+
data[np.where(data < min)] = min
|
|
267
|
+
result = sitk.GetImageFromArray(data.astype(dtype))
|
|
268
|
+
result.CopyInformation(image)
|
|
269
|
+
return result
|