konfai 1.0.8__py3-none-any.whl → 1.1.0__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of konfai might be problematic. Click here for more details.

konfai/network/network.py CHANGED
@@ -13,11 +13,11 @@ from torch.utils.checkpoint import checkpoint
13
13
  from typing import Union
14
14
  from enum import Enum
15
15
 
16
- from konfai import DEEP_LEARNING_API_ROOT
16
+ from konfai import KONFAI_ROOT
17
17
  from konfai.metric.schedulers import Scheduler
18
18
  from konfai.utils.config import config
19
19
  from konfai.utils.utils import State, _getModule, getDevice, getGPUMemory
20
- from konfai.data.HDF5 import Accumulator, ModelPatch
20
+ from konfai.data.patching import Accumulator, ModelPatch
21
21
 
22
22
  class NetState(Enum):
23
23
  TRAIN = 0,
@@ -40,7 +40,7 @@ class OptimizerLoader():
40
40
 
41
41
  def getOptimizer(self, key: str, parameter: Iterator[torch.nn.parameter.Parameter]) -> torch.optim.Optimizer:
42
42
  torch.optim.AdamW
43
- return config("{}.Model.{}.Optimizer".format(DEEP_LEARNING_API_ROOT(), key))(getattr(importlib.import_module('torch.optim'), self.name))(parameter, config = None)
43
+ return config("{}.Model.{}.Optimizer".format(KONFAI_ROOT(), key))(getattr(importlib.import_module('torch.optim'), self.name))(parameter, config = None)
44
44
 
45
45
  class SchedulerStep():
46
46
 
@@ -98,8 +98,8 @@ class CriterionsLoader():
98
98
  for module_classpath, criterionsAttr in self.criterionsLoader.items():
99
99
  module, name = _getModule(module_classpath, "metric.measure")
100
100
  criterionsAttr.isTorchCriterion = module.startswith("torch")
101
- criterionsAttr.sheduler = criterionsAttr.l.getShedulers("{}.Model.{}.outputsCriterions.{}.targetsCriterions.{}.criterionsLoader.{}".format(DEEP_LEARNING_API_ROOT(), model_classname, output_group, target_group, module_classpath))
102
- criterions[config("{}.Model.{}.outputsCriterions.{}.targetsCriterions.{}.criterionsLoader.{}".format(DEEP_LEARNING_API_ROOT(), model_classname, output_group, target_group, module_classpath))(getattr(importlib.import_module(module), name))(config = None)] = criterionsAttr
101
+ criterionsAttr.sheduler = criterionsAttr.l.getShedulers("{}.Model.{}.outputsCriterions.{}.targetsCriterions.{}.criterionsLoader.{}".format(KONFAI_ROOT(), model_classname, output_group, target_group, module_classpath))
102
+ criterions[config("{}.Model.{}.outputsCriterions.{}.targetsCriterions.{}.criterionsLoader.{}".format(KONFAI_ROOT(), model_classname, output_group, target_group, module_classpath))(getattr(importlib.import_module(module), name))(config = None)] = criterionsAttr
103
103
  return criterions
104
104
 
105
105
  class TargetCriterionsLoader():
@@ -753,14 +753,14 @@ class Network(ModuleArgsDict, ABC):
753
753
  output_layer_accumulator : dict[str, Accumulator] = {}
754
754
  output_layer_patch_indexed : dict[str, Patch_Indexed] = {}
755
755
  it = 0
756
- debug = "DL_API_DEBUG" in os.environ
756
+ debug = "KONFAI_DEBUG" in os.environ
757
757
  for (nameTmp, output_layer) in self.named_forward(*inputs):
758
758
  name = nameTmp.replace(";accu;", "")
759
759
  if debug:
760
- if "DL_API_DEBUG_LAST_LAYER" in os.environ:
761
- os.environ["DL_API_DEBUG_LAST_LAYER"] = "{}|{}:{}:{}".format(os.environ["DL_API_DEBUG_LAST_LAYER"], name, getGPUMemory(output_layer.device), str(output_layer.device).replace("cuda:", ""))
760
+ if "KONFAI_DEBUG_LAST_LAYER" in os.environ:
761
+ os.environ["KONFAI_DEBUG_LAST_LAYER"] = "{}|{}:{}:{}".format(os.environ["KONFAI_DEBUG_LAST_LAYER"], name, getGPUMemory(output_layer.device), str(output_layer.device).replace("cuda:", ""))
762
762
  else:
763
- os.environ["DL_API_DEBUG_LAST_LAYER"] = "{}:{}:{}".format(name, getGPUMemory(output_layer.device), str(output_layer.device).replace("cuda:", ""))
763
+ os.environ["KONFAI_DEBUG_LAST_LAYER"] = "{}:{}:{}".format(name, getGPUMemory(output_layer.device), str(output_layer.device).replace("cuda:", ""))
764
764
  it += 1
