konfai 1.0.7__py3-none-any.whl → 1.0.9__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of konfai might be problematic. Click here for more details.

konfai/predictor.py CHANGED
@@ -6,12 +6,12 @@ import torch
6
6
  import tqdm
7
7
  import os
8
8
 
9
- from konfai import MODELS_DIRECTORY, PREDICTIONS_DIRECTORY, CONFIG_FILE, MODEL, DEEP_LEARNING_API_ROOT
9
+ from konfai import MODELS_DIRECTORY, PREDICTIONS_DIRECTORY, CONFIG_FILE, MODEL, KONFAI_ROOT
10
10
  from konfai.utils.config import config
11
11
  from konfai.utils.utils import State, get_patch_slices_from_nb_patch_per_dim, NeedDevice, _getModule, DistributedObject, DataLog, description
12
12
  from konfai.utils.dataset import Dataset, Attribute
13
- from konfai.data.dataset import DataPrediction, DatasetIter
14
- from konfai.data.HDF5 import Accumulator, PathCombine
13
+ from konfai.data.data_manager import DataPrediction, DatasetIter
14
+ from konfai.data.patching import Accumulator, PathCombine
15
15
  from konfai.network.network import ModelLoader, Network, NetState, CPU_Model
16
16
  from konfai.data.transform import Transform, TransformLoader
17
17
 
@@ -22,6 +22,8 @@ import torch.distributed as dist
22
22
  from torch.nn.parallel import DistributedDataParallel as DDP
23
23
  from torch.utils.data import DataLoader
24
24
  import importlib
25
+ import copy
26
+ from collections import defaultdict
25
27
 
26
28
  class OutDataset(Dataset, NeedDevice, ABC):
27
29
 
@@ -50,13 +52,13 @@ class OutDataset(Dataset, NeedDevice, ABC):
50
52
 
51
53
  if _transform_type is not None:
52
54
  for classpath, transform in _transform_type.items():
53
- transform = transform.getTransform(classpath, DL_args = "{}.outsDataset.{}.OutDataset.{}".format(DEEP_LEARNING_API_ROOT(), name_layer, name))
55
+ transform = transform.getTransform(classpath, DL_args = "{}.outsDataset.{}.OutDataset.{}".format(KONFAI_ROOT(), name_layer, name))
54
56
  transform.setDatasets(datasets)
55
57
  transform_type.append(transform)
56
58
 
57
59
  if self._patchCombine is not None:
58
- module, name = _getModule(self._patchCombine, "data.HDF5")
59
- self.patchCombine = getattr(importlib.import_module(module), name)(config = None, DL_args = "{}.outsDataset.{}.OutDataset".format(DEEP_LEARNING_API_ROOT(), name_layer))
60
+ module, name = _getModule(self._patchCombine, "data.patching")
61
+ self.patchCombine = getattr(importlib.import_module(module), name)(config = None, DL_args = "{}.outsDataset.{}.OutDataset".format(KONFAI_ROOT(), name_layer))
60
62
 
61
63
  def setPatchConfig(self, patchSize: Union[list[int], None], overlap: Union[int, None], nb_data_augmentation: int) -> None:
62
64
  if patchSize is not None and overlap is not None:
@@ -92,7 +94,7 @@ class OutDataset(Dataset, NeedDevice, ABC):
92
94
  class OutSameAsGroupDataset(OutDataset):
93
95
 
94
96
  @config("OutDataset")
95
- def __init__(self, dataset_filename: str = "Dataset:h5", group: str = "default", sameAsGroup: str = "default", pre_transforms : dict[str, TransformLoader] = {"default:Normalize": TransformLoader()}, post_transforms : dict[str, TransformLoader] = {"default:Normalize": TransformLoader()}, final_transforms : dict[str, TransformLoader] = {"default:Normalize": TransformLoader()}, patchCombine: Union[str, None] = None, redution: str = "mean", inverse_transform: bool = True) -> None:
97
+ def __init__(self, dataset_filename: str = "./Dataset:mha", group: str = "default", sameAsGroup: str = "default", pre_transforms : dict[str, TransformLoader] = {"default:Normalize": TransformLoader()}, post_transforms : dict[str, TransformLoader] = {"default:Normalize": TransformLoader()}, final_transforms : dict[str, TransformLoader] = {"default:Normalize": TransformLoader()}, patchCombine: Union[str, None] = None, redution: str = "mean", inverse_transform: bool = True) -> None:
96
98
  super().__init__(dataset_filename, group, pre_transforms, post_transforms, final_transforms, patchCombine)
