konfai 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of konfai might be problematic. Click here for more details.
- konfai/__init__.py +16 -0
- konfai/data/HDF5.py +326 -0
- konfai/data/__init__.py +0 -0
- konfai/data/augmentation.py +597 -0
- konfai/data/dataset.py +470 -0
- konfai/data/transform.py +536 -0
- konfai/evaluator.py +146 -0
- konfai/main.py +43 -0
- konfai/metric/__init__.py +0 -0
- konfai/metric/measure.py +488 -0
- konfai/metric/schedulers.py +49 -0
- konfai/models/classification/convNeXt.py +175 -0
- konfai/models/classification/resnet.py +116 -0
- konfai/models/generation/cStyleGan.py +137 -0
- konfai/models/generation/ddpm.py +218 -0
- konfai/models/generation/diffusionGan.py +557 -0
- konfai/models/generation/gan.py +134 -0
- konfai/models/generation/vae.py +72 -0
- konfai/models/registration/registration.py +136 -0
- konfai/models/representation/representation.py +57 -0
- konfai/models/segmentation/NestedUNet.py +53 -0
- konfai/models/segmentation/UNet.py +58 -0
- konfai/network/__init__.py +0 -0
- konfai/network/blocks.py +348 -0
- konfai/network/network.py +950 -0
- konfai/predictor.py +366 -0
- konfai/trainer.py +330 -0
- konfai/utils/ITK.py +269 -0
- konfai/utils/Registration.py +199 -0
- konfai/utils/__init__.py +0 -0
- konfai/utils/config.py +218 -0
- konfai/utils/dataset.py +764 -0
- konfai/utils/utils.py +493 -0
- konfai-1.0.0.dist-info/METADATA +68 -0
- konfai-1.0.0.dist-info/RECORD +39 -0
- konfai-1.0.0.dist-info/WHEEL +5 -0
- konfai-1.0.0.dist-info/entry_points.txt +3 -0
- konfai-1.0.0.dist-info/licenses/LICENSE +201 -0
- konfai-1.0.0.dist-info/top_level.txt +1 -0
konfai/predictor.py
ADDED
|
@@ -0,0 +1,366 @@
|
|
|
1
|
+
from abc import ABC, abstractmethod
|
|
2
|
+
import builtins
|
|
3
|
+
import importlib
|
|
4
|
+
import shutil
|
|
5
|
+
import torch
|
|
6
|
+
import tqdm
|
|
7
|
+
import os
|
|
8
|
+
|
|
9
|
+
from KonfAI.konfai import MODELS_DIRECTORY, PREDICTIONS_DIRECTORY, CONFIG_FILE, MODEL, DEEP_LEARNING_API_ROOT
|
|
10
|
+
from KonfAI.konfai.utils.config import config
|
|
11
|
+
from KonfAI.konfai.utils.utils import State, get_patch_slices_from_nb_patch_per_dim, NeedDevice, _getModule, DistributedObject, DataLog, description
|
|
12
|
+
from KonfAI.konfai.utils.dataset import Dataset, Attribute
|
|
13
|
+
from KonfAI.konfai.data.dataset import DataPrediction, DatasetIter
|
|
14
|
+
from KonfAI.konfai.data.HDF5 import Accumulator, PathCombine
|
|
15
|
+
from KonfAI.konfai.network.network import ModelLoader, Network, NetState, CPU_Model
|
|
16
|
+
from KonfAI.konfai.data.transform import Transform, TransformLoader
|
|
17
|
+
|
|
18
|
+
from torch.utils.tensorboard.writer import SummaryWriter
|
|
19
|
+
from typing import Union
|
|
20
|
+
import numpy as np
|
|
21
|
+
import torch.distributed as dist
|
|
22
|
+
from torch.nn.parallel import DistributedDataParallel as DDP
|
|
23
|
+
from torch.utils.data import DataLoader
|
|
24
|
+
import importlib
|
|
25
|
+
|
|
26
|
+
class OutDataset(Dataset, NeedDevice, ABC):
|
|
27
|
+
|
|
28
|
+
def __init__(self, filename: str, group: str, pre_transforms : dict[str, TransformLoader], post_transforms : dict[str, TransformLoader], final_transforms : dict[str, TransformLoader], patchCombine: Union[str, None]) -> None:
|
|
29
|
+
filename, format = filename.split(":")
|
|
30
|
+
super().__init__(filename, format)
|
|
31
|
+
self.group = group
|
|
32
|
+
self._pre_transforms = pre_transforms
|
|
33
|
+
self._post_transforms = post_transforms
|
|
34
|
+
self._final_transforms = final_transforms
|
|
35
|
+
self._patchCombine = patchCombine
|
|
36
|
+
|
|
37
|
+
self.pre_transforms : list[Transform] = []
|
|
38
|
+
self.post_transforms : list[Transform] = []
|
|
39
|
+
self.final_transforms : list[Transform] = []
|
|
40
|
+
self.patchCombine: PathCombine = None
|
|
41
|
+
|
|
42
|
+
self.output_layer_accumulator: dict[int, dict[int, Accumulator]] = {}
|
|
43
|
+
self.attributes: dict[int, dict[int, dict[int, Attribute]]] = {}
|
|
44
|
+
self.names: dict[int, str] = {}
|
|
45
|
+
self.nb_data_augmentation = 0
|
|
46
|
+
|
|
47
|
+
def load(self, name_layer: str, datasets: list[Dataset]):
|
|
48
|
+
transforms_type = ["pre_transforms", "post_transforms", "final_transforms"]
|
|
49
|
+
for name, _transform_type, transform_type in [(k, getattr(self, "_{}".format(k)), getattr(self, k)) for k in transforms_type]:
|
|
50
|
+
|
|
51
|
+
if _transform_type is not None:
|
|
52
|
+
for classpath, transform in _transform_type.items():
|
|
53
|
+
transform = transform.getTransform(classpath, DL_args = "{}.outsDataset.{}.OutDataset.{}".format(DEEP_LEARNING_API_ROOT(), name_layer, name))
|
|
54
|
+
transform.setDatasets(datasets)
|
|
55
|
+
transform_type.append(transform)
|
|
56
|
+
|
|
57
|
+
if self._patchCombine is not None:
|
|
58
|
+
module, name = _getModule(self._patchCombine, "HDF5")
|
|
59
|
+
self.patchCombine = getattr(importlib.import_module(module), name)(config = None, DL_args = "{}.outsDataset.{}.OutDataset".format(DEEP_LEARNING_API_ROOT(), name_layer))
|
|
60
|
+
|
|
61
|
+
def setPatchConfig(self, patchSize: Union[list[int], None], overlap: Union[int, None], nb_data_augmentation: int) -> None:
|
|
62
|
+
if patchSize is not None and overlap is not None:
|
|
63
|
+
if self.patchCombine is not None:
|
|
64
|
+
self.patchCombine.setPatchConfig(patchSize, overlap)
|
|
65
|
+
else:
|
|
66
|
+
self.patchCombine = None
|
|
67
|
+
self.nb_data_augmentation = nb_data_augmentation
|
|
68
|
+
|
|
69
|
+
def setDevice(self, device: torch.device):
|
|
70
|
+
super().setDevice(device)
|
|
71
|
+
transforms_type = ["pre_transforms", "post_transforms", "final_transforms"]
|
|
72
|
+
for transform_type in [(getattr(self, k)) for k in transforms_type]:
|
|
73
|
+
if transform_type is not None:
|
|
74
|
+
for transform in transform_type:
|
|
75
|
+
transform.setDevice(device)
|
|
76
|
+
|
|
77
|
+
@abstractmethod
|
|
78
|
+
def addLayer(self, index: int, index_patch: int, layer: torch.Tensor, dataset: DatasetIter):
|
|
79
|
+
pass
|
|
80
|
+
|
|
81
|
+
def isDone(self, index: int) -> bool:
|
|
82
|
+
return len(self.output_layer_accumulator[index]) == self.nb_data_augmentation and all([acc.isFull() for acc in self.output_layer_accumulator[index].values()])
|
|
83
|
+
|
|
84
|
+
@abstractmethod
|
|
85
|
+
def getOutput(self, index: int, dataset: DatasetIter) -> torch.Tensor:
|
|
86
|
+
pass
|
|
87
|
+
|
|
88
|
+
def write(self, index: int, name: str, layer: torch.Tensor):
|
|
89
|
+
super().write(self.group, name, layer.numpy(), self.attributes[index][0][0])
|
|
90
|
+
self.attributes.pop(index)
|
|
91
|
+
|
|
92
|
+
class OutSameAsGroupDataset(OutDataset):
|
|
93
|
+
|
|
94
|
+
@config("OutDataset")
|
|
95
|
+
def __init__(self, dataset_filename: str = "Dataset:h5", group: str = "default", sameAsGroup: str = "default", pre_transforms : dict[str, TransformLoader] = {"default:Normalize": TransformLoader()}, post_transforms : dict[str, TransformLoader] = {"default:Normalize": TransformLoader()}, final_transforms : dict[str, TransformLoader] = {"default:Normalize": TransformLoader()}, patchCombine: Union[str, None] = None, redution: str = "mean", inverse_transform: bool = True) -> None:
|
|
96
|
+
super().__init__(dataset_filename, group, pre_transforms, post_transforms, final_transforms, patchCombine)
|
|
97
|
+
self.group_src, self.group_dest = sameAsGroup.split(":")
|
|
98
|
+
self.redution = redution
|
|
99
|
+
self.inverse_transform = inverse_transform
|
|
100
|
+
|
|
101
|
+
def addLayer(self, index_dataset: int, index_augmentation: int, index_patch: int, layer: torch.Tensor, dataset: DatasetIter):
|
|
102
|
+
if index_dataset not in self.output_layer_accumulator or index_augmentation not in self.output_layer_accumulator[index_dataset]:
|
|
103
|
+
input_dataset = dataset.getDatasetFromIndex(self.group_dest, index_dataset)
|
|
104
|
+
if index_dataset not in self.output_layer_accumulator:
|
|
105
|
+
self.output_layer_accumulator[index_dataset] = {}
|
|
106
|
+
self.attributes[index_dataset] = {}
|
|
107
|
+
self.names[index_dataset] = input_dataset.name
|
|
108
|
+
self.attributes[index_dataset][index_augmentation] = {}
|
|
109
|
+
|
|
110
|
+
self.output_layer_accumulator[index_dataset][index_augmentation] = Accumulator(input_dataset.patch.getPatch_slices(index_augmentation), input_dataset.patch.patch_size, self.patchCombine, batch=False)
|
|
111
|
+
|
|
112
|
+
for i in range(len(input_dataset.patch.getPatch_slices(index_augmentation))):
|
|
113
|
+
self.attributes[index_dataset][index_augmentation][i] = Attribute(input_dataset.cache_attributes[0])
|
|
114
|
+
|
|
115
|
+
for transform in self.pre_transforms:
|
|
116
|
+
layer = transform(self.names[index_dataset], layer, self.attributes[index_dataset][index_augmentation][index_patch])
|
|
117
|
+
|
|
118
|
+
if self.inverse_transform:
|
|
119
|
+
for transform in reversed(dataset.groups_src[self.group_src][self.group_dest].post_transforms):
|
|
120
|
+
layer = transform.inverse(self.names[index_dataset], layer, self.attributes[index_dataset][index_augmentation][index_patch])
|
|
121
|
+
|
|
122
|
+
self.output_layer_accumulator[index_dataset][index_augmentation].addLayer(index_patch, layer)
|
|
123
|
+
|
|
124
|
+
def _getOutput(self, index: int, index_augmentation: int, dataset: DatasetIter) -> torch.Tensor:
|
|
125
|
+
layer = self.output_layer_accumulator[index][index_augmentation].assemble()
|
|
126
|
+
name = self.names[index]
|
|
127
|
+
if index_augmentation > 0:
|
|
128
|
+
|
|
129
|
+
i = 0
|
|
130
|
+
index_augmentation_tmp = index_augmentation-1
|
|
131
|
+
for dataAugmentations in dataset.dataAugmentationsList:
|
|
132
|
+
if index_augmentation_tmp >= i and index_augmentation_tmp < i+dataAugmentations.nb:
|
|
133
|
+
for dataAugmentation in reversed(dataAugmentations.dataAugmentations):
|
|
134
|
+
layer = dataAugmentation.inverse(index, index_augmentation_tmp-i, layer)
|
|
135
|
+
break
|
|
136
|
+
i += dataAugmentations.nb
|
|
137
|
+
|
|
138
|
+
for transform in self.post_transforms:
|
|
139
|
+
layer = transform(name, layer, self.attributes[index][index_augmentation][0])
|
|
140
|
+
|
|
141
|
+
if self.inverse_transform:
|
|
142
|
+
for transform in reversed(dataset.groups_src[self.group_src][self.group_dest].pre_transforms):
|
|
143
|
+
layer = transform.inverse(name, layer, self.attributes[index][index_augmentation][0])
|
|
144
|
+
return layer
|
|
145
|
+
|
|
146
|
+
def getOutput(self, index: int, dataset: DatasetIter) -> torch.Tensor:
|
|
147
|
+
result = torch.cat([self._getOutput(index, index_augmentation, dataset).unsqueeze(0) for index_augmentation in self.output_layer_accumulator[index].keys()], dim=0)
|
|
148
|
+
name = self.names[index]
|
|
149
|
+
self.output_layer_accumulator.pop(index)
|
|
150
|
+
dtype = result.dtype
|
|
151
|
+
|
|
152
|
+
if self.redution == "mean":
|
|
153
|
+
result = torch.mean(result.float(), dim=0).to(dtype)
|
|
154
|
+
elif self.redution == "median":
|
|
155
|
+
result, _ = torch.median(result.float(), dim=0).to(dtype)
|
|
156
|
+
else:
|
|
157
|
+
raise NameError("Reduction method does not exist (mean, median)")
|
|
158
|
+
for transform in self.final_transforms:
|
|
159
|
+
result = transform(name, result, self.attributes[index][0][0])
|
|
160
|
+
return result
|
|
161
|
+
|
|
162
|
+
class OutLayerDataset(OutDataset):
|
|
163
|
+
|
|
164
|
+
@config("OutDataset")
|
|
165
|
+
def __init__(self, dataset_filename: str = "Dataset.h5", group: str = "default", overlap : Union[list[int], None] = None, pre_transforms : dict[str, TransformLoader] = {"default:Normalize": TransformLoader()}, post_transforms : dict[str, TransformLoader] = {"default:Normalize": TransformLoader()}, final_transforms : dict[str, TransformLoader] = {"default:Normalize": TransformLoader()}, patchCombine: Union[str, None] = None) -> None:
|
|
166
|
+
super().__init__(dataset_filename, group, pre_transforms, post_transforms, final_transforms, patchCombine)
|
|
167
|
+
self.overlap = overlap
|
|
168
|
+
|
|
169
|
+
def addLayer(self, index: int, index_patch: int, layer: torch.Tensor, dataset: DatasetIter):
|
|
170
|
+
if index not in self.output_layer_accumulator:
|
|
171
|
+
group = list(dataset.groups.keys())[0]
|
|
172
|
+
patch_slices = get_patch_slices_from_nb_patch_per_dim(list(layer.shape[2:]), dataset.getDatasetFromIndex(group, index).patch.nb_patch_per_dim, self.overlap)
|
|
173
|
+
self.output_layer_accumulator[index] = Accumulator(patch_slices, self.patchCombine, batch=False)
|
|
174
|
+
self.attributes[index] = Attribute()
|
|
175
|
+
self.names[index] = dataset.getDatasetFromIndex(group, index).name
|
|
176
|
+
|
|
177
|
+
|
|
178
|
+
for transform in self.pre_transforms:
|
|
179
|
+
layer = transform(layer, self.attributes[index])
|
|
180
|
+
self.output_layer_accumulator[index].addLayer(index_patch, layer)
|
|
181
|
+
|
|
182
|
+
def getOutput(self, index: int, dataset: DatasetIter) -> torch.Tensor:
|
|
183
|
+
layer = self.output_layer_accumulator[index].assemble()
|
|
184
|
+
name = self.names[index]
|
|
185
|
+
for transform in self.post_transforms:
|
|
186
|
+
layer = transform(name, layer, self.attributes[index])
|
|
187
|
+
|
|
188
|
+
self.output_layer_accumulator.pop(index)
|
|
189
|
+
return layer
|
|
190
|
+
|
|
191
|
+
class OutDatasetLoader():
|
|
192
|
+
|
|
193
|
+
@config("OutDataset")
|
|
194
|
+
def __init__(self, name_class: str = "OutSameAsGroupDataset") -> None:
|
|
195
|
+
self.name_class = name_class
|
|
196
|
+
|
|
197
|
+
def getOutDataset(self, layer_name: str) -> OutDataset:
|
|
198
|
+
return getattr(importlib.import_module("KonfAI.konfai.predictor"), self.name_class)(config = None, DL_args = "Predictor.outsDataset.{}".format(layer_name))
|
|
199
|
+
|
|
200
|
+
class _Predictor():
|
|
201
|
+
|
|
202
|
+
def __init__(self, world_size: int, global_rank: int, local_rank: int, predict_path: str, data_log: Union[list[str], None], outsDataset: dict[str, OutDataset], model: DDP, dataloader_prediction: DataLoader) -> None:
|
|
203
|
+
self.world_size = world_size
|
|
204
|
+
self.global_rank = global_rank
|
|
205
|
+
self.local_rank = local_rank
|
|
206
|
+
|
|
207
|
+
self.model = model
|
|
208
|
+
self.dataloader_prediction = dataloader_prediction
|
|
209
|
+
self.outsDataset = outsDataset
|
|
210
|
+
|
|
211
|
+
|
|
212
|
+
self.it = 0
|
|
213
|
+
|
|
214
|
+
self.device = self.model.device
|
|
215
|
+
self.dataset: DatasetIter = self.dataloader_prediction.dataset
|
|
216
|
+
patch_size, overlap = self.dataset.getPatchConfig()
|
|
217
|
+
for outDataset in self.outsDataset.values():
|
|
218
|
+
outDataset.setPatchConfig([size for size in patch_size if size > 1], overlap, np.max([int(np.sum([data_augmentation.nb for data_augmentation in self.dataset.dataAugmentationsList])+1), 1]))
|
|
219
|
+
self.data_log : dict[str, tuple[DataLog, int]] = {}
|
|
220
|
+
if data_log is not None:
|
|
221
|
+
for data in data_log:
|
|
222
|
+
self.data_log[data.split("/")[0].replace(":", ".")] = (DataLog.__getitem__(data.split("/")[1]).value[0], int(data.split("/")[2]))
|
|
223
|
+
self.tb = SummaryWriter(log_dir = predict_path+"Metric/") if len([network for network in self.model.module.getNetworks().values() if network.measure is not None]) or len(self.data_log) else None
|
|
224
|
+
|
|
225
|
+
def __enter__(self):
|
|
226
|
+
return self
|
|
227
|
+
|
|
228
|
+
def __exit__(self, type, value, traceback):
|
|
229
|
+
if self.tb:
|
|
230
|
+
self.tb.close()
|
|
231
|
+
|
|
232
|
+
def getInput(self, data_dict : dict[str, tuple[torch.Tensor, int, int, int, str, bool]]) -> dict[tuple[str, bool], torch.Tensor]:
|
|
233
|
+
return {(k, v[5][0].item()) : v[0] for k, v in data_dict.items()}
|
|
234
|
+
|
|
235
|
+
@torch.no_grad()
|
|
236
|
+
def run(self):
|
|
237
|
+
self.model.eval()
|
|
238
|
+
self.model.module.setState(NetState.PREDICTION)
|
|
239
|
+
desc = lambda : "Prediction : {}".format(description(self.model))
|
|
240
|
+
self.dataloader_prediction.dataset.load()
|
|
241
|
+
with tqdm.tqdm(iterable = enumerate(self.dataloader_prediction), leave=False, desc = desc(), total=len(self.dataloader_prediction), disable=self.global_rank != 0 and "DL_API_CLUSTER" not in os.environ) as batch_iter:
|
|
242
|
+
dist.barrier()
|
|
243
|
+
for it, data_dict in batch_iter:
|
|
244
|
+
input = self.getInput(data_dict)
|
|
245
|
+
for name, output in self.model(input, list(self.outsDataset.keys())):
|
|
246
|
+
self._predict_log(data_dict)
|
|
247
|
+
outDataset = self.outsDataset[name]
|
|
248
|
+
for i, (index, patch_augmentation, patch_index) in enumerate([(int(index), int(patch_augmentation), int(patch_index)) for index, patch_augmentation, patch_index in zip(list(data_dict.values())[0][1], list(data_dict.values())[0][2], list(data_dict.values())[0][3])]):
|
|
249
|
+
outDataset.addLayer(index, patch_augmentation, patch_index, output[i].cpu(), self.dataset)
|
|
250
|
+
if outDataset.isDone(index):
|
|
251
|
+
outDataset.write(index, self.dataset.getDatasetFromIndex(list(data_dict.keys())[0], index).name.split("/")[-1], outDataset.getOutput(index, self.dataset))
|
|
252
|
+
|
|
253
|
+
batch_iter.set_description(desc())
|
|
254
|
+
self.it += 1
|
|
255
|
+
|
|
256
|
+
def _predict_log(self, data_dict : dict[str, tuple[torch.Tensor, int, int, int]]):
|
|
257
|
+
measures = DistributedObject.getMeasure(self.world_size, self.global_rank, self.local_rank, {"" : self.model.module}, 1)
|
|
258
|
+
|
|
259
|
+
if self.global_rank == 0:
|
|
260
|
+
images_log = []
|
|
261
|
+
if len(self.data_log):
|
|
262
|
+
for name, data_type in self.data_log.items():
|
|
263
|
+
if name in data_dict:
|
|
264
|
+
data_type[0](self.tb, "Prediction/{}".format(name), data_dict[name][0][:self.data_log[name][1]].detach().cpu().numpy(), self.it)
|
|
265
|
+
else:
|
|
266
|
+
images_log.append(name.replace(":", "."))
|
|
267
|
+
|
|
268
|
+
for name, network in self.model.module.getNetworks().items():
|
|
269
|
+
if network.measure is not None:
|
|
270
|
+
self.tb.add_scalars("Prediction/{}/Loss".format(name), {k : v[1] for k, v in measures["{}{}".format(name, "")][0].items()}, self.it)
|
|
271
|
+
self.tb.add_scalars("Prediction/{}/Metric".format(name), {k : v[1] for k, v in measures["{}{}".format(name, "")][1].items()}, self.it)
|
|
272
|
+
if len(images_log):
|
|
273
|
+
for name, layer, _ in self.model.module.get_layers([v.to(0) for k, v in self.getInput(data_dict).items() if k[1]], images_log):
|
|
274
|
+
self.data_log[name][0](self.tb, "Prediction/{}".format(name), layer[:self.data_log[name][1]].detach().cpu().numpy(), self.it)
|
|
275
|
+
|
|
276
|
+
class Predictor(DistributedObject):
|
|
277
|
+
|
|
278
|
+
@config("Predictor")
|
|
279
|
+
def __init__(self,
|
|
280
|
+
model: ModelLoader = ModelLoader(),
|
|
281
|
+
dataset: DataPrediction = DataPrediction(),
|
|
282
|
+
train_name: str = "name",
|
|
283
|
+
manual_seed : Union[int, None] = None,
|
|
284
|
+
gpu_checkpoints: Union[list[str], None] = None,
|
|
285
|
+
outsDataset: Union[dict[str, OutDatasetLoader], None] = {"default:Default" : OutDatasetLoader()},
|
|
286
|
+
images_log: list[str] = []) -> None:
|
|
287
|
+
if os.environ["DEEP_LEANING_API_CONFIG_MODE"] != "Done":
|
|
288
|
+
exit(0)
|
|
289
|
+
super().__init__(train_name)
|
|
290
|
+
self.manual_seed = manual_seed
|
|
291
|
+
self.dataset = dataset
|
|
292
|
+
|
|
293
|
+
self.model = model.getModel(train=False)
|
|
294
|
+
self.it = 0
|
|
295
|
+
self.outsDatasetLoader = outsDataset if outsDataset else {}
|
|
296
|
+
self.outsDataset = {name.replace(":", ".") : value.getOutDataset(name) for name, value in self.outsDatasetLoader.items()}
|
|
297
|
+
|
|
298
|
+
self.datasets_filename = []
|
|
299
|
+
self.predict_path = PREDICTIONS_DIRECTORY()+self.name+"/"
|
|
300
|
+
self.images_log = images_log
|
|
301
|
+
for outDataset in self.outsDataset.values():
|
|
302
|
+
self.datasets_filename.append(outDataset.filename)
|
|
303
|
+
outDataset.filename = "{}{}".format(self.predict_path, outDataset.filename)
|
|
304
|
+
|
|
305
|
+
|
|
306
|
+
self.gpu_checkpoints = gpu_checkpoints
|
|
307
|
+
|
|
308
|
+
def _load(self) -> dict[str, dict[str, torch.Tensor]]:
|
|
309
|
+
if MODEL().startswith("https://"):
|
|
310
|
+
try:
|
|
311
|
+
state_dict = {MODEL().split(":")[1]: torch.hub.load_state_dict_from_url(url=MODEL().split(":")[0], map_location="cpu", check_hash=True)}
|
|
312
|
+
except:
|
|
313
|
+
raise Exception("Model : {} does not exist !".format(MODEL()))
|
|
314
|
+
else:
|
|
315
|
+
if MODEL() != "":
|
|
316
|
+
path = ""
|
|
317
|
+
name = MODEL()
|
|
318
|
+
else:
|
|
319
|
+
if self.name.endswith(".pt"):
|
|
320
|
+
path = MODELS_DIRECTORY()+"/".join(self.name.split("/")[:-1])+"/StateDict/"
|
|
321
|
+
name = self.name.split("/")[-1]
|
|
322
|
+
else:
|
|
323
|
+
path = MODELS_DIRECTORY()+self.name+"/StateDict/"
|
|
324
|
+
name = sorted(os.listdir(path))[-1]
|
|
325
|
+
if os.path.exists(path+name):
|
|
326
|
+
state_dict = torch.load(path+name, weights_only=False)
|
|
327
|
+
else:
|
|
328
|
+
raise Exception("Model : {} does not exist !".format(path+name))
|
|
329
|
+
return state_dict
|
|
330
|
+
|
|
331
|
+
def setup(self, world_size: int):
|
|
332
|
+
for dataset_filename in self.datasets_filename:
|
|
333
|
+
path = self.predict_path +dataset_filename
|
|
334
|
+
if os.path.exists(path):
|
|
335
|
+
if os.environ["DL_API_OVERWRITE"] != "True":
|
|
336
|
+
accept = builtins.input("The prediction {} already exists ! Do you want to overwrite it (yes,no) : ".format(path))
|
|
337
|
+
if accept != "yes":
|
|
338
|
+
return
|
|
339
|
+
|
|
340
|
+
if not os.path.exists(path):
|
|
341
|
+
os.makedirs(path)
|
|
342
|
+
|
|
343
|
+
shutil.copyfile(CONFIG_FILE(), self.predict_path+"Prediction.yml")
|
|
344
|
+
|
|
345
|
+
self.model.init(autocast=False, state = State.PREDICTION)
|
|
346
|
+
self.model.init_outputsGroup()
|
|
347
|
+
self.model._compute_channels_trace(self.model, self.model.in_channels, None, self.gpu_checkpoints)
|
|
348
|
+
self.model.load(self._load(), init=False)
|
|
349
|
+
|
|
350
|
+
if len(list(self.outsDataset.keys())) == 0 and len([network for network in self.model.getNetworks().values() if network.measure is not None]) == 0:
|
|
351
|
+
exit(0)
|
|
352
|
+
|
|
353
|
+
|
|
354
|
+
self.size = (len(self.gpu_checkpoints)+1 if self.gpu_checkpoints else 1)
|
|
355
|
+
self.dataloader = self.dataset.getData(world_size//self.size)
|
|
356
|
+
for name, outDataset in self.outsDataset.items():
|
|
357
|
+
outDataset.load(name.replace(".", ":"), list(self.dataset.datasets.values()))
|
|
358
|
+
|
|
359
|
+
|
|
360
|
+
def run_process(self, world_size: int, global_rank: int, local_rank: int, dataloaders: list[DataLoader]):
|
|
361
|
+
model = Network.to(self.model, local_rank*self.size)
|
|
362
|
+
model = DDP(model, static_graph=True) if torch.cuda.is_available() else CPU_Model(model)
|
|
363
|
+
with _Predictor(world_size, global_rank, local_rank, self.predict_path, self.images_log, self.outsDataset, model, *dataloaders) as p:
|
|
364
|
+
p.run()
|
|
365
|
+
|
|
366
|
+
|