konfai 1.1.7__py3-none-any.whl → 1.1.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of konfai might be problematic. Click here for more details.

Files changed (36) hide show
  1. konfai/__init__.py +59 -14
  2. konfai/data/augmentation.py +457 -286
  3. konfai/data/data_manager.py +509 -290
  4. konfai/data/patching.py +300 -183
  5. konfai/data/transform.py +384 -277
  6. konfai/evaluator.py +309 -68
  7. konfai/main.py +71 -22
  8. konfai/metric/measure.py +341 -222
  9. konfai/metric/schedulers.py +24 -13
  10. konfai/models/classification/convNeXt.py +187 -81
  11. konfai/models/classification/resnet.py +272 -58
  12. konfai/models/generation/cStyleGan.py +233 -59
  13. konfai/models/generation/ddpm.py +348 -121
  14. konfai/models/generation/diffusionGan.py +757 -358
  15. konfai/models/generation/gan.py +177 -53
  16. konfai/models/generation/vae.py +140 -40
  17. konfai/models/registration/registration.py +135 -52
  18. konfai/models/representation/representation.py +57 -23
  19. konfai/models/segmentation/NestedUNet.py +339 -68
  20. konfai/models/segmentation/UNet.py +140 -30
  21. konfai/network/blocks.py +331 -187
  22. konfai/network/network.py +781 -423
  23. konfai/predictor.py +645 -240
  24. konfai/trainer.py +527 -216
  25. konfai/utils/ITK.py +191 -106
  26. konfai/utils/config.py +152 -95
  27. konfai/utils/dataset.py +326 -455
  28. konfai/utils/utils.py +495 -249
  29. {konfai-1.1.7.dist-info → konfai-1.1.9.dist-info}/METADATA +1 -3
  30. konfai-1.1.9.dist-info/RECORD +38 -0
  31. konfai/utils/registration.py +0 -199
  32. konfai-1.1.7.dist-info/RECORD +0 -39
  33. {konfai-1.1.7.dist-info → konfai-1.1.9.dist-info}/WHEEL +0 -0
  34. {konfai-1.1.7.dist-info → konfai-1.1.9.dist-info}/entry_points.txt +0 -0
  35. {konfai-1.1.7.dist-info → konfai-1.1.9.dist-info}/licenses/LICENSE +0 -0
  36. {konfai-1.1.7.dist-info → konfai-1.1.9.dist-info}/top_level.txt +0 -0
@@ -1,85 +1,114 @@
1
1
  import math
2
2
  import os
3
3
  import random
4
+ import threading
5
+ from abc import ABC, abstractmethod
6
+ from collections.abc import Iterator, Mapping
7
+ from concurrent.futures import ThreadPoolExecutor, as_completed
8
+ from functools import partial
9
+
10
+ import numpy as np
4
11
  import torch
5
- from torch.utils import data
6
12
  import tqdm
7
- import numpy as np
8
- from abc import ABC
9
- from torch.utils.data import DataLoader, Sampler
10
- from typing import Union, Iterator
11
- from concurrent.futures import ThreadPoolExecutor, as_completed
12
- import threading
13
13
  from torch.cuda import device_count
14
- import SimpleITK as sitk
14
+ from torch.utils import data
15
+ from torch.utils.data import DataLoader, Sampler
15
16
 
16
- from konfai import KONFAI_STATE, KONFAI_ROOT
17
- from konfai.data.patching import DatasetPatch, DatasetManager
18
- from konfai.utils.config import config
19
- from konfai.utils.utils import memoryInfo, cpuInfo, memoryForecast, getMemory, State, SUPPORTED_EXTENSIONS, DatasetManagerError
20
- from konfai.utils.dataset import Dataset, Attribute
21
- from konfai.data.transform import TransformLoader, Transform
17
+ from konfai import konfai_root, konfai_state
22
18
  from konfai.data.augmentation import DataAugmentationsList
19
+ from konfai.data.patching import DatasetManager, DatasetPatch
20
+ from konfai.data.transform import Transform, TransformLoader
21
+ from konfai.utils.config import config
22
+ from konfai.utils.dataset import Attribute, Dataset
23
+ from konfai.utils.utils import (
24
+ SUPPORTED_EXTENSIONS,
25
+ DatasetManagerError,
26
+ State,
27
+ get_cpu_info,
28
+ get_memory,
29
+ get_memory_info,
30
+ memory_forecast,
31
+ )
32
+
23
33
 
24
34
  class GroupTransform:
25
35
 
26
36
  @config()
27
- def __init__(self, transforms : Union[dict[str, TransformLoader], list[Transform]] = {"default:Normalize:Standardize:Unsqueeze:TensorCast:ResampleIsotropic:ResampleResize": TransformLoader()},
28
- patch_transforms : Union[dict[str, TransformLoader], list[Transform]] = {"default:Normalize:Standardize:Unsqueeze:TensorCast:ResampleIsotropic:ResampleResize": TransformLoader()},
29
- isInput: bool = True) -> None:
30
- self._pre_transforms = transforms
31
- self._post_transforms = patch_transforms
32
- self.pre_transforms : list[Transform] = []
33
- self.post_transforms : list[Transform] = []
34
- self.isInput = isInput
35
-
36
- def load(self, group_src : str, group_dest : str, datasets: list[Dataset]):
37
- if self._pre_transforms is not None:
38
- if isinstance(self._pre_transforms, dict):
39
- for classpath, transform in self._pre_transforms.items():
40
- transform = transform.getTransform(classpath, DL_args = "{}.Dataset.groups_src.{}.groups_dest.{}.transforms".format(KONFAI_ROOT(), group_src, group_dest))
41
- transform.setDatasets(datasets)
42
- self.pre_transforms.append(transform)
43
- else:
44
- for transform in self._pre_transforms:
45
- transform.setDatasets(datasets)
46
- self.pre_transforms.append(transform)
47
-
48
- if self._post_transforms is not None:
49
- if isinstance(self._post_transforms, dict):
50
- for classpath, transform in self._post_transforms.items():
51
- transform = transform.getTransform(classpath, DL_args = "{}.Dataset.groups_src.{}.groups_dest.{}.patch_transforms".format(KONFAI_ROOT(), group_src, group_dest))
52
- transform.setDatasets(datasets)
53
- self.post_transforms.append(transform)
54
- else:
55
- for transform in self._post_transforms:
56
- transform.setDatasets(datasets)
57
- self.post_transforms.append(transform)
58
-
37
+ def __init__(
38
+ self,
39
+ transforms: dict[str, TransformLoader] = {
40
+ "default:Normalize:Standardize:Unsqueeze:TensorCast:ResampleIsotropic:ResampleResize": TransformLoader()
41
+ },
42
+ patch_transforms: dict[str, TransformLoader] = {
43
+ "default:Normalize:Standardize:Unsqueeze:TensorCast:ResampleIsotropic:ResampleResize": TransformLoader()
44
+ },
45
+ is_input: bool = True,
46
+ ) -> None:
47
+ self._transforms = transforms
48
+ self._patch_transforms = patch_transforms
49
+ self.transforms: list[Transform] = []
50
+ self.patch_transforms: list[Transform] = []
51
+ self.is_input = is_input
52
+
53
+ def load(self, group_src: str, group_dest: str, datasets: list[Dataset]):
54
+ if self._transforms is not None:
55
+ for classpath, transform_loader in self._transforms.items():
56
+ transform = transform_loader.get_transform(
57
+ classpath,
58
+ konfai_args=f"{konfai_root()}.Dataset.groups_src.{group_src}.groups_dest.{group_dest}.transforms",
59
+ )
60
+ transform.set_datasets(datasets)
61
+ self.transforms.append(transform)
62
+
63
+ if self._patch_transforms is not None:
64
+ for classpath, transform_loader in self._patch_transforms.items():
65
+ transform = transform_loader.get_transform(
66
+ classpath,
67
+ konfai_args=f"{konfai_root()}.Dataset.groups_src.{group_src}"
68
+ f".groups_dest.{group_dest}.patch_transforms",
69
+ )
70
+ transform.set_datasets(datasets)
71
+ self.patch_transforms.append(transform)
72
+
59
73
  def to(self, device: int):
