konfai 1.1.0__py3-none-any.whl → 1.1.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of konfai might be problematic. Click here for more details.

konfai/evaluator.py CHANGED
@@ -9,7 +9,7 @@ import builtins
9
9
  import importlib
10
10
  from konfai import EVALUATIONS_DIRECTORY, PREDICTIONS_DIRECTORY, KONFAI_ROOT, CONFIG_FILE
11
11
  from konfai.utils.config import config
12
- from konfai.utils.utils import _getModule, DistributedObject, synchronize_data
12
+ from konfai.utils.utils import _getModule, DistributedObject, synchronize_data, EvaluatorError
13
13
  from konfai.data.data_manager import DataMetric
14
14
 
15
15
  class CriterionsAttr():
@@ -54,7 +54,8 @@ class Statistics():
54
54
  if name_dataset not in self.measures:
55
55
  self.measures[name_dataset] = {}
56
56
  self.measures[name_dataset][name] = value
57
-
57
+
58
+ @staticmethod
58
59
  def getStatistic(values: list[float]) -> dict[str, float]:
59
60
  return {"max": np.max(values), "min": np.min(values), "std": np.std(values), "25pc": np.percentile(values, 25), "50pc": np.percentile(values, 50), "75pc": np.percentile(values, 75), "mean": np.mean(values), "count": len(values)}
60
61
 
@@ -91,7 +92,6 @@ class Evaluator(DistributedObject):
91
92
  exit(0)
92
93
  super().__init__(train_name)
93
94
  self.metric_path = EVALUATIONS_DIRECTORY()+self.name+"/"
94
- self.predict_path = PREDICTIONS_DIRECTORY()+self.name+"/"
95
95
  self.metricsLoader = metrics
96
96
  self.dataset = dataset
97
97
  self.metrics = {k: v.getTargetsCriterions(k) for k, v in self.metricsLoader.items()}
@@ -102,10 +102,10 @@ class Evaluator(DistributedObject):
102
102
  result = {}
103
103
  for output_group in self.metrics:
104
104
  for target_group in self.metrics[output_group]:
105
- targets = [data_dict[group][0] for group in target_group.split("/") if group in data_dict]
105
+ targets = [data_dict[group][0].to(0) if torch.cuda.is_available() else data_dict[group][0] for group in target_group.split("/") if group in data_dict]
106
106
  name = data_dict[output_group][1][0]
107
107
  for metric in self.metrics[output_group][target_group]:
108
- result["{}:{}:{}".format(output_group, target_group, metric.__class__.__name__)] = metric(data_dict[output_group][0], *targets).item()
108
+ result["{}:{}:{}".format(output_group, target_group, metric.__class__.__name__)] = metric(data_dict[output_group][0].to(0) if torch.cuda.is_available() else data_dict[output_group][0], *targets).item()
109
109
  statistics.add(result, name)
110
110
  return result
111
111
 
@@ -126,9 +126,26 @@ class Evaluator(DistributedObject):
126
126
 
127
127
  self.dataloader = self.dataset.getData(world_size)
128
128
 
129
+ groupsDest = [group for groups in self.dataset.groups_src.values() for group in groups]
130
+
131
+ missing_outputs = set(self.metrics.keys()) - set(groupsDest)
132
+ if missing_outputs:
133
+ raise EvaluatorError(
134
+ f"The following metric output groups are missing from 'groupsDest': {sorted(missing_outputs)}. ",
135
+ f"Available groups: {sorted(groupsDest)}"
136
+ )
137
+
138
+ target_groups = {target for targets in self.metrics.values() for target in targets}
139
+ missing_targets = target_groups - set(groupsDest)
140
+ if missing_targets:
141
+ raise EvaluatorError(
142
+ f"The following metric target groups are missing from 'groupsDest': {sorted(missing_targets)}. ",
143
+ f"Available groups: {sorted(groupsDest)}"
144
+ )
145
+
129
146
  def run_process(self, world_size: int, global_rank: int, gpu: int, dataloaders: list[DataLoader]):
130
147
  description = lambda measure : "Metric TRAIN : {} ".format(" | ".join("{}: {:.2f}".format(k, v) for k, v in measure.items()) if measure is not None else "")
131
- with tqdm.tqdm(iterable = enumerate(dataloaders[0]), leave=False, desc = description(None), total=len(dataloaders[0])) as batch_iter:
148
+ with tqdm.tqdm(iterable = enumerate(dataloaders[0]), leave=True, desc = description(None), total=len(dataloaders[0]), ncols=0) as batch_iter:
132
149
  for _, data_dict in batch_iter:
133
150
  batch_iter.set_description(description(self.update({k: (v[0], v[4]) for k,v in data_dict.items()}, self.statistics_train)))
134
151
  outputs = synchronize_data(world_size, gpu, self.statistics_train.measures)
