konfai 1.1.6__tar.gz → 1.1.8__tar.gz

This diff represents the content of publicly available package versions that have been released to one of the supported registries. The information contained in this diff is provided for informational purposes only and reflects changes between package versions as they appear in their respective public registries.

Potentially problematic release.


This version of konfai might be problematic. Click here for more details.

Files changed (46) hide show
  1. {konfai-1.1.6 → konfai-1.1.8}/PKG-INFO +1 -1
  2. {konfai-1.1.6 → konfai-1.1.8}/konfai/data/augmentation.py +0 -1
  3. {konfai-1.1.6 → konfai-1.1.8}/konfai/data/data_manager.py +25 -39
  4. {konfai-1.1.6 → konfai-1.1.8}/konfai/data/patching.py +7 -19
  5. {konfai-1.1.6 → konfai-1.1.8}/konfai/data/transform.py +1 -2
  6. {konfai-1.1.6 → konfai-1.1.8}/konfai/metric/measure.py +85 -12
  7. {konfai-1.1.6 → konfai-1.1.8}/konfai/metric/schedulers.py +2 -16
  8. {konfai-1.1.6 → konfai-1.1.8}/konfai/models/classification/convNeXt.py +1 -2
  9. {konfai-1.1.6 → konfai-1.1.8}/konfai/models/classification/resnet.py +1 -2
  10. {konfai-1.1.6 → konfai-1.1.8}/konfai/models/generation/cStyleGan.py +1 -2
  11. {konfai-1.1.6 → konfai-1.1.8}/konfai/models/generation/ddpm.py +1 -2
  12. {konfai-1.1.6 → konfai-1.1.8}/konfai/models/generation/diffusionGan.py +10 -23
  13. {konfai-1.1.6 → konfai-1.1.8}/konfai/models/generation/gan.py +3 -6
  14. {konfai-1.1.6 → konfai-1.1.8}/konfai/models/generation/vae.py +0 -2
  15. {konfai-1.1.6 → konfai-1.1.8}/konfai/models/registration/registration.py +1 -2
  16. {konfai-1.1.6 → konfai-1.1.8}/konfai/models/representation/representation.py +1 -2
  17. {konfai-1.1.6 → konfai-1.1.8}/konfai/models/segmentation/NestedUNet.py +10 -9
  18. {konfai-1.1.6 → konfai-1.1.8}/konfai/models/segmentation/UNet.py +1 -1
  19. {konfai-1.1.6 → konfai-1.1.8}/konfai/network/blocks.py +8 -6
  20. {konfai-1.1.6 → konfai-1.1.8}/konfai/network/network.py +35 -49
  21. {konfai-1.1.6 → konfai-1.1.8}/konfai/predictor.py +39 -59
  22. {konfai-1.1.6 → konfai-1.1.8}/konfai/utils/utils.py +40 -2
  23. {konfai-1.1.6 → konfai-1.1.8}/konfai.egg-info/PKG-INFO +1 -1
  24. {konfai-1.1.6 → konfai-1.1.8}/pyproject.toml +1 -1
  25. {konfai-1.1.6 → konfai-1.1.8}/LICENSE +0 -0
  26. {konfai-1.1.6 → konfai-1.1.8}/README.md +0 -0
  27. {konfai-1.1.6 → konfai-1.1.8}/konfai/__init__.py +0 -0
  28. {konfai-1.1.6 → konfai-1.1.8}/konfai/data/__init__.py +0 -0
  29. {konfai-1.1.6 → konfai-1.1.8}/konfai/evaluator.py +0 -0
  30. {konfai-1.1.6 → konfai-1.1.8}/konfai/main.py +0 -0
  31. {konfai-1.1.6 → konfai-1.1.8}/konfai/metric/__init__.py +0 -0
  32. {konfai-1.1.6 → konfai-1.1.8}/konfai/network/__init__.py +0 -0
  33. {konfai-1.1.6 → konfai-1.1.8}/konfai/trainer.py +0 -0
  34. {konfai-1.1.6 → konfai-1.1.8}/konfai/utils/ITK.py +0 -0
  35. {konfai-1.1.6 → konfai-1.1.8}/konfai/utils/__init__.py +0 -0
  36. {konfai-1.1.6 → konfai-1.1.8}/konfai/utils/config.py +0 -0
  37. {konfai-1.1.6 → konfai-1.1.8}/konfai/utils/dataset.py +0 -0
  38. {konfai-1.1.6 → konfai-1.1.8}/konfai/utils/registration.py +0 -0
  39. {konfai-1.1.6 → konfai-1.1.8}/konfai.egg-info/SOURCES.txt +0 -0
  40. {konfai-1.1.6 → konfai-1.1.8}/konfai.egg-info/dependency_links.txt +0 -0
  41. {konfai-1.1.6 → konfai-1.1.8}/konfai.egg-info/entry_points.txt +0 -0
  42. {konfai-1.1.6 → konfai-1.1.8}/konfai.egg-info/requires.txt +0 -0
  43. {konfai-1.1.6 → konfai-1.1.8}/konfai.egg-info/top_level.txt +0 -0
  44. {konfai-1.1.6 → konfai-1.1.8}/setup.cfg +0 -0
  45. {konfai-1.1.6 → konfai-1.1.8}/tests/test_config.py +0 -0
  46. {konfai-1.1.6 → konfai-1.1.8}/tests/test_dataset.py +0 -0
