konfai 1.1.1__py3-none-any.whl → 1.1.3__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of konfai might be problematic. Click here for more details.

konfai/data/transform.py CHANGED
@@ -6,7 +6,7 @@ from abc import ABC, abstractmethod
6
6
  import torch.nn.functional as F
7
7
  from typing import Any, Union
8
8
 
9
- from konfai.utils.utils import _getModule, NeedDevice, _resample_affine, _affine_matrix
9
+ from konfai.utils.utils import _getModule, NeedDevice, _resample_affine, _affine_matrix, TransformError
10
10
  from konfai.utils.dataset import Dataset, Attribute, data_to_image, image_to_data
11
11
  from konfai.utils.config import config
12
12
 
@@ -52,7 +52,7 @@ class Clip(Transform):
52
52
  input[torch.where(input < self.min_value)] = self.min_value
53
53
  input[torch.where(input > self.max_value)] = self.max_value
54
54
  if self.saveClip_min:
55
- cache_attribute["Min"] = self .min_value
55
+ cache_attribute["Min"] = self.min_value
56
56
  if self.saveClip_max:
57
57
  cache_attribute["Max"] = self.max_value
58
58
  return input
@@ -211,39 +211,67 @@ class Resample(Transform, ABC):
211
211
  _ = cache_attribute.pop_np_array("Spacing")
212
212
  return self._resample(input, [int(size) for size in size_1])
213
213
 
214
- class ResampleIsotropic(Resample):
214
+ class ResampleToResolution(Resample):
215
+
216
+ def __init__(self, spacing : list[Union[float, None]] = [1., 1., 1.]) -> None:
217
+ self.spacing = torch.tensor([0 if s < 0 else s for s in spacing])
215
218
 
216
- def __init__(self, spacing : list[float] = [1., 1., 1.]) -> None:
217
- self.spacing = torch.tensor(spacing, dtype=torch.float64)
218
-
219
219
  def transformShape(self, shape: list[int], cache_attribute: Attribute) -> list[int]:
220
- assert "Spacing" in cache_attribute, "Error no spacing"
221
- resize_factor = self.spacing/cache_attribute.get_tensor("Spacing").flip(0)
222
- return [int(x) for x in (torch.tensor(shape) * 1/resize_factor)]
220
+ if "Spacing" not in cache_attribute:
221
+ TransformError("Missing 'Spacing' in cache attributes, the data is likely not a valid image.",
222
+ "Make sure your input is a image (e.g., .nii, .mha) with proper metadata.")
223
+ if len(shape) != len(self.spacing):
224
+ TransformError("Shape and spacing dimensions do not match: shape={shape}, spacing={self.spacing}")
225
+ image_spacing = cache_attribute.get_tensor("Spacing").flip(0)
226
+ spacing = self.spacing
227
+
228
+ for i, s in enumerate(self.spacing):
229
+ if s == 0:
230
+ spacing[i] = image_spacing[i]
231
+ resize_factor = spacing/cache_attribute.get_tensor("Spacing").flip(0)
232
+ return [int(x) for x in (torch.tensor(shape) * 1/resize_factor)]
223
233
 