@@ -136,7 +153,7 @@ class Evaluator(DistributedObject):
136
153
  self.statistics_train.write(outputs)
137
154
  if len(dataloaders) == 2:
138
155
  description = lambda measure : "Metric VALIDATION : {} ".format(" | ".join("{}: {:.2f}".format(k, v) for k, v in measure.items()) if measure is not None else "")
139
- with tqdm.tqdm(iterable = enumerate(dataloaders[1]), leave=False, desc = description(None), total=len(dataloaders[1])) as batch_iter:
156
+ with tqdm.tqdm(iterable = enumerate(dataloaders[1]), leave=True, desc = description(None), total=len(dataloaders[1]), ncols=0) as batch_iter:
140
157
  for _, data_dict in batch_iter:
141
158
  batch_iter.set_description(description(self.update({k: (v[0], v[4]) for k,v in data_dict.items()}, self.statistics_validation)))
142
159
  outputs = synchronize_data(world_size, gpu, self.statistics_validation.measures)
konfai/main.py CHANGED
@@ -3,6 +3,7 @@ import os
3
3
  from torch.cuda import device_count
4
4
  import torch.multiprocessing as mp
5
5
  from konfai.utils.utils import setup, TensorBoard, Log
6
+ from konfai import KONFAI_NB_CORES
6
7
 
7
8
  import sys
8
9
  sys.path.insert(0, os.getcwd())
@@ -11,16 +12,15 @@ def main():
11
12
  parser = argparse.ArgumentParser(description="KonfAI", formatter_class=argparse.ArgumentDefaultsHelpFormatter)
12
13
  try:
13
14
  with setup(parser) as distributedObject:
14
- with Log(distributedObject.name):
15
+ with Log(distributedObject.name, 0):
15
16
  world_size = device_count()
16
17
  if world_size == 0:
17
- world_size = 1
18
+ world_size = int(KONFAI_NB_CORES())
18
19
  distributedObject.setup(world_size)
19
20
  with TensorBoard(distributedObject.name):
20
21
  mp.spawn(distributedObject, nprocs=world_size)
21
- except Exception as e:
22
- print(e)
23
- exit(1)
22
+ except KeyboardInterrupt:
23
+ print("\n[KonfAI] Manual interruption (Ctrl+C)")
24
24
 
25
25
 
26
26
  def cluster():
