konfai 1.1.9__py3-none-any.whl → 1.2.1__py3-none-any.whl

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of konfai might be problematic. Click here for more details.

@@ -6,6 +6,7 @@ from abc import ABC, abstractmethod
6
6
  from collections.abc import Iterator, Mapping
7
7
  from concurrent.futures import ThreadPoolExecutor, as_completed
8
8
  from functools import partial
9
+ from typing import cast
9
10
 
10
11
  import numpy as np
11
12
  import torch
@@ -270,50 +271,55 @@ class Subset:
270
271
  self.subset = subset
271
272
  self.shuffle = shuffle
272
273
 
274
+ def _get_index(self, subset: str | int, names: list[str]) -> list[int]:
275
+ size = len(names)
276
+ index = []
277
+ if isinstance(subset, int):
278
+ index.append(subset)
279
+ elif ":" in subset:
280
+ r = np.clip(
281
+ np.asarray([int(subset.split(":")[0]), int(subset.split(":")[1])]),
282
+ 0,
283
+ size,
284
+ )
285
+ index = list(range(r[0], r[1]))
286
+ elif os.path.exists(subset):
287
+ train_names = []
288
+ with open(subset) as f:
289
+ for name in f:
290
+ train_names.append(name.strip())
291
+ index = []
292
+ for i, name in enumerate(names):
293
+ if name in train_names:
294
+ index.append(i)
295
+ elif subset.startswith("~") and os.path.exists(subset[1:]):
296
+ exclude_names = []
297
+ with open(subset[1:]) as f:
298
+ for name in f:
299
+ exclude_names.append(name.strip())
300
+ index = []
301
+ for i, name in enumerate(names):
302
+ if name not in exclude_names:
303
+ index.append(i)
304
+ return index
305
+
273
306
  def __call__(self, names: list[str], infos: dict[str, tuple[list[int], Attribute]]) -> set[str]:
274
307
  names = sorted(names)
275
-
276
308
  size = len(names)
277
- index = []
309
+
278
310
  if self.subset is None:
279
311
  index = list(range(0, size))
280
- elif isinstance(self.subset, str):
281
- if ":" in self.subset:
282
- r = np.clip(
283
- np.asarray([int(self.subset.split(":")[0]), int(self.subset.split(":")[1])]),
284
- 0,
285
- size,
286
- )
287
- index = list(range(r[0], r[1]))
288
- elif os.path.exists(self.subset):
289
- train_names = []
290
- with open(self.subset) as f:
291
- for name in f:
292
- train_names.append(name.strip())
293
- index = []
294
- for i, name in enumerate(names):
295
- if name in train_names:
296
- index.append(i)
297
- elif self.subset.startswith("~") and os.path.exists(self.subset[1:]):
298
- exclude_names = []
299
- with open(self.subset[1:]) as f:
300
- for name in f:
301
- exclude_names.append(name.strip())
302
- index = []
303
- for i, name in enumerate(names):
304
- if name not in exclude_names:
305
- index.append(i)
306
-
307
312
  elif isinstance(self.subset, list):
308
- index = []
309
- if len(self.subset) > 0:
310
- for s in self.subset:
311
- if isinstance(s, int):
312
- index.append(s)
313
- elif isinstance(s, str):
314
- for i, name in enumerate(names):
315
- if name in self.subset:
316
- index.append(i)
313
+ index_set: set[int] = set()
314
+ for s in self.subset:
315
+ if len(index_set) == 0:
316
+ index_set.update(set(self._get_index(s, names)))
317
+ else:
318
+ index_set = index_set.intersection(set(self._get_index(s, names)))
319
+ index = list(index_set)
320
+ print(index)
321
+ else:
322
+ index = self._get_index(self.subset, names)
317
323
  if self.shuffle:
318
324
  index = random.sample(index, len(index)) # nosec B311
319
325
  return {names[i] for i in index}
@@ -456,7 +462,7 @@ class Data(ABC):
456
462
  mappings.append(list(mapping[-offset:]) if itr + offset > len(mapping) else mapping[itr : itr + offset])
457
463
  return mappings
458
464
 
459
- def get_data(self, world_size: int) -> list[list[DataLoader]]:
465
+ def get_data(self, world_size: int) -> tuple[list[list[DataLoader]], list[str], list[str]]:
460
466
  datasets: dict[str, list[tuple[str, bool]]] = {}
461
467
  if self.dataset_filenames is None or len(self.dataset_filenames) == 0:
462
468
  raise DatasetManagerError("No dataset filenames were provided")
@@ -502,8 +508,8 @@ class Data(ABC):
502
508
  f"Group source '{group_src}' not found in any dataset.",
503
509
  f"Dataset filenames provided: {self.dataset_filenames}",
504
510
  f"Available groups across all datasets: "
505
- "{[f'{f} {d.get_group()}' for f, d in self.datasets.items()]}"
506
- f"Please check that an entry in the dataset with the name '{group_src}.{format}' exists.",
511
+ f"{[f'{f} {d.get_group()}' for f, d in self.datasets.items()]}\n"
512
+ f"Please check that an entry in the dataset with the name '{group_src}' exists.",
507
513
  )
508
514
 
509
515
  for group_dest in self.groups_src[group_src]:
@@ -596,8 +602,7 @@ class Data(ABC):
596
602
 
597
603
  data, mapping = self._get_datasets(list(subset_names), dataset_name)
598
604
 
599
- train_mapping = mapping
600
- validate_mapping = []
605
+ index = []
601
606
  if isinstance(self.validation, float) or isinstance(self.validation, int):