224
234
  def __call__(self, name: str, input : torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
225
- assert "Spacing" in cache_attribute, "Error no spacing"
226
- resize_factor = self.spacing/cache_attribute.get_tensor("Spacing").flip(0)
227
- cache_attribute["Spacing"] = self.spacing.flip(0)
235
+ image_spacing = cache_attribute.get_tensor("Spacing").flip(0)
236
+ spacing = self.spacing
237
+ for i, s in enumerate(self.spacing):
238
+ if s == 0:
239
+ spacing[i] = image_spacing[i]
240
+ resize_factor = spacing/cache_attribute.get_tensor("Spacing").flip(0)
241
+ cache_attribute["Spacing"] = spacing.flip(0)
228
242
  cache_attribute["Size"] = np.asarray([int(x) for x in torch.tensor(input.shape[1:])])
229
243
  size = [int(x) for x in (torch.tensor(input.shape[1:]) * 1/resize_factor)]
230
244
  cache_attribute["Size"] = np.asarray(size)
231
245
  return self._resample(input, size)
232
246
 
233
- class ResampleResize(Resample):
247
+ class ResampleToSize(Resample):
234
248
 
235
249
  def __init__(self, size : list[int] = [100,512,512]) -> None:
236
250
  self.size = size
237
251
 
238
252
  def transformShape(self, shape: list[int], cache_attribute: Attribute) -> list[int]:
239
- return self.size
253
+ if "Spacing" not in cache_attribute:
254
+ TransformError("Missing 'Spacing' in cache attributes, the data is likely not a valid image.",
255
+ "Make sure your input is a image (e.g., .nii, .mha) with proper metadata.")
256
+ if len(shape) != len(self.size):
257
+ TransformError("Shape and spacing dimensions do not match: shape={shape}, spacing={self.spacing}")
258
+ size = self.size
259
+ for i, s in enumerate(self.size):
260
+ if s == -1:
261
+ size[i] = shape[i]
262
+ return size
240
263
 
241
264
  def __call__(self, name: str, input: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
265
+ size = self.size
266
+ image_size = np.asarray([int(x) for x in torch.tensor(input.shape[1:])])
267
+ for i, s in enumerate(self.size):
268
+ if s is None:
269
+ size[i] = image_size[i]
242
270
  if "Spacing" in cache_attribute:
243
- cache_attribute["Spacing"] = torch.flip(torch.tensor(list(input.shape[1:]))/torch.tensor(self.size)*torch.flip(cache_attribute.get_tensor("Spacing"), dims=[0]), dims=[0])
244
- cache_attribute["Size"] = np.asarray([int(x) for x in torch.tensor(input.shape[1:])])
245
- cache_attribute["Size"] = self.size
246
- return self._resample(input, self.size)
271
+ cache_attribute["Spacing"] = torch.flip(torch.tensor(image_size)/torch.tensor(size)*torch.flip(cache_attribute.get_tensor("Spacing"), dims=[0]), dims=[0])
272
+ cache_attribute["Size"] = image_size
273
+ cache_attribute["Size"] = size
274
+ return self._resample(input, size)
247
275
 
248
276
  class ResampleTransform(Transform):
249
277
 
@@ -412,8 +440,8 @@ class FlatLabel(Transform):
412
440
 
413
441
  class Save(Transform):
414
442
 
415
- def __init__(self, save: str) -> None:
416
- self.save = save
443
+ def __init__(self, dataset: str) -> None:
444
+ self.dataset = dataset
417
445
 
418
446
  def __call__(self, name: str, input : torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
419
447
  return input
@@ -528,7 +556,7 @@ class OneHot(Transform):
528
556
  self.num_classes = num_classes
529
557
 
530
558
  def __call__(self, name: str, input: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
531
- result = F.one_hot(input.type(torch.int64), num_classes=self.num_classes).permute(0, len(input.shape), *[i+1 for i in range(len(input.shape)-1)]).float().squeeze(2)
559
+ result = F.one_hot(input.type(torch.int64), num_classes=self.num_classes).permute(0, len(input.shape), *[i+1 for i in range(len(input.shape)-1)]).float().squeeze(0)
532
560
  return result
533
561
 
534
562
  def inverse(self, name: str, input: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
konfai/evaluator.py CHANGED
@@ -9,7 +9,7 @@ import builtins
9
9
  import importlib
10
10
  from konfai import EVALUATIONS_DIRECTORY, PREDICTIONS_DIRECTORY, KONFAI_ROOT, CONFIG_FILE
11
11
  from konfai.utils.config import config
12
- from konfai.utils.utils import _getModule, DistributedObject, synchronize_data
12
+ from konfai.utils.utils import _getModule, DistributedObject, synchronize_data, EvaluatorError
13
13
  from konfai.data.data_manager import DataMetric
14
14
 
15
15
  class CriterionsAttr():
@@ -54,7 +54,8 @@ class Statistics():
54
54
  if name_dataset not in self.measures:
55
55
  self.measures[name_dataset] = {}
56
56
  self.measures[name_dataset][name] = value
57
-
57
+
58
+ @staticmethod
58
59
  def getStatistic(values: list[float]) -> dict[str, float]:
59
60
  return {"max": np.max(values), "min": np.min(values), "std": np.std(values), "25pc": np.percentile(values, 25), "50pc": np.percentile(values, 50), "75pc": np.percentile(values, 75), "mean": np.mean(values), "count": len(values)}
60
61
 
@@ -91,7 +92,6 @@ class Evaluator(DistributedObject):
91
92
  exit(0)
92
93
  super().__init__(train_name)
93
94
  self.metric_path = EVALUATIONS_DIRECTORY()+self.name+"/"
94
- self.predict_path = PREDICTIONS_DIRECTORY()+self.name+"/"
95
95
  self.metricsLoader = metrics
96
96
  self.dataset = dataset
97
97
  self.metrics = {k: v.getTargetsCriterions(k) for k, v in self.metricsLoader.items()}
@@ -102,10 +102,10 @@ class Evaluator(DistributedObject):
102
102
  result = {}
103
103
  for output_group in self.metrics:
104
104
  for target_group in self.metrics[output_group]:
105
- targets = [data_dict[group][0] for group in target_group.split("/") if group in data_dict]
105
+ targets = [data_dict[group][0].to(0) if torch.cuda.is_available() else data_dict[group][0] for group in target_group.split("/") if group in data_dict]
106
106
  name = data_dict[output_group][1][0]
107
107
  for metric in self.metrics[output_group][target_group]:
108
- result["{}:{}:{}".format(output_group, target_group, metric.__class__.__name__)] = metric(data_dict[output_group][0], *targets).item()
108
+ result["{}:{}:{}".format(output_group, target_group, metric.__class__.__name__)] = metric(data_dict[output_group][0].to(0) if torch.cuda.is_available() else data_dict[output_group][0], *targets).item()
109
109
  statistics.add(result, name)
110
110
  return result
111
111
 
@@ -126,9 +126,26 @@ class Evaluator(DistributedObject):
126
126
 
127
127
  self.dataloader = self.dataset.getData(world_size)
128
128
 
129
+ groupsDest = [group for groups in self.dataset.groups_src.values() for group in groups]
130
+
131
+ missing_outputs = set(self.metrics.keys()) - set(groupsDest)
132
+ if missing_outputs:
133
+ raise EvaluatorError(
134
+ f"The following metric output groups are missing from 'groupsDest': {sorted(missing_outputs)}. ",
135
+ f"Available groups: {sorted(groupsDest)}"
136
+ )
137
+
138
+ target_groups = {target for targets in self.metrics.values() for target in targets}
139
+ missing_targets = target_groups - set(groupsDest)
140
+ if missing_targets:
141
+ raise EvaluatorError(
142
+ f"The following metric target groups are missing from 'groupsDest': {sorted(missing_targets)}. ",
143
+ f"Available groups: {sorted(groupsDest)}"
144
+ )
145
+
129
146
  def run_process(self, world_size: int, global_rank: int, gpu: int, dataloaders: list[DataLoader]):
130
147
  description = lambda measure : "Metric TRAIN : {} ".format(" | ".join("{}: {:.2f}".format(k, v) for k, v in measure.items()) if measure is not None else "")
131
- with tqdm.tqdm(iterable = enumerate(dataloaders[0]), leave=False, desc = description(None), total=len(dataloaders[0])) as batch_iter:
148
+ with tqdm.tqdm(iterable = enumerate(dataloaders[0]), leave=True, desc = description(None), total=len(dataloaders[0]), ncols=0) as batch_iter:
132
149
  for _, data_dict in batch_iter:
133
150
  batch_iter.set_description(description(self.update({k: (v[0], v[4]) for k,v in data_dict.items()}, self.statistics_train)))
134
151
  outputs = synchronize_data(world_size, gpu, self.statistics_train.measures)
@@ -136,7 +153,7 @@ class Evaluator(DistributedObject):
136
153
  self.statistics_train.write(outputs)
137
154
  if len(dataloaders) == 2:
138
155
  description = lambda measure : "Metric VALIDATION : {} ".format(" | ".join("{}: {:.2f}".format(k, v) for k, v in measure.items()) if measure is not None else "")
139
- with tqdm.tqdm(iterable = enumerate(dataloaders[1]), leave=False, desc = description(None), total=len(dataloaders[1])) as batch_iter:
156
+ with tqdm.tqdm(iterable = enumerate(dataloaders[1]), leave=True, desc = description(None), total=len(dataloaders[1]), ncols=0) as batch_iter:
140
157
  for _, data_dict in batch_iter:
141
158
  batch_iter.set_description(description(self.update({k: (v[0], v[4]) for k,v in data_dict.items()}, self.statistics_validation)))
142
159
  outputs = synchronize_data(world_size, gpu, self.statistics_validation.measures)
konfai/main.py CHANGED
@@ -3,25 +3,27 @@ import os
3
3
  from torch.cuda import device_count
4
4
  import torch.multiprocessing as mp
5
5
  from konfai.utils.utils import setup, TensorBoard, Log
6
+ from konfai import KONFAI_NB_CORES
6
7
 
7
8
  import sys
8
9
  sys.path.insert(0, os.getcwd())
9
10
 
10
11
  def main():
12
+ import tracemalloc
13
+
11
14
  parser = argparse.ArgumentParser(description="KonfAI", formatter_class=argparse.ArgumentDefaultsHelpFormatter)
12
15
  try:
13
16
  with setup(parser) as distributedObject:
14
- with Log(distributedObject.name):
17
+ with Log(distributedObject.name, 0):
15
18
  world_size = device_count()
16
19
  if world_size == 0:
17
- world_size = 1
20
+ world_size = int(KONFAI_NB_CORES())
18
21
  distributedObject.setup(world_size)
19
22
  with TensorBoard(distributedObject.name):
20
23
  mp.spawn(distributedObject, nprocs=world_size)
21
24
  except KeyboardInterrupt:
22
25
  print("\n[KonfAI] Manual interruption (Ctrl+C)")
23
26
 
24
-
25
27
  def cluster():
26
28
  parser = argparse.ArgumentParser(description="KonfAI", formatter_class=argparse.ArgumentDefaultsHelpFormatter)
27
29
 
@@ -7,32 +7,33 @@ from konfai.data.patching import ModelPatch
7
7
 
8
8
  class UNetHead(network.ModuleArgsDict):
9
9
 
10
- def __init__(self, in_channels: int, nb_class: int, dim: int) -> None:
10
+ def __init__(self, in_channels: int, nb_class: int, dim: int, level: int) -> None:
11
11
  super().__init__()
12
12
  self.add_module("Conv", blocks.getTorchModule("Conv", dim)(in_channels = in_channels, out_channels = nb_class, kernel_size = 1, stride = 1, padding = 0))
13
13
  self.add_module("Softmax", torch.nn.Softmax(dim=1))
14
14
  self.add_module("Argmax", blocks.ArgMax(dim=1))
15
+
15
16
 
16
17
  class UNetBlock(network.ModuleArgsDict):
17
18
 
18
- def __init__(self, channels: list[int], nb_conv_per_stage: int, blockConfig: blocks.BlockConfig, downSampleMode: blocks.DownSampleMode, upSampleMode: blocks.UpSampleMode, attention : bool, block: type, nb_class: int, dim: int, i : int = 0, mri: bool = False) -> None:
19
+ def __init__(self, channels: list[int], nb_conv_per_stage: int, blockConfig: blocks.BlockConfig, downSampleMode: blocks.DownSampleMode, upSampleMode: blocks.UpSampleMode, attention : bool, block: type, nb_class: int, dim: int, i : int = 0) -> None:
19
20
  super().__init__()
20
21
  blockConfig_stride = blockConfig
21
22
  if i > 0:
22
23
  if downSampleMode != blocks.DownSampleMode.CONV_STRIDE:
23
24
  self.add_module(downSampleMode.name, blocks.downSample(in_channels=channels[0], out_channels=channels[1], downSampleMode=downSampleMode, dim=dim))
24
25
  else:
25
- blockConfig_stride = blocks.BlockConfig(blockConfig.kernel_size, (1,2,2) if mri and i > 4 else 2, blockConfig.padding, blockConfig.bias, blockConfig.activation, blockConfig.normMode)
26
+ blockConfig_stride = blocks.BlockConfig(blockConfig.kernel_size, 2, blockConfig.padding, blockConfig.bias, blockConfig.activation, blockConfig.normMode)
26
27
  self.add_module("DownConvBlock", block(in_channels=channels[0], out_channels=channels[1], blockConfigs=[blockConfig_stride]+[blockConfig]*(nb_conv_per_stage-1), dim=dim))
27
28
  if len(channels) > 2:
28
- self.add_module("UNetBlock_{}".format(i+1), UNetBlock(channels[1:], nb_conv_per_stage, blockConfig, downSampleMode, upSampleMode, attention, block, nb_class, dim, i+1, mri=mri))
29
+ self.add_module("UNetBlock_{}".format(i+1), UNetBlock(channels[1:], nb_conv_per_stage, blockConfig, downSampleMode, upSampleMode, attention, block, nb_class, dim, i+1))
29
30
  self.add_module("UpConvBlock", block(in_channels=(channels[1]+channels[2]) if upSampleMode != blocks.UpSampleMode.CONV_TRANSPOSE else channels[1]*2, out_channels=channels[1], blockConfigs=[blockConfig]*nb_conv_per_stage, dim=dim))
30
31
  if nb_class > 0:
31
- self.add_module("Head", UNetHead(channels[1], nb_class, dim), out_branch=[-1])
32
+ self.add_module("Head", UNetHead(channels[1], nb_class, dim, i), out_branch=[-1])
32
33
  if i > 0:
33
34
  if attention:
34
35
  self.add_module("Attention", blocks.Attention(F_g=channels[1], F_l=channels[0], F_int=channels[0], dim=dim), in_branch=[1, 0], out_branch=[1])
35
- self.add_module(upSampleMode.name, blocks.upSample(in_channels=channels[1], out_channels=channels[0], upSampleMode=upSampleMode, dim=dim, kernel_size=(1,2,2) if mri and i > 4 else 2, stride=(1,2,2) if mri and i > 4 else 2))
36
+ self.add_module(upSampleMode.name, blocks.upSample(in_channels=channels[1], out_channels=channels[0], upSampleMode=upSampleMode, dim=dim, kernel_size=2, stride=2))
36
37
  self.add_module("SkipConnection", blocks.Concat(), in_branch=[0, 1])
37
38
 
38
39
  class UNet(network.Network):
@@ -51,8 +52,6 @@ class UNet(network.Network):
51
52
  downSampleMode: str = "MAXPOOL",
52
53
  upSampleMode: str = "CONV_TRANSPOSE",
53
54
  attention : bool = False,
54
- blockType: str = "Conv",
55
- mri: bool = False) -> None:
55
+ blockType: str = "Conv") -> None:
56
56
  super().__init__(in_channels = channels[0], optimizer = optimizer, schedulers = schedulers, outputsCriterions = outputsCriterions, patch=patch, dim = dim)
57
- self.add_module("UNetBlock_0", UNetBlock(channels, nb_conv_per_stage, blockConfig, downSampleMode=blocks.DownSampleMode._member_map_[downSampleMode], upSampleMode=blocks.UpSampleMode._member_map_[upSampleMode], attention=attention, block = blocks.ConvBlock if blockType == "Conv" else blocks.ResBlock, nb_class=nb_class, dim=dim, mri = mri))
58
-
57
+ self.add_module("UNetBlock_0", UNetBlock(channels, nb_conv_per_stage, blockConfig, downSampleMode=blocks.DownSampleMode._member_map_[downSampleMode], upSampleMode=blocks.UpSampleMode._member_map_[upSampleMode], attention=attention, block = blocks.ConvBlock if blockType == "Conv" else blocks.ResBlock, nb_class=nb_class, dim=dim))
konfai/network/network.py CHANGED
@@ -39,7 +39,6 @@ class OptimizerLoader():
39
39
  self.name = name
40
40
 
41
41
  def getOptimizer(self, key: str, parameter: Iterator[torch.nn.parameter.Parameter]) -> torch.optim.Optimizer:
42
- torch.optim.AdamW
43
42
  return config("{}.Model.{}.Optimizer".format(KONFAI_ROOT(), key))(getattr(importlib.import_module('torch.optim'), self.name))(parameter, config = None)
44
43
 
45
44
  class SchedulerStep():
konfai/predictor.py CHANGED
@@ -91,13 +91,30 @@ class OutDataset(Dataset, NeedDevice, ABC):
91
91
  super().write(self.group, name, layer.numpy(), self.attributes[index][0][0])
92
92
  self.attributes.pop(index)
93
93
 
94
+ class Reduction():
95
+
96
+ def __init__(self):
97
+ pass
98
+
99
+ class ReductionMean():
100
+
101
+ def __init__(self):
102
+ pass
103
+
104
+ class ReductionMedian():
105
+
106
+ def __init__(self):
107
+ pass
108
+
109
+
110
+
94
111
  class OutSameAsGroupDataset(OutDataset):
95
112
 
96
113
  @config("OutDataset")
97
- def __init__(self, dataset_filename: str = "./Dataset:mha", group: str = "default", sameAsGroup: str = "default", pre_transforms : dict[str, TransformLoader] = {"default:Normalize": TransformLoader()}, post_transforms : dict[str, TransformLoader] = {"default:Normalize": TransformLoader()}, final_transforms : dict[str, TransformLoader] = {"default:Normalize": TransformLoader()}, patchCombine: Union[str, None] = None, redution: str = "mean", inverse_transform: bool = True) -> None:
114
+ def __init__(self, dataset_filename: str = "./Dataset:mha", group: str = "default", sameAsGroup: str = "default", pre_transforms : dict[str, TransformLoader] = {"default:Normalize": TransformLoader()}, post_transforms : dict[str, TransformLoader] = {"default:Normalize": TransformLoader()}, final_transforms : dict[str, TransformLoader] = {"default:Normalize": TransformLoader()}, patchCombine: Union[str, None] = None, reduction: str = "mean", inverse_transform: bool = True) -> None:
98
115
  super().__init__(dataset_filename, group, pre_transforms, post_transforms, final_transforms, patchCombine)
99
116
  self.group_src, self.group_dest = sameAsGroup.split(":")
100
- self.redution = redution
117
+ self.reduction = reduction
101
118
  self.inverse_transform = inverse_transform
102
119
 
103
120
  def addLayer(self, index_dataset: int, index_augmentation: int, index_patch: int, layer: torch.Tensor, dataset: DatasetIter):
@@ -151,9 +168,9 @@ class OutSameAsGroupDataset(OutDataset):
151
168
  self.output_layer_accumulator.pop(index)
152
169
  dtype = result.dtype
153
170
 
154
- if self.redution == "mean":
171
+ if self.reduction == "mean":
155
172
  result = torch.mean(result.float(), dim=0).to(dtype)
156
- elif self.redution == "median":
173
+ elif self.reduction == "median":
157
174
  result, _ = torch.median(result.float(), dim=0)
158
175
  else:
159
176
  raise NameError("Reduction method does not exist (mean, median)")
@@ -201,7 +218,7 @@ class OutDatasetLoader():
201
218
 
202
219
  class _Predictor():
203
220
 
204
- def __init__(self, world_size: int, global_rank: int, local_rank: int, predict_path: str, data_log: Union[list[str], None], outsDataset: dict[str, OutDataset], modelComposite: DDP, dataloader_prediction: DataLoader) -> None:
221
+ def __init__(self, world_size: int, global_rank: int, local_rank: int, autocast: bool, predict_path: str, data_log: Union[list[str], None], outsDataset: dict[str, OutDataset], modelComposite: DDP, dataloader_prediction: DataLoader) -> None:
205
222
  self.world_size = world_size
206
223
  self.global_rank = global_rank
207
224
  self.local_rank = local_rank
@@ -209,7 +226,7 @@ class _Predictor():
209
226
  self.modelComposite = modelComposite
210
227
  self.dataloader_prediction = dataloader_prediction
211
228
  self.outsDataset = outsDataset
212
-
229
+ self.autocast = autocast
213
230
 
214
231
  self.it = 0
215
232
 
@@ -239,21 +256,22 @@ class _Predictor():
239
256
  self.modelComposite.eval()
240
257
  self.modelComposite.module.setState(NetState.PREDICTION)
241
258
  desc = lambda : "Prediction : {}".format(description(self.modelComposite))
242
- self.dataloader_prediction.dataset.load()
243
- with tqdm.tqdm(iterable = enumerate(self.dataloader_prediction), leave=False, desc = desc(), total=len(self.dataloader_prediction), disable=self.global_rank != 0 and "KONFAI_CLUSTER" not in os.environ) as batch_iter:
259
+ self.dataloader_prediction.dataset.load("Prediction")
260
+ with tqdm.tqdm(iterable = enumerate(self.dataloader_prediction), leave=True, desc = desc(), total=len(self.dataloader_prediction), ncols=0) as batch_iter:
244
261
  dist.barrier()
245
262
  for it, data_dict in batch_iter:
246
- input = self.getInput(data_dict)
247
- for name, output in self.modelComposite(input, list(self.outsDataset.keys())):
248
- self._predict_log(data_dict)
249
- outDataset = self.outsDataset[name]
250
- for i, (index, patch_augmentation, patch_index) in enumerate([(int(index), int(patch_augmentation), int(patch_index)) for index, patch_augmentation, patch_index in zip(list(data_dict.values())[0][1], list(data_dict.values())[0][2], list(data_dict.values())[0][3])]):
251
- outDataset.addLayer(index, patch_augmentation, patch_index, output[i].cpu(), self.dataset)
252
- if outDataset.isDone(index):
253
- outDataset.write(index, self.dataset.getDatasetFromIndex(list(data_dict.keys())[0], index).name.split("/")[-1], outDataset.getOutput(index, self.dataset))
254
-
255
- batch_iter.set_description(desc())
256
- self.it += 1
263
+ with torch.amp.autocast('cuda', enabled=self.autocast):
264
+ input = self.getInput(data_dict)
265
+ for name, output in self.modelComposite(input, list(self.outsDataset.keys())):
266
+ self._predict_log(data_dict)
267
+ outDataset = self.outsDataset[name]
268
+ for i, (index, patch_augmentation, patch_index) in enumerate([(int(index), int(patch_augmentation), int(patch_index)) for index, patch_augmentation, patch_index in zip(list(data_dict.values())[0][1], list(data_dict.values())[0][2], list(data_dict.values())[0][3])]):
269
+ outDataset.addLayer(index, patch_augmentation, patch_index, output[i].cpu(), self.dataset)
270
+ if outDataset.isDone(index):
271
+ outDataset.write(index, self.dataset.getDatasetFromIndex(list(data_dict.keys())[0], index).name.split("/")[-1], outDataset.getOutput(index, self.dataset))
272
+
273
+ batch_iter.set_description(desc())
274
+ self.it += 1
257
275
 
258
276
  def _predict_log(self, data_dict : dict[str, tuple[torch.Tensor, int, int, int]]):
259
277
  measures = DistributedObject.getMeasure(self.world_size, self.global_rank, self.local_rank, {"" : self.modelComposite.module}, 1)
@@ -320,6 +338,7 @@ class Predictor(DistributedObject):
320
338
  train_name: str = "name",
321
339
  manual_seed : Union[int, None] = None,
322
340
  gpu_checkpoints: Union[list[str], None] = None,
341
+ autocast : bool = False,
323
342
  outsDataset: Union[dict[str, OutDatasetLoader], None] = {"default:Default" : OutDatasetLoader()},
324
343
  images_log: list[str] = []) -> None:
325
344
  if os.environ["KONFAI_CONFIG_MODE"] != "Done":
@@ -328,6 +347,7 @@ class Predictor(DistributedObject):
328
347
  self.manual_seed = manual_seed
329
348
  self.dataset = dataset
330
349
  self.combine = combine
350
+ self.autocast = autocast
331
351
 
332
352
  self.model = model.getModel(train=False)
333
353
  self.it = 0
@@ -384,7 +404,7 @@ class Predictor(DistributedObject):
384
404
 
385
405
  shutil.copyfile(CONFIG_FILE(), self.predict_path+"Prediction.yml")
386
406
 
387
- self.model.init(autocast=False, state = State.PREDICTION)
407
+ self.model.init(self.autocast, State.PREDICTION, self.dataset.getGroupsDest())
388
408
  self.model.init_outputsGroup()
389
409
  self.model._compute_channels_trace(self.model, self.model.in_channels, None, self.gpu_checkpoints)
390
410
  self.modelComposite = ModelComposite(self.model, len(MODEL().split(":")), self.combine)
@@ -402,7 +422,7 @@ class Predictor(DistributedObject):
402
422
  def run_process(self, world_size: int, global_rank: int, local_rank: int, dataloaders: list[DataLoader]):
403
423
  modelComposite = Network.to(self.modelComposite, local_rank*self.size)
404
424
  modelComposite = DDP(modelComposite, static_graph=True) if torch.cuda.is_available() else CPU_Model(modelComposite)
405
- with _Predictor(world_size, global_rank, local_rank, self.predict_path, self.images_log, self.outsDataset, modelComposite, *dataloaders) as p:
425
+ with _Predictor(world_size, global_rank, local_rank, self.autocast, self.predict_path, self.images_log, self.outsDataset, modelComposite, *dataloaders) as p:
406
426
  p.run()
407
427
 
408
428
 
konfai/trainer.py CHANGED
@@ -17,7 +17,6 @@ from konfai.utils.config import config
17
17
  from konfai.utils.utils import State, DataLog, DistributedObject, description, TrainerError
18
18
  from konfai.network.network import Network, ModelLoader, NetState, CPU_Model
19
19
 
20
-
21
20
  class EarlyStoppingBase:
22
21
 
23
22
  def __init__(self):
@@ -99,7 +98,7 @@ class _Trainer():
99
98
  self.autocast = autocast
100
99
  self.modelEMA = modelEMA
101
100
  self.early_stopping = EarlyStoppingBase() if early_stopping is None else early_stopping
102
-
101
+
103
102
  self.it_validation = it_validation
104
103
  if self.it_validation is None:
105
104
  self.it_validation = len(dataloader_training)
@@ -110,6 +109,8 @@ class _Trainer():
110
109
  for data in data_log:
111
110
  self.data_log[data.split("/")[0].replace(":", ".")] = (DataLog.__getitem__(data.split("/")[1]).value[0], int(data.split("/")[2]))
112
111
 
112
+
113
+
113
114
  def __enter__(self):
114
115
  return self
115
116
 
@@ -118,13 +119,14 @@ class _Trainer():
118
119
  self.tb.close()
119
120
 
120
121
  def run(self) -> None:
121
- with tqdm.tqdm(iterable = range(self.epoch, self.epochs), leave=False, total=self.epochs, initial=self.epoch, desc="Progress", disable=self.global_rank != 0) as epoch_tqdm:
122
+ self.dataloader_training.dataset.load("Train")
123
+ self.dataloader_validation.dataset.load("Validation")
124
+ with tqdm.tqdm(iterable = range(self.epoch, self.epochs), leave=False, total=self.epochs, initial=self.epoch, desc="Progress") as epoch_tqdm:
122
125
  for self.epoch in epoch_tqdm:
123
- self.dataloader_training.dataset.load()
124
126
  self.train()
125
127
  if self.early_stopping.isStopped():
126
128
  break
127
- self.dataloader_training.dataset.resetAugmentation()
129
+ self.dataloader_training.dataset.resetAugmentation("Train")
128
130
 
129
131
  def getInput(self, data_dict : dict[str, tuple[torch.Tensor, int, int, int, str, bool]]) -> dict[tuple[str, bool], torch.Tensor]:
130
132
  return {(k, v[5][0].item()) : v[0] for k, v in data_dict.items()}
@@ -137,7 +139,8 @@ class _Trainer():
137
139
  self.modelEMA.module.setState(NetState.TRAIN)
138
140
 
139
141
  desc = lambda : "Training : {}".format(description(self.model, self.modelEMA))
140
- with tqdm.tqdm(iterable = enumerate(self.dataloader_training), desc = desc(), total=len(self.dataloader_training), leave=False, disable=self.global_rank != 0 and "KONFAI_CLUSTER" not in os.environ) as batch_iter:
142
+
143
+ with tqdm.tqdm(iterable = enumerate(self.dataloader_training), desc = desc(), total=len(self.dataloader_training), leave=False, ncols=0) as batch_iter:
141
144
  for _, data_dict in batch_iter:
142
145
  with torch.amp.autocast('cuda', enabled=self.autocast):
143
146
  input = self.getInput(data_dict)
@@ -171,8 +174,7 @@ class _Trainer():
171
174
 
172
175
  desc = lambda : "Validation : {}".format(description(self.model, self.modelEMA))
173
176
  data_dict = None
174
- self.dataloader_validation.dataset.load()
175
- with tqdm.tqdm(iterable = enumerate(self.dataloader_validation), desc = desc(), total=len(self.dataloader_validation), leave=False, disable=self.global_rank != 0 and "KONFAI_CLUSTER" not in os.environ) as batch_iter:
177
+ with tqdm.tqdm(iterable = enumerate(self.dataloader_validation), desc = desc(), total=len(self.dataloader_validation), leave=False, ncols=0) as batch_iter:
176
178
  for _, data_dict in batch_iter:
177
179
  input = self.getInput(data_dict)
178
180
  self.model(input)
@@ -180,7 +182,7 @@ class _Trainer():
180
182
  self.modelEMA.module(input)
181
183
 
182
184
  batch_iter.set_description(desc())
183
- self.dataloader_validation.dataset.resetAugmentation()
185
+ self.dataloader_validation.dataset.resetAugmentation("Validation")
184
186
  dist.barrier()
185
187
  self.model.train()
186
188
  self.model.module.setState(NetState.TRAIN)
@@ -314,6 +316,18 @@ class Trainer(DistributedObject):
314
316
  self.ema_decay = ema_decay
315
317
  self.modelEMA : Union[torch.optim.swa_utils.AveragedModel, None] = None
316
318
  self.data_log = data_log
319
+
320
+ modules = []
321
+ for i,_ in self.model.named_modules():
322
+ modules.append(i)
323
+ for k in self.data_log:
324
+ tmp = k.split("/")[0].replace(":", ".")
325
+ if tmp not in self.dataset.getGroupsDest() and tmp not in modules:
326
+ raise TrainerError( f"Invalid key '{tmp}' in `data_log`.",
327
+ f"This key is neither a destination group from the dataset ({self.dataset.getGroupsDest()})",
328
+ f"nor a valid module name in the model ({modules}).",
329
+ "Please check your `data_log` configuration — it should reference either a model output or a dataset group.")
330
+
317
331
  self.gradient_checkpoints = gradient_checkpoints
318
332
  self.gpu_checkpoints = gpu_checkpoints
319
333
  self.save_checkpoint_mode = save_checkpoint_mode
@@ -360,7 +374,7 @@ class Trainer(DistributedObject):
360
374
  os.makedirs(dir)
361
375
 
362
376
  for name in sorted(os.listdir(path_checkpoint)):
363
- checkpoint = torch.load(path_checkpoint+name, weights_only=False)
377
+ checkpoint = torch.load(path_checkpoint+name, weights_only=False, map_location='cpu')
364
378
  self.model.load(checkpoint, init=False, ema=False)
365
379
 
366
380
  torch.save(self.model, "{}Serialized/{}".format(path_model, name))
konfai/utils/dataset.py CHANGED
@@ -348,6 +348,10 @@ class Dataset():
348
348
  @abstractmethod
349
349
  def getNames(self, group: str) -> list[str]:
350
350
  pass
351
+
352
+ @abstractmethod
353
+ def getGroup(self) -> list[str]:
354
+ pass
351
355
 
352
356
  @abstractmethod
353
357
  def isExist(self, group: str, name: Union[str, None] = None) -> bool:
@@ -458,6 +462,9 @@ class Dataset():
458
462
  names.extend(self.getNames("/".join(groups.split("/")[1:]), h5_group[group]))
459
463
  return names
460
464
 
465
+ def getGroup(self):
466
+ return self.h5.keys()
467
+
461
468
  def _getDataset(self, groups: str, name: str, h5_group: h5py.Group = None) -> h5py.Dataset:
462
469
  if h5_group is None:
463
470
  h5_group = self.h5
@@ -613,7 +620,10 @@ class Dataset():
613
620
 
614
621
  def getNames(self, group: str) -> list[str]:
615
622
  raise NotImplementedError()
616
-
623
+
624
+ def getGroup(self):
625
+ raise NotImplementedError()
626
+
617
627
  def getInfos(self, group: str, name: str) -> tuple[list[int], Attribute]:
618
628
  attributes = Attribute()
619
629
  if os.path.exists("{}{}{}.{}".format(self.filename, group if group is not None else "", name, self.format)):
@@ -736,7 +746,7 @@ class Dataset():
736
746
  subDirectories.extend(self._getSubDirectories("/".join(groups.split("/")[1:]), subDirectory))
737
747
  return subDirectories
738
748
 
739
- def getNames(self, groups: str, index: Union[list[int], None] = None, subDirectory: str = "") -> list[str]:
749
+ def getNames(self, groups: str, index: Union[list[int], None] = None) -> list[str]:
740
750
  names = []
741
751
  if self.is_directory:
742
752
  for subDirectory in self._getSubDirectories(groups):
@@ -752,6 +762,21 @@ class Dataset():
752
762
  names = file.getNames(groups)
753
763
  return [name for i, name in enumerate(sorted(names)) if index is None or i in index]
754
764
 
765
+ def getGroup(self):
766
+ if self.is_directory:
767
+ groups = set()
768
+ for root, _, files in os.walk(self.filename):
769
+ for file in files:
770
+ path = os.path.relpath(os.path.join(root, file.split(".")[0]), self.filename)
771
+ parts = path.split("/")
772
+ if len(parts) >= 2:
773
+ del parts[-2]
774
+ groups.add("/".join(parts))
775
+ else:
776
+ with Dataset.File(self.filename, True, self.format) as file:
777
+ groups = file.getGroup()
778
+ return list(groups)
779
+
755
780
  def getInfos(self, groups: str, name: str) -> tuple[list[int], Attribute]:
756
781
  if self.is_directory:
757
782
  for subDirectory in self._getSubDirectories(groups):