konfai 1.0.9__tar.gz → 1.1.0__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of konfai might be problematic. Click here for more details.

Files changed (45) hide show
  1. {konfai-1.0.9 → konfai-1.1.0}/PKG-INFO +1 -1
  2. {konfai-1.0.9 → konfai-1.1.0}/konfai/data/data_manager.py +19 -19
  3. {konfai-1.0.9 → konfai-1.1.0}/konfai/data/patching.py +1 -1
  4. {konfai-1.0.9 → konfai-1.1.0}/konfai/metric/measure.py +1 -1
  5. {konfai-1.0.9 → konfai-1.1.0}/konfai/trainer.py +124 -39
  6. {konfai-1.0.9 → konfai-1.1.0}/konfai/utils/config.py +1 -1
  7. {konfai-1.0.9 → konfai-1.1.0}/konfai.egg-info/PKG-INFO +1 -1
  8. {konfai-1.0.9 → konfai-1.1.0}/konfai.egg-info/SOURCES.txt +2 -1
  9. {konfai-1.0.9 → konfai-1.1.0}/pyproject.toml +5 -2
  10. konfai-1.1.0/tests/test_config.py +18 -0
  11. {konfai-1.0.9 → konfai-1.1.0}/LICENSE +0 -0
  12. {konfai-1.0.9 → konfai-1.1.0}/README.md +0 -0
  13. {konfai-1.0.9 → konfai-1.1.0}/konfai/__init__.py +0 -0
  14. {konfai-1.0.9 → konfai-1.1.0}/konfai/data/__init__.py +0 -0
  15. {konfai-1.0.9 → konfai-1.1.0}/konfai/data/augmentation.py +0 -0
  16. {konfai-1.0.9 → konfai-1.1.0}/konfai/data/transform.py +0 -0
  17. {konfai-1.0.9 → konfai-1.1.0}/konfai/evaluator.py +0 -0
  18. {konfai-1.0.9 → konfai-1.1.0}/konfai/main.py +0 -0
  19. {konfai-1.0.9 → konfai-1.1.0}/konfai/metric/__init__.py +0 -0
  20. {konfai-1.0.9 → konfai-1.1.0}/konfai/metric/schedulers.py +0 -0
  21. {konfai-1.0.9 → konfai-1.1.0}/konfai/models/classification/convNeXt.py +0 -0
  22. {konfai-1.0.9 → konfai-1.1.0}/konfai/models/classification/resnet.py +0 -0
  23. {konfai-1.0.9 → konfai-1.1.0}/konfai/models/generation/cStyleGan.py +0 -0
  24. {konfai-1.0.9 → konfai-1.1.0}/konfai/models/generation/ddpm.py +0 -0
  25. {konfai-1.0.9 → konfai-1.1.0}/konfai/models/generation/diffusionGan.py +0 -0
  26. {konfai-1.0.9 → konfai-1.1.0}/konfai/models/generation/gan.py +0 -0
  27. {konfai-1.0.9 → konfai-1.1.0}/konfai/models/generation/vae.py +0 -0
  28. {konfai-1.0.9 → konfai-1.1.0}/konfai/models/registration/registration.py +0 -0
  29. {konfai-1.0.9 → konfai-1.1.0}/konfai/models/representation/representation.py +0 -0
  30. {konfai-1.0.9 → konfai-1.1.0}/konfai/models/segmentation/NestedUNet.py +0 -0
  31. {konfai-1.0.9 → konfai-1.1.0}/konfai/models/segmentation/UNet.py +0 -0
  32. {konfai-1.0.9 → konfai-1.1.0}/konfai/network/__init__.py +0 -0
  33. {konfai-1.0.9 → konfai-1.1.0}/konfai/network/blocks.py +0 -0
  34. {konfai-1.0.9 → konfai-1.1.0}/konfai/network/network.py +0 -0
  35. {konfai-1.0.9 → konfai-1.1.0}/konfai/predictor.py +0 -0
  36. {konfai-1.0.9 → konfai-1.1.0}/konfai/utils/ITK.py +0 -0
  37. {konfai-1.0.9 → konfai-1.1.0}/konfai/utils/__init__.py +0 -0
  38. {konfai-1.0.9 → konfai-1.1.0}/konfai/utils/dataset.py +0 -0
  39. {konfai-1.0.9 → konfai-1.1.0}/konfai/utils/registration.py +0 -0
  40. {konfai-1.0.9 → konfai-1.1.0}/konfai/utils/utils.py +0 -0
  41. {konfai-1.0.9 → konfai-1.1.0}/konfai.egg-info/dependency_links.txt +0 -0
  42. {konfai-1.0.9 → konfai-1.1.0}/konfai.egg-info/entry_points.txt +0 -0
  43. {konfai-1.0.9 → konfai-1.1.0}/konfai.egg-info/requires.txt +0 -0
  44. {konfai-1.0.9 → konfai-1.1.0}/konfai.egg-info/top_level.txt +0 -0
  45. {konfai-1.0.9 → konfai-1.1.0}/setup.cfg +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: konfai
