konfai 1.0.7__py3-none-any.whl → 1.0.9__py3-none-any.whl
This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.
Potentially problematic release.
This version of konfai might be problematic. Click here for more details.
- konfai/__init__.py +10 -10
- konfai/data/augmentation.py +4 -4
- konfai/data/{dataset.py → data_manager.py} +21 -25
- konfai/data/{HDF5.py → patching.py} +6 -6
- konfai/data/transform.py +3 -3
- konfai/evaluator.py +6 -6
- konfai/main.py +29 -22
- konfai/metric/measure.py +7 -4
- konfai/models/classification/convNeXt.py +1 -1
- konfai/models/classification/resnet.py +1 -1
- konfai/models/generation/cStyleGan.py +1 -1
- konfai/models/generation/ddpm.py +1 -1
- konfai/models/generation/diffusionGan.py +1 -1
- konfai/models/generation/gan.py +1 -1
- konfai/models/segmentation/NestedUNet.py +1 -1
- konfai/models/segmentation/UNet.py +1 -1
- konfai/network/network.py +18 -23
- konfai/predictor.py +90 -48
- konfai/trainer.py +9 -9
- konfai/utils/config.py +52 -19
- konfai/utils/dataset.py +1 -1
- konfai/utils/utils.py +74 -59
- {konfai-1.0.7.dist-info → konfai-1.0.9.dist-info}/METADATA +5 -3
- konfai-1.0.9.dist-info/RECORD +39 -0
- konfai-1.0.7.dist-info/RECORD +0 -39
- {konfai-1.0.7.dist-info → konfai-1.0.9.dist-info}/WHEEL +0 -0
- {konfai-1.0.7.dist-info → konfai-1.0.9.dist-info}/entry_points.txt +0 -0
- {konfai-1.0.7.dist-info → konfai-1.0.9.dist-info}/licenses/LICENSE +0 -0
- {konfai-1.0.7.dist-info → konfai-1.0.9.dist-info}/top_level.txt +0 -0
konfai/predictor.py
CHANGED
|
@@ -6,12 +6,12 @@ import torch
|
|
|
6
6
|
import tqdm
|
|
7
7
|
import os
|
|
8
8
|
|
|
9
|
-
from konfai import MODELS_DIRECTORY, PREDICTIONS_DIRECTORY, CONFIG_FILE, MODEL,
|
|
9
|
+
from konfai import MODELS_DIRECTORY, PREDICTIONS_DIRECTORY, CONFIG_FILE, MODEL, KONFAI_ROOT
|
|
10
10
|
from konfai.utils.config import config
|
|
11
11
|
from konfai.utils.utils import State, get_patch_slices_from_nb_patch_per_dim, NeedDevice, _getModule, DistributedObject, DataLog, description
|
|
12
12
|
from konfai.utils.dataset import Dataset, Attribute
|
|
13
|
-
from konfai.data.
|
|
14
|
-
from konfai.data.
|
|
13
|
+
from konfai.data.data_manager import DataPrediction, DatasetIter
|
|
14
|
+
from konfai.data.patching import Accumulator, PathCombine
|
|
15
15
|
from konfai.network.network import ModelLoader, Network, NetState, CPU_Model
|
|
16
16
|
from konfai.data.transform import Transform, TransformLoader
|
|
17
17
|
|
|
@@ -22,6 +22,8 @@ import torch.distributed as dist
|
|
|
22
22
|
from torch.nn.parallel import DistributedDataParallel as DDP
|
|
23
23
|
from torch.utils.data import DataLoader
|
|
24
24
|
import importlib
|
|
25
|
+
import copy
|
|
26
|
+
from collections import defaultdict
|
|
25
27
|
|
|
26
28
|
class OutDataset(Dataset, NeedDevice, ABC):
|
|
27
29
|
|
|
@@ -50,13 +52,13 @@ class OutDataset(Dataset, NeedDevice, ABC):
|
|
|
50
52
|
|
|
51
53
|
if _transform_type is not None:
|
|
52
54
|
for classpath, transform in _transform_type.items():
|
|
53
|
-
transform = transform.getTransform(classpath, DL_args = "{}.outsDataset.{}.OutDataset.{}".format(
|
|
55
|
+
transform = transform.getTransform(classpath, DL_args = "{}.outsDataset.{}.OutDataset.{}".format(KONFAI_ROOT(), name_layer, name))
|
|
54
56
|
transform.setDatasets(datasets)
|
|
55
57
|
transform_type.append(transform)
|
|
56
58
|
|
|
57
59
|
if self._patchCombine is not None:
|
|
58
|
-
module, name = _getModule(self._patchCombine, "data.
|
|
59
|
-
self.patchCombine = getattr(importlib.import_module(module), name)(config = None, DL_args = "{}.outsDataset.{}.OutDataset".format(
|
|
60
|
+
module, name = _getModule(self._patchCombine, "data.patching")
|
|
61
|
+
self.patchCombine = getattr(importlib.import_module(module), name)(config = None, DL_args = "{}.outsDataset.{}.OutDataset".format(KONFAI_ROOT(), name_layer))
|
|
60
62
|
|
|
61
63
|
def setPatchConfig(self, patchSize: Union[list[int], None], overlap: Union[int, None], nb_data_augmentation: int) -> None:
|
|
62
64
|
if patchSize is not None and overlap is not None:
|
|
@@ -92,7 +94,7 @@ class OutDataset(Dataset, NeedDevice, ABC):
|
|
|
92
94
|
class OutSameAsGroupDataset(OutDataset):
|
|
93
95
|
|
|
94
96
|
@config("OutDataset")
|
|
95
|
-
def __init__(self, dataset_filename: str = "Dataset:
|
|
97
|
+
def __init__(self, dataset_filename: str = "./Dataset:mha", group: str = "default", sameAsGroup: str = "default", pre_transforms : dict[str, TransformLoader] = {"default:Normalize": TransformLoader()}, post_transforms : dict[str, TransformLoader] = {"default:Normalize": TransformLoader()}, final_transforms : dict[str, TransformLoader] = {"default:Normalize": TransformLoader()}, patchCombine: Union[str, None] = None, redution: str = "mean", inverse_transform: bool = True) -> None:
|
|
96
98
|
super().__init__(dataset_filename, group, pre_transforms, post_transforms, final_transforms, patchCombine)
|
|
97
99
|
self.group_src, self.group_dest = sameAsGroup.split(":")
|
|
98
100
|
self.redution = redution
|
|
@@ -152,7 +154,7 @@ class OutSameAsGroupDataset(OutDataset):
|
|
|
152
154
|
if self.redution == "mean":
|
|
153
155
|
result = torch.mean(result.float(), dim=0).to(dtype)
|
|
154
156
|
elif self.redution == "median":
|
|
155
|
-
result, _ = torch.median(result.float(), dim=0)
|
|
157
|
+
result, _ = torch.median(result.float(), dim=0)
|
|
156
158
|
else:
|
|
157
159
|
raise NameError("Reduction method does not exist (mean, median)")
|
|
158
160
|
for transform in self.final_transforms:
|
|
@@ -199,19 +201,19 @@ class OutDatasetLoader():
|
|
|
199
201
|
|
|
200
202
|
class _Predictor():
|
|
201
203
|
|
|
202
|
-
def __init__(self, world_size: int, global_rank: int, local_rank: int, predict_path: str, data_log: Union[list[str], None], outsDataset: dict[str, OutDataset],
|
|
204
|
+
def __init__(self, world_size: int, global_rank: int, local_rank: int, predict_path: str, data_log: Union[list[str], None], outsDataset: dict[str, OutDataset], modelComposite: DDP, dataloader_prediction: DataLoader) -> None:
|
|
203
205
|
self.world_size = world_size
|
|
204
206
|
self.global_rank = global_rank
|
|
205
207
|
self.local_rank = local_rank
|
|
206
208
|
|
|
207
|
-
self.
|
|
209
|
+
self.modelComposite = modelComposite
|
|
208
210
|
self.dataloader_prediction = dataloader_prediction
|
|
209
211
|
self.outsDataset = outsDataset
|
|
210
212
|
|
|
211
213
|
|
|
212
214
|
self.it = 0
|
|
213
215
|
|
|
214
|
-
self.device = self.
|
|
216
|
+
self.device = self.modelComposite.device
|
|
215
217
|
self.dataset: DatasetIter = self.dataloader_prediction.dataset
|
|
216
218
|
patch_size, overlap = self.dataset.getPatchConfig()
|
|
217
219
|
for outDataset in self.outsDataset.values():
|
|
@@ -220,7 +222,7 @@ class _Predictor():
|
|
|
220
222
|
if data_log is not None:
|
|
221
223
|
for data in data_log:
|
|
222
224
|
self.data_log[data.split("/")[0].replace(":", ".")] = (DataLog.__getitem__(data.split("/")[1]).value[0], int(data.split("/")[2]))
|
|
223
|
-
self.tb = SummaryWriter(log_dir = predict_path+"Metric/") if len([network for network in self.
|
|
225
|
+
self.tb = SummaryWriter(log_dir = predict_path+"Metric/") if len([network for network in self.modelComposite.module.getNetworks().values() if network.measure is not None]) or len(self.data_log) else None
|
|
224
226
|
|
|
225
227
|
def __enter__(self):
|
|
226
228
|
return self
|
|
@@ -234,15 +236,15 @@ class _Predictor():
|
|
|
234
236
|
|
|
235
237
|
@torch.no_grad()
|
|
236
238
|
def run(self):
|
|
237
|
-
self.
|
|
238
|
-
self.
|
|
239
|
-
desc = lambda : "Prediction : {}".format(description(self.
|
|
239
|
+
self.modelComposite.eval()
|
|
240
|
+
self.modelComposite.module.setState(NetState.PREDICTION)
|
|
241
|
+
desc = lambda : "Prediction : {}".format(description(self.modelComposite))
|
|
240
242
|
self.dataloader_prediction.dataset.load()
|
|
241
|
-
with tqdm.tqdm(iterable = enumerate(self.dataloader_prediction), leave=False, desc = desc(), total=len(self.dataloader_prediction), disable=self.global_rank != 0 and "
|
|
243
|
+
with tqdm.tqdm(iterable = enumerate(self.dataloader_prediction), leave=False, desc = desc(), total=len(self.dataloader_prediction), disable=self.global_rank != 0 and "KONFAI_CLUSTER" not in os.environ) as batch_iter:
|
|
242
244
|
dist.barrier()
|
|
243
245
|
for it, data_dict in batch_iter:
|
|
244
246
|
input = self.getInput(data_dict)
|
|
245
|
-
for name, output in self.
|
|
247
|
+
for name, output in self.modelComposite(input, list(self.outsDataset.keys())):
|
|
246
248
|
self._predict_log(data_dict)
|
|
247
249
|
outDataset = self.outsDataset[name]
|
|
248
250
|
for i, (index, patch_augmentation, patch_index) in enumerate([(int(index), int(patch_augmentation), int(patch_index)) for index, patch_augmentation, patch_index in zip(list(data_dict.values())[0][1], list(data_dict.values())[0][2], list(data_dict.values())[0][3])]):
|
|
@@ -254,7 +256,7 @@ class _Predictor():
|
|
|
254
256
|
self.it += 1
|
|
255
257
|
|
|
256
258
|
def _predict_log(self, data_dict : dict[str, tuple[torch.Tensor, int, int, int]]):
|
|
257
|
-
measures = DistributedObject.getMeasure(self.world_size, self.global_rank, self.local_rank, {"" : self.
|
|
259
|
+
measures = DistributedObject.getMeasure(self.world_size, self.global_rank, self.local_rank, {"" : self.modelComposite.module}, 1)
|
|
258
260
|
|
|
259
261
|
if self.global_rank == 0:
|
|
260
262
|
images_log = []
|
|
@@ -265,30 +267,67 @@ class _Predictor():
|
|
|
265
267
|
else:
|
|
266
268
|
images_log.append(name.replace(":", "."))
|
|
267
269
|
|
|
268
|
-
for name, network in self.
|
|
270
|
+
for name, network in self.modelComposite.module.getNetworks().items():
|
|
269
271
|
if network.measure is not None:
|
|
270
272
|
self.tb.add_scalars("Prediction/{}/Loss".format(name), {k : v[1] for k, v in measures["{}{}".format(name, "")][0].items()}, self.it)
|
|
271
273
|
self.tb.add_scalars("Prediction/{}/Metric".format(name), {k : v[1] for k, v in measures["{}{}".format(name, "")][1].items()}, self.it)
|
|
272
274
|
if len(images_log):
|
|
273
275
|
for name, layer, _ in self.model.module.get_layers([v.to(0) for k, v in self.getInput(data_dict).items() if k[1]], images_log):
|
|
274
276
|
self.data_log[name][0](self.tb, "Prediction/{}".format(name), layer[:self.data_log[name][1]].detach().cpu().numpy(), self.it)
|
|
277
|
+
|
|
278
|
+
class ModelComposite(Network):
|
|
279
|
+
|
|
280
|
+
def __init__(self, model: Network, nb_models: int, method: str):
|
|
281
|
+
super().__init__(model.in_channels, model.optimizer, model.schedulers, model.outputsCriterionsLoader, model.patch, model.nb_batch_per_step, model.init_type, model.init_gain, model.dim)
|
|
282
|
+
self.method = method
|
|
283
|
+
for i in range(nb_models):
|
|
284
|
+
self.add_module("Model_{}".format(i), copy.deepcopy(model), in_branch=[0], out_branch=["output_{}".format(i)])
|
|
285
|
+
|
|
286
|
+
def load(self, state_dicts : list[dict[str, dict[str, torch.Tensor]]]):
|
|
287
|
+
for i, state_dict in enumerate(state_dicts):
|
|
288
|
+
self["Model_{}".format(i)].load(state_dict, init=False)
|
|
289
|
+
self["Model_{}".format(i)].setName("{}_{}".format(self["Model_{}".format(i)].getName(), i))
|
|
290
|
+
|
|
291
|
+
def forward(self, data_dict: dict[tuple[str, bool], torch.Tensor], output_layers: list[str] = []) -> list[tuple[str, torch.Tensor]]:
|
|
292
|
+
result = {}
|
|
293
|
+
for name, module in self.items():
|
|
294
|
+
result[name] = module(data_dict, output_layers)
|
|
275
295
|
|
|
296
|
+
aggregated = defaultdict(list)
|
|
297
|
+
for module_outputs in result.values():
|
|
298
|
+
for key, tensor in module_outputs:
|
|
299
|
+
aggregated[key].append(tensor)
|
|
300
|
+
|
|
301
|
+
final_outputs = []
|
|
302
|
+
for key, tensors in aggregated.items():
|
|
303
|
+
stacked = torch.stack(tensors, dim=0)
|
|
304
|
+
if self.method == 'mean':
|
|
305
|
+
agg = torch.mean(stacked, dim=0)
|
|
306
|
+
elif self.method == 'median':
|
|
307
|
+
agg = torch.median(stacked, dim=0).values
|
|
308
|
+
final_outputs.append((key, agg))
|
|
309
|
+
|
|
310
|
+
return final_outputs
|
|
311
|
+
|
|
312
|
+
|
|
276
313
|
class Predictor(DistributedObject):
|
|
277
314
|
|
|
278
315
|
@config("Predictor")
|
|
279
316
|
def __init__(self,
|
|
280
317
|
model: ModelLoader = ModelLoader(),
|
|
281
318
|
dataset: DataPrediction = DataPrediction(),
|
|
319
|
+
combine: str = "mean",
|
|
282
320
|
train_name: str = "name",
|
|
283
321
|
manual_seed : Union[int, None] = None,
|
|
284
322
|
gpu_checkpoints: Union[list[str], None] = None,
|
|
285
323
|
outsDataset: Union[dict[str, OutDatasetLoader], None] = {"default:Default" : OutDatasetLoader()},
|
|
286
324
|
images_log: list[str] = []) -> None:
|
|
287
|
-
if os.environ["
|
|
325
|
+
if os.environ["KONFAI_CONFIG_MODE"] != "Done":
|
|
288
326
|
exit(0)
|
|
289
327
|
super().__init__(train_name)
|
|
290
328
|
self.manual_seed = manual_seed
|
|
291
329
|
self.dataset = dataset
|
|
330
|
+
self.combine = combine
|
|
292
331
|
|
|
293
332
|
self.model = model.getModel(train=False)
|
|
294
333
|
self.it = 0
|
|
@@ -305,34 +344,37 @@ class Predictor(DistributedObject):
|
|
|
305
344
|
|
|
306
345
|
self.gpu_checkpoints = gpu_checkpoints
|
|
307
346
|
|
|
308
|
-
def _load(self) -> dict[str, dict[str, torch.Tensor]]:
|
|
309
|
-
|
|
310
|
-
|
|
311
|
-
|
|
312
|
-
|
|
313
|
-
|
|
314
|
-
|
|
315
|
-
|
|
316
|
-
|
|
317
|
-
name = MODEL()
|
|
347
|
+
def _load(self) -> list[dict[str, dict[str, torch.Tensor]]]:
|
|
348
|
+
model_paths = MODEL().split(":")
|
|
349
|
+
state_dicts = []
|
|
350
|
+
for model_path in model_paths:
|
|
351
|
+
if model_path.startswith("https://"):
|
|
352
|
+
try:
|
|
353
|
+
state_dicts.append(torch.hub.load_state_dict_from_url(url=model_path, map_location="cpu", check_hash=True))
|
|
354
|
+
except:
|
|
355
|
+
raise Exception("Model : {} does not exist !".format(model_path))
|
|
318
356
|
else:
|
|
319
|
-
if
|
|
320
|
-
path =
|
|
321
|
-
name =
|
|
357
|
+
if model_path != "":
|
|
358
|
+
path = ""
|
|
359
|
+
name = model_path
|
|
322
360
|
else:
|
|
323
|
-
|
|
324
|
-
|
|
325
|
-
|
|
326
|
-
|
|
327
|
-
|
|
328
|
-
|
|
329
|
-
|
|
361
|
+
if self.name.endswith(".pt"):
|
|
362
|
+
path = MODELS_DIRECTORY()+"/".join(self.name.split("/")[:-1])+"/StateDict/"
|
|
363
|
+
name = self.name.split("/")[-1]
|
|
364
|
+
else:
|
|
365
|
+
path = MODELS_DIRECTORY()+self.name+"/StateDict/"
|
|
366
|
+
name = sorted(os.listdir(path))[-1]
|
|
367
|
+
if os.path.exists(path+name):
|
|
368
|
+
state_dicts.append(torch.load(path+name, weights_only=False))
|
|
369
|
+
else:
|
|
370
|
+
raise Exception("Model : {} does not exist !".format(path+name))
|
|
371
|
+
return state_dicts
|
|
330
372
|
|
|
331
373
|
def setup(self, world_size: int):
|
|
332
374
|
for dataset_filename in self.datasets_filename:
|
|
333
375
|
path = self.predict_path +dataset_filename
|
|
334
376
|
if os.path.exists(path):
|
|
335
|
-
if os.environ["
|
|
377
|
+
if os.environ["KONFAI_OVERWRITE"] != "True":
|
|
336
378
|
accept = builtins.input("The prediction {} already exists ! Do you want to overwrite it (yes,no) : ".format(path))
|
|
337
379
|
if accept != "yes":
|
|
338
380
|
return
|
|
@@ -345,11 +387,11 @@ class Predictor(DistributedObject):
|
|
|
345
387
|
self.model.init(autocast=False, state = State.PREDICTION)
|
|
346
388
|
self.model.init_outputsGroup()
|
|
347
389
|
self.model._compute_channels_trace(self.model, self.model.in_channels, None, self.gpu_checkpoints)
|
|
348
|
-
self.
|
|
349
|
-
|
|
350
|
-
if len(list(self.outsDataset.keys())) == 0 and len([network for network in self.model.getNetworks().values() if network.measure is not None]) == 0:
|
|
351
|
-
exit(0)
|
|
390
|
+
self.modelComposite = ModelComposite(self.model, len(MODEL().split(":")), self.combine)
|
|
391
|
+
self.modelComposite.load(self._load())
|
|
352
392
|
|
|
393
|
+
if len(list(self.outsDataset.keys())) == 0 and len([network for network in self.modelComposite.getNetworks().values() if network.measure is not None]) == 0:
|
|
394
|
+
exit(0)
|
|
353
395
|
|
|
354
396
|
self.size = (len(self.gpu_checkpoints)+1 if self.gpu_checkpoints else 1)
|
|
355
397
|
self.dataloader = self.dataset.getData(world_size//self.size)
|
|
@@ -358,9 +400,9 @@ class Predictor(DistributedObject):
|
|
|
358
400
|
|
|
359
401
|
|
|
360
402
|
def run_process(self, world_size: int, global_rank: int, local_rank: int, dataloaders: list[DataLoader]):
|
|
361
|
-
|
|
362
|
-
|
|
363
|
-
with _Predictor(world_size, global_rank, local_rank, self.predict_path, self.images_log, self.outsDataset,
|
|
403
|
+
modelComposite = Network.to(self.modelComposite, local_rank*self.size)
|
|
404
|
+
modelComposite = DDP(modelComposite, static_graph=True) if torch.cuda.is_available() else CPU_Model(modelComposite)
|
|
405
|
+
with _Predictor(world_size, global_rank, local_rank, self.predict_path, self.images_log, self.outsDataset, modelComposite, *dataloaders) as p:
|
|
364
406
|
p.run()
|
|
365
407
|
|
|
366
408
|
|
konfai/trainer.py
CHANGED
|
@@ -12,8 +12,8 @@ from torch.utils.tensorboard.writer import SummaryWriter
|
|
|
12
12
|
from torch.optim.swa_utils import AveragedModel
|
|
13
13
|
import torch.distributed as dist
|
|
14
14
|
|
|
15
|
-
from konfai import MODELS_DIRECTORY, CHECKPOINTS_DIRECTORY, STATISTICS_DIRECTORY, SETUPS_DIRECTORY, CONFIG_FILE, MODEL, DATE,
|
|
16
|
-
from konfai.data.
|
|
15
|
+
from konfai import MODELS_DIRECTORY, CHECKPOINTS_DIRECTORY, STATISTICS_DIRECTORY, SETUPS_DIRECTORY, CONFIG_FILE, MODEL, DATE, KONFAI_STATE
|
|
16
|
+
from konfai.data.data_manager import DataTrain
|
|
17
17
|
from konfai.utils.config import config
|
|
18
18
|
from konfai.utils.utils import State, DataLog, DistributedObject, description
|
|
19
19
|
from konfai.network.network import Network, ModelLoader, NetState, CPU_Model
|
|
@@ -72,7 +72,7 @@ class _Trainer():
|
|
|
72
72
|
self.modelEMA.module.setState(NetState.TRAIN)
|
|
73
73
|
|
|
74
74
|
desc = lambda : "Training : {}".format(description(self.model, self.modelEMA))
|
|
75
|
-
with tqdm.tqdm(iterable = enumerate(self.dataloader_training), desc = desc(), total=len(self.dataloader_training), leave=False, disable=self.global_rank != 0 and "
|
|
75
|
+
with tqdm.tqdm(iterable = enumerate(self.dataloader_training), desc = desc(), total=len(self.dataloader_training), leave=False, disable=self.global_rank != 0 and "KONFAI_CLUSTER" not in os.environ) as batch_iter:
|
|
76
76
|
for _, data_dict in batch_iter:
|
|
77
77
|
with torch.amp.autocast('cuda', enabled=self.autocast):
|
|
78
78
|
input = self.getInput(data_dict)
|
|
@@ -103,7 +103,7 @@ class _Trainer():
|
|
|
103
103
|
desc = lambda : "Validation : {}".format(description(self.model, self.modelEMA))
|
|
104
104
|
data_dict = None
|
|
105
105
|
self.dataloader_validation.dataset.load()
|
|
106
|
-
with tqdm.tqdm(iterable = enumerate(self.dataloader_validation), desc = desc(), total=len(self.dataloader_validation), leave=False, disable=self.global_rank != 0 and "
|
|
106
|
+
with tqdm.tqdm(iterable = enumerate(self.dataloader_validation), desc = desc(), total=len(self.dataloader_validation), leave=False, disable=self.global_rank != 0 and "KONFAI_CLUSTER" not in os.environ) as batch_iter:
|
|
107
107
|
for _, data_dict in batch_iter:
|
|
108
108
|
input = self.getInput(data_dict)
|
|
109
109
|
self.model(input)
|
|
@@ -205,7 +205,7 @@ class Trainer(DistributedObject):
|
|
|
205
205
|
def __init__( self,
|
|
206
206
|
model : ModelLoader = ModelLoader(),
|
|
207
207
|
dataset : DataTrain = DataTrain(),
|
|
208
|
-
train_name : str = "default:
|
|
208
|
+
train_name : str = "default:TRAIN_01",
|
|
209
209
|
manual_seed : Union[int, None] = None,
|
|
210
210
|
epochs: int = 100,
|
|
211
211
|
it_validation : Union[int, None] = None,
|
|
@@ -215,7 +215,7 @@ class Trainer(DistributedObject):
|
|
|
215
215
|
ema_decay : float = 0,
|
|
216
216
|
data_log: Union[list[str], None] = None,
|
|
217
217
|
save_checkpoint_mode: str= "BEST") -> None:
|
|
218
|
-
if os.environ["
|
|
218
|
+
if os.environ["KONFAI_CONFIG_MODE"] != "Done":
|
|
219
219
|
exit(0)
|
|
220
220
|
super().__init__(train_name)
|
|
221
221
|
self.manual_seed = manual_seed
|
|
@@ -292,9 +292,9 @@ class Trainer(DistributedObject):
|
|
|
292
292
|
return (1-self.ema_decay) * averaged_model_parameter + self.ema_decay * model_parameter
|
|
293
293
|
|
|
294
294
|
def setup(self, world_size: int):
|
|
295
|
-
state = State._member_map_[
|
|
296
|
-
if state != State.RESUME and os.path.exists(
|
|
297
|
-
if os.environ["
|
|
295
|
+
state = State._member_map_[KONFAI_STATE()]
|
|
296
|
+
if state != State.RESUME and os.path.exists(CHECKPOINTS_DIRECTORY()+self.name+"/"):
|
|
297
|
+
if os.environ["KONFAI_OVERWRITE"] != "True":
|
|
298
298
|
accept = input("The model {} already exists ! Do you want to overwrite it (yes,no) : ".format(self.name))
|
|
299
299
|
if accept != "yes":
|
|
300
300
|
return
|
konfai/utils/config.py
CHANGED
|
@@ -3,9 +3,8 @@ import ruamel.yaml
|
|
|
3
3
|
import inspect
|
|
4
4
|
import collections
|
|
5
5
|
from copy import deepcopy
|
|
6
|
-
from typing import Union
|
|
6
|
+
from typing import Union, Literal, get_origin, get_args
|
|
7
7
|
import torch
|
|
8
|
-
|
|
9
8
|
from konfai import CONFIG_FILE
|
|
10
9
|
|
|
11
10
|
yaml = ruamel.yaml.YAML()
|
|
@@ -26,7 +25,7 @@ class Config():
|
|
|
26
25
|
if not os.path.exists(self.filename):
|
|
27
26
|
result = input("Create a new config file ? [no,yes,interactive] : ")
|
|
28
27
|
if result in ["yes", "interactive"]:
|
|
29
|
-
os.environ["
|
|
28
|
+
os.environ["KONFAI_CONFIG_MODE"] = "interactive" if result == "interactive" else "default"
|
|
30
29
|
else:
|
|
31
30
|
exit(0)
|
|
32
31
|
with open(self.filename, "w") as f:
|
|
@@ -69,7 +68,7 @@ class Config():
|
|
|
69
68
|
|
|
70
69
|
def __exit__(self, type, value, traceback) -> None:
|
|
71
70
|
self.yml.close()
|
|
72
|
-
if os.environ["
|
|
71
|
+
if os.environ["KONFAI_CONFIG_MODE"] == "remove":
|
|
73
72
|
if os.path.exists(CONFIG_FILE()):
|
|
74
73
|
os.remove(CONFIG_FILE())
|
|
75
74
|
return
|
|
@@ -87,22 +86,22 @@ class Config():
|
|
|
87
86
|
except:
|
|
88
87
|
result = input("\nKeep a default configuration file ? (yes,no) : ")
|
|
89
88
|
if result == "yes":
|
|
90
|
-
os.environ["
|
|
89
|
+
os.environ["KONFAI_CONFIG_MODE"] = "default"
|
|
91
90
|
else:
|
|
92
|
-
os.environ["
|
|
91
|
+
os.environ["KONFAI_CONFIG_MODE"] = "remove"
|
|
93
92
|
exit(0)
|
|
94
93
|
return default.split(":")[1] if len(default.split(":")) > 1 else default
|
|
95
94
|
|
|
96
95
|
@staticmethod
|
|
97
96
|
def _getInputDefault(name : str, default : Union[str, None], isList : bool = False) -> Union[list[Union[str, None]], str, None]:
|
|
98
97
|
if isinstance(default, str) and (default == "default" or (len(default.split(":")) > 1 and default.split(":")[0] == "default")):
|
|
99
|
-
if os.environ["
|
|
98
|
+
if os.environ["KONFAI_CONFIG_MODE"] == "interactive":
|
|
100
99
|
if isList:
|
|
101
100
|
list_tmp = []
|
|
102
101
|
key_tmp = "OK"
|
|
103
|
-
while key_tmp != "!" and os.environ["
|
|
102
|
+
while (key_tmp != "!" and key_tmp != " ") and os.environ["KONFAI_CONFIG_MODE"] == "interactive":
|
|
104
103
|
key_tmp = Config._getInput(name, default)
|
|
105
|
-
if key_tmp != "!":
|
|
104
|
+
if (key_tmp != "!" and key_tmp != " "):
|
|
106
105
|
if key_tmp == "":
|
|
107
106
|
key_tmp = default.split(":")[1] if len(default.split(":")) > 1 else default
|
|
108
107
|
list_tmp.append(key_tmp)
|
|
@@ -134,6 +133,7 @@ class Config():
|
|
|
134
133
|
list_tmp = []
|
|
135
134
|
for key in value_config:
|
|
136
135
|
list_tmp.extend(Config._getInputDefault(name, key, isList=True))
|
|
136
|
+
|
|
137
137
|
|
|
138
138
|
value = list_tmp
|
|
139
139
|
value_config = list_tmp
|
|
@@ -155,7 +155,7 @@ class Config():
|
|
|
155
155
|
dict_value[key] = value_tmp
|
|
156
156
|
value = dict_value
|
|
157
157
|
if isinstance(self.config, str):
|
|
158
|
-
os.environ['
|
|
158
|
+
os.environ['KONFAI_CONFIG_VARIABLE'] = "True"
|
|
159
159
|
return None
|
|
160
160
|
|
|
161
161
|
self.config[name] = value_config if value_config is not None else "None"
|
|
@@ -169,17 +169,30 @@ def config(key : Union[str, None] = None):
|
|
|
169
169
|
if "config" in kwargs:
|
|
170
170
|
filename = kwargs["config"]
|
|
171
171
|
if filename == None:
|
|
172
|
-
filename = os.environ['
|
|
172
|
+
filename = os.environ['KONFAI_CONFIG_FILE']
|
|
173
173
|
else:
|
|
174
|
-
os.environ['
|
|
174
|
+
os.environ['KONFAI_CONFIG_FILE'] = filename
|
|
175
175
|
key_tmp = kwargs["DL_args"]+("."+key if key is not None else "") if "DL_args" in kwargs else key
|
|
176
176
|
without = kwargs["DL_without"] if "DL_without" in kwargs else []
|
|
177
|
-
os.environ['
|
|
177
|
+
os.environ['KONFAI_CONFIG_PATH'] = key_tmp
|
|
178
178
|
with Config(filename, key_tmp) as config:
|
|
179
|
-
os.environ['
|
|
179
|
+
os.environ['KONFAI_CONFIG_VARIABLE'] = "False"
|
|
180
180
|
kwargs = {}
|
|
181
181
|
for param in list(inspect.signature(function).parameters.values())[len(args):]:
|
|
182
|
+
|
|
182
183
|
annotation = param.annotation
|
|
184
|
+
# --- support Literal ---
|
|
185
|
+
if get_origin(annotation) is Literal:
|
|
186
|
+
allowed_values = get_args(annotation)
|
|
187
|
+
default_value = param.default if param.default != inspect._empty else allowed_values[0]
|
|
188
|
+
value = config.getValue(param.name, f"default:{default_value}")
|
|
189
|
+
if value not in allowed_values:
|
|
190
|
+
raise ValueError(
|
|
191
|
+
f"[Config] Invalid value '{value}' for parameter '{param.name}'. "
|
|
192
|
+
f"Expected one of: {allowed_values}."
|
|
193
|
+
)
|
|
194
|
+
kwargs[param.name] = value
|
|
195
|
+
continue
|
|
183
196
|
if str(annotation).startswith("typing.Union") or str(annotation).startswith("typing.Optional"):
|
|
184
197
|
for i in annotation.__args__:
|
|
185
198
|
annotation = i
|
|
@@ -188,8 +201,24 @@ def config(key : Union[str, None] = None):
|
|
|
188
201
|
continue
|
|
189
202
|
if not annotation == inspect._empty:
|
|
190
203
|
if annotation not in [int, str, bool, float, torch.Tensor]:
|
|
191
|
-
if str(annotation).startswith("list") or str(annotation).startswith("tuple") or str(annotation).startswith("typing.Tuple"):
|
|
192
|
-
|
|
204
|
+
if str(annotation).startswith("list") or str(annotation).startswith("tuple") or str(annotation).startswith("typing.Tuple") or str(annotation).startswith("typing.List"):
|
|
205
|
+
elem_type = annotation.__args__[0]
|
|
206
|
+
values = config.getValue(param.name, param.default)
|
|
207
|
+
if getattr(elem_type, '__origin__', None) is Union:
|
|
208
|
+
valid_types = elem_type.__args__
|
|
209
|
+
result = []
|
|
210
|
+
for v in values:
|
|
211
|
+
for t in valid_types:
|
|
212
|
+
try:
|
|
213
|
+
if t == torch.Tensor and not isinstance(v, torch.Tensor):
|
|
214
|
+
v = torch.tensor(v)
|
|
215
|
+
result.append(t(v) if t != torch.Tensor else v)
|
|
216
|
+
break
|
|
217
|
+
except Exception:
|
|
218
|
+
continue
|
|
219
|
+
kwargs[param.name] = result
|
|
220
|
+
|
|
221
|
+
elif annotation.__args__[0] in [int, str, bool, float]:
|
|
193
222
|
values = config.getValue(param.name, param.default)
|
|
194
223
|
kwargs[param.name] = values
|
|
195
224
|
else:
|
|
@@ -204,9 +233,13 @@ def config(key : Union[str, None] = None):
|
|
|
204
233
|
else:
|
|
205
234
|
raise ConfigError()
|
|
206
235
|
else:
|
|
207
|
-
|
|
208
|
-
|
|
209
|
-
|
|
236
|
+
try:
|
|
237
|
+
kwargs[param.name] = annotation(config = filename, DL_args = key_tmp)
|
|
238
|
+
except Exception as e:
|
|
239
|
+
raise ValueError("[Config] Failed to instantiate {} with type {}".format(param.name, annotation.__name__))
|
|
240
|
+
|
|
241
|
+
if os.environ['KONFAI_CONFIG_VARIABLE'] == "True":
|
|
242
|
+
os.environ['KONFAI_CONFIG_VARIABLE'] = "False"
|
|
210
243
|
kwargs[param.name] = None
|
|
211
244
|
else:
|
|
212
245
|
kwargs[param.name] = config.getValue(param.name, param.default)
|
konfai/utils/dataset.py
CHANGED
|
@@ -750,7 +750,7 @@ class Dataset():
|
|
|
750
750
|
else:
|
|
751
751
|
with Dataset.File(self.filename, True, self.format) as file:
|
|
752
752
|
names = file.getNames(groups)
|
|
753
|
-
return [name for i, name in enumerate(names) if index is None or i in index]
|
|
753
|
+
return [name for i, name in enumerate(sorted(names)) if index is None or i in index]
|
|
754
754
|
|
|
755
755
|
def getInfos(self, groups: str, name: str) -> tuple[list[int], Attribute]:
|
|
756
756
|
if self.is_directory:
|