konfai 1.1.8__py3-none-any.whl → 1.2.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of konfai might be problematic. Click here for more details.

Files changed (36) hide show
  1. konfai/__init__.py +59 -14
  2. konfai/data/augmentation.py +457 -286
  3. konfai/data/data_manager.py +533 -316
  4. konfai/data/patching.py +300 -183
  5. konfai/data/transform.py +408 -275
  6. konfai/evaluator.py +325 -68
  7. konfai/main.py +71 -22
  8. konfai/metric/measure.py +360 -244
  9. konfai/metric/schedulers.py +24 -13
  10. konfai/models/classification/convNeXt.py +187 -81
  11. konfai/models/classification/resnet.py +272 -58
  12. konfai/models/generation/cStyleGan.py +233 -59
  13. konfai/models/generation/ddpm.py +348 -121
  14. konfai/models/generation/diffusionGan.py +757 -358
  15. konfai/models/generation/gan.py +177 -53
  16. konfai/models/generation/vae.py +140 -40
  17. konfai/models/registration/registration.py +135 -52
  18. konfai/models/representation/representation.py +57 -23
  19. konfai/models/segmentation/NestedUNet.py +339 -68
  20. konfai/models/segmentation/UNet.py +140 -30
  21. konfai/network/blocks.py +331 -187
  22. konfai/network/network.py +795 -427
  23. konfai/predictor.py +644 -238
  24. konfai/trainer.py +509 -222
  25. konfai/utils/ITK.py +191 -106
  26. konfai/utils/config.py +152 -95
  27. konfai/utils/dataset.py +326 -455
  28. konfai/utils/utils.py +497 -249
  29. {konfai-1.1.8.dist-info → konfai-1.2.0.dist-info}/METADATA +1 -3
  30. konfai-1.2.0.dist-info/RECORD +38 -0
  31. konfai/utils/registration.py +0 -199
  32. konfai-1.1.8.dist-info/RECORD +0 -39
  33. {konfai-1.1.8.dist-info → konfai-1.2.0.dist-info}/WHEEL +0 -0
  34. {konfai-1.1.8.dist-info → konfai-1.2.0.dist-info}/entry_points.txt +0 -0
  35. {konfai-1.1.8.dist-info → konfai-1.2.0.dist-info}/licenses/LICENSE +0 -0
  36. {konfai-1.1.8.dist-info → konfai-1.2.0.dist-info}/top_level.txt +0 -0
@@ -1,85 +1,115 @@
1
1
  import math
2
2
  import os
3
3
  import random
4
+ import threading
5
+ from abc import ABC, abstractmethod
6
+ from collections.abc import Iterator, Mapping
7
+ from concurrent.futures import ThreadPoolExecutor, as_completed
8
+ from functools import partial
9
+ from typing import cast
10
+
11
+ import numpy as np
4
12
  import torch
5
- from torch.utils import data
6
13
  import tqdm
7
- import numpy as np
8
- from abc import ABC
9
- from torch.utils.data import DataLoader, Sampler
10
- from typing import Union, Iterator
11
- from concurrent.futures import ThreadPoolExecutor, as_completed
12
- import threading
13
14
  from torch.cuda import device_count
14
- import SimpleITK as sitk
15
+ from torch.utils import data
16
+ from torch.utils.data import DataLoader, Sampler
15
17
 
16
- from konfai import KONFAI_STATE, KONFAI_ROOT
17
- from konfai.data.patching import DatasetPatch, DatasetManager
18
- from konfai.utils.config import config
19
- from konfai.utils.utils import memoryInfo, cpuInfo, memoryForecast, getMemory, State, SUPPORTED_EXTENSIONS, DatasetManagerError
20
- from konfai.utils.dataset import Dataset, Attribute
21
- from konfai.data.transform import TransformLoader, Transform
18
+ from konfai import konfai_root, konfai_state
22
19
  from konfai.data.augmentation import DataAugmentationsList
20
+ from konfai.data.patching import DatasetManager, DatasetPatch
21
+ from konfai.data.transform import Transform, TransformLoader
22
+ from konfai.utils.config import config
23
+ from konfai.utils.dataset import Attribute, Dataset
24
+ from konfai.utils.utils import (
25
+ SUPPORTED_EXTENSIONS,
26
+ DatasetManagerError,
27
+ State,
28
+ get_cpu_info,
29
+ get_memory,
30
+ get_memory_info,
31
+ memory_forecast,
32
+ )
33
+
23
34
 
24
35
  class GroupTransform:
25
36
 
26
37
  @config()
27
- def __init__(self, transforms : Union[dict[str, TransformLoader], list[Transform]] = {"default:Normalize:Standardize:Unsqueeze:TensorCast:ResampleIsotropic:ResampleResize": TransformLoader()},
28
- patch_transforms : Union[dict[str, TransformLoader], list[Transform]] = {"default:Normalize:Standardize:Unsqueeze:TensorCast:ResampleIsotropic:ResampleResize": TransformLoader()},
29
- isInput: bool = True) -> None:
30
- self._pre_transforms = transforms
31
- self._post_transforms = patch_transforms
32
- self.pre_transforms : list[Transform] = []
33
- self.post_transforms : list[Transform] = []
34
- self.isInput = isInput
35
-
36
- def load(self, group_src : str, group_dest : str, datasets: list[Dataset]):
37
- if self._pre_transforms is not None:
38
- if isinstance(self._pre_transforms, dict):
39
- for classpath, transform in self._pre_transforms.items():
40
- transform = transform.getTransform(classpath, DL_args = "{}.Dataset.groups_src.{}.groups_dest.{}.transforms".format(KONFAI_ROOT(), group_src, group_dest))
41
- transform.setDatasets(datasets)
42
- self.pre_transforms.append(transform)
43
- else:
44
- for transform in self._pre_transforms:
45
- transform.setDatasets(datasets)
46
- self.pre_transforms.append(transform)
47
-
48
- if self._post_transforms is not None:
49
- if isinstance(self._post_transforms, dict):
50
- for classpath, transform in self._post_transforms.items():
51
- transform = transform.getTransform(classpath, DL_args = "{}.Dataset.groups_src.{}.groups_dest.{}.patch_transforms".format(KONFAI_ROOT(), group_src, group_dest))
52
- transform.setDatasets(datasets)
53
- self.post_transforms.append(transform)
54
- else:
55
- for transform in self._post_transforms:
56
- transform.setDatasets(datasets)
57
- self.post_transforms.append(transform)
58
-
38
+ def __init__(
39
+ self,
40
+ transforms: dict[str, TransformLoader] = {
41
+ "default:Normalize:Standardize:Unsqueeze:TensorCast:ResampleIsotropic:ResampleResize": TransformLoader()
42
+ },
43
+ patch_transforms: dict[str, TransformLoader] = {
44
+ "default:Normalize:Standardize:Unsqueeze:TensorCast:ResampleIsotropic:ResampleResize": TransformLoader()
45
+ },
46
+ is_input: bool = True,
47
+ ) -> None:
48
+ self._transforms = transforms
49
+ self._patch_transforms = patch_transforms
50
+ self.transforms: list[Transform] = []
51
+ self.patch_transforms: list[Transform] = []
52
+ self.is_input = is_input
53
+
54
+ def load(self, group_src: str, group_dest: str, datasets: list[Dataset]):
55
+ if self._transforms is not None:
56
+ for classpath, transform_loader in self._transforms.items():
57
+ transform = transform_loader.get_transform(
58
+ classpath,
59
+ konfai_args=f"{konfai_root()}.Dataset.groups_src.{group_src}.groups_dest.{group_dest}.transforms",
60
+ )
61
+ transform.set_datasets(datasets)
62
+ self.transforms.append(transform)
63
+
64
+ if self._patch_transforms is not None:
65
+ for classpath, transform_loader in self._patch_transforms.items():
66
+ transform = transform_loader.get_transform(
67
+ classpath,
68
+ konfai_args=f"{konfai_root()}.Dataset.groups_src.{group_src}"
69
+ f".groups_dest.{group_dest}.patch_transforms",
70
+ )
71
+ transform.set_datasets(datasets)
72
+ self.patch_transforms.append(transform)
73
+
59
74
  def to(self, device: int):