3
- Version: 1.0.9
3
+ Version: 1.1.0
4
4
  Summary: Modular and configurable Deep Learning framework with YAML and PyTorch
5
5
  Author-email: Valentin Boussot <boussot.v@gmail.com>
6
6
  License-Expression: Apache-2.0
@@ -257,14 +257,14 @@ class Data(ABC):
257
257
  subset : Union[Subset, dict[str, Subset]],
258
258
  num_workers : int,
259
259
  batch_size : int,
260
- train_size: Union[float, str, list[int], list[str]] = 1,
260
+ validation: Union[float, str, list[int], list[str]] = 1,
261
261
  inlineAugmentations: bool = False,
262
262
  dataAugmentationsList: dict[str, DataAugmentationsList]= {}) -> None:
263
263
  self.dataset_filenames = dataset_filenames
264
264
  self.subset = subset
265
265
  self.groups_src = groups_src
266
266
  self.patch = patch
267
- self.train_size = train_size
267
+ self.validation = validation
268
268
  self.dataAugmentationsList = dataAugmentationsList
269
269
  self.batch_size = batch_size
270
270
  self.dataSet_args = dict(groups_src=self.groups_src, inlineAugmentations=inlineAugmentations, dataAugmentationsList = list(self.dataAugmentationsList.values()), use_cache = use_cache, buffer_size=batch_size+1, patch_size=self.patch.patch_size if self.patch is not None else None, overlap=self.patch.overlap if self.patch is not None else None)
@@ -372,17 +372,17 @@ class Data(ABC):
372
372
 
373
373
  train_map = map
374
374
  validate_map = []
375
- if isinstance(self.train_size, float):
376
- if self.train_size < 1.0 and int(math.floor(len(map)*(1-self.train_size))) > 0:
377
- train_map, validate_map = map[:int(math.floor(len(map)*self.train_size))], map[int(math.floor(len(map)*self.train_size)):]
378
- elif isinstance(self.train_size, str):
379
- if ":" in self.train_size:
375
+ if isinstance(self.validation, float):
376
+ if self.validation < 1.0 and int(math.floor(len(map)*(1-self.validation))) > 0:
377
+ train_map, validate_map = map[:int(math.floor(len(map)*self.validation))], map[int(math.floor(len(map)*self.validation)):]
378
+ elif isinstance(self.validation, str):
379
+ if ":" in self.validation:
380
380
  index = list(range(int(self.subset.split(":")[0]), int(self.subset.split(":")[1])))
381
381
  train_map = [m for m in map if m[0] not in index]
382
382
  validate_map = [m for m in map if m[0] in index]
383
- elif os.path.exists(self.train_size):
383
+ elif os.path.exists(self.validation):
384
384
  validation_names = []
385
- with open(self.train_size, "r") as f:
385
+ with open(self.validation, "r") as f:
386
386
  for name in f:
387
387
  validation_names.append(name.strip())
388
388
  index = [i for i, n in enumerate(subset_names) if n in validation_names]