765
765
  if name in layers_name or nameTmp in layers_name:
766
766
  if ";accu;" in nameTmp:
@@ -918,12 +918,12 @@ class Network(ModuleArgsDict, ABC):
918
918
  class ModelLoader():
919
919
 
920
920
  @config("Model")
921
- def __init__(self, classpath : str = "default:segmentation.UNet") -> None:
922
- self.module, self.name = _getModule(classpath.split(".")[-1] if len(classpath.split(".")) > 1 else classpath, ".".join(classpath.split(".")[:-1]) if len(classpath.split(".")) > 1 else "")
921
+ def __init__(self, classpath : str = "default:segmentation.UNet.UNet") -> None:
922
+ self.module, self.name = _getModule(classpath, "models")
923
923
 
924
924
  def getModel(self, train : bool = True, DL_args: Union[str, None] = None, DL_without=["optimizer", "schedulers", "nb_batch_per_step", "init_type", "init_gain"]) -> Network:
925
925
  if not DL_args:
926
- DL_args="{}.Model".format(DEEP_LEARNING_API_ROOT())
926
+ DL_args="{}.Model".format(KONFAI_ROOT())
927
927
  model = partial(getattr(importlib.import_module(self.module), self.name), config = None, DL_args=DL_args)
928
928
  if not train:
929
929
  model = partial(model, DL_without = DL_without)
konfai/predictor.py CHANGED
@@ -6,12 +6,12 @@ import torch
6
6
  import tqdm
7
7
  import os
8
8
 
9
- from konfai import MODELS_DIRECTORY, PREDICTIONS_DIRECTORY, CONFIG_FILE, MODEL, DEEP_LEARNING_API_ROOT
9
+ from konfai import MODELS_DIRECTORY, PREDICTIONS_DIRECTORY, CONFIG_FILE, MODEL, KONFAI_ROOT
10
10
  from konfai.utils.config import config
11
11
  from konfai.utils.utils import State, get_patch_slices_from_nb_patch_per_dim, NeedDevice, _getModule, DistributedObject, DataLog, description
12
12
  from konfai.utils.dataset import Dataset, Attribute
13
- from konfai.data.dataset import DataPrediction, DatasetIter
14
- from konfai.data.HDF5 import Accumulator, PathCombine
13
+ from konfai.data.data_manager import DataPrediction, DatasetIter
14
+ from konfai.data.patching import Accumulator, PathCombine
15
15
  from konfai.network.network import ModelLoader, Network, NetState, CPU_Model
16
16
  from konfai.data.transform import Transform, TransformLoader
17
17
 
@@ -52,13 +52,13 @@ class OutDataset(Dataset, NeedDevice, ABC):
52
52
 
53
53
  if _transform_type is not None:
54
54
  for classpath, transform in _transform_type.items():
55
- transform = transform.getTransform(classpath, DL_args = "{}.outsDataset.{}.OutDataset.{}".format(DEEP_LEARNING_API_ROOT(), name_layer, name))
55
+ transform = transform.getTransform(classpath, DL_args = "{}.outsDataset.{}.OutDataset.{}".format(KONFAI_ROOT(), name_layer, name))
56
56
  transform.setDatasets(datasets)
57
57
  transform_type.append(transform)
58
58
 
59
59
  if self._patchCombine is not None:
60
- module, name = _getModule(self._patchCombine, "data.HDF5")
61
- self.patchCombine = getattr(importlib.import_module(module), name)(config = None, DL_args = "{}.outsDataset.{}.OutDataset".format(DEEP_LEARNING_API_ROOT(), name_layer))
60
+ module, name = _getModule(self._patchCombine, "data.patching")
61
+ self.patchCombine = getattr(importlib.import_module(module), name)(config = None, DL_args = "{}.outsDataset.{}.OutDataset".format(KONFAI_ROOT(), name_layer))
62
62
 
63
63
  def setPatchConfig(self, patchSize: Union[list[int], None], overlap: Union[int, None], nb_data_augmentation: int) -> None:
64
64
  if patchSize is not None and overlap is not None:
@@ -94,7 +94,7 @@ class OutDataset(Dataset, NeedDevice, ABC):
94
94
  class OutSameAsGroupDataset(OutDataset):
95
95
 
96
96
  @config("OutDataset")
97
- def __init__(self, dataset_filename: str = "Dataset:h5", group: str = "default", sameAsGroup: str = "default", pre_transforms : dict[str, TransformLoader] = {"default:Normalize": TransformLoader()}, post_transforms : dict[str, TransformLoader] = {"default:Normalize": TransformLoader()}, final_transforms : dict[str, TransformLoader] = {"default:Normalize": TransformLoader()}, patchCombine: Union[str, None] = None, redution: str = "mean", inverse_transform: bool = True) -> None:
97
+ def __init__(self, dataset_filename: str = "./Dataset:mha", group: str = "default", sameAsGroup: str = "default", pre_transforms : dict[str, TransformLoader] = {"default:Normalize": TransformLoader()}, post_transforms : dict[str, TransformLoader] = {"default:Normalize": TransformLoader()}, final_transforms : dict[str, TransformLoader] = {"default:Normalize": TransformLoader()}, patchCombine: Union[str, None] = None, redution: str = "mean", inverse_transform: bool = True) -> None:
98
98
  super().__init__(dataset_filename, group, pre_transforms, post_transforms, final_transforms, patchCombine)