@@ -1,6 +1,6 @@
1
1
  Metadata-Version: 2.4
2
2
  Name: konfai
3
- Version: 1.1.6
3
+ Version: 1.1.8
4
4
  Summary: Modular and configurable Deep Learning framework with YAML and PyTorch
5
5
  Author-email: Valentin Boussot <boussot.v@gmail.com>
6
6
  License-Expression: Apache-2.0
@@ -264,7 +264,6 @@ class LumaFlip(ColorTransform):
264
264
 
265
265
  class HUE(ColorTransform):
266
266
 
267
- @config("HUE")
268
267
  def __init__(self, hue_max: float, groups: Union[list[str], None] = None) -> None:
269
268
  super().__init__(groups)
270
269
  self.hue_max = hue_max
@@ -24,11 +24,11 @@ from konfai.data.augmentation import DataAugmentationsList
24
24
  class GroupTransform:
25
25
 
26
26
  @config()
27
- def __init__(self, pre_transforms : Union[dict[str, TransformLoader], list[Transform]] = {"default:Normalize:Standardize:Unsqueeze:TensorCast:ResampleIsotropic:ResampleResize": TransformLoader()},
28
- post_transforms : Union[dict[str, TransformLoader], list[Transform]] = {"default:Normalize:Standardize:Unsqueeze:TensorCast:ResampleIsotropic:ResampleResize": TransformLoader()},
27
+ def __init__(self, transforms : Union[dict[str, TransformLoader], list[Transform]] = {"default:Normalize:Standardize:Unsqueeze:TensorCast:ResampleIsotropic:ResampleResize": TransformLoader()},
28
+ patch_transforms : Union[dict[str, TransformLoader], list[Transform]] = {"default:Normalize:Standardize:Unsqueeze:TensorCast:ResampleIsotropic:ResampleResize": TransformLoader()},
29
29
  isInput: bool = True) -> None:
30
- self._pre_transforms = pre_transforms
31
- self._post_transforms = post_transforms
30
+ self._pre_transforms = transforms
31
+ self._post_transforms = patch_transforms
32
32
  self.pre_transforms : list[Transform] = []
33
33
  self.post_transforms : list[Transform] = []
34
34
  self.isInput = isInput
@@ -37,7 +37,7 @@ class GroupTransform:
37
37
  if self._pre_transforms is not None:
38
38
  if isinstance(self._pre_transforms, dict):
39
39
  for classpath, transform in self._pre_transforms.items():
40
- transform = transform.getTransform(classpath, DL_args = "{}.Dataset.groups_src.{}.groups_dest.{}.pre_transforms".format(KONFAI_ROOT(), group_src, group_dest))
40
+ transform = transform.getTransform(classpath, DL_args = "{}.Dataset.groups_src.{}.groups_dest.{}.transforms".format(KONFAI_ROOT(), group_src, group_dest))
41
41
  transform.setDatasets(datasets)
42
42
  self.pre_transforms.append(transform)
43
43
  else:
@@ -48,7 +48,7 @@ class GroupTransform:
48
48
  if self._post_transforms is not None:
49
49
  if isinstance(self._post_transforms, dict):
50
50
  for classpath, transform in self._post_transforms.items():
51
- transform = transform.getTransform(classpath, DL_args = "{}.Dataset.groups_src.{}.groups_dest.{}.post_transforms".format(KONFAI_ROOT(), group_src, group_dest))
51
+ transform = transform.getTransform(classpath, DL_args = "{}.Dataset.groups_src.{}.groups_dest.{}.patch_transforms".format(KONFAI_ROOT(), group_src, group_dest))
52
52
  transform.setDatasets(datasets)
53
53
  self.post_transforms.append(transform)
54
54
  else:
@@ -65,9 +65,8 @@ class GroupTransform:
65
65
  class GroupTransformMetric(GroupTransform):
66
66
 
67
67
  @config()
68
- def __init__(self, pre_transforms : Union[dict[str, TransformLoader], list[Transform]] = {"default:Normalize:Standardize:Unsqueeze:TensorCast:ResampleIsotropic:ResampleResize": TransformLoader()},
69
- post_transforms : Union[dict[str, TransformLoader], list[Transform]] = {"default:Normalize:Standardize:Unsqueeze:TensorCast:ResampleIsotropic:ResampleResize": TransformLoader()}):
70
- super().__init__(pre_transforms, post_transforms)
68
+ def __init__(self, transforms : Union[dict[str, TransformLoader], list[Transform]] = {"default:Normalize:Standardize:Unsqueeze:TensorCast:ResampleIsotropic:ResampleResize": TransformLoader()}):
69
+ super().__init__(transforms, None)
71
70
 
72
71
  class Group(dict[str, GroupTransform]):
73
72
 
@@ -196,10 +195,9 @@ class DatasetIter(data.Dataset):
196
195
 
197
196
  class Subset():
198
197
 
199
- def __init__(self, subset: Union[str, list[int], list[str], None] = None, shuffle: bool = True, filter: str = None) -> None:
198
+ def __init__(self, subset: Union[str, list[int], list[str], None] = None, shuffle: bool = True) -> None:
200
199
  self.subset = subset
201
200
  self.shuffle = shuffle
202
- self.filter: list[tuple[str]] = [tuple(a.split(":")) for a in filter.split(";")] if filter is not None else None
203
201
 
204
202
  def __call__(self, names: list[str], infos: list[dict[str, tuple[np.ndarray, Attribute]]]) -> set[str]:
205
203
  inter_name = set(names[0])
@@ -207,16 +205,7 @@ class Subset():
207
205
  inter_name = inter_name.intersection(set(n))
208
206
  names = sorted(list(inter_name))
209
207
 
210
- names_filtred = []
211
- if self.filter is not None:
212
- for name in names:
213
- for info in infos:
214
- if all([info[name][1].isInfo(f[0], f[1]) for f in self.filter]):
215
- names_filtred.append(name)
216
- continue
217
- else:
218
- names_filtred = names
219
- size = len(names_filtred)
208
+ size = len(names)
220
209
  index = []
221
210
  if self.subset is None:
222
211
  index = list(range(0, size))
@@ -230,7 +219,7 @@ class Subset():
230
219
  for name in f:
231
220
  train_names.append(name.strip())
232
221
  index = []
233
- for i, name in enumerate(names_filtred):
222
+ for i, name in enumerate(names):
234
223
  if name in train_names:
235
224
  index.append(i)
236
225
  elif self.subset.startswith("~") and os.path.exists(self.subset[1:]):
@@ -239,7 +228,7 @@ class Subset():
239
228
  for name in f:
240
229
  exclude_names.append(name.strip())
241
230
  index = []
242
- for i, name in enumerate(names_filtred):
231
+ for i, name in enumerate(names):
243
232
  if name not in exclude_names:
244
233
  index.append(i)
245
234
 
@@ -252,26 +241,27 @@ class Subset():
252
241
  index = self.subset
253
242
  if isinstance(self.subset[0], str):
254
243
  index = []
255
- for i, name in enumerate(names_filtred):
244
+ for i, name in enumerate(names):
256
245
  if name in self.subset:
257
246
  index.append(i)
258
247
  if self.shuffle:
259
248
  index = random.sample(index, len(index))
260
- return set([names_filtred[i] for i in index])
249
+ return set([names[i] for i in index])
261
250
 
262
251
  def __str__(self):
263
- return "Subset : " + str(self.subset) + " shuffle : "+ str(self.shuffle) + " filter : "+ str(self.filter)
252
+ return "Subset : " + str(self.subset) + " shuffle : "+ str(self.shuffle)
253
+
264
254
  class TrainSubset(Subset):
265
255
 
266
256
  @config()
267
- def __init__(self, subset: Union[str, list[int], list[str], None] = None, shuffle: bool = True, filter: str = None) -> None:
268
- super().__init__(subset, shuffle, filter)
257
+ def __init__(self, subset: Union[str, list[int], list[str], None] = None, shuffle: bool = True) -> None:
258
+ super().__init__(subset, shuffle)
269
259
 
270
260
  class PredictionSubset(Subset):
271
261
 
272
262
  @config()
273
- def __init__(self, subset: Union[str, list[int], list[str], None] = None, filter: str = None) -> None:
274
- super().__init__(subset, False, filter)
263
+ def __init__(self, subset: Union[str, list[int], list[str], None] = None) -> None:
264
+ super().__init__(subset, False)
275
265
 
276
266
  class Data(ABC):
277
267
 
@@ -280,7 +270,6 @@ class Data(ABC):
280
270
  patch : Union[DatasetPatch, None],
281
271
  use_cache : bool,
282
272
  subset : Subset,
283
- num_workers : int,
284
273
  batch_size : int,
285
274
  validation: Union[float, str, list[int], list[str], None] = None,
286
275
  inlineAugmentations: bool = False,
@@ -293,7 +282,7 @@ class Data(ABC):
293
282
  self.dataAugmentationsList = dataAugmentationsList
294
283
  self.batch_size = batch_size
295
284
  self.dataSet_args = dict(groups_src=self.groups_src, inlineAugmentations=inlineAugmentations, dataAugmentationsList = list(self.dataAugmentationsList.values()), use_cache = use_cache, buffer_size=batch_size+1, patch_size=self.patch.patch_size if self.patch is not None else None, overlap=self.patch.overlap if self.patch is not None else None)
296
- self.dataLoader_args = dict(num_workers=num_workers if use_cache else 0, pin_memory=True)
285
+ self.dataLoader_args = dict(num_workers=int(os.environ["KONFAI_WORKERS"]) if use_cache else 0, pin_memory=True)
297
286
  self.data : list[list[dict[str, list[DatasetManager]]], dict[str, list[DatasetManager]]] = []
298
287
  self.map : list[list[list[tuple[int, int, int]]], list[tuple[int, int, int]]] = []
299
288
  self.datasets: dict[str, Dataset] = {}
@@ -547,10 +536,9 @@ class DataTrain(Data):
547
536
  patch : Union[DatasetPatch, None] = DatasetPatch(),
548
537
  use_cache : bool = True,
549
538
  subset : Union[TrainSubset, dict[str, TrainSubset]] = TrainSubset(),
550
- num_workers : int = 4,
551
539
  batch_size : int = 1,
552
540
  validation : Union[float, str, list[int], list[str]] = 0.2) -> None:
553
- super().__init__(dataset_filenames, groups_src, patch, use_cache, subset, num_workers, batch_size, validation, inlineAugmentations, augmentations if augmentations else {})
541
+ super().__init__(dataset_filenames, groups_src, patch, use_cache, subset, batch_size, validation, inlineAugmentations, augmentations if augmentations else {})
554
542
 
555
543
  class DataPrediction(Data):
556
544
 
@@ -560,10 +548,9 @@ class DataPrediction(Data):
560
548
  augmentations : Union[dict[str, DataAugmentationsList], None] = {"DataAugmentation_0" : DataAugmentationsList()},
561
549
  patch : Union[DatasetPatch, None] = DatasetPatch(),
562
550
  subset : Union[PredictionSubset, dict[str, PredictionSubset]] = PredictionSubset(),
563
- num_workers : int = 4,
564
551
  batch_size : int = 1) -> None:
565
552
 
566
- super().__init__(dataset_filenames, groups_src, patch, False, subset, num_workers, batch_size, dataAugmentationsList=augmentations if augmentations else {})
553
+ super().__init__(dataset_filenames, groups_src, patch, False, subset, batch_size, dataAugmentationsList=augmentations if augmentations else {})
567
554
 
568
555
  class DataMetric(Data):
569
556
 
@@ -571,7 +558,6 @@ class DataMetric(Data):
571
558
  def __init__(self, dataset_filenames : list[str] = ["default:./Dataset"],
572
559
  groups_src : dict[str, GroupMetric] = {"default" : GroupMetric()},
573
560
  subset : Union[PredictionSubset, dict[str, PredictionSubset]] = PredictionSubset(),
574
- validation: Union[str, None] = None,
575
- num_workers : int = 4) -> None:
561
+ validation: Union[str, None] = None) -> None:
576
562
 
577
- super().__init__(dataset_filenames=dataset_filenames, groups_src=groups_src, patch=None, use_cache=False, subset=subset, num_workers=num_workers, batch_size=1, validation=validation)
563
+ super().__init__(dataset_filenames=dataset_filenames, groups_src=groups_src, patch=None, use_cache=False, subset=subset, batch_size=1, validation=validation)
@@ -81,7 +81,6 @@ class PathCombine(ABC):
81
81
 
82
82
  class Mean(PathCombine):
83
83
 
84
- @config("Mean")
85
84
  def __init__(self) -> None:
86
85
  super().__init__()
87
86
 
@@ -90,7 +89,6 @@ class Mean(PathCombine):
90
89
 
91
90
  class Cosinus(PathCombine):
92
91
 
93
- @config("Cosinus")
94
92
  def __init__(self) -> None:
95
93
  super().__init__()
96
94
 
@@ -144,7 +142,7 @@ class Accumulator():
144
142
 
145
143
  class Patch(ABC):
146
144
 
147
- def __init__(self, patch_size: list[int], overlap: Union[int, None], path_mask: Union[str, None] = None, padValue: float = 0, extend_slice: int = 0) -> None:
145
+ def __init__(self, patch_size: list[int], overlap: Union[int, None], padValue: float = 0, extend_slice: int = 0) -> None:
148
146
  self.patch_size = patch_size
149
147
  self.overlap = overlap
150
148
  if isinstance(self.overlap, int):
@@ -152,15 +150,8 @@ class Patch(ABC):
152
150
  self.overlap = None
153
151
  self._patch_slices : dict[int, list[tuple[slice]]] = {}
154
152
  self._nb_patch_per_dim: dict[int, list[tuple[int, bool]]] = {}
155
- self.path_mask = path_mask
156
- self.mask = None
157
153
  self.padValue = padValue
158
154
  self.extend_slice = extend_slice
159
- if self.path_mask is not None:
160
- if os.path.exists(self.path_mask):
161
- self.mask = torch.tensor(sitk.GetArrayFromImage(sitk.ReadImage(self.path_mask)))
162
- else:
163
- raise NameError('Mask file not found')
164
155
 
165
156
  def load(self, shape : dict[int, list[int]], a: int = 0) -> None:
166
157
  self._patch_slices[a], self._nb_patch_per_dim[a] = get_patch_slices_from_shape(self.patch_size, shape, self.overlap)
@@ -199,9 +190,6 @@ class Patch(ABC):
199
190
  padding.append(p)
200
191
 
201
192
  data_sliced = F.pad(data_sliced, tuple(padding), "constant", 0 if data_sliced.dtype == torch.uint8 and self.padValue < 0 else self.padValue)
202
- if self.mask is not None:
203
- outside = torch.ones_like(data_sliced)*(0 if data_sliced.dtype == torch.uint8 and self.padValue < 0 else self.padValue)
204
- data_sliced = torch.where(self.mask == 0, outside, data_sliced)
205
193
 
206
194
  for d in [i for i, v in enumerate(reversed(self.patch_size)) if v == 1]:
207
195
  data_sliced = torch.squeeze(data_sliced, dim = len(data_sliced.shape)-d-1)
@@ -213,14 +201,14 @@ class Patch(ABC):
213
201
  class DatasetPatch(Patch):
214
202
 
215
203
  @config("Patch")
216
- def __init__(self, patch_size : list[int] = [128, 128, 128], overlap : Union[int, None] = None, mask: Union[str, None] = None, padValue: float = 0, extend_slice: int = 0) -> None:
217
- super().__init__(patch_size, overlap, mask, padValue, extend_slice)
204
+ def __init__(self, patch_size : list[int] = [128, 128, 128], overlap : Union[int, None] = None, padValue: float = 0, extend_slice: int = 0) -> None:
205
+ super().__init__(patch_size, overlap, padValue, extend_slice)
218
206
 
219
207
  class ModelPatch(Patch):
220
208
 
221
209
  @config("Patch")
222
- def __init__(self, patch_size : list[int] = [128, 128, 128], overlap : Union[int, None] = None, patchCombine: Union[str, None] = None, mask: Union[str, None] = None, padValue: float = 0, extend_slice: int = 0) -> None:
223
- super().__init__(patch_size, overlap, mask, padValue, extend_slice)
210
+ def __init__(self, patch_size : list[int] = [128, 128, 128], overlap : Union[int, None] = None, patchCombine: Union[str, None] = None, padValue: float = 0, extend_slice: int = 0) -> None:
211
+ super().__init__(patch_size, overlap, padValue, extend_slice)
224
212
  self.patchCombine = patchCombine
225
213
 
226
214
  def disassemble(self, *dataList: torch.Tensor) -> Iterator[list[torch.Tensor]]:
@@ -247,7 +235,7 @@ class DatasetManager():
247
235
  for transformFunction in pre_transforms:
248
236
  _shape = transformFunction.transformShape(_shape, cache_attribute)
249
237
 
250
- self.patch = DatasetPatch(patch_size=patch.patch_size, overlap=patch.overlap, mask=patch.path_mask, padValue=patch.padValue, extend_slice=patch.extend_slice) if patch else DatasetPatch(_shape)
238
+ self.patch = DatasetPatch(patch_size=patch.patch_size, overlap=patch.overlap, padValue=patch.padValue, extend_slice=patch.extend_slice) if patch else DatasetPatch(_shape)
251
239
  self.patch.load(_shape, 0)
252
240
  self.shape = _shape
253
241
  self.dataAugmentationsList = dataAugmentationsList
@@ -314,7 +302,7 @@ class DatasetManager():
314
302
  dataset.write(self.group_dest, self.name, data.numpy(), self.cache_attributes[0])
315
303
  self.data : list[torch.Tensor] = list()
316
304
  self.data.append(data)
317
-
305
+
318
306
  for i in range(len(self.cache_attributes)-1):
319
307
  self.cache_attributes[i+1].update(self.cache_attributes[0])
320
308
  self.loaded = True
@@ -259,7 +259,6 @@ class ResampleToShape(Resample):
259
259
  for i, s in enumerate(self.shape):
260
260
  if s == 0:
261
261
  new_shape[i] = shape[i]
262
- print(new_shape)
263
262
  return new_shape
264
263
 
265
264
  def __call__(self, name: str, input: torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
@@ -411,7 +410,7 @@ class Gradient(Transform):
411
410
  def inverse(self, name: str, input : torch.Tensor, cache_attribute: Attribute) -> torch.Tensor:
412
411
  return input
413
412
 
414
- class ArgMax(Transform):
413
+ class Argmax(Transform):
415
414
 
416
415
  def __init__(self, dim: int = 0) -> None:
417
416
  self.dim = dim
@@ -9,7 +9,6 @@ from scipy import linalg
9
9
 
10
10
  import torch.nn.functional as F
11
11
  import os
12
- import tqdm
13
12
 
14
13
  from typing import Callable, Union
15
14
  from functools import partial
@@ -19,10 +18,11 @@ import copy
19
18
  from abc import abstractmethod
20
19
 
21
20
  from konfai.utils.config import config
22
- from konfai.utils.utils import _getModule
21
+ from konfai.utils.utils import _getModule, download_url
23
22
  from konfai.data.patching import ModelPatch
24
23
  from konfai.network.blocks import LatentDistribution
25
24
  from konfai.network.network import ModelLoader, Network
25
+ from tqdm import tqdm
26
26
 
27
27
  modelsRegister = {}
28
28
 
@@ -122,7 +122,7 @@ class LPIPS(MaskedLoss):
122
122
 
123
123
  patchIterator = datasetPatch.disassemble(LPIPS.normalize(x), LPIPS.normalize(y))
124
124
  loss = 0
125
- with tqdm.tqdm(iterable = enumerate(patchIterator), leave=False, total=datasetPatch.getSize(0)) as batch_iter:
125
+ with tqdm(iterable = enumerate(patchIterator), leave=False, total=datasetPatch.getSize(0)) as batch_iter:
126
126
  for i, patch_input in batch_iter:
127
127
  real, fake = LPIPS.preprocessing(patch_input[0]), LPIPS.preprocessing(patch_input[1])
128
128
  loss += loss_fn_alex(real, fake).item()
@@ -146,9 +146,11 @@ class Dice(Criterion):
146
146
  target = self.flatten(target)
147
147
  return (2.*(input * target).sum() + self.smooth)/(input.sum() + target.sum() + self.smooth)
148
148
 
149
-
150
149
  def forward(self, output: torch.Tensor, *targets : list[torch.Tensor]) -> torch.Tensor:
151
- target = targets[0]
150
+ target = F.interpolate(
151
+ targets[0],
152
+ output.shape[2:],
153
+ mode="nearest")
152
154
  if output.shape[1] == 1:
153
155
  output = F.one_hot(output.type(torch.int64), num_classes=int(torch.max(output).item()+1)).permute(0, len(target.shape), *[i+1 for i in range(len(target.shape)-1)]).float()
154
156
  target = F.one_hot(target.type(torch.int64), num_classes=output.shape[1]).permute(0, len(target.shape), *[i+1 for i in range(len(target.shape)-1)]).float().squeeze(2)
@@ -242,7 +244,7 @@ class PerceptualLoss(Criterion):
242
244
 
243
245
  class Module():
244
246
 
245
- @config(None)
247
+ @config()
246
248
  def __init__(self, losses: dict[str, float] = {"Gram": 1, "torch_nn_L1Loss": 1}) -> None:
247
249
  self.losses = losses
248
250
  self.DL_args = os.environ['KONFAI_CONFIG_PATH'] if "KONFAI_CONFIG_PATH" in os.environ else ""
@@ -413,16 +415,19 @@ class FocalLoss(Criterion):
413
415
  self.gamma = gamma
414
416
  self.reduction = reduction
415
417
 
416
- def forward(self, output: torch.Tensor, targets: list[torch.Tensor]) -> torch.Tensor:
417
- target = targets[0].long()
418
+ def forward(self, output: torch.Tensor, *targets: list[torch.Tensor]) -> torch.Tensor:
419
+ target = F.interpolate(
420
+ targets[0],
421
+ output.shape[2:],
422
+ mode="nearest").long()
418
423
 
419
424
  logpt = F.log_softmax(output, dim=1)
420
425
  pt = torch.exp(logpt)
421
426
 
422
- logpt = logpt.gather(1, target.unsqueeze(1))
423
- pt = pt.gather(1, target.unsqueeze(1))
427
+ logpt = logpt.gather(1, target)
428
+ pt = pt.gather(1, target)
424
429
 
425
- at = self.alpha[target].unsqueeze(1)
430
+ at = self.alpha.to(target.device)[target].unsqueeze(1)
426
431
  loss = -at * ((1 - pt) ** self.gamma) * logpt
427
432
 
428
433
  if self.reduction == "mean":
@@ -508,4 +513,72 @@ class MutualInformationLoss(torch.nn.Module):
508
513
  pab = torch.bmm(wa.permute(0, 2, 1), wb.to(wa)).div(wa.shape[1]) # (batch, num_bins, num_bins)
509
514
  papb = torch.bmm(pa.permute(0, 2, 1), pb.to(pa)) # (batch, num_bins, num_bins)
510
515
  mi = torch.sum(pab * torch.log((pab + self.smooth_nr) / (papb + self.smooth_dr) + self.smooth_dr), dim=(1, 2)) # (batch)
511
- return torch.mean(mi).neg() # average over the batch and channel ndims
516
+ return torch.mean(mi).neg() # average over the batch and channel ndims
517
+
518
+ class FOCUS(Criterion): #Feature-Oriented Comparison for Unpaired Synthesis
519
+
520
+ def __init__(self, model_name: str, shape: list[int] = [0,0], in_channels: int = 1, losses: dict[str, list[float]] = {"Gram": 1, "torch_nn_L1Loss": 1}) -> None:
521
+ super().__init__()
522
+ if model_name is None:
523
+ return
524
+ self.shape = shape
525
+ self.in_channels = in_channels
526
+ self.losses: dict[torch.nn.Module, list[float]] = {}
527
+ for loss, weights in losses.items():
528
+ module, name = _getModule(loss, "konfai.metric.measure")
529
+ self.losses[config(os.environ['KONFAI_CONFIG_PATH'])(getattr(importlib.import_module(module), name))(config=None)] = weights
530
+
531
+ self.model_path = download_url(model_name, "https://huggingface.co/VBoussot/impact-torchscript-models/resolve/main/")
532
+ self.model: torch.nn.Module = torch.jit.load(self.model_path)
533
+ self.dim = len(self.shape)
534
+ self.shape = self.shape if all([s > 0 for s in self.shape]) else None
535
+ self.modules_loss: dict[str, dict[torch.nn.Module, float]] = {}
536
+
537
+ try:
538
+ dummy_input = torch.zeros((1, self.in_channels, *(self.shape if self.shape else [224] * self.dim)))
539
+ out = self.model(dummy_input)
540
+ if not isinstance(out, (list, tuple)):
541
+ raise TypeError(f"Expected model output to be a list or tuple, but got {type(out)}.")
542
+ if len(self.weight) != len(out):
543
+ raise ValueError(f"Mismatch between number of weights ({len(self.weight)}) and model outputs ({len(out)}).")
544
+ except Exception as e:
545
+ msg = (
546
+ f"[Model Sanity Check Failed]\n"
547
+ f"Input shape attempted: {dummy_input.shape}\n"
548
+ f"Expected output length: {len(self.weight)}\n"
549
+ f"Error: {type(e).__name__}: {e}"
550
+ )
551
+ raise RuntimeError(msg) from e
552
+ self.model = None
553
+
554
+ def preprocessing(self, input: torch.Tensor) -> torch.Tensor:
555
+ if self.shape is not None and not all([input.shape[-i-1] == size for i, size in enumerate(reversed(self.shape[2:]))]):
556
+ input = torch.nn.functional.interpolate(input, mode=self.mode, size=tuple(self.shape), align_corners=False).type(torch.float32)
557
+ if input.shape[1] != self.in_channels:
558
+ input = input.repeat(tuple([1,3] + [1 for _ in range(self.dim)]))
559
+ return input
560
+
561
+ def _compute(self, output: torch.Tensor, targets: list[torch.Tensor]) -> torch.Tensor:
562
+ loss = torch.zeros((1), requires_grad = True).to(output.device, non_blocking=False).type(torch.float32)
563
+ output = self.preprocessing(output)
564
+ targets = [self.preprocessing(target) for target in targets]
565
+ self.model.to(output.device)
566
+ for zipped_output in zip(self.model(output), *[self.model(target) for target in targets]):
567
+ output_features = zipped_output[0]
568
+ targets_features = zipped_output[1:]
569
+ for target_features, (loss_function, weights) in zip(targets_features, self.losses.items()):
570
+ for output_feature, target_feature, weight in zip(output_features ,target_features, weights):
571
+ loss += weights*loss_function(output_feature, target_feature)
572
+ return loss
573
+
574
+ def forward(self, output: torch.Tensor, *targets : list[torch.Tensor]) -> torch.Tensor:
575
+ if self.model is None:
576
+ self.model = torch.jit.load(self.model_path)
577
+ loss = torch.zeros((1), requires_grad = True).to(output.device, non_blocking=False).type(torch.float32)
578
+ if len(output.shape) == 5 and self.dim == 2:
579
+ for i in range(output.shape[2]):
580
+ loss = loss + self._compute(output[:, :, i, ...], [t[:, :, i, ...] for t in targets])
581
+ loss/=output.shape[2]
582
+ else:
583
+ loss = self._compute(output, targets)
584
+ return loss.to(output)
@@ -6,7 +6,6 @@ from functools import partial
6
6
 
7
7
  class Scheduler():
8
8
 
9
- @config("Scheduler")
10
9
  def __init__(self, start_value: float) -> None:
11
10
  self.baseValue = float(start_value)
12
11
  self.it = 0
@@ -20,7 +19,6 @@ class Scheduler():
20
19
 
21
20
  class Constant(Scheduler):
22
21
 
23
- @config("Constant")
24
22
  def __init__(self, value: float = 1):
25
23
  super().__init__(value)
26
24
 
@@ -29,19 +27,7 @@ class Constant(Scheduler):
29
27
 
30
28
  class CosineAnnealing(Scheduler):
31
29
 
32
- @config("CosineAnnealing")
33
- def __init__(self, start_value: float, eta_min: float = 0.00001, T_max: int = 100):
34
- super().__init__(start_value)
35
- self.eta_min = eta_min
36
- self.T_max = T_max
37
-
38
- def get_value(self):
39
- return self.eta_min + (self.baseValue - self.eta_min) *(1 + np.cos(self.it * torch.pi / self.T_max)) / 2
40
-
41
- class CosineAnnealing(Scheduler):
42
-
43
- @config("CosineAnnealing")
44
- def __init__(self, start_value: float, eta_min: float = 0.00001, T_max: int = 100):
30
+ def __init__(self, start_value: float = 1, eta_min: float = 0.00001, T_max: int = 100):
45
31
  super().__init__(start_value)
46
32
  self.eta_min = eta_min
47
33
  self.T_max = T_max
@@ -52,7 +38,7 @@ class CosineAnnealing(Scheduler):
52
38
  class Warmup(torch.optim.lr_scheduler.LambdaLR):
53
39
 
54
40
  def warmup(warmup_steps: int, step: int) -> float:
55
- return min(1.0, step / warmup_steps)
41
+ return min(1.0, (step+1) / (warmup_steps+1))
56
42
 
57
43
  def __init__(self, optimizer: torch.optim.Optimizer, warmup_steps: int = 10, last_epoch=-1, verbose="deprecated"):
58
44
  super().__init__(optimizer, partial(Warmup.warmup, warmup_steps), last_epoch, verbose)
@@ -157,10 +157,9 @@ class Head(network.ModuleArgsDict):
157
157
 
158
158
  class ConvNeXt(network.Network):
159
159
 
160
- @config("ConvNeXt")
161
160
  def __init__( self,
162
161
  optimizer : network.OptimizerLoader = network.OptimizerLoader(),
163
- schedulers : network.LRSchedulersLoader = network.LRSchedulersLoader(),
162
+ schedulers: dict[str, network.LRSchedulersLoader] = {"default:ReduceLROnPlateau": network.LRSchedulersLoader(0)},
164
163
  outputsCriterions: dict[str, network.TargetCriterionsLoader] = {"default" : network.TargetCriterionsLoader()},
165
164
  patch : ModelPatch = ModelPatch(),
166
165
  dim : int = 3,
@@ -98,10 +98,9 @@ class Head(network.ModuleArgsDict):
98
98
 
99
99
  class ResNet(network.Network):
100
100
 
101
- @config("ResNet")
102
101
  def __init__( self,
103
102
  optimizer : network.OptimizerLoader = network.OptimizerLoader(),
104
- schedulers : network.LRSchedulersLoader = network.LRSchedulersLoader(),
103
+ schedulers: dict[str, network.LRSchedulersLoader] = {"default:ReduceLROnPlateau": network.LRSchedulersLoader(0)},
105
104
  outputsCriterions: dict[str, network.TargetCriterionsLoader] = {"default" : network.TargetCriterionsLoader()},
106
105
  patch : ModelPatch = ModelPatch(),
107
106
  dim : int = 3,
@@ -118,10 +118,9 @@ class Generator(network.Network):
118
118
  self.add_module("Conv", blocks.getTorchModule("Conv", dim)(in_channels = in_channels, out_channels = out_channels, kernel_size = 3, stride = 1, padding = 1))
119
119
  self.add_module("Tanh", torch.nn.Tanh())
120
120
 
121
- @config("Generator")
122
121
  def __init__(self,
123
122
  optimizer : network.OptimizerLoader = network.OptimizerLoader(),
124
- schedulers : network.LRSchedulersLoader = network.LRSchedulersLoader(),
123
+ schedulers: dict[str, network.LRSchedulersLoader] = {"default:ReduceLROnPlateau": network.LRSchedulersLoader(0)},
125
124
  patch : ModelPatch = ModelPatch(),
126
125
  outputsCriterions: dict[str, network.TargetCriterionsLoader] = {"default" : network.TargetCriterionsLoader()},
127
126
  channels: list[int]=[1, 64, 128, 256, 512, 1024],
@@ -175,10 +175,9 @@ class DDPM(network.Network):
175
175
  def forward(self, input: torch.Tensor, eta: torch.Tensor, eta_hat: torch.Tensor, t: int) -> torch.Tensor:
176
176
  return torch.concat((self.v(input, eta, t), self.v(input, eta_hat, t)), dim=1)
177
177
 
178
- @config("DDPM")
179
178
  def __init__( self,
180
179
  optimizer : network.OptimizerLoader = network.OptimizerLoader(),
181
- schedulers : network.LRSchedulersLoader = network.LRSchedulersLoader(),
180
+ schedulers: dict[str, network.LRSchedulersLoader] = {"default:ReduceLROnPlateau": network.LRSchedulersLoader(0)},
182
181
  outputsCriterions: dict[str, network.TargetCriterionsLoader] = {"default" : network.TargetCriterionsLoader()},
183
182
  patch : Union[ModelPatch, None] = None,
184
183
  train_noise_step: int = 1000,