@@ -390,13 +390,13 @@ class Data(ABC):
390
390
  validate_map = [m for m in map if m[0] in index]
391
391
  else:
392
392
  validate_map = train_map
393
- elif isinstance(self.train_size, list):
394
- if len(self.train_size) > 0:
395
- if isinstance(self.train_size[0], int):
396
- train_map = [m for m in map if m[0] not in self.train_size]
397
- validate_map = [m for m in map if m[0] in self.train_size]
398
- elif isinstance(self.train_size[0], str):
399
- index = [i for i, n in enumerate(subset_names) if n in self.train_size]
393
+ elif isinstance(self.validation, list):
394
+ if len(self.validation) > 0:
395
+ if isinstance(self.validation[0], int):
396
+ train_map = [m for m in map if m[0] not in self.validation]
397
+ validate_map = [m for m in map if m[0] in self.validation]
398
+ elif isinstance(self.validation[0], str):
399
+ index = [i for i, n in enumerate(subset_names) if n in self.validation]
400
400
  train_map = [m for m in map if m[0] not in index]
401
401
  validate_map = [m for m in map if m[0] in index]
402
402
 
@@ -436,8 +436,8 @@ class DataTrain(Data):
436
436
  subset : Union[TrainSubset, dict[str, TrainSubset]] = TrainSubset(),
437
437
  num_workers : int = 4,
438
438
  batch_size : int = 1,
439
- train_size : Union[float, str, list[int], list[str]] = 0.8) -> None:
440
- super().__init__(dataset_filenames, groups_src, patch, use_cache, subset, num_workers, batch_size, train_size, inlineAugmentations, augmentations if augmentations else {})
439
+ validation : Union[float, str, list[int], list[str]] = 0.8) -> None:
440
+ super().__init__(dataset_filenames, groups_src, patch, use_cache, subset, num_workers, batch_size, validation, inlineAugmentations, augmentations if augmentations else {})
441
441
 
442
442
  class DataPrediction(Data):
443
443
 
@@ -463,4 +463,4 @@ class DataMetric(Data):
463
463
  validation: Union[str, None] = None,
464
464
  num_workers : int = 4) -> None:
465
465
 
466
- super().__init__(dataset_filenames=dataset_filenames, groups_src=groups_src, patch=None, use_cache=False, subset=subset, num_workers=num_workers, batch_size=1, train_size=1 if validation is None else validation)
466
+ super().__init__(dataset_filenames=dataset_filenames, groups_src=groups_src, patch=None, use_cache=False, subset=subset, num_workers=num_workers, batch_size=1, validation=1 if validation is None else validation)
@@ -188,7 +188,7 @@ class Patch(ABC):
188
188
  pad_bottom = 0
189
189
  pad_top = 0
190
190
  if self._patch_slices[a][index][0].start-bottom < 0:
191
- pad_bottom = bottom-s.start
191
+ pad_bottom = bottom-self._patch_slices[a][index][0].start
192
192
  if self._patch_slices[a][index][0].stop+top > data.shape[len(slices_pre)]:
193
193
  pad_top = self._patch_slices[a][index][0].stop+top-data.shape[len(slices_pre)]
194
194
  data_sliced = F.pad(data_sliced, [0 for _ in range((len(slices)-1)*2)]+[pad_bottom, pad_top], 'reflect')
@@ -147,7 +147,7 @@ class Dice(Criterion):
147
147
  def forward(self, output: torch.Tensor, *targets : list[torch.Tensor]) -> torch.Tensor:
148
148
  target = targets[0]
149
149
  if output.shape[1] == 1:
150
- output = F.one_hot(output.type(torch.int64), num_classes=torch.max(output).item()+1).permute(0, len(target.shape), *[i+1 for i in range(len(target.shape)-1)]).float()
150
+ output = F.one_hot(output.type(torch.int64), num_classes=int(torch.max(output).item()+1)).permute(0, len(target.shape), *[i+1 for i in range(len(target.shape)-1)]).float()
151
151
  target = F.one_hot(target.type(torch.int64), num_classes=output.shape[1]).permute(0, len(target.shape), *[i+1 for i in range(len(target.shape)-1)]).float().squeeze(2)