97
99
  self.group_src, self.group_dest = sameAsGroup.split(":")
98
100
  self.redution = redution
@@ -152,7 +154,7 @@ class OutSameAsGroupDataset(OutDataset):
152
154
  if self.redution == "mean":
153
155
  result = torch.mean(result.float(), dim=0).to(dtype)
154
156
  elif self.redution == "median":
155
- result, _ = torch.median(result.float(), dim=0).to(dtype)
157
+ result, _ = torch.median(result.float(), dim=0)
156
158
  else:
157
159
  raise NameError("Reduction method does not exist (mean, median)")
158
160
  for transform in self.final_transforms:
@@ -199,19 +201,19 @@ class OutDatasetLoader():
199
201
 
200
202
  class _Predictor():
201
203
 
202
- def __init__(self, world_size: int, global_rank: int, local_rank: int, predict_path: str, data_log: Union[list[str], None], outsDataset: dict[str, OutDataset], model: DDP, dataloader_prediction: DataLoader) -> None:
204
+ def __init__(self, world_size: int, global_rank: int, local_rank: int, predict_path: str, data_log: Union[list[str], None], outsDataset: dict[str, OutDataset], modelComposite: DDP, dataloader_prediction: DataLoader) -> None:
203
205
  self.world_size = world_size
204
206
  self.global_rank = global_rank
205
207
  self.local_rank = local_rank
206
208
 
207
- self.model = model
209
+ self.modelComposite = modelComposite
208
210
  self.dataloader_prediction = dataloader_prediction
209
211
  self.outsDataset = outsDataset
210
212
 
211
213
 
212
214
  self.it = 0
213
215
 
214
- self.device = self.model.device
216
+ self.device = self.modelComposite.device
215
217
  self.dataset: DatasetIter = self.dataloader_prediction.dataset
216
218
  patch_size, overlap = self.dataset.getPatchConfig()
217
219
  for outDataset in self.outsDataset.values():
@@ -220,7 +222,7 @@ class _Predictor():
220
222
  if data_log is not None:
221
223
  for data in data_log:
222
224
  self.data_log[data.split("/")[0].replace(":", ".")] = (DataLog.__getitem__(data.split("/")[1]).value[0], int(data.split("/")[2]))
223
- self.tb = SummaryWriter(log_dir = predict_path+"Metric/") if len([network for network in self.model.module.getNetworks().values() if network.measure is not None]) or len(self.data_log) else None
225
+ self.tb = SummaryWriter(log_dir = predict_path+"Metric/") if len([network for network in self.modelComposite.module.getNetworks().values() if network.measure is not None]) or len(self.data_log) else None
224
226
 
225
227
  def __enter__(self):
226
228
  return self
@@ -234,15 +236,15 @@ class _Predictor():
234
236
 
235
237
  @torch.no_grad()
236
238
  def run(self):
237
- self.model.eval()
238
- self.model.module.setState(NetState.PREDICTION)
239
- desc = lambda : "Prediction : {}".format(description(self.model))
239
+ self.modelComposite.eval()
240
+ self.modelComposite.module.setState(NetState.PREDICTION)
241
+ desc = lambda : "Prediction : {}".format(description(self.modelComposite))
240
242
  self.dataloader_prediction.dataset.load()
241
- with tqdm.tqdm(iterable = enumerate(self.dataloader_prediction), leave=False, desc = desc(), total=len(self.dataloader_prediction), disable=self.global_rank != 0 and "DL_API_CLUSTER" not in os.environ) as batch_iter:
243
+ with tqdm.tqdm(iterable = enumerate(self.dataloader_prediction), leave=False, desc = desc(), total=len(self.dataloader_prediction), disable=self.global_rank != 0 and "KONFAI_CLUSTER" not in os.environ) as batch_iter:
242
244
  dist.barrier()
