konfai 1.0.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of konfai might be problematic. Click here for more details.

konfai/data/dataset.py ADDED
@@ -0,0 +1,470 @@
1
+ import math
2
+ import os
3
+ import random
4
+ import torch
5
+ from torch.utils import data
6
+ import tqdm
7
+ import numpy as np
8
+ from abc import ABC
9
+ from torch.utils.data import DataLoader, Sampler
10
+ from typing import Union, Iterator
11
+ from concurrent.futures import ThreadPoolExecutor, as_completed
12
+ import threading
13
+ from torch.cuda import device_count
14
+
15
+ from KonfAI.konfai import DL_API_STATE, DEEP_LEARNING_API_ROOT
16
+ from KonfAI.konfai.data.HDF5 import DatasetPatch, DatasetManager
17
+ from KonfAI.konfai.utils.config import config
18
+ from KonfAI.konfai.utils.utils import memoryInfo, cpuInfo, memoryForecast, getMemory, State
19
+ from KonfAI.konfai.utils.dataset import Dataset, Attribute
20
+ from KonfAI.konfai.data.transform import TransformLoader, Transform
21
+ from KonfAI.konfai.data.augmentation import DataAugmentationsList
22
+
23
+ class GroupTransform:
24
+
25
+ @config()
26
+ def __init__(self, pre_transforms : Union[dict[str, TransformLoader], list[Transform]] = {"default:Normalize:Standardize:Unsqueeze:TensorCast:ResampleIsotropic:ResampleResize": TransformLoader()},
27
+ post_transforms : Union[dict[str, TransformLoader], list[Transform]] = {"default:Normalize:Standardize:Unsqueeze:TensorCast:ResampleIsotropic:ResampleResize": TransformLoader()},
28
+ isInput: bool = True) -> None:
29
+ self._pre_transforms = pre_transforms
30
+ self._post_transforms = post_transforms
31
+ self.pre_transforms : list[Transform] = []
32
+ self.post_transforms : list[Transform] = []
33
+ self.isInput = isInput
34
+
35
+ def load(self, group_src : str, group_dest : str, datasets: list[Dataset]):
36
+ if self._pre_transforms is not None:
37
+ if isinstance(self._pre_transforms, dict):
38
+ for classpath, transform in self._pre_transforms.items():
39
+ transform = transform.getTransform(classpath, DL_args = "{}.Dataset.groups_src.{}.groups_dest.{}.pre_transforms".format(DEEP_LEARNING_API_ROOT(), group_src, group_dest))
40
+ transform.setDatasets(datasets)
41
+ self.pre_transforms.append(transform)
42
+ else:
43
+ for transform in self._pre_transforms:
44
+ transform.setDatasets(datasets)
45
+ self.pre_transforms.append(transform)
46
+
47
+ if self._post_transforms is not None:
48
+ if isinstance(self._post_transforms, dict):
49
+ for classpath, transform in self._post_transforms.items():
50
+ transform = transform.getTransform(classpath, DL_args = "{}.Dataset.groups_src.{}.groups_dest.{}.post_transforms".format(DEEP_LEARNING_API_ROOT(), group_src, group_dest))
51
+ transform.setDatasets(datasets)
52
+ self.post_transforms.append(transform)
53
+ else:
54
+ for transform in self._post_transforms:
55
+ transform.setDatasets(datasets)
56
+ self.post_transforms.append(transform)
57
+
58
+ def to(self, device: int):
59
+ for transform in self.pre_transforms:
60
+ transform.setDevice(device)
61
+ for transform in self.post_transforms:
62
+ transform.setDevice(device)
63
+
64
+ class Group(dict[str, GroupTransform]):
65
+
66
+ @config()
67
+ def __init__(self, groups_dest: dict[str, GroupTransform] = {"default": GroupTransform()}):
68
+ super().__init__(groups_dest)
69
+
70
+ class CustomSampler(Sampler[int]):
71
+
72
+ def __init__(self, size: int, shuffle: bool = False) -> None:
73
+ self.size = size
74
+ self.shuffle = shuffle
75
+
76
+ def __iter__(self) -> Iterator[int]:
77
+ return iter(torch.randperm(len(self)).tolist() if self.shuffle else list(range(len(self))) )
78
+
79
+ def __len__(self) -> int:
80
+ return self.size
81
+
82
+ class DatasetIter(data.Dataset):
83
+
84
+ def __init__(self, rank: int, data : dict[str, list[DatasetManager]], map: dict[int, tuple[int, int, int]], groups_src : dict[str, Group], inlineAugmentations: bool, dataAugmentationsList : list[DataAugmentationsList], patch_size: Union[list[int], None], overlap: Union[int, None], buffer_size: int, use_cache = True) -> None:
85
+ self.rank = rank
86
+ self.data = data
87
+ self.map = map
88
+ self.patch_size = patch_size
89
+ self.overlap = overlap
90
+ self.groups_src = groups_src
91
+ self.dataAugmentationsList = dataAugmentationsList
92
+ self.use_cache = use_cache
93
+ self.nb_dataset = len(data[list(data.keys())[0]])
94
+ self.buffer_size = buffer_size
95
+ self._index_cache = list()
96
+ self.device = None
97
+ self.inlineAugmentations = inlineAugmentations
98
+
99
+ def getPatchConfig(self) -> tuple[list[int], int]:
100
+ return self.patch_size, self.overlap
101
+
102
+ def to(self, device: int):
103
+ for group_src in self.groups_src:
104
+ for group_dest in self.groups_src[group_src]:
105
+ self.groups_src[group_src][group_dest].to(device)
106
+ self.device = device
107
+
108
+ def getDatasetFromIndex(self, group_dest: str, index: int) -> DatasetManager:
109
+ return self.data[group_dest][index]
110
+
111
+ def resetAugmentation(self):
112
+ if self.inlineAugmentations:
113
+ for index in range(self.nb_dataset):
114
+ self._unloadData(index)
115
+ for group_src in self.groups_src:
116
+ for group_dest in self.groups_src[group_src]:
117
+ self.data[group_dest][index].resetAugmentation()
118
+
119
+ def load(self):
120
+ if self.use_cache:
121
+ memory_init = getMemory()
122
+
123
+ indexs = [index for index in range(self.nb_dataset) if index not in self._index_cache]
124
+ if len(indexs) > 0:
125
+ memory_lock = threading.Lock()
126
+ pbar = tqdm.tqdm(
127
+ total=len(indexs),
128
+ desc="Caching : init | {} | {}".format(memoryForecast(memory_init, 0, self.nb_dataset), cpuInfo()),
129
+ leave=False,
130
+ disable=self.rank != 0 and "DL_API_CLUSTER" not in os.environ
131
+ )
132
+
133
+ def process(index):
134
+ self._loadData(index)
135
+ with memory_lock:
136
+ pbar.set_description("Caching : {} | {} | {}".format(memoryInfo(), memoryForecast(memory_init, index, self.nb_dataset), cpuInfo()))
137
+ pbar.update(1)
138
+ with ThreadPoolExecutor(max_workers=os.cpu_count()//device_count()) as executor:
139
+ futures = [executor.submit(process, index) for index in indexs]
140
+ for _ in as_completed(futures):
141
+ pass
142
+
143
+ pbar.close()
144
+
145
+ def _loadData(self, index):
146
+ if index not in self._index_cache:
147
+ self._index_cache.append(index)
148
+ for group_src in self.groups_src:
149
+ for group_dest in self.groups_src[group_src]:
150
+ self.loadData(group_src, group_dest, index)
151
+
152
+ def loadData(self, group_src: str, group_dest : str, index : int) -> None:
153
+ self.data[group_dest][index].load(self.groups_src[group_src][group_dest].pre_transforms, self.dataAugmentationsList, self.device)
154
+
155
+ def _unloadData(self, index : int) -> None:
156
+ if index in self._index_cache:
157
+ self._index_cache.remove(index)
158
+ for group_src in self.groups_src:
159
+ for group_dest in self.groups_src[group_src]:
160
+ self.unloadData(group_dest, index)
161
+
162
+ def unloadData(self, group_dest : str, index : int) -> None:
163
+ return self.data[group_dest][index].unload()
164
+
165
+ def __len__(self) -> int:
166
+ return len(self.map)
167
+
168
+ def __getitem__(self, index : int) -> dict[str, tuple[torch.Tensor, int, int, int, str, bool]]:
169
+ data = {}
170
+ x, a, p = self.map[index]
171
+ if x not in self._index_cache:
172
+ if len(self._index_cache) >= self.buffer_size and not self.use_cache:
173
+ self._unloadData(self._index_cache[0])
174
+ self._loadData(x)
175
+
176
+ for group_src in self.groups_src:
177
+ for group_dest in self.groups_src[group_src]:
178
+ dataset = self.data[group_dest][x]
179
+ data["{}".format(group_dest)] = (dataset.getData(p, a, self.groups_src[group_src][group_dest].post_transforms, self.groups_src[group_src][group_dest].isInput), x, a, p, dataset.name, self.groups_src[group_src][group_dest].isInput)
180
+ return data
181
+
182
+ class Subset():
183
+
184
+ def __init__(self, subset: Union[str, list[int], list[str], None] = None, shuffle: bool = True, filter: str = None) -> None:
185
+ self.subset = subset
186
+ self.shuffle = shuffle
187
+ self.filter: list[tuple[str]] = [tuple(a.split(":")) for a in filter.split(";")] if filter is not None else None
188
+
189
+ def __call__(self, names: list[str], infos: list[dict[str, tuple[np.ndarray, Attribute]]]) -> set[str]:
190
+ inter_name = set(names[0])
191
+ for n in names[1:]:
192
+ inter_name = inter_name.intersection(set(n))
193
+ names = sorted(list(inter_name))
194
+
195
+ names_filtred = []
196
+ if self.filter is not None:
197
+ for name in names:
198
+ for info in infos:
199
+ if all([info[name][1].isInfo(f[0], f[1]) for f in self.filter]):
200
+ names_filtred.append(name)
201
+ continue
202
+ else:
203
+ names_filtred = names
204
+ size = len(names_filtred)
205
+
206
+ index = []
207
+ if self.subset is None:
208
+ index = list(range(0, size))
209
+ elif isinstance(self.subset, str):
210
+ if ":" in self.subset:
211
+ r = np.clip(np.asarray([int(self.subset.split(":")[0]), int(self.subset.split(":")[1])]), 0, size)
212
+ index = list(range(r[0], r[1]))
213
+ elif os.path.exists(self.subset):
214
+ validation_names = []
215
+ with open(self.subset, "r") as f:
216
+ for name in f:
217
+ validation_names.append(name.strip())
218
+ index = []
219
+ for i, name in enumerate(names_filtred):
220
+ if name in validation_names:
221
+ index.append(i)
222
+
223
+ elif isinstance(self.subset, list):
224
+ if len(self.subset) > 0:
225
+ if isinstance(self.subset[0], int):
226
+ if len(self.subset) == 1:
227
+ index = list(range(self.subset[0], min(size, self.subset[0]+1)))
228
+ else:
229
+ index = self.subset
230
+ if isinstance(self.subset[0], str):
231
+ index = []
232
+ for i, name in enumerate(names_filtred):
233
+ if name in self.subset:
234
+ index.append(i)
235
+ if self.shuffle:
236
+ index = random.sample(index, len(index))
237
+ return set([names_filtred[i] for i in index])
238
+
239
+ class TrainSubset(Subset):
240
+
241
+ @config()
242
+ def __init__(self, subset: Union[str, list[int], list[str], None] = None, shuffle: bool = True, filter: str = None) -> None:
243
+ super().__init__(subset, shuffle, filter)
244
+
245
+ class PredictionSubset(Subset):
246
+
247
+ @config()
248
+ def __init__(self, subset: Union[str, list[int], list[str], None] = None, filter: str = None) -> None:
249
+ super().__init__(subset, False, filter)
250
+
251
+ class Data(ABC):
252
+
253
+ def __init__(self, dataset_filenames : list[str],
254
+ groups_src : dict[str, Group],
255
+ patch : Union[DatasetPatch, None],
256
+ use_cache : bool,
257
+ subset : Union[Subset, dict[str, Subset]],
258
+ num_workers : int,
259
+ batch_size : int,
260
+ train_size: Union[float, str, list[int], list[str]] = 1,
261
+ inlineAugmentations: bool = False,
262
+ dataAugmentationsList: dict[str, DataAugmentationsList]= {}) -> None:
263
+ self.dataset_filenames = dataset_filenames
264
+ self.subset = subset
265
+ self.groups_src = groups_src
266
+ self.patch = patch
267
+ self.train_size = train_size
268
+ self.dataAugmentationsList = dataAugmentationsList
269
+ self.batch_size = batch_size
270
+ self.dataSet_args = dict(groups_src=self.groups_src, inlineAugmentations=inlineAugmentations, dataAugmentationsList = list(self.dataAugmentationsList.values()), use_cache = use_cache, buffer_size=batch_size+1, patch_size=self.patch.patch_size if self.patch is not None else None, overlap=self.patch.overlap if self.patch is not None else None)
271
+ self.dataLoader_args = dict(num_workers=num_workers if use_cache else 0, pin_memory=True)
272
+ self.data : list[list[dict[str, list[DatasetManager]]], dict[str, list[DatasetManager]]] = []
273
+ self.map : list[list[list[tuple[int, int, int]]], list[tuple[int, int, int]]] = []
274
+ self.datasets: dict[str, Dataset] = {}
275
+
276
+ def _getDatasets(self, names: list[str], dataset_name: dict[str, dict[str, list[str]]]) -> tuple[dict[str, list[Dataset]], list[tuple[int, int, int]]]:
277
+ nb_dataset = len(names)
278
+ nb_patch = None
279
+ data = {}
280
+ map = []
281
+ nb_augmentation = np.max([int(np.sum([data_augmentation.nb for data_augmentation in self.dataAugmentationsList.values()])+1), 1])
282
+ for group_src in self.groups_src:
283
+ for group_dest in self.groups_src[group_src]:
284
+ data[group_dest] = [DatasetManager(i, group_src, group_dest, name, self.datasets[[filename for filename, names in dataset_name[group_src].items() if name in names][0]], patch = self.patch, pre_transforms = self.groups_src[group_src][group_dest].pre_transforms, dataAugmentationsList=list(self.dataAugmentationsList.values())) for i, name in enumerate(names)]
285
+ nb_patch = [[dataset.getSize(a) for a in range(nb_augmentation)] for dataset in data[group_dest]]
286
+
287
+ for x in range(nb_dataset):
288
+ for y in range(nb_augmentation):
289
+ for z in range(nb_patch[x][y]):
290
+ map.append((x, y, z))
291
+ return data, map
292
+
293
+ def _split(map: list[tuple[int, int, int]], world_size: int) -> list[list[tuple[int, int, int]]]:
294
+ if len(map) == 0:
295
+ return [[] for _ in range(world_size)]
296
+
297
+ maps = []
298
+ if DL_API_STATE() == str(State.PREDICTION) or DL_API_STATE() == str(State.EVALUATION):
299
+ np_map = np.asarray(map)
300
+ unique_index = np.unique(np_map[:, 0])
301
+ offset = int(np.ceil(len(unique_index)/world_size))
302
+ if offset == 0:
303
+ offset = 1
304
+ for itr in range(0, len(unique_index), offset):
305
+ maps.append([tuple(v) for v in np_map[np.where(np.isin(np_map[:, 0], unique_index[itr:itr+offset]))[0], :]])
306
+ else:
307
+ offset = int(np.ceil(len(map)/world_size))
308
+ if offset == 0:
309
+ offset = 1
310
+ for itr in range(0, len(map), offset):
311
+ maps.append(list(map[-offset:]) if itr+offset > len(map) else map[itr:itr+offset])
312
+ return maps
313
+
314
+ def getData(self, world_size: int) -> list[list[DataLoader]]:
315
+ datasets: dict[str, list[(str, bool)]] = {}
316
+ for dataset_filename in self.dataset_filenames:
317
+ if len(dataset_filename.split(":")) == 2:
318
+ filename, format = dataset_filename.split(":")
319
+ append = True
320
+ else:
321
+ filename, flag, format = dataset_filename.split(":")
322
+ append = flag == "a"
323
+
324
+ dataset = Dataset(filename, format)
325
+
326
+ self.datasets[filename] = dataset
327
+ for group in self.groups_src:
328
+ if dataset.isGroupExist(group):
329
+ if group in datasets:
330
+ datasets[group].append((filename, append))
331
+ else:
332
+ datasets[group] = [(filename, append)]
333
+ for group_src in self.groups_src:
334
+ assert group_src in datasets, "Error group source {} not found".format(group_src)
335
+
336
+ for group_dest in self.groups_src[group_src]:
337
+ self.groups_src[group_src][group_dest].load(group_src, group_dest, [self.datasets[filename] for filename, _ in datasets[group_src]])
338
+ for key, dataAugmentations in self.dataAugmentationsList.items():
339
+ dataAugmentations.load(key)
340
+
341
+ names : list[list[str]] = []
342
+ dataset_name : dict[str, dict[str, list[str]]] = {}
343
+ dataset_info : dict[str, dict[str, dict[str, Attribute]]] = {}
344
+ for group in self.groups_src:
345
+ if group not in dataset_name:
346
+ dataset_name[group] = {}
347
+ dataset_info[group] = {}
348
+ for filename, _ in datasets[group]:
349
+ names.append(self.datasets[filename].getNames(group))
350
+ dataset_name[group][filename] = self.datasets[filename].getNames(group)
351
+ dataset_info[group][filename] = {name: self.datasets[filename].getInfos(group, name) for name in dataset_name[group][filename]}
352
+ subset_names = set()
353
+ if isinstance(self.subset, dict):
354
+ for filename, subset in self.subset.items():
355
+ subset_names.update(subset([dataset_name[group][filename] for group in dataset_name], [dataset_info[group][filename] for group in dataset_name]))
356
+ else:
357
+ for group in dataset_name:
358
+ for filename, append in datasets[group]:
359
+ if append:
360
+ subset_names.update(self.subset([dataset_name[group][filename]], [dataset_info[group][filename]]))
361
+ else:
362
+ if len(subset_names) == 0:
363
+ subset_names.update(self.subset([dataset_name[group][filename]], [dataset_info[group][filename]]))
364
+ else:
365
+ subset_names.intersection(self.subset([dataset_name[group][filename]], [dataset_info[group][filename]]))
366
+ data, map = self._getDatasets(list(subset_names), dataset_name)
367
+
368
+ train_map = map
369
+ validate_map = []
370
+ if isinstance(self.train_size, float):
371
+ if self.train_size < 1.0 and int(math.floor(len(map)*(1-self.train_size))) > 0:
372
+ train_map, validate_map = map[:int(math.floor(len(map)*self.train_size))], map[int(math.floor(len(map)*self.train_size)):]
373
+ elif isinstance(self.train_size, str):
374
+ if ":" in self.train_size:
375
+ index = list(range(int(self.subset.split(":")[0]), int(self.subset.split(":")[1])))
376
+ train_map = [m for m in map if m[0] not in index]
377
+ validate_map = [m for m in map if m[0] in index]
378
+ elif os.path.exists(self.train_size):
379
+ validation_names = []
380
+ with open(self.train_size, "r") as f:
381
+ for name in f:
382
+ validation_names.append(name.strip())
383
+ index = [i for i, n in enumerate(subset_names) if n in validation_names]
384
+ train_map = [m for m in map if m[0] not in index]
385
+ validate_map = [m for m in map if m[0] in index]
386
+ else:
387
+ validate_map = train_map
388
+ elif isinstance(self.train_size, list):
389
+ if len(self.train_size) > 0:
390
+ if isinstance(self.train_size[0], int):
391
+ train_map = [m for m in map if m[0] not in self.train_size]
392
+ validate_map = [m for m in map if m[0] in self.train_size]
393
+ elif isinstance(self.train_size[0], str):
394
+ index = [i for i, n in enumerate(subset_names) if n in self.train_size]
395
+ train_map = [m for m in map if m[0] not in index]
396
+ validate_map = [m for m in map if m[0] in index]
397
+
398
+ train_maps = Data._split(train_map, world_size)
399
+ validate_maps = Data._split(validate_map, world_size)
400
+
401
+ for i, (train_map, validate_map) in enumerate(zip(train_maps, validate_maps)):
402
+ maps = [train_map]
403
+ if len(validate_map):
404
+ maps += [validate_map]
405
+ self.data.append([])
406
+ self.map.append([])
407
+ for map_tmp in maps:
408
+ indexs = np.unique(np.asarray(map_tmp)[:, 0])
409
+ self.data[i].append({k:[v[it] for it in indexs] for k, v in data.items()})
410
+ map_tmp_array = np.asarray(map_tmp)
411
+ for a, b in enumerate(indexs):
412
+ map_tmp_array[np.where(np.asarray(map_tmp_array)[:, 0] == b), 0] = a
413
+ self.map[i].append([(a,b,c) for a,b,c in map_tmp_array])
414
+
415
+ dataLoaders: list[list[DataLoader]] = []
416
+ for i, (datas, maps) in enumerate(zip(self.data, self.map)):
417
+ dataLoaders.append([])
418
+ for data, map in zip(datas, maps):
419
+ dataLoaders[i].append(DataLoader(dataset=DatasetIter(rank=i, data=data, map=map, **self.dataSet_args), sampler=CustomSampler(len(map), self.subset.shuffle), batch_size=self.batch_size,**self.dataLoader_args))
420
+ return dataLoaders
421
+
422
+ class DataTrain(Data):
423
+
424
+ @config("Dataset")
425
+ def __init__(self, dataset_filenames : list[str] = ["default:Dataset.h5"],
426
+ groups_src : dict[str, Group] = {"default" : Group()},
427
+ augmentations : Union[dict[str, DataAugmentationsList], None] = {"DataAugmentation_0" : DataAugmentationsList()},
428
+ inlineAugmentations: bool = False,
429
+ patch : Union[DatasetPatch, None] = DatasetPatch(),
430
+ use_cache : bool = True,
431
+ subset : Union[TrainSubset, dict[str, TrainSubset]] = TrainSubset(),
432
+ num_workers : int = 4,
433
+ batch_size : int = 1,
434
+ train_size : Union[float, str, list[int], list[str]] = 0.8) -> None:
435
+ super().__init__(dataset_filenames, groups_src, patch, use_cache, subset, num_workers, batch_size, train_size, inlineAugmentations, augmentations if augmentations else {})
436
+
437
+ class DataPrediction(Data):
438
+
439
+ @config("Dataset")
440
+ def __init__(self, dataset_filenames : list[str] = ["default:Dataset.h5"],
441
+ groups_src : dict[str, Group] = {"default" : Group()},
442
+ augmentations : Union[dict[str, DataAugmentationsList], None] = {"DataAugmentation_0" : DataAugmentationsList()},
443
+ inlineAugmentations: bool = False,
444
+ patch : Union[DatasetPatch, None] = DatasetPatch(),
445
+ use_cache : bool = True,
446
+ subset : Union[PredictionSubset, dict[str, PredictionSubset]] = PredictionSubset(),
447
+ num_workers : int = 4,
448
+ batch_size : int = 1) -> None:
449
+
450
+ super().__init__(dataset_filenames, groups_src, patch, use_cache, subset, num_workers, batch_size, inlineAugmentations=inlineAugmentations, dataAugmentationsList=augmentations if augmentations else {})
451
+
452
+ class DataMetric(Data):
453
+
454
+ @config("Dataset")
455
+ def __init__(self, dataset_filenames : list[str] = ["default:Dataset.h5"],
456
+ groups_src : dict[str, Group] = {"default" : Group()},
457
+ subset : Union[PredictionSubset, dict[str, PredictionSubset]] = PredictionSubset(),
458
+ validation: Union[str, None] = None,
459
+ num_workers : int = 4) -> None:
460
+
461
+ super().__init__(dataset_filenames=dataset_filenames, groups_src=groups_src, patch=None, use_cache=False, subset=subset, num_workers=num_workers, batch_size=1, train_size=1 if validation is None else validation)
462
+
463
+ class DataHyperparameter(Data):
464
+
465
+ @config("Dataset")
466
+ def __init__(self, dataset_filenames : list[str] = ["default:Dataset.h5"],
467
+ groups_src : dict[str, Group] = {"default" : Group()},
468
+ patch : Union[DatasetPatch, None] = DatasetPatch()) -> None:
469
+
470
+ super().__init__(dataset_filenames, groups_src, patch, False, PredictionSubset(), 0, False, 1)