konfai 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of konfai might be problematic. Click here for more details.

konfai/predictor.py ADDED
@@ -0,0 +1,366 @@
1
+ from abc import ABC, abstractmethod
2
+ import builtins
3
+ import importlib
4
+ import shutil
5
+ import torch
6
+ import tqdm
7
+ import os
8
+
9
+ from KonfAI.konfai import MODELS_DIRECTORY, PREDICTIONS_DIRECTORY, CONFIG_FILE, MODEL, DEEP_LEARNING_API_ROOT
10
+ from KonfAI.konfai.utils.config import config
11
+ from KonfAI.konfai.utils.utils import State, get_patch_slices_from_nb_patch_per_dim, NeedDevice, _getModule, DistributedObject, DataLog, description
12
+ from KonfAI.konfai.utils.dataset import Dataset, Attribute
13
+ from KonfAI.konfai.data.dataset import DataPrediction, DatasetIter
14
+ from KonfAI.konfai.data.HDF5 import Accumulator, PathCombine
15
+ from KonfAI.konfai.network.network import ModelLoader, Network, NetState, CPU_Model
16
+ from KonfAI.konfai.data.transform import Transform, TransformLoader
17
+
18
+ from torch.utils.tensorboard.writer import SummaryWriter
19
+ from typing import Union
20
+ import numpy as np
21
+ import torch.distributed as dist
22
+ from torch.nn.parallel import DistributedDataParallel as DDP
23
+ from torch.utils.data import DataLoader
24
+ import importlib
25
+
26
+ class OutDataset(Dataset, NeedDevice, ABC):
27
+
28
+ def __init__(self, filename: str, group: str, pre_transforms : dict[str, TransformLoader], post_transforms : dict[str, TransformLoader], final_transforms : dict[str, TransformLoader], patchCombine: Union[str, None]) -> None:
29
+ filename, format = filename.split(":")
30
+ super().__init__(filename, format)
31
+ self.group = group
32
+ self._pre_transforms = pre_transforms
33
+ self._post_transforms = post_transforms
34
+ self._final_transforms = final_transforms
35
+ self._patchCombine = patchCombine
36
+
37
+ self.pre_transforms : list[Transform] = []
38
+ self.post_transforms : list[Transform] = []
39
+ self.final_transforms : list[Transform] = []
40
+ self.patchCombine: PathCombine = None
41
+
42
+ self.output_layer_accumulator: dict[int, dict[int, Accumulator]] = {}
43
+ self.attributes: dict[int, dict[int, dict[int, Attribute]]] = {}
44
+ self.names: dict[int, str] = {}
45
+ self.nb_data_augmentation = 0
46
+
47
+ def load(self, name_layer: str, datasets: list[Dataset]):
48
+ transforms_type = ["pre_transforms", "post_transforms", "final_transforms"]
49
+ for name, _transform_type, transform_type in [(k, getattr(self, "_{}".format(k)), getattr(self, k)) for k in transforms_type]:
50
+
51
+ if _transform_type is not None:
52
+ for classpath, transform in _transform_type.items():
53
+ transform = transform.getTransform(classpath, DL_args = "{}.outsDataset.{}.OutDataset.{}".format(DEEP_LEARNING_API_ROOT(), name_layer, name))
54
+ transform.setDatasets(datasets)
55
+ transform_type.append(transform)
56
+
57
+ if self._patchCombine is not None:
58
+ module, name = _getModule(self._patchCombine, "HDF5")
59
+ self.patchCombine = getattr(importlib.import_module(module), name)(config = None, DL_args = "{}.outsDataset.{}.OutDataset".format(DEEP_LEARNING_API_ROOT(), name_layer))
60
+
61
+ def setPatchConfig(self, patchSize: Union[list[int], None], overlap: Union[int, None], nb_data_augmentation: int) -> None:
62
+ if patchSize is not None and overlap is not None:
63
+ if self.patchCombine is not None:
64
+ self.patchCombine.setPatchConfig(patchSize, overlap)
65
+ else:
66
+ self.patchCombine = None
67
+ self.nb_data_augmentation = nb_data_augmentation
68
+
69
+ def setDevice(self, device: torch.device):
70
+ super().setDevice(device)
71
+ transforms_type = ["pre_transforms", "post_transforms", "final_transforms"]
72
+ for transform_type in [(getattr(self, k)) for k in transforms_type]:
73
+ if transform_type is not None:
74
+ for transform in transform_type:
75
+ transform.setDevice(device)
76
+
77
+ @abstractmethod
78
+ def addLayer(self, index: int, index_patch: int, layer: torch.Tensor, dataset: DatasetIter):
79
+ pass
80
+
81
+ def isDone(self, index: int) -> bool:
82
+ return len(self.output_layer_accumulator[index]) == self.nb_data_augmentation and all([acc.isFull() for acc in self.output_layer_accumulator[index].values()])
83
+
84
+ @abstractmethod
85
+ def getOutput(self, index: int, dataset: DatasetIter) -> torch.Tensor:
86
+ pass
87
+
88
+ def write(self, index: int, name: str, layer: torch.Tensor):
89
+ super().write(self.group, name, layer.numpy(), self.attributes[index][0][0])
90
+ self.attributes.pop(index)
91
+
92
+ class OutSameAsGroupDataset(OutDataset):
93
+
94
+ @config("OutDataset")
95
+ def __init__(self, dataset_filename: str = "Dataset:h5", group: str = "default", sameAsGroup: str = "default", pre_transforms : dict[str, TransformLoader] = {"default:Normalize": TransformLoader()}, post_transforms : dict[str, TransformLoader] = {"default:Normalize": TransformLoader()}, final_transforms : dict[str, TransformLoader] = {"default:Normalize": TransformLoader()}, patchCombine: Union[str, None] = None, redution: str = "mean", inverse_transform: bool = True) -> None:
96
+ super().__init__(dataset_filename, group, pre_transforms, post_transforms, final_transforms, patchCombine)
97
+ self.group_src, self.group_dest = sameAsGroup.split(":")
98
+ self.redution = redution
99
+ self.inverse_transform = inverse_transform
100
+
101
+ def addLayer(self, index_dataset: int, index_augmentation: int, index_patch: int, layer: torch.Tensor, dataset: DatasetIter):
102
+ if index_dataset not in self.output_layer_accumulator or index_augmentation not in self.output_layer_accumulator[index_dataset]:
103
+ input_dataset = dataset.getDatasetFromIndex(self.group_dest, index_dataset)
104
+ if index_dataset not in self.output_layer_accumulator:
105
+ self.output_layer_accumulator[index_dataset] = {}
106
+ self.attributes[index_dataset] = {}
107
+ self.names[index_dataset] = input_dataset.name
108
+ self.attributes[index_dataset][index_augmentation] = {}
109
+
110
+ self.output_layer_accumulator[index_dataset][index_augmentation] = Accumulator(input_dataset.patch.getPatch_slices(index_augmentation), input_dataset.patch.patch_size, self.patchCombine, batch=False)
111
+
112
+ for i in range(len(input_dataset.patch.getPatch_slices(index_augmentation))):
113
+ self.attributes[index_dataset][index_augmentation][i] = Attribute(input_dataset.cache_attributes[0])
114
+
115
+ for transform in self.pre_transforms:
116
+ layer = transform(self.names[index_dataset], layer, self.attributes[index_dataset][index_augmentation][index_patch])
117
+
118
+ if self.inverse_transform:
119
+ for transform in reversed(dataset.groups_src[self.group_src][self.group_dest].post_transforms):
120
+ layer = transform.inverse(self.names[index_dataset], layer, self.attributes[index_dataset][index_augmentation][index_patch])
121
+
122
+ self.output_layer_accumulator[index_dataset][index_augmentation].addLayer(index_patch, layer)
123
+
124
+ def _getOutput(self, index: int, index_augmentation: int, dataset: DatasetIter) -> torch.Tensor:
125
+ layer = self.output_layer_accumulator[index][index_augmentation].assemble()
126
+ name = self.names[index]
127
+ if index_augmentation > 0:
128
+
129
+ i = 0
130
+ index_augmentation_tmp = index_augmentation-1
131
+ for dataAugmentations in dataset.dataAugmentationsList:
132
+ if index_augmentation_tmp >= i and index_augmentation_tmp < i+dataAugmentations.nb:
133
+ for dataAugmentation in reversed(dataAugmentations.dataAugmentations):
134
+ layer = dataAugmentation.inverse(index, index_augmentation_tmp-i, layer)
135
+ break
136
+ i += dataAugmentations.nb
137
+
138
+ for transform in self.post_transforms:
139
+ layer = transform(name, layer, self.attributes[index][index_augmentation][0])
140
+
141
+ if self.inverse_transform:
142
+ for transform in reversed(dataset.groups_src[self.group_src][self.group_dest].pre_transforms):
143
+ layer = transform.inverse(name, layer, self.attributes[index][index_augmentation][0])
144
+ return layer
145
+
146
+ def getOutput(self, index: int, dataset: DatasetIter) -> torch.Tensor:
147
+ result = torch.cat([self._getOutput(index, index_augmentation, dataset).unsqueeze(0) for index_augmentation in self.output_layer_accumulator[index].keys()], dim=0)
148
+ name = self.names[index]
149
+ self.output_layer_accumulator.pop(index)
150
+ dtype = result.dtype
151
+
152
+ if self.redution == "mean":
153
+ result = torch.mean(result.float(), dim=0).to(dtype)
154
+ elif self.redution == "median":
155
+ result, _ = torch.median(result.float(), dim=0).to(dtype)
156
+ else:
157
+ raise NameError("Reduction method does not exist (mean, median)")
158
+ for transform in self.final_transforms:
159
+ result = transform(name, result, self.attributes[index][0][0])
160
+ return result
161
+
162
+ class OutLayerDataset(OutDataset):
163
+
164
+ @config("OutDataset")
165
+ def __init__(self, dataset_filename: str = "Dataset.h5", group: str = "default", overlap : Union[list[int], None] = None, pre_transforms : dict[str, TransformLoader] = {"default:Normalize": TransformLoader()}, post_transforms : dict[str, TransformLoader] = {"default:Normalize": TransformLoader()}, final_transforms : dict[str, TransformLoader] = {"default:Normalize": TransformLoader()}, patchCombine: Union[str, None] = None) -> None:
166
+ super().__init__(dataset_filename, group, pre_transforms, post_transforms, final_transforms, patchCombine)
167
+ self.overlap = overlap
168
+
169
+ def addLayer(self, index: int, index_patch: int, layer: torch.Tensor, dataset: DatasetIter):
170
+ if index not in self.output_layer_accumulator:
171
+ group = list(dataset.groups.keys())[0]
172
+ patch_slices = get_patch_slices_from_nb_patch_per_dim(list(layer.shape[2:]), dataset.getDatasetFromIndex(group, index).patch.nb_patch_per_dim, self.overlap)
173
+ self.output_layer_accumulator[index] = Accumulator(patch_slices, self.patchCombine, batch=False)
174
+ self.attributes[index] = Attribute()
175
+ self.names[index] = dataset.getDatasetFromIndex(group, index).name
176
+
177
+
178
+ for transform in self.pre_transforms:
179
+ layer = transform(layer, self.attributes[index])
180
+ self.output_layer_accumulator[index].addLayer(index_patch, layer)
181
+
182
+ def getOutput(self, index: int, dataset: DatasetIter) -> torch.Tensor:
183
+ layer = self.output_layer_accumulator[index].assemble()
184
+ name = self.names[index]
185
+ for transform in self.post_transforms:
186
+ layer = transform(name, layer, self.attributes[index])
187
+
188
+ self.output_layer_accumulator.pop(index)
189
+ return layer
190
+
191
+ class OutDatasetLoader():
192
+
193
+ @config("OutDataset")
194
+ def __init__(self, name_class: str = "OutSameAsGroupDataset") -> None:
195
+ self.name_class = name_class
196
+
197
+ def getOutDataset(self, layer_name: str) -> OutDataset:
198
+ return getattr(importlib.import_module("KonfAI.konfai.predictor"), self.name_class)(config = None, DL_args = "Predictor.outsDataset.{}".format(layer_name))
199
+
200
+ class _Predictor():
201
+
202
+ def __init__(self, world_size: int, global_rank: int, local_rank: int, predict_path: str, data_log: Union[list[str], None], outsDataset: dict[str, OutDataset], model: DDP, dataloader_prediction: DataLoader) -> None:
203
+ self.world_size = world_size
204
+ self.global_rank = global_rank
205
+ self.local_rank = local_rank
206
+
207
+ self.model = model
208
+ self.dataloader_prediction = dataloader_prediction
209
+ self.outsDataset = outsDataset
210
+
211
+
212
+ self.it = 0
213
+
214
+ self.device = self.model.device
215
+ self.dataset: DatasetIter = self.dataloader_prediction.dataset
216
+ patch_size, overlap = self.dataset.getPatchConfig()
217
+ for outDataset in self.outsDataset.values():
218
+ outDataset.setPatchConfig([size for size in patch_size if size > 1], overlap, np.max([int(np.sum([data_augmentation.nb for data_augmentation in self.dataset.dataAugmentationsList])+1), 1]))
219
+ self.data_log : dict[str, tuple[DataLog, int]] = {}
220
+ if data_log is not None:
221
+ for data in data_log:
222
+ self.data_log[data.split("/")[0].replace(":", ".")] = (DataLog.__getitem__(data.split("/")[1]).value[0], int(data.split("/")[2]))
223
+ self.tb = SummaryWriter(log_dir = predict_path+"Metric/") if len([network for network in self.model.module.getNetworks().values() if network.measure is not None]) or len(self.data_log) else None
224
+
225
+ def __enter__(self):
226
+ return self
227
+
228
+ def __exit__(self, type, value, traceback):
229
+ if self.tb:
230
+ self.tb.close()
231
+
232
+ def getInput(self, data_dict : dict[str, tuple[torch.Tensor, int, int, int, str, bool]]) -> dict[tuple[str, bool], torch.Tensor]:
233
+ return {(k, v[5][0].item()) : v[0] for k, v in data_dict.items()}
234
+
235
+ @torch.no_grad()
236
+ def run(self):
237
+ self.model.eval()
238
+ self.model.module.setState(NetState.PREDICTION)
239
+ desc = lambda : "Prediction : {}".format(description(self.model))
240
+ self.dataloader_prediction.dataset.load()
241
+ with tqdm.tqdm(iterable = enumerate(self.dataloader_prediction), leave=False, desc = desc(), total=len(self.dataloader_prediction), disable=self.global_rank != 0 and "DL_API_CLUSTER" not in os.environ) as batch_iter:
242
+ dist.barrier()
243
+ for it, data_dict in batch_iter:
244
+ input = self.getInput(data_dict)
245
+ for name, output in self.model(input, list(self.outsDataset.keys())):
246
+ self._predict_log(data_dict)
247
+ outDataset = self.outsDataset[name]
248
+ for i, (index, patch_augmentation, patch_index) in enumerate([(int(index), int(patch_augmentation), int(patch_index)) for index, patch_augmentation, patch_index in zip(list(data_dict.values())[0][1], list(data_dict.values())[0][2], list(data_dict.values())[0][3])]):
249
+ outDataset.addLayer(index, patch_augmentation, patch_index, output[i].cpu(), self.dataset)
250
+ if outDataset.isDone(index):
251
+ outDataset.write(index, self.dataset.getDatasetFromIndex(list(data_dict.keys())[0], index).name.split("/")[-1], outDataset.getOutput(index, self.dataset))
252
+
253
+ batch_iter.set_description(desc())
254
+ self.it += 1
255
+
256
+ def _predict_log(self, data_dict : dict[str, tuple[torch.Tensor, int, int, int]]):
257
+ measures = DistributedObject.getMeasure(self.world_size, self.global_rank, self.local_rank, {"" : self.model.module}, 1)
258
+
259
+ if self.global_rank == 0:
260
+ images_log = []
261
+ if len(self.data_log):
262
+ for name, data_type in self.data_log.items():
263
+ if name in data_dict:
264
+ data_type[0](self.tb, "Prediction/{}".format(name), data_dict[name][0][:self.data_log[name][1]].detach().cpu().numpy(), self.it)
265
+ else:
266
+ images_log.append(name.replace(":", "."))
267
+
268
+ for name, network in self.model.module.getNetworks().items():
269
+ if network.measure is not None:
270
+ self.tb.add_scalars("Prediction/{}/Loss".format(name), {k : v[1] for k, v in measures["{}{}".format(name, "")][0].items()}, self.it)
271
+ self.tb.add_scalars("Prediction/{}/Metric".format(name), {k : v[1] for k, v in measures["{}{}".format(name, "")][1].items()}, self.it)
272
+ if len(images_log):
273
+ for name, layer, _ in self.model.module.get_layers([v.to(0) for k, v in self.getInput(data_dict).items() if k[1]], images_log):
274
+ self.data_log[name][0](self.tb, "Prediction/{}".format(name), layer[:self.data_log[name][1]].detach().cpu().numpy(), self.it)
275
+
276
+ class Predictor(DistributedObject):
277
+
278
+ @config("Predictor")
279
+ def __init__(self,
280
+ model: ModelLoader = ModelLoader(),
281
+ dataset: DataPrediction = DataPrediction(),
282
+ train_name: str = "name",
283
+ manual_seed : Union[int, None] = None,
284
+ gpu_checkpoints: Union[list[str], None] = None,
285
+ outsDataset: Union[dict[str, OutDatasetLoader], None] = {"default:Default" : OutDatasetLoader()},
286
+ images_log: list[str] = []) -> None:
287
+ if os.environ["DEEP_LEANING_API_CONFIG_MODE"] != "Done":
288
+ exit(0)
289
+ super().__init__(train_name)
290
+ self.manual_seed = manual_seed
291
+ self.dataset = dataset
292
+
293
+ self.model = model.getModel(train=False)
294
+ self.it = 0
295
+ self.outsDatasetLoader = outsDataset if outsDataset else {}
296
+ self.outsDataset = {name.replace(":", ".") : value.getOutDataset(name) for name, value in self.outsDatasetLoader.items()}
297
+
298
+ self.datasets_filename = []
299
+ self.predict_path = PREDICTIONS_DIRECTORY()+self.name+"/"
300
+ self.images_log = images_log
301
+ for outDataset in self.outsDataset.values():
302
+ self.datasets_filename.append(outDataset.filename)
303
+ outDataset.filename = "{}{}".format(self.predict_path, outDataset.filename)
304
+
305
+
306
+ self.gpu_checkpoints = gpu_checkpoints
307
+
308
+ def _load(self) -> dict[str, dict[str, torch.Tensor]]:
309
+ if MODEL().startswith("https://"):
310
+ try:
311
+ state_dict = {MODEL().split(":")[1]: torch.hub.load_state_dict_from_url(url=MODEL().split(":")[0], map_location="cpu", check_hash=True)}
312
+ except:
313
+ raise Exception("Model : {} does not exist !".format(MODEL()))
314
+ else:
315
+ if MODEL() != "":
316
+ path = ""
317
+ name = MODEL()
318
+ else:
319
+ if self.name.endswith(".pt"):
320
+ path = MODELS_DIRECTORY()+"/".join(self.name.split("/")[:-1])+"/StateDict/"
321
+ name = self.name.split("/")[-1]
322
+ else:
323
+ path = MODELS_DIRECTORY()+self.name+"/StateDict/"
324
+ name = sorted(os.listdir(path))[-1]
325
+ if os.path.exists(path+name):
326
+ state_dict = torch.load(path+name, weights_only=False)
327
+ else:
328
+ raise Exception("Model : {} does not exist !".format(path+name))
329
+ return state_dict
330
+
331
+ def setup(self, world_size: int):
332
+ for dataset_filename in self.datasets_filename:
333
+ path = self.predict_path +dataset_filename
334
+ if os.path.exists(path):
335
+ if os.environ["DL_API_OVERWRITE"] != "True":
336
+ accept = builtins.input("The prediction {} already exists ! Do you want to overwrite it (yes,no) : ".format(path))
337
+ if accept != "yes":
338
+ return
339
+
340
+ if not os.path.exists(path):
341
+ os.makedirs(path)
342
+
343
+ shutil.copyfile(CONFIG_FILE(), self.predict_path+"Prediction.yml")
344
+
345
+ self.model.init(autocast=False, state = State.PREDICTION)
346
+ self.model.init_outputsGroup()
347
+ self.model._compute_channels_trace(self.model, self.model.in_channels, None, self.gpu_checkpoints)
348
+ self.model.load(self._load(), init=False)
349
+
350
+ if len(list(self.outsDataset.keys())) == 0 and len([network for network in self.model.getNetworks().values() if network.measure is not None]) == 0:
351
+ exit(0)
352
+
353
+
354
+ self.size = (len(self.gpu_checkpoints)+1 if self.gpu_checkpoints else 1)
355
+ self.dataloader = self.dataset.getData(world_size//self.size)
356
+ for name, outDataset in self.outsDataset.items():
357
+ outDataset.load(name.replace(".", ":"), list(self.dataset.datasets.values()))
358
+
359
+
360
+ def run_process(self, world_size: int, global_rank: int, local_rank: int, dataloaders: list[DataLoader]):
361
+ model = Network.to(self.model, local_rank*self.size)
362
+ model = DDP(model, static_graph=True) if torch.cuda.is_available() else CPU_Model(model)
363
+ with _Predictor(world_size, global_rank, local_rank, self.predict_path, self.images_log, self.outsDataset, model, *dataloaders) as p:
364
+ p.run()
365
+
366
+