@@ -47,6 +47,8 @@ def cluster():
47
47
  executor.update_parameters(name=config["name"], mem_gb=config["memory"], gpus_per_node=n_gpu, tasks_per_node=n_gpu//distributedObject.size, cpus_per_task=config["num_workers"], nodes=config["num_nodes"], timeout_min=config["time_limit"])
48
48
  with TensorBoard(distributedObject.name):
49
49
  executor.submit(distributedObject)
50
+ except KeyboardInterrupt:
51
+ print("\n[KonfAI] Manual interruption (Ctrl+C)")
50
52
  except Exception as e:
51
53
  print(e)
52
54
  exit(1)
@@ -94,7 +94,6 @@ class SpatialTransformer(torch.nn.Module):
94
94
  new_locs[:, 1,1] = 1
95
95
  new_locs[:, 0,2] = flow[:, 0]
96
96
  new_locs[:, 1,2] = flow[:, 1]
97
- print(new_locs)
98
97
  return F.grid_sample(src, F.affine_grid(new_locs, src.size()), align_corners=True, mode="bilinear")
99
98
  else:
100
99
  new_locs = self.grid + flow
konfai/network/blocks.py CHANGED
@@ -176,7 +176,6 @@ class Print(torch.nn.Module):
176
176
  super().__init__()
177
177
 
178
178
  def forward(self, input: torch.Tensor) -> torch.Tensor:
179
- print(input.shape)
180
179
  return input
181
180
 
182
181
  class Write(torch.nn.Module):
konfai/network/network.py CHANGED
@@ -16,7 +16,7 @@ from enum import Enum
16
16
  from konfai import KONFAI_ROOT
17
17
  from konfai.metric.schedulers import Scheduler
18
18
  from konfai.utils.config import config
19
- from konfai.utils.utils import State, _getModule, getDevice, getGPUMemory
19
+ from konfai.utils.utils import State, _getModule, getDevice, getGPUMemory, MeasureError
20
20
  from konfai.data.patching import Accumulator, ModelPatch
21
21
 
22
22
  class NetState(Enum):
@@ -39,7 +39,6 @@ class OptimizerLoader():
39
39
  self.name = name
40
40
 
41
41
  def getOptimizer(self, key: str, parameter: Iterator[torch.nn.parameter.Parameter]) -> torch.optim.Optimizer:
42
- torch.optim.AdamW
43
42
  return config("{}.Model.{}.Optimizer".format(KONFAI_ROOT(), key))(getattr(importlib.import_module('torch.optim'), self.name))(parameter, config = None)
44
43
 
45
44
  class SchedulerStep():
@@ -54,12 +53,12 @@ class LRSchedulersLoader():
54
53
  def __init__(self, params: dict[str, SchedulerStep] = {"default:ReduceLROnPlateau" : SchedulerStep(0)}) -> None:
55
54
  self.params = params
56
55
 
57
- def getShedulers(self, key: str, optimizer: torch.optim.Optimizer) -> dict[torch.optim.lr_scheduler._LRScheduler, int]:
58
- shedulers : dict[torch.optim.lr_scheduler._LRScheduler, int] = {}
56
+ def getschedulers(self, key: str, optimizer: torch.optim.Optimizer) -> dict[torch.optim.lr_scheduler._LRScheduler, int]:
57
+ schedulers : dict[torch.optim.lr_scheduler._LRScheduler, int] = {}
59
58
  for name, step in self.params.items():
60
59
  if name:
61
- shedulers[config("Trainer.Model.{}.Schedulers.{}".format(key, name))(getattr(importlib.import_module('torch.optim.lr_scheduler'), name))(optimizer, config = None)] = step.nb_step
62
- return shedulers
60
+ schedulers[config("Trainer.Model.{}.Schedulers.{}".format(key, name))(getattr(importlib.import_module('torch.optim.lr_scheduler'), name))(optimizer, config = None)] = step.nb_step
61
+ return schedulers
63
62
 
64
63
  class SchedulersLoader():
65
64
 
@@ -67,12 +66,12 @@ class SchedulersLoader():
67
66
  def __init__(self, params: dict[str, SchedulerStep] = {"default:Constant" : SchedulerStep(0)}) -> None:
68
67
  self.params = params
69
68
 
70
- def getShedulers(self, key: str) -> dict[torch.optim.lr_scheduler._LRScheduler, int]:
71
- shedulers : dict[Scheduler, int] = {}
69
+ def getschedulers(self, key: str) -> dict[torch.optim.lr_scheduler._LRScheduler, int]:
70
+ schedulers : dict[Scheduler, int] = {}
72
71
  for name, step in self.params.items():
73
72
  if name:
74
- shedulers[getattr(importlib.import_module("konfai.metric.schedulers"), name)(config = None, DL_args = key)] = step.nb_step
75
- return shedulers
73
+ schedulers[getattr(importlib.import_module("konfai.metric.schedulers"), name)(config = None, DL_args = key)] = step.nb_step
74
+ return schedulers
76
75
 
77
76
  class CriterionsAttr():
78
77
 
@@ -98,7 +97,7 @@ class CriterionsLoader():
98
97
  for module_classpath, criterionsAttr in self.criterionsLoader.items():
99
98
  module, name = _getModule(module_classpath, "metric.measure")
100
99
  criterionsAttr.isTorchCriterion = module.startswith("torch")
101
- criterionsAttr.sheduler = criterionsAttr.l.getShedulers("{}.Model.{}.outputsCriterions.{}.targetsCriterions.{}.criterionsLoader.{}".format(KONFAI_ROOT(), model_classname, output_group, target_group, module_classpath))
100
+ criterionsAttr.sheduler = criterionsAttr.l.getschedulers("{}.Model.{}.outputsCriterions.{}.targetsCriterions.{}.criterionsLoader.{}".format(KONFAI_ROOT(), model_classname, output_group, target_group, module_classpath))
102
101
  criterions[config("{}.Model.{}.outputsCriterions.{}.targetsCriterions.{}.criterionsLoader.{}".format(KONFAI_ROOT(), model_classname, output_group, target_group, module_classpath))(getattr(importlib.import_module(module), name))(config = None)] = criterionsAttr
103
102
  return criterions
104
103
 
@@ -154,10 +153,25 @@ class Measure():
154
153
  self.outputsCriterions[output_group.replace(":", ".")] = targetCriterionsLoader.getTargetsCriterions(output_group, model_classname)
155
154
  self._loss : dict[int, dict[str, Measure.Loss]] = {}
156
155
 
157
- def init(self, model : torch.nn.Module) -> None:
156
+ def init(self, model : torch.nn.Module, group_dest: list[str]) -> None:
158
157
  outputs_group_rename = {}
158
+
159
+ modules = []
160
+ for i,_ in model.named_modules():
161
+ modules.append(i)
162
+
159
163
  for output_group in self.outputsCriterions.keys():
164
+ if output_group not in modules:
165
+ raise MeasureError(f"The output group '{output_group}' defined in 'outputsCriterions' does not correspond to any module in the model.",
166
+ f"Available modules: {modules}",
167
+ "Please check that the name matches exactly a submodule or output of your model architecture."
168
+ )
160
169
  for target_group in self.outputsCriterions[output_group]:
170
+ if target_group not in group_dest:
171
+ raise MeasureError(
172
+ f"The target_group '{target_group}' defined in 'outputsCriterions.{output_group}.targetsCriterions' was not found in the available destination groups.",
173
+ "This target_group is expected for loss or metric computation, but was not loaded in 'group_dest'.",
174
+ f"Please make sure that the group '{target_group}' is defined in 'Dataset:groups_src:...:groups_dest:'{target_group}'' and correctly loaded from the dataset.")
161
175
  for criterion in self.outputsCriterions[output_group][target_group]:
162
176
  if not self.outputsCriterions[output_group][target_group][criterion].isTorchCriterion:
163
177
  outputs_group_rename[output_group] = criterion.init(model, output_group, target_group)
@@ -703,10 +717,10 @@ class Network(ModuleArgsDict, ABC):
703
717
  return in_channels, in_is_channel, out_channels, out_is_channel
704
718
 
705
719
  @_function_network()
706
- def init(self, autocast : bool, state : State, key: str) -> None:
720
+ def init(self, autocast : bool, state : State, group_dest: list[str], key: str) -> None:
707
721
  if self.outputsCriterionsLoader:
708
722
  self.measure = Measure(key, self.outputsCriterionsLoader)
709
- self.measure.init(self)
723
+ self.measure.init(self, group_dest)
710
724
  if state != State.PREDICTION:
711
725
  self.scaler = torch.amp.GradScaler("cuda", enabled=autocast)
712
726
  if self.optimizerLoader:
@@ -714,7 +728,7 @@ class Network(ModuleArgsDict, ABC):
714
728
  self.optimizer.zero_grad()
715
729
 
716
730
  if self.LRSchedulersLoader and self.optimizer:
717
- self.schedulers = self.LRSchedulersLoader.getShedulers(key, self.optimizer)
731
+ self.schedulers = self.LRSchedulersLoader.getschedulers(key, self.optimizer)
718
732
 
719
733
  def initialized(self):
720
734
  pass
@@ -880,7 +894,6 @@ class Network(ModuleArgsDict, ABC):
880
894
  if scheduler:
881
895
  if scheduler.__class__.__name__ == 'ReduceLROnPlateau':
882
896
  if self.measure:
883
- print(sum(self.measure.getLastValues(0).values()))
884
897
  scheduler.step(sum(self.measure.getLastValues(0).values()))
885
898
  else:
886
899
  scheduler.step()
konfai/predictor.py CHANGED
@@ -94,10 +94,10 @@ class OutDataset(Dataset, NeedDevice, ABC):
94
94
  class OutSameAsGroupDataset(OutDataset):
95
95
 
96
96
  @config("OutDataset")
97
- def __init__(self, dataset_filename: str = "./Dataset:mha", group: str = "default", sameAsGroup: str = "default", pre_transforms : dict[str, TransformLoader] = {"default:Normalize": TransformLoader()}, post_transforms : dict[str, TransformLoader] = {"default:Normalize": TransformLoader()}, final_transforms : dict[str, TransformLoader] = {"default:Normalize": TransformLoader()}, patchCombine: Union[str, None] = None, redution: str = "mean", inverse_transform: bool = True) -> None:
97
+ def __init__(self, dataset_filename: str = "./Dataset:mha", group: str = "default", sameAsGroup: str = "default", pre_transforms : dict[str, TransformLoader] = {"default:Normalize": TransformLoader()}, post_transforms : dict[str, TransformLoader] = {"default:Normalize": TransformLoader()}, final_transforms : dict[str, TransformLoader] = {"default:Normalize": TransformLoader()}, patchCombine: Union[str, None] = None, reduction: str = "mean", inverse_transform: bool = True) -> None:
98
98
  super().__init__(dataset_filename, group, pre_transforms, post_transforms, final_transforms, patchCombine)
99
99
  self.group_src, self.group_dest = sameAsGroup.split(":")
100
- self.redution = redution
100
+ self.reduction = reduction
101
101
  self.inverse_transform = inverse_transform
102
102
 
103
103
  def addLayer(self, index_dataset: int, index_augmentation: int, index_patch: int, layer: torch.Tensor, dataset: DatasetIter):
@@ -151,9 +151,9 @@ class OutSameAsGroupDataset(OutDataset):
151
151
  self.output_layer_accumulator.pop(index)
152
152
  dtype = result.dtype
153
153
 
154
- if self.redution == "mean":
154
+ if self.reduction == "mean":
155
155
  result = torch.mean(result.float(), dim=0).to(dtype)
156
- elif self.redution == "median":
156
+ elif self.reduction == "median":
157
157
  result, _ = torch.median(result.float(), dim=0)
158
158
  else:
159
159
  raise NameError("Reduction method does not exist (mean, median)")
@@ -201,7 +201,7 @@ class OutDatasetLoader():
201
201
 
202
202
  class _Predictor():
203
203
 
204
- def __init__(self, world_size: int, global_rank: int, local_rank: int, predict_path: str, data_log: Union[list[str], None], outsDataset: dict[str, OutDataset], modelComposite: DDP, dataloader_prediction: DataLoader) -> None:
204
+ def __init__(self, world_size: int, global_rank: int, local_rank: int, autocast: bool, predict_path: str, data_log: Union[list[str], None], outsDataset: dict[str, OutDataset], modelComposite: DDP, dataloader_prediction: DataLoader) -> None:
205
205
  self.world_size = world_size
206
206
  self.global_rank = global_rank
207
207
  self.local_rank = local_rank
@@ -209,7 +209,7 @@ class _Predictor():
209
209
  self.modelComposite = modelComposite
210
210
  self.dataloader_prediction = dataloader_prediction
211
211
  self.outsDataset = outsDataset
212
-
212
+ self.autocast = autocast
213
213
 
214
214
  self.it = 0
215
215
 
@@ -239,21 +239,22 @@ class _Predictor():
239
239
  self.modelComposite.eval()
240
240
  self.modelComposite.module.setState(NetState.PREDICTION)
241
241
  desc = lambda : "Prediction : {}".format(description(self.modelComposite))
242
- self.dataloader_prediction.dataset.load()
243
- with tqdm.tqdm(iterable = enumerate(self.dataloader_prediction), leave=False, desc = desc(), total=len(self.dataloader_prediction), disable=self.global_rank != 0 and "KONFAI_CLUSTER" not in os.environ) as batch_iter:
242
+ self.dataloader_prediction.dataset.load("Prediction")
243
+ with tqdm.tqdm(iterable = enumerate(self.dataloader_prediction), leave=True, desc = desc(), total=len(self.dataloader_prediction), ncols=0) as batch_iter:
244
244
  dist.barrier()
245
245
  for it, data_dict in batch_iter:
246
- input = self.getInput(data_dict)
247
- for name, output in self.modelComposite(input, list(self.outsDataset.keys())):
248
- self._predict_log(data_dict)
249
- outDataset = self.outsDataset[name]
250
- for i, (index, patch_augmentation, patch_index) in enumerate([(int(index), int(patch_augmentation), int(patch_index)) for index, patch_augmentation, patch_index in zip(list(data_dict.values())[0][1], list(data_dict.values())[0][2], list(data_dict.values())[0][3])]):
251
- outDataset.addLayer(index, patch_augmentation, patch_index, output[i].cpu(), self.dataset)
252
- if outDataset.isDone(index):
253
- outDataset.write(index, self.dataset.getDatasetFromIndex(list(data_dict.keys())[0], index).name.split("/")[-1], outDataset.getOutput(index, self.dataset))
254
-
255
- batch_iter.set_description(desc())
256
- self.it += 1
246
+ with torch.amp.autocast('cuda', enabled=self.autocast):
247
+ input = self.getInput(data_dict)
248
+ for name, output in self.modelComposite(input, list(self.outsDataset.keys())):
249
+ self._predict_log(data_dict)
250
+ outDataset = self.outsDataset[name]
251
+ for i, (index, patch_augmentation, patch_index) in enumerate([(int(index), int(patch_augmentation), int(patch_index)) for index, patch_augmentation, patch_index in zip(list(data_dict.values())[0][1], list(data_dict.values())[0][2], list(data_dict.values())[0][3])]):
252
+ outDataset.addLayer(index, patch_augmentation, patch_index, output[i].cpu(), self.dataset)
253
+ if outDataset.isDone(index):
254
+ outDataset.write(index, self.dataset.getDatasetFromIndex(list(data_dict.keys())[0], index).name.split("/")[-1], outDataset.getOutput(index, self.dataset))
255
+
256
+ batch_iter.set_description(desc())
257
+ self.it += 1
257
258
 
258
259
  def _predict_log(self, data_dict : dict[str, tuple[torch.Tensor, int, int, int]]):
259
260
  measures = DistributedObject.getMeasure(self.world_size, self.global_rank, self.local_rank, {"" : self.modelComposite.module}, 1)
@@ -320,6 +321,7 @@ class Predictor(DistributedObject):
320
321
  train_name: str = "name",
321
322
  manual_seed : Union[int, None] = None,
322
323
  gpu_checkpoints: Union[list[str], None] = None,
324
+ autocast : bool = False,
323
325
  outsDataset: Union[dict[str, OutDatasetLoader], None] = {"default:Default" : OutDatasetLoader()},
324
326
  images_log: list[str] = []) -> None:
325
327
  if os.environ["KONFAI_CONFIG_MODE"] != "Done":
@@ -328,6 +330,7 @@ class Predictor(DistributedObject):
328
330
  self.manual_seed = manual_seed
329
331
  self.dataset = dataset
330
332
  self.combine = combine
333
+ self.autocast = autocast
331
334
 
332
335
  self.model = model.getModel(train=False)
333
336
  self.it = 0
@@ -384,7 +387,7 @@ class Predictor(DistributedObject):
384
387
 
385
388
  shutil.copyfile(CONFIG_FILE(), self.predict_path+"Prediction.yml")
386
389
 
387
- self.model.init(autocast=False, state = State.PREDICTION)
390
+ self.model.init(self.autocast, State.PREDICTION, self.dataset.getGroupsDest())
388
391
  self.model.init_outputsGroup()
389
392
  self.model._compute_channels_trace(self.model, self.model.in_channels, None, self.gpu_checkpoints)
390
393
  self.modelComposite = ModelComposite(self.model, len(MODEL().split(":")), self.combine)
@@ -402,7 +405,7 @@ class Predictor(DistributedObject):
402
405
  def run_process(self, world_size: int, global_rank: int, local_rank: int, dataloaders: list[DataLoader]):
403
406
  modelComposite = Network.to(self.modelComposite, local_rank*self.size)