243
245
  for it, data_dict in batch_iter:
244
246
  input = self.getInput(data_dict)
245
- for name, output in self.model(input, list(self.outsDataset.keys())):
247
+ for name, output in self.modelComposite(input, list(self.outsDataset.keys())):
246
248
  self._predict_log(data_dict)
247
249
  outDataset = self.outsDataset[name]
248
250
  for i, (index, patch_augmentation, patch_index) in enumerate([(int(index), int(patch_augmentation), int(patch_index)) for index, patch_augmentation, patch_index in zip(list(data_dict.values())[0][1], list(data_dict.values())[0][2], list(data_dict.values())[0][3])]):
@@ -254,7 +256,7 @@ class _Predictor():
254
256
  self.it += 1
255
257
 
256
258
  def _predict_log(self, data_dict : dict[str, tuple[torch.Tensor, int, int, int]]):
257
- measures = DistributedObject.getMeasure(self.world_size, self.global_rank, self.local_rank, {"" : self.model.module}, 1)
259
+ measures = DistributedObject.getMeasure(self.world_size, self.global_rank, self.local_rank, {"" : self.modelComposite.module}, 1)
258
260
 
259
261
  if self.global_rank == 0:
260
262
  images_log = []
@@ -265,30 +267,67 @@ class _Predictor():
265
267
  else:
266
268
  images_log.append(name.replace(":", "."))
267
269
 
268
- for name, network in self.model.module.getNetworks().items():
270
+ for name, network in self.modelComposite.module.getNetworks().items():
269
271
  if network.measure is not None:
270
272
  self.tb.add_scalars("Prediction/{}/Loss".format(name), {k : v[1] for k, v in measures["{}{}".format(name, "")][0].items()}, self.it)
271
273
  self.tb.add_scalars("Prediction/{}/Metric".format(name), {k : v[1] for k, v in measures["{}{}".format(name, "")][1].items()}, self.it)
272
274
  if len(images_log):
273
275
  for name, layer, _ in self.model.module.get_layers([v.to(0) for k, v in self.getInput(data_dict).items() if k[1]], images_log):
274
276
  self.data_log[name][0](self.tb, "Prediction/{}".format(name), layer[:self.data_log[name][1]].detach().cpu().numpy(), self.it)
277
+
278
+ class ModelComposite(Network):
279
+
280
+ def __init__(self, model: Network, nb_models: int, method: str):
281
+ super().__init__(model.in_channels, model.optimizer, model.schedulers, model.outputsCriterionsLoader, model.patch, model.nb_batch_per_step, model.init_type, model.init_gain, model.dim)
282
+ self.method = method
283
+ for i in range(nb_models):
284
+ self.add_module("Model_{}".format(i), copy.deepcopy(model), in_branch=[0], out_branch=["output_{}".format(i)])
285
+
286
+ def load(self, state_dicts : list[dict[str, dict[str, torch.Tensor]]]):
287
+ for i, state_dict in enumerate(state_dicts):
288
+ self["Model_{}".format(i)].load(state_dict, init=False)
289
+ self["Model_{}".format(i)].setName("{}_{}".format(self["Model_{}".format(i)].getName(), i))
290
+
291
+ def forward(self, data_dict: dict[tuple[str, bool], torch.Tensor], output_layers: list[str] = []) -> list[tuple[str, torch.Tensor]]:
292
+ result = {}
293
+ for name, module in self.items():
294
+ result[name] = module(data_dict, output_layers)
275
295
 
296
+ aggregated = defaultdict(list)
297
+ for module_outputs in result.values():
298
+ for key, tensor in module_outputs:
299
+ aggregated[key].append(tensor)
300
+
301
+ final_outputs = []
302
+ for key, tensors in aggregated.items():
303
+ stacked = torch.stack(tensors, dim=0)
304
+ if self.method == 'mean':
305
+ agg = torch.mean(stacked, dim=0)
306
+ elif self.method == 'median':
307
+ agg = torch.median(stacked, dim=0).values
308
+ final_outputs.append((key, agg))
309
+
310
+ return final_outputs
311
+
312
+
276
313
  class Predictor(DistributedObject):