152
152
  return 1-torch.mean(self.dice_per_channel(output, target))
153
153
 
@@ -1,5 +1,4 @@
1
1
  import torch
2
- import torch.optim.adamw
3
2
  from torch.utils.data import DataLoader
4
3
  import tqdm
5
4
  import numpy as np
@@ -19,9 +18,72 @@ from konfai.utils.utils import State, DataLog, DistributedObject, description
19
18
  from konfai.network.network import Network, ModelLoader, NetState, CPU_Model
20
19
 
21
20
 
21
+ class EarlyStoppingBase:
22
+
23
+ def __init__(self):
24
+ pass
25
+
26
+ def isStopped(self) -> bool:
27
+ return False
28
+
29
+ def getScore(self, values: dict[str, float]):
30
+ return sum([i for i in values.values()])
31
+
32
+ def __call__(self, current_score: float) -> bool:
33
+ return False
34
+
35
+ class EarlyStopping(EarlyStoppingBase):
36
+
37
+ @config("EarlyStopping")
38
+ def __init__(self, monitor: Union[list[str], None] = [], patience=10, min_delta=0.0, mode="min"):
39
+ super().__init__()
40
+ self.monitor = [] if monitor is None else monitor
41
+ self.patience = patience
42
+ self.min_delta = min_delta
43
+ self.mode = mode
44
+ self.counter = 0
45
+ self.best_score = None
46
+ self.early_stop = False
47
+
48
+ def isStopped(self) -> bool:
49
+ return self.early_stop
50
+
51
+ def getScore(self, values: dict[str, float]):
52
+ if len(self.monitor) == 0:
53
+ return super().getScore(values)
54
+ for v in self.monitor:
55
+ if v not in values.keys():
56
+ raise ValueError(
57
+ "[EarlyStopping] Metric '{}' specified in `monitor` not found in logged values. "
58
+ "Available keys: {}. Please check your configuration.".format(v, list(values.keys())))
59
+ return sum([i for v, i in values.items() if v in self.monitor])
60
+
61
+ def __call__(self, current_score: float) -> bool:
62
+ if self.best_score is None:
63
+ self.best_score = current_score
64
+ return False
65
+
66
+ if self.mode == "min":
67
+ improvement = self.best_score - current_score
68
+ elif self.mode == "max":
69
+ improvement = current_score - self.best_score
70
+ else:
71
+ raise ValueError("Mode must be 'min' or 'max'.")
72
+
73
+ if improvement > self.min_delta:
74
+ self.best_score = current_score
75
+ self.counter = 0
76
+ else:
77
+ self.counter += 1
78
+
79
+ if self.counter >= self.patience:
80
+ self.early_stop = True
81
+
82
+ return self.early_stop
83
+
22
84
  class _Trainer():
23
85
 
24
- def __init__(self, world_size: int, global_rank: int, local_rank: int, size: int, train_name: str, data_log: Union[list[str], None] , save_checkpoint_mode: str, epochs: int, epoch: int, autocast: bool, it_validation: Union[int, None], it: int, model: Union[DDP, CPU_Model], modelEMA: AveragedModel, dataloader_training: DataLoader, dataloader_validation: Union[DataLoader, None] = None) -> None:
86
+ def __init__(self, world_size: int, global_rank: int, local_rank: int, size: int, train_name: str, early_stopping: EarlyStopping, data_log: Union[list[str], None] , save_checkpoint_mode: str, epochs: int, epoch: int, autocast: bool, it_validation: Union[int, None], it: int, model: Union[DDP, CPU_Model], modelEMA: AveragedModel, dataloader_training: DataLoader, dataloader_validation: Union[DataLoader, None] = None) -> None:
25
87
  self.world_size = world_size
26
88
  self.global_rank = global_rank
27
89
  self.local_rank = local_rank
@@ -36,7 +98,8 @@ class _Trainer():
36
98
  self.dataloader_validation = dataloader_validation