60
- for transform in self.pre_transforms:
61
- transform.setDevice(device)
62
- for transform in self.post_transforms:
63
- transform.setDevice(device)
74
+ for transform in self.transforms:
75
+ transform.to(device)
76
+ for transform in self.patch_transforms:
77
+ transform.to(device)
78
+
64
79
 
65
80
  class GroupTransformMetric(GroupTransform):
66
81
 
67
82
  @config()
68
- def __init__(self, transforms : Union[dict[str, TransformLoader], list[Transform]] = {"default:Normalize:Standardize:Unsqueeze:TensorCast:ResampleIsotropic:ResampleResize": TransformLoader()}):
83
+ def __init__(
84
+ self,
85
+ transforms: dict[str, TransformLoader] = {
86
+ "default:Normalize:Standardize:Unsqueeze:TensorCast:ResampleIsotropic:ResampleResize": TransformLoader()
87
+ },
88
+ ):
69
89
  super().__init__(transforms, None)
70
90
 
91
+
71
92
  class Group(dict[str, GroupTransform]):
72
93
 
73
94
  @config()
74
- def __init__(self, groups_dest: dict[str, GroupTransform] = {"default:group_dest": GroupTransform()}):
95
+ def __init__(
96
+ self,
97
+ groups_dest: dict[str, GroupTransform] = {"default:group_dest": GroupTransform()},
98
+ ):
75
99
  super().__init__(groups_dest)
76
100
 
101
+
77
102
  class GroupMetric(dict[str, GroupTransformMetric]):
78
103
 
79
104
  @config()
80
- def __init__(self, groups_dest: dict[str, GroupTransformMetric] = {"default:group_dest": GroupTransformMetric()}):
105
+ def __init__(
106
+ self,
107
+ groups_dest: dict[str, GroupTransformMetric] = {"default:group_dest": GroupTransformMetric()},
108
+ ):
81
109
  super().__init__(groups_dest)
82
110
 
111
+
83
112
  class CustomSampler(Sampler[int]):
84
113
 
85
114
  def __init__(self, size: int, shuffle: bool = False) -> None:
@@ -87,135 +116,178 @@ class CustomSampler(Sampler[int]):
87
116
  self.shuffle = shuffle
88
117
 
89
118
  def __iter__(self) -> Iterator[int]:
90
- return iter(torch.randperm(len(self)).tolist() if self.shuffle else list(range(len(self))) )
119
+ return iter(torch.randperm(len(self)).tolist() if self.shuffle else list(range(len(self))))
91
120
 
92
121
  def __len__(self) -> int:
93
122
  return self.size
94
123
 
124
+
95
125
  class DatasetIter(data.Dataset):
96
126
 
97
- def __init__(self, rank: int, data : dict[str, list[DatasetManager]], map: dict[int, tuple[int, int, int]], groups_src : dict[str, Group], inlineAugmentations: bool, dataAugmentationsList : list[DataAugmentationsList], patch_size: Union[list[int], None], overlap: Union[int, None], buffer_size: int, use_cache = True) -> None:
127
+ def __init__(
128
+ self,
129
+ rank: int,
130
+ data: dict[str, list[DatasetManager]],
131
+ mapping: list[tuple[int, int, int]],
132
+ groups_src: Mapping[str, Group | GroupMetric],
133
+ inline_augmentations: bool,
134
+ data_augmentations_list: list[DataAugmentationsList],
135
+ patch_size: list[int] | None,
136
+ overlap: int | None,
137
+ buffer_size: int,
138
+ use_cache=True,
139
+ ) -> None:
98
140
  self.rank = rank
99
141
  self.data = data
100
- self.map = map
142
+ self.mapping = mapping
101
143
  self.patch_size = patch_size
102
144
  self.overlap = overlap
103
145
  self.groups_src = groups_src
104
- self.dataAugmentationsList = dataAugmentationsList
146
+ self.data_augmentations_list = data_augmentations_list
105
147
  self.use_cache = use_cache
106
148
  self.nb_dataset = len(data[list(data.keys())[0]])
107
149
  self.buffer_size = buffer_size
108
- self._index_cache = list()
109
- self.device = None
110
- self.inlineAugmentations = inlineAugmentations
150
+ self._index_cache: list[int] = []
151
+ self.inline_augmentations = inline_augmentations
111
152
 
112
- def getPatchConfig(self) -> tuple[list[int], int]:
153
+ def get_patch_config(self) -> tuple[list[int] | None, int | None]:
113
154
  return self.patch_size, self.overlap
114
-
155
+
115
156
  def to(self, device: int):
116
157
  for group_src in self.groups_src:
117
158
  for group_dest in self.groups_src[group_src]:
118
159
  self.groups_src[group_src][group_dest].to(device)
119
- self.device = device
160
+ for data_augmentations in self.data_augmentations_list:
161
+ for data_augmentation in data_augmentations.data_augmentations:
162
+ data_augmentation.to(device)
120
163
 
121
- def getDatasetFromIndex(self, group_dest: str, index: int) -> DatasetManager:
164
+ def get_dataset_from_index(self, group_dest: str, index: int) -> DatasetManager:
122
165
  return self.data[group_dest][index]
123
-
124
- def resetAugmentation(self, label):
125
- if self.inlineAugmentations and len(self.dataAugmentationsList) > 0:
166
+
167
+ def reset_augmentation(self, label):
168
+ if self.inline_augmentations and len(self.data_augmentations_list) > 0:
126
169
  for index in range(self.nb_dataset):
