konfai 1.0.8__py3-none-any.whl → 1.1.0__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of konfai might be problematic. Click here for more details.
- konfai/__init__.py +10 -10
- konfai/data/augmentation.py +4 -4
- konfai/data/{dataset.py → data_manager.py} +39 -43
- konfai/data/{HDF5.py → patching.py} +7 -7
- konfai/data/transform.py +3 -3
- konfai/evaluator.py +6 -6
- konfai/main.py +29 -22
- konfai/metric/measure.py +7 -4
- konfai/models/classification/convNeXt.py +1 -1
- konfai/models/classification/resnet.py +1 -1
- konfai/models/generation/cStyleGan.py +1 -1
- konfai/models/generation/ddpm.py +1 -1
- konfai/models/generation/diffusionGan.py +1 -1
- konfai/models/generation/gan.py +1 -1
- konfai/models/segmentation/NestedUNet.py +1 -1
- konfai/models/segmentation/UNet.py +1 -1
- konfai/network/network.py +12 -12
- konfai/predictor.py +10 -10
- konfai/trainer.py +133 -48
- konfai/utils/config.py +52 -19
- konfai/utils/dataset.py +1 -1
- konfai/utils/utils.py +74 -59
- {konfai-1.0.8.dist-info → konfai-1.1.0.dist-info}/METADATA +1 -1
- konfai-1.1.0.dist-info/RECORD +39 -0
- konfai-1.0.8.dist-info/RECORD +0 -39
- {konfai-1.0.8.dist-info → konfai-1.1.0.dist-info}/WHEEL +0 -0
- {konfai-1.0.8.dist-info → konfai-1.1.0.dist-info}/entry_points.txt +0 -0
- {konfai-1.0.8.dist-info → konfai-1.1.0.dist-info}/licenses/LICENSE +0 -0
- {konfai-1.0.8.dist-info → konfai-1.1.0.dist-info}/top_level.txt +0 -0
konfai/__init__.py
CHANGED
|
@@ -1,16 +1,16 @@
|
|
|
1
1
|
import os
|
|
2
2
|
import datetime
|
|
3
3
|
|
|
4
|
-
MODELS_DIRECTORY = lambda : os.environ["
|
|
5
|
-
CHECKPOINTS_DIRECTORY =lambda : os.environ["
|
|
6
|
-
MODEL = lambda : os.environ["
|
|
7
|
-
PREDICTIONS_DIRECTORY =lambda : os.environ["
|
|
8
|
-
EVALUATIONS_DIRECTORY =lambda : os.environ["
|
|
9
|
-
STATISTICS_DIRECTORY = lambda : os.environ["
|
|
10
|
-
SETUPS_DIRECTORY = lambda : os.environ["
|
|
11
|
-
CONFIG_FILE = lambda : os.environ["
|
|
12
|
-
|
|
13
|
-
|
|
4
|
+
MODELS_DIRECTORY = lambda : os.environ["KONFAI_MODELS_DIRECTORY"]
|
|
5
|
+
CHECKPOINTS_DIRECTORY =lambda : os.environ["KONFAI_CHECKPOINTS_DIRECTORY"]
|
|
6
|
+
MODEL = lambda : os.environ["KONFAI_MODEL"]
|
|
7
|
+
PREDICTIONS_DIRECTORY =lambda : os.environ["KONFAI_PREDICTIONS_DIRECTORY"]
|
|
8
|
+
EVALUATIONS_DIRECTORY =lambda : os.environ["KONFAI_EVALUATIONS_DIRECTORY"]
|
|
9
|
+
STATISTICS_DIRECTORY = lambda : os.environ["KONFAI_STATISTICS_DIRECTORY"]
|
|
10
|
+
SETUPS_DIRECTORY = lambda : os.environ["KONFAI_SETUPS_DIRECTORY"]
|
|
11
|
+
CONFIG_FILE = lambda : os.environ["KONFAI_CONFIG_FILE"]
|
|
12
|
+
KONFAI_STATE = lambda : os.environ["KONFAI_STATE"]
|
|
13
|
+
KONFAI_ROOT = lambda : os.environ["KONFAI_ROOT"]
|
|
14
14
|
CUDA_VISIBLE_DEVICES = lambda : os.environ["CUDA_VISIBLE_DEVICES"]
|
|
15
15
|
|
|
16
16
|
DATE = lambda : datetime.datetime.now().strftime("%Y_%m_%d_%H_%M_%S")
|
konfai/data/augmentation.py
CHANGED
|
@@ -6,7 +6,7 @@ import SimpleITK as sitk
|
|
|
6
6
|
import torch.nn.functional as F
|
|
7
7
|
from typing import Union
|
|
8
8
|
import os
|
|
9
|
-
from konfai import
|
|
9
|
+
from konfai import KONFAI_ROOT
|
|
10
10
|
from konfai.utils.config import config
|
|
11
11
|
from konfai.utils.utils import _getModule
|
|
12
12
|
from konfai.utils.dataset import Attribute, data_to_image
|
|
@@ -51,7 +51,7 @@ class Prob():
|
|
|
51
51
|
class DataAugmentationsList():
|
|
52
52
|
|
|
53
53
|
@config()
|
|
54
|
-
def __init__(self, nb : int = 10, dataAugmentations: dict[str, Prob] = {"default:
|
|
54
|
+
def __init__(self, nb : int = 10, dataAugmentations: dict[str, Prob] = {"default:Flip" : Prob(1)}) -> None:
|
|
55
55
|
self.nb = nb
|
|
56
56
|
self.dataAugmentations : list[DataAugmentation] = []
|
|
57
57
|
self.dataAugmentationsLoader = dataAugmentations
|
|
@@ -59,7 +59,7 @@ class DataAugmentationsList():
|
|
|
59
59
|
def load(self, key: str):
|
|
60
60
|
for augmentation, prob in self.dataAugmentationsLoader.items():
|
|
61
61
|
module, name = _getModule(augmentation, "data.augmentation")
|
|
62
|
-
dataAugmentation: DataAugmentation = getattr(importlib.import_module(module), name)(config = None, DL_args="{}.Dataset.augmentations.{}.dataAugmentations".format(
|
|
62
|
+
dataAugmentation: DataAugmentation = getattr(importlib.import_module(module), name)(config = None, DL_args="{}.Dataset.augmentations.{}.dataAugmentations".format(KONFAI_ROOT(), key))
|
|
63
63
|
dataAugmentation.load(prob.prob)
|
|
64
64
|
self.dataAugmentations.append(dataAugmentation)
|
|
65
65
|
|
|
@@ -525,7 +525,7 @@ class Elastix(DataAugmentation):
|
|
|
525
525
|
class Permute(DataAugmentation):
|
|
526
526
|
|
|
527
527
|
@config("Permute")
|
|
528
|
-
def __init__(self, prob_permute: Union[list[float], None] = [0.
|
|
528
|
+
def __init__(self, prob_permute: Union[list[float], None] = [0.5 ,0.5]) -> None:
|
|
529
529
|
super().__init__()
|
|
530
530
|
self._permute_dims = torch.tensor([[0, 2, 1, 3], [0, 3, 1, 2]])
|
|
531
531
|
self.prob_permute = prob_permute
|
|
@@ -12,8 +12,8 @@ from concurrent.futures import ThreadPoolExecutor, as_completed
|
|
|
12
12
|
import threading
|
|
13
13
|
from torch.cuda import device_count
|
|
14
14
|
|
|
15
|
-
from konfai import
|
|
16
|
-
from konfai.data.
|
|
15
|
+
from konfai import KONFAI_STATE, KONFAI_ROOT
|
|
16
|
+
from konfai.data.patching import DatasetPatch, DatasetManager
|
|
17
17
|
from konfai.utils.config import config
|
|
18
18
|
from konfai.utils.utils import memoryInfo, cpuInfo, memoryForecast, getMemory, State
|
|
19
19
|
from konfai.utils.dataset import Dataset, Attribute
|
|
@@ -36,7 +36,7 @@ class GroupTransform:
|
|
|
36
36
|
if self._pre_transforms is not None:
|
|
37
37
|
if isinstance(self._pre_transforms, dict):
|
|
38
38
|
for classpath, transform in self._pre_transforms.items():
|
|
39
|
-
transform = transform.getTransform(classpath, DL_args = "{}.Dataset.groups_src.{}.groups_dest.{}.pre_transforms".format(
|
|
39
|
+
transform = transform.getTransform(classpath, DL_args = "{}.Dataset.groups_src.{}.groups_dest.{}.pre_transforms".format(KONFAI_ROOT(), group_src, group_dest))
|
|
40
40
|
transform.setDatasets(datasets)
|
|
41
41
|
self.pre_transforms.append(transform)
|
|
42
42
|
else:
|
|
@@ -47,7 +47,7 @@ class GroupTransform:
|
|
|
47
47
|
if self._post_transforms is not None:
|
|
48
48
|
if isinstance(self._post_transforms, dict):
|
|
49
49
|
for classpath, transform in self._post_transforms.items():
|
|
50
|
-
transform = transform.getTransform(classpath, DL_args = "{}.Dataset.groups_src.{}.groups_dest.{}.post_transforms".format(
|
|
50
|
+
transform = transform.getTransform(classpath, DL_args = "{}.Dataset.groups_src.{}.groups_dest.{}.post_transforms".format(KONFAI_ROOT(), group_src, group_dest))
|
|
51
51
|
transform.setDatasets(datasets)
|
|
52
52
|
self.post_transforms.append(transform)
|
|
53
53
|
else:
|
|
@@ -64,7 +64,7 @@ class GroupTransform:
|
|
|
64
64
|
class Group(dict[str, GroupTransform]):
|
|
65
65
|
|
|
66
66
|
@config()
|
|
67
|
-
def __init__(self, groups_dest: dict[str, GroupTransform] = {"default": GroupTransform()}):
|
|
67
|
+
def __init__(self, groups_dest: dict[str, GroupTransform] = {"default:group_dest": GroupTransform()}):
|
|
68
68
|
super().__init__(groups_dest)
|
|
69
69
|
|
|
70
70
|
class CustomSampler(Sampler[int]):
|
|
@@ -127,7 +127,7 @@ class DatasetIter(data.Dataset):
|
|
|
127
127
|
total=len(indexs),
|
|
128
128
|
desc="Caching : init | {} | {}".format(memoryForecast(memory_init, 0, self.nb_dataset), cpuInfo()),
|
|
129
129
|
leave=False,
|
|
130
|
-
disable=self.rank != 0 and "
|
|
130
|
+
disable=self.rank != 0 and "KONFAI_CLUSTER" not in os.environ
|
|
131
131
|
)
|
|
132
132
|
|
|
133
133
|
def process(index):
|
|
@@ -135,7 +135,7 @@ class DatasetIter(data.Dataset):
|
|
|
135
135
|
with memory_lock:
|
|
136
136
|
pbar.set_description("Caching : {} | {} | {}".format(memoryInfo(), memoryForecast(memory_init, index, self.nb_dataset), cpuInfo()))
|
|
137
137
|
pbar.update(1)
|
|
138
|
-
with ThreadPoolExecutor(max_workers=os.cpu_count()//device_count()) as executor:
|
|
138
|
+
with ThreadPoolExecutor(max_workers=os.cpu_count()//(device_count() if device_count() > 0 else 1)) as executor:
|
|
139
139
|
futures = [executor.submit(process, index) for index in indexs]
|
|
140
140
|
for _ in as_completed(futures):
|
|
141
141
|
pass
|
|
@@ -257,14 +257,14 @@ class Data(ABC):
|
|
|
257
257
|
subset : Union[Subset, dict[str, Subset]],
|
|
258
258
|
num_workers : int,
|
|
259
259
|
batch_size : int,
|
|
260
|
-
|
|
260
|
+
validation: Union[float, str, list[int], list[str]] = 1,
|
|
261
261
|
inlineAugmentations: bool = False,
|
|
262
262
|
dataAugmentationsList: dict[str, DataAugmentationsList]= {}) -> None:
|
|
263
263
|
self.dataset_filenames = dataset_filenames
|
|
264
264
|
self.subset = subset
|
|
265
265
|
self.groups_src = groups_src
|
|
266
266
|
self.patch = patch
|
|
267
|
-
self.
|
|
267
|
+
self.validation = validation
|
|
268
268
|
self.dataAugmentationsList = dataAugmentationsList
|
|
269
269
|
self.batch_size = batch_size
|
|
270
270
|
self.dataSet_args = dict(groups_src=self.groups_src, inlineAugmentations=inlineAugmentations, dataAugmentationsList = list(self.dataAugmentationsList.values()), use_cache = use_cache, buffer_size=batch_size+1, patch_size=self.patch.patch_size if self.patch is not None else None, overlap=self.patch.overlap if self.patch is not None else None)
|
|
@@ -295,7 +295,7 @@ class Data(ABC):
|
|
|
295
295
|
return [[] for _ in range(world_size)]
|
|
296
296
|
|
|
297
297
|
maps = []
|
|
298
|
-
if
|
|
298
|
+
if KONFAI_STATE() == str(State.PREDICTION) or KONFAI_STATE() == str(State.EVALUATION):
|
|
299
299
|
np_map = np.asarray(map)
|
|
300
300
|
unique_index = np.unique(np_map[:, 0])
|
|
301
301
|
offset = int(np.ceil(len(unique_index)/world_size))
|
|
@@ -314,7 +314,11 @@ class Data(ABC):
|
|
|
314
314
|
def getData(self, world_size: int) -> list[list[DataLoader]]:
|
|
315
315
|
datasets: dict[str, list[(str, bool)]] = {}
|
|
316
316
|
for dataset_filename in self.dataset_filenames:
|
|
317
|
-
if len(dataset_filename.split(":")) ==
|
|
317
|
+
if len(dataset_filename.split(":")) == 1:
|
|
318
|
+
filename = dataset_filename
|
|
319
|
+
format = "mha"
|
|
320
|
+
append = True
|
|
321
|
+
elif len(dataset_filename.split(":")) == 2:
|
|
318
322
|
filename, format = dataset_filename.split(":")
|
|
319
323
|
append = True
|
|
320
324
|
else:
|
|
@@ -331,8 +335,9 @@ class Data(ABC):
|
|
|
331
335
|
else:
|
|
332
336
|
datasets[group] = [(filename, append)]
|
|
333
337
|
for group_src in self.groups_src:
|
|
334
|
-
|
|
335
|
-
|
|
338
|
+
if group_src not in datasets:
|
|
339
|
+
raise ValueError("[DatasetManager] Error: group source {} not found. Available groups: {}".format(group_src, list(datasets.keys())))
|
|
340
|
+
|
|
336
341
|
for group_dest in self.groups_src[group_src]:
|
|
337
342
|
self.groups_src[group_src][group_dest].load(group_src, group_dest, [self.datasets[filename] for filename, _ in datasets[group_src]])
|
|
338
343
|
for key, dataAugmentations in self.dataAugmentationsList.items():
|
|
@@ -367,17 +372,17 @@ class Data(ABC):
|
|
|
367
372
|
|
|
368
373
|
train_map = map
|
|
369
374
|
validate_map = []
|
|
370
|
-
if isinstance(self.
|
|
371
|
-
if self.
|
|
372
|
-
train_map, validate_map = map[:int(math.floor(len(map)*self.
|
|
373
|
-
elif isinstance(self.
|
|
374
|
-
if ":" in self.
|
|
375
|
+
if isinstance(self.validation, float):
|
|
376
|
+
if self.validation < 1.0 and int(math.floor(len(map)*(1-self.validation))) > 0:
|
|
377
|
+
train_map, validate_map = map[:int(math.floor(len(map)*self.validation))], map[int(math.floor(len(map)*self.validation)):]
|
|
378
|
+
elif isinstance(self.validation, str):
|
|
379
|
+
if ":" in self.validation:
|
|
375
380
|
index = list(range(int(self.subset.split(":")[0]), int(self.subset.split(":")[1])))
|
|
376
381
|
train_map = [m for m in map if m[0] not in index]
|
|
377
382
|
validate_map = [m for m in map if m[0] in index]
|
|
378
|
-
elif os.path.exists(self.
|
|
383
|
+
elif os.path.exists(self.validation):
|
|
379
384
|
validation_names = []
|
|
380
|
-
with open(self.
|
|
385
|
+
with open(self.validation, "r") as f:
|
|
381
386
|
for name in f:
|
|
382
387
|
validation_names.append(name.strip())
|
|
383
388
|
index = [i for i, n in enumerate(subset_names) if n in validation_names]
|
|
@@ -385,13 +390,13 @@ class Data(ABC):
|
|
|
385
390
|
validate_map = [m for m in map if m[0] in index]
|
|
386
391
|
else:
|
|
387
392
|
validate_map = train_map
|
|
388
|
-
elif isinstance(self.
|
|
389
|
-
if len(self.
|
|
390
|
-
if isinstance(self.
|
|
391
|
-
train_map = [m for m in map if m[0] not in self.
|
|
392
|
-
validate_map = [m for m in map if m[0] in self.
|
|
393
|
-
elif isinstance(self.
|
|
394
|
-
index = [i for i, n in enumerate(subset_names) if n in self.
|
|
393
|
+
elif isinstance(self.validation, list):
|
|
394
|
+
if len(self.validation) > 0:
|
|
395
|
+
if isinstance(self.validation[0], int):
|
|
396
|
+
train_map = [m for m in map if m[0] not in self.validation]
|
|
397
|
+
validate_map = [m for m in map if m[0] in self.validation]
|
|
398
|
+
elif isinstance(self.validation[0], str):
|
|
399
|
+
index = [i for i, n in enumerate(subset_names) if n in self.validation]
|
|
395
400
|
train_map = [m for m in map if m[0] not in index]
|
|
396
401
|
validate_map = [m for m in map if m[0] in index]
|
|
397
402
|
|
|
@@ -422,8 +427,8 @@ class Data(ABC):
|
|
|
422
427
|
class DataTrain(Data):
|
|
423
428
|
|
|
424
429
|
@config("Dataset")
|
|
425
|
-
def __init__(self, dataset_filenames : list[str] = ["default
|
|
426
|
-
groups_src : dict[str, Group] = {"default" : Group()},
|
|
430
|
+
def __init__(self, dataset_filenames : list[str] = ["default:./Dataset"],
|
|
431
|
+
groups_src : dict[str, Group] = {"default:group_src" : Group()},
|
|
427
432
|
augmentations : Union[dict[str, DataAugmentationsList], None] = {"DataAugmentation_0" : DataAugmentationsList()},
|
|
428
433
|
inlineAugmentations: bool = False,
|
|
429
434
|
patch : Union[DatasetPatch, None] = DatasetPatch(),
|
|
@@ -431,13 +436,13 @@ class DataTrain(Data):
|
|
|
431
436
|
subset : Union[TrainSubset, dict[str, TrainSubset]] = TrainSubset(),
|
|
432
437
|
num_workers : int = 4,
|
|
433
438
|
batch_size : int = 1,
|
|
434
|
-
|
|
435
|
-
super().__init__(dataset_filenames, groups_src, patch, use_cache, subset, num_workers, batch_size,
|
|
439
|
+
validation : Union[float, str, list[int], list[str]] = 0.8) -> None:
|
|
440
|
+
super().__init__(dataset_filenames, groups_src, patch, use_cache, subset, num_workers, batch_size, validation, inlineAugmentations, augmentations if augmentations else {})
|
|
436
441
|
|
|
437
442
|
class DataPrediction(Data):
|
|
438
443
|
|
|
439
444
|
@config("Dataset")
|
|
440
|
-
def __init__(self, dataset_filenames : list[str] = ["default
|
|
445
|
+
def __init__(self, dataset_filenames : list[str] = ["default:./Dataset"],
|
|
441
446
|
groups_src : dict[str, Group] = {"default" : Group()},
|
|
442
447
|
augmentations : Union[dict[str, DataAugmentationsList], None] = {"DataAugmentation_0" : DataAugmentationsList()},
|
|
443
448
|
inlineAugmentations: bool = False,
|
|
@@ -452,19 +457,10 @@ class DataPrediction(Data):
|
|
|
452
457
|
class DataMetric(Data):
|
|
453
458
|
|
|
454
459
|
@config("Dataset")
|
|
455
|
-
def __init__(self, dataset_filenames : list[str] = ["default
|
|
460
|
+
def __init__(self, dataset_filenames : list[str] = ["default:./Dataset"],
|
|
456
461
|
groups_src : dict[str, Group] = {"default" : Group()},
|
|
457
462
|
subset : Union[PredictionSubset, dict[str, PredictionSubset]] = PredictionSubset(),
|
|
458
463
|
validation: Union[str, None] = None,
|
|
459
464
|
num_workers : int = 4) -> None:
|
|
460
465
|
|
|
461
|
-
super().__init__(dataset_filenames=dataset_filenames, groups_src=groups_src, patch=None, use_cache=False, subset=subset, num_workers=num_workers, batch_size=1,
|
|
462
|
-
|
|
463
|
-
class DataHyperparameter(Data):
|
|
464
|
-
|
|
465
|
-
@config("Dataset")
|
|
466
|
-
def __init__(self, dataset_filenames : list[str] = ["default:Dataset.h5"],
|
|
467
|
-
groups_src : dict[str, Group] = {"default" : Group()},
|
|
468
|
-
patch : Union[DatasetPatch, None] = DatasetPatch()) -> None:
|
|
469
|
-
|
|
470
|
-
super().__init__(dataset_filenames, groups_src, patch, False, PredictionSubset(), 0, False, 1)
|
|
466
|
+
super().__init__(dataset_filenames=dataset_filenames, groups_src=groups_src, patch=None, use_cache=False, subset=subset, num_workers=num_workers, batch_size=1, validation=1 if validation is None else validation)
|
|
@@ -37,13 +37,13 @@ class PathCombine(ABC):
|
|
|
37
37
|
3D :
|
|
38
38
|
AAA+AAB+ABA+ABB+BAA+BAB+BBA+BBB
|
|
39
39
|
|
|
40
|
-
AAC+ABC+BAC+BBC
|
|
41
|
-
ACA+ACB+BCA+BCB
|
|
42
40
|
CAA+CAB+CBA+CBB
|
|
41
|
+
ACA+ACB+BCA+BCB
|
|
42
|
+
AAC+ABC+BAC+BBC
|
|
43
43
|
|
|
44
|
-
ACC+BCC
|
|
45
|
-
CAC+CBC
|
|
46
44
|
CCA+CCB
|
|
45
|
+
CAC+CBC
|
|
46
|
+
ACC+BCC
|
|
47
47
|
|
|
48
48
|
"""
|
|
49
49
|
def setPatchConfig(self, patch_size: list[int], overlap: int):
|
|
@@ -188,7 +188,7 @@ class Patch(ABC):
|
|
|
188
188
|
pad_bottom = 0
|
|
189
189
|
pad_top = 0
|
|
190
190
|
if self._patch_slices[a][index][0].start-bottom < 0:
|
|
191
|
-
pad_bottom = bottom-
|
|
191
|
+
pad_bottom = bottom-self._patch_slices[a][index][0].start
|
|
192
192
|
if self._patch_slices[a][index][0].stop+top > data.shape[len(slices_pre)]:
|
|
193
193
|
pad_top = self._patch_slices[a][index][0].stop+top-data.shape[len(slices_pre)]
|
|
194
194
|
data_sliced = F.pad(data_sliced, [0 for _ in range((len(slices)-1)*2)]+[pad_bottom, pad_top], 'reflect')
|
|
@@ -214,13 +214,13 @@ class Patch(ABC):
|
|
|
214
214
|
class DatasetPatch(Patch):
|
|
215
215
|
|
|
216
216
|
@config("Patch")
|
|
217
|
-
def __init__(self, patch_size : list[int] = [128,
|
|
217
|
+
def __init__(self, patch_size : list[int] = [128, 128, 128], overlap : Union[int, None] = None, mask: Union[str, None] = None, padValue: float = 0, extend_slice: int = 0) -> None:
|
|
218
218
|
super().__init__(patch_size, overlap, mask, padValue, extend_slice)
|
|
219
219
|
|
|
220
220
|
class ModelPatch(Patch):
|
|
221
221
|
|
|
222
222
|
@config("Patch")
|
|
223
|
-
def __init__(self, patch_size : list[int] = [128,
|
|
223
|
+
def __init__(self, patch_size : list[int] = [128, 128, 128], overlap : Union[int, None] = None, patchCombine: Union[str, None] = None, mask: Union[str, None] = None, padValue: float = 0, extend_slice: int = 0) -> None:
|
|
224
224
|
super().__init__(patch_size, overlap, mask, padValue, extend_slice)
|
|
225
225
|
self.patchCombine = patchCombine
|
|
226
226
|
|
konfai/data/transform.py
CHANGED
|
@@ -130,7 +130,7 @@ class Standardize(Transform):
|
|
|
130
130
|
|
|
131
131
|
class TensorCast(Transform):
|
|
132
132
|
|
|
133
|
-
def __init__(self, dtype : str = "
|
|
133
|
+
def __init__(self, dtype : str = "float32") -> None:
|
|
134
134
|
self.dtype : torch.dtype = getattr(torch, dtype)
|
|
135
135
|
|
|
136
136
|
def __call__(self, name: str, input : torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
|
|
@@ -142,7 +142,7 @@ class TensorCast(Transform):
|
|
|
142
142
|
|
|
143
143
|
class Padding(Transform):
|
|
144
144
|
|
|
145
|
-
def __init__(self, padding : list[int] = [0,0,0,0,0,0], mode : str = "
|
|
145
|
+
def __init__(self, padding : list[int] = [0,0,0,0,0,0], mode : str = "constant") -> None:
|
|
146
146
|
self.padding = padding
|
|
147
147
|
self.mode = mode
|
|
148
148
|
|
|
@@ -332,7 +332,7 @@ class ResampleTransform(Transform):
|
|
|
332
332
|
|
|
333
333
|
class Mask(Transform):
|
|
334
334
|
|
|
335
|
-
def __init__(self, path : str = "default
|
|
335
|
+
def __init__(self, path : str = "./default.mha", value_outside: int = 0) -> None:
|
|
336
336
|
self.path = path
|
|
337
337
|
self.value_outside = value_outside
|
|
338
338
|
|
konfai/evaluator.py
CHANGED
|
@@ -7,10 +7,10 @@ import json
|
|
|
7
7
|
import shutil
|
|
8
8
|
import builtins
|
|
9
9
|
import importlib
|
|
10
|
-
from konfai import EVALUATIONS_DIRECTORY, PREDICTIONS_DIRECTORY,
|
|
10
|
+
from konfai import EVALUATIONS_DIRECTORY, PREDICTIONS_DIRECTORY, KONFAI_ROOT, CONFIG_FILE
|
|
11
11
|
from konfai.utils.config import config
|
|
12
12
|
from konfai.utils.utils import _getModule, DistributedObject, synchronize_data
|
|
13
|
-
from konfai.data.
|
|
13
|
+
from konfai.data.data_manager import DataMetric
|
|
14
14
|
|
|
15
15
|
class CriterionsAttr():
|
|
16
16
|
|
|
@@ -28,7 +28,7 @@ class CriterionsLoader():
|
|
|
28
28
|
criterions = {}
|
|
29
29
|
for module_classpath, criterionsAttr in self.criterionsLoader.items():
|
|
30
30
|
module, name = _getModule(module_classpath, "metric.measure")
|
|
31
|
-
criterions[config("{}.metrics.{}.targetsCriterions.{}.criterionsLoader.{}".format(
|
|
31
|
+
criterions[config("{}.metrics.{}.targetsCriterions.{}.criterionsLoader.{}".format(KONFAI_ROOT(), output_group, target_group, module_classpath))(getattr(importlib.import_module(module), name))(config = None)] = criterionsAttr
|
|
32
32
|
return criterions
|
|
33
33
|
|
|
34
34
|
class TargetCriterionsLoader():
|
|
@@ -86,8 +86,8 @@ class Statistics():
|
|
|
86
86
|
class Evaluator(DistributedObject):
|
|
87
87
|
|
|
88
88
|
@config("Evaluator")
|
|
89
|
-
def __init__(self, train_name: str = "default:
|
|
90
|
-
if os.environ["
|
|
89
|
+
def __init__(self, train_name: str = "default:TRAIN_01", metrics: dict[str, TargetCriterionsLoader] = {"default": TargetCriterionsLoader()}, dataset : DataMetric = DataMetric(),) -> None:
|
|
90
|
+
if os.environ["KONFAI_CONFIG_MODE"] != "Done":
|
|
91
91
|
exit(0)
|
|
92
92
|
super().__init__(train_name)
|
|
93
93
|
self.metric_path = EVALUATIONS_DIRECTORY()+self.name+"/"
|
|
@@ -111,7 +111,7 @@ class Evaluator(DistributedObject):
|
|
|
111
111
|
|
|
112
112
|
def setup(self, world_size: int):
|
|
113
113
|
if os.path.exists(self.metric_path):
|
|
114
|
-
if os.environ["
|
|
114
|
+
if os.environ["KONFAI_OVERWRITE"] != "True":
|
|
115
115
|
accept = builtins.input("The metric {} already exists ! Do you want to overwrite it (yes,no) : ".format(self.name))
|
|
116
116
|
if accept != "yes":
|
|
117
117
|
return
|
konfai/main.py
CHANGED
|
@@ -2,21 +2,25 @@ import argparse
|
|
|
2
2
|
import os
|
|
3
3
|
from torch.cuda import device_count
|
|
4
4
|
import torch.multiprocessing as mp
|
|
5
|
-
from konfai.utils.utils import
|
|
5
|
+
from konfai.utils.utils import setup, TensorBoard, Log
|
|
6
6
|
|
|
7
7
|
import sys
|
|
8
8
|
sys.path.insert(0, os.getcwd())
|
|
9
9
|
|
|
10
10
|
def main():
|
|
11
11
|
parser = argparse.ArgumentParser(description="KonfAI", formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
|
12
|
-
|
|
13
|
-
with
|
|
14
|
-
|
|
15
|
-
|
|
16
|
-
world_size
|
|
17
|
-
|
|
18
|
-
|
|
19
|
-
|
|
12
|
+
try:
|
|
13
|
+
with setup(parser) as distributedObject:
|
|
14
|
+
with Log(distributedObject.name):
|
|
15
|
+
world_size = device_count()
|
|
16
|
+
if world_size == 0:
|
|
17
|
+
world_size = 1
|
|
18
|
+
distributedObject.setup(world_size)
|
|
19
|
+
with TensorBoard(distributedObject.name):
|
|
20
|
+
mp.spawn(distributedObject, nprocs=world_size)
|
|
21
|
+
except Exception as e:
|
|
22
|
+
print(e)
|
|
23
|
+
exit(1)
|
|
20
24
|
|
|
21
25
|
|
|
22
26
|
def cluster():
|
|
@@ -29,17 +33,20 @@ def cluster():
|
|
|
29
33
|
cluster_args.add_argument('--memory', type=int, default=16, help='Amount of memory per node')
|
|
30
34
|
cluster_args.add_argument('--time-limit', '--time_limit', type=int, default=1440, help='Job time limit in minute')
|
|
31
35
|
cluster_args.add_argument('--resubmit', action='store_true', help='Automatically resubmit job just before timout')
|
|
36
|
+
try:
|
|
37
|
+
with setup(parser) as distributedObject:
|
|
38
|
+
args = parser.parse_args()
|
|
39
|
+
config = vars(args)
|
|
40
|
+
os.environ["KONFAI_OVERWRITE"] = "True"
|
|
41
|
+
os.environ["KONFAI_CLUSTER"] = "True"
|
|
32
42
|
|
|
33
|
-
|
|
34
|
-
|
|
35
|
-
|
|
36
|
-
|
|
37
|
-
|
|
38
|
-
|
|
39
|
-
|
|
40
|
-
|
|
41
|
-
|
|
42
|
-
|
|
43
|
-
executor.update_parameters(name=config["name"], mem_gb=config["memory"], gpus_per_node=n_gpu, tasks_per_node=n_gpu//distributedObject.size, cpus_per_task=config["num_workers"], nodes=config["num_nodes"], timeout_min=config["time_limit"])
|
|
44
|
-
with TensorBoard(distributedObject.name):
|
|
45
|
-
executor.submit(distributedObject)
|
|
43
|
+
n_gpu = len(config["gpu"].split(","))
|
|
44
|
+
distributedObject.setup(n_gpu*int(config["num_nodes"]))
|
|
45
|
+
import submitit
|
|
46
|
+
executor = submitit.AutoExecutor(folder="./Cluster/")
|
|
47
|
+
executor.update_parameters(name=config["name"], mem_gb=config["memory"], gpus_per_node=n_gpu, tasks_per_node=n_gpu//distributedObject.size, cpus_per_task=config["num_workers"], nodes=config["num_nodes"], timeout_min=config["time_limit"])
|
|
48
|
+
with TensorBoard(distributedObject.name):
|
|
49
|
+
executor.submit(distributedObject)
|
|
50
|
+
except Exception as e:
|
|
51
|
+
print(e)
|
|
52
|
+
exit(1)
|
konfai/metric/measure.py
CHANGED
|
@@ -17,7 +17,7 @@ from abc import abstractmethod
|
|
|
17
17
|
|
|
18
18
|
from konfai.utils.config import config
|
|
19
19
|
from konfai.utils.utils import _getModule
|
|
20
|
-
from konfai.data.
|
|
20
|
+
from konfai.data.patching import ModelPatch
|
|
21
21
|
from konfai.network.blocks import LatentDistribution
|
|
22
22
|
from konfai.network.network import ModelLoader, Network
|
|
23
23
|
|
|
@@ -92,7 +92,7 @@ class PSNR(MaskedLoss):
|
|
|
92
92
|
return psnr
|
|
93
93
|
|
|
94
94
|
def __init__(self, dynamic_range: Union[float, None] = None) -> None:
|
|
95
|
-
dynamic_range = dynamic_range if dynamic_range else 1024+
|
|
95
|
+
dynamic_range = dynamic_range if dynamic_range else 1024+3071
|
|
96
96
|
super().__init__(partial(PSNR._loss, dynamic_range), False)
|
|
97
97
|
|
|
98
98
|
class SSIM(MaskedLoss):
|
|
@@ -143,8 +143,11 @@ class Dice(Criterion):
|
|
|
143
143
|
target = self.flatten(target)
|
|
144
144
|
return (2.*(input * target).sum() + self.smooth)/(input.sum() + target.sum() + self.smooth)
|
|
145
145
|
|
|
146
|
+
|
|
146
147
|
def forward(self, output: torch.Tensor, *targets : list[torch.Tensor]) -> torch.Tensor:
|
|
147
148
|
target = targets[0]
|
|
149
|
+
if output.shape[1] == 1:
|
|
150
|
+
output = F.one_hot(output.type(torch.int64), num_classes=int(torch.max(output).item()+1)).permute(0, len(target.shape), *[i+1 for i in range(len(target.shape)-1)]).float()
|
|
148
151
|
target = F.one_hot(target.type(torch.int64), num_classes=output.shape[1]).permute(0, len(target.shape), *[i+1 for i in range(len(target.shape)-1)]).float().squeeze(2)
|
|
149
152
|
return 1-torch.mean(self.dice_per_channel(output, target))
|
|
150
153
|
|
|
@@ -239,7 +242,7 @@ class PerceptualLoss(Criterion):
|
|
|
239
242
|
@config(None)
|
|
240
243
|
def __init__(self, losses: dict[str, float] = {"Gram": 1, "torch_nn_L1Loss": 1}) -> None:
|
|
241
244
|
self.losses = losses
|
|
242
|
-
self.DL_args = os.environ['
|
|
245
|
+
self.DL_args = os.environ['KONFAI_CONFIG_PATH'] if "KONFAI_CONFIG_PATH" in os.environ else ""
|
|
243
246
|
|
|
244
247
|
def getLoss(self) -> dict[torch.nn.Module, float]:
|
|
245
248
|
result: dict[torch.nn.Module, float] = {}
|
|
@@ -252,7 +255,7 @@ class PerceptualLoss(Criterion):
|
|
|
252
255
|
super().__init__()
|
|
253
256
|
self.path_model = path_model
|
|
254
257
|
if self.path_model not in modelsRegister:
|
|
255
|
-
self.model = modelLoader.getModel(train=False, DL_args=os.environ['
|
|
258
|
+
self.model = modelLoader.getModel(train=False, DL_args=os.environ['KONFAI_CONFIG_PATH'].split("PerceptualLoss")[0]+"PerceptualLoss.Model", DL_without=["optimizer", "schedulers", "nb_batch_per_step", "init_type", "init_gain", "outputsCriterions", "drop_p"])
|
|
256
259
|
if path_model.startswith("https"):
|
|
257
260
|
state_dict = torch.hub.load_state_dict_from_url(path_model)
|
|
258
261
|
state_dict = {"Model": {self.model.getName() : state_dict["model"]}}
|
|
@@ -2,7 +2,7 @@ import torch
|
|
|
2
2
|
import torch.nn.functional as F
|
|
3
3
|
from konfai.network import network, blocks
|
|
4
4
|
from konfai.utils.config import config
|
|
5
|
-
from konfai.data.
|
|
5
|
+
from konfai.data.patching import ModelPatch
|
|
6
6
|
|
|
7
7
|
"""
|
|
8
8
|
"convnext_tiny_1k": "https://dl.fbaipublicfiles.com/convnext/convnext_tiny_1k_224_ema.pth", depths=[3, 3, 9, 3], dims=[96, 192, 384, 768]
|
|
@@ -3,7 +3,7 @@ from typing import Type
|
|
|
3
3
|
import torch
|
|
4
4
|
from konfai.network import network, blocks
|
|
5
5
|
from konfai.utils.config import config
|
|
6
|
-
from konfai.data.
|
|
6
|
+
from konfai.data.patching import ModelPatch
|
|
7
7
|
|
|
8
8
|
"""
|
|
9
9
|
'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', dim = 2, in_channels = 3, depths=[2, 2, 2, 2], widths = [64, 64, 128, 256, 512], num_classes=1000, useBottleneck=False
|
|
@@ -3,7 +3,7 @@ import torch
|
|
|
3
3
|
|
|
4
4
|
from konfai.network import network, blocks
|
|
5
5
|
from konfai.utils.config import config
|
|
6
|
-
from konfai.data.
|
|
6
|
+
from konfai.data.patching import ModelPatch
|
|
7
7
|
|
|
8
8
|
class MappingNetwork(network.ModuleArgsDict):
|
|
9
9
|
def __init__(self, z_dim: int, c_dim: int, w_dim: int, num_layers: int, embed_features: int, layer_features: int):
|
konfai/models/generation/ddpm.py
CHANGED
|
@@ -8,7 +8,7 @@ import numpy as np
|
|
|
8
8
|
|
|
9
9
|
from konfai.network import network, blocks
|
|
10
10
|
from konfai.utils.config import config
|
|
11
|
-
from konfai.data.
|
|
11
|
+
from konfai.data.patching import ModelPatch
|
|
12
12
|
from konfai.utils.utils import gpuInfo
|
|
13
13
|
from konfai.metric.measure import Criterion
|
|
14
14
|
|
|
@@ -5,7 +5,7 @@ import numpy as np
|
|
|
5
5
|
|
|
6
6
|
from konfai.network import network, blocks
|
|
7
7
|
from konfai.utils.config import config
|
|
8
|
-
from konfai.data.
|
|
8
|
+
from konfai.data.patching import ModelPatch, Attribute
|
|
9
9
|
from konfai.data import augmentation
|
|
10
10
|
from konfai.models.segmentation import UNet, NestedUNet
|
|
11
11
|
from konfai.models.generation.ddpm import DDPM
|
konfai/models/generation/gan.py
CHANGED