konfai 1.1.1__py3-none-any.whl → 1.1.2__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of konfai might be problematic. Click here for more details.
- konfai/__init__.py +1 -1
- konfai/data/data_manager.py +55 -32
- konfai/data/patching.py +35 -12
- konfai/data/transform.py +48 -20
- konfai/evaluator.py +24 -7
- konfai/main.py +3 -2
- konfai/network/network.py +0 -1
- konfai/predictor.py +24 -21
- konfai/trainer.py +10 -10
- konfai/utils/dataset.py +27 -2
- konfai/utils/utils.py +49 -12
- {konfai-1.1.1.dist-info → konfai-1.1.2.dist-info}/METADATA +1 -1
- {konfai-1.1.1.dist-info → konfai-1.1.2.dist-info}/RECORD +17 -17
- {konfai-1.1.1.dist-info → konfai-1.1.2.dist-info}/WHEEL +0 -0
- {konfai-1.1.1.dist-info → konfai-1.1.2.dist-info}/entry_points.txt +0 -0
- {konfai-1.1.1.dist-info → konfai-1.1.2.dist-info}/licenses/LICENSE +0 -0
- {konfai-1.1.1.dist-info → konfai-1.1.2.dist-info}/top_level.txt +0 -0
konfai/__init__.py
CHANGED
|
@@ -12,5 +12,5 @@ CONFIG_FILE = lambda : os.environ["KONFAI_CONFIG_FILE"]
|
|
|
12
12
|
KONFAI_STATE = lambda : os.environ["KONFAI_STATE"]
|
|
13
13
|
KONFAI_ROOT = lambda : os.environ["KONFAI_ROOT"]
|
|
14
14
|
CUDA_VISIBLE_DEVICES = lambda : os.environ["CUDA_VISIBLE_DEVICES"]
|
|
15
|
-
|
|
15
|
+
KONFAI_NB_CORES = lambda : os.environ["KONFAI_NB_CORES"]
|
|
16
16
|
DATE = lambda : datetime.datetime.now().strftime("%Y_%m_%d_%H_%M_%S")
|
konfai/data/data_manager.py
CHANGED
|
@@ -62,12 +62,25 @@ class GroupTransform:
|
|
|
62
62
|
for transform in self.post_transforms:
|
|
63
63
|
transform.setDevice(device)
|
|
64
64
|
|
|
65
|
+
class GroupTransformMetric(GroupTransform):
|
|
66
|
+
|
|
67
|
+
@config()
|
|
68
|
+
def __init__(self, pre_transforms : Union[dict[str, TransformLoader], list[Transform]] = {"default:Normalize:Standardize:Unsqueeze:TensorCast:ResampleIsotropic:ResampleResize": TransformLoader()},
|
|
69
|
+
post_transforms : Union[dict[str, TransformLoader], list[Transform]] = {"default:Normalize:Standardize:Unsqueeze:TensorCast:ResampleIsotropic:ResampleResize": TransformLoader()}):
|
|
70
|
+
super().__init__(pre_transforms, post_transforms)
|
|
71
|
+
|
|
65
72
|
class Group(dict[str, GroupTransform]):
|
|
66
73
|
|
|
67
74
|
@config()
|
|
68
75
|
def __init__(self, groups_dest: dict[str, GroupTransform] = {"default:group_dest": GroupTransform()}):
|
|
69
76
|
super().__init__(groups_dest)
|
|
70
77
|
|
|
78
|
+
class GroupMetric(dict[str, GroupTransformMetric]):
|
|
79
|
+
|
|
80
|
+
@config()
|
|
81
|
+
def __init__(self, groups_dest: dict[str, GroupTransformMetric] = {"default:group_dest": GroupTransformMetric()}):
|
|
82
|
+
super().__init__(groups_dest)
|
|
83
|
+
|
|
71
84
|
class CustomSampler(Sampler[int]):
|
|
72
85
|
|
|
73
86
|
def __init__(self, size: int, shuffle: bool = False) -> None:
|
|
@@ -109,32 +122,33 @@ class DatasetIter(data.Dataset):
|
|
|
109
122
|
def getDatasetFromIndex(self, group_dest: str, index: int) -> DatasetManager:
|
|
110
123
|
return self.data[group_dest][index]
|
|
111
124
|
|
|
112
|
-
def resetAugmentation(self):
|
|
113
|
-
if self.inlineAugmentations:
|
|
125
|
+
def resetAugmentation(self, label):
|
|
126
|
+
if self.inlineAugmentations and len(self.dataAugmentationsList) > 0:
|
|
114
127
|
for index in range(self.nb_dataset):
|
|
115
|
-
self._unloadData(index)
|
|
116
128
|
for group_src in self.groups_src:
|
|
117
129
|
for group_dest in self.groups_src[group_src]:
|
|
130
|
+
self.data[group_dest][index].unloadAugmentation()
|
|
118
131
|
self.data[group_dest][index].resetAugmentation()
|
|
132
|
+
self.load(label + " Augmentation")
|
|
119
133
|
|
|
120
|
-
def load(self):
|
|
134
|
+
def load(self, label: str):
|
|
121
135
|
if self.use_cache:
|
|
122
136
|
memory_init = getMemory()
|
|
123
137
|
|
|
124
|
-
indexs = [index for index in range(self.nb_dataset)
|
|
138
|
+
indexs = [index for index in range(self.nb_dataset)]
|
|
125
139
|
if len(indexs) > 0:
|
|
126
140
|
memory_lock = threading.Lock()
|
|
141
|
+
desc = lambda : "Caching "+ label +": {} | {} | {}".format(memoryInfo(), memoryForecast(memory_init, 0, self.nb_dataset), cpuInfo())
|
|
127
142
|
pbar = tqdm.tqdm(
|
|
128
143
|
total=len(indexs),
|
|
129
|
-
desc=
|
|
130
|
-
leave=False
|
|
131
|
-
disable=self.rank != 0 and "KONFAI_CLUSTER" not in os.environ
|
|
144
|
+
desc=desc(),
|
|
145
|
+
leave=False
|
|
132
146
|
)
|
|
133
147
|
|
|
134
148
|
def process(index):
|
|
135
149
|
self._loadData(index)
|
|
136
150
|
with memory_lock:
|
|
137
|
-
pbar.set_description(
|
|
151
|
+
pbar.set_description(desc())
|
|
138
152
|
pbar.update(1)
|
|
139
153
|
with ThreadPoolExecutor(max_workers=os.cpu_count()//(device_count() if device_count() > 0 else 1)) as executor:
|
|
140
154
|
futures = [executor.submit(process, index) for index in indexs]
|
|
@@ -257,10 +271,10 @@ class Data(ABC):
|
|
|
257
271
|
groups_src : dict[str, Group],
|
|
258
272
|
patch : Union[DatasetPatch, None],
|
|
259
273
|
use_cache : bool,
|
|
260
|
-
subset :
|
|
274
|
+
subset : Subset,
|
|
261
275
|
num_workers : int,
|
|
262
276
|
batch_size : int,
|
|
263
|
-
validation: Union[float, str, list[int], list[str]] =
|
|
277
|
+
validation: Union[float, str, list[int], list[str], None] = None,
|
|
264
278
|
inlineAugmentations: bool = False,
|
|
265
279
|
dataAugmentationsList: dict[str, DataAugmentationsList]= {}) -> None:
|
|
266
280
|
self.dataset_filenames = dataset_filenames
|
|
@@ -343,7 +357,7 @@ class Data(ABC):
|
|
|
343
357
|
append = flag == "a"
|
|
344
358
|
|
|
345
359
|
if format not in SUPPORTED_EXTENSIONS:
|
|
346
|
-
|
|
360
|
+
raise DatasetManagerError(f"Unsupported file format '{format}'.",
|
|
347
361
|
f"Supported extensions are: {', '.join(SUPPORTED_EXTENSIONS)}")
|
|
348
362
|
|
|
349
363
|
dataset = Dataset(filename, format)
|
|
@@ -362,7 +376,7 @@ class Data(ABC):
|
|
|
362
376
|
raise DatasetManagerError(
|
|
363
377
|
f"Group source '{group_src}' not found in any dataset.",
|
|
364
378
|
f"Dataset filenames provided: {self.dataset_filenames}",
|
|
365
|
-
|
|
379
|
+
"Available groups across all datasets: {}".format(["{} {}".format(f, d.getGroup()) for f, d in self.datasets.items()]),
|
|
366
380
|
f"Please check that an entry in the dataset with the name '{group_src}.{format}' exists."
|
|
367
381
|
)
|
|
368
382
|
|
|
@@ -382,28 +396,38 @@ class Data(ABC):
|
|
|
382
396
|
dataset_name : dict[str, dict[str, list[str]]] = {}
|
|
383
397
|
dataset_info : dict[str, dict[str, dict[str, Attribute]]] = {}
|
|
384
398
|
for group in self.groups_src:
|
|
399
|
+
namesByGroup = set()
|
|
385
400
|
if group not in dataset_name:
|
|
386
401
|
dataset_name[group] = {}
|
|
387
402
|
dataset_info[group] = {}
|
|
388
403
|
for filename, _ in datasets[group]:
|
|
389
|
-
|
|
404
|
+
namesByGroup.update(self.datasets[filename].getNames(group))
|
|
390
405
|
dataset_name[group][filename] = self.datasets[filename].getNames(group)
|
|
391
406
|
dataset_info[group][filename] = {name: self.datasets[filename].getInfos(group, name) for name in dataset_name[group][filename]}
|
|
407
|
+
if len(names) == 0:
|
|
408
|
+
names.update(namesByGroup)
|
|
409
|
+
else:
|
|
410
|
+
names = names.intersection(namesByGroup)
|
|
411
|
+
if len(names) == 0:
|
|
412
|
+
raise DatasetManagerError(
|
|
413
|
+
f"No data was found for groups {list(self.groups_src.keys())}: although each group contains data from a dataset, there are no common dataset names shared across all groups, the intersection is empty."
|
|
414
|
+
)
|
|
415
|
+
|
|
392
416
|
subset_names = set()
|
|
393
|
-
|
|
394
|
-
|
|
395
|
-
for filename,
|
|
396
|
-
|
|
397
|
-
|
|
398
|
-
|
|
399
|
-
|
|
400
|
-
|
|
401
|
-
subset_names.update(self.subset([dataset_name[group][filename]], [dataset_info[group][filename]]))
|
|
417
|
+
for group in dataset_name:
|
|
418
|
+
subset_names_bygroup = set()
|
|
419
|
+
for filename, append in datasets[group]:
|
|
420
|
+
if append:
|
|
421
|
+
subset_names_bygroup.update(self.subset([dataset_name[group][filename]], [dataset_info[group][filename]]))
|
|
422
|
+
else:
|
|
423
|
+
if len(subset_names_bygroup) == 0:
|
|
424
|
+
subset_names_bygroup.update(self.subset([dataset_name[group][filename]], [dataset_info[group][filename]]))
|
|
402
425
|
else:
|
|
403
|
-
|
|
404
|
-
|
|
405
|
-
|
|
406
|
-
|
|
426
|
+
subset_names_bygroup = subset_names_bygroup.intersection(self.subset([dataset_name[group][filename]], [dataset_info[group][filename]]))
|
|
427
|
+
if len(subset_names) == 0:
|
|
428
|
+
subset_names.update(subset_names_bygroup)
|
|
429
|
+
else:
|
|
430
|
+
subset_names = subset_names.intersection(subset_names_bygroup)
|
|
407
431
|
if len(subset_names) == 0:
|
|
408
432
|
raise DatasetManagerError("All data entries were excluded by the subset filter.",
|
|
409
433
|
f"Dataset entries found: {', '.join(names)}",
|
|
@@ -424,7 +448,7 @@ class Data(ABC):
|
|
|
424
448
|
validate_map = []
|
|
425
449
|
if isinstance(self.validation, float) or isinstance(self.validation, int):
|
|
426
450
|
if self.validation <= 0 or self.validation >= 1:
|
|
427
|
-
raise DatasetManagerError("
|
|
451
|
+
raise DatasetManagerError("Validation must be a float between 0 and 1.", f"Received: {self.validation}", "Example: validation = 0.2 # for a 20% validation split")
|
|
428
452
|
|
|
429
453
|
train_map, validate_map = map[:int(math.floor(len(map)*(1-self.validation)))], map[int(math.floor(len(map)*(1-self.validation))):]
|
|
430
454
|
elif isinstance(self.validation, str):
|
|
@@ -527,20 +551,19 @@ class DataPrediction(Data):
|
|
|
527
551
|
groups_src : dict[str, Group] = {"default" : Group()},
|
|
528
552
|
augmentations : Union[dict[str, DataAugmentationsList], None] = {"DataAugmentation_0" : DataAugmentationsList()},
|
|
529
553
|
patch : Union[DatasetPatch, None] = DatasetPatch(),
|
|
530
|
-
use_cache : bool = True,
|
|
531
554
|
subset : Union[PredictionSubset, dict[str, PredictionSubset]] = PredictionSubset(),
|
|
532
555
|
num_workers : int = 4,
|
|
533
556
|
batch_size : int = 1) -> None:
|
|
534
557
|
|
|
535
|
-
super().__init__(dataset_filenames, groups_src, patch,
|
|
558
|
+
super().__init__(dataset_filenames, groups_src, patch, False, subset, num_workers, batch_size, dataAugmentationsList=augmentations if augmentations else {})
|
|
536
559
|
|
|
537
560
|
class DataMetric(Data):
|
|
538
561
|
|
|
539
562
|
@config("Dataset")
|
|
540
563
|
def __init__(self, dataset_filenames : list[str] = ["default:./Dataset"],
|
|
541
|
-
groups_src : dict[str,
|
|
564
|
+
groups_src : dict[str, GroupMetric] = {"default" : GroupMetric()},
|
|
542
565
|
subset : Union[PredictionSubset, dict[str, PredictionSubset]] = PredictionSubset(),
|
|
543
566
|
validation: Union[str, None] = None,
|
|
544
567
|
num_workers : int = 4) -> None:
|
|
545
568
|
|
|
546
|
-
super().__init__(dataset_filenames=dataset_filenames, groups_src=groups_src, patch=None, use_cache=False, subset=subset, num_workers=num_workers, batch_size=1, validation=
|
|
569
|
+
super().__init__(dataset_filenames=dataset_filenames, groups_src=groups_src, patch=None, use_cache=False, subset=subset, num_workers=num_workers, batch_size=1, validation=validation)
|
konfai/data/patching.py
CHANGED
|
@@ -15,7 +15,6 @@ from konfai.utils.dataset import Dataset, Attribute
|
|
|
15
15
|
from konfai.data.transform import Transform, Save
|
|
16
16
|
from konfai.data.augmentation import DataAugmentationsList
|
|
17
17
|
|
|
18
|
-
|
|
19
18
|
class PathCombine(ABC):
|
|
20
19
|
|
|
21
20
|
def __init__(self) -> None:
|
|
@@ -240,6 +239,7 @@ class DatasetManager():
|
|
|
240
239
|
self.index = index
|
|
241
240
|
self.dataset = dataset
|
|
242
241
|
self.loaded = False
|
|
242
|
+
self.augmentationLoaded = False
|
|
243
243
|
self.cache_attributes: list[Attribute] = []
|
|
244
244
|
_shape, cache_attribute = self.dataset.getInfos(self.group_src, name)
|
|
245
245
|
self.cache_attributes.append(cache_attribute)
|
|
@@ -258,6 +258,7 @@ class DatasetManager():
|
|
|
258
258
|
self.cache_attributes_bak = copy.deepcopy(self.cache_attributes)
|
|
259
259
|
|
|
260
260
|
def resetAugmentation(self):
|
|
261
|
+
self.cache_attributes[:] = self.cache_attributes[:1]
|
|
261
262
|
i = 1
|
|
262
263
|
for dataAugmentations in self.dataAugmentationsList:
|
|
263
264
|
shape = []
|
|
@@ -272,15 +273,24 @@ class DatasetManager():
|
|
|
272
273
|
self.cache_attributes.append(caches_attribute[it])
|
|
273
274
|
self.patch.load(s, i)
|
|
274
275
|
i+=1
|
|
275
|
-
|
|
276
|
+
|
|
276
277
|
def load(self, pre_transform : list[Transform], dataAugmentationsList : list[DataAugmentationsList], device: torch.device) -> None:
|
|
277
|
-
if self.loaded:
|
|
278
|
-
|
|
278
|
+
if not self.loaded:
|
|
279
|
+
self._load(pre_transform)
|
|
280
|
+
if not self.augmentationLoaded:
|
|
281
|
+
self._loadAugmentation(dataAugmentationsList, device)
|
|
282
|
+
|
|
283
|
+
def _load(self, pre_transform : list[Transform]):
|
|
284
|
+
self.cache_attributes = copy.deepcopy(self.cache_attributes_bak)
|
|
279
285
|
i = len(pre_transform)
|
|
280
286
|
data = None
|
|
281
287
|
for transformFunction in reversed(pre_transform):
|
|
282
288
|
if isinstance(transformFunction, Save):
|
|
283
|
-
|
|
289
|
+
if len(transformFunction.dataset.split(":")) > 1:
|
|
290
|
+
filename, format = transformFunction.dataset.split(":")
|
|
291
|
+
else:
|
|
292
|
+
filename = transformFunction.dataset.split(":")
|
|
293
|
+
format = "mha"
|
|
284
294
|
dataset = Dataset(filename, format)
|
|
285
295
|
if dataset.isDatasetExist(self.group_dest, self.name):
|
|
286
296
|
data, attrib = dataset.readData(self.group_dest, self.name)
|
|
@@ -298,27 +308,40 @@ class DatasetManager():
|
|
|
298
308
|
for transformFunction in pre_transform[i:]:
|
|
299
309
|
data = transformFunction(self.name, data, self.cache_attributes[0])
|
|
300
310
|
if isinstance(transformFunction, Save):
|
|
301
|
-
|
|
311
|
+
if len(transformFunction.dataset.split(":")) > 1:
|
|
312
|
+
filename, format = transformFunction.dataset.split(":")
|
|
313
|
+
else:
|
|
314
|
+
filename = transformFunction.dataset.split(":")
|
|
315
|
+
format = "mha"
|
|
302
316
|
dataset = Dataset(filename, format)
|
|
303
317
|
dataset.write(self.group_dest, self.name, data.numpy(), self.cache_attributes[0])
|
|
304
318
|
self.data : list[torch.Tensor] = list()
|
|
305
319
|
self.data.append(data)
|
|
306
320
|
|
|
321
|
+
for i in range(len(self.cache_attributes)-1):
|
|
322
|
+
self.cache_attributes[i+1].update(self.cache_attributes[0])
|
|
323
|
+
self.loaded = True
|
|
324
|
+
|
|
325
|
+
def _loadAugmentation(self, dataAugmentationsList : list[DataAugmentationsList], device: torch.device) -> None:
|
|
307
326
|
for dataAugmentations in dataAugmentationsList:
|
|
308
|
-
a_data = [data.clone() for _ in range(dataAugmentations.nb)]
|
|
327
|
+
a_data = [self.data[0].clone() for _ in range(dataAugmentations.nb)]
|
|
309
328
|
for dataAugmentation in dataAugmentations.dataAugmentations:
|
|
310
329
|
a_data = dataAugmentation(self.index, a_data, device)
|
|
311
330
|
|
|
312
331
|
for d in a_data:
|
|
313
332
|
self.data.append(d)
|
|
314
|
-
self.
|
|
333
|
+
self.augmentationLoaded = True
|
|
315
334
|
|
|
316
335
|
def unload(self) -> None:
|
|
317
|
-
|
|
318
|
-
del self.data
|
|
319
|
-
self.cache_attributes = copy.deepcopy(self.cache_attributes_bak)
|
|
336
|
+
self.data.clear()
|
|
320
337
|
self.loaded = False
|
|
321
|
-
|
|
338
|
+
self.augmentationLoaded = False
|
|
339
|
+
|
|
340
|
+
def unloadAugmentation(self) -> None:
|
|
341
|
+
self.data[:] = self.data[:1]
|
|
342
|
+
self.augmentationLoaded = False
|
|
343
|
+
|
|
344
|
+
|
|
322
345
|
def getData(self, index : int, a : int, post_transforms : list[Transform], isInput: bool) -> torch.Tensor:
|
|
323
346
|
data = self.patch.getData(self.data[a], index, a, isInput)
|
|
324
347
|
for transformFunction in post_transforms:
|
konfai/data/transform.py
CHANGED
|
@@ -6,7 +6,7 @@ from abc import ABC, abstractmethod
|
|
|
6
6
|
import torch.nn.functional as F
|
|
7
7
|
from typing import Any, Union
|
|
8
8
|
|
|
9
|
-
from konfai.utils.utils import _getModule, NeedDevice, _resample_affine, _affine_matrix
|
|
9
|
+
from konfai.utils.utils import _getModule, NeedDevice, _resample_affine, _affine_matrix, TransformError
|
|
10
10
|
from konfai.utils.dataset import Dataset, Attribute, data_to_image, image_to_data
|
|
11
11
|
from konfai.utils.config import config
|
|
12
12
|
|
|
@@ -211,39 +211,67 @@ class Resample(Transform, ABC):
|
|
|
211
211
|
_ = cache_attribute.pop_np_array("Spacing")
|
|
212
212
|
return self._resample(input, [int(size) for size in size_1])
|
|
213
213
|
|
|
214
|
-
class
|
|
214
|
+
class ResampleToResolution(Resample):
|
|
215
|
+
|
|
216
|
+
def __init__(self, spacing : list[Union[float, None]] = [1., 1., 1.]) -> None:
|
|
217
|
+
self.spacing = torch.tensor([0 if s < 0 else s for s in spacing])
|
|
215
218
|
|
|
216
|
-
def __init__(self, spacing : list[float] = [1., 1., 1.]) -> None:
|
|
217
|
-
self.spacing = torch.tensor(spacing, dtype=torch.float64)
|
|
218
|
-
|
|
219
219
|
def transformShape(self, shape: list[int], cache_attribute: Attribute) -> list[int]:
|
|
220
|
-
|
|
221
|
-
|
|
222
|
-
|
|
220
|
+
if "Spacing" not in cache_attribute:
|
|
221
|
+
TransformError("Missing 'Spacing' in cache attributes, the data is likely not a valid image.",
|
|
222
|
+
"Make sure your input is a image (e.g., .nii, .mha) with proper metadata.")
|
|
223
|
+
if len(shape) != len(self.spacing):
|
|
224
|
+
TransformError("Shape and spacing dimensions do not match: shape={shape}, spacing={self.spacing}")
|
|
225
|
+
image_spacing = cache_attribute.get_tensor("Spacing").flip(0)
|
|
226
|
+
spacing = self.spacing
|
|
227
|
+
|
|
228
|
+
for i, s in enumerate(self.spacing):
|
|
229
|
+
if s == 0:
|
|
230
|
+
spacing[i] = image_spacing[i]
|
|
231
|
+
resize_factor = spacing/cache_attribute.get_tensor("Spacing").flip(0)
|
|
232
|
+
return [int(x) for x in (torch.tensor(shape) * 1/resize_factor)]
|
|
223
233
|
|
|
224
234
|
def __call__(self, name: str, input : torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
|
|
225
|
-
|
|
226
|
-
|
|
227
|
-
|
|
235
|
+
image_spacing = cache_attribute.get_tensor("Spacing").flip(0)
|
|
236
|
+
spacing = self.spacing
|
|
237
|
+
for i, s in enumerate(self.spacing):
|
|
238
|
+
if s == 0:
|
|
239
|
+
spacing[i] = image_spacing[i]
|
|
240
|
+
resize_factor = spacing/cache_attribute.get_tensor("Spacing").flip(0)
|
|
241
|
+
cache_attribute["Spacing"] = spacing.flip(0)
|
|
228
242
|
cache_attribute["Size"] = np.asarray([int(x) for x in torch.tensor(input.shape[1:])])
|
|
229
243
|
size = [int(x) for x in (torch.tensor(input.shape[1:]) * 1/resize_factor)]
|
|
230
244
|
cache_attribute["Size"] = np.asarray(size)
|
|
231
245
|
return self._resample(input, size)
|
|
232
246
|
|
|
233
|
-
class
|
|
247
|
+
class ResampleToSize(Resample):
|
|
234
248
|
|
|
235
249
|
def __init__(self, size : list[int] = [100,512,512]) -> None:
|
|
236
250
|
self.size = size
|
|
237
251
|
|
|
238
252
|
def transformShape(self, shape: list[int], cache_attribute: Attribute) -> list[int]:
|
|
239
|
-
|
|
253
|
+
if "Spacing" not in cache_attribute:
|
|
254
|
+
TransformError("Missing 'Spacing' in cache attributes, the data is likely not a valid image.",
|
|
255
|
+
"Make sure your input is a image (e.g., .nii, .mha) with proper metadata.")
|
|
256
|
+
if len(shape) != len(self.size):
|
|
257
|
+
TransformError("Shape and spacing dimensions do not match: shape={shape}, spacing={self.spacing}")
|
|
258
|
+
size = self.size
|
|
259
|
+
for i, s in enumerate(self.size):
|
|
260
|
+
if s == -1:
|
|
261
|
+
size[i] = shape[i]
|
|
262
|
+
return size
|
|
240
263
|
|
|
241
264
|
def __call__(self, name: str, input: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
|
|
265
|
+
size = self.size
|
|
266
|
+
image_size = np.asarray([int(x) for x in torch.tensor(input.shape[1:])])
|
|
267
|
+
for i, s in enumerate(self.size):
|
|
268
|
+
if s is None:
|
|
269
|
+
size[i] = image_size[i]
|
|
242
270
|
if "Spacing" in cache_attribute:
|
|
243
|
-
cache_attribute["Spacing"] = torch.flip(torch.tensor(
|
|
244
|
-
cache_attribute["Size"] =
|
|
245
|
-
cache_attribute["Size"] =
|
|
246
|
-
return self._resample(input,
|
|
271
|
+
cache_attribute["Spacing"] = torch.flip(torch.tensor(image_size)/torch.tensor(size)*torch.flip(cache_attribute.get_tensor("Spacing"), dims=[0]), dims=[0])
|
|
272
|
+
cache_attribute["Size"] = image_size
|
|
273
|
+
cache_attribute["Size"] = size
|
|
274
|
+
return self._resample(input, size)
|
|
247
275
|
|
|
248
276
|
class ResampleTransform(Transform):
|
|
249
277
|
|
|
@@ -412,8 +440,8 @@ class FlatLabel(Transform):
|
|
|
412
440
|
|
|
413
441
|
class Save(Transform):
|
|
414
442
|
|
|
415
|
-
def __init__(self,
|
|
416
|
-
self.
|
|
443
|
+
def __init__(self, dataset: str) -> None:
|
|
444
|
+
self.dataset = dataset
|
|
417
445
|
|
|
418
446
|
def __call__(self, name: str, input : torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
|
|
419
447
|
return input
|
|
@@ -528,7 +556,7 @@ class OneHot(Transform):
|
|
|
528
556
|
self.num_classes = num_classes
|
|
529
557
|
|
|
530
558
|
def __call__(self, name: str, input: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
|
|
531
|
-
result = F.one_hot(input.type(torch.int64), num_classes=self.num_classes).permute(0, len(input.shape), *[i+1 for i in range(len(input.shape)-1)]).float().squeeze(
|
|
559
|
+
result = F.one_hot(input.type(torch.int64), num_classes=self.num_classes).permute(0, len(input.shape), *[i+1 for i in range(len(input.shape)-1)]).float().squeeze(0)
|
|
532
560
|
return result
|
|
533
561
|
|
|
534
562
|
def inverse(self, name: str, input: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
|
konfai/evaluator.py
CHANGED
|
@@ -9,7 +9,7 @@ import builtins
|
|
|
9
9
|
import importlib
|
|
10
10
|
from konfai import EVALUATIONS_DIRECTORY, PREDICTIONS_DIRECTORY, KONFAI_ROOT, CONFIG_FILE
|
|
11
11
|
from konfai.utils.config import config
|
|
12
|
-
from konfai.utils.utils import _getModule, DistributedObject, synchronize_data
|
|
12
|
+
from konfai.utils.utils import _getModule, DistributedObject, synchronize_data, EvaluatorError
|
|
13
13
|
from konfai.data.data_manager import DataMetric
|
|
14
14
|
|
|
15
15
|
class CriterionsAttr():
|
|
@@ -54,7 +54,8 @@ class Statistics():
|
|
|
54
54
|
if name_dataset not in self.measures:
|
|
55
55
|
self.measures[name_dataset] = {}
|
|
56
56
|
self.measures[name_dataset][name] = value
|
|
57
|
-
|
|
57
|
+
|
|
58
|
+
@staticmethod
|
|
58
59
|
def getStatistic(values: list[float]) -> dict[str, float]:
|
|
59
60
|
return {"max": np.max(values), "min": np.min(values), "std": np.std(values), "25pc": np.percentile(values, 25), "50pc": np.percentile(values, 50), "75pc": np.percentile(values, 75), "mean": np.mean(values), "count": len(values)}
|
|
60
61
|
|
|
@@ -91,7 +92,6 @@ class Evaluator(DistributedObject):
|
|
|
91
92
|
exit(0)
|
|
92
93
|
super().__init__(train_name)
|
|
93
94
|
self.metric_path = EVALUATIONS_DIRECTORY()+self.name+"/"
|
|
94
|
-
self.predict_path = PREDICTIONS_DIRECTORY()+self.name+"/"
|
|
95
95
|
self.metricsLoader = metrics
|
|
96
96
|
self.dataset = dataset
|
|
97
97
|
self.metrics = {k: v.getTargetsCriterions(k) for k, v in self.metricsLoader.items()}
|
|
@@ -102,10 +102,10 @@ class Evaluator(DistributedObject):
|
|
|
102
102
|
result = {}
|
|
103
103
|
for output_group in self.metrics:
|
|
104
104
|
for target_group in self.metrics[output_group]:
|
|
105
|
-
targets = [data_dict[group][0] for group in target_group.split("/") if group in data_dict]
|
|
105
|
+
targets = [data_dict[group][0].to(0) if torch.cuda.is_available() else data_dict[group][0] for group in target_group.split("/") if group in data_dict]
|
|
106
106
|
name = data_dict[output_group][1][0]
|
|
107
107
|
for metric in self.metrics[output_group][target_group]:
|
|
108
|
-
result["{}:{}:{}".format(output_group, target_group, metric.__class__.__name__)] = metric(data_dict[output_group][0], *targets).item()
|
|
108
|
+
result["{}:{}:{}".format(output_group, target_group, metric.__class__.__name__)] = metric(data_dict[output_group][0].to(0) if torch.cuda.is_available() else data_dict[output_group][0], *targets).item()
|
|
109
109
|
statistics.add(result, name)
|
|
110
110
|
return result
|
|
111
111
|
|
|
@@ -126,9 +126,26 @@ class Evaluator(DistributedObject):
|
|
|
126
126
|
|
|
127
127
|
self.dataloader = self.dataset.getData(world_size)
|
|
128
128
|
|
|
129
|
+
groupsDest = [group for groups in self.dataset.groups_src.values() for group in groups]
|
|
130
|
+
|
|
131
|
+
missing_outputs = set(self.metrics.keys()) - set(groupsDest)
|
|
132
|
+
if missing_outputs:
|
|
133
|
+
raise EvaluatorError(
|
|
134
|
+
f"The following metric output groups are missing from 'groupsDest': {sorted(missing_outputs)}. ",
|
|
135
|
+
f"Available groups: {sorted(groupsDest)}"
|
|
136
|
+
)
|
|
137
|
+
|
|
138
|
+
target_groups = {target for targets in self.metrics.values() for target in targets}
|
|
139
|
+
missing_targets = target_groups - set(groupsDest)
|
|
140
|
+
if missing_targets:
|
|
141
|
+
raise EvaluatorError(
|
|
142
|
+
f"The following metric target groups are missing from 'groupsDest': {sorted(missing_targets)}. ",
|
|
143
|
+
f"Available groups: {sorted(groupsDest)}"
|
|
144
|
+
)
|
|
145
|
+
|
|
129
146
|
def run_process(self, world_size: int, global_rank: int, gpu: int, dataloaders: list[DataLoader]):
|
|
130
147
|
description = lambda measure : "Metric TRAIN : {} ".format(" | ".join("{}: {:.2f}".format(k, v) for k, v in measure.items()) if measure is not None else "")
|
|
131
|
-
with tqdm.tqdm(iterable = enumerate(dataloaders[0]), leave=
|
|
148
|
+
with tqdm.tqdm(iterable = enumerate(dataloaders[0]), leave=True, desc = description(None), total=len(dataloaders[0]), ncols=0) as batch_iter:
|
|
132
149
|
for _, data_dict in batch_iter:
|
|
133
150
|
batch_iter.set_description(description(self.update({k: (v[0], v[4]) for k,v in data_dict.items()}, self.statistics_train)))
|
|
134
151
|
outputs = synchronize_data(world_size, gpu, self.statistics_train.measures)
|
|
@@ -136,7 +153,7 @@ class Evaluator(DistributedObject):
|
|
|
136
153
|
self.statistics_train.write(outputs)
|
|
137
154
|
if len(dataloaders) == 2:
|
|
138
155
|
description = lambda measure : "Metric VALIDATION : {} ".format(" | ".join("{}: {:.2f}".format(k, v) for k, v in measure.items()) if measure is not None else "")
|
|
139
|
-
with tqdm.tqdm(iterable = enumerate(dataloaders[1]), leave=
|
|
156
|
+
with tqdm.tqdm(iterable = enumerate(dataloaders[1]), leave=True, desc = description(None), total=len(dataloaders[1]), ncols=0) as batch_iter:
|
|
140
157
|
for _, data_dict in batch_iter:
|
|
141
158
|
batch_iter.set_description(description(self.update({k: (v[0], v[4]) for k,v in data_dict.items()}, self.statistics_validation)))
|
|
142
159
|
outputs = synchronize_data(world_size, gpu, self.statistics_validation.measures)
|
konfai/main.py
CHANGED
|
@@ -3,6 +3,7 @@ import os
|
|
|
3
3
|
from torch.cuda import device_count
|
|
4
4
|
import torch.multiprocessing as mp
|
|
5
5
|
from konfai.utils.utils import setup, TensorBoard, Log
|
|
6
|
+
from konfai import KONFAI_NB_CORES
|
|
6
7
|
|
|
7
8
|
import sys
|
|
8
9
|
sys.path.insert(0, os.getcwd())
|
|
@@ -11,10 +12,10 @@ def main():
|
|
|
11
12
|
parser = argparse.ArgumentParser(description="KonfAI", formatter_class=argparse.ArgumentDefaultsHelpFormatter)
|
|
12
13
|
try:
|
|
13
14
|
with setup(parser) as distributedObject:
|
|
14
|
-
with Log(distributedObject.name):
|
|
15
|
+
with Log(distributedObject.name, 0):
|
|
15
16
|
world_size = device_count()
|
|
16
17
|
if world_size == 0:
|
|
17
|
-
world_size =
|
|
18
|
+
world_size = int(KONFAI_NB_CORES())
|
|
18
19
|
distributedObject.setup(world_size)
|
|
19
20
|
with TensorBoard(distributedObject.name):
|
|
20
21
|
mp.spawn(distributedObject, nprocs=world_size)
|
konfai/network/network.py
CHANGED
|
@@ -39,7 +39,6 @@ class OptimizerLoader():
|
|
|
39
39
|
self.name = name
|
|
40
40
|
|
|
41
41
|
def getOptimizer(self, key: str, parameter: Iterator[torch.nn.parameter.Parameter]) -> torch.optim.Optimizer:
|
|
42
|
-
torch.optim.AdamW
|
|
43
42
|
return config("{}.Model.{}.Optimizer".format(KONFAI_ROOT(), key))(getattr(importlib.import_module('torch.optim'), self.name))(parameter, config = None)
|
|
44
43
|
|
|
45
44
|
class SchedulerStep():
|
konfai/predictor.py
CHANGED
|
@@ -94,10 +94,10 @@ class OutDataset(Dataset, NeedDevice, ABC):
|
|
|
94
94
|
class OutSameAsGroupDataset(OutDataset):
|
|
95
95
|
|
|
96
96
|
@config("OutDataset")
|
|
97
|
-
def __init__(self, dataset_filename: str = "./Dataset:mha", group: str = "default", sameAsGroup: str = "default", pre_transforms : dict[str, TransformLoader] = {"default:Normalize": TransformLoader()}, post_transforms : dict[str, TransformLoader] = {"default:Normalize": TransformLoader()}, final_transforms : dict[str, TransformLoader] = {"default:Normalize": TransformLoader()}, patchCombine: Union[str, None] = None,
|
|
97
|
+
def __init__(self, dataset_filename: str = "./Dataset:mha", group: str = "default", sameAsGroup: str = "default", pre_transforms : dict[str, TransformLoader] = {"default:Normalize": TransformLoader()}, post_transforms : dict[str, TransformLoader] = {"default:Normalize": TransformLoader()}, final_transforms : dict[str, TransformLoader] = {"default:Normalize": TransformLoader()}, patchCombine: Union[str, None] = None, reduction: str = "mean", inverse_transform: bool = True) -> None:
|
|
98
98
|
super().__init__(dataset_filename, group, pre_transforms, post_transforms, final_transforms, patchCombine)
|
|
99
99
|
self.group_src, self.group_dest = sameAsGroup.split(":")
|
|
100
|
-
self.
|
|
100
|
+
self.reduction = reduction
|
|
101
101
|
self.inverse_transform = inverse_transform
|
|
102
102
|
|
|
103
103
|
def addLayer(self, index_dataset: int, index_augmentation: int, index_patch: int, layer: torch.Tensor, dataset: DatasetIter):
|
|
@@ -151,9 +151,9 @@ class OutSameAsGroupDataset(OutDataset):
|
|
|
151
151
|
self.output_layer_accumulator.pop(index)
|
|
152
152
|
dtype = result.dtype
|
|
153
153
|
|
|
154
|
-
if self.
|
|
154
|
+
if self.reduction == "mean":
|
|
155
155
|
result = torch.mean(result.float(), dim=0).to(dtype)
|
|
156
|
-
elif self.
|
|
156
|
+
elif self.reduction == "median":
|
|
157
157
|
result, _ = torch.median(result.float(), dim=0)
|
|
158
158
|
else:
|
|
159
159
|
raise NameError("Reduction method does not exist (mean, median)")
|
|
@@ -201,7 +201,7 @@ class OutDatasetLoader():
|
|
|
201
201
|
|
|
202
202
|
class _Predictor():
|
|
203
203
|
|
|
204
|
-
def __init__(self, world_size: int, global_rank: int, local_rank: int, predict_path: str, data_log: Union[list[str], None], outsDataset: dict[str, OutDataset], modelComposite: DDP, dataloader_prediction: DataLoader) -> None:
|
|
204
|
+
def __init__(self, world_size: int, global_rank: int, local_rank: int, autocast: bool, predict_path: str, data_log: Union[list[str], None], outsDataset: dict[str, OutDataset], modelComposite: DDP, dataloader_prediction: DataLoader) -> None:
|
|
205
205
|
self.world_size = world_size
|
|
206
206
|
self.global_rank = global_rank
|
|
207
207
|
self.local_rank = local_rank
|
|
@@ -209,7 +209,7 @@ class _Predictor():
|
|
|
209
209
|
self.modelComposite = modelComposite
|
|
210
210
|
self.dataloader_prediction = dataloader_prediction
|
|
211
211
|
self.outsDataset = outsDataset
|
|
212
|
-
|
|
212
|
+
self.autocast = autocast
|
|
213
213
|
|
|
214
214
|
self.it = 0
|
|
215
215
|
|
|
@@ -239,21 +239,22 @@ class _Predictor():
|
|
|
239
239
|
self.modelComposite.eval()
|
|
240
240
|
self.modelComposite.module.setState(NetState.PREDICTION)
|
|
241
241
|
desc = lambda : "Prediction : {}".format(description(self.modelComposite))
|
|
242
|
-
self.dataloader_prediction.dataset.load()
|
|
243
|
-
with tqdm.tqdm(iterable = enumerate(self.dataloader_prediction), leave=
|
|
242
|
+
self.dataloader_prediction.dataset.load("Prediction")
|
|
243
|
+
with tqdm.tqdm(iterable = enumerate(self.dataloader_prediction), leave=True, desc = desc(), total=len(self.dataloader_prediction), ncols=0) as batch_iter:
|
|
244
244
|
dist.barrier()
|
|
245
245
|
for it, data_dict in batch_iter:
|
|
246
|
-
|
|
247
|
-
|
|
248
|
-
self.
|
|
249
|
-
|
|
250
|
-
|
|
251
|
-
|
|
252
|
-
|
|
253
|
-
|
|
254
|
-
|
|
255
|
-
|
|
256
|
-
|
|
246
|
+
with torch.amp.autocast('cuda', enabled=self.autocast):
|
|
247
|
+
input = self.getInput(data_dict)
|
|
248
|
+
for name, output in self.modelComposite(input, list(self.outsDataset.keys())):
|
|
249
|
+
self._predict_log(data_dict)
|
|
250
|
+
outDataset = self.outsDataset[name]
|
|
251
|
+
for i, (index, patch_augmentation, patch_index) in enumerate([(int(index), int(patch_augmentation), int(patch_index)) for index, patch_augmentation, patch_index in zip(list(data_dict.values())[0][1], list(data_dict.values())[0][2], list(data_dict.values())[0][3])]):
|
|
252
|
+
outDataset.addLayer(index, patch_augmentation, patch_index, output[i].cpu(), self.dataset)
|
|
253
|
+
if outDataset.isDone(index):
|
|
254
|
+
outDataset.write(index, self.dataset.getDatasetFromIndex(list(data_dict.keys())[0], index).name.split("/")[-1], outDataset.getOutput(index, self.dataset))
|
|
255
|
+
|
|
256
|
+
batch_iter.set_description(desc())
|
|
257
|
+
self.it += 1
|
|
257
258
|
|
|
258
259
|
def _predict_log(self, data_dict : dict[str, tuple[torch.Tensor, int, int, int]]):
|
|
259
260
|
measures = DistributedObject.getMeasure(self.world_size, self.global_rank, self.local_rank, {"" : self.modelComposite.module}, 1)
|
|
@@ -320,6 +321,7 @@ class Predictor(DistributedObject):
|
|
|
320
321
|
train_name: str = "name",
|
|
321
322
|
manual_seed : Union[int, None] = None,
|
|
322
323
|
gpu_checkpoints: Union[list[str], None] = None,
|
|
324
|
+
autocast : bool = False,
|
|
323
325
|
outsDataset: Union[dict[str, OutDatasetLoader], None] = {"default:Default" : OutDatasetLoader()},
|
|
324
326
|
images_log: list[str] = []) -> None:
|
|
325
327
|
if os.environ["KONFAI_CONFIG_MODE"] != "Done":
|
|
@@ -328,6 +330,7 @@ class Predictor(DistributedObject):
|
|
|
328
330
|
self.manual_seed = manual_seed
|
|
329
331
|
self.dataset = dataset
|
|
330
332
|
self.combine = combine
|
|
333
|
+
self.autocast = autocast
|
|
331
334
|
|
|
332
335
|
self.model = model.getModel(train=False)
|
|
333
336
|
self.it = 0
|
|
@@ -384,7 +387,7 @@ class Predictor(DistributedObject):
|
|
|
384
387
|
|
|
385
388
|
shutil.copyfile(CONFIG_FILE(), self.predict_path+"Prediction.yml")
|
|
386
389
|
|
|
387
|
-
self.model.init(autocast
|
|
390
|
+
self.model.init(self.autocast, State.PREDICTION, self.dataset.getGroupsDest())
|
|
388
391
|
self.model.init_outputsGroup()
|
|
389
392
|
self.model._compute_channels_trace(self.model, self.model.in_channels, None, self.gpu_checkpoints)
|
|
390
393
|
self.modelComposite = ModelComposite(self.model, len(MODEL().split(":")), self.combine)
|
|
@@ -402,7 +405,7 @@ class Predictor(DistributedObject):
|
|
|
402
405
|
def run_process(self, world_size: int, global_rank: int, local_rank: int, dataloaders: list[DataLoader]):
|
|
403
406
|
modelComposite = Network.to(self.modelComposite, local_rank*self.size)
|
|
404
407
|
modelComposite = DDP(modelComposite, static_graph=True) if torch.cuda.is_available() else CPU_Model(modelComposite)
|
|
405
|
-
with _Predictor(world_size, global_rank, local_rank, self.predict_path, self.images_log, self.outsDataset, modelComposite, *dataloaders) as p:
|
|
408
|
+
with _Predictor(world_size, global_rank, local_rank, self.autocast, self.predict_path, self.images_log, self.outsDataset, modelComposite, *dataloaders) as p:
|
|
406
409
|
p.run()
|
|
407
410
|
|
|
408
411
|
|
konfai/trainer.py
CHANGED
|
@@ -17,7 +17,6 @@ from konfai.utils.config import config
|
|
|
17
17
|
from konfai.utils.utils import State, DataLog, DistributedObject, description, TrainerError
|
|
18
18
|
from konfai.network.network import Network, ModelLoader, NetState, CPU_Model
|
|
19
19
|
|
|
20
|
-
|
|
21
20
|
class EarlyStoppingBase:
|
|
22
21
|
|
|
23
22
|
def __init__(self):
|
|
@@ -99,7 +98,7 @@ class _Trainer():
|
|
|
99
98
|
self.autocast = autocast
|
|
100
99
|
self.modelEMA = modelEMA
|
|
101
100
|
self.early_stopping = EarlyStoppingBase() if early_stopping is None else early_stopping
|
|
102
|
-
|
|
101
|
+
|
|
103
102
|
self.it_validation = it_validation
|
|
104
103
|
if self.it_validation is None:
|
|
105
104
|
self.it_validation = len(dataloader_training)
|
|
@@ -118,13 +117,14 @@ class _Trainer():
|
|
|
118
117
|
self.tb.close()
|
|
119
118
|
|
|
120
119
|
def run(self) -> None:
|
|
121
|
-
|
|
120
|
+
self.dataloader_training.dataset.load("Train")
|
|
121
|
+
self.dataloader_validation.dataset.load("Validation")
|
|
122
|
+
with tqdm.tqdm(iterable = range(self.epoch, self.epochs), leave=False, total=self.epochs, initial=self.epoch, desc="Progress") as epoch_tqdm:
|
|
122
123
|
for self.epoch in epoch_tqdm:
|
|
123
|
-
self.dataloader_training.dataset.load()
|
|
124
124
|
self.train()
|
|
125
125
|
if self.early_stopping.isStopped():
|
|
126
126
|
break
|
|
127
|
-
self.dataloader_training.dataset.resetAugmentation()
|
|
127
|
+
self.dataloader_training.dataset.resetAugmentation("Train")
|
|
128
128
|
|
|
129
129
|
def getInput(self, data_dict : dict[str, tuple[torch.Tensor, int, int, int, str, bool]]) -> dict[tuple[str, bool], torch.Tensor]:
|
|
130
130
|
return {(k, v[5][0].item()) : v[0] for k, v in data_dict.items()}
|
|
@@ -137,7 +137,8 @@ class _Trainer():
|
|
|
137
137
|
self.modelEMA.module.setState(NetState.TRAIN)
|
|
138
138
|
|
|
139
139
|
desc = lambda : "Training : {}".format(description(self.model, self.modelEMA))
|
|
140
|
-
|
|
140
|
+
|
|
141
|
+
with tqdm.tqdm(iterable = enumerate(self.dataloader_training), desc = desc(), total=len(self.dataloader_training), leave=False, ncols=0) as batch_iter:
|
|
141
142
|
for _, data_dict in batch_iter:
|
|
142
143
|
with torch.amp.autocast('cuda', enabled=self.autocast):
|
|
143
144
|
input = self.getInput(data_dict)
|
|
@@ -171,8 +172,7 @@ class _Trainer():
|
|
|
171
172
|
|
|
172
173
|
desc = lambda : "Validation : {}".format(description(self.model, self.modelEMA))
|
|
173
174
|
data_dict = None
|
|
174
|
-
self.dataloader_validation
|
|
175
|
-
with tqdm.tqdm(iterable = enumerate(self.dataloader_validation), desc = desc(), total=len(self.dataloader_validation), leave=False, disable=self.global_rank != 0 and "KONFAI_CLUSTER" not in os.environ) as batch_iter:
|
|
175
|
+
with tqdm.tqdm(iterable = enumerate(self.dataloader_validation), desc = desc(), total=len(self.dataloader_validation), leave=False, ncols=0) as batch_iter:
|
|
176
176
|
for _, data_dict in batch_iter:
|
|
177
177
|
input = self.getInput(data_dict)
|
|
178
178
|
self.model(input)
|
|
@@ -180,7 +180,7 @@ class _Trainer():
|
|
|
180
180
|
self.modelEMA.module(input)
|
|
181
181
|
|
|
182
182
|
batch_iter.set_description(desc())
|
|
183
|
-
self.dataloader_validation.dataset.resetAugmentation()
|
|
183
|
+
self.dataloader_validation.dataset.resetAugmentation("Validation")
|
|
184
184
|
dist.barrier()
|
|
185
185
|
self.model.train()
|
|
186
186
|
self.model.module.setState(NetState.TRAIN)
|
|
@@ -360,7 +360,7 @@ class Trainer(DistributedObject):
|
|
|
360
360
|
os.makedirs(dir)
|
|
361
361
|
|
|
362
362
|
for name in sorted(os.listdir(path_checkpoint)):
|
|
363
|
-
checkpoint = torch.load(path_checkpoint+name, weights_only=False)
|
|
363
|
+
checkpoint = torch.load(path_checkpoint+name, weights_only=False, map_location='cpu')
|
|
364
364
|
self.model.load(checkpoint, init=False, ema=False)
|
|
365
365
|
|
|
366
366
|
torch.save(self.model, "{}Serialized/{}".format(path_model, name))
|
konfai/utils/dataset.py
CHANGED
|
@@ -348,6 +348,10 @@ class Dataset():
|
|
|
348
348
|
@abstractmethod
|
|
349
349
|
def getNames(self, group: str) -> list[str]:
|
|
350
350
|
pass
|
|
351
|
+
|
|
352
|
+
@abstractmethod
|
|
353
|
+
def getGroup(self) -> list[str]:
|
|
354
|
+
pass
|
|
351
355
|
|
|
352
356
|
@abstractmethod
|
|
353
357
|
def isExist(self, group: str, name: Union[str, None] = None) -> bool:
|
|
@@ -458,6 +462,9 @@ class Dataset():
|
|
|
458
462
|
names.extend(self.getNames("/".join(groups.split("/")[1:]), h5_group[group]))
|
|
459
463
|
return names
|
|
460
464
|
|
|
465
|
+
def getGroup(self):
|
|
466
|
+
return self.h5.keys()
|
|
467
|
+
|
|
461
468
|
def _getDataset(self, groups: str, name: str, h5_group: h5py.Group = None) -> h5py.Dataset:
|
|
462
469
|
if h5_group is None:
|
|
463
470
|
h5_group = self.h5
|
|
@@ -613,7 +620,10 @@ class Dataset():
|
|
|
613
620
|
|
|
614
621
|
def getNames(self, group: str) -> list[str]:
|
|
615
622
|
raise NotImplementedError()
|
|
616
|
-
|
|
623
|
+
|
|
624
|
+
def getGroup(self):
|
|
625
|
+
raise NotImplementedError()
|
|
626
|
+
|
|
617
627
|
def getInfos(self, group: str, name: str) -> tuple[list[int], Attribute]:
|
|
618
628
|
attributes = Attribute()
|
|
619
629
|
if os.path.exists("{}{}{}.{}".format(self.filename, group if group is not None else "", name, self.format)):
|
|
@@ -736,7 +746,7 @@ class Dataset():
|
|
|
736
746
|
subDirectories.extend(self._getSubDirectories("/".join(groups.split("/")[1:]), subDirectory))
|
|
737
747
|
return subDirectories
|
|
738
748
|
|
|
739
|
-
def getNames(self, groups: str, index: Union[list[int], None] = None
|
|
749
|
+
def getNames(self, groups: str, index: Union[list[int], None] = None) -> list[str]:
|
|
740
750
|
names = []
|
|
741
751
|
if self.is_directory:
|
|
742
752
|
for subDirectory in self._getSubDirectories(groups):
|
|
@@ -752,6 +762,21 @@ class Dataset():
|
|
|
752
762
|
names = file.getNames(groups)
|
|
753
763
|
return [name for i, name in enumerate(sorted(names)) if index is None or i in index]
|
|
754
764
|
|
|
765
|
+
def getGroup(self):
|
|
766
|
+
if self.is_directory:
|
|
767
|
+
groups = set()
|
|
768
|
+
for root, _, files in os.walk(self.filename):
|
|
769
|
+
for file in files:
|
|
770
|
+
path = os.path.relpath(os.path.join(root, file.split(".")[0]), self.filename)
|
|
771
|
+
parts = path.split("/")
|
|
772
|
+
if len(parts) >= 2:
|
|
773
|
+
del parts[-2]
|
|
774
|
+
groups.add("/".join(parts))
|
|
775
|
+
else:
|
|
776
|
+
with Dataset.File(self.filename, True, self.format) as file:
|
|
777
|
+
groups = file.getGroup()
|
|
778
|
+
return list(groups)
|
|
779
|
+
|
|
755
780
|
def getInfos(self, groups: str, name: str) -> tuple[list[int], Attribute]:
|
|
756
781
|
if self.is_directory:
|
|
757
782
|
for subDirectory in self._getSubDirectories(groups):
|
konfai/utils/utils.py
CHANGED
|
@@ -18,13 +18,15 @@ import random
|
|
|
18
18
|
from torch.utils.data import DataLoader
|
|
19
19
|
import torch.nn.functional as F
|
|
20
20
|
import sys
|
|
21
|
+
import re
|
|
22
|
+
|
|
21
23
|
|
|
22
24
|
def description(model, modelEMA = None, showMemory: bool = True) -> str:
|
|
23
25
|
values_desc = lambda weights, values: " ".join(["{}({:.2f}) : {:.6f}".format(name.split(":")[-1], weight, value) for (name, value), weight in zip(values.items(), weights.values())])
|
|
24
26
|
model_desc = lambda model : "("+" ".join(["{}({:.6f}) : {}".format(name, network.optimizer.param_groups[0]['lr'] if network.optimizer is not None else 0, values_desc(network.measure.getLastWeights(), network.measure.getLastValues())) for name, network in model.module.getNetworks().items() if network.measure is not None])+")"
|
|
25
27
|
result = "Loss {}".format(model_desc(model))
|
|
26
28
|
if modelEMA is not None:
|
|
27
|
-
result += "Loss EMA {}".format(model_desc(modelEMA))
|
|
29
|
+
result += " Loss EMA {}".format(model_desc(modelEMA))
|
|
28
30
|
result += " "+gpuInfo()
|
|
29
31
|
if showMemory:
|
|
30
32
|
result +=" | {}".format(memoryInfo())
|
|
@@ -237,7 +239,7 @@ class DataLog(Enum):
|
|
|
237
239
|
AUDIO = lambda tb, name, layer, it : tb.add_audio(name, _logImageFormat(layer), it)
|
|
238
240
|
|
|
239
241
|
class Log:
|
|
240
|
-
def __init__(self, name: str) -> None:
|
|
242
|
+
def __init__(self, name: str, rank: int) -> None:
|
|
241
243
|
if KONFAI_STATE() == "PREDICTION":
|
|
242
244
|
path = PREDICTIONS_DIRECTORY()
|
|
243
245
|
elif KONFAI_STATE() == "EVALUATION":
|
|
@@ -248,11 +250,12 @@ class Log:
|
|
|
248
250
|
self.verbose = os.environ.get("KONFAI_VERBOSE", "True") == "True"
|
|
249
251
|
self.log_path = os.path.join(path, name)
|
|
250
252
|
os.makedirs(self.log_path, exist_ok=True)
|
|
251
|
-
|
|
252
|
-
self.file = open(os.path.join(self.log_path, "
|
|
253
|
+
self.rank = rank
|
|
254
|
+
self.file = open(os.path.join(self.log_path, "log_{}.txt".format(rank)), "w", buffering=1)
|
|
253
255
|
self.stdout_bak = sys.stdout
|
|
254
256
|
self.stderr_bak = sys.stderr
|
|
255
|
-
|
|
257
|
+
self._buffered_line = ""
|
|
258
|
+
|
|
256
259
|
def __enter__(self):
|
|
257
260
|
self.file.__enter__()
|
|
258
261
|
sys.stdout = self
|
|
@@ -264,12 +267,26 @@ class Log:
|
|
|
264
267
|
sys.stdout = self.stdout_bak
|
|
265
268
|
sys.stderr = self.stderr_bak
|
|
266
269
|
|
|
267
|
-
def write(self, msg):
|
|
270
|
+
def write(self, msg: str):
|
|
268
271
|
if not msg:
|
|
269
272
|
return
|
|
270
|
-
|
|
271
|
-
|
|
272
|
-
|
|
273
|
+
|
|
274
|
+
|
|
275
|
+
ANSI_ESCAPE = re.compile(r'\x1b\[[0-9;]*[a-zA-Z]')
|
|
276
|
+
CARRIAGE_RETURN = re.compile(r'(?:\r|\x1b\[A).*')
|
|
277
|
+
msg_clean = ANSI_ESCAPE.sub('', msg)
|
|
278
|
+
if '\r' in msg_clean or '[A' in msg:
|
|
279
|
+
# On garde seulement le contenu après le dernier retour chariot
|
|
280
|
+
msg_clean = msg_clean.split('\r')[-1].strip()
|
|
281
|
+
self._buffered_line = msg_clean
|
|
282
|
+
else:
|
|
283
|
+
self._buffered_line = msg_clean.strip()
|
|
284
|
+
|
|
285
|
+
if self._buffered_line:
|
|
286
|
+
# Écrit dans le fichier
|
|
287
|
+
self.file.write(self._buffered_line + "\n")
|
|
288
|
+
self.file.flush()
|
|
289
|
+
if self.verbose and (self.rank == 0 or "KONFAI_CLUSTER" in os.environ):
|
|
273
290
|
sys.__stdout__.write(msg)
|
|
274
291
|
sys.__stdout__.flush()
|
|
275
292
|
|
|
@@ -352,7 +369,7 @@ class DistributedObject():
|
|
|
352
369
|
return result
|
|
353
370
|
|
|
354
371
|
def __call__(self, rank: Union[int, None] = None) -> None:
|
|
355
|
-
with Log(self.name):
|
|
372
|
+
with Log(self.name, rank):
|
|
356
373
|
world_size = len(self.dataloader)
|
|
357
374
|
global_rank, local_rank = setupGPU(world_size, self.port, rank)
|
|
358
375
|
if global_rank is None:
|
|
@@ -385,6 +402,7 @@ def setup(parser: argparse.ArgumentParser) -> DistributedObject:
|
|
|
385
402
|
KONFAI_args.add_argument('-tb', action='store_true', help='Start TensorBoard')
|
|
386
403
|
KONFAI_args.add_argument("-c", "--config", type=str, default="None", help="Configuration file location")
|
|
387
404
|
KONFAI_args.add_argument("-g", "--gpu", type=str, default=os.environ["CUDA_VISIBLE_DEVICES"] if "CUDA_VISIBLE_DEVICES" in os.environ else "", help="List of GPU")
|
|
405
|
+
KONFAI_args.add_argument("-cpu", "--cpu", type=str, default="1" , help="List of GPU")
|
|
388
406
|
KONFAI_args.add_argument('--num-workers', '--num_workers', default=4, type=int, help='No. of workers per DataLoader & GPU')
|
|
389
407
|
KONFAI_args.add_argument("-models_dir", "--MODELS_DIRECTORY", type=str, default="./Models/", help="Models location")
|
|
390
408
|
KONFAI_args.add_argument("-checkpoints_dir", "--CHECKPOINTS_DIRECTORY", type=str, default="./Checkpoints/", help="Checkpoints location")
|
|
@@ -401,6 +419,7 @@ def setup(parser: argparse.ArgumentParser) -> DistributedObject:
|
|
|
401
419
|
config = vars(args)
|
|
402
420
|
|
|
403
421
|
os.environ["CUDA_VISIBLE_DEVICES"] = config["gpu"]
|
|
422
|
+
os.environ["KONFAI_NB_CORES"] = config["cpu"]
|
|
404
423
|
os.environ["KONFAI_MODELS_DIRECTORY"] = config["MODELS_DIRECTORY"]
|
|
405
424
|
os.environ["KONFAI_CHECKPOINTS_DIRECTORY"] = config["CHECKPOINTS_DIRECTORY"]
|
|
406
425
|
os.environ["KONFAI_PREDICTIONS_DIRECTORY"] = config["PREDICTIONS_DIRECTORY"]
|
|
@@ -465,6 +484,14 @@ def setupGPU(world_size: int, port: int, rank: Union[int, None] = None) -> tuple
|
|
|
465
484
|
if torch.cuda.is_available():
|
|
466
485
|
torch.cuda.empty_cache()
|
|
467
486
|
dist.init_process_group("nccl", rank=global_rank, init_method="tcp://{}:{}".format(host_name, port), world_size=world_size)
|
|
487
|
+
else:
|
|
488
|
+
if not dist.is_initialized():
|
|
489
|
+
dist.init_process_group(
|
|
490
|
+
backend="gloo",
|
|
491
|
+
init_method=f"tcp://{host_name}:{port}",
|
|
492
|
+
rank=global_rank,
|
|
493
|
+
world_size=world_size
|
|
494
|
+
)
|
|
468
495
|
return global_rank, local_rank
|
|
469
496
|
|
|
470
497
|
import socket
|
|
@@ -524,8 +551,8 @@ SUPPORTED_EXTENSIONS = [
|
|
|
524
551
|
|
|
525
552
|
class KonfAIError(Exception):
|
|
526
553
|
|
|
527
|
-
def __init__(self, typeError: str,
|
|
528
|
-
super().__init__("\n[{}] {}".format(typeError,
|
|
554
|
+
def __init__(self, typeError: str, messages: list[str]) -> None:
|
|
555
|
+
super().__init__("\n[{}] {}".format(typeError, messages[0])+("\n" if len(messages) > 0 else "")+"\n→\t".join(messages[1:]))
|
|
529
556
|
|
|
530
557
|
|
|
531
558
|
class ConfigError(KonfAIError):
|
|
@@ -553,3 +580,13 @@ class AugmentationError(KonfAIError):
|
|
|
553
580
|
|
|
554
581
|
def __init__(self, *message) -> None:
|
|
555
582
|
super().__init__("Augmentation", message)
|
|
583
|
+
|
|
584
|
+
class EvaluatorError(KonfAIError):
|
|
585
|
+
|
|
586
|
+
def __init__(self, *message) -> None:
|
|
587
|
+
super().__init__("Evaluator", message)
|
|
588
|
+
|
|
589
|
+
class TransformError(KonfAIError):
|
|
590
|
+
|
|
591
|
+
def __init__(self, *message) -> None:
|
|
592
|
+
super().__init__("Transform", message)
|
|
@@ -1,13 +1,13 @@
|
|
|
1
|
-
konfai/__init__.py,sha256=
|
|
2
|
-
konfai/evaluator.py,sha256=
|
|
3
|
-
konfai/main.py,sha256=
|
|
4
|
-
konfai/predictor.py,sha256=
|
|
5
|
-
konfai/trainer.py,sha256=
|
|
1
|
+
konfai/__init__.py,sha256=YXG-wpSEXWs6Jt3BDI77V4r89gEUNX-6lxW9btj5VYI,851
|
|
2
|
+
konfai/evaluator.py,sha256=rAhfdRemMjzC3VoaqyQKJR0SBekuLDiLT1nhblH8RQk,8293
|
|
3
|
+
konfai/main.py,sha256=rTTJl-biaX4CkLNxPtqwwsrybXSDxjWlU39TE2ImU5o,2574
|
|
4
|
+
konfai/predictor.py,sha256=e5V7awlIRDE1NkLTR_D6fUuDg93ZFGaGLTDR71KeqM8,22549
|
|
5
|
+
konfai/trainer.py,sha256=NQDzp4hUKYAcDdKmeBcAE9QNFW9KGW2qV-7q_PXeWi8,19509
|
|
6
6
|
konfai/data/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
7
7
|
konfai/data/augmentation.py,sha256=ASmKWBpykLBHDB_YeTgoBTlqvJ06v5OUG7n7ugPN6NU,31718
|
|
8
|
-
konfai/data/data_manager.py,sha256=
|
|
9
|
-
konfai/data/patching.py,sha256=
|
|
10
|
-
konfai/data/transform.py,sha256=
|
|
8
|
+
konfai/data/data_manager.py,sha256=9VpF4CTTAnqS1Hq0csbTE-XAocRwziFssITg-SZhA4A,29603
|
|
9
|
+
konfai/data/patching.py,sha256=OtGNs99jKQSYmuj8B3MNGcbupKOcX5PUpMCIe0pKnX0,15636
|
|
10
|
+
konfai/data/transform.py,sha256=dTZh2CgfxNitHW-HEkO0jQ9GG26-Wf-YlI-qoI2mqC0,26472
|
|
11
11
|
konfai/metric/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
12
12
|
konfai/metric/measure.py,sha256=K7s15wNJm1_Iaav9y7Oe_UhSnbLUiybhLTVumHZ8ZaY,22001
|
|
13
13
|
konfai/metric/schedulers.py,sha256=UoSr1TW_hrus3DhvOEbDefxCSUGz7lJS_8vbz0GEye8,1370
|
|
@@ -24,16 +24,16 @@ konfai/models/segmentation/NestedUNet.py,sha256=6XGizAIc4bDL8vx4AHW8BBFjUvovRYcj
|
|
|
24
24
|
konfai/models/segmentation/UNet.py,sha256=9-g63oNqaxSlGmrD1-IVzJ-kac9QphmM7i-3rEsH23I,4218
|
|
25
25
|
konfai/network/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
26
26
|
konfai/network/blocks.py,sha256=RaTI0Lrvq1V-GIFei-WTUB6wlg4LydZksAyJ8DMk40M,13502
|
|
27
|
-
konfai/network/network.py,sha256=
|
|
27
|
+
konfai/network/network.py,sha256=SnwQUKFTZZU8zuix5sm0vW8s0G3h6Fa8pHoIcwC984Q,46512
|
|
28
28
|
konfai/utils/ITK.py,sha256=OxTieDNNYHGkn7zxJsAG-6ecRG1VYMvn1dlBbBe1DOs,13955
|
|
29
29
|
konfai/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
30
30
|
konfai/utils/config.py,sha256=OsCUqgNaPX56nD_MpJ2WRpiQ7xxqHonfPYaS9VGtsl8,12413
|
|
31
|
-
konfai/utils/dataset.py,sha256=
|
|
31
|
+
konfai/utils/dataset.py,sha256=6ZzevdhJ7e5zlXATAVwSh9O6acKXM7gYNxkMAa5DrmM,36351
|
|
32
32
|
konfai/utils/registration.py,sha256=v1srEBOcgDnHrx0YtsK6bcj0yCMH7wNeaQ3wC7gEvOw,8898
|
|
33
|
-
konfai/utils/utils.py,sha256=
|
|
34
|
-
konfai-1.1.
|
|
35
|
-
konfai-1.1.
|
|
36
|
-
konfai-1.1.
|
|
37
|
-
konfai-1.1.
|
|
38
|
-
konfai-1.1.
|
|
39
|
-
konfai-1.1.
|
|
33
|
+
konfai/utils/utils.py,sha256=Laq8bGc5mGKFZlJIkHxa-BrC9uR2F7MTRuL4YFAIxQY,23439
|
|
34
|
+
konfai-1.1.2.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
|
|
35
|
+
konfai-1.1.2.dist-info/METADATA,sha256=Usg7hzW1xmlHswHZb3fC2Hgl2V1Jkq43J0T5j5TDEIw,2515
|
|
36
|
+
konfai-1.1.2.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
37
|
+
konfai-1.1.2.dist-info/entry_points.txt,sha256=fG82HRN5-g39ACSOCtij_I3N6EHxfYnMR0D7TI_8pW8,81
|
|
38
|
+
konfai-1.1.2.dist-info/top_level.txt,sha256=xF470dkIlFoFqTZEOlRehKJr4WU_8OKGXrJqYm9vWKs,7
|
|
39
|
+
konfai-1.1.2.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|