277
314
 
278
315
  @config("Predictor")
279
316
  def __init__(self,
280
317
  model: ModelLoader = ModelLoader(),
281
318
  dataset: DataPrediction = DataPrediction(),
319
+ combine: str = "mean",
282
320
  train_name: str = "name",
283
321
  manual_seed : Union[int, None] = None,
284
322
  gpu_checkpoints: Union[list[str], None] = None,
285
323
  outsDataset: Union[dict[str, OutDatasetLoader], None] = {"default:Default" : OutDatasetLoader()},
286
324
  images_log: list[str] = []) -> None:
287
- if os.environ["DEEP_LEANING_API_CONFIG_MODE"] != "Done":
325
+ if os.environ["KONFAI_CONFIG_MODE"] != "Done":
288
326
  exit(0)
289
327
  super().__init__(train_name)
290
328
  self.manual_seed = manual_seed
291
329
  self.dataset = dataset
330
+ self.combine = combine
292
331
 
293
332
  self.model = model.getModel(train=False)
294
333
  self.it = 0
@@ -305,34 +344,37 @@ class Predictor(DistributedObject):
305
344
 
306
345
  self.gpu_checkpoints = gpu_checkpoints
307
346
 
308
- def _load(self) -> dict[str, dict[str, torch.Tensor]]:
309
- if MODEL().startswith("https://"):
310
- try:
311
- state_dict = {MODEL().split(":")[1]: torch.hub.load_state_dict_from_url(url=MODEL().split(":")[0], map_location="cpu", check_hash=True)}
312
- except:
313
- raise Exception("Model : {} does not exist !".format(MODEL()))
314
- else:
315
- if MODEL() != "":
316
- path = ""
317
- name = MODEL()
347
+ def _load(self) -> list[dict[str, dict[str, torch.Tensor]]]:
348
+ model_paths = MODEL().split(":")
349
+ state_dicts = []
350
+ for model_path in model_paths:
351
+ if model_path.startswith("https://"):
352
+ try:
353
+ state_dicts.append(torch.hub.load_state_dict_from_url(url=model_path, map_location="cpu", check_hash=True))
354
+ except:
355
+ raise Exception("Model : {} does not exist !".format(model_path))
318
356
  else:
319
- if self.name.endswith(".pt"):
320
- path = MODELS_DIRECTORY()+"/".join(self.name.split("/")[:-1])+"/StateDict/"
321
- name = self.name.split("/")[-1]
357
+ if model_path != "":
358
+ path = ""
359
+ name = model_path
322
360
  else:
323
- path = MODELS_DIRECTORY()+self.name+"/StateDict/"
324
- name = sorted(os.listdir(path))[-1]
325
- if os.path.exists(path+name):
326
- state_dict = torch.load(path+name, weights_only=False)
327
- else:
328
- raise Exception("Model : {} does not exist !".format(path+name))
329
- return state_dict
361
+ if self.name.endswith(".pt"):
362
+ path = MODELS_DIRECTORY()+"/".join(self.name.split("/")[:-1])+"/StateDict/"
363
+ name = self.name.split("/")[-1]
364
+ else:
365
+ path = MODELS_DIRECTORY()+self.name+"/StateDict/"
366
+ name = sorted(os.listdir(path))[-1]
367
+ if os.path.exists(path+name):
368
+ state_dicts.append(torch.load(path+name, weights_only=False))
369
+ else:
370
+ raise Exception("Model : {} does not exist !".format(path+name))
371
+ return state_dicts
330
372
 
331
373
  def setup(self, world_size: int):
332
374
  for dataset_filename in self.datasets_filename:
333
375
  path = self.predict_path +dataset_filename
334
376
  if os.path.exists(path):
