konfai 1.0.7__py3-none-any.whl → 1.0.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of konfai might be problematic. Click here for more details.

konfai/__init__.py CHANGED
@@ -1,16 +1,16 @@
1
1
  import os
2
2
  import datetime
3
3
 
4
- MODELS_DIRECTORY = lambda : os.environ["DL_API_MODELS_DIRECTORY"]
5
- CHECKPOINTS_DIRECTORY =lambda : os.environ["DL_API_CHECKPOINTS_DIRECTORY"]
6
- MODEL = lambda : os.environ["DL_API_MODEL"]
7
- PREDICTIONS_DIRECTORY =lambda : os.environ["DL_API_PREDICTIONS_DIRECTORY"]
8
- EVALUATIONS_DIRECTORY =lambda : os.environ["DL_API_EVALUATIONS_DIRECTORY"]
9
- STATISTICS_DIRECTORY = lambda : os.environ["DL_API_STATISTICS_DIRECTORY"]
10
- SETUPS_DIRECTORY = lambda : os.environ["DL_API_SETUPS_DIRECTORY"]
11
- CONFIG_FILE = lambda : os.environ["DEEP_LEARNING_API_CONFIG_FILE"]
12
- DL_API_STATE = lambda : os.environ["DL_API_STATE"]
13
- DEEP_LEARNING_API_ROOT = lambda : os.environ["DEEP_LEARNING_API_ROOT"]
4
+ MODELS_DIRECTORY = lambda : os.environ["KONFAI_MODELS_DIRECTORY"]
5
+ CHECKPOINTS_DIRECTORY =lambda : os.environ["KONFAI_CHECKPOINTS_DIRECTORY"]
6
+ MODEL = lambda : os.environ["KONFAI_MODEL"]
7
+ PREDICTIONS_DIRECTORY =lambda : os.environ["KONFAI_PREDICTIONS_DIRECTORY"]
8
+ EVALUATIONS_DIRECTORY =lambda : os.environ["KONFAI_EVALUATIONS_DIRECTORY"]
9
+ STATISTICS_DIRECTORY = lambda : os.environ["KONFAI_STATISTICS_DIRECTORY"]
10
+ SETUPS_DIRECTORY = lambda : os.environ["KONFAI_SETUPS_DIRECTORY"]
11
+ CONFIG_FILE = lambda : os.environ["KONFAI_CONFIG_FILE"]
12
+ KONFAI_STATE = lambda : os.environ["KONFAI_STATE"]
13
+ KONFAI_ROOT = lambda : os.environ["KONFAI_ROOT"]
14
14
  CUDA_VISIBLE_DEVICES = lambda : os.environ["CUDA_VISIBLE_DEVICES"]
15
15
 
16
16
  DATE = lambda : datetime.datetime.now().strftime("%Y_%m_%d_%H_%M_%S")
@@ -6,7 +6,7 @@ import SimpleITK as sitk
6
6
  import torch.nn.functional as F
7
7
  from typing import Union
8
8
  import os
9
- from konfai import DEEP_LEARNING_API_ROOT
9
+ from konfai import KONFAI_ROOT
10
10
  from konfai.utils.config import config
11
11
  from konfai.utils.utils import _getModule
12
12
  from konfai.utils.dataset import Attribute, data_to_image
@@ -51,7 +51,7 @@ class Prob():
51
51
  class DataAugmentationsList():
52
52
 
53
53
  @config()
54
- def __init__(self, nb : int = 10, dataAugmentations: dict[str, Prob] = {"default:RandomElastixTransform" : Prob(1)}) -> None:
54
+ def __init__(self, nb : int = 10, dataAugmentations: dict[str, Prob] = {"default:Flip" : Prob(1)}) -> None:
55
55
  self.nb = nb
56
56
  self.dataAugmentations : list[DataAugmentation] = []
57
57
  self.dataAugmentationsLoader = dataAugmentations
@@ -59,7 +59,7 @@ class DataAugmentationsList():
59
59
  def load(self, key: str):
60
60
  for augmentation, prob in self.dataAugmentationsLoader.items():
61
61
  module, name = _getModule(augmentation, "data.augmentation")
62
- dataAugmentation: DataAugmentation = getattr(importlib.import_module(module), name)(config = None, DL_args="{}.Dataset.augmentations.{}.dataAugmentations".format(DEEP_LEARNING_API_ROOT(), key))
62
+ dataAugmentation: DataAugmentation = getattr(importlib.import_module(module), name)(config = None, DL_args="{}.Dataset.augmentations.{}.dataAugmentations".format(KONFAI_ROOT(), key))
63
63
  dataAugmentation.load(prob.prob)
64
64
  self.dataAugmentations.append(dataAugmentation)
65
65
 
@@ -525,7 +525,7 @@ class Elastix(DataAugmentation):
525
525
  class Permute(DataAugmentation):
526
526
 
527
527
  @config("Permute")
528
- def __init__(self, prob_permute: Union[list[float], None] = [0.33 ,0.33]) -> None:
528
+ def __init__(self, prob_permute: Union[list[float], None] = [0.5 ,0.5]) -> None:
529
529
  super().__init__()