99
99
  self.group_src, self.group_dest = sameAsGroup.split(":")
100
100
  self.redution = redution
@@ -240,7 +240,7 @@ class _Predictor():
240
240
  self.modelComposite.module.setState(NetState.PREDICTION)
241
241
  desc = lambda : "Prediction : {}".format(description(self.modelComposite))
242
242
  self.dataloader_prediction.dataset.load()
243
- with tqdm.tqdm(iterable = enumerate(self.dataloader_prediction), leave=False, desc = desc(), total=len(self.dataloader_prediction), disable=self.global_rank != 0 and "DL_API_CLUSTER" not in os.environ) as batch_iter:
243
+ with tqdm.tqdm(iterable = enumerate(self.dataloader_prediction), leave=False, desc = desc(), total=len(self.dataloader_prediction), disable=self.global_rank != 0 and "KONFAI_CLUSTER" not in os.environ) as batch_iter:
244
244
  dist.barrier()
245
245
  for it, data_dict in batch_iter:
246
246
  input = self.getInput(data_dict)
@@ -322,7 +322,7 @@ class Predictor(DistributedObject):
322
322
  gpu_checkpoints: Union[list[str], None] = None,
323
323
  outsDataset: Union[dict[str, OutDatasetLoader], None] = {"default:Default" : OutDatasetLoader()},
324
324
  images_log: list[str] = []) -> None:
325
- if os.environ["DEEP_LEANING_API_CONFIG_MODE"] != "Done":
325
+ if os.environ["KONFAI_CONFIG_MODE"] != "Done":
326
326
  exit(0)
327
327
  super().__init__(train_name)
328
328
  self.manual_seed = manual_seed
@@ -374,7 +374,7 @@ class Predictor(DistributedObject):
374
374
  for dataset_filename in self.datasets_filename:
375
375
  path = self.predict_path +dataset_filename
376
376
  if os.path.exists(path):
377
- if os.environ["DL_API_OVERWRITE"] != "True":
377
+ if os.environ["KONFAI_OVERWRITE"] != "True":
378
378
  accept = builtins.input("The prediction {} already exists ! Do you want to overwrite it (yes,no) : ".format(path))
379
379
  if accept != "yes":
380
380
  return
konfai/trainer.py CHANGED
@@ -1,5 +1,4 @@
1
1
  import torch
2
- import torch.optim.adamw
3
2
  from torch.utils.data import DataLoader
4
3
  import tqdm
5
4
  import numpy as np
@@ -12,16 +11,79 @@ from torch.utils.tensorboard.writer import SummaryWriter
12
11
  from torch.optim.swa_utils import AveragedModel
13
12
  import torch.distributed as dist
14
13
 
15
- from konfai import MODELS_DIRECTORY, CHECKPOINTS_DIRECTORY, STATISTICS_DIRECTORY, SETUPS_DIRECTORY, CONFIG_FILE, MODEL, DATE, DL_API_STATE
16
- from konfai.data.dataset import DataTrain
14
+ from konfai import MODELS_DIRECTORY, CHECKPOINTS_DIRECTORY, STATISTICS_DIRECTORY, SETUPS_DIRECTORY, CONFIG_FILE, MODEL, DATE, KONFAI_STATE
15
+ from konfai.data.data_manager import DataTrain
17
16
  from konfai.utils.config import config
18
17
  from konfai.utils.utils import State, DataLog, DistributedObject, description
19
18
  from konfai.network.network import Network, ModelLoader, NetState, CPU_Model
20
19
 
21
20
 
