konfai 1.0.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of konfai might be problematic. Click here for more details.
- konfai/__init__.py +16 -0
- konfai/data/HDF5.py +326 -0
- konfai/data/__init__.py +0 -0
- konfai/data/augmentation.py +597 -0
- konfai/data/dataset.py +470 -0
- konfai/data/transform.py +536 -0
- konfai/evaluator.py +146 -0
- konfai/main.py +43 -0
- konfai/metric/__init__.py +0 -0
- konfai/metric/measure.py +488 -0
- konfai/metric/schedulers.py +49 -0
- konfai/models/classification/convNeXt.py +175 -0
- konfai/models/classification/resnet.py +116 -0
- konfai/models/generation/cStyleGan.py +137 -0
- konfai/models/generation/ddpm.py +218 -0
- konfai/models/generation/diffusionGan.py +557 -0
- konfai/models/generation/gan.py +134 -0
- konfai/models/generation/vae.py +72 -0
- konfai/models/registration/registration.py +136 -0
- konfai/models/representation/representation.py +57 -0
- konfai/models/segmentation/NestedUNet.py +53 -0
- konfai/models/segmentation/UNet.py +58 -0
- konfai/network/__init__.py +0 -0
- konfai/network/blocks.py +348 -0
- konfai/network/network.py +950 -0
- konfai/predictor.py +366 -0
- konfai/trainer.py +330 -0
- konfai/utils/ITK.py +269 -0
- konfai/utils/Registration.py +199 -0
- konfai/utils/__init__.py +0 -0
- konfai/utils/config.py +218 -0
- konfai/utils/dataset.py +764 -0
- konfai/utils/utils.py +493 -0
- konfai-1.0.0.dist-info/METADATA +68 -0
- konfai-1.0.0.dist-info/RECORD +39 -0
- konfai-1.0.0.dist-info/WHEEL +5 -0
- konfai-1.0.0.dist-info/entry_points.txt +3 -0
- konfai-1.0.0.dist-info/licenses/LICENSE +201 -0
- konfai-1.0.0.dist-info/top_level.txt +1 -0
konfai/data/dataset.py
ADDED
|
@@ -0,0 +1,470 @@
|
|
|
1
|
+
import math
|
|
2
|
+
import os
|
|
3
|
+
import random
|
|
4
|
+
import torch
|
|
5
|
+
from torch.utils import data
|
|
6
|
+
import tqdm
|
|
7
|
+
import numpy as np
|
|
8
|
+
from abc import ABC
|
|
9
|
+
from torch.utils.data import DataLoader, Sampler
|
|
10
|
+
from typing import Union, Iterator
|
|
11
|
+
from concurrent.futures import ThreadPoolExecutor, as_completed
|
|
12
|
+
import threading
|
|
13
|
+
from torch.cuda import device_count
|
|
14
|
+
|
|
15
|
+
from KonfAI.konfai import DL_API_STATE, DEEP_LEARNING_API_ROOT
|
|
16
|
+
from KonfAI.konfai.data.HDF5 import DatasetPatch, DatasetManager
|
|
17
|
+
from KonfAI.konfai.utils.config import config
|
|
18
|
+
from KonfAI.konfai.utils.utils import memoryInfo, cpuInfo, memoryForecast, getMemory, State
|
|
19
|
+
from KonfAI.konfai.utils.dataset import Dataset, Attribute
|
|
20
|
+
from KonfAI.konfai.data.transform import TransformLoader, Transform
|
|
21
|
+
from KonfAI.konfai.data.augmentation import DataAugmentationsList
|
|
22
|
+
|
|
23
|
+
class GroupTransform:
|
|
24
|
+
|
|
25
|
+
@config()
|
|
26
|
+
def __init__(self, pre_transforms : Union[dict[str, TransformLoader], list[Transform]] = {"default:Normalize:Standardize:Unsqueeze:TensorCast:ResampleIsotropic:ResampleResize": TransformLoader()},
|
|
27
|
+
post_transforms : Union[dict[str, TransformLoader], list[Transform]] = {"default:Normalize:Standardize:Unsqueeze:TensorCast:ResampleIsotropic:ResampleResize": TransformLoader()},
|
|
28
|
+
isInput: bool = True) -> None:
|
|
29
|
+
self._pre_transforms = pre_transforms
|
|
30
|
+
self._post_transforms = post_transforms
|
|
31
|
+
self.pre_transforms : list[Transform] = []
|
|
32
|
+
self.post_transforms : list[Transform] = []
|
|
33
|
+
self.isInput = isInput
|
|
34
|
+
|
|
35
|
+
def load(self, group_src : str, group_dest : str, datasets: list[Dataset]):
|
|
36
|
+
if self._pre_transforms is not None:
|
|
37
|
+
if isinstance(self._pre_transforms, dict):
|
|
38
|
+
for classpath, transform in self._pre_transforms.items():
|
|
39
|
+
transform = transform.getTransform(classpath, DL_args = "{}.Dataset.groups_src.{}.groups_dest.{}.pre_transforms".format(DEEP_LEARNING_API_ROOT(), group_src, group_dest))
|
|
40
|
+
transform.setDatasets(datasets)
|
|
41
|
+
self.pre_transforms.append(transform)
|
|
42
|
+
else:
|
|
43
|
+
for transform in self._pre_transforms:
|
|
44
|
+
transform.setDatasets(datasets)
|
|
45
|
+
self.pre_transforms.append(transform)
|
|
46
|
+
|
|
47
|
+
if self._post_transforms is not None:
|
|
48
|
+
if isinstance(self._post_transforms, dict):
|
|
49
|
+
for classpath, transform in self._post_transforms.items():
|
|
50
|
+
transform = transform.getTransform(classpath, DL_args = "{}.Dataset.groups_src.{}.groups_dest.{}.post_transforms".format(DEEP_LEARNING_API_ROOT(), group_src, group_dest))
|
|
51
|
+
transform.setDatasets(datasets)
|
|
52
|
+
self.post_transforms.append(transform)
|
|
53
|
+
else:
|
|
54
|
+
for transform in self._post_transforms:
|
|
55
|
+
transform.setDatasets(datasets)
|
|
56
|
+
self.post_transforms.append(transform)
|
|
57
|
+
|
|
58
|
+
def to(self, device: int):
|
|
59
|
+
for transform in self.pre_transforms:
|
|
60
|
+
transform.setDevice(device)
|
|
61
|
+
for transform in self.post_transforms:
|
|
62
|
+
transform.setDevice(device)
|
|
63
|
+
|
|
64
|
+
class Group(dict[str, GroupTransform]):
|
|
65
|
+
|
|
66
|
+
@config()
|
|
67
|
+
def __init__(self, groups_dest: dict[str, GroupTransform] = {"default": GroupTransform()}):
|
|
68
|
+
super().__init__(groups_dest)
|
|
69
|
+
|
|
70
|
+
class CustomSampler(Sampler[int]):
|
|
71
|
+
|
|
72
|
+
def __init__(self, size: int, shuffle: bool = False) -> None:
|
|
73
|
+
self.size = size
|
|
74
|
+
self.shuffle = shuffle
|
|
75
|
+
|
|
76
|
+
def __iter__(self) -> Iterator[int]:
|
|
77
|
+
return iter(torch.randperm(len(self)).tolist() if self.shuffle else list(range(len(self))) )
|
|
78
|
+
|
|
79
|
+
def __len__(self) -> int:
|
|
80
|
+
return self.size
|
|
81
|
+
|
|
82
|
+
class DatasetIter(data.Dataset):
|
|
83
|
+
|
|
84
|
+
def __init__(self, rank: int, data : dict[str, list[DatasetManager]], map: dict[int, tuple[int, int, int]], groups_src : dict[str, Group], inlineAugmentations: bool, dataAugmentationsList : list[DataAugmentationsList], patch_size: Union[list[int], None], overlap: Union[int, None], buffer_size: int, use_cache = True) -> None:
|
|
85
|
+
self.rank = rank
|
|
86
|
+
self.data = data
|
|
87
|
+
self.map = map
|
|
88
|
+
self.patch_size = patch_size
|
|
89
|
+
self.overlap = overlap
|
|
90
|
+
self.groups_src = groups_src
|
|
91
|
+
self.dataAugmentationsList = dataAugmentationsList
|
|
92
|
+
self.use_cache = use_cache
|
|
93
|
+
self.nb_dataset = len(data[list(data.keys())[0]])
|
|
94
|
+
self.buffer_size = buffer_size
|
|
95
|
+
self._index_cache = list()
|
|
96
|
+
self.device = None
|
|
97
|
+
self.inlineAugmentations = inlineAugmentations
|
|
98
|
+
|
|
99
|
+
def getPatchConfig(self) -> tuple[list[int], int]:
|
|
100
|
+
return self.patch_size, self.overlap
|
|
101
|
+
|
|
102
|
+
def to(self, device: int):
|
|
103
|
+
for group_src in self.groups_src:
|
|
104
|
+
for group_dest in self.groups_src[group_src]:
|
|
105
|
+
self.groups_src[group_src][group_dest].to(device)
|
|
106
|
+
self.device = device
|
|
107
|
+
|
|
108
|
+
def getDatasetFromIndex(self, group_dest: str, index: int) -> DatasetManager:
|
|
109
|
+
return self.data[group_dest][index]
|
|
110
|
+
|
|
111
|
+
def resetAugmentation(self):
|
|
112
|
+
if self.inlineAugmentations:
|
|
113
|
+
for index in range(self.nb_dataset):
|
|
114
|
+
self._unloadData(index)
|
|
115
|
+
for group_src in self.groups_src:
|
|
116
|
+
for group_dest in self.groups_src[group_src]:
|
|
117
|
+
self.data[group_dest][index].resetAugmentation()
|
|
118
|
+
|
|
119
|
+
def load(self):
|
|
120
|
+
if self.use_cache:
|
|
121
|
+
memory_init = getMemory()
|
|
122
|
+
|
|
123
|
+
indexs = [index for index in range(self.nb_dataset) if index not in self._index_cache]
|
|
124
|
+
if len(indexs) > 0:
|
|
125
|
+
memory_lock = threading.Lock()
|
|
126
|
+
pbar = tqdm.tqdm(
|
|
127
|
+
total=len(indexs),
|
|
128
|
+
desc="Caching : init | {} | {}".format(memoryForecast(memory_init, 0, self.nb_dataset), cpuInfo()),
|
|
129
|
+
leave=False,
|
|
130
|
+
disable=self.rank != 0 and "DL_API_CLUSTER" not in os.environ
|
|
131
|
+
)
|
|
132
|
+
|
|
133
|
+
def process(index):
|
|
134
|
+
self._loadData(index)
|
|
135
|
+
with memory_lock:
|
|
136
|
+
pbar.set_description("Caching : {} | {} | {}".format(memoryInfo(), memoryForecast(memory_init, index, self.nb_dataset), cpuInfo()))
|
|
137
|
+
pbar.update(1)
|
|
138
|
+
with ThreadPoolExecutor(max_workers=os.cpu_count()//device_count()) as executor:
|
|
139
|
+
futures = [executor.submit(process, index) for index in indexs]
|
|
140
|
+
for _ in as_completed(futures):
|
|
141
|
+
pass
|
|
142
|
+
|
|
143
|
+
pbar.close()
|
|
144
|
+
|
|
145
|
+
def _loadData(self, index):
|
|
146
|
+
if index not in self._index_cache:
|
|
147
|
+
self._index_cache.append(index)
|
|
148
|
+
for group_src in self.groups_src:
|
|
149
|
+
for group_dest in self.groups_src[group_src]:
|
|
150
|
+
self.loadData(group_src, group_dest, index)
|
|
151
|
+
|
|
152
|
+
def loadData(self, group_src: str, group_dest : str, index : int) -> None:
|
|
153
|
+
self.data[group_dest][index].load(self.groups_src[group_src][group_dest].pre_transforms, self.dataAugmentationsList, self.device)
|
|
154
|
+
|
|
155
|
+
def _unloadData(self, index : int) -> None:
|
|
156
|
+
if index in self._index_cache:
|
|
157
|
+
self._index_cache.remove(index)
|
|
158
|
+
for group_src in self.groups_src:
|
|
159
|
+
for group_dest in self.groups_src[group_src]:
|
|
160
|
+
self.unloadData(group_dest, index)
|
|
161
|
+
|
|
162
|
+
def unloadData(self, group_dest : str, index : int) -> None:
|
|
163
|
+
return self.data[group_dest][index].unload()
|
|
164
|
+
|
|
165
|
+
def __len__(self) -> int:
|
|
166
|
+
return len(self.map)
|
|
167
|
+
|
|
168
|
+
def __getitem__(self, index : int) -> dict[str, tuple[torch.Tensor, int, int, int, str, bool]]:
|
|
169
|
+
data = {}
|
|
170
|
+
x, a, p = self.map[index]
|
|
171
|
+
if x not in self._index_cache:
|
|
172
|
+
if len(self._index_cache) >= self.buffer_size and not self.use_cache:
|
|
173
|
+
self._unloadData(self._index_cache[0])
|
|
174
|
+
self._loadData(x)
|
|
175
|
+
|
|
176
|
+
for group_src in self.groups_src:
|
|
177
|
+
for group_dest in self.groups_src[group_src]:
|
|
178
|
+
dataset = self.data[group_dest][x]
|
|
179
|
+
data["{}".format(group_dest)] = (dataset.getData(p, a, self.groups_src[group_src][group_dest].post_transforms, self.groups_src[group_src][group_dest].isInput), x, a, p, dataset.name, self.groups_src[group_src][group_dest].isInput)
|
|
180
|
+
return data
|
|
181
|
+
|
|
182
|
+
class Subset():
|
|
183
|
+
|
|
184
|
+
def __init__(self, subset: Union[str, list[int], list[str], None] = None, shuffle: bool = True, filter: str = None) -> None:
|
|
185
|
+
self.subset = subset
|
|
186
|
+
self.shuffle = shuffle
|
|
187
|
+
self.filter: list[tuple[str]] = [tuple(a.split(":")) for a in filter.split(";")] if filter is not None else None
|
|
188
|
+
|
|
189
|
+
def __call__(self, names: list[str], infos: list[dict[str, tuple[np.ndarray, Attribute]]]) -> set[str]:
|
|
190
|
+
inter_name = set(names[0])
|
|
191
|
+
for n in names[1:]:
|
|
192
|
+
inter_name = inter_name.intersection(set(n))
|
|
193
|
+
names = sorted(list(inter_name))
|
|
194
|
+
|
|
195
|
+
names_filtred = []
|
|
196
|
+
if self.filter is not None:
|
|
197
|
+
for name in names:
|
|
198
|
+
for info in infos:
|
|
199
|
+
if all([info[name][1].isInfo(f[0], f[1]) for f in self.filter]):
|
|
200
|
+
names_filtred.append(name)
|
|
201
|
+
continue
|
|
202
|
+
else:
|
|
203
|
+
names_filtred = names
|
|
204
|
+
size = len(names_filtred)
|
|
205
|
+
|
|
206
|
+
index = []
|
|
207
|
+
if self.subset is None:
|
|
208
|
+
index = list(range(0, size))
|
|
209
|
+
elif isinstance(self.subset, str):
|
|
210
|
+
if ":" in self.subset:
|
|
211
|
+
r = np.clip(np.asarray([int(self.subset.split(":")[0]), int(self.subset.split(":")[1])]), 0, size)
|
|
212
|
+
index = list(range(r[0], r[1]))
|
|
213
|
+
elif os.path.exists(self.subset):
|
|
214
|
+
validation_names = []
|
|
215
|
+
with open(self.subset, "r") as f:
|
|
216
|
+
for name in f:
|
|
217
|
+
validation_names.append(name.strip())
|
|
218
|
+
index = []
|
|
219
|
+
for i, name in enumerate(names_filtred):
|
|
220
|
+
if name in validation_names:
|
|
221
|
+
index.append(i)
|
|
222
|
+
|
|
223
|
+
elif isinstance(self.subset, list):
|
|
224
|
+
if len(self.subset) > 0:
|
|
225
|
+
if isinstance(self.subset[0], int):
|
|
226
|
+
if len(self.subset) == 1:
|
|
227
|
+
index = list(range(self.subset[0], min(size, self.subset[0]+1)))
|
|
228
|
+
else:
|
|
229
|
+
index = self.subset
|
|
230
|
+
if isinstance(self.subset[0], str):
|
|
231
|
+
index = []
|
|
232
|
+
for i, name in enumerate(names_filtred):
|
|
233
|
+
if name in self.subset:
|
|
234
|
+
index.append(i)
|
|
235
|
+
if self.shuffle:
|
|
236
|
+
index = random.sample(index, len(index))
|
|
237
|
+
return set([names_filtred[i] for i in index])
|
|
238
|
+
|
|
239
|
+
class TrainSubset(Subset):
|
|
240
|
+
|
|
241
|
+
@config()
|
|
242
|
+
def __init__(self, subset: Union[str, list[int], list[str], None] = None, shuffle: bool = True, filter: str = None) -> None:
|
|
243
|
+
super().__init__(subset, shuffle, filter)
|
|
244
|
+
|
|
245
|
+
class PredictionSubset(Subset):
|
|
246
|
+
|
|
247
|
+
@config()
|
|
248
|
+
def __init__(self, subset: Union[str, list[int], list[str], None] = None, filter: str = None) -> None:
|
|
249
|
+
super().__init__(subset, False, filter)
|
|
250
|
+
|
|
251
|
+
class Data(ABC):
|
|
252
|
+
|
|
253
|
+
def __init__(self, dataset_filenames : list[str],
|
|
254
|
+
groups_src : dict[str, Group],
|
|
255
|
+
patch : Union[DatasetPatch, None],
|
|
256
|
+
use_cache : bool,
|
|
257
|
+
subset : Union[Subset, dict[str, Subset]],
|
|
258
|
+
num_workers : int,
|
|
259
|
+
batch_size : int,
|
|
260
|
+
train_size: Union[float, str, list[int], list[str]] = 1,
|
|
261
|
+
inlineAugmentations: bool = False,
|
|
262
|
+
dataAugmentationsList: dict[str, DataAugmentationsList]= {}) -> None:
|
|
263
|
+
self.dataset_filenames = dataset_filenames
|
|
264
|
+
self.subset = subset
|
|
265
|
+
self.groups_src = groups_src
|
|
266
|
+
self.patch = patch
|
|
267
|
+
self.train_size = train_size
|
|
268
|
+
self.dataAugmentationsList = dataAugmentationsList
|
|
269
|
+
self.batch_size = batch_size
|
|
270
|
+
self.dataSet_args = dict(groups_src=self.groups_src, inlineAugmentations=inlineAugmentations, dataAugmentationsList = list(self.dataAugmentationsList.values()), use_cache = use_cache, buffer_size=batch_size+1, patch_size=self.patch.patch_size if self.patch is not None else None, overlap=self.patch.overlap if self.patch is not None else None)
|
|
271
|
+
self.dataLoader_args = dict(num_workers=num_workers if use_cache else 0, pin_memory=True)
|
|
272
|
+
self.data : list[list[dict[str, list[DatasetManager]]], dict[str, list[DatasetManager]]] = []
|
|
273
|
+
self.map : list[list[list[tuple[int, int, int]]], list[tuple[int, int, int]]] = []
|
|
274
|
+
self.datasets: dict[str, Dataset] = {}
|
|
275
|
+
|
|
276
|
+
def _getDatasets(self, names: list[str], dataset_name: dict[str, dict[str, list[str]]]) -> tuple[dict[str, list[Dataset]], list[tuple[int, int, int]]]:
|
|
277
|
+
nb_dataset = len(names)
|
|
278
|
+
nb_patch = None
|
|
279
|
+
data = {}
|
|
280
|
+
map = []
|
|
281
|
+
nb_augmentation = np.max([int(np.sum([data_augmentation.nb for data_augmentation in self.dataAugmentationsList.values()])+1), 1])
|
|
282
|
+
for group_src in self.groups_src:
|
|
283
|
+
for group_dest in self.groups_src[group_src]:
|
|
284
|
+
data[group_dest] = [DatasetManager(i, group_src, group_dest, name, self.datasets[[filename for filename, names in dataset_name[group_src].items() if name in names][0]], patch = self.patch, pre_transforms = self.groups_src[group_src][group_dest].pre_transforms, dataAugmentationsList=list(self.dataAugmentationsList.values())) for i, name in enumerate(names)]
|
|
285
|
+
nb_patch = [[dataset.getSize(a) for a in range(nb_augmentation)] for dataset in data[group_dest]]
|
|
286
|
+
|
|
287
|
+
for x in range(nb_dataset):
|
|
288
|
+
for y in range(nb_augmentation):
|
|
289
|
+
for z in range(nb_patch[x][y]):
|
|
290
|
+
map.append((x, y, z))
|
|
291
|
+
return data, map
|
|
292
|
+
|
|
293
|
+
def _split(map: list[tuple[int, int, int]], world_size: int) -> list[list[tuple[int, int, int]]]:
|
|
294
|
+
if len(map) == 0:
|
|
295
|
+
return [[] for _ in range(world_size)]
|
|
296
|
+
|
|
297
|
+
maps = []
|
|
298
|
+
if DL_API_STATE() == str(State.PREDICTION) or DL_API_STATE() == str(State.EVALUATION):
|
|
299
|
+
np_map = np.asarray(map)
|
|
300
|
+
unique_index = np.unique(np_map[:, 0])
|
|
301
|
+
offset = int(np.ceil(len(unique_index)/world_size))
|
|
302
|
+
if offset == 0:
|
|
303
|
+
offset = 1
|
|
304
|
+
for itr in range(0, len(unique_index), offset):
|
|
305
|
+
maps.append([tuple(v) for v in np_map[np.where(np.isin(np_map[:, 0], unique_index[itr:itr+offset]))[0], :]])
|
|
306
|
+
else:
|
|
307
|
+
offset = int(np.ceil(len(map)/world_size))
|
|
308
|
+
if offset == 0:
|
|
309
|
+
offset = 1
|
|
310
|
+
for itr in range(0, len(map), offset):
|
|
311
|
+
maps.append(list(map[-offset:]) if itr+offset > len(map) else map[itr:itr+offset])
|
|
312
|
+
return maps
|
|
313
|
+
|
|
314
|
+
def getData(self, world_size: int) -> list[list[DataLoader]]:
|
|
315
|
+
datasets: dict[str, list[(str, bool)]] = {}
|
|
316
|
+
for dataset_filename in self.dataset_filenames:
|
|
317
|
+
if len(dataset_filename.split(":")) == 2:
|
|
318
|
+
filename, format = dataset_filename.split(":")
|
|
319
|
+
append = True
|
|
320
|
+
else:
|
|
321
|
+
filename, flag, format = dataset_filename.split(":")
|
|
322
|
+
append = flag == "a"
|
|
323
|
+
|
|
324
|
+
dataset = Dataset(filename, format)
|
|
325
|
+
|
|
326
|
+
self.datasets[filename] = dataset
|
|
327
|
+
for group in self.groups_src:
|
|
328
|
+
if dataset.isGroupExist(group):
|
|
329
|
+
if group in datasets:
|
|
330
|
+
datasets[group].append((filename, append))
|
|
331
|
+
else:
|
|
332
|
+
datasets[group] = [(filename, append)]
|
|
333
|
+
for group_src in self.groups_src:
|
|
334
|
+
assert group_src in datasets, "Error group source {} not found".format(group_src)
|
|
335
|
+
|
|
336
|
+
for group_dest in self.groups_src[group_src]:
|
|
337
|
+
self.groups_src[group_src][group_dest].load(group_src, group_dest, [self.datasets[filename] for filename, _ in datasets[group_src]])
|
|
338
|
+
for key, dataAugmentations in self.dataAugmentationsList.items():
|
|
339
|
+
dataAugmentations.load(key)
|
|
340
|
+
|
|
341
|
+
names : list[list[str]] = []
|
|
342
|
+
dataset_name : dict[str, dict[str, list[str]]] = {}
|
|
343
|
+
dataset_info : dict[str, dict[str, dict[str, Attribute]]] = {}
|
|
344
|
+
for group in self.groups_src:
|
|
345
|
+
if group not in dataset_name:
|
|
346
|
+
dataset_name[group] = {}
|
|
347
|
+
dataset_info[group] = {}
|
|
348
|
+
for filename, _ in datasets[group]:
|
|
349
|
+
names.append(self.datasets[filename].getNames(group))
|
|
350
|
+
dataset_name[group][filename] = self.datasets[filename].getNames(group)
|
|
351
|
+
dataset_info[group][filename] = {name: self.datasets[filename].getInfos(group, name) for name in dataset_name[group][filename]}
|
|
352
|
+
subset_names = set()
|
|
353
|
+
if isinstance(self.subset, dict):
|
|
354
|
+
for filename, subset in self.subset.items():
|
|
355
|
+
subset_names.update(subset([dataset_name[group][filename] for group in dataset_name], [dataset_info[group][filename] for group in dataset_name]))
|
|
356
|
+
else:
|
|
357
|
+
for group in dataset_name:
|
|
358
|
+
for filename, append in datasets[group]:
|
|
359
|
+
if append:
|
|
360
|
+
subset_names.update(self.subset([dataset_name[group][filename]], [dataset_info[group][filename]]))
|
|
361
|
+
else:
|
|
362
|
+
if len(subset_names) == 0:
|
|
363
|
+
subset_names.update(self.subset([dataset_name[group][filename]], [dataset_info[group][filename]]))
|
|
364
|
+
else:
|
|
365
|
+
subset_names.intersection(self.subset([dataset_name[group][filename]], [dataset_info[group][filename]]))
|
|
366
|
+
data, map = self._getDatasets(list(subset_names), dataset_name)
|
|
367
|
+
|
|
368
|
+
train_map = map
|
|
369
|
+
validate_map = []
|
|
370
|
+
if isinstance(self.train_size, float):
|
|
371
|
+
if self.train_size < 1.0 and int(math.floor(len(map)*(1-self.train_size))) > 0:
|
|
372
|
+
train_map, validate_map = map[:int(math.floor(len(map)*self.train_size))], map[int(math.floor(len(map)*self.train_size)):]
|
|
373
|
+
elif isinstance(self.train_size, str):
|
|
374
|
+
if ":" in self.train_size:
|
|
375
|
+
index = list(range(int(self.subset.split(":")[0]), int(self.subset.split(":")[1])))
|
|
376
|
+
train_map = [m for m in map if m[0] not in index]
|
|
377
|
+
validate_map = [m for m in map if m[0] in index]
|
|
378
|
+
elif os.path.exists(self.train_size):
|
|
379
|
+
validation_names = []
|
|
380
|
+
with open(self.train_size, "r") as f:
|
|
381
|
+
for name in f:
|
|
382
|
+
validation_names.append(name.strip())
|
|
383
|
+
index = [i for i, n in enumerate(subset_names) if n in validation_names]
|
|
384
|
+
train_map = [m for m in map if m[0] not in index]
|
|
385
|
+
validate_map = [m for m in map if m[0] in index]
|
|
386
|
+
else:
|
|
387
|
+
validate_map = train_map
|
|
388
|
+
elif isinstance(self.train_size, list):
|
|
389
|
+
if len(self.train_size) > 0:
|
|
390
|
+
if isinstance(self.train_size[0], int):
|
|
391
|
+
train_map = [m for m in map if m[0] not in self.train_size]
|
|
392
|
+
validate_map = [m for m in map if m[0] in self.train_size]
|
|
393
|
+
elif isinstance(self.train_size[0], str):
|
|
394
|
+
index = [i for i, n in enumerate(subset_names) if n in self.train_size]
|
|
395
|
+
train_map = [m for m in map if m[0] not in index]
|
|
396
|
+
validate_map = [m for m in map if m[0] in index]
|
|
397
|
+
|
|
398
|
+
train_maps = Data._split(train_map, world_size)
|
|
399
|
+
validate_maps = Data._split(validate_map, world_size)
|
|
400
|
+
|
|
401
|
+
for i, (train_map, validate_map) in enumerate(zip(train_maps, validate_maps)):
|
|
402
|
+
maps = [train_map]
|
|
403
|
+
if len(validate_map):
|
|
404
|
+
maps += [validate_map]
|
|
405
|
+
self.data.append([])
|
|
406
|
+
self.map.append([])
|
|
407
|
+
for map_tmp in maps:
|
|
408
|
+
indexs = np.unique(np.asarray(map_tmp)[:, 0])
|
|
409
|
+
self.data[i].append({k:[v[it] for it in indexs] for k, v in data.items()})
|
|
410
|
+
map_tmp_array = np.asarray(map_tmp)
|
|
411
|
+
for a, b in enumerate(indexs):
|
|
412
|
+
map_tmp_array[np.where(np.asarray(map_tmp_array)[:, 0] == b), 0] = a
|
|
413
|
+
self.map[i].append([(a,b,c) for a,b,c in map_tmp_array])
|
|
414
|
+
|
|
415
|
+
dataLoaders: list[list[DataLoader]] = []
|
|
416
|
+
for i, (datas, maps) in enumerate(zip(self.data, self.map)):
|
|
417
|
+
dataLoaders.append([])
|
|
418
|
+
for data, map in zip(datas, maps):
|
|
419
|
+
dataLoaders[i].append(DataLoader(dataset=DatasetIter(rank=i, data=data, map=map, **self.dataSet_args), sampler=CustomSampler(len(map), self.subset.shuffle), batch_size=self.batch_size,**self.dataLoader_args))
|
|
420
|
+
return dataLoaders
|
|
421
|
+
|
|
422
|
+
class DataTrain(Data):
|
|
423
|
+
|
|
424
|
+
@config("Dataset")
|
|
425
|
+
def __init__(self, dataset_filenames : list[str] = ["default:Dataset.h5"],
|
|
426
|
+
groups_src : dict[str, Group] = {"default" : Group()},
|
|
427
|
+
augmentations : Union[dict[str, DataAugmentationsList], None] = {"DataAugmentation_0" : DataAugmentationsList()},
|
|
428
|
+
inlineAugmentations: bool = False,
|
|
429
|
+
patch : Union[DatasetPatch, None] = DatasetPatch(),
|
|
430
|
+
use_cache : bool = True,
|
|
431
|
+
subset : Union[TrainSubset, dict[str, TrainSubset]] = TrainSubset(),
|
|
432
|
+
num_workers : int = 4,
|
|
433
|
+
batch_size : int = 1,
|
|
434
|
+
train_size : Union[float, str, list[int], list[str]] = 0.8) -> None:
|
|
435
|
+
super().__init__(dataset_filenames, groups_src, patch, use_cache, subset, num_workers, batch_size, train_size, inlineAugmentations, augmentations if augmentations else {})
|
|
436
|
+
|
|
437
|
+
class DataPrediction(Data):
|
|
438
|
+
|
|
439
|
+
@config("Dataset")
|
|
440
|
+
def __init__(self, dataset_filenames : list[str] = ["default:Dataset.h5"],
|
|
441
|
+
groups_src : dict[str, Group] = {"default" : Group()},
|
|
442
|
+
augmentations : Union[dict[str, DataAugmentationsList], None] = {"DataAugmentation_0" : DataAugmentationsList()},
|
|
443
|
+
inlineAugmentations: bool = False,
|
|
444
|
+
patch : Union[DatasetPatch, None] = DatasetPatch(),
|
|
445
|
+
use_cache : bool = True,
|
|
446
|
+
subset : Union[PredictionSubset, dict[str, PredictionSubset]] = PredictionSubset(),
|
|
447
|
+
num_workers : int = 4,
|
|
448
|
+
batch_size : int = 1) -> None:
|
|
449
|
+
|
|
450
|
+
super().__init__(dataset_filenames, groups_src, patch, use_cache, subset, num_workers, batch_size, inlineAugmentations=inlineAugmentations, dataAugmentationsList=augmentations if augmentations else {})
|
|
451
|
+
|
|
452
|
+
class DataMetric(Data):
|
|
453
|
+
|
|
454
|
+
@config("Dataset")
|
|
455
|
+
def __init__(self, dataset_filenames : list[str] = ["default:Dataset.h5"],
|
|
456
|
+
groups_src : dict[str, Group] = {"default" : Group()},
|
|
457
|
+
subset : Union[PredictionSubset, dict[str, PredictionSubset]] = PredictionSubset(),
|
|
458
|
+
validation: Union[str, None] = None,
|
|
459
|
+
num_workers : int = 4) -> None:
|
|
460
|
+
|
|
461
|
+
super().__init__(dataset_filenames=dataset_filenames, groups_src=groups_src, patch=None, use_cache=False, subset=subset, num_workers=num_workers, batch_size=1, train_size=1 if validation is None else validation)
|
|
462
|
+
|
|
463
|
+
class DataHyperparameter(Data):
|
|
464
|
+
|
|
465
|
+
@config("Dataset")
|
|
466
|
+
def __init__(self, dataset_filenames : list[str] = ["default:Dataset.h5"],
|
|
467
|
+
groups_src : dict[str, Group] = {"default" : Group()},
|
|
468
|
+
patch : Union[DatasetPatch, None] = DatasetPatch()) -> None:
|
|
469
|
+
|
|
470
|
+
super().__init__(dataset_filenames, groups_src, patch, False, PredictionSubset(), 0, False, 1)
|