335
- if os.environ["DL_API_OVERWRITE"] != "True":
377
+ if os.environ["KONFAI_OVERWRITE"] != "True":
336
378
  accept = builtins.input("The prediction {} already exists ! Do you want to overwrite it (yes,no) : ".format(path))
337
379
  if accept != "yes":
338
380
  return
@@ -345,11 +387,11 @@ class Predictor(DistributedObject):
345
387
  self.model.init(autocast=False, state = State.PREDICTION)
346
388
  self.model.init_outputsGroup()
347
389
  self.model._compute_channels_trace(self.model, self.model.in_channels, None, self.gpu_checkpoints)
348
- self.model.load(self._load(), init=False)
349
-
350
- if len(list(self.outsDataset.keys())) == 0 and len([network for network in self.model.getNetworks().values() if network.measure is not None]) == 0:
351
- exit(0)
390
+ self.modelComposite = ModelComposite(self.model, len(MODEL().split(":")), self.combine)
391
+ self.modelComposite.load(self._load())
352
392
 
393
+ if len(list(self.outsDataset.keys())) == 0 and len([network for network in self.modelComposite.getNetworks().values() if network.measure is not None]) == 0:
394
+ exit(0)
353
395
 
354
396
  self.size = (len(self.gpu_checkpoints)+1 if self.gpu_checkpoints else 1)