37
99
  self.autocast = autocast
38
100
  self.modelEMA = modelEMA
39
-
101
+ self.early_stopping = EarlyStoppingBase() if early_stopping is None else early_stopping
102
+
40
103
  self.it_validation = it_validation
41
104
  if self.it_validation is None:
42
105
  self.it_validation = len(dataloader_training)
@@ -59,6 +122,8 @@ class _Trainer():
59
122
  for self.epoch in epoch_tqdm:
60
123
  self.dataloader_training.dataset.load()
61
124
  self.train()
125
+ if self.early_stopping.isStopped():
126
+ break
62
127
  self.dataloader_training.dataset.resetAugmentation()
63
128
 
64
129
  def getInput(self, data_dict : dict[str, tuple[torch.Tensor, int, int, int, str, bool]]) -> dict[tuple[str, bool], torch.Tensor]:
@@ -88,7 +153,11 @@ class _Trainer():
88
153
  if self.dataloader_validation is not None:
89
154
  loss = self._validate()
90
155
  self.model.module.update_lr()
91
- self.checkpoint_save(loss)
156
+ score = self.early_stopping.getScore(loss)
157
+ self.checkpoint_save(score)
158
+ if self.early_stopping(score):
159
+ break
160
+
92
161
 
93
162
  batch_iter.set_description(desc())
94
163
 
@@ -120,41 +189,53 @@ class _Trainer():
120
189
  return self._validation_log(data_dict)
121
190
 
122
191
  def checkpoint_save(self, loss: float) -> None:
123
- if self.global_rank == 0:
124
- path = CHECKPOINTS_DIRECTORY()+self.train_name+"/"
125
- last_loss = None
126
- if os.path.exists(path) and os.listdir(path):
127
- name = sorted(os.listdir(path))[-1]
128
- state_dict = torch.load(path+name, weights_only=False)
129
- last_loss = state_dict["loss"]
130
- if self.save_checkpoint_mode == "BEST":
131
- if last_loss >= loss:
132
- os.remove(path+name)
133
-
134
- if self.save_checkpoint_mode != "BEST" or (last_loss is None or last_loss >= loss):
135
- name = DATE()+".pt"
136
- if not os.path.exists(path):
137
- os.makedirs(path)
138
-
139
- save_dict = {
140
- "epoch": self.epoch,
141
- "it": self.it,
142
- "loss": loss,
143
- "Model": self.model.module.state_dict()}
192
+ if self.global_rank != 0:
193
+ return
144
194
 
145
- if self.modelEMA is not None:
146
- save_dict.update({"Model_EMA" : self.modelEMA.module.state_dict()})
195
+ path = CHECKPOINTS_DIRECTORY()+self.train_name+"/"
196
+ os.makedirs(path, exist_ok=True)
197
+
198
+ name = DATE() + ".pt"
199
+ save_path = os.path.join(path, name)
147
200
 
148
- save_dict.update({'{}_optimizer_state_dict'.format(name): network.optimizer.state_dict() for name, network in self.model.module.getNetworks().items() if network.optimizer is not None})
149
- torch.save(save_dict, path+name)
201
+ save_dict = {
202
+ "epoch": self.epoch,
203
+ "it": self.it,
204
+ "loss": loss,
205
+ "Model": self.model.module.state_dict()
206
+ }
207
+
208
+ if self.modelEMA is not None:
209
+ save_dict["Model_EMA"] = self.modelEMA.module.state_dict()
210
+
211
+ save_dict.update({'{}_optimizer_state_dict'.format(name): network.optimizer.state_dict() for name, network in self.model.module.getNetworks().items() if network.optimizer is not None})
212
+ torch.save(save_dict, save_path)
213
+
214
+ if self.save_checkpoint_mode == "BEST":
215
+ all_checkpoints = sorted([
216
+ os.path.join(path, f)
217
+ for f in os.listdir(path) if f.endswith(".pt")
218
+ ])
219
+ best_ckpt = None
220
+ best_loss = float('inf')
221
+
222
+ for f in all_checkpoints:
223
+ d = torch.load(f, weights_only=False)
224
+ if d.get("loss", float("inf")) < best_loss:
225
+ best_loss = d["loss"]
226
+ best_ckpt = f
227
+
228
+ for f in all_checkpoints:
229
+ if f != best_ckpt and f != save_path:
230
+ os.remove(f)
150
231
 