530
530
  self._permute_dims = torch.tensor([[0, 2, 1, 3], [0, 3, 1, 2]])
531
531
  self.prob_permute = prob_permute
@@ -12,8 +12,8 @@ from concurrent.futures import ThreadPoolExecutor, as_completed
12
12
  import threading
13
13
  from torch.cuda import device_count
14
14
 
15
- from konfai import DL_API_STATE, DEEP_LEARNING_API_ROOT
16
- from konfai.data.HDF5 import DatasetPatch, DatasetManager
15
+ from konfai import KONFAI_STATE, KONFAI_ROOT
16
+ from konfai.data.patching import DatasetPatch, DatasetManager
17
17
  from konfai.utils.config import config
18
18
  from konfai.utils.utils import memoryInfo, cpuInfo, memoryForecast, getMemory, State
19
19
  from konfai.utils.dataset import Dataset, Attribute
@@ -36,7 +36,7 @@ class GroupTransform:
36
36
  if self._pre_transforms is not None:
37
37
  if isinstance(self._pre_transforms, dict):
38
38
  for classpath, transform in self._pre_transforms.items():
39
- transform = transform.getTransform(classpath, DL_args = "{}.Dataset.groups_src.{}.groups_dest.{}.pre_transforms".format(DEEP_LEARNING_API_ROOT(), group_src, group_dest))
39
+ transform = transform.getTransform(classpath, DL_args = "{}.Dataset.groups_src.{}.groups_dest.{}.pre_transforms".format(KONFAI_ROOT(), group_src, group_dest))
40
40
  transform.setDatasets(datasets)
41
41
  self.pre_transforms.append(transform)
42
42
  else:
@@ -47,7 +47,7 @@ class GroupTransform:
47
47
  if self._post_transforms is not None:
48
48
  if isinstance(self._post_transforms, dict):
49
49
  for classpath, transform in self._post_transforms.items():
50
- transform = transform.getTransform(classpath, DL_args = "{}.Dataset.groups_src.{}.groups_dest.{}.post_transforms".format(DEEP_LEARNING_API_ROOT(), group_src, group_dest))
50
+ transform = transform.getTransform(classpath, DL_args = "{}.Dataset.groups_src.{}.groups_dest.{}.post_transforms".format(KONFAI_ROOT(), group_src, group_dest))
51
51
  transform.setDatasets(datasets)
52
52
  self.post_transforms.append(transform)
53
53
  else:
@@ -64,7 +64,7 @@ class GroupTransform:
64
64
  class Group(dict[str, GroupTransform]):
65
65
 
66
66
  @config()
67
- def __init__(self, groups_dest: dict[str, GroupTransform] = {"default": GroupTransform()}):
67
+ def __init__(self, groups_dest: dict[str, GroupTransform] = {"default:group_dest": GroupTransform()}):
68
68
  super().__init__(groups_dest)
69
69
 
70
70
  class CustomSampler(Sampler[int]):
@@ -127,7 +127,7 @@ class DatasetIter(data.Dataset):
127
127
  total=len(indexs),
128
128
  desc="Caching : init | {} | {}".format(memoryForecast(memory_init, 0, self.nb_dataset), cpuInfo()),
129
129
  leave=False,
130
- disable=self.rank != 0 and "DL_API_CLUSTER" not in os.environ
130
+ disable=self.rank != 0 and "KONFAI_CLUSTER" not in os.environ
131
131
  )
132
132
 
133
133
  def process(index):
@@ -135,7 +135,7 @@ class DatasetIter(data.Dataset):
135
135
  with memory_lock:
136
136
  pbar.set_description("Caching : {} | {} | {}".format(memoryInfo(), memoryForecast(memory_init, index, self.nb_dataset), cpuInfo()))
137
137
  pbar.update(1)