127
170
  for group_src in self.groups_src:
128
171
  for group_dest in self.groups_src[group_src]:
129
- self.data[group_dest][index].unloadAugmentation()
130
- self.data[group_dest][index].resetAugmentation()
172
+ self.data[group_dest][index].unload_augmentation()
173
+ self.data[group_dest][index].reset_augmentation()
131
174
  self.load(label + " Augmentation")
132
175
 
133
176
  def load(self, label: str):
134
177
  if self.use_cache:
135
- memory_init = getMemory()
178
+ memory_init = get_memory()
136
179
 
137
- indexs = [index for index in range(self.nb_dataset)]
180
+ indexs = list(range(self.nb_dataset))
138
181
  if len(indexs) > 0:
139
182
  memory_lock = threading.Lock()
140
- desc = lambda : "Caching "+ label +": {} | {} | {}".format(memoryInfo(), memoryForecast(memory_init, 0, self.nb_dataset), cpuInfo())
141
- pbar = tqdm.tqdm(
142
- total=len(indexs),
143
- desc=desc(),
144
- leave=False
145
- )
183
+
184
+ def desc():
185
+ return (
186
+ f"Caching {label}: "
187
+ f"{get_memory_info()} | "
188
+ f"{memory_forecast(memory_init, 0, self.nb_dataset)} | "
189
+ f"{get_cpu_info()}"
190
+ )
191
+
192
+ pbar = tqdm.tqdm(total=len(indexs), desc=desc(), leave=False)
146
193
 
147
194
  def process(index):
148
- self._loadData(index)
195
+ self._load_data(index)
149
196
  with memory_lock:
150
197
  pbar.set_description(desc())
151
198
  pbar.update(1)
152
- with ThreadPoolExecutor(max_workers=os.cpu_count()//(device_count() if device_count() > 0 else 1)) as executor:
199
+
200
+ cpu_count = os.cpu_count() or 1
201
+ with ThreadPoolExecutor(
202
+ max_workers=cpu_count // (device_count() if device_count() > 0 else 1)
203
+ ) as executor:
153
204
  futures = [executor.submit(process, index) for index in indexs]
154
205
  for _ in as_completed(futures):
155
206
  pass
156
207
 
157
208
  pbar.close()
158
-
159
- def _loadData(self, index):
209
+
210
+ def _load_data(self, index):
160
211
  if index not in self._index_cache:
161
212
  self._index_cache.append(index)
162
213
  for group_src in self.groups_src:
163
214
  for group_dest in self.groups_src[group_src]:
164
- self.loadData(group_src, group_dest, index)
215
+ self.load_data(group_src, group_dest, index)
165
216
 
166
- def loadData(self, group_src: str, group_dest : str, index : int) -> None:
167
- self.data[group_dest][index].load(self.groups_src[group_src][group_dest].pre_transforms, self.dataAugmentationsList, self.device)
217
+ def load_data(self, group_src: str, group_dest: str, index: int) -> None:
218
+ self.data[group_dest][index].load(
219
+ self.groups_src[group_src][group_dest].transforms,
220
+ self.data_augmentations_list,
221
+ )
168
222
 
169
- def _unloadData(self, index : int) -> None:
223
+ def _unload_data(self, index: int) -> None:
170
224
  if index in self._index_cache:
171
225
  self._index_cache.remove(index)
172
226
  for group_src in self.groups_src:
173
227
  for group_dest in self.groups_src[group_src]:
174
- self.unloadData(group_dest, index)
175
-
176
- def unloadData(self, group_dest : str, index : int) -> None:
228
+ self.unload_data(group_dest, index)
229
+
230
+ def unload_data(self, group_dest: str, index: int) -> None:
177
231
  return self.data[group_dest][index].unload()
178
232
 
179
233
  def __len__(self) -> int:
180
- return len(self.map)
234
+ return len(self.mapping)
181
235
 
182
- def __getitem__(self, index : int) -> dict[str, tuple[torch.Tensor, int, int, int, str, bool]]:
236
+ def __getitem__(self, index: int) -> dict[str, tuple[torch.Tensor, int, int, int, str, bool]]:
183
237
  data = {}
184
- x, a, p = self.map[index]
238
+ x, a, p = self.mapping[index]
185
239
  if x not in self._index_cache:
186
240
  if len(self._index_cache) >= self.buffer_size and not self.use_cache:
187
- self._unloadData(self._index_cache[0])
188
- self._loadData(x)
241
+ self._unload_data(self._index_cache[0])
242
+ self._load_data(x)
189
243
 
190
244
  for group_src in self.groups_src:
191
245
  for group_dest in self.groups_src[group_src]:
192
246
  dataset = self.data[group_dest][x]
193
- data["{}".format(group_dest)] = (dataset.getData(p, a, self.groups_src[group_src][group_dest].post_transforms, self.groups_src[group_src][group_dest].isInput), x, a, p, dataset.name, self.groups_src[group_src][group_dest].isInput)
247
+ data[f"{group_dest}"] = (
248
+ dataset.get_data(
249
+ p,
250
+ a,
251
+ self.groups_src[group_src][group_dest].patch_transforms,
252
+ self.groups_src[group_src][group_dest].is_input,
253
+ ),
254
+ x,
255
+ a,
256
+ p,
257
+ dataset.name,
258
+ self.groups_src[group_src][group_dest].is_input,
259
+ )
194
260
  return data
195
261
 
196
- class Subset():
197
-
198
- def __init__(self, subset: Union[str, list[int], list[str], None] = None, shuffle: bool = True) -> None:
262
+
263
+ class Subset:
264
+
265
+ def __init__(
266
+ self,
267
+ subset: str | list[int] | list[str] | None = None,
268
+ shuffle: bool = True,
269
+ ) -> None:
199
270
  self.subset = subset
200
271
  self.shuffle = shuffle
201
272
 
202
- def __call__(self, names: list[str], infos: list[dict[str, tuple[np.ndarray, Attribute]]]) -> set[str]:
203
- inter_name = set(names[0])
204
- for n in names[1:]:
205
- inter_name = inter_name.intersection(set(n))
206
- names = sorted(list(inter_name))
207
-
273
+ def __call__(self, names: list[str], infos: dict[str, tuple[list[int], Attribute]]) -> set[str]:
274
+ names = sorted(names)
275
+
208
276
  size = len(names)
209
277
  index = []
210
278
  if self.subset is None:
211
279
  index = list(range(0, size))
212
280
  elif isinstance(self.subset, str):
213
281
  if ":" in self.subset:
214
- r = np.clip(np.asarray([int(self.subset.split(":")[0]), int(self.subset.split(":")[1])]), 0, size)
282
+ r = np.clip(
283
+ np.asarray([int(self.subset.split(":")[0]), int(self.subset.split(":")[1])]),
284
+ 0,
285
+ size,
286
+ )
215
287
  index = list(range(r[0], r[1]))
216
288
  elif os.path.exists(self.subset):
217
289
  train_names = []
218
- with open(self.subset, "r") as f:
290
+ with open(self.subset) as f:
219
291
  for name in f:
220
292
  train_names.append(name.strip())
221
293
  index = []
@@ -224,7 +296,7 @@ class Subset():
224
296
  index.append(i)
225
297
  elif self.subset.startswith("~") and os.path.exists(self.subset[1:]):
226
298
  exclude_names = []
227
- with open(self.subset[1:], "r") as f:
299
+ with open(self.subset[1:]) as f:
228
300
  for name in f:
229
301
  exclude_names.append(name.strip())
230
302
  index = []
@@ -233,200 +305,283 @@ class Subset():
233
305
  index.append(i)
234
306
 
235
307
  elif isinstance(self.subset, list):
308
+ index = []
236
309
  if len(self.subset) > 0:
237
- if isinstance(self.subset[0], int):
238
- if len(self.subset) == 1:
239
- index = list(range(self.subset[0], min(size, self.subset[0]+1)))
240
- else:
241
- index = self.subset
242
- if isinstance(self.subset[0], str):
243
- index = []
244
- for i, name in enumerate(names):
245
- if name in self.subset:
246
- index.append(i)
310
+ for s in self.subset:
311
+ if isinstance(s, int):
312
+ index.append(s)
313
+ elif isinstance(s, str):
314
+ for i, name in enumerate(names):
315
+ if name in self.subset:
316
+ index.append(i)
247
317
  if self.shuffle:
248
- index = random.sample(index, len(index))
249
- return set([names[i] for i in index])
250
-
318
+ index = random.sample(index, len(index)) # nosec B311
319
+ return {names[i] for i in index}
320
+
251
321
  def __str__(self):
252
- return "Subset : " + str(self.subset) + " shuffle : "+ str(self.shuffle)
253
-
322
+ return "Subset : " + str(self.subset) + " shuffle : " + str(self.shuffle)
323
+
324
+
254
325
  class TrainSubset(Subset):
255
326
 
256
327
  @config()
257
- def __init__(self, subset: Union[str, list[int], list[str], None] = None, shuffle: bool = True) -> None:
328
+ def __init__(
329
+ self,
330
+ subset: str | list[int] | list[str] | None = None,
331
+ shuffle: bool = True,
332
+ ) -> None:
258
333
  super().__init__(subset, shuffle)
259
334
 
335
+
260
336
  class PredictionSubset(Subset):
261
337
 
262
338
  @config()
263
- def __init__(self, subset: Union[str, list[int], list[str], None] = None) -> None:
339
+ def __init__(self, subset: str | list[int] | list[str] | None = None) -> None:
264
340
  super().__init__(subset, False)
265
341
 
342
+
266
343
  class Data(ABC):
267
-
268
- def __init__(self, dataset_filenames : list[str],
269
- groups_src : dict[str, Group],
270
- patch : Union[DatasetPatch, None],
271
- use_cache : bool,
272
- subset : Subset,
273
- batch_size : int,
274
- validation: Union[float, str, list[int], list[str], None] = None,
275
- inlineAugmentations: bool = False,
276
- dataAugmentationsList: dict[str, DataAugmentationsList]= {}) -> None:
344
+
345
+ @abstractmethod
346
+ def __init__(
347
+ self,
348
+ dataset_filenames: list[str],
349
+ groups_src: Mapping[str, Group | GroupMetric],
350
+ patch: DatasetPatch | None,
351
+ use_cache: bool,
352
+ subset: Subset,
353
+ batch_size: int,
354
+ validation: float | str | list[int] | list[str] | None,
355
+ inline_augmentations: bool,
356
+ data_augmentations_list: dict[str, DataAugmentationsList],
357
+ ) -> None:
277
358
  self.dataset_filenames = dataset_filenames
278
359
  self.subset = subset
279
360
  self.groups_src = groups_src
280
361
  self.patch = patch
281
362
  self.validation = validation
282
- self.dataAugmentationsList = dataAugmentationsList
363
+ self.data_augmentations_list = data_augmentations_list
283
364
  self.batch_size = batch_size
284
- self.dataSet_args = dict(groups_src=self.groups_src, inlineAugmentations=inlineAugmentations, dataAugmentationsList = list(self.dataAugmentationsList.values()), use_cache = use_cache, buffer_size=batch_size+1, patch_size=self.patch.patch_size if self.patch is not None else None, overlap=self.patch.overlap if self.patch is not None else None)
285
- self.dataLoader_args = dict(num_workers=int(os.environ["KONFAI_WORKERS"]), pin_memory=True)
286
- self.data : list[list[dict[str, list[DatasetManager]]], dict[str, list[DatasetManager]]] = []
287
- self.map : list[list[list[tuple[int, int, int]]], list[tuple[int, int, int]]] = []
365
+
366
+ self.datasetIter = partial(
367
+ DatasetIter,
368
+ groups_src=self.groups_src,
369
+ inline_augmentations=inline_augmentations,
370
+ data_augmentations_list=list(self.data_augmentations_list.values()),
371
+ patch_size=self.patch.patch_size if self.patch is not None else None,
372
+ overlap=self.patch.overlap if self.patch is not None else None,
373
+ buffer_size=batch_size + 1,
374
+ use_cache=use_cache,
375
+ )
376
+ self.dataLoader_args = {
377
+ "num_workers": int(os.environ["KONFAI_WORKERS"]) if use_cache else 0,
378
+ "pin_memory": True,
379
+ }
380
+ self.data: list[list[dict[str, list[DatasetManager]]]] = []
381
+ self.mapping: list[list[list[tuple[int, int, int]]]] = []
288
382
  self.datasets: dict[str, Dataset] = {}
289
383
 
290
- def _getDatasets(self, names: list[str], dataset_name: dict[str, dict[str, list[str]]]) -> tuple[dict[str, list[Dataset]], list[tuple[int, int, int]]]:
384
+ def _get_datasets(
385
+ self, names: list[str], dataset_name: dict[str, dict[str, list[str]]]
386
+ ) -> tuple[dict[str, list[DatasetManager]], list[tuple[int, int, int]]]:
291
387
  nb_dataset = len(names)
292
- nb_patch = None
388
+ nb_patch: list[list[int]]
293
389
  data = {}
294
- map = []
295
- nb_augmentation = np.max([int(np.sum([data_augmentation.nb for data_augmentation in self.dataAugmentationsList.values()])+1), 1])
390
+ mapping = []
391
+ nb_augmentation = np.max(
392
+ [
393
+ int(np.sum([data_augmentation.nb for data_augmentation in self.data_augmentations_list.values()]) + 1),
394
+ 1,
395
+ ]
396
+ )
296
397
  for group_src in self.groups_src:
297
398
  for group_dest in self.groups_src[group_src]:
298
- data[group_dest] = [DatasetManager(i, group_src, group_dest, name, self.datasets[[filename for filename, names in dataset_name[group_src].items() if name in names][0]], patch = self.patch, pre_transforms = self.groups_src[group_src][group_dest].pre_transforms, dataAugmentationsList=list(self.dataAugmentationsList.values())) for i, name in enumerate(names)]
299
- nb_patch = [[dataset.getSize(a) for a in range(nb_augmentation)] for dataset in data[group_dest]]
399
+ data[group_dest] = [
400
+ DatasetManager(
401
+ i,
402
+ group_src,
403
+ group_dest,
404
+ name,
405
+ self.datasets[
406
+ [filename for filename, names in dataset_name[group_src].items() if name in names][0]
407
+ ],
408
+ patch=self.patch,
409
+ transforms=self.groups_src[group_src][group_dest].transforms,
410
+ data_augmentations_list=list(self.data_augmentations_list.values()),
411
+ )
412
+ for i, name in enumerate(names)
413
+ ]
414
+ nb_patch = [[dataset.get_size(a) for a in range(nb_augmentation)] for dataset in data[group_dest]]
300
415
 
301
416
  for x in range(nb_dataset):
302
417
  for y in range(nb_augmentation):
303
418
  for z in range(nb_patch[x][y]):
304
- map.append((x, y, z))
305
- return data, map
419
+ mapping.append((x, y, z))
420
+ return data, mapping
306
421
 
307
- def getGroupsDest(self):
308
- groupsDest = []
422
+ def get_groups_dest(self):
423
+ groups_dest = []
309
424
  for group_src in self.groups_src:
310
425
  for group_dest in self.groups_src[group_src]:
311
- groupsDest.append(group_dest)
312
- return groupsDest
313
-
314
- def _split(map: list[tuple[int, int, int]], world_size: int) -> list[list[tuple[int, int, int]]]:
315
- if len(map) == 0:
426
+ groups_dest.append(group_dest)
427
+ return groups_dest
428
+
429
+ @staticmethod
430
+ def _split(mapping: list[tuple[int, int, int]], world_size: int) -> list[list[tuple[int, int, int]]]:
431
+ if len(mapping) == 0:
316
432
  return [[] for _ in range(world_size)]
317
-
318
- maps = []
319
- if KONFAI_STATE() == str(State.PREDICTION) or KONFAI_STATE() == str(State.EVALUATION):
320
- np_map = np.asarray(map)
321
- unique_index = np.unique(np_map[:, 0])
322
- offset = int(np.ceil(len(unique_index)/world_size))
433
+
434
+ mappings = []
435
+ if konfai_state() == str(State.PREDICTION) or konfai_state() == str(State.EVALUATION):
436
+ np_mapping = np.asarray(mapping)
437
+ unique_index = np.unique(np_mapping[:, 0])
438
+ offset = int(np.ceil(len(unique_index) / world_size))
323
439
  if offset == 0:
324
440
  offset = 1
325
441
  for itr in range(0, len(unique_index), offset):
326
- maps.append([tuple(v) for v in np_map[np.where(np.isin(np_map[:, 0], unique_index[itr:itr+offset]))[0], :]])
442
+ mappings.append(
443
+ [
444
+ tuple(v)
445
+ for v in np_mapping[
446
+ np.where(np.isin(np_mapping[:, 0], unique_index[itr : itr + offset]))[0],
447
+ :,
448
+ ]
449
+ ]
450
+ )
327
451
  else:
328
- offset = int(np.ceil(len(map)/world_size))
452
+ offset = int(np.ceil(len(mapping) / world_size))
329
453
  if offset == 0:
330
454
  offset = 1
331
- for itr in range(0, len(map), offset):
332
- maps.append(list(map[-offset:]) if itr+offset > len(map) else map[itr:itr+offset])
333
- return maps
334
-
335
- def getData(self, world_size: int) -> list[list[DataLoader]]:
336
- datasets: dict[str, list[(str, bool)]] = {}
455
+ for itr in range(0, len(mapping), offset):
456
+ mappings.append(list(mapping[-offset:]) if itr + offset > len(mapping) else mapping[itr : itr + offset])
457
+ return mappings
458
+
459
+ def get_data(self, world_size: int) -> list[list[DataLoader]]:
460
+ datasets: dict[str, list[tuple[str, bool]]] = {}
337
461
  if self.dataset_filenames is None or len(self.dataset_filenames) == 0:
338
462
  raise DatasetManagerError("No dataset filenames were provided")
339
463
  for dataset_filename in self.dataset_filenames:
340
464
  if dataset_filename is None:
341
- raise DatasetManagerError("Invalid dataset entry: 'None' received.",
342
- "Each dataset must be a valid path string (e.g., './Dataset/', './Dataset/:mha, './Dataset/:a:mha', './Dataset/:i:mha').",
343
- "Please check your 'dataset_filenames' list for missing or null entries."
465
+ raise DatasetManagerError(
466
+ "Invalid dataset entry: 'None' received.",
467
+ "Each dataset must be a valid path string (e.g., './Dataset/', './Dataset/:mha, "
468
+ "'./Dataset/:a:mha', './Dataset/:i:mha').",
469
+ "Please check your 'dataset_filenames' list for missing or null entries.",
344
470
  )
345
471
  if len(dataset_filename.split(":")) == 1:
346
472
  filename = dataset_filename
347
- format = "mha"
473
+ file_format = "mha"
348
474
  append = True
349
475
  elif len(dataset_filename.split(":")) == 2:
350
- filename, format = dataset_filename.split(":")
476
+ filename, file_format = dataset_filename.split(":")
351
477
  append = True
352
478
  else:
353
- filename, flag, format = dataset_filename.split(":")
479
+ filename, flag, file_format = dataset_filename.split(":")
354
480
  append = flag == "a"
355
481
 
356
- if format not in SUPPORTED_EXTENSIONS:
357
- raise DatasetManagerError(f"Unsupported file format '{format}'.",
358
- f"Supported extensions are: {', '.join(SUPPORTED_EXTENSIONS)}")
359
-
360
- dataset = Dataset(filename, format)
482
+ if file_format not in SUPPORTED_EXTENSIONS:
483
+ raise DatasetManagerError(
484
+ f"Unsupported file format '{file_format}'.",
485
+ f"Supported extensions are: {', '.join(SUPPORTED_EXTENSIONS)}",
486
+ )
487
+
488
+ dataset = Dataset(filename, file_format)
361
489
 
362
490
  self.datasets[filename] = dataset
363
491
  for group in self.groups_src:
364
- if dataset.isGroupExist(group):
492
+ if dataset.is_group_exist(group):
365
493
  if group in datasets:
366
- datasets[group].append((filename, append))
494
+ datasets[group].append((filename, append))
367
495
  else:
368
496
  datasets[group] = [(filename, append)]
369
- modelHaveInput = False
497
+ model_have_input = False
370
498
  for group_src in self.groups_src:
371
499
  if group_src not in datasets:
372
-
500
+
373
501
  raise DatasetManagerError(
374
502
  f"Group source '{group_src}' not found in any dataset.",
375
503
  f"Dataset filenames provided: {self.dataset_filenames}",
376
- "Available groups across all datasets: {}".format(["{} {}".format(f, d.getGroup()) for f, d in self.datasets.items()]),
377
- f"Please check that an entry in the dataset with the name '{group_src}.{format}' exists."
504
+ f"Available groups across all datasets: "
505
+ "{[f'{f} {d.get_group()}' for f, d in self.datasets.items()]}"
506
+ f"Please check that an entry in the dataset with the name '{group_src}.{format}' exists.",
378
507
  )
379
-
508
+
380
509
  for group_dest in self.groups_src[group_src]:
381
- self.groups_src[group_src][group_dest].load(group_src, group_dest, [self.datasets[filename] for filename, _ in datasets[group_src]])
382
- modelHaveInput |= self.groups_src[group_src][group_dest].isInput
510
+ self.groups_src[group_src][group_dest].load(
511
+ group_src,
512
+ group_dest,
513
+ [self.datasets[filename] for filename, _ in datasets[group_src]],
514
+ )
515
+ model_have_input |= self.groups_src[group_src][group_dest].is_input
516
+ if self.patch is not None:
517
+ self.patch.init()
383
518
 
384
- if not modelHaveInput:
519
+ if not model_have_input:
385
520
  raise DatasetManagerError(
386
- "At least one group must be defined with 'isInput: true' to provide input to the network."
521
+ "At least one group must be defined with 'is_input: true' to provide input to the network."
387
522
  )
388
523
 
389
- for key, dataAugmentations in self.dataAugmentationsList.items():
390
- dataAugmentations.load(key, [self.datasets[filename] for filename, _ in datasets[group_src]])
524
+ for key, data_augmentations in self.data_augmentations_list.items():
525
+ data_augmentations.load(key, [self.datasets[filename] for filename, _ in datasets[group_src]])
391
526
 
392
- names = set()
393
- dataset_name : dict[str, dict[str, list[str]]] = {}
394
- dataset_info : dict[str, dict[str, dict[str, Attribute]]] = {}
527
+ names: set[str] = set()
528
+ dataset_name: dict[str, dict[str, list[str]]] = {}
529
+ dataset_info: dict[str, dict[str, dict[str, tuple[list[int], Attribute]]]] = {}
395
530
  for group in self.groups_src:
396
- namesByGroup = set()
531
+ names_by_group = set()
397
532
  if group not in dataset_name:
398
533
  dataset_name[group] = {}
399
534
  dataset_info[group] = {}
400
535
  for filename, _ in datasets[group]:
401
- namesByGroup.update(self.datasets[filename].getNames(group))
402
- dataset_name[group][filename] = self.datasets[filename].getNames(group)
403
- dataset_info[group][filename] = {name: self.datasets[filename].getInfos(group, name) for name in dataset_name[group][filename]}
536
+ names_by_group.update(self.datasets[filename].get_names(group))
537
+ dataset_name[group][filename] = self.datasets[filename].get_names(group)
538
+ dataset_info[group][filename] = {
539
+ name: self.datasets[filename].get_infos(group, name) for name in dataset_name[group][filename]
540
+ }
404
541
  if len(names) == 0:
405
- names.update(namesByGroup)
406
- else:
407
- names = names.intersection(namesByGroup)
542
+ names.update(names_by_group)
543
+ else:
544
+ names = names.intersection(names_by_group)
408
545
  if len(names) == 0:
409
- raise DatasetManagerError(
410
- f"No data was found for groups {list(self.groups_src.keys())}: although each group contains data from a dataset, there are no common dataset names shared across all groups, the intersection is empty."
411
- )
412
-
413
- subset_names = set()
546
+ raise DatasetManagerError(
547
+ f"No data was found for groups {list(self.groups_src.keys())}: although each group contains data "
548
+ "from a dataset, there are no common dataset names shared across all groups, the intersection is empty."
549
+ )
550
+
551
+ subset_names: set[str] = set()
414
552
  for group in dataset_name:
415
- subset_names_bygroup = set()
553
+ subset_names_bygroup: set[str] = set()
416
554
  for filename, append in datasets[group]:
417
555
  if append:
418
- subset_names_bygroup.update(self.subset([dataset_name[group][filename]], [dataset_info[group][filename]]))
556
+ subset_names_bygroup.update(
557
+ self.subset(
558
+ dataset_name[group][filename],
559
+ dataset_info[group][filename],
560
+ )
561
+ )
419
562
  else:
420
563
  if len(subset_names_bygroup) == 0:
421
- subset_names_bygroup.update(self.subset([dataset_name[group][filename]], [dataset_info[group][filename]]))
564
+ subset_names_bygroup.update(
565
+ self.subset(
566
+ dataset_name[group][filename],
567
+ dataset_info[group][filename],
568
+ )
569
+ )
422
570
  else:
423
- subset_names_bygroup = subset_names_bygroup.intersection(self.subset([dataset_name[group][filename]], [dataset_info[group][filename]]))
571
+ subset_names_bygroup = subset_names_bygroup.intersection(
572
+ self.subset(
573
+ dataset_name[group][filename],
574
+ dataset_info[group][filename],
575
+ )
576
+ )
424
577
  if len(subset_names) == 0:
425
578
  subset_names.update(subset_names_bygroup)
426
- else:
579
+ else:
427
580
  subset_names = subset_names.intersection(subset_names_bygroup)
581
+
428
582
  if len(subset_names) == 0:
429
- raise DatasetManagerError("All data entries were excluded by the subset filter.",
583
+ raise DatasetManagerError(
584
+ "All data entries were excluded by the subset filter.",
430
585
  f"Dataset entries found: {', '.join(names)}",
431
586
  f"Subset object applied: {self.subset}",
432
587
  f"Subset requested : {', '.join(subset_names)}",
@@ -436,31 +591,38 @@ class Data(ABC):
436
591
  "\tsubset: [0, 1] # explicit indices",
437
592
  "\tsubset: 0:10 # slice notation",
438
593
  "\tsubset: ./Validation.txt # external file",
439
- "\tsubset: None # to disable filtering"
594
+ "\tsubset: None # to disable filtering",
440
595
  )
441
-
442
- data, map = self._getDatasets(list(subset_names), dataset_name)
443
596
 
444
- train_map = map
445
- validate_map = []
597
+ data, mapping = self._get_datasets(list(subset_names), dataset_name)
598
+
599
+ train_mapping = mapping
600
+ validate_mapping = []
446
601
  if isinstance(self.validation, float) or isinstance(self.validation, int):
447
602
  if self.validation <= 0 or self.validation >= 1:
448
- raise DatasetManagerError("Validation must be a float between 0 and 1.", f"Received: {self.validation}", "Example: validation = 0.2 # for a 20% validation split")
449
-
450
- train_map, validate_map = map[:int(math.floor(len(map)*(1-self.validation)))], map[int(math.floor(len(map)*(1-self.validation))):]
603
+ raise DatasetManagerError(
604
+ "Validation must be a float between 0 and 1.",
605
+ f"Received: {self.validation}",
606
+ "Example: validation = 0.2 # for a 20% validation split",
607
+ )
608
+
609
+ train_mapping, validate_mapping = (
610
+ mapping[: int(math.floor(len(mapping) * (1 - self.validation)))],
611
+ mapping[int(math.floor(len(mapping) * (1 - self.validation))) :],
612
+ )
451
613
  elif isinstance(self.validation, str):
452
614
  if ":" in self.validation:
453
- index = list(range(int(self.subset.split(":")[0]), int(self.subset.split(":")[1])))
454
- train_map = [m for m in map if m[0] not in index]
455
- validate_map = [m for m in map if m[0] in index]
615
+ index = list(range(int(self.validation.split(":")[0]), int(self.validation.split(":")[1])))
616
+ train_mapping = [m for m in mapping if m[0] not in index]
617
+ validate_mapping = [m for m in mapping if m[0] in index]
456
618
  elif os.path.exists(self.validation):
457
619
  validation_names = []
458
- with open(self.validation, "r") as f:
620
+ with open(self.validation) as f:
459
621
  for name in f:
460
622
  validation_names.append(name.strip())
461
623
  index = [i for i, n in enumerate(subset_names) if n in validation_names]
462
- train_map = [m for m in map if m[0] not in index]
463
- validate_map = [m for m in map if m[0] in index]
624
+ train_mapping = [m for m in mapping if m[0] not in index]
625
+ validate_mapping = [m for m in mapping if m[0] in index]
464
626
  else:
465
627
  raise DatasetManagerError(
466
628
  f"Invalid string value for 'validation': '{self.validation}'",
@@ -470,94 +632,151 @@ class Data(ABC):
470
632
  "\t• A float between 0 and 1 (e.g., 0.2)",
471
633
  "\t• A list of sample names or indices",
472
634
  "The provided value is neither a valid slice nor a readable file.",
473
- "Please fix your 'validation' setting in the configuration."
474
- )
475
-
635
+ "Please fix your 'validation' setting in the configuration.",
636
+ )
637
+
476
638
  elif isinstance(self.validation, list):
477
639
  if len(self.validation) > 0:
478
640
  if isinstance(self.validation[0], int):
479
- train_map = [m for m in map if m[0] not in self.validation]
480
- validate_map = [m for m in map if m[0] in self.validation]
641
+ train_mapping = [m for m in mapping if m[0] not in self.validation]
642
+ validate_mapping = [m for m in mapping if m[0] in self.validation]
481
643
  elif isinstance(self.validation[0], str):
482
644
  index = [i for i, n in enumerate(subset_names) if n in self.validation]
483
- train_map = [m for m in map if m[0] not in index]
484
- validate_map = [m for m in map if m[0] in index]
645
+ train_mapping = [m for m in mapping if m[0] not in index]
646
+ validate_mapping = [m for m in mapping if m[0] in index]
485
647
  else:
486
- raise DatasetManagerError(f"Invalid list type for 'validation': elements of type '{type(self.validation[0]).__name__}' are not supported.",
487
- "Supported list element types are:",
488
- "\t• int → list of indices (e.g., [0, 1, 2])",
489
- "\t• str → list of sample names (e.g., ['patient01', 'patient02'])",
490
- f"Received list: {self.validation}"
491
- )
492
- if len(train_map) == 0:
493
- raise DatasetManagerError("No data left for training after applying the validation split.",
494
- f"Dataset size: {len(map)}",
648
+ raise DatasetManagerError(
649
+ "Invalid list type for 'validation': elements of type "
650
+ f"'{type(self.validation[0]).__name__}' are not supported.",
651
+ "Supported list element types are:",
652
+ "\t• int → list of indices (e.g., [0, 1, 2])",
653
+ "\t• str → list of sample names (e.g., ['patient01', 'patient02'])",
654
+ f"Received list: {self.validation}",
655
+ )
656
+ if len(train_mapping) == 0:
657
+ raise DatasetManagerError(
658
+ "No data left for training after applying the validation split.",
659
+ f"Dataset size: {len(mapping)}",
495
660
  f"Validation setting: {self.validation}",
496
- "Please reduce the validation size, increase the dataset, or disable validation."
661
+ "Please reduce the validation size, increase the dataset, or disable validation.",
497
662
  )
498
663
 
499
- if self.validation is not None and len(validate_map) == 0:
500
- raise DatasetManagerError("No data left for validation after applying the validation split.",
501
- f"Dataset size: {len(map)}",
664
+ if self.validation is not None and len(validate_mapping) == 0:
665
+ raise DatasetManagerError(
666
+ "No data left for validation after applying the validation split.",
667
+ f"Dataset size: {len(mapping)}",
502
668
  f"Validation setting: {self.validation}",
503
- "Please increase the validation size, increase the dataset, or disable validation."
669
+ "Please increase the validation size, increase the dataset, or disable validation.",
504
670
  )
505
- train_maps = Data._split(train_map, world_size)
506
- validate_maps = Data._split(validate_map, world_size)
507
-
508
- for i, (train_map, validate_map) in enumerate(zip(train_maps, validate_maps)):
509
- maps = [train_map]
510
- if len(validate_map):
511
- maps += [validate_map]
671
+ train_mappings = Data._split(train_mapping, world_size)
672
+ validate_mappings = Data._split(validate_mapping, world_size)
673
+
674
+ for i, (train_mapping, validate_mapping) in enumerate(zip(train_mappings, validate_mappings)):
675
+ mappings = [train_mapping]
676
+ if len(validate_mapping):
677
+ mappings += [validate_mapping]
512
678
  self.data.append([])
513
- self.map.append([])
514
- for map_tmp in maps:
515
- indexs = np.unique(np.asarray(map_tmp)[:, 0])
516
- self.data[i].append({k:[v[it] for it in indexs] for k, v in data.items()})
517
- map_tmp_array = np.asarray(map_tmp)
679
+ self.mapping.append([])
680
+ for mapping_tmp in mappings:
681
+ indexs = np.unique(np.asarray(mapping_tmp)[:, 0])
682
+ self.data[i].append({k: [v[it] for it in indexs] for k, v in data.items()})
683
+ mapping_tmp_array = np.asarray(mapping_tmp)
518
684
  for a, b in enumerate(indexs):
519
- map_tmp_array[np.where(np.asarray(map_tmp_array)[:, 0] == b), 0] = a
520
- self.map[i].append([(a,b,c) for a,b,c in map_tmp_array])
685
+ mapping_tmp_array[np.where(np.asarray(mapping_tmp_array)[:, 0] == b), 0] = a
686
+ self.mapping[i].append([(a, b, c) for a, b, c in mapping_tmp_array])
687
+
688
+ data_loaders: list[list[DataLoader]] = []
689
+ for i, (datas, mappings) in enumerate(zip(self.data, self.mapping)):
690
+ data_loaders.append([])
691
+ for data, mapping in zip(datas, mappings):
692
+ data_loaders[i].append(
693
+ DataLoader(
694
+ dataset=self.datasetIter(
695
+ rank=i,
696
+ data=data,
697
+ mapping=mapping,
698
+ ),
699
+ sampler=CustomSampler(len(mapping), self.subset.shuffle),
700
+ batch_size=self.batch_size,
701
+ **self.dataLoader_args,
702
+ )
703
+ )
704
+ return data_loaders
521
705
 
522
- dataLoaders: list[list[DataLoader]] = []
523
- for i, (datas, maps) in enumerate(zip(self.data, self.map)):
524
- dataLoaders.append([])
525
- for data, map in zip(datas, maps):
526
- dataLoaders[i].append(DataLoader(dataset=DatasetIter(rank=i, data=data, map=map, **self.dataSet_args), sampler=CustomSampler(len(map), self.subset.shuffle), batch_size=self.batch_size,**self.dataLoader_args))
527
- return dataLoaders
528
706
 
529
707
  class DataTrain(Data):
530
708
 
531
709
  @config("Dataset")
532
- def __init__(self, dataset_filenames : list[str] = ["default:./Dataset"],
533
- groups_src : dict[str, Group] = {"default:group_src" : Group()},
534
- augmentations : Union[dict[str, DataAugmentationsList], None] = {"DataAugmentation_0" : DataAugmentationsList()},
535
- inlineAugmentations: bool = False,
536
- patch : Union[DatasetPatch, None] = DatasetPatch(),
537
- use_cache : bool = True,
538
- subset : Union[TrainSubset, dict[str, TrainSubset]] = TrainSubset(),
539
- batch_size : int = 1,
540
- validation : Union[float, str, list[int], list[str]] = 0.2) -> None:
541
- super().__init__(dataset_filenames, groups_src, patch, use_cache, subset, batch_size, validation, inlineAugmentations, augmentations if augmentations else {})
710
+ def __init__(
711
+ self,
712
+ dataset_filenames: list[str] = ["default:./Dataset"],
713
+ groups_src: dict[str, Group] = {"default:group_src": Group()},
714
+ augmentations: dict[str, DataAugmentationsList] | None = {"DataAugmentation_0": DataAugmentationsList()},
715
+ inline_augmentations: bool = False,
716
+ patch: DatasetPatch | None = DatasetPatch(),
717
+ use_cache: bool = True,
718
+ subset: TrainSubset = TrainSubset(),
719
+ batch_size: int = 1,
720
+ validation: float | str | list[int] | list[str] = 0.2,
721
+ ) -> None:
722
+ super().__init__(
723
+ dataset_filenames,
724
+ groups_src,
725
+ patch,
726
+ use_cache,
727
+ subset,
728
+ batch_size,
729
+ validation,
730
+ inline_augmentations,
731
+ augmentations if augmentations else {},
732
+ )
733
+
542
734
 
543
735
  class DataPrediction(Data):
544
736
 
545
737
  @config("Dataset")
546
- def __init__(self, dataset_filenames : list[str] = ["default:./Dataset"],
547
- groups_src : dict[str, Group] = {"default" : Group()},
548
- augmentations : Union[dict[str, DataAugmentationsList], None] = {"DataAugmentation_0" : DataAugmentationsList()},
549
- patch : Union[DatasetPatch, None] = DatasetPatch(),
550
- subset : Union[PredictionSubset, dict[str, PredictionSubset]] = PredictionSubset(),
551
- batch_size : int = 1) -> None:
738
+ def __init__(
739
+ self,
740
+ dataset_filenames: list[str] = ["default:./Dataset"],
741
+ groups_src: dict[str, Group] = {"default": Group()},
742
+ augmentations: dict[str, DataAugmentationsList] | None = {"DataAugmentation_0": DataAugmentationsList()},
743
+ patch: DatasetPatch | None = DatasetPatch(),
744
+ subset: PredictionSubset = PredictionSubset(),
745
+ batch_size: int = 1,
746
+ ) -> None:
747
+
748
+ super().__init__(
749
+ dataset_filenames=dataset_filenames,
750
+ groups_src=groups_src,
751
+ patch=patch,
752
+ use_cache=False,
753
+ subset=subset,
754
+ batch_size=batch_size,
755
+ validation=None,
756
+ inline_augmentations=False,
757
+ data_augmentations_list=augmentations if augmentations else {},
758
+ )
552
759
 
553
- super().__init__(dataset_filenames, groups_src, patch, False, subset, batch_size, dataAugmentationsList=augmentations if augmentations else {})
554
760
 
555
761
  class DataMetric(Data):
556
762
 
557
763
  @config("Dataset")
558
- def __init__(self, dataset_filenames : list[str] = ["default:./Dataset"],
559
- groups_src : dict[str, GroupMetric] = {"default" : GroupMetric()},
560
- subset : Union[PredictionSubset, dict[str, PredictionSubset]] = PredictionSubset(),
561
- validation: Union[str, None] = None) -> None:
562
-
563
- super().__init__(dataset_filenames=dataset_filenames, groups_src=groups_src, patch=None, use_cache=False, subset=subset, batch_size=1, validation=validation)
764
+ def __init__(
765
+ self,
766
+ dataset_filenames: list[str] = ["default:./Dataset"],
767
+ groups_src: dict[str, GroupMetric] = {"default": GroupMetric()},
768
+ subset: PredictionSubset = PredictionSubset(),
769
+ validation: str | None = None,
770
+ ) -> None:
771
+
772
+ super().__init__(
773
+ dataset_filenames=dataset_filenames,
774
+ groups_src=groups_src,
775
+ patch=None,
776
+ use_cache=False,
777
+ subset=subset,
778
+ batch_size=1,
779
+ validation=validation,
780
+ data_augmentations_list={},
781
+ inline_augmentations=False,
782
+ )