151
232
  @torch.no_grad()
152
- def _log(self, type_log: str, data_dict : dict[str, tuple[torch.Tensor, int, int, int]]) -> float:
233
+ def _log(self, type_log: str, data_dict : dict[str, tuple[torch.Tensor, int, int, int]]) -> dict[str, float]:
153
234
  models: dict[str, Network] = {"" : self.model.module}
154
235
  if self.modelEMA is not None:
155
236
  models["_EMA"] = self.modelEMA.module
156
237
 
157
- measures = DistributedObject.getMeasure(self.world_size, self.global_rank, self.local_rank*self.size+self.size-1, models, self.it_validation if type_log == "Trainning" else len(self.dataloader_validation))
238
+ measures = DistributedObject.getMeasure(self.world_size, self.global_rank, self.local_rank*self.size+self.size-1, models, self.it_validation if type_log == "Training" else len(self.dataloader_validation))
158
239
 
159
240
  if self.global_rank == 0:
160
241
  images_log = []
@@ -178,27 +259,29 @@ class _Trainer():
178
259
  for name, layer, _ in model.get_layers([v.to(0) for k, v in self.getInput(data_dict).items() if k[1]], images_log):
179
260
  self.data_log[name][0](self.tb, "{}/{}{}".format(type_log, name, label), layer[:self.data_log[name][1]].detach().cpu().numpy(), self.it)
180
261
 
181
- if type_log == "Trainning":
262
+ if type_log == "Training":
182
263
  for name, network in self.model.module.getNetworks().items():
183
264
  if network.optimizer is not None:
184
265
  self.tb.add_scalar("{}/{}/Learning Rate".format(type_log, name), network.optimizer.param_groups[0]['lr'], self.it)
185
266
 
186
267
  if self.global_rank == 0:
187
- loss = []
268
+ loss = {}
188
269
  for name, network in self.model.module.getNetworks().items():
189
270
  if network.measure is not None:
190
- loss.append(sum([v[1] for v in measures["{}".format(name)][0].values()]))
191
- return np.mean(loss)
271
+ loss.update({k : v[1] for k, v in measures["{}{}".format(name, label)][0].items()})
272
+ loss.update({k : v[1] for k, v in measures["{}{}".format(name, label)][1].items()})
273
+ return loss
192
274
  return None
193
275
 
194
276
  @torch.no_grad()
195
- def _train_log(self, data_dict : dict[str, tuple[torch.Tensor, int, int, int]]) -> float:
196
- return self._log("Trainning", data_dict)
277
+ def _train_log(self, data_dict : dict[str, tuple[torch.Tensor, int, int, int]]) -> dict[str, float]:
278
+ return self._log("Training", data_dict)
197
279
 
198
280
  @torch.no_grad()
199
- def _validation_log(self, data_dict : dict[str, tuple[torch.Tensor, int, int, int]]) -> float:
281
+ def _validation_log(self, data_dict : dict[str, tuple[torch.Tensor, int, int, int]]) -> dict[str, float]:
200
282
  return self._log("Validation", data_dict)
201
283
 
284
+
202
285
  class Trainer(DistributedObject):
203
286
 
204
287
  @config("Trainer")
@@ -214,6 +297,7 @@ class Trainer(DistributedObject):
214
297
  gpu_checkpoints: Union[list[str], None] = None,
215
298
  ema_decay : float = 0,
216
299
  data_log: Union[list[str], None] = None,
300
+ early_stopping: Union[EarlyStopping, None] = None,
217
301
  save_checkpoint_mode: str= "BEST") -> None:
218
302
  if os.environ["KONFAI_CONFIG_MODE"] != "Done":
219
303
  exit(0)