21
+ class EarlyStoppingBase:
22
+
23
+ def __init__(self):
24
+ pass
25
+
26
+ def isStopped(self) -> bool:
27
+ return False
28
+
29
+ def getScore(self, values: dict[str, float]):
30
+ return sum([i for i in values.values()])
31
+
32
+ def __call__(self, current_score: float) -> bool:
33
+ return False
34
+
35
+ class EarlyStopping(EarlyStoppingBase):
36
+
37
+ @config("EarlyStopping")
38
+ def __init__(self, monitor: Union[list[str], None] = [], patience=10, min_delta=0.0, mode="min"):
39
+ super().__init__()
40
+ self.monitor = [] if monitor is None else monitor
41
+ self.patience = patience
42
+ self.min_delta = min_delta
43
+ self.mode = mode
44
+ self.counter = 0
45
+ self.best_score = None
46
+ self.early_stop = False
47
+
48
+ def isStopped(self) -> bool:
49
+ return self.early_stop
50
+
51
+ def getScore(self, values: dict[str, float]):
52
+ if len(self.monitor) == 0:
53
+ return super().getScore(values)
54
+ for v in self.monitor:
55
+ if v not in values.keys():
56
+ raise ValueError(
57
+ "[EarlyStopping] Metric '{}' specified in `monitor` not found in logged values. "
58
+ "Available keys: {}. Please check your configuration.".format(v, list(values.keys())))
59
+ return sum([i for v, i in values.items() if v in self.monitor])
60
+
61
+ def __call__(self, current_score: float) -> bool:
62
+ if self.best_score is None:
63
+ self.best_score = current_score
64
+ return False
65
+
66
+ if self.mode == "min":
67
+ improvement = self.best_score - current_score
68
+ elif self.mode == "max":
69
+ improvement = current_score - self.best_score
70
+ else:
71
+ raise ValueError("Mode must be 'min' or 'max'.")
72
+
73
+ if improvement > self.min_delta:
74
+ self.best_score = current_score
75
+ self.counter = 0
76
+ else:
77
+ self.counter += 1
78
+
79
+ if self.counter >= self.patience:
80
+ self.early_stop = True
81
+
82
+ return self.early_stop
83
+
22
84
  class _Trainer():
23
85
 
24
- def __init__(self, world_size: int, global_rank: int, local_rank: int, size: int, train_name: str, data_log: Union[list[str], None] , save_checkpoint_mode: str, epochs: int, epoch: int, autocast: bool, it_validation: Union[int, None], it: int, model: Union[DDP, CPU_Model], modelEMA: AveragedModel, dataloader_training: DataLoader, dataloader_validation: Union[DataLoader, None] = None) -> None:
86
+ def __init__(self, world_size: int, global_rank: int, local_rank: int, size: int, train_name: str, early_stopping: EarlyStopping, data_log: Union[list[str], None] , save_checkpoint_mode: str, epochs: int, epoch: int, autocast: bool, it_validation: Union[int, None], it: int, model: Union[DDP, CPU_Model], modelEMA: AveragedModel, dataloader_training: DataLoader, dataloader_validation: Union[DataLoader, None] = None) -> None:
25
87
  self.world_size = world_size
26
88
  self.global_rank = global_rank
27
89
  self.local_rank = local_rank
@@ -36,7 +98,8 @@ class _Trainer():
36
98
  self.dataloader_validation = dataloader_validation
37
99
  self.autocast = autocast
38
100
  self.modelEMA = modelEMA
39
-
101
+ self.early_stopping = EarlyStoppingBase() if early_stopping is None else early_stopping
102
+
40
103
  self.it_validation = it_validation
41
104
  if self.it_validation is None:
42
105
  self.it_validation = len(dataloader_training)
@@ -59,6 +122,8 @@ class _Trainer():
59
122
  for self.epoch in epoch_tqdm:
60
123
  self.dataloader_training.dataset.load()
61
124
  self.train()
125
+ if self.early_stopping.isStopped():
126
+ break
62
127
  self.dataloader_training.dataset.resetAugmentation()
63
128
 
64
129
  def getInput(self, data_dict : dict[str, tuple[torch.Tensor, int, int, int, str, bool]]) -> dict[tuple[str, bool], torch.Tensor]:
@@ -72,7 +137,7 @@ class _Trainer():
72
137
  self.modelEMA.module.setState(NetState.TRAIN)
73
138
 
74
139
  desc = lambda : "Training : {}".format(description(self.model, self.modelEMA))
75
- with tqdm.tqdm(iterable = enumerate(self.dataloader_training), desc = desc(), total=len(self.dataloader_training), leave=False, disable=self.global_rank != 0 and "DL_API_CLUSTER" not in os.environ) as batch_iter:
140
+ with tqdm.tqdm(iterable = enumerate(self.dataloader_training), desc = desc(), total=len(self.dataloader_training), leave=False, disable=self.global_rank != 0 and "KONFAI_CLUSTER" not in os.environ) as batch_iter:
76
141
  for _, data_dict in batch_iter:
77
142
  with torch.amp.autocast('cuda', enabled=self.autocast):
78
143
  input = self.getInput(data_dict)
@@ -88,7 +153,11 @@ class _Trainer():
88
153
  if self.dataloader_validation is not None:
89
154
  loss = self._validate()
90
155
  self.model.module.update_lr()
91
- self.checkpoint_save(loss)
156
+ score = self.early_stopping.getScore(loss)
157
+ self.checkpoint_save(score)
158
+ if self.early_stopping(score):
159
+ break
160
+
92
161
 
93
162
  batch_iter.set_description(desc())
94
163
 