602
607
  if self.validation <= 0 or self.validation >= 1:
603
608
  raise DatasetManagerError(
@@ -605,24 +610,16 @@ class Data(ABC):
605
610
  f"Received: {self.validation}",
606
611
  "Example: validation = 0.2 # for a 20% validation split",
607
612
  )
608
-
609
- train_mapping, validate_mapping = (
610
- mapping[: int(math.floor(len(mapping) * (1 - self.validation)))],
611
- mapping[int(math.floor(len(mapping) * (1 - self.validation))) :],
612
- )
613
+ index = [m[0] for m in mapping[int(math.floor(len(mapping) * (1 - self.validation))) :]]
613
614
  elif isinstance(self.validation, str):
614
615
  if ":" in self.validation:
615
616
  index = list(range(int(self.validation.split(":")[0]), int(self.validation.split(":")[1])))
616
- train_mapping = [m for m in mapping if m[0] not in index]
617
- validate_mapping = [m for m in mapping if m[0] in index]
618
617
  elif os.path.exists(self.validation):
619
618
  validation_names = []
620
619
  with open(self.validation) as f:
621
620
  for name in f:
622
621
  validation_names.append(name.strip())
623
622
  index = [i for i, n in enumerate(subset_names) if n in validation_names]
624
- train_mapping = [m for m in mapping if m[0] not in index]
625
- validate_mapping = [m for m in mapping if m[0] in index]
626
623
  else:
627
624
  raise DatasetManagerError(
628
625
  f"Invalid string value for 'validation': '{self.validation}'",
@@ -634,25 +631,23 @@ class Data(ABC):
634
631
  "The provided value is neither a valid slice nor a readable file.",
635
632
  "Please fix your 'validation' setting in the configuration.",
636
633
  )
637
-
638
634
  elif isinstance(self.validation, list):
639
- if len(self.validation) > 0:
640
- if isinstance(self.validation[0], int):
641
- train_mapping = [m for m in mapping if m[0] not in self.validation]
642
- validate_mapping = [m for m in mapping if m[0] in self.validation]
643
- elif isinstance(self.validation[0], str):
644
- index = [i for i, n in enumerate(subset_names) if n in self.validation]
645
- train_mapping = [m for m in mapping if m[0] not in index]
646
- validate_mapping = [m for m in mapping if m[0] in index]
647
- else:
648
- raise DatasetManagerError(
649
- "Invalid list type for 'validation': elements of type "
650
- f"'{type(self.validation[0]).__name__}' are not supported.",
651
- "Supported list element types are:",
652
- "\t• int → list of indices (e.g., [0, 1, 2])",
653
- "\t• str → list of sample names (e.g., ['patient01', 'patient02'])",
654
- f"Received list: {self.validation}",
655
- )
635
+ if isinstance(self.validation[0], int):
636
+ index = cast(list[int], self.validation)
637
+ elif isinstance(self.validation[0], str):
638
+ index = [i for i, n in enumerate(subset_names) if n in self.validation]
639
+ else:
640
+ raise DatasetManagerError(
641
+ "Invalid list type for 'validation': elements of type "
642
+ f"'{type(self.validation[0]).__name__}' are not supported.",
643
+ "Supported list element types are:",
644
+ "\t• int → list of indices (e.g., [0, 1, 2])",
645
+ "\t• str → list of sample names (e.g., ['patient01', 'patient02'])",
646
+ f"Received list: {self.validation}",
647
+ )
648
+ train_mapping = [m for m in mapping if m[0] not in index]
649
+ validate_mapping = [m for m in mapping if m[0] in index]
650
+
656
651
  if len(train_mapping) == 0:
657
652
  raise DatasetManagerError(
658
653
  "No data left for training after applying the validation split.",
@@ -668,6 +663,9 @@ class Data(ABC):
668
663
  f"Validation setting: {self.validation}",
669
664
  "Please increase the validation size, increase the dataset, or disable validation.",
670
665
  )
666
+
667
+ validation_names = [name for i, name in enumerate(subset_names) if i in index]
668
+ train_names = [name for name in subset_names if name not in validation_names]
671
669
  train_mappings = Data._split(train_mapping, world_size)
672
670
  validate_mappings = Data._split(validate_mapping, world_size)
673
671
 
@@ -701,7 +699,7 @@ class Data(ABC):
701
699
  **self.dataLoader_args,
702
700
  )
703
701
  )
704
- return data_loaders
702
+ return data_loaders, train_names, validation_names
705
703
 
706
704
 
707
705
  class DataTrain(Data):
@@ -773,7 +771,7 @@ class DataMetric(Data):
773
771
  dataset_filenames=dataset_filenames,
774
772
  groups_src=groups_src,
775
773
  patch=None,
776
- use_cache=False,
774
+ use_cache=True,
777
775
  subset=subset,
778
776
  batch_size=1,
779
777
  validation=validation,
konfai/data/transform.py CHANGED
@@ -1,4 +1,5 @@
1
1
  import importlib
2
+ import tempfile
2
3
  from abc import ABC, abstractmethod
3
4
  from typing import Any
4
5
 
@@ -670,3 +671,26 @@ class OneHot(Transform):
670
671
 