138
- with ThreadPoolExecutor(max_workers=os.cpu_count()//device_count()) as executor:
138
+ with ThreadPoolExecutor(max_workers=os.cpu_count()//(device_count() if device_count() > 0 else 1)) as executor:
139
139
  futures = [executor.submit(process, index) for index in indexs]
140
140
  for _ in as_completed(futures):
141
141
  pass
@@ -295,7 +295,7 @@ class Data(ABC):
295
295
  return [[] for _ in range(world_size)]
296
296
 
297
297
  maps = []
298
- if DL_API_STATE() == str(State.PREDICTION) or DL_API_STATE() == str(State.EVALUATION):
298
+ if KONFAI_STATE() == str(State.PREDICTION) or KONFAI_STATE() == str(State.EVALUATION):
299
299
  np_map = np.asarray(map)
300
300
  unique_index = np.unique(np_map[:, 0])
301
301
  offset = int(np.ceil(len(unique_index)/world_size))
@@ -314,7 +314,11 @@ class Data(ABC):
314
314
  def getData(self, world_size: int) -> list[list[DataLoader]]:
315
315
  datasets: dict[str, list[(str, bool)]] = {}
316
316
  for dataset_filename in self.dataset_filenames:
317
- if len(dataset_filename.split(":")) == 2:
317
+ if len(dataset_filename.split(":")) == 1:
318
+ filename = dataset_filename
319
+ format = "mha"
320
+ append = True
321
+ elif len(dataset_filename.split(":")) == 2:
318
322
  filename, format = dataset_filename.split(":")
319
323
  append = True
320
324
  else:
@@ -331,8 +335,9 @@ class Data(ABC):
331
335
  else:
332
336
  datasets[group] = [(filename, append)]
333
337
  for group_src in self.groups_src:
334
- assert group_src in datasets, "Error group source {} not found".format(group_src)
335
-
338
+ if group_src not in datasets:
339
+ raise ValueError("[DatasetManager] Error: group source {} not found. Available groups: {}".format(group_src, list(datasets.keys())))
340
+
336
341
  for group_dest in self.groups_src[group_src]:
337
342
  self.groups_src[group_src][group_dest].load(group_src, group_dest, [self.datasets[filename] for filename, _ in datasets[group_src]])
338
343
  for key, dataAugmentations in self.dataAugmentationsList.items():
@@ -422,8 +427,8 @@ class Data(ABC):
422
427
  class DataTrain(Data):
423
428
 
424
429
  @config("Dataset")
425
- def __init__(self, dataset_filenames : list[str] = ["default:Dataset.h5"],
426
- groups_src : dict[str, Group] = {"default" : Group()},
430
+ def __init__(self, dataset_filenames : list[str] = ["default:./Dataset"],
431
+ groups_src : dict[str, Group] = {"default:group_src" : Group()},
427
432
  augmentations : Union[dict[str, DataAugmentationsList], None] = {"DataAugmentation_0" : DataAugmentationsList()},
428
433
  inlineAugmentations: bool = False,
429
434
  patch : Union[DatasetPatch, None] = DatasetPatch(),
@@ -437,7 +442,7 @@ class DataTrain(Data):
437
442
  class DataPrediction(Data):
438
443
 
439
444
  @config("Dataset")
440
- def __init__(self, dataset_filenames : list[str] = ["default:Dataset.h5"],
445
+ def __init__(self, dataset_filenames : list[str] = ["default:./Dataset"],
441
446
  groups_src : dict[str, Group] = {"default" : Group()},
442
447
  augmentations : Union[dict[str, DataAugmentationsList], None] = {"DataAugmentation_0" : DataAugmentationsList()},
443
448
  inlineAugmentations: bool = False,
@@ -452,19 +457,10 @@ class DataPrediction(Data):
452
457
  class DataMetric(Data):
453
458
 
454
459
  @config("Dataset")
455
- def __init__(self, dataset_filenames : list[str] = ["default:Dataset.h5"],
460
+ def __init__(self, dataset_filenames : list[str] = ["default:./Dataset"],
456
461
  groups_src : dict[str, Group] = {"default" : Group()},
457
462
  subset : Union[PredictionSubset, dict[str, PredictionSubset]] = PredictionSubset(),
458
463
  validation: Union[str, None] = None,
459
464
  num_workers : int = 4) -> None:
460
465
 
461
- super().__init__(dataset_filenames=dataset_filenames, groups_src=groups_src, patch=None, use_cache=False, subset=subset, num_workers=num_workers, batch_size=1, train_size=1 if validation is None else validation)
462
-
463
- class DataHyperparameter(Data):
464
-
465
- @config("Dataset")
466
- def __init__(self, dataset_filenames : list[str] = ["default:Dataset.h5"],
467
- groups_src : dict[str, Group] = {"default" : Group()},
468
- patch : Union[DatasetPatch, None] = DatasetPatch()) -> None:
469
-
470
- super().__init__(dataset_filenames, groups_src, patch, False, PredictionSubset(), 0, False, 1)
466
+ super().__init__(dataset_filenames=dataset_filenames, groups_src=groups_src, patch=None, use_cache=False, subset=subset, num_workers=num_workers, batch_size=1, train_size=1 if validation is None else validation)
@@ -37,13 +37,13 @@ class PathCombine(ABC):
37
37
  3D :
38
38
  AAA+AAB+ABA+ABB+BAA+BAB+BBA+BBB
39
39
 
40
- AAC+ABC+BAC+BBC
41
- ACA+ACB+BCA+BCB
42
40
  CAA+CAB+CBA+CBB
41
+ ACA+ACB+BCA+BCB
42
+ AAC+ABC+BAC+BBC
43
43
 
44
- ACC+BCC
45
- CAC+CBC
46
44
  CCA+CCB
45
+ CAC+CBC
46
+ ACC+BCC
47
47
 
48
48
  """
49
49
  def setPatchConfig(self, patch_size: list[int], overlap: int):
@@ -214,13 +214,13 @@ class Patch(ABC):
214
214
  class DatasetPatch(Patch):
215
215
 
216
216
  @config("Patch")
217
- def __init__(self, patch_size : list[int] = [128, 256, 256], overlap : Union[int, None] = None, mask: Union[str, None] = None, padValue: float = 0, extend_slice: int = 0) -> None:
217
+ def __init__(self, patch_size : list[int] = [128, 128, 128], overlap : Union[int, None] = None, mask: Union[str, None] = None, padValue: float = 0, extend_slice: int = 0) -> None:
218
218
  super().__init__(patch_size, overlap, mask, padValue, extend_slice)
219
219
 
220
220
  class ModelPatch(Patch):
221
221
 
222
222
  @config("Patch")
223
- def __init__(self, patch_size : list[int] = [128, 256, 256], overlap : Union[int, None] = None, patchCombine: Union[str, None] = None, mask: Union[str, None] = None, padValue: float = 0, extend_slice: int = 0) -> None:
223
+ def __init__(self, patch_size : list[int] = [128, 128, 128], overlap : Union[int, None] = None, patchCombine: Union[str, None] = None, mask: Union[str, None] = None, padValue: float = 0, extend_slice: int = 0) -> None:
224
224
  super().__init__(patch_size, overlap, mask, padValue, extend_slice)
225
225
  self.patchCombine = patchCombine
226
226
 
konfai/data/transform.py CHANGED
@@ -130,7 +130,7 @@ class Standardize(Transform):
130
130
 
131
131
  class TensorCast(Transform):
132
132
 
133
- def __init__(self, dtype : str = "default:float32,int64,int16") -> None:
133
+ def __init__(self, dtype : str = "float32") -> None:
134
134
  self.dtype : torch.dtype = getattr(torch, dtype)
135
135
 
136
136
  def __call__(self, name: str, input : torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
@@ -142,7 +142,7 @@ class TensorCast(Transform):
142
142
 
143
143
  class Padding(Transform):
144
144
 
145
- def __init__(self, padding : list[int] = [0,0,0,0,0,0], mode : str = "default:constant,reflect,replicate,circular") -> None:
145
+ def __init__(self, padding : list[int] = [0,0,0,0,0,0], mode : str = "constant") -> None:
146
146
  self.padding = padding
147
147
  self.mode = mode
148
148
 
@@ -332,7 +332,7 @@ class ResampleTransform(Transform):
332
332
 
333
333
  class Mask(Transform):
334
334
 
335
- def __init__(self, path : str = "default:./default.mha", value_outside: int = 0) -> None:
335
+ def __init__(self, path : str = "./default.mha", value_outside: int = 0) -> None:
336
336
  self.path = path
337
337
  self.value_outside = value_outside
338
338
 
konfai/evaluator.py CHANGED
@@ -7,10 +7,10 @@ import json
7
7
  import shutil
8
8
  import builtins
9
9
  import importlib
10
- from konfai import EVALUATIONS_DIRECTORY, PREDICTIONS_DIRECTORY, DEEP_LEARNING_API_ROOT, CONFIG_FILE
10
+ from konfai import EVALUATIONS_DIRECTORY, PREDICTIONS_DIRECTORY, KONFAI_ROOT, CONFIG_FILE
11
11
  from konfai.utils.config import config
12
12
  from konfai.utils.utils import _getModule, DistributedObject, synchronize_data
13
- from konfai.data.dataset import DataMetric
13
+ from konfai.data.data_manager import DataMetric
14
14
 
15
15
  class CriterionsAttr():
16
16
 
@@ -28,7 +28,7 @@ class CriterionsLoader():
28
28
  criterions = {}
29
29
  for module_classpath, criterionsAttr in self.criterionsLoader.items():
30
30
  module, name = _getModule(module_classpath, "metric.measure")
31
- criterions[config("{}.metrics.{}.targetsCriterions.{}.criterionsLoader.{}".format(DEEP_LEARNING_API_ROOT(), output_group, target_group, module_classpath))(getattr(importlib.import_module(module), name))(config = None)] = criterionsAttr
31
+ criterions[config("{}.metrics.{}.targetsCriterions.{}.criterionsLoader.{}".format(KONFAI_ROOT(), output_group, target_group, module_classpath))(getattr(importlib.import_module(module), name))(config = None)] = criterionsAttr
32
32
  return criterions
33
33
 
34
34
  class TargetCriterionsLoader():
@@ -86,8 +86,8 @@ class Statistics():
86
86
  class Evaluator(DistributedObject):
87
87
 
88
88
  @config("Evaluator")
89
- def __init__(self, train_name: str = "default:name", metrics: dict[str, TargetCriterionsLoader] = {"default": TargetCriterionsLoader()}, dataset : DataMetric = DataMetric(),) -> None:
90
- if os.environ["DEEP_LEANING_API_CONFIG_MODE"] != "Done":
89
+ def __init__(self, train_name: str = "default:TRAIN_01", metrics: dict[str, TargetCriterionsLoader] = {"default": TargetCriterionsLoader()}, dataset : DataMetric = DataMetric(),) -> None:
90
+ if os.environ["KONFAI_CONFIG_MODE"] != "Done":
91
91
  exit(0)
92
92
  super().__init__(train_name)
93
93
  self.metric_path = EVALUATIONS_DIRECTORY()+self.name+"/"
@@ -111,7 +111,7 @@ class Evaluator(DistributedObject):
111
111
 
112
112
  def setup(self, world_size: int):
113
113
  if os.path.exists(self.metric_path):
114
- if os.environ["DL_API_OVERWRITE"] != "True":
114
+ if os.environ["KONFAI_OVERWRITE"] != "True":
115
115
  accept = builtins.input("The metric {} already exists ! Do you want to overwrite it (yes,no) : ".format(self.name))
116
116
  if accept != "yes":
117
117
  return
konfai/main.py CHANGED
@@ -2,21 +2,25 @@ import argparse
2
2
  import os
3
3
  from torch.cuda import device_count
4
4
  import torch.multiprocessing as mp
5
- from konfai.utils.utils import setupAPI, TensorBoard, Log
5
+ from konfai.utils.utils import setup, TensorBoard, Log
6
6
 
7
7
  import sys
8
8
  sys.path.insert(0, os.getcwd())
9
9
 
10
10
  def main():
11
11
  parser = argparse.ArgumentParser(description="KonfAI", formatter_class=argparse.ArgumentDefaultsHelpFormatter)
12
- with setupAPI(parser) as distributedObject:
13
- with Log(distributedObject.name):
14
- world_size = device_count()
15
- if world_size == 0:
16
- world_size = 1
17
- distributedObject.setup(world_size)
18
- with TensorBoard(distributedObject.name):
19
- mp.spawn(distributedObject, nprocs=world_size)
12
+ try:
13
+ with setup(parser) as distributedObject:
14
+ with Log(distributedObject.name):
15
+ world_size = device_count()
16
+ if world_size == 0:
17
+ world_size = 1
18
+ distributedObject.setup(world_size)
19
+ with TensorBoard(distributedObject.name):
20
+ mp.spawn(distributedObject, nprocs=world_size)
21
+ except Exception as e:
22
+ print(e)
23
+ exit(1)
20
24
 
21
25
 
22
26
  def cluster():
@@ -29,17 +33,20 @@ def cluster():
29
33
  cluster_args.add_argument('--memory', type=int, default=16, help='Amount of memory per node')
30
34
  cluster_args.add_argument('--time-limit', '--time_limit', type=int, default=1440, help='Job time limit in minute')
31
35
  cluster_args.add_argument('--resubmit', action='store_true', help='Automatically resubmit job just before timout')
36
+ try:
37
+ with setup(parser) as distributedObject:
38
+ args = parser.parse_args()
39
+ config = vars(args)
40
+ os.environ["KONFAI_OVERWRITE"] = "True"
41
+ os.environ["KONFAI_CLUSTER"] = "True"
32
42
 
33
- with setupAPI(parser) as distributedObject:
34
- args = parser.parse_args()
35
- config = vars(args)
36
- os.environ["DL_API_OVERWRITE"] = "True"
37
- os.environ["DL_API_CLUSTER"] = "True"
38
-
39
- n_gpu = len(config["gpu"].split(","))
40
- distributedObject.setup(n_gpu*int(config["num_nodes"]))
41
- import submitit
42
- executor = submitit.AutoExecutor(folder="./Cluster/")
43
- executor.update_parameters(name=config["name"], mem_gb=config["memory"], gpus_per_node=n_gpu, tasks_per_node=n_gpu//distributedObject.size, cpus_per_task=config["num_workers"], nodes=config["num_nodes"], timeout_min=config["time_limit"])
44
- with TensorBoard(distributedObject.name):
45
- executor.submit(distributedObject)
43
+ n_gpu = len(config["gpu"].split(","))
44
+ distributedObject.setup(n_gpu*int(config["num_nodes"]))
45
+ import submitit
46
+ executor = submitit.AutoExecutor(folder="./Cluster/")
47
+ executor.update_parameters(name=config["name"], mem_gb=config["memory"], gpus_per_node=n_gpu, tasks_per_node=n_gpu//distributedObject.size, cpus_per_task=config["num_workers"], nodes=config["num_nodes"], timeout_min=config["time_limit"])
48
+ with TensorBoard(distributedObject.name):
49
+ executor.submit(distributedObject)
50
+ except Exception as e:
51
+ print(e)
52
+ exit(1)
konfai/metric/measure.py CHANGED
@@ -17,7 +17,7 @@ from abc import abstractmethod
17
17
 
18
18
  from konfai.utils.config import config
19
19
  from konfai.utils.utils import _getModule
20
- from konfai.data.HDF5 import ModelPatch
20
+ from konfai.data.patching import ModelPatch
21
21
  from konfai.network.blocks import LatentDistribution
22
22
  from konfai.network.network import ModelLoader, Network
23
23
 
@@ -92,7 +92,7 @@ class PSNR(MaskedLoss):
92
92
  return psnr
93
93
 
94
94
  def __init__(self, dynamic_range: Union[float, None] = None) -> None:
95
- dynamic_range = dynamic_range if dynamic_range else 1024+3000
95
+ dynamic_range = dynamic_range if dynamic_range else 1024+3071
96
96
  super().__init__(partial(PSNR._loss, dynamic_range), False)
97
97
 
98
98
  class SSIM(MaskedLoss):
@@ -143,8 +143,11 @@ class Dice(Criterion):
143
143
  target = self.flatten(target)
144
144
  return (2.*(input * target).sum() + self.smooth)/(input.sum() + target.sum() + self.smooth)
145
145
 
146
+
146
147
  def forward(self, output: torch.Tensor, *targets : list[torch.Tensor]) -> torch.Tensor:
147
148
  target = targets[0]
149
+ if output.shape[1] == 1:
150
+ output = F.one_hot(output.type(torch.int64), num_classes=torch.max(output).item()+1).permute(0, len(target.shape), *[i+1 for i in range(len(target.shape)-1)]).float()
148
151
  target = F.one_hot(target.type(torch.int64), num_classes=output.shape[1]).permute(0, len(target.shape), *[i+1 for i in range(len(target.shape)-1)]).float().squeeze(2)
149
152
  return 1-torch.mean(self.dice_per_channel(output, target))
150
153
 
@@ -239,7 +242,7 @@ class PerceptualLoss(Criterion):
239
242
  @config(None)
240
243
  def __init__(self, losses: dict[str, float] = {"Gram": 1, "torch_nn_L1Loss": 1}) -> None:
241
244
  self.losses = losses
242
- self.DL_args = os.environ['DEEP_LEARNING_API_CONFIG_PATH'] if "DEEP_LEARNING_API_CONFIG_PATH" in os.environ else ""
245
+ self.DL_args = os.environ['KONFAI_CONFIG_PATH'] if "KONFAI_CONFIG_PATH" in os.environ else ""
243
246
 
244
247
  def getLoss(self) -> dict[torch.nn.Module, float]:
245
248
  result: dict[torch.nn.Module, float] = {}
@@ -252,7 +255,7 @@ class PerceptualLoss(Criterion):
252
255
  super().__init__()
253
256
  self.path_model = path_model
254
257
  if self.path_model not in modelsRegister:
255
- self.model = modelLoader.getModel(train=False, DL_args=os.environ['DEEP_LEARNING_API_CONFIG_PATH'].split("PerceptualLoss")[0]+"PerceptualLoss.Model", DL_without=["optimizer", "schedulers", "nb_batch_per_step", "init_type", "init_gain", "outputsCriterions", "drop_p"])
258
+ self.model = modelLoader.getModel(train=False, DL_args=os.environ['KONFAI_CONFIG_PATH'].split("PerceptualLoss")[0]+"PerceptualLoss.Model", DL_without=["optimizer", "schedulers", "nb_batch_per_step", "init_type", "init_gain", "outputsCriterions", "drop_p"])
256
259
  if path_model.startswith("https"):
257
260
  state_dict = torch.hub.load_state_dict_from_url(path_model)
258
261
  state_dict = {"Model": {self.model.getName() : state_dict["model"]}}
@@ -2,7 +2,7 @@ import torch
2
2
  import torch.nn.functional as F
3
3
  from konfai.network import network, blocks
4
4
  from konfai.utils.config import config
5
- from konfai.data.HDF5 import ModelPatch
5
+ from konfai.data.patching import ModelPatch
6
6
 
7
7
  """
8
8
  "convnext_tiny_1k": "https://dl.fbaipublicfiles.com/convnext/convnext_tiny_1k_224_ema.pth", depths=[3, 3, 9, 3], dims=[96, 192, 384, 768]
@@ -3,7 +3,7 @@ from typing import Type
3
3
  import torch
4
4
  from konfai.network import network, blocks
5
5
  from konfai.utils.config import config
6
- from konfai.data.HDF5 import ModelPatch
6
+ from konfai.data.patching import ModelPatch
7
7
 
8
8
  """
9
9
  'resnet18': 'https://download.pytorch.org/models/resnet18-5c106cde.pth', dim = 2, in_channels = 3, depths=[2, 2, 2, 2], widths = [64, 64, 128, 256, 512], num_classes=1000, useBottleneck=False
@@ -3,7 +3,7 @@ import torch
3
3
 
4
4
  from konfai.network import network, blocks
5
5
  from konfai.utils.config import config
6
- from konfai.data.HDF5 import ModelPatch
6
+ from konfai.data.patching import ModelPatch
7
7
 
8
8
  class MappingNetwork(network.ModuleArgsDict):
9
9
  def __init__(self, z_dim: int, c_dim: int, w_dim: int, num_layers: int, embed_features: int, layer_features: int):
@@ -8,7 +8,7 @@ import numpy as np
8
8
 
9
9
  from konfai.network import network, blocks
10
10
  from konfai.utils.config import config
11
- from konfai.data.HDF5 import ModelPatch
11
+ from konfai.data.patching import ModelPatch
12
12
  from konfai.utils.utils import gpuInfo
13
13
  from konfai.metric.measure import Criterion
14
14
 
@@ -5,7 +5,7 @@ import numpy as np
5
5
 
6
6
  from konfai.network import network, blocks
7
7
  from konfai.utils.config import config
8
- from konfai.data.HDF5 import ModelPatch, Attribute
8
+ from konfai.data.patching import ModelPatch, Attribute
9
9
  from konfai.data import augmentation
10
10
  from konfai.models.segmentation import UNet, NestedUNet
11
11
  from konfai.models.generation.ddpm import DDPM
@@ -3,7 +3,7 @@ import torch
3
3
 
4
4
  from konfai.network import network, blocks
5
5
  from konfai.utils.config import config
6
- from konfai.data.HDF5 import ModelPatch
6
+ from konfai.data.patching import ModelPatch
7
7
 
8
8
  class Discriminator(network.Network):
9
9
 
@@ -3,7 +3,7 @@ import torch
3
3
 
4
4
  from konfai.network import network, blocks
5
5
  from konfai.utils.config import config
6
- from konfai.data.HDF5 import ModelPatch
6
+ from konfai.data.patching import ModelPatch
7
7
 
8
8
 
9
9
  class NestedUNetBlock(network.ModuleArgsDict):
@@ -3,7 +3,7 @@ from typing import Union
3
3
 
4
4
  from konfai.network import network, blocks
5
5
  from konfai.utils.config import config
6
- from konfai.data.HDF5 import ModelPatch
6
+ from konfai.data.patching import ModelPatch
7
7
 
8
8
  class UNetHead(network.ModuleArgsDict):
9
9
 
konfai/network/network.py CHANGED
@@ -13,11 +13,11 @@ from torch.utils.checkpoint import checkpoint
13
13
  from typing import Union
14
14
  from enum import Enum
15
15
 
16
- from konfai import DEEP_LEARNING_API_ROOT
16
+ from konfai import KONFAI_ROOT
17
17
  from konfai.metric.schedulers import Scheduler
18
18
  from konfai.utils.config import config
19
19
  from konfai.utils.utils import State, _getModule, getDevice, getGPUMemory
20
- from konfai.data.HDF5 import Accumulator, ModelPatch
20
+ from konfai.data.patching import Accumulator, ModelPatch
21
21
 
22
22
  class NetState(Enum):
23
23
  TRAIN = 0,
@@ -40,7 +40,7 @@ class OptimizerLoader():
40
40
 
41
41
  def getOptimizer(self, key: str, parameter: Iterator[torch.nn.parameter.Parameter]) -> torch.optim.Optimizer:
42
42
  torch.optim.AdamW
43
- return config("{}.Model.{}.Optimizer".format(DEEP_LEARNING_API_ROOT(), key))(getattr(importlib.import_module('torch.optim'), self.name))(parameter, config = None)
43
+ return config("{}.Model.{}.Optimizer".format(KONFAI_ROOT(), key))(getattr(importlib.import_module('torch.optim'), self.name))(parameter, config = None)
44
44
 
45
45
  class SchedulerStep():
46
46
 
@@ -98,8 +98,8 @@ class CriterionsLoader():
98
98
  for module_classpath, criterionsAttr in self.criterionsLoader.items():
99
99
  module, name = _getModule(module_classpath, "metric.measure")
100
100
  criterionsAttr.isTorchCriterion = module.startswith("torch")
101
- criterionsAttr.sheduler = criterionsAttr.l.getShedulers("{}.Model.{}.outputsCriterions.{}.targetsCriterions.{}.criterionsLoader.{}".format(DEEP_LEARNING_API_ROOT(), model_classname, output_group, target_group, module_classpath))
102
- criterions[config("{}.Model.{}.outputsCriterions.{}.targetsCriterions.{}.criterionsLoader.{}".format(DEEP_LEARNING_API_ROOT(), model_classname, output_group, target_group, module_classpath))(getattr(importlib.import_module(module), name))(config = None)] = criterionsAttr
101
+ criterionsAttr.sheduler = criterionsAttr.l.getShedulers("{}.Model.{}.outputsCriterions.{}.targetsCriterions.{}.criterionsLoader.{}".format(KONFAI_ROOT(), model_classname, output_group, target_group, module_classpath))
102
+ criterions[config("{}.Model.{}.outputsCriterions.{}.targetsCriterions.{}.criterionsLoader.{}".format(KONFAI_ROOT(), model_classname, output_group, target_group, module_classpath))(getattr(importlib.import_module(module), name))(config = None)] = criterionsAttr
103
103
  return criterions
104
104
 
105
105
  class TargetCriterionsLoader():
@@ -509,10 +509,10 @@ class Network(ModuleArgsDict, ABC):
509
509
  results[name_function(self)] = function(self, *args, **kwargs)
510
510
  return results
511
511
 
512
- def _function_network(t : bool = False):
512
+ def _function_network():
513
513
  def _function_network_d(function : Callable):
514
514
  def new_function(self : Self, *args, **kwargs) -> dict[str, object]:
515
- return self._apply_network(lambda network: network._getName() if t else network.getName(), [], self.getName(), function, *args, **kwargs)
515
+ return self._apply_network(lambda network: network.getName(), [], self.getName(), function, *args, **kwargs)
516
516
  return new_function
517
517
  return _function_network_d
518
518
 
@@ -547,7 +547,7 @@ class Network(ModuleArgsDict, ABC):
547
547
  self._it = 0
548
548
  self.outputsGroup : list[OutputsGroup]= []
549
549
 
550
- @_function_network(True)
550
+ @_function_network()
551
551
  def state_dict(self) -> dict[str, OrderedDict]:
552
552
  destination = OrderedDict()
553
553
  destination._metadata = OrderedDict()
@@ -618,13 +618,13 @@ class Network(ModuleArgsDict, ABC):
618
618
  module.apply(fn)
619
619
  fn(self)
620
620
 
621
- @_function_network(True)
621
+ @_function_network()
622
622
  def load(self, state_dict : dict[str, dict[str, torch.Tensor]], init: bool = True, ema : bool =False):
623
623
  if init:
624
624
  self.apply(partial(ModuleArgsDict.init_func, init_type=self.init_type, init_gain=self.init_gain))
625
625
  name = "Model" + ("_EMA" if ema else "")
626
626
  if name in state_dict:
627
- model_state_dict_tmp = {k.split(".")[-1] : v for k, v in state_dict[name].items()}[self._getName()]
627
+ model_state_dict_tmp = {k.split(".")[-1] : v for k, v in state_dict[name].items()}[self.getName()]
628
628
  map = self.getMap()
629
629
  model_state_dict : OrderedDict[str, torch.Tensor] = OrderedDict()
630
630
 
@@ -753,14 +753,14 @@ class Network(ModuleArgsDict, ABC):
753
753
  output_layer_accumulator : dict[str, Accumulator] = {}
754
754
  output_layer_patch_indexed : dict[str, Patch_Indexed] = {}
755
755
  it = 0
756
- debug = "DL_API_DEBUG" in os.environ
756
+ debug = "KONFAI_DEBUG" in os.environ
757
757
  for (nameTmp, output_layer) in self.named_forward(*inputs):
758
758
  name = nameTmp.replace(";accu;", "")
759
759
  if debug:
760
- if "DL_API_DEBUG_LAST_LAYER" in os.environ:
761
- os.environ["DL_API_DEBUG_LAST_LAYER"] = "{}|{}:{}:{}".format(os.environ["DL_API_DEBUG_LAST_LAYER"], name, getGPUMemory(output_layer.device), str(output_layer.device).replace("cuda:", ""))
760
+ if "KONFAI_DEBUG_LAST_LAYER" in os.environ:
761
+ os.environ["KONFAI_DEBUG_LAST_LAYER"] = "{}|{}:{}:{}".format(os.environ["KONFAI_DEBUG_LAST_LAYER"], name, getGPUMemory(output_layer.device), str(output_layer.device).replace("cuda:", ""))
762
762
  else:
763
- os.environ["DL_API_DEBUG_LAST_LAYER"] = "{}:{}:{}".format(name, getGPUMemory(output_layer.device), str(output_layer.device).replace("cuda:", ""))
763
+ os.environ["KONFAI_DEBUG_LAST_LAYER"] = "{}:{}:{}".format(name, getGPUMemory(output_layer.device), str(output_layer.device).replace("cuda:", ""))
764
764
  it += 1
765
765
  if name in layers_name or nameTmp in layers_name:
766
766
  if ";accu;" in nameTmp:
@@ -806,8 +806,6 @@ class Network(ModuleArgsDict, ABC):
806
806
 
807
807
  if not len(layers_name):
808
808
  break
809
-
810
-
811
809
 
812
810
  def init_outputsGroup(self):
813
811
  metric_tmp = {network.measure : network.measure.outputsCriterions.keys() for network in self.getNetworks().values() if network.measure}
@@ -906,15 +904,12 @@ class Network(ModuleArgsDict, ABC):
906
904
  return module
907
905
 
908
906
  def getName(self) -> str:
909
- return self.__class__.__name__
907
+ return self.name
910
908
 
911
909
  def setName(self, name: str) -> Self:
912
910
  self.name = name
913
911
  return self
914
912
 
915
- def _getName(self) -> str:
916
- return self.name
917
-
918
913
  def setState(self, state: NetState):
919
914
  for module in self.modules():
920
915
  if isinstance(module, ModuleArgsDict):
@@ -923,12 +918,12 @@ class Network(ModuleArgsDict, ABC):
923
918
  class ModelLoader():
924
919
 
925
920
  @config("Model")
926
- def __init__(self, classpath : str = "default:segmentation.UNet") -> None:
927
- self.module, self.name = _getModule(classpath.split(".")[-1] if len(classpath.split(".")) > 1 else classpath, ".".join(classpath.split(".")[:-1]) if len(classpath.split(".")) > 1 else "")
921
+ def __init__(self, classpath : str = "default:segmentation.UNet.UNet") -> None:
922
+ self.module, self.name = _getModule(classpath, "models")
928
923
 
929
924
  def getModel(self, train : bool = True, DL_args: Union[str, None] = None, DL_without=["optimizer", "schedulers", "nb_batch_per_step", "init_type", "init_gain"]) -> Network:
930
925
  if not DL_args:
931
- DL_args="{}.Model".format(DEEP_LEARNING_API_ROOT())
926
+ DL_args="{}.Model".format(KONFAI_ROOT())
932
927
  model = partial(getattr(importlib.import_module(self.module), self.name), config = None, DL_args=DL_args)
933
928
  if not train:
934
929
  model = partial(model, DL_without = DL_without)