konfai 1.1.8__py3-none-any.whl → 1.2.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of konfai might be problematic. Click here for more details.
- konfai/__init__.py +59 -14
- konfai/data/augmentation.py +457 -286
- konfai/data/data_manager.py +533 -316
- konfai/data/patching.py +300 -183
- konfai/data/transform.py +408 -275
- konfai/evaluator.py +325 -68
- konfai/main.py +71 -22
- konfai/metric/measure.py +360 -244
- konfai/metric/schedulers.py +24 -13
- konfai/models/classification/convNeXt.py +187 -81
- konfai/models/classification/resnet.py +272 -58
- konfai/models/generation/cStyleGan.py +233 -59
- konfai/models/generation/ddpm.py +348 -121
- konfai/models/generation/diffusionGan.py +757 -358
- konfai/models/generation/gan.py +177 -53
- konfai/models/generation/vae.py +140 -40
- konfai/models/registration/registration.py +135 -52
- konfai/models/representation/representation.py +57 -23
- konfai/models/segmentation/NestedUNet.py +339 -68
- konfai/models/segmentation/UNet.py +140 -30
- konfai/network/blocks.py +331 -187
- konfai/network/network.py +795 -427
- konfai/predictor.py +644 -238
- konfai/trainer.py +509 -222
- konfai/utils/ITK.py +191 -106
- konfai/utils/config.py +152 -95
- konfai/utils/dataset.py +326 -455
- konfai/utils/utils.py +497 -249
- {konfai-1.1.8.dist-info → konfai-1.2.0.dist-info}/METADATA +1 -3
- konfai-1.2.0.dist-info/RECORD +38 -0
- konfai/utils/registration.py +0 -199
- konfai-1.1.8.dist-info/RECORD +0 -39
- {konfai-1.1.8.dist-info → konfai-1.2.0.dist-info}/WHEEL +0 -0
- {konfai-1.1.8.dist-info → konfai-1.2.0.dist-info}/entry_points.txt +0 -0
- {konfai-1.1.8.dist-info → konfai-1.2.0.dist-info}/licenses/LICENSE +0 -0
- {konfai-1.1.8.dist-info → konfai-1.2.0.dist-info}/top_level.txt +0 -0
konfai/data/data_manager.py
CHANGED
|
@@ -1,85 +1,115 @@
|
|
|
1
1
|
import math
|
|
2
2
|
import os
|
|
3
3
|
import random
|
|
4
|
+
import threading
|
|
5
|
+
from abc import ABC, abstractmethod
|
|
6
|
+
from collections.abc import Iterator, Mapping
|
|
7
|
+
from concurrent.futures import ThreadPoolExecutor, as_completed
|
|
8
|
+
from functools import partial
|
|
9
|
+
from typing import cast
|
|
10
|
+
|
|
11
|
+
import numpy as np
|
|
4
12
|
import torch
|
|
5
|
-
from torch.utils import data
|
|
6
13
|
import tqdm
|
|
7
|
-
import numpy as np
|
|
8
|
-
from abc import ABC
|
|
9
|
-
from torch.utils.data import DataLoader, Sampler
|
|
10
|
-
from typing import Union, Iterator
|
|
11
|
-
from concurrent.futures import ThreadPoolExecutor, as_completed
|
|
12
|
-
import threading
|
|
13
14
|
from torch.cuda import device_count
|
|
14
|
-
|
|
15
|
+
from torch.utils import data
|
|
16
|
+
from torch.utils.data import DataLoader, Sampler
|
|
15
17
|
|
|
16
|
-
from konfai import
|
|
17
|
-
from konfai.data.patching import DatasetPatch, DatasetManager
|
|
18
|
-
from konfai.utils.config import config
|
|
19
|
-
from konfai.utils.utils import memoryInfo, cpuInfo, memoryForecast, getMemory, State, SUPPORTED_EXTENSIONS, DatasetManagerError
|
|
20
|
-
from konfai.utils.dataset import Dataset, Attribute
|
|
21
|
-
from konfai.data.transform import TransformLoader, Transform
|
|
18
|
+
from konfai import konfai_root, konfai_state
|
|
22
19
|
from konfai.data.augmentation import DataAugmentationsList
|
|
20
|
+
from konfai.data.patching import DatasetManager, DatasetPatch
|
|
21
|
+
from konfai.data.transform import Transform, TransformLoader
|
|
22
|
+
from konfai.utils.config import config
|
|
23
|
+
from konfai.utils.dataset import Attribute, Dataset
|
|
24
|
+
from konfai.utils.utils import (
|
|
25
|
+
SUPPORTED_EXTENSIONS,
|
|
26
|
+
DatasetManagerError,
|
|
27
|
+
State,
|
|
28
|
+
get_cpu_info,
|
|
29
|
+
get_memory,
|
|
30
|
+
get_memory_info,
|
|
31
|
+
memory_forecast,
|
|
32
|
+
)
|
|
33
|
+
|
|
23
34
|
|
|
24
35
|
class GroupTransform:
|
|
25
36
|
|
|
26
37
|
@config()
|
|
27
|
-
def __init__(
|
|
28
|
-
|
|
29
|
-
|
|
30
|
-
|
|
31
|
-
|
|
32
|
-
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
|
|
44
|
-
|
|
45
|
-
|
|
46
|
-
|
|
47
|
-
|
|
48
|
-
|
|
49
|
-
|
|
50
|
-
|
|
51
|
-
|
|
52
|
-
|
|
53
|
-
|
|
54
|
-
|
|
55
|
-
|
|
56
|
-
|
|
57
|
-
|
|
58
|
-
|
|
38
|
+
def __init__(
|
|
39
|
+
self,
|
|
40
|
+
transforms: dict[str, TransformLoader] = {
|
|
41
|
+
"default:Normalize:Standardize:Unsqueeze:TensorCast:ResampleIsotropic:ResampleResize": TransformLoader()
|
|
42
|
+
},
|
|
43
|
+
patch_transforms: dict[str, TransformLoader] = {
|
|
44
|
+
"default:Normalize:Standardize:Unsqueeze:TensorCast:ResampleIsotropic:ResampleResize": TransformLoader()
|
|
45
|
+
},
|
|
46
|
+
is_input: bool = True,
|
|
47
|
+
) -> None:
|
|
48
|
+
self._transforms = transforms
|
|
49
|
+
self._patch_transforms = patch_transforms
|
|
50
|
+
self.transforms: list[Transform] = []
|
|
51
|
+
self.patch_transforms: list[Transform] = []
|
|
52
|
+
self.is_input = is_input
|
|
53
|
+
|
|
54
|
+
def load(self, group_src: str, group_dest: str, datasets: list[Dataset]):
|
|
55
|
+
if self._transforms is not None:
|
|
56
|
+
for classpath, transform_loader in self._transforms.items():
|
|
57
|
+
transform = transform_loader.get_transform(
|
|
58
|
+
classpath,
|
|
59
|
+
konfai_args=f"{konfai_root()}.Dataset.groups_src.{group_src}.groups_dest.{group_dest}.transforms",
|
|
60
|
+
)
|
|
61
|
+
transform.set_datasets(datasets)
|
|
62
|
+
self.transforms.append(transform)
|
|
63
|
+
|
|
64
|
+
if self._patch_transforms is not None:
|
|
65
|
+
for classpath, transform_loader in self._patch_transforms.items():
|
|
66
|
+
transform = transform_loader.get_transform(
|
|
67
|
+
classpath,
|
|
68
|
+
konfai_args=f"{konfai_root()}.Dataset.groups_src.{group_src}"
|
|
69
|
+
f".groups_dest.{group_dest}.patch_transforms",
|
|
70
|
+
)
|
|
71
|
+
transform.set_datasets(datasets)
|
|
72
|
+
self.patch_transforms.append(transform)
|
|
73
|
+
|
|
59
74
|
def to(self, device: int):
|
|
60
|
-
for transform in self.
|
|
61
|
-
transform.
|
|
62
|
-
for transform in self.
|
|
63
|
-
transform.
|
|
75
|
+
for transform in self.transforms:
|
|
76
|
+
transform.to(device)
|
|
77
|
+
for transform in self.patch_transforms:
|
|
78
|
+
transform.to(device)
|
|
79
|
+
|
|
64
80
|
|
|
65
81
|
class GroupTransformMetric(GroupTransform):
|
|
66
82
|
|
|
67
83
|
@config()
|
|
68
|
-
def __init__(
|
|
84
|
+
def __init__(
|
|
85
|
+
self,
|
|
86
|
+
transforms: dict[str, TransformLoader] = {
|
|
87
|
+
"default:Normalize:Standardize:Unsqueeze:TensorCast:ResampleIsotropic:ResampleResize": TransformLoader()
|
|
88
|
+
},
|
|
89
|
+
):
|
|
69
90
|
super().__init__(transforms, None)
|
|
70
91
|
|
|
92
|
+
|
|
71
93
|
class Group(dict[str, GroupTransform]):
|
|
72
94
|
|
|
73
95
|
@config()
|
|
74
|
-
def __init__(
|
|
96
|
+
def __init__(
|
|
97
|
+
self,
|
|
98
|
+
groups_dest: dict[str, GroupTransform] = {"default:group_dest": GroupTransform()},
|
|
99
|
+
):
|
|
75
100
|
super().__init__(groups_dest)
|
|
76
101
|
|
|
102
|
+
|
|
77
103
|
class GroupMetric(dict[str, GroupTransformMetric]):
|
|
78
104
|
|
|
79
105
|
@config()
|
|
80
|
-
def __init__(
|
|
106
|
+
def __init__(
|
|
107
|
+
self,
|
|
108
|
+
groups_dest: dict[str, GroupTransformMetric] = {"default:group_dest": GroupTransformMetric()},
|
|
109
|
+
):
|
|
81
110
|
super().__init__(groups_dest)
|
|
82
111
|
|
|
112
|
+
|
|
83
113
|
class CustomSampler(Sampler[int]):
|
|
84
114
|
|
|
85
115
|
def __init__(self, size: int, shuffle: bool = False) -> None:
|
|
@@ -87,346 +117,477 @@ class CustomSampler(Sampler[int]):
|
|
|
87
117
|
self.shuffle = shuffle
|
|
88
118
|
|
|
89
119
|
def __iter__(self) -> Iterator[int]:
|
|
90
|
-
return iter(torch.randperm(len(self)).tolist() if self.shuffle else list(range(len(self)))
|
|
120
|
+
return iter(torch.randperm(len(self)).tolist() if self.shuffle else list(range(len(self))))
|
|
91
121
|
|
|
92
122
|
def __len__(self) -> int:
|
|
93
123
|
return self.size
|
|
94
124
|
|
|
125
|
+
|
|
95
126
|
class DatasetIter(data.Dataset):
|
|
96
127
|
|
|
97
|
-
def __init__(
|
|
128
|
+
def __init__(
|
|
129
|
+
self,
|
|
130
|
+
rank: int,
|
|
131
|
+
data: dict[str, list[DatasetManager]],
|
|
132
|
+
mapping: list[tuple[int, int, int]],
|
|
133
|
+
groups_src: Mapping[str, Group | GroupMetric],
|
|
134
|
+
inline_augmentations: bool,
|
|
135
|
+
data_augmentations_list: list[DataAugmentationsList],
|
|
136
|
+
patch_size: list[int] | None,
|
|
137
|
+
overlap: int | None,
|
|
138
|
+
buffer_size: int,
|
|
139
|
+
use_cache=True,
|
|
140
|
+
) -> None:
|
|
98
141
|
self.rank = rank
|
|
99
142
|
self.data = data
|
|
100
|
-
self.
|
|
143
|
+
self.mapping = mapping
|
|
101
144
|
self.patch_size = patch_size
|
|
102
145
|
self.overlap = overlap
|
|
103
146
|
self.groups_src = groups_src
|
|
104
|
-
self.
|
|
147
|
+
self.data_augmentations_list = data_augmentations_list
|
|
105
148
|
self.use_cache = use_cache
|
|
106
149
|
self.nb_dataset = len(data[list(data.keys())[0]])
|
|
107
150
|
self.buffer_size = buffer_size
|
|
108
|
-
self._index_cache =
|
|
109
|
-
self.
|
|
110
|
-
self.inlineAugmentations = inlineAugmentations
|
|
151
|
+
self._index_cache: list[int] = []
|
|
152
|
+
self.inline_augmentations = inline_augmentations
|
|
111
153
|
|
|
112
|
-
def
|
|
154
|
+
def get_patch_config(self) -> tuple[list[int] | None, int | None]:
|
|
113
155
|
return self.patch_size, self.overlap
|
|
114
|
-
|
|
156
|
+
|
|
115
157
|
def to(self, device: int):
|
|
116
158
|
for group_src in self.groups_src:
|
|
117
159
|
for group_dest in self.groups_src[group_src]:
|
|
118
160
|
self.groups_src[group_src][group_dest].to(device)
|
|
119
|
-
self.
|
|
161
|
+
for data_augmentations in self.data_augmentations_list:
|
|
162
|
+
for data_augmentation in data_augmentations.data_augmentations:
|
|
163
|
+
data_augmentation.to(device)
|
|
120
164
|
|
|
121
|
-
def
|
|
165
|
+
def get_dataset_from_index(self, group_dest: str, index: int) -> DatasetManager:
|
|
122
166
|
return self.data[group_dest][index]
|
|
123
|
-
|
|
124
|
-
def
|
|
125
|
-
if self.
|
|
167
|
+
|
|
168
|
+
def reset_augmentation(self, label):
|
|
169
|
+
if self.inline_augmentations and len(self.data_augmentations_list) > 0:
|
|
126
170
|
for index in range(self.nb_dataset):
|
|
127
171
|
for group_src in self.groups_src:
|
|
128
172
|
for group_dest in self.groups_src[group_src]:
|
|
129
|
-
self.data[group_dest][index].
|
|
130
|
-
self.data[group_dest][index].
|
|
173
|
+
self.data[group_dest][index].unload_augmentation()
|
|
174
|
+
self.data[group_dest][index].reset_augmentation()
|
|
131
175
|
self.load(label + " Augmentation")
|
|
132
176
|
|
|
133
177
|
def load(self, label: str):
|
|
134
178
|
if self.use_cache:
|
|
135
|
-
memory_init =
|
|
179
|
+
memory_init = get_memory()
|
|
136
180
|
|
|
137
|
-
indexs =
|
|
181
|
+
indexs = list(range(self.nb_dataset))
|
|
138
182
|
if len(indexs) > 0:
|
|
139
183
|
memory_lock = threading.Lock()
|
|
140
|
-
|
|
141
|
-
|
|
142
|
-
|
|
143
|
-
|
|
144
|
-
|
|
145
|
-
|
|
184
|
+
|
|
185
|
+
def desc():
|
|
186
|
+
return (
|
|
187
|
+
f"Caching {label}: "
|
|
188
|
+
f"{get_memory_info()} | "
|
|
189
|
+
f"{memory_forecast(memory_init, 0, self.nb_dataset)} | "
|
|
190
|
+
f"{get_cpu_info()}"
|
|
191
|
+
)
|
|
192
|
+
|
|
193
|
+
pbar = tqdm.tqdm(total=len(indexs), desc=desc(), leave=False)
|
|
146
194
|
|
|
147
195
|
def process(index):
|
|
148
|
-
self.
|
|
196
|
+
self._load_data(index)
|
|
149
197
|
with memory_lock:
|
|
150
198
|
pbar.set_description(desc())
|
|
151
199
|
pbar.update(1)
|
|
152
|
-
|
|
200
|
+
|
|
201
|
+
cpu_count = os.cpu_count() or 1
|
|
202
|
+
with ThreadPoolExecutor(
|
|
203
|
+
max_workers=cpu_count // (device_count() if device_count() > 0 else 1)
|
|
204
|
+
) as executor:
|
|
153
205
|
futures = [executor.submit(process, index) for index in indexs]
|
|
154
206
|
for _ in as_completed(futures):
|
|
155
207
|
pass
|
|
156
208
|
|
|
157
209
|
pbar.close()
|
|
158
|
-
|
|
159
|
-
def
|
|
210
|
+
|
|
211
|
+
def _load_data(self, index):
|
|
160
212
|
if index not in self._index_cache:
|
|
161
213
|
self._index_cache.append(index)
|
|
162
214
|
for group_src in self.groups_src:
|
|
163
215
|
for group_dest in self.groups_src[group_src]:
|
|
164
|
-
self.
|
|
216
|
+
self.load_data(group_src, group_dest, index)
|
|
165
217
|
|
|
166
|
-
def
|
|
167
|
-
self.data[group_dest][index].load(
|
|
218
|
+
def load_data(self, group_src: str, group_dest: str, index: int) -> None:
|
|
219
|
+
self.data[group_dest][index].load(
|
|
220
|
+
self.groups_src[group_src][group_dest].transforms,
|
|
221
|
+
self.data_augmentations_list,
|
|
222
|
+
)
|
|
168
223
|
|
|
169
|
-
def
|
|
224
|
+
def _unload_data(self, index: int) -> None:
|
|
170
225
|
if index in self._index_cache:
|
|
171
226
|
self._index_cache.remove(index)
|
|
172
227
|
for group_src in self.groups_src:
|
|
173
228
|
for group_dest in self.groups_src[group_src]:
|
|
174
|
-
self.
|
|
175
|
-
|
|
176
|
-
def
|
|
229
|
+
self.unload_data(group_dest, index)
|
|
230
|
+
|
|
231
|
+
def unload_data(self, group_dest: str, index: int) -> None:
|
|
177
232
|
return self.data[group_dest][index].unload()
|
|
178
233
|
|
|
179
234
|
def __len__(self) -> int:
|
|
180
|
-
return len(self.
|
|
235
|
+
return len(self.mapping)
|
|
181
236
|
|
|
182
|
-
def __getitem__(self, index
|
|
237
|
+
def __getitem__(self, index: int) -> dict[str, tuple[torch.Tensor, int, int, int, str, bool]]:
|
|
183
238
|
data = {}
|
|
184
|
-
x, a, p = self.
|
|
239
|
+
x, a, p = self.mapping[index]
|
|
185
240
|
if x not in self._index_cache:
|
|
186
241
|
if len(self._index_cache) >= self.buffer_size and not self.use_cache:
|
|
187
|
-
self.
|
|
188
|
-
self.
|
|
242
|
+
self._unload_data(self._index_cache[0])
|
|
243
|
+
self._load_data(x)
|
|
189
244
|
|
|
190
245
|
for group_src in self.groups_src:
|
|
191
246
|
for group_dest in self.groups_src[group_src]:
|
|
192
247
|
dataset = self.data[group_dest][x]
|
|
193
|
-
data["{}"
|
|
248
|
+
data[f"{group_dest}"] = (
|
|
249
|
+
dataset.get_data(
|
|
250
|
+
p,
|
|
251
|
+
a,
|
|
252
|
+
self.groups_src[group_src][group_dest].patch_transforms,
|
|
253
|
+
self.groups_src[group_src][group_dest].is_input,
|
|
254
|
+
),
|
|
255
|
+
x,
|
|
256
|
+
a,
|
|
257
|
+
p,
|
|
258
|
+
dataset.name,
|
|
259
|
+
self.groups_src[group_src][group_dest].is_input,
|
|
260
|
+
)
|
|
194
261
|
return data
|
|
195
262
|
|
|
196
|
-
|
|
197
|
-
|
|
198
|
-
|
|
263
|
+
|
|
264
|
+
class Subset:
|
|
265
|
+
|
|
266
|
+
def __init__(
|
|
267
|
+
self,
|
|
268
|
+
subset: str | list[int] | list[str] | None = None,
|
|
269
|
+
shuffle: bool = True,
|
|
270
|
+
) -> None:
|
|
199
271
|
self.subset = subset
|
|
200
272
|
self.shuffle = shuffle
|
|
201
273
|
|
|
202
|
-
def
|
|
203
|
-
inter_name = set(names[0])
|
|
204
|
-
for n in names[1:]:
|
|
205
|
-
inter_name = inter_name.intersection(set(n))
|
|
206
|
-
names = sorted(list(inter_name))
|
|
207
|
-
|
|
274
|
+
def _get_index(self, subset: str | int, names: list[str]) -> list[int]:
|
|
208
275
|
size = len(names)
|
|
209
276
|
index = []
|
|
277
|
+
if isinstance(subset, int):
|
|
278
|
+
index.append(subset)
|
|
279
|
+
elif ":" in subset:
|
|
280
|
+
r = np.clip(
|
|
281
|
+
np.asarray([int(subset.split(":")[0]), int(subset.split(":")[1])]),
|
|
282
|
+
0,
|
|
283
|
+
size,
|
|
284
|
+
)
|
|
285
|
+
index = list(range(r[0], r[1]))
|
|
286
|
+
elif os.path.exists(subset):
|
|
287
|
+
train_names = []
|
|
288
|
+
with open(subset) as f:
|
|
289
|
+
for name in f:
|
|
290
|
+
train_names.append(name.strip())
|
|
291
|
+
index = []
|
|
292
|
+
for i, name in enumerate(names):
|
|
293
|
+
if name in train_names:
|
|
294
|
+
index.append(i)
|
|
295
|
+
elif subset.startswith("~") and os.path.exists(subset[1:]):
|
|
296
|
+
exclude_names = []
|
|
297
|
+
with open(subset[1:]) as f:
|
|
298
|
+
for name in f:
|
|
299
|
+
exclude_names.append(name.strip())
|
|
300
|
+
index = []
|
|
301
|
+
for i, name in enumerate(names):
|
|
302
|
+
if name not in exclude_names:
|
|
303
|
+
index.append(i)
|
|
304
|
+
return index
|
|
305
|
+
|
|
306
|
+
def __call__(self, names: list[str], infos: dict[str, tuple[list[int], Attribute]]) -> set[str]:
|
|
307
|
+
names = sorted(names)
|
|
308
|
+
size = len(names)
|
|
309
|
+
|
|
210
310
|
if self.subset is None:
|
|
211
311
|
index = list(range(0, size))
|
|
212
|
-
elif isinstance(self.subset, str):
|
|
213
|
-
if ":" in self.subset:
|
|
214
|
-
r = np.clip(np.asarray([int(self.subset.split(":")[0]), int(self.subset.split(":")[1])]), 0, size)
|
|
215
|
-
index = list(range(r[0], r[1]))
|
|
216
|
-
elif os.path.exists(self.subset):
|
|
217
|
-
train_names = []
|
|
218
|
-
with open(self.subset, "r") as f:
|
|
219
|
-
for name in f:
|
|
220
|
-
train_names.append(name.strip())
|
|
221
|
-
index = []
|
|
222
|
-
for i, name in enumerate(names):
|
|
223
|
-
if name in train_names:
|
|
224
|
-
index.append(i)
|
|
225
|
-
elif self.subset.startswith("~") and os.path.exists(self.subset[1:]):
|
|
226
|
-
exclude_names = []
|
|
227
|
-
with open(self.subset[1:], "r") as f:
|
|
228
|
-
for name in f:
|
|
229
|
-
exclude_names.append(name.strip())
|
|
230
|
-
index = []
|
|
231
|
-
for i, name in enumerate(names):
|
|
232
|
-
if name not in exclude_names:
|
|
233
|
-
index.append(i)
|
|
234
|
-
|
|
235
312
|
elif isinstance(self.subset, list):
|
|
236
|
-
|
|
237
|
-
|
|
238
|
-
|
|
239
|
-
|
|
240
|
-
|
|
241
|
-
|
|
242
|
-
|
|
243
|
-
|
|
244
|
-
|
|
245
|
-
|
|
246
|
-
index.append(i)
|
|
313
|
+
index_set: set[int] = set()
|
|
314
|
+
for s in self.subset:
|
|
315
|
+
if len(index_set) == 0:
|
|
316
|
+
index_set.update(set(self._get_index(s, names)))
|
|
317
|
+
else:
|
|
318
|
+
index_set = index_set.intersection(set(self._get_index(s, names)))
|
|
319
|
+
index = list(index_set)
|
|
320
|
+
print(index)
|
|
321
|
+
else:
|
|
322
|
+
index = self._get_index(self.subset, names)
|
|
247
323
|
if self.shuffle:
|
|
248
|
-
index = random.sample(index, len(index))
|
|
249
|
-
return
|
|
250
|
-
|
|
324
|
+
index = random.sample(index, len(index)) # nosec B311
|
|
325
|
+
return {names[i] for i in index}
|
|
326
|
+
|
|
251
327
|
def __str__(self):
|
|
252
|
-
return "Subset : " + str(self.subset) + " shuffle : "+ str(self.shuffle)
|
|
253
|
-
|
|
328
|
+
return "Subset : " + str(self.subset) + " shuffle : " + str(self.shuffle)
|
|
329
|
+
|
|
330
|
+
|
|
254
331
|
class TrainSubset(Subset):
|
|
255
332
|
|
|
256
333
|
@config()
|
|
257
|
-
def __init__(
|
|
334
|
+
def __init__(
|
|
335
|
+
self,
|
|
336
|
+
subset: str | list[int] | list[str] | None = None,
|
|
337
|
+
shuffle: bool = True,
|
|
338
|
+
) -> None:
|
|
258
339
|
super().__init__(subset, shuffle)
|
|
259
340
|
|
|
341
|
+
|
|
260
342
|
class PredictionSubset(Subset):
|
|
261
343
|
|
|
262
344
|
@config()
|
|
263
|
-
def __init__(self, subset:
|
|
345
|
+
def __init__(self, subset: str | list[int] | list[str] | None = None) -> None:
|
|
264
346
|
super().__init__(subset, False)
|
|
265
347
|
|
|
348
|
+
|
|
266
349
|
class Data(ABC):
|
|
267
|
-
|
|
268
|
-
|
|
269
|
-
|
|
270
|
-
|
|
271
|
-
|
|
272
|
-
|
|
273
|
-
|
|
274
|
-
|
|
275
|
-
|
|
276
|
-
|
|
350
|
+
|
|
351
|
+
@abstractmethod
|
|
352
|
+
def __init__(
|
|
353
|
+
self,
|
|
354
|
+
dataset_filenames: list[str],
|
|
355
|
+
groups_src: Mapping[str, Group | GroupMetric],
|
|
356
|
+
patch: DatasetPatch | None,
|
|
357
|
+
use_cache: bool,
|
|
358
|
+
subset: Subset,
|
|
359
|
+
batch_size: int,
|
|
360
|
+
validation: float | str | list[int] | list[str] | None,
|
|
361
|
+
inline_augmentations: bool,
|
|
362
|
+
data_augmentations_list: dict[str, DataAugmentationsList],
|
|
363
|
+
) -> None:
|
|
277
364
|
self.dataset_filenames = dataset_filenames
|
|
278
365
|
self.subset = subset
|
|
279
366
|
self.groups_src = groups_src
|
|
280
367
|
self.patch = patch
|
|
281
368
|
self.validation = validation
|
|
282
|
-
self.
|
|
369
|
+
self.data_augmentations_list = data_augmentations_list
|
|
283
370
|
self.batch_size = batch_size
|
|
284
|
-
|
|
285
|
-
self.
|
|
286
|
-
|
|
287
|
-
|
|
371
|
+
|
|
372
|
+
self.datasetIter = partial(
|
|
373
|
+
DatasetIter,
|
|
374
|
+
groups_src=self.groups_src,
|
|
375
|
+
inline_augmentations=inline_augmentations,
|
|
376
|
+
data_augmentations_list=list(self.data_augmentations_list.values()),
|
|
377
|
+
patch_size=self.patch.patch_size if self.patch is not None else None,
|
|
378
|
+
overlap=self.patch.overlap if self.patch is not None else None,
|
|
379
|
+
buffer_size=batch_size + 1,
|
|
380
|
+
use_cache=use_cache,
|
|
381
|
+
)
|
|
382
|
+
self.dataLoader_args = {
|
|
383
|
+
"num_workers": int(os.environ["KONFAI_WORKERS"]) if use_cache else 0,
|
|
384
|
+
"pin_memory": True,
|
|
385
|
+
}
|
|
386
|
+
self.data: list[list[dict[str, list[DatasetManager]]]] = []
|
|
387
|
+
self.mapping: list[list[list[tuple[int, int, int]]]] = []
|
|
288
388
|
self.datasets: dict[str, Dataset] = {}
|
|
289
389
|
|
|
290
|
-
def
|
|
390
|
+
def _get_datasets(
|
|
391
|
+
self, names: list[str], dataset_name: dict[str, dict[str, list[str]]]
|
|
392
|
+
) -> tuple[dict[str, list[DatasetManager]], list[tuple[int, int, int]]]:
|
|
291
393
|
nb_dataset = len(names)
|
|
292
|
-
nb_patch
|
|
394
|
+
nb_patch: list[list[int]]
|
|
293
395
|
data = {}
|
|
294
|
-
|
|
295
|
-
nb_augmentation = np.max(
|
|
396
|
+
mapping = []
|
|
397
|
+
nb_augmentation = np.max(
|
|
398
|
+
[
|
|
399
|
+
int(np.sum([data_augmentation.nb for data_augmentation in self.data_augmentations_list.values()]) + 1),
|
|
400
|
+
1,
|
|
401
|
+
]
|
|
402
|
+
)
|
|
296
403
|
for group_src in self.groups_src:
|
|
297
404
|
for group_dest in self.groups_src[group_src]:
|
|
298
|
-
data[group_dest] = [
|
|
299
|
-
|
|
405
|
+
data[group_dest] = [
|
|
406
|
+
DatasetManager(
|
|
407
|
+
i,
|
|
408
|
+
group_src,
|
|
409
|
+
group_dest,
|
|
410
|
+
name,
|
|
411
|
+
self.datasets[
|
|
412
|
+
[filename for filename, names in dataset_name[group_src].items() if name in names][0]
|
|
413
|
+
],
|
|
414
|
+
patch=self.patch,
|
|
415
|
+
transforms=self.groups_src[group_src][group_dest].transforms,
|
|
416
|
+
data_augmentations_list=list(self.data_augmentations_list.values()),
|
|
417
|
+
)
|
|
418
|
+
for i, name in enumerate(names)
|
|
419
|
+
]
|
|
420
|
+
nb_patch = [[dataset.get_size(a) for a in range(nb_augmentation)] for dataset in data[group_dest]]
|
|
300
421
|
|
|
301
422
|
for x in range(nb_dataset):
|
|
302
423
|
for y in range(nb_augmentation):
|
|
303
424
|
for z in range(nb_patch[x][y]):
|
|
304
|
-
|
|
305
|
-
return data,
|
|
425
|
+
mapping.append((x, y, z))
|
|
426
|
+
return data, mapping
|
|
306
427
|
|
|
307
|
-
def
|
|
308
|
-
|
|
428
|
+
def get_groups_dest(self):
|
|
429
|
+
groups_dest = []
|
|
309
430
|
for group_src in self.groups_src:
|
|
310
431
|
for group_dest in self.groups_src[group_src]:
|
|
311
|
-
|
|
312
|
-
return
|
|
313
|
-
|
|
314
|
-
|
|
315
|
-
|
|
432
|
+
groups_dest.append(group_dest)
|
|
433
|
+
return groups_dest
|
|
434
|
+
|
|
435
|
+
@staticmethod
|
|
436
|
+
def _split(mapping: list[tuple[int, int, int]], world_size: int) -> list[list[tuple[int, int, int]]]:
|
|
437
|
+
if len(mapping) == 0:
|
|
316
438
|
return [[] for _ in range(world_size)]
|
|
317
|
-
|
|
318
|
-
|
|
319
|
-
if
|
|
320
|
-
|
|
321
|
-
unique_index = np.unique(
|
|
322
|
-
offset = int(np.ceil(len(unique_index)/world_size))
|
|
439
|
+
|
|
440
|
+
mappings = []
|
|
441
|
+
if konfai_state() == str(State.PREDICTION) or konfai_state() == str(State.EVALUATION):
|
|
442
|
+
np_mapping = np.asarray(mapping)
|
|
443
|
+
unique_index = np.unique(np_mapping[:, 0])
|
|
444
|
+
offset = int(np.ceil(len(unique_index) / world_size))
|
|
323
445
|
if offset == 0:
|
|
324
446
|
offset = 1
|
|
325
447
|
for itr in range(0, len(unique_index), offset):
|
|
326
|
-
|
|
448
|
+
mappings.append(
|
|
449
|
+
[
|
|
450
|
+
tuple(v)
|
|
451
|
+
for v in np_mapping[
|
|
452
|
+
np.where(np.isin(np_mapping[:, 0], unique_index[itr : itr + offset]))[0],
|
|
453
|
+
:,
|
|
454
|
+
]
|
|
455
|
+
]
|
|
456
|
+
)
|
|
327
457
|
else:
|
|
328
|
-
offset = int(np.ceil(len(
|
|
458
|
+
offset = int(np.ceil(len(mapping) / world_size))
|
|
329
459
|
if offset == 0:
|
|
330
460
|
offset = 1
|
|
331
|
-
for itr in range(0, len(
|
|
332
|
-
|
|
333
|
-
return
|
|
334
|
-
|
|
335
|
-
def
|
|
336
|
-
datasets: dict[str, list[
|
|
461
|
+
for itr in range(0, len(mapping), offset):
|
|
462
|
+
mappings.append(list(mapping[-offset:]) if itr + offset > len(mapping) else mapping[itr : itr + offset])
|
|
463
|
+
return mappings
|
|
464
|
+
|
|
465
|
+
def get_data(self, world_size: int) -> tuple[list[list[DataLoader]], list[str], list[str]]:
|
|
466
|
+
datasets: dict[str, list[tuple[str, bool]]] = {}
|
|
337
467
|
if self.dataset_filenames is None or len(self.dataset_filenames) == 0:
|
|
338
468
|
raise DatasetManagerError("No dataset filenames were provided")
|
|
339
469
|
for dataset_filename in self.dataset_filenames:
|
|
340
470
|
if dataset_filename is None:
|
|
341
|
-
raise DatasetManagerError(
|
|
342
|
-
"
|
|
343
|
-
"
|
|
471
|
+
raise DatasetManagerError(
|
|
472
|
+
"Invalid dataset entry: 'None' received.",
|
|
473
|
+
"Each dataset must be a valid path string (e.g., './Dataset/', './Dataset/:mha, "
|
|
474
|
+
"'./Dataset/:a:mha', './Dataset/:i:mha').",
|
|
475
|
+
"Please check your 'dataset_filenames' list for missing or null entries.",
|
|
344
476
|
)
|
|
345
477
|
if len(dataset_filename.split(":")) == 1:
|
|
346
478
|
filename = dataset_filename
|
|
347
|
-
|
|
479
|
+
file_format = "mha"
|
|
348
480
|
append = True
|
|
349
481
|
elif len(dataset_filename.split(":")) == 2:
|
|
350
|
-
filename,
|
|
482
|
+
filename, file_format = dataset_filename.split(":")
|
|
351
483
|
append = True
|
|
352
484
|
else:
|
|
353
|
-
filename, flag,
|
|
485
|
+
filename, flag, file_format = dataset_filename.split(":")
|
|
354
486
|
append = flag == "a"
|
|
355
487
|
|
|
356
|
-
if
|
|
357
|
-
raise DatasetManagerError(
|
|
358
|
-
|
|
359
|
-
|
|
360
|
-
|
|
488
|
+
if file_format not in SUPPORTED_EXTENSIONS:
|
|
489
|
+
raise DatasetManagerError(
|
|
490
|
+
f"Unsupported file format '{file_format}'.",
|
|
491
|
+
f"Supported extensions are: {', '.join(SUPPORTED_EXTENSIONS)}",
|
|
492
|
+
)
|
|
493
|
+
|
|
494
|
+
dataset = Dataset(filename, file_format)
|
|
361
495
|
|
|
362
496
|
self.datasets[filename] = dataset
|
|
363
497
|
for group in self.groups_src:
|
|
364
|
-
if dataset.
|
|
498
|
+
if dataset.is_group_exist(group):
|
|
365
499
|
if group in datasets:
|
|
366
|
-
datasets[group].append((filename, append))
|
|
500
|
+
datasets[group].append((filename, append))
|
|
367
501
|
else:
|
|
368
502
|
datasets[group] = [(filename, append)]
|
|
369
|
-
|
|
503
|
+
model_have_input = False
|
|
370
504
|
for group_src in self.groups_src:
|
|
371
505
|
if group_src not in datasets:
|
|
372
|
-
|
|
506
|
+
|
|
373
507
|
raise DatasetManagerError(
|
|
374
508
|
f"Group source '{group_src}' not found in any dataset.",
|
|
375
509
|
f"Dataset filenames provided: {self.dataset_filenames}",
|
|
376
|
-
"Available groups across all datasets:
|
|
377
|
-
f"
|
|
510
|
+
f"Available groups across all datasets: "
|
|
511
|
+
f"{[f'{f} {d.get_group()}' for f, d in self.datasets.items()]}\n"
|
|
512
|
+
f"Please check that an entry in the dataset with the name '{group_src}' exists.",
|
|
378
513
|
)
|
|
379
|
-
|
|
514
|
+
|
|
380
515
|
for group_dest in self.groups_src[group_src]:
|
|
381
|
-
self.groups_src[group_src][group_dest].load(
|
|
382
|
-
|
|
516
|
+
self.groups_src[group_src][group_dest].load(
|
|
517
|
+
group_src,
|
|
518
|
+
group_dest,
|
|
519
|
+
[self.datasets[filename] for filename, _ in datasets[group_src]],
|
|
520
|
+
)
|
|
521
|
+
model_have_input |= self.groups_src[group_src][group_dest].is_input
|
|
522
|
+
if self.patch is not None:
|
|
523
|
+
self.patch.init()
|
|
383
524
|
|
|
384
|
-
if not
|
|
525
|
+
if not model_have_input:
|
|
385
526
|
raise DatasetManagerError(
|
|
386
|
-
"At least one group must be defined with '
|
|
527
|
+
"At least one group must be defined with 'is_input: true' to provide input to the network."
|
|
387
528
|
)
|
|
388
529
|
|
|
389
|
-
for key,
|
|
390
|
-
|
|
530
|
+
for key, data_augmentations in self.data_augmentations_list.items():
|
|
531
|
+
data_augmentations.load(key, [self.datasets[filename] for filename, _ in datasets[group_src]])
|
|
391
532
|
|
|
392
|
-
names = set()
|
|
393
|
-
dataset_name
|
|
394
|
-
dataset_info
|
|
533
|
+
names: set[str] = set()
|
|
534
|
+
dataset_name: dict[str, dict[str, list[str]]] = {}
|
|
535
|
+
dataset_info: dict[str, dict[str, dict[str, tuple[list[int], Attribute]]]] = {}
|
|
395
536
|
for group in self.groups_src:
|
|
396
|
-
|
|
537
|
+
names_by_group = set()
|
|
397
538
|
if group not in dataset_name:
|
|
398
539
|
dataset_name[group] = {}
|
|
399
540
|
dataset_info[group] = {}
|
|
400
541
|
for filename, _ in datasets[group]:
|
|
401
|
-
|
|
402
|
-
dataset_name[group][filename] = self.datasets[filename].
|
|
403
|
-
dataset_info[group][filename] = {
|
|
542
|
+
names_by_group.update(self.datasets[filename].get_names(group))
|
|
543
|
+
dataset_name[group][filename] = self.datasets[filename].get_names(group)
|
|
544
|
+
dataset_info[group][filename] = {
|
|
545
|
+
name: self.datasets[filename].get_infos(group, name) for name in dataset_name[group][filename]
|
|
546
|
+
}
|
|
404
547
|
if len(names) == 0:
|
|
405
|
-
names.update(
|
|
406
|
-
else:
|
|
407
|
-
names = names.intersection(
|
|
548
|
+
names.update(names_by_group)
|
|
549
|
+
else:
|
|
550
|
+
names = names.intersection(names_by_group)
|
|
408
551
|
if len(names) == 0:
|
|
409
|
-
|
|
410
|
-
f"No data was found for groups {list(self.groups_src.keys())}: although each group contains data
|
|
411
|
-
|
|
412
|
-
|
|
413
|
-
|
|
552
|
+
raise DatasetManagerError(
|
|
553
|
+
f"No data was found for groups {list(self.groups_src.keys())}: although each group contains data "
|
|
554
|
+
"from a dataset, there are no common dataset names shared across all groups, the intersection is empty."
|
|
555
|
+
)
|
|
556
|
+
|
|
557
|
+
subset_names: set[str] = set()
|
|
414
558
|
for group in dataset_name:
|
|
415
|
-
subset_names_bygroup = set()
|
|
559
|
+
subset_names_bygroup: set[str] = set()
|
|
416
560
|
for filename, append in datasets[group]:
|
|
417
561
|
if append:
|
|
418
|
-
subset_names_bygroup.update(
|
|
562
|
+
subset_names_bygroup.update(
|
|
563
|
+
self.subset(
|
|
564
|
+
dataset_name[group][filename],
|
|
565
|
+
dataset_info[group][filename],
|
|
566
|
+
)
|
|
567
|
+
)
|
|
419
568
|
else:
|
|
420
569
|
if len(subset_names_bygroup) == 0:
|
|
421
|
-
subset_names_bygroup.update(
|
|
570
|
+
subset_names_bygroup.update(
|
|
571
|
+
self.subset(
|
|
572
|
+
dataset_name[group][filename],
|
|
573
|
+
dataset_info[group][filename],
|
|
574
|
+
)
|
|
575
|
+
)
|
|
422
576
|
else:
|
|
423
|
-
subset_names_bygroup = subset_names_bygroup.intersection(
|
|
577
|
+
subset_names_bygroup = subset_names_bygroup.intersection(
|
|
578
|
+
self.subset(
|
|
579
|
+
dataset_name[group][filename],
|
|
580
|
+
dataset_info[group][filename],
|
|
581
|
+
)
|
|
582
|
+
)
|
|
424
583
|
if len(subset_names) == 0:
|
|
425
584
|
subset_names.update(subset_names_bygroup)
|
|
426
|
-
else:
|
|
585
|
+
else:
|
|
427
586
|
subset_names = subset_names.intersection(subset_names_bygroup)
|
|
587
|
+
|
|
428
588
|
if len(subset_names) == 0:
|
|
429
|
-
raise DatasetManagerError(
|
|
589
|
+
raise DatasetManagerError(
|
|
590
|
+
"All data entries were excluded by the subset filter.",
|
|
430
591
|
f"Dataset entries found: {', '.join(names)}",
|
|
431
592
|
f"Subset object applied: {self.subset}",
|
|
432
593
|
f"Subset requested : {', '.join(subset_names)}",
|
|
@@ -436,31 +597,29 @@ class Data(ABC):
|
|
|
436
597
|
"\tsubset: [0, 1] # explicit indices",
|
|
437
598
|
"\tsubset: 0:10 # slice notation",
|
|
438
599
|
"\tsubset: ./Validation.txt # external file",
|
|
439
|
-
"\tsubset: None # to disable filtering"
|
|
600
|
+
"\tsubset: None # to disable filtering",
|
|
440
601
|
)
|
|
441
|
-
|
|
442
|
-
data, map = self._getDatasets(list(subset_names), dataset_name)
|
|
443
602
|
|
|
444
|
-
|
|
445
|
-
|
|
603
|
+
data, mapping = self._get_datasets(list(subset_names), dataset_name)
|
|
604
|
+
|
|
605
|
+
index = []
|
|
446
606
|
if isinstance(self.validation, float) or isinstance(self.validation, int):
|
|
447
607
|
if self.validation <= 0 or self.validation >= 1:
|
|
448
|
-
raise DatasetManagerError(
|
|
449
|
-
|
|
450
|
-
|
|
608
|
+
raise DatasetManagerError(
|
|
609
|
+
"Validation must be a float between 0 and 1.",
|
|
610
|
+
f"Received: {self.validation}",
|
|
611
|
+
"Example: validation = 0.2 # for a 20% validation split",
|
|
612
|
+
)
|
|
613
|
+
index = [m[0] for m in mapping[int(math.floor(len(mapping) * (1 - self.validation))) :]]
|
|
451
614
|
elif isinstance(self.validation, str):
|
|
452
615
|
if ":" in self.validation:
|
|
453
|
-
index = list(range(int(self.
|
|
454
|
-
train_map = [m for m in map if m[0] not in index]
|
|
455
|
-
validate_map = [m for m in map if m[0] in index]
|
|
616
|
+
index = list(range(int(self.validation.split(":")[0]), int(self.validation.split(":")[1])))
|
|
456
617
|
elif os.path.exists(self.validation):
|
|
457
618
|
validation_names = []
|
|
458
|
-
with open(self.validation
|
|
619
|
+
with open(self.validation) as f:
|
|
459
620
|
for name in f:
|
|
460
621
|
validation_names.append(name.strip())
|
|
461
622
|
index = [i for i, n in enumerate(subset_names) if n in validation_names]
|
|
462
|
-
train_map = [m for m in map if m[0] not in index]
|
|
463
|
-
validate_map = [m for m in map if m[0] in index]
|
|
464
623
|
else:
|
|
465
624
|
raise DatasetManagerError(
|
|
466
625
|
f"Invalid string value for 'validation': '{self.validation}'",
|
|
@@ -470,94 +629,152 @@ class Data(ABC):
|
|
|
470
629
|
"\t• A float between 0 and 1 (e.g., 0.2)",
|
|
471
630
|
"\t• A list of sample names or indices",
|
|
472
631
|
"The provided value is neither a valid slice nor a readable file.",
|
|
473
|
-
"Please fix your 'validation' setting in the configuration."
|
|
474
|
-
|
|
475
|
-
|
|
632
|
+
"Please fix your 'validation' setting in the configuration.",
|
|
633
|
+
)
|
|
476
634
|
elif isinstance(self.validation, list):
|
|
477
|
-
if
|
|
478
|
-
|
|
479
|
-
|
|
480
|
-
|
|
481
|
-
|
|
482
|
-
|
|
483
|
-
|
|
484
|
-
|
|
485
|
-
|
|
486
|
-
|
|
487
|
-
|
|
488
|
-
|
|
489
|
-
|
|
490
|
-
|
|
491
|
-
|
|
492
|
-
|
|
493
|
-
|
|
494
|
-
|
|
635
|
+
if isinstance(self.validation[0], int):
|
|
636
|
+
index = cast(list[int], self.validation)
|
|
637
|
+
elif isinstance(self.validation[0], str):
|
|
638
|
+
index = [i for i, n in enumerate(subset_names) if n in self.validation]
|
|
639
|
+
else:
|
|
640
|
+
raise DatasetManagerError(
|
|
641
|
+
"Invalid list type for 'validation': elements of type "
|
|
642
|
+
f"'{type(self.validation[0]).__name__}' are not supported.",
|
|
643
|
+
"Supported list element types are:",
|
|
644
|
+
"\t• int → list of indices (e.g., [0, 1, 2])",
|
|
645
|
+
"\t• str → list of sample names (e.g., ['patient01', 'patient02'])",
|
|
646
|
+
f"Received list: {self.validation}",
|
|
647
|
+
)
|
|
648
|
+
train_mapping = [m for m in mapping if m[0] not in index]
|
|
649
|
+
validate_mapping = [m for m in mapping if m[0] in index]
|
|
650
|
+
|
|
651
|
+
if len(train_mapping) == 0:
|
|
652
|
+
raise DatasetManagerError(
|
|
653
|
+
"No data left for training after applying the validation split.",
|
|
654
|
+
f"Dataset size: {len(mapping)}",
|
|
495
655
|
f"Validation setting: {self.validation}",
|
|
496
|
-
"Please reduce the validation size, increase the dataset, or disable validation."
|
|
656
|
+
"Please reduce the validation size, increase the dataset, or disable validation.",
|
|
497
657
|
)
|
|
498
658
|
|
|
499
|
-
if self.validation is not None and len(
|
|
500
|
-
raise DatasetManagerError(
|
|
501
|
-
|
|
659
|
+
if self.validation is not None and len(validate_mapping) == 0:
|
|
660
|
+
raise DatasetManagerError(
|
|
661
|
+
"No data left for validation after applying the validation split.",
|
|
662
|
+
f"Dataset size: {len(mapping)}",
|
|
502
663
|
f"Validation setting: {self.validation}",
|
|
503
|
-
"Please increase the validation size, increase the dataset, or disable validation."
|
|
664
|
+
"Please increase the validation size, increase the dataset, or disable validation.",
|
|
504
665
|
)
|
|
505
|
-
|
|
506
|
-
|
|
507
|
-
|
|
508
|
-
|
|
509
|
-
|
|
510
|
-
|
|
511
|
-
|
|
666
|
+
|
|
667
|
+
validation_names = [name for i, name in enumerate(subset_names) if i in index]
|
|
668
|
+
train_names = [name for name in subset_names if name not in validation_names]
|
|
669
|
+
train_mappings = Data._split(train_mapping, world_size)
|
|
670
|
+
validate_mappings = Data._split(validate_mapping, world_size)
|
|
671
|
+
|
|
672
|
+
for i, (train_mapping, validate_mapping) in enumerate(zip(train_mappings, validate_mappings)):
|
|
673
|
+
mappings = [train_mapping]
|
|
674
|
+
if len(validate_mapping):
|
|
675
|
+
mappings += [validate_mapping]
|
|
512
676
|
self.data.append([])
|
|
513
|
-
self.
|
|
514
|
-
for
|
|
515
|
-
indexs = np.unique(np.asarray(
|
|
516
|
-
self.data[i].append({k:[v[it] for it in indexs] for k, v in data.items()})
|
|
517
|
-
|
|
677
|
+
self.mapping.append([])
|
|
678
|
+
for mapping_tmp in mappings:
|
|
679
|
+
indexs = np.unique(np.asarray(mapping_tmp)[:, 0])
|
|
680
|
+
self.data[i].append({k: [v[it] for it in indexs] for k, v in data.items()})
|
|
681
|
+
mapping_tmp_array = np.asarray(mapping_tmp)
|
|
518
682
|
for a, b in enumerate(indexs):
|
|
519
|
-
|
|
520
|
-
self.
|
|
683
|
+
mapping_tmp_array[np.where(np.asarray(mapping_tmp_array)[:, 0] == b), 0] = a
|
|
684
|
+
self.mapping[i].append([(a, b, c) for a, b, c in mapping_tmp_array])
|
|
685
|
+
|
|
686
|
+
data_loaders: list[list[DataLoader]] = []
|
|
687
|
+
for i, (datas, mappings) in enumerate(zip(self.data, self.mapping)):
|
|
688
|
+
data_loaders.append([])
|
|
689
|
+
for data, mapping in zip(datas, mappings):
|
|
690
|
+
data_loaders[i].append(
|
|
691
|
+
DataLoader(
|
|
692
|
+
dataset=self.datasetIter(
|
|
693
|
+
rank=i,
|
|
694
|
+
data=data,
|
|
695
|
+
mapping=mapping,
|
|
696
|
+
),
|
|
697
|
+
sampler=CustomSampler(len(mapping), self.subset.shuffle),
|
|
698
|
+
batch_size=self.batch_size,
|
|
699
|
+
**self.dataLoader_args,
|
|
700
|
+
)
|
|
701
|
+
)
|
|
702
|
+
return data_loaders, train_names, validation_names
|
|
521
703
|
|
|
522
|
-
dataLoaders: list[list[DataLoader]] = []
|
|
523
|
-
for i, (datas, maps) in enumerate(zip(self.data, self.map)):
|
|
524
|
-
dataLoaders.append([])
|
|
525
|
-
for data, map in zip(datas, maps):
|
|
526
|
-
dataLoaders[i].append(DataLoader(dataset=DatasetIter(rank=i, data=data, map=map, **self.dataSet_args), sampler=CustomSampler(len(map), self.subset.shuffle), batch_size=self.batch_size,**self.dataLoader_args))
|
|
527
|
-
return dataLoaders
|
|
528
704
|
|
|
529
705
|
class DataTrain(Data):
|
|
530
706
|
|
|
531
707
|
@config("Dataset")
|
|
532
|
-
def __init__(
|
|
533
|
-
|
|
534
|
-
|
|
535
|
-
|
|
536
|
-
|
|
537
|
-
|
|
538
|
-
|
|
539
|
-
|
|
540
|
-
|
|
541
|
-
|
|
708
|
+
def __init__(
|
|
709
|
+
self,
|
|
710
|
+
dataset_filenames: list[str] = ["default:./Dataset"],
|
|
711
|
+
groups_src: dict[str, Group] = {"default:group_src": Group()},
|
|
712
|
+
augmentations: dict[str, DataAugmentationsList] | None = {"DataAugmentation_0": DataAugmentationsList()},
|
|
713
|
+
inline_augmentations: bool = False,
|
|
714
|
+
patch: DatasetPatch | None = DatasetPatch(),
|
|
715
|
+
use_cache: bool = True,
|
|
716
|
+
subset: TrainSubset = TrainSubset(),
|
|
717
|
+
batch_size: int = 1,
|
|
718
|
+
validation: float | str | list[int] | list[str] = 0.2,
|
|
719
|
+
) -> None:
|
|
720
|
+
super().__init__(
|
|
721
|
+
dataset_filenames,
|
|
722
|
+
groups_src,
|
|
723
|
+
patch,
|
|
724
|
+
use_cache,
|
|
725
|
+
subset,
|
|
726
|
+
batch_size,
|
|
727
|
+
validation,
|
|
728
|
+
inline_augmentations,
|
|
729
|
+
augmentations if augmentations else {},
|
|
730
|
+
)
|
|
731
|
+
|
|
542
732
|
|
|
543
733
|
class DataPrediction(Data):
|
|
544
734
|
|
|
545
735
|
@config("Dataset")
|
|
546
|
-
def __init__(
|
|
547
|
-
|
|
548
|
-
|
|
549
|
-
|
|
550
|
-
|
|
551
|
-
|
|
736
|
+
def __init__(
|
|
737
|
+
self,
|
|
738
|
+
dataset_filenames: list[str] = ["default:./Dataset"],
|
|
739
|
+
groups_src: dict[str, Group] = {"default": Group()},
|
|
740
|
+
augmentations: dict[str, DataAugmentationsList] | None = {"DataAugmentation_0": DataAugmentationsList()},
|
|
741
|
+
patch: DatasetPatch | None = DatasetPatch(),
|
|
742
|
+
subset: PredictionSubset = PredictionSubset(),
|
|
743
|
+
batch_size: int = 1,
|
|
744
|
+
) -> None:
|
|
745
|
+
|
|
746
|
+
super().__init__(
|
|
747
|
+
dataset_filenames=dataset_filenames,
|
|
748
|
+
groups_src=groups_src,
|
|
749
|
+
patch=patch,
|
|
750
|
+
use_cache=False,
|
|
751
|
+
subset=subset,
|
|
752
|
+
batch_size=batch_size,
|
|
753
|
+
validation=None,
|
|
754
|
+
inline_augmentations=False,
|
|
755
|
+
data_augmentations_list=augmentations if augmentations else {},
|
|
756
|
+
)
|
|
552
757
|
|
|
553
|
-
super().__init__(dataset_filenames, groups_src, patch, False, subset, batch_size, dataAugmentationsList=augmentations if augmentations else {})
|
|
554
758
|
|
|
555
759
|
class DataMetric(Data):
|
|
556
760
|
|
|
557
761
|
@config("Dataset")
|
|
558
|
-
def __init__(
|
|
559
|
-
|
|
560
|
-
|
|
561
|
-
|
|
562
|
-
|
|
563
|
-
|
|
762
|
+
def __init__(
|
|
763
|
+
self,
|
|
764
|
+
dataset_filenames: list[str] = ["default:./Dataset"],
|
|
765
|
+
groups_src: dict[str, GroupMetric] = {"default": GroupMetric()},
|
|
766
|
+
subset: PredictionSubset = PredictionSubset(),
|
|
767
|
+
validation: str | None = None,
|
|
768
|
+
) -> None:
|
|
769
|
+
|
|
770
|
+
super().__init__(
|
|
771
|
+
dataset_filenames=dataset_filenames,
|
|
772
|
+
groups_src=groups_src,
|
|
773
|
+
patch=None,
|
|
774
|
+
use_cache=True,
|
|
775
|
+
subset=subset,
|
|
776
|
+
batch_size=1,
|
|
777
|
+
validation=validation,
|
|
778
|
+
data_augmentations_list={},
|
|
779
|
+
inline_augmentations=False,
|
|
780
|
+
)
|