404
407
  modelComposite = DDP(modelComposite, static_graph=True) if torch.cuda.is_available() else CPU_Model(modelComposite)
405
- with _Predictor(world_size, global_rank, local_rank, self.predict_path, self.images_log, self.outsDataset, modelComposite, *dataloaders) as p:
408
+ with _Predictor(world_size, global_rank, local_rank, self.autocast, self.predict_path, self.images_log, self.outsDataset, modelComposite, *dataloaders) as p:
406
409
  p.run()
407
410
 
408
411
 
konfai/trainer.py CHANGED
@@ -14,10 +14,9 @@ import torch.distributed as dist
14
14
  from konfai import MODELS_DIRECTORY, CHECKPOINTS_DIRECTORY, STATISTICS_DIRECTORY, SETUPS_DIRECTORY, CONFIG_FILE, MODEL, DATE, KONFAI_STATE
15
15
  from konfai.data.data_manager import DataTrain
16
16
  from konfai.utils.config import config
17
- from konfai.utils.utils import State, DataLog, DistributedObject, description
17
+ from konfai.utils.utils import State, DataLog, DistributedObject, description, TrainerError
18
18
  from konfai.network.network import Network, ModelLoader, NetState, CPU_Model
19
19
 
20
-
21
20
  class EarlyStoppingBase:
22
21
 
23
22
  def __init__(self):
@@ -53,8 +52,8 @@ class EarlyStopping(EarlyStoppingBase):
53
52
  return super().getScore(values)
54
53
  for v in self.monitor:
55
54
  if v not in values.keys():
56
- raise ValueError(
57
- "[EarlyStopping] Metric '{}' specified in `monitor` not found in logged values. "
55
+ raise TrainerError(
56
+ "Metric '{}' specified in EarlyStopping.monitor not found in logged values. ",
58
57
  "Available keys: {}. Please check your configuration.".format(v, list(values.keys())))
59
58
  return sum([i for v, i in values.items() if v in self.monitor])
60
59
 
@@ -68,7 +67,7 @@ class EarlyStopping(EarlyStoppingBase):
68
67
  elif self.mode == "max":
69
68
  improvement = current_score - self.best_score
70
69
  else:
71
- raise ValueError("Mode must be 'min' or 'max'.")
70
+ raise TrainerError("Mode must be 'min' or 'max'.")
72
71
 
73
72
  if improvement > self.min_delta:
74
73
  self.best_score = current_score
@@ -99,7 +98,7 @@ class _Trainer():
99
98
  self.autocast = autocast
100
99
  self.modelEMA = modelEMA
101
100
  self.early_stopping = EarlyStoppingBase() if early_stopping is None else early_stopping
102
-
101
+
103
102
  self.it_validation = it_validation
104
103
  if self.it_validation is None:
105
104
  self.it_validation = len(dataloader_training)
@@ -118,13 +117,14 @@ class _Trainer():
118
117
  self.tb.close()
119
118
 
120
119
  def run(self) -> None:
121
- with tqdm.tqdm(iterable = range(self.epoch, self.epochs), leave=False, total=self.epochs, initial=self.epoch, desc="Progress", disable=self.global_rank != 0) as epoch_tqdm:
120
+ self.dataloader_training.dataset.load("Train")
121
+ self.dataloader_validation.dataset.load("Validation")
122
+ with tqdm.tqdm(iterable = range(self.epoch, self.epochs), leave=False, total=self.epochs, initial=self.epoch, desc="Progress") as epoch_tqdm:
122
123
  for self.epoch in epoch_tqdm:
123
- self.dataloader_training.dataset.load()
124
124
  self.train()
125
125
  if self.early_stopping.isStopped():
126
126
  break
127
- self.dataloader_training.dataset.resetAugmentation()
127
+ self.dataloader_training.dataset.resetAugmentation("Train")
128
128
 
129
129
  def getInput(self, data_dict : dict[str, tuple[torch.Tensor, int, int, int, str, bool]]) -> dict[tuple[str, bool], torch.Tensor]:
130
130
  return {(k, v[5][0].item()) : v[0] for k, v in data_dict.items()}
@@ -137,7 +137,8 @@ class _Trainer():
137
137
  self.modelEMA.module.setState(NetState.TRAIN)
138
138
 
139
139
  desc = lambda : "Training : {}".format(description(self.model, self.modelEMA))
140
- with tqdm.tqdm(iterable = enumerate(self.dataloader_training), desc = desc(), total=len(self.dataloader_training), leave=False, disable=self.global_rank != 0 and "KONFAI_CLUSTER" not in os.environ) as batch_iter:
140
+
141
+ with tqdm.tqdm(iterable = enumerate(self.dataloader_training), desc = desc(), total=len(self.dataloader_training), leave=False, ncols=0) as batch_iter:
141
142
  for _, data_dict in batch_iter:
142
143
  with torch.amp.autocast('cuda', enabled=self.autocast):
143
144
  input = self.getInput(data_dict)
@@ -171,8 +172,7 @@ class _Trainer():
171
172
 
172
173
  desc = lambda : "Validation : {}".format(description(self.model, self.modelEMA))
173
174
  data_dict = None
174
- self.dataloader_validation.dataset.load()
175
- with tqdm.tqdm(iterable = enumerate(self.dataloader_validation), desc = desc(), total=len(self.dataloader_validation), leave=False, disable=self.global_rank != 0 and "KONFAI_CLUSTER" not in os.environ) as batch_iter:
175
+ with tqdm.tqdm(iterable = enumerate(self.dataloader_validation), desc = desc(), total=len(self.dataloader_validation), leave=False, ncols=0) as batch_iter:
176
176
  for _, data_dict in batch_iter:
177
177
  input = self.getInput(data_dict)
178
178
  self.model(input)
@@ -180,7 +180,7 @@ class _Trainer():
180
180
  self.modelEMA.module(input)
181
181
 
182
182
  batch_iter.set_description(desc())
183
- self.dataloader_validation.dataset.resetAugmentation()
183
+ self.dataloader_validation.dataset.resetAugmentation("Validation")
184
184
  dist.barrier()
185
185
  self.model.train()
186
186
  self.model.module.setState(NetState.TRAIN)
@@ -360,7 +360,7 @@ class Trainer(DistributedObject):
360
360
  os.makedirs(dir)
361
361
 
362
362
  for name in sorted(os.listdir(path_checkpoint)):
363
- checkpoint = torch.load(path_checkpoint+name, weights_only=False)
363
+ checkpoint = torch.load(path_checkpoint+name, weights_only=False, map_location='cpu')
364
364
  self.model.load(checkpoint, init=False, ema=False)
365
365
 
366
366
  torch.save(self.model, "{}Serialized/{}".format(path_model, name))
@@ -391,7 +391,7 @@ class Trainer(DistributedObject):
391
391
  if state != State.TRAIN:
392
392
  state_dict = self._load()
393
393
 
394
- self.model.init(self.autocast, state)
394
+ self.model.init(self.autocast, state, self.dataset.getGroupsDest())
395
395
  self.model.init_outputsGroup()
396
396
  self.model._compute_channels_trace(self.model, self.model.in_channels, self.gradient_checkpoints, self.gpu_checkpoints)
397
397
  self.model.load(state_dict, init=True, ema=False)
konfai/utils/config.py CHANGED
@@ -6,14 +6,10 @@ from copy import deepcopy
6
6
  from typing import Union, Literal, get_origin, get_args
7
7
  import torch
8
8
  from konfai import CONFIG_FILE
9
+ from konfai.utils.utils import ConfigError
9
10
 
10
11
  yaml = ruamel.yaml.YAML()
11
12
 
12
- class ConfigError(Exception):
13
-
14
- def __init__(self, message : str = "The config only supports types : config(Object), int, str, bool, float, list[int], list[str], list[bool], list[float], dict[str, Object]") -> None:
15
- self.message = message
16
- super().__init__(self.message)
17
13
 
18
14
  class Config():
19
15
 
@@ -187,9 +183,8 @@ def config(key : Union[str, None] = None):
187
183
  default_value = param.default if param.default != inspect._empty else allowed_values[0]
188
184
  value = config.getValue(param.name, f"default:{default_value}")
189
185
  if value not in allowed_values:
190
- raise ValueError(
191
- f"[Config] Invalid value '{value}' for parameter '{param.name}'. "
192
- f"Expected one of: {allowed_values}."
186
+ raise ConfigError(
187
+ f"Invalid value '{value}' for parameter '{param.name} expected one of: {allowed_values}."
193
188
  )
194
189
  kwargs[param.name] = value
195
190
  continue
@@ -222,21 +217,26 @@ def config(key : Union[str, None] = None):
222
217
  values = config.getValue(param.name, param.default)
223
218
  kwargs[param.name] = values
224
219
  else:
225
- raise ConfigError()
220
+ raise ConfigError("Config: The config only supports types : config(Object), int, str, bool, float, list[int], list[str], list[bool], list[float], dict[str, Object]")
226
221
  elif str(annotation).startswith("dict"):
227
222
  if annotation.__args__[0] == str:
228
223
  values = config.getValue(param.name, param.default)
229
224
  if values is not None and annotation.__args__[1] not in [int, str, bool, float]:
230
- kwargs[param.name] = {value : annotation.__args__[1](config = filename, DL_args = key_tmp+"."+param.name+"."+value) for value in values}
225
+ try:
226
+ kwargs[param.name] = {value : annotation.__args__[1](config = filename, DL_args = key_tmp+"."+param.name+"."+value) for value in values}
227
+ except ValueError as e:
228
+ raise ValueError(e)
229
+ except Exception as e:
230
+ raise ConfigError("{} {}".format(values, e))
231
231
  else:
232
232
  kwargs[param.name] = values
233
233
  else:
234
- raise ConfigError()
234
+ raise ConfigError("Config: The config only supports types : config(Object), int, str, bool, float, list[int], list[str], list[bool], list[float], dict[str, Object]")
235
235
  else:
236
236
  try:
237
237
  kwargs[param.name] = annotation(config = filename, DL_args = key_tmp)
238
238
  except Exception as e:
239
- raise ValueError("[Config] Failed to instantiate {} with type {}, error {} ".format(param.name, annotation.__name__, e))
239
+ raise ConfigError("Failed to instantiate {} with type {}, error {} ".format(param.name, annotation.__name__, e))
240
240
 
241
241
  if os.environ['KONFAI_CONFIG_VARIABLE'] == "True":
242
242
  os.environ['KONFAI_CONFIG_VARIABLE'] = "False"
konfai/utils/dataset.py CHANGED
@@ -348,6 +348,10 @@ class Dataset():
348
348
  @abstractmethod
349
349
  def getNames(self, group: str) -> list[str]:
350
350
  pass
351
+
352
+ @abstractmethod
353
+ def getGroup(self) -> list[str]:
354
+ pass
351
355
 
352
356
  @abstractmethod
353
357
  def isExist(self, group: str, name: Union[str, None] = None) -> bool:
@@ -458,6 +462,9 @@ class Dataset():
458
462
  names.extend(self.getNames("/".join(groups.split("/")[1:]), h5_group[group]))
459
463
  return names
460
464
 
465
+ def getGroup(self):
466
+ return self.h5.keys()
467
+
461
468
  def _getDataset(self, groups: str, name: str, h5_group: h5py.Group = None) -> h5py.Dataset:
462
469
  if h5_group is None:
463
470
  h5_group = self.h5
@@ -613,7 +620,10 @@ class Dataset():
613
620
 
614
621
  def getNames(self, group: str) -> list[str]:
615
622
  raise NotImplementedError()
616
-
623
+
624
+ def getGroup(self):
625
+ raise NotImplementedError()
626
+
617
627
  def getInfos(self, group: str, name: str) -> tuple[list[int], Attribute]:
618
628
  attributes = Attribute()
619
629
  if os.path.exists("{}{}{}.{}".format(self.filename, group if group is not None else "", name, self.format)):
@@ -736,7 +746,7 @@ class Dataset():
736
746
  subDirectories.extend(self._getSubDirectories("/".join(groups.split("/")[1:]), subDirectory))
737
747
  return subDirectories
738
748
 
739
- def getNames(self, groups: str, index: Union[list[int], None] = None, subDirectory: str = "") -> list[str]:
749
+ def getNames(self, groups: str, index: Union[list[int], None] = None) -> list[str]:
740
750
  names = []
741
751
  if self.is_directory:
742
752
  for subDirectory in self._getSubDirectories(groups):
@@ -752,6 +762,21 @@ class Dataset():
752
762
  names = file.getNames(groups)
753
763
  return [name for i, name in enumerate(sorted(names)) if index is None or i in index]
754
764
 
765
+ def getGroup(self):
766
+ if self.is_directory:
767
+ groups = set()
768
+ for root, _, files in os.walk(self.filename):
769
+ for file in files:
770
+ path = os.path.relpath(os.path.join(root, file.split(".")[0]), self.filename)
771
+ parts = path.split("/")
772
+ if len(parts) >= 2:
773
+ del parts[-2]
774
+ groups.add("/".join(parts))
775
+ else:
776
+ with Dataset.File(self.filename, True, self.format) as file:
777
+ groups = file.getGroup()
778
+ return list(groups)
779
+
755
780
  def getInfos(self, groups: str, name: str) -> tuple[list[int], Attribute]:
756
781
  if self.is_directory:
757
782
  for subDirectory in self._getSubDirectories(groups):