355
397
  self.dataloader = self.dataset.getData(world_size//self.size)
@@ -358,9 +400,9 @@ class Predictor(DistributedObject):
358
400
 
359
401
 
360
402
  def run_process(self, world_size: int, global_rank: int, local_rank: int, dataloaders: list[DataLoader]):
361
- model = Network.to(self.model, local_rank*self.size)
362
- model = DDP(model, static_graph=True) if torch.cuda.is_available() else CPU_Model(model)
363
- with _Predictor(world_size, global_rank, local_rank, self.predict_path, self.images_log, self.outsDataset, model, *dataloaders) as p:
403
+ modelComposite = Network.to(self.modelComposite, local_rank*self.size)
404
+ modelComposite = DDP(modelComposite, static_graph=True) if torch.cuda.is_available() else CPU_Model(modelComposite)
405
+ with _Predictor(world_size, global_rank, local_rank, self.predict_path, self.images_log, self.outsDataset, modelComposite, *dataloaders) as p:
364
406
  p.run()
365
407
 
366
408
 
konfai/trainer.py CHANGED
@@ -12,8 +12,8 @@ from torch.utils.tensorboard.writer import SummaryWriter
12
12
  from torch.optim.swa_utils import AveragedModel
13
13
  import torch.distributed as dist
14
14
 
15
- from konfai import MODELS_DIRECTORY, CHECKPOINTS_DIRECTORY, STATISTICS_DIRECTORY, SETUPS_DIRECTORY, CONFIG_FILE, MODEL, DATE, DL_API_STATE
16
- from konfai.data.dataset import DataTrain
15
+ from konfai import MODELS_DIRECTORY, CHECKPOINTS_DIRECTORY, STATISTICS_DIRECTORY, SETUPS_DIRECTORY, CONFIG_FILE, MODEL, DATE, KONFAI_STATE
16
+ from konfai.data.data_manager import DataTrain
17
17
  from konfai.utils.config import config
18
18
  from konfai.utils.utils import State, DataLog, DistributedObject, description
19
19
  from konfai.network.network import Network, ModelLoader, NetState, CPU_Model
@@ -72,7 +72,7 @@ class _Trainer():
72
72
  self.modelEMA.module.setState(NetState.TRAIN)
73
73
 
74
74
  desc = lambda : "Training : {}".format(description(self.model, self.modelEMA))
75
- with tqdm.tqdm(iterable = enumerate(self.dataloader_training), desc = desc(), total=len(self.dataloader_training), leave=False, disable=self.global_rank != 0 and "DL_API_CLUSTER" not in os.environ) as batch_iter:
75
+ with tqdm.tqdm(iterable = enumerate(self.dataloader_training), desc = desc(), total=len(self.dataloader_training), leave=False, disable=self.global_rank != 0 and "KONFAI_CLUSTER" not in os.environ) as batch_iter:
76
76
  for _, data_dict in batch_iter:
77
77
  with torch.amp.autocast('cuda', enabled=self.autocast):
78
78
  input = self.getInput(data_dict)
@@ -103,7 +103,7 @@ class _Trainer():
103
103
  desc = lambda : "Validation : {}".format(description(self.model, self.modelEMA))
104
104
  data_dict = None
105
105
  self.dataloader_validation.dataset.load()
106
- with tqdm.tqdm(iterable = enumerate(self.dataloader_validation), desc = desc(), total=len(self.dataloader_validation), leave=False, disable=self.global_rank != 0 and "DL_API_CLUSTER" not in os.environ) as batch_iter:
106
+ with tqdm.tqdm(iterable = enumerate(self.dataloader_validation), desc = desc(), total=len(self.dataloader_validation), leave=False, disable=self.global_rank != 0 and "KONFAI_CLUSTER" not in os.environ) as batch_iter:
107
107
  for _, data_dict in batch_iter:
108
108
  input = self.getInput(data_dict)
109
109
  self.model(input)
@@ -205,7 +205,7 @@ class Trainer(DistributedObject):
205
205
  def __init__( self,
206
206
  model : ModelLoader = ModelLoader(),
207
207
  dataset : DataTrain = DataTrain(),
208
- train_name : str = "default:name",
208
+ train_name : str = "default:TRAIN_01",
209
209
  manual_seed : Union[int, None] = None,
210
210
  epochs: int = 100,
211
211
  it_validation : Union[int, None] = None,
@@ -215,7 +215,7 @@ class Trainer(DistributedObject):
215
215
  ema_decay : float = 0,
216
216
  data_log: Union[list[str], None] = None,
217
217
  save_checkpoint_mode: str= "BEST") -> None:
218
- if os.environ["DEEP_LEANING_API_CONFIG_MODE"] != "Done":
218
+ if os.environ["KONFAI_CONFIG_MODE"] != "Done":
219
219
  exit(0)
220
220
  super().__init__(train_name)
221
221
  self.manual_seed = manual_seed
@@ -292,9 +292,9 @@ class Trainer(DistributedObject):
292
292
  return (1-self.ema_decay) * averaged_model_parameter + self.ema_decay * model_parameter
293
293
 
294
294
  def setup(self, world_size: int):
295
- state = State._member_map_[DL_API_STATE()]
296
- if state != State.RESUME and os.path.exists(STATISTICS_DIRECTORY()+self.name+"/"):
297
- if os.environ["DL_API_OVERWRITE"] != "True":
295
+ state = State._member_map_[KONFAI_STATE()]
296
+ if state != State.RESUME and os.path.exists(CHECKPOINTS_DIRECTORY()+self.name+"/"):
297
+ if os.environ["KONFAI_OVERWRITE"] != "True":
298
298
  accept = input("The model {} already exists ! Do you want to overwrite it (yes,no) : ".format(self.name))
299
299
  if accept != "yes":
300
300
  return
konfai/utils/config.py CHANGED
@@ -3,9 +3,8 @@ import ruamel.yaml
3
3
  import inspect
4
4
  import collections
5
5
  from copy import deepcopy
6
- from typing import Union
6
+ from typing import Union, Literal, get_origin, get_args
7
7
  import torch
8
-
9
8
  from konfai import CONFIG_FILE
10
9
 
11
10
  yaml = ruamel.yaml.YAML()
@@ -26,7 +25,7 @@ class Config():
26
25
  if not os.path.exists(self.filename):
27
26
  result = input("Create a new config file ? [no,yes,interactive] : ")
28
27
  if result in ["yes", "interactive"]:
29
- os.environ["DEEP_LEANING_API_CONFIG_MODE"] = "interactive" if result == "interactive" else "default"
28
+ os.environ["KONFAI_CONFIG_MODE"] = "interactive" if result == "interactive" else "default"
30
29
  else:
31
30
  exit(0)
32
31
  with open(self.filename, "w") as f:
@@ -69,7 +68,7 @@ class Config():
69
68
 
70
69
  def __exit__(self, type, value, traceback) -> None:
71
70
  self.yml.close()
72
- if os.environ["DEEP_LEANING_API_CONFIG_MODE"] == "remove":
71
+ if os.environ["KONFAI_CONFIG_MODE"] == "remove":
73
72
  if os.path.exists(CONFIG_FILE()):
74
73
  os.remove(CONFIG_FILE())
75
74
  return
@@ -87,22 +86,22 @@ class Config():
87
86
  except:
88
87
  result = input("\nKeep a default configuration file ? (yes,no) : ")
89
88
  if result == "yes":
90
- os.environ["DEEP_LEANING_API_CONFIG_MODE"] = "default"
89
+ os.environ["KONFAI_CONFIG_MODE"] = "default"
91
90
  else:
92
- os.environ["DEEP_LEANING_API_CONFIG_MODE"] = "remove"
91
+ os.environ["KONFAI_CONFIG_MODE"] = "remove"
93
92
  exit(0)
94
93
  return default.split(":")[1] if len(default.split(":")) > 1 else default
95
94
 
96
95
  @staticmethod
97
96
  def _getInputDefault(name : str, default : Union[str, None], isList : bool = False) -> Union[list[Union[str, None]], str, None]:
98
97
  if isinstance(default, str) and (default == "default" or (len(default.split(":")) > 1 and default.split(":")[0] == "default")):
99
- if os.environ["DEEP_LEANING_API_CONFIG_MODE"] == "interactive":
98
+ if os.environ["KONFAI_CONFIG_MODE"] == "interactive":
100
99
  if isList:
101
100
  list_tmp = []
102
101
  key_tmp = "OK"
103
- while key_tmp != "!" and os.environ["DEEP_LEANING_API_CONFIG_MODE"] == "interactive":
102
+ while (key_tmp != "!" and key_tmp != " ") and os.environ["KONFAI_CONFIG_MODE"] == "interactive":
104
103
  key_tmp = Config._getInput(name, default)
105
- if key_tmp != "!":
104
+ if (key_tmp != "!" and key_tmp != " "):
106
105
  if key_tmp == "":
107
106
  key_tmp = default.split(":")[1] if len(default.split(":")) > 1 else default
108
107
  list_tmp.append(key_tmp)
@@ -134,6 +133,7 @@ class Config():
134
133
  list_tmp = []
135
134
  for key in value_config:
136
135
  list_tmp.extend(Config._getInputDefault(name, key, isList=True))
136
+
137
137
 
138
138
  value = list_tmp
139
139
  value_config = list_tmp
@@ -155,7 +155,7 @@ class Config():
155
155
  dict_value[key] = value_tmp
156
156
  value = dict_value
157
157
  if isinstance(self.config, str):
158
- os.environ['DEEP_LEARNING_API_CONFIG_VARIABLE'] = "True"
158
+ os.environ['KONFAI_CONFIG_VARIABLE'] = "True"
159
159
  return None
160
160
 
161
161
  self.config[name] = value_config if value_config is not None else "None"
@@ -169,17 +169,30 @@ def config(key : Union[str, None] = None):
169
169
  if "config" in kwargs:
170
170
  filename = kwargs["config"]
171
171
  if filename == None:
172
- filename = os.environ['DEEP_LEARNING_API_CONFIG_FILE']
172
+ filename = os.environ['KONFAI_CONFIG_FILE']
173
173
  else:
174
- os.environ['DEEP_LEARNING_API_CONFIG_FILE'] = filename
174
+ os.environ['KONFAI_CONFIG_FILE'] = filename
175
175
  key_tmp = kwargs["DL_args"]+("."+key if key is not None else "") if "DL_args" in kwargs else key
176
176
  without = kwargs["DL_without"] if "DL_without" in kwargs else []
177
- os.environ['DEEP_LEARNING_API_CONFIG_PATH'] = key_tmp
177
+ os.environ['KONFAI_CONFIG_PATH'] = key_tmp
178
178
  with Config(filename, key_tmp) as config:
179
- os.environ['DEEP_LEARNING_API_CONFIG_VARIABLE'] = "False"
179
+ os.environ['KONFAI_CONFIG_VARIABLE'] = "False"
180
180
  kwargs = {}
181
181
  for param in list(inspect.signature(function).parameters.values())[len(args):]:
182
+
182
183
  annotation = param.annotation
184
+ # --- support Literal ---
185
+ if get_origin(annotation) is Literal:
186
+ allowed_values = get_args(annotation)
187
+ default_value = param.default if param.default != inspect._empty else allowed_values[0]
188
+ value = config.getValue(param.name, f"default:{default_value}")
189
+ if value not in allowed_values:
190
+ raise ValueError(
191
+ f"[Config] Invalid value '{value}' for parameter '{param.name}'. "
192
+ f"Expected one of: {allowed_values}."
193
+ )
194
+ kwargs[param.name] = value
195
+ continue
183
196
  if str(annotation).startswith("typing.Union") or str(annotation).startswith("typing.Optional"):
184
197
  for i in annotation.__args__:
185
198
  annotation = i
@@ -188,8 +201,24 @@ def config(key : Union[str, None] = None):
188
201
  continue
189
202
  if not annotation == inspect._empty:
190
203
  if annotation not in [int, str, bool, float, torch.Tensor]:
191
- if str(annotation).startswith("list") or str(annotation).startswith("tuple") or str(annotation).startswith("typing.Tuple"):
192
- if annotation.__args__[0] in [int, str, bool, float]:
204
+ if str(annotation).startswith("list") or str(annotation).startswith("tuple") or str(annotation).startswith("typing.Tuple") or str(annotation).startswith("typing.List"):
205
+ elem_type = annotation.__args__[0]
206
+ values = config.getValue(param.name, param.default)
207
+ if getattr(elem_type, '__origin__', None) is Union:
208
+ valid_types = elem_type.__args__
209
+ result = []
210
+ for v in values:
211
+ for t in valid_types:
212
+ try:
213
+ if t == torch.Tensor and not isinstance(v, torch.Tensor):
214
+ v = torch.tensor(v)
215
+ result.append(t(v) if t != torch.Tensor else v)
216
+ break
217
+ except Exception:
218
+ continue
219
+ kwargs[param.name] = result
220
+
221
+ elif annotation.__args__[0] in [int, str, bool, float]:
193
222
  values = config.getValue(param.name, param.default)
194
223
  kwargs[param.name] = values
195
224
  else:
@@ -204,9 +233,13 @@ def config(key : Union[str, None] = None):
204
233
  else:
205
234
  raise ConfigError()
206
235
  else:
207
- kwargs[param.name] = annotation(config = filename, DL_args = key_tmp)
208
- if os.environ['DEEP_LEARNING_API_CONFIG_VARIABLE'] == "True":
209
- os.environ['DEEP_LEARNING_API_CONFIG_VARIABLE'] = "False"
236
+ try:
237
+ kwargs[param.name] = annotation(config = filename, DL_args = key_tmp)
238
+ except Exception as e:
239
+ raise ValueError("[Config] Failed to instantiate {} with type {}".format(param.name, annotation.__name__))
240
+
241
+ if os.environ['KONFAI_CONFIG_VARIABLE'] == "True":
242
+ os.environ['KONFAI_CONFIG_VARIABLE'] = "False"
210
243
  kwargs[param.name] = None
211
244
  else:
212
245
  kwargs[param.name] = config.getValue(param.name, param.default)
konfai/utils/dataset.py CHANGED
@@ -750,7 +750,7 @@ class Dataset():
750
750
  else:
751
751
  with Dataset.File(self.filename, True, self.format) as file:
752
752
  names = file.getNames(groups)
753
- return [name for i, name in enumerate(names) if index is None or i in index]
753
+ return [name for i, name in enumerate(sorted(names)) if index is None or i in index]
754
754
 
755
755
  def getInfos(self, groups: str, name: str) -> tuple[list[int], Attribute]:
756
756
  if self.is_directory: