konfai 1.1.1__py3-none-any.whl → 1.1.2__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of konfai might be problematic. Click here for more details.

konfai/__init__.py CHANGED
@@ -12,5 +12,5 @@ CONFIG_FILE = lambda : os.environ["KONFAI_CONFIG_FILE"]
12
12
  KONFAI_STATE = lambda : os.environ["KONFAI_STATE"]
13
13
  KONFAI_ROOT = lambda : os.environ["KONFAI_ROOT"]
14
14
  CUDA_VISIBLE_DEVICES = lambda : os.environ["CUDA_VISIBLE_DEVICES"]
15
-
15
+ KONFAI_NB_CORES = lambda : os.environ["KONFAI_NB_CORES"]
16
16
  DATE = lambda : datetime.datetime.now().strftime("%Y_%m_%d_%H_%M_%S")
@@ -62,12 +62,25 @@ class GroupTransform:
62
62
  for transform in self.post_transforms:
63
63
  transform.setDevice(device)
64
64
 
65
+ class GroupTransformMetric(GroupTransform):
66
+
67
+ @config()
68
+ def __init__(self, pre_transforms : Union[dict[str, TransformLoader], list[Transform]] = {"default:Normalize:Standardize:Unsqueeze:TensorCast:ResampleIsotropic:ResampleResize": TransformLoader()},
69
+ post_transforms : Union[dict[str, TransformLoader], list[Transform]] = {"default:Normalize:Standardize:Unsqueeze:TensorCast:ResampleIsotropic:ResampleResize": TransformLoader()}):
70
+ super().__init__(pre_transforms, post_transforms)
71
+
65
72
  class Group(dict[str, GroupTransform]):
66
73
 
67
74
  @config()
68
75
  def __init__(self, groups_dest: dict[str, GroupTransform] = {"default:group_dest": GroupTransform()}):
69
76
  super().__init__(groups_dest)
70
77
 
78
+ class GroupMetric(dict[str, GroupTransformMetric]):
79
+
80
+ @config()
81
+ def __init__(self, groups_dest: dict[str, GroupTransformMetric] = {"default:group_dest": GroupTransformMetric()}):
82
+ super().__init__(groups_dest)
83
+
71
84
  class CustomSampler(Sampler[int]):
72
85
 
73
86
  def __init__(self, size: int, shuffle: bool = False) -> None:
@@ -109,32 +122,33 @@ class DatasetIter(data.Dataset):
109
122
  def getDatasetFromIndex(self, group_dest: str, index: int) -> DatasetManager:
110
123
  return self.data[group_dest][index]
111
124
 
112
- def resetAugmentation(self):
113
- if self.inlineAugmentations:
125
+ def resetAugmentation(self, label):
126
+ if self.inlineAugmentations and len(self.dataAugmentationsList) > 0:
114
127
  for index in range(self.nb_dataset):
115
- self._unloadData(index)
116
128
  for group_src in self.groups_src:
117
129
  for group_dest in self.groups_src[group_src]:
130
+ self.data[group_dest][index].unloadAugmentation()
118
131
  self.data[group_dest][index].resetAugmentation()
132
+ self.load(label + " Augmentation")
119
133
 
120
- def load(self):
134
+ def load(self, label: str):
121
135
  if self.use_cache:
122
136
  memory_init = getMemory()
123
137
 
124
- indexs = [index for index in range(self.nb_dataset) if index not in self._index_cache]
138
+ indexs = [index for index in range(self.nb_dataset)]
125
139
  if len(indexs) > 0:
126
140
  memory_lock = threading.Lock()
141
+ desc = lambda : "Caching "+ label +": {} | {} | {}".format(memoryInfo(), memoryForecast(memory_init, 0, self.nb_dataset), cpuInfo())
127
142
  pbar = tqdm.tqdm(
128
143
  total=len(indexs),
129
- desc="Caching : init | {} | {}".format(memoryForecast(memory_init, 0, self.nb_dataset), cpuInfo()),
130
- leave=False,
131
- disable=self.rank != 0 and "KONFAI_CLUSTER" not in os.environ
144
+ desc=desc(),
145
+ leave=False
132
146
  )
133
147
 
134
148
  def process(index):
135
149
  self._loadData(index)
136
150
  with memory_lock:
137
- pbar.set_description("Caching : {} | {} | {}".format(memoryInfo(), memoryForecast(memory_init, index, self.nb_dataset), cpuInfo()))
151
+ pbar.set_description(desc())
138
152
  pbar.update(1)
139
153
  with ThreadPoolExecutor(max_workers=os.cpu_count()//(device_count() if device_count() > 0 else 1)) as executor:
140
154
  futures = [executor.submit(process, index) for index in indexs]
@@ -257,10 +271,10 @@ class Data(ABC):
257
271
  groups_src : dict[str, Group],
258
272
  patch : Union[DatasetPatch, None],
259
273
  use_cache : bool,
260
- subset : Union[Subset, dict[str, Subset]],
274
+ subset : Subset,
261
275
  num_workers : int,
262
276
  batch_size : int,
263
- validation: Union[float, str, list[int], list[str]] = 1,
277
+ validation: Union[float, str, list[int], list[str], None] = None,
264
278
  inlineAugmentations: bool = False,
265
279
  dataAugmentationsList: dict[str, DataAugmentationsList]= {}) -> None:
266
280
  self.dataset_filenames = dataset_filenames
@@ -343,7 +357,7 @@ class Data(ABC):
343
357
  append = flag == "a"
344
358
 
345
359
  if format not in SUPPORTED_EXTENSIONS:
346
- raise DatasetManagerError(f"Unsupported file format '{format}'.",
360
+ raise DatasetManagerError(f"Unsupported file format '{format}'.",
347
361
  f"Supported extensions are: {', '.join(SUPPORTED_EXTENSIONS)}")
348
362
 
349
363
  dataset = Dataset(filename, format)
@@ -362,7 +376,7 @@ class Data(ABC):
362
376
  raise DatasetManagerError(
363
377
  f"Group source '{group_src}' not found in any dataset.",
364
378
  f"Dataset filenames provided: {self.dataset_filenames}",
365
- f"Available groups across all datasets: {sorted(list(datasets.keys()))}",
379
+ "Available groups across all datasets: {}".format(["{} {}".format(f, d.getGroup()) for f, d in self.datasets.items()]),
366
380
  f"Please check that an entry in the dataset with the name '{group_src}.{format}' exists."
367
381
  )
368
382
 
@@ -382,28 +396,38 @@ class Data(ABC):
382
396
  dataset_name : dict[str, dict[str, list[str]]] = {}
383
397
  dataset_info : dict[str, dict[str, dict[str, Attribute]]] = {}
384
398
  for group in self.groups_src:
399
+ namesByGroup = set()
385
400
  if group not in dataset_name:
386
401
  dataset_name[group] = {}
387
402
  dataset_info[group] = {}
388
403
  for filename, _ in datasets[group]:
389
- names.update(self.datasets[filename].getNames(group))
404
+ namesByGroup.update(self.datasets[filename].getNames(group))
390
405
  dataset_name[group][filename] = self.datasets[filename].getNames(group)
391
406
  dataset_info[group][filename] = {name: self.datasets[filename].getInfos(group, name) for name in dataset_name[group][filename]}
407
+ if len(names) == 0:
408
+ names.update(namesByGroup)
409
+ else:
410
+ names = names.intersection(namesByGroup)
411
+ if len(names) == 0:
412
+ raise DatasetManagerError(
413
+ f"No data was found for groups {list(self.groups_src.keys())}: although each group contains data from a dataset, there are no common dataset names shared across all groups, the intersection is empty."
414
+ )
415
+
392
416
  subset_names = set()
393
-
394
- if isinstance(self.subset, dict):
395
- for filename, subset in self.subset.items():
396
- subset_names.update(subset([dataset_name[group][filename] for group in dataset_name], [dataset_info[group][filename] for group in dataset_name]))
397
- else:
398
- for group in dataset_name:
399
- for filename, append in datasets[group]:
400
- if append:
401
- subset_names.update(self.subset([dataset_name[group][filename]], [dataset_info[group][filename]]))
417
+ for group in dataset_name:
418
+ subset_names_bygroup = set()
419
+ for filename, append in datasets[group]:
420
+ if append:
421
+ subset_names_bygroup.update(self.subset([dataset_name[group][filename]], [dataset_info[group][filename]]))
422
+ else:
423
+ if len(subset_names_bygroup) == 0:
424
+ subset_names_bygroup.update(self.subset([dataset_name[group][filename]], [dataset_info[group][filename]]))
402
425
  else:
403
- if len(subset_names) == 0:
404
- subset_names.update(self.subset([dataset_name[group][filename]], [dataset_info[group][filename]]))
405
- else:
406
- subset_names.intersection(self.subset([dataset_name[group][filename]], [dataset_info[group][filename]]))
426
+ subset_names_bygroup = subset_names_bygroup.intersection(self.subset([dataset_name[group][filename]], [dataset_info[group][filename]]))
427
+ if len(subset_names) == 0:
428
+ subset_names.update(subset_names_bygroup)
429
+ else:
430
+ subset_names = subset_names.intersection(subset_names_bygroup)
407
431
  if len(subset_names) == 0:
408
432
  raise DatasetManagerError("All data entries were excluded by the subset filter.",
409
433
  f"Dataset entries found: {', '.join(names)}",
@@ -424,7 +448,7 @@ class Data(ABC):
424
448
  validate_map = []
425
449
  if isinstance(self.validation, float) or isinstance(self.validation, int):
426
450
  if self.validation <= 0 or self.validation >= 1:
427
- raise DatasetManagerError("validation must be a float between 0 and 1.", f"Received: {self.validation}", "Example: validation = 0.2 # for a 20% validation split")
451
+ raise DatasetManagerError("Validation must be a float between 0 and 1.", f"Received: {self.validation}", "Example: validation = 0.2 # for a 20% validation split")
428
452
 
429
453
  train_map, validate_map = map[:int(math.floor(len(map)*(1-self.validation)))], map[int(math.floor(len(map)*(1-self.validation))):]
430
454
  elif isinstance(self.validation, str):
@@ -527,20 +551,19 @@ class DataPrediction(Data):
527
551
  groups_src : dict[str, Group] = {"default" : Group()},
528
552
  augmentations : Union[dict[str, DataAugmentationsList], None] = {"DataAugmentation_0" : DataAugmentationsList()},
529
553
  patch : Union[DatasetPatch, None] = DatasetPatch(),
530
- use_cache : bool = True,
531
554
  subset : Union[PredictionSubset, dict[str, PredictionSubset]] = PredictionSubset(),
532
555
  num_workers : int = 4,
533
556
  batch_size : int = 1) -> None:
534
557
 
535
- super().__init__(dataset_filenames, groups_src, patch, use_cache, subset, num_workers, batch_size, dataAugmentationsList=augmentations if augmentations else {})
558
+ super().__init__(dataset_filenames, groups_src, patch, False, subset, num_workers, batch_size, dataAugmentationsList=augmentations if augmentations else {})
536
559
 
537
560
  class DataMetric(Data):
538
561
 
539
562
  @config("Dataset")
540
563
  def __init__(self, dataset_filenames : list[str] = ["default:./Dataset"],
541
- groups_src : dict[str, Group] = {"default" : Group()},
564
+ groups_src : dict[str, GroupMetric] = {"default" : GroupMetric()},
542
565
  subset : Union[PredictionSubset, dict[str, PredictionSubset]] = PredictionSubset(),
543
566
  validation: Union[str, None] = None,
544
567
  num_workers : int = 4) -> None:
545
568
 
546
- super().__init__(dataset_filenames=dataset_filenames, groups_src=groups_src, patch=None, use_cache=False, subset=subset, num_workers=num_workers, batch_size=1, validation=1 if validation is None else validation)
569
+ super().__init__(dataset_filenames=dataset_filenames, groups_src=groups_src, patch=None, use_cache=False, subset=subset, num_workers=num_workers, batch_size=1, validation=validation)
konfai/data/patching.py CHANGED
@@ -15,7 +15,6 @@ from konfai.utils.dataset import Dataset, Attribute
15
15
  from konfai.data.transform import Transform, Save
16
16
  from konfai.data.augmentation import DataAugmentationsList
17
17
 
18
-
19
18
  class PathCombine(ABC):
20
19
 
21
20
  def __init__(self) -> None:
@@ -240,6 +239,7 @@ class DatasetManager():
240
239
  self.index = index
241
240
  self.dataset = dataset
242
241
  self.loaded = False
242
+ self.augmentationLoaded = False
243
243
  self.cache_attributes: list[Attribute] = []
244
244
  _shape, cache_attribute = self.dataset.getInfos(self.group_src, name)
245
245
  self.cache_attributes.append(cache_attribute)
@@ -258,6 +258,7 @@ class DatasetManager():
258
258
  self.cache_attributes_bak = copy.deepcopy(self.cache_attributes)
259
259
 
260
260
  def resetAugmentation(self):
261
+ self.cache_attributes[:] = self.cache_attributes[:1]
261
262
  i = 1
262
263
  for dataAugmentations in self.dataAugmentationsList:
263
264
  shape = []
@@ -272,15 +273,24 @@ class DatasetManager():
272
273
  self.cache_attributes.append(caches_attribute[it])
273
274
  self.patch.load(s, i)
274
275
  i+=1
275
-
276
+
276
277
  def load(self, pre_transform : list[Transform], dataAugmentationsList : list[DataAugmentationsList], device: torch.device) -> None:
277
- if self.loaded:
278
- return
278
+ if not self.loaded:
279
+ self._load(pre_transform)
280
+ if not self.augmentationLoaded:
281
+ self._loadAugmentation(dataAugmentationsList, device)
282
+
283
+ def _load(self, pre_transform : list[Transform]):
284
+ self.cache_attributes = copy.deepcopy(self.cache_attributes_bak)
279
285
  i = len(pre_transform)
280
286
  data = None
281
287
  for transformFunction in reversed(pre_transform):
282
288
  if isinstance(transformFunction, Save):
283
- filename, format = transformFunction.save.split(":")
289
+ if len(transformFunction.dataset.split(":")) > 1:
290
+ filename, format = transformFunction.dataset.split(":")
291
+ else:
292
+ filename = transformFunction.dataset.split(":")
293
+ format = "mha"
284
294
  dataset = Dataset(filename, format)
285
295
  if dataset.isDatasetExist(self.group_dest, self.name):
286
296
  data, attrib = dataset.readData(self.group_dest, self.name)
@@ -298,27 +308,40 @@ class DatasetManager():
298
308
  for transformFunction in pre_transform[i:]:
299
309
  data = transformFunction(self.name, data, self.cache_attributes[0])
300
310
  if isinstance(transformFunction, Save):
301
- filename, format = transformFunction.save.split(":")
311
+ if len(transformFunction.dataset.split(":")) > 1:
312
+ filename, format = transformFunction.dataset.split(":")
313
+ else:
314
+ filename = transformFunction.dataset.split(":")
315
+ format = "mha"
302
316
  dataset = Dataset(filename, format)
303
317
  dataset.write(self.group_dest, self.name, data.numpy(), self.cache_attributes[0])
304
318
  self.data : list[torch.Tensor] = list()
305
319
  self.data.append(data)
306
320
 
321
+ for i in range(len(self.cache_attributes)-1):
322
+ self.cache_attributes[i+1].update(self.cache_attributes[0])
323
+ self.loaded = True
324
+
325
+ def _loadAugmentation(self, dataAugmentationsList : list[DataAugmentationsList], device: torch.device) -> None:
307
326
  for dataAugmentations in dataAugmentationsList:
308
- a_data = [data.clone() for _ in range(dataAugmentations.nb)]
327
+ a_data = [self.data[0].clone() for _ in range(dataAugmentations.nb)]
309
328
  for dataAugmentation in dataAugmentations.dataAugmentations:
310
329
  a_data = dataAugmentation(self.index, a_data, device)
311
330
 
312
331
  for d in a_data:
313
332
  self.data.append(d)
314
- self.loaded = True
333
+ self.augmentationLoaded = True
315
334
 
316
335
  def unload(self) -> None:
317
- if hasattr(self, "data"):
318
- del self.data
319
- self.cache_attributes = copy.deepcopy(self.cache_attributes_bak)
336
+ self.data.clear()
320
337
  self.loaded = False
321
-
338
+ self.augmentationLoaded = False
339
+
340
+ def unloadAugmentation(self) -> None:
341
+ self.data[:] = self.data[:1]
342
+ self.augmentationLoaded = False
343
+
344
+
322
345
  def getData(self, index : int, a : int, post_transforms : list[Transform], isInput: bool) -> torch.Tensor:
323
346
  data = self.patch.getData(self.data[a], index, a, isInput)
324
347
  for transformFunction in post_transforms:
konfai/data/transform.py CHANGED
@@ -6,7 +6,7 @@ from abc import ABC, abstractmethod
6
6
  import torch.nn.functional as F
7
7
  from typing import Any, Union
8
8
 
9
- from konfai.utils.utils import _getModule, NeedDevice, _resample_affine, _affine_matrix
9
+ from konfai.utils.utils import _getModule, NeedDevice, _resample_affine, _affine_matrix, TransformError
10
10
  from konfai.utils.dataset import Dataset, Attribute, data_to_image, image_to_data
11
11
  from konfai.utils.config import config
12
12
 
@@ -211,39 +211,67 @@ class Resample(Transform, ABC):
211
211
  _ = cache_attribute.pop_np_array("Spacing")
212
212
  return self._resample(input, [int(size) for size in size_1])
213
213
 
214
- class ResampleIsotropic(Resample):
214
+ class ResampleToResolution(Resample):
215
+
216
+ def __init__(self, spacing : list[Union[float, None]] = [1., 1., 1.]) -> None:
217
+ self.spacing = torch.tensor([0 if s < 0 else s for s in spacing])
215
218
 
216
- def __init__(self, spacing : list[float] = [1., 1., 1.]) -> None:
217
- self.spacing = torch.tensor(spacing, dtype=torch.float64)
218
-
219
219
  def transformShape(self, shape: list[int], cache_attribute: Attribute) -> list[int]:
220
- assert "Spacing" in cache_attribute, "Error no spacing"
221
- resize_factor = self.spacing/cache_attribute.get_tensor("Spacing").flip(0)
222
- return [int(x) for x in (torch.tensor(shape) * 1/resize_factor)]
220
+ if "Spacing" not in cache_attribute:
221
+ TransformError("Missing 'Spacing' in cache attributes, the data is likely not a valid image.",
222
+ "Make sure your input is a image (e.g., .nii, .mha) with proper metadata.")
223
+ if len(shape) != len(self.spacing):
224
+ TransformError("Shape and spacing dimensions do not match: shape={shape}, spacing={self.spacing}")
225
+ image_spacing = cache_attribute.get_tensor("Spacing").flip(0)
226
+ spacing = self.spacing
227
+
228
+ for i, s in enumerate(self.spacing):
229
+ if s == 0:
230
+ spacing[i] = image_spacing[i]
231
+ resize_factor = spacing/cache_attribute.get_tensor("Spacing").flip(0)
232
+ return [int(x) for x in (torch.tensor(shape) * 1/resize_factor)]
223
233
 
224
234
  def __call__(self, name: str, input : torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
225
- assert "Spacing" in cache_attribute, "Error no spacing"
226
- resize_factor = self.spacing/cache_attribute.get_tensor("Spacing").flip(0)
227
- cache_attribute["Spacing"] = self.spacing.flip(0)
235
+ image_spacing = cache_attribute.get_tensor("Spacing").flip(0)
236
+ spacing = self.spacing
237
+ for i, s in enumerate(self.spacing):
238
+ if s == 0:
239
+ spacing[i] = image_spacing[i]
240
+ resize_factor = spacing/cache_attribute.get_tensor("Spacing").flip(0)
241
+ cache_attribute["Spacing"] = spacing.flip(0)
228
242
  cache_attribute["Size"] = np.asarray([int(x) for x in torch.tensor(input.shape[1:])])
229
243
  size = [int(x) for x in (torch.tensor(input.shape[1:]) * 1/resize_factor)]
230
244
  cache_attribute["Size"] = np.asarray(size)
231
245
  return self._resample(input, size)
232
246
 
233
- class ResampleResize(Resample):
247
+ class ResampleToSize(Resample):
234
248
 
235
249
  def __init__(self, size : list[int] = [100,512,512]) -> None:
236
250
  self.size = size
237
251
 
238
252
  def transformShape(self, shape: list[int], cache_attribute: Attribute) -> list[int]:
239
- return self.size
253
+ if "Spacing" not in cache_attribute:
254
+ TransformError("Missing 'Spacing' in cache attributes, the data is likely not a valid image.",
255
+ "Make sure your input is a image (e.g., .nii, .mha) with proper metadata.")
256
+ if len(shape) != len(self.size):
257
+ TransformError("Shape and spacing dimensions do not match: shape={shape}, spacing={self.spacing}")
258
+ size = self.size
259
+ for i, s in enumerate(self.size):
260
+ if s == -1:
261
+ size[i] = shape[i]
262
+ return size
240
263
 
241
264
  def __call__(self, name: str, input: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
265
+ size = self.size
266
+ image_size = np.asarray([int(x) for x in torch.tensor(input.shape[1:])])
267
+ for i, s in enumerate(self.size):
268
+ if s is None:
269
+ size[i] = image_size[i]
242
270
  if "Spacing" in cache_attribute:
243
- cache_attribute["Spacing"] = torch.flip(torch.tensor(list(input.shape[1:]))/torch.tensor(self.size)*torch.flip(cache_attribute.get_tensor("Spacing"), dims=[0]), dims=[0])
244
- cache_attribute["Size"] = np.asarray([int(x) for x in torch.tensor(input.shape[1:])])
245
- cache_attribute["Size"] = self.size
246
- return self._resample(input, self.size)
271
+ cache_attribute["Spacing"] = torch.flip(torch.tensor(image_size)/torch.tensor(size)*torch.flip(cache_attribute.get_tensor("Spacing"), dims=[0]), dims=[0])
272
+ cache_attribute["Size"] = image_size
273
+ cache_attribute["Size"] = size
274
+ return self._resample(input, size)
247
275
 
248
276
  class ResampleTransform(Transform):
249
277
 
@@ -412,8 +440,8 @@ class FlatLabel(Transform):
412
440
 
413
441
  class Save(Transform):
414
442
 
415
- def __init__(self, save: str) -> None:
416
- self.save = save
443
+ def __init__(self, dataset: str) -> None:
444
+ self.dataset = dataset
417
445
 
418
446
  def __call__(self, name: str, input : torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
419
447
  return input
@@ -528,7 +556,7 @@ class OneHot(Transform):
528
556
  self.num_classes = num_classes
529
557
 
530
558
  def __call__(self, name: str, input: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
531
- result = F.one_hot(input.type(torch.int64), num_classes=self.num_classes).permute(0, len(input.shape), *[i+1 for i in range(len(input.shape)-1)]).float().squeeze(2)
559
+ result = F.one_hot(input.type(torch.int64), num_classes=self.num_classes).permute(0, len(input.shape), *[i+1 for i in range(len(input.shape)-1)]).float().squeeze(0)
532
560
  return result
533
561
 
534
562
  def inverse(self, name: str, input: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
konfai/evaluator.py CHANGED
@@ -9,7 +9,7 @@ import builtins
9
9
  import importlib
10
10
  from konfai import EVALUATIONS_DIRECTORY, PREDICTIONS_DIRECTORY, KONFAI_ROOT, CONFIG_FILE
11
11
  from konfai.utils.config import config
12
- from konfai.utils.utils import _getModule, DistributedObject, synchronize_data
12
+ from konfai.utils.utils import _getModule, DistributedObject, synchronize_data, EvaluatorError
13
13
  from konfai.data.data_manager import DataMetric
14
14
 
15
15
  class CriterionsAttr():
@@ -54,7 +54,8 @@ class Statistics():
54
54
  if name_dataset not in self.measures:
55
55
  self.measures[name_dataset] = {}
56
56
  self.measures[name_dataset][name] = value
57
-
57
+
58
+ @staticmethod
58
59
  def getStatistic(values: list[float]) -> dict[str, float]:
59
60
  return {"max": np.max(values), "min": np.min(values), "std": np.std(values), "25pc": np.percentile(values, 25), "50pc": np.percentile(values, 50), "75pc": np.percentile(values, 75), "mean": np.mean(values), "count": len(values)}
60
61
 
@@ -91,7 +92,6 @@ class Evaluator(DistributedObject):
91
92
  exit(0)
92
93
  super().__init__(train_name)
93
94
  self.metric_path = EVALUATIONS_DIRECTORY()+self.name+"/"
94
- self.predict_path = PREDICTIONS_DIRECTORY()+self.name+"/"
95
95
  self.metricsLoader = metrics
96
96
  self.dataset = dataset
97
97
  self.metrics = {k: v.getTargetsCriterions(k) for k, v in self.metricsLoader.items()}
@@ -102,10 +102,10 @@ class Evaluator(DistributedObject):
102
102
  result = {}
103
103
  for output_group in self.metrics:
104
104
  for target_group in self.metrics[output_group]:
105
- targets = [data_dict[group][0] for group in target_group.split("/") if group in data_dict]
105
+ targets = [data_dict[group][0].to(0) if torch.cuda.is_available() else data_dict[group][0] for group in target_group.split("/") if group in data_dict]
106
106
  name = data_dict[output_group][1][0]
107
107
  for metric in self.metrics[output_group][target_group]:
108
- result["{}:{}:{}".format(output_group, target_group, metric.__class__.__name__)] = metric(data_dict[output_group][0], *targets).item()
108
+ result["{}:{}:{}".format(output_group, target_group, metric.__class__.__name__)] = metric(data_dict[output_group][0].to(0) if torch.cuda.is_available() else data_dict[output_group][0], *targets).item()
109
109
  statistics.add(result, name)
110
110
  return result
111
111
 
@@ -126,9 +126,26 @@ class Evaluator(DistributedObject):
126
126
 
127
127
  self.dataloader = self.dataset.getData(world_size)
128
128
 
129
+ groupsDest = [group for groups in self.dataset.groups_src.values() for group in groups]
130
+
131
+ missing_outputs = set(self.metrics.keys()) - set(groupsDest)
132
+ if missing_outputs:
133
+ raise EvaluatorError(
134
+ f"The following metric output groups are missing from 'groupsDest': {sorted(missing_outputs)}. ",
135
+ f"Available groups: {sorted(groupsDest)}"
136
+ )
137
+
138
+ target_groups = {target for targets in self.metrics.values() for target in targets}
139
+ missing_targets = target_groups - set(groupsDest)
140
+ if missing_targets:
141
+ raise EvaluatorError(
142
+ f"The following metric target groups are missing from 'groupsDest': {sorted(missing_targets)}. ",
143
+ f"Available groups: {sorted(groupsDest)}"
144
+ )
145
+
129
146
  def run_process(self, world_size: int, global_rank: int, gpu: int, dataloaders: list[DataLoader]):
130
147
  description = lambda measure : "Metric TRAIN : {} ".format(" | ".join("{}: {:.2f}".format(k, v) for k, v in measure.items()) if measure is not None else "")
131
- with tqdm.tqdm(iterable = enumerate(dataloaders[0]), leave=False, desc = description(None), total=len(dataloaders[0])) as batch_iter:
148
+ with tqdm.tqdm(iterable = enumerate(dataloaders[0]), leave=True, desc = description(None), total=len(dataloaders[0]), ncols=0) as batch_iter:
132
149
  for _, data_dict in batch_iter:
133
150
  batch_iter.set_description(description(self.update({k: (v[0], v[4]) for k,v in data_dict.items()}, self.statistics_train)))
134
151
  outputs = synchronize_data(world_size, gpu, self.statistics_train.measures)
@@ -136,7 +153,7 @@ class Evaluator(DistributedObject):
136
153
  self.statistics_train.write(outputs)
137
154
  if len(dataloaders) == 2:
138
155
  description = lambda measure : "Metric VALIDATION : {} ".format(" | ".join("{}: {:.2f}".format(k, v) for k, v in measure.items()) if measure is not None else "")
139
- with tqdm.tqdm(iterable = enumerate(dataloaders[1]), leave=False, desc = description(None), total=len(dataloaders[1])) as batch_iter:
156
+ with tqdm.tqdm(iterable = enumerate(dataloaders[1]), leave=True, desc = description(None), total=len(dataloaders[1]), ncols=0) as batch_iter:
140
157
  for _, data_dict in batch_iter:
141
158
  batch_iter.set_description(description(self.update({k: (v[0], v[4]) for k,v in data_dict.items()}, self.statistics_validation)))
142
159
  outputs = synchronize_data(world_size, gpu, self.statistics_validation.measures)
konfai/main.py CHANGED
@@ -3,6 +3,7 @@ import os
3
3
  from torch.cuda import device_count
4
4
  import torch.multiprocessing as mp
5
5
  from konfai.utils.utils import setup, TensorBoard, Log
6
+ from konfai import KONFAI_NB_CORES
6
7
 
7
8
  import sys
8
9
  sys.path.insert(0, os.getcwd())
@@ -11,10 +12,10 @@ def main():
11
12
  parser = argparse.ArgumentParser(description="KonfAI", formatter_class=argparse.ArgumentDefaultsHelpFormatter)
12
13
  try:
13
14
  with setup(parser) as distributedObject:
14
- with Log(distributedObject.name):
15
+ with Log(distributedObject.name, 0):
15
16
  world_size = device_count()
16
17
  if world_size == 0:
17
- world_size = 1
18
+ world_size = int(KONFAI_NB_CORES())
18
19
  distributedObject.setup(world_size)
19
20
  with TensorBoard(distributedObject.name):
20
21
  mp.spawn(distributedObject, nprocs=world_size)
konfai/network/network.py CHANGED
@@ -39,7 +39,6 @@ class OptimizerLoader():
39
39
  self.name = name
40
40
 
41
41
  def getOptimizer(self, key: str, parameter: Iterator[torch.nn.parameter.Parameter]) -> torch.optim.Optimizer:
42
- torch.optim.AdamW
43
42
  return config("{}.Model.{}.Optimizer".format(KONFAI_ROOT(), key))(getattr(importlib.import_module('torch.optim'), self.name))(parameter, config = None)
44
43
 
45
44
  class SchedulerStep():
konfai/predictor.py CHANGED
@@ -94,10 +94,10 @@ class OutDataset(Dataset, NeedDevice, ABC):
94
94
  class OutSameAsGroupDataset(OutDataset):
95
95
 
96
96
  @config("OutDataset")
97
- def __init__(self, dataset_filename: str = "./Dataset:mha", group: str = "default", sameAsGroup: str = "default", pre_transforms : dict[str, TransformLoader] = {"default:Normalize": TransformLoader()}, post_transforms : dict[str, TransformLoader] = {"default:Normalize": TransformLoader()}, final_transforms : dict[str, TransformLoader] = {"default:Normalize": TransformLoader()}, patchCombine: Union[str, None] = None, redution: str = "mean", inverse_transform: bool = True) -> None:
97
+ def __init__(self, dataset_filename: str = "./Dataset:mha", group: str = "default", sameAsGroup: str = "default", pre_transforms : dict[str, TransformLoader] = {"default:Normalize": TransformLoader()}, post_transforms : dict[str, TransformLoader] = {"default:Normalize": TransformLoader()}, final_transforms : dict[str, TransformLoader] = {"default:Normalize": TransformLoader()}, patchCombine: Union[str, None] = None, reduction: str = "mean", inverse_transform: bool = True) -> None:
98
98
  super().__init__(dataset_filename, group, pre_transforms, post_transforms, final_transforms, patchCombine)
99
99
  self.group_src, self.group_dest = sameAsGroup.split(":")
100
- self.redution = redution
100
+ self.reduction = reduction
101
101
  self.inverse_transform = inverse_transform
102
102
 
103
103
  def addLayer(self, index_dataset: int, index_augmentation: int, index_patch: int, layer: torch.Tensor, dataset: DatasetIter):
@@ -151,9 +151,9 @@ class OutSameAsGroupDataset(OutDataset):
151
151
  self.output_layer_accumulator.pop(index)
152
152
  dtype = result.dtype
153
153
 
154
- if self.redution == "mean":
154
+ if self.reduction == "mean":
155
155
  result = torch.mean(result.float(), dim=0).to(dtype)
156
- elif self.redution == "median":
156
+ elif self.reduction == "median":
157
157
  result, _ = torch.median(result.float(), dim=0)
158
158
  else:
159
159
  raise NameError("Reduction method does not exist (mean, median)")
@@ -201,7 +201,7 @@ class OutDatasetLoader():
201
201
 
202
202
  class _Predictor():
203
203
 
204
- def __init__(self, world_size: int, global_rank: int, local_rank: int, predict_path: str, data_log: Union[list[str], None], outsDataset: dict[str, OutDataset], modelComposite: DDP, dataloader_prediction: DataLoader) -> None:
204
+ def __init__(self, world_size: int, global_rank: int, local_rank: int, autocast: bool, predict_path: str, data_log: Union[list[str], None], outsDataset: dict[str, OutDataset], modelComposite: DDP, dataloader_prediction: DataLoader) -> None:
205
205
  self.world_size = world_size
206
206
  self.global_rank = global_rank
207
207
  self.local_rank = local_rank
@@ -209,7 +209,7 @@ class _Predictor():
209
209
  self.modelComposite = modelComposite
210
210
  self.dataloader_prediction = dataloader_prediction
211
211
  self.outsDataset = outsDataset
212
-
212
+ self.autocast = autocast
213
213
 
214
214
  self.it = 0
215
215
 
@@ -239,21 +239,22 @@ class _Predictor():
239
239
  self.modelComposite.eval()
240
240
  self.modelComposite.module.setState(NetState.PREDICTION)
241
241
  desc = lambda : "Prediction : {}".format(description(self.modelComposite))
242
- self.dataloader_prediction.dataset.load()
243
- with tqdm.tqdm(iterable = enumerate(self.dataloader_prediction), leave=False, desc = desc(), total=len(self.dataloader_prediction), disable=self.global_rank != 0 and "KONFAI_CLUSTER" not in os.environ) as batch_iter:
242
+ self.dataloader_prediction.dataset.load("Prediction")
243
+ with tqdm.tqdm(iterable = enumerate(self.dataloader_prediction), leave=True, desc = desc(), total=len(self.dataloader_prediction), ncols=0) as batch_iter:
244
244
  dist.barrier()
245
245
  for it, data_dict in batch_iter:
246
- input = self.getInput(data_dict)
247
- for name, output in self.modelComposite(input, list(self.outsDataset.keys())):
248
- self._predict_log(data_dict)
249
- outDataset = self.outsDataset[name]
250
- for i, (index, patch_augmentation, patch_index) in enumerate([(int(index), int(patch_augmentation), int(patch_index)) for index, patch_augmentation, patch_index in zip(list(data_dict.values())[0][1], list(data_dict.values())[0][2], list(data_dict.values())[0][3])]):
251
- outDataset.addLayer(index, patch_augmentation, patch_index, output[i].cpu(), self.dataset)
252
- if outDataset.isDone(index):
253
- outDataset.write(index, self.dataset.getDatasetFromIndex(list(data_dict.keys())[0], index).name.split("/")[-1], outDataset.getOutput(index, self.dataset))
254
-
255
- batch_iter.set_description(desc())
256
- self.it += 1
246
+ with torch.amp.autocast('cuda', enabled=self.autocast):
247
+ input = self.getInput(data_dict)
248
+ for name, output in self.modelComposite(input, list(self.outsDataset.keys())):
249
+ self._predict_log(data_dict)
250
+ outDataset = self.outsDataset[name]
251
+ for i, (index, patch_augmentation, patch_index) in enumerate([(int(index), int(patch_augmentation), int(patch_index)) for index, patch_augmentation, patch_index in zip(list(data_dict.values())[0][1], list(data_dict.values())[0][2], list(data_dict.values())[0][3])]):
252
+ outDataset.addLayer(index, patch_augmentation, patch_index, output[i].cpu(), self.dataset)
253
+ if outDataset.isDone(index):
254
+ outDataset.write(index, self.dataset.getDatasetFromIndex(list(data_dict.keys())[0], index).name.split("/")[-1], outDataset.getOutput(index, self.dataset))
255
+
256
+ batch_iter.set_description(desc())
257
+ self.it += 1
257
258
 
258
259
  def _predict_log(self, data_dict : dict[str, tuple[torch.Tensor, int, int, int]]):
259
260
  measures = DistributedObject.getMeasure(self.world_size, self.global_rank, self.local_rank, {"" : self.modelComposite.module}, 1)
@@ -320,6 +321,7 @@ class Predictor(DistributedObject):
320
321
  train_name: str = "name",
321
322
  manual_seed : Union[int, None] = None,
322
323
  gpu_checkpoints: Union[list[str], None] = None,
324
+ autocast : bool = False,
323
325
  outsDataset: Union[dict[str, OutDatasetLoader], None] = {"default:Default" : OutDatasetLoader()},
324
326
  images_log: list[str] = []) -> None:
325
327
  if os.environ["KONFAI_CONFIG_MODE"] != "Done":
@@ -328,6 +330,7 @@ class Predictor(DistributedObject):
328
330
  self.manual_seed = manual_seed
329
331
  self.dataset = dataset
330
332
  self.combine = combine
333
+ self.autocast = autocast
331
334
 
332
335
  self.model = model.getModel(train=False)
333
336
  self.it = 0
@@ -384,7 +387,7 @@ class Predictor(DistributedObject):
384
387
 
385
388
  shutil.copyfile(CONFIG_FILE(), self.predict_path+"Prediction.yml")
386
389
 
387
- self.model.init(autocast=False, state = State.PREDICTION)
390
+ self.model.init(self.autocast, State.PREDICTION, self.dataset.getGroupsDest())
388
391
  self.model.init_outputsGroup()
389
392
  self.model._compute_channels_trace(self.model, self.model.in_channels, None, self.gpu_checkpoints)
390
393
  self.modelComposite = ModelComposite(self.model, len(MODEL().split(":")), self.combine)
@@ -402,7 +405,7 @@ class Predictor(DistributedObject):
402
405
  def run_process(self, world_size: int, global_rank: int, local_rank: int, dataloaders: list[DataLoader]):
403
406
  modelComposite = Network.to(self.modelComposite, local_rank*self.size)
404
407
  modelComposite = DDP(modelComposite, static_graph=True) if torch.cuda.is_available() else CPU_Model(modelComposite)
405
- with _Predictor(world_size, global_rank, local_rank, self.predict_path, self.images_log, self.outsDataset, modelComposite, *dataloaders) as p:
408
+ with _Predictor(world_size, global_rank, local_rank, self.autocast, self.predict_path, self.images_log, self.outsDataset, modelComposite, *dataloaders) as p:
406
409
  p.run()
407
410
 
408
411
 
konfai/trainer.py CHANGED
@@ -17,7 +17,6 @@ from konfai.utils.config import config
17
17
  from konfai.utils.utils import State, DataLog, DistributedObject, description, TrainerError
18
18
  from konfai.network.network import Network, ModelLoader, NetState, CPU_Model
19
19
 
20
-
21
20
  class EarlyStoppingBase:
22
21
 
23
22
  def __init__(self):
@@ -99,7 +98,7 @@ class _Trainer():
99
98
  self.autocast = autocast
100
99
  self.modelEMA = modelEMA
101
100
  self.early_stopping = EarlyStoppingBase() if early_stopping is None else early_stopping
102
-
101
+
103
102
  self.it_validation = it_validation
104
103
  if self.it_validation is None:
105
104
  self.it_validation = len(dataloader_training)
@@ -118,13 +117,14 @@ class _Trainer():
118
117
  self.tb.close()
119
118
 
120
119
  def run(self) -> None:
121
- with tqdm.tqdm(iterable = range(self.epoch, self.epochs), leave=False, total=self.epochs, initial=self.epoch, desc="Progress", disable=self.global_rank != 0) as epoch_tqdm:
120
+ self.dataloader_training.dataset.load("Train")
121
+ self.dataloader_validation.dataset.load("Validation")
122
+ with tqdm.tqdm(iterable = range(self.epoch, self.epochs), leave=False, total=self.epochs, initial=self.epoch, desc="Progress") as epoch_tqdm:
122
123
  for self.epoch in epoch_tqdm:
123
- self.dataloader_training.dataset.load()
124
124
  self.train()
125
125
  if self.early_stopping.isStopped():
126
126
  break
127
- self.dataloader_training.dataset.resetAugmentation()
127
+ self.dataloader_training.dataset.resetAugmentation("Train")
128
128
 
129
129
  def getInput(self, data_dict : dict[str, tuple[torch.Tensor, int, int, int, str, bool]]) -> dict[tuple[str, bool], torch.Tensor]:
130
130
  return {(k, v[5][0].item()) : v[0] for k, v in data_dict.items()}
@@ -137,7 +137,8 @@ class _Trainer():
137
137
  self.modelEMA.module.setState(NetState.TRAIN)
138
138
 
139
139
  desc = lambda : "Training : {}".format(description(self.model, self.modelEMA))
140
- with tqdm.tqdm(iterable = enumerate(self.dataloader_training), desc = desc(), total=len(self.dataloader_training), leave=False, disable=self.global_rank != 0 and "KONFAI_CLUSTER" not in os.environ) as batch_iter:
140
+
141
+ with tqdm.tqdm(iterable = enumerate(self.dataloader_training), desc = desc(), total=len(self.dataloader_training), leave=False, ncols=0) as batch_iter:
141
142
  for _, data_dict in batch_iter:
142
143
  with torch.amp.autocast('cuda', enabled=self.autocast):
143
144
  input = self.getInput(data_dict)
@@ -171,8 +172,7 @@ class _Trainer():
171
172
 
172
173
  desc = lambda : "Validation : {}".format(description(self.model, self.modelEMA))
173
174
  data_dict = None
174
- self.dataloader_validation.dataset.load()
175
- with tqdm.tqdm(iterable = enumerate(self.dataloader_validation), desc = desc(), total=len(self.dataloader_validation), leave=False, disable=self.global_rank != 0 and "KONFAI_CLUSTER" not in os.environ) as batch_iter:
175
+ with tqdm.tqdm(iterable = enumerate(self.dataloader_validation), desc = desc(), total=len(self.dataloader_validation), leave=False, ncols=0) as batch_iter:
176
176
  for _, data_dict in batch_iter:
177
177
  input = self.getInput(data_dict)
178
178
  self.model(input)
@@ -180,7 +180,7 @@ class _Trainer():
180
180
  self.modelEMA.module(input)
181
181
 
182
182
  batch_iter.set_description(desc())
183
- self.dataloader_validation.dataset.resetAugmentation()
183
+ self.dataloader_validation.dataset.resetAugmentation("Validation")
184
184
  dist.barrier()
185
185
  self.model.train()
186
186
  self.model.module.setState(NetState.TRAIN)
@@ -360,7 +360,7 @@ class Trainer(DistributedObject):
360
360
  os.makedirs(dir)
361
361
 
362
362
  for name in sorted(os.listdir(path_checkpoint)):
363
- checkpoint = torch.load(path_checkpoint+name, weights_only=False)
363
+ checkpoint = torch.load(path_checkpoint+name, weights_only=False, map_location='cpu')
364
364
  self.model.load(checkpoint, init=False, ema=False)
365
365
 
366
366
  torch.save(self.model, "{}Serialized/{}".format(path_model, name))
konfai/utils/dataset.py CHANGED
@@ -348,6 +348,10 @@ class Dataset():
348
348
  @abstractmethod
349
349
  def getNames(self, group: str) -> list[str]:
350
350
  pass
351
+
352
+ @abstractmethod
353
+ def getGroup(self) -> list[str]:
354
+ pass
351
355
 
352
356
  @abstractmethod
353
357
  def isExist(self, group: str, name: Union[str, None] = None) -> bool:
@@ -458,6 +462,9 @@ class Dataset():
458
462
  names.extend(self.getNames("/".join(groups.split("/")[1:]), h5_group[group]))
459
463
  return names
460
464
 
465
+ def getGroup(self):
466
+ return self.h5.keys()
467
+
461
468
  def _getDataset(self, groups: str, name: str, h5_group: h5py.Group = None) -> h5py.Dataset:
462
469
  if h5_group is None:
463
470
  h5_group = self.h5
@@ -613,7 +620,10 @@ class Dataset():
613
620
 
614
621
  def getNames(self, group: str) -> list[str]:
615
622
  raise NotImplementedError()
616
-
623
+
624
+ def getGroup(self):
625
+ raise NotImplementedError()
626
+
617
627
  def getInfos(self, group: str, name: str) -> tuple[list[int], Attribute]:
618
628
  attributes = Attribute()
619
629
  if os.path.exists("{}{}{}.{}".format(self.filename, group if group is not None else "", name, self.format)):
@@ -736,7 +746,7 @@ class Dataset():
736
746
  subDirectories.extend(self._getSubDirectories("/".join(groups.split("/")[1:]), subDirectory))
737
747
  return subDirectories
738
748
 
739
- def getNames(self, groups: str, index: Union[list[int], None] = None, subDirectory: str = "") -> list[str]:
749
+ def getNames(self, groups: str, index: Union[list[int], None] = None) -> list[str]:
740
750
  names = []
741
751
  if self.is_directory:
742
752
  for subDirectory in self._getSubDirectories(groups):
@@ -752,6 +762,21 @@ class Dataset():
752
762
  names = file.getNames(groups)
753
763
  return [name for i, name in enumerate(sorted(names)) if index is None or i in index]
754
764
 
765
+ def getGroup(self):
766
+ if self.is_directory:
767
+ groups = set()
768
+ for root, _, files in os.walk(self.filename):
769
+ for file in files:
770
+ path = os.path.relpath(os.path.join(root, file.split(".")[0]), self.filename)
771
+ parts = path.split("/")
772
+ if len(parts) >= 2:
773
+ del parts[-2]
774
+ groups.add("/".join(parts))
775
+ else:
776
+ with Dataset.File(self.filename, True, self.format) as file:
777
+ groups = file.getGroup()
778
+ return list(groups)
779
+
755
780
  def getInfos(self, groups: str, name: str) -> tuple[list[int], Attribute]:
756
781
  if self.is_directory:
757
782
  for subDirectory in self._getSubDirectories(groups):
konfai/utils/utils.py CHANGED
@@ -18,13 +18,15 @@ import random
18
18
  from torch.utils.data import DataLoader
19
19
  import torch.nn.functional as F
20
20
  import sys
21
+ import re
22
+
21
23
 
22
24
  def description(model, modelEMA = None, showMemory: bool = True) -> str:
23
25
  values_desc = lambda weights, values: " ".join(["{}({:.2f}) : {:.6f}".format(name.split(":")[-1], weight, value) for (name, value), weight in zip(values.items(), weights.values())])
24
26
  model_desc = lambda model : "("+" ".join(["{}({:.6f}) : {}".format(name, network.optimizer.param_groups[0]['lr'] if network.optimizer is not None else 0, values_desc(network.measure.getLastWeights(), network.measure.getLastValues())) for name, network in model.module.getNetworks().items() if network.measure is not None])+")"
25
27
  result = "Loss {}".format(model_desc(model))
26
28
  if modelEMA is not None:
27
- result += "Loss EMA {}".format(model_desc(modelEMA))
29
+ result += " Loss EMA {}".format(model_desc(modelEMA))
28
30
  result += " "+gpuInfo()
29
31
  if showMemory:
30
32
  result +=" | {}".format(memoryInfo())
@@ -237,7 +239,7 @@ class DataLog(Enum):
237
239
  AUDIO = lambda tb, name, layer, it : tb.add_audio(name, _logImageFormat(layer), it)
238
240
 
239
241
  class Log:
240
- def __init__(self, name: str) -> None:
242
+ def __init__(self, name: str, rank: int) -> None:
241
243
  if KONFAI_STATE() == "PREDICTION":
242
244
  path = PREDICTIONS_DIRECTORY()
243
245
  elif KONFAI_STATE() == "EVALUATION":
@@ -248,11 +250,12 @@ class Log:
248
250
  self.verbose = os.environ.get("KONFAI_VERBOSE", "True") == "True"
249
251
  self.log_path = os.path.join(path, name)
250
252
  os.makedirs(self.log_path, exist_ok=True)
251
-
252
- self.file = open(os.path.join(self.log_path, "log.txt"), "w", buffering=1)
253
+ self.rank = rank
254
+ self.file = open(os.path.join(self.log_path, "log_{}.txt".format(rank)), "w", buffering=1)
253
255
  self.stdout_bak = sys.stdout
254
256
  self.stderr_bak = sys.stderr
255
-
257
+ self._buffered_line = ""
258
+
256
259
  def __enter__(self):
257
260
  self.file.__enter__()
258
261
  sys.stdout = self
@@ -264,12 +267,26 @@ class Log:
264
267
  sys.stdout = self.stdout_bak
265
268
  sys.stderr = self.stderr_bak
266
269
 
267
- def write(self, msg):
270
+ def write(self, msg: str):
268
271
  if not msg:
269
272
  return
270
- self.file.write(msg)
271
- self.file.flush()
272
- if self.verbose:
273
+
274
+
275
+ ANSI_ESCAPE = re.compile(r'\x1b\[[0-9;]*[a-zA-Z]')
276
+ CARRIAGE_RETURN = re.compile(r'(?:\r|\x1b\[A).*')
277
+ msg_clean = ANSI_ESCAPE.sub('', msg)
278
+ if '\r' in msg_clean or '' in msg:
279
+ # On garde seulement le contenu après le dernier retour chariot
280
+ msg_clean = msg_clean.split('\r')[-1].strip()
281
+ self._buffered_line = msg_clean
282
+ else:
283
+ self._buffered_line = msg_clean.strip()
284
+
285
+ if self._buffered_line:
286
+ # Écrit dans le fichier
287
+ self.file.write(self._buffered_line + "\n")
288
+ self.file.flush()
289
+ if self.verbose and (self.rank == 0 or "KONFAI_CLUSTER" in os.environ):
273
290
  sys.__stdout__.write(msg)
274
291
  sys.__stdout__.flush()
275
292
 
@@ -352,7 +369,7 @@ class DistributedObject():
352
369
  return result
353
370
 
354
371
  def __call__(self, rank: Union[int, None] = None) -> None:
355
- with Log(self.name):
372
+ with Log(self.name, rank):
356
373
  world_size = len(self.dataloader)
357
374
  global_rank, local_rank = setupGPU(world_size, self.port, rank)
358
375
  if global_rank is None:
@@ -385,6 +402,7 @@ def setup(parser: argparse.ArgumentParser) -> DistributedObject:
385
402
  KONFAI_args.add_argument('-tb', action='store_true', help='Start TensorBoard')
386
403
  KONFAI_args.add_argument("-c", "--config", type=str, default="None", help="Configuration file location")
387
404
  KONFAI_args.add_argument("-g", "--gpu", type=str, default=os.environ["CUDA_VISIBLE_DEVICES"] if "CUDA_VISIBLE_DEVICES" in os.environ else "", help="List of GPU")
405
+ KONFAI_args.add_argument("-cpu", "--cpu", type=str, default="1" , help="List of GPU")
388
406
  KONFAI_args.add_argument('--num-workers', '--num_workers', default=4, type=int, help='No. of workers per DataLoader & GPU')
389
407
  KONFAI_args.add_argument("-models_dir", "--MODELS_DIRECTORY", type=str, default="./Models/", help="Models location")
390
408
  KONFAI_args.add_argument("-checkpoints_dir", "--CHECKPOINTS_DIRECTORY", type=str, default="./Checkpoints/", help="Checkpoints location")
@@ -401,6 +419,7 @@ def setup(parser: argparse.ArgumentParser) -> DistributedObject:
401
419
  config = vars(args)
402
420
 
403
421
  os.environ["CUDA_VISIBLE_DEVICES"] = config["gpu"]
422
+ os.environ["KONFAI_NB_CORES"] = config["cpu"]
404
423
  os.environ["KONFAI_MODELS_DIRECTORY"] = config["MODELS_DIRECTORY"]
405
424
  os.environ["KONFAI_CHECKPOINTS_DIRECTORY"] = config["CHECKPOINTS_DIRECTORY"]
406
425
  os.environ["KONFAI_PREDICTIONS_DIRECTORY"] = config["PREDICTIONS_DIRECTORY"]
@@ -465,6 +484,14 @@ def setupGPU(world_size: int, port: int, rank: Union[int, None] = None) -> tuple
465
484
  if torch.cuda.is_available():
466
485
  torch.cuda.empty_cache()
467
486
  dist.init_process_group("nccl", rank=global_rank, init_method="tcp://{}:{}".format(host_name, port), world_size=world_size)
487
+ else:
488
+ if not dist.is_initialized():
489
+ dist.init_process_group(
490
+ backend="gloo",
491
+ init_method=f"tcp://{host_name}:{port}",
492
+ rank=global_rank,
493
+ world_size=world_size
494
+ )
468
495
  return global_rank, local_rank
469
496
 
470
497
  import socket
@@ -524,8 +551,8 @@ SUPPORTED_EXTENSIONS = [
524
551
 
525
552
  class KonfAIError(Exception):
526
553
 
527
- def __init__(self, typeError: str, message: list[str]) -> None:
528
- super().__init__("\n[{}] {}".format(typeError, message[0])+"\n→\t".join(message[1:]))
554
+ def __init__(self, typeError: str, messages: list[str]) -> None:
555
+ super().__init__("\n[{}] {}".format(typeError, messages[0])+("\n" if len(messages) > 0 else "")+"\n→\t".join(messages[1:]))
529
556
 
530
557
 
531
558
  class ConfigError(KonfAIError):
@@ -553,3 +580,13 @@ class AugmentationError(KonfAIError):
553
580
 
554
581
  def __init__(self, *message) -> None:
555
582
  super().__init__("Augmentation", message)
583
+
584
+ class EvaluatorError(KonfAIError):
585
+
586
+ def __init__(self, *message) -> None:
587
+ super().__init__("Evaluator", message)
588
+
589
+ class TransformError(KonfAIError):
590
+
591
+ def __init__(self, *message) -> None:
592
+ super().__init__("Transform", message)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: konfai
3
- Version: 1.1.1
3
+ Version: 1.1.2
4
4
  Summary: Modular and configurable Deep Learning framework with YAML and PyTorch
5
5
  Author-email: Valentin Boussot <boussot.v@gmail.com>
6
6
  License-Expression: Apache-2.0
@@ -1,13 +1,13 @@
1
- konfai/__init__.py,sha256=fZM9EjrMcdJZE-jQAsNu3ZcWfa9Ylbly2q4l9STfpBI,795
2
- konfai/evaluator.py,sha256=TsHQNN4DhP33Sal9-exbOouZAHe7gNEk7RsBaasgkEQ,7373
3
- konfai/main.py,sha256=OPAUWees1j6NNmggyI7lM2abZctCqk1OsvQuc9FpePY,2515
4
- konfai/predictor.py,sha256=FGn6_Nt4PCdxNFoIOt8nbBjrAMTwc55uZqy6OFXUjhY,22320
5
- konfai/trainer.py,sha256=22IKMiPeHNJMGcJdVq9BH-Hsopp6xF4QSYyPabz4LGs,19604
1
+ konfai/__init__.py,sha256=YXG-wpSEXWs6Jt3BDI77V4r89gEUNX-6lxW9btj5VYI,851
2
+ konfai/evaluator.py,sha256=rAhfdRemMjzC3VoaqyQKJR0SBekuLDiLT1nhblH8RQk,8293
3
+ konfai/main.py,sha256=rTTJl-biaX4CkLNxPtqwwsrybXSDxjWlU39TE2ImU5o,2574
4
+ konfai/predictor.py,sha256=e5V7awlIRDE1NkLTR_D6fUuDg93ZFGaGLTDR71KeqM8,22549
5
+ konfai/trainer.py,sha256=NQDzp4hUKYAcDdKmeBcAE9QNFW9KGW2qV-7q_PXeWi8,19509
6
6
  konfai/data/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
7
7
  konfai/data/augmentation.py,sha256=ASmKWBpykLBHDB_YeTgoBTlqvJ06v5OUG7n7ugPN6NU,31718
8
- konfai/data/data_manager.py,sha256=AfZpQaHEd4QmA8Sxd6SK2LK3Nv658wlZVbivETq4Y4k,28439
9
- konfai/data/patching.py,sha256=JFgb4qgS_gqlcnFV_Rm0JvCzKQleXNZyOWxqqM8H01M,14499
10
- konfai/data/transform.py,sha256=ttALCEvvqIcwcYJTp4KCnf87e2hzvkXlwrtksfpWRUU,25064
8
+ konfai/data/data_manager.py,sha256=9VpF4CTTAnqS1Hq0csbTE-XAocRwziFssITg-SZhA4A,29603
9
+ konfai/data/patching.py,sha256=OtGNs99jKQSYmuj8B3MNGcbupKOcX5PUpMCIe0pKnX0,15636
10
+ konfai/data/transform.py,sha256=dTZh2CgfxNitHW-HEkO0jQ9GG26-Wf-YlI-qoI2mqC0,26472
11
11
  konfai/metric/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
12
12
  konfai/metric/measure.py,sha256=K7s15wNJm1_Iaav9y7Oe_UhSnbLUiybhLTVumHZ8ZaY,22001
13
13
  konfai/metric/schedulers.py,sha256=UoSr1TW_hrus3DhvOEbDefxCSUGz7lJS_8vbz0GEye8,1370
@@ -24,16 +24,16 @@ konfai/models/segmentation/NestedUNet.py,sha256=6XGizAIc4bDL8vx4AHW8BBFjUvovRYcj
24
24
  konfai/models/segmentation/UNet.py,sha256=9-g63oNqaxSlGmrD1-IVzJ-kac9QphmM7i-3rEsH23I,4218
25
25
  konfai/network/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
26
26
  konfai/network/blocks.py,sha256=RaTI0Lrvq1V-GIFei-WTUB6wlg4LydZksAyJ8DMk40M,13502
27
- konfai/network/network.py,sha256=DTROmwwPJZQllJx32GwAixfy1E23F1mtYy1bpF23TjU,46538
27
+ konfai/network/network.py,sha256=SnwQUKFTZZU8zuix5sm0vW8s0G3h6Fa8pHoIcwC984Q,46512
28
28
  konfai/utils/ITK.py,sha256=OxTieDNNYHGkn7zxJsAG-6ecRG1VYMvn1dlBbBe1DOs,13955
29
29
  konfai/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
30
30
  konfai/utils/config.py,sha256=OsCUqgNaPX56nD_MpJ2WRpiQ7xxqHonfPYaS9VGtsl8,12413
31
- konfai/utils/dataset.py,sha256=RNULLVauhYiyeQkpVhOVq4VRwI4KRyoQA2VCmOh8VbE,35525
31
+ konfai/utils/dataset.py,sha256=6ZzevdhJ7e5zlXATAVwSh9O6acKXM7gYNxkMAa5DrmM,36351
32
32
  konfai/utils/registration.py,sha256=v1srEBOcgDnHrx0YtsK6bcj0yCMH7wNeaQ3wC7gEvOw,8898
33
- konfai/utils/utils.py,sha256=GPzQQ7kBmraHytsUdPOqwW0Qu-kAHGA-wwMd4obdVLE,22018
34
- konfai-1.1.1.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
35
- konfai-1.1.1.dist-info/METADATA,sha256=mMT6fvoAbwoMkChW6G_Y2vDhrtiMss86xezEnylzLrE,2515
36
- konfai-1.1.1.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
37
- konfai-1.1.1.dist-info/entry_points.txt,sha256=fG82HRN5-g39ACSOCtij_I3N6EHxfYnMR0D7TI_8pW8,81
38
- konfai-1.1.1.dist-info/top_level.txt,sha256=xF470dkIlFoFqTZEOlRehKJr4WU_8OKGXrJqYm9vWKs,7
39
- konfai-1.1.1.dist-info/RECORD,,
33
+ konfai/utils/utils.py,sha256=Laq8bGc5mGKFZlJIkHxa-BrC9uR2F7MTRuL4YFAIxQY,23439
34
+ konfai-1.1.2.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
35
+ konfai-1.1.2.dist-info/METADATA,sha256=Usg7hzW1xmlHswHZb3fC2Hgl2V1Jkq43J0T5j5TDEIw,2515
36
+ konfai-1.1.2.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
37
+ konfai-1.1.2.dist-info/entry_points.txt,sha256=fG82HRN5-g39ACSOCtij_I3N6EHxfYnMR0D7TI_8pW8,81
38
+ konfai-1.1.2.dist-info/top_level.txt,sha256=xF470dkIlFoFqTZEOlRehKJr4WU_8OKGXrJqYm9vWKs,7
39
+ konfai-1.1.2.dist-info/RECORD,,
File without changes