konfai 1.1.7__py3-none-any.whl → 1.1.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of konfai might be problematic. Click here for more details.
- konfai/__init__.py +59 -14
- konfai/data/augmentation.py +457 -286
- konfai/data/data_manager.py +509 -290
- konfai/data/patching.py +300 -183
- konfai/data/transform.py +384 -277
- konfai/evaluator.py +309 -68
- konfai/main.py +71 -22
- konfai/metric/measure.py +341 -222
- konfai/metric/schedulers.py +24 -13
- konfai/models/classification/convNeXt.py +187 -81
- konfai/models/classification/resnet.py +272 -58
- konfai/models/generation/cStyleGan.py +233 -59
- konfai/models/generation/ddpm.py +348 -121
- konfai/models/generation/diffusionGan.py +757 -358
- konfai/models/generation/gan.py +177 -53
- konfai/models/generation/vae.py +140 -40
- konfai/models/registration/registration.py +135 -52
- konfai/models/representation/representation.py +57 -23
- konfai/models/segmentation/NestedUNet.py +339 -68
- konfai/models/segmentation/UNet.py +140 -30
- konfai/network/blocks.py +331 -187
- konfai/network/network.py +781 -423
- konfai/predictor.py +645 -240
- konfai/trainer.py +527 -216
- konfai/utils/ITK.py +191 -106
- konfai/utils/config.py +152 -95
- konfai/utils/dataset.py +326 -455
- konfai/utils/utils.py +495 -249
- {konfai-1.1.7.dist-info → konfai-1.1.9.dist-info}/METADATA +1 -3
- konfai-1.1.9.dist-info/RECORD +38 -0
- konfai/utils/registration.py +0 -199
- konfai-1.1.7.dist-info/RECORD +0 -39
- {konfai-1.1.7.dist-info → konfai-1.1.9.dist-info}/WHEEL +0 -0
- {konfai-1.1.7.dist-info → konfai-1.1.9.dist-info}/entry_points.txt +0 -0
- {konfai-1.1.7.dist-info → konfai-1.1.9.dist-info}/licenses/LICENSE +0 -0
- {konfai-1.1.7.dist-info → konfai-1.1.9.dist-info}/top_level.txt +0 -0
konfai/data/data_manager.py
CHANGED
|
@@ -1,85 +1,114 @@
|
|
|
1
1
|
import math
|
|
2
2
|
import os
|
|
3
3
|
import random
|
|
4
|
+
import threading
|
|
5
|
+
from abc import ABC, abstractmethod
|
|
6
|
+
from collections.abc import Iterator, Mapping
|
|
7
|
+
from concurrent.futures import ThreadPoolExecutor, as_completed
|
|
8
|
+
from functools import partial
|
|
9
|
+
|
|
10
|
+
import numpy as np
|
|
4
11
|
import torch
|
|
5
|
-
from torch.utils import data
|
|
6
12
|
import tqdm
|
|
7
|
-
import numpy as np
|
|
8
|
-
from abc import ABC
|
|
9
|
-
from torch.utils.data import DataLoader, Sampler
|
|
10
|
-
from typing import Union, Iterator
|
|
11
|
-
from concurrent.futures import ThreadPoolExecutor, as_completed
|
|
12
|
-
import threading
|
|
13
13
|
from torch.cuda import device_count
|
|
14
|
-
|
|
14
|
+
from torch.utils import data
|
|
15
|
+
from torch.utils.data import DataLoader, Sampler
|
|
15
16
|
|
|
16
|
-
from konfai import
|
|
17
|
-
from konfai.data.patching import DatasetPatch, DatasetManager
|
|
18
|
-
from konfai.utils.config import config
|
|
19
|
-
from konfai.utils.utils import memoryInfo, cpuInfo, memoryForecast, getMemory, State, SUPPORTED_EXTENSIONS, DatasetManagerError
|
|
20
|
-
from konfai.utils.dataset import Dataset, Attribute
|
|
21
|
-
from konfai.data.transform import TransformLoader, Transform
|
|
17
|
+
from konfai import konfai_root, konfai_state
|
|
22
18
|
from konfai.data.augmentation import DataAugmentationsList
|
|
19
|
+
from konfai.data.patching import DatasetManager, DatasetPatch
|
|
20
|
+
from konfai.data.transform import Transform, TransformLoader
|
|
21
|
+
from konfai.utils.config import config
|
|
22
|
+
from konfai.utils.dataset import Attribute, Dataset
|
|
23
|
+
from konfai.utils.utils import (
|
|
24
|
+
SUPPORTED_EXTENSIONS,
|
|
25
|
+
DatasetManagerError,
|
|
26
|
+
State,
|
|
27
|
+
get_cpu_info,
|
|
28
|
+
get_memory,
|
|
29
|
+
get_memory_info,
|
|
30
|
+
memory_forecast,
|
|
31
|
+
)
|
|
32
|
+
|
|
23
33
|
|
|
24
34
|
class GroupTransform:
|
|
25
35
|
|
|
26
36
|
@config()
|
|
27
|
-
def __init__(
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
37
|
+
def __init__(
|
|
38
|
+
self,
|
|
39
|
+
transforms: dict[str, TransformLoader] = {
|
|
40
|
+
"default:Normalize:Standardize:Unsqueeze:TensorCast:ResampleIsotropic:ResampleResize": TransformLoader()
|
|
41
|
+
},
|
|
42
|
+
patch_transforms: dict[str, TransformLoader] = {
|
|
43
|
+
"default:Normalize:Standardize:Unsqueeze:TensorCast:ResampleIsotropic:ResampleResize": TransformLoader()
|
|
44
|
+
},
|
|
45
|
+
is_input: bool = True,
|
|
46
|
+
) -> None:
|
|
47
|
+
self._transforms = transforms
|
|
48
|
+
self._patch_transforms = patch_transforms
|
|
49
|
+
self.transforms: list[Transform] = []
|
|
50
|
+
self.patch_transforms: list[Transform] = []
|
|
51
|
+
self.is_input = is_input
|
|
52
|
+
|
|
53
|
+
def load(self, group_src: str, group_dest: str, datasets: list[Dataset]):
|
|
54
|
+
if self._transforms is not None:
|
|
55
|
+
for classpath, transform_loader in self._transforms.items():
|
|
56
|
+
transform = transform_loader.get_transform(
|
|
57
|
+
classpath,
|
|
58
|
+
konfai_args=f"{konfai_root()}.Dataset.groups_src.{group_src}.groups_dest.{group_dest}.transforms",
|
|
59
|
+
)
|
|
60
|
+
transform.set_datasets(datasets)
|
|
61
|
+
self.transforms.append(transform)
|
|
62
|
+
|
|
63
|
+
if self._patch_transforms is not None:
|
|
64
|
+
for classpath, transform_loader in self._patch_transforms.items():
|
|
65
|
+
transform = transform_loader.get_transform(
|
|
66
|
+
classpath,
|
|
67
|
+
konfai_args=f"{konfai_root()}.Dataset.groups_src.{group_src}"
|
|
68
|
+
f".groups_dest.{group_dest}.patch_transforms",
|
|
69
|
+
)
|
|
70
|
+
transform.set_datasets(datasets)
|
|
71
|
+
self.patch_transforms.append(transform)
|
|
72
|
+
|
|
59
73
|
def to(self, device: int):
|
|
60
|
-
for transform in self.
|
|
61
|
-
transform.
|
|
62
|
-
for transform in self.
|
|
63
|
-
transform.
|
|
74
|
+
for transform in self.transforms:
|
|
75
|
+
transform.to(device)
|
|
76
|
+
for transform in self.patch_transforms:
|
|
77
|
+
transform.to(device)
|
|
78
|
+
|
|
64
79
|
|
|
65
80
|
class GroupTransformMetric(GroupTransform):
|
|
66
81
|
|
|
67
82
|
@config()
|
|
68
|
-
def __init__(
|
|
83
|
+
def __init__(
|
|
84
|
+
self,
|
|
85
|
+
transforms: dict[str, TransformLoader] = {
|
|
86
|
+
"default:Normalize:Standardize:Unsqueeze:TensorCast:ResampleIsotropic:ResampleResize": TransformLoader()
|
|
87
|
+
},
|
|
88
|
+
):
|
|
69
89
|
super().__init__(transforms, None)
|
|
70
90
|
|
|
91
|
+
|
|
71
92
|
class Group(dict[str, GroupTransform]):
|
|
72
93
|
|
|
73
94
|
@config()
|
|
74
|
-
def __init__(
|
|
95
|
+
def __init__(
|
|
96
|
+
self,
|
|
97
|
+
groups_dest: dict[str, GroupTransform] = {"default:group_dest": GroupTransform()},
|
|
98
|
+
):
|
|
75
99
|
super().__init__(groups_dest)
|
|
76
100
|
|
|
101
|
+
|
|
77
102
|
class GroupMetric(dict[str, GroupTransformMetric]):
|
|
78
103
|
|
|
79
104
|
@config()
|
|
80
|
-
def __init__(
|
|
105
|
+
def __init__(
|
|
106
|
+
self,
|
|
107
|
+
groups_dest: dict[str, GroupTransformMetric] = {"default:group_dest": GroupTransformMetric()},
|
|
108
|
+
):
|
|
81
109
|
super().__init__(groups_dest)
|
|
82
110
|
|
|
111
|
+
|
|
83
112
|
class CustomSampler(Sampler[int]):
|
|
84
113
|
|
|
85
114
|
def __init__(self, size: int, shuffle: bool = False) -> None:
|
|
@@ -87,135 +116,178 @@ class CustomSampler(Sampler[int]):
|
|
|
87
116
|
self.shuffle = shuffle
|
|
88
117
|
|
|
89
118
|
def __iter__(self) -> Iterator[int]:
|
|
90
|
-
return iter(torch.randperm(len(self)).tolist() if self.shuffle else list(range(len(self)))
|
|
119
|
+
return iter(torch.randperm(len(self)).tolist() if self.shuffle else list(range(len(self))))
|
|
91
120
|
|
|
92
121
|
def __len__(self) -> int:
|
|
93
122
|
return self.size
|
|
94
123
|
|
|
124
|
+
|
|
95
125
|
class DatasetIter(data.Dataset):
|
|
96
126
|
|
|
97
|
-
def __init__(
|
|
127
|
+
def __init__(
|
|
128
|
+
self,
|
|
129
|
+
rank: int,
|
|
130
|
+
data: dict[str, list[DatasetManager]],
|
|
131
|
+
mapping: list[tuple[int, int, int]],
|
|
132
|
+
groups_src: Mapping[str, Group | GroupMetric],
|
|
133
|
+
inline_augmentations: bool,
|
|
134
|
+
data_augmentations_list: list[DataAugmentationsList],
|
|
135
|
+
patch_size: list[int] | None,
|
|
136
|
+
overlap: int | None,
|
|
137
|
+
buffer_size: int,
|
|
138
|
+
use_cache=True,
|
|
139
|
+
) -> None:
|
|
98
140
|
self.rank = rank
|
|
99
141
|
self.data = data
|
|
100
|
-
self.
|
|
142
|
+
self.mapping = mapping
|
|
101
143
|
self.patch_size = patch_size
|
|
102
144
|
self.overlap = overlap
|
|
103
145
|
self.groups_src = groups_src
|
|
104
|
-
self.
|
|
146
|
+
self.data_augmentations_list = data_augmentations_list
|
|
105
147
|
self.use_cache = use_cache
|
|
106
148
|
self.nb_dataset = len(data[list(data.keys())[0]])
|
|
107
149
|
self.buffer_size = buffer_size
|
|
108
|
-
self._index_cache =
|
|
109
|
-
self.
|
|
110
|
-
self.inlineAugmentations = inlineAugmentations
|
|
150
|
+
self._index_cache: list[int] = []
|
|
151
|
+
self.inline_augmentations = inline_augmentations
|
|
111
152
|
|
|
112
|
-
def
|
|
153
|
+
def get_patch_config(self) -> tuple[list[int] | None, int | None]:
|
|
113
154
|
return self.patch_size, self.overlap
|
|
114
|
-
|
|
155
|
+
|
|
115
156
|
def to(self, device: int):
|
|
116
157
|
for group_src in self.groups_src:
|
|
117
158
|
for group_dest in self.groups_src[group_src]:
|
|
118
159
|
self.groups_src[group_src][group_dest].to(device)
|
|
119
|
-
self.
|
|
160
|
+
for data_augmentations in self.data_augmentations_list:
|
|
161
|
+
for data_augmentation in data_augmentations.data_augmentations:
|
|
162
|
+
data_augmentation.to(device)
|
|
120
163
|
|
|
121
|
-
def
|
|
164
|
+
def get_dataset_from_index(self, group_dest: str, index: int) -> DatasetManager:
|
|
122
165
|
return self.data[group_dest][index]
|
|
123
|
-
|
|
124
|
-
def
|
|
125
|
-
if self.
|
|
166
|
+
|
|
167
|
+
def reset_augmentation(self, label):
|
|
168
|
+
if self.inline_augmentations and len(self.data_augmentations_list) > 0:
|
|
126
169
|
for index in range(self.nb_dataset):
|
|
127
170
|
for group_src in self.groups_src:
|
|
128
171
|
for group_dest in self.groups_src[group_src]:
|
|
129
|
-
self.data[group_dest][index].
|
|
130
|
-
self.data[group_dest][index].
|
|
172
|
+
self.data[group_dest][index].unload_augmentation()
|
|
173
|
+
self.data[group_dest][index].reset_augmentation()
|
|
131
174
|
self.load(label + " Augmentation")
|
|
132
175
|
|
|
133
176
|
def load(self, label: str):
|
|
134
177
|
if self.use_cache:
|
|
135
|
-
memory_init =
|
|
178
|
+
memory_init = get_memory()
|
|
136
179
|
|
|
137
|
-
indexs =
|
|
180
|
+
indexs = list(range(self.nb_dataset))
|
|
138
181
|
if len(indexs) > 0:
|
|
139
182
|
memory_lock = threading.Lock()
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
183
|
+
|
|
184
|
+
def desc():
|
|
185
|
+
return (
|
|
186
|
+
f"Caching {label}: "
|
|
187
|
+
f"{get_memory_info()} | "
|
|
188
|
+
f"{memory_forecast(memory_init, 0, self.nb_dataset)} | "
|
|
189
|
+
f"{get_cpu_info()}"
|
|
190
|
+
)
|
|
191
|
+
|
|
192
|
+
pbar = tqdm.tqdm(total=len(indexs), desc=desc(), leave=False)
|
|
146
193
|
|
|
147
194
|
def process(index):
|
|
148
|
-
self.
|
|
195
|
+
self._load_data(index)
|
|
149
196
|
with memory_lock:
|
|
150
197
|
pbar.set_description(desc())
|
|
151
198
|
pbar.update(1)
|
|
152
|
-
|
|
199
|
+
|
|
200
|
+
cpu_count = os.cpu_count() or 1
|
|
201
|
+
with ThreadPoolExecutor(
|
|
202
|
+
max_workers=cpu_count // (device_count() if device_count() > 0 else 1)
|
|
203
|
+
) as executor:
|
|
153
204
|
futures = [executor.submit(process, index) for index in indexs]
|
|
154
205
|
for _ in as_completed(futures):
|
|
155
206
|
pass
|
|
156
207
|
|
|
157
208
|
pbar.close()
|
|
158
|
-
|
|
159
|
-
def
|
|
209
|
+
|
|
210
|
+
def _load_data(self, index):
|
|
160
211
|
if index not in self._index_cache:
|
|
161
212
|
self._index_cache.append(index)
|
|
162
213
|
for group_src in self.groups_src:
|
|
163
214
|
for group_dest in self.groups_src[group_src]:
|
|
164
|
-
self.
|
|
215
|
+
self.load_data(group_src, group_dest, index)
|
|
165
216
|
|
|
166
|
-
def
|
|
167
|
-
self.data[group_dest][index].load(
|
|
217
|
+
def load_data(self, group_src: str, group_dest: str, index: int) -> None:
|
|
218
|
+
self.data[group_dest][index].load(
|
|
219
|
+
self.groups_src[group_src][group_dest].transforms,
|
|
220
|
+
self.data_augmentations_list,
|
|
221
|
+
)
|
|
168
222
|
|
|
169
|
-
def
|
|
223
|
+
def _unload_data(self, index: int) -> None:
|
|
170
224
|
if index in self._index_cache:
|
|
171
225
|
self._index_cache.remove(index)
|
|
172
226
|
for group_src in self.groups_src:
|
|
173
227
|
for group_dest in self.groups_src[group_src]:
|
|
174
|
-
self.
|
|
175
|
-
|
|
176
|
-
def
|
|
228
|
+
self.unload_data(group_dest, index)
|
|
229
|
+
|
|
230
|
+
def unload_data(self, group_dest: str, index: int) -> None:
|
|
177
231
|
return self.data[group_dest][index].unload()
|
|
178
232
|
|
|
179
233
|
def __len__(self) -> int:
|
|
180
|
-
return len(self.
|
|
234
|
+
return len(self.mapping)
|
|
181
235
|
|
|
182
|
-
def __getitem__(self, index
|
|
236
|
+
def __getitem__(self, index: int) -> dict[str, tuple[torch.Tensor, int, int, int, str, bool]]:
|
|
183
237
|
data = {}
|
|
184
|
-
x, a, p = self.
|
|
238
|
+
x, a, p = self.mapping[index]
|
|
185
239
|
if x not in self._index_cache:
|
|
186
240
|
if len(self._index_cache) >= self.buffer_size and not self.use_cache:
|
|
187
|
-
self.
|
|
188
|
-
self.
|
|
241
|
+
self._unload_data(self._index_cache[0])
|
|
242
|
+
self._load_data(x)
|
|
189
243
|
|
|
190
244
|
for group_src in self.groups_src:
|
|
191
245
|
for group_dest in self.groups_src[group_src]:
|
|
192
246
|
dataset = self.data[group_dest][x]
|
|
193
|
-
data["{}"
|
|
247
|
+
data[f"{group_dest}"] = (
|
|
248
|
+
dataset.get_data(
|
|
249
|
+
p,
|
|
250
|
+
a,
|
|
251
|
+
self.groups_src[group_src][group_dest].patch_transforms,
|
|
252
|
+
self.groups_src[group_src][group_dest].is_input,
|
|
253
|
+
),
|
|
254
|
+
x,
|
|
255
|
+
a,
|
|
256
|
+
p,
|
|
257
|
+
dataset.name,
|
|
258
|
+
self.groups_src[group_src][group_dest].is_input,
|
|
259
|
+
)
|
|
194
260
|
return data
|
|
195
261
|
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
262
|
+
|
|
263
|
+
class Subset:
|
|
264
|
+
|
|
265
|
+
def __init__(
|
|
266
|
+
self,
|
|
267
|
+
subset: str | list[int] | list[str] | None = None,
|
|
268
|
+
shuffle: bool = True,
|
|
269
|
+
) -> None:
|
|
199
270
|
self.subset = subset
|
|
200
271
|
self.shuffle = shuffle
|
|
201
272
|
|
|
202
|
-
def __call__(self, names: list[str], infos:
|
|
203
|
-
|
|
204
|
-
|
|
205
|
-
inter_name = inter_name.intersection(set(n))
|
|
206
|
-
names = sorted(list(inter_name))
|
|
207
|
-
|
|
273
|
+
def __call__(self, names: list[str], infos: dict[str, tuple[list[int], Attribute]]) -> set[str]:
|
|
274
|
+
names = sorted(names)
|
|
275
|
+
|
|
208
276
|
size = len(names)
|
|
209
277
|
index = []
|
|
210
278
|
if self.subset is None:
|
|
211
279
|
index = list(range(0, size))
|
|
212
280
|
elif isinstance(self.subset, str):
|
|
213
281
|
if ":" in self.subset:
|
|
214
|
-
r = np.clip(
|
|
282
|
+
r = np.clip(
|
|
283
|
+
np.asarray([int(self.subset.split(":")[0]), int(self.subset.split(":")[1])]),
|
|
284
|
+
0,
|
|
285
|
+
size,
|
|
286
|
+
)
|
|
215
287
|
index = list(range(r[0], r[1]))
|
|
216
288
|
elif os.path.exists(self.subset):
|
|
217
289
|
train_names = []
|
|
218
|
-
with open(self.subset
|
|
290
|
+
with open(self.subset) as f:
|
|
219
291
|
for name in f:
|
|
220
292
|
train_names.append(name.strip())
|
|
221
293
|
index = []
|
|
@@ -224,7 +296,7 @@ class Subset():
|
|
|
224
296
|
index.append(i)
|
|
225
297
|
elif self.subset.startswith("~") and os.path.exists(self.subset[1:]):
|
|
226
298
|
exclude_names = []
|
|
227
|
-
with open(self.subset[1:]
|
|
299
|
+
with open(self.subset[1:]) as f:
|
|
228
300
|
for name in f:
|
|
229
301
|
exclude_names.append(name.strip())
|
|
230
302
|
index = []
|
|
@@ -233,200 +305,283 @@ class Subset():
|
|
|
233
305
|
index.append(i)
|
|
234
306
|
|
|
235
307
|
elif isinstance(self.subset, list):
|
|
308
|
+
index = []
|
|
236
309
|
if len(self.subset) > 0:
|
|
237
|
-
|
|
238
|
-
if
|
|
239
|
-
index
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
for i, name in enumerate(names):
|
|
245
|
-
if name in self.subset:
|
|
246
|
-
index.append(i)
|
|
310
|
+
for s in self.subset:
|
|
311
|
+
if isinstance(s, int):
|
|
312
|
+
index.append(s)
|
|
313
|
+
elif isinstance(s, str):
|
|
314
|
+
for i, name in enumerate(names):
|
|
315
|
+
if name in self.subset:
|
|
316
|
+
index.append(i)
|
|
247
317
|
if self.shuffle:
|
|
248
|
-
index = random.sample(index, len(index))
|
|
249
|
-
return
|
|
250
|
-
|
|
318
|
+
index = random.sample(index, len(index)) # nosec B311
|
|
319
|
+
return {names[i] for i in index}
|
|
320
|
+
|
|
251
321
|
def __str__(self):
|
|
252
|
-
return "Subset : " + str(self.subset) + " shuffle : "+ str(self.shuffle)
|
|
253
|
-
|
|
322
|
+
return "Subset : " + str(self.subset) + " shuffle : " + str(self.shuffle)
|
|
323
|
+
|
|
324
|
+
|
|
254
325
|
class TrainSubset(Subset):
|
|
255
326
|
|
|
256
327
|
@config()
|
|
257
|
-
def __init__(
|
|
328
|
+
def __init__(
|
|
329
|
+
self,
|
|
330
|
+
subset: str | list[int] | list[str] | None = None,
|
|
331
|
+
shuffle: bool = True,
|
|
332
|
+
) -> None:
|
|
258
333
|
super().__init__(subset, shuffle)
|
|
259
334
|
|
|
335
|
+
|
|
260
336
|
class PredictionSubset(Subset):
|
|
261
337
|
|
|
262
338
|
@config()
|
|
263
|
-
def __init__(self, subset:
|
|
339
|
+
def __init__(self, subset: str | list[int] | list[str] | None = None) -> None:
|
|
264
340
|
super().__init__(subset, False)
|
|
265
341
|
|
|
342
|
+
|
|
266
343
|
class Data(ABC):
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
|
|
270
|
-
|
|
271
|
-
|
|
272
|
-
|
|
273
|
-
|
|
274
|
-
|
|
275
|
-
|
|
276
|
-
|
|
344
|
+
|
|
345
|
+
@abstractmethod
|
|
346
|
+
def __init__(
|
|
347
|
+
self,
|
|
348
|
+
dataset_filenames: list[str],
|
|
349
|
+
groups_src: Mapping[str, Group | GroupMetric],
|
|
350
|
+
patch: DatasetPatch | None,
|
|
351
|
+
use_cache: bool,
|
|
352
|
+
subset: Subset,
|
|
353
|
+
batch_size: int,
|
|
354
|
+
validation: float | str | list[int] | list[str] | None,
|
|
355
|
+
inline_augmentations: bool,
|
|
356
|
+
data_augmentations_list: dict[str, DataAugmentationsList],
|
|
357
|
+
) -> None:
|
|
277
358
|
self.dataset_filenames = dataset_filenames
|
|
278
359
|
self.subset = subset
|
|
279
360
|
self.groups_src = groups_src
|
|
280
361
|
self.patch = patch
|
|
281
362
|
self.validation = validation
|
|
282
|
-
self.
|
|
363
|
+
self.data_augmentations_list = data_augmentations_list
|
|
283
364
|
self.batch_size = batch_size
|
|
284
|
-
|
|
285
|
-
self.
|
|
286
|
-
|
|
287
|
-
|
|
365
|
+
|
|
366
|
+
self.datasetIter = partial(
|
|
367
|
+
DatasetIter,
|
|
368
|
+
groups_src=self.groups_src,
|
|
369
|
+
inline_augmentations=inline_augmentations,
|
|
370
|
+
data_augmentations_list=list(self.data_augmentations_list.values()),
|
|
371
|
+
patch_size=self.patch.patch_size if self.patch is not None else None,
|
|
372
|
+
overlap=self.patch.overlap if self.patch is not None else None,
|
|
373
|
+
buffer_size=batch_size + 1,
|
|
374
|
+
use_cache=use_cache,
|
|
375
|
+
)
|
|
376
|
+
self.dataLoader_args = {
|
|
377
|
+
"num_workers": int(os.environ["KONFAI_WORKERS"]) if use_cache else 0,
|
|
378
|
+
"pin_memory": True,
|
|
379
|
+
}
|
|
380
|
+
self.data: list[list[dict[str, list[DatasetManager]]]] = []
|
|
381
|
+
self.mapping: list[list[list[tuple[int, int, int]]]] = []
|
|
288
382
|
self.datasets: dict[str, Dataset] = {}
|
|
289
383
|
|
|
290
|
-
def
|
|
384
|
+
def _get_datasets(
|
|
385
|
+
self, names: list[str], dataset_name: dict[str, dict[str, list[str]]]
|
|
386
|
+
) -> tuple[dict[str, list[DatasetManager]], list[tuple[int, int, int]]]:
|
|
291
387
|
nb_dataset = len(names)
|
|
292
|
-
nb_patch
|
|
388
|
+
nb_patch: list[list[int]]
|
|
293
389
|
data = {}
|
|
294
|
-
|
|
295
|
-
nb_augmentation = np.max(
|
|
390
|
+
mapping = []
|
|
391
|
+
nb_augmentation = np.max(
|
|
392
|
+
[
|
|
393
|
+
int(np.sum([data_augmentation.nb for data_augmentation in self.data_augmentations_list.values()]) + 1),
|
|
394
|
+
1,
|
|
395
|
+
]
|
|
396
|
+
)
|
|
296
397
|
for group_src in self.groups_src:
|
|
297
398
|
for group_dest in self.groups_src[group_src]:
|
|
298
|
-
data[group_dest] = [
|
|
299
|
-
|
|
399
|
+
data[group_dest] = [
|
|
400
|
+
DatasetManager(
|
|
401
|
+
i,
|
|
402
|
+
group_src,
|
|
403
|
+
group_dest,
|
|
404
|
+
name,
|
|
405
|
+
self.datasets[
|
|
406
|
+
[filename for filename, names in dataset_name[group_src].items() if name in names][0]
|
|
407
|
+
],
|
|
408
|
+
patch=self.patch,
|
|
409
|
+
transforms=self.groups_src[group_src][group_dest].transforms,
|
|
410
|
+
data_augmentations_list=list(self.data_augmentations_list.values()),
|
|
411
|
+
)
|
|
412
|
+
for i, name in enumerate(names)
|
|
413
|
+
]
|
|
414
|
+
nb_patch = [[dataset.get_size(a) for a in range(nb_augmentation)] for dataset in data[group_dest]]
|
|
300
415
|
|
|
301
416
|
for x in range(nb_dataset):
|
|
302
417
|
for y in range(nb_augmentation):
|
|
303
418
|
for z in range(nb_patch[x][y]):
|
|
304
|
-
|
|
305
|
-
return data,
|
|
419
|
+
mapping.append((x, y, z))
|
|
420
|
+
return data, mapping
|
|
306
421
|
|
|
307
|
-
def
|
|
308
|
-
|
|
422
|
+
def get_groups_dest(self):
|
|
423
|
+
groups_dest = []
|
|
309
424
|
for group_src in self.groups_src:
|
|
310
425
|
for group_dest in self.groups_src[group_src]:
|
|
311
|
-
|
|
312
|
-
return
|
|
313
|
-
|
|
314
|
-
|
|
315
|
-
|
|
426
|
+
groups_dest.append(group_dest)
|
|
427
|
+
return groups_dest
|
|
428
|
+
|
|
429
|
+
@staticmethod
|
|
430
|
+
def _split(mapping: list[tuple[int, int, int]], world_size: int) -> list[list[tuple[int, int, int]]]:
|
|
431
|
+
if len(mapping) == 0:
|
|
316
432
|
return [[] for _ in range(world_size)]
|
|
317
|
-
|
|
318
|
-
|
|
319
|
-
if
|
|
320
|
-
|
|
321
|
-
unique_index = np.unique(
|
|
322
|
-
offset = int(np.ceil(len(unique_index)/world_size))
|
|
433
|
+
|
|
434
|
+
mappings = []
|
|
435
|
+
if konfai_state() == str(State.PREDICTION) or konfai_state() == str(State.EVALUATION):
|
|
436
|
+
np_mapping = np.asarray(mapping)
|
|
437
|
+
unique_index = np.unique(np_mapping[:, 0])
|
|
438
|
+
offset = int(np.ceil(len(unique_index) / world_size))
|
|
323
439
|
if offset == 0:
|
|
324
440
|
offset = 1
|
|
325
441
|
for itr in range(0, len(unique_index), offset):
|
|
326
|
-
|
|
442
|
+
mappings.append(
|
|
443
|
+
[
|
|
444
|
+
tuple(v)
|
|
445
|
+
for v in np_mapping[
|
|
446
|
+
np.where(np.isin(np_mapping[:, 0], unique_index[itr : itr + offset]))[0],
|
|
447
|
+
:,
|
|
448
|
+
]
|
|
449
|
+
]
|
|
450
|
+
)
|
|
327
451
|
else:
|
|
328
|
-
offset = int(np.ceil(len(
|
|
452
|
+
offset = int(np.ceil(len(mapping) / world_size))
|
|
329
453
|
if offset == 0:
|
|
330
454
|
offset = 1
|
|
331
|
-
for itr in range(0, len(
|
|
332
|
-
|
|
333
|
-
return
|
|
334
|
-
|
|
335
|
-
def
|
|
336
|
-
datasets: dict[str, list[
|
|
455
|
+
for itr in range(0, len(mapping), offset):
|
|
456
|
+
mappings.append(list(mapping[-offset:]) if itr + offset > len(mapping) else mapping[itr : itr + offset])
|
|
457
|
+
return mappings
|
|
458
|
+
|
|
459
|
+
def get_data(self, world_size: int) -> list[list[DataLoader]]:
|
|
460
|
+
datasets: dict[str, list[tuple[str, bool]]] = {}
|
|
337
461
|
if self.dataset_filenames is None or len(self.dataset_filenames) == 0:
|
|
338
462
|
raise DatasetManagerError("No dataset filenames were provided")
|
|
339
463
|
for dataset_filename in self.dataset_filenames:
|
|
340
464
|
if dataset_filename is None:
|
|
341
|
-
raise DatasetManagerError(
|
|
342
|
-
"
|
|
343
|
-
"
|
|
465
|
+
raise DatasetManagerError(
|
|
466
|
+
"Invalid dataset entry: 'None' received.",
|
|
467
|
+
"Each dataset must be a valid path string (e.g., './Dataset/', './Dataset/:mha, "
|
|
468
|
+
"'./Dataset/:a:mha', './Dataset/:i:mha').",
|
|
469
|
+
"Please check your 'dataset_filenames' list for missing or null entries.",
|
|
344
470
|
)
|
|
345
471
|
if len(dataset_filename.split(":")) == 1:
|
|
346
472
|
filename = dataset_filename
|
|
347
|
-
|
|
473
|
+
file_format = "mha"
|
|
348
474
|
append = True
|
|
349
475
|
elif len(dataset_filename.split(":")) == 2:
|
|
350
|
-
filename,
|
|
476
|
+
filename, file_format = dataset_filename.split(":")
|
|
351
477
|
append = True
|
|
352
478
|
else:
|
|
353
|
-
filename, flag,
|
|
479
|
+
filename, flag, file_format = dataset_filename.split(":")
|
|
354
480
|
append = flag == "a"
|
|
355
481
|
|
|
356
|
-
if
|
|
357
|
-
raise DatasetManagerError(
|
|
358
|
-
|
|
359
|
-
|
|
360
|
-
|
|
482
|
+
if file_format not in SUPPORTED_EXTENSIONS:
|
|
483
|
+
raise DatasetManagerError(
|
|
484
|
+
f"Unsupported file format '{file_format}'.",
|
|
485
|
+
f"Supported extensions are: {', '.join(SUPPORTED_EXTENSIONS)}",
|
|
486
|
+
)
|
|
487
|
+
|
|
488
|
+
dataset = Dataset(filename, file_format)
|
|
361
489
|
|
|
362
490
|
self.datasets[filename] = dataset
|
|
363
491
|
for group in self.groups_src:
|
|
364
|
-
if dataset.
|
|
492
|
+
if dataset.is_group_exist(group):
|
|
365
493
|
if group in datasets:
|
|
366
|
-
datasets[group].append((filename, append))
|
|
494
|
+
datasets[group].append((filename, append))
|
|
367
495
|
else:
|
|
368
496
|
datasets[group] = [(filename, append)]
|
|
369
|
-
|
|
497
|
+
model_have_input = False
|
|
370
498
|
for group_src in self.groups_src:
|
|
371
499
|
if group_src not in datasets:
|
|
372
|
-
|
|
500
|
+
|
|
373
501
|
raise DatasetManagerError(
|
|
374
502
|
f"Group source '{group_src}' not found in any dataset.",
|
|
375
503
|
f"Dataset filenames provided: {self.dataset_filenames}",
|
|
376
|
-
"Available groups across all datasets:
|
|
377
|
-
f
|
|
504
|
+
f"Available groups across all datasets: "
|
|
505
|
+
"{[f'{f} {d.get_group()}' for f, d in self.datasets.items()]}"
|
|
506
|
+
f"Please check that an entry in the dataset with the name '{group_src}.{format}' exists.",
|
|
378
507
|
)
|
|
379
|
-
|
|
508
|
+
|
|
380
509
|
for group_dest in self.groups_src[group_src]:
|
|
381
|
-
self.groups_src[group_src][group_dest].load(
|
|
382
|
-
|
|
510
|
+
self.groups_src[group_src][group_dest].load(
|
|
511
|
+
group_src,
|
|
512
|
+
group_dest,
|
|
513
|
+
[self.datasets[filename] for filename, _ in datasets[group_src]],
|
|
514
|
+
)
|
|
515
|
+
model_have_input |= self.groups_src[group_src][group_dest].is_input
|
|
516
|
+
if self.patch is not None:
|
|
517
|
+
self.patch.init()
|
|
383
518
|
|
|
384
|
-
if not
|
|
519
|
+
if not model_have_input:
|
|
385
520
|
raise DatasetManagerError(
|
|
386
|
-
"At least one group must be defined with '
|
|
521
|
+
"At least one group must be defined with 'is_input: true' to provide input to the network."
|
|
387
522
|
)
|
|
388
523
|
|
|
389
|
-
for key,
|
|
390
|
-
|
|
524
|
+
for key, data_augmentations in self.data_augmentations_list.items():
|
|
525
|
+
data_augmentations.load(key, [self.datasets[filename] for filename, _ in datasets[group_src]])
|
|
391
526
|
|
|
392
|
-
names = set()
|
|
393
|
-
dataset_name
|
|
394
|
-
dataset_info
|
|
527
|
+
names: set[str] = set()
|
|
528
|
+
dataset_name: dict[str, dict[str, list[str]]] = {}
|
|
529
|
+
dataset_info: dict[str, dict[str, dict[str, tuple[list[int], Attribute]]]] = {}
|
|
395
530
|
for group in self.groups_src:
|
|
396
|
-
|
|
531
|
+
names_by_group = set()
|
|
397
532
|
if group not in dataset_name:
|
|
398
533
|
dataset_name[group] = {}
|
|
399
534
|
dataset_info[group] = {}
|
|
400
535
|
for filename, _ in datasets[group]:
|
|
401
|
-
|
|
402
|
-
dataset_name[group][filename] = self.datasets[filename].
|
|
403
|
-
dataset_info[group][filename] = {
|
|
536
|
+
names_by_group.update(self.datasets[filename].get_names(group))
|
|
537
|
+
dataset_name[group][filename] = self.datasets[filename].get_names(group)
|
|
538
|
+
dataset_info[group][filename] = {
|
|
539
|
+
name: self.datasets[filename].get_infos(group, name) for name in dataset_name[group][filename]
|
|
540
|
+
}
|
|
404
541
|
if len(names) == 0:
|
|
405
|
-
names.update(
|
|
406
|
-
else:
|
|
407
|
-
names = names.intersection(
|
|
542
|
+
names.update(names_by_group)
|
|
543
|
+
else:
|
|
544
|
+
names = names.intersection(names_by_group)
|
|
408
545
|
if len(names) == 0:
|
|
409
|
-
|
|
410
|
-
f"No data was found for groups {list(self.groups_src.keys())}: although each group contains data
|
|
411
|
-
|
|
412
|
-
|
|
413
|
-
|
|
546
|
+
raise DatasetManagerError(
|
|
547
|
+
f"No data was found for groups {list(self.groups_src.keys())}: although each group contains data "
|
|
548
|
+
"from a dataset, there are no common dataset names shared across all groups, the intersection is empty."
|
|
549
|
+
)
|
|
550
|
+
|
|
551
|
+
subset_names: set[str] = set()
|
|
414
552
|
for group in dataset_name:
|
|
415
|
-
subset_names_bygroup = set()
|
|
553
|
+
subset_names_bygroup: set[str] = set()
|
|
416
554
|
for filename, append in datasets[group]:
|
|
417
555
|
if append:
|
|
418
|
-
subset_names_bygroup.update(
|
|
556
|
+
subset_names_bygroup.update(
|
|
557
|
+
self.subset(
|
|
558
|
+
dataset_name[group][filename],
|
|
559
|
+
dataset_info[group][filename],
|
|
560
|
+
)
|
|
561
|
+
)
|
|
419
562
|
else:
|
|
420
563
|
if len(subset_names_bygroup) == 0:
|
|
421
|
-
subset_names_bygroup.update(
|
|
564
|
+
subset_names_bygroup.update(
|
|
565
|
+
self.subset(
|
|
566
|
+
dataset_name[group][filename],
|
|
567
|
+
dataset_info[group][filename],
|
|
568
|
+
)
|
|
569
|
+
)
|
|
422
570
|
else:
|
|
423
|
-
subset_names_bygroup = subset_names_bygroup.intersection(
|
|
571
|
+
subset_names_bygroup = subset_names_bygroup.intersection(
|
|
572
|
+
self.subset(
|
|
573
|
+
dataset_name[group][filename],
|
|
574
|
+
dataset_info[group][filename],
|
|
575
|
+
)
|
|
576
|
+
)
|
|
424
577
|
if len(subset_names) == 0:
|
|
425
578
|
subset_names.update(subset_names_bygroup)
|
|
426
|
-
else:
|
|
579
|
+
else:
|
|
427
580
|
subset_names = subset_names.intersection(subset_names_bygroup)
|
|
581
|
+
|
|
428
582
|
if len(subset_names) == 0:
|
|
429
|
-
raise DatasetManagerError(
|
|
583
|
+
raise DatasetManagerError(
|
|
584
|
+
"All data entries were excluded by the subset filter.",
|
|
430
585
|
f"Dataset entries found: {', '.join(names)}",
|
|
431
586
|
f"Subset object applied: {self.subset}",
|
|
432
587
|
f"Subset requested : {', '.join(subset_names)}",
|
|
@@ -436,31 +591,38 @@ class Data(ABC):
|
|
|
436
591
|
"\tsubset: [0, 1] # explicit indices",
|
|
437
592
|
"\tsubset: 0:10 # slice notation",
|
|
438
593
|
"\tsubset: ./Validation.txt # external file",
|
|
439
|
-
"\tsubset: None # to disable filtering"
|
|
594
|
+
"\tsubset: None # to disable filtering",
|
|
440
595
|
)
|
|
441
|
-
|
|
442
|
-
data, map = self._getDatasets(list(subset_names), dataset_name)
|
|
443
596
|
|
|
444
|
-
|
|
445
|
-
|
|
597
|
+
data, mapping = self._get_datasets(list(subset_names), dataset_name)
|
|
598
|
+
|
|
599
|
+
train_mapping = mapping
|
|
600
|
+
validate_mapping = []
|
|
446
601
|
if isinstance(self.validation, float) or isinstance(self.validation, int):
|
|
447
602
|
if self.validation <= 0 or self.validation >= 1:
|
|
448
|
-
raise DatasetManagerError(
|
|
449
|
-
|
|
450
|
-
|
|
603
|
+
raise DatasetManagerError(
|
|
604
|
+
"Validation must be a float between 0 and 1.",
|
|
605
|
+
f"Received: {self.validation}",
|
|
606
|
+
"Example: validation = 0.2 # for a 20% validation split",
|
|
607
|
+
)
|
|
608
|
+
|
|
609
|
+
train_mapping, validate_mapping = (
|
|
610
|
+
mapping[: int(math.floor(len(mapping) * (1 - self.validation)))],
|
|
611
|
+
mapping[int(math.floor(len(mapping) * (1 - self.validation))) :],
|
|
612
|
+
)
|
|
451
613
|
elif isinstance(self.validation, str):
|
|
452
614
|
if ":" in self.validation:
|
|
453
|
-
index = list(range(int(self.
|
|
454
|
-
|
|
455
|
-
|
|
615
|
+
index = list(range(int(self.validation.split(":")[0]), int(self.validation.split(":")[1])))
|
|
616
|
+
train_mapping = [m for m in mapping if m[0] not in index]
|
|
617
|
+
validate_mapping = [m for m in mapping if m[0] in index]
|
|
456
618
|
elif os.path.exists(self.validation):
|
|
457
619
|
validation_names = []
|
|
458
|
-
with open(self.validation
|
|
620
|
+
with open(self.validation) as f:
|
|
459
621
|
for name in f:
|
|
460
622
|
validation_names.append(name.strip())
|
|
461
623
|
index = [i for i, n in enumerate(subset_names) if n in validation_names]
|
|
462
|
-
|
|
463
|
-
|
|
624
|
+
train_mapping = [m for m in mapping if m[0] not in index]
|
|
625
|
+
validate_mapping = [m for m in mapping if m[0] in index]
|
|
464
626
|
else:
|
|
465
627
|
raise DatasetManagerError(
|
|
466
628
|
f"Invalid string value for 'validation': '{self.validation}'",
|
|
@@ -470,94 +632,151 @@ class Data(ABC):
|
|
|
470
632
|
"\t• A float between 0 and 1 (e.g., 0.2)",
|
|
471
633
|
"\t• A list of sample names or indices",
|
|
472
634
|
"The provided value is neither a valid slice nor a readable file.",
|
|
473
|
-
"Please fix your 'validation' setting in the configuration."
|
|
474
|
-
|
|
475
|
-
|
|
635
|
+
"Please fix your 'validation' setting in the configuration.",
|
|
636
|
+
)
|
|
637
|
+
|
|
476
638
|
elif isinstance(self.validation, list):
|
|
477
639
|
if len(self.validation) > 0:
|
|
478
640
|
if isinstance(self.validation[0], int):
|
|
479
|
-
|
|
480
|
-
|
|
641
|
+
train_mapping = [m for m in mapping if m[0] not in self.validation]
|
|
642
|
+
validate_mapping = [m for m in mapping if m[0] in self.validation]
|
|
481
643
|
elif isinstance(self.validation[0], str):
|
|
482
644
|
index = [i for i, n in enumerate(subset_names) if n in self.validation]
|
|
483
|
-
|
|
484
|
-
|
|
645
|
+
train_mapping = [m for m in mapping if m[0] not in index]
|
|
646
|
+
validate_mapping = [m for m in mapping if m[0] in index]
|
|
485
647
|
else:
|
|
486
|
-
raise DatasetManagerError(
|
|
487
|
-
|
|
488
|
-
|
|
489
|
-
|
|
490
|
-
|
|
491
|
-
)
|
|
492
|
-
|
|
493
|
-
|
|
494
|
-
|
|
648
|
+
raise DatasetManagerError(
|
|
649
|
+
"Invalid list type for 'validation': elements of type "
|
|
650
|
+
f"'{type(self.validation[0]).__name__}' are not supported.",
|
|
651
|
+
"Supported list element types are:",
|
|
652
|
+
"\t• int → list of indices (e.g., [0, 1, 2])",
|
|
653
|
+
"\t• str → list of sample names (e.g., ['patient01', 'patient02'])",
|
|
654
|
+
f"Received list: {self.validation}",
|
|
655
|
+
)
|
|
656
|
+
if len(train_mapping) == 0:
|
|
657
|
+
raise DatasetManagerError(
|
|
658
|
+
"No data left for training after applying the validation split.",
|
|
659
|
+
f"Dataset size: {len(mapping)}",
|
|
495
660
|
f"Validation setting: {self.validation}",
|
|
496
|
-
"Please reduce the validation size, increase the dataset, or disable validation."
|
|
661
|
+
"Please reduce the validation size, increase the dataset, or disable validation.",
|
|
497
662
|
)
|
|
498
663
|
|
|
499
|
-
if self.validation is not None and len(
|
|
500
|
-
raise DatasetManagerError(
|
|
501
|
-
|
|
664
|
+
if self.validation is not None and len(validate_mapping) == 0:
|
|
665
|
+
raise DatasetManagerError(
|
|
666
|
+
"No data left for validation after applying the validation split.",
|
|
667
|
+
f"Dataset size: {len(mapping)}",
|
|
502
668
|
f"Validation setting: {self.validation}",
|
|
503
|
-
"Please increase the validation size, increase the dataset, or disable validation."
|
|
669
|
+
"Please increase the validation size, increase the dataset, or disable validation.",
|
|
504
670
|
)
|
|
505
|
-
|
|
506
|
-
|
|
507
|
-
|
|
508
|
-
for i, (
|
|
509
|
-
|
|
510
|
-
if len(
|
|
511
|
-
|
|
671
|
+
train_mappings = Data._split(train_mapping, world_size)
|
|
672
|
+
validate_mappings = Data._split(validate_mapping, world_size)
|
|
673
|
+
|
|
674
|
+
for i, (train_mapping, validate_mapping) in enumerate(zip(train_mappings, validate_mappings)):
|
|
675
|
+
mappings = [train_mapping]
|
|
676
|
+
if len(validate_mapping):
|
|
677
|
+
mappings += [validate_mapping]
|
|
512
678
|
self.data.append([])
|
|
513
|
-
self.
|
|
514
|
-
for
|
|
515
|
-
indexs = np.unique(np.asarray(
|
|
516
|
-
self.data[i].append({k:[v[it] for it in indexs] for k, v in data.items()})
|
|
517
|
-
|
|
679
|
+
self.mapping.append([])
|
|
680
|
+
for mapping_tmp in mappings:
|
|
681
|
+
indexs = np.unique(np.asarray(mapping_tmp)[:, 0])
|
|
682
|
+
self.data[i].append({k: [v[it] for it in indexs] for k, v in data.items()})
|
|
683
|
+
mapping_tmp_array = np.asarray(mapping_tmp)
|
|
518
684
|
for a, b in enumerate(indexs):
|
|
519
|
-
|
|
520
|
-
self.
|
|
685
|
+
mapping_tmp_array[np.where(np.asarray(mapping_tmp_array)[:, 0] == b), 0] = a
|
|
686
|
+
self.mapping[i].append([(a, b, c) for a, b, c in mapping_tmp_array])
|
|
687
|
+
|
|
688
|
+
data_loaders: list[list[DataLoader]] = []
|
|
689
|
+
for i, (datas, mappings) in enumerate(zip(self.data, self.mapping)):
|
|
690
|
+
data_loaders.append([])
|
|
691
|
+
for data, mapping in zip(datas, mappings):
|
|
692
|
+
data_loaders[i].append(
|
|
693
|
+
DataLoader(
|
|
694
|
+
dataset=self.datasetIter(
|
|
695
|
+
rank=i,
|
|
696
|
+
data=data,
|
|
697
|
+
mapping=mapping,
|
|
698
|
+
),
|
|
699
|
+
sampler=CustomSampler(len(mapping), self.subset.shuffle),
|
|
700
|
+
batch_size=self.batch_size,
|
|
701
|
+
**self.dataLoader_args,
|
|
702
|
+
)
|
|
703
|
+
)
|
|
704
|
+
return data_loaders
|
|
521
705
|
|
|
522
|
-
dataLoaders: list[list[DataLoader]] = []
|
|
523
|
-
for i, (datas, maps) in enumerate(zip(self.data, self.map)):
|
|
524
|
-
dataLoaders.append([])
|
|
525
|
-
for data, map in zip(datas, maps):
|
|
526
|
-
dataLoaders[i].append(DataLoader(dataset=DatasetIter(rank=i, data=data, map=map, **self.dataSet_args), sampler=CustomSampler(len(map), self.subset.shuffle), batch_size=self.batch_size,**self.dataLoader_args))
|
|
527
|
-
return dataLoaders
|
|
528
706
|
|
|
529
707
|
class DataTrain(Data):
|
|
530
708
|
|
|
531
709
|
@config("Dataset")
|
|
532
|
-
def __init__(
|
|
533
|
-
|
|
534
|
-
|
|
535
|
-
|
|
536
|
-
|
|
537
|
-
|
|
538
|
-
|
|
539
|
-
|
|
540
|
-
|
|
541
|
-
|
|
710
|
+
def __init__(
|
|
711
|
+
self,
|
|
712
|
+
dataset_filenames: list[str] = ["default:./Dataset"],
|
|
713
|
+
groups_src: dict[str, Group] = {"default:group_src": Group()},
|
|
714
|
+
augmentations: dict[str, DataAugmentationsList] | None = {"DataAugmentation_0": DataAugmentationsList()},
|
|
715
|
+
inline_augmentations: bool = False,
|
|
716
|
+
patch: DatasetPatch | None = DatasetPatch(),
|
|
717
|
+
use_cache: bool = True,
|
|
718
|
+
subset: TrainSubset = TrainSubset(),
|
|
719
|
+
batch_size: int = 1,
|
|
720
|
+
validation: float | str | list[int] | list[str] = 0.2,
|
|
721
|
+
) -> None:
|
|
722
|
+
super().__init__(
|
|
723
|
+
dataset_filenames,
|
|
724
|
+
groups_src,
|
|
725
|
+
patch,
|
|
726
|
+
use_cache,
|
|
727
|
+
subset,
|
|
728
|
+
batch_size,
|
|
729
|
+
validation,
|
|
730
|
+
inline_augmentations,
|
|
731
|
+
augmentations if augmentations else {},
|
|
732
|
+
)
|
|
733
|
+
|
|
542
734
|
|
|
543
735
|
class DataPrediction(Data):
|
|
544
736
|
|
|
545
737
|
@config("Dataset")
|
|
546
|
-
def __init__(
|
|
547
|
-
|
|
548
|
-
|
|
549
|
-
|
|
550
|
-
|
|
551
|
-
|
|
738
|
+
def __init__(
|
|
739
|
+
self,
|
|
740
|
+
dataset_filenames: list[str] = ["default:./Dataset"],
|
|
741
|
+
groups_src: dict[str, Group] = {"default": Group()},
|
|
742
|
+
augmentations: dict[str, DataAugmentationsList] | None = {"DataAugmentation_0": DataAugmentationsList()},
|
|
743
|
+
patch: DatasetPatch | None = DatasetPatch(),
|
|
744
|
+
subset: PredictionSubset = PredictionSubset(),
|
|
745
|
+
batch_size: int = 1,
|
|
746
|
+
) -> None:
|
|
747
|
+
|
|
748
|
+
super().__init__(
|
|
749
|
+
dataset_filenames=dataset_filenames,
|
|
750
|
+
groups_src=groups_src,
|
|
751
|
+
patch=patch,
|
|
752
|
+
use_cache=False,
|
|
753
|
+
subset=subset,
|
|
754
|
+
batch_size=batch_size,
|
|
755
|
+
validation=None,
|
|
756
|
+
inline_augmentations=False,
|
|
757
|
+
data_augmentations_list=augmentations if augmentations else {},
|
|
758
|
+
)
|
|
552
759
|
|
|
553
|
-
super().__init__(dataset_filenames, groups_src, patch, False, subset, batch_size, dataAugmentationsList=augmentations if augmentations else {})
|
|
554
760
|
|
|
555
761
|
class DataMetric(Data):
|
|
556
762
|
|
|
557
763
|
@config("Dataset")
|
|
558
|
-
def __init__(
|
|
559
|
-
|
|
560
|
-
|
|
561
|
-
|
|
562
|
-
|
|
563
|
-
|
|
764
|
+
def __init__(
|
|
765
|
+
self,
|
|
766
|
+
dataset_filenames: list[str] = ["default:./Dataset"],
|
|
767
|
+
groups_src: dict[str, GroupMetric] = {"default": GroupMetric()},
|
|
768
|
+
subset: PredictionSubset = PredictionSubset(),
|
|
769
|
+
validation: str | None = None,
|
|
770
|
+
) -> None:
|
|
771
|
+
|
|
772
|
+
super().__init__(
|
|
773
|
+
dataset_filenames=dataset_filenames,
|
|
774
|
+
groups_src=groups_src,
|
|
775
|
+
patch=None,
|
|
776
|
+
use_cache=False,
|
|
777
|
+
subset=subset,
|
|
778
|
+
batch_size=1,
|
|
779
|
+
validation=validation,
|
|
780
|
+
data_augmentations_list={},
|
|
781
|
+
inline_augmentations=False,
|
|
782
|
+
)
|