@@ -103,7 +172,7 @@ class _Trainer():
103
172
  desc = lambda : "Validation : {}".format(description(self.model, self.modelEMA))
104
173
  data_dict = None
105
174
  self.dataloader_validation.dataset.load()
106
- with tqdm.tqdm(iterable = enumerate(self.dataloader_validation), desc = desc(), total=len(self.dataloader_validation), leave=False, disable=self.global_rank != 0 and "DL_API_CLUSTER" not in os.environ) as batch_iter:
175
+ with tqdm.tqdm(iterable = enumerate(self.dataloader_validation), desc = desc(), total=len(self.dataloader_validation), leave=False, disable=self.global_rank != 0 and "KONFAI_CLUSTER" not in os.environ) as batch_iter:
107
176
  for _, data_dict in batch_iter:
108
177
  input = self.getInput(data_dict)
109
178
  self.model(input)
@@ -120,41 +189,53 @@ class _Trainer():
120
189
  return self._validation_log(data_dict)
121
190
 
122
191
  def checkpoint_save(self, loss: float) -> None:
123
- if self.global_rank == 0:
124
- path = CHECKPOINTS_DIRECTORY()+self.train_name+"/"
125
- last_loss = None
126
- if os.path.exists(path) and os.listdir(path):
127
- name = sorted(os.listdir(path))[-1]
128
- state_dict = torch.load(path+name, weights_only=False)
129
- last_loss = state_dict["loss"]
130
- if self.save_checkpoint_mode == "BEST":
131
- if last_loss >= loss:
132
- os.remove(path+name)
133
-
134
- if self.save_checkpoint_mode != "BEST" or (last_loss is None or last_loss >= loss):
135
- name = DATE()+".pt"
136
- if not os.path.exists(path):
137
- os.makedirs(path)
138
-
139
- save_dict = {
140
- "epoch": self.epoch,
141
- "it": self.it,
142
- "loss": loss,
143
- "Model": self.model.module.state_dict()}
192
+ if self.global_rank != 0:
193
+ return
144
194
 
145
- if self.modelEMA is not None:
146
- save_dict.update({"Model_EMA" : self.modelEMA.module.state_dict()})
195
+ path = CHECKPOINTS_DIRECTORY()+self.train_name+"/"
196
+ os.makedirs(path, exist_ok=True)
197
+
198
+ name = DATE() + ".pt"
199
+ save_path = os.path.join(path, name)
147
200
 
148
- save_dict.update({'{}_optimizer_state_dict'.format(name): network.optimizer.state_dict() for name, network in self.model.module.getNetworks().items() if network.optimizer is not None})
149
- torch.save(save_dict, path+name)
201
+ save_dict = {
202
+ "epoch": self.epoch,
203
+ "it": self.it,
204
+ "loss": loss,
205
+ "Model": self.model.module.state_dict()
206
+ }
207
+
208
+ if self.modelEMA is not None:
209
+ save_dict["Model_EMA"] = self.modelEMA.module.state_dict()
210
+
211
+ save_dict.update({'{}_optimizer_state_dict'.format(name): network.optimizer.state_dict() for name, network in self.model.module.getNetworks().items() if network.optimizer is not None})
212
+ torch.save(save_dict, save_path)
213
+
214
+ if self.save_checkpoint_mode == "BEST":
215
+ all_checkpoints = sorted([
216
+ os.path.join(path, f)
217
+ for f in os.listdir(path) if f.endswith(".pt")
218
+ ])
219
+ best_ckpt = None
220
+ best_loss = float('inf')
221
+
222
+ for f in all_checkpoints:
223
+ d = torch.load(f, weights_only=False)
224
+ if d.get("loss", float("inf")) < best_loss:
225
+ best_loss = d["loss"]
226
+ best_ckpt = f
227
+
228
+ for f in all_checkpoints:
229
+ if f != best_ckpt and f != save_path:
230
+ os.remove(f)
150
231
 
151
232
  @torch.no_grad()
152
- def _log(self, type_log: str, data_dict : dict[str, tuple[torch.Tensor, int, int, int]]) -> float:
233
+ def _log(self, type_log: str, data_dict : dict[str, tuple[torch.Tensor, int, int, int]]) -> dict[str, float]:
153
234
  models: dict[str, Network] = {"" : self.model.module}
154
235
  if self.modelEMA is not None:
155
236
  models["_EMA"] = self.modelEMA.module
156
237
 
157
- measures = DistributedObject.getMeasure(self.world_size, self.global_rank, self.local_rank*self.size+self.size-1, models, self.it_validation if type_log == "Trainning" else len(self.dataloader_validation))
238
+ measures = DistributedObject.getMeasure(self.world_size, self.global_rank, self.local_rank*self.size+self.size-1, models, self.it_validation if type_log == "Training" else len(self.dataloader_validation))
158
239
 
159
240
  if self.global_rank == 0:
160
241
  images_log = []
@@ -178,34 +259,36 @@ class _Trainer():
178
259
  for name, layer, _ in model.get_layers([v.to(0) for k, v in self.getInput(data_dict).items() if k[1]], images_log):
179
260
  self.data_log[name][0](self.tb, "{}/{}{}".format(type_log, name, label), layer[:self.data_log[name][1]].detach().cpu().numpy(), self.it)
180
261
 
181
- if type_log == "Trainning":
262
+ if type_log == "Training":
182
263
  for name, network in self.model.module.getNetworks().items():
183
264
  if network.optimizer is not None:
184
265
  self.tb.add_scalar("{}/{}/Learning Rate".format(type_log, name), network.optimizer.param_groups[0]['lr'], self.it)
185
266
 
186
267
  if self.global_rank == 0:
187
- loss = []
268
+ loss = {}
188
269
  for name, network in self.model.module.getNetworks().items():
189
270
  if network.measure is not None:
190
- loss.append(sum([v[1] for v in measures["{}".format(name)][0].values()]))
191
- return np.mean(loss)
271
+ loss.update({k : v[1] for k, v in measures["{}{}".format(name, label)][0].items()})
272
+ loss.update({k : v[1] for k, v in measures["{}{}".format(name, label)][1].items()})
273
+ return loss
192
274
  return None
193
275
 
194
276
  @torch.no_grad()
195
- def _train_log(self, data_dict : dict[str, tuple[torch.Tensor, int, int, int]]) -> float:
196
- return self._log("Trainning", data_dict)
277
+ def _train_log(self, data_dict : dict[str, tuple[torch.Tensor, int, int, int]]) -> dict[str, float]:
278
+ return self._log("Training", data_dict)
197
279
 
198
280
  @torch.no_grad()
199
- def _validation_log(self, data_dict : dict[str, tuple[torch.Tensor, int, int, int]]) -> float:
281
+ def _validation_log(self, data_dict : dict[str, tuple[torch.Tensor, int, int, int]]) -> dict[str, float]:
200
282
  return self._log("Validation", data_dict)
201
283
 
284
+
202
285
  class Trainer(DistributedObject):
203
286
 
204
287
  @config("Trainer")
205
288
  def __init__( self,
206
289
  model : ModelLoader = ModelLoader(),
207
290
  dataset : DataTrain = DataTrain(),
208
- train_name : str = "default:name",
291
+ train_name : str = "default:TRAIN_01",
209
292
  manual_seed : Union[int, None] = None,
210
293
  epochs: int = 100,
211
294
  it_validation : Union[int, None] = None,
@@ -214,8 +297,9 @@ class Trainer(DistributedObject):
214
297
  gpu_checkpoints: Union[list[str], None] = None,
215
298
  ema_decay : float = 0,
216
299
  data_log: Union[list[str], None] = None,
300
+ early_stopping: Union[EarlyStopping, None] = None,
217
301
  save_checkpoint_mode: str= "BEST") -> None:
218
- if os.environ["DEEP_LEANING_API_CONFIG_MODE"] != "Done":
302
+ if os.environ["KONFAI_CONFIG_MODE"] != "Done":
219
303
  exit(0)
220
304
  super().__init__(train_name)
221
305
  self.manual_seed = manual_seed
@@ -223,6 +307,7 @@ class Trainer(DistributedObject):
223
307
  self.autocast = autocast
224
308
  self.epochs = epochs
225
309
  self.epoch = 0
310
+ self.early_stopping = early_stopping
226
311
  self.it = 0
227
312
  self.it_validation = it_validation
228
313
  self.model = model.getModel(train=True)
@@ -292,9 +377,9 @@ class Trainer(DistributedObject):
292
377
  return (1-self.ema_decay) * averaged_model_parameter + self.ema_decay * model_parameter
293
378
 
294
379
  def setup(self, world_size: int):
295
- state = State._member_map_[DL_API_STATE()]
296
- if state != State.RESUME and os.path.exists(STATISTICS_DIRECTORY()+self.name+"/"):
297
- if os.environ["DL_API_OVERWRITE"] != "True":
380
+ state = State._member_map_[KONFAI_STATE()]
381
+ if state != State.RESUME and os.path.exists(CHECKPOINTS_DIRECTORY()+self.name+"/"):
382
+ if os.environ["KONFAI_OVERWRITE"] != "True":
298
383
  accept = input("The model {} already exists ! Do you want to overwrite it (yes,no) : ".format(self.name))
299
384
  if accept != "yes":
300
385
  return
@@ -326,5 +411,5 @@ class Trainer(DistributedObject):
326
411
  model = DDP(model, static_graph=True) if torch.cuda.is_available() else CPU_Model(model)
327
412
  if self.modelEMA is not None:
328
413
  self.modelEMA.module = Network.to(self.modelEMA.module, local_rank)