@@ -223,6 +307,7 @@ class Trainer(DistributedObject):
223
307
  self.autocast = autocast
224
308
  self.epochs = epochs
225
309
  self.epoch = 0
310
+ self.early_stopping = early_stopping
226
311
  self.it = 0
227
312
  self.it_validation = it_validation
228
313
  self.model = model.getModel(train=True)
@@ -326,5 +411,5 @@ class Trainer(DistributedObject):
326
411
  model = DDP(model, static_graph=True) if torch.cuda.is_available() else CPU_Model(model)
327
412
  if self.modelEMA is not None:
328
413
  self.modelEMA.module = Network.to(self.modelEMA.module, local_rank)
329
- with _Trainer(world_size, global_rank, local_rank, self.size, self.name, self.data_log, self.save_checkpoint_mode, self.epochs, self.epoch, self.autocast, self.it_validation, self.it, model, self.modelEMA, *dataloaders) as t:
414
+ with _Trainer(world_size, global_rank, local_rank, self.size, self.name, self.early_stopping, self.data_log, self.save_checkpoint_mode, self.epochs, self.epoch, self.autocast, self.it_validation, self.it, model, self.modelEMA, *dataloaders) as t:
330
415
  t.run()
@@ -236,7 +236,7 @@ def config(key : Union[str, None] = None):
236
236
  try:
237
237
  kwargs[param.name] = annotation(config = filename, DL_args = key_tmp)
238
238
  except Exception as e:
239
- raise ValueError("[Config] Failed to instantiate {} with type {}".format(param.name, annotation.__name__))
239
+ raise ValueError("[Config] Failed to instantiate {} with type {}, error {} ".format(param.name, annotation.__name__, e))
240
240
 
241
241
  if os.environ['KONFAI_CONFIG_VARIABLE'] == "True":
242
242
  os.environ['KONFAI_CONFIG_VARIABLE'] = "False"
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: konfai
3
- Version: 1.0.9
3
+ Version: 1.1.0
4
4
  Summary: Modular and configurable Deep Learning framework with YAML and PyTorch
5
5
  Author-email: Valentin Boussot <boussot.v@gmail.com>
6
6
  License-Expression: Apache-2.0
@@ -39,4 +39,5 @@ konfai/utils/__init__.py
39
39
  konfai/utils/config.py
40
40
  konfai/utils/dataset.py
41
41
  konfai/utils/registration.py
42
- konfai/utils/utils.py
42
+ konfai/utils/utils.py
43
+ tests/test_config.py
@@ -4,7 +4,7 @@ build-backend = "setuptools.build_meta"
4
4
 
5
5
  [project]
6
6
  name = "konfai"
7
- version = "1.0.9"
7
+ version = "1.1.0"
8
8
  description = "Modular and configurable Deep Learning framework with YAML and PyTorch"
9
9
  readme = "README.md"
10
10
  requires-python = ">=3.8"
@@ -41,4 +41,7 @@ konfai-cluster = "konfai.main:cluster"
41
41
  vtk = ["vtk"]
42
42
  lpips = ["lpips"]
43
43
  cluster = ["submitit"]
44
- plot = ["matplotlib"]
44
+ plot = ["matplotlib"]
45
+
46
+ [tool.poetry.dev-dependencies]
47
+ pytest = "^8.0"
@@ -0,0 +1,18 @@
1
+ import pytest
2
+ from konfai.utils.config import config
3
+
4
+ class Dummy:
5
+
6
+ @config("Test")
7
+ def __init__(self, a: int = 1, b: str = "ok"):
8
+ self.a = a
9
+ self.b = b
10
+
11
+ def test_config_instantiation(monkeypatch):
12
+ import os
13
+ os.environ['KONFAI_CONFIG_FILE'] = "./tests/dummy_data/dummy_config.yml"
14
+ os.environ['KONFAI_CONFIG_PATH'] = "Test"
15
+
16
+ dummy = Dummy()
17
+ assert dummy.a == 42
18
+ assert dummy.b == "hello"
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes
File without changes