671
672
  def inverse(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
672
673
  return torch.argmax(tensor, dim=1).unsqueeze(1)
674
+
675
+
676
+ class TotalSegmentator(Transform):
677
+
678
+ def __init__(self, task: str = "total"):
679
+ super().__init__()
680
+ self.task = task
681
+
682
+ def __call__(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
683
+ from totalsegmentator.python_api import totalsegmentator
684
+
685
+ with tempfile.TemporaryDirectory() as tmpdir:
686
+ image = data_to_image(tensor.numpy(), cache_attribute)
687
+ sitk.WriteImage(image, tmpdir + "/image.nii.gz")
688
+ seg = totalsegmentator(tmpdir + "/image.nii.gz", tmpdir, task=self.task, skip_saving=True, quiet=True)
689
+ return (
690
+ torch.from_numpy(np.array(np.asanyarray(seg.dataobj), copy=True).astype(np.uint8, copy=False))
691
+ .permute(2, 1, 0)
692
+ .unsqueeze(0)
693
+ )
694
+
695
+ def inverse(self, name: str, tensor: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
696
+ return tensor
konfai/evaluator.py CHANGED
@@ -152,14 +152,14 @@ class Statistics:
152
152
  - mean and count
153
153
  """
154
154
  return {
155
- "max": np.max(values),
156
- "min": np.min(values),
157
- "std": np.std(values),
158
- "25pc": np.percentile(values, 25),
159
- "50pc": np.percentile(values, 50),
160
- "75pc": np.percentile(values, 75),
161
- "mean": np.mean(values),
162
- "count": len(values),
155
+ "max": np.nanmax(values) if np.any(~np.isnan(values)) else np.nan,
156
+ "min": np.nanmin(values) if np.any(~np.isnan(values)) else np.nan,
157
+ "std": np.nanstd(values) if np.any(~np.isnan(values)) else np.nan,
158
+ "25pc": np.nanpercentile(values, 25) if np.any(~np.isnan(values)) else np.nan,
159
+ "50pc": np.nanpercentile(values, 50) if np.any(~np.isnan(values)) else np.nan,
160
+ "75pc": np.nanpercentile(values, 75) if np.any(~np.isnan(values)) else np.nan,
161
+ "mean": np.nanmean(values) if np.any(~np.isnan(values)) else np.nan,
162
+ "count": np.count_nonzero(~np.isnan(values)) if np.any(~np.isnan(values)) else np.nan,
163
163
  }
164
164
 
165
165
  def write(self, outputs: list[dict[str, dict[str, Any]]]) -> None:
@@ -261,10 +261,26 @@ class Evaluator(DistributedObject):
261
261
  ]
262
262
  name = data_dict[output_group][1][0]
263
263
  for metric in self.metrics[output_group][target_group]:
264
- result[f"{output_group}:{target_group}:{metric.__class__.__name__}"] = metric(
264
+ loss = metric(
265
265
  (data_dict[output_group][0].to(0) if torch.cuda.is_available() else data_dict[output_group][0]),
266
266
  *targets,
267
- ).item()
267
+ )
268
+ if isinstance(loss, tuple):
269
+ true_loss = loss[1]
270
+ else:
271
+ true_loss = loss.item()
272
+
273
+ if isinstance(true_loss, dict):
274
+ loss = 0
275
+ c = 0
276
+ for k, v in true_loss.items():
277
+ result[f"{output_group}:{target_group}:{metric.__class__.__name__}:{k}"] = v
278
+ if not np.isnan(v):
279
+ loss += v
280
+ c += 1
281
+ result[f"{output_group}:{target_group}:{metric.__class__.__name__}"] = loss / c
282
+ else:
283
+ result[f"{output_group}:{target_group}:{metric.__class__.__name__}"] = true_loss
268
284
  statistics.add(result, name)
269
285
  return result
270
286
 
@@ -304,7 +320,7 @@ class Evaluator(DistributedObject):
304
320
  f"{self.metric_path}{metric_namefile_src}.yml",
305
321
  )
306
322
 
307
- self.dataloader = self.dataset.get_data(world_size)
323
+ self.dataloader, _, _ = self.dataset.get_data(world_size)
308
324
 
309
325
  groups_dest = [group for groups in self.dataset.groups_src.values() for group in groups]
310
326
 
konfai/metric/measure.py CHANGED
@@ -46,7 +46,7 @@ class MaskedLoss(Criterion):
46
46
  self.loss = loss
47
47
  self.mode_image_masked = mode_image_masked
48
48
 
49
- def get_mask(self, *targets: torch.Tensor) -> torch.Tensor:
49
+ def get_mask(self, targets: list[torch.Tensor]) -> torch.Tensor | None:
50
50
  result = None
51
51
  if len(targets) > 0:
52
52
  result = targets[0]
@@ -54,29 +54,35 @@ class MaskedLoss(Criterion):
54
54
  result = result * mask
55
55
  return result
56
56
 
57
- def forward(self, output: torch.Tensor, *targets: torch.Tensor) -> torch.Tensor:
57
+ def forward(self, output: torch.Tensor, *targets: torch.Tensor) -> tuple[torch.Tensor, float]:
58
58
  loss = torch.tensor(0, dtype=torch.float32).to(output.device)
59
- mask = self.get_mask(targets[1:])
60
- for batch in range(output.shape[0]):
61
- if mask is not None:
59
+ true_loss = 0
60
+ true_nb = 0
61
+ mask = self.get_mask(list(targets[1:]))
62
+ if mask is not None:
63
+ for batch in range(output.shape[0]):
62
64
  if self.mode_image_masked:
63
- for i in torch.unique(mask[batch]):
64
- if i != 0:
65
- loss += self.loss(
66
- output[batch, ...] * torch.where(mask == i, 1, 0),
67
- targets[0][batch, ...] * torch.where(mask == i, 1, 0),
68
- )
65
+ loss_b = self.loss(
66
+ output[batch, ...] * torch.where(mask == 1, 1, 0),
67
+ targets[0][batch, ...] * torch.where(mask == 1, 1, 0),
68
+ )
69
69
  else:
70
- for i in torch.unique(mask[batch]):
71
- if i != 0:
72
- index = mask[batch, ...] == i
73
- loss += self.loss(
74
- torch.masked_select(output[batch, ...], index),
75
- torch.masked_select(targets[0][batch, ...], index),
76
- )
77
- else:
78
- loss += self.loss(output[batch, ...], targets[0][batch, ...])
79
- return loss / output.shape[0]
70
+ index = mask[batch, ...] == 1
71
+ loss_b = self.loss(
72
+ torch.masked_select(output[batch, ...], index),
73
+ torch.masked_select(targets[0][batch, ...], index),
74
+ )
75
+
76
+ loss += loss_b
77
+ if torch.any(mask[batch] == 1):
78
+ true_loss += loss_b.item()
79
+ true_nb += 1
80
+ else:
81
+ loss_tmp = self.loss(output, targets[0])
82
+ loss += loss_tmp
83
+ true_loss += loss_tmp.item()
84
+ true_nb += 1
85
+ return loss / output.shape[0], np.nan if true_nb == 0 else true_loss / true_nb
80
86
 
81
87
 
82
88
  class MSE(MaskedLoss):
@@ -137,7 +143,7 @@ class LPIPS(MaskedLoss):
137
143
 
138
144
  @staticmethod
139
145
  def preprocessing(tensor: torch.Tensor) -> torch.Tensor:
140
- return tensor.repeat((1, 3, 1, 1))
146
+ return tensor.repeat((1, 3, 1, 1)).to(0)
141
147
 
142
148
  @staticmethod
143
149
  def _loss(loss_fn_alex, x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
@@ -159,14 +165,14 @@ class LPIPS(MaskedLoss):
159
165
  def __init__(self, model: str = "alex") -> None:
160
166
  import lpips
161
167
 
162
- super().__init__(partial(LPIPS._loss, lpips.LPIPS(net=model)), True)
168
+ super().__init__(partial(LPIPS._loss, lpips.LPIPS(net=model).to(0)), True)
163
169
 
164
170
 
165
171
  class Dice(Criterion):
166
172
 
167
- def __init__(self, smooth: float = 1e-6) -> None:
173
+ def __init__(self, labels: list[int] | None = None) -> None:
168
174
  super().__init__()
169
- self.smooth = smooth
175
+ self.labels = labels
170
176
 
171
177
  def flatten(self, tensor: torch.Tensor) -> torch.Tensor:
172
178
  return tensor.permute((1, 0) + tuple(range(2, tensor.dim()))).contiguous().view(tensor.size(1), -1)
@@ -174,26 +180,26 @@ class Dice(Criterion):
174
180
  def dice_per_channel(self, tensor: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
175
181
  tensor = self.flatten(tensor)
176
182
  target = self.flatten(target)
177
- return (2.0 * (tensor * target).sum() + self.smooth) / (tensor.sum() + target.sum() + self.smooth)
183
+ return (2.0 * (tensor * target).sum() + 1e-6) / (tensor.sum() + target.sum() + 1e-6)
178
184
 
179
185
  def forward(self, output: torch.Tensor, *targets: torch.Tensor) -> torch.Tensor:
180
186
  target = F.interpolate(targets[0], output.shape[2:], mode="nearest")
181
- if output.shape[1] == 1:
182
- output = (
183
- F.one_hot(
184
- output.type(torch.int64),
185
- num_classes=int(torch.max(output).item() + 1),
186
- )
187
- .permute(0, len(target.shape), *[i + 1 for i in range(len(target.shape) - 1)])
188
- .float()
189
- )
190
- target = (
191
- F.one_hot(target.type(torch.int64), num_classes=output.shape[1])
192
- .permute(0, len(target.shape), *[i + 1 for i in range(len(target.shape) - 1)])
193
- .float()
194
- .squeeze(2)
195
- )
196
- return 1 - torch.mean(self.dice_per_channel(output, target))
187
+ result = {}
188
+ loss = torch.tensor(0, dtype=torch.float32).to(output.device)
189
+ labels = self.labels if self.labels is not None else torch.unique(target)
190
+ for label in labels:
191
+ tp = target == label
192
+ if tp.any().item():
193
+ if output.shape[1] > 1:
194
+ pp = output[:, label].unsqueeze(1)
195
+ else:
196
+ pp = output == label
197
+ loss_tmp = self.dice_per_channel(pp.float(), tp.float())
198
+ loss += loss_tmp
199
+ result[label] = loss_tmp.item()
200
+ else:
201
+ result[label] = np.nan
202
+ return 1 - loss / len(labels), result
197
203
 
198
204
 
199
205
  class GradientImages(Criterion):
@@ -621,23 +627,14 @@ class MutualInformationLoss(torch.nn.Module):
621
627
  class IMPACTSynth(Criterion): # Feature-Oriented Comparison for Unpaired Synthesis
622
628
 
623
629
  def __init__(
624
- self,
625
- model_name: str,
626
- shape: list[int] = [0, 0],
627
- in_channels: int = 1,
628
- losses: dict[str, list[float]] = {"Gram": [1], "torch_nn_L1Loss": [1]},
630
+ self, model_name: str, shape: list[int] = [0, 0], in_channels: int = 1, weights: list[float] = [1]
629
631
  ) -> None:
630
632
  super().__init__()
631
633
  if model_name is None:
632
634
  return
633
635
  self.in_channels = in_channels
634
- self.losses: dict[torch.nn.Module, list[float]] = {}
635
- for loss, weights in losses.items():
636
- module, name = get_module(loss, "konfai.metric.measure")
637
- self.losses[
638
- config(os.environ["KONFAI_CONFIG_PATH"])(getattr(importlib.import_module(module), name))(config=None)
639
- ] = weights
640
-
636
+ self.loss = torch.nn.L1Loss()
637
+ self.weights = weights
641
638
  self.model_path = download_url(
642
639
  model_name,
643
640
  "https://huggingface.co/VBoussot/impact-torchscript-models/resolve/main/",
@@ -648,19 +645,17 @@ class IMPACTSynth(Criterion): # Feature-Oriented Comparison for Unpaired Synthe
648
645
  self.modules_loss: dict[str, dict[torch.nn.Module, float]] = {}
649
646
 
650
647
  try:
651
- dummy_input = torch.zeros((1, self.in_channels, *(self.shape if self.shape else [224] * self.dim)))
648
+ dummy_input = torch.zeros((1, self.in_channels, *(self.shape if self.shape else [224] * self.dim))).to(0)
652
649
  out = self.model(dummy_input)
653
650
  if not isinstance(out, (list, tuple)):
654
651
  raise TypeError(f"Expected model output to be a list or tuple, but got {type(out)}.")
655
- if len(self.weight) != len(out):
656
- raise ValueError(
657
- f"Mismatch between number of weights ({len(self.weight)}) and model outputs ({len(out)})."
658
- )
652
+ if len(weights) != len(out):
653
+ raise ValueError(f"Mismatch between number of weights ({len(weights)}) and model outputs ({len(out)}).")
659
654
  except Exception as e:
660
655
  msg = (
661
656
  f"[Model Sanity Check Failed]\n"
662
657
  f"Input shape attempted: {dummy_input.shape}\n"
663
- f"Expected output length: {len(self.weight)}\n"
658
+ f"Expected output length: {len(weights)}\n"
664
659
  f"Error: {type(e).__name__}: {e}"
665
660
  )
666
661
  raise RuntimeError(msg) from e
@@ -682,12 +677,14 @@ class IMPACTSynth(Criterion): # Feature-Oriented Comparison for Unpaired Synthe
682
677
  output = self.preprocessing(output)
683
678
  targets = [self.preprocessing(target) for target in targets]
684
679
  self.model.to(output.device)
685
- for zipped_output in zip(self.model(output), *[self.model(target) for target in targets]):
686
- output_features = zipped_output[0]
680
+ for zipped_output in zip(self.weights, self.model(output), *[self.model(target) for target in targets]):
681
+ weight = zipped_output[0]
682
+ if weight == 0:
683
+ continue
684
+ output_feature = zipped_output[1]
687
685
  targets_features = zipped_output[1:]
688
- for target_features, (loss_function, weights) in zip(targets_features, self.losses.items()):
689
- for output_feature, target_feature, weight in zip(output_features, target_features, weights):
690
- loss += weight * loss_function(output_feature, target_feature)
686
+ for target_feature in targets_features:
687
+ loss += weight * self.loss(output_feature, target_feature)
691
688
  return loss
692
689
 
693
690
  def forward(self, output: torch.Tensor, *targets: torch.Tensor) -> torch.Tensor:
@@ -696,7 +693,7 @@ class IMPACTSynth(Criterion): # Feature-Oriented Comparison for Unpaired Synthe
696
693
  loss = torch.zeros((1), requires_grad=True).to(output.device, non_blocking=False).type(torch.float32)
697
694
  if len(output.shape) == 5 and self.dim == 2:
698
695
  for i in range(output.shape[2]):
699
- loss = loss + self._compute(output[:, :, i, ...], [t[:, :, i, ...] for t in targets])
696
+ loss += self._compute(output[:, :, i, ...], [t[:, :, i, ...] for t in targets])
700
697
  loss /= output.shape[2]
701
698
  else:
702
699
  loss = self._compute(output, list(targets))
konfai/network/network.py CHANGED
@@ -6,7 +6,8 @@ from collections import OrderedDict
6
6
  from collections.abc import Callable, Iterable, Iterator, Sequence
7
7
  from enum import Enum
8
8
  from functools import partial
9
- from typing import Any, Self
9
+ from typing import Any
10
+ from typing_extensions import Self
10
11
 
11
12
  import numpy as np
12
13
  import torch
@@ -188,11 +189,16 @@ class Measure:
188
189
  def reset_loss(self) -> None:
189
190
  self._loss.clear()
190
191
 
191
- def add(self, weight: float, value: torch.Tensor) -> None:
192
- if self.is_loss or value != 0:
193
- self._loss.append(value if self.is_loss else value.detach())
194
- self._values.append(value.item())
195
- self._weight.append(weight)
192
+ def add(self, weight: float, value: torch.Tensor | tuple[torch.Tensor, float]) -> None:
193
+ if isinstance(value, tuple):
194
+ loss_value, true_value = value
195
+ else:
196
+ loss_value = value
197
+ true_value = value.item()
198
+
199
+ self._loss.append(loss_value if self.is_loss else loss_value.detach())
200
+ self._values.append(true_value)
201
+ self._weight.append(weight)
196
202
 
197
203
  def get_last_loss(self) -> torch.Tensor:
198
204
  return self._loss[-1] * self._weight[-1] if len(self._loss) else torch.zeros((1), requires_grad=True)
@@ -239,13 +245,13 @@ class Measure:
239
245
  for target_group_tmp in target_group.split(";"):
240
246
  if target_group_tmp not in group_dest:
241
247
  raise MeasureError(
242
- f"The target_group '{target_group_tmp}' defined in "
248
+ f"The target_group {target_group_tmp} defined in "
243
249
  "'outputs_criterions.{output_group}.targets_criterions'"
244
250
  " was not found in the available destination groups.",
245
251
  "This target_group is expected for loss or metric computation, "
246
252
  "but was not loaded in 'group_dest'.",
247
- f"Please make sure that the group '{target_group_tmp}' is defined in "
248
- "'Dataset:groups_src:...:groups_dest:'{target_group_tmp}'' "
253
+ f"Please make sure that the group {target_group_tmp} is defined in "
254
+ "Dataset:groups_src:...:groups_dest: {target_group_tmp} "
249
255
  "and correctly loaded from the dataset.",
250
256
  )
251
257
  for criterion in self.outputs_criterions[output_group][target_group]:
@@ -868,7 +874,10 @@ class Network(ModuleArgsDict, ABC):
868
874
  init_gain=self.init_gain,
869
875
  )
870
876
  )
871
- name = "Model" + ("_EMA" if ema else "")
877
+ name = "Model"
878
+ if ema:
879
+ if name + "_EMA" in state_dict:
880
+ name += "_EMA"
872
881
  if name in state_dict:
873
882
  value = state_dict[name]
874
883
  model_state_dict_tmp = {}
@@ -893,7 +902,9 @@ class Network(ModuleArgsDict, ABC):
893
902
  model_state_dict[alias] = model_state_dict_tmp[alias]
894
903
  self.load_state_dict(model_state_dict)
895
904
  if f"{self.get_name()}_optimizer_state_dict" in state_dict and self.optimizer:
905
+ last_lr = self.optimizer.param_groups[0]["lr"]
896
906
  self.optimizer.load_state_dict(state_dict[f"{self.get_name()}_optimizer_state_dict"])
907
+ self.optimizer.param_groups[0]["lr"] = last_lr
897
908
  if f"{self.get_name()}_it" in state_dict:
898
909
  _it = state_dict.get(f"{self.get_name()}_it")
899
910
  if isinstance(_it, int):
konfai/predictor.py CHANGED
@@ -715,7 +715,7 @@ class Predictor(DistributedObject):
715
715
  path = models_directory() + self.name + "/StateDict/"
716
716
  name = sorted(os.listdir(path))[-1]
717
717
  if os.path.exists(path + name):
718
- state_dicts.append(torch.load(path + name, weights_only=True))
718
+ state_dicts.append(torch.load(path + name, weights_only=False))
719
719
  else:
720
720
  raise Exception(f"Model : {path + name} does not exist !")
721
721
  return state_dicts
@@ -793,7 +793,7 @@ class Predictor(DistributedObject):
793
793
  exit(0)
794
794
 
795
795
  self.size = len(self.gpu_checkpoints) + 1 if self.gpu_checkpoints else 1
796
- self.dataloader = self.dataset.get_data(world_size // self.size)
796
+ self.dataloader, _, _ = self.dataset.get_data(world_size // self.size)
797
797
  for name, output_dataset in self.outputs_dataset.items():
798
798
  output_dataset.load(
799
799
  name.replace(".", ":"),
konfai/trainer.py CHANGED
@@ -14,7 +14,6 @@ from konfai import (
14
14
  config_file,
15
15
  current_date,
16
16
  konfai_state,
17
- models_directory,
18
17
  path_to_models,
19
18
  setups_directory,
20
19
  statistics_directory,
@@ -147,7 +146,6 @@ class _Trainer:
147
146
  self.global_rank = global_rank
148
147
  self.local_rank = local_rank
149
148
  self.size = size
150
-
151
149
  self.save_checkpoint_mode = save_checkpoint_mode
152
150
  self.train_name = train_name
153
151
  self.epochs = epochs
@@ -186,6 +184,9 @@ class _Trainer:
186
184
  self.dataloader_training.dataset.load("Train")
187
185
  if self.dataloader_validation is not None:
188
186
  self.dataloader_validation.dataset.load("Validation")
187
+ if State[konfai_state()] != State.TRAIN:
188
+ self._validate()
189
+
189
190
  with tqdm.tqdm(
190
191
  iterable=range(self.epoch, self.epochs),
191
192
  leave=False,
@@ -604,38 +605,7 @@ class Trainer(DistributedObject):
604
605
  Serializes model and EMA model (if any) in both state_dict and full model formats.
605
606
  Also saves optimizer states and YAML config snapshot.
606
607
  """
607
- path_checkpoint = checkpoints_directory() + self.name + "/"
608
- path_model = models_directory() + self.name + "/"
609
- if os.path.exists(path_checkpoint) and os.listdir(path_checkpoint):
610
- for directory in [
611
- path_model,
612
- f"{path_model}Serialized/",
613
- f"{path_model}StateDict/",
614
- ]:
615
- if not os.path.exists(directory):
616
- os.makedirs(directory)
617
-
618
- for name in sorted(os.listdir(path_checkpoint)):
619
- checkpoint = torch.load(path_checkpoint + name, weights_only=True, map_location="cpu")
620
- self.model.load(checkpoint, init=False, ema=False)
621
-
622
- torch.save(self.model, f"{path_model}Serialized/{name}")
623
- torch.save(
624
- {"Model": self.model.state_dict()},
625
- f"{path_model}StateDict/{name}",
626
- )
627
-
628
- if self.model_ema is not None:
629
- self.model_ema.module.load(checkpoint, init=False, ema=True)
630
- torch.save(
631
- self.model_ema.module,
632
- f"{path_model}Serialized/{current_date()}_EMA.pt",
633
- )
634
- torch.save(
635
- {"Model_EMA": self.model_ema.module.state_dict()},
636
- f"{path_model}StateDict/{current_date()}_EMA.pt",
637
- )
638
-
608
+ if os.path.exists(self.config_namefile):
639
609
  os.rename(
640
610
  self.config_namefile,
641
611
  self.config_namefile.replace(".yml", "") + "_" + str(self.it) + ".yml",
@@ -662,6 +632,7 @@ class Trainer(DistributedObject):
662
632
  world_size (int): Total number of distributed processes.
663
633
  """
664
634
  state = State[konfai_state()]
635
+ print(checkpoints_directory() + self.name + "/")
665
636
  if state != State.RESUME and os.path.exists(checkpoints_directory() + self.name + "/"):
666
637
  if os.environ["KONFAI_OVERWRITE"] != "True":
667
638
  accept = input(f"The model {self.name} already exists ! Do you want to overwrite it (yes,no) : ")
@@ -669,7 +640,6 @@ class Trainer(DistributedObject):
669
640
  return
670
641
  for directory_path in [
671
642
  statistics_directory(),
672
- models_directory(),
673
643
  checkpoints_directory(),
674
644
  setups_directory(),
675
645
  ]:
@@ -698,7 +668,13 @@ class Trainer(DistributedObject):
698
668
  os.makedirs(setups_directory() + self.name + "/")
699
669
  shutil.copyfile(self.config_namefile_src + ".yml", self.config_namefile)
700
670
 
701
- self.dataloader = self.dataset.get_data(world_size // self.size)
671
+ self.dataloader, train_names, validation_names = self.dataset.get_data(world_size // self.size)
672
+ with open(setups_directory() + self.name + "/Train_" + str(self.it) + ".txt", "w") as f:
673
+ for name in train_names:
674
+ f.write(name + "\n")
675
+ with open(setups_directory() + self.name + "/Validation_" + str(self.it) + ".txt", "w") as f:
676
+ for name in validation_names:
677
+ f.write(name + "\n")
702
678
 
703
679
  def run_process(
704
680
  self,
konfai/utils/utils.py CHANGED
@@ -41,7 +41,7 @@ def description(model, model_ema=None, show_memory: bool = True, train: bool = T
41
41
  "("
42
42
  + " ".join(
43
43
  [
44
- f"{name}({network.optimizer.param_groups[0]['lr']:.6f} if network.optimizer else 0) : "
44
+ f"{name}({(network.optimizer.param_groups[0]['lr'] if network.optimizer else 0):.6f}) : "
45
45
  + " ".join(
46
46
  f"{k.split(':')[-1]}({w:.2f}) : {v:.6f}"
47
47
  for (k, v), w in zip(
@@ -58,11 +58,11 @@ def description(model, model_ema=None, show_memory: bool = True, train: bool = T
58
58
  model_loss_desc = loss_desc(model)
59
59
  result = ""
60
60
  if len(model_loss_desc) > 2:
61
- f"Loss {model_loss_desc} "
61
+ result += f"Loss {model_loss_desc} "
62
62
  if model_ema is not None:
63
63
  model_ema_loss_desc = loss_desc(model_ema)
64
64
  if len(model_ema_loss_desc) > 2:
65
- result += f"Loss EMA {loss_desc(model_ema_loss_desc)} "
65
+ result += f"Loss EMA {model_ema_loss_desc} "
66
66
  result += gpu_info()
67
67
  if show_memory:
68
68
  result += f" | {get_memory_info()}"
@@ -765,6 +765,8 @@ def download_url(model_name: str, url: str) -> str:
765
765
  if not isinstance(locations, list) or not locations:
766
766
  raise ImportError("No valid submodule_search_locations found")
767
767
  base_path = Path(locations[0]) / "metric" / "models"
768
+ os.makedirs(base_path, exist_ok=True)
769
+
768
770
  subdirs = Path(model_name).parent
769
771
  model_dir = base_path / subdirs
770
772
  model_dir.mkdir(exist_ok=True)
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: konfai
3
- Version: 1.1.9
3
+ Version: 1.2.1
4
4
  Summary: Modular and configurable Deep Learning framework with YAML and PyTorch
5
5
  Author-email: Valentin Boussot <boussot.v@gmail.com>
6
6
  License-Expression: Apache-2.0
@@ -1,15 +1,15 @@
1
1
  konfai/__init__.py,sha256=qjE9Rqxo1sMrkqGS8I5xlGQMZnjIfU-CGgSI5Wmbmbs,1231
2
- konfai/evaluator.py,sha256=jsRzVSFjK-V1rZVK9kmN0Gh5-F2JhJrJv291UGNm8CM,16736
2
+ konfai/evaluator.py,sha256=xAKWUDvdSxqYRUsKqH6ieQF06LWa785aE4zLv4I3_i4,17850
3
3
  konfai/main.py,sha256=Fc4HcJEhPmgunj_f-QYyvQNvjHrKHSUv27Okgu6V5_A,3842
4
- konfai/predictor.py,sha256=k9S-AH-wGvmr4YQF2IczJ2Nb9_aTZwNd9f6iu4s9v78,34591
5
- konfai/trainer.py,sha256=wNHgDh0LtxTi0-aWCkT90hjjpJFaX_zRWyA5esVrsLY,28072
4
+ konfai/predictor.py,sha256=-ZcHrFnP7fOUZ4SK4DpNYbir7iScWQnSOZfmeSLtg1I,34598
5
+ konfai/trainer.py,sha256=4mc-8r-FxtX_EAn2su8qd-BLLQ0D0So8hh5rmuf6Hqs,27163
6
6
  konfai/data/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
7
7
  konfai/data/augmentation.py,sha256=vcJE7mosvUkwbpbTN_lGP0S1uJrJYGjlLrt9VnDdJYY,27792
8
- konfai/data/data_manager.py,sha256=FOxpmL56-Cgqsms2rdPk1NML5OiuuIQW49-G0j3O2Os,31564
8
+ konfai/data/data_manager.py,sha256=7-ruYS2HJMvLzTq5425loOf2OCpAm5MspdYoO74ICsw,31180
9
9
  konfai/data/patching.py,sha256=jS35OxnJagKNUnJu7TzuGZpVj9fP-6H4nc2OEYOGgt8,16494
10
- konfai/data/transform.py,sha256=MmA1vgXAj1V-e-8RNa1XAaHbWI5i85NoVyCp7yZ2-kg,26816
10
+ konfai/data/transform.py,sha256=YCldsqTTBFFCqc_VdvyuNVs2kmV56CxQBN5XhEoPxho,27745
11
11
  konfai/metric/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
12
- konfai/metric/measure.py,sha256=7CHpLWCuuNNk3cYvGiNql6nt_pq6PA5SHIVNiSpeolk,28144
12
+ konfai/metric/measure.py,sha256=0mOIZKTa2u0UECpoDSbdJUhttAw_e1BlsROQQpi1oBk,27804
13
13
  konfai/metric/schedulers.py,sha256=TpYMA24FMpxRnqfhMGb0i_Mm-bzT9kySbBgvkYk-6wM,1327
14
14
  konfai/models/classification/convNeXt.py,sha256=Ha9QYd1-JEYUwL7zgNNAWeuJLbUT7LCNUkLcsdySAHM,10060
15
15
  konfai/models/classification/resnet.py,sha256=4-l7HtpP_OqApDT8XjTH94fXIuiSzz833SUWXP2oFJo,10813
@@ -24,15 +24,15 @@ konfai/models/segmentation/NestedUNet.py,sha256=W4uauwF0HY8ybi49iYiTlKLdJEyD7SaC
24
24
  konfai/models/segmentation/UNet.py,sha256=Pu_LiQdO4Mrzyn0HRE6rwxUjHGH4OG-JpzWB_U1K46g,5602
25
25
  konfai/network/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
26
26
  konfai/network/blocks.py,sha256=l70_oOcz5Hmol2xmxruG0kke_2SVgO3rXYXVTMSdAS8,15645
27
- konfai/network/network.py,sha256=FPv3i3pnyThFMd1KgMk3wNz41btE5_Kby95dUwG7PsM,54458
27
+ konfai/network/network.py,sha256=0G3goIopB_r4PHj4ohOkbov344mVYxZjq5PH57QETmc,54829
28
28
  konfai/utils/ITK.py,sha256=HVed4Z96X1jTaWrrQNdoBMqOtVK9InAPlDBJu-5uv3g,15476
29
29
  konfai/utils/__init__.py,sha256=47DEQpj8HBSa-_TImW-5JCeuQeRkm5NMpJWZG3hSuFU,0
30
30
  konfai/utils/config.py,sha256=a7t44CYMUT5oCDdjL94IswhCVfFbQ5FCgDWZktDDkc4,14347
31
31
  konfai/utils/dataset.py,sha256=Au22fcADKyDJMfS8Z9q8kEXLtKkoufJsH7Pwly6pALo,28288
32
- konfai/utils/utils.py,sha256=AtPWHJh_RdAGK2m9Cv3BXvUZTpspVHRTPXXvRfn5dZg,28366
33
- konfai-1.1.9.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
34
- konfai-1.1.9.dist-info/METADATA,sha256=6PaWz831mjixz5kTN_-wBtDNGVBk0SjwayD4AGyUC5o,2451
35
- konfai-1.1.9.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
36
- konfai-1.1.9.dist-info/entry_points.txt,sha256=fG82HRN5-g39ACSOCtij_I3N6EHxfYnMR0D7TI_8pW8,81
37
- konfai-1.1.9.dist-info/top_level.txt,sha256=xF470dkIlFoFqTZEOlRehKJr4WU_8OKGXrJqYm9vWKs,7
38
- konfai-1.1.9.dist-info/RECORD,,
32
+ konfai/utils/utils.py,sha256=jCj3tZ8agQYceSY_tlVYp88UFPE5oUn6tXrqnZGrKiI,28410
33
+ konfai-1.2.1.dist-info/licenses/LICENSE,sha256=xx0jnfkXJvxRnG63LTGOxlggYnIysveWIZ6H3PNdCrQ,11357
34
+ konfai-1.2.1.dist-info/METADATA,sha256=UJzi2uPo_JMW4VeEWYm3IB0ZdEaIHbIrdDj4zNNw3WQ,2451
35
+ konfai-1.2.1.dist-info/WHEEL,sha256=_zCd3N1l69ArxyTb8rzEoP9TpbYXkqRFSNOD5OuxnTs,91
36
+ konfai-1.2.1.dist-info/entry_points.txt,sha256=fG82HRN5-g39ACSOCtij_I3N6EHxfYnMR0D7TI_8pW8,81
37
+ konfai-1.2.1.dist-info/top_level.txt,sha256=xF470dkIlFoFqTZEOlRehKJr4WU_8OKGXrJqYm9vWKs,7
38
+ konfai-1.2.1.dist-info/RECORD,,
File without changes