329
- with _Trainer(world_size, global_rank, local_rank, self.size, self.name, self.data_log, self.save_checkpoint_mode, self.epochs, self.epoch, self.autocast, self.it_validation, self.it, model, self.modelEMA, *dataloaders) as t:
414
+ with _Trainer(world_size, global_rank, local_rank, self.size, self.name, self.early_stopping, self.data_log, self.save_checkpoint_mode, self.epochs, self.epoch, self.autocast, self.it_validation, self.it, model, self.modelEMA, *dataloaders) as t:
330
415
  t.run()
konfai/utils/config.py CHANGED
@@ -3,9 +3,8 @@ import ruamel.yaml
3
3
  import inspect
4
4
  import collections
5
5
  from copy import deepcopy
6
- from typing import Union
6
+ from typing import Union, Literal, get_origin, get_args
7
7
  import torch
8
-
9
8
  from konfai import CONFIG_FILE
10
9
 
11
10
  yaml = ruamel.yaml.YAML()
@@ -26,7 +25,7 @@ class Config():
26
25
  if not os.path.exists(self.filename):
27
26
  result = input("Create a new config file ? [no,yes,interactive] : ")
28
27
  if result in ["yes", "interactive"]:
29
- os.environ["DEEP_LEANING_API_CONFIG_MODE"] = "interactive" if result == "interactive" else "default"
28
+ os.environ["KONFAI_CONFIG_MODE"] = "interactive" if result == "interactive" else "default"
30
29
  else:
31
30
  exit(0)
32
31
  with open(self.filename, "w") as f:
@@ -69,7 +68,7 @@ class Config():
69
68
 
70
69
  def __exit__(self, type, value, traceback) -> None:
71
70
  self.yml.close()
72
- if os.environ["DEEP_LEANING_API_CONFIG_MODE"] == "remove":
71
+ if os.environ["KONFAI_CONFIG_MODE"] == "remove":
73
72
  if os.path.exists(CONFIG_FILE()):
74
73
  os.remove(CONFIG_FILE())
75
74
  return
@@ -87,22 +86,22 @@ class Config():
87
86
  except:
88
87
  result = input("\nKeep a default configuration file ? (yes,no) : ")
89
88
  if result == "yes":
90
- os.environ["DEEP_LEANING_API_CONFIG_MODE"] = "default"
89
+ os.environ["KONFAI_CONFIG_MODE"] = "default"
91
90
  else:
92
- os.environ["DEEP_LEANING_API_CONFIG_MODE"] = "remove"
91
+ os.environ["KONFAI_CONFIG_MODE"] = "remove"
93
92
  exit(0)
94
93
  return default.split(":")[1] if len(default.split(":")) > 1 else default
95
94
 
96
95
  @staticmethod
97
96
  def _getInputDefault(name : str, default : Union[str, None], isList : bool = False) -> Union[list[Union[str, None]], str, None]:
98
97
  if isinstance(default, str) and (default == "default" or (len(default.split(":")) > 1 and default.split(":")[0] == "default")):
99
- if os.environ["DEEP_LEANING_API_CONFIG_MODE"] == "interactive":
98
+ if os.environ["KONFAI_CONFIG_MODE"] == "interactive":
100
99
  if isList:
101
100
  list_tmp = []
102
101
  key_tmp = "OK"
103
- while key_tmp != "!" and os.environ["DEEP_LEANING_API_CONFIG_MODE"] == "interactive":
102
+ while (key_tmp != "!" and key_tmp != " ") and os.environ["KONFAI_CONFIG_MODE"] == "interactive":
104
103
  key_tmp = Config._getInput(name, default)
105
- if key_tmp != "!":
104
+ if (key_tmp != "!" and key_tmp != " "):
106
105
  if key_tmp == "":
107
106
  key_tmp = default.split(":")[1] if len(default.split(":")) > 1 else default
108
107
  list_tmp.append(key_tmp)
@@ -134,6 +133,7 @@ class Config():
134
133
  list_tmp = []
135
134
  for key in value_config:
136
135
  list_tmp.extend(Config._getInputDefault(name, key, isList=True))
136
+
137
137
 
138
138
  value = list_tmp
139
139
  value_config = list_tmp
@@ -155,7 +155,7 @@ class Config():
155
155
  dict_value[key] = value_tmp
156
156
  value = dict_value
157
157
  if isinstance(self.config, str):
158
- os.environ['DEEP_LEARNING_API_CONFIG_VARIABLE'] = "True"
158
+ os.environ['KONFAI_CONFIG_VARIABLE'] = "True"
159
159
  return None
160
160
 
161
161
  self.config[name] = value_config if value_config is not None else "None"
@@ -169,17 +169,30 @@ def config(key : Union[str, None] = None):
169
169
  if "config" in kwargs:
170
170
  filename = kwargs["config"]
171
171
  if filename == None:
172
- filename = os.environ['DEEP_LEARNING_API_CONFIG_FILE']
172
+ filename = os.environ['KONFAI_CONFIG_FILE']
173
173
  else:
174
- os.environ['DEEP_LEARNING_API_CONFIG_FILE'] = filename
174
+ os.environ['KONFAI_CONFIG_FILE'] = filename
175
175
  key_tmp = kwargs["DL_args"]+("."+key if key is not None else "") if "DL_args" in kwargs else key
176
176
  without = kwargs["DL_without"] if "DL_without" in kwargs else []
177
- os.environ['DEEP_LEARNING_API_CONFIG_PATH'] = key_tmp
177
+ os.environ['KONFAI_CONFIG_PATH'] = key_tmp
178
178
  with Config(filename, key_tmp) as config:
179
- os.environ['DEEP_LEARNING_API_CONFIG_VARIABLE'] = "False"
179
+ os.environ['KONFAI_CONFIG_VARIABLE'] = "False"
180
180
  kwargs = {}
181
181
  for param in list(inspect.signature(function).parameters.values())[len(args):]:
182
+
182
183
  annotation = param.annotation
184
+ # --- support Literal ---
185
+ if get_origin(annotation) is Literal:
186
+ allowed_values = get_args(annotation)
187
+ default_value = param.default if param.default != inspect._empty else allowed_values[0]
188
+ value = config.getValue(param.name, f"default:{default_value}")
189
+ if value not in allowed_values:
190
+ raise ValueError(
191
+ f"[Config] Invalid value '{value}' for parameter '{param.name}'. "
192
+ f"Expected one of: {allowed_values}."
193
+ )
194
+ kwargs[param.name] = value
195
+ continue
183
196
  if str(annotation).startswith("typing.Union") or str(annotation).startswith("typing.Optional"):
184
197
  for i in annotation.__args__:
185
198
  annotation = i
@@ -188,8 +201,24 @@ def config(key : Union[str, None] = None):
188
201
  continue
189
202
  if not annotation == inspect._empty:
190
203
  if annotation not in [int, str, bool, float, torch.Tensor]:
191
- if str(annotation).startswith("list") or str(annotation).startswith("tuple") or str(annotation).startswith("typing.Tuple"):
192
- if annotation.__args__[0] in [int, str, bool, float]:
204
+ if str(annotation).startswith("list") or str(annotation).startswith("tuple") or str(annotation).startswith("typing.Tuple") or str(annotation).startswith("typing.List"):
205
+ elem_type = annotation.__args__[0]
206
+ values = config.getValue(param.name, param.default)
207
+ if getattr(elem_type, '__origin__', None) is Union:
208
+ valid_types = elem_type.__args__
209
+ result = []
210
+ for v in values:
211
+ for t in valid_types:
212
+ try:
213
+ if t == torch.Tensor and not isinstance(v, torch.Tensor):
214
+ v = torch.tensor(v)
215
+ result.append(t(v) if t != torch.Tensor else v)
216
+ break
217
+ except Exception:
218
+ continue
219
+ kwargs[param.name] = result
220
+
221
+ elif annotation.__args__[0] in [int, str, bool, float]:
193
222
  values = config.getValue(param.name, param.default)
194
223
  kwargs[param.name] = values
195
224
  else:
@@ -204,9 +233,13 @@ def config(key : Union[str, None] = None):
204
233
  else:
205
234
  raise ConfigError()
206
235
  else:
207
- kwargs[param.name] = annotation(config = filename, DL_args = key_tmp)
208
- if os.environ['DEEP_LEARNING_API_CONFIG_VARIABLE'] == "True":
209
- os.environ['DEEP_LEARNING_API_CONFIG_VARIABLE'] = "False"
236
+ try:
237
+ kwargs[param.name] = annotation(config = filename, DL_args = key_tmp)
238
+ except Exception as e:
239
+ raise ValueError("[Config] Failed to instantiate {} with type {}, error {} ".format(param.name, annotation.__name__, e))
240
+
241
+ if os.environ['KONFAI_CONFIG_VARIABLE'] == "True":
242
+ os.environ['KONFAI_CONFIG_VARIABLE'] = "False"
210
243
  kwargs[param.name] = None
211
244
  else:
212
245
  kwargs[param.name] = config.getValue(param.name, param.default)
konfai/utils/dataset.py CHANGED
@@ -750,7 +750,7 @@ class Dataset():
750
750
  else:
751
751
  with Dataset.File(self.filename, True, self.format) as file:
752
752
  names = file.getNames(groups)
753
- return [name for i, name in enumerate(names) if index is None or i in index]
753
+ return [name for i, name in enumerate(sorted(names)) if index is None or i in index]
754
754
 
755
755
  def getInfos(self, groups: str, name: str) -> tuple[list[int], Attribute]:
756
756
  if self.is_directory: