konfai 1.1.0__py3-none-any.whl → 1.1.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of konfai might be problematic. Click here for more details.
- konfai/__init__.py +1 -1
- konfai/data/augmentation.py +2 -2
- konfai/data/data_manager.py +145 -42
- konfai/data/patching.py +39 -13
- konfai/data/transform.py +48 -21
- konfai/evaluator.py +24 -7
- konfai/main.py +7 -5
- konfai/models/registration/registration.py +0 -1
- konfai/network/blocks.py +0 -1
- konfai/network/network.py +29 -16
- konfai/predictor.py +24 -21
- konfai/trainer.py +15 -15
- konfai/utils/config.py +12 -12
- konfai/utils/dataset.py +27 -2
- konfai/utils/utils.py +108 -24
- {konfai-1.1.0.dist-info → konfai-1.1.2.dist-info}/METADATA +1 -1
- {konfai-1.1.0.dist-info → konfai-1.1.2.dist-info}/RECORD +21 -21
- {konfai-1.1.0.dist-info → konfai-1.1.2.dist-info}/WHEEL +0 -0
- {konfai-1.1.0.dist-info → konfai-1.1.2.dist-info}/entry_points.txt +0 -0
- {konfai-1.1.0.dist-info → konfai-1.1.2.dist-info}/licenses/LICENSE +0 -0
- {konfai-1.1.0.dist-info → konfai-1.1.2.dist-info}/top_level.txt +0 -0
konfai/__init__.py
CHANGED
|
@@ -12,5 +12,5 @@ CONFIG_FILE = lambda : os.environ["KONFAI_CONFIG_FILE"]
|
|
|
12
12
|
KONFAI_STATE = lambda : os.environ["KONFAI_STATE"]
|
|
13
13
|
KONFAI_ROOT = lambda : os.environ["KONFAI_ROOT"]
|
|
14
14
|
CUDA_VISIBLE_DEVICES = lambda : os.environ["CUDA_VISIBLE_DEVICES"]
|
|
15
|
-
|
|
15
|
+
KONFAI_NB_CORES = lambda : os.environ["KONFAI_NB_CORES"]
|
|
16
16
|
DATE = lambda : datetime.datetime.now().strftime("%Y_%m_%d_%H_%M_%S")
|
konfai/data/augmentation.py
CHANGED
|
@@ -8,7 +8,7 @@ from typing import Union
|
|
|
8
8
|
import os
|
|
9
9
|
from konfai import KONFAI_ROOT
|
|
10
10
|
from konfai.utils.config import config
|
|
11
|
-
from konfai.utils.utils import _getModule
|
|
11
|
+
from konfai.utils.utils import _getModule, AugmentationError
|
|
12
12
|
from konfai.utils.dataset import Attribute, data_to_image
|
|
13
13
|
|
|
14
14
|
|
|
@@ -222,7 +222,7 @@ class ColorTransform(DataAugmentation):
|
|
|
222
222
|
matrix = matrix[:, :3, :].mean(dim=1, keepdims=True).to(input.device)
|
|
223
223
|
result = result.float() * matrix[:, :, :3].sum(dim=2, keepdims=True) + matrix[:, :, 3:]
|
|
224
224
|
else:
|
|
225
|
-
raise
|
|
225
|
+
raise AugmentationError('Image must be RGB (3 channels) or L (1 channel)')
|
|
226
226
|
results.append(result.reshape(input.shape))
|
|
227
227
|
return results
|
|
228
228
|
|
konfai/data/data_manager.py
CHANGED
|
@@ -11,11 +11,12 @@ from typing import Union, Iterator
|
|
|
11
11
|
from concurrent.futures import ThreadPoolExecutor, as_completed
|
|
12
12
|
import threading
|
|
13
13
|
from torch.cuda import device_count
|
|
14
|
+
import SimpleITK as sitk
|
|
14
15
|
|
|
15
16
|
from konfai import KONFAI_STATE, KONFAI_ROOT
|
|
16
17
|
from konfai.data.patching import DatasetPatch, DatasetManager
|
|
17
18
|
from konfai.utils.config import config
|
|
18
|
-
from konfai.utils.utils import memoryInfo, cpuInfo, memoryForecast, getMemory, State
|
|
19
|
+
from konfai.utils.utils import memoryInfo, cpuInfo, memoryForecast, getMemory, State, SUPPORTED_EXTENSIONS, DatasetManagerError
|
|
19
20
|
from konfai.utils.dataset import Dataset, Attribute
|
|
20
21
|
from konfai.data.transform import TransformLoader, Transform
|
|
21
22
|
from konfai.data.augmentation import DataAugmentationsList
|
|
@@ -61,12 +62,25 @@ class GroupTransform:
|
|
|
61
62
|
for transform in self.post_transforms:
|
|
62
63
|
transform.setDevice(device)
|
|
63
64
|
|
|
65
|
+
class GroupTransformMetric(GroupTransform):
|
|
66
|
+
|
|
67
|
+
@config()
|
|
68
|
+
def __init__(self, pre_transforms : Union[dict[str, TransformLoader], list[Transform]] = {"default:Normalize:Standardize:Unsqueeze:TensorCast:ResampleIsotropic:ResampleResize": TransformLoader()},
|
|
69
|
+
post_transforms : Union[dict[str, TransformLoader], list[Transform]] = {"default:Normalize:Standardize:Unsqueeze:TensorCast:ResampleIsotropic:ResampleResize": TransformLoader()}):
|
|
70
|
+
super().__init__(pre_transforms, post_transforms)
|
|
71
|
+
|
|
64
72
|
class Group(dict[str, GroupTransform]):
|
|
65
73
|
|
|
66
74
|
@config()
|
|
67
75
|
def __init__(self, groups_dest: dict[str, GroupTransform] = {"default:group_dest": GroupTransform()}):
|
|
68
76
|
super().__init__(groups_dest)
|
|
69
77
|
|
|
78
|
+
class GroupMetric(dict[str, GroupTransformMetric]):
|
|
79
|
+
|
|
80
|
+
@config()
|
|
81
|
+
def __init__(self, groups_dest: dict[str, GroupTransformMetric] = {"default:group_dest": GroupTransformMetric()}):
|
|
82
|
+
super().__init__(groups_dest)
|
|
83
|
+
|
|
70
84
|
class CustomSampler(Sampler[int]):
|
|
71
85
|
|
|
72
86
|
def __init__(self, size: int, shuffle: bool = False) -> None:
|
|
@@ -108,32 +122,33 @@ class DatasetIter(data.Dataset):
|
|
|
108
122
|
def getDatasetFromIndex(self, group_dest: str, index: int) -> DatasetManager:
|
|
109
123
|
return self.data[group_dest][index]
|
|
110
124
|
|
|
111
|
-
def resetAugmentation(self):
|
|
112
|
-
if self.inlineAugmentations:
|
|
125
|
+
def resetAugmentation(self, label):
|
|
126
|
+
if self.inlineAugmentations and len(self.dataAugmentationsList) > 0:
|
|
113
127
|
for index in range(self.nb_dataset):
|
|
114
|
-
self._unloadData(index)
|
|
115
128
|
for group_src in self.groups_src:
|
|
116
129
|
for group_dest in self.groups_src[group_src]:
|
|
130
|
+
self.data[group_dest][index].unloadAugmentation()
|
|
117
131
|
self.data[group_dest][index].resetAugmentation()
|
|
132
|
+
self.load(label + " Augmentation")
|
|
118
133
|
|
|
119
|
-
def load(self):
|
|
134
|
+
def load(self, label: str):
|
|
120
135
|
if self.use_cache:
|
|
121
136
|
memory_init = getMemory()
|
|
122
137
|
|
|
123
|
-
indexs = [index for index in range(self.nb_dataset)
|
|
138
|
+
indexs = [index for index in range(self.nb_dataset)]
|
|
124
139
|
if len(indexs) > 0:
|
|
125
140
|
memory_lock = threading.Lock()
|
|
141
|
+
desc = lambda : "Caching "+ label +": {} | {} | {}".format(memoryInfo(), memoryForecast(memory_init, 0, self.nb_dataset), cpuInfo())
|
|
126
142
|
pbar = tqdm.tqdm(
|
|
127
143
|
total=len(indexs),
|
|
128
|
-
desc=
|
|
129
|
-
leave=False
|
|
130
|
-
disable=self.rank != 0 and "KONFAI_CLUSTER" not in os.environ
|
|
144
|
+
desc=desc(),
|
|
145
|
+
leave=False
|
|
131
146
|
)
|
|
132
147
|
|
|
133
148
|
def process(index):
|
|
134
149
|
self._loadData(index)
|
|
135
150
|
with memory_lock:
|
|
136
|
-
pbar.set_description(
|
|
151
|
+
pbar.set_description(desc())
|
|
137
152
|
pbar.update(1)
|
|
138
153
|
with ThreadPoolExecutor(max_workers=os.cpu_count()//(device_count() if device_count() > 0 else 1)) as executor:
|
|
139
154
|
futures = [executor.submit(process, index) for index in indexs]
|
|
@@ -236,6 +251,8 @@ class Subset():
|
|
|
236
251
|
index = random.sample(index, len(index))
|
|
237
252
|
return set([names_filtred[i] for i in index])
|
|
238
253
|
|
|
254
|
+
def __str__(self):
|
|
255
|
+
return "Subset : " + str(self.subset) + " shuffle : "+ str(self.shuffle) + " filter : "+ str(self.filter)
|
|
239
256
|
class TrainSubset(Subset):
|
|
240
257
|
|
|
241
258
|
@config()
|
|
@@ -254,10 +271,10 @@ class Data(ABC):
|
|
|
254
271
|
groups_src : dict[str, Group],
|
|
255
272
|
patch : Union[DatasetPatch, None],
|
|
256
273
|
use_cache : bool,
|
|
257
|
-
subset :
|
|
274
|
+
subset : Subset,
|
|
258
275
|
num_workers : int,
|
|
259
276
|
batch_size : int,
|
|
260
|
-
validation: Union[float, str, list[int], list[str]] =
|
|
277
|
+
validation: Union[float, str, list[int], list[str], None] = None,
|
|
261
278
|
inlineAugmentations: bool = False,
|
|
262
279
|
dataAugmentationsList: dict[str, DataAugmentationsList]= {}) -> None:
|
|
263
280
|
self.dataset_filenames = dataset_filenames
|
|
@@ -290,6 +307,13 @@ class Data(ABC):
|
|
|
290
307
|
map.append((x, y, z))
|
|
291
308
|
return data, map
|
|
292
309
|
|
|
310
|
+
def getGroupsDest(self):
|
|
311
|
+
groupsDest = []
|
|
312
|
+
for group_src in self.groups_src:
|
|
313
|
+
for group_dest in self.groups_src[group_src]:
|
|
314
|
+
groupsDest.append(group_dest)
|
|
315
|
+
return groupsDest
|
|
316
|
+
|
|
293
317
|
def _split(map: list[tuple[int, int, int]], world_size: int) -> list[list[tuple[int, int, int]]]:
|
|
294
318
|
if len(map) == 0:
|
|
295
319
|
return [[] for _ in range(world_size)]
|
|
@@ -313,7 +337,14 @@ class Data(ABC):
|
|
|
313
337
|
|
|
314
338
|
def getData(self, world_size: int) -> list[list[DataLoader]]:
|
|
315
339
|
datasets: dict[str, list[(str, bool)]] = {}
|
|
340
|
+
if self.dataset_filenames is None or len(self.dataset_filenames) == 0:
|
|
341
|
+
raise DatasetManagerError("No dataset filenames were provided")
|
|
316
342
|
for dataset_filename in self.dataset_filenames:
|
|
343
|
+
if dataset_filename is None:
|
|
344
|
+
raise DatasetManagerError("Invalid dataset entry: 'None' received.",
|
|
345
|
+
"Each dataset must be a valid path string (e.g., './Dataset/', './Dataset/:mha, './Dataset/:a:mha', './Dataset/:i:mha').",
|
|
346
|
+
"Please check your 'dataset_filenames' list for missing or null entries."
|
|
347
|
+
)
|
|
317
348
|
if len(dataset_filename.split(":")) == 1:
|
|
318
349
|
filename = dataset_filename
|
|
319
350
|
format = "mha"
|
|
@@ -324,9 +355,13 @@ class Data(ABC):
|
|
|
324
355
|
else:
|
|
325
356
|
filename, flag, format = dataset_filename.split(":")
|
|
326
357
|
append = flag == "a"
|
|
327
|
-
|
|
328
|
-
|
|
358
|
+
|
|
359
|
+
if format not in SUPPORTED_EXTENSIONS:
|
|
360
|
+
raise DatasetManagerError(f"Unsupported file format '{format}'.",
|
|
361
|
+
f"Supported extensions are: {', '.join(SUPPORTED_EXTENSIONS)}")
|
|
329
362
|
|
|
363
|
+
dataset = Dataset(filename, format)
|
|
364
|
+
|
|
330
365
|
self.datasets[filename] = dataset
|
|
331
366
|
for group in self.groups_src:
|
|
332
367
|
if dataset.isGroupExist(group):
|
|
@@ -334,47 +369,88 @@ class Data(ABC):
|
|
|
334
369
|
datasets[group].append((filename, append))
|
|
335
370
|
else:
|
|
336
371
|
datasets[group] = [(filename, append)]
|
|
372
|
+
modelHaveInput = False
|
|
337
373
|
for group_src in self.groups_src:
|
|
338
374
|
if group_src not in datasets:
|
|
339
|
-
raise ValueError("[DatasetManager] Error: group source {} not found. Available groups: {}".format(group_src, list(datasets.keys())))
|
|
340
375
|
|
|
376
|
+
raise DatasetManagerError(
|
|
377
|
+
f"Group source '{group_src}' not found in any dataset.",
|
|
378
|
+
f"Dataset filenames provided: {self.dataset_filenames}",
|
|
379
|
+
"Available groups across all datasets: {}".format(["{} {}".format(f, d.getGroup()) for f, d in self.datasets.items()]),
|
|
380
|
+
f"Please check that an entry in the dataset with the name '{group_src}.{format}' exists."
|
|
381
|
+
)
|
|
382
|
+
|
|
341
383
|
for group_dest in self.groups_src[group_src]:
|
|
342
384
|
self.groups_src[group_src][group_dest].load(group_src, group_dest, [self.datasets[filename] for filename, _ in datasets[group_src]])
|
|
385
|
+
modelHaveInput |= self.groups_src[group_src][group_dest].isInput
|
|
386
|
+
|
|
387
|
+
if not modelHaveInput:
|
|
388
|
+
raise DatasetManagerError(
|
|
389
|
+
"At least one group must be defined with 'isInput: true' to provide input to the network."
|
|
390
|
+
)
|
|
391
|
+
|
|
343
392
|
for key, dataAugmentations in self.dataAugmentationsList.items():
|
|
344
393
|
dataAugmentations.load(key)
|
|
345
394
|
|
|
346
|
-
names
|
|
395
|
+
names = set()
|
|
347
396
|
dataset_name : dict[str, dict[str, list[str]]] = {}
|
|
348
397
|
dataset_info : dict[str, dict[str, dict[str, Attribute]]] = {}
|
|
349
398
|
for group in self.groups_src:
|
|
399
|
+
namesByGroup = set()
|
|
350
400
|
if group not in dataset_name:
|
|
351
401
|
dataset_name[group] = {}
|
|
352
402
|
dataset_info[group] = {}
|
|
353
403
|
for filename, _ in datasets[group]:
|
|
354
|
-
|
|
404
|
+
namesByGroup.update(self.datasets[filename].getNames(group))
|
|
355
405
|
dataset_name[group][filename] = self.datasets[filename].getNames(group)
|
|
356
406
|
dataset_info[group][filename] = {name: self.datasets[filename].getInfos(group, name) for name in dataset_name[group][filename]}
|
|
407
|
+
if len(names) == 0:
|
|
408
|
+
names.update(namesByGroup)
|
|
409
|
+
else:
|
|
410
|
+
names = names.intersection(namesByGroup)
|
|
411
|
+
if len(names) == 0:
|
|
412
|
+
raise DatasetManagerError(
|
|
413
|
+
f"No data was found for groups {list(self.groups_src.keys())}: although each group contains data from a dataset, there are no common dataset names shared across all groups, the intersection is empty."
|
|
414
|
+
)
|
|
415
|
+
|
|
357
416
|
subset_names = set()
|
|
358
|
-
|
|
359
|
-
|
|
360
|
-
|
|
361
|
-
|
|
362
|
-
|
|
363
|
-
|
|
364
|
-
if
|
|
365
|
-
|
|
417
|
+
for group in dataset_name:
|
|
418
|
+
subset_names_bygroup = set()
|
|
419
|
+
for filename, append in datasets[group]:
|
|
420
|
+
if append:
|
|
421
|
+
subset_names_bygroup.update(self.subset([dataset_name[group][filename]], [dataset_info[group][filename]]))
|
|
422
|
+
else:
|
|
423
|
+
if len(subset_names_bygroup) == 0:
|
|
424
|
+
subset_names_bygroup.update(self.subset([dataset_name[group][filename]], [dataset_info[group][filename]]))
|
|
366
425
|
else:
|
|
367
|
-
|
|
368
|
-
|
|
369
|
-
|
|
370
|
-
|
|
426
|
+
subset_names_bygroup = subset_names_bygroup.intersection(self.subset([dataset_name[group][filename]], [dataset_info[group][filename]]))
|
|
427
|
+
if len(subset_names) == 0:
|
|
428
|
+
subset_names.update(subset_names_bygroup)
|
|
429
|
+
else:
|
|
430
|
+
subset_names = subset_names.intersection(subset_names_bygroup)
|
|
431
|
+
if len(subset_names) == 0:
|
|
432
|
+
raise DatasetManagerError("All data entries were excluded by the subset filter.",
|
|
433
|
+
f"Dataset entries found: {', '.join(names)}",
|
|
434
|
+
f"Subset object applied: {self.subset}",
|
|
435
|
+
f"Subset requested : {', '.join(subset_names)}",
|
|
436
|
+
"None of the dataset entries matched the given subset.",
|
|
437
|
+
"Please check your 'subset' configuration — it may be too restrictive or incorrectly formatted.",
|
|
438
|
+
"Examples of valid subset formats:",
|
|
439
|
+
"\tsubset: [0, 1] # explicit indices",
|
|
440
|
+
"\tsubset: 0:10 # slice notation",
|
|
441
|
+
"\tsubset: ./Validation.txt # external file",
|
|
442
|
+
"\tsubset: None # to disable filtering"
|
|
443
|
+
)
|
|
444
|
+
|
|
371
445
|
data, map = self._getDatasets(list(subset_names), dataset_name)
|
|
372
|
-
|
|
446
|
+
|
|
373
447
|
train_map = map
|
|
374
448
|
validate_map = []
|
|
375
|
-
if isinstance(self.validation, float):
|
|
376
|
-
if self.validation
|
|
377
|
-
|
|
449
|
+
if isinstance(self.validation, float) or isinstance(self.validation, int):
|
|
450
|
+
if self.validation <= 0 or self.validation >= 1:
|
|
451
|
+
raise DatasetManagerError("Validation must be a float between 0 and 1.", f"Received: {self.validation}", "Example: validation = 0.2 # for a 20% validation split")
|
|
452
|
+
|
|
453
|
+
train_map, validate_map = map[:int(math.floor(len(map)*(1-self.validation)))], map[int(math.floor(len(map)*(1-self.validation))):]
|
|
378
454
|
elif isinstance(self.validation, str):
|
|
379
455
|
if ":" in self.validation:
|
|
380
456
|
index = list(range(int(self.subset.split(":")[0]), int(self.subset.split(":")[1])))
|
|
@@ -389,7 +465,17 @@ class Data(ABC):
|
|
|
389
465
|
train_map = [m for m in map if m[0] not in index]
|
|
390
466
|
validate_map = [m for m in map if m[0] in index]
|
|
391
467
|
else:
|
|
392
|
-
|
|
468
|
+
raise DatasetManagerError(
|
|
469
|
+
f"Invalid string value for 'validation': '{self.validation}'",
|
|
470
|
+
"Expected one of the following formats:",
|
|
471
|
+
"\t• A slice string like '0:10'",
|
|
472
|
+
"\t• A path to a text file listing validation sample names (e.g., './val.txt')",
|
|
473
|
+
"\t• A float between 0 and 1 (e.g., 0.2)",
|
|
474
|
+
"\t• A list of sample names or indices",
|
|
475
|
+
"The provided value is neither a valid slice nor a readable file.",
|
|
476
|
+
"Please fix your 'validation' setting in the configuration."
|
|
477
|
+
)
|
|
478
|
+
|
|
393
479
|
elif isinstance(self.validation, list):
|
|
394
480
|
if len(self.validation) > 0:
|
|
395
481
|
if isinstance(self.validation[0], int):
|
|
@@ -399,10 +485,29 @@ class Data(ABC):
|
|
|
399
485
|
index = [i for i, n in enumerate(subset_names) if n in self.validation]
|
|
400
486
|
train_map = [m for m in map if m[0] not in index]
|
|
401
487
|
validate_map = [m for m in map if m[0] in index]
|
|
402
|
-
|
|
488
|
+
else:
|
|
489
|
+
raise DatasetManagerError(f"Invalid list type for 'validation': elements of type '{type(self.validation[0]).__name__}' are not supported.",
|
|
490
|
+
"Supported list element types are:",
|
|
491
|
+
"\t• int → list of indices (e.g., [0, 1, 2])",
|
|
492
|
+
"\t• str → list of sample names (e.g., ['patient01', 'patient02'])",
|
|
493
|
+
f"Received list: {self.validation}"
|
|
494
|
+
)
|
|
495
|
+
if len(train_map) == 0:
|
|
496
|
+
raise DatasetManagerError("No data left for training after applying the validation split.",
|
|
497
|
+
f"Dataset size: {len(map)}",
|
|
498
|
+
f"Validation setting: {self.validation}",
|
|
499
|
+
"Please reduce the validation size, increase the dataset, or disable validation."
|
|
500
|
+
)
|
|
501
|
+
|
|
502
|
+
if self.validation is not None and len(validate_map) == 0:
|
|
503
|
+
raise DatasetManagerError("No data left for validation after applying the validation split.",
|
|
504
|
+
f"Dataset size: {len(map)}",
|
|
505
|
+
f"Validation setting: {self.validation}",
|
|
506
|
+
"Please increase the validation size, increase the dataset, or disable validation."
|
|
507
|
+
)
|
|
403
508
|
train_maps = Data._split(train_map, world_size)
|
|
404
509
|
validate_maps = Data._split(validate_map, world_size)
|
|
405
|
-
|
|
510
|
+
|
|
406
511
|
for i, (train_map, validate_map) in enumerate(zip(train_maps, validate_maps)):
|
|
407
512
|
maps = [train_map]
|
|
408
513
|
if len(validate_map):
|
|
@@ -436,7 +541,7 @@ class DataTrain(Data):
|
|
|
436
541
|
subset : Union[TrainSubset, dict[str, TrainSubset]] = TrainSubset(),
|
|
437
542
|
num_workers : int = 4,
|
|
438
543
|
batch_size : int = 1,
|
|
439
|
-
validation : Union[float, str, list[int], list[str]] = 0.
|
|
544
|
+
validation : Union[float, str, list[int], list[str]] = 0.2) -> None:
|
|
440
545
|
super().__init__(dataset_filenames, groups_src, patch, use_cache, subset, num_workers, batch_size, validation, inlineAugmentations, augmentations if augmentations else {})
|
|
441
546
|
|
|
442
547
|
class DataPrediction(Data):
|
|
@@ -445,22 +550,20 @@ class DataPrediction(Data):
|
|
|
445
550
|
def __init__(self, dataset_filenames : list[str] = ["default:./Dataset"],
|
|
446
551
|
groups_src : dict[str, Group] = {"default" : Group()},
|
|
447
552
|
augmentations : Union[dict[str, DataAugmentationsList], None] = {"DataAugmentation_0" : DataAugmentationsList()},
|
|
448
|
-
inlineAugmentations: bool = False,
|
|
449
553
|
patch : Union[DatasetPatch, None] = DatasetPatch(),
|
|
450
|
-
use_cache : bool = True,
|
|
451
554
|
subset : Union[PredictionSubset, dict[str, PredictionSubset]] = PredictionSubset(),
|
|
452
555
|
num_workers : int = 4,
|
|
453
556
|
batch_size : int = 1) -> None:
|
|
454
557
|
|
|
455
|
-
super().__init__(dataset_filenames, groups_src, patch,
|
|
558
|
+
super().__init__(dataset_filenames, groups_src, patch, False, subset, num_workers, batch_size, dataAugmentationsList=augmentations if augmentations else {})
|
|
456
559
|
|
|
457
560
|
class DataMetric(Data):
|
|
458
561
|
|
|
459
562
|
@config("Dataset")
|
|
460
563
|
def __init__(self, dataset_filenames : list[str] = ["default:./Dataset"],
|
|
461
|
-
groups_src : dict[str,
|
|
564
|
+
groups_src : dict[str, GroupMetric] = {"default" : GroupMetric()},
|
|
462
565
|
subset : Union[PredictionSubset, dict[str, PredictionSubset]] = PredictionSubset(),
|
|
463
566
|
validation: Union[str, None] = None,
|
|
464
567
|
num_workers : int = 4) -> None:
|
|
465
568
|
|
|
466
|
-
super().__init__(dataset_filenames=dataset_filenames, groups_src=groups_src, patch=None, use_cache=False, subset=subset, num_workers=num_workers, batch_size=1, validation=
|
|
569
|
+
super().__init__(dataset_filenames=dataset_filenames, groups_src=groups_src, patch=None, use_cache=False, subset=subset, num_workers=num_workers, batch_size=1, validation=validation)
|
konfai/data/patching.py
CHANGED
|
@@ -15,7 +15,6 @@ from konfai.utils.dataset import Dataset, Attribute
|
|
|
15
15
|
from konfai.data.transform import Transform, Save
|
|
16
16
|
from konfai.data.augmentation import DataAugmentationsList
|
|
17
17
|
|
|
18
|
-
|
|
19
18
|
class PathCombine(ABC):
|
|
20
19
|
|
|
21
20
|
def __init__(self) -> None:
|
|
@@ -150,7 +149,10 @@ class Patch(ABC):
|
|
|
150
149
|
|
|
151
150
|
def __init__(self, patch_size: list[int], overlap: Union[int, None], path_mask: Union[str, None] = None, padValue: float = 0, extend_slice: int = 0) -> None:
|
|
152
151
|
self.patch_size = patch_size
|
|
153
|
-
self.overlap = overlap
|
|
152
|
+
self.overlap = overlap
|
|
153
|
+
if isinstance(self.overlap, int):
|
|
154
|
+
if self.overlap < 0:
|
|
155
|
+
self.overlap = None
|
|
154
156
|
self._patch_slices : dict[int, list[tuple[slice]]] = {}
|
|
155
157
|
self._nb_patch_per_dim: dict[int, list[tuple[int, bool]]] = {}
|
|
156
158
|
self.path_mask = path_mask
|
|
@@ -237,6 +239,7 @@ class DatasetManager():
|
|
|
237
239
|
self.index = index
|
|
238
240
|
self.dataset = dataset
|
|
239
241
|
self.loaded = False
|
|
242
|
+
self.augmentationLoaded = False
|
|
240
243
|
self.cache_attributes: list[Attribute] = []
|
|
241
244
|
_shape, cache_attribute = self.dataset.getInfos(self.group_src, name)
|
|
242
245
|
self.cache_attributes.append(cache_attribute)
|
|
@@ -255,6 +258,7 @@ class DatasetManager():
|
|
|
255
258
|
self.cache_attributes_bak = copy.deepcopy(self.cache_attributes)
|
|
256
259
|
|
|
257
260
|
def resetAugmentation(self):
|
|
261
|
+
self.cache_attributes[:] = self.cache_attributes[:1]
|
|
258
262
|
i = 1
|
|
259
263
|
for dataAugmentations in self.dataAugmentationsList:
|
|
260
264
|
shape = []
|
|
@@ -269,15 +273,24 @@ class DatasetManager():
|
|
|
269
273
|
self.cache_attributes.append(caches_attribute[it])
|
|
270
274
|
self.patch.load(s, i)
|
|
271
275
|
i+=1
|
|
272
|
-
|
|
276
|
+
|
|
273
277
|
def load(self, pre_transform : list[Transform], dataAugmentationsList : list[DataAugmentationsList], device: torch.device) -> None:
|
|
274
|
-
if self.loaded:
|
|
275
|
-
|
|
278
|
+
if not self.loaded:
|
|
279
|
+
self._load(pre_transform)
|
|
280
|
+
if not self.augmentationLoaded:
|
|
281
|
+
self._loadAugmentation(dataAugmentationsList, device)
|
|
282
|
+
|
|
283
|
+
def _load(self, pre_transform : list[Transform]):
|
|
284
|
+
self.cache_attributes = copy.deepcopy(self.cache_attributes_bak)
|
|
276
285
|
i = len(pre_transform)
|
|
277
286
|
data = None
|
|
278
287
|
for transformFunction in reversed(pre_transform):
|
|
279
288
|
if isinstance(transformFunction, Save):
|
|
280
|
-
|
|
289
|
+
if len(transformFunction.dataset.split(":")) > 1:
|
|
290
|
+
filename, format = transformFunction.dataset.split(":")
|
|
291
|
+
else:
|
|
292
|
+
filename = transformFunction.dataset.split(":")
|
|
293
|
+
format = "mha"
|
|
281
294
|
dataset = Dataset(filename, format)
|
|
282
295
|
if dataset.isDatasetExist(self.group_dest, self.name):
|
|
283
296
|
data, attrib = dataset.readData(self.group_dest, self.name)
|
|
@@ -295,27 +308,40 @@ class DatasetManager():
|
|
|
295
308
|
for transformFunction in pre_transform[i:]:
|
|
296
309
|
data = transformFunction(self.name, data, self.cache_attributes[0])
|
|
297
310
|
if isinstance(transformFunction, Save):
|
|
298
|
-
|
|
311
|
+
if len(transformFunction.dataset.split(":")) > 1:
|
|
312
|
+
filename, format = transformFunction.dataset.split(":")
|
|
313
|
+
else:
|
|
314
|
+
filename = transformFunction.dataset.split(":")
|
|
315
|
+
format = "mha"
|
|
299
316
|
dataset = Dataset(filename, format)
|
|
300
317
|
dataset.write(self.group_dest, self.name, data.numpy(), self.cache_attributes[0])
|
|
301
318
|
self.data : list[torch.Tensor] = list()
|
|
302
319
|
self.data.append(data)
|
|
303
320
|
|
|
321
|
+
for i in range(len(self.cache_attributes)-1):
|
|
322
|
+
self.cache_attributes[i+1].update(self.cache_attributes[0])
|
|
323
|
+
self.loaded = True
|
|
324
|
+
|
|
325
|
+
def _loadAugmentation(self, dataAugmentationsList : list[DataAugmentationsList], device: torch.device) -> None:
|
|
304
326
|
for dataAugmentations in dataAugmentationsList:
|
|
305
|
-
a_data = [data.clone() for _ in range(dataAugmentations.nb)]
|
|
327
|
+
a_data = [self.data[0].clone() for _ in range(dataAugmentations.nb)]
|
|
306
328
|
for dataAugmentation in dataAugmentations.dataAugmentations:
|
|
307
329
|
a_data = dataAugmentation(self.index, a_data, device)
|
|
308
330
|
|
|
309
331
|
for d in a_data:
|
|
310
332
|
self.data.append(d)
|
|
311
|
-
self.
|
|
333
|
+
self.augmentationLoaded = True
|
|
312
334
|
|
|
313
335
|
def unload(self) -> None:
|
|
314
|
-
|
|
315
|
-
del self.data
|
|
316
|
-
self.cache_attributes = copy.deepcopy(self.cache_attributes_bak)
|
|
336
|
+
self.data.clear()
|
|
317
337
|
self.loaded = False
|
|
318
|
-
|
|
338
|
+
self.augmentationLoaded = False
|
|
339
|
+
|
|
340
|
+
def unloadAugmentation(self) -> None:
|
|
341
|
+
self.data[:] = self.data[:1]
|
|
342
|
+
self.augmentationLoaded = False
|
|
343
|
+
|
|
344
|
+
|
|
319
345
|
def getData(self, index : int, a : int, post_transforms : list[Transform], isInput: bool) -> torch.Tensor:
|
|
320
346
|
data = self.patch.getData(self.data[a], index, a, isInput)
|
|
321
347
|
for transformFunction in post_transforms:
|
konfai/data/transform.py
CHANGED
|
@@ -6,7 +6,7 @@ from abc import ABC, abstractmethod
|
|
|
6
6
|
import torch.nn.functional as F
|
|
7
7
|
from typing import Any, Union
|
|
8
8
|
|
|
9
|
-
from konfai.utils.utils import _getModule, NeedDevice, _resample_affine, _affine_matrix
|
|
9
|
+
from konfai.utils.utils import _getModule, NeedDevice, _resample_affine, _affine_matrix, TransformError
|
|
10
10
|
from konfai.utils.dataset import Dataset, Attribute, data_to_image, image_to_data
|
|
11
11
|
from konfai.utils.config import config
|
|
12
12
|
|
|
@@ -211,39 +211,67 @@ class Resample(Transform, ABC):
|
|
|
211
211
|
_ = cache_attribute.pop_np_array("Spacing")
|
|
212
212
|
return self._resample(input, [int(size) for size in size_1])
|
|
213
213
|
|
|
214
|
-
class
|
|
214
|
+
class ResampleToResolution(Resample):
|
|
215
|
+
|
|
216
|
+
def __init__(self, spacing : list[Union[float, None]] = [1., 1., 1.]) -> None:
|
|
217
|
+
self.spacing = torch.tensor([0 if s < 0 else s for s in spacing])
|
|
215
218
|
|
|
216
|
-
def __init__(self, spacing : list[float] = [1., 1., 1.]) -> None:
|
|
217
|
-
self.spacing = torch.tensor(spacing, dtype=torch.float64)
|
|
218
|
-
|
|
219
219
|
def transformShape(self, shape: list[int], cache_attribute: Attribute) -> list[int]:
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
220
|
+
if "Spacing" not in cache_attribute:
|
|
221
|
+
TransformError("Missing 'Spacing' in cache attributes, the data is likely not a valid image.",
|
|
222
|
+
"Make sure your input is a image (e.g., .nii, .mha) with proper metadata.")
|
|
223
|
+
if len(shape) != len(self.spacing):
|
|
224
|
+
TransformError("Shape and spacing dimensions do not match: shape={shape}, spacing={self.spacing}")
|
|
225
|
+
image_spacing = cache_attribute.get_tensor("Spacing").flip(0)
|
|
226
|
+
spacing = self.spacing
|
|
227
|
+
|
|
228
|
+
for i, s in enumerate(self.spacing):
|
|
229
|
+
if s == 0:
|
|
230
|
+
spacing[i] = image_spacing[i]
|
|
231
|
+
resize_factor = spacing/cache_attribute.get_tensor("Spacing").flip(0)
|
|
232
|
+
return [int(x) for x in (torch.tensor(shape) * 1/resize_factor)]
|
|
223
233
|
|
|
224
234
|
def __call__(self, name: str, input : torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
|
|
235
|
+
image_spacing = cache_attribute.get_tensor("Spacing").flip(0)
|
|
236
|
+
spacing = self.spacing
|
|
237
|
+
for i, s in enumerate(self.spacing):
|
|
238
|
+
if s == 0:
|
|
239
|
+
spacing[i] = image_spacing[i]
|
|
240
|
+
resize_factor = spacing/cache_attribute.get_tensor("Spacing").flip(0)
|
|
241
|
+
cache_attribute["Spacing"] = spacing.flip(0)
|
|
228
242
|
cache_attribute["Size"] = np.asarray([int(x) for x in torch.tensor(input.shape[1:])])
|
|
229
243
|
size = [int(x) for x in (torch.tensor(input.shape[1:]) * 1/resize_factor)]
|
|
230
244
|
cache_attribute["Size"] = np.asarray(size)
|
|
231
245
|
return self._resample(input, size)
|
|
232
246
|
|
|
233
|
-
class
|
|
247
|
+
class ResampleToSize(Resample):
|
|
234
248
|
|
|
235
249
|
def __init__(self, size : list[int] = [100,512,512]) -> None:
|
|
236
250
|
self.size = size
|
|
237
251
|
|
|
238
252
|
def transformShape(self, shape: list[int], cache_attribute: Attribute) -> list[int]:
|
|
239
|
-
|
|
253
|
+
if "Spacing" not in cache_attribute:
|
|
254
|
+
TransformError("Missing 'Spacing' in cache attributes, the data is likely not a valid image.",
|
|
255
|
+
"Make sure your input is a image (e.g., .nii, .mha) with proper metadata.")
|
|
256
|
+
if len(shape) != len(self.size):
|
|
257
|
+
TransformError("Shape and spacing dimensions do not match: shape={shape}, spacing={self.spacing}")
|
|
258
|
+
size = self.size
|
|
259
|
+
for i, s in enumerate(self.size):
|
|
260
|
+
if s == -1:
|
|
261
|
+
size[i] = shape[i]
|
|
262
|
+
return size
|
|
240
263
|
|
|
241
264
|
def __call__(self, name: str, input: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
|
|
265
|
+
size = self.size
|
|
266
|
+
image_size = np.asarray([int(x) for x in torch.tensor(input.shape[1:])])
|
|
267
|
+
for i, s in enumerate(self.size):
|
|
268
|
+
if s is None:
|
|
269
|
+
size[i] = image_size[i]
|
|
242
270
|
if "Spacing" in cache_attribute:
|
|
243
|
-
cache_attribute["Spacing"] = torch.flip(torch.tensor(
|
|
244
|
-
cache_attribute["Size"] =
|
|
245
|
-
cache_attribute["Size"] =
|
|
246
|
-
return self._resample(input,
|
|
271
|
+
cache_attribute["Spacing"] = torch.flip(torch.tensor(image_size)/torch.tensor(size)*torch.flip(cache_attribute.get_tensor("Spacing"), dims=[0]), dims=[0])
|
|
272
|
+
cache_attribute["Size"] = image_size
|
|
273
|
+
cache_attribute["Size"] = size
|
|
274
|
+
return self._resample(input, size)
|
|
247
275
|
|
|
248
276
|
class ResampleTransform(Transform):
|
|
249
277
|
|
|
@@ -412,8 +440,8 @@ class FlatLabel(Transform):
|
|
|
412
440
|
|
|
413
441
|
class Save(Transform):
|
|
414
442
|
|
|
415
|
-
def __init__(self,
|
|
416
|
-
self.
|
|
443
|
+
def __init__(self, dataset: str) -> None:
|
|
444
|
+
self.dataset = dataset
|
|
417
445
|
|
|
418
446
|
def __call__(self, name: str, input : torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
|
|
419
447
|
return input
|
|
@@ -528,8 +556,7 @@ class OneHot(Transform):
|
|
|
528
556
|
self.num_classes = num_classes
|
|
529
557
|
|
|
530
558
|
def __call__(self, name: str, input: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
|
|
531
|
-
result = F.one_hot(input.type(torch.int64), num_classes=self.num_classes).permute(0, len(input.shape), *[i+1 for i in range(len(input.shape)-1)]).float().squeeze(
|
|
532
|
-
print(result.shape)
|
|
559
|
+
result = F.one_hot(input.type(torch.int64), num_classes=self.num_classes).permute(0, len(input.shape), *[i+1 for i in range(len(input.shape)-1)]).float().squeeze(0)
|
|
533
560
|
return result
|
|
534
561
|
|
|
535
562
|
def inverse(self, name: str, input: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
|