60
- for transform in self.pre_transforms:
61
- transform.setDevice(device)
62
- for transform in self.post_transforms:
63
- transform.setDevice(device)
75
+ for transform in self.transforms:
76
+ transform.to(device)
77
+ for transform in self.patch_transforms:
78
+ transform.to(device)
79
+
64
80
 
65
81
  class GroupTransformMetric(GroupTransform):
66
82
 
67
83
  @config()
68
- def __init__(self, transforms : Union[dict[str, TransformLoader], list[Transform]] = {"default:Normalize:Standardize:Unsqueeze:TensorCast:ResampleIsotropic:ResampleResize": TransformLoader()}):
84
+ def __init__(
85
+ self,
86
+ transforms: dict[str, TransformLoader] = {
87
+ "default:Normalize:Standardize:Unsqueeze:TensorCast:ResampleIsotropic:ResampleResize": TransformLoader()
88
+ },
89
+ ):
69
90
  super().__init__(transforms, None)
70
91
 
92
+
71
93
  class Group(dict[str, GroupTransform]):
72
94
 
73
95
  @config()
74
- def __init__(self, groups_dest: dict[str, GroupTransform] = {"default:group_dest": GroupTransform()}):
96
+ def __init__(
97
+ self,
98
+ groups_dest: dict[str, GroupTransform] = {"default:group_dest": GroupTransform()},
99
+ ):
75
100
  super().__init__(groups_dest)
76
101
 
102
+
77
103
  class GroupMetric(dict[str, GroupTransformMetric]):
78
104
 
79
105
  @config()
80
- def __init__(self, groups_dest: dict[str, GroupTransformMetric] = {"default:group_dest": GroupTransformMetric()}):
106
+ def __init__(
107
+ self,
108
+ groups_dest: dict[str, GroupTransformMetric] = {"default:group_dest": GroupTransformMetric()},
109
+ ):
81
110
  super().__init__(groups_dest)
82
111
 
112
+
83
113
  class CustomSampler(Sampler[int]):
84
114
 
85
115
  def __init__(self, size: int, shuffle: bool = False) -> None:
@@ -87,346 +117,477 @@ class CustomSampler(Sampler[int]):
87
117
  self.shuffle = shuffle
88
118
 
89
119
  def __iter__(self) -> Iterator[int]:
90
- return iter(torch.randperm(len(self)).tolist() if self.shuffle else list(range(len(self))) )
120
+ return iter(torch.randperm(len(self)).tolist() if self.shuffle else list(range(len(self))))
91
121
 
92
122
  def __len__(self) -> int:
93
123
  return self.size
94
124
 
125
+
95
126
  class DatasetIter(data.Dataset):
96
127
 
97
- def __init__(self, rank: int, data : dict[str, list[DatasetManager]], map: dict[int, tuple[int, int, int]], groups_src : dict[str, Group], inlineAugmentations: bool, dataAugmentationsList : list[DataAugmentationsList], patch_size: Union[list[int], None], overlap: Union[int, None], buffer_size: int, use_cache = True) -> None:
128
+ def __init__(
129
+ self,
130
+ rank: int,
131
+ data: dict[str, list[DatasetManager]],
132
+ mapping: list[tuple[int, int, int]],
133
+ groups_src: Mapping[str, Group | GroupMetric],
134
+ inline_augmentations: bool,
135
+ data_augmentations_list: list[DataAugmentationsList],
136
+ patch_size: list[int] | None,
137
+ overlap: int | None,
138
+ buffer_size: int,
139
+ use_cache=True,
140
+ ) -> None:
98
141
  self.rank = rank
99
142
  self.data = data
100
- self.map = map
143
+ self.mapping = mapping
101
144
  self.patch_size = patch_size
102
145
  self.overlap = overlap
103
146
  self.groups_src = groups_src
104
- self.dataAugmentationsList = dataAugmentationsList
147
+ self.data_augmentations_list = data_augmentations_list
105
148
  self.use_cache = use_cache
106
149
  self.nb_dataset = len(data[list(data.keys())[0]])
107
150
  self.buffer_size = buffer_size
108
- self._index_cache = list()
109
- self.device = None
110
- self.inlineAugmentations = inlineAugmentations
151
+ self._index_cache: list[int] = []
152
+ self.inline_augmentations = inline_augmentations
111
153
 
112
- def getPatchConfig(self) -> tuple[list[int], int]:
154
+ def get_patch_config(self) -> tuple[list[int] | None, int | None]:
113
155
  return self.patch_size, self.overlap
114
-
156
+
115
157
  def to(self, device: int):
116
158
  for group_src in self.groups_src:
117
159
  for group_dest in self.groups_src[group_src]:
118
160
  self.groups_src[group_src][group_dest].to(device)
119
- self.device = device
161
+ for data_augmentations in self.data_augmentations_list:
162
+ for data_augmentation in data_augmentations.data_augmentations:
163
+ data_augmentation.to(device)
120
164
 
121
- def getDatasetFromIndex(self, group_dest: str, index: int) -> DatasetManager:
165
+ def get_dataset_from_index(self, group_dest: str, index: int) -> DatasetManager:
122
166
  return self.data[group_dest][index]
123
-
124
- def resetAugmentation(self, label):
125
- if self.inlineAugmentations and len(self.dataAugmentationsList) > 0:
167
+
168
+ def reset_augmentation(self, label):
169
+ if self.inline_augmentations and len(self.data_augmentations_list) > 0:
126
170
  for index in range(self.nb_dataset):
127
171
  for group_src in self.groups_src:
128
172
  for group_dest in self.groups_src[group_src]:
129
- self.data[group_dest][index].unloadAugmentation()
130
- self.data[group_dest][index].resetAugmentation()
173
+ self.data[group_dest][index].unload_augmentation()
174
+ self.data[group_dest][index].reset_augmentation()
131
175
  self.load(label + " Augmentation")
132
176
 
133
177
  def load(self, label: str):
134
178
  if self.use_cache:
135
- memory_init = getMemory()
179
+ memory_init = get_memory()
136
180
 
137
- indexs = [index for index in range(self.nb_dataset)]
181
+ indexs = list(range(self.nb_dataset))
138
182
  if len(indexs) > 0:
139
183
  memory_lock = threading.Lock()
140
- desc = lambda : "Caching "+ label +": {} | {} | {}".format(memoryInfo(), memoryForecast(memory_init, 0, self.nb_dataset), cpuInfo())
141
- pbar = tqdm.tqdm(
142
- total=len(indexs),
143
- desc=desc(),
144
- leave=False
145
- )
184
+
185
+ def desc():
186
+ return (
187
+ f"Caching {label}: "
188
+ f"{get_memory_info()} | "
189
+ f"{memory_forecast(memory_init, 0, self.nb_dataset)} | "
190
+ f"{get_cpu_info()}"
191
+ )
192
+
193
+ pbar = tqdm.tqdm(total=len(indexs), desc=desc(), leave=False)
146
194
 
147
195
  def process(index):
148
- self._loadData(index)
196
+ self._load_data(index)
149
197
  with memory_lock:
150
198
  pbar.set_description(desc())
151
199
  pbar.update(1)
