konfai 1.0.9__py3-none-any.whl → 1.1.1__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of konfai might be problematic. Click here for more details.
- konfai/data/augmentation.py +2 -2
- konfai/data/data_manager.py +111 -31
- konfai/data/patching.py +5 -2
- konfai/data/transform.py +0 -1
- konfai/main.py +4 -3
- konfai/metric/measure.py +1 -1
- konfai/models/registration/registration.py +0 -1
- konfai/network/blocks.py +0 -1
- konfai/network/network.py +29 -15
- konfai/trainer.py +126 -41
- konfai/utils/config.py +12 -12
- konfai/utils/utils.py +61 -14
- {konfai-1.0.9.dist-info → konfai-1.1.1.dist-info}/METADATA +1 -1
- {konfai-1.0.9.dist-info → konfai-1.1.1.dist-info}/RECORD +18 -18
- {konfai-1.0.9.dist-info → konfai-1.1.1.dist-info}/WHEEL +0 -0
- {konfai-1.0.9.dist-info → konfai-1.1.1.dist-info}/entry_points.txt +0 -0
- {konfai-1.0.9.dist-info → konfai-1.1.1.dist-info}/licenses/LICENSE +0 -0
- {konfai-1.0.9.dist-info → konfai-1.1.1.dist-info}/top_level.txt +0 -0
konfai/data/augmentation.py
CHANGED
|
@@ -8,7 +8,7 @@ from typing import Union
|
|
|
8
8
|
import os
|
|
9
9
|
from konfai import KONFAI_ROOT
|
|
10
10
|
from konfai.utils.config import config
|
|
11
|
-
from konfai.utils.utils import _getModule
|
|
11
|
+
from konfai.utils.utils import _getModule, AugmentationError
|
|
12
12
|
from konfai.utils.dataset import Attribute, data_to_image
|
|
13
13
|
|
|
14
14
|
|
|
@@ -222,7 +222,7 @@ class ColorTransform(DataAugmentation):
|
|
|
222
222
|
matrix = matrix[:, :3, :].mean(dim=1, keepdims=True).to(input.device)
|
|
223
223
|
result = result.float() * matrix[:, :, :3].sum(dim=2, keepdims=True) + matrix[:, :, 3:]
|
|
224
224
|
else:
|
|
225
|
-
raise
|
|
225
|
+
raise AugmentationError('Image must be RGB (3 channels) or L (1 channel)')
|
|
226
226
|
results.append(result.reshape(input.shape))
|
|
227
227
|
return results
|
|
228
228
|
|
konfai/data/data_manager.py
CHANGED
|
@@ -11,11 +11,12 @@ from typing import Union, Iterator
|
|
|
11
11
|
from concurrent.futures import ThreadPoolExecutor, as_completed
|
|
12
12
|
import threading
|
|
13
13
|
from torch.cuda import device_count
|
|
14
|
+
import SimpleITK as sitk
|
|
14
15
|
|
|
15
16
|
from konfai import KONFAI_STATE, KONFAI_ROOT
|
|
16
17
|
from konfai.data.patching import DatasetPatch, DatasetManager
|
|
17
18
|
from konfai.utils.config import config
|
|
18
|
-
from konfai.utils.utils import memoryInfo, cpuInfo, memoryForecast, getMemory, State
|
|
19
|
+
from konfai.utils.utils import memoryInfo, cpuInfo, memoryForecast, getMemory, State, SUPPORTED_EXTENSIONS, DatasetManagerError
|
|
19
20
|
from konfai.utils.dataset import Dataset, Attribute
|
|
20
21
|
from konfai.data.transform import TransformLoader, Transform
|
|
21
22
|
from konfai.data.augmentation import DataAugmentationsList
|
|
@@ -236,6 +237,8 @@ class Subset():
|
|
|
236
237
|
index = random.sample(index, len(index))
|
|
237
238
|
return set([names_filtred[i] for i in index])
|
|
238
239
|
|
|
240
|
+
def __str__(self):
|
|
241
|
+
return "Subset : " + str(self.subset) + " shuffle : "+ str(self.shuffle) + " filter : "+ str(self.filter)
|
|
239
242
|
class TrainSubset(Subset):
|
|
240
243
|
|
|
241
244
|
@config()
|
|
@@ -257,14 +260,14 @@ class Data(ABC):
|
|
|
257
260
|
subset : Union[Subset, dict[str, Subset]],
|
|
258
261
|
num_workers : int,
|
|
259
262
|
batch_size : int,
|
|
260
|
-
|
|
263
|
+
validation: Union[float, str, list[int], list[str]] = 1,
|
|
261
264
|
inlineAugmentations: bool = False,
|
|
262
265
|
dataAugmentationsList: dict[str, DataAugmentationsList]= {}) -> None:
|
|
263
266
|
self.dataset_filenames = dataset_filenames
|
|
264
267
|
self.subset = subset
|
|
265
268
|
self.groups_src = groups_src
|
|
266
269
|
self.patch = patch
|
|
267
|
-
self.
|
|
270
|
+
self.validation = validation
|
|
268
271
|
self.dataAugmentationsList = dataAugmentationsList
|
|
269
272
|
self.batch_size = batch_size
|
|
270
273
|
self.dataSet_args = dict(groups_src=self.groups_src, inlineAugmentations=inlineAugmentations, dataAugmentationsList = list(self.dataAugmentationsList.values()), use_cache = use_cache, buffer_size=batch_size+1, patch_size=self.patch.patch_size if self.patch is not None else None, overlap=self.patch.overlap if self.patch is not None else None)
|
|
@@ -290,6 +293,13 @@ class Data(ABC):
|
|
|
290
293
|
map.append((x, y, z))
|
|
291
294
|
return data, map
|
|
292
295
|
|
|
296
|
+
def getGroupsDest(self):
|
|
297
|
+
groupsDest = []
|
|
298
|
+
for group_src in self.groups_src:
|
|
299
|
+
for group_dest in self.groups_src[group_src]:
|
|
300
|
+
groupsDest.append(group_dest)
|
|
301
|
+
return groupsDest
|
|
302
|
+
|
|
293
303
|
def _split(map: list[tuple[int, int, int]], world_size: int) -> list[list[tuple[int, int, int]]]:
|
|
294
304
|
if len(map) == 0:
|
|
295
305
|
return [[] for _ in range(world_size)]
|
|
@@ -313,7 +323,14 @@ class Data(ABC):
|
|
|
313
323
|
|
|
314
324
|
def getData(self, world_size: int) -> list[list[DataLoader]]:
|
|
315
325
|
datasets: dict[str, list[(str, bool)]] = {}
|
|
326
|
+
if self.dataset_filenames is None or len(self.dataset_filenames) == 0:
|
|
327
|
+
raise DatasetManagerError("No dataset filenames were provided")
|
|
316
328
|
for dataset_filename in self.dataset_filenames:
|
|
329
|
+
if dataset_filename is None:
|
|
330
|
+
raise DatasetManagerError("Invalid dataset entry: 'None' received.",
|
|
331
|
+
"Each dataset must be a valid path string (e.g., './Dataset/', './Dataset/:mha, './Dataset/:a:mha', './Dataset/:i:mha').",
|
|
332
|
+
"Please check your 'dataset_filenames' list for missing or null entries."
|
|
333
|
+
)
|
|
317
334
|
if len(dataset_filename.split(":")) == 1:
|
|
318
335
|
filename = dataset_filename
|
|
319
336
|
format = "mha"
|
|
@@ -324,9 +341,13 @@ class Data(ABC):
|
|
|
324
341
|
else:
|
|
325
342
|
filename, flag, format = dataset_filename.split(":")
|
|
326
343
|
append = flag == "a"
|
|
327
|
-
|
|
328
|
-
|
|
344
|
+
|
|
345
|
+
if format not in SUPPORTED_EXTENSIONS:
|
|
346
|
+
raise DatasetManagerError(f"Unsupported file format '{format}'.",
|
|
347
|
+
f"Supported extensions are: {', '.join(SUPPORTED_EXTENSIONS)}")
|
|
329
348
|
|
|
349
|
+
dataset = Dataset(filename, format)
|
|
350
|
+
|
|
330
351
|
self.datasets[filename] = dataset
|
|
331
352
|
for group in self.groups_src:
|
|
332
353
|
if dataset.isGroupExist(group):
|
|
@@ -334,16 +355,30 @@ class Data(ABC):
|
|
|
334
355
|
datasets[group].append((filename, append))
|
|
335
356
|
else:
|
|
336
357
|
datasets[group] = [(filename, append)]
|
|
358
|
+
modelHaveInput = False
|
|
337
359
|
for group_src in self.groups_src:
|
|
338
360
|
if group_src not in datasets:
|
|
339
|
-
raise ValueError("[DatasetManager] Error: group source {} not found. Available groups: {}".format(group_src, list(datasets.keys())))
|
|
340
361
|
|
|
362
|
+
raise DatasetManagerError(
|
|
363
|
+
f"Group source '{group_src}' not found in any dataset.",
|
|
364
|
+
f"Dataset filenames provided: {self.dataset_filenames}",
|
|
365
|
+
f"Available groups across all datasets: {sorted(list(datasets.keys()))}",
|
|
366
|
+
f"Please check that an entry in the dataset with the name '{group_src}.{format}' exists."
|
|
367
|
+
)
|
|
368
|
+
|
|
341
369
|
for group_dest in self.groups_src[group_src]:
|
|
342
370
|
self.groups_src[group_src][group_dest].load(group_src, group_dest, [self.datasets[filename] for filename, _ in datasets[group_src]])
|
|
371
|
+
modelHaveInput |= self.groups_src[group_src][group_dest].isInput
|
|
372
|
+
|
|
373
|
+
if not modelHaveInput:
|
|
374
|
+
raise DatasetManagerError(
|
|
375
|
+
"At least one group must be defined with 'isInput: true' to provide input to the network."
|
|
376
|
+
)
|
|
377
|
+
|
|
343
378
|
for key, dataAugmentations in self.dataAugmentationsList.items():
|
|
344
379
|
dataAugmentations.load(key)
|
|
345
380
|
|
|
346
|
-
names
|
|
381
|
+
names = set()
|
|
347
382
|
dataset_name : dict[str, dict[str, list[str]]] = {}
|
|
348
383
|
dataset_info : dict[str, dict[str, dict[str, Attribute]]] = {}
|
|
349
384
|
for group in self.groups_src:
|
|
@@ -351,10 +386,11 @@ class Data(ABC):
|
|
|
351
386
|
dataset_name[group] = {}
|
|
352
387
|
dataset_info[group] = {}
|
|
353
388
|
for filename, _ in datasets[group]:
|
|
354
|
-
names.
|
|
389
|
+
names.update(self.datasets[filename].getNames(group))
|
|
355
390
|
dataset_name[group][filename] = self.datasets[filename].getNames(group)
|
|
356
391
|
dataset_info[group][filename] = {name: self.datasets[filename].getInfos(group, name) for name in dataset_name[group][filename]}
|
|
357
392
|
subset_names = set()
|
|
393
|
+
|
|
358
394
|
if isinstance(self.subset, dict):
|
|
359
395
|
for filename, subset in self.subset.items():
|
|
360
396
|
subset_names.update(subset([dataset_name[group][filename] for group in dataset_name], [dataset_info[group][filename] for group in dataset_name]))
|
|
@@ -368,41 +404,86 @@ class Data(ABC):
|
|
|
368
404
|
subset_names.update(self.subset([dataset_name[group][filename]], [dataset_info[group][filename]]))
|
|
369
405
|
else:
|
|
370
406
|
subset_names.intersection(self.subset([dataset_name[group][filename]], [dataset_info[group][filename]]))
|
|
407
|
+
if len(subset_names) == 0:
|
|
408
|
+
raise DatasetManagerError("All data entries were excluded by the subset filter.",
|
|
409
|
+
f"Dataset entries found: {', '.join(names)}",
|
|
410
|
+
f"Subset object applied: {self.subset}",
|
|
411
|
+
f"Subset requested : {', '.join(subset_names)}",
|
|
412
|
+
"None of the dataset entries matched the given subset.",
|
|
413
|
+
"Please check your 'subset' configuration — it may be too restrictive or incorrectly formatted.",
|
|
414
|
+
"Examples of valid subset formats:",
|
|
415
|
+
"\tsubset: [0, 1] # explicit indices",
|
|
416
|
+
"\tsubset: 0:10 # slice notation",
|
|
417
|
+
"\tsubset: ./Validation.txt # external file",
|
|
418
|
+
"\tsubset: None # to disable filtering"
|
|
419
|
+
)
|
|
420
|
+
|
|
371
421
|
data, map = self._getDatasets(list(subset_names), dataset_name)
|
|
372
|
-
|
|
422
|
+
|
|
373
423
|
train_map = map
|
|
374
424
|
validate_map = []
|
|
375
|
-
if isinstance(self.
|
|
376
|
-
if self.
|
|
377
|
-
|
|
378
|
-
|
|
379
|
-
|
|
425
|
+
if isinstance(self.validation, float) or isinstance(self.validation, int):
|
|
426
|
+
if self.validation <= 0 or self.validation >= 1:
|
|
427
|
+
raise DatasetManagerError("validation must be a float between 0 and 1.", f"→ Received: {self.validation}", "→ Example: validation = 0.2 # for a 20% validation split")
|
|
428
|
+
|
|
429
|
+
train_map, validate_map = map[:int(math.floor(len(map)*(1-self.validation)))], map[int(math.floor(len(map)*(1-self.validation))):]
|
|
430
|
+
elif isinstance(self.validation, str):
|
|
431
|
+
if ":" in self.validation:
|
|
380
432
|
index = list(range(int(self.subset.split(":")[0]), int(self.subset.split(":")[1])))
|
|
381
433
|
train_map = [m for m in map if m[0] not in index]
|
|
382
434
|
validate_map = [m for m in map if m[0] in index]
|
|
383
|
-
elif os.path.exists(self.
|
|
435
|
+
elif os.path.exists(self.validation):
|
|
384
436
|
validation_names = []
|
|
385
|
-
with open(self.
|
|
437
|
+
with open(self.validation, "r") as f:
|
|
386
438
|
for name in f:
|
|
387
439
|
validation_names.append(name.strip())
|
|
388
440
|
index = [i for i, n in enumerate(subset_names) if n in validation_names]
|
|
389
441
|
train_map = [m for m in map if m[0] not in index]
|
|
390
442
|
validate_map = [m for m in map if m[0] in index]
|
|
391
443
|
else:
|
|
392
|
-
|
|
393
|
-
|
|
394
|
-
|
|
395
|
-
|
|
396
|
-
|
|
397
|
-
|
|
398
|
-
|
|
399
|
-
|
|
444
|
+
raise DatasetManagerError(
|
|
445
|
+
f"Invalid string value for 'validation': '{self.validation}'",
|
|
446
|
+
"Expected one of the following formats:",
|
|
447
|
+
"\t• A slice string like '0:10'",
|
|
448
|
+
"\t• A path to a text file listing validation sample names (e.g., './val.txt')",
|
|
449
|
+
"\t• A float between 0 and 1 (e.g., 0.2)",
|
|
450
|
+
"\t• A list of sample names or indices",
|
|
451
|
+
"The provided value is neither a valid slice nor a readable file.",
|
|
452
|
+
"Please fix your 'validation' setting in the configuration."
|
|
453
|
+
)
|
|
454
|
+
|
|
455
|
+
elif isinstance(self.validation, list):
|
|
456
|
+
if len(self.validation) > 0:
|
|
457
|
+
if isinstance(self.validation[0], int):
|
|
458
|
+
train_map = [m for m in map if m[0] not in self.validation]
|
|
459
|
+
validate_map = [m for m in map if m[0] in self.validation]
|
|
460
|
+
elif isinstance(self.validation[0], str):
|
|
461
|
+
index = [i for i, n in enumerate(subset_names) if n in self.validation]
|
|
400
462
|
train_map = [m for m in map if m[0] not in index]
|
|
401
463
|
validate_map = [m for m in map if m[0] in index]
|
|
402
|
-
|
|
464
|
+
else:
|
|
465
|
+
raise DatasetManagerError(f"Invalid list type for 'validation': elements of type '{type(self.validation[0]).__name__}' are not supported.",
|
|
466
|
+
"Supported list element types are:",
|
|
467
|
+
"\t• int → list of indices (e.g., [0, 1, 2])",
|
|
468
|
+
"\t• str → list of sample names (e.g., ['patient01', 'patient02'])",
|
|
469
|
+
f"Received list: {self.validation}"
|
|
470
|
+
)
|
|
471
|
+
if len(train_map) == 0:
|
|
472
|
+
raise DatasetManagerError("No data left for training after applying the validation split.",
|
|
473
|
+
f"Dataset size: {len(map)}",
|
|
474
|
+
f"Validation setting: {self.validation}",
|
|
475
|
+
"Please reduce the validation size, increase the dataset, or disable validation."
|
|
476
|
+
)
|
|
477
|
+
|
|
478
|
+
if self.validation is not None and len(validate_map) == 0:
|
|
479
|
+
raise DatasetManagerError("No data left for validation after applying the validation split.",
|
|
480
|
+
f"Dataset size: {len(map)}",
|
|
481
|
+
f"Validation setting: {self.validation}",
|
|
482
|
+
"Please increase the validation size, increase the dataset, or disable validation."
|
|
483
|
+
)
|
|
403
484
|
train_maps = Data._split(train_map, world_size)
|
|
404
485
|
validate_maps = Data._split(validate_map, world_size)
|
|
405
|
-
|
|
486
|
+
|
|
406
487
|
for i, (train_map, validate_map) in enumerate(zip(train_maps, validate_maps)):
|
|
407
488
|
maps = [train_map]
|
|
408
489
|
if len(validate_map):
|
|
@@ -436,8 +517,8 @@ class DataTrain(Data):
|
|
|
436
517
|
subset : Union[TrainSubset, dict[str, TrainSubset]] = TrainSubset(),
|
|
437
518
|
num_workers : int = 4,
|
|
438
519
|
batch_size : int = 1,
|
|
439
|
-
|
|
440
|
-
super().__init__(dataset_filenames, groups_src, patch, use_cache, subset, num_workers, batch_size,
|
|
520
|
+
validation : Union[float, str, list[int], list[str]] = 0.2) -> None:
|
|
521
|
+
super().__init__(dataset_filenames, groups_src, patch, use_cache, subset, num_workers, batch_size, validation, inlineAugmentations, augmentations if augmentations else {})
|
|
441
522
|
|
|
442
523
|
class DataPrediction(Data):
|
|
443
524
|
|
|
@@ -445,14 +526,13 @@ class DataPrediction(Data):
|
|
|
445
526
|
def __init__(self, dataset_filenames : list[str] = ["default:./Dataset"],
|
|
446
527
|
groups_src : dict[str, Group] = {"default" : Group()},
|
|
447
528
|
augmentations : Union[dict[str, DataAugmentationsList], None] = {"DataAugmentation_0" : DataAugmentationsList()},
|
|
448
|
-
inlineAugmentations: bool = False,
|
|
449
529
|
patch : Union[DatasetPatch, None] = DatasetPatch(),
|
|
450
530
|
use_cache : bool = True,
|
|
451
531
|
subset : Union[PredictionSubset, dict[str, PredictionSubset]] = PredictionSubset(),
|
|
452
532
|
num_workers : int = 4,
|
|
453
533
|
batch_size : int = 1) -> None:
|
|
454
534
|
|
|
455
|
-
super().__init__(dataset_filenames, groups_src, patch, use_cache, subset, num_workers, batch_size,
|
|
535
|
+
super().__init__(dataset_filenames, groups_src, patch, use_cache, subset, num_workers, batch_size, dataAugmentationsList=augmentations if augmentations else {})
|
|
456
536
|
|
|
457
537
|
class DataMetric(Data):
|
|
458
538
|
|
|
@@ -463,4 +543,4 @@ class DataMetric(Data):
|
|
|
463
543
|
validation: Union[str, None] = None,
|
|
464
544
|
num_workers : int = 4) -> None:
|
|
465
545
|
|
|
466
|
-
super().__init__(dataset_filenames=dataset_filenames, groups_src=groups_src, patch=None, use_cache=False, subset=subset, num_workers=num_workers, batch_size=1,
|
|
546
|
+
super().__init__(dataset_filenames=dataset_filenames, groups_src=groups_src, patch=None, use_cache=False, subset=subset, num_workers=num_workers, batch_size=1, validation=1 if validation is None else validation)
|
konfai/data/patching.py
CHANGED
|
@@ -150,7 +150,10 @@ class Patch(ABC):
|
|
|
150
150
|
|
|
151
151
|
def __init__(self, patch_size: list[int], overlap: Union[int, None], path_mask: Union[str, None] = None, padValue: float = 0, extend_slice: int = 0) -> None:
|
|
152
152
|
self.patch_size = patch_size
|
|
153
|
-
self.overlap = overlap
|
|
153
|
+
self.overlap = overlap
|
|
154
|
+
if isinstance(self.overlap, int):
|
|
155
|
+
if self.overlap < 0:
|
|
156
|
+
self.overlap = None
|
|
154
157
|
self._patch_slices : dict[int, list[tuple[slice]]] = {}
|
|
155
158
|
self._nb_patch_per_dim: dict[int, list[tuple[int, bool]]] = {}
|
|
156
159
|
self.path_mask = path_mask
|
|
@@ -188,7 +191,7 @@ class Patch(ABC):
|
|
|
188
191
|
pad_bottom = 0
|
|
189
192
|
pad_top = 0
|
|
190
193
|
if self._patch_slices[a][index][0].start-bottom < 0:
|
|
191
|
-
pad_bottom = bottom-
|
|
194
|
+
pad_bottom = bottom-self._patch_slices[a][index][0].start
|
|
192
195
|
if self._patch_slices[a][index][0].stop+top > data.shape[len(slices_pre)]:
|
|
193
196
|
pad_top = self._patch_slices[a][index][0].stop+top-data.shape[len(slices_pre)]
|
|
194
197
|
data_sliced = F.pad(data_sliced, [0 for _ in range((len(slices)-1)*2)]+[pad_bottom, pad_top], 'reflect')
|
konfai/data/transform.py
CHANGED
|
@@ -529,7 +529,6 @@ class OneHot(Transform):
|
|
|
529
529
|
|
|
530
530
|
def __call__(self, name: str, input: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
|
|
531
531
|
result = F.one_hot(input.type(torch.int64), num_classes=self.num_classes).permute(0, len(input.shape), *[i+1 for i in range(len(input.shape)-1)]).float().squeeze(2)
|
|
532
|
-
print(result.shape)
|
|
533
532
|
return result
|
|
534
533
|
|
|
535
534
|
def inverse(self, name: str, input: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
|
konfai/main.py
CHANGED
|
@@ -18,9 +18,8 @@ def main():
|
|
|
18
18
|
distributedObject.setup(world_size)
|
|
19
19
|
with TensorBoard(distributedObject.name):
|
|
20
20
|
mp.spawn(distributedObject, nprocs=world_size)
|
|
21
|
-
except
|
|
22
|
-
print(
|
|
23
|
-
exit(1)
|
|
21
|
+
except KeyboardInterrupt:
|
|
22
|
+
print("\n[KonfAI] Manual interruption (Ctrl+C)")
|
|
24
23
|
|
|
25
24
|
|
|
26
25
|
def cluster():
|
|
@@ -47,6 +46,8 @@ def cluster():
|
|
|
47
46
|
executor.update_parameters(name=config["name"], mem_gb=config["memory"], gpus_per_node=n_gpu, tasks_per_node=n_gpu//distributedObject.size, cpus_per_task=config["num_workers"], nodes=config["num_nodes"], timeout_min=config["time_limit"])
|
|
48
47
|
with TensorBoard(distributedObject.name):
|
|
49
48
|
executor.submit(distributedObject)
|
|
49
|
+
except KeyboardInterrupt:
|
|
50
|
+
print("\n[KonfAI] Manual interruption (Ctrl+C)")
|
|
50
51
|
except Exception as e:
|
|
51
52
|
print(e)
|
|
52
53
|
exit(1)
|
konfai/metric/measure.py
CHANGED
|
@@ -147,7 +147,7 @@ class Dice(Criterion):
|
|
|
147
147
|
def forward(self, output: torch.Tensor, *targets : list[torch.Tensor]) -> torch.Tensor:
|
|
148
148
|
target = targets[0]
|
|
149
149
|
if output.shape[1] == 1:
|
|
150
|
-
output = F.one_hot(output.type(torch.int64), num_classes=torch.max(output).item()+1).permute(0, len(target.shape), *[i+1 for i in range(len(target.shape)-1)]).float()
|
|
150
|
+
output = F.one_hot(output.type(torch.int64), num_classes=int(torch.max(output).item()+1)).permute(0, len(target.shape), *[i+1 for i in range(len(target.shape)-1)]).float()
|
|
151
151
|
target = F.one_hot(target.type(torch.int64), num_classes=output.shape[1]).permute(0, len(target.shape), *[i+1 for i in range(len(target.shape)-1)]).float().squeeze(2)
|
|
152
152
|
return 1-torch.mean(self.dice_per_channel(output, target))
|
|
153
153
|
|
|
@@ -94,7 +94,6 @@ class SpatialTransformer(torch.nn.Module):
|
|
|
94
94
|
new_locs[:, 1,1] = 1
|
|
95
95
|
new_locs[:, 0,2] = flow[:, 0]
|
|
96
96
|
new_locs[:, 1,2] = flow[:, 1]
|
|
97
|
-
print(new_locs)
|
|
98
97
|
return F.grid_sample(src, F.affine_grid(new_locs, src.size()), align_corners=True, mode="bilinear")
|
|
99
98
|
else:
|
|
100
99
|
new_locs = self.grid + flow
|
konfai/network/blocks.py
CHANGED
konfai/network/network.py
CHANGED
|
@@ -16,7 +16,7 @@ from enum import Enum
|
|
|
16
16
|
from konfai import KONFAI_ROOT
|
|
17
17
|
from konfai.metric.schedulers import Scheduler
|
|
18
18
|
from konfai.utils.config import config
|
|
19
|
-
from konfai.utils.utils import State, _getModule, getDevice, getGPUMemory
|
|
19
|
+
from konfai.utils.utils import State, _getModule, getDevice, getGPUMemory, MeasureError
|
|
20
20
|
from konfai.data.patching import Accumulator, ModelPatch
|
|
21
21
|
|
|
22
22
|
class NetState(Enum):
|
|
@@ -54,12 +54,12 @@ class LRSchedulersLoader():
|
|
|
54
54
|
def __init__(self, params: dict[str, SchedulerStep] = {"default:ReduceLROnPlateau" : SchedulerStep(0)}) -> None:
|
|
55
55
|
self.params = params
|
|
56
56
|
|
|
57
|
-
def
|
|
58
|
-
|
|
57
|
+
def getschedulers(self, key: str, optimizer: torch.optim.Optimizer) -> dict[torch.optim.lr_scheduler._LRScheduler, int]:
|
|
58
|
+
schedulers : dict[torch.optim.lr_scheduler._LRScheduler, int] = {}
|
|
59
59
|
for name, step in self.params.items():
|
|
60
60
|
if name:
|
|
61
|
-
|
|
62
|
-
return
|
|
61
|
+
schedulers[config("Trainer.Model.{}.Schedulers.{}".format(key, name))(getattr(importlib.import_module('torch.optim.lr_scheduler'), name))(optimizer, config = None)] = step.nb_step
|
|
62
|
+
return schedulers
|
|
63
63
|
|
|
64
64
|
class SchedulersLoader():
|
|
65
65
|
|
|
@@ -67,12 +67,12 @@ class SchedulersLoader():
|
|
|
67
67
|
def __init__(self, params: dict[str, SchedulerStep] = {"default:Constant" : SchedulerStep(0)}) -> None:
|
|
68
68
|
self.params = params
|
|
69
69
|
|
|
70
|
-
def
|
|
71
|
-
|
|
70
|
+
def getschedulers(self, key: str) -> dict[torch.optim.lr_scheduler._LRScheduler, int]:
|
|
71
|
+
schedulers : dict[Scheduler, int] = {}
|
|
72
72
|
for name, step in self.params.items():
|
|
73
73
|
if name:
|
|
74
|
-
|
|
75
|
-
return
|
|
74
|
+
schedulers[getattr(importlib.import_module("konfai.metric.schedulers"), name)(config = None, DL_args = key)] = step.nb_step
|
|
75
|
+
return schedulers
|
|
76
76
|
|
|
77
77
|
class CriterionsAttr():
|
|
78
78
|
|
|
@@ -98,7 +98,7 @@ class CriterionsLoader():
|
|
|
98
98
|
for module_classpath, criterionsAttr in self.criterionsLoader.items():
|
|
99
99
|
module, name = _getModule(module_classpath, "metric.measure")
|
|
100
100
|
criterionsAttr.isTorchCriterion = module.startswith("torch")
|
|
101
|
-
criterionsAttr.sheduler = criterionsAttr.l.
|
|
101
|
+
criterionsAttr.sheduler = criterionsAttr.l.getschedulers("{}.Model.{}.outputsCriterions.{}.targetsCriterions.{}.criterionsLoader.{}".format(KONFAI_ROOT(), model_classname, output_group, target_group, module_classpath))
|
|
102
102
|
criterions[config("{}.Model.{}.outputsCriterions.{}.targetsCriterions.{}.criterionsLoader.{}".format(KONFAI_ROOT(), model_classname, output_group, target_group, module_classpath))(getattr(importlib.import_module(module), name))(config = None)] = criterionsAttr
|
|
103
103
|
return criterions
|
|
104
104
|
|
|
@@ -154,10 +154,25 @@ class Measure():
|
|
|
154
154
|
self.outputsCriterions[output_group.replace(":", ".")] = targetCriterionsLoader.getTargetsCriterions(output_group, model_classname)
|
|
155
155
|
self._loss : dict[int, dict[str, Measure.Loss]] = {}
|
|
156
156
|
|
|
157
|
-
def init(self, model : torch.nn.Module) -> None:
|
|
157
|
+
def init(self, model : torch.nn.Module, group_dest: list[str]) -> None:
|
|
158
158
|
outputs_group_rename = {}
|
|
159
|
+
|
|
160
|
+
modules = []
|
|
161
|
+
for i,_ in model.named_modules():
|
|
162
|
+
modules.append(i)
|
|
163
|
+
|
|
159
164
|
for output_group in self.outputsCriterions.keys():
|
|
165
|
+
if output_group not in modules:
|
|
166
|
+
raise MeasureError(f"The output group '{output_group}' defined in 'outputsCriterions' does not correspond to any module in the model.",
|
|
167
|
+
f"Available modules: {modules}",
|
|
168
|
+
"Please check that the name matches exactly a submodule or output of your model architecture."
|
|
169
|
+
)
|
|
160
170
|
for target_group in self.outputsCriterions[output_group]:
|
|
171
|
+
if target_group not in group_dest:
|
|
172
|
+
raise MeasureError(
|
|
173
|
+
f"The target_group '{target_group}' defined in 'outputsCriterions.{output_group}.targetsCriterions' was not found in the available destination groups.",
|
|
174
|
+
"This target_group is expected for loss or metric computation, but was not loaded in 'group_dest'.",
|
|
175
|
+
f"Please make sure that the group '{target_group}' is defined in 'Dataset:groups_src:...:groups_dest:'{target_group}'' and correctly loaded from the dataset.")
|
|
161
176
|
for criterion in self.outputsCriterions[output_group][target_group]:
|
|
162
177
|
if not self.outputsCriterions[output_group][target_group][criterion].isTorchCriterion:
|
|
163
178
|
outputs_group_rename[output_group] = criterion.init(model, output_group, target_group)
|
|
@@ -703,10 +718,10 @@ class Network(ModuleArgsDict, ABC):
|
|
|
703
718
|
return in_channels, in_is_channel, out_channels, out_is_channel
|
|
704
719
|
|
|
705
720
|
@_function_network()
|
|
706
|
-
def init(self, autocast : bool, state : State, key: str) -> None:
|
|
721
|
+
def init(self, autocast : bool, state : State, group_dest: list[str], key: str) -> None:
|
|
707
722
|
if self.outputsCriterionsLoader:
|
|
708
723
|
self.measure = Measure(key, self.outputsCriterionsLoader)
|
|
709
|
-
self.measure.init(self)
|
|
724
|
+
self.measure.init(self, group_dest)
|
|
710
725
|
if state != State.PREDICTION:
|
|
711
726
|
self.scaler = torch.amp.GradScaler("cuda", enabled=autocast)
|
|
712
727
|
if self.optimizerLoader:
|
|
@@ -714,7 +729,7 @@ class Network(ModuleArgsDict, ABC):
|
|
|
714
729
|
self.optimizer.zero_grad()
|
|
715
730
|
|
|
716
731
|
if self.LRSchedulersLoader and self.optimizer:
|
|
717
|
-
self.schedulers = self.LRSchedulersLoader.
|
|
732
|
+
self.schedulers = self.LRSchedulersLoader.getschedulers(key, self.optimizer)
|
|
718
733
|
|
|
719
734
|
def initialized(self):
|
|
720
735
|
pass
|
|
@@ -880,7 +895,6 @@ class Network(ModuleArgsDict, ABC):
|
|
|
880
895
|
if scheduler:
|
|
881
896
|
if scheduler.__class__.__name__ == 'ReduceLROnPlateau':
|
|
882
897
|
if self.measure:
|
|
883
|
-
print(sum(self.measure.getLastValues(0).values()))
|
|
884
898
|
scheduler.step(sum(self.measure.getLastValues(0).values()))
|
|
885
899
|
else:
|
|
886
900
|
scheduler.step()
|
konfai/trainer.py
CHANGED
|
@@ -1,5 +1,4 @@
|
|
|
1
1
|
import torch
|
|
2
|
-
import torch.optim.adamw
|
|
3
2
|
from torch.utils.data import DataLoader
|
|
4
3
|
import tqdm
|
|
5
4
|
import numpy as np
|
|
@@ -15,13 +14,76 @@ import torch.distributed as dist
|
|
|
15
14
|
from konfai import MODELS_DIRECTORY, CHECKPOINTS_DIRECTORY, STATISTICS_DIRECTORY, SETUPS_DIRECTORY, CONFIG_FILE, MODEL, DATE, KONFAI_STATE
|
|
16
15
|
from konfai.data.data_manager import DataTrain
|
|
17
16
|
from konfai.utils.config import config
|
|
18
|
-
from konfai.utils.utils import State, DataLog, DistributedObject, description
|
|
17
|
+
from konfai.utils.utils import State, DataLog, DistributedObject, description, TrainerError
|
|
19
18
|
from konfai.network.network import Network, ModelLoader, NetState, CPU_Model
|
|
20
19
|
|
|
21
20
|
|
|
21
|
+
class EarlyStoppingBase:
|
|
22
|
+
|
|
23
|
+
def __init__(self):
|
|
24
|
+
pass
|
|
25
|
+
|
|
26
|
+
def isStopped(self) -> bool:
|
|
27
|
+
return False
|
|
28
|
+
|
|
29
|
+
def getScore(self, values: dict[str, float]):
|
|
30
|
+
return sum([i for i in values.values()])
|
|
31
|
+
|
|
32
|
+
def __call__(self, current_score: float) -> bool:
|
|
33
|
+
return False
|
|
34
|
+
|
|
35
|
+
class EarlyStopping(EarlyStoppingBase):
|
|
36
|
+
|
|
37
|
+
@config("EarlyStopping")
|
|
38
|
+
def __init__(self, monitor: Union[list[str], None] = [], patience=10, min_delta=0.0, mode="min"):
|
|
39
|
+
super().__init__()
|
|
40
|
+
self.monitor = [] if monitor is None else monitor
|
|
41
|
+
self.patience = patience
|
|
42
|
+
self.min_delta = min_delta
|
|
43
|
+
self.mode = mode
|
|
44
|
+
self.counter = 0
|
|
45
|
+
self.best_score = None
|
|
46
|
+
self.early_stop = False
|
|
47
|
+
|
|
48
|
+
def isStopped(self) -> bool:
|
|
49
|
+
return self.early_stop
|
|
50
|
+
|
|
51
|
+
def getScore(self, values: dict[str, float]):
|
|
52
|
+
if len(self.monitor) == 0:
|
|
53
|
+
return super().getScore(values)
|
|
54
|
+
for v in self.monitor:
|
|
55
|
+
if v not in values.keys():
|
|
56
|
+
raise TrainerError(
|
|
57
|
+
"Metric '{}' specified in EarlyStopping.monitor not found in logged values. ",
|
|
58
|
+
"Available keys: {}. Please check your configuration.".format(v, list(values.keys())))
|
|
59
|
+
return sum([i for v, i in values.items() if v in self.monitor])
|
|
60
|
+
|
|
61
|
+
def __call__(self, current_score: float) -> bool:
|
|
62
|
+
if self.best_score is None:
|
|
63
|
+
self.best_score = current_score
|
|
64
|
+
return False
|
|
65
|
+
|
|
66
|
+
if self.mode == "min":
|
|
67
|
+
improvement = self.best_score - current_score
|
|
68
|
+
elif self.mode == "max":
|
|
69
|
+
improvement = current_score - self.best_score
|
|
70
|
+
else:
|
|
71
|
+
raise TrainerError("Mode must be 'min' or 'max'.")
|
|
72
|
+
|
|
73
|
+
if improvement > self.min_delta:
|
|
74
|
+
self.best_score = current_score
|
|
75
|
+
self.counter = 0
|
|
76
|
+
else:
|
|
77
|
+
self.counter += 1
|
|
78
|
+
|
|
79
|
+
if self.counter >= self.patience:
|
|
80
|
+
self.early_stop = True
|
|
81
|
+
|
|
82
|
+
return self.early_stop
|
|
83
|
+
|
|
22
84
|
class _Trainer():
|
|
23
85
|
|
|
24
|
-
def __init__(self, world_size: int, global_rank: int, local_rank: int, size: int, train_name: str, data_log: Union[list[str], None] , save_checkpoint_mode: str, epochs: int, epoch: int, autocast: bool, it_validation: Union[int, None], it: int, model: Union[DDP, CPU_Model], modelEMA: AveragedModel, dataloader_training: DataLoader, dataloader_validation: Union[DataLoader, None] = None) -> None:
|
|
86
|
+
def __init__(self, world_size: int, global_rank: int, local_rank: int, size: int, train_name: str, early_stopping: EarlyStopping, data_log: Union[list[str], None] , save_checkpoint_mode: str, epochs: int, epoch: int, autocast: bool, it_validation: Union[int, None], it: int, model: Union[DDP, CPU_Model], modelEMA: AveragedModel, dataloader_training: DataLoader, dataloader_validation: Union[DataLoader, None] = None) -> None:
|
|
25
87
|
self.world_size = world_size
|
|
26
88
|
self.global_rank = global_rank
|
|
27
89
|
self.local_rank = local_rank
|
|
@@ -36,7 +98,8 @@ class _Trainer():
|
|
|
36
98
|
self.dataloader_validation = dataloader_validation
|
|
37
99
|
self.autocast = autocast
|
|
38
100
|
self.modelEMA = modelEMA
|
|
39
|
-
|
|
101
|
+
self.early_stopping = EarlyStoppingBase() if early_stopping is None else early_stopping
|
|
102
|
+
|
|
40
103
|
self.it_validation = it_validation
|
|
41
104
|
if self.it_validation is None:
|
|
42
105
|
self.it_validation = len(dataloader_training)
|
|
@@ -59,6 +122,8 @@ class _Trainer():
|
|
|
59
122
|
for self.epoch in epoch_tqdm:
|
|
60
123
|
self.dataloader_training.dataset.load()
|
|
61
124
|
self.train()
|
|
125
|
+
if self.early_stopping.isStopped():
|
|
126
|
+
break
|
|
62
127
|
self.dataloader_training.dataset.resetAugmentation()
|
|
63
128
|
|
|
64
129
|
def getInput(self, data_dict : dict[str, tuple[torch.Tensor, int, int, int, str, bool]]) -> dict[tuple[str, bool], torch.Tensor]:
|
|
@@ -88,7 +153,11 @@ class _Trainer():
|
|
|
88
153
|
if self.dataloader_validation is not None:
|
|
89
154
|
loss = self._validate()
|
|
90
155
|
self.model.module.update_lr()
|
|
91
|
-
self.
|
|
156
|
+
score = self.early_stopping.getScore(loss)
|
|
157
|
+
self.checkpoint_save(score)
|
|
158
|
+
if self.early_stopping(score):
|
|
159
|
+
break
|
|
160
|
+
|
|
92
161
|
|
|
93
162
|
batch_iter.set_description(desc())
|
|
94
163
|
|
|
@@ -120,41 +189,53 @@ class _Trainer():
|
|
|
120
189
|
return self._validation_log(data_dict)
|
|
121
190
|
|
|
122
191
|
def checkpoint_save(self, loss: float) -> None:
|
|
123
|
-
if self.global_rank
|
|
124
|
-
|
|
125
|
-
last_loss = None
|
|
126
|
-
if os.path.exists(path) and os.listdir(path):
|
|
127
|
-
name = sorted(os.listdir(path))[-1]
|
|
128
|
-
state_dict = torch.load(path+name, weights_only=False)
|
|
129
|
-
last_loss = state_dict["loss"]
|
|
130
|
-
if self.save_checkpoint_mode == "BEST":
|
|
131
|
-
if last_loss >= loss:
|
|
132
|
-
os.remove(path+name)
|
|
133
|
-
|
|
134
|
-
if self.save_checkpoint_mode != "BEST" or (last_loss is None or last_loss >= loss):
|
|
135
|
-
name = DATE()+".pt"
|
|
136
|
-
if not os.path.exists(path):
|
|
137
|
-
os.makedirs(path)
|
|
138
|
-
|
|
139
|
-
save_dict = {
|
|
140
|
-
"epoch": self.epoch,
|
|
141
|
-
"it": self.it,
|
|
142
|
-
"loss": loss,
|
|
143
|
-
"Model": self.model.module.state_dict()}
|
|
192
|
+
if self.global_rank != 0:
|
|
193
|
+
return
|
|
144
194
|
|
|
145
|
-
|
|
146
|
-
|
|
195
|
+
path = CHECKPOINTS_DIRECTORY()+self.train_name+"/"
|
|
196
|
+
os.makedirs(path, exist_ok=True)
|
|
197
|
+
|
|
198
|
+
name = DATE() + ".pt"
|
|
199
|
+
save_path = os.path.join(path, name)
|
|
147
200
|
|
|
148
|
-
|
|
149
|
-
|
|
201
|
+
save_dict = {
|
|
202
|
+
"epoch": self.epoch,
|
|
203
|
+
"it": self.it,
|
|
204
|
+
"loss": loss,
|
|
205
|
+
"Model": self.model.module.state_dict()
|
|
206
|
+
}
|
|
207
|
+
|
|
208
|
+
if self.modelEMA is not None:
|
|
209
|
+
save_dict["Model_EMA"] = self.modelEMA.module.state_dict()
|
|
210
|
+
|
|
211
|
+
save_dict.update({'{}_optimizer_state_dict'.format(name): network.optimizer.state_dict() for name, network in self.model.module.getNetworks().items() if network.optimizer is not None})
|
|
212
|
+
torch.save(save_dict, save_path)
|
|
213
|
+
|
|
214
|
+
if self.save_checkpoint_mode == "BEST":
|
|
215
|
+
all_checkpoints = sorted([
|
|
216
|
+
os.path.join(path, f)
|
|
217
|
+
for f in os.listdir(path) if f.endswith(".pt")
|
|
218
|
+
])
|
|
219
|
+
best_ckpt = None
|
|
220
|
+
best_loss = float('inf')
|
|
221
|
+
|
|
222
|
+
for f in all_checkpoints:
|
|
223
|
+
d = torch.load(f, weights_only=False)
|
|
224
|
+
if d.get("loss", float("inf")) < best_loss:
|
|
225
|
+
best_loss = d["loss"]
|
|
226
|
+
best_ckpt = f
|
|
227
|
+
|
|
228
|
+
for f in all_checkpoints:
|
|
229
|
+
if f != best_ckpt and f != save_path:
|
|
230
|
+
os.remove(f)
|
|
150
231
|
|
|
151
232
|
@torch.no_grad()
|
|
152
|
-
def _log(self, type_log: str, data_dict : dict[str, tuple[torch.Tensor, int, int, int]]) -> float:
|
|
233
|
+
def _log(self, type_log: str, data_dict : dict[str, tuple[torch.Tensor, int, int, int]]) -> dict[str, float]:
|
|
153
234
|
models: dict[str, Network] = {"" : self.model.module}
|
|
154
235
|
if self.modelEMA is not None:
|
|
155
236
|
models["_EMA"] = self.modelEMA.module
|
|
156
237
|
|
|
157
|
-
measures = DistributedObject.getMeasure(self.world_size, self.global_rank, self.local_rank*self.size+self.size-1, models, self.it_validation if type_log == "
|
|
238
|
+
measures = DistributedObject.getMeasure(self.world_size, self.global_rank, self.local_rank*self.size+self.size-1, models, self.it_validation if type_log == "Training" else len(self.dataloader_validation))
|
|
158
239
|
|
|
159
240
|
if self.global_rank == 0:
|
|
160
241
|
images_log = []
|
|
@@ -178,27 +259,29 @@ class _Trainer():
|
|
|
178
259
|
for name, layer, _ in model.get_layers([v.to(0) for k, v in self.getInput(data_dict).items() if k[1]], images_log):
|
|
179
260
|
self.data_log[name][0](self.tb, "{}/{}{}".format(type_log, name, label), layer[:self.data_log[name][1]].detach().cpu().numpy(), self.it)
|
|
180
261
|
|
|
181
|
-
if type_log == "
|
|
262
|
+
if type_log == "Training":
|
|
182
263
|
for name, network in self.model.module.getNetworks().items():
|
|
183
264
|
if network.optimizer is not None:
|
|
184
265
|
self.tb.add_scalar("{}/{}/Learning Rate".format(type_log, name), network.optimizer.param_groups[0]['lr'], self.it)
|
|
185
266
|
|
|
186
267
|
if self.global_rank == 0:
|
|
187
|
-
loss =
|
|
268
|
+
loss = {}
|
|
188
269
|
for name, network in self.model.module.getNetworks().items():
|
|
189
270
|
if network.measure is not None:
|
|
190
|
-
loss.
|
|
191
|
-
|
|
271
|
+
loss.update({k : v[1] for k, v in measures["{}{}".format(name, label)][0].items()})
|
|
272
|
+
loss.update({k : v[1] for k, v in measures["{}{}".format(name, label)][1].items()})
|
|
273
|
+
return loss
|
|
192
274
|
return None
|
|
193
275
|
|
|
194
276
|
@torch.no_grad()
|
|
195
|
-
def _train_log(self, data_dict : dict[str, tuple[torch.Tensor, int, int, int]]) -> float:
|
|
196
|
-
return self._log("
|
|
277
|
+
def _train_log(self, data_dict : dict[str, tuple[torch.Tensor, int, int, int]]) -> dict[str, float]:
|
|
278
|
+
return self._log("Training", data_dict)
|
|
197
279
|
|
|
198
280
|
@torch.no_grad()
|
|
199
|
-
def _validation_log(self, data_dict : dict[str, tuple[torch.Tensor, int, int, int]]) -> float:
|
|
281
|
+
def _validation_log(self, data_dict : dict[str, tuple[torch.Tensor, int, int, int]]) -> dict[str, float]:
|
|
200
282
|
return self._log("Validation", data_dict)
|
|
201
283
|
|
|
284
|
+
|
|
202
285
|
class Trainer(DistributedObject):
|
|
203
286
|
|
|
204
287
|
@config("Trainer")
|
|
@@ -214,6 +297,7 @@ class Trainer(DistributedObject):
|
|
|
214
297
|
gpu_checkpoints: Union[list[str], None] = None,
|
|
215
298
|
ema_decay : float = 0,
|
|
216
299
|
data_log: Union[list[str], None] = None,
|
|
300
|
+
early_stopping: Union[EarlyStopping, None] = None,
|
|
217
301
|
save_checkpoint_mode: str= "BEST") -> None:
|
|
218
302
|
if os.environ["KONFAI_CONFIG_MODE"] != "Done":
|
|
219
303
|
exit(0)
|
|
@@ -223,6 +307,7 @@ class Trainer(DistributedObject):
|
|
|
223
307
|
self.autocast = autocast
|
|
224
308
|
self.epochs = epochs
|
|
225
309
|
self.epoch = 0
|
|
310
|
+
self.early_stopping = early_stopping
|
|
226
311
|
self.it = 0
|
|
227
312
|
self.it_validation = it_validation
|
|
228
313
|
self.model = model.getModel(train=True)
|
|
@@ -306,7 +391,7 @@ class Trainer(DistributedObject):
|
|
|
306
391
|
if state != State.TRAIN:
|
|
307
392
|
state_dict = self._load()
|
|
308
393
|
|
|
309
|
-
self.model.init(self.autocast, state)
|
|
394
|
+
self.model.init(self.autocast, state, self.dataset.getGroupsDest())
|
|
310
395
|
self.model.init_outputsGroup()
|
|
311
396
|
self.model._compute_channels_trace(self.model, self.model.in_channels, self.gradient_checkpoints, self.gpu_checkpoints)
|
|
312
397
|
self.model.load(state_dict, init=True, ema=False)
|
|
@@ -326,5 +411,5 @@ class Trainer(DistributedObject):
|
|
|
326
411
|
model = DDP(model, static_graph=True) if torch.cuda.is_available() else CPU_Model(model)
|
|
327
412
|
if self.modelEMA is not None:
|
|
328
413
|
self.modelEMA.module = Network.to(self.modelEMA.module, local_rank)
|
|
329
|
-
with _Trainer(world_size, global_rank, local_rank, self.size, self.name, self.data_log, self.save_checkpoint_mode, self.epochs, self.epoch, self.autocast, self.it_validation, self.it, model, self.modelEMA, *dataloaders) as t:
|
|
414
|
+
with _Trainer(world_size, global_rank, local_rank, self.size, self.name, self.early_stopping, self.data_log, self.save_checkpoint_mode, self.epochs, self.epoch, self.autocast, self.it_validation, self.it, model, self.modelEMA, *dataloaders) as t:
|
|
330
415
|
t.run()
|
konfai/utils/config.py
CHANGED
|
@@ -6,14 +6,10 @@ from copy import deepcopy
|
|
|
6
6
|
from typing import Union, Literal, get_origin, get_args
|
|
7
7
|
import torch
|
|
8
8
|
from konfai import CONFIG_FILE
|
|
9
|
+
from konfai.utils.utils import ConfigError
|
|
9
10
|
|
|
10
11
|
yaml = ruamel.yaml.YAML()
|
|
11
12
|
|
|
12
|
-
class ConfigError(Exception):
|
|
13
|
-
|
|
14
|
-
def __init__(self, message : str = "The config only supports types : config(Object), int, str, bool, float, list[int], list[str], list[bool], list[float], dict[str, Object]") -> None:
|
|
15
|
-
self.message = message
|
|
16
|
-
super().__init__(self.message)
|
|
17
13
|
|
|
18
14
|
class Config():
|
|
19
15
|
|
|
@@ -187,9 +183,8 @@ def config(key : Union[str, None] = None):
|
|
|
187
183
|
default_value = param.default if param.default != inspect._empty else allowed_values[0]
|
|
188
184
|
value = config.getValue(param.name, f"default:{default_value}")
|
|
189
185
|
if value not in allowed_values:
|
|
190
|
-
raise
|
|
191
|
-
f"
|
|
192
|
-
f"Expected one of: {allowed_values}."
|
|
186
|
+
raise ConfigError(
|
|
187
|
+
f"Invalid value '{value}' for parameter '{param.name} expected one of: {allowed_values}."
|
|
193
188
|
)
|
|
194
189
|
kwargs[param.name] = value
|
|
195
190
|
continue
|
|
@@ -222,21 +217,26 @@ def config(key : Union[str, None] = None):
|
|
|
222
217
|
values = config.getValue(param.name, param.default)
|
|
223
218
|
kwargs[param.name] = values
|
|
224
219
|
else:
|
|
225
|
-
raise ConfigError()
|
|
220
|
+
raise ConfigError("Config: The config only supports types : config(Object), int, str, bool, float, list[int], list[str], list[bool], list[float], dict[str, Object]")
|
|
226
221
|
elif str(annotation).startswith("dict"):
|
|
227
222
|
if annotation.__args__[0] == str:
|
|
228
223
|
values = config.getValue(param.name, param.default)
|
|
229
224
|
if values is not None and annotation.__args__[1] not in [int, str, bool, float]:
|
|
230
|
-
|
|
225
|
+
try:
|
|
226
|
+
kwargs[param.name] = {value : annotation.__args__[1](config = filename, DL_args = key_tmp+"."+param.name+"."+value) for value in values}
|
|
227
|
+
except ValueError as e:
|
|
228
|
+
raise ValueError(e)
|
|
229
|
+
except Exception as e:
|
|
230
|
+
raise ConfigError("{} {}".format(values, e))
|
|
231
231
|
else:
|
|
232
232
|
kwargs[param.name] = values
|
|
233
233
|
else:
|
|
234
|
-
raise ConfigError()
|
|
234
|
+
raise ConfigError("Config: The config only supports types : config(Object), int, str, bool, float, list[int], list[str], list[bool], list[float], dict[str, Object]")
|
|
235
235
|
else:
|
|
236
236
|
try:
|
|
237
237
|
kwargs[param.name] = annotation(config = filename, DL_args = key_tmp)
|
|
238
238
|
except Exception as e:
|
|
239
|
-
raise
|
|
239
|
+
raise ConfigError("Failed to instantiate {} with type {}, error {} ".format(param.name, annotation.__name__, e))
|
|
240
240
|
|
|
241
241
|
if os.environ['KONFAI_CONFIG_VARIABLE'] == "True":
|
|
242
242
|
os.environ['KONFAI_CONFIG_VARIABLE'] = "False"
|
konfai/utils/utils.py
CHANGED
|
@@ -145,8 +145,12 @@ def get_patch_slices_from_nb_patch_per_dim(patch_size_tmp: list[int], nb_patch_p
|
|
|
145
145
|
|
|
146
146
|
def get_patch_slices_from_shape(patch_size: list[int], shape : list[int], overlap: Union[int, None]) -> tuple[list[tuple[slice]], list[tuple[int, bool]]]:
|
|
147
147
|
if len(shape) != len(patch_size):
|
|
148
|
-
|
|
149
|
-
|
|
148
|
+
raise DatasetManagerError(
|
|
149
|
+
f"Dimension mismatch: 'patch_size' has {len(patch_size)} dimensions, but 'shape' has {len(shape)}.",
|
|
150
|
+
f"patch_size: {patch_size}",
|
|
151
|
+
f"shape: {shape}",
|
|
152
|
+
"Both must have the same number of dimensions (e.g., 3D patch for 3D volume)."
|
|
153
|
+
)
|
|
150
154
|
patch_slices = []
|
|
151
155
|
nb_patch_per_dim = []
|
|
152
156
|
slices : list[list[slice]] = []
|
|
@@ -192,12 +196,8 @@ def _logImageFormat(input : np.ndarray):
|
|
|
192
196
|
|
|
193
197
|
if len(input.shape) == 3 and input.shape[0] != 1:
|
|
194
198
|
input = np.expand_dims(input, axis=0)
|
|
195
|
-
|
|
196
199
|
if len(input.shape) == 4:
|
|
197
200
|
input = input[:, input.shape[1]//2]
|
|
198
|
-
|
|
199
|
-
if input.dtype == np.uint8:
|
|
200
|
-
return input
|
|
201
201
|
|
|
202
202
|
input = input.astype(float)
|
|
203
203
|
b = -np.min(input)
|
|
@@ -325,7 +325,7 @@ class DistributedObject():
|
|
|
325
325
|
return self
|
|
326
326
|
|
|
327
327
|
def __exit__(self, type, value, traceback):
|
|
328
|
-
|
|
328
|
+
cleanup()
|
|
329
329
|
|
|
330
330
|
@abstractmethod
|
|
331
331
|
def run_process(self, world_size: int, global_rank: int, local_rank: int, dataloaders: list[DataLoader]):
|
|
@@ -370,11 +370,12 @@ class DistributedObject():
|
|
|
370
370
|
dataloaders = self.dataloader[global_rank]
|
|
371
371
|
if torch.cuda.is_available():
|
|
372
372
|
torch.cuda.set_device(local_rank)
|
|
373
|
-
|
|
374
|
-
|
|
375
|
-
|
|
376
|
-
|
|
377
|
-
|
|
373
|
+
try:
|
|
374
|
+
self.run_process(world_size, global_rank, local_rank, dataloaders)
|
|
375
|
+
finally:
|
|
376
|
+
cleanup()
|
|
377
|
+
if torch.cuda.is_available():
|
|
378
|
+
pynvml.nvmlShutdown()
|
|
378
379
|
|
|
379
380
|
def setup(parser: argparse.ArgumentParser) -> DistributedObject:
|
|
380
381
|
# KONFAI arguments
|
|
@@ -460,7 +461,7 @@ def setupGPU(world_size: int, port: int, rank: Union[int, None] = None) -> tuple
|
|
|
460
461
|
local_rank = rank
|
|
461
462
|
if global_rank >= world_size:
|
|
462
463
|
return None, None
|
|
463
|
-
print("tcp://{}:{}".format(host_name, port))
|
|
464
|
+
#print("tcp://{}:{}".format(host_name, port))
|
|
464
465
|
if torch.cuda.is_available():
|
|
465
466
|
torch.cuda.empty_cache()
|
|
466
467
|
dist.init_process_group("nccl", rank=global_rank, init_method="tcp://{}:{}".format(host_name, port), world_size=world_size)
|
|
@@ -476,7 +477,7 @@ def find_free_port():
|
|
|
476
477
|
return s.getsockname()[1]
|
|
477
478
|
|
|
478
479
|
def cleanup():
|
|
479
|
-
if
|
|
480
|
+
if dist.is_initialized():
|
|
480
481
|
dist.destroy_process_group()
|
|
481
482
|
|
|
482
483
|
def synchronize_data(world_size: int, gpu: int, data: any) -> list[Any]:
|
|
@@ -506,3 +507,49 @@ def _resample_affine(data: torch.Tensor, matrix: torch.Tensor):
|
|
|
506
507
|
else:
|
|
507
508
|
mode = "bilinear"
|
|
508
509
|
return F.grid_sample(data.unsqueeze(0).type(torch.float32), F.affine_grid(matrix[:, :-1,...].type(torch.float32), [1]+list(data.shape), align_corners=True), align_corners=True, mode=mode, padding_mode="reflection").squeeze(0).type(data.dtype)
|
|
510
|
+
|
|
511
|
+
|
|
512
|
+
SUPPORTED_EXTENSIONS = [
|
|
513
|
+
"mha", "mhd", # MetaImage
|
|
514
|
+
"nii", "nii.gz", # NIfTI
|
|
515
|
+
"nrrd", "nrrd.gz", # NRRD
|
|
516
|
+
"gipl", "gipl.gz", # GIPL
|
|
517
|
+
"hdr", "img", # Analyze
|
|
518
|
+
"dcm", # DICOM (si GDCM activé)
|
|
519
|
+
"tif", "tiff", # TIFF
|
|
520
|
+
"png", "jpg", "jpeg", "bmp", # 2D formats
|
|
521
|
+
"h5", "itk.txt", ".fcsv", ".xml", ".vtk", ".npy"
|
|
522
|
+
|
|
523
|
+
]
|
|
524
|
+
|
|
525
|
+
class KonfAIError(Exception):
|
|
526
|
+
|
|
527
|
+
def __init__(self, typeError: str, message: list[str]) -> None:
|
|
528
|
+
super().__init__("\n[{}] {}".format(typeError, message[0])+"\n→\t".join(message[1:]))
|
|
529
|
+
|
|
530
|
+
|
|
531
|
+
class ConfigError(KonfAIError):
|
|
532
|
+
|
|
533
|
+
def __init__(self, *message) -> None:
|
|
534
|
+
super().__init__("Config", message)
|
|
535
|
+
|
|
536
|
+
|
|
537
|
+
class DatasetManagerError(KonfAIError):
|
|
538
|
+
|
|
539
|
+
def __init__(self, *message) -> None:
|
|
540
|
+
super().__init__("DatasetManager", message)
|
|
541
|
+
|
|
542
|
+
class MeasureError(KonfAIError):
|
|
543
|
+
|
|
544
|
+
def __init__(self, *message) -> None:
|
|
545
|
+
super().__init__("Measure", message)
|
|
546
|
+
|
|
547
|
+
class TrainerError(KonfAIError):
|
|
548
|
+
|
|
549
|
+
def __init__(self, *message) -> None:
|
|
550
|
+
super().__init__("Trainer", message)
|
|
551
|
+
|
|
552
|
+
class AugmentationError(KonfAIError):
|
|
553
|
+
|
|
554
|
+
def __init__(self, *message) -> None:
|
|
555
|
+
super().__init__("Augmentation", message)
|
|
@@ -1,15 +1,15 @@
|
|
|
1
1
|
konfai/__init__.py,sha256=fZM9EjrMcdJZE-jQAsNu3ZcWfa9Ylbly2q4l9STfpBI,795
|
|
2
2
|
konfai/evaluator.py,sha256=TsHQNN4DhP33Sal9-exbOouZAHe7gNEk7RsBaasgkEQ,7373
|
|
3
|
-
konfai/main.py,sha256=
|
|
3
|
+
konfai/main.py,sha256=OPAUWees1j6NNmggyI7lM2abZctCqk1OsvQuc9FpePY,2515
|
|
4
4
|
konfai/predictor.py,sha256=FGn6_Nt4PCdxNFoIOt8nbBjrAMTwc55uZqy6OFXUjhY,22320
|
|
5
|
-
konfai/trainer.py,sha256=
|
|
5
|
+
konfai/trainer.py,sha256=22IKMiPeHNJMGcJdVq9BH-Hsopp6xF4QSYyPabz4LGs,19604
|
|
6
6
|
konfai/data/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
7
|
-
konfai/data/augmentation.py,sha256=
|
|
8
|
-
konfai/data/data_manager.py,sha256=
|
|
9
|
-
konfai/data/patching.py,sha256=
|
|
10
|
-
konfai/data/transform.py,sha256=
|
|
7
|
+
konfai/data/augmentation.py,sha256=ASmKWBpykLBHDB_YeTgoBTlqvJ06v5OUG7n7ugPN6NU,31718
|
|
8
|
+
konfai/data/data_manager.py,sha256=AfZpQaHEd4QmA8Sxd6SK2LK3Nv658wlZVbivETq4Y4k,28439
|
|
9
|
+
konfai/data/patching.py,sha256=JFgb4qgS_gqlcnFV_Rm0JvCzKQleXNZyOWxqqM8H01M,14499
|
|
10
|
+
konfai/data/transform.py,sha256=ttALCEvvqIcwcYJTp4KCnf87e2hzvkXlwrtksfpWRUU,25064
|
|
11
11
|
konfai/metric/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
12
|
-
konfai/metric/measure.py,sha256=
|
|
12
|
+
konfai/metric/measure.py,sha256=K7s15wNJm1_Iaav9y7Oe_UhSnbLUiybhLTVumHZ8ZaY,22001
|
|
13
13
|
konfai/metric/schedulers.py,sha256=UoSr1TW_hrus3DhvOEbDefxCSUGz7lJS_8vbz0GEye8,1370
|
|
14
14
|
konfai/models/classification/convNeXt.py,sha256=Phj1hO8TCItVhBXoFXQcIkA85DZxurGLyHIPWBfXJ0Y,9243
|
|
15
15
|
konfai/models/classification/resnet.py,sha256=t8KJEgGudWBpGJM9YS1lH_5eX0-wGbK7ZW5uZW214eM,7958
|
|
@@ -18,22 +18,22 @@ konfai/models/generation/ddpm.py,sha256=awvuRo-vk8M80N93NWF4i0-WWfaycBxSOmdYJNJv
|
|
|
18
18
|
konfai/models/generation/diffusionGan.py,sha256=KnJyV-tx4CiE_ag-5IXwiYLCuC2yFHX16k2CtASdecg,33199
|
|
19
19
|
konfai/models/generation/gan.py,sha256=-GoKxHm3W9NdD4U77UcJrG5TfOZ3NWFUZG663kt2XPo,7854
|
|
20
20
|
konfai/models/generation/vae.py,sha256=_3JYVT2ojZ0P98tYcD2ny7a-gWVUmnByLDhY7i-n_4g,4719
|
|
21
|
-
konfai/models/registration/registration.py,sha256=
|
|
21
|
+
konfai/models/registration/registration.py,sha256=18EiWt4RJIXLyFtqU-kHjV1sMnQRm9mxAA6_-2B1YqI,6313
|
|
22
22
|
konfai/models/representation/representation.py,sha256=RwQYoxtdph440-t_ZLelykl0hkUAD1zdspQaLkgxb-0,2677
|
|
23
23
|
konfai/models/segmentation/NestedUNet.py,sha256=6XGizAIc4bDL8vx4AHW8BBFjUvovRYcjdMBHsN4ViNo,4301
|
|
24
24
|
konfai/models/segmentation/UNet.py,sha256=9-g63oNqaxSlGmrD1-IVzJ-kac9QphmM7i-3rEsH23I,4218
|
|
25
25
|
konfai/network/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
26
|
-
konfai/network/blocks.py,sha256=
|
|
27
|
-
konfai/network/network.py,sha256=
|
|
26
|
+
konfai/network/blocks.py,sha256=RaTI0Lrvq1V-GIFei-WTUB6wlg4LydZksAyJ8DMk40M,13502
|
|
27
|
+
konfai/network/network.py,sha256=DTROmwwPJZQllJx32GwAixfy1E23F1mtYy1bpF23TjU,46538
|
|
28
28
|
konfai/utils/ITK.py,sha256=OxTieDNNYHGkn7zxJsAG-6ecRG1VYMvn1dlBbBe1DOs,13955
|
|
29
29
|
konfai/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
|
|
30
|
-
konfai/utils/config.py,sha256=
|
|
30
|
+
konfai/utils/config.py,sha256=OsCUqgNaPX56nD_MpJ2WRpiQ7xxqHonfPYaS9VGtsl8,12413
|
|
31
31
|
konfai/utils/dataset.py,sha256=RNULLVauhYiyeQkpVhOVq4VRwI4KRyoQA2VCmOh8VbE,35525
|
|
32
32
|
konfai/utils/registration.py,sha256=v1srEBOcgDnHrx0YtsK6bcj0yCMH7wNeaQ3wC7gEvOw,8898
|
|
33
|
-
konfai/utils/utils.py,sha256=
|
|
34
|
-
konfai-1.
|
|
35
|
-
konfai-1.
|
|
36
|
-
konfai-1.
|
|
37
|
-
konfai-1.
|
|
38
|
-
konfai-1.
|
|
39
|
-
konfai-1.
|
|
33
|
+
konfai/utils/utils.py,sha256=GPzQQ7kBmraHytsUdPOqwW0Qu-kAHGA-wwMd4obdVLE,22018
|
|
34
|
+
konfai-1.1.1.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
|
|
35
|
+
konfai-1.1.1.dist-info/METADATA,sha256=mMT6fvoAbwoMkChW6G_Y2vDhrtiMss86xezEnylzLrE,2515
|
|
36
|
+
konfai-1.1.1.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
|
|
37
|
+
konfai-1.1.1.dist-info/entry_points.txt,sha256=fG82HRN5-g39ACSOCtij_I3N6EHxfYnMR0D7TI_8pW8,81
|
|
38
|
+
konfai-1.1.1.dist-info/top_level.txt,sha256=xF470dkIlFoFqTZEOlRehKJr4WU_8OKGXrJqYm9vWKs,7
|
|
39
|
+
konfai-1.1.1.dist-info/RECORD,,
|
|
File without changes
|
|
File without changes
|
|
File without changes
|
|
File without changes
|