152
- with ThreadPoolExecutor(max_workers=os.cpu_count()//(device_count() if device_count() > 0 else 1)) as executor:
200
+
201
+ cpu_count = os.cpu_count() or 1
202
+ with ThreadPoolExecutor(
203
+ max_workers=cpu_count // (device_count() if device_count() > 0 else 1)
204
+ ) as executor:
153
205
  futures = [executor.submit(process, index) for index in indexs]
154
206
  for _ in as_completed(futures):
155
207
  pass
156
208
 
157
209
  pbar.close()
158
-
159
- def _loadData(self, index):
210
+
211
+ def _load_data(self, index):
160
212
  if index not in self._index_cache:
161
213
  self._index_cache.append(index)
162
214
  for group_src in self.groups_src:
163
215
  for group_dest in self.groups_src[group_src]:
164
- self.loadData(group_src, group_dest, index)
216
+ self.load_data(group_src, group_dest, index)
165
217
 
166
- def loadData(self, group_src: str, group_dest : str, index : int) -> None:
167
- self.data[group_dest][index].load(self.groups_src[group_src][group_dest].pre_transforms, self.dataAugmentationsList, self.device)
218
+ def load_data(self, group_src: str, group_dest: str, index: int) -> None:
219
+ self.data[group_dest][index].load(
220
+ self.groups_src[group_src][group_dest].transforms,
221
+ self.data_augmentations_list,
222
+ )
168
223
 
169
- def _unloadData(self, index : int) -> None:
224
+ def _unload_data(self, index: int) -> None:
170
225
  if index in self._index_cache:
171
226
  self._index_cache.remove(index)
172
227
  for group_src in self.groups_src:
173
228
  for group_dest in self.groups_src[group_src]:
174
- self.unloadData(group_dest, index)
175
-
176
- def unloadData(self, group_dest : str, index : int) -> None:
229
+ self.unload_data(group_dest, index)
230
+
231
+ def unload_data(self, group_dest: str, index: int) -> None:
177
232
  return self.data[group_dest][index].unload()
178
233
 
179
234
  def __len__(self) -> int:
180
- return len(self.map)
235
+ return len(self.mapping)
181
236
 
182
- def __getitem__(self, index : int) -> dict[str, tuple[torch.Tensor, int, int, int, str, bool]]:
237
+ def __getitem__(self, index: int) -> dict[str, tuple[torch.Tensor, int, int, int, str, bool]]:
183
238
  data = {}
184
- x, a, p = self.map[index]
239
+ x, a, p = self.mapping[index]
185
240
  if x not in self._index_cache:
186
241
  if len(self._index_cache) >= self.buffer_size and not self.use_cache:
187
- self._unloadData(self._index_cache[0])
188
- self._loadData(x)
242
+ self._unload_data(self._index_cache[0])
243
+ self._load_data(x)
189
244
 
190
245
  for group_src in self.groups_src:
191
246
  for group_dest in self.groups_src[group_src]:
192
247
  dataset = self.data[group_dest][x]
193
- data["{}".format(group_dest)] = (dataset.getData(p, a, self.groups_src[group_src][group_dest].post_transforms, self.groups_src[group_src][group_dest].isInput), x, a, p, dataset.name, self.groups_src[group_src][group_dest].isInput)
248
+ data[f"{group_dest}"] = (
249
+ dataset.get_data(
250
+ p,
251
+ a,
252
+ self.groups_src[group_src][group_dest].patch_transforms,
253
+ self.groups_src[group_src][group_dest].is_input,
254
+ ),
255
+ x,
256
+ a,
257
+ p,
258
+ dataset.name,
259
+ self.groups_src[group_src][group_dest].is_input,
260
+ )
194
261
  return data
195
262
 
196
- class Subset():
197
-
198
- def __init__(self, subset: Union[str, list[int], list[str], None] = None, shuffle: bool = True) -> None:
263
+
264
+ class Subset:
265
+
266
+ def __init__(
267
+ self,
268
+ subset: str | list[int] | list[str] | None = None,
269
+ shuffle: bool = True,
270
+ ) -> None:
199
271
  self.subset = subset
200
272
  self.shuffle = shuffle
201
273
 
202
- def __call__(self, names: list[str], infos: list[dict[str, tuple[np.ndarray, Attribute]]]) -> set[str]:
203
- inter_name = set(names[0])
204
- for n in names[1:]:
205
- inter_name = inter_name.intersection(set(n))
206
- names = sorted(list(inter_name))
207
-
274
+ def _get_index(self, subset: str | int, names: list[str]) -> list[int]:
208
275
  size = len(names)
209
276
  index = []
277
+ if isinstance(subset, int):
278
+ index.append(subset)
279
+ elif ":" in subset:
280
+ r = np.clip(
281
+ np.asarray([int(subset.split(":")[0]), int(subset.split(":")[1])]),
282
+ 0,
283
+ size,
284
+ )
285
+ index = list(range(r[0], r[1]))
286
+ elif os.path.exists(subset):
287
+ train_names = []
288
+ with open(subset) as f:
289
+ for name in f:
290
+ train_names.append(name.strip())
291
+ index = []
292
+ for i, name in enumerate(names):
293
+ if name in train_names:
294
+ index.append(i)
295
+ elif subset.startswith("~") and os.path.exists(subset[1:]):
296
+ exclude_names = []
297
+ with open(subset[1:]) as f:
298
+ for name in f:
299
+ exclude_names.append(name.strip())
300
+ index = []
301
+ for i, name in enumerate(names):
302
+ if name not in exclude_names:
303
+ index.append(i)
304
+ return index
305
+
306
+ def __call__(self, names: list[str], infos: dict[str, tuple[list[int], Attribute]]) -> set[str]:
307
+ names = sorted(names)
308
+ size = len(names)
309
+
210
310
  if self.subset is None:
211
311
  index = list(range(0, size))
212
- elif isinstance(self.subset, str):
213
- if ":" in self.subset:
214
- r = np.clip(np.asarray([int(self.subset.split(":")[0]), int(self.subset.split(":")[1])]), 0, size)
215
- index = list(range(r[0], r[1]))
216
- elif os.path.exists(self.subset):
217
- train_names = []
218
- with open(self.subset, "r") as f:
219
- for name in f:
220
- train_names.append(name.strip())
221
- index = []
222
- for i, name in enumerate(names):
223
- if name in train_names:
224
- index.append(i)
225
- elif self.subset.startswith("~") and os.path.exists(self.subset[1:]):
226
- exclude_names = []
227
- with open(self.subset[1:], "r") as f:
228
- for name in f:
229
- exclude_names.append(name.strip())
230
- index = []
231
- for i, name in enumerate(names):
232
- if name not in exclude_names:
233
- index.append(i)
234
-
235
312
  elif isinstance(self.subset, list):
236
- if len(self.subset) > 0:
237
- if isinstance(self.subset[0], int):
238
- if len(self.subset) == 1:
239
- index = list(range(self.subset[0], min(size, self.subset[0]+1)))
240
- else:
241
- index = self.subset
242
- if isinstance(self.subset[0], str):
243
- index = []
244
- for i, name in enumerate(names):
245
- if name in self.subset:
246
- index.append(i)
313
+ index_set: set[int] = set()
314
+ for s in self.subset:
315
+ if len(index_set) == 0:
316
+ index_set.update(set(self._get_index(s, names)))
317
+ else:
318
+ index_set = index_set.intersection(set(self._get_index(s, names)))
319
+ index = list(index_set)
320
+ print(index)
321
+ else:
322
+ index = self._get_index(self.subset, names)
247
323
  if self.shuffle:
248
- index = random.sample(index, len(index))
249
- return set([names[i] for i in index])
250
-
324
+ index = random.sample(index, len(index)) # nosec B311
325
+ return {names[i] for i in index}
326
+
251
327
  def __str__(self):
252
- return "Subset : " + str(self.subset) + " shuffle : "+ str(self.shuffle)
253
-
328
+ return "Subset : " + str(self.subset) + " shuffle : " + str(self.shuffle)
329
+
330
+
254
331
  class TrainSubset(Subset):
255
332
 
256
333
  @config()
257
- def __init__(self, subset: Union[str, list[int], list[str], None] = None, shuffle: bool = True) -> None:
334
+ def __init__(
335
+ self,
336
+ subset: str | list[int] | list[str] | None = None,
337
+ shuffle: bool = True,
338
+ ) -> None:
258
339
  super().__init__(subset, shuffle)
259
340
 
341
+
260
342
  class PredictionSubset(Subset):
261
343
 
262
344
  @config()
263
- def __init__(self, subset: Union[str, list[int], list[str], None] = None) -> None:
345
+ def __init__(self, subset: str | list[int] | list[str] | None = None) -> None:
264
346
  super().__init__(subset, False)
265
347
 
348
+
266
349
  class Data(ABC):
267
-
268
- def __init__(self, dataset_filenames : list[str],
269
- groups_src : dict[str, Group],
270
- patch : Union[DatasetPatch, None],
271
- use_cache : bool,
272
- subset : Subset,
273
- batch_size : int,
274
- validation: Union[float, str, list[int], list[str], None] = None,
275
- inlineAugmentations: bool = False,
276
- dataAugmentationsList: dict[str, DataAugmentationsList]= {}) -> None:
350
+
351
+ @abstractmethod
352
+ def __init__(
353
+ self,
354
+ dataset_filenames: list[str],
355
+ groups_src: Mapping[str, Group | GroupMetric],
356
+ patch: DatasetPatch | None,
357
+ use_cache: bool,
358
+ subset: Subset,
359
+ batch_size: int,
360
+ validation: float | str | list[int] | list[str] | None,
361
+ inline_augmentations: bool,
362
+ data_augmentations_list: dict[str, DataAugmentationsList],
363
+ ) -> None:
277
364
  self.dataset_filenames = dataset_filenames
278
365
  self.subset = subset
279
366
  self.groups_src = groups_src
280
367
  self.patch = patch
281
368
  self.validation = validation
282
- self.dataAugmentationsList = dataAugmentationsList
369
+ self.data_augmentations_list = data_augmentations_list
283
370
  self.batch_size = batch_size
284
- self.dataSet_args = dict(groups_src=self.groups_src, inlineAugmentations=inlineAugmentations, dataAugmentationsList = list(self.dataAugmentationsList.values()), use_cache = use_cache, buffer_size=batch_size+1, patch_size=self.patch.patch_size if self.patch is not None else None, overlap=self.patch.overlap if self.patch is not None else None)
285
- self.dataLoader_args = dict(num_workers=int(os.environ["KONFAI_WORKERS"]) if use_cache else 0, pin_memory=True)
286
- self.data : list[list[dict[str, list[DatasetManager]]], dict[str, list[DatasetManager]]] = []
287
- self.map : list[list[list[tuple[int, int, int]]], list[tuple[int, int, int]]] = []
371
+
372
+ self.datasetIter = partial(
373
+ DatasetIter,
374
+ groups_src=self.groups_src,
375
+ inline_augmentations=inline_augmentations,
376
+ data_augmentations_list=list(self.data_augmentations_list.values()),
377
+ patch_size=self.patch.patch_size if self.patch is not None else None,
378
+ overlap=self.patch.overlap if self.patch is not None else None,
379
+ buffer_size=batch_size + 1,
380
+ use_cache=use_cache,
381
+ )
382
+ self.dataLoader_args = {
383
+ "num_workers": int(os.environ["KONFAI_WORKERS"]) if use_cache else 0,
384
+ "pin_memory": True,
385
+ }
386
+ self.data: list[list[dict[str, list[DatasetManager]]]] = []
387
+ self.mapping: list[list[list[tuple[int, int, int]]]] = []
288
388
  self.datasets: dict[str, Dataset] = {}
289
389
 
290
- def _getDatasets(self, names: list[str], dataset_name: dict[str, dict[str, list[str]]]) -> tuple[dict[str, list[Dataset]], list[tuple[int, int, int]]]:
390
+ def _get_datasets(
391
+ self, names: list[str], dataset_name: dict[str, dict[str, list[str]]]
392
+ ) -> tuple[dict[str, list[DatasetManager]], list[tuple[int, int, int]]]:
291
393
  nb_dataset = len(names)
292
- nb_patch = None
394
+ nb_patch: list[list[int]]
293
395
  data = {}
294
- map = []
295
- nb_augmentation = np.max([int(np.sum([data_augmentation.nb for data_augmentation in self.dataAugmentationsList.values()])+1), 1])
396
+ mapping = []
397
+ nb_augmentation = np.max(
398
+ [
399
+ int(np.sum([data_augmentation.nb for data_augmentation in self.data_augmentations_list.values()]) + 1),
400
+ 1,
401
+ ]
402
+ )
296
403
  for group_src in self.groups_src:
297
404
  for group_dest in self.groups_src[group_src]:
298
- data[group_dest] = [DatasetManager(i, group_src, group_dest, name, self.datasets[[filename for filename, names in dataset_name[group_src].items() if name in names][0]], patch = self.patch, pre_transforms = self.groups_src[group_src][group_dest].pre_transforms, dataAugmentationsList=list(self.dataAugmentationsList.values())) for i, name in enumerate(names)]
299
- nb_patch = [[dataset.getSize(a) for a in range(nb_augmentation)] for dataset in data[group_dest]]
405
+ data[group_dest] = [
406
+ DatasetManager(
407
+ i,
408
+ group_src,
409
+ group_dest,
410
+ name,
411
+ self.datasets[
412
+ [filename for filename, names in dataset_name[group_src].items() if name in names][0]
413
+ ],
414
+ patch=self.patch,
415
+ transforms=self.groups_src[group_src][group_dest].transforms,
416
+ data_augmentations_list=list(self.data_augmentations_list.values()),
417
+ )
418
+ for i, name in enumerate(names)
419
+ ]
420
+ nb_patch = [[dataset.get_size(a) for a in range(nb_augmentation)] for dataset in data[group_dest]]
300
421
 
301
422
  for x in range(nb_dataset):
302
423
  for y in range(nb_augmentation):
303
424
  for z in range(nb_patch[x][y]):
304
- map.append((x, y, z))
305
- return data, map
425
+ mapping.append((x, y, z))
426
+ return data, mapping
306
427
 
307
- def getGroupsDest(self):
308
- groupsDest = []
428
+ def get_groups_dest(self):
429
+ groups_dest = []
309
430
  for group_src in self.groups_src:
310
431
  for group_dest in self.groups_src[group_src]:
311
- groupsDest.append(group_dest)
312
- return groupsDest
313
-
314
- def _split(map: list[tuple[int, int, int]], world_size: int) -> list[list[tuple[int, int, int]]]:
315
- if len(map) == 0:
432
+ groups_dest.append(group_dest)
433
+ return groups_dest
434
+
435
+ @staticmethod
436
+ def _split(mapping: list[tuple[int, int, int]], world_size: int) -> list[list[tuple[int, int, int]]]:
437
+ if len(mapping) == 0:
316
438
  return [[] for _ in range(world_size)]
317
-
318
- maps = []
319
- if KONFAI_STATE() == str(State.PREDICTION) or KONFAI_STATE() == str(State.EVALUATION):
320
- np_map = np.asarray(map)
321
- unique_index = np.unique(np_map[:, 0])
322
- offset = int(np.ceil(len(unique_index)/world_size))
439
+
440
+ mappings = []
441
+ if konfai_state() == str(State.PREDICTION) or konfai_state() == str(State.EVALUATION):
442
+ np_mapping = np.asarray(mapping)
443
+ unique_index = np.unique(np_mapping[:, 0])
444
+ offset = int(np.ceil(len(unique_index) / world_size))
323
445
  if offset == 0:
324
446
  offset = 1
325
447
  for itr in range(0, len(unique_index), offset):
326
- maps.append([tuple(v) for v in np_map[np.where(np.isin(np_map[:, 0], unique_index[itr:itr+offset]))[0], :]])
448
+ mappings.append(
449
+ [
450
+ tuple(v)
451
+ for v in np_mapping[
452
+ np.where(np.isin(np_mapping[:, 0], unique_index[itr : itr + offset]))[0],
453
+ :,
454
+ ]
455
+ ]
456
+ )
327
457
  else:
328
- offset = int(np.ceil(len(map)/world_size))
458
+ offset = int(np.ceil(len(mapping) / world_size))
329
459
  if offset == 0:
330
460
  offset = 1
331
- for itr in range(0, len(map), offset):
332
- maps.append(list(map[-offset:]) if itr+offset > len(map) else map[itr:itr+offset])
333
- return maps
334
-
335
- def getData(self, world_size: int) -> list[list[DataLoader]]:
336
- datasets: dict[str, list[(str, bool)]] = {}
461
+ for itr in range(0, len(mapping), offset):
462
+ mappings.append(list(mapping[-offset:]) if itr + offset > len(mapping) else mapping[itr : itr + offset])
463
+ return mappings
464
+
465
+ def get_data(self, world_size: int) -> tuple[list[list[DataLoader]], list[str], list[str]]:
466
+ datasets: dict[str, list[tuple[str, bool]]] = {}
337
467
  if self.dataset_filenames is None or len(self.dataset_filenames) == 0:
338
468
  raise DatasetManagerError("No dataset filenames were provided")
339
469
  for dataset_filename in self.dataset_filenames:
340
470
  if dataset_filename is None:
341
- raise DatasetManagerError("Invalid dataset entry: 'None' received.",
342
- "Each dataset must be a valid path string (e.g., './Dataset/', './Dataset/:mha, './Dataset/:a:mha', './Dataset/:i:mha').",
343
- "Please check your 'dataset_filenames' list for missing or null entries."
471
+ raise DatasetManagerError(
472
+ "Invalid dataset entry: 'None' received.",
473
+ "Each dataset must be a valid path string (e.g., './Dataset/', './Dataset/:mha, "
474
+ "'./Dataset/:a:mha', './Dataset/:i:mha').",
475
+ "Please check your 'dataset_filenames' list for missing or null entries.",
344
476
  )
345
477
  if len(dataset_filename.split(":")) == 1:
346
478
  filename = dataset_filename
347
- format = "mha"
479
+ file_format = "mha"
348
480
  append = True
349
481
  elif len(dataset_filename.split(":")) == 2:
350
- filename, format = dataset_filename.split(":")
482
+ filename, file_format = dataset_filename.split(":")
351
483
  append = True
352
484
  else:
353
- filename, flag, format = dataset_filename.split(":")
485
+ filename, flag, file_format = dataset_filename.split(":")
354
486
  append = flag == "a"
355
487
 
356
- if format not in SUPPORTED_EXTENSIONS:
357
- raise DatasetManagerError(f"Unsupported file format '{format}'.",
358
- f"Supported extensions are: {', '.join(SUPPORTED_EXTENSIONS)}")
359
-
360
- dataset = Dataset(filename, format)
488
+ if file_format not in SUPPORTED_EXTENSIONS:
489
+ raise DatasetManagerError(
490
+ f"Unsupported file format '{file_format}'.",
491
+ f"Supported extensions are: {', '.join(SUPPORTED_EXTENSIONS)}",
492
+ )
493
+
494
+ dataset = Dataset(filename, file_format)
361
495
 
362
496
  self.datasets[filename] = dataset
363
497
  for group in self.groups_src:
364
- if dataset.isGroupExist(group):
498
+ if dataset.is_group_exist(group):
365
499
  if group in datasets:
366
- datasets[group].append((filename, append))
500
+ datasets[group].append((filename, append))
367
501
  else:
368
502
  datasets[group] = [(filename, append)]
369
- modelHaveInput = False
503
+ model_have_input = False
370
504
  for group_src in self.groups_src:
371
505
  if group_src not in datasets:
372
-
506
+
373
507
  raise DatasetManagerError(
374
508
  f"Group source '{group_src}' not found in any dataset.",
375
509
  f"Dataset filenames provided: {self.dataset_filenames}",
376
- "Available groups across all datasets: {}".format(["{} {}".format(f, d.getGroup()) for f, d in self.datasets.items()]),
377
- f"Please check that an entry in the dataset with the name '{group_src}.{format}' exists."
510
+ f"Available groups across all datasets: "
511
+ f"{[f'{f} {d.get_group()}' for f, d in self.datasets.items()]}\n"
512
+ f"Please check that an entry in the dataset with the name '{group_src}' exists.",
378
513
  )
379
-
514
+
380
515
  for group_dest in self.groups_src[group_src]:
381
- self.groups_src[group_src][group_dest].load(group_src, group_dest, [self.datasets[filename] for filename, _ in datasets[group_src]])
382
- modelHaveInput |= self.groups_src[group_src][group_dest].isInput
516
+ self.groups_src[group_src][group_dest].load(
517
+ group_src,
518
+ group_dest,
519
+ [self.datasets[filename] for filename, _ in datasets[group_src]],
520
+ )
521
+ model_have_input |= self.groups_src[group_src][group_dest].is_input
522
+ if self.patch is not None:
523
+ self.patch.init()
383
524
 
384
- if not modelHaveInput:
525
+ if not model_have_input:
385
526
  raise DatasetManagerError(
386
- "At least one group must be defined with 'isInput: true' to provide input to the network."
527
+ "At least one group must be defined with 'is_input: true' to provide input to the network."
387
528
  )
388
529
 
389
- for key, dataAugmentations in self.dataAugmentationsList.items():
390
- dataAugmentations.load(key, [self.datasets[filename] for filename, _ in datasets[group_src]])
530
+ for key, data_augmentations in self.data_augmentations_list.items():
531
+ data_augmentations.load(key, [self.datasets[filename] for filename, _ in datasets[group_src]])
391
532
 
392
- names = set()
393
- dataset_name : dict[str, dict[str, list[str]]] = {}
394
- dataset_info : dict[str, dict[str, dict[str, Attribute]]] = {}
533
+ names: set[str] = set()
534
+ dataset_name: dict[str, dict[str, list[str]]] = {}
535
+ dataset_info: dict[str, dict[str, dict[str, tuple[list[int], Attribute]]]] = {}
395
536
  for group in self.groups_src:
396
- namesByGroup = set()
537
+ names_by_group = set()
397
538
  if group not in dataset_name:
398
539
  dataset_name[group] = {}
399
540
  dataset_info[group] = {}
400
541
  for filename, _ in datasets[group]:
401
- namesByGroup.update(self.datasets[filename].getNames(group))
402
- dataset_name[group][filename] = self.datasets[filename].getNames(group)
403
- dataset_info[group][filename] = {name: self.datasets[filename].getInfos(group, name) for name in dataset_name[group][filename]}
542
+ names_by_group.update(self.datasets[filename].get_names(group))
543
+ dataset_name[group][filename] = self.datasets[filename].get_names(group)
544
+ dataset_info[group][filename] = {
545
+ name: self.datasets[filename].get_infos(group, name) for name in dataset_name[group][filename]
546
+ }
404
547
  if len(names) == 0:
405
- names.update(namesByGroup)
406
- else:
407
- names = names.intersection(namesByGroup)
548
+ names.update(names_by_group)
549
+ else:
550
+ names = names.intersection(names_by_group)
408
551
  if len(names) == 0:
409
- raise DatasetManagerError(
410
- f"No data was found for groups {list(self.groups_src.keys())}: although each group contains data from a dataset, there are no common dataset names shared across all groups, the intersection is empty."
411
- )
412
-
413
- subset_names = set()
552
+ raise DatasetManagerError(
553
+ f"No data was found for groups {list(self.groups_src.keys())}: although each group contains data "
554
+ "from a dataset, there are no common dataset names shared across all groups, the intersection is empty."
555
+ )
556
+
557
+ subset_names: set[str] = set()
414
558
  for group in dataset_name:
415
- subset_names_bygroup = set()
559
+ subset_names_bygroup: set[str] = set()
416
560
  for filename, append in datasets[group]:
417
561
  if append:
418
- subset_names_bygroup.update(self.subset([dataset_name[group][filename]], [dataset_info[group][filename]]))
562
+ subset_names_bygroup.update(
563
+ self.subset(
564
+ dataset_name[group][filename],
565
+ dataset_info[group][filename],
566
+ )
567
+ )
419
568
  else:
420
569
  if len(subset_names_bygroup) == 0:
421
- subset_names_bygroup.update(self.subset([dataset_name[group][filename]], [dataset_info[group][filename]]))
570
+ subset_names_bygroup.update(
571
+ self.subset(
572
+ dataset_name[group][filename],
573
+ dataset_info[group][filename],
574
+ )
575
+ )
422
576
  else:
423
- subset_names_bygroup = subset_names_bygroup.intersection(self.subset([dataset_name[group][filename]], [dataset_info[group][filename]]))
577
+ subset_names_bygroup = subset_names_bygroup.intersection(
578
+ self.subset(
579
+ dataset_name[group][filename],
580
+ dataset_info[group][filename],
581
+ )
582
+ )
424
583
  if len(subset_names) == 0:
425
584
  subset_names.update(subset_names_bygroup)
426
- else:
585
+ else:
427
586
  subset_names = subset_names.intersection(subset_names_bygroup)
587
+
428
588
  if len(subset_names) == 0:
429
- raise DatasetManagerError("All data entries were excluded by the subset filter.",
589
+ raise DatasetManagerError(
590
+ "All data entries were excluded by the subset filter.",
430
591
  f"Dataset entries found: {', '.join(names)}",
431
592
  f"Subset object applied: {self.subset}",
432
593
  f"Subset requested : {', '.join(subset_names)}",
@@ -436,31 +597,29 @@ class Data(ABC):
436
597
  "\tsubset: [0, 1] # explicit indices",
437
598
  "\tsubset: 0:10 # slice notation",
438
599
  "\tsubset: ./Validation.txt # external file",
439
- "\tsubset: None # to disable filtering"
600
+ "\tsubset: None # to disable filtering",
440
601
  )
441
-
442
- data, map = self._getDatasets(list(subset_names), dataset_name)
443
602
 
444
- train_map = map
445
- validate_map = []
603
+ data, mapping = self._get_datasets(list(subset_names), dataset_name)
604
+
605
+ index = []
446
606
  if isinstance(self.validation, float) or isinstance(self.validation, int):
447
607
  if self.validation <= 0 or self.validation >= 1:
448
- raise DatasetManagerError("Validation must be a float between 0 and 1.", f"Received: {self.validation}", "Example: validation = 0.2 # for a 20% validation split")
449
-
450
- train_map, validate_map = map[:int(math.floor(len(map)*(1-self.validation)))], map[int(math.floor(len(map)*(1-self.validation))):]
608
+ raise DatasetManagerError(
609
+ "Validation must be a float between 0 and 1.",
610
+ f"Received: {self.validation}",
611
+ "Example: validation = 0.2 # for a 20% validation split",
612
+ )
613
+ index = [m[0] for m in mapping[int(math.floor(len(mapping) * (1 - self.validation))) :]]
451
614
  elif isinstance(self.validation, str):
452
615
  if ":" in self.validation:
453
- index = list(range(int(self.subset.split(":")[0]), int(self.subset.split(":")[1])))
454
- train_map = [m for m in map if m[0] not in index]
455
- validate_map = [m for m in map if m[0] in index]
616
+ index = list(range(int(self.validation.split(":")[0]), int(self.validation.split(":")[1])))
456
617
  elif os.path.exists(self.validation):
457
618
  validation_names = []
458
- with open(self.validation, "r") as f:
619
+ with open(self.validation) as f:
459
620
  for name in f:
460
621
  validation_names.append(name.strip())
461
622
  index = [i for i, n in enumerate(subset_names) if n in validation_names]
462
- train_map = [m for m in map if m[0] not in index]
463
- validate_map = [m for m in map if m[0] in index]
464
623
  else:
465
624
  raise DatasetManagerError(
466
625
  f"Invalid string value for 'validation': '{self.validation}'",
@@ -470,94 +629,152 @@ class Data(ABC):
470
629
  "\t• A float between 0 and 1 (e.g., 0.2)",
471
630
  "\t• A list of sample names or indices",
472
631
  "The provided value is neither a valid slice nor a readable file.",
473
- "Please fix your 'validation' setting in the configuration."
474
- )
475
-
632
+ "Please fix your 'validation' setting in the configuration.",
633
+ )
476
634
  elif isinstance(self.validation, list):
477
- if len(self.validation) > 0:
478
- if isinstance(self.validation[0], int):
479
- train_map = [m for m in map if m[0] not in self.validation]
480
- validate_map = [m for m in map if m[0] in self.validation]
481
- elif isinstance(self.validation[0], str):
482
- index = [i for i, n in enumerate(subset_names) if n in self.validation]
483
- train_map = [m for m in map if m[0] not in index]
484
- validate_map = [m for m in map if m[0] in index]
485
- else:
486
- raise DatasetManagerError(f"Invalid list type for 'validation': elements of type '{type(self.validation[0]).__name__}' are not supported.",
487
- "Supported list element types are:",
488
- "\t• int → list of indices (e.g., [0, 1, 2])",
489
- "\t• str → list of sample names (e.g., ['patient01', 'patient02'])",
490
- f"Received list: {self.validation}"
491
- )
492
- if len(train_map) == 0:
493
- raise DatasetManagerError("No data left for training after applying the validation split.",
494
- f"Dataset size: {len(map)}",
635
+ if isinstance(self.validation[0], int):
636
+ index = cast(list[int], self.validation)
637
+ elif isinstance(self.validation[0], str):
638
+ index = [i for i, n in enumerate(subset_names) if n in self.validation]
639
+ else:
640
+ raise DatasetManagerError(
641
+ "Invalid list type for 'validation': elements of type "
642
+ f"'{type(self.validation[0]).__name__}' are not supported.",
643
+ "Supported list element types are:",
644
+ "\t• int → list of indices (e.g., [0, 1, 2])",
645
+ "\t• str → list of sample names (e.g., ['patient01', 'patient02'])",
646
+ f"Received list: {self.validation}",
647
+ )
648
+ train_mapping = [m for m in mapping if m[0] not in index]
649
+ validate_mapping = [m for m in mapping if m[0] in index]
650
+
651
+ if len(train_mapping) == 0:
652
+ raise DatasetManagerError(
653
+ "No data left for training after applying the validation split.",
654
+ f"Dataset size: {len(mapping)}",
495
655
  f"Validation setting: {self.validation}",
496
- "Please reduce the validation size, increase the dataset, or disable validation."
656
+ "Please reduce the validation size, increase the dataset, or disable validation.",
497
657
  )
498
658
 
499
- if self.validation is not None and len(validate_map) == 0:
500
- raise DatasetManagerError("No data left for validation after applying the validation split.",
501
- f"Dataset size: {len(map)}",
659
+ if self.validation is not None and len(validate_mapping) == 0:
660
+ raise DatasetManagerError(
661
+ "No data left for validation after applying the validation split.",
662
+ f"Dataset size: {len(mapping)}",
502
663
  f"Validation setting: {self.validation}",
503
- "Please increase the validation size, increase the dataset, or disable validation."
664
+ "Please increase the validation size, increase the dataset, or disable validation.",
504
665
  )
505
- train_maps = Data._split(train_map, world_size)
506
- validate_maps = Data._split(validate_map, world_size)
507
-
508
- for i, (train_map, validate_map) in enumerate(zip(train_maps, validate_maps)):
509
- maps = [train_map]
510
- if len(validate_map):
511
- maps += [validate_map]
666
+
667
+ validation_names = [name for i, name in enumerate(subset_names) if i in index]
668
+ train_names = [name for name in subset_names if name not in validation_names]
669
+ train_mappings = Data._split(train_mapping, world_size)
670
+ validate_mappings = Data._split(validate_mapping, world_size)
671
+
672
+ for i, (train_mapping, validate_mapping) in enumerate(zip(train_mappings, validate_mappings)):
673
+ mappings = [train_mapping]
674
+ if len(validate_mapping):
675
+ mappings += [validate_mapping]
512
676
  self.data.append([])
513
- self.map.append([])
514
- for map_tmp in maps:
515
- indexs = np.unique(np.asarray(map_tmp)[:, 0])
516
- self.data[i].append({k:[v[it] for it in indexs] for k, v in data.items()})
517
- map_tmp_array = np.asarray(map_tmp)
677
+ self.mapping.append([])
678
+ for mapping_tmp in mappings:
679
+ indexs = np.unique(np.asarray(mapping_tmp)[:, 0])
680
+ self.data[i].append({k: [v[it] for it in indexs] for k, v in data.items()})
681
+ mapping_tmp_array = np.asarray(mapping_tmp)
518
682
  for a, b in enumerate(indexs):
519
- map_tmp_array[np.where(np.asarray(map_tmp_array)[:, 0] == b), 0] = a
520
- self.map[i].append([(a,b,c) for a,b,c in map_tmp_array])
683
+ mapping_tmp_array[np.where(np.asarray(mapping_tmp_array)[:, 0] == b), 0] = a
684
+ self.mapping[i].append([(a, b, c) for a, b, c in mapping_tmp_array])
685
+
686
+ data_loaders: list[list[DataLoader]] = []
687
+ for i, (datas, mappings) in enumerate(zip(self.data, self.mapping)):
688
+ data_loaders.append([])
689
+ for data, mapping in zip(datas, mappings):
690
+ data_loaders[i].append(
691
+ DataLoader(
692
+ dataset=self.datasetIter(
693
+ rank=i,
694
+ data=data,
695
+ mapping=mapping,
696
+ ),
697
+ sampler=CustomSampler(len(mapping), self.subset.shuffle),
698
+ batch_size=self.batch_size,
699
+ **self.dataLoader_args,
700
+ )
701
+ )
702
+ return data_loaders, train_names, validation_names
521
703
 
522
- dataLoaders: list[list[DataLoader]] = []
523
- for i, (datas, maps) in enumerate(zip(self.data, self.map)):
524
- dataLoaders.append([])
525
- for data, map in zip(datas, maps):
526
- dataLoaders[i].append(DataLoader(dataset=DatasetIter(rank=i, data=data, map=map, **self.dataSet_args), sampler=CustomSampler(len(map), self.subset.shuffle), batch_size=self.batch_size,**self.dataLoader_args))
527
- return dataLoaders
528
704
 
529
705
  class DataTrain(Data):
530
706
 
531
707
  @config("Dataset")
532
- def __init__(self, dataset_filenames : list[str] = ["default:./Dataset"],
533
- groups_src : dict[str, Group] = {"default:group_src" : Group()},
534
- augmentations : Union[dict[str, DataAugmentationsList], None] = {"DataAugmentation_0" : DataAugmentationsList()},
535
- inlineAugmentations: bool = False,
536
- patch : Union[DatasetPatch, None] = DatasetPatch(),
537
- use_cache : bool = True,
538
- subset : Union[TrainSubset, dict[str, TrainSubset]] = TrainSubset(),
539
- batch_size : int = 1,
540
- validation : Union[float, str, list[int], list[str]] = 0.2) -> None:
541
- super().__init__(dataset_filenames, groups_src, patch, use_cache, subset, batch_size, validation, inlineAugmentations, augmentations if augmentations else {})
708
+ def __init__(
709
+ self,
710
+ dataset_filenames: list[str] = ["default:./Dataset"],
711
+ groups_src: dict[str, Group] = {"default:group_src": Group()},
712
+ augmentations: dict[str, DataAugmentationsList] | None = {"DataAugmentation_0": DataAugmentationsList()},
713
+ inline_augmentations: bool = False,
714
+ patch: DatasetPatch | None = DatasetPatch(),
715
+ use_cache: bool = True,
716
+ subset: TrainSubset = TrainSubset(),
717
+ batch_size: int = 1,
718
+ validation: float | str | list[int] | list[str] = 0.2,
719
+ ) -> None:
720
+ super().__init__(
721
+ dataset_filenames,
722
+ groups_src,
723
+ patch,
724
+ use_cache,
725
+ subset,
726
+ batch_size,
727
+ validation,
728
+ inline_augmentations,
729
+ augmentations if augmentations else {},
730
+ )
731
+
542
732
 
543
733
  class DataPrediction(Data):
544
734
 
545
735
  @config("Dataset")
546
- def __init__(self, dataset_filenames : list[str] = ["default:./Dataset"],
547
- groups_src : dict[str, Group] = {"default" : Group()},
548
- augmentations : Union[dict[str, DataAugmentationsList], None] = {"DataAugmentation_0" : DataAugmentationsList()},
549
- patch : Union[DatasetPatch, None] = DatasetPatch(),
550
- subset : Union[PredictionSubset, dict[str, PredictionSubset]] = PredictionSubset(),
551
- batch_size : int = 1) -> None:
736
+ def __init__(
737
+ self,
738
+ dataset_filenames: list[str] = ["default:./Dataset"],
739
+ groups_src: dict[str, Group] = {"default": Group()},
740
+ augmentations: dict[str, DataAugmentationsList] | None = {"DataAugmentation_0": DataAugmentationsList()},
741
+ patch: DatasetPatch | None = DatasetPatch(),
742
+ subset: PredictionSubset = PredictionSubset(),
743
+ batch_size: int = 1,
744
+ ) -> None:
745
+
746
+ super().__init__(
747
+ dataset_filenames=dataset_filenames,
748
+ groups_src=groups_src,
749
+ patch=patch,
750
+ use_cache=False,
751
+ subset=subset,
752
+ batch_size=batch_size,
753
+ validation=None,
754
+ inline_augmentations=False,
755
+ data_augmentations_list=augmentations if augmentations else {},
756
+ )
552
757
 
553
- super().__init__(dataset_filenames, groups_src, patch, False, subset, batch_size, dataAugmentationsList=augmentations if augmentations else {})
554
758
 
555
759
  class DataMetric(Data):
556
760
 
557
761
  @config("Dataset")
558
- def __init__(self, dataset_filenames : list[str] = ["default:./Dataset"],
559
- groups_src : dict[str, GroupMetric] = {"default" : GroupMetric()},
560
- subset : Union[PredictionSubset, dict[str, PredictionSubset]] = PredictionSubset(),
561
- validation: Union[str, None] = None) -> None:
562
-
563
- super().__init__(dataset_filenames=dataset_filenames, groups_src=groups_src, patch=None, use_cache=False, subset=subset, batch_size=1, validation=validation)
762
+ def __init__(
763
+ self,
764
+ dataset_filenames: list[str] = ["default:./Dataset"],
765
+ groups_src: dict[str, GroupMetric] = {"default": GroupMetric()},
766
+ subset: PredictionSubset = PredictionSubset(),
767
+ validation: str | None = None,
768
+ ) -> None:
769
+
770
+ super().__init__(
771
+ dataset_filenames=dataset_filenames,
772
+ groups_src=groups_src,
773
+ patch=None,
774
+ use_cache=True,
775
+ subset=subset,
776
+ batch_size=1,
777
+ validation=validation,
778
+ data_augmentations_list={},
779
+ inline_augmentations=False,
780
+ )